diff --git a/.bazelrc b/.bazelrc
index 590a87f..01b416c 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -30,6 +30,10 @@
 # opts in to modular op registration support by default.
 build --define framework_shared_object=true
 
+# Flags for open source build, always set to be true.
+build --define open_source_build=true
+test --define open_source_build=true
+
 # Please note that MKL on MacOS or windows is still not supported.
 # If you would like to use a local MKL instead of downloading, please set the
 # environment variable "TF_MKL_ROOT" every time before build.
diff --git a/CODEOWNERS b/CODEOWNERS
index f498440..e8bef10 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -1,6 +1,6 @@
 # Where component owners are known, add them here.
 
-/tensorflow/c/eager @jaingurav @alextp
+/tensorflow/c/eager @jaingaurav @alextp
 /tensorflow/core/common_runtime/eager @jaingaurav @alextp
 /tenosrflow/core/debug @caisq
 /tensorflow/core/nccl/ @azaks2 @chsigg
@@ -8,7 +8,7 @@
 /tensorflow/core/platform/s3 @yongtang
 /tensorflow/python/autograph/ @mdanatg @kkimdev
 /tensorflow/python/debug @caisq
-/tensorflow/python/eager @jaingurav @alextp
+/tensorflow/python/eager @jaingaurav @alextp
 /tensorflow/python/tools/api/generator/ @annarev
 /tensorflow/tensorboard/ @jart
 /tensorflow/tools/docs/ @markdaoust
@@ -28,7 +28,7 @@
 /tensorflow/contrib/data/ @mrry
 /tensorflow/tensorflow/contrib/distribute @joshl @priyag @sourabhbajaj @frankchn
 /tensorflow/contrib/distributions/ @jvdillon @langmore @rsepassi
-/tensorflow/contrib/eager @jaingurav @alextp
+/tensorflow/contrib/eager @jaingaurav @alextp
 /tensorflow/contrib/factorization/ @agarwal-ashish @xavigonzalvo
 /tensorflow/contrib/ffmpeg/ @fredbertsch
 /tensorflow/contrib/framework/ @ebrevdo
diff --git a/RELEASE.md b/RELEASE.md
index 6a4c2d6..debbba7 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -43,6 +43,11 @@
 *   Transitive dependencies on :pooling_ops were removed. Some users may need to
     add explicit dependencies on :pooling_ops if they reference the operators
     from that library.
+*   tf.keras.optimizers default learning rate changes:
+    *   Adadelta: 1.000 to 0.001
+    *   Adagrad: 0.01 to 0.001
+    *   Adamax: 0.002 to 0.001
+    *   NAdam: 0.002 to 0.001
 
 ## Bug Fixes and Other Changes
 
diff --git a/WORKSPACE b/WORKSPACE
index d5bd495..74ea14d 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -49,10 +49,15 @@
 # Apple and Swift rules.
 http_archive(
     name = "build_bazel_rules_apple",
-    sha256 = "23792cd999f97fc97284d1c44cb1324bfdd0bc54aa68ad513fa3705aca3b1f9e",
-    urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.15.0/rules_apple.0.15.0.tar.gz"],
+    sha256 = "6efdde60c91724a2be7f89b0c0a64f01138a45e63ba5add2dca2645d981d23a1",
+    urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.17.2/rules_apple.0.17.2.tar.gz"],
 )  # https://github.com/bazelbuild/rules_apple/releases
 http_archive(
+    name = "build_bazel_rules_swift",
+    sha256 = "96a86afcbdab215f8363e65a10cf023b752e90b23abf02272c4fc668fcb70311",
+    urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.11.1/rules_swift.0.11.1.tar.gz"],
+)  # https://github.com/bazelbuild/rules_swift/releases
+http_archive(
     name = "build_bazel_apple_support",
     sha256 = "7356dbd44dea71570a929d1d4731e870622151a5f27164d966dda97305f33471",
     urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.6.0/apple_support.0.6.0.tar.gz"],
@@ -63,11 +68,6 @@
     urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.8.0/bazel-skylib.0.8.0.tar.gz"],
 )  # https://github.com/bazelbuild/bazel-skylib/releases
 http_archive(
-    name = "build_bazel_rules_swift",
-    sha256 = "9efe9699e9765e6b4a5e063e4a08f6b163cccaf0443f775d935baf5c3cd6ed0e",
-    urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.9.0/rules_swift.0.9.0.tar.gz"],
-)  # https://github.com/bazelbuild/rules_swift/releases
-http_archive(
     name = "com_github_apple_swift_swift_protobuf",
     type = "zip",
     strip_prefix = "swift-protobuf-1.5.0/",
@@ -104,8 +104,7 @@
     build_file = "//:models.BUILD",
     sha256 = "7efe12a8363f09bc24d7b7a450304a15655a57a7751929b2c1593a71183bb105",
     urls = [
-        "http://storage.googleapis.com/download.tensorflow.org/models/inception_v1.zip",
-        "http://download.tensorflow.org/models/inception_v1.zip",
+        "https://storage.googleapis.com/download.tensorflow.org/models/inception_v1.zip",
     ],
 )
 
@@ -114,8 +113,7 @@
     build_file = "//:models.BUILD",
     sha256 = "bddd81ea5c80a97adfac1c9f770e6f55cbafd7cce4d3bbe15fbeb041e6b8f3e8",
     urls = [
-        "http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip",
-        "http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip",
+        "https://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip",
     ],
 )
 
@@ -124,8 +122,7 @@
     build_file = "//:models.BUILD",
     sha256 = "859edcddf84dddb974c36c36cfc1f74555148e9c9213dedacf1d6b613ad52b96",
     urls = [
-        "http://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip",
-        "http://download.tensorflow.org/models/mobile_multibox_v1a.zip",
+        "https://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip",
     ],
 )
 
@@ -134,8 +131,7 @@
     build_file = "//:models.BUILD",
     sha256 = "3d374a730aef330424a356a8d4f04d8a54277c425e274ecb7d9c83aa912c6bfa",
     urls = [
-        "http://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip",
-        "http://download.tensorflow.org/models/stylize_v1.zip",
+        "https://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip",
     ],
 )
 
@@ -144,7 +140,6 @@
     build_file = "//:models.BUILD",
     sha256 = "c3ec4fea3158eb111f1d932336351edfe8bd515bb6e87aad4f25dbad0a600d0c",
     urls = [
-        "http://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
-        "http://download.tensorflow.org/models/speech_commands_v0.01.zip",
+        "https://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
     ],
 )
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 61539c5..6b86445 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -7,7 +7,7 @@
 load("//tensorflow:tensorflow.bzl", "tf_custom_op_library_additional_deps_impl")
 load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary")
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_additional_binary_deps",
 )
 load(
@@ -356,6 +356,15 @@
     },
 )
 
+# Flag to indicate open source build, .bazelrc always has it set to be true
+config_setting(
+    name = "oss",
+    define_values = {
+        "open_source_build": "true",
+    },
+    visibility = ["//visibility:public"],
+)
+
 config_setting(
     name = "using_cuda_clang_with_dynamic_build",
     define_values = {
@@ -364,11 +373,20 @@
     },
 )
 
+config_setting(
+    name = "build_oss_using_cuda_clang",
+    define_values = {
+        "using_cuda_clang": "true",
+        "open_source_build": "true",
+    },
+)
+
 # Setting to use when loading kernels dynamically
 config_setting(
     name = "dynamic_loaded_kernels",
     define_values = {
         "dynamic_loaded_kernels": "true",
+        "framework_shared_object": "true",
     },
     visibility = ["//visibility:public"],
 )
@@ -389,6 +407,14 @@
 )
 
 config_setting(
+    name = "build_oss_using_cuda_nvcc",
+    define_values = {
+        "using_cuda_nvcc": "true",
+        "open_source_build": "true",
+    },
+)
+
+config_setting(
     name = "using_rocm_hipcc",
     define_values = {
         "using_rocm_hipcc": "true",
@@ -607,6 +633,7 @@
         "//tensorflow/c:version_script.lds",
         "//tensorflow/c/eager:c_api",
         "//tensorflow/core:tensorflow",
+        "//tensorflow/core/distributed_runtime/rpc:grpc_session",
     ],
 )
 
@@ -750,8 +777,8 @@
     mkdir $@
     for f in $(SRCS); do
       d="$${f%/*}"
-      d="$${d#bazel-out*genfiles/}"
-      d="$${d#*external/eigen_archive/}"
+      d="$${d#bazel-out/*/genfiles/}"
+      d="$${d#bazel-out/*/bin/}"
 
       if [[ $${d} == *local_config_* ]]; then
         continue
@@ -763,6 +790,9 @@
         if [[ $${TF_SYSTEM_LIBS:-} == *$${extname}* ]]; then
           continue
         fi
+
+        d="$${d#*external/farmhash_archive/src}"
+        d="$${d#*external/$${extname}/}"
       fi
 
       mkdir -p "$@/$${d}"
diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py
index 6d1c40a..2962a7a 100644
--- a/tensorflow/api_template_v1.__init__.py
+++ b/tensorflow/api_template_v1.__init__.py
@@ -27,11 +27,27 @@
 # pylint: disable=g-bad-import-order
 from tensorflow.python import pywrap_tensorflow  # pylint: disable=unused-import
 from tensorflow.python.tools import module_util as _module_util
+from tensorflow.python.platform import tf_logging as _logging
 
 # API IMPORTS PLACEHOLDER
 
 # WRAPPER_PLACEHOLDER
 
+if "dev" in __version__:   # pylint: disable=undefined-variable
+  _logging.warning("""
+
+  TensorFlow's `tf-nightly` package will soon be updated to TensorFlow 2.0.
+
+  Please upgrade your code to TensorFlow 2.0:
+    * https://www.tensorflow.org/beta/guide/migration_guide
+
+  Or install the latest stable TensorFlow 1.X release:
+    * `pip install -U "tensorflow==1.*"`
+
+  Otherwise your code may be broken by the change.
+
+  """)
+
 # Make sure directory containing top level submodules is in
 # the __path__ so that "from tensorflow.foo import bar" works.
 # We're using bitwise, but there's nothing special about that.
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index dd5a3a0..f740ba6 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -73,7 +73,7 @@
             "//tensorflow/core:core_cpu",
             "//tensorflow/core:framework",
             "//tensorflow/core:lib",
-            "//tensorflow/core:lib_platform",
+            "//tensorflow/core/platform:platform",
             "//tensorflow/core:op_gen_lib",
             "//tensorflow/core/distributed_runtime:server_lib",
         ],
@@ -264,10 +264,10 @@
         "//tensorflow/core:core_cpu",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
-        "//tensorflow/core:lib_platform",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/common_runtime/eager:attr_builder",
         "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
+        "//tensorflow/core/platform",
         "@com_google_absl//absl/strings",
     ],
 )
@@ -355,6 +355,7 @@
     deps = [
         ":tf_status",
         ":tf_status_helper",
+        ":tf_tensor_internal",
     ] + select({
         "//tensorflow:android": [
             ":c_api_internal",
@@ -503,6 +504,7 @@
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
+        "@com_google_absl//absl/types:optional",
     ],
 )
 
@@ -579,7 +581,7 @@
         "//tensorflow:macos": ["-headerpad_max_install_names"],
         "//conditions:default": [],
     }),
-    tags = ["noasan"],
+    tags = ["no_cuda_on_cpu_tap"],
     # We must ensure that the dependencies can be dynamically linked since
     # the shared library must be able to use core:framework.
     # linkstatic = tf_kernel_tests_linkstatic(),
@@ -592,6 +594,8 @@
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
+        "//tensorflow/core/kernels:ops_testutil",
+        "//third_party/eigen3",
     ],
 )
 
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 62b2504..ed4f10e 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -1024,7 +1024,7 @@
       desc->colocation_constraints.insert(location);
     }
   } else {
-    desc->node_builder.Attr(attr_name, attr_value);
+    desc->node_builder.Attr(attr_name, std::move(attr_value));
   }
 
   status->status = Status::OK();
@@ -1045,7 +1045,8 @@
           std::vector<string>(desc->colocation_constraints.begin(),
                               desc->colocation_constraints.end()));
     }
-    status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret);
+    status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret,
+                                                 /*consume=*/true);
 
     if (TF_GetCode(status) == TF_OK) {
       // Run shape inference function for newly added node.
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index b37d2e7..f04f017 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -598,7 +598,10 @@
 TF_CheckpointReader* TF_NewCheckpointReader(const char* filename,
                                             TF_Status* status) {
   TF_CheckpointReader* reader = new TF_CheckpointReader(filename, status);
-  if (!status->status.ok()) return nullptr;
+  if (!status->status.ok()) {
+    TF_DeleteCheckpointReader(reader);
+    return nullptr;
+  }
   const auto& m = reader->GetVariableToDataTypeMap();
   for (auto it = m.begin(); it != m.end(); ++it)
     reader->variable_list.push_back(it->first);
@@ -1050,8 +1053,12 @@
   delete[] shape_list_array;
 }
 
+namespace tensorflow {
+Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
+}  // namespace tensorflow
+
 void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
-                     TF_Tensor** input_tensors, int num_input_tensors,
+                     TF_Tensor** input_tensors,
                      TF_ShapeAndTypeList* input_tensors_as_shapes,
                      TF_ShapeAndTypeList** input_resource_shapes_and_types,
                      TF_ShapeAndTypeList** output_shapes,
@@ -1079,10 +1086,30 @@
       tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data);
   if (!status->status.ok()) return;
 
+  // Initialize a input_tensor vector with `nullptr` values.
+  std::vector<const Tensor*> input_tensors_vector(num_inputs, nullptr);
+  // A vector to keep track of newly created `tf::Tensor` objects.
+  std::vector<Tensor> all_input_tensors;
+  // Update the vector with information from `input_tensors` if provided.
+  if (input_tensors != nullptr) {
+    // Note that we take the address of the elements in `all_input_tensors`
+    // below. Allocate enough space so that no reallocation happens, which will
+    // make the pointers invalid.
+    all_input_tensors.reserve(num_inputs);
+    for (int i = 0; i < num_inputs; ++i) {
+      if (input_tensors[i] == nullptr) continue;
+      all_input_tensors.emplace_back();
+      Tensor& input_tensor = all_input_tensors.back();
+      status->status = TF_TensorToTensor(input_tensors[i], &input_tensor);
+      if (!status->status.ok()) return;
+      input_tensors_vector[i] = &input_tensor;
+    }
+  }
+
   // Create an inference context with dummy values, which will be updated later.
   InferenceContext c(TF_GRAPH_DEF_VERSION, &node_def, op_reg_data->op_def,
-                     std::vector<ShapeHandle>(num_inputs),
-                     std::vector<const Tensor*>(num_inputs, nullptr), {},
+                     std::vector<ShapeHandle>(num_inputs), input_tensors_vector,
+                     {},
                      std::vector<std::unique_ptr<std::vector<ShapeAndType>>>());
 
   // Set input_shapes.
@@ -1099,7 +1126,6 @@
     c.SetInput(i, c.MakeShape(dims));
   }
 
-  // TODO(bgogul): Handle input_tensors.
   // TODO(bgogul): Handle input_tensors_as_shapes.
   // TODO(bgogul): Handle input_resource_shapes_and_types.
 
@@ -1136,3 +1162,8 @@
 
   // TODO(bgogul): Set output_resource_shapes_and_types.
 }
+
+void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
+    TF_ImportGraphDefOptions* opts, unsigned char enable) {
+  opts->opts.validate_colocation_constraints = enable;
+}
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index 36028fd..126db26 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -378,19 +378,30 @@
 TF_CAPI_EXPORT extern void TF_DeleteShapeAndTypeListArray(
     TF_ShapeAndTypeList** shape_list_array, int num_items);
 
-// Infer shapes for the given `node_def`. The arguments mimic the arguments of
-// the `shape_inference::InferenceContext` constructor. The types need not be
-// set in `input_shapes` as it is not used for shape inference.
+// Infer shapes for the given `op`. The arguments mimic the arguments of the
+// `shape_inference::InferenceContext` constructor. Note the following:
+//   - The inputs of the `op` are not used for shape inference. So, it is
+//     OK to not have the inputs properly set in `op`. See `input_tensors`
+//     if you want shape inference to consider the input tensors of the
+//     op for shape inference.
+//   - The types need not be set in `input_shapes` as it is not used.
+//   - The number of `input_tensors` should be the same as the number of items
+//     in `input_shapes`.
 //
 // The results are returned in `output_shapes` and
 // `output_resource_shapes_and_types`. The caller is responsible for freeing the
 // memory in these buffers by calling `TF_DeleteShapeAndTypeList`.
 TF_CAPI_EXPORT extern void TFE_InferShapes(
     TFE_Op* op, TF_ShapeAndTypeList* input_shapes, TF_Tensor** input_tensors,
-    int num_input_tensors, TF_ShapeAndTypeList* input_tensor_as_shapes,
+    TF_ShapeAndTypeList* input_tensor_as_shapes,
     TF_ShapeAndTypeList** input_resource_shapes_and_types,
     TF_ShapeAndTypeList** output_shapes,
     TF_ShapeAndTypeList*** output_resource_shapes_and_types, TF_Status* status);
+
+TF_CAPI_EXPORT extern void
+TF_ImportGraphDefOptionsSetValidateColocationConstraints(
+    TF_ImportGraphDefOptions* opts, unsigned char enable);
+
 #ifdef __cplusplus
 } /* end extern "C" */
 #endif
diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc
index f4f6753..ed0ab7c 100644
--- a/tensorflow/c/c_api_experimental_test.cc
+++ b/tensorflow/c/c_api_experimental_test.cc
@@ -14,6 +14,8 @@
 ==============================================================================*/
 
 #include "tensorflow/c/c_api_experimental.h"
+
+#include "absl/types/optional.h"
 #include "tensorflow/c/c_api_internal.h"
 #include "tensorflow/c/c_test_util.h"
 #include "tensorflow/c/eager/c_api.h"
@@ -437,86 +439,149 @@
       : status_(TF_NewStatus()), tfe_context_options_(TFE_NewContextOptions()) {
     tfe_context_ = TFE_NewContext(tfe_context_options_, status_);
     CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
-    matmul_op_ = TFE_NewOp(tfe_context_, "MatMul", status_);
-    CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
   }
 
   ~ShapeInferenceTest() override {
-    TFE_DeleteOp(matmul_op_);
     TFE_DeleteContextOptions(tfe_context_options_);
     TFE_DeleteContext(tfe_context_);
     TF_DeleteStatus(status_);
   }
 
-  void infer_matmul_shapes(TF_ShapeAndTypeList* input_shapes,
-                           int64_t expected_rank, int64_t expected_first_dim,
-                           int64_t expected_second_dim) {
+  // Checks the expected result of shape inference for the given `op`.
+  void CheckOutputShapes(
+      TFE_Op* op,
+      const std::vector<absl::optional<std::vector<int64_t>>>& input_shapes_vec,
+      const std::vector<TF_Tensor*>& input_tensors,
+      const absl::optional<std::vector<int64_t>>& expected_shape) {
+    // Create input_shapes.
+    TF_ShapeAndTypeList* input_shapes =
+        TF_NewShapeAndTypeList(input_shapes_vec.size());
+    for (size_t i = 0; i < input_shapes_vec.size(); ++i) {
+      const auto& input_shape = input_shapes_vec[i];
+      if (input_shape.has_value()) {
+        TF_ShapeAndTypeListSetShape(input_shapes, i, input_shape->data(),
+                                    input_shape->size());
+      } else {
+        TF_ShapeAndTypeListSetUnknownShape(input_shapes, i);
+      }
+    }
     TF_ShapeAndTypeList* output_shapes;
-    TFE_InferShapes(matmul_op_, input_shapes,
-                    /*input_tensors*/ nullptr, /*num_input_tensors*/ 0,
+    TFE_InferShapes(op, input_shapes,
+                    input_tensors.empty()
+                        ? nullptr
+                        : const_cast<TF_Tensor**>(input_tensors.data()),
                     /*input_tensors_as_shapes*/ nullptr,
                     /*input_resource_shapes_and_types*/ nullptr, &output_shapes,
                     /*output_resource_shapes_and_types*/ nullptr, status_);
     CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
     CHECK_EQ(output_shapes->num_items, 1);
-    EXPECT_EQ(output_shapes->items[0].num_dims, expected_rank);
-    if (expected_rank == 2) {
-      EXPECT_EQ(output_shapes->items[0].dims[0], expected_first_dim);
-      EXPECT_EQ(output_shapes->items[0].dims[1], expected_second_dim);
+
+    int num_dims = output_shapes->items[0].num_dims;
+    int64_t* dims = output_shapes->items[0].dims;
+
+    if (!expected_shape.has_value()) {
+      EXPECT_EQ(num_dims, -1);
+      EXPECT_EQ(dims, nullptr);
+      return;
+    }
+
+    EXPECT_EQ(num_dims, expected_shape->size());
+    for (size_t i = 0; i < num_dims; ++i) {
+      EXPECT_EQ(dims[i], (*expected_shape)[i]);
     }
     TF_DeleteShapeAndTypeList(input_shapes);
     TF_DeleteShapeAndTypeList(output_shapes);
   }
 
+  absl::optional<std::vector<int64_t>> make_shape(
+      std::vector<int64_t>&& dims) const {
+    return absl::make_optional(dims);
+  }
+
+  absl::optional<std::vector<int64_t>> unknown_shape() const {
+    return absl::nullopt;
+  }
+
+  static constexpr int64_t kUnknownDim =
+      shape_inference::InferenceContext::kUnknownDim;
   TF_Status* status_;
   TFE_ContextOptions* tfe_context_options_;
   TFE_Context* tfe_context_;
-  TFE_Op* matmul_op_;
 };
 
-TEST_F(ShapeInferenceTest, InfersShapes) {
+TEST_F(ShapeInferenceTest, InfersShapesFromInputShapes) {
+  TFE_Op* matmul_op;
+  matmul_op = TFE_NewOp(tfe_context_, "MatMul", status_);
+  CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
+
   // Infer shape when everything is known.
-  int64_t _3by2[] = {3, 2};
-  int64_t _2by4[] = {2, 4};
-  TF_ShapeAndTypeList* input_shapes = TF_NewShapeAndTypeList(/*num_shapes*/ 2);
-  TF_ShapeAndTypeListSetShape(input_shapes, 0, _3by2, 2);
-  TF_ShapeAndTypeListSetShape(input_shapes, 1, _2by4, 2);
-  infer_matmul_shapes(input_shapes, /*expected_rank*/ 2,
-                      /*expected_first_dim*/ 3, /*expected_second_dim*/ 4);
+  CheckOutputShapes(matmul_op,
+                    /*input_shapes*/ {make_shape({3, 2}), make_shape({2, 4})},
+                    /*input_tensors*/ {},
+                    /*expected_shape*/ make_shape({3, 4}));
 
   // Infer shape when second operand has unknown shape.
-  TF_ShapeAndTypeList* input_shapes_unknown_second =
-      TF_NewShapeAndTypeList(/*num_shapes*/ 2);
-  TF_ShapeAndTypeListSetShape(input_shapes_unknown_second, 0, _3by2, 2);
-  TF_ShapeAndTypeListSetUnknownShape(input_shapes_unknown_second, 1);
-  infer_matmul_shapes(
-      input_shapes_unknown_second, /*expected_rank*/ 2,
-      /*expected_first_dim*/ 3,
-      /*expected_second_dim*/ shape_inference::InferenceContext::kUnknownDim);
+  CheckOutputShapes(matmul_op,
+                    /*input_shapes*/ {make_shape({3, 2}), unknown_shape()},
+                    /*input_tensors*/ {},
+                    /*expected_shape*/ make_shape({3, kUnknownDim}));
 
   // Infer shape when some dimensions are unknown.
-  int64_t _unknownby2[] = {-1, 2};
-  TF_ShapeAndTypeList* input_shapes_unknown_dims =
-      TF_NewShapeAndTypeList(/*num_shapes*/ 2);
-  TF_ShapeAndTypeListSetShape(input_shapes_unknown_dims, 0, _unknownby2, 2);
-  TF_ShapeAndTypeListSetShape(input_shapes_unknown_dims, 1, _2by4, 2);
-  infer_matmul_shapes(
-      input_shapes_unknown_dims, /*expected_rank*/ 2,
-      /*expected_first_dim*/ shape_inference::InferenceContext::kUnknownDim,
-      /*expected_second_dim*/ 4);
+  CheckOutputShapes(
+      matmul_op,
+      /*input_shapes*/ {make_shape({kUnknownDim, 2}), make_shape({2, 4})},
+      /*input_tensors*/ {},
+      /*expected_shape*/ make_shape({kUnknownDim, 4}));
 
   // Infer shape when everything is unknown.
-  TF_ShapeAndTypeList* unknown_shapes =
-      TF_NewShapeAndTypeList(/*num_shapes*/ 2);
-  TF_ShapeAndTypeListSetUnknownShape(unknown_shapes, 0);
-  TF_ShapeAndTypeListSetUnknownShape(unknown_shapes, 1);
-  infer_matmul_shapes(
-      unknown_shapes, /*expected_rank*/ 2,
-      /*expected_first_dim*/ shape_inference::InferenceContext::kUnknownDim,
-      /*expected_second_dim*/ shape_inference::InferenceContext::kUnknownDim);
+  CheckOutputShapes(matmul_op,
+                    /*input_shapes*/ {unknown_shape(), unknown_shape()},
+                    /*input_tensors*/ {},
+                    /*expected_shape*/ make_shape({kUnknownDim, kUnknownDim}));
 
+  TFE_DeleteOp(matmul_op);
   // TODO(bgogul): Add some death tests where status is not OK.
 }
 
+TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) {
+  // Prepare some tensors for shape.
+  TF_Tensor* tensor_1X6 = Int32Tensor({1, 6});
+  CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
+  TF_Tensor* tensor_1X1X6 = Int32Tensor({1, 1, 6});
+  CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
+
+  TFE_Op* reshape_op = TFE_NewOp(tfe_context_, "Reshape", status_);
+  CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
+  TFE_OpSetAttrType(reshape_op, "T", TF_FLOAT);
+  TFE_OpSetAttrType(reshape_op, "Tshape", TF_INT32);
+  CheckOutputShapes(reshape_op,
+                    /* input_shapes*/ {unknown_shape(), unknown_shape()},
+                    /* input_tensors*/ {nullptr, tensor_1X6},
+                    /*expected_shape*/ make_shape({1, 6}));
+  TFE_DeleteOp(reshape_op);
+  reshape_op = nullptr;
+
+  TFE_Op* fill_op = TFE_NewOp(tfe_context_, "Fill", status_);
+  CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
+  TFE_OpSetAttrType(fill_op, "T", TF_FLOAT);
+  TFE_OpSetAttrType(fill_op, "Tshape", TF_INT32);
+
+  float five = 5.0;
+  TFE_TensorHandle* scalar = TestScalarTensorHandle(five);
+  TF_Tensor* scalarTensor = TFE_TensorHandleResolve(scalar, status_);
+  CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
+  CheckOutputShapes(fill_op,
+                    /* input_shapes*/ {unknown_shape(), unknown_shape()},
+                    /* input_tensors*/ {tensor_1X1X6, scalarTensor},
+                    /*expected_shape*/ make_shape({1, 1, 6}));
+  TFE_DeleteOp(fill_op);
+  fill_op = nullptr;
+
+  TFE_DeleteTensorHandle(scalar);
+  TF_DeleteTensor(scalarTensor);
+  TF_DeleteTensor(tensor_1X1X6);
+  TF_DeleteTensor(tensor_1X6);
+}
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index 4907603..ddf1f46 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -22,6 +22,7 @@
 #include <vector>
 
 #include "tensorflow/c/c_test_util.h"
+#include "tensorflow/c/tf_status.h"
 #include "tensorflow/cc/saved_model/signature_constants.h"
 #include "tensorflow/cc/saved_model/tag_constants.h"
 #include "tensorflow/core/example/example.pb.h"
@@ -233,7 +234,7 @@
     // Create C++ Tensor
     Tensor src(tensorflow::DT_STRING, TensorShape(dims));
     for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
-      src.flat<string>()(i) = data[i];
+      src.flat<tstring>()(i) = data[i];
     }
     TF_Tensor* dst = TF_TensorFromTensor(src, status);
     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@@ -243,7 +244,7 @@
     ASSERT_EQ(Status::OK(), TF_TensorToTensor(dst, &output)) << line;
     ASSERT_EQ(src.NumElements(), output.NumElements()) << line;
     for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
-      ASSERT_EQ(data[i], output.flat<string>()(i)) << line;
+      ASSERT_EQ(data[i], output.flat<tstring>()(i)) << line;
     }
 
     TF_DeleteTensor(dst);
@@ -1385,7 +1386,7 @@
     tensorflow::Example example;
     auto* feature_map = example.mutable_features()->mutable_feature();
     (*feature_map)["x"].mutable_float_list()->add_value(i);
-    input.flat<string>()(i) = example.SerializeAsString();
+    input.flat<tstring>()(i) = example.SerializeAsString();
   }
 
   const tensorflow::string input_op_name(
@@ -2498,6 +2499,38 @@
 
 #undef EXPECT_TF_META
 
+TEST(CAPI, TestTensorAligned) {
+  int64_t dim = 7;
+  size_t tensor_size_bytes = dim * TF_DataTypeSize(TF_FLOAT);
+  TF_Tensor* a = TF_AllocateTensor(
+      /*dtype=*/TF_FLOAT, /*dims=*/&dim, /*num_dims=*/1,
+      /*len=*/tensor_size_bytes);
+  float* data = reinterpret_cast<float*>(TF_TensorData(a));
+  for (int i = 0; i < dim; ++i) {
+    data[i] = 0;
+  }
+  if (EIGEN_MAX_ALIGN_BYTES > 0) {
+    EXPECT_TRUE(TF_TensorIsAligned(a));
+  }
+  TF_DeleteTensor(a);
+}
+
+TEST(CAPI, TestTensorIsNotAligned) {
+  // Test unaligned access via a Slice.
+  Tensor x(DT_FLOAT, TensorShape({30}));
+  x.flat<float>().setConstant(0.0);
+
+  // Take an unaligned slice.
+  Tensor y = x.Slice(1, 13);
+  TF_Status* status = TF_NewStatus();
+  TF_Tensor* a = TF_TensorFromTensor(y, status);
+  if (EIGEN_MAX_ALIGN_BYTES > 0) {
+    EXPECT_FALSE(TF_TensorIsAligned(a));
+  }
+  TF_DeleteStatus(status);
+  TF_DeleteTensor(a);
+}
+
 }  // namespace
 }  // namespace tensorflow
 
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 7eddc17..5c42e50 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -8,12 +8,12 @@
     "tfe_xla_copts",
 )
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_additional_device_tracer_test_flags",
     "tf_kernel_tests_linkstatic",
 )
 load(
-    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "//tensorflow/core/platform:default/build_config_root.bzl",
     "tf_cuda_tests_tags",
 )
 
@@ -156,6 +156,7 @@
     ],
     deps = [
         ":c_api",
+        ":c_api_experimental",
         ":c_api_internal",
         ":c_api_test_util",
         "//tensorflow/c:c_test_util",
@@ -235,9 +236,11 @@
     ],
     args =
         ["--heap_check=local"] + tf_additional_device_tracer_test_flags(),
+    extra_copts = tfe_xla_copts(),
     linkstatic = tf_kernel_tests_linkstatic(),
     tags = tf_cuda_tests_tags() + ["nomac"],
     deps = [
+        ":c_api",
         ":c_api_experimental",
         ":c_api_test_util",
         "//tensorflow/c:c_test_util",
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 22c1f21..49d2891 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -202,9 +202,11 @@
         "Currently, TFE_NewContext only supports tensorflow::GrpcServer."));
   }
 
-  LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
-
-  tensorflow::uint64 context_id = tensorflow::random::New64();
+  tensorflow::uint64 context_id = tensorflow::EagerContext::NewContextId();
+  // Make master eager context accessible by local eager service, which might
+  // receive send tensor requests from remote workers.
+  LOG_AND_RETURN_IF_ERROR(grpc_server->AddMasterEagerContextToEagerService(
+      context_id, ctx->context));
 
   std::vector<string> remote_workers;
   grpc_server->master_env()->worker_cache->ListWorkers(&remote_workers);
@@ -240,9 +242,11 @@
           &remote_eager_workers));
 
   // Initialize remote eager workers.
-  LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
-      remote_workers, context_id, keep_alive_secs, server_def,
-      remote_eager_workers.get(), ctx->context->Async(), base_request));
+  // TODO(b/138847548) Create remote eager contexts in async mode by default.
+  LOG_AND_RETURN_IF_ERROR(
+      CreateRemoteContexts(remote_workers, context_id, keep_alive_secs,
+                           server_def, remote_eager_workers.get(),
+                           ctx->context->Executor()->Async(), base_request));
 
   tensorflow::RemoteRendezvous* r =
       grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
@@ -261,15 +265,21 @@
   TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
 
   auto* device_mgr = grpc_server->worker_env()->device_mgr;
-  auto remote_mgr =
-      absl::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/true);
+  auto remote_mgr = absl::make_unique<tensorflow::eager::RemoteMgr>(
+      /*is_master=*/true, ctx->context);
 
-  return ctx->context->InitializeRemoteMaster(
+  LOG_AND_RETURN_IF_ERROR(ctx->context->InitializeRemoteMaster(
       std::move(server), grpc_server->worker_env(), worker_session,
       std::move(remote_eager_workers), std::move(remote_device_mgr),
       remote_workers, context_id, r, device_mgr, keep_alive_secs,
-      worker_session->cluster_flr.get(), std::move(remote_mgr));
+      worker_session->cluster_flr.get(), std::move(remote_mgr)));
+
+  // NOTE: We start the server after all other initialization, because the
+  // GrpcServer cannot be destroyed after it is started.
+  LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
 #undef LOG_AND_RETURN_IF_ERROR
+
+  return tensorflow::Status::OK();
 }
 #endif  // !IS_MOBILE_PLATFORM
 
@@ -365,12 +375,6 @@
   options->device_placement_policy = policy;
 }
 
-TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
-                                                        unsigned char enable,
-                                                        TF_Status* status) {
-  status->status = ctx->context->SetAsyncForThread(enable);
-}
-
 void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
 
 TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
@@ -455,18 +459,6 @@
       ctx->context->GetDevicePlacementPolicy());
 }
 
-void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) {
-  status->status = ctx->context->AsyncWait();
-}
-
-void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) {
-  status->status = ctx->context->GetStatus();
-}
-
-void TFE_ContextAsyncClearError(TFE_Context* ctx) {
-  ctx->context->ClearAsyncError();
-}
-
 TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
   tensorflow::Tensor tensor;
   status->status = tensorflow::TF_TensorToTensor(t, &tensor);
@@ -571,7 +563,8 @@
     const tensorflow::Tensor* t = nullptr;
     tensorflow::TensorHandle* h_cpu = nullptr;
     status->status = EagerCopyToDevice(
-        handle, handle->Context(), handle->Context()->HostCPU(), false, &h_cpu);
+        handle, handle->Context(), handle->Context()->Executor(),
+        handle->Context()->HostCPU(), false, &h_cpu);
     if (!status->status.ok()) {
       return nullptr;
     }
@@ -671,7 +664,7 @@
 
 TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
                               unsigned char* is_list, TF_Status* status) {
-  TF_AttrType ret;
+  TF_AttrType ret = TF_ATTR_INT;
   status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(),
                                               attr_name, &ret, is_list);
   return ret;
@@ -683,10 +676,11 @@
                                   TF_Status* status) {
   TF_AttrType ret;
   TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status);
-  if (!status->status.ok()) {
-    return TF_ATTR_INT;  // Same dummy return as TFE_OpGetAttrType.
+  if (status->status.ok()) {
+    ret = TFE_OpGetAttrType(op, attr_name, is_list, status);
+  } else {
+    ret = TF_ATTR_INT;  // Same dummy return as TFE_OpGetAttrType.
   }
-  ret = TFE_OpGetAttrType(op, attr_name, is_list, status);
   TFE_DeleteOp(op);
   return ret;
 }
@@ -922,6 +916,7 @@
     return nullptr;
   }
   status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
+                                                 ctx->context->Executor(),
                                                  device, false, &handle);
   if (status->status.ok()) {
     return new TFE_TensorHandle(handle);
@@ -974,7 +969,7 @@
 
 void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
                                   TF_Status* status) {
-  TFE_ContextAsyncWait(ctx, status);
+  status->status = ctx->context->Executor()->WaitForAllPendingNodes();
   if (!status->status.ok()) return;
   tensorflow::mutex_lock ml(*ctx->context->MetadataMu());
   status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf);
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index f685011..cf534c0 100755
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -77,7 +77,7 @@
 // LINT.ThenChange(//tensorflow/core/common_runtime/eager/context.h)
 
 // Sets the default execution mode (sync/async). Note that this can be
-// overridden per thread using TFE_ContextSetAsyncForThread.
+// overridden per thread using TFE_ContextSetExecutorForThread.
 TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*,
                                                       unsigned char enable);
 
@@ -89,6 +89,9 @@
 
 // "Context" under which operations/functions are executed. It encapsulates
 // things like the available devices, resource manager etc.
+// TFE_Context must outlive all tensor handles created using it. In other
+// words, TFE_DeleteContext() must be called after all tensor handles have
+// been deleted (with TFE_DeleteTensorHandle).
 //
 // TODO(ashankar): Merge with TF_Session?
 typedef struct TFE_Context TFE_Context;
@@ -115,11 +118,6 @@
 TF_CAPI_EXPORT extern TFE_ContextDevicePlacementPolicy
 TFE_ContextGetDevicePlacementPolicy(TFE_Context* ctx);
 
-// Overrides the execution mode (sync/async) for the current thread.
-TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
-                                                        unsigned char enable,
-                                                        TF_Status* status);
-
 // A tensorflow.ServerDef specifies remote workers (in addition to the current
 // workers name). Operations created on this context can then be executed on
 // any of these remote workers by setting an appropriate device.
@@ -132,24 +130,6 @@
                                                    size_t proto_len,
                                                    TF_Status* status);
 
-// Causes the calling thread to block till all ops dispatched in async mode
-// have been executed. Note that "execution" here refers to kernel execution /
-// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee
-// that lower level device queues (like GPU streams) have been flushed.
-//
-// This call may not block for execution of ops enqueued concurrently with this
-// call.
-TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context*,
-                                                TF_Status* status);
-
-// When an error happens, any pending operations are discarded and newly issued
-// ops return an error. This call clears the error state and re-enables
-// execution of newly issued ops.
-//
-// Note that outputs of discarded ops remain in a corrupt state and should not
-// be used for future calls.
-// TODO(agarwal): mark the affected handles and raise errors if they are used.
-TF_CAPI_EXPORT extern void TFE_ContextAsyncClearError(TFE_Context*);
 
 // A handle to a tensor on a device.
 //
diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc
index 32f28a0..a9ad771 100644
--- a/tensorflow/c/eager/c_api_experimental.cc
+++ b/tensorflow/c/eager/c_api_experimental.cc
@@ -32,9 +32,7 @@
   op->operation.ConsumeInput(h->handle);
 }
 
-TFE_Profiler* TFE_NewProfiler(TFE_ProfilerContext* ctx) {
-  return new TFE_Profiler(ctx);
-}
+TFE_Profiler* TFE_NewProfiler() { return new TFE_Profiler(); }
 
 bool TFE_ProfilerIsOk(TFE_Profiler* profiler) {
   return profiler->profiler->Status().ok();
@@ -55,23 +53,10 @@
   };
 }
 
-TFE_ProfilerContext* TFE_NewProfilerContext() {
-  return new TFE_ProfilerContext;
-}
-
-void TFE_ProfilerContextSetEagerContext(TFE_ProfilerContext* profiler_context,
-                                        TFE_Context* eager_context) {
-  profiler_context->profiler_context.eager_context = eager_context->context;
-}
-
-void TFE_DeleteProfilerContext(TFE_ProfilerContext* profiler_context) {
-  delete profiler_context;
-}
-
-void TFE_StartProfilerServer(TFE_ProfilerContext* context, int port) {
-  // Release child thread intentionally. The child thread can be terminate by
+void TFE_StartProfilerServer(int port) {
+  // Release child thread intentionally. The child thread can be terminated by
   // terminating the main thread.
-  tensorflow::StartProfilerServer(&context->profiler_context, port).release();
+  tensorflow::StartProfilerServer(port).release();
 }
 
 void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
@@ -587,3 +572,30 @@
   op->operation.SetCancellationManager(
       &cancellation_manager->cancellation_manager);
 }
+
+TFE_Executor* TFE_NewExecutor(bool is_async) {
+  return new TFE_Executor(is_async);
+}
+
+void TFE_DeleteExecutor(TFE_Executor* executor) { delete executor; }
+
+bool TFE_ExecutorIsAsync(TFE_Executor* executor) {
+  return executor->executor()->Async();
+}
+
+void TFE_ExecutorWaitForAllPendingNodes(TFE_Executor* executor,
+                                        TF_Status* status) {
+  status->status = executor->executor()->WaitForAllPendingNodes();
+}
+
+void TFE_ExecutorClearError(TFE_Executor* executor) {
+  executor->executor()->ClearError();
+}
+
+void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
+  ctx->context->SetExecutorForThread(executor->executor());
+}
+
+TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
+  return new TFE_Executor(ctx->context->Executor());
+}
diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h
index cdf1492..e5a9459 100644
--- a/tensorflow/c/eager/c_api_experimental.h
+++ b/tensorflow/c/eager/c_api_experimental.h
@@ -25,8 +25,6 @@
 TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
                                               TF_Status* status);
 
-typedef struct TFE_ProfilerContext TFE_ProfilerContext;
-
 // A profiler which will start profiling when creating the object and will stop
 // when the object is destroyed. It will profile all operations run under the
 // given TFE_Context. Multiple instance of it can be created, but at most one
@@ -34,7 +32,7 @@
 // Thread-safety: TFE_Profiler is thread-safe.
 typedef struct TFE_Profiler TFE_Profiler;
 
-TF_CAPI_EXPORT extern TFE_Profiler* TFE_NewProfiler(TFE_ProfilerContext* ctx);
+TF_CAPI_EXPORT extern TFE_Profiler* TFE_NewProfiler();
 TF_CAPI_EXPORT extern bool TFE_ProfilerIsOk(TFE_Profiler* profiler);
 TF_CAPI_EXPORT extern void TFE_DeleteProfiler(TFE_Profiler* profiler);
 
@@ -44,27 +42,14 @@
                                                          TF_Buffer* buf,
                                                          TF_Status* status);
 
-// Return a new profiler context object.
-TF_CAPI_EXPORT extern TFE_ProfilerContext* TFE_NewProfilerContext(void);
-
-// Set the eager context in TFE_ProfilerServerOptions
-TF_CAPI_EXPORT extern void TFE_ProfilerContextSetEagerContext(
-    TFE_ProfilerContext* profiler_context, TFE_Context* eager_context);
-
-// Destroy a profiler context object.
-TF_CAPI_EXPORT extern void TFE_DeleteProfilerContext(
-    TFE_ProfilerContext* profiler_context);
-
 // Start a profiler grpc server which listens to specified port. It will start
 // the server on its own thread. It can be shutdown by terminating tensorflow.
 // It can be used in both Eager mode and graph mode. Creating multiple profiler
 // server is allowed. The service defined in
 // tensorflow/contrib/tpu/profiler/tpu_profiler.proto. Please use
-// tensorflow/contrib/tpu/profiler/capture_tpu_profile to capture tracable
-// file following
-// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
-TF_CAPI_EXPORT extern void TFE_StartProfilerServer(TFE_ProfilerContext* context,
-                                                   int port);
+// tensorflow/contrib/tpu/profiler/capture_tpu_profile to capture trace file
+// following https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
+TF_CAPI_EXPORT extern void TFE_StartProfilerServer(int port);
 
 // Enables only graph collection in RunMetadata on the functions executed from
 // this context.
@@ -367,6 +352,51 @@
     TFE_Op* op, TFE_CancellationManager* cancellation_manager,
     TF_Status* status);
 
+// -----------------------------------------------------------------------------
+// Eager Executor APIs.
+typedef struct TFE_Executor TFE_Executor;
+
+// Creates a new eager Executor. Nodes in one executor are guaranteed to be
+// executed in sequence. Assigning nodes to different executors allows executing
+// nodes in parallel.
+TF_CAPI_EXPORT extern TFE_Executor* TFE_NewExecutor(bool is_async);
+
+// Deletes the eager Executor without waiting for enqueued nodes. Please call
+// TFE_ExecutorWaitForAllPendingNodes before calling this API if you want to
+// make sure all nodes are finished.
+TF_CAPI_EXPORT extern void TFE_DeleteExecutor(TFE_Executor*);
+
+// Returns true if the executor is in async mode.
+TF_CAPI_EXPORT extern bool TFE_ExecutorIsAsync(TFE_Executor*);
+
+// Causes the calling thread to block till all ops dispatched in this executor
+// have been executed. Note that "execution" here refers to kernel execution /
+// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee
+// that lower level device queues (like GPU streams) have been flushed.
+//
+// This call may not block for execution of ops enqueued concurrently with this
+// call.
+TF_CAPI_EXPORT extern void TFE_ExecutorWaitForAllPendingNodes(
+    TFE_Executor*, TF_Status* status);
+
+// When an error happens, any pending operations are discarded and newly issued
+// ops return an error. This call clears the error state and re-enables
+// execution of newly issued ops.
+//
+// Note that outputs of discarded ops remain in a corrupt state and should not
+// be used for future calls.
+// TODO(agarwal): mark the affected handles and raise errors if they are used.
+TF_CAPI_EXPORT extern void TFE_ExecutorClearError(TFE_Executor*);
+
+// Sets a custom Executor for current thread. All nodes created by this thread
+// will be added to this Executor. It will override current executor.
+TF_CAPI_EXPORT extern void TFE_ContextSetExecutorForThread(TFE_Context*,
+                                                           TFE_Executor*);
+
+// Returns the Executor for current thread.
+TF_CAPI_EXPORT extern TFE_Executor* TFE_ContextGetExecutorForThread(
+    TFE_Context*);
+
 #ifdef __cplusplus
 } /* end extern "C" */
 #endif
diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc
index 249d6c8..bbe7cae 100644
--- a/tensorflow/c/eager/c_api_experimental_test.cc
+++ b/tensorflow/c/eager/c_api_experimental_test.cc
@@ -17,6 +17,7 @@
 
 #include <string.h>
 
+#include "tensorflow/c/eager/c_api.h"
 #include "tensorflow/c/eager/c_api_test_util.h"
 #include "tensorflow/cc/profiler/profiler.h"
 #include "tensorflow/core/lib/monitoring/collection_registry.h"
@@ -43,12 +44,9 @@
   TFE_ContextOptions* opts = TFE_NewContextOptions();
   TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
   TFE_Context* ctx = TFE_NewContext(opts, status);
-  TFE_ProfilerContext* profiler_context = TFE_NewProfilerContext();
-  TFE_ProfilerContextSetEagerContext(profiler_context, ctx);
-  TFE_Profiler* profiler = TFE_NewProfiler(profiler_context);
+  TFE_Profiler* profiler = TFE_NewProfiler();
   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   TFE_DeleteContextOptions(opts);
-  TFE_DeleteProfilerContext(profiler_context);
 
   TFE_TensorHandle* m = TestMatrixTensorHandle();
   TFE_Op* matmul = MatMulOp(ctx, m, m);
@@ -71,8 +69,10 @@
   ASSERT_EQ(1, num_retvals);
   TF_Buffer* profiler_result = TF_NewBuffer();
   if (async) {
-    TFE_ContextAsyncWait(ctx, status);
+    TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
+    TFE_ExecutorWaitForAllPendingNodes(executor, status);
     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+    TFE_DeleteExecutor(executor);
   }
   TFE_ProfilerSerializeToString(profiler, profiler_result, status);
   TFE_DeleteProfiler(profiler);
@@ -110,27 +110,14 @@
 TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithProfiling(true); }
 
 TEST(CAPI, MultipleProfilerSession) {
-  TF_Status* status = TF_NewStatus();
-  TFE_ContextOptions* opts = TFE_NewContextOptions();
-  TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(false));
-  TFE_Context* ctx = TFE_NewContext(opts, status);
-  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
-  TFE_DeleteContextOptions(opts);
-
-  TFE_ProfilerContext* profiler_context = TFE_NewProfilerContext();
-  TFE_ProfilerContextSetEagerContext(profiler_context, ctx);
-
-  TFE_Profiler* profiler1 = TFE_NewProfiler(profiler_context);
+  TFE_Profiler* profiler1 = TFE_NewProfiler();
   EXPECT_TRUE(TFE_ProfilerIsOk(profiler1));
 
-  TFE_Profiler* profiler2 = TFE_NewProfiler(profiler_context);
+  TFE_Profiler* profiler2 = TFE_NewProfiler();
   EXPECT_FALSE(TFE_ProfilerIsOk(profiler2));
 
   TFE_DeleteProfiler(profiler1);
   TFE_DeleteProfiler(profiler2);
-  TFE_DeleteProfilerContext(profiler_context);
-  TFE_DeleteContext(ctx);
-  TF_DeleteStatus(status);
 }
 
 TEST(CAPI, MonitoringCounter0) {
@@ -307,5 +294,205 @@
   TFE_DeleteCancellationManager(c_mgr);
 }
 
+TEST(CAPI, Function_ident_CPU) {
+  // First create a simple identity function.
+  TF_Graph* function_graph = TF_NewGraph();
+  TF_OperationDescription* arg_descr =
+      TF_NewOperation(function_graph, "Placeholder", "arg");
+  TF_SetAttrType(arg_descr, "dtype", TF_INT32);
+  TF_Status* status = TF_NewStatus();
+  TF_Operation* arg = TF_FinishOperation(arg_descr, status);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+  TF_OperationDescription* id_descr =
+      TF_NewOperation(function_graph, "Identity", "id");
+  TF_SetAttrType(id_descr, "T", TF_INT32);
+  TF_AddInput(id_descr, {arg, 0});
+  TF_Operation* id = TF_FinishOperation(id_descr, status);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+  TF_Output input{arg, 0};
+  TF_Output output{id, 0};
+  TF_Function* fn =
+      TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1,
+                         &output, nullptr, nullptr, "test", status);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+  TF_DeleteGraph(function_graph);
+  TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_Context* ctx = TFE_NewContext(opts, status);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+  TFE_DeleteContextOptions(opts);
+  TFE_ContextAddFunction(ctx, fn, status);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+  TF_DeleteFunction(fn);
+
+  for (bool async : {false, true, false}) {
+    TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx);
+    TFE_Executor* executor = TFE_NewExecutor(async);
+    TFE_ContextSetExecutorForThread(ctx, executor);
+    CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+    TF_Tensor* t =
+        TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
+    *reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
+    TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
+    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+    TF_DeleteTensor(t);
+
+    TFE_Op* op = TFE_NewOp(ctx, "ident", status);
+    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+    TFE_OpAddInput(op, h, status);
+    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+
+    std::vector<TFE_TensorHandle*> result;
+    result.push_back(nullptr);
+    int num_retvals = 1;
+    TFE_Execute(op, result.data(), &num_retvals, status);
+    TFE_DeleteOp(op);
+    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+    ASSERT_EQ(num_retvals, 1);
+
+    TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
+    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+    EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
+    TFE_ContextSetExecutorForThread(ctx, old_executor);
+    TFE_ExecutorWaitForAllPendingNodes(executor, status);
+    ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+    TFE_DeleteExecutor(executor);
+    TFE_DeleteExecutor(old_executor);
+    TFE_DeleteTensorHandle(h);
+    TF_DeleteTensor(r);
+    TFE_DeleteTensorHandle(result[0]);
+  }
+  TFE_ContextRemoveFunction(ctx, "ident", status);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+  TFE_DeleteContext(ctx);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+  TF_DeleteStatus(status);
+}
+
+#ifdef TENSORFLOW_EAGER_USE_XLA
+TEST(CAPI, Function_ident_XLA_CPU) {
+  // First create a simple identity function.
+  TF_Graph* function_graph = TF_NewGraph();
+  TF_OperationDescription* arg_descr =
+      TF_NewOperation(function_graph, "Placeholder", "arg");
+  TF_SetAttrType(arg_descr, "dtype", TF_INT32);
+  TF_Status* status = TF_NewStatus();
+  TF_Operation* arg = TF_FinishOperation(arg_descr, status);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+  TF_OperationDescription* id_descr =
+      TF_NewOperation(function_graph, "Identity", "id");
+  TF_SetAttrType(id_descr, "T", TF_INT32);
+  TF_AddInput(id_descr, {arg, 0});
+  TF_Operation* id = TF_FinishOperation(id_descr, status);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+  TF_Output input{arg, 0};
+  TF_Output output{id, 0};
+  TF_Function* fn =
+      TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1,
+                         &output, nullptr, nullptr, "test", status);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+  TF_DeleteGraph(function_graph);
+  TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_Context* ctx = TFE_NewContext(opts, status);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+  TFE_DeleteContextOptions(opts);
+  TFE_ContextAddFunction(ctx, fn, status);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+  TF_DeleteFunction(fn);
+
+  for (bool async : {false, true, false}) {
+    TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx);
+    TFE_Executor* executor = TFE_NewExecutor(async);
+    TFE_ContextSetExecutorForThread(ctx, executor);
+    CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+    ASSERT_TRUE(TF_GetCode(status) == TF_OK);
+    TF_Tensor* t =
+        TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
+    *reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
+    TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
+    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+    TF_DeleteTensor(t);
+
+    TFE_Op* op = TFE_NewOp(ctx, "ident", status);
+    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+    TFE_OpAddInput(op, h, status);
+    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+
+    // Now run it via XLA.
+    TFE_OpSetXLACompilation(op, true);
+
+    std::vector<TFE_TensorHandle*> result;
+    result.push_back(nullptr);
+    int num_retvals = 1;
+    TFE_Execute(op, result.data(), &num_retvals, status);
+    TFE_DeleteOp(op);
+    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+    ASSERT_EQ(num_retvals, 1);
+
+    TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
+    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+    EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
+    TFE_ContextSetExecutorForThread(ctx, old_executor);
+    TFE_ExecutorWaitForAllPendingNodes(executor, status);
+    ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+    TFE_DeleteExecutor(executor);
+    TFE_DeleteExecutor(old_executor);
+    TFE_DeleteTensorHandle(h);
+    TF_DeleteTensor(r);
+    TFE_DeleteTensorHandle(result[0]);
+  }
+  TFE_ContextRemoveFunction(ctx, "ident", status);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+  TFE_DeleteContext(ctx);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+  TF_DeleteStatus(status);
+}
+#endif  // TENSORFLOW_EAGER_USE_XLA
+
+void Executor_MatMul_CPU(bool async) {
+  TF_Status* status = TF_NewStatus();
+  TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_Context* ctx = TFE_NewContext(opts, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_DeleteContextOptions(opts);
+
+  TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx);
+  TFE_Executor* executor = TFE_NewExecutor(async);
+  TFE_ContextSetExecutorForThread(ctx, executor);
+
+  TFE_TensorHandle* m = TestMatrixTensorHandle();
+  TFE_Op* matmul = MatMulOp(ctx, m, m);
+  TFE_TensorHandle* retvals[2] = {nullptr, nullptr};
+  int num_retvals = 2;
+  TFE_Execute(matmul, &retvals[0], &num_retvals, status);
+  EXPECT_EQ(1, num_retvals);
+  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_DeleteOp(matmul);
+  TFE_DeleteTensorHandle(m);
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+  TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_DeleteTensorHandle(retvals[0]);
+  TFE_ContextSetExecutorForThread(ctx, old_executor);
+  TFE_ExecutorWaitForAllPendingNodes(executor, status);
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_DeleteExecutor(executor);
+  TFE_DeleteExecutor(old_executor);
+  TFE_DeleteContext(ctx);
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  float product[4] = {0};
+  EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
+  memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
+  TF_DeleteTensor(t);
+  EXPECT_EQ(7, product[0]);
+  EXPECT_EQ(10, product[1]);
+  EXPECT_EQ(15, product[2]);
+  EXPECT_EQ(22, product[3]);
+  TF_DeleteStatus(status);
+}
+TEST(CAPI, Executor_MatMul_CPU) { Executor_MatMul_CPU(false); }
+TEST(CAPI, Executor_MatMul_CPUAsync) { Executor_MatMul_CPU(true); }
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index fe0c952..5efed2c 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -76,7 +76,14 @@
             async, device_mgr, device_mgr_owned, rendezvous,
             custom_kernel_creator)) {}
 
-  ~TFE_Context() { context->Unref(); }
+  ~TFE_Context() {
+    // TODO(iga): Add a separate API method to shutdown TFE_Context so that we
+    // don't send RPCs and block in destructor.
+    context->WaitForAndCloseRemoteContexts();
+    // context->RefCountIsOne() should be true here.
+    // TODO(iga): Remove EagerContext refcounting.
+    context->Unref();
+  }
 
   tensorflow::EagerContext* context;
 };
@@ -130,14 +137,8 @@
   std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
 };
 
-struct TFE_ProfilerContext {
-  tensorflow::ProfilerContext profiler_context;
-};
-
 struct TFE_Profiler {
-  explicit TFE_Profiler(TFE_ProfilerContext* ctx) {
-    profiler = tensorflow::ProfilerSession::Create(&ctx->profiler_context);
-  }
+  explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); }
 
   std::unique_ptr<tensorflow::ProfilerSession> profiler;
 };
@@ -291,4 +292,19 @@
   tensorflow::CancellationManager cancellation_manager;
 };
 
+struct TFE_Executor {
+  explicit TFE_Executor(bool async)
+      : owned_executor(new tensorflow::EagerExecutor(async)) {}
+
+  explicit TFE_Executor(tensorflow::EagerExecutor* executor)
+      : owned_executor(nullptr), unowned_executor(executor) {}
+
+  tensorflow::EagerExecutor* executor() {
+    return owned_executor == nullptr ? unowned_executor : owned_executor.get();
+  }
+
+  std::unique_ptr<tensorflow::EagerExecutor> owned_executor;
+  tensorflow::EagerExecutor* unowned_executor;
+};
+
 #endif  // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 17df7bb..d3b755f 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -18,6 +18,7 @@
 #include <string.h>
 
 #include "absl/strings/match.h"
+#include "tensorflow/c/eager/c_api_experimental.h"
 #include "tensorflow/c/eager/c_api_internal.h"
 #include "tensorflow/c/eager/c_api_test_util.h"
 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
@@ -78,7 +79,10 @@
     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   }
   if (async) {
-    TFE_ContextAsyncWait(ctx, status);
+    TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
+    TFE_ExecutorWaitForAllPendingNodes(executor, status);
+    ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+    TFE_DeleteExecutor(executor);
   }
   tensorflow::testing::StopTiming();
   TFE_DeleteOp(matmul);
@@ -110,7 +114,10 @@
     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   }
   if (async) {
-    TFE_ContextAsyncWait(ctx, status);
+    TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
+    TFE_ExecutorWaitForAllPendingNodes(executor, status);
+    ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+    TFE_DeleteExecutor(executor);
   }
   tensorflow::testing::StopTiming();
   TFE_DeleteOp(identity);
@@ -228,8 +235,10 @@
 
   TFE_DeleteOp(matmul);
 
-  TFE_ContextAsyncWait(ctx, status);
-  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
+  TFE_ExecutorWaitForAllPendingNodes(executor, status);
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_DeleteExecutor(executor);
   TFE_DeleteContext(ctx);
 
   TF_DeleteStatus(status);
@@ -314,9 +323,11 @@
 
   TFE_DeleteOp(matmul);
 
-  TFE_ContextAsyncWait(ctx, status);
+  TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
+  TFE_ExecutorWaitForAllPendingNodes(executor, status);
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_DeleteExecutor(executor);
   TFE_DeleteContext(ctx);
-  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
 
   TF_DeleteStatus(status);
 
@@ -330,7 +341,7 @@
   TestRemoteExecuteSilentCopies(true);
 }
 
-void TestRemoteExecuteDeleteTensorAfterContext(bool async) {
+void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
   tensorflow::ServerDef server_def = GetServerDef(2);
 
   // This server def has the task index set to 0.
@@ -356,33 +367,49 @@
   TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
 
-  TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
+  // Use large matrices so that RPCs don't return before we get a chance
+  // to call TFE_DeleteContext.
+  TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100();
+  TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100();
   const char remote_device_name[] =
       "/job:localhost/replica:0/task:1/device:CPU:0";
   auto* h0_task1 =
       TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  auto* h1_task1 =
+      TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status);
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+  TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1);
+  TFE_OpSetDevice(matmul, remote_device_name, status);
+  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+  TFE_TensorHandle* retvals[1];
+  int num_retvals = 1;
+  TFE_Execute(matmul, &retvals[0], &num_retvals, status);
+  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TF_DeleteStatus(status);
 
   TFE_DeleteTensorHandle(h0_task0);
-
-  TFE_ContextAsyncWait(ctx, status);
-  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
-  TFE_DeleteContext(ctx);
-
-  // Delete tensors after context is deleted.
+  TFE_DeleteTensorHandle(h1_task0);
   TFE_DeleteTensorHandle(h0_task1);
+  TFE_DeleteTensorHandle(h1_task1);
+  TFE_DeleteTensorHandle(retvals[0]);
 
-  TF_DeleteStatus(status);
+  TFE_DeleteOp(matmul);
+
+  TFE_DeleteContext(ctx);
 
   // TODO(b/136478427): Figure out how to correctly shut the server down.
   worker_server.release();
 }
 
-TEST(CAPI, RemoteExecuteDeleteTensorAfterContext) {
-  TestRemoteExecuteDeleteTensorAfterContext(false);
+TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) {
+  TestRemoteExecuteDeleteContextWithOutstandingRPC(false);
 }
-TEST(CAPI, RemoteExecuteDeleteTensorAfterContextAsync) {
-  TestRemoteExecuteDeleteTensorAfterContext(true);
+
+TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) {
+  TestRemoteExecuteDeleteContextWithOutstandingRPC(true);
 }
 
 void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
@@ -429,8 +456,10 @@
 
   TFE_DeleteOp(matmul);
 
-  TFE_ContextAsyncWait(ctx, status);
-  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
+  TFE_ExecutorWaitForAllPendingNodes(executor, status);
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_DeleteExecutor(executor);
   TF_DeleteStatus(status);
 }
 
@@ -465,8 +494,9 @@
       "/job:localhost/replica:0/task:0/device:CPU:0";
   CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
 
-  TFE_ContextAsyncWait(ctx, status);
-  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
+  TFE_ExecutorWaitForAllPendingNodes(executor, status);
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
 
   // TODO(b/136478427): Figure out how to correctly shut the server down.
   worker_server.release();
@@ -508,8 +538,9 @@
   CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name,
                               new_local_device_name);
 
-  TFE_ContextAsyncWait(ctx, status);
-  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_ExecutorWaitForAllPendingNodes(executor, status);
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_DeleteExecutor(executor);
 
   TF_DeleteStatus(status);
 
@@ -642,8 +673,11 @@
   TFE_TensorHandle* hcopy =
       TFE_TensorHandleCopyToDevice(hcpu, ctx, kCPUDevice, status.get());
   EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
-  TFE_ContextAsyncWait(ctx, status.get());
-  EXPECT_EQ(TF_OK, TF_GetCode(status.get()));
+
+  TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
+  TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
+  EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+  TFE_DeleteExecutor(executor);
   TFE_DeleteTensorHandle(hcopy);
   TFE_DeleteTensorHandle(hcpu);
   if (hdevice != nullptr) TFE_DeleteTensorHandle(hdevice);
@@ -772,8 +806,10 @@
 
   TF_DeleteTensor(t);
   TFE_DeleteTensorHandle(hcpu);
-  TFE_ContextAsyncWait(ctx, status.get());
-  EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+  TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
+  TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
+  ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+  TFE_DeleteExecutor(executor);
   TFE_DeleteContext(ctx);
 }
 
@@ -818,8 +854,10 @@
 
   TF_DeleteTensor(t);
   TFE_DeleteTensorHandle(hcpu);
-  TFE_ContextAsyncWait(ctx, status.get());
-  EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+  TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
+  TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
+  ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+  TFE_DeleteExecutor(executor);
   TFE_DeleteContext(ctx);
 }
 TEST(CAPI, TensorHandleSilentCopyLocal) { TensorHandleSilentCopyLocal(false); }
@@ -953,8 +991,10 @@
   }
 
   TFE_DeleteTensorHandle(hcpu);
-  TFE_ContextAsyncWait(ctx, status.get());
-  EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+  TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
+  TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
+  ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+  TFE_DeleteExecutor(executor);
   TFE_DeleteContext(ctx);
 }
 
@@ -1032,9 +1072,11 @@
     retvals[0] = nullptr;
     TFE_Execute(matmul2, &retvals[0], &num_retvals, status);
     EXPECT_NE(TF_OK, TF_GetCode(status));
-    TFE_ContextAsyncClearError(ctx);
-    TFE_ContextAsyncWait(ctx, status);
-    EXPECT_EQ(TF_OK, TF_GetCode(status));
+    TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
+    TFE_ExecutorClearError(executor);
+    TFE_ExecutorWaitForAllPendingNodes(executor, status);
+    ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+    TFE_DeleteExecutor(executor);
   }
   // Following works in async mode since TFE_ContextAsyncClearError was called.
   TF_SetStatus(status, TF_OK, "");
@@ -1252,147 +1294,6 @@
 TEST(CAPI, ExecuteWithTracing) { ExecuteWithTracing(false); }
 TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithTracing(true); }
 
-TEST(CAPI, Function_ident_CPU) {
-  // First create a simple identity function.
-  TF_Graph* function_graph = TF_NewGraph();
-  TF_OperationDescription* arg_descr =
-      TF_NewOperation(function_graph, "Placeholder", "arg");
-  TF_SetAttrType(arg_descr, "dtype", TF_INT32);
-  TF_Status* status = TF_NewStatus();
-  TF_Operation* arg = TF_FinishOperation(arg_descr, status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TF_OperationDescription* id_descr =
-      TF_NewOperation(function_graph, "Identity", "id");
-  TF_SetAttrType(id_descr, "T", TF_INT32);
-  TF_AddInput(id_descr, {arg, 0});
-  TF_Operation* id = TF_FinishOperation(id_descr, status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TF_Output input{arg, 0};
-  TF_Output output{id, 0};
-  TF_Function* fn =
-      TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1,
-                         &output, nullptr, nullptr, "test", status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TF_DeleteGraph(function_graph);
-  TFE_ContextOptions* opts = TFE_NewContextOptions();
-  TFE_Context* ctx = TFE_NewContext(opts, status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TFE_DeleteContextOptions(opts);
-  TFE_ContextAddFunction(ctx, fn, status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TF_DeleteFunction(fn);
-
-  for (bool async : {false, true, false}) {
-    TFE_ContextSetAsyncForThread(ctx, static_cast<unsigned char>(async),
-                                 status);
-    ASSERT_TRUE(TF_GetCode(status) == TF_OK);
-    TF_Tensor* t =
-        TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
-    *reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
-    TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
-    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-    TF_DeleteTensor(t);
-
-    TFE_Op* op = TFE_NewOp(ctx, "ident", status);
-    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-    TFE_OpAddInput(op, h, status);
-    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-
-    std::vector<TFE_TensorHandle*> result;
-    result.push_back(nullptr);
-    int num_retvals = 1;
-    TFE_Execute(op, result.data(), &num_retvals, status);
-    TFE_DeleteOp(op);
-    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-    ASSERT_EQ(num_retvals, 1);
-
-    TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
-    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-    EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
-    TFE_DeleteTensorHandle(h);
-    TF_DeleteTensor(r);
-    TFE_DeleteTensorHandle(result[0]);
-  }
-  TFE_ContextRemoveFunction(ctx, "ident", status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TFE_DeleteContext(ctx);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TF_DeleteStatus(status);
-}
-
-#ifdef TENSORFLOW_EAGER_USE_XLA
-TEST(CAPI, Function_ident_XLA_CPU) {
-  // First create a simple identity function.
-  TF_Graph* function_graph = TF_NewGraph();
-  TF_OperationDescription* arg_descr =
-      TF_NewOperation(function_graph, "Placeholder", "arg");
-  TF_SetAttrType(arg_descr, "dtype", TF_INT32);
-  TF_Status* status = TF_NewStatus();
-  TF_Operation* arg = TF_FinishOperation(arg_descr, status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TF_OperationDescription* id_descr =
-      TF_NewOperation(function_graph, "Identity", "id");
-  TF_SetAttrType(id_descr, "T", TF_INT32);
-  TF_AddInput(id_descr, {arg, 0});
-  TF_Operation* id = TF_FinishOperation(id_descr, status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TF_Output input{arg, 0};
-  TF_Output output{id, 0};
-  TF_Function* fn =
-      TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1,
-                         &output, nullptr, nullptr, "test", status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TF_DeleteGraph(function_graph);
-  TFE_ContextOptions* opts = TFE_NewContextOptions();
-  TFE_Context* ctx = TFE_NewContext(opts, status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TFE_DeleteContextOptions(opts);
-  TFE_ContextAddFunction(ctx, fn, status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TF_DeleteFunction(fn);
-
-  for (bool async : {false, true, false}) {
-    TFE_ContextSetAsyncForThread(ctx, static_cast<unsigned char>(async),
-                                 status);
-    ASSERT_TRUE(TF_GetCode(status) == TF_OK);
-    TF_Tensor* t =
-        TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
-    *reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
-    TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
-    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-    TF_DeleteTensor(t);
-
-    TFE_Op* op = TFE_NewOp(ctx, "ident", status);
-    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-    TFE_OpAddInput(op, h, status);
-    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-
-    // Now run it via XLA.
-    TFE_OpSetXLACompilation(op, true);
-
-    std::vector<TFE_TensorHandle*> result;
-    result.push_back(nullptr);
-    int num_retvals = 1;
-    TFE_Execute(op, result.data(), &num_retvals, status);
-    TFE_DeleteOp(op);
-    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-    ASSERT_EQ(num_retvals, 1);
-
-    TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
-    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-    EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
-    TFE_DeleteTensorHandle(h);
-    TF_DeleteTensor(r);
-    TFE_DeleteTensorHandle(result[0]);
-  }
-  TFE_ContextRemoveFunction(ctx, "ident", status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TFE_DeleteContext(ctx);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TF_DeleteStatus(status);
-}
-#endif  // TENSORFLOW_EAGER_USE_XLA
-
 string MatMulFunction() {
   tensorflow::FunctionDef def;
   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
@@ -1506,7 +1407,10 @@
     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   }
   if (async) {
-    TFE_ContextAsyncWait(ctx, status);
+    TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
+    TFE_ExecutorWaitForAllPendingNodes(executor, status);
+    ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+    TFE_DeleteExecutor(executor);
   }
   tensorflow::testing::StopTiming();
   TFE_DeleteTensorHandle(m);
diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc
index 10d95e6..51566b3 100644
--- a/tensorflow/c/eager/c_api_test_util.cc
+++ b/tensorflow/c/eager/c_api_test_util.cc
@@ -85,6 +85,24 @@
   return th;
 }
 
+TFE_TensorHandle* TestMatrixTensorHandle100x100() {
+  constexpr int64_t dims[] = {100, 100};
+  constexpr int num_elements = dims[0] * dims[1];
+  float data[num_elements];
+  for (int i = 0; i < num_elements; ++i) {
+    data[i] = 1.0f;
+  }
+  TF_Tensor* t = TF_AllocateTensor(
+      TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
+  memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
+  TF_Status* status = TF_NewStatus();
+  TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TF_DeleteTensor(t);
+  TF_DeleteStatus(status);
+  return th;
+}
+
 TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2() {
   int64_t dims[] = {3, 2};
   double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h
index d0c20ac..2806222 100644
--- a/tensorflow/c/eager/c_api_test_util.h
+++ b/tensorflow/c/eager/c_api_test_util.h
@@ -16,7 +16,6 @@
 #define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
 
 #include "tensorflow/c/eager/c_api.h"
-
 #include "tensorflow/core/platform/types.h"
 
 // Return a tensor handle containing a float scalar
@@ -34,6 +33,9 @@
 // Return a tensor handle containing a 2x2 matrix of floats
 TFE_TensorHandle* TestMatrixTensorHandle();
 
+// Return a tensor handle containing a 100x100 matrix of floats
+TFE_TensorHandle* TestMatrixTensorHandle100x100();
+
 // Return a tensor handle containing a 3x2 matrix of doubles
 TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2();
 
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 0545e3f..d87781d 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -262,6 +262,12 @@
       const std::function<BackwardFunction*()>& backward_function_getter,
       const std::function<void(BackwardFunction*)>& backward_function_deleter);
 
+  // Returns true if `Accumulate` is active somewhere above on the stack. This
+  // is useful for ordering ForwardAccumulators, where more deeply nested
+  // accumulators should not see computations from less deeply nested
+  // accumulators.
+  bool BusyAccumulating() const { return this->accumulating_; }
+
   // Fetches the current Jacobian-vector product associated with `tensor_id`, or
   // a nullptr if none is available.
   //
diff --git a/tensorflow/c/experimental/rendezvous.cc b/tensorflow/c/experimental/rendezvous.cc
index 0ee4907..7a90bde 100644
--- a/tensorflow/c/experimental/rendezvous.cc
+++ b/tensorflow/c/experimental/rendezvous.cc
@@ -45,6 +45,9 @@
 void CRemoteRendezvous::RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
                                             const Rendezvous::Args& args,
                                             DoneCallback done) {
+  if (args.cancellation_manager != nullptr) {
+    VLOG(1) << "WARNING: CRemoteRendezvous does not support cancellation.";
+  }
   TF_ParsedKey key;
   key.src_device = parsed.src_device.data();
   key.src_device_len = parsed.src_device.size();
diff --git a/tensorflow/c/generate-pc.sh b/tensorflow/c/generate-pc.sh
index 7184ad6..a4d51a1 100755
--- a/tensorflow/c/generate-pc.sh
+++ b/tensorflow/c/generate-pc.sh
@@ -63,12 +63,26 @@
 prefix=${TF_PREFIX}
 exec_prefix=\${prefix}
 libdir=\${exec_prefix}/${LIBDIR}
-includedir=\${prefix}/include
+includedir=\${prefix}/include/tensorflow
 
 Name: TensorFlow
 Version: ${TF_VERSION}
 Description: Library for computation using data flow graphs for scalable machine learning
 Requires:
-Libs: -L\${libdir} -ltensorflow
+Libs: -L\${libdir} -ltensorflow -ltensorflow_framework
+Cflags: -I\${includedir}
+EOF
+
+cat << EOF > tensorflow_cc.pc
+prefix=${TF_PREFIX}
+exec_prefix=\${prefix}
+libdir=\${exec_prefix}/${LIBDIR}
+includedir=\${prefix}/include/tensorflow
+
+Name: TensorFlow
+Version: ${TF_VERSION}
+Description: Library for computation using data flow graphs for scalable machine learning
+Requires:
+Libs: -L\${libdir} -ltensorflow_cc -ltensorflow_framework
 Cflags: -I\${includedir}
 EOF
diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc
index 94685c8..b067176f 100644
--- a/tensorflow/c/kernels.cc
+++ b/tensorflow/c/kernels.cc
@@ -19,6 +19,7 @@
 
 #include "tensorflow/c/c_api_internal.h"
 #include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/c/tf_tensor_internal.h"
 #include "tensorflow/core/framework/kernel_def_builder.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/register_types.h"
@@ -189,8 +190,8 @@
 void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor,
                   TF_Status* status) {
   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
-  if (i < 0 || i >= cc_ctx->num_inputs()) {
-    TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
+  if (i < 0 || i >= cc_ctx->num_outputs()) {
+    TF_SetStatus(status, TF_OUT_OF_RANGE, "output index out of range");
     return;
   }
   ::tensorflow::Tensor cc_tensor;
@@ -240,3 +241,14 @@
 int64_t TF_StepId(TF_OpKernelContext* ctx) {
   return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->step_id();
 }
+
+TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
+                             TF_DataType dtype, int64_t* dims, int num_dims,
+                             size_t len) {
+  auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
+  tensorflow::AllocatorAttributes attr = cc_ctx->output_alloc_attr(index);
+  auto* allocator = cc_ctx->get_allocator(attr);
+  void* data = tensorflow::allocate_tensor("TF_AllocateOutput", len, allocator);
+  return TF_NewTensor(dtype, dims, num_dims, data, len,
+                      tensorflow::deallocate_buffer, allocator);
+}
diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h
index a192437..8d0518a 100644
--- a/tensorflow/c/kernels.h
+++ b/tensorflow/c/kernels.h
@@ -180,6 +180,16 @@
     TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val,
     TF_Status* status);
 
+// Allocates Tensor for output at given index. Caller takes ownership of
+// returned TF_Tensor and should deallocate it using TF_DeleteTensor(tensor).
+//
+// This function should be used to allocate outputs inside kernel
+// compute function.
+TF_CAPI_EXPORT TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context,
+                                            int index, TF_DataType dtype,
+                                            int64_t* dims, int num_dims,
+                                            size_t len);
+
 #ifdef __cplusplus
 } /* end extern "C" */
 #endif
diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc
index 0e65d18..9d300ed 100644
--- a/tensorflow/c/kernels_test.cc
+++ b/tensorflow/c/kernels_test.cc
@@ -12,17 +12,23 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#define EIGEN_USE_GPU
+#endif
 
 #include "tensorflow/c/kernels.h"
 
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 #include "tensorflow/c/c_api.h"
 #include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/kernel_def.pb.h"
 #include "tensorflow/core/framework/node_def.pb_text.h"
+#include "tensorflow/core/framework/node_def_builder.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
 
@@ -309,4 +315,144 @@
   TF_DeleteKernelBuilder(builder);
   ASSERT_TRUE(delete_called);
 }
+
+class DeviceKernelOpTest : public OpsTestBase {
+ protected:
+  void SetupOp(const char* op_name, const char* kernel_name,
+               void (*compute_func)(void*, TF_OpKernelContext*)) {
+    TF_KernelBuilder* builder = TF_NewKernelBuilder(
+        op_name, device_name_, nullptr, compute_func, nullptr);
+    TF_Status* status = TF_NewStatus();
+    TF_RegisterKernelBuilder(kernel_name, builder, status);
+    EXPECT_EQ(TF_OK, TF_GetCode(status));
+    TF_DeleteStatus(status);
+
+#if GOOGLE_CUDA
+    std::unique_ptr<Device> device(
+        DeviceFactory::NewDevice(device_name_, {}, "/job:a/replica:0/task:0"));
+    OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
+#endif
+    TF_ASSERT_OK(NodeDefBuilder(op_name, op_name).Finalize(node_def()));
+    TF_ASSERT_OK(InitOp());
+  }
+
+#if GOOGLE_CUDA
+  const char* device_name_ = tensorflow::DEVICE_GPU;
+#else
+  const char* device_name_ = tensorflow::DEVICE_CPU;
+#endif
+};
+
+REGISTER_OP("AllocateOutputOp1").Output("output1: float");
+
+TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {
+  auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
+    // Allocate output
+    int64_t dim = 1;
+    size_t tensor_size_bytes = TF_DataTypeSize(TF_FLOAT);
+    TF_Tensor* output = TF_AllocateOutput(
+        /*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
+        /*num_dims=*/1, /*len=*/tensor_size_bytes);
+    EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
+    EXPECT_EQ(1, TF_NumDims(output));
+    EXPECT_EQ(1, TF_Dim(output, 0));
+
+    // Set output to 3
+    float* data = reinterpret_cast<float*>(TF_TensorData(output));
+    float value = 3.0f;
+#if GOOGLE_CUDA
+    OpKernelContext* cc_ctx = reinterpret_cast<OpKernelContext*>(ctx);
+    cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, &value,
+                                                  tensor_size_bytes);
+#else
+    *data = value;
+#endif
+
+    TF_Status* s = TF_NewStatus();
+    TF_SetOutput(ctx, 0, output, s);
+    EXPECT_EQ(TF_OK, TF_GetCode(s));
+
+    TF_DeleteStatus(s);
+    TF_DeleteTensor(output);
+  };
+
+  SetupOp("AllocateOutputOp1", "AllocateOutput1", my_compute_func);
+
+  TF_ASSERT_OK(RunOpKernel());
+  Tensor* output = GetOutput(0);
+  EXPECT_EQ("Tensor<type: float shape: [1] values: 3>",
+            output->DebugString(100));
+}
+
+REGISTER_OP("AllocateOutputOp0").Output("output1: float");
+
+TEST_F(DeviceKernelOpTest, TestAllocateEmptyOutput) {
+  auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
+    // Allocate empty output
+    int64_t dim = 0;
+    TF_Tensor* output = TF_AllocateOutput(
+        /*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
+        /*num_dims=*/1, /*len=*/0);
+
+    EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
+    EXPECT_EQ(1, TF_NumDims(output));
+    EXPECT_EQ(0, TF_Dim(output, 0));
+
+    TF_Status* s = TF_NewStatus();
+    TF_SetOutput(ctx, 0, output, s);
+    EXPECT_EQ(TF_OK, TF_GetCode(s));
+
+    TF_DeleteStatus(s);
+    TF_DeleteTensor(output);
+  };
+
+  SetupOp("AllocateOutputOp0", "AllocateOutput0", my_compute_func);
+
+  TF_ASSERT_OK(RunOpKernel());
+  Tensor* output = GetOutput(0);
+  EXPECT_EQ("Tensor<type: float shape: [0] values: >",
+            output->DebugString(100));
+}
+
+REGISTER_OP("AllocateOutputOp2x3").Output("output1: float");
+
+TEST_F(DeviceKernelOpTest, TestAllocateOutputSize2x3) {
+  auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
+    // Allocate 2x3 output
+    int64_t dim[2] = {2, 3};
+    size_t tensor_size_bytes = 6 * TF_DataTypeSize(TF_FLOAT);
+    TF_Tensor* output = TF_AllocateOutput(
+        /*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/dim,
+        /*num_dims=*/2, /*len=*/tensor_size_bytes);
+    EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
+    EXPECT_EQ(2, TF_NumDims(output));
+    EXPECT_EQ(2, TF_Dim(output, 0));
+    EXPECT_EQ(3, TF_Dim(output, 1));
+
+    // Set output to [1 2 3 4 5 6]
+    void* data = TF_TensorData(output);
+    float value[6] = {1, 2, 3, 4, 5, 6};
+#if GOOGLE_CUDA
+    OpKernelContext* cc_ctx = reinterpret_cast<OpKernelContext*>(ctx);
+    cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, value,
+                                                  tensor_size_bytes);
+#else
+    memcpy(data, value, tensor_size_bytes);
+#endif
+
+    TF_Status* s = TF_NewStatus();
+    TF_SetOutput(ctx, 0, output, s);
+    EXPECT_EQ(TF_OK, TF_GetCode(s));
+
+    TF_DeleteStatus(s);
+    TF_DeleteTensor(output);
+  };
+
+  SetupOp("AllocateOutputOp2x3", "AllocateOutput2x3", my_compute_func);
+
+  TF_ASSERT_OK(RunOpKernel());
+  Tensor* output = GetOutput(0);
+  EXPECT_EQ("Tensor<type: float shape: [2,3] values: [1 2 3][4 5 6]>",
+            output->DebugString(100));
+}
 }  // namespace tensorflow
diff --git a/tensorflow/c/tf_tensor.cc b/tensorflow/c/tf_tensor.cc
index deb3616..44efcba 100644
--- a/tensorflow/c/tf_tensor.cc
+++ b/tensorflow/c/tf_tensor.cc
@@ -31,6 +31,37 @@
 using tensorflow::errors::FailedPrecondition;
 using tensorflow::errors::InvalidArgument;
 
+namespace tensorflow {
+void* allocate_tensor(const char* operation, size_t len, Allocator* allocator) {
+  void* data = allocator->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len);
+  if (LogMemory::IsEnabled() && data != nullptr) {
+    LogMemory::RecordRawAllocation(
+        operation, LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, len, data,
+        allocator);
+  }
+  return data;
+}
+
+void* allocate_tensor(const char* operation, size_t len) {
+  return allocate_tensor(operation, len, cpu_allocator());
+}
+
+void deallocate_buffer(void* data, size_t len, void* arg) {
+  Allocator* allocator = nullptr;
+  if (arg == nullptr) {
+    allocator = cpu_allocator();
+  } else {
+    allocator = reinterpret_cast<Allocator*>(arg);
+  }
+  if (LogMemory::IsEnabled() && data != nullptr) {
+    LogMemory::RecordRawDeallocation(
+        "TensorFlow C Api", LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data,
+        allocator, false);
+  }
+  allocator->DeallocateRaw(data);
+}
+}  // namespace tensorflow
+
 namespace {
 class TF_ManagedBuffer : public TensorBuffer {
  public:
@@ -63,36 +94,17 @@
   bool OwnsMemory() const override { return false; }
 };
 
-void* allocate_tensor(const char* operation, size_t len) {
-  void* data =
-      tensorflow::cpu_allocator()->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len);
-  if (tensorflow::LogMemory::IsEnabled() && data != nullptr) {
-    tensorflow::LogMemory::RecordRawAllocation(
-        operation, tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID,
-        len, data, tensorflow::cpu_allocator());
-  }
-  return data;
-}
-
-void deallocate_buffer(void* data, size_t len, void* arg) {
-  if (tensorflow::LogMemory::IsEnabled() && data != nullptr) {
-    tensorflow::LogMemory::RecordRawDeallocation(
-        "TensorFlow C Api",
-        tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data,
-        tensorflow::cpu_allocator(), false);
-  }
-  tensorflow::cpu_allocator()->DeallocateRaw(data);
-}
-
 }  // namespace
 
 TF_Tensor::~TF_Tensor() { buffer->Unref(); }
 
 TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
                              int num_dims, size_t len) {
-  void* data = allocate_tensor("TF_AllocateTensor", len);
-  return TF_NewTensor(dtype, dims, num_dims, data, len, deallocate_buffer,
-                      nullptr);
+  void* data = tensorflow::allocate_tensor("TF_AllocateTensor", len,
+                                           tensorflow::cpu_allocator());
+  return TF_NewTensor(dtype, dims, num_dims, data, len,
+                      tensorflow::deallocate_buffer,
+                      tensorflow::cpu_allocator());
 }
 
 TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
@@ -117,8 +129,8 @@
     //
     // Other types have the same representation, so copy only if it is safe to
     // do so.
-    buf = new TF_ManagedBuffer(allocate_tensor("TF_NewTensor", len), len,
-                               deallocate_buffer, nullptr);
+    buf = new TF_ManagedBuffer(tensorflow::allocate_tensor("TF_NewTensor", len),
+                               len, tensorflow::deallocate_buffer, nullptr);
     std::memcpy(buf->data(), data, len);
     // Free the original buffer.
     deallocator(data, len, deallocator_arg);
@@ -342,7 +354,7 @@
 
   // Compute bytes needed for encoding.
   size_t size = 0;
-  const auto& srcarray = src.flat<string>();
+  const auto& srcarray = src.flat<tstring>();
   for (int i = 0; i < srcarray.size(); ++i) {
     const string& s = srcarray(i);
     // uint64 starting_offset, TF_StringEncode-d string.
@@ -428,7 +440,7 @@
   const char* limit = input + src_size;
 
   *dst = Tensor(static_cast<tensorflow::DataType>(src->dtype), src->shape);
-  auto dstarray = dst->flat<string>();
+  auto dstarray = dst->flat<tstring>();
   for (tensorflow::int64 i = 0; i < num_elements; ++i) {
     tensorflow::uint64 offset =
         reinterpret_cast<const tensorflow::uint64*>(input)[i];
@@ -447,3 +459,12 @@
 }
 
 }  // namespace tensorflow
+
+bool TF_TensorIsAligned(const TF_Tensor* tensor) {
+  if (EIGEN_MAX_ALIGN_BYTES == 0) {
+    return true;
+  }
+  void* ptr = TF_TensorData(tensor);
+  return tensor->dtype == TF_STRING ||
+         (reinterpret_cast<intptr_t>(ptr) % EIGEN_MAX_ALIGN_BYTES == 0);
+}
diff --git a/tensorflow/c/tf_tensor.h b/tensorflow/c/tf_tensor.h
index 5d4f70c..462fdc8 100644
--- a/tensorflow/c/tf_tensor.h
+++ b/tensorflow/c/tf_tensor.h
@@ -16,6 +16,7 @@
 #ifndef TENSORFLOW_C_TF_TENSOR_H_
 #define TENSORFLOW_C_TF_TENSOR_H_
 
+#include <stdbool.h>
 #include <stdint.h>
 
 #include "tensorflow/c/tf_datatype.h"
@@ -175,6 +176,9 @@
 // TF_STRING tensor.
 TF_CAPI_EXPORT extern size_t TF_StringEncodedSize(size_t len);
 
+// Returns bool iff this tensor is aligned.
+TF_CAPI_EXPORT extern bool TF_TensorIsAligned(const TF_Tensor*);
+
 #ifdef __cplusplus
 } /* end extern "C" */
 #endif
diff --git a/tensorflow/c/tf_tensor_internal.h b/tensorflow/c/tf_tensor_internal.h
index 6def66c..60a2ec8 100644
--- a/tensorflow/c/tf_tensor_internal.h
+++ b/tensorflow/c/tf_tensor_internal.h
@@ -42,5 +42,13 @@
   }
 };
 
+// Allocates tensor data buffer using specified allocator.
+// `operation` is a name for this operation.
+void* allocate_tensor(const char* operation, size_t len, Allocator* allocator);
+
+// Deallocates tensor data buffer.
+// Defaults to deallocating using CPU allocator. You can pass pointer to
+// a different Allocator as `arg`.
+void deallocate_buffer(void* data, size_t len, void* arg);
 }  // namespace tensorflow
 #endif  // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc
index a0353bf..86f503c 100644
--- a/tensorflow/cc/framework/cc_op_gen.cc
+++ b/tensorflow/cc/framework/cc_op_gen.cc
@@ -193,7 +193,7 @@
       string ret;
       for (int64 i = 0; i < num_elts; ++i) {
         if (i > 0) strings::StrAppend(&ret, " ");
-        strings::StrAppend(&ret, absl::CEscape(t.flat<string>()(i)));
+        strings::StrAppend(&ret, absl::CEscape(t.flat<tstring>()(i)));
       }
       return ret;
     }
diff --git a/tensorflow/cc/framework/ops.cc b/tensorflow/cc/framework/ops.cc
index 920a8e7..8516dfd 100644
--- a/tensorflow/cc/framework/ops.cc
+++ b/tensorflow/cc/framework/ops.cc
@@ -97,7 +97,7 @@
     Tensor elem = e.tensor;
     if (first.tensor.dtype() == DT_STRING) {
       for (int i = 0; i < elem.NumElements(); ++i) {
-        t.flat<string>()(offset + i) = elem.flat<string>()(i);
+        t.flat<tstring>()(offset + i) = elem.flat<tstring>()(i);
       }
       offset += elem.NumElements();
     } else {
diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc
index e93ca86..b5cac5f 100644
--- a/tensorflow/cc/framework/scope.cc
+++ b/tensorflow/cc/framework/scope.cc
@@ -272,7 +272,7 @@
   std::unordered_set<string> current_constraints(colocation_constraints_);
   const AttrSlice attrs = colocate_with_op.node()->attrs();
   std::vector<string> node_constraints;
-  if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) {
+  if (TryGetNodeAttr(attrs, kColocationAttrName, &node_constraints)) {
     for (const string& entry : node_constraints) {
       StringPiece s(entry);
       if (absl::ConsumePrefix(&s, kColocationGroupPrefix)) {
@@ -299,7 +299,7 @@
   return impl()->control_deps_;
 }
 
-void Scope::UpdateStatus(const Status s) const {
+void Scope::UpdateStatus(const Status& s) const {
   impl()->status_->Update(s);
   if (impl()->exit_on_error_ && !ok()) {
     LOG(FATAL) << *impl()->status_;
@@ -318,7 +318,7 @@
   if (ok()) {
     GraphDef graph_def;
     graph()->ToGraphDef(&graph_def);
-    UpdateStatus(ConvertGraphDefToGraph(opts, graph_def, g));
+    UpdateStatus(ConvertGraphDefToGraph(opts, std::move(graph_def), g));
   }
   return *impl()->status_;
 }
diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h
index ef2daff..63a555b 100644
--- a/tensorflow/cc/framework/scope.h
+++ b/tensorflow/cc/framework/scope.h
@@ -177,7 +177,7 @@
   /// Note: The status object is shared between all children of this scope.
   /// If the resulting status is not Status::OK() and exit_on_error_ is set on
   /// this scope, this function exits by calling LOG(FATAL).
-  void UpdateStatus(const Status s) const;
+  void UpdateStatus(const Status& s) const;
 
   // START_SKIP_DOXYGEN
 
diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD
index 01752b6..39b8492 100644
--- a/tensorflow/cc/saved_model/BUILD
+++ b/tensorflow/cc/saved_model/BUILD
@@ -10,7 +10,7 @@
     "tf_cc_test",
 )
 load(
-    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "//tensorflow/core/platform:default/build_config_root.bzl",
     "if_static",
     "if_static_and_not_mobile",
 )
diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc
index dfc7ccd..a3b80fb 100644
--- a/tensorflow/cc/saved_model/loader.cc
+++ b/tensorflow/cc/saved_model/loader.cc
@@ -75,7 +75,7 @@
 
 Tensor CreateStringTensor(const string& value) {
   Tensor tensor(DT_STRING, TensorShape({}));
-  tensor.scalar<string>()() = value;
+  tensor.scalar<tstring>()() = value;
   return tensor;
 }
 
@@ -219,7 +219,7 @@
 
   // Add variables to the graph.
   Tensor variables_path_tensor(DT_STRING, TensorShape({}));
-  variables_path_tensor.scalar<string>()() = variables_path;
+  variables_path_tensor.scalar<tstring>()() = variables_path;
 
   std::vector<std::pair<string, Tensor>> inputs = {
       {string(variable_filename_const_op_name), variables_path_tensor}};
diff --git a/tensorflow/cc/saved_model/python/BUILD b/tensorflow/cc/saved_model/python/BUILD
index fca45c8..b144065 100644
--- a/tensorflow/cc/saved_model/python/BUILD
+++ b/tensorflow/cc/saved_model/python/BUILD
@@ -1,7 +1,7 @@
 # Description:
 # CLIF wrappers for TensorFlow SavedModels.
 
-load("//tensorflow/core:platform/default/build_config.bzl", "tf_py_clif_cc")
+load("//tensorflow/core/platform:default/build_config.bzl", "tf_py_clif_cc")
 
 package(
     default_visibility = ["//visibility:public"],
diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc
index eeb9101..0ec48ec 100644
--- a/tensorflow/cc/tools/freeze_saved_model.cc
+++ b/tensorflow/cc/tools/freeze_saved_model.cc
@@ -42,6 +42,10 @@
     tensor_names->insert(coo_sparse.values_tensor_name());
     tensor_names->insert(coo_sparse.indices_tensor_name());
     tensor_names->insert(coo_sparse.dense_shape_tensor_name());
+  } else if (tensor_info.has_composite_tensor()) {
+    for (const auto& component : tensor_info.composite_tensor().components()) {
+      tensor_names->insert(component.name());
+    }
   } else {
     tensor_names->insert(tensor_info.name());
   }
diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc
index 979b23c..274a163 100644
--- a/tensorflow/cc/tools/freeze_saved_model_test.cc
+++ b/tensorflow/cc/tools/freeze_saved_model_test.cc
@@ -425,5 +425,63 @@
   TestFreezeGraphWithAndWithoutDependentVariables(true);
 }
 
+TEST_F(FreezeTest, InputsAndOutputsCompositeTensorSignatureDef) {
+  // Test that inputs and outputs get correctly populated for a
+  // SignatureDef containing composite tensor inputs and outputs.
+  SavedModelBundle saved_model_bundle;
+  SignatureDef signature_def;
+
+  TensorInfo& in = (*signature_def.mutable_inputs())["input_arg"];
+  in.mutable_composite_tensor()->add_components()->set_name("input1:0");
+  in.mutable_composite_tensor()->add_components()->set_name("input2:0");
+
+  TensorInfo& out = (*signature_def.mutable_outputs())["output_arg"];
+  out.mutable_composite_tensor()->add_components()->set_name("output2:0");
+  out.mutable_composite_tensor()->add_components()->set_name("output1:0");
+
+  AddSignatureDefToSavedModelBundle(signature_def, "signature_def",
+                                    &saved_model_bundle);
+  GraphDef frozen_graph_def;
+  std::unordered_set<string> inputs;
+  std::unordered_set<string> outputs;
+  TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+                                &outputs));
+  std::unordered_set<string> expected_inputs = {"input1:0", "input2:0"};
+  std::unordered_set<string> expected_outputs = {"output1:0", "output2:0"};
+  EXPECT_EQ(expected_inputs, inputs);
+  EXPECT_EQ(expected_outputs, outputs);
+}
+
+TEST_F(FreezeTest, InputsAndOutputsSparseCooSignatureDef) {
+  // Test that inputs and outputs get correctly populated for a
+  // SignatureDef containing composite tensor inputs and outputs.
+  SavedModelBundle saved_model_bundle;
+  SignatureDef signature_def;
+
+  TensorInfo& in = (*signature_def.mutable_inputs())["input_arg"];
+  in.mutable_coo_sparse()->set_values_tensor_name("input1:0");
+  in.mutable_coo_sparse()->set_indices_tensor_name("input2:0");
+  in.mutable_coo_sparse()->set_dense_shape_tensor_name("input3:0");
+
+  TensorInfo& out = (*signature_def.mutable_outputs())["output_arg"];
+  out.mutable_coo_sparse()->set_values_tensor_name("output1:0");
+  out.mutable_coo_sparse()->set_indices_tensor_name("output2:0");
+  out.mutable_coo_sparse()->set_dense_shape_tensor_name("output3:0");
+
+  AddSignatureDefToSavedModelBundle(signature_def, "signature_def",
+                                    &saved_model_bundle);
+  GraphDef frozen_graph_def;
+  std::unordered_set<string> inputs;
+  std::unordered_set<string> outputs;
+  TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+                                &outputs));
+  std::unordered_set<string> expected_inputs = {"input1:0", "input2:0",
+                                                "input3:0"};
+  std::unordered_set<string> expected_outputs = {"output1:0", "output2:0",
+                                                 "output3:0"};
+  EXPECT_EQ(expected_inputs, inputs);
+  EXPECT_EQ(expected_outputs, outputs);
+}
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 88b00cb..cbbb436 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -1,7 +1,7 @@
 load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
 load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
 load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps")
-load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
+load("//tensorflow/core/platform:default/build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
 
 package(
     default_visibility = [
@@ -281,6 +281,7 @@
     hdrs = ["xla_compilation_cache.h"],
     deps = [
         ":xla_activity_listener",
+        ":xla_activity_proto_cc",
         "//tensorflow/compiler/tf2xla:common",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/compiler/xla:statusor",
@@ -800,6 +801,8 @@
         ":flags",
         ":resource_operation_safety_analysis",
         ":union_find",
+        ":xla_activity_listener",
+        ":xla_activity_proto_cc",
         ":xla_cluster_util",
         "//tensorflow/compiler/jit/graphcycles",
         "//tensorflow/compiler/tf2xla:resource_operation_table",
@@ -901,6 +904,7 @@
     srcs = ["xla_activity_logging_listener.cc"],
     deps = [
         ":xla_activity_listener",
+        ":xla_activity_proto_cc",
         "//tensorflow/core:logger",
         "@com_google_absl//absl/memory",
     ],
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc
index 1265ff9..61695d5 100644
--- a/tensorflow/compiler/jit/build_xla_ops_pass.cc
+++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc
@@ -48,6 +48,19 @@
 
 namespace tensorflow {
 namespace {
+struct DebuggingOpts {
+  // If true, insert Print nodes to print every output from an XLA cluster.
+  bool print_outputs;
+
+  // If true, insert CheckNumerics nodes for every floating point typed input to
+  // an XLA cluster.
+  bool check_input_numerics;
+
+  // If true, insert CheckNumerics nodes for every floating point typed output
+  // from an XLA cluster.
+  bool check_output_numerics;
+};
+
 void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) {
   std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
                                      old_node->out_edges().end());
@@ -78,7 +91,8 @@
 // Replaces each outgoing edge from `old_node` with a merge node that merges in
 // the corresponding output from `new_node`.
 void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node,
-                            bool insert_print_nodes) {
+                            absl::string_view cluster_name,
+                            const DebuggingOpts& debugging_opts) {
   if (!s.status().ok()) {
     return;
   }
@@ -93,23 +107,36 @@
     int oidx = e->src_output();
     Output merged_output = merged_outputs[oidx];
     if (merged_output.node() == nullptr) {
-      ops::Merge merge_op(s.WithOpName(absl::StrCat("merge_oidx_", oidx)),
-                          {Output(old_node, oidx), Output(new_node, oidx)});
-      if (insert_print_nodes) {
+      Output new_output(new_node, oidx);
+      if (debugging_opts.print_outputs) {
         string cpu_device = "/job:localhost/replica:0/task:0/device:CPU:0";
-        ops::Print print_op(s.WithOpName(absl::StrCat("print_", oidx))
+        ops::Print print_op(s.WithOpName("print_", oidx)
                                 .WithDevice(cpu_device)
                                 .WithAssignedDevice(cpu_device),
-                            merge_op.output, {merge_op.output},
+                            new_output, {new_output},
                             ops::Print::Attrs{}
                                 .Message(absl::StrCat("output ", oidx, " from ",
                                                       old_node->name(), " is "))
                                 .FirstN(1000)
                                 .Summarize(-1));
-        merged_output = merged_outputs[oidx] = print_op;
-      } else {
-        merged_output = merged_outputs[oidx] = merge_op.output;
+        new_output = print_op;
       }
+
+      if (debugging_opts.check_output_numerics &&
+          DataTypeIsFloating(new_output.type())) {
+        ops::CheckNumerics check_numerics_op(
+            s.WithOpName("check_output_", oidx)
+                .WithDevice(new_node->requested_device())
+                .WithAssignedDevice(new_node->assigned_device_name()),
+            new_output,
+            absl::StrCat("CheckNumerics failed for output ", oidx, "(",
+                         new_output.name(), ") from cluster ", cluster_name));
+        new_output = check_numerics_op;
+      }
+
+      ops::Merge merge_op(s.WithOpName("merge_oidx_", oidx),
+                          {Output(old_node, oidx), new_output});
+      merged_output = merged_outputs[oidx] = merge_op.output;
     }
 
     Node* dst = e->dst();
@@ -324,11 +351,34 @@
   return result;
 }
 
+std::vector<Output> GetXlaRunArgs(const Scope& s,
+                                  const XlaClusterInfo& cluster_info,
+                                  const DebuggingOpts& debugging_opts) {
+  std::vector<Output> xla_run_args;
+  xla_run_args.reserve(cluster_info.non_constant_inputs.size() +
+                       cluster_info.resource_inputs.size());
+  int input_idx = 0;
+  for (const Output& o : cluster_info.non_constant_inputs) {
+    if (debugging_opts.check_input_numerics && DataTypeIsFloating(o.type())) {
+      ops::CheckNumerics check_numerics_op(
+          s.WithOpName("check_input_", input_idx), o,
+          absl::StrCat("CheckNumerics failed for input ", input_idx, "(",
+                       o.name(), ") into ", cluster_info.function.name()));
+      xla_run_args.push_back(check_numerics_op);
+    } else {
+      xla_run_args.push_back(o);
+    }
+    input_idx++;
+  }
+  absl::c_copy(cluster_info.resource_inputs, std::back_inserter(xla_run_args));
+  return xla_run_args;
+}
+
 Status ReplaceNodeWithXlaCompileAndXlaRun(
     jit::DeviceInfoCache* device_info_cache,
     const GraphOptimizationPassOptions& options,
     const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled,
-    bool insert_print_nodes, Graph* g, Node* n) {
+    const DebuggingOpts& debugging_opts, Graph* g, Node* n) {
   XlaClusterInfo cluster_info;
   TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info));
 
@@ -361,12 +411,12 @@
   TF_RETURN_IF_ERROR(
       CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node()));
 
+  std::vector<Output> xla_run_args =
+      GetXlaRunArgs(root, cluster_info, debugging_opts);
+
   if (requires_compilation) {
     // "Strict" compilation:  every _XlaCompile invocation must compile the
     // cluster.
-    std::vector<Output> xla_run_args = cluster_info.non_constant_inputs;
-    absl::c_copy(cluster_info.resource_inputs,
-                 std::back_inserter(xla_run_args));
     ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args,
                          xla_compile.key, n->output_types());
 
@@ -391,9 +441,6 @@
     Output predicated_compilation_key = s.output_true;
     Output inverse_predicated_compilation_key = s.output_false;
 
-    std::vector<Output> xla_run_args = cluster_info.non_constant_inputs;
-    absl::c_copy(cluster_info.resource_inputs,
-                 std::back_inserter(xla_run_args));
     ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args,
                          predicated_compilation_key, n->output_types());
 
@@ -402,7 +449,7 @@
 
     MergeOutgoingDataEdges(root, /*old_node=*/n,
                            /*new_node=*/xla_run.operation.node(),
-                           insert_print_nodes);
+                           cluster_info.function.name(), debugging_opts);
 
     TF_RETURN_IF_ERROR(root.status());
 
@@ -443,15 +490,25 @@
       enable_lazy_compilation_
           ? *enable_lazy_compilation_
           : GetBuildXlaOpsPassFlags()->tf_xla_enable_lazy_compilation;
-  bool insert_print_nodes =
-      GetBuildXlaOpsPassFlags()->tf_xla_print_cluster_outputs;
 
   jit::DeviceInfoCache device_info_cache;
+  const BuildXlaOpsPassFlags& flags = *GetBuildXlaOpsPassFlags();
+
+  DebuggingOpts debugging_opts;
+  debugging_opts.print_outputs = flags.tf_xla_print_cluster_outputs;
+  debugging_opts.check_input_numerics =
+      flags.tf_xla_check_cluster_input_numerics;
+  debugging_opts.check_output_numerics =
+      flags.tf_xla_check_cluster_output_numerics;
+
+  VLOG(1) << "print_outputs = " << debugging_opts.print_outputs;
+  VLOG(1) << "check_input_numerics = " << debugging_opts.check_input_numerics;
+  VLOG(1) << "check_output_numerics = " << debugging_opts.check_output_numerics;
 
   for (Node* n : xla_compiled_kernels) {
     TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(
         &device_info_cache, options, *options.flib_def,
-        lazy_compilation_enabled, insert_print_nodes, graph, n));
+        lazy_compilation_enabled, debugging_opts, graph, n));
   }
 
   if (VLOG_IS_ON(1)) {
diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc
index 5e3b93d..7b5c26f 100644
--- a/tensorflow/compiler/jit/compilability_check_util.cc
+++ b/tensorflow/compiler/jit/compilability_check_util.cc
@@ -37,6 +37,8 @@
 #include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
 #include "tensorflow/compiler/jit/union_find.h"
+#include "tensorflow/compiler/jit/xla_activity.pb.h"
+#include "tensorflow/compiler/jit/xla_activity_listener.h"
 #include "tensorflow/compiler/jit/xla_cluster_util.h"
 #include "tensorflow/compiler/tf2xla/const_analysis.h"
 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
@@ -263,9 +265,15 @@
 bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) const {
   // b/128001705: SelfAdjointEigV2 and Svd performance issues.
   // b/135640736: MatrixInverse performance issues.
+  // https://github.com/tensorflow/tensorflow/pull/31012:
+  //    ResizeNearestNeighbor, ResizeBilinear, and ResizeBilinearGrad sometimes
+  //    create convolutions too large for CuDNN to handle.
   return node.type_string() == "SelfAdjointEigV2" ||
          node.type_string() == "Svd" || node.type_string() == "Qr" ||
-         node.type_string() == "MatrixInverse";
+         node.type_string() == "MatrixInverse" ||
+         node.type_string() == "ResizeNearestNeighbor" ||
+         node.type_string() == "ResizeBilinear" ||
+         node.type_string() == "ResizeBilinearGrad";
 }
 
 bool RecursiveCompilabilityChecker::IsCompilableNode(
@@ -318,7 +326,7 @@
     return false;
   }
 
-  if (node.type_string() == "While" &&
+  if (node.IsWhileNode() &&
       !IsCompilableWhile(node, lib_runtime, stack_trace, uncompilable_nodes)) {
     LogNotCompilable(node, "unsupported while");
     return false;
@@ -394,6 +402,9 @@
   if (!op_filter_.allow_inaccurate_ops && OpIsInaccurate(node)) {
     absl::string_view uncompilable_reason =
         "operation with numerical accuracy issues";
+    BroadcastOptimizationRemark(XlaOptimizationRemark::INACCURATE_OPERATION,
+                                node.DebugString())
+        .IgnoreError();
     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
                               uncompilable_nodes);
     LogNotCompilable(node, uncompilable_reason);
@@ -402,6 +413,9 @@
 
   if (!op_filter_.allow_slow_ops && OpIsSlow(node)) {
     absl::string_view uncompilable_reason = "slow operation";
+    BroadcastOptimizationRemark(XlaOptimizationRemark::SLOW_OPERATION,
+                                node.DebugString())
+        .IgnoreError();
     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
                               uncompilable_nodes);
     LogNotCompilable(node, uncompilable_reason);
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index 6992a01..e0c0c0b 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -1317,7 +1317,7 @@
 bool IsXlaCompiledKernel(const Node& node) {
   bool is_compiled = false;
   bool has_compilation_attr =
-      GetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled).ok() &&
+      TryGetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled) &&
       is_compiled;
   return has_compilation_attr ? is_compiled : false;
 }
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
index 2c2cd09..b988998 100644
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
@@ -245,8 +245,8 @@
   // while iterating.
   std::vector<Node*> launch_nodes;
   for (Node* n : graph->nodes()) {
-    string name;
-    if (GetNodeAttr(n->attrs(), kXlaClusterAttr, &name).ok()) {
+    const string& name = GetNodeAttrString(n->attrs(), kXlaClusterAttr);
+    if (!name.empty()) {
       launch_nodes.push_back(n);
     }
   }
diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
index 85fb69b..b35e08fb1 100644
--- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
+++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
@@ -24,7 +24,6 @@
 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/core/common_runtime/function.h"
-#include "tensorflow/core/common_runtime/lower_functional_ops.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/graph_to_functiondef.h"
 #include "tensorflow/core/framework/node_def_builder.h"
@@ -33,6 +32,7 @@
 #include "tensorflow/core/graph/algorithm.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/util/dump_graph.h"
 #include "tensorflow/stream_executor/lib/statusor.h"
 
@@ -369,7 +369,8 @@
   return new_def;
 }
 
-Status ValidateOutsideCompilationCallNode(Node* call_node) {
+TF_ATTRIBUTE_NOINLINE Status
+ValidateOutsideCompilationCallNode(Node* call_node) {
   // DT_INT64 as input/output for outside compilation is not supported yet:
   // b/120809951.
   for (const Edge* e : call_node->in_edges()) {
@@ -402,7 +403,7 @@
 }
 
 // Replace outside compilation function call node with XlaHostCompute node.
-xla::StatusOr<Node*> ReplaceOutsideCompilationCallNode(
+TF_ATTRIBUTE_NOINLINE xla::StatusOr<Node*> ReplaceOutsideCompilationCallNode(
     Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
     const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
   // Build XlaHostCompute NodeDef.
@@ -440,7 +441,7 @@
         n->ClearAttr(attr_name);
         n->AddAttr(attr_name, branch_func);
       }
-    } else if (n->type_string() == "While") {
+    } else if (n->IsWhileNode()) {
       for (const string& attr_name : std::vector<string>{"cond", "body"}) {
         NameAttrList branch_func;
         TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func));
@@ -595,7 +596,7 @@
 Status PostprocessLiftedArgsForWhile(
     const std::unordered_map<string, Node*>& outside_compilation_attr_to_node,
     Graph* g, Node* n, FunctionLibraryDefinition* fld) {
-  TF_RET_CHECK(n->type_string() == "While");
+  TF_RET_CHECK(n->IsWhileNode());
 
   // Check if there is any lifted args in body function.
   NameAttrList body_func;
@@ -695,7 +696,7 @@
 Status PostprocessLiftedArgsForIf(
     const std::unordered_map<string, Node*>& outside_compilation_attr_to_node,
     Graph* g, Node* n, FunctionLibraryDefinition* fld) {
-  TF_RET_CHECK(n->type_string() == "If");
+  TF_RET_CHECK(n->IsIfNode());
 
   NameAttrList then_branch_func;
   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "then_branch", &then_branch_func));
@@ -913,10 +914,9 @@
   for (Node* n : g.op_nodes()) {
     bool is_lifted_arg;
     string outside_compilation_attr;
-    if (GetNodeAttr(n->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg).ok() &&
-        GetNodeAttr(n->def(), "_xla_outside_compilation",
-                    &outside_compilation_attr)
-            .ok()) {
+    if (TryGetNodeAttr(n->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg) &&
+        TryGetNodeAttr(n->def(), "_xla_outside_compilation",
+                       &outside_compilation_attr)) {
       TF_RET_CHECK(is_lifted_arg);
       TF_RET_CHECK(n->IsIdentity() || n->type_string() == "Placeholder");
       outside_compilation_attr_to_node[outside_compilation_attr] = n;
@@ -936,12 +936,12 @@
       continue;
     }
 
-    if (n->type_string() == "While") {
+    if (n->IsWhileNode()) {
       TF_RETURN_IF_ERROR(PostprocessLiftedArgsForWhile(
           outside_compilation_attr_to_node, g, n, fld));
     }
 
-    if (n->type_string() == "If") {
+    if (n->IsIfNode()) {
       TF_RETURN_IF_ERROR(PostprocessLiftedArgsForIf(
           outside_compilation_attr_to_node, g, n, fld));
     }
@@ -1307,9 +1307,9 @@
 }
 
 // Builds XlaSendToHost node which sends cond predicate to host.
-xla::StatusOr<Node*> BuildSendIfPredNode(const string& name,
-                                         const string& host_transfer_key,
-                                         Node* pred_node, Graph* g) {
+TF_ATTRIBUTE_NOINLINE xla::StatusOr<Node*> BuildSendIfPredNode(
+    const string& name, const string& host_transfer_key, Node* pred_node,
+    Graph* g) {
   NodeDefBuilder send_pred_builder(name, "XlaSendToHost");
   send_pred_builder.Attr("Tinput", DT_BOOL);
   send_pred_builder.Attr("key", absl::StrCat(host_transfer_key, "_dtoh_0"));
@@ -1372,15 +1372,13 @@
 }
 
 // Builds host side graph for If node.
-Status BuildHostGraphForIfNode(const string& xla_cluster_attr_name,
-                               const string& outside_compilation_attr_name,
-                               const string& xla_cluster_name,
-                               const string& if_node_name,
-                               const string& host_transfer_key,
-                               const string& host_graph_func_name,
-                               FunctionLibraryDefinition* fld,
-                               const string& then_branch_host_func_name,
-                               const string& else_branch_host_func_name) {
+TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForIfNode(
+    const string& xla_cluster_attr_name,
+    const string& outside_compilation_attr_name, const string& xla_cluster_name,
+    const string& if_node_name, const string& host_transfer_key,
+    const string& host_graph_func_name, FunctionLibraryDefinition* fld,
+    const string& then_branch_host_func_name,
+    const string& else_branch_host_func_name) {
   Graph host_graph(fld);
   string outside_compilation_name = absl::StrCat("oc_if_", if_node_name);
   AttrValue device_ordinal_value;
@@ -1457,10 +1455,9 @@
 }
 
 // Rewrites loop cond to add a node which sends loop cond to host.
-Status AddSendLoopPredToLoopCond(FunctionLibraryDefinition* fld,
-                                 const NameAttrList& loop_cond_func,
-                                 const string& while_node_name,
-                                 const string& host_transfer_key) {
+TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond(
+    FunctionLibraryDefinition* fld, const NameAttrList& loop_cond_func,
+    const string& while_node_name, const string& host_transfer_key) {
   // Instantiate the loop cond function.
   std::unique_ptr<FunctionBody> fbody;
   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld->Find(loop_cond_func.name()),
@@ -1648,7 +1645,7 @@
 }
 
 // Builds host side graph for while node.
-Status BuildHostGraphForWhileNode(
+TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForWhileNode(
     const string& xla_cluster_attr_name,
     const string& outside_compilation_attr_name, const string& xla_cluster_name,
     const string& while_node_name, const string& host_transfer_key,
@@ -1745,10 +1742,6 @@
   call_builder.Attr(kXlaHasHostTransferAttrName, true);
   call_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
   call_builder.Attr(outside_compilation_attr_name, call_builder.node_name());
-  // Make sure control outputs of this function call node will be respected when
-  // this node is lowered.
-  call_builder.Attr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr,
-                    true);
   NodeDef call_def;
   TF_RETURN_IF_ERROR(call_builder.Finalize(&call_def));
   Status s;
@@ -1771,6 +1764,221 @@
   return Status::OK();
 }
 
+TF_ATTRIBUTE_NOINLINE Status ExtractOutsideCompilationForFuncCallNode(
+    const string& xla_cluster_attr_name,
+    const string& outside_compilation_attr_name, const string& xla_cluster_name,
+    const std::map<string, int>& host_compute_core, Graph* g, Node* n,
+    FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
+    std::vector<string>* host_graphs,
+    std::vector<string>* shape_inference_graphs,
+    bool* has_outside_compilation) {
+  bool func_has_outside_compilation = false;
+  NameAttrList func;
+  if (fld->Contains(n->type_string())) {
+    func.set_name(n->type_string());
+    typedef protobuf::Map<string, AttrValue> AttrMap;
+    *func.mutable_attr() = AttrMap(n->attrs().begin(), n->attrs().end());
+  } else if (n->IsPartitionedCall()) {
+    TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &func));
+  } else {
+    TF_RET_CHECK(n->type_string() == FunctionLibraryDefinition::kGradientOp);
+    func.set_name(FunctionLibraryDefinition::kGradientOp);
+    *func.mutable_attr() = n->def().attr();
+  }
+  string new_func_name = absl::StrCat(n->name(), "_oc");
+  string host_func_name = absl::StrCat("oc_func_call_host_", n->name());
+  TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
+      xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
+      func, new_func_name, host_func_name, host_compute_core, flr, fld,
+      shape_inference_graphs, &func_has_outside_compilation));
+
+  // If the function call does not have outside compilation, nothing to do.
+  if (!func_has_outside_compilation) {
+    return Status::OK();
+  }
+
+  *has_outside_compilation = true;
+
+  // Change `n` to call the new function directly.
+  auto replace_builder =
+      absl::make_unique<NodeDefBuilder>(n->name(), new_func_name, fld);
+  std::vector<NodeDefBuilder::NodeOut> inputs(n->num_inputs());
+  for (const Edge* e : n->in_edges()) {
+    if (e->IsControlEdge()) {
+      continue;
+    }
+
+    TF_RET_CHECK(e->dst_input() >= 0 && e->dst_input() < inputs.size());
+    inputs[e->dst_input()] =
+        NodeDefBuilder::NodeOut{e->src()->name(), e->src_output(),
+                                e->src()->output_type(e->src_output())};
+  }
+  for (const auto& input : inputs) {
+    replace_builder->Input(input);
+  }
+  for (const auto& attr : n->attrs()) {
+    replace_builder->Attr(attr.first, attr.second);
+  }
+  auto replace_def = absl::make_unique<NodeDef>();
+  TF_RETURN_IF_ERROR(replace_builder->Finalize(replace_def.get()));
+  TF_ASSIGN_OR_RETURN(Node * replace, ReplaceNode(g, n, *replace_def));
+  replace->AddAttr(kXlaTokenInputNodesAttrName,
+                   std::vector<string>{kXlaTokenArgNodeName});
+
+  // Build host side graph for the function call.
+  string oc_host_graph_name =
+      absl::StrCat("oc_func_host_graph_", replace->name());
+  TF_RETURN_IF_ERROR(BuildHostGraphForFuncCallNode(
+      xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
+      replace->name(), host_func_name, oc_host_graph_name, fld));
+
+  // Record the host graph.
+  host_graphs->push_back(oc_host_graph_name);
+
+  return Status::OK();
+}
+
+Status ExtractOutsideCompilationForIfNode(
+    const string& xla_cluster_attr_name,
+    const string& outside_compilation_attr_name, const string& xla_cluster_name,
+    const std::map<string, int>& host_compute_core, Graph* g, Node* n,
+    FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
+    std::vector<string>* host_graphs,
+    std::vector<string>* shape_inference_graphs,
+    bool* has_outside_compilation) {
+  // Instantiate "then_branch" and "else_branch".
+  NameAttrList then_branch, else_branch;
+  TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "then_branch", &then_branch));
+  TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "else_branch", &else_branch));
+
+  // Extract outside compilation for then_branch and else_branch.
+  bool then_branch_has_outside_compilation = false;
+  bool else_branch_has_outside_compilation = false;
+  string then_branch_host_func_name =
+             absl::StrCat("oc_then_branch_host_if_", n->name()),
+         else_branch_host_func_name =
+             absl::StrCat("oc_else_branch_host_if_", n->name());
+  string then_branch_xla_func_name = absl::StrCat(then_branch.name(), "_oc"),
+         else_branch_xla_func_name = absl::StrCat(else_branch.name(), "_oc");
+  TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
+      xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
+      then_branch, then_branch_xla_func_name, then_branch_host_func_name,
+      host_compute_core, flr, fld, shape_inference_graphs,
+      &then_branch_has_outside_compilation));
+  TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
+      xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
+      else_branch, else_branch_xla_func_name, else_branch_host_func_name,
+      host_compute_core, flr, fld, shape_inference_graphs,
+      &else_branch_has_outside_compilation));
+
+  // If then/else branch do not have outside compilation, nothing to do.
+  if (!then_branch_has_outside_compilation &&
+      !else_branch_has_outside_compilation) {
+    return Status::OK();
+  }
+
+  *has_outside_compilation = true;
+
+  // Change If node to call the new functions.
+  then_branch.set_name(then_branch_xla_func_name);
+  n->ClearAttr("then_branch");
+  n->AddAttr("then_branch", then_branch);
+  else_branch.set_name(else_branch_xla_func_name);
+  n->ClearAttr("else_branch");
+  n->AddAttr("else_branch", else_branch);
+
+  string host_transfer_key = absl::StrCat("oc_if_pred_", n->name());
+
+  // XLA computation: add a SendToHost node to send cond predicate.
+  Node* pred_node;
+  TF_RETURN_IF_ERROR(n->input_node(0, &pred_node));
+  TF_ASSIGN_OR_RETURN(
+      Node * send_pred_node,
+      BuildSendIfPredNode(absl::StrCat("send_oc_if_pred_", n->name()),
+                          host_transfer_key, pred_node, g));
+  n->AddAttr(kXlaTokenInputNodesAttrName,
+             std::vector<string>{send_pred_node->name()});
+
+  // Add a control edge from `send_pred_node` to If node, so XlaCompiler will
+  // visit If node after `send_pred_node`, thus the token output for
+  // `send_pred_node` has been generated.
+  g->AddControlEdge(send_pred_node, n);
+
+  // Build host side graph for the "If" node.
+  string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name());
+  TF_RETURN_IF_ERROR(BuildHostGraphForIfNode(
+      xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
+      n->name(), host_transfer_key, oc_host_graph_name, fld,
+      then_branch_host_func_name, else_branch_host_func_name));
+  host_graphs->push_back(oc_host_graph_name);
+
+  return Status::OK();
+}
+
+Status ExtractOutsideCompilationForWhileNode(
+    const string& xla_cluster_attr_name,
+    const string& outside_compilation_attr_name, const string& xla_cluster_name,
+    const std::map<string, int>& host_compute_core, Graph* g, Node* n,
+    FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
+    std::vector<string>* host_graphs,
+    std::vector<string>* shape_inference_graphs,
+    bool* has_outside_compilation) {
+  // Instantiate "cond" and "body".
+  NameAttrList cond, body;
+  TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "cond", &cond));
+  TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "body", &body));
+
+  // Extract outside compilation for cond and body.
+  bool cond_has_outside_compilation = false;
+  bool body_has_outside_compilation = false;
+  string cond_host_func_name = absl::StrCat("oc_cond_host_while_", n->name()),
+         body_host_func_name = absl::StrCat("oc_body_host_while_", n->name());
+  string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"),
+         body_xla_func_name = absl::StrCat(body.name(), "_oc");
+  TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
+      xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
+      cond, cond_xla_func_name, cond_host_func_name, host_compute_core, flr,
+      fld, shape_inference_graphs, &cond_has_outside_compilation));
+  TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
+      xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
+      body, body_xla_func_name, body_host_func_name, host_compute_core, flr,
+      fld, shape_inference_graphs, &body_has_outside_compilation));
+
+  // If cond/body do not have outside compilation, nothing to do.
+  if (!cond_has_outside_compilation && !body_has_outside_compilation) {
+    return Status::OK();
+  }
+
+  *has_outside_compilation = true;
+
+  // Change While node to call the new functions.
+  cond.set_name(cond_xla_func_name);
+  n->ClearAttr("cond");
+  n->AddAttr("cond", cond);
+  body.set_name(body_xla_func_name);
+  n->ClearAttr("body");
+  n->AddAttr("body", body);
+
+  string host_transfer_key = absl::StrCat("oc_while_pred_", n->name());
+
+  // XLA computation: rewrite cond function to add a SendToHost node to send
+  // loop predicate.
+  TF_RETURN_IF_ERROR(
+      AddSendLoopPredToLoopCond(fld, cond, n->name(), host_transfer_key));
+  n->AddAttr(kXlaTokenInputNodesAttrName,
+             std::vector<string>{kXlaTokenArgNodeName});
+
+  // Build host side graph for the "While" node.
+  string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name());
+  TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode(
+      xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
+      n->name(), host_transfer_key, oc_host_graph_name, fld,
+      cond_host_func_name, body_host_func_name));
+  host_graphs->push_back(oc_host_graph_name);
+
+  return Status::OK();
+}
+
 Status ExtractOutsideCompilationForNodesWithAssociatedFunctions(
     Graph* g, const string& xla_cluster_attr_name,
     const string& outside_compilation_attr_name, const string& xla_cluster_name,
@@ -1782,7 +1990,7 @@
   for (Node* n : g->nodes()) {
     if (n->IsIfNode()) {
       if_nodes.push_back(n);
-    } else if (n->type_string() == "While") {
+    } else if (n->IsWhileNode()) {
       while_nodes.push_back(n);
     } else if (IsFunctionCall(*fld, *n)) {
       func_call_nodes.push_back(n);
@@ -1790,184 +1998,24 @@
   }
 
   for (Node* n : func_call_nodes) {
-    // Extract outside compilation for the function call.
-    bool func_has_outside_compilation = false;
-    NameAttrList func;
-    if (fld->Contains(n->type_string())) {
-      func.set_name(n->type_string());
-      typedef protobuf::Map<string, AttrValue> AttrMap;
-      *func.mutable_attr() = AttrMap(n->attrs().begin(), n->attrs().end());
-    } else if (n->IsPartitionedCall()) {
-      TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &func));
-    } else {
-      TF_RET_CHECK(n->type_string() == FunctionLibraryDefinition::kGradientOp);
-      func.set_name(FunctionLibraryDefinition::kGradientOp);
-      *func.mutable_attr() = n->def().attr();
-    }
-    string new_func_name = absl::StrCat(n->name(), "_oc");
-    string host_func_name = absl::StrCat("oc_func_call_host_", n->name());
-    TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
+    TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFuncCallNode(
         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
-        func, new_func_name, host_func_name, host_compute_core, flr, fld,
-        shape_inference_graphs, &func_has_outside_compilation));
-
-    // If the function call does not have outside compilation, nothing to do.
-    if (!func_has_outside_compilation) {
-      continue;
-    }
-
-    *has_outside_compilation = true;
-
-    // Change `n` to call the new function directly.
-    NodeDefBuilder replace_builder(n->name(), new_func_name, fld);
-    for (const Edge* e : n->in_edges()) {
-      if (e->IsControlEdge()) {
-        continue;
-      }
-      replace_builder.Input(e->src()->name(), e->src_output(),
-                            e->src()->output_type(e->src_output()));
-    }
-    for (const auto& attr : n->attrs()) {
-      replace_builder.Attr(attr.first, attr.second);
-    }
-    NodeDef replace_def;
-    TF_RETURN_IF_ERROR(replace_builder.Finalize(&replace_def));
-    TF_ASSIGN_OR_RETURN(Node * replace, ReplaceNode(g, n, replace_def));
-    replace->AddAttr(kXlaTokenInputNodesAttrName,
-                     std::vector<string>{kXlaTokenArgNodeName});
-
-    // Build host side graph for the function call.
-    string oc_host_graph_name =
-        absl::StrCat("oc_func_host_graph_", replace->name());
-    TF_RETURN_IF_ERROR(BuildHostGraphForFuncCallNode(
-        xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
-        replace->name(), host_func_name, oc_host_graph_name, fld));
-
-    // Record the host graph.
-    host_graphs->push_back(oc_host_graph_name);
+        host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs,
+        has_outside_compilation));
   }
 
   for (Node* n : if_nodes) {
-    // Instantiate "then_branch" and "else_branch".
-    NameAttrList then_branch, else_branch;
-    TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "then_branch", &then_branch));
-    TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "else_branch", &else_branch));
-
-    // Extract outside compilation for then_branch and else_branch.
-    bool then_branch_has_outside_compilation = false;
-    bool else_branch_has_outside_compilation = false;
-    string then_branch_host_func_name =
-               absl::StrCat("oc_then_branch_host_if_", n->name()),
-           else_branch_host_func_name =
-               absl::StrCat("oc_else_branch_host_if_", n->name());
-    string then_branch_xla_func_name = absl::StrCat(then_branch.name(), "_oc"),
-           else_branch_xla_func_name = absl::StrCat(else_branch.name(), "_oc");
-    TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
+    TF_RETURN_IF_ERROR(ExtractOutsideCompilationForIfNode(
         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
-        then_branch, then_branch_xla_func_name, then_branch_host_func_name,
-        host_compute_core, flr, fld, shape_inference_graphs,
-        &then_branch_has_outside_compilation));
-    TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
-        xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
-        else_branch, else_branch_xla_func_name, else_branch_host_func_name,
-        host_compute_core, flr, fld, shape_inference_graphs,
-        &else_branch_has_outside_compilation));
-
-    // If then/else branch do not have outside compilation, nothing to do.
-    if (!then_branch_has_outside_compilation &&
-        !else_branch_has_outside_compilation) {
-      continue;
-    }
-
-    *has_outside_compilation = true;
-
-    // Change If node to call the new functions.
-    then_branch.set_name(then_branch_xla_func_name);
-    n->ClearAttr("then_branch");
-    n->AddAttr("then_branch", then_branch);
-    else_branch.set_name(else_branch_xla_func_name);
-    n->ClearAttr("else_branch");
-    n->AddAttr("else_branch", else_branch);
-
-    string host_transfer_key = absl::StrCat("oc_if_pred_", n->name());
-
-    // XLA computation: add a SendToHost node to send cond predicate.
-    Node* pred_node;
-    TF_RETURN_IF_ERROR(n->input_node(0, &pred_node));
-    TF_ASSIGN_OR_RETURN(
-        Node * send_pred_node,
-        BuildSendIfPredNode(absl::StrCat("send_oc_if_pred_", n->name()),
-                            host_transfer_key, pred_node, g));
-    n->AddAttr(kXlaTokenInputNodesAttrName,
-               std::vector<string>{send_pred_node->name()});
-
-    // Add a control edge from `send_pred_node` to If node, so XlaCompiler will
-    // visit If node after `send_pred_node`, thus the token output for
-    // `send_pred_node` has been generated.
-    g->AddControlEdge(send_pred_node, n);
-
-    // Build host side graph for the "If" node.
-    string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name());
-    TF_RETURN_IF_ERROR(BuildHostGraphForIfNode(
-        xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
-        n->name(), host_transfer_key, oc_host_graph_name, fld,
-        then_branch_host_func_name, else_branch_host_func_name));
-    host_graphs->push_back(oc_host_graph_name);
+        host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs,
+        has_outside_compilation));
   }
 
   for (Node* n : while_nodes) {
-    // Instantiate "cond" and "body".
-    NameAttrList cond, body;
-    TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "cond", &cond));
-    TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "body", &body));
-
-    // Extract outside compilation for cond and body.
-    bool cond_has_outside_compilation = false;
-    bool body_has_outside_compilation = false;
-    string cond_host_func_name = absl::StrCat("oc_cond_host_while_", n->name()),
-           body_host_func_name = absl::StrCat("oc_body_host_while_", n->name());
-    string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"),
-           body_xla_func_name = absl::StrCat(body.name(), "_oc");
-    TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
+    TF_RETURN_IF_ERROR(ExtractOutsideCompilationForWhileNode(
         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
-        cond, cond_xla_func_name, cond_host_func_name, host_compute_core, flr,
-        fld, shape_inference_graphs, &cond_has_outside_compilation));
-    TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
-        xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
-        body, body_xla_func_name, body_host_func_name, host_compute_core, flr,
-        fld, shape_inference_graphs, &body_has_outside_compilation));
-
-    // If cond/body do not have outside compilation, nothing to do.
-    if (!cond_has_outside_compilation && !body_has_outside_compilation) {
-      continue;
-    }
-
-    *has_outside_compilation = true;
-
-    // Change While node to call the new functions.
-    cond.set_name(cond_xla_func_name);
-    n->ClearAttr("cond");
-    n->AddAttr("cond", cond);
-    body.set_name(body_xla_func_name);
-    n->ClearAttr("body");
-    n->AddAttr("body", body);
-
-    string host_transfer_key = absl::StrCat("oc_while_pred_", n->name());
-
-    // XLA computation: rewrite cond function to add a SendToHost node to send
-    // loop predicate.
-    TF_RETURN_IF_ERROR(
-        AddSendLoopPredToLoopCond(fld, cond, n->name(), host_transfer_key));
-    n->AddAttr(kXlaTokenInputNodesAttrName,
-               std::vector<string>{kXlaTokenArgNodeName});
-
-    // Build host side graph for the "While" node.
-    string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name());
-    TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode(
-        xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
-        n->name(), host_transfer_key, oc_host_graph_name, fld,
-        cond_host_func_name, body_host_func_name));
-    host_graphs->push_back(oc_host_graph_name);
+        host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs,
+        has_outside_compilation));
   }
 
   return Status::OK();
@@ -2130,11 +2178,11 @@
 
   // Encapsulate outside_compilation cluster into function call node.
   std::unique_ptr<Graph> graph_out;
-  RewriteOutsideCompilationSubgraphFn rewrite_fn(
+  auto rewrite_fn = absl::make_unique<RewriteOutsideCompilationSubgraphFn>(
       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
       new_func_name);
   TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
-      outside_compilation_attr_name, *fbody->graph, rewrite_fn,
+      outside_compilation_attr_name, *fbody->graph, *rewrite_fn,
       /*reuse_existing_functions=*/true, &graph_out, fld));
 
   // Replace outside_compilation function nodes with HostCompute ops.
@@ -2149,26 +2197,26 @@
       // If we could not infer shapes for XlaSendFromHost inputs statically, we
       // will set the "shape_inference_graph" attribute. In that case, copy
       // outside compilation subgraph as shape inference graph in `fld`.
-      NameAttrList shape_inference_graph;
+      auto shape_inference_graph = absl::make_unique<NameAttrList>();
       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "shape_inference_graph",
-                                     &shape_inference_graph));
-      if (!shape_inference_graph.name().empty()) {
-        shape_inference_graphs->push_back(shape_inference_graph.name());
+                                     shape_inference_graph.get()));
+      if (!shape_inference_graph->name().empty()) {
+        shape_inference_graphs->push_back(shape_inference_graph->name());
         shape_inference_graphs_to_rewrite.push_back(
-            shape_inference_graph.name());
+            shape_inference_graph->name());
 
         const FunctionDef* xla_fdef = fld->Find(n->name());
         if (!xla_fdef) {
           return errors::Internal("Cannot find XLA function ", n->name());
         }
-        FunctionDef shape_inference_fdef = *xla_fdef;
-        shape_inference_fdef.mutable_signature()->set_name(
-            shape_inference_graph.name());
-        if (fld->Find(shape_inference_graph.name())) {
-          TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph.name(),
-                                                  shape_inference_fdef));
+        auto shape_inference_fdef = absl::make_unique<FunctionDef>(*xla_fdef);
+        shape_inference_fdef->mutable_signature()->set_name(
+            shape_inference_graph->name());
+        if (fld->Find(shape_inference_graph->name())) {
+          TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph->name(),
+                                                  *shape_inference_fdef));
         } else {
-          TF_RETURN_IF_ERROR(fld->AddFunctionDef(shape_inference_fdef));
+          TF_RETURN_IF_ERROR(fld->AddFunctionDef(*shape_inference_fdef));
         }
       }
     }
@@ -2213,15 +2261,15 @@
   TF_RETURN_IF_ERROR(
       ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name,
                          outside_compilation_host_graphs, fld, &host_graph));
-  FunctionDef host_graph_fdef;
+  auto host_graph_fdef = absl::make_unique<FunctionDef>();
   TF_RETURN_IF_ERROR(GraphToFunctionDef(*host_graph, host_graph_func_name,
                                         HostGraphControlRetMapping,
-                                        &host_graph_fdef));
+                                        host_graph_fdef.get()));
   if (fld->Find(host_graph_func_name)) {
     TF_RETURN_IF_ERROR(
-        fld->ReplaceFunction(host_graph_func_name, host_graph_fdef));
+        fld->ReplaceFunction(host_graph_func_name, *host_graph_fdef));
   } else {
-    TF_RETURN_IF_ERROR(fld->AddFunctionDef(host_graph_fdef));
+    TF_RETURN_IF_ERROR(fld->AddFunctionDef(*host_graph_fdef));
   }
 
   // Shape inference graphs might contain Placeholder nodes for outside
@@ -2240,19 +2288,19 @@
   }
 
   // Replace original function.
-  FunctionDef updated_fdef;
+  auto updated_fdef = absl::make_unique<FunctionDef>();
   TF_RETURN_IF_ERROR(
-      GraphToFunctionDef(*graph_out, new_func_name, &updated_fdef));
+      GraphToFunctionDef(*graph_out, new_func_name, updated_fdef.get()));
   const FunctionDef* original_fdef = fld->Find(func_name);
   if (original_fdef) {
     for (const auto& attr : original_fdef->attr()) {
-      (*updated_fdef.mutable_attr())[attr.first] = attr.second;
+      (*updated_fdef->mutable_attr())[attr.first] = attr.second;
     }
   }
   if (fld->Find(new_func_name)) {
-    TF_RETURN_IF_ERROR(fld->ReplaceFunction(new_func_name, updated_fdef));
+    TF_RETURN_IF_ERROR(fld->ReplaceFunction(new_func_name, *updated_fdef));
   } else {
-    TF_RETURN_IF_ERROR(fld->AddFunctionDef(updated_fdef));
+    TF_RETURN_IF_ERROR(fld->AddFunctionDef(*updated_fdef));
   }
   if (VLOG_IS_ON(4)) {
     DumpGraphToFile(
diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc
index f69a28b..53f9b70 100644
--- a/tensorflow/compiler/jit/flags.cc
+++ b/tensorflow/compiler/jit/flags.cc
@@ -105,6 +105,8 @@
   build_ops_flags = new BuildXlaOpsPassFlags;
   build_ops_flags->tf_xla_enable_lazy_compilation = true;
   build_ops_flags->tf_xla_print_cluster_outputs = false;
+  build_ops_flags->tf_xla_check_cluster_input_numerics = false;
+  build_ops_flags->tf_xla_check_cluster_output_numerics = false;
   build_ops_flags->tf_xla_disable_constant_folding = false;
 
   mark_for_compilation_flags = new MarkForCompilationPassFlags;
@@ -144,6 +146,14 @@
             &build_ops_flags->tf_xla_print_cluster_outputs,
             "If true then insert Print nodes to print out values produced by "
             "XLA clusters."),
+       Flag("tf_xla_check_cluster_input_numerics",
+            &build_ops_flags->tf_xla_check_cluster_input_numerics,
+            "If true then insert CheckNumerics nodes to to check all cluster "
+            "inputs."),
+       Flag("tf_xla_check_cluster_output_numerics",
+            &build_ops_flags->tf_xla_check_cluster_output_numerics,
+            "If true then insert CheckNumerics nodes to to check all cluster "
+            "outputs."),
 
        Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
             "Switch a device into 'on-demand' mode, where instead of "
diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h
index 91e93f3..9307874 100644
--- a/tensorflow/compiler/jit/flags.h
+++ b/tensorflow/compiler/jit/flags.h
@@ -103,6 +103,14 @@
   // clusters.  Useful for debugging.
   bool tf_xla_print_cluster_outputs;
 
+  // If true, insert CheckNumerics nodes for every floating point typed input to
+  // an XLA cluster.
+  bool tf_xla_check_cluster_input_numerics;
+
+  // If true, insert CheckNumerics nodes for every floating point typed output
+  // from an XLA cluster.
+  bool tf_xla_check_cluster_output_numerics;
+
   // Disables all constant folding. The primary use for this is for testing to
   // guarantee that tests are run on XLA and not on TF's CPU implementation.
   bool tf_xla_disable_constant_folding;
diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD
index 49b8731..3fbd977 100644
--- a/tensorflow/compiler/jit/kernels/BUILD
+++ b/tensorflow/compiler/jit/kernels/BUILD
@@ -12,6 +12,8 @@
     deps = [
         "//tensorflow/compiler/jit:common",
         "//tensorflow/compiler/jit:flags",
+        "//tensorflow/compiler/jit:xla_activity_listener",
+        "//tensorflow/compiler/jit:xla_activity_proto_cc",
         "//tensorflow/compiler/jit:xla_compilation_cache",
         "//tensorflow/compiler/jit:xla_device",
         "//tensorflow/compiler/jit:xla_launch_util",
diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc
index 788e90f..87d6548 100644
--- a/tensorflow/compiler/jit/kernels/xla_ops.cc
+++ b/tensorflow/compiler/jit/kernels/xla_ops.cc
@@ -19,6 +19,7 @@
 #include "absl/memory/memory.h"
 #include "tensorflow/compiler/jit/defs.h"
 #include "tensorflow/compiler/jit/flags.h"
+#include "tensorflow/compiler/jit/xla_activity_listener.h"
 #include "tensorflow/compiler/tf2xla/shape_util.h"
 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
@@ -366,7 +367,12 @@
   Env* env = Env::Default();
   auto start_time = env->NowMicros();
 
-  auto run_result = executable->Run(launch_context.arguments(), run_options);
+  xla::StatusOr<xla::ScopedShapedBuffer> run_result;
+  if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
+    run_result = executable->Run(launch_context.arguments(), run_options);
+  } else {
+    run_result = executable->RunAsync(launch_context.arguments(), run_options);
+  }
   OP_REQUIRES(ctx, run_result.ok(), run_result.status());
 
   auto elapsed = env->NowMicros() - start_time;
@@ -467,6 +473,10 @@
     if (status.code() == error::UNIMPLEMENTED) {
       LOG(WARNING) << "Compilation failed:" << status.ToString()
                    << ".  Falling back to TF function call.";
+
+      BroadcastOptimizationRemark(
+          XlaOptimizationRemark::UNIMPLEMENTED_OPERATION, status.ToString())
+          .IgnoreError();
       executable = nullptr;
       mutex_lock guard(cannot_compile_cluster_mu_);
       cannot_compile_cluster_ = true;
@@ -498,7 +508,7 @@
           client, executable, kernel, std::move(variables), constants_.size()));
 
   Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
-  compilation_key.flat<string>()(0) = key;
+  compilation_key.flat<tstring>()(0) = key;
 
   Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
   compilation_successful.flat<bool>()(0) = true;
@@ -513,7 +523,7 @@
 void XlaRunOp::Compute(OpKernelContext* ctx) {
   VLOG(3) << "XlaRunOp " << def().name();
   Tensor key_tensor = ctx->input(ctx->num_inputs() - 1);
-  const XlaExecutableClosureStore::KeyT& key = key_tensor.flat<string>()(0);
+  const XlaExecutableClosureStore::KeyT& key = key_tensor.flat<tstring>()(0);
 
   XlaExecutableClosure closure =
       XlaExecutableClosureStore::Global()->Consume(key);
@@ -550,8 +560,14 @@
   Env* env = Env::Default();
   auto start_time = env->NowMicros();
 
-  auto run_result =
-      closure.executable()->Run(launch_context.arguments(), run_options);
+  xla::StatusOr<xla::ScopedShapedBuffer> run_result;
+  if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
+    run_result =
+        closure.executable()->Run(launch_context.arguments(), run_options);
+  } else {
+    run_result =
+        closure.executable()->RunAsync(launch_context.arguments(), run_options);
+  }
   OP_REQUIRES(ctx, run_result.ok(), run_result.status());
 
   auto elapsed = env->NowMicros() - start_time;
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 91423f6..b86ef93 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -677,8 +677,7 @@
   }
 
   DataType dtype;
-  if (!GetNodeAttr(n->def(), "dtype", &dtype).ok() ||
-      !DataTypeIsInteger(dtype)) {
+  if (!TryGetNodeAttr(n->def(), "dtype", &dtype) || !DataTypeIsInteger(dtype)) {
     return false;
   }
 
@@ -695,7 +694,7 @@
   }
 
   const TensorProto* proto = nullptr;
-  if (!GetNodeAttr(const_input->def(), "value", &proto).ok()) {
+  if (!TryGetNodeAttr(const_input->def(), "value", &proto)) {
     return false;
   }
 
@@ -935,8 +934,8 @@
     return absl::nullopt;
   }
 
-  string scope;
-  if (GetNodeAttr(node->attrs(), kXlaScopeAttr, &scope).ok()) {
+  const string& scope = GetNodeAttrString(node->attrs(), kXlaScopeAttr);
+  if (!scope.empty()) {
     return scope;
   }
 
@@ -970,8 +969,7 @@
     int effective_cluster_size =
         (node->IsIdentity() || node->IsConstant()) ? 0 : 1;
 
-    bool has_functional_control_flow =
-        node->type_string() == "While" || node->IsIfNode();
+    bool has_functional_control_flow = node->IsWhileNode() || node->IsIfNode();
 
     absl::optional<DeadnessPredicate> deadness_predicate;
     if (deadness_analysis_) {
@@ -1000,7 +998,7 @@
     bool is_xla_compile_attr_true = false;
 
     bool xla_compile_attr;
-    if (GetNodeAttr(node->attrs(), kXlaCompileAttr, &xla_compile_attr).ok()) {
+    if (TryGetNodeAttr(node->attrs(), kXlaCompileAttr, &xla_compile_attr)) {
       is_xla_compile_attr_true |= xla_compile_attr;
     }
 
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index cbe60b0..e056ecd 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -52,7 +52,7 @@
   std::unordered_map<string, string> ids;
   for (Node* node : graph.nodes()) {
     string cluster;
-    if (GetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster).ok()) {
+    if (TryGetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster)) {
       CHECK(!cluster.empty());
       ids[node->name()] = cluster;
     }
diff --git a/tensorflow/compiler/jit/node_matchers.cc b/tensorflow/compiler/jit/node_matchers.cc
index b878f05..932e076 100644
--- a/tensorflow/compiler/jit/node_matchers.cc
+++ b/tensorflow/compiler/jit/node_matchers.cc
@@ -135,7 +135,7 @@
 
     if (constant_value) {
       const TensorProto* proto = nullptr;
-      if (!GetNodeAttr(node->def(), "value", &proto).ok()) {
+      if (!TryGetNodeAttr(node->def(), "value", &proto)) {
         if (listener->IsInterested()) {
           *listener << "\ncould not find \"value\" attribute in node";
         }
diff --git a/tensorflow/compiler/jit/xla_activity.proto b/tensorflow/compiler/jit/xla_activity.proto
index 1edde32..50bfb29 100644
--- a/tensorflow/compiler/jit/xla_activity.proto
+++ b/tensorflow/compiler/jit/xla_activity.proto
@@ -94,3 +94,27 @@
   // Total microseconds spent in (re-)compiling this cluster so far.
   int64 cumulative_compile_time_us = 4;
 }
+
+// LINT.IfChange
+//
+// Used for logging situations seen in Tensorflow models being optimized that
+// are known to not perform well with XLA.
+//
+// Next ID: 3
+message XlaOptimizationRemark {
+  // Next ID: 6
+  enum Warning {
+    NONE = 0;
+    INACCURATE_OPERATION = 1;
+    SLOW_OPERATION = 2;
+    UNIMPLEMENTED_OPERATION = 3;
+    SLOW_IMAGE_RESIZE_DIMENSIONS = 4;
+    MEGAMORPHIC_FUNCTION = 5;
+  }
+
+  Warning warning = 1;
+
+  // Information such as which node was the problem.
+  string debug_information = 2;
+}
+// LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/compiler/jit/xla_activity_listener.h)
diff --git a/tensorflow/compiler/jit/xla_activity_listener.cc b/tensorflow/compiler/jit/xla_activity_listener.cc
index 1f14cc9..a1ea6a6 100644
--- a/tensorflow/compiler/jit/xla_activity_listener.cc
+++ b/tensorflow/compiler/jit/xla_activity_listener.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/compiler/jit/xla_activity_listener.h"
 
 #include "absl/synchronization/mutex.h"
+#include "tensorflow/compiler/jit/xla_activity.pb.h"
 #include "tensorflow/core/lib/core/errors.h"
 
 namespace tensorflow {
@@ -71,6 +72,21 @@
   });
 }
 
+Status BroadcastOptimizationRemark(XlaOptimizationRemark optimization_remark) {
+  VLOG(2) << "OptimizationRemark: " << optimization_remark.DebugString();
+  return ForEachListener([&](XlaActivityListener* listener) {
+    return listener->Listen(optimization_remark);
+  });
+}
+
+Status BroadcastOptimizationRemark(
+    XlaOptimizationRemark::Warning optimization_warning,
+    string debug_information) {
+  XlaOptimizationRemark remark;
+  remark.set_warning(optimization_warning);
+  remark.set_debug_information(std::move(debug_information));
+  return BroadcastOptimizationRemark(std::move(remark));
+}
 void RegisterXlaActivityListener(
     std::unique_ptr<XlaActivityListener> listener) {
   XlaActivityListenerList* listener_list = GetXlaActivityListenerList();
diff --git a/tensorflow/compiler/jit/xla_activity_listener.h b/tensorflow/compiler/jit/xla_activity_listener.h
index 547181d..05328c8 100644
--- a/tensorflow/compiler/jit/xla_activity_listener.h
+++ b/tensorflow/compiler/jit/xla_activity_listener.h
@@ -27,6 +27,18 @@
 // Broadcast `jit_compilation_activity` to all the registered listeners.
 Status BroadcastXlaActivity(XlaJitCompilationActivity jit_compilation_activity);
 
+// Broadcast `jit_compilation_activity` to all the registered listeners.
+Status BroadcastOptimizationRemark(XlaOptimizationRemark optimization_remark);
+
+// LINT.IfChange
+// Called after TensorFlow realizes possible lost performance. The parameters in
+// this should match all of the values in the XlaOptimizationRemark proto.
+Status BroadcastOptimizationRemark(
+    XlaOptimizationRemark::Warning optimization_warning,
+    string debug_information);
+
+// LINT.ThenChange(//tensorflow/compiler/jit/xla_activity.proto)
+
 // Various components of the system can subclass XlaActivityListener to
 // notifications on auto-clustering and JIT compilation events.
 //
@@ -41,6 +53,9 @@
   virtual Status Listen(
       const XlaJitCompilationActivity& jit_compilation_activity) = 0;
 
+  // Called after TensorFlow realizes possible lost performance.
+  virtual Status Listen(const XlaOptimizationRemark& optimization_remark) = 0;
+
   // Called at program exit in best-effort manner to give listeners a chance to
   // flush their state.
   //
diff --git a/tensorflow/compiler/jit/xla_activity_listener_test.cc b/tensorflow/compiler/jit/xla_activity_listener_test.cc
index 4d087e2..034adbf 100644
--- a/tensorflow/compiler/jit/xla_activity_listener_test.cc
+++ b/tensorflow/compiler/jit/xla_activity_listener_test.cc
@@ -43,6 +43,10 @@
     return Status::OK();
   }
 
+  Status Listen(const XlaOptimizationRemark& optimization_remark) override {
+    return Status::OK();
+  }
+
   ~TestListener() override {}
 
   const XlaAutoClusteringActivity& auto_clustering_activity() const {
diff --git a/tensorflow/compiler/jit/xla_activity_logging_listener.cc b/tensorflow/compiler/jit/xla_activity_logging_listener.cc
index a36bd3b..87e39a5 100644
--- a/tensorflow/compiler/jit/xla_activity_logging_listener.cc
+++ b/tensorflow/compiler/jit/xla_activity_logging_listener.cc
@@ -14,6 +14,7 @@
 ==============================================================================*/
 
 #include "absl/memory/memory.h"
+#include "tensorflow/compiler/jit/xla_activity.pb.h"
 #include "tensorflow/compiler/jit/xla_activity_listener.h"
 #include "tensorflow/core/platform/logger.h"
 
@@ -59,6 +60,23 @@
     return Status::OK();
   }
 
+  Status Listen(const XlaOptimizationRemark& optimization_remark) override {
+    if (!IsEnabled()) {
+      VLOG(3) << "Logging XlaJitCompilationActivity disabled";
+      return Status::OK();
+    }
+
+    if (Logger* logger = Logger::GetSingletonAsync()) {
+      VLOG(2) << "Logging XlaJitCompilationActivity";
+      VLOG(3) << optimization_remark.DebugString();
+      logger->LogProto(optimization_remark);
+    } else {
+      VLOG(2) << "Not logging: logger not ready yet.";
+    }
+
+    return Status::OK();
+  }
+
  private:
   bool IsEnabled() {
     static bool result = ComputeIsEnabled();
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index 035a50e..093c356 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -19,6 +19,7 @@
 
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_join.h"
+#include "tensorflow/compiler/jit/xla_activity.pb.h"
 #include "tensorflow/compiler/jit/xla_activity_listener.h"
 #include "tensorflow/compiler/tf2xla/shape_util.h"
 #include "tensorflow/compiler/tf2xla/type_util.h"
@@ -27,6 +28,7 @@
 #include "tensorflow/core/common_runtime/device.h"
 #include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/common_runtime/graph_optimizer.h"
+#include "tensorflow/core/common_runtime/metrics.h"
 #include "tensorflow/core/framework/attr_value_util.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/graph/graph_constructor.h"
@@ -301,6 +303,9 @@
       }
 
       if (is_megamorphic) {
+        BroadcastOptimizationRemark(XlaOptimizationRemark::MEGAMORPHIC_FUNCTION,
+                                    function.name())
+            .IgnoreError();
         VLOG(3) << "Not compiling cluster " << function.name()
                 << " because it is megamorphic.";
         return false;
@@ -346,6 +351,7 @@
 
     const uint64 compile_end_us = env->NowMicros();
     const uint64 compile_time_us = compile_end_us - compile_start_us;
+    metrics::UpdateXlaCompilationTime(compile_time_us);
     {
       mutex_lock lock(cluster_compile_stats_mu_);
       auto it = cluster_compile_stats_.find(function.name());
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index 1d8b4be..be2038a 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -203,6 +203,8 @@
       device_ordinal_(options.device_ordinal),
       jit_device_name_(options.compilation_device_name),
       platform_(options.platform),
+      intra_op_parallelism_threads_(
+          session_options.config.intra_op_parallelism_threads()),
       use_multiple_streams_(options.use_multiple_streams),
       shape_representation_fn_(options.shape_representation_fn),
       allowed_devices_(options.allowed_devices) {
@@ -233,10 +235,13 @@
   // don't want to do it until we get a chance to hook the platform up
   // to a simulator.
 
+  xla::LocalClientOptions options;
+  options.set_platform(platform_)
+      .set_allowed_devices(allowed_devices_)
+      .set_intra_op_parallelism_threads(intra_op_parallelism_threads_);
   // TODO(b/78468222): This can fail, at least when the backend is GPU and
   // there is no GPU on the host.
-  return xla::ClientLibrary::GetOrCreateLocalClient(platform_, allowed_devices_)
-      .ValueOrDie();
+  return xla::ClientLibrary::GetOrCreateLocalClient(options).ValueOrDie();
 }
 
 Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) {
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index 51910c6..877580e 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -202,6 +202,8 @@
   const DeviceType jit_device_name_;
   // The platform for this device.
   se::Platform* const platform_;  // Not owned.
+  // Intra-op threads to spawn (from SessionOptions).
+  const int intra_op_parallelism_threads_;
   // Memory allocator associated with this device.
   Allocator* xla_allocator_ GUARDED_BY(mu_) = nullptr;  // Not owned.
 
diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD
index 247bb83..45cc779 100644
--- a/tensorflow/compiler/mlir/BUILD
+++ b/tensorflow/compiler/mlir/BUILD
@@ -53,11 +53,13 @@
     name = "tf-opt",
     deps = [
         ":tf_mlir_opt_main",
+        "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
     ],
 )
 
 tf_cc_binary(
     name = "tf-mlir-translate",
+    srcs = ["tf_mlir_translate_main.cc"],
     deps = [
         "//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
         "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
@@ -66,12 +68,14 @@
         "//tensorflow/compiler/mlir/tensorflow:translate_registration",
         "//tensorflow/compiler/mlir/tensorflow:translate_tf_dialect_op",
         "//tensorflow/compiler/mlir/xla:xla_mlir_translate",
+        "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_proto_cc",
         "//tensorflow/stream_executor/lib",
         "@llvm//:support",
         "@local_config_mlir//:IR",
+        "@local_config_mlir//:Support",
+        "@local_config_mlir//:TranslateClParser",
         "@local_config_mlir//:Translation",
-        "@local_config_mlir//:tools/mlir-translate/mlir-translate",
     ],
 )
 
diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index 8aa78a2..260654a 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -162,11 +162,12 @@
         "utils/attribute_utils.cc",
     ],
     hdrs = [
+        "ir/quantization_traits.h",
         "ir/tfl_ops.h",
         "ir/tfl_traits.h",
+        "quantization/quantization_utils.h",
         "transforms/passes.h",
         "utils/attribute_utils.h",
-        "utils/quantization_utils.h",
     ],
     deps = [
         ":tensorflow_lite_ops_inc_gen",
@@ -188,12 +189,11 @@
 cc_library(
     name = "tensorflow_lite_quantization_utils",
     srcs = [
-        "utils/generated_op_quant_spec_getters.inc",
-        "utils/quantization_driver.cc",
-        "utils/quantization_utils.cc",
+        "quantization/quantization_driver.cc",
+        "quantization/quantization_utils.cc",
     ],
     hdrs = [
-        "utils/quantization_utils.h",
+        "quantization/quantization_utils.h",
     ],
     deps = [
         ":tensorflow_lite",
@@ -216,6 +216,7 @@
         "transforms/legalize_tf.cc",
         "transforms/lower_static_tensor_list.cc",
         "transforms/prepare_tf.cc",
+        "transforms/trim_functions_tf.cc",
     ],
     hdrs = [
         "transforms/passes.h",
@@ -249,10 +250,12 @@
     deps = [
         ":tensorflow_lite",
         ":validators",
+        "//tensorflow/compiler/mlir/tensorflow",
         "@llvm//:support",
         "@local_config_mlir//:Analysis",
         "@local_config_mlir//:IR",
         "@local_config_mlir//:Pass",
+        "@local_config_mlir//:StandardOps",
         "@local_config_mlir//:Support",
     ],
     alwayslink = 1,
@@ -265,6 +268,7 @@
         "transforms/post_quantize.cc",
         "transforms/prepare_quantize.cc",
         "transforms/quantize.cc",
+        "utils/generated_op_quant_spec_getters.inc",
     ],
     hdrs = [
         "transforms/passes.h",
@@ -297,6 +301,13 @@
     ],
 )
 
+filegroup(
+    name = "generated_op_quant_spec_getters",
+    srcs = [
+        "utils/generated_op_quant_spec_getters.inc",
+    ],
+)
+
 genrule(
     name = "op_quant_spec_getters_inc",
     srcs = [
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc
index 5f460b4..b58e2b6 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc
@@ -206,11 +206,13 @@
          op->getName().getStringRef() == "tf.Placeholder.input";
 }
 
-static bool IsConstOrInput(Operation* op) {
-  return (isa<mlir::ConstantOp>(op) || isa<mlir::TF::ConstOp>(op) ||
-          isa<tfl::ConstOp>(op) || isa<tfl::QConstOp>(op) || IsInput(op));
+static bool IsConst(Operation* op) {
+  return isa<mlir::ConstantOp>(op) || isa<mlir::TF::ConstOp>(op) ||
+         isa<tfl::ConstOp>(op) || isa<tfl::QConstOp>(op);
 }
 
+static bool IsConstOrInput(Operation* op) { return IsConst(op) || IsInput(op); }
+
 template <typename T>
 static bool HasValidTFLiteType(Value* value, T& error_handler) {
   // None type is allowed to represent unspecified operands.
@@ -228,7 +230,7 @@
     return false;
   }
   if (auto* inst = value->getDefiningOp()) {
-    if (IsConstOrInput(inst) && !type.hasStaticShape()) {
+    if (IsInput(inst) && !type.hasStaticShape()) {
       return error_handler.emitError("should have static shape, got ")
                  << type.getShape(),
              false;
@@ -414,6 +416,10 @@
   // mapping.
   void InitializeNamesFromAttribute(FuncOp fn);
 
+  // Determines if the specified operation op's operand at operand_index
+  // is marked as a stateful operand.
+  bool IsStatefulOperand(mlir::Operation* op, int operand_index);
+
   // Returns a unique name for `op`.
   std::string UniqueName(mlir::Operation* op);
 
@@ -531,8 +537,18 @@
   // However, we output all known shapes for better round-tripping
   std::vector<int32_t> shape;
   if (auto* inst = value->getDefiningOp()) {
-    if (type.hasStaticShape()) {
-      auto shape_ref = type.getShape();
+    if (type.hasStaticShape() || IsConst(inst)) {
+      // Const op can have a result of dynamic shaped type (e.g. due to constant
+      // folding), but we can still derive the shape of a constant tensor
+      // for its attribute type.
+      llvm::ArrayRef<int64_t> shape_ref;
+      if (type.hasStaticShape()) {
+        shape_ref = type.getShape();
+      } else {
+        mlir::Attribute tensor_attr = inst->getAttr("value");
+        shape_ref = tensor_attr.getType().cast<TensorType>().getShape();
+      }
+
       auto is_out_of_range = [](int64_t dim) {
         return dim > std::numeric_limits<int32_t>::max();
       };
@@ -559,10 +575,20 @@
   } else {
     q_params = tflite::CreateQuantizationParameters(builder_);
   }
-
+  // Check if the value's uses includes an op and usage at an operand index
+  // marked as a stateful. If so, set the tensor's is_variable as true
+  // This is v1 ref variable semantics in the TFLite runtime.
+  bool is_variable = false;
+  for (auto& use : value->getUses()) {
+    is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber());
+    if (is_variable) {
+      break;
+    }
+  }
   return tflite::CreateTensor(
-      builder_, builder_.CreateVector(shape), tflite_element_type, buffer_idx,
-      builder_.CreateString(name), q_params, /*is_variable=*/false);
+      builder_, builder_.CreateVector(shape), tflite_element_type,
+      (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
+      /*is_variable=*/is_variable);
 }
 
 BufferOffset<tflite::Operator> Translator::BuildIfOperator(
@@ -859,6 +885,25 @@
   }
 }
 
+bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) {
+  std::vector<int> operand_indices;
+  // TODO(b/138254427): When the bug is addressed, we'll be able to inspect
+  // for the presence of a specific OpTrait using mlir::Operation, without
+  // having to cast it to specific ops like below.
+  // Until then, when a new RNN/LSTM op is added to TFLite and has stateful
+  // tensors as operands, they will need to be added here as well.
+  if (auto tfl = llvm::dyn_cast<mlir::TFL::LSTMOp>(op)) {
+    operand_indices = tfl.GetStatefulOperands();
+  } else if (auto tfl =
+                 llvm::dyn_cast<mlir::TFL::UnidirectionalSequenceLSTMOp>(op)) {
+    operand_indices = tfl.GetStatefulOperands();
+  } else if (auto tfl =
+                 llvm::dyn_cast<mlir::TFL::UnidirectionalSequenceRNNOp>(op)) {
+    operand_indices = tfl.GetStatefulOperands();
+  }
+  return absl::c_find(operand_indices, operand_index) != operand_indices.end();
+}
+
 Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
   InitializeNamesFromAttribute(fn);
   std::vector<BufferOffset<tflite::Tensor>> tensors;
@@ -878,6 +923,10 @@
     if (!tensor_or) return false;
     tensors.push_back(*tensor_or);
 
+    // TODO(ashwinm): Check if for stateful tensors, if it is also needed to
+    // make the Buffer empty apart from setting the buffer_idx=0 in the Tensor.
+    // This does not seem to affect runtime behavior for RNN/LSTM, but would be
+    // good for reducing memory footprint.
     if (auto* inst = value->getDefiningOp()) {
       auto buffer_or = BuildBuffer(inst);
       if (!buffer_or) return false;
diff --git a/tensorflow/compiler/mlir/lite/g3doc/tfl_ops.md b/tensorflow/compiler/mlir/lite/g3doc/tfl_ops.md
deleted file mode 100755
index 74e4fc4..0000000
--- a/tensorflow/compiler/mlir/lite/g3doc/tfl_ops.md
+++ /dev/null
@@ -1,1606 +0,0 @@
-<!-- Autogenerated by mlir-tblgen; don't manually edit -->
-# Operation definition
-## tfl.abs (TFL::AbsOp)
-Absolute value operator
-
-### Description:
-
-Given a tensor `x`, this operation returns a tensor containing the absolute
-value of each element in `x`. For example, if x is an input element and y is
-an output element, this operation computes \\(y = |x|\\).
-
-### Operands:
-1. `x`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `y`: tensor of any type values
-
-## tfl.add_n (TFL::AddNOp)
-add_n operator
-
-### Description:
-
-Adds all input tensors element-wise.
-
-### Operands:
-1. `inputs`: tensor of 32-bit float or 32-bit integer values
-
-### Attributes:
-
-### Results:
-1. `sum`: tensor of 32-bit float or 32-bit integer values
-
-## tfl.add (TFL::AddOp)
-Addition operator
-
-### Description:
-
-Element-wise addition operation.
-
-### Operands:
-1. `lhs`: tensor of any type values
-1. `rhs`: tensor of any type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `fused_activation_function` | `StringAttr` | fused activation enum attribute |
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.average_pool_2d (TFL::AveragePool2DOp)
-Average_pool_2d operator
-
-### Description:
-
-Performs average-pooling operation on input.
-
-### Operands:
-1. `input`: tensor of any type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `filter_height` | `IntegerAttr` | 32-bit integer attribute attribute |
-| `filter_width` | `IntegerAttr` | 32-bit integer attribute attribute |
-| `padding` | `StringAttr` | padding enum attribute |
-| `stride_h` | `IntegerAttr` | 32-bit integer attribute attribute |
-| `stride_w` | `IntegerAttr` | 32-bit integer attribute attribute |
-| `fused_activation_function` | `StringAttr` | fused activation enum attribute |
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.batch_to_space_nd (TFL::BatchToSpaceNdOp)
-BatchToSpaceNd operator
-
-### Description:
-
-This operation reshapes the "batch" dimension 0 into space dimensions.
-
-### Operands:
-1. `input`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values
-1. `block_shape`: tensor of 32-bit integer values
-1. `indices`: tensor of 32-bit integer values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer values
-
-## tfl.ceil (TFL::CeilOp)
-Ceil operator
-
-### Description:
-
-Returns element-wise ceil value of the input.
-
-### Operands:
-1. `x`: tensor of floating-point values
-
-### Attributes:
-
-### Results:
-1. `y`: tensor of floating-point values
-
-## tfl.concatenation (TFL::ConcatenationOp)
-Concatenation operator
-
-### Description:
-
-Concatenates tensors along one dimension
-
-### Operands:
-1. `values`: tensor of 32-bit float or 64-bit integer or 32-bit integer or 16-bit integer or 8-bit integer or quantized type with 8 bits storage type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `axis` | `IntegerAttr` | 32-bit integer attribute attribute |
-| `fused_activation_function` | `StringAttr` | fused activation enum attribute |
-
-### Results:
-1. `output`: tensor of 32-bit float or 64-bit integer or 32-bit integer or 16-bit integer or 8-bit integer or quantized type with 8 bits storage type values
-
-## tfl.pseudo_const (TFL::ConstOp)
-Constant pseudo op.
-
-### Description:
-
-Represents a constant value in TensorFlow Lite dialect. This is not an
-actual operation and it will be lowered to buffer instead.
-
-The op is allowed to have all the same type of attributes as tf.Const does
-(e.g., opaque TF attributes are allowed).
-
-### Operands:
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `value` | `ElementsAttr` | constant vector/tensor attribute attribute |
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.conv_2d (TFL::Conv2DOp)
-Convolution operator
-
-### Description:
-
-Performs convolution operation on inputs.
-
-Inputs:
-  `inputs[0]`: required: the input activation tensor
-  `inputs[1]`: required: the filter weight tensor
-  `inputs[2]`: optional: the bias tensor
-
-### Operands:
-1. `input`: tensor of any type values
-1. `filter`: tensor of any type values
-1. `bias`: tensor of any type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `dilation_h_factor` | `IntegerAttr` | 32-bit integer attribute attribute |
-| `dilation_w_factor` | `IntegerAttr` | 32-bit integer attribute attribute |
-| `fused_activation_function` | `StringAttr` | fused activation enum attribute |
-| `padding` | `StringAttr` | padding enum attribute |
-| `stride_h` | `IntegerAttr` | 32-bit integer attribute attribute |
-| `stride_w` | `IntegerAttr` | 32-bit integer attribute attribute |
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.cos (TFL::CosOp)
-Cosine operator
-
-### Description:
-
-Computes element-wise Cosine of input
-
-### Operands:
-1. `x`: tensor of floating-point values
-
-### Attributes:
-
-### Results:
-1. `y`: tensor of floating-point values
-
-## tfl.depthwise_conv_2d (TFL::DepthwiseConv2DOp)
-Depthwise-separable convolution operator
-
-### Description:
-
-Performs convolution operation on inputs.
-
-Inputs:
-  `inputs[0]`: required: the input activation tensor
-  `inputs[1]`: required: the filter weight tensor
-  `inputs[2]`: optional: the bias tensor
-
-### Operands:
-1. `input`: tensor of any type values
-1. `filter`: tensor of any type values
-1. `bias`: tensor of any type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `dilation_h_factor` | `IntegerAttr` | 32-bit integer attribute attribute |
-| `dilation_w_factor` | `IntegerAttr` | 32-bit integer attribute attribute |
-| `fused_activation_function` | `StringAttr` | fused activation enum attribute |
-| `padding` | `StringAttr` | padding enum attribute |
-| `stride_h` | `IntegerAttr` | 32-bit integer attribute attribute |
-| `stride_w` | `IntegerAttr` | 32-bit integer attribute attribute |
-| `depth_multiplier` | `IntegerAttr` | 32-bit integer attribute attribute |
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.dequantize (TFL::DequantizeOp)
-Dequantize operator
-
-### Description:
-
-Converts quantized array of integers to floating-points according to the
-quantization parameters.
-
-### Operands:
-1. `input`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.div (TFL::DivOp)
-Division operator
-
-### Description:
-
-Element-wise division operation.
-
-### Operands:
-1. `lhs`: tensor of any type values
-1. `rhs`: tensor of any type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `fused_activation_function` | `StringAttr` | fused activation enum attribute |
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.elu (TFL::EluOp)
-Exponential Linear Unit operator
-
-### Description:
-
-Computes the exponential linear
-  f(x) -> exp(x) - 1 for x < 0, x for x >= 0.
-element-wise.
-
-### Operands:
-1. `x`: tensor of floating-point values
-
-### Attributes:
-
-### Results:
-1. `y`: tensor of any type values
-
-## tfl.equal (TFL::EqualOp)
-Equal operator
-
-### Description:
-
-Returns the truth element of x == y element-wise
-
-### Operands:
-1. `x`: tensor of 1-bit integer or 32-bit float or 32-bit integer or 64-bit integer or 8-bit integer values
-1. `y`: tensor of 1-bit integer or 32-bit float or 32-bit integer or 64-bit integer or 8-bit integer values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of 1-bit integer values
-
-## tfl.exp (TFL::ExpOp)
-Natural exponentiation operator
-
-### Description:
-
-Performs element-wise natural exponentiation operation on input.
-
-### Operands:
-1. `x`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `y`: tensor of any type values
-
-## tfl.expand_dims (TFL::ExpandDimsOp)
-Inserts a dimension of 1 into a tensor's shape.
-
-### Description:
-
-Given a tensor `input`, this operation inserts a dimension of 1 at the
-dimension index `axis` of `input`'s shape. The dimension index `axis` starts at
-zero; if you specify a negative number for `axis` it is counted backward from
-the end.
-
-This operation is useful if you want to add a batch dimension to a single
-element. For example, if you have a single image of shape `[height, width,
-channels]`, you can make it a batch of 1 image with `expand_dims(image, 0)`,
-which will make the shape `[1, height, width, channels]`.
-
-Other examples:
-
-```
-# 't' is a tensor of shape [2]
-shape(expand_dims(t, 0)) ==> [1, 2]
-shape(expand_dims(t, 1)) ==> [2, 1]
-shape(expand_dims(t, -1)) ==> [2, 1]
-
-# 't2' is a tensor of shape [2, 3, 5]
-shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5]
-shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5]
-shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1]
-```
-
-This operation requires that:
-
-`-1-input.dims() <= dim <= input.dims()`
-
-This operation is related to `squeeze()`, which removes dimensions of
-size 1.
-
-### Operands:
-1. `input`: tensor of any type values
-1. `dim`: tensor of any integer type
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.fake_quant (TFL::FakeQuantOp)
-FakeQuant operator
-
-### Description:
-
-Fake-quantize the 'inputs' tensor of type float via float scalars min and
-max to 'outputs' tensor of same shape as inputs.
-
-### Operands:
-1. `input`: tensor of any type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `minmax` | `ArrayAttr` | min-max range pair attribute |
-| `num_bits` | `IntegerAttr` | 32-bit integer attribute attribute |
-| `narrow_range` | `BoolAttr` | bool attribute attribute |
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.fill (TFL::FillOp)
-Fill the tensor with given value.
-
-### Description:
-
-Fill the tensor with given value.
-
-### Operands:
-1. `dims`: tensor of 32/64-bit integer values
-1. `value`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `res`: tensor of any type values
-
-## tfl.floor_div (TFL::FloorDivOp)
-Floor div operator
-
-### Description:
-
-Element-wise floor div operation.
-
-### Operands:
-1. `lhs`: tensor of any type values
-1. `rhs`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.floor_mod (TFL::FloorModOp)
-Division reminder
-
-### Description:
-
-Element-wise division reminder operation.
-
-### Operands:
-1. `lhs`: tensor of any type values
-1. `rhs`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.floor (TFL::FloorOp)
-Floor operator
-
-### Description:
-
-Returns element-wise floor value of the input.
-
-### Operands:
-1. `x`: tensor of floating-point values
-
-### Attributes:
-
-### Results:
-1. `y`: tensor of floating-point values
-
-## tfl.fully_connected (TFL::FullyConnectedOp)
-Fully connected op
-
-### Description:
-
-
-### Operands:
-1. `input`: tensor of 32-bit float values
-1. `filter`: tensor of 32-bit float values
-1. `bias`: tensor of 32-bit float values or none type
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `fused_activation_function` | `StringAttr` | fused activation enum attribute |
-| `weights_format` | `StringAttr` | fully connected options weights format attribute |
-| `keep_num_dims` | `BoolAttr` | bool attribute attribute |
-
-### Results:
-1. `output`: tensor of 32-bit float values
-
-## tfl.gather (TFL::GatherOp)
-Gather operator
-
-### Description:
-
-Gather slices from `params` axis `axis` according to `indices`.
-
-### Operands:
-1. `params`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer or TFLite string type values
-1. `indices`: tensor of 32-bit integer or 64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `axis` | `IntegerAttr` | 32-bit integer attribute attribute |
-
-### Results:
-1. `output`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer or TFLite string type values
-
-## tfl.greater_equal (TFL::GreaterEqualOp)
-Greater_equal operator
-
-### Description:
-
-Element-wise greater_equal operation.
-
-### Operands:
-1. `lhs`: tensor of any type values
-1. `rhs`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of 1-bit integer values
-
-## tfl.greater (TFL::GreaterOp)
-Greater operator
-
-### Description:
-
-Element-wise greater operation.
-
-### Operands:
-1. `lhs`: tensor of any type values
-1. `rhs`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.pseudo_input (TFL::InputOp)
-Input pseudo operator
-
-### Description:
-
-Takes one of the function arguments as input and returns it as result.  This
-is a NOP and is used to attach attributes such as tensor name to function
-arguments.
-
-### Operands:
-1. `input`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.leaky_relu (TFL::LeakyReluOp)
-Leaky Relu operator
-
-### Description:
-
-Element-wise Leaky ReLU operator
-  x -> x >= 0 ? x : (alpha * x)
-
-### Operands:
-1. `input`: tensor of any type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `alpha` | `FloatAttr` | 32-bit float attribute attribute |
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.less_equal (TFL::LessEqualOp)
-Less_equal operator
-
-### Description:
-
-Element-wise less_equal operation.
-
-### Operands:
-1. `lhs`: tensor of 32-bit float or 32-bit integer or 64-bit integer or 8-bit integer values
-1. `rhs`: tensor of 32-bit float or 32-bit integer or 64-bit integer or 8-bit integer values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of 1-bit integer values
-
-## tfl.less (TFL::LessOp)
-Less operator
-
-### Description:
-
-Element-wise less operation.
-
-### Operands:
-1. `lhs`: tensor of any type values
-1. `rhs`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of 1-bit integer values
-
-## tfl.log (TFL::LogOp)
-Natural logarithm operator
-
-### Description:
-
-Performs element-wise natural logarithm operation on input.
-
-### Operands:
-1. `x`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `y`: tensor of any type values
-
-## tfl.log_softmax (TFL::LogSoftmaxOp)
-Log softmax operator
-
-### Description:
-
-Computes element-wise log softmax activations with the following formula
-
-  input - log(reduce_sum(exp(input), dim))
-
-### Operands:
-1. `input`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.logical_and (TFL::LogicalAndOp)
-Logical AND operator
-
-### Description:
-
-Element-wise logical AND operation.
-
-### Operands:
-1. `lhs`: tensor of 1-bit integer values
-1. `rhs`: tensor of 1-bit integer values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of 1-bit integer values
-
-## tfl.logical_not (TFL::LogicalNotOp)
-Logical NOT operator
-
-### Description:
-
-Element-wise logical NOT operation.
-
-### Operands:
-1. `lhs`: tensor of 1-bit integer values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of 1-bit integer values
-
-## tfl.logical_or (TFL::LogicalOrOp)
-Logical OR operator
-
-### Description:
-
-Element-wise logical OR operation.
-
-### Operands:
-1. `lhs`: tensor of 1-bit integer values
-1. `rhs`: tensor of 1-bit integer values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of 1-bit integer values
-
-## tfl.logistic (TFL::LogisticOp)
-Logistic operator
-
-### Description:
-
-Computes element-wise Sigmoid of input
-
-### Operands:
-1. `x`: tensor of floating-point values
-
-### Attributes:
-
-### Results:
-1. `y`: tensor of floating-point values
-
-## tfl.max_pool_2d (TFL::MaxPool2DOp)
-Max Pool 2D op
-
-### Description:
-
-Performs max pool 2D on input.
-
-Inputs:
-  `inputs[0]`: required: the input tensor
-
-### Operands:
-1. `input`: tensor of any type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `padding` | `StringAttr` | padding enum attribute |
-| `stride_w` | `IntegerAttr` | 32-bit integer attribute attribute |
-| `stride_h` | `IntegerAttr` | 32-bit integer attribute attribute |
-| `filter_width` | `IntegerAttr` | 32-bit integer attribute attribute |
-| `filter_height` | `IntegerAttr` | 32-bit integer attribute attribute |
-| `fused_activation_function` | `StringAttr` | fused activation enum attribute |
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.maximum (TFL::MaximumOp)
-Max operator
-
-### Description:
-
-Element-wise max operation.
-
-### Operands:
-1. `lhs`: tensor of floating-point or 32/64-bit integer values
-1. `rhs`: tensor of floating-point or 32/64-bit integer values
-
-### Attributes:
-
-### Results:
-1. `max`: tensor of floating-point or 32/64-bit integer values
-
-## tfl.mean (TFL::MeanOp)
-Mean operator
-
-### Description:
-
-Computes the mean of elements across dimensions of a tensor.
-Reduces input_tensor along the dimensions given in axis.
-Unless keepdims is true, the rank of the tensor is reduced by 1 for
-each entry in axis. If keepdims is true, the reduced dimensions are retained
-with length 1.
-
-### Operands:
-1. `input`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values
-1. `axis`: tensor of 32-bit integer or 64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `keep_dims` | `BoolAttr` | bool attribute attribute |
-
-### Results:
-1. `output`: tensor of 32-bit float or 32-bit integer or 64-bit integer or 8-bit integer values
-
-## tfl.minimum (TFL::MinimumOp)
-Min operator
-
-### Description:
-
-Element-wise min operation.
-
-### Operands:
-1. `lhs`: tensor of floating-point or 32/64-bit integer values
-1. `rhs`: tensor of floating-point or 32/64-bit integer values
-
-### Attributes:
-
-### Results:
-1. `min`: tensor of floating-point or 32/64-bit integer values
-
-## tfl.mul (TFL::MulOp)
-Multiplication operator
-
-### Description:
-
-Element-wise multiplication operation.
-
-### Operands:
-1. `lhs`: tensor of any type values
-1. `rhs`: tensor of any type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `fused_activation_function` | `StringAttr` | fused activation enum attribute |
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.neg (TFL::NegOp)
-Negation operator
-
-### Description:
-
-Computes element-wise negation of input
-
-### Operands:
-1. `x`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `y`: tensor of any type values
-
-## tfl.not_equal (TFL::NotEqualOp)
-Not_equal operator
-
-### Description:
-
-Element-wise not_equal operation.
-
-### Operands:
-1. `lhs`: tensor of any type values
-1. `rhs`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of 1-bit integer values
-
-## tfl.pack (TFL::PackOp)
-Packs a list of tensors along a dimension into one tensor
-
-### Description:
-
-Packs a list of `values_count` rank-`R` tensors into one rank-`(R+1)`
-tensor.
-
-Packs the `values_count` tensors in `values` into a tensor with rank one
-higher than each tensor in `values`, by packing them along the `axis`
-dimension.
-
-Given a list of tensors of shape `(A, B, C)`;
-
-if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`.
-if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`.
-Etc.
-
-For example:
-
-```
-# 'x' is [1, 4]
-# 'y' is [2, 5]
-# 'z' is [3, 6]
-pack([x, y, z]) => [[1, 4], [2, 5], [3, 6]]  # Pack along first dim.
-pack([x, y, z], axis=1) => [[1, 2, 3], [4, 5, 6]]
-```
-
-This is the opposite of `unpack`.
-
-### Operands:
-1. `values`: tensor of 32-bit float or 8-bit integer or 16-bit integer or 32-bit integer or 64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `values_count` | `IntegerAttr` | 32-bit integer attribute attribute |
-| `axis` | `IntegerAttr` | 32-bit integer attribute attribute |
-
-### Results:
-1. `output`: tensor of 32-bit float or 8-bit integer or 16-bit integer or 32-bit integer or 64-bit integer values
-
-## tfl.pad (TFL::PadOp)
-Padding operator
-
-### Description:
-
-This operation pads a `input` with zeros according to the `paddings` you
-specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is
-the rank of `input`. For each dimension D of `input`, `paddings[D, 0]`
-indicates how many zeros to add before the contents of `input` in that
-dimension, and `paddings[D, 1]` indicates how many zeros to add after the
-contents of `input` in that dimension.
-
-The padded size of each dimension D of the output is:
-
-  `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)`
-
-For example:
-
-```
-# 't' is [[1, 1], [2, 2]]
-# 'paddings' is [[1, 1], [2, 2]]
-# rank of 't' is 2
-pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0]
-                      [0, 0, 1, 1, 0, 0]
-                      [0, 0, 2, 2, 0, 0]
-                      [0, 0, 0, 0, 0, 0]]
-
-### Operands:
-1. `input`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values
-1. `padding`: tensor of 32/64-bit integer values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values
-
-## tfl.padv2 (TFL::PadV2Op)
-Padding operator v2
-
-### Description:
-
-This operation pads a `input` according to the `paddings` and
-`constant_values` you specify. `paddings` is an integer tensor with shape
-`[Dn, 2]`, where n is the rank of `input`. For each dimension D of `input`,
-`paddings[D, 0]` indicates how many zeros to add before the contents of
-`input` in that dimension, and `paddings[D, 1]` indicates how many zeros to
-add after the contents of `input` in that dimension. `constant_values` is a
-scalar tensor of the same type as `input` that indicates the value to use
-for padding `input`.
-
-The padded size of each dimension D of the output is:
-
-  `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)`
-
-For example:
-
-```
-# 't' is [[1, 1], [2, 2]]
-# 'paddings' is [[1, 1], [2, 2]]
-# rank of 't' is 2
-pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0]
-                      [0, 0, 1, 1, 0, 0]
-                      [0, 0, 2, 2, 0, 0]
-                      [0, 0, 0, 0, 0, 0]]
-
-### Operands:
-1. `input`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values
-1. `padding`: tensor of 32/64-bit integer values
-1. `constant_values`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values
-
-## tfl.pow (TFL::PowOp)
-Power operator
-
-### Description:
-
-Element-wise power operation.
-
-### Operands:
-1. `lhs`: tensor of any type values
-1. `rhs`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.pseudo_qconst (TFL::QConstOp)
-Quantized constant pseudo op
-
-### Description:
-
-Represents a quantized constant value in TensorFlow Lite dialect. This is
-not an actual operation and it will be lowered to buffer instead. The
-quantization parameters are stored as a type attribute in this constant.
-
-### Operands:
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `qtype` | `TypeAttr` | Tensor type attribute attribute |
-| `value` | `ElementsAttr` | constant vector/tensor attribute attribute |
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.quantize (TFL::QuantizeOp)
-Quantize operator
-
-### Description:
-
-Converts floating point tensors to quantized integer tensors according to
-the quantization parameters defined in the type attribute.
-
-### Operands:
-1. `input`: tensor of any type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `qtype` | `TypeAttr` | Tensor type attribute attribute |
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.range (TFL::RangeOp)
-Range operator
-
-### Description:
-
-Returns a 1D tensor defined by a sequence from `start` to `limit` with
-a given `delta`.
-
-### Operands:
-1. `start`: tensor of any type values
-1. `limit`: tensor of any type values
-1. `delta`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `result`: tensor of any type values
-
-## tfl.rank (TFL::RankOp)
-Rank operator.
-
-### Description:
-
-Returns the rank of a tensor.
-
-### Operands:
-1. `input`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of any integer type
-
-## tfl.reduce_max (TFL::ReduceMaxOp)
-Max-reduction operator
-
-### Description:
-
-Computes the max reduction along the specified axes
-
-### Operands:
-1. `input`: tensor of any type values
-1. `axes`: tensor of 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `keep_dims` | `BoolAttr` | bool attribute attribute |
-
-### Results:
-1. &laquo;unnamed&raquo;: tensor of any type values
-
-## tfl.reduce_min (TFL::ReduceMinOp)
-Min-reduction operator
-
-### Description:
-
-Computes the min reduction along the specified axes
-
-### Operands:
-1. `input`: tensor of any type values
-1. `axes`: tensor of 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `keep_dims` | `BoolAttr` | bool attribute attribute |
-
-### Results:
-1. &laquo;unnamed&raquo;: tensor of any type values
-
-## tfl.relu6 (TFL::Relu6Op)
-Relu6 operator
-
-### Description:
-
-Element-wise Relu6 operator
-  x -> max(0, min(6, x))
-
-### Operands:
-1. `x`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `y`: tensor of any type values
-
-## tfl.relu (TFL::ReluOp)
-Relu operator
-
-### Description:
-
-Element-wise Relu operator
-  x -> max(0, x)
-
-### Operands:
-1. `x`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `y`: tensor of any type values
-
-## tfl.reshape (TFL::ReshapeOp)
-Reshape operator
-
-### Description:
-
-Produces a tensor with the same values but different static shape defined
-by the output type.
-
-### Operands:
-1. `input`: tensor of any type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `new_shape` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.resize_bilinear (TFL::ResizeBilinearOp)
-ResizeBilinear Op
-
-### Description:
-
-Resize `images` to `size` using bilinear interpolation.
-
-### Operands:
-1. `input`: tensor of 32-bit float or 32-bit integer values
-1. `size`: tensor of 32-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `align_corners` | `BoolAttr` | bool attribute attribute |
-
-### Results:
-1. `output`: tensor of 32-bit float values
-
-## tfl.reverse_v2 (TFL::ReverseV2Op)
-ReverseV2 Operator
-
-### Description:
-
-Reverses specific dimensions of a tensor.
-
-Given a tensor, and a int32/int64 tensor axis representing the set
-of dimensions of tensor to reverse.
-This operation reverses each dimension i for
-which there exists j s.t. axis[j] == i.
-
-Args:
-  tensor: A Tensor. Must be one of the following types:
-  int16, int32, int64, float32 Up to 8-D.
-
-  axis: A Tensor. Must be one of the following types: int32, int64.
-  with only 1 element which is the axis index.
-  TODO: Add support for multiple elements.
-
-### Operands:
-1. `input`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer values
-1. `axis`: tensor of 32-bit integer or 64-bit integer values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer values
-
-## tfl.rsqrt (TFL::RsqrtOp)
-Reciprocal of square root operator
-
-### Description:
-
-Computes element-wise reverse square root of input
-
-### Operands:
-1. `x`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `y`: tensor of any type values
-
-## tfl.select (TFL::SelectOp)
-Select operator
-
-### Description:
-
-Select values of 'x' if the corresponding value of 'condition' is true or
-the value of 'y' if false. There are valid condition input sizes:
-
-1. Either the same shape (in which case the select is elementwise), or
-2. condition must be Rank 1 and match over the first dimension.
-
-### Operands:
-1. `condition`: tensor of 1-bit integer values
-1. `x`: tensor of 32-bit float or 1-bit integer or 8-bit integer or 16-bit integer or 32-bit integer or 64-bit integer values
-1. `y`: tensor of 32-bit float or 1-bit integer or 8-bit integer or 16-bit integer or 32-bit integer or 64-bit integer values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.shape (TFL::ShapeOp)
-Shape operator
-
-### Description:
-
-Returns the shape of a tensor.
-
-### Operands:
-1. `input`: tensor of any type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `out_type` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.sin (TFL::SinOp)
-Sine operator
-
-### Description:
-
-Computes element-wise Sine of input
-
-### Operands:
-1. `x`: tensor of floating-point values
-
-### Attributes:
-
-### Results:
-1. `y`: tensor of floating-point values
-
-## tfl.softmax (TFL::SoftmaxOp)
-Softmax operator
-
-### Description:
-
-Computes element-wise softmax activiations with the following formula
-
-  exp(input) / tf.reduce_sum(exp(input * beta), dim)
-
-### Operands:
-1. `input`: tensor of any type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `beta` | `FloatAttr` | 32-bit float attribute attribute |
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.space_to_batch_nd (TFL::SpaceToBatchNdOp)
-SpaceToBatchNd operator
-
-### Description:
-
-This operation reshapes space dimensions into the "batch" dimension 0
-
-### Operands:
-1. `input`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values
-1. `block_shape`: tensor of 32-bit integer values
-1. `paddings`: tensor of 32-bit integer values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer values
-
-## tfl.split (TFL::SplitOp)
-Splits a tensor into `num_split` tensors along one dimension.
-
-### Description:
-
-Splits the `value` tensor along `split_dim` into a number of sub-tensors
-with same shape as the original one, except for `split_dim`. Same as
-tf.Split.
-
-### Operands:
-1. `split_dim`: tensor of 32-bit integer values
-1. `value`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `num_splits` | `IntegerAttr` | 32-bit integer attribute attribute |
-
-### Results:
-1. `outputs`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer values
-
-## tfl.split_v (TFL::SplitVOp)
-Splits a tensor into `num_split` tensors along one dimension.
-
-### Description:
-
-Splits the `value` tensor along `split_dim` into a number of sub-tensors
-with same shape as the original one, except for `split_dim`. The grouping
-of the resultant sub-tensors is decided by `size-splits`. Same as tf.SplitV.
-
-### Operands:
-1. `value`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer values
-1. `size_splits`: tensor of 32-bit integer values
-1. `split_dim`: tensor of 32-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `num_splits` | `IntegerAttr` | 32-bit integer attribute attribute |
-
-### Results:
-1. `outputs`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer values
-
-## tfl.sqrt (TFL::SqrtOp)
-Square root operator
-
-### Description:
-
-Computes element-wise Square root of input
-
-### Operands:
-1. `x`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `y`: tensor of any type values
-
-## tfl.square (TFL::SquareOp)
-Square operator
-
-### Description:
-
-Computes element-wise Square of input
-
-### Operands:
-1. `x`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `y`: tensor of any type values
-
-## tfl.squared_difference (TFL::SquaredDifferenceOp)
-Squared difference operator
-
-### Description:
-
-Element-wise squared difference operation.
-
-### Operands:
-1. `lhs`: tensor of any type values
-1. `rhs`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.squeeze (TFL::SqueezeOp)
-Removes dimensions of size 1 from the shape of a tensor.
-
-### Description:
-
-Given a tensor `input`, this operation returns a tensor of the same type with
-all dimensions of size 1 removed. If you don't want to remove all size 1
-dimensions, you can remove specific size 1 dimensions by specifying
-`axis`.
-
-For example:
-
-```
-# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
-shape(squeeze(t)) ==> [2, 3]
-```
-
-Or, to remove specific size 1 dimensions:
-
-```
-# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
-shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]
-```
-
-### Operands:
-1. `input`: tensor of any type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `squeeze_dims` | `ArrayAttr` | 64-bit integer array attribute attribute |
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.strided_slice (TFL::StridedSliceOp)
-StridedSlice Op
-
-### Description:
-
-Return a strided slice from `input`.
-
-### Operands:
-1. `input`: tensor of 32-bit float or 32-bit integer or 64-bit integer or 8-bit integer values
-1. `begin`: tensor of 32-bit integer values
-1. `end`: tensor of 32-bit integer values
-1. `strides`: tensor of 32-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `begin_mask` | `IntegerAttr` | 32-bit integer attribute attribute |
-| `end_mask` | `IntegerAttr` | 32-bit integer attribute attribute |
-| `ellipsis_mask` | `IntegerAttr` | 32-bit integer attribute attribute |
-| `new_axis_mask` | `IntegerAttr` | 32-bit integer attribute attribute |
-| `shrink_axis_mask` | `IntegerAttr` | 32-bit integer attribute attribute |
-
-### Results:
-1. `output`: tensor of 32-bit float or 32-bit integer or 64-bit integer or 8-bit integer values
-
-## tfl.sub (TFL::SubOp)
-Subtraction operator
-
-### Description:
-
-Element-wise subtraction operation.
-
-### Operands:
-1. `lhs`: tensor of any type values
-1. `rhs`: tensor of any type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `fused_activation_function` | `StringAttr` | fused activation enum attribute |
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.sum (TFL::SumOp)
-Sum operator
-
-### Description:
-
-Computes the sum reduction along the specified axes
-
-### Operands:
-1. `input`: tensor of any type values
-1. `axes`: tensor of 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `keep_dims` | `BoolAttr` | bool attribute attribute |
-
-### Results:
-1. &laquo;unnamed&raquo;: tensor of any type values
-
-## tfl.tanh (TFL::TanhOp)
-Hyperbolic tangent operator
-
-### Description:
-
-Computes element-wise Hyperbolic tangent of input
-
-### Operands:
-1. `x`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `y`: tensor of any type values
-
-## tfl.tile (TFL::TileOp)
-Tile operator.
-
-### Description:
-
- Constructs a tensor by tiling a given tensor.
-
-This operation creates a new tensor by replicating input
-multiples times. The output tensor's i'th dimension has
-input.dims(i) * multiples[i] elements, and the values of input
-are replicated multiples[i] times along the 'i'th dimension.
-For example, tiling [a b c d] by [2] produces [a b c d a b c d].
-
-### Operands:
-1. `input`: tensor of any type values
-1. `multiples`: tensor of 32/64-bit integer values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.topk_v2 (TFL::TopKV2Op)
-TopK operator
-
-### Description:
-
-Returns the top `k` largest element along each last dimensional slice of
-`input` and the indices of values within the last dimension of the input
-tensor.
-
-### Operands:
-1. `input`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values
-1. `k`: tensor of 32-bit integer values
-
-### Attributes:
-
-### Results:
-1. `values`: tensor of any type values
-1. `indices`: tensor of 32-bit integer values
-
-## tfl.transpose (TFL::TransposeOp)
-Transpose operator
-
-### Description:
-
-Returns the Transpose of x
-
-### Operands:
-1. `x`: tensor of any type values
-1. `perm`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `y`: tensor of any type values
-
-## tfl.unidirectional_sequence_lstm (TFL::UnidirectionalSequenceLSTMOp)
-Unidirectional sequence lstm operator
-
-### Description:
-
-A recurrent neural network specified by an LSTM cell. This Op supports
-unrolling the input along the time or batch dimensions, and
-implements the following operation for
-each element in the sequence s = 1...sequence_length:
-  outputs[s] = state = activation(LSTMOp(inputs[s]))
-
-where LSTMOp is LSTM TF Lite Op and the “activation” is the function passed
-as the “fused_activation_function” argument (if not “NONE”).
-
-### Operands:
-1. `input`: tensor of 32-bit float or 8-bit integer values
-1. `input_to_input_weights`: tensor of 32-bit float or 8-bit integer values or none type
-1. `input_to_forget_weights`: tensor of 32-bit float or 8-bit integer values
-1. `input_to_cell_weights`: tensor of 32-bit float or 8-bit integer values
-1. `input_to_output_weights`: tensor of 32-bit float or 8-bit integer values
-1. `recurrent_to_input_weights`: tensor of 32-bit float or 8-bit integer values or none type
-1. `recurrent_to_forget_weights`: tensor of 32-bit float or 8-bit integer values
-1. `recurrent_to_cell_weights`: tensor of 32-bit float or 8-bit integer values
-1. `recurrent_to_output_weights`: tensor of 32-bit float or 8-bit integer values
-1. `cell_to_input_weights`: tensor of 32-bit float or 8-bit integer values or none type
-1. `cell_to_forget_weights`: tensor of 32-bit float or 8-bit integer values or none type
-1. `cell_to_output_weights`: tensor of 32-bit float or 8-bit integer values or none type
-1. `input_gate_bias`: tensor of 32-bit float values or none type
-1. `forget_gate_bias`: tensor of 32-bit float values
-1. `cell_bias`: tensor of 32-bit float values
-1. `output_gate_bias`: tensor of 32-bit float values
-1. `projection_weights`: tensor of 32-bit float or 8-bit integer values or none type
-1. `projection_bias`: tensor of 32-bit float values or none type
-1. `input_activation_state`: stateful tensor
-1. `input_cell_state`: stateful tensor
-1. `input_layer_norm_coefficients`: tensor of 32-bit float or 8-bit integer values or none type
-1. `forget_layer_norm_coefficients`: tensor of 32-bit float or 8-bit integer values or none type
-1. `cell_layer_norm_coefficients`: tensor of 32-bit float or 8-bit integer values or none type
-1. `output_layer_norm_coefficients`: tensor of 32-bit float or 8-bit integer values or none type
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `fused_activation_function` | `StringAttr` | fused activation enum attribute |
-| `cell_clip` | `FloatAttr` | 32-bit float attribute attribute |
-| `proj_clip` | `FloatAttr` | 32-bit float attribute attribute |
-| `time_major` | `BoolAttr` | bool attribute attribute |
-
-### Results:
-1. `output`: tensor of any type values
-
-## tfl.unpack (TFL::UnpackOp)
-Unpacks a tensor along a dimension into multiple tensors
-
-### Description:
-
-Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors.
-
-Unpacks `num` tensors from `value` by chipping it along the `axis` dimension.
-For example, given a tensor of shape `(A, B, C, D)`;
-
-If `axis == 0` then the i'th tensor in `output` is the slice `value[i, :, :, :]`
-  and each tensor in `output` will have shape `(B, C, D)`. (Note that the
-  dimension unpacked along is gone, unlike `split`).
-
-If `axis == 1` then the i'th tensor in `output` is the slice `value[:, i, :, :]`
-  and each tensor in `output` will have shape `(A, C, D)`.
-Etc.
-
-This is the opposite of `pack`.
-
-### Operands:
-1. `input`: tensor of 32-bit float or 8-bit integer or 32-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `num` | `IntegerAttr` | 32-bit integer attribute attribute |
-| `axis` | `IntegerAttr` | 32-bit integer attribute attribute |
-
-### Results:
-1. `outputs`: tensor of 32-bit float or 8-bit integer or 32-bit integer values
-
-## tfl.zeros_like (TFL::ZerosLikeOp)
-ZerosLike operator
-
-### Description:
-
-Returns a tensor of zeros with the same shape and type as the input tensor.
-
-### Operands:
-1. `input`: tensor of any type values
-
-### Attributes:
-
-### Results:
-1. `output`: tensor of any type values
-
diff --git a/tensorflow/compiler/mlir/lite/ir/quantization_traits.h b/tensorflow/compiler/mlir/lite/ir/quantization_traits.h
new file mode 100644
index 0000000..a9cd13b
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/ir/quantization_traits.h
@@ -0,0 +1,127 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file defines the op traits used in the MLIR TensorFlow Lite dialect.
+
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_QUANTIZATION_TRAITS_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_IR_QUANTIZATION_TRAITS_H_
+
+#include "mlir/Dialect/QuantOps/QuantTypes.h"  // TF:local_config_mlir
+#include "mlir/Support/LLVM.h"  // TF:local_config_mlir
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
+
+namespace mlir {
+namespace OpTrait {
+namespace TFL {
+
+using QuantizedType = mlir::quant::QuantizedType;
+using UniformQuantizedType = mlir::quant::UniformQuantizedType;
+
+// The base class that all the quantization related OpTrait implements.
+template <typename ConcreteType, template <typename> class TraitType>
+struct QuantizationSpecTraitBase : public TraitBase<ConcreteType, TraitType> {
+  static bool IsBias(int index) { return false; }
+  static bool IsQuantizable() { return true; }
+};
+
+// This class provides the API for TFL ops that requires same input and output
+// scale as the quantization results. This is used as a trait like this:
+//
+//   class TransposeOp
+//       : public Op<TransposeOp, OpTrait::TFL::SameOperandsAndResultsScale> {
+//
+template <typename ConcreteType>
+class SameOperandsAndResultsScale
+    : public QuantizationSpecTraitBase<ConcreteType,
+                                       SameOperandsAndResultsScale> {};
+
+// This class provides the API for TFL ops that has a fixed output value range.
+// This is used as a trait like this:
+//
+//   class SoftmaxOp
+//       : public Op<SoftmaxOp,
+//           OpTrait::TFL::FixedResultUniformScale<
+//               8, -128, 390625, -8, 0, 255, false>::Impl> {
+//
+// TODO(fengliuai): create a better way to epxress floating point scale in the
+// template argument list.
+template <unsigned BitWidth, int ZeroPoint, int ScaleMantissa, int ScaleExp,
+          int64_t StorageTypeMin, int64_t StorageTypeMax, bool Sign>
+class FixedResultUniformScale {
+ public:
+  template <typename ConcreteType>
+  class Impl
+      : public QuantizationSpecTraitBase<
+            ConcreteType, FixedResultUniformScale<
+                              BitWidth, ZeroPoint, ScaleMantissa, ScaleExp,
+                              StorageTypeMin, StorageTypeMax, Sign>::Impl> {
+   public:
+    QuantizedType GetResultQuantizedType(int index) {
+      auto op = this->getOperation();
+      auto result_type =
+          op->getResult(index)->getType().template cast<TensorType>();
+      Builder builder(op->getContext());
+      IntegerType storage_type = builder.getIntegerType(BitWidth);
+      const double scale = static_cast<double>(ScaleMantissa) *
+                           ::pow(10.0, static_cast<double>(ScaleExp));
+      return UniformQuantizedType::getChecked(
+          Sign, storage_type, result_type.getElementType(), scale, ZeroPoint,
+          StorageTypeMin, StorageTypeMax, builder.getUnknownLoc());
+    }
+  };
+};
+
+// This class provides the API for TFL ops that has input as bias. This is used
+// as a trait like this:
+//
+//   class Conv2DOp
+//       : public Op<Conv2DOp, OpTrait::TFL::AccumulatorScale<2, 0, 1>::Impl> {
+//
+// TODO(fengliuai): supports a configurable accumulator bit width.
+template <int Bias, int... Operands>
+class AccumulatorUniformScale {
+ public:
+  template <typename ConcreteType>
+  class Impl
+      : public QuantizationSpecTraitBase<
+            ConcreteType, AccumulatorUniformScale<Bias, Operands...>::Impl> {
+   public:
+    // Whether the index-th operand is a bias.
+    static bool IsBias(int index) { return index == Bias; }
+
+    // Returns the indexes of all the non-bias operands.
+    static std::vector<int> GetAllNonBiasOperands() {
+      return std::vector<int>({Operands...});
+    }
+  };
+};
+
+// This class provides the API for TFL ops that shouldn't be quantized. This is
+// used as a trait like this:
+//
+//   class LessOp : public Op<LessOp, OpTrait::TFL::NoQuantizableResult> {
+//
+template <typename ConcreteType>
+class NoQuantizableResult
+    : public QuantizationSpecTraitBase<ConcreteType, NoQuantizableResult> {
+ public:
+  static bool IsQuantizable() { return false; }
+};
+
+}  // namespace TFL
+}  // namespace OpTrait
+}  // namespace mlir
+
+#endif  // TENSORFLOW_COMPILER_MLIR_LITE_IR_QUANTIZATION_TRAITS_H_
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
index 23d1388..d3254bc 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
@@ -15,6 +15,11 @@
 
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
 
+#include <cstdint>
+
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APInt.h"
+#include "mlir/IR/Attributes.h"  // TF:local_config_mlir
 #include "mlir/IR/Builders.h"  // TF:local_config_mlir
 #include "mlir/IR/Matchers.h"  // TF:local_config_mlir
 #include "mlir/IR/OpImplementation.h"  // TF:local_config_mlir
@@ -22,6 +27,7 @@
 #include "mlir/IR/StandardTypes.h"  // TF:local_config_mlir
 #include "mlir/IR/TypeUtilities.h"  // TF:local_config_mlir
 #include "mlir/StandardOps/Ops.h"  // TF:local_config_mlir
+#include "mlir/Support/LLVM.h"  // TF:local_config_mlir
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 
 namespace mlir {
@@ -54,13 +60,21 @@
   return std::equal(a.rbegin(), a.rend(), b.rbegin());
 }
 
+// Returns true if it is a shaped type of f32 elements.
+inline bool IsF32ShapedType(Type t) {
+  if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) {
+    return shaped_type.getElementType().isF32();
+  }
+  return false;
+}
+
 // Performs const folding `calculate` with broadcast behavior on the two
 // attributes `operand1` and `operand2` and returns the result if possible.
 // The two operands are expected to both be scalar values.
 template <class AttrElementT,
           class ElementValueT = typename AttrElementT::ValueType,
           class CalculationT =
-              std::function<ElementValueT(ElementValueT, ElementValueT)>>
+              llvm::function_ref<ElementValueT(ElementValueT, ElementValueT)>>
 Attribute ConstFoldBinaryOpScalarScalar(Type result_type, Attribute operand1,
                                         Attribute operand2,
                                         const CalculationT &calculate) {
@@ -84,7 +98,7 @@
 template <class AttrElementT,
           class ElementValueT = typename AttrElementT::ValueType,
           class CalculationT =
-              std::function<ElementValueT(ElementValueT, ElementValueT)>>
+              llvm::function_ref<ElementValueT(ElementValueT, ElementValueT)>>
 Attribute ConstFoldBinaryOpSplatSplat(Type result_type, Attribute operand1,
                                       Attribute operand2,
                                       const CalculationT &calculate) {
@@ -106,13 +120,13 @@
 template <class AttrElementT,
           class ElementValueT = typename AttrElementT::ValueType,
           class CalculationT =
-              std::function<ElementValueT(ElementValueT, ElementValueT)>>
+              llvm::function_ref<ElementValueT(ElementValueT, ElementValueT)>>
 Attribute ConstFoldBinaryOpDenseSplat(Type result_type, Attribute operand1,
                                       Attribute operand2,
                                       const CalculationT &calculate) {
   auto lhs = operand1.cast<DenseElementsAttr>();
 
-  // TODO: Support broadcast behavior
+  // TODO(b/139192933): Support broadcast behavior
   if (lhs.getType() != result_type || operand2.getType() != result_type)
     return {};
 
@@ -139,7 +153,7 @@
 template <class AttrElementT,
           class ElementValueT = typename AttrElementT::ValueType,
           class CalculationT =
-              std::function<ElementValueT(ElementValueT, ElementValueT)>>
+              llvm::function_ref<ElementValueT(ElementValueT, ElementValueT)>>
 Attribute ConstFoldBinaryOpDenseDense(Type result_type, Attribute operand1,
                                       Attribute operand2,
                                       const CalculationT &calculate) {
@@ -203,7 +217,7 @@
 template <class AttrElementT,
           class ElementValueT = typename AttrElementT::ValueType,
           class CalculationT =
-              std::function<ElementValueT(ElementValueT, ElementValueT)>>
+              llvm::function_ref<ElementValueT(ElementValueT, ElementValueT)>>
 Attribute ConstFoldBinaryOp(Type result_type, Attribute operand1,
                             Attribute operand2, const CalculationT &calculate,
                             bool is_commutative) {
@@ -249,8 +263,9 @@
 /// `intCalculate` is chosen to conduct the calculate.
 Attribute ConstFoldBinaryOp(
     Type result_type, ArrayRef<Attribute> operands,
-    std::function<APFloat(APFloat, APFloat)> float_calculate,
-    std::function<APInt(APInt, APInt)> int_calculate, bool is_commutative) {
+    llvm::function_ref<APFloat(APFloat, APFloat)> float_calculate,
+    llvm::function_ref<APInt(APInt, APInt)> int_calculate,
+    bool is_commutative) {
   // Note: All types are wrapped in tensor types in TFlite. E.g., f32 is
   // represented as tensor<f32>. So we are only handling tensor types here.
   auto type = result_type.dyn_cast<ShapedType>();
@@ -269,6 +284,32 @@
   return {};
 }
 
+/// Performs const folding a attributes `operand` and returns the result if
+/// possible.
+/// The function currently asserts that the `result_type` to be a f32 tensor
+/// type.
+/// TODO: Extend this function to handle integral tensor for ops like
+/// "tfl.logical_not".
+Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
+                           llvm::function_ref<APFloat(APFloat)> calculate) {
+  assert(IsF32ShapedType(result_type));
+  auto result_shape_type = result_type.cast<ShapedType>();
+
+  if (auto dense_elements = operand.dyn_cast_or_null<DenseElementsAttr>()) {
+    SmallVector<APFloat, 16> new_values;
+    const int num_elements = result_shape_type.getNumElements();
+    new_values.reserve(num_elements);
+
+    for (APFloat old_value : dense_elements.getValues<APFloat>()) {
+      new_values.push_back(calculate(old_value));
+    }
+
+    return DenseElementsAttr::get(result_shape_type, new_values);
+  }
+
+  return {};
+}
+
 void buildComparisonBinOp(Builder *builder, OperationState *result, Value *lhs,
                           Value *rhs) {
   auto result_type =
@@ -453,12 +494,19 @@
   // Remove identity reshape.
   if (getType() == getOperand()->getType()) return getOperand();
 
+  // Constant folding
+  assert(operands.size() == 1);
+  if (auto dense_elements = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
+    auto result_shape_type = getType().cast<ShapedType>();
+    return dense_elements.reshape(result_shape_type);
+  }
+
   return nullptr;
 }
 
 void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                             MLIRContext *context) {
-  results.push_back(llvm::make_unique<RemoveAdjacentReshape>(context));
+  results.insert<RemoveAdjacentReshape>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -543,7 +591,7 @@
 
 void FakeQuantOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                               MLIRContext *context) {
-  results.push_back(llvm::make_unique<DropFakeQuant>(context));
+  results.insert<DropFakeQuant>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -594,11 +642,337 @@
 }
 
 //===----------------------------------------------------------------------===//
+// UnidirectionalSequenceRNNOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult Verify(UnidirectionalSequenceRNNOp op) {
+  auto operands = op.GetStatefulOperands();
+  if (operands.size() == 1 && operands[0] == 4) {
+    return success();
+  }
+  return op.emitError(
+      "UnidirectionalSequenceRNNOp expected to have one stateful operand");
+}
+
+//===----------------------------------------------------------------------===//
+// AbsOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult AbsOp::fold(ArrayRef<Attribute> operands) {
+  Type result_type = getType();
+  // Only constant fold for tensor of f32 is implemented.
+  if (!IsF32ShapedType(result_type)) return nullptr;
+
+  auto compute = [](APFloat value) -> APFloat { return llvm::abs(value); };
+  return ConstFoldUnaryOp(result_type, operands[0], compute);
+}
+
+//===----------------------------------------------------------------------===//
+// SinOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult SinOp::fold(ArrayRef<Attribute> operands) {
+  Type result_type = getType();
+  // Only constant fold for tensor of f32 is implemented.
+  if (!IsF32ShapedType(result_type)) return nullptr;
+
+  auto compute = [](APFloat value) -> APFloat {
+    float f = value.convertToFloat();
+    float result = std::sin(f);
+    return APFloat(result);
+  };
+  return ConstFoldUnaryOp(result_type, operands[0], compute);
+}
+
+//===----------------------------------------------------------------------===//
+// CosOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult CosOp::fold(ArrayRef<Attribute> operands) {
+  Type result_type = getType();
+  // Only constant fold for tensor of f32 is implemented.
+  if (!IsF32ShapedType(result_type)) return nullptr;
+
+  auto compute = [](APFloat value) -> APFloat {
+    float f = value.convertToFloat();
+    float result = std::cos(f);
+    return APFloat(result);
+  };
+  return ConstFoldUnaryOp(result_type, operands[0], compute);
+}
+
+//===----------------------------------------------------------------------===//
+// LogOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult LogOp::fold(ArrayRef<Attribute> operands) {
+  Type result_type = getType();
+  // Only constant fold for tensor of f32 is implemented.
+  if (!IsF32ShapedType(result_type)) return nullptr;
+
+  auto compute = [](APFloat value) -> APFloat {
+    float f = value.convertToFloat();
+    float result = std::log(f);
+    return APFloat(result);
+  };
+  return ConstFoldUnaryOp(result_type, operands[0], compute);
+}
+
+//===----------------------------------------------------------------------===//
+// SqrtOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) {
+  Type result_type = getType();
+  // Only constant fold for tensor of f32 is implemented.
+  if (!IsF32ShapedType(result_type)) return nullptr;
+
+  auto compute = [](APFloat value) -> APFloat {
+    float f = value.convertToFloat();
+    float result = std::sqrt(f);
+    return APFloat(result);
+  };
+  return ConstFoldUnaryOp(result_type, operands[0], compute);
+}
+
+//===----------------------------------------------------------------------===//
+// RsqrtOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult RsqrtOp::fold(ArrayRef<Attribute> operands) {
+  Type result_type = getType();
+  // Only constant fold for tensor of f32 is implemented.
+  if (!IsF32ShapedType(result_type)) return nullptr;
+
+  auto compute = [](APFloat value) -> APFloat {
+    float f = value.convertToFloat();
+    float result = 1.f / std::sqrt(f);
+    return APFloat(result);
+  };
+  return ConstFoldUnaryOp(result_type, operands[0], compute);
+}
+
+//===----------------------------------------------------------------------===//
+// SquareOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult SquareOp::fold(ArrayRef<Attribute> operands) {
+  Type result_type = getType();
+  // Only constant fold for tensor of f32 is implemented.
+  if (!IsF32ShapedType(result_type)) return nullptr;
+
+  auto compute = [](APFloat value) -> APFloat { return value * value; };
+  return ConstFoldUnaryOp(result_type, operands[0], compute);
+}
+
+//===----------------------------------------------------------------------===//
+// RankOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 1);
+  auto result_type = getType().cast<ShapedType>();
+  if (auto elements_attr = operands[0].dyn_cast_or_null<ElementsAttr>()) {
+    auto rank = static_cast<int32_t>(elements_attr.getType().getRank());
+    return DenseElementsAttr::get(result_type, {rank});
+  }
+
+  // Also fold if `input` has a known rank.
+  auto input_type = input()->getType().cast<ShapedType>();
+  // Do not fold if rank is zero because the TFLite converter doesn't
+  // distinguish between unranked input and scalar input due to b/138865275.
+  // TODO(b/138865275): Remove `input_type.getRank() != 0` in the following
+  // predicate and fold the op when rank is zero.
+  if (input_type.hasRank() && input_type.getRank() != 0) {
+    auto rank = static_cast<int32_t>(input_type.getRank());
+    return DenseElementsAttr::get(result_type, {rank});
+  }
+
+  return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// ConstOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.empty() && "constant has no operands");
+
+  // Return the held attribute value.
+  return value();
+}
+
+//===----------------------------------------------------------------------===//
+// RangeOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Compute the length of a range (1-D) tensor given `start`, `limit`, `delta`.
+// Template parameter `FloatOrInt` must be standard C integer or floating-point
+// types.
+template <typename FloatOrInt>
+int GetLengthOfRange(FloatOrInt start, FloatOrInt limit, FloatOrInt delta) {
+  // Refer to the implementation in
+  // tensorflow/lite/kernels/range.cc.
+  return std::is_integral<FloatOrInt>::value
+             ? ((std::abs(limit - start) + std::abs(delta) - 1) /
+                std::abs(delta))
+             : std::ceil(std::abs((limit - start) / delta));
+}
+
+// Builds a constant range tensor of `result_elem_type` elements.
+// Template parameter `FloatOrIntAtrr` must be mlir::IntegerAttr or
+// mlir::FloatAttr.
+template <typename FloatOrIntAtrr>
+DenseElementsAttr BuildConstRangeTensor(Type result_elem_type, int num_elements,
+                                        FloatOrIntAtrr start_attr,
+                                        FloatOrIntAtrr delta_attr) {
+  using ValueType = typename FloatOrIntAtrr::ValueType;  // APInt or APFloat
+  ValueType start = start_attr.getValue();
+  ValueType delta = delta_attr.getValue();
+
+  SmallVector<ValueType, 16> new_values;
+  new_values.reserve(num_elements);
+  ValueType new_value = start;
+  for (int i = 0; i < num_elements; ++i) {
+    new_values.push_back(new_value);
+    new_value = new_value + delta;
+  }
+  // Result is always a 1-D tensor.
+  auto new_result_type =
+      RankedTensorType::get({num_elements}, result_elem_type);
+  return DenseElementsAttr::get(new_result_type, new_values);
+}
+}  // namespace
+
+OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 3);
+  auto start_tensor = operands[0].dyn_cast_or_null<ElementsAttr>();
+  auto limit_tensor = operands[1].dyn_cast_or_null<ElementsAttr>();
+  auto delta_tensor = operands[2].dyn_cast_or_null<ElementsAttr>();
+  if (start_tensor && limit_tensor && delta_tensor) {
+    // Operands should all be scalars
+    assert(start_tensor.getType().getRank() == 0 &&
+           limit_tensor.getType().getRank() == 0 &&
+           delta_tensor.getType().getRank() == 0);
+    Type elem_type = getType().cast<ShapedType>().getElementType();
+    if (elem_type.isa<IntegerType>()) {
+      auto start_attr = start_tensor.getValue({}).cast<IntegerAttr>();
+      auto limit_attr = limit_tensor.getValue({}).cast<IntegerAttr>();
+      auto delta_attr = delta_tensor.getValue({}).cast<IntegerAttr>();
+      const int num_elements = GetLengthOfRange(
+          start_attr.getInt(), limit_attr.getInt(), delta_attr.getInt());
+      return BuildConstRangeTensor(elem_type, num_elements, start_attr,
+                                   delta_attr);
+    } else if (elem_type.isa<FloatType>()) {
+      auto start_attr = start_tensor.getValue({}).cast<FloatAttr>();
+      auto limit_attr = limit_tensor.getValue({}).cast<FloatAttr>();
+      auto delta_attr = delta_tensor.getValue({}).cast<FloatAttr>();
+      const int num_elements = GetLengthOfRange(start_attr.getValueAsDouble(),
+                                                limit_attr.getValueAsDouble(),
+                                                delta_attr.getValueAsDouble());
+      return BuildConstRangeTensor(elem_type, num_elements, start_attr,
+                                   delta_attr);
+    }
+  }
+
+  return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// TransposeOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Computes the permutation of a constant `input_tensor` according to `perm`.
+// The function recursively traverses the dimensions of the output tensor in
+// a row-major order and writes the value in the output tensor into
+// `new_values`.
+void ComputePermutation(ElementsAttr input_tensor, ArrayRef<int32_t> perm,
+                        ArrayRef<int64_t> output_shape, int num_dimensions,
+                        int output_axis, std::vector<uint64_t> *input_indices,
+                        std::vector<Attribute> *new_values) {
+  // Refer to the implementation of `Transpose` function in
+  // tensorflow/lite/kernels/internal/reference/reference_ops.h
+  assert(output_axis < num_dimensions);
+  const int input_axis = perm[output_axis];
+  for (int i = 0; i < output_shape[output_axis]; ++i) {
+    // Update the input indices on `input_axis`.
+    input_indices->at(input_axis) = i;
+    // Write the value from `input_tensor` if it is the last axis or
+    // recurse into the next axis.
+    const bool is_last_axis = output_axis == num_dimensions - 1;
+    if (is_last_axis) {
+      new_values->push_back(input_tensor.getValue(*input_indices));
+    } else {
+      ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
+                         output_axis + 1, input_indices, new_values);
+    }
+  }
+}
+
+}  // namespace
+
+OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2);
+  auto input_tensor = operands[0].dyn_cast_or_null<ElementsAttr>();
+  auto perm_tensor = operands[1].dyn_cast_or_null<ElementsAttr>();
+  if (!input_tensor || !perm_tensor) return nullptr;
+
+  // Do not try to fold elements attr of a quant type because
+  // DenseElementsAttr does not support it.
+  if (!getType().cast<ShapedType>().getElementType().isIntOrFloat())
+    return nullptr;
+
+  assert(perm_tensor.getType().getRank() == 1);
+  const int num_dimensions = input_tensor.getType().getRank();
+  assert(perm_tensor.getType().getNumElements() == num_dimensions);
+
+  ArrayRef<int64_t> input_shape = input_tensor.getType().getShape();
+  auto output_type = getType().cast<ShapedType>();
+
+  SmallVector<int32_t, 4> perm;
+  SmallVector<int64_t, 4> output_shape;
+  for (int i = 0; i < num_dimensions; ++i) {
+    perm.push_back(perm_tensor.getValue({static_cast<uint64_t>(i)})
+                       .cast<IntegerAttr>()
+                       .getInt());
+    output_shape.push_back(input_shape[perm[i]]);
+
+    // Check that the derived output shape matches the static shape.
+    assert(!output_type.hasStaticShape() ||
+           output_type.getShape()[i] == output_shape[i]);
+  }
+
+  std::vector<Attribute> new_values;
+  new_values.reserve(input_tensor.getType().getNumElements());
+  std::vector<uint64_t> input_indices(num_dimensions);
+  ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
+                     /*output_axis=*/0, &input_indices, &new_values);
+  auto result_type =
+      RankedTensorType::get(output_shape, output_type.getElementType());
+  return DenseElementsAttr::get(result_type, new_values);
+}
+
+//===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
 
 #define GET_OP_CLASSES
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
 
+Operation *TensorFlowLiteDialect::materializeConstant(OpBuilder &builder,
+                                                      Attribute value,
+                                                      Type type, Location loc) {
+  // If this is an opaque elements attribute or the result type doesn't match
+  // the attribute type, then generate a tfl.pseudo_const.
+  if (value.isa<OpaqueElementsAttr>() ||
+      (value.isa<ElementsAttr>() && value.getType() != type))
+    return builder.create<ConstOp>(loc, type, value.cast<ElementsAttr>());
+  return nullptr;
+}
+
 }  // namespace TFL
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h
index 5eac051..4782896 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h
@@ -27,6 +27,7 @@
 #include "mlir/IR/StandardTypes.h"  // TF:local_config_mlir
 #include "mlir/Support/Functional.h"  // TF:local_config_mlir
 #include "mlir/Support/LLVM.h"  // TF:local_config_mlir
+#include "tensorflow/compiler/mlir/lite/ir/quantization_traits.h"
 #include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
 #include "tensorflow/lite/schema/schema_generated.h"
 
@@ -36,6 +37,11 @@
 class TensorFlowLiteDialect : public Dialect {
  public:
   explicit TensorFlowLiteDialect(MLIRContext *context);
+
+  // Registered hook to materialize a constant operation from a given attribute
+  // value with the desired resultant type.
+  Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
+                                 Location loc) override;
 };
 
 #define GET_OP_CLASSES
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index 298f962..5a31d3f 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -134,9 +134,14 @@
     : TFL_QuantizedType<"Uniform",
                         [8, zero_pt, smantissa, sexp, -128, 127], 1>;
 
-// 8-bits quantized types. The definitions can be used to specify tensor types.
+// General uniform quantized types. The definitions can be used to specify
+// operand's tensor types.
 def TFL_QUI8 : TFL_QuantizedType<"Uniform", [8], 0>;
 def TFL_QI8 : TFL_QuantizedType<"Uniform", [8], 1>;
+def TFL_QUI16 : TFL_QuantizedType<"Uniform", [16], 0>;
+def TFL_QI16 : TFL_QuantizedType<"Uniform", [16], 1>;
+def TFL_QUI32 : TFL_QuantizedType<"Uniform", [32], 0>;
+def TFL_QI32 : TFL_QuantizedType<"Uniform", [32], 1>;
 
 //===----------------------------------------------------------------------===//
 // TensorType attribute definitions.
@@ -331,7 +336,7 @@
   let arguments = (
     ins AnyTensor:$input,
     AnyTensor:$filter,
-    AnyTensor:$bias,
+    TFL_TensorOfOrNone<[AnyType]>:$bias,
     I32Attr:$dilation_h_factor,
     I32Attr:$dilation_w_factor,
     TFL_AFAttr:$fused_activation_function,
@@ -361,6 +366,8 @@
   let arguments = (ins AnyTensor:$x);
 
   let results = (outs AnyTensor:$y);
+
+  let hasFolder = 1;
 }
 
 def TFL_AddOp : TFL_Op<"add", [Broadcastable, NoSideEffect, Commutative]> {
@@ -406,6 +413,33 @@
   );
 }
 
+def TFL_ReduceAnyOp : TFL_Op<"reduce_any", [NoSideEffect]> {
+  let summary = [{
+Computes the "logical or" of elements across dimensions of a tensor.
+  }];
+
+  let description = [{
+Reduces `input` along the dimensions given in `axis`. Unless
+`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+`axis`. If `keep_dims` is true, the reduced dimensions are
+retained with length 1.
+  }];
+
+  let arguments = (ins
+    I1Tensor:$input,
+    I32Tensor:$reduction_indices,
+
+    DefaultValuedAttr<BoolAttr, "false">:$keep_dims
+  );
+
+  let results = (outs
+    I1Tensor:$output
+  );
+
+  let hasOptions = 1;
+  let customOption = "ReducerOptions";
+}
+
 def TFL_AveragePool2DOp:
     TFL_Op<"average_pool_2d", [NoSideEffect, TFL_SameOperandsAndResultsScale]> {
   let summary = "Average_pool_2d operator";
@@ -540,6 +574,8 @@
   let arguments = (ins ElementsAttr:$value);
 
   let results = (outs AnyTensor:$output);
+
+  let hasFolder = 1;
 }
 
 def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution">;
@@ -554,6 +590,8 @@
   let arguments = (ins TFL_FpTensor:$x);
 
   let results = (outs TFL_FpTensor:$y);
+
+  let hasFolder = 1;
 }
 
 def TFL_DepthwiseConv2DOp :
@@ -577,9 +615,9 @@
   let summary = "Fully connected op";
 
   let arguments = (ins
-    TensorOf<[F32, TFL_QI8, TFL_QUI8]>:$input,
-    TensorOf<[F32, TFL_QI8, TFL_QUI8]>:$filter,
-    TFL_TensorOfOrNone<[F32, TFL_QI8, TFL_QUI8]>:$bias,
+    TensorOf<[F32, TFL_QI8, TFL_QUI8, TFL_QI16, TFL_QUI16]>:$input,
+    TensorOf<[F32, TFL_QI8, TFL_QUI8, TFL_QI16, TFL_QUI16]>:$filter,
+    TFL_TensorOfOrNone<[F32, TFL_QI32, TFL_QUI32]>:$bias,
 
     TFL_AFAttr:$fused_activation_function,
     TFL_FullyConnectedOptionsWeightFormatAttr:$weights_format,
@@ -588,7 +626,7 @@
 
   // Depending on the weights format, this op can have one or two outputs.
   let results = (outs
-    Variadic<TensorOf<[F32, TFL_QI8, TFL_QUI8]>>:$output
+    Variadic<TensorOf<[F32, TFL_QI8, TFL_QUI8, TFL_QI16, TFL_QUI16]>>:$output
   );
 
   let hasOptions = 1;
@@ -634,14 +672,13 @@
     Gather slices from `params` into a Tensor with shape specified by `indices`.
   }];
 
-  // TODO: missing Uint8.
   let arguments = (ins
-    TensorOf<[F32, I8, I64, I32]>:$params,
+    TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$params,
     TFL_I32OrI64Tensor:$indices
   );
 
   let results = (outs
-    TensorOf<[F32, I8, I64, I32]>:$output
+    TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$output
   );
 }
 
@@ -818,9 +855,9 @@
     Performs element-wise natural exponentiation operation on input.
   }];
 
-  let arguments = (ins AnyTensor:$x);
+  let arguments = (ins TFL_FpTensor:$x);
 
-  let results = (outs AnyTensor:$y);
+  let results = (outs TFL_FpTensor:$y);
 
   let hasOptions = 0b1;
 }
@@ -962,14 +999,12 @@
   }];
 
   let arguments = (
-    ins AnyTensor:$lhs,
-    AnyTensor:$rhs);
+    ins TensorOf<[I32, I64, F32]>:$lhs,
+    TensorOf<[I32, I64, F32]>:$rhs);
 
-  let results = (outs AnyTensor:$output);
+  let results = (outs TensorOf<[I32, I64, F32]>:$output);
 
-  let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }];
-
-  let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }];
+  let builders = [TFL_BroadcastableBinaryBuilder];
 }
 
 def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, TFL_NoQuantizableResult]> {
@@ -1096,6 +1131,24 @@
   let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }];
 }
 
+def TFL_LogisticOp: TFL_Op<"logistic", [
+    NoSideEffect,
+    SameOperandsAndResultShape,
+    // zero_point = 0
+    // scale = 1. / (max_value + 1)
+    TFL_FixedResultScale<TFL_Int8UniformQuantizedType<-128, 390625, -8>>,
+    TFL_FixedResultScale<TFL_UInt8UniformQuantizedType<0, 390625, -8>>]> {
+  let summary = "Logistic operator";
+
+  let description = [{
+    Computes element-wise Sigmoid of input
+  }];
+
+  let arguments = (ins TensorOf<[AnyFloat, TFL_QI8, TFL_QUI8]>:$x);
+
+  let results = (outs TensorOf<[AnyFloat, TFL_QI8, TFL_QUI8]>:$y);
+}
+
 def TFL_LogOp: TFL_Op<"log", [NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Natural logarithm operator";
 
@@ -1106,6 +1159,8 @@
   let arguments = (ins AnyTensor:$x);
 
   let results = (outs AnyTensor:$y);
+
+  let hasFolder = 1;
 }
 
 // TODO(b/130643170): Adds some constraint for the input/output element types.
@@ -1182,11 +1237,13 @@
   }];
 
   let arguments = (
-    ins TFL_FpOrI32OrI64Tensor:$lhs,
-    TFL_FpOrI32OrI64Tensor:$rhs
+    ins TensorOf<[AnyFloat, TFL_Int32Or64, TFL_QI8, TFL_QUI8]>:$lhs,
+    TensorOf<[AnyFloat, TFL_Int32Or64, TFL_QI8, TFL_QUI8]>:$rhs
   );
 
-  let results = (outs TFL_FpOrI32OrI64Tensor:$max);
+  let results = (outs
+    TensorOf<[AnyFloat, TFL_Int32Or64, TFL_QI8, TFL_QUI8]>:$max
+  );
 
   let builders = [TFL_BroadcastableBinaryBuilder];
 
@@ -1245,6 +1302,22 @@
   let hasOptions = 1;
 }
 
+def TFL_RoundOp: TFL_Op<"round", [NoSideEffect, SameOperandsAndResultType]> {
+  let summary = "Round operator";
+
+  let description = [{
+Rounds the values of a tensor to the nearest integer, element-wise.
+  }];
+
+  let arguments = (ins
+    TensorOf<[F32]>:$x
+  );
+
+  let results = (outs
+    TensorOf<[F32]>:$y
+  );
+}
+
 def TFL_SliceOp : TFL_Op<"slice", [
     NoSideEffect, TFL_SameOperandsAndResultsScale]> {
   let summary = "Return a slice from 'input'.";
@@ -1278,7 +1351,7 @@
 
   let arguments = (ins
     AnyTensor:$input,
-    TFL_I32OrI64Tensor:$axes,
+    I32Tensor:$axes,
     BoolAttr:$keep_dims
   );
 
@@ -1297,7 +1370,7 @@
 
   let arguments = (ins
     AnyTensor:$input,
-    TFL_I32OrI64Tensor:$axes,
+    I32Tensor:$axes,
     BoolAttr:$keep_dims
   );
 
@@ -1316,7 +1389,7 @@
 
   let arguments = (ins
     AnyTensor:$input,
-    TFL_I32OrI64Tensor:$axes,
+    I32Tensor:$axes,
     BoolAttr:$keep_dims
   );
 
@@ -1335,7 +1408,7 @@
 
   let arguments = (ins
     TensorOf<[F32, I8, I32, I64]>:$input,
-    TFL_I32OrI64Tensor:$axes,
+    I32Tensor:$axes,
     BoolAttr:$keep_dims
   );
 
@@ -1353,11 +1426,13 @@
   }];
 
   let arguments = (
-    ins TFL_FpOrI32OrI64Tensor:$lhs,
-    TFL_FpOrI32OrI64Tensor:$rhs
+    ins TensorOf<[AnyFloat, TFL_Int32Or64, TFL_QI8, TFL_QUI8]>:$lhs,
+    TensorOf<[AnyFloat, TFL_Int32Or64, TFL_QI8, TFL_QUI8]>:$rhs
   );
 
-  let results = (outs TFL_FpOrI32OrI64Tensor:$min);
+  let results = (outs
+    TensorOf<[AnyFloat, TFL_Int32Or64, TFL_QI8, TFL_QUI8]>:$min
+  );
 
   let builders = [TFL_BroadcastableBinaryBuilder];
 
@@ -1562,9 +1637,13 @@
   let arguments = (ins AnyTensor:$input);
 
   let results = (outs TFL_IntTensor:$output);
+
+  let hasFolder = 1;
 }
 
-def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect, SameOperandsAndResultType]> {
+def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect,
+                                SameOperandsAndResultShape,
+                                TFL_SameOperandsAndResultsScale]> {
   let summary = "Relu operator";
 
   let description = [{
@@ -1577,7 +1656,9 @@
   let results = (outs AnyTensor:$y);
 }
 
-def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect, SameOperandsAndResultType]> {
+def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
+                                  SameOperandsAndResultShape,
+                                  TFL_SameOperandsAndResultsScale]> {
   let summary = "Relu6 operator";
 
   let description = [{
@@ -1628,9 +1709,8 @@
 `seq_dim` reversed.
   }];
 
-  // Missing Uint8.
   let arguments = (ins
-    TensorOf<[F32, I16, I32, I64]>:$input,
+    TensorOf<[F32, I16, I32, I64, TFL_Uint8]>:$input,
     TFL_I32OrI64Tensor:$seq_lengths,
 
     I32Attr:$seq_dim,
@@ -1638,7 +1718,7 @@
   );
 
   let results = (outs
-    TensorOf<[F32, I16, I32, I64]>:$output
+    TensorOf<[F32, I16, I32, I64, TFL_Uint8]>:$output
   );
 
   let hasOptions = 1;
@@ -1654,6 +1734,8 @@
   let arguments = (ins AnyTensor:$x);
 
   let results = (outs AnyTensor:$y);
+
+  let hasFolder = 1;
 }
 
 def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect, TFL_NoQuantizableResult]> {
@@ -1674,26 +1756,12 @@
   let hasOptions = 1;
 }
 
-def TFL_LogisticOp: TFL_Op<"logistic", [
-    NoSideEffect,
-    SameOperandsAndResultType,
-    // zero_point = 0
-    // scale = 1. / (max_value + 1)
-    TFL_FixedResultScale<TFL_Int8UniformQuantizedType<-128, 390625, -8>>,
-    TFL_FixedResultScale<TFL_UInt8UniformQuantizedType<0, 390625, -8>>]> {
-  let summary = "Logistic operator";
-
-  let description = [{
-    Computes element-wise Sigmoid of input
-  }];
-
-  let arguments = (ins TFL_FpTensor:$x);
-
-  let results = (outs TFL_FpTensor:$y);
-}
-
 // TODO(jpienaar): Flesh this out.
-def TFL_RangeOp: TFL_Op<"range", [NoSideEffect]> {
+def TFL_RangeOp: TFL_Op<"range", [NoSideEffect, TFL_OperandHasRank<0, 0>,
+    TFL_OperandHasRank<1, 0>, TFL_OperandHasRank<2, 0>,
+    PredOpTrait<"operands and output must have same element type",
+      And<[TCresVTEtIsSameAsOp<0, 0>, TCresVTEtIsSameAsOp<0, 1>,
+           TCresVTEtIsSameAsOp<0, 2>]>>]> {
   let summary = "Range operator";
 
   let description = [{
@@ -1707,6 +1775,8 @@
     AnyTensor:$delta);
 
   let results = (outs AnyTensor:$result);
+
+  let hasFolder = 1;
 }
 
 def TFL_ReverseV2Op: TFL_Op<"reverse_v2",
@@ -1760,9 +1830,8 @@
 
   let arguments = (ins
     TFL_BoolTensor:$condition,
-    // TODO: Missing uint8.
-    TensorOf<[F32, I1, I8, I16, I32, I64]>:$x,
-    TensorOf<[F32, I1, I8, I16, I32, I64]>:$y);
+    TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$x,
+    TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$y);
   let results = (outs AnyTensor:$output);
 
   // TODO(jpienaar): autogenerate this.
@@ -1787,6 +1856,8 @@
   let arguments = (ins TFL_FpTensor:$x);
 
   let results = (outs TFL_FpTensor:$y);
+
+  let hasFolder = 1;
 }
 
 // TODO(b/130643170): Adds some constraint for the input/output element types.
@@ -1822,9 +1893,11 @@
     Computes element-wise Square root of input
   }];
 
-  let arguments = (ins AnyTensor:$x);
+  let arguments = (ins TFL_FpTensor:$x);
 
-  let results = (outs AnyTensor:$y);
+  let results = (outs TFL_FpTensor:$y);
+
+  let hasFolder = 1;
 }
 
 def TFL_SquareOp: TFL_Op<"square", [NoSideEffect, SameOperandsAndResultType]> {
@@ -1834,11 +1907,13 @@
     Computes element-wise Square of input
   }];
 
-  let arguments = (ins AnyTensor:$x);
+  let arguments = (ins TFL_FpTensor:$x);
 
-  let results = (outs AnyTensor:$y);
+  let results = (outs TFL_FpTensor:$y);
 
   let hasOptions = 0b1;
+
+  let hasFolder = 1;
 }
 
 def TFL_SubOp : TFL_Op<"sub", [Broadcastable, NoSideEffect]> {
@@ -1902,10 +1977,9 @@
     Computes element-wise Hyperbolic tangent of input
   }];
 
-  // TODO(haoliang): missing Uint8.
-  let arguments = (ins TensorOf<[F32, I16, I8]>:$x);
+  let arguments = (ins TensorOf<[F32, I16, I8, TFL_Uint8]>:$x);
 
-  let results = (outs TensorOf<[F32, I16, I8]>:$y);
+  let results = (outs TensorOf<[F32, I16, I8, TFL_Uint8]>:$y);
 }
 
 def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
@@ -1922,9 +1996,11 @@
    For example, tiling [a b c d] by [2] produces [a b c d a b c d].
   }];
 
-  let arguments = (ins AnyTensor:$input, TFL_I32OrI64Tensor:$multiples);
+  let arguments = (ins
+    TensorOf<[F32, I1, I32, I64, TFL_Uint8]>:$input,
+    TFL_I32OrI64Tensor:$multiples);
 
-  let results = (outs AnyTensor:$output);
+  let results = (outs TensorOf<[F32, I1, I32, I64, TFL_Uint8]>:$output);
 
   let hasOptions = 0;
 }
@@ -1944,8 +2020,7 @@
   }];
 
   let arguments = (ins
-    // TODO: Missing uint8
-    TensorOf<[F32, I8, I32, I64]>:$input,
+    TensorOf<[F32, I8, I32, I64, TFL_Uint8]>:$input,
     I32Tensor:$k);
 
   let results = (outs
@@ -1963,10 +2038,12 @@
 // dimensions.
 def TFL_TransposeOp : TFL_Op<"transpose",
   [NoSideEffect,
+   TFL_OperandHasRank<1,1>,
    // TODO(jpienaar): these are only true dynamically, change so that it works
    // with unknowns.
-   // TFL_OperandHasRank<1,1>,
    // TFL_OperandRankEquals1DimOfOperand<0, 1>,
+   PredOpTrait<"input and output must have same element type",
+   TCresVTEtIsSameAsOp<0, 0>>,
    TFL_SameOperandsAndResultsScale]> {
   let summary = "Transpose operator";
 
@@ -1976,12 +2053,14 @@
 
   let arguments = (
     ins AnyTensor:$x,
-    AnyTensor:$perm
+    TensorOf<[I32]>:$perm
   );
 
   let results = (outs
     AnyTensor:$y
   );
+
+  let hasFolder = 1;
 }
 
 def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect]> {
@@ -2080,7 +2159,34 @@
   );
 }
 
-def TFL_SplitOp : TFL_Op<"split", [NoSideEffect]> {
+def TFL_SpaceToDepthOp: TFL_Op<"space_to_depth", [
+    NoSideEffect,
+    TFL_SameOperandsAndResultsScale,
+    PredOpTrait<"input and output must have same element type",
+      TCresVTEtIsSameAsOp<0, 0>>
+  ]> {
+  let summary = "SpaceToDepth operator";
+
+  let description = [{
+    Rearranges blocks of spatial data, into depth. More specifically,
+    this op outputs a copy of the input tensor where values from the `height`
+    and `width` dimensions are moved to the `depth` dimension.
+    `block_size` indicates the input block size.
+   }];
+
+  let arguments = (ins
+    TensorOf<[F32, I8, I32, I64, TFL_Uint8, TFL_QUI8]>:$input,
+    I32Attr:$block_size
+  );
+
+  let results = (outs
+    TensorOf<[F32, I8, I32, I64, TFL_Uint8, TFL_QUI8]>:$output
+  );
+
+  let hasOptions = 1;
+}
+
+def TFL_SplitOp : TFL_Op<"split", [NoSideEffect, TFL_SameOperandsAndResultsScale]> {
   let summary = "Splits a tensor into `num_split` tensors along one dimension.";
 
   let description = [{
@@ -2091,18 +2197,18 @@
 
   let arguments = (ins
     I32Tensor:$split_dim,
-    TensorOf<[F32, I16, I32, I64]>:$value,
+    TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>:$value,
     I32Attr:$num_splits
   );
 
   let results = (outs
-    Variadic<TensorOf<[F32, I16, I32, I64]>>:$outputs
+    Variadic<TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>>:$outputs
   );
 
   let hasOptions = 1;
 }
 
-def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect]> {
+def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect, TFL_SameOperandsAndResultsScale]> {
   let summary = "Splits a tensor into `num_split` tensors along one dimension.";
 
   let description = [{
@@ -2112,14 +2218,14 @@
   }];
 
   let arguments = (ins
-    TensorOf<[F32, I16, I32, I64]>:$value,
+    TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>:$value,
     I32Tensor:$size_splits,
     I32Tensor:$split_dim,
     I32Attr:$num_splits
   );
 
   let results = (outs
-    Variadic<TensorOf<[F32, I16, I32, I64]>>:$outputs
+    Variadic<TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>>:$outputs
   );
 
   let hasOptions = 1;
@@ -2146,6 +2252,64 @@
   let hasOptions = 1;
 }
 
+def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor",
+                                [NoSideEffect]> {
+  let summary = "ResizeNearestNeighbor Op";
+
+  let description = [{
+    Resize `images` to `size` using nearest neighbor interpolation.
+  }];
+
+  let arguments = (ins
+    TensorOf<[F32, I8, TFL_Uint8]>:$input,
+    TensorOf<[I32]>:$size,
+    BoolAttr:$align_corners
+  );
+
+  let results = (outs
+    TensorOf<[F32, I8, TFL_Uint8]>:$output
+  );
+
+  let hasOptions = 1;
+}
+
+def TFL_SparseToDenseOp : TFL_Op<"sparse_to_dense", [NoSideEffect]> {
+  let summary = "Converts a sparse representation into a dense tensor.";
+
+  let description = [{
+Builds an array `dense` with shape `output_shape` such that
+
+```
+# If sparse_indices is scalar
+dense[i] = (i == sparse_indices ? sparse_values : default_value)
+
+# If sparse_indices is a vector, then for each i
+dense[sparse_indices[i]] = sparse_values[i]
+
+# If sparse_indices is an n by d matrix, then for each i in [0, n)
+dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i]
+```
+
+All other values in `dense` are set to `default_value`.  If `sparse_values` is a
+scalar, all sparse indices are set to this single value.
+
+Indices should be sorted in lexicographic order, and indices must not
+contain any repeats. If `validate_indices` is true, these properties
+are checked during execution.
+  }];
+
+  let arguments = (ins
+    TFL_I32OrI64Tensor:$sparse_indices,
+    TFL_I32OrI64Tensor:$output_shape,
+    TensorOf<[I32, I64, I8, TFL_Uint8, F32]>:$sparse_values,
+    TensorOf<[I32, I64, I8, TFL_Uint8, F32]>:$default_value
+  );
+
+  let results = (outs
+    TensorOf<[I32, I64, I8, TFL_Uint8, F32]>:$dense
+  );
+}
+
 def TFL_StridedSliceOp: TFL_Op<"strided_slice",
   [
     NoSideEffect,
@@ -2470,7 +2634,7 @@
   let verifier = [{ return Verify(*this); }];
 }
 
-// UnidirectionalSequenceLstm op .
+// UnidirectionalSequenceLstm op.
 // TODO(ashwinm): Add constraint to validate the combination of operands
 // that are valid for hybrid vs fully quantized vs float only semantics
 def TFL_UnidirectionalSequenceLSTMOp :
@@ -2550,4 +2714,80 @@
   let verifier = [{ return Verify(*this); }];
 }
 
+def RnnResultConstraint : PredOpTrait<
+  "the input and result tensor elemental types must be same",
+  TCresVTEtIsSameAsOp<0, 0>>;
+
+// UnidirectionalSequenceRNN op.
+def TFL_UnidirectionalSequenceRNNOp :
+  TFL_Op<"unidirectional_sequence_rnn",
+         [RnnResultConstraint, StatefulOperands<[4]>]> {
+
+  let summary = "Unidirectional sequence rnn operator";
+
+  let description = [{
+    A recurrent neural network specified by an RNN cell. This Op takes in input
+    in a format {batch_size, seq_len, input_size} or
+    {seq_len, batch_size, input_size} if it's time-majored.
+
+    It implements the following operation for
+    each element in the sequence s = 1...sequence_length:
+      outputs[s] = state = activation(RNNOp(inputs[s]))
+
+    where RNNOp is RNNOp TF Lite Op and the “activation” is the function passed
+    as the “fused_activation_function” argument (if not “NONE”).
+  }];
+
+  let arguments = (
+    ins TensorOf<[F32, I8]>:$input,
+
+    // Weights
+    TFL_TensorOfOrNone<[F32, I8]>:$input_to_input_weights,
+
+    // Recurrent weights
+    TFL_TensorOfOrNone<[F32, I8]>:$recurrent_to_input_weights,
+
+    // Bias
+    TFL_TensorOfOrNone<[F32]>:$input_gate_bias,
+
+    // Hidden state.
+    TFL_StatefulTensor:$hidden_state,
+
+    // Attributes
+    BoolAttr:$time_major,
+    TFL_AFAttr:$fused_activation_function
+  );
+
+  let results = (outs AnyTensor:$output);
+
+  let hasOptions = 1;
+
+  let customOption = "SequenceRNNOptions";
+
+  let verifier = [{ return Verify(*this); }];
+}
+
+def TFL_WhereOp : TFL_Op<"where", [NoSideEffect]> {
+  let summary = "Returns locations of nonzero / true values in a tensor.";
+
+  let description = [{
+This operation returns the coordinates of true elements in `condition`. The
+coordinates are returned in a 2-D tensor where the first dimension (rows)
+represents the number of true elements, and the second dimension (columns)
+represents the coordinates of the true elements. Keep in mind, the shape of
+the output tensor can vary depending on how many true values there are in
+`condition`. Indices are output in row-major order.
+  }];
+
+  let arguments = (ins
+    I1Tensor:$input
+  );
+
+  // TODO(haoliang): TF Lite only support I32 output right now, need to fix
+  // either here or in the kernel.
+  let results = (outs
+    TFL_I32OrI64Tensor:$index
+  );
+}
+
 #endif // TFL_OPS
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_traits.h b/tensorflow/compiler/mlir/lite/ir/tfl_traits.h
index 97fc87a..2e119f4 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_traits.h
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_traits.h
@@ -18,108 +18,14 @@
 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
 #define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
 
-#include "mlir/Dialect/QuantOps/QuantTypes.h"  // TF:local_config_mlir
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/Support/LLVM.h"  // TF:local_config_mlir
-#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h"
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
 
 namespace mlir {
 namespace OpTrait {
 namespace TFL {
 
-using QuantizedType = mlir::quant::QuantizedType;
-using UniformQuantizedType = mlir::quant::UniformQuantizedType;
-
-// The base class that all the quantization related OpTrait implements.
-template <typename ConcreteType, template <typename> class TraitType>
-struct QuantizationSpecTraitBase : public TraitBase<ConcreteType, TraitType> {
-  static bool IsBias(int index) { return false; }
-  static bool IsQuantizable() { return true; }
-};
-
-// This class provides the API for TFL ops that requires same input and output
-// scale as the quantization results. This is used as a trait like this:
-//
-//   class TransposeOp
-//       : public Op<TransposeOp, OpTrait::TFL::SameOperandsAndResultsScale> {
-//
-template <typename ConcreteType>
-class SameOperandsAndResultsScale
-    : public QuantizationSpecTraitBase<ConcreteType,
-                                       SameOperandsAndResultsScale> {};
-
-// This class provides the API for TFL ops that has a fixed output value range.
-// This is used as a trait like this:
-//
-//   class SoftmaxOp
-//       : public Op<SoftmaxOp,
-//           OpTrait::TFL::FixedResultUniformScale<
-//               8, -128, 390625, -8, 0, 255, false>::Impl> {
-//
-// TODO(fengliuai): create a better way to epxress floating point scale in the
-// template argument list.
-template <unsigned BitWidth, int ZeroPoint, int ScaleMantissa, int ScaleExp,
-          int64_t StorageTypeMin, int64_t StorageTypeMax, bool Sign>
-class FixedResultUniformScale {
- public:
-  template <typename ConcreteType>
-  class Impl
-      : public QuantizationSpecTraitBase<
-            ConcreteType, FixedResultUniformScale<
-                              BitWidth, ZeroPoint, ScaleMantissa, ScaleExp,
-                              StorageTypeMin, StorageTypeMax, Sign>::Impl> {
-   public:
-    QuantizedType GetResultQuantizedType(int index) {
-      auto op = this->getOperation();
-      auto result_type =
-          op->getResult(index)->getType().template cast<TensorType>();
-      Builder builder(op->getContext());
-      IntegerType storage_type = builder.getIntegerType(BitWidth);
-      const double scale = static_cast<double>(ScaleMantissa) *
-                           ::pow(10.0, static_cast<double>(ScaleExp));
-      return UniformQuantizedType::getChecked(
-          Sign, storage_type, result_type.getElementType(), scale, ZeroPoint,
-          StorageTypeMin, StorageTypeMax, builder.getUnknownLoc());
-    }
-  };
-};
-
-// This class provides the API for TFL ops that has input as bias. This is used
-// as a trait like this:
-//
-//   class Conv2DOp
-//       : public Op<Conv2DOp, OpTrait::TFL::AccumulatorScale<2, 0, 1>::Impl> {
-//
-// TODO(fengliuai): supports a configurable accumulator bit width.
-template <int Bias, int... Operands>
-class AccumulatorUniformScale {
- public:
-  template <typename ConcreteType>
-  class Impl
-      : public QuantizationSpecTraitBase<
-            ConcreteType, AccumulatorUniformScale<Bias, Operands...>::Impl> {
-   public:
-    // Whether the index-th operand is a bias.
-    static bool IsBias(int index) { return index == Bias; }
-
-    // Returns the indexes of all the non-bias operands.
-    static std::vector<int> GetAllNonBiasOperands() {
-      return std::vector<int>({Operands...});
-    }
-  };
-};
-
-// This class provides the API for TFL ops that shouldn't be quantized. This is
-// used as a trait like this:
-//
-//   class LessOp : public Op<LessOp, OpTrait::TFL::NoQuantizableResult> {
-//
-template <typename ConcreteType>
-class NoQuantizableResult
-    : public QuantizationSpecTraitBase<ConcreteType, NoQuantizableResult> {
- public:
-  static bool IsQuantizable() { return false; }
-};
-
 // The trait to specify that the specified operands of the TFL op are stateful.
 // This is used as a trait like this:
 //
diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc
index 2a60715..ef8b669 100644
--- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc
+++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc
@@ -18,7 +18,7 @@
 #include "mlir/IR/MLIRContext.h"  // TF:local_config_mlir
 #include "mlir/IR/Module.h"  // TF:local_config_mlir
 #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
-#include "tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.h"
+#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/lib/core/errors.h"
@@ -129,13 +129,14 @@
   bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
   bool emit_custom_ops = toco_flags.allow_custom_ops();
   specs.prune_unused_nodes = true;
+  specs.convert_legacy_fed_inputs = true;
   WarningUnusedFlags(model_flags, toco_flags);
 
   bool emit_quant_adaptor_ops = false;
   bool lower_tensor_list_ops = true;
   TF_ASSIGN_OR_RETURN(
       auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));
-  return ConvertTFControlFlowToTFLOrFlatbuffer(
+  return ConvertTFExecutorToTFLOrFlatbuffer(
       module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
       emit_select_tf_ops, emit_custom_ops, emit_quant_adaptor_ops,
       lower_tensor_list_ops, result);
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc
new file mode 100644
index 0000000..196b7e3
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc
@@ -0,0 +1,702 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "mlir/Dialect/QuantOps/QuantTypes.h"  // TF:local_config_mlir
+#include "mlir/IR/Attributes.h"  // TF:local_config_mlir
+#include "mlir/IR/Builders.h"  // TF:local_config_mlir
+#include "mlir/IR/Function.h"  // TF:local_config_mlir
+#include "mlir/IR/MLIRContext.h"  // TF:local_config_mlir
+#include "mlir/IR/Matchers.h"  // TF:local_config_mlir
+#include "mlir/IR/Operation.h"  // TF:local_config_mlir
+#include "mlir/IR/StandardTypes.h"  // TF:local_config_mlir
+#include "mlir/IR/Value.h"  // TF:local_config_mlir
+#include "mlir/StandardOps/Ops.h"  // TF:local_config_mlir
+#include "mlir/Support/LLVM.h"  // TF:local_config_mlir
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace mlir {
+namespace TFL {
+namespace {
+static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); }
+
+// The state for each op result during the quantization parameters propagation.
+struct QuantState {
+  // Quantization parameters propagated to an op result.
+  QuantParams params;
+  // A flag indicates this state (the params) shouldn't be changed after it is
+  // initialized. This flag will be set to true if the quantization parameters
+  // are from the quantization-aware training.
+  const bool immutable;
+
+  bool IsEmpty() { return EmptyParams(params); }
+};
+
+// The state for rescaling the propagated quantization parameters. This can be
+// on the input side to satisfy the constraint of previous operation, or on the
+// output side to satisfy the constraint of the next operation.
+struct RequantizeState {
+  // Sometimes, we have to "requantize" the quantization result to satisfy all
+  // the constraints. The "requantize" can happen either on the input or output
+  // of the quantization result.
+  enum RequantizePosition {
+    NO_REQUANTIZE,
+    ON_INPUT,
+    ON_OUTPUT
+  } pos = NO_REQUANTIZE;
+
+  // Quantization parameters will be used to add the requantize ops.
+  QuantParams params;
+};
+
+// This is a worklist-driven driver for propagating quantization parameters
+// across operations.
+//
+// The initial quantization parameters are extracted from the quantized type
+// between adjacent tfl.quantize and tfl.dequantize ops. All these initial
+// parameters are marked as immutable because they are from quantization-aware
+// training.
+//
+// The algorithm traverses each op and sets the quantization parameters of its
+// operands and results, according to its quantization specification, and then
+// adds the operands and results to the worklist. If there are any conflicts
+// (for example, there are quantization parameters propagated from the previous
+// iteration), this process stops if the existing parameters are the immutable,
+// or adding `requantize` op to resolve the conflicts.
+//
+// After the algorithm is converaged, pairs of tfl.quantize and tfl.dequantize
+// are inserted to the right position to materialize the propagation and
+// requantize results.
+//
+class QuantizationDriver {
+ public:
+  explicit QuantizationDriver(FuncOp fn, bool is_signed,
+                              OpQuantSpecGetter op_quant_spec_getter)
+      : fn_(fn),
+        builder_(fn.getBody()),
+        is_signed_(is_signed),
+        op_quant_spec_getter_(op_quant_spec_getter) {}
+
+  // The entry point of the quantization parameters propagation.
+  void Run();
+
+ private:
+  // This is used to identify an operand or result of an op. The second element
+  // of this pair is the index of the operand or result.
+  using OpValue = std::pair<mlir::Operation *, int>;
+
+  // Sets up the states for all the op results in the function.
+  void Initialize();
+
+  // Propagates the quantization parameters across all the ops.
+  bool PropagateParams();
+
+  // Inserts the Quantize and Dequantize ops according to the propagation
+  // result.
+  void Finalize();
+
+  // Whether the constant is used as a bias input of another op. Here we assume
+  // bias is used immediately by the user. This assumption is always correct
+  // after constant folding.
+  bool UsedAsBias(ConstantOp cst) {
+    Value *value = cst.getResult();
+    for (auto &use : value->getUses()) {
+      auto biases = GetQuantSpec(use.getOwner())->biases_params;
+      if (biases.find(use.getOperandNumber()) != biases.end()) return true;
+    }
+    return false;
+  }
+
+  // Returns all the related quantization constraints of the op.
+  std::unique_ptr<OpQuantSpec> GetQuantSpec(Operation *op);
+
+  // Whether Quantization parameters have been propagated to the results of this
+  // op.
+  bool IsQuantized(Operation *op);
+
+  // Adds all the users of index-th result of op to the work list.
+  void AddUserToList(Operation *op, int index) {
+    for (auto *user : op->getResult(index)->getUsers()) {
+      work_list_.push_back(user);
+    }
+  }
+
+  // Adds the defining op of index-th operand of op to the work list.
+  void AddOperandToList(Operation *op, int index) {
+    if (auto *inst = op->getOperand(index)->getDefiningOp()) {
+      work_list_.push_back(inst);
+    }
+  }
+
+  // Returns the quantization params for the bias input from the non-bias
+  // operands which have their indexes in the `non_biases` vector. The returned
+  // parameters are calculated by `func`.
+  QuantParams GetBiasParams(Operation *op, int bias,
+                            const std::vector<int> &non_biases,
+                            AccumulatorScaleFunc func);
+
+  // Sets the quantization parameters of the result to a fixed value. If any
+  // quantization parameters have been propagated, a `requantize` will happen on
+  // the input of propagated quantization.
+  bool SetResultParams(Operation *op, int index, QuantParams params);
+
+  // Sets the quantization parameters of the operand to a fixed value. If any
+  // quantization parameters have been propagated, a `requantize` will happen on
+  // the output of propagated quantization.
+  bool SetOperandParams(Operation *op, int index, QuantParams params);
+
+  // Sets the quantization parameters of the constant result according to its
+  // content.
+  bool SetConstantResultParams(Operation *op);
+
+  // Inserts the Quantize and Dequantize ops for quantizing the index-th result
+  // of the op.
+  void QuantizeOpResult(Operation *op, int index, QuantParams params);
+
+  void QuantizeArg(BlockArgument *arg, QuantParams params);
+
+  // Inserts the Quantize and Dequantize ops to quantize the value and returns
+  // the Quantize op.
+  void QuantizeValue(Value *value, QuantParams params, Location loc);
+
+  // Inserts the Quantize ops for requantizing the index-th result of the op.
+  void RequantizeOpResult(Operation *op, int index, RequantizeState *state);
+
+  void RequantizeArg(BlockArgument *arg, RequantizeState *state);
+
+  // Inserts the Quantize and Dequantize ops to quantize the value and returns
+  // the Quantize op.
+  void RequantizeValue(Value *value, RequantizeState *state, Location loc);
+
+  // A heuristic to get the quantization parameter satisfies the same scale
+  // constraints for the op. Returns an empty option if this quantization
+  // parameter doesn't exist.
+  QuantParams GetQuantParamsForSameScaleConstraint(Operation *op);
+
+  // Returns the state of the index-th operand of the op.
+  QuantState &GetOperandQuantState(Operation *op, int index) {
+    return states_[operand_states_[{op, index}]];
+  }
+
+  // Returns the state of the index-th result of the op.
+  QuantState &GetResultQuantState(Operation *op, int index) {
+    return states_[result_states_[{op, index}]];
+  }
+
+  QuantState &GetArgQuantState(BlockArgument *arg) {
+    return states_[arg_states_[arg]];
+  }
+
+  // Returns the state of the index-th operand of the op.
+  RequantizeState &GetOperandRequantizeState(Operation *op, int index) {
+    return rescale_states_[operand_states_[{op, index}]];
+  }
+
+  // Returns the state of the index-th result of the op.
+  RequantizeState &GetResultRequantizeState(Operation *op, int index) {
+    return rescale_states_[result_states_[{op, index}]];
+  }
+
+  RequantizeState &GetArgRequantizeState(BlockArgument *arg) {
+    return rescale_states_[arg_states_[arg]];
+  }
+
+  // Uses the type of `val` to set the initial state of the index-th result if
+  // `as_result` is true or index-th operand if `as_result` is false. The state
+  // is immutable if the type is a quantized type. Returns the index of this
+  // new state in the state vector.
+  int InitializeState(Operation *op, int index, Value *val, bool as_result);
+
+  // Sets the state of the index-th operand of the op. If this operand is
+  // cached, uses the cached result without creating new entry in the state
+  // vector. Otherwise, allocate a new entry in the state vector.
+  void InitializeOperandState(Operation *op, int index, Value *in,
+                              llvm::DenseMap<Value *, int> *cache,
+                              bool is_argument) {
+    auto cached = cache->insert({in, 0});
+    if (!cached.second) {
+      operand_states_.insert({{op, index}, cached.first->second});
+      return;
+    }
+    cached.first->second = InitializeState(op, index, in, /*as_result=*/false);
+    if (is_argument) {
+      auto *arg = llvm::cast<BlockArgument>(in);
+      arg_states_[arg] = cached.first->second;
+      args_.push_back(arg);
+    }
+  }
+
+  // Sets the state of the index-th result of the op. If this result is cached,
+  // uses the cached result without creating new entry in the state vector.
+  // Otherwise, allocate a new entry in the state vector.
+  void InitializeResultState(Operation *op, int index, Value *res,
+                             llvm::DenseMap<Value *, int> *cache) {
+    auto cached = cache->insert({res, 0});
+    if (!cached.second) {
+      result_states_.insert({{op, index}, cached.first->second});
+      return;
+    }
+    cached.first->second = InitializeState(op, index, res, /*as_result=*/true);
+  }
+
+  FuncOp fn_;
+  OpBuilder builder_;
+  bool is_signed_;
+
+  // All the ops needs to propagate the quantization parameters to.
+  std::vector<Operation *> work_list_;
+  std::unordered_set<Operation *> quantized_;
+
+  // The vector contains all the quantization parameters propagated from the
+  // defining operations of the value, or from the quantization aware training.
+  std::vector<QuantState> states_;
+
+  // The map contains all the quantization parameters which are required to
+  // satisfy the same operands and results constraint. The keys of this map are
+  // the values from `operand_states_` and `result_state_`.
+  std::unordered_map<int, RequantizeState> rescale_states_;
+
+  // Maps of indexes to the propagation state vector from the ops operands,
+  // results and arguments.
+  llvm::DenseMap<OpValue, int> operand_states_;
+  llvm::DenseMap<OpValue, int> result_states_;
+  llvm::DenseMap<BlockArgument *, int> arg_states_;
+
+  // This vector is to preserve the arguments order, so the newly inserted
+  // quantized ops for the arguments are deterministically ordered.
+  llvm::SmallVector<BlockArgument *, 4> args_;
+
+  OpQuantSpecGetter op_quant_spec_getter_;
+};
+}  // namespace
+
+std::unique_ptr<OpQuantSpec> QuantizationDriver::GetQuantSpec(Operation *op) {
+  return op_quant_spec_getter_(op);
+}
+
+bool QuantizationDriver::IsQuantized(Operation *op) {
+  for (int i = 0, e = op->getNumResults(); i != e; ++i) {
+    if (GetResultQuantState(op, i).IsEmpty()) return false;
+  }
+  return true;
+}
+
+int QuantizationDriver::InitializeState(Operation *op, int index, Value *val,
+                                        bool as_result) {
+  QuantParams params =
+      quant::QuantizedType::getQuantizedElementType(val->getType());
+  bool immutable = !EmptyParams(params);
+  int next_state_index = states_.size();
+  states_.push_back({params, immutable});
+  if (as_result)
+    result_states_.insert({{op, index}, next_state_index});
+  else
+    operand_states_.insert({{op, index}, next_state_index});
+
+  return next_state_index;
+}
+
+bool QuantizationDriver::SetConstantResultParams(Operation *op) {
+  ElementsAttr attr;
+  Value *res = op->getResult(0);
+  if (!matchPattern(res, m_Constant(&attr))) {
+    return false;
+  }
+  // TODO(fengliuai): make storage_type_width and narrow_range configurable.
+  auto final_type =
+      GetUniformQuantizedTypeForElementsAttr(attr, /*storage_type_width=*/8,
+                                             is_signed_, /*narrow_range_=*/true)
+          .dyn_cast_or_null<quant::QuantizedType>();
+  if (!final_type) return false;
+  return SetResultParams(op, 0, final_type);
+}
+
+bool QuantizationDriver::SetResultParams(Operation *op, int res_index,
+                                         QuantParams params) {
+  auto &state = GetResultQuantState(op, res_index);
+  if (state.params == params) {
+    return false;
+  }
+  if (!state.IsEmpty()) {
+    auto &rescale = GetResultRequantizeState(op, res_index);
+    rescale.params = params;
+    rescale.pos = RequantizeState::ON_INPUT;
+    return true;
+  }
+  state.params = params;
+  AddUserToList(op, res_index);
+  return true;
+}
+
+QuantParams QuantizationDriver::GetBiasParams(
+    Operation *op, int bias, const std::vector<int> &non_biases,
+    AccumulatorScaleFunc func) {
+  auto &bias_state = GetOperandQuantState(op, bias);
+  if (!bias_state.IsEmpty()) {
+    return bias_state.params;
+  }
+  std::vector<QuantParams> op_types;
+  op_types.reserve(non_biases.size());
+  for (auto non_bias : non_biases) {
+    auto &non_bias_type = GetOperandQuantState(op, non_bias);
+    op_types.push_back(non_bias_type.params);
+  }
+  if (op_types.empty()) return {};
+  return func(op_types);
+}
+
+bool QuantizationDriver::SetOperandParams(Operation *op, int index,
+                                          QuantParams params) {
+  auto &state = GetOperandQuantState(op, index);
+  if (state.params == params) {
+    return false;
+  }
+
+  if (!state.IsEmpty()) {
+    auto &rescale = GetOperandRequantizeState(op, index);
+    rescale.params = params;
+    rescale.pos = RequantizeState::ON_OUTPUT;
+    return true;
+  }
+
+  state.params = params;
+  AddOperandToList(op, index);
+  return true;
+}
+
+void QuantizationDriver::QuantizeOpResult(Operation *op, int index,
+                                          QuantParams params) {
+  builder_.setInsertionPoint(op->getBlock(), ++Block::iterator(op));
+  Value *original_result = op->getResult(index);
+  QuantizeValue(original_result, params, op->getLoc());
+}
+
+void QuantizationDriver::QuantizeArg(BlockArgument *arg, QuantParams params) {
+  builder_.setInsertionPointToStart(arg->getOwner());
+  QuantizeValue(arg, params, builder_.getUnknownLoc());
+}
+
+void QuantizationDriver::QuantizeValue(Value *value, QuantParams params,
+                                       Location loc) {
+  Type expressed_type = value->getType();
+  Type new_type = params.castFromExpressedType(expressed_type);
+  // This value isn't an expressed type (float), skip.
+  if (!new_type) return;
+
+  TypeAttr type_attr = builder_.getTypeAttr(new_type);
+  auto quantize =
+      builder_.create<TFL::QuantizeOp>(loc, new_type, value, type_attr);
+  auto dequantize = builder_.create<TFL::DequantizeOp>(loc, expressed_type,
+                                                       quantize.output());
+  // `original_result` has a use to `quantize`, so this will replace that use
+  // by the result of `dequantize`. Remember to reset that use afterwards
+  value->replaceAllUsesWith(dequantize);
+  quantize.getOperation()->replaceUsesOfWith(dequantize, value);
+}
+
+void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
+                                            RequantizeState *state) {
+  if (state->pos == RequantizeState::NO_REQUANTIZE) return;
+  builder_.setInsertionPoint(op->getBlock(), ++Block::iterator(op));
+  Value *value = op->getResult(index);
+  if (state->pos == RequantizeState::ON_OUTPUT) {
+    Operation *op = value->getUses().begin().getUser();  // `quantize` op
+    // The requantize op is inserted between `quantize` and `dequantize` ops.
+    value = op->getResult(0);
+    builder_.setInsertionPoint(op->getBlock(), ++Block::iterator(op));
+  }
+  RequantizeValue(value, state, op->getLoc());
+}
+
+void QuantizationDriver::RequantizeArg(BlockArgument *arg,
+                                       RequantizeState *state) {
+  Value *value = arg;
+  builder_.setInsertionPointToStart(arg->getOwner());
+  if (value->hasOneUse()) {
+    auto user = value->use_begin().getUser();
+    if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
+      value = q.output();
+      builder_.setInsertionPoint(arg->getOwner(), ++Block::iterator(user));
+    }
+  }
+  RequantizeValue(value, state, builder_.getUnknownLoc());
+}
+
+void QuantizationDriver::RequantizeValue(Value *value, RequantizeState *state,
+                                         Location loc) {
+  Type new_type;
+  if (state->pos == RequantizeState::ON_INPUT) {
+    Type expressed_type = value->getType();
+    // The value needs to be requantized. A Quantize op will be created to use
+    // it as the operand and replace its uses.
+    new_type = state->params.castFromExpressedType(expressed_type);
+  } else {
+    Type expressed_type =
+        quant::QuantizedType::castToExpressedType(value->getType());
+    if (!expressed_type) return;
+
+    // The value needs to be requantized. A Quantize op will be created to use
+    // it as the operand and replace its uses.
+    new_type = state->params.castFromExpressedType(expressed_type);
+  }
+  // This value isn't an expressed type (float), skip.
+  if (!new_type) return;
+
+  TypeAttr type_attr = builder_.getTypeAttr(new_type);
+  auto requantize_op =
+      builder_.create<TFL::QuantizeOp>(loc, new_type, value, type_attr);
+  value->replaceAllUsesWith(requantize_op);
+  requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value);
+}
+
+// A heuristic to get quantization parameters satisfies the same scale
+// constraints:
+// - If there are immutable states,
+//   - use the single input, or,
+//   - use the single output, or,
+//   - use the first one in the collection,
+// - use the single input if it is ready, or,
+// - use the single output if it is ready, or,
+// - use use the first ready one in the collection.
+QuantParams QuantizationDriver::GetQuantParamsForSameScaleConstraint(
+    Operation *op) {
+  // Two vector to collect Non-empty operands and results states.
+  std::vector<QuantState *> mutable_states, immutable_states;
+  for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
+    auto &state = GetOperandQuantState(op, i);
+    if (state.immutable) {
+      immutable_states.push_back(&state);
+    } else if (!state.IsEmpty()) {
+      mutable_states.push_back(&state);
+    }
+  }
+
+  int immutable_operands_num = immutable_states.size();
+  int mutable_operands_num = mutable_states.size();
+  // Use the operand's state if it is immutable and it is the only one operand.
+  if (op->getNumOperands() == 1 && immutable_operands_num == 1) {
+    return immutable_states.front()->params;
+  }
+
+  for (int i = 0, e = op->getNumResults(); i != e; ++i) {
+    auto &state = GetResultQuantState(op, i);
+    if (state.immutable) {
+      immutable_states.push_back(&state);
+    } else if (!state.IsEmpty()) {
+      mutable_states.push_back(&state);
+    }
+  }
+
+  int immutable_results_num = immutable_states.size() - immutable_operands_num;
+  int mutable_results_num = mutable_states.size() - mutable_operands_num;
+  // Use the result's state if it is immutable and it is the only one result.
+  if (op->getNumResults() == 1 && immutable_results_num == 1) {
+    return immutable_states.back()->params;
+  }
+
+  // Use the first immutable state to quantize the rest operands and results.
+  if (!immutable_states.empty()) return immutable_states.front()->params;
+
+  // If there are no immutable states, use the operand's state if it is the only
+  // one operand and has parameters propagated.
+  if (op->getNumOperands() == 1 && mutable_operands_num == 1) {
+    return mutable_states.front()->params;
+  }
+
+  // If there are no immutable states, use the result's state if it is the only
+  // one result and has parameters propagated.
+  if (op->getNumResults() == 1 && mutable_results_num == 1) {
+    return mutable_states.back()->params;
+  }
+
+  // Use the first propagated state to quantize the rest operands and results.
+  if (!mutable_states.empty()) return mutable_states.front()->params;
+
+  // None operands/results have parameters propagated, skip this node for now.
+  return {};
+}
+
+// This method scans the operations in the function to setup the initial
+// states for quantization parameter propagation.
+// TODO(fengliuai): This algorithm assumes there are only one pair of
+// tfl.quantize and tfl.dequantize ops between two quantizable ops. A sanity
+// check should be applied.
+void QuantizationDriver::Initialize() {
+  llvm::DenseMap<Value *, int> value_to_state;
+
+  fn_.walk([&](Operation *op) {
+    if (op->isKnownTerminator()) return;
+    if (!GetQuantSpec(op)->is_quantizable) return;
+    work_list_.push_back(op);
+
+    for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
+      auto *operand = op->getOperand(i);
+      bool is_argument = true;
+      if (auto *inst = operand->getDefiningOp()) {
+        // If the operand comes from a tfl.dequantize op, we use the quantized
+        // input of this tfl.dequantize op to set the state.
+        if (auto dq = llvm::dyn_cast<TFL::DequantizeOp>(inst)) {
+          operand = dq.input();
+        }
+        is_argument = false;
+      }
+      InitializeOperandState(op, i, operand, &value_to_state, is_argument);
+    }
+
+    for (int res = 0, e = op->getNumResults(); res != e; ++res) {
+      auto *result = op->getResult(res);
+      // If the result has been quantized, it should only be used by a
+      // tfl.quantize op. For this case, we uses the quantized result to create
+      // the state and mark it immutable.
+      if (result->hasOneUse()) {
+        auto user = result->use_begin().getUser();
+        if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
+          result = q.output();
+        }
+      }
+      InitializeResultState(op, res, result, &value_to_state);
+    }
+  });
+}
+
+bool QuantizationDriver::PropagateParams() {
+  // TODO(fengliuai): uses a typed indicator instead of a bool value.
+  bool changed = false;
+  while (!work_list_.empty()) {
+    Operation *op = work_list_.back();
+    work_list_.pop_back();
+
+    // This op has been quantized, so we should not consider it again.
+    if (quantized_.find(op) != quantized_.end()) continue;
+    quantized_.insert(op);
+
+    auto spec = GetQuantSpec(op);
+
+    // If the op has no quantizable result, the quantization parameters will not
+    // be propagated to the results.
+    if (!spec->is_quantizable) continue;
+
+    if (auto cst = llvm::dyn_cast<ConstantOp>(op)) {
+      // This constant is used as a bias in another op, then the quantization
+      // parameters are determined by that op.
+      if (UsedAsBias(cst) || IsQuantized(op)) continue;
+
+      // The quantization parameters are determined by the content of the
+      // constant.
+      changed |= SetConstantResultParams(op);
+      continue;
+    }
+
+    if (spec->requires_same_scale) {
+      auto params = GetQuantParamsForSameScaleConstraint(op);
+      // The quantization parameters haven't been propagated to any operands or
+      // results. Skip this node for now.
+      if (!params) {
+        quantized_.erase(op);
+        continue;
+      }
+
+      // Use the final state to set all the operands' parameters.
+      for (int i = 0, e = op->getNumOperands(); i != e; ++i)
+        changed |= SetOperandParams(op, i, params);
+
+      // Use the final state to set all the results' parameters.
+      for (int res = 0, e = op->getNumResults(); res != e; ++res)
+        changed |= SetResultParams(op, res, params);
+    }
+
+    // TODO(fengliuai): make the bit width configurable.
+    auto key = std::make_pair(8, is_signed_);
+    auto &restricted_outputs = spec->restricted_output_params[key];
+    for (int i = 0, e = restricted_outputs.size(); i != e; ++i) {
+      changed |= SetResultParams(op, i, restricted_outputs[i]);
+    }
+
+    for (auto &it : spec->biases_params) {
+      auto params =
+          GetBiasParams(op, it.first, it.second.first, it.second.second);
+      if (!params) {
+        quantized_.erase(op);
+        continue;
+      }
+      changed |= SetOperandParams(op, it.first, params);
+    }
+  }
+  return changed;
+}
+
+void QuantizationDriver::Finalize() {
+  for (auto *arg : args_) {
+    auto &state = GetArgQuantState(arg);
+    auto &requantize = GetArgRequantizeState(arg);
+    if (state.IsEmpty() ||
+        (state.immutable && requantize.pos == RequantizeState::NO_REQUANTIZE)) {
+      continue;
+    }
+
+    if (!state.immutable) {
+      QuantizeArg(arg, state.params);
+    }
+
+    if (requantize.pos != RequantizeState::NO_REQUANTIZE) {
+      RequantizeArg(arg, &requantize);
+    }
+  }
+
+  for (auto it : result_states_) {
+    Operation *op = it.first.first;
+    int res_index = it.first.second;
+    auto &state = GetResultQuantState(op, res_index);
+    auto &requantize = GetResultRequantizeState(op, res_index);
+    if (state.IsEmpty() ||
+        (state.immutable && requantize.pos == RequantizeState::NO_REQUANTIZE)) {
+      continue;
+    }
+
+    if (!state.immutable) {
+      QuantizeOpResult(op, res_index, state.params);
+    }
+
+    if (requantize.pos != RequantizeState::NO_REQUANTIZE) {
+      RequantizeOpResult(op, res_index, &requantize);
+    }
+  }
+}
+
+void QuantizationDriver::Run() {
+  Initialize();
+  if (PropagateParams()) {
+    Finalize();
+  }
+}
+
+void ApplyQuantizationParamsPropagation(
+    mlir::FuncOp func, bool is_signed, OpQuantSpecGetter op_quant_spec_getter) {
+  QuantizationDriver(func, is_signed, op_quant_spec_getter).Run();
+}
+
+}  // namespace TFL
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc
new file mode 100644
index 0000000..31a7a18
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc
@@ -0,0 +1,134 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
+
+#include "mlir/Dialect/QuantOps/FakeQuantSupport.h"  // TF:local_config_mlir
+#include "mlir/Dialect/QuantOps/QuantTypes.h"  // TF:local_config_mlir
+#include "mlir/Dialect/QuantOps/QuantizeUtils.h"  // TF:local_config_mlir
+#include "mlir/Dialect/QuantOps/UniformSupport.h"  // TF:local_config_mlir
+#include "mlir/IR/Attributes.h"  // TF:local_config_mlir
+#include "mlir/IR/StandardTypes.h"  // TF:local_config_mlir
+#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
+
+namespace mlir {
+namespace TFL {
+
+// Returns the quantized type for the
+// input_type/min/max/storag_type_width/narrow_range.
+static Type GetQuantizedType(Builder builder, Type input_type, double min,
+                             double max, int storage_type_width,
+                             bool narrow_range, bool is_signed) {
+  auto converter =
+      quant::ExpressedToUniformQuantizedConverter::forInputType(input_type);
+
+  quant::UniformQuantizedType quantizedEleType = quant::fakeQuantAttrsToType(
+      builder.getUnknownLoc(), storage_type_width, min, max, narrow_range,
+      converter.expressedType, is_signed);
+  return converter.convert(quantizedEleType);
+}
+
+TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, FloatAttr min,
+                              FloatAttr max, Type storage_type,
+                              bool narrow_range, bool is_signed) {
+  int storage_type_width = storage_type.cast<IntegerType>().getWidth();
+  Type final_type = GetQuantizedType(
+      builder, input_type, min.getValueAsDouble(), max.getValueAsDouble(),
+      storage_type_width, narrow_range, is_signed);
+  return builder.getTypeAttr(final_type);
+}
+
+TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
+                              Attribute max, IntegerAttr num_bits,
+                              BoolAttr narrow_range) {
+  FloatAttr min_value = GetSingleElementAsFloatOrSelf(min);
+  FloatAttr max_value = GetSingleElementAsFloatOrSelf(max);
+  if (!min_value || !max_value) return {};
+  return GetQuantizedTypeAttr(builder, input_type, min_value, max_value,
+                              builder.getIntegerType(num_bits.getInt()),
+                              narrow_range.getValue(), /*is_signed=*/false);
+}
+
+TypeAttr CastQuantizedTypeAttrFromExpressedType(Builder builder,
+                                                TypeAttr source, Type target) {
+  if (!source || !source.getValue().isa<TensorType>()) return {};
+  auto ele_type = source.getValue().cast<TensorType>().getElementType();
+  if (auto quantized_type = ele_type.dyn_cast<quant::QuantizedType>()) {
+    Type final_type = quantized_type.castFromExpressedType(target);
+    if (final_type) return builder.getTypeAttr(final_type);
+  }
+  return {};
+}
+
+Type GetUniformQuantizedTypeForElementsAttr(ElementsAttr attr,
+                                            unsigned storage_type_width,
+                                            bool is_signed, bool narrow_range) {
+  Builder builder(attr.getContext());
+  double min = std::numeric_limits<double>::max();
+  double max = std::numeric_limits<double>::min();
+  if (auto fp = attr.dyn_cast<DenseFPElementsAttr>()) {
+    for (auto it = fp.begin(), e = fp.end(); it != e; ++it) {
+      double ele_value = FloatAttr::getValueAsDouble(*it);
+      min = std::min(min, ele_value);
+      max = std::max(max, ele_value);
+    }
+    // The range must straddle zero.
+    if (min > 0.0 || max < 0.0) return {};
+    auto type = GetQuantizedType(builder, attr.getType(), min, max,
+                                 storage_type_width, narrow_range, is_signed);
+    if (auto ele_type = type.dyn_cast_or_null<TensorType>())
+      return ele_type.getElementType();
+  }
+
+  // The range from SplatElementAttr and other element attribute types  couldn't
+  // straddle zero, so the quantization parameters couldn't be derived from its
+  // range.
+  return {};
+}
+
+quant::QuantizedType GetUniformQuantizedTypeForBias(
+    const std::vector<quant::QuantizedType>& op_types) {
+  if (op_types.empty()) return {};
+
+  double scale = 1.0;
+  for (unsigned i = 0, e = op_types.size(); i != e; ++i) {
+    auto qtype = op_types[i].dyn_cast_or_null<quant::UniformQuantizedType>();
+    if (!qtype) return {};
+    scale *= qtype.getScale();
+  }
+  auto type = op_types.back().cast<quant::UniformQuantizedType>();
+  Builder builder(type.getContext());
+  // TODO(fengliuai): make the bit width configurable.
+  IntegerType storageType = builder.getIntegerType(32);
+  return quant::UniformQuantizedType::getChecked(
+      /*flags=*/true, storageType, type.getExpressedType(), scale,
+      /*zeroPoint=*/0,
+      quant::QuantizedType::getDefaultMininumForInteger(/*isSigned=*/true, 32),
+      quant::QuantizedType::getDefaultMaxinumForInteger(/*isSigned=*/true, 32),
+      builder.getUnknownLoc());
+}
+
+ElementsAttr Quantize(Attribute real_value, Type tensor_type) {
+  if (auto q_type =
+          quant::QuantizedType::getQuantizedElementType(tensor_type)) {
+    Type converted_type;
+    return quant::quantizeAttr(real_value, q_type, converted_type)
+        .cast<ElementsAttr>();
+  }
+  return {};
+}
+
+}  // namespace TFL
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h
new file mode 100644
index 0000000..41e21ca
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h
@@ -0,0 +1,209 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This header file defines common utils used by TFLite transformation
+// passes to work with op attributes.
+
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_
+
+#include <unordered_map>
+
+#include "mlir/Dialect/QuantOps/QuantTypes.h"  // TF:local_config_mlir
+#include "mlir/IR/BlockAndValueMapping.h"  // TF:local_config_mlir
+#include "mlir/IR/PatternMatch.h"  // TF:local_config_mlir
+#include "mlir/IR/StandardTypes.h"  // TF:local_config_mlir
+#include "mlir/StandardOps/Ops.h"  // TF:local_config_mlir
+
+namespace mlir {
+namespace TFL {
+
+using QuantParams = quant::QuantizedType;
+using SignedInteger = std::pair<unsigned, unsigned>;  // bitwidth and sign
+using QuantParamsForResults = llvm::SmallVector<QuantParams, 4>;
+using AccumulatorScaleFunc =
+    std::function<QuantParams(const std::vector<QuantParams>&)>;
+
+// Quantization spec of an op, driving the quantization algorithm.
+struct OpQuantSpec {
+  // Whether the op has quantizable result. This flag is set to false if the op
+  // has "TFL::NoQuantizableResult" trait.
+  bool is_quantizable = true;
+
+  // Whether it requires same inputs and result scale. This flag is set to true
+  // if the op has "TFL::SameOperandsAndResultScale" trait.
+  bool requires_same_scale = false;
+
+  // Maps the operand index of a bias input to its quantization specifications,
+  // including the non-bias operand indexes and the method retrieving
+  // quantization parameters from list of parameters of the non-bias operands.
+  // This map is empty if the op doesn't havea bias operand.
+  std::unordered_map<int, std::pair<std::vector<int>, AccumulatorScaleFunc>>
+      biases_params;
+
+  // Quantization parameters for value restricted outputs. This is the
+  // "hard-coded" parameters and should be used unconditionally for the
+  // quantized op. This vector is empty if the op doesn't have value resctricted
+  // outputs.
+  llvm::DenseMap<SignedInteger, QuantParamsForResults> restricted_output_params;
+};
+
+// A function signature for getting the particular OpQuantSpec for the provided
+// op.
+typedef std::unique_ptr<OpQuantSpec> (*OpQuantSpecGetter)(Operation* op);
+
+// A generic rewrite pattern which matches any N-in-1-out operations with
+// quantization parameters propagated to all the operands and results values.
+// The quantization parameters are annotated by the Q/DQ op pairs. Each matched
+// pattern are rewritten by its quantized alternatives.
+//
+// This pattern assumes all the matched ops are quantizable. This assumption is
+// always right, except when a "Q" op is used as a requantize op. For non-"Q"
+// ops, quantization parameters should be propagated to their result.
+//
+// This pattern only matches ops which only have one result.
+template <typename Q, typename DQ>
+struct GenericFullQuantizationPattern : public RewritePattern {
+  explicit GenericFullQuantizationPattern(MLIRContext* context)
+      : RewritePattern(Q::getOperationName(), 1, context) {}
+
+  PatternMatchResult matchAndRewrite(Operation* op,
+                                     PatternRewriter& rewriter) const override {
+    if (op->getNumResults() != 1) {
+      return matchFailure();
+    }
+    auto quantize_op = cast<Q>(op);
+    Operation* quantized_op = quantize_op.input()->getDefiningOp();
+    // If it is a block argument, requantize op, or has more than one result, we
+    // shouldn't rewrite this op.
+    if (!quantized_op || llvm::isa<Q>(quantized_op) ||
+        llvm::isa<DQ>(quantized_op)) {
+      return matchFailure();
+    }
+
+    // Collect all the quantized inputs and "clone" the matched op by these
+    // inputs.
+    SmallVector<Value*, 4> inputs;
+    inputs.reserve(quantized_op->getNumOperands());
+    for (auto operand : quantized_op->getOperands()) {
+      auto tensor_type = operand->getType().dyn_cast<TensorType>();
+      if (!tensor_type) {
+        // There are none type values.
+        return matchFailure();
+      }
+      auto operand_ele_type = tensor_type.getElementType();
+      if (auto op_inst = dyn_cast_or_null<DQ>(operand->getDefiningOp())) {
+        inputs.push_back(op_inst.input());
+      } else if (operand_ele_type.isa<IntegerType>()) {
+        // If the operand is an integer tensor, then it doesn't require the
+        // DQ op in the pattern.
+        inputs.push_back(operand);
+      } else {
+        return matchFailure();
+      }
+    }
+
+    // Collect all the quantized outputs and replace them by the results of the
+    // new quantized op.
+    llvm::SmallDenseMap<Value*, int> outputs_replaced;
+    SmallVector<Type, 4> output_types;
+    output_types.reserve(quantized_op->getNumResults());
+    for (auto result : llvm::enumerate(quantized_op->getResults())) {
+      if (!result.value()->hasOneUse()) return matchFailure();
+      auto result_ele_type =
+          result.value()->getType().cast<TensorType>().getElementType();
+      if (auto user = dyn_cast_or_null<Q>(*result.value()->user_begin())) {
+        outputs_replaced.insert({user.output(), result.index()});
+        output_types.push_back(user.getType());
+      } else if (result_ele_type.template isa<IntegerType>()) {
+        // If the result is an integer tensor, then it doesn't require the
+        // D op in the pattern.
+        outputs_replaced.insert({result.value(), result.index()});
+        output_types.push_back(result_ele_type);
+      } else {
+        return matchFailure();
+      }
+    }
+
+    // Use OpBuilder so we can use op name to create the new op.
+    OpBuilder builder(quantized_op);
+    OperationState new_state(quantized_op->getLoc(),
+                             quantized_op->getName().getStringRef(), inputs,
+                             output_types, quantized_op->getAttrs());
+    Operation* new_op = builder.createOperation(new_state);
+    for (auto output : outputs_replaced) {
+      output.getFirst()->replaceAllUsesWith(
+          new_op->getResult(output.getSecond()));
+    }
+    return matchSuccess();
+  }
+};
+
+// Converts the min/max/storage_type/narrow_range information to a
+// QuantizedType, and then returns the attribute containing the QuantizedType.
+TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, FloatAttr min,
+                              FloatAttr max, Type storage_type,
+                              bool narrow_range = false,
+                              bool is_signed = false);
+
+// Converts the min/max/num_bits/narrow_range information to a
+// QuantizedType, and then returns the attribute containing the QuantizedType.
+// Note that this method assumes an unsigned quantization type, which is
+// implicitly defined by FakeQuant* ops in TensorFlow.
+TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
+                              Attribute max, IntegerAttr num_bits,
+                              BoolAttr narrow_range);
+
+// Casts the `target` type to a quantized type by using the quantization
+// parameters from the type in the `source` type attribute.
+// Examples:
+//   f32 -> !quant.uniform<i8:f32, 1.0>
+//   tensor<4xf32> -> tensor<4x!quant.uniform<i8:f32, 1.0>>
+// The result is wrapped by a type attribute. Returns nullptr if the cast isn't
+// valid.
+TypeAttr CastQuantizedTypeAttrFromExpressedType(Builder builder,
+                                                TypeAttr source, Type target);
+
+// Quantizes the elements in the attribute `real_value` by the quantization
+// parameters in `tensor_type`. Returns empty Attribute if the
+// `tensor_type` is not a QuantizedType or the quantization fails.
+ElementsAttr Quantize(Attribute real_value, Type tensor_type);
+
+// Returns the quantized type for an element attribute. The quantization
+// parameters in this type is based on the min and max element of the attribute.
+// When the elements in the `attr` are not in floating-point, or the value range
+// isn't straddling zero, an empty type is returned.
+Type GetUniformQuantizedTypeForElementsAttr(ElementsAttr attr,
+                                            unsigned storage_type_width,
+                                            bool is_sign, bool narrow_range);
+
+// Returns the quantized type of a bias input, given the quantized types of
+// other operands which are multiply-accumulated (the bias is added to the
+// accumulated value).
+quant::QuantizedType GetUniformQuantizedTypeForBias(
+    const std::vector<quant::QuantizedType>& op_types);
+
+// Propagates quantization parameters across ops in this function and satisfy
+// the quantization specification of the ops. This methods assumes the initial
+// quantization parameters are stored as adjacent quantize and dequantize ops
+// and the propagation results are materialized by inserting pairs of quantize
+// and dequantize ops to this function.
+void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
+                                        OpQuantSpecGetter op_quant_spec_getter);
+
+}  // end namespace TFL
+}  // end namespace mlir
+
+#endif  // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_
diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir
index da779c1..448830b 100644
--- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir
@@ -1,4 +1,4 @@
-// RUN: tf-opt %s -test-constant-fold | FileCheck %s
+// RUN: tf-opt %s -test-constant-fold | FileCheck %s --dump-input-on-failure
 
 // CHECK-LABEL: @add_float
 func @add_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) {
@@ -109,6 +109,36 @@
   return %5, %6, %7, %8 : tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>
 }
 
+// CHECK-LABEL: @elementwise_unary_ops
+func @elementwise_unary_ops() -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) {
+  %0 = constant dense<-1.0> : tensor<f32>
+  %1 = constant dense<1.0> : tensor<f32>
+  %2 = constant dense<1.0> : tensor<f32>
+  %3 = constant dense<1.0> : tensor<f32>
+  %4 = constant dense<4.0> : tensor<f32>
+  %5 = constant dense<4.0> : tensor<f32>
+  %6 = constant dense<2.0> : tensor<f32>
+
+  // CHECK-DAG: [[cst0:%.*]] = constant dense<1.000000e+00> : tensor<f32>
+  // CHECK-DAG: [[cst1:%.*]] = constant dense<0.841470957> : tensor<f32>
+  // CHECK-DAG: [[cst2:%.*]] = constant dense<0.540302277> : tensor<f32>
+  // CHECK-DAG: [[cst3:%.*]] = constant dense<0.000000e+00> : tensor<f32>
+  // CHECK-DAG: [[cst4:%.*]] = constant dense<2.000000e+00> : tensor<f32>
+  // CHECK-DAG: [[cst5:%.*]] = constant dense<5.000000e-01> : tensor<f32>
+  // CHECK-DAG: [[cst6:%.*]] = constant dense<4.000000e+00> : tensor<f32>
+  // CHECK: return [[cst0]], [[cst1]], [[cst2]], [[cst3]], [[cst4]], [[cst5]], [[cst6]]
+
+  %7 = "tfl.abs"(%0) : (tensor<f32>) -> tensor<f32>
+  %8 = "tfl.sin"(%1) : (tensor<f32>) -> tensor<f32>
+  %9 = "tfl.cos"(%2) : (tensor<f32>) -> tensor<f32>
+  %10 = "tfl.log"(%3) : (tensor<f32>) -> tensor<f32>
+  %11 = "tfl.sqrt"(%4) : (tensor<f32>) -> tensor<f32>
+  %12 = "tfl.rsqrt"(%5) : (tensor<f32>) -> tensor<f32>
+  %13 = "tfl.square"(%6) : (tensor<f32>) -> tensor<f32>
+
+  return %7, %8, %9, %10, %11, %12, %13 : tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>
+}
+
 // CHECK-LABEL: @mul_int
 func @mul_int() -> (tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
   %0 = constant dense<8> : tensor<i32>
@@ -273,3 +303,155 @@
 // CHECK:  %0 = "tfl.add"
 // CHECK:  return %0
 }
+
+// CHECK-LABEL: @rank
+func @rank() -> tensor<1xi32> {
+  %cst = constant dense<[[1], [2]]> : tensor<2x1xi32>
+
+  // CHECK: [[cst:%.*]] = constant dense<2> : tensor<1xi32>
+  // CHECK: return [[cst]]
+  %0 = "tfl.rank"(%cst) : (tensor<2x1xi32>) -> tensor<1xi32>
+  return %0 : tensor<1xi32>
+}
+
+// CHECK-LABEL: @rank_input_known_rank
+func @rank_input_known_rank(%arg0 : tensor<2x1xi32>) -> tensor<1xi32> {
+  // CHECK: [[cst:%.*]] = constant dense<2> : tensor<1xi32>
+  // CHECK: return [[cst]]
+  %0 = "tfl.rank"(%arg0) : (tensor<2x1xi32>) -> tensor<1xi32>
+  return %0 : tensor<1xi32>
+}
+
+// CHECK-LABEL: @reshape
+func @reshape() -> tensor<1x2xi32> {
+  %cst = constant dense<[1, 2]> : tensor<2xi32>
+
+  // CHECK: [[cst:%.*]] = constant dense<{{\[\[}}1, 2]]> : tensor<1x2xi32>
+  // CHECK: return [[cst]]
+  %0 = "tfl.reshape"(%cst) : (tensor<2xi32>) -> tensor<1x2xi32>
+  return %0 : tensor<1x2xi32>
+}
+// CHECK-LABEL: @pseudo_const
+func @pseudo_const() -> tensor<i32> {
+  // CHECK: [[cst:%.*]] = constant dense<1> : tensor<i32>
+  // CHECK: return [[cst]]
+  %0 = "tfl.pseudo_const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+  return %0 : tensor<i32>
+}
+
+
+// CHECK-LABEL: @range_int
+func @range_int() -> tensor<?xi32> {
+  %cst = constant dense<0> : tensor<i32>
+  %cst_1 = constant dense<4> : tensor<i32>
+  %cst_2 = constant dense<1> : tensor<i32>
+
+  // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<?xi32>
+  // CHECK: return [[cst]]
+  %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
+  return %0 : tensor<?xi32>
+}
+
+// CHECK-LABEL: @range_float
+func @range_float() -> tensor<?xf32> {
+  %cst = constant dense<0.0> : tensor<f32>
+  %cst_1 = constant dense<4.0> : tensor<f32>
+  %cst_2 = constant dense<1.0> : tensor<f32>
+
+  // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
+  // CHECK: return [[cst]]
+  %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+
+// CHECK-LABEL: @range_float_neg_delta
+func @range_float_neg_delta() -> tensor<?xf32> {
+  %cst = constant dense<0.0> : tensor<f32>
+  %cst_1 = constant dense<-4.0> : tensor<f32>
+  %cst_2 = constant dense<-1.0> : tensor<f32>
+
+  // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, -1.000000e+00, -2.000000e+00, -3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
+  // CHECK: return [[cst]]
+  %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL: @range_float_nonzero_base
+func @range_float_nonzero_base() -> tensor<?xf32> {
+  %cst = constant dense<2.0> : tensor<f32>
+  %cst_1 = constant dense<7.0> : tensor<f32>
+  %cst_2 = constant dense<1.5> : tensor<f32>
+
+  // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[2.000000e+00, 3.500000e+00, 5.000000e+00, 6.500000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
+  // CHECK: return [[cst]]
+  %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL: @transpose_no_fold
+func @transpose_no_fold(%arg0 : tensor<2xi32>) -> tensor<2x2xi32> {
+  %cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
+
+  // CHECK: tfl.transpose
+  %0 = "tfl.transpose"(%cst, %arg0) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
+  return %0 : tensor<2x2xi32>
+}
+
+// CHECK-LABEL: @transpose_1d
+// Basic 1D identity
+func @transpose_1d() -> tensor<3xi32> {
+  %cst = constant dense<[1, 2, 3]> : tensor<3xi32>
+  %cst_perm = constant dense<0> : tensor<1xi32>
+
+  // CHECK: [[cst:%.*]] = constant dense<{{\[}}1, 2, 3]> : tensor<3xi32>
+  // CHECK: return [[cst]]
+  %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32>
+  return %0 : tensor<3xi32>
+}
+
+// CHECK-LABEL: @transpose_dynamic
+func @transpose_dynamic() -> tensor<?xi32> {
+  %cst = constant dense<[1, 2, 3]> : tensor<3xi32>
+  %cst_perm = constant dense<0> : tensor<1xi32>
+
+  // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<{{\[}}1, 2, 3]> : tensor<3xi32>} : () -> tensor<?xi32>
+  // CHECK: return [[cst]]
+  %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor<?xi32>
+  return %0 : tensor<?xi32>
+}
+
+// CHECK-LABEL: @transpose_2d
+func @transpose_2d() -> tensor<2x2xi32> {
+  %cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
+  %cst_perm = constant dense<[1, 0]> : tensor<2xi32>
+
+  // CHECK: [[cst:%.*]] = constant dense<{{\[\[}}0, 2], {{\[}}1, 3]]> : tensor<2x2xi32>
+  // CHECK: return [[cst]]
+  %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
+  return %0 : tensor<2x2xi32>
+}
+
+// CHECK-LABEL: @transpose_2d_identity
+func @transpose_2d_identity() -> tensor<2x2xi32> {
+  %cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
+  %cst_perm = constant dense<[0, 1]> : tensor<2xi32>
+
+  // CHECK: [[cst:%.*]] = constant dense<{{\[\[}}0, 1], {{\[}}2, 3]]> : tensor<2x2xi32>
+  // CHECK: return [[cst]]
+  %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
+  return %0 : tensor<2x2xi32>
+}
+
+// CHECK-LABEL: @transpose_3d
+// A test case adopted from TransposeTest.Test3DInputConstTensor in
+// tensorflow/lite/kernels/transpose_test.cc
+func @transpose_3d() -> tensor<4x2x3xi32> {
+  %cst = constant dense<[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]> : tensor<2x3x4xi32>
+  %cst_perm = constant dense<[2, 0, 1]> : tensor<3xi32>
+
+  // CHECK: [[cst:%.*]] = constant dense<{{\[\[\[}}0, 4, 8], {{\[}}12, 16, 20]], {{\[\[}}1, 5, 9], {{\[}}13, 17, 21]], {{\[\[}}2, 6, 10], {{\[}}14, 18, 22]], {{\[\[}}3, 7, 11], {{\[}}15, 19, 23]]]> : tensor<4x2x3xi32>
+  // CHECK: return [[cst]]
+  %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x3x4xi32>, tensor<3xi32>) -> tensor<4x2x3xi32>
+  return %0 : tensor<4x2x3xi32>
+}
diff --git a/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.line.part.pbtxt b/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.line.part.pbtxt
index 1bf0b07..c1bb797 100644
--- a/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.line.part.pbtxt
+++ b/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.line.part.pbtxt
@@ -1,4 +1,4 @@
-# RUN: tf_tfl_translate -mlir-pretty-debuginfo -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes=1,224,224,3 -tf-output-arrays=MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm -tf-debug-info=%s.debug %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[1]} -eq 0
+# RUN: tf_tfl_translate -mlir-pretty-debuginfo -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes=1,224,224,3 -tf-output-arrays=MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm -tf-debug-info=%s.debug %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0
 
 # CHECK: fake/user/code/file_C.py:27:1: error: 'tf.Conv2D' op attribute 'data_format' failed to satisfy constraint: 'NHWC' or 'NCHW' convnet data format
 
diff --git a/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.stack.part.pbtxt b/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.stack.part.pbtxt
index edad75c..d3dcbc6 100644
--- a/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.stack.part.pbtxt
+++ b/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.stack.part.pbtxt
@@ -1,4 +1,4 @@
-# RUN: tf_tfl_translate -mlir-pretty-debuginfo -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes=1,224,224,3 -tf-output-arrays=MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm -tf-debug-info=%s.debug %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[1]} -eq 0
+# RUN: tf_tfl_translate -mlir-pretty-debuginfo -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes=1,224,224,3 -tf-output-arrays=MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm -tf-debug-info=%s.debug %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0
 
 # CHECK: fake/user/code/file_C.py:27:1: error: 'tf.Conv2D' op attribute 'data_format' failed to satisfy constraint: 'NHWC' or 'NCHW' convnet data format
 # CHECK: fake/user/code/file_D.py:28:1: note: called from
diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
index 539cf8f..7413b19 100644
--- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
@@ -142,7 +142,7 @@
   return %0: tensor<2xi32>
 
 // CHECK-LABEL: @const
-// CHECK: %0 = "tfl.pseudo_const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> tensor<2xi32>
+// CHECK: "tfl.pseudo_const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> tensor<2xi32>
 }
 
 func @placeholder(%arg0: tensor<f32>) -> tensor<f32> {
@@ -213,6 +213,20 @@
 // CHECK:  %0 = "tfl.logistic"(%arg0) : (tensor<?x88xf16>) -> tensor<?x88xf16>
 }
 
+func @sqrt(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  %0 = "tf.Sqrt"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  return %0 : tensor<8x16xf32>
+// CHECK-LABEL: sqrt
+// CHECK:  %0 = "tfl.sqrt"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
+}
+
+func @square(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  %0 = "tf.Square"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  return %0 : tensor<8x16xf32>
+// CHECK-LABEL: square
+// CHECK:  %0 = "tfl.square"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
+}
+
 func @log_softmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
   %0 = "tf.LogSoftmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
   return %0 : tensor<8x16xf32>
@@ -289,6 +303,14 @@
 // CHECK:  %0 = "tfl.abs"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
 }
 
+func @any(%arg0: tensor<2x2xi1>, %arg1: tensor<i32>) -> tensor<i1> {
+  %0 = "tf.Any"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor<i32>) -> tensor<i1>
+  return %0 : tensor<i1>
+
+// CHECK-LABEL:any
+// CHECK:  %0 = "tfl.reduce_any"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor<i32>) -> tensor<i1>
+}
+
 func @ceil(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
   %0 = "tf.Ceil"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
   return %0 : tensor<8x16xf32>
@@ -442,12 +464,12 @@
 // CHECK:  return %0 : tensor<8x16xi1>
 }
 
-func @rank(%arg0: tensor<11x16xf32>) -> tensor<1xi32> {
-  %0 = "tf.Rank"(%arg0) : (tensor<11x16xf32>) -> tensor<1xi32>
+func @rank(%arg0: tensor<*xf32>) -> tensor<1xi32> {
+  %0 = "tf.Rank"(%arg0) : (tensor<*xf32>) -> tensor<1xi32>
   return %0 : tensor<1xi32>
 
 // CHECK-LABEL:rank
-// CHECK:  %0 = "tfl.rank"(%arg0) : (tensor<11x16xf32>) -> tensor<1xi32>
+// CHECK:  %0 = "tfl.rank"(%arg0) : (tensor<*xf32>) -> tensor<1xi32>
 }
 
 func @floor(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
@@ -487,6 +509,15 @@
 // CHECK:  return %0 : tensor<8xf32>
 }
 
+func @select_v2(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> {
+  %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32>
+  return %0: tensor<8xf32>
+
+// CHECK-LABEL: select_v2
+// CHECK:  %0 = "tfl.select"(%arg0, %arg1, %arg2)
+// CHECK:  return %0 : tensor<8xf32>
+}
+
 func @sin(%arg0: tensor<f32>) -> tensor<f32> {
   %0 = "tf.Sin"(%arg0) : (tensor<f32>) -> tensor<f32>
   return %0 : tensor<f32>
@@ -629,6 +660,17 @@
   // CHECK:  return %0 : tensor<?xf32>
 }
 
+func @tile(tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x6xf32> {
+^bb0(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>):
+  %cst = constant dense<[1, 2]> : tensor<2xi32>
+  %0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x6xf32>
+  return %0 : tensor<2x6xf32>
+
+  // CHECK-LABEL: tile
+  // CHECK:  %0 = "tfl.tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x6xf32>
+  // CHECK:  return %0 : tensor<2x6xf32>
+}
+
 func @padv2(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
 ^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
   %cst = constant dense<2.0> : tensor<f32>
@@ -956,4 +998,80 @@
 
 // CHECK-LABEL: argmax64
 // CHECK:  %0 = "tfl.arg_max"(%arg0, %arg1) : (tensor<3xi32>, tensor<i32>) -> tensor<i64>
-}
\ No newline at end of file
+}
+
+func @space_to_depth(%arg0: tensor<1x2x2x1xf32>) -> tensor<?xf32> {
+  %0 = "tf.SpaceToDepth"(%arg0) {block_size = 2: i64,  data_format = "NHWC"}: (tensor<1x2x2x1xf32>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+
+  // CHECK-LABEL: space_to_depth
+  // CHECK: %[[ARG:.*]]: tensor<1x2x2x1xf32>
+  // CHECK: "tfl.space_to_depth"(%[[ARG]]) {block_size = 2 : i32} : (tensor<1x2x2x1xf32>) -> tensor<?xf32>
+}
+
+func @round(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  %0 = "tf.Round"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  return %0 : tensor<8x16xf32>
+
+  // CHECK-LABEL: round
+  // CHECK: %[[ARG:.*]]: tensor<8x16xf32>
+  // CHECK: %[[RESULT:.*]] = "tfl.round"(%[[ARG]]) : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK: return %[[RESULT]] : tensor<8x16xf32>
+}
+
+func @resize_nearest_neighbor(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor<?xf32> {
+  %0 = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+  // CHECK-LABEL: resize_nearest_neighbor
+  // CHECK: "tfl.resize_nearest_neighbor"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
+}
+
+// Note: half_pixel_centers isn't supported by TFLite, so it's not legalized.
+func @resize_nearest_neighbor_with_half_pixel_centers(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor<?xf32> {
+  %0 = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = true, half_pixel_centers = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+  // CHECK-LABEL: resize_nearest_neighbor_with_half_pixel_centers
+  // CHECK: "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = true, half_pixel_centers = true}
+}
+
+func @sparse_to_dense_with_scalar_sparse_indices(%arg0: tensor<i32>, %arg1: tensor<3xi32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<?x?x?xf32> {
+  %0 = "tf.SparseToDense"(%arg0, %arg1, %arg2, %arg3) {validate_indices = true}: (tensor<i32>, tensor<3xi32>, tensor<f32>, tensor<f32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+  // CHECK-LABEL: sparse_to_dense_with_scalar_sparse_indices
+  // CHECK: "tfl.sparse_to_dense"(%arg0, %arg1, %arg2, %arg3) : (tensor<i32>, tensor<3xi32>, tensor<f32>, tensor<f32>) -> tensor<?x?x?xf32>
+}
+
+func @sparse_to_dense_with_vector_sparse_indices(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>, %arg2: tensor<3xf32>, %arg3: tensor<f32>) -> tensor<?x?x?xf32> {
+  %0 = "tf.SparseToDense"(%arg0, %arg1, %arg2, %arg3) {validate_indices = true}: (tensor<3xi32>, tensor<3xi32>, tensor<3xf32>, tensor<f32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+  // CHECK-LABEL: sparse_to_dense_with_vector_sparse_indices
+  // CHECK: "tfl.sparse_to_dense"(%arg0, %arg1, %arg2, %arg3) : (tensor<3xi32>, tensor<3xi32>, tensor<3xf32>, tensor<f32>) -> tensor<?x?x?xf32>
+}
+
+func @sparse_to_dense_with_2d_sparse_indices(%arg0: tensor<3x2xi32>, %arg1: tensor<3xi32>, %arg2: tensor<2xf32>, %arg3: tensor<f32>) -> tensor<?x?x?xf32> {
+  %0 = "tf.SparseToDense"(%arg0, %arg1, %arg2, %arg3) {validate_indices = true}: (tensor<3x2xi32>, tensor<3xi32>, tensor<2xf32>, tensor<f32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+  // CHECK-LABEL: sparse_to_dense_with_2d_sparse_indices
+  // CHECK: "tfl.sparse_to_dense"(%arg0, %arg1, %arg2, %arg3) : (tensor<3x2xi32>, tensor<3xi32>, tensor<2xf32>, tensor<f32>) -> tensor<?x?x?xf32>
+}
+
+func @where(%arg0: tensor<3x5xi1>) -> tensor<?x2xi64> {
+  %0 = "tf.Where"(%arg0) : (tensor<3x5xi1>) -> tensor<?x2xi64>
+  return %0 : tensor<?x2xi64>
+  // CHECK-LABEL: where
+  // CHECK: "tfl.where"(%arg0) : (tensor<3x5xi1>) -> tensor<?x2xi64>
+}
+
+func @floor_mod(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> tensor<5xf32> {
+  %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
+  return %0 : tensor<5xf32>
+  // CHECK-LABEL: floor_mod
+  // CHECK: "tfl.floor_mod"(%arg0, %arg1) : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
+}
+
+func @exp(%arg0: tensor<5xf32>) -> tensor<5xf32> {
+  %0 = "tf.Exp"(%arg0) : (tensor<5xf32>) -> tensor<5xf32>
+  return %0 : tensor<5xf32>
+  // CHECK-LABEL: exp
+  // CHECK: "tfl.exp"(%arg0) : (tensor<5xf32>) -> tensor<5xf32>
+}
diff --git a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir
index 1fe6757..817ced7 100644
--- a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir
@@ -1,6 +1,6 @@
 // RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s --dump-input-on-failure
-func @tensorlistGetItem(tensor<3x10xf32>, tensor<1xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<3x10xf32>) {
-^bb0(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>):
+
+func @tensorlistGetItem(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>) -> (tensor<10xf32>, tensor<3x10xf32>) {
   %0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
   %1 = "tf.TensorListGetItem"(%0, %arg2, %arg1) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<i32>, tensor<1xi32>) -> tensor<10xf32>
   %2 = "tf.TensorListStack"(%0, %arg1) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<1xi32>) -> tensor<3x10xf32>
@@ -11,8 +11,7 @@
 // CHECK: return %0, %arg0 : tensor<10xf32>, tensor<3x10xf32>
 }
 
-func @tensorlistGetItemWithUnknownRank(tensor<*xf32>, tensor<1xi32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>) {
-^bb0(%arg0: tensor<*xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>):
+func @tensorlistGetItemWithUnknownRank(%arg0: tensor<*xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>) {
   %0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<*xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<*xf32>>>
   %1 = "tf.TensorListGetItem"(%0, %arg2, %arg1) : (tensor<!tf.variant<tensor<*xf32>>>, tensor<i32>, tensor<1xi32>) -> tensor<*xf32>
   %2 = "tf.TensorListStack"(%0, %arg1) : (tensor<!tf.variant<tensor<*xf32>>>, tensor<1xi32>) -> tensor<*xf32>
@@ -23,8 +22,7 @@
 // CHECK: return %0, %arg0 : tensor<*xf32>, tensor<*xf32>
 }
 
-func @tensorlistSetItem(tensor<3x10xf32>, tensor<1xi32>, tensor<i32>, tensor<10xf32>) -> tensor<3x10xf32> {
-^bb0(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>, %arg3: tensor<10xf32>):
+func @tensorlistSetItem(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>, %arg3: tensor<10xf32>) -> tensor<3x10xf32> {
   %0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
   %1 = "tf.TensorListSetItem"(%0, %arg2, %arg3) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<i32>, tensor<10xf32>) -> tensor<!tf.variant<tensor<10xf32>>>
   %2 = "tf.TensorListStack"(%1, %arg1) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<1xi32>) -> tensor<3x10xf32>
@@ -56,8 +54,7 @@
 // CHECK:  return %15 : tensor<3x10xf32>
 }
 
-func @tensorlistSetItemWithScalarElements(tensor<5xf32>, tensor<0xi32>, tensor<i32>, tensor<f32>) -> tensor<5xf32> {
-^bb0(%arg0: tensor<5xf32>, %arg1: tensor<0xi32>, %arg2: tensor<i32>, %arg3: tensor<f32>):
+func @tensorlistSetItemWithScalarElements(%arg0: tensor<5xf32>, %arg1: tensor<0xi32>, %arg2: tensor<i32>, %arg3: tensor<f32>) -> tensor<5xf32> {
   %0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<5xf32>, tensor<0xi32>) -> tensor<!tf.variant<tensor<f32>>>
   %1 = "tf.TensorListSetItem"(%0, %arg2, %arg3) : (tensor<!tf.variant<tensor<f32>>>, tensor<i32>, tensor<f32>) -> tensor<!tf.variant<tensor<f32>>>
   %2 = "tf.TensorListStack"(%1, %arg1) : (tensor<!tf.variant<tensor<f32>>>, tensor<0xi32>) -> tensor<5xf32>
@@ -89,24 +86,23 @@
 // CHECK:  return %15 : tensor<5xf32>
 }
 
-func @tensorlistReserve(tensor<3xi32>, tensor<i32>, tensor<i32>) -> tensor<?x?x?xf32> {
-^bb0(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>):
+func @tensorlistReserve(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?x?x?xf32> {
   %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<3xi32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?x?xf32>>>
   %1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor<!tf.variant<tensor<?x?x?xf32>>>, tensor<i32>, tensor<3xi32>) -> tensor<?x?x?xf32>
   return %1 : tensor<?x?x?xf32>
 
 // CHECK-LABEL: tensorlistReserve
-// CHECK:  %cst = constant dense<0> : tensor<i32>
-// CHECK:  %0 = "tf.ExpandDims"(%arg1, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
-// CHECK:  %1 = "tf.Concat"(%cst, %0, %arg0) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<3xi32>) -> tensor<4xi32>
-// CHECK:  %cst_0 = constant dense<0.000000e+00> : tensor<f32>
-// CHECK:  %2 = "tf.Fill"(%1, %cst_0) : (tensor<4xi32>, tensor<f32>) -> tensor<?x?x?x?xf32>
-// CHECK:  %3 = "tf.Gather"(%2, %arg2) {validate_indices = true} : (tensor<?x?x?x?xf32>, tensor<i32>) -> tensor<?x?x?xf32>
-// CHECK:  return %3 : tensor<?x?x?xf32>
+// CHECK-DAG:  [[ZERO1:%cst.*]] = constant dense<0> : tensor<i32>
+// CHECK-DAG:  [[ZERO2:%cst.*]] = constant dense<0> : tensor<i32>
+// CHECK-DAG:  [[DIM0:%.*]] = "tf.ExpandDims"(%arg1, [[ZERO1]]) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
+// CHECK-DAG:  [[SHAPE:%.*]] = "tf.Concat"([[ZERO2]], [[DIM0]], %arg0) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<3xi32>) -> tensor<4xi32>
+// CHECK-DAG:  [[VALUES:%.*]] = constant dense<0.000000e+00> : tensor<f32>
+// CHECK:      [[LIST:%.*]] = "tf.Fill"([[SHAPE]], [[VALUES]]) : (tensor<4xi32>, tensor<f32>) -> tensor<?x?x?x?xf32>
+// CHECK:      [[RESULT:%.*]] = "tf.Gather"([[LIST]], %arg2) {validate_indices = true} : (tensor<?x?x?x?xf32>, tensor<i32>) -> tensor<?x?x?xf32>
+// CHECK:      return [[RESULT]] : tensor<?x?x?xf32>
 }
 
-func @tensorlistReserveUnrankedElements(tensor<?xi32>, tensor<i32>, tensor<i32>) -> tensor<*xf32> {
-^bb0(%arg0: tensor<?xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>):
+func @tensorlistReserveUnrankedElements(%arg0: tensor<?xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<*xf32> {
   %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<?xi32>, tensor<i32>) -> tensor<!tf.variant<tensor<*xf32>>>
   %1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor<!tf.variant<tensor<*xf32>>>, tensor<i32>, tensor<?xi32>) -> tensor<*xf32>
   return %1 : tensor<*xf32>
@@ -117,13 +113,42 @@
 // CHECK:  return [[RESULT2]] : tensor<*xf32>
 }
 
-func @tensorlistWhileLoop(tensor<2x3xf32>) -> tensor<*xf32> {
-^bb0(%arg0: tensor<2x3xf32>):
+func @EmptyTensorList(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?x?x?xf32> {
+  %0 = "tf.EmptyTensorList"(%arg0, %arg1) : (tensor<3xi32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?x?xf32>>>
+  %1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor<!tf.variant<tensor<?x?x?xf32>>>, tensor<i32>, tensor<3xi32>) -> tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+
+// CHECK-LABEL: EmptyTensorList
+// CHECK-SAME:  ([[ELEM_SHAPE:%.*]]: tensor<3xi32>, [[MAX_ELEMS:%.*]]: tensor<i32>, [[IDX:%.*]]: tensor<i32>)
+// CHECK-DAG:  [[DIM0:%cst.*]] = constant dense<0> : tensor<1xi32>
+// CHECK-DAG:  [[ZERO:%cst.*]] = constant dense<0> : tensor<i32>
+// CHECK-DAG:  [[SHAPE:%.*]] = "tf.Concat"([[ZERO]], [[DIM0]], [[ELEM_SHAPE]]) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<3xi32>) -> tensor<4xi32>
+// CHECK-DAG:  [[VALUES:%.*]] = constant dense<0.000000e+00> : tensor<f32>
+// CHECK:      [[LIST:%.*]] = "tf.Fill"([[SHAPE]], [[VALUES]]) : (tensor<4xi32>, tensor<f32>) -> tensor<?x?x?x?xf32>
+// CHECK:      [[RESULT:%.*]] = "tf.Gather"([[LIST]], [[IDX]]) {validate_indices = true} : (tensor<?x?x?x?xf32>, tensor<i32>) -> tensor<?x?x?xf32>
+// CHECK:      return [[RESULT]] : tensor<?x?x?xf32>
+}
+
+func @tensorlistPushBack(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<10xf32>) -> tensor<?x10xf32> {
+  %0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
+  %1 = "tf.TensorListPushBack"(%0, %arg2) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<10xf32>) -> tensor<!tf.variant<tensor<10xf32>>>
+  %2 = "tf.TensorListStack"(%1, %arg1) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<1xi32>) -> tensor<?x10xf32>
+  return %2 : tensor<?x10xf32>
+
+// CHECK-LABEL: tensorlistPushBack
+// CHECK-SAME:  ([[INPUT:%.*]]: tensor<3x10xf32>, [[ELEM_SHAPE:%.*]]: tensor<1xi32>, [[ITEM:%.*]]: tensor<10xf32>)
+// CHECK:   [[ZERO:%.*]] = constant dense<0> : tensor<i32>
+// CHECK:   [[EXP_ITEM:%.*]] = "tf.ExpandDims"([[ITEM]], [[ZERO]]) {{.*}} -> tensor<1x10xf32>
+// CHECK:   [[RESULT:%.*]] = "tf.Concat"(%cst, [[INPUT]], [[EXP_ITEM]]) {N = 2 : i64} : {{.*}} -> tensor<?x10xf32>
+// CHECK:   return [[RESULT]] : tensor<?x10xf32>
+}
+
+func @tensorlistWhileLoop(%arg0: tensor<2x3xf32>) -> tensor<*xf32> {
   %cst = constant dense<3> : tensor<1xi32>
   %cst_0 = constant dense<0> : tensor<i32>
   %cst_1 = constant dense<-1> : tensor<i32>
   %0 = "tf.TensorListFromTensor"(%arg0, %cst) : (tensor<2x3xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<3xf32>>>
-  %1:2 = "tf.While"(%cst_0, %0) {T = ["tfdtype$DT_INT32", "tfdtype$DT_VARIANT"], body = @tensorlistWhileBody, cond = @tensorlistWhileCond} : (tensor<i32>, tensor<!tf.variant<tensor<3xf32>>>) -> (tensor<i32>, tensor<!tf.variant<tensor<*xf32>>>)
+  %1:2 = "tf.While"(%cst_0, %0) {T = ["tfdtype$DT_INT32", "tfdtype$DT_VARIANT"], body = @tensorlistWhileBody, cond = @tensorlistWhileCond, is_stateless = false} : (tensor<i32>, tensor<!tf.variant<tensor<3xf32>>>) -> (tensor<i32>, tensor<!tf.variant<tensor<*xf32>>>)
   %2 = "tf.TensorListStack"(%1#1, %cst_1) : (tensor<!tf.variant<tensor<*xf32>>>, tensor<i32>) -> tensor<*xf32>
   return %2 : tensor<*xf32>
 
@@ -136,8 +161,7 @@
 // CHECK:  return %0#1 : tensor<*xf32>
 }
 
-func @tensorlistWhileBody(tensor<*xi32>, tensor<!tf.variant>) -> (tensor<*xi32>, tensor<!tf.variant>) {
-^bb0(%arg0: tensor<*xi32>, %arg1: tensor<!tf.variant>):
+func @tensorlistWhileBody(%arg0: tensor<*xi32>, %arg1: tensor<!tf.variant>) -> (tensor<*xi32>, tensor<!tf.variant>) {
   %cst = constant dense<1> : tensor<i32>
   %0 = "tf.Add"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
   %1 = "tf.Identity"(%arg1) : (tensor<!tf.variant>) -> tensor<!tf.variant>
@@ -151,8 +175,7 @@
 // CHECK:  return %0, %1 : tensor<*xi32>, tensor<*xf32>
 }
 
-func @tensorlistWhileCond(tensor<*xi32>, tensor<!tf.variant>) -> tensor<*xi1> {
-^bb0(%arg0: tensor<*xi32>, %arg1: tensor<!tf.variant>):
+func @tensorlistWhileCond(%arg0: tensor<*xi32>, %arg1: tensor<!tf.variant>) -> tensor<*xi1> {
   %cst = constant dense<2> : tensor<i32>
   %0 = "tf.Less"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
   return %0 : tensor<*xi1>
diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_builtin.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_builtin.mlir
index 6f0882f..408fb51 100644
--- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_builtin.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_builtin.mlir
@@ -1,4 +1,4 @@
-// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-builtin-tflite-ops=false -o - | flatbuffer_to_string - | FileCheck %s; test ${PIPESTATUS[1]} -eq 1
+// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-builtin-tflite-ops=false -o - | flatbuffer_to_string - | FileCheck %s; test ${PIPESTATUS[0]} -ne 0
 # CHECK: loc("disable_builtin.mlir":2:1): is a TFLite builtin op but builtin emission is not enabled
 # CHECK-NEXT: Verification failed.
 
diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex.mlir
index be62118..c4dd8b5 100644
--- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex.mlir
@@ -1,4 +1,4 @@
-// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s; test ${PIPESTATUS[1]} -eq 1
+// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s; test ${PIPESTATUS[0]} -ne 0
 # CHECK:  loc("disable_flex.mlir":96:8): error: 'tf.div' op is a Flex op but Flex ops are not enabled for emission
 # CHECK-NEXT:  Verification failed.
 
diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/dynamic_shape_constant.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/dynamic_shape_constant.mlir
new file mode 100644
index 0000000..1eae962
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/dynamic_shape_constant.mlir
@@ -0,0 +1,25 @@
+// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string -
+
+func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> {
+  %cst = "tfl.pseudo_const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<?xi32>
+  %0 = "tfl.pseudo_input" (%arg0) : (tensor<2xi32>) -> tensor<2xi32>
+  %1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<?xi32>) -> tensor<2xi32>
+  return %1 : tensor<2xi32>
+}
+
+
+// CHECK:    tensors: [ {
+// CHECK-NEXT:      shape: [ 2 ],
+// CHECK-NEXT:      type: INT32,
+// CHECK-NEXT:      buffer: 1,
+// CHECK-NEXT:      name: "tfl.pseudo_const",
+// CHECK-NEXT:      quantization: {
+// CHECK-NEXT:
+// CHECK-NEXT:      }
+
+// CHECK:   buffers: [ {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-NEXT:     data: [ 1, 0, 0, 0, 2, 0, 0, 0 ]
+// CHECK-NEXT:   }, {
+
diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir
new file mode 100644
index 0000000..ddb122f
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir
@@ -0,0 +1,283 @@
+// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
+
+func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -> tensor<4 x f32> {
+// CHECK: {
+// CHECK-NEXT:   version: 3,
+// CHECK-NEXT:   operator_codes: [ {
+// CHECK-NEXT:     builtin_code: LSTM
+// CHECK-NEXT:   } ],
+// CHECK-NEXT:   subgraphs: [ {
+// CHECK-NEXT:     tensors: [ {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 1,
+// CHECK-NEXT:       name: "tfl.pseudo_input",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 2,
+// CHECK-NEXT:       name: "tfl.pseudo_input1",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 3,
+// CHECK-NEXT:       name: "tfl.pseudo_input2",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 4,
+// CHECK-NEXT:       name: "tfl.pseudo_input3",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 5,
+// CHECK-NEXT:       name: "tfl.pseudo_input4",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 6,
+// CHECK-NEXT:       name: "tfl.pseudo_input5",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 7,
+// CHECK-NEXT:       name: "tfl.pseudo_input6",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 8,
+// CHECK-NEXT:       name: "tfl.pseudo_input7",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 9,
+// CHECK-NEXT:       name: "tfl.pseudo_input8",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 10,
+// CHECK-NEXT:       name: "tfl.pseudo_input9",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 11,
+// CHECK-NEXT:       name: "tfl.pseudo_input10",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 12,
+// CHECK-NEXT:       name: "tfl.pseudo_input11",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 13,
+// CHECK-NEXT:       name: "tfl.pseudo_input12",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 14,
+// CHECK-NEXT:       name: "tfl.pseudo_input13",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 15,
+// CHECK-NEXT:       name: "tfl.pseudo_input14",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 16,
+// CHECK-NEXT:       name: "tfl.pseudo_input15",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 17,
+// CHECK-NEXT:       name: "tfl.pseudo_input16",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 18,
+// CHECK-NEXT:       name: "tfl.pseudo_input17",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       name: "Const",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       },
+// CHECK-NEXT:       is_variable: true
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       name: "Const1",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       },
+// CHECK-NEXT:       is_variable: true
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 21,
+// CHECK-NEXT:       name: "tfl.pseudo_input18",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 22,
+// CHECK-NEXT:       name: "tfl.pseudo_input19",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 23,
+// CHECK-NEXT:       name: "tfl.pseudo_input20",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 24,
+// CHECK-NEXT:       name: "tfl.pseudo_input21",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 25,
+// CHECK-NEXT:       name: "tfl.lstm",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     } ],
+// CHECK-NEXT:     inputs: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 20, 21, 22, 23 ],
+// CHECK-NEXT:     outputs: [ 24 ],
+// CHECK-NEXT:     operators: [ {
+// CHECK-NEXT:       inputs: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 ],
+// CHECK-NEXT:       outputs: [ 24 ],
+// CHECK-NEXT:       builtin_options_type: LSTMOptions,
+// CHECK-NEXT:       builtin_options: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     } ],
+// CHECK-NEXT:     name: "main"
+// CHECK-NEXT:   } ],
+// CHECK-NEXT:   description: "MLIR Converted.",
+// CHECK-NEXT:   buffers: [ {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-NEXT:     data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
+// CHECK-NEXT:   }, {
+// CHECK-NEXT:     data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   } ]
+// CHECK-NEXT: }
+// CHECK-EMPTY:
+
+
+^bb0(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>, %arg4: tensor<4 x f32>, %arg5: tensor<4 x f32>, %arg6: tensor<4 x f32>, %arg7: tensor<4 x f32>, %arg8: tensor<4 x f32>, %arg9: tensor<4 x f32>, %arg10: tensor<4 x f32>, %arg11: tensor<4 x f32>, %arg12: tensor<4 x f32>, %arg13: tensor<4 x f32>, %arg14: tensor<4 x f32>, %arg15: tensor<4 x f32>, %arg16: tensor<4 x f32>, %arg17: tensor<4 x f32>, %arg20: tensor<4 x f32>, %arg21: tensor<4 x f32>, %arg22: tensor<4 x f32>, %arg23: tensor<4 x f32>):
+  %0 = "tfl.pseudo_input" (%arg0) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %1 = "tfl.pseudo_input" (%arg1) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %2 = "tfl.pseudo_input" (%arg2) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %3 = "tfl.pseudo_input" (%arg3) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %4 = "tfl.pseudo_input" (%arg4) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %5 = "tfl.pseudo_input" (%arg5) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %6 = "tfl.pseudo_input" (%arg6) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %7 = "tfl.pseudo_input" (%arg7) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %8 = "tfl.pseudo_input" (%arg8) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %9 = "tfl.pseudo_input" (%arg9) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %10 = "tfl.pseudo_input" (%arg10) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %11 = "tfl.pseudo_input" (%arg11) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %12 = "tfl.pseudo_input" (%arg12) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %13 = "tfl.pseudo_input" (%arg13) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %14 = "tfl.pseudo_input" (%arg14) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %15 = "tfl.pseudo_input" (%arg15) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %16 = "tfl.pseudo_input" (%arg16) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %17 = "tfl.pseudo_input" (%arg17) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %18 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
+  %19 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
+  %20 = "tfl.pseudo_input" (%arg20) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %21 = "tfl.pseudo_input" (%arg21) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %22 = "tfl.pseudo_input" (%arg22) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %23 = "tfl.pseudo_input" (%arg23) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %24 = "tfl.lstm"(%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  return %24 : tensor<4xf32>
+}
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir
new file mode 100644
index 0000000..e2ffb24
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir
@@ -0,0 +1,282 @@
+// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
+
+func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -> tensor<4 x f32> {
+// CHECK: {
+// CHECK-NEXT:   version: 3,
+// CHECK-NEXT:   operator_codes: [ {
+// CHECK-NEXT:     builtin_code: UNIDIRECTIONAL_SEQUENCE_LSTM
+// CHECK-NEXT:   } ],
+// CHECK-NEXT:   subgraphs: [ {
+// CHECK-NEXT:     tensors: [ {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 1,
+// CHECK-NEXT:       name: "tfl.pseudo_input",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 2,
+// CHECK-NEXT:       name: "tfl.pseudo_input1",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 3,
+// CHECK-NEXT:       name: "tfl.pseudo_input2",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 4,
+// CHECK-NEXT:       name: "tfl.pseudo_input3",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 5,
+// CHECK-NEXT:       name: "tfl.pseudo_input4",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 6,
+// CHECK-NEXT:       name: "tfl.pseudo_input5",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 7,
+// CHECK-NEXT:       name: "tfl.pseudo_input6",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 8,
+// CHECK-NEXT:       name: "tfl.pseudo_input7",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 9,
+// CHECK-NEXT:       name: "tfl.pseudo_input8",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 10,
+// CHECK-NEXT:       name: "tfl.pseudo_input9",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 11,
+// CHECK-NEXT:       name: "tfl.pseudo_input10",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 12,
+// CHECK-NEXT:       name: "tfl.pseudo_input11",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 13,
+// CHECK-NEXT:       name: "tfl.pseudo_input12",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 14,
+// CHECK-NEXT:       name: "tfl.pseudo_input13",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 15,
+// CHECK-NEXT:       name: "tfl.pseudo_input14",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 16,
+// CHECK-NEXT:       name: "tfl.pseudo_input15",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 17,
+// CHECK-NEXT:       name: "tfl.pseudo_input16",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 18,
+// CHECK-NEXT:       name: "tfl.pseudo_input17",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       name: "Const",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       },
+// CHECK-NEXT:       is_variable: true
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       name: "Const1",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       },
+// CHECK-NEXT:       is_variable: true
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 21,
+// CHECK-NEXT:       name: "tfl.pseudo_input18",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 22,
+// CHECK-NEXT:       name: "tfl.pseudo_input19",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 23,
+// CHECK-NEXT:       name: "tfl.pseudo_input20",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 24,
+// CHECK-NEXT:       name: "tfl.pseudo_input21",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:       shape: [ 4 ],
+// CHECK-NEXT:       buffer: 25,
+// CHECK-NEXT:       name: "tfl.unidirectional_sequence_lstm",
+// CHECK-NEXT:       quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:       }
+// CHECK-NEXT:     } ],
+// CHECK-NEXT:     inputs: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 20, 21, 22, 23 ],
+// CHECK-NEXT:     outputs: [ 24 ],
+// CHECK-NEXT:     operators: [ {
+// CHECK-NEXT:       inputs: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 ],
+// CHECK-NEXT:       outputs: [ 24 ],
+// CHECK-NEXT:         builtin_options_type: UnidirectionalSequenceLSTMOptions,
+// CHECK-NEXT:         builtin_options: {
+// CHECK-NEXT:           time_major: true
+// CHECK-NEXT:         }
+// CHECK-NEXT:     } ],
+// CHECK-NEXT:     name: "main"
+// CHECK-NEXT:   } ],
+// CHECK-NEXT:   description: "MLIR Converted.",
+// CHECK-NEXT:   buffers: [ {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-NEXT:     data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
+// CHECK-NEXT:   }, {
+// CHECK-NEXT:     data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:   } ]
+// CHECK-NEXT: }
+// CHECK-EMPTY:
+
+^bb0(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>, %arg4: tensor<4 x f32>, %arg5: tensor<4 x f32>, %arg6: tensor<4 x f32>, %arg7: tensor<4 x f32>, %arg8: tensor<4 x f32>, %arg9: tensor<4 x f32>, %arg10: tensor<4 x f32>, %arg11: tensor<4 x f32>, %arg12: tensor<4 x f32>, %arg13: tensor<4 x f32>, %arg14: tensor<4 x f32>, %arg15: tensor<4 x f32>, %arg16: tensor<4 x f32>, %arg17: tensor<4 x f32>, %arg20: tensor<4 x f32>, %arg21: tensor<4 x f32>, %arg22: tensor<4 x f32>, %arg23: tensor<4 x f32>):
+  %0 = "tfl.pseudo_input" (%arg0) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %1 = "tfl.pseudo_input" (%arg1) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %2 = "tfl.pseudo_input" (%arg2) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %3 = "tfl.pseudo_input" (%arg3) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %4 = "tfl.pseudo_input" (%arg4) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %5 = "tfl.pseudo_input" (%arg5) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %6 = "tfl.pseudo_input" (%arg6) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %7 = "tfl.pseudo_input" (%arg7) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %8 = "tfl.pseudo_input" (%arg8) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %9 = "tfl.pseudo_input" (%arg9) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %10 = "tfl.pseudo_input" (%arg10) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %11 = "tfl.pseudo_input" (%arg11) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %12 = "tfl.pseudo_input" (%arg12) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %13 = "tfl.pseudo_input" (%arg13) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %14 = "tfl.pseudo_input" (%arg14) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %15 = "tfl.pseudo_input" (%arg15) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %16 = "tfl.pseudo_input" (%arg16) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %17 = "tfl.pseudo_input" (%arg17) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %18 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
+  %19 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
+  %20 = "tfl.pseudo_input" (%arg20) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %21 = "tfl.pseudo_input" (%arg21) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %22 = "tfl.pseudo_input" (%arg22) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %23 = "tfl.pseudo_input" (%arg23) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %24 = "tfl.unidirectional_sequence_lstm"(%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  return %24 : tensor<4xf32>
+}
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir
new file mode 100644
index 0000000..3d91f66
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir
@@ -0,0 +1,93 @@
+// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
+
+func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -> tensor<4 x f32> {
+// CHECK:      {
+// CHECK-NEXT:     version: 3,
+// CHECK-NEXT:     operator_codes: [ {
+// CHECK-NEXT:       builtin_code: UNIDIRECTIONAL_SEQUENCE_RNN
+// CHECK-NEXT:     } ],
+// CHECK-NEXT:     subgraphs: [ {
+// CHECK-NEXT:       tensors: [ {
+// CHECK-NEXT:         shape: [ 4 ],
+// CHECK-NEXT:         buffer: 1,
+// CHECK-NEXT:         name: "tfl.pseudo_input",
+// CHECK-NEXT:         quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:         }
+// CHECK-NEXT:       }, {
+// CHECK-NEXT:         shape: [ 4 ],
+// CHECK-NEXT:         buffer: 2,
+// CHECK-NEXT:         name: "tfl.pseudo_input1",
+// CHECK-NEXT:         quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:         }
+// CHECK-NEXT:       }, {
+// CHECK-NEXT:         shape: [ 4 ],
+// CHECK-NEXT:         buffer: 3,
+// CHECK-NEXT:         name: "tfl.pseudo_input2",
+// CHECK-NEXT:         quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:         }
+// CHECK-NEXT:       }, {
+// CHECK-NEXT:         shape: [ 4 ],
+// CHECK-NEXT:         buffer: 4,
+// CHECK-NEXT:         name: "tfl.pseudo_input3",
+// CHECK-NEXT:         quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:         }
+// CHECK-NEXT:       }, {
+// CHECK-NEXT:         shape: [ 4 ],
+// CHECK-NEXT:         name: "Const",
+// CHECK-NEXT:         quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:         },
+// CHECK-NEXT:         is_variable: true
+// CHECK-NEXT:       }, {
+// CHECK-NEXT:         shape: [ 4 ],
+// CHECK-NEXT:         buffer: 6,
+// CHECK-NEXT:         name: "tfl.unidirectional_sequence_rnn",
+// CHECK-NEXT:         quantization: {
+// CHECK-EMPTY:
+// CHECK-NEXT:         }
+// CHECK-NEXT:       } ],
+// CHECK-NEXT:       inputs: [ 0, 1, 2, 3 ],
+// CHECK-NEXT:       outputs: [ 5 ],
+// CHECK-NEXT:       operators: [ {
+// CHECK-NEXT:         inputs: [ 0, 1, 2, 3, 4 ],
+// CHECK-NEXT:         outputs: [ 5 ],
+// CHECK-NEXT:         builtin_options_type: SequenceRNNOptions,
+// CHECK-NEXT:         builtin_options: {
+// CHECK-NEXT:           time_major: true,
+// CHECK-NEXT:           fused_activation_function: TANH
+// CHECK-NEXT:         }
+// CHECK-NEXT:       } ],
+// CHECK-NEXT:       name: "main"
+// CHECK-NEXT:     } ],
+// CHECK-NEXT:     description: "MLIR Converted.",
+// CHECK-NEXT:     buffers: [ {
+// CHECK-EMPTY:
+// CHECK-NEXT:     }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:     }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:     }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:     }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:     }, {
+// CHECK-NEXT:      data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
+// CHECK-NEXT:     }, {
+// CHECK-EMPTY:
+// CHECK-NEXT:     } ]
+// CHECK-NEXT:   }
+// CHECK-EMPTY:
+
+^bb0(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>):
+  %0 = "tfl.pseudo_input" (%arg0) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %1 = "tfl.pseudo_input" (%arg1) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %2 = "tfl.pseudo_input" (%arg2) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %3 = "tfl.pseudo_input" (%arg3) : (tensor<4 x f32>) -> tensor<4 x f32>
+  %4 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
+  %5 = "tfl.unidirectional_sequence_rnn"(%0, %1, %2, %3, %4) {fused_activation_function = "TANH", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  return %5 : tensor<4xf32>
+}
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unknown-op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unknown-op.mlir
index 14f8174..eb20f37 100644
--- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unknown-op.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unknown-op.mlir
@@ -1,4 +1,4 @@
-// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[1]} -eq 0
+// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0
 
 func @main(tensor<3x2xi32>) -> tensor<3x2xi32> {
 ^bb0(%arg0: tensor<3x2xi32>):
diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir
index 117f974..bf76f4f 100644
--- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir
@@ -195,7 +195,7 @@
 
   // While %0 is greater than zero, element wise add %1 with itself.
   %2:2 = "tf.While"(%0, %1) {
-    cond = @cond, body = @body
+    cond = @cond, body = @body, is_stateless = false
   } : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>)
   return %2#1 : tensor<1xf32>
 }
diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir
index c627b9e..0de38cb 100644
--- a/tensorflow/compiler/mlir/lite/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir
@@ -155,6 +155,26 @@
 
 // -----
 
+// test invalid Sqrt input
+func @testSqrtWithWrongInputType(tensor<? x i32>) -> tensor<? x i32> {
+^bb0(%arg0: tensor<? x i32>):
+  // expected-error @+1 {{tfl.sqrt' op operand #0 must be tensor of floating-point values}}
+  %0 = "tfl.sqrt"(%arg0): (tensor<? x i32>) -> tensor<? x i32>
+  return %0#0 : tensor<? x i32>
+}
+
+// -----
+
+// test invalid Square input
+func @testSquareWithWrongInputType(tensor<? x i32>) -> tensor<? x i32> {
+^bb0(%arg0: tensor<? x i32>):
+  // expected-error @+1 {{tfl.square' op operand #0 must be tensor of floating-point values}}
+  %0 = "tfl.square"(%arg0): (tensor<? x i32>) -> tensor<? x i32>
+  return %0#0 : tensor<? x i32>
+}
+
+// -----
+
 // CHECK-LABEL: testSqrt
 func @testSqrt(tensor<? x f32>) -> tensor<? x f32> {
 ^bb0(%arg0: tensor<? x f32>):
@@ -287,11 +307,9 @@
 // -----
 
 // CHECK-LABEL: testFloorMod
-func @testFloorMod(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
-^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
-  // CHECK: tfl.floor_mod %arg0, %arg1
-  %0 = tfl.floor_mod %arg0, %arg1 : tensor<? x i32>
-  return %0#0 : tensor<? x i32>
+func @testFloorMod(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>) -> tensor<? x i32> {
+  %0 = "tfl.floor_mod"(%arg0, %arg1) : (tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32>
+  return %0 : tensor<? x i32>
 }
 
 // CHECK-LABEL: testPow
@@ -310,6 +328,13 @@
   return %0 : tensor<256x30x30x16xf32>
 }
 
+
+func @testConv2DNoBias(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>, %arg2: none) -> tensor<256x30x30x16xf32> {
+  // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2)
+  %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "RELU6"} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, none) -> tensor<256x30x30x16xf32>
+  return %0 : tensor<256x30x30x16xf32>
+}
+
 // CHECK-LABEL: testFakeQuant
 func @testFakeQuant(tensor<? x f32>, f32, f32) -> tensor<? x f32> {
 ^bb0(%arg0: tensor<? x f32>, %arg1: f32, %arg2: f32):
@@ -489,13 +514,22 @@
 // test invalid Logistic input
 func @testLogisticWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
 ^bb0(%arg0: tensor<?xi32>):
-  // expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of floating-point values}}
+  // expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of floating-point or QI8 type or QUI8 type values}}
   %0 = "tfl.logistic"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
   return %0#0 : tensor<?xi32>
 }
 
 // -----
 
+// CHECK-LABEL: testUnidirectionalSequenceRnn
+func @testUnidirectionalSequenceRnn(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>) -> tensor<? x f32> {
+  // CHECK: "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+  %0 = "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: testUnidirectionalSequenceLstm
 func @testUnidirectionalSequenceLstm(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
   // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
@@ -933,3 +967,144 @@
   %0 = "tfl.arg_min"(%arg0, %arg1) {output_type = 2 : i32} : (tensor<3xi32>, tensor<i32>) -> tensor<i32>
   return %0 : tensor<i32>
 }
+
+// -----
+
+// CHECK-LABEL: testSpaceToDepth
+func @testSpaceToDepthF32(%arg0: tensor<1x2x2x1xf32>) -> tensor<1x1x1x4xf32> {
+  // CHECK: %[[ARG:.*]]: tensor<1x2x2x1xf32>
+  // CHECK: "tfl.space_to_depth"(%[[ARG]]) {block_size = 2 : i32} : (tensor<1x2x2x1xf32>) -> tensor<1x1x1x4xf32>
+  %0 = "tfl.space_to_depth"(%arg0) {block_size = 2: i32} : (tensor<1x2x2x1xf32>) -> tensor<1x1x1x4xf32>
+  return %0 : tensor<1x1x1x4xf32>
+}
+
+// -----
+
+func @testSpaceToDepthInvalidOutputType(%arg0: tensor<1x2x2x1xf32>) -> tensor<1x1x1x4xi32> {
+  // expected-error @+1 {{'tfl.space_to_depth' op failed to verify that input and output must have same element type}}
+  %0 = "tfl.space_to_depth"(%arg0) {block_size = 2: i32} : (tensor<1x2x2x1xf32>) -> tensor<1x1x1x4xi32>
+  return %0 : tensor<1x1x1x4xi32>
+}
+
+// -----
+
+func @testRange(%arg0 : tensor<i32>, %arg1 : tensor<i32>, %arg2 : tensor<i32>) -> tensor<?xi32> {
+  %0 = "tfl.range"(%arg0, %arg1, %arg2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
+  return %0 : tensor<?xi32>
+}
+
+// -----
+
+func @testRangeNonScalarTensorInput(%arg0 : tensor<1xi32>, %arg1 : tensor<i32>, %arg2 : tensor<i32>) -> tensor<?xi32> {
+  // expected-error @+1 {{op failed to verify that operand 0 is 0-D}}
+  %0 = "tfl.range"(%arg0, %arg1, %arg2) : (tensor<1xi32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
+  return %0 : tensor<?xi32>
+}
+
+// -----
+
+func @testRangeOutputTypeMismatch(%arg0 : tensor<i32>, %arg1 : tensor<i32>, %arg2 : tensor<i32>) -> tensor<?xf32> {
+  // expected-error @+1 {{op failed to verify that operands and output must have same element type}}
+  %0 = "tfl.range"(%arg0, %arg1, %arg2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+func @transpose(%arg0 : tensor<2x2xi32>, %arg1 : tensor<2xi32>) -> tensor<2x2xi32> {
+  %0 = "tfl.transpose"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
+  return %0 : tensor<2x2xi32>
+}
+
+
+// -----
+
+func @transpose_perm_not_i32(%arg0 : tensor<2x2xi32>, %arg1 : tensor<2xf32>) -> tensor<2x2xi32> {
+  // expected-error @+1 {{op operand #1 must be tensor of 32-bit integer values}}
+  %0 = "tfl.transpose"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<2xf32>) -> tensor<2x2xi32>
+  return %0 : tensor<2x2xi32>
+}
+
+
+// -----
+
+func @transpose_element_type(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2xi32>) -> tensor<2x2xi32> {
+  // expected-error @+1 {{input and output must have same element type}}
+  %0 = "tfl.transpose"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2xi32>) -> tensor<2x2xi32>
+  return %0 : tensor<2x2xi32>
+}
+
+
+// -----
+
+func @transpose_1d_perm(%arg0 : tensor<2x2xi32>, %arg1 : tensor<2x2xi32>) -> tensor<2x2xi32> {
+  // expected-error @+1 {{op failed to verify that operand 1 is 1-D}}
+  %0 = "tfl.transpose"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+  return %0 : tensor<2x2xi32>
+}
+
+// -----
+
+func @anyWithI64Axis(%arg0: tensor<2x2xi1>, %arg1: tensor<i64>) -> tensor<i1> {
+  // expected-error @+1 {{tfl.reduce_any' op operand #1 must be tensor of 32-bit integer values}}
+  %0 = "tfl.reduce_any"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor<i64>) -> tensor<i1>
+  return %0 : tensor<i1>
+}
+
+// -----
+
+func @testRoundInvalidInputType(%arg: tensor<?xi32>) -> tensor<?xi32> {
+  // expected-error @+1 {{'tfl.round' op operand #0 must be tensor of 32-bit float values}}
+  %0 = "tfl.round"(%arg) : (tensor<?xi32>) -> tensor<?xi32>
+  return %0 : tensor<?xi32>
+}
+
+// -----
+
+func @testSplitWithQuantizedTypes(%arg0 : tensor<i32>, %arg1 : tensor<10x!quant.uniform<u8:f32, 1.0>>) -> tensor<10x!quant.uniform<u8:f32, 1.0>> {
+  %0 = "tfl.split"(%arg0, %arg1) {num_splits = 1 : i32} : (tensor<i32>, tensor<10x!quant.uniform<u8:f32, 1.0>>) -> tensor<10x!quant.uniform<u8:f32, 1.0>>
+  return %0 : tensor<10x!quant.uniform<u8:f32, 1.0>>
+}
+
+// -----
+
+func @testSplitVWithQuantizedTypes(%arg0 : tensor<10x!quant.uniform<u8:f32, 1.0>>, %arg1 : tensor<i32>, %arg2 : tensor<i32>) -> tensor<10x!quant.uniform<u8:f32, 1.0>> {
+  %0 = "tfl.split_v"(%arg0, %arg1, %arg2) {num_splits = 1 : i32} : (tensor<10x!quant.uniform<u8:f32, 1.0>>, tensor<i32>, tensor<i32>) -> tensor<10x!quant.uniform<u8:f32, 1.0>>
+  return %0 : tensor<10x!quant.uniform<u8:f32, 1.0>>
+}
+
+// -----
+
+func @whereWithI32Input(%arg0: tensor<3x5xi32>) -> tensor<?x2xi64> {
+  // expected-error @+1 {{'tfl.where' op operand #0 must be tensor of 1-bit integer values}}
+  %0 = "tfl.where"(%arg0) : (tensor<3x5xi32>) -> tensor<?x2xi64>
+  return %0 : tensor<?x2xi64>
+}
+
+// -----
+
+func @testMinimumWithQuantizedTypes(%arg0 : tensor<10x!quant.uniform<u8:f32, 1.0>>, %arg1 : tensor<10x!quant.uniform<u8:f32, 1.0>>) -> tensor<10x!quant.uniform<u8:f32, 1.0>> {
+  %0 = "tfl.minimum"(%arg0, %arg1) : (tensor<10x!quant.uniform<u8:f32, 1.0>>, tensor<10x!quant.uniform<u8:f32, 1.0>>) -> tensor<10x!quant.uniform<u8:f32, 1.0>>
+  return %0 : tensor<10x!quant.uniform<u8:f32, 1.0>>
+}
+
+// -----
+
+func @testMaximumWithQuantizedTypes(%arg0 : tensor<10x!quant.uniform<u8:f32, 1.0>>, %arg1 : tensor<10x!quant.uniform<u8:f32, 1.0>>) -> tensor<10x!quant.uniform<u8:f32, 1.0>> {
+  %0 = "tfl.maximum"(%arg0, %arg1) : (tensor<10x!quant.uniform<u8:f32, 1.0>>, tensor<10x!quant.uniform<u8:f32, 1.0>>) -> tensor<10x!quant.uniform<u8:f32, 1.0>>
+  return %0 : tensor<10x!quant.uniform<u8:f32, 1.0>>
+}
+
+// -----
+
+func @testReluWithQuantizedTypes(%arg0 : tensor<10x!quant.uniform<u8:f32, 1.0>>) -> tensor<10x!quant.uniform<u8:f32, 1.0>> {
+  %0 = "tfl.relu"(%arg0) : (tensor<10x!quant.uniform<u8:f32, 1.0>>) -> tensor<10x!quant.uniform<u8:f32, 1.0>>
+  return %0 : tensor<10x!quant.uniform<u8:f32, 1.0>>
+}
+
+// -----
+
+func @testRelu6WithQuantizedTypes(%arg0 : tensor<10x!quant.uniform<u8:f32, 1.0>>) -> tensor<10x!quant.uniform<u8:f32, 1.0>> {
+  %0 = "tfl.relu6"(%arg0) : (tensor<10x!quant.uniform<u8:f32, 1.0>>) -> tensor<10x!quant.uniform<u8:f32, 1.0>>
+  return %0 : tensor<10x!quant.uniform<u8:f32, 1.0>>
+}
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
index e7ebace..78afbc8 100644
--- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
@@ -96,6 +96,54 @@
 
 }
 
+// CHECK-LABEL: @fuseMulIntoFullyConnected
+func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
+  %cst0 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32>
+  %cst1 = constant dense<2.0> : tensor<2xf32>
+  %cst2 = constant dense<[1.0, 2.0]> : tensor<2xf32>
+
+  %0 = "tfl.fully_connected"(%arg0, %cst0, %cst1) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32>
+  %1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<4x2xf32>, tensor<2xf32>) -> tensor<4x2xf32>
+
+  return %1 : tensor<4x2xf32>
+
+// CHECK:  %cst = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
+// CHECK:  %cst_0 = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
+// CHECK:  %0 = "tfl.fully_connected"(%arg0, %cst, %cst_0) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}
+// CHECK:  return %0 : tensor<4x2xf32>
+}
+
+// CHECK-LABEL: @fuseMulIntoFullyConnectedBroadcast
+func @fuseMulIntoFullyConnectedBroadcast(%arg0: tensor<1x3xf32>) -> tensor<1x2xf32> {
+  %cst0 = constant dense<[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]> : tensor<2x3xf32>
+  %cst1 = constant dense<2.0> : tensor<2xf32>
+  %cst2 = constant dense<[1.0, 2.0]> : tensor<2xf32>
+  %0 = "tfl.fully_connected"(%arg0, %cst0, %cst1) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x3xf32>, tensor<2x3xf32>, tensor<2xf32>) -> tensor<1x2xf32>
+  // %cst2 isn't broadcast-compatible to %cst0, but tf.Mul is able to fold them.
+  %1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x2xf32>, tensor<2xf32>) -> tensor<1x2xf32>
+  return %1 : tensor<1x2xf32>
+
+// CHECK:  %cst = constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [2.000000e+00, 4.000000e+00, 6.000000e+00]]> : tensor<2x3xf32>
+// CHECK:  %cst_0 = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
+// CHECK:  %0 = "tfl.fully_connected"(%arg0, %cst, %cst_0) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}
+// CHECK:  return %0 : tensor<1x2xf32>
+}
+
+// CHECK-LABEL: @fuseMulIntoFullyConnectedNoBias
+func @fuseMulIntoFullyConnectedNoBias(%arg0: tensor<4x2xf32>, %arg1: none) -> tensor<4x2xf32> {
+  %cst0 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32>
+  %cst2 = constant dense<[1.0, 2.0]> : tensor<2xf32>
+
+  %0 = "tfl.fully_connected"(%arg0, %cst0, %arg1) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32>
+  %1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<4x2xf32>, tensor<2xf32>) -> tensor<4x2xf32>
+
+  return %1 : tensor<4x2xf32>
+
+// CHECK:  %cst = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
+// CHECK:  %0 = "tfl.fully_connected"(%arg0, %cst, %arg1) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32>
+// CHECK:  return %0 : tensor<4x2xf32>
+}
+
 // CHECK-LABEL: @fuseMulIntoDepthwiseConv2d
 func @fuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> {
   %cst0 = constant dense<[[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0], [17.0, 18.0]]]]> : tensor<1x3x3x2xf32>
@@ -148,6 +196,17 @@
   // CHECK: %2 = "tfl.fully_connected"(%0, %1, %cst)
 }
 
+// CHECK-LABEL: @FuseFullyConnectedRelu
+func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
+  %0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>
+  %1 = "tfl.relu"(%0) : (tensor<1x128xf32>) -> tensor<1x128xf32>
+  return %1 : tensor<1x128xf32>
+
+  // CHECK: %[[RES:[0-9].*]] = "tfl.fully_connected"
+  // CHECK-SAME: fused_activation_function = "RELU"
+  // CHECK: return %[[RES]]
+}
+
 // CHECK-LABEL: @NoPadStridedSliceNonNewAxisMask
 func @NoPadStridedSliceNonNewAxisMask(%arg0: tensor<1x2x3x1xf32>) -> tensor<1x2x3x1xf32> {
   %cst = constant dense<0> : tensor<4xi32>
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir
index f2ca713..16ac6b5 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir
@@ -343,7 +343,7 @@
   return %cst : tensor<2x3xf32>
 
 // CHECK: %cst = constant dense{{.*}}tensor<2x3xf32>
-// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<2x3x!quant.uniform<u8:f32, 0.023529411764705882:128>>}
+// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<2x3x!quant.uniform<u8<1:255>:f32, 0.023622047244094488:128>>}
 // CHECK: %1 = "tfl.dequantize"(%0)
 // CHECK: return %1 : tensor<2x3xf32>
 }
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
index 0edb4f4..fd35ed8 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
@@ -63,93 +63,218 @@
   return %2, %2#1 : tensor<8x8x8x8xf32>, tensor<8xf32>
 
 // CHECK-LABEL: fusedBatchNorm
-// CHECK:%cst = constant dense<1.000000e-03> : tensor<f32>
+// CHECK:  %[[CONSTANT:.*]] = constant dense<1.000000e-03>
 //              variance + epsilon
-// CHECK:  %0 = "tf.Add"(%arg4, %cst) : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
+// CHECK:  %[[ADD1:.*]] = "tf.Add"(%[[ARG4:.*]], %[[CONSTANT]])
 //              rsqrt(variance + epsilon)
-// CHECK:  %1 = "tf.Rsqrt"(%0) : (tensor<8xf32>) -> tensor<8xf32>
+// CHECK:  %[[RSQRT:.*]] = "tf.Rsqrt"(%[[ADD1]])
 //              scale * rsqrt(variance + epsilon)
-// CHECK:  %2 = "tf.Mul"(%arg1, %1) : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32>
+// CHECK:  %[[MUL1:.*]] = "tf.Mul"(%[[ARG1:.*]], %[[RSQRT]])
 //              x * scale * rsqrt(variance + epsilon)
-// CHECK:  %3 = "tf.Mul"(%arg0, %2) : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
+// CHECK:  %[[MUL2:.*]] = "tf.Mul"(%[[ARG0:.*]], %[[MUL1]])
 //              mean * scale * rsqrt(variance + epsilon)
-// CHECK:  %4 = "tf.Mul"(%arg3, %2) : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32>
+// CHECK:  %[[MUL3:.*]] = "tf.Mul"(%[[ARG3:.*]], %[[MUL1]])
 //              offset - mean * scale * rsqrt(variance + epsilon)
-// CHECK:  %5 = "tf.Sub"(%arg2, %4) : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32>
+// CHECK:  %[[SUB:.*]] = "tf.Sub"(%[[ARG2:.*]], %[[MUL3]])
 //              x * scale * rsqrt(variance + epsilon) +
 //              offset - mean * scale * rsqrt(variance + epsilon)
-// CHECK:  %6 = "tf.Add"(%3, %5) : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
+// CHECK:  %[[ADD2:.*]] = "tf.Add"(%[[MUL2]], %[[SUB]])
 
-// CHECK:  %7:5 = "tf.FusedBatchNorm"(%6, %arg1, %arg2, %arg3, %arg4)
-// CHECK:  %8:5 = "tf.FusedBatchNorm"(%7#0, %arg1, %arg2, %arg3, %arg4)
+// CHECK:  %[[BATCHNORM1:.*]]:5 = "tf.FusedBatchNorm"(%[[ADD2]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
+// CHECK:  {{.*}} = "tf.FusedBatchNorm"(%[[BATCHNORM1]]#0, %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
 }
 
-func @fakeQuantNotFollowedByQuant(tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>) {
-^bb0(%arg0: tensor<8x8x8x8xf32>):
-  %arg1 = constant dense<-0.1> : tensor<f32>
-  %arg2 = constant dense<0.2> : tensor<f32>
-  %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>, tensor<f32>, tensor<f32>) -> tensor<8x8x8x8xf32>
-  return %0 : tensor<8x8x8x8xf32>
+func @fusedBatchNormV3(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) {
+^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>):
+  // OK
+  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
+  // Unsupported training
+  %1:6 = "tf.FusedBatchNormV3"( %0#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true}  : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
+  // Use other output
+  %2:6 = "tf.FusedBatchNormV3"( %1#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
 
-// CHECK-LABEL: fakeQuantNotFollowedByQuant
+  return %2, %2#1 : tensor<8x8x8x8xf32>, tensor<8xf32>
+
+// CHECK-LABEL: fusedBatchNormV3
+// CHECK:  %[[CONSTANT:.*]] = constant dense<1.000000e-03>
+//              variance + epsilon
+// CHECK:  %[[ADD1:.*]] = "tf.Add"(%[[ARG4:.*]], %[[CONSTANT]])
+//              rsqrt(variance + epsilon)
+// CHECK:  %[[RSQRT:.*]] = "tf.Rsqrt"(%[[ADD1]])
+//              scale * rsqrt(variance + epsilon)
+// CHECK:  %[[MUL1:.*]] = "tf.Mul"(%[[ARG1:.*]], %[[RSQRT]])
+//              x * scale * rsqrt(variance + epsilon)
+// CHECK:  %[[MUL2:.*]] = "tf.Mul"(%[[ARG0:.*]], %[[MUL1]])
+//              mean * scale * rsqrt(variance + epsilon)
+// CHECK:  %[[MUL3:.*]] = "tf.Mul"(%[[ARG3:.*]], %[[MUL1]])
+//              offset - mean * scale * rsqrt(variance + epsilon)
+// CHECK:  %[[SUB:.*]] = "tf.Sub"(%[[ARG2:.*]], %[[MUL3]])
+//              x * scale * rsqrt(variance + epsilon) +
+//              offset - mean * scale * rsqrt(variance + epsilon)
+// CHECK:  %[[ADD2:.*]] = "tf.Add"(%[[MUL2]], %[[SUB]])
+
+// CHECK:  %[[BATCHNORM1:.*]]:6 = "tf.FusedBatchNormV3"(%[[ADD2]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
+// CHECK:  %[[BATCHNORM2:.*]]:6 = "tf.FusedBatchNormV3"(%[[BATCHNORM1]]#0, %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
+}
+
+// CHECK-LABEL: fakeQuantForActivation
+func @fakeQuantForActivation(tensor<8xf32>) -> (tensor<8xf32>) {
+^bb0(%arg0: tensor<8xf32>):
+  %arg1 = constant dense<0.0> : tensor<f32>
+  %arg2 = constant dense<255.0> : tensor<f32>
+  %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
+  return %0 : tensor<8xf32>
+
+// CHECK:  %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0)
+// CHECK:  %1 = "tfl.quantize"(%0) {qtype = tensor<8x!quant.uniform<u8:f32, 1.000000e+00>>}
+// CHECK:  %2 = "tfl.dequantize"(%1)
+// CHECK:  return %2
+}
+
+// CHECK-LABEL: fakeQuantForActivationNoDuplication
+func @fakeQuantForActivationNoDuplication(tensor<8xf32>) -> (tensor<8x!quant.uniform<u8:f32, 1.000000e+00>>) {
+^bb0(%arg0: tensor<8xf32>):
+  %arg1 = constant dense<0.0> : tensor<f32>
+  %arg2 = constant dense<255.0> : tensor<f32>
+  %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
+  %1 = "tfl.quantize"(%0) {qtype = tensor<8x!quant.uniform<u8:f32, 1.000000e+00>>} : (tensor<8xf32>) -> tensor<8x!quant.uniform<u8:f32, 1.000000e+00>>
+  return %1 : tensor<8x!quant.uniform<u8:f32, 1.000000e+00>>
+
 // CHECK:  %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) {narrow_range = false, num_bits = 3 : i64}
-// CHECK:  %1 = "tfl.quantize"(%0) {qtype = tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>}
-// CHECK:  %2 = "tfl.dequantize"(%1) : (tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>)
-// CHECK:  return %2 : tensor<8x8x8x8xf32>
+// CHECK:  %1 = "tfl.quantize"(%0) {qtype = tensor<8x!quant.uniform<u8:f32, 1.000000e+00>>}
+// CHECK:  return %1
 }
 
-func @fakeQuantFollowedByQuant(tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>) {
-^bb0(%arg0: tensor<8x8x8x8xf32>):
-  %arg1 = constant dense<-0.1> : tensor<f32>
-  %arg2 = constant dense<0.2> : tensor<f32>
-  %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>, tensor<f32>, tensor<f32>) -> tensor<8x8x8x8xf32>
-  %1 = "tfl.quantize"(%0) {qtype = tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>
-  %2 = "tfl.dequantize"(%1) : (tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>) -> tensor<8x8x8x8xf32>
-  return %2 : tensor<8x8x8x8xf32>
+// CHECK-LABEL: fakeQuantFolded
+func @fakeQuantFolded() -> (tensor<8xf32>) {
+  %in = constant dense<0.0> : tensor<8xf32>
+  %min = constant dense<0.0> : tensor<f32>
+  %max = constant dense<255.0> : tensor<f32>
+  %mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
+  %maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
+  %rst = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
+  return %rst : tensor<8xf32>
 
-// CHECK-LABEL: fakeQuantFollowedByQuant
-// CHECK:  %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) {narrow_range = false, num_bits = 3 : i64}
-// CHECK:  %1 = "tfl.quantize"(%0) {qtype = tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>}
-// CHECK:  %2 = "tfl.dequantize"(%1) : (tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>)
-// CHECK:  return %2 : tensor<8x8x8x8xf32>
+// CHECK: %cst = constant dense<0.000000e+00> : tensor<8xf32>
+// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<8x!quant.uniform<u8:f32, 1.000000e+00>>}
+// CHECK: %1 = "tfl.dequantize"(%0)
+// CHECK: return %1 : tensor<8xf32>
 }
 
-func @fakeQuantVarsNotConst(tensor<8x8x8x8xf32>, tensor<f32>, tensor<f32>) -> (tensor<8x8x8x8xf32>) {
-^bb0(%arg0: tensor<8x8x8x8xf32>, %arg3: tensor<f32>, %arg4: tensor<f32>):
-  %1 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg3, %arg4) {num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>, tensor<f32>, tensor<f32>) -> tensor<8x8x8x8xf32>
-  return %1 : tensor<8x8x8x8xf32>
+// CHECK-LABEL: fakeQuantNotFolded
+func @fakeQuantNotFolded(tensor<8xf32>, tensor<f32>, tensor<f32>) -> (tensor<8xf32>) {
+^bb0(%arg0: tensor<8xf32>, %arg3: tensor<f32>, %arg4: tensor<f32>):
+  %1 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg3, %arg4) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
+  return %1 : tensor<8xf32>
 
-// CHECK-LABEL: fakeQuantVarsNotConst
-// CHECK:  %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {narrow_range = false, num_bits = 3 : i64}
-// CHECK:  return %0 : tensor<8x8x8x8xf32>
+// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2)
+// CHECK: return %0 : tensor<8xf32>
 }
 
-func @fakeQuantFollowedByTranspose(tensor<3x3x3x16xf32>, tensor<f32>, tensor<f32>) -> (tensor<16x3x3x3xf32>) {
-^bb0(%arg0: tensor<3x3x3x16xf32>, %arg1: tensor<f32>, %arg2: tensor<f32>):
-  %cst_0 = constant dense<[3, 0, 1, 2]> : tensor<4xi32>
-  %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<f32>, tensor<f32>) -> tensor<3x3x3x16xf32>
-  %1 = "tf.Transpose"(%0, %cst_0): (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32>
-  return %1 : tensor<16x3x3x3xf32>
-
 // CHECK-LABEL: fakeQuantFollowedByTranspose
-// CHECK:  %cst = constant dense<[3, 0, 1, 2]> : tensor<4xi32>
-// CHECK:  %0 = "tf.Transpose"(%arg0, %cst) : (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32>
-// CHECK:  %1 = "tf.FakeQuantWithMinMaxVars"(%0, %arg1, %arg2) {narrow_range = false, num_bits = 3 : i64}
-// CHECK:  return %1 : tensor<16x3x3x3xf32>
+func @fakeQuantFollowedByTranspose(tensor<1x2xf32>, tensor<f32>, tensor<f32>) -> (tensor<2x1xf32>) {
+^bb0(%arg0: tensor<1x2xf32>, %arg1: tensor<f32>, %arg2: tensor<f32>):
+  %cst_0 = constant dense<[1, 0]> : tensor<2xi32>
+  %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<1x2xf32>, tensor<f32>, tensor<f32>) -> tensor<1x2xf32>
+  %1 = "tf.Transpose"(%0, %cst_0): (tensor<1x2xf32>, tensor<2xi32>) -> tensor<2x1xf32>
+  return %1 : tensor<2x1xf32>
+
+// CHECK:  %cst = constant
+// CHECK:  %0 = "tf.Transpose"(%arg0, %cst)
+// CHECK:  %1 = "tf.FakeQuantWithMinMaxVars"(%0, %arg1, %arg2)
+// CHECK:  return %1
 }
 
-func @fakeQuantFollowedByReshape(tensor<3x3x3x4xf32>, tensor<f32>, tensor<f32>) -> (tensor<1x3x3x12xf32>) {
-^bb0(%arg0: tensor<3x3x3x4xf32>, %arg1: tensor<f32>, %arg2: tensor<f32>):
-  %cst_0 = constant dense<[1, 3, 3, 12]> : tensor<4xi64>
-  %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x4xf32>, tensor<f32>, tensor<f32>) -> tensor<3x3x3x4xf32>
-  %1 = "tf.Reshape"(%0, %cst_0) : (tensor<3x3x3x4xf32>, tensor<4xi64>) -> tensor<1x3x3x12xf32>
-  return %1 : tensor<1x3x3x12xf32>
-
 // CHECK-LABEL: fakeQuantFollowedByReshape
-// CHECK:  %cst = constant dense<[1, 3, 3, 12]> : tensor<4xi64>
-// CHECK:  %0 = "tf.Reshape"(%arg0, %cst) : (tensor<3x3x3x4xf32>, tensor<4xi64>) -> tensor<1x3x3x12xf32>
-// CHECK:  %1 = "tf.FakeQuantWithMinMaxVars"(%0, %arg1, %arg2) {narrow_range = false, num_bits = 3 : i64}
-// CHECK:  return %1 : tensor<1x3x3x12xf32>
+func @fakeQuantFollowedByReshape(tensor<1x2xf32>, tensor<f32>, tensor<f32>) -> (tensor<2x1xf32>) {
+^bb0(%arg0: tensor<1x2xf32>, %arg1: tensor<f32>, %arg2: tensor<f32>):
+  %cst_0 = constant dense<[2, 1]> : tensor<2xi64>
+  %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<1x2xf32>, tensor<f32>, tensor<f32>) -> tensor<1x2xf32>
+  %1 = "tf.Reshape"(%0, %cst_0) : (tensor<1x2xf32>, tensor<2xi64>) -> tensor<2x1xf32>
+  return %1 : tensor<2x1xf32>
+
+// CHECK:  %cst = constant
+// CHECK:  %0 = "tf.Reshape"(%arg0, %cst)
+// CHECK:  %1 = "tf.FakeQuantWithMinMaxVars"(%0, %arg1, %arg2)
+// CHECK:  return %1
+}
+
+// CHECK-LABEL: QDQsFollowedByTranspose
+func @QDQsFollowedByTranspose(tensor<1x2xf32>) -> (tensor<2x1xf32>) {
+^bb0(%arg0: tensor<1x2xf32>):
+  %cst_0 = constant dense<[1, 0]> : tensor<2xi32>
+  %0 = "tfl.quantize"(%arg0){qtype = tensor<1x2x!quant.uniform<u8:f32, 1.0>>}: (tensor<1x2xf32>) -> (tensor<1x2x!quant.uniform<u8:f32, 1.0>>)
+  %1 = "tfl.dequantize"(%0): (tensor<1x2x!quant.uniform<u8:f32, 1.0>>) -> (tensor<1x2xf32>)
+  %2 = "tf.Transpose"(%1, %cst_0): (tensor<1x2xf32>, tensor<2xi32>) -> tensor<2x1xf32>
+  return %2 : tensor<2x1xf32>
+
+// CHECK: %cst = constant
+// CHECK: %0 = "tf.Transpose"
+// CHECK-SAME: -> tensor<2x1xf32>
+// CHECK: %1 = "tfl.quantize"(%0) {qtype = tensor<2x1x!quant.uniform<u8:f32, 1.000000e+00>>}
+// CHECK-SAME: -> tensor<2x1x!quant.uniform<u8:f32, 1.000000e+00>>
+// CHECK: %2 = "tfl.dequantize"(%1)
+// CHECK-SAME: -> tensor<2x1xf32>
+// CHECK: return %2
+}
+
+// CHECK-LABEL: QDQFollowedByReshape
+func @QDQFollowedByReshape(tensor<1x2xf32>) -> (tensor<2x1xf32>) {
+^bb0(%arg0: tensor<1x2xf32>):
+  %cst_0 = constant dense<[2, 1]> : tensor<2xi32>
+  %0 = "tfl.quantize"(%arg0){qtype = tensor<1x2x!quant.uniform<u8:f32, 1.0>>}: (tensor<1x2xf32>) -> (tensor<1x2x!quant.uniform<u8:f32, 1.0>>)
+  %1 = "tfl.dequantize"(%0): (tensor<1x2x!quant.uniform<u8:f32, 1.0>>) -> (tensor<1x2xf32>)
+  %2 = "tf.Reshape"(%1, %cst_0): (tensor<1x2xf32>, tensor<2xi32>) -> tensor<2x1xf32>
+  return %2 : tensor<2x1xf32>
+
+// CHECK: %cst = constant
+// CHECK: %0 = "tf.Reshape"
+// CHECK-SAME: -> tensor<2x1xf32>
+// CHECK: %1 = "tfl.quantize"(%0) {qtype = tensor<2x1x!quant.uniform<u8:f32, 1.000000e+00>>}
+// CHECK-SAME: -> tensor<2x1x!quant.uniform<u8:f32, 1.000000e+00>>
+// CHECK: %2 = "tfl.dequantize"(%1)
+// CHECK-SAME: -> tensor<2x1xf32>
+// CHECK: return %2
+}
+
+// CHECK-LABEL: fakeQuantWithConv2D
+func @fakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
+^bb0(%arg: tensor<256x32x32x3xf32>) :
+  %in = constant dense<0.0> : tensor<3x3x3x16xf32>
+  %min = constant dense<0.0> : tensor<f32>
+  %max = constant dense<255.0> : tensor<f32>
+  %mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
+  %maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
+  %fq = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<f32>, tensor<f32>) -> tensor<3x3x3x16xf32>
+  %rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
+  return %rst : tensor<256x30x30x16xf32>
+
+// CHECK: %cst = constant dense<0.000000e+00> : tensor<16xf32>
+// CHECK: %cst_0 = constant dense<0.000000e+00> : tensor<16x3x3x3xf32>
+// CHECK: %0 = "tfl.quantize"(%cst_0) {qtype = tensor<16x3x3x3x!quant.uniform<u8:f32, 1.000000e+00>>}
+// CHECK: %1 = "tfl.dequantize"(%0)
+// CHECK: %2 = "tfl.conv_2d"(%arg0, %1, %cst)
+// CHECK: return %2
+}
+
+// CHECK-LABEL: fakeQuantWithDepthwiseConv2D
+func @fakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
+^bb0(%arg: tensor<256x32x32x3xf32>) :
+  %in = constant dense<0.0> : tensor<3x3x3x16xf32>
+  %min = constant dense<0.0> : tensor<f32>
+  %max = constant dense<255.0> : tensor<f32>
+  %mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
+  %maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
+  %fq = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<f32>, tensor<f32>) -> tensor<3x3x3x16xf32>
+  %rst = "tf.DepthwiseConv2dNative"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
+  return %rst : tensor<256x30x30x16xf32>
+
+// CHECK: %cst = constant dense<0.000000e+00> : tensor<48xf32>
+// CHECK: %cst_0 = constant dense<0.000000e+00> : tensor<1x3x3x48xf32>
+// CHECK: %0 = "tfl.quantize"(%cst_0) {qtype = tensor<1x3x3x48x!quant.uniform<u8:f32, 1.000000e+00>>}
+// CHECK: %1 = "tfl.dequantize"(%0)
+// CHECK: %2 = "tfl.depthwise_conv_2d"(%arg0, %1, %cst)
+// CHECK: return %2
 }
 
 func @identity(tensor<10xi32>) -> tensor<10xi32> {
@@ -195,3 +320,11 @@
   // CHECK: %7 = "tf.Transpose"(%arg1, %6) : (tensor<1280x1000xf32>, tensor<?xi32>) -> tensor<*xf32>
   // CHECK: %8 = "tf.MatMul"(%3, %7) {transpose_a = false, transpose_b = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<1x1000xf32>
 }
+
+func @stop_gradient(%arg0: tensor<3xi32>) -> tensor<3xi32> {
+  %0 = "tf.StopGradient"(%arg0) : (tensor<3xi32>) -> tensor<3xi32>
+  return %0 : tensor<3xi32>
+  // Should be converted to Identity and then from Identity to value
+  // CHECK-LABEL: stop_gradient
+  // CHECK:  return %arg0 : tensor<3xi32>
+}
diff --git a/tensorflow/compiler/mlir/lite/tests/quantize.mlir b/tensorflow/compiler/mlir/lite/tests/quantize.mlir
index b3b439b..24d3887 100644
--- a/tensorflow/compiler/mlir/lite/tests/quantize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/quantize.mlir
@@ -82,6 +82,23 @@
 // CHECK: return %2
 }
 
+// CHECK-LABEL: QuantizeFullyConnected
+func @QuantizeFullyConnected(tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>> {
+^bb0(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>):
+  %cst = constant dense<-1.23697901> : tensor<32xf32>
+  %2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x224x224x3xf32>
+  %3 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>
+  %4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>) -> tensor<32x3x3x3xf32>
+  %5 = "tfl.fully_connected"(%2, %4, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
+  %6 = "tfl.quantize"(%5) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
+  return %6 : tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
+
+// CHECK: %0 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>, value = dense<-7254> : tensor<32xi32>}
+// CHECK: %1 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>}
+// CHECK: %2 = "tfl.fully_connected"(%arg0, %1, %0) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}
+// CHECK: return %2
+}
+
 // CHECK-LABEL: QuantizeAveragePool2D
 func @QuantizeAveragePool2D(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x1x1x16xf32> {
 ^bb0(%arg0: tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>):
@@ -118,6 +135,18 @@
 // CHECK: return %1 : tensor<1x6x6x16xf32>
 }
 
+// CHECK-LABEL: QuantizeLogistic
+func @QuantizeLogistic(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x6x6x16xf32> {
+^bb0(%arg0: tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>):
+  %0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x6x6x16xf32>
+  %1 = "tfl.logistic"(%0) : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32>
+  return %1 : tensor<1x6x6x16xf32>
+
+// CHECK: %0 = "tfl.logistic"(%arg0) : (tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>)
+// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x6x6x16x!quant.uniform<u8:f32, 3.906250e-03>>)
+// CHECK: return %1
+}
+
 // CHECK-LABEL: QuantizeAdd
 func @QuantizeAdd(tensor<1x56x56x24x!quant.uniform<u8:f32, 0.27583434161017922:119>>, tensor<1x56x56x24x!quant.uniform<u8:f32, 0.40149296779258581:136>>) -> tensor<1x56x56x24x!quant.uniform<u8:f32, 0.4321689530914905:133>> {
 ^bb0(%arg0: tensor<1x56x56x24x!quant.uniform<u8:f32, 0.27583434161017922:119>>, %arg1: tensor<1x56x56x24x!quant.uniform<u8:f32, 0.40149296779258581:136>>):
@@ -167,4 +196,16 @@
 // CHECK: %0 = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x1x1x16x!quant.uniform<u8:f32, 7.812500e-03:128>>
 // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x1x1x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x1x1x16xf32>
 // CHECK: return %1 : tensor<1x1x1x16xf32>
+}
+
+// CHECK-LABEL: QuantizeSplit
+func @QuantizeSplit(%arg: tensor<4x!quant.uniform<u8:f32, 1.0>>, %cst: tensor<i32>) -> (tensor<2x!quant.uniform<u8:f32, 1.0>>,tensor<2x!quant.uniform<u8:f32, 1.0>>) {
+  %0 = "tfl.dequantize"(%arg) : (tensor<4x!quant.uniform<u8:f32, 1.0>>) -> tensor<4xf32>
+  %1:2 = "tfl.split"(%cst, %0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
+  %2 = "tfl.quantize"(%1#0) {qtype = tensor<2x!quant.uniform<u8:f32, 1.0>>} : (tensor<2xf32>) -> tensor<2x!quant.uniform<u8:f32, 1.0>>
+  %3 = "tfl.quantize"(%1#1) {qtype = tensor<2x!quant.uniform<u8:f32, 1.0>>} : (tensor<2xf32>) -> tensor<2x!quant.uniform<u8:f32, 1.0>>
+  return %2, %3 : tensor<2x!quant.uniform<u8:f32, 1.0>>, tensor<2x!quant.uniform<u8:f32, 1.0>>
+
+// CHECK: %0:2 = "tfl.split"(%arg1, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4x!quant.uniform<u8:f32, 1.000000e+00>>)
+// CHECK: return %0#0, %0#1
 }
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/lite/tests/trim-functions-tf.mlir b/tensorflow/compiler/mlir/lite/tests/trim-functions-tf.mlir
new file mode 100644
index 0000000..95844cc
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/tests/trim-functions-tf.mlir
@@ -0,0 +1,21 @@
+// RUN: tf-opt -tfl-trim-funcs-tf -tfl-trim-funcs-whitelist="bar,foobar" %s | FileCheck %s --dump-input-on-failure
+
+func @foo(%arg0: tensor<1x4xf32>, %arg1: tensor<1x4xf32>) -> tensor<1x4xf32> {
+  %0 = "tfl.pseudo_input"(%arg0) : (tensor<1x4xf32>) -> tensor<1x4xf32>
+  return %0 : tensor<1x4xf32>
+}
+
+func @bar(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>) -> tensor<2x4xf32> {
+  %0 = "tfl.pseudo_input"(%arg0) : (tensor<2x4xf32>) -> tensor<2x4xf32>
+  return %0 : tensor<2x4xf32>
+}
+
+func @foobar(%arg0: tensor<1x4xf32>, %arg1: tensor<1x4xf32>) -> tensor<1x4xf32> {
+  %0 = "tfl.pseudo_input"(%arg0) : (tensor<1x4xf32>) -> tensor<1x4xf32>
+  return %0 : tensor<1x4xf32>
+}
+
+// CHECK-DAG: func @main
+// CHECK-DAG: func @foobar
+// CHECK-NOT: func @foo
+// CHECK-NOT: func @bar
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc
index 9656abb..7d3f260 100644
--- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc
+++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc
@@ -136,7 +136,7 @@
   if (!module.ok()) return kTrFailure;
 
   std::string result;
-  auto status = tensorflow::ConvertTFControlFlowToTFLOrFlatbuffer(
+  auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(
       module.ValueOrDie().get(), output_mlir, emit_builtin_tflite_ops,
       emit_select_tf_ops, emit_custom_ops, emit_quant_adaptor_ops,
       lower_tensor_list_ops, &result);
diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc
index bc2f36b..150b4a9 100644
--- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc
+++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc
@@ -31,6 +31,11 @@
 #include "tensorflow/core/framework/op_def.pb.h"
 #include "tensorflow/core/lib/core/errors.h"
 
+namespace mlir {
+/// Create a pass to convert from the TFExecutor to the TF control dialect.
+FunctionPassBase *CreateTFExecutorToControlDialectConversion();
+}  // namespace mlir
+
 namespace tensorflow {
 
 using mlir::MLIRContext;
@@ -79,12 +84,12 @@
     return tensorflow::GraphdefToSplattedMlirTranslateFunction(
         input_filename, debug_info_file, input_arrays, input_dtypes,
         input_shapes, output_arrays, inference_type, min_values, max_values,
-        prune_unused_nodes, context);
+        prune_unused_nodes, /*convert_legacy_fed_inputs=*/true, context);
   }
   return tensorflow::GraphdefToMlirTranslateFunction(
       input_filename, debug_info_file, input_arrays, input_dtypes, input_shapes,
       output_arrays, inference_type, min_values, max_values, prune_unused_nodes,
-      context);
+      /*convert_legacy_fed_inputs=*/true, context);
 }
 
 bool ShouldRunQuantizePasses(mlir::ModuleOp m) {
@@ -99,6 +104,7 @@
                                 bool emit_quant_adaptor_ops,
                                 bool lower_tensor_list_ops,
                                 mlir::PassManager *pass_manager) {
+  pass_manager->addPass(mlir::CreateTFExecutorToControlDialectConversion());
   pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass());
 
   if (lower_tensor_list_ops) {
@@ -138,7 +144,7 @@
   }
 }
 
-Status ConvertTFControlFlowToTFLOrFlatbuffer(
+Status ConvertTFExecutorToTFLOrFlatbuffer(
     mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops,
     bool emit_select_tf_ops, bool emit_custom_ops, bool emit_quant_adaptor_ops,
     bool lower_tensor_list_ops, std::string *result) {
diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h
index 68ab674..46b9bec 100644
--- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h
+++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h
@@ -63,12 +63,12 @@
                                 bool lower_tensor_list_ops,
                                 mlir::PassManager* pass_manager);
 
-// Taking a MLIR module in TF control flow dialect and a set of parameters,
+// Taking a MLIR module in TF executor dialect and a set of parameters,
 // applies a set of passes to convert the module to TF Lite dialect and
 // serializes the result to a string. Depending on an attribute in the module
 // main function, Quantization is applied. If `export_to_mlir` is true, the
 // result is exported in MLIR text format, otherwise exported in flat buffer.
-Status ConvertTFControlFlowToTFLOrFlatbuffer(
+Status ConvertTFExecutorToTFLOrFlatbuffer(
     mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops,
     bool emit_select_tf_ops, bool emit_custom_ops, bool emit_quant_adaptor_ops,
     bool lower_tensor_list_ops, std::string* result);
diff --git a/tensorflow/compiler/mlir/lite/tools/op_quant_spec_getters_gen.cc b/tensorflow/compiler/mlir/lite/tools/op_quant_spec_getters_gen.cc
index 9be4a0b..ca55716 100644
--- a/tensorflow/compiler/mlir/lite/tools/op_quant_spec_getters_gen.cc
+++ b/tensorflow/compiler/mlir/lite/tools/op_quant_spec_getters_gen.cc
@@ -38,7 +38,7 @@
   llvm::Regex acc_uniform_trait_regex{"AccumulatorUniformScale<([0-9]*),"};
   llvm::Regex fixed_uniform_trait_regex{
       "FixedResultUniformScale<([0-9]+).*(true|false)>"};
-  emitSourceFileHeader("TensorFlow Lite Ops Quant Spec Getters", os);
+  emitSourceFileHeader("Generated Ops Quant Spec Getters", os);
 
   // Retrieve all the definitions derived from TFL_Op and sort by record name.
   std::vector<Record *> defs = records.getAllDerivedDefinitions("Op");
diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
index 19ea5aa..0cde5c1 100644
--- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
@@ -29,7 +29,6 @@
     "$_builder.getI32IntegerAttr($_self.cast<ArrayAttr>().getValue()[" # i #
     "].cast<IntegerAttr>().getInt())">;
 
-
 // Merge the two Attributes to a ArrayAttr;
 def Merge2AttrsToArray : NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">;
 
@@ -135,10 +134,14 @@
           (TFL_ReverseSequenceOp $input, $seq_lengths,
            (convertIntAttrTo32Bit $seq_dim),
            (convertIntAttrTo32Bit $batch_dim))>;
+def : Pat<(TF_RoundOp $arg), (TFL_RoundOp $arg)>;
 def : Pat<(TF_RsqrtOp $arg), (TFL_RsqrtOp $arg)>;
+def : Pat<(TF_SqrtOp $arg), (TFL_SqrtOp $arg)>;
+def : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>;
 // TODO(jpienaar): this is not true for all selects, TF's select supports rank 0
 // condition
 def : Pat<(TF_SelectOp $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>;
+def : Pat<(TF_SelectV2Op $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>;
 def : Pat<(TF_ShapeOp $arg), (TFL_ShapeOp $arg)>;
 def : Pat<(TF_SigmoidOp $arg), (TFL_LogisticOp $arg)>;
 def : Pat<(TF_SinOp F32Tensor:$arg), (TFL_SinOp $arg)>;
@@ -147,6 +150,7 @@
 def : Pat<(TF_SqueezeOp $arg, $squeeze_dims), (TFL_SqueezeOp $arg, $squeeze_dims)>;
 def : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>;
 def : Pat<(TF_TransposeOp $arg, $perm), (TFL_TransposeOp $arg, $perm)>;
+def : Pat<(TF_WhereOp $arg), (TFL_WhereOp $arg)>;
 def : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>;
 
 // The following two rules can both match an tf.Placeholder.input node with
@@ -251,6 +255,8 @@
 
 def : Pat<(TF_PadOp $arg0, $arg1), (TFL_PadOp $arg0, $arg1)>;
 
+def : Pat<(TF_TileOp $arg0, $arg1), (TFL_TileOp $arg0, $arg1)>;
+
 def : Pat<(TF_PadV2Op $arg0, $arg1, $cst), (TFL_PadV2Op $arg0, $arg1, $cst)>;
 
 def : Pat<(TF_MeanOp $arg0, $arg1, BoolAttr:$arg2), (TFL_MeanOp $arg0, $arg1, $arg2)>;
@@ -266,16 +272,26 @@
 
 def : Pat<(TF_ProdOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceProdOp $arg0, $arg1, $arg2)>;
 
+def : Pat<(TF_AnyOp $input, $reduction_indices, $keep_dims),
+          (TFL_ReduceAnyOp $input, $reduction_indices, $keep_dims)>;
+
 def : Pat<(TF_CastOp $arg0, BoolAttr:$arg1), (TFL_CastOp $arg0)>;
 
 def : Pat<(TF_BatchToSpaceNDOp $input, $block_shape, $crops), (TFL_BatchToSpaceNdOp $input, $block_shape, $crops)>;
 
 def : Pat<(TF_SpaceToBatchNDOp $input, $block_shape, $paddings), (TFL_SpaceToBatchNdOp $input, $block_shape, $paddings)>;
 
+def : Pat<(TF_SpaceToDepthOp $input, $block_size, IsDataFormatNHWC:$data_format),
+          (TFL_SpaceToDepthOp $input, (convertIntAttrTo32Bit $block_size))>;
+
 def : Pat<(TF_ResizeBilinearOp $images, $size, $align_corners, ConstBoolAttrFalse:$half_pixel_centers), (TFL_ResizeBilinearOp $images, $size, $align_corners)>;
+def : Pat<(TF_ResizeNearestNeighborOp $images, $size, $align_corners, ConstBoolAttrFalse:$half_pixel_centers), (TFL_ResizeNearestNeighborOp $images, $size, $align_corners)>;
 
 def : Pat<(TF_MirrorPadOp $arg0, $arg1, $cst), (TFL_MirrorPadOp $arg0, $arg1, $cst)>;
 
+def : Pat<(TF_SparseToDenseOp $sparse_indices, $output_shape, $sparse_values, $default_value, $validate_indices),
+          (TFL_SparseToDenseOp $sparse_indices, $output_shape, $sparse_values, $default_value)>;
+
 def : Pat<
   (TF_StridedSliceOp $input, $begin, $end, $strides, $begin_mask, $end_mask, $ellipsis_mask, $new_axis_mask, $shrink_axis_mask),
   (TFL_StridedSliceOp $input, $begin, $end, $strides,
@@ -284,4 +300,7 @@
 
 def : Pat<(TF_UniqueOp $arg0),(TFL_UniqueOp $arg0)>;
 
+def : Pat<(TF_FloorModOp $arg0, $arg1), (TFL_FloorModOp $arg0, $arg1)>;
+def : Pat<(TF_ExpOp $arg0), (TFL_ExpOp $arg0)>;
+
 def : Pat<(TF_LRNOp $arg0, $radius, F32Attr:$bias, F32Attr:$alpha, F32Attr:$beta), (TFL_LocalResponseNormalizationOp $arg0, (convertIntAttrTo32Bit $radius), $bias, $alpha, $beta)>;
diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
index faf80f3..cd7fc34 100644
--- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
@@ -35,7 +35,6 @@
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
-#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h"
 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 
@@ -205,10 +204,10 @@
 
   // Add the generated patterns to the list.
   populateWithGenerated(ctx, &patterns);
-  RewriteListBuilder<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
-                     ConvertTFPackOp, ConvertTFSplitOp, ConvertTFSplitVOp,
-                     ConvertTFUnpackOp>::build(patterns, ctx);
-  applyPatternsGreedily(func, std::move(patterns));
+  patterns.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
+                  ConvertTFPackOp, ConvertTFSplitOp, ConvertTFSplitVOp,
+                  ConvertTFUnpackOp>(ctx);
+  applyPatternsGreedily(func, patterns);
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
index ad54a36..6c0f06f 100644
--- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
@@ -82,7 +82,7 @@
 
   // Changes the function type of `cond_func` and `body_func`, and the result
   // type of the `WhileOp`.
-  LogicalResult UpdateWhileFunctionType(TF::WhileOp *while_op);
+  LogicalResult UpdateWhileFunctionType(TF::WhileOp op);
 };
 
 Value *CreateI32SplatConst(Operation *op, PatternRewriter *rewriter,
@@ -100,10 +100,10 @@
       shape_tensor, scalar_val);
 }
 
-struct ConvertTFTensorListSetItem : public RewritePattern {
+struct ConvertTFTensorListSetItem
+    : public OpRewritePattern<TF::TensorListSetItemOp> {
   explicit ConvertTFTensorListSetItem(MLIRContext *context)
-      : RewritePattern(TF::TensorListSetItemOp::getOperationName(), 1,
-                       context) {}
+      : OpRewritePattern<TF::TensorListSetItemOp>(context, 1) {}
   // This function rewrites the original op into a series of slice and concat op
   // to produce the same result. It first slices the first `$index` rows. Then
   // expands the dimension of the `$item`, followed by another slice of the
@@ -116,23 +116,21 @@
   //        (Slice $input, [0, 0, ...], (Concat (ExpandDims $index, expand_dim =
   //        0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice
   //        $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>;
-  PatternMatchResult matchAndRewrite(Operation *op,
+  PatternMatchResult matchAndRewrite(TF::TensorListSetItemOp op,
                                      PatternRewriter &rewriter) const override {
-    TF::TensorListSetItemOp tf_op = cast<TF::TensorListSetItemOp>(op);
-
-    auto input = tf_op.input_handle();
+    auto input = op.input_handle();
     auto shape_dtype = rewriter.getIntegerType(32);
     auto input_rank = rewriter.create<TF::RankOp>(
-        op->getLoc(), rewriter.getTensorType({}, shape_dtype), input);
-    auto item = tf_op.item();
+        op.getLoc(), rewriter.getTensorType({}, shape_dtype), input);
+    auto item = op.item();
     auto item_rank = rewriter.create<TF::RankOp>(
-        op->getLoc(), rewriter.getTensorType({}, shape_dtype), item);
+        op.getLoc(), rewriter.getTensorType({}, shape_dtype), item);
 
     // Prepare the start position for the first slice op, which is [0, 0, ..,
     // 0].
     auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0);
     auto position_shape = rewriter.create<TF::ExpandDimsOp>(
-        op->getLoc(), rewriter.getTensorType({1}, shape_dtype), input_rank,
+        op.getLoc(), rewriter.getTensorType({1}, shape_dtype), input_rank,
         scalar_zero);
     // Fill all 0s into the first position tensor.
     auto first_start_position =
@@ -141,33 +139,33 @@
     // Prepare the start position for the second slice op, which is
     // [index + 1, 0, 0 .. 0].
     // Calculate the first dimension, which is index + 1.
-    auto index = tf_op.index();
+    auto index = op.index();
     auto vector_type = rewriter.getTensorType({1}, shape_dtype);
     auto begin = rewriter.create<TF::AddOp>(
-        op->getLoc(), rewriter.getTensorType(shape_dtype), index,
+        op.getLoc(), rewriter.getTensorType(shape_dtype), index,
         CreateI32SplatConst(op, &rewriter, {1}, 1));
 
     // Followed by the first dimension `begin`, are `item_rank` of 0s.
     auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
-        op->getLoc(), rewriter.getTensorType({1}, shape_dtype), item_rank,
+        op.getLoc(), rewriter.getTensorType({1}, shape_dtype), item_rank,
         scalar_zero);
     auto partial_second_start_position =
         CreateI32SplatTensor(op, &rewriter, item_position_shape, 0);
     auto position_type = first_start_position->getType();
     // Concatenate `begin` with the remaining 0s.
     auto second_start_position = rewriter.create<TF::ConcatOp>(
-        op->getLoc(), position_type, scalar_zero,
+        op.getLoc(), position_type, scalar_zero,
         ArrayRef<Value *>({begin, partial_second_start_position}),
         rewriter.getI64IntegerAttr(2));
 
     // Create the size parameter for the first slice op, which is [index, -1,
     // -1, .., -1].
     auto size1_leading_dim = rewriter.create<TF::ExpandDimsOp>(
-        op->getLoc(), vector_type, index, scalar_zero);
+        op.getLoc(), vector_type, index, scalar_zero);
     auto partial_size1 =
         CreateI32SplatTensor(op, &rewriter, item_position_shape, -1);
     auto size1 = rewriter.create<TF::ConcatOp>(
-        op->getLoc(), position_type, scalar_zero,
+        op.getLoc(), position_type, scalar_zero,
         ArrayRef<Value *>({size1_leading_dim, partial_size1}),
         rewriter.getI64IntegerAttr(2));
 
@@ -179,14 +177,14 @@
     auto element_type = input->getType().cast<TensorType>().getElementType();
     auto unranked_tensor = rewriter.getTensorType(element_type);
     auto slice1 = rewriter.create<TF::SliceOp>(
-        op->getLoc(), unranked_tensor, input, first_start_position, size1);
+        op.getLoc(), unranked_tensor, input, first_start_position, size1);
     auto slice2 = rewriter.create<TF::SliceOp>(
-        op->getLoc(), unranked_tensor, input, second_start_position, size2);
+        op.getLoc(), unranked_tensor, input, second_start_position, size2);
 
     // Expand the dimension of item so that it will have the same rank with
     // input.
     auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
-        op->getLoc(), unranked_tensor, item, scalar_zero);
+        op.getLoc(), unranked_tensor, item, scalar_zero);
 
     // Concatenate three parts together to generate the final result.
     rewriter.replaceOpWithNewOp<TF::ConcatOp>(
@@ -198,52 +196,54 @@
   }
 };
 
-struct ConvertTFTensorListReserve : public RewritePattern {
-  explicit ConvertTFTensorListReserve(MLIRContext *context)
-      : RewritePattern(TF::TensorListReserveOp::getOperationName(), 1,
-                       context) {}
+// Rewrites op of the template type initializing a TensorList with a list of ops
+// to generate an equivalent raw tensor. Derived classes are required to
+// override GetNumElements method.
+template <typename OpT>
+struct ConvertTFTensorListInitOp : public OpRewritePattern<OpT> {
+  explicit ConvertTFTensorListInitOp(MLIRContext *context)
+      : OpRewritePattern<OpT>(context, 1) {}
+
+  // Create and return a 1-d tensor with exactly one element equal to the number
+  // of list elements to initialize the output tensor list with.
+  virtual Value *GetNumElements(OpT op, PatternRewriter *rewriter) const = 0;
 
   // Rewrites the original op into `tf.fill`. The result tensor shape is
   // [num_element, element_shape]. All the values in the result tensor will be
   // initialized to 0.
-  PatternMatchResult matchAndRewrite(Operation *op,
+  PatternMatchResult matchAndRewrite(OpT op,
                                      PatternRewriter &rewriter) const override {
-    TF::TensorListReserveOp tf_op = cast<TF::TensorListReserveOp>(op);
-
-    auto element_shape = tf_op.element_shape();
+    auto element_shape = op.element_shape();
     auto shape_dtype = getElementTypeOrSelf(element_shape->getType());
-    auto num_elements = tf_op.num_elements();
-    Type element_dtype = tf_op.element_dtype();
+    Type element_dtype = op.element_dtype();
 
     int64_t result_rank = -1;  // -1 means unknown result rank.
     Type result_type = rewriter.getTensorType(element_dtype);
-    if (auto element_type = tf_op.element_type().dyn_cast<RankedTensorType>()) {
+    if (auto element_type =
+            op.element_type().template dyn_cast<RankedTensorType>()) {
       result_rank = element_type.getRank() + 1;
       // If element type is ranked, then result type will have unknown leading
       // dimension and element shape for the following dimensions.
       //
-      // Note: leading dim is not inferred here even if num_elements input is a
-      // constant.
+      // Note: leading dim is not inferred here even when it is a constant.
       SmallVector<int64_t, 4> result_shape = {-1};
       ArrayRef<int64_t> shape = element_type.getShape();
       result_shape.append(shape.begin(), shape.end());
       result_type = rewriter.getTensorType(result_shape, element_dtype);
     }
 
-    // The output shape of the result tensor should be [num_elements +
-    // element_shape].
-    auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0);
-    auto leading_dim = rewriter.create<TF::ExpandDimsOp>(
-        op->getLoc(), rewriter.getTensorType({1}, shape_dtype), num_elements,
-        scalar_zero);
-
     // Create a 1-D RankedTensorType for result's shape. Number of elements in
     // it is equal to the rank of the result, if known. Otherwise, the number of
     // elements are unknown and represented with -1. In both cases, we can
     // specify dimension using rank of the result.
     Type shape_type = rewriter.getTensorType({result_rank}, shape_dtype);
+
+    // Add number of elements as the prefix to the element shape to get shape of
+    // the output tensor.
+    auto leading_dim = GetNumElements(op, &rewriter);
+    auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0);
     auto list_shape = rewriter.create<TF::ConcatOp>(
-        op->getLoc(), shape_type, scalar_zero,
+        op.getLoc(), shape_type, scalar_zero,
         ArrayRef<Value *>({leading_dim, element_shape}),
         rewriter.getI64IntegerAttr(2));
 
@@ -251,9 +251,92 @@
     // as specified by element_dtype.
     auto zero_type = rewriter.getTensorType({}, element_dtype);
     auto zero_attr = rewriter.getZeroAttr(zero_type);
-    auto zero = rewriter.create<ConstantOp>(op->getLoc(), zero_type, zero_attr);
+    auto zero = rewriter.create<ConstantOp>(op.getLoc(), zero_type, zero_attr);
 
     rewriter.replaceOpWithNewOp<TF::FillOp>(op, result_type, list_shape, zero);
+    return Pattern::matchSuccess();
+  }
+};
+
+struct ConvertTFTensorListReserve
+    : public ConvertTFTensorListInitOp<TF::TensorListReserveOp> {
+  explicit ConvertTFTensorListReserve(MLIRContext *context)
+      : ConvertTFTensorListInitOp(context) {}
+
+  Value *GetNumElements(TF::TensorListReserveOp op,
+                        PatternRewriter *rewriter) const override {
+    auto scalar_zero = CreateI32SplatConst(op, rewriter, {}, 0);
+    auto shape_dtype = getElementTypeOrSelf(op.element_shape()->getType());
+    return rewriter->create<TF::ExpandDimsOp>(
+        op.getLoc(), rewriter->getTensorType({1}, shape_dtype),
+        op.num_elements(), scalar_zero);
+  }
+};
+
+// TODO(hinsu): Replace with declarative patterns once the RewriterGen infra
+// supports patterns involving variadic operand ops.
+//
+// Note that we ignore the second operand `max_num_elements` as we don't have
+// any restrictions on the number of elements we can support. So this may
+// have a different behavior compared to TensorFlow in case of errors.
+struct ConvertTFEmptyTensorList
+    : public ConvertTFTensorListInitOp<TF::EmptyTensorListOp> {
+  explicit ConvertTFEmptyTensorList(MLIRContext *context)
+      : ConvertTFTensorListInitOp(context) {}
+
+  Value *GetNumElements(TF::EmptyTensorListOp op,
+                        PatternRewriter *rewriter) const override {
+    return CreateI32SplatConst(op, rewriter, {1}, 0);
+  }
+};
+
+struct ConvertTFTensorListPushBack : public RewritePattern {
+  explicit ConvertTFTensorListPushBack(MLIRContext *context)
+      : RewritePattern(TF::TensorListPushBackOp::getOperationName(), 1,
+                       context) {}
+
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const override {
+    TF::TensorListPushBackOp push_back_op = cast<TF::TensorListPushBackOp>(op);
+    Value *item = push_back_op.tensor();
+    Type dtype = getElementTypeOrSelf(*item);
+
+    // Returns a new type by prepending the specified dimension to the shape of
+    // the given type if it is a ranked type.
+    auto with_leading_dim = [&](int64_t dim, Type type) -> Type {
+      if (RankedTensorType ty = type.dyn_cast<RankedTensorType>()) {
+        llvm::SmallVector<int64_t, 4> shape = {dim};
+        shape.append(ty.getShape().begin(), ty.getShape().end());
+        return rewriter.getTensorType(shape, dtype);
+      }
+
+      return rewriter.getTensorType(dtype);
+    };
+
+    // Expand the shape of the item so that it will have rank same as the input
+    // tensor and it is compatible for the Concat Op.
+    Type expanded_item_type = with_leading_dim(1, item->getType());
+    auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0);
+    auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
+        op->getLoc(), expanded_item_type, item, scalar_zero);
+
+    // If the variant type in the output handle has item shape available, use it
+    // to derive the output shape by setting unknown leading dimension.
+    // Otherwise, result type will be of unranked type.
+    Type handle_type = push_back_op.output_handle()->getType();
+    TF::VariantType handle_dtype =
+        getElementTypeOrSelf(handle_type).cast<TF::VariantType>();
+    Type result_type = rewriter.getTensorType(dtype);
+    if (!handle_dtype.getSubtypes().empty()) {
+      result_type = with_leading_dim(-1, handle_dtype.getSubtypes()[0]);
+    }
+
+    // Concatenate tensor stored in the input handle with the expanded item to
+    // get a tensor equivalent to the TensorList generated by this op.
+    rewriter.replaceOpWithNewOp<TF::ConcatOp>(
+        op, result_type, scalar_zero,
+        ArrayRef<Value *>({push_back_op.input_handle(), expanded_item}),
+        rewriter.getI64IntegerAttr(2));
     return matchSuccess();
   }
 };
@@ -267,17 +350,17 @@
 }  // namespace TFL
 
 LogicalResult LowerStaticTensorListPass::UpdateWhileFunctionType(
-    TF::WhileOp *while_op) {
+    TF::WhileOp op) {
   SmallVector<Type, 8> unranked_argument_types;
-  for (const auto &operand : while_op->getOperands()) {
+  for (const auto &operand : op.getOperands()) {
     unranked_argument_types.push_back(
         UnrankedTensorType::get(getElementTypeOrSelf(operand->getType())));
   }
 
   auto *context = &getContext();
   auto module = getModule();
-  FuncOp cond_func = module.lookupSymbol<FuncOp>(while_op->cond());
-  FuncOp body_func = module.lookupSymbol<FuncOp>(while_op->body());
+  FuncOp cond_func = module.lookupSymbol<FuncOp>(op.cond());
+  FuncOp body_func = module.lookupSymbol<FuncOp>(op.body());
 
   if (cond_func) {
     // Change `cond_func`'s argument types to `unranked_argument_types`.
@@ -313,9 +396,9 @@
     }
   }
 
-  for (int i = 0; i < while_op->getNumOperands(); ++i) {
-    auto operand = while_op->getOperand(i);
-    auto result = while_op->getResult(i);
+  for (int i = 0; i < op.getNumOperands(); ++i) {
+    auto operand = op.getOperand(i);
+    auto result = op.getResult(i);
     if (getElementTypeOrSelf(result->getType()).isa<TF::VariantType>()) {
       // If we notice the result type is a DT_VARIANT, we change the
       // corresponding result type to unranked tensor type.
@@ -357,7 +440,11 @@
         }
         auto c = ConvertTFTensorListReserve(context);
         rewriter->setInsertionPoint(op);
-        c.matchAndRewrite(op, *rewriter);
+        c.matchAndRewrite(tf_op, *rewriter);
+      } else if (auto tf_op = llvm::dyn_cast<TF::EmptyTensorListOp>(op)) {
+        auto c = ConvertTFEmptyTensorList(context);
+        rewriter->setInsertionPoint(op);
+        c.matchAndRewrite(tf_op, *rewriter);
       } else if (auto tf_op = llvm::dyn_cast<TF::TensorListGetItemOp>(op)) {
         auto c = TFL::ConvertTFTensorListGetItem(context);
         rewriter->setInsertionPoint(op);
@@ -365,14 +452,18 @@
       } else if (auto tf_op = llvm::dyn_cast<TF::TensorListSetItemOp>(op)) {
         auto c = ConvertTFTensorListSetItem(context);
         rewriter->setInsertionPoint(op);
-        c.matchAndRewrite(op, *rewriter);
+        c.matchAndRewrite(tf_op, *rewriter);
       } else if (auto tf_op = llvm::dyn_cast<TF::TensorListStackOp>(op)) {
         auto c = TFL::ConvertTFTensorListStack(context);
         rewriter->setInsertionPoint(op);
         c.matchAndRewrite(op, *rewriter);
+      } else if (auto tf_op = llvm::dyn_cast<TF::TensorListPushBackOp>(op)) {
+        auto c = ConvertTFTensorListPushBack(context);
+        rewriter->setInsertionPoint(op);
+        c.matchAndRewrite(op, *rewriter);
       } else if (auto tf_op = llvm::dyn_cast<TF::WhileOp>(op)) {
         if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context));
-        UpdateWhileFunctionType(&tf_op);
+        UpdateWhileFunctionType(tf_op);
       } else if (auto tf_op = llvm::dyn_cast<TF::IdentityOp>(op)) {
         if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context));
         tf_op.getResult()->setType(tf_op.getOperand()->getType());
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
index 8e3d969..1d7cece 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
@@ -21,14 +21,22 @@
 
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/StringSwitch.h"
+#include "llvm/Support/Casting.h"
 #include "mlir/IR/Attributes.h"  // TF:local_config_mlir
+#include "mlir/IR/Matchers.h"  // TF:local_config_mlir
 #include "mlir/IR/PatternMatch.h"  // TF:local_config_mlir
+#include "mlir/IR/StandardTypes.h"  // TF:local_config_mlir
 #include "mlir/Pass/Pass.h"  // TF:local_config_mlir
+#include "mlir/StandardOps/Ops.h"  // TF:local_config_mlir
 #include "mlir/Support/Functional.h"  // TF:local_config_mlir
+#include "mlir/Support/LLVM.h"  // TF:local_config_mlir
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 
 namespace mlir {
 namespace TFL {
@@ -37,19 +45,28 @@
 // The actual Optimize Pass.
 namespace {
 
+using ::llvm::cast;
+using ::llvm::isa;
+
 // Optimize TFLite operations in functions.
 struct Optimize : public FunctionPass<Optimize> {
   void runOnFunction() override;
 };
 
+// Returns whether the given type `a` is broadcast-compatible with `b`.
+bool IsBroadcastableElementsAttrAndType(Type a, Type b) {
+  return OpTrait::util::getBroadcastedType(a, b) != Type();
+}
+
 // Returns whether the given `a` and `b` ElementsAttr have broadcast-compatible
 // types.
 bool IsBroadcastableElementsAttrs(Attribute a, Attribute b) {
-  return OpTrait::util::getBroadcastedType(a.getType(), b.getType()) != Type();
+  return IsBroadcastableElementsAttrAndType(a.getType(), b.getType());
 }
 
 #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc"
-// Fuse Add with FullyConnected.
+
+// Fuse Add with proceeding FullyConnected.
 // Note that this assumes that the bias in the fullyConnected
 // is always None.
 // TODO(b/136285429): Move to tablegen when variadic is supported
@@ -121,6 +138,109 @@
   }
 };
 
+// TODO(b/136285429): Move to tablegen when variadic is supported.
+struct FuseFullyConnectedAndRelu : public RewritePattern {
+  explicit FuseFullyConnectedAndRelu(MLIRContext *context)
+      : RewritePattern(TFL::ReluOp::getOperationName(), {"tfl.fully_connected"},
+                       4, context) {}
+
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const override {
+    auto relu_op = cast<ReluOp>(op);
+    Operation *input = relu_op.getOperand()->getDefiningOp();
+    if (!isa_and_nonnull<FullyConnectedOp>(input)) return matchFailure();
+    auto fully_connected_op = cast<FullyConnectedOp>(input);
+    if (fully_connected_op.fused_activation_function() != "NONE")
+      return matchFailure();
+
+    auto new_activation_func = rewriter.getStringAttr("RELU");
+    auto new_weights_format =
+        rewriter.getStringAttr(fully_connected_op.weights_format());
+    auto new_keep_num_dims =
+        rewriter.getBoolAttr(fully_connected_op.keep_num_dims());
+    rewriter.replaceOpWithNewOp<FullyConnectedOp>(
+        relu_op, relu_op.getType(), fully_connected_op.input(),
+        fully_connected_op.filter(), fully_connected_op.bias(),
+        new_activation_func, new_weights_format, new_keep_num_dims);
+
+    return matchSuccess();
+  }
+};
+
+// Fuse Mul with proceeding FullyConnected.
+// TODO(b/136285429): Move to tablegen when variadic is supported
+struct FuseFullyConnectedAndMul : public RewritePattern {
+  explicit FuseFullyConnectedAndMul(MLIRContext *context)
+      : RewritePattern(TFL::MulOp::getOperationName(),
+                       {"tfl.fully_connected", "tfl.mul", "std.constant"}, 4,
+                       context) {}
+
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const override {
+    // Mul.
+    auto mul_op = cast<MulOp>(op);
+    DenseElementsAttr cst;
+    Value *constant_val = mul_op.rhs();
+    if (!matchPattern(constant_val, m_Constant(&cst))) return matchFailure();
+
+    // Fully Connected.
+    auto fc_op =
+        dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs()->getDefiningOp());
+    if (!fc_op) return matchFailure();
+    Value *filter = fc_op.filter();
+    Value *bias = fc_op.bias();
+    ElementsAttr cst_tmp;
+    if (!matchPattern(filter, m_Constant(&cst_tmp))) return matchFailure();
+    if (!bias->getType().isa<NoneType>() &&
+        !matchPattern(bias, m_Constant(&cst_tmp)))
+      return matchFailure();
+    if (fc_op.fused_activation_function().equals("None")) return matchFailure();
+
+    // Broadcast the constant operand of Mul if it isn't compatible to the
+    // filter input. We only support broadcasting the operand along the depth
+    // dimension, when the operand's depth is 1.
+    Value *new_const_val = constant_val;
+    if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter->getType())) {
+      auto original_shape = cst.getType().getShape();
+      llvm::SmallVector<int64_t, 4> normalized_shape(original_shape.begin(),
+                                                     original_shape.end());
+      normalized_shape.push_back(1);
+      auto new_cst = cst.reshape(rewriter.getTensorType(
+          normalized_shape, cst.getType().getElementType()));
+      Type new_type = new_cst.getType();
+      if (!IsBroadcastableElementsAttrAndType(new_type, filter->getType())) {
+        return matchFailure();
+      }
+      auto new_op =
+          rewriter.create<ConstantOp>(mul_op.getLoc(), new_type, new_cst);
+      new_const_val = new_op.getResult();
+    }
+
+    // Rewrite. Since the folder of TFL::MulOp couldn't broadcast the operands,
+    // TF::MulOp is used to fold the constant.
+    // TODO(b/139192933): switch to the TFL constant folding
+    Location loc = fc_op.getLoc();
+    auto new_filter =
+        rewriter.create<TF::MulOp>(loc, filter, new_const_val).z();
+    // If bias isn't None, it needs to be multiplied as well.
+    if (!bias->getType().isa<NoneType>()) {
+      bias = rewriter.create<TF::MulOp>(loc, bias, constant_val).z();
+    }
+
+    rewriter.replaceOpWithNewOp<TFL::FullyConnectedOp>(
+        mul_op, mul_op.getType(),
+        /*input=*/fc_op.input(),
+        /*filter=*/new_filter,
+        /*bias=*/bias,
+        /*fused_activation_function=*/
+        rewriter.getStringAttr(mul_op.fused_activation_function()),
+        /*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()),
+        /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()));
+
+    return matchSuccess();
+  }
+};
+
 // StridedSlice can have complicated atributes like begin_axis_mask,
 // end_axis_mask, ellipsis_axis_mask, new_axis_mask, shrink_axis_mask. These
 // masks will complicate the strided_slice computation logic, we can simplify
@@ -204,13 +324,14 @@
 
 void Optimize::runOnFunction() {
   OwningRewritePatternList patterns;
+  auto *ctx = &getContext();
   auto func = getFunction();
+
   // Add the generated patterns to the list.
-  TFL::populateWithGenerated(&getContext(), &patterns);
-  patterns.push_back(
-      llvm::make_unique<FuseFullyConnectedAndAdd>(&getContext()));
-  patterns.push_back(llvm::make_unique<PadStridedSliceDims>(&getContext()));
-  applyPatternsGreedily(func, std::move(patterns));
+  TFL::populateWithGenerated(ctx, &patterns);
+  patterns.insert<FuseFullyConnectedAndAdd, FuseFullyConnectedAndRelu,
+                  FuseFullyConnectedAndMul, PadStridedSliceDims>(ctx);
+  applyPatternsGreedily(func, patterns);
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h
index 561c0de..8dccae6 100644
--- a/tensorflow/compiler/mlir/lite/transforms/passes.h
+++ b/tensorflow/compiler/mlir/lite/transforms/passes.h
@@ -16,6 +16,9 @@
 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASSES_H_
 #define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASSES_H_
 
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+
 namespace mlir {
 class FunctionPassBase;
 class ModulePassBase;
@@ -46,6 +49,12 @@
 
 // Creates a instance of the TensorFlow Lite dialect PostQuantize pass.
 FunctionPassBase *CreatePostQuantizePass(bool emit_quant_adaptor_ops);
+
+// Creates an instance of the TensorFlow Lite dialect PruneUnexportedFunctions
+// pass.
+ModulePassBase *CreateTrimFunctionsPass(
+    llvm::ArrayRef<std::string> trim_funcs_whitelist);
+
 }  // namespace TFL
 
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc
index 94c19d2..e39789a 100644
--- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc
@@ -18,8 +18,8 @@
 #include "mlir/IR/MLIRContext.h"  // TF:local_config_mlir
 #include "mlir/Pass/Pass.h"  // TF:local_config_mlir
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
-#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h"
 
 //===----------------------------------------------------------------------===//
 // The post-quantize Pass.
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td
index 62c3de8..6b5b754 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td
@@ -18,7 +18,10 @@
 
 def FalseBoolAttr : AttrConstraint<CPred<"!$_self.getValue()">>;
 
-// Converts tf.FusedBatchNorm into a sequence of more primitive arithmetic
+def HasNoUse: Constraint<
+    CPred<"$0->use_begin() == $0->use_end()">, "has no use">;
+
+// Converts tf.FusedBatchNorm & tf.FusedBatchNormV3 into a sequence of more primitive arithmetic
 // operations. Specifically, performs the following calculation:
 //
 //   (x - mean) * scale / sqrt(variance + epsilon) + offset
@@ -29,9 +32,9 @@
 // is then to compute
 //   (x * multiplier) + (offset - mean * multiplier).
 def : Pattern<
-    (TF_FusedBatchNormOp $x, $scale, $offset, $mean, $variance,
-                         F32Attr:$epsilon, $data_format,
-                         FalseBoolAttr:$is_training),
+    (TF_FusedBatchNormOp:$root
+        $x, $scale, $offset, $mean, $variance,
+        F32Attr:$epsilon, $data_format, FalseBoolAttr:$is_training),
     [(TF_AddOp
         (TF_MulOp
             $x,
@@ -41,11 +44,38 @@
                     (TF_AddOp $variance,
                               (TF_ConstOp $epsilon))))),
         (TF_SubOp $offset, (TF_MulOp $mean, $multiplier))),
-     /*batch_mean=*/(verifyUnusedValue),
-     /*batch_variance=*/(verifyUnusedValue),
-     /*reserve_space_1=*/(verifyUnusedValue),
-     /*reserve_space_2=*/(verifyUnusedValue)
-    ]>;
+     // We already guaranteed that the last four results has no use so it does
+     // not matter what value we provide here for replacement.
+     /*batch_mean=*/(replaceWithValue $x),
+     /*batch_variance=*/(replaceWithValue $x),
+     /*reserve_space_1=*/(replaceWithValue $x),
+     /*reserve_space_2=*/(replaceWithValue $x)],
+    [(HasNoUse $root__1), (HasNoUse $root__2),
+     (HasNoUse $root__3), (HasNoUse $root__4)]>;
+
+def : Pattern<
+    (TF_FusedBatchNormV3Op:$root
+        $x, $scale, $offset, $mean, $variance,
+        F32Attr:$epsilon, $data_format, FalseBoolAttr:$is_training),
+    [(TF_AddOp
+        (TF_MulOp
+            $x,
+            (TF_MulOp:$multiplier
+                $scale,
+                (TF_RsqrtOp
+                    (TF_AddOp $variance,
+                              (TF_ConstOp $epsilon))))),
+        (TF_SubOp $offset, (TF_MulOp $mean, $multiplier))),
+     // We already guaranteed that the last five results have no use so it does
+     // not matter what value we provide here for replacement.
+     /*batch_mean=*/(replaceWithValue $x),
+     /*batch_variance=*/(replaceWithValue $x),
+     /*reserve_space_1=*/(replaceWithValue $x),
+     /*reserve_space_2=*/(replaceWithValue $x),
+     /*reserve_space_3=*/(replaceWithValue $x)],
+    [(HasNoUse $root__1), (HasNoUse $root__2),
+     (HasNoUse $root__3), (HasNoUse $root__4),
+     (HasNoUse $root__5)]>;
 
 // TODO(jpienaar): Move to opbase something more general.
 def TFi32ElementsAttr : Attr<CPred<"$_self.isa<DenseIntElementsAttr>">,
@@ -75,6 +105,8 @@
              /*delta=*/(ConstantOp TFi32<-1>)), (ConstantOp TFi32<1>))), $b,
            ConstBoolAttrFalse, $bt)>;
 
+def : Pat<(TF_StopGradientOp $arg), (TF_IdentityOp $arg)>;
+
 //===----------------------------------------------------------------------===//
 // Op removal patterns.
 //===----------------------------------------------------------------------===//
@@ -98,3 +130,22 @@
               $shape),
           (TF_FakeQuantWithMinMaxVarsOp (TF_ReshapeOp $input, $shape),
               $min, $max, $num_bits, $narrow_range)>;
+
+// Casts result type of $1 to a quantized type by using the quantization
+// parameters from the type in $0.
+def UpdateShape : NativeCodeCall<
+  "CastQuantizedTypeAttrFromExpressedType($_builder, $0, GetFirstResultType($1))">;
+
+// When the op is passing-through, the output types of the quantized ops need
+// to be updated as well. Since the quantize op manages its own type by the
+// "qtype" attribute, we should update the type shape in this attribute.
+def : Pat<(TF_TransposeOp:$op
+              (TFL_DequantizeOp (TFL_QuantizeOp $input, $qtype)), $perm),
+          (TFL_DequantizeOp (TFL_QuantizeOp (TF_TransposeOp $input, $perm),
+                                            (UpdateShape $qtype, $op)))>;
+
+def : Pat<(TF_ReshapeOp:$op
+              (TFL_DequantizeOp (TFL_QuantizeOp $input, $qtype)), $shape),
+          (TFL_DequantizeOp
+              (TFL_QuantizeOp (TF_ReshapeOp $input, $shape),
+              (UpdateShape $qtype, $op)))>;
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc
index c91cdb3..895ecbb 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc
@@ -15,10 +15,12 @@
 
 // This transformation pass applies quantization propagation on TFLite dialect.
 
+#include "absl/memory/memory.h"
 #include "mlir/IR/MLIRContext.h"  // TF:local_config_mlir
 #include "mlir/Pass/Pass.h"  // TF:local_config_mlir
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
-#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h"
 
 //===----------------------------------------------------------------------===//
 // The prepare-quantize Pass.
@@ -27,6 +29,7 @@
 namespace TFL {
 
 namespace {
+
 // Applies prepare quantization on the model in TFL dialect. This pass runs
 // before the quantization pass and propagate the quantization parameters
 // across ops. This step is necessary for post-training quantization and also
@@ -47,8 +50,11 @@
   bool quantize_sign_;
 };
 
+#include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc"
+
 void PrepareQuantizePass::runOnFunction() {
-  ApplyQuantizationParamsPropagation(getFunction(), quantize_sign_);
+  ApplyQuantizationParamsPropagation(getFunction(), quantize_sign_,
+                                     GetOpQuantSpec);
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
index 6f2e9e6..b39373c 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
@@ -48,9 +48,9 @@
 #include "mlir/Support/LLVM.h"  // TF:local_config_mlir
 #include "mlir/Support/LogicalResult.h"  // TF:local_config_mlir
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
-#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h"
 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 
@@ -65,60 +65,97 @@
 // pass.
 namespace {
 
+// Returns the first result type of the given `op`.
+Type GetFirstResultType(Operation *op) { return *op->result_type_begin(); }
+// TODO(antiagainst): We need overload functions of the above to facilitate
+// changes brought by declarative rewrite rules. Remove this post variadic
+// operand support is improved.
+// NOLINTNEXTLINE
+Type GetFirstResultType(TF::TransposeOp op) { return op.getType(); }
+// NOLINTNEXTLINE
+Type GetFirstResultType(TF::ReshapeOp op) { return op.getType(); }
+// NOLINTNEXTLINE
+Type GetFirstResultType(Value *val) { return val->getType(); }
+
 // Prepare TF operations in functions for subsequent legalization.
 struct PrepareTFPass : public FunctionPass<PrepareTFPass> {
   void runOnFunction() override;
 };
 
 // TODO(fengliuai): move this rule to PreparePatterns.td
-// Inserts a "tfl.quantize" and "tfl.dequantize" op pair after the
+// Inserts a "tfl.quantize" and "tfl.dequantize" op pair (QDQs) after the
 // "tf.FakeQuantWithMinMaxVarsOp" to be constant folded. Since the constant
 // folding logic will use a "std.constant" op to replace the
 // "tf.FakeQuantWithMinMaxVarsOp", the "tfl.quantize" op is used to preserve
 // the quantization parameters as a TypeAttr and "tfl.dequantize" op used to
-// convert the output type to the next op.
+// convert the output type to the next op. Here are the transformations:
+//
+// input   min cst       max cst          input   min cst       max cst
+//  \       |             |                \       |             |
+//   \  (tf.Identity) (tf.Identity)   =>    \  (tf.Identity) (tf.Identity)
+//    \     |             |                  \     |             |
+//       tf.FakeQuantWithMinMaxVars       tf.FakeQuantWithMinMaxVars
+//                   |                                 |
+//                                                tf.quantize
+//                                                     |
+//                                                tf.dequantize
+//                                                     |
+// If the input is a constant, the result pattern will eventually converted to
+
+//            quant-emulated input
+//                   |
+//               tf.quantize
+//                   |
+//              tf.dequantize
+//                   |
 struct InsertTFLQuantOpsAfterTFFakeQuantOp : public RewritePattern {
   InsertTFLQuantOpsAfterTFFakeQuantOp(MLIRContext *context)
-      : RewritePattern(TF::FakeQuantWithMinMaxVarsOp::getOperationName(), 1,
+      : RewritePattern(TF::FakeQuantWithMinMaxVarsOp::getOperationName(), 3,
                        context) {}
-  struct MatchedState : public PatternState {
-    FloatAttr min;
-    FloatAttr max;
-    APInt num_bits;
-    bool narrow_range;
-  };
-
-  PatternMatchResult match(Operation *op) const override {
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const override {
     auto tf_op = cast<TF::FakeQuantWithMinMaxVarsOp>(op);
+    // We don't want to insert quantize/dequantize if the quantize op exists.
     auto res = tf_op.outputs();
     if (!res->hasOneUse() || isa<QuantizeOp>(*res->user_begin()))
       return matchFailure();
-    auto state = absl::make_unique<MatchedState>();
-    ElementsAttr min_value, max_value;
-    if (!matchPattern(tf_op.min(), m_Constant(&min_value)))
-      return matchFailure();
-    if (!matchPattern(tf_op.max(), m_Constant(&max_value)))
-      return matchFailure();
-    state->min = ExtractSingleElementAsFloat(min_value);
-    state->max = ExtractSingleElementAsFloat(max_value);
-    if (!state->min || !state->max) return matchFailure();
-    state->num_bits = tf_op.num_bits();
-    state->narrow_range = tf_op.narrow_range();
-    return matchSuccess(std::move(state));
-  }
 
-  void rewrite(Operation *op, std::unique_ptr<PatternState> state,
-               PatternRewriter &rewriter) const override {
-    auto &s = *static_cast<MatchedState *>(state.get());
-    Location loc = op->getLoc();
-    Value *copied = OpBuilder(op).clone(*op)->getResult(0);
-    Type res_type = copied->getType();
-    Type storage_type = rewriter.getIntegerType(s.num_bits.getSExtValue());
-    TypeAttr qtype = GetQuantizedTypeAttr(rewriter, res_type, s.min, s.max,
-                                          storage_type, s.narrow_range);
-    Value *quantize_op =
-        rewriter.create<TFL::QuantizeOp>(loc, qtype.getValue(), copied, qtype);
-    rewriter.replaceOpWithNewOp<TFL::DequantizeOp>(op, res_type, quantize_op);
+    // Extract the min/max constant values from the operands. We also consider
+    // a special case that there are tf.Identity ops between the min/max
+    // constants and the tf.FakeQuantWithMinMaxVarsOp.
+    Value *min = tf_op.min(), *max = tf_op.max();
+    ElementsAttr min_value, max_value;
+    if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min->getDefiningOp()))
+      min = id1.input();
+    if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max->getDefiningOp()))
+      max = id2.input();
+    if (!matchPattern(min, m_Constant(&min_value))) return matchFailure();
+    if (!matchPattern(max, m_Constant(&max_value))) return matchFailure();
+    FloatAttr min_attr = ExtractSingleElementAsFloat(min_value);
+    FloatAttr max_attr = ExtractSingleElementAsFloat(max_value);
+    if (!min_attr || !max_attr) return matchFailure();
+
+    // Use the min/max from the operands and the num_bits and narrow_range
+    // attribute to create the quantization parameter for the new quantize op.
+    rewriter.setInsertionPoint(op->getBlock(), ++Block::iterator(op));
+    Type num_bits = rewriter.getIntegerType(tf_op.num_bits().getSExtValue());
+    bool narrow_range = tf_op.narrow_range();
+    Type res_type = tf_op.getType();
+    TypeAttr qtype = GetQuantizedTypeAttr(rewriter, res_type, min_attr,
+                                          max_attr, num_bits, narrow_range);
+
+    // Finally, use the quantization parameter to create the quantize and
+    // dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp
+    // and its users.
+    Value *value = tf_op.outputs();
+    auto quantize = rewriter.create<TFL::QuantizeOp>(
+        op->getLoc(), qtype.getValue(), value, qtype);
+    auto dequantize = rewriter.create<TFL::DequantizeOp>(op->getLoc(), res_type,
+                                                         quantize.output());
+    value->replaceAllUsesWith(dequantize);
+    quantize.getOperation()->replaceUsesOfWith(dequantize, value);
+
+    return matchSuccess();
   }
 };
 
@@ -352,19 +389,26 @@
 void PrepareTFPass::runOnFunction() {
   OwningRewritePatternList patterns;
   auto func = getFunction();
+  // This pattern was intented to uses TFL QDQs to preserve the quantization
+  // parameters from the TF Quant ops, thus this pattern should run with the
+  // first `applyPatternsGreedily` method, which would otherwise removes the
+  // TF FakeQuant ops by the constant folding.
+  patterns.insert<InsertTFLQuantOpsAfterTFFakeQuantOp>(&getContext());
   TFL::populateWithGenerated(&getContext(), &patterns);
   // TODO(karimnosseir): Split to separate pass probably after
   // deciding on long term plan for this optimization.
   // This will allow optimizing any TF_Mul->TF_Conv in the graph
   // and any expanded from FusedBatchNorm. We need to do this
   // before converting TF_Conv to TFL_Conv
-  applyPatternsGreedily(func, std::move(patterns));
-  patterns.push_back(llvm::make_unique<ConvertTFConv2D>(&getContext()));
-  patterns.push_back(
-      llvm::make_unique<ConvertTFDepthwiseConv2dNative>(&getContext()));
-  patterns.push_back(
-      llvm::make_unique<InsertTFLQuantOpsAfterTFFakeQuantOp>(&getContext()));
-  applyPatternsGreedily(func, std::move(patterns));
+  applyPatternsGreedily(func, patterns);
+
+  // Load the generated pattern again, so new quantization pass-through
+  // will be applied.
+  patterns.clear();
+  TFL::populateWithGenerated(&getContext(), &patterns);
+  patterns.insert<ConvertTFConv2D, ConvertTFDepthwiseConv2dNative>(
+      &getContext());
+  applyPatternsGreedily(func, patterns);
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize.cc b/tensorflow/compiler/mlir/lite/transforms/quantize.cc
index 91bb26a..0959531 100644
--- a/tensorflow/compiler/mlir/lite/transforms/quantize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/quantize.cc
@@ -31,8 +31,8 @@
 #include "mlir/Pass/Pass.h"  // TF:local_config_mlir
 #include "mlir/Support/Functional.h"  // TF:local_config_mlir
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
-#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h"
 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
 
 namespace mlir {
@@ -55,9 +55,9 @@
   auto func = getFunction();
   auto* ctx = func.getContext();
   TFL::populateWithGenerated(ctx, &patterns);
-  mlir::RewriteListBuilder<mlir::TFL::GenericFullQuantizationPattern<
-      mlir::TFL::QuantizeOp, mlir::TFL::DequantizeOp>>::build(patterns, ctx);
-  applyPatternsGreedily(func, std::move(patterns));
+  patterns.insert<mlir::TFL::GenericFullQuantizationPattern<
+      mlir::TFL::QuantizeOp, mlir::TFL::DequantizeOp>>(ctx);
+  applyPatternsGreedily(func, patterns);
 }
 }  // namespace
 
diff --git a/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc
new file mode 100644
index 0000000..dbd7288
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc
@@ -0,0 +1,133 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <queue>
+#include <string>
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/CommandLine.h"
+#include "mlir/IR/Builders.h"  // TF:local_config_mlir
+#include "mlir/IR/Identifier.h"  // TF:local_config_mlir
+#include "mlir/IR/Location.h"  // TF:local_config_mlir
+#include "mlir/IR/MLIRContext.h"  // TF:local_config_mlir
+#include "mlir/IR/SymbolTable.h"  // TF:local_config_mlir
+#include "mlir/Pass/Pass.h"  // TF:local_config_mlir
+#include "mlir/StandardOps/Ops.h"  // TF:local_config_mlir
+#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
+
+// The cmd line flag to specify the whitelist of functions. Rest are trimmed
+// after this pass is run.
+// NOLINTNEXTLINE
+static llvm::cl::list<std::string> trim_funcs_whitelist(
+    "tfl-trim-funcs-whitelist", llvm::cl::value_desc("list"),
+    llvm::cl::desc("comma seprarated list of whitelisted functions. The first "
+                   "function specified will be used as main."),
+    llvm::cl::CommaSeparated);
+
+namespace mlir {
+namespace TFL {
+namespace {
+
+// The pass to trim functions before we legalize to TFL
+// dialect using the specified whitelist.
+class TrimFunctionsPass : public mlir::ModulePass<TrimFunctionsPass> {
+ public:
+  explicit TrimFunctionsPass() : trim_funcs_whitelist_(trim_funcs_whitelist) {}
+  explicit TrimFunctionsPass(llvm::ArrayRef<std::string> trim_funcs_whitelist)
+      : trim_funcs_whitelist_(trim_funcs_whitelist) {}
+
+ private:
+  void runOnModule() override;
+  bool TrimModule();
+  void Verify();
+
+  llvm::ArrayRef<std::string> trim_funcs_whitelist_;
+};
+
+void TrimFunctionsPass::runOnModule() {
+  // trim the functions in the module using the trim_funcs_whitelist_
+  // by removing functions not in the whitelist.
+  if (TrimModule()) {
+    // verify the updated module is still valid, if not signal the
+    // pass as failed.
+    Verify();
+  }
+}
+
+bool TrimFunctionsPass::TrimModule() {
+  // if no trim_funcs_whitelist_ is specified, this pass is a no-op.
+  if (trim_funcs_whitelist_.empty()) return false;
+
+  llvm::SmallVector<FuncOp, 4> funcs_to_trim;
+  for (auto func : getModule().getOps<FuncOp>()) {
+    if (llvm::is_contained(trim_funcs_whitelist_, func.getName())) {
+      // If no main is specified in the whitelist, use the 1st func
+      // in trim_funcs_whitelist as the main.
+      // TODO(ashwinm): Currently tflite flatbuffer export assumes there is
+      // always a main. This is strictly not required for TFlite. We need to
+      // remove that restriction once we have support to attribute the main
+      // tensorflow function in MLIR TF import using an entry_point attr.
+      if (!llvm::is_contained(trim_funcs_whitelist_, "main") &&
+          func.getName() == trim_funcs_whitelist_[0]) {
+        func.setName("main");
+      }
+    } else {
+      funcs_to_trim.push_back(func);
+    }
+  }
+
+  // remove all unexported functions from the module.
+  for (auto func : funcs_to_trim) {
+    func.erase();
+  }
+  return true;
+}
+
+// validate that all reachable functions from the remaining functions are
+// also in the whitelist.
+void TrimFunctionsPass::Verify() {
+  // TODO(ashwinm): Instead, we should make sure that references to all
+  // SymbolRefAttrs of all ops are present.
+  SymbolTable symbol_table = SymbolTable(getModule());
+  llvm::SetVector<FuncOp> reachable_funcs;
+  for (auto func : getModule().getOps<FuncOp>()) {
+    func.walk<CallOp>([&](CallOp op) {
+      if (!symbol_table.lookup<FuncOp>(op.getCallee())) {
+        getModule().emitError()
+            << func.getName() << " is not in the funcs whitelist";
+        return signalPassFailure();
+      }
+    });
+  }
+}
+
+}  // namespace
+
+// Creates an instance of the TensorFlow Lite dialect TrimFunctions
+/// pass.
+ModulePassBase *CreateTrimFunctionsPass(
+    llvm::ArrayRef<std::string> trim_funcs_whitelist) {
+  return new TrimFunctionsPass(trim_funcs_whitelist);
+}
+
+static PassRegistration<TrimFunctionsPass> pass(
+    "tfl-trim-funcs-tf",
+    "Trim functions to restrict them to a specified whitelist prior to "
+    "legalization to TensorFlow lite dialect");
+
+}  // namespace TFL
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/lite/utils/quantization_driver.cc b/tensorflow/compiler/mlir/lite/utils/quantization_driver.cc
deleted file mode 100644
index 956c1f1..0000000
--- a/tensorflow/compiler/mlir/lite/utils/quantization_driver.cc
+++ /dev/null
@@ -1,730 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <unordered_map>
-#include <unordered_set>
-#include <utility>
-
-#include "absl/memory/memory.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/Casting.h"
-#include "llvm/Support/Debug.h"
-#include "llvm/Support/ErrorHandling.h"
-#include "mlir/Dialect/QuantOps/QuantTypes.h"  // TF:local_config_mlir
-#include "mlir/IR/Attributes.h"  // TF:local_config_mlir
-#include "mlir/IR/Builders.h"  // TF:local_config_mlir
-#include "mlir/IR/Function.h"  // TF:local_config_mlir
-#include "mlir/IR/MLIRContext.h"  // TF:local_config_mlir
-#include "mlir/IR/Matchers.h"  // TF:local_config_mlir
-#include "mlir/IR/Operation.h"  // TF:local_config_mlir
-#include "mlir/IR/StandardTypes.h"  // TF:local_config_mlir
-#include "mlir/IR/Value.h"  // TF:local_config_mlir
-#include "mlir/StandardOps/Ops.h"  // TF:local_config_mlir
-#include "mlir/Support/LLVM.h"  // TF:local_config_mlir
-#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
-#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
-#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h"
-#include "tensorflow/core/platform/logging.h"
-
-namespace mlir {
-namespace TFL {
-namespace {
-
-using QuantParams = quant::QuantizedType;
-using AccumulatorScaleFunc =
-    std::function<QuantParams(const std::vector<QuantParams> &)>;
-using SignedInteger = std::pair<unsigned, unsigned>;  // bitwidth and sign
-using QuantParamsForResults = llvm::SmallVector<QuantParams, 4>;
-
-// Quantization specs of ops, driving the TF Lite quantization algorithm.
-struct OpQuantSpec {
-  // Whether the op has quantizable result. This flag is set to false if the op
-  // has "TFL::NoQuantizableResult" trait.
-  bool is_quantizable = true;
-
-  // Whether it requires same inputs and result scale. This flag is set to true
-  // if the op has "TFL::SameOperandsAndResultScale" trait.
-  bool requires_same_scale = false;
-
-  // Maps the operand index of a bias input to its quantization specifications,
-  // including the non-bias operand indexes and the method retrieving
-  // quantization parameters from list of parameters of the non-bias operands.
-  // This map is empty if the op doesn't havea bias operand.
-  std::unordered_map<int, std::pair<std::vector<int>, AccumulatorScaleFunc>>
-      biases_params;
-
-  // Quantization parameters for value restricted outputs. This is the
-  // "hard-coded" parameters and should be used unconditionally for the
-  // quantized op. This vector is empty if the op doesn't have value resctricted
-  // outputs.
-  llvm::DenseMap<SignedInteger, QuantParamsForResults> restricted_output_params;
-};
-
-static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); }
-
-// The state for each op result during the quantization parameters propagation.
-struct QuantState {
-  // Quantization parameters propagated to an op result.
-  QuantParams params;
-  // A flag indicates this state (the params) shouldn't be changed after it is
-  // initialized. This flag will be set to true if the quantization parameters
-  // are from the quantization-aware training.
-  const bool immutable;
-
-  bool IsEmpty() { return EmptyParams(params); }
-};
-
-// The state for rescaling the propagated quantization parameters. This can be
-// on the input side to satisfy the constraint of previous operation, or on the
-// output side to satisfy the constraint of the next operation.
-struct RequantizeState {
-  // Sometimes, we have to "requantize" the quantization result to satisfy all
-  // the constraints. The "requantize" can happen either on the input or output
-  // of the quantization result.
-  enum RequantizePosition {
-    NO_REQUANTIZE,
-    ON_INPUT,
-    ON_OUTPUT
-  } pos = NO_REQUANTIZE;
-
-  // Quantization parameters will be used to add the requantize ops.
-  QuantParams params;
-};
-
-// This is a worklist-driven driver for propagating quantization parameters
-// across operations.
-//
-// The initial quantization parameters are extracted from the quantized type
-// between adjacent tfl.quantize and tfl.dequantize ops. All these initial
-// parameters are marked as immutable because they are from quantization-aware
-// training.
-//
-// The algorithm traverses each op and sets the quantization parameters of its
-// operands and results, according to its quantization specification, and then
-// adds the operands and results to the worklist. If there are any conflicts
-// (for example, there are quantization parameters propagated from the previous
-// iteration), this process stops if the existing parameters are the immutable,
-// or adding `requantize` op to resolve the conflicts.
-//
-// After the algorithm is converaged, pairs of tfl.quantize and tfl.dequantize
-// are inserted to the right position to materialize the propagation and
-// requantize results.
-//
-class QuantizationDriver {
- public:
-  explicit QuantizationDriver(FuncOp fn, bool is_signed)
-      : fn_(fn), builder_(fn.getBody()), is_signed_(is_signed) {}
-
-  // The entry point of the quantization parameters propagation.
-  void Run();
-
- private:
-  // This is used to identify an operand or result of an op. The second element
-  // of this pair is the index of the operand or result.
-  using OpValue = std::pair<mlir::Operation *, int>;
-
-  // Sets up the states for all the op results in the function.
-  void Initialize();
-
-  // Propagates the quantization parameters across all the ops.
-  bool PropagateParams();
-
-  // Inserts the Quantize and Dequantize ops according to the propagation
-  // result.
-  void Finalize();
-
-  // Whether the constant is used as a bias input of another op. Here we assume
-  // bias is used immediately by the user. This assumption is always correct
-  // after constant folding.
-  bool UsedAsBias(ConstantOp cst) {
-    Value *value = cst.getResult();
-    for (auto &use : value->getUses()) {
-      auto biases = GetQuantSpec(use.getOwner())->biases_params;
-      if (biases.find(use.getOperandNumber()) != biases.end()) return true;
-    }
-    return false;
-  }
-
-  // Returns all the related quantization constraints of the op.
-  std::unique_ptr<OpQuantSpec> GetQuantSpec(Operation *op);
-
-  // Whether Quantization parameters have been propagated to the results of this
-  // op.
-  bool IsQuantized(Operation *op);
-
-  // Adds all the users of index-th result of op to the work list.
-  void AddUserToList(Operation *op, int index) {
-    for (auto *user : op->getResult(index)->getUsers()) {
-      work_list_.push_back(user);
-    }
-  }
-
-  // Adds the defining op of index-th operand of op to the work list.
-  void AddOperandToList(Operation *op, int index) {
-    if (auto *inst = op->getOperand(index)->getDefiningOp()) {
-      work_list_.push_back(inst);
-    }
-  }
-
-  // Returns the quantization params for the bias input from the non-bias
-  // operands which have their indexes in the `non_biases` vector. The returned
-  // parameters are calculated by `func`.
-  QuantParams GetBiasParams(Operation *op, int bias,
-                            const std::vector<int> &non_biases,
-                            AccumulatorScaleFunc func);
-
-  // Sets the quantization parameters of the result to a fixed value. If any
-  // quantization parameters have been propagated, a `requantize` will happen on
-  // the input of propagated quantization.
-  bool SetResultParams(Operation *op, int index, QuantParams params);
-
-  // Sets the quantization parameters of the operand to a fixed value. If any
-  // quantization parameters have been propagated, a `requantize` will happen on
-  // the output of propagated quantization.
-  bool SetOperandParams(Operation *op, int index, QuantParams params);
-
-  // Sets the quantization parameters of the constant result according to its
-  // content.
-  bool SetConstantResultParams(Operation *op);
-
-  // Inserts the Quantize and Dequantize ops for quantizing the index-th result
-  // of the op.
-  void QuantizeOpResult(Operation *op, int index, QuantParams params);
-
-  void QuantizeArg(BlockArgument *arg, QuantParams params);
-
-  // Inserts the Quantize and Dequantize ops to quantize the value and returns
-  // the Quantize op.
-  void QuantizeValue(Value *value, QuantParams params, Location loc);
-
-  // Inserts the Quantize ops for requantizing the index-th result of the op.
-  void RequantizeOpResult(Operation *op, int index, RequantizeState *state);
-
-  void RequantizeArg(BlockArgument *arg, RequantizeState *state);
-
-  // Inserts the Quantize and Dequantize ops to quantize the value and returns
-  // the Quantize op.
-  void RequantizeValue(Value *value, RequantizeState *state, Location loc);
-
-  // A heuristic to get the quantization parameter satisfies the same scale
-  // constraints for the op. Returns an empty option if this quantization
-  // parameter doesn't exist.
-  QuantParams GetQuantParamsForSameScaleConstraint(Operation *op);
-
-  // Returns the state of the index-th operand of the op.
-  QuantState &GetOperandQuantState(Operation *op, int index) {
-    return states_[operand_states_[{op, index}]];
-  }
-
-  // Returns the state of the index-th result of the op.
-  QuantState &GetResultQuantState(Operation *op, int index) {
-    return states_[result_states_[{op, index}]];
-  }
-
-  QuantState &GetArgQuantState(BlockArgument *arg) {
-    return states_[arg_states_[arg]];
-  }
-
-  // Returns the state of the index-th operand of the op.
-  RequantizeState &GetOperandRequantizeState(Operation *op, int index) {
-    return rescale_states_[operand_states_[{op, index}]];
-  }
-
-  // Returns the state of the index-th result of the op.
-  RequantizeState &GetResultRequantizeState(Operation *op, int index) {
-    return rescale_states_[result_states_[{op, index}]];
-  }
-
-  RequantizeState &GetArgRequantizeState(BlockArgument *arg) {
-    return rescale_states_[arg_states_[arg]];
-  }
-
-  // Uses the type of `val` to set the initial state of the index-th result if
-  // `as_result` is true or index-th operand if `as_result` is false. The state
-  // is immutable if the type is a quantized type. Returns the index of this
-  // new state in the state vector.
-  int InitializeState(Operation *op, int index, Value *val, bool as_result);
-
-  // Sets the state of the index-th operand of the op. If this operand is
-  // cached, uses the cached result without creating new entry in the state
-  // vector. Otherwise, allocate a new entry in the state vector.
-  void InitializeOperandState(Operation *op, int index, Value *in,
-                              llvm::DenseMap<Value *, int> *cache,
-                              bool is_argument) {
-    auto cached = cache->insert({in, 0});
-    if (!cached.second) {
-      operand_states_.insert({{op, index}, cached.first->second});
-      return;
-    }
-    cached.first->second = InitializeState(op, index, in, /*as_result=*/false);
-    if (is_argument) {
-      auto *arg = llvm::cast<BlockArgument>(in);
-      arg_states_[arg] = cached.first->second;
-      args_.push_back(arg);
-    }
-  }
-
-  // Sets the state of the index-th result of the op. If this result is cached,
-  // uses the cached result without creating new entry in the state vector.
-  // Otherwise, allocate a new entry in the state vector.
-  void InitializeResultState(Operation *op, int index, Value *res,
-                             llvm::DenseMap<Value *, int> *cache) {
-    auto cached = cache->insert({res, 0});
-    if (!cached.second) {
-      result_states_.insert({{op, index}, cached.first->second});
-      return;
-    }
-    cached.first->second = InitializeState(op, index, res, /*as_result=*/true);
-  }
-
-  FuncOp fn_;
-  OpBuilder builder_;
-  bool is_signed_;
-
-  // All the ops needs to propagate the quantization parameters to.
-  std::vector<Operation *> work_list_;
-  std::unordered_set<Operation *> quantized_;
-
-  // The vector contains all the quantization parameters propagated from the
-  // defining operations of the value, or from the quantization aware training.
-  std::vector<QuantState> states_;
-
-  // The map contains all the quantization parameters which are required to
-  // satisfy the same operands and results constraint. The keys of this map are
-  // the values from `operand_states_` and `result_state_`.
-  std::unordered_map<int, RequantizeState> rescale_states_;
-
-  // Maps of indexes to the propagation state vector from the ops operands,
-  // results and arguments.
-  llvm::DenseMap<OpValue, int> operand_states_;
-  llvm::DenseMap<OpValue, int> result_states_;
-  llvm::DenseMap<BlockArgument *, int> arg_states_;
-
-  // This vector is to preserve the arguments order, so the newly inserted
-  // quantized ops for the arguments are deterministically ordered.
-  llvm::SmallVector<BlockArgument *, 4> args_;
-};
-
-#include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc"
-}  // namespace
-
-// TODO(fengliuai): cache the quantization parameters.
-std::unique_ptr<OpQuantSpec> QuantizationDriver::GetQuantSpec(Operation *op) {
-  return GetOpQuantSpec(op);
-}
-
-bool QuantizationDriver::IsQuantized(Operation *op) {
-  for (int i = 0, e = op->getNumResults(); i != e; ++i) {
-    if (GetResultQuantState(op, i).IsEmpty()) return false;
-  }
-  return true;
-}
-
-int QuantizationDriver::InitializeState(Operation *op, int index, Value *val,
-                                        bool as_result) {
-  QuantParams params =
-      quant::QuantizedType::getQuantizedElementType(val->getType());
-  bool immutable = !EmptyParams(params);
-  int next_state_index = states_.size();
-  states_.push_back({params, immutable});
-  if (as_result)
-    result_states_.insert({{op, index}, next_state_index});
-  else
-    operand_states_.insert({{op, index}, next_state_index});
-
-  return next_state_index;
-}
-
-bool QuantizationDriver::SetConstantResultParams(Operation *op) {
-  ElementsAttr attr;
-  Value *res = op->getResult(0);
-  if (!matchPattern(res, m_Constant(&attr))) {
-    return false;
-  }
-  // TODO(fengliuai): the bit width should be determined by its user.
-  auto final_type =
-      GetUniformQuantizedTypeForElementsAttr(
-          attr, /*storage_type_width=*/8, is_signed_, /*narrow_range_=*/false)
-          .dyn_cast_or_null<quant::QuantizedType>();
-  if (!final_type) return false;
-  return SetResultParams(op, 0, final_type);
-}
-
-bool QuantizationDriver::SetResultParams(Operation *op, int res_index,
-                                         QuantParams params) {
-  auto &state = GetResultQuantState(op, res_index);
-  if (state.params == params) {
-    return false;
-  }
-  if (!state.IsEmpty()) {
-    auto &rescale = GetResultRequantizeState(op, res_index);
-    rescale.params = params;
-    rescale.pos = RequantizeState::ON_INPUT;
-    return true;
-  }
-  state.params = params;
-  AddUserToList(op, res_index);
-  return true;
-}
-
-QuantParams QuantizationDriver::GetBiasParams(
-    Operation *op, int bias, const std::vector<int> &non_biases,
-    AccumulatorScaleFunc func) {
-  auto &bias_state = GetOperandQuantState(op, bias);
-  if (!bias_state.IsEmpty()) {
-    return bias_state.params;
-  }
-  std::vector<QuantParams> op_types;
-  op_types.reserve(non_biases.size());
-  for (auto non_bias : non_biases) {
-    auto &non_bias_type = GetOperandQuantState(op, non_bias);
-    op_types.push_back(non_bias_type.params);
-  }
-  if (op_types.empty()) return {};
-  return func(op_types);
-}
-
-bool QuantizationDriver::SetOperandParams(Operation *op, int index,
-                                          QuantParams params) {
-  auto &state = GetOperandQuantState(op, index);
-  if (state.params == params) {
-    return false;
-  }
-
-  if (!state.IsEmpty()) {
-    auto &rescale = GetOperandRequantizeState(op, index);
-    rescale.params = params;
-    rescale.pos = RequantizeState::ON_OUTPUT;
-    return true;
-  }
-
-  state.params = params;
-  AddOperandToList(op, index);
-  return true;
-}
-
-void QuantizationDriver::QuantizeOpResult(Operation *op, int index,
-                                          QuantParams params) {
-  builder_.setInsertionPoint(op->getBlock(), ++Block::iterator(op));
-  Value *original_result = op->getResult(index);
-  QuantizeValue(original_result, params, op->getLoc());
-}
-
-void QuantizationDriver::QuantizeArg(BlockArgument *arg, QuantParams params) {
-  builder_.setInsertionPointToStart(arg->getOwner());
-  QuantizeValue(arg, params, builder_.getUnknownLoc());
-}
-
-void QuantizationDriver::QuantizeValue(Value *value, QuantParams params,
-                                       Location loc) {
-  Type expressed_type = value->getType();
-  Type new_type = params.castFromExpressedType(expressed_type);
-  // This value isn't an expressed type (float), skip.
-  if (!new_type) return;
-
-  TypeAttr type_attr = builder_.getTypeAttr(new_type);
-  auto quantize =
-      builder_.create<TFL::QuantizeOp>(loc, new_type, value, type_attr);
-  auto dequantize = builder_.create<TFL::DequantizeOp>(loc, expressed_type,
-                                                       quantize.output());
-  // `original_result` has a use to `quantize`, so this will replace that use
-  // by the result of `dequantize`. Remember to reset that use afterwards
-  value->replaceAllUsesWith(dequantize);
-  quantize.getOperation()->replaceUsesOfWith(dequantize, value);
-}
-
-void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
-                                            RequantizeState *state) {
-  if (state->pos == RequantizeState::NO_REQUANTIZE) return;
-  builder_.setInsertionPoint(op->getBlock(), ++Block::iterator(op));
-  Value *value = op->getResult(index);
-  if (state->pos == RequantizeState::ON_OUTPUT) {
-    Operation *op = value->getUses().begin().getUser();  // `quantize` op
-    // The requantize op is inserted between `quantize` and `dequantize` ops.
-    value = op->getResult(0);
-    builder_.setInsertionPoint(op->getBlock(), ++Block::iterator(op));
-  }
-  RequantizeValue(value, state, op->getLoc());
-}
-
-void QuantizationDriver::RequantizeArg(BlockArgument *arg,
-                                       RequantizeState *state) {
-  Value *value = arg;
-  builder_.setInsertionPointToStart(arg->getOwner());
-  if (value->hasOneUse()) {
-    auto user = value->use_begin().getUser();
-    if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
-      value = q.output();
-      builder_.setInsertionPoint(arg->getOwner(), ++Block::iterator(user));
-    }
-  }
-  RequantizeValue(value, state, builder_.getUnknownLoc());
-}
-
-void QuantizationDriver::RequantizeValue(Value *value, RequantizeState *state,
-                                         Location loc) {
-  Type new_type;
-  if (state->pos == RequantizeState::ON_INPUT) {
-    Type expressed_type = value->getType();
-    // The value needs to be requantized. A Quantize op will be created to use
-    // it as the operand and replace its uses.
-    new_type = state->params.castFromExpressedType(expressed_type);
-  } else {
-    Type expressed_type =
-        quant::QuantizedType::castToExpressedType(value->getType());
-    if (!expressed_type) return;
-
-    // The value needs to be requantized. A Quantize op will be created to use
-    // it as the operand and replace its uses.
-    new_type = state->params.castFromExpressedType(expressed_type);
-  }
-  // This value isn't an expressed type (float), skip.
-  if (!new_type) return;
-
-  TypeAttr type_attr = builder_.getTypeAttr(new_type);
-  auto requantize_op =
-      builder_.create<TFL::QuantizeOp>(loc, new_type, value, type_attr);
-  value->replaceAllUsesWith(requantize_op);
-  requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value);
-}
-
-// A heuristic to get quantization parameters satisfies the same scale
-// constraints:
-// - If there are immutable states,
-//   - use the single input, or,
-//   - use the single output, or,
-//   - use the first one in the collection,
-// - use the single input if it is ready, or,
-// - use the single output if it is ready, or,
-// - use use the first ready one in the collection.
-QuantParams QuantizationDriver::GetQuantParamsForSameScaleConstraint(
-    Operation *op) {
-  // Two vector to collect Non-empty operands and results states.
-  std::vector<QuantState *> mutable_states, immutable_states;
-  for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
-    auto &state = GetOperandQuantState(op, i);
-    if (state.immutable) {
-      immutable_states.push_back(&state);
-    } else if (!state.IsEmpty()) {
-      mutable_states.push_back(&state);
-    }
-  }
-
-  int immutable_operands_num = immutable_states.size();
-  int mutable_operands_num = mutable_states.size();
-  // Use the operand's state if it is immutable and it is the only one operand.
-  if (op->getNumOperands() == 1 && immutable_operands_num == 1) {
-    return immutable_states.front()->params;
-  }
-
-  for (int i = 0, e = op->getNumResults(); i != e; ++i) {
-    auto &state = GetResultQuantState(op, i);
-    if (state.immutable) {
-      immutable_states.push_back(&state);
-    } else if (!state.IsEmpty()) {
-      mutable_states.push_back(&state);
-    }
-  }
-
-  int immutable_results_num = immutable_states.size() - immutable_operands_num;
-  int mutable_results_num = mutable_states.size() - mutable_operands_num;
-  // Use the result's state if it is immutable and it is the only one result.
-  if (op->getNumResults() == 1 && immutable_results_num == 1) {
-    return immutable_states.back()->params;
-  }
-
-  // Use the first immutable state to quantize the rest operands and results.
-  if (!immutable_states.empty()) return immutable_states.front()->params;
-
-  // If there are no immutable states, use the operand's state if it is the only
-  // one operand and has parameters propagated.
-  if (op->getNumOperands() == 1 && mutable_operands_num == 1) {
-    return mutable_states.front()->params;
-  }
-
-  // If there are no immutable states, use the result's state if it is the only
-  // one result and has parameters propagated.
-  if (op->getNumResults() == 1 && mutable_results_num == 1) {
-    return mutable_states.back()->params;
-  }
-
-  // Use the first propagated state to quantize the rest operands and results.
-  if (!mutable_states.empty()) return mutable_states.front()->params;
-
-  // None operands/results have parameters propagated, skip this node for now.
-  return {};
-}
-
-// This method scans the operations in the function to setup the initial
-// states for quantization parameter propagation.
-// TODO(fengliuai): This algorithm assumes there are only one pair of
-// tfl.quantize and tfl.dequantize ops between two quantizable ops. A sanity
-// check should be applied.
-void QuantizationDriver::Initialize() {
-  llvm::DenseMap<Value *, int> value_to_state;
-
-  fn_.walk([&](Operation *op) {
-    if (op->isKnownTerminator()) return;
-    if (!GetQuantSpec(op)->is_quantizable) return;
-    work_list_.push_back(op);
-
-    for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
-      auto *operand = op->getOperand(i);
-      bool is_argument = true;
-      if (auto *inst = operand->getDefiningOp()) {
-        // If the operand comes from a tfl.dequantize op, we use the quantized
-        // input of this tfl.dequantize op to set the state.
-        if (auto dq = llvm::dyn_cast<TFL::DequantizeOp>(inst)) {
-          operand = dq.input();
-        }
-        is_argument = false;
-      }
-      InitializeOperandState(op, i, operand, &value_to_state, is_argument);
-    }
-
-    for (int res = 0, e = op->getNumResults(); res != e; ++res) {
-      auto *result = op->getResult(res);
-      // If the result has been quantized, it should only be used by a
-      // tfl.quantize op. For this case, we uses the quantized result to create
-      // the state and mark it immutable.
-      if (result->hasOneUse()) {
-        auto user = result->use_begin().getUser();
-        if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
-          result = q.output();
-        }
-      }
-      InitializeResultState(op, res, result, &value_to_state);
-    }
-  });
-}
-
-bool QuantizationDriver::PropagateParams() {
-  // TODO(fengliuai): uses a typed indicator instead of a bool value.
-  bool changed = false;
-  while (!work_list_.empty()) {
-    Operation *op = work_list_.back();
-    work_list_.pop_back();
-
-    // This op has been quantized, so we should consider it again.
-    if (quantized_.find(op) != quantized_.end()) continue;
-    quantized_.insert(op);
-
-    auto spec = GetQuantSpec(op);
-
-    // If the op has no quantizable result, the quantization parameters will not
-    // be propagated to the results.
-    if (!spec->is_quantizable) continue;
-
-    if (auto cst = llvm::dyn_cast<ConstantOp>(op)) {
-      // This constant is used as a bias in another op, then the quantization
-      // parameters are determined by that op.
-      if (UsedAsBias(cst) || IsQuantized(op)) continue;
-
-      // The quantization parameters are determined by the content of the
-      // constant.
-      changed |= SetConstantResultParams(op);
-      continue;
-    }
-
-    if (spec->requires_same_scale) {
-      auto params = GetQuantParamsForSameScaleConstraint(op);
-      // The quantization parameters haven't been propagated to any operands or
-      // results. Skip this node for now.
-      if (!params) {
-        quantized_.erase(op);
-        continue;
-      }
-
-      // Use the final state to set all the operands' parameters.
-      for (int i = 0, e = op->getNumOperands(); i != e; ++i)
-        changed |= SetOperandParams(op, i, params);
-
-      // Use the final state to set all the results' parameters.
-      for (int res = 0, e = op->getNumResults(); res != e; ++res)
-        changed |= SetResultParams(op, res, params);
-    }
-
-    // TODO(fengliuai): make the bit width configurable.
-    auto key = std::make_pair(8, is_signed_);
-    auto &restricted_outputs = spec->restricted_output_params[key];
-    for (int i = 0, e = restricted_outputs.size(); i != e; ++i) {
-      changed |= SetResultParams(op, i, restricted_outputs[i]);
-    }
-
-    for (auto &it : spec->biases_params) {
-      auto params =
-          GetBiasParams(op, it.first, it.second.first, it.second.second);
-      if (!params) {
-        quantized_.erase(op);
-        continue;
-      }
-      changed |= SetOperandParams(op, it.first, params);
-    }
-  }
-  return changed;
-}
-
-void QuantizationDriver::Finalize() {
-  for (auto *arg : args_) {
-    auto &state = GetArgQuantState(arg);
-    auto &requantize = GetArgRequantizeState(arg);
-    if (state.IsEmpty() ||
-        (state.immutable && requantize.pos == RequantizeState::NO_REQUANTIZE)) {
-      continue;
-    }
-
-    if (!state.immutable) {
-      QuantizeArg(arg, state.params);
-    }
-
-    if (requantize.pos != RequantizeState::NO_REQUANTIZE) {
-      RequantizeArg(arg, &requantize);
-    }
-  }
-
-  for (auto it : result_states_) {
-    Operation *op = it.first.first;
-    int res_index = it.first.second;
-    auto &state = GetResultQuantState(op, res_index);
-    auto &requantize = GetResultRequantizeState(op, res_index);
-    if (state.IsEmpty() ||
-        (state.immutable && requantize.pos == RequantizeState::NO_REQUANTIZE)) {
-      continue;
-    }
-
-    if (!state.immutable) {
-      QuantizeOpResult(op, res_index, state.params);
-    }
-
-    if (requantize.pos != RequantizeState::NO_REQUANTIZE) {
-      RequantizeOpResult(op, res_index, &requantize);
-    }
-  }
-}
-
-void QuantizationDriver::Run() {
-  Initialize();
-  if (PropagateParams()) {
-    Finalize();
-  }
-}
-
-void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed) {
-  QuantizationDriver(func, is_signed).Run();
-}
-
-}  // namespace TFL
-}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/lite/utils/quantization_utils.cc b/tensorflow/compiler/mlir/lite/utils/quantization_utils.cc
deleted file mode 100644
index da797db..0000000
--- a/tensorflow/compiler/mlir/lite/utils/quantization_utils.cc
+++ /dev/null
@@ -1,123 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h"
-
-#include "mlir/Dialect/QuantOps/FakeQuantSupport.h"  // TF:local_config_mlir
-#include "mlir/Dialect/QuantOps/QuantTypes.h"  // TF:local_config_mlir
-#include "mlir/Dialect/QuantOps/QuantizeUtils.h"  // TF:local_config_mlir
-#include "mlir/Dialect/QuantOps/UniformSupport.h"  // TF:local_config_mlir
-#include "mlir/IR/Attributes.h"  // TF:local_config_mlir
-#include "mlir/IR/StandardTypes.h"  // TF:local_config_mlir
-#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
-
-namespace mlir {
-namespace TFL {
-
-// Returns the quantized type for the
-// input_type/min/max/storag_type_width/narrow_range.
-static Type GetQuantizedType(Builder builder, Type input_type, double min,
-                             double max, int storage_type_width,
-                             bool narrow_range, bool is_signed) {
-  auto converter =
-      quant::ExpressedToUniformQuantizedConverter::forInputType(input_type);
-
-  quant::UniformQuantizedType quantizedEleType = quant::fakeQuantAttrsToType(
-      builder.getUnknownLoc(), storage_type_width, min, max, narrow_range,
-      converter.expressedType, is_signed);
-  return converter.convert(quantizedEleType);
-}
-
-TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, FloatAttr min,
-                              FloatAttr max, Type storage_type,
-                              bool narrow_range, bool is_signed) {
-  int storage_type_width = storage_type.cast<IntegerType>().getWidth();
-  Type final_type = GetQuantizedType(
-      builder, input_type, min.getValueAsDouble(), max.getValueAsDouble(),
-      storage_type_width, narrow_range, is_signed);
-  return builder.getTypeAttr(final_type);
-}
-
-TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
-                              Attribute max, IntegerAttr num_bits,
-                              BoolAttr narrow_range) {
-  FloatAttr min_value = GetSingleElementAsFloatOrSelf(min);
-  FloatAttr max_value = GetSingleElementAsFloatOrSelf(max);
-  if (!min_value || !max_value) return {};
-  return GetQuantizedTypeAttr(builder, input_type, min_value, max_value,
-                              builder.getIntegerType(num_bits.getInt()),
-                              narrow_range.getValue(), /*is_signed=*/false);
-}
-
-Type GetUniformQuantizedTypeForElementsAttr(ElementsAttr attr,
-                                            unsigned storage_type_width,
-                                            bool is_signed, bool narrow_range) {
-  Builder builder(attr.getContext());
-  double min = std::numeric_limits<double>::max();
-  double max = std::numeric_limits<double>::min();
-  if (auto fp = attr.dyn_cast<DenseFPElementsAttr>()) {
-    for (auto it = fp.begin(), e = fp.end(); it != e; ++it) {
-      double ele_value = FloatAttr::getValueAsDouble(*it);
-      min = std::min(min, ele_value);
-      max = std::max(max, ele_value);
-    }
-    // The range must straddle zero.
-    if (min > 0.0 || max < 0.0) return {};
-    auto type = GetQuantizedType(builder, attr.getType(), min, max,
-                                 storage_type_width, narrow_range, is_signed);
-    if (auto ele_type = type.dyn_cast_or_null<TensorType>())
-      return ele_type.getElementType();
-  }
-
-  // The range from SplatElementAttr and other element attribute types  couldn't
-  // straddle zero, so the quantization parameters couldn't be derived from its
-  // range.
-  return {};
-}
-
-quant::QuantizedType GetUniformQuantizedTypeForBias(
-    const std::vector<quant::QuantizedType>& op_types) {
-  if (op_types.empty()) return {};
-
-  double scale = 1.0;
-  for (unsigned i = 0, e = op_types.size(); i != e; ++i) {
-    auto qtype = op_types[i].dyn_cast_or_null<quant::UniformQuantizedType>();
-    if (!qtype) return {};
-    scale *= qtype.getScale();
-  }
-  auto type = op_types.back().cast<quant::UniformQuantizedType>();
-  Builder builder(type.getContext());
-  // TODO(fengliuai): make the bit width configurable.
-  IntegerType storageType = builder.getIntegerType(32);
-  return quant::UniformQuantizedType::getChecked(
-      /*flags=*/true, storageType, type.getExpressedType(), scale,
-      /*zeroPoint=*/0,
-      quant::QuantizedType::getDefaultMininumForInteger(/*isSigned=*/true, 32),
-      quant::QuantizedType::getDefaultMaxinumForInteger(/*isSigned=*/true, 32),
-      builder.getUnknownLoc());
-}
-
-ElementsAttr Quantize(Attribute real_value, Type tensor_type) {
-  if (auto q_type =
-          quant::QuantizedType::getQuantizedElementType(tensor_type)) {
-    Type converted_type;
-    return quant::quantizeAttr(real_value, q_type, converted_type)
-        .cast<ElementsAttr>();
-  }
-  return {};
-}
-
-}  // namespace TFL
-}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/lite/utils/quantization_utils.h b/tensorflow/compiler/mlir/lite/utils/quantization_utils.h
deleted file mode 100644
index d2c5808..0000000
--- a/tensorflow/compiler/mlir/lite/utils/quantization_utils.h
+++ /dev/null
@@ -1,133 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// This header file defines common utils used by TFLite transformation
-// passes to work with op attributes.
-
-#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_QUANTIZATION_UTILS_H_
-#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_QUANTIZATION_UTILS_H_
-
-#include "mlir/Dialect/QuantOps/QuantTypes.h"  // TF:local_config_mlir
-#include "mlir/IR/BlockAndValueMapping.h"  // TF:local_config_mlir
-#include "mlir/IR/PatternMatch.h"  // TF:local_config_mlir
-#include "mlir/IR/StandardTypes.h"  // TF:local_config_mlir
-#include "mlir/StandardOps/Ops.h"  // TF:local_config_mlir
-
-namespace mlir {
-namespace TFL {
-
-// A generic rewrite pattern which matches any N-in-1-out operations with
-// quantization parameters propagated to all the operands and results values.
-// The quantization parameters are annotated by the Q/DQ op pairs. Each matched
-// pattern are rewritten by its quantized alternatives.
-//
-// This pattern assumes all the matched ops are quantizable. This assumption is
-// always right, except when a "Q" op is used as a requantize op. For non-"Q"
-// ops, quantization parameters should be propagated to their result.
-//
-// This pattern only matches ops which only have one result.
-template <typename Q, typename DQ>
-struct GenericFullQuantizationPattern : public RewritePattern {
-  explicit GenericFullQuantizationPattern(MLIRContext* context)
-      : RewritePattern(Q::getOperationName(), 1, context) {}
-
-  PatternMatchResult matchAndRewrite(Operation* op,
-                                     PatternRewriter& rewriter) const override {
-    if (op->getNumResults() != 1) {
-      return matchFailure();
-    }
-    auto quantize_op = cast<Q>(op);
-    auto quantized_op = quantize_op.input()->getDefiningOp();
-    // If it is a block argument, requantize op, or has more than one result, we
-    // shouldn't rewrite this op.
-    if (!quantized_op || llvm::isa<Q>(quantized_op) ||
-        llvm::isa<DQ>(quantized_op) || quantized_op->getNumResults() != 1) {
-      return matchFailure();
-    }
-
-    // Collect all the quantized inputs and "clone" the matched op by these
-    // inputs.
-    SmallVector<Value*, 4> inputs;
-    inputs.reserve(quantized_op->getNumOperands());
-    for (int i = 0, e = quantized_op->getNumOperands(); i != e; ++i) {
-      auto* operand = quantized_op->getOperand(i);
-      auto operand_ele_type =
-          operand->getType().template cast<TensorType>().getElementType();
-      if (auto op_inst = dyn_cast_or_null<DQ>(operand->getDefiningOp())) {
-        inputs.push_back(op_inst.input());
-      } else if (operand_ele_type.template isa<IntegerType>()) {
-        // If the operand is an integer tensor, then it doesn't require the
-        // DQ op in the pattern.
-        inputs.push_back(operand);
-      } else {
-        return matchFailure();
-      }
-    }
-    // Use OpBuilder so we can use op name to create the new op.
-    OpBuilder builder(quantized_op);
-    OperationState new_state(
-        quantized_op->getLoc(), quantized_op->getName().getStringRef(), inputs,
-        op->getResult(0)->getType(), quantized_op->getAttrs());
-    Operation* new_op = builder.createOperation(new_state);
-    rewriter.replaceOp(op, {new_op->getResult(0)});
-    return matchSuccess();
-  }
-};
-
-// Converts the min/max/storage_type/narrow_range information to a
-// QuantizedType, and then returns the attribute containing the QuantizedType.
-TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, FloatAttr min,
-                              FloatAttr max, Type storage_type,
-                              bool narrow_range = false,
-                              bool is_signed = false);
-
-// Converts the min/max/num_bits/narrow_range information to a
-// QuantizedType, and then returns the attribute containing the QuantizedType.
-// Note that this method assumes an unsigned quantization type, which is
-// implicitly defined by FakeQuant* ops in TensorFlow.
-TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
-                              Attribute max, IntegerAttr num_bits,
-                              BoolAttr narrow_range);
-
-// Quantizes the elements in the attribute `real_value` by the quantization
-// parameters in `tensor_type`. Returns empty Attribute if the
-// `tensor_type` is not a QuantizedType or the quantization fails.
-ElementsAttr Quantize(Attribute real_value, Type tensor_type);
-
-// Returns the quantized type for an element attribute. The quantization
-// parameters in this type is based on the min and max element of the attribute.
-// When the elements in the `attr` are not in floating-point, or the value range
-// isn't straddling zero, an empty type is returned.
-Type GetUniformQuantizedTypeForElementsAttr(ElementsAttr attr,
-                                            unsigned storage_type_width,
-                                            bool is_sign, bool narrow_range);
-
-// Returns the quantized type of a bias input, given the quantized types of
-// other operands which are multiply-accumulated (the bias is added to the
-// accumulated value).
-quant::QuantizedType GetUniformQuantizedTypeForBias(
-    const std::vector<quant::QuantizedType>& op_types);
-
-// Propagates quantization parameters across ops in this function and satisfy
-// the quantization specification of the ops. This methods assumes the initial
-// quantization parameters are stored as adjacent quantize and dequantize ops
-// and the propagation results are materialized by inserting pairs of quantize
-// and dequantize ops to this function.
-void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed);
-
-}  // end namespace TFL
-}  // end namespace mlir
-
-#endif  // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_QUANTIZATION_UTILS_H_
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index d096831..f75e752 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -99,6 +99,8 @@
         "ir/tf_ops.cc",
         "ir/tf_ops.cc.inc",
         "ir/tf_ops.h.inc",
+        "transforms/cluster_outlining.cc",
+        "transforms/executor_island_coarsening.cc",
         "transforms/functional_control_flow_to_cfg.cc",
         "transforms/generated_canonicalize.inc",
         "transforms/generated_optimize.inc",
@@ -111,6 +113,7 @@
         "ir/control_flow_ops.h",
         "ir/tf_executor.h",
         "ir/tf_ops.h",
+        "ir/tf_traits.h",
         "ir/tf_types.def",
         "ir/tf_types.h",
         "transforms/passes.h",
@@ -152,11 +155,11 @@
     name = "convert_graphdef",
     srcs = [
         "translate/export_graphdef.cc",
-        "translate/import_graphdef.cc",
+        "translate/import_model.cc",
     ],
     hdrs = [
         "translate/export_graphdef.h",
-        "translate/import_graphdef.h",
+        "translate/import_model.h",
     ],
     deps = [
         ":convert_tensor",
@@ -173,6 +176,7 @@
         "//tensorflow/core:graph",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_proto_cc",
+        "//tensorflow/core/platform:types",
         "//tensorflow/stream_executor/lib",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_map",
@@ -182,6 +186,7 @@
         "@com_google_absl//absl/types:optional",
         "@llvm//:support",
         "@local_config_mlir//:IR",
+        "@local_config_mlir//:Pass",
         "@local_config_mlir//:StandardDialectRegistration",
         "@local_config_mlir//:StandardOps",
         "@local_config_mlir//:Support",
@@ -285,6 +290,7 @@
         "@local_config_mlir//:IR",
         "@local_config_mlir//:StandardOps",
     ],
+    alwayslink = 1,
 )
 
 cc_library(
@@ -380,7 +386,6 @@
         ":convert_tensor",
         ":eval_util",
         ":tensorflow",
-        ":tf_graph_optimization_pass",
         "//tensorflow/c:tf_status",
         "//tensorflow/c/eager:c_api",
         "//tensorflow/core:framework",
@@ -410,6 +415,7 @@
     deps = [
         ":convert_graphdef",
         ":mlir_roundtrip_flags",
+        ":mlir_roundtrip_pass",
         "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration",
         "//tensorflow/core:core_cpu",
         "//tensorflow/core:framework",
diff --git a/tensorflow/compiler/mlir/tensorflow/g3doc/tf_ops.md b/tensorflow/compiler/mlir/tensorflow/g3doc/tf_ops.md
deleted file mode 100755
index cedeba5..0000000
--- a/tensorflow/compiler/mlir/tensorflow/g3doc/tf_ops.md
+++ /dev/null
@@ -1,2761 +0,0 @@
-<!-- Autogenerated by mlir-tblgen; don't manually edit -->
-# Operation definition
-## tf.Abs (TF::AbsOp)
-Computes the absolute value of a tensor.
-
-### Description:
-
-Given a tensor `x`, this operation returns a tensor containing the absolute
-value of each element in `x`. For example, if x is an input element and y is
-an output element, this operation computes \\(y = |x|\\).
-
-### Operands:
-1. `x`: tensor of floating-point or 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `y`: tensor of floating-point or 32/64-bit integer values
-
-## tf.AddN (TF::AddNOp)
-Add all input tensors element wise.
-
-### Description:
-
-
-### Operands:
-1. `inputs`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer or complex128 type or complex64 type or TensorFlow variant type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `N` | `IntegerAttr` | 64-bit integer attribute whose minimal value is 1 attribute |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `sum`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer or complex128 type or complex64 type or TensorFlow variant type values
-
-## tf.Add (TF::AddOp)
-Returns x + y element-wise.
-
-### Description:
-
-*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
-[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-
-### Operands:
-1. `x`: tensor of number or TensorFlow string type values
-1. `y`: tensor of number or TensorFlow string type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `z`: tensor of number or TensorFlow string type values
-
-## tf.AddV2 (TF::AddV2Op)
-Returns x + y element-wise.
-
-### Description:
-
-*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
-[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-
-### Operands:
-1. `x`: tensor of number values
-1. `y`: tensor of number values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `z`: tensor of number values
-
-## tf.AvgPool (TF::AvgPoolOp)
-Performs average pooling on the input.
-
-### Description:
-
-Each entry in `output` is the mean of the corresponding size `ksize`
-window in `value`.
-
-### Operands:
-1. `value`: tensor of floating-point values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `ksize` | `ArrayAttr` | 64-bit integer array attribute with at least 4 elements attribute |
-| `strides` | `ArrayAttr` | 64-bit integer array attribute with at least 4 elements attribute |
-| `padding` | `StringAttr` | string attribute whose value is SAME, or VALID attribute |
-| `data_format` | `StringAttr` | 'NHWC' or 'NCHW' convnet data format attribute |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of floating-point values
-
-## tf.BatchToSpaceND (TF::BatchToSpaceNDOp)
-BatchToSpace for N-D tensors of type T.
-
-### Description:
-
-This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of shape
-`block_shape + [batch]`, interleaves these blocks back into the grid defined by
-the spatial dimensions `[1, ..., M]`, to obtain a result with the same rank as
-the input.  The spatial dimensions of this intermediate result are then
-optionally cropped according to `crops` to produce the output.  This is the
-reverse of SpaceToBatch.  See below for a precise description.
-
-### Operands:
-1. `input`: tensor of tf.dtype values
-1. `block_shape`: tensor of 32/64-bit integer values
-1. `crops`: tensor of 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-| `Tcrops` | `Attribute` | derived attribute attribute |
-| `Tblock_shape` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.BiasAdd (TF::BiasAddOp)
-Adds `bias` to `value`.
-
-### Description:
-
-This is a special case of `tf.add` where `bias` is restricted to be 1-D.
-Broadcasting is supported, so `value` may have any number of dimensions.
-
-### Operands:
-1. `value`: tensor of number values
-1. `bias`: tensor of number values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `data_format` | `StringAttr` | 'NHWC' or 'NCHW' convnet data format attribute |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of number values
-
-## tf.Bitcast (TF::BitcastOp)
-
-Bitcasts a tensor from one type to another without copying data.
-  
-
-### Description:
-
-Given a tensor `input`, this operation returns a tensor that has the same buffer
-data as `input` with datatype `type`.
-
-If the input datatype `T` is larger than the output datatype `type` then the
-shape changes from [...] to [..., sizeof(`T`)/sizeof(`type`)].
-
-If `T` is smaller than `type`, the operator requires that the rightmost
-dimension be equal to sizeof(`type`)/sizeof(`T`). The shape then goes from
-[..., sizeof(`type`)/sizeof(`T`)] to [...].
-
-tf.bitcast() and tf.cast() work differently when real dtype is casted as a complex dtype
-(e.g. tf.complex64 or tf.complex128) as tf.cast() make imaginary part 0 while tf.bitcast()
-gives module error.
-For example,
-
-Example 1:
-```python
->>> a = [1., 2., 3.]
->>> equality_bitcast = tf.bitcast(a,tf.complex128)
-tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot bitcast from float to complex128: shape [3] [Op:Bitcast]
->>> equality_cast = tf.cast(a,tf.complex128)
->>> print(equality_cast)
-tf.Tensor([1.+0.j 2.+0.j 3.+0.j], shape=(3,), dtype=complex128)
-```
-Example 2:
-```python
->>> tf.bitcast(tf.constant(0xffffffff, dtype=tf.uint32), tf.uint8)
-<tf.Tensor: ... shape=(4,), dtype=uint8, numpy=array([255, 255, 255, 255], dtype=uint8)>
-```
-Example 3:
-```python
->>> x = [1., 2., 3.]
->>> y = [0., 2., 3.]
->>> equality= tf.equal(x,y)
->>> equality_cast = tf.cast(equality,tf.float32)
->>> equality_bitcast = tf.bitcast(equality_cast,tf.uint8)
->>> print(equality)
-tf.Tensor([False True True], shape=(3,), dtype=bool)
->>> print(equality_cast)
-tf.Tensor([0. 1. 1.], shape=(3,), dtype=float32)
->>> print(equality_bitcast)
-tf.Tensor(
-[[ 0 0 0 0]
- [ 0 0 128 63]
- [ 0 0 128 63]], shape=(3, 4), dtype=uint8)
-```
-
-*NOTE*: Bitcast is implemented as a low-level cast, so machines with different
-endian orderings will give different results.
-
-### Operands:
-1. `input`: tensor of number values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-| `type` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of number values
-
-## tf.BroadcastTo (TF::BroadcastToOp)
-Broadcast an array for a compatible shape.
-
-### Description:
-
-Broadcasting is the process of making arrays to have compatible shapes
-for arithmetic operations. Two shapes are compatible if for each
-dimension pair they are either equal or one of them is one. When trying
-to broadcast a Tensor to a shape, it starts with the trailing dimensions,
-and works its way forward.
-
-For example,
-
-```python
->>> x = tf.constant([1, 2, 3])
->>> y = tf.broadcast_to(x, [3, 3])
->>> sess.run(y)
-array([[1, 2, 3],
-       [1, 2, 3],
-       [1, 2, 3]], dtype=int32)
-```
-
-In the above example, the input Tensor with the shape of `[1, 3]`
-is broadcasted to output Tensor with shape of `[3, 3]`.
-
-### Operands:
-1. `input`: tensor of tf.dtype values
-1. `shape`: tensor of 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-| `Tidx` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.Cast (TF::CastOp)
-Cast x of type SrcT to y of DstT.
-
-### Description:
-
-
-### Operands:
-1. `x`: tensor of tf.dtype values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `Truncate` | `BoolAttr` | bool attribute attribute |
-| `SrcT` | `Attribute` | derived attribute attribute |
-| `DstT` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `y`: tensor of tf.dtype values
-
-## tf.Ceil (TF::CeilOp)
-Returns element-wise smallest integer not less than x.
-
-### Description:
-
-
-### Operands:
-1. `x`: tensor of floating-point values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `y`: tensor of floating-point values
-
-## tf.Concat (TF::ConcatOp)
-Concatenates tensors along one dimension.
-
-### Description:
-
-
-### Operands:
-1. `concat_dim`: tensor of 32-bit integer values
-1. `values`: tensor of tf.dtype values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `N` | `IntegerAttr` | 64-bit integer attribute whose minimal value is 2 attribute |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.ConcatV2 (TF::ConcatV2Op)
-Concatenates tensors along one dimension.
-
-### Description:
-
-
-### Operands:
-1. `values`: tensor of tf.dtype values
-1. `axis`: tensor of 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `N` | `IntegerAttr` | 64-bit integer attribute whose minimal value is 2 attribute |
-| `T` | `Attribute` | derived attribute attribute |
-| `Tidx` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.Conj (TF::ConjOp)
-Returns the complex conjugate of a complex number.
-
-### Description:
-
-Given a tensor `input` of complex numbers, this operation returns a tensor of
-complex numbers that are the complex conjugate of each element in `input`. The
-complex numbers in `input` must be of the form \\(a + bj\\), where *a* is the
-real part and *b* is the imaginary part.
-
-The complex conjugate returned by this operation is of the form \\(a - bj\\).
-
-For example:
-
-```
-# tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
-tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j]
-```
-
-### Operands:
-1. `input`: tensor of complex128 type or complex64 type or TensorFlow variant type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of complex128 type or complex64 type or TensorFlow variant type values
-
-## tf.Const (TF::ConstOp)
-Constant tensor op
-
-### Description:
-
-
-### Operands:
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `value` | `ElementsAttr` | constant vector/tensor attribute attribute |
-| `dtype` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.Conv2D (TF::Conv2DOp)
-
-Computes a 2-D convolution given 4-D `input` and `filter` tensors.
-  
-
-### Description:
-
-Given an input tensor of shape `[batch, in_height, in_width, in_channels]`
-and a filter / kernel tensor of shape
-`[filter_height, filter_width, in_channels, out_channels]`, this op
-performs the following:
-
-1. Flattens the filter to a 2-D matrix with shape
-   `[filter_height * filter_width * in_channels, output_channels]`.
-2. Extracts image patches from the input tensor to form a *virtual*
-   tensor of shape `[batch, out_height, out_width,
-   filter_height * filter_width * in_channels]`.
-3. For each patch, right-multiplies the filter matrix and the image patch
-   vector.
-
-In detail, with the default NHWC format,
-
-    output[b, i, j, k] =
-        sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] *
-                        filter[di, dj, q, k]
-
-Must have `strides[0] = strides[3] = 1`.  For the most common case of the same
-horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
-
-### Operands:
-1. `input`: tensor of floating-point values
-1. `filter`: tensor of floating-point values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `strides` | `ArrayAttr` | 64-bit integer array attribute attribute |
-| `use_cudnn_on_gpu` | `BoolAttr` | bool attribute attribute |
-| `padding` | `StringAttr` | string attribute whose value is SAME, or VALID, or EXPLICIT attribute |
-| `explicit_paddings` | `ArrayAttr` | 64-bit integer array attribute attribute |
-| `data_format` | `StringAttr` | 'NHWC' or 'NCHW' convnet data format attribute |
-| `dilations` | `ArrayAttr` | 64-bit integer array attribute attribute |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of floating-point values
-
-## tf.Cos (TF::CosOp)
-Computes cos of x element-wise.
-
-### Description:
-
-
-### Operands:
-1. `x`: tensor of floating-point or 64/128-bit complex type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `y`: tensor of floating-point or 64/128-bit complex type values
-
-## tf.DepthwiseConv2dNative (TF::DepthwiseConv2dNativeOp)
-
-Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors.
-  
-
-### Description:
-
-Given an input tensor of shape `[batch, in_height, in_width, in_channels]`
-and a filter / kernel tensor of shape
-`[filter_height, filter_width, in_channels, channel_multiplier]`, containing
-`in_channels` convolutional filters of depth 1, `depthwise_conv2d` applies
-a different filter to each input channel (expanding from 1 channel to
-`channel_multiplier` channels for each), then concatenates the results
-together. Thus, the output has `in_channels * channel_multiplier` channels.
-
-```
-for k in 0..in_channels-1
-  for q in 0..channel_multiplier-1
-    output[b, i, j, k * channel_multiplier + q] =
-      sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, k] *
-                        filter[di, dj, k, q]
-```
-
-Must have `strides[0] = strides[3] = 1`.  For the most common case of the same
-horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
-
-### Operands:
-1. `input`: tensor of floating-point values
-1. `filter`: tensor of floating-point values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `strides` | `ArrayAttr` | 64-bit integer array attribute attribute |
-| `padding` | `StringAttr` | string attribute whose value is SAME, or VALID attribute |
-| `data_format` | `StringAttr` | 'NHWC' or 'NCHW' convnet data format attribute |
-| `dilations` | `ArrayAttr` | 64-bit integer array attribute attribute |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of floating-point values
-
-## tf.Div (TF::DivOp)
-Returns x / y element-wise.
-
-### Description:
-
-*NOTE*: `Div` supports broadcasting. More about broadcasting
-[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-
-### Operands:
-1. `x`: tensor of number values
-1. `y`: tensor of number values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `z`: tensor of number values
-
-## tf.Elu (TF::EluOp)
-
-Computes exponential linear: `exp(features) - 1` if < 0, `features` otherwise.
-  
-
-### Description:
-
-See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)
-](http://arxiv.org/abs/1511.07289)
-
-### Operands:
-1. `features`: tensor of floating-point values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `activations`: tensor of floating-point values
-
-## tf.Equal (TF::EqualOp)
-Returns the truth value of (x == y) element-wise.
-
-### Description:
-
-*NOTE*: `Equal` supports broadcasting. More about broadcasting
-[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-
-```python
-x = tf.constant([2, 4])
-y = tf.constant(2)
-tf.math.equal(x, y) ==> array([True, False])
-
-x = tf.constant([2, 4])
-y = tf.constant([2, 4])
-tf.math.equal(x, y) ==> array([True,  True])
-```
-
-### Operands:
-1. `x`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 1-bit integer or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer or complex128 type or complex64 type or TensorFlow string type values
-1. `y`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 1-bit integer or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer or complex128 type or complex64 type or TensorFlow string type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `z`: tensor of 1-bit integer values
-
-## tf.ExpandDims (TF::ExpandDimsOp)
-Inserts a dimension of 1 into a tensor's shape.
-
-### Description:
-
-Given a tensor `input`, this operation inserts a dimension of 1 at the
-dimension index `axis` of `input`'s shape. The dimension index `axis` starts at
-zero; if you specify a negative number for `axis` it is counted backward from
-the end.
-
-This operation is useful if you want to add a batch dimension to a single
-element. For example, if you have a single image of shape `[height, width,
-channels]`, you can make it a batch of 1 image with `expand_dims(image, 0)`,
-which will make the shape `[1, height, width, channels]`.
-
-Other examples:
-
-```
-# 't' is a tensor of shape [2]
-shape(expand_dims(t, 0)) ==> [1, 2]
-shape(expand_dims(t, 1)) ==> [2, 1]
-shape(expand_dims(t, -1)) ==> [2, 1]
-
-# 't2' is a tensor of shape [2, 3, 5]
-shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5]
-shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5]
-shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1]
-```
-
-This operation requires that:
-
-`-1-input.dims() <= dim <= input.dims()`
-
-This operation is related to `squeeze()`, which removes dimensions of
-size 1.
-
-### Operands:
-1. `input`: tensor of tf.dtype values
-1. `dim`: tensor of 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-| `Tdim` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.FakeQuantWithMinMaxArgs (TF::FakeQuantWithMinMaxArgsOp)
-
-Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type.
-  
-
-### Description:
-
-Attributes `[min; max]` define the clamping range for the `inputs` data.
-`inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]`
-when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and
-then de-quantized and output as floats in `[min; max]` interval.
-`num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive.
-
-Before quantization, `min` and `max` values are adjusted with the following
-logic.
-It is suggested to have `min <= 0 <= max`. If `0` is not in the range of values,
-the behavior can be unexpected:
-If `0 < min < max`: `min_adj = 0` and `max_adj = max - min`.
-If `min < max < 0`: `min_adj = min - max` and `max_adj = 0`.
-If `min <= 0 <= max`: `scale = (max - min) / (2^num_bits - 1) `,
-`min_adj = scale * round(min / scale)` and `max_adj = max + min_adj - min`.
-
-Quantization is called fake since the output is still in floating point.
-
-### Operands:
-1. `inputs`: tensor of 32-bit float values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `min` | `FloatAttr` | 32-bit float attribute attribute |
-| `max` | `FloatAttr` | 32-bit float attribute attribute |
-| `num_bits` | `IntegerAttr` | 64-bit integer attribute attribute |
-| `narrow_range` | `BoolAttr` | bool attribute attribute |
-
-### Results:
-1. `outputs`: tensor of 32-bit float values
-
-## tf.FakeQuantWithMinMaxVars (TF::FakeQuantWithMinMaxVarsOp)
-
-Fake-quantize the 'inputs' tensor of type float via global float scalars `min`
-  
-
-### Description:
-
-and `max` to 'outputs' tensor of same shape as `inputs`.
-
-`[min; max]` define the clamping range for the `inputs` data.
-`inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]`
-when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and
-then de-quantized and output as floats in `[min; max]` interval.
-`num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive.
-
-Before quantization, `min` and `max` values are adjusted with the following
-logic.
-It is suggested to have `min <= 0 <= max`. If `0` is not in the range of values,
-the behavior can be unexpected:
-If `0 < min < max`: `min_adj = 0` and `max_adj = max - min`.
-If `min < max < 0`: `min_adj = min - max` and `max_adj = 0`.
-If `min <= 0 <= max`: `scale = (max - min) / (2^num_bits - 1) `,
-`min_adj = scale * round(min / scale)` and `max_adj = max + min_adj - min`.
-
-This operation has a gradient and thus allows for training `min` and `max`
-values.
-
-### Operands:
-1. `inputs`: tensor of 32-bit float values
-1. `min`: tensor of 32-bit float values
-1. `max`: tensor of 32-bit float values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `num_bits` | `IntegerAttr` | 64-bit integer attribute attribute |
-| `narrow_range` | `BoolAttr` | bool attribute attribute |
-
-### Results:
-1. `outputs`: tensor of 32-bit float values
-
-## tf.Fill (TF::FillOp)
-Creates a tensor filled with a scalar value.
-
-### Description:
-
-This operation creates a tensor of shape `dims` and fills it with `value`.
-
-For example:
-
-```
-# Output tensor has shape [2, 3].
-fill([2, 3], 9) ==> [[9, 9, 9]
-                     [9, 9, 9]]
-```
-
-`tf.fill` differs from `tf.constant` in a few ways:
-
-*   `tf.fill` only supports scalar contents, whereas `tf.constant` supports
-    Tensor values.
-*   `tf.fill` creates an Op in the computation graph that constructs the actual
-    Tensor value at runtime. This is in contrast to `tf.constant` which embeds
-    the entire Tensor into the graph with a `Const` node.
-*   Because `tf.fill` evaluates at graph runtime, it supports dynamic shapes
-    based on other runtime Tensors, unlike `tf.constant`.
-
-### Operands:
-1. `dims`: tensor of 32/64-bit integer values
-1. `value`: tensor of tf.dtype values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-| `index_type` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.FloorDiv (TF::FloorDivOp)
-Returns x // y element-wise.
-
-### Description:
-
-*NOTE*: `FloorDiv` supports broadcasting. More about broadcasting
-[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-
-### Operands:
-1. `x`: tensor of number values
-1. `y`: tensor of number values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `z`: tensor of number values
-
-## tf.Floor (TF::FloorOp)
-Returns element-wise largest integer not greater than x.
-
-### Description:
-
-
-### Operands:
-1. `x`: tensor of floating-point values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `y`: tensor of floating-point values
-
-## tf.FusedBatchNorm (TF::FusedBatchNormOp)
-Batch normalization.
-
-### Description:
-
-Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
-The size of 1D Tensors matches the dimension C of the 4D Tensors.
-
-### Operands:
-1. `x`: tensor of 32-bit float values
-1. `scale`: tensor of 32-bit float values
-1. `offset`: tensor of 32-bit float values
-1. `mean`: tensor of 32-bit float values
-1. `variance`: tensor of 32-bit float values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `epsilon` | `FloatAttr` | 32-bit float attribute attribute |
-| `data_format` | `StringAttr` | 'NHWC' or 'NCHW' convnet data format attribute |
-| `is_training` | `BoolAttr` | bool attribute attribute |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `y`: tensor of 32-bit float values
-1. `batch_mean`: tensor of 32-bit float values
-1. `batch_variance`: tensor of 32-bit float values
-1. `reserve_space_1`: tensor of 32-bit float values
-1. `reserve_space_2`: tensor of 32-bit float values
-
-## tf.Gather (TF::GatherOp)
-Gather slices from `params` according to `indices`.
-
-### Description:
-
-`indices` must be an integer tensor of any dimension (usually 0-D or 1-D).
-Produces an output tensor with shape `indices.shape + params.shape[1:]` where:
-
-```python
-    # Scalar indices
-    output[:, ..., :] = params[indices, :, ... :]
-
-    # Vector indices
-    output[i, :, ..., :] = params[indices[i], :, ... :]
-
-    # Higher rank indices
-    output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :]
-```
-
-If `indices` is a permutation and `len(indices) == params.shape[0]` then
-this operation will permute `params` accordingly.
-
-`validate_indices`: DEPRECATED. If this operation is assigned to CPU, values in
-`indices` are always validated to be within range. If assigned to GPU,
-out-of-bound indices result in safe but unspecified behavior, which may include
-raising an error.
-
-<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="https://www.tensorflow.org/images/Gather.png" alt>
-</div>
-
-### Operands:
-1. `params`: tensor of tf.dtype values
-1. `indices`: tensor of 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `validate_indices` | `BoolAttr` | bool attribute attribute |
-| `Tindices` | `Attribute` | derived attribute attribute |
-| `Tparams` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.GatherV2 (TF::GatherV2Op)
-
-Gather slices from `params` axis `axis` according to `indices`.
-  
-
-### Description:
-
-`indices` must be an integer tensor of any dimension (usually 0-D or 1-D).
-Produces an output tensor with shape `params.shape[:axis] + indices.shape +
-params.shape[axis + 1:]` where:
-
-```python
-    # Scalar indices (output is rank(params) - 1).
-    output[a_0, ..., a_n, b_0, ..., b_n] =
-      params[a_0, ..., a_n, indices, b_0, ..., b_n]
-
-    # Vector indices (output is rank(params)).
-    output[a_0, ..., a_n, i, b_0, ..., b_n] =
-      params[a_0, ..., a_n, indices[i], b_0, ..., b_n]
-
-    # Higher rank indices (output is rank(params) + rank(indices) - 1).
-    output[a_0, ..., a_n, i, ..., j, b_0, ... b_n] =
-      params[a_0, ..., a_n, indices[i, ..., j], b_0, ..., b_n]
-```
-
-<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="https://www.tensorflow.org/images/Gather.png" alt>
-</div>
-
-Note that on CPU, if an out of bound index is found, an error is returned.
-On GPU, if an out of bound index is found, a 0 is stored in the
-corresponding output value.
-
-See also `tf.batch_gather` and `tf.gather_nd`.
-
-### Operands:
-1. `params`: tensor of tf.dtype values
-1. `indices`: tensor of 32/64-bit integer values
-1. `axis`: tensor of 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `batch_dims` | `IntegerAttr` | 64-bit integer attribute attribute |
-| `Tindices` | `Attribute` | derived attribute attribute |
-| `Tparams` | `Attribute` | derived attribute attribute |
-| `Taxis` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.GreaterEqual (TF::GreaterEqualOp)
-Returns the truth value of (x >= y) element-wise.
-
-### Description:
-
-*NOTE*: `GreaterEqual` supports broadcasting. More about broadcasting
-[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-
-### Operands:
-1. `x`: tensor of 8/16/32/64-bit integer or floating-point values
-1. `y`: tensor of 8/16/32/64-bit integer or floating-point values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `z`: tensor of 1-bit integer values
-
-## tf.Greater (TF::GreaterOp)
-Returns the truth value of (x > y) element-wise.
-
-### Description:
-
-*NOTE*: `Greater` supports broadcasting. More about broadcasting
-[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-
-### Operands:
-1. `x`: tensor of 8/16/32/64-bit integer or floating-point values
-1. `y`: tensor of 8/16/32/64-bit integer or floating-point values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `z`: tensor of 1-bit integer values
-
-## tf.IdentityN (TF::IdentityNOp)
-
-Returns a list of tensors with the same shapes and contents as the input
-  
-
-### Description:
-
-tensors.
-
-This op can be used to override the gradient for complicated functions. For
-example, suppose y = f(x) and we wish to apply a custom function g for backprop
-such that dx = g(dy). In Python,
-
-```python
-with tf.get_default_graph().gradient_override_map(
-    {'IdentityN': 'OverrideGradientWithG'}):
-  y, _ = identity_n([f(x), x])
-
-@tf.RegisterGradient('OverrideGradientWithG')
-def ApplyG(op, dy, _):
-  return [None, g(dy)]  # Do not backprop to f(x).
-```
-
-### Operands:
-1. `input`: tensor of tf.dtype values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.Identity (TF::IdentityOp)
-Identity op
-
-### Description:
-
-Returns a tensor with the same shape and contents as input.
-
-### Operands:
-1. `input`: tensor of tf.dtype values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.Invert (TF::InvertOp)
-
-Invert (flip) each bit of supported types; for example, type `uint8` value 01010101 becomes 10101010.
-  
-
-### Description:
-
-Flip each bit of supported types.  For example, type `int8` (decimal 2) binary 00000010 becomes (decimal -3) binary 11111101.
-This operation is performed on each element of the tensor argument `x`.
-
-### Operands:
-1. `x`: tensor of 8/16/32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `y`: tensor of 8/16/32/64-bit integer values
-
-## tf.LeakyRelu (TF::LeakyReluOp)
-Computes rectified linear: `max(features, features * alpha)`.
-
-### Description:
-
-
-### Operands:
-1. `features`: tensor of floating-point values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `alpha` | `FloatAttr` | 32-bit float attribute attribute |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `activations`: tensor of floating-point values
-
-## tf.LessEqual (TF::LessEqualOp)
-Returns the truth value of (x <= y) element-wise.
-
-### Description:
-
-*NOTE*: `LessEqual` supports broadcasting. More about broadcasting
-[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-
-### Operands:
-1. `x`: tensor of 8/16/32/64-bit integer or floating-point values
-1. `y`: tensor of 8/16/32/64-bit integer or floating-point values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `z`: tensor of 1-bit integer values
-
-## tf.Less (TF::LessOp)
-Returns the truth value of (x < y) element-wise.
-
-### Description:
-
-*NOTE*: `Less` supports broadcasting. More about broadcasting
-[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-
-### Operands:
-1. `x`: tensor of 8/16/32/64-bit integer or floating-point values
-1. `y`: tensor of 8/16/32/64-bit integer or floating-point values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `z`: tensor of 1-bit integer values
-
-## tf.Log (TF::LogOp)
-Computes natural logarithm of x element-wise.
-
-### Description:
-
-I.e., \\(y = \log_e x\\).
-
-### Operands:
-1. `x`: tensor of floating-point or 64/128-bit complex type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `y`: tensor of floating-point or 64/128-bit complex type values
-
-## tf.LogSoftmax (TF::LogSoftmaxOp)
-Computes log softmax activations.
-
-### Description:
-
-For each batch `i` and class `j` we have
-
-    logsoftmax[i, j] = logits[i, j] - log(sum(exp(logits[i])))
-
-### Operands:
-1. `logits`: tensor of floating-point values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `logsoftmax`: tensor of floating-point values
-
-## tf.LogicalAnd (TF::LogicalAndOp)
-Returns the truth value of x AND y element-wise.
-
-### Description:
-
-*NOTE*: `LogicalAnd` supports broadcasting. More about broadcasting
-[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-
-### Operands:
-1. `x`: tensor of 1-bit integer values
-1. `y`: tensor of 1-bit integer values
-
-### Attributes:
-
-### Results:
-1. `z`: tensor of 1-bit integer values
-
-## tf.LogicalNot (TF::LogicalNotOp)
-Returns the truth value of NOT x element-wise.
-
-### Description:
-
-
-### Operands:
-1. `x`: tensor of 1-bit integer values
-
-### Attributes:
-
-### Results:
-1. `y`: tensor of 1-bit integer values
-
-## tf.LogicalOr (TF::LogicalOrOp)
-Returns the truth value of x OR y element-wise.
-
-### Description:
-
-*NOTE*: `LogicalOr` supports broadcasting. More about broadcasting
-[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-
-### Operands:
-1. `x`: tensor of 1-bit integer values
-1. `y`: tensor of 1-bit integer values
-
-### Attributes:
-
-### Results:
-1. `z`: tensor of 1-bit integer values
-
-## tf.MatMul (TF::MatMulOp)
-
-Multiply the matrix "a" by the matrix "b".
-  
-
-### Description:
-
-The inputs must be two-dimensional matrices and the inner dimension of
-"a" (after being transposed if transpose_a is true) must match the
-outer dimension of "b" (after being transposed if transposed_b is
-true).
-
-*Note*: The default kernel implementation for MatMul on GPUs uses
-cublas.
-
-### Operands:
-1. `a`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values
-1. `b`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `transpose_a` | `BoolAttr` | bool attribute attribute |
-| `transpose_b` | `BoolAttr` | bool attribute attribute |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `product`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values
-
-## tf.Max (TF::MaxOp)
-
-Computes the maximum of elements across dimensions of a tensor.
-  
-
-### Description:
-
-Reduces `input` along the dimensions given in `axis`. Unless
-`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
-`axis`. If `keep_dims` is true, the reduced dimensions are
-retained with length 1.
-
-### Operands:
-1. `input`: tensor of number values
-1. `reduction_indices`: tensor of 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `keep_dims` | `BoolAttr` | bool attribute attribute |
-| `T` | `Attribute` | derived attribute attribute |
-| `Tidx` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of number values
-
-## tf.MaxPool (TF::MaxPoolOp)
-Performs max pooling on the input.
-
-### Description:
-
-
-### Operands:
-1. `input`: tensor of 8/16/32/64-bit integer or floating-point values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `ksize` | `ArrayAttr` | 64-bit integer array attribute with at least 4 elements attribute |
-| `strides` | `ArrayAttr` | 64-bit integer array attribute with at least 4 elements attribute |
-| `padding` | `StringAttr` | string attribute whose value is SAME, or VALID attribute |
-| `data_format` | `StringAttr` | string attribute whose value is NHWC, or NCHW, or NCHW_VECT_C attribute |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of 8/16/32/64-bit integer or floating-point values
-
-## tf.Maximum (TF::MaximumOp)
-Returns the max of x and y (i.e. x > y ? x : y) element-wise.
-
-### Description:
-
-*NOTE*: `Maximum` supports broadcasting. More about broadcasting
-[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-
-### Operands:
-1. `x`: tensor of floating-point or 32/64-bit integer values
-1. `y`: tensor of floating-point or 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `z`: tensor of floating-point or 32/64-bit integer values
-
-## tf.Mean (TF::MeanOp)
-Computes the mean of elements across dimensions of a tensor.
-
-### Description:
-
-Reduces `input` along the dimensions given in `axis`. Unless
-`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
-`axis`. If `keep_dims` is true, the reduced dimensions are
-retained with length 1.
-
-### Operands:
-1. `input`: tensor of number values
-1. `reduction_indices`: tensor of 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `keep_dims` | `BoolAttr` | bool attribute attribute |
-| `T` | `Attribute` | derived attribute attribute |
-| `Tidx` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of number values
-
-## tf.Min (TF::MinOp)
-
-Computes the minimum of elements across dimensions of a tensor.
-  
-
-### Description:
-
-Reduces `input` along the dimensions given in `axis`. Unless
-`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
-`axis`. If `keep_dims` is true, the reduced dimensions are
-retained with length 1.
-
-### Operands:
-1. `input`: tensor of number values
-1. `reduction_indices`: tensor of 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `keep_dims` | `BoolAttr` | bool attribute attribute |
-| `T` | `Attribute` | derived attribute attribute |
-| `Tidx` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of number values
-
-## tf.Minimum (TF::MinimumOp)
-Returns the min of x and y (i.e. x < y ? x : y) element-wise.
-
-### Description:
-
-*NOTE*: `Minimum` supports broadcasting. More about broadcasting
-[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-
-### Operands:
-1. `x`: tensor of floating-point or 32/64-bit integer values
-1. `y`: tensor of floating-point or 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `z`: tensor of floating-point or 32/64-bit integer values
-
-## tf.MulNoNan (TF::MulNoNanOp)
-
-Returns x * y element-wise. Returns zero if y is zero, even if x if infinite or NaN.
-  
-
-### Description:
-
-*NOTE*: `MulNoNan` supports broadcasting. More about broadcasting
-[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-
-### Operands:
-1. `x`: tensor of 16-bit float or 32-bit float or 64-bit float or complex128 type or complex64 type values
-1. `y`: tensor of 16-bit float or 32-bit float or 64-bit float or complex128 type or complex64 type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `z`: tensor of 16-bit float or 32-bit float or 64-bit float or complex128 type or complex64 type values
-
-## tf.Mul (TF::MulOp)
-Returns x * y element-wise.
-
-### Description:
-
-*NOTE*: `Multiply` supports broadcasting. More about broadcasting
-[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-
-### Operands:
-1. `x`: tensor of number values
-1. `y`: tensor of number values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `z`: tensor of number values
-
-## tf.Neg (TF::NegOp)
-Computes numerical negative value element-wise.
-
-### Description:
-
-I.e., \\(y = -x\\).
-
-### Operands:
-1. `x`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `y`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values
-
-## tf.NoOp (TF::NoOp)
-Does nothing. Only useful as a placeholder for control edges.
-
-### Description:
-
-
-### Operands:
-
-### Attributes:
-
-### Results:
-
-## tf.NotEqual (TF::NotEqualOp)
-Returns the truth value of (x != y) element-wise.
-
-### Description:
-
-*NOTE*: `NotEqual` supports broadcasting. More about broadcasting
-[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-
-### Operands:
-1. `x`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 1-bit integer or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer or complex128 type or complex64 type or TensorFlow string type values
-1. `y`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 1-bit integer or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer or complex128 type or complex64 type or TensorFlow string type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `z`: tensor of 1-bit integer values
-
-## tf.Pack (TF::PackOp)
-
-Packs a list of `N` rank-`R` tensors into one rank-`(R+1)` tensor.
-  
-
-### Description:
-
-Packs the `N` tensors in `values` into a tensor with rank one higher than each
-tensor in `values`, by packing them along the `axis` dimension.
-Given a list of tensors of shape `(A, B, C)`;
-
-if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`.
-if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`.
-Etc.
-
-For example:
-
-```
-# 'x' is [1, 4]
-# 'y' is [2, 5]
-# 'z' is [3, 6]
-pack([x, y, z]) => [[1, 4], [2, 5], [3, 6]]  # Pack along first dim.
-pack([x, y, z], axis=1) => [[1, 2, 3], [4, 5, 6]]
-```
-
-This is the opposite of `unpack`.
-
-### Operands:
-1. `values`: tensor of tf.dtype values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `N` | `IntegerAttr` | 64-bit integer attribute whose minimal value is 1 attribute |
-| `axis` | `IntegerAttr` | 64-bit integer attribute attribute |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.Pad (TF::PadOp)
-Pads a tensor with zeros.
-
-### Description:
-
-This operation pads a `input` with zeros according to the `paddings` you
-specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the
-rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates
-how many zeros to add before the contents of `input` in that dimension, and
-`paddings[D, 1]` indicates how many zeros to add after the contents of `input`
-in that dimension.
-
-The padded size of each dimension D of the output is:
-
-`paddings(D, 0) + input.dim_size(D) + paddings(D, 1)`
-
-For example:
-
-```
-# 't' is [[1, 1], [2, 2]]
-# 'paddings' is [[1, 1], [2, 2]]
-# rank of 't' is 2
-pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0]
-                      [0, 0, 1, 1, 0, 0]
-                      [0, 0, 2, 2, 0, 0]
-                      [0, 0, 0, 0, 0, 0]]
-```
-
-### Operands:
-1. `input`: tensor of tf.dtype values
-1. `paddings`: tensor of 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-| `Tpaddings` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.PadV2 (TF::PadV2Op)
-Pads a tensor.
-
-### Description:
-
-This operation pads `input` according to the `paddings` and `constant_values`
-you specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is
-the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates
-how many padding values to add before the contents of `input` in that dimension,
-and `paddings[D, 1]` indicates how many padding values to add after the contents
-of `input` in that dimension. `constant_values` is a scalar tensor of the same
-type as `input` that indicates the value to use for padding `input`.
-
-The padded size of each dimension D of the output is:
-
-`paddings(D, 0) + input.dim_size(D) + paddings(D, 1)`
-
-For example:
-
-```
-# 't' is [[1, 1], [2, 2]]
-# 'paddings' is [[1, 1], [2, 2]]
-# 'constant_values' is 0
-# rank of 't' is 2
-pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0]
-                      [0, 0, 1, 1, 0, 0]
-                      [0, 0, 2, 2, 0, 0]
-                      [0, 0, 0, 0, 0, 0]]
-```
-
-### Operands:
-1. `input`: tensor of tf.dtype values
-1. `paddings`: tensor of 32/64-bit integer values
-1. `constant_values`: tensor of tf.dtype values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-| `Tpaddings` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.Placeholder.input (TF::PlaceholderInputOp)
-PlaceholderInput op
-
-### Description:
-
-Inserts a placeholder for a tensor that will be always fed.
-
-### Operands:
-1. `arg`: tensor of tf.dtype values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `min` | `FloatAttr` | 32-bit float attribute attribute |
-| `max` | `FloatAttr` | 32-bit float attribute attribute |
-| `type` | `TypeAttr` | integer type attribute |
-| `dtype` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.Placeholder (TF::PlaceholderOp)
-Placeholder op
-
-### Description:
-
-Inserts a placeholder for a tensor that will be always fed.
-
-### Operands:
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `dtype` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.QuantizeAndDequantize (TF::QuantizeAndDequantizeOp)
-Use QuantizeAndDequantizeV2 instead.
-
-### Description:
-
-
-### Operands:
-1. `input`: tensor of floating-point values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `signed_input` | `BoolAttr` | bool attribute attribute |
-| `num_bits` | `IntegerAttr` | 64-bit integer attribute attribute |
-| `range_given` | `BoolAttr` | bool attribute attribute |
-| `input_min` | `FloatAttr` | 32-bit float attribute attribute |
-| `input_max` | `FloatAttr` | 32-bit float attribute attribute |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of floating-point values
-
-## tf.QuantizeAndDequantizeV2 (TF::QuantizeAndDequantizeV2Op)
-Quantizes then dequantizes a tensor.
-
-### Description:
-
-This op simulates the precision loss from the quantized forward pass by:
-
-1. Quantizing the tensor to fixed point numbers, which should match the target
-   quantization method when it is used in inference.
-2. Dequantizing it back to floating point numbers for the following ops, most
-   likely matmul.
-
-There are different ways to quantize. This version uses only scaling, so 0.0
-maps to 0.
-
-From the specified 'num_bits' in the quantized output type, it determines
-minimum and maximum representable quantized values.
-
-e.g.
-
-*   [-128, 127] for signed, num_bits = 8, or
-*   [0, 255] for unsigned, num_bits = 8.
-
-If range_given == False, the initial input_min, input_max will be determined
-automatically as the minimum and maximum values in the input tensor, otherwise
-the specified values of input_min, input_max are used.
-
-Note: If the input_min, input_max are specified, they do not need to equal the
-actual minimum and maximum values in the tensor. e.g. in some cases it may be
-beneficial to specify these values such that the low probability extremes of the
-input distribution are clipped.
-
-This op determines the maximum scale_factor that would map the initial
-[input_min, input_max] range to a range that lies within the representable
-quantized range.
-
-It determines the scale from one of input_min and input_max, then updates the
-other one to maximize the respresentable range.
-
-e.g.
-
-*   if the output is signed, num_bits = 8, [input_min, input_max] = [-10.0,
-    5.0]: it would use a scale_factor of -128 / -10.0 = 12.8 In this case, it
-    would update input_max to be 127 / 12.8 = 9.921875
-*   if the output is signed, num_bits = 8, [input_min, input_max] = [-10.0,
-    10.0]: it would use a scale_factor of 127 / 10.0 = 12.7 In this case, it
-    would update input_min to be 128.0 / 12.7 = -10.07874
-*   if the output is unsigned, input_min is forced to be 0, and only the
-    specified input_max is used.
-
-After determining the scale_factor and updating the input range, it applies the
-following to each value in the 'input' tensor.
-
-output = round(clamp(value, input_min, input_max) * scale_factor) / scale_factor.
-
-The above round function rounds the value based on the given round_mode.
-
-### Operands:
-1. `input`: tensor of floating-point values
-1. `input_min`: tensor of floating-point values
-1. `input_max`: tensor of floating-point values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `signed_input` | `BoolAttr` | bool attribute attribute |
-| `num_bits` | `IntegerAttr` | 64-bit integer attribute attribute |
-| `range_given` | `BoolAttr` | bool attribute attribute |
-| `round_mode` | `StringAttr` | string attribute whose value is HALF_TO_EVEN, or HALF_UP attribute |
-| `narrow_range` | `BoolAttr` | bool attribute attribute |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of floating-point values
-
-## tf.QuantizeAndDequantizeV3 (TF::QuantizeAndDequantizeV3Op)
-Quantizes then dequantizes a tensor.
-
-### Description:
-
-This is almost identical to QuantizeAndDequantizeV2, except that num_bits is a
-tensor, so its value can change during training.
-
-### Operands:
-1. `input`: tensor of floating-point values
-1. `input_min`: tensor of floating-point values
-1. `input_max`: tensor of floating-point values
-1. `num_bits`: tensor of 32-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `signed_input` | `BoolAttr` | bool attribute attribute |
-| `range_given` | `BoolAttr` | bool attribute attribute |
-| `narrow_range` | `BoolAttr` | bool attribute attribute |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of floating-point values
-
-## tf.RandomUniform (TF::RandomUniformOp)
-Outputs random values from a uniform distribution.
-
-### Description:
-
-The generated values follow a uniform distribution in the range `[0, 1)`. The
-lower bound 0 is included in the range, while the upper bound 1 is excluded.
-
-### Operands:
-1. `shape`: tensor of 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `seed` | `IntegerAttr` | 64-bit integer attribute attribute |
-| `seed2` | `IntegerAttr` | 64-bit integer attribute attribute |
-| `T` | `Attribute` | derived attribute attribute |
-| `dtype` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of floating-point values
-
-## tf.Range (TF::RangeOp)
-Creates a sequence of numbers.
-
-### Description:
-
-This operation creates a sequence of numbers that begins at `start` and
-extends by increments of `delta` up to but not including `limit`.
-
-For example:
-
-```
-# 'start' is 3
-# 'limit' is 18
-# 'delta' is 3
-tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15]
-```
-
-### Operands:
-1. `start`: tensor of bfloat16 type or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer values
-1. `limit`: tensor of bfloat16 type or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer values
-1. `delta`: tensor of bfloat16 type or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `Tidx` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of bfloat16 type or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer values
-
-## tf.Rank (TF::RankOp)
-Returns the rank of a tensor.
-
-### Description:
-
-This operation returns an integer representing the rank of `input`.
-
-For example:
-
-```
-# 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]
-# shape of tensor 't' is [2, 2, 3]
-rank(t) ==> 3
-```
-
-**Note**: The rank of a tensor is not the same as the rank of a matrix. The rank
-of a tensor is the number of indices required to uniquely select each element
-of the tensor. Rank is also known as "order", "degree", or "ndims."
-
-### Operands:
-1. `input`: tensor of tf.dtype values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of 32-bit integer values
-
-## tf.RealDiv (TF::RealDivOp)
-Returns x / y element-wise for real types.
-
-### Description:
-
-If `x` and `y` are reals, this will return the floating-point division.
-
-*NOTE*: `Div` supports broadcasting. More about broadcasting
-[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-
-### Operands:
-1. `x`: tensor of number values
-1. `y`: tensor of number values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `z`: tensor of number values
-
-## tf.Reciprocal (TF::ReciprocalOp)
-Computes the reciprocal of x element-wise.
-
-### Description:
-
-I.e., \\(y = 1 / x\\).
-
-### Operands:
-1. `x`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `y`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values
-
-## tf.Relu6 (TF::Relu6Op)
-Computes rectified linear 6: `min(max(features, 0), 6)`.
-
-### Description:
-
-
-### Operands:
-1. `features`: tensor of 8/16/32/64-bit integer or floating-point values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `activations`: tensor of 8/16/32/64-bit integer or floating-point values
-
-## tf.Relu (TF::ReluOp)
-Computes rectified linear: `max(features, 0)`.
-
-### Description:
-
-
-### Operands:
-1. `features`: tensor of 8/16/32/64-bit integer or floating-point values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `activations`: tensor of 8/16/32/64-bit integer or floating-point values
-
-## tf.Reshape (TF::ReshapeOp)
-Reshapes a tensor.
-
-### Description:
-
-Given `tensor`, this operation returns a tensor that has the same values
-as `tensor` with shape `shape`.
-
-If one component of `shape` is the special value -1, the size of that dimension
-is computed so that the total size remains constant.  In particular, a `shape`
-of `[-1]` flattens into 1-D.  At most one component of `shape` can be -1.
-
-If `shape` is 1-D or higher, then the operation returns a tensor with shape
-`shape` filled with the values of `tensor`. In this case, the number of elements
-implied by `shape` must be the same as the number of elements in `tensor`.
-
-For example:
-
-```
-# tensor 't' is [1, 2, 3, 4, 5, 6, 7, 8, 9]
-# tensor 't' has shape [9]
-reshape(t, [3, 3]) ==> [[1, 2, 3],
-                        [4, 5, 6],
-                        [7, 8, 9]]
-
-# tensor 't' is [[[1, 1], [2, 2]],
-#                [[3, 3], [4, 4]]]
-# tensor 't' has shape [2, 2, 2]
-reshape(t, [2, 4]) ==> [[1, 1, 2, 2],
-                        [3, 3, 4, 4]]
-
-# tensor 't' is [[[1, 1, 1],
-#                 [2, 2, 2]],
-#                [[3, 3, 3],
-#                 [4, 4, 4]],
-#                [[5, 5, 5],
-#                 [6, 6, 6]]]
-# tensor 't' has shape [3, 2, 3]
-# pass '[-1]' to flatten 't'
-reshape(t, [-1]) ==> [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6]
-
-# -1 can also be used to infer the shape
-
-# -1 is inferred to be 9:
-reshape(t, [2, -1]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3],
-                         [4, 4, 4, 5, 5, 5, 6, 6, 6]]
-# -1 is inferred to be 2:
-reshape(t, [-1, 9]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3],
-                         [4, 4, 4, 5, 5, 5, 6, 6, 6]]
-# -1 is inferred to be 3:
-reshape(t, [ 2, -1, 3]) ==> [[[1, 1, 1],
-                              [2, 2, 2],
-                              [3, 3, 3]],
-                             [[4, 4, 4],
-                              [5, 5, 5],
-                              [6, 6, 6]]]
-
-# tensor 't' is [7]
-# shape `[]` reshapes to a scalar
-reshape(t, []) ==> 7
-```
-
-### Operands:
-1. `tensor`: tensor of tf.dtype values
-1. `shape`: tensor of 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-| `Tshape` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.ResizeBilinear (TF::ResizeBilinearOp)
-Resize `images` to `size` using bilinear interpolation.
-
-### Description:
-
-Input images can be of different types but output images are always float.
-
-### Operands:
-1. `images`: tensor of 8/16/32/64-bit integer or floating-point values
-1. `size`: tensor of 32-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `align_corners` | `BoolAttr` | bool attribute attribute |
-| `half_pixel_centers` | `BoolAttr` | bool attribute attribute |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `resized_images`: tensor of 32-bit float values
-
-## tf.ReverseV2 (TF::ReverseV2Op)
-Reverses specific dimensions of a tensor.
-
-### Description:
-
-NOTE `tf.reverse` has now changed behavior in preparation for 1.0.
-`tf.reverse_v2` is currently an alias that will be deprecated before TF 1.0.
-
-Given a `tensor`, and a `int32` tensor `axis` representing the set of
-dimensions of `tensor` to reverse. This operation reverses each dimension
-`i` for which there exists `j` s.t. `axis[j] == i`.
-
-`tensor` can have up to 8 dimensions. The number of dimensions specified
-in `axis` may be 0 or more entries. If an index is specified more than
-once, a InvalidArgument error is raised.
-
-For example:
-
-```
-# tensor 't' is [[[[ 0,  1,  2,  3],
-#                  [ 4,  5,  6,  7],
-#                  [ 8,  9, 10, 11]],
-#                 [[12, 13, 14, 15],
-#                  [16, 17, 18, 19],
-#                  [20, 21, 22, 23]]]]
-# tensor 't' shape is [1, 2, 3, 4]
-
-# 'dims' is [3] or 'dims' is [-1]
-reverse(t, dims) ==> [[[[ 3,  2,  1,  0],
-                        [ 7,  6,  5,  4],
-                        [ 11, 10, 9, 8]],
-                       [[15, 14, 13, 12],
-                        [19, 18, 17, 16],
-                        [23, 22, 21, 20]]]]
-
-# 'dims' is '[1]' (or 'dims' is '[-3]')
-reverse(t, dims) ==> [[[[12, 13, 14, 15],
-                        [16, 17, 18, 19],
-                        [20, 21, 22, 23]
-                       [[ 0,  1,  2,  3],
-                        [ 4,  5,  6,  7],
-                        [ 8,  9, 10, 11]]]]
-
-# 'dims' is '[2]' (or 'dims' is '[-2]')
-reverse(t, dims) ==> [[[[8, 9, 10, 11],
-                        [4, 5, 6, 7],
-                        [0, 1, 2, 3]]
-                       [[20, 21, 22, 23],
-                        [16, 17, 18, 19],
-                        [12, 13, 14, 15]]]]
-```
-
-### Operands:
-1. `tensor`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 1-bit integer or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer or complex128 type or complex64 type or TensorFlow string type values
-1. `axis`: tensor of 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-| `Tidx` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 1-bit integer or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer or complex128 type or complex64 type or TensorFlow string type values
-
-## tf.Rsqrt (TF::RsqrtOp)
-Computes reciprocal of square root of x element-wise.
-
-### Description:
-
-I.e., \\(y = 1 / \sqrt{x}\\).
-
-### Operands:
-1. `x`: tensor of floating-point or 64/128-bit complex type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `y`: tensor of floating-point or 64/128-bit complex type values
-
-## tf.Select (TF::SelectOp)
-Selects elements from `x` or `y`, depending on `condition`.
-
-### Description:
-
-The `x`, and `y` tensors must all have the same shape, and the
-output will also have that shape.
-
-The `condition` tensor must be a scalar if `x` and `y` are scalars.
-If `x` and `y` are vectors or higher rank, then `condition` must be either a
-scalar, a vector with size matching the first dimension of `x`, or must have
-the same shape as `x`.
-
-The `condition` tensor acts as a mask that chooses, based on the value at each
-element, whether the corresponding element / row in the output should be
-taken from `x` (if true) or `y` (if false).
-
-If `condition` is a vector and `x` and `y` are higher rank matrices, then
-it chooses which row (outer dimension) to copy from `x` and `y`.
-If `condition` has the same shape as `x` and `y`, then it chooses which
-element to copy from `x` and `y`.
-
-For example:
-
-```python
-# 'condition' tensor is [[True,  False]
-#                        [False, True]]
-# 't' is [[1, 2],
-#         [3, 4]]
-# 'e' is [[5, 6],
-#         [7, 8]]
-select(condition, t, e)  # => [[1, 6], [7, 4]]
-
-
-# 'condition' tensor is [True, False]
-# 't' is [[1, 2],
-#         [3, 4]]
-# 'e' is [[5, 6],
-#         [7, 8]]
-select(condition, t, e) ==> [[1, 2],
-                             [7, 8]]
-
-```
-
-### Operands:
-1. `condition`: tensor of 1-bit integer values
-1. `t`: tensor of tf.dtype values
-1. `e`: tensor of tf.dtype values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.Shape (TF::ShapeOp)
-Returns the shape of a tensor.
-
-### Description:
-
-This operation returns a 1-D integer tensor representing the shape of `input`.
-
-For example:
-
-```
-# 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]
-shape(t) ==> [2, 2, 3]
-```
-
-### Operands:
-1. `input`: tensor of tf.dtype values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-| `out_type` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of 32/64-bit integer values
-
-## tf.Sigmoid (TF::SigmoidOp)
-Computes sigmoid of `x` element-wise.
-
-### Description:
-
-Specifically, `y = 1 / (1 + exp(-x))`.
-
-### Operands:
-1. `x`: tensor of floating-point or 64/128-bit complex type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `y`: tensor of floating-point or 64/128-bit complex type values
-
-## tf.Sin (TF::SinOp)
-Computes sin of x element-wise.
-
-### Description:
-
-
-### Operands:
-1. `x`: tensor of floating-point or 64/128-bit complex type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `y`: tensor of floating-point or 64/128-bit complex type values
-
-## tf.Slice (TF::SliceOp)
-Return a slice from 'input'.
-
-### Description:
-
-The output tensor is a tensor with dimensions described by 'size'
-whose values are extracted from 'input' starting at the offsets in
-'begin'.
-
-*Requirements*:
-  0 <= begin[i] <= begin[i] + size[i] <= Di  for i in [0, n)
-
-### Operands:
-1. `input`: tensor of tf.dtype values
-1. `begin`: tensor of 32/64-bit integer values
-1. `size`: tensor of 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-| `Index` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.Softmax (TF::SoftmaxOp)
-Computes softmax activations.
-
-### Description:
-
-For each batch `i` and class `j` we have
-
-    $$softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j]))$$
-
-### Operands:
-1. `logits`: tensor of floating-point values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `softmax`: tensor of floating-point values
-
-## tf.SpaceToBatchND (TF::SpaceToBatchNDOp)
-SpaceToBatch for N-D tensors of type T.
-
-### Description:
-
-This operation divides "spatial" dimensions `[1, ..., M]` of the input into a
-grid of blocks of shape `block_shape`, and interleaves these blocks with the
-"batch" dimension (0) such that in the output, the spatial dimensions
-`[1, ..., M]` correspond to the position within the grid, and the batch
-dimension combines both the position within a spatial block and the original
-batch position.  Prior to division into blocks, the spatial dimensions of the
-input are optionally zero padded according to `paddings`.  See below for a
-precise description.
-
-### Operands:
-1. `input`: tensor of tf.dtype values
-1. `block_shape`: tensor of 32/64-bit integer values
-1. `paddings`: tensor of 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-| `Tpaddings` | `Attribute` | derived attribute attribute |
-| `Tblock_shape` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.Split (TF::SplitOp)
-Splits a tensor into `num_split` tensors along one dimension.
-
-### Description:
-
-
-### Operands:
-1. `split_dim`: tensor of 32-bit integer values
-1. `value`: tensor of tf.dtype values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `num_split` | `IntegerAttr` | 64-bit integer attribute whose minimal value is 1 attribute |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.SplitV (TF::SplitVOp)
-Splits a tensor into `num_split` tensors along one dimension.
-
-### Description:
-
-
-### Operands:
-1. `value`: tensor of tf.dtype values
-1. `size_splits`: tensor of 32/64-bit integer values
-1. `split_dim`: tensor of 32-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `num_split` | `IntegerAttr` | 64-bit integer attribute whose minimal value is 1 attribute |
-| `Tlen` | `Attribute` | derived attribute attribute |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.Sqrt (TF::SqrtOp)
-Computes square root of x element-wise.
-
-### Description:
-
-I.e., \\(y = \sqrt{x} = x^{1/2}\\).
-
-### Operands:
-1. `x`: tensor of floating-point or 64/128-bit complex type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `y`: tensor of floating-point or 64/128-bit complex type values
-
-## tf.Square (TF::SquareOp)
-Computes square of x element-wise.
-
-### Description:
-
-I.e., \\(y = x * x = x^2\\).
-
-### Operands:
-1. `x`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `y`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values
-
-## tf.SquaredDifference (TF::SquaredDifferenceOp)
-Returns (x - y)(x - y) element-wise.
-
-### Description:
-
-*NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting
-[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-
-### Operands:
-1. `x`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values
-1. `y`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `z`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values
-
-## tf.Squeeze (TF::SqueezeOp)
-Removes dimensions of size 1 from the shape of a tensor.
-
-### Description:
-
-Given a tensor `input`, this operation returns a tensor of the same type with
-all dimensions of size 1 removed. If you don't want to remove all size 1
-dimensions, you can remove specific size 1 dimensions by specifying
-`axis`.
-
-For example:
-
-```
-# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
-shape(squeeze(t)) ==> [2, 3]
-```
-
-Or, to remove specific size 1 dimensions:
-
-```
-# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
-shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]
-```
-
-### Operands:
-1. `input`: tensor of tf.dtype values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `squeeze_dims` | `ArrayAttr` | 64-bit integer array attribute attribute |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.StridedSlice (TF::StridedSliceOp)
-Return a strided slice from `input`.
-
-### Description:
-
-Note, most python users will want to use the Python `Tensor.__getitem__`
-or `Variable.__getitem__` rather than this op directly.
-
-The goal of this op is to produce a new tensor with a subset of
-the elements from the `n` dimensional `input` tensor. The subset is chosen using
-a sequence of `m` sparse range specifications encoded into the arguments
-of this function. Note, in some cases
-`m` could be equal to `n`, but this need not be the case. Each
-range specification entry can be one of the following:
-
-- An ellipsis (...). Ellipses are used to imply zero or more
-  dimensions of full-dimension selection and are produced using
-  `ellipsis_mask`. For example, `foo[...]` is the identity slice.
-
-- A new axis. This is used to insert a new shape=1 dimension and is
-  produced using `new_axis_mask`. For example, `foo[:, ...]` where
-  `foo` is shape `(3, 4)` produces a `(1, 3, 4)` tensor.
-
-
-- A range `begin:end:stride`. This is used to specify how much to choose from
-  a given dimension. `stride` can be any integer but 0.  `begin` is an integer
-  which represents the index of the first value to select while `end` represents
-  the index of the last value to select. The number of values selected in each
-  dimension is `end - begin` if `stride > 0` and `begin - end` if `stride < 0`.
-  `begin` and `end` can be negative where `-1` is the last element, `-2` is
-  the second to last. `begin_mask` controls whether to replace the explicitly
-  given `begin` with an implicit effective value of `0` if `stride > 0` and
-  `-1` if `stride < 0`. `end_mask` is analogous but produces the number
-  required to create the largest open interval. For example, given a shape
-  `(3,)` tensor `foo[:]`, the effective `begin` and `end` are `0` and `3`. Do
-  not assume this is equivalent to `foo[0:-1]` which has an effective `begin`
-  and `end` of `0` and `2`. Another example is `foo[-2::-1]` which reverses the
-  first dimension of a tensor while dropping the last two (in the original
-  order elements). For example `foo = [1,2,3,4]; foo[-2::-1]` is `[4,3]`.
-
-- A single index. This is used to keep only elements that have a given
-  index. For example (`foo[2, :]` on a shape `(5,6)` tensor produces a
-  shape `(6,)` tensor. This is encoded in `begin` and `end` and
-  `shrink_axis_mask`.
-
-Each conceptual range specification is encoded in the op's argument. This
-encoding is best understand by considering a non-trivial example. In
-particular,
-`foo[1, 2:4, None, ..., :-3:-1, :]` will be encoded as
-
-```
-begin = [1, 2, x, x, 0, x] # x denotes don't care (usually 0)
-end = [2, 4, x, x, -3, x]
-strides = [1, 1, x, x, -1, 1]
-begin_mask = 1<<4 | 1 << 5 = 48
-end_mask = 1<<5 = 32
-ellipsis_mask = 1<<3 = 8
-new_axis_mask = 1<<2 4
-shrink_axis_mask = 1<<0
-```
-
-In this case if `foo.shape` is (5, 5, 5, 5, 5, 5) the final shape of
-the slice becomes (2, 1, 5, 5, 2, 5).
-Let us walk step by step through each argument specification.
-
-1.  The first argument in the example slice is turned into `begin = 1` and
-`end = begin + 1 = 2`. To disambiguate from the original spec `2:4` we
-also set the appropriate bit in `shrink_axis_mask`.
-
-2. `2:4` is contributes 2, 4, 1 to begin, end, and stride. All masks have
-zero bits contributed.
-
-3. None is a synonym for `tf.newaxis`. This means insert a dimension of size 1
-dimension in the final shape. Dummy values are contributed to begin,
-end and stride, while the new_axis_mask bit is set.
-
-4. `...` grab the full ranges from as many dimensions as needed to
-fully specify a slice for every dimension of the input shape.
-
-5. `:-3:-1` shows the use of negative indices. A negative index `i` associated
-with a dimension that has shape `s` is converted to a positive index
-`s + i`. So `-1` becomes `s-1` (i.e. the last element). This conversion
-is done internally so begin, end and strides receive x, -3, and -1.
-The appropriate begin_mask bit is set to indicate the start range is the
-full range (ignoring the x).
-
-6. `:` indicates that the entire contents of the corresponding dimension
-is selected. This is equivalent to `::` or `0::1`. begin, end, and strides
-receive 0, 0, and 1, respectively. The appropriate bits in `begin_mask` and
-`end_mask` are also set.
-
-*Requirements*:
-  `0 != strides[i] for i in [0, m)`
-  `ellipsis_mask must be a power of two (only one ellipsis)`
-
-### Operands:
-1. `input`: tensor of tf.dtype values
-1. `begin`: tensor of 32/64-bit integer values
-1. `end`: tensor of 32/64-bit integer values
-1. `strides`: tensor of 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `begin_mask` | `IntegerAttr` | 64-bit integer attribute attribute |
-| `end_mask` | `IntegerAttr` | 64-bit integer attribute attribute |
-| `ellipsis_mask` | `IntegerAttr` | 64-bit integer attribute attribute |
-| `new_axis_mask` | `IntegerAttr` | 64-bit integer attribute attribute |
-| `shrink_axis_mask` | `IntegerAttr` | 64-bit integer attribute attribute |
-| `T` | `Attribute` | derived attribute attribute |
-| `Index` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.Sub (TF::SubOp)
-Returns x - y element-wise.
-
-### Description:
-
-*NOTE*: `Subtract` supports broadcasting. More about broadcasting
-[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-
-### Operands:
-1. `x`: tensor of number values
-1. `y`: tensor of number values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `z`: tensor of number values
-
-## tf.Sum (TF::SumOp)
-Computes the sum of elements across dimensions of a tensor.
-
-### Description:
-
-Reduces `input` along the dimensions given in `axis`. Unless
-`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
-`axis`. If `keep_dims` is true, the reduced dimensions are
-retained with length 1.
-
-### Operands:
-1. `input`: tensor of number values
-1. `reduction_indices`: tensor of 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `keep_dims` | `BoolAttr` | bool attribute attribute |
-| `T` | `Attribute` | derived attribute attribute |
-| `Tidx` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of number values
-
-## tf.TensorListFromTensor (TF::TensorListFromTensorOp)
-
-Creates a TensorList which, when stacked, has the value of `tensor`.
-  
-
-### Description:
-
-Each tensor in the result list corresponds to one row of the input tensor.
-
-tensor: The input tensor.
-output_handle: The list.
-
-### Operands:
-1. `tensor`: tensor of tf.dtype values
-1. `element_shape`: tensor of 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `shape_type` | `Attribute` | derived attribute attribute |
-| `element_dtype` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output_handle`: tensor of TensorFlow variant type values
-
-## tf.TensorListGetItem (TF::TensorListGetItemOp)
-
-
-### Description:
-
-
-### Operands:
-1. `input_handle`: tensor of TensorFlow variant type values
-1. `index`: tensor of 32-bit integer values
-1. `element_shape`: tensor of 32-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `element_dtype` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `item`: tensor of tf.dtype values
-
-## tf.TensorListReserve (TF::TensorListReserveOp)
-List of the given size with empty elements.
-
-### Description:
-
-element_shape: the shape of the future elements of the list
-num_elements: the number of elements to reserve
-handle: the output list
-element_dtype: the desired type of elements in the list.
-
-### Operands:
-1. `element_shape`: tensor of 32/64-bit integer values
-1. `num_elements`: tensor of 32-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `element_dtype` | `TypeAttr` | any type attribute attribute |
-| `shape_type` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `handle`: tensor of TensorFlow variant type values
-
-## tf.TensorListSetItem (TF::TensorListSetItemOp)
-
-
-### Description:
-
-
-### Operands:
-1. `input_handle`: tensor of TensorFlow variant type values
-1. `index`: tensor of 32-bit integer values
-1. `item`: tensor of tf.dtype values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `element_dtype` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output_handle`: tensor of TensorFlow variant type values
-
-## tf.TensorListStack (TF::TensorListStackOp)
-Stacks all tensors in the list.
-
-### Description:
-
-Requires that all tensors have the same shape.
-
-input_handle: the input list
-tensor: the gathered result
-num_elements: optional. If not -1, the number of elements in the list.
-
-### Operands:
-1. `input_handle`: tensor of TensorFlow variant type values
-1. `element_shape`: tensor of 32-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `num_elements` | `IntegerAttr` | 64-bit integer attribute attribute |
-| `element_dtype` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `tensor`: tensor of tf.dtype values
-
-## tf.TopKV2 (TF::TopKV2Op)
-
-Finds values and indices of the `k` largest elements for the last dimension.
-  
-
-### Description:
-
-If the input is a vector (rank-1), finds the `k` largest entries in the vector
-and outputs their values and indices as vectors.  Thus `values[j]` is the
-`j`-th largest entry in `input`, and its index is `indices[j]`.
-
-For matrices (resp. higher rank input), computes the top `k` entries in each
-row (resp. vector along the last dimension).  Thus,
-
-    values.shape = indices.shape = input.shape[:-1] + [k]
-
-If two elements are equal, the lower-index element appears first.
-
-### Operands:
-1. `input`: tensor of 8/16/32/64-bit integer or floating-point values
-1. `k`: tensor of 32-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `sorted` | `BoolAttr` | bool attribute attribute |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `values`: tensor of 8/16/32/64-bit integer or floating-point values
-1. `indices`: tensor of 32-bit integer values
-
-## tf.Transpose (TF::TransposeOp)
-Shuffle dimensions of x according to a permutation.
-
-### Description:
-
-The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy:
-  `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]`
-
-### Operands:
-1. `x`: tensor of tf.dtype values
-1. `perm`: tensor of 32/64-bit integer values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-| `Tperm` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `y`: tensor of tf.dtype values
-
-## tf.TruncateDiv (TF::TruncateDivOp)
-Returns x / y element-wise for integer types.
-
-### Description:
-
-Truncation designates that negative numbers will round fractional quantities
-toward zero. I.e. -7 / 5 = -1. This matches C semantics but it is different
-than Python semantics. See `FloorDiv` for a division function that matches
-Python Semantics.
-
-*NOTE*: `TruncateDiv` supports broadcasting. More about broadcasting
-[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-
-### Operands:
-1. `x`: tensor of number values
-1. `y`: tensor of number values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `z`: tensor of number values
-
-## tf.Unpack (TF::UnpackOp)
-
-Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors.
-  
-
-### Description:
-
-Unpacks `num` tensors from `value` by chipping it along the `axis` dimension.
-For example, given a tensor of shape `(A, B, C, D)`;
-
-If `axis == 0` then the i'th tensor in `output` is the slice `value[i, :, :, :]`
-  and each tensor in `output` will have shape `(B, C, D)`. (Note that the
-  dimension unpacked along is gone, unlike `split`).
-
-If `axis == 1` then the i'th tensor in `output` is the slice `value[:, i, :, :]`
-  and each tensor in `output` will have shape `(A, C, D)`.
-Etc.
-
-This is the opposite of `pack`.
-
-### Operands:
-1. `value`: tensor of tf.dtype values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `num` | `IntegerAttr` | 64-bit integer attribute whose minimal value is 0 attribute |
-| `axis` | `IntegerAttr` | 64-bit integer attribute attribute |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `output`: tensor of tf.dtype values
-
-## tf.Xdivy (TF::XdivyOp)
-Returns 0 if x == 0, and x / y otherwise, elementwise.
-
-### Description:
-
-
-### Operands:
-1. `x`: tensor of 16-bit float or 32-bit float or 64-bit float or complex128 type or complex64 type values
-1. `y`: tensor of 16-bit float or 32-bit float or 64-bit float or complex128 type or complex64 type values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `z`: tensor of 16-bit float or 32-bit float or 64-bit float or complex128 type or complex64 type values
-
-## tf.ZerosLike (TF::ZerosLikeOp)
-Returns a tensor of zeros with the same shape and type as x.
-
-### Description:
-
-
-### Operands:
-1. `x`: tensor of tf.dtype values
-
-### Attributes:
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `T` | `Attribute` | derived attribute attribute |
-
-### Results:
-1. `y`: tensor of tf.dtype values
-
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h
index 2756b4c..4bf7029 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h
+++ b/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h
@@ -65,7 +65,7 @@
 // tensor needs its own _tf.Enter to be made available inside the while loop.
 //
 // More details can be found in Tensorflow Controlflow white paper:
-// http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
+// https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
 //
 // This is defined in Tensorflow as:
 //
@@ -100,7 +100,7 @@
 // of the operand type along with the index of the first match encountered.
 //
 // More details can be found in Tensorflow Controlflow white paper:
-// http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
+// https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
 //
 // This is defined in TensorFlow as:
 //
@@ -130,7 +130,7 @@
 // of a while loop. Each loop variable needs its own NextIteration op.
 //
 // More details can be found in Tensorflow Controlflow white paper:
-// http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
+// https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
 //
 // NextIteration op is broken into _tf.NextIteration.sink and
 // _tf.NextIteration.source because NextIteration is a back-edge in Tensorflow
@@ -182,7 +182,7 @@
 // Tensorflow while loops.
 //
 // More details can be found in Tensorflow Controlflow white paper:
-// http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
+// https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
 //
 // This is defined in Tensorflow as:
 //
@@ -212,7 +212,7 @@
 // condition, and returns two values matching the type of the data predicate.
 //
 // More details can be found in Tensorflow Controlflow white paper:
-// http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
+// https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
 //
 // This is defined in TensorFlow as:
 //
@@ -246,7 +246,7 @@
 // outside of loop. Each returned tensor needs its own _tf.Exit.
 //
 // More details can be found in Tensorflow Controlflow white paper:
-// http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
+// https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
 //
 // This is defined in Tensorflow as:
 //
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
index 29d73a7..810332e 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
@@ -16,11 +16,14 @@
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
 
 #include <algorithm>
+#include <iterator>
 
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringSwitch.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "mlir/Dialect/Traits.h"  // TF:local_config_mlir
 #include "mlir/IR/Attributes.h"  // TF:local_config_mlir
@@ -34,9 +37,28 @@
 #include "mlir/IR/Types.h"  // TF:local_config_mlir
 #include "mlir/IR/Value.h"  // TF:local_config_mlir
 #include "mlir/StandardOps/Ops.h"  // TF:local_config_mlir
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 
 namespace mlir {
 namespace tf_executor {
+namespace {
+
+// If the given tensor has elements of type variant, then returns a new type
+// after dropping subtypes info. Otherwise, returns the original type as is.
+Type DropVariantSubTypes(Type ty) {
+  ShapedType shaped_ty = ty.cast<ShapedType>();
+  Type element_ty = shaped_ty.getElementType();
+  if (!element_ty.isa<TF::VariantType>()) return ty;
+
+  Type variant_ty = TF::VariantType::get(ty.getContext());
+  if (shaped_ty.hasRank()) {
+    return RankedTensorType::get(shaped_ty.getShape(), variant_ty);
+  }
+
+  return UnrankedTensorType::get(variant_ty);
+}
+
+}  // namespace
 
 //===----------------------------------------------------------------------===//
 // TF Executor Dialect
@@ -77,21 +99,6 @@
 
 namespace {
 
-// Inserts `tf_executor.Terminator` at the end of the region's only block if it
-// does not have a terminator already. If the region is empty, insert a new
-// block first.
-template <typename Terminator>
-void EnsureExecutorTerminator(Region *region, Builder *builder, Location loc) {
-  if (region->empty()) region->push_back(new Block);
-
-  Block &block = region->back();
-  if (!block.empty() && block.back().isKnownTerminator()) return;
-
-  OperationState terminator_state(loc, Terminator::getOperationName());
-  Terminator::build(builder, &terminator_state, {});
-  block.push_back(Operation::create(terminator_state));
-}
-
 // Verifies that every control operands are at the end of the list.
 // Used by the constraint `ControlOperandsAfterAllData` in ODS.
 LogicalResult VerifyControlOperandsAfterAllData(Operation *op) {
@@ -123,6 +130,9 @@
   for (Operation &op : graph.GetBody()) {
     if (op.getDialect() != executorDialect)
       return op.emitOpError() << "unallowed inside a tf_executor.graph region";
+    if (isa<GraphOp>(op))
+      return op.emitOpError()
+             << "unallowed directly inside another tf_executor.graph";
   }
 
   Operation &fetch = graph.GetBody().back();
@@ -174,8 +184,7 @@
 
   // Ensure that the region is well formed: it contains at least a block with
   // a FetchOp terminator.
-  EnsureExecutorTerminator<FetchOp>(&body, &parser->getBuilder(),
-                                    result->location);
+  GraphOp::ensureTerminator(body, parser->getBuilder(), result->location);
 
   // Get the results type from the terminator type inside the graph.
   Operation &fetch = body.back().back();
@@ -281,8 +290,7 @@
   if (parser->parseOperandList(op_infos, OpAsmParser::Delimiter::OptionalParen))
     return failure();
   if (!op_infos.empty()) {
-    SmallVector<Type, 2> types;
-    types.push_back(control_type);
+    SmallVector<Type, 2> types(op_infos.size(), control_type);
     parser->resolveOperands(op_infos, types, loc, result->operands);
   }
 
@@ -301,8 +309,7 @@
 
   if (parser->parseRegion(body, llvm::None, llvm::None)) return failure();
 
-  EnsureExecutorTerminator<YieldOp>(&body, &parser->getBuilder(),
-                                    result->location);
+  IslandOp::ensureTerminator(body, parser->getBuilder(), result->location);
 
   // Get the results type for the island from the terminator operands.
   Operation &yield = body.back().back();
@@ -498,8 +505,17 @@
   Type broadcasted_type = merge.output()->getType();
   for (Type operand_type : merge.getOperandTypes()) {
     if (operand_type.isa<ControlType>()) break;
+
+    // TODO(hinsu): Update ControlOperandsAfterAllData trait to verify this
+    // constraint.
+    if (!operand_type.isa<TensorType>())
+      return merge.emitOpError("expects data operands to have tensor type");
+
+    // Variant types may have opaque subtypes information that need not match
+    // between the two types so drop them before computing the broadcasted type.
     Type new_broadcasted_type =
-        OpTrait::util::getBroadcastedType(broadcasted_type, operand_type);
+        OpTrait::util::getBroadcastedType(DropVariantSubTypes(broadcasted_type),
+                                          DropVariantSubTypes(operand_type));
     if (!new_broadcasted_type)
       return merge.emitOpError()
              << "expects all operands to be broadcastable"
@@ -508,10 +524,8 @@
     // This is because for example starting with a result of tensor<4xf32>, if
     // the first operand is unranked, the broadcasted type will be unranked.
     // Then any tensor operand will be broadcastable to this unranked type.
-    if ((broadcasted_type.isa<TensorType>() &&
-         !broadcasted_type.cast<TensorType>().hasRank()) ||
-        (new_broadcasted_type.isa<TensorType>() &&
-         new_broadcasted_type.cast<TensorType>().hasRank()))
+    if (!broadcasted_type.cast<TensorType>().hasRank() ||
+        new_broadcasted_type.cast<TensorType>().hasRank())
       broadcasted_type = new_broadcasted_type;
   }
 
@@ -519,11 +533,33 @@
 }
 
 void Print(MergeOp merge, OpAsmPrinter *p) {
+  // Use short form only when there are exactly two data operands and their
+  // type matches the output type. Otherwise, use the generic printer.
+  bool use_short_form = true;
+  int num_data_operands = 0;
+
+  Type output_type = merge.output()->getType();
+  for (Type operand_type : merge.getOperandTypes()) {
+    if (operand_type.isa<ControlType>()) break;
+    num_data_operands++;
+
+    if (operand_type != output_type) {
+      use_short_form = false;
+      break;
+    }
+  }
+
   *p << merge.getOperationName() << ' ';
   p->printOperands(merge.getOperands());
 
   // Print the type signature of the operation.
-  *p << " : " << merge.getType(0);
+  *p << " : ";
+  if (!use_short_form || num_data_operands != 2) {
+    p->printFunctionalType(merge.getOperation());
+  } else {
+    *p << output_type;
+  }
+
   p->printOptionalAttrDict(merge.getAttrs());
 }
 
@@ -537,17 +573,26 @@
     return parser->emitError(parser->getNameLoc())
            << " expects only a single data type";
 
-  // Expect the type once, but use it for both operands.
-  types.push_back(types.front());
-  // Extra operands are expected to be control inputs.
-  Type control_type = ControlType::get(parser->getBuilder().getContext());
-  types.append(op_infos.size() - 2, control_type);
+  // Support parsing either a functional type (in which case all the types are
+  // fully qualified) or a short form with a single type (in which case the data
+  // inputs and the output are all using this type).
+  if (FunctionType type = types.front().dyn_cast<FunctionType>()) {
+    result->types.assign(type.getResults().begin(), type.getResults().end());
+    types.assign(type.getInputs().begin(), type.getInputs().end());
+  } else {
+    // In case of the short form, use the parsed type for both the operands and
+    // the remaining operands are expected to be control inputs.
+    types.push_back(types.front());
+    Type control_type = ControlType::get(parser->getBuilder().getContext());
+    types.append(op_infos.size() - 2, control_type);
+
+    RankedTensorType i32_tensor =
+        RankedTensorType::get({}, parser->getBuilder().getIntegerType(32));
+    result->types = {types.front(), i32_tensor, control_type};
+  }
 
   if (parser->resolveOperands(op_infos, types, loc, result->operands))
     return failure();
-  RankedTensorType i32_tensor =
-      RankedTensorType::get({}, parser->getBuilder().getIntegerType(32));
-  result->types = {types.front(), i32_tensor, control_type};
 
   return parser->parseOptionalAttributeDict(result->attributes);
 }
@@ -833,6 +878,96 @@
 }  // namespace
 
 //===----------------------------------------------------------------------===//
+// Canonicalization patterns
+//===----------------------------------------------------------------------===//
+
+// TODO(lyandy): Add canonicalization for dedupping control inputs.
+
+//===----------------------------------------------------------------------===//
+// tf_executor.graph
+//===----------------------------------------------------------------------===//
+
+namespace {
+// Finds in a block if the op of type `InnerOpT` is the first operation and
+// optionally followed by a terminator.
+template <typename InnerOpT>
+bool HasSingleOpInBlock(Block *block) {
+  if (block->empty()) return false;
+  if (!llvm::isa<InnerOpT>(block->front())) return false;
+  // Either InnerOpT is the only instruction in the block, or there is a
+  // possible terminator.
+  return std::next(block->begin()) == block->end() ||
+         std::next(block->begin(), 2) == block->end();
+}
+
+// This pattern matches GraphOps with only one FetchOp (empty) and remaps the
+// results of the GraphOp to the operands of the FetchOp.
+struct DropEmptyGraph : public OpRewritePattern<GraphOp> {
+  using OpRewritePattern<GraphOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(GraphOp op,
+                                     PatternRewriter &rewriter) const override {
+    Block &block = op.GetBody();
+    // Check if graph only has one fetch.
+    auto fetch_op = llvm::dyn_cast<FetchOp>(block.front());
+    if (!fetch_op) return matchFailure();
+
+    // Map graph results to fetch operands.
+    llvm::SmallVector<Value *, 8> new_rets(fetch_op.fetches());
+    rewriter.replaceOp(op, new_rets);
+
+    return matchSuccess();
+  }
+};
+
+// This pattern matches GraphOps with only one island, pulls out all inner ops
+// of the island to the block containing the GraphOp, and then removes the
+// GraphOp.
+struct HoistInnerOpsSingleIslandGraph : public OpRewritePattern<GraphOp> {
+  using OpRewritePattern<GraphOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(GraphOp op,
+                                     PatternRewriter &rewriter) const override {
+    Block &block = op.GetBody();
+    // Check if graph only has one island.
+    if (!HasSingleOpInBlock<IslandOp>(&block)) return matchFailure();
+
+    auto fetch_op = llvm::cast<FetchOp>(block.back());
+    auto island_op = llvm::cast<IslandOp>(block.front());
+    Operation &yield_op = island_op.GetBody().back();
+
+    // Map graph results to inner ops results of single island.
+    llvm::SmallVector<Value *, 8> new_rets;
+    for (Value *operand : fetch_op.fetches()) {
+      if (operand->getDefiningOp() != island_op) {
+        // Operand is not from island, simply propagate it out.
+        new_rets.push_back(operand);
+      } else {
+        // Lookup yield operand in island for inner op result.
+        auto result = llvm::cast<OpResult>(operand);
+        new_rets.push_back(yield_op.getOperand(result->getResultNumber()));
+      }
+    }
+
+    // Move inner ops from island to block containing graph.
+    auto &island_body = island_op.GetBody().getOperations();
+    Operation *operation = op.getOperation();
+    operation->getBlock()->getOperations().splice(
+        operation->getIterator(), island_body, island_body.begin(),
+        std::prev(island_body.end()));
+    rewriter.replaceOp(op, new_rets);
+
+    return matchSuccess();
+  }
+};
+}  // anonymous namespace
+
+void GraphOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                          MLIRContext *context) {
+  results.insert<DropEmptyGraph, HoistInnerOpsSingleIslandGraph>(context);
+}
+
+//===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
 
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td
index 748416a..6e827e9 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td
@@ -55,12 +55,14 @@
 // Token type.
 def TfeTokenType : Type<CPred<"$_self.isa<TokenType>()">, "token">;
 
+// TODO(hinsu): Define and use TensorType instead of AnyType for data operands
+// and results. For example, MergeOp output type.
+
 //===----------------------------------------------------------------------===//
 // TensorFlow Executor Type Constraint
 //===----------------------------------------------------------------------===//
 
-// Predicate to verify that the opId'th operand can be broadcasted to the type
-// of the  resId'th result.
+// Predicate to verify all control inputs appear after any non-control inputs.
 def ControlOperandsAfterAllData :
     PredOpTrait<"all control inputs must appear after any non-control input",
                 CPred<"succeeded(VerifyControlOperandsAfterAllData(&$_op))">>;
@@ -79,7 +81,8 @@
   let parser = [{ return Parse$cppClass(parser, result); }];
 }
 
-def TfExecutor_GraphOp : TfExecutor_Op<"graph", []> {
+def TfExecutor_GraphOp : TfExecutor_Op<"graph",
+    [SingleBlockImplicitTerminator<"FetchOp">]> {
   let summary = [{The `tf_executor.graph` operation contains a region with a
     single block that lists the operations in a TensorFlow graph.}];
 
@@ -118,12 +121,15 @@
 
   let regions = (region SizedRegion<1>:$body);
 
+  let hasCanonicalizer = 1;
+
   let extraClassDeclaration = [{
     Block &GetBody() { return getOperation()->getRegion(0).front(); }
   }];
 }
 
-def TfExecutor_FetchOp : TfExecutor_Op<"fetch", [Terminator, ControlOperandsAfterAllData]> {
+def TfExecutor_FetchOp : TfExecutor_Op<"fetch",
+    [Terminator, ControlOperandsAfterAllData, HasParent<"GraphOp">]> {
   let summary = [{
     The `tf_executor.fetch` operation terminates the graph and returns values";
   }];
@@ -137,10 +143,18 @@
     Variadic<AnyType>:$fetches
   );
 
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result",
+    [{
+      build(builder, result, {});
+    }]>
+   ];
+
   let verifier = ?;
 }
 
-def TfExecutor_IslandOp : TfExecutor_Op<"island", []> {
+def TfExecutor_IslandOp : TfExecutor_Op<"island",
+    [HasParent<"GraphOp">, SingleBlockImplicitTerminator<"YieldOp">]> {
   let summary = [{
     The `tf_executor.island` operation is a wrapper for operations in other
     dialects to be nested in a `tf_executor.graph`.
@@ -193,8 +207,8 @@
   }];
 }
 
-def TfExecutor_YieldOp :
-    TfExecutor_Op<"yield", [Terminator, ControlOperandsAfterAllData]> {
+def TfExecutor_YieldOp : TfExecutor_Op<"yield",
+    [Terminator, ControlOperandsAfterAllData, HasParent<"IslandOp">]> {
   let summary = [{
     The `tf_executor.yield` operation terminates and returns values for the
     `tf_executor.island` operation.
@@ -204,11 +218,18 @@
     Variadic<AnyType>:$fetches
   );
 
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result",
+    [{
+      build(builder, result, {});
+    }]>
+   ];
+
   let verifier = ?;
 }
 
 def TfExecutor_SwitchOp : TfExecutor_Op<"Switch",
-    [NoSideEffect, ControlOperandsAfterAllData,
+    [NoSideEffect, ControlOperandsAfterAllData, HasParent<"GraphOp">,
      PredOpTrait<"data operand must be broadcastable to true result",
                  TCOpIsBroadcastableToRes<0, 0>>,
      PredOpTrait<"data operand must be broadcastable to false result",
@@ -221,7 +242,7 @@
 
   let description = [{
     More details can be found in Tensorflow Control Flow white paper:
-    http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
+    https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
 
     This is defined in TensorFlow as:
 
@@ -253,8 +274,8 @@
    let verifier = ?;
 }
 
-def TfExecutor_SwitchNOp :
-    TfExecutor_Op<"SwitchN", [NoSideEffect, ControlOperandsAfterAllData]> {
+def TfExecutor_SwitchNOp : TfExecutor_Op<"SwitchN",
+    [NoSideEffect, ControlOperandsAfterAllData, HasParent<"GraphOp">]> {
   let summary = [{
     The "tf_executor.SwitchN" operation takes two inputs, `data` and `index` and
     an integer attribute `num_outs` indicating the number of outputs. The `data`
@@ -294,7 +315,8 @@
   );
 }
 
-def TfExecutor_MergeOp : TfExecutor_Op<"Merge", [NoSideEffect, ControlOperandsAfterAllData]> {
+def TfExecutor_MergeOp : TfExecutor_Op<"Merge",
+    [NoSideEffect, ControlOperandsAfterAllData, HasParent<"GraphOp">]> {
   let summary = [{
     The "tf_executor.Merge" operation takes a list of input operands and returns
     a value of the operand type along with the index of the first match encountered.
@@ -302,7 +324,7 @@
 
   let description = [{
     More details can be found in Tensorflow Control Flow white paper:
-    http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
+    https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
 
     This is defined in TensorFlow as:
 
@@ -322,14 +344,14 @@
   );
 
   let results = (outs
-    AnyType:$output,
+    AnyTensor:$output,
     TensorOf<[I32]>:$valueIndex,
     TfeControlType:$control
   );
 }
 
 def TfExecutor_EnterOp : TfExecutor_Op<"Enter",
-    [NoSideEffect, ControlOperandsAfterAllData,
+    [NoSideEffect, ControlOperandsAfterAllData, HasParent<"GraphOp">,
      PredOpTrait<"data operand must be broadcastable to result",
                  TCOpIsBroadcastableToRes<0, 0>>]>{
   let summary = [{
@@ -339,7 +361,7 @@
 
   let description = [{
     More details can be found in Tensorflow Control Flow white paper:
-    http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
+    https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
 
     Each tensor needs its own tf_executor.Enter to be made available inside a
     while loop.
@@ -378,7 +400,8 @@
   let verifier = ?;
 }
 
-def TfExecutor_NextIterationSourceOp : TfExecutor_Op<"NextIteration.Source", [NoSideEffect]> {
+def TfExecutor_NextIterationSourceOp : TfExecutor_Op<"NextIteration.Source",
+    [NoSideEffect, HasParent<"GraphOp">]> {
   let summary = [{
     The "tf_executor.NextIteration.Source" is paired with a
     "tf_executor.NextIteration.sink" to represent NextIteration op in
@@ -390,7 +413,7 @@
     of a while loop. Each loop variable needs its own NextIteration op.
 
     More details can be found in Tensorflow Control Flow white paper:
-    http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
+    https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
 
     In the TF executor dialect, the NextIteration op is broken into
     tf_executor.NextIteration.sink and tf_executor.NextIteration.source because
@@ -435,7 +458,8 @@
 }
 
 
-def TfExecutor_NextIterationSinkOp : TfExecutor_Op<"NextIteration.Sink"> {
+def TfExecutor_NextIterationSinkOp : TfExecutor_Op<"NextIteration.Sink",
+    [HasParent<"GraphOp">]> {
   let summary = [{
     The "tf_executor.NextIteration.Sink" is paired with a
     "tf_executor.NextIteration.source" to represent NextIteration op in
@@ -447,7 +471,7 @@
     of a while loop. Each loop variable needs its own NextIteration op.
 
     More details can be found in Tensorflow Control Flow white paper:
-    http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
+    https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
 
     In the TF executor dialect, the NextIteration op is broken into
     tf_executor.NextIteration.sink and tf_executor.NextIteration.source because
@@ -495,7 +519,7 @@
 }
 
 def TfExecutor_ExitOp : TfExecutor_Op<"Exit",
-    [NoSideEffect,
+    [NoSideEffect, HasParent<"GraphOp">,
      PredOpTrait<"data operand must be broadcastable to result",
                  TCOpIsBroadcastableToRes<0, 0>>]>{
 
@@ -507,7 +531,7 @@
 
   let description = [{
     More details can be found in Tensorflow Control Flow white paper:
-    http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
+    https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
 
     This is defined in Tensorflow as:
 
@@ -535,7 +559,8 @@
   let verifier = ?;
 }
 
-def TfExecutor_ControlTriggerOp : TfExecutor_Op<"ControlTrigger", [NoSideEffect]> {
+def TfExecutor_ControlTriggerOp : TfExecutor_Op<"ControlTrigger",
+    [NoSideEffect, HasParent<"GraphOp">]> {
   let summary = [{
     The `tf_executor.ControlTrigger` operation is similar to a no-op except that
     it always produces a valid output even when inputs are dead.
@@ -571,7 +596,8 @@
    ];
 }
 
-def TfExecutor_LoopCondOp : TfExecutor_Op<"LoopCond", [NoSideEffect]> {
+def TfExecutor_LoopCondOp : TfExecutor_Op<"LoopCond",
+    [NoSideEffect, HasParent<"GraphOp">]> {
   let summary = [{
     The "tf_executor.LoopCond" operation forwards a boolean value as loop
     condition of Tensorflow while loops.
@@ -579,7 +605,7 @@
 
   let description = [{
     More details can be found in Tensorflow Control Flow white paper:
-    http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
+    https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
 
     This is defined in Tensorflow as:
 
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index a748e29..e7d039d 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -123,6 +123,32 @@
   let hasCanonicalizer = 1;
 }
 
+def TF_AnyOp : TF_Op<"Any", [NoSideEffect]> {
+  let summary = [{
+Computes the "logical or" of elements across dimensions of a tensor.
+  }];
+
+  let description = [{
+Reduces `input` along the dimensions given in `axis`. Unless
+`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+`axis`. If `keep_dims` is true, the reduced dimensions are
+retained with length 1.
+  }];
+
+  let arguments = (ins
+    I1Tensor:$input,
+    TF_I32OrI64Tensor:$reduction_indices,
+
+    DefaultValuedAttr<BoolAttr, "false">:$keep_dims
+  );
+
+  let results = (outs
+    I1Tensor:$output
+  );
+
+  TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
+}
+
 def TF_ArgMaxOp : TF_Op<"ArgMax", [NoSideEffect]> {
   let summary = [{
 Returns the index with the largest value across dimensions of a tensor.
@@ -136,7 +162,7 @@
   import tensorflow as tf
   a = [1, 10, 26.9, 2.8, 166.32, 62.3]
   b = tf.math.argmax(input = a)
-  c = tf.keras.backend.eval(b)  
+  c = tf.keras.backend.eval(b)
   # c = 4
   # here a[4] = 166.32 which is the largest element of a across axis 0
   ```
@@ -169,7 +195,7 @@
   import tensorflow as tf
   a = [1, 10, 26.9, 2.8, 166.32, 62.3]
   b = tf.math.argmin(input = a)
-  c = tf.keras.backend.eval(b)  
+  c = tf.keras.backend.eval(b)
   # c = 0
   # here a[0] = 1 which is the smallest element of a across axis 0
   ```
@@ -189,6 +215,28 @@
   TF_DerivedResultTypeAttr output_type = TF_DerivedResultTypeAttr<0>;
 }
 
+def TF_AssertOp : TF_Op<"Assert", []> {
+  let summary = "Asserts that the given condition is true.";
+
+  let description = [{
+If `condition` evaluates to false, print the list of tensors in `data`.
+`summarize` determines how many entries of the tensors to print.
+  }];
+
+  let arguments = (ins
+    I1Tensor:$condition,
+    Variadic<TF_Tensor>:$data,
+
+    DefaultValuedAttr<I64Attr, "3">:$summarize
+  );
+
+  let results = (outs);
+
+  TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<1>;
+
+  let hasCanonicalizer = 1;
+}
+
 def TF_AvgPoolOp : TF_Op<"AvgPool", [NoSideEffect]> {
   let summary = "Performs average pooling on the input.";
 
@@ -679,6 +727,51 @@
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 }
 
+def TF_ExpOp : TF_Op<"Exp", [NoSideEffect, SameOperandsAndResultType]> {
+  let summary = [{
+Computes exponential of x element-wise.  \\(y = e^x\\).
+  }];
+
+  let description = [{
+This function computes the exponential of every element in the input tensor.
+  i.e. `exp(x)` or `e^(x)`, where `x` is the input tensor.
+  `e` denotes Euler's number and is approximately equal to 2.718281.
+  Output is positive for any real input.
+
+  ```python
+  x = tf.constant(2.0)
+  tf.math.exp(x) ==> 7.389056
+
+  x = tf.constant([2.0, 8.0])
+  tf.math.exp(x) ==> array([7.389056, 2980.958], dtype=float32)
+  ```
+
+  For complex numbers, the exponential value is calculated as follows:
+
+  ```
+  e^(x+iy) = e^x * e^iy = e^x * (cos y + i sin y)
+  ```
+
+  Let's consider complex number 1+1j as an example.
+  e^1 * (cos 1 + i sin 1) = 2.7182818284590 * (0.54030230586+0.8414709848j)
+
+  ```python
+  x = tf.constant(1 + 1j)
+  tf.math.exp(x) ==> 1.4686939399158851+2.2873552871788423j
+  ```
+  }];
+
+  let arguments = (ins
+    TF_FpOrComplexTensor:$x
+  );
+
+  let results = (outs
+    TF_FpOrComplexTensor:$y
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+}
+
 def TF_ExpandDimsOp : TF_Op<"ExpandDims", [NoSideEffect]> {
   let summary = "Inserts a dimension of 1 into a tensor's shape.";
 
@@ -891,6 +984,32 @@
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 }
 
+def TF_FloorModOp : TF_Op<"FloorMod", [Broadcastable, NoSideEffect]>,
+                    WithBroadcastableBinOpBuilder {
+  let summary = [{
+Returns element-wise remainder of division. When `x < 0` xor `y < 0` is
+  }];
+
+  let description = [{
+true, this follows Python semantics in that the result here is consistent
+with a flooring divide. E.g. `floor(x / y) * y + mod(x, y) = x`.
+
+*NOTE*: `FloorMod` supports broadcasting. More about broadcasting
+[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+  }];
+
+  let arguments = (ins
+    TF_FpOrI32OrI64Tensor:$x,
+    TF_FpOrI32OrI64Tensor:$y
+  );
+
+  let results = (outs
+    TF_FpOrI32OrI64Tensor:$z
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+}
+
 def TF_FusedBatchNormOp : TF_Op<"FusedBatchNorm", [NoSideEffect]> {
   let summary = "Batch normalization.";
 
@@ -926,6 +1045,39 @@
   }];
 }
 
+def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect]> {
+  let summary = "Batch normalization.";
+
+  let description = [{
+Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
+The size of 1D Tensors matches the dimension C of the 4D Tensors.
+  }];
+
+  let arguments = (ins
+    TensorOf<[BF16, F16, F32]>:$x,
+    F32Tensor:$scale,
+    F32Tensor:$offset,
+    F32Tensor:$mean,
+    F32Tensor:$variance,
+
+    DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
+    DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
+    DefaultValuedAttr<BoolAttr, "true">:$is_training
+  );
+
+  let results = (outs
+    TensorOf<[BF16, F16, F32]>:$y,
+    F32Tensor:$batch_mean,
+    F32Tensor:$batch_variance,
+    F32Tensor:$reserve_space_1,
+    F32Tensor:$reserve_space_2,
+    F32Tensor:$reserve_space_3
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+  TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
+}
+
 def TF_GatherOp : TF_Op<"Gather", [NoSideEffect]> {
   let summary = "Gather slices from `params` according to `indices`.";
 
@@ -978,13 +1130,13 @@
   }];
 
   let description = [{
-`indices` is an K-dimensional integer tensor, best thought of as a
+`indices` is a K-dimensional integer tensor, best thought of as a
 (K-1)-dimensional tensor of indices into `params`, where each element defines a
 slice of `params`:
 
     output[\\(i_0, ..., i_{K-2}\\)] = params[indices[\\(i_0, ..., i_{K-2}\\)]]
 
-Whereas in `tf.gather` `indices` defines slices into the first
+Whereas in `tf.gather` `indices` defines slices into the `axis`
 dimension of `params`, in `tf.gather_nd`, `indices` defines slices into the
 first `N` dimensions of `params`, where `N = indices.shape[-1]`.
 
@@ -2435,6 +2587,29 @@
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 }
 
+def TF_ResizeNearestNeighborOp : TF_Op<"ResizeNearestNeighbor", [NoSideEffect]> {
+  let summary = [{
+Resize `images` to `size` using nearest neighbor interpolation.
+  }];
+
+  let description = [{
+  }];
+
+  let arguments = (ins
+    TensorOf<[F16, F32, F64, I16, I32, I64, I8]>:$images,
+    I32Tensor:$size,
+
+    DefaultValuedAttr<BoolAttr, "false">:$align_corners,
+    DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers
+  );
+
+  let results = (outs
+    TensorOf<[F16, F32, F64, I16, I32, I64, I8]>:$resized_images
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+}
+
 def TF_ReverseSequenceOp : TF_Op<"ReverseSequence", [NoSideEffect]> {
   let summary = "Reverses variable length slices.";
 
@@ -2576,6 +2751,27 @@
   TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
 }
 
+def TF_RoundOp : TF_Op<"Round", [NoSideEffect, SameOperandsAndResultType]> {
+  let summary = [{
+Rounds the values of a tensor to the nearest integer, element-wise.
+  }];
+
+  let description = [{
+Rounds half to even.  Also known as bankers rounding. If you want to round
+according to the current system rounding mode use std::cint.
+  }];
+
+  let arguments = (ins
+    TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x
+  );
+
+  let results = (outs
+    TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+}
+
 def TF_RsqrtOp : TF_Op<"Rsqrt", [NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Computes reciprocal of square root of x element-wise.";
 
@@ -2651,6 +2847,63 @@
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
 }
 
+def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect]> {
+  let summary = "Selects elements from `x` or `y`, depending on `condition`.";
+
+  let description = [{
+The `x`, and `y` tensors must all have the same shape, and the
+output will also have that shape.
+
+The `condition` tensor must be a scalar if `x` and `y` are scalars.
+If `x` and `y` are vectors or higher rank, then `condition` must be either a
+scalar, a vector with size matching the first dimension of `x`, or must have
+the same shape as `x`.
+
+The `condition` tensor acts as a mask that chooses, based on the value at each
+element, whether the corresponding element / row in the output should be
+taken from `x` (if true) or `y` (if false).
+
+If `condition` is a vector and `x` and `y` are higher rank matrices, then
+it chooses which row (outer dimension) to copy from `x` and `y`.
+If `condition` has the same shape as `x` and `y`, then it chooses which
+element to copy from `x` and `y`.
+
+For example:
+
+```python
+# 'condition' tensor is [[True,  False]
+#                        [False, True]]
+# 't' is [[1, 2],
+#         [3, 4]]
+# 'e' is [[5, 6],
+#         [7, 8]]
+select(condition, t, e)  # => [[1, 6], [7, 4]]
+
+
+# 'condition' tensor is [True, False]
+# 't' is [[1, 2],
+#         [3, 4]]
+# 'e' is [[5, 6],
+#         [7, 8]]
+select(condition, t, e) ==> [[1, 2],
+                             [7, 8]]
+
+```
+  }];
+
+  let arguments = (ins
+    I1Tensor:$condition,
+    TF_Tensor:$t,
+    TF_Tensor:$e
+  );
+
+  let results = (outs
+    TF_Tensor:$output
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
+}
+
 def TF_ShapeOp : TF_Op<"Shape", [NoSideEffect]> {
   let summary = "Returns the shape of a tensor.";
 
@@ -2805,6 +3058,151 @@
   TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>;
 }
 
+def TF_SpaceToDepthOp : TF_Op<"SpaceToDepth", [NoSideEffect]> {
+  let summary = "SpaceToDepth for tensors of type T.";
+
+  let description = [{
+Rearranges blocks of spatial data, into depth. More specifically,
+this op outputs a copy of the input tensor where values from the `height`
+and `width` dimensions are moved to the `depth` dimension.
+The attr `block_size` indicates the input block size.
+
+  * Non-overlapping blocks of size `block_size x block size` are rearranged
+    into depth at each location.
+  * The depth of the output tensor is `block_size * block_size * input_depth`.
+  * The Y, X coordinates within each block of the input become the high order
+    component of the output channel index.
+  * The input tensor's height and width must be divisible by block_size.
+
+The `data_format` attr specifies the layout of the input and output tensors
+with the following options:
+  "NHWC": `[ batch, height, width, channels ]`
+  "NCHW": `[ batch, channels, height, width ]`
+  "NCHW_VECT_C":
+      `qint8 [ batch, channels / 4, height, width, 4 ]`
+
+It is useful to consider the operation as transforming a 6-D Tensor.
+e.g. for data_format = NHWC,
+     Each element in the input tensor can be specified via 6 coordinates,
+     ordered by decreasing memory layout significance as:
+     n,oY,bY,oX,bX,iC  (where n=batch index, oX, oY means X or Y coordinates
+                        within the output image, bX, bY means coordinates
+                        within the input block, iC means input channels).
+     The output would be a transpose to the following layout:
+     n,oY,oX,bY,bX,iC
+
+This operation is useful for resizing the activations between convolutions
+(but keeping all data), e.g. instead of pooling. It is also useful for training
+purely convolutional models.
+
+For example, given an input of shape `[1, 2, 2, 1]`, data_format = "NHWC" and
+block_size = 2:
+
+```
+x = [[[[1], [2]],
+      [[3], [4]]]]
+```
+
+This operation will output a tensor of shape `[1, 1, 1, 4]`:
+
+```
+[[[[1, 2, 3, 4]]]]
+```
+
+Here, the input has a batch of 1 and each batch element has shape `[2, 2, 1]`,
+the corresponding output will have a single element (i.e. width and height are
+both 1) and will have a depth of 4 channels (1 * block_size * block_size).
+The output element shape is `[1, 1, 4]`.
+
+For an input tensor with larger depth, here of shape `[1, 2, 2, 3]`, e.g.
+
+```
+x = [[[[1, 2, 3], [4, 5, 6]],
+      [[7, 8, 9], [10, 11, 12]]]]
+```
+
+This operation, for block_size of 2, will return the following tensor of shape
+`[1, 1, 1, 12]`
+
+```
+[[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]]
+```
+
+Similarly, for the following input of shape `[1 4 4 1]`, and a block size of 2:
+
+```
+x = [[[[1],   [2],  [5],  [6]],
+      [[3],   [4],  [7],  [8]],
+      [[9],  [10], [13],  [14]],
+      [[11], [12], [15],  [16]]]]
+```
+
+the operator will return the following tensor of shape `[1 2 2 4]`:
+
+```
+x = [[[[1, 2, 3, 4],
+       [5, 6, 7, 8]],
+      [[9, 10, 11, 12],
+       [13, 14, 15, 16]]]]
+```
+  }];
+
+  let arguments = (ins
+    TF_Tensor:$input,
+
+    Confined<I64Attr, [IntMinValue<2>]>:$block_size,
+    DefaultValuedAttr<TF_AnyStrAttrOf<["NHWC", "NCHW", "NCHW_VECT_C"]>, "NHWC">:$data_format
+  );
+
+  let results = (outs
+    TF_Tensor:$output
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+}
+
+def TF_SparseToDenseOp : TF_Op<"SparseToDense", [NoSideEffect]> {
+  let summary = "Converts a sparse representation into a dense tensor.";
+
+  let description = [{
+Builds an array `dense` with shape `output_shape` such that
+
+```
+# If sparse_indices is scalar
+dense[i] = (i == sparse_indices ? sparse_values : default_value)
+
+# If sparse_indices is a vector, then for each i
+dense[sparse_indices[i]] = sparse_values[i]
+
+# If sparse_indices is an n by d matrix, then for each i in [0, n)
+dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i]
+```
+
+All other values in `dense` are set to `default_value`.  If `sparse_values` is a
+scalar, all sparse indices are set to this single value.
+
+Indices should be sorted in lexicographic order, and indices must not
+contain any repeats. If `validate_indices` is true, these properties
+are checked during execution.
+  }];
+
+  let arguments = (ins
+    TF_I32OrI64Tensor:$sparse_indices,
+    TF_I32OrI64Tensor:$output_shape,
+    TF_Tensor:$sparse_values,
+    TF_Tensor:$default_value,
+
+    DefaultValuedAttr<BoolAttr, "true">:$validate_indices
+  );
+
+  let results = (outs
+    TF_Tensor:$dense
+  );
+
+  TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<0>;
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>;
+}
+
 def TF_SplitOp : TF_Op<"Split", [NoSideEffect]> {
   let summary = "Splits a tensor into `num_split` tensors along one dimension.";
 
@@ -2943,6 +3341,42 @@
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 }
 
+def TF_StopGradientOp : TF_Op<"StopGradient", [NoSideEffect, SameOperandsAndResultType]> {
+  let summary = "Stops gradient computation.";
+
+  let description = [{
+When executed in a graph, this op outputs its input tensor as-is.
+
+When building ops to compute gradients, this op prevents the contribution of
+its inputs to be taken into account.  Normally, the gradient generator adds ops
+to a graph to compute the derivatives of a specified 'loss' by recursively
+finding out inputs that contributed to its computation.  If you insert this op
+in the graph it inputs are masked from the gradient generator.  They are not
+taken into account for computing gradients.
+
+This is useful any time you want to compute a value with TensorFlow but need
+to pretend that the value was a constant. Some examples include:
+
+*  The *EM* algorithm where the *M-step* should not involve backpropagation
+   through the output of the *E-step*.
+*  Contrastive divergence training of Boltzmann machines where, when
+   differentiating the energy function, the training must not backpropagate
+   through the graph that generated the samples from the model.
+*  Adversarial training, where no backprop should happen through the adversarial
+   example generation process.
+  }];
+
+  let arguments = (ins
+    TF_Tensor:$input
+  );
+
+  let results = (outs
+    TF_Tensor:$output
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+}
+
 def TF_StridedSliceOp : TF_Op<"StridedSlice", [NoSideEffect]> {
   let summary = "Return a strided slice from `input`.";
 
@@ -3176,6 +3610,31 @@
   TF_DerivedResultTypeAttr element_dtype = TF_DerivedResultTypeAttr<0>;
 }
 
+def TF_TensorListPushBackOp : TF_Op<"TensorListPushBack", [NoSideEffect]> {
+  let summary = [{
+Returns a list which has the passed-in `Tensor` as last element and the other elements of the given list in `input_handle`.
+  }];
+
+  let description = [{
+tensor: The tensor to put on the list.
+input_handle: The old list.
+output_handle: A list with the elements of the old list followed by tensor.
+element_dtype: the type of elements in the list.
+element_shape: a shape compatible with that of elements in the list.
+  }];
+
+  let arguments = (ins
+    TF_VariantTensor:$input_handle,
+    TF_Tensor:$tensor
+  );
+
+  let results = (outs
+    TF_VariantTensor:$output_handle
+  );
+
+  TF_DerivedOperandTypeAttr element_dtype = TF_DerivedOperandTypeAttr<1>;
+}
+
 def TF_TensorListSetItemOp : TF_Op<"TensorListSetItem", [NoSideEffect]> {
   let summary = "";
 
@@ -3220,6 +3679,30 @@
   TF_DerivedResultTypeAttr element_dtype = TF_DerivedResultTypeAttr<0>;
 }
 
+def TF_TileOp : TF_Op<"Tile", [NoSideEffect]> {
+  let summary = "Constructs a tensor by tiling a given tensor.";
+
+  let description = [{
+This operation creates a new tensor by replicating `input` `multiples` times.
+The output tensor's i'th dimension has `input.dims(i) * multiples[i]` elements,
+and the values of `input` are replicated `multiples[i]` times along the 'i'th
+dimension. For example, tiling `[a b c d]` by `[2]` produces
+`[a b c d a b c d]`.
+  }];
+
+  let arguments = (ins
+    TF_Tensor:$input,
+    TF_I32OrI64Tensor:$multiples
+  );
+
+  let results = (outs
+    TF_Tensor:$output
+  );
+
+  TF_DerivedOperandTypeAttr Tmultiples = TF_DerivedOperandTypeAttr<1>;
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+}
+
 def TF_TopKV2Op : TF_Op<"TopKV2", [NoSideEffect]> {
   let summary = [{
 Finds values and indices of the `k` largest elements for the last dimension.
@@ -3379,6 +3862,82 @@
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 }
 
+def TF_WhereOp : TF_Op<"Where", [NoSideEffect]> {
+  let summary = "Returns locations of nonzero / true values in a tensor.";
+
+  let description = [{
+This operation returns the coordinates of true elements in `condition`. The
+coordinates are returned in a 2-D tensor where the first dimension (rows)
+represents the number of true elements, and the second dimension (columns)
+represents the coordinates of the true elements. Keep in mind, the shape of
+the output tensor can vary depending on how many true values there are in
+`condition`. Indices are output in row-major order.
+
+For example:
+
+```
+# 'input' tensor is [[True, False]
+#                    [True, False]]
+# 'input' has two true values, so output has two coordinates.
+# 'input' has rank of 2, so coordinates have two indices.
+where(input) ==> [[0, 0],
+                  [1, 0]]
+
+# `condition` tensor is [[[True, False]
+#                     [True, False]]
+#                    [[False, True]
+#                     [False, True]]
+#                    [[False, False]
+#                     [False, True]]]
+# 'input' has 5 true values, so output has 5 coordinates.
+# 'input' has rank of 3, so coordinates have three indices.
+where(input) ==> [[0, 0, 0],
+                  [0, 1, 0],
+                  [1, 0, 1],
+                  [1, 1, 1],
+                  [2, 1, 1]]
+
+# `condition` tensor is [[[1.5,  0.0]
+#                     [-0.5, 0.0]]
+#                    [[0.0,  0.25]
+#                     [0.0,  0.75]]
+#                    [[0.0,  0.0]
+#                     [0.0,  0.01]]]
+# 'input' has 5 nonzero values, so output has 5 coordinates.
+# 'input' has rank of 3, so coordinates have three indices.
+where(input) ==> [[0, 0, 0],
+                  [0, 1, 0],
+                  [1, 0, 1],
+                  [1, 1, 1],
+                  [2, 1, 1]]
+
+# `condition` tensor is [[[1.5 + 0.0j, 0.0  + 0.0j]
+#                     [0.0 + 0.5j, 0.0  + 0.0j]]
+#                    [[0.0 + 0.0j, 0.25 + 1.5j]
+#                     [0.0 + 0.0j, 0.75 + 0.0j]]
+#                    [[0.0 + 0.0j, 0.0  + 0.0j]
+#                     [0.0 + 0.0j, 0.01 + 0.0j]]]
+# 'input' has 5 nonzero magnitude values, so output has 5 coordinates.
+# 'input' has rank of 3, so coordinates have three indices.
+where(input) ==> [[0, 0, 0],
+                  [0, 1, 0],
+                  [1, 0, 1],
+                  [1, 1, 1],
+                  [2, 1, 1]]
+```
+  }];
+
+  let arguments = (ins
+    TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$input
+  );
+
+  let results = (outs
+    I64Tensor:$index
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+}
+
 def TF_XdivyOp : TF_Op<"Xdivy", [Broadcastable, NoSideEffect]>,
                  WithBroadcastableBinOpBuilder {
   let summary = "Returns 0 if x == 0, and x / y otherwise, elementwise.";
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
index f374b6b..ca6e181 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
@@ -16,7 +16,8 @@
 // This is the base operation definition file for TensorFlow.
 //
 // This file includes the definition for the TensorFlow dialect, base TensorFlow
-// op, and various commonly used TensorFlow types, attributes, and builders.
+// op, and various commonly used TensorFlow traits, types, attributes, and
+// builders.
 
 #ifdef TF_OP_BASE
 #else
@@ -51,6 +52,16 @@
 }
 
 //===----------------------------------------------------------------------===//
+// TensorFlow traits
+//===----------------------------------------------------------------------===//
+
+// Specify this trait if the op requires all outputs to have the same type and
+// the inputs either have the same type as result or a ref type corresponding to
+// the result type.
+def TF_OperandsSameAsResultsTypeOrRef : NativeOpTrait<
+  "TF::OperandsSameAsResultsTypeOrRef">;
+
+//===----------------------------------------------------------------------===//
 // TensorFlow op definitions
 //===----------------------------------------------------------------------===//
 
@@ -80,7 +91,30 @@
 
 def TF_I32OrI64Tensor : TensorOf<[TF_I32Or64]>;
 
-def TF_Int : IntOfWidths<[8, 16, 32, 64]>;
+def TF_Uint8 : Type<CPred<"$_self.isa<mlir::TF::Uint8Type>()">,
+                    "TensorFlow uint8 type">,
+               BuildableType<"getType<mlir::TF::Uint8Type>()">;
+
+def TF_Uint16 : Type<CPred<"$_self.isa<mlir::TF::Uint16Type>()">,
+                     "TensorFlow uint16 type">,
+               BuildableType<"getType<mlir::TF::Uint16Type>()">;
+
+def TF_Uint32 : Type<CPred<"$_self.isa<mlir::TF::Uint32Type>()">,
+                     "TensorFlow uint32 type">,
+               BuildableType<"getType<mlir::TF::Uint32Type>()">;
+
+def TF_Uint64 : Type<CPred<"$_self.isa<mlir::TF::Uint64Type>()">,
+                     "TensorFlow uint64 type">,
+                BuildableType<"getType<mlir::TF::Uint64Type>()">;
+
+// Any unsigned integer type
+def TF_UInt : AnyTypeOf<[TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64]>;
+
+// Any signed integer type
+def TF_SInt : IntOfWidths<[8, 16, 32, 64]>;
+
+// Any integer type
+def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt]>;
 
 // Any integer tensor types
 def TF_IntTensor : TensorOf<[TF_Int]>;
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
index 41e168b..f3308a5 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
@@ -42,10 +42,6 @@
 namespace mlir {
 namespace TF {
 
-namespace {
-#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
-}  // namespace
-
 //===----------------------------------------------------------------------===//
 // TF op helper functions
 //===----------------------------------------------------------------------===//
@@ -75,10 +71,11 @@
     return ranked_type.getRank() >= rank;
   return type.isa<UnrankedTensorType>();
 }
+
 // Returns true if the given pair of TensorFlow types can be cast to one
 // another. In other words, a single run-time value is legal for both the types.
 // For example, tensor<*xf32> and tensor<3xf32> are cast compatible.
-bool AreCastCompatible(Type a, Type b) {
+static bool AreCastCompatible(Type a, Type b) {
   if (TensorCastOp::areCastCompatible(a, b)) return true;
 
   // Variant types may optionally contain subtypes information that need not
@@ -89,13 +86,27 @@
          getElementTypeOrSelf(b).getKind() == TensorFlowTypes::VARIANT;
 }
 
+// Returns either the element type or type of the result of a single result
+// operation.
+// TODO(antiagainst): We need an overload function, which mandates function
+// name. This is temporary. Remove this post variadic operand support is
+// improved.
+static Type getElementTypeOrSelf(Operation *op) {
+  if (op->getNumResults() != 1) return {};
+  return getElementTypeOrSelf(op->getResult(0));
+}
+
+namespace {
+#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
+}  // namespace
+
 //===----------------------------------------------------------------------===//
 // AddOp
 //===----------------------------------------------------------------------===//
 
 void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                         MLIRContext *context) {
-  RewriteListBuilder<AddToAddV2>::build(results, context);
+  results.insert<AddToAddV2>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -104,7 +115,36 @@
 
 void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                           MLIRContext *context) {
-  RewriteListBuilder<AddV2OfNegLeft, AddV2OfNegRight>::build(results, context);
+  results.insert<AddV2OfNegLeft, AddV2OfNegRight>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// AssertOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Removes Assert with constant true predicate.
+struct AssertWithTrue : public OpRewritePattern<AssertOp> {
+  using OpRewritePattern<AssertOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(AssertOp op,
+                                     PatternRewriter &rewriter) const override {
+    ElementsAttr cst;
+    if (matchPattern(op.condition(), m_Constant(&cst))) {
+      if (cst.getValue({}).cast<BoolAttr>().getValue()) {
+        rewriter.replaceOp(op, llvm::None);
+        return matchSuccess();
+      }
+    }
+    return matchFailure();
+  }
+};
+}  // namespace
+
+void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                           MLIRContext *context) {
+  results.insert<AssertWithTrue>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -113,7 +153,7 @@
 
 void BitcastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                             MLIRContext *context) {
-  RewriteListBuilder<BitcastSameType, BitcastNested>::build(results, context);
+  results.insert<BitcastSameType, BitcastNested>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -134,7 +174,7 @@
 
 void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                          MLIRContext *context) {
-  RewriteListBuilder<CastSameType>::build(results, context);
+  results.insert<CastSameType>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -143,7 +183,7 @@
 
 void ConjOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                          MLIRContext *context) {
-  RewriteListBuilder<ConjNested>::build(results, context);
+  results.insert<ConjNested>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -199,7 +239,23 @@
 
 void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                         MLIRContext *context) {
-  RewriteListBuilder<DivWithSqrtDivisor>::build(results, context);
+  results.insert<DivWithSqrtDivisor>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// EmptyTensorListOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult Verify(EmptyTensorListOp op) {
+  if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
+      !IsOfRankOrUnranked(op.element_shape(), 1)) {
+    return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
+  }
+
+  if (!IsOfRankOrUnranked(op.max_num_elements(), 0)) {
+    return op.emitOpError("requires max_num_elements operand to be 0D tensor");
+  }
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -359,7 +415,7 @@
 
 void InvertOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                            MLIRContext *context) {
-  RewriteListBuilder<InvertNested>::build(results, context);
+  results.insert<InvertNested>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -393,7 +449,7 @@
 
 void LogOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                         MLIRContext *context) {
-  RewriteListBuilder<LogOfSoftmax>::build(results, context);
+  results.insert<LogOfSoftmax>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -402,10 +458,9 @@
 
 void LogicalNotOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
-  RewriteListBuilder<LogicalNotNested, LogicalNotOfEqual, LogicalNotOfNotEqual,
-                     LogicalNotOfGreater, LogicalNotOfGreaterEqual,
-                     LogicalNotOfLess, LogicalNotOfLessEqual>::build(results,
-                                                                     context);
+  results.insert<LogicalNotNested, LogicalNotOfEqual, LogicalNotOfNotEqual,
+                 LogicalNotOfGreater, LogicalNotOfGreaterEqual,
+                 LogicalNotOfLess, LogicalNotOfLessEqual>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -414,7 +469,7 @@
 
 void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                         MLIRContext *context) {
-  RewriteListBuilder<NegNested>::build(results, context);
+  results.insert<NegNested>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -423,7 +478,7 @@
 
 void ReciprocalOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
-  RewriteListBuilder<ReciprocalNested>::build(results, context);
+  results.insert<ReciprocalNested>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -482,7 +537,7 @@
 
 void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                             MLIRContext *context) {
-  RewriteListBuilder<RealDivWithSqrtDivisor>::build(results, context);
+  results.insert<RealDivWithSqrtDivisor>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -493,6 +548,7 @@
 // m_Constant.
 static LogicalResult Verify(ReshapeOp op) {
   auto shapeType = op.shape()->getType().cast<TensorType>();
+  if (!shapeType.hasRank()) return success();
   if (shapeType.getRank() != 1)
     return op.emitOpError("shape must be 1D tensor");
   auto rankByShape = shapeType.getShape()[0];
@@ -641,7 +697,7 @@
 
 void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                            MLIRContext *context) {
-  RewriteListBuilder<SquareOfSub>::build(results, context);
+  results.insert<SquareOfSub>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -650,7 +706,7 @@
 
 void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                         MLIRContext *context) {
-  RewriteListBuilder<SubOfNeg>::build(results, context);
+  results.insert<SubOfNeg>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -721,7 +777,7 @@
 
 void TruncateDivOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
-  RewriteListBuilder<TruncateDivWithSqrtDivisor>::build(results, context);
+  results.insert<TruncateDivWithSqrtDivisor>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -731,15 +787,23 @@
 static LogicalResult Verify(WhileOp op) {
   auto module = op.getParentOfType<ModuleOp>();
   auto condFn = module.lookupSymbol<FuncOp>(op.cond());
+  auto bodyFn = module.lookupSymbol<FuncOp>(op.body());
+  if (!condFn) {
+    return op.emitOpError("cond refers to an undefined function : ")
+           << op.cond();
+  }
+  if (!bodyFn) {
+    return op.emitOpError("body refers to an undefined function : ")
+           << op.body();
+  }
+
   auto condFuncType = condFn.getType();
+  auto bodyFuncType = bodyFn.getType();
 
   // Verify that the cond function has exactly one result.
   if (condFuncType.getNumResults() != 1)
     return op.emitOpError("requires cond function to have exactly one result");
 
-  auto bodyFn = module.lookupSymbol<FuncOp>(op.body());
-  auto bodyFuncType = bodyFn.getType();
-
   SmallVector<Type, 4> operands(op.getOperandTypes());
   SmallVector<Type, 4> results(op.getResultTypes());
 
@@ -810,7 +874,7 @@
 
 void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                           MLIRContext *context) {
-  RewriteListBuilder<XdivyWithSqrtDivisor>::build(results, context);
+  results.insert<XdivyWithSqrtDivisor>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h
index fff2ffa..8a2fa9d 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h
@@ -28,6 +28,7 @@
 #include "mlir/IR/OpDefinition.h"  // TF:local_config_mlir
 #include "mlir/IR/StandardTypes.h"  // TF:local_config_mlir
 #include "mlir/IR/TypeUtilities.h"  // TF:local_config_mlir
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 
 namespace mlir {
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
index a803826..d889a5d 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
@@ -30,6 +30,37 @@
 
 include "tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td"
 
+class TF_TensorListInitOp<string mnemonic> : TF_Op<mnemonic, [NoSideEffect]> {
+  let results = (outs
+    TF_VariantTensor:$handle
+  );
+
+  TF_DerivedOperandTypeAttr shape_type = TF_DerivedOperandTypeAttr<0>;
+
+  let verifier = [{
+    if (handle_dtype().getSubtypes().size() != 1) {
+      return emitOpError(
+          "must have exactly one subtype in the result variant type");
+    }
+
+    return Verify(*this);
+  }];
+
+  DerivedTypeAttr element_dtype = DerivedTypeAttr<
+      "return getElementTypeOrSelf(element_type());">;
+
+  let extraClassDeclaration = [{
+    // Returns type of the TensorList element produced by this op.
+    TensorType element_type() { return handle_dtype().getSubtypes()[0]; }
+
+    // Returns data type of the result handle. Returned type contains type of
+    // the TensorList element as a subtype.
+    VariantType handle_dtype() {
+      return getElementTypeOrSelf(handle()->getType()).cast<TF::VariantType>();
+    }
+  }];
+}
+
 // In MLIR, the TensorFlow tensor value is represented as an ElementsAttr, with
 // its type encoding the tensor's shape and data type.
 def TF_ConstOp : TF_Op<"Const", [NoSideEffect]> {
@@ -55,12 +86,30 @@
   let hasFolder = 1;
 }
 
+def TF_EmptyTensorListOp : TF_TensorListInitOp<"EmptyTensorList"> {
+  let summary = "Creates and returns an empty tensor list.";
+
+  let description = [{
+All list elements must be tensors of dtype element_dtype and shape compatible
+with element_shape.
+
+handle: an empty tensor list.
+element_dtype: the type of elements in the list.
+element_shape: a shape compatible with that of elements in the list.
+  }];
+
+  let arguments = (ins
+    TF_I32OrI64Tensor:$element_shape,
+    I32Tensor:$max_num_elements
+  );
+}
+
 // TODO(fengliuai): The tf.Identity is side-effect free and it doesn't change
 // the status of the system during the execution. However it shouldn't be folded
 // in general if it used to serve for caching and some other invariant checks,
 // so we removed the side-effect free property in the op definition. This is a
 // hack, and we should fix it if we have a better way to model it.
-def TF_IdentityOp : TF_Op<"Identity", [SameOperandsAndResultType]> {
+def TF_IdentityOp : TF_Op<"Identity", [TF_OperandsSameAsResultsTypeOrRef]> {
   let summary = "Identity op";
 
   let description = [{
@@ -191,51 +240,6 @@
   TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
 }
 
-def TF_TensorListReserveOp : TF_Op<"TensorListReserve", [NoSideEffect]> {
-  let summary = "List of the given size with empty elements.";
-
-  let description = [{
-element_shape: the shape of the future elements of the list
-num_elements: the number of elements to reserve
-handle: the output list
-element_dtype: the desired type of elements in the list.
-  }];
-
-  let arguments = (ins
-    TF_I32OrI64Tensor:$element_shape,
-    I32Tensor:$num_elements
-  );
-
-  let results = (outs
-    TF_VariantTensor:$handle
-  );
-
-  TF_DerivedOperandTypeAttr shape_type = TF_DerivedOperandTypeAttr<0>;
-
-  let verifier = [{
-    if (handle_dtype().getSubtypes().size() != 1) {
-      return emitOpError(
-          "must have exactly one subtype in the result variant type");
-    }
-
-    return Verify(*this);
-  }];
-
-  DerivedTypeAttr element_dtype = DerivedTypeAttr<
-      "return getElementTypeOrSelf(element_type());">;
-
-  let extraClassDeclaration = [{
-    // Returns type of the TensorList element produced by this op.
-    TensorType element_type() { return handle_dtype().getSubtypes()[0]; }
-
-    // Returns data type of the result handle. Returned type contains type of
-    // the TensorList element as a subtype.
-    VariantType handle_dtype() {
-      return getElementTypeOrSelf(handle()->getType()).cast<TF::VariantType>();
-    }
-  }];
-}
-
 def TF_WhileOp : TF_Op<"While", []> {
   let summary = [{
 output = input; While (Cond(output)) { output = Body(output) }
@@ -264,7 +268,11 @@
     SymbolRefAttr:$cond,
     SymbolRefAttr:$body,
     DefaultValuedAttr<StrArrayAttr, "{}">:$output_shapes,
-    DefaultValuedAttr<I64Attr, "10">:$parallel_iterations
+    DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
+
+    // Used to map StatelessWhile and While op defined in TensorFlow to a common
+    // op.
+    BoolAttr:$is_stateless
   );
 
   let results = (outs
@@ -278,4 +286,20 @@
   }];
 }
 
+def TF_TensorListReserveOp : TF_TensorListInitOp<"TensorListReserve"> {
+  let summary = "List of the given size with empty elements.";
+
+  let description = [{
+element_shape: the shape of the future elements of the list
+num_elements: the number of elements to reserve
+handle: the output list
+element_dtype: the desired type of elements in the list.
+  }];
+
+  let arguments = (ins
+    TF_I32OrI64Tensor:$element_shape,
+    I32Tensor:$num_elements
+  );
+}
+
 #endif // TF_OPS
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h
new file mode 100644
index 0000000..b96026c
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h
@@ -0,0 +1,109 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file defines the op traits used in the MLIR TensorFlow dialect.
+
+#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_
+#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_
+
+#include "mlir/IR/OpDefinition.h"  // TF:local_config_mlir
+#include "mlir/IR/StandardTypes.h"  // TF:local_config_mlir
+#include "mlir/IR/TypeUtilities.h"  // TF:local_config_mlir
+#include "mlir/Support/LogicalResult.h"  // TF:local_config_mlir
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
+
+namespace mlir {
+namespace OpTrait {
+namespace TF {
+
+// Verifies if 'ref_type' is a REF type corresponding to 'type'.
+static inline LogicalResult VerifyRefTypeMatch(mlir::Type type,
+                                               mlir::Type ref_type) {
+  auto ref_type_kind = ref_type.getKind();
+  switch (type.getKind()) {
+    case mlir::StandardTypes::F16:
+      return success(ref_type_kind == mlir::TF::TensorFlowTypes::HALF_REF);
+    case mlir::StandardTypes::F32:
+      return success(ref_type_kind == mlir::TF::TensorFlowTypes::FLOAT_REF);
+    case mlir::StandardTypes::F64:
+      return success(ref_type_kind == mlir::TF::TensorFlowTypes::DOUBLE_REF);
+    case mlir::StandardTypes::BF16:
+      return success(ref_type_kind == mlir::TF::TensorFlowTypes::BFLOAT16_REF);
+    case mlir::StandardTypes::Integer: {
+      const auto& itype = type.cast<mlir::IntegerType>();
+      switch (itype.getWidth()) {
+        case 1:
+          return success(ref_type_kind == mlir::TF::TensorFlowTypes::BOOL_REF);
+        case 8:
+          return success(ref_type_kind == mlir::TF::TensorFlowTypes::INT8_REF);
+        case 16:
+          return success(ref_type_kind == mlir::TF::TensorFlowTypes::INT16_REF);
+        case 32:
+          return success(ref_type_kind == mlir::TF::TensorFlowTypes::INT32_REF);
+        case 64:
+          return success(ref_type_kind == mlir::TF::TensorFlowTypes::INT64_REF);
+        default:
+          return failure();
+      }
+    }
+#define HANDLE_TF_TYPE(tftype, enumerant, name) \
+  case mlir::TF::TensorFlowTypes::enumerant:    \
+    return success(ref_type_kind == mlir::TF::TensorFlowTypes::enumerant##_REF);
+
+#define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
+// NOLINTNEXTLINE
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
+    default:
+      return failure();
+  }
+}
+
+// This class provides verification for ops that are known to have the same
+// result types and all operands are either of the same type as result or a REF
+// type corresponding to the result type.
+template <typename ConcreteType>
+class OperandsSameAsResultsTypeOrRef
+    : public TraitBase<ConcreteType, OperandsSameAsResultsTypeOrRef> {
+ public:
+  static LogicalResult verifyTrait(Operation* op) {
+    LogicalResult shapeMatch = impl::verifySameOperandsAndResultShape(op);
+    if (failed(shapeMatch)) return shapeMatch;
+
+    auto type = getElementTypeOrSelf(op->getResult(0)->getType());
+
+    // Verify that the first result type is same as the rest of the results.
+    // We skip the comparison against itself.
+    for (auto resultType : llvm::drop_begin(op->getResultTypes(), 1)) {
+      resultType = getElementTypeOrSelf(resultType);
+      if (resultType != type)
+        return op->emitOpError() << "requires the same type for all results";
+    }
+
+    for (auto opType : op->getOperandTypes()) {
+      opType = getElementTypeOrSelf(opType);
+      if (opType != type && failed(VerifyRefTypeMatch(type, opType))) {
+        return op->emitError() << "requires all operands to be either same "
+                                  "as or ref type of results";
+      }
+    }
+    return success();
+  }
+};
+
+}  // namespace TF
+}  // namespace OpTrait
+}  // namespace mlir
+
+#endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.def b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.def
index 9f1154b..e5041d0 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.def
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.def
@@ -32,28 +32,33 @@
 HANDLE_TF_TYPE(Resource, RESOURCE, "resource")
 HANDLE_TF_TYPE(Complex64, COMPLEX64, "complex64")
 HANDLE_TF_TYPE(Complex128, COMPLEX128, "complex128")
-HANDLE_TF_TYPE(FloatRef, FLOAT_REF, "f32ref")
-HANDLE_TF_TYPE(DoubleRef, DOUBLE_REF, "f64ref")
-HANDLE_TF_TYPE(Uint8Ref, UINT8_REF, "uint8ref")
-HANDLE_TF_TYPE(Int8Ref, INT8_REF, "int8ref")
-HANDLE_TF_TYPE(Uint16Ref, UINT16_REF, "uint16ref")
-HANDLE_TF_TYPE(Int16Ref, INT16_REF, "int16ref")
-HANDLE_TF_TYPE(Uint32Ref, UINT32_REF, "uint32ref")
-HANDLE_TF_TYPE(Int32Ref, INT32_REF, "int32ref")
-HANDLE_TF_TYPE(Uint64Ref, UINT64_REF, "uint64ref")
-HANDLE_TF_TYPE(Int64Ref, INT64_REF, "int64ref")
-HANDLE_TF_TYPE(StringRef, STRING_REF, "stringref")
-HANDLE_TF_TYPE(BoolRef, BOOL_REF, "boolref")
-HANDLE_TF_TYPE(Quint8Ref, QUINT8_REF, "quint8ref")
-HANDLE_TF_TYPE(Qint8Ref, QINT8_REF, "qint8ref")
-HANDLE_TF_TYPE(Quint16Ref, QUINT16_REF, "quint16ref")
-HANDLE_TF_TYPE(Qint16Ref, QINT16_REF, "qint16ref")
-HANDLE_TF_TYPE(Qint32Ref, QINT32_REF, "qint32ref")
-HANDLE_TF_TYPE(Bfloat16Ref, BFLOAT16_REF, "bfloat16ref")
-HANDLE_TF_TYPE(Complex64Ref, COMPLEX64_REF, "complex64ref")
-HANDLE_TF_TYPE(Complex128Ref, COMPLEX128_REF, "complex128ref")
-HANDLE_TF_TYPE(HalfRef, HALF_REF, "halfref")
-HANDLE_TF_TYPE(ResourceRef, RESOURCE_REF, "resourceref")
+
+#ifndef HANDLE_TF_REF_TYPE
+#define HANDLE_TF_REF_TYPE(class, enumerant, name) \
+  HANDLE_TF_TYPE(class, enumerant, name)
+#endif
+HANDLE_TF_REF_TYPE(FloatRef, FLOAT_REF, "f32ref")
+HANDLE_TF_REF_TYPE(DoubleRef, DOUBLE_REF, "f64ref")
+HANDLE_TF_REF_TYPE(Uint8Ref, UINT8_REF, "uint8ref")
+HANDLE_TF_REF_TYPE(Int8Ref, INT8_REF, "int8ref")
+HANDLE_TF_REF_TYPE(Uint16Ref, UINT16_REF, "uint16ref")
+HANDLE_TF_REF_TYPE(Int16Ref, INT16_REF, "int16ref")
+HANDLE_TF_REF_TYPE(Uint32Ref, UINT32_REF, "uint32ref")
+HANDLE_TF_REF_TYPE(Int32Ref, INT32_REF, "int32ref")
+HANDLE_TF_REF_TYPE(Uint64Ref, UINT64_REF, "uint64ref")
+HANDLE_TF_REF_TYPE(Int64Ref, INT64_REF, "int64ref")
+HANDLE_TF_REF_TYPE(StringRef, STRING_REF, "stringref")
+HANDLE_TF_REF_TYPE(BoolRef, BOOL_REF, "boolref")
+HANDLE_TF_REF_TYPE(Quint8Ref, QUINT8_REF, "quint8ref")
+HANDLE_TF_REF_TYPE(Qint8Ref, QINT8_REF, "qint8ref")
+HANDLE_TF_REF_TYPE(Quint16Ref, QUINT16_REF, "quint16ref")
+HANDLE_TF_REF_TYPE(Qint16Ref, QINT16_REF, "qint16ref")
+HANDLE_TF_REF_TYPE(Qint32Ref, QINT32_REF, "qint32ref")
+HANDLE_TF_REF_TYPE(Bfloat16Ref, BFLOAT16_REF, "bfloat16ref")
+HANDLE_TF_REF_TYPE(Complex64Ref, COMPLEX64_REF, "complex64ref")
+HANDLE_TF_REF_TYPE(Complex128Ref, COMPLEX128_REF, "complex128ref")
+HANDLE_TF_REF_TYPE(HalfRef, HALF_REF, "halfref")
+HANDLE_TF_REF_TYPE(ResourceRef, RESOURCE_REF, "resourceref")
 
 #ifndef HANDLE_CUSTOM_TF_TYPE
 #define HANDLE_CUSTOM_TF_TYPE(class, enumerant, name) \
@@ -64,10 +69,11 @@
 
 #ifndef HANDLE_LAST_TF_TYPE
 #define HANDLE_LAST_TF_TYPE(class, enumerant, name) \
-  HANDLE_TF_TYPE(class, enumerant, name)
+  HANDLE_TF_REF_TYPE(class, enumerant, name)
 #endif
 HANDLE_LAST_TF_TYPE(VariantRef, VARIANT_REF, "variantref")
 #undef HANDLE_LAST_TF_TYPE
 
+#undef HANDLE_TF_REF_TYPE
 #undef HANDLE_TF_TYPE
 #endif
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
index ffd6bee..65feaa8 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
@@ -1,5 +1,21 @@
 // RUN: tf-opt %s -canonicalize | FileCheck %s
 
+// CHECK-LABEL: func @tfAssertTrue
+func @tfAssertTrue(%arg0: tensor<1x1x6x2xf32>) {
+  %t = constant dense<true> : tensor<i1>
+  // CHECK-NOT: tf.Assert
+  "tf.Assert"(%t, %arg0) {summarize = 3} : (tensor<i1>, tensor<1x1x6x2xf32>) -> ()
+  return
+}
+
+// CHECK-LABEL: func @tfAssertFalse
+func @tfAssertFalse(%arg0: tensor<1x1x6x2xf32>) {
+  %f = constant dense<false> : tensor<i1>
+  // CHECK: tf.Assert
+  "tf.Assert"(%f, %arg0) {summarize = 3} : (tensor<i1>, tensor<1x1x6x2xf32>) -> ()
+  return
+}
+
 // CHECK-LABEL: func @testLeakyRelu
 func @testLeakyRelu(%arg0 : tensor<16xf32>) -> (tensor<16xf32>) {
   %2 = "tf.LeakyRelu"(%arg0) {alpha = 1.0 : f32} : (tensor<16xf32>) -> tensor<16xf32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir b/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir
new file mode 100644
index 0000000..f879767
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir
@@ -0,0 +1,112 @@
+// RUN: tf-opt %s -split-input-file -tf-device-cluster-outlining | FileCheck %s
+
+// Tests simple case of a single `tf_device.launch`.
+
+module {
+  // CHECK-LABEL: func @multiplelaunches
+  // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<?xi32>)
+  func @multiplelaunches(%arg0: tensor<?xi32>) -> tensor<?xi32> {
+    %0 = tf_executor.graph {
+      %1:2 = tf_executor.island {
+        // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]])
+        %2 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
+
+        // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf_device.launch_func"(%[[A_OUTPUT]]) {device = "tpu0", func = @tpu0_func}
+        %3 = "tf_device.launch"() ( {
+          %4 = "tf.B"(%2) : (tensor<?xi32>) -> tensor<?xi32>
+          "tf_device.return"(%4) : (tensor<?xi32>) -> ()
+        }) {device = "tpu0"} : () -> tensor<?xi32>
+
+        // CHECK: tf_executor.yield %[[C_OUTPUT]]
+        tf_executor.yield %3 : tensor<?xi32>
+      }
+      tf_executor.fetch %1#0 : tensor<?xi32>
+    }
+    return %0 : tensor<?xi32>
+  }
+
+// CHECK-LABEL: func @tpu0_func
+// CHECK-SAME: (%[[TPU0_FUNC_ARG_0:[a-z0-9]*]]: tensor<?xi32>) -> tensor<?xi32>
+// CHECK: %[[TPU0_FUNC_B_OUTPUT:[0-9]*]] = "tf.B"(%[[TPU0_FUNC_ARG_0]])
+// CHECK: return %[[TPU0_FUNC_B_OUTPUT]]
+}
+
+// -----
+
+// Tests that multiple `tf_device.launch` that depend on each other are
+// correctly handled.
+
+module {
+  // CHECK-LABEL: func @multiplelaunches
+  // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<?xi32>)
+  func @multiplelaunches(%arg0: tensor<?xi32>) -> tensor<?xi32> {
+    %0 = tf_executor.graph {
+      %1:2 = tf_executor.island {
+        // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]])
+        %2 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
+
+        // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf_device.launch_func"(%[[A_OUTPUT]]) {device = "tpu0", func = @tpu0_func}
+        %3 = "tf_device.launch"() ( {
+          %6 = "tf.B"(%2) : (tensor<?xi32>) -> tensor<?xi32>
+          "tf_device.return"(%6) : (tensor<?xi32>) -> ()
+        }) {device = "tpu0"} : () -> tensor<?xi32>
+
+        // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[C_OUTPUT]])
+        %4 = "tf.D"(%3) : (tensor<?xi32>) -> tensor<?xi32>
+
+        // CHECK: %[[E_OUTPUT:[0-9]*]] = "tf_device.launch_func"(%[[C_OUTPUT]], %[[D_OUTPUT]]) {device = "gpu0", func = @gpu0_func}
+        %5 = "tf_device.launch"() ( {
+          %6 = "tf.E"(%3) : (tensor<?xi32>) -> tensor<?xi32>
+          %7 = "tf.F"(%4, %6) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
+          "tf_device.return"(%7) : (tensor<?xi32>) -> ()
+        }) {device = "gpu0"} : () -> tensor<?xi32>
+
+        // CHECK: tf_executor.yield %[[E_OUTPUT]]
+        tf_executor.yield %5 : tensor<?xi32>
+      }
+      tf_executor.fetch %1#0 : tensor<?xi32>
+    }
+    return %0 : tensor<?xi32>
+  }
+
+// CHECK-LABEL: func @tpu0_func
+// CHECK-SAME: (%[[TPU0_FUNC_ARG_0:[a-z0-9]*]]: tensor<?xi32>) -> tensor<?xi32>
+// CHECK: %[[TPU0_FUNC_B_OUTPUT:[0-9]*]] = "tf.B"(%[[TPU0_FUNC_ARG_0]])
+// CHECK: return %[[TPU0_FUNC_B_OUTPUT]]
+
+// CHECK-LABEL: func @gpu0_func
+// CHECK-SAME: (%[[GPU0_FUNC_ARG_0:[a-z0-9]*]]: tensor<?xi32>, %[[GPU0_FUNC_ARG_1:[a-z0-9]*]]: tensor<?xi32>) -> tensor<?xi32>
+// CHECK: %[[GPU0_FUNC_E_OUTPUT:[0-9]*]] = "tf.E"(%[[GPU0_FUNC_ARG_0]])
+// CHECK: %[[GPU0_FUNC_F_OUTPUT:[0-9]*]] = "tf.F"(%[[GPU0_FUNC_ARG_1]], %[[GPU0_FUNC_E_OUTPUT]])
+// CHECK: return %[[GPU0_FUNC_F_OUTPUT]]
+}
+
+// -----
+
+// Tests outlining launches with no live-in values.
+
+module {
+  // CHECK-LABEL: func @multiplelaunches
+  // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<?xi32>)
+  func @multiplelaunches(%arg0: tensor<?xi32>) -> tensor<?xi32> {
+    %0 = tf_executor.graph {
+      %1:2 = tf_executor.island {
+        // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf_device.launch_func"() {device = "tpu0", func = @tpu0_func}
+        %2 = "tf_device.launch"() ( {
+          %3 = "tf.A"() : () -> tensor<?xi32>
+          "tf_device.return"(%3) : (tensor<?xi32>) -> ()
+        }) {device = "tpu0"} : () -> tensor<?xi32>
+
+        // CHECK: tf_executor.yield %[[A_OUTPUT]]
+        tf_executor.yield %2 : tensor<?xi32>
+      }
+      tf_executor.fetch %1#0 : tensor<?xi32>
+    }
+    return %0 : tensor<?xi32>
+  }
+
+// CHECK-LABEL: func @tpu0_func
+// CHECK-SAME: () -> tensor<?xi32>
+// CHECK: %[[TPU0_FUNC_A_OUTPUT:[0-9]*]] = "tf.A"()
+// CHECK: return %[[TPU0_FUNC_A_OUTPUT]]
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
index 51aaf6e..d8a1ce6 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
@@ -42,3 +42,24 @@
   // CHECK-DAG: constant dense<0.242886767> : tensor<1x1x6x2xf32>
   return %0, %21 : tensor<4xf32>, tensor<1x1x6x2xf32>
 }
+
+// CHECK-LABEL: func @testAdd() -> tensor<2x2xi32>
+func @testAdd() -> tensor<2x2xi32> {
+^bb0:
+  %0 = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
+  %1 = constant dense<1> : tensor<2xi32>
+  %2 = "tf.Add"(%0, %1) {device = "", name = "add"} : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
+  // CHECK:         [[cst:%.*]] = constant dense<{{\[\[}}1, 2], {{\[}}3, 4]]> : tensor<2x2xi32>
+  // CHECK-NEXT:    return [[cst]] : tensor<2x2xi32>
+  return %2: tensor<2x2xi32>
+}
+
+// Ops with side effects should not get constant folded.
+// CHECK-LABEL: func @testSideEffectOp() -> tensor<3xf32>
+func @testSideEffectOp() -> tensor<3xf32> {
+  %0 = constant dense<[3]> : tensor<1xi32>
+  %1 = "tf.RandomUniform"(%0) {device = "", seed = 3 : i64, seed2 = 5 : i64} : (tensor<1xi32>) -> tensor<3xf32>
+  // CHECK: %[[random:.*]] = "tf.RandomUniform"
+  // CHECK: return %[[random]]
+  return %1: tensor<3xf32>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/control_to_executor_dialect.mlir b/tensorflow/compiler/mlir/tensorflow/tests/control_to_executor_dialect.mlir
index b1a9dd7..48f4c8f 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/control_to_executor_dialect.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/control_to_executor_dialect.mlir
@@ -79,7 +79,7 @@
 // CHECK-NEXT:       %{{[0-9]*}} = "tf.Add"(%[[IDENTITY]]#0, %[[CONST_ADD]]#0) {T =  "tfdtype$DT_INT32", device =  "", name =  "while/Add"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
 // CHECK-NEXT:       tf_executor.yield %{{[0-9]*}} : tensor<*xi32>
 // CHECK-NEXT:     }
-// CHECK-NEXT:     %[[CT:[0-9]*]] = tf_executor.ControlTrigger %2, %12#1, %9#1 {_tpu_replicate = "cluster", device = "", name = "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync"}
+// CHECK-NEXT:     %[[CT:[0-9]*]] = tf_executor.ControlTrigger %[[NOOP]], %[[ADD]]#1, %[[EXIT]]#1 {_tpu_replicate = "cluster", device = "", name = "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync"}
 // CHECK-NEXT:     tf_executor.NextIteration.Sink [%[[NEXTIT_SRC]]#1] %[[ADD]]#0, %[[CT]] : tensor<*xi32> {T =  "tfdtype$DT_INT32", device =  "", id = 0 : i64, name =  "while/NextIteration"}
 // CHECK-NEXT:     tf_executor.fetch
 // CHECK-NEXT:   }
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/empty-main.mlir b/tensorflow/compiler/mlir/tensorflow/tests/empty-main.mlir
new file mode 100644
index 0000000..4a4aa27
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/empty-main.mlir
@@ -0,0 +1,15 @@
+// RUN: tf-opt -tf-executor-to-control-conversion %s  | FileCheck %s --check-prefix=CONTROL --dump-input=fail
+// RUN: tf-opt -tf-control-to-executor-conversion %s  | FileCheck %s --check-prefix=EXECUTOR --dump-input=fail
+
+// CONTROL-LABEL: func @main
+// CONTROL-NEXT:    return
+
+// EXECUTOR-LABEL: func @main
+// EXECUTOR-NEXT:    tf_executor.graph {
+// EXECUTOR-NEXT:      tf_executor.fetch
+// EXECUTOR-NEXT:    }
+// EXECUTOR-NEXT:    return
+
+func @main() {
+  return
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_canonicalize.mlir
new file mode 100644
index 0000000..ba1dfd3
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_canonicalize.mlir
@@ -0,0 +1,248 @@
+// RUN: tf-opt %s -canonicalize | FileCheck %s --dump-input=fail
+
+
+// Test single graph with no outputs and one island is folded away.
+// CHECK-LABEL: func @graph_with_no_outputs
+// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
+func @graph_with_no_outputs(%arg0 : tensor<i1>) {
+  tf_executor.graph {
+    %1:2 = tf_executor.island {
+      %3 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
+      %4 = "tf.opB"(%3) : (tensor<i1>) -> tensor<i1>
+      tf_executor.yield %3 : tensor<i1>
+    }
+    tf_executor.fetch
+  }
+  return
+}
+
+// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]])
+// CHECK-NEXT: "tf.opB"(%[[OP_A]])
+// CHECK-NEXT: return
+
+
+// Test single graph with some outputs and one island is folded away.
+// CHECK-LABEL: func @graph_with_outputs
+// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
+func @graph_with_outputs(%arg0 : tensor<i1>) -> (tensor<i1>, tensor<i1>) {
+  %0:3 = tf_executor.graph {
+    %1:4 = tf_executor.island {
+      %3 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
+      %4 = "tf.opB"(%3) : (tensor<i1>) -> tensor<i1>
+      %5 = "tf.opC"(%4) : (tensor<i1>) -> tensor<i1>
+      tf_executor.yield %3, %5, %4 : tensor<i1>, tensor<i1>, tensor<i1>
+    }
+    tf_executor.fetch %1#1, %1#0, %1#2 : tensor<i1>, tensor<i1>, tensor<i1>
+  }
+  return %0#2, %0#1 : tensor<i1>, tensor<i1>
+}
+
+// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]])
+// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]])
+// CHECK-NEXT: "tf.opC"(%[[OP_B]])
+// CHECK-NEXT: return %[[OP_B]], %[[OP_A]] : tensor<i1>, tensor<i1>
+
+
+// Test nested graphs and islands.
+// CHECK-LABEL: func @nested_graph
+// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
+func @nested_graph(%arg0 : tensor<i1>) -> (tensor<i1>, tensor<i1>) {
+  %0:3 = tf_executor.graph {
+    %1:4 = tf_executor.island {
+      %2:3 = tf_executor.graph {
+        %3:4 = tf_executor.island {
+          %4 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
+          %5 = "tf.opB"(%4) : (tensor<i1>) -> tensor<i1>
+          %6 = "tf.opC"(%5) : (tensor<i1>) -> tensor<i1>
+          tf_executor.yield %4, %6, %5 : tensor<i1>, tensor<i1>, tensor<i1>
+        }
+        tf_executor.fetch %3#2, %3#0, %3#1 : tensor<i1>, tensor<i1>, tensor<i1>
+      }
+      tf_executor.yield %2#1, %2#1, %2#0 : tensor<i1>, tensor<i1>, tensor<i1>
+    }
+    tf_executor.fetch %1#1, %1#0, %1#2 : tensor<i1>, tensor<i1>, tensor<i1>
+  }
+  return %0#2, %0#1 : tensor<i1>, tensor<i1>
+}
+
+// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]])
+// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]])
+// CHECK-NEXT: "tf.opC"(%[[OP_B]])
+// CHECK-NEXT: return %[[OP_B]], %[[OP_A]] : tensor<i1>, tensor<i1>
+
+
+// Test single graph with multiple islands is unmodified.
+// CHECK-LABEL: func @graph_with_multiple_islands
+// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
+func @graph_with_multiple_islands(%arg0 : tensor<i1>) -> (tensor<i1>, tensor<i1>) {
+  %0:3 = tf_executor.graph {
+    %1:4 = tf_executor.island {
+      %3 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
+      %4 = "tf.opB"(%3) : (tensor<i1>) -> tensor<i1>
+      %5 = "tf.opC"(%4) : (tensor<i1>) -> tensor<i1>
+      tf_executor.yield %3, %5, %4 : tensor<i1>, tensor<i1>, tensor<i1>
+    }
+    %6:3 = tf_executor.island {
+      %7 = "tf.opD"(%arg0) : (tensor<i1>) -> tensor<i1>
+      %8 = "tf.opE"(%7) : (tensor<i1>) -> tensor<i1>
+      tf_executor.yield %8, %7 : tensor<i1>, tensor<i1>
+    }
+    tf_executor.fetch %1#1, %1#0, %6#0 : tensor<i1>, tensor<i1>, tensor<i1>
+  }
+  return %0#2, %0#1 : tensor<i1>, tensor<i1>
+}
+
+// CHECK-NEXT: %[[GRAPH:[0-9]*]]:3 = tf_executor.graph {
+// CHECK-NEXT:   %[[ISLAND_0:[0-9]*]]:4 = tf_executor.island {
+// CHECK-NEXT:     %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]])
+// CHECK-NEXT:     %[[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]])
+// CHECK-NEXT:     %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_B]])
+// CHECK-NEXT:     tf_executor.yield %[[OP_A]], %[[OP_C]], %[[OP_B]] : tensor<i1>, tensor<i1>, tensor<i1>
+// CHECK:        %[[ISLAND_1:[0-9]*]]:3 = tf_executor.island {
+// CHECK-NEXT:     %[[OP_D:[0-9]*]] = "tf.opD"(%[[ARG_0]])
+// CHECK-NEXT:     %[[OP_E:[0-9]*]] = "tf.opE"(%[[OP_D]])
+// CHECK-NEXT:     tf_executor.yield %[[OP_E]], %[[OP_D]] : tensor<i1>, tensor<i1>
+// CHECK:        tf_executor.fetch %[[ISLAND_0]]#1, %[[ISLAND_0]]#0, %[[ISLAND_1]]#0 : tensor<i1>, tensor<i1>, tensor<i1>
+// CHECK:      return %[[GRAPH]]#2, %[[GRAPH]]#1 : tensor<i1>, tensor<i1>
+
+
+// Test single graph with an island and executor ops is unmodified.
+// CHECK-LABEL: func @graph_with_island_and_executor_op
+// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
+func @graph_with_island_and_executor_op(%arg0 : tensor<i1>) -> (tensor<i1>, tensor<i1>) {
+  %0:3 = tf_executor.graph {
+    %1:4 = tf_executor.island {
+      %3 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
+      %4 = "tf.opB"(%3) : (tensor<i1>) -> tensor<i1>
+      %5 = "tf.opC"(%4) : (tensor<i1>) -> tensor<i1>
+      tf_executor.yield %3, %5, %4 : tensor<i1>, tensor<i1>, tensor<i1>
+    }
+    %6:2 = tf_executor.LoopCond %1#0 : tensor<i1>
+    tf_executor.fetch %1#1, %1#0, %6#0 : tensor<i1>, tensor<i1>, tensor<i1>
+  }
+  return %0#2, %0#1 : tensor<i1>, tensor<i1>
+}
+
+// CHECK-NEXT: %[[GRAPH:[0-9]*]]:3 = tf_executor.graph {
+// CHECK-NEXT:   %[[ISLAND:[0-9]*]]:4 = tf_executor.island {
+// CHECK-NEXT:     %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]])
+// CHECK-NEXT:     %[[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]])
+// CHECK-NEXT:     %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_B]])
+// CHECK-NEXT:     tf_executor.yield %[[OP_A]], %[[OP_C]], %[[OP_B]] : tensor<i1>, tensor<i1>, tensor<i1>
+// CHECK:        %[[LOOP_COND:[0-9]*]]:2 = tf_executor.LoopCond %[[ISLAND]]#0
+// CHECK-NEXT:   tf_executor.fetch %[[ISLAND]]#1, %[[ISLAND]]#0, %[[LOOP_COND]]#0 : tensor<i1>, tensor<i1>, tensor<i1>
+// CHECK:      return %[[GRAPH]]#2, %[[GRAPH]]#1 : tensor<i1>, tensor<i1>
+
+
+// Test multiple graphs collapsed.
+// CHECK-LABEL: func @multiple_graphs
+// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
+func @multiple_graphs(%arg0 : tensor<i1>) -> (tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>) {
+  %0:4 = tf_executor.graph {
+    %2:4 = tf_executor.island {
+      %3 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
+      %4 = "tf.opB"(%3) : (tensor<i1>) -> tensor<i1>
+      %5 = "tf.opC"(%4) : (tensor<i1>) -> tensor<i1>
+      tf_executor.yield %3, %5, %4 : tensor<i1>, tensor<i1>, tensor<i1>
+    }
+    tf_executor.fetch %arg0, %2#0, %2#1, %2#2 : tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>
+  }
+  %1:3 = tf_executor.graph {
+    %6:3 = tf_executor.island {
+      %7 = "tf.opD"(%arg0) : (tensor<i1>) -> tensor<i1>
+      %8 = "tf.opE"(%7) : (tensor<i1>) -> tensor<i1>
+      tf_executor.yield %8, %7 : tensor<i1>, tensor<i1>
+    }
+    tf_executor.fetch %arg0, %6#0, %6#1 : tensor<i1>, tensor<i1>, tensor<i1>
+  }
+  return %1#1, %1#0, %1#2, %0#1, %0#0, %0#3 : tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>
+}
+
+// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]])
+// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]])
+// CHECK-NEXT: "tf.opC"(%[[OP_B]])
+// CHECK-NEXT: %[[OP_D:[0-9]*]] = "tf.opD"(%[[ARG_0]])
+// CHECK-NEXT: %[[OP_E:[0-9]*]] = "tf.opE"(%[[OP_D]])
+// CHECK-NEXT: return %[[OP_E]], %[[ARG_0]], %[[OP_D]], %[[OP_A]], %[[ARG_0]], %[[OP_B]] : tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>
+
+
+// Test empty graph with no outputs.
+// CHECK-LABEL: func @empty_graph_with_no_outputs
+func @empty_graph_with_no_outputs() {
+  tf_executor.graph {
+    tf_executor.fetch
+  }
+  return
+}
+
+// CHECK-NEXT: return
+
+
+// Test empty graph with some outputs.
+// CHECK-LABEL: func @empty_graph_with_outputs
+// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>, %[[ARG_1:[a-z0-9]*]]: tensor<i1>)
+func @empty_graph_with_outputs(%arg0 : tensor<i1>, %arg1 : tensor<i1>) -> (tensor<i1>, tensor<i1>) {
+  %0:2 = tf_executor.graph {
+    tf_executor.fetch %arg1, %arg0 : tensor<i1>, tensor<i1>
+  }
+  return %0#0, %0#1 : tensor<i1>, tensor<i1>
+}
+
+// CHECK-NEXT: return %[[ARG_1]], %[[ARG_0]] : tensor<i1>, tensor<i1>
+
+
+// Test multiple empty graphs.
+// CHECK-LABEL: func @empty_graphs
+// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>, %[[ARG_1:[a-z0-9]*]]: tensor<i1>)
+func @empty_graphs(%arg0 : tensor<i1>, %arg1 : tensor<i1>) -> (tensor<i1>, tensor<i1>) {
+  %0 = tf_executor.graph {
+    tf_executor.fetch %arg1 : tensor<i1>
+  }
+  tf_executor.graph {
+    tf_executor.fetch
+  }
+  %1 = tf_executor.graph {
+    tf_executor.fetch %arg0 : tensor<i1>
+  }
+  return %0, %1 : tensor<i1>, tensor<i1>
+}
+
+// CHECK-NEXT: return %[[ARG_1]], %[[ARG_0]] : tensor<i1>, tensor<i1>
+
+
+// Test empty graphs and graphs with a single island.
+// CHECK-LABEL: func @empty_and_filled_graphs
+// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
+func @empty_and_filled_graphs(%arg0 : tensor<i1>) -> (tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>) {
+  %0:4 = tf_executor.graph {
+    %2:4 = tf_executor.island {
+      %3 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
+      %4 = "tf.opB"(%3) : (tensor<i1>) -> tensor<i1>
+      %5 = "tf.opC"(%4) : (tensor<i1>) -> tensor<i1>
+      tf_executor.yield %3, %5, %4 : tensor<i1>, tensor<i1>, tensor<i1>
+    }
+    tf_executor.fetch %arg0, %2#0, %2#1, %2#2 : tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>
+  }
+  tf_executor.graph {
+    tf_executor.fetch
+  }
+  %1:3 = tf_executor.graph {
+    %6:3 = tf_executor.island {
+      %7 = "tf.opD"(%arg0) : (tensor<i1>) -> tensor<i1>
+      %8 = "tf.opE"(%7) : (tensor<i1>) -> tensor<i1>
+      tf_executor.yield %8, %7 : tensor<i1>, tensor<i1>
+    }
+    tf_executor.fetch %arg0, %6#0, %6#1 : tensor<i1>, tensor<i1>, tensor<i1>
+  }
+  %9 = tf_executor.graph {
+    tf_executor.fetch %arg0 : tensor<i1>
+  }
+  return %1#1, %1#0, %9, %0#1, %0#0, %0#3 : tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>
+}
+
+// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]])
+// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]])
+// CHECK-NEXT: "tf.opC"(%[[OP_B]])
+// CHECK-NEXT: %[[OP_D:[0-9]*]] = "tf.opD"(%[[ARG_0]])
+// CHECK-NEXT: %[[OP_E:[0-9]*]] = "tf.opE"(%[[OP_D]])
+// CHECK-NEXT: return %[[OP_E]], %[[ARG_0]], %[[ARG_0]], %[[OP_A]], %[[ARG_0]], %[[OP_B]] : tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir
new file mode 100644
index 0000000..a9e83dd
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir
@@ -0,0 +1,460 @@
+// RUN: tf-opt %s -tf-executor-island-coarsening | FileCheck %s --dump-input=fail
+
+
+// Test that islands linked by a control dependency are merged.
+// CHECK-LABEL: func @control_input
+// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
+func @control_input(%arg0 : tensor<i1>) -> tensor<f32> {
+  %0 = tf_executor.graph {
+    %1:2 = tf_executor.island {
+      %3 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
+      tf_executor.yield %3 : tensor<i1>
+    }
+    %2:2 = tf_executor.island(%1#1) {
+      %4 = "tf.opB"() : () -> tensor<f32>
+      tf_executor.yield %4 : tensor<f32>
+    }
+    tf_executor.fetch %2#0 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// CHECK:        %[[ISLAND:[0-9]*]]:2 = tf_executor.island {
+// CHECK-NEXT:     "tf.opA"(%[[ARG_0]])
+// CHECK-NEXT:     %[[OP_B:[0-9]*]] = "tf.opB"
+// CHECK-NEXT:     tf_executor.yield %[[OP_B]] : tensor<f32>
+// CHECK:        tf_executor.fetch %[[ISLAND]]#0 : tensor<f32>
+
+
+// Test that islands linked by a data dependency are merged.
+// CHECK-LABEL: func @data_input
+// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
+func @data_input(%arg0 : tensor<i1>) -> tensor<i1> {
+  %0 = tf_executor.graph {
+    %1:2 = tf_executor.island {
+      %3 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
+      tf_executor.yield %3 : tensor<i1>
+    }
+    %2:2 = tf_executor.island {
+      %4 = "tf.opB"(%1#0) : (tensor<i1>) -> tensor<i1>
+      tf_executor.yield %4 : tensor<i1>
+    }
+    tf_executor.fetch %2#0 : tensor<i1>
+  }
+  return %0 : tensor<i1>
+}
+
+// CHECK:        %[[ISLAND:[0-9]*]]:2 = tf_executor.island {
+// CHECK-NEXT:     %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]])
+// CHECK-NEXT:     %[[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]])
+// CHECK-NEXT:     tf_executor.yield %[[OP_B]] : tensor<i1>
+// CHECK:        tf_executor.fetch %[[ISLAND]]#0 : tensor<i1>
+
+
+// Test empty/trivial islands are merged.
+// CHECK-LABEL: func @empty_islands
+// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>, %[[ARG_1:[a-z0-9]*]]: tensor<i1>)
+func @empty_islands(%arg0 : tensor<i1>, %arg1 : tensor<i1>) -> (tensor<i1>, tensor<i1>) {
+  %0:2 = tf_executor.graph {
+    %1:2 = tf_executor.island {
+      tf_executor.yield %arg1 : tensor<i1>
+    }
+    %2:2 = tf_executor.island {
+      tf_executor.yield %arg0 : tensor<i1>
+    }
+    %3:2 = tf_executor.island {
+      tf_executor.yield %1#0 : tensor<i1>
+    }
+    %4:2 = tf_executor.island {
+      tf_executor.yield %2#0 : tensor<i1>
+    }
+    %5:3 = tf_executor.island {
+      %10:2 = "tf.opA"(%3#0, %4#0) : (tensor<i1>, tensor<i1>) -> (tensor<i1>, tensor<i1>)
+      tf_executor.yield %10#0, %10#1 : tensor<i1>, tensor<i1>
+    }
+    %6:2 = tf_executor.island {
+      tf_executor.yield %5#0 : tensor<i1>
+    }
+    %7:2 = tf_executor.island {
+      tf_executor.yield %5#1 : tensor<i1>
+    }
+    %8:3 = tf_executor.island {
+      tf_executor.yield %6#0, %7#0 : tensor<i1>, tensor<i1>
+    }
+    %9 = tf_executor.island(%8#2) {
+      tf_executor.yield
+    }
+    tf_executor.fetch %8#0, %8#1 : tensor<i1>, tensor<i1>
+  }
+  return %0#0, %0#1 : tensor<i1>, tensor<i1>
+}
+
+// CHECK:        %[[ISLAND:[0-9]*]]:3 = tf_executor.island {
+// CHECK-NEXT:     %[[OP_A:[0-9]*]]:2 = "tf.opA"(%[[ARG_1]], %[[ARG_0]])
+// CHECK-NEXT:     tf_executor.yield %[[OP_A]]#0, %[[OP_A]]#1 : tensor<i1>, tensor<i1>
+// CHECK:        tf_executor.fetch %[[ISLAND]]#0, %[[ISLAND]]#1 : tensor<i1>, tensor<i1>
+
+
+// Test merging islands handle merging results.
+// CHECK-LABEL: func @multiple_outputs
+// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>, %[[ARG_1:[a-z0-9]*]]: tensor<i1>)
+func @multiple_outputs(%arg0 : tensor<i1>, %arg1 : tensor<i1>) -> (tensor<i1>, tensor<i1>) {
+  %0:2 = tf_executor.graph {
+    %1:2 = tf_executor.island {
+      %3 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
+      tf_executor.yield %3 : tensor<i1>
+    }
+    %2:2 = tf_executor.island(%1#1) {
+      %4 = "tf.opB"(%arg1) : (tensor<i1>) -> tensor<i1>
+      tf_executor.yield %4 : tensor<i1>
+    }
+    tf_executor.fetch %1#0, %2#0 : tensor<i1>, tensor<i1>
+  }
+  return %0#0, %0#1 : tensor<i1>, tensor<i1>
+}
+
+// CHECK:        %[[ISLAND:[0-9]*]]:3 = tf_executor.island {
+// CHECK-NEXT:     %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]])
+// CHECK-NEXT:     %[[OP_B:[0-9]*]] = "tf.opB"(%[[ARG_1]])
+// CHECK-NEXT:     tf_executor.yield %[[OP_A]], %[[OP_B]] : tensor<i1>, tensor<i1>
+// CHECK:        tf_executor.fetch %[[ISLAND]]#0, %[[ISLAND]]#1 : tensor<i1>, tensor<i1>
+
+
+// Test merging islands with multiple inner ops.
+// CHECK-LABEL: func @multi_op_regions
+// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i32>, %[[ARG_1:[a-z0-9]*]]: tensor<i32>)
+func @multi_op_regions(%arg0 : tensor<i32>, %arg1 : tensor<i32>) -> tensor<i32> {
+  %0 = tf_executor.graph {
+    %1:2 = tf_executor.island {
+      %2 = "tf.opA"(%arg0, %arg0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+      %3 = "tf.opB"(%2, %arg0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+      tf_executor.yield %3 : tensor<i32>
+    }
+    %4:2 = tf_executor.island {
+      %5 = "tf.opC"(%1#0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+      %6 = "tf.opD"(%5, %arg0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+      tf_executor.yield %6 : tensor<i32>
+    }
+    tf_executor.fetch %4#0 : tensor<i32>
+  }
+  return %0 : tensor<i32>
+}
+
+// CHECK:        %[[ISLAND:[0-9]*]]:2 = tf_executor.island {
+// CHECK-NEXT:     %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]], %[[ARG_0]])
+// CHECK-NEXT:     %[[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]], %[[ARG_0]])
+// CHECK-NEXT:     %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_B]], %[[ARG_1]])
+// CHECK-NEXT:     %[[OP_D:[0-9]*]] = "tf.opD"(%[[OP_C]], %[[ARG_0]])
+// CHECK-NEXT:     tf_executor.yield %[[OP_D]] : tensor<i32>
+// CHECK:        tf_executor.fetch %[[ISLAND]]#0 : tensor<i32>
+
+
+// Test merging multiple islands with multiple inner ops preserves order.
+// CHECK-LABEL: func @transitive_preserve_order
+// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i32>, %[[ARG_1:[a-z0-9]*]]: tensor<i32>)
+func @transitive_preserve_order(%arg0 : tensor<i32>, %arg1 : tensor<i32>) -> tensor<i32> {
+  %0 = tf_executor.graph {
+    %1:2 = tf_executor.island {
+      %2 = "tf.opA"(%arg0, %arg0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+      %3 = "tf.opB"(%2, %arg0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+      tf_executor.yield %3 : tensor<i32>
+    }
+    %4:2 = tf_executor.island {
+      %5 = "tf.opC"(%1#0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+      %6 = "tf.opD"(%5, %arg0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+      tf_executor.yield %6 : tensor<i32>
+    }
+    %7:2 = tf_executor.island {
+      %8 = "tf.opE"(%4#0, %1#0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+      %9 = "tf.opF"(%8, %8) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+      tf_executor.yield %9 : tensor<i32>
+    }
+    tf_executor.fetch %7#0 : tensor<i32>
+  }
+  return %0 : tensor<i32>
+}
+
+// CHECK:        %[[ISLAND:[0-9]*]]:2 = tf_executor.island {
+// CHECK-NEXT:     %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]], %[[ARG_0]])
+// CHECK-NEXT:     %[[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]], %[[ARG_0]])
+// CHECK-NEXT:     %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_B]], %[[ARG_1]])
+// CHECK-NEXT:     %[[OP_D:[0-9]*]] = "tf.opD"(%[[OP_C]], %[[ARG_0]])
+// CHECK-NEXT:     %[[OP_E:[0-9]*]] = "tf.opE"(%[[OP_D]], %[[OP_B]])
+// CHECK-NEXT:     %[[OP_F:[0-9]*]] = "tf.opF"(%[[OP_E]], %[[OP_E]])
+// CHECK-NEXT:     tf_executor.yield %[[OP_F]] : tensor<i32>
+// CHECK:        tf_executor.fetch %[[ISLAND]]#0 : tensor<i32>
+
+
+// Test if islands can be merged when non dependent islands are interleaved.
+// CHECK-LABEL: func @islands_interleaved
+// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i32>, %[[ARG_1:[a-z0-9]*]]: tensor<i32>)
+func @islands_interleaved(%arg0 : tensor<i32>, %arg1 : tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+  %0:2 = tf_executor.graph {
+    %1:2 = tf_executor.island {
+      %7 = "tf.opA"(%arg0) : (tensor<i32>) -> tensor<i32>
+      tf_executor.yield %7 : tensor<i32>
+    }
+    %2:2 = tf_executor.island {
+      %8 = "tf.opB"(%arg1) : (tensor<i32>) -> tensor<i32>
+      tf_executor.yield %8 : tensor<i32>
+    }
+    %3:2 = tf_executor.island {
+      %9 = "tf.opC"(%1#0) : (tensor<i32>) -> tensor<i32>
+      tf_executor.yield %9 : tensor<i32>
+    }
+    %4:2 = tf_executor.island {
+      %10 = "tf.opD"(%2#0) : (tensor<i32>) -> tensor<i32>
+      tf_executor.yield %10 : tensor<i32>
+    }
+    %5:2 = tf_executor.island(%3#1) {
+      %11 = "tf.opE"(%arg0) : (tensor<i32>) -> tensor<i32>
+      tf_executor.yield %11 : tensor<i32>
+    }
+    %6:2 = tf_executor.island {
+      %12 = "tf.opF"(%arg1) : (tensor<i32>) -> tensor<i32>
+      tf_executor.yield %12 : tensor<i32>
+    }
+    tf_executor.fetch %4#0, %3#0 : tensor<i32>, tensor<i32>
+  }
+  return %0#0, %0#1 : tensor<i32>, tensor<i32>
+}
+
+// CHECK:        %[[ISLAND_0:[0-9]*]]:2 = tf_executor.island {
+// CHECK-NEXT:     %[[OP_B:[0-9]*]] = "tf.opB"(%[[ARG_1]])
+// CHECK-NEXT:     %[[OP_D:[0-9]*]] = "tf.opD"(%[[OP_B]])
+// CHECK-NEXT:     tf_executor.yield %[[OP_D]] : tensor<i32>
+// CHECK:        %[[ISLAND_1:[0-9]*]]:2 = tf_executor.island {
+// CHECK-NEXT:     %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]])
+// CHECK-NEXT:     %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_A]])
+// CHECK-NEXT:     %{{[0-9]*}} = "tf.opE"(%[[ARG_0]])
+// CHECK-NEXT:     tf_executor.yield %[[OP_C]] : tensor<i32>
+// CHECK:        tf_executor.island {
+// CHECK-NEXT:     %[[OP_F:[0-9]*]] = "tf.opF"(%[[ARG_1]])
+// CHECK-NEXT:     tf_executor.yield %[[OP_F]] : tensor<i32>
+// CHECK:        tf_executor.fetch %[[ISLAND_0]]#0, %[[ISLAND_1]]#0 : tensor<i32>, tensor<i32>
+
+
+// Test only islands are merged when other tf_executor ops are interleaved.
+// CHECK-LABEL: func @merge_islands_only
+func @merge_islands_only() {
+  tf_executor.graph {
+    %0:2 = tf_executor.island {
+      %14 = "tf.opA"() : () -> tensor<i32>
+      tf_executor.yield %14 : tensor<i32>
+    }
+    %1:2 = tf_executor.Enter %0#0 frame "while/while_context" : (tensor<i32>) -> (tensor<*xi32>, !tf_executor.control)
+    %2 = tf_executor.island {
+      "tf.opB"() : () -> ()
+      tf_executor.yield
+    }
+    %3:3 = tf_executor.NextIteration.Source : tensor<*xi32>
+    %4:3 = tf_executor.Merge %3#0, %1#0 : tensor<*xi32>
+    %5:2 = tf_executor.island(%4#2) {
+      %15 = "tf.opC"() : () -> tensor<i32>
+      tf_executor.yield %15 : tensor<i32>
+    }
+    %6:2 = tf_executor.island {
+      %16 = "tf.opD"(%4#0, %5#0) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
+      tf_executor.yield %16 : tensor<*xi1>
+    }
+    %7:2 = tf_executor.LoopCond %6#0 : (tensor<*xi1>) -> (tensor<i1>, !tf_executor.control)
+    %8:3 = tf_executor.Switch %4#0, %7#0 : tensor<*xi32>
+    %9:2 = tf_executor.Exit %8#0 : tensor<*xi32>
+    %10:2 = tf_executor.island {
+      %17 = "tf.opE"(%8#1) : (tensor<*xi32>) -> tensor<*xi32>
+      tf_executor.yield %17 : tensor<*xi32>
+    }
+    %11:2 = tf_executor.island(%10#1) {
+      %18 = "tf.opF"() : () -> tensor<i32>
+      tf_executor.yield %18 : tensor<i32>
+    }
+    %12:2 = tf_executor.island {
+      %19 = "tf.opG"(%10#0, %11#0) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
+      tf_executor.yield %19 : tensor<*xi32>
+    }
+    %13 = tf_executor.ControlTrigger %2, %12#1, %9#1
+    tf_executor.NextIteration.Sink [%3#1] %12#0, %13 : tensor<*xi32>
+    tf_executor.fetch
+  }
+  return
+}
+
+// CHECK:        %[[ISLAND_0:[0-9]*]]:2 = tf_executor.island {
+// CHECK-NEXT:     %[[OP_A:.*]] = "tf.opA"
+// CHECK-NEXT:     tf_executor.yield %[[OP_A]] : tensor<i32>
+// CHECK:        %[[ENTER:[0-9]*]]:2 = tf_executor.Enter %[[ISLAND_0]]#0
+// CHECK-NEXT:   %[[ISLAND_1:[0-9]*]] = tf_executor.island {
+// CHECK-NEXT:     "tf.opB"()
+// CHECK-NEXT:     tf_executor.yield
+// CHECK:        %[[NEXTIT_SRC:[0-9]*]]:3 = tf_executor.NextIteration.Source
+// CHECK-NEXT:   %[[MERGE:[0-9]*]]:3 = tf_executor.Merge %[[NEXTIT_SRC]]#0, %[[ENTER]]#0
+// CHECK-NEXT:   %[[ISLAND_2:[0-9]*]]:2 = tf_executor.island(%[[MERGE]]#2) {
+// CHECK-NEXT:     %[[OP_C:.*]] = "tf.opC"
+// CHECK-NEXT:     %[[OP_D:[0-9]*]] = "tf.opD"(%[[MERGE]]#0, %[[OP_C]])
+// CHECK-NEXT:     tf_executor.yield %[[OP_D]] : tensor<*xi1>
+// CHECK:        %[[COND:[0-9]*]]:2 = tf_executor.LoopCond %[[ISLAND_2:[0-9]*]]#0
+// CHECK-NEXT:   %[[SWITCH:[0-9]*]]:3 = tf_executor.Switch %[[MERGE]]#0, %[[COND]]#0
+// CHECK-NEXT:   %[[EXIT:[0-9]*]]:2 = tf_executor.Exit %[[SWITCH]]#0
+// CHECK-NEXT:   %[[ISLAND_3:[0-9]*]]:2 = tf_executor.island {
+// CHECK-NEXT:     %[[OP_E:[0-9]*]] = "tf.opE"(%[[SWITCH]]#1)
+// CHECK-NEXT:     %[[OP_F:.*]] = "tf.opF"
+// CHECK-NEXT:     %[[OP_G:[0-9]*]] = "tf.opG"(%[[OP_E]], %[[OP_F]])
+// CHECK-NEXT:     tf_executor.yield %[[OP_G]] : tensor<*xi32>
+// CHECK:        %[[CT:[0-9]*]] = tf_executor.ControlTrigger %[[ISLAND_1]], %[[ISLAND_3]]#1, %[[EXIT]]#1
+// CHECK-NEXT:   tf_executor.NextIteration.Sink [%[[NEXTIT_SRC]]#1] %[[ISLAND_3]]#0, %[[CT]]
+
+
+// Test no merging took place as cycle would be formed otherwise.
+// CHECK-LABEL: func @simple_potential_cycle
+func @simple_potential_cycle() {
+  tf_executor.graph {
+    %0:2 = tf_executor.island {
+      %3 = "tf.opA"() : () -> tensor<1xf32>
+      tf_executor.yield %3 : tensor<1xf32>
+    }
+    %1 = tf_executor.ControlTrigger %0#1
+    %2:3 = tf_executor.island(%1) {
+      %4 = "tf.opB"() : () -> tensor<1xf32>
+      tf_executor.yield %0#0, %4 : tensor<1xf32>, tensor<1xf32>
+    }
+    tf_executor.fetch
+  }
+  return
+}
+
+// CHECK:        %[[ISLAND:[0-9]*]]:2 = tf_executor.island {
+// CHECK-NEXT:     %[[OP_A:[0-9]*]] = "tf.opA"
+// CHECK-NEXT:     tf_executor.yield %[[OP_A]] : tensor<1xf32>
+// CHECK:        %[[CT:[0-9]*]] = tf_executor.ControlTrigger %[[ISLAND]]#1
+// CHECK-NEXT:   tf_executor.island(%[[CT]]) {
+// CHECK-NEXT:     %[[OP_B:[0-9]*]] = "tf.opB"
+// CHECK-NEXT:     tf_executor.yield %[[ISLAND]]#0, %[[OP_B]] : tensor<1xf32>, tensor<1xf32>
+
+
+// Test if island was merged into its result.
+// CHECK-LABEL: func @merge_into_result
+func @merge_into_result() {
+  tf_executor.graph {
+    %0:2 = tf_executor.island {
+      %3 = "tf.opA"() : () -> tensor<1xf32>
+      tf_executor.yield %3 : tensor<1xf32>
+    }
+    %1 = tf_executor.ControlTrigger {}
+    %2:3 = tf_executor.island(%1) {
+      %4 = "tf.opB"() : () -> tensor<1xf32>
+      tf_executor.yield %0#0, %4 : tensor<1xf32>, tensor<1xf32>
+    }
+    tf_executor.fetch
+  }
+  return
+}
+
+// CHECK:        %[[CT:[0-9]*]] = tf_executor.ControlTrigger
+// CHECK-NEXT:   tf_executor.island(%[[CT]]) {
+// CHECK-NEXT:     "tf.opA"
+// CHECK-NEXT:     "tf.opB"
+// CHECK-NEXT:     tf_executor.yield
+
+
+// Test merging island into data result nested in a graph of another island.
+// CHECK-LABEL: func @merge_into_nested_data_result
+func @merge_into_nested_data_result() {
+  tf_executor.graph {
+    %0:2 = tf_executor.island {
+      %1 = "tf.opA"() : () -> tensor<1xf32>
+      tf_executor.yield %1 : tensor<1xf32>
+    }
+    %2:2 = tf_executor.island {
+      %3 = tf_executor.graph {
+        %4 = tf_executor.ControlTrigger {}
+        %5:2 = tf_executor.island(%4) {
+          %6 = "tf.opB"(%0#0) : (tensor<1xf32>) -> tensor<1xf32>
+          tf_executor.yield %6 : tensor<1xf32>
+        }
+        tf_executor.fetch %5#0 : tensor<1xf32>
+      }
+      tf_executor.yield %3 : tensor<1xf32>
+    }
+    tf_executor.fetch
+  }
+  return
+}
+
+// CHECK:        tf_executor.island {
+// CHECK-NEXT:     [[OP_A:[0-9*]]] = "tf.opA"
+// CHECK-NEXT:     [[INNER_GRAPH:[0-9]*]] = tf_executor.graph {
+// CHECK-NEXT:       [[CT:[0-9]*]] = tf_executor.ControlTrigger
+// CHECK-NEXT:       [[ISLAND_1:[0-9]*]]:2 = tf_executor.island(%[[CT]]) {
+// CHECK-NEXT:         [[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]])
+// CHECK-NEXT:         tf_executor.yield %[[OP_B]] : tensor<1xf32>
+// CHECK:            tf_executor.fetch %[[ISLAND_1]]#0 : tensor<1xf32>
+// CHECK:          tf_executor.yield
+
+
+// Test merging islands in a nested graph.
+// CHECK-LABEL: func @merge_islands_inner_graph
+func @merge_islands_inner_graph() {
+  tf_executor.graph {
+    %0:2 = tf_executor.island {
+      %1 = "tf.opA"() : () -> tensor<1xf32>
+      tf_executor.yield %1 : tensor<1xf32>
+    }
+    %2:2 = tf_executor.island {
+      %3 = tf_executor.graph {
+        %4:2 = tf_executor.island {
+          %5 = "tf.opB"() : () -> tensor<1xf32>
+          tf_executor.yield %5 : tensor<1xf32>
+        }
+        %6:2 = tf_executor.island {
+          %7 = "tf.opC"() : () -> tensor<1xf32>
+          tf_executor.yield %7 : tensor<1xf32>
+        }
+        %8:2 = tf_executor.island(%4#1) {
+          %9 = "tf.opD"(%6#0) : (tensor<1xf32>) -> tensor<1xf32>
+          tf_executor.yield %9 : tensor<1xf32>
+        }
+        tf_executor.fetch %8#0 : tensor<1xf32>
+      }
+      tf_executor.yield %3 : tensor<1xf32>
+    }
+    tf_executor.fetch
+  }
+  return
+}
+
+// CHECK:        tf_executor.island {
+// CHECK-NEXT:     [[OP_A:[0-9*]]] = "tf.opA"
+// CHECK-NEXT:     tf_executor.yield %[[OP_A]] : tensor<1xf32>
+// CHECK:        tf_executor.island {
+// CHECK-NEXT:     [[INNER_GRAPH:[0-9]*]] = tf_executor.graph {
+// CHECK-NEXT:       [[ISLAND_1:[0-9]*]]:2 = tf_executor.island {
+// CHECK-NEXT:         "tf.opB"
+// CHECK-NEXT:         [[OP_C:[0-9]*]] = "tf.opC"
+// CHECK-NEXT:         [[OP_D:[0-9]*]] = "tf.opD"(%[[OP_C]])
+// CHECK-NEXT:         tf_executor.yield %[[OP_D]] : tensor<1xf32>
+// CHECK:            tf_executor.fetch %[[ISLAND_1]]#0 : tensor<1xf32>
+// CHECK:          tf_executor.yield %[[INNER_GRAPH]] : tensor<1xf32>
+
+
+// Test merging islands with control island operands and island results only if
+// they are the closest ones.
+// CHECK-LABEL: func @merge_islands_closest_control
+func @merge_islands_closest_control() {
+  tf_executor.graph {
+    %0 = tf_executor.island {
+      tf_executor.yield
+    }
+    %1 = tf_executor.ControlTrigger %0
+    %2 = tf_executor.ControlTrigger {}
+    %3 = tf_executor.island(%0, %2) {
+      tf_executor.yield
+    }
+    tf_executor.fetch
+  }
+  return
+}
+
+// CHECK: %[[ISLAND:[0-9]*]] = tf_executor.island {
+// CHECK: tf_executor.ControlTrigger %[[ISLAND]]
+// CHECK: %[[CT:[0-9]*]] = tf_executor.ControlTrigger
+// CHECK: tf_executor.island(%[[ISLAND]], %[[CT]]) {
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir
index 79f471b..2a0434b 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir
@@ -113,7 +113,7 @@
 func @testWhile2Result(tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
 ^bb0(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>):
   %1:2 = "tf.While"(%arg0, %arg1) {
-    cond = @testWhile2Cond, body = @testWhile2Body
+    cond = @testWhile2Cond, body = @testWhile2Body, is_stateless = false
   } : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>)
 
 // CHECK:   br ^bb1(%arg0, %arg1 : tensor<*xf32>, tensor<*xf32>)
@@ -138,7 +138,7 @@
 func @testWhile0Result() {
 
 ^bb0:
-  "tf.While"() { cond = @testWhile0Cond, body = @testWhile0Body } : () -> ()
+  "tf.While"() { cond = @testWhile0Cond, body = @testWhile0Body, is_stateless = false } : () -> ()
 // CHECK:   br ^bb1
 // CHECK: ^bb1:
 // CHECK:   %0 = call @testWhile0Cond() : () -> tensor<i1>
@@ -162,7 +162,7 @@
 ^bb1(%0: tensor<*xf32>, %1: tensor<*xf32>):
   %2 = addf %0, %1 : tensor<*xf32>
   %3:2 = "tf.While"(%0, %2) {
-    cond = @testWhile2Cond, body = @testWhile2Body
+    cond = @testWhile2Cond, body = @testWhile2Body, is_stateless = false
   } : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>)
 
 // CHECK:   br ^bb2(%0, %2 : tensor<*xf32>, tensor<*xf32>)
@@ -194,7 +194,7 @@
 // CHECK-LABEL: func @testWhileCasts(%arg0: tensor<1x3xf32>)
 func @testWhileCasts(%arg0: tensor<1x3xf32>) -> (tensor<?x?xf32>) {
   %0 = "tf.While"(%arg0) {
-    cond = @testWhileCond, body = @testWhileBody
+    cond = @testWhileCond, body = @testWhileBody, is_stateless = false
   } : (tensor<1x3xf32>) -> (tensor<?x?xf32>)
 
 // CHECK:   %0 = tensor_cast %arg0 : tensor<1x3xf32> to tensor<?x3xf32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/functionalize-if-fail.mlir b/tensorflow/compiler/mlir/tensorflow/tests/functionalize-if-fail.mlir
index 779fe90..e13d558 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/functionalize-if-fail.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/functionalize-if-fail.mlir
@@ -1,4 +1,4 @@
-// RUN: tf-opt %s --run-tf-graph-optimization --graph-passes=FunctionalizeControlFlowPass 2>&1 | FileCheck %s; test ${PIPESTATUS[1]} -eq 0
+// RUN: tf-opt %s --run-tf-graph-optimization --graph-passes=FunctionalizeControlFlowPass 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0
 
 // CHECK:       FunctionalizeControlFlowPass: Graph contains node with inputs predicated on incompatible predicates: {s(Cond:0,then)} and {s(Cond:0,else)}
 // CHECK-NEXT:  for node {{[{][{]node Add[}][}]}}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/functionalize-if.mlir b/tensorflow/compiler/mlir/tensorflow/tests/functionalize-if.mlir
index d3b2d83..0d40a4d 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/functionalize-if.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/functionalize-if.mlir
@@ -1,4 +1,4 @@
-// RUN: tf-opt %s --run-tf-graph-optimization --graph-passes=FunctionalizeControlFlowPass | FileCheck %s
+// RUN: tf-opt %s --run-tf-graph-optimization --graph-passes=FunctionalizeControlFlowPass | FileCheck %s --dump-input-on-failure
 
 func @main() {
   %0 = "_tf._TPUReplicate"() {computation = @foo, Tinputs = [], Tbroadcast_inputs = [], NumVariables = 0, Tguaranteed_constants = [], output_types = []} : () -> !_tf.control loc("_TPUReplicate")
@@ -17,18 +17,18 @@
 
 // Match the name of the cloned function with functionalized control-flow at call site
 // CHECK: func @main()
-// CHECK-NEXT: computation = @[[FUNCTIONALIZE_FUNC:[A-Za-z0-9_]*]]
+// CHECK: computation = @[[FUNCTIONALIZE_FUNC:[A-Za-z0-9_]*]]
 
 
 // In the newly cloned function, check that we have a _tf.If operation and capture the then and else branch.
 // CHECK: func @[[FUNCTIONALIZE_FUNC]]
-// CHECK: "_tf.If"
+// CHECK: "tf.If"
 // CHECK-SAME:  else_branch = @[[ELSE_FUNC:[A-Za-z0-9_]*]]
 // CHECK-SAME:  then_branch = @[[THEN_FUNC:[A-Za-z0-9_]*]]
 
 // We expect the _tf.Add in the else func and the _tf.Mul in the then func
 
 // CHECK: func @[[ELSE_FUNC]]
-// CHECK: "_tf.Add"
+// CHECK: "tf.Add"
 // CHECK: func @[[THEN_FUNC]]
-// CHECK: "_tf.Mul"
+// CHECK: "tf.Mul"
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/add.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/add.pbtxt
index c9df1f2..a2b9eff 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/add.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/add.pbtxt
@@ -38,8 +38,14 @@
 
 # CHECK: func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32>
 # CHECK: attributes {tf.entry_function = {inputs = "input0, input1", outputs = "Add"}} {
-# CHECK:   %0:2 = "_tf.Placeholder.input"(%arg0) {device = "", dtype = "tfdtype$DT_INT32", name = "input0", shape = "tfshape$dim { size: 10 }"} : (tensor<10xi32>) -> (tensor<10xi32>, !_tf.control)
-# CHECK:   %1:2 = "_tf.Placeholder.input"(%arg1) {device = "", dtype = "tfdtype$DT_INT32", name = "input1", shape = "tfshape$dim { size: 10 }"} : (tensor<10xi32>) -> (tensor<10xi32>, !_tf.control)
-# CHECK:   %2:2 = "_tf.Add"(%0#0, %1#0) {T = "tfdtype$DT_INT32", device = "", name = "Add"} : (tensor<10xi32>, tensor<10xi32>) -> (tensor<10xi32>, !_tf.control)
-# CHECK:   return %2#0 : tensor<10xi32>
-# CHECK: }
+
+# CHECK:   %[[INPUT0:[0-9]+]]:2 = tf_executor.island
+# CHECK-NEXT: "tf.Placeholder.input"(%arg0)
+
+# CHECK:   %[[INPUT1:[0-9]+]]:2 = tf_executor.island
+# CHECK-NEXT: "tf.Placeholder.input"(%arg1)
+
+# CHECK:   %[[add:[0-9]+]]:2 = tf_executor.island
+# CHECK-NEXT: "tf.Add"(%[[INPUT0]]#0, %[[INPUT1]]#0)
+
+# CHECK:   fetch %[[add]]#0
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-control-dep.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-control-dep.pbtxt
index da77c16..74adc38 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-control-dep.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-control-dep.pbtxt
@@ -40,7 +40,9 @@
       }
     }
     # Drop the control dependency on arg for the node "test"
-    # CHECK:   %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "test", value = dense<0> : tensor<i32>} : () -> (tensor<i32>, !_tf.control)
+    # CHECK-LABEL: func @foo
+    # CHECK: tf_executor.island {
+    # CHECK-NEXT:   "tf.Const"()
     node_def {
       name: "test"
       op: "Const"
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/empty-value-attr.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/empty-value-attr.pbtxt
index 81466e6..93a2f60 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/empty-value-attr.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/empty-value-attr.pbtxt
@@ -75,8 +75,8 @@
 }
 
 # Match partitioned call in main and capture the callee name.
-# CHECK: func @main
-# CHECK-NEXT: _tf.PartitionedCall
+# CHECK-LABEL: func @main
+# CHECK: tf.PartitionedCall
 # CHECK-SAME: f = @[[FUNCTION:[a-zA-Z0-9_]*]]
 
 # Verify that callee has the unit attribute tf._input_shapes.
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-if-ops.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-if-ops.pbtxt
new file mode 100644
index 0000000..cbfa973
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-if-ops.pbtxt
@@ -0,0 +1,256 @@
+# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulIf,StatelessIf -o - | FileCheck %s
+
+# Verify that TensorFlow If and StatelessIf ops are mapped to the
+# composite If op in MLIR with is_stateless attribute set accordingly to
+# distinguish between them.
+
+# CHECK-DAG: "tf.If"{{.*}} is_stateless = false, name = "StatefulIf"
+# CHECK-DAG: "tf.If"{{.*}} is_stateless = true, name = "StatelessIf"
+
+node {
+  name: "tf.Less"
+  op: "Less"
+  input: "a"
+  input: "b"
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  experimental_debug_info {
+  }
+}
+node {
+  name: "StatefulIf"
+  op: "If"
+  input: "tf.Less"
+  input: "a"
+  input: "b"
+  attr {
+    key: "Tcond"
+    value {
+      type: DT_BOOL
+    }
+  }
+  attr {
+    key: "Tin"
+    value {
+      list {
+        type: DT_FLOAT
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    key: "Tout"
+    value {
+      list {
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    key: "else_branch"
+    value {
+      func {
+        name: "cond_false"
+      }
+    }
+  }
+  attr {
+    key: "then_branch"
+    value {
+      func {
+        name: "cond_true"
+      }
+    }
+  }
+  experimental_debug_info {
+  }
+}
+node {
+  name: "StatelessIf"
+  op: "StatelessIf"
+  input: "tf.Less"
+  input: "a"
+  input: "b"
+  attr {
+    key: "Tcond"
+    value {
+      type: DT_BOOL
+    }
+  }
+  attr {
+    key: "Tin"
+    value {
+      list {
+        type: DT_FLOAT
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    key: "Tout"
+    value {
+      list {
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    key: "else_branch"
+    value {
+      func {
+        name: "cond_false"
+      }
+    }
+  }
+  attr {
+    key: "then_branch"
+    value {
+      func {
+        name: "cond_true"
+      }
+    }
+  }
+  experimental_debug_info {
+  }
+}
+node {
+  name: "main"
+  op: "_Retval"
+  input: "StatefulIf"
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "index"
+    value {
+      i: 0
+    }
+  }
+}
+node {
+  name: "main1"
+  op: "_Retval"
+  input: "StatelessIf"
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "index"
+    value {
+      i: 1
+    }
+  }
+}
+node {
+  name: "a"
+  op: "Placeholder"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  experimental_debug_info {
+  }
+}
+node {
+  name: "b"
+  op: "Placeholder"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  experimental_debug_info {
+  }
+}
+library {
+  function {
+    signature {
+      name: "cond_true"
+      input_arg {
+        name: "cond_true"
+        type: DT_FLOAT
+      }
+      input_arg {
+        name: "cond_true1"
+        type: DT_FLOAT
+      }
+      output_arg {
+        name: "cond_true2"
+        type: DT_FLOAT
+      }
+    }
+    node_def {
+      name: "tf.Add"
+      op: "Add"
+      input: "cond_true"
+      input: "cond_true1"
+      attr {
+        key: "T"
+        value {
+          type: DT_FLOAT
+        }
+      }
+      experimental_debug_info {
+        original_node_names: "tf.Add"
+      }
+    }
+    ret {
+      key: "cond_true2"
+      value: "tf.Add:z:0"
+    }
+  }
+  function {
+    signature {
+      name: "cond_false"
+      input_arg {
+        name: "cond_false"
+        type: DT_FLOAT
+      }
+      input_arg {
+        name: "cond_false1"
+        type: DT_FLOAT
+      }
+      output_arg {
+        name: "cond_false2"
+        type: DT_FLOAT
+      }
+    }
+    node_def {
+      name: "tf.Mul"
+      op: "Mul"
+      input: "cond_false"
+      input: "cond_false1"
+      attr {
+        key: "T"
+        value {
+          type: DT_FLOAT
+        }
+      }
+      experimental_debug_info {
+        original_node_names: "tf.Mul"
+      }
+    }
+    ret {
+      key: "cond_false2"
+      value: "tf.Mul:z:0"
+    }
+  }
+}
+versions {
+  producer: 115
+  min_consumer: 12
+}
+
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt
new file mode 100644
index 0000000..953f83a
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt
@@ -0,0 +1,283 @@
+# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1 -o - | FileCheck %s
+
+# Verify that TensorFlow While and StatelessWhile ops are mapped to the
+# composite While op in MLIR with is_stateless attribute set accordingly to
+# distinguish between them.
+
+# CHECK-DAG: "tf.While"{{.*}} is_stateless = false, name = "StatefulWhile"
+# CHECK-DAG: "tf.While"{{.*}} is_stateless = true, name = "StatelessWhile"
+
+node {
+  name: "StatefulWhile"
+  op: "While"
+  input: "iter"
+  input: "val"
+  attr {
+    key: "T"
+    value {
+      list {
+        type: DT_INT32
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    key: "body"
+    value {
+      func {
+        name: "body"
+      }
+    }
+  }
+  attr {
+    key: "cond"
+    value {
+      func {
+        name: "cond"
+      }
+    }
+  }
+  experimental_debug_info {
+  }
+}
+node {
+  name: "StatelessWhile"
+  op: "StatelessWhile"
+  input: "iter"
+  input: "val"
+  attr {
+    key: "T"
+    value {
+      list {
+        type: DT_INT32
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    key: "body"
+    value {
+      func {
+        name: "body"
+      }
+    }
+  }
+  attr {
+    key: "cond"
+    value {
+      func {
+        name: "cond"
+      }
+    }
+  }
+  experimental_debug_info {
+  }
+}
+node {
+  name: "main"
+  op: "_Retval"
+  input: "StatefulWhile:1"
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "index"
+    value {
+      i: 0
+    }
+  }
+}
+node {
+  name: "main1"
+  op: "_Retval"
+  input: "StatelessWhile:1"
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "index"
+    value {
+      i: 1
+    }
+  }
+}
+node {
+  name: "iter"
+  op: "Placeholder"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_INT32
+    }
+  }
+  experimental_debug_info {
+  }
+}
+node {
+  name: "val"
+  op: "Placeholder"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  experimental_debug_info {
+  }
+}
+library {
+  function {
+    signature {
+      name: "cond"
+      input_arg {
+        name: "cond"
+        type: DT_INT32
+      }
+      input_arg {
+        name: "cond1"
+        type: DT_FLOAT
+      }
+      output_arg {
+        name: "cond2"
+        type: DT_BOOL
+      }
+    }
+    node_def {
+      name: "Const"
+      op: "Const"
+      attr {
+        key: "dtype"
+        value {
+          type: DT_INT32
+        }
+      }
+      attr {
+        key: "value"
+        value {
+          tensor {
+            dtype: DT_INT32
+            tensor_shape {
+            }
+            int_val: 0
+          }
+        }
+      }
+      experimental_debug_info {
+        original_node_names: "Const"
+      }
+    }
+    node_def {
+      name: "tf.Greater"
+      op: "Greater"
+      input: "cond"
+      input: "Const:output:0"
+      attr {
+        key: "T"
+        value {
+          type: DT_INT32
+        }
+      }
+      experimental_debug_info {
+        original_node_names: "tf.Greater"
+      }
+    }
+    ret {
+      key: "cond2"
+      value: "tf.Greater:z:0"
+    }
+  }
+  function {
+    signature {
+      name: "body"
+      input_arg {
+        name: "body"
+        type: DT_INT32
+      }
+      input_arg {
+        name: "body1"
+        type: DT_FLOAT
+      }
+      output_arg {
+        name: "body2"
+        type: DT_INT32
+      }
+      output_arg {
+        name: "body3"
+        type: DT_FLOAT
+      }
+    }
+    node_def {
+      name: "Const"
+      op: "Const"
+      attr {
+        key: "dtype"
+        value {
+          type: DT_INT32
+        }
+      }
+      attr {
+        key: "value"
+        value {
+          tensor {
+            dtype: DT_INT32
+            tensor_shape {
+            }
+            int_val: 1
+          }
+        }
+      }
+      experimental_debug_info {
+        original_node_names: "Const"
+      }
+    }
+    node_def {
+      name: "tf.Sub"
+      op: "Sub"
+      input: "body"
+      input: "Const:output:0"
+      attr {
+        key: "T"
+        value {
+          type: DT_INT32
+        }
+      }
+      experimental_debug_info {
+        original_node_names: "tf.Sub"
+      }
+    }
+    node_def {
+      name: "tf.Add"
+      op: "Add"
+      input: "body1"
+      input: "body1"
+      attr {
+        key: "T"
+        value {
+          type: DT_FLOAT
+        }
+      }
+      experimental_debug_info {
+        original_node_names: "tf.Add"
+      }
+    }
+    ret {
+      key: "body2"
+      value: "tf.Sub:z:0"
+    }
+    ret {
+      key: "body3"
+      value: "tf.Add:z:0"
+    }
+  }
+}
+versions {
+  producer: 115
+  min_consumer: 12
+}
+
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt
index 83c1d2d..9ce1531 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt
@@ -54,5 +54,5 @@
 # the names are matching between the function definition and the uses / call
 # site (a numerical suffix may be appended).
 
-# CHECK: "_tf.foo0"(
+# CHECK: "tf.foo0"(
 # CHECK: func @foo0
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-empty-tensor-content.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-empty-tensor-content.pbtxt
index c023c7e..12d05c1 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-empty-tensor-content.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-empty-tensor-content.pbtxt
@@ -3,7 +3,7 @@
 # This test is intended to verify the tensor_content field on import of an empty
 # tensor.
 # CHECK:  tf.Const
-# CHECK-SAME: value = opaque<"tf", "0x746674656E736F722464747970653A2044545F464C4F41542074656E736F725F7368617065207B2064696D207B2073697A653A2031207D207D">
+# CHECK-SAME: value = dense<0.000000e+00>
 
 node {
   name: "Const"
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-functional-while-loop.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-input-func-arg-name-collision.pbtxt
similarity index 100%
rename from tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-functional-while-loop.pbtxt
rename to tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-input-func-arg-name-collision.pbtxt
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt
index 760dffd..17b2655 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt
@@ -39,10 +39,10 @@
 # Verify that functions from the library are properly imported.
 
 # CHECK-LABEL:  func @main() {
-# CHECK:    "_tf.foo0"()
-# CHECK:    "_tf.bar0"()
+# CHECK:    "tf.foo0"()
+# CHECK:    "tf.bar0"()
 
 # CHECK-LABEL:  func @foo0() {
-# CHECK: "_tf.bar0"()
+# CHECK: "tf.bar0"()
 
 # CHECK-LABEL:  func @bar0() {
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-malformed.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-malformed.pbtxt
index 97e2225..0a5aba2 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-malformed.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-malformed.pbtxt
@@ -1,4 +1,4 @@
-# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[1]} -eq 0
+# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0
 
 this is not a valid graph def
 
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-scalar-input.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-scalar-input.pbtxt
index 01a8a11..37f7a87 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-scalar-input.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-scalar-input.pbtxt
@@ -4,10 +4,12 @@
 
 # CHECK: func @main(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>)
 # CHECK-NEXT: attributes {tf.entry_function = {inputs = "input", outputs = "out"}} {
-# CHECK:  "_tf.Placeholder.input"(%arg0)
+# CHECK:  "tf.Placeholder.input"(%arg0)
 
-# CHECK:  %[[IDENTITY:[0-9]+]]:3 = "_tf.IdentityN"
-# CHECK:  return %[[IDENTITY]]#1, %[[IDENTITY]]#0 : tensor<f32>, tensor<f32>
+# CHECK: tf.Relu
+# CHECK:  %[[IDENTITY:[0-9]+]]:3 = tf_executor.island
+# CHECK-NEXT: tf.Identity
+# CHECK:  fetch %[[IDENTITY]]#1, %[[IDENTITY]]#0 : tensor<f32>, tensor<f32>
 
 node {
   name: "input"
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-uint8-return.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-uint8-return.pbtxt
index 32b816f..9ae5601 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-uint8-return.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-uint8-return.pbtxt
@@ -104,8 +104,8 @@
 }
 
 # CHECK: func @main
-# CHECK: "_tf.PartitionedCall"()
+# CHECK: "tf.PartitionedCall"()
 # CHECK-SAME: Tout = ["tfdtype$DT_UINT8"]
 # CHECK-SAME: f = @[[FUNCTION:[A-Za-z0-9_]*]]
 # CHECK: func @[[FUNCTION]]() -> tensor<!tf.uint8>
-# CHECK: return {{%[0-9]*#[0-9]*}} : tensor<!tf.uint8>
+# CHECK: return {{.*}} : tensor<!tf.uint8>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-undefined-output.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-undefined-output.pbtxt
index 4fa8407..6816088 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-undefined-output.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-undefined-output.pbtxt
@@ -1,4 +1,4 @@
-# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes='' -tf-output-arrays=NotANodeInTheGraph -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[1]} -eq 0
+# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes='' -tf-output-arrays=NotANodeInTheGraph -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0
 
 # CHECK: Graph import failed: Invalid argument: Output NotANodeInTheGraph was not found in graph
 
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-version-info.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-version-info.pbtxt
index 5f8e785..20bf33d 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-version-info.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-version-info.pbtxt
@@ -29,7 +29,6 @@
             size: 2
           }
         }
-        tensor_content: "\350\251\242>\276\335r?"
       }
     }
   }
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt
index f60fb46..4ada2f6 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt
@@ -4,11 +4,10 @@
 # to break the cycle.
 
 # CHECK-LABEL: func @main()
-# CHECK:    %[[NEXTITERATION:[0-9]+]]:2 = "_tf.NextIteration.source"
-# CHECK:    tf.Merge"({{.*}} %[[NEXTITERATION]]#0)
+# CHECK:    %[[NEXTITERATION:[0-9]+]]:3 = tf_executor.NextIteration.Source
+# CHECK:    tf_executor.Merge {{.*}} %[[NEXTITERATION]]#0
 
-# CHECK:    %[[ADD:[0-9]+]]:2 = "_tf.Add"
-# CHECK:    "_tf.NextIteration.sink"(%[[ADD]]#0)
+# CHECK:    tf_executor.NextIteration.Sink [%[[NEXTITERATION]]#1]
 
 node {
   name: "Const"
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/invalid-output-index.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/invalid-output-index.pbtxt
new file mode 100644
index 0000000..6fec080
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/invalid-output-index.pbtxt
@@ -0,0 +1,14 @@
+# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes='' -tf-output-arrays=input:1 -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0
+
+# CHECK: Graph import failed: Invalid argument: Invalid output index 1 specified for node: input
+
+node {
+  name: "input"
+  op: "Placeholder"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/legacy-fed-input-without-inputs.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/legacy-fed-input-without-inputs.pbtxt
new file mode 100644
index 0000000..c6d00a6
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/legacy-fed-input-without-inputs.pbtxt
@@ -0,0 +1,30 @@
+# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes='' -tf-output-arrays=input -tf-convert-legacy-fed-inputs -o - | FileCheck %s
+
+# Verify that invalid LegacyFedInput ops without any inputs are replaced with
+# Placeholder ops.
+
+# CHECK-NOT: LegacyFedInput
+# CHECK: tf.Placeholder.input{{.*}}(tensor<f32>) -> tensor<f32>
+# CHECK-NOT: LegacyFedInput
+
+node {
+  name: "input"
+  op: "LegacyFedInput"
+  attr {
+    key: "input_def"
+    value {
+      s: "name: \"batch_1\"\n[dist_belief.ImageInputDef.ext] {\n  num_rows: 128\n  num_cols: 128\n  mean_value: 128\n  std_value: 128\n  colorspace: RGB\n}\n"
+    }
+  }
+  attr {
+    key: "output_types"
+    value {
+      list {
+        type: DT_FLOAT
+      }
+    }
+  }
+}
+versions {
+  producer: 27
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/multiple-use-next-iteration.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/multiple-use-next-iteration.pbtxt
index b8d7cfe..09a900e 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/multiple-use-next-iteration.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/multiple-use-next-iteration.pbtxt
@@ -4,9 +4,9 @@
 # Imported.
 
 # CHECK-LABEL: func @main()
-# CHECK:         %[[NEXTITERATION:[0-9]+]]:2 = "_tf.NextIteration.source"
-# CHECK:         "_tf.Merge"({{.*}}, %[[NEXTITERATION]]#0)
-# CHECK:         "_tf.Merge"({{.*}}, %[[NEXTITERATION]]#0)
+# CHECK:         %[[NEXTITERATION:[0-9]+]]:3 = tf_executor.NextIteration.Source
+# CHECK:         tf_executor.Merge {{.*}}, %[[NEXTITERATION]]#0
+# CHECK:         tf_executor.Merge {{.*}}, %[[NEXTITERATION]]#0
 
 node {
   name: "Const"
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/quint8-const.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/quint8-const.pbtxt
index 0962647..748bc99 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/quint8-const.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/quint8-const.pbtxt
@@ -27,6 +27,6 @@
   producer: 70
 }
 
-# CHECK: "_tf.Const"()
+# CHECK: tf.Const
 # CHECK-SAME: name = "Quantized_Constant"
 # CHECK-SAME: value = opaque<"tf", "{{0[xX][0-9a-fA-F]*}}"> : tensor<!tf.quint8>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/stateful-attribute.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/stateful-attribute.pbtxt
index 3200715..54877e8 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/stateful-attribute.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/stateful-attribute.pbtxt
@@ -82,7 +82,7 @@
 
 # Find PartitionedCall ops in main and match the callee name.
 # CHECK: func @main
-# CHECK: "_tf.PartitionedCall"
+# CHECK: "tf.PartitionedCall"
 # CHECK-SAME: f = @[[FUNCTION_FOO:[a-zA-Z0-9_]*]]
 
 # Find callee and verify it has the stateful attribute set.
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/string-attr.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/string-attr.pbtxt
index c6f0730..707b044 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/string-attr.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/string-attr.pbtxt
@@ -1,4 +1,9 @@
 # RUN: tf-mlir-translate -graphdef-to-splatted-mlir %s -o - | FileCheck %s
+
+# CHECK: tf.Const
+# CHECK-SAME: _output_shapes = ["tfshape$dim { size: 3 }"]
+# CHECK-SAME: value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B2073697A653A2033207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C30303022"> : tensor<3x!tf.string>
+
 node {
   name: "save/SaveV2/shape_and_slices"
   op: "Const"
@@ -40,8 +45,3 @@
 versions {
   producer: 74
 }
-
-# CHECK: func @main() {
-# CHECK-NEXT: %0:2 = "_tf.Const"() {_output_shapes = ["tfshape$dim { size: 3 }"], device = "", dtype = "tfdtype$DT_STRING", name = "save/SaveV2/shape_and_slices", value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B2073697A653A2033207D207D20737472696E675F76616C3A20222220737472696E675F76616C3A20222220737472696E675F76616C3A202222"> : tensor<3x!tf.string>} : () -> (tensor<3x!tf.string>, !_tf.control)
-# CHECK-NEXT: return
-# CHECK-NEXT: }
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/tensor-list.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/tensor-list.pbtxt
index a8802a9..cc24caa 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/tensor-list.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/tensor-list.pbtxt
@@ -209,10 +209,10 @@
 }
 
 # Verify that list element shape and dtype are expected.
-# CHECK:  _tf.TensorListReserve{{.*}}(tensor<2xi32>, tensor<i32>) -> (tensor<!tf.variant<tensor<2x2xf32>>>, !_tf.control)
+# CHECK:  tf.TensorListReserve{{.*}}(tensor<2xi32>, tensor<i32>) -> tensor<!tf.variant<tensor<2x2xf32>>>
 
 # Nested variant type.
-# CHECK:  _tf.TensorListReserve{{.*}}(tensor<2xi32>, tensor<i32>) -> (tensor<!tf.variant<tensor<2x2x!tf.variant>>>, !_tf.control)
+# CHECK:  tf.TensorListReserve{{.*}}(tensor<2xi32>, tensor<i32>) -> tensor<!tf.variant<tensor<2x2x!tf.variant>>>
 
-# CHECK:  _tf.TensorListSetItem{{.*}}(tensor<!tf.variant<tensor<2x2xf32>>>, tensor<i32>, tensor<2x2xf32>) -> (tensor<!tf.variant<tensor<2x2xf32>>>, !_tf.control)
-# CHECK:  _tf.TensorListStack{{.*}}(tensor<!tf.variant<tensor<2x2xf32>>>, tensor<i32>) -> (tensor<?x2x2xf32>, !_tf.control)
+# CHECK:  tf.TensorListSetItem{{.*}}(tensor<!tf.variant<tensor<2x2xf32>>>, tensor<i32>, tensor<2x2xf32>) -> tensor<!tf.variant<tensor<2x2xf32>>>
+# CHECK:  tf.TensorListStack{{.*}}(tensor<!tf.variant<tensor<2x2xf32>>>, tensor<i32>) -> tensor<?x2x2xf32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/isolate-placer.mlir b/tensorflow/compiler/mlir/tensorflow/tests/isolate-placer.mlir
index 2259d30..4566ffb 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/isolate-placer.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/isolate-placer.mlir
@@ -13,15 +13,19 @@
 // The IsolatePlacerInspectionRequiredOpsPass adds Identities for each input/output of function-calling ops.
 
 // Capture the result of input to function call.
-// CHECK:      [[VARIABLE_REG:%[0-9]*]]:2 = "_tf.VarHandleOp"()
+// CHECK:      [[VARIABLE_REG:%[0-9]*]]:2 = tf_executor.island
+// CHECK-NEXT:      "tf.VarHandleOp"()
 
 // Test for the presence of Identity op between input and function call.
-// CHECK-NEXT: [[IDENTITY_REG:%[0-9]*]]:2 = "_tf.Identity"([[VARIABLE_REG]]#0)
-// CHECK-NEXT: [[CALL_RESULT_REG:%[0-9]*]]:2 = "_tf.StatefulPartitionedCall"([[IDENTITY_REG]]#0)
+// CHECK: [[IDENTITY_REG:%[0-9]*]]:2 = tf_executor.island
+// CHECK-NEXT: "tf.Identity"([[VARIABLE_REG]]#0)
+
+// CHECK: [[CALL_RESULT_REG:%[0-9]*]]:2 = tf_executor.island
+// CHECK-NEXT: "tf.StatefulPartitionedCall"([[IDENTITY_REG]]#0)
 // CHECK-SAME: f = @[[FUNCTION:[a-zA-Z0-9_]*]]
 
 // Match the inserted Identity op for call output.
-// CHECK-NEXT: "_tf.Identity"([[CALL_RESULT_REG]]#0)
+// CHECK: "tf.Identity"([[CALL_RESULT_REG]]#0)
 
 // Match the function name
 // CHECK: func @[[FUNCTION]]
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-if-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-if-ops.mlir
new file mode 100644
index 0000000..ccd0588
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-if-ops.mlir
@@ -0,0 +1,34 @@
+// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
+
+func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
+  %0 = "tf.Placeholder.input"(%arg0) : (tensor<f32>) -> tensor<f32>
+  %1 = "tf.Placeholder.input"(%arg1) : (tensor<f32>) -> tensor<f32>
+  %2 = "tf.Less"(%0, %1) : (tensor<f32>, tensor<f32>) -> tensor<i1>
+  %3 = "tf.If"(%2, %0, %1) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = false} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> loc("StatefulIf")
+  %4 = "tf.If"(%2, %0, %1) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> loc("StatelessIf")
+  return %3, %4 : tensor<f32>, tensor<f32>
+}
+
+func @cond_true(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
+  %0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+func @cond_false(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
+  %0 = "tf.Mul"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// Verify that If op is mapped to TensorFlow StatelessIf op if the is_stateless
+// attribute is present and otherwise it is mapped to TensorFlow If op. In both
+// cases, the additional attribute should be dropped.
+
+// CHECK: name: "StatefulIf"
+// CHECK-NOT: name:
+// CHECK: op: "If"
+// CHECK-NOT: is_stateless
+
+// CHECK: name: "StatelessIf"
+// CHECK-NOT: name:
+// CHECK: op: "StatelessIf"
+// CHECK-NOT: is_stateless
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir
new file mode 100644
index 0000000..0009c7a
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir
@@ -0,0 +1,43 @@
+// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
+
+func @main(%arg0: tensor<i32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
+  %iter = "tf.Placeholder.input"(%arg0) : (tensor<i32>) -> tensor<i32> loc("iter")
+  %val = "tf.Placeholder.input"(%arg1) : (tensor<f32>) -> tensor<f32> loc("val")
+
+  // Element wise add `val` with itself for `iter` number of times.
+  %2:2 = "tf.While"(%iter, %val) {
+    cond = @cond, body = @body, is_stateless = false
+  } : (tensor<i32>, tensor<f32>) -> (tensor<i32>, tensor<f32>) loc("StatefulWhile")
+  %3:2 = "tf.While"(%iter, %val) {
+    cond = @cond, body = @body, is_stateless = true
+  } : (tensor<i32>, tensor<f32>) -> (tensor<i32>, tensor<f32>) loc("StatelessWhile")
+
+  return %2#1, %3#1 : tensor<f32>, tensor<f32>
+}
+
+func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
+  %0 = "tf.Const" () {value = dense<0> : tensor<i32>} : () -> tensor<i32> loc("Const")
+  %1 = "tf.Greater"(%arg0, %0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
+  return %1 : tensor<i1>
+}
+
+func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>) {
+  %0 = "tf.Const" () {value = dense<1> : tensor<i32>} : () -> tensor<i32> loc("Const")
+  %1 = "tf.Sub"(%arg0, %0) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
+  %2 = "tf.Add"(%arg1, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+  return %1, %2 : tensor<*xi32>, tensor<*xf32>
+}
+
+// Verify that While op is mapped to TensorFlow StatelessWhile op if the
+// is_stateless attribute is present and otherwise it is mapped to TensorFlow
+// While op. In both cases, the additional attribute should be dropped.
+
+// CHECK: name: "StatefulWhile"
+// CHECK-NOT: name:
+// CHECK: op: "While"
+// CHECK-NOT: is_stateless
+
+// CHECK: name: "StatelessWhile"
+// CHECK-NOT: name:
+// CHECK: op: "StatelessWhile"
+// CHECK-NOT: is_stateless
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/missing-main.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/missing-main.mlir
index 041be4b..f73e933 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/missing-main.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/missing-main.mlir
@@ -1,4 +1,4 @@
-// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[1]} -eq 0
+// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0
 
 // CHECK: Graph export failed: Failed precondition: entry function `main` must be present
 
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/roundtrip-tf-control.mlir b/tensorflow/compiler/mlir/tensorflow/tests/roundtrip-tf-control.mlir
new file mode 100644
index 0000000..271b6ec
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/roundtrip-tf-control.mlir
@@ -0,0 +1,12 @@
+// RUN: tf-opt %s --run-tf-graph-optimization --graph-passes=MlirRoundtripPass | FileCheck %s --dump-input-on-failure
+
+// The test uses the tf_graph_optimization_pass to run the MlirRoundtripPass.
+// We convert mlir -> Graph -> mlir -> Graph -> mlir
+
+func @main() {
+  "_tf.NoOp"() {} : () -> () loc("X")
+  return
+}
+
+// Check for the presence of tf.NoOp in the final output.
+// CHECK: tf.NoOp
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/roundtrip-tf-executor.mlir b/tensorflow/compiler/mlir/tensorflow/tests/roundtrip-tf-executor.mlir
new file mode 100644
index 0000000..6b24523
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/roundtrip-tf-executor.mlir
@@ -0,0 +1,19 @@
+// RUN: tf-opt %s --run-tf-graph-optimization --graph-passes=MlirRoundtripPass | FileCheck %s --dump-input-on-failure
+
+module {
+  func @main() {
+    tf_executor.graph {
+      %0 = tf_executor.island {
+        "tf.NoOp"() {} : () -> () loc("X")
+        tf_executor.yield
+      }
+      tf_executor.fetch
+    }
+    return
+  }
+}
+
+// The test uses the tf_graph_optimization_pass to run the MlirRoundtripPass.
+// We convert mlir -> Graph -> mlir -> Graph -> mlir
+// Check for the presence of tf.NoOp in the final output.
+// CHECK: tf.NoOp
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
index d37892d..a0dba3f 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
@@ -65,6 +65,23 @@
 
 // -----
 
+// CHECK-LABEL: func @testIdentity
+func @testIdentity(%arg0: tensor<4x2x!tf.stringref>) -> tensor<4x2x!tf.string> {
+  // CHECK: tf.Identity
+  %0 = "tf.Identity"(%arg0) : (tensor<4x2x!tf.stringref>) -> tensor<4x2x!tf.string>
+  return %0 : tensor<4x2x!tf.string>
+}
+
+// -----
+
+func @testIdentityWrongType(%arg0: tensor<4x2x!tf.string>) -> tensor<4x2x!tf.stringref> {
+  // expected-error @+1 {{requires all operands to be either same as or ref type of results}}
+  %0 = "tf.Identity"(%arg0) : (tensor<4x2x!tf.string>) -> tensor<4x2x!tf.stringref>
+  return %0 : tensor<4x2x!tf.stringref>
+}
+
+// -----
+
 // TODO(hinsu): Move this to MLIR core once the test dialect have a custom type.
 
 // Check that broadcastable trait accepts TF specific element type
@@ -133,9 +150,18 @@
 }
 
 // -----
-// CHECK-LABEL: func @testReshape(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<10000xf32>)
-func @testReshape(tensor<*xf32>, tensor<*xf32>, tensor<10000xf32>) -> (tensor<100x100xf32>, tensor<*xf32>, tensor<10000xf32>, tensor<100x100xf32>) {
-^bb0(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<10000xf32>):
+
+// CHECK-LABEL: func @testMul
+func @testMul(%arg0: tensor<2x!tf.uint16>) -> (tensor<2x!tf.uint16>) {
+  // CHECK: tf.Mul
+  %0 = "tf.Mul"(%arg0, %arg0) {T = "tfdtype$DT_UINT16", device = "/device:CPU:0", name = "Mul"} : (tensor<2x!tf.uint16>, tensor<2x!tf.uint16>) -> tensor<2x!tf.uint16>
+  return %0 : tensor<2x!tf.uint16>
+}
+
+// -----
+
+// CHECK-LABEL: func @testReshape(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<10000xf32>, %arg3: tensor<*xi32>)
+func @testReshape(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<10000xf32>, %arg3: tensor<*xi32>) -> (tensor<100x100xf32>, tensor<*xf32>, tensor<10000xf32>, tensor<100x100xf32>, tensor<*xf32>, tensor<*xf32>) {
   // CHECK: %cst = constant dense<100> : tensor<2xi32>
   %shape1 = constant dense<100> : tensor<2xi32>
   // CHECK: %0 = "tf.Reshape"(%arg0, %cst) : (tensor<*xf32>, tensor<2xi32>) -> tensor<100x100xf32>
@@ -150,7 +176,11 @@
   %shape3 = constant dense<[-1, 100]> : tensor<2xi32>
   // CHECK: %4 = "tf.Reshape"(%arg2, %cst_0) {T = "tfdtype$DT_FLOAT", Tshape = "tfdtype$DT_INT32", device = "", name = "Reshape_1"} : (tensor<10000xf32>, tensor<2xi32>) -> tensor<100x100xf32>
   %r4 = "tf.Reshape"(%arg2, %shape3) {device = "", name = "Reshape_1", T = "tfdtype$DT_FLOAT", Tshape = "tfdtype$DT_INT32"} : (tensor<10000xf32>, tensor<2xi32>) -> (tensor<100x100xf32>)
-  return %r1, %r2, %r3, %r4: tensor<100x100xf32>, tensor<*xf32>, tensor<10000xf32>, tensor<100x100xf32>
+  // CHECK: "tf.Reshape"(%arg0, %arg3)
+  %r5 = "tf.Reshape"(%arg0, %arg3) {T = "tfdtype$DT_FLOAT", Tshape = "tfdtype$DT_INT32"} : (tensor<*xf32>, tensor<*xi32>) -> (tensor<*xf32>)
+  // CHECK: "tf.Reshape"(%arg2, %arg3)
+  %r6 = "tf.Reshape"(%arg2, %arg3) {T = "tfdtype$DT_FLOAT", Tshape = "tfdtype$DT_INT32"} : (tensor<10000xf32>, tensor<*xi32>) -> (tensor<*xf32>)
+  return %r1, %r2, %r3, %r4, %r5, %r6: tensor<100x100xf32>, tensor<*xf32>, tensor<10000xf32>, tensor<100x100xf32>, tensor<*xf32>, tensor<*xf32>
 }
 
 // -----
@@ -628,13 +658,32 @@
 ^bb0(%arg0: tensor<*xf32>):
   %1 = "tf.While"(%arg0) {
     cond = @testWhileCond,
-    body = @testWhileBody
+    body = @testWhileBody,
+    is_stateless = false
   } : (tensor<*xf32>) -> (tensor<*xf32>)
 
   return %1 : tensor<*xf32>
 }
 
 // -----
+func @testWhileUndefinedCond(%arg0: tensor<i1>, %arg1: tensor<f32>) -> tensor<f32> {
+  // expected-error @+1 {{cond refers to an undefined function : undefined_func}}
+  %0 = "tf.While"(%arg0, %arg1) {cond = @undefined_func, body = @body, is_stateless = false} : (tensor<i1>, tensor<f32>) -> (tensor<f32>)
+  return %0 : tensor<f32>
+}
+
+func @body(%arg0: tensor<i1>, %arg1: tensor<f32>) -> tensor<f32>
+
+// -----
+func @testWhileUndefinedBody(%arg0: tensor<i1>, %arg1: tensor<f32>) -> tensor<f32> {
+  // expected-error @+1 {{body refers to an undefined function : undefined_func}}
+  %0 = "tf.While"(%arg0, %arg1) {cond = @cond, body = @undefined_func, is_stateless = false} : (tensor<i1>, tensor<f32>) -> (tensor<f32>)
+  return %0 : tensor<f32>
+}
+
+func @cond(%arg0: tensor<i1>, %arg1: tensor<f32>) -> tensor<i1>
+
+// -----
 
 func @testWhileCond(tensor<*xf32>) -> ()
 func @testWhileBody(tensor<*xf32>) -> (tensor<*xf32>)
@@ -645,7 +694,8 @@
   // expected-error @+1 {{requires cond function to have exactly one result}}
   %1 = "tf.While"(%arg0) {
     cond = @testWhileCond,
-    body = @testWhileBody
+    body = @testWhileBody,
+    is_stateless = false
   } : (tensor<*xf32>) -> (tensor<*xf32>)
 
   return %1 : tensor<*xf32>
@@ -662,7 +712,8 @@
   // expected-error @+1 {{operand type tensor<*xf32> is incompatible with result type}}
   %1 = "tf.While"(%arg0) {
     cond = @testWhileCond,
-    body = @testWhileBody
+    body = @testWhileBody,
+    is_stateless = false
   } : (tensor<*xf32>) -> (tensor<*xi32>)
 
   return %1 : tensor<*xi32>
@@ -679,7 +730,8 @@
   // expected-error @+1 {{operand type tensor<*xf32> is incompatible with cond function input type}}
   %1 = "tf.While"(%arg0) {
     cond = @testWhileCond,
-    body = @testWhileBody
+    body = @testWhileBody,
+    is_stateless = false
   } : (tensor<*xf32>) -> (tensor<*xf32>)
 
   return %1 : tensor<*xf32>
@@ -696,7 +748,8 @@
   // expected-error @+1 {{requires the number of operands to be equal to the number of body function inputs. Found 1 and 2, respectively}}
   %1 = "tf.While"(%arg0) {
     cond = @testWhileCond,
-    body = @testWhileBody
+    body = @testWhileBody,
+    is_stateless = false
   } : (tensor<*xf32>) -> (tensor<*xf32>)
 
   return %1 : tensor<*xf32>
@@ -713,7 +766,8 @@
   // expected-error @+1 {{body function result type tensor<*xi32> is incompatible with result type}}
   %1 = "tf.While"(%arg0) {
     cond = @testWhileCond,
-    body = @testWhileBody
+    body = @testWhileBody,
+    is_stateless = false
   } : (tensor<*xf32>) -> (tensor<*xf32>)
 
   return %1 : tensor<*xf32>
@@ -730,7 +784,8 @@
   // expected-error @+1 {{cond function input type tensor<3xf32> is incompatible with body function input type}}
   %1 = "tf.While"(%arg0) {
     cond = @testWhileCond,
-    body = @testWhileBody
+    body = @testWhileBody,
+    is_stateless = false
   } : (tensor<*xf32>) -> (tensor<*xf32>)
 
   return %1 : tensor<*xf32>
@@ -850,3 +905,4 @@
   %0 = "tf.ConcatV2"(%arg, %axis) {N = 1: i64} : (tensor<8x16xf32>, tensor<1xi32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
+
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir
index 510aacc..6eda784 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir
@@ -68,6 +68,30 @@
   return %0 : tensor<*xf32>
 }
 
+// CHECK-LABEL: func @simpleIsland_with_multiple_control_inputs(%arg0: tensor<*xf32>)
+func @simpleIsland_with_multiple_control_inputs(%arg0: tensor<*xf32>) -> tensor<*xf32> {
+  %0 = tf_executor.graph {
+    %1 = tf_executor.island {
+      tf_executor.yield
+    }
+    %2 = tf_executor.island {
+      tf_executor.yield
+    }
+    %3:2 = tf_executor.island(%1, %2) {
+      tf_executor.yield %arg0 : tensor<*xf32>
+    }
+    tf_executor.fetch %3#0 : tensor<*xf32>
+  }
+// CHECK:      %[[ISLAND0:[0-9]*]] = tf_executor.island {
+// CHECK-NEXT:   tf_executor.yield
+// CHECK:      %[[ISLAND1:[0-9]*]] = tf_executor.island {
+// CHECK-NEXT:   tf_executor.yield
+// CHECK:      %[[ISLAND2:[0-9]*]]:2 = tf_executor.island(%[[ISLAND0]], %[[ISLAND1]]) {
+// CHECK:      tf_executor.fetch %[[ISLAND2]]#0 : tensor<*xf32>
+
+  return %0 : tensor<*xf32>
+}
+
 // CHECK-LABEL: func @fetchWithControlDep(%arg0: tensor<*xf32>)
 func @fetchWithControlDep(%arg0: tensor<*xf32>) -> tensor<*xf32> {
   %result = tf_executor.graph {
@@ -210,6 +234,43 @@
   return %result : tensor<*xf32>
 }
 
+// Verify that long form printing is used when operand types do not match the
+// result type and then it can be parsed again correctly.
+// CHECK-LABEL: func @merge_different_operand_types
+func @merge_different_operand_types(%arg0: tensor<*xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
+  %result = tf_executor.graph {
+
+// CHECK: tf_executor.Merge{{.*}}(tensor<*xf32>, tensor<4xf32>) -> (tensor<4xf32>, tensor<i32>, !tf_executor.control)
+    %value, %idx, %ctlMerge = tf_executor.Merge %arg0, %arg1  : (tensor<*xf32>, tensor<4xf32>) -> (tensor<4xf32>, tensor<i32>, !tf_executor.control)
+    tf_executor.fetch %value : tensor<4xf32>
+  }
+  return %result : tensor<4xf32>
+}
+
+// Verify that long form printing is used when there is only one data operand
+// and then it can be parsed again correctly.
+// CHECK-LABEL: func @merge_one_data_operand
+func @merge_one_data_operand(%arg0: tensor<*xf32>) -> tensor<*xf32> {
+  %result = tf_executor.graph {
+
+// CHECK: tf_executor.Merge{{.*}}(tensor<*xf32>) -> (tensor<*xf32>, tensor<i32>, !tf_executor.control)
+    %value, %idx, %ctlMerge = tf_executor.Merge %arg0  : (tensor<*xf32>) -> (tensor<*xf32>, tensor<i32>, !tf_executor.control)
+    tf_executor.fetch %value : tensor<*xf32>
+  }
+  return %result : tensor<*xf32>
+}
+
+// CHECK-LABEL: func @merge_with_variant_type
+func @merge_with_variant_type(%arg0: tensor<!tf.variant>, %arg1: tensor<!tf.variant<tensor<4xi32>>>) -> tensor<!tf.variant<tensor<8xf32>>> {
+  %result = tf_executor.graph {
+
+// CHECK: tf_executor.Merge{{.*}}(tensor<!tf.variant>, tensor<!tf.variant<tensor<4xi32>>>) -> (tensor<!tf.variant<tensor<8xf32>>>, tensor<i32>, !tf_executor.control)
+    %value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<!tf.variant>, tensor<!tf.variant<tensor<4xi32>>>) -> (tensor<!tf.variant<tensor<8xf32>>>, tensor<i32>, !tf_executor.control)
+    tf_executor.fetch %value : tensor<!tf.variant<tensor<8xf32>>>
+  }
+  return %result : tensor<!tf.variant<tensor<8xf32>>>
+}
+
 // CHECK-LABEL: func @enter(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> {
 func @enter(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> {
   %result = tf_executor.graph {
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir
index 366cd82..90b245d 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir
@@ -27,7 +27,7 @@
 // Check that an empty graph is invalid (it needs a region).
 func @empty_graph() {
  "tf_executor.graph" () ({
-// expected-error@-1 {{'tf_executor.graph' op expects a non-empty body}}
+// expected-error@-1 {{'tf_executor.graph' op expects a non-empty block}}
  ^entry:
   }) : () -> ()
   return
@@ -47,6 +47,17 @@
 
 // -----
 
+// Check that tf_executor.graph can't be nested directly in a tf_executor.graph.
+func @nested_graph() {
+  tf_executor.graph {
+    tf_executor.graph {}
+// expected-error@-1 {{'tf_executor.graph' op unallowed directly inside another tf_executor.graph}}
+  }
+  return
+}
+
+// -----
+
 // Check that a tf_executor.fetch is terminating a tf_executor.graph (custom parser)
 func @graph_with_invalid_terminator(%arg0: tensor<*xf32>) -> tensor<*xf32> {
   tf_executor.graph {
@@ -58,11 +69,23 @@
 
 // -----
 
+// Check that a tf_executor.fetch parent is a graph.
+func @parent_is_graph() {
+  "some.op"() ({
+    tf_executor.fetch
+// expected-error@-1 {{'tf_executor.fetch' op expects parent op 'tf_executor.graph'}}
+  }) : () -> ()
+  return
+}
+
+// -----
+
 // Check that a tf_executor.fetch is terminating a tf_executor.graph (verifier)
 func @graph_with_invalid_terminator(%arg0: tensor<*xf32>) -> tensor<*xf32> {
+// expected-error@+2 {{'tf_executor.graph' op expects regions to end with 'tf_executor.fetch', found 'tf_executor.yield'}}
+// expected-note@+1 {{in custom textual format, the absence of terminator implies 'tf_executor.fetch'}}
   "tf_executor.graph" () ({
     tf_executor.yield
-// expected-error@-1 {{'tf_executor.yield' op invalid tf_executor.graph terminator, fetch expected}}
   }) : () -> ()
   return %arg0 : tensor<*xf32>
 }
@@ -149,6 +172,17 @@
 
 // -----
 
+// Check that a tf_executor.island parent is a graph.
+func @parent_is_graph() {
+  "some.op"() ({
+    %ctl = tf_executor.island {}
+// expected-error@-1 {{'tf_executor.island' op expects parent op 'tf_executor.graph'}}
+  }) : () -> ()
+  return
+}
+
+// -----
+
 // Check that an island can't have other operands than controls.
 func @invalid_island(%arg0: tensor<*xf32>, %ctl: !tf_executor.control) {
   tf_executor.graph {
@@ -189,7 +223,7 @@
 func @invalid_island(%arg0: tensor<*xf32>, %ctl: !tf_executor.control) {
   tf_executor.graph {
     "tf_executor.island"() ({
-// expected-error@-1 {{'tf_executor.island' op expects a non-empty body}}
+// expected-error@-1 {{'tf_executor.island' op expects a non-empty block}}
  ^entry:
     }) : () -> (!tf_executor.control)
   }
@@ -202,8 +236,9 @@
 func @invalid_island(%arg0: tensor<*xf32>, %ctl: !tf_executor.control) {
   tf_executor.graph {
     "tf_executor.island"() ({
+// expected-error@-1 {{'tf_executor.island' op expects regions to end with 'tf_executor.yield', found 'std.return'}}
+// expected-note@-2 {{in custom textual format, the absence of terminator implies 'tf_executor.yield'}}
       return
-// expected-error@-1 {{'std.return' op invalid tf_executor.island terminator, yield expected}}
     }) : () -> (!tf_executor.control)
   }
   return
@@ -211,6 +246,17 @@
 
 // -----
 
+// Check that a tf_executor.yield parent is a tf_executor.island.
+func @parent_is_island() {
+  "some.op"() ({
+    tf_executor.yield
+// expected-error@-1 {{'tf_executor.yield' op expects parent op 'tf_executor.island'}}
+  }) : () -> ()
+  return
+}
+
+// -----
+
 // Check that an island yield matches the island results.
 func @invalid_island(%arg0: tensor<*xf32>, %ctl: !tf_executor.control) {
   tf_executor.graph {
@@ -276,6 +322,17 @@
 
 // -----
 
+// Check that a tf_executor.Switch parent is a graph.
+func @parent_is_graph(%arg0: tensor<*xf32>, %arg1: tensor<i1>) {
+  "some.op"() ({
+    %true, %false, %ctlSwitch = tf_executor.Switch %arg0, %arg1 : tensor<*xf32>
+// expected-error@-1 {{'tf_executor.Switch' op expects parent op 'tf_executor.graph'}}
+  }) : () -> ()
+  return
+}
+
+// -----
+
 // Check that a switch always takes two arguments.
 func @invalid_switch(%arg0: tensor<*xf32>) {
   tf_executor.graph {
@@ -335,6 +392,17 @@
 
 // -----
 
+// Check that a tf_executor.SwitchN parent is a graph.
+func @parent_is_graph(%arg0: tensor<*xf32>, %arg1: i32) {
+  "some.op"() ({
+     %1:6 = tf_executor.SwitchN %arg0, %arg1 of 5 : tensor<*xf32>
+// expected-error@-1 {{'tf_executor.SwitchN' op expects parent op 'tf_executor.graph'}}
+  }) : () -> ()
+  return
+}
+
+// -----
+
 // Check that switchN result numbers matches the num_out attribute.
 func @invalid_switchN(%arg0: i32, %arg1: tensor<*xf32>) -> tensor<*xf32> {
   %fetches = tf_executor.graph {
@@ -377,6 +445,17 @@
 
 // -----
 
+// Check that a tf_executor.Merge parent is a graph.
+func @parent_is_graph(%arg0: tensor<*xf32>) {
+  "some.op"() ({
+    %value, %idx, %ctlMerge = tf_executor.Merge %arg0, %arg0 : tensor<*xf32>
+// expected-error@-1 {{'tf_executor.Merge' op expects parent op 'tf_executor.graph'}}
+  }) : () -> ()
+  return
+}
+
+// -----
+
 // Check that merge has at least one operand.
 func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
   %result = tf_executor.graph {
@@ -431,6 +510,18 @@
 
 // -----
 
+// Check that merge data inputs of variant type are broadcastable to the output
+func @invalid_merge(%arg0: tensor<*x!tf.variant>, %arg1: tensor<4x!tf.variant>) -> tensor<8x!tf.variant> {
+  %result = tf_executor.graph {
+    %value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<*x!tf.variant>, tensor<4x!tf.variant>) -> (tensor<8x!tf.variant>, tensor<i32>, !tf_executor.control)
+// expected-error@-1 {{'tf_executor.Merge' op expects all operands to be broadcastable but got 'tensor<8x!tf.variant>' vs 'tensor<4x!tf.variant>'}}
+    tf_executor.fetch %value : tensor<8x!tf.variant>
+  }
+  return %result : tensor<8x!tf.variant>
+}
+
+// -----
+
 // Check that merge data inputs can't appear after control input.
 func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
   %result = tf_executor.graph {
@@ -446,6 +537,17 @@
 
 // -----
 
+// Check that a tf_executor.Enter parent is a graph.
+func @parent_is_graph(%arg0: tensor<*xf32>) {
+  "some.op"() ({
+    %res:2 = tf_executor.Enter %arg0 frame "some/fra\"me" : tensor<*xf32>
+// expected-error@-1 {{'tf_executor.Enter' op expects parent op 'tf_executor.graph'}}
+  }) : () -> ()
+  return
+}
+
+// -----
+
 // Check that Enter return value is the same type as the input.
 func @invalid_enter(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> {
   %result = tf_executor.graph {
@@ -458,6 +560,28 @@
 
 // -----
 
+// Check that a tf_executor.NextIteration.Sink parent is a graph.
+func @parent_is_graph(%arg0: tensor<*xf32>, %arg1: !tf_executor.token) {
+  "some.op"() ({
+    tf_executor.NextIteration.Sink[%arg1] %arg0 : tensor<*xf32>
+// expected-error@-1 {{'tf_executor.NextIteration.Sink' op expects parent op 'tf_executor.graph'}}
+  }) : () -> ()
+  return
+}
+
+// -----
+
+// Check that a tf_executor.NextIteration.Source parent is a graph.
+func @parent_is_graph() {
+  "some.op"() ({
+    %1:3 = tf_executor.NextIteration.Source : tensor<*xf32>
+// expected-error@-1 {{'tf_executor.NextIteration.Source' op expects parent op 'tf_executor.graph'}}
+  }) : () -> ()
+  return
+}
+
+// -----
+
 func @invalid_nextiteration(%arg0: tensor<*xf32>, %arg1: !tf_executor.token) -> tensor<*xf32> {
   %0 = tf_executor.graph {
     %1:3 = tf_executor.NextIteration.Source : tensor<*xf32>
@@ -521,6 +645,17 @@
 
 // -----
 
+// Check that a tf_executor.Exit parent is a graph.
+func @parent_is_graph(%arg0: tensor<*xf32>) {
+  "some.op"() ({
+    %1:2 = tf_executor.Exit %arg0 : tensor<*xf32>
+// expected-error@-1 {{'tf_executor.Exit' op expects parent op 'tf_executor.graph'}}
+  }) : () -> ()
+  return
+}
+
+// -----
+
 func @exit(%arg0: tensor<*xi32>) -> tensor<*xf32> {
   %0 = tf_executor.graph {
     %1:2 = "tf_executor.Exit"(%arg0) : (tensor<*xi32>) -> (tensor<*xf32>, !tf_executor.control)
@@ -529,3 +664,25 @@
   }
   return %0 : tensor<*xf32>
 }
+
+// -----
+
+// Check that a tf_executor.ControlTrigger parent is a graph.
+func @parent_is_graph(%arg0: !tf_executor.control, %arg1: !tf_executor.control) {
+  "some.op"() ({
+    %0 = tf_executor.ControlTrigger %arg0, %arg1
+// expected-error@-1 {{'tf_executor.ControlTrigger' op expects parent op 'tf_executor.graph'}}
+  }) : () -> ()
+  return
+}
+
+// -----
+
+// Check that a tf_executor.LoopCond parent is a graph.
+func @parent_is_graph(%arg0: tensor<i1>, %arg1: !tf_executor.control) {
+  "some.op"() ({
+    %1:2 = tf_executor.LoopCond %arg0, %arg1 : tensor<i1>
+// expected-error@-1 {{'tf_executor.LoopCond' op expects parent op 'tf_executor.graph'}}
+  }) : () -> ()
+  return
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td
index 473f69f..0653c1d 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td
@@ -20,9 +20,7 @@
 
 /// TODO(b/130756570): Support OpBase constraints in PatternRewrites.
 def SingleResultAndOperandHaveSameElementType : Constraint<
-  CPred<"$0->getResult(0)->getType().cast<ShapedType>()"
-        ".getElementType() == "
-        "$1->getType().cast<ShapedType>().getElementType()">>;
+  CPred<"getElementTypeOrSelf($0) == getElementTypeOrSelf($1)">>;
 
 //===----------------------------------------------------------------------===//
 // Add op patterns.
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc
new file mode 100644
index 0000000..7e2405b
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc
@@ -0,0 +1,155 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This pass outlines regions of `tf_device.launch` into functions and replaces
+// `tf_device.launch` with equivalent `tf_device.launch_func` operations.
+
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/IR/Attributes.h"  // TF:local_config_mlir
+#include "mlir/IR/Block.h"  // TF:local_config_mlir
+#include "mlir/IR/Builders.h"  // TF:local_config_mlir
+#include "mlir/IR/Module.h"  // TF:local_config_mlir
+#include "mlir/IR/Operation.h"  // TF:local_config_mlir
+#include "mlir/Pass/Pass.h"  // TF:local_config_mlir
+#include "mlir/Pass/PassRegistry.h"  // TF:local_config_mlir
+#include "mlir/StandardOps/Ops.h"  // TF:local_config_mlir
+#include "mlir/Transforms/RegionUtils.h"  // TF:local_config_mlir
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
+
+namespace mlir {
+namespace TFDevice {
+
+namespace {
+
+struct ClusterOutliningPass : public ModulePass<ClusterOutliningPass> {
+  void runOnModule() override;
+};
+
+void ReplaceLaunchReturnWithReturn(Operation* launch_return_op,
+                                   OpBuilder* builder) {
+  llvm::SmallVector<Value*, 4> operands(launch_return_op->getOperands());
+  builder->create<ReturnOp>(launch_return_op->getLoc(), operands);
+  launch_return_op->erase();
+}
+
+// Builds a function that outlines region attached to launch_op and inserts
+// built function into given module.
+FuncOp BuildFunction(StringRef device, llvm::ArrayRef<Value*> live_ins,
+                     Operation* launch_op, ModuleManager* module_manager,
+                     OpBuilder* builder) {
+  llvm::SmallVector<Type, 4> operand_types;
+  operand_types.reserve(live_ins.size());
+  for (Value* v : live_ins) operand_types.emplace_back(v->getType());
+
+  llvm::SmallVector<Type, 4> result_types(launch_op->getResultTypes());
+
+  auto func_type =
+      FunctionType::get(operand_types, result_types, builder->getContext());
+
+  std::string func_name_prefix = Twine(device, "_func").str();
+  FuncOp outlined_func =
+      FuncOp::create(launch_op->getLoc(), func_name_prefix, func_type);
+
+  // Create function body.
+  Block* outlined_func_block = outlined_func.addEntryBlock();
+
+  // Replace uses of live-in values within launch_op region with function
+  // arguments.
+  Region& launch_op_region = launch_op->getRegion(0);
+  for (const auto& p :
+       llvm::zip(live_ins, outlined_func_block->getArguments())) {
+    replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p),
+                               launch_op_region);
+  }
+
+  // Move all instructions in launch_op into outlined_function's only block.
+  auto& launch_op_body = launch_op_region.front().getOperations();
+  outlined_func_block->getOperations().splice(
+      outlined_func_block->end(), launch_op_body, launch_op_body.begin(),
+      launch_op_body.end());
+
+  // Replace `tf_device.launch_return` terminator with `std.return` in function
+  // body.
+  Operation* launch_return_op = &outlined_func_block->back();
+  builder->setInsertionPoint(launch_return_op);
+  ReplaceLaunchReturnWithReturn(launch_return_op, builder);
+
+  module_manager->insert(outlined_func);
+  return outlined_func;
+}
+
+Operation* BuildLaunchFunc(const Location& loc, StringRef device, FuncOp func,
+                           llvm::ArrayRef<Value*> live_ins,
+                           OpBuilder* builder) {
+  // TODO(b/138909768): Define `tf_device.launch_func` and use its build method
+  // instead.
+  OperationState launch_func_op(loc, "tf_device.launch_func");
+  launch_func_op.addAttribute("device", builder->getStringAttr(device));
+  launch_func_op.addAttribute("func",
+                              builder->getSymbolRefAttr(func.getName()));
+  launch_func_op.addTypes(func.getType().getResults());
+  llvm::SmallVector<Value*, 4> operands(live_ins.begin(), live_ins.end());
+  launch_func_op.addOperands(operands);
+  return builder->createOperation(launch_func_op);
+}
+
+// Outlines body of `tf_device.launch` into a function and create a
+// `tf_device.launch_func` to invoke that function. `tf_device.launch` is
+// removed afterwards.`
+void OutlineLaunch(Operation* launch_op, ModuleManager* module_manager,
+                   OpBuilder* builder) {
+  llvm::SetVector<Value*> live_ins;
+  getUsedValuesDefinedAbove(launch_op->getRegion(0), launch_op->getRegion(0),
+                            live_ins);
+
+  StringRef device = launch_op->getAttrOfType<StringAttr>("device").getValue();
+
+  FuncOp outlined_func = BuildFunction(device, live_ins.getArrayRef(),
+                                       launch_op, module_manager, builder);
+  builder->setInsertionPoint(launch_op);
+  Operation* launch_func_op =
+      BuildLaunchFunc(launch_op->getLoc(), device, outlined_func,
+                      live_ins.getArrayRef(), builder);
+
+  launch_op->replaceAllUsesWith(launch_func_op);
+  launch_op->erase();
+}
+
+void ClusterOutliningPass::runOnModule() {
+  ModuleOp m = getModule();
+  ModuleManager module_manager(m);
+  OpBuilder builder(m.getContext());
+  m.walk([&](Operation* op) {
+    // TODO(b/138909768): Use templated Walk method instead of skipping
+    // operations according to their type string.
+    if (op->getName().getStringRef() != "tf_device.launch") return;
+
+    OutlineLaunch(op, &module_manager, &builder);
+  });
+}
+
+}  // namespace
+
+ModulePassBase* CreateClusterOutliningPass() {
+  return new ClusterOutliningPass();
+}
+
+static PassRegistration<ClusterOutliningPass> pass(
+    "tf-device-cluster-outlining",
+    "Outline regions of tf_device.launch operations.");
+
+}  // namespace TFDevice
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc
new file mode 100644
index 0000000..12bd709
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc
@@ -0,0 +1,348 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This transformation pass takes TFExecutor dialect IslandOps and merges them.
+// Note, this currently does not handle TensorFlow V1 style control flow/frames
+// or side effecting ops yet.
+
+#include <iterator>
+#include <tuple>
+
+#include "llvm/ADT/None.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include "mlir/IR/Block.h"  // TF:local_config_mlir
+#include "mlir/IR/Builders.h"  // TF:local_config_mlir
+#include "mlir/IR/Location.h"  // TF:local_config_mlir
+#include "mlir/IR/Operation.h"  // TF:local_config_mlir
+#include "mlir/Pass/Pass.h"  // TF:local_config_mlir
+#include "mlir/Pass/PassRegistry.h"  // TF:local_config_mlir
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace mlir {
+namespace TFExecutor {
+
+namespace {
+
+// IslandType is an enum representing if an island is the island (parent)
+// merging another island or is the island (child) being being merged.
+enum IslandType { kParentIsland, kChildIsland };
+
+// Output is a helper struct holding a result index and island type (parent or
+// child).
+struct Output {
+  Output(IslandType island_type, int result_index)
+      : island_type(island_type), result_index(result_index) {}
+
+  IslandType island_type;
+  int result_index;
+};
+
+struct ExecutorIslandCoarsening
+    : public FunctionPass<ExecutorIslandCoarsening> {
+  void runOnFunction() override;
+
+ private:
+  void MergeIslands(OpBuilder* builder, tf_executor::IslandOp* parent,
+                    tf_executor::IslandOp* child, IslandType insert_position);
+  bool MergeIslandWithOperand(OpBuilder* builder, tf_executor::IslandOp* child);
+  bool MergeIslandWithResult(OpBuilder* builder, tf_executor::IslandOp* parent);
+};
+
+// Finds the operation leading to an island that the island can be merged with.
+// This looks for the operation, either control input or data input to an op,
+// that is closest to the island in the graph. If no candidate can be found or
+// the op found is not an island, an empty optional is returned.
+llvm::Optional<tf_executor::IslandOp> GetOperandCandidateToMergeWith(
+    tf_executor::IslandOp* island) {
+  Operation* graph_op = island->getParentOp();
+  Operation* candidate = nullptr;
+
+  // Check island control operands.
+  for (Value* input : island->controlInputs()) {
+    Operation* def = input->getDefiningOp();
+    DCHECK_EQ(def->getParentOp(), graph_op);
+    if (!candidate || candidate->isBeforeInBlock(def)) candidate = def;
+  }
+
+  // Check island data operands.
+  island->walk([graph_op, &candidate](Operation* op) {
+    for (Value* input : op->getOperands()) {
+      Operation* def = input->getDefiningOp();
+      if (!def || def->getParentOp() != graph_op) continue;
+      if (!candidate || candidate->isBeforeInBlock(def)) candidate = def;
+    }
+  });
+
+  if (!candidate || !llvm::isa<tf_executor::IslandOp>(candidate))
+    return llvm::None;
+
+  return llvm::Optional<tf_executor::IslandOp>(
+      llvm::cast<tf_executor::IslandOp>(candidate));
+}
+
+// Finds the operation leading from an island that the island can be merged
+// with. This looks for the operation, either control output or data output to
+// an op, that is closest to the island in the graph. If no candidate can be
+// found or the op found is not an island, an empty optional is returned.
+llvm::Optional<tf_executor::IslandOp> GetResultCandidateToMergeWith(
+    tf_executor::IslandOp* island) {
+  Operation* graph_op = island->getParentOp();
+  Operation* candidate = nullptr;
+
+  // Check island control results.
+  for (Operation* user : island->control()->getUsers()) {
+    DCHECK_EQ(user->getParentOp(), graph_op);
+    if (!candidate || user->isBeforeInBlock(candidate)) candidate = user;
+  }
+
+  // Check island data results.
+  Block& graph_body = llvm::cast<tf_executor::GraphOp>(graph_op).GetBody();
+  for (Value* result : island->outputs()) {
+    for (Operation* user : result->getUsers()) {
+      Operation* def = graph_body.findAncestorInstInBlock(*user);
+      DCHECK_NE(def, nullptr);
+      if (!candidate || def->isBeforeInBlock(candidate)) candidate = def;
+    }
+  }
+
+  if (!candidate || !llvm::isa<tf_executor::IslandOp>(candidate))
+    return llvm::None;
+
+  return llvm::Optional<tf_executor::IslandOp>(
+      llvm::cast<tf_executor::IslandOp>(candidate));
+}
+
+// Collects the operands for the new island by collecting all control inputs of
+// the islands being merged.
+llvm::SmallSetVector<Value*, 8> GetNewIslandOperands(
+    tf_executor::IslandOp* parent, tf_executor::IslandOp* child) {
+  llvm::SmallSetVector<Value*, 8> operands;
+  operands.insert(parent->getOperands().begin(), parent->getOperands().end());
+  operands.insert(child->getOperands().begin(), child->getOperands().end());
+  operands.remove(parent->control());
+  return operands;
+}
+
+// Collects the results for the new island by going through each data output of
+// the islands being merged. Unused results outside of the merged island to be
+// formed are pruned. If the child island inner ops consume the parent island
+// control output, the child island inner ops will have that respective control
+// input pruned. Results of the parent island that are consumed by the child
+// island are replaced by the respective inner ops output from the parent
+// island.
+llvm::SmallVector<Output, 8> GetNewIslandResultsAndForwardOutputs(
+    mlir::MLIRContext* context, tf_executor::IslandOp* parent,
+    tf_executor::IslandOp* child, llvm::SmallVector<Type, 8>* result_types) {
+  llvm::SmallVector<Output, 8> results;
+
+  Operation& last_op = parent->GetBody().back();
+  auto yield_op = cast<tf_executor::YieldOp>(last_op);
+  Block& child_body = child->GetBody();
+  for (auto& ret_and_idx : llvm::enumerate(parent->outputs())) {
+    bool output_captured = false;
+    Value* yield_input = yield_op.getOperand(ret_and_idx.index());
+    for (auto& use :
+         llvm::make_early_inc_range(ret_and_idx.value()->getUses())) {
+      if (child_body.findAncestorInstInBlock(*use.getOwner())) {
+        // Forward output from inner op.
+        use.set(yield_input);
+      } else if (!output_captured) {
+        results.push_back(
+            Output(IslandType::kParentIsland, ret_and_idx.index()));
+        result_types->push_back(ret_and_idx.value()->getType());
+        output_captured = true;
+      }
+    }
+  }
+
+  for (auto& ret_and_idx : llvm::enumerate(child->outputs())) {
+    if (!ret_and_idx.value()->use_empty()) {
+      results.push_back(Output(IslandType::kChildIsland, ret_and_idx.index()));
+      result_types->push_back(ret_and_idx.value()->getType());
+    }
+  }
+
+  // IslandOps always have a control output.
+  result_types->push_back(tf_executor::ControlType::get(context));
+
+  return results;
+}
+
+// Creates the new merged island.
+tf_executor::IslandOp CreateNewIsland(
+    OpBuilder* builder, Operation* old_island,
+    const llvm::SmallVector<Type, 8>& result_types,
+    const llvm::SmallSetVector<Value*, 8>& operands) {
+  builder->setInsertionPoint(old_island);
+  auto new_island = builder->create<tf_executor::IslandOp>(
+      old_island->getLoc(), result_types, operands.getArrayRef(),
+      ArrayRef<NamedAttribute>{});
+  new_island.body().push_back(new Block);
+  return new_island;
+}
+
+// Creates respective YieldOp for the new merged island.
+tf_executor::YieldOp CreateNewIslandYieldOp(
+    OpBuilder* builder, tf_executor::IslandOp* new_island,
+    const llvm::SmallVector<Output, 8>& results, tf_executor::IslandOp* parent,
+    tf_executor::IslandOp* child) {
+  llvm::SmallVector<Value*, 8> yield_operands;
+  yield_operands.reserve(results.size());
+  for (auto ret_vals : llvm::zip(results, new_island->outputs())) {
+    // Get consumed output (island type and result index).
+    const auto& output = std::get<0>(ret_vals);
+    tf_executor::IslandOp* output_island =
+        output.island_type == IslandType::kParentIsland ? parent : child;
+    Value* result = output_island->getResult(output.result_index);
+    // Replace original result with new island result.
+    result->replaceAllUsesWith(std::get<1>(ret_vals));
+    // Find YieldOp in original island, grab the associated operand (inner op
+    // output) and add it as a operand to the YieldOp of the merged island.
+    yield_operands.push_back(
+        output_island->GetBody().back().getOperand(output.result_index));
+  }
+
+  // Create YieldOp for the new island.
+  builder->setInsertionPoint(&new_island->GetBody(),
+                             new_island->GetBody().end());
+  return builder->create<tf_executor::YieldOp>(new_island->getLoc(),
+                                               yield_operands);
+}
+
+// Moves inner ops (excluding last op/YieldOp) from islands being merged into
+// the new merged island.
+void MoveInnerOpsToNewIsland(tf_executor::IslandOp* parent,
+                             tf_executor::IslandOp* child,
+                             Operation* new_yield_op) {
+  Block* block = new_yield_op->getBlock();
+
+  auto move_inner_ops = [block, new_yield_op](tf_executor::IslandOp* island) {
+    auto& island_body = island->GetBody().getOperations();
+    block->getOperations().splice(new_yield_op->getIterator(), island_body,
+                                  island_body.begin(),
+                                  std::prev(island_body.end()));
+  };
+
+  move_inner_ops(parent);
+  move_inner_ops(child);
+}
+
+// Merges two islands and places new merged island before parent or child.
+void ExecutorIslandCoarsening::MergeIslands(OpBuilder* builder,
+                                            tf_executor::IslandOp* parent,
+                                            tf_executor::IslandOp* child,
+                                            IslandType insert_position) {
+  // Collect operands for the new merged island.
+  llvm::SmallSetVector<Value*, 8> operands =
+      GetNewIslandOperands(parent, child);
+
+  // Collect results and result types for the new merged island.
+  llvm::SmallVector<Type, 8> result_types;
+  llvm::SmallVector<Output, 8> results = GetNewIslandResultsAndForwardOutputs(
+      &getContext(), parent, child, &result_types);
+
+  // Create the new merged island.
+  tf_executor::IslandOp new_island = CreateNewIsland(
+      builder, insert_position == IslandType::kParentIsland ? *parent : *child,
+      result_types, operands);
+
+  // Create associated YieldOp for the new merged island.
+  tf_executor::YieldOp new_yield_op =
+      CreateNewIslandYieldOp(builder, &new_island, results, parent, child);
+
+  // Move inner ops from original islands into the new island.
+  MoveInnerOpsToNewIsland(parent, child, new_yield_op.getOperation());
+
+  // Update control inputs to point to the new merged island.
+  child->control()->replaceAllUsesWith(new_island.control());
+  parent->control()->replaceAllUsesWith(new_island.control());
+
+  // Remove merged islands.
+  child->erase();
+  parent->erase();
+}
+
+// Merges island with the operand closest to the island in the graph. The
+// operand must be another IslandOp for merging to take place. A new island is
+// created and the islands being merged are removed if a merge took place.
+// Returns true if the island was merged with its operand.
+bool ExecutorIslandCoarsening::MergeIslandWithOperand(
+    OpBuilder* builder, tf_executor::IslandOp* child) {
+  // Find candidate operand to merge island with.
+  llvm::Optional<tf_executor::IslandOp> candidate =
+      GetOperandCandidateToMergeWith(child);
+  if (!candidate.hasValue()) return false;
+  auto& parent = candidate.getValue();
+  MergeIslands(builder, &parent, child, IslandType::kParentIsland);
+  return true;
+}
+
+// Merges island with the result closest to the island in the graph. The result
+// must be another IslandOp for merging to take place. A new island is created
+// and the islands being merged are removed if a merge took place. Returns true
+// if the island was merged with its result.
+bool ExecutorIslandCoarsening::MergeIslandWithResult(
+    OpBuilder* builder, tf_executor::IslandOp* parent) {
+  // Find candidate result to merge island with.
+  llvm::Optional<tf_executor::IslandOp> candidate =
+      GetResultCandidateToMergeWith(parent);
+  if (!candidate.hasValue()) return false;
+  auto& child = candidate.getValue();
+  MergeIslands(builder, parent, &child, IslandType::kChildIsland);
+  return false;
+}
+
+void ExecutorIslandCoarsening::runOnFunction() {
+  getFunction().walk<tf_executor::GraphOp>([this](tf_executor::GraphOp graph) {
+    Block& graph_body = graph.GetBody();
+    OpBuilder builder(&graph_body);
+
+    bool updated = false;
+    do {
+      updated = false;
+
+      auto reversed = llvm::reverse(graph_body);
+      for (Operation& operation : llvm::make_early_inc_range(reversed)) {
+        auto island = llvm::dyn_cast<tf_executor::IslandOp>(operation);
+        if (!island) continue;
+        updated |= MergeIslandWithResult(&builder, &island);
+      }
+
+      for (Operation& operation : llvm::make_early_inc_range(graph_body)) {
+        auto island = llvm::dyn_cast<tf_executor::IslandOp>(operation);
+        if (!island) continue;
+        updated |= MergeIslandWithOperand(&builder, &island);
+      }
+    } while (updated);
+  });
+}
+
+}  // namespace
+
+FunctionPassBase* CreateTFExecutorIslandCoarseningPass() {
+  return new ExecutorIslandCoarsening();
+}
+
+static PassRegistration<ExecutorIslandCoarsening> pass(
+    "tf-executor-island-coarsening", "Merges TFExecutor dialect IslandOps");
+
+}  // namespace TFExecutor
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc
index 72775d0..38bc602 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc
@@ -35,7 +35,7 @@
     OwningRewritePatternList patterns;
     auto func = getFunction();
     populateWithGenerated(&getContext(), &patterns);
-    applyPatternsGreedily(func, std::move(patterns));
+    applyPatternsGreedily(func, patterns);
   }
 };
 
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
index 1202d4d..1a7848f 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
@@ -22,19 +22,31 @@
 namespace TF {
 // Transforms functional control flow operations in the standard TensorFlow
 // dialect to MLIR Control Flow Graph (CFG) form.
-FunctionPassBase *CreateTFFunctionalControlFlowToCFG();
+FunctionPassBase* CreateTFFunctionalControlFlowToCFG();
 
 // Optimizes Tensorflow graph.
-FunctionPassBase *CreateTFOptimizePass();
+FunctionPassBase* CreateTFOptimizePass();
 
 }  // namespace TF
 
 namespace TFControlFlow {
 // Raises from the "TensorFlow Control Flow" dialect to the standard TensorFlow
 // dialect.
-FunctionPassBase *CreateRaiseTFControlFlowPass();
+FunctionPassBase* CreateRaiseTFControlFlowPass();
 
 }  // namespace TFControlFlow
+
+namespace TFExecutor {
+// Create a pass to merge IslandOps from TFExecutor dialect.
+FunctionPassBase* CreateTFExecutorIslandCoarseningPass();
+
+}  // namespace TFExecutor
+
+namespace TFDevice {
+// Creates a pass that outlines regions of tf_device.launch operations.
+ModulePassBase* CreateClusterOutliningPass();
+}  // namespace TFDevice
+
 }  // namespace mlir
 
 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_PASSES_H_
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc
index 60f7ed3..c5f21fa 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc
@@ -19,7 +19,7 @@
 #include "mlir/IR/Location.h"  // TF:local_config_mlir
 #include "mlir/Pass/Pass.h"  // TF:local_config_mlir
 #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
-#include "tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.h"
+#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
 #include "tensorflow/core/common_runtime/optimization_registry.h"
 #include "tensorflow/core/framework/function.h"
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc b/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc
index 546898f..e80dcef 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc
@@ -45,9 +45,15 @@
 }  // end anonymous namespace
 
 static bool HasSingleGraph(FuncOp function) {
+  // We expect the function has only one region with one block,
   if (function.getBlocks().size() != 1) return false;
-  if (!std::next(function.begin()->begin())->isKnownTerminator()) return false;
-  if (!isa<tf_executor::GraphOp>(function.begin()->begin())) return false;
+  auto &block = function.front();
+  // and the block contains two ops,
+  if (std::next(block.begin()) == block.end()) return false;
+  // one GraphOp,
+  if (!isa<tf_executor::GraphOp>(block.begin())) return false;
+  // followed by a terminator.
+  if (!std::next(block.begin())->isKnownTerminator()) return false;
   return true;
 }
 
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc
index 3d98cdf..a3b1cc9 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc
@@ -34,8 +34,10 @@
 #include "mlir/IR/Module.h"  // TF:local_config_mlir
 #include "mlir/IR/Operation.h"  // TF:local_config_mlir
 #include "mlir/IR/Types.h"  // TF:local_config_mlir
+#include "mlir/Pass/PassManager.h"  // TF:local_config_mlir
 #include "mlir/StandardOps/Ops.h"  // TF:local_config_mlir
 #include "mlir/Support/DebugStringHelper.h"  // TF:local_config_mlir
+#include "mlir/Support/LogicalResult.h"  // TF:local_config_mlir
 #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
@@ -55,6 +57,11 @@
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status.h"
 
+namespace mlir {
+/// Create a pass to convert from the TFExecutor to the TF control dialect.
+FunctionPassBase* CreateTFExecutorToControlDialectConversion();
+}  // namespace mlir
+
 namespace tensorflow {
 using llvm::cast;
 using llvm::dyn_cast;
@@ -201,10 +208,8 @@
 StatusOr<std::unique_ptr<NodeDef>> Exporter::GetArgumentNode(
     mlir::BlockArgument* arg, unsigned index) {
   auto node_def = absl::make_unique<NodeDef>();
-  node_def->set_name(UniqueName(arg->getContainingRegion()
-                                    ->getParentOfType<mlir::FuncOp>()
-                                    .getName()
-                                    .str()));
+  node_def->set_name(UniqueName(
+      arg->getParentRegion()->getParentOfType<mlir::FuncOp>().getName().str()));
   node_def->set_op(FunctionLibraryDefinition::kArgOp);
   DataType dtype;
   TF_RETURN_IF_ERROR(ConvertToDataType(
@@ -326,7 +331,7 @@
   // is an input node. We recover the original input node and skip adding the
   // argument node. The new input node will be handled as normal in the
   // following steps.
-  if (arg->getContainingRegion()->getParentOfType<mlir::FuncOp>().getName() ==
+  if (arg->getParentRegion()->getParentOfType<mlir::FuncOp>().getName() ==
       "main") {
     if (!arg->hasOneUse()) {
       return errors::FailedPrecondition(
@@ -604,6 +609,12 @@
 Status ConvertMlirToGraph(mlir::ModuleOp module, const ExporterConfigs& confs,
                           std::unique_ptr<Graph>* graph,
                           FunctionLibraryDefinition* flib_def) {
+  mlir::PassManager pass_manager;
+  pass_manager.addPass(mlir::CreateTFExecutorToControlDialectConversion());
+  if (mlir::failed(pass_manager.run(module))) {
+    return errors::FailedPrecondition(
+        "Failed to convert TFExecutor Dialect to Control Dialect.");
+  }
   return Exporter::Convert(module, confs, graph, flib_def);
 }
 
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.cc
deleted file mode 100644
index 0b9012d..0000000
--- a/tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.cc
+++ /dev/null
@@ -1,1411 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.h"
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/strings/escaping.h"
-#include "absl/strings/numbers.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "absl/strings/strip.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/ADT/Twine.h"
-#include "llvm/Support/raw_ostream.h"
-#include "mlir/IR/Attributes.h"  // TF:local_config_mlir
-#include "mlir/IR/Builders.h"  // TF:local_config_mlir
-#include "mlir/IR/Function.h"  // TF:local_config_mlir
-#include "mlir/IR/Identifier.h"  // TF:local_config_mlir
-#include "mlir/IR/Location.h"  // TF:local_config_mlir
-#include "mlir/IR/MLIRContext.h"  // TF:local_config_mlir
-#include "mlir/IR/Module.h"  // TF:local_config_mlir
-#include "mlir/IR/Types.h"  // TF:local_config_mlir
-#include "mlir/StandardOps/Ops.h"  // TF:local_config_mlir
-#include "tensorflow/compiler/jit/shape_inference_helpers.h"
-#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
-#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
-#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
-#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
-#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
-#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
-#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
-#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/core/common_runtime/function.h"
-#include "tensorflow/core/common_runtime/shape_refiner.h"
-#include "tensorflow/core/framework/attr_value.pb.h"
-#include "tensorflow/core/framework/graph.pb.h"
-#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/shape_inference.h"
-#include "tensorflow/core/framework/tensor.pb.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/framework/versions.pb.h"
-#include "tensorflow/core/graph/algorithm.h"
-#include "tensorflow/core/graph/graph.h"
-#include "tensorflow/core/graph/graph_constructor.h"
-#include "tensorflow/core/graph/node_builder.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/platform/protobuf.h"
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
-
-namespace tensorflow {
-using stream_executor::port::StatusOr;
-
-namespace {
-
-// Stateful helper class to import a GraphDef into an MLIR Module. The nodes
-// defined in the graph is converted to a function called "main". All the
-// library function definitions are converted to MLIR functions in the module.
-class Importer {
- public:
-  // Main entry point: converts the given graph to an MLIR Module.
-  static StatusOr<mlir::OwningModuleRef> Convert(
-      mlir::MLIRContext* context, const Graph& graph,
-      const GraphDebugInfo& debug_info,
-      const FunctionLibraryDefinition& flib_def, const NodeSpecs& specs);
-
- private:
-  // Most types with subtypes have only one subtype.
-  using ElementSubtypes = llvm::SmallVector<mlir::TensorType, 1>;
-
-  explicit Importer(
-      const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
-      const NodeSpecs& specs, mlir::ModuleOp module,
-      std::unordered_map<std::string, std::string>* tf_name_to_mlir_name)
-      : module_(module),
-        context_(module.getContext()),
-        tf_name_to_mlir_name_(tf_name_to_mlir_name),
-        graph_flib_(flib),
-        specs_(specs),
-        debug_info_(debug_info) {}
-
-  // Prepares converting the graph to an MLIR module. This step removes the
-  // backedges of the graph, orders the nodes and infers the shapes.
-  Status PrepareConvert(const Graph& graph);
-
-  // Returns the function signature of the main function of converted MLIR
-  // module, the input nodes and output nodes. The type and shape information
-  // for the function arguments are read from the specs_, but the type and shape
-  // information for the function returns are inferred by the shape_refiner_.
-  StatusOr<mlir::FunctionType> InferMainFunctionType(
-      absl::InlinedVector<OutputTensor, 4>* arg_nodes,
-      absl::InlinedVector<OutputTensor, 4>* ret_nodes);
-
-  // Returns the inferred function signature of the given function body. Input
-  // types are unranked tensor of the respective datatype in the function and
-  // result types are inferred by the shape_refiner_. Result types need not be
-  // unranked tensors and could be ranked tensors in cases where result type
-  // depends on an op with static output shape like tf.Const.
-  StatusOr<mlir::FunctionType> InferLibFunctionType(const FunctionBody& fbody);
-
-  // Converts the prepared graph to a Function and adds it to the module. A set
-  // of nodes from the graph are given to converted to the arguments and returns
-  // of the function.
-  Status Convert(llvm::StringRef func_name, mlir::FunctionType func_type,
-                 const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
-                 const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
-                 llvm::ArrayRef<mlir::NamedAttribute> attrs);
-
-  // Adds all the ordered_nodes to the shape refiner shape_refiner_. Then all
-  // data type and shape information is maintained by the shape_refiner_.
-  Status AddNodesToShapeRefiner();
-
-  // Returns the inferred input type at index `idx` of the node in the context.
-  StatusOr<mlir::TensorType> InferInputType(
-      ExtendedInferenceContext* shape_context, int idx, mlir::Builder builder);
-
-  // Returns the inferred output type at index `idx` of the node in the context.
-  StatusOr<mlir::TensorType> InferOutputType(
-      ExtendedInferenceContext* shape_context, int idx, mlir::Builder builder);
-
-  // Converts the inferred shape referred to by 'handle' in 'context', with
-  // given element type, and returns an MLIR tensor type.
-  StatusOr<mlir::TensorType> ConvertDataTypeAndShape(
-      DataType dtype, const shape_inference::ShapeHandle& handle,
-      const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
-      shape_inference::InferenceContext* context, mlir::Builder builder);
-
-  // Converts the inferred shape referred to by 'handle' in 'context', with
-  // given element type, and returns an MLIR tensor type.
-  StatusOr<mlir::TensorType> ConvertElementTypeAndShape(
-      mlir::Type element_type, const shape_inference::ShapeHandle& handle,
-      shape_inference::InferenceContext* context, mlir::Builder builder);
-
-  // Converts the inferred subtypes for an element type to corresponding MLIR
-  // types in 'context'.
-  StatusOr<ElementSubtypes> ConvertSubtypes(
-      const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
-      shape_inference::InferenceContext* context, mlir::Builder builder);
-
-  // Converts the tensor proto into an MLIR elements attribute.
-  StatusOr<mlir::ElementsAttr> ConvertTensorProto(const TensorProto& value) {
-    return ::tensorflow::ConvertTensorProto(value, builder_.get());
-  }
-
-  // Converts func name in graphdef to mlir::SymbolRefAttribute.
-  StatusOr<mlir::SymbolRefAttr> ConvertFunctionCallName(
-      const std::string& func_name);
-
-  // Converts the given non-function-call AttrValue to an MLIR Attribute.
-  StatusOr<mlir::Attribute> ConvertAttributeValue(const AttrValue& value);
-
-  // Converts the given function-call AttrValue to MLIR Attributes and pushes
-  // them to the given attributes list. For example, if there is a kFunc
-  // AttrValue {name : foo, attrs : {k1 : bar, k2 : rfc}}, it will convert it to
-  // a list of MLIR Attributes: [{base_name : foo}, {base_name.k1 : bar},
-  // {base_name.k2 : rfc}}.
-  Status ConvertFunctionCallAttribute(
-      const std::string& base_name, const AttrValue& value,
-      llvm::SmallVector<mlir::NamedAttribute, 4>* attributes);
-
-  // Converts one NodeDef from the input GraphDef into an Operation and
-  // inserts it into the MLIR module using builder_.
-  Status ConvertNode(const Node& node);
-
-  // If the input graph represents a while-loop, the edges pointing from a
-  // "NextIteration" node to a "Merge" node add cyclic dependencies and make the
-  // topological sorting impossible. We need to remove these edges from the
-  // input graph to infer shapes and construct a Function. For each
-  // "NextIteration" node, there are two operations, "NextIteration.source"
-  // and "NextIteration.sink" are added to the MLIR module.
-  using BackEdge = BackEdgeHelper::BackEdge;
-
-  // Removes backedges from the input graph. The removed edges are added back to
-  // to OpBuilder after the remaining graph is converted to the Function.
-  Status RemoveBackedges(const Graph& graph);
-
-  // Restores backedges removed during shape inference to the final Function.
-  Status AddBackedges();
-
-  // Restores a single backedge in the Function by adding a replicated
-  // operation before the dst operation.
-  Status AddBackedge(mlir::Operation* sink, mlir::Operation* dst,
-                     int dst_input);
-
-  // Gets the "source" of a NextIteration operation. If it doesn't exist,
-  // creates and inserts it to the front of the basic block.
-  mlir::Operation* GetOrCreateNextIterationSource(mlir::Operation* sink,
-                                                  mlir::Operation* dst);
-
-  // Finds out the function definition for the given function name from the
-  // graph and converts it to a function of the module. This method is called
-  // on demand because the graph flib_def does not provide an iterator
-  // interface. The consequence is that only the referred functions are added to
-  // the MLIR module.
-  Status ConvertLibFunction(const std::string& func_name);
-
-  // Adds the input arguments and return operation to the function. The
-  // arguments are added as basic block argument. Also the argument types and
-  // the id of the nodes from the input graph needs to be specified.
-  Status ConvertFunctionArgAndRets(
-      mlir::Block* bb, llvm::ArrayRef<mlir::Type> arg_types,
-      const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
-      const absl::InlinedVector<OutputTensor, 4>& ret_nodes);
-
-  // Gets the location information of the given node. It uses the
-  // "original_node_name" in the NodeDef to get the corresponding file location
-  // (FileLineColLoc) from the input DebugInfo and returns an CallSiteLoc. If
-  // there are multiple "original_node_names", a FusedLoc is returned. If the
-  // node name couldn't be found in the input DebugInfo, a NameLoc is used as
-  // the location.
-  mlir::Location GetLocation(const NodeDef& node);
-
-  // Gets the location information string for the given node.
-  std::string GetLocationStr(const Node& node, bool includeNodeName = false);
-
-  // Inserts a placeholder node in the graph to replace the input node. Replaces
-  // all the output edges of the input_node with the placeholder node, and
-  // removes the input_node from the graph. The new node has the same name as
-  // the input_node, so Nodespecs do not need any modification.
-  // Note: This modifies the graph, and so any list of ordered nodes needs to be
-  // reconstructed.
-  StatusOr<Node*> ReplaceWithPlaceholderNode(const TensorShapeProto& shape,
-                                             DataType dtype, Node* input_node);
-
-  // Gets the input and output nodes corresponding to the specified input and
-  // output nodes in specs_. If there are no input or output nodes specified,
-  // nodes will be empty
-  Status GetInputOutputNodes(std::unordered_set<const Node*>* nodes);
-
-  // The input graph with backedges removed. The removed backedges are stored
-  // in the back_edge_helper.
-  BackEdgeHelper back_edge_helper_;
-  // A map between node and output index, for each backedge.
-  absl::flat_hash_map<const Node*, int> back_edge_node_output_;
-  absl::flat_hash_map<const Node*, BackEdge> back_edge_dst_inputs_;
-  // A map between sink and source operation of NextIteration
-  absl::flat_hash_map<mlir::Operation*, mlir::Operation*>
-      next_iteration_sink_source_;
-
-  // All nodes and version information about the (copied) imported graph.
-  std::unique_ptr<Graph> graph_;
-  const VersionDef* graph_versions_;
-  std::vector<Node*> ordered_nodes_;
-
-  // Maps from a Node ID to a MLIR value.
-  using NodeValueMap = absl::flat_hash_map<int, mlir::Operation*>;
-
-  std::unique_ptr<mlir::OpBuilder> builder_;
-  mlir::ModuleOp module_;
-  mlir::MLIRContext* context_;
-  std::unordered_map<std::string, std::string>* tf_name_to_mlir_name_;
-  const FunctionLibraryDefinition& graph_flib_;
-  const NodeSpecs& specs_;
-  const GraphDebugInfo& debug_info_;
-  NodeValueMap node_values_;
-  std::unique_ptr<ShapeRefiner> shape_refiner_;
-};
-
-// Adds the default attributes to each node def if they are missing from the
-// GraphDef.
-Status AddDefaultsToNodeDef(GraphDef* graph_def) {
-  const tensorflow::OpRegistrationData* op_reg_data;
-  for (auto& node_def : *graph_def->mutable_node()) {
-    auto status =
-        tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data);
-    if (!status.ok()) {
-      // This is likely a function call node, so we should continue.
-      VLOG(1) << status.ToString();
-      continue;
-    }
-    ::tensorflow::AddDefaultsToNodeDef(op_reg_data->op_def, &node_def);
-  }
-  return Status::OK();
-}
-
-Status Importer::RemoveBackedges(const Graph& graph) {
-  // TODO(fengliuai): Converting to GraphDef and back is the easiest way to
-  // clone a graph.
-  // TODO(fengliuai): clone the graph without going to graph_def first.
-  GraphDef graph_def;
-  graph.ToGraphDef(&graph_def);
-  graph_ = absl::make_unique<Graph>(graph.flib_def());
-  GraphConstructorOptions opts;
-  opts.allow_internal_ops = true;
-  TF_RETURN_IF_ERROR(
-      ::tensorflow::ConvertGraphDefToGraph(opts, graph_def, graph_.get()));
-
-  // Remove all the backedges. So the nodes can be added to the shape refiner.
-  TF_RETURN_IF_ERROR(back_edge_helper_.Remove(graph_.get()));
-  VLOG(1) << "Found " << (back_edge_helper_.RemovedEdges().size())
-          << " backedges.";
-
-  // Creates a map for quickly identifying whether a node output is a backedge.
-  for (const auto& edge : back_edge_helper_.RemovedEdges()) {
-    if (back_edge_node_output_.find(edge.src) != back_edge_node_output_.end() &&
-        back_edge_node_output_[edge.src] != edge.src_output) {
-      return errors::FailedPrecondition(
-          "More than one of the src node outputs are backedges!");
-    }
-    back_edge_node_output_[edge.src] = edge.src_output;
-    // We expect a merge to receive a single backedge (multiple NextIteration
-    // nodes feeding into the same merge is unexpected here).
-    DCHECK(!back_edge_dst_inputs_.contains(edge.dst));
-    back_edge_dst_inputs_[edge.dst] = edge;
-  }
-
-  // Obtains a RPO ordering, using node names as a tiebreak for stable sorting.
-  GetReversePostOrder(
-      *graph_, &ordered_nodes_,
-      [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
-
-  return Status::OK();
-}
-
-StatusOr<Node*> Importer::ReplaceWithPlaceholderNode(
-    const TensorShapeProto& shape, DataType dtype, Node* input_node) {
-  Node* placeholder_node;
-  NodeBuilder builder(input_node->name(), "Placeholder");
-  builder.Attr("shape", shape);
-  builder.Attr("dtype", dtype);
-  TF_RETURN_IF_ERROR(builder.Finalize(graph_.get(), &placeholder_node));
-
-  while (!input_node->out_edges().empty()) {
-    const Edge* oe = *input_node->out_edges().begin();
-    TF_RETURN_IF_ERROR(graph_->UpdateEdge(
-        placeholder_node,
-        oe->src_output() == Graph::kControlSlot ? Graph::kControlSlot : 0,
-        oe->dst(), oe->dst_input()));
-  }
-
-  graph_->RemoveNode(input_node);
-
-  return placeholder_node;
-}
-
-Status Importer::GetInputOutputNodes(std::unordered_set<const Node*>* nodes) {
-  auto node_name_map = graph_->BuildNodeNameIndex();
-  auto add_node = [&](const string& name) {
-    auto it = node_name_map.find(name);
-    if (it == node_name_map.end()) {
-      return errors::FailedPrecondition(
-          absl::StrCat("Graph does not contain node :", name));
-    }
-    nodes->insert(it->second);
-    return Status::OK();
-  };
-
-  for (const auto& input : specs_.inputs) {
-    TF_RETURN_IF_ERROR(add_node(input.first));
-  }
-
-  for (const auto& output_node_name : specs_.output_arrays) {
-    TF_RETURN_IF_ERROR(add_node(output_node_name));
-  }
-
-  return Status::OK();
-}
-
-// TODO(fengliuai): Replace the iterative algorithm by an one pass propagation
-Status Importer::AddNodesToShapeRefiner() {
-  shape_refiner_ =
-      absl::make_unique<ShapeRefiner>(*graph_versions_, graph_->op_registry());
-  // Some operations (for example "TPUExecute") don't have shape inference
-  // function defined, so we should set this to false for adding nodes with
-  // these types of operations.
-  shape_refiner_->set_require_shape_inference_fns(false);
-  shape_refiner_->set_function_library_for_shape_inference(&graph_flib_);
-
-  // First add all nodes to the refiner.
-  for (Node* node : ordered_nodes_) {
-    // We need to use a TensorFlow node to teach the shape refiner that user
-    // specifies certain data type and shape for the inputs in the `specs_`.
-    // This node shouldn't have any inputs, only have one output and its
-    // output type/shape is only determined by its "named" attributes. (The
-    // attributes should have fixed names so we can use the info from `specs_`
-    // to set the value of them.) `Placeholder` satisfies these constraints.
-    //
-    // Therefore, if the input node isn't a `Placeholder`, we create one and use
-    // it to replace the original input node, so the shape refiner can
-    // successfully propagate the user's input type and shape to the rest of the
-    // graph.
-    auto it = specs_.inputs.find(node->name());
-    if (it != specs_.inputs.end()) {
-      auto node_name = node->op_def().name();
-      if (node_name != "Placeholder" && node_name != "LegacyFedInput") {
-        // We do not handle the case where the input node has multple outputs
-        if (node->num_outputs() > 1) {
-          return errors::FailedPrecondition(absl::StrCat(
-              "Input arrays can only have op with single output. Node op:",
-              node_name));
-        }
-        // For single output nodes, replace them with Placeholder node
-        TF_ASSIGN_OR_RETURN(
-            node, ReplaceWithPlaceholderNode(it->second.shape,
-                                             it->second.imported_dtype, node));
-      } else {
-        node->AddAttr("shape", it->second.shape);
-        node->AddAttr("dtype", it->second.imported_dtype);
-      }
-    }
-    // Adds the node to the shape refiner.
-    TF_RETURN_WITH_CONTEXT_IF_ERROR(shape_refiner_->AddNode(node),
-                                    GetLocationStr(*node));
-
-    // If it is the argument node, the shape handle is set explicitly, so it
-    // can be propagated to the body nodes of the function.
-    if (StringPiece(node->type_string()) == FunctionLibraryDefinition::kArgOp) {
-      auto* node_context = shape_refiner_->GetContext(node);
-      DCHECK(node_context != nullptr);
-      auto it = node->def().attr().find("shape");
-      if (it != node->def().attr().end()) {
-        shape_inference::ShapeHandle handle;
-        TF_RETURN_WITH_CONTEXT_IF_ERROR(
-            node_context->MakeShapeFromShapeProto(it->second.shape(), &handle),
-            GetLocationStr(*node));
-        node_context->set_output(0, handle);
-      } else {
-        node_context->set_output(0, node_context->UnknownShape());
-      }
-    }
-  }
-
-  // Since we might have inserted and removed nodes from the graph, fix
-  // source/sink edges and reconstruct the RPO ordering of nodes
-  FixupSourceAndSinkEdges(graph_.get());
-
-  // Prune nodes in the graph that are not reachable from the output.
-  if (specs_.prune_unused_nodes) {
-    std::unordered_set<const Node*> prune_start;
-    TF_RETURN_IF_ERROR(GetInputOutputNodes(&prune_start));
-    if (!prune_start.empty()) {
-      if (PruneForReverseReachability(graph_.get(), prune_start)) {
-        VLOG(1) << "Pruned unused nodes in graphdef";
-      } else {
-        VLOG(1) << "No unused nodes in graphdef to prune";
-      }
-    } else {
-      VLOG(1) << "No output nodes specified, skipping pruning";
-    }
-  } else {
-    VLOG(1) << "Pruning unused nodes in graphdef is disabled";
-  }
-
-  // Re-initialize ordered_nodes_ since we might have modified the graph.
-  GetReversePostOrder(
-      *graph_, &ordered_nodes_,
-      [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
-
-  VLOG(1) << "Inferring graph shapes to fixpoint";
-
-  // The "changed" information from UpdateNode can give false positives, so we
-  // create a dedicated method to verify the shapes are not changed before and
-  // after the shape refine.
-  auto same_inferred_shape = [](shape_inference::InferenceContext* c,
-                                shape_inference::ShapeHandle s0,
-                                shape_inference::ShapeHandle s1) -> bool {
-    if (s0.SameHandle(s1) || (!c->RankKnown(s0) && !c->RankKnown(s1))) {
-      return true;
-    }
-    if (c->Rank(s0) != c->Rank(s1)) {
-      return false;
-    }
-    for (int i = 0; i < c->Rank(s0); ++i) {
-      if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) {
-        int64 val0 = c->Value(c->Dim(s0, i));
-        int64 val1 = c->Value(c->Dim(s1, i));
-        // Negative value is treated as unknown so all negative values indicate
-        // the same dimension.
-        if (val0 >= 0 && val1 >= 0 && val0 != val1) return false;
-      }
-    }
-    return true;
-  };
-
-  bool changed = true;
-  int i = 0;
-  const int kMaxIterationCount = 2;
-  while (changed && i != kMaxIterationCount) {
-    changed = false;
-    for (const Node* node : ordered_nodes_) {
-      auto* shape_context = shape_refiner_->GetContext(node);
-      DCHECK(shape_context != nullptr);
-      absl::InlinedVector<shape_inference::ShapeHandle, 4> existing;
-      existing.reserve(shape_context->num_outputs());
-      for (int o = 0; o < shape_context->num_outputs(); ++o) {
-        existing.push_back(shape_context->output(o));
-      }
-      bool inferred = false;
-      TF_RETURN_WITH_CONTEXT_IF_ERROR(
-          shape_refiner_->UpdateNode(node, /*relax=*/false, &inferred),
-          GetLocationStr(*node));
-      for (int o = 0; o < shape_context->num_outputs(); ++o) {
-        if (!same_inferred_shape(shape_context, shape_context->output(o),
-                                 existing[o])) {
-          changed = true;
-          break;
-        }
-      }
-    }
-    ++i;
-  }
-  if (i >= kMaxIterationCount) {
-    LOG(WARNING) << "Graph shapes did not converge to a fixpoint within "
-                 << kMaxIterationCount
-                 << " iterations. Graph shapes may be conservative.";
-  }
-  VLOG(1) << "Graph shapes were inferred with " << (i - 1)
-          << " extra rounds of analysis to reach a fixpoint.";
-  return Status::OK();
-}
-
-StatusOr<mlir::TensorType> Importer::InferInputType(
-    ExtendedInferenceContext* shape_context, int idx, mlir::Builder builder) {
-  DataType dtype = shape_context->input_type(idx);
-  auto* context = shape_context->get_context();
-  return ConvertDataTypeAndShape(dtype, context->input(idx),
-                                 context->input_handle_shapes_and_types(idx),
-                                 context, builder);
-}
-
-StatusOr<mlir::TensorType> Importer::InferOutputType(
-    ExtendedInferenceContext* shape_context, int idx, mlir::Builder builder) {
-  DataType dtype = shape_context->output_type(idx);
-  auto* context = shape_context->get_context();
-  return ConvertDataTypeAndShape(dtype, context->output(idx),
-                                 context->output_handle_shapes_and_types(idx),
-                                 context, builder);
-}
-
-StatusOr<mlir::TensorType> Importer::ConvertDataTypeAndShape(
-    DataType dtype, const shape_inference::ShapeHandle& handle,
-    const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
-    shape_inference::InferenceContext* context, mlir::Builder builder) {
-  TF_ASSIGN_OR_RETURN(auto subtypes,
-                      ConvertSubtypes(handle_subtypes, context, builder));
-
-  // TODO(hinsu): Store subtypes information for DT_RESOURCE element type as
-  // well.
-  mlir::Type element_type;
-  if (dtype == DT_VARIANT) {
-    element_type = mlir::TF::VariantType::get(subtypes, context_);
-  } else {
-    TF_RETURN_IF_ERROR(
-        ::tensorflow::ConvertDataType(dtype, builder, &element_type));
-  }
-  return ConvertElementTypeAndShape(element_type, handle, context, builder);
-}
-
-StatusOr<mlir::TensorType> Importer::ConvertElementTypeAndShape(
-    mlir::Type element_type, const shape_inference::ShapeHandle& handle,
-    shape_inference::InferenceContext* context, mlir::Builder builder) {
-  if (!context->RankKnown(handle)) {
-    return builder.getTensorType(element_type);
-  }
-
-  // Sentinel for an unknown dimension size. getTensorType interprets any
-  // negative value as an unknown dimension.
-  // TODO(jmolloy): Ideally this shouldn't be a local sentinel.
-  const int64_t kUnknownDim = -1;
-
-  absl::InlinedVector<int64_t, 4> dimensions;
-  int32 rank = context->Rank(handle);
-  dimensions.reserve(rank);
-  for (int i = 0; i < rank; ++i) {
-    auto dim_handle = context->Dim(handle, i);
-    if (!context->ValueKnown(dim_handle))
-      dimensions.push_back(kUnknownDim);
-    else
-      dimensions.push_back(context->Value(dim_handle));
-  }
-
-  return builder.getTensorType(
-      llvm::makeArrayRef(dimensions.begin(), dimensions.end()), element_type);
-}
-
-StatusOr<Importer::ElementSubtypes> Importer::ConvertSubtypes(
-    const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
-    shape_inference::InferenceContext* context, mlir::Builder builder) {
-  ElementSubtypes subtypes;
-  if (!handle_subtypes) return subtypes;
-
-  subtypes.reserve(handle_subtypes->size());
-  for (const auto& subtype : *handle_subtypes) {
-    mlir::Type element_type;
-    TF_RETURN_IF_ERROR(
-        ::tensorflow::ConvertDataType(subtype.dtype, builder, &element_type));
-    TF_ASSIGN_OR_RETURN(mlir::TensorType type,
-                        ConvertElementTypeAndShape(element_type, subtype.shape,
-                                                   context, builder));
-    subtypes.push_back(type);
-  }
-  return subtypes;
-}
-
-Status Importer::ConvertFunctionCallAttribute(
-    const std::string& base_name, const AttrValue& value,
-    llvm::SmallVector<mlir::NamedAttribute, 4>* attributes) {
-  TF_ASSIGN_OR_RETURN(auto func_attr,
-                      ConvertFunctionCallName(value.func().name()));
-  attributes->push_back(builder_->getNamedAttr(base_name, func_attr));
-
-  for (const auto& it : value.func().attr()) {
-    auto name = absl::StrCat(base_name, ".", it.first);
-    TF_ASSIGN_OR_RETURN(auto value, ConvertAttributeValue(it.second));
-    attributes->push_back(builder_->getNamedAttr(name, value));
-  }
-  return Status::OK();
-}
-
-StatusOr<mlir::SymbolRefAttr> Importer::ConvertFunctionCallName(
-    const std::string& func_name) {
-  TF_RETURN_IF_ERROR(ConvertLibFunction(func_name));
-  auto mlir_func_name = (*tf_name_to_mlir_name_)[func_name];
-  auto func = module_.lookupSymbol<mlir::FuncOp>(mlir_func_name);
-  return builder_->getSymbolRefAttr(func);
-}
-
-StatusOr<mlir::Attribute> Importer::ConvertAttributeValue(
-    const AttrValue& value) {
-  switch (value.value_case()) {
-    case AttrValue::kI:
-      return builder_->getI64IntegerAttr(value.i());
-    case AttrValue::kS:
-      return builder_->getStringAttr(value.s());
-    case AttrValue::kF:
-      return builder_->getFloatAttr(builder_->getF32Type(), value.f());
-    case AttrValue::kB:
-      return builder_->getBoolAttr(value.b());
-    case AttrValue::kType:
-      return builder_->getStringAttr(
-          mangling_util::MangleDataType(value.type()));
-    case AttrValue::kShape:
-      return builder_->getStringAttr(mangling_util::MangleShape(value.shape()));
-    case AttrValue::kTensor:
-      return ConvertTensorProto(value.tensor());
-    case AttrValue::kList: {
-      absl::InlinedVector<mlir::Attribute, 8> attrs;
-      for (const auto& item : value.list().i())
-        attrs.push_back(builder_->getI64IntegerAttr(item));
-      for (const auto& item : value.list().s())
-        attrs.push_back(builder_->getStringAttr(item));
-      for (const auto& item : value.list().f())
-        attrs.push_back(builder_->getFloatAttr(builder_->getF32Type(), item));
-      for (const auto& item : value.list().b())
-        attrs.push_back(builder_->getBoolAttr(item));
-      for (const auto& item : value.list().type()) {
-        attrs.push_back(builder_->getStringAttr(
-            mangling_util::MangleDataType(static_cast<DataType>(item))));
-      }
-      for (const auto& item : value.list().shape()) {
-        attrs.push_back(
-            builder_->getStringAttr(mangling_util::MangleShape(item)));
-      }
-      for (const auto& item : value.list().tensor()) {
-        TF_ASSIGN_OR_RETURN(auto attr, ConvertTensorProto(item));
-        attrs.push_back(attr);
-      }
-      for (const auto& item : value.list().func()) {
-        TF_ASSIGN_OR_RETURN(auto attr, ConvertFunctionCallName(item.name()));
-        if (item.attr_size() != 0)
-          return errors::Unimplemented(
-              "func attributes with non-zero attr.size()");
-        attrs.push_back(attr);
-      }
-      return builder_->getArrayAttr(
-          llvm::makeArrayRef(attrs.begin(), attrs.end()));
-    }
-    case AttrValue::kFunc:
-      return errors::Unknown("kFunc type should be handled separately!");
-    case AttrValue::VALUE_NOT_SET:
-      return builder_->getUnitAttr();
-    // kPlaceholder is not implemented.
-    default:
-      return errors::Unimplemented(
-          absl::StrCat("Attribute ", value.DebugString()));
-  }
-}
-
-Status Importer::ConvertLibFunction(const std::string& func_name) {
-  // If the library function has been converted already, nothing needs to be
-  // done.
-  if (tf_name_to_mlir_name_->find(func_name) != tf_name_to_mlir_name_->end())
-    return Status::OK();
-
-  std::string mlir_func_name = graph_flib_.UniqueFunctionName(func_name);
-  (*tf_name_to_mlir_name_)[func_name] = mlir_func_name;
-
-  const auto& func_lib = graph_flib_;
-  const auto* func_def = func_lib.Find(func_name);
-  if (func_def == nullptr) {
-    return errors::FailedPrecondition(
-        absl::StrCat("Failed to find function '", func_name,
-                     "'. The imported TensorFlow GraphDef is ill-formed."));
-  }
-
-  // Converts the function definition to a graph.
-  std::unique_ptr<FunctionBody> fbody;
-  TF_RETURN_IF_ERROR(
-      FunctionDefToBodyHelper(*func_def, AttrSlice(), &func_lib, &fbody));
-
-  // Converts the argument and return types to mlir types.
-  absl::InlinedVector<mlir::NamedAttribute, 8> attributes;
-  attributes.reserve(func_def->attr_size());
-  for (const auto& name_and_value : func_def->attr()) {
-    // This is a function definition attribute, so it shouldn't contain
-    // kFunc attribute and it is treated as normal one.
-    TF_ASSIGN_OR_RETURN(auto attr,
-                        ConvertAttributeValue(name_and_value.second));
-    std::string attr_name =
-        mangling_util::MangleAttributeName(name_and_value.first);
-    attributes.push_back(builder_->getNamedAttr(attr_name, attr));
-  }
-
-  // Checks opdef stateful attribute and import that as Function Attribute
-  if (func_def->signature().is_stateful()) {
-    auto stateful_str = mlir::TF::TensorFlowDialect::GetStatefulAttrName();
-    attributes.push_back(
-        builder_->getNamedAttr(stateful_str, builder_->getUnitAttr()));
-  }
-
-  // Checks for an associated custom gradient function. Adds it to the attribute
-  // list of this function.
-  auto grad_func_name = func_lib.FindGradient(func_name);
-  if (!grad_func_name.empty()) {
-    TF_RETURN_IF_ERROR(ConvertLibFunction(grad_func_name));
-    auto mlir_grad_func_name = (*tf_name_to_mlir_name_)[grad_func_name];
-    auto grad_func = module_.lookupSymbol<mlir::FuncOp>(mlir_grad_func_name);
-    auto gradient_attr = builder_->getSymbolRefAttr(grad_func);
-    auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
-    attributes.push_back(builder_->getNamedAttr(grad_string, gradient_attr));
-  }
-
-  // Converts the graph to a MLIR function and adds it to the module. Uses the
-  // default node spec without any inputs or outputs as the function graph has
-  // special '_Arg' and '_Retval' ops for argument and return values.
-  NodeSpecs specs;
-  Importer child_importer(graph_flib_, debug_info_, specs, module_,
-                          tf_name_to_mlir_name_);
-  TF_RETURN_IF_ERROR(child_importer.PrepareConvert(*fbody->graph));
-
-  TF_ASSIGN_OR_RETURN(auto func_type,
-                      child_importer.InferLibFunctionType(*fbody));
-
-  absl::InlinedVector<OutputTensor, 4> arg_nodes;
-  arg_nodes.reserve(fbody->arg_nodes.size());
-  absl::InlinedVector<OutputTensor, 4> ret_nodes;
-  ret_nodes.reserve(fbody->ret_nodes.size());
-  for (auto arg : fbody->arg_nodes) {
-    arg_nodes.emplace_back(arg, 0);
-  }
-  for (auto ret : fbody->ret_nodes) {
-    ret_nodes.emplace_back(ret, 0);
-  }
-
-  TF_RETURN_IF_ERROR(child_importer.Convert(
-      mlir_func_name, func_type, arg_nodes, ret_nodes,
-      llvm::makeArrayRef(attributes.begin(), attributes.end())));
-  return Status::OK();
-}
-
-Status Importer::PrepareConvert(const Graph& graph) {
-  graph_versions_ = &graph.versions();
-  TF_RETURN_IF_ERROR(RemoveBackedges(graph));
-  TF_RETURN_IF_ERROR(AddNodesToShapeRefiner());
-  return Status::OK();
-}
-
-Status Importer::ConvertFunctionArgAndRets(
-    mlir::Block* bb, llvm::ArrayRef<mlir::Type> arg_types,
-    const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
-    const absl::InlinedVector<OutputTensor, 4>& ret_nodes) {
-  for (int i = 0, e = arg_types.size(); i < e; ++i) {
-    auto* inst = node_values_[arg_nodes[i].node->id()];
-    auto* bb_arg = bb->addArgument(arg_types[i]);
-    mlir::Value* arg_def = bb_arg;
-
-    // If this is an input node add argument to the operation operands by
-    // creating a new input operation.
-    if (StringPiece(arg_nodes[i].node->type_string()) !=
-        FunctionLibraryDefinition::kArgOp) {
-      auto inst_name = inst->getName().getStringRef();
-      mlir::OperationState state(inst->getLoc(),
-                                 inst_name.str().append(".input"));
-      state.attributes.append(inst->getAttrs().begin(), inst->getAttrs().end());
-
-      // If there are quantization specifications, add them as the attributes
-      auto name = inst->getAttrOfType<mlir::StringAttr>("name").getValue();
-      auto input_spec_it = specs_.inputs.find(name.str());
-      if (input_spec_it != specs_.inputs.end()) {
-        auto input_spec = input_spec_it->second;
-        if (IsQuantizationType(input_spec.final_dtype)) {
-          // Uses the MLIR built-in type so it can be handled easily later.
-          auto final_type = mlir::IntegerType::get(
-              GetQuantizationTypeWidth(input_spec.final_dtype), context_);
-          state.attributes.push_back(builder_->getNamedAttr(
-              "min", builder_->getF32FloatAttr(input_spec.min_value)));
-          state.attributes.push_back(builder_->getNamedAttr(
-              "max", builder_->getF32FloatAttr(input_spec.max_value)));
-          state.attributes.push_back(builder_->getNamedAttr(
-              "type", builder_->getTypeAttr(final_type)));
-          inst->getParentOfType<mlir::FuncOp>().setAttr(
-              "tf.quantize", builder_->getUnitAttr());
-        }
-      }
-
-      for (auto* r : inst->getResults()) state.types.push_back(r->getType());
-
-      state.operands.append(inst->getOperands().begin(),
-                            inst->getOperands().end());
-      state.operands.push_back(bb_arg);
-      builder_->setInsertionPoint(inst);
-      auto* input = builder_->createOperation(state);
-      arg_def = input->getResult(arg_nodes[i].index);
-      // Verify on the equivalent TF op would have failed, but catching this
-      // earlier for now as this exposed a bug. TODO(jpienaar): remove post
-      // dialect refactoring.
-      DCHECK(input->getResult(0)->getType() == input->getOperand(0)->getType())
-          << "invalid placeholder_input constructed";
-    }
-
-    for (auto index = 0; index < inst->getNumResults(); index++) {
-      inst->getResult(index)->replaceAllUsesWith(arg_def);
-    }
-    inst->dropAllReferences();
-    inst->erase();
-  }
-
-  absl::InlinedVector<mlir::Value*, 8> inst_to_returned;
-  for (const auto& ret : ret_nodes) {
-    auto* inst = node_values_[ret.node->id()];
-    auto op = absl::string_view(ret.node->type_string());
-    if (op == FunctionLibraryDefinition::kRetOp ||
-        op == FunctionLibraryDefinition::kDeviceRetOp) {
-      // Remove kRetOp or kDeviceRetOp operation and return its operand.
-      // kRetOp and kDeviceRetOp should have just one operand unless they have
-      // control dependencies.
-      if (inst->getNumOperands() != 1)
-        return errors::Unimplemented("Return node with multiple inputs.");
-      inst_to_returned.push_back(inst->getOperand(0));
-      node_values_[ret.node->id()]->dropAllReferences();
-      node_values_[ret.node->id()]->erase();
-    } else {
-      inst_to_returned.push_back(inst->getResult(ret.index));
-    }
-  }
-  builder_->setInsertionPointToEnd(bb);
-  builder_->create<mlir::ReturnOp>(
-      mlir::UnknownLoc::get(context_),
-      llvm::makeArrayRef(inst_to_returned.begin(), inst_to_returned.end()));
-  return Status::OK();
-}
-
-mlir::Location Importer::GetLocation(const NodeDef& node_def) {
-  const auto& debug_info = debug_info_.traces();
-
-  // Get the CallSiteLoc for a node name.
-  // - If the debug info of the node couldn't be found, the caller of the
-  //   returned CallSiteLoc is set to an UnknownLoc;
-  // - If the debug info of the node is found, the caller of the returned
-  //   CallSiteLoc is set to a call stack which is formed by the debug info.
-  auto node_name_to_call_site = [&](const std::string& name) -> mlir::Location {
-    auto name_id = mlir::Identifier::get(name, context_);
-    const auto& location_it = debug_info.find(name);
-    if (location_it == debug_info.end()) {
-      // Only the node name is stored if the location is unknown.
-      return mlir::NameLoc::get(name_id, context_);
-    }
-
-    // Convert the stack trace to a chain of mlir::CallSiteLocs.
-    const auto& trace = location_it->second;
-    llvm::SmallVector<mlir::Location, 4> locations;
-    locations.reserve(trace.file_line_cols_size());
-    for (const auto& location : trace.file_line_cols()) {
-      const auto& file = debug_info_.files(location.file_index());
-      auto file_name = mlir::Identifier::get(file, context_);
-      auto file_line_loc = mlir::FileLineColLoc::get(file_name, location.line(),
-                                                     location.col(), context_);
-      locations.push_back(file_line_loc);
-    }
-    // Handle empty location vector.
-    if (locations.empty()) return mlir::NameLoc::get(name_id, context_);
-
-    // Use the front FileLineColLoc to generate a NameLoc.
-    mlir::Location node_name_loc =
-        mlir::NameLoc::get(name_id, locations.front(), context_);
-
-    // If there are more locations then generate a stack trace, otherwise just
-    // return the name loc.
-    auto callsite_locs = llvm::makeArrayRef(locations).drop_front();
-    return callsite_locs.empty()
-               ? node_name_loc
-               : mlir::CallSiteLoc::get(node_name_loc, callsite_locs, context_);
-  };
-
-  // For NextIteration nodes, location is used to pair source and sink nodes.
-  // Hence, we use node name as location to keep it unique.
-  // TODO(prakalps): In future the plan is to use tokens to pair source/sink
-  // nodes. Then NextIteration nodes would not need to be handled seprately.
-  if (node_def.op() == "NextIteration")
-    return node_name_to_call_site(node_def.name());
-
-  auto original_nodes =
-      node_def.experimental_debug_info().original_node_names();
-  auto original_funcs =
-      node_def.experimental_debug_info().original_func_names();
-
-  if (original_nodes.empty()) {
-    // If the original nodes are not defined in the node def, but the current
-    // node name is contained in the debug info file, then we fall back to use
-    // the current node name to get the location info. Otherwise, use a
-    // NameLoc with node name as in a TensorFlow graph the node name is unique.
-    auto& curr_node_name = node_def.name();
-    if (debug_info.find(curr_node_name) == debug_info.end()) {
-      return mlir::NameLoc::get(mlir::Identifier::get(curr_node_name, context_),
-                                context_);
-    } else {
-      return node_name_to_call_site(curr_node_name);
-    }
-  } else {
-    // If the original nodes are defined, then we use them to get a list of
-    // call sites, and then fuse them to a single fused location.
-    llvm::SmallVector<mlir::Location, 4> node_call_sites;
-    node_call_sites.reserve(original_nodes.size());
-    for (int i = 0, e = original_nodes.size(); i != e; ++i) {
-      auto node_name = original_nodes[i];
-      auto func_name = (i < original_funcs.size()) ? original_funcs[i] : "";
-      // Use the catenation of function and node names as the lookup key. This
-      // is to match the utility of generating the GraphDebugInfo.
-      node_call_sites.push_back(node_name_to_call_site(func_name + node_name));
-    }
-    return mlir::FusedLoc::get(node_call_sites, context_);
-  }
-}
-
-std::string Importer::GetLocationStr(const Node& node, bool includeNodeName) {
-  const auto location = GetLocation(node.def());
-  std::string s;
-  llvm::raw_string_ostream ss(s);
-  location.print(ss);
-  ss.flush();
-  // Removes the node name prefix if it exists.
-  if (!s.empty() && s[0] == '\"' && s.find_first_of(node.name()) == 1) {
-    return s.replace(0, node.name().size() + 3, "");
-  }
-  return s;
-}
-
-Status Importer::ConvertNode(const Node& node) {
-  if (!node.IsOp()) {
-    // Don't import the pseudo-nodes _SOURCE or _SINK. These are added by
-    // Graph and don't exist in GraphDef.
-    return Status::OK();
-  }
-
-  // If it is a custom OP, its definition should be found in the library. We
-  // create the MLIR function and insert it to the module if it doesn't exist.
-  std::string node_type_name = node.type_string();
-  const auto* func_def = graph_flib_.Find(node_type_name);
-  if (func_def) {
-    TF_RETURN_IF_ERROR(ConvertLibFunction(node_type_name));
-    node_type_name = (*tf_name_to_mlir_name_)[node_type_name];
-  }
-
-  auto get_full_op_name = [&](const std::string& op_name) {
-    const char* kTfControlFlowFormPrefix = "_tf.";
-    return kTfControlFlowFormPrefix + op_name;
-  };
-
-  std::string op_name = get_full_op_name(node_type_name);
-  if (back_edge_node_output_.contains(&node)) {
-    op_name = op_name + ".sink";
-  }
-
-  const auto& node_def = node.def();
-  mlir::OperationState result(GetLocation(node_def), op_name);
-
-  ExtendedInferenceContext* context = shape_refiner_->GetExtendedContext(&node);
-  for (int i = 0; i < node.num_outputs(); ++i) {
-    // The backedge has been removed, so we shouldn't count the corresponding
-    // output from the src node when converting to an operation.
-    if (back_edge_node_output_.contains(&node) &&
-        back_edge_node_output_[&node] == i) {
-      continue;
-    }
-    TF_ASSIGN_OR_RETURN(auto type, InferOutputType(context, i, *builder_));
-    result.types.push_back(type);
-  }
-  result.types.push_back(
-      builder_->getType<mlir::TFControlFlow::TFControlType>());
-
-  // Surprisingly input edges can be nondeterministically ordered. This
-  // particularly seems to be the case for the control edges between _SOURCE
-  // and _SINK that the Graph constructor inserts. Copy the input edges and
-  // sort the edges, but only the control edges, not data edges!
-  // TODO(jmolloy): We should probably just ignore _SOURCE and _SINK nodes.
-  // They'll break roundtripping anyway unless we strip them when converting
-  // back to graphdef.
-  absl::InlinedVector<const Edge*, 8> in_edges(node.in_edges().size());
-  absl::c_copy(node.in_edges(), in_edges.begin());
-  absl::c_stable_sort(in_edges, [](const Edge* e1, const Edge* e2) {
-    if (e1->IsControlEdge() && !e2->IsControlEdge()) return false;
-    if (!e1->IsControlEdge() && e2->IsControlEdge()) return true;
-    return e1->dst_input() < e2->dst_input();
-  });
-
-  result.operands.reserve(in_edges.size());
-  for (const auto* input_edge : in_edges) {
-    const Node& input_node = *input_edge->src();
-    if (input_node.IsSource()) {
-      if (in_edges.size() != 1) {
-        return errors::FailedPrecondition(
-            "The node has other inputs besides the _Source node");
-      }
-      // We don't import the _SOURCE node.
-      continue;
-    }
-    if (input_node.IsArg() && input_edge->IsControlEdge()) {
-      // Currently we have not reached consensus as to what TF function
-      // semantics are (b/133509504). Here we assume that all arguments to a
-      // function should be available before we start execution of any internal
-      // node. This makes the control dependencies between function arguments
-      // and internal nodes redundant, and so we do not import them. The TF
-      // inliner however assumes no such dependency between function args and
-      // internal nodes exists, unless explicitly stated. Since we drop control
-      // dependencies here, it leads to loss of information. If the function is
-      // inlined later, the inliner would not know of these explicit control
-      // dependencies present in the original graph.
-      continue;
-    }
-    if (node_values_.find(input_node.id()) == node_values_.end())
-      return errors::FailedPrecondition(
-          "Graph not traversed in reverse post order; use seen before def!");
-    mlir::Operation* inst = node_values_[input_node.id()];
-    result.operands.push_back(inst->getResult(input_edge->IsControlEdge()
-                                                  ? inst->getNumResults() - 1
-                                                  : input_edge->src_output()));
-  }
-
-  using FuncPairType = std::pair<const std::string*, const AttrValue*>;
-  std::vector<FuncPairType> funcs;
-  result.attributes.reserve(node.attrs().size() + 2);
-  for (const auto& name_and_value : node.attrs()) {
-    const auto& attr_name = name_and_value.first;
-    const AttrValue& attr_value = name_and_value.second;
-    if (attr_value.value_case() == AttrValue::kFunc) {
-      // Attribute iteration order is not defined for protocol buffer Map.
-      // Process function attributes separately in the lexicographical order to
-      // have deterministic order of functions in the constructed IR.
-      funcs.emplace_back(&attr_name, &attr_value);
-    } else {
-      TF_ASSIGN_OR_RETURN(auto attr, ConvertAttributeValue(attr_value));
-      result.attributes.push_back(builder_->getNamedAttr(attr_name, attr));
-    }
-  }
-
-  auto comparator = [](const FuncPairType& a, const FuncPairType& b) {
-    return *a.first < *b.first;
-  };
-  std::sort(funcs.begin(), funcs.end(), comparator);
-  for (const auto& func : funcs) {
-    TF_RETURN_IF_ERROR(ConvertFunctionCallAttribute(*func.first, *func.second,
-                                                    &result.attributes));
-  }
-
-  result.attributes.push_back(builder_->getNamedAttr(
-      "name", builder_->getStringAttr(std::string(node.name()))));
-  result.attributes.push_back(builder_->getNamedAttr(
-      "device", builder_->getStringAttr(std::string(node_def.device()))));
-
-  // Map If and StatelessIf op in TensorFlow to the common If op in MLIR and add
-  // the differentiating attribute.
-  if (node.IsIfNode()) {
-    result.name = mlir::OperationName(get_full_op_name("If"), context_);
-    mlir::BoolAttr val = builder_->getBoolAttr(node_type_name == "StatelessIf");
-    result.attributes.push_back(builder_->getNamedAttr("is_stateless", val));
-  }
-
-  node_values_[node.id()] = builder_->createOperation(result);
-  return Status::OK();
-}
-
-// Add the backedges to the CFG. Given a backedge, we replace the original
-// source and destination operations by two new operations. Most of the
-// fields of the replacements are copied from the original operations.
-// However,
-// - for the src operation, one output is inserted to the front of the output
-//   list. The type of the output is set to the type of the non-control result
-//   of the dst operation, and
-// - for the dst operation, one operand is inserted to the front of the
-//   operand list. This operand is using the first result of the src
-//   operation.
-// TODO(fengliuai): Preserve the order of the results and operands if
-// necessary.
-Status Importer::AddBackedges() {
-  for (auto it : back_edge_dst_inputs_) {
-    BackEdge& edge = it.second;
-    if (!edge.src->IsNextIteration() || !edge.dst->IsMerge()) {
-      return errors::FailedPrecondition(
-          "Invalid backedge; should be from NextIteration to Merge!");
-    }
-    auto* sink = node_values_[edge.src->id()];
-    auto* dst = node_values_[edge.dst->id()];
-    TF_RETURN_IF_ERROR(AddBackedge(sink, dst, edge.dst_input));
-  }
-  return Status::OK();
-}
-
-Status Importer::AddBackedge(mlir::Operation* sink, mlir::Operation* dst,
-                             int dst_input) {
-  mlir::Operation* source = GetOrCreateNextIterationSource(sink, dst);
-
-  // Adds the "source" to the operands of the dst by creating a new dst
-  // operation.
-  mlir::OperationState state(dst->getLoc(), dst->getName());
-  auto num_operands = dst->getNumOperands();
-  state.operands.reserve(num_operands + 1);
-  for (int input = 0, e = num_operands + 1; input != e; ++input) {
-    if (input < dst_input) {
-      state.operands.push_back(dst->getOperand(input));
-    } else if (input == dst_input) {
-      state.operands.push_back(source->getResult(0));
-    } else {
-      state.operands.push_back(dst->getOperand(input - 1));
-    }
-  }
-  state.attributes.append(dst->getAttrs().begin(), dst->getAttrs().end());
-  for (auto* result : dst->getResults()) {
-    state.types.push_back(result->getType());
-  }
-  builder_->setInsertionPoint(dst);
-  auto* new_dst = builder_->createOperation(state);
-
-  // Replaces the output uses of the old operation by the corresponding
-  // result of the new operation, and deletes the old operation.
-  for (unsigned i = 0, e = dst->getNumResults(); i != e; ++i) {
-    auto* new_output = new_dst->getResult(i);
-    dst->getResult(i)->replaceAllUsesWith(new_output);
-  }
-  dst->dropAllReferences();
-  dst->erase();
-  return Status::OK();
-}
-
-mlir::Operation* Importer::GetOrCreateNextIterationSource(
-    mlir::Operation* sink, mlir::Operation* dst) {
-  auto iter = next_iteration_sink_source_.find(sink);
-  if (iter != next_iteration_sink_source_.end()) return iter->second;
-
-  auto inst_name = sink->getName().getStringRef();
-  inst_name.consume_back(".sink");
-  mlir::OperationState src_state(sink->getLoc(),
-                                 inst_name.str().append(".source"));
-  src_state.attributes.append(sink->getAttrs().begin(), sink->getAttrs().end());
-  src_state.types.push_back(dst->getResult(0)->getType());
-  src_state.types.push_back(
-      builder_->getType<mlir::TFControlFlow::TFControlType>());
-  builder_->setInsertionPoint(dst->getBlock(), dst->getBlock()->begin());
-  mlir::Operation* source = builder_->createOperation(src_state);
-  next_iteration_sink_source_[sink] = source;
-  return source;
-}
-
-Status Importer::Convert(llvm::StringRef func_name,
-                         mlir::FunctionType func_type,
-                         const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
-                         const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
-                         llvm::ArrayRef<mlir::NamedAttribute> attrs) {
-  // TODO(b/122040776): Uses debug info for FunctionDef.
-  auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_),
-                                       func_name, func_type, attrs);
-
-  module_.push_back(function);
-  builder_ = absl::make_unique<mlir::OpBuilder>(function.getBody());
-  // Seeds the builder with an initial block.
-  auto* bb = builder_->createBlock(&function.getBody());
-
-  for (const Node* node : ordered_nodes_) {
-    TF_RETURN_IF_ERROR(ConvertNode(*node));
-  }
-
-  // Adds the backedges back to the function by creating the source and sink
-  // pairs.
-  TF_RETURN_IF_ERROR(AddBackedges());
-
-  return ConvertFunctionArgAndRets(bb, func_type.getInputs(), arg_nodes,
-                                   ret_nodes);
-}
-
-StatusOr<mlir::FunctionType> Importer::InferMainFunctionType(
-    absl::InlinedVector<OutputTensor, 4>* arg_nodes,
-    absl::InlinedVector<OutputTensor, 4>* ret_nodes) {
-  // Finds out all the input nodes and output nodes.
-  if (!specs_.inputs.empty() || !specs_.output_arrays.empty()) {
-    arg_nodes->resize(specs_.inputs.size());
-    ret_nodes->resize(specs_.output_arrays_order.size());
-
-    for (Node* n : ordered_nodes_) {
-      // Handle inputs/arguments.
-      auto input_it = specs_.inputs.find(n->name());
-      if (input_it != specs_.inputs.end()) {
-        (*arg_nodes)[std::distance(specs_.inputs.begin(), input_it)] = {n, 0};
-      }
-
-      // Handle outputs/returns.
-      if (specs_.output_arrays.find(n->name()) != specs_.output_arrays.end()) {
-        for (int i = 0, e = specs_.output_arrays_order.size(); i != e; ++i) {
-          std::pair<std::string, std::string> name_and_port =
-              absl::StrSplit(specs_.output_arrays_order[i], ':');
-          auto name = name_and_port.first;
-          if (name != n->name()) continue;
-          int port = 0;
-          if (!name_and_port.second.empty() &&
-              !absl::SimpleAtoi(name_and_port.second, &port)) {
-            return errors::InvalidArgument("Invalid port specification: ",
-                                           specs_.output_arrays_order[i]);
-          }
-          (*ret_nodes)[i] = {n, port};
-        }
-      }
-    }
-  }
-
-  int i = 0;
-  for (auto it : specs_.inputs) {
-    if (arg_nodes->at(i++).node == nullptr) {
-      return errors::InvalidArgument("Input ", it.first,
-                                     " was not found in graph");
-    }
-  }
-  for (int i = 0, e = specs_.output_arrays_order.size(); i != e; ++i) {
-    if (ret_nodes->at(i).node == nullptr) {
-      return errors::InvalidArgument("Output ", specs_.output_arrays_order[i],
-                                     " was not found in graph");
-    }
-  }
-
-  // Starts to construct the function type.
-  llvm::SmallVector<mlir::Type, 4> arg_types;
-  llvm::SmallVector<mlir::Type, 4> ret_types;
-  arg_types.reserve(specs_.inputs.size());
-  ret_types.reserve(specs_.output_arrays.size());
-  mlir::Builder builder(context_);
-
-  // Input nodes as function arguments.
-  for (const auto& input : specs_.inputs) {
-    mlir::Type element_type;
-    const auto& node_info = input.second;
-    TF_RETURN_IF_ERROR(::tensorflow::ConvertDataType(node_info.imported_dtype,
-                                                     builder, &element_type));
-    llvm::SmallVector<int64_t, 4> shape;
-    TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape));
-    arg_types.push_back(builder.getTensorType(shape, element_type));
-  }
-
-  // Output nodes as function returns.
-  for (const auto& ret : *ret_nodes) {
-    if (ret.node->num_outputs() < 1) {
-      return errors::FailedPrecondition(
-          "Invalid output node; should have at least 1 output: " +
-          ret.node->name());
-    }
-    auto* shape_context = shape_refiner_->GetExtendedContext(ret.node);
-    TF_ASSIGN_OR_RETURN(auto type,
-                        InferOutputType(shape_context, ret.index, builder));
-    ret_types.push_back(type);
-  }
-
-  return builder.getFunctionType(arg_types, ret_types);
-}
-
-StatusOr<mlir::FunctionType> Importer::InferLibFunctionType(
-    const FunctionBody& fbody) {
-  mlir::Builder builder(context_);
-
-  llvm::SmallVector<mlir::Type, 4> arg_types;
-  arg_types.reserve(fbody.arg_types.size());
-  for (auto dataType : fbody.arg_types) {
-    mlir::Type element_type;
-    TF_RETURN_IF_ERROR(
-        ::tensorflow::ConvertDataType(dataType, builder, &element_type));
-    // TODO(hinsu): Derive shape of function arguments based on shapes available
-    // at call sites of this function. That way it is possible to have a
-    // partially known shape in some cases instead of unranked tensor types.
-    arg_types.push_back(builder.getTensorType(element_type));
-  }
-
-  llvm::SmallVector<mlir::Type, 4> ret_types;
-  ret_types.reserve(fbody.ret_types.size());
-  for (auto ret : fbody.ret_nodes) {
-    // Find node in the graph using the node id instead of using `ret` directly
-    // because the graph has been cloned.
-    auto* node = graph_->FindNodeId(ret->id());
-    auto* shape_context = shape_refiner_->GetExtendedContext(node);
-
-    // Return type of the function is type of the only input of the respective
-    // return node in the function.
-    TF_ASSIGN_OR_RETURN(auto type,
-                        InferInputType(shape_context, /*idx=*/0, builder));
-    ret_types.push_back(type);
-  }
-
-  return builder.getFunctionType(arg_types, ret_types);
-}
-
-StatusOr<mlir::OwningModuleRef> Importer::Convert(
-    mlir::MLIRContext* context, const Graph& graph,
-    const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
-    const NodeSpecs& specs) {
-  mlir::OwningModuleRef module =
-      mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
-  std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
-  Importer importer(flib_def, debug_info, specs, module.get(),
-                    &tf_name_to_mlir_name);
-  TF_RETURN_IF_ERROR(importer.PrepareConvert(graph));
-
-  // Collects the argument and return nodes by looking up the node names
-  // specified by the user.
-  absl::InlinedVector<OutputTensor, 4> arg_nodes;
-  absl::InlinedVector<OutputTensor, 4> ret_nodes;
-  TF_ASSIGN_OR_RETURN(auto func_type,
-                      importer.InferMainFunctionType(&arg_nodes, &ret_nodes));
-
-  // TODO(prakalps): Refactor to keep attribute strings (tf.entry_function,
-  // tf.versions) shared by importer and exporter in a centralized place.
-  // Record the input and output mapping.
-  llvm::SmallVector<mlir::NamedAttribute, 1> attrs;
-  if (!specs.inputs.empty() || !specs.output_arrays.empty()) {
-    mlir::Builder b(context);
-    std::string s;
-    llvm::raw_string_ostream ss(s);
-    mlir::interleaveComma(
-        specs.inputs, ss,
-        [&](const std::pair<std::string, ArrayInfo>& v) { ss << v.first; });
-    auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
-    s.clear();
-    mlir::interleaveComma(specs.output_arrays, ss,
-                          [&](const std::string& v) { ss << v; });
-    auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
-
-    attrs.push_back(b.getNamedAttr("tf.entry_function",
-                                   b.getDictionaryAttr({inputs, outputs})));
-  }
-
-  // Record version info.
-  if (importer.graph_versions_) {
-    mlir::Builder b(context);
-    auto producer = b.getNamedAttr(
-        "producer", b.getI32IntegerAttr(importer.graph_versions_->producer()));
-    auto min_consumer = b.getNamedAttr(
-        "min_consumer",
-        b.getI32IntegerAttr(importer.graph_versions_->min_consumer()));
-    auto bad_consumers = b.getNamedAttr(
-        "bad_consumers", b.getI32ArrayAttr(llvm::ArrayRef<int32_t>(
-                             importer.graph_versions_->bad_consumers().begin(),
-                             importer.graph_versions_->bad_consumers().end())));
-    module->setAttr("tf.versions",
-                    b.getDictionaryAttr(llvm::ArrayRef<mlir::NamedAttribute>(
-                        {producer, min_consumer, bad_consumers})));
-  }
-
-  TF_RETURN_IF_ERROR(
-      importer.Convert("main", func_type, arg_nodes, ret_nodes, attrs));
-  return module;
-}
-}  // namespace
-
-StatusOr<mlir::OwningModuleRef> ConvertGraphdefToMlir(
-    const GraphDef& graphdef, const GraphDebugInfo& debug_info,
-    const NodeSpecs& specs, mlir::MLIRContext* context,
-    bool add_default_attributes) {
-  GraphConstructorOptions options;
-  options.allow_internal_ops = true;
-  Graph graph(OpRegistry::Global());
-
-  GraphDef preprocessed_graphdef(graphdef);
-  if (add_default_attributes) {
-    TF_RETURN_IF_ERROR(AddDefaultsToNodeDef(&preprocessed_graphdef));
-  }
-  TF_RETURN_IF_ERROR(
-      ConvertGraphDefToGraph(options, preprocessed_graphdef, &graph));
-
-  return ConvertGraphToMlir(graph, debug_info, graph.flib_def(), specs,
-                            context);
-}
-
-StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
-    const Graph& graph, const GraphDebugInfo& debug_info,
-    const FunctionLibraryDefinition& flib_def, const NodeSpecs& specs,
-    mlir::MLIRContext* context) {
-  return Importer::Convert(context, graph, debug_info, flib_def, specs);
-}
-
-}  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.h b/tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.h
deleted file mode 100644
index c494526..0000000
--- a/tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.h
+++ /dev/null
@@ -1,46 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_GRAPHDEF_H_
-#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_GRAPHDEF_H_
-
-#include "mlir/IR/MLIRContext.h"  // TF:local_config_mlir
-#include "mlir/IR/Module.h"  // TF:local_config_mlir
-#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
-#include "tensorflow/core/framework/function.h"
-#include "tensorflow/core/framework/graph.pb.h"
-#include "tensorflow/core/graph/graph.h"
-#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
-#include "tensorflow/stream_executor/lib/statusor.h"
-
-namespace tensorflow {
-
-// Given a GraphDef, returns a MLIR module containing the graph in control-flow
-// form.
-stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertGraphdefToMlir(
-    const GraphDef& graphdef, const GraphDebugInfo& debug_info,
-    const NodeSpecs& specs, mlir::MLIRContext* context,
-    bool add_default_attributes = true);
-
-// Given a Graph, returns a MLIR module containing the graph in control-flow
-// form.
-stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
-    const Graph& graph, const GraphDebugInfo& debug_info,
-    const FunctionLibraryDefinition& flib_def, const NodeSpecs& specs,
-    mlir::MLIRContext* context);
-
-}  // namespace tensorflow
-
-#endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_GRAPHDEF_H_
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
new file mode 100644
index 0000000..33d696d
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
@@ -0,0 +1,1594 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/strings/escaping.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/strings/strip.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/IR/Attributes.h"  // TF:local_config_mlir
+#include "mlir/IR/Builders.h"  // TF:local_config_mlir
+#include "mlir/IR/Function.h"  // TF:local_config_mlir
+#include "mlir/IR/Identifier.h"  // TF:local_config_mlir
+#include "mlir/IR/Location.h"  // TF:local_config_mlir
+#include "mlir/IR/MLIRContext.h"  // TF:local_config_mlir
+#include "mlir/IR/Module.h"  // TF:local_config_mlir
+#include "mlir/IR/Types.h"  // TF:local_config_mlir
+#include "mlir/StandardOps/Ops.h"  // TF:local_config_mlir
+#include "tensorflow/compiler/jit/shape_inference_helpers.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
+#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/shape_refiner.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/versions.pb.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
+
+namespace tensorflow {
+using stream_executor::port::StatusOr;
+
+namespace {
+
+// Stateful helper class to import a TensorFlow model into an MLIR Module.
+//
+// This is the base class that contains common utilties shared between the
+// GraphDef importer and SavedModel importer.
+//
+// A subclass is expected to call `PrepareConvert` first to perform necessary
+// preparation over the graph and also certain internal bookkeeping data.
+// Afterwards the other protected methods can be called.
+class ImporterBase {
+ protected:
+  explicit ImporterBase(
+      const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
+      const NodeSpecs& specs, mlir::ModuleOp module,
+      std::unordered_map<std::string, std::string>* tf_name_to_mlir_name)
+      : module_(module),
+        context_(module.getContext()),
+        tf_name_to_mlir_name_(tf_name_to_mlir_name),
+        graph_flib_(flib),
+        specs_(specs),
+        debug_info_(debug_info) {}
+
+  // Prepares converting the graph to an MLIR module. This step removes the
+  // backedges of the graph, orders the nodes and infers the shapes.
+  Status PrepareConvert(const Graph& graph);
+
+  // Converts the prepared graph to a Function and adds it to the module. A set
+  // of nodes from the graph are given to converted to the arguments and returns
+  // of the function.
+  Status Convert(llvm::StringRef func_name, mlir::FunctionType func_type,
+                 const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
+                 const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
+                 llvm::ArrayRef<mlir::NamedAttribute> attrs);
+
+  // Returns the list of nodes in the graph. Nodes are presented in the reverse
+  // order of a post-order depth-first visit starting from the graph's source
+  // nodes.
+  llvm::ArrayRef<Node*> GetOrderedNodes() const { return ordered_nodes_; }
+
+  // Returns the inferred output type at index `idx` of the `node` in the
+  // context.
+  StatusOr<mlir::TensorType> InferOutputType(const Node& node, int idx,
+                                             mlir::Builder builder);
+
+ private:
+  // Most types with subtypes have only one subtype.
+  using ElementSubtypes = llvm::SmallVector<mlir::TensorType, 1>;
+
+  // Returns the inferred function signature of the given function body. Input
+  // types are unranked tensor of the respective datatype in the function and
+  // result types are inferred by the shape_refiner_. Result types need not be
+  // unranked tensors and could be ranked tensors in cases where result type
+  // depends on an op with static output shape like tf.Const.
+  StatusOr<mlir::FunctionType> InferLibFunctionType(const FunctionBody& fbody);
+
+  // Adds all the ordered_nodes to the shape refiner shape_refiner_. Then all
+  // data type and shape information is maintained by the shape_refiner_.
+  Status AddNodesToShapeRefiner();
+
+  // Returns the inferred input type at index `idx` of the `node` in the
+  // context.
+  StatusOr<mlir::TensorType> InferInputType(const Node& node, int idx,
+                                            mlir::Builder builder);
+
+  // Converts the inferred shape referred to by 'handle' in 'context', with
+  // given element type, and returns an MLIR tensor type.
+  StatusOr<mlir::TensorType> ConvertDataTypeAndShape(
+      DataType dtype, const shape_inference::ShapeHandle& handle,
+      const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
+      shape_inference::InferenceContext* context, mlir::Builder builder);
+
+  // Converts the inferred shape referred to by 'handle' in 'context', with
+  // given element type, and returns an MLIR tensor type.
+  StatusOr<mlir::TensorType> ConvertElementTypeAndShape(
+      mlir::Type element_type, const shape_inference::ShapeHandle& handle,
+      shape_inference::InferenceContext* context, mlir::Builder builder);
+
+  // Converts the inferred subtypes for an element type to corresponding MLIR
+  // types in 'context'.
+  StatusOr<ElementSubtypes> ConvertSubtypes(
+      const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
+      shape_inference::InferenceContext* context, mlir::Builder builder);
+
+  // Converts the tensor proto into an MLIR elements attribute.
+  StatusOr<mlir::ElementsAttr> ConvertTensorProto(const TensorProto& value) {
+    return ::tensorflow::ConvertTensorProto(value, builder_.get());
+  }
+
+  // Converts func name in graphdef to mlir::SymbolRefAttribute.
+  StatusOr<mlir::SymbolRefAttr> ConvertFunctionCallName(
+      const std::string& func_name);
+
+  // Converts the given non-function-call AttrValue to an MLIR Attribute.
+  StatusOr<mlir::Attribute> ConvertAttributeValue(const AttrValue& value);
+
+  // Converts the given function-call AttrValue to MLIR Attributes and pushes
+  // them to the given attributes list. For example, if there is a kFunc
+  // AttrValue {name : foo, attrs : {k1 : bar, k2 : rfc}}, it will convert it to
+  // a list of MLIR Attributes: [{base_name : foo}, {base_name.k1 : bar},
+  // {base_name.k2 : rfc}}.
+  Status ConvertFunctionCallAttribute(
+      const std::string& base_name, const AttrValue& value,
+      llvm::SmallVector<mlir::NamedAttribute, 4>* attributes);
+
+  // Helper to create either a tf_executor operation or a TF operation wrapped
+  // in an island.
+  mlir::Operation* createOperation(
+      const Node& node, llvm::StringRef op_name,
+      const mlir::OperationState& result,
+      const llvm::SmallVectorImpl<mlir::Value*>& control_operands);
+
+  // Converts one NodeDef from the input GraphDef into an Operation and
+  // inserts it into the MLIR module using builder_.
+  Status ConvertNode(const Node& node);
+
+  // If the input graph represents a while-loop, the edges pointing from a
+  // "NextIteration" node to a "Merge" node add cyclic dependencies and make the
+  // topological sorting impossible. We need to remove these edges from the
+  // input graph to infer shapes and construct a Function. For each
+  // "NextIteration" node, there are two operations, "NextIteration.source"
+  // and "NextIteration.sink" are added to the MLIR module.
+  using BackEdge = BackEdgeHelper::BackEdge;
+
+  // Removes backedges from the input graph. The removed edges are added back to
+  // to OpBuilder after the remaining graph is converted to the Function.
+  Status RemoveBackedges(const Graph& graph);
+
+  // Restores backedges removed during shape inference to the final Function.
+  Status AddBackedges();
+
+  // Restores a single backedge in the Function by adding a replicated
+  // operation before the dst operation.
+  Status AddBackedge(mlir::Operation* sink, mlir::Operation* dst,
+                     int dst_input);
+
+  // Finds out the function definition for the given function name from the
+  // graph and converts it to a function of the module. This method is called
+  // on demand because the graph flib_def does not provide an iterator
+  // interface. The consequence is that only the referred functions are added to
+  // the MLIR module.
+  Status ConvertLibFunction(const std::string& func_name);
+
+  // Adds the input arguments and return operation to the function. The
+  // arguments are added as basic block argument. Also the argument types and
+  // the id of the nodes from the input graph needs to be specified.
+  Status ConvertFunctionArgAndRets(
+      mlir::Block* bb, mlir::tf_executor::GraphOp graph_op,
+      llvm::ArrayRef<mlir::Type> arg_types,
+      const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
+      const absl::InlinedVector<OutputTensor, 4>& ret_nodes);
+
+  // Gets the location information of the given node. It uses the
+  // "original_node_name" in the NodeDef to get the corresponding file location
+  // (FileLineColLoc) from the input DebugInfo and returns an CallSiteLoc. If
+  // there are multiple "original_node_names", a FusedLoc is returned. If the
+  // node name couldn't be found in the input DebugInfo, a NameLoc is used as
+  // the location.
+  mlir::Location GetLocation(const NodeDef& node);
+
+  // Gets the location information string for the given node.
+  std::string GetLocationStr(const Node& node, bool includeNodeName = false);
+
+  // Inserts a placeholder node in the graph to replace the input node. Replaces
+  // all the output edges of the input_node with the placeholder node, and
+  // removes the input_node from the graph. The new node has the same name as
+  // the input_node, so Nodespecs do not need any modification.
+  // Note: This modifies the graph, and so any list of ordered nodes needs to be
+  // reconstructed.
+  StatusOr<Node*> ReplaceWithPlaceholderNode(const TensorShapeProto& shape,
+                                             DataType dtype, Node* input_node);
+
+  // Gets the input and output nodes corresponding to the specified input and
+  // output nodes in specs_. If there are no input or output nodes specified,
+  // nodes will be empty
+  Status GetInputOutputNodes(std::unordered_set<const Node*>* nodes);
+
+  // The input graph with backedges removed. The removed backedges are stored
+  // in the back_edge_helper.
+  BackEdgeHelper back_edge_helper_;
+  // A map between node and output index, for each backedge.
+  absl::flat_hash_map<const Node*, int> back_edge_node_output_;
+  absl::flat_hash_map<const Node*, BackEdge> back_edge_dst_inputs_;
+  // A map between sink and source operation of NextIteration
+  absl::flat_hash_map<mlir::Operation*, mlir::Operation*>
+      next_iteration_sink_source_;
+
+  // All nodes and version information about the (copied) imported graph.
+  std::unique_ptr<Graph> graph_;
+  std::vector<Node*> ordered_nodes_;
+
+  // Maps from a Node ID to a MLIR value.
+  using NodeValueMap = absl::flat_hash_map<int, mlir::Operation*>;
+
+  std::unique_ptr<mlir::OpBuilder> builder_;
+  mlir::ModuleOp module_;
+  mlir::MLIRContext* context_;
+  std::unordered_map<std::string, std::string>* tf_name_to_mlir_name_;
+  const FunctionLibraryDefinition& graph_flib_;
+  const NodeSpecs& specs_;
+  const GraphDebugInfo& debug_info_;
+  NodeValueMap node_values_;
+  std::unique_ptr<ShapeRefiner> shape_refiner_;
+};
+
+// Returns true if the node with given name has a non primary output that is
+// used by some other node as an input. Returns false if no outputs are in use
+// or only the first output is in use.
+bool HasNonPrimaryOutputInUse(const GraphDef& graph_def,
+                              const std::string& node) {
+  for (const auto& node_def : graph_def.node()) {
+    for (const auto& input : node_def.input()) {
+      if (absl::StartsWith(input, node + ":") && input != node + ":0") {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
+// Updates the given LegacyFedInput node with Placeholder node if it is one of
+// the inputs. Returns an error if non primary output of the LegacyFedInput node
+// is in use and therefore can not be replaced by the Placeholder node that only
+// has a single output.
+Status UpdateLegacyFedInputNode(const GraphDef& graph_def,
+                                const NodeSpecs::InputArrays& inputs,
+                                NodeDef* node) {
+  const std::string& node_name = node->name();
+  auto it = inputs.find(node_name);
+
+  // Node is not an input.
+  if (it == inputs.end()) return Status::OK();
+
+  if (HasNonPrimaryOutputInUse(graph_def, node_name)) {
+    return errors::InvalidArgument(
+        "LegacyFedInput node ", node->name(),
+        " has non primary output in use and can not be replaced with "
+        "Placeholder node");
+  }
+
+  // Update op name, drop inputs and set attributes required by the Placeholder
+  // op.
+  *node->mutable_op() = "Placeholder";
+  node->clear_attr();
+  node->clear_input();
+  AddNodeAttr("dtype", it->second.imported_dtype, node);
+  AddNodeAttr("shape", it->second.shape, node);
+  return Status::OK();
+}
+
+// Preprocesses GraphDef before it can be converted to Graph by,
+// - Adding the default attributes to each node def if they are missing from
+//   the GraphDef.
+// - Replacing LegacyFedInput nodes with Placeholder nodes if
+//   convert_legacy_fed_inputs option is enabled.
+Status PreprocessGraphDef(const NodeSpecs& specs, GraphDef* graph_def) {
+  const tensorflow::OpRegistrationData* op_reg_data;
+  for (auto& node_def : *graph_def->mutable_node()) {
+    // TODO(hinsu): Completely deprecate support for LegacyFedInput ops. One
+    // solution could be have a tool to let users upgrade old serialized graphs.
+    if (specs.convert_legacy_fed_inputs && node_def.op() == "LegacyFedInput") {
+      TF_RETURN_IF_ERROR(
+          UpdateLegacyFedInputNode(*graph_def, specs.inputs, &node_def));
+    }
+
+    auto status =
+        tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data);
+    if (!status.ok()) {
+      // This is likely a function call node, so we should continue.
+      VLOG(1) << status.ToString();
+      continue;
+    }
+    ::tensorflow::AddDefaultsToNodeDef(op_reg_data->op_def, &node_def);
+  }
+  return Status::OK();
+}
+
+Status ImporterBase::RemoveBackedges(const Graph& graph) {
+  // TODO(fengliuai): Converting to GraphDef and back is the easiest way to
+  // clone a graph.
+  // TODO(fengliuai): clone the graph without going to graph_def first.
+  GraphDef graph_def;
+  graph.ToGraphDef(&graph_def);
+  graph_ = absl::make_unique<Graph>(graph.flib_def());
+  GraphConstructorOptions opts;
+  opts.allow_internal_ops = true;
+  TF_RETURN_IF_ERROR(::tensorflow::ConvertGraphDefToGraph(
+      opts, std::move(graph_def), graph_.get()));
+
+  // Remove all the backedges. So the nodes can be added to the shape refiner.
+  TF_RETURN_IF_ERROR(back_edge_helper_.Remove(graph_.get()));
+  VLOG(1) << "Found " << (back_edge_helper_.RemovedEdges().size())
+          << " backedges.";
+
+  // Creates a map for quickly identifying whether a node output is a backedge.
+  for (const auto& edge : back_edge_helper_.RemovedEdges()) {
+    if (back_edge_node_output_.find(edge.src) != back_edge_node_output_.end() &&
+        back_edge_node_output_[edge.src] != edge.src_output) {
+      return errors::FailedPrecondition(
+          "More than one of the src node outputs are backedges!");
+    }
+    back_edge_node_output_[edge.src] = edge.src_output;
+    // We expect a merge to receive a single backedge (multiple NextIteration
+    // nodes feeding into the same merge is unexpected here).
+    DCHECK(!back_edge_dst_inputs_.contains(edge.dst));
+    back_edge_dst_inputs_[edge.dst] = edge;
+  }
+
+  // Obtains a RPO ordering, using node names as a tiebreak for stable sorting.
+  GetReversePostOrder(
+      *graph_, &ordered_nodes_,
+      [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
+
+  return Status::OK();
+}
+
+StatusOr<Node*> ImporterBase::ReplaceWithPlaceholderNode(
+    const TensorShapeProto& shape, DataType dtype, Node* input_node) {
+  Node* placeholder_node;
+  NodeBuilder builder(input_node->name(), "Placeholder");
+  builder.Attr("shape", shape);
+  builder.Attr("dtype", dtype);
+  TF_RETURN_IF_ERROR(builder.Finalize(graph_.get(), &placeholder_node));
+
+  while (!input_node->out_edges().empty()) {
+    const Edge* oe = *input_node->out_edges().begin();
+    TF_RETURN_IF_ERROR(graph_->UpdateEdge(
+        placeholder_node,
+        oe->src_output() == Graph::kControlSlot ? Graph::kControlSlot : 0,
+        oe->dst(), oe->dst_input()));
+  }
+
+  graph_->RemoveNode(input_node);
+
+  return placeholder_node;
+}
+
+Status ImporterBase::GetInputOutputNodes(
+    std::unordered_set<const Node*>* nodes) {
+  auto node_name_map = graph_->BuildNodeNameIndex();
+  auto add_node = [&](const string& name) {
+    auto it = node_name_map.find(name);
+    if (it == node_name_map.end()) {
+      return errors::FailedPrecondition(
+          absl::StrCat("Graph does not contain node :", name));
+    }
+    nodes->insert(it->second);
+    return Status::OK();
+  };
+
+  for (const auto& input : specs_.inputs) {
+    TF_RETURN_IF_ERROR(add_node(input.first));
+  }
+
+  for (const auto& output_node_name : specs_.output_arrays) {
+    TF_RETURN_IF_ERROR(add_node(output_node_name));
+  }
+
+  return Status::OK();
+}
+
+// TODO(fengliuai): Replace the iterative algorithm by an one pass propagation
+Status ImporterBase::AddNodesToShapeRefiner() {
+  shape_refiner_ = absl::make_unique<ShapeRefiner>(graph_->versions(),
+                                                   graph_->op_registry());
+  // Some operations (for example "TPUExecute") don't have shape inference
+  // function defined, so we should set this to false for adding nodes with
+  // these types of operations.
+  shape_refiner_->set_require_shape_inference_fns(false);
+  shape_refiner_->set_function_library_for_shape_inference(&graph_flib_);
+
+  // First add all nodes to the refiner.
+  for (Node* node : ordered_nodes_) {
+    // We need to use a TensorFlow node to teach the shape refiner that user
+    // specifies certain data type and shape for the inputs in the `specs_`.
+    // This node shouldn't have any inputs, only have one output and its
+    // output type/shape is only determined by its "named" attributes. (The
+    // attributes should have fixed names so we can use the info from `specs_`
+    // to set the value of them.) `Placeholder` satisfies these constraints.
+    //
+    // Therefore, if the input node isn't a `Placeholder`, we create one and use
+    // it to replace the original input node, so the shape refiner can
+    // successfully propagate the user's input type and shape to the rest of the
+    // graph.
+    auto it = specs_.inputs.find(node->name());
+    if (it != specs_.inputs.end()) {
+      auto node_name = node->op_def().name();
+      if (node_name != "Placeholder" && node_name != "LegacyFedInput") {
+        // We do not handle the case where the input node has multple outputs
+        if (node->num_outputs() > 1) {
+          return errors::FailedPrecondition(absl::StrCat(
+              "Input arrays can only have op with single output. Node op:",
+              node_name));
+        }
+        // For single output nodes, replace them with Placeholder node
+        TF_ASSIGN_OR_RETURN(
+            node, ReplaceWithPlaceholderNode(it->second.shape,
+                                             it->second.imported_dtype, node));
+      } else {
+        node->AddAttr("shape", it->second.shape);
+        node->AddAttr("dtype", it->second.imported_dtype);
+      }
+    }
+    // Adds the node to the shape refiner.
+    TF_RETURN_WITH_CONTEXT_IF_ERROR(shape_refiner_->AddNode(node),
+                                    GetLocationStr(*node));
+
+    // If it is the argument node, the shape handle is set explicitly, so it
+    // can be propagated to the body nodes of the function.
+    if (StringPiece(node->type_string()) == FunctionLibraryDefinition::kArgOp) {
+      auto* node_context = shape_refiner_->GetContext(node);
+      DCHECK(node_context != nullptr);
+      auto it = node->def().attr().find("shape");
+      if (it != node->def().attr().end()) {
+        shape_inference::ShapeHandle handle;
+        TF_RETURN_WITH_CONTEXT_IF_ERROR(
+            node_context->MakeShapeFromShapeProto(it->second.shape(), &handle),
+            GetLocationStr(*node));
+        node_context->set_output(0, handle);
+      } else {
+        node_context->set_output(0, node_context->UnknownShape());
+      }
+    }
+  }
+
+  // Since we might have inserted and removed nodes from the graph, fix
+  // source/sink edges and reconstruct the RPO ordering of nodes
+  FixupSourceAndSinkEdges(graph_.get());
+
+  // Prune nodes in the graph that are not reachable from the output.
+  if (specs_.prune_unused_nodes) {
+    std::unordered_set<const Node*> prune_start;
+    TF_RETURN_IF_ERROR(GetInputOutputNodes(&prune_start));
+    if (!prune_start.empty()) {
+      if (PruneForReverseReachability(graph_.get(), prune_start)) {
+        VLOG(1) << "Pruned unused nodes in graphdef";
+      } else {
+        VLOG(1) << "No unused nodes in graphdef to prune";
+      }
+    } else {
+      VLOG(1) << "No output nodes specified, skipping pruning";
+    }
+  } else {
+    VLOG(1) << "Pruning unused nodes in graphdef is disabled";
+  }
+
+  // Re-initialize ordered_nodes_ since we might have modified the graph.
+  GetReversePostOrder(
+      *graph_, &ordered_nodes_,
+      [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
+
+  VLOG(1) << "Inferring graph shapes to fixpoint";
+
+  // The "changed" information from UpdateNode can give false positives, so we
+  // create a dedicated method to verify the shapes are not changed before and
+  // after the shape refine.
+  auto same_inferred_shape = [](shape_inference::InferenceContext* c,
+                                shape_inference::ShapeHandle s0,
+                                shape_inference::ShapeHandle s1) -> bool {
+    if (s0.SameHandle(s1) || (!c->RankKnown(s0) && !c->RankKnown(s1))) {
+      return true;
+    }
+    if (c->Rank(s0) != c->Rank(s1)) {
+      return false;
+    }
+    for (int i = 0; i < c->Rank(s0); ++i) {
+      if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) {
+        int64 val0 = c->Value(c->Dim(s0, i));
+        int64 val1 = c->Value(c->Dim(s1, i));
+        // Negative value is treated as unknown so all negative values indicate
+        // the same dimension.
+        if (val0 >= 0 && val1 >= 0 && val0 != val1) return false;
+      }
+    }
+    return true;
+  };
+
+  bool changed = true;
+  int i = 0;
+  const int kMaxIterationCount = 2;
+  while (changed && i != kMaxIterationCount) {
+    changed = false;
+    for (const Node* node : ordered_nodes_) {
+      auto* shape_context = shape_refiner_->GetContext(node);
+      DCHECK(shape_context != nullptr);
+      absl::InlinedVector<shape_inference::ShapeHandle, 4> existing;
+      existing.reserve(shape_context->num_outputs());
+      for (int o = 0; o < shape_context->num_outputs(); ++o) {
+        existing.push_back(shape_context->output(o));
+      }
+      bool inferred = false;
+      TF_RETURN_WITH_CONTEXT_IF_ERROR(
+          shape_refiner_->UpdateNode(node, /*relax=*/false, &inferred),
+          GetLocationStr(*node));
+      for (int o = 0; o < shape_context->num_outputs(); ++o) {
+        if (!same_inferred_shape(shape_context, shape_context->output(o),
+                                 existing[o])) {
+          changed = true;
+          break;
+        }
+      }
+    }
+    ++i;
+  }
+  if (i >= kMaxIterationCount) {
+    LOG(WARNING) << "Graph shapes did not converge to a fixpoint within "
+                 << kMaxIterationCount
+                 << " iterations. Graph shapes may be conservative.";
+  }
+  VLOG(1) << "Graph shapes were inferred with " << (i - 1)
+          << " extra rounds of analysis to reach a fixpoint.";
+  return Status::OK();
+}
+
+StatusOr<mlir::TensorType> ImporterBase::InferInputType(const Node& node,
+                                                        int idx,
+                                                        mlir::Builder builder) {
+  ExtendedInferenceContext* shape_context =
+      shape_refiner_->GetExtendedContext(&node);
+  DataType dtype = shape_context->input_type(idx);
+  auto* context = shape_context->get_context();
+  return ConvertDataTypeAndShape(dtype, context->input(idx),
+                                 context->input_handle_shapes_and_types(idx),
+                                 context, builder);
+}
+
+StatusOr<mlir::TensorType> ImporterBase::InferOutputType(
+    const Node& node, int idx, mlir::Builder builder) {
+  ExtendedInferenceContext* shape_context =
+      shape_refiner_->GetExtendedContext(&node);
+  DataType dtype = shape_context->output_type(idx);
+  auto* context = shape_context->get_context();
+  return ConvertDataTypeAndShape(dtype, context->output(idx),
+                                 context->output_handle_shapes_and_types(idx),
+                                 context, builder);
+}
+
+StatusOr<mlir::TensorType> ImporterBase::ConvertDataTypeAndShape(
+    DataType dtype, const shape_inference::ShapeHandle& handle,
+    const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
+    shape_inference::InferenceContext* context, mlir::Builder builder) {
+  TF_ASSIGN_OR_RETURN(auto subtypes,
+                      ConvertSubtypes(handle_subtypes, context, builder));
+
+  // TODO(hinsu): Store subtypes information for DT_RESOURCE element type as
+  // well.
+  mlir::Type element_type;
+  if (dtype == DT_VARIANT) {
+    element_type = mlir::TF::VariantType::get(subtypes, context_);
+  } else {
+    TF_RETURN_IF_ERROR(
+        ::tensorflow::ConvertDataType(dtype, builder, &element_type));
+  }
+  return ConvertElementTypeAndShape(element_type, handle, context, builder);
+}
+
+StatusOr<mlir::TensorType> ImporterBase::ConvertElementTypeAndShape(
+    mlir::Type element_type, const shape_inference::ShapeHandle& handle,
+    shape_inference::InferenceContext* context, mlir::Builder builder) {
+  if (!context->RankKnown(handle)) {
+    return builder.getTensorType(element_type);
+  }
+
+  // Sentinel for an unknown dimension size. getTensorType interprets any
+  // negative value as an unknown dimension.
+  // TODO(jmolloy): Ideally this shouldn't be a local sentinel.
+  const int64_t kUnknownDim = -1;
+
+  absl::InlinedVector<int64_t, 4> dimensions;
+  int32 rank = context->Rank(handle);
+  dimensions.reserve(rank);
+  for (int i = 0; i < rank; ++i) {
+    auto dim_handle = context->Dim(handle, i);
+    if (!context->ValueKnown(dim_handle))
+      dimensions.push_back(kUnknownDim);
+    else
+      dimensions.push_back(context->Value(dim_handle));
+  }
+
+  return builder.getTensorType(
+      llvm::makeArrayRef(dimensions.begin(), dimensions.end()), element_type);
+}
+
+StatusOr<ImporterBase::ElementSubtypes> ImporterBase::ConvertSubtypes(
+    const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
+    shape_inference::InferenceContext* context, mlir::Builder builder) {
+  ElementSubtypes subtypes;
+  if (!handle_subtypes) return subtypes;
+
+  subtypes.reserve(handle_subtypes->size());
+  for (const auto& subtype : *handle_subtypes) {
+    mlir::Type element_type;
+    TF_RETURN_IF_ERROR(
+        ::tensorflow::ConvertDataType(subtype.dtype, builder, &element_type));
+    TF_ASSIGN_OR_RETURN(mlir::TensorType type,
+                        ConvertElementTypeAndShape(element_type, subtype.shape,
+                                                   context, builder));
+    subtypes.push_back(type);
+  }
+  return subtypes;
+}
+
+Status ImporterBase::ConvertFunctionCallAttribute(
+    const std::string& base_name, const AttrValue& value,
+    llvm::SmallVector<mlir::NamedAttribute, 4>* attributes) {
+  TF_ASSIGN_OR_RETURN(auto func_attr,
+                      ConvertFunctionCallName(value.func().name()));
+  attributes->push_back(builder_->getNamedAttr(base_name, func_attr));
+
+  for (const auto& it : value.func().attr()) {
+    auto name = absl::StrCat(base_name, ".", it.first);
+    TF_ASSIGN_OR_RETURN(auto value, ConvertAttributeValue(it.second));
+    attributes->push_back(builder_->getNamedAttr(name, value));
+  }
+  return Status::OK();
+}
+
+StatusOr<mlir::SymbolRefAttr> ImporterBase::ConvertFunctionCallName(
+    const std::string& func_name) {
+  TF_RETURN_IF_ERROR(ConvertLibFunction(func_name));
+  auto mlir_func_name = (*tf_name_to_mlir_name_)[func_name];
+  auto func = module_.lookupSymbol<mlir::FuncOp>(mlir_func_name);
+  return builder_->getSymbolRefAttr(func);
+}
+
+StatusOr<mlir::Attribute> ImporterBase::ConvertAttributeValue(
+    const AttrValue& value) {
+  switch (value.value_case()) {
+    case AttrValue::kI:
+      return builder_->getI64IntegerAttr(value.i());
+    case AttrValue::kS:
+      return builder_->getStringAttr(value.s());
+    case AttrValue::kF:
+      return builder_->getFloatAttr(builder_->getF32Type(), value.f());
+    case AttrValue::kB:
+      return builder_->getBoolAttr(value.b());
+    case AttrValue::kType:
+      return builder_->getStringAttr(
+          mangling_util::MangleDataType(value.type()));
+    case AttrValue::kShape:
+      return builder_->getStringAttr(mangling_util::MangleShape(value.shape()));
+    case AttrValue::kTensor:
+      return ConvertTensorProto(value.tensor());
+    case AttrValue::kList: {
+      absl::InlinedVector<mlir::Attribute, 8> attrs;
+      for (const auto& item : value.list().i())
+        attrs.push_back(builder_->getI64IntegerAttr(item));
+      for (const auto& item : value.list().s())
+        attrs.push_back(builder_->getStringAttr(item));
+      for (const auto& item : value.list().f())
+        attrs.push_back(builder_->getFloatAttr(builder_->getF32Type(), item));
+      for (const auto& item : value.list().b())
+        attrs.push_back(builder_->getBoolAttr(item));
+      for (const auto& item : value.list().type()) {
+        attrs.push_back(builder_->getStringAttr(
+            mangling_util::MangleDataType(static_cast<DataType>(item))));
+      }
+      for (const auto& item : value.list().shape()) {
+        attrs.push_back(
+            builder_->getStringAttr(mangling_util::MangleShape(item)));
+      }
+      for (const auto& item : value.list().tensor()) {
+        TF_ASSIGN_OR_RETURN(auto attr, ConvertTensorProto(item));
+        attrs.push_back(attr);
+      }
+      for (const auto& item : value.list().func()) {
+        TF_ASSIGN_OR_RETURN(auto attr, ConvertFunctionCallName(item.name()));
+        if (item.attr_size() != 0)
+          return errors::Unimplemented(
+              "func attributes with non-zero attr.size()");
+        attrs.push_back(attr);
+      }
+      return builder_->getArrayAttr(
+          llvm::makeArrayRef(attrs.begin(), attrs.end()));
+    }
+    case AttrValue::kFunc:
+      return errors::Unknown("kFunc type should be handled separately!");
+    case AttrValue::VALUE_NOT_SET:
+      return builder_->getUnitAttr();
+    // kPlaceholder is not implemented.
+    default:
+      return errors::Unimplemented(
+          absl::StrCat("Attribute ", value.DebugString()));
+  }
+}
+
+Status ImporterBase::ConvertLibFunction(const std::string& func_name) {
+  // If the library function has been converted already, nothing needs to be
+  // done.
+  if (tf_name_to_mlir_name_->find(func_name) != tf_name_to_mlir_name_->end())
+    return Status::OK();
+
+  std::string mlir_func_name = graph_flib_.UniqueFunctionName(func_name);
+  (*tf_name_to_mlir_name_)[func_name] = mlir_func_name;
+
+  const auto& func_lib = graph_flib_;
+  const auto* func_def = func_lib.Find(func_name);
+  if (func_def == nullptr) {
+    return errors::FailedPrecondition(
+        absl::StrCat("Failed to find function '", func_name,
+                     "'. The imported TensorFlow GraphDef is ill-formed."));
+  }
+
+  // Converts the function definition to a graph.
+  std::unique_ptr<FunctionBody> fbody;
+  TF_RETURN_IF_ERROR(
+      FunctionDefToBodyHelper(*func_def, AttrSlice(), &func_lib, &fbody));
+
+  // Converts the argument and return types to mlir types.
+  absl::InlinedVector<mlir::NamedAttribute, 8> attributes;
+  attributes.reserve(func_def->attr_size());
+  for (const auto& name_and_value : func_def->attr()) {
+    // This is a function definition attribute, so it shouldn't contain
+    // kFunc attribute and it is treated as normal one.
+    TF_ASSIGN_OR_RETURN(auto attr,
+                        ConvertAttributeValue(name_and_value.second));
+    std::string attr_name =
+        mangling_util::MangleAttributeName(name_and_value.first);
+    attributes.push_back(builder_->getNamedAttr(attr_name, attr));
+  }
+
+  // Checks opdef stateful attribute and import that as Function Attribute
+  if (func_def->signature().is_stateful()) {
+    auto stateful_str = mlir::TF::TensorFlowDialect::GetStatefulAttrName();
+    attributes.push_back(
+        builder_->getNamedAttr(stateful_str, builder_->getUnitAttr()));
+  }
+
+  // Checks for an associated custom gradient function. Adds it to the attribute
+  // list of this function.
+  auto grad_func_name = func_lib.FindGradient(func_name);
+  if (!grad_func_name.empty()) {
+    TF_RETURN_IF_ERROR(ConvertLibFunction(grad_func_name));
+    auto mlir_grad_func_name = (*tf_name_to_mlir_name_)[grad_func_name];
+    auto grad_func = module_.lookupSymbol<mlir::FuncOp>(mlir_grad_func_name);
+    auto gradient_attr = builder_->getSymbolRefAttr(grad_func);
+    auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
+    attributes.push_back(builder_->getNamedAttr(grad_string, gradient_attr));
+  }
+
+  // Converts the graph to a MLIR function and adds it to the module. Uses the
+  // default node spec without any inputs or outputs as the function graph has
+  // special '_Arg' and '_Retval' ops for argument and return values.
+  NodeSpecs specs;
+  ImporterBase child_importer(graph_flib_, debug_info_, specs, module_,
+                              tf_name_to_mlir_name_);
+  TF_RETURN_IF_ERROR(child_importer.PrepareConvert(*fbody->graph));
+
+  TF_ASSIGN_OR_RETURN(auto func_type,
+                      child_importer.InferLibFunctionType(*fbody));
+
+  absl::InlinedVector<OutputTensor, 4> arg_nodes;
+  arg_nodes.reserve(fbody->arg_nodes.size());
+  absl::InlinedVector<OutputTensor, 4> ret_nodes;
+  ret_nodes.reserve(fbody->ret_nodes.size());
+  for (auto arg : fbody->arg_nodes) {
+    arg_nodes.emplace_back(arg, 0);
+  }
+  for (auto ret : fbody->ret_nodes) {
+    ret_nodes.emplace_back(ret, 0);
+  }
+
+  TF_RETURN_IF_ERROR(child_importer.Convert(
+      mlir_func_name, func_type, arg_nodes, ret_nodes,
+      llvm::makeArrayRef(attributes.begin(), attributes.end())));
+  return Status::OK();
+}
+
+Status ImporterBase::PrepareConvert(const Graph& graph) {
+  TF_RETURN_IF_ERROR(RemoveBackedges(graph));
+  TF_RETURN_IF_ERROR(AddNodesToShapeRefiner());
+  return Status::OK();
+}
+
+Status ImporterBase::Convert(
+    llvm::StringRef func_name, mlir::FunctionType func_type,
+    const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
+    const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
+    llvm::ArrayRef<mlir::NamedAttribute> attrs) {
+  // TODO(b/122040776): Uses debug info for FunctionDef.
+  auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_),
+                                       func_name, func_type, attrs);
+
+  module_.push_back(function);
+  // Seeds the builder with an initial block.
+  function.addEntryBlock();
+  builder_ = absl::make_unique<mlir::OpBuilder>(function.getBody());
+  auto* bb = &function.front();
+
+  // Create the graph operation in which we will convert the individual nodes.
+  auto graph = builder_->create<mlir::tf_executor::GraphOp>(
+      function.getLoc(), func_type.getResults());
+  builder_->createBlock(&graph.body());
+
+  for (const Node* node : ordered_nodes_) {
+    TF_RETURN_IF_ERROR(ConvertNode(*node));
+  }
+
+  // Adds the backedges back to the function by creating the source and sink
+  // pairs.
+  TF_RETURN_IF_ERROR(AddBackedges());
+
+  return ConvertFunctionArgAndRets(bb, graph, func_type.getInputs(), arg_nodes,
+                                   ret_nodes);
+}
+
+Status ImporterBase::ConvertFunctionArgAndRets(
+    mlir::Block* bb, mlir::tf_executor::GraphOp graph_op,
+    llvm::ArrayRef<mlir::Type> arg_types,
+    const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
+    const absl::InlinedVector<OutputTensor, 4>& ret_nodes) {
+  for (int i = 0, e = arg_types.size(); i < e; ++i) {
+    // The lookup can't fail here: otherwise some nodes in the function haven't
+    // be converted to mlir operations and don't have a mapping.
+    mlir::Operation* island =
+        node_values_.find(arg_nodes[i].node->id())->second;
+    // We are looking for the instruction inside the island
+    mlir::Block& body = island->getRegion(0).front();
+    mlir::Operation* inst = &body.front();
+
+    auto* bb_arg = bb->getArgument(i);
+    mlir::Value* arg_def = bb_arg;
+
+    // If this is an arg node, just forward the entry block argument
+    if (arg_nodes[i].node->IsArg()) {
+      island->getResult(0)->replaceAllUsesWith(arg_def);
+      island->dropAllReferences();
+      island->erase();
+      continue;
+    }
+
+    // This is an input node, we'll create a new input operation by suffixing
+    // the existing one with .input.
+    auto inst_name = inst->getName().getStringRef();
+    mlir::OperationState state(inst->getLoc(),
+                               inst_name.str().append(".input"));
+    state.attributes.append(inst->getAttrs().begin(), inst->getAttrs().end());
+
+    // If there are quantization specifications, add them as the attributes
+    auto name = inst->getAttrOfType<mlir::StringAttr>("name").getValue();
+    auto input_spec_it = specs_.inputs.find(name.str());
+    if (input_spec_it != specs_.inputs.end()) {
+      auto input_spec = input_spec_it->second;
+      if (IsQuantizationType(input_spec.final_dtype)) {
+        // Uses the MLIR built-in type so it can be handled easily later.
+        auto final_type = mlir::IntegerType::get(
+            GetQuantizationTypeWidth(input_spec.final_dtype), context_);
+        state.attributes.push_back(builder_->getNamedAttr(
+            "min", builder_->getF32FloatAttr(input_spec.min_value)));
+        state.attributes.push_back(builder_->getNamedAttr(
+            "max", builder_->getF32FloatAttr(input_spec.max_value)));
+        state.attributes.push_back(
+            builder_->getNamedAttr("type", builder_->getTypeAttr(final_type)));
+        inst->getParentOfType<mlir::FuncOp>().setAttr("tf.quantize",
+                                                      builder_->getUnitAttr());
+      }
+    }
+
+    for (auto* r : inst->getResults()) state.types.push_back(r->getType());
+
+    state.operands.append(inst->getOperands().begin(),
+                          inst->getOperands().end());
+    state.operands.push_back(bb_arg);
+    builder_->setInsertionPoint(inst);
+    auto* input = builder_->createOperation(state);
+    arg_def = input->getResult(arg_nodes[i].index);
+
+    for (auto index = 0; index < inst->getNumResults(); index++) {
+      inst->getResult(index)->replaceAllUsesWith(arg_def);
+    }
+    inst->dropAllReferences();
+    inst->erase();
+  }
+
+  llvm::SmallVector<mlir::Value*, 8> inst_to_returned;
+  for (const auto& ret : ret_nodes) {
+    auto* inst = node_values_[ret.node->id()];
+    auto op = absl::string_view(ret.node->type_string());
+    if (op == FunctionLibraryDefinition::kRetOp ||
+        op == FunctionLibraryDefinition::kDeviceRetOp) {
+      // Lookup the instruction inside the island
+      auto island_op = llvm::cast<mlir::tf_executor::IslandOp>(inst);
+      mlir::Operation* inner_op = &island_op.GetBody().front();
+      // Remove kRetOp or kDeviceRetOp operation and return its operand.
+      // kRetOp and kDeviceRetOp should have just one operand unless they have
+      // control dependencies.
+      if (inner_op->getNumOperands() != 1)
+        return errors::Unimplemented("Return node with multiple inputs.");
+      inst_to_returned.push_back(inner_op->getOperand(0));
+      inst->dropAllReferences();
+      inst->erase();
+    } else {
+      inst_to_returned.push_back(inst->getResult(ret.index));
+    }
+  }
+
+  // Terminate the function by adding a Fetch operation to terminate the graph
+  // and a return operation to return the Graph results.
+  builder_->setInsertionPointToEnd(&graph_op.body().front());
+  builder_->create<mlir::tf_executor::FetchOp>(graph_op.getLoc(),
+                                               inst_to_returned);
+  inst_to_returned.assign(graph_op.getResults().begin(),
+                          graph_op.getResults().end());
+  builder_->setInsertionPointToEnd(bb);
+  builder_->create<mlir::ReturnOp>(
+      mlir::UnknownLoc::get(context_),
+      llvm::makeArrayRef(inst_to_returned.begin(), inst_to_returned.end()));
+  return Status::OK();
+}
+
+mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) {
+  const auto& debug_info = debug_info_.traces();
+
+  // Get the CallSiteLoc for a node name.
+  // - If the debug info of the node couldn't be found, the caller of the
+  //   returned CallSiteLoc is set to an UnknownLoc;
+  // - If the debug info of the node is found, the caller of the returned
+  //   CallSiteLoc is set to a call stack which is formed by the debug info.
+  auto node_name_to_call_site = [&](const std::string& name) -> mlir::Location {
+    auto name_id = mlir::Identifier::get(name, context_);
+    const auto& location_it = debug_info.find(name);
+    if (location_it == debug_info.end()) {
+      // Only the node name is stored if the location is unknown.
+      return mlir::NameLoc::get(name_id, context_);
+    }
+
+    // Convert the stack trace to a chain of mlir::CallSiteLocs.
+    const auto& trace = location_it->second;
+    llvm::SmallVector<mlir::Location, 4> locations;
+    locations.reserve(trace.file_line_cols_size());
+    for (const auto& location : trace.file_line_cols()) {
+      const auto& file = debug_info_.files(location.file_index());
+      auto file_name = mlir::Identifier::get(file, context_);
+      auto file_line_loc = mlir::FileLineColLoc::get(file_name, location.line(),
+                                                     location.col(), context_);
+      locations.push_back(file_line_loc);
+    }
+    // Handle empty location vector.
+    if (locations.empty()) return mlir::NameLoc::get(name_id, context_);
+
+    // Use the front FileLineColLoc to generate a NameLoc.
+    mlir::Location node_name_loc =
+        mlir::NameLoc::get(name_id, locations.front(), context_);
+
+    // If there are more locations then generate a stack trace, otherwise just
+    // return the name loc.
+    auto callsite_locs = llvm::makeArrayRef(locations).drop_front();
+    return callsite_locs.empty()
+               ? node_name_loc
+               : mlir::CallSiteLoc::get(node_name_loc, callsite_locs, context_);
+  };
+
+  // For NextIteration nodes, location is used to pair source and sink nodes.
+  // Hence, we use node name as location to keep it unique.
+  // TODO(prakalps): In future the plan is to use tokens to pair source/sink
+  // nodes. Then NextIteration nodes would not need to be handled seprately.
+  if (node_def.op() == "NextIteration")
+    return node_name_to_call_site(node_def.name());
+
+  auto original_nodes =
+      node_def.experimental_debug_info().original_node_names();
+  auto original_funcs =
+      node_def.experimental_debug_info().original_func_names();
+
+  if (original_nodes.empty()) {
+    // If the original nodes are not defined in the node def, but the current
+    // node name is contained in the debug info file, then we fall back to use
+    // the current node name to get the location info. Otherwise, use a
+    // NameLoc with node name as in a TensorFlow graph the node name is unique.
+    auto& curr_node_name = node_def.name();
+    if (debug_info.find(curr_node_name) == debug_info.end()) {
+      return mlir::NameLoc::get(mlir::Identifier::get(curr_node_name, context_),
+                                context_);
+    } else {
+      return node_name_to_call_site(curr_node_name);
+    }
+  } else {
+    // If the original nodes are defined, then we use them to get a list of
+    // call sites, and then fuse them to a single fused location.
+    llvm::SmallVector<mlir::Location, 4> node_call_sites;
+    node_call_sites.reserve(original_nodes.size());
+    for (int i = 0, e = original_nodes.size(); i != e; ++i) {
+      auto node_name = original_nodes[i];
+      auto func_name = (i < original_funcs.size()) ? original_funcs[i] : "";
+      // Use the catenation of function and node names as the lookup key. This
+      // is to match the utility of generating the GraphDebugInfo.
+      node_call_sites.push_back(node_name_to_call_site(func_name + node_name));
+    }
+    return mlir::FusedLoc::get(node_call_sites, context_);
+  }
+}
+
+std::string ImporterBase::GetLocationStr(const Node& node,
+                                         bool includeNodeName) {
+  const auto location = GetLocation(node.def());
+  std::string s;
+  llvm::raw_string_ostream ss(s);
+  location.print(ss);
+  ss.flush();
+  // Removes the node name prefix if it exists.
+  if (!s.empty() && s[0] == '\"' && s.find_first_of(node.name()) == 1) {
+    return s.replace(0, node.name().size() + 3, "");
+  }
+  return s;
+}
+
+mlir::Operation* ImporterBase::createOperation(
+    const Node& node, llvm::StringRef op_name,
+    const mlir::OperationState& result,
+    const llvm::SmallVectorImpl<mlir::Value*>& control_operands) {
+  // For the tf.executor specific operations (not wrapped in an island), we
+  // have an extra returned value for the control result, and we concatenate
+  // control and non-control operands.
+  mlir::SmallVector<mlir::Type, 4> types(result.types);
+  types.push_back(mlir::tf_executor::ControlType::get(builder_->getContext()));
+  mlir::SmallVector<mlir::Value*, 4> operands(result.operands);
+  operands.append(control_operands.begin(), control_operands.end());
+
+  auto loc = result.location;
+  // Dispatch based on the name and create the appropriate operation.
+  if (node.IsSwitch()) {
+    return builder_->create<mlir::tf_executor::SwitchOp>(loc, types, operands,
+                                                         result.attributes);
+  }
+  if (op_name == "tf.SwitchN") {
+    return builder_->create<mlir::tf_executor::SwitchNOp>(loc, types, operands,
+                                                          result.attributes);
+  }
+  if (node.IsMerge()) {
+    return builder_->create<mlir::tf_executor::MergeOp>(loc, types, operands,
+                                                        result.attributes);
+  }
+  if (node.IsNextIteration()) {
+    // NextIteration is a bit special, we create a pair of operations that are
+    // linked together through a token returned by the source.
+    // We make use of a separate builder to insert the source at the top of
+    // the block.
+    mlir::OpBuilder builder_at_begin(builder_->getBlock(),
+                                     builder_->getBlock()->begin());
+    auto source_op =
+        builder_at_begin.create<mlir::tf_executor::NextIterationSourceOp>(
+            loc, operands[0]->getType(), result.attributes);
+    return builder_->create<mlir::tf_executor::NextIterationSinkOp>(
+        loc, source_op.token(), operands, result.attributes);
+  }
+  if (node.IsLoopCond()) {
+    return builder_->create<mlir::tf_executor::LoopCondOp>(loc, types, operands,
+                                                           result.attributes);
+  }
+  if (node.IsEnter()) {
+    return builder_->create<mlir::tf_executor::EnterOp>(loc, types, operands,
+                                                        result.attributes);
+  }
+  if (node.IsExit()) {
+    return builder_->create<mlir::tf_executor::ExitOp>(loc, types, operands,
+                                                       result.attributes);
+  }
+  if (node.IsControlTrigger()) {
+    return builder_->create<mlir::tf_executor::ControlTriggerOp>(
+        loc, operands, result.attributes);
+  }
+  // Regular TensorFlow operation are wrapped in a tf_executor.island.
+  auto island = builder_->create<mlir::tf_executor::IslandOp>(
+      result.location, types, control_operands,
+      mlir::ArrayRef<mlir::NamedAttribute>{});
+  island.body().push_back(new mlir::Block);
+  mlir::OpBuilder island_builder(&island.GetBody());
+
+  // Create the operation inside the island now.
+  mlir::Operation* inner_op = island_builder.createOperation(result);
+
+  // Add the terminator for the island
+  mlir::SmallVector<mlir::Value*, 8> ret_vals(inner_op->getResults());
+  island_builder.create<mlir::tf_executor::YieldOp>(result.location, ret_vals);
+  return island.getOperation();
+}
+
+Status ImporterBase::ConvertNode(const Node& node) {
+  if (!node.IsOp()) {
+    // Don't import the pseudo-nodes _SOURCE or _SINK. These are added by
+    // Graph and don't exist in GraphDef.
+    return Status::OK();
+  }
+
+  // If it is a custom OP, its definition should be found in the library. We
+  // create the MLIR function and insert it to the module if it doesn't exist.
+  std::string node_type_name = node.type_string();
+  const auto* func_def = graph_flib_.Find(node_type_name);
+  if (func_def) {
+    TF_RETURN_IF_ERROR(ConvertLibFunction(node_type_name));
+    node_type_name = (*tf_name_to_mlir_name_)[node_type_name];
+  }
+
+  auto get_full_op_name = [&](const std::string& op_name) {
+    const char* kTfPrefix = "tf.";
+    return kTfPrefix + op_name;
+  };
+
+  std::string op_name = get_full_op_name(node_type_name);
+  if (back_edge_node_output_.contains(&node)) {
+    op_name = op_name + ".sink";
+  }
+
+  const auto& node_def = node.def();
+  mlir::OperationState result(GetLocation(node_def), op_name);
+
+  for (int i = 0; i < node.num_outputs(); ++i) {
+    // The backedge has been removed, so we shouldn't count the corresponding
+    // output from the src node when converting to an operation.
+    if (back_edge_node_output_.contains(&node) &&
+        back_edge_node_output_[&node] == i) {
+      continue;
+    }
+    TF_ASSIGN_OR_RETURN(auto type, InferOutputType(node, i, *builder_));
+    result.types.push_back(type);
+  }
+
+  // Surprisingly input edges can be nondeterministically ordered. This
+  // particularly seems to be the case for the control edges between _SOURCE
+  // and _SINK that the Graph constructor inserts. Copy the input edges and
+  // sort the edges, but only the control edges, not data edges!
+  // TODO(jmolloy): We should probably just ignore _SOURCE and _SINK nodes.
+  // They'll break roundtripping anyway unless we strip them when converting
+  // back to graphdef.
+  absl::InlinedVector<const Edge*, 8> in_edges(node.in_edges().size());
+  absl::c_copy(node.in_edges(), in_edges.begin());
+  absl::c_stable_sort(in_edges, [](const Edge* e1, const Edge* e2) {
+    if (e1->IsControlEdge() && !e2->IsControlEdge()) return false;
+    if (!e1->IsControlEdge() && e2->IsControlEdge()) return true;
+    return e1->dst_input() < e2->dst_input();
+  });
+
+  result.operands.reserve(in_edges.size());
+
+  // Collect the control operands separately, they will be held by the island.
+  mlir::SmallVector<mlir::Value*, 8> control_operands;
+
+  for (const auto* input_edge : in_edges) {
+    const Node& input_node = *input_edge->src();
+    if (input_node.IsSource()) {
+      if (in_edges.size() != 1) {
+        return errors::FailedPrecondition(
+            "The node has other inputs besides the _Source node");
+      }
+      // We don't import the _SOURCE node.
+      continue;
+    }
+    if (input_node.IsArg() && input_edge->IsControlEdge()) {
+      // Currently we have not reached consensus as to what TF function
+      // semantics are (b/133509504). Here we assume that all arguments to a
+      // function should be available before we start execution of any internal
+      // node. This makes the control dependencies between function arguments
+      // and internal nodes redundant, and so we do not import them. The TF
+      // inliner however assumes no such dependency between function args and
+      // internal nodes exists, unless explicitly stated. Since we drop control
+      // dependencies here, it leads to loss of information. If the function is
+      // inlined later, the inliner would not know of these explicit control
+      // dependencies present in the original graph.
+      continue;
+    }
+    if (node_values_.find(input_node.id()) == node_values_.end())
+      return errors::FailedPrecondition(
+          "Graph not traversed in reverse post order; use seen before def!");
+    mlir::Operation* inst = node_values_[input_node.id()];
+    if (input_edge->IsControlEdge())
+      control_operands.push_back(inst->getResult(inst->getNumResults() - 1));
+    else
+      result.operands.push_back(inst->getResult(input_edge->src_output()));
+  }
+
+  using FuncPairType = std::pair<const std::string*, const AttrValue*>;
+  std::vector<FuncPairType> funcs;
+  result.attributes.reserve(node.attrs().size() + 2);
+  for (const auto& name_and_value : node.attrs()) {
+    const auto& attr_name = name_and_value.first;
+    const AttrValue& attr_value = name_and_value.second;
+    if (attr_value.value_case() == AttrValue::kFunc) {
+      // Attribute iteration order is not defined for protocol buffer Map.
+      // Process function attributes separately in the lexicographical order to
+      // have deterministic order of functions in the constructed IR.
+      funcs.emplace_back(&attr_name, &attr_value);
+    } else {
+      TF_ASSIGN_OR_RETURN(auto attr, ConvertAttributeValue(attr_value));
+      result.attributes.push_back(builder_->getNamedAttr(attr_name, attr));
+    }
+  }
+
+  auto comparator = [](const FuncPairType& a, const FuncPairType& b) {
+    return *a.first < *b.first;
+  };
+  std::sort(funcs.begin(), funcs.end(), comparator);
+  for (const auto& func : funcs) {
+    TF_RETURN_IF_ERROR(ConvertFunctionCallAttribute(*func.first, *func.second,
+                                                    &result.attributes));
+  }
+
+  result.attributes.push_back(builder_->getNamedAttr(
+      "name", builder_->getStringAttr(std::string(node.name()))));
+  result.attributes.push_back(builder_->getNamedAttr(
+      "device", builder_->getStringAttr(std::string(node_def.device()))));
+
+  // Map If and StatelessIf op in TensorFlow to the common If op in MLIR and add
+  // the differentiating attribute.
+  if (node.IsIfNode()) {
+    result.name = mlir::OperationName(get_full_op_name("If"), context_);
+    mlir::BoolAttr val = builder_->getBoolAttr(node_type_name == "StatelessIf");
+    result.attributes.push_back(builder_->getNamedAttr("is_stateless", val));
+  }
+
+  // Map While and StatelessWhile op in TensorFlow to the common While op in
+  // MLIR and add the differentiating attribute.
+  if (node.IsWhileNode()) {
+    result.name = mlir::OperationName(get_full_op_name("While"), context_);
+    mlir::BoolAttr val =
+        builder_->getBoolAttr(node_type_name == "StatelessWhile");
+    result.attributes.push_back(builder_->getNamedAttr("is_stateless", val));
+  }
+
+  // Register the mapping between the TF node and the newly created operation.
+  node_values_[node.id()] =
+      createOperation(node, op_name, result, control_operands);
+
+  return Status::OK();
+}
+
+// Add the backedges to the CFG. Given a backedge, we replace the original
+// source and destination operations by two new operations. Most of the
+// fields of the replacements are copied from the original operations.
+// However,
+// - for the src operation, one output is inserted to the front of the output
+//   list. The type of the output is set to the type of the non-control result
+//   of the dst operation, and
+// - for the dst operation, one operand is inserted to the front of the
+//   operand list. This operand is using the first result of the src
+//   operation.
+// TODO(fengliuai): Preserve the order of the results and operands if
+// necessary.
+Status ImporterBase::AddBackedges() {
+  for (auto it : back_edge_dst_inputs_) {
+    BackEdge& edge = it.second;
+    if (!edge.src->IsNextIteration() || !edge.dst->IsMerge()) {
+      return errors::FailedPrecondition(
+          "Invalid backedge; should be from NextIteration to Merge!");
+    }
+    auto* sink = node_values_[edge.src->id()];
+    auto* dst = node_values_[edge.dst->id()];
+    TF_RETURN_IF_ERROR(AddBackedge(sink, dst, edge.dst_input));
+  }
+  return Status::OK();
+}
+
+Status ImporterBase::AddBackedge(mlir::Operation* sink, mlir::Operation* dst,
+                                 int dst_input) {
+  // Get the NextIteration.Source operation from the token operand of the sink.
+  mlir::Operation* source = sink->getOperand(0)->getDefiningOp();
+
+  // Adds the "source" to the operands of the dst by creating a new dst
+  // operation.
+  mlir::OperationState state(dst->getLoc(), dst->getName());
+  auto num_operands = dst->getNumOperands();
+  state.operands.reserve(num_operands + 1);
+  for (int input = 0, e = num_operands + 1; input != e; ++input) {
+    if (input < dst_input) {
+      state.operands.push_back(dst->getOperand(input));
+    } else if (input == dst_input) {
+      state.operands.push_back(source->getResult(0));
+    } else {
+      state.operands.push_back(dst->getOperand(input - 1));
+    }
+  }
+  state.attributes.assign(dst->getAttrs().begin(), dst->getAttrs().end());
+  state.types.assign(dst->getResultTypes().begin(),
+                     dst->getResultTypes().end());
+  builder_->setInsertionPoint(dst);
+  auto* new_dst = builder_->createOperation(state);
+
+  // Replaces the output uses of the old operation by the corresponding
+  // result of the new operation, and deletes the old operation.
+  for (unsigned i = 0, e = dst->getNumResults(); i != e; ++i) {
+    auto* new_output = new_dst->getResult(i);
+    dst->getResult(i)->replaceAllUsesWith(new_output);
+  }
+  dst->dropAllReferences();
+  dst->erase();
+  return Status::OK();
+}
+
+StatusOr<mlir::FunctionType> ImporterBase::InferLibFunctionType(
+    const FunctionBody& fbody) {
+  mlir::Builder builder(context_);
+
+  llvm::SmallVector<mlir::Type, 4> arg_types;
+  arg_types.reserve(fbody.arg_types.size());
+  for (auto dataType : fbody.arg_types) {
+    mlir::Type element_type;
+    TF_RETURN_IF_ERROR(
+        ::tensorflow::ConvertDataType(dataType, builder, &element_type));
+    // TODO(hinsu): Derive shape of function arguments based on shapes available
+    // at call sites of this function. That way it is possible to have a
+    // partially known shape in some cases instead of unranked tensor types.
+    arg_types.push_back(builder.getTensorType(element_type));
+  }
+
+  llvm::SmallVector<mlir::Type, 4> ret_types;
+  ret_types.reserve(fbody.ret_types.size());
+  for (auto ret : fbody.ret_nodes) {
+    // Find node in the graph using the node id instead of using `ret` directly
+    // because the graph has been cloned.
+    auto* node = graph_->FindNodeId(ret->id());
+
+    // Return type of the function is type of the only input of the respective
+    // return node in the function.
+    TF_ASSIGN_OR_RETURN(auto type, InferInputType(*node, /*idx=*/0, builder));
+    ret_types.push_back(type);
+  }
+
+  return builder.getFunctionType(arg_types, ret_types);
+}
+
+// Stateful helper class to import a TensorFlow model expressed in GraphDef into
+// an MLIR Module.
+//
+// The nodes defined in the graph is converted to a function called "main". All
+// the library function definitions are converted to MLIR functions in the
+// module.
+class GraphDefImporter : public ImporterBase {
+ public:
+  // Main entry point: converts the given graph to an MLIR Module.
+  static StatusOr<mlir::OwningModuleRef> Convert(
+      mlir::MLIRContext* context, const Graph& graph,
+      const GraphDebugInfo& debug_info,
+      const FunctionLibraryDefinition& flib_def, const NodeSpecs& specs);
+
+ private:
+  explicit GraphDefImporter(
+      const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
+      const NodeSpecs& specs, mlir::ModuleOp module,
+      std::unordered_map<std::string, std::string>* tf_name_to_mlir_name)
+      : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name) {}
+
+  // Returns the function signature of the main function of converted MLIR
+  // module, the input nodes and output nodes. The type and shape information
+  // for the function arguments are read from `specs`, but the type and shape
+  // information for the function returns are inferred by the shape refiner in
+  // ImporterBase.
+  StatusOr<mlir::FunctionType> InferMainFunctionType(
+      const NodeSpecs& specs, mlir::MLIRContext* context,
+      absl::InlinedVector<OutputTensor, 4>* arg_nodes,
+      absl::InlinedVector<OutputTensor, 4>* ret_nodes);
+};
+
+StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
+    mlir::MLIRContext* context, const Graph& graph,
+    const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
+    const NodeSpecs& specs) {
+  mlir::OwningModuleRef module =
+      mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
+  std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
+
+  GraphDefImporter importer(flib_def, debug_info, specs, module.get(),
+                            &tf_name_to_mlir_name);
+  TF_RETURN_IF_ERROR(importer.PrepareConvert(graph));
+
+  // Collects the argument and return nodes by looking up the node names
+  // specified by the user.
+  absl::InlinedVector<OutputTensor, 4> arg_nodes;
+  absl::InlinedVector<OutputTensor, 4> ret_nodes;
+  TF_ASSIGN_OR_RETURN(
+      auto func_type,
+      importer.InferMainFunctionType(specs, context, &arg_nodes, &ret_nodes));
+
+  // TODO(prakalps): Refactor to keep attribute strings (tf.entry_function,
+  // tf.versions) shared by importer and exporter in a centralized place.
+  // Record the input and output mapping.
+  llvm::SmallVector<mlir::NamedAttribute, 1> attrs;
+  if (!specs.inputs.empty() || !specs.output_arrays.empty()) {
+    mlir::Builder b(context);
+    std::string s;
+    llvm::raw_string_ostream ss(s);
+    mlir::interleaveComma(
+        specs.inputs, ss,
+        [&](const std::pair<std::string, ArrayInfo>& v) { ss << v.first; });
+    auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
+    s.clear();
+    mlir::interleaveComma(specs.output_arrays, ss);
+    auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
+
+    attrs.push_back(b.getNamedAttr("tf.entry_function",
+                                   b.getDictionaryAttr({inputs, outputs})));
+  }
+
+  // Record version info.
+  const auto& graph_versions = graph.versions();
+  mlir::Builder b(context);
+  auto producer = b.getNamedAttr(
+      "producer", b.getI32IntegerAttr(graph_versions.producer()));
+  auto min_consumer = b.getNamedAttr(
+      "min_consumer", b.getI32IntegerAttr(graph_versions.min_consumer()));
+  auto bad_consumers = b.getNamedAttr(
+      "bad_consumers", b.getI32ArrayAttr(llvm::ArrayRef<int32_t>(
+                           graph_versions.bad_consumers().begin(),
+                           graph_versions.bad_consumers().end())));
+  module->setAttr("tf.versions",
+                  b.getDictionaryAttr(llvm::ArrayRef<mlir::NamedAttribute>(
+                      {producer, min_consumer, bad_consumers})));
+
+  TF_RETURN_IF_ERROR(importer.ImporterBase::Convert(
+      "main", func_type, arg_nodes, ret_nodes, attrs));
+  return module;
+}
+
+StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType(
+    const NodeSpecs& specs, mlir::MLIRContext* context,
+    absl::InlinedVector<OutputTensor, 4>* arg_nodes,
+    absl::InlinedVector<OutputTensor, 4>* ret_nodes) {
+  // Finds out all the input nodes and output nodes.
+  if (!specs.inputs.empty() || !specs.output_arrays.empty()) {
+    arg_nodes->resize(specs.inputs.size());
+    ret_nodes->resize(specs.output_arrays_order.size());
+
+    for (Node* n : GetOrderedNodes()) {
+      // Handle inputs/arguments.
+      auto input_it = specs.inputs.find(n->name());
+      if (input_it != specs.inputs.end()) {
+        (*arg_nodes)[std::distance(specs.inputs.begin(), input_it)] = {n, 0};
+      }
+
+      // Handle outputs/returns.
+      if (specs.output_arrays.find(n->name()) != specs.output_arrays.end()) {
+        for (int i = 0, e = specs.output_arrays_order.size(); i != e; ++i) {
+          std::pair<std::string, std::string> name_and_port =
+              absl::StrSplit(specs.output_arrays_order[i], ':');
+          auto name = name_and_port.first;
+          if (name != n->name()) continue;
+          int port = 0;
+          if (!name_and_port.second.empty() &&
+              !absl::SimpleAtoi(name_and_port.second, &port)) {
+            return errors::InvalidArgument("Invalid port specification: ",
+                                           specs.output_arrays_order[i]);
+          }
+          (*ret_nodes)[i] = {n, port};
+        }
+      }
+    }
+  }
+
+  int i = 0;
+  for (auto it : specs.inputs) {
+    if (arg_nodes->at(i++).node == nullptr) {
+      return errors::InvalidArgument("Input ", it.first,
+                                     " was not found in graph");
+    }
+  }
+  for (int i = 0, e = specs.output_arrays_order.size(); i != e; ++i) {
+    if (ret_nodes->at(i).node == nullptr) {
+      return errors::InvalidArgument("Output ", specs.output_arrays_order[i],
+                                     " was not found in graph");
+    }
+  }
+
+  // Starts to construct the function type.
+  llvm::SmallVector<mlir::Type, 4> arg_types;
+  llvm::SmallVector<mlir::Type, 4> ret_types;
+  arg_types.reserve(specs.inputs.size());
+  ret_types.reserve(specs.output_arrays.size());
+  mlir::Builder builder(context);
+
+  // Input nodes as function arguments.
+  for (const auto& input : specs.inputs) {
+    mlir::Type element_type;
+    const auto& node_info = input.second;
+    TF_RETURN_IF_ERROR(::tensorflow::ConvertDataType(node_info.imported_dtype,
+                                                     builder, &element_type));
+    llvm::SmallVector<int64_t, 4> shape;
+    TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape));
+    arg_types.push_back(builder.getTensorType(shape, element_type));
+  }
+
+  // Output nodes as function returns.
+  for (const auto& ret : *ret_nodes) {
+    if (ret.node->num_outputs() <= ret.index) {
+      return errors::InvalidArgument("Invalid output index ", ret.index,
+                                     " specified for node: ", ret.node->name());
+    }
+    TF_ASSIGN_OR_RETURN(auto type,
+                        InferOutputType(*ret.node, ret.index, builder));
+    ret_types.push_back(type);
+  }
+
+  return builder.getFunctionType(arg_types, ret_types);
+}
+
+}  // namespace
+
+StatusOr<mlir::OwningModuleRef> ConvertGraphdefToMlir(
+    const GraphDef& graphdef, const GraphDebugInfo& debug_info,
+    const NodeSpecs& specs, mlir::MLIRContext* context,
+    bool add_default_attributes) {
+  GraphConstructorOptions options;
+  options.allow_internal_ops = true;
+  Graph graph(OpRegistry::Global());
+
+  GraphDef preprocessed_graphdef(graphdef);
+  if (add_default_attributes) {
+    TF_RETURN_IF_ERROR(PreprocessGraphDef(specs, &preprocessed_graphdef));
+  }
+  TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
+      options, std::move(preprocessed_graphdef), &graph));
+
+  return ConvertGraphToMlir(graph, debug_info, graph.flib_def(), specs,
+                            context);
+}
+
+StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
+    const Graph& graph, const GraphDebugInfo& debug_info,
+    const FunctionLibraryDefinition& flib_def, const NodeSpecs& specs,
+    mlir::MLIRContext* context) {
+  return GraphDefImporter::Convert(context, graph, debug_info, flib_def, specs);
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h
new file mode 100644
index 0000000..a996ca6
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h
@@ -0,0 +1,46 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_
+#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_
+
+#include "mlir/IR/MLIRContext.h"  // TF:local_config_mlir
+#include "mlir/IR/Module.h"  // TF:local_config_mlir
+#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+
+namespace tensorflow {
+
+// Given a GraphDef, returns a MLIR module containing the graph in control-flow
+// form.
+stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertGraphdefToMlir(
+    const GraphDef& graphdef, const GraphDebugInfo& debug_info,
+    const NodeSpecs& specs, mlir::MLIRContext* context,
+    bool add_default_attributes = true);
+
+// Given a Graph, returns a MLIR module containing the graph in control-flow
+// form.
+stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
+    const Graph& graph, const GraphDebugInfo& debug_info,
+    const FunctionLibraryDefinition& flib_def, const NodeSpecs& specs,
+    mlir::MLIRContext* context);
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h
index 3fc7ee5..dcd8008 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h
+++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h
@@ -56,6 +56,11 @@
   // setting prune_unused_nodes to true, would prune unreachable nodes if
   // output_arrays is specified.
   bool prune_unused_nodes = false;
+  // If true, inputs of type LegacyFedInput are replaced with Placeholder ops.
+  // LegacyFedInput ops have two outputs unlike Placeholder which has only one
+  // output, so if both outputs of the LegacyFedInput ops are used then returns
+  // an error.
+  bool convert_legacy_fed_inputs = false;
 };
 
 struct ExporterConfigs {
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc
index 231a734..3ebd722 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc
@@ -19,7 +19,7 @@
 #include "mlir/IR/MLIRContext.h"  // TF:local_config_mlir
 #include "mlir/IR/Module.h"  // TF:local_config_mlir
 #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
-#include "tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.h"
+#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/core/graph/graph_constructor.h"
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc
index cd487811..f9a6e24 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc
@@ -26,8 +26,7 @@
 #include "mlir/IR/Operation.h"  // TF:local_config_mlir
 #include "mlir/IR/StandardTypes.h"  // TF:local_config_mlir
 #include "mlir/Parser.h"  // TF:local_config_mlir
-#include "mlir/Pass/PassManager.h"  // TF:local_config_mlir
-#include "tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.h"
+#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
@@ -36,14 +35,6 @@
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
 
-namespace mlir {
-/// Create a pass to convert from the TF control to the TFExecutor dialect.
-FunctionPassBase* CreateTFControlToExecutorDialectConversion();
-
-/// Create a pass to convert from the TFExecutor to the TF control dialect.
-FunctionPassBase* CreateTFExecutorToControlDialectConversion();
-}  // namespace mlir
-
 namespace tensorflow {
 
 using stream_executor::port::Status;
@@ -55,7 +46,7 @@
     absl::string_view input_shapes, absl::string_view output_arrays,
     absl::string_view inference_type, absl::string_view min_values,
     absl::string_view max_values, bool prune_unused_nodes,
-    mlir::MLIRContext* context) {
+    bool convert_legacy_fed_inputs, mlir::MLIRContext* context) {
   GraphDef graphdef;
   TF_RETURN_IF_ERROR(tensorflow::LoadProtoFromFile(input_filename, &graphdef));
 
@@ -66,6 +57,7 @@
 
   NodeSpecs specs;
   specs.prune_unused_nodes = prune_unused_nodes;
+  specs.convert_legacy_fed_inputs = convert_legacy_fed_inputs;
   TF_RETURN_IF_ERROR(ParseInputArrayInfo(
       input_arrays, input_dtypes, input_shapes, inference_type, min_values,
       max_values, &specs.inputs));
@@ -80,29 +72,16 @@
     absl::string_view input_shapes, absl::string_view output_arrays,
     absl::string_view inference_type, absl::string_view min_values,
     absl::string_view max_values, bool prune_unused_nodes,
-    mlir::MLIRContext* context) {
+    bool convert_legacy_fed_inputs, mlir::MLIRContext* context) {
   auto module_or = GraphdefToMlirImport(
       input_filename, debug_info_file, input_arrays, input_dtypes, input_shapes,
       output_arrays, inference_type, min_values, max_values, prune_unused_nodes,
-      context);
+      convert_legacy_fed_inputs, context);
   if (!module_or.status().ok()) {
     LOG(ERROR) << "Graph import failed: " << module_or.status();
     return nullptr;
   }
 
-  // Round-trip to the tf_executor dialect, this is temporary while bringing up
-  // the new dialect.
-  {
-    mlir::PassManager pm;
-    pm.addPass(mlir::CreateTFControlToExecutorDialectConversion());
-    pm.addPass(mlir::CreateTFExecutorToControlDialectConversion());
-    if (failed(pm.run(module_or.ValueOrDie().get()))) {
-      module_or.ValueOrDie()->emitOpError()
-          << "Round-trip to tf_executor dialect failed";
-      return nullptr;
-    }
-  }
-
   return module_or.ConsumeValueOrDie();
 }
 
@@ -112,11 +91,11 @@
     absl::string_view input_shapes, absl::string_view output_arrays,
     absl::string_view inference_type, absl::string_view min_values,
     absl::string_view max_values, bool prune_unused_nodes,
-    mlir::MLIRContext* context) {
+    bool convert_legacy_fed_inputs, mlir::MLIRContext* context) {
   auto module_or = GraphdefToMlirImport(
       input_filename, debug_info_file, input_arrays, input_dtypes, input_shapes,
       output_arrays, inference_type, min_values, max_values, prune_unused_nodes,
-      context);
+      convert_legacy_fed_inputs, context);
   if (!module_or.status().ok()) {
     LOG(ERROR) << "Graph import failed: " << module_or.status();
     return nullptr;
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h
index 794a2ef..7696f5a 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h
+++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h
@@ -33,7 +33,7 @@
     absl::string_view input_shapes, absl::string_view output_arrays,
     absl::string_view inference_type, absl::string_view min_values,
     absl::string_view max_values, bool prune_unused_nodes,
-    mlir::MLIRContext* context);
+    bool convert_legacy_fed_inputs, mlir::MLIRContext* context);
 
 // Similar as the above function, but replaces all constant tensors
 // with randomly generated splat values.
@@ -43,7 +43,7 @@
     absl::string_view input_shapes, absl::string_view output_arrays,
     absl::string_view inference_type, absl::string_view min_values,
     absl::string_view max_values, bool prune_unused_nodes,
-    mlir::MLIRContext* context);
+    bool convert_legacy_fed_inputs, mlir::MLIRContext* context);
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_H_
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.cc
index 8e74296..65bccb4 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.cc
@@ -84,3 +84,10 @@
     "tf-prune-unused-nodes",
     llvm::cl::desc("Prune unused nodes in the input graphdef "),
     llvm::cl::init(false));
+
+// NOLINTNEXTLINE
+opt<bool> convert_legacy_fed_inputs(
+    "tf-convert-legacy-fed-inputs",
+    llvm::cl::desc(
+        "Eliminate LegacyFedInput nodes by replacing them with Placeholder "),
+    llvm::cl::init(false));
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h
index 8cf17e3..f5126c4 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h
+++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h
@@ -35,5 +35,6 @@
 extern llvm::cl::opt<std::string> max_values;
 extern llvm::cl::opt<std::string> debug_info_file;
 extern llvm::cl::opt<bool> prune_unused_nodes;
+extern llvm::cl::opt<bool> convert_legacy_fed_inputs;
 
 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_CL_H_
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc
index 7d7632d..b4413d9 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc
@@ -45,7 +45,7 @@
   return tensorflow::GraphdefToMlirTranslateFunction(
       StringRefToView(input_filename), debug_info_file, input_arrays,
       input_dtypes, input_shapes, output_arrays, inference_type, min_values,
-      max_values, prune_unused_nodes, context);
+      max_values, prune_unused_nodes, convert_legacy_fed_inputs, context);
 }
 
 static TranslateToMLIRRegistration GraphdefToMlirTranslate(
@@ -56,7 +56,7 @@
   return tensorflow::GraphdefToSplattedMlirTranslateFunction(
       StringRefToView(input_filename), debug_info_file, input_arrays,
       input_dtypes, input_shapes, output_arrays, inference_type, min_values,
-      max_values, prune_unused_nodes, context);
+      max_values, prune_unused_nodes, convert_legacy_fed_inputs, context);
 }
 
 static TranslateToMLIRRegistration GraphdefToSplattedMlirTranslate(
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc
index 380d125..0e5b46a 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc
@@ -35,7 +35,6 @@
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/platform/cord.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/stream_executor/lib/statusor.h"
 
@@ -43,20 +42,23 @@
 
 using llvm::ArrayRef;
 using llvm::SmallVector;
-using mlir::Attribute;
-using mlir::BoolAttr;
 using mlir::Builder;
 using mlir::DenseFPElementsAttr;
 using mlir::DenseIntElementsAttr;
 using mlir::ElementsAttr;
-using mlir::FloatAttr;
-using mlir::IntegerAttr;
 using mlir::OpaqueElementsAttr;
 using mlir::ShapedType;
-using mlir::SplatElementsAttr;
 using mlir::Type;
 using tensorflow::errors::InvalidArgument;
 
+void ConvertToMlirShape(const TensorShape& input_shape,
+                        llvm::SmallVectorImpl<int64_t>* shape) {
+  shape->reserve(input_shape.dims());
+  for (const auto& d : input_shape) {
+    shape->push_back(d.size);
+  }
+}
+
 Status ConvertToMlirShape(const TensorShapeProto& input_shape,
                           llvm::SmallVectorImpl<int64_t>* shape) {
   shape->reserve(input_shape.dim_size());
@@ -70,206 +72,72 @@
   return Status::OK();
 }
 
-// Converts an TensorFlow tensor proto to an MLIR opaque elements attribute.
-StatusOr<ElementsAttr> ConvertToOpaqueElementsAttr(
-    const TensorProto& input_tensor, ShapedType type, Builder* builder) {
-  // TODO(shpeisman): restructure code to reuse dialect pointer across calls.
-  auto* dialect = builder->getContext()->getRegisteredDialect("tf");
-  return builder->getOpaqueElementsAttr(
-      dialect, type, mangling_util::MangleTensor(input_tensor));
+static TensorProto ConvertToProto(const Tensor& input_tensor,
+                                  bool use_tensor_content = true) {
+  TensorProto tensor_proto;
+  // Using tensor content (mostly*) reduces serialization overhead during RPC
+  // calls, but is less human reader friendly. People reading protobufs are less
+  // frequent than serialization, so default to using tensor content
+  // representation.
+  // * For scalars and short strings it may be marginally worse and a more
+  //   intelligent decision could be made by caller.
+  if (use_tensor_content)
+    input_tensor.AsProtoTensorContent(&tensor_proto);
+  else
+    input_tensor.AsProtoField(&tensor_proto);
+  return tensor_proto;
 }
 
-// Template predicate that provides a constant member `value` equal to true if
-// a sequence of `From` values can be copied wholesale to locations for `To`
-// values.
-
-// Primary template declaration
-template <typename From, typename To, typename Enable = void>
-struct IsBatchCopyable;
-
-// Partial template specialization: allow wholesale copy for the same type
-template <typename Self>
-struct IsBatchCopyable<Self, Self> : std::true_type {};
-
-// SFINAE: integral types depend on the bitwidth
-template <typename From, typename To>
-struct IsBatchCopyable<
-    From, To,
-    typename std::enable_if<std::is_integral<From>::value &&
-                            std::is_integral<To>::value>::type> {
-  static constexpr bool value =
-      std::numeric_limits<From>::digits == std::numeric_limits<To>::digits;
-};
-
-// Converts an TensorFlow tensor proto to an MLIR dense elements attribute.
-// To save the memory held by the attribute, the value is casted to the
-// specified type.
-template <typename ProtoT, typename MlirT>
-typename std::enable_if<IsBatchCopyable<ProtoT, MlirT>::value,
-                        StatusOr<ElementsAttr>>::type
-ConvertToDenseElementsAttr(
-    const tensorflow::protobuf::RepeatedField<ProtoT>& values, ShapedType type,
-    Builder* builder) {
-  return mlir::DenseElementsAttr::get(
-      type, llvm::makeArrayRef(values.data(), values.size()));
+static std::string MangleTensor(const Tensor& tensor) {
+  return mangling_util::MangleTensor(ConvertToProto(tensor));
 }
 
-template <typename ProtoT, typename MlirT>
-typename std::enable_if<!IsBatchCopyable<ProtoT, MlirT>::value,
-                        StatusOr<ElementsAttr>>::type
-ConvertToDenseElementsAttr(
-    const tensorflow::protobuf::RepeatedField<ProtoT>& values, ShapedType type,
-    Builder* builder) {
-  std::vector<MlirT> buff;
-  buff.reserve(values.size());
-  for (auto value : values) {
-    buff.push_back(value);
-  }
-  return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(buff));
-}
-
-// Convert a TensorFlow tensor from its raw serialization into a
-// DenseElementAttr. This is a wrapper around mlir::DenseElementsAttr that
-// creates a temporary copy of the data for satisfying strict aliasing
-// defensively. TODO(aminim): this extra copy should not be needed,
-// DenseElementAttr will perform a similar copy internally.
-// Template parameter `T` must match the element type of the `type` argument
-// (this is checked in DenseElementsAttr::get()).
+// Converts a TensorFlow tensor into an MLIR elements attribute.
 template <typename T>
-mlir::DenseElementsAttr ConvertToDenseElementsAttr(const absl::Cord& values,
-                                                   ShapedType type,
-                                                   Builder* builder) {
-  DCHECK_EQ((values.size() % sizeof(T)), 0)
-      << "unexpected size vs elt type mismatch";
-  int n_elements = values.size() / sizeof(T);
-  auto data = absl::make_unique<T[]>(n_elements);
-  // This assumes that the endianess conversion was handled when loading the
-  // tensor in memory.
-  values.CopyToArray(reinterpret_cast<char*>(data.get()));
-  return mlir::DenseElementsAttr::get(
-      type, llvm::makeArrayRef(data.get(), n_elements));
-}
-
-// Converts an TensorFlow tensor proto with DT_FLOAT data type into an MLIR
-// elements attribute.
-StatusOr<ElementsAttr> ConvertFloatTensor(const TensorProto& input_tensor,
-                                          ShapedType type, Builder* builder) {
-  // When the repeated "float_val" field only has one element, it is converted
-  // to a splat elements attribute; When it has more than one element, it is
-  // converted to a dense elements attribute; otherwise, convert the whole
-  // tensor to an opaque elements attribute if the "tensor_content" field is
-  // set.
-  auto repeated_val_size = input_tensor.float_val_size();
-  if (repeated_val_size == 1 || repeated_val_size == type.getNumElements()) {
-    return ConvertToDenseElementsAttr<float, float>(input_tensor.float_val(),
-                                                    type, builder);
-  }
-  auto raw_data = input_tensor.tensor_content();
-  if (raw_data.size() == type.getSizeInBits() / 8)
-    return ConvertToDenseElementsAttr<float>(raw_data, type, builder);
-  return ConvertToOpaqueElementsAttr(input_tensor, type, builder);
-}
-
-// Converts an TensorFlow tensor proto with DT_INT32, DT_INT16, DT_INT8,
-// DT_UINT8, DT_QUINT8 data type into an MLIR elements attribute.
-template <typename T>
-StatusOr<ElementsAttr> ConvertIntTensor(const TensorProto& input_tensor,
-                                        ShapedType type, Builder* builder) {
-  // When the repeated "int_val" field only has one element, it is converted to
-  // a splat elements attribute; When it has more than one element, it is
-  // converted to a dense elements attribute; otherwise, convert the whole
-  // tensor to an opaque elements attribute if the "tensor_content" field is
-  // set.
-  auto repeated_val_size = input_tensor.int_val_size();
-  if (repeated_val_size == 1 || repeated_val_size == type.getNumElements()) {
-    return ConvertToDenseElementsAttr<int32_t, T>(input_tensor.int_val(), type,
-                                                  builder);
-  }
-  auto raw_data = input_tensor.tensor_content();
-  if (raw_data.size() == type.getSizeInBits() / 8)
-    return ConvertToDenseElementsAttr<int32_t>(raw_data, type, builder);
-
-  return ConvertToOpaqueElementsAttr(input_tensor, type, builder);
-}
-
-// Converts an TensorFlow tensor proto with DT_INT64 data type into an MLIR
-// elements attribute.
-StatusOr<ElementsAttr> ConvertInt64Tensor(const TensorProto& input_tensor,
-                                          ShapedType type, Builder* builder) {
-  // When the repeated "int64_val" field only has one element, it is converted
-  // to a splat elements attribute; When it has more than one element, it is
-  // converted to a dense elements attribute; otherwise, convert the whole
-  // tensor to an opaque elements attribute if the "tensor_content" field is
-  // set.
-  auto repeated_val_size = input_tensor.int64_val_size();
-  if (repeated_val_size == 1 || repeated_val_size == type.getNumElements()) {
-    return ConvertToDenseElementsAttr<decltype(input_tensor.int64_val(0)),
-                                      uint64_t>(input_tensor.int64_val(), type,
-                                                builder);
-  }
-  auto raw_data = input_tensor.tensor_content();
-  if (raw_data.size() == type.getSizeInBits() / 8)
-    return ConvertToDenseElementsAttr<int64_t>(raw_data, type, builder);
-  return ConvertToOpaqueElementsAttr(input_tensor, type, builder);
-}
-
-// Converts an TensorFlow tensor proto with DT_BOOL data type into an MLIR
-// elements attribute.
-StatusOr<ElementsAttr> ConvertBoolTensor(const TensorProto& input_tensor,
+StatusOr<ElementsAttr> ConvertFlatTensor(const Tensor& input_tensor,
                                          ShapedType type, Builder* builder) {
-  // When the repeated "bool_val" field only has one element, it is converted to
-  // a splat elements attribute; When it has more than one element, it is
-  // converted to a dense elements attribute; otherwise, convert the whole
-  // tensor to an opaque elements attribute if the "tensor_content" field is
-  // set.
-  auto repeated_val_size = input_tensor.bool_val_size();
-  if (repeated_val_size == 1 || repeated_val_size == type.getNumElements()) {
-    const auto& proto = input_tensor.bool_val();
-    return mlir::DenseElementsAttr::get(
-        type, llvm::makeArrayRef(proto.data(), proto.size()));
+  auto arr = input_tensor.flat<T>();
+  return mlir::DenseElementsAttr::get(
+      type, llvm::makeArrayRef(arr.data(), arr.size()));
+}
+
+StatusOr<ElementsAttr> ConvertTensor(const Tensor& input_tensor,
+                                     Builder* builder) {
+  const auto& input_dtype = input_tensor.dtype();
+  const auto& input_shape = input_tensor.shape();
+  Type elt_type;
+  TF_RETURN_IF_ERROR(ConvertDataType(input_dtype, *builder, &elt_type));
+  SmallVector<int64_t, 4> shape;
+  ConvertToMlirShape(input_shape, &shape);
+  auto type = builder->getTensorType(shape, elt_type);
+
+#define CONVERT_FLAT(DTYPE, CTYPE) \
+  case DTYPE:                      \
+    return ConvertFlatTensor<CTYPE>(input_tensor, type, builder);
+
+  // TODO(fengliuai): customize the conversions for more types.
+  switch (input_dtype) {
+    CONVERT_FLAT(DT_BOOL, bool)
+    CONVERT_FLAT(DT_FLOAT, float)
+    CONVERT_FLAT(DT_INT32, int32)
+    CONVERT_FLAT(DT_INT64, int64)
+    default:
+      // TODO(shpeisman): restructure code to reuse dialect pointer across
+      // calls.
+      auto* dialect = builder->getContext()->getRegisteredDialect("tf");
+      return builder->getOpaqueElementsAttr(dialect, type,
+                                            MangleTensor(input_tensor));
   }
-  return ConvertToOpaqueElementsAttr(input_tensor, type, builder);
+
+#undef CONVERT_FLAT
 }
 
 StatusOr<ElementsAttr> ConvertTensorProto(const TensorProto& input_tensor,
                                           Builder* builder) {
-  const auto& input_dtype = input_tensor.dtype();
-  const auto& input_shape = input_tensor.tensor_shape();
-  Type elt_type;
-  TF_RETURN_IF_ERROR(ConvertDataType(input_dtype, *builder, &elt_type));
-  SmallVector<int64_t, 4> shape;
-  TF_RETURN_IF_ERROR(ConvertToMlirShape(input_shape, &shape));
-  auto type = builder->getTensorType(shape, elt_type);
-
-  // TODO(fengliuai): customize the conversions for more types.
-  switch (input_dtype) {
-    case DT_FLOAT:
-      return ConvertFloatTensor(input_tensor, type, builder);
-    case DT_INT32:
-      return ConvertIntTensor<uint32_t>(input_tensor, type, builder);
-    case DT_INT64:
-      return ConvertInt64Tensor(input_tensor, type, builder);
-    case DT_BOOL:
-      return ConvertBoolTensor(input_tensor, type, builder);
-    default:
-      // The value of the opaque elements attribute contains the whole tensor
-      // proto, not just the tensor content.
-
-      // TODO(shpeisman): restructure code to reuse dialect pointer across
-      // calls.
-      auto* dialect = builder->getContext()->getRegisteredDialect("tf");
-
-      return builder->getOpaqueElementsAttr(
-          dialect, type, mangling_util::MangleTensor(input_tensor));
-  }
-}
-
-StatusOr<mlir::ElementsAttr> ConvertTensor(const Tensor& input_tensor,
-                                           mlir::Builder* builder) {
-  TensorProto input_proto;
-  // This decodes the tensor content into a proper proto field.
-  input_tensor.AsProtoField(&input_proto);
-  return ConvertTensorProto(input_proto, builder);
+  Tensor t;
+  if (!t.FromProto(input_tensor))
+    return InvalidArgument("Failed to parse input_tensor.");
+  return ConvertTensor(t, builder);
 }
 
 Status ConvertToTensorShapeProto(ArrayRef<int64_t> shape,
@@ -280,7 +148,7 @@
   return Status::OK();
 }
 
-// Converts an MLIR opaque elements attribute to an TensorFlow tensor proto.
+// Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
 Status ConvertOpaqueElementsAttr(const ElementsAttr attr,
                                  TensorProto* output_tensor) {
   if (attr.isa<OpaqueElementsAttr>()) {
@@ -291,7 +159,7 @@
   return InvalidArgument("Unexpected elements attribute type from MLIR.");
 }
 
-// Converts an MLIR elements attribute to an TensorFlow tensor proto
+// Converts an MLIR elements attribute to a TensorFlow tensor proto
 // with the float_val field updated.
 Status ConvertFloatElementsAttr(const ElementsAttr attr,
                                 TensorProto* output_tensor) {
@@ -299,13 +167,12 @@
     for (auto value : elts.getValues<float>()) {
       output_tensor->add_float_val(value);
     }
-  } else {
-    return ConvertOpaqueElementsAttr(attr, output_tensor);
+    return Status::OK();
   }
-  return Status::OK();
+  return ConvertOpaqueElementsAttr(attr, output_tensor);
 }
 
-// Converts an MLIR elements attribute to an TensorFlow tensor proto
+// Converts an MLIR elements attribute to a TensorFlow tensor proto
 // with the int_val field updated.
 Status ConvertIntElementsAttr(const mlir::ElementsAttr attr,
                               TensorProto* output_tensor) {
@@ -313,13 +180,12 @@
     for (auto val : elts) {
       output_tensor->add_int_val(val.getSExtValue());
     }
-  } else {
-    return ConvertOpaqueElementsAttr(attr, output_tensor);
+    return Status::OK();
   }
-  return Status::OK();
+  return ConvertOpaqueElementsAttr(attr, output_tensor);
 }
 
-// Converts an MLIR elements attribute to an TensorFlow tensor proto
+// Converts an MLIR elements attribute to a TensorFlow tensor proto
 // with the int64_val field updated.
 Status ConvertInt64ElementsAttr(const mlir::ElementsAttr attr,
                                 TensorProto* output_tensor) {
@@ -327,13 +193,12 @@
     for (auto val : elts) {
       output_tensor->add_int64_val(val.getSExtValue());
     }
-  } else {
-    return ConvertOpaqueElementsAttr(attr, output_tensor);
+    return Status::OK();
   }
-  return Status::OK();
+  return ConvertOpaqueElementsAttr(attr, output_tensor);
 }
 
-// Converts an MLIR elements attribute to an TensorFlow tensor proto
+// Converts an MLIR elements attribute to a TensorFlow tensor proto
 // with bool_val field updated.
 Status ConvertBoolElementsAttr(const mlir::ElementsAttr attr,
                                TensorProto* output_tensor) {
@@ -341,10 +206,9 @@
     for (auto val : elts) {
       output_tensor->add_bool_val(val.getBoolValue());
     }
-  } else {
-    return ConvertOpaqueElementsAttr(attr, output_tensor);
+    return Status::OK();
   }
-  return Status::OK();
+  return ConvertOpaqueElementsAttr(attr, output_tensor);
 }
 
 Status ConvertToTensorProto(const ElementsAttr attr,
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc
index 5d6cd1bb..4e59cec 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc
@@ -29,12 +29,8 @@
 
 TEST(ErrorUtilTest, StatusScopedDiagnosticHandler) {
   MLIRContext context;
-
-  auto emit_error = [&](const std::string& msg) {
-    emitError(FileLineColLoc::get(Identifier::get("test.cc", &context), 10, 32,
-                                  &context),
-              msg);
-  };
+  auto id = Identifier::get("test.cc", &context);
+  auto loc = FileLineColLoc::get(id, 0, 0, &context);
 
   // Test OK without diagnostic gets passed through.
   {
@@ -44,7 +40,7 @@
   // Verify diagnostics are captured as Unknown status.
   {
     StatusScopedDiagnosticHandler handler(&context);
-    emit_error("Diagnostic message");
+    emitError(loc) << "Diagnostic message";
     ASSERT_TRUE(tensorflow::errors::IsUnknown(handler.ConsumeStatus()));
   }
 
@@ -58,8 +54,8 @@
   // Verify diagnostic reported are append to passed in error.
   {
     auto function = [&]() {
-      emit_error("Diagnostic message reported");
-      emit_error("Second diagnostic message reported");
+      emitError(loc) << "Diagnostic message reported";
+      emitError(loc) << "Second diagnostic message reported";
       return tensorflow::errors::Internal("Passed in error");
     };
     Status s = StatusScopedDiagnosticHandler(&context).Combine(function());
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc
index 7befa9a..dae5aa8 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc
@@ -172,6 +172,18 @@
   }
 }
 
+// Updates NodeDef constructed out of an MLIR While op to map it to either
+// TensorFlow StatelessWhile or While op depending on the additional attribute.
+void UpdateCompositeWhileOp(NodeDef* node_def) {
+  auto it = node_def->mutable_attr()->find("is_stateless");
+  if (it != node_def->attr().end()) {
+    if (it->second.b()) {
+      *node_def->mutable_op() = "StatelessWhile";
+    }
+    node_def->mutable_attr()->erase(it);
+  }
+}
+
 }  // anonymous namespace
 
 StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
@@ -207,6 +219,7 @@
       inst->getLoc(), node_def->mutable_experimental_debug_info()));
 
   if (node_def->op() == "If") UpdateCompositeIfOp(node_def.get());
+  if (node_def->op() == "While") UpdateCompositeWhileOp(node_def.get());
 
   return node_def;
 }
diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc
new file mode 100644
index 0000000..b10f432
--- /dev/null
+++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc
@@ -0,0 +1,50 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/PrettyStackTrace.h"
+#include "mlir/IR/MLIRContext.h"  // TF:local_config_mlir
+#include "mlir/Support/LogicalResult.h"  // TF:local_config_mlir
+#include "mlir/Support/TranslateClParser.h"  // TF:local_config_mlir
+#include "tensorflow/core/platform/init_main.h"
+
+// NOLINTNEXTLINE
+static llvm::cl::opt<std::string> input_filename(llvm::cl::Positional,
+                                                 llvm::cl::desc("<input file>"),
+                                                 llvm::cl::init("-"));
+
+// NOLINTNEXTLINE
+static llvm::cl::opt<std::string> output_filename(
+    "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
+    llvm::cl::init("-"));
+
+int main(int argc, char** argv) {
+  llvm::PrettyStackTraceProgram x(argc, argv);
+  llvm::InitLLVM y(argc, argv);
+
+  // Add flags for all the registered translations.
+  llvm::cl::opt<const mlir::TranslateFunction*, false, mlir::TranslationParser>
+      requested_translation("", llvm::cl::desc("Translation to perform"),
+                            llvm::cl::Required);
+  llvm::cl::ParseCommandLineOptions(argc, argv, "TF MLIR translation driver\n");
+
+  // TODO(jpienaar): Enable command line parsing for both sides.
+  int fake_argc = 1;
+  tensorflow::port::InitMain(argv[0], &fake_argc, &argv);
+
+  mlir::MLIRContext context;
+  return failed(
+      (*requested_translation)(input_filename, output_filename, &context));
+}
diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
index b9ba5fc..55a11db 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
@@ -374,7 +374,7 @@
       // is not mentioned in xla client anywhere or in the hlo of our sample
       // models.
     default: {
-      mlir::OperationState result(loc, "xla.unknown");
+      mlir::OperationState result(loc, "xla_hlo.unknown");
       result.addOperands(operands);
       result.addTypes(result_type);
       for (auto attr : attributes) {
diff --git a/tensorflow/compiler/mlir/xla/ir/dialect_registration.cc b/tensorflow/compiler/mlir/xla/ir/dialect_registration.cc
index 79eda9c..57f8733 100644
--- a/tensorflow/compiler/mlir/xla/ir/dialect_registration.cc
+++ b/tensorflow/compiler/mlir/xla/ir/dialect_registration.cc
@@ -14,7 +14,6 @@
 ==============================================================================*/
 
 #include "tensorflow/compiler/mlir/xla/ir/xla_ops.h"
-using namespace mlir;
 
 // Static initialization for XLA dialect registration.
-static DialectRegistration<XLA::XLADialect> XlaOps;
+static mlir::DialectRegistration<mlir::XLA::XlaHloDialect> xla_hlo_ops;
diff --git a/tensorflow/compiler/mlir/xla/ir/xla_ops.cc b/tensorflow/compiler/mlir/xla/ir/xla_ops.cc
index f47d4a0..13194e3 100644
--- a/tensorflow/compiler/mlir/xla/ir/xla_ops.cc
+++ b/tensorflow/compiler/mlir/xla/ir/xla_ops.cc
@@ -26,7 +26,7 @@
 using namespace mlir;
 using namespace mlir::XLA;
 
-XLADialect::XLADialect(MLIRContext* context)
+XlaHloDialect::XlaHloDialect(MLIRContext* context)
     : Dialect(getDialectNamespace(), context) {
   addOperations<
 #define GET_OP_LIST
@@ -37,9 +37,10 @@
   allowUnknownOperations();
 }
 
-Operation* XLADialect::materializeConstant(OpBuilder& builder, Attribute value,
-                                           Type type, Location loc) {
-  // If this is an opaque elements attribute, then generate an xla.constant.
+Operation* XlaHloDialect::materializeConstant(OpBuilder& builder,
+                                              Attribute value, Type type,
+                                              Location loc) {
+  // If this is an opaque elements attribute, then generate an xla_hlo.constant.
   if (value.isa<OpaqueElementsAttr>())
     return builder.create<XLA::ConstOp>(loc, type, value.cast<ElementsAttr>());
   return nullptr;
@@ -74,7 +75,7 @@
   }
 
   // TODO: support other XLA specific types.
-  assert(type && "unsupported attribute type for building xla.constant");
+  assert(type && "unsupported attribute type for building xla_hlo.constant");
   result->types.push_back(type);
   result->addAttribute("value", value);
 }
diff --git a/tensorflow/compiler/mlir/xla/ir/xla_ops.h b/tensorflow/compiler/mlir/xla/ir/xla_ops.h
index 2be8160..9f82392 100644
--- a/tensorflow/compiler/mlir/xla/ir/xla_ops.h
+++ b/tensorflow/compiler/mlir/xla/ir/xla_ops.h
@@ -29,10 +29,10 @@
 
 namespace XLA {
 
-class XLADialect : public Dialect {
+class XlaHloDialect : public Dialect {
  public:
-  XLADialect(MLIRContext *context);
-  static StringRef getDialectNamespace() { return "xla"; }
+  explicit XlaHloDialect(MLIRContext *context);
+  static StringRef getDialectNamespace() { return "xla_hlo"; }
 
   // Registered hook to materialize a constant operation from a given attribute
   // value with the desired resultant type.
diff --git a/tensorflow/compiler/mlir/xla/ir/xla_ops.td b/tensorflow/compiler/mlir/xla/ir/xla_ops.td
index a05dd9b..1a3ce77 100644
--- a/tensorflow/compiler/mlir/xla/ir/xla_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/xla_ops.td
@@ -25,7 +25,7 @@
 #endif // OP_BASE
 
 def XLA_Dialect : Dialect {
-  let name = "xla";
+  let name = "xla_hlo";
   let cppNamespace = "XLA";
 }
 
@@ -111,8 +111,11 @@
 //===----------------------------------------------------------------------===//
 // See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
 class XLA_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits>:
-    XLA_Op<mnemonic, traits>, Arguments<(ins XLA_Tensor:$operand)>,
-    Results<(outs XLA_Tensor:$res)>;
+    XLA_Op<mnemonic, traits> {
+
+    let arguments = (ins XLA_Tensor);
+    let results = (outs XLA_Tensor);
+}
 
 def XLA_AbsOp: XLA_UnaryElementwiseOp<"abs", [NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Absolute value operator";
@@ -196,15 +199,14 @@
 def BroadcastDimAttr : OptionalAttr<ElementsAttr>;
 
 // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
-class XLA_BinaryElementwiseOp<string mnemonic,
-    list<OpTrait> traits, dag args = (ins)> :
-        XLA_Op<mnemonic, traits>,
-        Arguments<(
-            ins XLA_Tensor:$lhs,
-            XLA_Tensor:$rhs,
-            BroadcastDimAttr:$broadcast_dimensions
-        )>,
-        Results<(outs XLA_Tensor:$res)> {
+class XLA_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
+        XLA_Op<mnemonic, traits> {
+  let arguments = (ins
+      XLA_Tensor:$lhs,
+      XLA_Tensor:$rhs,
+      BroadcastDimAttr:$broadcast_dimensions
+  );
+  let results = (outs XLA_Tensor);
   let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }];
   let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }];
 }
@@ -302,7 +304,7 @@
     SymbolRefAttr:$body
   );
 
-  let results = (outs Variadic<XLA_TensorOrTuple>:$res);
+  let results = (outs Variadic<XLA_TensorOrTuple>);
 
   // TODO(b/129422361): WhileOp has special conversion logic to HLO.
   let hasCustomHLOConverter = 1;
@@ -324,7 +326,7 @@
     ElementsAttr:$dimensions
   );
 
-  let results = (outs Variadic<XLA_Tensor>:$res);
+  let results = (outs Variadic<XLA_Tensor>);
 
   // TODO(b/129422361): ReduceOp has special conversion logic to HLO.
   let hasCustomHLOConverter = 1;
@@ -363,7 +365,7 @@
    }];
 
    let arguments = (ins Variadic<XLA_TensorOrTuple>:$val);
-   let results = (outs XLA_Tuple:$res);
+   let results = (outs XLA_Tuple);
 
   // TupleOp has special conversion logic to HLO.
   let hasCustomHLOConverter = 1;
@@ -418,7 +420,7 @@
       BroadcastDimAttr:$broadcast_dimensions,
       XLA_ComparisonDirectionAttr:$comparison_direction
   );
-  let results = (outs I1Tensor:$res);
+  let results = (outs XLA_PredTensor);
   let summary = "Comparison operator";
 
   let description = [{
@@ -433,15 +435,17 @@
 // XLA Slice definitions.
 //===----------------------------------------------------------------------===//
 
-def XLA_SliceOp: XLA_UnaryElementwiseOp<"slice",
-      [NoSideEffect, SameOperandsAndResultElementType]> {
+def XLA_SliceOp: XLA_Op<
+      "slice",
+      [NoSideEffect, SameOperandsAndResultElementType,
+       AllTypesMatch<["start_indices", "limit_indices"]>]> {
   let arguments = (
     ins XLA_Tensor:$operand,
     ElementsAttr:$start_indices,
     ElementsAttr:$limit_indices
   );
 
-  let results = (outs XLA_Tensor:$res);
+  let results = (outs XLA_Tensor);
 
   let summary = "Slice operator";
 
@@ -456,15 +460,15 @@
   let hasCustomHLOConverter = 1;
 }
 
-def XLA_DynamicUpdateSliceOp: XLA_UnaryElementwiseOp<"dynamic-update-slice",
-      [NoSideEffect, AllElementTypesMatch<["operand", "res"]>]> {
+def XLA_DynamicUpdateSliceOp: XLA_Op<"dynamic-update-slice",
+      [NoSideEffect, AllElementTypesMatch<["operand", "result"]>]> {
   let arguments = (ins
     XLA_Tensor:$operand,
     XLA_Tensor:$update,
     Variadic<XLA_Tensor>:$start_indices
   );
 
-  let results = (outs XLA_Tensor:$res);
+  let results = (outs XLA_Tensor:$result);
 
   let summary = "Dynamic Update Slice operator";
 
@@ -503,9 +507,7 @@
     I64Attr:$feature_index
   );
 
-  let results = (outs
-    XLA_Tensor:$res
-  );
+  let results = (outs XLA_Tensor);
 }
 
 def XLA_BroadcastOp : XLA_Op<"broadcast",
@@ -529,7 +531,7 @@
     ElementsAttr:$broadcast_sizes
   );
 
-  let results = (outs XLA_Tensor:$res);
+  let results = (outs XLA_Tensor);
 
   // TODO(b/129012527) These should be expressed as type constraints.
   let verifier = [{
@@ -546,7 +548,7 @@
           "broadcast_sizes has rank {0} instead of rank 1", sizesRank));
     }
 
-    auto resultType = res()->getType().cast<RankedTensorType>();
+    auto resultType = getResult()->getType().cast<RankedTensorType>();
     auto resultRank = resultType.getRank();
     auto operandType = operand()->getType().cast<RankedTensorType>();
     auto operandRank = operandType.getRank();
@@ -611,7 +613,7 @@
     BroadcastDimAttr:$broadcast_dimensions
   );
 
-  let results = (outs XLA_Tensor:$res);
+  let results = (outs XLA_Tensor);
 
   // TODO(b/129012527) These should be expressed as type constraints.
   let verifier = [{
@@ -649,7 +651,7 @@
           dimensionsSize, operandRank));
     }
 
-    auto resultType = res()->getType().cast<RankedTensorType>();
+    auto resultType = getResult()->getType().cast<RankedTensorType>();
     auto resultRank = resultType.getRank();
     if (resultRank < operandRank) {
       return emitOpError(
@@ -704,9 +706,7 @@
     XLA_Tensor:$max
   );
 
-  let results = (outs
-    XLA_Tensor:$res
-  );
+  let results = (outs XLA_Tensor);
 
   // TODO(b/129012527) These should be expressed as type constraints.
   let verifier = [{
@@ -781,7 +781,7 @@
      return success();
    }];
 
-   let results = (outs XLA_Tensor:$res);
+   let results = (outs XLA_Tensor);
 
   // TODO(b/129422361) ConcatOp has special conversion logic to HLO.
   let hasCustomHLOConverter = 1;
@@ -801,20 +801,23 @@
     XLA_Tensor:$rhs
   );
 
-  let results = (outs XLA_Tensor:$res);
+  let results = (outs XLA_Tensor);
 
   // TODO(b/129422361) Needs additional work to handle attributes.
   // Conv has custom handling because its other args are passed as attributes
   let hasCustomHLOConverter = 1;
 }
 
-def XLA_CopyOp: XLA_UnaryElementwiseOp<"copy", [NoSideEffect, SameOperandsAndResultType]> {
+def XLA_CopyOp: XLA_Op<"copy", [NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Copy operator";
 
   let description = [{
     Returns a copy of `operand`.
   }];
 
+  let arguments = (ins XLA_Tensor);
+  let results = (outs XLA_Tensor);
+
   // TODO(b/129422361) Implement special handling.
   // Copy has an HloOpcode, but is not one of the ops defined in xla_builder.
   let hasCustomHLOConverter = 1;
@@ -826,7 +829,7 @@
         XLA_Tensor:$rhs,
         XLA_PrecisionConfigAttr:$precision_config
     );
-  let results = (outs XLA_Tensor:$res);
+  let results = (outs XLA_Tensor);
 
   let description = [{
     Performs dot products between vectors, vector/matrix and matrix/matrix
@@ -841,13 +844,13 @@
       ins XLA_Tensor:$operand,
           XLA_IntTensor:$start_indices,
           I64Attr: $index_vector_dim,
-          ElementsAttr: $offsets_dim,
+          ElementsAttr: $offset_dims,
           ElementsAttr: $slice_sizes,
-          ElementsAttr: $collapsed_slice_sizes,
+          ElementsAttr: $collapsed_slice_dims,
           ElementsAttr: $start_index_map
   );
 
-  let results = (outs XLA_Tensor:$res);
+  let results = (outs XLA_Tensor);
 
   let summary = "Gather operator";
 
@@ -866,7 +869,7 @@
       [NoSideEffect, SameOperandsAndResultElementType]> {
   let arguments = (ins XLA_Tensor:$operand);
 
-  let results = (outs XLA_Tensor:$res);
+  let results = (outs XLA_Tensor);
 
   let summary = "Reshape operator";
 
@@ -907,7 +910,7 @@
     XLA_Tensor:$on_false
   );
 
-  let results = (outs XLA_Tensor:$res);
+  let results = (outs XLA_Tensor);
 
   // TODO(b/129012527) These should be expressed as type constraints.
   let verifier = [{
@@ -954,7 +957,7 @@
     ElementsAttr:$dimensions
   );
 
-  let results = (outs XLA_Tensor:$res);
+  let results = (outs XLA_Tensor);
 
   // TODO(b/129422361): ReverseOp has a custom constructor for HLO.
   let hasCustomHLOConverter = 1;
@@ -979,7 +982,7 @@
     ElementsAttr: $interior_padding
   );
 
-  let results = (outs XLA_Tensor:$res);
+  let results = (outs XLA_Tensor);
 
   let description = [{
     Pads the `operand` according to TBD.
@@ -1050,7 +1053,7 @@
     XLA_Tensor:$operand,
     ElementsAttr:$permutation
   );
-  let results = (outs XLA_Tensor:$res);
+  let results = (outs XLA_Tensor);
 
   // TODO(b/129012527) These should be expressed as type constraints.
   let verifier = [{
@@ -1076,7 +1079,7 @@
           permutationSize, operandRank));
     }
 
-    auto resultType = res()->getType().cast<RankedTensorType>();
+    auto resultType = getResult()->getType().cast<RankedTensorType>();
     auto resultRank = resultType.getRank();
     if (resultRank != operandRank) {
       return emitOpError(
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
index 2ec1324..fe6e08c 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
@@ -37,6 +37,8 @@
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 
+using tensorflow::int64;
+
 static std::vector<int64> ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr) {
   llvm::ArrayRef<int64> raw_data = attr.getValues<int64>();
   if (attr.isSplat())
@@ -177,7 +179,7 @@
 
   // TODO(riverriddle) We currently don't support lowering constant operations.
   if (isa<mlir::XLA::ConstOp>(inst)) {
-    inst->emitError("unable to lower 'xla.constant' operation");
+    inst->emitError("unable to lower 'xla_hlo.constant' operation");
     return failure();
   }
 
diff --git a/tensorflow/compiler/mlir/xla/tests/convert.mlir b/tensorflow/compiler/mlir/xla/tests/convert.mlir
index 93de3b3..c87ac35 100644
--- a/tensorflow/compiler/mlir/xla/tests/convert.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/convert.mlir
@@ -4,8 +4,8 @@
 
 // CHECK-LABEL: func @convert.1(%arg0: tensor<f32>) -> tensor<f32> {
 func @convert.1(%arg0: tensor<f32>) -> tensor<f32> {
-  // CHECK-NEXT: %0 = "xla.convert"(%arg0) : (tensor<f32>) -> tensor<f32>
-  %0 = "xla.convert"(%arg0) : (tensor<f32>) -> tensor<f32>
+  // CHECK-NEXT: %0 = "xla_hlo.convert"(%arg0) : (tensor<f32>) -> tensor<f32>
+  %0 = "xla_hlo.convert"(%arg0) : (tensor<f32>) -> tensor<f32>
   // CHECK-NEXT: return %0 : tensor<f32>
   return %0 : tensor<f32>
 }
@@ -14,8 +14,8 @@
 
 // CHECK-LABEL: func @convert.2(%arg0: tensor<i32>) -> tensor<i32> {
 func @convert.2(%arg0: tensor<i32>) -> tensor<i32> {
-  // CHECK-NEXT: %0 = "xla.convert"(%arg0) : (tensor<i32>) -> tensor<i32>
-  %0 = "xla.convert"(%arg0) : (tensor<i32>) -> tensor<i32>
+  // CHECK-NEXT: %0 = "xla_hlo.convert"(%arg0) : (tensor<i32>) -> tensor<i32>
+  %0 = "xla_hlo.convert"(%arg0) : (tensor<i32>) -> tensor<i32>
   // CHECK-NEXT: return %0 : tensor<i32>
   return %0 : tensor<i32>
 }
@@ -24,8 +24,8 @@
 
 // CHECK-LABEL: func @convert.3(%arg0: tensor<i32>) -> tensor<i64> {
 func @convert.3(%arg0: tensor<i32>) -> tensor<i64> {
-  // CHECK-NEXT: %0 = "xla.convert"(%arg0) : (tensor<i32>) -> tensor<i64>
-  %0 = "xla.convert"(%arg0) : (tensor<i32>) -> tensor<i64>
+  // CHECK-NEXT: %0 = "xla_hlo.convert"(%arg0) : (tensor<i32>) -> tensor<i64>
+  %0 = "xla_hlo.convert"(%arg0) : (tensor<i32>) -> tensor<i64>
   // CHECK-NEXT: return %0 : tensor<i64>
   return %0 : tensor<i64>
 }
@@ -34,8 +34,8 @@
 
 // CHECK-LABEL: func @convert.4(%arg0: tensor<f32>) -> tensor<i32> {
 func @convert.4(%arg0: tensor<f32>) -> tensor<i32> {
-  // CHECK-NEXT: %0 = "xla.convert"(%arg0) : (tensor<f32>) -> tensor<i32>
-  %0 = "xla.convert"(%arg0) : (tensor<f32>) -> tensor<i32>
+  // CHECK-NEXT: %0 = "xla_hlo.convert"(%arg0) : (tensor<f32>) -> tensor<i32>
+  %0 = "xla_hlo.convert"(%arg0) : (tensor<f32>) -> tensor<i32>
   // CHECK-NEXT: return %0 : tensor<i32>
   return %0 : tensor<i32>
 }
@@ -44,8 +44,8 @@
 
 // CHECK-LABEL: func @convert.5(%arg0: tensor<i32>) -> tensor<f32> {
 func @convert.5(%arg0: tensor<i32>) -> tensor<f32> {
-  // CHECK-NEXT: %0 = "xla.convert"(%arg0) : (tensor<i32>) -> tensor<f32>
-  %0 = "xla.convert"(%arg0) : (tensor<i32>) -> tensor<f32>
+  // CHECK-NEXT: %0 = "xla_hlo.convert"(%arg0) : (tensor<i32>) -> tensor<f32>
+  %0 = "xla_hlo.convert"(%arg0) : (tensor<i32>) -> tensor<f32>
   // CHECK-NEXT: return %0 : tensor<f32>
   return %0 : tensor<f32>
 }
@@ -57,7 +57,7 @@
 func @convert.const.1() -> tensor<f32> {
   // CHECK-NEXT: %cst = constant dense<4.200000e+01> : tensor<f32>
   %cst = constant  dense<42.0> : tensor<f32>
-  %0 = "xla.convert"(%cst) : (tensor<f32>) -> tensor<f32>
+  %0 = "xla_hlo.convert"(%cst) : (tensor<f32>) -> tensor<f32>
   // CHECK-NEXT: return %cst : tensor<f32>
   return %0 : tensor<f32>
 }
@@ -68,7 +68,7 @@
 func @convert.const.2() -> tensor<i32> {
   // check-next: %cst = constant dense<42> : tensor<i32>
   %cst = constant  dense<42> : tensor<i32>
-  %0 = "xla.convert"(%cst) : (tensor<i32>) -> tensor<i32>
+  %0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<i32>
   // check-next: return %cst : tensor<i32>
   return %0 : tensor<i32>
 }
@@ -79,7 +79,7 @@
 func @convert.const.3() -> tensor<i32> {
   // CHECK-NEXT: %cst = constant dense<42> : tensor<i32>
   %cst = constant  dense<42.0> : tensor<f32>
-  %0 = "xla.convert"(%cst) : (tensor<f32>) -> tensor<i32>
+  %0 = "xla_hlo.convert"(%cst) : (tensor<f32>) -> tensor<i32>
   // CHECK-NEXT: return %cst : tensor<i32>
   return %0 : tensor<i32>
 }
@@ -90,7 +90,7 @@
 func @convert.const.4() -> tensor<f32> {
   // CHECK-NEXT: %cst = constant dense<4.200000e+01> : tensor<f32>
   %cst = constant  dense<42> : tensor<i32>
-  %0 = "xla.convert"(%cst) : (tensor<i32>) -> tensor<f32>
+  %0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<f32>
   // CHECK-NEXT: return %cst : tensor<f32>
   return %0 : tensor<f32>
 }
@@ -101,7 +101,7 @@
 func @convert.const.5() -> tensor<bf16> {
   // CHECK-NEXT: %cst = constant dense<4.200000e+01> : tensor<bf16>
   %cst = constant  dense<42> : tensor<i32>
-  %0 = "xla.convert"(%cst) : (tensor<i32>) -> tensor<bf16>
+  %0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<bf16>
   // CHECK-NEXT: return %cst : tensor<bf16>
   return %0 : tensor<bf16>
 }
@@ -112,7 +112,7 @@
 func @convert.const.6() -> tensor<i16> {
   // CHECK-NEXT: %cst = constant dense<42> : tensor<i16>
   %cst = constant  dense<42.0> : tensor<bf16>
-  %0 = "xla.convert"(%cst) : (tensor<bf16>) -> tensor<i16>
+  %0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<i16>
   // CHECK-NEXT: return %cst : tensor<i16>
   return %0 : tensor<i16>
 }
@@ -123,7 +123,7 @@
 func @convert.const.7() -> tensor<i32> {
   // CHECK-NEXT: %cst = constant dense<42> : tensor<i32>
   %cst = constant  dense<42> : tensor<i64>
-  %0 = "xla.convert"(%cst) : (tensor<i64>) -> tensor<i32>
+  %0 = "xla_hlo.convert"(%cst) : (tensor<i64>) -> tensor<i32>
   // CHECK-NEXT: return %cst : tensor<i32>
   return %0 : tensor<i32>
 }
@@ -134,7 +134,7 @@
 func @convert.const.8() -> tensor<i64> {
   // CHECK-NEXT: %cst = constant dense<42> : tensor<i64>
   %cst = constant  dense<42> : tensor<i32>
-  %0 = "xla.convert"(%cst) : (tensor<i32>) -> tensor<i64>
+  %0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
   // CHECK-NEXT: return %cst : tensor<i64>
   return %0 : tensor<i64>
 }
@@ -145,7 +145,7 @@
 func @convert.const.9() -> tensor<f32> {
   // CHECK-NEXT: %cst = constant  dense<4.200000e+01> : tensor<f32>
   %cst = constant  dense<42.0> : tensor<f64>
-  %0 = "xla.convert"(%cst) : (tensor<f64>) -> tensor<f32>
+  %0 = "xla_hlo.convert"(%cst) : (tensor<f64>) -> tensor<f32>
   // CHECK-NEXT: return %cst : tensor<f32>
   return %0 : tensor<f32>
 }
@@ -156,7 +156,7 @@
 func @convert.const.9() -> tensor<bf16> {
   // CHECK-NEXT: %cst = constant  dense<4.200000e+01> : tensor<bf16>
   %cst = constant  dense<42.0> : tensor<f32>
-  %0 = "xla.convert"(%cst) : (tensor<f32>) -> tensor<bf16>
+  %0 = "xla_hlo.convert"(%cst) : (tensor<f32>) -> tensor<bf16>
   // CHECK-NEXT: return %cst : tensor<bf16>
   return %0 : tensor<bf16>
 }
@@ -167,7 +167,7 @@
 func @convert.const.10() -> tensor<f64> {
   // CHECK-NEXT: %cst = constant dense<4.200000e+01> : tensor<f64>
   %cst = constant  dense<42.0> : tensor<bf16>
-  %0 = "xla.convert"(%cst) : (tensor<bf16>) -> tensor<f64>
+  %0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<f64>
   // CHECK-NEXT: return %cst : tensor<f64>
   return %0 : tensor<f64>
 }
@@ -178,7 +178,7 @@
 func @convert.const.11() -> tensor<f64> {
   // CHECK-NEXT: %cst = constant dense<4.200000e+01> : tensor<f64>
   %cst = constant  dense<42.0> : tensor<bf16>
-  %0 = "xla.convert"(%cst) : (tensor<bf16>) -> tensor<f64>
+  %0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<f64>
   // CHECK-NEXT: return %cst : tensor<f64>
   return %0 : tensor<f64>
 }
@@ -190,7 +190,7 @@
 func @convert.const.12() -> tensor<i64> {
   // CHECK-NEXT: %cst = constant dense<42> : tensor<i64>
   %cst = constant  dense<42.0> : tensor<bf16>
-  %0 = "xla.convert"(%cst) : (tensor<bf16>) -> tensor<i64>
+  %0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<i64>
   // CHECK-NEXT: return %cst : tensor<i64>
   return %0 : tensor<i64>
 }
@@ -201,7 +201,7 @@
 func @convert.const.13() -> tensor<i64> {
   // CHECK-NEXT: %cst = constant dense<42> : tensor<i64>
   %cst = constant  dense<42> : tensor<i16>
-  %0 = "xla.convert"(%cst) : (tensor<i16>) -> tensor<i64>
+  %0 = "xla_hlo.convert"(%cst) : (tensor<i16>) -> tensor<i64>
   // CHECK-NEXT: return %cst : tensor<i64>
   return %0 : tensor<i64>
 }
@@ -212,7 +212,7 @@
 func @convert.const.14() -> tensor<f64> {
   // CHECK-NEXT: %cst = constant dense<4.200000e+01> : tensor<f64>
   %cst = constant  dense<42> : tensor<i16>
-  %0 = "xla.convert"(%cst) : (tensor<i16>) -> tensor<f64>
+  %0 = "xla_hlo.convert"(%cst) : (tensor<i16>) -> tensor<f64>
   // CHECK-NEXT: return %cst : tensor<f64>
   return %0 : tensor<f64>
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/iota.mlir b/tensorflow/compiler/mlir/xla/tests/iota.mlir
index 10559a4..46e0984 100644
--- a/tensorflow/compiler/mlir/xla/tests/iota.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/iota.mlir
@@ -5,7 +5,7 @@
 // CHECK-LABEL: func @iota.const.1() -> tensor<4xi32> {
 func @iota.const.1() -> tensor<4xi32> {
   // CHECK-NEXT: %cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
-  %0 = "xla.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
+  %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
   // CHECK-NEXT: return %cst : tensor<4xi32>
   return %0 : tensor<4xi32>
 }
@@ -15,7 +15,7 @@
 // CHECK-LABEL: func @iota.const.2() -> tensor<2x4xi32> {
 func @iota.const.2() -> tensor<2x4xi32> {
   // CHECK-NEXT: %cst = constant dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1]]> : tensor<2x4xi32>
-  %0 = "xla.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32>
+  %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32>
   // CHECK-NEXT: return %cst : tensor<2x4xi32>
   return %0 : tensor<2x4xi32>
 }
@@ -25,7 +25,7 @@
 // CHECK-LABEL: func @iota.const.3() -> tensor<2x4xi32> {
 func @iota.const.3() -> tensor<2x4xi32> {
   // CHECK-NEXT: %cst = constant dense<{{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]> : tensor<2x4xi32>
-  %0 = "xla.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32>
+  %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32>
   // CHECK-NEXT: return %cst : tensor<2x4xi32>
   return %0 : tensor<2x4xi32>
 }
@@ -35,7 +35,7 @@
 // CHECK-LABEL: func @iota.const.4() -> tensor<2x3x4xi32> {
 func @iota.const.4() -> tensor<2x3x4xi32> {
   // CHECK-NEXT: %cst = constant dense<{{\[\[\[}}0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0{{\]\]}}, {{\[\[}}1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]]> : tensor<2x3x4xi32>
-  %0 = "xla.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32>
+  %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32>
   // CHECK-NEXT: return %cst : tensor<2x3x4xi32>
   return %0 : tensor<2x3x4xi32>
 }
@@ -45,7 +45,7 @@
 // CHECK-LABEL: func @iota.const.5() -> tensor<2x3x4xi32> {
 func @iota.const.5() -> tensor<2x3x4xi32> {
   // CHECK-NEXT: %cst = constant dense<{{\[\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2{{\]\]}}, {{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]]]> : tensor<2x3x4xi32>
-  %0 = "xla.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32>
+  %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32>
   // CHECK-NEXT: return %cst : tensor<2x3x4xi32>
   return %0 : tensor<2x3x4xi32>
 }
@@ -55,7 +55,7 @@
 // CHECK-LABEL: func @iota.const.6() -> tensor<2x3x4xi32> {
 func @iota.const.6() -> tensor<2x3x4xi32> {
   // CHECK-NEXT: %cst = constant dense<{{\[\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3{{\]\]}}, {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x3x4xi32>
-  %0 = "xla.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32>
+  %0 = "xla_hlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32>
   // CHECK-NEXT: return %cst : tensor<2x3x4xi32>
   return %0 : tensor<2x3x4xi32>
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir
index 74dd003..92d9c35 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir
@@ -2,16 +2,16 @@
 
 // CHECK-LABEL: func @cond(%arg0: tensor<i64>) -> tensor<i1> {
 func @cond(%arg0: tensor<i64>) -> tensor<i1> {
-  // CHECK-NEXT: %0 = "xla.compare"(%arg0, %arg0) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
-  %0 = "xla.compare"(%arg0, %arg0) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
+  // CHECK-NEXT: %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
+  %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
   // CHECK-NEXT: return %0 : tensor<i1>
   return %0 : tensor<i1>
 }
 
 // CHECK-LABEL: func @loop(%arg0: tensor<i64>) -> tensor<i64> {
 func @loop(%arg0: tensor<i64>) -> tensor<i64> {
-  // CHECK-NEXT: %0 = xla.add %arg0, %arg0 {name = "compare.0"} : tensor<i64>
-  %0 = "xla.add"(%arg0, %arg0) {name = "compare.0"} : (tensor<i64>, tensor<i64>) -> tensor<i64>
+  // CHECK-NEXT: %0 = xla_hlo.add %arg0, %arg0 {name = "compare.0"} : tensor<i64>
+  %0 = "xla_hlo.add"(%arg0, %arg0) {name = "compare.0"} : (tensor<i64>, tensor<i64>) -> tensor<i64>
   // CHECK-NEXT: return %0 : tensor<i64>
   return %0 : tensor<i64>
 }
@@ -27,7 +27,7 @@
   // CHECK-NEXT:   %4 = call @loop(%3) : (tensor<i64>) -> tensor<i64>
   // CHECK-NEXT:   br ^bb1(%4 : tensor<i64>)
   // CHECK-NEXT: b3(%5: tensor<i64>):	// pred: ^bb1
-  %0 = "xla.while"(%arg0) {body = @loop, cond = @cond} : (tensor<i64>) -> tensor<i64>
+  %0 = "xla_hlo.while"(%arg0) {body = @loop, cond = @cond} : (tensor<i64>) -> tensor<i64>
   // CHECK-NEXT:   return %5 : tensor<i64>
   return %0 : tensor<i64>
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
index 69be978..b2a52c7 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
@@ -6,7 +6,7 @@
 
 // CHECK-LABEL: fusedBatchNorm_notraining
 func @fusedBatchNorm_notraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
-  // CHECK-NEXT: "xla.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
+  // CHECK-NEXT: "xla_hlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
   %0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
   return %0#0 : tensor<8x8x8x8xf32>
 }
@@ -25,14 +25,14 @@
 
 // CHECK-LABEL: func @biasAdd_NHWC
 func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
-  // CHECK-NEXT: %0 = "xla.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>}
+  // CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>}
   %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
   return %0 : tensor<1x32x10x32xi32>
 }
 
 // CHECK-LABEL: func @biasAdd_NCHW
 func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
-  // CHECK-NEXT: %0 = "xla.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>}
+  // CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>}
   %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
   return %0 : tensor<1x32x10x32xi32>
 }
@@ -42,14 +42,14 @@
 
 // CHECK-LABEL: func @biasAdd_NHWC_invalid
 func @biasAdd_NHWC_invalid(%arg0: tensor<1x32x10x2xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x2xi32> {
-  // CHECK-NOT: xla.add
+  // CHECK-NOT: xla_hlo.add
   %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x2xi32>, tensor<32xi32>) -> tensor<1x32x10x2xi32>
   return %0 : tensor<1x32x10x2xi32>
 }
 
 // CHECK-LABEL: func @biasAdd_NCHW_invalid
 func @biasAdd_NCHW_invalid(%arg0: tensor<1x10x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x10x10x32xi32> {
-  // CHECK-NOT: xla.add
+  // CHECK-NOT: xla_hlo.add
   %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x10x10x32xi32>, tensor<32xi32>) -> tensor<1x10x10x32xi32>
   return %0 : tensor<1x10x10x32xi32>
 }
@@ -60,7 +60,7 @@
 
 // CHECK-LABEL: func @add
 func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
-  // CHECK-NEXT:  %0 = xla.add %arg0, %arg0 : tensor<2xi32>
+  // CHECK-NEXT:  %0 = xla_hlo.add %arg0, %arg0 : tensor<2xi32>
   // CHECK-NEXT:  return %0 : tensor<2xi32>
   %0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
   return %0: tensor<2xi32>
@@ -68,21 +68,21 @@
 
 // CHECK-LABEL: func @broadcast_add
 func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
-  // CHECK-NEXT: "xla.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>}
+  // CHECK-NEXT: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>}
   %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
   return %0: tensor<1x2xi32>
 }
 
 // CHECK-LABEL: func @broadcast_multi_dim_add
 func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> {
-  // CHECK-NEXT: "xla.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>}
+  // CHECK-NEXT: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>}
   %0 = "tf.Add"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
   return %0: tensor<4x4x4x4xi32>
 }
 
 // CHECK-LABEL: func @div
 func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
-  // CHECK-NEXT:  %0 = xla.div %arg0, %arg0 : tensor<2xi32>
+  // CHECK-NEXT:  %0 = xla_hlo.div %arg0, %arg0 : tensor<2xi32>
   // CHECK-NEXT:  return %0 : tensor<2xi32>
   %0 = "tf.Div"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
   return %0: tensor<2xi32>
@@ -90,14 +90,14 @@
 
 // CHECK-LABEL: func @broadcast_div
 func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
-  // CHECK-NEXT: "xla.div"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>}
+  // CHECK-NEXT: "xla_hlo.div"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>}
   %0 = "tf.Div"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
   return %0: tensor<1x2xi32>
 }
 
 // CHECK-LABEL: func @mul
 func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> {
-  // CHECK-NEXT:  %0 = xla.mul %arg0, %arg0 : tensor<2xi32>
+  // CHECK-NEXT:  %0 = xla_hlo.mul %arg0, %arg0 : tensor<2xi32>
   // CHECK-NEXT:  return %0 : tensor<2xi32>
   %0 = "tf.Mul"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
   return %0: tensor<2xi32>
@@ -105,28 +105,28 @@
 
 // CHECK-LABEL: func @broadcast_mul
 func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
-  // CHECK-NEXT: "xla.mul"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>}
+  // CHECK-NEXT: "xla_hlo.mul"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>}
   %0 = "tf.Mul"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
   return %0: tensor<1x2xi32>
 }
 
 // CHECK-LABEL: func @real_div
 func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
-  // CHECK-NEXT:  %0 = xla.div %arg0, %arg0 : tensor<2xi32>
+  // CHECK-NEXT:  %0 = xla_hlo.div %arg0, %arg0 : tensor<2xi32>
   %0 = "tf.RealDiv"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
   return %0: tensor<2xi32>
 }
 
 // CHECK-LABEL: func @broadcast_real_div
 func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
-  // CHECK-NEXT: "xla.div"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>}
+  // CHECK-NEXT: "xla_hlo.div"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>}
   %0 = "tf.RealDiv"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
   return %0: tensor<1x2xi32>
 }
 
 // CHECK-LABEL: func @sub
 func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> {
-  // CHECK-NEXT:  %0 = xla.sub %arg0, %arg0 : tensor<2xi32>
+  // CHECK-NEXT:  %0 = xla_hlo.sub %arg0, %arg0 : tensor<2xi32>
   // CHECK-NEXT:  return %0 : tensor<2xi32>
   %0 = "tf.Sub"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
   return %0: tensor<2xi32>
@@ -134,7 +134,7 @@
 
 // CHECK-LABEL: func @broadcast_sub
 func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
-  // CHECK-NEXT: "xla.sub"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>}
+  // CHECK-NEXT: "xla_hlo.sub"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>}
   %0 = "tf.Sub"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
   return %0: tensor<1x2xi32>
 }
@@ -156,7 +156,7 @@
 
 // CHECK-LABEL: @const
 func @const() -> tensor<2xi32> {
-  // tf.Const is legalized into xla.constant, which is folded into constant.
+  // tf.Const is legalized into xla_hlo.constant, which is folded into constant.
 
   // CHECK-NEXT: constant dense<0> : tensor<2xi32>
   %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = dense<0> : tensor<2xi32>} : () -> (tensor<2xi32>)
@@ -170,7 +170,7 @@
 // CHECK-LABEL: func @relu
 func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> {
   // CHECK-NEXT: %cst = constant dense<0> : tensor<1xi32>
-  // CHECK-NEXT: %0 = xla.max %arg0, %cst : tensor<1xi32>
+  // CHECK-NEXT: %0 = xla_hlo.max %arg0, %cst : tensor<1xi32>
   %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
   return %0: tensor<1xi32>
 }
@@ -179,7 +179,7 @@
 func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> {
   // CHECK-NEXT: %cst = constant dense<0> : tensor<1xi32>
   // CHECK-NEXT: %cst_0 = constant dense<6> : tensor<1xi32>
-  // CHECK-NEXT: %0 = "xla.clamp"(%cst, %arg0, %cst_0) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  // CHECK-NEXT: %0 = "xla_hlo.clamp"(%cst, %arg0, %cst_0) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
   %0 = "tf.Relu6"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
   return %0: tensor<1xi32>
 }
@@ -190,7 +190,7 @@
 
 // CHECK-LABEL: reshape
 func @reshape(%arg0: tensor<2xf32>, %arg1: tensor<2xi32>) -> tensor<1x1xf32> {
-  // CHECK:  %0 = "xla.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x1xf32>
+  // CHECK:  %0 = "xla_hlo.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x1xf32>
   %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xi32>) -> tensor<1x1xf32>
   return %0 : tensor<1x1xf32>
 }
@@ -204,7 +204,7 @@
 
 // CHECK-LABEL: squeeze
 func @squeeze(%arg0: tensor<1x1x10xf32>) -> tensor<1x10xf32> {
-  // CHECK-NEXT: %0 = "xla.reshape"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
+  // CHECK-NEXT: %0 = "xla_hlo.reshape"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
   %0 = "tf.Squeeze"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
   return %0 : tensor<1x10xf32>
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir
index d75b283..6dad191 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir
@@ -3,16 +3,16 @@
 // CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
 func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
   // CHECK-NEXT:   %0 = addf %arg0, %arg1 : tensor<4xf32>
-  %0 = "xla.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  %0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
 
   // CHECK-NEXT:   %1 = mulf %0, %arg1 : tensor<4xf32>
-  %1 = "xla.mul"(%0, %arg1) {name = "mul.4"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  %1 = "xla_hlo.mul"(%0, %arg1) {name = "mul.4"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
 
   // CHECK-NEXT:   %2 = subf %1, %arg1 : tensor<4xf32>
-  %2 = "xla.sub"(%1, %arg1) {name = "sub.5"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  %2 = "xla_hlo.sub"(%1, %arg1) {name = "sub.5"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
 
   // CHECK-NEXT:   %3 = divf %2, %arg1 : tensor<4xf32>
-  %3 = "xla.div"(%2, %arg1) {name = "div.6"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  %3 = "xla_hlo.div"(%2, %arg1) {name = "div.6"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
 
   // CHECK-NEXT:   return %3 : tensor<4xf32>
   return %3 : tensor<4xf32>
@@ -21,16 +21,16 @@
 // CHECK-LABEL: func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
 func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
   // CHECK-NEXT:   %0 = addi %arg0, %arg1 : tensor<4xi32>
-  %0 = "xla.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+  %0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
 
   // CHECK-NEXT:   %1 = muli %0, %arg1 : tensor<4xi32>
-  %1 = "xla.mul"(%0, %arg1) {name = "mul.4"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+  %1 = "xla_hlo.mul"(%0, %arg1) {name = "mul.4"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
 
   // CHECK-NEXT:   %2 = subi %1, %arg1 : tensor<4xi32>
-  %2 = "xla.sub"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+  %2 = "xla_hlo.sub"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
 
   // CHECK-NEXT:   %3 = divis %2, %arg1 : tensor<4xi32>
-  %3 = "xla.div"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+  %3 = "xla_hlo.div"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
 
   // CHECK-NEXT:   return %3 : tensor<4xi32>
   return %3 : tensor<4xi32>
@@ -41,23 +41,23 @@
 // them to separate broadcast and binary op.
 // CHECK-LABEL: func @binary_ops_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4xf32> {
 func @binary_ops_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4xf32> {
-  // CHECK-NEXT: %0 = "xla.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "add.3"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
-  %0 = "xla.add"(%arg0, %arg1) {
+  // CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "add.3"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
+  %0 = "xla_hlo.add"(%arg0, %arg1) {
       name = "add.3", broadcast_dimensions = dense<1> : tensor<1xi64>} :
           (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
 
-  // CHECK-NEXT: %1 = "xla.mul"(%0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "mul.4"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
-  %1 = "xla.mul"(%0, %arg1) {
+  // CHECK-NEXT: %1 = "xla_hlo.mul"(%0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "mul.4"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
+  %1 = "xla_hlo.mul"(%0, %arg1) {
       name = "mul.4", broadcast_dimensions = dense<1> : tensor<1xi64>} :
           (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
 
-  // CHECK-NEXT: %2 = "xla.sub"(%1, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "sub.5"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
-  %2 = "xla.sub"(%1, %arg1) {
+  // CHECK-NEXT: %2 = "xla_hlo.sub"(%1, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "sub.5"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
+  %2 = "xla_hlo.sub"(%1, %arg1) {
       name = "sub.5", broadcast_dimensions = dense<1> : tensor<1xi64>} :
           (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
 
-  // CHECK-NEXT: %3 = "xla.div"(%2, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "div.6"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
-  %3 = "xla.div"(%2, %arg1) {
+  // CHECK-NEXT: %3 = "xla_hlo.div"(%2, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "div.6"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
+  %3 = "xla_hlo.div"(%2, %arg1) {
       name = "div.6", broadcast_dimensions = dense<1> : tensor<1xi64>} :
           (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
 
@@ -68,17 +68,17 @@
 // CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
 func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
   // CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32>
-  %0 = "xla.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
   // CHECK-NEXT: %1 = cmpi "ne", %arg0, %arg0 : tensor<4xi32>
-  %1 = "xla.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %1 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
   // CHECK-NEXT: %2 = cmpi "slt", %arg0, %arg0 : tensor<4xi32>
-  %2 = "xla.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
   // CHECK-NEXT: %3 = cmpi "sle", %arg0, %arg0 : tensor<4xi32>
-  %3 = "xla.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %3 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
   // CHECK-NEXT: %4 = cmpi "sgt", %arg0, %arg0 : tensor<4xi32>
-  %4 = "xla.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %4 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
   // CHECK-NEXT: %5 = cmpi "sge", %arg0, %arg0 : tensor<4xi32>
-  %5 = "xla.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %5 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
   // CHECK-NEXT: return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
   return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
 }
@@ -86,17 +86,17 @@
 // CHECK-LABEL: func @compare_float
 func @compare_float(%arg0: tensor<4xf32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
   // CHECK-NEXT: %0 = cmpf "oeq", %arg0, %arg0 : tensor<4xf32>
-  %0 = "xla.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+  %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
   // CHECK-NEXT: %1 = cmpf "une", %arg0, %arg0 : tensor<4xf32>
-  %1 = "xla.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+  %1 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
   // CHECK-NEXT: %2 = cmpf "olt", %arg0, %arg0 : tensor<4xf32>
-  %2 = "xla.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+  %2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
   // CHECK-NEXT: %3 = cmpf "ole", %arg0, %arg0 : tensor<4xf32>
-  %3 = "xla.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+  %3 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
   // CHECK-NEXT: %4 = cmpf "ogt", %arg0, %arg0 : tensor<4xf32>
-  %4 = "xla.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+  %4 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
   // CHECK-NEXT: %5 = cmpf "oge", %arg0, %arg0 : tensor<4xf32>
-  %5 = "xla.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+  %5 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
   return %0, %1, %2, %3, %4, %5: tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
 }
 
diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir
index fcd93bb..11dd3db 100644
--- a/tensorflow/compiler/mlir/xla/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir
@@ -4,7 +4,7 @@
 
 func @enforce_static_shapes(%arg0: tensor<*xf32>) -> tensor<*xf32> {
   // expected-error@+1 {{op operand #0 must be statically shaped tensor}}
-  %0 = "xla.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
+  %0 = "xla_hlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
   return %0: tensor<*xf32>
 }
 
@@ -12,7 +12,7 @@
 
 // CHECK-LABEL: func @add_tensors
 func @add_tensors(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> {
-  %0 = "xla.add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
   return %0: tensor<1xi32>
 }
 
@@ -20,7 +20,7 @@
 
 // CHECK-LABEL: func @add_scalars
 func @add_scalars(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
-  %0 = "xla.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
   return %0: tensor<i32>
 }
 
@@ -28,7 +28,7 @@
 
 // CHECK-LABEL: func @add_scalar_tensor
 func @add_scalar_tensor(%arg0: tensor<1xi32>, %arg1: tensor<i32>) -> tensor<1xi32> {
-  %0 = "xla.add"(%arg0, %arg1) : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
+  %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
   return %0: tensor<1xi32>
 }
 
@@ -36,7 +36,7 @@
 
 // CHECK-LABEL: func @batch_norm_inference
 func @batch_norm_inference(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> tensor<8x8x8x8xf32> {
-  %0 = "xla.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
+  %0 = "xla_hlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
   return %0 : tensor<8x8x8x8xf32>
 }
 
@@ -44,7 +44,7 @@
 
 // CHECK-LABEL: func @broadcast
 func @broadcast(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> {
-  %0 = "xla.broadcast"(%arg0) {broadcast_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32>
+  %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32>
   return %0 : tensor<1x2x3xi32>
 }
 
@@ -52,7 +52,7 @@
 
 func @broadcast_nonint_sizes(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> {
   // expected-error@+1 {{broadcast_sizes must be a DenseIntElementsAttr}}
-  %0 = "xla.broadcast"(%arg0) {broadcast_sizes = dense<[1.0, 2.0]> : tensor<2xf64>} : (tensor<3xi32>) -> tensor<1x2x3xi32>
+  %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1.0, 2.0]> : tensor<2xf64>} : (tensor<3xi32>) -> tensor<1x2x3xi32>
   return %0 : tensor<1x2x3xi32>
 }
 
@@ -60,7 +60,7 @@
 
 func @broadcast_splat_sizes(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> {
   // expected-error@+1 {{broadcast_sizes must be a DenseIntElementsAttr}}
-  %0 = "xla.broadcast"(%arg0) {broadcast_sizes = dense<2.0> : tensor<2xf64>} : (tensor<3xi32>) -> tensor<1x2x3xi32>
+  %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<2.0> : tensor<2xf64>} : (tensor<3xi32>) -> tensor<1x2x3xi32>
   return %0 : tensor<1x2x3xi32>
 }
 
@@ -68,7 +68,7 @@
 
 func @broadcast_sparse_sizes(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> {
   // expected-error@+1 {{broadcast_sizes must be a DenseIntElementsAttr}}
-  %0 = "xla.broadcast"(%arg0) {broadcast_sizes = sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32>} : (tensor<3xi32>) -> tensor<1x2x3xi32>
+  %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32>} : (tensor<3xi32>) -> tensor<1x2x3xi32>
   return %0 : tensor<1x2x3xi32>
 }
 
@@ -76,7 +76,7 @@
 
 func @broadcast_bad_sizes_rank(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> {
   // expected-error@+1 {{broadcast_sizes has rank 2 instead of rank 1}}
-  %0 = "xla.broadcast"(%arg0) {broadcast_sizes = dense<[[1, 2]]> : tensor<1x2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32>
+  %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[[1, 2]]> : tensor<1x2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32>
   return %0 : tensor<1x2x3xi32>
 }
 
@@ -84,7 +84,7 @@
 
 func @broadcast_bad_result_rank(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> {
   // expected-error@+1 {{result rank (3) does not match operand rank (1) plus size of broadcast_sizes (3)}}
-  %0 = "xla.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32>
+  %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32>
   return %0 : tensor<1x2x3xi32>
 }
 
@@ -92,7 +92,7 @@
 
 func @broadcast_bad_first_part_result_shape(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> {
   // expected-error@+1 {{result has shape [1, 3] instead of [2, 3]}}
-  %0 = "xla.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x3xi32>
+  %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x3xi32>
   return %0 : tensor<1x3xi32>
 }
 
@@ -100,7 +100,7 @@
 
 func @broadcast_bad_second_part_result_shape(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> {
   // expected-error@+1 {{result has shape [2, 1] instead of [2, 3]}}
-  %0 = "xla.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<2x1xi32>
+  %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<2x1xi32>
   return %0 : tensor<2x1xi32>
 }
 
@@ -108,7 +108,7 @@
 
 // CHECK-LABEL: func @broadcast_in_dim
 func @broadcast_in_dim(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> {
-  %0 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x2xi32>
+  %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x2xi32>
   return %0 : tensor<1x2x2xi32>
 }
 
@@ -116,7 +116,7 @@
 
 // CHECK-LABEL: func @broadcast_in_dim_zero_rank
 func @broadcast_in_dim_zero_rank(%arg0: tensor<i32>) -> tensor<1x2x3xi32> {
-  %0 = "xla.broadcast_in_dim"(%arg0) : (tensor<i32>) -> tensor<1x2x3xi32>
+  %0 = "xla_hlo.broadcast_in_dim"(%arg0) : (tensor<i32>) -> tensor<1x2x3xi32>
   return %0 : tensor<1x2x3xi32>
 }
 
@@ -124,7 +124,7 @@
 
 func @broadcast_in_dim_bad_nonint_dimensions(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> {
   // expected-error@+1 {{broadcast_sizes must be a DenseIntElementsAttr}}
-  %0 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1.0, 2.0]> : tensor<2xf64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32>
+  %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1.0, 2.0]> : tensor<2xf64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32>
   return %0 : tensor<1x2x3xi32>
 }
 
@@ -132,7 +132,7 @@
 
 func @broadcast_in_dim_bad_splat_dimensions(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> {
   // expected-error@+1 {{broadcast_sizes must be a DenseIntElementsAttr}}
-  %0 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<2.0> : tensor<2xf64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32>
+  %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<2.0> : tensor<2xf64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32>
   return %0 : tensor<1x2x3xi32>
 }
 
@@ -140,7 +140,7 @@
 
 func @broadcast_in_dim_bad_sparse_dimensions(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> {
   // expected-error@+1 {{broadcast_sizes must be a DenseIntElementsAttr}}
-  %0 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32>
+  %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32>
   return %0 : tensor<1x2x3xi32>
 }
 
@@ -148,7 +148,7 @@
 
 func @broadcast_in_dim_bad_dimension_rank(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> {
   // expected-error@+1 {{broadcast_dimensions has rank 2 instead of rank 1}}
-  %0 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[[1,1],[1,1]]> : tensor<2x2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32>
+  %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[[1,1],[1,1]]> : tensor<2x2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32>
   return %0 : tensor<1x2x3xi32>
 }
 
@@ -156,7 +156,7 @@
 
 func @broadcast_in_dim_bad_dimension_size(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> {
   // expected-error@+1 {{broadcast_dimensions size (1) does not match operand rank (2)}}
-  %0 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32>
+  %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32>
   return %0 : tensor<1x2x3xi32>
 }
 
@@ -164,7 +164,7 @@
 
 func @broadcast_in_dim_bad_rank_decrease(%arg0: tensor<1x2x3xi32>) -> tensor<3xi32> {
   // expected-error@+1 {{result rank (1) is less than operand rank (3)}}
-  %0 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0,1,2]> : tensor<3xi64>} : (tensor<1x2x3xi32>) -> tensor<3xi32>
+  %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0,1,2]> : tensor<3xi64>} : (tensor<1x2x3xi32>) -> tensor<3xi32>
   return %0 : tensor<3xi32>
 }
 
@@ -172,7 +172,7 @@
 
 func @broadcast_in_dim_dimension_values_too_large(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> {
   // expected-error@+1 {{broadcast_dimensions contains invalid value 9 for result result with rank 3}}
-  %0 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[9, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32>
+  %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[9, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32>
   return %0 : tensor<1x2x3xi32>
 }
 
@@ -180,7 +180,7 @@
 
 func @broadcast_in_dim_bad_shape_mismatch(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> {
   // expected-error@+1 {{size of operand dimension 0 (3) is not equal to 1 or size of result dimension 1 (2)}}
-  %0 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32>
+  %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32>
   return %0 : tensor<1x2x3xi32>
 }
 
@@ -188,7 +188,7 @@
 
 // CHECK-LABEL: func @comp_eq
 func @comp_eq(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi1> {
-  %0 = "xla.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
+  %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
   return %0 : tensor<3xi1>
 }
 
@@ -196,7 +196,7 @@
 
 func @comp_bad_direction(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi1> {
   // expected-error@+1 {{'comparison_direction' failed to satisfy constraint}}
-  %0 = "xla.compare"(%arg0, %arg1) {comparison_direction = "FOOBAR"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
+  %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "FOOBAR"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
   return %0 : tensor<3xi1>
 }
 
@@ -204,7 +204,7 @@
 
 func @comp_no_direction(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi1> {
   // expected-error@+1 {{op requires attribute 'comparison_direction'}}
-  %0 = "xla.compare"(%arg0, %arg1) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
+  %0 = "xla_hlo.compare"(%arg0, %arg1) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
   return %0 : tensor<3xi1>
 }
 
@@ -212,7 +212,7 @@
 
 // CHECK-LABEL: func @conv
 func @conv(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi32> {
-  %0 = "xla.conv"(%arg0, %arg1) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
+  %0 = "xla_hlo.conv"(%arg0, %arg1) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
   return %0: tensor<3xi32>
 }
 
@@ -220,7 +220,7 @@
 
 // CHECK-LABEL: func @copy
 func @copy(%arg0: tensor<1xi32>) -> tensor<1xi32> {
-  %0 = "xla.copy"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+  %0 = "xla_hlo.copy"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
   return %0: tensor<1xi32>
 }
 
@@ -228,7 +228,7 @@
 
 // CHECK-LABEL: func @clamp
 func @clamp(%arg0: tensor<1xi32>) -> tensor<1xi32> {
-  %0 = "xla.clamp"(%arg0, %arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %0 = "xla_hlo.clamp"(%arg0, %arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
   return %0: tensor<1xi32>
 }
 
@@ -236,15 +236,15 @@
 
 // CHECK-LABEL: func @clamp_scalar
 func @clamp_scalar(%arg0: tensor<1xi32>, %arg1: tensor<i32>) -> tensor<1xi32> {
-  %0 = "xla.clamp"(%arg1, %arg0, %arg1) : (tensor<i32>, tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
+  %0 = "xla_hlo.clamp"(%arg1, %arg0, %arg1) : (tensor<i32>, tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
   return %0: tensor<1xi32>
 }
 
 // -----
 
 func @clamp_invalid_min_element_type(%arg0: tensor<1xi32>, %arg1: tensor<1xf32>) -> tensor<1xi32> {
-  // expected-error@+1 {{'xla.clamp' op requires the same element type for all operands and results}}
-  %0 = "xla.clamp"(%arg1, %arg0, %arg0) : (tensor<1xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  // expected-error@+1 {{'xla_hlo.clamp' op requires the same element type for all operands and results}}
+  %0 = "xla_hlo.clamp"(%arg1, %arg0, %arg0) : (tensor<1xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
   return %0: tensor<1xi32>
 }
 
@@ -252,15 +252,15 @@
 
 func @clamp_invalid_min_shape(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<1xi32> {
   // expected-error@+1 {{min shape [2] is not scalar and does not match operand shape [1]}}
-  %0 = "xla.clamp"(%arg1, %arg0, %arg0) : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %0 = "xla_hlo.clamp"(%arg1, %arg0, %arg0) : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
   return %0: tensor<1xi32>
 }
 
 // -----
 
 func @clamp_invalid_max_element_type(%arg0: tensor<1xi32>, %arg1: tensor<1xf32>) -> tensor<1xi32> {
-  // expected-error@+1 {{'xla.clamp' op requires the same element type for all operands and results}}
-  %0 = "xla.clamp"(%arg0, %arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>, tensor<1xf32>) -> tensor<1xi32>
+  // expected-error@+1 {{'xla_hlo.clamp' op requires the same element type for all operands and results}}
+  %0 = "xla_hlo.clamp"(%arg0, %arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>, tensor<1xf32>) -> tensor<1xi32>
   return %0: tensor<1xi32>
 }
 
@@ -268,7 +268,7 @@
 
 func @clamp_invalid_max_shape(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<1xi32> {
   // expected-error@+1 {{max shape [2] is not scalar and does not match operand shape [1]}}
-  %0 = "xla.clamp"(%arg0, %arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<1xi32>
+  %0 = "xla_hlo.clamp"(%arg0, %arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<1xi32>
   return %0: tensor<1xi32>
 }
 
@@ -276,7 +276,7 @@
 
 // CHECK-LABEL: func @dot_vector
 func @dot_vector(%arg0: tensor<1x2xi32>, %arg1: tensor<2x1xi32>) -> tensor<i32> {
-  %0 = "xla.dot"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<i32>
+  %0 = "xla_hlo.dot"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<i32>
   return %0: tensor<i32>
 }
 
@@ -284,7 +284,7 @@
 
 // CHECK-LABEL: func @dot_matrix
 func @dot_matrix(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> {
-  %0 = "xla.dot"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+  %0 = "xla_hlo.dot"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
   return %0: tensor<2x2xi32>
 }
 
@@ -292,7 +292,7 @@
 
 // CHECK-LABEL: func @dot_precision_config
 func @dot_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> {
-  %0 = "xla.dot"(%arg0, %arg1) {precision_config = ["HIGH", "HIGHEST"]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+  %0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["HIGH", "HIGHEST"]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
   return %0: tensor<2x2xi32>
 }
 
@@ -300,7 +300,7 @@
 
 func @dot_bad_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> {
   // expected-error@+1 {{'precision_config' failed to satisfy constraint}}
-  %0 = "xla.dot"(%arg0, %arg1) {precision_config = ["FOO", "HIGHEST"]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+  %0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["FOO", "HIGHEST"]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
   return %0: tensor<2x2xi32>
 }
 
@@ -308,7 +308,7 @@
 
 // CHECK-LABEL: func @tanh
 func @tanh(%arg0: tensor<1xf32>) -> tensor<1xf32> {
-  %0 = "xla.tanh"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
+  %0 = "xla_hlo.tanh"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
   return %0: tensor<1xf32>
 }
 
@@ -316,7 +316,7 @@
 
 // CHECK-LABEL: func @reshape_same_shape
 func @reshape_same_shape(%arg0: tensor<1xi32>) -> tensor<1xi32> {
-  %0 = "xla.reshape"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+  %0 = "xla_hlo.reshape"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
   return %0: tensor<1xi32>
 }
 
@@ -324,7 +324,7 @@
 
 // CHECK-LABEL: func @reshape_different_shape
 func @reshape_different_shape(%arg0: tensor<1x16xi32>) -> tensor<4x4xi32> {
-  %0 = "xla.reshape"(%arg0) : (tensor<1x16xi32>) -> tensor<4x4xi32>
+  %0 = "xla_hlo.reshape"(%arg0) : (tensor<1x16xi32>) -> tensor<4x4xi32>
   return %0: tensor<4x4xi32>
 }
 
@@ -332,7 +332,7 @@
 
 // CHECK-LABEL: func @reshape_from_scalar
 func @reshape_from_scalar(%arg0: tensor<i32>) -> tensor<1xi32> {
-  %0 = "xla.reshape"(%arg0) : (tensor<i32>) -> tensor<1xi32>
+  %0 = "xla_hlo.reshape"(%arg0) : (tensor<i32>) -> tensor<1xi32>
   return %0: tensor<1xi32>
 }
 
@@ -340,7 +340,7 @@
 
 // CHECK-LABEL: func @reshape_to_scalar
 func @reshape_to_scalar(%arg0: tensor<1xi32>) -> tensor<i32> {
-  %0 = "xla.reshape"(%arg0) : (tensor<1xi32>) -> tensor<i32>
+  %0 = "xla_hlo.reshape"(%arg0) : (tensor<1xi32>) -> tensor<i32>
   return %0: tensor<i32>
 }
 
@@ -348,7 +348,7 @@
 
 // CHECK-LABEL: func @select
 func @select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
-  %0 = "xla.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
   return %0 : tensor<2x3xi32>
 }
 
@@ -356,7 +356,7 @@
 
 // CHECK-LABEL: func @select_scalar_pred
 func @select_scalar_pred(%arg0: tensor<i1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
-  %0 = "xla.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
   return %0 : tensor<2x3xi32>
 }
 
@@ -364,7 +364,7 @@
 
 func @select_bad_pred_type(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
   // expected-error@+1 {{must be statically shaped tensor of pred (AKA boolean or 1-bit integer)}}
-  %0 = "xla.select"(%arg0, %arg1, %arg2) : (tensor<3xi32>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi32>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
   return %0 : tensor<2x3xi32>
 }
 
@@ -372,7 +372,7 @@
 
 func @select_bad_shape_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
   // expected-error@+1 {{on_true type (tensor<2x4xi32>) does not match on_false type (tensor<2x3xi32>)}}
-  %0 = "xla.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x4xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x4xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
   return %0 : tensor<2x3xi32>
 }
 
@@ -380,7 +380,7 @@
 
 func @select_bad_element_type_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
   // expected-error@+1 {{on_true type (tensor<2x3xf32>) does not match on_false type (tensor<2x3xi32>)}}
-  %0 = "xla.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf32>, tensor<2x3xi32>) -> tensor<2x3xi32>
   return %0 : tensor<2x3xi32>
 }
 
@@ -388,15 +388,39 @@
 
 func @select_bad_pred_shape(%arg0: tensor<3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
   // expected-error@+1 {{red shape ([3]) is not scalar and does not match operand shapes ([2, 3])}}
-  %0 = "xla.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
   return %0 : tensor<2x3xi32>
 }
 
 // -----
 
+// CHECK-LABEL: func @slice
+func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> {
+  %0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32>
+  return %0 : tensor<1x4xi32>
+}
+
+// -----
+
+func @slice_indices_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> {
+  // expected-error@+1 {{failed to verify that all of {start_indices, limit_indices} have same type}}
+  %0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 2, 3]> : tensor<3xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32>
+  return %0 : tensor<1x4xi32>
+}
+
+// -----
+
+func @slice_operand_result_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xf32> {
+  // expected-error@+1 {{requires the same element type for all operands and results}}
+  %0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xf32>
+  return %0 : tensor<1x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @transpose
-func @transpose(%arg0: tensor<1x2x3x4xi32>) ->  tensor<2x1x4x3xi32> {
-  %0 = "xla.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
+func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
+  %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
   return %0: tensor<2x1x4x3xi32>
 }
 
@@ -404,7 +428,7 @@
 
 func @transpose_bad_permutations_float(%arg0: tensor<1x2x3x4xi32>) ->  tensor<2x1x4x3xi32> {
   // expected-error@+1 {{permutation must be a DenseIntElementsAttr}}
-  %0 = "xla.transpose"(%arg0) {permutation = dense<[1.0, 0.0, 3.0, 2.0]> : tensor<4xf64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
+  %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1.0, 0.0, 3.0, 2.0]> : tensor<4xf64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
   return %0: tensor<2x1x4x3xi32>
 }
 
@@ -412,7 +436,7 @@
 
 func @transpose_bad_permutations_splat(%arg0: tensor<1x2x3x4xi32>) ->  tensor<2x1x4x3xi32> {
   // expected-error@+1 {{permutation must be a DenseIntElementsAttr}}
-  %0 = "xla.transpose"(%arg0) {permutation = dense<2.0> : tensor<2xf64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
+  %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<2.0> : tensor<2xf64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
   return %0: tensor<2x1x4x3xi32>
 }
 
@@ -420,7 +444,7 @@
 
 func @transpose_bad_permutations_sparse(%arg0: tensor<1x2x3x4xi32>) ->  tensor<2x1x4x3xi32> {
   // expected-error@+1 {{permutation must be a DenseIntElementsAttr}}
-  %0 = "xla.transpose"(%arg0) {permutation = sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
+  %0 = "xla_hlo.transpose"(%arg0) {permutation = sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
   return %0: tensor<2x1x4x3xi32>
 }
 
@@ -428,7 +452,7 @@
 
 func @transpose_bad_permutations_rank(%arg0: tensor<1x2x3x4xi32>) ->  tensor<2x1x4x3xi32> {
   // expected-error@+1 {{permutation has rank 2 instead of rank 1}}
-  %0 = "xla.transpose"(%arg0) {permutation = dense<[[1]]> : tensor<1x1xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
+  %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[[1]]> : tensor<1x1xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
   return %0: tensor<2x1x4x3xi32>
 }
 
@@ -436,7 +460,7 @@
 
 func @transpose_bad_permutations_size(%arg0: tensor<1x2x3x4xi32>) ->  tensor<2x1x4x3xi32> {
   // expected-error@+1 {{permutation size (1) does not match operand rank (4)}}
-  %0 = "xla.transpose"(%arg0) {permutation = dense<[1]> : tensor<1xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
+  %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1]> : tensor<1xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
   return %0: tensor<2x1x4x3xi32>
 }
 
@@ -444,7 +468,7 @@
 
 func @transpose_operand_result_rank_mismatch(%arg0: tensor<1x2x3x4xi32>) ->  tensor<2xi32> {
   // expected-error@+1 {{result rank (1) does not match operand rank (4)}}
-  %0 = "xla.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2xi32>
+  %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2xi32>
   return %0: tensor<2xi32>
 }
 
@@ -452,7 +476,7 @@
 
 func @transpose_operand_result_permutation_mismatch(%arg0: tensor<1x2x3x4xi32>) ->  tensor<1x2x3x4xi32> {
   // expected-error@+1 {{result shape is [1, 2, 3, 4] instead of [2, 1, 4, 3]}}
-  %0 = "xla.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
+  %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
   return %0: tensor<1x2x3x4xi32>
 }
 
@@ -460,6 +484,6 @@
 
 // CHECK-LABEL: func @tuple
 func @tuple(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>> {
-  %0 = "xla.tuple"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>>
+  %0 = "xla_hlo.tuple"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>>
   return %0: tuple<tensor<1xi32>, tensor<1x2xf32>>
 }
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/xla/tests/reshape.mlir b/tensorflow/compiler/mlir/xla/tests/reshape.mlir
index ee29a71..5987e03 100644
--- a/tensorflow/compiler/mlir/xla/tests/reshape.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/reshape.mlir
@@ -6,7 +6,7 @@
 func @reshape.const.1() -> tensor<f32> {
   // CHECK-NEXT: %cst = constant dense<4.200000e+01> : tensor<f32>
   %cst = constant  {name = "constant.1"} dense<42.0> : tensor<1x1xf32>
-  %0 = "xla.reshape"(%cst) : (tensor<1x1xf32>) -> tensor<f32>
+  %0 = "xla_hlo.reshape"(%cst) : (tensor<1x1xf32>) -> tensor<f32>
   // CHECK-NEXT: return %cst : tensor<f32>
   return %0 : tensor<f32>
 }
@@ -17,7 +17,7 @@
 func @reshape.const.2() -> tensor<2xf32> {
   // CHECK-NEXT: %cst = constant dense<4.200000e+01> : tensor<2xf32>
   %cst = constant  {name = "constant.1"} dense<42.0> : tensor<1x2xf32>
-  %0 = "xla.reshape"(%cst) : (tensor<1x2xf32>) -> tensor<2xf32>
+  %0 = "xla_hlo.reshape"(%cst) : (tensor<1x2xf32>) -> tensor<2xf32>
   // CHECK-NEXT: return %cst : tensor<2xf32>
   return %0 : tensor<2xf32>
 }
@@ -28,7 +28,7 @@
 func @reshape.const.3() -> tensor<1xf32> {
   // CHECK-NEXT: %cst = constant dense<4.200000e+01> : tensor<1xf32>
   %cst = constant  {name = "constant.1"} dense<42.0> : tensor<f32>
-  %0 = "xla.reshape"(%cst) : (tensor<f32>) -> tensor<1xf32>
+  %0 = "xla_hlo.reshape"(%cst) : (tensor<f32>) -> tensor<1xf32>
   // CHECK-NEXT: return %cst : tensor<1xf32>
   return %0 : tensor<1xf32>
 }
@@ -39,7 +39,7 @@
 func @reshape.const.4() -> tensor<16xi64> {
   // CHECK-NEXT: %cst = constant dense<42> : tensor<16xi64>
   %cst = constant  dense<42> : tensor<4x4xi64>
-  %0 = "xla.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
+  %0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
   // CHECK-NEXT: return %cst : tensor<16xi64>
   return %0 : tensor<16xi64>
 }
@@ -50,7 +50,7 @@
 func @reshape.const.5() -> tensor<16xf64> {
   // CHECK-NEXT: %cst = constant dense<4.200000e+01> : tensor<16xf64>
   %cst = constant  dense<4.200000e+01> : tensor<4x4xf64>
-  %0 = "xla.reshape"(%cst) : (tensor<4x4xf64>) -> tensor<16xf64>
+  %0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xf64>) -> tensor<16xf64>
   // CHECK-NEXT: return %cst : tensor<16xf64>
   return %0 : tensor<16xf64>
 }
@@ -62,7 +62,7 @@
 func @reshape.const.6() -> tensor<6xi32> {
   // CHECK-NEXT: %cst = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
   %cst = constant  {name = "constant.1"} dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
-  %0 = "xla.reshape"(%cst) : (tensor<3x2xi32>) -> tensor<6xi32>
+  %0 = "xla_hlo.reshape"(%cst) : (tensor<3x2xi32>) -> tensor<6xi32>
   // CHECK-NEXT: return %cst : tensor<6xi32>
   return %0 : tensor<6xi32>
 }
@@ -74,7 +74,7 @@
 func @reshape.const.7() -> tensor<2x3xi32> {
   // CHECK-NEXT: %cst = constant dense<{{\[\[}}1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
   %cst = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
-  %0 = "xla.reshape"(%cst) : (tensor<6xi32>) -> tensor<2x3xi32>
+  %0 = "xla_hlo.reshape"(%cst) : (tensor<6xi32>) -> tensor<2x3xi32>
   // CHECK-NEXT: return %cst : tensor<2x3xi32>
   return %0 : tensor<2x3xi32>
 }
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/add.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/add.hlotxt
index d285df1..96423e0 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/add.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/add.hlotxt
@@ -13,15 +13,15 @@
   %Arg_3.4 = f32[] parameter(3)
 
   // Add two tensors
-  // CHECK-NEXT:   %0 = "xla.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-NEXT:   %0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
   %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
 
   // Add two scalars
-  // CHECK-NEXT: %1 = "xla.add"(%arg2, %arg3) {name = "add.4"} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  // CHECK-NEXT: %1 = "xla_hlo.add"(%arg2, %arg3) {name = "add.4"} : (tensor<f32>, tensor<f32>) -> tensor<f32>
   %add.4 = f32[] add(f32[] %Arg_2.3, f32[] %Arg_3.4)
 
   // Add a tensor and scalar
-  // CHECK-NEXT: %2 = "xla.add"(%0, %1) {name = "add.5"} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
+  // CHECK-NEXT: %2 = "xla_hlo.add"(%0, %1) {name = "add.5"} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
   // CHECK-NEXT: return %2 : tensor<4xf32>
   ROOT %add.5 = f32[4] add(f32[4] %add.3, f32[] %add.4)
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/add.mlir b/tensorflow/compiler/mlir/xla/tests/translate/add.mlir
index 4009759..a77b90c 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/add.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/translate/add.mlir
@@ -6,9 +6,9 @@
   // CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1)
 
   // CHECK-NEXT: %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
-  %0 = "xla.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
 
   // CHECK-NEXT: ROOT %add.4 = f32[4] add(f32[4] %add.3, f32[4] %Arg_1.2)
-  %1 = "xla.add"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  %1 = "xla_hlo.add"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
   return %1 : tensor<4xf32>
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/and.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/and.hlotxt
index 1826809..25cf3ec 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/and.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/and.hlotxt
@@ -7,7 +7,7 @@
   %Arg_0.1 = f32[4] parameter(0)
   %Arg_1.2 = f32[4] parameter(1)
 
-  // CHECK-NEXT: %0 = "xla.and"(%arg0, %arg1) {name = "and.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-NEXT: %0 = "xla_hlo.and"(%arg0, %arg1) {name = "and.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT: return %0 : tensor<4xf32>
   ROOT %and.3 = f32[4] and(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/binary_op_broadcast.mlir b/tensorflow/compiler/mlir/xla/tests/translate/binary_op_broadcast.mlir
index 9aff639..38aa4f0 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/binary_op_broadcast.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/translate/binary_op_broadcast.mlir
@@ -8,19 +8,19 @@
   // CHECK-NEXT: %broadcast.5 = s32[2,4] broadcast(s32[4] %reshape.4)
   // CHECK-NEXT: %Arg_1.2 = s32[2,4] parameter(1)
   // CHECK-NEXT: %add.6 = s32[2,4] add(s32[2,4] %broadcast.5, s32[2,4] %Arg_1.2)
-  %0 = "xla.add"(%arg0, %arg1) : (tensor<1x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
+  %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<1x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
 
   // Broadcast up rank
   // CHECK-NEXT: %broadcast.7 = s32[2,3,4] broadcast(s32[2,4] %Arg_1.2), dimensions={0,2}
   // CHECK-NEXT: %Arg_2.3 = s32[2,3,4] parameter(2)
   // CHECK-NEXT: %add.8 = s32[2,3,4] add(s32[2,3,4] %broadcast.7, s32[2,3,4] %Arg_2.3)
-  %1 = "xla.add"(%arg1, %arg2) {broadcast_dimensions = dense<[0,2]> : tensor<2xi64>} : (tensor<2x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
+  %1 = "xla_hlo.add"(%arg1, %arg2) {broadcast_dimensions = dense<[0,2]> : tensor<2xi64>} : (tensor<2x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
 
   // Broadcast up rank + degenerate broadcast
   // CHECK-NEXT: %broadcast.9 = s32[2,1,4] broadcast(s32[1,4] %Arg_0.1), dimensions={1,2}
   // CHECK-NEXT: %reshape.10 = s32[2,4] reshape(s32[2,1,4] %broadcast.9)
   // CHECK-NEXT: %broadcast.11 = s32[2,3,4] broadcast(s32[2,4] %reshape.10), dimensions={0,2}
   // CHECK-NEXT: ROOT %add.12 = s32[2,3,4] add(s32[2,3,4] %broadcast.11, s32[2,3,4] %Arg_2.3)
-  %2 = "xla.add"(%arg0, %arg2) {broadcast_dimensions = dense<[1,2]> : tensor<2xi64>} : (tensor<1x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
+  %2 = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<[1,2]> : tensor<2xi64>} : (tensor<1x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
   return %2 : tensor<2x3x4xi32>
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/broadcast.mlir b/tensorflow/compiler/mlir/xla/tests/translate/broadcast.mlir
index 1d23153..0b64ab2 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/broadcast.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/translate/broadcast.mlir
@@ -4,6 +4,6 @@
 func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> {
   // CHECK-NEXT: %Arg_0.1 = s32[4] parameter(0)
   // CHECK-NEXT: ROOT %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] %Arg_0.1), dimensions={3}
-  %0 = "xla.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32>
+  %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32>
   return %0 : tensor<1x2x3x4xi32>
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/broadcast_in_dim.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/broadcast_in_dim.hlotxt
index d9c2e9f..3d520fc 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/broadcast_in_dim.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/broadcast_in_dim.hlotxt
@@ -6,14 +6,14 @@
 ENTRY %main {
   %Arg_0.1 = f32[1, 2] parameter(0)
 
-  // CHECK-NEXT: %0 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "broadcast.2"} : (tensor<1x2xf32>) -> tensor<1x2x3xf32>
+  // CHECK-NEXT: %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "broadcast.2"} : (tensor<1x2xf32>) -> tensor<1x2x3xf32>
   %broadcast.2 = f32[1,2,3] broadcast(%Arg_0.1), dimensions={0,1}
 
   // Degenerate broadcast
-  // CHECK-NEXT: %1 = "xla.broadcast_in_dim"(%arg0) {name = "broadcast.3"} : (tensor<1x2xf32>) -> tensor<3x2xf32>
+  // CHECK-NEXT: %1 = "xla_hlo.broadcast_in_dim"(%arg0) {name = "broadcast.3"} : (tensor<1x2xf32>) -> tensor<3x2xf32>
   broadcast.3 = f32[3,2] broadcast(%Arg_0.1), dimensions={}
 
-  // CHECK-NEXT: %2 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>, name = "broadcast.4"} : (tensor<1x2xf32>) -> tensor<3x1x2xf32>
+  // CHECK-NEXT: %2 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>, name = "broadcast.4"} : (tensor<1x2xf32>) -> tensor<3x1x2xf32>
   // CHECK-NEXT: return %2 : tensor<3x1x2xf32>
   ROOT broadcast.4 = f32[3,1,2] broadcast(%Arg_0.1), dimensions={1, 2}
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/call.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/call.hlotxt
index c7ea0f9..c6cd58f 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/call.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/call.hlotxt
@@ -5,7 +5,7 @@
 // CHECK-LABEL: func @call(%arg0: tensor<i64>) -> tensor<i64> {
 %call (arg_1: s64[]) -> s64[] {
   %arg_1 = s64[] parameter(0), metadata={op_name="XLA_Args"}
-  // CHECK-NEXT: %0 = "xla.add"(%arg0, %arg0) {name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i64>
+  // CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg0) {name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i64>
   // CHECK-NEXT: return %0 : tensor<i64>
   ROOT %compare.2 = s64[] add(%arg_1, %arg_1), metadata={op_type="Less" op_name="Less"}
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/comp.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/comp.hlotxt
index ed3019b..637629d 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/comp.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/comp.hlotxt
@@ -8,14 +8,14 @@
   %Arg_1.2 = f32[3] parameter(1)
   %Arg_2.3 = f32[1] parameter(2)
 
-  // CHECK-NEXT: %0 = "xla.compare"(%arg0, %arg1) {comparison_direction = "EQ", name = "compare.4"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
+  // CHECK-NEXT: %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ", name = "compare.4"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
   %compare.4 = pred[3] compare(Arg_0.1, Arg_1.2), direction=EQ
 
-  // CHECK-NEXT: %1 = "xla.compare"(%arg0, %arg1) {comparison_direction = "LE", name = "compare.5"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
+  // CHECK-NEXT: %1 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "LE", name = "compare.5"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
   %compare.5 = pred[3] compare(Arg_0.1, Arg_1.2), direction=LE
 
   // Requires broadcast of compatible tensors.
-  // CHECK-NEXT: %2 = "xla.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "compare.6"} : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xi1>
+  // CHECK-NEXT: %2 = "xla_hlo.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "compare.6"} : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xi1>
   // CHECK-NEXT: return %2 : tensor<3xi1>
   ROOT %compare.6 = pred[3] compare(Arg_0.1, Arg_2.3), direction=GT
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/concat.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/concat.hlotxt
index e73447d..b23c22b 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/concat.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/concat.hlotxt
@@ -7,7 +7,7 @@
   %Arg_0.1 = f32[4, 1] parameter(0)
   %Arg_1.2 = f32[4, 2] parameter(1)
 
-  // CHECK-NEXT: %0 = "xla.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<4x1xf32>, tensor<4x2xf32>) -> tensor<4x3xf32>
+  // CHECK-NEXT: %0 = "xla_hlo.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<4x1xf32>, tensor<4x2xf32>) -> tensor<4x3xf32>
   // CHECK-NEXT: return %0 : tensor<4x3xf32>
   ROOT %concatenate.3 = f32[4, 3] concatenate(f32[4, 1] %Arg_0.1, f32[4, 2] %Arg_1.2), dimensions={1}
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/conv.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/conv.hlotxt
index 0de3ac6..6c5989b 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/conv.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/conv.hlotxt
@@ -8,10 +8,10 @@
 ENTRY %tfcompile.7 {
   %arg0.1 = f32[1,16,16,1]{3,2,1,0} parameter(0), metadata={op_name="XLA_Args"}
 
-  // CHECK-NEXT:   %0 = "xla.copy"(%arg0) {name = "copy.1"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
+  // CHECK-NEXT:   %0 = "xla_hlo.copy"(%arg0) {name = "copy.1"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
   %copy.1 = f32[1,16,16,1]{2,1,3,0} copy(%arg0.1), metadata={op_name="XLA_Args"}
 
-  // CHECK-NEXT:   %1 = "xla.reshape"(%0) {name = "reshape.2"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
+  // CHECK-NEXT:   %1 = "xla_hlo.reshape"(%0) {name = "reshape.2"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
   %reshape.2 = f32[1,16,16,1]{2,1,3,0} reshape(%copy.1)
 
   // Note that double brackets "[[" have to be escaped as they denote variables
@@ -19,13 +19,13 @@
   // CHECK-NEXT:   %cst = constant  {name = "constant.3"} dense<{{\[\[\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[\[}}3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32>
   %constant.3 = f32[2,2,1,1]{3,2,1,0} constant({{{{0.5}}, {{-0.6}}}, {{{0.3}}, {{-0.1}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
 
-  // CHECK-NEXT:   %2 = "xla.conv"(%1, %cst) {name = "convolution.4"} : (tensor<1x16x16x1xf32>, tensor<2x2x1x1xf32>) -> tensor<1x16x16x1xf32>
+  // CHECK-NEXT:   %2 = "xla_hlo.conv"(%1, %cst) {name = "convolution.4"} : (tensor<1x16x16x1xf32>, tensor<2x2x1x1xf32>) -> tensor<1x16x16x1xf32>
   %convolution.4 = f32[1,16,16,1]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=2x2 pad=0_1x0_1}, dim_labels=b01f_01io->b01f, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
 
-  // CHECK-NEXT:   %3 = "xla.reshape"(%2) {name = "reshape.5"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
+  // CHECK-NEXT:   %3 = "xla_hlo.reshape"(%2) {name = "reshape.5"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
   %reshape.5 = f32[1,16,16,1]{3,2,1,0} reshape(%convolution.4), metadata={op_name="XLA_Retvals"}
 
-  // CHECK-NEXT:   %4 = "xla.tuple"(%3) {name = "tuple.6"} : (tensor<1x16x16x1xf32>) -> tuple<tensor<1x16x16x1xf32>>
+  // CHECK-NEXT:   %4 = "xla_hlo.tuple"(%3) {name = "tuple.6"} : (tensor<1x16x16x1xf32>) -> tuple<tensor<1x16x16x1xf32>>
   // CHECK-NEXT:   return %4 : tuple<tensor<1x16x16x1xf32>>
   ROOT %tuple.6 = (f32[1,16,16,1]{3,2,1,0}) tuple(%reshape.5), metadata={op_name="XLA_Retvals"}
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/convert.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/convert.hlotxt
index 3c0c7a9..f22646f 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/convert.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/convert.hlotxt
@@ -7,13 +7,13 @@
   %Arg_0.1 = f32[4] parameter(0)
   %Arg_1.2 = f32[] parameter(1)
 
-  // CHECK-NEXT: %0 = "xla.convert"(%arg0) {name = "convert.3"} : (tensor<4xf32>) -> tensor<4xf64>
+  // CHECK-NEXT: %0 = "xla_hlo.convert"(%arg0) {name = "convert.3"} : (tensor<4xf32>) -> tensor<4xf64>
   %convert.3 = f64[4] convert(f32[4] %Arg_0.1)
 
-  // CHECK-NEXT: %1 = "xla.convert"(%arg1) {name = "convert.4"} : (tensor<f32>) -> tensor<f64>
+  // CHECK-NEXT: %1 = "xla_hlo.convert"(%arg1) {name = "convert.4"} : (tensor<f32>) -> tensor<f64>
   %convert.4 = f64[] convert(f32[] %Arg_1.2)
 
-  // CHECK-NEXT: %2 = "xla.add"(%0, %1) {name = "add.5"} : (tensor<4xf64>, tensor<f64>) -> tensor<4xf64>
+  // CHECK-NEXT: %2 = "xla_hlo.add"(%0, %1) {name = "add.5"} : (tensor<4xf64>, tensor<f64>) -> tensor<4xf64>
   // CHECK-NEXT: return %2 : tensor<4xf64>
   ROOT %add.5 = f64[4] add(f64[4] %convert.3, f64[] %convert.4)
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/div.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/div.hlotxt
index 602ad96..772e47a 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/div.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/div.hlotxt
@@ -7,7 +7,7 @@
   %Arg_0.1 = f32[4] parameter(0)
   %Arg_1.2 = f32[4] parameter(1)
 
-  // CHECK-NEXT: %0 = "xla.div"(%arg0, %arg1) {name = "divide.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-NEXT: %0 = "xla_hlo.div"(%arg0, %arg1) {name = "divide.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT: return %0 : tensor<4xf32>
   ROOT %divide.3 = f32[4] divide(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/dot.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/dot.hlotxt
index 5b7d0c6..88beb2f 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/dot.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/dot.hlotxt
@@ -7,17 +7,17 @@
   %Arg_0.1 = f32[1, 4] parameter(0)
   %Arg_1.2 = f32[4, 1] parameter(1)
 
-  // CHECK-NEXT:   %0 = "xla.dot"(%arg0, %arg1) {name = "dot.3", precision_config = ["HIGH", "HIGHEST"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
+  // CHECK-NEXT:   %0 = "xla_hlo.dot"(%arg0, %arg1) {name = "dot.3", precision_config = ["HIGH", "HIGHEST"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
   dot.3 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={high,highest}
 
-  // CHECK-NEXT:   %1 = "xla.dot"(%arg0, %arg1) {name = "dot.4", precision_config = ["HIGHEST", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
+  // CHECK-NEXT:   %1 = "xla_hlo.dot"(%arg0, %arg1) {name = "dot.4", precision_config = ["HIGHEST", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
   dot.4 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,default}
 
-  // CHECK-NEXT:   %2 = "xla.dot"(%arg0, %arg1) {name = "dot.5", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
+  // CHECK-NEXT:   %2 = "xla_hlo.dot"(%arg0, %arg1) {name = "dot.5", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
   %dot.5 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={default,default}
 
   // TODO(b/129709049) consider making this default precision config inferred.
-  // CHECK-NEXT:   %3 = "xla.dot"(%arg0, %arg1) {name = "dot.6", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
+  // CHECK-NEXT:   %3 = "xla_hlo.dot"(%arg0, %arg1) {name = "dot.6", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
   // CHECK-NEXT:   return %3 : tensor<f32>
   ROOT %dot.6 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/dynamic-update-slice.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/dynamic-update-slice.hlotxt
index d31160c..8536945 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/dynamic-update-slice.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/dynamic-update-slice.hlotxt
@@ -9,7 +9,7 @@
   %Arg_2.3 = f32[] parameter(2)
   %Arg_3.4 = f32[] parameter(3)
 
-  // CHECK-NEXT: %0 = "xla.dynamic-update-slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x4xf32>, tensor<1x4xf32>, tensor<f32>, tensor<f32>) -> tensor<4x4xf32>
+  // CHECK-NEXT: %0 = "xla_hlo.dynamic-update-slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x4xf32>, tensor<1x4xf32>, tensor<f32>, tensor<f32>) -> tensor<4x4xf32>
   // CHECK-NEXT: return %0 : tensor<4x4xf32>
   ROOT %dynamic-update-slice.5 = f32[4, 4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3, %Arg_3.4)
 }
@@ -20,7 +20,7 @@
   %Arg_1.2 = f32[2] parameter(1)
   %Arg_2.3 = f32[] parameter(2)
 
-  // CHECK-NEXT: %0 = "xla.dynamic-update-slice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<f32>) -> tensor<4xf32>
+  // CHECK-NEXT: %0 = "xla_hlo.dynamic-update-slice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<f32>) -> tensor<4xf32>
   // CHECK-NEXT: return %0 : tensor<4xf32>
   ROOT %dynamic-update-slice.5 = f32[4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3)
 }
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt
index a4e5b19..fca13d7 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt
@@ -9,95 +9,95 @@
   %arg0.1 = f32[1,300] parameter(0)
   %arg1.2 = f32[1,300,3,1] parameter(1)
 
-  // CHECK-NEXT: %0 = "xla.reshape"(%arg0) {name = "reshape.3"} : (tensor<1x300xf32>) -> tensor<1x300xf32>
+  // CHECK-NEXT: %0 = "xla_hlo.reshape"(%arg0) {name = "reshape.3"} : (tensor<1x300xf32>) -> tensor<1x300xf32>
   %reshape.3 = f32[1,300] reshape(%arg0.1)
 
-  // CHECK-NEXT: %1 = "xla.transpose"(%0) {name = "transpose.27", permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32>
+  // CHECK-NEXT: %1 = "xla_hlo.transpose"(%0) {name = "transpose.27", permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32>
   %transpose.27 = f32[300,1] transpose(%reshape.3), dimensions={1,0}
 
-  // CHECK-NEXT: %2 = "xla.reshape"(%1) {name = "reshape.28"} : (tensor<300x1xf32>) -> tensor<300x1x1xf32>
+  // CHECK-NEXT: %2 = "xla_hlo.reshape"(%1) {name = "reshape.28"} : (tensor<300x1xf32>) -> tensor<300x1x1xf32>
   %reshape.28 = f32[300,1,1] reshape(%transpose.27)
 
-  // CHECK-NEXT: %3 = "xla.reshape"(%2) {name = "reshape.29"} : (tensor<300x1x1xf32>) -> tensor<300x1xf32>
+  // CHECK-NEXT: %3 = "xla_hlo.reshape"(%2) {name = "reshape.29"} : (tensor<300x1x1xf32>) -> tensor<300x1xf32>
   %reshape.29 = f32[300,1] reshape(%reshape.28)
 
-  // CHECK-NEXT: %4 = "xla.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "broadcast.30"} : (tensor<300x1xf32>) -> tensor<300x1x5xf32>
+  // CHECK-NEXT: %4 = "xla_hlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "broadcast.30"} : (tensor<300x1xf32>) -> tensor<300x1x5xf32>
   %broadcast.30 = f32[300,1,5] broadcast(%reshape.29), dimensions={0,1}
 
   // CHECK-NEXT: %cst = constant  {name = "constant.8"} dense<1.000000e+00> : tensor<f32>
   %constant.8 = f32[] constant(1)
 
-  // CHECK-NEXT: %5 = "xla.broadcast_in_dim"(%cst) {name = "broadcast.9"} : (tensor<f32>) -> tensor<300x1x5xf32>
+  // CHECK-NEXT: %5 = "xla_hlo.broadcast_in_dim"(%cst) {name = "broadcast.9"} : (tensor<f32>) -> tensor<300x1x5xf32>
   %broadcast.9 = f32[300,1,5] broadcast(%constant.8), dimensions={}
 
-  // CHECK-NEXT: %6 = "xla.mul"(%4, %5) {name = "multiply.31"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
+  // CHECK-NEXT: %6 = "xla_hlo.mul"(%4, %5) {name = "multiply.31"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
   %multiply.31 = f32[300,1,5] multiply(%broadcast.30, %broadcast.9)
 
   // CHECK-NEXT: %cst_0 = constant  {name = "constant.32"} dense<0.000000e+00> : tensor<f32>
   %constant.32 = f32[] constant(0)
 
-  // CHECK-NEXT: %7 = "xla.broadcast_in_dim"(%cst_0) {name = "broadcast.33"} : (tensor<f32>) -> tensor<300x1x5xf32>
+  // CHECK-NEXT: %7 = "xla_hlo.broadcast_in_dim"(%cst_0) {name = "broadcast.33"} : (tensor<f32>) -> tensor<300x1x5xf32>
   %broadcast.33 = f32[300,1,5] broadcast(%constant.32), dimensions={}
 
-  // CHECK-NEXT: %8 = "xla.compare"(%6, %7) {comparison_direction = "GT", name = "compare.34"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1>
+  // CHECK-NEXT: %8 = "xla_hlo.compare"(%6, %7) {comparison_direction = "GT", name = "compare.34"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1>
   %compare.34 = pred[300,1,5] compare(%multiply.31, %broadcast.33), direction=GT
 
   // CHECK-NEXT: %cst_1 = constant  {name = "constant.10"} dense<0.000000e+00> : tensor<f32>
   %constant.10 = f32[] constant(0)
 
-  // CHECK-NEXT: %9 = "xla.broadcast_in_dim"(%cst_1) {name = "broadcast.11"} : (tensor<f32>) -> tensor<300x1x5xf32>
+  // CHECK-NEXT: %9 = "xla_hlo.broadcast_in_dim"(%cst_1) {name = "broadcast.11"} : (tensor<f32>) -> tensor<300x1x5xf32>
   %broadcast.11 = f32[300,1,5] broadcast(%constant.10), dimensions={}
 
   // CHECK-NEXT: %cst_2 = constant  {name = "constant.40"} dense<0.000000e+00> : tensor<f32>
   %constant.40 = f32[] constant(0)
 
-  // CHECK-NEXT: %10 = "xla.broadcast_in_dim"(%cst_2) {name = "broadcast.41"} : (tensor<f32>) -> tensor<300x5xf32>
+  // CHECK-NEXT: %10 = "xla_hlo.broadcast_in_dim"(%cst_2) {name = "broadcast.41"} : (tensor<f32>) -> tensor<300x5xf32>
   %broadcast.41 = f32[300,5] broadcast(%constant.40), dimensions={}
 
-  // CHECK-NEXT: %11 = "xla.copy"(%arg1) {name = "copy.1"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
+  // CHECK-NEXT: %11 = "xla_hlo.copy"(%arg1) {name = "copy.1"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
   %copy.1 = f32[1,300,3,1] copy(%arg1.2)
 
-  // CHECK-NEXT: %12 = "xla.reshape"(%11) {name = "reshape.4"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
+  // CHECK-NEXT: %12 = "xla_hlo.reshape"(%11) {name = "reshape.4"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
   %reshape.4 = f32[1,300,3,1] reshape(%copy.1)
 
-  // CHECK-NEXT: %13 = "xla.reshape"(%12) {name = "reshape.24"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32>
+  // CHECK-NEXT: %13 = "xla_hlo.reshape"(%12) {name = "reshape.24"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32>
   %reshape.24 = f32[1,300,3] reshape(%reshape.4)
 
-  // CHECK-NEXT: %14 = "xla.transpose"(%13) {name = "transpose.25", permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32>
+  // CHECK-NEXT: %14 = "xla_hlo.transpose"(%13) {name = "transpose.25", permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32>
   %transpose.25 = f32[300,1,3] transpose(%reshape.24), dimensions={1,0,2}
 
-  // CHECK-NEXT: %15 = "xla.reshape"(%14) {name = "reshape.26"} : (tensor<300x1x3xf32>) -> tensor<300x3xf32>
+  // CHECK-NEXT: %15 = "xla_hlo.reshape"(%14) {name = "reshape.26"} : (tensor<300x1x3xf32>) -> tensor<300x3xf32>
   %reshape.26 = f32[300,3] reshape(%transpose.25)
 
   // CHECK-NEXT: %cst_3 = constant  {name = "constant.35"} dense<{{\[\[}}-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32>
   %constant.35 = f32[3,5] constant({ { -0.106023, 0.121505, 0.800239, -0.768885, 0.0966113 }, { 0.689014, -0.407056, -0.797853, 0.00378925, -0.208881 }, { -0.608529, 0.0276617, 0.268557, 0.577401, -0.428437 } })
 
   // TODO(b/129709049) consider making this default precision config implied.
-  // CHECK-NEXT: %16 = "xla.dot"(%15, %cst_3) {name = "dot.36", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32>
+  // CHECK-NEXT: %16 = "xla_hlo.dot"(%15, %cst_3) {name = "dot.36", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32>
   %dot.36 = f32[300,5] dot(%reshape.26, %constant.35), lhs_contracting_dims={1}, rhs_contracting_dims={0}
 
   // CHECK-NEXT: %cst_4 = constant  {name = "constant.37"} dense<0.000000e+00> : tensor<5xf32>
   %constant.37 = f32[5]{0} constant({0, 0, 0, 0, 0})
 
-  // CHECK-NEXT: %17 = "xla.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "broadcast.38"} : (tensor<5xf32>) -> tensor<300x5xf32>
+  // CHECK-NEXT: %17 = "xla_hlo.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "broadcast.38"} : (tensor<5xf32>) -> tensor<300x5xf32>
   %broadcast.38 = f32[300,5] broadcast(%constant.37), dimensions={1}
 
-  // CHECK-NEXT: %18 = "xla.add"(%16, %17) {name = "add.39"} : (tensor<300x5xf32>, tensor<300x5xf32>) -> tensor<300x5xf32>
+  // CHECK-NEXT: %18 = "xla_hlo.add"(%16, %17) {name = "add.39"} : (tensor<300x5xf32>, tensor<300x5xf32>) -> tensor<300x5xf32>
   %add.39 = f32[300,5] add(%dot.36, %broadcast.38)
 
-  // CHECK-NEXT: %19 = "xla.max"(%10, %18) {name = "maximum.42"} : (tensor<300x5xf32>, tensor<300x5xf32>) -> tensor<300x5xf32>
+  // CHECK-NEXT: %19 = "xla_hlo.max"(%10, %18) {name = "maximum.42"} : (tensor<300x5xf32>, tensor<300x5xf32>) -> tensor<300x5xf32>
   %maximum.42 = f32[300,5] maximum(%broadcast.41, %add.39)
 
-  // CHECK-NEXT: %20 = "xla.reshape"(%19) {name = "reshape.44"} : (tensor<300x5xf32>) -> tensor<300x1x5xf32>
+  // CHECK-NEXT: %20 = "xla_hlo.reshape"(%19) {name = "reshape.44"} : (tensor<300x5xf32>) -> tensor<300x1x5xf32>
   %reshape.44 = f32[300,1,5] reshape(%maximum.42)
 
-  // CHECK-NEXT: %21 = "xla.select"(%8, %9, %20) {name = "select.45"} : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
+  // CHECK-NEXT: %21 = "xla_hlo.select"(%8, %9, %20) {name = "select.45"} : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
   %select.45 = f32[300,1,5] select(%compare.34, %broadcast.11, %reshape.44)
 
-  // CHECK-NEXT: %22 = "xla.reshape"(%21) {name = "reshape.46"} : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
+  // CHECK-NEXT: %22 = "xla_hlo.reshape"(%21) {name = "reshape.46"} : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
   %reshape.46 = f32[300,1,5] reshape(%select.45)
 
-  // CHECK-NEXT: %23 = "xla.tuple"(%22) {name = "tuple.47"} : (tensor<300x1x5xf32>) -> tuple<tensor<300x1x5xf32>>
+  // CHECK-NEXT: %23 = "xla_hlo.tuple"(%22) {name = "tuple.47"} : (tensor<300x1x5xf32>) -> tuple<tensor<300x1x5xf32>>
   // CHECK-NEXT: return %23 : tuple<tensor<300x1x5xf32>>
   ROOT %tuple.47 = (f32[300,1,5]) tuple(%reshape.46)
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/iota.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/iota.hlotxt
index 9a4944d..35c762c 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/iota.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/iota.hlotxt
@@ -4,14 +4,14 @@
 
 // CHECK-LABEL: func @main() -> tensor<4xf32> {
 ENTRY %iota.1 () -> f32[4] {
-  // CHECK-NEXT: %0 = "xla.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
+  // CHECK-NEXT: %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
   // CHECK-NEXT: return %0 : tensor<4xf32>
   ROOT %iota.0 = f32[4] iota(), iota_dimension=0
 }
 
 // CHECK-LABEL: func @iota.2() -> tensor<4x5xf32> {
 %iota.2 () -> f32[4, 5] {
-  // CHECK-NEXT: %0 = "xla.iota"() {iota_dimension = 1 : i64} : () -> tensor<4x5xf32>
+  // CHECK-NEXT: %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<4x5xf32>
   // CHECK-NEXT: return %0 : tensor<4x5xf32>
   ROOT %iota.0 = f32[4, 5] iota(), iota_dimension=1
 }
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/max.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/max.hlotxt
index dd6c0f5..f4ba76b 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/max.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/max.hlotxt
@@ -7,7 +7,7 @@
   %Arg_0.1 = f32[4] parameter(0)
   %Arg_1.2 = f32[4] parameter(1)
 
-  // CHECK-NEXT: %0 = "xla.max"(%arg0, %arg1) {name = "maximum.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-NEXT: %0 = "xla_hlo.max"(%arg0, %arg1) {name = "maximum.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT: return %0 : tensor<4xf32>
   ROOT %maximum.3 = f32[4] maximum(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/min.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/min.hlotxt
index 5efe44a..880fc0f 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/min.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/min.hlotxt
@@ -7,7 +7,7 @@
   %Arg_0.1 = f32[4] parameter(0)
   %Arg_1.2 = f32[4] parameter(1)
 
-  // CHECK-NEXT: %0 = "xla.min"(%arg0, %arg1) {name = "minimum.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-NEXT: %0 = "xla_hlo.min"(%arg0, %arg1) {name = "minimum.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT: return %0 : tensor<4xf32>
   ROOT %minimum.3 = f32[4] minimum(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/mul.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/mul.hlotxt
index 1bfb666..ad7feef 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/mul.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/mul.hlotxt
@@ -7,7 +7,7 @@
   %Arg_0.1 = f32[4] parameter(0)
   %Arg_1.2 = f32[4] parameter(1)
 
-  // CHECK-NEXT: %0 = "xla.mul"(%arg0, %arg1) {name = "multiply.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-NEXT: %0 = "xla_hlo.mul"(%arg0, %arg1) {name = "multiply.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT: return %0 : tensor<4xf32>
   ROOT %multiply.3 = f32[4] multiply(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/pad.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/pad.hlotxt
index 412f267..84e1fbc 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/pad.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/pad.hlotxt
@@ -7,7 +7,7 @@
   %Arg_0.1 = f32[4] parameter(0)
   %Arg_1.2 = f32[] parameter(1)
 
-  // CHECK-NEXT: %0 = "xla.pad"(%arg0, %arg1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
+  // CHECK-NEXT: %0 = "xla_hlo.pad"(%arg0, %arg1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
   // CHECK-NEXT: return %0 : tensor<4xf32>
   ROOT %pad.3 = f32[4] pad(%Arg_0.1, %Arg_1.2), padding=0_0_0
 }
@@ -17,7 +17,7 @@
   %Arg_0.1 = f32[4, 4, 4] parameter(0)
   %Arg_1.2 = f32[] parameter(1)
 
-  // CHECK-NEXT: %0 = "xla.pad"(%arg0, %arg1) {edge_padding_high = dense<[2, 4, 6]> : tensor<3xi64>, edge_padding_low = dense<[1, 3, 5]> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<4x4x4xf32>, tensor<f32>) -> tensor<7x11x15xf32>
+  // CHECK-NEXT: %0 = "xla_hlo.pad"(%arg0, %arg1) {edge_padding_high = dense<[2, 4, 6]> : tensor<3xi64>, edge_padding_low = dense<[1, 3, 5]> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<4x4x4xf32>, tensor<f32>) -> tensor<7x11x15xf32>
   // CHECK-NEXT: return %0 : tensor<7x11x15xf32>
   ROOT %pad.3 = f32[7, 11, 15] pad(%Arg_0.1, %Arg_1.2), padding=1_2x3_4x5_6
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/reduce.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/reduce.hlotxt
index 37e638e..e4dc4d5 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/reduce.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/reduce.hlotxt
@@ -33,19 +33,19 @@
   %Arg_1.2 = f32[4] parameter(1)
   %Arg_2.3 = f32[] parameter(2)
 
-  // CHECK-NEXT: %0 = "xla.reduce"(%arg0, %arg0, %arg2, %arg2) {computation = @reduce_helper.3, dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>>
+  // CHECK-NEXT: %0 = "xla_hlo.reduce"(%arg0, %arg0, %arg2, %arg2) {computation = @reduce_helper.3, dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>>
   %reduce.1 = f32[4] reduce(%Arg_0.1, %Arg_1.2), dimensions={0}, to_apply=%reduce_helper.1
 
-  // CHECK-NEXT: %1 = "xla.reduce"(%arg0, %arg1) {computation = @reduce_helper.1, dimensions = dense<0> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-NEXT: %1 = "xla_hlo.reduce"(%arg0, %arg1) {computation = @reduce_helper.1, dimensions = dense<0> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4xf32>
   %reduce.2 = f32[] reduce(%reduce.1, %Arg_2.3), dimensions={0}, to_apply=%reduce_helper.2
 
-  // CHECK-NEXT: %2 = "xla.reduce"(%1, %arg2) {computation = @reduce_helper.2, dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<f32>
+  // CHECK-NEXT: %2 = "xla_hlo.reduce"(%1, %arg2) {computation = @reduce_helper.2, dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<f32>
   %reduce.3 = f32[] reduce(%Arg_0.1, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.2
 
-  // CHECK-NEXT: %3 = "xla.reduce"(%arg0, %arg2) {computation = @reduce_helper.2, dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>, tensor<f32>) -> tensor<f32>
+  // CHECK-NEXT: %3 = "xla_hlo.reduce"(%arg0, %arg2) {computation = @reduce_helper.2, dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>, tensor<f32>) -> tensor<f32>
   %reduce.4 = (f32[], f32[]) reduce(%Arg_0.1, %Arg_0.1, %Arg_2.3, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.3
 
-  // CHECK-NEXT: %4 = "xla.sub"(%2, %3) {name = "sub.5"} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  // CHECK-NEXT: %4 = "xla_hlo.sub"(%2, %3) {name = "sub.5"} : (tensor<f32>, tensor<f32>) -> tensor<f32>
   %sub.5 = f32[] subtract(%reduce.2, %reduce.3)
 
   ROOT %tuple.6 = ((f32[], f32[]), f32[]) tuple(%reduce.4, %sub.5)
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/reverse.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/reverse.hlotxt
index 7c8303d..f89f3eb 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/reverse.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/reverse.hlotxt
@@ -6,7 +6,7 @@
 ENTRY %reverse.1 (Arg_0.1: f32[4]) -> f32[4] {
   %Arg_0.1 = f32[4] parameter(0)
 
-  // CHECK-NEXT: %0 = "xla.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-NEXT: %0 = "xla_hlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT: return %0 : tensor<4xf32>
   ROOT reverse.2 = f32[4] reverse(%Arg_0.1), dimensions={0}
 }
@@ -15,7 +15,7 @@
 %reverse.2 (Arg_0.1: f32[4, 4]) -> f32[4, 4] {
   %Arg_0.1 = f32[4, 4] parameter(0)
 
-  // CHECK-NEXT: %0 = "xla.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4xf32>
+  // CHECK-NEXT: %0 = "xla_hlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4xf32>
   // CHECK-NEXT: return %0 : tensor<4x4xf32>
   ROOT reverse.2 = f32[4, 4] reverse(%Arg_0.1), dimensions={0, 1}
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/select.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/select.hlotxt
index b9ae08d..d3fe6a5 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/select.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/select.hlotxt
@@ -8,7 +8,7 @@
   %Arg_1.2 = s32[2,3] parameter(1)
   %Arg_2.3 = s32[2,3] parameter(2)
 
-  // CHECK-NEXT: %0 = "xla.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  // CHECK-NEXT: %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
   // CHECK-NEXT: return %0 : tensor<2x3xi32>
   ROOT %select.4 = s32[2,3] select(pred[2,3] %Arg_0.1, s32[2,3] %Arg_1.2, s32[2,3] %Arg_2.3)
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/select.mlir b/tensorflow/compiler/mlir/xla/tests/translate/select.mlir
index 4990ae7..f00aa0a 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/select.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/translate/select.mlir
@@ -7,7 +7,7 @@
   // CHECK-NEXT: %Arg_2.3 = s32[2,3] parameter(2)
 
   // CHECK-NEXT: ROOT %select.4 = s32[2,3] select(pred[2,3] %Arg_0.1, s32[2,3] %Arg_1.2, s32[2,3] %Arg_2.3)
-  %0 = "xla.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
   return %0 : tensor<2x3xi32>
 }
 
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo b/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo
index 83d85f7..5d35859 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo
+++ b/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo
@@ -139,8 +139,8 @@
 }
 
 # CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<f32> {
-# CHECK-NEXT:   %0 = "xla.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+# CHECK-NEXT:   %0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
 # TODO(b/129709049) consider making this default precision config inferred.
-# CHECK-NEXT:   %1 = "xla.dot"(%0, %arg1) {name = "dot.4", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor<f32>
+# CHECK-NEXT:   %1 = "xla_hlo.dot"(%0, %arg1) {name = "dot.4", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor<f32>
 # CHECK-NEXT:   return %1 : tensor<f32>
 # CHECK-NEXT: }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/simple.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/simple.hlotxt
index 0946262..b3f8e97 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/simple.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/simple.hlotxt
@@ -7,11 +7,11 @@
   %Arg_0.1 = f32[4]{0} parameter(0)
   %Arg_1.2 = f32[4]{0} parameter(1)
 
-  // CHECK-NEXT:   %0 = "xla.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-NEXT:   %0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
   %add.3 = f32[4]{0} add(f32[4]{0} %Arg_0.1, f32[4]{0} %Arg_1.2)
 
   // TODO(b/129709049) consider making this default precision config inferred.
-  // CHECK-NEXT:   %1 = "xla.dot"(%0, %arg1) {name = "dot.4", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor<f32>
+  // CHECK-NEXT:   %1 = "xla_hlo.dot"(%0, %arg1) {name = "dot.4", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor<f32>
   // CHECK-NEXT:   return %1 : tensor<f32>
   ROOT %dot.4 = f32[] dot(f32[4]{0} %add.3, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0}
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/simple.mlir b/tensorflow/compiler/mlir/xla/tests/translate/simple.mlir
index f6e277c..e68262b 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/simple.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/translate/simple.mlir
@@ -2,8 +2,8 @@
 
 func @main(tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> {
 ^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>):
-  %0 = "xla.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
-  %1 = "xla.dot"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  %1 = "xla_hlo.dot"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
   return %1 : tensor<4xf32>
 }
 
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/sub.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/sub.hlotxt
index 6fc493a..24d4dff 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/sub.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/sub.hlotxt
@@ -7,7 +7,7 @@
   %Arg_0.1 = f32[4] parameter(0)
   %Arg_1.2 = f32[4] parameter(1)
 
-  // CHECK-NEXT: %0 = "xla.sub"(%arg0, %arg1) {name = "subtract.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-NEXT: %0 = "xla_hlo.sub"(%arg0, %arg1) {name = "subtract.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT: return %0 : tensor<4xf32>
   ROOT %subtract.3 = f32[4] subtract(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/tanh.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/tanh.hlotxt
index 54dc0fa..806ab79 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/tanh.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/tanh.hlotxt
@@ -6,7 +6,7 @@
 ENTRY %foo (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] {
   %arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="XLA_Args"}
 
-  // CHECK-NEXT: %0 = "xla.tanh"(%arg0) {name = "tanh.3"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
+  // CHECK-NEXT: %0 = "xla_hlo.tanh"(%arg0) {name = "tanh.3"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
   // CHECK-NEXT: return %0 : tensor<1x16x16x3xf32>
   ROOT %tanh.3 = f32[1,16,16,3]{3,2,1,0} tanh(f32[1,16,16,3]{3,2,1,0} %arg0.1), metadata={op_type="Tanh" op_name="embedded_inference/tanh_model/Tanh"}
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/transpose.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/transpose.hlotxt
index 335e546..203152d 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/transpose.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/transpose.hlotxt
@@ -6,7 +6,7 @@
 ENTRY %main {
   %Arg_0.1 = s32[1,2,3,4] parameter(0)
 
-  // CHECK-NEXT: %0 = "xla.transpose"(%arg0) {name = "transpose.2", permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
+  // CHECK-NEXT: %0 = "xla_hlo.transpose"(%arg0) {name = "transpose.2", permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
   // CHECK-NEXT: return %0 : tensor<2x1x4x3xi32>
   ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] %Arg_0.1), dimensions={1,0,3,2}
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/transpose.mlir b/tensorflow/compiler/mlir/xla/tests/translate/transpose.mlir
index e28d0a3..77048e6 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/transpose.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/translate/transpose.mlir
@@ -5,7 +5,7 @@
   // CHECK-NEXT: %Arg_0.1 = s32[1,2,3,4] parameter(0)
 
   // CHECK-NEXT: ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] %Arg_0.1), dimensions={1,0,3,2}
-  %0 = "xla.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
+  %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
   return %0 : tensor<2x1x4x3xi32>
 }
 
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/tuple.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/tuple.hlotxt
index c98fa93..bcaf1c8 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/tuple.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/tuple.hlotxt
@@ -7,10 +7,10 @@
   %Arg_0.1 = s32[1] parameter(0)
   %Arg_1.2 = f32[1, 2] parameter(1)
 
-  // CHECK-NEXT: %0 = "xla.tuple"(%arg0) {name = "tuple.3"} : (tensor<1xi32>) -> tuple<tensor<1xi32>>
+  // CHECK-NEXT: %0 = "xla_hlo.tuple"(%arg0) {name = "tuple.3"} : (tensor<1xi32>) -> tuple<tensor<1xi32>>
   %tuple.3 = (s32[1]) tuple(%Arg_0.1)
 
-  // CHECK-NEXT: %1 = "xla.tuple"(%arg0, %arg1) {name = "tuple.4"} : (tensor<1xi32>, tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>>
+  // CHECK-NEXT: %1 = "xla_hlo.tuple"(%arg0, %arg1) {name = "tuple.4"} : (tensor<1xi32>, tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>>
   // CHECK-NEXT: return %1 : tuple<tensor<1xi32>, tensor<1x2xf32>>
   ROOT %tuple.4 = (s32[1], f32[1,2]) tuple(%Arg_0.1, %Arg_1.2)
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/unknown.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/unknown.hlotxt
index 42d52fd..daf7dd8 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/unknown.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/unknown.hlotxt
@@ -6,6 +6,6 @@
 ENTRY %main (Arg_0.1: f32[1, 4], Arg_1.2: f32[4, 1]) -> f32[1] {
   %Arg_0.1 = f32[1] parameter(0)
 
-  // CHECK-NEXT: %0 = "xla.unknown"(%arg0, %arg0) {name = "add-dependency.2"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  // CHECK-NEXT: %0 = "xla_hlo.unknown"(%arg0, %arg0) {name = "add-dependency.2"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
   ROOT add-dependency.2 = f32[1] add-dependency(Arg_0.1, Arg_0.1)
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt
index a6d2a48..f7ab195 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt
@@ -5,7 +5,7 @@
 // CHECK-LABEL: func @cond(%arg0: tensor<i64>) -> tensor<i1> {
 %cond (arg_1: s64[]) -> pred[] {
   %arg_1 = s64[] parameter(0), metadata={op_name="XLA_Args"}
-  // CHECK-NEXT: %0 = "xla.compare"(%arg0, %arg0) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
+  // CHECK-NEXT: %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
   // CHECK-NEXT: return %0 : tensor<i1>
   ROOT %compare.2 = pred[] compare(%arg_1, %arg_1), direction=LT, metadata={op_type="Less" op_name="Less"}
 }
@@ -13,7 +13,7 @@
 // CHECK-LABEL: func @loop(%arg0: tensor<i64>) -> tensor<i64> {
 %loop (arg_1: s64[]) -> s64[] {
   %arg_1 = s64[] parameter(0), metadata={op_name="XLA_Args"}
-  // CHECK-NEXT: %0 = "xla.add"(%arg0, %arg0) {name = "compare.0"} : (tensor<i64>, tensor<i64>) -> tensor<i64>
+  // CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg0) {name = "compare.0"} : (tensor<i64>, tensor<i64>) -> tensor<i64>
   // CHECK-NEXT: return %0 : tensor<i64>
   ROOT %compare.2 = s64[] add(%arg_1, %arg_1), metadata={op_type="Less" op_name="Less"}
 }
@@ -21,7 +21,7 @@
 // CHECK-LABEL: func @main(%arg0: tensor<i64>) -> tensor<i64> {
 ENTRY %foo (arg0.1: s64[]) -> s64[] {
   %arg0.1 = s64[] parameter(0), metadata={op_name="XLA_Args"}
-  // CHECK-NEXT: %0 = "xla.while"(%arg0) {body = @loop, cond = @cond} : (tensor<i64>) -> tensor<i64>
+  // CHECK-NEXT: %0 = "xla_hlo.while"(%arg0) {body = @loop, cond = @cond} : (tensor<i64>) -> tensor<i64>
   // CHECK-NEXT: return %0 : tensor<i64>
   ROOT %while.2 = s64[] while(%arg0.1), body=%loop, condition=%cond
 }
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc
index cf271f4..1974645 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc
@@ -44,7 +44,7 @@
   // to below:
   //
   //   <prior operations>
-  //   %0 = "xla.while"(%arg0) {body: @loop, cond: @cond}
+  //   %0 = "xla_hlo.while"(%arg0) {body: @loop, cond: @cond}
   //   <post operations>
   auto* opInst = while_op.getOperation();
   mlir::OpBuilder builder(while_op);
@@ -150,6 +150,10 @@
 }  // namespace XLA
 }  // namespace mlir
 
+mlir::FunctionPassBase* mlir::XLA::createLegalizeControlFlowPass() {
+  return new LegalizeControlFlow();
+}
+
 static PassRegistration<mlir::XLA::LegalizeControlFlow> legalize_cf_pass(
     "xla-legalize-control-flow",
     "Legalize from XLA control flow to MLIR control flow");
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index a10329c..d0452cd 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -20,6 +20,7 @@
 #include "mlir/IR/PatternMatch.h"  // TF:local_config_mlir
 #include "mlir/Pass/Pass.h"  // TF:local_config_mlir
 #include "mlir/StandardOps/Ops.h"  // TF:local_config_mlir
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/xla/ir/xla_ops.h"
 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
 
@@ -141,7 +142,7 @@
 
   // Add the generated patterns to the list.
   XLA::populateWithGenerated(func.getContext(), &patterns);
-  applyPatternsGreedily(func, std::move(patterns));
+  applyPatternsGreedily(func, patterns);
 }
 
 static PassRegistration<LegalizeTF> pass(
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
index 7835fcf..c1518bc 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
@@ -22,6 +22,9 @@
 
 def NullElementsAttr : NativeCodeCall<"ElementsAttr()">;
 
+def HasNoUse: Constraint<
+    CPred<"$0->use_begin() == $0->use_end()">, "has no use">;
+
 //===----------------------------------------------------------------------===//
 // BatchNorm op patterns.
 //===----------------------------------------------------------------------===//
@@ -30,17 +33,19 @@
     "getFeatureDimensionAttr($_builder, $0, $1)">;
 def FalseBoolAttr : AttrConstraint<CPred<"!$_self.getValue()">>;
 
-def : Pattern<(TF_FusedBatchNormOp F32Tensor:$x, F32Tensor:$scale,
-                               F32Tensor:$offset, F32Tensor:$mean,
-                               F32Tensor:$variance, F32Attr:$epsilon,
+def : Pattern<
+    (TF_FusedBatchNormOp:$root $x, $scale, $offset, $mean, $variance, $epsilon,
                                $data_format, FalseBoolAttr:$is_training),
-           [(XLA_BatchNormInferenceOp $x, $scale, $offset, $mean, $variance,
-               $epsilon, (FeatureDimension $data_format, $x)),
-            /*batch_mean=*/(verifyUnusedValue),
-            /*batch_variance=*/(verifyUnusedValue),
-            /*reserve_space_1=*/(verifyUnusedValue),
-            /*reserve_space_2=*/(verifyUnusedValue)
-           ]>;
+    [(XLA_BatchNormInferenceOp $x, $scale, $offset, $mean, $variance,
+                               $epsilon, (FeatureDimension $data_format, $x)),
+     // We already guaranteed that the last four results has no use so it
+     // does not matter what value we provide here for replacement.
+     /*batch_mean=*/(replaceWithValue $x),
+     /*batch_variance=*/(replaceWithValue $x),
+     /*reserve_space_1=*/(replaceWithValue $x),
+     /*reserve_space_2=*/(replaceWithValue $x)],
+    [(HasNoUse $root__1), (HasNoUse $root__2),
+     (HasNoUse $root__3), (HasNoUse $root__4)]>;
 
 //===----------------------------------------------------------------------===//
 // Bias op patterns.
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc
index 4ac42d3..0db05b3 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc
@@ -36,7 +36,7 @@
 
 struct CompareIConvert : public RewritePattern {
   explicit CompareIConvert(MLIRContext *context)
-      : RewritePattern("xla.compare", 1, context) {}
+      : RewritePattern("xla_hlo.compare", 1, context) {}
 
   PatternMatchResult matchAndRewrite(Operation *op,
                                      PatternRewriter &rewriter) const override {
@@ -75,7 +75,7 @@
 
 struct CompareFConvert : public RewritePattern {
   explicit CompareFConvert(MLIRContext *context)
-      : RewritePattern("xla.compare", 1, context) {}
+      : RewritePattern("xla_hlo.compare", 1, context) {}
 
   PatternMatchResult matchAndRewrite(Operation *op,
                                      PatternRewriter &rewriter) const override {
@@ -133,11 +133,9 @@
   auto func = getFunction();
 
   mlir::XLA::populateWithGenerated(func.getContext(), &patterns);
-  patterns.push_back(
-      llvm::make_unique<mlir::XLA::CompareFConvert>(&getContext()));
-  patterns.push_back(
-      llvm::make_unique<mlir::XLA::CompareIConvert>(&getContext()));
-  applyPatternsGreedily(func, std::move(patterns));
+  patterns.insert<mlir::XLA::CompareFConvert, mlir::XLA::CompareIConvert>(
+      &getContext());
+  applyPatternsGreedily(func, patterns);
 }
 
 static PassRegistration<LegalizeToStandard> legalize_pass(
diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h
index 2ed0453..00cb50d 100644
--- a/tensorflow/compiler/mlir/xla/transforms/passes.h
+++ b/tensorflow/compiler/mlir/xla/transforms/passes.h
@@ -24,7 +24,10 @@
 /// Lowers from TF dialect to XLA dialect.
 FunctionPassBase *createLegalizeTFPass();
 
-// Lowers from XLA dialect to Standard dialect.
+/// Lowers XLA control flow ops to the Standard dialect.
+FunctionPassBase *createLegalizeControlFlowPass();
+
+/// Lowers from XLA dialect to Standard dialect.
 FunctionPassBase *createLegalizeToStdPass();
 
 }  // end namespace XLA
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index d39d159..e73a461 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -4,7 +4,7 @@
 load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test")
 load("//tensorflow/compiler/tests:build_defs.bzl", "generate_backend_suites")
 load(
-    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "//tensorflow/core/platform:default/build_config_root.bzl",
     "tf_cuda_tests_tags",
 )
 
@@ -1192,6 +1192,7 @@
         "//tensorflow/python:framework",
         "//tensorflow/python:math_ops",
     ],
+    xla_enable_strict_auto_jit = False,
 )
 
 cuda_py_test(
@@ -1217,11 +1218,12 @@
         "nogpu",
         "no_cuda_on_cpu_tap",
     ],
+    xla_enable_strict_auto_jit = False,
 )
 
 cuda_py_test(
     name = "dense_layer_test",
-    size = "small",
+    size = "medium",
     srcs = ["dense_layer_test.py"],
     additional_deps = [
         ":test_utils",
@@ -1232,6 +1234,7 @@
         "//tensorflow/python:layers",
         "//tensorflow/python:variables",
     ],
+    xla_enable_strict_auto_jit = False,
 )
 
 cc_library(
@@ -1317,6 +1320,7 @@
         "//tensorflow/python:platform",
         "//tensorflow/python:variables",
     ],
+    xla_enable_strict_auto_jit = False,
 )
 
 # An example of ahead-of-time compilation using tfcompile.  The
@@ -1395,3 +1399,19 @@
         "@absl_py//absl/testing:parameterized",
     ],
 )
+
+tf_xla_py_test(
+    name = "conv_node_name_test",
+    size = "medium",
+    srcs = ["conv_node_name_test.py"],
+    shard_count = 5,
+    deps = [
+        ":xla_test",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:framework",
+        "//tensorflow/python:nn",
+        "//tensorflow/python:nn_ops",
+        "//tensorflow/python:nn_ops_gen",
+        "//tensorflow/python:platform_test",
+    ],
+)
diff --git a/tensorflow/compiler/tests/adagrad_da_test.py b/tensorflow/compiler/tests/adagrad_da_test.py
index 369d009..e08435b 100644
--- a/tensorflow/compiler/tests/adagrad_da_test.py
+++ b/tensorflow/compiler/tests/adagrad_da_test.py
@@ -56,9 +56,9 @@
         # Run a step of AdagradDA
         update.run()
 
-        # Let g to be gradient accumulator, gg to be gradient squared
-        # accumulator, T be the global step, lr is the learning rate, and k the
-        # initial gradient squared accumulator value.
+        # Let g be the gradient accumulator, gg be the gradient squared
+        # accumulator, T be the global step, lr be the learning rate,
+        # and k the initial gradient squared accumulator value.
         # w = \dfrac{sign(-g)*lr*|g - l1*T|_{+}}{l2*T*lr + \sqrt{k+gg})}
         # For -0.1*3.0*(0.1 - 0)/(0 + sqrt(0.1 + 0.1*0.1)) = -0.904534
         # similarly for others.
diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl
index 96d389a..a3b17e4 100644
--- a/tensorflow/compiler/tests/build_defs.bzl
+++ b/tensorflow/compiler/tests/build_defs.bzl
@@ -3,7 +3,7 @@
 load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
 load("//tensorflow/compiler/tests:plugin.bzl", "plugins")
 load(
-    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "//tensorflow/core/platform:default/build_config_root.bzl",
     "tf_cuda_tests_tags",
     "tf_exec_compatible_with",
 )
diff --git a/tensorflow/compiler/tests/conv_node_name_test.py b/tensorflow/compiler/tests/conv_node_name_test.py
new file mode 100644
index 0000000..85e8bce
--- /dev/null
+++ b/tensorflow/compiler/tests/conv_node_name_test.py
@@ -0,0 +1,115 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Convolution node name match via the XLA JIT.
+
+The canned results in these tests are created by running each test using the
+Tensorflow CPU device and saving the output.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import ops
+from tensorflow.python.layers import layers
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.platform import googletest
+
+
+class ConvolutionNodeNameTest(xla_test.XLATestCase):
+  """Verify convolution node name match.
+
+  Verify convolution node names on TPU and CPU match with dilation > 1.
+  """
+
+  def _verifyNodeNameMatch(self, layer, input_sizes, filter_sizes, strides,
+                           dilations):
+
+    def _GetNodeNames(use_xla):
+      with self.session():
+        input_tensor = array_ops.placeholder(np.float32, shape=input_sizes)
+
+        if use_xla:
+          with self.test_scope():
+            # pylint: disable=protected-access
+            graph = ops.get_default_graph()
+            graph._set_control_flow_context(
+                control_flow_ops.XLAControlFlowContext())
+            # pylint: enable=protected-access
+            conv2d_op = layer(
+                filters=64,
+                kernel_size=filter_sizes,
+                dilation_rate=dilations,
+                padding="same")
+            _ = conv2d_op(input_tensor)
+            return [n.name for n in ops.get_default_graph().as_graph_def().node]
+        else:
+          with ops.device("CPU"):
+            conv2d_op = layer(
+                filters=64,
+                kernel_size=filter_sizes,
+                dilation_rate=dilations,
+                padding="same")
+            _ = conv2d_op(input_tensor)
+            names = [
+                n.name for n in ops.get_default_graph().as_graph_def().node
+            ]
+            # filter out space to depth ops.
+            return [
+                name for name in names
+                if "space" not in name and "Space" not in name
+            ]
+
+    xla_names = _GetNodeNames(use_xla=True)
+    no_xla_names = _GetNodeNames(use_xla=False)
+    self.assertListEqual(
+        xla_names,
+        no_xla_names,
+    )
+
+  def testConv1DNodeNameMatch(self):
+    input_sizes = [8, 16, 3]
+    filter_sizes = [7]
+    strides = 1
+    dilations = [2]
+    layer = layers.Conv1D
+    self._verifyNodeNameMatch(layer, input_sizes, filter_sizes, strides,
+                              dilations)
+
+  def testConv2DNodeNameMatch(self):
+    input_sizes = [8, 16, 16, 3]
+    filter_sizes = [7, 7]
+    strides = 1
+    dilations = [2, 2]
+    layer = layers.Conv2D
+    self._verifyNodeNameMatch(layer, input_sizes, filter_sizes, strides,
+                              dilations)
+
+  def testConv3DNodeNameMatch(self):
+    input_sizes = [8, 16, 16, 16, 3]
+    filter_sizes = [7, 7, 7]
+    strides = 1
+    dilations = [2, 2, 2]
+    layer = layers.Conv3D
+    self._verifyNodeNameMatch(layer, input_sizes, filter_sizes, strides,
+                              dilations)
+
+
+if __name__ == "__main__":
+  googletest.main()
diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py
index c55bc23..a49985f 100644
--- a/tensorflow/compiler/tests/depthwise_conv_op_test.py
+++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py
@@ -25,6 +25,7 @@
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import nn_impl
 from tensorflow.python.ops import nn_ops
 import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
 from tensorflow.python.platform import test
@@ -87,6 +88,32 @@
     yield i, f, o, s, p
 
 
+def ConfigsWithDilationsToTest():
+  """Iterator for different convolution shapes, strides and paddings.
+
+  Yields:
+    Tuple (input_size, filter_size, out_size, stride, dilation, padding), the
+    depthwise
+    convolution parameters.
+  """
+  input_sizes = [[4, 6, 6, 48], [4, 8, 8, 84], [4, 36, 36, 2], [4, 148, 148, 2],
+                 [3, 300, 300, 3]]
+  filter_sizes = [[1, 1, 48, 2], [1, 3, 84, 1], [5, 5, 2, 1], [4, 4, 2, 8],
+                  [2, 2, 3, 8]]
+  out_sizes = [[4, 6, 6, 96], [4, 8, 8, 84], [4, 36, 36, 2], [4, 74, 74, 16],
+               [3, 296, 296, 24]]
+  strides = [1, 1, 2, 2, 1]
+  dilations = [2, 2, 4, 2, 4]
+  # pylint: disable=invalid-name
+  VALID = "VALID"
+  SAME = "SAME"
+  # pylint: enable=invalid-name
+  paddings = [SAME, SAME, SAME, SAME, VALID]
+  for i, f, o, s, d, p in zip(input_sizes, filter_sizes, out_sizes, strides,
+                              dilations, paddings):
+    yield i, f, o, s, d, p
+
+
 def CheckGradConfigsToTest():
   """Iterator for different convolution shapes, strides and paddings.
 
@@ -315,6 +342,118 @@
         padding="VALID",
         expected=expected_output)
 
+  # This is testing that depthwise_conv2d with dilation produces
+  # the same results between CPU and TPU. It also tests that NCHW
+  # and NWHC formats agree.
+  def _VerifyValuesWithDilation(self,
+                                tensor_in_sizes,
+                                filter_in_sizes,
+                                stride,
+                                dilation,
+                                padding,
+                                data_type,
+                                data_format="NHWC"):
+    """Verifies the output values of the convolution function.
+
+    Args:
+      tensor_in_sizes: Input tensor dimensions in [batch, input_rows,
+        input_cols, input_depth].
+      filter_in_sizes: Filter tensor dimensions in [filter_rows, filter_cols,
+        input_depth, depth_multiplier].
+      stride: Stride.
+      dilation: Dilation.
+      padding: Padding type.
+      data_type: The data type to use.
+      data_format: The data_format of the input. "NHWC" or "NCHW".
+    """
+    total_size_1 = 1
+    total_size_2 = 1
+    for s in tensor_in_sizes:
+      total_size_1 *= s
+    for s in filter_in_sizes:
+      total_size_2 *= s
+    # Initializes the input and filter tensor with numbers incrementing from 1.
+    x1 = np.array([f * 1.0 for f in range(1, total_size_1 + 1)],
+                  dtype=data_type).reshape(tensor_in_sizes)
+    x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)],
+                  dtype=data_type).reshape(filter_in_sizes)
+    with self.session() as sess:
+      if data_type == np.float32:
+        # TODO(b/64210055): Tolerance for TPU is high.
+        tolerance = 1e-2
+      else:
+        self.assertEqual(data_type, np.float64)
+        tolerance = 1e-8
+
+      t1 = array_ops.placeholder(shape=tensor_in_sizes, dtype=data_type)
+      t2 = array_ops.placeholder(shape=filter_in_sizes, dtype=data_type)
+
+      native_t1 = t1
+      strides = [1, stride, stride, 1]
+      dilations = [dilation, dilation]
+      if data_format == "NCHW":
+        # Transpose from NWHC input to NCHW
+        # Ex. [4, 5, 5, 48] to [4, 48, 5, 5]
+        native_t1 = array_ops.transpose(t1, [0, 3, 1, 2])
+        strides = [1, 1, stride, stride]
+
+      with self.test_scope():
+        conv_native = nn_impl.depthwise_conv2d(
+            native_t1,
+            t2,
+            strides=strides,
+            rate=dilations,
+            data_format=data_format,
+            padding=padding)
+
+      if data_format == "NCHW":
+        # Transpose back from NCHW to NHWC
+        conv_native = array_ops.transpose(conv_native, [0, 2, 3, 1])
+
+      with ops.device("CPU"):
+        # CPU only support NHWC format
+        strides = [1, stride, stride, 1]
+        conv_interface = nn_impl.depthwise_conv2d(
+            t1, t2, strides=strides, rate=dilations, padding=padding)
+
+      native_result = sess.run(conv_native, {t1: x1, t2: x2})
+      interface_result = sess.run(conv_interface, {t1: x1, t2: x2})
+
+    print("data_type:", data_type, "max diff = ",
+          np.amax(np.absolute(native_result - interface_result)))
+    self.assertAllClose(
+        np.ravel(native_result), np.ravel(interface_result), rtol=tolerance)
+
+  def testDilationDepthwiseConv2DWith(self):
+    for index, (input_size, filter_size, _, stride, dilation,
+                padding) in enumerate(ConfigsWithDilationsToTest()):
+      print("Testing DilationDepthwiseConv2D,", index, "th config:", input_size,
+            "*", filter_size, "stride:", stride, "dilation: ", dilation,
+            "padding:", padding)
+      for data_type in self.float_types:
+        # TODO(phawkins): the reference implementation only supports float32.
+        if data_type == np.float32:
+          self._VerifyValuesWithDilation(input_size, filter_size, stride,
+                                         dilation, padding, data_type)
+
+  def testDilationDepthwiseConv2DWithFormat(self):
+    for index, (input_size, filter_size, _, stride, dilation,
+                padding) in enumerate(ConfigsWithDilationsToTest()):
+      print("Testing DilationDepthwiseConv2DFormat,", index, "th config:",
+            input_size, "*", filter_size, "stride:", stride, "dilation:",
+            dilation, "padding:", padding)
+      for data_type in self.float_types:
+        # TODO(phawkins): the reference implementation only supports float32.
+        if data_type == np.float32:
+          self._VerifyValuesWithDilation(
+              input_size,
+              filter_size,
+              stride,
+              dilation,
+              padding,
+              data_type,
+              data_format="NCHW")
+
   def _CompareBackpropInput(self, input_sizes, filter_sizes, output_sizes,
                             stride, padding):
     x1 = np.random.rand(*filter_sizes).astype(np.float32)
@@ -420,5 +559,139 @@
           padding,
           data_format="NCHW")
 
+  def _CompareBackpropInputWithDilation(self, input_sizes, filter_sizes,
+                                        output_sizes, stride, dilation,
+                                        padding):
+    x1 = np.random.rand(*filter_sizes).astype(np.float32)
+    x2 = np.random.rand(*output_sizes).astype(np.float32)
+
+    def _GetVal(use_xla):
+      with self.session():
+        t1 = array_ops.placeholder(np.float32, shape=filter_sizes)
+        t2 = array_ops.placeholder(np.float32, shape=output_sizes)
+        if use_xla:
+          with self.test_scope():
+            t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)])
+            backprop = nn_ops.depthwise_conv2d_native_backprop_input(
+                t0,
+                t1,
+                t2,
+                strides=[1, stride, stride, 1],
+                dilations=[1, dilation, dilation, 1],
+                padding=padding)
+        else:
+          # TODO(wangtao): figure out gradient with stride > 1.
+          # depthwise_conv2d_native_backprop_input on CPU doesn't support
+          # dilation.
+          t3 = array_ops.space_to_batch(
+              t2, block_size=dilation, paddings=[[0, 0], [0, 0]])
+          input_sizes_transform = [
+              input_sizes[0] * dilation * dilation, input_sizes[1] // dilation,
+              input_sizes[2] // dilation, input_sizes[3]
+          ]
+          t0 = constant_op.constant(
+              input_sizes_transform, shape=[len(input_sizes)])
+          backprop_naive = nn_ops.depthwise_conv2d_native_backprop_input(
+              t0, t1, t3, strides=[1, stride, stride, 1], padding=padding)
+          backprop = array_ops.batch_to_space(
+              backprop_naive, [[0, 0], [0, 0]], block_size=dilation)
+
+        ret = backprop.eval({t1: x1, t2: x2})
+        self.assertShapeEqual(ret, backprop)
+        return ret
+
+    gpu_value = _GetVal(use_xla=True)
+    cpu_value = _GetVal(use_xla=False)
+
+    # TODO (b/64210055): Tolerance for TPU is high.
+    self.assertAllClose(cpu_value, gpu_value, rtol=1e-2, atol=1e-3)
+
+  def testDilationDepthwiseConv2DInputGradWithCompare(self):
+    for index, (input_size, filter_size, output_size, stride, dilation,
+                padding) in enumerate(ConfigsWithDilationsToTest()):
+      print("Testing DilationDepthwiseConv2DInputGradWithDilationCompare,",
+            index, "th config:", input_size, "*", filter_size, "stride:",
+            stride, "dilation:", dilation, "padding:", padding)
+      # TODO(wangtao): implement CPU grad computation with stride > 1.
+      if stride == 1:
+        self._CompareBackpropInputWithDilation(input_size, filter_size,
+                                               output_size, stride, dilation,
+                                               padding)
+
+  def _CompareBackpropFilterWithDilation(self,
+                                         input_sizes,
+                                         filter_sizes,
+                                         output_sizes,
+                                         stride,
+                                         dilation,
+                                         padding,
+                                         data_format="NHWC"):
+    x0 = np.random.rand(*input_sizes).astype(np.float32)
+    x2 = np.random.rand(*output_sizes).astype(np.float32)
+
+    def _GetVal(use_xla):
+      with self.session():
+        t0 = array_ops.placeholder(np.float32, shape=input_sizes)
+        t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)])
+        t2 = array_ops.placeholder(np.float32, shape=output_sizes)
+        native_t0 = t0
+        native_t2 = t2
+        strides = [1, stride, stride, 1]
+        dilations = [1, dilation, dilation, 1]
+
+        if use_xla:
+          if data_format == "NCHW":
+            # Transpose from NWHC input to NCHW
+            # Ex. [4, 5, 5, 48] to [4, 48, 5, 5]
+            native_t0 = array_ops.transpose(t0, [0, 3, 1, 2])
+            native_t2 = array_ops.transpose(t2, [0, 3, 1, 2])
+            strides = [1, 1, stride, stride]
+            dilations = [1, 1, dilation, dilation]
+          with self.test_scope():
+            backprop = nn_ops.depthwise_conv2d_native_backprop_filter(
+                native_t0,
+                t1,
+                native_t2,
+                strides=strides,
+                padding=padding,
+                dilations=dilations,
+                data_format=data_format)
+        else:
+          # For CPU, the format NCHW is not supported. Therefore we always use
+          # NHWC here.
+          # depthwise_conv2d_native_backprop_filter on CPU doesn't support
+          # dilation.
+          native_t3 = array_ops.space_to_batch(
+              native_t2, block_size=dilation, paddings=[[0, 0], [0, 0]])
+          native_t0_transform = array_ops.space_to_batch(
+              native_t0, block_size=dilation, paddings=[[0, 0], [0, 0]])
+          backprop = nn_ops.depthwise_conv2d_native_backprop_filter(
+              native_t0_transform,
+              t1,
+              native_t3,
+              strides=strides,
+              padding=padding)
+        ret = backprop.eval({t0: x0, t2: x2})
+        self.assertShapeEqual(ret, backprop)
+        return ret
+
+    gpu_value = _GetVal(use_xla=True)
+    cpu_value = _GetVal(use_xla=False)
+    # TODO(b/64210055): Tolerance for TPU is high.
+    self.assertAllClose(cpu_value, gpu_value, rtol=1e-3, atol=1e-4)
+
+  def testDilationDepthwiseConv2DFilterGradCompare(self):
+    for index, (input_size, filter_size, output_size, stride, dilation,
+                padding) in enumerate(ConfigsWithDilationsToTest()):
+      print("Testing DilationDepthwiseConv2DFilterGradCompare,", index,
+            "th config:", input_size, "*", filter_size, "producing output",
+            output_size, "stride:", stride, "dilation:", dilation, "padding:",
+            padding)
+      if stride == 1:
+        # TODO(wangtao): implement CPU grad computation with stride > 1.
+        self._CompareBackpropFilterWithDilation(input_size, filter_size,
+                                                output_size, stride, dilation,
+                                                padding)
+
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index d2c459b..a03980f 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -693,8 +693,7 @@
         return x, y
 
       wholly_compiled_f = def_function.function(f)
-      op_by_op_f = function.defun_with_attributes(
-          f, attributes={'_XlaCompile': False})
+      op_by_op_f = def_function.function(f, experimental_compile=False)
 
       x = constant_op.constant([0.0, 2.0], name='data')
 
diff --git a/tensorflow/compiler/tests/matrix_diag_ops_test.py b/tensorflow/compiler/tests/matrix_diag_ops_test.py
index a994be8..6437c27 100644
--- a/tensorflow/compiler/tests/matrix_diag_ops_test.py
+++ b/tensorflow/compiler/tests/matrix_diag_ops_test.py
@@ -328,7 +328,7 @@
   # From here onwards are v2-only tests.
   def testSquare(self):
     # LINT.IfChange
-    if compat.forward_compatible(2019, 7, 31):
+    if compat.forward_compatible(2019, 8, 31):
     # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
       for _, tests in [square_cases()]:
         for diag_index, (vecs, solution) in tests.items():
@@ -340,7 +340,7 @@
 
   def testSquareBatch(self):
     # LINT.IfChange
-    if compat.forward_compatible(2019, 7, 31):
+    if compat.forward_compatible(2019, 8, 31):
     # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
       for _, tests in [square_cases()]:
         for diag_index, (vecs, solution) in tests.items():
@@ -352,7 +352,7 @@
 
   def testRectangularBatch(self):
     # LINT.IfChange
-    if not compat.forward_compatible(2019, 7, 31):
+    if not compat.forward_compatible(2019, 8, 31):
     # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
       return
 
@@ -422,7 +422,7 @@
 
   def testPadding(self):
     # LINT.IfChange
-    if compat.forward_compatible(2019, 7, 31):
+    if compat.forward_compatible(2019, 8, 31):
     # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
       for padding_value in [555, -11]:
         for _, tests in [square_cases(), tall_cases(), fat_cases()]:
@@ -543,7 +543,7 @@
   # From here onwards are v2-only tests.
   def testSingleMatrix(self):
     # LINT.IfChange
-    if compat.forward_compatible(2019, 7, 31):
+    if compat.forward_compatible(2019, 8, 31):
     # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
       for _, tests in [square_cases(), tall_cases(), fat_cases()]:
         for diag_index, (vecs, banded_mat) in tests.items():
@@ -559,7 +559,7 @@
 
   def testBatch(self):
     # LINT.IfChange
-    if compat.forward_compatible(2019, 7, 31):
+    if compat.forward_compatible(2019, 8, 31):
     # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
       for _, tests in [square_cases(), tall_cases(), fat_cases()]:
         for diag_index, (vecs, banded_mat) in tests.items():
@@ -614,7 +614,7 @@
   # From here onwards are v2-only tests.
   def testSingleMatrix(self):
     # LINT.IfChange
-    if compat.forward_compatible(2019, 7, 31):
+    if compat.forward_compatible(2019, 8, 31):
     # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
       for mat, tests in [square_cases(), tall_cases(), fat_cases()]:
         for diag_index, (solution, _) in tests.items():
@@ -625,7 +625,7 @@
 
   def testBatch(self):
     # LINT.IfChange
-    if compat.forward_compatible(2019, 7, 31):
+    if compat.forward_compatible(2019, 8, 31):
     # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
       for mat, tests in [square_cases(), tall_cases(), fat_cases()]:
         for diag_index, (solution, _) in tests.items():
@@ -636,7 +636,7 @@
 
   def testPadding(self):
     # LINT.IfChange
-    if compat.forward_compatible(2019, 7, 31):
+    if compat.forward_compatible(2019, 8, 31):
     # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
       for padding_value in [555, -11]:
         for mat, tests in [square_cases(), tall_cases(), fat_cases()]:
diff --git a/tensorflow/compiler/tests/stateful_random_ops_test.py b/tensorflow/compiler/tests/stateful_random_ops_test.py
index a54cd60..343969c 100644
--- a/tensorflow/compiler/tests/stateful_random_ops_test.py
+++ b/tensorflow/compiler/tests/stateful_random_ops_test.py
@@ -278,10 +278,11 @@
       maxval = 1
       if dtype.is_integer:
         maxval = 100
-      x = gen.uniform(shape=[n], maxval=maxval, dtype=dtype).numpy()
+      t = gen.uniform(shape=[n], maxval=maxval, dtype=dtype)
+      x = t.numpy().astype(float)
       if maxval > 1:
         # Normalize y to range [0, 1).
-        x = x.astype(float) / maxval
+        x = x / maxval
       # Tests that the values are distributed amongst 10 bins with equal
       # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with
       # p=0.05. This test is probabilistic and would be flaky if the random
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 64af33c..349dabb 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -297,11 +297,12 @@
 
       self._assertOpOutputMatchesExpected(
           math_ops.tanh,
-          np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype),
-          expected=np.array(
-              [[0.76159418, 0.76159418, 0.76159418, 0.76159418],
-               [0.76159418, 0.96402758, 0.99505478, 0.99932933]],
-              dtype=dtype))
+          np.array(
+              [[1, 2, 3, 4], [np.inf, -np.inf, np.nan, 20], [19, -19, 22, -22]],
+              dtype=dtype),
+          expected=np.array([[0.76159418, 0.96402758, 0.99505478, 0.99932933],
+                             [1.0, -1.0, np.nan, 1.0], [1.0, -1.0, 1.0, -1.0]],
+                            dtype=dtype))
 
       self._assertOpOutputMatchesExpected(
           nn_ops.log_softmax,
diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD
index fee4d8a..3de09b2 100644
--- a/tensorflow/compiler/tf2tensorrt/BUILD
+++ b/tensorflow/compiler/tf2tensorrt/BUILD
@@ -17,7 +17,7 @@
 load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
 load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_additional_all_protos",
     "tf_proto_library",
 )
@@ -26,7 +26,7 @@
 
 # NOTE: we always assume that if_static returns "otherwise" list in open source.
 load(
-    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "//tensorflow/core/platform:default/build_config_root.bzl",
     "if_static",
 )
 
@@ -97,10 +97,17 @@
         ":utils",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
+        "@local_config_cuda//cuda:cuda_headers",
+        "//tensorflow/core:core_cpu_lib_no_ops",
+        "//tensorflow/core:framework",
         "//tensorflow/core:gpu_headers_lib",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
         "//tensorflow/core:lib_proto_parsing",
+        "//tensorflow/core:stream_executor",
         "//tensorflow/core:stream_executor_headers_lib",
         "//tensorflow/core/grappler/costs:graph_properties",
+        "//tensorflow/stream_executor/lib",
     ] + if_tensorrt([":tensorrt_lib"]) + tf_custom_op_library_additional_deps(),
     alwayslink = 1,
 )
@@ -168,8 +175,12 @@
         ":trt_op_kernels",
         ":trt_op_libs",
         ":trt_resources",
+        ":trt_conversion",
+        ":utils",
         "@com_google_googletest//:gtest",
+        "@com_google_absl//absl/strings",
         "//tensorflow/cc:cc_ops",
+        "//tensorflow/cc:function_ops",
         "//tensorflow/cc:ops",
         "//tensorflow/cc:scope",
         "//tensorflow/core:framework",
@@ -248,6 +259,9 @@
         ":utils",
         "//tensorflow/core:framework_headers_lib",
         "//tensorflow/core:framework_lite",
+        "//tensorflow/core/grappler:op_types",
+        "//tensorflow/core:graph",
+        "//tensorflow/core:gpu_runtime",
         "//tensorflow/core:lib_proto_parsing",
     ] + if_tensorrt([":tensorrt_lib"]),
 )
@@ -318,11 +332,13 @@
         "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
         "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
         "//tensorflow/core/grappler:grappler_item",
+        "//tensorflow/core/grappler:op_types",
         "//tensorflow/core/grappler:utils",
         "//tensorflow/core:framework",
         "//tensorflow/core:framework_lite",
         "//tensorflow/core:gpu_runtime",
         "//tensorflow/core:graph",
+        "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/grappler:devices",
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc
index d5004af..e523992 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc
@@ -324,8 +324,6 @@
                      nvinfer1::IGpuAllocator* alloc,
                      std::vector<Node*>* engine_nodes) {
   const auto& info = infos.at(pos);
-  std::vector<TensorShapeProto> output_shape_protos;
-  std::vector<TensorShapeProto> input_shape_protos;
   std::vector<PartialTensorShape> input_shapes;
   std::vector<NodeDefBuilder::NodeOut> inputs;
   std::vector<Node*> input_nodes;
@@ -359,25 +357,16 @@
     } else {
       // Data edges
       if (!conn.is_input_edge) {
-        // Set the shapes and data types of output edge.
-        TensorShapeProto out_shape;
-        // shape of the output node inside segment
-        conn.inside_shape.AsProto(&out_shape);
-        if (output_shape_protos.size() <= conn.port_number) {
-          output_shape_protos.resize(conn.port_number + 1);
+        // Set the data types of output edge.
+        if (out_types.size() <= conn.port_number) {
           out_types.resize(conn.port_number + 1);
         }
-        output_shape_protos.at(conn.port_number) = out_shape;
         out_types.at(conn.port_number) = conn.connection_type;
       } else {
         // Set the shapes and data types of input edge.
-        TensorShapeProto in_shape;
-        conn.outside_shape.AsProto(&in_shape);
-        if (input_shape_protos.size() <= conn.port_number) {
-          input_shape_protos.resize(conn.port_number + 1);
+        if (input_shapes.size() <= conn.port_number) {
           input_shapes.resize(conn.port_number + 1);
         }
-        input_shape_protos.at(conn.port_number) = in_shape;
         input_shapes.at(conn.port_number) = conn.outside_shape;
         // Shape must be fully defined (excluding batch dimension) for static
         // mode.
@@ -439,8 +428,6 @@
     TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
     segment_string = string(static_cast<const char*>(engine_data->data()),
                             engine_data->size());
-  } else {
-    segment_string = info.segment_graph_def.SerializeAsString();
   }
 
   string prec_string;
@@ -460,15 +447,13 @@
   }
 
   NodeDef trt_node;
+  NameAttrList function;
+  function.set_name(StrCat(info.engine_name, "_native_segment"));
   Status status =
-      node_builder.Attr("input_shapes", input_shape_protos)
-          .Attr("output_shapes", output_shape_protos)
+      node_builder
           .Attr("static_engine",
                 info.engine_type == EngineInfo::EngineType::TRTStatic)
-          .Attr("segment_funcdef_name",
-                params.use_function_backup
-                    ? StrCat(info.engine_name, "_native_segment")
-                    : "")
+          .Attr("segment_func", function)
           .Attr("serialized_segment", segment_string)
           .Attr("calibration_data", "")
           .Attr("max_cached_engines_count", info.maximum_cached_engines)
@@ -537,103 +522,27 @@
   return Status::OK();
 }
 
-// Function to construct a funcdef from the segment and add it to the graph.
-Status RegisterSegmentFunctionToFunctionLibrary(Graph* graph,
-                                                const GraphDef& segment,
-                                                const string& engine_name) {
-  Graph sgraph(graph->flib_def());
-  GraphConstructorOptions gcopts;
-  TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(gcopts, segment, &sgraph));
-  std::map<string, Node*> io_nodes;
-  int num_inputs = 0;
-  for (auto n : sgraph.op_nodes()) {
-    if (absl::StartsWith(n->name(), kInputPHName)) {
-      num_inputs++;
-      io_nodes.insert({n->name(), n});
-    } else if (absl::StartsWith(n->name(), kOutputPHName)) {
-      io_nodes.insert({n->name(), n});
-    }
-  }
-
-  for (int i = 0; i < num_inputs; ++i) {
-    auto name = StrCat(kInputPHName, i);
-    auto node = io_nodes[name];
-    NodeDef nd;
-    NodeDefBuilder node_builder(StrCat(name, "_Arg"),
-                                FunctionLibraryDefinition::kArgOp);
-    VLOG(1) << "Adding " << StrCat(name, "_Arg");
-    TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0))
-                           .Attr("index", i)
-                           .Finalize(&nd));
-    Status s;
-    auto node_arg = sgraph.AddNode(nd, &s);
-    if (!s.ok()) {
-      LOG(ERROR) << "Couldn't add _Arg node for " << name;
-    }
-    for (auto edge : node->out_edges()) {
-      sgraph.AddEdge(node_arg, 0, edge->dst(), edge->dst_input());
-      VLOG(1) << "Updating funcdef input " << node_arg->name() << ":" << 0
-              << " - > " << edge->dst()->name() << ":" << edge->dst_input();
-      if (!s.ok()) {
-        LOG(ERROR) << "Failed to update edge from " << node_arg->name()
-                   << " to " << edge->dst()->name() << ":" << edge->dst_input();
-      }
-    }
-    sgraph.RemoveNode(node);
-  }
-
-  for (int i = 0; i < io_nodes.size() - num_inputs; ++i) {
-    auto name = StrCat(kOutputPHName, i);
-    auto node = io_nodes[name];
-    NodeDef nd;
-    NodeDefBuilder node_builder(StrCat(name, "_Ret"),
-                                FunctionLibraryDefinition::kRetOp);
-    auto edge = *(node->in_edges().begin());
-    NodeDefBuilder::NodeOut nout(edge->src()->name(), edge->src_output(),
-                                 edge->src()->output_type(edge->src_output()));
-    VLOG(1) << " input " << nout.node << ":" << nout.index
-            << " dtype=" << DataTypeString(nout.data_type);
-    // nvcc complains that Input(<brace-enclosed initializer list>) is
-    // ambiguous, so do not use Input({nout}).
-    node_builder.Input(nout);
-    TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0))
-                           .Attr("index", i)
-                           .Finalize(&nd));
-    if (VLOG_IS_ON(3)) {
-      VLOG(3) << nd.DebugString();
-    }
-    Status s;
-    auto node_ret = sgraph.AddNode(nd, &s);
-    if (!s.ok()) {
-      LOG(ERROR) << "Couldn't add _Ret node for " << name;
-    }
-    VLOG(1) << "Update edge from " << edge->src()->name() << ":"
-            << edge->src_output() << " - > " << node_ret->name() << ":" << 0;
-    sgraph.AddEdge(edge->src(), edge->src_output(), node_ret, 0);
-    s = sgraph.UpdateEdge(edge->src(), edge->src_output(), node_ret, 0);
-    if (!s.ok()) {
-      LOG(ERROR) << "Failed to update edge from " << edge->src()->name() << ":"
-                 << edge->src_output() << " - > " << node_ret->name() << ":"
-                 << 0;
-    }
-    sgraph.RemoveNode(node);
-  }
-  FunctionDefLibrary fdeflib;
-  auto native_segment = fdeflib.add_function();
+Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def,
+                                      Graph* graph, const string& engine_name) {
+  Graph segment_graph(graph->flib_def());
+  TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
+                                            segment_graph_def, &segment_graph));
+  FunctionDefLibrary library;
+  auto segment_func = library.add_function();
   TF_RETURN_IF_ERROR(GraphToFunctionDef(
-      sgraph, StrCat(engine_name, "_native_segment"), native_segment));
+      segment_graph, StrCat(engine_name, "_native_segment"), segment_func));
   // Set kIntsonDeviceAttr to true so that all TRTEngineOp outputs are always on
   // a GPU device as expected. Otherwise, some of the tensors of type DT_INT32
   // would be on host if the op generating the tensor has host memory tag set.
-  (*native_segment
-        ->mutable_attr())[FunctionLibraryDefinition::kIntsOnDeviceAttr]
+  (*segment_func->mutable_attr())[FunctionLibraryDefinition::kIntsOnDeviceAttr]
       .set_b(true);
   if (VLOG_IS_ON(7)) {
     VLOG(7) << engine_name << " Function_Def ";
-    VLOG(7) << native_segment->DebugString();
+    VLOG(7) << segment_func->DebugString();
   }
-  VLOG(1) << "Adding funcdef to graphlib";
-  TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdeflib));
+  VLOG(1) << "Adding funcdef " << segment_func->signature().name()
+          << " to graphlib";
+  TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(library));
   return Status::OK();
 }
 
@@ -690,16 +599,10 @@
 // Entry function from optimization pass.
 Status ConvertAfterShapes(const ConversionParams& params) {
   // Sanity checks.
-  if (params.precision_mode == TrtPrecisionMode::INT8) {
-    if (params.use_calibration && !params.use_function_backup) {
-      return errors::InvalidArgument(
-          "Calibration requires enabling fallback to TF function execution.");
-    }
-  } else {
-    if (params.use_calibration) {
-      return errors::InvalidArgument(
-          "Calibration with FP32 or FP16 is not supported.");
-    }
+  if (params.precision_mode != TrtPrecisionMode::INT8 &&
+      params.use_calibration) {
+    return errors::InvalidArgument(
+        "Calibration with FP32 or FP16 is not supported.");
   }
 
   // Convert graphdef to graph.
@@ -760,14 +663,14 @@
                                    : EngineInfo::EngineType::TRTStatic);
     curr_engine.use_calibration = params.use_calibration;
     curr_engine.maximum_cached_engines = params.max_cached_engines;
-    if (params.use_function_backup) {
-      status = RegisterSegmentFunctionToFunctionLibrary(
-          &graph, curr_engine.segment_graph_def, curr_engine.engine_name);
-      if (!status.ok()) {
-        LOG(WARNING) << "Failed to register segment graphdef as a function "
-                     << t << ": " << status;
-        continue;
-      }
+
+    status = RegisterGraphToFunctionLibrary(curr_engine.segment_graph_def,
+                                            &graph, curr_engine.engine_name);
+
+    if (!status.ok()) {
+      LOG(WARNING) << "Failed to register segment graphdef to the library " << t
+                   << ": " << status;
+      continue;
     }
 
     engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong());
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h
index d7f1df5..9288829 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h
@@ -18,6 +18,7 @@
 #include <vector>
 
 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
+#include "tensorflow/core/framework/function.pb.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/grappler/clusters/cluster.h"
 #include "tensorflow/core/grappler/costs/graph_properties.h"
@@ -46,8 +47,6 @@
   // maximum number of cached engines
   int max_cached_engines = 1;
   bool use_calibration = true;
-  // Whether to use function fallback for TRTEngineOp
-  bool use_function_backup = true;
 };
 
 // Method to call from optimization pass
@@ -57,6 +56,11 @@
 std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params,
                                                  const EngineInfo& engine);
 
+// Helper method that registers `segment_graph` as a function to the function
+// library in `graph`.
+Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def,
+                                      Graph* graph, const string& engine_name);
+
 }  // namespace convert
 }  // namespace tensorrt
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
index 3d223d7..851c3df 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
@@ -39,6 +39,7 @@
 #include "tensorflow/core/graph/algorithm.h"
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/grappler/op_types.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/strings/numbers.h"
@@ -75,18 +76,15 @@
 
 namespace tensorflow {
 namespace tensorrt {
-// TODO(aaroey): put these constants into some class.
-const char* const kInputPHName = "TensorRTInputPH_";
-const char* const kOutputPHName = "TensorRTOutputPH_";
+namespace convert {
 
 bool IsEngineInput(absl::string_view name) {
-  return absl::StartsWith(name, kInputPHName);
+  return absl::StartsWith(name, IONamePrefixes::kInputPHName);
 }
 bool IsEngineOutput(absl::string_view name) {
-  return absl::StartsWith(name, kOutputPHName);
+  return absl::StartsWith(name, IONamePrefixes::kOutputPHName);
 }
 
-namespace convert {
 using absl::StrAppend;
 using absl::StrCat;
 
@@ -3907,6 +3905,7 @@
       *tensor, pre_padding, post_padding);
   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+  params->converter->MarkQuantizationRangesAsInferrable(tensor, output_tensor);
 
   if (!legit_pad) {
     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
@@ -5200,19 +5199,33 @@
   for (const auto& node_def : gdef.node()) {
     string node_name = node_def.name();
     VLOG(2) << "Converting op name=" << node_name << ", op=" << node_def.op();
-    if (IsEngineInput(node_name) && (node_def.op() == "Placeholder")) {
+    if (IsEngineInput(node_name)) {
       int32 slot_number = -1;
-      if (!strings::safe_strto32(  // non-absl ok
-              node_name.c_str() + strlen(kInputPHName), &slot_number)) {
-        return errors::InvalidArgument("Failed to parse slot number from ",
-                                       node_name);
+      string type_key;
+      if (node_def.op() == "Placeholder") {
+        if (!strings::safe_strto32(  // non-absl ok
+                node_name.c_str() + strlen(IONamePrefixes::kInputPHName),
+                &slot_number)) {
+          return errors::InvalidArgument("Failed to parse slot number from ",
+                                         node_name);
+        }
+        type_key = "dtype";
+      } else if (tensorflow::grappler::IsArg(node_def)) {
+        // Maybe remove the dependence on grappler and re-implement IsArg,
+        // which is pretty simple (but could change if new Arg nodes are added)
+        slot_number = node_def.attr().at("index").i();
+        type_key = "T";
+      } else {
+        return errors::InvalidArgument(
+            "Node ", node_name,
+            " with is neither Placeholder nor Arg, instead ", node_def.op());
       }
       nvinfer1::DataType trt_dtype;
       nvinfer1::Dims trt_dims;
       int batch_size = -1;
       auto shape = input_shapes.at(slot_number);
       auto status = ValidateTensorProperties(
-          node_def.op(), node_def.attr().at("dtype").type(), shape,
+          node_def.op(), node_def.attr().at(type_key).type(), shape,
           /*validation_only=*/false, &trt_dtype, &trt_dims, &batch_size);
       if (!status.ok()) {
         const string error_message =
@@ -5228,12 +5241,23 @@
       // engines offline, by calling sess.run() and cache/serialize the engines.
       TF_RETURN_IF_ERROR(
           converter.AddInputTensor(node_name, trt_dtype, trt_dims, batch_size));
-    } else if (IsEngineOutput(node_name) && (node_def.op() == "Identity")) {
+    } else if (IsEngineOutput(node_name)) {
       int32 slot_number = -1;
-      if (!strings::safe_strto32(  // non-absl ok
-              node_name.c_str() + strlen(kOutputPHName), &slot_number)) {
-        return errors::InvalidArgument("Failed to parse slot number from ",
-                                       node_name);
+      if (node_def.op() == "Identity") {
+        if (!strings::safe_strto32(  // non-absl ok
+                node_name.c_str() + strlen(IONamePrefixes::kOutputPHName),
+                &slot_number)) {
+          return errors::InvalidArgument("Failed to parse slot number from ",
+                                         node_name);
+        }
+      } else if (tensorflow::grappler::IsRetval(node_def)) {
+        slot_number = node_def.attr().at("index").i();
+      } else {
+        return errors::InvalidArgument(
+            "Node with name ", node_name,
+            " starting with IONamePrefixes::kOutputPHName is "
+            "neither Identity nor Retval, instead ",
+            node_def.op());
       }
       // Get output type that TensorFlow expects
       TFAttrs attrs(node_def);
@@ -5302,7 +5326,8 @@
 
     // Add dummy input/output nodes to the segment graphdef.
     if (connection.is_input_edge) {
-      const string node_name = StrCat(kInputPHName, connection.port_number);
+      const string node_name =
+          StrCat(IONamePrefixes::kInputPHName, connection.port_number);
       if (marker_nodes.count(node_name)) {
         VLOG(1) << "Reusing input " << node_name << " for the edge "
                 << connection.outside_node_name << ":"
@@ -5312,16 +5337,18 @@
       }
       marker_nodes.insert(node_name);
       auto seg_node = segment_def->add_node();
-      NodeDefBuilder builder(node_name, "Placeholder");
+      NodeDefBuilder builder(node_name, "_Arg");
       auto status = builder.Attr("shape", partial_shape)
-                        .Attr("dtype", dtype)
+                        .Attr("T", dtype)
+                        .Attr("index", connection.port_number)
                         .Finalize(seg_node);
       VLOG(1) << "Constructing input " << node_name << " for the edge "
               << connection.outside_node_name << ":" << connection.outside_port
               << " -> " << connection.inside_node_name << ":"
               << connection.inside_port;
     } else {
-      const string node_name = StrCat(kOutputPHName, connection.port_number);
+      const string node_name =
+          StrCat(IONamePrefixes::kOutputPHName, connection.port_number);
       if (marker_nodes.count(node_name)) {
         VLOG(1) << "Reusing output " << node_name << " for the edge "
                 << connection.inside_node_name << ":" << connection.inside_port
@@ -5331,9 +5358,10 @@
       }
       marker_nodes.insert(node_name);
       auto seg_node = segment_def->add_node();
-      NodeDefBuilder builder(node_name, "Identity");
+      NodeDefBuilder builder(node_name, "_Retval");
       auto status =
-          builder
+          builder.Attr("T", dtype)
+              .Attr("index", connection.port_number)
               .Input(connection.inside_node_name, connection.inside_port, dtype)
               .Finalize(seg_node);
       VLOG(1) << "Constructing output " << node_name << " for the edge "
@@ -5359,12 +5387,12 @@
     if (connection.is_control_edge() || !connection.is_input_edge) continue;
     auto snode =
         segment_def->mutable_node(old_to_new_id_map[connection.inside_id]);
-    const string placeholder_name =
-        StrCat(kInputPHName, connection.port_number);
+    const string arg_name =
+        StrCat(IONamePrefixes::kInputPHName, connection.port_number);
     VLOG(1) << "Updating " << snode->name() << ":" << connection.inside_port
             << " from " << snode->input(connection.inside_port) << " to "
-            << placeholder_name;
-    snode->set_input(connection.inside_port, placeholder_name);
+            << arg_name;
+    snode->set_input(connection.inside_port, arg_name);
   }
   std::set<string> subgraph_node_names;
   for (const Node* node : subgraph_nodes) {
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h
index c4249ff..9d475e2 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h
@@ -37,8 +37,6 @@
 
 namespace tensorflow {
 namespace tensorrt {
-extern const char* const kInputPHName;
-extern const char* const kOutputPHName;
 
 namespace convert {
 
@@ -119,8 +117,8 @@
   bool use_calibration;
 };
 
-// Constructs a graphdef from the segment in the given graph. Adds placeholder
-// nodes for input edges (InputPH_*) and identity nodes for output edges
+// Constructs a graphdef from the segment in the given graph. Adds _Arg
+// nodes for input edges (InputPH_*) and _Retval nodes for output edges
 // (OutputPH_*). This function needs to be called before TensorRT nodes
 // inserted in order to correctly get sizes from the original graph.
 //
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
index b6a3587..373e6d8 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
@@ -1158,7 +1158,7 @@
     int batch_size = -1;
     for (const NodeDef& node : gdef.node()) {
       absl::string_view node_name(node.name());
-      if (absl::ConsumePrefix(&node_name, kInputPHName)) {
+      if (absl::ConsumePrefix(&node_name, IONamePrefixes::kInputPHName)) {
         int port = -1;
         EXPECT_TRUE(absl::SimpleAtoi(node_name, &port)) << node.name();
         if (input_shapes.size() < port + 1) input_shapes.resize(port + 1);
@@ -1188,11 +1188,13 @@
 
 TEST_F(ConvertGraphDefToEngineTest, IdentityGraph) {
   Scope s = Scope::NewRootScope();
-  auto input = ops::Placeholder(s.WithOpName(StrCat(kInputPHName, 0)), DT_FLOAT,
-                                ops::Placeholder::Shape({1, 1}));
+  auto input =
+      ops::Placeholder(s.WithOpName(StrCat(IONamePrefixes::kInputPHName, 0)),
+                       DT_FLOAT, ops::Placeholder::Shape({1, 1}));
   auto output = ops::Identity(s.WithOpName("identity1"), input);
   output = ops::Identity(s.WithOpName("identity2"), output);
-  output = ops::Identity(s.WithOpName(StrCat(kOutputPHName, 0)), output);
+  output = ops::Identity(s.WithOpName(StrCat(IONamePrefixes::kOutputPHName, 0)),
+                         output);
   // If the converter marks the input tensor as output tensor, the conversion
   // below will fail with:
   // > TensorRTOutputPH_0 cannot be both input and output
@@ -1453,6 +1455,9 @@
     return converter_->quantization_ranges_;
   }
 
+  void PropagateQuantizationRanges() {
+    converter_->PropagateQuantizationRanges();
+  }
   std::unique_ptr<Converter> converter_;
 
  protected:
@@ -5847,6 +5852,111 @@
 }
 #endif  // IS_TRT_VERSION_GE(6, 0, 0, 0)
 
+NodeDef MakePadNodeDef(std::string name, DataType dtype) {
+  Scope s = Scope::NewRootScope();
+  auto input = ops::Placeholder(s.WithOpName("input"), dtype);
+  auto padding = ops::Placeholder(s.WithOpName("padding"), DT_INT32);
+  auto pad = ops::Pad(s.WithOpName(name), input, padding);
+  return pad.operation.node()->def();
+}
+
+template <typename CType>
+struct PadTestParams {
+  std::vector<int> input_dims;
+  std::vector<int> pad_dims;
+  std::vector<CType> input_values;
+  std::vector<int> expected_output_dims;
+  std::vector<CType> expected_output_values;
+};
+
+template <DataType dtype>
+void TestConvertPad(OpConverterTest* test) {
+  typedef typename EnumToDataType<dtype>::Type CType;
+
+  std::vector<PadTestParams<CType>> params{
+      {
+          /*input_dims=*/{1, 2, 1},  // H, W, C
+          /*pad_dims=*/{4, 2},       // #dims, {pad_before, pad_after}
+          /*input_values=*/CastTestVector<float, CType>({2.0f, -1.0f}),
+          /*expected_output_dims=*/{2, 3, 1},  // H, W, C
+          /*expected_output_values=*/
+          CastTestVector<float, CType>({0.0, 0.0, 0.0, 2.0f, -1.0f, 0.0}),
+      },
+  };
+
+  for (int i = 0; i < params.size(); ++i) {
+    test->Reset();
+    // Create pad node.
+    NodeDef node_def = MakePadNodeDef("my_pad", dtype);
+    // Create input tensor
+    test->AddTestTensor("input", params[i].input_dims, /*batch_size=*/1,
+                        /*trt_dtype=*/TfDataTypeToTrt(dtype));
+    // Create output size.
+    test->AddTestWeights<int32>("padding", params[i].pad_dims,
+                                {0, 0, 1, 0, 0, 1, 0, 0});
+    test->RunValidationAndConversion(node_def);
+
+    TRT_TensorOrWeights output;
+    TF_EXPECT_OK(test->GetTensorOrWeights("padding", &output));
+
+    // Create input data for tensors.
+    const DataVec input_data{
+        {"input", test::AsTensor<CType>(params[i].input_values)}};
+    DataVec output_data{
+        {"my_pad",
+         ConstructTensor<CType>(params[i].expected_output_values.size())}};
+
+    test->BuildAndRun(
+        input_data, &output_data,
+        dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
+    ExpectArrayAlmostEqual(params[i].expected_output_values,
+                           GetSpanForData<CType>(output_data[0]), CType(1e-5));
+  }
+}
+
+TEST_F(OpConverterTest, ConvertPad) {
+  {
+    // First input is weight, should fail.
+    Reset();
+    NodeDef node_def = MakePadNodeDef("my_pad", DT_FLOAT);
+    AddTestWeights<float>("input", {1, 2}, {1, 2});
+    AddTestWeights<int>("padding", {1, 2}, {1, 2});
+    RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
+                               "The input \"tensor\" for Pad must be a "
+                               "tensor");
+  }
+  {
+    // padding is a tensor, should fail.
+    Reset();
+    NodeDef node_def = MakePadNodeDef("my_pad", DT_FLOAT);
+    AddTestTensor("input", {1, 2});
+    AddTestTensor("padding", {1, 2});
+    RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
+                               "The input \"paddings\" for Pad must be a "
+                               "constant");
+  }
+  TestConvertPad<DT_FLOAT>(this);
+  TestConvertPad<DT_HALF>(this);
+  {
+    // Make sure that ranges are inferred across a Pad.
+    Reset();
+    NodeDef node_def = MakePadNodeDef("my_pad", DT_FLOAT);
+    AddTestTensor("input", {1, 2, 1});
+    AddTestWeights<int>("padding", {4, 2}, {0, 0, 1, 0, 0, 1, 0, 0});
+    TRT_TensorOrWeights input;
+    TRT_TensorOrWeights output;
+    RunValidationAndConversion(node_def);
+    TF_EXPECT_OK(GetTensorOrWeights("input", &input));
+    TF_EXPECT_OK(GetTensorOrWeights("my_pad", &output));
+    converter_->ProvideQuantizationRange(input.tensor(), -5.0f, 5.0f);
+    // Input range should be inferred across pad.
+    PropagateQuantizationRanges();
+    auto ranges = quantization_ranges();
+    EXPECT_EQ(5.0f, ranges[input.tensor()]);
+    EXPECT_EQ(5.0f, ranges[output.tensor()]);
+  }
+}
+
 }  // namespace convert
 }  // namespace tensorrt
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc
index 6af483d..35a8c63 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc
@@ -67,9 +67,6 @@
   if (params.count("use_calibration")) {
     use_calibration_ = params.at("use_calibration").b();
   }
-  if (params.count("use_function_backup")) {
-    use_function_backup_ = params.at("use_function_backup").b();
-  }
   return Status::OK();
 }
 
@@ -193,31 +190,30 @@
     LOG(INFO) << CurrentStackTrace();
     PrintDebugInfo(cluster, item);
   }
-  int max_dim = -1;
-  if (!item.feed.empty()) {
-    for (const auto& f : item.feed) {
-      const auto& shape = f.second.shape();
-      if (shape.dims() > 0) {
-        if (shape.dim_size(0) > max_dim) max_dim = shape.dim_size(0);
+  if (!is_dynamic_op_) {
+    int max_batch_dim = -1;
+    if (!item.feed.empty()) {
+      for (const auto& f : item.feed) {
+        const auto& shape = f.second.shape();
+        if (shape.dims() > 0) {
+          if (shape.dim_size(0) > max_batch_dim)
+            max_batch_dim = shape.dim_size(0);
+          VLOG(2) << "Setting max_batch_dim to " << max_batch_dim
+                  << " using batch dimension of " << f.first << " with shape "
+                  << shape;
+        }
       }
     }
-  }
-  if (maximum_batch_size_ < 0) {  // automatic batch size from input
-    if (max_dim > 0) {
-      maximum_batch_size_ = max_dim;
-      VLOG(1) << "Setting maximum batch size to " << max_dim;
-    } else {
-      maximum_batch_size_ = 128;
-      LOG(WARNING) << "Maximum batch size is not set"
-                      " and can't be deduced from inputs setting it to"
-                   << maximum_batch_size_
-                   << ". Suggest configuring it from configuration parameters";
-    }
-  } else {
-    if (max_dim > maximum_batch_size_) {
-      LOG(WARNING) << "Configured batch size " << maximum_batch_size_
-                   << " is less than input batch size " << max_dim
-                   << " adjusting maximum batch size to match input batch size";
+    if (max_batch_dim > maximum_batch_size_) {
+      return errors::InvalidArgument(
+          "Specified max_batch_size=", maximum_batch_size_,
+          " is less than maximum batch dimension of inputs (", max_batch_dim,
+          "). ", "To continue, set max_batch_size to >= ", max_batch_dim);
+    } else if (max_batch_dim < maximum_batch_size_) {
+      LOG(INFO) << "Specified max_batch_size=" << maximum_batch_size_
+                << " is larger than maximum batch dimension of inputs ("
+                << max_batch_dim << "). "
+                << "This can result in poor performance.";
     }
   }
   grappler::GraphProperties static_graph_properties(item);
@@ -259,7 +255,6 @@
   cp.is_dyn_op = is_dynamic_op_;
   cp.max_cached_engines = max_cached_batches_;
   cp.use_calibration = use_calibration_;
-  cp.use_function_backup = use_function_backup_;
   auto status = ConvertAfterShapes(cp);
   VLOG(1) << "Returning from " << name_;
   return status;
diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h
index d3fd914..dbed535 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h
+++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h
@@ -40,8 +40,7 @@
         is_dynamic_op_(false),
         max_cached_batches_(1),
         max_workspace_size_bytes_(256LL << 20),
-        use_calibration_(true),
-        use_function_backup_(true) {
+        use_calibration_(true) {
     VLOG(1) << "Constructing " << name_;
   }
 
@@ -71,8 +70,6 @@
   int64_t max_workspace_size_bytes_;
   bool use_calibration_;
 
-  // Whether to allow TF function fallback path in TRTEngineOp.
-  bool use_function_backup_;
 };
 
 }  // namespace convert
diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.h
index 91c8c66..eb60829 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/utils.h
+++ b/tensorflow/compiler/tf2tensorrt/convert/utils.h
@@ -23,6 +23,12 @@
 namespace tensorflow {
 namespace tensorrt {
 
+class IONamePrefixes {
+ public:
+  static constexpr const char* const kInputPHName = "TensorRTInputPH_";
+  static constexpr const char* const kOutputPHName = "TensorRTOutputPH_";
+};
+
 template <typename T>
 struct TrtDestroyer {
   void operator()(T* t) {
diff --git a/tensorflow/compiler/tf2tensorrt/kernels/get_calibration_data_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/get_calibration_data_op.cc
index 7af6052..374f75c 100644
--- a/tensorflow/compiler/tf2tensorrt/kernels/get_calibration_data_op.cc
+++ b/tensorflow/compiler/tf2tensorrt/kernels/get_calibration_data_op.cc
@@ -40,11 +40,11 @@
     // serialized string to that tensor, and later sess.run() will copy it back
     // to host. We need to optimize this.
 
-    const string& resource_name = context->input(0).scalar<string>()();
+    const string& resource_name = context->input(0).scalar<tstring>()();
     // Get the resource.
     TRTEngineCacheResource* resource = nullptr;
     OP_REQUIRES_OK(context, context->resource_manager()->Lookup(
-                                std::string(kCacheContainerName), resource_name,
+                                std::string(kTfTrtContainerName), resource_name,
                                 &resource));
     core::ScopedUnref sc(resource);
 
@@ -59,7 +59,7 @@
     OP_REQUIRES_OK(context,
                    context->allocate_output(0, TensorShape({}), &output));
 
-    output->scalar<string>()() = serialized_resource;
+    output->scalar<tstring>()() = serialized_resource;
   }
 };
 
diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
index 2494e03..7e7592c 100644
--- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
@@ -17,6 +17,7 @@
 #include <vector>
 
 #include "absl/memory/memory.h"
+#include "absl/strings/ascii.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
@@ -24,10 +25,15 @@
 #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
 #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/graph_optimizer.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/graph_to_functiondef.h"
+#include "tensorflow/core/framework/node_def_builder.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph_constructor.h"
 #include "tensorflow/core/lib/core/refcount.h"
 #include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/lib/strings/strcat.h"
@@ -53,6 +59,7 @@
 
 // A helper class to call done() when destructed for asynchronous execution.
 // Helps simultaneous execution of native and TRT engines.
+
 class AsyncHelper : public core::RefCounted {
  public:
   AsyncHelper(AsyncOpKernel::DoneCallback done) : done_(done) {}
@@ -86,10 +93,15 @@
                VectorTensorShapeHasher>;
 
   // Execute calibration
-  void ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper);
+  void ExecuteCalibration(OpKernelContext* ctx,
+                          TRTEngineCacheResource* cache_res,
+                          AsyncHelper* helper);
 
   // Construct a function handle for executing native funcdef graph
-  Status ConstructFunctionHandle(OpKernelContext* ctx);
+  // These are the exact same function.
+
+  Status ConstructFunctionHandle(FunctionLibraryRuntime* lib,
+                                 const string& device_name);
 
   // Execute replaced native segment as function Op.
   void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
@@ -107,7 +119,8 @@
 
   // Get engine for the input shape
   StatusOr<EngineContext*> GetEngine(
-      const std::vector<TensorShape>& input_shapes, OpKernelContext* ctx);
+      const std::vector<TensorShape>& input_shapes, OpKernelContext* ctx,
+      TRTEngineCacheResource* cache_res);
 
   // Verify that the input shapes are consistent and can be handled by this op.
   Status VerifyInputShapes(const std::vector<TensorShape>& shapes);
@@ -125,10 +138,8 @@
   // serialized protobuf segment or trt engine depending on static_engine_ flag.
   string serialized_segment_;
 
-  // Name of the function for TF native execution of the segment. If empty, it
-  // means TF native execution is not allowed, and if TRT engine fails to run
-  // an error will be returned.
-  string funcdef_name_;
+  // The function for TF native execution of the segment.
+  NameAttrList func_;
 
   // GraphDef representation of the segment.
   GraphDef segment_graph_;
@@ -148,7 +159,7 @@
 
   int64 workspace_size_;
   mutex engine_mutex_;
-  FunctionLibraryRuntime::Handle native_func_;
+  FunctionLibraryRuntime::Handle func_handle_;
 
   // The finalized calibrator for inference.
   std::unique_ptr<TRTInt8Calibrator> calibrator_;
@@ -177,23 +188,61 @@
   }
 }
 
-Status TRTEngineOp::ConstructFunctionHandle(OpKernelContext* ctx) {
+static Status FunctionDefToGraphDef(FunctionLibraryRuntime::Handle handle,
+                                    FunctionLibraryRuntime* flib_runtime,
+                                    GraphDef* graph_def) {
+  const FunctionLibraryDefinition* flib_def =
+      flib_runtime->GetFunctionLibraryDefinition();
+  const FunctionBody* fbody;
+  fbody = flib_runtime->GetFunctionBody(handle);
+  if (!fbody) {
+    return errors::Internal(
+        "Function body is null when converting from FuncDef to GraphDef.");
+  }
+  std::unique_ptr<Graph> graph(new Graph(flib_def));
+  CopyGraph(*fbody->graph, graph.get());
+
+  auto replace_name = [](const char* const prefix, string* name) {
+    if (absl::StartsWith(*name, absl::AsciiStrToLower(prefix))) {
+      name->replace(0, strlen(prefix), prefix);
+      return true;
+    }
+    return false;
+  };
+  graph->ToGraphDef(graph_def);
+  // GraphToFunctionDef() will convert all the node names to lowercase.
+  for (auto& node : *graph_def->mutable_node()) {
+    if (!replace_name(IONamePrefixes::kInputPHName, node.mutable_name())) {
+      if (replace_name(IONamePrefixes::kOutputPHName, node.mutable_name())) {
+        // Instantiation of the function will append _RetVal to the node name,
+        // need to remove it for backward compatibility.
+        const char* const suffix_to_remove = "_RetVal";
+        if (absl::EndsWith(node.name(), suffix_to_remove)) {
+          node.mutable_name()->erase(node.name().size() -
+                                     strlen(suffix_to_remove));
+        }
+      }
+    }
+    for (auto& input : *node.mutable_input()) {
+      if (!replace_name(IONamePrefixes::kInputPHName, &input)) {
+        replace_name(IONamePrefixes::kOutputPHName, &input);
+      }
+    }
+  }
+  return Status::OK();
+}
+
+Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib,
+                                            const string& device_name) {
   VLOG(1) << "Constructing function handle";
-  auto lib = ctx->function_library();
   if (lib == nullptr) {
     return errors::Internal("Context function library is null");
   }
-  auto fdef = lib->GetFunctionLibraryDefinition()->Find(funcdef_name_);
-  if (fdef == nullptr) {
-    return errors::Internal("Native FunctionDef ", funcdef_name_,
-                            " can't be found in function library");
-  }
   FunctionLibraryRuntime::InstantiateOptions inst_ops;
   inst_ops.state_handle = "";
-  inst_ops.target = ctx->device()->name();
-  native_func_ = 0;
-  return lib->Instantiate(funcdef_name_, AttrSlice(&fdef->attr()), inst_ops,
-                          &native_func_);
+  inst_ops.target = device_name;
+  return lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), inst_ops,
+                          &func_handle_);
 }
 
 TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
@@ -204,15 +253,7 @@
   OP_REQUIRES_OK(context,
                  context->GetAttr("workspace_size_bytes", &workspace_size_));
   OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_));
-  if (!static_engine_) {
-    OP_REQUIRES(context, segment_graph_.ParseFromString(serialized_segment_),
-                errors::InvalidArgument("Failed to parse segment graphdef!"));
-    VLOG(1) << "Size of serialized GraphDef: "
-            << serialized_segment_.capacity();
-    string tmp;
-    // Swap with temporary empty string to deallocate the CPU memory.
-    serialized_segment_.swap(tmp);
-  }
+
   VLOG(1) << "Constructing " << name();
   string precision_string;
   OP_REQUIRES_OK(context,
@@ -220,12 +261,25 @@
   string calibration_data;
   OP_REQUIRES_OK(context,
                  context->GetAttr("calibration_data", &calibration_data));
-  OP_REQUIRES_OK(context,
-                 context->GetAttr("segment_funcdef_name", &funcdef_name_));
+  OP_REQUIRES_OK(context, context->GetAttr("segment_func", &func_));
+  OP_REQUIRES(context, !func_.name().empty(),
+              errors::InvalidArgument(
+                  "The TF function for the TRT segment could not be empty"));
   OP_REQUIRES_OK(context,
                  TrtPrecisionModeFromName(precision_string, &precision_mode_));
   OP_REQUIRES_OK(context,
                  context->GetAttr("use_calibration", &use_calibration_));
+  func_handle_ = kInvalidHandle;
+  if (!static_engine_) {
+    FunctionLibraryRuntime* lib = context->function_library();
+    OP_REQUIRES_OK(context,
+                   ConstructFunctionHandle(lib, context->device()->name()));
+    OP_REQUIRES_OK(context,
+                   FunctionDefToGraphDef(func_handle_, lib, &segment_graph_));
+  }
+  // TODO(laigd): calibration_data is used in TF v1.x and we keep it only for
+  // backward compatibility reasons. Remove it once all known users switch to
+  // 2.0.
   calibration_mode_ =
       (use_calibration_ && precision_mode_ == TrtPrecisionMode::INT8 &&
        calibration_data.empty());
@@ -233,20 +287,19 @@
     calibrator_.reset(new TRTInt8Calibrator(calibration_data));
     calibration_data.resize(0);
   }
-  native_func_ = kInvalidHandle;
   OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count",
                                            &max_cached_engines_));
 }
 
 void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
                                        AsyncHelper* helper) {
-  OP_REQUIRES_ASYNC(ctx, !funcdef_name_.empty(),
-                    errors::Internal("Fallback path is disabled, for ", name()),
-                    *helper);
   std::vector<Tensor> inputs;
   std::vector<Tensor>* outputs = new std::vector<Tensor>();
-  if (native_func_ == kInvalidHandle) {
-    OP_REQUIRES_OK_ASYNC(ctx, ConstructFunctionHandle(ctx), *helper);
+  if (func_handle_ == kInvalidHandle) {
+    OP_REQUIRES_OK_ASYNC(
+        ctx,
+        ConstructFunctionHandle(ctx->function_library(), ctx->device()->name()),
+        *helper);
   }
   auto lib = ctx->function_library();
   FunctionLibraryRuntime::Options opts;
@@ -259,7 +312,7 @@
   }
   helper->Ref();  // Increment count for calculating native graph
   VLOG(1) << "Executing native segment: " << name();
-  lib->Run(opts, native_func_, inputs, outputs,
+  lib->Run(opts, func_handle_, inputs, outputs,
            [this, ctx, outputs, helper](const Status& s) {
              core::ScopedUnref sc(helper);
              OP_REQUIRES_OK_ASYNC(ctx, s, *helper);
@@ -272,18 +325,14 @@
 }
 
 void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
+                                     TRTEngineCacheResource* cache_res,
                                      AsyncHelper* helper) {
   VLOG(1) << "Executing TRT calibration: " << name();
   helper->Ref();
   core::ScopedUnref sc(helper);
 
-  TRTEngineCacheResource* cache_res = nullptr;
-  OP_REQUIRES_OK_ASYNC(ctx, GetEngineCacheResource(ctx, &cache_res), *helper);
-  core::ScopedUnref unref_cache_res(cache_res);
-
   CalibrationContext* calib_ctx = cache_res->calib_ctx_.get();
-
-  int num_inputs = ctx->num_inputs();
+  const int num_inputs = ctx->num_inputs();
   // TODO(laigd): need to check that input shape matches.
   // Pass input data to calibrator
   std::unordered_map<string, void*> input_data;
@@ -298,7 +347,7 @@
     const auto device_tensor =
         calib_ctx->device_tensors_.at(i).AccessTensor(ctx);
     CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
-    input_data.emplace(StrCat(kInputPHName, i), data_address);
+    input_data.emplace(StrCat(IONamePrefixes::kInputPHName, i), data_address);
   }
   VLOG(2) << "Filled map for sending";
   // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
@@ -396,10 +445,44 @@
                                AsyncOpKernel::DoneCallback done) {
   auto helper = new AsyncHelper(done);
   core::ScopedUnref sc(helper);
-  if (calibration_mode_) {
-    ExecuteCalibration(ctx, helper);
+
+  // Get TRT resource.
+  TRTEngineCacheResource* cache_res = nullptr;
+  OP_REQUIRES_OK_ASYNC(ctx, GetEngineCacheResource(ctx, &cache_res), *helper);
+  core::ScopedUnref unref_cache_res(cache_res);
+
+  // Run calibration if in int8+calibration mode.
+  // * Logic in TF 1.x:
+  //   - During conversion: calibration_mode_ is true and cache size is 0, so it
+  //     will run calibration.
+  //   - During inference: calibration_data will be set, so calibration_mode_ is
+  //     false and it won't trigger calibration.
+  // * Logic in TF 2.0:
+  //   - During conversion: similar to 1.x.
+  //   - During inference: calibration_data will still be empty, but cache will
+  //     contain the the calibrated engine, so it won't trigger calibration.
+  //
+  // TODO(laigd): consider the following alternatives:
+  // 1. Serialize the state (calibration or inference) using
+  //    TRTEngineInstance proto (or a new proto), so we know which mode we're
+  //    in and don't run calibration during inference (which is invalid).
+  // 2. Reuse the calibration_data attribute or use a new attribute in the
+  //    NodeDef to indicate whether it's in calibration mode.
+  if (calibration_mode_ && cache_res->cache_.size() == 0) {
+    if (!cache_res->calib_ctx_) {
+      // TODO(laigd): better encapsulation.
+      mutex_lock lock(engine_mutex_);
+      if (!cache_res->calib_ctx_) {
+        OP_REQUIRES_OK_ASYNC(ctx, AllocateCalibrationResources(ctx, cache_res),
+                             *helper);
+      }
+    }
+    // TODO(laigd): check that the input shapes match the shapes of the
+    // persistent tensor in the calibration resource.
+    ExecuteCalibration(ctx, cache_res, helper);
     return;
   }
+
   // Get shapes of inputs to engine.
   std::vector<TensorShape> input_shapes;
   input_shapes.reserve(ctx->num_inputs());
@@ -407,8 +490,9 @@
     input_shapes.push_back(ctx->input(i).shape());
   }
   OP_REQUIRES_OK_ASYNC(ctx, VerifyInputShapes(input_shapes), *helper);
-  StatusOr<EngineContext*> status = GetEngine(input_shapes, ctx);
+  StatusOr<EngineContext*> status = GetEngine(input_shapes, ctx, cache_res);
   OP_REQUIRES_OK_ASYNC(ctx, status.status(), *helper);
+
   EngineContext* engine_context = status.ValueOrDie();
   if (!engine_context->cuda_engine) {
     VLOG(1) << "Engine retrieval for input shapes: "
@@ -435,9 +519,11 @@
   // input.
   const int num_batch = ctx->input(0).shape().dim_size(0);
   const int num_binding = ctx->num_inputs() + ctx->num_outputs();
+
   std::vector<void*> buffers(num_binding);
+
   for (int i = 0; i < ctx->num_inputs(); i++) {
-    const string input_name = StrCat(kInputPHName, i);
+    const string input_name = StrCat(IONamePrefixes::kInputPHName, i);
     const int binding_index = cuda_engine->getBindingIndex(input_name.c_str());
     if (binding_index == -1) {
       const string msg =
@@ -479,7 +565,7 @@
 
   for (int i = 0; i < ctx->num_outputs(); i++) {
     // Create an output tensor
-    const string output_name = StrCat(kOutputPHName, i);
+    const string output_name = StrCat(IONamePrefixes::kOutputPHName, i);
     const int binding_index = cuda_engine->getBindingIndex(output_name.c_str());
     Tensor* output_tensor = nullptr;
 
@@ -569,22 +655,17 @@
 
   // Get engine cache.
   return ctx->resource_manager()->LookupOrCreate(
-      std::string(kCacheContainerName), std::string(resource_name), cache_res,
+      std::string(kTfTrtContainerName), std::string(resource_name), cache_res,
       {[this, ctx](TRTEngineCacheResource** cr) -> Status {
         *cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_);
-        if (calibration_mode_) {
-          TF_RETURN_IF_ERROR(AllocateCalibrationResources(ctx, *cr));
-        }
         return Status::OK();
       }});
 }
 
 StatusOr<EngineContext*> TRTEngineOp::GetEngine(
-    const std::vector<TensorShape>& input_shapes, OpKernelContext* ctx) {
+    const std::vector<TensorShape>& input_shapes, OpKernelContext* ctx,
+    TRTEngineCacheResource* cache_res) {
   static EngineContext empty_context;
-  TRTEngineCacheResource* cache_res = nullptr;
-  TF_RETURN_IF_ERROR(GetEngineCacheResource(ctx, &cache_res));
-  core::ScopedUnref sc(cache_res);
 
   mutex_lock lock(engine_mutex_);
   // TODO(tmorris): using first input to get batch size - is this reliable?
@@ -597,6 +678,9 @@
 
   // Handle the static engine case. For static engines, the cache will have a
   // single element containing the only engine.
+  //
+  // TODO(laigd): This is legacy mode for TF v1.x, need to remove when all known
+  // users switch to 2.0.
   if (static_engine_) {
     if (cache.size()) {
       // Batch size of engine must be >= the input batch size
@@ -698,7 +782,7 @@
   const int num_inputs = ctx->num_inputs();
   std::vector<TensorShape> shapes;
   cres->device_tensors_.resize(num_inputs);
-  VLOG(1) << " Constructing calibrator";
+  VLOG(1) << "Constructing calibrator";
   for (int i = 0; i < num_inputs; i++) {
     // allocate workspace on device for inputs
     const Tensor& t = ctx->input(i);
@@ -713,7 +797,7 @@
           "Unsupported data type encountered in input ", i);
     }
     cres->device_buffers_.emplace(
-        StrCat(kInputPHName, i),
+        StrCat(IONamePrefixes::kInputPHName, i),
         std::pair<void*, size_t>(device_address, device_tensor->TotalBytes()));
   }
   cres->calibrator_.reset(
@@ -727,56 +811,53 @@
   }
 
   cache_res->Ref();
-  cres->thr_.reset(
-      new std::thread([this, cres, shapes, platform_gpu_id, cache_res]() {
-        core::ScopedUnref sc(cache_res);
+  cres->thr_.reset(new std::thread([this, cres, shapes, platform_gpu_id,
+                                    cache_res]() {
+    core::ScopedUnref sc(cache_res);
 
-        LOG(INFO) << "Starting calibration thread on device " << platform_gpu_id
-                  << ", Calibration Resource @ " << cres;
-        auto err = cudaSetDevice(platform_gpu_id);
-        if (err != cudaSuccess) {
-          // TODO(aaroey): should return error here.
-          LOG(ERROR) << "Couldn't set cuda device to " << platform_gpu_id
-                     << " in calibration thread";
-        }
-        std::vector<PartialTensorShape> partial_shapes(shapes.begin(),
-                                                       shapes.end());
-        // ConvertGraphDefToEngine() will try to build the engine. This thread
-        // will loop inside buildCudaEngine() consuming the calibration data
-        // that is set by the TF op, and drive the builder until calibrator
-        // returns false. Engine is discarded after calibration table is
-        // generated
-        //
-        // TODO(aaroey): maybe setting the max batch size using the python
-        // calibration wrapper class.
-        auto s = convert::ConvertGraphDefToEngine(
-            this->segment_graph_, TrtPrecisionMode::INT8,
-            cres->calibrator_->getBatchSize(), this->workspace_size_,
-            partial_shapes, &cache_res->GetLogger(),
-            cache_res->allocator_.get(), cres->calibrator_.get(),
-            &cres->engine_,
-            /*use_calibration=*/true,
-            /*convert_successfully=*/nullptr);
-        if (!s.ok()) {
-          LOG(ERROR) << "Calibration failed: " << s;
-          cres->calibrator_->setDone();  // Ignore further pushes
-        }
+    LOG(INFO) << "Starting calibration thread on device " << platform_gpu_id
+              << ", Calibration Resource @ " << cres;
+    auto err = cudaSetDevice(platform_gpu_id);
+    if (err != cudaSuccess) {
+      // TODO(aaroey): should return error here.
+      LOG(ERROR) << "Couldn't set cuda device to " << platform_gpu_id
+                 << " in calibration thread";
+    }
+    std::vector<PartialTensorShape> partial_shapes(shapes.begin(),
+                                                   shapes.end());
+    // ConvertGraphDefToEngine() will try to build the engine. This thread
+    // will loop inside buildCudaEngine() consuming the calibration data
+    // that is set by the TF op, and drive the builder until calibrator
+    // returns false. Engine is discarded after calibration table is
+    // generated
+    //
+    // TODO(aaroey): maybe setting the max batch size using the python
+    // calibration wrapper class.
+    auto s = convert::ConvertGraphDefToEngine(
+        this->segment_graph_, TrtPrecisionMode::INT8,
+        cres->calibrator_->getBatchSize(), this->workspace_size_,
+        partial_shapes, &cache_res->GetLogger(), cache_res->allocator_.get(),
+        cres->calibrator_.get(), &cres->engine_,
+        /*use_calibration=*/true,
+        /*convert_successfully=*/nullptr);
+    if (!s.ok()) {
+      LOG(ERROR) << "Calibration failed: " << s;
+      cres->calibrator_->setDone();  // Ignore further pushes
+    }
 
-        // Transfer the ownership of the engine to the engine cache, so we can
-        // dump it out during conversion for TF 2.0.
-        if (cache_res) {
-          mutex_lock lock(this->engine_mutex_);
-          cres->SetCalibrationTable();
-          this->calibrator_ = std::move(cres->calibrator_);
-          TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
-              cres->engine_->createExecutionContext());
-          cache_res->cache_.emplace(
-              shapes, absl::make_unique<EngineContext>(
-                          std::move(cres->engine_), std::move(exec_context)));
-        }
+    // Transfer the ownership of the engine to the engine cache, so we can
+    // dump it out during conversion for TF 2.0.
+    mutex_lock lock(this->engine_mutex_);
+    cres->SetCalibrationTable();
+    this->calibrator_ = std::move(cres->calibrator_);
+    TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
+        cres->engine_->createExecutionContext());
+    cache_res->cache_.emplace(
+        shapes, absl::make_unique<EngineContext>(std::move(cres->engine_),
+                                                 std::move(exec_context)));
 
-        VLOG(1) << "Calibration loop terminated " << this->name();
-      }));
+    VLOG(1) << "Calibration loop terminated " << this->name();
+  }));
   VLOG(1) << "initialized calibrator resource";
   return Status::OK();
 }
diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc
index 1c08061..4228136 100644
--- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc
+++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc
@@ -22,10 +22,17 @@
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
+#include "absl/strings/str_cat.h"
+#include "tensorflow/cc/ops/function_ops.h"
 #include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
+#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
 #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
 #include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
 #include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/kernels/ops_testutil.h"
@@ -38,6 +45,7 @@
 
 namespace tensorflow {
 namespace tensorrt {
+using ::absl::StrCat;
 using ::testing::ElementsAre;
 
 class TRTEngineOpTestBase : public OpsTestBase {
@@ -49,25 +57,32 @@
 
     // Create simple TF graph.
     Scope s = Scope::NewRootScope();
-    auto feed = ops::Placeholder(s.WithOpName("TensorRTInputPH_0"), dtype,
-                                 ops::Placeholder::Shape({-1, -1}));
+    auto feed = ops::_Arg(s.WithOpName("TensorRTInputPH_0"), dtype, 0);
     auto add = ops::Add(s.WithOpName("add"), feed, feed);
-    ops::Identity(s.WithOpName("TensorRTOutputPH_0"), add);
+    ops::_Retval(s.WithOpName("TensorRTOutputPH_0"), add, 0);
 
     // Serialize the graph. TRTEngineOp will convert it using dynamic mode.
     GraphDef graph_def;
     TF_ASSERT_OK(s.ToGraphDef(&graph_def));
+    Graph* graph = s.graph();
+    const char* op_name = "myop";
+    TF_ASSERT_OK(
+        convert::RegisterGraphToFunctionLibrary(graph_def, graph, op_name));
+    TF_ASSERT_OK(flib_def_->AddLibrary(graph->flib_def()));
+
     PartialTensorShape shape({-1, -1});
 
     // Create the op.
     OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
-    TF_ASSERT_OK(NodeDefBuilder("myop", "TRTEngineOp")
+    NameAttrList function;
+    function.set_name(StrCat(op_name, "_native_segment"));
+    TF_ASSERT_OK(NodeDefBuilder(op_name, "TRTEngineOp")
                      .Input(FakeInput(1, dtype))
                      .Attr("input_shapes", {shape})
                      .Attr("output_shapes", {shape})
                      .Attr("static_engine", false)
-                     .Attr("segment_funcdef_name", "")  // no native fallback
-                     .Attr("serialized_segment", graph_def.SerializeAsString())
+                     .Attr("segment_func", function)
+                     .Attr("serialized_segment", "")
                      .Attr("calibration_data", "")
                      .Attr("max_cached_engines_count", max_cached_engines_count)
                      .Attr("workspace_size_bytes", 1 << 20)
@@ -75,7 +90,7 @@
                      .Attr("use_calibration", false)
                      .Attr("OutT", {dtype})
                      .Finalize(OpsTestBase::node_def()));
-    TF_ASSERT_OK(OpsTestBase::InitOp());
+    TF_ASSERT_OK(InitOpWithFunctionLibrary());
   }
 
   template <typename T>
@@ -89,9 +104,20 @@
     inputs_.clear();
     gtl::STLDeleteElements(&tensors_);
   }
+
+ private:
+  Status InitOpWithFunctionLibrary() {
+    OpKernel* kernel = nullptr;
+    Status status = CreateOpKernel(device_type_, device_, allocator(),
+                                   pflr_->GetFLR(device_->name()), node_def_,
+                                   TF_GRAPH_DEF_VERSION, &kernel);
+    kernel_ = std::unique_ptr<OpKernel>(kernel);
+    if (kernel_ != nullptr) input_types_ = kernel_->input_types();
+    return status;
+  }
 };
 
-TEST_F(TRTEngineOpTestBase, dynamic_shapes) {
+TEST_F(TRTEngineOpTestBase, DynamicShapes) {
   TRTEngineOpTestBase::AddSimpleTrtOp(DT_FLOAT, /*max_cached_engines_count=*/4);
 
   // Execute the op with batch size > 1.
@@ -100,8 +126,8 @@
 
   // Get the engine cache.
   TRTEngineCacheResource* cache_resource = nullptr;
-  TF_ASSERT_OK(device_->resource_manager()->Lookup("TF-TRT-Engine-Cache",
-                                                   "myop", &cache_resource));
+  TF_ASSERT_OK(
+      device_->resource_manager()->Lookup("TF-TRT", "myop", &cache_resource));
   core::ScopedUnref sc(cache_resource);
 
   // It should contain only one engine.
diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc
index 8f6f087..51f7e3a 100644
--- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc
+++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc
@@ -25,6 +25,7 @@
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/refcount.h"
 #include "tensorflow/core/lib/io/record_reader.h"
 #include "tensorflow/core/lib/io/record_writer.h"
@@ -40,11 +41,9 @@
 namespace tensorrt {
 using ::nvinfer1::IRuntime;
 
-class CreateTRTEngineCacheHandle : public OpKernel {
+class CreateTRTResourceHandle : public OpKernel {
  public:
-  explicit CreateTRTEngineCacheHandle(OpKernelConstruction* ctx)
-      : OpKernel(ctx) {
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
+  explicit CreateTRTResourceHandle(OpKernelConstruction* ctx) : OpKernel(ctx) {
     OP_REQUIRES_OK(ctx, ctx->GetAttr("resource_name", &resource_name_));
   }
 
@@ -57,12 +56,11 @@
         OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
                                                &handle_, attr));
 
-        VLOG(1) << "Creating TRT engine cache resource handle for container "
-                << container_ << " and op " << resource_name_ << " on device "
-                << ctx->device()->name();
+        VLOG(1) << "Creating TRT engine cache resource handle for op "
+                << resource_name_ << " on device " << ctx->device()->name();
         handle_.scalar<ResourceHandle>()() =
-            MakeResourceHandle<TRTEngineCacheResource>(ctx, container_,
-                                                       resource_name_);
+            MakeResourceHandle<TRTEngineCacheResource>(
+                ctx, std::string(kTfTrtContainerName), resource_name_);
         initialized_ = true;
       }
     }
@@ -70,23 +68,22 @@
   }
 
  private:
-  string container_;
   string resource_name_;
   Tensor handle_;
   mutex mutex_;
   bool initialized_ GUARDED_BY(mutex_) = false;
 
-  TF_DISALLOW_COPY_AND_ASSIGN(CreateTRTEngineCacheHandle);
+  TF_DISALLOW_COPY_AND_ASSIGN(CreateTRTResourceHandle);
 };
 
-REGISTER_KERNEL_BUILDER(Name("CreateTRTEngineCacheHandle")
+REGISTER_KERNEL_BUILDER(Name("CreateTRTResourceHandle")
                             .Device(DEVICE_GPU)
-                            .HostMemory("engine_cache_handle"),
-                        CreateTRTEngineCacheHandle);
+                            .HostMemory("resource_handle"),
+                        CreateTRTResourceHandle);
 
-class PopulateTRTEngineCache : public OpKernel {
+class InitializeTRTResource : public OpKernel {
  public:
-  explicit PopulateTRTEngineCache(OpKernelConstruction* ctx) : OpKernel(ctx) {
+  explicit InitializeTRTResource(OpKernelConstruction* ctx) : OpKernel(ctx) {
     OP_REQUIRES_OK(
         ctx, ctx->GetAttr("max_cached_engines_count", &max_cached_engines_));
   }
@@ -112,7 +109,7 @@
                                  resource->cache_.size(), " entries."));
 
     // Get the file name.
-    const string& filename = ctx->input(1).scalar<string>()();
+    const string& filename = ctx->input(1).scalar<tstring>()();
     OP_REQUIRES(ctx, !filename.empty(),
                 errors::InvalidArgument("filename cannot be empty."));
 
@@ -150,48 +147,57 @@
                                      raw_engine->createExecutionContext())));
       ++num_loaded_engine;
     } while (1);
-    VLOG(1) << "Loaded " << num_loaded_engine << " TRT engines to container "
-            << handle.container() << " for op " << handle.name()
-            << " on device " << ctx->device()->name() << " from file "
-            << filename;
+    VLOG(1) << "Loaded " << num_loaded_engine << " TRT engines for op "
+            << handle.name() << " on device " << ctx->device()->name()
+            << " from file " << filename;
   }
 
  private:
   // Maximum number of cached engines
   int max_cached_engines_;
 
-  TF_DISALLOW_COPY_AND_ASSIGN(PopulateTRTEngineCache);
+  TF_DISALLOW_COPY_AND_ASSIGN(InitializeTRTResource);
 };
 
-REGISTER_KERNEL_BUILDER(Name("PopulateTRTEngineCache")
+REGISTER_KERNEL_BUILDER(Name("InitializeTRTResource")
                             .Device(DEVICE_GPU)
-                            .HostMemory("engine_cache_handle"),
-                        PopulateTRTEngineCache);
+                            .HostMemory("resource_handle"),
+                        InitializeTRTResource);
 
-class DumpTRTEngineCache : public OpKernel {
+class SerializeTRTResource : public OpKernel {
  public:
-  explicit DumpTRTEngineCache(OpKernelConstruction* ctx) : OpKernel(ctx) {
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("delete_cache_after_dump",
-                                     &delete_cache_after_dump_));
+  explicit SerializeTRTResource(OpKernelConstruction* ctx) : OpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("delete_resource", &delete_resource_));
   }
 
   void Compute(OpKernelContext* ctx) override {
-    const string& container = ctx->input(0).scalar<string>()();
-    const string& resource_name = ctx->input(1).scalar<string>()();
-    const string& filename = ctx->input(2).scalar<string>()();
+    const string& resource_name = ctx->input(0).scalar<tstring>()();
+    const string& filename = ctx->input(1).scalar<tstring>()();
     OP_REQUIRES(ctx, !filename.empty(),
                 errors::InvalidArgument("filename cannot be empty."));
 
+    // Lookup engine cache resource.
     TRTEngineCacheResource* resource = nullptr;
-    OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup(
-                            container, resource_name, &resource));
+    OP_REQUIRES_OK(
+        ctx, ctx->resource_manager()->Lookup(std::string(kTfTrtContainerName),
+                                             resource_name, &resource));
     core::ScopedUnref unref_me(resource);
 
+    // Terminate the calibration if any.
+    if (resource->calib_ctx_) {
+      // We don't save the calibration_table for TF 2.0 at the moment, it's used
+      // in 1.x environment.
+      string calibration_table;
+      OP_REQUIRES_OK(
+          ctx, resource->calib_ctx_->SerializeToString(&calibration_table));
+    }
+
     // Serialize the engines and write them to file.
     std::unique_ptr<WritableFile> file;
     OP_REQUIRES_OK(ctx, ctx->env()->NewWritableFile(filename, &file));
     auto writer = absl::make_unique<io::RecordWriter>(file.get());
 
+    int num_serialized_engines = 0;
     for (const auto& pair : resource->cache_) {
       // Ignore engines that failed to build.
       const std::unique_ptr<EngineContext>& engine = pair.second;
@@ -211,30 +217,29 @@
 
       OP_REQUIRES_OK(ctx,
                      writer->WriteRecord(engine_instance.SerializeAsString()));
+      ++num_serialized_engines;
     }
-    VLOG(1) << "Serialized " << resource->cache_.size()
-            << " TRT engines in container " << container << " for op "
+    VLOG(1) << "Serialized " << num_serialized_engines << " TRT engines for op "
             << resource_name << " on device " << ctx->device()->name()
             << " to file " << filename;
 
-    if (delete_cache_after_dump_) {
-      VLOG(1) << "Destroying TRT engine cache resource in container "
-              << container << " for op " << resource_name << " on device "
-              << ctx->device()->name();
+    if (delete_resource_) {
+      VLOG(1) << "Destroying TRT engine cache resource for op " << resource_name
+              << " on device " << ctx->device()->name();
       OP_REQUIRES_OK(ctx,
                      ctx->resource_manager()->Delete<TRTEngineCacheResource>(
-                         container, resource_name));
+                         std::string(kTfTrtContainerName), resource_name));
     }
   }
 
  private:
-  bool delete_cache_after_dump_ = false;
+  bool delete_resource_ = false;
 
-  TF_DISALLOW_COPY_AND_ASSIGN(DumpTRTEngineCache);
+  TF_DISALLOW_COPY_AND_ASSIGN(SerializeTRTResource);
 };
 
-REGISTER_KERNEL_BUILDER(Name("DumpTRTEngineCache").Device(DEVICE_GPU),
-                        DumpTRTEngineCache);
+REGISTER_KERNEL_BUILDER(Name("SerializeTRTResource").Device(DEVICE_GPU),
+                        SerializeTRTResource);
 
 }  // namespace tensorrt
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops_test.cc
index b3e541a..8492c51 100644
--- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops_test.cc
+++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops_test.cc
@@ -92,11 +92,10 @@
   SetDevice(DEVICE_GPU, std::move(device));
 
   // Create the resource handle.
-  const string container = "mycontainer";
+  const string container(kTfTrtContainerName);
   const string resource_name = "myresource";
   Reset();
-  TF_ASSERT_OK(NodeDefBuilder("op", "CreateTRTEngineCacheHandle")
-                   .Attr("container", container)
+  TF_ASSERT_OK(NodeDefBuilder("op", "CreateTRTResourceHandle")
                    .Attr("resource_name", resource_name)
                    .Finalize(node_def()));
   TF_ASSERT_OK(InitOp());
@@ -108,7 +107,7 @@
   EXPECT_TRUE(
       errors::IsNotFound(rm->Lookup(container, resource_name, &resource)));
 
-  // Create the resouce using an empty file with PopulateTRTEngineCache.
+  // Create the resouce using an empty file with InitializeTRTResource.
   Reset();
   Env* env = Env::Default();
   const string filename = io::JoinPath(testing::TmpDir(), "trt_engine_file");
@@ -116,7 +115,7 @@
     std::unique_ptr<WritableFile> file;
     TF_ASSERT_OK(env->NewWritableFile(filename, &file));
   }
-  TF_ASSERT_OK(NodeDefBuilder("op", "PopulateTRTEngineCache")
+  TF_ASSERT_OK(NodeDefBuilder("op", "InitializeTRTResource")
                    .Input(FakeInput(DT_RESOURCE))
                    .Input(FakeInput(DT_STRING))
                    .Attr("max_cached_engines_count", 1)
@@ -137,16 +136,14 @@
       absl::make_unique<EngineContext>(std::move(engine), std::move(context)));
   resource->Unref();
 
-  // Serialize the engine using DumpTRTEngineCache op.
+  // Serialize the engine using SerializeTRTResource op.
   Reset();
-  TF_ASSERT_OK(NodeDefBuilder("op", "DumpTRTEngineCache")
-                   .Attr("delete_cache_after_dump", true)
-                   .Input(FakeInput(DT_STRING))
+  TF_ASSERT_OK(NodeDefBuilder("op", "SerializeTRTResource")
+                   .Attr("delete_resource", true)
                    .Input(FakeInput(DT_STRING))
                    .Input(FakeInput(DT_STRING))
                    .Finalize(node_def()));
   TF_ASSERT_OK(InitOp());
-  AddInputFromArray<string>(TensorShape({}), {container});
   AddInputFromArray<string>(TensorShape({}), {resource_name});
   AddInputFromArray<string>(TensorShape({}), {filename});
   TF_ASSERT_OK(RunOpKernel());
@@ -178,7 +175,7 @@
 
   // Recreate the cache resource.
   Reset();
-  TF_ASSERT_OK(NodeDefBuilder("op", "PopulateTRTEngineCache")
+  TF_ASSERT_OK(NodeDefBuilder("op", "InitializeTRTResource")
                    .Input(FakeInput(DT_RESOURCE))
                    .Input(FakeInput(DT_STRING))
                    .Attr("max_cached_engines_count", 1)
diff --git a/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc
index b8f9058..7d8ff6d 100644
--- a/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc
+++ b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc
@@ -33,7 +33,7 @@
 // key to cache the instantiated functions for different executor subgraphs.
 REGISTER_OP("TRTEngineOp")
     .Attr("serialized_segment: string")
-    .Attr("segment_funcdef_name: string")
+    .Attr("segment_func: func = {}")
     .Attr("InT: list({int8,float16,float32,int32})")
     .Attr("OutT: list({int8,float16,float32,int32})")
     .Attr("max_cached_engines_count: int = 1")
@@ -51,10 +51,11 @@
     // inference function as a workaround.
     .SetShapeFn(shape_inference::UnknownShape)
     // Deprecated attributes.
+    .Attr("segment_funcdef_name: string = ''")
     .Attr("cached_engine_batches: list(int) >= 0 = []")
     .Attr("fixed_input_size: bool = true")
-    .Attr("input_shapes: list(shape)")
-    .Attr("output_shapes: list(shape)")
+    .Attr("input_shapes: list(shape) = []")
+    .Attr("output_shapes: list(shape) = []")
     .Attr("static_engine: bool = true");
 }  // namespace tensorflow
 
diff --git a/tensorflow/compiler/tf2tensorrt/ops/trt_engine_resource_ops.cc b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_resource_ops.cc
index 67177ef..01911de 100644
--- a/tensorflow/compiler/tf2tensorrt/ops/trt_engine_resource_ops.cc
+++ b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_resource_ops.cc
@@ -24,23 +24,21 @@
 
 namespace tensorflow {
 
-REGISTER_OP("CreateTRTEngineCacheHandle")
-    .Attr("container: string")
+REGISTER_OP("CreateTRTResourceHandle")
     .Attr("resource_name: string")
-    .Output("engine_cache_handle: resource")
+    .Output("resource_handle: resource")
     .SetIsStateful()
     .SetShapeFn(shape_inference::ScalarShape);
 
-REGISTER_OP("PopulateTRTEngineCache")
+REGISTER_OP("InitializeTRTResource")
     .Attr("max_cached_engines_count: int = 1")
-    .Input("engine_cache_handle: resource")
+    .Input("resource_handle: resource")
     .Input("filename: string")
     .SetIsStateful()
     .SetShapeFn(shape_inference::NoOutputs);
 
-REGISTER_OP("DumpTRTEngineCache")
-    .Attr("delete_cache_after_dump: bool = false")
-    .Input("container: string")
+REGISTER_OP("SerializeTRTResource")
+    .Attr("delete_resource: bool = false")
     .Input("resource_name: string")
     .Input("filename: string")
     .SetIsStateful()
diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc
index d518a37..f9306d5 100644
--- a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc
+++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc
@@ -30,7 +30,7 @@
 namespace tensorflow {
 namespace tensorrt {
 
-const absl::string_view kCacheContainerName = "TF-TRT-Engine-Cache";
+const absl::string_view kTfTrtContainerName = "TF-TRT";
 
 Logger& TRTEngineCacheResource::GetLogger() {
   static Logger* logger = new Logger();
diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h
index df25ee0..9c29d56 100644
--- a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h
+++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h
@@ -170,7 +170,7 @@
   std::unique_ptr<std::thread> thr_;
 };
 
-ABSL_CONST_INIT extern const absl::string_view kCacheContainerName;
+ABSL_CONST_INIT extern const absl::string_view kTfTrtContainerName;
 
 class TRTEngineCacheResource : public ResourceBase {
  public:
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 6a28a5a..1f8df23 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -1,6 +1,6 @@
 load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_cuda_cc_test")
 load(
-    "//tensorflow/core:platform/default/cuda_build_defs.bzl",
+    "//tensorflow/core/platform:default/cuda_build_defs.bzl",
     "if_cuda_is_configured",
 )
 load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library", "xla_py_proto_library")
@@ -29,6 +29,7 @@
     packages = [
         "//learning/brain/tools/tf_replay/...",
         "//tensorflow/...",
+        "//tensorflow_models/...",
     ],
 )
 
@@ -207,6 +208,7 @@
         ":side_effect_util",
         ":tf2xla_util",
         "//tensorflow/compiler/jit:flags",
+        "//tensorflow/compiler/jit:shape_inference",
         "//tensorflow/compiler/jit:xla_cluster_util",
         "//tensorflow/compiler/tf2xla:rearrange_function_argument",
         "//tensorflow/compiler/tf2xla/lib:util",
diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc
index ad2cc7b..48513a4 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.cc
+++ b/tensorflow/compiler/tf2xla/const_analysis.cc
@@ -91,7 +91,7 @@
                                  FunctionLibraryRuntime* flib_runtime) {
   DCHECK(op_def != nullptr || op_kernel != nullptr);
   // TODO(b/124403063): Implement similar functionality for function call nodes.
-  if (node.op() == "While") {
+  if (node.op() == "While" || node.op() == "StatelessWhile") {
     // For While nodes, recurse into the body and cond graphs.
     const FunctionBody* fcond = nullptr;
     const FunctionBody* fbody = nullptr;
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index ef2202c..d60b4ca 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -132,6 +132,8 @@
         ":if_op",
         ":tensor_list_utils",
         ":while_op",
+        "//tensorflow/compiler/jit:xla_activity_listener",
+        "//tensorflow/compiler/jit:xla_activity_proto_cc",
         "//tensorflow/compiler/tf2xla:common",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/compiler/tf2xla/lib:broadcast",
@@ -202,6 +204,7 @@
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/types:optional",
         "@com_google_absl//absl/types:span",
     ],
diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
index b309541..8e53ca1 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
@@ -14,7 +14,10 @@
 ==============================================================================*/
 #include "tensorflow/compiler/tf2xla/kernels/image_resize_ops.h"
 
+#include "absl/strings/str_format.h"
 #include "absl/types/span.h"
+#include "tensorflow/compiler/jit/xla_activity.pb.h"
+#include "tensorflow/compiler/jit/xla_activity_listener.h"
 #include "tensorflow/compiler/tf2xla/shape_util.h"
 #include "tensorflow/compiler/tf2xla/type_util.h"
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
@@ -255,6 +258,15 @@
 
   ResizeConvolutionDims dims =
       ComputeResizeConvolutionParameters(in_size, out_size, align_corners);
+
+  if (dims.kernel_size[0] * dims.kernel_size[1] >
+      kMax2DKernelSize * kMax2DKernelSize) {
+    BroadcastOptimizationRemark(
+        XlaOptimizationRemark::SLOW_IMAGE_RESIZE_DIMENSIONS,
+        absl::StrFormat("%dx%d", dims.kernel_size[0], dims.kernel_size[1]))
+        .IgnoreError();
+  }
+
   xla::XlaOp output;
 
   // Concatenation and padding below currently assumes num_spatial_dims is 2 to
diff --git a/tensorflow/compiler/tf2xla/kernels/roll_op.cc b/tensorflow/compiler/tf2xla/kernels/roll_op.cc
index a6cc596..99f4a5f 100644
--- a/tensorflow/compiler/tf2xla/kernels/roll_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/roll_op.cc
@@ -47,11 +47,8 @@
     xla::PrimitiveType shift_type = ctx->input_xla_type(1);
     int64 num_axes = axis_shape.dims() == 0 ? 1 : axis_shape.dim_size(0);
     for (int64 i = 0; i != num_axes; ++i) {
-      auto cur_axis_status = axis_shape.dims() == 0
-                                 ? axis.GetIntegralAsS64({})
-                                 : axis.GetIntegralAsS64({i});
-      OP_REQUIRES_OK(ctx, cur_axis_status.status());
-      int64 cur_axis = cur_axis_status.ValueOrDie();
+      int64 cur_axis = axis_shape.dims() == 0 ? *axis.GetIntegralAsS64({})
+                                              : *axis.GetIntegralAsS64({i});
 
       xla::XlaOp offset =
           shift_shape.dims() == 0
diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
index 247db8d..191ce9d 100644
--- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
@@ -270,6 +270,53 @@
 REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes),
                 ResourceApplyAdagrad);
 
+class ResourceApplyAdagradV2 : public XlaOpKernel {
+ public:
+  explicit ResourceApplyAdagradV2(OpKernelConstruction* ctx)
+      : XlaOpKernel(ctx) {}
+
+  void Compile(XlaOpKernelContext* ctx) override {
+    DataType type = ctx->input_type(2);
+
+    TensorShape var_shape, accum_shape;
+    xla::XlaOp var, accum;
+    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
+    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
+
+    OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
+                errors::InvalidArgument(
+                    "var and accum do not have the same shape",
+                    var_shape.DebugString(), " ", accum_shape.DebugString()));
+
+    TensorShape lr_shape = ctx->InputShape(2);
+    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
+                errors::InvalidArgument("lr is not a scalar: ",
+                                        lr_shape.DebugString()));
+
+    TensorShape epsilon_shape = ctx->InputShape(3);
+    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
+                errors::InvalidArgument("epsilon is not a scalar: ",
+                                        epsilon_shape.DebugString()));
+
+    TensorShape grad_shape = ctx->InputShape(4);
+    OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
+                errors::InvalidArgument(
+                    "var and grad do not have the same shape",
+                    var_shape.DebugString(), " ", grad_shape.DebugString()));
+
+    xla::XlaOp lr = ctx->Input(2);
+    xla::XlaOp epsilon = ctx->Input(3);
+    xla::XlaOp grad = ctx->Input(4);
+
+    accum = accum + xla::Square(grad);
+    var = var - grad * lr / (xla::Sqrt(accum) + epsilon);
+    OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
+    OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
+  }
+};
+REGISTER_XLA_OP(Name("ResourceApplyAdagradV2").TypeConstraint("T", kFloatTypes),
+                ResourceApplyAdagradV2);
+
 class ResourceApplyProximalAdagrad : public XlaOpKernel {
  public:
   explicit ResourceApplyProximalAdagrad(OpKernelConstruction* ctx)
diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD
index 3cc551e..eaba5d3 100644
--- a/tensorflow/compiler/tf2xla/python/BUILD
+++ b/tensorflow/compiler/tf2xla/python/BUILD
@@ -1,5 +1,5 @@
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_py_clif_cc",
 )
 load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
diff --git a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc
index b376fe9..b6f8928 100644
--- a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc
+++ b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc
@@ -527,7 +527,7 @@
 
   // Rewrite If/While nodes.
   for (Node* n : g->nodes()) {
-    if (n->type_string() == "While") {
+    if (n->IsWhileNode()) {
       bool node_rewritten;
       TF_RETURN_IF_ERROR(MaybeRewriteWhileNode(get_function_body_fn, g, n, fld,
                                                &node_rewritten));
diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc
index 1243e31..2db431c 100644
--- a/tensorflow/compiler/tf2xla/resource_operation_table.cc
+++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc
@@ -57,6 +57,7 @@
   add("ResourceApplyAdaMax"                  , kReadWrite, kVariable);
   add("ResourceApplyAdadelta"                , kReadWrite, kVariable);
   add("ResourceApplyAdagrad"                 , kReadWrite, kVariable);
+  add("ResourceApplyAdagradV2"               , kReadWrite, kVariable),
   add("ResourceApplyAdagradDA"               , kReadWrite, kVariable);
   add("ResourceApplyAdam"                    , kReadWrite, kVariable);
   add("ResourceApplyAddSign"                 , kReadWrite, kVariable);
diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc
index eebeec8..fb8b481 100644
--- a/tensorflow/compiler/tf2xla/side_effect_util.cc
+++ b/tensorflow/compiler/tf2xla/side_effect_util.cc
@@ -50,7 +50,7 @@
       node->ClearAttr(attr_name);
       node->AddAttr(attr_name, branch_func);
     }
-  } else if (node->type_string() == "While") {
+  } else if (node->IsWhileNode()) {
     AttrValue device_ordinal_value;
     device_ordinal_value.set_i(device_ordinal);
     for (const string& attr_name : std::vector<string>{"cond", "body"}) {
diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc
index 3e4188f..3c2b256 100644
--- a/tensorflow/compiler/tf2xla/tf2xla.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla.cc
@@ -384,8 +384,8 @@
   TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
       &second_copy_def, *g->op_registry(), /*node_offset=*/0));
 
-  TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
-                                            second_copy_def, g.get()));
+  TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
+      GraphConstructorOptions(), std::move(second_copy_def), g.get()));
   TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping));
 
   // Functionalize control flow.
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc
index 3e8b9eb..e82546d 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc
@@ -765,7 +765,7 @@
   for (Node* n : g->op_nodes()) {
     if (n->IsIfNode()) {
       TF_RETURN_IF_ERROR(PropagateConstIntoIfNode(g, n, lookup_fld, fld));
-    } else if (n->type_string() == "While") {
+    } else if (n->IsWhileNode()) {
       TF_RETURN_IF_ERROR(PropagateConstIntoWhileNode(g, n, lookup_fld, fld));
     }
   }
@@ -796,7 +796,7 @@
     // Find the forward While op.
     std::vector<const Edge*> fwd_while_edges;
     for (const Edge* e : n->out_edges()) {
-      if (!e->IsControlEdge() && e->dst()->type_string() == "While") {
+      if (!e->IsControlEdge() && e->dst()->IsWhileNode()) {
         fwd_while_edges.push_back(e);
       }
     }
@@ -810,8 +810,7 @@
     int fwd_while_dst_input = fwd_while_edges[0]->dst_input();
     std::vector<const Edge*> bwd_while_edges;
     for (const Edge* e : fwd_while->out_edges()) {
-      if (e->src_output() == fwd_while_dst_input &&
-          e->dst()->type_string() == "While") {
+      if (e->src_output() == fwd_while_dst_input && e->dst()->IsWhileNode()) {
         bwd_while_edges.push_back(e);
       }
     }
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 3959f13..0121f83 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -21,6 +21,7 @@
 #include "absl/memory/memory.h"
 #include "absl/types/variant.h"
 #include "tensorflow/compiler/jit/flags.h"
+#include "tensorflow/compiler/jit/shape_inference.h"
 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
 #include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
 #include "tensorflow/compiler/tf2xla/shape_util.h"
@@ -529,6 +530,11 @@
 std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
   std::unique_ptr<Graph> graph(new Graph(options_.flib_def));
   CopyGraph(*fbody->graph, graph.get());
+
+  // Performs a first function inlining pass before shape inference, since
+  // otherwise shape inference can't see inside functions and a comprehensive
+  // shape_map, including function ops, is needed to constant-propagate Shape
+  // Ops below.
   auto flags = GetBuildXlaOpsPassFlags();
   OptimizerOptions opts;
   opts.set_opt_level(OptimizerOptions::L0);
@@ -567,6 +573,28 @@
   optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
                      /*device=*/nullptr, &graph, graph_optimizer_options);
 
+  // Run shape inference on the graph and optimize the graph again.
+  GraphShapeInfo shape_info;
+  InferShapes(graph.get(), /*arg_shapes=*/{},
+              flib_runtime_->GetFunctionLibraryDefinition(), &shape_info)
+      .IgnoreError();
+  auto node_name_index = graph->BuildNodeNameIndex();
+  std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
+  for (const auto& node_shape_info : shape_info) {
+    const string& node_name = node_shape_info.first;
+    const std::vector<InferredShape>& output_shapes = node_shape_info.second;
+    const auto& node_iter = node_name_index.find(node_name);
+    if (node_iter != node_name_index.end()) {
+      auto& partial_shapes = shape_map[node_name];
+      for (const auto& inferred_shape : output_shapes) {
+        partial_shapes.push_back(inferred_shape.shape);
+      }
+    }
+  }
+  graph_optimizer_options.shape_map = &shape_map;
+  optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
+                     /*device=*/nullptr, &graph, graph_optimizer_options);
+
   return graph;
 }
 
@@ -593,6 +621,33 @@
       CheckSignature(fbody->arg_types, args),
       "Signature check failure while compiling: ", fn_name_attrs.name());
 
+  // Set shapes for _Arg nodes. They are useful for constant folding (e.g. an
+  // Xla op requires a compile-time constant input, and that input is shape of
+  // an _Arg node.
+  for (int i = 0; i < args.size(); i++) {
+    // Skip resource variables and tensor lists.
+    DataType dtype;
+    TF_RETURN_IF_ERROR(GetNodeAttr(fbody->arg_nodes[i]->def(), "T", &dtype));
+    if (dtype == DT_RESOURCE || dtype == DT_VARIANT) {
+      continue;
+    }
+
+    if (absl::holds_alternative<xla::Shape>(args[i].shape)) {
+      xla::Shape xla_shape = absl::get<xla::Shape>(args[i].shape);
+      TensorShape tensor_shape;
+      if (XLAShapeToTensorShape(xla_shape, &tensor_shape).ok()) {
+        fbody->arg_nodes[i]->ClearAttr("_output_shapes");
+        fbody->arg_nodes[i]->AddAttr("_output_shapes",
+                                     std::vector<TensorShape>{tensor_shape});
+      }
+    } else {
+      TensorShape tensor_shape = absl::get<TensorShape>(args[i].shape);
+      fbody->arg_nodes[i]->ClearAttr("_output_shapes");
+      fbody->arg_nodes[i]->AddAttr("_output_shapes",
+                                   std::vector<TensorShape>{tensor_shape});
+    }
+  }
+
   std::unique_ptr<Graph> graph = GetGraph(fbody);
 
   // Clear the "_kernel" attribute if it is set to "host". This is used to
@@ -601,7 +656,7 @@
   const char* const kKernelAttr = "_kernel";
   for (Node* n : graph->nodes()) {
     string value;
-    if (GetNodeAttrSimple(n->attrs(), kKernelAttr, &value) && value == "host") {
+    if (TryGetNodeAttr(n->attrs(), kKernelAttr, &value) && value == "host") {
       n->ClearAttr(kKernelAttr);
     }
   }
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 2bafc74..0a4448b 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -1,7 +1,7 @@
 load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
 load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_proto_library_py",
 )
 
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index acf59c4..b46d04d 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -296,6 +296,8 @@
     srcs = ["slicing.cc"],
     hdrs = ["slicing.h"],
     deps = [
+        ":arithmetic",
+        ":constants",
         "//tensorflow/compiler/xla:types",
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla/client:xla_builder",
diff --git a/tensorflow/compiler/xla/client/lib/slicing.cc b/tensorflow/compiler/xla/client/lib/slicing.cc
index d4bc560..f10342a 100644
--- a/tensorflow/compiler/xla/client/lib/slicing.cc
+++ b/tensorflow/compiler/xla/client/lib/slicing.cc
@@ -17,6 +17,8 @@
 
 #include <limits>
 
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/util.h"
 
@@ -138,18 +140,54 @@
   });
 }
 
-XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim) {
+XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim, bool sparse) {
   XlaBuilder* builder = input.builder();
   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
-    ShapeUtil::AppendMajorDimension(1, &index_shape);
-    std::vector<XlaOp> to_concat;
     TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
     if (ShapeUtil::ElementHasBitWidth(index_shape, 64) &&
         input_shape.dimensions(dim) < std::numeric_limits<uint32>::max()) {
       index = ConvertElementType(index, U32);
       index_shape.set_element_type(U32);
     }
+    if (index_shape.rank() == 1) {
+      return TorchIndexSelect(input, index, 0);
+    }
+    if (!sparse) {
+      std::vector<int64> index_broacast_dims;
+      std::vector<int64> input_broacast_dims;
+      std::vector<int64> sizes;
+      for (int64 i = 0; i < index_shape.rank(); ++i) {
+        if (i < dim) {
+          input_broacast_dims.push_back(i);
+          index_broacast_dims.push_back(i);
+        } else if (i == dim) {
+          sizes.push_back(input_shape.dimensions(i));
+          input_broacast_dims.push_back(i);
+          index_broacast_dims.push_back(i + 1);
+        } else {
+          input_broacast_dims.push_back(i + 1);
+          index_broacast_dims.push_back(i + 1);
+        }
+        sizes.push_back(index_shape.dimensions(i));
+      }
+      auto mask = Eq(
+          BroadcastInDim(index, sizes, index_broacast_dims),
+          Iota(builder, ShapeUtil::MakeShape(index_shape.element_type(), sizes),
+               dim));
+      auto masked_input = Select(
+          mask, BroadcastInDim(input, sizes, input_broacast_dims),
+          Zeros(builder,
+                ShapeUtil::MakeShape(input_shape.element_type(), sizes)));
+      return Reduce(masked_input, Zero(builder, input_shape.element_type()),
+                    CreateScalarIdentityWithZeroComputation(
+                        input_shape.element_type(), builder),
+                    {dim});
+    }
+
+    ShapeUtil::AppendMajorDimension(1, &index_shape);
+    std::vector<XlaOp> to_concat;
+
     to_concat.reserve(input_shape.rank());
     for (int64 i = 0; i < input_shape.rank(); ++i) {
       if (i == dim) {
diff --git a/tensorflow/compiler/xla/client/lib/slicing.h b/tensorflow/compiler/xla/client/lib/slicing.h
index 89ec1fe..9a59a04 100644
--- a/tensorflow/compiler/xla/client/lib/slicing.h
+++ b/tensorflow/compiler/xla/client/lib/slicing.h
@@ -55,7 +55,7 @@
 // [X0,X1,X2,..XN] and dim = i `index` must be an n-dimensional tensor with size
 // [X0,X1,...Y,Xi+1,...,X[N] where y >= 1 and `out` will have the same sizes as
 // `index`.
-XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim);
+XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim, bool sparse = true);
 
 // Returns a new tensor which indexes the input tensor along dimension dim using
 // the entries in index.
diff --git a/tensorflow/compiler/xla/client/lib/slicing_test.cc b/tensorflow/compiler/xla/client/lib/slicing_test.cc
index 04d3f96..107cbae 100644
--- a/tensorflow/compiler/xla/client/lib/slicing_test.cc
+++ b/tensorflow/compiler/xla/client/lib/slicing_test.cc
@@ -102,7 +102,7 @@
       {a_data.get(), b_data.get(), x_data.get(), y_data.get()});
 }
 
-XLA_TEST_F(SlicingTest, TorchGather) {
+XLA_TEST_F(SlicingTest, TorchGatherSparse) {
   xla::XlaBuilder builder(TestName());
 
   xla::XlaOp input, index;
@@ -116,6 +116,20 @@
                            {input_data.get(), index_data.get()});
 }
 
+XLA_TEST_F(SlicingTest, TorchGatherDense) {
+  xla::XlaBuilder builder(TestName());
+
+  xla::XlaOp input, index;
+  auto input_data =
+      CreateR2Parameter<int>({{1, 2}, {3, 4}}, 0, "input", &builder, &input);
+  auto index_data =
+      CreateR2Parameter<int>({{0, 0}, {1, 0}}, 1, "index", &builder, &index);
+  TorchGather(input, index, 1, false);
+
+  ComputeAndCompareR2<int>(&builder, {{1, 1}, {4, 3}},
+                           {input_data.get(), index_data.get()});
+}
+
 XLA_TEST_F(SlicingTest, TorchIndexSelectOn0) {
   xla::XlaBuilder builder(TestName());
 
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index 1bd9d7b..153cb9f 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -176,12 +176,13 @@
     ExecutableRunOptions run_options) {
   TF_ASSIGN_OR_RETURN(auto options_and_stream,
                       RunHelper(arguments, run_options));
-
-  if (executable_->dumping_snapshot()) {
-    return ExecuteAndDump(&options_and_stream.first, arguments);
-  }
-  return executable_->ExecuteOnStreamWrapper(
-      &options_and_stream.first, run_options.execution_profile(), arguments);
+  ExecutableRunOptions options = options_and_stream.first.run_options();
+  options.set_device_ordinal(-1);
+  auto result = RunAsync(arguments, options);
+  Status block_status = options.stream()->BlockHostUntilDone();
+  TF_RETURN_IF_ERROR(result.status());
+  TF_RETURN_IF_ERROR(block_status);
+  return result;
 }
 
 StatusOr<ScopedShapedBuffer> LocalExecutable::RunAsync(
@@ -189,50 +190,49 @@
     ExecutableRunOptions run_options) {
   TF_ASSIGN_OR_RETURN(auto options_and_stream,
                       RunHelper(arguments, run_options));
-  return executable_->ExecuteAsyncOnStream(&options_and_stream.first,
-                                           arguments);
-}
+  se::Stream* stream = run_options.stream();
 
-StatusOr<ScopedShapedBuffer> LocalExecutable::ExecuteAndDump(
-    const ServiceExecutableRunOptions* run_options,
-    const absl::Span<const ShapedBuffer* const> arguments) {
-  executable_->hlo_snapshot()->set_execution_platform(
-      backend_->platform()->Name());
-  TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->hlo_snapshot()));
-  TF_ASSIGN_OR_RETURN(
-      ScopedShapedBuffer result,
-      executable_->ExecuteOnStream(run_options, arguments,
-                                   /*hlo_execution_profile=*/nullptr));
-  TF_RETURN_IF_ERROR(RecordResult(&result, executable_->hlo_snapshot()));
-  DumpHloSnapshotIfEnabled(executable_->module(), *executable_->hlo_snapshot());
-  return std::move(result);
-}
-
-Status LocalExecutable::RecordArguments(
-    const absl::Span<const ShapedBuffer* const> arguments,
-    HloSnapshot* hlo_snapshot) {
-  hlo_snapshot->clear_arguments();
-  for (const ShapedBuffer* argument : arguments) {
-    TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*argument));
-    *hlo_snapshot->add_arguments() = literal.ToProto();
+  std::shared_ptr<HloSnapshot> snapshot;
+  if (executable_->dumping_snapshot()) {
+    snapshot = std::make_shared<HloSnapshot>();
+    snapshot->set_execution_platform(backend_->platform()->Name());
+    *snapshot->mutable_hlo() = *executable_->hlo_proto();
+    for (const ShapedBuffer* arg : arguments) {
+      auto literal = std::make_shared<Literal>(arg->on_host_shape());
+      backend_->transfer_manager()->TransferLiteralFromDevice(
+          stream, *arg, literal.get(), [snapshot, literal](Status status) {
+            if (!status.ok()) {
+              LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot inputs "
+                            "failed: "
+                         << status;
+              return;
+            }
+            *snapshot->add_arguments() = literal->ToProto();
+          });
+    }
   }
-  return Status::OK();
-}
 
-Status LocalExecutable::RecordResult(const ShapedBuffer* result,
-                                     HloSnapshot* hlo_snapshot) {
-  hlo_snapshot->clear_result();
-  TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*result));
-  *hlo_snapshot->mutable_result() = literal.ToProto();
-  return Status::OK();
-}
+  TF_ASSIGN_OR_RETURN(ScopedShapedBuffer outputs,
+                      executable_->ExecuteAsyncOnStreamWrapper(
+                          &options_and_stream.first, arguments));
 
-StatusOr<Literal> LocalExecutable::LiteralFromShapedBuffer(
-    const ShapedBuffer& shaped_buffer) {
-  TF_ASSIGN_OR_RETURN(auto stream,
-                      backend_->BorrowStream(shaped_buffer.device_ordinal()));
-  return backend_->transfer_manager()->TransferLiteralFromDevice(stream.get(),
-                                                                 shaped_buffer);
+  // Transfer the outputs and save the snapshot to disk.
+  if (snapshot) {
+    auto literal = std::make_shared<Literal>(outputs.on_host_shape());
+    backend_->transfer_manager()->TransferLiteralFromDevice(
+        stream, outputs, literal.get(), [snapshot, literal](Status status) {
+          if (status.ok()) {
+            *snapshot->mutable_result() = literal->ToProto();
+          } else {
+            LOG(ERROR)
+                << "TransferLiteralFromDevice for HLO snapshot outputs failed: "
+                << status;
+          }
+          DumpHloSnapshotIfEnabled(*snapshot, GetDebugOptionsFromFlags());
+        });
+  }
+
+  return std::move(outputs);
 }
 
 se::Platform* LocalClient::platform() const {
diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h
index 1e7c97d..b697fb0 100644
--- a/tensorflow/compiler/xla/client/local_client.h
+++ b/tensorflow/compiler/xla/client/local_client.h
@@ -72,23 +72,6 @@
       const absl::Span<const ShapedBuffer* const> arguments,
       const ExecutableRunOptions& run_options, const Backend& backend);
 
-  // Records the computation in a SessionModule proto with the arguments used to
-  // invoke it, and the result. Enabled by flag: --xla_dump_hlo_snapshots.
-  //
-  // The given ServiceExecutableRunOptions override any values from the
-  // XLA_FLAGS environment variable.
-  StatusOr<ScopedShapedBuffer> ExecuteAndDump(
-      const ServiceExecutableRunOptions* run_options,
-      const absl::Span<const ShapedBuffer* const> arguments);
-
-  // Records the arguments used to invoke the computation in a SessionModule
-  // proto.
-  Status RecordArguments(const absl::Span<const ShapedBuffer* const> arguments,
-                         HloSnapshot* hlo_snapshot);
-
-  // Records the result of the computation in a SessionModule proto.
-  Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot);
-
   // Returns a literal containing the contents of the given ShapedBuffer.
   StatusOr<Literal> LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer);
 
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 318d5f3..dd20fd3 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -1028,6 +1028,11 @@
           "Operand to GetTupleElement() is not a tuple; got %s",
           ShapeUtil::HumanString(tuple_shape));
     }
+    if (index < 0 || index >= ShapeUtil::TupleElementCount(tuple_shape)) {
+      return InvalidArgument(
+          "GetTupleElement() index (%d) out of range for tuple shape %s", index,
+          ShapeUtil::HumanString(tuple_shape));
+    }
     *instr.mutable_shape() =
         ShapeUtil::GetTupleElementShape(tuple_shape, index).ToProto();
 
diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc
index 45f9cbe..93ae3d2 100644
--- a/tensorflow/compiler/xla/debug_options_flags.cc
+++ b/tensorflow/compiler/xla/debug_options_flags.cc
@@ -149,6 +149,12 @@
         return true;
       };
 
+  // Custom "sub-parser" lambda for xla_gpu_ptx_file.
+  auto setter_for_xla_gpu_ptx_file = [](string value) {
+    flag_values->add_xla_gpu_ptx_file(value);
+    return true;
+  };
+
   // Custom "sub-parser" lambda for xla_backend_extra_options.
   auto setter_for_xla_backend_extra_options =
       [](string comma_separated_values) {
@@ -342,6 +348,13 @@
           int32_setter_for(&DebugOptions::set_xla_gpu_max_kernel_unroll_factor),
           flag_values->xla_gpu_max_kernel_unroll_factor(),
           "Specify the maximum kernel unroll factor for the GPU backend."),
+      tensorflow::Flag("xla_gpu_ptx_file", setter_for_xla_gpu_ptx_file, "",
+                       "If non-empty, speficies a file containing ptx to use. "
+                       "The filename prefix must have the same pattern as PTX "
+                       "dumped by XLA. This allows to match one specific "
+                       "module. General workflow. Get the generated module "
+                       "ptx from XLA. Modify it. Then pass it back via this "
+                       "option."),
       tensorflow::Flag(
           "xla_test_all_output_layouts",
           bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts),
@@ -508,6 +521,13 @@
           bool_setter_for(&DebugOptions::set_xla_gpu_force_conv_nchw),
           flag_values->xla_gpu_force_conv_nchw(),
           "For cuDNN convolutions, always NCHW layouts."),
+      tensorflow::Flag(
+          "xla_gpu_cudnn_conv_blacklist_path",
+          string_setter_for(
+              &DebugOptions::set_xla_gpu_cudnn_conv_blacklist_path),
+          flag_values->xla_gpu_cudnn_conv_blacklist_path(),
+          "A CudnnConvolutionList text proto file as a blacklist of "
+          "convolutions to avoid to use."),
   });
   ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects);
 }
diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md
index 7bf48d5..054fccf 100644
--- a/tensorflow/compiler/xla/g3doc/operation_semantics.md
+++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md
@@ -1385,7 +1385,7 @@
 
 The output is an array of rank `batch_dims.size` + `offset_dims.size`.
 
-The `operand.rank` must equal the sume of `offset_dims.size` and
+The `operand.rank` must equal the sum of `offset_dims.size` and
 `collapsed_slice_dims`. Also, `slice_sizes.size` has to be equal to
 `operand.rank`.
 
@@ -2532,6 +2532,11 @@
 :                 :                     : respective `start_indices` value for :
 :                 :                     : the dimension and less than or equal :
 :                 :                     : to the size of the dimension.        :
+| `strides`      | `ArraySlice<int64>` | List of N integers that decides the   |
+:                 :                     : input stride of the slice.  The slice :
+:                 :                     : picks every `strides[d]` element in  :
+:                 :                     : dimension `d`.                       :
+
 
 1-dimensional example:
 
diff --git a/tensorflow/compiler/xla/layout.h b/tensorflow/compiler/xla/layout.h
index f216bd6..36e1ece 100644
--- a/tensorflow/compiler/xla/layout.h
+++ b/tensorflow/compiler/xla/layout.h
@@ -136,6 +136,7 @@
     Equal& MinorToMajorOnly() {
       ignore_tiles_ = true;
       ignore_element_size_ = true;
+      ignore_memory_space_ = true;
       return *this;
     }
 
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index 63d9a1e..03b47ba 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -891,7 +891,7 @@
   }
 }
 
-StatusOr<int64> LiteralBase::GetIntegralAsS64(
+absl::optional<int64> LiteralBase::GetIntegralAsS64(
     absl::Span<const int64> multi_index) const {
   CHECK(LayoutUtil::IsDenseArray(shape()));
   switch (shape().element_type()) {
@@ -908,12 +908,11 @@
     case U64:
       return Get<uint64>(multi_index);
     default:
-      return FailedPrecondition("Array element type is not integral: %s",
-                                PrimitiveType_Name(shape().element_type()));
+      return absl::nullopt;
   }
 }
 
-StatusOr<double> LiteralBase::GetAsDouble(
+absl::optional<double> LiteralBase::GetAsDouble(
     absl::Span<const int64> multi_index) const {
   CHECK(LayoutUtil::IsDenseArray(shape()));
   switch (shape().element_type()) {
@@ -926,8 +925,27 @@
     case BF16:
       return static_cast<double>(Get<bfloat16>(multi_index));
     default:
-      return FailedPrecondition("Array element type is not floating: %s",
-                                PrimitiveType_Name(shape().element_type()));
+      return absl::nullopt;
+  }
+}
+
+absl::optional<complex128> LiteralBase::GetAsComplex128(
+    absl::Span<const int64> multi_index) const {
+  switch (shape().element_type()) {
+    case BF16:
+      return {{static_cast<double>(Get<bfloat16>(multi_index)), 0}};
+    case F16:
+      return {{static_cast<double>(Get<Eigen::half>(multi_index)), 0}};
+    case F32:
+      return {{Get<float>(multi_index), 0}};
+    case F64:
+      return {{Get<double>(multi_index), 0}};
+    case C64:
+      return {Get<complex64>(multi_index)};
+    case C128:
+      return {Get<complex128>(multi_index)};
+    default:
+      return absl::nullopt;
   }
 }
 
diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h
index ffd5a88..af15cab 100644
--- a/tensorflow/compiler/xla/literal.h
+++ b/tensorflow/compiler/xla/literal.h
@@ -130,13 +130,47 @@
   // value into text.
   string GetSparseElementAsString(int64 sparse_element_number,
                                   const ShapeIndex& shape_index = {}) const;
+
+  // Return whether the value at the specified index is equal to the provided
+  // generic `value` (T must be an arithmetic type).
+  //
+  // Precondition: must be an array.
+  template <typename T>
+  typename std::enable_if<(std::is_arithmetic<T>::value ||
+                           std::is_same<T, Eigen::half>::value ||
+                           std::is_same<T, bfloat16>::value),
+                          bool>::type
+  IsEqualAt(absl::Span<const int64> multi_index, T value) const {
+    if (auto as_s64 = GetIntegralAsS64(multi_index)) {
+      return *as_s64 == value;
+    }
+    complex128 as_complex128 = *GetAsComplex128(multi_index);
+    return as_complex128.imag() == 0 && as_complex128.real() == value;
+  }
+
+  bool IsEqualAt(absl::Span<const int64> multi_index, complex128 value) const {
+    if (auto as_s64 = GetIntegralAsS64(multi_index)) {
+      return *as_s64 == value.real() && value.imag() == 0;
+    }
+    auto as_complex128 = GetAsComplex128(multi_index);
+    return *as_complex128 == value;
+  }
+
   // As Get(), but determines the correct type and converts the value into
   // int64.  This literal must be an array.
-  StatusOr<int64> GetIntegralAsS64(absl::Span<const int64> multi_index) const;
+  absl::optional<int64> GetIntegralAsS64(
+      absl::Span<const int64> multi_index) const;
 
   // As Get(), but determines the correct type, and converts the value into
   // double. This literal must be an array.
-  StatusOr<double> GetAsDouble(absl::Span<const int64> multi_index) const;
+  absl::optional<double> GetAsDouble(absl::Span<const int64> multi_index) const;
+
+  // As Get(), but determines the correct type, and converts the value into
+  // complex128. All floating point types can be converted into complex128.
+  //
+  // This literal must be an array.
+  absl::optional<complex128> GetAsComplex128(
+      absl::Span<const int64> multi_index) const;
 
   // Returns the multi-index of the element in a sparse literal at the given
   // sparse element number.  The sparse element number is the position with in
diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc
index 8d46d30..885d18d 100644
--- a/tensorflow/compiler/xla/literal_test.cc
+++ b/tensorflow/compiler/xla/literal_test.cc
@@ -2021,5 +2021,46 @@
             LiteralUtil::CreateR2<int32>({{9, 9}, {9, 9}}));
 }
 
+TEST_F(LiteralUtilTest, GetAsComplex128) {
+  complex128 value = {1, 0};
+  Literal c1 = LiteralUtil::CreateR0<complex128>(value);
+  EXPECT_EQ(*c1.GetAsComplex128({}), value);
+  Literal c2 = LiteralUtil::CreateR0<double>(1);
+  EXPECT_EQ(*c2.GetAsComplex128({}), value);
+  complex64 float_value = {1, 0};
+  Literal c4 = LiteralUtil::CreateR0<complex64>(float_value);
+  EXPECT_EQ(*c4.GetAsComplex128({}), value);
+  complex128 other_value = {1, 2};
+  Literal c5 = LiteralUtil::CreateR0<complex128>(other_value);
+  EXPECT_EQ(*c5.GetAsComplex128({}), other_value);
+  Literal c6 = LiteralUtil::CreateR0<int64>(1);
+  EXPECT_FALSE(c6.GetAsComplex128({}).has_value());
+}
+
+TEST_F(LiteralUtilTest, IsEqualAt) {
+  double val_double = 10.0;
+  int val_integral = 10;
+  Literal c1 = LiteralUtil::CreateR0<int>(10);
+  EXPECT_TRUE(c1.IsEqualAt({}, val_double));
+  EXPECT_TRUE(c1.IsEqualAt({}, val_integral));
+  Literal c2 = LiteralUtil::CreateR0<double>(10);
+  EXPECT_TRUE(c2.IsEqualAt({}, val_double));
+  EXPECT_TRUE(c2.IsEqualAt({}, val_integral));
+  complex128 val_complex = {10, 0};
+  EXPECT_TRUE(c2.IsEqualAt({}, val_complex));
+  EXPECT_TRUE(c1.IsEqualAt({}, val_complex));
+  Literal c3 = LiteralUtil::CreateR0<complex128>(val_complex);
+  EXPECT_TRUE(c3.IsEqualAt({}, val_double));
+  EXPECT_TRUE(c3.IsEqualAt({}, val_integral));
+  EXPECT_TRUE(c3.IsEqualAt({}, val_complex));
+  double val_inf = 1. / 0;
+  EXPECT_FALSE(c3.IsEqualAt({}, val_inf));
+  complex128 val_true_complex = {10, 3};
+  complex64 val_smaller_complex = {10, 3};
+  Literal c4 = LiteralUtil::CreateR0<complex128>(val_true_complex);
+  EXPECT_TRUE(c4.IsEqualAt({}, val_true_complex));
+  EXPECT_TRUE(c4.IsEqualAt({}, val_smaller_complex));
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc
index e476015..b7c3053 100644
--- a/tensorflow/compiler/xla/protobuf_util.cc
+++ b/tensorflow/compiler/xla/protobuf_util.cc
@@ -39,12 +39,17 @@
 }
 
 Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
-                            const string& directory, const string& file_name) {
+                            const string& directory, const string& file_name,
+                            string* full_path) {
   tensorflow::Env* env = tensorflow::Env::Default();
   TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory));
   string safe_file_name = SanitizeFileName(file_name) + ".pb";
-  const string path = tensorflow::io::JoinPath(directory, safe_file_name);
-  return tensorflow::WriteBinaryProto(env, path, message);
+  string full_path_impl;
+  if (!full_path) {
+    full_path = &full_path_impl;
+  }
+  *full_path = tensorflow::io::JoinPath(directory, safe_file_name);
+  return tensorflow::WriteBinaryProto(env, *full_path, message);
 }
 
 }  // namespace protobuf_util
diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h
index e20a7e9..7db0209 100644
--- a/tensorflow/compiler/xla/protobuf_util.h
+++ b/tensorflow/compiler/xla/protobuf_util.h
@@ -37,8 +37,12 @@
 // 'directory/file_name.pb'. The 'directory' is recursively created if it
 // doesn't already exist, and the 'file_name' is sanitized by replacing
 // illegal characters with underscore '_'.
+//
+// If 'full_name' is not null then it is set to the name of the file the
+// protobuf was written to.
 Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
-                            const string& directory, const string& file_name);
+                            const string& directory, const string& file_name,
+                            string* full_path = nullptr);
 
 // Registers a function that may either expand a dirpath or forward the original
 // dirpath along as-is.
diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
index b287767..696aa98 100644
--- a/tensorflow/compiler/xla/python/BUILD
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -1,7 +1,7 @@
-load("//tensorflow/core:platform/default/build_config.bzl", "pyx_library")
-load("//tensorflow/compiler/xla:xla.bzl", "xla_python_default_plugins")
-load("//tensorflow:tensorflow.bzl", "tf_pybind_extension")
-load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load("//tensorflow/core/platform:default/build_config.bzl", "pyx_library")
+load("//tensorflow/compiler/xla:xla.bzl", "xla_py_test_deps", "xla_python_default_plugins")
+load("//tensorflow:tensorflow.bzl", "pybind_extension")
+load("//tensorflow:tensorflow.bzl", "py_test", "tf_cc_test")
 
 package(
     default_visibility = ["//tensorflow:internal"],
@@ -29,15 +29,14 @@
     name = "xla_client_test",
     srcs = ["xla_client_test.py"],
     main = "xla_client_test.py",
-    python_version = "PY2",
     srcs_version = "PY2AND3",
-    tags = ["no_oss"],
+    tags = ["no_oss"],  # TODO(phawkins): This test passes, but requires --config=monolithic.
     deps = [
         ":custom_call_for_test",
         ":xla_client",
-        "//tensorflow/compiler/xla:xla_data_proto_py",
-        "//tensorflow/python:platform_test",
-    ],
+        ":xla_extension",
+        "@absl_py//absl/testing:absltest",
+    ] + xla_py_test_deps(),
 )
 
 cc_library(
@@ -240,7 +239,7 @@
     ],
 )
 
-tf_pybind_extension(
+pybind_extension(
     name = "xla_extension",
     srcs = [
         "xla.cc",
diff --git a/tensorflow/compiler/xla/python/custom_call_for_test.pyx b/tensorflow/compiler/xla/python/custom_call_for_test.pyx
index 530dffd..4f7c4c3 100644
--- a/tensorflow/compiler/xla/python/custom_call_for_test.pyx
+++ b/tensorflow/compiler/xla/python/custom_call_for_test.pyx
@@ -15,7 +15,7 @@
 cpu_custom_call_targets = {}
 
 cdef register_custom_call_target(fn_name, void* fn):
-  cdef const char* name = "xla._CPU_CUSTOM_CALL_TARGET"
+  cdef const char* name = "xla._CUSTOM_CALL_TARGET"
   cpu_custom_call_targets[fn_name] = PyCapsule_New(fn, name, NULL)
 
 register_custom_call_target(b"test_subtract_f32", <void*>(test_subtract_f32))
diff --git a/tensorflow/compiler/xla/python/device.cc b/tensorflow/compiler/xla/python/device.cc
index 73df698..27af9ad 100644
--- a/tensorflow/compiler/xla/python/device.cc
+++ b/tensorflow/compiler/xla/python/device.cc
@@ -64,6 +64,7 @@
   // stopped, also block on the compute stream. If SynchronizeAllActivity is
   // fixed, we could remove the BlockHostUntilDone call.
   status.Update(compute_stream_->BlockHostUntilDone());
+  status.Update(callback_stream_->BlockHostUntilDone());
   bool ok = compute_stream_->parent()->SynchronizeAllActivity();
   if (!ok) {
     status.Update(Unknown("SynchronizeAllActivity failed."));
diff --git a/tensorflow/compiler/xla/python/local_client.h b/tensorflow/compiler/xla/python/local_client.h
index 7496d53..a0ca85e 100644
--- a/tensorflow/compiler/xla/python/local_client.h
+++ b/tensorflow/compiler/xla/python/local_client.h
@@ -212,6 +212,10 @@
     return executable_->build_options().num_replicas();
   }
 
+  int64 SizeOfGeneratedCodeInBytes() const {
+    return executable_->executable()->SizeOfGeneratedCodeInBytes();
+  }
+
   // Returns the device ordinals to which each replica is assigned.
   std::vector<int> DeviceOrdinals() const;
 
diff --git a/tensorflow/compiler/xla/python/types.h b/tensorflow/compiler/xla/python/types.h
index bc0ee2b..1873249 100644
--- a/tensorflow/compiler/xla/python/types.h
+++ b/tensorflow/compiler/xla/python/types.h
@@ -104,7 +104,7 @@
   using value_conv = make_caster<T>;
 
   PYBIND11_TYPE_CASTER(absl::Span<const T>,
-                       _("Span[") + value_conv::name() + _("]"));
+                       _("Span[") + value_conv::name + _("]"));
 
   // absl::Span doesn't hold ownership. We therefore need a temporary array.
   // Pybind appears to keep type_casters alive until the callee has run.
@@ -151,7 +151,7 @@
   using value_conv = make_caster<T>;
 
   PYBIND11_TYPE_CASTER(xla::StatusOr<T>,
-                       _("StatusOr[") + value_conv::name() + _("]"));
+                       _("StatusOr[") + value_conv::name + _("]"));
 
   static handle cast(xla::StatusOr<T> src, return_value_policy policy,
                      handle parent) {
diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc
index d8a4aaa..9f9209f 100644
--- a/tensorflow/compiler/xla/python/xla.cc
+++ b/tensorflow/compiler/xla/python/xla.cc
@@ -110,18 +110,23 @@
 }
 
 // Registers a 'fn_capsule' as a CPU custom call target.
-// 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name
-// "xla._CPU_CUSTOM_CALL_TARGET".
-Status RegisterCpuCustomCallTarget(const std::string& fn_name,
-                                   py::capsule capsule) {
-  static const char* const kName = "xla._CPU_CUSTOM_CALL_TARGET";
-  if (absl::string_view(capsule.name()) != kName) {
+// 'fn_capsule' must be a void* pointer encapsulated in a PyCapsule object,
+// with name "xla._CUSTOM_CALL_TARGET".
+// 'platform' is an XLA platform name, e.g., "Host" or "CUDA".
+Status PyRegisterCustomCallTarget(const std::string& fn_name,
+                                  py::capsule capsule,
+                                  const std::string& platform) {
+  static const char* const kName = "xla._CUSTOM_CALL_TARGET";
+  // TODO(phawkins): remove old name after fixing users.
+  static const char* const kOldCpuName = "xla._CPU_CUSTOM_CALL_TARGET";
+  if (absl::string_view(capsule.name()) != kName &&
+      absl::string_view(capsule.name()) != kOldCpuName) {
     return InvalidArgument(
-        "Argument to RegisterCpuCustomCallTargetRegistry was not a "
-        "xla._CPU_CUSTOM_CALL_TARGET capsule.");
+        "Argument to RegisterCustomCallTargetRegistry was not a "
+        "xla._CUSTOM_CALL_TARGET capsule.");
   }
   CustomCallTargetRegistry::Global()->Register(
-      fn_name, static_cast<void*>(capsule), "Host");
+      fn_name, static_cast<void*>(capsule), platform);
   return Status::OK();
 }
 
@@ -295,8 +300,8 @@
 
   // Local XLA client methods.
 
-  // CPU custom-call targets.
-  m.def("RegisterCpuCustomCallTarget", &RegisterCpuCustomCallTarget);
+  // Custom-call targets.
+  m.def("RegisterCustomCallTarget", &PyRegisterCustomCallTarget);
 
   py::class_<AllocatorConfig> alloc_config(m, "AllocatorConfig");
   alloc_config.def(py::init<>())
@@ -407,6 +412,8 @@
       .def_static("Compile", &PyLocalExecutable::Compile,
                   py::call_guard<py::gil_scoped_release>())
       .def("DeviceOrdinals", &PyLocalExecutable::DeviceOrdinals)
+      .def("SizeOfGeneratedCodeInBytes",
+           &PyLocalExecutable::SizeOfGeneratedCodeInBytes)
       .def("Delete", &PyLocalExecutable::Delete)
       .def("Execute", &PyLocalExecutable::Execute,
            py::call_guard<py::gil_scoped_release>(), py::arg("arguments"))
@@ -425,7 +432,10 @@
                     &DebugOptions::set_xla_cpu_fast_math_honor_nans)
       .def_property("xla_cpu_fast_math_honor_division",
                     &DebugOptions::xla_cpu_fast_math_honor_division,
-                    &DebugOptions::set_xla_cpu_fast_math_honor_division);
+                    &DebugOptions::set_xla_cpu_fast_math_honor_division)
+      .def_property("xla_gpu_enable_fast_min_max",
+                    &DebugOptions::xla_gpu_enable_fast_min_max,
+                    &DebugOptions::set_xla_gpu_enable_fast_min_max);
 
   py::class_<ExecutableBuildOptions>(m, "ExecutableBuildOptions")
       .def(py::init<>())
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index 7e5692f..3ef28b6 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -109,15 +109,24 @@
     options.debug_options.xla_cpu_fast_math_honor_infs = True
     options.debug_options.xla_cpu_fast_math_honor_nans = True
     options.debug_options.xla_cpu_fast_math_honor_division = True
+    options.debug_options.xla_gpu_enable_fast_min_max = False
     return _xla.LocalExecutable.Compile(c_computation,
                                         compile_options.argument_layouts,
                                         options, self.client,
                                         compile_options.device_assignment)
 
 
+xla_platform_names = {
+    'cpu': 'Host',
+    'gpu': 'CUDA',
+}
+
+
 def _cpu_backend_factory():
   client = _xla.LocalClient.Get(
-      platform='cpu', xla_platform_id='Host', asynchronous=True)
+      platform='cpu',
+      xla_platform_id=xla_platform_names['cpu'],
+      asynchronous=True)
   return LocalBackend(platform='cpu', client=client)
 
 
@@ -142,7 +151,9 @@
   config.preallocate = preallocate not in ('0', 'false', 'False')
 
   client = _xla.LocalClient.Get(
-      platform='gpu', xla_platform_id='CUDA', asynchronous=True,
+      platform='gpu',
+      xla_platform_id=xla_platform_names['gpu'],
+      asynchronous=True,
       allocator_config=config)
   return LocalBackend(platform='gpu', client=client)
 
@@ -544,6 +555,9 @@
 #   def Execute(self, arguments : [Buffer]) -> Buffer:
 #     """Execute on one replica with Buffer arguments and return value."""
 #
+#   def SizeOfGeneratedCodeInBytes(self) -> int:
+#     """Return generated binary size, or -1 if not known."""
+#
 #   def ExecutePerReplica(self, arguments: [[Buffer]]) -> [Buffer]:
 #     """Execute on many replicas with Buffer arguments and return value.
 #
@@ -1592,14 +1606,18 @@
 _forward_methods_to_local_builder()
 
 
-def register_cpu_custom_call_target(name, fn):
-  """Registers a CPU custom call target.
+def register_custom_call_target(name, fn, platform='cpu'):
+  """Registers a custom call target.
 
   Args:
     name: bytes containing the name of the function.
     fn: a PyCapsule object containing the function pointer.
+    platform: the target platform.
   """
-  _xla.RegisterCpuCustomCallTarget(name, fn)
+  _xla.RegisterCustomCallTarget(name, fn, xla_platform_names[platform])
+
+# Deprecated. Use register_custom_call_target instead.
+register_cpu_custom_call_target = register_custom_call_target
 
 
 class PaddingConfigDimension(object):
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index 16c1d42..ac15bc8 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -22,14 +22,14 @@
 import itertools
 import threading
 
+from absl.testing import absltest
 import numpy as np
 
 from tensorflow.compiler.xla.python import custom_call_for_test
 from tensorflow.compiler.xla.python import xla_client
-import unittest
 
 
-class ComputationTest(unittest.TestCase):
+class ComputationTest(absltest.TestCase):
   """Base class for running an XLA Computation through the local client."""
 
   def _NewComputation(self, name=None):
@@ -89,7 +89,7 @@
   return np.array(*args, dtype=np.bool, **kwargs)
 
 
-class ComputationPrinting(unittest.TestCase):
+class ComputationPrinting(absltest.TestCase):
 
   def ExampleComputation(self):
     builder = xla_client.ComputationBuilder("acomputation")
@@ -311,7 +311,7 @@
   def testCustomCall(self):
     c = self._NewComputation()
     for name, fn in custom_call_for_test.cpu_custom_call_targets.items():
-      xla_client.register_cpu_custom_call_target(name, fn)
+      xla_client.register_custom_call_target(name, fn, platform="cpu")
     c.CustomCall(
         b"test_subtract_f32",
         operands=(c.ConstantF32Scalar(1.25), c.ConstantF32Scalar(0.5)),
@@ -448,14 +448,14 @@
     local_buffer = xla_client.Buffer.from_pyval(t)
     pieces = local_buffer.destructure()
     self.assertFalse(local_buffer.is_deleted())
-    self.assertEqual(len(pieces), 0)
+    self.assertEmpty(pieces)
 
   def testDestructureTupleOneArrayElement(self):
     t = (np.array([1, 2, 3, 4], dtype=np.int32),)
     local_buffer = xla_client.Buffer.from_pyval(t)
     pieces = local_buffer.destructure()
     self.assertFalse(local_buffer.is_deleted())
-    self.assertEqual(len(pieces), 1)
+    self.assertLen(pieces, 1)
     array = pieces[0]
     got = array.to_py()
     want = NumpyArrayS32([1, 2, 3, 4])
@@ -472,7 +472,7 @@
     for _ in range(2):
       pieces = local_buffer.destructure()
       self.assertFalse(local_buffer.is_deleted())
-      self.assertEqual(len(pieces), 2)
+      self.assertLen(pieces, 2)
       array0, array1 = pieces
       got = array0.to_py()
       want = NumpyArrayF32([1.0, 2.0, 3.0, 4.0])
@@ -486,14 +486,14 @@
     local_buffer = xla_client.Buffer.from_pyval(t)
     pieces = local_buffer.destructure()
     self.assertFalse(local_buffer.is_deleted())
-    self.assertEqual(len(pieces), 2)
+    self.assertLen(pieces, 2)
     tuple0, array1 = pieces
     got = array1.to_py()
     want = NumpyArrayS32([5])
     np.testing.assert_equal(want, got)
     got = tuple0.to_py()
     self.assertEqual(type(got), tuple)
-    self.assertEqual(len(got), 2)
+    self.assertLen(got, 2)
     np.testing.assert_equal(NumpyArrayF32([1.0, 2.0]), got[0])
     np.testing.assert_equal(NumpyArrayS32([3, 4]), got[1])
 
@@ -506,7 +506,7 @@
     b1 = xla_client.Buffer.from_pyval(t[1])
     btup = xla_client.Buffer.make_tuple([b0, b1], device=0)
     pieces = btup.destructure()
-    self.assertEqual(len(pieces), 2)
+    self.assertLen(pieces, 2)
     array0, array1 = pieces
     np.testing.assert_equal(
         np.array([1, 2, 3, 4], dtype=np.float32), array0.to_py())
@@ -699,7 +699,7 @@
     rhs = NumpyArrayF32(rng.randn(10, 4, 5))
     dimension_numbers = (([2], [1]), ([0], [0]))
     c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers)
-    self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs))
+    self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs), rtol=1e-6)
 
   def testDotGeneralWithDotDimensionNumbersProto(self):
     c = self._NewComputation()
@@ -714,7 +714,7 @@
     dimension_numbers.rhs_batch_dimensions.append(0)
 
     c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers)
-    self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs))
+    self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs), rtol=1e-6)
 
   def testDotGeneralWithPrecisionConfig(self):
     c = self._NewComputation()
@@ -730,7 +730,7 @@
         c.Constant(rhs),
         dimension_numbers,
         precision_config=config)
-    self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs))
+    self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs), rtol=1e-6)
 
   def testConvF32Same(self):
     c = self._NewComputation()
@@ -1222,7 +1222,7 @@
     result = xla_client.execute_with_python_values(c.Build().Compile())
     # since the result is random, we just check shape and uniqueness
     self.assertEqual(result.shape, shape)
-    self.assertEqual(len(np.unique(result)), np.prod(shape))
+    self.assertLen(np.unique(result), np.prod(shape))
 
   def testRngUniformF32(self):
     lo, hi = 2., 4.
@@ -1235,7 +1235,7 @@
     result = xla_client.execute_with_python_values(c.Build().Compile())
     # since the result is random, we just check shape, uniqueness, and range
     self.assertEqual(result.shape, shape)
-    self.assertEqual(len(np.unique(result)), np.prod(shape))
+    self.assertLen(np.unique(result), np.prod(shape))
     self.assertTrue(np.all(lo <= result))
     self.assertTrue(np.all(result < hi))
 
@@ -1923,4 +1923,4 @@
 
 
 if __name__ == "__main__":
-  unittest.main()
+  absltest.main()
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index f34572b..ac67497 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -4,10 +4,18 @@
 load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
 load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_proto_library_py",
 )
 load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load(
+    "//tensorflow/core/platform:default/cuda_build_defs.bzl",
+    "if_cuda_is_configured",
+)
+load(
+    "@local_config_rocm//rocm:build_defs.bzl",
+    "if_rocm_is_configured",
+)
 
 package(
     default_visibility = [":friends"],
@@ -290,6 +298,64 @@
     ],
 )
 
+cc_library(
+    name = "hlo_live_range",
+    srcs = [
+        "hlo_live_range.cc",
+    ],
+    hdrs = [
+        "hlo_live_range.h",
+    ],
+    deps = [
+        ":hlo",
+        ":hlo_alias_analysis",
+        ":hlo_buffer",
+        ":hlo_dataflow_analysis",
+        ":hlo_ordering",
+        ":logical_buffer",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/compiler/xla:types",
+        "//tensorflow/compiler/xla:util",
+        "//tensorflow/core:lib",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:str_format",
+    ],
+)
+
+tf_cc_test(
+    name = "hlo_live_range_test",
+    srcs = ["hlo_live_range_test.cc"],
+    deps = [
+        ":call_graph",
+        ":hlo",
+        ":hlo_alias_analysis",
+        ":hlo_live_range",
+        ":hlo_memory_scheduler",
+        ":hlo_ordering",
+        ":hlo_parser",
+        ":hlo_value",
+        "//tensorflow/compiler/xla:literal",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/compiler/xla:test",
+        "//tensorflow/compiler/xla:test_helpers",
+        "//tensorflow/compiler/xla:types",
+        "//tensorflow/compiler/xla:util",
+        "//tensorflow/compiler/xla:xla_data_proto",
+        "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:test",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/strings:str_format",
+    ],
+)
+
 tf_cc_test(
     name = "hlo_evaluator_test",
     srcs = ["hlo_evaluator_test.cc"],
@@ -862,11 +928,14 @@
     name = "gpu_plugin",
     deps = [
         ":service",
+        "//tensorflow/compiler/xla/service/gpu:gpu_compiler",
         "//tensorflow/compiler/xla/service/gpu:gpu_transfer_manager",
-        "//tensorflow/compiler/xla/service/gpu:nvptx_compiler",
         "//tensorflow/core:stream_executor_no_cuda",
+    ] + if_cuda_is_configured([
         "//tensorflow/core/platform/default/build_config:stream_executor_cuda",
-    ],
+    ]) + if_rocm_is_configured([
+        "//tensorflow/core/platform/default/build_config:stream_executor_rocm",
+    ]),
 )
 
 cc_library(
@@ -952,6 +1021,7 @@
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:stream_executor_no_cuda",
         "//tensorflow/stream_executor",
+        "//tensorflow/stream_executor:device_description",
         "//tensorflow/stream_executor:device_memory_allocator",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings:str_format",
@@ -1113,8 +1183,10 @@
         ":hlo_alias_analysis",
         ":hlo_buffer",
         ":hlo_dataflow_analysis",
+        ":hlo_live_range",
         ":hlo_proto",
         ":logical_buffer",
+        ":memory_space_assignment",
         ":tuple_points_to_analysis",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:status_macros",
@@ -1215,6 +1287,7 @@
         ":hlo_alias_analysis",
         ":hlo_buffer",
         ":hlo_dataflow_analysis",
+        ":hlo_live_range",
         ":hlo_ordering",
         ":hlo_proto",
         ":tuple_points_to_analysis",
@@ -1681,6 +1754,7 @@
         ":hlo",
         ":hlo_casting_utils",
         ":hlo_creation_utils",
+        ":hlo_evaluator",
         ":hlo_pass",
         ":hlo_query",
         ":pattern_matcher",
@@ -1694,6 +1768,7 @@
         "//tensorflow/compiler/xla:window_util",
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/core:lib",
+        "//tensorflow/stream_executor/lib",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:flat_hash_set",
@@ -4247,3 +4322,18 @@
         "//tensorflow/compiler/xla/client/lib:prng",
     ],
 )
+
+cc_library(
+    name = "slow_operation_alarm",
+    srcs = ["slow_operation_alarm.cc"],
+    hdrs = ["slow_operation_alarm.h"],
+    deps = [
+        "//tensorflow/compiler/xla:types",
+        "//tensorflow/core:lib",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/base:core_headers",
+        "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/synchronization",
+        "@com_google_absl//absl/time",
+    ],
+)
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 2025cb0..a3e107e 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -43,6 +43,7 @@
 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -60,6 +61,7 @@
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
 
 namespace xla {
 
@@ -170,6 +172,10 @@
 // more general case a worklist based approach would be needed.
 class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
  public:
+  explicit AlgebraicSimplifierVisitor(const AlgebraicSimplifierOptions& options,
+                                      AlgebraicSimplifier* simplifier)
+      : options_(options), simplifier_(simplifier) {}
+
   Status HandleAdd(HloInstruction* add) override;
 
   Status HandleAnd(HloInstruction* logical_and) override;
@@ -208,6 +214,10 @@
 
   Status HandleLog(HloInstruction* log) override;
 
+  Status HandleMaximum(HloInstruction* maximum) override;
+
+  Status HandleMinimum(HloInstruction* minimum) override;
+
   Status HandleMultiply(HloInstruction* multiply) override;
 
   Status HandleNegate(HloInstruction* negate) override;
@@ -224,7 +234,7 @@
 
   Status HandleReshape(HloInstruction* reshape) override;
 
-  Status HandleReduce(HloInstruction* reduce) override;
+  Status HandleReduce(HloInstruction* hlo) override;
 
   Status HandleReduceWindow(HloInstruction* reduce_window) override;
 
@@ -246,16 +256,11 @@
   Status HandleMap(HloInstruction* map) override;
 
   // Runs the visitor on a computation.
-  static bool Run(HloComputation* computation,
-                  const AlgebraicSimplifierOptions& options,
-                  AlgebraicSimplifier* simplifier);
+  bool Run(HloComputation* computation,
+           const AlgebraicSimplifierOptions& options,
+           AlgebraicSimplifier* simplifier);
 
  private:
-  explicit AlgebraicSimplifierVisitor(HloComputation* computation,
-                                      const AlgebraicSimplifierOptions& options,
-                                      AlgebraicSimplifier* simplifier)
-      : computation_(computation), options_(options), simplifier_(simplifier) {}
-
   // Removes degenerate dimension from dot.
   StatusOr<bool> RemoveDegenerateDimensionFromDot(HloInstruction* dot);
 
@@ -385,6 +390,9 @@
   // Tries to convert slice(reshape(X)) into reshape(slice(X))
   StatusOr<bool> TryToReorderSliceAndReshape(HloInstruction* slice);
 
+  // Useful when we want to use the same visitor over multiple computations.
+  void ResetState(HloComputation* computation);
+
   // Current HloComputation instance the AlgebraicSimplifierVisitor is
   // traversing.
   HloComputation* computation_;
@@ -403,12 +411,18 @@
 
 }  // namespace
 
+void AlgebraicSimplifierVisitor::ResetState(HloComputation* computation) {
+  changed_ = false;
+  ResetVisitStates();
+  computation_ = computation;
+}
+
 bool AlgebraicSimplifierVisitor::Run(HloComputation* computation,
                                      const AlgebraicSimplifierOptions& options,
                                      AlgebraicSimplifier* simplifier) {
-  AlgebraicSimplifierVisitor visitor(computation, options, simplifier);
-  TF_CHECK_OK(computation->Accept(&visitor));
-  return visitor.changed_ || visitor.changed();
+  ResetState(computation);
+  TF_CHECK_OK(computation->Accept(this));
+  return changed_ || changed();
 }
 
 bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs,
@@ -1874,6 +1888,98 @@
   return Status::OK();
 }
 
+namespace {
+StatusOr<std::unique_ptr<HloInstruction>> MinMaxToClamp(
+    HloInstruction* clamp_lower_bound_bcast, HloInstruction* to_clamp,
+    HloInstruction* clamp_upper_bound_bcast) {
+  HloInstruction* clamp_lower_bound;
+  CHECK(Match(clamp_lower_bound_bcast,
+              m::Broadcast(m::ConstantEffectiveScalar(&clamp_lower_bound))))
+      << clamp_lower_bound_bcast->ToString();
+
+  HloInstruction* clamp_upper_bound;
+  CHECK(Match(clamp_upper_bound_bcast,
+              m::Broadcast(m::ConstantEffectiveScalar(&clamp_upper_bound))))
+      << clamp_upper_bound_bcast->ToString();
+
+  const Literal& lower_bound =
+      Cast<HloConstantInstruction>(clamp_lower_bound)->literal();
+  const Literal& upper_bound =
+      Cast<HloConstantInstruction>(clamp_upper_bound)->literal();
+
+  std::unique_ptr<HloInstruction> lower_bound_instr =
+      HloInstruction::CreateConstant(lower_bound.Clone());
+  std::unique_ptr<HloInstruction> upper_bound_instr =
+      HloInstruction::CreateConstant(upper_bound.Clone());
+
+  std::unique_ptr<HloInstruction> cloned_instruction =
+      HloInstruction::CreateCompare(
+          ShapeUtil::ChangeElementType(lower_bound_instr->shape(), PRED),
+          lower_bound_instr.get(), upper_bound_instr.get(),
+          ComparisonDirection::kLt);
+
+  HloEvaluator evaluator;
+  TF_ASSIGN_OR_RETURN(auto result,
+                      evaluator.Evaluate(cloned_instruction.get()));
+  if (result.IsAll(true)) {
+    return HloInstruction::CreateTernary(to_clamp->shape(), HloOpcode::kClamp,
+                                         clamp_lower_bound_bcast, to_clamp,
+                                         clamp_upper_bound_bcast);
+  }
+  return std::unique_ptr<HloInstruction>();
+}
+}  // namespace
+
+Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) {
+  HloInstruction *lhs, *rhs;
+  CHECK(Match(maximum, m::Maximum(m::Op(&lhs), m::Op(&rhs))));
+
+  HloInstruction* clamp_upper_bound_bcast;
+  HloInstruction* clamp_lower_bound_bcast;
+  HloInstruction* to_clamp;
+  if (Match(maximum, m::MaximumAnyOrder(
+                         m::Broadcast(&clamp_lower_bound_bcast,
+                                      m::ConstantEffectiveScalar()),
+                         m::MinimumAnyOrder(
+                             m::Op(&to_clamp),
+                             m::Broadcast(&clamp_upper_bound_bcast,
+                                          m::ConstantEffectiveScalar()))))) {
+    TF_ASSIGN_OR_RETURN(auto clamp,
+                        MinMaxToClamp(clamp_lower_bound_bcast, to_clamp,
+                                      clamp_upper_bound_bcast));
+    if (clamp) {
+      return ReplaceWithNewInstruction(maximum, std::move(clamp));
+    }
+  }
+
+  return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) {
+  HloInstruction *lhs, *rhs;
+  CHECK(Match(minimum, m::Minimum(m::Op(&lhs), m::Op(&rhs))));
+
+  HloInstruction* clamp_upper_bound_bcast;
+  HloInstruction* clamp_lower_bound_bcast;
+  HloInstruction* to_clamp;
+  if (Match(minimum, m::MinimumAnyOrder(
+                         m::Broadcast(&clamp_upper_bound_bcast,
+                                      m::ConstantEffectiveScalar()),
+                         m::MaximumAnyOrder(
+                             m::Op(&to_clamp),
+                             m::Broadcast(&clamp_lower_bound_bcast,
+                                          m::ConstantEffectiveScalar()))))) {
+    TF_ASSIGN_OR_RETURN(auto clamp,
+                        MinMaxToClamp(clamp_lower_bound_bcast, to_clamp,
+                                      clamp_upper_bound_bcast));
+    if (clamp) {
+      return ReplaceWithNewInstruction(minimum, std::move(clamp));
+    }
+  }
+
+  return Status::OK();
+}
+
 Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
   HloInstruction *lhs, *rhs;
   CHECK(Match(multiply, m::Multiply(m::Op(&lhs), m::Op(&rhs))));
@@ -2384,9 +2490,11 @@
     TF_ASSIGN_OR_RETURN(
         HloInstruction * slice,
         MakeSliceHlo(nonzero_pad, start_indices, end_indices, strides));
+    TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
+        pad->shape(), slice->mutable_shape()));
 
     // Verify that the slice shape matches the pad shape.
-    TF_RET_CHECK(ShapeUtil::Compatible(slice->shape(), pad->shape()));
+    TF_RET_CHECK(ShapeUtil::Equal(slice->shape(), pad->shape()));
 
     return ReplaceInstruction(pad, slice);
   }
@@ -2698,9 +2806,9 @@
     // this.  But that's OK for our purposes here.)
     int64 iota_upper_bound = iota->shape().dimensions(
         Cast<HloIotaInstruction>(iota)->iota_dimension());
-    StatusOr<int64> divisor_val = divisor->literal().GetIntegralAsS64(
+    absl::optional<int64> divisor_val = divisor->literal().GetIntegralAsS64(
         std::vector<int64>(0, divisor->shape().dimensions_size()));
-    if (divisor_val.ok() && divisor_val.ValueOrDie() >= iota_upper_bound) {
+    if (divisor_val && *divisor_val >= iota_upper_bound) {
       return ReplaceInstruction(remainder, iota);
     }
   }
@@ -2726,12 +2834,12 @@
     // smaller.
     int64 iota_upper_bound = iota->shape().dimensions(
         Cast<HloIotaInstruction>(iota)->iota_dimension());
-    StatusOr<int64> divisor_val = divisor->literal().GetIntegralAsS64(
+    absl::optional<int64> divisor_val = divisor->literal().GetIntegralAsS64(
         std::vector<int64>(0, divisor->shape().dimensions_size()));
-    if (divisor_val.ok()) {
+    if (divisor_val) {
       // Check whether divisor_val + iota_upper_bound - 1 overflows.
       absl::optional<int64> max_val =
-          OverflowSafeAdd(divisor_val.ValueOrDie(), iota_upper_bound);
+          OverflowSafeAdd(*divisor_val, iota_upper_bound);
       if (max_val.has_value() &&
           FitsInIntegralType(*max_val, iota->shape().element_type())) {
         return ReplaceWithNewInstruction(
@@ -3945,8 +4053,9 @@
   XLA_VLOG_LINES(2,
                  "AlgebraicSimplifier::Run(), before:\n" + module->ToString());
   bool changed = false;
+  AlgebraicSimplifierVisitor visitor(options_, this);
   for (auto* comp : module->MakeNonfusionComputations()) {
-    if (AlgebraicSimplifierVisitor::Run(comp, options_, this)) {
+    if (visitor.Run(comp, options_, this)) {
       changed = true;
     }
   }
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 4c5e5ef..a3282b9 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -5543,5 +5543,67 @@
               GmockMatch(m::Remainder(m::Parameter(), m::Parameter())));
 }
 
+TEST_F(AlgebraicSimplifierTest, SlicePadLayout) {
+  const char* kModuleStr = R"(
+    HloModule m
+    test {
+      %param.0 = f32[128,9,9,1024]{0,3,2,1} parameter(0)
+      %param.1 = f32[] parameter(1)
+      %slice = f32[128,9,9,1024]{0,3,2,1} slice(%param.0),
+        slice={[0:128], [0:9], [0:9], [0:1024]}
+      ROOT %pad = f32[128,8,9,1024]{0,3,2,1} pad(%slice, %param.1),
+        padding=0_0x-1_0x0_0x0_0
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+  const Shape root_shape = m->entry_computation()->root_instruction()->shape();
+  AlgebraicSimplifierOptions options;
+  options.set_is_layout_sensitive(true);
+  ASSERT_TRUE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
+  EXPECT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::Slice().WithShapeEqualTo(&root_shape)));
+}
+
+TEST_F(AlgebraicSimplifierTest, MinOfMaxToClamp) {
+  const char* kModuleStr = R"(
+    HloModule m
+    test {
+      p0 = f32[4] parameter(0)
+      c0 = f32[] constant(3.0)
+      c1 = f32[] constant(4.0)
+      b0 = f32[4] broadcast(c0), dimensions={}
+      b1 = f32[4] broadcast(c1), dimensions={}
+      m0 = f32[4] maximum(b0, p0)
+      ROOT m1 = f32[4] minimum(m0, b1)
+    }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+  ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+  EXPECT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::Clamp(m::Broadcast(m::ConstantScalar(3.0)), m::Parameter(0),
+                          m::Broadcast(m::ConstantScalar(4.0)))));
+}
+
+TEST_F(AlgebraicSimplifierTest, MaxOfMinToClamp) {
+  const char* kModuleStr = R"(
+    HloModule m
+    test {
+      p0 = f32[4] parameter(0)
+      c0 = f32[] constant(3.0)
+      c1 = f32[] constant(4.0)
+      b0 = f32[4] broadcast(c0), dimensions={}
+      b1 = f32[4] broadcast(c1), dimensions={}
+      m0 = f32[4] minimum(p0, b1)
+      ROOT m1 = f32[4] maximum(b0, m0)
+    }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+  ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+  EXPECT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::Clamp(m::Broadcast(m::ConstantScalar(3.0)), m::Parameter(0),
+                          m::Broadcast(m::ConstantScalar(4.0)))));
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h
index 147f3ae..9c19308 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.h
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.h
@@ -29,9 +29,9 @@
 class BatchNormExpander : public HloModulePass {
  public:
   // When use_fusion is set, a multi-output fusion node is created.
-  BatchNormExpander(bool rewrite_training_op = false,
-                    bool rewrite_inference_op = false,
-                    bool rewrite_grad_op = false)
+  explicit BatchNormExpander(bool rewrite_training_op = false,
+                             bool rewrite_inference_op = false,
+                             bool rewrite_grad_op = false)
       : rewrite_training_op_(rewrite_training_op),
         rewrite_inference_op_(rewrite_inference_op),
         rewrite_grad_op_(rewrite_grad_op) {}
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 3ae7235..d72a91f 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -33,6 +33,7 @@
 #include "tensorflow/compiler/xla/service/hlo.pb.h"
 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
 #include "tensorflow/compiler/xla/service/hlo_buffer.h"
+#include "tensorflow/compiler/xla/service/hlo_live_range.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
@@ -233,8 +234,8 @@
 
 void BufferAllocation::AddAssignment(const HloValue& buffer, int64 offset,
                                      int64 size) {
-  VLOG(4) << "Adding the following buffer to allocation #" << index() << ": "
-          << buffer;
+  VLOG(4) << "Adding the following buffer to allocation #" << index() << " ["
+          << offset << ", " << size << "]: " << buffer;
   CHECK(!assigned_buffers_.contains(&buffer))
       << "LogicalBuffer " << buffer << " already assigned to allocation "
       << index_;
@@ -250,6 +251,13 @@
   offset_size.offset = offset;
   offset_size.size = size;
   assigned_buffers_.emplace(&buffer, offset_size);
+  // For debugging purposes, store the assigned memory space in the
+  // instruction's layout.
+  HloInstruction* defining_instruction = buffer.defining_instruction();
+  if (defining_instruction->shape().has_layout()) {
+    defining_instruction->mutable_shape()->mutable_layout()->set_memory_space(
+        buffer.color().value());
+  }
 }
 
 BufferAllocationProto BufferAllocation::ToProto() const {
@@ -758,14 +766,69 @@
     LogicalBuffer::AlignmentFunction color_alignment,
     bool allocate_buffers_for_constants, BufferAssigner::Colorer colorer,
     const absl::flat_hash_set<HloOpcode>& reuse_checker,
-    HloDataflowAnalysis::CanShareBuffer can_share_buffer) {
+    HloDataflowAnalysis::CanShareBuffer can_share_buffer,
+    std::unique_ptr<PresetAssignments> preset_assignments) {
   BufferAssigner assigner(allocate_buffers_for_constants, std::move(colorer),
-                          reuse_checker);
+                          reuse_checker, std::move(preset_assignments));
   return assigner.CreateAssignment(
       module, std::move(hlo_ordering), std::move(buffer_size),
       std::move(color_alignment), std::move(can_share_buffer));
 }
 
+bool BufferAssigner::LiveRangeInterferes(const HloValue* buffer1,
+                                         const HloValue* buffer2,
+                                         BufferAssignment* assignment) {
+  CHECK((assignment->hlo_live_range().total_order_scheduled()));
+  const HloLiveRange& hlo_live_range = assignment->hlo_live_range();
+
+  const auto& buffer_live_ranges = hlo_live_range.buffer_live_ranges();
+
+  CHECK(buffer_live_ranges.contains(buffer1))
+      << "Buffer doesn't have a proper live range:" << buffer1;
+
+  CHECK(buffer_live_ranges.contains(buffer2))
+      << "Buffer doesn't have a proper live range:" << buffer2;
+
+  // Check if a user value can share the same buffer as its operand.
+  auto can_share_as_operand = [&assignment](const HloValue* user_value,
+                                            const HloValue* operand_value) {
+    return user_value->instruction()->IsUserOf(operand_value->instruction()) &&
+           assignment->dataflow_analysis().CanShareOperandBufferWithUser(
+               operand_value->instruction(), operand_value->index(),
+               user_value->instruction(), user_value->index()) &&
+           user_value->instruction()->opcode() != HloOpcode::kCopy;
+  };
+
+  auto live_range_1 = buffer_live_ranges.at(buffer1);
+  auto live_range_2 = buffer_live_ranges.at(buffer2);
+
+  if (!(live_range_1.start > live_range_2.end ||
+        live_range_2.start > live_range_1.end)) {
+    if (live_range_1.end == live_range_2.start) {
+      auto operand_value = buffer1;
+      auto user_value = buffer2;
+      if (!can_share_as_operand(user_value, operand_value)) {
+        return true;
+      }
+    } else if (live_range_2.end == live_range_1.start) {
+      auto operand_value = buffer2;
+      auto user_value = buffer1;
+      if (!can_share_as_operand(user_value, operand_value)) {
+        return true;
+      }
+    } else {
+      VLOG(4) << "Can't assign: assignee " << *buffer1 << " may interfere with "
+              << *buffer2;
+      VLOG(4) << "assigned_buffer.start: " << live_range_1.start;
+      VLOG(4) << "assigned_buffer.end: " << live_range_1.end;
+      VLOG(4) << "live_range_2.start" << live_range_2.start;
+      VLOG(4) << "live_range_2.end" << live_range_2.end;
+      return true;
+    }
+  }
+  return false;
+}
+
 bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
                                        const HloBuffer& hlo_buffer,
                                        BufferAssignment* assignment) {
@@ -777,7 +840,7 @@
           << " to allocation: " << *allocation;
 
   if (hlo_buffer.color() != allocation->color()) {
-    VLOG(4) << "Can't assign: buffer has color" << hlo_buffer.color()
+    VLOG(4) << "Can't assign: buffer has color " << hlo_buffer.color()
             << " and allocation has color " << allocation->color() << ".";
     return false;
   }
@@ -833,10 +896,17 @@
     const HloValue& assigned_buffer =
         *CHECK_NOTNULL(dynamic_cast<const HloValue*>(buffer_offset_size.first));
     for (const HloValue* new_value : hlo_buffer.values()) {
-      if (assignment->hlo_ordering().MayInterfere(
-              assigned_buffer, *new_value, assignment->dataflow_analysis())) {
+      if (assignment->hlo_live_range().total_order_scheduled()) {
+        if (LiveRangeInterferes(new_value, &assigned_buffer, assignment)) {
+          return false;
+        }
+      } else if (assignment->hlo_ordering().MayInterfere(
+                     assigned_buffer, *new_value,
+                     assignment->dataflow_analysis())) {
+        // Fallback to partial order based interference detection (slower) when
+        // we don't have a total order scheduled module.
         VLOG(4) << "Can't assign: assignee " << assigned_buffer
-                << " may interfere with " << new_value;
+                << " may interfere with " << new_value->ToShortString();
         return false;
       }
 
@@ -847,7 +917,8 @@
                 assigned_buffer_position.instruction) &&
             new_value->instruction()->opcode() == HloOpcode::kCopy) {
           VLOG(4) << "Can't assign: assignee " << assigned_buffer
-                  << " is used at copy instruction " << new_value;
+                  << " is used at copy instruction "
+                  << new_value->ToShortString();
           return false;
         }
       }
@@ -1094,8 +1165,20 @@
   }
   std::vector<const HloBuffer*> sorted_buffers;
 
+  // First assign the preset allocations.
+  absl::flat_hash_set<const HloBuffer*> preset_assigned_buffers;
+
+  TF_RETURN_IF_ERROR(AssignPresetBuffers(&preset_assigned_buffers, assignment));
+
   const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
+
   for (const HloBuffer& buffer : alias_analysis.buffers()) {
+    // Skip if the buffer is already assigned since it had a preset allocation.
+    if (preset_assigned_buffers.find(&buffer) !=
+        preset_assigned_buffers.end()) {
+      VLOG(3) << "Skip allocation for buffer: " << buffer;
+      continue;
+    }
     TF_RET_CHECK(!buffer.values().empty());
     const HloComputation* comp = buffer.values()[0]->instruction()->parent();
     if (absl::c_linear_search(computations, comp)) {
@@ -1124,9 +1207,12 @@
     }
   }
 
+  HloSchedule schedule(&assignment->module());
+
   for (const HloComputation* computation : computations) {
-    const bool has_sequential_order =
-        assignment->hlo_ordering().SequentialOrder(*computation) != nullptr;
+    const HloInstructionSequence* instruction_sequence =
+        assignment->hlo_ordering().SequentialOrder(*computation);
+    const bool has_sequential_order = instruction_sequence != nullptr;
     if (has_sequential_order && buffers_to_assign_sequentially != nullptr) {
       // Every sequential computation must get an entry in the
       // buffers_to_assign_sequentially map, even if we end up with an empty
@@ -1134,6 +1220,8 @@
       // run whole-module heap simulation.
       buffers_to_assign_sequentially->emplace(computation,
                                               flat_hash_set<const HloValue*>());
+
+      schedule.set_sequence(computation, *instruction_sequence);
     }
   }
 
@@ -1188,6 +1276,54 @@
   return color_map;
 }
 
+Status BufferAssigner::AssignPresetBuffers(
+    absl::flat_hash_set<const HloBuffer*>* assigned_buffers,
+    BufferAssignment* assignment) {
+  if (!preset_assignments_) {
+    return Status::OK();
+  }
+
+  // Create an allocation for each preset color.
+  absl::flat_hash_map<LogicalBuffer::Color, BufferAllocation*,
+                      LogicalBuffer::Color::Hasher>
+      preset_allocations;
+  for (auto& color_and_size : preset_assignments_->sizes()) {
+    LogicalBuffer::Color color(color_and_size.first);
+    auto inserted = preset_allocations.emplace(
+        color, assignment->NewEmptyAllocation(color_and_size.second, color));
+    BufferAllocation* inserted_allocation = inserted.first->second;
+    VLOG(3) << "Created preset buffer allocation "
+            << inserted_allocation->index()
+            << ", color: " << inserted_allocation->color()
+            << ", size: " << inserted_allocation->size();
+  }
+
+  const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
+
+  for (auto& position_and_chunk : preset_assignments_->chunks()) {
+    const HloPosition& position = position_and_chunk.first;
+    const HloBuffer& buffer =
+        alias_analysis.GetUniqueBufferAt(position.instruction, position.index);
+    VLOG(3) << "Preset allocation for buffer: " << buffer;
+    const HeapSimulator::Chunk& chunk = position_and_chunk.second;
+    auto preset_allocations_iter = preset_allocations.find(buffer.color());
+    CHECK(preset_allocations_iter != preset_allocations.end())
+        << "No preset buffer allocation for color " << buffer.color()
+        << " found.";
+    preset_allocations_iter->second->AddAssignment(buffer.GetUniqueValue(),
+                                                   chunk.offset, chunk.size);
+    // Ensure that there is at most one preset allocation for each buffer.
+    CHECK_EQ(assigned_buffers->count(&buffer), 0);
+    assigned_buffers->emplace(&buffer);
+  }
+
+  // Upon consumption of the preset assignments, delete it so that if this
+  // method is called again, it does not assign the same buffers multiple times.
+  preset_assignments_ = {};
+
+  return Status::OK();
+}
+
 Status BufferAssigner::AssignBuffersWithSequentialOrdering(
     const flat_hash_map<const HloComputation*, flat_hash_set<const HloValue*>>&
         buffers_to_assign_sequentially,
@@ -1393,6 +1529,21 @@
   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
                       HloAliasAnalysis::Run(module, can_share_buffer));
 
+  // Set up a schedule for each computation.
+  HloSchedule schedule(module);
+  for (const HloComputation* computation : module->computations()) {
+    const HloInstructionSequence* instruction_sequence =
+        hlo_ordering->SequentialOrder(*computation);
+    const bool has_sequential_order = instruction_sequence != nullptr;
+    if (has_sequential_order) {
+      schedule.set_sequence(computation, *instruction_sequence);
+    }
+  }
+
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
+                      HloLiveRange::Run(schedule, *alias_analysis,
+                                        module->entry_computation(), true));
+
   VLOG(1) << "Assigning buffers to module " << module->name();
   XLA_VLOG_LINES(3, module->ToString());
   XLA_VLOG_LINES(3, alias_analysis->ToString());
@@ -1404,7 +1555,8 @@
   // private.
   std::unique_ptr<BufferAssignment> assignment(new BufferAssignment(
       module, std::move(hlo_ordering), std::move(buffer_size),
-      std::move(color_alignment), std::move(alias_analysis)));
+      std::move(color_alignment), std::move(alias_analysis),
+      std::move(hlo_live_range)));
 
   TF_RETURN_IF_ERROR(
       colorer_(&assignment->alias_analysis(), assignment->hlo_ordering()));
@@ -1432,7 +1584,7 @@
   // module, which reduces memory usage.
   const bool run_whole_module_heap_simulation =
       buffers_to_assign_sequentially.size() == global_computations.size();
-  VLOG(2) << "Running whole module heap simulation"
+  VLOG(2) << "Running whole module heap simulation: "
           << run_whole_module_heap_simulation;
   TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering(
       buffers_to_assign_sequentially, run_whole_module_heap_simulation,
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h
index f60ad22..9caf4be 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.h
+++ b/tensorflow/compiler/xla/service/buffer_assignment.h
@@ -31,8 +31,10 @@
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_live_range.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/service/memory_space_assignment.h"
 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/types.h"
@@ -445,9 +447,11 @@
 
   HloAliasAnalysis& alias_analysis() const { return *alias_analysis_; }
 
-  // Returns the BufferLiveness object used to construct this assignment.
   const HloOrdering& hlo_ordering() const { return *hlo_ordering_; }
 
+  // Returns the HloLiveRange object used to construct this assignment.
+  const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; }
+
   string ToString() const;
   BufferAssignmentProto ToProto() const;
 
@@ -480,12 +484,14 @@
                    std::unique_ptr<HloOrdering> hlo_ordering,
                    BufferValue::SizeFunction buffer_size,
                    LogicalBuffer::AlignmentFunction color_alignment,
-                   std::unique_ptr<HloAliasAnalysis> alias_analysis)
+                   std::unique_ptr<HloAliasAnalysis> alias_analysis,
+                   std::unique_ptr<HloLiveRange> hlo_live_range)
       : module_(module),
         hlo_ordering_(std::move(hlo_ordering)),
         buffer_size_(std::move(buffer_size)),
         color_alignment_(std::move(color_alignment)),
-        alias_analysis_(std::move(alias_analysis)) {}
+        alias_analysis_(std::move(alias_analysis)),
+        hlo_live_range_(std::move(hlo_live_range)) {}
 
   // Creates and returns a new BufferAllocation, with no assigned
   // LogicalBuffers. Ownership is maintained internally.
@@ -545,6 +551,8 @@
 
   std::unique_ptr<HloAliasAnalysis> alias_analysis_;
 
+  std::unique_ptr<HloLiveRange> hlo_live_range_;
+
   Stats stats_;
 
   TF_DISALLOW_COPY_AND_ASSIGN(BufferAssignment);
@@ -558,7 +566,13 @@
   static Colorer DefaultColorer() {
     return [](HloAliasAnalysis* alias_analysis, const HloOrdering&) {
       for (HloValue* value : alias_analysis->dataflow_analysis().values()) {
-        value->set_color(BufferValue::Color(0));
+        HloInstruction* defining_instruction = value->defining_instruction();
+        if (defining_instruction->shape().has_layout()) {
+          value->set_color(BufferValue::Color(
+              defining_instruction->shape().layout().memory_space()));
+        } else {
+          value->set_color(BufferValue::Color(0));
+        }
       }
       return Status::OK();
     };
@@ -569,7 +583,9 @@
   // Build and return a BufferAssignment for the given module. The given
   // HloOrdering is used to determine buffer liveness. buffer_size and
   // color_alignment are functions which returns the size and alignment of a
-  // LogicalBuffer.
+  // LogicalBuffer. If preset_assignments is provided, those pre-set assignment
+  // offsets will be used. The caller guarantees that those assignments are
+  // valid and they do not overwrite each other.
   static StatusOr<std::unique_ptr<BufferAssignment>> Run(
       const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
       BufferValue::SizeFunction buffer_size,
@@ -577,14 +593,17 @@
       bool allocate_buffers_for_constants = false,
       Colorer colorer = DefaultColorer(),
       const absl::flat_hash_set<HloOpcode>& must_not_live_out = {},
-      HloDataflowAnalysis::CanShareBuffer can_share_buffer = nullptr);
+      HloDataflowAnalysis::CanShareBuffer can_share_buffer = nullptr,
+      std::unique_ptr<PresetAssignments> preset_assignments = {});
 
  private:
   BufferAssigner(bool allocate_buffers_for_constants, Colorer colorer,
-                 const absl::flat_hash_set<HloOpcode>& must_not_live_out)
+                 const absl::flat_hash_set<HloOpcode>& must_not_live_out,
+                 std::unique_ptr<PresetAssignments> preset_assignments)
       : allocate_buffers_for_constants_(allocate_buffers_for_constants),
         colorer_(colorer),
-        must_not_live_out_(must_not_live_out) {}
+        must_not_live_out_(must_not_live_out),
+        preset_assignments_(std::move(preset_assignments)) {}
   virtual ~BufferAssigner() = default;
 
   // Create a buffer assignment.
@@ -606,6 +625,16 @@
           buffers_to_assign_sequentially,
       BufferAssignment* assignment);
 
+  // Returns true if buffer's live range interferences with buffer2's.
+  bool LiveRangeInterferes(const HloValue* buffer1, const HloValue* buffer2,
+                           BufferAssignment* assignment);
+
+  // Assigns pre-set assignments, if provided. These assignments will be added
+  // to assigned_buffers and skip buffer allocation.
+  Status AssignPresetBuffers(
+      absl::flat_hash_set<const HloBuffer*>* assigned_buffers,
+      BufferAssignment* assignment);
+
   // Promotes operations (DUS, scatter) to be done in place: If an operation can
   // be done in place, merge its buffer with its operand buffer.
   Status MergeInplaceOpBuffers(BufferAssignment* assignment);
@@ -657,6 +686,9 @@
   // A set of hlo opcodes that can't live out of a computation.
   absl::flat_hash_set<HloOpcode> must_not_live_out_;
 
+  // Description of any buffer offsets that are already set by an earlier pass.
+  std::unique_ptr<PresetAssignments> preset_assignments_;
+
   TF_DISALLOW_COPY_AND_ASSIGN(BufferAssigner);
 };
 
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 1ca20b6..1c98548 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -143,6 +143,20 @@
         .ConsumeValueOrDie();
   }
 
+  std::unique_ptr<BufferAssignment> RunBufferAssignmentWithPresetAssignments(
+      HloModule* module, std::unique_ptr<PresetAssignments> preset_assignments,
+      int64 alignment = 1) {
+    return BufferAssigner::Run(
+               module, absl::make_unique<DependencyHloOrdering>(module),
+               backend().compiler()->BufferSizeBytesFunction(),
+               [alignment](LogicalBuffer::Color) { return alignment; },
+               /*allocate_buffers_for_constants=*/true,
+               BufferAssigner::DefaultColorer(),
+               /*must_not_live_out=*/{},
+               /*can_share_buffer=*/nullptr, std::move(preset_assignments))
+        .ConsumeValueOrDie();
+  }
+
   // Builds an x+1.0 computation to use in a Map.
   std::unique_ptr<HloComputation> BuildMapComputationPlus1(const string& name) {
     auto builder = HloComputation::Builder(name);
@@ -599,6 +613,13 @@
 
   // The sub node has a valid output buffer assigned.
   GetAssignedOutputAllocation(*buffers, sub);
+
+  // Check if the HLO instructions have the correct colors in the layout.
+  EXPECT_EQ(param0->shape().layout().memory_space(), 2);
+  EXPECT_EQ(param1->shape().layout().memory_space(), 3);
+  EXPECT_EQ(mul->shape().layout().memory_space(), 4);
+  EXPECT_EQ(add->shape().layout().memory_space(), 5);
+  EXPECT_EQ(sub->shape().layout().memory_space(), 6);
 }
 
 TEST_F(BufferAssignmentTest, BasicPartiallyColored) {
@@ -666,6 +687,86 @@
 
   // The sub node has a valid output buffer assigned.
   GetAssignedOutputAllocation(*buffers, sub);
+
+  // Check if the HLO instructions have the correct colors in the layout.
+  EXPECT_EQ(mul->shape().layout().memory_space(), 1);
+  EXPECT_EQ(add->shape().layout().memory_space(), 1);
+  EXPECT_EQ(sub->shape().layout().memory_space(), 0);
+  EXPECT_EQ(param0->shape().layout().memory_space(), 0);
+  EXPECT_EQ(param1->shape().layout().memory_space(), 0);
+}
+
+TEST_F(BufferAssignmentTest, PresetAssignments) {
+  // paramscalar ------- (mul) -- (add) -- (sub)
+  //                     /        /        /
+  // param0[100] -------/        /        /
+  //                            /        /
+  // param1[100] --------------/--------/
+  // Similar to BasicPartiallyColored, but the color is set in the layout.
+  // The output of the mul and the add have the color 1 and have preset
+  // assignments, and the other buffers have the color 0, which allows the mul
+  // and add to share buffers.
+  auto builder = HloComputation::Builder(TestName());
+  auto paramscalar =
+      builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
+  auto broadcast = builder.AddInstruction(
+      HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
+  auto param0 = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, f32vec100_, "p1"));
+  auto param1 = builder.AddInstruction(
+      HloInstruction::CreateParameter(2, f32vec100_, "p2"));
+  Shape f32vec100_color1 =
+      ShapeUtil::MakeShapeWithLayout(F32, {100}, {0}, {}, 0, 1);
+  auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
+      f32vec100_color1, HloOpcode::kMultiply, broadcast, param0));
+  auto add = builder.AddInstruction(HloInstruction::CreateBinary(
+      f32vec100_color1, HloOpcode::kAdd, mul, param1));
+  auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
+      f32vec100_, HloOpcode::kSubtract, add, param1));
+  auto module = CreateNewVerifiedModule();
+  module->AddEntryComputation(builder.Build());
+
+  auto preset_assignments = absl::make_unique<PresetAssignments>();
+  preset_assignments->add_chunk({mul, {}}, {/*offset=*/100, /*size=*/400});
+  preset_assignments->add_chunk({add, {}}, {/*offset=*/550, /*size=*/400});
+  preset_assignments->add_size(/*memory_space=*/1, /*size=*/950);
+
+  auto buffers = RunBufferAssignmentWithPresetAssignments(
+      module.get(), std::move(preset_assignments));
+
+  // Distinct input buffers were assigned for parameters.
+  BufferAllocation paramscalar_buffer =
+      GetAssignedInputAllocation(*buffers, paramscalar);
+  BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
+  BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
+  EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
+  EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
+  EXPECT_EQ(paramscalar_buffer.color(), LogicalBuffer::Color(0));
+  EXPECT_NE(param0_buffer.index(), param1_buffer.index());
+  EXPECT_EQ(param0_buffer.color(), LogicalBuffer::Color(0));
+
+  // The mul and add use the same preset buffer. Ensure it has the correct color
+  // and offsets.
+  const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
+  const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
+  EXPECT_EQ(mul_buffer, add_buffer);
+  EXPECT_NE(mul_buffer.index(), param0_buffer.index());
+  EXPECT_EQ(mul_buffer.color(), LogicalBuffer::Color(1));
+
+  EXPECT_EQ(mul_buffer.assigned_buffers().size(), 2);
+  for (const auto& value_and_offsetsize : mul_buffer.assigned_buffers()) {
+    if (value_and_offsetsize.first->instruction() == mul) {
+      EXPECT_EQ(value_and_offsetsize.second.offset, 100);
+      EXPECT_EQ(value_and_offsetsize.second.size, 400);
+    } else {
+      EXPECT_EQ(value_and_offsetsize.first->instruction(), add);
+      EXPECT_EQ(value_and_offsetsize.second.offset, 550);
+      EXPECT_EQ(value_and_offsetsize.second.size, 400);
+    }
+  }
+
+  // The sub node has a valid output buffer assigned.
+  GetAssignedOutputAllocation(*buffers, sub);
 }
 
 TEST_F(BufferAssignmentTest, MultipleUsersForNode) {
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc
index f193603..985603b 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier.cc
+++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc
@@ -253,6 +253,31 @@
   }
   return true;
 }
+
+// Replaces the roots of all branches with an empty tuple if the conditional op
+// has no users. Returns if anything is changed.
+bool ReplaceRootWithEmptyTupleIfNoUsers(HloInstruction* conditional_op) {
+  const Shape empty_tuple = ShapeUtil::MakeTupleShape({});
+  if (conditional_op->user_count() == 0 &&
+      conditional_op != conditional_op->parent()->root_instruction() &&
+      !ShapeUtil::Compatible(empty_tuple, conditional_op->shape())) {
+    for (int64 branch_id = 0; branch_id < conditional_op->branch_count();
+         ++branch_id) {
+      auto branch_computation =
+          conditional_op->GetModule()->AddEmbeddedComputation(
+              conditional_op->branch_computation(branch_id)->Clone());
+      conditional_op->set_branch_computation(branch_id, branch_computation);
+      auto new_empty_root =
+          branch_computation->AddInstruction(HloInstruction::CreateTuple({}));
+      branch_computation->set_root_instruction(new_empty_root,
+                                               /*accept_different_shape=*/true);
+    }
+    *conditional_op->mutable_shape() = empty_tuple;
+    return true;
+  }
+  return false;
+}
+
 }  // namespace
 
 StatusOr<bool> ConditionalSimplifier::Run(HloModule* module) {
@@ -274,6 +299,7 @@
 
   std::map<HloComputation*, std::set<int64>> changed_computations;
   for (HloInstruction* conditional_op : conditional_ops) {
+    changed |= ReplaceRootWithEmptyTupleIfNoUsers(conditional_op);
     TF_ASSIGN_OR_RETURN(bool result, TryRemoveConditional(conditional_op));
     if (!result) {
       TF_ASSIGN_OR_RETURN(result, TryRemoveUnusedConditionalOperands(
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
index 5865915..d409e22 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
@@ -285,6 +285,49 @@
   EXPECT_TRUE(ConditionalSimplifier().Run(module.get()).ValueOrDie());
 }
 
+TEST_F(ConditionalSimplifierTest, RemoveDeadRoots) {
+  absl::string_view hlo_string =
+      R"(
+HloModule RemoveDeadRoots
+on_false {
+  t = (f32[20,40], f32[40,40]) parameter(0)
+  lhs = f32[20,40] get-tuple-element(t), index=0
+  rhs = f32[40,40] get-tuple-element(t), index=1
+  dot = f32[20,40] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  after-all = token[] after-all()
+  outfeed = token[] outfeed(dot, after-all)
+  ROOT result = (f32[20,40]) tuple(dot)
+}
+
+on_true {
+  t = (f32[20,40], f32[40,40]) parameter(0)
+  lhs = f32[20,40] get-tuple-element(t), index=0
+  add = f32[20,40] add(lhs, lhs)
+  ROOT result = (f32[20,40]) tuple(add)
+}
+
+ENTRY main {
+  c0_0 = f32[20,40] parameter(0)
+  c0_1 = f32[40,40] parameter(1)
+  p = pred[] parameter(2)
+  t = (f32[20,40], f32[40,40]) tuple(c0_0, c0_1)
+  conditional = (f32[20, 40]) conditional(p,t,t), false_computation=on_false, true_computation=on_true
+  ROOT result = () tuple()
+}
+)";
+  auto status = ParseAndReturnUnverifiedModule(hlo_string);
+  TF_ASSERT_OK(status.status());
+  HloVerifier v(false, false);
+  TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
+  EXPECT_TRUE(
+      ConditionalSimplifier().Run(status.ValueOrDie().get()).ValueOrDie());
+  TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
+  HloInstruction* conditional =
+      FindInstruction(status.ValueOrDie().get(), "conditional");
+  // The conditional root should be replaced with an empty tuple.
+  EXPECT_EQ(ShapeUtil::TupleElementCount(conditional->shape()), 0);
+}
+
 }  // namespace
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/convolution_group_converter.cc b/tensorflow/compiler/xla/service/convolution_group_converter.cc
index ff75f0f..20ebafc 100644
--- a/tensorflow/compiler/xla/service/convolution_group_converter.cc
+++ b/tensorflow/compiler/xla/service/convolution_group_converter.cc
@@ -355,7 +355,7 @@
     }
     // We want to repeat 'filter' in the 'input_feature_dim' dimension
     // 'group_count' times.
-    if (filter_expansion_) {
+    if (!is_cost_viable_(convolution) || filter_expansion_) {
       Shape reshaped_filter_shape =
           ShapeUtil::DeleteDimension(kernel_input_feature_dim, filter->shape());
       auto reshaped_filter =
diff --git a/tensorflow/compiler/xla/service/convolution_group_converter_test.cc b/tensorflow/compiler/xla/service/convolution_group_converter_test.cc
index d2eea14..85c54d3 100644
--- a/tensorflow/compiler/xla/service/convolution_group_converter_test.cc
+++ b/tensorflow/compiler/xla/service/convolution_group_converter_test.cc
@@ -49,7 +49,8 @@
   auto computation = module->entry_computation();
   HloInstruction* root = computation->root_instruction();
   EXPECT_EQ(root->opcode(), HloOpcode::kConvolution);
-  ConvolutionGroupConverter converter(nullptr, /*convert_batch_groups_only=*/
+  auto cost_model = [](HloInstruction* conv) { return true; };
+  ConvolutionGroupConverter converter(cost_model, /*convert_batch_groups_only=*/
                                       false);
   ASSERT_TRUE(converter.Run(module.get()).ValueOrDie());
   root = computation->root_instruction();
@@ -80,7 +81,8 @@
   auto computation = module->entry_computation();
   HloInstruction* root = computation->root_instruction();
   EXPECT_EQ(root->opcode(), HloOpcode::kConvolution);
-  ConvolutionGroupConverter converter(nullptr, /*convert_batch_groups_only=*/
+  auto cost_model = [](HloInstruction* conv) { return true; };
+  ConvolutionGroupConverter converter(cost_model, /*convert_batch_groups_only=*/
                                       false);
   ASSERT_TRUE(converter.Run(module.get()).ValueOrDie());
   root = computation->root_instruction();
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 37baf0e..e39ee46 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -35,6 +35,7 @@
     srcs = ["cpu_transfer_manager.cc"],
     hdrs = ["cpu_transfer_manager.h"],
     deps = [
+        ":cpu_runtime",
         "//tensorflow/compiler/xla:literal",
         "//tensorflow/compiler/xla:literal_util",
         "//tensorflow/compiler/xla:shape_util",
@@ -45,7 +46,6 @@
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/service:generic_transfer_manager",
         "//tensorflow/compiler/xla/service:transfer_manager",
-        "//tensorflow/compiler/xla/service/cpu:cpu_runtime",
         "//tensorflow/core:lib",
         "//tensorflow/core:stream_executor_no_cuda",
         "//tensorflow/stream_executor",
@@ -97,6 +97,7 @@
         "//tensorflow/compiler/xla/service:map_inliner",
         "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter",
         "//tensorflow/compiler/xla/service:conditional_to_select",
+        "//tensorflow/compiler/xla/service:slow_operation_alarm",
         "//tensorflow/compiler/xla/service:scatter_expander",
         "//tensorflow/compiler/xla/service:slice_sinker",
         "//tensorflow/compiler/xla:cpu_function_runtime",
@@ -1012,3 +1013,19 @@
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
     ],
 )
+
+tf_cc_test(
+    name = "vectorized_reduce_with_no_vector_registers_test",
+    size = "small",
+    srcs = ["vectorized_reduce_with_no_vector_registers_test.cc"],
+    deps = [
+        ":cpu_compiler",
+        ":cpu_transfer_manager",
+        "//tensorflow/compiler/xla:test",
+        "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "@llvm//:core",
+        "@llvm//:support",
+        "@llvm//:target",
+    ],
+)
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 9f8f743..acafa2c 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -99,6 +99,7 @@
 #include "tensorflow/compiler/xla/service/rng_expander.h"
 #include "tensorflow/compiler/xla/service/scatter_expander.h"
 #include "tensorflow/compiler/xla/service/slice_sinker.h"
+#include "tensorflow/compiler/xla/service/slow_operation_alarm.h"
 #include "tensorflow/compiler/xla/service/sort_simplifier.h"
 #include "tensorflow/compiler/xla/service/transpose_folding.h"
 #include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
@@ -606,6 +607,7 @@
   VLOG(1) << "Compiling: " << module->name();
   XLA_SCOPED_LOGGING_TIMER(
       absl::StrFormat("Compiling [%s] for CPU using JIT", module->name()));
+  auto slow_compile_alarm = SlowCompilationAlarm();
 
   TF_RET_CHECK(stream_exec != nullptr);
   std::call_once(llvm_command_line_options_initialized,
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index 4765798..9b79e8c 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -194,13 +194,13 @@
 
   uint64 end_micros = tensorflow::Env::Default()->NowMicros();
 
-  {
-    tensorflow::mutex_lock lock(mutex_);
+  if (run_options->execution_profile()) {
     const double nanoseconds = (end_micros - start_micros) * 1000.0;
-    execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0));
+    run_options->execution_profile()->set_compute_time_ns(
+        std::max(nanoseconds, 1.0));
     // If hlo profiling was disabled then the cycle count is left empty.
     if (hlo_execution_profile) {
-      execution_profile_.set_compute_cycle_count(
+      run_options->execution_profile()->set_compute_cycle_count(
           hlo_execution_profile->total_cycles_executed(
               *module().entry_computation()));
     }
@@ -268,30 +268,8 @@
   return std::move(result_buffer);
 }
 
-StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteOnStream(
-    const ServiceExecutableRunOptions* run_options,
-    absl::Span<const ShapedBuffer* const> arguments,
-    HloExecutionProfile* hlo_execution_profile) {
-  TF_ASSIGN_OR_RETURN(
-      auto result,
-      ExecuteAsyncOnStreamImpl(run_options, arguments, hlo_execution_profile));
-  TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone());
-  return std::move(result);
-}
-
 StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
     const ServiceExecutableRunOptions* run_options,
-    absl::Span<const ShapedBuffer* const> arguments) {
-  if (hlo_profiling_enabled()) {
-    return Unimplemented(
-        "Asynchronous execution on stream with hlo profiling is not yet "
-        "supported on CPU.");
-  }
-  return ExecuteAsyncOnStreamImpl(run_options, arguments, nullptr);
-}
-
-StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
-    const ServiceExecutableRunOptions* run_options,
     absl::Span<const ShapedBuffer* const> arguments,
     HloExecutionProfile* hlo_execution_profile) {
   if (GetRootValueSet().IsAmbiguous()) {
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
index 169acde..37af630 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
@@ -55,15 +55,11 @@
                 std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
   ~CpuExecutable() override {}
 
-  StatusOr<ScopedShapedBuffer> ExecuteOnStream(
+  StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
       const ServiceExecutableRunOptions* run_options,
       absl::Span<const ShapedBuffer* const> arguments,
       HloExecutionProfile* hlo_execution_profile) override;
 
-  StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
-      const ServiceExecutableRunOptions* run_options,
-      absl::Span<const ShapedBuffer* const> arguments) override;
-
   // This should be called after set_ir_module_string.
   const string& ir_module_string() const { return ir_module_string_; }
 
@@ -86,16 +82,6 @@
   const BufferAssignment& buffer_assignment() const { return *assignment_; }
 
  private:
-  // This is for sharing the code between ExecuteOnStream and
-  // ExecuteAsyncOnStream.
-  //
-  // Notice that it's tricky to use correctly, as the profile object (when it
-  // exists) must out-live the task.
-  StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStreamImpl(
-      const ServiceExecutableRunOptions* run_options,
-      absl::Span<const ShapedBuffer* const> arguments,
-      HloExecutionProfile* hlo_execution_profile);
-
   // Creates an array suitable for passing as the "buffer_table" argument to the
   // JIT compiled function pointer.
   //
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
index 6620a96..a6f960a 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
@@ -40,10 +40,11 @@
          hlo.opcode() == HloOpcode::kTranspose;
 }
 
-bool IsNonComplexMatrixVectorDot(const HloInstruction* hlo) {
+bool IsNonComplexNonBatchedMatrixVectorDot(const HloInstruction* hlo) {
   const Shape& hlo_shape = hlo->shape();
   return !ShapeUtil::ElementIsComplex(hlo_shape) &&
-         hlo->opcode() == HloOpcode::kDot && hlo_shape.dimensions_size() <= 1;
+         hlo->opcode() == HloOpcode::kDot && hlo_shape.dimensions_size() <= 1 &&
+         hlo->dot_dimension_numbers().lhs_batch_dimensions_size() == 0;
 }
 
 bool HasExactlyOneUse(const HloInstruction& hlo_instr) {
@@ -54,7 +55,7 @@
 bool CanBeOutputFused(const HloInstruction* producer,
                       const HloInstruction* consumer) {
   return consumer->opcode() == HloOpcode::kAdd &&
-         IsNonComplexMatrixVectorDot(producer) &&
+         IsNonComplexNonBatchedMatrixVectorDot(producer) &&
          HasExactlyOneUse(*producer) == 1;
 }
 
@@ -74,10 +75,13 @@
   constexpr int kFusionThresholdBytes = 16 * 1024;
 
   if (CanBeOutputFused(producer, consumer)) {
+    VLOG(2) << "Fusion OK: Can create output fusion.";
     return true;
   }
 
   if (CanBeOutputFusedIntoSomeOperand(producer)) {
+    VLOG(2)
+        << "Bailing because producer can be output-fused into some operand.";
     return false;
   }
 
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index ceaeacb..f0d7461 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -1739,6 +1739,16 @@
     return false;
   }
 
+  int vector_register_size_in_elements =
+      target_machine_features_.vector_register_byte_size(
+          *compute_function_->function()) /
+      ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type());
+  if (vector_register_size_in_elements == 0) {
+    // Either we don't know the vector register width for the target or the
+    // vector register is smaller than the size of the primitive type.
+    return false;
+  }
+
   int vectorization_factor_in_bytes =
       target_machine_features_.vectorization_factor_in_bytes();
 
diff --git a/tensorflow/compiler/xla/service/cpu/vectorized_reduce_with_no_vector_registers_test.cc b/tensorflow/compiler/xla/service/cpu/vectorized_reduce_with_no_vector_registers_test.cc
new file mode 100644
index 0000000..2918c88
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/vectorized_reduce_with_no_vector_registers_test.cc
@@ -0,0 +1,106 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "llvm/IR/Function.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/Support/TargetRegistry.h"
+#include "llvm/Target/TargetMachine.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace {
+class CodegenReduceOnArchWithNoVectorRegisters : public HloTestBase {};
+
+StatusOr<unsigned> GetTargetVectorRegisterByteSize(std::string triple) {
+  // Unfortunately we need a lot of boilerplate to get to an
+  // llvm::TargetMachine.
+
+  std::string error;
+  const llvm::Target* target =
+      llvm::TargetRegistry::lookupTarget(triple, error);
+  if (target == nullptr) {
+    return InternalError("TargetRegistry::lookupTarget failed: %s", error);
+  }
+
+  llvm::LLVMContext context;
+  std::unique_ptr<llvm::Function> function =
+      absl::WrapUnique(llvm::Function::Create(
+          llvm::FunctionType::get(llvm::Type::getVoidTy(context), {}),
+          llvm::GlobalValue::ExternalLinkage, "test"));
+
+  std::unique_ptr<llvm::TargetMachine> target_machine =
+      absl::WrapUnique(target->createTargetMachine(
+          /*TT=*/triple, /*CPU=*/"", /*Features=*/"", llvm::TargetOptions{},
+          /*RM=*/llvm::None));
+  cpu::LLVMTargetMachineFeatures target_machine_features(target_machine.get());
+  return target_machine_features.vector_register_byte_size(*function);
+}
+
+TEST_F(CodegenReduceOnArchWithNoVectorRegisters, Test) {
+  absl::string_view text = R"(
+HloModule Reduce
+
+add {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY main {
+  input = f32[1000,1000] parameter(0)
+  constant = f32[] constant(0)
+  ROOT reduce = f32[1000] reduce(input, constant), dimensions={0}, to_apply=add
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
+                          ParseAndReturnVerifiedModule(text));
+  cpu::CpuCompiler cpu_compiler;
+  auto module_group = absl::make_unique<HloModuleGroup>("group");
+  module_group->push_back(std::move(hlo_module));
+
+  // Check that the GetTargetVectorRegisterByteSize is itself working.
+  TF_ASSERT_OK_AND_ASSIGN(unsigned vector_register_byte_size_for_x86_64,
+                          GetTargetVectorRegisterByteSize("x86_64-pc-linux"));
+  ASSERT_EQ(vector_register_byte_size_for_x86_64, 16);
+
+  std::string triple = "i686-none-android";
+
+  TF_ASSERT_OK_AND_ASSIGN(unsigned vector_register_byte_size,
+                          GetTargetVectorRegisterByteSize(triple));
+
+  // This test is supposed to check whether the XLA CPU vectorized reduction
+  // codegen works correctly for architectures that do not have vector
+  // registers.  So first ASSERT that `triple` is actually a target with no
+  // vector registers, as otherwise the test isn't actually testing anything
+  // interesting.
+
+  ASSERT_EQ(vector_register_byte_size, 0);
+
+  cpu::CpuAotCompilationOptions aot_compilation_options(
+      /*triple=*/triple, /*cpu_name=*/"", /*features=*/"",
+      /*entry_point_name=*/"main",
+      cpu::CpuAotCompilationOptions::RelocationModel::BigPic);
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::vector<std::unique_ptr<AotCompilationResult>> aot_compilation_result,
+      cpu_compiler.CompileAheadOfTime(std::move(module_group),
+                                      aot_compilation_options));
+  EXPECT_EQ(aot_compilation_result.size(), 1);
+}
+}  // namespace
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 1341535..86bed87 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -300,7 +300,12 @@
 
   // Useful when we want to visit the same computation more than once with the
   // same visitor.
-  void ResetVisitStates() { visit_state_.clear(); }
+  void ResetVisitStates() {
+    // Clear the map, but don't resize the capacity across uses -- Calculating
+    // and reserving space could be expensive, and we always use the same
+    // module->instruction_count() as the capacity.
+    visit_state_.erase(visit_state_.begin(), visit_state_.end());
+  }
 
   void SetVisitState(int id, VisitState state) { visit_state_[id] = state; }
 
diff --git a/tensorflow/compiler/xla/service/dump.cc b/tensorflow/compiler/xla/service/dump.cc
index 6a48372..331c935 100644
--- a/tensorflow/compiler/xla/service/dump.cc
+++ b/tensorflow/compiler/xla/service/dump.cc
@@ -136,10 +136,6 @@
   bool dump_snapshots;
 };
 
-string FilenameFor(const HloModule& module, string_view suffix) {
-  return StrFormat("module_%04d.%s", module.unique_id(), suffix);
-}
-
 void DumpToFileInDirImpl(string_view filename, string_view contents,
                          const CanonicalDebugOptions& opts) {
   if (opts.dumping_to_stdout()) {
@@ -263,6 +259,10 @@
 
 }  // namespace
 
+string FilenameFor(const HloModule& module, string_view suffix) {
+  return StrFormat("module_%04d.%s", module.unique_id(), suffix);
+}
+
 void DumpToFileInDir(const HloModule& module, string_view suffix,
                      string_view contents) {
   DumpToFileInDirImpl(FilenameFor(module, suffix), contents,
diff --git a/tensorflow/compiler/xla/service/dump.h b/tensorflow/compiler/xla/service/dump.h
index 6edc9b2..d245ad5 100644
--- a/tensorflow/compiler/xla/service/dump.h
+++ b/tensorflow/compiler/xla/service/dump.h
@@ -33,6 +33,9 @@
 class HloExecutionProfile;
 class HloSnapshot;
 
+// Create the filename we will use to dump in DumpToFileInDir.
+string FilenameFor(const HloModule& module, absl::string_view suffix);
+
 // Writes the given string to a file in the xla_dump_to directory specified by
 // module's DebugOptions.
 //
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 517d15f..2b90d77 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -515,15 +515,14 @@
           : input_type;
   switch (op->opcode()) {
     case HloOpcode::kLog: {
-      // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a)
+      // log(a+bi) = log(abs(a+bi)) + i*atan2(b,a)
       auto a = EmitExtractReal(operand_value);
       auto b = EmitExtractImag(operand_value);
-      llvm::Type* llvm_ty = a->getType();
-      auto sum_sq = FAdd(FMul(a, a), FMul(b, b));
-      TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
-      TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a));
-      auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
-      return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle);
+      TF_ASSIGN_OR_RETURN(llvm::Value * angle, EmitAtan2(component_type, b, a));
+      TF_ASSIGN_OR_RETURN(llvm::Value * abs,
+                          EmitComplexAbs(component_type, operand_value));
+      TF_ASSIGN_OR_RETURN(llvm::Value * log_abs, EmitLog(component_type, abs));
+      return EmitComposeComplex(op, log_abs, angle);
     }
     case HloOpcode::kLog1p: {
       // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
@@ -639,32 +638,128 @@
              =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) +
                i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) /
               ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
+             =(e^(2a)-e^(-2a) +
+               i*[cos(b)sin(b)(e^(2a)+2+e^(-2a))-cos(b)sin(b)(e^(2a)-2+e^(2a)))]
+               / (cos(b)^2*(e^(2a)+2+e^(-2a)) + sin(b)^2*(e^(2a)-2+e^(2a))
+             =(e^(2a)-e^(-2a) +
+               i*cos(b)sin(b)*[e^(2a)+2+e^(-2a)-e^(2a)+2-e^(-2a)]) /
+               ([cos(b)^2 + sin(b)^2][e^(2a)+e^(-2a)])+2*[cos(b)^2 - sin(b)^2])
+             =(e^(2a)-e^(-2a) + i*cos(b)sin(b)*4) /
+              (e^(2a)+e^(-2a)+2*[cos(b)^2 - sin(b)^2])
+             =(e^(2a)-e^(-2a) + i*[sin(2b)/2]*4) /
+              (e^(2a)+e^(-2a)+2*[cos(2b)])
+             =(e^(2a)-e^(-2a) + i*2*sin(2b)) / (e^(2a) + e^(-2a) + 2*cos(2b))
       */
-      auto a = EmitExtractReal(operand_value);
-      auto b = EmitExtractImag(operand_value);
-      TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a));
-      TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b));
-      TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b));
-      auto exp_neg_a = FDiv(llvm::ConstantFP::get(exp_a->getType(), 1), exp_a);
-      auto exp_2a_minus_exp_neg_2a =
-          FSub(FMul(exp_a, exp_a), FMul(exp_neg_a, exp_neg_a));
-      auto cos_b_sq = FMul(cos_b, cos_b);
-      auto sin_b_sq = FMul(sin_b, sin_b);
-      auto real_num = FAdd(FMul(cos_b_sq, exp_2a_minus_exp_neg_2a),
-                           FMul(sin_b_sq, exp_2a_minus_exp_neg_2a));
-      auto cos_b_sin_b = FMul(cos_b, sin_b);
-      auto exp_a_plus_exp_neg_a = FAdd(exp_a, exp_neg_a);
-      auto exp_a_plus_exp_neg_a_sq =
-          FMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a);
-      auto exp_a_minus_exp_neg_a = FSub(exp_a, exp_neg_a);
-      auto exp_a_minus_exp_neg_a_sq =
-          FMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a);
-      auto imag_num = FMul(
-          cos_b_sin_b, FSub(exp_a_plus_exp_neg_a_sq, exp_a_minus_exp_neg_a_sq));
-      auto denom = FAdd(FMul(cos_b_sq, exp_a_plus_exp_neg_a_sq),
-                        FMul(sin_b_sq, exp_a_minus_exp_neg_a_sq));
-      return EmitComposeComplex(op, FDiv(real_num, denom),
-                                FDiv(imag_num, denom));
+      llvm::Value* a = EmitExtractReal(operand_value);
+      llvm::Value* b = EmitExtractImag(operand_value);
+
+      llvm::Type* type = a->getType();
+
+      llvm::Value* neg_one = llvm::ConstantFP::get(type, -1.F);
+      llvm::Value* two_a = FAdd(a, a);
+      llvm::Value* neg_2a = FMul(neg_one, two_a);
+
+      // When we are calculating the real numerator, e^(2a)-e^(-2a), for small
+      // values of `a`, we will get a ULP of 2^-23 using the exp function. Using
+      // expm1 to calculate e^(2a)-e^(-2a) = [e^(2a)-1] - [e^(-2a)-1] allows our
+      // ULP to be arbitrarily small. For larger values of `a`, calculating the
+      // numerator as Exp(2a)-Exp(-2a) vs Expm1(2a)-Expm1(-2a) return virtually
+      // identical results.
+      TF_ASSIGN_OR_RETURN(llvm::Value * exp_2a_m1,
+                          EmitExpm1(component_type, two_a));
+      TF_ASSIGN_OR_RETURN(llvm::Value * exp_neg_2a_m1,
+                          EmitExpm1(component_type, neg_2a));
+      llvm::Value* real_numerator = FSub(exp_2a_m1, exp_neg_2a_m1);
+
+      // We can use the identity cos(2b)+1 = cos(b)^2-sin(b)^2+cos(b)^2+sin(b)^2
+      // = 2cos(b)^2. This gives us the ability to be more precise when the
+      // denominator is close to zero.
+      TF_ASSIGN_OR_RETURN(llvm::Value * cos_b, EmitCos(component_type, b));
+      llvm::Value* four = llvm::ConstantFP::get(type, 4.F);
+      llvm::Value* cos_b_sq = FMul(cos_b, cos_b);
+      llvm::Value* two_cos_2b_p2 = FMul(cos_b_sq, four);
+
+      // Similarly we can compute sin(2b) with the formula sin(2b) =
+      // 2*sin(b)*cos(b).
+      TF_ASSIGN_OR_RETURN(llvm::Value * sin_b, EmitSin(component_type, b));
+      llvm::Value* imag_numerator = FMul(four, FMul(cos_b, sin_b));
+
+      // Expm1(x) is about x for small values of x, but exp_sum_m2 is about x^2
+      // for small value of x. As a result, due to floating point precission
+      // issues, x^2 is a better approximation than Expm1(x) + Expm1(x) for
+      // small values of x.
+      llvm::Value* a_sqr = FMul(a, a);
+      llvm::Value* use_approx_cutoff = llvm::ConstantFP::get(type, 1e-8);
+      llvm::Value* use_approx = FCmpOLT(a_sqr, use_approx_cutoff);
+
+      llvm::Value* exp_sum_m2 =
+          Select(use_approx, a_sqr, FAdd(exp_2a_m1, exp_neg_2a_m1));
+      llvm::Value* denom = FAdd(exp_sum_m2, two_cos_2b_p2);
+
+      // As `a` grows toward +inf and -inf, the real numerator will grow towards
+      // +inf and -inf respectively, while the denominator will always grow
+      // towards +inf. The result is real_numerator/denom = NaN, when it should
+      // equal +1 and -1 respectively. Therefore, if our denominator is +inf,
+      // we just hardcode the limits for the real numbers.
+      llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
+      llvm::Value* is_inf = FCmpOEQ(exp_sum_m2, inf);
+      llvm::Value* real_limit = llvm_ir::EmitCallToIntrinsic(
+          llvm::Intrinsic::copysign, {neg_one, a}, {type}, b_);
+
+      llvm::Value* real =
+          Select(is_inf, real_limit, FDiv(real_numerator, denom));
+      llvm::Value* imag = FDiv(imag_numerator, denom);
+
+      // The complex tanh functions have a few corner cases:
+      // 1. (+0, +0) => (+0, +0)        - Handled normally
+      // 2. (x, +Inf) => (NaN, NaN)     - See below
+      // 3. (x, NaN) => (NaN, NaN)      - See below
+      // 4. (+inf, y) => (1, +0)        - Handled normally
+      // 5. (+Inf, +Inf) => (1, +/-0)   - See below
+      // 6. (+Inf, NaN) => (1, +/-0)    - See below
+      // 7. (NaN, +0) => (NaN, +0)      - See below
+      // 8. (NaN, y) => (NaN, NaN)      - Handled normally
+      // 9. (NaN, NaN) => (NaN, NaN)    - Handled normally
+      //
+      // For the cases that aren't handled normally:
+      // 2/3) Part of the calculation we do is that if exp(a) + exp(-a) = +inf,
+      //      then we return (+/-1, +/-0). However, this is only true if we
+      //      assume that a is infinity or b is finite. In the event that both a
+      //      is finite and b is either +/-Inf or NaN, then our normal
+      //      calculation would end up returing (+/-1, NaN), as opposed to (NaN,
+      //      NaN).
+      // 5/6) We always calculate the imagninary value as sin(2b)/denominator.
+      //      When the denominator is infinity, this assures us that the zero is
+      //      the correct sign. However if our imaginary input results in
+      //      sin(2b) = NaN, we calculate our imaginary result as NaN.
+      // 7)   In the event that a is NaN, the denominator will be NaN.
+      //      Therefore, the normal calculation gives (NaN, NaN) while we need
+      //      (NaN, +0).
+      if (!(b_->getFastMathFlags().noNaNs() &&
+            b_->getFastMathFlags().noInfs())) {
+        llvm::Value* abs_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
+                                                          {a}, {type}, b_);
+        llvm::Value* zero = llvm::ConstantFP::get(type, 0.F);
+        llvm::Value* nan = llvm::ConstantFP::getNaN(type);
+
+        llvm::Value* a_is_inf = FCmpOEQ(abs_a, inf);
+        llvm::Value* b_is_zero = FCmpOEQ(b, zero);
+
+        // imag_numerator = 2sin(2b), so sin(2b) is NaN if and only if
+        // imag_numerator is NaN.
+        llvm::Value* sin_2b_is_nan =
+            b_->CreateFCmpUNO(imag_numerator, imag_numerator);
+
+        llvm::Value* real_is_nan =
+            b_->CreateAnd(sin_2b_is_nan, b_->CreateNot(a_is_inf));
+        llvm::Value* imag_is_zero =
+            b_->CreateOr(b_is_zero, b_->CreateAnd(a_is_inf, sin_2b_is_nan));
+
+        real = Select(real_is_nan, nan, real);
+        imag = Select(imag_is_zero, zero, imag);
+      }
+
+      return EmitComposeComplex(op, real, imag);
     }
     case HloOpcode::kAbs: {
       return EmitComplexAbs(component_type, operand_value);
@@ -1100,7 +1195,10 @@
   auto x_squared = FMul(x, x);
   auto x_squared_over_two = FMul(x_squared, half);
   auto for_small_x = FAdd(x, x_squared_over_two);
-  const auto kExponentIsSmallThreshold = 1e-5;
+  // At this point, the relative errors due to floating point precision loss of
+  // calculating exp(x) - 1 and the polynomial exp(x)-1 = x + x^2/2 are about
+  // equal, with a value of approximetely 2^-16.
+  const auto kExponentIsSmallThreshold = 0.009;
   auto abs_x =
       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
   auto x_is_small =
diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc
index 7b60c98..9f5b764 100644
--- a/tensorflow/compiler/xla/service/executable.cc
+++ b/tensorflow/compiler/xla/service/executable.cc
@@ -26,9 +26,42 @@
 #include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/lib/strings/proto_serialization.h"
 #include "tensorflow/core/platform/env.h"
+#include "tensorflow/stream_executor/device_description.h"
 
 namespace xla {
 
+StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStream(
+    const ServiceExecutableRunOptions* run_options,
+    absl::Span<const ShapedBuffer* const> arguments,
+    HloExecutionProfile* hlo_execution_profile) {
+  StatusOr<ScopedShapedBuffer> result =
+      ExecuteAsyncOnStream(run_options, arguments, hlo_execution_profile);
+  Status blocking_status = run_options->stream()->BlockHostUntilDone();
+  TF_RETURN_IF_ERROR(result.status());
+  TF_RETURN_IF_ERROR(blocking_status);
+  return result;
+}
+
+StatusOr<ExecutionOutput> Executable::ExecuteOnStream(
+    const ServiceExecutableRunOptions* run_options,
+    std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> arguments,
+    HloExecutionProfile* hlo_execution_profile) {
+  StatusOr<ExecutionOutput> result = ExecuteAsyncOnStream(
+      run_options, std::move(arguments), hlo_execution_profile);
+  Status blocking_status = run_options->stream()->BlockHostUntilDone();
+  TF_RETURN_IF_ERROR(result.status());
+  TF_RETURN_IF_ERROR(blocking_status);
+  return result;
+}
+
+StatusOr<ExecutionOutput> Executable::ExecuteAsyncOnStream(
+    const ServiceExecutableRunOptions* /*run_options*/,
+    std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> /*arguments*/,
+    HloExecutionProfile* /*hlo_execution_profile*/) {
+  return Unimplemented(
+      "MaybeOwningDeviceMemory version of overload is not implemented ");
+}
+
 StatusOr<std::vector<ScopedShapedBuffer>> Executable::ExecuteOnStreams(
     absl::Span<const ServiceExecutableRunOptions> run_options,
     absl::Span<const absl::Span<const ShapedBuffer* const>> arguments) {
@@ -49,8 +82,9 @@
     // We cannot BlockHostUntilDone() on the already-launched executions in case
     // of error, since if the executions communicate, the initially launched
     // executions may never complete if not all executions are running.
-    TF_ASSIGN_OR_RETURN(auto rv,
-                        ExecuteAsyncOnStream(&run_options[i], arguments[i]));
+    TF_ASSIGN_OR_RETURN(
+        auto rv, ExecuteAsyncOnStream(&run_options[i], arguments[i],
+                                      /*hlo_execution_profile=*/nullptr));
     return_values.push_back(std::move(rv));
   }
   for (const auto& options : run_options) {
@@ -61,27 +95,37 @@
 }
 
 StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStreamWrapper(
-    const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile,
+    const ServiceExecutableRunOptions* run_options,
+    absl::Span<const ShapedBuffer* const> arguments) {
+  StatusOr<ScopedShapedBuffer> result =
+      ExecuteAsyncOnStreamWrapper(run_options, arguments);
+  TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone());
+  return result;
+}
+
+StatusOr<ScopedShapedBuffer> Executable::ExecuteAsyncOnStreamWrapper(
+    const ServiceExecutableRunOptions* run_options,
     absl::Span<const ShapedBuffer* const> arguments) {
   se::Stream* stream = run_options->stream();
-  std::unique_ptr<se::Timer> timer;
+  std::shared_ptr<se::Timer> timer;
+  ExecutionProfile* profile = run_options->run_options().execution_profile();
   if (profile != nullptr) {
-    timer.reset(new se::Timer(stream->parent()));
+    timer = std::make_shared<se::Timer>(stream->parent());
     stream->InitTimer(timer.get()).ThenStartTimer(timer.get());
   }
 
   VLOG(1) << "enqueueing executable on stream...";
   // If the profiling flag isn't enabled, we pass nullptr as the profile to
   // indicate profiling is not requested.
-  std::unique_ptr<HloExecutionProfile> profile_ptr =
+  std::shared_ptr<HloExecutionProfile> profile_ptr =
       module_config().debug_options().xla_hlo_profile() &&
               hlo_profiling_enabled()
-          ? absl::make_unique<HloExecutionProfile>(&hlo_profile_printer_data(),
-                                                   &hlo_profile_index_map())
+          ? std::make_shared<HloExecutionProfile>(&hlo_profile_printer_data(),
+                                                  &hlo_profile_index_map())
           : nullptr;
 
   StatusOr<ScopedShapedBuffer> return_value =
-      ExecuteOnStream(run_options, arguments, profile_ptr.get());
+      ExecuteAsyncOnStream(run_options, arguments, profile_ptr.get());
   if (!return_value.status().ok()) {
     if (profile != nullptr) {
       // Ensure the ThenStartTimer call has completed before we destroy timer.
@@ -96,30 +140,19 @@
   }
 
   if (profile != nullptr) {
-    VLOG(1) << "enqueueing 'stop timer' and blocking host until done...";
+    VLOG(1) << "enqueueing 'stop timer' and profiling callback...";
     stream->ThenStopTimer(timer.get());
-    TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
-    VLOG(1) << "done with block-host-until-done";
 
+    // We block instead of using an async callback because reading the timer
+    // value may call back into the driver on GPU, which is not allowed.
+    TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
+
+    const int64 executable_size_in_bytes = SizeOfGeneratedCodeInBytes();
     // Merge in run-time profile information from execution_profile.
-    //
-    // TODO(b/71713097): This is buggy -- even though the mutex takes care of
-    // C++ level races, some other concurrent ExecuteOnStreamWrapper call could
-    // have rewritten the execution_profile before we get to it.
-    profile->MergeFrom(execution_profile());
 
     // Overall execution time (in nanoseconds) from the executor timer.
-    if (stream->ok()) {
-      // Don't read timer->Nanoseconds() if the stream isn't OK -- that's
-      // illegal.
-      profile->set_compute_and_transfer_time_ns(timer->Nanoseconds());
-    }
+    profile->set_compute_and_transfer_time_ns(timer->Nanoseconds());
 
-    // TODO(b/28123297): On GPU we end up including transfer time in
-    // the compute time this way. Instead, we should get the correct
-    // value by measuring it. Setting the field here at least lets
-    // benchmarks provide *some* value for GPU computations.
-    //
     // TODO(b/28447609): The value in compute_and_transfer_time_ns is actually
     // the compute time without the transfer time, so this way we get the
     // correct compute time. We should instead have the correct value for
@@ -128,21 +161,23 @@
       profile->set_compute_time_ns(profile->compute_and_transfer_time_ns());
     }
 
-    const int64 executable_size_in_bytes = SizeInBytes();
     if (executable_size_in_bytes != 0) {
       profile->set_executable_size_in_bytes(executable_size_in_bytes);
     }
   }
 
   if (profile_ptr != nullptr) {
-    XLA_LOG_LINES(
-        tensorflow::INFO,
-        profile_ptr->ToString(stream->parent()->GetDeviceDescription()));
+    const se::DeviceDescription* device_description =
+        &stream->parent()->GetDeviceDescription();
+    stream->ThenDoHostCallback([profile_ptr, device_description]() {
+      XLA_LOG_LINES(tensorflow::INFO,
+                    profile_ptr->ToString(*device_description));
+    });
   }
 
   return return_value;
 }
 
-int64 Executable::SizeInBytes() { return -1; }
+int64 Executable::SizeOfGeneratedCodeInBytes() { return -1; }
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h
index 492ea72..2238322 100644
--- a/tensorflow/compiler/xla/service/executable.h
+++ b/tensorflow/compiler/xla/service/executable.h
@@ -123,16 +123,10 @@
   // enabled.
   //
   // Returns a shaped buffer containing the result of the computation.
-  virtual StatusOr<ScopedShapedBuffer> ExecuteOnStream(
+  StatusOr<ScopedShapedBuffer> ExecuteOnStream(
       const ServiceExecutableRunOptions* run_options,
       absl::Span<const ShapedBuffer* const> arguments,
-      HloExecutionProfile* hlo_execution_profile) = 0;
-
-  // Same as ExecuteOnStream(), but this call is non-blocking and returns as
-  // soon as all of the operations are enqueued for launch on the stream.
-  virtual StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
-      const ServiceExecutableRunOptions* run_options,
-      absl::Span<const ShapedBuffer* const> arguments) = 0;
+      HloExecutionProfile* hlo_execution_profile);
 
   // Starts the given program executing on the given stream/executor.
   //
@@ -143,20 +137,31 @@
   //
   // If an input is donated to XLA but is not reused as output, it is returned
   // as an leftover buffer for the caller to release.
-  virtual StatusOr<ExecutionOutput> ExecuteOnStream(
+  //
+  // This call should be non-blocking and may return as soon as all of the
+  // operations are enqueued for launch on the stream. Note that some
+  // implementations may in fact block or may block in some circumstances (e.g.,
+  // when profiling); i.e., asynchronous is a "may" not a "must".
+  //
+  // If the hlo_execution_profile is provided as non-nullptr, profiling will be
+  // enabled. Note that profiling is tricky to use correctly, as the profiling
+  // objects (when they exist) must out-live the task.
+  virtual StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
+      const ServiceExecutableRunOptions* run_options,
+      absl::Span<const ShapedBuffer* const> arguments,
+      HloExecutionProfile* hlo_execution_profile) = 0;
+
+  // Same as ExecuteAsyncOnStream(), but blocks waiting for the computation to
+  // complete.
+  StatusOr<ExecutionOutput> ExecuteOnStream(
       const ServiceExecutableRunOptions* run_options,
       std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> arguments,
-      HloExecutionProfile* hlo_execution_profile) {
-    return Unimplemented(
-        "MaybeOwningDeviceMemory version of overload is not implemented ");
-  }
+      HloExecutionProfile* hlo_execution_profile);
 
   virtual StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
       const ServiceExecutableRunOptions* run_options,
-      std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> arguments) {
-    return Unimplemented(
-        "MaybeOwningDeviceMemory version of overload is not implemented ");
-  }
+      std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> arguments,
+      HloExecutionProfile* hlo_execution_profile);
 
   // Same as ExecuteOnStream(), but runs this executable on multiple
   // streams. arguments[i] contains the arguments to the execution on
@@ -171,6 +176,7 @@
   // called explicitly for other (async, for example) variants after the stream
   // has completed.
   virtual Status PopulateExecutionProfile(
+      ExecutionProfile* execution_profile,
       HloExecutionProfile* hlo_execution_profile, se::Stream* stream) {
     return Status::OK();
   }
@@ -179,15 +185,12 @@
   // timer for the execution, sets up HLO profiling if enabled, and fills in the
   // given ExecutionProfile if non-null.
   StatusOr<ScopedShapedBuffer> ExecuteOnStreamWrapper(
-      const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile,
+      const ServiceExecutableRunOptions* run_options,
       absl::Span<const ShapedBuffer* const> arguments);
 
-  // Returns the ExecutionProfile from executing on the device. This includes
-  // the number of cycles taken for the computation or the compilation time.
-  ExecutionProfile execution_profile() const {
-    tensorflow::mutex_lock lock(mutex_);
-    return execution_profile_;
-  }
+  StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStreamWrapper(
+      const ServiceExecutableRunOptions* run_options,
+      absl::Span<const ShapedBuffer* const> arguments);
 
   const HloProfilePrinterData& hlo_profile_printer_data() const {
     CHECK(hlo_profiling_enabled());
@@ -219,30 +222,27 @@
     return hlo_module_->config().entry_computation_layout().result_shape();
   }
 
-  // Returns the size of the executable in bytes. Returns -1 by default if the
-  // method is not overridden to support this kind of query.
-  virtual int64 SizeInBytes();
+  // Returns the size of the executable in bytes. Returns -1 if this query is
+  // not supported by the executable.
+  //
+  // Does not include the size of used libraries (e.g. cuDNN, Eigen, etc.).
+  virtual int64 SizeOfGeneratedCodeInBytes();
 
   // Dumping helpers.
-  void set_hlo_snapshot(std::unique_ptr<xla::HloSnapshot> hlo_snapshot) {
-    hlo_snapshot_ = std::move(hlo_snapshot);
+  void set_hlo_proto(std::unique_ptr<xla::HloProto> hlo_proto) {
+    hlo_proto_ = std::move(hlo_proto);
   }
-  bool dumping_snapshot() const { return hlo_snapshot_ != nullptr; }
-  HloSnapshot* hlo_snapshot() const { return hlo_snapshot_.get(); }
+  bool dumping_snapshot() const { return hlo_proto_ != nullptr; }
+  HloProto const* hlo_proto() const { return hlo_proto_.get(); }
 
  protected:
-  mutable tensorflow::mutex mutex_;
-
-  // Execution profile data on the device.
-  ExecutionProfile execution_profile_ GUARDED_BY(mutex_);
-
   // HloModule this was compiled from. BufferAssignment keeps pointers to
   // HloInstructions owned by the HloModule so we need to keep the HloModule
   // around.
   const std::shared_ptr<HloModule> hlo_module_;
 
-  // HloSnapshot this was compiled from. Null if not dumping executions.
-  std::unique_ptr<HloSnapshot> hlo_snapshot_;
+  // The serialized HLO proto. Non-null only if dumping snapshots is enabled.
+  std::unique_ptr<HloProto const> hlo_proto_;
 
   // Execution count, used to generate a unique filename for each dumped
   // execution.
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index a5fc6e8..3c2dbc0 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -3,12 +3,24 @@
 
 load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
 load(
-    "//tensorflow/core:platform/default/build_config_root.bzl",
-    "if_static",
+    "//tensorflow/core/platform:default/build_config_root.bzl",
     "tf_cuda_tests_tags",
 )
-load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_copts", "tf_cuda_library")
+load(
+    "//tensorflow:tensorflow.bzl",
+    "tf_cc_test",
+    "tf_copts",
+    "tf_cuda_library",
+)
 load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
+load(
+    "//tensorflow/core/platform:default/cuda_build_defs.bzl",
+    "if_cuda_is_configured",
+)
+load(
+    "@local_config_rocm//rocm:build_defs.bzl",
+    "if_rocm_is_configured",
+)
 
 package(
     default_visibility = [":friends"],
@@ -410,7 +422,6 @@
 cc_library(
     name = "gpu_executable",
     srcs = [
-        "cholesky_thunk.cc",
         "collective_permute_thunk.cc",
         "conditional_thunk.cc",
         "convolution_thunk.cc",
@@ -431,9 +442,10 @@
         "triangular_solve_thunk.cc",
         "tuple_thunk.cc",
         "while_thunk.cc",
-    ],
+    ] + if_cuda_is_configured([
+        "cholesky_thunk.cc",
+    ]),
     hdrs = [
-        "cholesky_thunk.h",
         "collective_permute_thunk.h",
         "conditional_thunk.h",
         "convolution_thunk.h",
@@ -454,12 +466,13 @@
         "triangular_solve_thunk.h",
         "tuple_thunk.h",
         "while_thunk.h",
-    ],
+    ] + if_cuda_is_configured([
+        "cholesky_thunk.h",
+    ]),
     deps = [
         ":backend_configs",
         ":buffer_allocations",
         ":cudnn_conv_runner",
-        ":cusolver_context",
         ":gpu_debug_info_manager",
         ":gpu_types",
         ":hlo_execution_profiler",
@@ -495,17 +508,12 @@
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:stream_executor_no_cuda",
-        "//tensorflow/core/platform/default/build_config:cublas_plugin",
-        "//tensorflow/core/platform/default/build_config:cudnn_plugin",
-        "//tensorflow/core/platform/default/build_config:cufft_plugin",
-        "//tensorflow/core/platform/default/build_config:stream_executor_cuda",  # build_cleaner: keep
         "//tensorflow/core/profiler/lib:traceme",
         "//tensorflow/stream_executor",
         "//tensorflow/stream_executor:blas",
         "//tensorflow/stream_executor:device_memory",
         "//tensorflow/stream_executor:device_memory_allocator",
         "//tensorflow/stream_executor:kernel",
-        "//tensorflow/stream_executor/cuda:cuda_stream",
         "//tensorflow/stream_executor/gpu:gpu_stream",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/base:core_headers",
@@ -516,8 +524,18 @@
         "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/types:optional",
         "@com_google_absl//absl/types:span",
+    ] + if_cuda_is_configured([
+        ":cusolver_context",
+        "//tensorflow/stream_executor/cuda:cuda_stream",
+        "//tensorflow/core/platform/default/build_config:cublas_plugin",
+        "//tensorflow/core/platform/default/build_config:cudnn_plugin",
+        "//tensorflow/core/platform/default/build_config:cufft_plugin",
+        "//tensorflow/core/platform/default/build_config:stream_executor_cuda",  # build_cleaner: keep
         "@local_config_cuda//cuda:cuda_headers",
-    ],
+    ]) + if_rocm_is_configured([
+        "//tensorflow/core/platform/default/build_config:stream_executor_rocm",
+        "@local_config_rocm//rocm:rocm_headers",
+    ]),
 )
 
 cc_library(
@@ -593,6 +611,7 @@
     deps = [
         ":backend_configs",
         ":buffer_comparator",
+        ":cudnn_conv_blacklist",
         ":cudnn_conv_runner",
         ":gpu_autotuning_proto",
         ":gpu_executable",
@@ -621,18 +640,6 @@
 )
 
 cc_library(
-    name = "scratch_allocator",
-    srcs = ["scratch_allocator.cc"],
-    hdrs = ["scratch_allocator.h"],
-    deps = [
-        "//tensorflow/compiler/xla:status_macros",
-        "//tensorflow/compiler/xla:util",
-        "//tensorflow/core:stream_executor_no_cuda",
-        "//tensorflow/stream_executor:device_memory_allocator",
-    ],
-)
-
-cc_library(
     name = "cudnn_conv_runner",
     srcs = ["cudnn_conv_runner.cc"],
     hdrs = ["cudnn_conv_runner.h"],
@@ -703,10 +710,8 @@
         "//tensorflow/core:lib",
         "//tensorflow/core:stream_executor_no_cuda",
         "//tensorflow/stream_executor:blas",
-    ] + if_static(
-        ["@local_config_cuda//cuda:cusolver"],
-        ["//tensorflow/stream_executor/cuda:cusolver_stub"],
-    ),
+        "//tensorflow/stream_executor/cuda:cusolver_lib",
+    ],
 )
 
 cc_library(
@@ -972,20 +977,29 @@
 )
 
 cc_library(
-    name = "nvptx_compiler_impl",
-    srcs = ["nvptx_compiler.cc"],
-    hdrs = ["nvptx_compiler.h"],
+    name = "gpu_compiler",
+    deps = if_cuda_is_configured([
+        ":nvptx_compiler",
+    ]) + if_rocm_is_configured([
+        ":amdgpu_compiler",
+    ]),
+    alwayslink = True,  # Contains compiler registration
+)
+
+cc_library(
+    name = "gpu_compiler_impl",
+    srcs = [
+        "gpu_compiler.cc",
+    ],
+    hdrs = [
+        "gpu_compiler.h",
+    ],
     deps = [
         ":cudnn_batchnorm_rewriter",
         ":cudnn_conv_algorithm_picker",
-        ":cudnn_conv_pad_for_tensor_cores",
         ":cudnn_conv_padding_legalization",
         ":cudnn_conv_rewriter",
-        ":cudnn_fused_conv_rewriter",
-        ":cusolver_rewriter",
         ":fusion_merger",
-        ":gemm_algorithm_picker",
-        ":gemm_rewriter",
         ":gpu_constants",
         ":gpu_copy_insertion",
         ":gpu_executable",
@@ -1038,6 +1052,7 @@
         "//tensorflow/compiler/xla/service:reshape_mover",
         "//tensorflow/compiler/xla/service:rng_expander",
         "//tensorflow/compiler/xla/service:slice_sinker",
+        "//tensorflow/compiler/xla/service:slow_operation_alarm",
         "//tensorflow/compiler/xla/service:sort_simplifier",
         "//tensorflow/compiler/xla/service:stable_sort_expander",
         "//tensorflow/compiler/xla/service:transpose_folding",
@@ -1048,15 +1063,12 @@
         "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination",
         "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
         "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
-        "//tensorflow/core:cuda_libdevice_path",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:regexp_internal",
         "//tensorflow/core:stream_executor_no_cuda",
         "//tensorflow/core/profiler/lib:traceme",
         "//tensorflow/stream_executor:stream_executor_headers",
-        "//tensorflow/stream_executor/cuda:cuda_diagnostics",
-        "//tensorflow/stream_executor/cuda:ptxas_utils",
         "@com_google_absl//absl/container:node_hash_map",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
@@ -1068,12 +1080,146 @@
 
 cc_library(
     name = "nvptx_compiler",
-    srcs = ["nvptx_compiler_registration.cc"],
-    deps = [":nvptx_compiler_impl"],
+    srcs = if_cuda_is_configured([
+        "nvptx_compiler_registration.cc",
+    ]),
+    deps = if_cuda_is_configured([
+        "nvptx_compiler_impl",
+    ]),
     alwayslink = True,  # Contains compiler registration
 )
 
 cc_library(
+    name = "nvptx_compiler_impl",
+    srcs = if_cuda_is_configured([
+        "nvptx_compiler.cc",
+    ]),
+    hdrs = if_cuda_is_configured([
+        "nvptx_compiler.h",
+    ]),
+    deps = [
+        "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/container:node_hash_map",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:optional",
+        "@com_google_absl//absl/types:span",
+        "@llvm//:core",
+        "//tensorflow/compiler/xla:protobuf_util",
+        "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/compiler/xla:types",
+        "//tensorflow/compiler/xla:util",
+        "//tensorflow/compiler/xla/service:algebraic_simplifier",
+        "//tensorflow/compiler/xla/service:batchnorm_expander",
+        "//tensorflow/compiler/xla/service:buffer_assignment",
+        "//tensorflow/compiler/xla/service:call_inliner",
+        "//tensorflow/compiler/xla/service:conditional_simplifier",
+        "//tensorflow/compiler/xla/service:convolution_group_converter",
+        "//tensorflow/compiler/xla/service:dot_decomposer",
+        "//tensorflow/compiler/xla/service:dump",
+        "//tensorflow/compiler/xla/service:dynamic_index_splitter",
+        "//tensorflow/compiler/xla/service:executable",
+        "//tensorflow/compiler/xla/service:flatten_call_graph",
+        "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/compiler/xla/service:hlo_constant_folding",
+        "//tensorflow/compiler/xla/service:hlo_cse",
+        "//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
+        "//tensorflow/compiler/xla/service:hlo_dce",
+        "//tensorflow/compiler/xla/service:hlo_element_type_converter",
+        "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter",
+        "//tensorflow/compiler/xla/service:hlo_pass",
+        "//tensorflow/compiler/xla/service:hlo_pass_pipeline",
+        "//tensorflow/compiler/xla/service:hlo_proto",
+        "//tensorflow/compiler/xla/service:hlo_proto_util",
+        "//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
+        "//tensorflow/compiler/xla/service:hlo_verifier",
+        "//tensorflow/compiler/xla/service:llvm_compiler",
+        "//tensorflow/compiler/xla/service:mem_wasted_on_passthrough_params",
+        "//tensorflow/compiler/xla/service:reduce_precision_insertion",
+        "//tensorflow/compiler/xla/service:reshape_mover",
+        "//tensorflow/compiler/xla/service:rng_expander",
+        "//tensorflow/compiler/xla/service:slice_sinker",
+        "//tensorflow/compiler/xla/service:slow_operation_alarm",
+        "//tensorflow/compiler/xla/service:sort_simplifier",
+        "//tensorflow/compiler/xla/service:stable_sort_expander",
+        "//tensorflow/compiler/xla/service:transpose_folding",
+        "//tensorflow/compiler/xla/service:tuple_simplifier",
+        "//tensorflow/compiler/xla/service:while_loop_constant_sinking",
+        "//tensorflow/compiler/xla/service:while_loop_simplifier",
+        "//tensorflow/compiler/xla/service:while_loop_trip_count_annotator",
+        "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination",
+        "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
+        "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:regexp_internal",
+        "//tensorflow/core:stream_executor_no_cuda",
+        "//tensorflow/core/profiler/lib:traceme",
+        "//tensorflow/stream_executor:stream_executor_headers",
+    ] + if_cuda_is_configured([
+        ":cudnn_batchnorm_rewriter",
+        ":cudnn_conv_algorithm_picker",
+        ":cudnn_conv_padding_legalization",
+        ":cudnn_conv_rewriter",
+        ":cudnn_conv_pad_for_tensor_cores",
+        ":cudnn_fused_conv_rewriter",
+        ":cusolver_rewriter",
+        ":fusion_merger",
+        ":gemm_algorithm_picker",
+        ":gemm_rewriter",
+        ":gpu_compiler_impl",
+        ":gpu_constants",
+        ":gpu_copy_insertion",
+        ":gpu_executable",
+        ":gpu_hlo_schedule",
+        ":gpu_hlo_support_checker",
+        ":gpu_layout_assignment",
+        ":gpu_sanitize_constant_names",
+        ":gpu_scatter_expander",
+        ":instruction_fusion",
+        ":ir_emitter",
+        ":ir_emission_utils",
+        ":multi_output_fusion",
+        ":partition_assignment",
+        ":stream_assignment",
+        ":stream_executor_util",
+        ":target_constants",
+        "//tensorflow/core:cuda_libdevice_path",
+        "//tensorflow/stream_executor/cuda:cuda_diagnostics",
+        "//tensorflow/stream_executor/cuda:ptxas_utils",
+    ]),
+)
+
+cc_library(
+    name = "amdgpu_compiler",
+    srcs = if_rocm_is_configured([
+        "amdgpu_compiler_registration.cc",
+    ]),
+    deps = if_rocm_is_configured([
+        "amdgpu_compiler_impl",
+    ]),
+    alwayslink = True,  # Contains compiler registration
+)
+
+cc_library(
+    name = "amdgpu_compiler_impl",
+    srcs = if_rocm_is_configured([
+        # TODO(whchung@gmail.com) : enable in the subsequent PR.
+        #"amdgpu_compiler.cc",
+    ]),
+    hdrs = if_rocm_is_configured([
+        # TODO(whchung@gmail.com): enable in the subsequent PR.
+        #"amdgpu_compiler.h"
+    ]),
+    deps = if_rocm_is_configured([
+        # TODO(whchung@gmail.com): Enable these after pending PRs get merged.
+        #":gpu_compiler_impl",
+        #":miopen_conv_algorithm_picker",
+        #"//tensorflow/core:rocm_rocdl_path",
+    ]),
+)
+
+cc_library(
     name = "cudnn_batchnorm_rewriter",
     srcs = ["cudnn_batchnorm_rewriter.cc"],
     hdrs = ["cudnn_batchnorm_rewriter.h"],
@@ -1411,3 +1557,30 @@
         "//tensorflow/core:autotuning_proto_cc",
     ],
 )
+
+cc_library(
+    name = "cudnn_conv_blacklist",
+    srcs = ["cudnn_conv_blacklist.cc"],
+    hdrs = ["cudnn_conv_blacklist.h"],
+    deps = [
+        ":gpu_autotuning_proto",
+        "//tensorflow/compiler/xla:debug_options_flags",
+        "//tensorflow/core:autotuning_proto_cc",
+        "//tensorflow/core:stream_executor_no_cuda",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/strings",
+    ],
+)
+
+tf_cc_test(
+    name = "cudnn_conv_blacklist_test",
+    srcs = ["cudnn_conv_blacklist_test.cc"],
+    data = ["data/cudnn_conv_blacklist.pbtxt"],
+    deps = [
+        ":cudnn_conv_blacklist",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/stream_executor:dnn",
+    ],
+)
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc
index ce17e02..136988f 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc
@@ -25,6 +25,7 @@
 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
 #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
 #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_blacklist.h"
 #include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h"
 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
@@ -142,10 +143,8 @@
   XLA_SCOPED_LOGGING_TIMER_LEVEL("CudnnConvAlgorithmPicker checking redzones",
                                  2);
   using RedzoneCheckStatus = se::cuda::RedzoneAllocator::RedzoneCheckStatus;
-
   TF_ASSIGN_OR_RETURN(RedzoneCheckStatus redzone_check,
-                      allocator.CheckRedzones(stream));
-
+                      allocator.CheckRedzones());
   if (redzone_check.ok()) {
     return true;
   }
@@ -235,7 +234,6 @@
   return result_or;
 }
 
-
 StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
     const HloCustomCallInstruction* instr) {
   XLA_SCOPED_LOGGING_TIMER(
@@ -253,8 +251,6 @@
   // Create a stream for us to do our work on.
   se::Stream stream{stream_exec_};
   stream.Init();
-  const auto device_ordinal = stream_exec_->device_ordinal();
-
   // allocator either points to this->allocator_ or, if that's null, to a
   // se::StreamExecutorMemoryAllocator for stream_exec_.
   se::DeviceMemoryAllocator* allocator;
@@ -278,18 +274,18 @@
 
   // Allocate space for the input, filter, and output of the convolution.
   se::cuda::RedzoneAllocator input_output_allocator(
-      device_ordinal, allocator, PtxOptsFromConfig(hlo_module_config));
+      &stream, allocator, PtxOptsFromConfig(hlo_module_config));
   std::vector<se::DeviceMemoryBase> operand_buffers;
   for (const auto* operand : instr->operands()) {
     TF_ASSIGN_OR_RETURN(auto buffer,
                         input_output_allocator.AllocateBytes(
-                            &stream, ShapeUtil::ByteSizeOf(operand->shape())));
+                            ShapeUtil::ByteSizeOf(operand->shape())));
     initialize_buffer(buffer);
     operand_buffers.push_back(buffer);
   }
   TF_ASSIGN_OR_RETURN(auto result_buffer,
                       input_output_allocator.AllocateBytes(
-                          &stream, ShapeUtil::ByteSizeOf(result_shape)));
+                          ShapeUtil::ByteSizeOf(result_shape)));
   initialize_buffer(result_buffer);
 
   TF_ASSIGN_OR_RETURN(auto backend_config,
@@ -311,14 +307,27 @@
   const bool crash_on_checking_failure =
       debug_options.xla_gpu_crash_on_verification_failures();
 
+  const auto canonical_hlo =
+      std::get<1>(AutotuneCacheKeyfromInstruction(instr, stream_exec_));
+
+  absl::Span<const AlgorithmDesc> blacklisted_algos =
+      GetBlacklistedAlgorithms(GetComputeCapability(stream_exec_),
+                               GetCudnnVersion(stream_exec_), canonical_hlo);
+
   for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) {
     XLA_SCOPED_LOGGING_TIMER_LEVEL(
         absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithm algo ",
                      AlgorithmToString(alg)),
         2);
 
+    if (absl::c_linear_search(blacklisted_algos, alg)) {
+      LOG(INFO) << "Omitted potentially buggy algorithm "
+                << AlgorithmToString(alg) << " for conv " << instr->ToString();
+      continue;
+    }
+
     se::cuda::RedzoneAllocator scratch_allocator(
-        device_ordinal, allocator, PtxOptsFromConfig(hlo_module_config));
+        &stream, allocator, PtxOptsFromConfig(hlo_module_config));
     se::dnn::ProfileResult profile_result;
     VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
             << instr->ToString();
@@ -361,6 +370,22 @@
 
     if (!input_output_allocator_redzone_clear ||
         !scratch_allocator_redzone_clear) {
+      CudnnConvolutionList proto;
+      auto entry = proto.add_entries();
+      entry->set_hlo(canonical_hlo);
+      *entry->mutable_cc() = GetComputeCapability(stream_exec_);
+      *entry->add_cudnn_versions() = GetCudnnVersion(stream_exec_);
+      auto algo = entry->add_algos();
+      algo->set_id(alg.algo_id());
+      algo->set_tensor_ops(alg.tensor_ops_enabled());
+
+      LOG(ERROR)
+          << "To blacklist this algorithm for this convolution, "
+             "copy-paste the following "
+             "proto to the blacklist file pointed by XLA_FLAGS "
+             "--xla_gpu_cudnn_conv_blacklist_path="
+          << GetDebugOptionsFromFlags().xla_gpu_cudnn_conv_blacklist_path()
+          << " : " << proto.ShortDebugString();
       continue;
     }
 
@@ -402,7 +427,7 @@
       comparator.emplace(result_shape, hlo_module_config);
       TF_ASSIGN_OR_RETURN(
           reference_result_buffer,
-          input_output_allocator.AllocateBytes(&stream, result_buffer.size()));
+          input_output_allocator.AllocateBytes(result_buffer.size()));
       stream.ThenMemcpy(&reference_result_buffer, result_buffer,
                         result_buffer.size());
       first_algorithm = alg;
@@ -431,6 +456,14 @@
     *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec_);
     log.set_device_pci_bus_id(
         stream_exec_->GetDeviceDescription().pci_bus_id());
+    {
+      string blas_version;
+      if (auto* blas = stream_exec_->AsBlas()) {
+        if (blas->GetVersion(&blas_version).ok()) {
+          log.set_blas_version(blas_version);
+        }
+      }
+    }
     VLOG(1) << "Autotuning result: " << log.ShortDebugString();
     // If we crash on checking failure, we are in a testing/benchmark mode, thus
     // omitting logging through the logger.
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_blacklist.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_blacklist.cc
new file mode 100644
index 0000000..4d55ddb
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_blacklist.cc
@@ -0,0 +1,67 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_blacklist.h"
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/compiler/xla/debug_options_flags.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h"
+
+namespace xla {
+namespace gpu {
+
+absl::Span<const stream_executor::dnn::AlgorithmDesc> GetBlacklistedAlgorithms(
+    tensorflow::ComputeCapability cc, tensorflow::CudnnVersion cudnn_version,
+    absl::string_view hlo) {
+  // Key is the tuple of canonicalized hlo, compute capability major/minor,
+  // cudnn version major/minor/patch.
+  using MapType =
+      absl::flat_hash_map<std::tuple<std::string, int, int, int, int, int>,
+                          std::vector<stream_executor::dnn::AlgorithmDesc>>;
+
+  static MapType* blacklist = [] {
+    MapType* list = new MapType();
+    CudnnConvolutionList proto;
+    std::string file_path =
+        GetDebugOptionsFromFlags().xla_gpu_cudnn_conv_blacklist_path();
+    if (!file_path.empty()) {
+      TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(),
+                                            file_path, &proto));
+    }
+    for (const auto& entry : proto.entries()) {
+      for (const auto& cudnn_version : entry.cudnn_versions()) {
+        for (const auto& algo : entry.algos()) {
+          (*list)[std::make_tuple(std::string(entry.hlo()), entry.cc().major(),
+                                  entry.cc().minor(), cudnn_version.major(),
+                                  cudnn_version.minor(), cudnn_version.patch())]
+              .push_back({algo.id(), algo.tensor_ops()});
+        }
+      }
+    }
+    return list;
+  }();
+
+  auto iter = blacklist->find(std::make_tuple(
+      std::string(hlo), cc.major(), cc.minor(), cudnn_version.major(),
+      cudnn_version.minor(), cudnn_version.patch()));
+  if (iter != blacklist->end()) {
+    return iter->second;
+  }
+  return {};
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_blacklist.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_blacklist.h
new file mode 100644
index 0000000..df14955
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_blacklist.h
@@ -0,0 +1,34 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_BLACKLIST_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_BLACKLIST_H_
+
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/protobuf/autotuning.pb.h"
+
+namespace xla {
+namespace gpu {
+
+absl::Span<const stream_executor::dnn::AlgorithmDesc> GetBlacklistedAlgorithms(
+    tensorflow::ComputeCapability, tensorflow::CudnnVersion, absl::string_view);
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_BLACKLIST_H_
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_blacklist_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_blacklist_test.cc
new file mode 100644
index 0000000..09af973
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_blacklist_test.cc
@@ -0,0 +1,72 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_blacklist.h"
+
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/stream_executor/dnn.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class BlacklistTest : public testing::Test {
+ protected:
+  BlacklistTest() {
+    setenv("XLA_FLAGS",
+           absl::StrCat(
+               "--xla_gpu_cudnn_conv_blacklist_path=",
+               tensorflow::io::JoinPath(
+                   tensorflow::testing::TensorFlowSrcRoot(), "compiler", "xla",
+                   "service", "gpu", "data", "cudnn_conv_blacklist.pbtxt"))
+               .data(),
+           0);
+  }
+};
+
+TEST_F(BlacklistTest, DefaultTest) {
+  tensorflow::ComputeCapability cc;
+  cc.set_major(7);
+  cc.set_minor(0);
+  tensorflow::CudnnVersion cudnn_version;
+  cudnn_version.set_major(7);
+  cudnn_version.set_minor(6);
+  cudnn_version.set_patch(2);
+  auto list = GetBlacklistedAlgorithms(
+      cc, cudnn_version,
+      R"((f16[256,112,112,64]{3,2,1,0}, u8[0]{0}) custom-call(f16[256,224,224,4]{3,2,1,0}, f16[7,7,4,64]{2,1,0,3}), window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", backend_config="{conv_result_scale:1}")");
+  ASSERT_EQ(4, list.size());
+  EXPECT_EQ(stream_executor::dnn::AlgorithmDesc(0, false), list[0]);
+  EXPECT_EQ(stream_executor::dnn::AlgorithmDesc(0, true), list[1]);
+  EXPECT_EQ(stream_executor::dnn::AlgorithmDesc(1, false), list[2]);
+  EXPECT_EQ(stream_executor::dnn::AlgorithmDesc(1, true), list[3]);
+}
+
+TEST_F(BlacklistTest, NegativeTest) {
+  tensorflow::ComputeCapability cc;
+  cc.set_major(7);
+  cc.set_minor(0);
+  tensorflow::CudnnVersion cudnn_version;
+  cudnn_version.set_major(7);
+  cudnn_version.set_minor(6);
+  cudnn_version.set_minor(2);
+  auto list = GetBlacklistedAlgorithms(cc, cudnn_version, R"(invalid hlo)");
+  ASSERT_EQ(0, list.size());
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc
index 5aa76ac..da5059e 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc
@@ -48,12 +48,10 @@
 
   ~ScratchBufAllocator() override = default;
 
-  int64 GetMemoryLimitInBytes(se::Stream* /*stream*/) override {
-    return scratch_.size();
-  }
+  int64 GetMemoryLimitInBytes() override { return scratch_.size(); }
 
   se::port::StatusOr<DeviceMemory<uint8>> AllocateBytes(
-      se::Stream* stream, int64 byte_size) override {
+      int64 byte_size) override {
     if (allocated_) {
       return se::port::InternalError(
           "Can't allocate twice from a ScratchBufAllocator.");
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc
index dee257a..aca7307 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc
@@ -223,6 +223,7 @@
   }
   auto new_conv = computation->AddInstruction(HloInstruction::CreateCustomCall(
       conv->shape(), args, kCudnnConvBiasActivationForwardCallTarget));
+  new_conv->set_feature_group_count(conv->feature_group_count());
   new_conv->set_window(conv->window());
   new_conv->set_convolution_dimension_numbers(
       conv->convolution_dimension_numbers());
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc
index 7aa442d..b621880 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc
@@ -163,6 +163,26 @@
     })");
 }
 
+TEST_F(CudnnFusedConvRewriterTest, TestNoCrashOnInf) {
+  EXPECT_TRUE(RunAndCompare(R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = f32[] constant(inf)
+      zeros = f32[1,32,9,9] broadcast(zero), dimensions={}
+      alpha_conv_scalar = f32[] constant(0.999994934)
+
+      input = f32[1,17,9,9] parameter(0)
+      filter = f32[3,3,17,32] parameter(1)
+
+      conv = f32[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+      alpha_conv = f32[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={}
+      scaled_conv = f32[1,32,9,9] multiply(conv, alpha_conv)
+      ROOT relu = f32[1,32,9,9] maximum(zeros, scaled_conv)
+    })",
+                            ErrorSpec{0.01}));
+}
+
 TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndSideInput) {
   // max(0, conv(x, w) + 0.899994934 * side_input);
   TestMatchWithAllTypes(R"(
@@ -305,6 +325,30 @@
       ::testing::ContainsRegex(R"(custom-call.*metadata=\{op_type="foo"\})"));
 }
 
+TEST_F(CudnnFusedConvRewriterTest, TestPreservesFeatureGroupCount) {
+  // The convolution below would crash if feature_count is not preserved.
+  const char* kHloString = R"(
+    HloModule jaxpr_computation__6.19
+
+    primitive_computation__1.4 {
+      parameter.5 = f32[] parameter(0)
+      parameter.6 = f32[] parameter(1)
+      ROOT add.7 = f32[] add(parameter.5, parameter.6)
+    }
+
+    ENTRY jaxpr_computation__7.8 {
+      parameter.11 = f32[2,64,64,53]{3,2,1,0} parameter(1)
+      parameter.10 = f32[3,3,1,53]{3,2,1,0} parameter(0)
+      convolution.12 = f32[2,64,64,53]{3,2,1,0} convolution(parameter.11, parameter.10), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=53
+      constant.13 = f32[] constant(0)
+      broadcast.14 = f32[2,64,64,53]{3,2,1,0} broadcast(constant.13), dimensions={}
+      maximum.15 = f32[2,64,64,53]{3,2,1,0} maximum(convolution.12, broadcast.14)
+      ROOT reduce.17 = f32[] reduce(maximum.15, constant.13), dimensions={0,1,2,3}, to_apply=primitive_computation__1.4
+    }
+  )";
+  EXPECT_TRUE(RunAndCompare(kHloString, ErrorSpec{0.01}));
+}
+
 }  // namespace
 }  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/data/cudnn_conv_blacklist.pbtxt b/tensorflow/compiler/xla/service/gpu/data/cudnn_conv_blacklist.pbtxt
new file mode 100644
index 0000000..50cf947
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/data/cudnn_conv_blacklist.pbtxt
@@ -0,0 +1,6 @@
+entries {
+  hlo: '(f16[256,112,112,64]{3,2,1,0}, u8[0]{0}) custom-call(f16[256,224,224,4]{3,2,1,0}, f16[7,7,4,64]{2,1,0,3}), window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", backend_config="{conv_result_scale:1}"'
+  cc: {major: 7, minor: 0}
+  cudnn_versions: [{major: 7, minor: 6, patch: 0}, {major: 7, minor: 6, patch: 2}]
+  algos: [{}, {tensor_ops: true}, {id: 1}, {id:1, tensor_ops: true}]
+}
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index a8dae7d..6e72135 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
 
 #include <stddef.h>
+
 #include <unordered_map>
 #include <vector>
 
@@ -269,8 +270,19 @@
   // Upcast F16 to F32 if necessary.
   llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType();
   llvm::Value* input = FPCast(value, type);
+
+  // If |value| >= kMaxValue, tanh() is set to -1.0 or 1.0.
+  constexpr double kMaxValue = 20.0;
+  auto max_value = llvm::ConstantFP::get(type, kMaxValue);
+  llvm::Value* abs_value =
+      llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {input}, {type}, b_);
+
   llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
-  return FPCast(fast_tanh, value->getType());
+  auto one = llvm::ConstantFP::get(type, 1.0);
+  auto one_with_sign = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign,
+                                                    {one, input}, {type}, b_);
+  return FPCast(Select(FCmpULT(abs_value, max_value), fast_tanh, one_with_sign),
+                value->getType());
 }
 
 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexAbs(
diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
index da90ba9..991a463 100644
--- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
@@ -32,20 +32,20 @@
     int device_ordinal, se::DeviceMemoryAllocator* memory_allocator)
     : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
 
-int64 FftScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) {
+int64 FftScratchAllocator::GetMemoryLimitInBytes() {
   constexpr int64 kFftScratchSize = 1LL << 32;  // 4GB by default.
   return kFftScratchSize;
 }
 
 StatusOr<se::DeviceMemory<uint8>> FftScratchAllocator::AllocateBytes(
-    se::Stream* stream, int64 byte_size) {
+    int64 byte_size) {
   CHECK_GE(byte_size, 0) << "byte_size must be positive.";
-  if (byte_size > GetMemoryLimitInBytes(stream)) {
+  if (byte_size > GetMemoryLimitInBytes()) {
     return se::port::Status(
         se::port::error::RESOURCE_EXHAUSTED,
         absl::StrFormat(
             "Allocating %d bytes exceeds the memory limit of %d bytes.",
-            byte_size, GetMemoryLimitInBytes(stream)));
+            byte_size, GetMemoryLimitInBytes()));
   }
 
   TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory allocated_buffer,
diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h
index be77df1..95186c7 100644
--- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h
@@ -40,12 +40,12 @@
   FftScratchAllocator(int device_ordinal,
                       se::DeviceMemoryAllocator* memory_allocator);
 
-  int64 GetMemoryLimitInBytes(se::Stream* stream) override;
+  int64 GetMemoryLimitInBytes() override;
 
   int64 TotalAllocatedBytes() { return total_allocated_bytes_; }
 
   se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
-      se::Stream* stream, int64 byte_size) override;
+      int64 byte_size) override;
 
  private:
   const int device_ordinal_;
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc
index 626bef7..24a2dce 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc
@@ -110,7 +110,7 @@
 
     TF_ASSIGN_OR_RETURN(
         se::cuda::RedzoneAllocator::RedzoneCheckStatus rz_check_status,
-        allocator.CheckRedzones(stream));
+        allocator.CheckRedzones());
     if (!rz_check_status.ok()) {
       result.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED);
       *result.mutable_failure()->mutable_msg() =
@@ -244,8 +244,7 @@
 
   const HloModuleConfig& hlo_module_config = instr->GetModule()->config();
   se::cuda::RedzoneAllocator input_output_allocator(
-      executor->device_ordinal(), allocator,
-      PtxOptsFromConfig(hlo_module_config));
+      &stream, allocator, PtxOptsFromConfig(hlo_module_config));
 
   BufferComparator comparator(instr->shape(), hlo_module_config);
 
@@ -254,7 +253,7 @@
       [&](const HloInstruction* op) -> StatusOr<se::DeviceMemoryBase> {
     TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer,
                         input_output_allocator.AllocateBytes(
-                            &stream, ShapeUtil::ByteSizeOf(op->shape())));
+                            ShapeUtil::ByteSizeOf(op->shape())));
     InitializeFloatBuffer(&stream, op->shape().element_type(), &rng_state,
                           buffer);
     return buffer;
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc
index df7ee3c..bdf697a 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc
@@ -32,23 +32,6 @@
 
 namespace m = match;
 
-static complex128 GetScalarConstantAsComplex(const Literal &literal) {
-  switch (literal.shape().element_type()) {
-    case F16:
-      return {static_cast<double>(literal.Get<Eigen::half>({})), 0};
-    case F32:
-      return {literal.Get<float>({}), 0};
-    case F64:
-      return {literal.Get<double>({}), 0};
-    case C64:
-      return literal.Get<complex64>({});
-    case C128:
-      return literal.Get<complex128>({});
-    default:
-      LOG(FATAL) << "Unexpected type: " << literal.shape();
-  }
-}
-
 // The rewriting proceeds in a bottom-up way:
 //
 // (kDot A B) is rewritten into a (kCustomCall:gemm A B)
@@ -103,7 +86,7 @@
       if (config.beta() == 0.0 && existing_gemm->user_count() == 1) {
         complex128 prev_alpha = {config.alpha_real(), config.alpha_imag()};
         complex128 new_alpha =
-            GetScalarConstantAsComplex(alpha->literal()) * prev_alpha;
+            *alpha->literal().GetAsComplex128({}) * prev_alpha;
         config.set_alpha_real(new_alpha.real());
         config.set_alpha_imag(new_alpha.imag());
         TF_RETURN_IF_ERROR(existing_gemm->set_backend_config(config));
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_autotuning.proto b/tensorflow/compiler/xla/service/gpu/gpu_autotuning.proto
index 6ed7243..1fada38 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_autotuning.proto
+++ b/tensorflow/compiler/xla/service/gpu/gpu_autotuning.proto
@@ -6,6 +6,7 @@
 
 import "tensorflow/compiler/xla/service/hlo.proto";
 import "tensorflow/compiler/xla/xla_data.proto";
+import "tensorflow/core/protobuf/autotuning.proto";
 
 message ConvInstructionLog {
   xla.HloInstructionProto instruction = 1;
@@ -13,3 +14,19 @@
   uint64 result_address = 3;
   repeated uint64 operand_addresses = 4;
 }
+
+message CudnnConvAlgorithm {
+  int64 id = 1;
+  bool tensor_ops = 2;
+}
+
+message CudnnConvolutionEntry {
+  string hlo = 1;
+  tensorflow.ComputeCapability cc = 2;
+  repeated tensorflow.CudnnVersion cudnn_versions = 3;
+  repeated CudnnConvAlgorithm algos = 4;
+}
+
+message CudnnConvolutionList {
+  repeated CudnnConvolutionEntry entries = 1;
+}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
new file mode 100644
index 0000000..bcf3c54
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -0,0 +1,476 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
+
+#include <stdlib.h>
+
+#include <atomic>
+#include <functional>
+#include <mutex>  // NOLINT(build/c++11): only using std::call_once, not mutex.
+#include <utility>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
+#include "llvm/IR/DiagnosticInfo.h"
+#include "llvm/IR/DiagnosticPrinter.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Verifier.h"
+#include "tensorflow/compiler/xla/protobuf_util.h"
+#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
+#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/call_inliner.h"
+#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
+#include "tensorflow/compiler/xla/service/convolution_group_converter.h"
+#include "tensorflow/compiler/xla/service/dot_decomposer.h"
+#include "tensorflow/compiler/xla/service/dump.h"
+#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
+#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
+#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
+#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h"
+#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
+#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
+#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
+#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
+#include "tensorflow/compiler/xla/service/gpu/target_constants.h"
+#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
+#include "tensorflow/compiler/xla/service/gpu/variadic_op_splitter.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
+#include "tensorflow/compiler/xla/service/hlo_cse.h"
+#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
+#include "tensorflow/compiler/xla/service/hlo_dce.h"
+#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
+#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
+#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
+#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
+#include "tensorflow/compiler/xla/service/hlo_verifier.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/service/mem_wasted_on_passthrough_params.h"
+#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
+#include "tensorflow/compiler/xla/service/reshape_mover.h"
+#include "tensorflow/compiler/xla/service/rng_expander.h"
+#include "tensorflow/compiler/xla/service/slice_sinker.h"
+#include "tensorflow/compiler/xla/service/slow_operation_alarm.h"
+#include "tensorflow/compiler/xla/service/sort_simplifier.h"
+#include "tensorflow/compiler/xla/service/stable_sort_expander.h"
+#include "tensorflow/compiler/xla/service/transpose_folding.h"
+#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
+#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
+#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
+#include "tensorflow/compiler/xla/service/while_loop_trip_count_annotator.h"
+#include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/regexp.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/subprocess.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/profiler/lib/traceme.h"
+
+namespace xla {
+namespace gpu {
+
+GpuCompiler::GpuCompiler(se::Platform::Id platform_id,
+                         const char* target_triple, const char* data_layout)
+    : platform_id_(platform_id),
+      target_triple_(target_triple),
+      data_layout_(data_layout),
+      pointer_size_(llvm::DataLayout(data_layout)
+                        .getPointerSize(0 /* default address space */)) {}
+
+// Runs optimization passes on the given HLO module.
+Status GpuCompiler::OptimizeHloModule(
+    HloModule* hlo_module, se::StreamExecutor* stream_exec,
+    se::DeviceMemoryAllocator* device_allocator) {
+  {
+    HloPassPipeline pipeline("optimization");
+    pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+                                              /*allow_mixed_precision=*/false);
+
+    // Expand random number generation.
+    pipeline.AddPass<RngExpander>();
+
+    // Remove zero-sized HLO from the input so that other passes don't have to
+    // handle it.
+    pipeline.AddPass<ZeroSizedHloElimination>();
+
+    pipeline.AddPass<GpuScatterExpander>();
+
+    pipeline.AddPass<DynamicIndexSplitter>();
+    pipeline.AddPass<GpuHloSupportChecker>();
+    ReducePrecisionInsertion::AddPasses(
+        &pipeline, hlo_module->config().debug_options(),
+        ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION);
+
+    // TODO(b/64094172): make Call work on GPU instead of inlining.
+    pipeline.AddPass<CallInliner>();
+    auto cost_model = [](HloInstruction* conv) {
+      // We need a cost model for GPUs. Currently, do nothing.
+      return false;
+    };
+    pipeline.AddPass<DotDecomposer>();
+    pipeline.AddPass<ConvolutionGroupConverter>(
+        cost_model,
+        /*convert_batch_groups_only=*/true);
+    // Expand the sort op to support stable sorting if required.
+    pipeline.AddPass<StableSortExpander>();
+    // Convert BF16 operations to F32 operations so that the GPU backend can
+    // support BF16 operations without directly implementing a BF16 lowering for
+    // most ops.
+    pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
+
+    {
+      auto& pass =
+          pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
+      pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+                                            /*allow_mixed_precision=*/false);
+
+      // If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls
+      // where possible.  Not every batchnorm op can be implemented as a call to
+      // cudnn, so decompose any remaining batchnorm ops into a soup of HLOs.
+      if (hlo_module->config().debug_options().xla_gpu_use_cudnn_batchnorm()) {
+        pass.AddPass<CudnnBatchNormRewriter>();
+      }
+      pass.AddPass<BatchNormExpander>(
+          /*rewrite_training_op=*/true,
+          /*rewrite_inference_op=*/true,
+          /*rewrite_grad_op=*/true);
+
+      pipeline.AddPass<HloGetDimensionSizeRewriter>();
+
+      // BatchNormExpander can create zero-sized ops, so zero-sized HLO
+      // elimination has to come after that pass.
+      pipeline.AddPass<ZeroSizedHloElimination>();
+
+      AlgebraicSimplifierOptions options;
+      pass.AddPass<AlgebraicSimplifier>(options);
+      pass.AddPass<SortSimplifier>();
+      pass.AddPass<TupleSimplifier>();
+      pass.AddPass<WhileLoopConstantSinking>();
+      pass.AddPass<WhileLoopSimplifier>();
+
+      // TODO(b/134075051): Re-enable after b/134075051 is fixed.
+      // pass.AddPass<SliceSinker>();
+
+      pass.AddPass<HloDCE>();
+      pass.AddPass<ReshapeMover>();
+      pass.AddPass<HloConstantFolding>();
+      pass.AddPass<ConditionalSimplifier>();
+    }
+
+    pipeline.AddPass<TransposeFolding>(
+        [](const HloInstruction& dot,
+           const TransposeFolding::OperandIndices& candidate_operands) {
+          return IsMatrixMultiplication(dot)
+                     ? candidate_operands
+                     : TransposeFolding::OperandIndices{};
+        },
+        TransposeFolding::NeverFoldTranspose);
+    pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
+    pipeline.AddPass<HloDCE>();
+
+    // Run WhileLoopTripCountAnnotator at the end of the simplification
+    // pipeline, before layout assignment and fusion.  This pass does some
+    // pattern-matching on while bodies/conditions, and this is where the HLO is
+    // "nicest".
+    //
+    // It's important that we don't make semantic changes (e.g. unrolling) to
+    // any `while` loops after this point, because otherwise the trip-count
+    // annotations added by this pass may not be correct after the
+    // modifications.
+    pipeline.AddPass<WhileLoopTripCountAnnotator>();
+    TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
+  }
+
+  // Run target-specific HLO optimization passes for convolution
+  // canonicalization.
+  TF_RETURN_IF_ERROR(OptimizeHloConvolutionCanonicalization(
+      hlo_module, stream_exec, device_allocator));
+
+  {
+    // Run layout assignment in a separate pipeline from
+    // "post-layout-assignment" because we want everything after layout
+    // assignment to have a layout-sensitive invariant-checker, but
+    // HloPassPipeline also runs its invariant checker before any passes are
+    // run, meaning, the pipeline that contains layout assignment cannot contain
+    // a layout-sensitive verifier!
+    HloPassPipeline pipeline("layout assignment");
+    pipeline.AddPass<GpuLayoutAssignment>(
+        hlo_module->mutable_entry_computation_layout(),
+        LayoutAssignment::InstructionCanChangeLayout, stream_exec);
+    TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
+  }
+
+  // Run target-specific HLO optimization passes after layout assignment.
+  TF_RETURN_IF_ERROR(OptimizeHloPostLayoutAssignment(hlo_module, stream_exec,
+                                                     device_allocator));
+
+  {
+    HloPassFix<HloPassPipeline> fusion("fusion");
+    // We try to split variadic ops with many parameters into several such ops
+    // to avoid exceeding the parameter space.
+    fusion.AddPass<VariadicOpSplitter>();
+    /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
+     * fixing the ticket. */
+    fusion.AddInvariantChecker<HloVerifier>(
+        /*layout_sensitive=*/true,
+        /*allow_mixed_precision=*/false,
+        LayoutAssignment::InstructionCanChangeLayout);
+    fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
+    fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
+    fusion.AddPass<FusionMerger>();
+    fusion.AddPass<GpuMultiOutputFusion>();
+    fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
+                           /*only_fusion_computations=*/true);
+    fusion.AddPass<HloDCE>();
+    TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
+
+    HloPassPipeline reduce_pipeline("reduce-precision");
+    /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
+     * fixing the ticket. */
+    reduce_pipeline.AddInvariantChecker<HloVerifier>(
+        /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false,
+        LayoutAssignment::InstructionCanChangeLayout);
+    ReducePrecisionInsertion::AddPasses(
+        &reduce_pipeline, hlo_module->config().debug_options(),
+        ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
+    StatusOr<bool> reduce_result = reduce_pipeline.Run(hlo_module);
+    TF_RETURN_IF_ERROR(reduce_result.status());
+
+    if (reduce_result.ValueOrDie()) {
+      // Do another fusion pass, with the expectation that we may be able to
+      // fuse the new ReducePrecision operations.
+      TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
+    }
+  }
+
+  return Status::OK();
+}
+
+// Modifies the given HLO module so that it will be accepted by IrEmitter.
+// Unlike optimization passes, the passes are necessary for correctness.
+Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
+  // In some cases, we have to place the result of an instruction in a temporary
+  // buffer. For instance, the buffer that holds an external parameter is
+  // assumed immutable at this point, and should not be reused for output
+  // (b/27180329). Therefore, in that case, we set the output to be a copy of
+  // the parameter.
+  HloPassPipeline pipeline("GPU-ir-emit-prepare");
+  /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
+   * fixing the ticket. */
+  pipeline.AddInvariantChecker<HloVerifier>(
+      /*layout_sensitive=*/true,
+      /*allow_mixed_precision=*/false,
+      LayoutAssignment::InstructionCanChangeLayout);
+
+  // Copy insertion should be performed immediately before IR emission to avoid
+  // inserting unnecessary copies (later pass adds an instruction which
+  // materializes the value) or missing a necessary copy (later pass removes an
+  // instruction which materializes a value). DCE must be run immediately before
+  // (and sometime after) copy insertion, to avoid dead code from interfering
+  // with the rewrites.
+  pipeline.AddPass<HloDCE>();
+  pipeline.AddPass<FlattenCallGraph>();
+  // The following pass LOGs memory waste. Add it when VLOGing is enabled only.
+  if (VLOG_IS_ON(2)) {
+    pipeline.AddPass<MemWastedOnPassthroughParams>();
+  }
+  pipeline.AddPass<GpuCopyInsertion>(GetCanShareBuffer());
+  pipeline.AddPass<GpuSanitizeConstantNames>();
+  return pipeline.Run(hlo_module).status();
+}
+
+StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
+    std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
+    se::DeviceMemoryAllocator* device_allocator) {
+  // We dump the post-optimization HLO in RunBackend so no need to dump it here.
+  XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses");
+  tensorflow::profiler::TraceMe activity(
+      [&] { return absl::StrCat("HLO Transforms:", module->name()); },
+      tensorflow::profiler::TraceMeLevel::kInfo);
+  TF_RETURN_IF_ERROR(
+      OptimizeHloModule(module.get(), stream_exec, device_allocator));
+
+  TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get()));
+
+  return std::move(module);
+}
+
+StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
+    std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
+    se::DeviceMemoryAllocator* device_allocator) {
+  XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend");
+  auto slow_compile_alarm = SlowCompilationAlarm();
+
+  TF_RET_CHECK(stream_exec != nullptr);
+
+  llvm::LLVMContext llvm_context;
+  std::string buffer;
+  llvm::raw_string_ostream error(buffer);
+  llvm::DiagnosticPrinterRawOStream printer(error);
+  auto DiagnosticHandler = [](const llvm::DiagnosticInfo& diag_info,
+                              void* Context) {
+    auto printer = static_cast<llvm::DiagnosticPrinterRawOStream*>(Context);
+    diag_info.print(*printer);
+  };
+  llvm_context.setDiagnosticHandlerCallBack(DiagnosticHandler, &printer);
+
+  llvm::Module llvm_module(module->name().c_str(), llvm_context);
+  // Set the target triple and the data layout.
+  llvm_module.setTargetTriple(target_triple_);
+  llvm_module.setDataLayout(data_layout_);
+
+  // Determine the HLO schedule, which is an ordering of HLO instructions.  This
+  // is used by buffer assignment to enable buffer reuse, and the same ordering
+  // must also be used to determine the thunk launch schedule.
+  std::unique_ptr<StreamAssignment> stream_assignment = AssignStreams(*module);
+  TF_ASSIGN_OR_RETURN(
+      std::unique_ptr<GpuHloSchedule> hlo_schedule,
+      GpuHloSchedule::Build(*module, *stream_assignment, pointer_size_));
+
+  // Run buffer analysis on the HLO graph. This analysis figures out which
+  // temporary buffers are required to run the computation.
+  TF_ASSIGN_OR_RETURN(
+      std::unique_ptr<BufferAssignment> buffer_assignment,
+      BufferAssigner::Run(
+          module.get(), hlo_schedule->ConsumeHloOrdering(),
+          BufferSizeBytesFunction(),
+          /*color_alignment=*/
+          [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; },
+          /*allocate_buffers_for_constants=*/true,
+          /*colorer=*/BufferAssigner::DefaultColorer(),
+          /*must_not_live_out=*/{}, GetCanShareBuffer()));
+  DumpHloModuleIfEnabled(*module, *buffer_assignment, "after_optimizations");
+
+  IrEmitterContext ir_emitter_context(
+      module.get(), buffer_assignment.get(), stream_exec->platform(),
+      &stream_exec->GetDeviceDescription(), &llvm_module);
+
+  HloComputation* entry_computation = module->entry_computation();
+  IrEmitterUnnested ir_emitter(module->config(), entry_computation,
+                               &ir_emitter_context);
+
+  TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
+
+  {
+    XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission");
+    TF_RETURN_IF_ERROR(entry_computation->Accept(&ir_emitter));
+  }
+
+  if (user_pre_optimization_hook_) {
+    user_pre_optimization_hook_(llvm_module);
+  }
+  string ir_module_string_before_opt;
+  const bool embed_ir_in_executable =
+      module->config().debug_options().xla_embed_ir_in_executable();
+  if (embed_ir_in_executable) {
+    ir_module_string_before_opt = llvm_ir::DumpModuleToString(llvm_module);
+  }
+
+  llvm_ir::DumpIrIfEnabled(*module, llvm_module, /*optimized=*/false);
+
+  {
+    XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - Running LLVM verifier");
+
+    std::string err;
+    llvm::raw_string_ostream err_stream(err);
+
+    // verifyModule() returns true if the module is broken.
+    TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream))
+        << "Invalid LLVM IR before optimizations:\n"
+        << err_stream.str()
+        << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. "
+           "Rerun with --xla_dump_to to get the IR. ";
+  }
+
+  GpuVersion gpu_version = GetGpuVersion(stream_exec);
+
+  using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
+  TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result,
+                      CompileTargetBinary(module.get(), &llvm_module,
+                                          gpu_version, stream_exec));
+
+  auto thunk_schedule = absl::make_unique<ThunkSchedule>(
+      ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment),
+      hlo_schedule->ThunkLaunchOrder());
+  if (DumpingEnabledForHloModule(*module)) {
+    DumpToFileInDirOrStdout(*module, "thunk_schedule",
+                            thunk_schedule->ToString());
+  }
+
+  std::unique_ptr<HloProfileIndexMap> profile_index_map;
+  std::unique_ptr<HloProfilePrinterData> profile_printer;
+
+  if (module->config().hlo_profiling_enabled() || VLOG_IS_ON(1)) {
+    HloCostAnalysis cost_analysis(ShapeSizeBytesFunction());
+    cost_analysis.set_bytes_per_second(
+        stream_exec->GetDeviceDescription().memory_bandwidth());
+    TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis));
+    VLOG(1) << "HLO memory read+written: "
+            << tensorflow::strings::HumanReadableNumBytes(
+                   cost_analysis.bytes_accessed());
+    if (module->config().hlo_profiling_enabled()) {
+      profile_index_map = absl::make_unique<HloProfileIndexMap>(*module);
+      profile_printer = CreateHloProfilePrinterData(
+          *profile_index_map, cost_analysis, entry_computation->name());
+    }
+  }
+
+  auto* gpu_executable = new GpuExecutable(
+      backend_result.first, backend_result.second, gpu_version,
+      std::move(thunk_schedule), std::move(module),
+      std::move(buffer_assignment), std::move(profile_printer),
+      std::move(profile_index_map));
+  if (embed_ir_in_executable) {
+    DCHECK_NE("", ir_module_string_before_opt);
+    gpu_executable->set_ir_module_string(ir_module_string_before_opt);
+  }
+  return std::unique_ptr<Executable>(gpu_executable);
+}
+
+StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
+GpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
+                                const AotCompilationOptions& options) {
+  return Unimplemented("not yet implemented: GpuCompiler::CompileAheadOfTime");
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
new file mode 100644
index 0000000..901d994
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
@@ -0,0 +1,120 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/llvm_compiler.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/stream_executor/stream_executor_pimpl.h"
+
+namespace xla {
+namespace gpu {
+
+// The GPU compiler generates efficient GPU executables.
+class GpuCompiler : public LLVMCompiler {
+ public:
+  GpuCompiler(se::Platform::Id platform_id, const char* target_triple,
+              const char* data_layout);
+  ~GpuCompiler() override {}
+
+  // Bring in
+  // StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
+  //     std::vector<std::unique_ptr<HloModule>> modules,
+  //     std::vector<std::vector<se::StreamExecutor*>>
+  //        stream_execs)
+  using LLVMCompiler::Compile;
+
+  StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
+      std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
+      se::DeviceMemoryAllocator* device_allocator) override;
+
+  Status OptimizeHloModule(HloModule* hlo_module,
+                           se::StreamExecutor* stream_exec,
+                           se::DeviceMemoryAllocator* device_allocator);
+
+  virtual Status OptimizeHloConvolutionCanonicalization(
+      HloModule* hlo_module, se::StreamExecutor* stream_exec,
+      se::DeviceMemoryAllocator* device_allocator) = 0;
+
+  virtual Status OptimizeHloPostLayoutAssignment(
+      HloModule* hlo_module, se::StreamExecutor* stream_exec,
+      se::DeviceMemoryAllocator* device_allocator) = 0;
+
+  virtual HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() {
+    return
+        [](const HloInstruction*, const HloInstruction*,
+           const ShapeIndex&) -> absl::optional<bool> { return absl::nullopt; };
+  }
+
+  virtual GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) = 0;
+
+  virtual StatusOr<std::pair<std::string, std::vector<uint8>>>
+  CompileTargetBinary(const HloModule* hlo_module, llvm::Module* llvm_module,
+                      GpuVersion gpu_version,
+                      se::StreamExecutor* stream_exec) = 0;
+
+  Status PrepareHloModuleForIrEmitting(HloModule* hlo_module);
+
+  StatusOr<std::unique_ptr<Executable>> RunBackend(
+      std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
+      se::DeviceMemoryAllocator* device_allocator) override;
+
+  StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
+  CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
+                     AotCompilationOptions const& options) override;
+
+  se::Platform::Id PlatformId() const override { return platform_id_; }
+
+  HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
+    // Capture just the pointer size, not the entire GpuCompiler object.
+    int64 pointer_size = pointer_size_;
+    return [pointer_size](const Shape& shape) {
+      return ShapeUtil::ByteSizeOf(shape, pointer_size);
+    };
+  }
+
+ private:
+  se::Platform::Id platform_id_;
+
+  // The triple that represents our target.
+  const char* target_triple_;
+
+  // The data layout of the emitted module.
+  const char* data_layout_;
+
+  // The size in bytes of a pointer. Used by ShapeSizeBytesFunction.
+  const int64 pointer_size_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(GpuCompiler);
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index e4942bd..2706b4f 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -195,10 +195,11 @@
   }
 
   main_stream->ThenWaitFor(&sub_streams);
-  // Make sure kernels are completed before deallocating temporary buffers.
+  // Make sure kernels are completed before deallocating temporary buffers or
+  // the profiler state.
   // TODO(b/30100571): we could potentially postpone deallocating the temp
   // buffers until a different computation is executed.
-  if (block_host_until_done) {
+  if (do_profile || block_host_until_done) {
     Status block_status = main_stream->BlockHostUntilDone();
     if (!block_status.ok()) {
       return InternalError(
@@ -207,17 +208,20 @@
     }
   }
 
+  // FinishExecution() blocks until main_stream has completed if profiling is
+  // enabled; we therefore do not need to defer profile collection onto a
+  // stream.
   profiler.FinishExecution();
   uint64 end_micros = tensorflow::Env::Default()->NowMicros();
 
-  {
-    tensorflow::mutex_lock lock(mutex_);
+  if (run_options->run_options().execution_profile()) {
+    ExecutionProfile* profile = run_options->run_options().execution_profile();
     const double nanoseconds = (end_micros - start_micros) * 1000.0;
-    execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0));
+    profile->set_compute_time_ns(std::max(nanoseconds, 1.0));
 
     // If hlo profiling was disabled then the cycle count is left empty.
     if (do_profile) {
-      execution_profile_.set_compute_cycle_count(
+      profile->set_compute_cycle_count(
           hlo_execution_profile->total_cycles_executed(
               *module().entry_computation()));
     }
@@ -241,8 +245,13 @@
   module_spec.AddCudaPtxInMemory(text().c_str());
 
   absl::flat_hash_map<int64, se::DeviceMemoryBase> globals;
+  if (module_spec.cuda_ptx_in_memory() == nullptr) {
+    // No custom PTX => no globals.
+    return &module_globals_.emplace(executor, std::move(globals)).first->second;
+  }
+
   se::ModuleHandle module_handle;
-  executor->LoadModule(module_spec, &module_handle);
+  TF_RETURN_IF_ERROR(executor->LoadModule(module_spec, &module_handle));
 
   for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size();
        ++i) {
@@ -402,25 +411,16 @@
   return std::move(shaped_buffer);
 }
 
-StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
+StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteAsyncOnStream(
     const ServiceExecutableRunOptions* run_options,
     absl::Span<const ShapedBuffer* const> arguments,
     HloExecutionProfile* hlo_execution_profile) {
-  // TODO(b/134086343): ExecuteOnStream should not be async according to the
-  // documentation, instead ExecuteAsyncOnStream should be used.
-  return Execute(run_options, arguments, hlo_execution_profile,
-                 /*block_host_until_done=*/
-                 !run_options->allocator()->AllowsAsynchronousDeallocation());
-}
-
-StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteAsyncOnStream(
-    const ServiceExecutableRunOptions* run_options,
-    absl::Span<const ShapedBuffer* const> arguments) {
   se::DeviceMemoryAllocator* memory_allocator = run_options->allocator();
   // Force synchronous execution if the allocator requires it.
   bool block_host_until_done =
       !memory_allocator->AllowsAsynchronousDeallocation();
-  return Execute(run_options, arguments, nullptr, block_host_until_done);
+  return Execute(run_options, arguments, hlo_execution_profile,
+                 block_host_until_done);
 }
 
 const InstructionValueSet& GpuExecutable::GetRootValueSet() const {
@@ -428,5 +428,14 @@
       module().entry_computation()->root_instruction());
 }
 
+int64 GpuExecutable::SizeOfGeneratedCodeInBytes() {
+  // Non-empty PTX but empty cubin: compilation must have failed, return
+  // "unknown".
+  if (binary().empty() && !text_.empty()) {
+    return -1;
+  }
+  return binary().size();
+}
+
 }  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
index 5f9fe3e..0175e31 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
@@ -61,6 +61,8 @@
                 std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
   ~GpuExecutable() override;
 
+  int64 SizeOfGeneratedCodeInBytes() override;
+
   // This should be called after set_ir_module_string.
   const string& ir_module_string() const { return ir_module_string_; }
 
@@ -78,17 +80,13 @@
   // compilation is left up to the GPU driver.
   const std::vector<uint8>& binary() const { return binary_; }
 
-  // ExecuteOnStream will fail if the compute capability of the stream doesn't
-  // match the compute capability passed to this object's constructor.
-  StatusOr<ScopedShapedBuffer> ExecuteOnStream(
+  // ExecuteAsyncOnStream will fail if the compute capability of the stream
+  // doesn't match the compute capability passed to this object's constructor.
+  StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
       const ServiceExecutableRunOptions* run_options,
       absl::Span<const ShapedBuffer* const> arguments,
       HloExecutionProfile* hlo_execution_profile) override;
 
-  StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
-      const ServiceExecutableRunOptions* run_options,
-      absl::Span<const ShapedBuffer* const> arguments) override;
-
   std::shared_ptr<const BufferAssignment> GetBufferAssignment() const {
     return assignment_;
   }
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
index 2d266b9..97fa275 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
@@ -16,9 +16,11 @@
 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
 
 #include <iterator>
+#include <stack>
 #include <vector>
 
 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/shape.h"
@@ -26,8 +28,8 @@
 
 namespace xla {
 namespace gpu {
-
 namespace {
+
 void AppendParams(const HloInstruction& instr,
                   std::vector<HloInstruction*>* params) {
   if (instr.opcode() == HloOpcode::kFusion) {
@@ -39,6 +41,25 @@
     }
   }
 }
+
+bool CodegensIntoLoop(const HloInstruction& instr) {
+  CHECK_NE(instr.opcode(), HloOpcode::kFusion) << "`instr` has to be unfused.";
+  if (instr.opcode() == HloOpcode::kReduce &&
+      !IsReductionFromOrToContiguousDimensions(instr)) {
+    return true;
+  }
+  // Reduce window codegens into loop only when windows overlap, i.e. stride is
+  // less than window size.
+  if (instr.opcode() == HloOpcode::kReduceWindow) {
+    for (const auto& dim : instr.window().dimensions()) {
+      if (dim.size() > dim.stride()) {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
 }  // namespace
 
 bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer,
@@ -202,19 +223,19 @@
   if (!IsLoopFusible(producer) || !IsFusible(consumer)) {
     return false;
   }
-
   // Skip multiple output fusion. It's not yet supported.
   if (producer.IsMultiOutputFusion()) {
     return false;
   }
-
+  if (CreatesNestedLoop(producer, consumer)) {
+    return false;
+  }
   // Do not fuse into reduce input fusions if the resulting kernel would suffer
   // from poor data locality (due to unfriendly input layouts).
   if (IsInputFusibleReduction(consumer) &&
       !LayoutsAreReduceInputFusionFriendly(producer, consumer)) {
     return false;
   }
-
   // We can't fuse library calls, so if a user of such an op could become a
   // bitcast, leave it unfused. See `xla::InstructionFusion::ShouldFuse` for
   // further rationale.
@@ -222,7 +243,6 @@
       ImplementedAsLibraryCall(*producer.operand(0))) {
     return false;
   }
-
   // Fuse scalar constants into loop fusion nodes. This reduces the number of
   // parameters and makes matching scalar broadcasts easier.
   //
@@ -235,7 +255,6 @@
     return ShapeUtil::IsEffectiveScalar(producer.shape()) &&
            consumer.opcode() == HloOpcode::kFusion;
   }
-
   return true;
 }
 
@@ -249,15 +268,15 @@
   if (!IsLoopFusible(producer) || !IsFusibleAsMultiOutputFusionRoot(consumer)) {
     return false;
   }
-
+  if (CreatesNestedLoop(producer, consumer)) {
+    return false;
+  }
   if (!ShapesCompatibleForMultiOutputFusion(producer, consumer)) {
     return false;
   }
-
   if (!LayoutsAreReduceInputFusionFriendly(producer, consumer)) {
     return false;
   }
-
   return true;
 }
 
@@ -323,6 +342,71 @@
   return operands.size() + num_output_buffers > kMaxOperandsAndOutputsPerFusion;
 }
 
+bool CreatesNestedLoop(const HloInstruction& producer,
+                       const HloInstruction& consumer) {
+  // If producer does not have an instruction that codegens a loop then there is
+  // nothing to do.
+  auto producer_has_loop_codegen = [&](const HloInstruction& instr) {
+    if (producer.opcode() != HloOpcode::kFusion) {
+      return CodegensIntoLoop(producer);
+    }
+    for (const auto& instr : producer.fused_instructions()) {
+      if (CodegensIntoLoop(*instr)) {
+        return true;
+      }
+    }
+    return false;
+  };
+  if (!producer_has_loop_codegen(producer)) {
+    return false;
+  }
+
+  // If consumer is a non-fusion instruction then we have to check if it
+  // generates a loop.
+  if (consumer.opcode() != HloOpcode::kFusion) {
+    return CodegensIntoLoop(consumer);
+  }
+
+  // If consumer is a fusion then we have to check if the output of producer is
+  // used directly or indirectly as an input to an HLO instruction that
+  // generates a loop, i.e. there is a path in the graph from an operand
+  // corresponding to the producer to an HLO instruction generating a loop in
+  // the consumer.
+  for (const HloInstruction* operand : consumer.operands()) {
+    if (operand != &producer) {
+      continue;
+    }
+
+    const HloInstruction* root =
+        consumer.fused_instructions_computation()->parameter_instruction(
+            consumer.operand_index(operand));
+
+    std::stack<const HloInstruction*> dfs;
+    dfs.push(root);
+    absl::flat_hash_set<const HloInstruction*> visited;
+    while (!dfs.empty()) {
+      const HloInstruction* cur = dfs.top();
+      dfs.pop();
+
+      if (visited.contains(cur)) {
+        continue;
+      }
+      visited.insert(cur);
+
+      if (CodegensIntoLoop(*cur)) {
+        return true;
+      }
+      for (const auto& user : cur->users()) {
+        if (visited.contains(user)) {
+          continue;
+        }
+        dfs.push(user);
+      }
+    }
+  }
+  return false;
+}
+
 bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr) {
   // We can fuse reduces and loop fusions. Elementwise instructions can be fused
   // with any other instruction.
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h
index 4956bf0..145975e 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h
@@ -67,6 +67,11 @@
 bool FusionWouldBeTooLarge(const HloInstruction& instr1,
                            const HloInstruction& instr2);
 
+// Check if fusing producer and consumer will generate a nested loop, e.g. both
+// producer and consumer are `reduce-window` HLO instructions.
+bool CreatesNestedLoop(const HloInstruction& producer,
+                       const HloInstruction& consumer);
+
 // Whether instruction shapes are compatible for multi-output fusion, i.e.
 // whether the emitters support lowering the resulting fusion.
 // This function works for both, sibling and producer-consumer multi-output
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc
index 388723f..dc4e54c 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc
@@ -906,5 +906,136 @@
   EXPECT_TRUE(IsProducerConsumerFusible(*producer, *consumer));
 }
 
+TEST_F(GpuFusibleTest, CreatesNestedLoop_NonfusionInstr) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    ENTRY entry {
+      p_0 = f32[2,5] parameter(0)
+
+      constant_1 = f32[] constant(1)
+      reduce-window_1 = f32[3,5] reduce-window(p_0, constant_1),
+        window={size=2x1 pad=0_2x0_0}, to_apply=scalar_add
+
+      constant_2 = f32[] constant(2)
+      reduce-window_2 = f32[3,5] reduce-window(p_0, constant_2),
+        window={size=2x1 pad=0_2x0_0}, to_apply=scalar_add
+
+      ROOT root = (f32[32,32], f32[32,32,32]) tuple(reduce-window_1, reduce-window_2)
+    })"))
+                    .ValueOrDie();
+  const HloInstruction* root = module->entry_computation()->root_instruction();
+  const HloInstruction* producer = root->operand(0);
+  const HloInstruction* consumer = root->operand(1);
+  EXPECT_TRUE(CreatesNestedLoop(*producer, *consumer));
+}
+
+TEST_F(GpuFusibleTest, DoesNotCreateNestedLoop_NonfusionInstr) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    ENTRY entry {
+      p_0 = f32[3,5] parameter(0)
+      constant = f32[] constant(1)
+      broadcast = f32[3, 5] broadcast(f32[] constant), dimensions={}
+      scaled_p_0 = f32[3,5] multiply(f32[3, 5] broadcast, f32[3,5]{1, 0} p_0)
+
+      p_1 = f32[2,5] parameter(1)
+      reduce-window = f32[3,5] reduce-window(p_1, constant),
+        window={size=2x1 pad=0_2x0_0}, to_apply=scalar_add
+
+      ROOT root = (f32[32,32], f32[32,32,32]) tuple(reduce-window, scaled_p_0)
+    })"))
+                    .ValueOrDie();
+  const HloInstruction* root = module->entry_computation()->root_instruction();
+  const HloInstruction* producer = root->operand(0);
+  const HloInstruction* consumer = root->operand(1);
+  EXPECT_FALSE(CreatesNestedLoop(*producer, *consumer));
+}
+
+TEST_F(GpuFusibleTest, DoesNotCreateNestedLoop_NonoverlappingReduceWindows) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    ENTRY entry {
+      p_0 = f32[2,5] parameter(0)
+
+      constant_1 = f32[] constant(1)
+      reduce-window_1 = f32[3,5] reduce-window(p_0, constant_1),
+        window={size=2x1 pad=0_2x0_0}, to_apply=scalar_add
+
+      constant_2 = f32[] constant(2)
+      reduce-window_2 = f32[2,3] reduce-window(p_0, constant_2),
+        window={size=2x1 pad=0_2x0_0 stride=2x2}, to_apply=scalar_add
+
+      ROOT root = (f32[32,32], f32[32,32,32]) tuple(reduce-window_1, reduce-window_2)
+    })"))
+                    .ValueOrDie();
+  const HloInstruction* root = module->entry_computation()->root_instruction();
+  const HloInstruction* producer = root->operand(0);
+  const HloInstruction* consumer = root->operand(1);
+  EXPECT_FALSE(CreatesNestedLoop(*producer, *consumer));
+}
+
+TEST_F(GpuFusibleTest, CreatesNestedLoop_FusionInstr) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_producer {
+      operand = f32[2,2] parameter(0)
+      constant = f32[] constant(1)
+      ROOT reduce-window = f32[2,2] reduce-window(operand, constant),
+        window={size=2x2 pad=0_1x0_1}, to_apply=scalar_add
+    }
+
+    fused_consumer {
+      operand_0 = f32[2,2] parameter(0)
+
+      operand_1 = f32[2,2] parameter(1)
+      constant = f32[] constant(1)
+      reduce-window = f32[2,2] reduce-window(operand_1, constant),
+        window={size=2x2 pad=0_1x0_1}, to_apply=scalar_add
+
+      ROOT scaled_operand_1 = f32[2,2] multiply(f32[2, 2] operand_0, f32[2,2] reduce-window)
+    }
+
+    ENTRY entry {
+      p0 = f32[2,2] parameter(0)
+      producer = f32[2,2] fusion(p0), kind=kLoop, calls=fused_producer
+      consumer = f32[2,2] fusion(p0, producer), kind=kLoop, calls=fused_consumer
+      ROOT root = (f32[2,2], f32[2,2]) tuple(producer, consumer)
+    })"))
+                    .ValueOrDie();
+  const HloInstruction* root = module->entry_computation()->root_instruction();
+  const HloInstruction* producer = root->operand(0);
+  const HloInstruction* consumer = root->operand(1);
+  EXPECT_TRUE(CreatesNestedLoop(*producer, *consumer));
+}
+
+TEST_F(GpuFusibleTest, DoesNotCreateNestedLoop_FusionInstr) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_producer {
+      p_0 = f32[2,2] parameter(0)
+      constant = f32[] constant(1)
+      ROOT reduce-window = f32[2,2] reduce-window(p_0, constant),
+        window={size=2x2 pad=0_1x0_1}, to_apply=scalar_add
+    }
+
+    fused_consumer {
+      p_0 = f32[2,2] parameter(0)
+
+      p_1 = f32[2,2] parameter(1)
+      constant = f32[] constant(1)
+      reduce-window = f32[2,2] reduce-window(p_1, constant),
+        window={size=2x2 pad=0_1x0_1}, to_apply=scalar_add
+
+      ROOT scaled_p_1 = f32[2,2] multiply(f32[2, 2] p_0, f32[2,2] reduce-window)
+    }
+
+    ENTRY entry {
+      p_0 = f32[2,2] parameter(0)
+      producer = f32[2,2] fusion(p_0), kind=kLoop, calls=fused_producer
+      consumer = f32[2,2] fusion(producer, p_0), kind=kLoop, calls=fused_consumer
+      ROOT root = (f32[2,2], f32[2,2]) tuple(producer, consumer)
+    })"))
+                    .ValueOrDie();
+  const HloInstruction* root = module->entry_computation()->root_instruction();
+  const HloInstruction* producer = root->operand(0);
+  const HloInstruction* consumer = root->operand(1);
+  EXPECT_FALSE(CreatesNestedLoop(*producer, *consumer));
+}
+
 }  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index c10f5b9..f299c49 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -1044,7 +1044,8 @@
     thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
         /*source_address=*/operand_buffer,
         /*destination_buffer=*/destination_buffer,
-        /*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape()), scatter));
+        /*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape()),
+        /*hlo_instruction=*/nullptr));
   }
 
   thunks.push_back(
@@ -2357,13 +2358,11 @@
 void IrEmitterUnnested::EmitTileElementForCopy(
     HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
     const KernelCodegenInfo* kernel_info, llvm::Value* y_loc,
-    llvm::Value* x_loc, int64 /*x_iter_num*/) {
-  llvm_ir::TiledParameterInfo* tiled_param_info =
-      kernel_info->GetTiledParameterInfo();
+    llvm::Value* x_loc, int64 /*x_iter_num*/,
+    absl::Span<llvm::Value* const> param_shmem_buffers) {
   // TODO(jlebar): Add AA metadata to this load.
   llvm::Instruction* load_from_shmem_buffer =
-      Load(GEP(tiled_param_info->GetBufferForParameter(0),
-               {b_.getInt64(0), x_loc, y_loc}),
+      Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x_loc, y_loc}),
            "output_element");
   llvm_ir::IrArray output_array = GetIrArray(*hlo, *hlo);
   Shape output_reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
@@ -2387,17 +2386,15 @@
 void IrEmitterUnnested::EmitTileElementForFusion(
     HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
     const KernelCodegenInfo* kernel_info, llvm::Value* y_loc,
-    llvm::Value* x_loc, int64 /*x_iter_num*/) {
-  llvm_ir::TiledParameterInfo* tiled_param_info =
-      kernel_info->GetTiledParameterInfo();
+    llvm::Value* x_loc, int64 /*x_iter_num*/,
+    absl::Span<llvm::Value* const> param_shmem_buffers) {
   std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(*hlo);
   GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
                                      GetNestedComputer());
   FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo),
-                               &elem_emitter);
-  tiled_param_info->set_y(y_loc);
-  tiled_param_info->set_x(x_loc);
-  fused_emitter.SetTiledParameterInfo(tiled_param_info);
+                               &elem_emitter, x_loc, y_loc,
+                               param_shmem_buffers);
+
   TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter));
   IrArray::Index untiled_index =
       kernel_info->GetKernelMappingScheme()->GetUnnormalizedIndex(
@@ -2530,32 +2527,6 @@
   bool is_row_reduction_;
 };
 
-namespace {
-// Returns a group of instructions that generate the output for the kernel
-// containing the given HLO instruction. The result may be an unnested kReduce
-// HLO, a nested kReduce HLO of a kInput fusion, or the operands of the tuple
-// for a multiple output fusion.
-absl::Span<HloInstruction* const> GetOutputInstructions(
-    HloInstruction* const* reduce_or_tuple_pointer) {
-  HloOpcode opcode = (*reduce_or_tuple_pointer)->opcode();
-  CHECK(opcode == HloOpcode::kReduce || opcode == HloOpcode::kTuple);
-  return opcode == HloOpcode::kTuple
-             ? (*reduce_or_tuple_pointer)->operands()
-             : absl::Span<HloInstruction* const>(reduce_or_tuple_pointer, 1);
-}
-
-const HloInstruction* GetFirstReduceInstruction(
-    absl::Span<HloInstruction* const> instructions) {
-  auto first_reduce_iter =
-      absl::c_find_if(instructions, [](const HloInstruction* inst) {
-        return IsReductionFromOrToContiguousDimensions(*inst);
-      });
-  CHECK_NE(first_reduce_iter, instructions.end());
-  return *first_reduce_iter;
-}
-
-};  // namespace
-
 void IrEmitterUnnested::EmitPrologueForOneReduction(
     HloInstruction* unnested_hlo, HloInstruction* reduce_inst, int reduce_idx,
     KernelCodegenInfo* kernel_info, GpuElementalIrEmitter* elemental_emitter,
@@ -2611,14 +2582,13 @@
 }
 
 void IrEmitterUnnested::EmitPrologueForReduction(
-    HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info) {
+    HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info,
+    absl::Span<HloInstruction* const> output_instructions) {
   VLOG(10) << "Emit prologue for reduction " << unnested_hlo->ToString();
   // Find the unnested kReduce or the tuple that contains a list of kReduce.
   HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion
                                         ? unnested_hlo->fused_expression_root()
                                         : unnested_hlo;
-  absl::Span<HloInstruction* const> output_instructions =
-      GetOutputInstructions(&reduce_or_tuple);
   auto reduction_info = static_cast<ReductionCodegenInfo*>(kernel_info);
   GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
                                           ir_emitter_context_->llvm_module(),
@@ -2691,7 +2661,8 @@
 }
 
 void IrEmitterUnnested::EmitEpilogueForReduction(
-    HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info) {
+    HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info,
+    absl::Span<const HloInstruction* const> reduce_instructions) {
   auto reduction_info = static_cast<ReductionCodegenInfo*>(kernel_info);
   int num_reduces = reduction_info->GetNumberOfReduces();
   absl::Span<llvm::AllocaInst* const> partial_result_addresses =
@@ -2720,16 +2691,6 @@
     llvm_ir::SetToFirstInsertPoint(if_output_inbound_data.true_block, &b_);
   }
 
-  HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion
-                                        ? unnested_hlo->fused_expression_root()
-                                        : unnested_hlo;
-  std::vector<const HloInstruction*> reduce_instructions;
-  absl::c_for_each(GetOutputInstructions(&reduce_or_tuple),
-                   [&](const HloInstruction* instr) {
-                     if (IsReductionFromOrToContiguousDimensions(*instr)) {
-                       reduce_instructions.push_back(instr);
-                     }
-                   });
   int num_partial_results = reduction_info->GetNumberOfPartialResults();
 
   // Emit an atomic operation that accumulates the partial reduction to the
@@ -2794,18 +2755,14 @@
 }
 
 void IrEmitterUnnested::EmitTileElementForReduction(
-    HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index,
-    const KernelCodegenInfo* kernel_info, llvm::Value* y_loc,
-    llvm::Value* x_loc, int64 x_iter_num) {
+    HloInstruction* unnested_hlo, const Shape& reduction_operand_shape,
+    absl::Span<HloInstruction* const> output_instructions,
+    const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info,
+    int64 x_iter_num) {
   VLOG(10) << "Emit tile element for reduce " << unnested_hlo->ToString();
   HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion
                                         ? unnested_hlo->fused_expression_root()
                                         : unnested_hlo;
-  llvm_ir::TiledParameterInfo* tiled_param_info =
-      kernel_info->GetTiledParameterInfo();
-  tiled_param_info->set_y(y_loc);
-  tiled_param_info->set_x(x_loc);
-
   // Record the untransposed output linear address for the reduction.
   auto reduction_info = dynamic_cast<const ReductionCodegenInfo*>(kernel_info);
   int partial_result_index = reduction_info->IsRowReduction() ? 0 : x_iter_num;
@@ -2827,12 +2784,9 @@
                                      GetNestedComputer());
   FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo),
                                &elem_emitter);
-  absl::Span<HloInstruction* const> output_instructions =
-      GetOutputInstructions(&reduce_or_tuple);
   // Construct the ElementGenerator for each reduction and extra output in the
   // the group of output instructions.
   if (unnested_hlo->opcode() == HloOpcode::kFusion) {
-    fused_emitter.SetTiledParameterInfo(tiled_param_info);
     TF_CHECK_OK(unnested_hlo->fused_expression_root()->Accept(&fused_emitter));
 
     for (int i = 0, e = output_instructions.size(); i != e; ++i) {
@@ -2855,8 +2809,6 @@
     });
   }
 
-  Shape reduction_operand_shape =
-      GetFirstReduceInstruction(output_instructions)->operand(0)->shape();
   IrArray::Index input_index =
       reduction_info->GetKernelMappingScheme()->GetUnnormalizedIndex(
           index, reduction_operand_shape);
@@ -3050,9 +3002,6 @@
                 absl::Span<llvm::Value* const> output_tile_bounds) {
               std::vector<llvm::Value*> param_shmem_buffers(
                   unnested_hlo->operand_count(), nullptr);
-              llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers,
-                                                           y, x);
-              kernel_info->SetTiledParamInfo(&tiled_param_info);
               kernel_generator.GetTileElementGenerator()(
                   y, x, output_tile_origin, "output", output_tile_bounds[1],
                   output_tile_bounds[2], &ksl);
@@ -3118,7 +3067,6 @@
       Shape reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
           param->shape().element_type(),
           Permute({0, 2, 1}, reduced_output_dims));
-      LOG(ERROR) << "Generated shape: " << reduced_shape.ToString(true);
       param_in_reduced_shape_arrays.push_back(
           param_arrays[id].CastToShape(reduced_shape, &b_));
     } else {
@@ -3131,11 +3079,11 @@
           llvm::Value* x_loc, int64 x_iter_num) {
         if (hlo->opcode() == HloOpcode::kCopy) {
           EmitTileElementForCopy(hlo, index, &kernel_info, y_loc, x_loc,
-                                 x_iter_num);
+                                 x_iter_num, param_shmem_buffers);
         } else {
           CHECK_EQ(hlo->opcode(), HloOpcode::kFusion);
           EmitTileElementForFusion(hlo, index, &kernel_info, y_loc, x_loc,
-                                   x_iter_num);
+                                   x_iter_num, param_shmem_buffers);
         }
       };
 
@@ -3143,9 +3091,6 @@
       [&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index,
           const string& loop_name, llvm::Value* tile_height,
           llvm::Value* tile_width, KernelSupportLibrary* ksl) {
-        llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x);
-        kernel_info.SetTiledParamInfo(&tiled_param_info);
-
         // If shared memory transpose is needed, wait for all threads to reach
         // this point, lest we copy a value from tile to output before the other
         // thread copies it from input to tile. This is `__syncthreads` in CUDA.
@@ -3379,6 +3324,10 @@
     }
   }
 
+  if (params_012.empty()) {
+    return false;
+  }
+
   VLOG(3) << "EmitHlo021Tile Emitting hlo tile 0-2-1" << hlo->ToString();
   std::unique_ptr<KernelThunk> kernel_thunk =
       BuildKernelThunk(hlo, /*implements_whole_instruction=*/true);
@@ -3585,11 +3534,26 @@
   HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion
                                         ? unnested_hlo->fused_expression_root()
                                         : unnested_hlo;
-  absl::Span<HloInstruction* const> output_instructions =
-      GetOutputInstructions(&reduce_or_tuple);
-  const HloInstruction* first_reduce =
-      GetFirstReduceInstruction(output_instructions);
+  // A group of instructions that generate the output for the kernel
+  // containing the given HLO instruction. The result may be an unnested kReduce
+  // HLO, a nested kReduce HLO of a kInput fusion, or the operands of the tuple
+  // for a multiple output fusion.
+  auto output_instructions = ([&]() -> absl::Span<HloInstruction* const> {
+    if (reduce_or_tuple->opcode() == HloOpcode::kReduce) {
+      return absl::Span<HloInstruction* const>(&reduce_or_tuple, 1);
+    }
+    CHECK(reduce_or_tuple->opcode() == HloOpcode::kTuple);
+    return reduce_or_tuple->operands();
+  })();
 
+  std::vector<const HloInstruction*> reduce_instructions;
+  absl::c_for_each(output_instructions, [&](const HloInstruction* instr) {
+    if (IsReductionFromOrToContiguousDimensions(*instr)) {
+      reduce_instructions.push_back(instr);
+    }
+  });
+
+  const HloInstruction* first_reduce = reduce_instructions.at(0);
   if (output_instructions.size() > 1) {
     TF_RETURN_IF_ERROR(
         AreFusedReductionOutputsConsistent(output_instructions, first_reduce));
@@ -3629,8 +3593,9 @@
   EmitElementFunction emit_reduction_tile =
       [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
           llvm::Value* x_loc, int64 x_iter_num) {
-        EmitTileElementForReduction(unnested_hlo, index, &reduction_info, y_loc,
-                                    x_loc, x_iter_num);
+        EmitTileElementForReduction(unnested_hlo, input_shape,
+                                    output_instructions, index, &reduction_info,
+                                    x_iter_num);
       };
 
   KernelCodeGenerator kernel_generator(
@@ -3644,11 +3609,11 @@
       },
       /*block_prologue_generator=*/
       [&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) {
-        EmitPrologueForReduction(hlo, kernel_info);
+        EmitPrologueForReduction(hlo, kernel_info, output_instructions);
       },
       /*block_epilogue_generator*/
       [&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) {
-        EmitEpilogueForReduction(hlo, kernel_info);
+        EmitEpilogueForReduction(hlo, kernel_info, reduce_instructions);
       });
 
   LaunchDimensions launch_dimensions = EmitKernel(
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 514de5a..6804918 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -67,29 +67,21 @@
    public:
     explicit KernelCodegenInfo(llvm_ir::KernelMappingScheme* mapping_scheme)
         : mapping_scheme_(mapping_scheme),
-          tiled_param_info_(nullptr),
           lane_id_(nullptr),
           index_ty_(nullptr) {}
     virtual ~KernelCodegenInfo() {}
 
     void SetLaneId(llvm::Value* v) { lane_id_ = v; }
     void SetIndexType(llvm::Type* t) { index_ty_ = t; }
-    void SetTiledParamInfo(llvm_ir::TiledParameterInfo* tiled_param_info) {
-      tiled_param_info_ = tiled_param_info;
-    }
 
     llvm::Value* GetLaneId() const { return lane_id_; }
     llvm_ir::KernelMappingScheme* GetKernelMappingScheme() const {
       return mapping_scheme_;
     }
-    llvm_ir::TiledParameterInfo* GetTiledParameterInfo() const {
-      return tiled_param_info_;
-    }
     llvm::Type* GetIndexType() const { return index_ty_; }
 
    protected:
     llvm_ir::KernelMappingScheme* mapping_scheme_;
-    llvm_ir::TiledParameterInfo* tiled_param_info_;
     llvm::Value* lane_id_;
     llvm::Type* index_ty_;
   };
@@ -265,36 +257,40 @@
 
   // Emits code to process a tensor element in a tile for the given kCopy HLO
   // that performs a 0-2-1 transpose.
-  void EmitTileElementForCopy(HloInstruction* hlo,
-                              const llvm_ir::IrArray::Index& index,
-                              const KernelCodegenInfo* kernel_info,
-                              llvm::Value* y_loc, llvm::Value* x_loc,
-                              int64 x_iter_num);
+  void EmitTileElementForCopy(
+      HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
+      const KernelCodegenInfo* kernel_info, llvm::Value* y_loc,
+      llvm::Value* x_loc, int64 x_iter_num,
+      absl::Span<llvm::Value* const> param_shmem_buffers);
+
   // Emits code to process a tensor element in a tile for the given kLoop fusion
   // HLO containing parameters that are 0-2-1 transpose of its outputs.
-  void EmitTileElementForFusion(HloInstruction* hlo,
-                                const llvm_ir::IrArray::Index& index,
-                                const KernelCodegenInfo* kernel_info,
-                                llvm::Value* y_loc, llvm::Value* x_loc,
-                                int64 x_iter_num);
+  void EmitTileElementForFusion(
+      HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
+      const KernelCodegenInfo* kernel_info, llvm::Value* y_loc,
+      llvm::Value* x_loc, int64 x_iter_num,
+      absl::Span<llvm::Value* const> param_shmem_buffers);
+
   // Emits code to process a tensor element in a tile for the given input hlo
   // that is either a unnested kReduce or a kInput fusion.
-  void EmitTileElementForReduction(HloInstruction* unnested_hlo,
-                                   const llvm_ir::IrArray::Index& index,
-                                   const KernelCodegenInfo* kernel_info,
-                                   llvm::Value* y_loc, llvm::Value* x_loc,
-                                   int64 x_iter_num);
+  void EmitTileElementForReduction(
+      HloInstruction* unnested_hlo, const Shape& reduction_operand_shape,
+      absl::Span<HloInstruction* const> output_instructions,
+      const llvm_ir::IrArray::Index& index,
+      const KernelCodegenInfo* kernel_info, int64 x_iter_num);
   // Prepares for the code generation for a tile block of a reduction kernel.
-  void EmitPrologueForReduction(HloInstruction* unnested_hlo,
-                                KernelCodegenInfo* kernel_info);
+  void EmitPrologueForReduction(
+      HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info,
+      absl::Span<HloInstruction* const> output_instructions);
   void EmitPrologueForOneReduction(HloInstruction* unnested_hlo,
                                    HloInstruction* reduce_inst, int reduce_idx,
                                    KernelCodegenInfo* kernel_info,
                                    GpuElementalIrEmitter* elemental_emitter,
                                    ShapeIndex output_shape_index);
   // Wraps up the code generation for a tile block of a reduction kernel.
-  void EmitEpilogueForReduction(HloInstruction* unnested_hlo,
-                                KernelCodegenInfo* kernel_info);
+  void EmitEpilogueForReduction(
+      HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info,
+      absl::Span<const HloInstruction* const> reduce_instructions);
   // For each reducer, emits the shuffle-down loop to accumulate the partial
   // result to the global result.
   void EmitFullWarpShuffleDownLoopForAllReduces(
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD
index 2f73fd0..db26d36 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD
@@ -16,12 +16,12 @@
     name = "llvm_gpu_backend",
     srcs = [
         "dump_ir_pass.cc",
-        "nvptx_backend_lib.cc",
+        "gpu_backend_lib.cc",
         "utils.cc",
     ],
     hdrs = [
         "dump_ir_pass.h",
-        "nvptx_backend_lib.h",
+        "gpu_backend_lib.h",
         "utils.h",
     ],
     deps = [
@@ -30,6 +30,7 @@
         "//tensorflow/compiler/xla:types",
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla/service:hlo_module_config",
+        "//tensorflow/compiler/xla/service/gpu:gpu_types",
         "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
new file mode 100644
index 0000000..84616f3
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
@@ -0,0 +1,752 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
+
+#include <fstream>
+#include <map>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/Bitcode/BitcodeReader.h"
+#include "llvm/Bitcode/BitcodeWriter.h"
+#include "llvm/CodeGen/CommandFlags.inc"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Verifier.h"
+#include "llvm/Linker/Linker.h"
+#include "llvm/PassRegistry.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/FormattedStream.h"
+#include "llvm/Support/Program.h"
+#include "llvm/Support/TargetRegistry.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/ToolOutputFile.h"
+#include "llvm/Target/TargetMachine.h"
+#include "llvm/Transforms/IPO.h"
+#include "llvm/Transforms/IPO/AlwaysInliner.h"
+#include "llvm/Transforms/IPO/Internalize.h"
+#include "llvm/Transforms/IPO/PassManagerBuilder.h"
+#include "llvm/Transforms/Scalar.h"
+#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h"
+#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/profiler/lib/traceme.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+// Inline threshold value to use in LLVM AMDGPU backend.
+const int kAMDGPUInlineThreshold = 0x100000;
+
+// Default inline threshold value to use in llvm.
+const int kDefaultInlineThreshold = 1100;
+
+// Gets the GPU name as it's known to LLVM for a given compute capability.  If
+// we see an unrecognized compute capability, we return "sm_35".
+static string GetSmName(std::pair<int, int> compute_capability) {
+  static auto* m = new std::map<std::pair<int, int>, int>({
+      {{3, 5}, 35},
+      {{3, 7}, 37},
+      {{5, 0}, 50},
+      {{5, 2}, 52},
+      {{5, 3}, 53},
+      {{6, 0}, 60},
+      {{6, 1}, 61},
+      {{6, 2}, 62},
+      {{7, 0}, 70},
+      {{7, 2}, 72},
+      {{7, 5}, 75},
+  });
+  int sm_version = 35;
+  auto it = m->find(compute_capability);
+  if (it != m->end()) {
+    sm_version = it->second;
+  } else {
+    LOG(WARNING) << "Unknown compute capability (" << compute_capability.first
+                 << ", " << compute_capability.second << ") ."
+                 << "Defaulting to telling LLVM that we're compiling for sm_"
+                 << sm_version;
+  }
+  return absl::StrCat("sm_", sm_version);
+}
+
+// Convenience function for producing a name of a temporary compilation product
+// from the input filename.
+string MakeNameForTempProduct(absl::string_view input_filename,
+                              absl::string_view extension) {
+  return ReplaceFilenameExtension(tensorflow::io::Basename(input_filename),
+                                  extension);
+}
+
+// Initializes LLVM passes. Uses the PassRegistry mechanism.
+void InitializePasses(llvm::PassRegistry* pass_registry) {
+  llvm::initializeCore(*pass_registry);
+  llvm::initializeCodeGen(*pass_registry);
+  llvm::initializeScalarOpts(*pass_registry);
+  llvm::initializeObjCARCOpts(*pass_registry);
+  llvm::initializeVectorization(*pass_registry);
+  llvm::initializeIPO(*pass_registry);
+  llvm::initializeAnalysis(*pass_registry);
+  llvm::initializeTransformUtils(*pass_registry);
+  llvm::initializeInstCombine(*pass_registry);
+  llvm::initializeInstrumentation(*pass_registry);
+  llvm::initializeTarget(*pass_registry);
+  llvm::initializeCodeGenPreparePass(*pass_registry);
+}
+
+// Returns the TargetMachine, given a triple.
+std::unique_ptr<llvm::TargetMachine> GetTargetMachine(
+    llvm::Triple triple, absl::string_view cpu_name,
+    const HloModuleConfig& hlo_module_config, absl::string_view feature_str) {
+  std::string error;
+  const llvm::Target* target = TargetRegistry::lookupTarget("", triple, error);
+  if (target == nullptr) {
+    LOG(FATAL) << "Unable to find Target for triple '" << triple.str() << "'"
+               << " -- " << error;
+    return nullptr;
+  }
+
+  TargetOptions target_options = InitTargetOptionsFromCodeGenFlags();
+
+  // Set the verbose assembly options.
+  target_options.MCOptions.AsmVerbose = false;
+
+  // The selection of codegen optimization level is copied from function
+  // GetCodeGenOptLevel in //third_party/llvm/llvm/tools/opt/opt.cpp.
+  CodeGenOpt::Level codegen_opt_level;
+  switch (hlo_module_config.debug_options().xla_backend_optimization_level()) {
+    case 1:
+      codegen_opt_level = CodeGenOpt::Less;
+      break;
+    case 2:
+      codegen_opt_level = CodeGenOpt::Default;
+      break;
+    case 3:
+      codegen_opt_level = CodeGenOpt::Aggressive;
+      break;
+    default:
+      codegen_opt_level = CodeGenOpt::None;
+  }
+  return absl::WrapUnique(target->createTargetMachine(
+      triple.str(), llvm_ir::AsStringRef(cpu_name),
+      llvm_ir::AsStringRef(feature_str), target_options, getRelocModel(),
+      getCodeModel(), codegen_opt_level));
+}
+
+// Adds the standard LLVM optimization passes, based on the speed optimization
+// level (opt_level) and size optimization level (size_level). Both module
+// and function-level passes are added, so two pass managers are passed in and
+// modified by this function.
+void AddOptimizationPasses(unsigned opt_level, unsigned size_level,
+                           llvm::TargetMachine* target_machine,
+                           llvm::legacy::PassManagerBase* module_passes,
+                           llvm::legacy::FunctionPassManager* function_passes,
+                           int inline_threshold) {
+  PassManagerBuilder builder;
+  builder.OptLevel = opt_level;
+  builder.SizeLevel = size_level;
+
+  if (opt_level > 1) {
+    builder.Inliner = llvm::createFunctionInliningPass(inline_threshold);
+  } else {
+    // Only inline functions marked with "alwaysinline".
+    builder.Inliner = llvm::createAlwaysInlinerLegacyPass();
+  }
+
+  builder.DisableUnrollLoops = opt_level == 0;
+  builder.LoopVectorize = opt_level > 0;
+  builder.SLPVectorize = opt_level > 1 && size_level < 2;
+
+  // NVPTX's early-as-possible passes include NVVM reflect.
+  target_machine->adjustPassManager(builder);
+
+  builder.populateFunctionPassManager(*function_passes);
+  builder.populateModulePassManager(*module_passes);
+}
+
+// Emits the given module to a bit code file.
+void EmitBitcodeToFile(const Module& module, absl::string_view filename) {
+  std::error_code error_code;
+  llvm::ToolOutputFile outfile(string(filename).c_str(), error_code,
+                               llvm::sys::fs::F_None);
+  if (error_code) {
+    LOG(FATAL) << "opening bitcode file for writing: " << error_code.message();
+  }
+
+  llvm::WriteBitcodeToFile(module, outfile.os());
+  outfile.keep();
+}
+
+// Emits the given module to PTX. target_machine is an initialized TargetMachine
+// for the NVPTX target.
+string EmitModuleToPTX(Module* module, llvm::TargetMachine* target_machine) {
+  std::string ptx;  // need a std::string instead of a ::string.
+  {
+    llvm::raw_string_ostream stream(ptx);
+    llvm::buffer_ostream pstream(stream);
+    // The extension is stripped by IrDumpingPassManager, so we need to
+    // get creative to add a suffix.
+    IrDumpingPassManager codegen_passes(
+        MakeNameForTempProduct(module->getModuleIdentifier(), "-nvptx.dummy"),
+        "", false);
+    codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass(
+        llvm::Triple(module->getTargetTriple())));
+
+    target_machine->addPassesToEmitFile(codegen_passes, pstream, nullptr,
+                                        llvm::TargetMachine::CGFT_AssemblyFile);
+    codegen_passes.run(*module);
+  }
+
+  return ptx;
+}
+
+// LLVM has an extensive flags mechanism of its own, which is only accessible
+// through the command line. Internal libraries within LLVM register parsers for
+// flags, with no other way to configure them except pass these flags.
+// To do this programmatically, we invoke ParseCommandLineOptions manually with
+// a "fake argv".
+// Note: setting flags with this method is stateful, since flags are just
+// static globals within LLVM libraries.
+void FeedLLVMWithFlags(const std::vector<string>& cl_opts) {
+  std::vector<const char*> fake_argv = {""};
+  for (const string& cl_opt : cl_opts) {
+    fake_argv.push_back(cl_opt.c_str());
+  }
+  llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]);
+}
+
+// Returns whether the module could use any device bitcode library functions.
+// This function may have false positives -- the module might not use libdevice
+// on NVPTX or ROCm-Device-Libs on AMDGPU even if this function returns true.
+bool CouldNeedDeviceBitcode(const llvm::Module& module) {
+  for (const llvm::Function& function : module.functions()) {
+    // This is a conservative approximation -- not all such functions are in
+    // libdevice or ROCm-Device-Libs.
+    if (!function.isIntrinsic() && function.isDeclaration()) {
+      return true;
+    }
+  }
+  return false;
+}
+
+// Links the module with a vector of path to bitcode modules.
+// The caller must guarantee that the paths exist.
+Status LinkWithBitcodeVector(llvm::Module* module,
+                             const std::vector<string>& bitcode_path_vector) {
+  llvm::Linker linker(*module);
+
+  for (auto& bitcode_path : bitcode_path_vector) {
+    if (!tensorflow::Env::Default()->FileExists(bitcode_path).ok()) {
+      LOG(ERROR) << "bitcode module is required by this HLO module but was "
+                    "not found at "
+                 << bitcode_path;
+      return xla::InternalError("bitcode module not found at %s", bitcode_path);
+    }
+
+    std::unique_ptr<llvm::Module> bitcode_module =
+        LoadIRModule(bitcode_path, &module->getContext());
+    if (linker.linkInModule(
+            std::move(bitcode_module), llvm::Linker::Flags::LinkOnlyNeeded,
+            [](Module& M, const StringSet<>& GVS) {
+              internalizeModule(M, [&GVS](const GlobalValue& GV) {
+                return !GV.hasName() || (GVS.count(GV.getName()) == 0);
+              });
+            })) {
+      return xla::InternalError("Error linking bitcode module from %s",
+                                bitcode_path);
+    }
+  }
+  return Status::OK();
+}
+
+// Links libdevice into the given module if the module needs libdevice.
+Status LinkLibdeviceIfNecessary(llvm::Module* module,
+                                std::pair<int, int> compute_capability,
+                                const string& libdevice_dir_path) {
+  if (!CouldNeedDeviceBitcode(*module)) {
+    return Status::OK();
+  }
+
+  // CUDA 9+ uses a single libdevice file for all devices, and we don't support
+  // older CUDAs.
+  string libdevice_path =
+      tensorflow::io::JoinPath(libdevice_dir_path, "libdevice.10.bc");
+  if (!tensorflow::Env::Default()->FileExists(libdevice_path).ok()) {
+    LOG(WARNING)
+        << "libdevice is required by this HLO module but was not found at "
+        << libdevice_path;
+    return xla::InternalError("libdevice not found at %s", libdevice_path);
+  }
+
+  VLOG(1) << "Linking with libdevice from: " << libdevice_path;
+  return LinkWithBitcodeVector(module, {libdevice_path});
+}
+
+Status NVPTXTargetModuleLinker(llvm::Module* module, GpuVersion gpu_version,
+                               const HloModuleConfig& hlo_module_config,
+                               const string& device_bitcode_dir_path) {
+  // Link the input module with libdevice, to pull in implementations of some
+  // builtins.
+  auto compute_capability = absl::get_if<std::pair<int, int>>(&gpu_version);
+  if (!compute_capability) {
+    return xla::InternalError("Incompatible compute capability was specified.");
+  }
+  TF_RETURN_IF_ERROR(LinkLibdeviceIfNecessary(module, *compute_capability,
+                                              device_bitcode_dir_path));
+
+  // Set the flush-denormals-to-zero flag on the module so the NVVM reflect pass
+  // can access it.
+  module->addModuleFlag(llvm::Module::Override, "nvvm-reflect-ftz",
+                        hlo_module_config.debug_options().xla_gpu_ftz());
+
+  // If ftz is enabled, set it as an attribute on every function in the module.
+  if (hlo_module_config.debug_options().xla_gpu_ftz()) {
+    for (llvm::Function& fn : *module) {
+      fn.addFnAttr("nvptx-f32ftz", "true");
+    }
+  }
+
+  return Status::OK();
+}
+
+std::unique_ptr<llvm::TargetMachine> NVPTXGetTargetMachine(
+    llvm::Triple target_triple, std::pair<int, int> compute_capability,
+    const HloModuleConfig& hlo_module_config) {
+  // Figure out the exact name of the processor as known to the NVPTX backend
+  // from the gpu_architecture flag.
+  return GetTargetMachine(target_triple, GetSmName(compute_capability),
+                          hlo_module_config, "+ptx60");
+}
+
+using TargetModuleLinker = std::function<Status(
+    llvm::Module*, GpuVersion, const HloModuleConfig&, const string&)>;
+
+Status LinkAndOptimizeModule(llvm::Module* module, GpuVersion gpu_version,
+                             const HloModuleConfig& hlo_module_config,
+                             const string& device_bitcode_dir_path,
+                             TargetModuleLinker module_linker,
+                             llvm::Triple default_target_triple,
+                             llvm::TargetMachine* target_machine,
+                             int inline_threshold) {
+  TF_RETURN_IF_ERROR(module_linker(module, gpu_version, hlo_module_config,
+                                   device_bitcode_dir_path));
+
+  IrDumpingPassManager module_passes(module->getModuleIdentifier(), "", false);
+
+  // Add an appropriate TargetLibraryInfo pass for the module's triple.
+  llvm::TargetLibraryInfoWrapperPass* tliwp =
+      new llvm::TargetLibraryInfoWrapperPass(
+          llvm::Triple(module->getTargetTriple()));
+  module_passes.add(tliwp);
+
+  // Try to fetch the target triple from the module. If not present, set a
+  // default target triple.
+  llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
+  if (target_triple.getArch() == llvm::Triple::UnknownArch) {
+    LOG(WARNING) << "target triple not found in the module";
+    target_triple = default_target_triple;
+  }
+
+  module_passes.add(llvm::createTargetTransformInfoWrapperPass(
+      target_machine->getTargetIRAnalysis()));
+
+  // The LLVM IR verifier performs sanity checking on the IR. This helps
+  // discover problems and report them in a meaningful manner, rather than let
+  // later passes report obscure assertions because of unfulfilled invariants.
+  module_passes.add(llvm::createVerifierPass());
+
+  // Create the function-level pass manager. It needs data layout information
+  // too.
+  llvm::legacy::FunctionPassManager function_passes(module);
+
+  int32 opt_level =
+      hlo_module_config.debug_options().xla_backend_optimization_level();
+
+  if (opt_level < 2) {
+    LOG(ERROR) << std::string(80, '*');
+    LOG(ERROR) << "The XLA GPU backend doesn't support unoptimized code "
+                  "generation but ";
+    LOG(ERROR) << "--xla_backend_optimization_level is set to " << opt_level
+               << "!";
+    LOG(ERROR) << "(Supported configuration is "
+                  "--xla_backend_optimization_level >= 2.)";
+    LOG(ERROR) << std::string(80, '*');
+  }
+
+  // Add optimization passes, and set inliner threshold.
+  AddOptimizationPasses(opt_level,
+                        /*size_level=*/0, target_machine, &module_passes,
+                        &function_passes, inline_threshold);
+
+  // Loop unrolling exposes more opportunities for SROA. Therefore, we run SROA
+  // again after the standard optimization passes [http://b/13329423].
+  // TODO(jingyue): SROA may further expose more optimization opportunities such
+  // as more precise alias analysis and more function inlining (SROA may change
+  // the inlining cost of a function). For now, running SROA already emits good
+  // enough code for the evaluated benchmarks. We may want to run more
+  // optimizations later.
+  if (opt_level > 0) {
+    // LLVM's optimizer turns on SROA when the optimization level is greater
+    // than 0. We mimic this behavior here.
+    module_passes.add(llvm::createSROAPass());
+  }
+
+  // Verify that the module is well formed after optimizations ran.
+  module_passes.add(llvm::createVerifierPass());
+
+  // Done populating the pass managers. Now run them.
+
+  function_passes.doInitialization();
+  for (auto func = module->begin(); func != module->end(); ++func) {
+    function_passes.run(*func);
+  }
+  function_passes.doFinalization();
+  module_passes.run(*module);
+
+  return Status::OK();
+}
+
+// One-time module initializer.
+// Must be called only once -- DO NOT CALL DIRECTLY.
+void NVPTXBackendInit(const HloModuleConfig& hlo_module_config) {
+  // Feed all customized flags here, so we can override them with llvm_cl_opts
+  // without redeploy the compiler for development purpose.
+
+  // This flag tunes a threshold in branch folding. The default threshold, which
+  // is one, is not suitable for CUDA programs where branches are more expensive
+  // than for CPU programs. Setting the threshold to 2 improves the latency of
+  // TwoDPatchDotProductKernel_IND_3_ND_48 by over 5%, and does not affect the
+  // latency of other benchmarks so far.
+  //
+  // I also tried setting this threshold to other values:
+  // * 3-6 gives similar results as 2;
+  // * >6 start hurting the performance of at least dot product kernels.
+  //
+  // TODO(jingyue): The current threshold only considers the number of IR
+  // instructions which do not accurately reflect the true cost. We need a
+  // better cost model.
+  FeedLLVMWithFlags({"-bonus-inst-threshold=2"});
+  // Increase limit when scanning memory dependencies.  This helps to reduce
+  // more redundant load instructions.
+  //
+  // The specific value is currently large enough for s3d in shoc benchmark,
+  // which contains a lot of load instructions and many arithmetic instructions
+  // between those loads.
+  FeedLLVMWithFlags({"-memdep-block-scan-limit=500"});
+
+  // Use div.full -- it matters for some float-division heavy benchmarks.
+  // Using div.approx produces incorrect result for float32(max)/float32(max).
+  FeedLLVMWithFlags({"-nvptx-prec-divf32=1"});
+
+  llvm_ir::InitializeLLVMCommandLineOptions(hlo_module_config);
+
+  // Initialize the NVPTX target; it's the only target we link with, so call its
+  // specific initialization functions instead of the catch-all InitializeAll*.
+  LLVMInitializeNVPTXTarget();
+  LLVMInitializeNVPTXTargetInfo();
+  LLVMInitializeNVPTXTargetMC();
+  LLVMInitializeNVPTXAsmPrinter();
+
+  // Initialize the LLVM optimization passes.
+  llvm::PassRegistry* registry = llvm::PassRegistry::getPassRegistry();
+  InitializePasses(registry);
+}
+
+}  // namespace
+
+namespace nvptx {
+
+StatusOr<string> CompileToPtx(llvm::Module* module, GpuVersion gpu_version,
+                              const HloModuleConfig& hlo_module_config,
+                              const string& libdevice_dir_path) {
+  static std::once_flag backend_init_flag;
+  std::call_once(backend_init_flag, NVPTXBackendInit, hlo_module_config);
+
+  string ptx;
+  std::unique_ptr<llvm::TargetMachine> target_machine;
+  {
+    tensorflow::profiler::TraceMe activity(
+        [&] { return absl::StrCat("Compiling IR:", module->getName().str()); },
+        tensorflow::profiler::TraceMeLevel::kInfo);
+    XLA_SCOPED_LOGGING_TIMER("Compile module " + module->getName().str());
+
+    // If the module has no functions or globals, there's nothing to compile.
+    // Just return an empty string.
+    if (module->empty() && module->global_empty()) {
+      VLOG(2) << "Module '" << module->getName().str()
+              << "' is empty. Skipping compilation.";
+      return string();
+    }
+
+    auto compute_capability = absl::get_if<std::pair<int, int>>(&gpu_version);
+    if (!compute_capability) {
+      return xla::InternalError(
+          "Incompatible compute capability was specified.");
+    }
+
+    llvm::Triple default_target_triple("nvptx64-unknown-unknown");
+    // Construct LLVM TargetMachine for NVPTX.
+    std::unique_ptr<llvm::TargetMachine> target_machine = NVPTXGetTargetMachine(
+        default_target_triple, *compute_capability, hlo_module_config);
+
+    // Link with libdeivce, and optimize the LLVM module.
+    TF_RETURN_IF_ERROR(LinkAndOptimizeModule(
+        module, gpu_version, hlo_module_config, libdevice_dir_path,
+        NVPTXTargetModuleLinker, default_target_triple, target_machine.get(),
+        kDefaultInlineThreshold));
+
+    // Lower optimized LLVM module to PTX.
+    ptx = EmitModuleToPTX(module, target_machine.get());
+  }
+  return ptx;
+}
+
+}  // namespace nvptx
+
+namespace {
+
+// Gets the ROCm-Device-Libs filenames for a particular AMDGPU version.
+static std::vector<string> GetROCDLPaths(int amdgpu_version,
+                                         const string& rocdl_dir_path) {
+  // AMDGPU version-neutral bitcodes.
+  static std::vector<string>* rocdl_filenames = new std::vector<string>(
+      {"hc.amdgcn.bc", "opencl.amdgcn.bc", "ocml.amdgcn.bc", "ockl.amdgcn.bc",
+       "oclc_finite_only_off.amdgcn.bc", "oclc_daz_opt_off.amdgcn.bc",
+       "oclc_correctly_rounded_sqrt_on.amdgcn.bc",
+       "oclc_unsafe_math_off.amdgcn.bc"});
+
+  // Construct full path to ROCDL bitcode libraries.
+  std::vector<string> result;
+  for (auto& filename : *rocdl_filenames) {
+    result.push_back(tensorflow::io::JoinPath(rocdl_dir_path, filename));
+  }
+
+  // Add AMDGPU version-specific bitcodes.
+  result.push_back(tensorflow::io::JoinPath(
+      rocdl_dir_path,
+      absl::StrCat("oclc_isa_version_", amdgpu_version, ".amdgcn.bc")));
+  return result;
+}
+
+// Emits the given module to HSA Code Object. target_machine is an initialized
+// TargetMachine for the AMDGPU target.
+StatusOr<std::vector<uint8>> EmitModuleToHsaco(
+    Module* module, llvm::TargetMachine* target_machine) {
+  auto* env = tensorflow::Env::Default();
+  std::vector<std::string> tempdir_vector;
+  env->GetLocalTempDirectories(&tempdir_vector);
+  if (tempdir_vector.empty()) {
+    return xla::InternalError(
+        "Unable to locate a temporary directory for compile-time artifacts.");
+  }
+  std::string tempdir_name = tempdir_vector.front();
+  VLOG(1) << "Compile-time artifacts located at: " << tempdir_name;
+
+  // Prepare filenames for all stages of compilation:
+  // IR, binary ISA, and HSACO.
+  std::string ir_filename = absl::StrCat(module->getModuleIdentifier(), ".ll");
+  std::string ir_path = tensorflow::io::JoinPath(tempdir_name, ir_filename);
+
+  std::string isabin_filename =
+      absl::StrCat(module->getModuleIdentifier(), ".o");
+  std::string isabin_path =
+      tensorflow::io::JoinPath(tempdir_name, isabin_filename);
+
+  std::string hsaco_filename =
+      absl::StrCat(module->getModuleIdentifier(), ".hsaco");
+  std::string hsaco_path =
+      tensorflow::io::JoinPath(tempdir_name, hsaco_filename);
+
+  std::error_code ec;
+
+  // Dump LLVM IR.
+  std::unique_ptr<llvm::raw_fd_ostream> ir_fs(
+      new llvm::raw_fd_ostream(ir_path, ec, llvm::sys::fs::F_None));
+  module->print(*ir_fs, nullptr);
+  ir_fs->flush();
+
+  // Emit GCN ISA binary.
+  // The extension is stripped by IrDumpingPassManager, so we need to
+  // get creative to add a suffix.
+  std::string module_id = module->getModuleIdentifier();
+  IrDumpingPassManager codegen_passes(
+      ReplaceFilenameExtension(tensorflow::io::Basename(module_id),
+                               "-amdgpu.dummy"),
+      "", false);
+  codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass(
+      llvm::Triple(module->getTargetTriple())));
+  llvm::SmallVector<char, 0> stream;
+  llvm::raw_svector_ostream pstream(stream);
+  std::unique_ptr<llvm::raw_fd_ostream> isabin_fs(
+      new llvm::raw_fd_ostream(isabin_path, ec, llvm::sys::fs::F_Text));
+  module->setDataLayout(target_machine->createDataLayout());
+  target_machine->addPassesToEmitFile(codegen_passes, *isabin_fs, nullptr,
+                                      llvm::TargetMachine::CGFT_ObjectFile);
+  codegen_passes.run(*module);
+  isabin_fs->flush();
+
+  // Locate lld.
+  // TODO(whchung@gmail.com): change to tensorflow::ROCmRoot() after
+  // ROCm-Device-Libs PR.
+  std::string lld_path = tensorflow::io::JoinPath("/opt/rocm", "hcc/bin");
+  auto lld_program = llvm::sys::findProgramByName("ld.lld", {lld_path});
+  if (!lld_program) {
+    return xla::InternalError("unable to find ld.lld in PATH: %s",
+                              lld_program.getError().message());
+  }
+  std::vector<llvm::StringRef> lld_args{
+      llvm_ir::AsStringRef("ld.lld"),
+      llvm_ir::AsStringRef("-flavor"),
+      llvm_ir::AsStringRef("gnu"),
+      llvm_ir::AsStringRef("-shared"),
+      llvm_ir::AsStringRef(isabin_path),
+      llvm_ir::AsStringRef("-o"),
+      llvm_ir::AsStringRef(hsaco_path),
+  };
+
+  std::string error_message;
+  int lld_result =
+      llvm::sys::ExecuteAndWait(*lld_program, llvm_ir::AsArrayRef(lld_args),
+                                llvm::None, {}, 0, 0, &error_message);
+
+  if (lld_result) {
+    return xla::InternalError("ld.lld execute fail: %s", error_message);
+  }
+
+  // Read HSACO.
+  std::ifstream hsaco_file(hsaco_path, std::ios::binary | std::ios::ate);
+  std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg();
+
+  std::vector<uint8> hsaco(hsaco_file_size);
+  hsaco_file.seekg(0, std::ios::beg);
+  hsaco_file.read(reinterpret_cast<char*>(&hsaco[0]), hsaco_file_size);
+  return hsaco;
+}
+
+// Links ROCm-Device-Libs into the given module if the module needs it.
+Status LinkROCDLIfNecessary(llvm::Module* module, int amdgpu_version,
+                            const string& rocdl_dir_path) {
+  if (!CouldNeedDeviceBitcode(*module)) {
+    return Status::OK();
+  }
+
+  return LinkWithBitcodeVector(module,
+                               GetROCDLPaths(amdgpu_version, rocdl_dir_path));
+}
+
+Status AMDGPUTargetModuleLinker(llvm::Module* module, GpuVersion gpu_version,
+                                const HloModuleConfig& hlo_module_config,
+                                const string& device_bitcode_dir_path) {
+  // Link the input module with ROCDL.
+  auto amdgpu_version = absl::get_if<int>(&gpu_version);
+  if (!amdgpu_version) {
+    return xla::InternalError(
+        "Incompatible AMD GCN ISA version was specified.");
+  }
+  TF_RETURN_IF_ERROR(
+      LinkROCDLIfNecessary(module, *amdgpu_version, device_bitcode_dir_path));
+
+  return Status::OK();
+}
+
+std::unique_ptr<llvm::TargetMachine> AMDGPUGetTargetMachine(
+    llvm::Triple target_triple, int amdgpu_version,
+    const HloModuleConfig& hlo_module_config) {
+  return GetTargetMachine(target_triple, absl::StrCat("gfx", amdgpu_version),
+                          hlo_module_config, "-code-object-v3");
+}
+
+void AMDGPUBackendInit(const HloModuleConfig& hlo_module_config) {
+  llvm_ir::InitializeLLVMCommandLineOptions(hlo_module_config);
+
+  // Initialize the AMDGPU target; it's the only target we link with, so call
+  // its specific initialization functions instead of the catch-all
+  // InitializeAll*.
+#if TENSORFLOW_USE_ROCM
+  LLVMInitializeAMDGPUTarget();
+  LLVMInitializeAMDGPUTargetInfo();
+  LLVMInitializeAMDGPUTargetMC();
+  LLVMInitializeAMDGPUAsmPrinter();
+#endif
+
+  llvm::PassRegistry* registry = llvm::PassRegistry::getPassRegistry();
+  InitializePasses(registry);
+}
+
+}  // namespace
+
+namespace amdgpu {
+StatusOr<std::vector<uint8>> CompileToHsaco(
+    llvm::Module* module, GpuVersion gpu_version,
+    const HloModuleConfig& hlo_module_config, const string& rocdl_dir_path) {
+  static std::once_flag backend_init_flag;
+  std::call_once(backend_init_flag, AMDGPUBackendInit, hlo_module_config);
+
+  std::vector<uint8> hsaco;
+  std::unique_ptr<llvm::TargetMachine> target_machine;
+  {
+    tensorflow::profiler::TraceMe activity(
+        [&] { return absl::StrCat("Compiling IR", module->getName().str()); },
+        tensorflow::profiler::TraceMeLevel::kInfo);
+    XLA_SCOPED_LOGGING_TIMER("Compile module " + module->getName().str());
+
+    auto amdgpu_version = absl::get_if<int>(&gpu_version);
+    if (!amdgpu_version) {
+      return xla::InternalError(
+          "Incompatible AMD GCN ISA version was specified.");
+    }
+
+    llvm::Triple default_target_triple("amdgcn--amdhsa-amdgiz");
+    // Construct LLVM TargetMachine for AMDGPU.
+    std::unique_ptr<llvm::TargetMachine> target_machine =
+        AMDGPUGetTargetMachine(default_target_triple, *amdgpu_version,
+                               hlo_module_config);
+
+    // Link with ROCm-Device-Libs, and optimize the LLVM module.
+    TF_RETURN_IF_ERROR(LinkAndOptimizeModule(
+        module, gpu_version, hlo_module_config, rocdl_dir_path,
+        AMDGPUTargetModuleLinker, default_target_triple, target_machine.get(),
+        kAMDGPUInlineThreshold));
+
+    // Lower optimized LLVM module to HSA code object.
+    TF_ASSIGN_OR_RETURN(hsaco, EmitModuleToHsaco(module, target_machine.get()));
+  }
+  return hsaco;
+}
+
+}  // namespace amdgpu
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h
new file mode 100644
index 0000000..526621d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h
@@ -0,0 +1,58 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// LLVM-based compiler backend.
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_GPU_BACKEND_LIB_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_GPU_BACKEND_LIB_H_
+
+#include <string>
+#include <utility>
+
+#include "absl/strings/string_view.h"
+#include "llvm/IR/Module.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_types.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+namespace gpu {
+
+namespace nvptx {
+// Compiles the argument module and returns it. libdevice_dir_path is the parent
+// directory of the libdevice bitcode libraries. The contents of the module may
+// be changed.
+//
+// The Compile.* interfaces each create their own llvm::LLVMContext objects for
+// thread safety, but note that LLVM's multithreaded support is very
+// preliminary; multithreaded use is not recommended at this time.
+StatusOr<string> CompileToPtx(llvm::Module* module, GpuVersion gpu_version,
+                              const HloModuleConfig& hlo_module_config,
+                              const string& libdevice_dir_path);
+}  // namespace nvptx
+
+namespace amdgpu {
+// Compiles the argument module and returns it with LLVM AMDGPU backend.
+// rocdl_dir_path is the parent directory of ROCm-Device-Libs bitcode libraries.
+// The contents of the module may be changed.
+StatusOr<std::vector<uint8>> CompileToHsaco(
+    llvm::Module* module, GpuVersion gpu_version,
+    const HloModuleConfig& hlo_module_config, const string& rocdl_dir_path);
+}  // namespace amdgpu
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_GPU_BACKEND_LIB_H_
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
deleted file mode 100644
index 9f52f09..0000000
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
+++ /dev/null
@@ -1,470 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h"
-
-#include <map>
-#include <memory>
-#include <string>
-#include <utility>
-
-#include "absl/memory/memory.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/StringMap.h"
-#include "llvm/ADT/StringSet.h"
-#include "llvm/Analysis/TargetLibraryInfo.h"
-#include "llvm/Analysis/TargetTransformInfo.h"
-#include "llvm/Bitcode/BitcodeReader.h"
-#include "llvm/Bitcode/BitcodeWriter.h"
-#include "llvm/CodeGen/CommandFlags.inc"
-#include "llvm/IR/LLVMContext.h"
-#include "llvm/IR/LegacyPassManager.h"
-#include "llvm/IR/Module.h"
-#include "llvm/IR/Verifier.h"
-#include "llvm/Linker/Linker.h"
-#include "llvm/PassRegistry.h"
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/FileSystem.h"
-#include "llvm/Support/FormattedStream.h"
-#include "llvm/Support/TargetRegistry.h"
-#include "llvm/Support/TargetSelect.h"
-#include "llvm/Support/ToolOutputFile.h"
-#include "llvm/Target/TargetMachine.h"
-#include "llvm/Transforms/IPO.h"
-#include "llvm/Transforms/IPO/AlwaysInliner.h"
-#include "llvm/Transforms/IPO/Internalize.h"
-#include "llvm/Transforms/IPO/PassManagerBuilder.h"
-#include "llvm/Transforms/Scalar.h"
-#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h"
-#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h"
-#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
-#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/tracing.h"
-#include "tensorflow/core/profiler/lib/traceme.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-// Default inline threshold value to use in llvm.
-const int kDefaultInlineThreshold = 1100;
-
-// Gets the GPU name as it's known to LLVM for a given compute capability.  If
-// we see an unrecognized compute capability, we return "sm_35".
-static string GetSmName(std::pair<int, int> compute_capability) {
-  static auto* m = new std::map<std::pair<int, int>, int>({
-      {{3, 5}, 35},
-      {{3, 7}, 37},
-      {{5, 0}, 50},
-      {{5, 2}, 52},
-      {{5, 3}, 53},
-      {{6, 0}, 60},
-      {{6, 1}, 61},
-      {{6, 2}, 62},
-      {{7, 0}, 70},
-      {{7, 2}, 72},
-      {{7, 5}, 75},
-  });
-  int sm_version = 35;
-  auto it = m->find(compute_capability);
-  if (it != m->end()) {
-    sm_version = it->second;
-  } else {
-    LOG(WARNING) << "Unknown compute capability (" << compute_capability.first
-                 << ", " << compute_capability.second << ") ."
-                 << "Defaulting to telling LLVM that we're compiling for sm_"
-                 << sm_version;
-  }
-  return absl::StrCat("sm_", sm_version);
-}
-
-// Convenience function for producing a name of a temporary compilation product
-// from the input filename.
-string MakeNameForTempProduct(absl::string_view input_filename,
-                              absl::string_view extension) {
-  return ReplaceFilenameExtension(tensorflow::io::Basename(input_filename),
-                                  extension);
-}
-
-// Initializes LLVM passes. Uses the PassRegistry mechanism.
-void InitializePasses(llvm::PassRegistry* pass_registry) {
-  llvm::initializeCore(*pass_registry);
-  llvm::initializeCodeGen(*pass_registry);
-  llvm::initializeScalarOpts(*pass_registry);
-  llvm::initializeObjCARCOpts(*pass_registry);
-  llvm::initializeVectorization(*pass_registry);
-  llvm::initializeIPO(*pass_registry);
-  llvm::initializeAnalysis(*pass_registry);
-  llvm::initializeTransformUtils(*pass_registry);
-  llvm::initializeInstCombine(*pass_registry);
-  llvm::initializeInstrumentation(*pass_registry);
-  llvm::initializeTarget(*pass_registry);
-  llvm::initializeCodeGenPreparePass(*pass_registry);
-}
-
-// Returns the TargetMachine, given a triple.
-std::unique_ptr<llvm::TargetMachine> GetTargetMachine(
-    llvm::Triple triple, absl::string_view cpu_name,
-    const HloModuleConfig& hlo_module_config) {
-  std::string error;
-  const llvm::Target* target = TargetRegistry::lookupTarget("", triple, error);
-  if (target == nullptr) {
-    LOG(FATAL) << "Unable to find Target for triple '" << triple.str() << "'"
-               << " -- " << error;
-    return nullptr;
-  }
-
-  TargetOptions target_options = InitTargetOptionsFromCodeGenFlags();
-
-  // Set the verbose assembly options.
-  target_options.MCOptions.AsmVerbose = false;
-
-  // The selection of codegen optimization level is copied from function
-  // GetCodeGenOptLevel in //third_party/llvm/llvm/tools/opt/opt.cpp.
-  CodeGenOpt::Level codegen_opt_level;
-  switch (hlo_module_config.debug_options().xla_backend_optimization_level()) {
-    case 1:
-      codegen_opt_level = CodeGenOpt::Less;
-      break;
-    case 2:
-      codegen_opt_level = CodeGenOpt::Default;
-      break;
-    case 3:
-      codegen_opt_level = CodeGenOpt::Aggressive;
-      break;
-    default:
-      codegen_opt_level = CodeGenOpt::None;
-  }
-  return absl::WrapUnique(target->createTargetMachine(
-      triple.str(), llvm_ir::AsStringRef(cpu_name), "+ptx60", target_options,
-      getRelocModel(), getCodeModel(), codegen_opt_level));
-}
-
-// Adds the standard LLVM optimization passes, based on the speed optimization
-// level (opt_level) and size optimization level (size_level). Both module
-// and function-level passes are added, so two pass managers are passed in and
-// modified by this function.
-void AddOptimizationPasses(unsigned opt_level, unsigned size_level,
-                           llvm::TargetMachine* target_machine,
-                           llvm::legacy::PassManagerBase* module_passes,
-                           llvm::legacy::FunctionPassManager* function_passes) {
-  PassManagerBuilder builder;
-  builder.OptLevel = opt_level;
-  builder.SizeLevel = size_level;
-
-  if (opt_level > 1) {
-    builder.Inliner = llvm::createFunctionInliningPass(kDefaultInlineThreshold);
-  } else {
-    // Only inline functions marked with "alwaysinline".
-    builder.Inliner = llvm::createAlwaysInlinerLegacyPass();
-  }
-
-  builder.DisableUnrollLoops = opt_level == 0;
-  builder.LoopVectorize = opt_level > 0;
-  builder.SLPVectorize = opt_level > 1 && size_level < 2;
-
-  // NVPTX's early-as-possible passes include NVVM reflect.
-  target_machine->adjustPassManager(builder);
-
-  builder.populateFunctionPassManager(*function_passes);
-  builder.populateModulePassManager(*module_passes);
-}
-
-// Emits the given module to a bit code file.
-void EmitBitcodeToFile(const Module& module, absl::string_view filename) {
-  std::error_code error_code;
-  llvm::ToolOutputFile outfile(string(filename).c_str(), error_code,
-                               llvm::sys::fs::F_None);
-  if (error_code) {
-    LOG(FATAL) << "opening bitcode file for writing: " << error_code.message();
-  }
-
-  llvm::WriteBitcodeToFile(module, outfile.os());
-  outfile.keep();
-}
-
-// Emits the given module to PTX. target_machine is an initialized TargetMachine
-// for the NVPTX target.
-string EmitModuleToPTX(Module* module, llvm::TargetMachine* target_machine) {
-  std::string ptx;  // need a std::string instead of a ::string.
-  {
-    llvm::raw_string_ostream stream(ptx);
-    llvm::buffer_ostream pstream(stream);
-    // The extension is stripped by IrDumpingPassManager, so we need to
-    // get creative to add a suffix.
-    IrDumpingPassManager codegen_passes(
-        MakeNameForTempProduct(module->getModuleIdentifier(), "-nvptx.dummy"),
-        "", false);
-    codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass(
-        llvm::Triple(module->getTargetTriple())));
-
-    target_machine->addPassesToEmitFile(codegen_passes, pstream, nullptr,
-                                        llvm::TargetMachine::CGFT_AssemblyFile);
-    codegen_passes.run(*module);
-  }
-
-  return ptx;
-}
-
-// LLVM has an extensive flags mechanism of its own, which is only accessible
-// through the command line. Internal libraries within LLVM register parsers for
-// flags, with no other way to configure them except pass these flags.
-// To do this programmatically, we invoke ParseCommandLineOptions manually with
-// a "fake argv".
-// Note: setting flags with this method is stateful, since flags are just
-// static globals within LLVM libraries.
-void FeedLLVMWithFlags(const std::vector<string>& cl_opts) {
-  std::vector<const char*> fake_argv = {""};
-  for (const string& cl_opt : cl_opts) {
-    fake_argv.push_back(cl_opt.c_str());
-  }
-  llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]);
-}
-
-// Returns whether the module could use any libdevice functions. This function
-// may have false positives -- the module might not use libdevice even if this
-// function returns true.
-bool CouldNeedLibdevice(const llvm::Module& module) {
-  for (const llvm::Function& function : module.functions()) {
-    // This is a conservative approximation -- not all such functions are in
-    // libdevice.
-    if (!function.isIntrinsic() && function.isDeclaration()) {
-      return true;
-    }
-  }
-  return false;
-}
-
-// Links libdevice into the given module if the module needs libdevice.
-Status LinkLibdeviceIfNecessary(llvm::Module* module,
-                                std::pair<int, int> compute_capability,
-                                const string& libdevice_dir_path) {
-  if (!CouldNeedLibdevice(*module)) {
-    return Status::OK();
-  }
-
-  // CUDA 9+ uses a single libdevice file for all devices, and we don't support
-  // older CUDAs.
-  string libdevice_path =
-      tensorflow::io::JoinPath(libdevice_dir_path, "libdevice.10.bc");
-  if (!tensorflow::Env::Default()->FileExists(libdevice_path).ok()) {
-    LOG(WARNING)
-        << "libdevice is required by this HLO module but was not found at "
-        << libdevice_path;
-    return xla::InternalError("libdevice not found at %s", libdevice_path);
-  }
-
-  VLOG(1) << "Linking with libdevice from: " << libdevice_path;
-  std::unique_ptr<llvm::Module> libdevice_module =
-      LoadIRModule(libdevice_path, &module->getContext());
-
-  llvm::Linker linker(*module);
-  if (linker.linkInModule(
-          std::move(libdevice_module), llvm::Linker::Flags::LinkOnlyNeeded,
-          [](Module& M, const StringSet<>& GVS) {
-            internalizeModule(M, [&GVS](const GlobalValue& GV) {
-              return !GV.hasName() || (GVS.count(GV.getName()) == 0);
-            });
-          })) {
-    return xla::InternalError("Error linking libdevice from %s",
-                              libdevice_path);
-  }
-  return Status::OK();
-}
-
-StatusOr<string> CompileModuleToPtx(llvm::Module* module,
-                                    std::pair<int, int> compute_capability,
-                                    const HloModuleConfig& hlo_module_config,
-                                    const string& libdevice_dir_path) {
-  // If the module has no functions or globals, there's nothing to compile. Just
-  // return an empty string.
-  if (module->empty() && module->global_empty()) {
-    VLOG(2) << "Module '" << module->getName().str()
-            << "' is empty. Skipping compilation.";
-    return string();
-  }
-  // Link the input module with libdevice, to pull in implementations of some
-  // builtins.
-  TF_RETURN_IF_ERROR(
-      LinkLibdeviceIfNecessary(module, compute_capability, libdevice_dir_path));
-
-  // Set the flush-denormals-to-zero flag on the module so the NVVM reflect pass
-  // can access it.
-  module->addModuleFlag(llvm::Module::Override, "nvvm-reflect-ftz",
-                        hlo_module_config.debug_options().xla_gpu_ftz());
-
-  // If ftz is enabled, set it as an attribute on every function in the module.
-  if (hlo_module_config.debug_options().xla_gpu_ftz()) {
-    for (llvm::Function& fn : *module) {
-      fn.addFnAttr("nvptx-f32ftz", "true");
-    }
-  }
-
-  IrDumpingPassManager module_passes(module->getModuleIdentifier(), "", false);
-
-  // Add an appropriate TargetLibraryInfo pass for the module's triple.
-  llvm::TargetLibraryInfoWrapperPass* tliwp =
-      new llvm::TargetLibraryInfoWrapperPass(
-          llvm::Triple(module->getTargetTriple()));
-  module_passes.add(tliwp);
-
-  // Try to fetch the target triple from the module. If not present, set a
-  // default target triple.
-  llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
-  if (target_triple.getArch() == llvm::Triple::UnknownArch) {
-    LOG(WARNING) << "target triple not found in the module";
-    target_triple = llvm::Triple("nvptx64-unknown-unknown");
-  }
-
-  // Figure out the exact name of the processor as known to the NVPTX backend
-  // from the gpu_architecture flag.
-  std::unique_ptr<llvm::TargetMachine> target_machine = GetTargetMachine(
-      target_triple, GetSmName(compute_capability), hlo_module_config);
-  module_passes.add(llvm::createTargetTransformInfoWrapperPass(
-      target_machine->getTargetIRAnalysis()));
-
-  // The LLVM IR verifier performs sanity checking on the IR. This helps
-  // discover problems and report them in a meaningful manner, rather than let
-  // later passes report obscure assertions because of unfulfilled invariants.
-  module_passes.add(llvm::createVerifierPass());
-
-  // Create the function-level pass manager. It needs data layout information
-  // too.
-  llvm::legacy::FunctionPassManager function_passes(module);
-
-  int32 opt_level =
-      hlo_module_config.debug_options().xla_backend_optimization_level();
-
-  if (opt_level < 2) {
-    LOG(ERROR) << std::string(80, '*');
-    LOG(ERROR) << "The XLA GPU backend doesn't support unoptimized code "
-                  "generation but ";
-    LOG(ERROR) << "--xla_backend_optimization_level is set to " << opt_level
-               << "!";
-    LOG(ERROR) << "(Supported configuration is "
-                  "--xla_backend_optimization_level >= 2.)";
-    LOG(ERROR) << std::string(80, '*');
-  }
-
-  AddOptimizationPasses(opt_level,
-                        /*size_level=*/0, target_machine.get(), &module_passes,
-                        &function_passes);
-
-  // Loop unrolling exposes more opportunities for SROA. Therefore, we run SROA
-  // again after the standard optimization passes [http://b/13329423].
-  // TODO(jingyue): SROA may further expose more optimization opportunities such
-  // as more precise alias analysis and more function inlining (SROA may change
-  // the inlining cost of a function). For now, running SROA already emits good
-  // enough code for the evaluated benchmarks. We may want to run more
-  // optimizations later.
-  if (opt_level > 0) {
-    // LLVM's optimizer turns on SROA when the optimization level is greater
-    // than 0. We mimic this behavior here.
-    module_passes.add(llvm::createSROAPass());
-  }
-
-  // Verify that the module is well formed after optimizations ran.
-  module_passes.add(llvm::createVerifierPass());
-
-  // Done populating the pass managers. Now run them.
-
-  function_passes.doInitialization();
-  for (auto func = module->begin(); func != module->end(); ++func) {
-    function_passes.run(*func);
-  }
-  function_passes.doFinalization();
-  module_passes.run(*module);
-
-  // Finally, produce PTX.
-  return EmitModuleToPTX(module, target_machine.get());
-}
-
-// One-time module initializer.
-// Must be called only once -- DO NOT CALL DIRECTLY.
-void GPUBackendInit(const HloModuleConfig& hlo_module_config) {
-  // Feed all customized flags here, so we can override them with llvm_cl_opts
-  // without redeploy the compiler for development purpose.
-
-  // This flag tunes a threshold in branch folding. The default threshold, which
-  // is one, is not suitable for CUDA programs where branches are more expensive
-  // than for CPU programs. Setting the threshold to 2 improves the latency of
-  // TwoDPatchDotProductKernel_IND_3_ND_48 by over 5%, and does not affect the
-  // latency of other benchmarks so far.
-  //
-  // I also tried setting this threshold to other values:
-  // * 3-6 gives similar results as 2;
-  // * >6 start hurting the performance of at least dot product kernels.
-  //
-  // TODO(jingyue): The current threshold only considers the number of IR
-  // instructions which do not accurately reflect the true cost. We need a
-  // better cost model.
-  FeedLLVMWithFlags({"-bonus-inst-threshold=2"});
-  // Increase limit when scanning memory dependencies.  This helps to reduce
-  // more redundant load instructions.
-  //
-  // The specific value is currently large enough for s3d in shoc benchmark,
-  // which contains a lot of load instructions and many arithmetic instructions
-  // between those loads.
-  FeedLLVMWithFlags({"-memdep-block-scan-limit=500"});
-
-  // Use div.full -- it matters for some float-division heavy benchmarks.
-  // Using div.approx produces incorrect result for float32(max)/float32(max).
-  FeedLLVMWithFlags({"-nvptx-prec-divf32=1"});
-
-  llvm_ir::InitializeLLVMCommandLineOptions(hlo_module_config);
-
-  // Initialize the NVPTX target; it's the only target we link with, so call its
-  // specific initialization functions instead of the catch-all InitializeAll*.
-  LLVMInitializeNVPTXTarget();
-  LLVMInitializeNVPTXTargetInfo();
-  LLVMInitializeNVPTXTargetMC();
-  LLVMInitializeNVPTXAsmPrinter();
-
-  // Initialize the LLVM optimization passes.
-  llvm::PassRegistry* registry = llvm::PassRegistry::getPassRegistry();
-  InitializePasses(registry);
-}
-
-}  // namespace
-
-StatusOr<string> CompileToPtx(llvm::Module* module,
-                              std::pair<int, int> compute_capability,
-                              const HloModuleConfig& hlo_module_config,
-                              const string& libdevice_dir_path) {
-  static std::once_flag backend_init_flag;
-  std::call_once(backend_init_flag, GPUBackendInit, hlo_module_config);
-
-  string ptx;
-  {
-    tensorflow::profiler::TraceMe activity(
-        [&] { return absl::StrCat("Compiling IR:", module->getName().str()); },
-        tensorflow::profiler::TraceMeLevel::kInfo);
-    XLA_SCOPED_LOGGING_TIMER("Compile module " + module->getName().str());
-    TF_ASSIGN_OR_RETURN(
-        ptx, CompileModuleToPtx(module, compute_capability, hlo_module_config,
-                                libdevice_dir_path));
-  }
-  return ptx;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h
deleted file mode 100644
index 9654175..0000000
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h
+++ /dev/null
@@ -1,47 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// LLVM-based compiler backend.
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_NVPTX_BACKEND_LIB_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_NVPTX_BACKEND_LIB_H_
-
-#include <string>
-#include <utility>
-
-#include "absl/strings/string_view.h"
-#include "llvm/IR/Module.h"
-#include "tensorflow/compiler/xla/service/hlo_module_config.h"
-#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/compiler/xla/types.h"
-
-namespace xla {
-namespace gpu {
-
-// Compiles the argument module and returns it. libdevice_dir_path is the parent
-// directory of the libdevice bitcode libraries. The contents of the module may
-// be changed.
-//
-// The Compile.* interfaces each create their own llvm::LLVMContext objects for
-// thread safety, but note that LLVM's multithreaded support is very
-// preliminary; multithreaded use is not recommended at this time.
-StatusOr<string> CompileToPtx(llvm::Module* module,
-                              std::pair<int, int> compute_capability,
-                              const HloModuleConfig& hlo_module_config,
-                              const string& libdevice_dir_path);
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_NVPTX_BACKEND_LIB_H_
diff --git a/tensorflow/compiler/xla/service/gpu/mlir/BUILD b/tensorflow/compiler/xla/service/gpu/mlir/BUILD
new file mode 100644
index 0000000..0c4a3a4
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/mlir/BUILD
@@ -0,0 +1,14 @@
+# Description:
+#   Conversion of late HLO to XLA-HLO MLIR dialect.
+
+package(
+    default_visibility = [":friends"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+package_group(
+    name = "friends",
+    includes = [
+        "//tensorflow/compiler/xla:friends",
+    ],
+)
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
index 536b11a..d4e3d34 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
@@ -109,6 +109,7 @@
 }
 
 namespace {
+
 // We prefer multi-output fusions over other fusions over unfused ops, because
 // we want to preserve fusion opportunities if possible.
 HloInstruction* GetPreferredFusionCandidate(
@@ -125,6 +126,7 @@
   }
   return candidates.empty() ? nullptr : candidates[0];
 }
+
 }  // namespace
 
 bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
@@ -148,11 +150,9 @@
     for (HloInstruction* consumer : producer->users()) {
       VLOG(3) << "Looking at producer " << producer->name()
               << " and its consumer " << consumer->name();
-      // TODO(b/136623068): Use IsFusibleAsMultiOutputFusionRoot(...) to lift
-      // the restriction to input-fusible reductions.
-      if (!IsInputFusibleReduction(*consumer)) {
+      if (!IsFusibleAsMultiOutputFusionRoot(*consumer)) {
         VLOG(3) << "Consumer " << consumer->name()
-                << " is not an input-fusible reduction.";
+                << " is not eligible as multi-output fusion root.";
         continue;
       }
       if (!IsProducerConsumerMultiOutputFusible(*producer, *consumer)) {
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index 20b3d64..083f578 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -18,6 +18,7 @@
 #include <stdlib.h>
 
 #include <atomic>
+#include <fstream>
 #include <functional>
 #include <mutex>  // NOLINT(build/c++11): only using std::call_once, not mutex.
 #include <utility>
@@ -51,6 +52,7 @@
 #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
 #include "tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h"
 #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
 #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
 #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
@@ -63,14 +65,13 @@
 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
-#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h"
+#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
 #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
 #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
 #include "tensorflow/compiler/xla/service/gpu/target_constants.h"
 #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
-#include "tensorflow/compiler/xla/service/gpu/variadic_op_splitter.h"
 #include "tensorflow/compiler/xla/service/hlo.pb.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
@@ -165,6 +166,95 @@
   return ".";
 }
 
+}  // namespace
+
+Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization(
+    HloModule* hlo_module, se::StreamExecutor* stream_exec,
+    se::DeviceMemoryAllocator* device_allocator) {
+  // Convert convolutions into CustomCalls to cudnn, then canonicalize them
+  // (CudnnConvPaddingLegalization). Also expand cuSolver calls.
+  HloPassPipeline pipeline("conv_canonicalization");
+  pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+                                            /*allow_mixed_precision=*/false);
+  pipeline.AddPass<CusolverRewriter>();
+  pipeline.AddPass<CudnnConvRewriter>();
+  pipeline.AddPass<CudnnFusedConvRewriter>();
+  pipeline.AddPass<CudnnConvPaddingLegalization>();
+  if (IsVoltaOrLater(*stream_exec)) {
+    pipeline.AddPass<CudnnConvPadForTensorCores>();
+    // CudnnConvPadForTensorCores leaves behind unnecessary
+    // tuple/get-tuple-element pairs that TupleSimplifier fixes.
+    pipeline.AddPass<TupleSimplifier>();
+  }
+  // CudnnConvRewriter, CudnnConvPaddingLegalization and
+  // CudnnConvPadForTensorCores may add instructions which can be simplified
+  // by constant folding.
+  pipeline.AddPass<HloConstantFolding>();
+  TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
+
+  return Status::OK();
+}
+
+Status NVPTXCompiler::OptimizeHloPostLayoutAssignment(
+    HloModule* hlo_module, se::StreamExecutor* stream_exec,
+    se::DeviceMemoryAllocator* device_allocator) {
+  HloPassPipeline pipeline("post-layout_assignment");
+  /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
+   * fixing the ticket. */
+  pipeline.AddInvariantChecker<HloVerifier>(
+      /*layout_sensitive=*/true,
+      /*allow_mixed_precision=*/false,
+      LayoutAssignment::InstructionCanChangeLayout);
+
+  // The LayoutAssignment pass may leave behind kCopy instructions which are
+  // duplicate or NOPs, so remove them with algebraic simplification and CSE.
+  AlgebraicSimplifierOptions options;
+  options.set_is_layout_sensitive(true);
+  pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
+
+  // Rewrite GEMMs into custom calls.
+  pipeline.AddPass<GemmRewriter>();
+
+  // Choose the fastest algorithm for each conv.
+  //
+  // We pick the algorithm before fusion so we can generate better HLO. After
+  // CudnnConvRewriter, our convolutions are CustomCalls which return a
+  // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of
+  // scratch:
+  //
+  //   customcall = (f32[...], f32[0])
+  //   return gte(customcall, 0)
+  //
+  // The algorithm picker then chooses the best algorithm, and potentially
+  // increases the scratch space.  It replaces customcall with new_tuple,
+  // giving us the following:
+  //
+  //   new_customcall = (f32[...], f32[N])
+  //   new_tuple = tuple(gte(new_customcall, 0), constant f32[0])
+  //   return gte(new_tuple, 0)
+  //
+  // The new tuple and gte instructions then be simplified away, because
+  // nobody is expected to use the scratch value.
+  //
+  // However, if we were to run CudnnConvAlgorithmPicker after fusion
+  // the gte(customcall, 0) would probably already be into a fusion node.  We
+  // can't simplify across HloComputation boundaries, so in this case we
+  // wouldn't be able to simplify away the new_tuple bits.
+  pipeline.AddPass<CudnnConvAlgorithmPicker>(stream_exec, device_allocator);
+
+  // Find the fastest algorithm for GEMMs.
+  pipeline.AddPass<GemmAlgorithmPicker>(stream_exec, device_allocator);
+
+  // Clean up new_tuple described above.
+  pipeline.AddPass<TupleSimplifier>();
+
+  pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
+  TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
+
+  return Status::OK();
+}
+
+namespace {
 absl::optional<bool> CanShareBufferHint(const HloInstruction* user,
                                         const HloInstruction* operand,
                                         const ShapeIndex& user_index) {
@@ -222,387 +312,71 @@
   });
 }
 
+// Try to load ptx from files defined in the FLAGS. If successful, return true.
+bool MaybeLoadPtxFromFile(const HloModule* module, std::string* ptx) {
+  // If the xla_gpu_ptx_file options is set, be explicit when a file is used
+  // and warn when a file is not used to ease catching typo in filename.
+  std::string prefix = xla::FilenameFor(*module, *ptx);
+  std::string matched_filename;
+  for (const string filename :
+       module->config().debug_options().xla_gpu_ptx_file()) {
+    // To ease comparing many PTX versions, accept different suffixes then
+    // the original filename.
+    if (absl::StartsWith(filename, prefix)) {
+      matched_filename = filename;
+      VLOG(0) << "RunBackend() - Will load PTX from file: " << filename;
+      break;
+    }
+  }
+  if (module->config().debug_options().xla_gpu_ptx_file().size() > 0 &&
+      matched_filename.empty()) {
+    VLOG(0) << "RunBackend() - For module with prefix '" << prefix
+            << "', we did not found a PTX file to load.";
+  }
+
+  if (!matched_filename.empty()) {
+    std::ifstream ifs(matched_filename, std::ifstream::in);
+    *ptx = std::string(std::istreambuf_iterator<char>(ifs),
+                       std::istreambuf_iterator<char>());
+    CHECK(!ptx->empty()) << "Empty or non existing PTX file: "
+                         << matched_filename;
+    return true;
+  }
+  return false;
+}
+
 }  // namespace
 
-// Runs optimization passes on the given HLO module.
-Status impl::OptimizeHloModule(HloModule* hlo_module,
-                               se::StreamExecutor* stream_exec,
-                               se::DeviceMemoryAllocator* device_allocator) {
-  {
-    HloPassPipeline pipeline("optimization");
-    pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
-                                              /*allow_mixed_precision=*/false);
-
-    // Expand random number generation.
-    pipeline.AddPass<RngExpander>();
-
-    // Remove zero-sized HLO from the input so that other passes don't have to
-    // handle it.
-    pipeline.AddPass<ZeroSizedHloElimination>();
-
-    pipeline.AddPass<GpuScatterExpander>();
-
-    pipeline.AddPass<DynamicIndexSplitter>();
-    pipeline.AddPass<GpuHloSupportChecker>();
-    ReducePrecisionInsertion::AddPasses(
-        &pipeline, hlo_module->config().debug_options(),
-        ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION);
-
-    // TODO(b/64094172): make Call work on GPU instead of inlining.
-    pipeline.AddPass<CallInliner>();
-    auto cost_model = [](HloInstruction* conv) {
-      // We need a cost model for GPUs. Currently, do nothing.
-      return false;
-    };
-    pipeline.AddPass<DotDecomposer>();
-    pipeline.AddPass<ConvolutionGroupConverter>(
-        cost_model,
-        /*convert_batch_groups_only=*/true);
-    // Expand the sort op to support stable sorting if required.
-    pipeline.AddPass<StableSortExpander>();
-    // Convert BF16 operations to F32 operations so that the GPU backend can
-    // support BF16 operations without directly implementing a BF16 lowering for
-    // most ops.
-    pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
-
-    {
-      auto& pass =
-          pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
-      pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
-                                            /*allow_mixed_precision=*/false);
-
-      // If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls
-      // where possible.  Not every batchnorm op can be implemented as a call to
-      // cudnn, so decompose any remaining batchnorm ops into a soup of HLOs.
-      if (hlo_module->config().debug_options().xla_gpu_use_cudnn_batchnorm()) {
-        pass.AddPass<CudnnBatchNormRewriter>();
-      }
-      pass.AddPass<BatchNormExpander>(
-          /*rewrite_training_op=*/true,
-          /*rewrite_inference_op=*/true,
-          /*rewrite_grad_op=*/true);
-
-      pipeline.AddPass<HloGetDimensionSizeRewriter>();
-
-      // BatchNormExpander can create zero-sized ops, so zero-sized HLO
-      // elimination has to come after that pass.
-      pipeline.AddPass<ZeroSizedHloElimination>();
-
-      AlgebraicSimplifierOptions options;
-      pass.AddPass<AlgebraicSimplifier>(options);
-      pass.AddPass<SortSimplifier>();
-      pass.AddPass<TupleSimplifier>();
-      pass.AddPass<WhileLoopConstantSinking>();
-      pass.AddPass<WhileLoopSimplifier>();
-
-      // TODO(b/134075051): Re-enable after b/134075051 is fixed.
-      // pass.AddPass<SliceSinker>();
-
-      pass.AddPass<HloDCE>();
-      pass.AddPass<ReshapeMover>();
-      pass.AddPass<HloConstantFolding>();
-      pass.AddPass<ConditionalSimplifier>();
-    }
-
-    pipeline.AddPass<TransposeFolding>(
-        [](const HloInstruction& dot,
-           const TransposeFolding::OperandIndices& candidate_operands) {
-          return IsMatrixMultiplication(dot)
-                     ? candidate_operands
-                     : TransposeFolding::OperandIndices{};
-        },
-        TransposeFolding::NeverFoldTranspose);
-    pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
-    pipeline.AddPass<HloDCE>();
-
-    // Run WhileLoopTripCountAnnotator at the end of the simplification
-    // pipeline, before layout assignment and fusion.  This pass does some
-    // pattern-matching on while bodies/conditions, and this is where the HLO is
-    // "nicest".
-    //
-    // It's important that we don't make semantic changes (e.g. unrolling) to
-    // any `while` loops after this point, because otherwise the trip-count
-    // annotations added by this pass may not be correct after the
-    // modifications.
-    pipeline.AddPass<WhileLoopTripCountAnnotator>();
-    TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
-  }
-
-  {
-    // Convert convolutions into CustomCalls to cudnn, then canonicalize them
-    // (CudnnConvPaddingLegalization). Also expand cuSolver calls.
-    HloPassPipeline pipeline("conv_canonicalization");
-    pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
-                                              /*allow_mixed_precision=*/false);
-    pipeline.AddPass<CusolverRewriter>();
-    pipeline.AddPass<CudnnConvRewriter>();
-    pipeline.AddPass<CudnnFusedConvRewriter>();
-    pipeline.AddPass<CudnnConvPaddingLegalization>();
-    if (IsVoltaOrLater(*stream_exec)) {
-      pipeline.AddPass<CudnnConvPadForTensorCores>();
-      // CudnnConvPadForTensorCores leaves behind unnecessary
-      // tuple/get-tuple-element pairs that TupleSimplifier fixes.
-      pipeline.AddPass<TupleSimplifier>();
-    }
-    // CudnnConvRewriter, CudnnConvPaddingLegalization and
-    // CudnnConvPadForTensorCores may add instructions which can be simplified
-    // by constant folding.
-    pipeline.AddPass<HloConstantFolding>();
-    TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
-  }
-
-  {
-    // Run layout assignment in a separate pipeline from
-    // "post-layout-assignment" because we want everything after layout
-    // assignment to have a layout-sensitive invariant-checker, but
-    // HloPassPipeline also runs its invariant checker before any passes are
-    // run, meaning, the pipeline that contains layout assignment cannot contain
-    // a layout-sensitive verifier!
-    HloPassPipeline pipeline("layout assignment");
-    pipeline.AddPass<GpuLayoutAssignment>(
-        hlo_module->mutable_entry_computation_layout(),
-        LayoutAssignment::InstructionCanChangeLayout, stream_exec);
-    TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
-  }
-
-  {
-    HloPassPipeline pipeline("post-layout_assignment");
-    /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
-     * fixing the ticket. */
-    pipeline.AddInvariantChecker<HloVerifier>(
-        /*layout_sensitive=*/true,
-        /*allow_mixed_precision=*/false,
-        LayoutAssignment::InstructionCanChangeLayout);
-
-    // The LayoutAssignment pass may leave behind kCopy instructions which are
-    // duplicate or NOPs, so remove them with algebraic simplification and CSE.
-    AlgebraicSimplifierOptions options;
-    options.set_is_layout_sensitive(true);
-    pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
-
-    // Rewrite GEMMs into custom calls.
-    pipeline.AddPass<GemmRewriter>();
-
-    // Choose the fastest algorithm for each conv.
-    //
-    // We pick the algorithm before fusion so we can generate better HLO. After
-    // CudnnConvRewriter, our convolutions are CustomCalls which return a
-    // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of
-    // scratch:
-    //
-    //   customcall = (f32[...], f32[0])
-    //   return gte(customcall, 0)
-    //
-    // The algorithm picker then chooses the best algorithm, and potentially
-    // increases the scratch space.  It replaces customcall with new_tuple,
-    // giving us the following:
-    //
-    //   new_customcall = (f32[...], f32[N])
-    //   new_tuple = tuple(gte(new_customcall, 0), constant f32[0])
-    //   return gte(new_tuple, 0)
-    //
-    // The new tuple and gte instructions then be simplified away, because
-    // nobody is expected to use the scratch value.
-    //
-    // However, if we were to run CudnnConvAlgorithmPicker after fusion
-    // the gte(customcall, 0) would probably already be into a fusion node.  We
-    // can't simplify across HloComputation boundaries, so in this case we
-    // wouldn't be able to simplify away the new_tuple bits.
-    pipeline.AddPass<CudnnConvAlgorithmPicker>(stream_exec, device_allocator);
-
-    // Find the fastest algorithm for GEMMs.
-    pipeline.AddPass<GemmAlgorithmPicker>(stream_exec, device_allocator);
-
-    // Clean up new_tuple described above.
-    pipeline.AddPass<TupleSimplifier>();
-
-    pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
-    TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
-  }
-
-  {
-    HloPassFix<HloPassPipeline> fusion("fusion");
-    // We try to split variadic ops with many parameters into several such ops
-    // to avoid exceeding the parameter space.
-    fusion.AddPass<VariadicOpSplitter>();
-    /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
-     * fixing the ticket. */
-    fusion.AddInvariantChecker<HloVerifier>(
-        /*layout_sensitive=*/true,
-        /*allow_mixed_precision=*/false,
-        LayoutAssignment::InstructionCanChangeLayout);
-    fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
-    fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
-    fusion.AddPass<FusionMerger>();
-    fusion.AddPass<GpuMultiOutputFusion>();
-    fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
-                           /*only_fusion_computations=*/true);
-    fusion.AddPass<HloDCE>();
-    TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
-
-    HloPassPipeline reduce_pipeline("reduce-precision");
-    /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
-     * fixing the ticket. */
-    reduce_pipeline.AddInvariantChecker<HloVerifier>(
-        /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false,
-        LayoutAssignment::InstructionCanChangeLayout);
-    ReducePrecisionInsertion::AddPasses(
-        &reduce_pipeline, hlo_module->config().debug_options(),
-        ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
-    StatusOr<bool> reduce_result = reduce_pipeline.Run(hlo_module);
-    TF_RETURN_IF_ERROR(reduce_result.status());
-
-    if (reduce_result.ValueOrDie()) {
-      // Do another fusion pass, with the expectation that we may be able to
-      // fuse the new ReducePrecision operations.
-      TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
-    }
-  }
-
-  return Status::OK();
-}
-
-// Modifies the given HLO module so that it will be accepted by IrEmitter.
-// Unlike optimization passes, the passes are necessary for correctness.
-Status impl::PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
-  // In some cases, we have to place the result of an instruction in a temporary
-  // buffer. For instance, the buffer that holds an external parameter is
-  // assumed immutable at this point, and should not be reused for output
-  // (b/27180329). Therefore, in that case, we set the output to be a copy of
-  // the parameter.
-  HloPassPipeline pipeline("GPU-ir-emit-prepare");
-  /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
-   * fixing the ticket. */
-  pipeline.AddInvariantChecker<HloVerifier>(
-      /*layout_sensitive=*/true,
-      /*allow_mixed_precision=*/false,
-      LayoutAssignment::InstructionCanChangeLayout);
-
-  // Copy insertion should be performed immediately before IR emission to avoid
-  // inserting unnecessary copies (later pass adds an instruction which
-  // materializes the value) or missing a necessary copy (later pass removes an
-  // instruction which materializes a value). DCE must be run immediately before
-  // (and sometime after) copy insertion, to avoid dead code from interfering
-  // with the rewrites.
-  pipeline.AddPass<HloDCE>();
-  pipeline.AddPass<FlattenCallGraph>();
-  // The following pass LOGs memory waste. Add it when VLOGing is enabled only.
-  if (VLOG_IS_ON(2)) {
-    pipeline.AddPass<MemWastedOnPassthroughParams>();
-  }
-  pipeline.AddPass<GpuCopyInsertion>(&CanShareBufferHint);
-  pipeline.AddPass<GpuSanitizeConstantNames>();
-  return pipeline.Run(hlo_module).status();
-}
-
 NVPTXCompiler::NVPTXCompiler()
-    : pointer_size_(llvm::DataLayout(nvptx::kDataLayout)
-                        .getPointerSize(0 /* default address space */)) {}
+    : GpuCompiler(stream_executor::cuda::kCudaPlatformId, nvptx::kTargetTriple,
+                  nvptx::kDataLayout) {}
 
-StatusOr<std::unique_ptr<HloModule>> NVPTXCompiler::RunHloPasses(
-    std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
-    se::DeviceMemoryAllocator* device_allocator) {
-  // We dump the post-optimization HLO in RunBackend so no need to dump it here.
-  XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunHloPasses");
-  tensorflow::profiler::TraceMe activity(
-      [&] { return absl::StrCat("HLO Transforms:", module->name()); },
-      tensorflow::profiler::TraceMeLevel::kInfo);
-  TF_RETURN_IF_ERROR(
-      impl::OptimizeHloModule(module.get(), stream_exec, device_allocator));
-
-  TF_RETURN_IF_ERROR(impl::PrepareHloModuleForIrEmitting(module.get()));
-
-  return std::move(module);
+HloDataflowAnalysis::CanShareBuffer NVPTXCompiler::GetCanShareBuffer() {
+  return &CanShareBufferHint;
 }
 
-StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend(
-    std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
-    se::DeviceMemoryAllocator* device_allocator) {
-  XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunBackend");
-
-  TF_RET_CHECK(stream_exec != nullptr);
-
-  llvm::LLVMContext llvm_context;
-  std::string buffer;
-  llvm::raw_string_ostream error(buffer);
-  llvm::DiagnosticPrinterRawOStream printer(error);
-  auto DiagnosticHandler = [](const llvm::DiagnosticInfo& diag_info,
-                              void* Context) {
-    auto printer = static_cast<llvm::DiagnosticPrinterRawOStream*>(Context);
-    diag_info.print(*printer);
-  };
-  llvm_context.setDiagnosticHandlerCallBack(DiagnosticHandler, &printer);
-
-  llvm::Module llvm_module(module->name().c_str(), llvm_context);
-  // Set the target triple and the data layout.
-  llvm_module.setTargetTriple(nvptx::kTargetTriple);
-  llvm_module.setDataLayout(nvptx::kDataLayout);
-
-  // Determine the HLO schedule, which is an ordering of HLO instructions.  This
-  // is used by buffer assignment to enable buffer reuse, and the same ordering
-  // must also be used to determine the thunk launch schedule.
-  std::unique_ptr<StreamAssignment> stream_assignment = AssignStreams(*module);
-  TF_ASSIGN_OR_RETURN(
-      std::unique_ptr<GpuHloSchedule> hlo_schedule,
-      GpuHloSchedule::Build(*module, *stream_assignment, pointer_size_));
-
-  // Run buffer analysis on the HLO graph. This analysis figures out which
-  // temporary buffers are required to run the computation.
-  TF_ASSIGN_OR_RETURN(
-      std::unique_ptr<BufferAssignment> buffer_assignment,
-      BufferAssigner::Run(
-          module.get(), hlo_schedule->ConsumeHloOrdering(),
-          BufferSizeBytesFunction(),
-          /*color_alignment=*/
-          [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; },
-          /*allocate_buffers_for_constants=*/true,
-          /*colorer=*/BufferAssigner::DefaultColorer(),
-          /*must_not_live_out=*/{}, &CanShareBufferHint));
-  DumpHloModuleIfEnabled(*module, *buffer_assignment, "after_optimizations");
-
-  IrEmitterContext ir_emitter_context(
-      module.get(), buffer_assignment.get(), stream_exec->platform(),
-      &stream_exec->GetDeviceDescription(), &llvm_module);
-
-  HloComputation* entry_computation = module->entry_computation();
-  IrEmitterUnnested ir_emitter(module->config(), entry_computation,
-                               &ir_emitter_context);
-
-  TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
-
-  {
-    XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunBackend - IR emission");
-    TF_RETURN_IF_ERROR(entry_computation->Accept(&ir_emitter));
+GpuVersion NVPTXCompiler::GetGpuVersion(se::StreamExecutor* stream_exec) {
+  int cc_major, cc_minor;
+  if (!stream_exec->GetDeviceDescription().cuda_compute_capability(&cc_major,
+                                                                   &cc_minor)) {
+    LOG(WARNING)
+        << "Couldn't get compute capability for device; assuming sm_20.";
+    cc_major = 2;
+    cc_minor = 0;
   }
 
-  if (user_pre_optimization_hook_) {
-    user_pre_optimization_hook_(llvm_module);
-  }
-  string ir_module_string_before_opt;
-  const bool embed_ir_in_executable =
-      module->config().debug_options().xla_embed_ir_in_executable();
-  if (embed_ir_in_executable) {
-    ir_module_string_before_opt = llvm_ir::DumpModuleToString(llvm_module);
-  }
+  return std::make_pair(cc_major, cc_minor);
+}
 
-  llvm_ir::DumpIrIfEnabled(*module, llvm_module, /*optimized=*/false);
+StatusOr<std::pair<std::string, std::vector<uint8>>>
+NVPTXCompiler::CompileTargetBinary(const HloModule* module,
+                                   llvm::Module* llvm_module,
+                                   GpuVersion gpu_version,
+                                   se::StreamExecutor* stream_exec) {
+  std::pair<int, int> compute_capability =
+      absl::get<std::pair<int, int>>(gpu_version);
 
-  {
-    XLA_SCOPED_LOGGING_TIMER(
-        "NVPTXCompiler::RunBackend - Running LLVM verifier");
-
-    std::string err;
-    llvm::raw_string_ostream err_stream(err);
-
-    // verifyModule() returns true if the module is broken.
-    TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream))
-        << "Invalid LLVM IR before optimizations:\n"
-        << err_stream.str()
-        << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. "
-           "Rerun with --xla_dump_to to get the IR. ";
-  }
-
-  string libdevice_dir;
+  std::string libdevice_dir;
   {
     tensorflow::mutex_lock lock(mutex_);
 
@@ -616,70 +390,31 @@
   }
   VLOG(2) << "Libdevice dir = " << libdevice_dir << "\n";
 
-  int cc_major, cc_minor;
-  if (!stream_exec->GetDeviceDescription().cuda_compute_capability(&cc_major,
-                                                                   &cc_minor)) {
-    LOG(WARNING)
-        << "Couldn't get compute capability for device; assuming sm_20.";
-    cc_major = 2;
-    cc_minor = 0;
-  }
-
   string ptx;
-  {
-    XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunBackend - CompileToPtx");
-    TF_ASSIGN_OR_RETURN(ptx, CompileToPtx(&llvm_module, {cc_major, cc_minor},
-                                          module->config(), libdevice_dir));
+  if (!MaybeLoadPtxFromFile(module, &ptx)) {
+    XLA_SCOPED_LOGGING_TIMER(
+        "NVPTXCompiler::CompileTargetBinary - CompileToPtx");
+    TF_ASSIGN_OR_RETURN(
+        ptx, nvptx::CompileToPtx(llvm_module, gpu_version, module->config(),
+                                 libdevice_dir));
   }
 
-  llvm_ir::DumpIrIfEnabled(*module, llvm_module, /*optimized=*/true);
+  llvm_ir::DumpIrIfEnabled(*module, *llvm_module, /*optimized=*/true);
 
   if (user_post_optimization_hook_) {
-    user_post_optimization_hook_(llvm_module);
+    user_post_optimization_hook_(*llvm_module);
   }
   // Write PTX to IR dump directory, if IR dumping was requested.
   if (DumpingEnabledForHloModule(*module)) {
     DumpToFileInDirOrStdout(*module, "ptx", ptx);
   }
 
-  const std::vector<uint8> cubin = CompilePtxOrGetCachedResult(
-      stream_exec, ptx, cc_major, cc_minor, module->config());
+  std::vector<uint8> cubin =
+      CompilePtxOrGetCachedResult(stream_exec, ptx, compute_capability.first,
+                                  compute_capability.second, module->config());
 
-  auto thunk_schedule = absl::make_unique<ThunkSchedule>(
-      ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment),
-      hlo_schedule->ThunkLaunchOrder());
-  if (DumpingEnabledForHloModule(*module)) {
-    DumpToFileInDirOrStdout(*module, "thunk_schedule",
-                            thunk_schedule->ToString());
-  }
-
-  std::unique_ptr<HloProfileIndexMap> profile_index_map;
-  std::unique_ptr<HloProfilePrinterData> profile_printer;
-
-  if (module->config().hlo_profiling_enabled() || VLOG_IS_ON(1)) {
-    HloCostAnalysis cost_analysis(ShapeSizeBytesFunction());
-    cost_analysis.set_bytes_per_second(
-        stream_exec->GetDeviceDescription().memory_bandwidth());
-    TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis));
-    VLOG(1) << "HLO memory read+written: "
-            << tensorflow::strings::HumanReadableNumBytes(
-                   cost_analysis.bytes_accessed());
-    if (module->config().hlo_profiling_enabled()) {
-      profile_index_map = absl::make_unique<HloProfileIndexMap>(*module);
-      profile_printer = CreateHloProfilePrinterData(
-          *profile_index_map, cost_analysis, entry_computation->name());
-    }
-  }
-
-  auto* gpu_executable = new GpuExecutable(
-      ptx, cubin, std::make_pair(cc_major, cc_minor), std::move(thunk_schedule),
-      std::move(module), std::move(buffer_assignment),
-      std::move(profile_printer), std::move(profile_index_map));
-  if (embed_ir_in_executable) {
-    DCHECK_NE("", ir_module_string_before_opt);
-    gpu_executable->set_ir_module_string(ir_module_string_before_opt);
-  }
-  return std::unique_ptr<Executable>(gpu_executable);
+  return std::pair<std::string, std::vector<uint8>>(std::move(ptx),
+                                                    std::move(cubin));
 }
 
 std::vector<uint8> NVPTXCompiler::CompilePtxOrGetCachedResult(
@@ -761,16 +496,5 @@
   return cache_value->cubin_data;
 }
 
-StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
-NVPTXCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
-                                  const AotCompilationOptions& options) {
-  return Unimplemented(
-      "not yet implemented: NVPTXCompiler::CompileAheadOfTime");
-}
-
-se::Platform::Id NVPTXCompiler::PlatformId() const {
-  return se::cuda::kCudaPlatformId;
-}
-
 }  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
index 980c00a..60a07f1 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
@@ -23,71 +23,37 @@
 #include "absl/container/node_hash_map.h"
 #include "absl/types/optional.h"
 #include "absl/types/span.h"
-#include "tensorflow/compiler/xla/service/executable.h"
-#include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/compiler/xla/service/llvm_compiler.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
 #include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/core/lib/hash/hash.h"
-#include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/stream_executor_no_cuda.h"
-#include "tensorflow/core/platform/thread_annotations.h"
-#include "tensorflow/stream_executor/stream_executor_pimpl.h"
 
 namespace xla {
 namespace gpu {
 
-// Temporarily expose the optimization pipeline for the GPU backend for reuse
-// in the MLIR GPU backend.
-// TODO(b/137624192): Remove once MLIR backend uses tailored optimizations.
-namespace impl {
-
-Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
-                         se::DeviceMemoryAllocator* device_allocator);
-Status PrepareHloModuleForIrEmitting(HloModule* hlo_module);
-
-}  // namespace impl
-
-// The GPU compiler generates efficient GPU executables.
-class NVPTXCompiler : public LLVMCompiler {
+// NVPTXCompiler generates efficient GPU executables for NVPTX target.
+class NVPTXCompiler : public GpuCompiler {
  public:
   NVPTXCompiler();
   ~NVPTXCompiler() override {}
 
-  // Bring in
-  // StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
-  //     std::vector<std::unique_ptr<HloModule>> modules,
-  //     std::vector<std::vector<se::StreamExecutor*>>
-  //        stream_execs)
-  using LLVMCompiler::Compile;
-
-  StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
-      std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
+  Status OptimizeHloConvolutionCanonicalization(
+      HloModule* hlo_module, se::StreamExecutor* stream_exec,
       se::DeviceMemoryAllocator* device_allocator) override;
 
-  StatusOr<std::unique_ptr<Executable>> RunBackend(
-      std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
+  Status OptimizeHloPostLayoutAssignment(
+      HloModule* hlo_module, se::StreamExecutor* stream_exec,
       se::DeviceMemoryAllocator* device_allocator) override;
 
-  StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
-  CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
-                     AotCompilationOptions const& options) override;
+  HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() override;
 
-  se::Platform::Id PlatformId() const override;
+  GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override;
 
-  HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
-    // Capture just the pointer size, not the entire NVPTXCompiler object.
-    int64 pointer_size = pointer_size_;
-    return [pointer_size](const Shape& shape) {
-      return ShapeUtil::ByteSizeOf(shape, pointer_size);
-    };
-  }
+  StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
+      const HloModule* hlo_module, llvm::Module* llvm_module,
+      GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;
 
  private:
-  // The size in bytes of a pointer. Used by ShapeSizeBytesFunction.
-  const int64 pointer_size_;
-
   tensorflow::mutex mutex_;
 
   // When compiling an HLO module, we need to find a path to the nvvm libdevice
diff --git a/tensorflow/compiler/xla/service/gpu/scratch_allocator.cc b/tensorflow/compiler/xla/service/gpu/scratch_allocator.cc
deleted file mode 100644
index 5793051..0000000
--- a/tensorflow/compiler/xla/service/gpu/scratch_allocator.cc
+++ /dev/null
@@ -1,43 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/gpu/scratch_allocator.h"
-
-namespace xla {
-namespace gpu {
-
-StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes(
-    se::Stream* stream, int64 byte_size) {
-  CHECK_GE(byte_size, 0) << "byte_size must be positive.";
-  if (byte_size > GetMemoryLimitInBytes(stream)) {
-    return se::port::Status(
-        se::port::error::RESOURCE_EXHAUSTED,
-        absl::StrFormat(
-            "Allocating %d bytes exceeds the memory limit of %d bytes.",
-            byte_size, GetMemoryLimitInBytes(stream)));
-  }
-
-  TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory allocated_buffer,
-                      memory_allocator_->Allocate(device_ordinal_, byte_size,
-                                                  /*retry_on_failure=*/false));
-  total_allocated_bytes_ += byte_size;
-
-  se::DeviceMemoryBase buffer_addr = *allocated_buffer;
-  allocated_buffers_.push_back(std::move(allocated_buffer));
-  return se::DeviceMemory<uint8>(buffer_addr);
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/scratch_allocator.h b/tensorflow/compiler/xla/service/gpu/scratch_allocator.h
deleted file mode 100644
index 9654237..0000000
--- a/tensorflow/compiler/xla/service/gpu/scratch_allocator.h
+++ /dev/null
@@ -1,61 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_SCRATCH_ALLOCATOR_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_SCRATCH_ALLOCATOR_H_
-
-#include <vector>
-
-#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/platform/stream_executor_no_cuda.h"
-#include "tensorflow/stream_executor/device_memory_allocator.h"
-
-namespace xla {
-namespace gpu {
-
-class ScratchAllocator : public se::ScratchAllocator {
- public:
-  ScratchAllocator(int device_ordinal,
-                   se::DeviceMemoryAllocator* memory_allocator)
-      : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
-
-  int64 GetMemoryLimitInBytes(se::Stream* stream) override {
-    return 1LL << 32;  // 4GB.  TODO(jlebar): Tune this?
-  }
-  int64 TotalAllocatedBytes() { return total_allocated_bytes_; }
-
-  StatusOr<se::DeviceMemory<uint8>> AllocateBytes(se::Stream* stream,
-                                                  int64 byte_size) override;
-
-  template <typename T>
-  StatusOr<se::DeviceMemory<T>> Allocate(se::Stream* stream,
-                                         int64 num_elements) {
-    TF_ASSIGN_OR_RETURN(se::DeviceMemory<uint8> bytes,
-                        AllocateBytes(stream, num_elements * sizeof(T)));
-    return se::DeviceMemory<T>(bytes);
-  }
-
- private:
-  const int device_ordinal_;
-  se::DeviceMemoryAllocator* memory_allocator_;
-  std::vector<se::OwningDeviceMemory> allocated_buffers_;
-  int64 total_allocated_bytes_ = 0;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_SCRATCH_ALLOCATOR_H_
diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc
index 1cdf975..117931e 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc
@@ -19,6 +19,7 @@
 #include "tensorflow/compiler/xla/layout_util.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/gtl/cleanup.h"
 #include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/cuda_libdevice_path.h"
@@ -201,10 +202,7 @@
   }
 
   auto kernel_base = absl::make_unique<se::KernelBase>(stream_exec);
-  if (!stream_exec->GetKernel(loader_spec, kernel_base.get())) {
-    return InternalError("Unable to load kernel '%s'", kernel_name);
-  }
-
+  TF_RETURN_IF_ERROR(stream_exec->GetKernel(loader_spec, kernel_base.get()));
   return std::move(kernel_base);
 }
 
@@ -217,13 +215,9 @@
   for (const se::DeviceMemoryBase& buf : args) {
     kernel_args->add_device_memory_argument(buf);
   }
-
-  if (!stream->parent()->Launch(stream, se::ThreadDim(threads_per_block),
-                                se::BlockDim(block_count), kernel,
-                                *kernel_args)) {
-    return InternalError("Unable to launch kernel");
-  }
-  return Status::OK();
+  return stream->parent()->Launch(stream, se::ThreadDim(threads_per_block),
+                                  se::BlockDim(block_count), kernel,
+                                  *kernel_args);
 }
 
 se::cuda::PtxCompilationOptions PtxOptsFromConfig(
diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD
index a9b52d9..67051b1 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD
@@ -7,7 +7,7 @@
 load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
 load("//tensorflow:tensorflow.bzl", "tf_cc_test")
 load(
-    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "//tensorflow/core/platform:default/build_config_root.bzl",
     "tf_cuda_tests_tags",
 )
 
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
index a12932f..d722973 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
@@ -99,6 +99,22 @@
                      /*match_optimized_ir=*/true);
 }
 
+TEST_F(GpuKernelTilingTest, UnnestedTransposeC128TypeRun) {
+  const char *const kHloString = R"(
+    HloModule unnested_transpose_3
+
+    ENTRY unnested_transpose_3 {
+      para0 = c128[65,65]{1,0} parameter(0)
+      ROOT copy1 = c128[65,65]{0,1} copy(para0)
+    })";
+
+  // With the current implementation for the available hardwares, we bail out
+  // from the tiled transpose implementation at the last minute. Instead of
+  // checking the transpose is not tiled, we only check the module compiled and
+  // run in this test.
+  EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0}));
+}
+
 TEST_F(GpuKernelTilingTest, SimpleFusionWithTransposeTiled) {
   const char *const kHloString = R"(
     HloModule multiple_output_fusion_1
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index 8cc891f..48b5975 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -22,6 +22,8 @@
 #include "absl/container/flat_hash_set.h"
 #include "absl/memory/memory.h"
 #include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/hlo_live_range.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
 #include "tensorflow/compiler/xla/util.h"
 
 namespace xla {
@@ -29,205 +31,6 @@
 using absl::flat_hash_map;
 using absl::flat_hash_set;
 
-namespace {
-// FlattenSchedule walks through the instruction, and recurse into each called
-// computations. As it walks it also tracks down the ordinal number of each
-// instruction in the schedule and store it in the `instruction_schedule` and
-// 'flattened_instruction_sequence`. The end of each computation is tracked in
-// `computation_schedule`.
-int64 FlattenSchedule(
-    const HloComputation& computation,
-    const HloInstructionSequence& instruction_sequence,
-    const HloSchedule* schedule, int64 start_time,
-    HloInstructionSequence* flattened_instruction_sequence,
-    absl::flat_hash_map<const HloInstruction*, int64>* instruction_schedule,
-    absl::flat_hash_map<const HloComputation*, int64>* computation_schedule) {
-  int64 time = start_time;
-  for (HloInstruction* instruction : instruction_sequence.instructions()) {
-    if (schedule != nullptr) {
-      // Recurse into sub computations if we have a module-scoped schedule.
-      if (instruction->opcode() == HloOpcode::kCall ||
-          instruction->opcode() == HloOpcode::kConditional) {
-        for (const HloComputation* called_computation :
-             instruction->called_computations()) {
-          const HloInstructionSequence& called_sequence =
-              schedule->sequence(called_computation);
-          time = FlattenSchedule(*called_computation, called_sequence, schedule,
-                                 time, flattened_instruction_sequence,
-                                 instruction_schedule, computation_schedule);
-          computation_schedule->insert({called_computation, time});
-        }
-      }
-      if (instruction->opcode() == HloOpcode::kWhile) {
-        const HloInstructionSequence& condition_sequence =
-            schedule->sequence(instruction->while_condition());
-        time =
-            FlattenSchedule(*instruction->while_condition(), condition_sequence,
-                            schedule, time, flattened_instruction_sequence,
-                            instruction_schedule, computation_schedule);
-        computation_schedule->insert({instruction->while_condition(), time});
-        const HloInstructionSequence& body_sequence =
-            schedule->sequence(instruction->while_body());
-        time = FlattenSchedule(*instruction->while_body(), body_sequence,
-                               schedule, time, flattened_instruction_sequence,
-                               instruction_schedule, computation_schedule);
-      }
-    }
-    if (instruction_schedule->count(instruction) != 0) {
-      continue;
-    }
-    instruction_schedule->insert({instruction, time++});
-    flattened_instruction_sequence->push_back(instruction);
-  }
-  computation_schedule->insert({&computation, time});
-  DCHECK_EQ(instruction_schedule->size(),
-            flattened_instruction_sequence->size());
-  DCHECK_EQ(instruction_schedule->size(), time);
-  return time;
-}
-
-// The aliased buffers could have overlapping live ranges.
-// NormalizeAliasedBuffers normalizes the buffer such that each alias buffer has
-// disjoint live range while keeping the live range union the same. This avoid
-// double counting aliased buffer sizes.
-//
-// Before(buffer1 and 2 are aliased):
-//
-//           +----+          live range of buffer1
-//   +------------------+    live range of buffer2
-//
-// After:
-//
-//           +----------+    live range of buffer1
-//   +------+                live range of buffer2
-//
-// Before(buffer1 and 2 are aliased):
-//
-//           +----------+    live range of buffer1
-//   +------------+          live range of buffer2
-//
-// After:
-//
-//           +----------+    live range of buffer1
-//   +------+                live range of buffer2
-//
-// Before(buffer1 and 2 are aliased):
-//
-//           +----------+    live range of buffer1
-//   +---+                   live range of buffer2
-//
-// After(unchanged):
-//
-//           +----------+    live range of buffer1
-//   +---+                   live range of buffer2
-//
-// As another example, imagine we have the following code sequence with live
-// ranges of each while-aliased buffers:
-//
-//                     a      p1    p2    e     b
-// a = ...             +
-//                     |
-// {                   |
-//   p1 = param        |       +
-//   ROOT true         |       |
-// }                   |       +
-// { // body           |
-//   p2 = param        +             +
-//   c = p2 + 1                      +
-//   d = c + 1
-//   ROOT e = d + 1                       +
-// }                                      |
-//                                        |
-// b = while (a)                          +     +
-//                                              |
-// f = b + 1                                    +
-//
-// After normalization it becomes:
-//
-//                     a      p1    p2    e     b
-// a = ...             +
-//                     |
-// {                   +
-//   p1 = param                +
-//   ROOT true                 |
-// }                           +
-// { // body
-//   p2 = param                      +
-//   c = p2 + 1                      +
-//   d = c + 1
-//   ROOT e = d + 1                       +
-// }                                      |
-//                                        |
-// b = while (a)                          +
-//                                              +
-// f = b + 1                                    +
-//
-// Note there is no overlap of live ranges after normalization.
-void NormalizeAliasedBuffers(
-    absl::flat_hash_map<const HloValue*, int64>* buffer_start_map,
-    absl::flat_hash_map<const HloValue*, int64>* buffer_end_map,
-    const std::vector<const HloValue*>& values_to_assign,
-    const HloAliasAnalysis& alias_analysis) {
-  absl::flat_hash_set<const HloValue*> values_to_assign_set(
-      values_to_assign.begin(), values_to_assign.end());
-  for (const HloBuffer& hlo_buffer : alias_analysis.buffers()) {
-    std::vector<const HloValue*> aliased_buffers;
-    for (const HloValue* hlo_value : hlo_buffer.values()) {
-      if (values_to_assign_set.count(hlo_value) != 0) {
-        aliased_buffers.push_back(hlo_value);
-        CHECK_NE(buffer_start_map->count(hlo_value), 0);
-        CHECK_NE(buffer_end_map->count(hlo_value), 0);
-      }
-    }
-    absl::c_sort(
-        aliased_buffers, [&](const HloValue* value1, const HloValue* value2) {
-          if ((*buffer_start_map)[value1] != (*buffer_start_map)[value2]) {
-            return (*buffer_start_map)[value1] < (*buffer_start_map)[value2];
-          }
-          return (*buffer_end_map)[value1] < (*buffer_end_map)[value2];
-        });
-
-    for (int64 i = 0; i < aliased_buffers.size(); ++i) {
-      // We can't use aliased_buffers.size() - 1 since aliased_buffers.size() is
-      // an unsigned integer and can be 0.
-      if (i + 1 == aliased_buffers.size()) {
-        break;
-      }
-
-      const HloValue* value1 = aliased_buffers[i];
-      const HloValue* value2 = aliased_buffers[i + 1];
-      if ((*buffer_start_map)[value1] == (*buffer_start_map)[value2]) {
-        // If value1 has the same start time as value2, make value1 disappear by
-        // setting the end time same as start time:
-        //
-        // Before:
-        // +----+           value1
-        // +----------+     value2
-        //
-        // After:
-        // +                value1
-        // +----------+     value2
-        //
-        // Note that only when heap simulator runs before copy insertion can
-        // this happen where one instruction defines multiple aliased buffers --
-        // This is illegle to execute and can be fixed by copy insertion later.
-        (*buffer_end_map)[value1] = (*buffer_start_map)[value1];
-        continue;
-      }
-
-      if ((*buffer_end_map)[value1] < (*buffer_start_map)[value2]) {
-        continue;
-      }
-
-      if ((*buffer_end_map)[value1] > (*buffer_end_map)[value2]) {
-        (*buffer_end_map)[value2] = (*buffer_end_map)[value1];
-      }
-      (*buffer_end_map)[value1] = (*buffer_start_map)[value2] - 1;
-    }
-  }
-}
-}  // namespace
-
 /*static*/
 StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
     const HloSchedule& schedule,
@@ -289,8 +92,12 @@
   const HloComputation* entry_computation = module.entry_computation();
   const HloInstructionSequence& instruction_sequence =
       schedule.sequence(entry_computation);
+  TF_ASSIGN_OR_RETURN(
+      std::unique_ptr<HloLiveRange> hlo_live_range,
+      HloLiveRange::Run(schedule, alias_analysis, entry_computation));
   TF_RETURN_IF_ERROR(heap.RunComputation(*entry_computation,
-                                         instruction_sequence, alias_analysis));
+                                         instruction_sequence, alias_analysis,
+                                         hlo_live_range.get()));
   return heap.Finish();
 }
 
@@ -304,8 +111,13 @@
         memory_by_computation) {
   HeapSimulator heap(std::move(algorithm), size_fn, options,
                      /*schedule=*/nullptr, memory_by_computation);
-  TF_RETURN_IF_ERROR(
-      heap.RunComputation(computation, instruction_sequence, alias_analysis));
+  HloSchedule schedule(computation.parent());
+  schedule.set_sequence(&computation, instruction_sequence);
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
+                      HloLiveRange::Run(schedule, alias_analysis, &computation,
+                                        /*module_scoped_analysis=*/false));
+  TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
+                                         alias_analysis, hlo_live_range.get()));
   return heap.Finish();
 }
 
@@ -318,8 +130,11 @@
     const Options& options) {
   HeapSimulator heap(std::move(algorithm), size_fn, options,
                      /*schedule=*/schedule, nullptr);
-  TF_RETURN_IF_ERROR(
-      heap.RunComputation(computation, instruction_sequence, alias_analysis));
+  TF_ASSIGN_OR_RETURN(
+      std::unique_ptr<HloLiveRange> hlo_live_range,
+      HloLiveRange::Run(*schedule, alias_analysis, &computation));
+  TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
+                                         alias_analysis, hlo_live_range.get()));
   return heap.Finish();
 }
 
@@ -328,35 +143,24 @@
 Status HeapSimulator::RunComputation(
     const HloComputation& computation,
     const HloInstructionSequence& instruction_sequence,
-    const HloAliasAnalysis& alias_analysis) {
+    const HloAliasAnalysis& alias_analysis, HloLiveRange* hlo_live_range) {
   XLA_VLOG_LINES(1, computation.parent()->ToString());
   XLA_VLOG_LINES(2, computation.ToString());
 
+  VLOG(1) << hlo_live_range->ToString();
+
   HloDataflowAnalysis& dataflow_analysis = alias_analysis.dataflow_analysis();
 
-  // program_end_time is the time of the last instruction scheduled. It is equal
-  // to the number of instructions in a computation.
-  int64 program_end_time =
-      FlattenSchedule(computation, instruction_sequence, schedule_, 0,
-                      &flattened_instruction_sequence_, &instruction_schedule_,
-                      &computation_schedule_);
-
-  VLOG(1) << "Program end time: " << program_end_time;
-
-  algorithm_->SetSchedules(&flattened_instruction_sequence_,
-                           &instruction_schedule_, &computation_schedule_);
-
-  // We track the definition and free events for each buffer, then we go through
-  // each step and reply those events in program order.
-  absl::flat_hash_map<const HloValue*, int64> buffer_start_map;
-  absl::flat_hash_map<const HloValue*, int64> buffer_end_map;
+  algorithm_->SetSchedules(&hlo_live_range->flattened_instruction_sequence(),
+                           &hlo_live_range->instruction_schedule());
 
   // Record the buffer define/free event for each time step. We free all
   // remaining buffers (entry parameter, etc) after the program has finished
   // running, so we set the size of to program_end_time + 1.
-  std::vector<std::vector<const HloValue*>> buffers_defined(program_end_time +
-                                                            1);
-  std::vector<std::vector<const HloValue*>> buffers_freed(program_end_time + 1);
+  std::vector<std::vector<const HloValue*>> buffers_defined(
+      hlo_live_range->schedule_end_time() + 1);
+  std::vector<std::vector<const HloValue*>> buffers_freed(
+      hlo_live_range->schedule_end_time() + 1);
 
   // values_to_assign tracks the HloValues that we need to assign a buffer to.
   // Note that we only need to assign a buffer to a value when both of the
@@ -369,106 +173,49 @@
   // - If the instruction is in a nested call of the current computation, only
   // assign a buffer if we are doing global heap simulation.
   std::vector<const HloValue*> values_to_assign;
+  values_to_assign.reserve(dataflow_analysis.values().size());
 
-  // Keeps track of buffer start time and buffer end time.
   for (const HloValue* value : dataflow_analysis.values()) {
-    // Ignore buffers that are not defined.
-    if (instruction_schedule_.count(value->defining_instruction()) == 0) {
+    // Ignore buffers that are not tracked.
+    if (hlo_live_range->instruction_schedule().count(
+            value->defining_instruction()) == 0) {
       continue;
     }
     if (IgnoreBuffer(value)) {
       continue;
     }
     values_to_assign.push_back(value);
-    int64 buffer_start_time = instruction_schedule_[value->instruction()];
-
-    int64 buffer_end_time = -1;
-    // A buffer's live range ends when the last user finishes executing.
-    for (const HloUse& use : value->uses()) {
-      const HloInstruction* used = use.instruction;
-      // As an optimization, we deem a while's init value's live range ends as
-      // soon as the loop body starts. This optimization is only applicable to
-      // the whole module simulation.
-      if (schedule_ != nullptr && used->opcode() == HloOpcode::kWhile) {
-        // The current live range is at the end of the while, move it to the
-        // beginning of the body.
-        used = used->while_body()->parameter_instruction(0);
-        VLOG(1) << "Moved value " << value->ToShortString()
-                << " to while param: " << used->ToString();
-      }
-      if (instruction_schedule_.count(used) == 0) {
-        // We didn't track the instruction `used`. This happens when we do
-        // computation scope (versus module scope) heap simulation and when the
-        // used instruction is outside of the computation being simulated.
-        continue;
-      }
-      buffer_end_time = std::max(buffer_end_time, instruction_schedule_[used]);
-    }
-
-    if (buffer_end_time == -1) {
-      buffer_end_time = buffer_start_time;
-    }
-
-    for (const HloPosition& position : value->positions()) {
-      const HloComputation* position_comp = position.instruction->parent();
-      // If this instruction lives out, the live range of the instruction should
-      // be extended to the end of the computation.
-      if (position.instruction == position_comp->root_instruction()) {
-        if (schedule_ == nullptr && &computation != position_comp) {
-          continue;
-        }
-        if (computation_schedule_.count(position_comp) == 0) {
-          continue;
-        }
-        buffer_end_time =
-            std::max(buffer_end_time, computation_schedule_[position_comp]);
-      }
-    }
-
-    // Entry parameters live across whole computation.
-    if (value->instruction()->opcode() == HloOpcode::kParameter &&
-        value->instruction()->parent() ==
-            computation.parent()->entry_computation()) {
-      buffer_end_time = program_end_time;
-    }
-
-    CHECK(buffer_start_time <= buffer_end_time);
-
-    buffer_start_map[value] = buffer_start_time;
-    buffer_end_map[value] = buffer_end_time;
   }
 
-  NormalizeAliasedBuffers(&buffer_start_map, &buffer_end_map, values_to_assign,
-                          alias_analysis);
+  auto& buffer_live_ranges = hlo_live_range->buffer_live_ranges();
 
   absl::c_sort(values_to_assign,
                [&](const HloValue* value1, const HloValue* value2) {
-                 if (buffer_start_map[value1] != buffer_start_map[value2]) {
-                   return buffer_start_map[value1] < buffer_start_map[value2];
-                 }
-
-                 if (buffer_end_map[value1] != buffer_end_map[value2]) {
-                   return buffer_end_map[value1] < buffer_end_map[value2];
-                 }
-                 return value1->id() < value2->id();
+                 const auto& live_range1 = buffer_live_ranges.at(value1);
+                 const auto& live_range2 = buffer_live_ranges.at(value2);
+                 return std::forward_as_tuple(live_range1.start,
+                                              live_range1.end, value1->id()) <
+                        std::forward_as_tuple(live_range2.start,
+                                              live_range2.end, value2->id());
                });
 
   // For each value that we need to assign a buffer to, add the define and free
   // events.
   for (const HloValue* value : values_to_assign) {
-    buffers_defined[buffer_start_map[value]].push_back(value);
-    buffers_freed[buffer_end_map[value]].push_back(value);
+    auto live_range = buffer_live_ranges.at(value);
+    buffers_defined[live_range.start].push_back(value);
+    buffers_freed[live_range.end].push_back(value);
   }
 
   // All HloValues in a hlo buffer should be allocated to the same address. This
   // map tracks the first value that got allocated in a buffer.
   absl::flat_hash_map<const HloBuffer*, const HloValue*> first_allocated_value;
 
-  VLOG(1) << "Program time" << program_end_time;
+  VLOG(1) << "Program time" << hlo_live_range->schedule_end_time();
 
   // Go through each step in the program and replay each buffer define and free
   // events.
-  for (int64 i = 0; i < program_end_time + 1; ++i) {
+  for (int64 i = 0; i < hlo_live_range->schedule_end_time() + 1; ++i) {
     VLOG(1) << "Time step: " << i;
 
     for (const HloValue* value : buffers_defined[i]) {
@@ -500,11 +247,21 @@
             if (operand_buffer->values().size() > 1) {
               continue;
             }
-            if (buffer_end_map.count(operand_value) == 0) {
+            auto it = buffer_live_ranges.find(operand_value);
+            if (it == buffer_live_ranges.end()) {
               continue;
             }
+
+            auto& operand_live_range = it->second;
+
+            auto& user_live_range = buffer_live_ranges[value];
+
             // Can only share buffers that are about to be freed.
-            if (buffer_end_map[operand_value] != i) {
+            if (operand_live_range.end != i) {
+              continue;
+            }
+
+            if (IgnoreBuffer(operand_value)) {
               continue;
             }
 
@@ -527,7 +284,7 @@
               ShareBuffer(value, operand_value, value->instruction());
               // The live range of the operand buffer is now extended to the end
               // of the current instruction.
-              buffer_end_map[operand_value] = buffer_end_map[value];
+              operand_live_range.end = user_live_range.end;
               VLOG(1) << "Sharing " << value->ToShortString() << " with "
                       << operand_value->ToShortString()
                       << ", size:" << size_fn_(*value);
@@ -871,29 +628,27 @@
     // start of the first buffer and the end of the last co-located
     // buffer. There could be "holes" in the live ranges of each co-located
     // buffers, but in this heuristics we think they are contiguous.
-    absl::c_sort(sorted_buffer_intervals,
-                 [&](const BufferInterval& x, const BufferInterval& y) {
-                   int64 x_end = x.end;
-                   for (auto colocation : GetTransitiveColocations(x)) {
-                     x_end =
-                         std::max(x_end, buffer_intervals_.at(colocation).end);
-                   }
+    absl::c_sort(sorted_buffer_intervals, [&](const BufferInterval& x,
+                                              const BufferInterval& y) {
+      int64 x_end = x.end;
+      for (auto colocation : GetTransitiveColocations(x)) {
+        x_end = std::max(x_end, buffer_intervals_.at(colocation).end);
+      }
 
-                   int64 y_end = y.end;
-                   for (auto colocation : GetTransitiveColocations(y)) {
-                     y_end =
-                         std::max(y_end, buffer_intervals_.at(colocation).end);
-                   }
+      int64 y_end = y.end;
+      for (auto colocation : GetTransitiveColocations(y)) {
+        y_end = std::max(y_end, buffer_intervals_.at(colocation).end);
+      }
 
-                   if (x_end - x.start != y_end - y.start) {
-                     return x_end - x.start > y_end - y.start;
-                   }
+      if (x_end - x.start != y_end - y.start) {
+        return x_end - x.start > y_end - y.start;
+      }
 
-                   if (x.size != y.size) {
-                     return x.size > y.size;
-                   }
-                   return x.buffer->id() < y.buffer->id();
-                 });
+      if (x.size != y.size) {
+        return x.size > y.size;
+      }
+      return x.buffer->id() < y.buffer->id();
+    });
   } else {
     // Sort by spatial size. We don't look at co-locates as they should have the
     // same size.
diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h
index f70f6c2..00a748f 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.h
+++ b/tensorflow/compiler/xla/service/heap_simulator.h
@@ -31,6 +31,7 @@
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_live_range.h"
 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
@@ -165,7 +166,8 @@
 
   Status RunComputation(const HloComputation& computation,
                         const HloInstructionSequence& instruction_sequence,
-                        const HloAliasAnalysis& alias_analysis);
+                        const HloAliasAnalysis& alias_analysis,
+                        HloLiveRange* live_range);
 
   bool IgnoreBuffer(const HloValue* buffer) const;
   void Alloc(const HloValue* buffer, const HloInstruction* instruction);
@@ -204,15 +206,6 @@
   absl::flat_hash_set<const HloValue*> allocated_buffers_;
   absl::flat_hash_set<const HloValue*> freed_buffers_;
 
-  // The flattened sequence of all instructions in the module. It contains the
-  // same information as instruction_schedule_, but allows fast indexing using
-  // the schedule index.
-  HloInstructionSequence flattened_instruction_sequence_;
-  // instruction_schedule and computation_schedule are the maps that track each
-  // instruction/computation and their ordinal in the schedule.
-  absl::flat_hash_map<const HloInstruction*, int64> instruction_schedule_;
-  absl::flat_hash_map<const HloComputation*, int64> computation_schedule_;
-
   // Debugging information filled in while the heap simulator runs.
   HeapSimulatorTrace debug_trace_;
 };
@@ -271,20 +264,15 @@
   virtual void SetSchedules(
       const HloInstructionSequence* flattened_instruction_sequence,
       const absl::flat_hash_map<const HloInstruction*, int64>*
-          instruction_schedule,
-      const absl::flat_hash_map<const HloComputation*, int64>*
-          computation_schedule) {
+          instruction_schedule) {
     flattened_instruction_sequence_ = flattened_instruction_sequence;
     instruction_schedule_ = instruction_schedule;
-    computation_schedule_ = computation_schedule;
   }
 
  protected:
   const HloInstructionSequence* flattened_instruction_sequence_;
   const absl::flat_hash_map<const HloInstruction*, int64>*
       instruction_schedule_;
-  const absl::flat_hash_map<const HloComputation*, int64>*
-      computation_schedule_;
 };
 
 // NoFragmentationStatsHeap computes the heap size assuming no fragmentation;
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index 4f7daa8..80a0471 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -442,8 +442,8 @@
   tracker.ExpectCallSequence({
       {kAlloc, tracker.BufferAt(paramA, {})},
       {kAlloc, tracker.BufferAt(paramX, {})},
-      {kAlloc, tracker.BufferAt(mul, {})},
       {kAlloc, tracker.BufferAt(paramY, {})},
+      {kAlloc, tracker.BufferAt(mul, {})},
       {kFree, tracker.BufferAt(mul, {})},
       {kShare, tracker.BufferAt(add, {})},
       // All params and outputs are freed at the end.
@@ -516,8 +516,8 @@
   tracker.ExpectCallSequence({
       {kAlloc, tracker.BufferAt(paramA, {})},
       {kAlloc, tracker.BufferAt(paramX, {})},
-      {kAlloc, tracker.BufferAt(mul, {})},
       {kAlloc, tracker.BufferAt(paramY, {})},
+      {kAlloc, tracker.BufferAt(mul, {})},
       {kAlloc, tracker.BufferAt(dot, {})},
       // All params and outputs are freed at the end.
       {kFree, tracker.BufferAt(mul, {})},
@@ -554,8 +554,8 @@
   tracker.ExpectCallSequence({
       {kAlloc, tracker.BufferAt(paramA, {})},
       {kAlloc, tracker.BufferAt(paramX, {})},
-      {kAlloc, tracker.BufferAt(mul, {})},
       {kAlloc, tracker.BufferAt(paramY, {})},
+      {kAlloc, tracker.BufferAt(mul, {})},
       {kAlloc, tracker.BufferAt(dot, {})},
       {kFree, tracker.BufferAt(mul, {})},
       {kFree, tracker.BufferAt(dot, {})},
@@ -596,8 +596,8 @@
   tracker.ExpectCallSequence({
       {kAlloc, tracker.BufferAt(paramA, {})},
       {kAlloc, tracker.BufferAt(paramX, {})},
-      {kAlloc, tracker.BufferAt(mul, {})},
       {kAlloc, tracker.BufferAt(paramY, {})},
+      {kAlloc, tracker.BufferAt(mul, {})},
       {kAlloc, tracker.BufferAt(dot0, {})},
       {kFree, tracker.BufferAt(mul, {})},  // mul no longer used
       {kAlloc, tracker.BufferAt(dot1, {})},
@@ -640,8 +640,8 @@
   tracker.ExpectCallSequence({
       {kAlloc, tracker.BufferAt(paramA, {})},
       {kAlloc, tracker.BufferAt(paramX, {})},
-      {kAlloc, tracker.BufferAt(mul, {})},
       {kAlloc, tracker.BufferAt(paramY, {})},
+      {kAlloc, tracker.BufferAt(mul, {})},
       {kAlloc, tracker.BufferAt(dot0, {})},
       {kFree, tracker.BufferAt(mul, {})},  // mul no longer used
       {kAlloc, tracker.BufferAt(dot1, {})},
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 639e853..6fe91e4 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -532,11 +532,12 @@
     if (options.print_percent()) {
       s << "%";
     }
-    s << name() << " ";
+    s << PrintName(name(), options.print_ids()) << " ";
   }
 
   if (options.print_program_shape()) {
-    s << ShapeUtil::HumanString(ComputeProgramShape()) << " ";
+    s << ShapeUtil::HumanString(ComputeProgramShape(options.print_ids()))
+      << " ";
   }
   s << "{\n";
   {
@@ -753,12 +754,13 @@
   return DeepCopyHelper(instruction, &index, copy_leaf);
 }
 
-ProgramShape HloComputation::ComputeProgramShape() const {
+ProgramShape HloComputation::ComputeProgramShape(bool include_ids) const {
   ProgramShape program_shape;
 
   for (auto* param_instruction : param_instructions_) {
     *program_shape.add_parameters() = param_instruction->shape();
-    *program_shape.add_parameter_names() = param_instruction->name();
+    *program_shape.add_parameter_names() =
+        PrintName(param_instruction->name(), include_ids);
   }
   *program_shape.mutable_result() = root_instruction_->shape();
 
@@ -835,6 +837,14 @@
   if (new_instruction->metadata().op_name().empty()) {
     new_instruction->set_metadata(old_instruction->metadata());
   }
+
+  // Like the metadata above, if the user didn't specify any sharding
+  // information on the new instruction we should copy the old sharding
+  // information (if any).
+  if (!new_instruction->has_sharding()) {
+    new_instruction->set_sharding(old_instruction->sharding_ptr());
+  }
+
   TF_RETURN_IF_ERROR(old_instruction->ReplaceAllUsesWith(new_instruction));
   return RemoveInstructionAndUnusedOperands(old_instruction);
 }
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 111b28a..264bb66 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -288,7 +288,7 @@
 
   // Computes and returns the ProgramShape of this computation (shape of
   // parameters and result with layout).
-  ProgramShape ComputeProgramShape() const;
+  ProgramShape ComputeProgramShape(bool include_ids = true) const;
 
   // Return whether `*this` and `other` are functionally equivalent.
   bool Equal(const HloComputation& other, bool is_layout_sensitive) const;
@@ -314,6 +314,8 @@
   // Replace old instruction with new instruction.  Updates uses and root
   // instruction. Removes old instruction from computation. Precondition:
   // old_instruction and new_instruction must have the compatible shapes.
+  // If |new_instruction| doesn't have any sharding information it will
+  // recieve the sharding information of |old_instruction|.
   Status ReplaceInstruction(HloInstruction* old_instruction,
                             HloInstruction* new_instruction);
 
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index a7e1d3a..9a9898f 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -1543,8 +1543,9 @@
     int64 index_vector_dim = dim_numbers_.index_vector_dim();
     for (int64 i = 0, e = index_vector_.size(); i < e; i++) {
       index_vector_index_[index_vector_dim] = i;
-      TF_ASSIGN_OR_RETURN(index_vector_[i],
-                          start_indices_.GetIntegralAsS64(index_vector_index_));
+      // TODO(george): OK what should happen here?
+      // seems OK to crash though.
+      index_vector_[i] = *start_indices_.GetIntegralAsS64(index_vector_index_);
     }
     return Status::OK();
   }
@@ -2295,12 +2296,10 @@
   }
 
   if (use_fast_add) {
-    TF_ASSIGN_OR_RETURN(double computed_result,
-                        init_values[0]->GetAsDouble({}));
+    double computed_result = *init_values[0]->GetAsDouble({});
     auto reduction_step =
         [&](absl::Span<const int64> input_index) -> StatusOr<bool> {
-      TF_ASSIGN_OR_RETURN(double argument,
-                          input_args[0]->GetAsDouble(input_index));
+      double argument = *input_args[0]->GetAsDouble(input_index);
       computed_result += argument;
       return true;
     };
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 9fcc627..9487d95 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -2035,8 +2035,8 @@
       int64 index_vector_dim = dim_numbers_.index_vector_dim();
       for (int64 i = 0, e = index_vector_.size(); i < e; i++) {
         index_vector_index_[index_vector_dim] = i;
-        TF_ASSIGN_OR_RETURN(index_vector_[i], scatter_indices_.GetIntegralAsS64(
-                                                  index_vector_index_));
+        index_vector_[i] =
+            *scatter_indices_.GetIntegralAsS64(index_vector_index_);
       }
       return Status::OK();
     }
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index ddfcdcf..7e646f3 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -2179,10 +2179,20 @@
   return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape()));
 }
 
+string PrintName(const string& name, bool print_ids) {
+  if (print_ids) {
+    return name;
+  } else {
+    auto dot_position = name.find_first_of(".");
+    return name.substr(0, dot_position);
+  }
+}
+
 namespace {
 
-string PrintName(const string& name, const HloPrintOptions& options) {
-  return StrCat(options.print_percent() ? "%" : "", name);
+string PrintNameInternal(const string& name, const HloPrintOptions& options) {
+  return StrCat(options.print_percent() ? "%" : "",
+                PrintName(name, options.print_ids()));
 }
 
 }  // namespace
@@ -2277,11 +2287,12 @@
       // If we are canonicalizing instruction names and this is a top-level
       // HloInstruction::ToString() call, don't print an instruction name.
       StrAppend(&result,
-                PrintName(canonical_name_map->LookupOrInsert(name()), options),
+                PrintNameInternal(canonical_name_map->LookupOrInsert(name()),
+                                  options),
                 " = ");
     }
   } else {
-    StrAppend(&result, PrintName(name(), options), " = ");
+    StrAppend(&result, PrintNameInternal(name(), options), " = ");
   }
 
   // Print shape.
@@ -2347,10 +2358,10 @@
     // part of the canonical string.
     if (options.canonicalize_instruction_names() &&
         options.is_in_nested_computation()) {
-      str.push_back(PrintName(
+      str.push_back(PrintNameInternal(
           canonical_name_map->LookupOrInsert(operand->name()), options));
     } else if (options.print_operand_names()) {
-      str.push_back(PrintName(operand->name(), options));
+      str.push_back(PrintNameInternal(operand->name(), options));
     }
     StrAppend(out, StrJoin(str, " "));
   });
@@ -2368,27 +2379,30 @@
   if (options.print_subcomputation_mode() ==
       HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
     if (opcode() == HloOpcode::kWhile) {
+      extra.push_back(StrCat(
+          "condition=", PrintNameInternal(while_condition()->name(), options)));
       extra.push_back(
-          StrCat("condition=", PrintName(while_condition()->name(), options)));
-      extra.push_back(
-          StrCat("body=", PrintName(while_body()->name(), options)));
+          StrCat("body=", PrintNameInternal(while_body()->name(), options)));
     } else if (opcode() == HloOpcode::kSelectAndScatter) {
-      extra.push_back(StrCat("select=", PrintName(select()->name(), options)));
       extra.push_back(
-          StrCat("scatter=", PrintName(scatter()->name(), options)));
+          StrCat("select=", PrintNameInternal(select()->name(), options)));
+      extra.push_back(
+          StrCat("scatter=", PrintNameInternal(scatter()->name(), options)));
     } else if (opcode() == HloOpcode::kConditional) {
       if (operand(0)->shape().element_type() == PRED) {
-        extra.push_back(StrCat("true_computation=",
-                               PrintName(true_computation()->name(), options)));
+        extra.push_back(
+            StrCat("true_computation=",
+                   PrintNameInternal(true_computation()->name(), options)));
         extra.push_back(
             StrCat("false_computation=",
-                   PrintName(false_computation()->name(), options)));
+                   PrintNameInternal(false_computation()->name(), options)));
       } else {
         extra.push_back(StrCat(
             "branch_computations={",
             StrJoin(branch_computations(), ", ",
                     [&](string* out, const HloComputation* computation) {
-                      StrAppend(out, PrintName(computation->name(), options));
+                      StrAppend(
+                          out, PrintNameInternal(computation->name(), options));
                     }),
             "}"));
       }
@@ -2399,13 +2413,14 @@
                opcode() == HloOpcode::kScatter ||
                opcode() == HloOpcode::kSort) {
       extra.push_back(
-          StrCat("to_apply=", PrintName(to_apply()->name(), options)));
+          StrCat("to_apply=", PrintNameInternal(to_apply()->name(), options)));
     } else if (!called_computations().empty()) {
       extra.push_back(StrCat(
           "calls=",
           StrJoin(called_computations(), ", ",
                   [&](string* out, const HloComputation* computation) {
-                    StrAppend(out, PrintName(computation->name(), options));
+                    StrAppend(out,
+                              PrintNameInternal(computation->name(), options));
                   })));
     }
   } else if (options.print_subcomputation_mode() ==
@@ -2473,8 +2488,8 @@
     extra.push_back(StrCat("control-predecessors={",
                            StrJoin(control_predecessors_, ", ",
                                    [&](string* out, HloInstruction* pre) {
-                                     StrAppend(out,
-                                               PrintName(pre->name(), options));
+                                     StrAppend(out, PrintNameInternal(
+                                                        pre->name(), options));
                                    }),
                            "}"));
   }
@@ -2573,6 +2588,9 @@
   switch (opcode_) {
     case HloOpcode::kDomain:
     case HloOpcode::kParameter:
+    case HloOpcode::kWhile:
+    case HloOpcode::kConditional:
+    case HloOpcode::kCall:
       return false;
     // Side effecting instrutions cannot be fused.
     default:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index fbaeb5d..78128a7 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -63,6 +63,8 @@
 class HloComputation;
 class HloModule;
 
+string PrintName(const string& name, bool print_ids);
+
 // A bunch of switches that control how the hlo text should be printed.
 class HloPrintOptions {
  public:
@@ -88,7 +90,8 @@
         print_control_dependencies_(true),
         canonicalize_instruction_names_(false),
         indent_amount_(0),
-        is_in_nested_computation_(false) {}
+        is_in_nested_computation_(false),
+        print_ids_(true) {}
 
   static HloPrintOptions ShortParsable() {
     return HloPrintOptions()
@@ -118,6 +121,22 @@
         .set_canonicalize_instruction_names(true);
   }
 
+  // Options to produce a fingerprint of an HLO.
+  static HloPrintOptions Fingerprint() {
+    return HloPrintOptions()
+        .set_print_subcomputation_mode(PrintSubcomputationMode::kNameOnly)
+        .set_print_metadata(false)
+        .set_print_backend_config(false)
+        .set_compact_operands(true)
+        .set_print_operand_names(false)
+        .set_print_operand_shape(true)
+        .set_print_program_shape(false)
+        .set_print_percent(false)
+        .set_print_control_dependencies(false)
+        .set_canonicalize_instruction_names(true)
+        .set_print_ids(false);
+  }
+
   // If true, large constants will be printed out.
   HloPrintOptions& set_print_large_constants(bool value) {
     print_large_constants_ = value;
@@ -154,6 +173,12 @@
     return *this;
   }
 
+  // If true, all printed names include unique identifiers.
+  HloPrintOptions& set_print_ids(bool value) {
+    print_ids_ = value;
+    return *this;
+  }
+
   // If true, program shape of hlo computations will be printed.
   HloPrintOptions& set_print_program_shape(bool value) {
     print_program_shape_ = value;
@@ -216,6 +241,7 @@
   bool include_layout_in_shapes() const { return include_layout_in_shapes_; }
   bool print_operand_shape() const { return print_operand_shape_; }
   bool print_operand_names() const { return print_operand_names_; }
+  bool print_ids() const { return print_ids_; }
   bool print_program_shape() const { return print_program_shape_; }
   bool print_percent() const { return print_percent_; }
   bool print_control_dependencies() const {
@@ -242,6 +268,7 @@
   bool canonicalize_instruction_names_;
   int indent_amount_;
   bool is_in_nested_computation_;
+  bool print_ids_;
 };
 
 // For canonical string output, we need to have a canonical way to rename
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 52d8c7a..312dc1b 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -1737,7 +1737,7 @@
 }
 
 std::vector<string> HloParameterInstruction::ExtraAttributesToStringImpl(
-    const HloPrintOptions& /*options*/) const {
+    const HloPrintOptions& options) const {
   std::vector<string> result;
   if (!parameter_replicated_at_leaf_buffers_) {
     return result;
@@ -1746,8 +1746,10 @@
   for (bool replicated : *parameter_replicated_at_leaf_buffers_) {
     buffers_replicated_strs.push_back(replicated ? "true" : "false");
   }
-  result.push_back(StrCat("parameter_replication={",
-                          StrJoin(buffers_replicated_strs, ","), "}"));
+  if (options.print_ids()) {
+    result.push_back(StrCat("parameter_replication={",
+                            StrJoin(buffers_replicated_strs, ","), "}"));
+  }
   return result;
 }
 
diff --git a/tensorflow/compiler/xla/service/hlo_live_range.cc b/tensorflow/compiler/xla/service/hlo_live_range.cc
new file mode 100644
index 0000000..8ec437e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_live_range.cc
@@ -0,0 +1,235 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_live_range.h"
+
+#include "absl/strings/str_format.h"
+
+namespace xla {
+/*static*/
+StatusOr<std::unique_ptr<HloLiveRange>> HloLiveRange::Run(
+    const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis,
+    const HloComputation* computation, bool module_scoped_analysis) {
+  std::unique_ptr<HloLiveRange> hlo_live_range(
+      new HloLiveRange(schedule, alias_analysis, module_scoped_analysis));
+  hlo_live_range->schedule_end_time_ =
+      hlo_live_range->FlattenSchedule(*computation, 0);
+  hlo_live_range->CalculateBufferStartEndMap();
+  hlo_live_range->NormalizeAliasedBuffers();
+  return std::move(hlo_live_range);
+}
+
+void HloLiveRange::NormalizeAliasedBuffers() {
+  for (const HloBuffer& hlo_buffer : alias_analysis_.buffers()) {
+    std::vector<const HloValue*> aliased_buffers;
+    for (const HloValue* hlo_value : hlo_buffer.values()) {
+      if (buffer_live_ranges_.contains(hlo_value)) {
+        aliased_buffers.push_back(hlo_value);
+      }
+    }
+    absl::c_sort(
+        aliased_buffers, [&](const HloValue* value1, const HloValue* value2) {
+          const TimeBound& live_range1 = buffer_live_ranges_.at(value1);
+          const TimeBound& live_range2 = buffer_live_ranges_.at(value2);
+
+          return std::forward_as_tuple(live_range1.start, live_range1.end) <
+                 std::forward_as_tuple(live_range2.start, live_range2.end);
+        });
+
+    for (int64 i = 0; i + 1 < aliased_buffers.size(); ++i) {
+      const HloValue* value1 = aliased_buffers[i];
+      const HloValue* value2 = aliased_buffers[i + 1];
+      TimeBound& live_range1 = buffer_live_ranges_[value1];
+      TimeBound& live_range2 = buffer_live_ranges_[value2];
+      if (live_range1.start == live_range2.start) {
+        // If value1 has the same start time as value2, make value1 disappear
+        // by setting the end time same as start time:
+        //
+        // Before:
+        // +----+           value1
+        // +----------+     value2
+        //
+        // After:
+        // +                value1
+        // +----------+     value2
+        //
+        // Note that only when heap simulator runs before copy insertion can
+        // this happen where one instruction defines multiple aliased buffers
+        // -- This is illegle to execute and can be fixed by copy insertion
+        // later.
+        live_range1.end = live_range2.end;
+        continue;
+      }
+
+      if (live_range1.end < live_range2.start) {
+        continue;
+      }
+
+      if (live_range1.end > live_range2.end) {
+        live_range2.end = live_range1.end;
+      }
+      live_range1.end = live_range2.start - 1;
+    }
+  }
+}
+
+// FlattenSchedule walks through the computation and tracks down the ordinal
+// number of each instruction in the schedule.
+int64 HloLiveRange::FlattenSchedule(const HloComputation& computation,
+                                    int64 start_time) {
+  if (!schedule_.is_computation_scheduled(&computation)) {
+    total_order_scheduled_ = false;
+    return start_time;
+  }
+
+  const HloInstructionSequence& instruction_sequence =
+      schedule_.sequence(&computation);
+  int64 time = start_time;
+  for (HloInstruction* instruction : instruction_sequence.instructions()) {
+    if (module_scoped_analysis_) {
+      // Recurse into sub computations if running with module scoped analysis
+      // mode.
+      if (instruction->opcode() == HloOpcode::kCall ||
+          instruction->opcode() == HloOpcode::kConditional) {
+        for (const HloComputation* called_computation :
+             instruction->called_computations()) {
+          time = FlattenSchedule(*called_computation, time);
+        }
+      }
+      if (instruction->opcode() == HloOpcode::kWhile) {
+        time = FlattenSchedule(*instruction->while_condition(), time);
+        time++;
+        time = FlattenSchedule(*instruction->while_body(), time);
+      }
+    }
+    if (instruction_schedule_.count(instruction) != 0) {
+      continue;
+    }
+    instruction_schedule_.insert({instruction, time++});
+    flattened_instruction_sequence_.push_back(instruction);
+  }
+  computation_span_times_.try_emplace(&computation,
+                                      TimeBound{start_time, time});
+  DCHECK_EQ(instruction_schedule_.size(),
+            flattened_instruction_sequence_.size());
+  DCHECK_LE(instruction_schedule_.size(), time);
+  return time;
+}
+
+void HloLiveRange::CalculateBufferStartEndMap() {
+  for (const HloValue* value : alias_analysis_.dataflow_analysis().values()) {
+    // Ignore buffers that are not defined.
+    if (instruction_schedule_.count(value->defining_instruction()) == 0) {
+      continue;
+    }
+
+    int64 buffer_start_time = instruction_schedule_[value->instruction()];
+
+    int64 buffer_end_time = -1;
+    for (const HloUse& use : value->uses()) {
+      const HloInstruction* used = use.instruction;
+      // As an optimization, we deem a while's init value's live range ends as
+      // soon as the loop body starts. This optimization is only applicable in
+      // module scoped mode.
+      if (module_scoped_analysis_ && used->opcode() == HloOpcode::kWhile) {
+        // The current live range is at the end of the while, move it to the
+        // beginning of the body.
+        used = used->while_body()->parameter_instruction(0);
+        VLOG(1) << "Moved value " << value->ToShortString()
+                << " to while param: " << used->ToString();
+      }
+      if (instruction_schedule_.count(used) == 0) {
+        // We didn't track the instruction `used`. This happens when we do
+        // computation scope (versus module scope) heap simulation and when
+        // the used instruction is outside of the computation being simulated.
+        continue;
+      }
+      buffer_end_time = std::max(buffer_end_time, instruction_schedule_[used]);
+    }
+
+    // Parameters are defined at the beginning of the computation. This prevents
+    // any instruction that's scheduled before the parameter clobbers the
+    // parameter's buffer.
+    if (value->instruction()->opcode() == HloOpcode::kParameter) {
+      const HloComputation* computation = value->instruction()->parent();
+      auto it = computation_span_times_.find(computation);
+      if (it != computation_span_times_.end()) {
+        buffer_start_time = std::min(buffer_start_time, it->second.start);
+      }
+    }
+
+    if (buffer_end_time == -1) {
+      buffer_end_time = buffer_start_time;
+    }
+
+    for (const HloPosition& position : value->positions()) {
+      const HloComputation* position_comp = position.instruction->parent();
+      // If this instruction lives out, the live range of the instruction
+      // should be extended to the end of the computation.
+      if (position.instruction == position_comp->root_instruction()) {
+        auto it = computation_span_times_.find(position_comp);
+        if (it == computation_span_times_.end()) {
+          continue;
+        }
+        buffer_end_time = std::max(buffer_end_time, it->second.end);
+      }
+    }
+
+    const HloModule* module = value->instruction()->parent()->parent();
+
+    // Readonly entry parameters (parameters that don't alias) live across whole
+    // computation.
+    if (value->instruction()->opcode() == HloOpcode::kParameter &&
+        value->instruction()->parent() == module->entry_computation() &&
+        !module->input_output_alias_config().ParameterHasAlias(
+            value->instruction()->parameter_number(), value->index())) {
+      buffer_end_time = schedule_end_time_;
+    }
+
+    CHECK(buffer_start_time <= buffer_end_time)
+        << buffer_start_time << ", " << buffer_end_time
+        << value->instruction()->ToString();
+
+    auto& live_range = buffer_live_ranges_[value];
+    live_range.start = buffer_start_time;
+    live_range.end = buffer_end_time;
+  }
+}
+
+std::string HloLiveRange::ToString() const {
+  std::string output;
+  absl::StrAppendFormat(&output, "HloLiveRange (max %d):\n",
+                        schedule_end_time_);
+  absl::StrAppendFormat(&output, "  InstructionSequence:\n");
+  auto& instructions = flattened_instruction_sequence().instructions();
+  for (int64 i = 0; i < instructions.size(); ++i) {
+    absl::StrAppendFormat(&output, "    %d:%s\n", i, instructions[i]->name());
+  }
+
+  absl::StrAppendFormat(&output, "  BufferLiveRange:\n");
+
+  for (const HloValue* value : alias_analysis_.dataflow_analysis().values()) {
+    auto it = buffer_live_ranges_.find(value);
+    if (it != buffer_live_ranges_.end()) {
+      absl::StrAppendFormat(
+          &output, "    %s%s:%d-%d\n", value->instruction()->name(),
+          value->index().ToString(), it->second.start, it->second.end);
+    }
+  }
+
+  return output;
+}
+
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_live_range.h b/tensorflow/compiler/xla/service/hlo_live_range.h
new file mode 100644
index 0000000..cc0445a
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_live_range.h
@@ -0,0 +1,206 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
+License for the specific language governing permissions and limitations under
+the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVE_RANGE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVE_RANGE_H_
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/container/flat_hash_set.h"
+#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
+#include "tensorflow/compiler/xla/service/hlo_buffer.h"
+#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace xla {
+
+// Class which computes live range of the output buffers of HLOs and their
+// interference by flattening all computations. The live range is only available
+// when all global computations (while, if, call, etc) have total order
+// sequential orders.
+class HloLiveRange {
+ public:
+  // Constructs a hlo live range object for the given module and computation
+  // assuming the given HLO instruction ordering.
+  static StatusOr<std::unique_ptr<HloLiveRange>> Run(
+      const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis,
+      const HloComputation* computation, bool module_scoped_analysis = true);
+
+  // LogicalTime represents the time in a virtual clock. Each instruction has
+  // one monotonically increasing logical time assigned according to the
+  // schedule.
+  using LogicalTime = int64;
+
+  struct TimeBound {
+    LogicalTime start;
+    LogicalTime end;
+
+    bool friend operator==(const TimeBound& a, const TimeBound& b) {
+      return a.start == b.start && a.end == b.end;
+    }
+    bool friend operator!=(const TimeBound& a, const TimeBound& b) {
+      return !(a == b);
+    }
+  };
+
+  std::string ToString() const;
+
+  const HloInstructionSequence& flattened_instruction_sequence() const {
+    return flattened_instruction_sequence_;
+  }
+
+  // Returns the map from instruction to the end time of that instruction.
+  const absl::flat_hash_map<const HloInstruction*, LogicalTime>&
+  instruction_schedule() const {
+    return instruction_schedule_;
+  }
+
+  // Returns the map from a hlo value to the definition time of that hlo value.
+  const absl::flat_hash_map<const HloValue*, TimeBound>& buffer_live_ranges()
+      const {
+    return buffer_live_ranges_;
+  }
+
+  absl::flat_hash_map<const HloValue*, TimeBound>& buffer_live_ranges() {
+    return buffer_live_ranges_;
+  }
+
+  // Returns the time stamp of the end of the program.
+  LogicalTime schedule_end_time() const { return schedule_end_time_; }
+
+  // Returns whether hlo live range is available on this entire module. Hlo live
+  // range is not available if the module is partially ordered.
+  bool total_order_scheduled() const { return total_order_scheduled_; }
+
+ private:
+  explicit HloLiveRange(const HloSchedule& schedule,
+                        const HloAliasAnalysis& alias_analysis,
+                        bool module_scoped_analysis)
+      : schedule_(schedule),
+        alias_analysis_(alias_analysis),
+        module_scoped_analysis_(module_scoped_analysis) {}
+
+  // FlattenSchedule walks through the instructions in `computation`, and
+  // recurse into each called computations in module_scoped_analysis mode. As it
+  // walks it also tracks down the ordinal number of each instruction in the
+  // schedule and store it in the `instruction_schedule` and
+  // 'flattened_instruction_sequence`. The end of each computation is tracked in
+  // `computation_end_time`.
+  int64 FlattenSchedule(const HloComputation& computation, int64 start_time);
+
+  // Based on the flattened schedule, calculate the start and end of each
+  // buffer.
+  void CalculateBufferStartEndMap();
+
+  // The aliased buffers could have overlapping live ranges.
+  // NormalizeAliasedBuffers normalizes the buffer such that each alias buffer
+  // has disjoint live range while keeping the live range union the same. This
+  // avoid double counting aliased buffer sizes.
+  //
+  // Before(buffer1 and 2 are aliased):
+  //
+  //           +----+          live range of buffer1
+  //   +------------------+    live range of buffer2
+  //
+  // After:
+  //
+  //           +----------+    live range of buffer1
+  //   +------+                live range of buffer2
+  //
+  // Before(buffer1 and 2 are aliased):
+  //
+  //           +----------+    live range of buffer1
+  //   +------------+          live range of buffer2
+  //
+  // After:
+  //
+  //           +----------+    live range of buffer1
+  //   +------+                live range of buffer2
+  //
+  // Before(buffer1 and 2 are aliased):
+  //
+  //           +----------+    live range of buffer1
+  //   +---+                   live range of buffer2
+  //
+  // After(unchanged):
+  //
+  //           +----------+    live range of buffer1
+  //   +---+                   live range of buffer2
+  //
+  // As another example, imagine we have the following code sequence with live
+  // ranges of each while-aliased buffers:
+  //
+  //                     a      p1    p2    e     b
+  // a = ...             +
+  //                     |
+  // {                   |
+  //   p1 = param        |       +
+  //   ROOT true         |       |
+  // }                   |       +
+  // { // body           |
+  //   p2 = param        +             +
+  //   c = p2 + 1                      +
+  //   d = c + 1
+  //   ROOT e = d + 1                       +
+  // }                                      |
+  //                                        |
+  // b = while (a)                          +     +
+  //                                              |
+  // f = b + 1                                    +
+  //
+  // After normalization it becomes:
+  //
+  //                     a      p1    p2    e     b
+  // a = ...             +
+  //                     |
+  // {                   +
+  //   p1 = param                +
+  //   ROOT true                 |
+  // }                           +
+  // { // body
+  //   p2 = param                      +
+  //   c = p2 + 1                      +
+  //   d = c + 1
+  //   ROOT e = d + 1                       +
+  // }                                      |
+  //                                        |
+  // b = while (a)                          +
+  //                                              +
+  // f = b + 1                                    +
+  //
+  // Note there is no overlap of live ranges after normalization.
+  void NormalizeAliasedBuffers();
+
+  const HloSchedule& schedule_;
+  const HloAliasAnalysis& alias_analysis_;
+  bool module_scoped_analysis_;
+  bool total_order_scheduled_ = true;
+
+  HloInstructionSequence flattened_instruction_sequence_;
+  absl::flat_hash_map<const HloInstruction*, int64> instruction_schedule_;
+  absl::flat_hash_map<const HloComputation*, TimeBound> computation_span_times_;
+  absl::flat_hash_map<const HloValue*, TimeBound> buffer_live_ranges_;
+  LogicalTime schedule_end_time_;
+};
+
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVE_RANGE_H_
diff --git a/tensorflow/compiler/xla/service/hlo_live_range_test.cc b/tensorflow/compiler/xla/service/hlo_live_range_test.cc
new file mode 100644
index 0000000..d524d9f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_live_range_test.cc
@@ -0,0 +1,239 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/xla/service/hlo_live_range.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_value.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+namespace {
+
+using TimeBound = HloLiveRange::TimeBound;
+class HloLiveRangeTest : public HloTestBase {
+ protected:
+  HloLiveRangeTest() : module_(CreateNewVerifiedModule()) {}
+  ~HloLiveRangeTest() override {}
+
+  void Analyze(const HloSchedule& schedule) {
+    alias_analysis_ = HloAliasAnalysis::Run(module_.get()).ValueOrDie();
+    hlo_live_range_ = HloLiveRange::Run(schedule, *alias_analysis_,
+                                        module_->entry_computation())
+                          .ValueOrDie();
+  }
+
+  std::unique_ptr<HloModule> module_;
+  std::unique_ptr<HloLiveRange> hlo_live_range_;
+  std::unique_ptr<HloAliasAnalysis> alias_analysis_;
+  // Shapes for use in the examples.
+  Shape f32scalar_ = ShapeUtil::MakeShape(xla::F32, {});
+  Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4});
+
+  // Returns the buffer defined at the given instruction and index.
+  const HloValue* BufferAt(const HloInstruction* instruction,
+                           const ShapeIndex& index) const {
+    return &alias_analysis_->dataflow_analysis().GetUniqueValueAt(instruction,
+                                                                  index);
+  }
+
+  HloLiveRange::TimeBound LiveRangeAt(const HloInstruction* instruction,
+                                      const ShapeIndex& index = {}) const {
+    auto* value = BufferAt(instruction, index);
+    return hlo_live_range_->buffer_live_ranges().at(value);
+  }
+};
+
+TEST_F(HloLiveRangeTest, Multiply) {
+  auto builder = HloComputation::Builder(TestName());
+  auto paramA = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, f32vec4_, "paramA"));
+  auto paramX = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
+  auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
+      f32vec4_, HloOpcode::kMultiply, paramA, paramX));
+  module_->AddEntryComputation(builder.Build());
+
+  HloSchedule schedule(module_.get());
+
+  schedule.set_sequence(module_->entry_computation(), {paramA, paramX, mul});
+
+  Analyze(schedule);
+
+  // Parameters live from beginning to end.
+  EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 3}));
+  EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 3}));
+  // Mul lives after parameters are defined to the end.
+  EXPECT_EQ(LiveRangeAt(mul), TimeBound({2, 3}));
+}
+
+TEST_F(HloLiveRangeTest, MultiplyAdd) {
+  auto builder = HloComputation::Builder(TestName());
+  auto paramA = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, f32vec4_, "paramA"));
+  auto paramX = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
+  auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
+      f32vec4_, HloOpcode::kMultiply, paramA, paramX));
+  auto paramY = builder.AddInstruction(
+      HloInstruction::CreateParameter(2, f32vec4_, "paramY"));
+  auto add = builder.AddInstruction(
+      HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY));
+  module_->AddEntryComputation(builder.Build());
+
+  HloSchedule schedule(module_.get());
+
+  schedule.set_sequence(module_->entry_computation(),
+                        {paramA, paramX, mul, paramY, add});
+
+  Analyze(schedule);
+
+  // Parameters live from beginning to end.
+  EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 5}));
+  EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 5}));
+  EXPECT_EQ(LiveRangeAt(paramY), TimeBound({0, 5}));
+  // Mul starts after parameter are defined (Note: all parameters are defined at
+  // 0, mul starts at 2 which is an arbitrary number).
+  EXPECT_EQ(LiveRangeAt(mul), TimeBound({2, 4}));
+  // Add lives after mul is defined to the end of the program.
+  EXPECT_EQ(LiveRangeAt(add), TimeBound({4, 5}));
+}
+
+TEST_F(HloLiveRangeTest, LiveOutBuffers) {
+  // If a buffer is live out, its life range is extened to the end of
+  // computation.
+  auto builder = HloComputation::Builder(TestName());
+  auto paramA = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, f32vec4_, "paramA"));
+  auto paramX = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
+  auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
+      f32vec4_, HloOpcode::kMultiply, paramA, paramX));
+  auto paramY = builder.AddInstruction(
+      HloInstruction::CreateParameter(2, f32vec4_, "paramY"));
+  auto add = builder.AddInstruction(
+      HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY));
+  auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({mul, add}));
+  module_->AddEntryComputation(builder.Build());
+
+  HloSchedule schedule(module_.get());
+
+  schedule.set_sequence(module_->entry_computation(),
+                        {paramA, paramX, mul, paramY, add, tuple});
+
+  Analyze(schedule);
+
+  // Parameters live from beginning to end.
+  EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 6}));
+  EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 6}));
+  EXPECT_EQ(LiveRangeAt(paramY), TimeBound({0, 6}));
+  // Mul starts after parameter are defined (Note: all parameters are defined at
+  // 0, mul starts at 2 which is an arbitrary number).
+  EXPECT_EQ(LiveRangeAt(mul), TimeBound({2, 6}));
+  // Add lives after mul is defined to the end of the program.
+  EXPECT_EQ(LiveRangeAt(add), TimeBound({4, 6}));
+}
+
+TEST_F(HloLiveRangeTest, InstructionScheduledAfterRoot) {
+  // If a buffer is live out, its life range is extened to the end of
+  // computation.
+  auto builder = HloComputation::Builder(TestName());
+  auto paramA = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, f32vec4_, "paramA"));
+  auto paramX = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
+  auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
+      f32vec4_, HloOpcode::kMultiply, paramA, paramX));
+  auto paramY = builder.AddInstruction(
+      HloInstruction::CreateParameter(2, f32vec4_, "paramY"));
+  auto add = builder.AddInstruction(
+      HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY));
+  auto add2 = builder.AddInstruction(
+      HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY));
+  auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({mul, add}));
+  module_->AddEntryComputation(builder.Build());
+
+  HloSchedule schedule(module_.get());
+
+  // Schedule another instruction after root.
+  schedule.set_sequence(module_->entry_computation(),
+                        {paramA, paramX, mul, paramY, add, tuple, add2});
+
+  Analyze(schedule);
+
+  // Parameters live from beginning to end.
+  EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 7}));
+  EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 7}));
+  EXPECT_EQ(LiveRangeAt(paramY), TimeBound({0, 7}));
+  // Live out buffers live through the computation.
+
+  EXPECT_EQ(LiveRangeAt(mul), TimeBound({2, 7}));
+  EXPECT_EQ(LiveRangeAt(add), TimeBound({4, 7}));
+  EXPECT_EQ(LiveRangeAt(tuple), TimeBound({5, 7}));
+  EXPECT_EQ(LiveRangeAt(add2), TimeBound({6, 6}));
+}
+
+TEST_F(HloLiveRangeTest, AliasedParameter) {
+  // If a parameter is non-readonly(non-aliased), its live range can end in the
+  // middle of the program.
+  auto builder = HloComputation::Builder(TestName());
+  auto paramA = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, f32vec4_, "paramA"));
+  auto paramX = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
+  auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
+      f32vec4_, HloOpcode::kMultiply, paramA, paramX));
+  auto paramY = builder.AddInstruction(
+      HloInstruction::CreateParameter(2, f32vec4_, "paramY"));
+  auto add = builder.AddInstruction(
+      HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY));
+  module_->AddEntryComputation(builder.Build());
+  // Set up alias of the first parameter.
+  TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
+      {}, 0, {}, HloInputOutputAliasConfig::kUserAlias));
+
+  HloSchedule schedule(module_.get());
+
+  schedule.set_sequence(module_->entry_computation(),
+                        {paramA, paramX, mul, paramY, add});
+
+  Analyze(schedule);
+
+  // Non-readonly parameter live like other normal buffers.
+  EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 2}));
+
+  // Readonly parameters live from beginning to end.
+  EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 5}));
+  EXPECT_EQ(LiveRangeAt(paramY), TimeBound({0, 5}));
+  // Mul starts after parameter are defined (Note: all parameters are defined at
+  // 0, mul starts at 2 which is an arbitrary number).
+  EXPECT_EQ(LiveRangeAt(mul), TimeBound({2, 4}));
+  // Add lives after mul is defined to the end of the program.
+  EXPECT_EQ(LiveRangeAt(add), TimeBound({4, 5}));
+}
+
+}  // namespace
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index fbef51c..508c7a1 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -215,7 +215,7 @@
 
 string HloModule::ToString(const HloPrintOptions& options) const {
   std::ostringstream s;
-  s << "HloModule " << name();
+  s << "HloModule " << PrintName(name(), options.print_ids());
   if (has_schedule()) {
     TF_CHECK_OK(schedule().Verify());
     s << ", is_scheduled=true";
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 950c7a7..ef91284 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -300,6 +300,38 @@
     return &fusion_config_;
   }
 
+  // Checks if this config has a list of entry parameters' HLO shardings for
+  // SPMD.
+  bool has_spmd_parameters_shardings() const {
+    return spmd_parameters_shardings_.has_value();
+  }
+
+  // Getter and setter for the list of entry parameters' HLO shardings for SPMD.
+  const std::vector<HloSharding>& spmd_parameters_shardings() const {
+    CHECK(spmd_parameters_shardings_.has_value());
+    return *spmd_parameters_shardings_;
+  }
+  void set_spmd_parameters_shardings(
+      const std::vector<HloSharding>& shardings) {
+    spmd_parameters_shardings_ = shardings;
+  }
+
+  // Checks if this config has the entry computation output's HLO sharding for
+  // SPMD.
+  bool has_spmd_output_sharding() const {
+    return spmd_output_sharding_.has_value();
+  }
+
+  // Getter and setter for the entry computation output's HLO shardings for
+  // SPMD.
+  const HloSharding& spmd_output_sharding() const {
+    CHECK(spmd_output_sharding_.has_value());
+    return *spmd_output_sharding_;
+  }
+  void set_spmd_output_sharding(const HloSharding& sharding) {
+    spmd_output_sharding_ = sharding;
+  }
+
  private:
   HloComputation* AddComputationInternal(
       std::unique_ptr<HloComputation> computation, bool is_entry,
@@ -342,6 +374,14 @@
 
   // Fusion configuration.
   std::vector<std::vector<bool>> fusion_config_;
+
+  // The HLO shardings of the entry computation's parameters for
+  // SPMD-partitioned programs.
+  absl::optional<std::vector<HloSharding>> spmd_parameters_shardings_;
+
+  // The HLO sharding of the entry computation's output (root) for
+  // SPMD-partitioned programs.
+  absl::optional<HloSharding> spmd_output_sharding_;
 };
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index 154cf7f..daeb594 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -208,13 +208,13 @@
   ServiceExecutableRunOptions service_run_options =
       GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream,
                                     nullptr, RunId());
+  service_run_options.mutable_run_options()->set_execution_profile(profile);
 
   TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
                       CreateExecutable(std::move(module), run_hlo_passes));
   TF_ASSIGN_OR_RETURN(
       ScopedShapedBuffer retval,
-      executable->ExecuteOnStreamWrapper(&service_run_options,
-                                         /*profile=*/profile, arguments));
+      executable->ExecuteOnStreamWrapper(&service_run_options, arguments));
   TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
   return std::move(retval);
 }
@@ -244,11 +244,11 @@
   ServiceExecutableRunOptions service_run_options =
       GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream,
                                     nullptr, RunId());
+  service_run_options.mutable_run_options()->set_execution_profile(profile);
 
   TF_ASSIGN_OR_RETURN(
       ScopedShapedBuffer retval,
-      executable->ExecuteOnStreamWrapper(&service_run_options,
-                                         /*profile=*/profile, arguments));
+      executable->ExecuteOnStreamWrapper(&service_run_options, arguments));
   TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
   return std::move(retval);
 }
diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD
index ae7ccad..1551870 100644
--- a/tensorflow/compiler/xla/service/interpreter/BUILD
+++ b/tensorflow/compiler/xla/service/interpreter/BUILD
@@ -1,5 +1,5 @@
 load(
-    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "//tensorflow/core/platform:default/build_config_root.bzl",
     "if_static",
 )
 
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc
index 167a013..0dab86d 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executable.cc
@@ -45,7 +45,7 @@
 
 InterpreterExecutable::~InterpreterExecutable() {}
 
-StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteOnStream(
+StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteAsyncOnStream(
     const ServiceExecutableRunOptions* run_options,
     absl::Span<const ShapedBuffer* const> arguments,
     HloExecutionProfile* hlo_execution_profile) {
@@ -113,22 +113,15 @@
 
   uint64 end_micros = tensorflow::Env::Default()->NowMicros();
 
-  {
-    tensorflow::mutex_lock lock(mutex_);
+  ExecutionProfile* profile = run_options->run_options().execution_profile();
+  if (profile) {
     const double nanoseconds = (end_micros - start_micros) * 1000.0;
-    execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0));
+    profile->set_compute_time_ns(std::max(nanoseconds, 1.0));
   }
 
   return std::move(result);
 }
 
-StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteAsyncOnStream(
-    const ServiceExecutableRunOptions* run_options,
-    absl::Span<const ShapedBuffer* const> arguments) {
-  return tensorflow::errors::Unimplemented(
-      "ExecuteAsyncOnStream is not yet supported on Interpreter.");
-}
-
 /*static*/ int64 InterpreterExecutable::ShapeSizeBytes(const Shape& shape) {
   if (shape.IsOpaque()) {
     return sizeof(void*);
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h
index bda13d3..ba010de 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.h
+++ b/tensorflow/compiler/xla/service/interpreter/executable.h
@@ -46,16 +46,12 @@
                         std::unique_ptr<HloEvaluator> evaluator);
   ~InterpreterExecutable() override;
 
-  StatusOr<ScopedShapedBuffer> ExecuteOnStream(
+  StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
       const ServiceExecutableRunOptions* run_options,
       absl::Span<const ShapedBuffer* const> arguments,
       HloExecutionProfile* hlo_execution_profile) override
       LOCKS_EXCLUDED(evaluator_lock_);
 
-  StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
-      const ServiceExecutableRunOptions* run_options,
-      absl::Span<const ShapedBuffer* const> arguments) override;
-
   static int64 ShapeSizeBytes(const Shape& shape);
 
  protected:
diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h
index 6d33768..43493b6 100644
--- a/tensorflow/compiler/xla/service/interpreter/executor.h
+++ b/tensorflow/compiler/xla/service/interpreter/executor.h
@@ -58,14 +58,14 @@
     return port::Status::OK();
   }
 
-  bool GetKernel(const MultiKernelLoaderSpec &spec,
-                 KernelBase *kernel) override {
-    return false;
+  port::Status GetKernel(const MultiKernelLoaderSpec &spec,
+                         KernelBase *kernel) override {
+    return port::UnimplementedError("Not Implemented");
   }
-  bool Launch(Stream *stream, const ThreadDim &thread_dims,
-              const BlockDim &block_dims, const KernelBase &kernel,
-              const KernelArgsArrayBase &args) override {
-    return false;
+  port::Status Launch(Stream *stream, const ThreadDim &thread_dims,
+                      const BlockDim &block_dims, const KernelBase &kernel,
+                      const KernelArgsArrayBase &args) override {
+    return port::UnimplementedError("Not Implemented");
   }
 
   void *Allocate(uint64 size) override;
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 72ffcd2..ddb049b 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -619,8 +619,9 @@
         TF_RET_CHECK(instruction->branch_computation(j)->num_parameters() == 1);
         ComputationLayout& branch_computation_layout =
             FindOrDie(computation_layouts_, instruction->branch_computation(k));
-        if (branch_computation_layout.result_layout() !=
-            best_branch_computation_layout.result_layout()) {
+        if (!branch_computation_layout.result_layout().MatchesLayoutInShape(
+                best_branch_computation_layout.result_layout().shape(),
+                /*minor_to_major_only=*/true)) {
           computation_layouts_.erase(instruction->branch_computation(k));
           InsertOrDie(&conditional_mismatch_,
                       instruction->branch_computation(k),
@@ -715,8 +716,10 @@
     absl::Span<const ComputationLayout> branch_computation_layouts) {
   for (int j = 0; j < instruction->branch_count(); ++j) {
     const HloInstruction* branch_operand = instruction->operand(j + 1);
-    TF_RET_CHECK(branch_computation_layouts[0].result_layout() ==
-                 branch_computation_layouts[j].result_layout());
+    TF_RET_CHECK(
+        branch_computation_layouts[0].result_layout().MatchesLayoutInShape(
+            branch_computation_layouts[j].result_layout().shape(),
+            /*minor_to_major_only=*/true));
     TF_RET_CHECK(
         branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
             instruction->shape(), /*minor_to_major_only=*/true));
diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
index ffb2df9..9ffb120 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
@@ -151,10 +151,9 @@
 Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) {
   indexed_generators_[parameter] =
       [=](const IrArray::Index& index) -> llvm::Value* {
-    if (tiled_parameter_info_) {
-      if (llvm::Value* param_tile_buffer =
-              tiled_parameter_info_->GetBufferForParameter(
-                  parameter->parameter_number())) {
+    int64 param_num = parameter->parameter_number();
+    if (param_shmem_buffers_.size() > param_num) {
+      if (llvm::Value* param_tile_buffer = param_shmem_buffers_[param_num]) {
         // TODO(jlebar): Add AA metadata to this load.  Tile buffers are global
         // variables, so LLVM's points-to analysis doesn't help us much.  And we
         // want the AA info to be present before address spaces are inferred
@@ -162,13 +161,12 @@
         // address-space-based AA in LLVM, it wouldn't help us much here.
         return b_->CreateLoad(
             b_->CreateGEP(param_tile_buffer, {index.GetConstantWithIndexType(0),
-                                              tiled_parameter_info_->x(),
-                                              tiled_parameter_info_->y()}),
+                                              tile_param_x_, tile_param_y_}),
             "tiled_buffer");
       }
     }
-    return GetIrArrayForFusedParameter(parameter->parameter_number())
-        .EmitReadArrayElement(index, b_);
+    return GetIrArrayForFusedParameter(param_num).EmitReadArrayElement(index,
+                                                                       b_);
   };
   return Status::OK();
 }
diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h
index b1aa6d5..9b02714 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h
@@ -60,10 +60,16 @@
       std::function<std::vector<llvm_ir::IrArray>()>;
 
   FusedIrEmitter(GeneratorForOperandIrArrays operand_arrays_generator,
-                 ElementalIrEmitter* elemental_emitter)
+                 ElementalIrEmitter* elemental_emitter,
+                 llvm::Value* tile_param_x = nullptr,
+                 llvm::Value* tile_param_y = nullptr,
+                 absl::Span<llvm::Value* const> param_shmem_buffers = {})
       : operand_arrays_(),
         operand_arrays_generator_(std::move(operand_arrays_generator)),
-        tiled_parameter_info_(nullptr),
+        tile_param_x_(tile_param_x),
+        tile_param_y_(tile_param_y),
+        param_shmem_buffers_(param_shmem_buffers.begin(),
+                             param_shmem_buffers.end()),
         elemental_emitter_(elemental_emitter),
         b_(elemental_emitter->b()),
         module_(elemental_emitter->module()) {}
@@ -87,10 +93,6 @@
   // Returns the generator function for the given instruction.
   IndexedGenerator GetGenerator(const HloInstruction* instruction) const;
 
-  void SetTiledParameterInfo(const llvm_ir::TiledParameterInfo* info) {
-    tiled_parameter_info_ = info;
-  }
-
   // Evaluates whether fusing 'producer' into 'consumer' might cause exponential
   // behavior in FusedIrEmitter. We currently can have exponential time/memory
   // requirements for emitting certain fusion kernels, in which case we don't
@@ -118,7 +120,15 @@
   absl::optional<std::vector<llvm_ir::IrArray>> operand_arrays_;
   GeneratorForOperandIrArrays operand_arrays_generator_;
 
-  const llvm_ir::TiledParameterInfo* tiled_parameter_info_;
+  // The x coordinate within a tile.
+  llvm::Value* tile_param_x_;
+
+  // The y coordinate within a tile.
+  llvm::Value* tile_param_y_;
+
+  // Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr
+  // if the parameter is not tiled.
+  std::vector<llvm::Value*> param_shmem_buffers_;
 
   ElementalIrEmitter* elemental_emitter_;
 
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h
index 02c7195..fe54768 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h
@@ -255,6 +255,11 @@
   }
 
   template <class... Args>
+  llvm::Value* FCmpULT(Args&&... args) {
+    return mixin_builder()->CreateFCmpULT(std::forward<Args>(args)...);
+  }
+
+  template <class... Args>
   llvm::Value* FCmpOLE(Args&&... args) {
     return mixin_builder()->CreateFCmpOLE(std::forward<Args>(args)...);
   }
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
index 80f4221..6321594 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
@@ -204,34 +204,6 @@
   bool dilated_x_;
 };
 
-// A class to represent information for tiled parameters to support IR emission
-// for 021 transpose.
-class TiledParameterInfo {
- public:
-  TiledParameterInfo(absl::Span<llvm::Value* const> param_buffers,
-                     llvm::Value* y, llvm::Value* x)
-      : param_buffers_(param_buffers), y_(y), x_(x) {}
-
-  llvm::Value* x() const { return x_; }
-  llvm::Value* y() const { return y_; }
-
-  void set_x(llvm::Value* x) { x_ = x; }
-  void set_y(llvm::Value* y) { y_ = y; }
-
-  llvm::Value* GetBufferForParameter(int64 index) const {
-    return param_buffers_[index];
-  }
-
- private:
-  // Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr
-  // if the parameter is not tiled.
-  absl::Span<llvm::Value* const> param_buffers_;
-  // The y coordinate within a tile.
-  llvm::Value* y_;
-  // The x coordinate within a tile.
-  llvm::Value* x_;
-};
-
 }  // namespace llvm_ir
 }  // namespace xla
 
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc
index f08cf01..9c5fe0c 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc
@@ -65,6 +65,16 @@
       continue;
     }
 
+    // If the buffer is a tuple, don't use this algorithm for now. The buffers
+    // that are pointed to by the tuple will still use this algorithm.
+    // TODO(berkin): Because tuples are cheap to place in the alternate memory
+    // (they are just pointers) we don't need to use prefetch/evict logic.
+    if (buffer.values()[0]->shape().IsTuple()) {
+      VLOG(4) << "Keeping buffer " << buffer.ToString()
+              << " in default mem because it is a tuple.";
+      continue;
+    }
+
     auto colocated_intervals = GetSortedColocatedIntervals(interval);
     bool keep_in_default_memory = false;
     for (const BufferInterval* colocated_interval : colocated_intervals) {
@@ -90,12 +100,18 @@
       const HloValue* value = colocated_interval->buffer;
       int64 definition_time =
           instruction_schedule_->at(value->defining_instruction());
+      // Sort the uses by the use time.
+      std::vector<HloUse> uses = value->uses();
+      absl::c_sort(uses, [&](HloUse use1, HloUse use2) {
+        return instruction_schedule_->at(use1.instruction) <
+               instruction_schedule_->at(use2.instruction);
+      });
       // Iterate over the uses.
-      for (HloUse use : value->uses()) {
+      for (HloUse use : uses) {
         int64 use_time = instruction_schedule_->at(use.instruction);
 
-        FindAllocation(definition_time, use_time, use, *colocated_interval,
-                       allocation_sequence);
+        FindAllocation(definition_time, use_time, value->defining_position(),
+                       use, *colocated_interval, allocation_sequence);
         // If there are multiple uses, they can try using the memory allocation
         // already at the alternate memory.
         definition_time = use_time;
@@ -126,10 +142,10 @@
 }
 
 void AlternateMemoryBestFitHeap::FindAllocation(
-    int64 start_time, int64 end_time, HloUse use,
+    int64 start_time, int64 end_time, HloPosition defining_position, HloUse use,
     const BufferInterval& interval,
     MemorySpaceAssignment::AllocationSequence* allocations) {
-  HloInstruction* def_instruction =
+  HloInstruction* operand =
       use.instruction->mutable_operand(use.operand_number);
   // Create an alternate memory interval that starts at the earliest
   // possible position, given by max_prefetch_interval.
@@ -143,6 +159,7 @@
   VLOG(2) << "Finding allocation for " << interval.buffer->ToShortString()
           << " (" << start_time << ", " << end_time
           << "). Size = " << interval.size;
+  CHECK_LT(start_time, end_time);
 
   MemorySpaceAssignment::Allocation* prev_allocation = nullptr;
   bool can_eliminate_copy = false;
@@ -186,13 +203,13 @@
       // If there was a previous allocation, the buffer location is the
       // same as the previous. Otherwise, it is the operand.
       if (prev_allocation != nullptr &&
-          prev_allocation->defining_instruction() == def_instruction) {
+          prev_allocation->instruction() == operand) {
         prev_allocation->Extend(end_time);
       } else {
         allocations->push_back(
             absl::make_unique<MemorySpaceAssignment::Allocation>(
-                def_instruction, MemorySpace::kAlternate, chunk_candidate.chunk,
-                start_time, end_time));
+                operand, defining_position, MemorySpace::kAlternate,
+                chunk_candidate.chunk, start_time, end_time));
       }
       allocations->back()->AddUse(use);
       return;
@@ -203,7 +220,7 @@
   // memory space.
   if (prev_allocation != nullptr &&
       prev_allocation->memory_space() == MemorySpace::kAlternate &&
-      prev_allocation->defining_instruction() == def_instruction) {
+      prev_allocation->instruction() == operand) {
     // If there was an allocation for this HloValue that was in the alternate
     // memory space, we also need to perform an eviction.
     // TODO(berkin): For now evictions happen relative to the most recent
@@ -231,15 +248,15 @@
             end_time, earliest_instruction, latest_instruction));
   } else if (prev_allocation != nullptr &&
              prev_allocation->memory_space() == MemorySpace::kDefault &&
-             prev_allocation->defining_instruction() == def_instruction) {
+             prev_allocation->instruction() == operand) {
     // If the previous allocation was in the default memory space and was
     // defined by the same instruction, extend that.  Otherwise, create a new
     // allocation.
     prev_allocation->Extend(end_time);
   } else {
     allocations->push_back(absl::make_unique<MemorySpaceAssignment::Allocation>(
-        def_instruction, MemorySpace::kDefault, kDefaultMemorySpaceDummyChunk,
-        start_time, end_time));
+        operand, defining_position, MemorySpace::kDefault,
+        kDefaultMemorySpaceDummyChunk, start_time, end_time));
   }
 
   // Try partially placing the buffer in the alternate space. The time that is
@@ -293,7 +310,8 @@
   allocations->back()->AddUse(use);
 }
 
-/*static*/ StatusOr<bool> MemorySpaceAssignment::Run(
+/*static*/ StatusOr<std::unique_ptr<PresetAssignments>>
+MemorySpaceAssignment::Run(
     HloModule* module, int64 alternate_memory_space, int64 max_size_in_bytes,
     int64 min_prefetch_interval, int64 max_prefetch_interval,
     int64 alternate_memory_space_alignment_in_bytes,
@@ -301,7 +319,8 @@
     AlternateMemoryBestFitHeap::IsAllowedInAlternateMemoryFunction
         is_allowed_in_alternate_mem) {
   CHECK(module->has_schedule());
-  VLOG(4) << "Module before memory space assignment: " << module->ToString();
+  VLOG(4) << "Module before memory space assignment: ";
+  XLA_VLOG_LINES(4, module->ToString());
   VLOG(4) << "Schedule: " << module->schedule().ToString();
   TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module));
 
@@ -322,11 +341,11 @@
   TF_RETURN_IF_ERROR(memory_space_assignment.Process());
   TF_RETURN_IF_ERROR(memory_space_assignment.FixSchedule());
 
-  VLOG(4) << "Module after memory space assignment: " << module->ToString();
-  VLOG(4) << "Schedule: " << module->schedule().ToString();
+  VLOG(4) << "Module after memory space assignment: ";
+  XLA_VLOG_LINES(4, module->ToString());
   TF_CHECK_OK(module->schedule().Verify());
 
-  return true;
+  return std::move(memory_space_assignment.preset_assignments_);
 }
 
 Status MemorySpaceAssignment::Allocation::Process(
@@ -334,7 +353,7 @@
   // For non-copy allocations, all we need to do is to update the output memory
   // space if placed in the alternate memory.
   if (memory_space_ == MemorySpace::kAlternate) {
-    Layout* layout = defining_instruction_->mutable_shape()->mutable_layout();
+    Layout* layout = instruction_->mutable_shape()->mutable_layout();
     layout->set_memory_space(memory_space_assignment->alternate_memory_space_);
   }
   return Status::OK();
@@ -343,11 +362,11 @@
 Status MemorySpaceAssignment::CopyAllocation::Process(
     MemorySpaceAssignment* memory_space_assignment) {
   // Copy allocations need to insert asynchronous copy nodes.
-  HloInstruction* def_instruction = defining_instruction();
-  CHECK_NE(def_instruction, nullptr);
+  HloInstruction* producing_instruction = instruction();
+  CHECK_NE(producing_instruction, nullptr);
 
-  Shape shape = def_instruction->shape();
-  HloComputation* computation = def_instruction->parent();
+  Shape shape = producing_instruction->shape();
+  HloComputation* computation = producing_instruction->parent();
 
   // Set the layout to include the memory space.
   Layout* layout = shape.mutable_layout();
@@ -360,12 +379,15 @@
   HloInstruction* copy_start =
       computation->AddInstruction(HloInstruction::CreateUnary(
           ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}),
-          HloOpcode::kCopyStart, def_instruction));
+          HloOpcode::kCopyStart, producing_instruction));
   HloInstruction* copy_done = computation->AddInstruction(
       HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start));
-  // Update the allocation with the defining instruction so that if there
+  // Update the allocation with the copy done instruction so that if there
   // are further copies from it, it can find the correct instruction.
-  defining_instruction_ = copy_done;
+  instruction_ = copy_done;
+  // Also update the defining position. Note that the output of CopyDone is
+  // actually defined in the item {0} of CopyStart.
+  defining_position_ = HloPosition{copy_start, {0}};
 
   // Replace all the uses with the new copy instruction.
   for (HloUse use : uses_) {
@@ -383,9 +405,39 @@
 
 Status MemorySpaceAssignment::Process() {
   // Insert CopyStart/CopyDone pairs.
+  int64 alternate_memory_size = 0;
+  HloPosition prev_defining_position{nullptr, {}};
   for (auto& buffer_and_sequence : allocation_map_) {
     for (auto& allocation : buffer_and_sequence.second) {
       TF_RETURN_IF_ERROR(allocation->Process(this));
+      // Add the offset and size of the allocation in the alternate memory to
+      // the output map. Ensure there is one entry for each position in the
+      // preset assignments.
+      if (allocation->memory_space() == MemorySpace::kAlternate &&
+          prev_defining_position != allocation->defining_position()) {
+        preset_assignments_->add_chunk(allocation->defining_position(),
+                                       allocation->chunk());
+        alternate_memory_size =
+            std::max(alternate_memory_size, allocation->chunk().chunk_end());
+        prev_defining_position = allocation->defining_position();
+      }
+    }
+  }
+
+  if (!preset_assignments_->chunks().empty()) {
+    preset_assignments_->add_size(alternate_memory_space_,
+                                  alternate_memory_size);
+  }
+
+  if (VLOG_IS_ON(3)) {
+    VLOG(3) << "Exported alternate memory allocations:";
+    for (auto& pair : preset_assignments_->chunks()) {
+      VLOG(3) << " [" << pair.second.offset << ", " << pair.second.size
+              << "] : " << pair.first.ToString();
+    }
+    VLOG(3) << "Exported alternate memory sizes:";
+    for (auto& pair : preset_assignments_->sizes()) {
+      VLOG(3) << "  space: " << pair.first << ", size: " << pair.second;
     }
   }
   return Status::OK();
@@ -398,28 +450,51 @@
   schedule_before_[copy_done_schedule_before].push_back(copy_done);
 }
 
+void MemorySpaceAssignment::EnsureInstructionAndOperandsInserted(
+    HloInstruction* new_instruction, HloInstructionSequence* new_sequence,
+    absl::flat_hash_set<HloInstruction*>* inserted_instructions) const {
+  if (inserted_instructions->contains(new_instruction)) {
+    return;
+  }
+  for (HloInstruction* operand : new_instruction->operands()) {
+    EnsureInstructionAndOperandsInserted(operand, new_sequence,
+                                         inserted_instructions);
+  }
+  VLOG(4) << "inserting: " << new_instruction->ToString();
+  new_sequence->push_back(new_instruction);
+  inserted_instructions->insert(new_instruction);
+}
+
 Status MemorySpaceAssignment::FixSchedule() {
   CHECK(module_->has_schedule());
   HloSchedule& schedule = module_->schedule();
-  for (const HloComputation* computation : module_->computations()) {
+  for (const HloComputation* computation :
+       module_->MakeNonfusionComputations()) {
+    CHECK(schedule.is_computation_scheduled(computation));
     const HloInstructionSequence& sequence = schedule.sequence(computation);
     HloInstructionSequence new_sequence;
 
+    absl::flat_hash_set<HloInstruction*> inserted_instructions;
+
     for (HloInstruction* instruction : sequence.instructions()) {
       auto insts_before_iter = schedule_before_.find(instruction);
       if (insts_before_iter != schedule_before_.end()) {
         for (HloInstruction* new_instruction : insts_before_iter->second) {
-          new_sequence.push_back(new_instruction);
-          VLOG(4) << "before: " << new_instruction->ToString();
+          EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence,
+                                               &inserted_instructions);
         }
       }
-      new_sequence.push_back(instruction);
-      VLOG(4) << instruction->ToString();
+      // Insert only if not previously inserted.
+      if (!inserted_instructions.contains(instruction)) {
+        new_sequence.push_back(instruction);
+        inserted_instructions.insert(instruction);
+        VLOG(4) << instruction->ToString();
+      }
       auto insts_after_iter = schedule_after_.find(instruction);
       if (insts_after_iter != schedule_after_.end()) {
         for (HloInstruction* new_instruction : insts_after_iter->second) {
-          new_sequence.push_back(new_instruction);
-          VLOG(4) << "after: " << new_instruction->ToString();
+          EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence,
+                                               &inserted_instructions);
         }
       }
     }
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h
index 5560130..9ddd661 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.h
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.h
@@ -21,6 +21,36 @@
 
 namespace xla {
 
+// This class contains pre-set assignments determined by memory space
+// assignment. It contains two data structures: (1) a chunks vector that maps a
+// defining HloPosition to a Chunk (offset and size), and (2) a sizes vector
+// that maps the memory space to its size. If there is only one alternate memory
+// space like there is currently, there will be one entry in sizes.
+class PresetAssignments {
+ public:
+  PresetAssignments() = default;
+
+  void add_chunk(const HloPosition& position,
+                 const HeapSimulator::Chunk& chunk) {
+    chunks_.emplace_back(position, chunk);
+  }
+
+  void add_size(int64 memory_space, int64 size) {
+    sizes_.emplace_back(memory_space, size);
+  }
+
+  absl::Span<const std::pair<const HloPosition, const HeapSimulator::Chunk>>
+  chunks() const {
+    return chunks_;
+  }
+
+  absl::Span<const std::pair<int64, int64>> sizes() const { return sizes_; }
+
+ private:
+  std::vector<std::pair<const HloPosition, const HeapSimulator::Chunk>> chunks_;
+  std::vector<std::pair<int64, int64>> sizes_;
+};
+
 // MemorySpaceAssignment assigns memory spaces (default or alternate) to each
 // instruction in the module. It will greedily try placing as as many values in
 // the alternate memory space as possible. It uses the heap simulator to
@@ -69,9 +99,11 @@
   //   - CopyAllocation(memory_space=kAlternate, start_time=22, end_time=25)
   class Allocation {
    public:
-    Allocation(HloInstruction* defining_instruction, MemorySpace memory_space,
-               Chunk chunk, int64 start_time, int64 end_time)
-        : defining_instruction_(defining_instruction),
+    Allocation(HloInstruction* instruction, HloPosition defining_position,
+               MemorySpace memory_space, Chunk chunk, int64 start_time,
+               int64 end_time)
+        : instruction_(instruction),
+          defining_position_(defining_position),
           memory_space_(memory_space),
           chunk_(chunk),
           start_time_(start_time),
@@ -89,10 +121,13 @@
     // insert asynchronous copy instructions if necessary.
     virtual Status Process(MemorySpaceAssignment* memory_space_assignment);
 
-    // Returns the defining instruction for this allocation.
-    virtual HloInstruction* defining_instruction() const {
-      return defining_instruction_;
-    }
+    // Returns the instruction that produces this allocation. It might be
+    // different than the instruction in defining_position (e.g., a
+    // GetTupleElement instruction does not define the buffer).
+    virtual HloInstruction* instruction() const { return instruction_; }
+
+    // Returns the defining position for this allocation.
+    HloPosition defining_position() const { return defining_position_; }
 
     const std::vector<HloUse>& uses() const { return uses_; }
     MemorySpace memory_space() const { return memory_space_; }
@@ -101,7 +136,8 @@
     int64 end_time() const { return end_time_; }
 
    protected:
-    HloInstruction* defining_instruction_;
+    HloInstruction* instruction_;
+    HloPosition defining_position_;
     std::vector<HloUse> uses_;
     MemorySpace memory_space_;
     Chunk chunk_;
@@ -116,7 +152,8 @@
                    Chunk chunk, int64 start_time, int64 end_time,
                    HloInstruction* copy_start_schedule_after,
                    HloInstruction* copy_done_schedule_before)
-        : Allocation(/*defining_instruction=*/nullptr, memory_space, chunk,
+        : Allocation(/*instruction=*/nullptr,
+                     /*defining_position=*/{nullptr, {}}, memory_space, chunk,
                      start_time, end_time),
           prev_allocation_(prev_allocation),
           copy_start_schedule_after_(copy_start_schedule_after),
@@ -124,13 +161,13 @@
 
     Status Process(MemorySpaceAssignment* memory_space_assignment) override;
 
-    HloInstruction* defining_instruction() const override {
-      // Unless explicitly set, the defining instruction of a copy allocation in
+    HloInstruction* instruction() const override {
+      // Unless explicitly set, the instruction of a copy allocation in
       // retrieved from the previous allocation.
-      if (defining_instruction_ != nullptr) {
-        return defining_instruction_;
+      if (instruction_ != nullptr) {
+        return instruction_;
       } else {
-        return prev_allocation_.defining_instruction();
+        return prev_allocation_.instruction();
       }
     }
 
@@ -159,7 +196,7 @@
   // HloValues (e.g., based on the opcode) to be placed on the alternate memory.
   // TODO(berkin): Use the cost model instead of using number of instructions to
   // decide how early to prefetch.
-  static StatusOr<bool> Run(
+  static StatusOr<std::unique_ptr<PresetAssignments>> Run(
       HloModule* module, int64 alternate_memory_space, int64 max_size_in_bytes,
       int64 min_prefetch_interval, int64 max_prefetch_interval,
       int64 alternate_memory_space_alignment_in_bytes,
@@ -168,7 +205,9 @@
 
  private:
   MemorySpaceAssignment(HloModule* module, int64 alternate_memory_space)
-      : module_(module), alternate_memory_space_(alternate_memory_space) {}
+      : module_(module),
+        alternate_memory_space_(alternate_memory_space),
+        preset_assignments_(absl::make_unique<PresetAssignments>()) {}
 
   // Process calls Process methods of the allocations after the allocations have
   // been finalized.
@@ -177,6 +216,13 @@
   // FixSchedule inserts asynchronous copies in the schedule.
   Status FixSchedule();
 
+  // Insert an instruction to the schedule, and make sure its dependencies
+  // (operands) are already in the schedule. If not, insert these operands
+  // before the instruction.
+  void EnsureInstructionAndOperandsInserted(
+      HloInstruction* new_instruction, HloInstructionSequence* new_sequence,
+      absl::flat_hash_set<HloInstruction*>* inserted_instructions) const;
+
   // Schedules a pair of asynchronous copy instructions (copy_start and
   // copy_done) where copy_start will be scheduled after the instruction in
   // copy_start_schedule_after and copy_done will be scheduled before the
@@ -189,6 +235,7 @@
   HloModule* module_;
   int64 alternate_memory_space_;
   AllocationMap allocation_map_;
+  std::unique_ptr<PresetAssignments> preset_assignments_;
 
   // These maps hold vectors of new instructions that need to be scheduled after
   // (or before) the instruction in the key. FixSchedule uses these maps to
@@ -229,7 +276,8 @@
   // limits, and append the new allocation(s) to allocations. The new
   // allocations can be in default or alternate memory spaces, or can be
   // prefetches or evictions.
-  void FindAllocation(int64 start_time, int64 end_time, HloUse use,
+  void FindAllocation(int64 start_time, int64 end_time,
+                      HloPosition defining_position, HloUse use,
                       const BufferInterval& interval,
                       MemorySpaceAssignment::AllocationSequence* allocations);
 
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
index 5d6d0c8..b5d8cb4 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
@@ -31,7 +31,7 @@
   const int64 kDefaultMemorySpace = 0;
   const int64 kAlternateMemorySpace = 1;
 
-  void AssignMemorySpace(HloModule* module) {
+  std::unique_ptr<PresetAssignments> AssignMemorySpace(HloModule* module) {
     auto size_fn = [](const BufferValue& buffer) {
       return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
     };
@@ -49,13 +49,14 @@
       return true;
     };
 
-    ASSERT_IS_OK(MemorySpaceAssignment::Run(
-                     module, kAlternateMemorySpace, /*max_size_in_bytes=*/128,
-                     /*min_prefetch_interval=*/2,
-                     /*max_prefetch_interval=*/10,
-                     /*alternate_memory_space_alignment_in_bytes=*/8, size_fn,
-                     is_allowed_in_alternate_mem)
-                     .status());
+    return std::move(MemorySpaceAssignment::Run(
+                         module, kAlternateMemorySpace,
+                         /*max_size_in_bytes=*/128,
+                         /*min_prefetch_interval=*/2,
+                         /*max_prefetch_interval=*/10,
+                         /*alternate_memory_space_alignment_in_bytes=*/8,
+                         size_fn, is_allowed_in_alternate_mem)
+                         .ValueOrDie());
   }
 };
 
@@ -103,7 +104,7 @@
   schedule.set_sequence(computation, {p0, p1, add, sub, mul});
   TF_CHECK_OK(module->set_schedule(schedule));
 
-  AssignMemorySpace(module.get());
+  auto preset_assignments = AssignMemorySpace(module.get());
 
   // Inputs and outputs are currently placed in the default memory. Everything
   // else should be in the alternate memory.
@@ -116,6 +117,10 @@
   EXPECT_THAT(mul, op::ShapeWithLayout(shape));
   EXPECT_THAT(add, op::ShapeWithLayout(shape_in_alternate_mem));
   EXPECT_THAT(sub, op::ShapeWithLayout(shape_in_alternate_mem));
+
+  // Make sure the preset assignments is sane.
+  EXPECT_THAT(preset_assignments->chunks().size(), 2);
+  EXPECT_THAT(preset_assignments->sizes().size(), 1);
 }
 
 TEST_F(MemorySpaceAssignmentTest, NegateChain) {
@@ -338,5 +343,98 @@
   EXPECT_THAT(body_data_mul, op::ShapeWithLayout(shape_in_alternate_mem));
 }
 
+TEST_F(MemorySpaceAssignmentTest, Tuple) {
+  HloComputation::Builder builder(TestName());
+  Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
+  Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({shape});
+  Shape tuple_shape =
+      ShapeUtil::MakeTupleShape({shape, shape, inner_tuple_shape});
+  HloInstruction* p = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, tuple_shape, "p"));
+  HloInstruction* p0 = builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(shape, p, 0));
+  HloInstruction* negate0 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
+  HloInstruction* negate1 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
+  HloInstruction* negate2 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
+  HloInstruction* negate3 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
+  HloInstruction* negate4 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
+  HloInstruction* negate5 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
+  HloInstruction* negate6 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
+  HloInstruction* p1 = builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(shape, p, 1));
+  HloInstruction* add = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1));
+  HloInstruction* p2 = builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(inner_tuple_shape, p, 2));
+  HloInstruction* p2_0 = builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(shape, p2, 0));
+  HloInstruction* mul = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, add, p2_0));
+
+  auto module = CreateNewVerifiedModule();
+  HloComputation* computation = module->AddEntryComputation(builder.Build());
+
+  HloSchedule schedule(module.get());
+  schedule.set_sequence(
+      computation, {p, p0, negate0, negate1, negate2, negate3, negate4, negate5,
+                    negate6, p1, add, p2, p2_0, mul});
+  TF_CHECK_OK(module->set_schedule(schedule));
+
+  AssignMemorySpace(module.get());
+
+  EXPECT_THAT(
+      mul,
+      op::Multiply(op::Add(op::Negate(), op::AsyncCopy(kAlternateMemorySpace,
+                                                       kDefaultMemorySpace,
+                                                       op::GetTupleElement())),
+                   op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
+                                 op::GetTupleElement(op::GetTupleElement()))));
+}
+
+TEST_F(MemorySpaceAssignmentTest, Bitcast) {
+  // Bitcasts can cause the position in the alternate memory to appear multiple
+  // times in the preset assignments. This test ensure the preset assignments
+  // refer to unique positions.
+  HloComputation::Builder builder(TestName());
+  Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
+  HloInstruction* p0 =
+      builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
+  HloInstruction* p1 =
+      builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
+  HloInstruction* negate = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
+  HloInstruction* bitcast =
+      builder.AddInstruction(HloInstruction::CreateBitcast(shape, negate));
+  HloInstruction* add = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, bitcast, p1));
+
+  auto module = CreateNewVerifiedModule();
+  HloComputation* computation = module->AddEntryComputation(builder.Build());
+
+  HloSchedule schedule(module.get());
+  schedule.set_sequence(computation, {p0, p1, negate, bitcast, add});
+  TF_CHECK_OK(module->set_schedule(schedule));
+
+  auto preset_assignments = AssignMemorySpace(module.get());
+
+  // Ensure the positions are unique. Note that we're using a std::set instead
+  // of absl::flat_hash_set because we can make use of HloPosition's comparator
+  // logic instead of providing a hasher.
+  std::set<HloPosition> positions_in_preset_assignments;
+  for (auto& position_and_chunk : preset_assignments->chunks()) {
+    HloPosition position = position_and_chunk.first;
+    EXPECT_EQ(positions_in_preset_assignments.find(position),
+              positions_in_preset_assignments.end());
+    positions_in_preset_assignments.insert(position);
+  }
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD
index 72ca402..ef9cf37 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD
+++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD
@@ -49,3 +49,16 @@
     ],
     alwayslink = True,  # Contains compiler registration
 )
+
+cc_library(
+    name = "hlo_dialect_emitter",
+    srcs = ["hlo_dialect_emitter.cc"],
+    hdrs = ["hlo_dialect_emitter.h"],
+    deps = [
+        "//tensorflow/compiler/xla:status",
+        "//tensorflow/compiler/xla/service:buffer_assignment",
+        "//tensorflow/compiler/xla/service:hlo",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@local_config_mlir//:IR",
+    ],
+)
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc
new file mode 100644
index 0000000..b2a2bd2
--- /dev/null
+++ b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc
@@ -0,0 +1,42 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h"
+
+namespace xla {
+namespace gpu {
+
+HloDialectEmitter::HloDialectEmitter(const HloModule& hlo_module,
+                                     const BufferAssignment& assignment,
+                                     ::mlir::ModuleOp mlir_module)
+    : mlir_module_(mlir_module), builder_(mlir_module_.getContext()) {}
+
+Status DefaultAction(HloInstruction* hlo) {
+  LOG(FATAL) << "Not implemented yet.";
+}
+
+Status HandleFusion(HloInstruction* fusion) {
+  LOG(FATAL) << "Not implemented yet.";
+}
+
+Status HandleCustomCall(HloInstruction* custom_call) {
+  LOG(FATAL) << "Not implemented yet.";
+}
+
+Status FinishVisit(HloInstruction* root) {
+  LOG(FATAL) << "Not implemented yet.";
+}
+}  // namespace gpu
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h
new file mode 100644
index 0000000..622b931
--- /dev/null
+++ b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h
@@ -0,0 +1,64 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_HLO_DIALECT_EMITTER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_HLO_DIALECT_EMITTER_H_
+
+#include "absl/container/flat_hash_map.h"
+#include "mlir/IR/Builders.h"  // TF:local_config_mlir
+#include "mlir/IR/Function.h"  // TF:local_config_mlir
+#include "mlir/IR/MLIRContext.h"  // TF:local_config_mlir
+#include "mlir/IR/Module.h"  // TF:local_config_mlir
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/status.h"
+
+namespace xla {
+namespace gpu {
+
+// This class is the top-level API for the HLO --> HLO dialect compiler. It
+// implements the DfsHloVisitor interface and emits HLO computations as MLIR IR
+// functions.
+class HloDialectEmitter : public DfsHloVisitorWithDefault {
+ public:
+  HloDialectEmitter(const HloModule& hlo_module,
+                    const BufferAssignment& assignment,
+                    ::mlir::ModuleOp mlir_module);
+  ~HloDialectEmitter() override = default;
+
+  // The following methods implement the DfsHloVisitor interface.
+  //
+  // Default action which emits code for most operations. Operations which are
+  // special in some way are handled explicitly in HandleFoo methods.
+  Status DefaultAction(HloInstruction* hlo) override;
+
+  Status HandleFusion(HloInstruction* fusion) override;
+  Status HandleCustomCall(HloInstruction* custom_call) override;
+
+  Status FinishVisit(HloInstruction* root) override;
+
+ private:
+  ::mlir::ModuleOp mlir_module_;
+  ::mlir::Builder builder_;
+  absl::flat_hash_map<const xla::HloComputation*, ::mlir::FuncOp>
+      computation_to_mlir_function_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(HloDialectEmitter);
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_HLO_DIALECT_EMITTER_H_
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc
index 582e593..6c31f6b 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc
@@ -123,7 +123,6 @@
   if (fused->IsMultiOutputFusion()) {
     std::swap(remaining, fused);
   }
-
   if (fused->opcode() == HloOpcode::kFusion) {
     remaining->MergeFusionInstructionIntoMultiOutput(fused);
   } else {
@@ -249,14 +248,12 @@
       multioutput_user_is_not_gte(instr2)) {
     return false;
   }
-
   if (is_connected(instr1, instr2)) {
     return false;
   }
   if (!ShapesCompatibleForFusion(instr1, instr2)) {
     return false;
   }
-
   return true;
 }
 
@@ -339,4 +336,5 @@
 }
 
 bool MultiOutputFusion::DoProducerConsumerMultiOutputFusion() { return false; }
+
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h
index 3d129c4..9000370 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.h
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.h
@@ -40,8 +40,8 @@
 //      fused and their fusion profit scores.
 //
 //  Function Perform() applies the optimization. It picks up the most profitable
-//  pair in the worklist_, check if it's legal to fuse and fuse the pair.
-//  After fusion, it updates the associated structure such as reachability_,
+//  pair in the worklist_, checks if it's legal to fuse and fuses the pair.
+//  After fusion, it updates the associated structures such as reachability_,
 //  candidates_ and worklist_.
 //  Note that the reachability map is updated based on the original computation.
 //  This works because the reachability is monotonically increasing with
@@ -105,13 +105,6 @@
   virtual bool DoProducerConsumerMultiOutputFusion();
 
  private:
-  // Update the internal data structures after instr1 and instr2 are fused into
-  // one fusion instruction.
-  void Update(HloInstruction* instr1, HloInstruction* instr2);
-
-  // Computation for the pass.
-  HloComputation* computation_;
-
   // An internal data structure for each instruction in current computation.
   // When an instruction is removed, member 'hlo' is set to nullptr.
   struct FusionCandidate {
@@ -119,16 +112,6 @@
     std::list<std::pair<HloInstruction*, int64>> fusibles;
     explicit FusionCandidate(HloInstruction* hlo) : hlo(hlo) {}
   };
-  std::vector<FusionCandidate> candidates_;
-
-  // A map that maps an instruction to the index_.
-  absl::flat_hash_map<HloInstruction*, int> candidates_index_;
-
-  // The reachability map of current computation.
-  std::unique_ptr<HloReachabilityMap> reachability_;
-
-  // This stores all the candidate instructions in current computation.
-  std::vector<HloInstruction*> all_fusion_candidates_;
 
   // The pair of candidates to be fused and the profit score.
   struct ToBeFused {
@@ -139,7 +122,10 @@
         : instr1(instr1), instr2(instr2), score(score) {}
     bool operator<(const ToBeFused& rhs) const { return score < rhs.score; }
   };
-  std::priority_queue<ToBeFused> worklist_;
+
+  // Update the internal data structures after instr1 and instr2 are fused into
+  // one fusion instruction.
+  void Update(HloInstruction* instr1, HloInstruction* instr2);
 
   int64 get_candidate_id(HloInstruction* instr) {
     return FindOrDie(candidates_index_, instr);
@@ -156,6 +142,21 @@
   bool is_connected(HloInstruction* instr1, HloInstruction* instr2) {
     return reachability_->IsConnected(instr1, instr2);
   }
+
+  std::vector<FusionCandidate> candidates_;
+  std::priority_queue<ToBeFused> worklist_;
+
+  // A map that maps an instruction to the index_.
+  absl::flat_hash_map<HloInstruction*, int> candidates_index_;
+
+  // The reachability map of current computation.
+  std::unique_ptr<HloReachabilityMap> reachability_;
+
+  // This stores all the candidate instructions in current computation.
+  std::vector<HloInstruction*> all_fusion_candidates_;
+
+  // Computation for the pass.
+  HloComputation* computation_;
 };
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h
index db2cd28..741e07c 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher.h
+++ b/tensorflow/compiler/xla/service/pattern_matcher.h
@@ -1779,30 +1779,19 @@
       return true;
     }
 
-    // Check that literal == static_cast<LitearlTy>(val) and
-    // val == static_cast<ValTy>(literal).  This is sufficient to ensure that
-    // the two constant scalars are actually "equal".
-    auto val_literal = LiteralUtil::CreateR0(*val_);
-    auto literal_r0_or = const_inst->literal().Reshape({});
-    auto val_as_literal_ty_or =
-        val_literal.Convert(const_inst->shape().element_type());
-    if (!literal_r0_or.ok() || !val_as_literal_ty_or.ok()) {
-      EXPLAIN << "could not construct relevant Literals (how did this happen?)";
+    auto const_inst_scalar_or = const_inst->literal().Reshape({});
+    if (!const_inst_scalar_or.ok()) {
+      EXPLAIN << "could not convert matched literal to effective scalar";
       return false;
     }
-    auto literal_r0 = std::move(literal_r0_or).ValueOrDie();
-    auto val_as_literal_ty = std::move(val_as_literal_ty_or).ValueOrDie();
-    auto literal_r0_as_val_ty_or =
-        literal_r0.Convert(val_literal.shape().element_type());
-    bool rv = literal_r0_as_val_ty_or.ok() &&  //
-              literal_r0_as_val_ty_or.ValueOrDie() == val_literal &&
-              literal_r0 == val_as_literal_ty;
-    if (!rv) {
+    Literal const_inst_scalar = std::move(const_inst_scalar_or).ValueOrDie();
+    if (!const_inst_scalar.IsEqualAt({}, *val_)) {
       EXPLAIN << "HloInstruction's constant value "
-              << literal_r0.ToStringWithoutShape()
+              << const_inst_scalar.ToStringWithoutShape()
               << " did not match expected value " << *val_;
+      return false;
     }
-    return rv;
+    return true;
   }
 
   absl::optional<ScalarTy> val_;
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 5ec45eb..1353c00 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -351,11 +351,11 @@
   VLOG(1) << StrFormat("BuildExecutable on service %p", this);
 
   // Dump computation proto state if flag is set.
-  std::vector<std::unique_ptr<HloSnapshot>> hlo_snapshots;
+  std::vector<std::unique_ptr<HloProto>> hlo_protos;
   for (int64 i = 0; i < module_protos.size(); ++i) {
-    auto hlo_snapshot = absl::make_unique<HloSnapshot>();
-    *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = *module_protos[i];
-    hlo_snapshots.push_back(std::move(hlo_snapshot));
+    auto hlo_proto = absl::make_unique<HloProto>();
+    *hlo_proto->mutable_hlo_module() = *module_protos[i];
+    hlo_protos.push_back(std::move(hlo_proto));
   }
 
   VLOG(1) << "Computations:";
@@ -383,7 +383,7 @@
     const auto& debug_opts = module_configs[i]->debug_options();
     if (DumpingEnabledForHloModule(module_protos[i]->name(), debug_opts) &&
         debug_opts.xla_dump_hlo_snapshots()) {
-      executables[i]->set_hlo_snapshot(std::move(hlo_snapshots[i]));
+      executables[i]->set_hlo_proto(std::move(hlo_protos[i]));
     }
   }
 
@@ -451,13 +451,19 @@
       options.set_intra_op_thread_pool(
           backend->eigen_intra_op_thread_pool_device());
       options.set_device_assignment(&device_assignment);
+      // Use run-time profile information from execution_profile on the 0th
+      // device.
+      if (i == 0) {
+        options.set_execution_profile(profile);
+      }
       ServiceExecutableRunOptions run_options(options,
                                               backend->StreamBorrower());
 
       // Asynchronously launch the computation.
       TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
                           executables[i]->ExecuteAsyncOnStream(
-                              &run_options, arguments[i][replica]));
+                              &run_options, arguments[i][replica],
+                              /*hlo_execution_profile=*/nullptr));
 
       if (replica == 0 && profile != nullptr) {
         streams.back()->ThenStopTimer(timers.back().get());
@@ -490,10 +496,6 @@
     uint64 nanoseconds =
         *std::max_element(timer_nanoseconds.begin(), timer_nanoseconds.end());
 
-    // Merge in run-time profile information from execution_profile on the
-    // zeroth device.
-    profile->MergeFrom(executables[0]->execution_profile());
-
     // Overall execution time (in nanoseconds) from the executor timer.
     profile->set_compute_and_transfer_time_ns(nanoseconds);
 
@@ -546,13 +548,13 @@
     options.set_intra_op_thread_pool(
         backend->eigen_intra_op_thread_pool_device());
     options.set_device_assignment(&device_assignment);
+    options.set_execution_profile(profile);
     run_options.emplace_back(options, backend->StreamBorrower());
   }
 
   if (options_.number_of_replicas() == 1) {
-    TF_ASSIGN_OR_RETURN(
-        auto result, executable->ExecuteOnStreamWrapper(&run_options[0],
-                                                        profile, arguments[0]));
+    TF_ASSIGN_OR_RETURN(auto result, executable->ExecuteOnStreamWrapper(
+                                         &run_options[0], arguments[0]));
     return allocation_tracker_.Register(std::move(result), result_tag);
   }
 
@@ -692,14 +694,17 @@
     executable_ptrs.push_back(executable.get());
   }
 
+  std::vector<HloSnapshot> snapshots;
+  snapshots.resize(executable_ptrs.size());
   for (int i = 0; i < executable_ptrs.size(); i++) {
     if (executable_ptrs[i]->dumping_snapshot()) {
+      *snapshots[i].mutable_hlo() = *executable_ptrs[i]->hlo_proto();
       TF_ASSIGN_OR_RETURN(auto stream,
                           execute_backend_->BorrowStream(
                               all_executors[i][0]->device_ordinal()));
       TF_RETURN_IF_ERROR(RecordArguments(all_arguments[i].front(), stream.get(),
                                          execute_backend_->transfer_manager(),
-                                         executable_ptrs[i]->hlo_snapshot()));
+                                         &snapshots[i]));
     }
   }
 
@@ -746,9 +751,8 @@
                           execute_backend_->BorrowStream(all_executors[i][0]));
       TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(),
                                       execute_backend_->transfer_manager(),
-                                      executable->hlo_snapshot()));
-      DumpHloSnapshotIfEnabled(executable->module(),
-                               *executable->hlo_snapshot());
+                                      &snapshots[i]));
+      DumpHloSnapshotIfEnabled(executable->module(), snapshots[i]);
     }
   }
 
@@ -803,9 +807,9 @@
   const auto& debug_opts = module_config->debug_options();
   if (DumpingEnabledForHloModule(module_proto.name(), debug_opts) &&
       debug_opts.xla_dump_hlo_snapshots()) {
-    auto hlo_snapshot = absl::make_unique<HloSnapshot>();
-    *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = module_proto;
-    executable->set_hlo_snapshot(std::move(hlo_snapshot));
+    auto hlo_proto = absl::make_unique<HloProto>();
+    *hlo_proto->mutable_hlo_module() = module_proto;
+    executable->set_hlo_proto(std::move(hlo_proto));
   }
 
   return std::move(executable);
@@ -891,12 +895,13 @@
   TF_ASSIGN_OR_RETURN(auto stream,
                       execute_backend_->BorrowStream(
                           execute_backend_->default_stream_executor()));
+  HloSnapshot snapshot;
   if (executable->dumping_snapshot()) {
-    executable->hlo_snapshot()->set_execution_platform(
-        execute_backend_->platform()->Name());
-    TF_RETURN_IF_ERROR(RecordArguments(
-        replicated_arguments.front(), stream.get(),
-        execute_backend_->transfer_manager(), executable->hlo_snapshot()));
+    *snapshot.mutable_hlo() = *executable->hlo_proto();
+    snapshot.set_execution_platform(execute_backend_->platform()->Name());
+    TF_RETURN_IF_ERROR(
+        RecordArguments(replicated_arguments.front(), stream.get(),
+                        execute_backend_->transfer_manager(), &snapshot));
   }
 
   TF_ASSIGN_OR_RETURN(
@@ -913,8 +918,8 @@
         allocation_tracker_.ResolveForReplica(result->output(), 0));
     TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(),
                                     execute_backend_->transfer_manager(),
-                                    executable->hlo_snapshot()));
-    DumpHloSnapshotIfEnabled(executable->module(), *executable->hlo_snapshot());
+                                    &snapshot));
+    DumpHloSnapshotIfEnabled(executable->module(), snapshot);
   }
 
   VLOG(1) << "successfully completed 'execute' request";
diff --git a/tensorflow/compiler/xla/service/slow_operation_alarm.cc b/tensorflow/compiler/xla/service/slow_operation_alarm.cc
new file mode 100644
index 0000000..3a0bd83
--- /dev/null
+++ b/tensorflow/compiler/xla/service/slow_operation_alarm.cc
@@ -0,0 +1,136 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/slow_operation_alarm.h"
+
+#include <list>
+#include <mutex>  // NOLINT (for std::call_once, not std::mutex)
+
+#include "absl/algorithm/container.h"
+#include "absl/base/thread_annotations.h"
+#include "absl/memory/memory.h"
+#include "absl/synchronization/mutex.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace xla {
+namespace {
+
+absl::Mutex mu(absl::kConstInit);
+absl::CondVar* ready;
+std::once_flag init_flag;
+std::list<SlowOperationAlarm*>* outstanding_alarms ABSL_PT_GUARDED_BY(mu) =
+    nullptr;
+
+void AlarmLoop() {
+  while (true) {
+    absl::MutexLock lock(&mu);
+
+    // Fire any alarms which are ready.
+    absl::Time now = absl::Now();
+    for (auto it = outstanding_alarms->begin();
+         it != outstanding_alarms->end();) {
+      auto next = std::next(it);
+      auto* alarm = *it;
+      // Fire the alarm if applicable.
+      if (alarm->deadline() <= now) {
+        outstanding_alarms->erase(it);
+        int64 count =
+            alarm->counter() == nullptr ? 0 : alarm->counter()->fetch_add(1);
+        // If the alarm has a counter, only fire if the count is a power of 2.
+        if (count == 0 || (count & (count - 1)) == 0) {
+          // We fire alarms with LOG(ERROR) because otherwise it might not show
+          // up without --logtostderr.
+          LOG(ERROR) << alarm->msg();
+        }
+      }
+      it = next;
+    }
+
+    if (outstanding_alarms->empty()) {
+      ready->Wait(&mu);
+      continue;
+    }
+
+    SlowOperationAlarm* next_alarm = *absl::c_min_element(
+        *outstanding_alarms,
+        [](const SlowOperationAlarm* a, const SlowOperationAlarm* b) {
+          return a->deadline() < b->deadline();
+        });
+    ready->WaitWithDeadline(&mu, next_alarm->deadline());
+  }
+}
+
+void ScheduleAlarm(SlowOperationAlarm* alarm) {
+  std::call_once(init_flag, [] {
+    ready = new absl::CondVar();
+    outstanding_alarms = new std::list<SlowOperationAlarm*>();
+    (void)tensorflow::Env::Default()->StartThread(
+        tensorflow::ThreadOptions(), "SlowOperationAlarm", [] { AlarmLoop(); });
+  });
+
+  absl::MutexLock lock(&mu);
+  outstanding_alarms->push_back(alarm);
+  ready->Signal();
+}
+
+void UnscheduleAlarm(const SlowOperationAlarm* alarm) {
+  absl::MutexLock lock(&mu);
+  CHECK(outstanding_alarms != nullptr);
+  auto it = absl::c_find(*outstanding_alarms, alarm);
+  if (it != outstanding_alarms->end()) {
+    outstanding_alarms->erase(it);
+  }
+}
+
+}  // namespace
+
+SlowOperationAlarm::SlowOperationAlarm(absl::Duration timeout, string msg,
+                                       std::atomic<int64>* counter /*=nullptr*/)
+    : deadline_(absl::Now() + timeout),
+      msg_(std::move(msg)),
+      counter_(counter) {
+  ScheduleAlarm(this);
+}
+
+SlowOperationAlarm::~SlowOperationAlarm() { UnscheduleAlarm(this); }
+
+std::unique_ptr<SlowOperationAlarm> SlowCompilationAlarm() {
+  // Pass a counter to these alarms so they only log once every power-of-two
+  // occurrences.
+  static auto* counter = new std::atomic<int64>(0);
+
+  const char* separator = "\n********************************";
+#if NDEBUG
+  return absl::make_unique<SlowOperationAlarm>(
+      absl::Duration(absl::Minutes(2)),
+      absl::StrCat(
+          separator,
+          "\nVery slow compile?  If you want to file a bug, run with envvar "
+          "XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.",
+          separator),
+      counter);
+#else
+  return absl::make_unique<SlowOperationAlarm>(
+      absl::Duration(absl::Seconds(10)),
+      absl::StrCat(
+          separator,
+          "\nSlow compile?  XLA was built without compiler optimizations, "
+          "which can be slow.  Try rebuilding with -c opt.",
+          separator),
+      counter);
+#endif
+}
+
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/slow_operation_alarm.h b/tensorflow/compiler/xla/service/slow_operation_alarm.h
new file mode 100644
index 0000000..014fc77
--- /dev/null
+++ b/tensorflow/compiler/xla/service/slow_operation_alarm.h
@@ -0,0 +1,70 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SLOW_OPERATION_ALARM_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_SLOW_OPERATION_ALARM_H_
+
+#include <atomic>
+#include <memory>
+#include <string>
+#include <tuple>
+
+#include "absl/time/time.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+
+// This RAII object asynchronously prints a warning if it's alive for more than
+// a certain amount of time.
+class SlowOperationAlarm {
+ public:
+  // If `counter` is not null, this alarm will throttle itself to logging
+  // once-every-power-of-two occurrences. The counter must outlive this object.
+  SlowOperationAlarm(absl::Duration timeout, std::string msg,
+                     std::atomic<int64>* counter = nullptr);
+  ~SlowOperationAlarm();
+
+  // Not copyable or movable, because the constructor stores a pointer to `this`
+  // into a global variable.
+  SlowOperationAlarm(const SlowOperationAlarm&) = delete;
+  SlowOperationAlarm(const SlowOperationAlarm&&) = delete;
+  SlowOperationAlarm& operator=(const SlowOperationAlarm&) = delete;
+  SlowOperationAlarm& operator=(const SlowOperationAlarm&&) = delete;
+
+  absl::Time deadline() const { return deadline_; }
+  absl::string_view msg() const { return msg_; }
+  std::atomic<int64>* counter() { return counter_; }
+
+ private:
+  absl::Time deadline_;
+  std::string msg_;
+  // counter_ may be null.  If it's not, this alarm prints something only once
+  // every power of two occurrences.
+  std::atomic<int64>* counter_;
+};
+
+// Returns an object which prints a warning about slow compilation after a
+// certain amount of time.
+//
+// In debug builds, recommends building with -c opt.
+//
+// In opt builds, recommends filing a bug.
+//
+// This is throttled to once-every-power-of-two occurrences, globally.
+ABSL_MUST_USE_RESULT std::unique_ptr<SlowOperationAlarm> SlowCompilationAlarm();
+
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_SLOW_OPERATION_ALARM_H_
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index f670508..7f0eb42 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -3,7 +3,7 @@
 
 load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "generate_backend_test_macros", "xla_test", "xla_test_library")
 load(
-    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "//tensorflow/core/platform:default/build_config_root.bzl",
     "tf_cuda_tests_tags",
 )
 load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
@@ -296,9 +296,12 @@
 xla_test(
     name = "conv_depthwise_test",
     timeout = "long",
-    srcs = ["conv_depthwise_test.cc"],
+    srcs = [
+        "conv_depthwise_test.cc",
+    ],
     shard_count = 50,
     deps = [
+        ":conv_depthwise_common",
         ":test_macros_header",
         "//tensorflow/compiler/xla:execution_options_util",
         "//tensorflow/compiler/xla:status_macros",
@@ -709,9 +712,151 @@
     ],
 )
 
+cc_library(
+    name = "conv_depthwise_common",
+    testonly = True,
+    srcs = ["conv_depthwise_common.cc"],
+    hdrs = ["conv_depthwise_common.h"],
+    deps = [
+        ":test_macros_header",
+        "//tensorflow/compiler/xla:execution_options_util",
+        "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/compiler/xla:test",
+        "//tensorflow/compiler/xla/client:xla_computation",
+        "//tensorflow/compiler/xla/service:bfloat16_normalization",
+        "//tensorflow/compiler/xla/service:despecializer",
+        "//tensorflow/compiler/xla/service:hlo_parser",
+        "//tensorflow/compiler/xla/tests:client_library_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "@com_google_absl//absl/types:optional",
+    ],
+)
+
 xla_test(
-    name = "exhaustive_unary_test",
+    name = "exhaustive_unary_test_f32_or_smaller",
     srcs = ["exhaustive_unary_test.cc"],
+    copts = ["-DUNARY_TEST_TARGET_F32_OR_SMALLER"],
+    real_hardware_only = True,  # Very slow on the interpreter.
+    shard_count = 48,
+    tags = [
+        "optonly",
+        # This is a big test that we skip for capacity reasons in OSS testing.
+        "no_oss",
+    ],
+    deps = [
+        ":exhaustive_op_test_utils",
+    ],
+)
+
+xla_test(
+    name = "exhaustive_unary_test_f64",
+    srcs = ["exhaustive_unary_test.cc"],
+    backends = [
+        "gpu",
+        "cpu",
+    ],
+    copts = ["-DUNARY_TEST_TARGET_F64"],
+    real_hardware_only = True,  # Very slow on the interpreter.
+    shard_count = 48,
+    tags = [
+        "optonly",
+        # This is a big test that we skip for capacity reasons in OSS testing.
+        "no_oss",
+    ],
+    deps = [
+        ":exhaustive_op_test_utils",
+    ],
+)
+
+xla_test(
+    name = "exhaustive_unary_test_complex",
+    srcs = ["exhaustive_unary_test.cc"],
+    backends = [
+        "gpu",
+        "cpu",
+    ],
+    copts = ["-DUNARY_TEST_TARGET_COMPLEX"],
+    real_hardware_only = True,  # Very slow on the interpreter.
+    shard_count = 48,
+    tags = [
+        "optonly",
+        # This is a big test that we skip for capacity reasons in OSS testing.
+        "no_oss",
+    ],
+    deps = [
+        ":exhaustive_op_test_utils",
+    ],
+)
+
+xla_test(
+    name = "exhaustive_binary_test_f16",
+    srcs = ["exhaustive_binary_test.cc"],
+    backends = [
+        "gpu",
+        "cpu",
+    ],
+    copts = ["-DBINARY_TEST_TARGET_F16"],
+    real_hardware_only = True,  # Very slow on the interpreter.
+    shard_count = 48,
+    tags = [
+        "optonly",
+        # This is a big test that we skip for capacity reasons in OSS testing.
+        "no_oss",
+    ],
+    deps = [
+        ":exhaustive_op_test_utils",
+    ],
+)
+
+xla_test(
+    name = "exhaustive_binary_test_bf16",
+    srcs = ["exhaustive_binary_test.cc"],
+    backends = [
+        "gpu",
+        "cpu",
+    ],
+    copts = ["-DBINARY_TEST_TARGET_BF16"],
+    real_hardware_only = True,  # Very slow on the interpreter.
+    shard_count = 48,
+    tags = [
+        "optonly",
+        # This is a big test that we skip for capacity reasons in OSS testing.
+        "no_oss",
+    ],
+    deps = [
+        ":exhaustive_op_test_utils",
+    ],
+)
+
+xla_test(
+    name = "exhaustive_binary_test_f32",
+    srcs = ["exhaustive_binary_test.cc"],
+    backends = [
+        "gpu",
+        "cpu",
+    ],
+    copts = ["-DBINARY_TEST_TARGET_F32"],
+    real_hardware_only = True,  # Very slow on the interpreter.
+    shard_count = 48,
+    tags = [
+        "optonly",
+        # This is a big test that we skip for capacity reasons in OSS testing.
+        "no_oss",
+    ],
+    deps = [
+        ":exhaustive_op_test_utils",
+    ],
+)
+
+xla_test(
+    name = "exhaustive_binary_test_f64",
+    srcs = ["exhaustive_binary_test.cc"],
+    backends = [
+        "gpu",
+        "cpu",
+    ],
+    copts = ["-DBINARY_TEST_TARGET_F64"],
     real_hardware_only = True,  # Very slow on the interpreter.
     shard_count = 48,
     tags = [
@@ -1954,7 +2099,6 @@
         "//tensorflow/compiler/xla/service:llvm_compiler",
         "//tensorflow/compiler/xla/service:platform_util",
         "//tensorflow/compiler/xla/service/cpu:cpu_compiler",
-        "//tensorflow/compiler/xla/service/gpu:nvptx_compiler",
         "//tensorflow/compiler/xla/service/gpu:nvptx_compiler_impl",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
@@ -1987,8 +2131,8 @@
 )
 
 xla_test(
-    name = "fusion_test",
-    srcs = ["fusion_test.cc"],
+    name = "cpu_gpu_fusion_test",
+    srcs = ["cpu_gpu_fusion_test.cc"],
     deps = [
         ":test_macros_header",
         "//tensorflow/compiler/xla:array2d",
diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl
index 48719c6..7153ace 100644
--- a/tensorflow/compiler/xla/tests/build_defs.bzl
+++ b/tensorflow/compiler/xla/tests/build_defs.bzl
@@ -4,7 +4,7 @@
 load("//tensorflow/compiler/xla/tests:plugin.bzl", "plugins")
 load("//tensorflow:tensorflow.bzl", "tf_cc_test")
 load(
-    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "//tensorflow/core/platform:default/build_config_root.bzl",
     "tf_cuda_tests_tags",
 )
 
diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_common.cc b/tensorflow/compiler/xla/tests/conv_depthwise_common.cc
new file mode 100644
index 0000000..e11ec33
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/conv_depthwise_common.cc
@@ -0,0 +1,135 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/conv_depthwise_common.h"
+
+#include "absl/types/optional.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/execution_options_util.h"
+#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
+#include "tensorflow/compiler/xla/service/despecializer.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+
+namespace xla {
+string GetFloatDataType(bool use_bfloat16) {
+  return use_bfloat16 ? "bf16" : "f32";
+}
+
+string DepthwiseConvolution2DTestDataToString(
+    const ::testing::TestParamInfo<
+        ::testing::tuple<DepthwiseConvolution2DSpec, bool>>& data) {
+  const auto& spec = ::testing::get<0>(data.param);
+  const string data_type = GetFloatDataType(::testing::get<1>(data.param));
+  string str = absl::StrCat(
+      "activation_dims_", absl::StrJoin(spec.activation_dims, "x"),
+      "_activation_layout_", absl::StrJoin(spec.activation_layout, "_"),
+      "_kernel_dims_", absl::StrJoin(spec.kernel_dims, "x"), "_kernel_layout_",
+      absl::StrJoin(spec.kernel_layout, "_"), "_output_dims_",
+      absl::StrJoin(spec.output_dims, "x"), "_output_layout_",
+      absl::StrJoin(spec.output_layout, "_"), data_type);
+  // -1 indicates non-existence.
+  if (spec.stride != -1) {
+    absl::StrAppend(&str, "_lhs_dilation_", spec.lhs_dilate, "x1");
+  }
+
+  // Test names are not allowed to contain the '-' character.
+  absl::c_replace(str, '-', 'n');
+  return str;
+}
+
+string BuildHloTextDepthwiseConvolution2D(
+    const DepthwiseConvolution2DSpec& spec, bool use_bfloat16,
+    bool is_scheduled) {
+  const string data_type = GetFloatDataType(use_bfloat16);
+  const string sched_tag = is_scheduled ? ", is_scheduled=true " : "";
+  if (spec.activation_dims[1] == 1 && spec.kernel_dims[1] == 2) {
+    return absl::StrFormat(
+        R"(
+    HloModule TensorFlowDepthwiseConv %s
+    ENTRY main {
+      activation = %s[%s]{%s} parameter(0)
+      kernel = %s[%s]{%s} parameter(1)
+      ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel),
+          window={size=%dx%d  pad=1_1x%d_%d rhs_dilate=1x%d}, dim_labels=b01f_01io->b01f,
+          feature_group_count=%d
+    }
+    )",
+        sched_tag, data_type, absl::StrJoin(spec.activation_dims, ","),
+        absl::StrJoin(spec.activation_layout, ","), data_type,
+        absl::StrJoin(spec.kernel_dims, ","),
+        absl::StrJoin(spec.kernel_layout, ","), data_type,
+        absl::StrJoin(spec.output_dims, ","),
+        absl::StrJoin(spec.output_layout, ","), data_type,
+        absl::StrJoin(spec.activation_dims, ","),
+        absl::StrJoin(spec.activation_layout, ","), data_type,
+        absl::StrJoin(spec.kernel_dims, ","),
+        absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window,
+        spec.window, spec.window, spec.window, spec.output_feature);
+
+  } else if (spec.stride == -1) {
+    return absl::StrFormat(
+        R"(
+      HloModule TensorFlowDepthwiseConv %s
+      ENTRY main {
+        activation = %s[%s]{%s} parameter(0)
+        kernel = %s[%s]{%s} parameter(1)
+        ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel),
+            window={size=%dx%d}, dim_labels=b01f_01io->b01f,
+            feature_group_count=%d
+      }
+      )",
+        sched_tag, data_type, absl::StrJoin(spec.activation_dims, ","),
+        absl::StrJoin(spec.activation_layout, ","), data_type,
+        absl::StrJoin(spec.kernel_dims, ","),
+        absl::StrJoin(spec.kernel_layout, ","), data_type,
+        absl::StrJoin(spec.output_dims, ","),
+        absl::StrJoin(spec.output_layout, ","), data_type,
+        absl::StrJoin(spec.activation_dims, ","),
+        absl::StrJoin(spec.activation_layout, ","), data_type,
+        absl::StrJoin(spec.kernel_dims, ","),
+        absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window,
+        spec.output_feature);
+  } else {
+    return absl::StrFormat(
+        R"(
+    HloModule TensorFlowDepthwiseConv %s
+
+    ENTRY main {
+      activation = %s[%s]{%s} parameter(0)
+      kernel = %s[%s]{%s} parameter(1)
+      ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel),
+          window={size=%dx%d stride=%dx1 pad=%d_%dx0_0 lhs_dilate=%dx1}, 
+          dim_labels=b01f_01io->b01f, feature_group_count=%d
+    }
+    )",
+        sched_tag, data_type, absl::StrJoin(spec.activation_dims, ","),
+        absl::StrJoin(spec.activation_layout, ","), data_type,
+        absl::StrJoin(spec.kernel_dims, ","),
+        absl::StrJoin(spec.kernel_layout, ","), data_type,
+        absl::StrJoin(spec.output_dims, ","),
+        absl::StrJoin(spec.output_layout, ","), data_type,
+        absl::StrJoin(spec.activation_dims, ","),
+        absl::StrJoin(spec.activation_layout, ","), data_type,
+        absl::StrJoin(spec.kernel_dims, ","),
+        absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window,
+        spec.stride, 0, 0, spec.lhs_dilate, spec.output_feature);
+  }
+}
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_common.h b/tensorflow/compiler/xla/tests/conv_depthwise_common.h
new file mode 100644
index 0000000..0c00f8d
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/conv_depthwise_common.h
@@ -0,0 +1,53 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_TESTS_CONV_DEPTHWISE_COMMON_H_
+#define TENSORFLOW_COMPILER_XLA_TESTS_CONV_DEPTHWISE_COMMON_H_
+
+#include "absl/types/optional.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/execution_options_util.h"
+#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
+#include "tensorflow/compiler/xla/service/despecializer.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+
+namespace xla {
+string GetFloatDataType(bool use_bfloat16);
+
+struct DepthwiseConvolution2DSpec {
+  int64 output_feature, window, stride, pad, lhs_dilate;
+  std::vector<int64> activation_dims;
+  std::vector<int64> activation_layout;
+  std::vector<int64> kernel_dims;
+  std::vector<int64> kernel_layout;
+  std::vector<int64> output_dims;
+  std::vector<int64> output_layout;
+};
+
+string DepthwiseConvolution2DTestDataToString(
+    const ::testing::TestParamInfo<
+        ::testing::tuple<DepthwiseConvolution2DSpec, bool>>& data);
+
+string BuildHloTextDepthwiseConvolution2D(
+    const DepthwiseConvolution2DSpec& spec, bool use_bfloat16,
+    bool is_scheduled = false);
+
+}  // namespace xla
+#endif  // TENSORFLOW_COMPILER_XLA_TESTS_CONV_DEPTHWISE_COMMON_H_
diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_test.cc b/tensorflow/compiler/xla/tests/conv_depthwise_test.cc
index fe95824..98f6b5b 100644
--- a/tensorflow/compiler/xla/tests/conv_depthwise_test.cc
+++ b/tensorflow/compiler/xla/tests/conv_depthwise_test.cc
@@ -22,26 +22,13 @@
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/test.h"
 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/conv_depthwise_common.h"
 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
 #include "tensorflow/compiler/xla/tests/test_macros.h"
 
 namespace xla {
 namespace {
 
-string GetFloatDataType(bool use_bfloat16) {
-  return use_bfloat16 ? "bf16" : "f32";
-}
-
-struct DepthwiseConvolution2DSpec {
-  int64 output_feature, window, stride, pad, lhs_dilate;
-  std::vector<int64> activation_dims;
-  std::vector<int64> activation_layout;
-  std::vector<int64> kernel_dims;
-  std::vector<int64> kernel_layout;
-  std::vector<int64> output_dims;
-  std::vector<int64> output_layout;
-};
-
 class DepthwiseConvolution2DTest
     : public HloTestBase,
       public ::testing::WithParamInterface<
@@ -70,6 +57,7 @@
 
     config.kernel_dims = {kernel_size, kernel_size, 1, feature};
     config.kernel_layout = {3, 2, 1, 0};
+    config.output_layout = {3, 0, 2, 1};
 
     if (activation_size == 1 && kernel_size == 2) {
       // Test for outer dim.
@@ -87,127 +75,12 @@
       config.output_dims = {batch, activation_size - kernel_size + 1,
                             activation_size - kernel_size + 1, feature};
     }
-
-    // Try this layout for all kernel shapes.
-    config.output_layout = {3, 0, 2, 1};
     config_set.push_back(config);
-
-    // Try other layouts only for certain kernel shapes.
-    if (kernel_size % 2 == 0) {
-      config.activation_layout = {0, 3, 2, 1};
-      config_set.push_back(config);
-
-      config.output_layout = {0, 3, 2, 1};
-      config_set.push_back(config);
-
-      config.activation_layout = {3, 0, 2, 1};
-      config_set.push_back(config);
-    }
   }
 
   return config_set;
 }
 
-string DepthwiseConvolution2DTestDataToString(
-    const ::testing::TestParamInfo<
-        ::testing::tuple<DepthwiseConvolution2DSpec, bool>>& data) {
-  const auto& spec = ::testing::get<0>(data.param);
-  const string data_type = GetFloatDataType(::testing::get<1>(data.param));
-  string str = absl::StrCat(
-      "activation_dims_", absl::StrJoin(spec.activation_dims, "x"),
-      "_activation_layout_", absl::StrJoin(spec.activation_layout, "_"),
-      "_kernel_dims_", absl::StrJoin(spec.kernel_dims, "x"), "_kernel_layout_",
-      absl::StrJoin(spec.kernel_layout, "_"), "_output_dims_",
-      absl::StrJoin(spec.output_dims, "x"), "_output_layout_",
-      absl::StrJoin(spec.output_layout, "_"), data_type);
-  // -1 indicates non-existence.
-  if (spec.stride != -1) {
-    absl::StrAppend(&str, "_lhs_dilation_", spec.lhs_dilate, "x1");
-  }
-
-  // Test names are not allowed to contain the '-' character.
-  absl::c_replace(str, '-', 'n');
-  return str;
-}
-
-string BuildHloTextDepthwiseConvolution2D(
-    const DepthwiseConvolution2DSpec& spec, bool use_bfloat16) {
-  const string data_type = GetFloatDataType(use_bfloat16);
-  if (spec.activation_dims[1] == 1 && spec.kernel_dims[1] == 2) {
-    return absl::StrFormat(
-        R"(
-    HloModule TensorFlowDepthwiseConv
-
-    ENTRY main {
-      activation = %s[%s]{%s} parameter(0)
-      kernel = %s[%s]{%s} parameter(1)
-      ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel),
-          window={size=%dx%d  pad=1_1x%d_%d rhs_dilate=1x%d}, dim_labels=b01f_01io->b01f,
-          feature_group_count=%d
-    }
-    )",
-        data_type, absl::StrJoin(spec.activation_dims, ","),
-        absl::StrJoin(spec.activation_layout, ","), data_type,
-        absl::StrJoin(spec.kernel_dims, ","),
-        absl::StrJoin(spec.kernel_layout, ","), data_type,
-        absl::StrJoin(spec.output_dims, ","),
-        absl::StrJoin(spec.output_layout, ","), data_type,
-        absl::StrJoin(spec.activation_dims, ","),
-        absl::StrJoin(spec.activation_layout, ","), data_type,
-        absl::StrJoin(spec.kernel_dims, ","),
-        absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window,
-        spec.window, spec.window, spec.window, spec.output_feature);
-
-  } else if (spec.stride == -1) {
-    return absl::StrFormat(
-        R"(
-      HloModule TensorFlowDepthwiseConv
-
-      ENTRY main {
-        activation = %s[%s]{%s} parameter(0)
-        kernel = %s[%s]{%s} parameter(1)
-        ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel),
-            window={size=%dx%d}, dim_labels=b01f_01io->b01f,
-            feature_group_count=%d
-      }
-      )",
-        data_type, absl::StrJoin(spec.activation_dims, ","),
-        absl::StrJoin(spec.activation_layout, ","), data_type,
-        absl::StrJoin(spec.kernel_dims, ","),
-        absl::StrJoin(spec.kernel_layout, ","), data_type,
-        absl::StrJoin(spec.output_dims, ","),
-        absl::StrJoin(spec.output_layout, ","), data_type,
-        absl::StrJoin(spec.activation_dims, ","),
-        absl::StrJoin(spec.activation_layout, ","), data_type,
-        absl::StrJoin(spec.kernel_dims, ","),
-        absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window,
-        spec.output_feature);
-  } else {
-    return absl::StrFormat(
-        R"(
-    HloModule TensorFlowDepthwiseConv
-
-    ENTRY main {
-      activation = %s[%s]{%s} parameter(0)
-      kernel = %s[%s]{%s} parameter(1)
-      ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel),
-          window={size=%dx%d stride=%dx1 pad=%d_%dx0_0 lhs_dilate=%dx1}, 
-          dim_labels=b01f_01io->b01f, feature_group_count=%d
-    }
-    )",
-        data_type, absl::StrJoin(spec.activation_dims, ","),
-        absl::StrJoin(spec.activation_layout, ","), data_type,
-        absl::StrJoin(spec.kernel_dims, ","),
-        absl::StrJoin(spec.kernel_layout, ","), data_type,
-        absl::StrJoin(spec.output_dims, ","),
-        absl::StrJoin(spec.output_layout, ","), data_type,
-        absl::StrJoin(spec.activation_dims, ","),
-        absl::StrJoin(spec.activation_layout, ","), data_type,
-        absl::StrJoin(spec.kernel_dims, ","),
-        absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window,
-        spec.stride, 0, 0, spec.lhs_dilate, spec.output_feature);
-  }
-}
 
 XLA_TEST_P(DepthwiseConvolution2DTest, DoIt) {
   const DepthwiseConvolution2DSpec& spec = ::testing::get<0>(GetParam());
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index 0fae5d9..e8e8229 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -1942,7 +1942,8 @@
 
 class ConvolutionHloTest : public HloTestBase {};
 
-XLA_TEST_F(ConvolutionHloTest, ConvolveF64Forward) {
+// double datatype is not yet supported in ROCm
+XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_GPU_ROCM(ConvolveF64Forward)) {
   constexpr char kHlo[] = R"(
 HloModule TestModule
 
@@ -1966,7 +1967,9 @@
   EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
 }
 
-XLA_TEST_F(ConvolutionHloTest, ConvolveF64BackwardFilter) {
+// double datatype is not yet supported in ROCm
+XLA_TEST_F(ConvolutionHloTest,
+           DISABLED_ON_GPU_ROCM(ConvolveF64BackwardFilter)) {
   constexpr char kHlo[] = R"(
 HloModule TestModule
 
@@ -1978,7 +1981,8 @@
   EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
 }
 
-XLA_TEST_F(ConvolutionHloTest, ConvolveF64BackwardInput) {
+// double datatype is not yet supported in ROCm
+XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_GPU_ROCM(ConvolveF64BackwardInput)) {
   constexpr char kHlo[] = R"(
 HloModule TestModule
 
diff --git a/tensorflow/compiler/xla/tests/cpu_gpu_fusion_test.cc b/tensorflow/compiler/xla/tests/cpu_gpu_fusion_test.cc
new file mode 100644
index 0000000..7719e89
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/cpu_gpu_fusion_test.cc
@@ -0,0 +1,931 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <math.h>
+
+#include <algorithm>
+#include <memory>
+#include <new>
+#include <random>
+#include <utility>
+
+#define EIGEN_USE_THREADS
+
+#include "absl/memory/memory.h"
+#include "absl/types/span.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/service/platform_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+const int test_width = 2, test_height = 3;
+
+const float test_float_vals[3][test_width][test_height] = {
+    {{-1.0, -1.0, 1.0}, {-3.0, 0.0, -1.0}},
+    {{-3.0, 2.0, 1.0}, {0.0, -3.0, 1.0}},
+    {{-3.0, 0.0, -3.0}, {-1.0, -2.0, 1.0}}};
+
+// Test whether fusion operations are emitted with no errors and compute
+// accurate outputs.
+class CpuGpuFusionTest : public HloTestBase {
+ protected:
+  template <typename T, int Arity>
+  void TestElementwise2D(
+      HloOpcode opcode,
+      absl::optional<ComparisonDirection> direction = absl::nullopt) {
+    // Create a variable for comparisons since they require the direction.
+    bool is_compare = std::is_same<T, bool>::value;
+    Array2D<float> operand_data[Arity];
+    for (int i = 0; i < Arity; ++i) {
+      new (&operand_data[i]) Array2D<float>(test_width, test_height);
+    }
+    Array2D<T> answer_data(test_width, test_height);
+    for (int i = 0; i < test_width; ++i) {
+      for (int j = 0; j < test_height; ++j) {
+        float xs[Arity];
+        for (int k = 0; k < Arity; ++k) {
+          xs[k] = test_float_vals[k][i][j];
+          operand_data[k](i, j) = xs[k];
+        }
+        if (is_compare) {
+          answer_data(i, j) = ComputeElementwiseAnswerCompare(*direction, xs);
+        } else {
+          answer_data(i, j) = ComputeElementwiseAnswerFloat(opcode, xs);
+        }
+      }
+    }
+
+    auto builder = HloComputation::Builder(TestName());
+    auto hlo_module = CreateNewVerifiedModule();
+
+    auto prim_type = primitive_util::NativeToPrimitiveType<T>();
+
+    HloInstruction* hlos[4];
+    for (int i = 0; i < Arity; ++i) {
+      hlos[i + 1] = builder.AddInstruction(HloInstruction::CreateConstant(
+          LiteralUtil::CreateR2FromArray2D(operand_data[i])));
+    }
+    auto answer_shape =
+        ShapeUtil::MakeShape(prim_type, {test_width, test_height});
+    std::unique_ptr<HloInstruction> root_hlo;
+    switch (Arity) {
+      case 1:
+        root_hlo = HloInstruction::CreateUnary(answer_shape, opcode, hlos[1]);
+        break;
+      case 2:
+        if (is_compare) {
+          root_hlo = HloInstruction::CreateCompare(answer_shape, hlos[1],
+                                                   hlos[2], *direction);
+        } else {
+          root_hlo = HloInstruction::CreateBinary(answer_shape, opcode, hlos[1],
+                                                  hlos[2]);
+        }
+        break;
+      case 3:
+        root_hlo = HloInstruction::CreateTernary(answer_shape, opcode, hlos[1],
+                                                 hlos[2], hlos[3]);
+        break;
+      default:
+        LOG(FATAL) << "Bad arity: " << Arity;
+    }
+    hlos[0] = builder.AddInstruction(std::move(root_hlo));
+    hlo_module->AddEntryComputation(builder.Build())
+        ->CreateFusionInstruction(
+            absl::Span<HloInstruction* const>(hlos).subspan(0, Arity + 1),
+            HloInstruction::FusionKind::kLoop);
+
+    auto expected = LiteralUtil::CreateR2FromArray2D(answer_data);
+    auto actual = ExecuteAndTransfer(std::move(hlo_module), {});
+    if (primitive_util::IsFloatingPointType(prim_type)) {
+      EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, ErrorSpec(1e-4)));
+    } else {
+      EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
+    }
+  }
+
+ private:
+  float ComputeElementwiseAnswerFloat(HloOpcode opcode,
+                                      absl::Span<const float> xs);
+  bool ComputeElementwiseAnswerCompare(ComparisonDirection direction,
+                                       absl::Span<const float> xs);
+  DebugOptions GetDebugOptionsForTest() override {
+    DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
+    debug_options.add_xla_disable_hlo_passes("layout-assignment");
+    return debug_options;
+  }
+};
+
+float CpuGpuFusionTest::ComputeElementwiseAnswerFloat(
+    HloOpcode opcode, absl::Span<const float> xs) {
+  switch (opcode) {
+    case HloOpcode::kAdd:
+      return xs[0] + xs[1];
+    case HloOpcode::kSubtract:
+      return xs[0] - xs[1];
+    case HloOpcode::kMultiply:
+      return xs[0] * xs[1];
+    case HloOpcode::kDivide:
+      return xs[0] / xs[1];
+    case HloOpcode::kPower:
+      return powf(xs[0], xs[1]);
+    case HloOpcode::kMinimum:
+      return std::min(xs[0], xs[1]);
+    case HloOpcode::kMaximum:
+      return std::max(xs[0], xs[1]);
+    case HloOpcode::kClamp:
+      return std::min(xs[2], std::max(xs[1], xs[0]));
+    default:
+      LOG(FATAL) << "No elementwise opcode: " << opcode;
+  }
+}
+
+bool CpuGpuFusionTest::ComputeElementwiseAnswerCompare(
+    ComparisonDirection direction, absl::Span<const float> xs) {
+  switch (direction) {
+    case ComparisonDirection::kEq:
+      return xs[0] == xs[1];
+    case ComparisonDirection::kNe:
+      return xs[0] != xs[1];
+    case ComparisonDirection::kGt:
+      return xs[0] > xs[1];
+    case ComparisonDirection::kLt:
+      return xs[0] < xs[1];
+    case ComparisonDirection::kGe:
+      return xs[0] >= xs[1];
+    case ComparisonDirection::kLe:
+      return xs[0] <= xs[1];
+  }
+}
+
+XLA_TEST_F(CpuGpuFusionTest, Test) {
+  // test expression:
+  // slice(select({{T, F, T}, {F, T, F}},
+  //              concat(transpose({{1.0}, {2.0}, {3.0}} +
+  //                               {{-1.0}, {-1.0}, {-1.0}}),
+  //                     {{1.62, 2.72, 3.14}}) +
+  //                     (-{{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}),
+  //              {{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})) = {{0.5}, {2.72}}
+  auto builder = HloComputation::Builder(TestName());
+  auto hlo_module = CreateNewVerifiedModule();
+  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::CreateR2<float>({{1.0}, {2.0}, {3.0}})));
+  auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::CreateR2<float>({{-1.0}, {-1.0}, {-1.0}})));
+  auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
+      ShapeUtil::MakeShape(F32, {3, 1}), HloOpcode::kAdd, const0, const1));
+  auto reshape3 = builder.AddInstruction(HloInstruction::CreateTranspose(
+      ShapeUtil::MakeShape(F32, {1, 3}), add2, {1, 0}));
+  auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::CreateR2<float>({{1.62, 2.72, 3.14}})));
+  auto concat5 = builder.AddInstruction(HloInstruction::CreateConcatenate(
+      ShapeUtil::MakeShape(F32, {2, 3}), {reshape3, const4}, 0));
+  auto const6 = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::CreateR2<float>({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}})));
+  auto negate7 = builder.AddInstruction(HloInstruction::CreateUnary(
+      ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kNegate, const6));
+  auto add8 = builder.AddInstruction(HloInstruction::CreateBinary(
+      ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kAdd, concat5, negate7));
+  auto const9 = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::CreateR2<float>({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})));
+  auto const10 = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR2<bool>(
+          {{true, false, true}, {false, true, false}})));
+  auto select11 = builder.AddInstruction(
+      HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}),
+                                    HloOpcode::kSelect, const10, add8, const9));
+  auto slice12 = builder.AddInstruction(HloInstruction::CreateSlice(
+      ShapeUtil::MakeShape(F32, {2, 1}), select11, {0, 1}, {2, 2}, {1, 1}));
+  // CreateFusionInstruction needs the `instructions_to_fuse` argument in
+  // reverse topological order, so the first element in `instructions_to_fuse`
+  // must be the root.
+  hlo_module->AddEntryComputation(builder.Build())
+      ->CreateFusionInstruction(
+          {slice12, select11, const10, const9, add8, negate7, const6, concat5,
+           const4, reshape3, add2, const1, const0},
+          HloInstruction::FusionKind::kLoop);
+
+  EXPECT_TRUE(LiteralTestUtil::Near(
+      LiteralUtil::CreateR2<float>({{0.5}, {2.72}}),
+      ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
+}
+
+// Test whether we emit appropriate code for parameters of fusion instructions.
+XLA_TEST_F(CpuGpuFusionTest, Parameter) {
+  // Build a computation and fuse part of it so the fusion instruction has an
+  // operand parameter.
+  auto builder = HloComputation::Builder(TestName());
+  auto hlo_module = CreateNewVerifiedModule();
+  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}})));
+  auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary(
+      ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kCopy, const0));
+  auto const2 = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::CreateR2<float>({{-2.0, -2.0, -2.0}})));
+  // add3 = copy1 + const2 = const0 + const2 = {1,2,3} + {-2,-2,-2} = {-1,0,+1}
+  auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
+      ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kAdd, copy1, const2));
+  // CreateFusionInstruction needs `instructions_to_fuse` in reverse topological
+  // order.
+  hlo_module->AddEntryComputation(builder.Build())
+      ->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2},
+                                HloInstruction::FusionKind::kLoop);
+
+  EXPECT_TRUE(LiteralTestUtil::Near(
+      LiteralUtil::CreateR2<float>({{-1.0, 0.0, 1.0}}),
+      ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
+}
+
+XLA_TEST_F(CpuGpuFusionTest, RandomizedParallelPartition) {
+  // Tests parallel partitioning of a fusion instruction.
+  // Create shape with random outer dimension size to generate random parallel
+  // partition counts for each test run.
+  const int seed = tensorflow::testing::RandomSeed();
+  LOG(INFO) << "RandomizedParallelPartition seed: " << seed;
+  std::mt19937 generator(seed);
+  std::uniform_int_distribution<int> distribution(128, 1024);
+  const int64 rand_dim0_size = distribution(generator);
+  const int64 dim1_size = 1024;
+  Shape shape =
+      ShapeUtil::MakeShapeWithLayout(F32, {rand_dim0_size, dim1_size}, {1, 0});
+  // Build simple fusion computation: y = x^2 (elementwise).
+  auto builder = HloComputation::Builder(TestName());
+  auto hlo_module = CreateNewVerifiedModule();
+
+  auto two = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
+  auto x =
+      builder.AddInstruction(HloInstruction::CreateBroadcast(shape, two, {}));
+  auto y = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, x, x));
+
+  hlo_module->AddEntryComputation(builder.Build())
+      ->CreateFusionInstruction(/*instructions_to_fuse=*/{y, x, two},
+                                HloInstruction::FusionKind::kLoop);
+  // Compute result.
+  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+  // Every element of result should be y = x^2 = 4.0.
+  for (int i = 0; i < rand_dim0_size; ++i) {
+    for (int j = 0; j < dim1_size; ++j) {
+      EXPECT_EQ(4.0, result.Get<float>({i, j}));
+    }
+  }
+}
+
+XLA_TEST_F(CpuGpuFusionTest, BroadcastIntoBinaryOp) {
+  auto builder = HloComputation::Builder(TestName());
+  auto hlo_module = CreateNewVerifiedModule();
+  auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
+  auto const_array = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::CreateR2<float>({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}})));
+  auto broadcast = builder.AddInstruction(
+      HloInstruction::CreateBroadcast(const_array->shape(), const_vector, {1}));
+  // add2 = broadcast(const_vector) + const_array
+  //      = broadcast({1,2,3}) + {{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}
+  //      = {{1, 2, 3}, {1, 2, 3}} + {{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}
+  auto add2 = builder.AddInstruction(
+      HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {2, 3}),
+                                   HloOpcode::kAdd, broadcast, const_array));
+  hlo_module->AddEntryComputation(builder.Build())
+      ->CreateFusionInstruction(/*instructions_to_fuse=*/{add2, broadcast},
+                                HloInstruction::FusionKind::kLoop);
+
+  EXPECT_TRUE(LiteralTestUtil::Near(
+      LiteralUtil::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
+      ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
+}
+
+XLA_TEST_F(CpuGpuFusionTest, ReshapeToScalar) {
+  auto builder = HloComputation::Builder(TestName());
+  auto hlo_module = CreateNewVerifiedModule();
+  auto single_element_array = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR2<int32>({{5}})));
+  auto reshape = builder.AddInstruction(HloInstruction::CreateReshape(
+      ShapeUtil::MakeShape(S32, {}), single_element_array));
+  hlo_module->AddEntryComputation(builder.Build())
+      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape},
+                                HloInstruction::FusionKind::kLoop);
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(5),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
+}
+
+XLA_TEST_F(CpuGpuFusionTest, Reshape_3by2_1by2by3) {
+  auto builder = HloComputation::Builder(TestName());
+  auto hlo_module = CreateNewVerifiedModule();
+  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}})));
+  auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
+      ShapeUtil::MakeShape(S32, {1, 2, 3}), const0));
+  hlo_module->AddEntryComputation(builder.Build())
+      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
+                                HloInstruction::FusionKind::kLoop);
+  EXPECT_TRUE(LiteralTestUtil::Equal(
+      LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
+      ExecuteAndTransfer(std::move(hlo_module), {})));
+}
+
+XLA_TEST_F(CpuGpuFusionTest, Reshape_1by2by3_3by2) {
+  auto builder = HloComputation::Builder(TestName());
+  auto hlo_module = CreateNewVerifiedModule();
+  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}})));
+  auto reshape1 = builder.AddInstruction(
+      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 2}), const0));
+  hlo_module->AddEntryComputation(builder.Build())
+      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
+                                HloInstruction::FusionKind::kLoop);
+  EXPECT_TRUE(LiteralTestUtil::Equal(
+      LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
+      ExecuteAndTransfer(std::move(hlo_module), {})));
+}
+
+XLA_TEST_F(CpuGpuFusionTest, Reshape_1by1by1_) {
+  auto builder = HloComputation::Builder(TestName());
+  auto hlo_module = CreateNewVerifiedModule();
+  auto const0 = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR3<int32>({{{7}}})));
+  auto reshape1 = builder.AddInstruction(
+      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0));
+  hlo_module->AddEntryComputation(builder.Build())
+      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
+                                HloInstruction::FusionKind::kLoop);
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(7),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
+}
+
+XLA_TEST_F(CpuGpuFusionTest, Reshape__1by1by1) {
+  auto builder = HloComputation::Builder(TestName());
+  auto hlo_module = CreateNewVerifiedModule();
+  auto const0 = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(7)));
+  auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
+      ShapeUtil::MakeShape(S32, {1, 1, 1}), const0));
+  hlo_module->AddEntryComputation(builder.Build())
+      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
+                                HloInstruction::FusionKind::kLoop);
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(LiteralUtil::CreateR3<int32>({{{7}}}),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
+}
+
+XLA_TEST_F(CpuGpuFusionTest, Reshape__) {
+  auto builder = HloComputation::Builder(TestName());
+  auto hlo_module = CreateNewVerifiedModule();
+  auto const0 = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(7)));
+  auto reshape1 = builder.AddInstruction(
+      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0));
+  hlo_module->AddEntryComputation(builder.Build())
+      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
+                                HloInstruction::FusionKind::kLoop);
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(7),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
+}
+
+XLA_TEST_F(CpuGpuFusionTest, Reshape_3by3_3by3) {
+  auto builder = HloComputation::Builder(TestName());
+  auto hlo_module = CreateNewVerifiedModule();
+  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
+  auto reshape1 = builder.AddInstruction(
+      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 3}), const0));
+  hlo_module->AddEntryComputation(builder.Build())
+      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
+                                HloInstruction::FusionKind::kLoop);
+  EXPECT_TRUE(LiteralTestUtil::Equal(
+      LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
+      ExecuteAndTransfer(std::move(hlo_module), {})));
+}
+
+XLA_TEST_F(CpuGpuFusionTest, Transpose_2by3) {
+  auto builder = HloComputation::Builder(TestName());
+  auto hlo_module = CreateNewVerifiedModule();
+  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}})));
+  auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
+      ShapeUtil::MakeShape(S32, {3, 2}), const0, {1, 0}));
+  hlo_module->AddEntryComputation(builder.Build())
+      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
+                                HloInstruction::FusionKind::kLoop);
+  EXPECT_TRUE(LiteralTestUtil::Equal(
+      LiteralUtil::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
+      ExecuteAndTransfer(std::move(hlo_module), {})));
+}
+
+XLA_TEST_F(CpuGpuFusionTest, Transpose_3by3) {
+  auto builder = HloComputation::Builder(TestName());
+  auto hlo_module = CreateNewVerifiedModule();
+  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
+  auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
+      ShapeUtil::MakeShape(S32, {3, 3}), const0, {1, 0}));
+  hlo_module->AddEntryComputation(builder.Build())
+      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
+                                HloInstruction::FusionKind::kLoop);
+  EXPECT_TRUE(LiteralTestUtil::Equal(
+      LiteralUtil::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
+      ExecuteAndTransfer(std::move(hlo_module), {})));
+}
+
+XLA_TEST_F(CpuGpuFusionTest, Reverse) {
+  auto builder = HloComputation::Builder(TestName());
+  auto hlo_module = CreateNewVerifiedModule();
+  auto const0 = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1, 2, 3})));
+  auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse(
+      ShapeUtil::MakeShape(S32, {3}), const0, {0}));
+  hlo_module->AddEntryComputation(builder.Build())
+      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1},
+                                HloInstruction::FusionKind::kLoop);
+
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({3, 2, 1}),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
+}
+
+XLA_TEST_F(CpuGpuFusionTest, ReverseNegate) {
+  auto builder = HloComputation::Builder(TestName());
+  auto hlo_module = CreateNewVerifiedModule();
+  auto const0 = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1, 2, 3})));
+  auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse(
+      ShapeUtil::MakeShape(S32, {3}), const0, {0}));
+  auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
+      ShapeUtil::MakeShape(S32, {3}), HloOpcode::kNegate, reverse1));
+  hlo_module->AddEntryComputation(builder.Build())
+      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reverse1},
+                                HloInstruction::FusionKind::kLoop);
+
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-3, -2, -1}),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
+}
+
+XLA_TEST_F(CpuGpuFusionTest, BroadcastNegate) {
+  auto builder = HloComputation::Builder(TestName());
+  auto hlo_module = CreateNewVerifiedModule();
+  auto const0 = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
+  auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
+      ShapeUtil::MakeShape(S32, {2}), const0, {}));
+  auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
+      ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, broadcast1));
+  hlo_module->AddEntryComputation(builder.Build())
+      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, broadcast1},
+                                HloInstruction::FusionKind::kLoop);
+
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-1, -1}),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
+}
+
+XLA_TEST_F(CpuGpuFusionTest, SliceNegate) {
+  auto builder = HloComputation::Builder(TestName());
+  auto hlo_module = CreateNewVerifiedModule();
+  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::CreateR1<int32>({1, 2, 3, 4})));
+  auto slice1 = builder.AddInstruction(HloInstruction::CreateSlice(
+      ShapeUtil::MakeShape(S32, {2}), const0, {0}, {4}, {2}));
+  auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
+      ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, slice1));
+  hlo_module->AddEntryComputation(builder.Build())
+      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, slice1},
+                                HloInstruction::FusionKind::kLoop);
+
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-1, -3}),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
+}
+
+XLA_TEST_F(CpuGpuFusionTest, DynamicSliceNegate) {
+  auto builder = HloComputation::Builder(TestName());
+  auto hlo_module = CreateNewVerifiedModule();
+  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::CreateR1<int32>({1, 2, 3, 4})));
+  auto const1 = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
+  auto dynamic_slice2 =
+      builder.AddInstruction(HloInstruction::CreateDynamicSlice(
+          ShapeUtil::MakeShape(S32, {2}), const0, {const1}, {2}));
+  auto negate3 = builder.AddInstruction(HloInstruction::CreateUnary(
+      ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, dynamic_slice2));
+  hlo_module->AddEntryComputation(builder.Build())
+      ->CreateFusionInstruction(
+          /*instructions_to_fuse=*/{negate3, dynamic_slice2},
+          HloInstruction::FusionKind::kLoop);
+
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-2, -3}),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
+}
+
+XLA_TEST_F(CpuGpuFusionTest, ReshapeNegate) {
+  auto builder = HloComputation::Builder(TestName());
+  auto hlo_module = CreateNewVerifiedModule();
+  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::CreateR1<int32>({1, 2, 3, 4})));
+  auto reshape1 = builder.AddInstruction(
+      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {2, 2}), const0));
+  auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
+      ShapeUtil::MakeShape(S32, {2, 2}), HloOpcode::kNegate, reshape1));
+  hlo_module->AddEntryComputation(builder.Build())
+      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1},
+                                HloInstruction::FusionKind::kLoop);
+
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, -2}, {-3, -4}}),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
+}
+
+XLA_TEST_F(CpuGpuFusionTest, TransposeNegate) {
+  auto builder = HloComputation::Builder(TestName());
+  auto hlo_module = CreateNewVerifiedModule();
+  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}})));
+  auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose(
+      ShapeUtil::MakeShape(S32, {2, 2}), const0, {1, 0}));
+  auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
+      ShapeUtil::MakeShape(S32, {2, 2}), HloOpcode::kNegate, transpose1));
+  hlo_module->AddEntryComputation(builder.Build())
+      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1},
+                                HloInstruction::FusionKind::kLoop);
+
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, -3}, {-2, -4}}),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
+}
+
+std::unique_ptr<HloComputation> MakeReduceTestComputation() {
+  auto builder = HloComputation::Builder("add");
+  auto lhs = builder.AddInstruction(HloInstruction::CreateParameter(
+      /*parameter_number=*/0, ShapeUtil::MakeShape(S32, {}), "lhs"));
+  auto rhs = builder.AddInstruction(HloInstruction::CreateParameter(
+      /*parameter_number=*/1, ShapeUtil::MakeShape(S32, {}), "rhs"));
+  builder.AddInstruction(HloInstruction::CreateBinary(
+      ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, lhs, rhs));
+  return builder.Build();
+}
+
+XLA_TEST_F(CpuGpuFusionTest, DISABLED_ON_CPU(Reduce)) {
+  auto hlo_module = CreateNewVerifiedModule();
+  auto builder = HloComputation::Builder(TestName());
+  auto const0 = builder.AddInstruction(
+      HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {32}), 0));
+  auto const1 = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
+  auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce(
+      ShapeUtil::MakeShape(S32, {}), const0, const1, {0},
+      hlo_module->AddEmbeddedComputation(MakeReduceTestComputation())));
+  hlo_module->AddEntryComputation(builder.Build())
+      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2},
+                                HloInstruction::FusionKind::kInput);
+
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(496),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
+}
+
+XLA_TEST_F(CpuGpuFusionTest, ReduceImplicitBroadcast) {
+  auto hlo_module = CreateNewVerifiedModule();
+
+  auto builder = HloComputation::Builder(TestName());
+  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::CreateR1<int32>({1, 2, 4, 8})));
+  auto const1 = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
+  auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce(
+      ShapeUtil::MakeShape(S32, {}), const0, const1, {0},
+      hlo_module->AddEmbeddedComputation(MakeReduceTestComputation())));
+  auto negate3 = builder.AddInstruction(HloInstruction::CreateUnary(
+      ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, reduce2));
+  hlo_module->AddEntryComputation(builder.Build())
+      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2},
+                                HloInstruction::FusionKind::kLoop);
+
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(-15),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
+}
+
+XLA_TEST_F(CpuGpuFusionTest, DISABLED_ON_CPU(ReduceWindow)) {
+  auto builder = HloComputation::Builder(TestName());
+  auto hlo_module = CreateNewVerifiedModule();
+  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::CreateR2<int32>({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}})));
+  auto const1 = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
+  Window window;
+  ASSERT_TRUE(
+      tensorflow::protobuf::TextFormat::ParseFromString("dimensions:{\n"
+                                                        "size:2\n"
+                                                        "stride:1\n"
+                                                        "padding_low:0\n"
+                                                        "padding_high:0\n"
+                                                        "window_dilation:1\n"
+                                                        "base_dilation:1\n"
+                                                        "}\n"
+                                                        "dimensions:{\n"
+                                                        "size:2\n"
+                                                        "stride:1\n"
+                                                        "padding_low:0\n"
+                                                        "padding_high:0\n"
+                                                        "window_dilation:1\n"
+                                                        "base_dilation:1\n"
+                                                        "}\n",
+                                                        &window));
+  auto nested_builder = HloComputation::Builder("mul");
+  {
+    auto x = nested_builder.AddInstruction(
+        HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "x"));
+    auto y = nested_builder.AddInstruction(
+        HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(S32, {}), "y"));
+    nested_builder.AddInstruction(HloInstruction::CreateBinary(
+        ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply, x, y));
+  }
+  auto nested_computation =
+      hlo_module->AddEmbeddedComputation(nested_builder.Build());
+  auto reduce_window2 =
+      builder.AddInstruction(HloInstruction::CreateReduceWindow(
+          ShapeUtil::MakeShape(S32, {2, 2}), const0, const1, window,
+          nested_computation));
+  hlo_module->AddEntryComputation(builder.Build())
+      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce_window2},
+                                HloInstruction::FusionKind::kLoop);
+
+  EXPECT_TRUE(LiteralTestUtil::Equal(
+      LiteralUtil::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
+      ExecuteAndTransfer(std::move(hlo_module), {})));
+}
+
+// When a constant (or other op) which has multiple users is imported
+// into a fusion, it should remain shared, rather than being duplicated
+// within the fusion.
+XLA_TEST_F(CpuGpuFusionTest, SharedConstant) {
+  auto hlo_module = CreateNewVerifiedModule();
+
+  auto builder = HloComputation::Builder(TestName());
+  auto const0 = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({0})));
+  auto const1 = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
+  auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
+      ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, const0));
+  auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
+      ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add1));
+  auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
+      ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add2));
+  auto add4 = builder.AddInstruction(HloInstruction::CreateBinary(
+      ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add3));
+  hlo_module->AddEntryComputation(builder.Build())
+      ->CreateFusionInstruction({add4, add3, add2, add1, const1},
+                                HloInstruction::FusionKind::kLoop);
+
+  HloComputation* entry_comp = hlo_module->entry_computation();
+
+  // entry computation contains the constant(0) and the fusion
+  EXPECT_EQ(entry_comp->instruction_count(), 2);
+
+  // fused instruction contains the constant(2), the parameter, and 4 adds
+  EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6);
+
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({8}),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
+}
+
+XLA_TEST_F(CpuGpuFusionTest, Add2D) {
+  TestElementwise2D<float, 2>(HloOpcode::kAdd);
+}
+
+XLA_TEST_F(CpuGpuFusionTest, Subtract2D) {
+  TestElementwise2D<float, 2>(HloOpcode::kSubtract);
+}
+
+XLA_TEST_F(CpuGpuFusionTest, Multiply2D) {
+  TestElementwise2D<float, 2>(HloOpcode::kMultiply);
+}
+
+XLA_TEST_F(CpuGpuFusionTest, Divide2D) {
+  TestElementwise2D<float, 2>(HloOpcode::kDivide);
+}
+
+XLA_TEST_F(CpuGpuFusionTest, Power2D) {
+  TestElementwise2D<float, 2>(HloOpcode::kPower);
+}
+
+XLA_TEST_F(CpuGpuFusionTest, Minimum2D) {
+  TestElementwise2D<float, 2>(HloOpcode::kMinimum);
+}
+
+XLA_TEST_F(CpuGpuFusionTest, Maximum2D) {
+  TestElementwise2D<float, 2>(HloOpcode::kMaximum);
+}
+
+XLA_TEST_F(CpuGpuFusionTest, Equal2D) {
+  TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kEq);
+}
+
+XLA_TEST_F(CpuGpuFusionTest, Inequal2D) {
+  TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kNe);
+}
+
+XLA_TEST_F(CpuGpuFusionTest, Greater2D) {
+  TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kGt);
+}
+
+XLA_TEST_F(CpuGpuFusionTest, Lesser2D) {
+  TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kLt);
+}
+
+XLA_TEST_F(CpuGpuFusionTest, GreaterOrEqual2D) {
+  TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kGe);
+}
+
+XLA_TEST_F(CpuGpuFusionTest, LesserOrEqual2D) {
+  TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kLe);
+}
+
+XLA_TEST_F(CpuGpuFusionTest, Clamp2D) {
+  TestElementwise2D<float, 3>(HloOpcode::kClamp);
+}
+
+class FusionClientLibraryTest : public ClientLibraryTestBase {};
+
+XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) {
+  // On the GPU backend, it's possible to have too many transposes within one
+  // fusion, causing the kernel to run out shared memory and thus not compile.
+  // We want to check that doesn't happen.
+  //
+  // To do this, we create a computation that computes
+  //
+  //   P0 + P0*P1*P1 + P0*P2*P2 ...
+  //
+  // where even parameters have layout 1 and odd parameters have layout 2.
+  //
+  // Our goal is to tempt the backend into creating one giant multi-output
+  // fusion for the whole computation, including the transposes.  Currently
+  // multi-output fusion only fuses fusions, so each of the terms in the sum
+  // needs to be a fusion itself, thus the contortions above.
+  constexpr int kNumParams = 25;
+  XlaBuilder b("ManyLayoutTransformations");
+
+  // This test produces values that overflow int32, which is UB, so use uint32,
+  // where overflow is OK.
+  Array2D<uint32> arr(32, 32);
+  arr.FillUnique();
+  Literal l1 = LiteralUtil::CreateR2FromArray2D(arr).Relayout(
+      LayoutUtil::MakeLayout({0, 1}));
+
+  Literal l2 = LiteralUtil::CreateR2FromArray2D(arr).Relayout(
+      LayoutUtil::MakeLayout({1, 0}));
+
+  XlaOp p0 = AddParam(l1, &b);
+  XlaOp sum = p0;
+  for (int i = 1; i < kNumParams; ++i) {
+    auto pN = AddParam((i % 2 == 0 ? l1 : l2), &b);
+    sum = sum + p0 * pN * pN;
+  }
+
+  ComputeAndCompare(&b, {});
+}
+
+void BM_ParallelFusion(int num_iters) {
+  // Simple element-wise computation to benchmark parallel task partitioning.
+  tensorflow::testing::StopTiming();
+
+  se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
+  auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
+  se::StreamExecutorMemoryAllocator allocator(platform, executors);
+
+  const int64 intra_op_parallelism_threads = 24;
+  xla::LocalClientOptions client_options;
+  client_options.set_platform(platform);
+  client_options.set_intra_op_parallelism_threads(intra_op_parallelism_threads);
+  auto client =
+      ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie();
+
+  int device_ordinal = client->default_device_ordinal();
+
+  // Computation shape parameters.
+  const int64 param0_dim0 = 1024;
+  const int64 param0_dim1 = 1024;
+  const int64 param1_dim0 = 1024;
+  const int64 param1_dim1 = 1024;
+  const int64 param2_dim0 = 1024;
+  const int64 param2_dim1 = 1024;
+
+  // Create computation.
+  XlaBuilder builder("ParallelFusion");
+  Shape shape0 = ShapeUtil::MakeShape(F32, {param0_dim0, param0_dim1});
+  auto param0 = Parameter(&builder, 0, shape0, "param0");
+  Shape shape1 = ShapeUtil::MakeShape(F32, {param1_dim0, param1_dim1});
+  auto param1 = Parameter(&builder, 1, shape1, "param1");
+  Shape shape2 = ShapeUtil::MakeShape(F32, {param2_dim0, param2_dim1});
+  auto param2 = Parameter(&builder, 2, shape2, "param2");
+
+  auto x = Mul(param0, param1);
+  Add(x, param2);
+  auto computation = builder.Build().ConsumeValueOrDie();
+
+  // Transfer literals to device.
+  auto param0_literal =
+      LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1);
+  ScopedShapedBuffer buffer0 =
+      client->LiteralToShapedBuffer(param0_literal, device_ordinal)
+          .ConsumeValueOrDie();
+
+  auto param1_literal =
+      LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1);
+  ScopedShapedBuffer buffer1 =
+      client->LiteralToShapedBuffer(param1_literal, device_ordinal)
+          .ConsumeValueOrDie();
+
+  auto param2_literal =
+      LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1);
+  ScopedShapedBuffer buffer2 =
+      client->LiteralToShapedBuffer(param2_literal, device_ordinal)
+          .ConsumeValueOrDie();
+
+  // Build executable.
+  std::unique_ptr<LocalExecutable> executable =
+      client
+          ->Compile(computation,
+                    {&buffer0.on_host_shape(), &buffer1.on_host_shape(),
+                     &buffer2.on_host_shape()},
+                    ExecutableBuildOptions())
+          .ConsumeValueOrDie();
+
+  se::Stream stream(executors[device_ordinal]);
+  stream.Init();
+
+  // Initialize thread pool.
+  tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
+                                      intra_op_parallelism_threads);
+  Eigen::ThreadPoolDevice device(pool.AsEigenThreadPool(), pool.NumThreads());
+
+  // Initialize ExecutableRunOptions.
+  ExecutableRunOptions options;
+  options.set_allocator(&allocator).set_stream(&stream);
+  options.set_intra_op_thread_pool(&device);
+
+  // Run some warm-up executions.
+  const int kWarmups = 2;
+  for (int i = 0; i < kWarmups; ++i) {
+    auto result = executable->Run({&buffer0, &buffer1, &buffer2}, options);
+    ASSERT_TRUE(result.ok());
+  }
+
+  // Run benchmark.
+  const int64 total_bytes = param0_dim0 * param0_dim0 +
+                            param1_dim0 * param1_dim0 +
+                            param2_dim0 * param2_dim0;
+  tensorflow::testing::BytesProcessed(static_cast<int64>(num_iters) *
+                                      total_bytes * sizeof(float));
+  tensorflow::testing::UseRealTime();
+  tensorflow::testing::StartTiming();
+  for (int i = 0; i < num_iters; ++i) {
+    auto result = executable->Run({&buffer0, &buffer1, &buffer2}, options);
+    ASSERT_TRUE(result.ok());
+  }
+}
+
+BENCHMARK(BM_ParallelFusion);
+
+}  // namespace
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index 25e8284..ff2fd7e 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -1409,6 +1409,54 @@
   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
 }
 
+// Regression test for b/138155357, where we were incorrectly creating a dot-add
+// fusion where the dot had a batch dimension.  This isn't supported on the CPU
+// backend.
+XLA_TEST_F(DotOperationTextTest, FusedBatchDotRegressionTest) {
+  absl::string_view module_string = R"(
+HloModule jaxpr_computation__5.33
+
+jaxpr_computation__6.8 {
+  tuple.9 = () tuple()
+  parameter.14 = () parameter(4)
+  parameter.13 = (f32[2]{0}) parameter(3)
+  get-tuple-element.15 = f32[2]{0} get-tuple-element(parameter.13), index=0
+  reshape.16 = f32[1,2]{1,0} reshape(get-tuple-element.15)
+  parameter.10 = f32[2,2]{1,0} parameter(0)
+  reshape.17 = f32[2,1]{1,0} reshape(get-tuple-element.15)
+  dot.18 = f32[2,1]{1,0} dot(parameter.10, reshape.17), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  reshape.19 = f32[2]{0} reshape(dot.18)
+  reshape.20 = f32[2,1]{1,0} reshape(reshape.19)
+  dot.21 = f32[1,1]{1,0} dot(reshape.16, reshape.20), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  reshape.22 = f32[] reshape(dot.21)
+  parameter.11 = f32[2,1,2]{2,1,0} parameter(1)
+  broadcast.23 = f32[2,2,1]{2,1,0} broadcast(reshape.20), dimensions={1,2}
+  dot.24 = f32[2,1,1]{2,1,0} dot(parameter.11, broadcast.23), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
+  broadcast.25 = f32[2,1,2]{2,1,0} broadcast(reshape.16), dimensions={1,2}
+  parameter.12 = f32[2,2,1]{2,1,0} parameter(2)
+  dot.26 = f32[2,1,1]{2,1,0} dot(broadcast.25, parameter.12), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
+  add.27 = f32[2,1,1]{2,1,0} add(dot.24, dot.26)
+  reshape.28 = f32[2]{0} reshape(add.27)
+  ROOT tuple.29 = (f32[], f32[2]{0}) tuple(reshape.22, reshape.28)
+}
+
+ENTRY jaxpr_computation__5.33 {
+  constant.2 = f32[] constant(1)
+  broadcast.3 = f32[2,2]{1,0} broadcast(constant.2), dimensions={}
+  constant.5 = f32[2,1,2]{2,1,0} constant({ { { 1, 0 } }, { { 0, 1 } } })
+  constant.4 = f32[2,2,1]{2,1,0} constant({ { {1}, {1} }, { {1}, {1} } })
+  parameter.6 = f32[2]{0} parameter(0)
+  tuple.7 = (f32[2]{0}) tuple(parameter.6)
+  tuple.1 = () tuple()
+  call.30 = (f32[], f32[2]{0}) call(broadcast.3, constant.5, constant.4, tuple.7, tuple.1), to_apply=jaxpr_computation__6.8
+  get-tuple-element.31 = f32[] get-tuple-element(call.30), index=0
+  ROOT get-tuple-element.32 = f32[2]{0} get-tuple-element(call.30), index=1
+})";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(module_string));
+  EXPECT_TRUE(RunAndCompare(std::move(module), /*error=*/absl::nullopt));
+}
+
 XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstLHS_RL) {
   Array3D<float> input_arr(2, 3, 2);
   Array2D<float> const_arr(2, 6);
diff --git a/tensorflow/compiler/xla/tests/exhaustive_binary_test.cc b/tensorflow/compiler/xla/tests/exhaustive_binary_test.cc
new file mode 100644
index 0000000..c0f8a0d
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/exhaustive_binary_test.cc
@@ -0,0 +1,392 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h"
+
+#ifdef __FAST_MATH__
+#error("Can't be compiled with fast math on");
+#endif
+
+namespace xla {
+namespace {
+
+template <PrimitiveType T>
+using ExhaustiveBinaryTest = ExhaustiveOpTestBase<T, 2>;
+
+// Exhaustive test for binary operations for 16 bit floating point types,
+// including float16 and bfloat.
+//
+// Test parameter is a pair of (begin, end) for range under test.
+template <
+    PrimitiveType T,
+    typename std::enable_if<
+        std::is_same<typename primitive_util::PrimitiveTypeToNative<T>::type,
+                     half>::value ||
+        std::is_same<typename primitive_util::PrimitiveTypeToNative<T>::type,
+                     bfloat16>::value>::type* = nullptr>
+class Exhaustive16BitBinaryTest
+    : public ExhaustiveBinaryTest<T>,
+      public ::testing::WithParamInterface<std::pair<int64, int64>> {
+ public:
+  int64 GetInputSize() override {
+    int64 begin, end;
+    std::tie(begin, end) = GetParam();
+    return end - begin;
+  }
+
+  // Given a range of uint64 representation, uses bits 0..15 and bits 16..31 for
+  // the values of src0 and src1 for a 16 bit binary operation being tested,
+  // and generates the cartesian product of the two sets as the two inputs for
+  // the test.
+  void FillInput(std::array<Literal, 2>* input_literals) override {
+    int64 input_size = GetInputSize();
+    CHECK_EQ(input_size, (*input_literals)[0].element_count());
+    CHECK_EQ(input_size, (*input_literals)[1].element_count());
+
+    int64 begin, end;
+    std::tie(begin, end) = GetParam();
+    VLOG(2) << "Checking range [" << begin << ", " << end << "]";
+
+    absl::Span<NativeT> input_arr_0 = (*input_literals)[0].data<NativeT>();
+    absl::Span<NativeT> input_arr_1 = (*input_literals)[1].data<NativeT>();
+    for (int64 i = 0; i < input_size; i++) {
+      uint32 input_val = i + begin;
+      // Convert the lower 16 bits to the NativeT and replaced known incorrect
+      // input values with 0.
+      input_arr_0[i] = ConvertAndReplaceKnownIncorrectValueWith(input_val, 0);
+      input_arr_1[i] =
+          ConvertAndReplaceKnownIncorrectValueWith(input_val >> 16, 0);
+    }
+  }
+
+ protected:
+  using typename ExhaustiveBinaryTest<T>::NativeT;
+  using ExhaustiveBinaryTest<T>::ConvertAndReplaceKnownIncorrectValueWith;
+};
+
+using ExhaustiveF16BinaryTest = Exhaustive16BitBinaryTest<F16>;
+using ExhaustiveBF16BinaryTest = Exhaustive16BitBinaryTest<BF16>;
+
+// Returns a wrapper of the given build method, which build an HLO operation
+// with an empty broadcast dimension.
+inline std::function<XlaOp(XlaOp, XlaOp)> AddEmptyBroadcastDimension(
+    std::function<XlaOp(XlaOp, XlaOp, absl::Span<const int64>)> build_method) {
+  return [&](XlaOp src0, XlaOp src1) -> XlaOp {
+    return build_method(src0, src1, {});
+  };
+}
+
+#define XLA_TEST_16BIT(test_name, ...)            \
+  XLA_TEST_P(ExhaustiveF16BinaryTest, test_name)  \
+  __VA_ARGS__                                     \
+  XLA_TEST_P(ExhaustiveBF16BinaryTest, test_name) \
+  __VA_ARGS__
+
+XLA_TEST_16BIT(Add, {
+  auto host_add = [](float x, float y) { return x + y; };
+  Run(AddEmptyBroadcastDimension(Add), host_add);
+})
+
+XLA_TEST_16BIT(Sub, {
+  auto host_sub = [](float x, float y) { return x - y; };
+  Run(AddEmptyBroadcastDimension(Sub), host_sub);
+})
+
+// TODO(bixia): Mul fails with bfloat16 on CPU.
+XLA_TEST_16BIT(DISABLED_ON_CPU(Mul), {
+  auto host_mul = [](float x, float y) { return x * y; };
+  Run(AddEmptyBroadcastDimension(Mul), host_mul);
+})
+
+// TODO(bixia): Div fails with bfloat16 on CPU.
+XLA_TEST_16BIT(DISABLED_ON_CPU(Div), {
+  auto host_div = [](float x, float y) { return x / y; };
+  Run(AddEmptyBroadcastDimension(Div), host_div);
+})
+
+template <typename T, typename std::enable_if<
+                          std::is_same<T, float>::value ||
+                          std::is_same<T, double>::value>::type* = nullptr>
+T ReferenceMax(T x, T y) {
+  // We need to propagate NAN here becasue std::max may not propagate NAN.
+  if (std::fpclassify(x) == FP_NAN) {
+    return x;
+  }
+  if (std::fpclassify(y) == FP_NAN) {
+    return y;
+  }
+
+  return std::max<T>(x, y);
+}
+
+template <typename T, typename std::enable_if<
+                          std::is_same<T, float>::value ||
+                          std::is_same<T, double>::value>::type* = nullptr>
+T ReferenceMin(T x, T y) {
+  // We need to propagate NAN here becasue std::max may not propagate NAN.
+  if (std::fpclassify(x) == FP_NAN) {
+    return x;
+  }
+  if (std::fpclassify(y) == FP_NAN) {
+    return y;
+  }
+
+  return std::min<T>(x, y);
+}
+
+XLA_TEST_16BIT(Max,
+               { Run(AddEmptyBroadcastDimension(Max), ReferenceMax<float>); })
+
+XLA_TEST_16BIT(Min,
+               { Run(AddEmptyBroadcastDimension(Min), ReferenceMin<float>); })
+
+// TODO(bixia): Pow fails with bfloat16 on CPU.
+XLA_TEST_16BIT(DISABLED_ON_CPU(Pow),
+               { Run(AddEmptyBroadcastDimension(Pow), std::powf); })
+
+// TODO(bixia): Atan2 fails with bfloat16 on CPU.
+XLA_TEST_16BIT(DISABLED_ON_CPU(Atan2),
+               { Run(AddEmptyBroadcastDimension(Atan2), std::atan2f); })
+
+#if defined(BINARY_TEST_TARGET_F16)
+#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
+INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16BinaryTest,
+                         ::testing::ValuesIn(CreateExhaustiveF32Ranges()));
+#endif
+#endif
+
+#if defined(BINARY_TEST_TARGET_BF16)
+#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16)
+INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16BinaryTest,
+                         ::testing::ValuesIn(CreateExhaustiveF32Ranges()));
+#endif
+#endif
+
+// Exhaustive test for binary operations for float and double.
+//
+// Test parameter is a tuple of (FpValues, FpValues) describing the possible
+// values for each operand. The inputs for the test are the Cartesian product
+// of the possible values for the two operands.
+template <PrimitiveType T>
+class Exhaustive32BitOrMoreBinaryTest
+    : public ExhaustiveBinaryTest<T>,
+      public ::testing::WithParamInterface<std::tuple<FpValues, FpValues>> {
+ protected:
+  using typename ExhaustiveBinaryTest<T>::NativeT;
+  using ExhaustiveBinaryTest<T>::ConvertAndReplaceKnownIncorrectValueWith;
+
+ private:
+  int64 GetInputSize() override {
+    FpValues values_0;
+    FpValues values_1;
+    std::tie(values_0, values_1) = GetParam();
+    return values_0.GetTotalNumValues() * values_1.GetTotalNumValues();
+  }
+
+  void FillInput(std::array<Literal, 2>* input_literals) override {
+    int64 input_size = GetInputSize();
+    FpValues values_0;
+    FpValues values_1;
+    std::tie(values_0, values_1) = GetParam();
+
+    VLOG(2) << " testing " << values_0.ToString() << " " << values_1.ToString()
+            << "total values " << input_size;
+    CHECK(input_size == (*input_literals)[0].element_count() &&
+          input_size == (*input_literals)[1].element_count());
+
+    absl::Span<NativeT> input_arr_0 = (*input_literals)[0].data<NativeT>();
+    absl::Span<NativeT> input_arr_1 = (*input_literals)[1].data<NativeT>();
+
+    uint64 i = 0;
+    for (auto src0 : values_0) {
+      for (auto src1 : values_1) {
+        input_arr_0[i] = ConvertAndReplaceKnownIncorrectValueWith(src0, 1);
+        input_arr_1[i] = ConvertAndReplaceKnownIncorrectValueWith(src1, 1);
+        ++i;
+      }
+    }
+    CHECK_EQ(i, input_size);
+  }
+};
+
+using ExhaustiveF32BinaryTest = Exhaustive32BitOrMoreBinaryTest<F32>;
+using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest<F64>;
+
+XLA_TEST_P(ExhaustiveF32BinaryTest, Add) {
+  auto host_add = [](float x, float y) { return x + y; };
+  Run(AddEmptyBroadcastDimension(Add), host_add);
+}
+
+XLA_TEST_P(ExhaustiveF32BinaryTest, Sub) {
+  auto host_sub = [](float x, float y) { return x - y; };
+  Run(AddEmptyBroadcastDimension(Sub), host_sub);
+}
+
+// TODO(bixia): Need to investigate the failure on CPU and file bugs.
+XLA_TEST_P(ExhaustiveF32BinaryTest, DISABLED_ON_CPU(Mul)) {
+  auto host_mul = [](float x, float y) { return x * y; };
+  Run(AddEmptyBroadcastDimension(Mul), host_mul);
+}
+
+// TODO(bixia): Need to investigate the failure on CPU and file bugs.
+XLA_TEST_P(ExhaustiveF32BinaryTest, DISABLED_ON_CPU(Div)) {
+  auto host_div = [](float x, float y) { return x / y; };
+  Run(AddEmptyBroadcastDimension(Div), host_div);
+}
+
+XLA_TEST_P(ExhaustiveF32BinaryTest, Max) {
+  Run(AddEmptyBroadcastDimension(Max), ReferenceMax<float>);
+}
+
+XLA_TEST_P(ExhaustiveF32BinaryTest, Min) {
+  Run(AddEmptyBroadcastDimension(Min), ReferenceMin<float>);
+}
+
+// It is more convenient to implement Abs(complex) as a binary op than a unary
+// op, as the operations we currently support all have the same data type for
+// the source operands and the results.
+// TODO(bixia): May want to move this test to unary test if we will be able to
+// implement Abs(complex) as unary conveniently.
+//
+// TODO(bixia): Need to investigate the failure on CPU and file bugs.
+XLA_TEST_P(ExhaustiveF32BinaryTest, DISABLED_ON_CPU(AbsComplex)) {
+  auto host_abs_complex = [](float x, float y) {
+    return std::abs(std::complex<float>(x, y));
+  };
+  auto device_abs_complex = [](XlaOp x, XlaOp y) { return Abs(Complex(x, y)); };
+
+  Run(device_abs_complex, host_abs_complex);
+}
+
+#if defined(BINARY_TEST_TARGET_F32)
+
+INSTANTIATE_TEST_SUITE_P(
+    SpecialValues, ExhaustiveF32BinaryTest,
+    ::testing::Combine(
+        ::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>()),
+        ::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>())));
+
+INSTANTIATE_TEST_SUITE_P(
+    SpecialAndNormalValues, ExhaustiveF32BinaryTest,
+    ::testing::Combine(
+        ::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>()),
+        ::testing::Values(GetNormals<float>(2000))));
+
+INSTANTIATE_TEST_SUITE_P(
+    NormalAndSpecialValues, ExhaustiveF32BinaryTest,
+    ::testing::Combine(
+        ::testing::Values(GetNormals<float>(2000)),
+        ::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>())));
+
+INSTANTIATE_TEST_SUITE_P(
+    NormalAndNormalValues, ExhaustiveF32BinaryTest,
+    ::testing::Combine(::testing::Values(GetNormals<float>(2000)),
+                       ::testing::Values(GetNormals<float>(2000))));
+
+// Tests a total of 40000 ^ 2 inputs, with 2000 ^ 2 inputs in each sub-test.
+// Comparing with the unary tests, the binary tests use a smaller set of inputs
+// for each sub-test to avoid timeout because the implementation of ExpectNear
+// more than 2x slower for binary test.
+INSTANTIATE_TEST_SUITE_P(
+    LargeAndSmallMagnituedNormalValues, ExhaustiveF32BinaryTest,
+    ::testing::Combine(
+        ::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals<float>(40000,
+                                                                         2000)),
+        ::testing::ValuesIn(
+            GetFpValuesForMagnitudeExtremeNormals<float>(40000, 2000))));
+
+#endif
+
+XLA_TEST_P(ExhaustiveF64BinaryTest, Add) {
+  auto host_add = [](double x, double y) { return x + y; };
+  Run(AddEmptyBroadcastDimension(Add), host_add);
+}
+
+XLA_TEST_P(ExhaustiveF64BinaryTest, Sub) {
+  auto host_sub = [](double x, double y) { return x - y; };
+  Run(AddEmptyBroadcastDimension(Sub), host_sub);
+}
+
+// TODO(bixia): Need to investigate the failure on CPU and file bugs.
+XLA_TEST_P(ExhaustiveF64BinaryTest, DISABLED_ON_CPU(Mul)) {
+  auto host_mul = [](double x, double y) { return x * y; };
+  Run(AddEmptyBroadcastDimension(Mul), host_mul);
+}
+
+// TODO(bixia): Need to investigate the failure on CPU and file bugs.
+XLA_TEST_P(ExhaustiveF64BinaryTest, DISABLED_ON_CPU(Div)) {
+  auto host_div = [](double x, double y) { return x / y; };
+  Run(AddEmptyBroadcastDimension(Div), host_div);
+}
+
+XLA_TEST_P(ExhaustiveF64BinaryTest, Max) {
+  Run(AddEmptyBroadcastDimension(Max), ReferenceMax<double>);
+}
+
+XLA_TEST_P(ExhaustiveF64BinaryTest, Min) {
+  Run(AddEmptyBroadcastDimension(Min), ReferenceMin<double>);
+}
+
+// TODO(bixia): Need to investigate the failure on CPU and file bugs.
+XLA_TEST_P(ExhaustiveF64BinaryTest, DISABLED_ON_CPU(AbsComplex)) {
+  auto host_abs_complex = [](double x, double y) {
+    return std::abs(std::complex<double>(x, y));
+  };
+  auto device_abs_complex = [](XlaOp x, XlaOp y) { return Abs(Complex(x, y)); };
+
+  Run(device_abs_complex, host_abs_complex);
+}
+
+#if defined(BINARY_TEST_TARGET_F64)
+
+#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
+INSTANTIATE_TEST_SUITE_P(
+    SpecialValues, ExhaustiveF64BinaryTest,
+    ::testing::Combine(
+        ::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>()),
+        ::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>())));
+
+INSTANTIATE_TEST_SUITE_P(
+    SpecialAndNormalValues, ExhaustiveF64BinaryTest,
+    ::testing::Combine(
+        ::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>()),
+        ::testing::Values(GetNormals<double>(1000))));
+
+INSTANTIATE_TEST_SUITE_P(
+    NormalAndSpecialValues, ExhaustiveF64BinaryTest,
+    ::testing::Combine(
+        ::testing::Values(GetNormals<double>(1000)),
+        ::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>())));
+
+INSTANTIATE_TEST_SUITE_P(
+    NormalAndNormalValues, ExhaustiveF64BinaryTest,
+    ::testing::Combine(::testing::Values(GetNormals<double>(1000)),
+                       ::testing::Values(GetNormals<double>(1000))));
+
+// Tests a total of 40000 ^ 2 inputs, with 1000 ^ 2 inputs in each sub-test.
+// Similar to ExhaustiveF64BinaryTest, we use a smaller set of inputs for each
+// for each sub-test comparing with the unary test to avoid timeout.
+INSTANTIATE_TEST_SUITE_P(
+    LargeAndSmallMagnituedNormalValues, ExhaustiveF64BinaryTest,
+    ::testing::Combine(
+        ::testing::ValuesIn(
+            GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000)),
+        ::testing::ValuesIn(
+            GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000))));
+#endif
+
+#endif
+}  // namespace
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.cc b/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.cc
index 02273d7..1d3248f 100644
--- a/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.cc
+++ b/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.cc
@@ -17,8 +17,8 @@
 
 namespace xla {
 
-// For f32, f16, and bf16, we need 9, 5, and 4 decimal places of precision to be
-// guaranteed that we're printing the full number.
+// For f64, f32, f16, and bf16, we need 17, 9, 5, and 4 decimal places of
+// precision to be guaranteed that we're printing the full number.
 //
 // (The general formula is, given a floating-point number with S significand
 // bits, the number of decimal digits needed to print it to full precision is
@@ -26,71 +26,237 @@
 //   ceil(1 + S * log_10(2)) ~= ceil(1 + S * 0.30103).
 //
 // See https://people.eecs.berkeley.edu/~wkahan/Math128/BinDecBin.pdf.)
-/*static*/
-string ExhaustiveOpTestBase::StringifyNum(float x) {
-  return absl::StrFormat("%0.9g (0x%08x)", x, BitCast<uint32>(x));
-}
+namespace {
+template <typename T>
+struct ComponentStringifyFormat {};
+
+template <>
+struct ComponentStringifyFormat<double> {
+  static constexpr absl::string_view value = "%0.17g (0x%16x)";
+};
+
+template <>
+struct ComponentStringifyFormat<float> {
+  static constexpr absl::string_view value = "%0.8g (0x%08x)";
+};
+
+template <>
+struct ComponentStringifyFormat<Eigen::half> {
+  static constexpr absl::string_view value = "%0.5g (0x%04x)";
+};
+
+template <>
+struct ComponentStringifyFormat<bfloat16> {
+  static constexpr absl::string_view value = "%0.4g (0x%04x)";
+};
+}  // namespace
 
 /*static*/
-string ExhaustiveOpTestBase::StringifyNum(half x) {
-  return absl::StrFormat("%0.5g (0x%04x)", static_cast<float>(x),
-                         BitCast<uint16>(x));
+template <PrimitiveType T, size_t N>
+string ExhaustiveOpTestBase<T, N>::StringifyNum(
+    typename ExhaustiveOpTestBase<T, N>::ComponentNativeT x) {
+  typedef typename ExhaustiveOpTestBase<T, N>::ComponentNativeT ComponentType;
+  typedef typename ExhaustiveOpTestBase<T, N>::ComponentIntegralNativeT
+      IntegralType;
+  return absl::StrFormat(ComponentStringifyFormat<ComponentType>::value,
+                         static_cast<double>(x), BitCast<IntegralType>(x));
 }
 
-/*static*/
-string ExhaustiveOpTestBase::StringifyNum(bfloat16 x) {
-  return absl::StrFormat("%0.4g (0x%04x)", static_cast<float>(x),
-                         BitCast<uint16>(x));
-}
-
-/*static*/
-std::vector<std::pair<int64, int64>>
-ExhaustiveOpTestBase::CreateExhaustiveF32Ranges() {
-  // We break up the 2^32-element space into small'ish chunks to keep peak
-  // memory usage low.
-  std::vector<std::pair<int64, int64>> result;
-  const int64 step = 1 << 25;
-  for (int64 i = 0; i < (1l << 32); i += step) {
-    result.push_back({i, i + step});
+template <PrimitiveType T, size_t N>
+void ExhaustiveOpTestBase<T, N>::ExpectNear(const InputLiterals& input_literals,
+                                            const Literal& result_literal,
+                                            EvaluateOp evaluate_op,
+                                            ErrorSpecGen error_spec_gen) {
+  // Cache for when all components are subnormal testing values.
+  std::vector<NativeRefT> pure_subnormal_cache;
+  pure_subnormal_cache.reserve(GetMaxCacheSize());
+  for (int i = 0; i < GetMaxCacheSize(); ++i) {
+    pure_subnormal_cache.push_back(
+        CallOperation(evaluate_op, FromCacheLocation(i)));
   }
-  return result;
+
+  NativeInputsList inputs_arr;
+  for (int i = 0; i < N; ++i) {
+    const Literal& literal = input_literals[i];
+    inputs_arr[i] = literal.data<NativeT>();
+  }
+
+  absl::Span<const NativeT> result_arr = result_literal.data<NativeT>();
+
+  int64 mismatches = 0;
+
+  for (int64 i = 0; i < result_arr.size(); ++i) {
+    NativeInputs inputs;
+    NativeRefInputs inputs_ref_ty;
+
+    for (int j = 0; j < N; ++j) {
+      inputs[j] = inputs_arr[j][i];
+      inputs_ref_ty[j] = static_cast<NativeRefT>(inputs[j]);
+    }
+
+    NativeT actual = result_arr[i];
+    NativeT expected =
+        static_cast<NativeT>(CallOperation(evaluate_op, inputs_ref_ty));
+    ErrorSpec error_spec = CallErrorSpec(error_spec_gen, inputs);
+
+    if (IsClose(static_cast<NativeRefT>(expected),
+                static_cast<NativeRefT>(actual), error_spec)) {
+      continue;
+    }
+
+    std::vector<NativeRefInputs> subnormal_test_inputs =
+        GetTestValuesWithSubnormalSubstitutions(inputs_ref_ty);
+
+    // Easy case: If `input` is not subnormal and !IsClose(expected, actual,
+    // error_spec), print an error.
+    if (subnormal_test_inputs.size() == 1) {
+      PrintMismatch(&mismatches, [&] {
+        return absl::StrFormat("Mismatch on %s. Expected %s, but got %s.",
+                               StringifyNum(inputs), StringifyNum(expected),
+                               StringifyNum(actual));
+      });
+      continue;
+    }
+
+    // Otherwise, we need to test the additional subnormal test values.
+    std::vector<NativeRefT> subnormal_test_results;
+    subnormal_test_results.reserve(subnormal_test_inputs.size());
+    bool passed_subnormal_test = false;
+
+    for (NativeRefInputs test_value : subnormal_test_inputs) {
+      NativeRefT result;
+      int cache_loc = GetCacheLocation(test_value);
+      if (cache_loc == kInvalidCacheIndex) {
+        result = CallOperation(evaluate_op, test_value);
+      } else {
+        result = pure_subnormal_cache[cache_loc];
+      }
+
+      if (IsClose(result, static_cast<NativeRefT>(actual), error_spec)) {
+        passed_subnormal_test = true;
+        break;
+      }
+      subnormal_test_results.push_back(std::move(result));
+    }
+
+    if (passed_subnormal_test) {
+      continue;
+    }
+
+    std::string mismatch = absl::StrFormat(
+        "Mismatch on subnormal value %s.  Expected one of:\n"
+        "  %10s (evaluated at full-precision value)\n",
+        StringifyNum(inputs), StringifyNum(expected));
+
+    CHECK_EQ(subnormal_test_inputs.size(), subnormal_test_results.size());
+    for (int i = 0; i < subnormal_test_inputs.size(); ++i) {
+      absl::StrAppend(
+          &mismatch,
+          absl::StrFormat("  %10s (evaluated at %s)\n",
+                          StringifyNum(subnormal_test_results[i]),
+                          GetSubnormalDescription(subnormal_test_inputs[i],
+                                                  inputs_ref_ty)));
+    }
+    absl::StrAppend(&mismatch,
+                    absl::StrFormat("but got %s", StringifyNum(actual)));
+
+    PrintMismatch(&mismatches, [mismatch] { return mismatch; });
+  }
+  EXPECT_EQ(mismatches, 0);
 }
 
 namespace {
-ExhaustiveOpTestBase::ErrorSpec DefaultF64SpecGenerator(float) {
-  return ExhaustiveOpTestBase::ErrorSpec(0.0001, 0.0001);
+template <PrimitiveType T, size_t N>
+inline typename ExhaustiveOpTestBase<T, N>::ErrorSpec DefaultSpecGenerator(
+    typename ExhaustiveOpTestBase<T, N>::NativeT) {
+  LOG(FATAL) << "Unhandled Type";
 }
 
-ExhaustiveOpTestBase::ErrorSpec DefaultF32SpecGenerator(float) {
-  return ExhaustiveOpTestBase::ErrorSpec(0.0001, 0.0001);
+template <PrimitiveType T, size_t N>
+inline typename ExhaustiveOpTestBase<T, N>::ErrorSpec DefaultSpecGenerator(
+    typename ExhaustiveOpTestBase<T, N>::NativeT,
+    typename ExhaustiveOpTestBase<T, N>::NativeT) {
+  LOG(FATAL) << "Unhandled Type";
 }
 
-ExhaustiveOpTestBase::ErrorSpec DefaultF16SpecGenerator(float) {
-  return ExhaustiveOpTestBase::ErrorSpec(0.001, 0.001);
+template <>
+inline ExhaustiveOpTestBase<C128, 1>::ErrorSpec DefaultSpecGenerator<C128, 1>(
+    complex128) {
+  return ExhaustiveOpTestBase<C128, 1>::ErrorSpec{0.0001, 0.0001};
 }
 
-ExhaustiveOpTestBase::ErrorSpec DefaultBF16SpecGenerator(float) {
-  return ExhaustiveOpTestBase::ErrorSpec(0.002, 0.02);
+template <>
+inline ExhaustiveOpTestBase<C64, 1>::ErrorSpec DefaultSpecGenerator<C64, 1>(
+    complex64) {
+  return ExhaustiveOpTestBase<C64, 1>::ErrorSpec{0.0001, 0.0001};
+}
+
+template <>
+inline ExhaustiveOpTestBase<F64, 1>::ErrorSpec DefaultSpecGenerator<F64, 1>(
+    double) {
+  return ExhaustiveOpTestBase<F64, 1>::ErrorSpec{0.0001, 0.0001};
+}
+
+template <>
+inline ExhaustiveOpTestBase<F32, 1>::ErrorSpec DefaultSpecGenerator<F32, 1>(
+    float) {
+  return ExhaustiveOpTestBase<F32, 1>::ErrorSpec{0.0001, 0.0001};
+}
+
+template <>
+inline ExhaustiveOpTestBase<F16, 1>::ErrorSpec DefaultSpecGenerator<F16, 1>(
+    Eigen::half) {
+  return ExhaustiveOpTestBase<F16, 1>::ErrorSpec{0.001, 0.001};
+}
+
+template <>
+inline ExhaustiveOpTestBase<BF16, 1>::ErrorSpec DefaultSpecGenerator<BF16, 1>(
+    bfloat16) {
+  return ExhaustiveOpTestBase<BF16, 1>::ErrorSpec{0.002, 0.02};
+}
+
+template <>
+inline ExhaustiveOpTestBase<F64, 2>::ErrorSpec DefaultSpecGenerator<F64, 2>(
+    double, double) {
+  return ExhaustiveOpTestBase<F64, 2>::ErrorSpec{0.001, 0.001};
+}
+
+template <>
+inline ExhaustiveOpTestBase<F32, 2>::ErrorSpec DefaultSpecGenerator<F32, 2>(
+    float, float) {
+  return ExhaustiveOpTestBase<F32, 2>::ErrorSpec{0.001, 0.001};
+}
+
+template <>
+inline ExhaustiveOpTestBase<F16, 2>::ErrorSpec DefaultSpecGenerator<F16, 2>(
+    Eigen::half, Eigen::half) {
+  return ExhaustiveOpTestBase<F16, 2>::ErrorSpec{0.001, 0.001};
+}
+
+template <>
+inline ExhaustiveOpTestBase<BF16, 2>::ErrorSpec DefaultSpecGenerator<BF16, 2>(
+    bfloat16, bfloat16) {
+  return ExhaustiveOpTestBase<BF16, 2>::ErrorSpec{0.002, 0.02};
 }
 }  // namespace
 
 /*static*/
-std::function<ExhaustiveOpTestBase::ErrorSpec(float)>
-ExhaustiveOpTestBase::GetDefaultSpecGenerator(PrimitiveType ty) {
-  switch (ty) {
-    case C128:
-    case F64:
-      return DefaultF64SpecGenerator;
-    case C64:
-    case F32:
-      return DefaultF32SpecGenerator;
-    case F16:
-      return DefaultF16SpecGenerator;
-    case BF16:
-      return DefaultBF16SpecGenerator;
-    default:
-      LOG(FATAL) << "Unhandled Type";
-  }
+template <PrimitiveType T, size_t N>
+typename ExhaustiveOpTestBase<T, N>::ErrorSpecGen
+ExhaustiveOpTestBase<T, N>::GetDefaultSpecGenerator() {
+  return DefaultSpecGenerator<T, N>;
 }
 
+template class ExhaustiveOpTestBase<C128, 1>;
+template class ExhaustiveOpTestBase<C64, 1>;
+template class ExhaustiveOpTestBase<F64, 1>;
+template class ExhaustiveOpTestBase<F32, 1>;
+template class ExhaustiveOpTestBase<F16, 1>;
+template class ExhaustiveOpTestBase<BF16, 1>;
+
+template class ExhaustiveOpTestBase<F64, 2>;
+template class ExhaustiveOpTestBase<F32, 2>;
+template class ExhaustiveOpTestBase<F16, 2>;
+template class ExhaustiveOpTestBase<BF16, 2>;
+
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h b/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h
index be16fdd..d66da60 100644
--- a/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h
+++ b/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h
@@ -28,28 +28,10 @@
 #include "tensorflow/compiler/xla/tests/test_macros.h"
 
 namespace xla {
-using Eigen::half;
 
-namespace test_util {
-template <int N>
-struct IntegralTypeWithByteWidth {};
-
-template <>
-struct IntegralTypeWithByteWidth<2> {
-  using type = uint16;
-};
-
-template <>
-struct IntegralTypeWithByteWidth<4> {
-  using type = uint32;
-};
-
-template <>
-struct IntegralTypeWithByteWidth<8> {
-  using type = uint64;
-};
-}  // namespace test_util
-
+// T: The primitive type being tested.
+// N: The number of operands that the function being tested takes.
+template <PrimitiveType T, size_t N>
 class ExhaustiveOpTestBase : public ClientLibraryTestBase {
  public:
   struct ErrorSpec {
@@ -65,9 +47,120 @@
     ErrorSpec(float a, float r) : abs_err(a), rel_err(r) {}
   };
 
-  // `ty` is the primitive type being tested.
-  explicit ExhaustiveOpTestBase(PrimitiveType ty)
-      : ty_(ty), platform_(client_->platform()->Name()) {
+  // Definitions depending on the primitive type T.
+
+  static constexpr bool kIsComplex = (T == C128 || T == C64);
+
+  // The primitive type used to compute the reference output.
+  struct RefT {
+    static constexpr PrimitiveType value = (T == F16 || T == BF16) ? F32 : T;
+  };
+
+  // The primitive type of the component of T. If T is not complex, then
+  // ComponentT = T.
+  struct ComponentT {
+    static constexpr PrimitiveType value =
+        !kIsComplex ? T
+                    : T == C128 ? F64 : T == C64 ? F32 : PRIMITIVE_TYPE_INVALID;
+  };
+
+  // Same as ComponentT, but for the RefT primitive type.
+  struct ComponentRefT {
+    static constexpr PrimitiveType value =
+        !kIsComplex ? RefT::value
+                    : RefT::value == C128
+                          ? F64
+                          : RefT::value == C64 ? F32 : PRIMITIVE_TYPE_INVALID;
+  };
+
+  // The primitive type of an unsigned integer that can be bitcasted to and from
+  // ComponentT.
+  struct ComponentIntegralT {
+    static constexpr PrimitiveType value =
+        (T == C128 || T == F64)
+            ? U64
+            : (T == C64 || T == F32)
+                  ? U32
+                  : (T == F16 || T == BF16) ? U16 : PRIMITIVE_TYPE_INVALID;
+  };
+
+  // Native types that correspond to the primtive types above.
+  using NativeT = typename primitive_util::PrimitiveTypeToNative<T>::type;
+  using NativeRefT =
+      typename primitive_util::PrimitiveTypeToNative<RefT::value>::type;
+  using ComponentNativeT =
+      typename primitive_util::PrimitiveTypeToNative<ComponentT::value>::type;
+  using ComponentNativeRefT = typename primitive_util::PrimitiveTypeToNative<
+      ComponentRefT::value>::type;
+  using ComponentIntegralNativeT =
+      typename primitive_util::PrimitiveTypeToNative<
+          ComponentIntegralT::value>::type;
+
+  using InputLiterals = std::array<Literal, N>;
+
+ private:
+  // N spans corresponding to the list of literal data values.
+  using NativeInputsList = std::array<absl::Span<const NativeT>, N>;
+
+  // N data items representing a single input to an XLA function.
+  using NativeInputs = std::array<NativeT, N>;
+
+  // N data items representing a single input to an interpreter backend
+  // function.
+  using NativeRefInputs = std::array<NativeRefT, N>;
+
+  // N data items representing a single input to an XLA function.
+  using XlaInputs = std::array<XlaOp, N>;
+
+  // Representations of the reference function passed in by the user.
+  template <size_t K>
+  struct EvaluateOpWrapper {};
+  template <>
+  struct EvaluateOpWrapper<1> {
+    using type = NativeRefT (*)(NativeRefT);
+  };
+  template <>
+  struct EvaluateOpWrapper<2> {
+    using type = NativeRefT (*)(NativeRefT, NativeRefT);
+  };
+
+  // Representations of the reference function passed in by the user.
+  template <size_t K>
+  struct EnqueueOpWrapper {};
+  template <>
+  struct EnqueueOpWrapper<1> {
+    using type = std::function<XlaOp(XlaOp)>;
+    static XlaOp BuildFromInputs(XlaInputs inputs, type ty) {
+      return ty(inputs[0]);
+    }
+  };
+  template <>
+  struct EnqueueOpWrapper<2> {
+    using type = std::function<XlaOp(XlaOp, XlaOp)>;
+    static XlaOp BuildFromInputs(XlaInputs inputs, type ty) {
+      return ty(inputs[0], inputs[1]);
+    }
+  };
+
+  // Representations of the ErrorSpecGen function passed in by the user.
+  template <size_t K>
+  struct ErrorSpecGenWrapper {};
+  template <>
+  struct ErrorSpecGenWrapper<1> {
+    using type = ErrorSpec (*)(NativeT);
+  };
+  template <>
+  struct ErrorSpecGenWrapper<2> {
+    using type = ErrorSpec (*)(NativeT, NativeT);
+  };
+
+ public:
+  using ErrorSpecGen = typename ErrorSpecGenWrapper<N>::type;
+  using EvaluateOp = typename EvaluateOpWrapper<N>::type;
+  using EnqueueOp = typename EnqueueOpWrapper<N>::type;
+
+  explicit ExhaustiveOpTestBase()
+      : ty_(T), platform_(client_->platform()->Name()) {
     SetFastMathDisabled(true);
 
     // Run all HLO passes.  In particular, constant folding is disabled by
@@ -75,6 +168,62 @@
     mutable_debug_options()->clear_xla_disable_hlo_passes();
   }
 
+  void Run(EnqueueOp enqueue_op, EvaluateOp evaluate_op) {
+    Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator());
+  }
+
+  // A helper for implementing the Run method for exhaustive op tests. It
+  // constructs the HLO module, compiles and runs the module and checks the
+  // result.
+  //
+  // We use a function pointer for evaluate_op for performance because it is
+  // called each time an output element is compared inside a loop in routine
+  // ExpectNear.
+  void Run(EnqueueOp enqueue_op, EvaluateOp evaluate_op,
+           ErrorSpecGen error_spec_gen) {
+    InputLiterals input_literals = CreateInputLiterals();
+    FillInput(&input_literals);
+
+    XlaBuilder builder(TestName());
+    XlaInputs xla_inputs;
+    for (int i = 0; i < N; ++i) {
+      xla_inputs[i] =
+          Parameter(&builder, i, input_literals[i].shape(), "input");
+    }
+    EnqueueOpWrapper<N>::BuildFromInputs(xla_inputs, enqueue_op);
+
+    TF_ASSERT_OK_AND_ASSIGN(XlaComputation comp, builder.Build());
+    TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
+                            RunComputationHelper(comp, input_literals));
+    ExpectNear(input_literals, result_literal, evaluate_op, error_spec_gen);
+  }
+
+  StatusOr<Literal> RunComputationHelper(const XlaComputation& comp,
+                                         const Literal& literal) {
+    return RunComputation(comp, {&literal});
+  }
+
+  StatusOr<Literal> RunComputationHelper(
+      const XlaComputation& comp, const std::array<Literal, N>& literals) {
+    std::array<const Literal*, N> lit_ptrs;
+    for (int i = 0; i < N; ++i) {
+      lit_ptrs[i] = &literals[i];
+    }
+    return RunComputation(comp, lit_ptrs);
+  }
+
+  // We essentially reimplement LiteralTestUtil::Near here because
+  //  a) this streamlined implementation is much faster, and
+  //  b) we can print out better error messages (namely, we can print out
+  //     which floating-point value input failed, while LiteralTestUtil::Near
+  //     can only print out the input index that failed).
+  //  c) we need special handling of certain inputs.  For example, we say that
+  //     a denormal input has multiple correct outputs (namely, f(x) and f(0))
+  //     and just needs to be close to one of them.
+  void ExpectNear(const InputLiterals& input_literals,
+                  const Literal& result_literal, EvaluateOp evaluate_op,
+                  ErrorSpecGen error_spec_gen);
+
   // Builds and runs the computation using the LocalClient API, rather than the
   // plain Client API, which is used by ClientLibraryTestBase.  This is because
   // the plain Client API results does more memcpys to/from Literals, and that's
@@ -122,34 +271,395 @@
     return std::move(result_literal);
   }
 
+  const string& Platform() { return platform_; }
+
   // Returns the number of elements in each input literal.
   virtual int64 GetInputSize() = 0;
 
-  Literal CreateInputLiteral() {
-    return LiteralUtil::CreateFromDimensions(ty_, {GetInputSize()});
+  // Fills the literals with values to test for.
+  virtual void FillInput(InputLiterals* literals) = 0;
+
+  // Replace infinites with max value to help compute errors.
+  static ComponentNativeRefT ReplaceInfWithMax(ComponentNativeRefT value) {
+    if (std::isinf(value)) {
+      return std::copysign(std::numeric_limits<ComponentNativeRefT>::max(),
+                           value);
+    }
+    return value;
   }
 
-  // `T` is the type of the value being compared, which is float if ty_ is of 32
-  // bits or less, and double otherwise.
-  template <typename T>
-  bool IsClose(T expected, T actual, ErrorSpec spec) {
-    static_assert(
-        std::is_same<T, float>::value || std::is_same<T, double>::value,
-        "Only supports float and double.");
+  // Returns true if both components are 0, but their sign bits differ.
+  static bool CheckSignedZeroError(ComponentNativeRefT expected,
+                                   ComponentNativeRefT actual) {
+    return expected == 0 && actual == 0 &&
+           std::signbit(expected) != std::signbit(actual);
+  }
+
+  // Sets the components to 0 if both are NaNs.
+  static void RemoveCorrespondingNaNs(ComponentNativeRefT* expected,
+                                      ComponentNativeRefT* actual) {
+    if (std::isnan(*expected) && std::isnan(*actual)) {
+      *expected = 0;
+      *actual = 0;
+    }
+  }
+
+  // The Implementation of the functions above, except for complex inputs.
+
+  static std::complex<ComponentNativeRefT> ReplaceInfWithMax(
+      std::complex<ComponentNativeRefT> value) {
+    value.real(ReplaceInfWithMax(value.real()));
+    value.imag(ReplaceInfWithMax(value.imag()));
+    return value;
+  }
+
+  static bool CheckSignedZeroError(std::complex<ComponentNativeRefT> expected,
+                                   std::complex<ComponentNativeRefT> actual) {
+    return CheckSignedZeroError(expected.real(), actual.real()) ||
+           CheckSignedZeroError(expected.imag(), actual.imag());
+  }
+
+  static void RemoveCorrespondingNaNs(
+      std::complex<ComponentNativeRefT>* expected,
+      std::complex<ComponentNativeRefT>* actual) {
+    ComponentNativeRefT expected_real = expected->real();
+    ComponentNativeRefT expected_imag = expected->imag();
+    ComponentNativeRefT actual_real = actual->real();
+    ComponentNativeRefT actual_imag = actual->imag();
+    RemoveCorrespondingNaNs(&expected_real, &actual_real);
+    RemoveCorrespondingNaNs(&expected_imag, &actual_imag);
+    expected->real(expected_real);
+    expected->imag(expected_imag);
+    actual->real(actual_real);
+    actual->imag(actual_imag);
+  }
+
+  // Returns a list of inputs that should be tested for closeness given some
+  // original input values.
+  //
+  // For denormal component inputs, we accept answers that are close to any of:
+  //
+  //   - evaluate_op(input)
+  //   - evaluate_op(+/-0), where the sign of 0 equal to the sign of
+  //     `input`,
+  //   - evaluate_op(+/-min_normal_float), where the sign of
+  //     min_normal_float matches `input`.
+  //   - if relaxed_denormal_signs_, evaluate_op(-/+0), where the sign of
+  //     0 is the opposite of `input`.
+  //
+  // (In particular, the XLA:CPU implementation of log flushes positive
+  // denormals to min-normal-float.  This seems kind of reasonable if our
+  // goal is to avoid infinities because they cause nans?)
+  std::vector<ComponentNativeRefT> GetTestValuesWithSubnormalSubstitutions(
+      ComponentNativeRefT value) {
+    std::vector<ComponentNativeRefT> test_values;
+    if (std::fpclassify(value) == FP_SUBNORMAL) {
+      test_values.reserve(relaxed_denormal_signs_ ? 3 : 2);
+      test_values.push_back(std::copysign(0, value));
+      test_values.push_back(std::copysign(
+          std::numeric_limits<ComponentNativeRefT>::min(), value));
+      if (relaxed_denormal_signs_) {
+        test_values.push_back(std::copysign(0, -value));
+      }
+    } else {
+      test_values.push_back(value);
+    }
+    return test_values;
+  }
+
+  // Similar to complex numbers, we only need to test the components that are
+  // subnormal. We can find the subnormal testing values for each component,
+  // then take the Cartesian product of each set of component values.
+  std::vector<std::complex<ComponentNativeRefT>>
+  GetTestValuesWithSubnormalSubstitutions(
+      std::complex<ComponentNativeRefT> value) {
+    using complex = std::complex<ComponentNativeRefT>;
+
+    auto real_values = GetTestValuesWithSubnormalSubstitutions(value.real());
+    auto imag_values = GetTestValuesWithSubnormalSubstitutions(value.imag());
+
+    std::vector<complex> test_values;
+    test_values.reserve(real_values.size() * imag_values.size());
+    for (auto real : real_values) {
+      for (auto imag : imag_values) {
+        test_values.push_back(complex(real, imag));
+      }
+    }
+
+    return test_values;
+  }
+
+  // The test values for an XLA function with N operands are the Cartesian
+  // product of the test values for each of the N operands.
+  std::vector<std::array<NativeRefT, N>>
+  GetTestValuesWithSubnormalSubstitutions(
+      const std::array<NativeRefT, N>& value) {
+    std::vector<std::array<NativeRefT, N>> test_values;
+
+    std::array<std::vector<NativeRefT>, N> component_test_values;
+    int total = 1;
+    for (int i = 0; i < N; ++i) {
+      component_test_values[i] =
+          GetTestValuesWithSubnormalSubstitutions(value[i]);
+      if (!component_test_values.empty()) {
+        total *= component_test_values[i].size();
+      }
+    }
+
+    // If total == 1, then value has no subnormal components, so we can just
+    // return a vector with value in it.
+    if (total == 1) {
+      test_values.push_back(value);
+      return test_values;
+    }
+
+    test_values.reserve(total);
+
+    // Perform a Cartesian product of the vectors in component_test_values.
+    // We can calculate this by uniquely mapping each integer from 0 to
+    // (total - 1) to a list of component indices. The function that maps an
+    // integer z to the index of component j is:
+    //    component_index(j) =  (i / NumValues(0, j-1)) % NumValues(j, j)
+    // and NumIndices(x, y) is the number of values in the Cartesian product of
+    // component_test_values[x], component_test_values[x+1], ...
+    // component_test_values[y].
+    for (int i = 0; i < total; ++i) {
+      int accumulated_num_values = 1;
+      std::array<NativeRefT, N> test_value;
+      for (int j = 0; j < N; ++j) {
+        int num_indices = component_test_values[j].size();
+        int component_index = (i / accumulated_num_values) % num_indices;
+        test_value[j] = component_test_values[j][component_index];
+        accumulated_num_values *= num_indices;
+      }
+      test_values.push_back(std::move(test_value));
+    }
+    return test_values;
+  }
+
+  // The number of values that can be substituted for subnormal inputs.
+  static constexpr int kNumSubnormalSubstitutionValues = 4;
+
+  // Encodings used to determine where subnormal test values are cached.
+  static constexpr int kPositiveMin = 0;
+  static constexpr int kNegativeMin = 1;
+  static constexpr int kPositiveZero = 2;
+  static constexpr int kNegativeZero = 3;
+  static constexpr int kNonSubnormal = -1;
+  static constexpr int kInvalidCacheIndex = -1;
+
+  // Since we take the cross product of all possible test values, and each
+  // component has kNumSubnormalSubstitutionValues possible test values, then
+  // the total number of different cache locations are
+  // kNumSubnormalSubstitutionValues raised to the num_components.
+  // num_components = N for the reals, and 2*N for the complex.
+  static constexpr int GetMaxCacheSize() {
+    return pow(kNumSubnormalSubstitutionValues, N * (kIsComplex ? 2 : 1));
+  }
+
+  // When we are testing a value such that all of its components are subnormal,
+  // we also need to test inputs made up of the Cartesian product of values
+  // replaced for each subnormal component. These additional test inputs are
+  // common enough where it will be efficient to just cache the results of these
+  // Cartesian products. In order to cache these values, we need a one to one
+  // mapping between these Cartesian products and cache locations.
+  //
+  // Our mapping works by assigning each component an integer in
+  // [0, kNumSubnormalSubstitutionValues) based on its test value. By lining
+  // these integers up with the n'th component corresponding to the n'th digit,
+  // then for each Cartesian product element we essentially create a unique base
+  // kNumSubnormalSubstitutionValues number. This number represents our cache
+  // index.
+  //
+  // In the event that there a component is not a subnormal, the value should
+  // not be cached, so we return a kNonSubnormal value.
+
+  static int GetCacheLocation(ComponentNativeRefT value) {
+    bool positive = !std::signbit(value);
+    if (std::abs(value) == std::numeric_limits<ComponentNativeRefT>::min()) {
+      if (positive) {
+        return kPositiveMin;
+      } else {
+        return kNegativeMin;
+      }
+    } else if (value != 0) {
+      CHECK(std::fpclassify(value) != FP_SUBNORMAL);
+      return kNonSubnormal;
+    } else if (positive) {
+      return kPositiveZero;
+    } else {
+      return kNegativeZero;
+    }
+  }
+
+  static int GetCacheLocation(std::complex<ComponentNativeRefT> value) {
+    int real_loc = GetCacheLocation(value.real());
+    int imag_loc = GetCacheLocation(value.imag());
+    if (real_loc == kNonSubnormal || imag_loc == kNonSubnormal) {
+      return kNonSubnormal;
+    } else {
+      return real_loc * kNumSubnormalSubstitutionValues + imag_loc;
+    }
+  }
+
+  static int GetCacheLocation(const NativeRefInputs& input) {
+    int location = 0;
+    int cache_size_per_element =
+        (kIsComplex
+             ? kNumSubnormalSubstitutionValues * kNumSubnormalSubstitutionValues
+             : kNumSubnormalSubstitutionValues);
+    for (int i = 0; i < N; ++i) {
+      int comp_loc = GetCacheLocation(input[i]);
+      if (i == kNonSubnormal) {
+        return kNonSubnormal;
+      }
+      location *= cache_size_per_element;
+      location += comp_loc;
+    }
+    return location;
+  }
+
+  // The inverse function of GetCacheLocation.
+
+  template <bool complex, typename RetT>
+  static RetT FromCacheLocationComponent(int cache_loc) {
+    LOG(FATAL) << "Not implemented.";
+  }
+
+  template <>
+  static ComponentNativeRefT
+  FromCacheLocationComponent<false, ComponentNativeRefT>(int cache_loc) {
+    switch (cache_loc) {
+      case kPositiveMin:
+        return std::numeric_limits<ComponentNativeRefT>::min();
+      case kNegativeMin:
+        return -std::numeric_limits<ComponentNativeRefT>::min();
+      case kPositiveZero:
+        return static_cast<ComponentNativeRefT>(0.0);
+      case kNegativeZero:
+        return static_cast<ComponentNativeRefT>(-0.0);
+      default:
+        LOG(FATAL) << "Invalid cache_loc value of " << cache_loc;
+    }
+  }
+
+  template <>
+  static std::complex<ComponentNativeRefT>
+  FromCacheLocationComponent<true, std::complex<ComponentNativeRefT>>(
+      int cache_loc) {
+    CHECK_LT(cache_loc,
+             kNumSubnormalSubstitutionValues * kNumSubnormalSubstitutionValues);
+    CHECK_GE(cache_loc, 0);
+
+    std::complex<ComponentNativeRefT> value;
+    value.real(FromCacheLocationComponent<false, ComponentNativeRefT>(
+        cache_loc / kNumSubnormalSubstitutionValues));
+    value.imag(FromCacheLocationComponent<false, ComponentNativeRefT>(
+        cache_loc % kNumSubnormalSubstitutionValues));
+    return std::move(value);
+  }
+
+  static NativeRefInputs FromCacheLocation(int cache_loc) {
+    NativeRefInputs input;
+    int cache_size_per_element =
+        (kIsComplex
+             ? kNumSubnormalSubstitutionValues * kNumSubnormalSubstitutionValues
+             : kNumSubnormalSubstitutionValues);
+    for (int i = N - 1; i >= 0; --i) {
+      input[i] = FromCacheLocationComponent<kIsComplex, NativeRefT>(
+          cache_loc % cache_size_per_element);
+      cache_loc /= cache_size_per_element;
+    }
+
+    return input;
+  }
+
+  // Returns a string that describes the test value for the actual value.
+  std::string GetSubnormalDescription(ComponentNativeRefT test_val,
+                                      ComponentNativeRefT actual_val) {
+    const string sp_min_normal = "sign-preserving min-normal-float";
+    const string sp_zero = "sign-preserving zero";
+    const string nsp_zero = "non-sign-preserving zero";
+
+    switch (GetCacheLocation(test_val)) {
+      case kNegativeMin:
+      case kPositiveMin:
+        return sp_min_normal;
+      case kNegativeZero:
+      case kPositiveZero:
+        return (std::signbit(test_val) == std::signbit(actual_val)) ? sp_zero
+                                                                    : nsp_zero;
+      default:
+        return "";
+    }
+  }
+
+  std::string GetSubnormalDescription(
+      std::complex<ComponentNativeRefT> test_val,
+      std::complex<ComponentNativeRefT> actual_val) {
+    std::string real =
+        GetSubnormalDescription(test_val.real(), actual_val.real());
+    std::string imag =
+        GetSubnormalDescription(test_val.imag(), actual_val.imag());
+
+    if (real.empty()) {
+      if (imag.empty()) {
+        return "";
+      }
+      real = "real";
+    } else if (imag.empty()) {
+      imag = "imag";
+    }
+
+    return absl::StrCat("(", real, ", ", imag, ")");
+  }
+
+  std::string GetSubnormalDescription(std::array<NativeRefT, N> test_vals,
+                                      std::array<NativeRefT, N> actual_vals) {
+    if (N == 1) {
+      return GetSubnormalDescription(test_vals[0], actual_vals[0]);
+    }
+
+    std::array<std::string, N> str_vals;
+    for (int i = 0; i < N; ++i) {
+      str_vals[i] = GetSubnormalDescription(test_vals[i], actual_vals[i]);
+      if (str_vals[i].empty()) {
+        str_vals[i] = "original";
+      }
+    }
+
+    return absl::StrCat("(", absl::StrJoin(str_vals, ", "), ")");
+  }
+
+  InputLiterals CreateInputLiterals() {
+    InputLiterals literals;
+    for (int i = 0; i < N; ++i) {
+      literals[i] = LiteralUtil::CreateFromDimensions(T, {GetInputSize()});
+    }
+    return std::move(literals);
+  }
+
+  // Determines if two output values are sufficiently close to each other based
+  // on an error spec.
+  bool IsClose(NativeRefT expected, NativeRefT actual, ErrorSpec spec) {
+    // When two corresponding values are a NaN, they can be considered to have
+    // the same value, so the values are just set to 0.
+    RemoveCorrespondingNaNs(&expected, &actual);
+
+    if (spec.strict_signed_zeros) {
+      if (CheckSignedZeroError(expected, actual)) {
+        return false;
+      }
+    }
+
     // Replace Inf with Max when calculating absolute or relative errors. This
     // allows the test to pass when another value are close to Inf and the
     // specified absolute or relative errors are not zero.
-    T abs_err =
+    double abs_err =
         std::abs(ReplaceInfWithMax(expected) - ReplaceInfWithMax(actual));
-    T rel_err = abs_err / std::abs(ReplaceInfWithMax(expected));
-    if (spec.strict_signed_zeros && actual == T{0} && expected == T{0}) {
-      // Check sign of zero.
-      return std::signbit(actual) == std::signbit(expected);
-    }
-    return abs_err <= spec.abs_err || rel_err <= spec.rel_err ||
-           (std::isnan(expected) && std::isnan(actual)) ||
-           (std::isinf(expected) && std::isinf(actual) &&
-            (expected > 0) == (actual > 0));
+    double rel_err = abs_err / std::abs(ReplaceInfWithMax(expected));
+
+    return abs_err <= spec.abs_err || rel_err <= spec.rel_err;
   }
 
   template <typename ErrorGenerator>
@@ -180,57 +690,57 @@
   // bit patterns for T. This bit pattern is zero extended and stored as uint64.
   // This function is used to convert such a bit pattern stored as uint64 to
   // the input value for T.
-  //
-  // T is the type of the floating value represented by the `bits`.
-  template <typename T>
-  T ConvertValue(uint64 bits) {
-    using I = typename test_util::IntegralTypeWithByteWidth<sizeof(T)>::type;
+  static ComponentNativeT ConvertValue(uint64 bits) {
+    using I = ComponentIntegralNativeT;
     I used_bits = static_cast<I>(bits);
-    return BitCast<T>(used_bits);
+    return BitCast<ComponentNativeT>(used_bits);
   }
 
-  template <typename T>
-  T ConvertAndReplaceKnownIncorrectValueWith(uint64 bits,
-                                             int replacement_value = 0) {
+  ComponentNativeT ConvertAndReplaceKnownIncorrectValueWith(
+      uint64 bits, int replacement_value = 0) {
     if (known_incorrect_fn_ && known_incorrect_fn_(bits)) {
-      return static_cast<T>(replacement_value);
+      return static_cast<ComponentNativeT>(replacement_value);
     }
-    return ConvertValue<T>(bits);
+    return ConvertValue(bits);
   }
 
-  static string StringifyNum(float x);
+  static string StringifyNum(ComponentNativeT x);
 
-  static string StringifyNum(half x);
-
-  static string StringifyNum(bfloat16 x);
-
-  template <typename T>
-  static string StringifyNum(std::complex<T> x) {
-    return absl::StrCat(StringifyNum(x.real()), " ", StringifyNum(x.imag()));
+  static string StringifyNum(std::complex<ComponentNativeT> x) {
+    return absl::StrCat("(", StringifyNum(x.real()), ", ",
+                        StringifyNum(x.imag()), ")");
   }
 
-  template <typename T>
-  static void AppendStringifyNum(std::string* s, T x) {
+  // We also stringify the NativeRefT, so we need to generate an additional
+  // version of this function when NativeRefT != NativeT.
+  template <
+      typename T1 = NativeRefT,
+      class = typename std::enable_if<!std::is_same<NativeT, T1>::value>::type>
+  static string StringifyNum(NativeRefT x) {
+    return ExhaustiveOpTestBase<RefT::value, N>::StringifyNum(x);
+  }
+
+  static string StringifyNum(const NativeInputs& inputs) {
+    if (N == 1) {
+      return StringifyNum(inputs[0]);
+    }
+
+    std::array<std::string, N> str_vals;
+    for (int i = 0; i < N; ++i) {
+      str_vals[i] = StringifyNum(inputs[i]);
+    }
+
+    return absl::StrCat("(", absl::StrJoin(str_vals, ", "), ")");
+  }
+
+  static void AppendStringifyNum(std::string* s, NativeT x) {
     absl::StrAppend(s, StringifyNum(x));
   }
 
-  static std::function<ErrorSpec(float)> GetDefaultSpecGenerator(
-      PrimitiveType ty);
-
-  static std::vector<std::pair<int64, int64>> CreateExhaustiveF32Ranges();
-
- private:
-  template <typename T>
-  T ReplaceInfWithMax(T value) {
-    if (std::isinf(value)) {
-      return std::copysign(std::numeric_limits<T>::max(), value);
-    }
-
-    return value;
-  }
+  static ErrorSpecGen GetDefaultSpecGenerator();
 
  protected:
-  // The primitive type under test.
+  // The primitive type being tested.
   const PrimitiveType ty_;
 
   // The platform under test.
@@ -249,6 +759,30 @@
   //
   // XLA:GPU preserves denormal signs, but other backends don't.
   bool relaxed_denormal_signs_ = platform_ != "CUDA";
+
+ private:
+  using EvaluateOpInternal = NativeRefT (*)(NativeRefInputs);
+  using ErrorSpecGenInternal = ErrorSpec (*)(NativeInputs);
+
+  template <typename Type, typename FuncPtr>
+  ErrorSpec CallErrorSpec(FuncPtr* func, const std::array<Type, 1>& in) {
+    return func(in[0]);
+  }
+
+  template <typename Type, typename FuncPtr>
+  ErrorSpec CallErrorSpec(FuncPtr* func, const std::array<Type, 2>& in) {
+    return func(in[0], in[1]);
+  }
+
+  template <typename Type, typename FuncPtr>
+  Type CallOperation(FuncPtr* func, const std::array<Type, 1>& in) {
+    return func(in[0]);
+  }
+
+  template <typename Type, typename FuncPtr>
+  Type CallOperation(FuncPtr* func, const std::array<Type, 2>& in) {
+    return func(in[0], in[1]);
+  }
 };
 
 // Represents a set of 64 bit chunks by representing the starting bit chunk,
@@ -467,6 +1001,7 @@
     const FpValues* fp_values_;
   };
 
+  FpValues() : bit_chunks_(), offsets_() {}
   FpValues(absl::Span<const BitChunks> chunks, absl::Span<const int> offsets) {
     CHECK_EQ(chunks.size(), offsets.size() - 1);
     CHECK_EQ(chunks.size(), kTotalBitChunks);
@@ -515,10 +1050,10 @@
   std::array<int, kTotalBitChunks + 1> offsets_;
 };
 
-template <typename T>
+template <typename T, typename std::enable_if<
+                          std::is_same<T, float>::value ||
+                          std::is_same<T, double>::value>::type* = nullptr>
 int GetMantissaTotalBits() {
-  static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
-                "Only supports float and double.");
   return std::numeric_limits<T>::digits - 1;
 }
 
@@ -542,10 +1077,10 @@
   return (1ull << GetExponentTotalBits<T>()) - 1ull;
 }
 
-template <typename T>
+template <typename T, typename std::enable_if<
+                          std::is_same<T, float>::value ||
+                          std::is_same<T, double>::value>::type* = nullptr>
 FpValues GetFpValues(BitChunks mantissa, BitChunks exponent, BitChunks sign) {
-  static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
-                "Only supports float and double.");
   int total_bits = GetFpTotalBits<T>();
   return FpValues({mantissa, exponent, sign},
                   {0, GetMantissaTotalBits<T>(), total_bits - 1, total_bits});
@@ -656,5 +1191,16 @@
           GetNans<T>(1000)};
 }
 
+inline std::vector<std::pair<int64, int64>> CreateExhaustiveF32Ranges() {
+  // We break up the 2^32-element space into small'ish chunks to keep peak
+  // memory usage low.
+  std::vector<std::pair<int64, int64>> result;
+  const int64 step = 1 << 25;
+  for (int64 i = 0; i < (1l << 32); i += step) {
+    result.push_back({i, i + step});
+  }
+  return result;
+}
+
 }  // namespace xla
 #endif  // TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_
diff --git a/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc b/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc
index 4019a5f..f8eb738 100644
--- a/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc
+++ b/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc
@@ -155,154 +155,8 @@
   return result - reflection;
 }
 
-class ExhaustiveRealUnaryTestBase : public ExhaustiveOpTestBase {
- public:
-  explicit ExhaustiveRealUnaryTestBase(PrimitiveType ty)
-      : ExhaustiveOpTestBase(ty) {}
-
-  // A helper for implementing the Run method for unary op test. It constructs
-  // the HLO module, compiles and runs the module and checks the result.
-  //
-  // T: is the input and output data type.
-  // RefT: is the type used for the host function to get the reference result.
-  //  RefT is different from T when T is of less than 32 bits, that is half and
-  //  bfloat16.
-  //
-  // We use a function pointer for evaluate_op for performance because it is
-  // called each time an output element is compared inside a loop in routine
-  // ExpectNear.
-  template <typename T, typename RefT>
-  void RunImpl(std::function<XlaOp(XlaOp)> enqueue_op,
-               RefT (*evaluate_op)(RefT), const Literal& input_literal,
-               std::function<ErrorSpec(float)> error_spec_gen) {
-    XlaBuilder builder(TestName());
-    XlaOp input = Parameter(&builder, 0, input_literal.shape(), "input");
-    enqueue_op(input);
-    TF_ASSERT_OK_AND_ASSIGN(XlaComputation comp, builder.Build());
-    TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
-                            RunComputation(comp, {&input_literal}));
-    ExpectNear<T, RefT>(input_literal, result_literal, evaluate_op,
-                        error_spec_gen);
-  }
-
-  // We essentially reimplement LiteralTestUtil::Near here because
-  //  a) this streamlined implementation is much faster, and
-  //  b) we can print out better error messages (namely, we can print out
-  //     which floating-point value input failed, while LiteralTestUtil::Near
-  //     can only print out the input index that failed).
-  //  c) we need special handling of certain inputs.  For example, we say that
-  //     a denormal input has multiple correct outputs (namely, f(x) and f(0))
-  //     and just needs to be close to one of them.
-  template <typename T, typename RefT>
-  void ExpectNear(const Literal& input_literal, const Literal& result_literal,
-                  RefT (*evaluate_op)(RefT),
-                  std::function<ErrorSpec(float)> error_spec_gen) {
-    absl::Span<const T> input_arr = input_literal.data<T>();
-    absl::Span<const T> result_arr = result_literal.data<T>();
-    ASSERT_EQ(result_arr.size(), input_arr.size());
-    int64 mismatches = 0;
-    // Hoisting these out of the loop is a nice speedup on shards that have many
-    // denormals.
-    const T expected_at_pos_zero = static_cast<T>(evaluate_op(0));
-    const T expected_at_neg_zero = static_cast<T>(evaluate_op(-0.0));
-    const T expected_at_pos_min_normal_float =
-        static_cast<T>(evaluate_op(std::numeric_limits<RefT>::min()));
-    const T expected_at_neg_min_normal_float =
-        static_cast<T>(evaluate_op(-std::numeric_limits<RefT>::min()));
-
-    for (int64 i = 0; i < input_arr.size(); ++i) {
-      T input = input_arr[i];
-      RefT input_ref_ty = static_cast<RefT>(input);
-      T actual = result_arr[i];
-      T expected = static_cast<T>(evaluate_op(input_ref_ty));
-
-      ErrorSpec error_spec = error_spec_gen(input_ref_ty);
-
-      // We only implement fpclassify for float and double, so we call
-      // IsClose<float> for half and bfloat16.
-      if (IsClose(static_cast<RefT>(expected), static_cast<RefT>(actual),
-                  error_spec)) {
-        continue;
-      }
-
-      // Easy case: If `input` is not denormal and !IsClose(expected, actual,
-      // error_spec), print an error.
-      if (std::fpclassify(input_ref_ty) != FP_SUBNORMAL) {
-        PrintMismatch(&mismatches, [&] {
-          return absl::StrFormat("Mismatch on %s. Expected %s, but got %s.",
-                                 StringifyNum(input), StringifyNum(expected),
-                                 StringifyNum(actual));
-        });
-        continue;
-      }
-
-      // Otherwise, `input` is denormal.  For denormal inputs, we accept answers
-      // that are close to any of:
-      //
-      //   - evaluate_op(input)
-      //   - evaluate_op(+/-0), where the sign of 0 equal to the sign of
-      //     `input`,
-      //   - evaluate_op(+/-min_normal_float), where the sign of
-      //     min_normal_float matches `input`.
-      //   - if relaxed_denormal_signs_, evaluate_op(-/+0), where the sign of
-      //     0 is the opposite of `input`.
-      //
-      // (In particular, the XLA:CPU implementation of log flushes positive
-      // denormals to min-normal-float.  This seems kind of reasonable if our
-      // goal is to avoid infinities because they cause nans?)
-      T sign_preserving_ftz_expected = std::signbit(input_ref_ty)
-                                           ? expected_at_neg_zero
-                                           : expected_at_pos_zero;
-      T flush_to_normal_expected = std::signbit(input_ref_ty)
-                                       ? expected_at_neg_min_normal_float
-                                       : expected_at_pos_min_normal_float;
-      T sign_nonpreserving_ftz_expected = std::signbit(input_ref_ty)
-                                              ? expected_at_pos_zero
-                                              : expected_at_neg_zero;
-      if (IsClose(static_cast<RefT>(sign_preserving_ftz_expected),
-                  static_cast<RefT>(actual), error_spec) ||
-          IsClose(static_cast<RefT>(flush_to_normal_expected),
-                  static_cast<RefT>(actual), error_spec) ||
-          (relaxed_denormal_signs_ &&
-           IsClose(static_cast<RefT>(sign_nonpreserving_ftz_expected),
-                   static_cast<RefT>(actual), error_spec))) {
-        continue;
-      }
-
-      if (relaxed_denormal_signs_) {
-        PrintMismatch(&mismatches, [&] {
-          return absl::StrFormat(
-              "Mismatch on denormal value %s.  Expected one of:\n"
-              "  %10s (evaluated at full-precision value)\n"
-              "  %10s (evaluated at sign-preserving min-normal-float)\n"
-              "  %10s (evaluated after flushing to sign-preserving zero)\n"
-              "  %10s (evaluated after flushing to non-sign-preserving "
-              "zero)\n"
-              "but got %s.",
-              StringifyNum(input),  //
-              StringifyNum(expected), StringifyNum(flush_to_normal_expected),
-              StringifyNum(sign_preserving_ftz_expected),
-              StringifyNum(sign_nonpreserving_ftz_expected),
-              StringifyNum(actual));
-        });
-      } else {
-        PrintMismatch(&mismatches, [&] {
-          return absl::StrFormat(
-              "Mismatch on denormal value %s.  Expected one of:\n"
-              "  %10s (evaluated at full-precision value)\n"
-              "  %10s (evaluated at sign-preserving min-normal-float)\n"
-              "  %10s (evaluated after flushing to sign-preserving zero)\n"
-              "but got %s.",
-              StringifyNum(input),  //
-              StringifyNum(expected), StringifyNum(flush_to_normal_expected),
-              StringifyNum(sign_preserving_ftz_expected),  //
-              StringifyNum(actual));
-        });
-      }
-    }
-    EXPECT_EQ(mismatches, 0);
-  }
-};
+template <PrimitiveType T>
+using ExhaustiveUnaryTest = ExhaustiveOpTestBase<T, 1>;
 
 // Exhaustive test for unary operations for <= 32bit floating point types.
 //
@@ -310,48 +164,21 @@
 //   - primitive type under test,
 //   - (begin, end) range under test, as zero-extended int64s bitcast to the
 //     primtive type under test.
+template <PrimitiveType T>
 class Exhaustive32BitOrLessUnaryTest
-    : public ExhaustiveRealUnaryTestBase,
-      public ::testing::WithParamInterface<
-          std::tuple<PrimitiveType, std::pair<int64, int64>>> {
+    : public ExhaustiveUnaryTest<T>,
+      public ::testing::WithParamInterface<std::pair<int64, int64>> {
  public:
-  typedef float (*F32EvaluateOp)(float);
-
-  Exhaustive32BitOrLessUnaryTest()
-      : ExhaustiveRealUnaryTestBase(std::get<0>(GetParam())) {}
-
-  void Run(std::function<XlaOp(XlaOp)> enqueue_op, F32EvaluateOp evaluate_op) {
-    return Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator(ty_));
-  }
-
-  void Run(std::function<XlaOp(XlaOp)> enqueue_op, F32EvaluateOp evaluate_op,
-           std::function<ErrorSpec(float)> error_spec_gen) {
-    Literal input_literal = CreateInputLiteral();
-    switch (ty_) {
-      case F32:
-        FillInput<float>(&input_literal);
-        return RunImpl<float, float>(enqueue_op, evaluate_op, input_literal,
-                                     error_spec_gen);
-      case F16:
-        FillInput<half>(&input_literal);
-        return RunImpl<half, float>(enqueue_op, evaluate_op, input_literal,
-                                    error_spec_gen);
-      case BF16:
-        FillInput<bfloat16>(&input_literal);
-        return RunImpl<bfloat16, float>(enqueue_op, evaluate_op, input_literal,
-                                        error_spec_gen);
-      default:
-        LOG(FATAL) << "Unhandled type.";
-    }
-  }
-
   // Sets error parameters appropriately for testing sin/cos/tan.
   void SetParamsForSinCosTan();
 
+ protected:
+  using typename ExhaustiveUnaryTest<T>::NativeT;
+
  private:
   int64 GetInputSize() override {
     int64 begin, end;
-    std::tie(begin, end) = std::get<1>(GetParam());
+    std::tie(begin, end) = GetParam();
     VLOG(2) << "Checking range [" << begin << ", " << end << ")";
     return end - begin;
   }
@@ -362,55 +189,64 @@
   // pattern. Each bit representation is first truncated to the integral type of
   // the same bit as the type being tested, if needed, and then bitcasted to the
   // type being tested.
-  template <typename T>
-  void FillInput(Literal* input_literal) {
+  void FillInput(std::array<Literal, 1>* input_literal) override {
     using IntegralT =
-        typename test_util::IntegralTypeWithByteWidth<sizeof(T)>::type;
-    int64 input_size = input_literal->element_count();
+        typename ExhaustiveOpTestBase<T, 1>::ComponentIntegralNativeT;
+    int64 input_size = (*input_literal)[0].element_count();
     int64 begin, end;
-    std::tie(begin, end) = std::get<1>(GetParam());
+    std::tie(begin, end) = GetParam();
     VLOG(2) << "Checking range [" << begin << ", " << end << ")";
     CHECK_EQ(input_size, end - begin);
 
-    absl::Span<T> input_arr = input_literal->data<T>();
+    absl::Span<NativeT> input_arr = (*input_literal)[0].data<NativeT>();
     for (int64 i = 0; i < input_size; i++) {
       IntegralT input_val = i + begin;
-      input_arr[i] = ConvertAndReplaceKnownIncorrectValueWith<T>(input_val, 0);
+      input_arr[i] =
+          this->ConvertAndReplaceKnownIncorrectValueWith(input_val, 0);
     }
   }
 };
 
-XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Log) {
-  auto error_spec_gen = GetDefaultSpecGenerator(ty_);
-  if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) {
-    error_spec_gen = [](float x) { return ErrorSpec{0.001, 0.001}; };
-  }
+typedef Exhaustive32BitOrLessUnaryTest<F32> ExhaustiveF32UnaryTest;
+typedef Exhaustive32BitOrLessUnaryTest<F16> ExhaustiveF16UnaryTest;
+typedef Exhaustive32BitOrLessUnaryTest<BF16> ExhaustiveBF16UnaryTest;
 
+#define XLA_TEST_FLOAT_32_BITS_OR_LESS(test_name, ...) \
+  XLA_TEST_P(ExhaustiveF32UnaryTest, test_name)        \
+  __VA_ARGS__                                          \
+  XLA_TEST_P(ExhaustiveF16UnaryTest, test_name)        \
+  __VA_ARGS__                                          \
+  XLA_TEST_P(ExhaustiveBF16UnaryTest, test_name)       \
+  __VA_ARGS__
+
+XLA_TEST_FLOAT_32_BITS_OR_LESS(Log, {
+  ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
+  if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) {
+    error_spec_gen = +[](NativeT x) { return ErrorSpec{0.001, 0.001}; };
+  }
   Run(Log, std::log, error_spec_gen);
-}
+})
 
-XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Log1p) {
-  auto error_spec_gen = GetDefaultSpecGenerator(ty_);
+XLA_TEST_FLOAT_32_BITS_OR_LESS(Log1p, {
+  ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
   if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) {
-    error_spec_gen = [](float x) { return ErrorSpec{0.001, 0.001}; };
+    error_spec_gen = +[](NativeT x) { return ErrorSpec{0.001, 0.001}; };
   }
-
   Run(Log1p, std::log1p, error_spec_gen);
-}
+})
 
-XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Exp) {
+XLA_TEST_FLOAT_32_BITS_OR_LESS(Exp, {
   // When x < -105, the true value of exp(x) is smaller than the smallest F32,
   // so exp(x) should return exactly 0. We want our implementation of exp to
   // return exactly 0 as well, as not doing so implies either that our
   // implementation of exp is not following the asymptotic behavior that exp(x)
   // approaches 0 as x approaches -inf, or that our implementation is not
   // approaching 0 fast enough.
-  auto default_spec_gen = GetDefaultSpecGenerator(ty_);
-  auto error_spec_gen = [default_spec_gen](float x) {
-    if (x < -105) {
+  ErrorSpecGen error_spec_gen = +[](NativeT x) {
+    if (x < static_cast<NativeT>(-105)) {
       return ErrorSpec{0, 0};
     }
-    return default_spec_gen(x);
+    return GetDefaultSpecGenerator()(x);
   };
 
   // Our CPU implementation of exp returns one incorrect value: says
@@ -428,20 +264,13 @@
   } else {
     Run(Exp, std::exp, error_spec_gen);
   }
-}
+})
 
-XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Expm1) {
-  auto default_spec_gen = GetDefaultSpecGenerator(ty_);
-  auto error_spec_gen = [default_spec_gen](float x) {
-    if (x < -105) {
-      return ErrorSpec{0, 0};
-    } else if (std::abs(x) < 5e-6) {
-      // For points around x=0, we should make sure that the result is accurate
-      // within 1 ULP of the value.
-      return ErrorSpec{0, 1.1921e-7};
-    }
-    return default_spec_gen(x);
-  };
+XLA_TEST_FLOAT_32_BITS_OR_LESS(Expm1, {
+  ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
+  if (ty_ == F32) {
+    error_spec_gen = +[](NativeT x) { return ErrorSpec{0, 0.00015}; };
+  }
 
   // Our CPU implementation of expm1 returns one incorrect value: says
   // exp(88.7228394) = max-float, but the correct answer is inf.  We deem this
@@ -458,65 +287,73 @@
   } else {
     Run(Expm1, std::expm1, error_spec_gen);
   }
-}
+})
 
 // It feels a little overkill to exhaustively test sqrt and pow(x, 0.5), but
 // this *did* find a bug, namely that some backends were assuming sqrt(x) ==
 // pow(x, 0.5), but this is not true for x == -inf.
-XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, PowOneHalf) {
-  Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); },
-      +[](float x) { return std::pow(x, 0.5f); });
-}
+XLA_TEST_FLOAT_32_BITS_OR_LESS(PowOneHalf, {
+  EvaluateOp fn = +[](float x) { return std::pow(x, 0.5f); };
+  // TODO(b/123837116): Enable the test for all values after fixing the bug.
+  if (platform_ != "Host" && platform_ != "CUDA") {
+    fn = +[](float x) {
+      if (x == -std::numeric_limits<float>::infinity()) {
+        return std::nanf("");
+      }
+      return std::pow(x, 0.5f);
+    };
+  }
+  Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); }, fn);
+})
 
-XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Rsqrt) {
+XLA_TEST_FLOAT_32_BITS_OR_LESS(Rsqrt, {
   Run(
       Rsqrt, +[](float x) { return 1 / std::sqrt(x); });
-}
+})
 
-XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Sqrt) {
-  auto default_spec_gen = GetDefaultSpecGenerator(ty_);
-  std::function<ErrorSpec(float)> error_spec_gen;
+XLA_TEST_FLOAT_32_BITS_OR_LESS(Sqrt, {
+  ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
   if (platform_ == "Host" || platform_ == "CUDA") {
-    error_spec_gen = [default_spec_gen](float x) {
-      ErrorSpec spec = default_spec_gen(x);
+    error_spec_gen = +[](NativeT x) {
+      auto spec = GetDefaultSpecGenerator()(x);
       spec.strict_signed_zeros = true;
       return spec;
     };
-  } else {
-    error_spec_gen = default_spec_gen;
   }
 
   Run(Sqrt, std::sqrt, error_spec_gen);
-}
+})
 
 // TODO(jlebar): Test trig functions over complex inputs.
-
-XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Acosh) {
+XLA_TEST_P(ExhaustiveF32UnaryTest, Acosh) {
   // Error inherited from Log, which our implementation of Acosh uses.
-  std::function<ErrorSpec(float)> error_spec_gen;
-  if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) {
-    error_spec_gen = [](float x) { return ErrorSpec{0.001, 0.001}; };
-  } else {
-    error_spec_gen = GetDefaultSpecGenerator(ty_);
+  ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
+  if (platform_ != "Host" && platform_ != "CUDA") {
+    error_spec_gen = +[](float x) { return ErrorSpec{0.001, 0.001}; };
   }
 
   Run(Acosh, std::acosh, error_spec_gen);
 }
-XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Asinh) {
-  // Error inherited from Log, which our implementation of Asinh uses.
-  std::function<ErrorSpec(float)> error_spec_gen;
-  if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) {
-    error_spec_gen = [](float x) { return ErrorSpec{0.001, 0.001}; };
-  } else {
-    error_spec_gen = GetDefaultSpecGenerator(ty_);
+XLA_TEST_P(ExhaustiveF16UnaryTest, Acosh) { Run(Acosh, std::acosh); }
+XLA_TEST_P(ExhaustiveBF16UnaryTest, Acosh) { Run(Acosh, std::acosh); }
+
+// Tests for Asinh
+XLA_TEST_P(ExhaustiveF32UnaryTest, Asinh) {
+  ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
+  if (platform_ != "Host" && platform_ != "CUDA") {
+    error_spec_gen = +[](float x) { return ErrorSpec{0.001, 0.001}; };
   }
+
   Run(Asinh, std::asinh, error_spec_gen);
 }
-XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Atanh) { Run(Atanh, std::atanh); }
-XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Acos) { Run(Acos, std::acos); }
-XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Asin) { Run(Asin, std::asin); }
+XLA_TEST_P(ExhaustiveF16UnaryTest, Asinh) { Run(Asinh, std::asinh); }
+XLA_TEST_P(ExhaustiveBF16UnaryTest, Asinh) { Run(Asinh, std::asinh); }
 
-XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Cosh) {
+XLA_TEST_FLOAT_32_BITS_OR_LESS(Atanh, { Run(Atanh, std::atanh); })
+XLA_TEST_FLOAT_32_BITS_OR_LESS(Acos, { Run(Acos, std::acos); })
+XLA_TEST_FLOAT_32_BITS_OR_LESS(Asin, { Run(Asin, std::asin); })
+
+XLA_TEST_FLOAT_32_BITS_OR_LESS(Cosh, {
   // Our cosh implementation incorrectly overflows to inf for +/-89.4159851.
   // The correct answer of 3.40281961e+38 (0x7f7fffec) is very close to
   // max-float, so we deem this acceptable.
@@ -535,8 +372,9 @@
     };
   }
   Run(Cosh, host_cosh);
-}
-XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Sinh) {
+})
+
+XLA_TEST_FLOAT_32_BITS_OR_LESS(Sinh, {
   // Our sinh implementation incorrectly overflows to +/-inf for +/-89.4159851.
   // The correct answer of 3.40281961e+38 (0x7f7fffec) is very close to
   // max-float, so we deem this acceptable.
@@ -555,76 +393,103 @@
     };
   }
   Run(Sinh, host_sinh);
-}
-XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Tanh) { Run(Tanh, std::tanh); }
+})
 
-void Exhaustive32BitOrLessUnaryTest::SetParamsForSinCosTan() {
-  if (platform_ == "Host" || platform_ == "CUDA") {
+XLA_TEST_FLOAT_32_BITS_OR_LESS(Tanh, {
+  ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
+  if (platform_ == "CUDA") {
+    error_spec_gen = +[](NativeT x) {
+      return x <= static_cast<NativeT>(-20.0) || x >= static_cast<NativeT>(20.0)
+                 ? ErrorSpec{0, 0}
+                 : GetDefaultSpecGenerator()(x);
+    };
+  }
+  Run(Tanh, std::tanh, error_spec_gen);
+})
+
+template <PrimitiveType T>
+void Exhaustive32BitOrLessUnaryTest<T>::SetParamsForSinCosTan() {
+  if (this->platform_ == "Host" || this->platform_ == "CUDA") {
     return;
   }
 
   // Non CPU/GPU targets may have used the Cody-Waite range reduction technique
   // and will not provide meaningful results for sin/cos/tan if magnitudes
   // exceed 2**p.
-  if (ty_ == F32) {
-    known_incorrect_fn_ = [](int64 v) {
+  if (T == F32) {
+    this->known_incorrect_fn_ = [](int64 v) {
       float f = BitCast<float>(static_cast<uint32>(v));
       return std::abs(f) > (1 << 13);
     };
-  } else if (ty_ == BF16) {
-    known_incorrect_fn_ = [](int64 v) {
+  } else if (T == BF16) {
+    this->known_incorrect_fn_ = [](int64 v) {
       float f = static_cast<float>(BitCast<bfloat16>(static_cast<uint16>(v)));
       return std::abs(f) > (1 << 13);
     };
   }
 }
 
-XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Cos) {
+XLA_TEST_P(ExhaustiveF32UnaryTest, Cos) {
   SetParamsForSinCosTan();
-  std::function<ErrorSpec(float)> error_spec_gen;
-  if (ty_ == F32) {
-    error_spec_gen = [](float) { return ErrorSpec{0.001, 0.001}; };
-  } else {
-    error_spec_gen = GetDefaultSpecGenerator(ty_);
-  }
-  Run(Cos, std::cos, error_spec_gen);
+  Run(
+      Cos, std::cos, +[](NativeT) {
+        return ErrorSpec{0.001, 0.001};
+      });
 }
-XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Sin) {
+XLA_TEST_P(ExhaustiveF16UnaryTest, Cos) {
   SetParamsForSinCosTan();
-  std::function<ErrorSpec(float)> error_spec_gen;
-  if (ty_ == F32) {
-    error_spec_gen = [](float) { return ErrorSpec{0.001, 0.001}; };
-  } else {
-    error_spec_gen = GetDefaultSpecGenerator(ty_);
-  }
-  Run(Sin, std::sin, error_spec_gen);
+  Run(Cos, std::cos);
 }
-XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Tan) {
+XLA_TEST_P(ExhaustiveBF16UnaryTest, Cos) {
   SetParamsForSinCosTan();
-  std::function<ErrorSpec(float)> error_spec_gen;
-  if (ty_ == F32) {
-    error_spec_gen = [](float) { return ErrorSpec{0.001, 0.001}; };
-  } else {
-    error_spec_gen = GetDefaultSpecGenerator(ty_);
-  }
-  Run(Tan, std::tan, error_spec_gen);
+  Run(Cos, std::cos);
+}
+
+XLA_TEST_P(ExhaustiveF32UnaryTest, Sin) {
+  SetParamsForSinCosTan();
+  Run(
+      Sin, std::sin, +[](NativeT) {
+        return ErrorSpec{0.001, 0.001};
+      });
+}
+XLA_TEST_P(ExhaustiveF16UnaryTest, Sin) {
+  SetParamsForSinCosTan();
+  Run(Sin, std::sin);
+}
+XLA_TEST_P(ExhaustiveBF16UnaryTest, Sin) {
+  SetParamsForSinCosTan();
+  Run(Sin, std::sin);
+}
+
+XLA_TEST_P(ExhaustiveF32UnaryTest, Tan) {
+  SetParamsForSinCosTan();
+  Run(
+      Tan, std::tan, +[](NativeT) {
+        return ErrorSpec{0.001, 0.001};
+      });
+}
+XLA_TEST_P(ExhaustiveF16UnaryTest, Tan) {
+  SetParamsForSinCosTan();
+  Run(Tan, std::tan);
+}
+XLA_TEST_P(ExhaustiveBF16UnaryTest, Tan) {
+  SetParamsForSinCosTan();
+  Run(Tan, std::tan);
 }
 
 // TODO(jlebar): Enable these.
-// XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Atan) { Run(Atan, std::atan); }
-// XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Atan2) { Run(Atan2, std::atan2); }
+// XLA_TEST_FLOAT_32_BITS_OR_LESS(Atan) { Run(Atan, std::atan); }
+// XLA_TEST_FLOAT_32_BITS_OR_LESS(Atan2) { Run(Atan2, std::atan2); }
 
-XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Erf) { Run(Erf, std::erf); }
-XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Erfc) { Run(Erfc, std::erfc); }
-XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, ErfInv) { Run(ErfInv, HostErfInv); }
-XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Digamma) {
-  std::function<ErrorSpec(float)> error_spec_gen;
+XLA_TEST_FLOAT_32_BITS_OR_LESS(Erf, { Run(Erf, std::erf); })
+XLA_TEST_FLOAT_32_BITS_OR_LESS(Erfc, { Run(Erfc, std::erfc); })
+XLA_TEST_FLOAT_32_BITS_OR_LESS(ErfInv, { Run(ErfInv, HostErfInv); })
+XLA_TEST_FLOAT_32_BITS_OR_LESS(Digamma, {
+  ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
   if (platform_ != "Host" && platform_ != "CUDA") {
     // TODO(b/123956399): This is a fairly high error, significantly higher than
     // we see on CPU/GPU.
-    error_spec_gen = [](float) { return ErrorSpec{0.01, 0.01}; };
-  } else {
-    error_spec_gen = GetDefaultSpecGenerator(ty_);
+    error_spec_gen = +[](NativeT) { return ErrorSpec{0.01, 0.01}; };
   }
 
   if (platform_ == "CUDA") {
@@ -647,27 +512,25 @@
   } else {
     Run(Digamma, HostDigamma, error_spec_gen);
   }
-}
-XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Lgamma) {
+})
+
+XLA_TEST_FLOAT_32_BITS_OR_LESS(Lgamma, {
   // Our implementation gets within 0.0001 rel error except for ~20 denormal
   // inputs on GPU.  Anyway 0.001 rel error should be good enough for lgamma.
-  auto default_spec_gen = GetDefaultSpecGenerator(ty_);
-  std::function<ErrorSpec(float)> error_spec_gen;
+  ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
   if (platform_ == "CUDA" && (ty_ == F32 || ty_ == F16)) {
-    error_spec_gen = [default_spec_gen](float x) {
-      ErrorSpec spec = default_spec_gen(x);
+    error_spec_gen = +[](NativeT x) {
+      auto spec = GetDefaultSpecGenerator()(x);
       spec.rel_err = 0.001;
       return spec;
     };
-  } else {
-    error_spec_gen = default_spec_gen;
   }
 
   float (*host_lgamma)(float) = std::lgamma;
   if (platform_ != "Host" && platform_ != "CUDA") {
     // TODO(b/123956399): This is a fairly high error, significantly higher than
     // we see on CPU/GPU.
-    error_spec_gen = [](float) { return ErrorSpec{0.01, 0.01}; };
+    error_spec_gen = +[](NativeT) { return ErrorSpec{0.01, 0.01}; };
 
     // Overflows to inf for input 4.08500343e+36 (0x7c44af8e).
     if (ty_ == F32) {
@@ -680,28 +543,25 @@
     }
   }
   Run(Lgamma, host_lgamma, error_spec_gen);
-}
+})
 
-XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Round) { Run(Round, std::round); }
+XLA_TEST_FLOAT_32_BITS_OR_LESS(Round, { Run(Round, std::round); })
 
-INSTANTIATE_TEST_SUITE_P(
-    F32, Exhaustive32BitOrLessUnaryTest,
-    ::testing::Combine(::testing::Values(F32),
-                       ::testing::ValuesIn(
-                           ExhaustiveOpTestBase::CreateExhaustiveF32Ranges())));
+#if defined(UNARY_TEST_TARGET_F32_OR_SMALLER)
+
+INSTANTIATE_TEST_SUITE_P(F32, ExhaustiveF32UnaryTest,
+                         ::testing::ValuesIn(CreateExhaustiveF32Ranges()));
 
 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
-INSTANTIATE_TEST_SUITE_P(
-    F16, Exhaustive32BitOrLessUnaryTest,
-    ::testing::Combine(::testing::Values(F16),
-                       ::testing::Values(std::make_pair(0, 1 << 16))));
+INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16UnaryTest,
+                         ::testing::Values(std::make_pair(0, 1 << 16)));
 #endif
 
 #if defined(XLA_BACKEND_SUPPORTS_BFLOAT16)
-INSTANTIATE_TEST_SUITE_P(
-    BF16, Exhaustive32BitOrLessUnaryTest,
-    ::testing::Combine(::testing::Values(BF16),
-                       ::testing::Values(std::make_pair(0, 1 << 16))));
+INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16UnaryTest,
+                         ::testing::Values(std::make_pair(0, 1 << 16)));
+#endif
+
 #endif
 
 // Exhaustive test for unary operations for double.
@@ -709,44 +569,25 @@
 // Test parameter is a tuple containing
 //   - primitive type under test,
 //   - FpValues representing a set of double values.
-class ExhaustiveF64UnaryTest : public ExhaustiveRealUnaryTestBase,
-                               public ::testing::WithParamInterface<
-                                   std::tuple<PrimitiveType, FpValues>> {
- public:
-  typedef double (*F64EvaluateOp)(double);
 
-  ExhaustiveF64UnaryTest()
-      : ExhaustiveRealUnaryTestBase(std::get<0>(GetParam())) {}
-
-  void Run(std::function<XlaOp(XlaOp)> enqueue_op, F64EvaluateOp evaluate_op) {
-    return Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator(ty_));
-  }
-
-  void Run(std::function<XlaOp(XlaOp)> enqueue_op, F64EvaluateOp evaluate_op,
-           std::function<ErrorSpec(float)> error_spec_gen) {
-    CHECK_EQ(ty_, F64);
-    Literal input_literal = CreateInputLiteral();
-    FillInputF64(&input_literal);
-    RunImpl<double, double>(enqueue_op, evaluate_op, input_literal,
-                            error_spec_gen);
-  }
-
+class ExhaustiveF64UnaryTest : public ExhaustiveUnaryTest<F64>,
+                               public ::testing::WithParamInterface<FpValues> {
  private:
   int64 GetInputSize() override {
-    FpValues values = std::get<1>(GetParam());
+    FpValues values = GetParam();
     return values.GetTotalNumValues();
   }
 
-  void FillInputF64(Literal* input_literal) {
-    FpValues fp_values = std::get<1>(GetParam());
-    int64 input_size = input_literal->element_count();
+  void FillInput(std::array<Literal, 1>* input_literal) override {
+    FpValues fp_values = GetParam();
+    int64 input_size = (*input_literal)[0].element_count();
     LOG(INFO) << "Checking fp values " << fp_values.ToString() << ", "
               << input_size;
-    absl::Span<double> input_arr = input_literal->data<double>();
+    absl::Span<double> input_arr = (*input_literal)[0].data<double>();
 
     uint64 i = 0;
     for (auto bits : fp_values) {
-      input_arr[i] = ConvertAndReplaceKnownIncorrectValueWith<double>(bits, 1);
+      input_arr[i] = this->ConvertAndReplaceKnownIncorrectValueWith(bits, 1);
       ++i;
     }
     CHECK_EQ(i, input_size);
@@ -755,192 +596,193 @@
 
 XLA_TEST_P(ExhaustiveF64UnaryTest, Log) { Run(Log, std::log); }
 
-// TODO(bixia): add other unary ops for double
+XLA_TEST_P(ExhaustiveF64UnaryTest, Log1p) { Run(Log1p, std::log1p); }
 
+XLA_TEST_P(ExhaustiveF64UnaryTest, Exp) { Run(Exp, std::exp); }
+
+XLA_TEST_P(ExhaustiveF64UnaryTest, Expm1) { Run(Expm1, std::expm1); }
+
+// TODO(b/138385863): Turn on the test for GPU after fixing the bug.
+XLA_TEST_P(ExhaustiveF64UnaryTest, DISABLED_ON_GPU(PowOneHalf)) {
+  Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); },
+      +[](double x) { return std::pow(x, 0.5); });
+}
+
+XLA_TEST_P(ExhaustiveF64UnaryTest, Rsqrt) {
+  Run(
+      Rsqrt, +[](double x) { return 1 / std::sqrt(x); });
+}
+
+XLA_TEST_P(ExhaustiveF64UnaryTest, Sqrt) { Run(Sqrt, std::sqrt); }
+
+XLA_TEST_P(ExhaustiveF64UnaryTest, Acosh) { Run(Acosh, std::acosh); }
+
+XLA_TEST_P(ExhaustiveF64UnaryTest, Asinh) { Run(Asinh, std::asinh); }
+
+XLA_TEST_P(ExhaustiveF64UnaryTest, Atanh) { Run(Atanh, std::atanh); }
+
+XLA_TEST_P(ExhaustiveF64UnaryTest, Acos) { Run(Acos, std::acos); }
+
+XLA_TEST_P(ExhaustiveF64UnaryTest, Asin) { Run(Asin, std::asin); }
+
+XLA_TEST_P(ExhaustiveF64UnaryTest, Cosh) { Run(Cosh, std::cosh); }
+
+XLA_TEST_P(ExhaustiveF64UnaryTest, Sinh) { Run(Sinh, std::sinh); }
+
+XLA_TEST_P(ExhaustiveF64UnaryTest, Tanh) {
+  ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
+  if (platform_ == "CUDA") {
+    error_spec_gen = +[](NativeT x) {
+      return x <= static_cast<NativeT>(-20.0) || x >= static_cast<NativeT>(20.0)
+                 ? ErrorSpec{0, 0}
+                 : GetDefaultSpecGenerator()(x);
+    };
+  }
+  Run(Tanh, std::tanh, error_spec_gen);
+}
+
+XLA_TEST_P(ExhaustiveF64UnaryTest, Cos) { Run(Cos, std::cos); }
+
+XLA_TEST_P(ExhaustiveF64UnaryTest, Sin) { Run(Sin, std::sin); }
+
+XLA_TEST_P(ExhaustiveF64UnaryTest, Tan) { Run(Tan, std::tan); }
+
+XLA_TEST_P(ExhaustiveF64UnaryTest, Round) { Run(Round, std::round); }
+
+#if defined(UNARY_TEST_TARGET_F64)
 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
 INSTANTIATE_TEST_SUITE_P(
     SpecialValues, ExhaustiveF64UnaryTest,
-    ::testing::Combine(
-        ::testing::Values(F64),
-        ::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>())));
+    ::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>()));
 
-INSTANTIATE_TEST_SUITE_P(
-    NormalValues, ExhaustiveF64UnaryTest,
-    ::testing::Combine(::testing::Values(F64),
-                       ::testing::Values(GetNormals<double>(1000))));
+INSTANTIATE_TEST_SUITE_P(NormalValues, ExhaustiveF64UnaryTest,
+                         ::testing::Values(GetNormals<double>(1000)));
 
 // Tests a total of 4000000000 inputs, with 16000000 inputs in each sub-test, to
 // keep the peak memory usage low.
 INSTANTIATE_TEST_SUITE_P(
     LargeAndSmallMagnituedNormalValues, ExhaustiveF64UnaryTest,
-    ::testing::Combine(
-        ::testing::Values(F64),
-        ::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals<double>(
-            4000000000ull, 16000000))));
+    ::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals<double>(
+        4000000000ull, 16000000)));
+#endif
 #endif
 
-class ExhaustiveComplexUnaryTestBase : public ExhaustiveOpTestBase {
- public:
-  explicit ExhaustiveComplexUnaryTestBase(PrimitiveType ty)
-      : ExhaustiveOpTestBase(ty) {}
+// T is the Primitive Type of the complex number
+// Test parameter is a tuple containing
+//   - primitive type under test,
+//   - two FpValues representing the values for the real and imaginary
+//     components. The complex numbers for the test input is the cartesian
+//     product of the values represented by the two FpValues.
+template <PrimitiveType T>
+class ExhaustiveComplexUnaryTestBase
+    : public ExhaustiveUnaryTest<T>,
+      public ::testing::WithParamInterface<std::tuple<FpValues, FpValues>> {
+ protected:
+  using typename ExhaustiveUnaryTest<T>::NativeT;
 
-  // A helper for implementing the Run method for unary op test of complex
-  // numbers.
-  //
-  // T is the component type of the complex number.
-  template <typename T>
-  void Run(std::function<XlaOp(XlaOp)> enqueue_op,
-           std::complex<T> (*evaluate_op)(std::complex<T>),
-           FpValues* values_real, FpValues* values_imag,
-           std::function<ErrorSpec(float)> error_spec_gen) {
-    Literal input_literal = CreateInputLiteral();
-
-    FillInput<T>(&input_literal, values_real, values_imag);
-
-    XlaBuilder builder(TestName());
-    auto input = Parameter(&builder, 0, input_literal.shape(), "input");
-    enqueue_op(input);
-    TF_ASSERT_OK_AND_ASSIGN(XlaComputation comp, builder.Build());
-    TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
-                            RunComputation(comp, {&input_literal}));
-    ExpectNearComplex<T>(input_literal, result_literal, evaluate_op,
-                         error_spec_gen);
+  void SetParamsForTanh() {
+    // TODO(b/138126045): Current libc++ implementation of the complex tanh
+    //                    function returns (NaN, NaN) when the imaginary
+    //                    component is more than half of the max value.
+    // TODO(b/138750327): Current libc++ implementation of the complex tanh
+    //                    function returns (1, 0) when the real component is
+    //                    negative infinity, when it should return (-1, 0).
+    // We only need to set the former as incorrect values for C128 because when
+    // testing with C64, we first cast our input to a C128 value.
+    this->known_incorrect_fn_ = [&](int64 v) {
+      double f = this->ConvertValue(v);
+      return (T == C128 &&
+              std::abs(f) > std::numeric_limits<float>::max() / 2) ||
+             f == -std::numeric_limits<double>::infinity();
+    };
   }
 
+ private:
   // Generates the input complex literal given the FpValues representation for
   // the real and imaginary components.
-  //
-  // T is the component type of the complex number.
-  template <typename T>
-  void FillInput(Literal* input_literal, FpValues* real_values,
-                 FpValues* imag_values) {
-    VLOG(2) << " testing input total "
-            << real_values->GetTotalNumValues() *
-                   imag_values->GetTotalNumValues()
-            << ", range " << real_values->ToString() << " "
-            << imag_values->ToString();
+  void FillInput(std::array<Literal, 1>* input_literal) override {
+    FpValues real_values = std::get<0>(GetParam());
+    FpValues imag_values = std::get<1>(GetParam());
 
-    absl::Span<std::complex<T>> input_arr =
-        input_literal->data<std::complex<T>>();
+    VLOG(2) << " testing input total "
+            << real_values.GetTotalNumValues() * imag_values.GetTotalNumValues()
+            << ", range " << real_values.ToString() << " "
+            << imag_values.ToString();
+
+    absl::Span<NativeT> input_arr = (*input_literal)[0].data<NativeT>();
 
     uint64 i = 0;
-    for (auto real : *real_values) {
-      for (auto imag : *imag_values) {
-        input_arr[i] = std::complex<T>(
-            ConvertAndReplaceKnownIncorrectValueWith<T>(real, 1),
-            ConvertAndReplaceKnownIncorrectValueWith<T>(imag, 1));
+    for (auto real : real_values) {
+      for (auto imag : imag_values) {
+        input_arr[i] =
+            NativeT(this->ConvertAndReplaceKnownIncorrectValueWith(real, 1),
+                    this->ConvertAndReplaceKnownIncorrectValueWith(imag, 1));
 
         ++i;
       }
     }
   }
 
-  template <typename T>
-  void ExpectNearComplex(const Literal& input_literal,
-                         const Literal& result_literal,
-                         std::complex<T> (*evaluate_op)(std::complex<T>),
-                         std::function<ErrorSpec(float)> error_spec_gen) {
-    absl::Span<const std::complex<T>> input_arr =
-        input_literal.data<std::complex<T>>();
-    absl::Span<const std::complex<T>> result_arr =
-        result_literal.data<std::complex<T>>();
-    ASSERT_EQ(result_arr.size(), input_arr.size());
-    int64 mismatches = 0;
+  int64 GetInputSize() override {
+    FpValues real_values = std::get<0>(GetParam());
+    FpValues imag_values = std::get<1>(GetParam());
+    return real_values.GetTotalNumValues() * imag_values.GetTotalNumValues();
+  }
+};
 
-    for (int64 i = 0; i < input_arr.size(); ++i) {
-      std::complex<T> input = input_arr[i];
-      std::complex<T> actual = result_arr[i];
-      std::complex<T> expected = evaluate_op(input);
+typedef ExhaustiveComplexUnaryTestBase<C64> ExhaustiveC64UnaryTest;
+typedef ExhaustiveComplexUnaryTestBase<C128> ExhaustiveC128UnaryTest;
 
-      // TODO(bixia): Need to fix error_spec_gen to consider both components.
-      // This only affects the value specific error_spec, and before we fix
-      // this, it means complex operation testing doesn't support value
-      // specific error_spec yet. We delay the fix to this partially because
-      // we don't know whether it is enough for the error_spec to only take
-      // the absolute value of the complex number.
-      ErrorSpec error_spec = error_spec_gen(input.real());
+// TODO(b/138578594): Enable the test for the CPU backend after fixing the bug.
+XLA_TEST_P(ExhaustiveC64UnaryTest, DISABLED_ON_CPU(Log)) {
+  Run(Log, [](complex64 x) { return std::log<float>(x); });
+}
 
-      if (IsClose(expected.real(), actual.real(), error_spec) &&
-          IsClose(expected.imag(), actual.imag(), error_spec)) {
-        continue;
-      }
-
-      // TODO(bixia): Need to handle complex operands with subnormals in
-      // real and/or imaginary components.
-      VLOG(2) << "calculate " << StringifyNum(input) << " ;"
-              << StringifyNum(actual) << "; " << StringifyNum(expected);
-
-      PrintMismatch(&mismatches, [&] {
-        return absl::StrFormat("Mismatch on %s. Expected %s, but got %s.",
-                               StringifyNum(input), StringifyNum(expected),
-                               StringifyNum(actual));
-      });
+// The current libc++ implementation of the complex tanh function provides
+// less accurate results when the denomenator of a complex tanh is small, due
+// to floating point precision loss. To avoid this issue for complex64 numbers,
+// we cast it to and from a complex128 when computing tanh.
+XLA_TEST_P(ExhaustiveC64UnaryTest, Tanh) {
+  SetParamsForTanh();
+  ErrorSpecGen error_spec_gen = +[](complex64 x) {
+    // This implementation of Tanh becomes less accurate when the denominator
+    // is small.
+    if (std::cosh(2 * x.real()) + std::cos(2 * x.imag()) < 1e-4) {
+      return ErrorSpec{5e-2, 5e-2};
     }
 
-    EXPECT_EQ(mismatches, 0);
-  }
-};
+    return GetDefaultSpecGenerator()(x);
+  };
+  Run(
+      Tanh,
+      +[](complex64 x) {
+        return static_cast<complex64>(std::tanh(static_cast<complex128>(x)));
+      },
+      error_spec_gen);
+}
 
-// Unary op test for complex<float>.
-//
-// Test parameter is a tuple containing
-//   - primitive type under test,
-//   - two FpValues representing the values for the real and imaginary
-//     components. The complex numbers for the test input is the cartesian
-//     product of the values represented by the two FpValues.
-class ExhaustiveC64UnaryTest
-    : public ExhaustiveComplexUnaryTestBase,
-      public ::testing::WithParamInterface<
-          std::tuple<PrimitiveType, FpValues, FpValues>> {
- public:
-  typedef complex64 (*C64EvaluateOp)(complex64);
-
-  ExhaustiveC64UnaryTest()
-      : ExhaustiveComplexUnaryTestBase(std::get<0>(GetParam())) {}
-
-  void Run(std::function<XlaOp(XlaOp)> enqueue_op, C64EvaluateOp evaluate_op) {
-    return Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator(ty_));
-  }
-
-  void Run(std::function<XlaOp(XlaOp)> enqueue_op, C64EvaluateOp evaluate_op,
-           std::function<ErrorSpec(float)> error_spec_gen) {
-    FpValues values_real = std::get<1>(GetParam());
-    FpValues values_imag = std::get<2>(GetParam());
-    ExhaustiveComplexUnaryTestBase::Run<float>(
-        enqueue_op, evaluate_op, &values_real, &values_imag, error_spec_gen);
-  }
-
-  int64 GetInputSize() override {
-    FpValues values_real = std::get<1>(GetParam());
-    FpValues values_imag = std::get<2>(GetParam());
-    return values_real.GetTotalNumValues() * values_imag.GetTotalNumValues();
-  }
-};
-
+#if defined(UNARY_TEST_TARGET_COMPLEX)
 INSTANTIATE_TEST_SUITE_P(
     F32SpecialValues, ExhaustiveC64UnaryTest,
     ::testing::Combine(
-        ::testing::Values(C64),
         ::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>()),
         ::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>())));
 
 INSTANTIATE_TEST_SUITE_P(
     F32SpecialAndNormalValues, ExhaustiveC64UnaryTest,
     ::testing::Combine(
-        ::testing::Values(C64),
         ::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>()),
         ::testing::Values(GetNormals<float>(10000))));
 
 INSTANTIATE_TEST_SUITE_P(
     F32NormalAndSpecialValues, ExhaustiveC64UnaryTest,
     ::testing::Combine(
-        ::testing::Values(C64), ::testing::Values(GetNormals<float>(10000)),
+        ::testing::Values(GetNormals<float>(10000)),
         ::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>())));
 
 INSTANTIATE_TEST_SUITE_P(
     F32NormalAndNormalValues, ExhaustiveC64UnaryTest,
-    ::testing::Combine(::testing::Values(C64),
-                       ::testing::Values(GetNormals<float>(10000)),
+    ::testing::Combine(::testing::Values(GetNormals<float>(10000)),
                        ::testing::Values(GetNormals<float>(10000))));
 
 // Tests a total of 40000 ^ 2 inputs, with 4000 ^ 2 inputs in each sub-test, to
@@ -948,84 +790,52 @@
 INSTANTIATE_TEST_SUITE_P(
     F32LargeAndSmallMagnituedNormalValues, ExhaustiveC64UnaryTest,
     ::testing::Combine(
-        ::testing::Values(C64),
         ::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals<float>(40000,
                                                                          4000)),
         ::testing::ValuesIn(
             GetFpValuesForMagnitudeExtremeNormals<float>(40000, 4000))));
+#endif
 
-// Unary op test for complex<double>.
-//
-// Test parameter is a tuple containing
-//   - primitive type under test,
-//   - two FpValues representing the values for the real and imaginary
-//     components. The complex numbers for the test input is the cartesian
-//     product of the values represented by the two FpValues.
-class ExhaustiveC128UnaryTest
-    : public ExhaustiveComplexUnaryTestBase,
-      public ::testing::WithParamInterface<
-          std::tuple<PrimitiveType, FpValues, FpValues>> {
- public:
-  typedef complex128 (*C128EvaluateOp)(complex128);
-
-  ExhaustiveC128UnaryTest()
-      : ExhaustiveComplexUnaryTestBase(std::get<0>(GetParam())) {}
-
-  void Run(std::function<XlaOp(XlaOp)> enqueue_op, C128EvaluateOp evaluate_op) {
-    return Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator(ty_));
-  }
-
-  void Run(std::function<XlaOp(XlaOp)> enqueue_op, C128EvaluateOp evaluate_op,
-           std::function<ErrorSpec(float)> error_spec_gen) {
-    FpValues values_real = std::get<1>(GetParam());
-    FpValues values_imag = std::get<2>(GetParam());
-    ExhaustiveComplexUnaryTestBase::Run<double>(
-        enqueue_op, evaluate_op, &values_real, &values_imag, error_spec_gen);
-  }
-
-  int64 GetInputSize() override {
-    FpValues values_real = std::get<1>(GetParam());
-    FpValues values_imag = std::get<2>(GetParam());
-    return values_real.GetTotalNumValues() * values_imag.GetTotalNumValues();
-  }
-};
 
 XLA_TEST_P(ExhaustiveC128UnaryTest, Log) {
-  // TODO(bixia): only test values that are not too big and not too small
-  //             for now and will work on fixing the implementation of XLA
-  //             operations to enable test for other values.
+  // TODO(b/138578313): Enable the test for all values after fixing the bug.
   known_incorrect_fn_ = [&](int64 v) {
-    double f = ConvertValue<double>(v);
-    return std::fpclassify(f) == FP_NAN || std::abs(f) > 5 || std::abs(f) < 1;
+    double f = this->ConvertValue(v);
+    return std::fpclassify(f) == FP_NAN || std::abs(f) > 1.0e+300 ||
+           std::abs(f) < 1.0e-300;
   };
-  Run(Log, [](complex128 x) { return std::log(x); });
+  Run(Log, [](complex128 x) { return std::log<double>(x); });
 }
 
+XLA_TEST_P(ExhaustiveC128UnaryTest, Tanh) {
+  SetParamsForTanh();
+  Run(
+      Tanh, +[](complex128 x) { return std::tanh(x); });
+}
+
+#if defined(UNARY_TEST_TARGET_COMPLEX)
 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
 INSTANTIATE_TEST_SUITE_P(
     SpecialValues, ExhaustiveC128UnaryTest,
     ::testing::Combine(
-        ::testing::Values(C128),
         ::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>()),
         ::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>())));
 
 INSTANTIATE_TEST_SUITE_P(
     SpecialAndNormalValues, ExhaustiveC128UnaryTest,
     ::testing::Combine(
-        ::testing::Values(C128),
         ::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>()),
         ::testing::Values(GetNormals<double>(10000))));
 
 INSTANTIATE_TEST_SUITE_P(
     NormalAndSpecialValues, ExhaustiveC128UnaryTest,
     ::testing::Combine(
-        ::testing::Values(C128), ::testing::Values(GetNormals<double>(10000)),
+        ::testing::Values(GetNormals<double>(10000)),
         ::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>())));
 
 INSTANTIATE_TEST_SUITE_P(
     F32NormalAndNormalValues, ExhaustiveC128UnaryTest,
-    ::testing::Combine(::testing::Values(C128),
-                       ::testing::Values(GetNormals<double>(10000)),
+    ::testing::Combine(::testing::Values(GetNormals<double>(10000)),
                        ::testing::Values(GetNormals<double>(10000))));
 
 // Tests a total of 40000 ^ 2 inputs, with 2000 ^ 2 inputs in each sub-test, to
@@ -1033,11 +843,11 @@
 INSTANTIATE_TEST_SUITE_P(
     LargeAndSmallMagnituedNormalValues, ExhaustiveC128UnaryTest,
     ::testing::Combine(
-        ::testing::Values(C128),
         ::testing::ValuesIn(
             GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000)),
         ::testing::ValuesIn(
             GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000))));
 #endif
+#endif
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
deleted file mode 100644
index 2d0805c..0000000
--- a/tensorflow/compiler/xla/tests/fusion_test.cc
+++ /dev/null
@@ -1,929 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <math.h>
-
-#include <algorithm>
-#include <memory>
-#include <new>
-#include <random>
-#include <utility>
-
-#define EIGEN_USE_THREADS
-
-#include "absl/memory/memory.h"
-#include "absl/types/span.h"
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/compiler/xla/array2d.h"
-#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/primitive_util.h"
-#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/compiler/xla/service/hlo_opcode.h"
-#include "tensorflow/compiler/xla/service/hlo_parser.h"
-#include "tensorflow/compiler/xla/service/platform_util.h"
-#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tests/literal_test_util.h"
-#include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/protobuf.h"
-#include "tensorflow/core/platform/test_benchmark.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace xla {
-namespace {
-
-const int test_width = 2, test_height = 3;
-
-const float test_float_vals[3][test_width][test_height] = {
-    {{-1.0, -1.0, 1.0}, {-3.0, 0.0, -1.0}},
-    {{-3.0, 2.0, 1.0}, {0.0, -3.0, 1.0}},
-    {{-3.0, 0.0, -3.0}, {-1.0, -2.0, 1.0}}};
-
-// Test whether fusion operations are emitted with no errors and compute
-// accurate outputs.
-class FusionTest : public HloTestBase {
- protected:
-  template <typename T, int Arity>
-  void TestElementwise2D(
-      HloOpcode opcode,
-      absl::optional<ComparisonDirection> direction = absl::nullopt) {
-    // Create a variable for comparisons since they require the direction.
-    bool is_compare = std::is_same<T, bool>::value;
-    Array2D<float> operand_data[Arity];
-    for (int i = 0; i < Arity; ++i) {
-      new (&operand_data[i]) Array2D<float>(test_width, test_height);
-    }
-    Array2D<T> answer_data(test_width, test_height);
-    for (int i = 0; i < test_width; ++i) {
-      for (int j = 0; j < test_height; ++j) {
-        float xs[Arity];
-        for (int k = 0; k < Arity; ++k) {
-          xs[k] = test_float_vals[k][i][j];
-          operand_data[k](i, j) = xs[k];
-        }
-        if (is_compare) {
-          answer_data(i, j) = ComputeElementwiseAnswerCompare(*direction, xs);
-        } else {
-          answer_data(i, j) = ComputeElementwiseAnswerFloat(opcode, xs);
-        }
-      }
-    }
-
-    auto builder = HloComputation::Builder(TestName());
-    auto hlo_module = CreateNewVerifiedModule();
-
-    auto prim_type = primitive_util::NativeToPrimitiveType<T>();
-
-    HloInstruction* hlos[4];
-    for (int i = 0; i < Arity; ++i) {
-      hlos[i + 1] = builder.AddInstruction(HloInstruction::CreateConstant(
-          LiteralUtil::CreateR2FromArray2D(operand_data[i])));
-    }
-    auto answer_shape =
-        ShapeUtil::MakeShape(prim_type, {test_width, test_height});
-    std::unique_ptr<HloInstruction> root_hlo;
-    switch (Arity) {
-      case 1:
-        root_hlo = HloInstruction::CreateUnary(answer_shape, opcode, hlos[1]);
-        break;
-      case 2:
-        if (is_compare) {
-          root_hlo = HloInstruction::CreateCompare(answer_shape, hlos[1],
-                                                   hlos[2], *direction);
-        } else {
-          root_hlo = HloInstruction::CreateBinary(answer_shape, opcode, hlos[1],
-                                                  hlos[2]);
-        }
-        break;
-      case 3:
-        root_hlo = HloInstruction::CreateTernary(answer_shape, opcode, hlos[1],
-                                                 hlos[2], hlos[3]);
-        break;
-      default:
-        LOG(FATAL) << "Bad arity: " << Arity;
-    }
-    hlos[0] = builder.AddInstruction(std::move(root_hlo));
-    hlo_module->AddEntryComputation(builder.Build())
-        ->CreateFusionInstruction(
-            absl::Span<HloInstruction* const>(hlos).subspan(0, Arity + 1),
-            HloInstruction::FusionKind::kLoop);
-
-    auto expected = LiteralUtil::CreateR2FromArray2D(answer_data);
-    auto actual = ExecuteAndTransfer(std::move(hlo_module), {});
-    if (primitive_util::IsFloatingPointType(prim_type)) {
-      EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, ErrorSpec(1e-4)));
-    } else {
-      EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
-    }
-  }
-
- private:
-  float ComputeElementwiseAnswerFloat(HloOpcode opcode,
-                                      absl::Span<const float> xs);
-  bool ComputeElementwiseAnswerCompare(ComparisonDirection direction,
-                                       absl::Span<const float> xs);
-  DebugOptions GetDebugOptionsForTest() override {
-    DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
-    debug_options.add_xla_disable_hlo_passes("layout-assignment");
-    return debug_options;
-  }
-};
-
-float FusionTest::ComputeElementwiseAnswerFloat(HloOpcode opcode,
-                                                absl::Span<const float> xs) {
-  switch (opcode) {
-    case HloOpcode::kAdd:
-      return xs[0] + xs[1];
-    case HloOpcode::kSubtract:
-      return xs[0] - xs[1];
-    case HloOpcode::kMultiply:
-      return xs[0] * xs[1];
-    case HloOpcode::kDivide:
-      return xs[0] / xs[1];
-    case HloOpcode::kPower:
-      return powf(xs[0], xs[1]);
-    case HloOpcode::kMinimum:
-      return std::min(xs[0], xs[1]);
-    case HloOpcode::kMaximum:
-      return std::max(xs[0], xs[1]);
-    case HloOpcode::kClamp:
-      return std::min(xs[2], std::max(xs[1], xs[0]));
-    default:
-      LOG(FATAL) << "No elementwise opcode: " << opcode;
-  }
-}
-
-bool FusionTest::ComputeElementwiseAnswerCompare(ComparisonDirection direction,
-                                                 absl::Span<const float> xs) {
-  switch (direction) {
-    case ComparisonDirection::kEq:
-      return xs[0] == xs[1];
-    case ComparisonDirection::kNe:
-      return xs[0] != xs[1];
-    case ComparisonDirection::kGt:
-      return xs[0] > xs[1];
-    case ComparisonDirection::kLt:
-      return xs[0] < xs[1];
-    case ComparisonDirection::kGe:
-      return xs[0] >= xs[1];
-    case ComparisonDirection::kLe:
-      return xs[0] <= xs[1];
-  }
-}
-
-XLA_TEST_F(FusionTest, Test) {
-  // test expression:
-  // slice(select({{T, F, T}, {F, T, F}},
-  //              concat(transpose({{1.0}, {2.0}, {3.0}} +
-  //                               {{-1.0}, {-1.0}, {-1.0}}),
-  //                     {{1.62, 2.72, 3.14}}) +
-  //                     (-{{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}),
-  //              {{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})) = {{0.5}, {2.72}}
-  auto builder = HloComputation::Builder(TestName());
-  auto hlo_module = CreateNewVerifiedModule();
-  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR2<float>({{1.0}, {2.0}, {3.0}})));
-  auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR2<float>({{-1.0}, {-1.0}, {-1.0}})));
-  auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
-      ShapeUtil::MakeShape(F32, {3, 1}), HloOpcode::kAdd, const0, const1));
-  auto reshape3 = builder.AddInstruction(HloInstruction::CreateTranspose(
-      ShapeUtil::MakeShape(F32, {1, 3}), add2, {1, 0}));
-  auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR2<float>({{1.62, 2.72, 3.14}})));
-  auto concat5 = builder.AddInstruction(HloInstruction::CreateConcatenate(
-      ShapeUtil::MakeShape(F32, {2, 3}), {reshape3, const4}, 0));
-  auto const6 = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR2<float>({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}})));
-  auto negate7 = builder.AddInstruction(HloInstruction::CreateUnary(
-      ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kNegate, const6));
-  auto add8 = builder.AddInstruction(HloInstruction::CreateBinary(
-      ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kAdd, concat5, negate7));
-  auto const9 = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR2<float>({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})));
-  auto const10 = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR2<bool>(
-          {{true, false, true}, {false, true, false}})));
-  auto select11 = builder.AddInstruction(
-      HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}),
-                                    HloOpcode::kSelect, const10, add8, const9));
-  auto slice12 = builder.AddInstruction(HloInstruction::CreateSlice(
-      ShapeUtil::MakeShape(F32, {2, 1}), select11, {0, 1}, {2, 2}, {1, 1}));
-  // CreateFusionInstruction needs the `instructions_to_fuse` argument in
-  // reverse topological order, so the first element in `instructions_to_fuse`
-  // must be the root.
-  hlo_module->AddEntryComputation(builder.Build())
-      ->CreateFusionInstruction(
-          {slice12, select11, const10, const9, add8, negate7, const6, concat5,
-           const4, reshape3, add2, const1, const0},
-          HloInstruction::FusionKind::kLoop);
-
-  EXPECT_TRUE(LiteralTestUtil::Near(
-      LiteralUtil::CreateR2<float>({{0.5}, {2.72}}),
-      ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
-}
-
-// Test whether we emit appropriate code for parameters of fusion instructions.
-XLA_TEST_F(FusionTest, Parameter) {
-  // Build a computation and fuse part of it so the fusion instruction has an
-  // operand parameter.
-  auto builder = HloComputation::Builder(TestName());
-  auto hlo_module = CreateNewVerifiedModule();
-  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}})));
-  auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary(
-      ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kCopy, const0));
-  auto const2 = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR2<float>({{-2.0, -2.0, -2.0}})));
-  // add3 = copy1 + const2 = const0 + const2 = {1,2,3} + {-2,-2,-2} = {-1,0,+1}
-  auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
-      ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kAdd, copy1, const2));
-  // CreateFusionInstruction needs `instructions_to_fuse` in reverse topological
-  // order.
-  hlo_module->AddEntryComputation(builder.Build())
-      ->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2},
-                                HloInstruction::FusionKind::kLoop);
-
-  EXPECT_TRUE(LiteralTestUtil::Near(
-      LiteralUtil::CreateR2<float>({{-1.0, 0.0, 1.0}}),
-      ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
-}
-
-XLA_TEST_F(FusionTest, RandomizedParallelPartition) {
-  // Tests parallel partitioning of a fusion instruction.
-  // Create shape with random outer dimension size to generate random parallel
-  // partition counts for each test run.
-  const int seed = tensorflow::testing::RandomSeed();
-  LOG(INFO) << "RandomizedParallelPartition seed: " << seed;
-  std::mt19937 generator(seed);
-  std::uniform_int_distribution<int> distribution(128, 1024);
-  const int64 rand_dim0_size = distribution(generator);
-  const int64 dim1_size = 1024;
-  Shape shape =
-      ShapeUtil::MakeShapeWithLayout(F32, {rand_dim0_size, dim1_size}, {1, 0});
-  // Build simple fusion computation: y = x^2 (elementwise).
-  auto builder = HloComputation::Builder(TestName());
-  auto hlo_module = CreateNewVerifiedModule();
-
-  auto two = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
-  auto x =
-      builder.AddInstruction(HloInstruction::CreateBroadcast(shape, two, {}));
-  auto y = builder.AddInstruction(
-      HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, x, x));
-
-  hlo_module->AddEntryComputation(builder.Build())
-      ->CreateFusionInstruction(/*instructions_to_fuse=*/{y, x, two},
-                                HloInstruction::FusionKind::kLoop);
-  // Compute result.
-  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
-  // Every element of result should be y = x^2 = 4.0.
-  for (int i = 0; i < rand_dim0_size; ++i) {
-    for (int j = 0; j < dim1_size; ++j) {
-      EXPECT_EQ(4.0, result.Get<float>({i, j}));
-    }
-  }
-}
-
-XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
-  auto builder = HloComputation::Builder(TestName());
-  auto hlo_module = CreateNewVerifiedModule();
-  auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
-  auto const_array = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR2<float>({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}})));
-  auto broadcast = builder.AddInstruction(
-      HloInstruction::CreateBroadcast(const_array->shape(), const_vector, {1}));
-  // add2 = broadcast(const_vector) + const_array
-  //      = broadcast({1,2,3}) + {{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}
-  //      = {{1, 2, 3}, {1, 2, 3}} + {{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}
-  auto add2 = builder.AddInstruction(
-      HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {2, 3}),
-                                   HloOpcode::kAdd, broadcast, const_array));
-  hlo_module->AddEntryComputation(builder.Build())
-      ->CreateFusionInstruction(/*instructions_to_fuse=*/{add2, broadcast},
-                                HloInstruction::FusionKind::kLoop);
-
-  EXPECT_TRUE(LiteralTestUtil::Near(
-      LiteralUtil::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
-      ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
-}
-
-XLA_TEST_F(FusionTest, ReshapeToScalar) {
-  auto builder = HloComputation::Builder(TestName());
-  auto hlo_module = CreateNewVerifiedModule();
-  auto single_element_array = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR2<int32>({{5}})));
-  auto reshape = builder.AddInstruction(HloInstruction::CreateReshape(
-      ShapeUtil::MakeShape(S32, {}), single_element_array));
-  hlo_module->AddEntryComputation(builder.Build())
-      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape},
-                                HloInstruction::FusionKind::kLoop);
-  EXPECT_TRUE(
-      LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(5),
-                             ExecuteAndTransfer(std::move(hlo_module), {})));
-}
-
-XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
-  auto builder = HloComputation::Builder(TestName());
-  auto hlo_module = CreateNewVerifiedModule();
-  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}})));
-  auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
-      ShapeUtil::MakeShape(S32, {1, 2, 3}), const0));
-  hlo_module->AddEntryComputation(builder.Build())
-      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
-                                HloInstruction::FusionKind::kLoop);
-  EXPECT_TRUE(LiteralTestUtil::Equal(
-      LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
-      ExecuteAndTransfer(std::move(hlo_module), {})));
-}
-
-XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
-  auto builder = HloComputation::Builder(TestName());
-  auto hlo_module = CreateNewVerifiedModule();
-  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}})));
-  auto reshape1 = builder.AddInstruction(
-      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 2}), const0));
-  hlo_module->AddEntryComputation(builder.Build())
-      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
-                                HloInstruction::FusionKind::kLoop);
-  EXPECT_TRUE(LiteralTestUtil::Equal(
-      LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
-      ExecuteAndTransfer(std::move(hlo_module), {})));
-}
-
-XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
-  auto builder = HloComputation::Builder(TestName());
-  auto hlo_module = CreateNewVerifiedModule();
-  auto const0 = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR3<int32>({{{7}}})));
-  auto reshape1 = builder.AddInstruction(
-      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0));
-  hlo_module->AddEntryComputation(builder.Build())
-      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
-                                HloInstruction::FusionKind::kLoop);
-  EXPECT_TRUE(
-      LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(7),
-                             ExecuteAndTransfer(std::move(hlo_module), {})));
-}
-
-XLA_TEST_F(FusionTest, Reshape__1by1by1) {
-  auto builder = HloComputation::Builder(TestName());
-  auto hlo_module = CreateNewVerifiedModule();
-  auto const0 = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(7)));
-  auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
-      ShapeUtil::MakeShape(S32, {1, 1, 1}), const0));
-  hlo_module->AddEntryComputation(builder.Build())
-      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
-                                HloInstruction::FusionKind::kLoop);
-  EXPECT_TRUE(
-      LiteralTestUtil::Equal(LiteralUtil::CreateR3<int32>({{{7}}}),
-                             ExecuteAndTransfer(std::move(hlo_module), {})));
-}
-
-XLA_TEST_F(FusionTest, Reshape__) {
-  auto builder = HloComputation::Builder(TestName());
-  auto hlo_module = CreateNewVerifiedModule();
-  auto const0 = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(7)));
-  auto reshape1 = builder.AddInstruction(
-      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0));
-  hlo_module->AddEntryComputation(builder.Build())
-      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
-                                HloInstruction::FusionKind::kLoop);
-  EXPECT_TRUE(
-      LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(7),
-                             ExecuteAndTransfer(std::move(hlo_module), {})));
-}
-
-XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
-  auto builder = HloComputation::Builder(TestName());
-  auto hlo_module = CreateNewVerifiedModule();
-  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
-  auto reshape1 = builder.AddInstruction(
-      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 3}), const0));
-  hlo_module->AddEntryComputation(builder.Build())
-      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
-                                HloInstruction::FusionKind::kLoop);
-  EXPECT_TRUE(LiteralTestUtil::Equal(
-      LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
-      ExecuteAndTransfer(std::move(hlo_module), {})));
-}
-
-XLA_TEST_F(FusionTest, Transpose_2by3) {
-  auto builder = HloComputation::Builder(TestName());
-  auto hlo_module = CreateNewVerifiedModule();
-  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}})));
-  auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
-      ShapeUtil::MakeShape(S32, {3, 2}), const0, {1, 0}));
-  hlo_module->AddEntryComputation(builder.Build())
-      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
-                                HloInstruction::FusionKind::kLoop);
-  EXPECT_TRUE(LiteralTestUtil::Equal(
-      LiteralUtil::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
-      ExecuteAndTransfer(std::move(hlo_module), {})));
-}
-
-XLA_TEST_F(FusionTest, Transpose_3by3) {
-  auto builder = HloComputation::Builder(TestName());
-  auto hlo_module = CreateNewVerifiedModule();
-  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
-  auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
-      ShapeUtil::MakeShape(S32, {3, 3}), const0, {1, 0}));
-  hlo_module->AddEntryComputation(builder.Build())
-      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
-                                HloInstruction::FusionKind::kLoop);
-  EXPECT_TRUE(LiteralTestUtil::Equal(
-      LiteralUtil::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
-      ExecuteAndTransfer(std::move(hlo_module), {})));
-}
-
-XLA_TEST_F(FusionTest, Reverse) {
-  auto builder = HloComputation::Builder(TestName());
-  auto hlo_module = CreateNewVerifiedModule();
-  auto const0 = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1, 2, 3})));
-  auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse(
-      ShapeUtil::MakeShape(S32, {3}), const0, {0}));
-  hlo_module->AddEntryComputation(builder.Build())
-      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1},
-                                HloInstruction::FusionKind::kLoop);
-
-  EXPECT_TRUE(
-      LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({3, 2, 1}),
-                             ExecuteAndTransfer(std::move(hlo_module), {})));
-}
-
-XLA_TEST_F(FusionTest, ReverseNegate) {
-  auto builder = HloComputation::Builder(TestName());
-  auto hlo_module = CreateNewVerifiedModule();
-  auto const0 = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1, 2, 3})));
-  auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse(
-      ShapeUtil::MakeShape(S32, {3}), const0, {0}));
-  auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
-      ShapeUtil::MakeShape(S32, {3}), HloOpcode::kNegate, reverse1));
-  hlo_module->AddEntryComputation(builder.Build())
-      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reverse1},
-                                HloInstruction::FusionKind::kLoop);
-
-  EXPECT_TRUE(
-      LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-3, -2, -1}),
-                             ExecuteAndTransfer(std::move(hlo_module), {})));
-}
-
-XLA_TEST_F(FusionTest, BroadcastNegate) {
-  auto builder = HloComputation::Builder(TestName());
-  auto hlo_module = CreateNewVerifiedModule();
-  auto const0 = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
-  auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
-      ShapeUtil::MakeShape(S32, {2}), const0, {}));
-  auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
-      ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, broadcast1));
-  hlo_module->AddEntryComputation(builder.Build())
-      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, broadcast1},
-                                HloInstruction::FusionKind::kLoop);
-
-  EXPECT_TRUE(
-      LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-1, -1}),
-                             ExecuteAndTransfer(std::move(hlo_module), {})));
-}
-
-XLA_TEST_F(FusionTest, SliceNegate) {
-  auto builder = HloComputation::Builder(TestName());
-  auto hlo_module = CreateNewVerifiedModule();
-  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR1<int32>({1, 2, 3, 4})));
-  auto slice1 = builder.AddInstruction(HloInstruction::CreateSlice(
-      ShapeUtil::MakeShape(S32, {2}), const0, {0}, {4}, {2}));
-  auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
-      ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, slice1));
-  hlo_module->AddEntryComputation(builder.Build())
-      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, slice1},
-                                HloInstruction::FusionKind::kLoop);
-
-  EXPECT_TRUE(
-      LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-1, -3}),
-                             ExecuteAndTransfer(std::move(hlo_module), {})));
-}
-
-XLA_TEST_F(FusionTest, DynamicSliceNegate) {
-  auto builder = HloComputation::Builder(TestName());
-  auto hlo_module = CreateNewVerifiedModule();
-  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR1<int32>({1, 2, 3, 4})));
-  auto const1 = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
-  auto dynamic_slice2 =
-      builder.AddInstruction(HloInstruction::CreateDynamicSlice(
-          ShapeUtil::MakeShape(S32, {2}), const0, {const1}, {2}));
-  auto negate3 = builder.AddInstruction(HloInstruction::CreateUnary(
-      ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, dynamic_slice2));
-  hlo_module->AddEntryComputation(builder.Build())
-      ->CreateFusionInstruction(
-          /*instructions_to_fuse=*/{negate3, dynamic_slice2},
-          HloInstruction::FusionKind::kLoop);
-
-  EXPECT_TRUE(
-      LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-2, -3}),
-                             ExecuteAndTransfer(std::move(hlo_module), {})));
-}
-
-XLA_TEST_F(FusionTest, ReshapeNegate) {
-  auto builder = HloComputation::Builder(TestName());
-  auto hlo_module = CreateNewVerifiedModule();
-  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR1<int32>({1, 2, 3, 4})));
-  auto reshape1 = builder.AddInstruction(
-      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {2, 2}), const0));
-  auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
-      ShapeUtil::MakeShape(S32, {2, 2}), HloOpcode::kNegate, reshape1));
-  hlo_module->AddEntryComputation(builder.Build())
-      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1},
-                                HloInstruction::FusionKind::kLoop);
-
-  EXPECT_TRUE(
-      LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, -2}, {-3, -4}}),
-                             ExecuteAndTransfer(std::move(hlo_module), {})));
-}
-
-XLA_TEST_F(FusionTest, TransposeNegate) {
-  auto builder = HloComputation::Builder(TestName());
-  auto hlo_module = CreateNewVerifiedModule();
-  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}})));
-  auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose(
-      ShapeUtil::MakeShape(S32, {2, 2}), const0, {1, 0}));
-  auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
-      ShapeUtil::MakeShape(S32, {2, 2}), HloOpcode::kNegate, transpose1));
-  hlo_module->AddEntryComputation(builder.Build())
-      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1},
-                                HloInstruction::FusionKind::kLoop);
-
-  EXPECT_TRUE(
-      LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, -3}, {-2, -4}}),
-                             ExecuteAndTransfer(std::move(hlo_module), {})));
-}
-
-std::unique_ptr<HloComputation> MakeReduceTestComputation() {
-  auto builder = HloComputation::Builder("add");
-  auto lhs = builder.AddInstruction(HloInstruction::CreateParameter(
-      /*parameter_number=*/0, ShapeUtil::MakeShape(S32, {}), "lhs"));
-  auto rhs = builder.AddInstruction(HloInstruction::CreateParameter(
-      /*parameter_number=*/1, ShapeUtil::MakeShape(S32, {}), "rhs"));
-  builder.AddInstruction(HloInstruction::CreateBinary(
-      ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, lhs, rhs));
-  return builder.Build();
-}
-
-XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
-  auto hlo_module = CreateNewVerifiedModule();
-  auto builder = HloComputation::Builder(TestName());
-  auto const0 = builder.AddInstruction(
-      HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {32}), 0));
-  auto const1 = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
-  auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce(
-      ShapeUtil::MakeShape(S32, {}), const0, const1, {0},
-      hlo_module->AddEmbeddedComputation(MakeReduceTestComputation())));
-  hlo_module->AddEntryComputation(builder.Build())
-      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2},
-                                HloInstruction::FusionKind::kInput);
-
-  EXPECT_TRUE(
-      LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(496),
-                             ExecuteAndTransfer(std::move(hlo_module), {})));
-}
-
-XLA_TEST_F(FusionTest, ReduceImplicitBroadcast) {
-  auto hlo_module = CreateNewVerifiedModule();
-
-  auto builder = HloComputation::Builder(TestName());
-  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR1<int32>({1, 2, 4, 8})));
-  auto const1 = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
-  auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce(
-      ShapeUtil::MakeShape(S32, {}), const0, const1, {0},
-      hlo_module->AddEmbeddedComputation(MakeReduceTestComputation())));
-  auto negate3 = builder.AddInstruction(HloInstruction::CreateUnary(
-      ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, reduce2));
-  hlo_module->AddEntryComputation(builder.Build())
-      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2},
-                                HloInstruction::FusionKind::kLoop);
-
-  EXPECT_TRUE(
-      LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(-15),
-                             ExecuteAndTransfer(std::move(hlo_module), {})));
-}
-
-XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
-  auto builder = HloComputation::Builder(TestName());
-  auto hlo_module = CreateNewVerifiedModule();
-  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR2<int32>({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}})));
-  auto const1 = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
-  Window window;
-  ASSERT_TRUE(
-      tensorflow::protobuf::TextFormat::ParseFromString("dimensions:{\n"
-                                                        "size:2\n"
-                                                        "stride:1\n"
-                                                        "padding_low:0\n"
-                                                        "padding_high:0\n"
-                                                        "window_dilation:1\n"
-                                                        "base_dilation:1\n"
-                                                        "}\n"
-                                                        "dimensions:{\n"
-                                                        "size:2\n"
-                                                        "stride:1\n"
-                                                        "padding_low:0\n"
-                                                        "padding_high:0\n"
-                                                        "window_dilation:1\n"
-                                                        "base_dilation:1\n"
-                                                        "}\n",
-                                                        &window));
-  auto nested_builder = HloComputation::Builder("mul");
-  {
-    auto x = nested_builder.AddInstruction(
-        HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "x"));
-    auto y = nested_builder.AddInstruction(
-        HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(S32, {}), "y"));
-    nested_builder.AddInstruction(HloInstruction::CreateBinary(
-        ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply, x, y));
-  }
-  auto nested_computation =
-      hlo_module->AddEmbeddedComputation(nested_builder.Build());
-  auto reduce_window2 =
-      builder.AddInstruction(HloInstruction::CreateReduceWindow(
-          ShapeUtil::MakeShape(S32, {2, 2}), const0, const1, window,
-          nested_computation));
-  hlo_module->AddEntryComputation(builder.Build())
-      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce_window2},
-                                HloInstruction::FusionKind::kLoop);
-
-  EXPECT_TRUE(LiteralTestUtil::Equal(
-      LiteralUtil::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
-      ExecuteAndTransfer(std::move(hlo_module), {})));
-}
-
-// When a constant (or other op) which has multiple users is imported
-// into a fusion, it should remain shared, rather than being duplicated
-// within the fusion.
-XLA_TEST_F(FusionTest, SharedConstant) {
-  auto hlo_module = CreateNewVerifiedModule();
-
-  auto builder = HloComputation::Builder(TestName());
-  auto const0 = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({0})));
-  auto const1 = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
-  auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
-      ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, const0));
-  auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
-      ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add1));
-  auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
-      ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add2));
-  auto add4 = builder.AddInstruction(HloInstruction::CreateBinary(
-      ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add3));
-  hlo_module->AddEntryComputation(builder.Build())
-      ->CreateFusionInstruction({add4, add3, add2, add1, const1},
-                                HloInstruction::FusionKind::kLoop);
-
-  HloComputation* entry_comp = hlo_module->entry_computation();
-
-  // entry computation contains the constant(0) and the fusion
-  EXPECT_EQ(entry_comp->instruction_count(), 2);
-
-  // fused instruction contains the constant(2), the parameter, and 4 adds
-  EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6);
-
-  EXPECT_TRUE(
-      LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({8}),
-                             ExecuteAndTransfer(std::move(hlo_module), {})));
-}
-
-XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); }
-
-XLA_TEST_F(FusionTest, Subtract2D) {
-  TestElementwise2D<float, 2>(HloOpcode::kSubtract);
-}
-
-XLA_TEST_F(FusionTest, Multiply2D) {
-  TestElementwise2D<float, 2>(HloOpcode::kMultiply);
-}
-
-XLA_TEST_F(FusionTest, Divide2D) {
-  TestElementwise2D<float, 2>(HloOpcode::kDivide);
-}
-
-XLA_TEST_F(FusionTest, Power2D) {
-  TestElementwise2D<float, 2>(HloOpcode::kPower);
-}
-
-XLA_TEST_F(FusionTest, Minimum2D) {
-  TestElementwise2D<float, 2>(HloOpcode::kMinimum);
-}
-
-XLA_TEST_F(FusionTest, Maximum2D) {
-  TestElementwise2D<float, 2>(HloOpcode::kMaximum);
-}
-
-XLA_TEST_F(FusionTest, Equal2D) {
-  TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kEq);
-}
-
-XLA_TEST_F(FusionTest, Inequal2D) {
-  TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kNe);
-}
-
-XLA_TEST_F(FusionTest, Greater2D) {
-  TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kGt);
-}
-
-XLA_TEST_F(FusionTest, Lesser2D) {
-  TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kLt);
-}
-
-XLA_TEST_F(FusionTest, GreaterOrEqual2D) {
-  TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kGe);
-}
-
-XLA_TEST_F(FusionTest, LesserOrEqual2D) {
-  TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kLe);
-}
-
-XLA_TEST_F(FusionTest, Clamp2D) {
-  TestElementwise2D<float, 3>(HloOpcode::kClamp);
-}
-
-class FusionClientLibraryTest : public ClientLibraryTestBase {};
-
-XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) {
-  // On the GPU backend, it's possible to have too many transposes within one
-  // fusion, causing the kernel to run out shared memory and thus not compile.
-  // We want to check that doesn't happen.
-  //
-  // To do this, we create a computation that computes
-  //
-  //   P0 + P0*P1*P1 + P0*P2*P2 ...
-  //
-  // where even parameters have layout 1 and odd parameters have layout 2.
-  //
-  // Our goal is to tempt the backend into creating one giant multi-output
-  // fusion for the whole computation, including the transposes.  Currently
-  // multi-output fusion only fuses fusions, so each of the terms in the sum
-  // needs to be a fusion itself, thus the contortions above.
-  constexpr int kNumParams = 25;
-  XlaBuilder b("ManyLayoutTransformations");
-
-  // This test produces values that overflow int32, which is UB, so use uint32,
-  // where overflow is OK.
-  Array2D<uint32> arr(32, 32);
-  arr.FillUnique();
-  Literal l1 = LiteralUtil::CreateR2FromArray2D(arr).Relayout(
-      LayoutUtil::MakeLayout({0, 1}));
-
-  Literal l2 = LiteralUtil::CreateR2FromArray2D(arr).Relayout(
-      LayoutUtil::MakeLayout({1, 0}));
-
-  XlaOp p0 = AddParam(l1, &b);
-  XlaOp sum = p0;
-  for (int i = 1; i < kNumParams; ++i) {
-    auto pN = AddParam((i % 2 == 0 ? l1 : l2), &b);
-    sum = sum + p0 * pN * pN;
-  }
-
-  ComputeAndCompare(&b, {});
-}
-
-void BM_ParallelFusion(int num_iters) {
-  // Simple element-wise computation to benchmark parallel task partitioning.
-  tensorflow::testing::StopTiming();
-
-  se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
-  auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
-  se::StreamExecutorMemoryAllocator allocator(platform, executors);
-
-  const int64 intra_op_parallelism_threads = 24;
-  xla::LocalClientOptions client_options;
-  client_options.set_platform(platform);
-  client_options.set_intra_op_parallelism_threads(intra_op_parallelism_threads);
-  auto client =
-      ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie();
-
-  int device_ordinal = client->default_device_ordinal();
-
-  // Computation shape parameters.
-  const int64 param0_dim0 = 1024;
-  const int64 param0_dim1 = 1024;
-  const int64 param1_dim0 = 1024;
-  const int64 param1_dim1 = 1024;
-  const int64 param2_dim0 = 1024;
-  const int64 param2_dim1 = 1024;
-
-  // Create computation.
-  XlaBuilder builder("ParallelFusion");
-  Shape shape0 = ShapeUtil::MakeShape(F32, {param0_dim0, param0_dim1});
-  auto param0 = Parameter(&builder, 0, shape0, "param0");
-  Shape shape1 = ShapeUtil::MakeShape(F32, {param1_dim0, param1_dim1});
-  auto param1 = Parameter(&builder, 1, shape1, "param1");
-  Shape shape2 = ShapeUtil::MakeShape(F32, {param2_dim0, param2_dim1});
-  auto param2 = Parameter(&builder, 2, shape2, "param2");
-
-  auto x = Mul(param0, param1);
-  Add(x, param2);
-  auto computation = builder.Build().ConsumeValueOrDie();
-
-  // Transfer literals to device.
-  auto param0_literal =
-      LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1);
-  ScopedShapedBuffer buffer0 =
-      client->LiteralToShapedBuffer(param0_literal, device_ordinal)
-          .ConsumeValueOrDie();
-
-  auto param1_literal =
-      LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1);
-  ScopedShapedBuffer buffer1 =
-      client->LiteralToShapedBuffer(param1_literal, device_ordinal)
-          .ConsumeValueOrDie();
-
-  auto param2_literal =
-      LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1);
-  ScopedShapedBuffer buffer2 =
-      client->LiteralToShapedBuffer(param2_literal, device_ordinal)
-          .ConsumeValueOrDie();
-
-  // Build executable.
-  std::unique_ptr<LocalExecutable> executable =
-      client
-          ->Compile(computation,
-                    {&buffer0.on_host_shape(), &buffer1.on_host_shape(),
-                     &buffer2.on_host_shape()},
-                    ExecutableBuildOptions())
-          .ConsumeValueOrDie();
-
-  se::Stream stream(executors[device_ordinal]);
-  stream.Init();
-
-  // Initialize thread pool.
-  tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
-                                      intra_op_parallelism_threads);
-  Eigen::ThreadPoolDevice device(pool.AsEigenThreadPool(), pool.NumThreads());
-
-  // Initialize ExecutableRunOptions.
-  ExecutableRunOptions options;
-  options.set_allocator(&allocator).set_stream(&stream);
-  options.set_intra_op_thread_pool(&device);
-
-  // Run some warm-up executions.
-  const int kWarmups = 2;
-  for (int i = 0; i < kWarmups; ++i) {
-    auto result = executable->Run({&buffer0, &buffer1, &buffer2}, options);
-    ASSERT_TRUE(result.ok());
-  }
-
-  // Run benchmark.
-  const int64 total_bytes = param0_dim0 * param0_dim0 +
-                            param1_dim0 * param1_dim0 +
-                            param2_dim0 * param2_dim0;
-  tensorflow::testing::BytesProcessed(static_cast<int64>(num_iters) *
-                                      total_bytes * sizeof(float));
-  tensorflow::testing::UseRealTime();
-  tensorflow::testing::StartTiming();
-  for (int i = 0; i < num_iters; ++i) {
-    auto result = executable->Run({&buffer0, &buffer1, &buffer2}, options);
-    ASSERT_TRUE(result.ok());
-  }
-}
-
-BENCHMARK(BM_ParallelFusion);
-
-}  // namespace
-}  // namespace xla
diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h
index 9636df2..c9c2cb7 100644
--- a/tensorflow/compiler/xla/tests/test_macros.h
+++ b/tensorflow/compiler/xla/tests/test_macros.h
@@ -36,6 +36,7 @@
 
 #define DISABLED_ON_CPU(X) X
 #define DISABLED_ON_GPU(X) X
+#define DISABLED_ON_GPU_ROCM(X) X
 #define DISABLED_ON_INTERPRETER(X) X
 
 // We need this macro instead of pasting directly to support nesting
@@ -54,6 +55,12 @@
 #ifdef XLA_TEST_BACKEND_GPU
 # undef DISABLED_ON_GPU
 # define DISABLED_ON_GPU(X) XLA_TEST_PASTE(DISABLED_, X)
+
+#if TENSORFLOW_USE_ROCM
+# undef DISABLED_ON_GPU_ROCM
+# define DISABLED_ON_GPU_ROCM(X) XLA_TEST_PASTE(DISABLED_, X)
+#endif  // TENSORFLOW_USE_ROCM
+
 #endif  // XLA_TEST_BACKEND_GPU
 
 #ifdef XLA_TEST_BACKEND_INTERPRETER
diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl
index d91bc72..bfd79b5 100644
--- a/tensorflow/compiler/xla/xla.bzl
+++ b/tensorflow/compiler/xla/xla.bzl
@@ -1,15 +1,15 @@
 """Wrapper around cc_proto_library used inside the XLA codebase."""
 
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "cc_proto_library",
 )
 load(
-    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "//tensorflow/core/platform:default/build_config_root.bzl",
     "if_static",
 )
 load(
-    "//tensorflow/core:platform/default/cuda_build_defs.bzl",
+    "//tensorflow/core/platform:default/cuda_build_defs.bzl",
     "if_cuda_is_configured",
 )
 
@@ -48,3 +48,6 @@
 # We link the GPU plugin into the XLA Python extension if CUDA is enabled.
 def xla_python_default_plugins():
     return if_cuda_is_configured(["//tensorflow/compiler/xla/service:gpu_plugin"])
+
+def xla_py_test_deps():
+    return []
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index 7a40e40..f20ff9a 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -282,7 +282,13 @@
 
   bool xla_gpu_force_conv_nchw = 125;
 
-  // Next id: 127
+  // Paths to files with ptx code.
+  repeated string xla_gpu_ptx_file = 127;
+
+  // Blacklist for cuDNN convolutions.
+  string xla_gpu_cudnn_conv_blacklist_path = 128;
+
+  // Next id: 129
 
   // Extra options to pass to the compilation backend (e.g. LLVM); specific
   // interpretation of these values is left to the backend.
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 1bd6db2..120be3d 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -294,6 +294,10 @@
 
   // The size of the binary code in the executable.
   int64 executable_size_in_bytes = 6;
+
+  // Whether this profile was drawn from a cache of profiles instead of from
+  // execution on the hardware.
+  bool profile_cache_hit = 7;
 }
 
 // Handle given to a user that represents an execution that the user launched
diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD
index 67402c1..ce61490 100644
--- a/tensorflow/compiler/xrt/BUILD
+++ b/tensorflow/compiler/xrt/BUILD
@@ -8,7 +8,7 @@
 )
 load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_proto_library_py",
 )
 
diff --git a/tensorflow/compiler/xrt/client/xrt_tf_client.cc b/tensorflow/compiler/xrt/client/xrt_tf_client.cc
index 88d0d25..3c6f54c 100644
--- a/tensorflow/compiler/xrt/client/xrt_tf_client.cc
+++ b/tensorflow/compiler/xrt/client/xrt_tf_client.cc
@@ -440,6 +440,7 @@
 void XrtTensorHandle::Serialize(eager::RemoteTensorHandle* proto) const {
   proto->set_op_id(tensor_id_.first);
   proto->set_output_num(tensor_id_.second);
+  proto->set_device(context_->devices_.at(device_id_).name());
 }
 
 AttrValue MakeAttrValue(std::string s) {
diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc
index b791519..89daa98 100644
--- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc
+++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc
@@ -151,7 +151,7 @@
   xrt::XLAComputation computation_proto;
   OP_REQUIRES(
       ctx,
-      computation_proto.ParseFromString(computation_input.scalar<string>()()),
+      computation_proto.ParseFromString(computation_input.scalar<tstring>()()),
       errors::InvalidArgument(
           "Unable to parse computation input to XLAComputation"));
 
@@ -191,7 +191,7 @@
                                              .ComputeProgramShape()
                                              .ToProto();
   Tensor program_shape_output(DT_STRING, TensorShape({1}));
-  program_shape_output.vec<string>()(0) = program_shape.SerializeAsString();
+  program_shape_output.vec<tstring>()(0) = program_shape.SerializeAsString();
   ctx->set_output(1, program_shape_output);
 }
 
diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc
index 231387e..1c4e1f7 100644
--- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc
+++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc
@@ -260,7 +260,7 @@
   TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape()));
   xrt::XRTExecutionConfig config_proto;
   TF_RET_CHECK(
-      config_proto.ParseFromString(execution_config.scalar<string>()()));
+      config_proto.ParseFromString(execution_config.scalar<tstring>()()));
 
   int core_index_in_replica = config_proto.core_index_in_replica();
   TF_RET_CHECK(core_index_in_replica == 0);
@@ -343,12 +343,12 @@
   const Tensor& execution_plan = context->input(0);
   TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_plan.shape()));
   xrt::XRTChainedExecutePlan plan;
-  TF_RET_CHECK(plan.ParseFromString(execution_plan.scalar<string>()()));
+  TF_RET_CHECK(plan.ParseFromString(execution_plan.scalar<tstring>()()));
 
   const Tensor& execution_config = context->input(1);
   TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape()));
   xrt::XRTChainedExecuteConfig config;
-  TF_RET_CHECK(config.ParseFromString(execution_config.scalar<string>()()));
+  TF_RET_CHECK(config.ParseFromString(execution_config.scalar<tstring>()()));
 
   XRTCompilationCache* cache;
   TF_RETURN_IF_ERROR(rm->Lookup<XRTCompilationCache>(
diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h
index 2ffde52..8afd205 100644
--- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h
+++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h
@@ -177,7 +177,7 @@
     xrt::XLAAllocation allocation_proto;
     OP_REQUIRES(
         ctx,
-        allocation_proto.ParseFromString(allocation_info.scalar<string>()()),
+        allocation_proto.ParseFromString(allocation_info.scalar<tstring>()()),
         errors::InvalidArgument(
             "Unable to parse allocation input to XLAAllocation"));
 
@@ -419,7 +419,7 @@
         errors::Internal("tuple description input should be a string scalar"));
     xrt::XLATupleNode tuple_proto;
     OP_REQUIRES(
-        ctx, tuple_proto.ParseFromString(tuple_info.scalar<string>()()),
+        ctx, tuple_proto.ParseFromString(tuple_info.scalar<tstring>()()),
         errors::InvalidArgument("Unable to parse tuple input to XLATupleNode"));
 
     OpInputList arg_list;
@@ -627,7 +627,7 @@
                 errors::Internal("literal input should be a string scalar"));
     xla::LiteralProto literal_proto;
     OP_REQUIRES(ctx,
-                literal_proto.ParseFromString(literal_info.scalar<string>()()),
+                literal_proto.ParseFromString(literal_info.scalar<tstring>()()),
                 errors::InvalidArgument(
                     "Unable to parse allocation input to LiteralProto"));
     xla::Literal literal;
diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD
index cc6ab9a..701125f 100644
--- a/tensorflow/compiler/xrt/tests/BUILD
+++ b/tensorflow/compiler/xrt/tests/BUILD
@@ -1,6 +1,6 @@
 load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_cc_test")
 load(
-    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "//tensorflow/core/platform:default/build_config_root.bzl",
     "tf_cuda_tests_tags",
 )
 
diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc
index f072925..427a631 100644
--- a/tensorflow/compiler/xrt/tests/raw_api_test.cc
+++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc
@@ -127,7 +127,7 @@
 
 xla::Literal ReadOutputLiteral(const std::vector<Tensor>& outputs, size_t idx) {
   xla::LiteralProto response;
-  CHECK(response.ParseFromString(outputs[idx].scalar<string>()()));
+  CHECK(response.ParseFromString(outputs[idx].scalar<tstring>()()));
   return xla::Literal::CreateFromProto(response).ValueOrDie();
 }
 
@@ -316,7 +316,7 @@
   EXPECT_EQ(outputs.size(), 1);
 
   xla::LiteralProto response;
-  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
   EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
 }
 
@@ -351,7 +351,7 @@
     EXPECT_EQ(outputs.size(), 1);
     xla::LiteralProto read_back_literal;
     EXPECT_TRUE(
-        read_back_literal.ParseFromString(outputs[0].scalar<string>()()));
+        read_back_literal.ParseFromString(outputs[0].scalar<tstring>()()));
     Tensor read_back_tensor;
     TF_ASSERT_OK(LiteralToHostTensor(
         xla::Literal::CreateFromProto(read_back_literal).ValueOrDie(), DT_FLOAT,
@@ -381,7 +381,7 @@
     EXPECT_EQ(outputs.size(), 1);
 
     xla::LiteralProto response;
-    EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+    EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
     EXPECT_TRUE(CompareLiteralProtos(response, new_literal));
   }
 }
@@ -413,7 +413,7 @@
   EXPECT_EQ(outputs.size(), 1);
 
   xla::LiteralProto response;
-  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
   EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
 }
 
@@ -439,7 +439,7 @@
   EXPECT_EQ(outputs.size(), 1);
 
   xla::LiteralProto response;
-  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
   EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
 }
 
@@ -465,7 +465,7 @@
   EXPECT_EQ(outputs.size(), 1);
 
   xla::LiteralProto response;
-  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
   // We have sent literal's data (in array layout) with a attribute layout
   // {0,1}, so the expected literal read from device needs to be changed
   // accordingly.
@@ -493,7 +493,7 @@
 
   int64 allocation_handle = outputs[1].scalar<int64>()();
   xla::LiteralProto response;
-  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
   EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
 
   xla::LiteralProto new_literal =
@@ -512,7 +512,7 @@
   EXPECT_EQ(outputs.size(), 1);
 
   xla::LiteralProto new_response;
-  EXPECT_TRUE(new_response.ParseFromString(outputs[0].scalar<string>()()));
+  EXPECT_TRUE(new_response.ParseFromString(outputs[0].scalar<tstring>()()));
   EXPECT_TRUE(CompareLiteralProtos(new_literal, new_response));
 
   Tensor release_tensor(DT_INT64, TensorShape({1}));
@@ -652,7 +652,7 @@
       session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs));
 
   xla::LiteralProto response;
-  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
 
   EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
 }
@@ -673,7 +673,7 @@
   TF_EXPECT_OK(session.Run({read_back}, &outputs));
 
   xla::LiteralProto response;
-  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
   EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
 }
 
@@ -707,13 +707,13 @@
   auto base_elements = base_literal.DecomposeTuple();
   auto nested_0_elements = base_elements[0].Clone().DecomposeTuple();
   xla::LiteralProto response_0;
-  EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar<string>()()));
+  EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar<tstring>()()));
   EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[0], response_0));
   xla::LiteralProto response_1;
-  EXPECT_TRUE(response_1.ParseFromString(outputs[1].scalar<string>()()));
+  EXPECT_TRUE(response_1.ParseFromString(outputs[1].scalar<tstring>()()));
   EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[1], response_1));
   xla::LiteralProto response_00;
-  EXPECT_TRUE(response_00.ParseFromString(outputs[2].scalar<string>()()));
+  EXPECT_TRUE(response_00.ParseFromString(outputs[2].scalar<tstring>()()));
   EXPECT_TRUE(CompareLiteralToLiteralProto(nested_0_elements[0], response_00));
 }
 
@@ -779,9 +779,9 @@
   std::vector<Tensor> outputs;
   TF_EXPECT_OK(session.Run({res_0, res_1}, &outputs));
   xla::LiteralProto response_0;
-  EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar<string>()()));
+  EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar<tstring>()()));
   xla::LiteralProto response_1;
-  EXPECT_TRUE(response_1.ParseFromString(outputs[1].scalar<string>()()));
+  EXPECT_TRUE(response_1.ParseFromString(outputs[1].scalar<tstring>()()));
 
   auto expected_0 = MakeTuple0();
   EXPECT_TRUE(CompareLiteralProtos(response_0, expected_0));
@@ -853,7 +853,7 @@
   TF_EXPECT_OK(session.Run({read_back}, &outputs));
 
   xla::LiteralProto response;
-  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
 
   auto expected = xla::LiteralUtil::CreateR1<float>({-150.0f, -36.0f});
   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
@@ -973,7 +973,7 @@
   EXPECT_EQ(outputs.size(), 1);
 
   xla::LiteralProto response;
-  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
 
   auto expected = xla::LiteralUtil::CreateR1<float>({-150.0f, -36.0f});
   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
@@ -1022,13 +1022,13 @@
   TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
 
   xla::LiteralProto response;
-  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
 
   auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
 
   xla::ProgramShapeProto program_shape;
-  EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
+  EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<tstring>()(0)));
   EXPECT_EQ(program_shape.parameters_size(), 2);
 }
 
@@ -1077,13 +1077,13 @@
   TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
 
   xla::LiteralProto response;
-  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
 
   auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
 
   xla::ProgramShapeProto program_shape;
-  EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
+  EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<tstring>()(0)));
   EXPECT_EQ(program_shape.parameters_size(), 2);
 }
 
@@ -1128,7 +1128,8 @@
                            {release}, &outputs));
 
   xla::ProgramShapeProto program_shape_proto;
-  EXPECT_TRUE(program_shape_proto.ParseFromString(outputs[0].vec<string>()(0)));
+  EXPECT_TRUE(
+      program_shape_proto.ParseFromString(outputs[0].vec<tstring>()(0)));
   xla::ProgramShape program_shape(program_shape_proto);
   EXPECT_EQ(program_shape.parameters_size(), 1);
 
@@ -1196,7 +1197,7 @@
   TF_EXPECT_OK(session.Run({read_back}, &outputs));
 
   xla::LiteralProto response;
-  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
 
   auto expected =
       xla::LiteralUtil::CreateR2WithLayout<float>({{18.0f}, {44.0f}}, layout);
@@ -1231,7 +1232,7 @@
   TF_EXPECT_OK(session.Run({read_back}, &outputs));
 
   xla::LiteralProto response;
-  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
 
   auto expected = xla::LiteralUtil::CreateR0<float>(3.0f);
   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
@@ -1281,7 +1282,7 @@
   TF_EXPECT_OK(session.Run({read_back}, &outputs));
 
   xla::LiteralProto response;
-  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
 
   auto sum = xla::LiteralUtil::CreateR1<float>({9.0f, 7.0f});
   auto expected = xla::LiteralUtil::MakeTuple({&sum});
@@ -1343,7 +1344,7 @@
     EXPECT_EQ(voutputs.size(), 1);
 
     xla::LiteralProto response;
-    EXPECT_TRUE(response.ParseFromString(voutputs[0].scalar<string>()()));
+    EXPECT_TRUE(response.ParseFromString(voutputs[0].scalar<tstring>()()));
 
     auto expected = xla::LiteralUtil::CreateR0<float>(kResults[i]);
     EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
@@ -1514,13 +1515,13 @@
   TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
 
   xla::LiteralProto response;
-  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
 
   auto expected = xla::LiteralUtil::CreateR0<int64>(15123899);
   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
 
   xla::ProgramShapeProto program_shape;
-  EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
+  EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<tstring>()(0)));
   EXPECT_EQ(program_shape.parameters_size(), 2);
   EXPECT_TRUE(xla::ShapeUtil::HasPrimitiveType(
       xla::Shape(program_shape.result()), xla::S64));
@@ -1580,7 +1581,7 @@
   // we have on record.
   for (size_t i = 1, j = 0; i < handles.size(); i += 2, ++j) {
     xla::LiteralProto response;
-    EXPECT_TRUE(response.ParseFromString(outputs[j].scalar<string>()()));
+    EXPECT_TRUE(response.ParseFromString(outputs[j].scalar<tstring>()()));
     EXPECT_TRUE(CompareLiteralProtos(allocs[i].value(), response));
   }
 }
@@ -1668,7 +1669,7 @@
     EXPECT_EQ(outputs.size(), 1);
 
     xla::LiteralProto response;
-    EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+    EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
     auto literal = xla::Literal::CreateFromProto(response).ValueOrDie();
     EXPECT_EQ(literal, zero_literal);
   }
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
index 51b27ea..1e6de7e 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
@@ -214,8 +214,8 @@
       std::vector<string> columns;
       columns.reserve(column_families_tensor->NumElements());
       for (uint64 i = 0; i < column_families_tensor->NumElements(); ++i) {
-        column_families.push_back(column_families_tensor->flat<string>()(i));
-        columns.push_back(columns_tensor->flat<string>()(i));
+        column_families.push_back(column_families_tensor->flat<tstring>()(i));
+        columns.push_back(columns_tensor->flat<tstring>()(i));
       }
 
       DatasetBase* dataset;
@@ -317,7 +317,7 @@
           "Iterator produced a set of Tensors shorter than expected");
     }
     ::google::cloud::bigtable::SingleRowMutation mutation(
-        std::move(tensors[0].scalar<string>()()));
+        std::move(tensors[0].scalar<tstring>()()));
     std::chrono::milliseconds timestamp(timestamp_int);
     for (size_t i = 1; i < tensors.size(); ++i) {
       if (!TensorShapeUtils::IsScalar(tensors[i].shape())) {
@@ -326,11 +326,11 @@
       if (timestamp_int == -1) {
         mutation.emplace_back(::google::cloud::bigtable::SetCell(
             column_families[i - 1], columns[i - 1],
-            std::move(tensors[i].scalar<string>()())));
+            std::move(tensors[i].scalar<tstring>()())));
       } else {
         mutation.emplace_back(::google::cloud::bigtable::SetCell(
             column_families[i - 1], columns[i - 1], timestamp,
-            std::move(tensors[i].scalar<string>()())));
+            std::move(tensors[i].scalar<tstring>()())));
       }
     }
     bulk_mutation->emplace_back(std::move(mutation));
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h
index 1325560..085dc75 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h
@@ -115,6 +115,15 @@
                           const ::google::cloud::bigtable::Row& row,
                           std::vector<Tensor>* out_tensors) = 0;
 
+  Status SaveInternal(IteratorStateWriter* writer) override {
+    return errors::Unimplemented("SaveInternal is currently not supported");
+  }
+
+  Status RestoreInternal(IteratorContext* ctx,
+                         IteratorStateReader* reader) override {
+    return errors::Unimplemented("RestoreInternal is currently not supported");
+  }
+
  private:
   Status EnsureIteratorInitialized() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
     if (reader_) {
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
index 8039ef8..6f1f880 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
@@ -97,12 +97,17 @@
       return "BigtableLookupDatasetOp::Dataset";
     }
 
+    Status CheckExternalState() const override {
+      return errors::FailedPrecondition(DebugString(),
+                                        " depends on external state.");
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
                               Node** output) const override {
-      return errors::Unimplemented("%s does not support serialization",
-                                   DebugString());
+      return errors::Unimplemented(DebugString(),
+                                   " does not support serialization");
     }
 
    private:
@@ -154,13 +159,13 @@
           ::google::cloud::StatusOr<
               std::pair<bool, ::google::cloud::bigtable::Row>>
               row = dataset()->table_->table().ReadRow(
-                  input_tensors[0].scalar<string>()(), dataset()->filter_);
+                  input_tensors[0].scalar<tstring>()(), dataset()->filter_);
           if (!row.ok()) {
             return GcpStatusToTfStatus(row.status());
           }
           if (!row->first) {
             return errors::DataLoss("Row key '",
-                                    input_tensors[0].scalar<string>()(),
+                                    input_tensors[0].scalar<tstring>()(),
                                     "' not found.");
           }
           TF_RETURN_IF_ERROR(ParseRow(ctx, row->second, out_tensors));
@@ -172,13 +177,24 @@
         return Status::OK();
       }
 
+     protected:
+      Status SaveInternal(IteratorStateWriter* writer) override {
+        return errors::Unimplemented("SaveInternal is currently not supported");
+      }
+
+      Status RestoreInternal(IteratorContext* ctx,
+                             IteratorStateReader* reader) override {
+        return errors::Unimplemented(
+            "RestoreInternal is currently not supported");
+      }
+
      private:
       Status ParseRow(IteratorContext* ctx,
                       const ::google::cloud::bigtable::Row& row,
                       std::vector<Tensor>* out_tensors) {
         out_tensors->reserve(dataset()->columns_.size() + 1);
         Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {});
-        row_key_tensor.scalar<string>()() = string(row.row_key());
+        row_key_tensor.scalar<tstring>()() = tstring(row.row_key());
         out_tensors->emplace_back(std::move(row_key_tensor));
 
         if (row.cells().size() > 2 * dataset()->columns_.size()) {
@@ -196,7 +212,7 @@
             if (cell_itr->family_name() == dataset()->column_families_[i] &&
                 string(cell_itr->column_qualifier()) ==
                     dataset()->columns_[i]) {
-              col_tensor.scalar<string>()() = string(cell_itr->value());
+              col_tensor.scalar<tstring>()() = tstring(cell_itr->value());
               found_column = true;
             }
           }
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
index e9d4a1e..51ccd83 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
@@ -71,12 +71,17 @@
 
     BigtableTableResource* table() const { return table_; }
 
+    Status CheckExternalState() const override {
+      return errors::FailedPrecondition(DebugString(),
+                                        " depends on external state.");
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
                               Node** output) const override {
-      return errors::Unimplemented("%s does not support serialization",
-                                   DebugString());
+      return errors::Unimplemented(DebugString(),
+                                   " does not support serialization");
     }
 
    private:
@@ -97,7 +102,7 @@
                       const ::google::cloud::bigtable::Row& row,
                       std::vector<Tensor>* out_tensors) override {
         Tensor output_tensor(ctx->allocator({}), DT_STRING, {});
-        output_tensor.scalar<string>()() = string(row.row_key());
+        output_tensor.scalar<tstring>()() = tstring(row.row_key());
         out_tensors->emplace_back(std::move(output_tensor));
         return Status::OK();
       }
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
index be3c7cc..2bc642f 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
@@ -76,12 +76,17 @@
 
     BigtableTableResource* table() const { return table_; }
 
+    Status CheckExternalState() const override {
+      return errors::FailedPrecondition(DebugString(),
+                                        " depends on external state.");
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
                               Node** output) const override {
-      return errors::Unimplemented("%s does not support serialization",
-                                   DebugString());
+      return errors::Unimplemented(DebugString(),
+                                   " does not support serialization");
     }
 
    private:
@@ -103,7 +108,7 @@
                       const ::google::cloud::bigtable::Row& row,
                       std::vector<Tensor>* out_tensors) override {
         Tensor output_tensor(ctx->allocator({}), DT_STRING, {});
-        output_tensor.scalar<string>()() = string(row.row_key());
+        output_tensor.scalar<tstring>()() = string(row.row_key());
         out_tensors->emplace_back(std::move(output_tensor));
         return Status::OK();
       }
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
index 880f5e4..2659097 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
@@ -89,12 +89,17 @@
       return "BigtableSampleKeyPairsDatasetOp::Dataset";
     }
 
+    Status CheckExternalState() const override {
+      return errors::FailedPrecondition(DebugString(),
+                                        " depends on external state.");
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
                               Node** output) const override {
-      return errors::Unimplemented("%s does not support serialization",
-                                   DebugString());
+      return errors::Unimplemented(DebugString(),
+                                   " does not support serialization");
     }
 
    private:
@@ -175,16 +180,27 @@
         *end_of_sequence = false;
         out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
                                   TensorShape({}));
-        out_tensors->back().scalar<string>()() = keys_[index_];
+        out_tensors->back().scalar<tstring>()() = keys_[index_];
 
         out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
                                   TensorShape({}));
-        out_tensors->back().scalar<string>()() = keys_[index_ + 1];
+        out_tensors->back().scalar<tstring>()() = keys_[index_ + 1];
         ++index_;
 
         return Status::OK();
       }
 
+     protected:
+      Status SaveInternal(IteratorStateWriter* writer) override {
+        return errors::Unimplemented("SaveInternal is currently not supported");
+      }
+
+      Status RestoreInternal(IteratorContext* ctx,
+                             IteratorStateReader* reader) override {
+        return errors::Unimplemented(
+            "RestoreInternal is currently not supported");
+      }
+
      private:
       mutex mu_;
       size_t index_ GUARDED_BY(mu_) = 0;
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
index 53be3b5..1118caf 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
@@ -64,12 +64,17 @@
 
     BigtableTableResource* table() const { return table_; }
 
+    Status CheckExternalState() const override {
+      return errors::FailedPrecondition(DebugString(),
+                                        " depends on external state.");
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
                               Node** output) const override {
-      return errors::Unimplemented("%s does not support serialization",
-                                   DebugString());
+      return errors::Unimplemented(DebugString(),
+                                   " does not support serialization");
     }
 
    private:
@@ -97,7 +102,7 @@
         if (index_ < row_keys_.size()) {
           out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
                                     TensorShape({}));
-          out_tensors->back().scalar<string>()() =
+          out_tensors->back().scalar<tstring>()() =
               string(row_keys_[index_].row_key);
           *end_of_sequence = false;
           index_++;
@@ -107,6 +112,17 @@
         return Status::OK();
       }
 
+     protected:
+      Status SaveInternal(IteratorStateWriter* writer) override {
+        return errors::Unimplemented("SaveInternal is currently not supported");
+      }
+
+      Status RestoreInternal(IteratorContext* ctx,
+                             IteratorStateReader* reader) override {
+        return errors::Unimplemented(
+            "RestoreInternal is currently not supported");
+      }
+
      private:
       mutex mu_;
       size_t index_ = 0;
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
index e68c83e..b6beaf3 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
@@ -131,12 +131,17 @@
 
     BigtableTableResource* table() const { return table_; }
 
+    Status CheckExternalState() const override {
+      return errors::FailedPrecondition(DebugString(),
+                                        " depends on external state.");
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
                               Node** output) const override {
-      return errors::Unimplemented("%s does not support serialization",
-                                   DebugString());
+      return errors::Unimplemented(DebugString(),
+                                   " does not support serialization");
     }
 
    private:
@@ -175,7 +180,7 @@
                       std::vector<Tensor>* out_tensors) override {
         out_tensors->reserve(dataset()->columns_.size() + 1);
         Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {});
-        row_key_tensor.scalar<string>()() = string(row.row_key());
+        row_key_tensor.scalar<tstring>()() = string(row.row_key());
         out_tensors->emplace_back(std::move(row_key_tensor));
 
         if (row.cells().size() > 2 * dataset()->columns_.size()) {
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
index 5a8b2ba..60f92a0 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
@@ -20,8 +20,10 @@
 import tempfile
 import numpy as np
 
+from google.protobuf import text_format
 from tensorflow.contrib.boosted_trees.estimator_batch import estimator
 from tensorflow.contrib.boosted_trees.proto import learner_pb2
+from tensorflow.contrib.boosted_trees.proto import tree_config_pb2
 from tensorflow.contrib.layers.python.layers import feature_column as contrib_feature_column
 from tensorflow.contrib.learn.python.learn.estimators import run_config
 from tensorflow.python.estimator.canned import head as head_lib
@@ -137,6 +139,15 @@
     self._export_dir_base = tempfile.mkdtemp() + "export/"
     gfile.MkDir(self._export_dir_base)
 
+  def _assert_checkpoint_and_return_model(self, model_dir, global_step):
+    reader = checkpoint_utils.load_checkpoint(model_dir)
+    self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
+    serialized = reader.get_tensor("ensemble_model:0_config")
+    ensemble_proto = tree_config_pb2.DecisionTreeEnsembleConfig()
+    ensemble_proto.ParseFromString(serialized)
+
+    return ensemble_proto
+
   def _assert_checkpoint(self, model_dir, global_step):
     reader = checkpoint_utils.load_checkpoint(model_dir)
     self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
@@ -404,8 +415,8 @@
     learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE
     learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE
     learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE
-    learner_config.regularization.tree_complexity = (
-        1.0 / _QUANTILE_REGRESSION_SIZE)
+    learner_config.regularization.tree_complexity = (1.0 /
+                                                     _QUANTILE_REGRESSION_SIZE)
 
     train_input_fn, test_input_fn, y = _quantile_regression_input_fns()
 
@@ -437,8 +448,8 @@
     learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE
     learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE
     learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE
-    learner_config.regularization.tree_complexity = (
-        1.0 / _QUANTILE_REGRESSION_SIZE)
+    learner_config.regularization.tree_complexity = (1.0 /
+                                                     _QUANTILE_REGRESSION_SIZE)
 
     train_input_fn, test_input_fn, y = _quantile_regression_input_fns(
         two_dimension=True)
@@ -471,6 +482,329 @@
     self.assertTrue(frac_both_below_upper >= 0.91)
     self.assertTrue(frac_both_below_upper <= 0.99)
 
+  def testForcedInitialSplits(self):
+    learner_config = learner_pb2.LearnerConfig()
+    learner_config.num_classes = 2
+    learner_config.constraints.max_tree_depth = 3
+
+    initial_subtree = """
+            nodes {
+              dense_float_binary_split {
+                feature_column: 0
+                threshold: -0.5
+                left_id: 1
+                right_id: 2
+              }
+              node_metadata {
+                gain: 0
+              }
+            }
+            nodes {
+              dense_float_binary_split {
+                feature_column: 0
+                threshold: 0.52
+                left_id: 3
+                right_id: 4
+              }
+              node_metadata {
+                gain: 0
+              }
+            }
+            nodes {
+              dense_float_binary_split {
+                feature_column: 0
+                threshold: 0.554
+                left_id: 5
+                right_id: 6
+              }
+              node_metadata {
+                gain: 0
+              }
+            }
+            nodes {
+              leaf {
+                vector {
+                  value: 0.0
+                }
+              }
+            }
+            nodes {
+              leaf {
+                vector {
+                  value: 0.0
+                }
+              }
+            }
+            nodes {
+              leaf {
+                vector {
+                  value: 0.0
+                }
+              }
+            }
+            nodes {
+              leaf {
+                vector {
+                  value: 0.0
+                }
+              }
+            }
+    """
+    tree_proto = tree_config_pb2.DecisionTreeConfig()
+    text_format.Merge(initial_subtree, tree_proto)
+
+    # Set initial subtree info.
+    learner_config.each_tree_start.CopyFrom(tree_proto)
+    learner_config.each_tree_start_num_layers = 2
+
+    model_dir = tempfile.mkdtemp()
+    config = run_config.RunConfig()
+
+    classifier = estimator.GradientBoostedDecisionTreeClassifier(
+        learner_config=learner_config,
+        num_trees=2,
+        examples_per_layer=6,
+        model_dir=model_dir,
+        config=config,
+        center_bias=False,
+        feature_columns=[contrib_feature_column.real_valued_column("x")],
+        output_leaf_index=False)
+
+    classifier.fit(input_fn=_train_input_fn, steps=100)
+    # When no override of global steps, 5 steps were used.
+    ensemble = self._assert_checkpoint_and_return_model(
+        classifier.model_dir, global_step=6)
+
+    # TODO(nponomareva): find a better way to test this.
+    expected_ensemble = """
+      trees {
+        nodes {
+          dense_float_binary_split {
+            threshold: -0.5
+            left_id: 1
+            right_id: 2
+          }
+          node_metadata {
+          }
+        }
+        nodes {
+          dense_float_binary_split {
+            threshold: 0.519999980927
+            left_id: 3
+            right_id: 4
+          }
+          node_metadata {
+          }
+        }
+        nodes {
+          dense_float_binary_split {
+            threshold: 0.554000020027
+            left_id: 5
+            right_id: 6
+          }
+          node_metadata {
+          }
+        }
+        nodes {
+          leaf {
+            vector {
+              value: 0.0
+            }
+          }
+        }
+        nodes {
+          leaf {
+            vector {
+              value: 0.0
+            }
+          }
+        }
+        nodes {
+          leaf {
+            vector {
+              value: 0.0
+            }
+          }
+        }
+        nodes {
+          dense_float_binary_split {
+            threshold: 1.0
+            left_id: 7
+            right_id: 8
+          }
+          node_metadata {
+            gain: 0.888888895512
+          }
+        }
+        nodes {
+          leaf {
+            vector {
+              value: -2.0
+            }
+          }
+        }
+        nodes {
+          leaf {
+            vector {
+              value: 2.00000023842
+            }
+          }
+        }
+      }
+      trees {
+        nodes {
+          dense_float_binary_split {
+            threshold: -0.5
+            left_id: 1
+            right_id: 2
+          }
+          node_metadata {
+          }
+        }
+        nodes {
+          dense_float_binary_split {
+            threshold: 0.519999980927
+            left_id: 3
+            right_id: 4
+          }
+          node_metadata {
+          }
+        }
+        nodes {
+          dense_float_binary_split {
+            threshold: 0.554000020027
+            left_id: 5
+            right_id: 6
+          }
+          node_metadata {
+          }
+        }
+        nodes {
+          leaf {
+            vector {
+              value: 0.0
+            }
+          }
+        }
+        nodes {
+          leaf {
+            vector {
+              value: 0.0
+            }
+          }
+        }
+        nodes {
+          leaf {
+            vector {
+              value: 0.0
+            }
+          }
+        }
+        nodes {
+          dense_float_binary_split {
+            threshold: 1.0
+            left_id: 7
+            right_id: 8
+          }
+          node_metadata {
+            gain: 0.727760672569
+          }
+        }
+        nodes {
+          leaf {
+            vector {
+              value: -1.81873059273
+            }
+          }
+        }
+        nodes {
+          leaf {
+            vector {
+              value: 1.81873047352
+            }
+          }
+        }
+      }
+      trees {
+        nodes {
+          dense_float_binary_split {
+            threshold: -0.5
+            left_id: 1
+            right_id: 2
+          }
+          node_metadata {
+          }
+        }
+        nodes {
+          dense_float_binary_split {
+            threshold: 0.519999980927
+            left_id: 3
+            right_id: 4
+          }
+          node_metadata {
+          }
+        }
+        nodes {
+          dense_float_binary_split {
+            threshold: 0.554000020027
+            left_id: 5
+            right_id: 6
+          }
+          node_metadata {
+          }
+        }
+        nodes {
+          leaf {
+            vector {
+              value: 0.0
+            }
+          }
+        }
+        nodes {
+          leaf {
+            vector {
+              value: 0.0
+            }
+          }
+        }
+        nodes {
+          leaf {
+            vector {
+              value: 0.0
+            }
+          }
+        }
+        nodes {
+          leaf {
+            vector {
+              value: 0.0
+            }
+          }
+        }
+      }
+      tree_weights: 0.10000000149
+      tree_weights: 0.10000000149
+      tree_weights: 0.10000000149
+      tree_metadata {
+        num_tree_weight_updates: 1
+        num_layers_grown: 3
+        is_finalized: true
+      }
+      tree_metadata {
+        num_tree_weight_updates: 1
+        num_layers_grown: 3
+        is_finalized: true
+      }
+      tree_metadata {
+        num_tree_weight_updates: 1
+        num_layers_grown: 2
+      }
+      growing_metadata {
+        num_layers_attempted: 3
+      }
+    """
+    self.assertProtoEquals(expected_ensemble, ensemble)
+
 
 class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase):
 
@@ -674,8 +1008,8 @@
     learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE
     learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE
     learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE
-    learner_config.regularization.tree_complexity = (
-        1.0 / _QUANTILE_REGRESSION_SIZE)
+    learner_config.regularization.tree_complexity = (1.0 /
+                                                     _QUANTILE_REGRESSION_SIZE)
 
     train_input_fn, test_input_fn, y = _quantile_regression_input_fns()
     y = y.reshape(_QUANTILE_REGRESSION_SIZE, 1)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
index 07fa4ca..477b191 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
@@ -29,6 +29,9 @@
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import state_ops
 from tensorflow.python.training import training_util
+from google.protobuf import text_format
+from tensorflow.contrib.boosted_trees.proto import tree_config_pb2
+
 
 class ModelBuilderOutputType(object):
   MODEL_FN_OPS = 0
@@ -106,10 +109,30 @@
   training_features = copy.copy(features)
   training_features.pop(weight_column_name, None)
   global_step = training_util.get_global_step()
+
+  initial_ensemble = ""
+  if learner_config.each_tree_start.nodes:
+    if learner_config.each_tree_start_num_layers <= 0:
+      raise ValueError("You must provide each_tree_start_num_layers.")
+    num_layers = learner_config.each_tree_start_num_layers
+    initial_ensemble = """
+             trees { %s }
+             tree_weights: 0.1
+             tree_metadata {
+              num_tree_weight_updates: 1
+              num_layers_grown: %d
+              is_finalized: false
+             }
+             """ % (text_format.MessageToString(
+                 learner_config.each_tree_start), num_layers)
+    tree_ensemble_proto = tree_config_pb2.DecisionTreeEnsembleConfig()
+    text_format.Merge(initial_ensemble, tree_ensemble_proto)
+    initial_ensemble = tree_ensemble_proto.SerializeToString()
+
   with ops.device(global_step.device):
     ensemble_handle = model_ops.tree_ensemble_variable(
         stamp_token=0,
-        tree_ensemble_config="",  # Initialize an empty ensemble.
+        tree_ensemble_config=initial_ensemble,  # Initialize the ensemble.
         name="ensemble_model")
 
   # Create GBDT model.
diff --git a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc
index 9655e49..5f9976a 100644
--- a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc
@@ -46,7 +46,7 @@
     OP_REQUIRES_OK(context, context->input("tree_ensemble_config",
                                            &tree_ensemble_config_t));
     auto* result = new DecisionTreeEnsembleResource();
-    if (!result->InitFromSerialized(tree_ensemble_config_t->scalar<string>()(),
+    if (!result->InitFromSerialized(tree_ensemble_config_t->scalar<tstring>()(),
                                     stamp_token)) {
       result->Unref();
       OP_REQUIRES(
@@ -99,7 +99,7 @@
     Tensor* output_config_t = nullptr;
     OP_REQUIRES_OK(
         context, context->allocate_output(1, TensorShape(), &output_config_t));
-    output_config_t->scalar<string>()() =
+    output_config_t->scalar<tstring>()() =
         ensemble_resource->SerializeAsString();
   }
 };
@@ -130,7 +130,7 @@
     OP_REQUIRES(
         context,
         ensemble_resource->InitFromSerialized(
-            tree_ensemble_config_t->scalar<string>()(), stamp_token),
+            tree_ensemble_config_t->scalar<tstring>()(), stamp_token),
         errors::InvalidArgument("Unable to parse tree ensemble config."));
   }
 };
diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc
index 431dc68..bea5c2a 100644
--- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc
@@ -324,7 +324,7 @@
                 context,
                 ParseProtoUnlimited(
                     summary_proto,
-                    summary_list[resource_handle_idx].scalar<string>()()),
+                    summary_list[resource_handle_idx].scalar<tstring>()()),
                 errors::InvalidArgument("Unable to parse quantile summary."));
             std::vector<QuantileSummaryEntry> entries;
             entries.reserve(summary_proto->entries_size());
@@ -543,7 +543,7 @@
     ::boosted_trees::QuantileStreamState state_proto;
     OP_REQUIRES(
         context,
-        ParseProtoUnlimited(&state_proto, stream_state_t->scalar<string>()()),
+        ParseProtoUnlimited(&state_proto, stream_state_t->scalar<tstring>()()),
         errors::InvalidArgument("Unabnle to parse quantile stream state."));
     std::vector<QuantileSummary> summaries;
     summaries.reserve(state_proto.summaries_size());
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
index 6527624..4e96957 100644
--- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
@@ -213,8 +213,8 @@
     OP_REQUIRES_OK(context, context->allocate_output("split_infos",
                                                      TensorShape({size_output}),
                                                      &output_splits_t));
-    tensorflow::TTypes<string>::Vec output_splits =
-        output_splits_t->vec<string>();
+    tensorflow::TTypes<tstring>::Vec output_splits =
+        output_splits_t->vec<tstring>();
 
     if (num_elements == 0) {
       return;
@@ -529,8 +529,8 @@
     OP_REQUIRES_OK(context, context->allocate_output(
                                 "split_infos", TensorShape({num_elements}),
                                 &output_splits_t));
-    tensorflow::TTypes<string>::Vec output_splits =
-        output_splits_t->vec<string>();
+    tensorflow::TTypes<tstring>::Vec output_splits =
+        output_splits_t->vec<tstring>();
     SplitBuilderState state(context);
     // For each tree node that needs to be split.
     for (int root_idx = 0; root_idx < num_elements; ++root_idx) {
@@ -780,8 +780,8 @@
     OP_REQUIRES_OK(context, context->allocate_output("split_infos",
                                                      TensorShape({size_output}),
                                                      &output_splits_t));
-    tensorflow::TTypes<string>::Vec output_splits =
-        output_splits_t->vec<string>();
+    tensorflow::TTypes<tstring>::Vec output_splits =
+        output_splits_t->vec<tstring>();
     if (num_elements == 0) {
       return;
     }
diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc
index 91c0178..bf5f5d3 100644
--- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc
@@ -432,6 +432,27 @@
       if (tree_config->nodes_size() <= 0) {
         ensemble_resource->RemoveLastTree();
       }
+
+      if ((ensemble_resource->num_trees() == 0 ||
+           ensemble_resource->LastTreeMetadata()->is_finalized()) &&
+          learner_config_.has_each_tree_start() &&
+          learner_config_.each_tree_start().nodes_size() > 0) {
+        DCHECK_GT(learner_config_.each_tree_start_num_layers(), 0);
+        // Add new dummy tree
+        boosted_trees::trees::DecisionTreeConfig* const tree_config =
+            ensemble_resource->AddNewTree(learning_rate);
+        VLOG(1) << "Adding a new forced tree";
+
+        *tree_config = learner_config_.each_tree_start();
+
+        boosted_trees::trees::DecisionTreeMetadata* const tree_metadata =
+            ensemble_resource->LastTreeMetadata();
+
+        tree_metadata->set_is_finalized(max_tree_depth <= 1);
+        tree_metadata->set_num_tree_weight_updates(1);
+        tree_metadata->set_num_layers_grown(
+            learner_config_.each_tree_start_num_layers());
+      }
     }
   }
 
@@ -447,7 +468,7 @@
     for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) {
       const auto& partition_ids = partition_ids_list[handler_id].vec<int32>();
       const auto& gains = gains_list[handler_id].vec<float>();
-      const auto& splits = splits_list[handler_id].vec<string>();
+      const auto& splits = splits_list[handler_id].vec<tstring>();
       OP_REQUIRES(context, partition_ids.size() == gains.size(),
                   errors::InvalidArgument(
                       "Inconsistent partition Ids and gains tensors: ",
@@ -481,7 +502,7 @@
     // Find best split per partition going through every feature candidate.
     for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) {
       const auto& gains = gains_list[handler_id].vec<float>();
-      const auto& splits = splits_list[handler_id].vec<string>();
+      const auto& splits = splits_list[handler_id].vec<tstring>();
       OP_REQUIRES(context, gains.size() == 1,
                   errors::InvalidArgument(
                       "Gains size must be one for oblivious weak learner: ",
diff --git a/tensorflow/contrib/boosted_trees/proto/BUILD b/tensorflow/contrib/boosted_trees/proto/BUILD
index edddc59..ca3dd54 100644
--- a/tensorflow/contrib/boosted_trees/proto/BUILD
+++ b/tensorflow/contrib/boosted_trees/proto/BUILD
@@ -1,4 +1,4 @@
-load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
+load("//tensorflow/core/platform:default/build_config.bzl", "tf_proto_library")
 
 package(
     licenses = ["notice"],  # Apache 2.0
@@ -12,6 +12,9 @@
         "learner.proto",
     ],
     cc_api_version = 2,
+    protodeps = [
+        ":tree_config_proto",
+    ],
     visibility = ["//visibility:public"],
 )
 
diff --git a/tensorflow/contrib/boosted_trees/proto/learner.proto b/tensorflow/contrib/boosted_trees/proto/learner.proto
index c49cb48..fc5f158 100644
--- a/tensorflow/contrib/boosted_trees/proto/learner.proto
+++ b/tensorflow/contrib/boosted_trees/proto/learner.proto
@@ -1,9 +1,11 @@
 syntax = "proto3";
 
-option cc_enable_arenas = true;
-
 package tensorflow.boosted_trees.learner;
 
+import "tensorflow/contrib/boosted_trees/proto/tree_config.proto";
+
+option cc_enable_arenas = true;
+
 // Tree regularization config.
 message TreeRegularizationConfig {
   // Classic L1/L2.
@@ -149,4 +151,11 @@
 
   // By default we use NORMAL_DECISION_TREE as weak learner.
   WeakLearnerType weak_learner_type = 12;
+
+  // If you want to enforce some splits and allow boosting to figure out the
+  // rest, you can provide a tree that represents the starting splits for each
+  // tree in the ensemble.
+  // Set both each_tree_start and each_tree_start_num_layers.
+  tensorflow.boosted_trees.trees.DecisionTreeConfig each_tree_start = 13;
+  int32 each_tree_start_num_layers = 14;
 }
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py
index 86fd577..74a51f4 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py
@@ -142,7 +142,8 @@
 
 
 def _get_bias_update(grads, hess):
-  return array_ops.where(hess > 0, -grads / hess, array_ops.zeros_like(grads))
+  return array_ops.where_v2(hess > 0, -grads / hess,
+                            array_ops.zeros_like(grads))
 
 
 class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD
index 152d883..d7bbbc1 100644
--- a/tensorflow/contrib/cloud/kernels/BUILD
+++ b/tensorflow/contrib/cloud/kernels/BUILD
@@ -10,7 +10,7 @@
 
 # For platform specific build config
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_proto_library",
 )
 
diff --git a/tensorflow/contrib/cloud/kernels/bigquery_reader_ops.cc b/tensorflow/contrib/cloud/kernels/bigquery_reader_ops.cc
index b0f9237..7a19a1c 100644
--- a/tensorflow/contrib/cloud/kernels/bigquery_reader_ops.cc
+++ b/tensorflow/contrib/cloud/kernels/bigquery_reader_ops.cc
@@ -153,7 +153,7 @@
                    context->allocate_output(0, TensorShape({num_partitions_}),
                                             &output_tensor));
 
-    auto output = output_tensor->template flat<string>();
+    auto output = output_tensor->template flat<tstring>();
     for (int64 i = 0; i < num_partitions_; ++i) {
       BigQueryTablePartition partition;
       partition.set_start_index(i * partition_size);
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
index be66fac..5831781 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
@@ -18,7 +18,6 @@
 from __future__ import print_function
 
 import argparse
-import collections
 import functools
 import itertools
 import os
@@ -59,6 +58,7 @@
 from tensorflow.python.training import rmsprop
 from tensorflow.python.training import saver as saver_lib
 from tensorflow.python.training.tracking import util as trackable_utils
+from tensorflow.python.util.compat import collections_abc
 
 
 CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM
@@ -1131,7 +1131,7 @@
     return numeric_grad.reshape(x_shape)
 
   def _GetShape(self, sess, inputs):
-    if not isinstance(inputs, collections.Iterable):
+    if not isinstance(inputs, collections_abc.Iterable):
       return sess.run(array_ops.shape(inputs))
     else:
       return sess.run([array_ops.shape(x) for x in inputs])
diff --git a/tensorflow/contrib/decision_trees/proto/BUILD b/tensorflow/contrib/decision_trees/proto/BUILD
index a0b2ca5..ebbb9b3 100644
--- a/tensorflow/contrib/decision_trees/proto/BUILD
+++ b/tensorflow/contrib/decision_trees/proto/BUILD
@@ -1,5 +1,5 @@
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_proto_library",
     "tf_pyclif_proto_library",
 )
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 8730dd4..d95ace6 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -1,7 +1,7 @@
 # Implementation of a prototype TF distributed computation library.
 
 load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test")
-load("//tensorflow/core:platform/default/distribute.bzl", "distribute_py_test")
+load("//tensorflow/core/platform:default/distribute.bzl", "distribute_py_test")
 load("//tensorflow:tensorflow.bzl", "cuda_py_test")
 
 package(
@@ -273,6 +273,7 @@
         "no_windows_gpu",
         "notsan",
     ],
+    xla_enable_strict_auto_jit = False,  # Ignoring due to in contrib.
     deps = [
         ":mirrored_strategy",
         "//tensorflow/python/distribute:tpu_strategy",
diff --git a/tensorflow/contrib/distribute/python/keras_backward_compat_test.py b/tensorflow/contrib/distribute/python/keras_backward_compat_test.py
index d6929de..98195cc 100644
--- a/tensorflow/contrib/distribute/python/keras_backward_compat_test.py
+++ b/tensorflow/contrib/distribute/python/keras_backward_compat_test.py
@@ -374,7 +374,7 @@
           loss,
           metrics=metrics,
           distribute=distribution,
-          run_distributed=False)
+          experimental_run_tf_function=False)
 
       inputs = np.zeros((64, 3), dtype=np.float32)
       targets = np.zeros((64, 4), dtype=np.float32)
@@ -405,7 +405,10 @@
       optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001)
       loss = 'mse'
       model.compile(
-          optimizer, loss, distribute=distribution, run_distributed=False)
+          optimizer,
+          loss,
+          distribute=distribution,
+          experimental_run_tf_function=False)
 
       input_a_np = np.asarray(np.random.random((64, 3)), dtype=np.float32)
       input_b_np = np.asarray(np.random.random((64, 5)), dtype=np.float32)
@@ -439,7 +442,10 @@
     optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
     loss = 'mse'
     model.compile(
-        optimizer, loss, distribute=distribution, run_distributed=False)
+        optimizer,
+        loss,
+        distribute=distribution,
+        experimental_run_tf_function=False)
 
     inputs = np.zeros((20, 3), np.float32)
     targets = np.zeros((20, 4), np.float32)
@@ -456,7 +462,10 @@
       optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001)
       loss = 'mse'
       model.compile(
-          optimizer, loss, distribute=distribution, run_distributed=False)
+          optimizer,
+          loss,
+          distribute=distribution,
+          experimental_run_tf_function=False)
 
       # We take 6 input samples with each input having a dimension of 3 or 5.
       input_a_np = np.asarray(np.random.random((6, 3)), dtype=np.float32)
@@ -491,7 +500,7 @@
           loss,
           metrics=metrics,
           distribute=distribution,
-          run_distributed=False)
+          experimental_run_tf_function=False)
 
       dataset = get_dataset(distribution)
 
@@ -511,7 +520,7 @@
           loss='mse',
           metrics=['mae', keras.metrics.CategoricalAccuracy()],
           distribute=distribution,
-          run_distributed=False)
+          experimental_run_tf_function=False)
 
       interleaved_model = get_model()
       interleaved_model.set_weights(user_controlled_model.get_weights())
@@ -520,7 +529,7 @@
           loss='mse',
           metrics=['mae', keras.metrics.CategoricalAccuracy()],
           distribute=distribution,
-          run_distributed=False)
+          experimental_run_tf_function=False)
 
       dataset = get_dataset(distribution)
 
@@ -566,7 +575,7 @@
           loss,
           metrics=metrics,
           distribute=distribution,
-          run_distributed=False)
+          experimental_run_tf_function=False)
 
       input_a_np = np.random.random((10, 3))
       input_b_np = np.random.random((10, 5))
@@ -603,7 +612,7 @@
           loss,
           metrics=metrics,
           distribute=distribution,
-          run_distributed=False)
+          experimental_run_tf_function=False)
 
       dataset = get_dataset(distribution)
 
@@ -618,7 +627,10 @@
 
       loss = 'mse'
       model.compile(
-          optimizer(), loss, distribute=distribution, run_distributed=False)
+          optimizer(),
+          loss,
+          distribute=distribution,
+          experimental_run_tf_function=False)
 
       dataset = get_dataset(distribution)
 
@@ -632,7 +644,10 @@
     optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
     loss = 'mse'
     model.compile(
-        optimizer, loss, distribute=distribution, run_distributed=False)
+        optimizer,
+        loss,
+        distribute=distribution,
+        experimental_run_tf_function=False)
 
     inputs = np.zeros((10, 3), np.float32)
     targets = np.zeros((10, 4), np.float32)
@@ -661,7 +676,10 @@
       optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
       loss = 'mse'
       model.compile(
-          optimizer, loss, distribute=distribution, run_distributed=False)
+          optimizer,
+          loss,
+          distribute=distribution,
+          experimental_run_tf_function=False)
 
       # Wrong input shape
       inputs = np.zeros((10, 5), dtype=np.float32)
@@ -689,7 +707,10 @@
       optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
       loss = 'mse'
       model.compile(
-          optimizer, loss, distribute=distribution, run_distributed=False)
+          optimizer,
+          loss,
+          distribute=distribution,
+          experimental_run_tf_function=False)
 
       # User forgets to batch the dataset
       inputs = np.zeros((10, 3), dtype=np.float32)
@@ -726,7 +747,7 @@
           loss,
           metrics=metrics,
           distribute=distribution,
-          run_distributed=False)
+          experimental_run_tf_function=False)
 
       batch_size = 8
       if isinstance(distribution, mirrored_strategy.CoreMirroredStrategy):
@@ -762,7 +783,10 @@
       optimizer = gradient_descent_keras.SGD(0.01)
       loss = 'mse'
       model.compile(
-          optimizer, loss, distribute=distribution, run_distributed=False)
+          optimizer,
+          loss,
+          distribute=distribution,
+          experimental_run_tf_function=False)
 
       dataset = get_dataset(distribution)
 
@@ -801,7 +825,7 @@
           loss,
           metrics=metrics,
           distribute=distribution,
-          run_distributed=False)
+          experimental_run_tf_function=False)
 
       dataset = get_dataset(distribution)
 
@@ -861,7 +885,7 @@
           loss,
           metrics=metrics,
           distribute=distribution,
-          run_distributed=False)
+          experimental_run_tf_function=False)
 
       dataset = get_dataset(distribution)
 
@@ -905,7 +929,7 @@
           loss='mse',
           optimizer=gradient_descent.GradientDescentOptimizer(0.01),
           distribute=distribution,
-          run_distributed=False)
+          experimental_run_tf_function=False)
       y = np.array([[[1], [1]], [[1], [1]]])
       dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
       dataset = dataset.repeat(100)
@@ -928,7 +952,7 @@
           loss='mse',
           optimizer=gradient_descent.GradientDescentOptimizer(0.01),
           distribute=distribution,
-          run_distributed=False)
+          experimental_run_tf_function=False)
 
       # centered on 5.0, variance 10.0
       x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10))
@@ -974,7 +998,7 @@
           optimizer=gradient_descent.GradientDescentOptimizer(0.5),
           metrics=[keras.metrics.BinaryAccuracy()],
           distribute=distribution,
-          run_distributed=False)
+          experimental_run_tf_function=False)
 
       batch_size = 64
       if not distributed_training_utils.global_batch_size_supported(
@@ -1001,7 +1025,7 @@
           metrics=['accuracy', keras.metrics.BinaryAccuracy()],
           optimizer=gradient_descent.GradientDescentOptimizer(0.001),
           distribute=distribution,
-          run_distributed=False)
+          experimental_run_tf_function=False)
 
       # verify correctness of stateful and stateless metrics.
       x = np.ones((100, 4)).astype('float32')
@@ -1078,7 +1102,7 @@
             optimizer=gradient_descent_keras.SGD(0.5),
             metrics=['mse'],
             distribute=with_distribution,
-            run_distributed=False)
+            experimental_run_tf_function=False)
 
         training_inputs, eval_inputs, predict_inputs = (
             get_correctness_test_inputs(use_numpy, use_validation_data,
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index f502a0b..87c920e 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -513,6 +513,7 @@
         "//tensorflow/python:platform_test",
     ],
     tags = ["nomsan"],  # disable to avoid false positives from scipy.
+    xla_enable_strict_auto_jit = False,
 )
 
 cuda_py_test(
diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
index 4fe4650..e174596 100644
--- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py
+++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
@@ -379,7 +379,7 @@
     size_implicit_dim = (
         original_size // math_ops.maximum(1, -math_ops.reduce_prod(new_shape)))
     new_ndims = array_ops.shape(new_shape)
-    expanded_new_shape = array_ops.where(  # Assumes exactly one `-1`.
+    expanded_new_shape = array_ops.where_v2(  # Assumes exactly one `-1`.
         implicit_dim, array_ops.fill(new_ndims, size_implicit_dim), new_shape)
     validations = [] if not validate else [
         check_ops.assert_rank(
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py
index 241fba2..aee3a60 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py
@@ -43,7 +43,7 @@
     warn_once=True)
 def _sqrtx2p1(x):
   """Implementation of `sqrt(1 + x**2)` which is stable despite large `x`."""
-  return array_ops.where(
+  return array_ops.where_v2(
       math_ops.abs(x) * np.sqrt(np.finfo(x.dtype.as_numpy_dtype).eps) <= 1.,
       math_ops.sqrt(x**2. + 1.),
       # For large x, calculating x**2 can overflow. This can be alleviated by
diff --git a/tensorflow/contrib/distributions/python/ops/binomial.py b/tensorflow/contrib/distributions/python/ops/binomial.py
index cc9e29f..38505c1 100644
--- a/tensorflow/contrib/distributions/python/ops/binomial.py
+++ b/tensorflow/contrib/distributions/python/ops/binomial.py
@@ -68,9 +68,9 @@
   #   where(unsafe, safe_output, betainc(where(unsafe, safe_input, input)))
   ones = array_ops.ones_like(n - k)
   k_eq_n = math_ops.equal(k, n)
-  safe_dn = array_ops.where(k_eq_n, ones, n - k)
+  safe_dn = array_ops.where_v2(k_eq_n, ones, n - k)
   dk = math_ops.betainc(a=safe_dn, b=k + 1, x=1 - p)
-  return array_ops.where(k_eq_n, ones, dk)
+  return array_ops.where_v2(k_eq_n, ones, dk)
 
 
 class Binomial(distribution.Distribution):
diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py
index 85692d2..75e5ca4 100644
--- a/tensorflow/contrib/distributions/python/ops/distribution_util.py
+++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py
@@ -475,10 +475,9 @@
       return array_ops.shape(d.batch_shape_tensor())[0]
     dist_batch_ndims = _get_ndims(mixture_distribution)
     cat_batch_ndims = _get_ndims(categorical_distribution)
-    pad_ndims = array_ops.where(
-        categorical_distribution.is_scalar_batch(),
-        dist_batch_ndims,
-        dist_batch_ndims - cat_batch_ndims)
+    pad_ndims = array_ops.where_v2(categorical_distribution.is_scalar_batch(),
+                                   dist_batch_ndims,
+                                   dist_batch_ndims - cat_batch_ndims)
     s = array_ops.shape(x)
     x = array_ops.reshape(x, shape=array_ops.concat([
         s[:-1],
diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
index 343a7f5..e55b4a1 100644
--- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
+++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
@@ -236,7 +236,7 @@
           self.batch_shape_tensor(),
           np.array(np.nan, dtype=self.dtype.as_numpy_dtype()),
           name="nan")
-      return array_ops.where(self.concentration > 1., mean, nan)
+      return array_ops.where_v2(self.concentration > 1., mean, nan)
     else:
       return control_flow_ops.with_dependencies([
           check_ops.assert_less(
@@ -257,7 +257,7 @@
           self.batch_shape_tensor(),
           np.array(np.nan, dtype=self.dtype.as_numpy_dtype()),
           name="nan")
-      return array_ops.where(self.concentration > 2., var, nan)
+      return array_ops.where_v2(self.concentration > 2., var, nan)
     else:
       return control_flow_ops.with_dependencies([
           check_ops.assert_less(
diff --git a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py
index e3712dd..56f35c2 100644
--- a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py
+++ b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py
@@ -235,7 +235,7 @@
           np.array(np.nan, dtype=self.dtype.as_numpy_dtype),
           name="nan")
       is_defined = (self.concentration1 > 1.) & (self.concentration0 > 1.)
-      return array_ops.where(is_defined, mode, nan)
+      return array_ops.where_v2(is_defined, mode, nan)
 
     return control_flow_ops.with_dependencies([
         check_ops.assert_less(
diff --git a/tensorflow/contrib/distributions/python/ops/negative_binomial.py b/tensorflow/contrib/distributions/python/ops/negative_binomial.py
index 9ab98d1..faf9827 100644
--- a/tensorflow/contrib/distributions/python/ops/negative_binomial.py
+++ b/tensorflow/contrib/distributions/python/ops/negative_binomial.py
@@ -190,10 +190,9 @@
     return self.total_count * math_ops.exp(self.logits)
 
   def _mode(self):
-    adjusted_count = array_ops.where(
-        1. < self.total_count,
-        self.total_count - 1.,
-        array_ops.zeros_like(self.total_count))
+    adjusted_count = array_ops.where_v2(1. < self.total_count,
+                                        self.total_count - 1.,
+                                        array_ops.zeros_like(self.total_count))
     return math_ops.floor(adjusted_count * math_ops.exp(self.logits))
 
   def _variance(self):
diff --git a/tensorflow/contrib/distributions/python/ops/shape.py b/tensorflow/contrib/distributions/python/ops/shape.py
index 19d88d5..1be2dd1 100644
--- a/tensorflow/contrib/distributions/python/ops/shape.py
+++ b/tensorflow/contrib/distributions/python/ops/shape.py
@@ -457,9 +457,9 @@
         batch_shape = s[1:1+self.batch_ndims]
         # Since sample_dims=1 and is left-most, we add 1 to the number of
         # batch_ndims to get the event start dim.
-        event_start = array_ops.where(
-            math_ops.logical_and(expand_batch_dim, self._batch_ndims_is_0),
-            2, 1 + self.batch_ndims)
+        event_start = array_ops.where_v2(
+            math_ops.logical_and(expand_batch_dim, self._batch_ndims_is_0), 2,
+            1 + self.batch_ndims)
         event_shape = s[event_start:event_start+self.event_ndims]
       new_shape = array_ops.concat([sample_shape, batch_shape, event_shape], 0)
       x = array_ops.reshape(x, shape=new_shape)
diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
index f974846..b39dba7 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
@@ -1060,5 +1060,5 @@
     if axis_ is not None:
       axis = np.int(ndims + axis_ if axis_ < 0 else axis_)
     else:
-      axis = array_ops.where(axis < 0, ndims + axis, axis)
+      axis = array_ops.where_v2(axis < 0, ndims + axis, axis)
   return nn_ops.softmax(x, axis=axis)
diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py
index a5bb880..cf51a5b 100644
--- a/tensorflow/contrib/distributions/python/ops/wishart.py
+++ b/tensorflow/contrib/distributions/python/ops/wishart.py
@@ -400,10 +400,9 @@
 
   def _mode(self):
     s = self.df - self.dimension - 1.
-    s = array_ops.where(
+    s = array_ops.where_v2(
         math_ops.less(s, 0.),
-        constant_op.constant(float("NaN"), dtype=self.dtype, name="nan"),
-        s)
+        constant_op.constant(float("NaN"), dtype=self.dtype, name="nan"), s)
     if self.cholesky_input_output_matrices:
       return math_ops.sqrt(s) * self.scale_operator.to_dense()
     return s * self._square_scale_operator()
diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py
index 48925b1..0bbece7 100644
--- a/tensorflow/contrib/eager/python/datasets_test.py
+++ b/tensorflow/contrib/eager/python/datasets_test.py
@@ -25,9 +25,9 @@
 
 from tensorflow.contrib import lookup
 from tensorflow.contrib.eager.python import datasets
-from tensorflow.python.data import Dataset
 from tensorflow.python.data.experimental.ops import threadpool
 from tensorflow.python.data.experimental.ops import unique
+from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.eager import test
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -44,24 +44,24 @@
 
   def testBasic(self):
     got = []
-    for t in datasets.Iterator(Dataset.range(4)):
+    for t in datasets.Iterator(dataset_ops.Dataset.range(4)):
       got.append(t.numpy())
     self.assertAllEqual([0, 1, 2, 3], got)
 
   def testBasicOneShotIterator(self):
     got = []
-    for t in Dataset.range(4).make_one_shot_iterator():
+    for t in dataset_ops.Dataset.range(4).make_one_shot_iterator():
       got.append(t.numpy())
     self.assertAllEqual([0, 1, 2, 3], got)
 
   def testBasicImplicitIterator(self):
     got = []
-    for t in Dataset.range(4):
+    for t in dataset_ops.Dataset.range(4):
       got.append(t.numpy())
     self.assertAllEqual([0, 1, 2, 3], got)
 
   def testGetNext(self):
-    iterator = datasets.Iterator(Dataset.range(4))
+    iterator = datasets.Iterator(dataset_ops.Dataset.range(4))
     self.assertEqual(0, iterator.get_next().numpy())
     self.assertEqual(1, iterator.get_next().numpy())
     self.assertEqual(2, iterator.get_next().numpy())
@@ -70,7 +70,7 @@
       iterator.get_next()
 
   def testGetNextOneShotIterator(self):
-    iterator = Dataset.range(4).make_one_shot_iterator()
+    iterator = dataset_ops.Dataset.range(4).make_one_shot_iterator()
     self.assertEqual(0, iterator.get_next().numpy())
     self.assertEqual(1, iterator.get_next().numpy())
     self.assertEqual(2, iterator.get_next().numpy())
@@ -79,7 +79,7 @@
       iterator.get_next()
 
   def testMultipleIteratorsOnTheSameDataset(self):
-    ds = Dataset.range(4)
+    ds = dataset_ops.Dataset.range(4)
     it1 = datasets.Iterator(ds)
     it2 = datasets.Iterator(ds)
     got = [x.numpy() for x in it1]
@@ -89,8 +89,10 @@
     self.assertAllEqual([0, 1, 2, 3], got)
 
   def testNestedOutputs(self):
-    ds = Dataset.zip((Dataset.range(4), Dataset.zip((Dataset.range(4),
-                                                     Dataset.range(4)))))
+    ds = dataset_ops.Dataset.zip(
+        (dataset_ops.Dataset.range(4),
+         dataset_ops.Dataset.zip(
+             (dataset_ops.Dataset.range(4), dataset_ops.Dataset.range(4)))))
     total = 0
     # The Iterator will return a nested structure of Tensor objects.
     # Some funkiness to compare against simple integers.
@@ -102,10 +104,12 @@
     self.assertEqual(4, total)
 
   def testMapAndFilter(self):
+
     def even(x):
       return math_ops.equal(math_ops.mod(x, 2), 0)
 
-    it = datasets.Iterator(Dataset.range(8).map(math_ops.square).filter(even))
+    it = datasets.Iterator(
+        dataset_ops.Dataset.range(8).map(math_ops.square).filter(even))
     got = [x.numpy() for x in it]
     self.assertAllEqual([0, 4, 16, 36], got)
 
@@ -115,14 +119,16 @@
     values = constant_op.constant([0, 1, 2], dtypes.int64)
     table = lookup.HashTable(
         lookup.KeyValueTensorInitializer(keys, values), default_val)
-    dataset = Dataset.from_tensor_slices(['brain', 'salad', 'surgery'])
+    dataset = dataset_ops.Dataset.from_tensor_slices(
+        ['brain', 'salad', 'surgery'])
     dataset = dataset.map(table.lookup)
     it = datasets.Iterator(dataset)
     got = [x.numpy() for x in it]
     self.assertAllEqual([0, 1, 2], got)
 
   def testMultipleIteratorsOnADatasetThatUsesFunctions(self):
-    ds = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]).map(math_ops.square)
+    ds = dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4, 5,
+                                                 6]).map(math_ops.square)
 
     got1 = [x.numpy() for x in datasets.Iterator(ds)]
     self.assertAllEqual([1, 4, 9, 16, 25, 36], got1)
@@ -172,7 +178,7 @@
     ]
 
     for i, result in enumerate(
-        datasets.Iterator(Dataset.from_tensor_slices(components))):
+        datasets.Iterator(dataset_ops.Dataset.from_tensor_slices(components))):
       self.assertSparseValuesEqual(expected[i][0], result[0])
       self.assertSparseValuesEqual(expected[i][1], result[1])
 
@@ -181,20 +187,20 @@
     def my_map(inp):
       return [[x + 1 for x in inp]]
 
-    ds = Dataset.range(4).map(
+    ds = dataset_ops.Dataset.range(4).map(
         lambda x: script_ops.py_func(my_map, [[x]], dtypes.int64))
     got = [x.numpy() for x in datasets.Iterator(ds)]
     self.assertAllEqual([[1], [2], [3], [4]], got)
 
   def testTensorsPlacedOnDevice(self):
-    ds = Dataset.from_tensors([0., 1.])
+    ds = dataset_ops.Dataset.from_tensors([0., 1.])
     with ops.device(test.gpu_device_name()):
       x = datasets.Iterator(ds).next()
       x = math_ops.add(x, x)
     self.assertAllEqual([0., 2.], x.numpy())
 
   def testGpuTensor(self):
-    ds = Dataset.from_tensors([0., 1.])
+    ds = dataset_ops.Dataset.from_tensors([0., 1.])
     with ops.device(test.gpu_device_name()):
       for x in ds:
         y = math_ops.add(x, x)
@@ -213,7 +219,7 @@
     for num_threads in [1, 2, 4, 8, 16]:
 
       dataset = (
-          Dataset.range(1000).map(
+          dataset_ops.Dataset.range(1000).map(
               lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
               num_parallel_calls=32).apply(unique.unique()))
 
@@ -235,8 +241,13 @@
   def testSaveRestore(self):
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
-    dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
+    dataset = dataset_ops.Dataset.from_tensor_slices(
+        [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
     dataset = dataset.map(math_ops.square).batch(2)
+    # TODO(b/138399725): Re-enable default optimizations.
+    options = dataset_ops.Options()
+    options.experimental_optimization.apply_default_optimizations = False
+    dataset = dataset.with_options(options)
     iterator = datasets.Iterator(dataset)
     checkpoint = trackable_utils.Checkpoint(iterator=iterator)
     self.assertAllEqual([1, 4], iterator.get_next().numpy())
@@ -250,11 +261,16 @@
   def testSaveRestoreMultipleIterator(self):
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
-    dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
+    dataset = dataset_ops.Dataset.from_tensor_slices(
+        [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
     dataset = dataset.map(math_ops.square).batch(2)
+    # TODO(b/138399725): Re-enable default optimizations.
+    options = dataset_ops.Options()
+    options.experimental_optimization.apply_default_optimizations = False
+    dataset = dataset.with_options(options)
     iterator_1 = datasets.Iterator(dataset)
     iterator_2 = datasets.Iterator(dataset)
-    dataset_2 = Dataset.range(10)
+    dataset_2 = dataset_ops.Dataset.range(10)
     iterator_3 = datasets.Iterator(dataset_2)
 
     checkpoint = trackable_utils.Checkpoint(
@@ -276,7 +292,7 @@
   def testRestoreExhaustedIterator(self):
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
-    dataset = Dataset.range(3)
+    dataset = dataset_ops.Dataset.range(3)
     iterator = datasets.Iterator(dataset)
 
     checkpoint = trackable_utils.Checkpoint(iterator=iterator)
@@ -290,12 +306,12 @@
   def testRestoreInReconstructedIterator(self):
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
-    dataset = Dataset.range(10)
+    dataset = dataset_ops.Dataset.range(10)
     for i in range(5):
       iterator = datasets.Iterator(dataset)
       checkpoint = trackable_utils.Checkpoint(iterator=iterator)
-      checkpoint.restore(checkpoint_management.latest_checkpoint(
-          checkpoint_directory))
+      checkpoint.restore(
+          checkpoint_management.latest_checkpoint(checkpoint_directory))
       for j in range(2):
         self.assertEqual(i * 2 + j, iterator.get_next().numpy())
       checkpoint.save(file_prefix=checkpoint_prefix)
@@ -311,8 +327,8 @@
     input_data = np.random.randn(input_size)
 
     dataset = (
-        Dataset.from_tensor_slices(input_data).repeat(num_epochs)
-        .batch(batch_size))
+        dataset_ops.Dataset.from_tensor_slices(input_data).repeat(
+            num_epochs).batch(batch_size))
     iterator = datasets.Iterator(dataset)
 
     ends = [time.time()]
@@ -321,10 +337,8 @@
 
     deltas = np.ediff1d(ends)
     median_wall_time = np.median(deltas)
-    print(
-        'Slice/repeat/batch eager input size: %d batch size: %d Median wall '
-        'time per element: %f'
-        % (input_size, batch_size, median_wall_time))
+    print('Slice/repeat/batch eager input size: %d batch size: %d Median wall '
+          'time per element: %f' % (input_size, batch_size, median_wall_time))
     self.report_benchmark(
         iters=len(deltas),
         wall_time=median_wall_time,
@@ -339,8 +353,8 @@
     input_data = np.random.randn(input_size)
 
     dataset = (
-        Dataset.from_tensor_slices(input_data).batch(batch_size).cache()
-        .repeat(num_epochs))
+        dataset_ops.Dataset.from_tensor_slices(input_data).batch(
+            batch_size).cache().repeat(num_epochs))
     iterator = datasets.Iterator(dataset)
 
     ends = [time.time()]
@@ -349,10 +363,9 @@
 
     deltas = np.ediff1d(ends)
     median_wall_time = np.median(deltas)
-    print(
-        'Slice/batch/cache/repeat eager input size: %d batch size: %d Median '
-        'wall time per element: %f'
-        % (input_size, batch_size, median_wall_time))
+    print('Slice/batch/cache/repeat eager input size: %d batch size: %d Median '
+          'wall time per element: %f' %
+          (input_size, batch_size, median_wall_time))
     self.report_benchmark(
         iters=len(deltas),
         wall_time=median_wall_time,
diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
index 512605a..cabc71c 100644
--- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
+++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
@@ -117,7 +117,7 @@
       "source": [
         "# Download the file\n",
         "path_to_zip = tf.keras.utils.get_file(\n",
-        "    'spa-eng.zip', origin='http://download.tensorflow.org/data/spa-eng.zip', \n",
+        "    'spa-eng.zip', origin='https://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip', \n",
         "    extract=True)\n",
         "\n",
         "path_to_file = os.path.dirname(path_to_zip)+\"/spa-eng/spa.txt\""
diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks.py b/tensorflow/contrib/eager/python/examples/revnet/blocks.py
index f61354b..221b076 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/blocks.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/blocks.py
@@ -61,7 +61,7 @@
       fused: use fused batch normalization if True
       dtype: float16, float32, or float64
     """
-    super(RevBlock, self).__init__()
+    super(RevBlock, self).__init__(dtype=dtype)
     self.blocks = tf.contrib.checkpoint.List()
     for i in range(n_res):
       curr_batch_norm_first = batch_norm_first and i == 0
@@ -135,7 +135,7 @@
       fused: use fused batch normalization if True
       dtype: float16, float32, or float64
     """
-    super(_Residual, self).__init__()
+    super(_Residual, self).__init__(dtype=dtype)
 
     self.filters = filters
     self.strides = strides
@@ -283,7 +283,7 @@
       fused: use fused batch normalization if True
       dtype: float16, float32, or float64
     """
-    super(_BottleneckResidualInner, self).__init__()
+    super(_BottleneckResidualInner, self).__init__(dtype=dtype)
     axis = 1 if data_format == "channels_first" else 3
     if batch_norm_first:
       self.batch_norm_0 = tf.keras.layers.BatchNormalization(
@@ -365,7 +365,7 @@
       fused: use fused batch normalization if True
       dtype: float16, float32, or float64
     """
-    super(_ResidualInner, self).__init__()
+    super(_ResidualInner, self).__init__(dtype=dtype)
     axis = 1 if data_format == "channels_first" else 3
     if batch_norm_first:
       self.batch_norm_0 = tf.keras.layers.BatchNormalization(
@@ -416,7 +416,7 @@
     Args:
       config: tf.contrib.training.HParams object; specifies hyperparameters
     """
-    super(InitBlock, self).__init__()
+    super(InitBlock, self).__init__(config.dtype)
     self.config = config
     self.axis = 1 if self.config.data_format == "channels_first" else 3
     self.conv2d = tf.keras.layers.Conv2D(
@@ -430,7 +430,8 @@
         dtype=self.config.dtype)
     self.batch_norm = tf.keras.layers.BatchNormalization(
         axis=self.axis, fused=self.config.fused, dtype=self.config.dtype)
-    self.activation = tf.keras.layers.Activation("relu")
+    self.activation = tf.keras.layers.Activation("relu",
+                                                 dtype=self.config.dtype)
 
     if self.config.init_max_pool:
       self.max_pool = tf.keras.layers.MaxPooling2D(
@@ -464,7 +465,7 @@
     Raises:
       ValueError: Unsupported data format
     """
-    super(FinalBlock, self).__init__()
+    super(FinalBlock, self).__init__(dtype=config.dtype)
     self.config = config
     self.axis = 1 if self.config.data_format == "channels_first" else 3
 
@@ -488,7 +489,8 @@
         input_shape=input_shape,
         fused=self.config.fused,
         dtype=self.config.dtype)
-    self.activation = tf.keras.layers.Activation("relu")
+    self.activation = tf.keras.layers.Activation("relu",
+                                                 dtype=self.config.dtype)
     self.global_avg_pool = tf.keras.layers.GlobalAveragePooling2D(
         data_format=self.config.data_format, dtype=self.config.dtype)
     self.dense = tf.keras.layers.Dense(
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py
index 7406787..08f2d8d 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/revnet.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py
@@ -37,7 +37,7 @@
     Args:
       config: tf.contrib.training.HParams object; specifies hyperparameters
     """
-    super(RevNet, self).__init__()
+    super(RevNet, self).__init__(dtype=config.dtype)
     self.axis = 1 if config.data_format == "channels_first" else 3
     self.config = config
 
diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py
index 5c55f7f..e04de05 100644
--- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py
+++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py
@@ -18,7 +18,6 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
 import numbers
 
 from six.moves import xrange  # pylint: disable=redefined-builtin
@@ -42,6 +41,7 @@
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import resource_loader
+from tensorflow.python.util.compat import collections_abc
 
 _factorization_ops = loader.load_op_library(
     resource_loader.get_path_to_datafile("_factorization_ops.so"))
@@ -388,7 +388,7 @@
       return None
 
     init_mode = "list"
-    if isinstance(wt_init, collections.Iterable):
+    if isinstance(wt_init, collections_abc.Iterable):
       if num_shards == 1 and len(wt_init) == num_wts:
         wt_init = [wt_init]
       assert len(wt_init) == num_shards
@@ -641,9 +641,9 @@
         extras = size % num_shards
         assignments = math_ops.maximum(ids // (ids_per_shard + 1),
                                        (ids - extras) // ids_per_shard)
-        new_ids = array_ops.where(assignments < extras,
-                                  ids % (ids_per_shard + 1),
-                                  (ids - extras) % ids_per_shard)
+        new_ids = array_ops.where_v2(assignments < extras,
+                                     ids % (ids_per_shard + 1),
+                                     (ids - extras) % ids_per_shard)
         return assignments, new_ids
 
     return func
diff --git a/tensorflow/contrib/ffmpeg/decode_audio_op.cc b/tensorflow/contrib/ffmpeg/decode_audio_op.cc
index ca65ad4..32e62a6 100644
--- a/tensorflow/contrib/ffmpeg/decode_audio_op.cc
+++ b/tensorflow/contrib/ffmpeg/decode_audio_op.cc
@@ -135,9 +135,10 @@
                     "channel_count must be a rank-0 tensor but got shape ",
                     channel_count_tensor.shape().DebugString()));
 
-    const tensorflow::StringPiece contents = contents_tensor.scalar<string>()();
+    const tensorflow::StringPiece contents =
+        contents_tensor.scalar<tstring>()();
     const string file_format =
-        absl::AsciiStrToLower(file_format_tensor.scalar<string>()());
+        absl::AsciiStrToLower(file_format_tensor.scalar<tstring>()());
     const int32 samples_per_second =
         samples_per_second_tensor.scalar<int32>()();
     const int32 channel_count = channel_count_tensor.scalar<int32>()();
@@ -243,7 +244,7 @@
         errors::InvalidArgument("contents must be scalar but got shape ",
                                 contents.shape().DebugString()));
 
-    const tensorflow::StringPiece file_contents = contents.scalar<string>()();
+    const tensorflow::StringPiece file_contents = contents.scalar<tstring>()();
     Decode(context, file_contents, file_format_, samples_per_second_,
            channel_count_, "");
   }
diff --git a/tensorflow/contrib/ffmpeg/decode_video_op.cc b/tensorflow/contrib/ffmpeg/decode_video_op.cc
index 6f8ad48..0bfdc27 100644
--- a/tensorflow/contrib/ffmpeg/decode_video_op.cc
+++ b/tensorflow/contrib/ffmpeg/decode_video_op.cc
@@ -45,7 +45,8 @@
                 errors::InvalidArgument(
                     "contents must be a rank-0 tensor but got shape ",
                     contents_tensor.shape().DebugString()));
-    const tensorflow::StringPiece contents = contents_tensor.scalar<string>()();
+    const tensorflow::StringPiece contents =
+        contents_tensor.scalar<tstring>()();
 
     // Write the input data to a temp file.
     string extension;
diff --git a/tensorflow/contrib/ffmpeg/encode_audio_op.cc b/tensorflow/contrib/ffmpeg/encode_audio_op.cc
index 7de09e0..ee418fb 100644
--- a/tensorflow/contrib/ffmpeg/encode_audio_op.cc
+++ b/tensorflow/contrib/ffmpeg/encode_audio_op.cc
@@ -45,7 +45,7 @@
   // Copy the encoded audio file to the output tensor.
   Tensor* output = nullptr;
   OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), &output));
-  output->scalar<string>()() = encoded_audio;
+  output->scalar<tstring>()() = encoded_audio;
 }
 
 }  // namespace
@@ -95,7 +95,7 @@
                     bits_per_second_tensor.shape().DebugString()));
 
     const string file_format =
-        absl::AsciiStrToLower(file_format_tensor.scalar<string>()());
+        absl::AsciiStrToLower(file_format_tensor.scalar<tstring>()());
     const int32 samples_per_second =
         samples_per_second_tensor.scalar<int32>()();
     const int32 bits_per_second = bits_per_second_tensor.scalar<int32>()();
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
index c26fdb1..d4ad46e 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
@@ -561,6 +561,14 @@
   *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec);
   *log.mutable_compute_capability() = GetComputeCapability(stream_exec);
   log.set_device_pci_bus_id(stream_exec->GetDeviceDescription().pci_bus_id());
+  {
+    string blas_version;
+    if (auto* blas = stream_exec->AsBlas()) {
+      if (blas->GetVersion(&blas_version).ok()) {
+        log.set_blas_version(blas_version);
+      }
+    }
+  }
   for (const auto& result : results) {
     *log.add_results() = result;
   }
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD
index ddd0494..9231d5b 100644
--- a/tensorflow/contrib/gan/BUILD
+++ b/tensorflow/contrib/gan/BUILD
@@ -63,7 +63,10 @@
     python_version = "PY2",
     shard_count = 50,
     srcs_version = "PY2AND3",
-    tags = ["notsan"],
+    tags = [
+        "no_oss",
+        "notsan",
+    ],
     deps = [
         ":namedtuples",
         ":random_tensor_pool",
@@ -528,7 +531,10 @@
     python_version = "PY2",
     shard_count = 1,
     srcs_version = "PY2AND3",
-    tags = ["notsan"],
+    tags = [
+        "no_oss",
+        "notsan",
+    ],
     deps = [
         ":gan_estimator",
         ":namedtuples",
diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py
index 2c30126..4d7328d 100644
--- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py
+++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py
@@ -74,7 +74,7 @@
     'INCEPTION_DEFAULT_IMAGE_SIZE',
 ]
 
-INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05.tar.gz'
+INCEPTION_URL = 'https://storage.googleapis.com/download.tensorflow.org/models/frozen_inception_v1_2015_12_05.tar.gz'
 INCEPTION_FROZEN_GRAPH = 'inceptionv1_for_inception_score.pb'
 INCEPTION_INPUT = 'Mul:0'
 INCEPTION_OUTPUT = 'logits:0'
@@ -108,7 +108,7 @@
   # Unlike numpy, tensorflow's return order is (s, u, v)
   s, u, v = linalg_ops.svd(mat)
   # sqrt is unstable around 0, just use 0 in such case
-  si = array_ops.where(math_ops.less(s, eps), s, math_ops.sqrt(s))
+  si = array_ops.where_v2(math_ops.less(s, eps), s, math_ops.sqrt(s))
   # Note that the v returned by Tensorflow is v = V
   # (when referencing the equation A = U S V^T)
   # This is unlike Numpy which returns v = V^T
@@ -123,7 +123,7 @@
   """Prepare a batch of images for evaluation.
 
   This is the preprocessing portion of the graph from
-  http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz.
+  https://storage.googleapis.com/download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz.
 
   Note that it expects Tensors in [0, 255]. This function maps pixel values to
   [-1, 1] and resizes to match the InceptionV1 network.
diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py
index bc7c105..56319eb 100644
--- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py
+++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py
@@ -366,7 +366,7 @@
     incscore = _run_with_mock(classifier_metrics.inception_score, unused_image)
 
     with self.cached_session(use_gpu=True) as sess:
-      incscore_np = sess.run(incscore, {'concat:0': logits})
+      incscore_np = sess.run(incscore, {'concat/concat:0': logits})
 
     self.assertAllClose(_expected_inception_score(logits), incscore_np)
 
diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD
index 0e8a493..1eead8b 100644
--- a/tensorflow/contrib/gdr/BUILD
+++ b/tensorflow/contrib/gdr/BUILD
@@ -3,7 +3,7 @@
 
 # For platform specific build config
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_proto_library_cc",
 )
 
diff --git a/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc
index c0b4019..4988ce6 100644
--- a/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc
+++ b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc
@@ -24,6 +24,7 @@
 #include "tensorflow/core/distributed_runtime/request_id.h"
 #include "tensorflow/core/distributed_runtime/worker_cache.h"
 #include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/unbounded_work_queue.h"
 
 namespace tensorflow {
 
@@ -65,12 +66,12 @@
 
 class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
  public:
-  CollectiveRemoteAccessDistributed(const DeviceMgr* dev_mgr,
-                                    DeviceResolverInterface* dev_resolver,
-                                    WorkerCacheInterface* worker_cache,
-                                    int64 step_id,
-                                    RemoteMemoryManager* remote_memory_manager)
-      : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id),
+  CollectiveRemoteAccessDistributed(
+      const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
+      std::shared_ptr<UnboundedWorkQueue> work_queue,
+      WorkerCacheInterface* worker_cache, int64 step_id,
+      RemoteMemoryManager* remote_memory_manager)
+      : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, work_queue, step_id),
         worker_cache_(worker_cache),
         remote_memory_manager_(remote_memory_manager) {}
 
@@ -152,7 +153,7 @@
 CollectiveExecutor* GdrCollectiveExecutorMgr::Create(int64 step_id) {
   CollectiveRemoteAccessDistributed* rma =
       new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(),
-                                            worker_cache_, step_id,
+                                            work_queue_, worker_cache_, step_id,
                                             remote_memory_manager_);
   return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_,
                                     &gpu_ring_order_);
diff --git a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc
index 4744a9e..51f6201 100644
--- a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc
+++ b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc
@@ -163,7 +163,7 @@
                               recv_args, step_id_, parsed.FullKey());
 
     // Record "call" in active_ so that it can be aborted cleanly.
-    RegisterCall(call);
+    RegisterCall(call, recv_args);
 
     // RendezvousMgr already aborted, shouldn't send RPC call any more
     if (!call->status().ok()) {
diff --git a/tensorflow/contrib/graph_editor/util.py b/tensorflow/contrib/graph_editor/util.py
index 4b53d18..543c1da 100644
--- a/tensorflow/contrib/graph_editor/util.py
+++ b/tensorflow/contrib/graph_editor/util.py
@@ -19,11 +19,11 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
 import re
 from six import iteritems
 from tensorflow.python.framework import ops as tf_ops
 from tensorflow.python.ops import array_ops as tf_array_ops
+from tensorflow.python.util.compat import collections_abc
 
 __all__ = [
     "make_list_of_op",
@@ -157,7 +157,7 @@
         res = tree.__new__(type(tree),
                            (transform_tree(child, fn) for child in tree))
       return res
-    elif isinstance(tree, collections.Sequence):
+    elif isinstance(tree, collections_abc.Sequence):
       res = tree.__new__(type(tree))
       res.__init__(transform_tree(child, fn) for child in tree)
       return res
diff --git a/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc
index 2bf6097..99d9319 100644
--- a/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc
+++ b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc
@@ -198,7 +198,7 @@
     std::vector<string> filenames;
     filenames.reserve(filenames_tensor->NumElements());
     for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
-      filenames.push_back(filenames_tensor->flat<string>()(i));
+      filenames.push_back(filenames_tensor->flat<tstring>()(i));
     }
 
     *output = new Dataset(ctx, filenames, output_types_);
@@ -233,6 +233,8 @@
       return "SequenceFileDatasetOp::Dataset";
     }
 
+    bool IsStateful() const override { return false; }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -262,11 +264,11 @@
               TF_RETURN_IF_ERROR(status);
 
               Tensor key_tensor(ctx->allocator({}), DT_STRING, {});
-              key_tensor.scalar<string>()() = key;
+              key_tensor.scalar<tstring>()() = std::move(key);
               out_tensors->emplace_back(std::move(key_tensor));
 
               Tensor value_tensor(ctx->allocator({}), DT_STRING, {});
-              value_tensor.scalar<string>()() = value;
+              value_tensor.scalar<tstring>()() = std::move(value);
               out_tensors->emplace_back(std::move(value_tensor));
 
               *end_of_sequence = false;
diff --git a/tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.cc b/tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.cc
index 4218ec0..41c9a8b 100644
--- a/tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.cc
+++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.cc
@@ -73,7 +73,7 @@
     }
     case STRING: {
       out_tensors->emplace_back(cpu_allocator(), DT_STRING, TensorShape({}));
-      out_tensors->back().scalar<string>()() = ParseString(ptr);
+      out_tensors->back().scalar<tstring>()() = ParseString(ptr);
       break;
     }
     case DATE: {
@@ -150,7 +150,7 @@
       out_tensors->emplace_back(cpu_allocator(), DT_STRING,
                                 TensorShape({length}));
       for (int32_t i = 0; i < length; i++)
-        out_tensors->back().vec<string>()(i) = ParseString(ptr);
+        out_tensors->back().vec<tstring>()(i) = ParseString(ptr);
       break;
     }
     case DATE_ARR: {
diff --git a/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.cc b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.cc
index ce8972f..67a84b9 100644
--- a/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.cc
+++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.cc
@@ -379,7 +379,7 @@
 
 Status IgniteDatasetIterator::ReceivePage(int32_t page_size) {
   remainder_ = page_size;
-  page_ = std::unique_ptr<uint8_t>(new uint8_t[remainder_]);
+  page_ = std::unique_ptr<uint8_t[]>(new uint8_t[remainder_]);
   ptr_ = page_.get();
 
   uint64 start = Env::Default()->NowMicros();
diff --git a/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h
index 5868c2c..2e50511 100644
--- a/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h
+++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h
@@ -74,7 +74,7 @@
 
   mutex mutex_;
 
-  std::unique_ptr<uint8_t> page_;
+  std::unique_ptr<uint8_t[]> page_;
   uint8_t* ptr_;
 };
 
diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py
index 05ba915..96f6af2 100644
--- a/tensorflow/contrib/image/python/ops/image_ops.py
+++ b/tensorflow/contrib/image/python/ops/image_ops.py
@@ -506,7 +506,7 @@
     # constructing multiple additional large tensors.
     components_flat = array_ops.reshape(components, [-1])
     unique_ids, id_index = array_ops.unique(components_flat)
-    id_is_zero = array_ops.where(math_ops.equal(unique_ids, 0))[:, 0]
+    id_is_zero = array_ops.where_v2(math_ops.equal(unique_ids, 0))[:, 0]
     # Map each nonzero id to consecutive values.
     nonzero_consecutive_ids = math_ops.range(
         array_ops.shape(unique_ids)[0] - array_ops.shape(id_is_zero)[0]) + 1
diff --git a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py
index 2b0bcf6..dfc6af3 100755
--- a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py
+++ b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py
@@ -48,7 +48,7 @@
   corrupt the encode 3-D data within the image.
 
   Based upon [this
-  paper](http://www.learningace.com/doc/4331582/b6ab058d1e206d68ab60e4e1ead2fe6e/sirds-paper).
+  paper](https://www.cs.waikato.ac.nz/~ihw/papers/94-HWT-SI-IHW-SIRDS-paper.pdf).
 
   This outputs a SIRDS image as picture_out.png:
 
diff --git a/tensorflow/contrib/input_pipeline/BUILD b/tensorflow/contrib/input_pipeline/BUILD
index 7773991..4fd9e2c 100644
--- a/tensorflow/contrib/input_pipeline/BUILD
+++ b/tensorflow/contrib/input_pipeline/BUILD
@@ -12,7 +12,7 @@
     "tf_kernel_library",
 )
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_kernel_tests_linkstatic",
 )
 load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
diff --git a/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc b/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc
index 886f679..d5da76a 100644
--- a/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc
+++ b/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc
@@ -30,7 +30,7 @@
     const Tensor* list;
     OP_REQUIRES_OK(ctx, ctx->input("list", &list));
     int64 num_elements = list->NumElements();
-    auto list_flat = list->flat<string>();
+    auto list_flat = list->flat<tstring>();
 
     // Allocate output.
     Tensor* output_tensor = nullptr;
@@ -48,7 +48,7 @@
     *pos = (*pos + 1) % num_elements;
 
     // Assign value to output.
-    output_tensor->scalar<string>()() = list_flat(*pos);
+    output_tensor->scalar<tstring>()() = list_flat(*pos);
   }
 };
 
diff --git a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc
index bb0d4c1..34ae684 100644
--- a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc
+++ b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc
@@ -33,7 +33,7 @@
     std::vector<string> topics;
     topics.reserve(topics_tensor->NumElements());
     for (int i = 0; i < topics_tensor->NumElements(); ++i) {
-      topics.push_back(topics_tensor->flat<string>()(i));
+      topics.push_back(topics_tensor->flat<tstring>()(i));
     }
 
     std::string servers = "";
@@ -128,9 +128,9 @@
               if (message->err() == RdKafka::ERR_NO_ERROR) {
                 // Produce the line as output.
                 Tensor line_tensor(cpu_allocator(), DT_STRING, {});
-                line_tensor.scalar<string>()() =
-                    std::string(static_cast<const char*>(message->payload()),
-                                message->len());
+                line_tensor.scalar<tstring>()().assign(
+                    static_cast<const char*>(message->payload()),
+                    message->len());
                 out_tensors->emplace_back(std::move(line_tensor));
                 *end_of_sequence = false;
                 // Sync offset
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py b/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py
index 1783a07..3a257d8 100644
--- a/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py
+++ b/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py
@@ -21,11 +21,11 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
 import functools
 import re
 
 from tensorflow.python.util import tf_inspect
+from tensorflow.python.util.compat import collections_abc
 
 # used for register_type_abbreviation and _type_repr below.
 _TYPE_ABBREVIATIONS = {}
@@ -114,7 +114,7 @@
   """
 
   def __instancecheck__(self, instance):
-    return (isinstance(instance, collections.Sequence) and
+    return (isinstance(instance, collections_abc.Sequence) and
             all(isinstance(x, self._type) for x in instance))
 
 
@@ -130,9 +130,9 @@
   """
 
   def __instancecheck__(self, instance):
-    return (isinstance(instance, collections.Iterable) and
-            isinstance(instance, collections.Sized) and
-            isinstance(instance, collections.Container) and
+    return (isinstance(instance, collections_abc.Iterable) and
+            isinstance(instance, collections_abc.Sized) and
+            isinstance(instance, collections_abc.Container) and
             all(isinstance(x, self._type) for x in instance))
 
 
@@ -157,7 +157,7 @@
 
   def __instancecheck__(self, instance):
     key_type, value_type = self._types  # pylint: disable=unbalanced-tuple-unpacking
-    return (isinstance(instance, collections.Mapping) and
+    return (isinstance(instance, collections_abc.Mapping) and
             all(isinstance(k, key_type) for k in instance.keys()) and
             all(isinstance(k, value_type) for k in instance.values()))
 
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/core.py b/tensorflow/contrib/labeled_tensor/python/ops/core.py
index b0961e5..394254c 100644
--- a/tensorflow/contrib/labeled_tensor/python/ops/core.py
+++ b/tensorflow/contrib/labeled_tensor/python/ops/core.py
@@ -41,11 +41,12 @@
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.util.compat import collections_abc
 
 # pylint: disable=invalid-name
 
 # Types coercible to Axis.labels
-# We use this instead of collections.Sequence to exclude strings.
+# We use this instead of collections_abc.Sequence to exclude strings.
 LabelsLike = tc.Union(np.ndarray, range, list, tuple)
 
 # Types coercible to a tf.compat.v1.Dimension
@@ -195,7 +196,7 @@
   return axis
 
 
-class Axes(collections.Mapping):
+class Axes(collections_abc.Mapping):
   """Axis names and indices for a tensor.
 
   It is an ordered mapping, with keys given by axis name and values given
@@ -719,7 +720,7 @@
 @tc.accepts(LabeledTensorLike,
             tc.Collection(
                 tc.Union(string_types,
-                         tc.Tuple(string_types, collections.Hashable))),
+                         tc.Tuple(string_types, collections_abc.Hashable))),
             tc.Optional(string_types))
 def expand_dims(labeled_tensor, axes, name=None):
   """Insert dimensions of size 1.
@@ -1055,7 +1056,7 @@
 
 
 @tc.returns(types.FunctionType)
-@tc.accepts(string_types, collections.Callable)
+@tc.accepts(string_types, collections_abc.Callable)
 def define_unary_op(op_name, elementwise_function):
   """Define a unary operation for labeled tensors.
 
@@ -1124,7 +1125,7 @@
 
 
 @tc.returns(types.FunctionType)
-@tc.accepts(string_types, collections.Callable)
+@tc.accepts(string_types, collections_abc.Callable)
 def define_binary_op(op_name, elementwise_function):
   """Define a binary operation that broadcasts labeled tensors.
 
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops.py b/tensorflow/contrib/labeled_tensor/python/ops/ops.py
index a04e377..35ab141 100644
--- a/tensorflow/contrib/labeled_tensor/python/ops/ops.py
+++ b/tensorflow/contrib/labeled_tensor/python/ops/ops.py
@@ -17,7 +17,6 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
 import types
 
 import numpy as np
@@ -34,6 +33,7 @@
 from tensorflow.python.ops import numerics
 from tensorflow.python.ops import random_ops
 from tensorflow.python.training import input  # pylint: disable=redefined-builtin
+from tensorflow.python.util.compat import collections_abc
 
 
 @tc.returns(core.LabeledTensor)
@@ -52,7 +52,7 @@
 @tc.returns(core.LabeledTensor)
 @tc.accepts(core.LabeledTensorLike,
             tc.Mapping(string_types,
-                       tc.Union(slice, collections.Hashable, list)),
+                       tc.Union(slice, collections_abc.Hashable, list)),
             tc.Optional(string_types))
 def select(labeled_tensor, selection, name=None):
   """Slice out a subset of the tensor.
@@ -111,8 +111,8 @@
         slices[axis_name] = slice(start, stop)
 
       # Needs to be after checking for slices, since slice objects claim to be
-      # instances of collections.Hashable but hash() on them fails.
-      elif isinstance(value, collections.Hashable):
+      # instances of collections_abc.Hashable but hash() on them fails.
+      elif isinstance(value, collections_abc.Hashable):
         slices[axis_name] = axis.index(value)
 
       elif isinstance(value, list):
@@ -400,7 +400,7 @@
 
 
 @tc.returns(tc.List(core.LabeledTensor))
-@tc.accepts(string_types, collections.Callable, int, bool,
+@tc.accepts(string_types, collections_abc.Callable, int, bool,
             tc.Collection(core.LabeledTensorLike), bool,
             tc.Optional(string_types))
 def _batch_helper(default_name,
@@ -606,7 +606,7 @@
 
 # TODO(shoyer): Allow the user to select the axis over which to map.
 @tc.returns(core.LabeledTensor)
-@tc.accepts(collections.Callable, core.LabeledTensorLike,
+@tc.accepts(collections_abc.Callable, core.LabeledTensorLike,
             tc.Optional(string_types))
 def map_fn(fn, labeled_tensor, name=None):
   """Map on the list of tensors unpacked from labeled_tensor.
@@ -661,7 +661,7 @@
 
 
 @tc.returns(core.LabeledTensor)
-@tc.accepts(collections.Callable, core.LabeledTensorLike,
+@tc.accepts(collections_abc.Callable, core.LabeledTensorLike,
             core.LabeledTensorLike, tc.Optional(string_types))
 def foldl(fn, labeled_tensor, initial_value, name=None):
   """Left fold on the list of tensors unpacked from labeled_tensor.
@@ -754,7 +754,7 @@
 
 # pylint: disable=invalid-name
 ReduceAxis = tc.Union(string_types,
-                      tc.Tuple(string_types, collections.Hashable))
+                      tc.Tuple(string_types, collections_abc.Hashable))
 ReduceAxes = tc.Optional(tc.Union(ReduceAxis, tc.Collection(ReduceAxis)))
 # pylint: enable=invalid-name
 
@@ -876,7 +876,7 @@
 
 
 @tc.returns(types.FunctionType)
-@tc.accepts(string_types, collections.Callable)
+@tc.accepts(string_types, collections_abc.Callable)
 def define_reduce_op(op_name, reduce_fn):
   """Define a reduction op for labeled tensors.
 
diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD
index 8e41000..6010b07 100644
--- a/tensorflow/contrib/layers/BUILD
+++ b/tensorflow/contrib/layers/BUILD
@@ -77,6 +77,8 @@
     srcs_version = "PY2AND3",
     visibility = [
         "//learning/brain:__subpackages__",
+        "//learning/lib/ami/simple_ml/link_other_ml_tools/tensorflow:__subpackages__",
+        "//storage/d/analysis/prefetch:__pkg__",
         "//tensorflow:__subpackages__",
         "//tensorflow_model_optimization:__subpackages__",
         "//third_party/py/tf_slim:__subpackages__",
@@ -154,6 +156,7 @@
         "//tensorflow/python:variables",
         "//tensorflow/python/ops/losses:losses",
     ],
+    xla_enable_strict_auto_jit = False,
 )
 
 py_test(
diff --git a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc
index ee4b037..0923bdd 100644
--- a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc
+++ b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc
@@ -78,7 +78,7 @@
 int64 SparseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
   const int64 start = feature_start_indices_[batch];
   if (DT_STRING == values_.dtype())
-    return Fingerprint64(values_.vec<string>().data()[start + n]);
+    return Fingerprint64(values_.vec<tstring>().data()[start + n]);
   return values_.vec<int64>().data()[start + n];
 }
 
@@ -87,7 +87,7 @@
 string SparseTensorColumn<string>::Feature(int64 batch, int64 n) const {
   const int64 start = feature_start_indices_[batch];
   if (DT_STRING == values_.dtype())
-    return values_.vec<string>().data()[start + n];
+    return values_.vec<tstring>().data()[start + n];
   return std::to_string(values_.vec<int64>().data()[start + n]);
 }
 
@@ -95,7 +95,7 @@
 StringPiece SparseTensorColumn<StringPiece>::Feature(int64 batch,
                                                      int64 n) const {
   const int64 start = feature_start_indices_[batch];
-  return values_.vec<string>().data()[start + n];
+  return values_.vec<tstring>().data()[start + n];
 }
 
 // A column that is backed by a dense tensor.
@@ -118,21 +118,21 @@
 template <>
 int64 DenseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
   if (DT_STRING == tensor_.dtype())
-    return Fingerprint64(tensor_.matrix<string>()(batch, n));
+    return Fingerprint64(tensor_.matrix<tstring>()(batch, n));
   return tensor_.matrix<int64>()(batch, n);
 }
 
 // Internal type is string or StringPiece when using StringCrosser.
 template <>
 string DenseTensorColumn<string>::Feature(int64 batch, int64 n) const {
-  if (DT_STRING == tensor_.dtype()) return tensor_.matrix<string>()(batch, n);
+  if (DT_STRING == tensor_.dtype()) return tensor_.matrix<tstring>()(batch, n);
   return std::to_string(tensor_.matrix<int64>()(batch, n));
 }
 
 template <>
 StringPiece DenseTensorColumn<StringPiece>::Feature(int64 batch,
                                                     int64 n) const {
-  return tensor_.matrix<string>()(batch, n);
+  return tensor_.matrix<tstring>()(batch, n);
 }
 
 // Updates Output tensors with sparse crosses.
diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py
index e47a52a..385dcc0 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column.py
@@ -155,6 +155,7 @@
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import deprecation
 from tensorflow.python.util import nest
+from tensorflow.python.util.compat import collections_abc
 
 # Imports the core `InputLayer` symbol in contrib during development.
 InputLayer = fc_core.InputLayer  # pylint: disable=invalid-name
@@ -1403,7 +1404,7 @@
       least one element of `sparse_id_columns` is not a `SparseColumn` or a
       `WeightedSparseColumn`.
   """
-  if (not isinstance(sparse_id_columns, collections.Sequence) or
+  if (not isinstance(sparse_id_columns, collections_abc.Sequence) or
       isinstance(sparse_id_columns, six.string_types)):
     raise TypeError(
         "sparse_id_columns must be a non-string sequence (ex: list or tuple) "
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 65e8d75..d48edc0 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -25,6 +25,7 @@
     srcs_version = "PY2AND3",
     visibility = [
         "//learning/brain:__subpackages__",
+        "//storage/d/analysis/prefetch:__pkg__",
         "//tensorflow:__subpackages__",
         "//video/youtube/personalization:__subpackages__",
     ],
diff --git a/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py b/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py
index 99f22d1..a15bbce 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py
@@ -19,12 +19,13 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
 import os
 
 import numpy as np
 import six
 
+from tensorflow.python.util.compat import collections_abc
+
 
 def _pprint(d):
   return ', '.join(['%s=%s' % (key, str(value)) for key, value in d.items()])
@@ -55,7 +56,7 @@
     for key in param_names:
       value = getattr(self, key, None)
 
-      if isinstance(value, collections.Callable):
+      if isinstance(value, collections_abc.Callable):
         continue
 
       # XXX: should we rather test if instance of estimator?
diff --git a/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc b/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc
index 720c74e..f35453f 100644
--- a/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc
+++ b/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc
@@ -36,7 +36,7 @@
   void Compute(OpKernelContext* ctx) override {
     const Tensor* input_tensor;
     OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
-    const auto& input_flat = input_tensor->flat<string>();
+    const auto& input_flat = input_tensor->flat<tstring>();
 
     Tensor* label_tensor;
     OP_REQUIRES_OK(
diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile
index b6e82cb..fa8dad9 100644
--- a/tensorflow/contrib/makefile/Makefile
+++ b/tensorflow/contrib/makefile/Makefile
@@ -133,8 +133,6 @@
 $(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*benchmark*.cc) \
 $(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*/*benchmark*.cc) \
 $(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*/*/*benchmark*.cc) \
-$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/random/*.cc) \
-$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/random/internal/*.cc) \
 tensorflow/contrib/makefile/downloads/absl/absl/synchronization/internal/mutex_nonprod.cc \
 tensorflow/contrib/makefile/downloads/absl/absl/hash/internal/print_hash_of.cc
 
diff --git a/tensorflow/contrib/makefile/README.md b/tensorflow/contrib/makefile/README.md
index 1293e59..7ace5d9 100644
--- a/tensorflow/contrib/makefile/README.md
+++ b/tensorflow/contrib/makefile/README.md
@@ -87,9 +87,11 @@
 Assign your NDK location to $NDK_ROOT:
 
 ```bash
-export NDK_ROOT=/absolute/path/to/NDK/android-ndk-rxxx/
+export NDK_ROOT=/absolute/path/to/NDK/android-ndk-r14b
 ```
 
+Note : libtensorflow-core.a cannot be compiled with any ndk version above r14b.
+
 Download the graph if you haven't already:
 
 ```bash
diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh
index efa122b..6cf1145 100755
--- a/tensorflow/contrib/makefile/download_dependencies.sh
+++ b/tensorflow/contrib/makefile/download_dependencies.sh
@@ -140,7 +140,7 @@
 replace_by_sed 's#static uint64x2_t p2ul_CONJ_XOR = vld1q_u64( p2ul_conj_XOR_DATA );#static uint64x2_t p2ul_CONJ_XOR;// = vld1q_u64( p2ul_conj_XOR_DATA ); - Removed by script#' \
   "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h"
 # TODO(satok): Remove this once protobuf/autogen.sh is fixed.
-replace_by_sed 's#https://googlemock.googlecode.com/files/gmock-1.7.0.zip#http://download.tensorflow.org/deps/gmock-1.7.0.zip#' \
+replace_by_sed 's#https://googlemock.googlecode.com/files/gmock-1.7.0.zip#https://storage.googleapis.com/download.tensorflow.org/deps/gmock-1.7.0.zip#' \
   "${DOWNLOADS_DIR}/protobuf/autogen.sh"
 cat "third_party/eigen3/gebp_neon.patch" | patch "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h"
 
diff --git a/tensorflow/contrib/makefile/proto_text_cc_files.txt b/tensorflow/contrib/makefile/proto_text_cc_files.txt
index d7ad266..f5b157d 100644
--- a/tensorflow/contrib/makefile/proto_text_cc_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_cc_files.txt
@@ -36,7 +36,6 @@
 tensorflow/core/lib/strings/scanner.cc
 tensorflow/core/lib/strings/str_util.cc
 tensorflow/core/lib/strings/strcat.cc
-tensorflow/core/lib/strings/stringprintf.cc
 tensorflow/core/lib/wav/wav_io.cc
 tensorflow/core/platform/cpu_info.cc
 tensorflow/core/platform/default/logging.cc
@@ -56,6 +55,7 @@
 tensorflow/core/platform/protobuf.cc
 tensorflow/core/platform/protobuf_util.cc
 tensorflow/core/platform/setround.cc
+tensorflow/core/platform/stringprintf.cc
 tensorflow/core/platform/tensor_coding.cc
 tensorflow/core/platform/tracing.cc
 tensorflow/tools/proto_text/gen_proto_text_functions.cc
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index e284353..d233fe6 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -129,6 +129,7 @@
 tensorflow/core/kernels/fused_batch_norm_op.cc
 tensorflow/core/kernels/fused_eigen_output_kernels.cc
 tensorflow/core/kernels/gather_functor.cc
+tensorflow/core/kernels/gather_functor_batched.cc
 tensorflow/core/kernels/gather_nd_op.cc
 tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc
 tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc
diff --git a/tensorflow/contrib/memory_stats/BUILD b/tensorflow/contrib/memory_stats/BUILD
index 352b2d6..765c93b 100644
--- a/tensorflow/contrib/memory_stats/BUILD
+++ b/tensorflow/contrib/memory_stats/BUILD
@@ -102,4 +102,5 @@
         "//tensorflow/python:math_ops",
         "//tensorflow/python:random_ops",
     ],
+    xla_enable_strict_auto_jit = False,
 )
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index b3f4d8c..e46263b 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -3641,7 +3641,8 @@
       next_shape = array_ops.stack([next_size] + fixed_shape)
       new_value = array_ops.zeros(next_shape, dtype=values.dtype)
       old_value = array.value()
-      assign_op = state_ops.assign(array, new_value, validate_shape=False)
+      with ops.control_dependencies([old_value]):
+        assign_op = state_ops.assign(array, new_value, validate_shape=False)
       with ops.control_dependencies([assign_op]):
         copy_op = array[:size].assign(old_value[:size])
       # return value needs to be the same dtype as no_op() for cond
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index aec0724..e647f61 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -33,6 +33,7 @@
 from tensorflow.python.ops import data_flow_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
 
@@ -6718,6 +6719,7 @@
 
   def setUp(self):
     ops.reset_default_graph()
+    variable_scope.enable_resource_variables()
 
   def testVars(self):
     metrics.streaming_concat(values=array_ops.ones((10,)))
diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py
index 388384a..30375c7 100644
--- a/tensorflow/contrib/model_pruning/python/pruning.py
+++ b/tensorflow/contrib/model_pruning/python/pruning.py
@@ -172,9 +172,11 @@
     nbins: integer
       number of bins to use for histogram computation
     block_height: integer
-      number of rows in a block (defaults to 1)
+      number of rows in a block (defaults to 1), can be -1 in which
+      case it is set to the size of the corresponding weight tensor.
     block_width: integer
-      number of cols in a block (defaults to 1)
+      number of cols in a block (defaults to 1), can be -1 in which
+      case it is set to the size of the corresponding weight tensor.
     block_pooling_function: string
       Whether to perform average (AVG) or max (MAX) pooling in the block
       (default: AVG)
@@ -489,6 +491,10 @@
     if squeezed_weights.get_shape().ndims != 2 or block_dims == [1, 1]:
       return self._update_mask(weights, threshold)
 
+    for i in range(2):
+      if block_dims[i] == -1:
+        block_dims[i] = squeezed_weights.get_shape()[i]
+
     if self._block_pooling_function not in ['AVG', 'MAX']:
       raise ValueError('Unknown pooling function for block sparsity: %s' %
                        self._block_pooling_function)
diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py
index 58080ad..1a925ca 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_test.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_test.py
@@ -129,7 +129,7 @@
       mask_val = new_mask.eval()
       self.assertAllEqual(mask_val, expected_mask)
 
-  def testBlockMasking(self):
+  def testBlockMaskingWithNonnegativeBlockDimensions(self):
     param_list = ["block_height=2", "block_width=2", "threshold_decay=0"]
 
     weights_avg = constant_op.constant(
@@ -146,6 +146,25 @@
     self._blockMasking(param_list + ["block_pooling_function=AVG"], weights_avg,
                        expected_mask)
 
+  def testBlockMaskingWithNegativeBlockDimensions(self):
+    param_list = ["block_height=1", "block_width=-1", "threshold_decay=0"]
+
+    weights_avg = constant_op.constant([[0.1, 0.1, 0.1, 0.1],
+                                        [0.2, 0.2, 0.2, 0.2],
+                                        [0.3, 0.3, 0.3, 0.3],
+                                        [0.3, 0.3, 0.4, 0.4]])
+    weights_max = constant_op.constant([[0.1, 0.0, 0.1, 0.0],
+                                        [0.0, 0.1, 0.0, 0.2],
+                                        [0.3, 0.0, 0.3, 0.0],
+                                        [0.0, -0.3, 0.0, 0.4]])
+    expected_mask = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
+                     [1., 1., 1., 1.], [1., 1., 1., 1.]]
+
+    self._blockMasking(param_list + ["block_pooling_function=MAX"], weights_max,
+                       expected_mask)
+    self._blockMasking(param_list + ["block_pooling_function=AVG"], weights_avg,
+                       expected_mask)
+
   def testBlockMaskingWithHigherDimensions(self):
     param_list = ["block_height=2", "block_width=2", "threshold_decay=0"]
 
diff --git a/tensorflow/contrib/mpi/BUILD b/tensorflow/contrib/mpi/BUILD
index 23f90cf..7522e88 100644
--- a/tensorflow/contrib/mpi/BUILD
+++ b/tensorflow/contrib/mpi/BUILD
@@ -31,7 +31,7 @@
 
 # For platform specific build config
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_proto_library_cc",
 )
 
diff --git a/tensorflow/contrib/mpi_collectives/BUILD b/tensorflow/contrib/mpi_collectives/BUILD
index 5e848c9..f8072ac 100644
--- a/tensorflow/contrib/mpi_collectives/BUILD
+++ b/tensorflow/contrib/mpi_collectives/BUILD
@@ -7,7 +7,7 @@
 licenses(["notice"])  # Apache 2.0
 
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_additional_mpi_lib_defines",
     "tf_proto_library_cc",
 )
diff --git a/tensorflow/contrib/reduce_slice_ops/BUILD b/tensorflow/contrib/reduce_slice_ops/BUILD
index c98ae64..aeb2c67 100644
--- a/tensorflow/contrib/reduce_slice_ops/BUILD
+++ b/tensorflow/contrib/reduce_slice_ops/BUILD
@@ -1,7 +1,7 @@
 load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_custom_op_library", "tf_gen_op_libs", "tf_gen_op_wrapper_py", "tf_kernel_library")
 load("//tensorflow:tensorflow.bzl", "cuda_py_test")
 load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
-load("//tensorflow/core:platform/default/build_config.bzl", "tf_kernel_tests_linkstatic")
+load("//tensorflow/core/platform:default/build_config.bzl", "tf_kernel_tests_linkstatic")
 
 package(
     licenses = ["notice"],  # Apache 2.0
diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
index 4f8186c..78ea637 100644
--- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py
+++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
@@ -227,9 +227,6 @@
   # pylint: enable=invalid-name
 
 
-_lstm_block_cell_grad_outputs = ["cs_prev_grad", "dicfo"]
-
-
 @ops.RegisterGradient("LSTMBlockCell")
 def _LSTMBlockCellGrad(op, *grad):
   """Gradient for LSTMBlockCell."""
@@ -247,7 +244,7 @@
   if cell_size is None:
     raise ValueError("cell_size from `cs_prev` should not be None.")
 
-  (cs_prev_grad, dicfo, wci_grad, wcf_grad,
+  (cs_prev_grad, dgates, wci_grad, wcf_grad,
    wco_grad) = gen_rnn_ops.lstm_block_cell_grad(
        x=x,
        cs_prev=cs_prev,
@@ -267,8 +264,8 @@
        h_grad=h_grad,
        use_peephole=op.get_attr("use_peephole"))
 
-  # Backprop from dicfo to xh.
-  xh_grad = math_ops.matmul(dicfo, w, transpose_b=True)
+  # Backprop from dgates to xh.
+  xh_grad = math_ops.matmul(dgates, w, transpose_b=True)
 
   x_grad = array_ops.slice(xh_grad, (0, 0), (batch_size, input_size))
   x_grad.get_shape().merge_with(x.get_shape())
@@ -277,13 +274,13 @@
                                 (batch_size, cell_size))
   h_prev_grad.get_shape().merge_with(h_prev.get_shape())
 
-  # Backprop from dicfo to w.
+  # Backprop from dgates to w.
   xh = array_ops.concat([x, h_prev], 1)
-  w_grad = math_ops.matmul(xh, dicfo, transpose_a=True)
+  w_grad = math_ops.matmul(xh, dgates, transpose_a=True)
   w_grad.get_shape().merge_with(w.get_shape())
 
-  # Backprop from dicfo to b.
-  b_grad = nn_ops.bias_add_grad(dicfo)
+  # Backprop from dgates to b.
+  b_grad = nn_ops.bias_add_grad(dgates)
   b_grad.get_shape().merge_with(b.get_shape())
 
   return (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad,
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index 75710ea..c0939c8 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -1948,7 +1948,9 @@
         in an existing scope. If not `True`, and the existing scope already has
         the given variables, an error is raised.
     """
-    super(PhasedLSTMCell, self).__init__(_reuse=reuse)
+    # We pass autocast=False because this layer can accept inputs of different
+    # dtypes, so we do not want to automatically cast them to the same dtype.
+    super(PhasedLSTMCell, self).__init__(_reuse=reuse, autocast=False)
     self._num_units = num_units
     self._use_peepholes = use_peepholes
     self._leak = leak
diff --git a/tensorflow/contrib/rpc/BUILD b/tensorflow/contrib/rpc/BUILD
index a037be7..f092af1 100644
--- a/tensorflow/contrib/rpc/BUILD
+++ b/tensorflow/contrib/rpc/BUILD
@@ -1,4 +1,4 @@
-load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
+load("//tensorflow/core/platform:default/build_config_root.bzl", "if_static")
 
 package(
     default_visibility = ["//visibility:public"],
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/BUILD b/tensorflow/contrib/rpc/python/kernel_tests/BUILD
index 47413aa..db197d1 100644
--- a/tensorflow/contrib/rpc/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/rpc/python/kernel_tests/BUILD
@@ -1,7 +1,7 @@
 load("//tensorflow:tensorflow.bzl", "tf_py_test")
 load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
-load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
-load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
+load("//tensorflow/core/platform:default/build_config_root.bzl", "if_static")
+load("//tensorflow/core/platform:default/build_config.bzl", "tf_proto_library")
 # Placeholder for loading internal BUILD rule.
 
 package(
diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD
index 6d8c501..3f9400a 100644
--- a/tensorflow/contrib/seq2seq/BUILD
+++ b/tensorflow/contrib/seq2seq/BUILD
@@ -251,6 +251,7 @@
         "//tensorflow/python:variable_scope",
         "//tensorflow/python:variables",
     ],
+    xla_enable_strict_auto_jit = False,
 )
 
 cuda_py_test(
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py
index 66a464d..824c8da 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py
@@ -149,7 +149,8 @@
     x_test = np.random.randint(vocab, size=(self.batch, self.timestep))
     y = np.random.randn(self.batch, self.timestep)
     model = keras.models.Model([inputs, query, state], score)
-    model.compile("rmsprop", "mse")
+    # TODO(b/138592586): Run with single-execution-path
+    model.compile("rmsprop", "mse", experimental_run_tf_function=False)
     model.fit([x, self.query, self.state], (y, y))
     y_ref = model.predict_on_batch([x_test, self.query, self.state])
 
@@ -159,6 +160,9 @@
         config, custom_objects={attention_cls.__name__: attention_cls})
     loaded_model.set_weights(weights)
 
+    # TODO(b/138592586): Run with single-execution-path
+    loaded_model.compile("rmsprop", "mse", experimental_run_tf_function=False)
+
     y = loaded_model.predict_on_batch([x_test, self.query, self.state])
 
     self.assertAllClose(y_ref, y)
@@ -405,11 +409,13 @@
         memory_sequence_length=self.encoder_sequence_length,
         normalize=True,
         dtype=dtype)
-    cell = keras.layers.LSTMCell(self.units, recurrent_activation="sigmoid")
-    cell = wrapper.AttentionWrapper(cell, attention_mechanism)
+    cell = keras.layers.LSTMCell(self.units, recurrent_activation="sigmoid",
+                                 dtype=dtype)
+    cell = wrapper.AttentionWrapper(cell, attention_mechanism, dtype=dtype)
 
     sampler = sampler_py.TrainingSampler()
-    my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler)
+    my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler,
+                                              dtype=dtype)
 
     final_outputs, final_state, _ = my_decoder(
         decoder_inputs,
@@ -432,11 +438,13 @@
         scale=True,
         dtype=dtype,
     )
-    cell = keras.layers.LSTMCell(self.units, recurrent_activation="sigmoid")
-    cell = wrapper.AttentionWrapper(cell, attention_mechanism)
+    cell = keras.layers.LSTMCell(self.units, recurrent_activation="sigmoid",
+                                 dtype=dtype)
+    cell = wrapper.AttentionWrapper(cell, attention_mechanism, dtype=dtype)
 
     sampler = sampler_py.TrainingSampler()
-    my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler)
+    my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler,
+                                              dtype=dtype)
 
     final_outputs, final_state, _ = my_decoder(
         decoder_inputs,
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
index 6360d1c..343e5f4 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
@@ -407,8 +407,8 @@
       log_prob_neg_inf = array_ops.ones(
           [self.batch_size, self.beam_width], dtype=dtypes.float32) * -np.Inf
 
-      log_probs = array_ops.where(log_prob_mask, log_prob_zeros,
-                                  log_prob_neg_inf)
+      log_probs = array_ops.where_v2(log_prob_mask, log_prob_zeros,
+                                     log_prob_neg_inf)
       return log_probs
 
     log_probs = get_probs()
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index a9215e88..0e19d1e 100644
--- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -2147,7 +2147,8 @@
                initial_cell_state=None,
                name=None,
                attention_layer=None,
-               attention_fn=None):
+               attention_fn=None,
+               dtype=None):
     """Construct the `AttentionWrapper`.
 
     **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
@@ -2224,6 +2225,7 @@
         (attention_mechanism, cell_output, attention_state, attention_layer) and
         outputs (attention, alignments, next_attention_state). If provided, the
         attention_layer_size should be the size of the outputs of attention_fn.
+      dtype: The cell dtype
 
     Raises:
       TypeError: `attention_layer_size` is not None and (`attention_mechanism`
@@ -2232,7 +2234,7 @@
         is a list, and its length does not match that of `attention_layer_size`;
         if `attention_layer_size` and `attention_layer` are set simultaneously.
     """
-    super(AttentionWrapper, self).__init__(name=name)
+    super(AttentionWrapper, self).__init__(name=name, dtype=dtype)
     rnn_cell_impl.assert_like_rnncell("cell", cell)
     if isinstance(attention_mechanism, (list, tuple)):
       self._is_multi = True
diff --git a/tensorflow/contrib/session_bundle/BUILD b/tensorflow/contrib/session_bundle/BUILD
index 5e4f5f5..737d686 100644
--- a/tensorflow/contrib/session_bundle/BUILD
+++ b/tensorflow/contrib/session_bundle/BUILD
@@ -10,7 +10,7 @@
     "py_test",
     "tf_cc_test",
 )
-load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
+load("//tensorflow/core/platform:default/build_config.bzl", "tf_proto_library")
 # Placeholder: load("//tensorflow:tensorflow.bzl", "tf_portable_proto_lib")
 
 package(
diff --git a/tensorflow/contrib/session_bundle/session_bundle.cc b/tensorflow/contrib/session_bundle/session_bundle.cc
index a690d9b..996e4ce 100644
--- a/tensorflow/contrib/session_bundle/session_bundle.cc
+++ b/tensorflow/contrib/session_bundle/session_bundle.cc
@@ -72,7 +72,7 @@
 // Creates a string tensor.
 Tensor CreateStringTensor(const string& value) {
   Tensor tensor(DT_STRING, TensorShape({}));
-  tensor.scalar<string>()() = value;
+  tensor.scalar<tstring>()() = value;
   return tensor;
 }
 
diff --git a/tensorflow/contrib/slim/python/slim/learning_test.py b/tensorflow/contrib/slim/python/slim/learning_test.py
index 5db4fe0..aefc076 100644
--- a/tensorflow/contrib/slim/python/slim/learning_test.py
+++ b/tensorflow/contrib/slim/python/slim/learning_test.py
@@ -197,7 +197,8 @@
     gradient = constant_op.constant(self._grad_vec, dtype=dtypes.float32)
     variable = variables_lib.Variable(array_ops.zeros_like(gradient))
     multiplier_flag = variables_lib.Variable(True)
-    tensor_multiplier = array_ops.where(multiplier_flag, self._multiplier, 1.0)
+    tensor_multiplier = array_ops.where_v2(multiplier_flag, self._multiplier,
+                                           1.0)
     grad_to_var = (gradient, variable)
     gradient_multipliers = {variable: tensor_multiplier}
 
diff --git a/tensorflow/contrib/sparsemax/BUILD b/tensorflow/contrib/sparsemax/BUILD
index 69cbb12..7bb73f5 100644
--- a/tensorflow/contrib/sparsemax/BUILD
+++ b/tensorflow/contrib/sparsemax/BUILD
@@ -9,7 +9,7 @@
     "tf_py_test",
 )
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_kernel_tests_linkstatic",
 )
 
diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD
index fdd7e1e..ca246f9 100644
--- a/tensorflow/contrib/tensor_forest/BUILD
+++ b/tensorflow/contrib/tensor_forest/BUILD
@@ -8,7 +8,7 @@
 load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
 load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
 load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
-load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
+load("//tensorflow/core/platform:default/build_config_root.bzl", "if_static")
 
 package(
     default_visibility = ["//visibility:public"],
diff --git a/tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py b/tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py
index 926e4dd..a8a5b57 100644
--- a/tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py
+++ b/tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py
@@ -17,8 +17,6 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
-
 from tensorflow.contrib import layers
 from tensorflow.contrib.framework.python.ops import variables as framework_variables
 
@@ -29,6 +27,7 @@
 from tensorflow.python.ops import variables
 
 from tensorflow.python.training import adagrad
+from tensorflow.python.util.compat import collections_abc
 
 
 class HybridModel(object):
@@ -66,7 +65,7 @@
 
     # If this is a collection of layers, return the mean of their inference
     # results.
-    if isinstance(layer, collections.Iterable):
+    if isinstance(layer, collections_abc.Iterable):
       return math_ops.reduce_mean(
           array_ops.stack([l.inference_graph(data) for l in layer]), 0)
     # If this is a single layer, return its inference result.
diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc
index 94650fe..5f997c2 100644
--- a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc
@@ -52,7 +52,7 @@
 
     auto* result = new DecisionTreeResource(param_proto_);
     if (!ParseProtoUnlimited(result->mutable_decision_tree(),
-                             tree_config_t->scalar<string>()())) {
+                             tree_config_t->scalar<tstring>()())) {
       result->Unref();
       OP_REQUIRES(context, false,
                   errors::InvalidArgument("Unable to parse tree  config."));
@@ -85,7 +85,7 @@
     Tensor* output_config_t = nullptr;
     OP_REQUIRES_OK(
         context, context->allocate_output(0, TensorShape(), &output_config_t));
-    output_config_t->scalar<string>()() =
+    output_config_t->scalar<tstring>()() =
         decision_tree_resource->decision_tree().SerializeAsString();
   }
 };
@@ -116,7 +116,7 @@
     decision_trees::Model* config =
         decision_tree_resource->mutable_decision_tree();
     OP_REQUIRES(context,
-                ParseProtoUnlimited(config, tree_config_t->scalar<string>()()),
+                ParseProtoUnlimited(config, tree_config_t->scalar<tstring>()()),
                 errors::InvalidArgument("Unable to parse tree  config."));
     decision_tree_resource->MaybeInitialize();
   }
@@ -224,7 +224,7 @@
                                                                   : 0);
     OP_REQUIRES_OK(context, context->allocate_output(1, output_paths_shape,
                                                      &output_tree_paths));
-    auto out_paths = output_tree_paths->unaligned_flat<string>();
+    auto out_paths = output_tree_paths->unaligned_flat<tstring>();
 
     // TODO(gilberth): If this slows down inference too much, consider having
     // a filter that only serializes paths for the predicted label that we're
diff --git a/tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc b/tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc
index b21a917..fcea240 100644
--- a/tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc
@@ -38,7 +38,7 @@
 void Evaluate(const Tensor& input_data, Tensor output_data, int32 start,
               int32 end) {
   auto out_data = output_data.unaligned_flat<float>();
-  const auto in_data = input_data.unaligned_flat<string>();
+  const auto in_data = input_data.unaligned_flat<tstring>();
 
   for (int32 i = start; i < end; ++i) {
     out_data(i) = Convert(in_data(i));
diff --git a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc
index ede6e1a..e4693cf 100644
--- a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc
@@ -56,7 +56,7 @@
                 errors::InvalidArgument("Stats config must be a scalar."));
     auto* result = new FertileStatsResource(param_proto_);
     FertileStats stats;
-    if (!ParseProtoUnlimited(&stats, stats_config_t->scalar<string>()())) {
+    if (!ParseProtoUnlimited(&stats, stats_config_t->scalar<tstring>()())) {
       result->Unref();
       OP_REQUIRES(context, false,
                   errors::InvalidArgument("Unable to parse stats config."));
@@ -98,7 +98,7 @@
 
     FertileStats stats;
     fertile_stats_resource->PackToProto(&stats);
-    output_config_t->scalar<string>()() = stats.SerializeAsString();
+    output_config_t->scalar<tstring>()() = stats.SerializeAsString();
   }
 
  private:
@@ -128,9 +128,10 @@
     // Deallocate all the previous objects on the resource.
     fertile_stats_resource->Reset();
     FertileStats stats;
-    OP_REQUIRES(context,
-                ParseProtoUnlimited(&stats, stats_config_t->scalar<string>()()),
-                errors::InvalidArgument("Unable to parse stats config."));
+    OP_REQUIRES(
+        context,
+        ParseProtoUnlimited(&stats, stats_config_t->scalar<tstring>()()),
+        errors::InvalidArgument("Unable to parse stats config."));
 
     fertile_stats_resource->ExtractFromProto(stats);
     fertile_stats_resource->MaybeInitialize();
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/BUILD b/tensorflow/contrib/tensor_forest/kernels/v4/BUILD
index d205b25..71bfa5b 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/BUILD
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/BUILD
@@ -1,7 +1,7 @@
 # TensorFlow code for training random forests.
 
 load("//tensorflow:tensorflow.bzl", "tf_cc_test")
-load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
+load("//tensorflow/core/platform:default/build_config_root.bzl", "if_static")
 
 package(
     default_visibility = ["//visibility:public"],
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc b/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc
index f4a7058..417cb6f 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc
@@ -103,7 +103,7 @@
 void CandidateGraphRunner::GetSplit(decision_trees::BinaryNode* node) {
   std::vector<Tensor> outputs;
   RunOp(kNoOp, TensorNameValueList(), {kGetSplitName}, &outputs);
-  ParseProtoUnlimited(node, outputs[0].unaligned_flat<string>()(0));
+  ParseProtoUnlimited(node, outputs[0].unaligned_flat<tstring>()(0));
   const auto& oblique = split_.inequality_left_child_test().oblique();
   auto* new_split =
       node->mutable_inequality_left_child_test()->mutable_oblique();
diff --git a/tensorflow/contrib/tensor_forest/proto/BUILD b/tensorflow/contrib/tensor_forest/proto/BUILD
index efa696f..702dbed 100644
--- a/tensorflow/contrib/tensor_forest/proto/BUILD
+++ b/tensorflow/contrib/tensor_forest/proto/BUILD
@@ -1,4 +1,4 @@
-load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
+load("//tensorflow/core/platform:default/build_config.bzl", "tf_proto_library")
 
 package(
     default_visibility = ["//visibility:public"],
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
index df10997..623e52c 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
@@ -461,7 +461,7 @@
           mask = math_ops.less(
               r,
               array_ops.ones_like(r) * self.params.bagging_fraction)
-          gather_indices = array_ops.squeeze(array_ops.where(mask), axis=[1])
+          gather_indices = array_ops.squeeze(array_ops.where_v2(mask), axis=[1])
           # TODO(thomaswc): Calculate out-of-bag data and labels, and store
           # them for use in calculating statistics later.
           tree_data = array_ops.gather(processed_dense_features, gather_indices)
diff --git a/tensorflow/contrib/tensorboard/BUILD b/tensorflow/contrib/tensorboard/BUILD
index e5efe4b..801fe67 100644
--- a/tensorflow/contrib/tensorboard/BUILD
+++ b/tensorflow/contrib/tensorboard/BUILD
@@ -2,7 +2,7 @@
 # TensorBoard module containing volatile or experimental code.
 
 # For platform specific build config
-load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
+load("//tensorflow/core/platform:default/build_config.bzl", "tf_proto_library")
 load("//tensorflow:tensorflow.bzl", "py_test")
 
 package(
diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD
index 94a51ab..017d08f 100644
--- a/tensorflow/contrib/training/BUILD
+++ b/tensorflow/contrib/training/BUILD
@@ -1,7 +1,6 @@
 # Description:
 #   contains parts of TensorFlow that are experimental or unstable and which are not supported.
-
-load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
+load("//tensorflow/core/platform:default/build_config.bzl", "tf_proto_library")
 load("//tensorflow:tensorflow.bzl", "py_test")
 
 package(
@@ -174,7 +173,7 @@
 
 py_test(
     name = "sampling_ops_test",
-    size = "small",
+    size = "medium",
     srcs = ["python/training/sampling_ops_test.py"],
     python_version = "PY2",
     srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/training/python/training/bucket_ops.py b/tensorflow/contrib/training/python/training/bucket_ops.py
index 7a4abc4..fddcf1e 100644
--- a/tensorflow/contrib/training/python/training/bucket_ops.py
+++ b/tensorflow/contrib/training/python/training/bucket_ops.py
@@ -399,7 +399,7 @@
     conditions_c = math_ops.logical_and(
         math_ops.less_equal(buckets_min, input_length),
         math_ops.less(input_length, buckets_max))
-    which_bucket = math_ops.reduce_min(array_ops.where(conditions_c))
+    which_bucket = math_ops.reduce_min(array_ops.where_v2(conditions_c))
     which_bucket = math_ops.cast(which_bucket, dtypes.int32)
 
     if shapes is not None:
diff --git a/tensorflow/contrib/training/python/training/sampling_ops.py b/tensorflow/contrib/training/python/training/sampling_ops.py
index 849b77d..257cc4f 100644
--- a/tensorflow/contrib/training/python/training/sampling_ops.py
+++ b/tensorflow/contrib/training/python/training/sampling_ops.py
@@ -417,7 +417,7 @@
   ratio_l = target_probs / init_probs
 
   # Replace NaNs with 0s.
-  ratio_l = array_ops.where(
+  ratio_l = array_ops.where_v2(
       math_ops.is_nan(ratio_l), array_ops.zeros_like(ratio_l), ratio_l)
 
   # Calculate list of acceptance probabilities.
diff --git a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py
index e44c4f8..02baf4e 100644
--- a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py
+++ b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py
@@ -594,7 +594,7 @@
       # unless we explicitly tie them to CPU.
       with ops.colocate_with(self._state_saver._capacity_queue.queue_ref):
         indices_where_not_done = array_ops.reshape(
-            array_ops.where(
+            array_ops.where_v2(
                 math_ops.logical_not(self._state_saver._sequence_is_done)),
             [-1])
         keeping_next_key = array_ops.gather(
diff --git a/tensorflow/contrib/verbs/BUILD b/tensorflow/contrib/verbs/BUILD
index fac783b..b0035269 100644
--- a/tensorflow/contrib/verbs/BUILD
+++ b/tensorflow/contrib/verbs/BUILD
@@ -5,7 +5,7 @@
 
 # For platform specific build config
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_proto_library_cc",
 )
 
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index fd89109..7fe0765 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -7,7 +7,7 @@
 # ":protos_all_cc" - exports all core TensorFlow protos
 #     ":protos_all_py" - py_proto_library version (Google-internal)
 # ":lib" - exports the public non-test headers for:
-#     platform/: Platform-specific code and external dependencies
+#     //third_party/tensorflow/core/platform:: Platform-specific code and external dependencies
 #     lib/: Low-level libraries that are not TensorFlow-specific
 # ":test" - test equivalent of ":lib".
 #     This is currently public, but may be made internal in the
@@ -104,7 +104,7 @@
 
 # For platform specific build config
 load(
-    ":platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_additional_all_protos",
     "tf_additional_cloud_kernel_deps",
     "tf_additional_cloud_op_deps",
@@ -112,36 +112,26 @@
     "tf_additional_cupti_wrapper_deps",
     "tf_additional_device_tracer_cuda_deps",
     "tf_additional_device_tracer_deps",
-    "tf_additional_device_tracer_srcs",
     "tf_additional_device_tracer_test_flags",
     "tf_additional_gdr_lib_defines",
     "tf_additional_human_readable_json_deps",
     "tf_additional_lib_defines",
     "tf_additional_lib_deps",
-    "tf_additional_lib_hdrs",
-    "tf_additional_lib_srcs",
     "tf_additional_libdevice_data",
     "tf_additional_libdevice_deps",
-    "tf_additional_libdevice_srcs",
     "tf_additional_minimal_lib_srcs",
     "tf_additional_monitoring_hdrs",
-    "tf_additional_monitoring_srcs",
     "tf_additional_mpi_lib_defines",
     "tf_additional_numa_copts",
     "tf_additional_numa_deps",
     "tf_additional_numa_lib_defines",
-    "tf_additional_proto_hdrs",
-    "tf_additional_proto_srcs",
     "tf_additional_test_deps",
-    "tf_additional_test_srcs",
     "tf_additional_verbs_lib_defines",
     "tf_grpc_service_all",
     "tf_jspb_proto_library",
     "tf_kernel_tests_linkstatic",
     "tf_lib_proto_compiler_deps",
     "tf_lib_proto_parsing_deps",
-    "tf_platform_hdrs",
-    "tf_platform_srcs",
     "tf_proto_library",
     "tf_proto_library_cc",
     "tf_protos_all",
@@ -151,10 +141,11 @@
     "tf_pyclif_proto_library",
 )
 load(
-    ":platform/default/build_config_root.bzl",
+    "//tensorflow/core/platform:default/build_config_root.bzl",
     "if_dynamic_kernels",
     "if_static",
     "tf_cuda_tests_tags",
+    "tf_gpu_tests_tags",
 )
 load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
 load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt")
@@ -322,48 +313,36 @@
     visibility = ["//visibility:public"],
 )
 
-# Minimal lib to detect platform
-cc_library(
-    name = "lib_platform",
-    hdrs = [
-        "platform/platform.h",
-    ],
-)
-
 filegroup(
     name = "platform_base_hdrs",
     srcs = [
-        "platform/byte_order.h",
-        "platform/cord.h",
-        "platform/env_time.h",
-        "platform/logging.h",
-        "platform/macros.h",
-        "platform/platform_strings.h",
-        "platform/types.h",
+        "//tensorflow/core/platform:byte_order.h",
+        "//tensorflow/core/platform:cord.h",
+        "//tensorflow/core/platform:env_time.h",
+        "//tensorflow/core/platform:logging.h",
+        "//tensorflow/core/platform:macros.h",
+        "//tensorflow/core/platform:platform_strings.h",
+        "//tensorflow/core/platform:tstring.h",
+        "//tensorflow/core/platform:types.h",
     ],
     visibility = ["//visibility:private"],
 )
 
 cc_library(
     name = "platform_base",
-    srcs = tf_platform_hdrs([
-        "integral_types.h",
-        "logging.h",
-    ]) + tf_platform_srcs([
-        "logging.cc",
-        "env_time.cc",
-    ]) + [
-        "platform/env_time.cc",
-    ],
     hdrs = [":platform_base_hdrs"],
     copts = tf_copts(),
     tags = ["avoid_dep"],
     visibility = [":__subpackages__"],
     deps = [
-        ":lib_platform",
+        "//tensorflow/core/platform",
+        "//tensorflow/core/platform:byte_order",
+        "//tensorflow/core/platform:env_time",
+        "//tensorflow/core/platform:logging",
+        "//tensorflow/core/platform:macros",
+        "//tensorflow/core/platform:types",
         "//tensorflow/core/platform/default/build_config:base",
         "@com_google_absl//absl/base",
-        "@com_google_absl//absl/base:log_severity",
         "@com_google_absl//absl/strings",
     ],
 )
@@ -381,13 +360,13 @@
 filegroup(
     name = "platform_port_hdrs",
     srcs = [
-        "platform/cpu_info.h",
-        "platform/dynamic_annotations.h",
-        "platform/init_main.h",
-        "platform/mem.h",
-        "platform/mutex.h",
-        "platform/numa.h",
-        "platform/thread_annotations.h",
+        "//tensorflow/core/platform:cpu_info.h",
+        "//tensorflow/core/platform:dynamic_annotations.h",
+        "//tensorflow/core/platform:init_main.h",
+        "//tensorflow/core/platform:mem.h",
+        "//tensorflow/core/platform:mutex.h",
+        "//tensorflow/core/platform:numa.h",
+        "//tensorflow/core/platform:thread_annotations.h",
     ],
     visibility = ["//visibility:private"],
 )
@@ -396,24 +375,18 @@
 filegroup(
     name = "platform_port_internal_hdrs",
     srcs = [
-        "platform/demangle.h",
-        "platform/host_info.h",
-        "platform/snappy.h",
+        "//tensorflow/core/platform:demangle.h",
+        "//tensorflow/core/platform:host_info.h",
+        "//tensorflow/core/platform:snappy.h",
     ],
     visibility = ["//visibility:private"],
 )
 
 cc_library(
     name = "platform_port",
-    srcs = tf_platform_hdrs([
-        "cpu_info.h",
-        "dynamic_annotations.h",
-        "thread_annotations.h",
-        "mutex.h",
-    ]) + tf_platform_srcs([
-        "port.cc",
-    ]) + [
-        "platform/cpu_info.cc",
+    srcs = [
+        "//tensorflow/core/platform:cpu_info.cc",
+        "//tensorflow/core/platform:legacy_platform_port_srcs",
     ],
     hdrs = [
         ":platform_port_hdrs",
@@ -422,7 +395,7 @@
     copts = tf_copts() + tf_additional_numa_copts(),
     visibility = [":__subpackages__"],
     deps = [
-        ":lib_platform",
+        "//tensorflow/core/platform:platform",
         ":platform_base",
         "@com_google_absl//absl/base",
         "//tensorflow/core/platform/default/build_config:port",
@@ -433,7 +406,7 @@
 filegroup(
     name = "platform_protobuf_hdrs",
     srcs = [
-        "platform/protobuf.h",
+        "//tensorflow/core/platform:protobuf.h",
     ],
     visibility = ["//visibility:private"],
 )
@@ -442,19 +415,18 @@
 filegroup(
     name = "platform_protobuf_internal_hdrs",
     srcs = [
-        "platform/protobuf_internal.h",
+        "//tensorflow/core/platform:protobuf_internal.h",
     ],
     visibility = ["//visibility:private"],
 )
 
 cc_library(
     name = "platform_protobuf",
-    srcs = tf_platform_hdrs([
-        "protobuf.h",
-    ]) + [
-        "platform/protobuf.cc",
-        "platform/protobuf_util.cc",
+    srcs = [
         "lib/core/status.h",
+        "//tensorflow/core/platform:protobuf.cc",
+        "//tensorflow/core/platform:protobuf.h",
+        "//tensorflow/core/platform:protobuf_util.cc",
     ],
     hdrs = [
         ":platform_protobuf_hdrs",
@@ -463,9 +435,9 @@
     copts = tf_copts(),
     visibility = [":__subpackages__"],
     deps = [
-        ":lib_platform",
         ":platform_base",
         ":platform_port",
+        "//tensorflow/core/platform",
         "//tensorflow/core/platform/default/build_config:protobuf",
         "@com_google_protobuf//:protobuf",
     ],
@@ -475,7 +447,7 @@
     name = "grpc_services",
     srcs = [],
     hdrs = [
-        "platform/grpc_services.h",
+        "//tensorflow/core/platform:grpc_services.h",
     ],
     copts = tf_copts(),
     visibility = ["//visibility:public"],
@@ -484,8 +456,8 @@
 
 cc_library(
     name = "human_readable_json",
-    srcs = tf_platform_srcs(["human_readable_json.cc"]),
-    hdrs = ["platform/human_readable_json.h"],
+    srcs = ["//tensorflow/core/platform:legacy_human_readable_json_src"],
+    hdrs = ["//tensorflow/core/platform:human_readable_json.h"],
     copts = tf_copts(),
     visibility = ["//visibility:public"],
     deps = [
@@ -496,8 +468,8 @@
 
 cc_library(
     name = "logger",
-    srcs = ["platform/logger.cc"],
-    hdrs = ["platform/logger.h"],
+    srcs = ["//tensorflow/core/platform:logger.cc"],
+    hdrs = ["//tensorflow/core/platform:logger.h"],
     copts = tf_copts(),
     visibility = ["//visibility:public"],
     deps = [
@@ -511,9 +483,9 @@
 filegroup(
     name = "platform_env_hdrs",
     srcs = [
-        "platform/env.h",
-        "platform/file_statistics.h",
-        "platform/file_system.h",
+        "//tensorflow/core/platform:env.h",
+        "//tensorflow/core/platform:file_statistics.h",
+        "//tensorflow/core/platform:file_system.h",
     ],
     visibility = ["//visibility:private"],
 )
@@ -522,21 +494,17 @@
 filegroup(
     name = "platform_env_internal_hdrs",
     srcs = [
-        "platform/load_library.h",
+        "//tensorflow/core/platform:load_library.h",
     ],
     visibility = ["//visibility:private"],
 )
 
 cc_library(
     name = "platform_env",
-    srcs = tf_platform_srcs([
-        "env.cc",
-        "load_library.cc",
-    ]) + tf_platform_hdrs([
-        "wide_char.h",
-    ]) + [
-        "platform/env.cc",
-        "platform/file_system.cc",
+    srcs = [
+        "//tensorflow/core/platform:env.cc",
+        "//tensorflow/core/platform:file_system.cc",
+        "//tensorflow/core/platform:legacy_platform_env_srcs",
     ],
     hdrs = [
         ":platform_env_hdrs",
@@ -551,10 +519,10 @@
         ":error_codes_proto_cc",
         ":lib",
         ":lib_internal",
-        ":lib_platform",
         ":platform_base",
         ":platform_port",
         ":platform_protobuf",
+        "//tensorflow/core/platform",
         "//tensorflow/core/platform/default/build_config:env",
         "//tensorflow/core/platform/default/build_config:port",
     ],
@@ -563,19 +531,17 @@
 filegroup(
     name = "platform_file_system_hdrs",
     srcs = [
-        "platform/file_system_helper.h",
-        "platform/null_file_system.h",
+        "//tensorflow/core/platform:file_system_helper.h",
+        "//tensorflow/core/platform:null_file_system.h",
     ],
     visibility = ["//visibility:private"],
 )
 
 cc_library(
     name = "platform_file_system",
-    srcs = tf_platform_srcs([
-    ]) + tf_platform_hdrs([
-        "windows_file_system.h",
-    ]) + [
-        "platform/file_system_helper.cc",
+    srcs = [
+        "//tensorflow/core/platform:file_system_helper.cc",
+        "//tensorflow/core/platform:legacy_file_system_hdrs",
     ],
     hdrs = [
         ":platform_file_system_hdrs",
@@ -584,83 +550,84 @@
     visibility = [":__subpackages__"],
     deps = [
         ":lib",
-        ":lib_platform",
         ":platform_env",
+        "//tensorflow/core/platform",
     ],
 )
 
 cc_library(
     name = "platform_strings",
-    srcs = tf_platform_srcs([
-        "platform/platform_strings.cc",
-        "platform/platform_strings_computed.h",
-    ]),
+    srcs = [
+        "//tensorflow/core/platform:platform_strings.cc",
+        "//tensorflow/core/platform:platform_strings_computed.h",
+    ],
     hdrs = [
-        "platform/platform_strings.h",
+        "//tensorflow/core/platform:platform_strings.h",
     ],
     visibility = [":__subpackages__"],
-    deps = [":lib"],
+    deps = [],
 )
 
 filegroup(
     name = "platform_other_hdrs",
     srcs = [
-        "platform/abi.h",
-        "platform/context.h",
-        "platform/cpu_feature_guard.h",
-        "platform/error.h",
-        "platform/fingerprint.h",
-        "platform/monitoring.h",
-        "platform/net.h",
-        "platform/notification.h",
-        "platform/prefetch.h",
-        "platform/profile_utils/android_armv7a_cpu_utils_helper.h",
-        "platform/profile_utils/clock_cycle_profiler.h",
-        "platform/profile_utils/cpu_utils.h",
-        "platform/profile_utils/i_cpu_utils_helper.h",
-        "platform/stacktrace.h",
-        "platform/stacktrace_handler.h",
-        "platform/strong_hash.h",
-        "platform/subprocess.h",
+        "//tensorflow/core/platform:abi.h",
+        "//tensorflow/core/platform:context.h",
+        "//tensorflow/core/platform:cpu_feature_guard.h",
+        "//tensorflow/core/platform:error.h",
+        "//tensorflow/core/platform:fingerprint.h",
+        "//tensorflow/core/platform:monitoring.h",
+        "//tensorflow/core/platform:net.h",
+        "//tensorflow/core/platform:notification.h",
+        "//tensorflow/core/platform:prefetch.h",
+        "//tensorflow/core/platform:profile_utils/android_armv7a_cpu_utils_helper.h",
+        "//tensorflow/core/platform:profile_utils/clock_cycle_profiler.h",
+        "//tensorflow/core/platform:profile_utils/cpu_utils.h",
+        "//tensorflow/core/platform:profile_utils/i_cpu_utils_helper.h",
+        "//tensorflow/core/platform:stacktrace.h",
+        "//tensorflow/core/platform:stacktrace_handler.h",
+        "//tensorflow/core/platform:strong_hash.h",
+        "//tensorflow/core/platform:subprocess.h",
     ] + tf_additional_monitoring_hdrs(),
     visibility = ["//visibility:private"],
 )
 
+tf_cc_test(
+    name = "platform_unbounded_work_queue_test",
+    srcs = ["//tensorflow/core/platform:unbounded_work_queue_test.cc"],
+    deps = [
+        ":framework",
+        ":lib",
+        ":lib_internal",
+        ":lib_test_internal",
+        ":test",
+        ":test_main",
+        "@com_google_absl//absl/memory",
+    ],
+)
+
 # Headers that are not exported as part of ":lib".
 filegroup(
     name = "platform_other_internal_hdrs",
     srcs = [
-        "platform/denormal.h",
-        "platform/setround.h",
-        "platform/tracing.h",
+        "//tensorflow/core/platform:denormal.h",
+        "//tensorflow/core/platform:setround.h",
+        "//tensorflow/core/platform:tracing.h",
     ],
     visibility = ["//visibility:private"],
 )
 
 cc_library(
     name = "platform_other",
-    srcs = tf_platform_srcs([
-        "subprocess.cc",
-        "net.cc",
-        "tracing.cc",
-    ]) + tf_platform_hdrs([
-        "tracing.h",
-        "error.h",
-        "context.h",
-        "fingerprint.h",
-        "notification.h",
-        "stacktrace.h",
-        "strong_hash.h",
-        "subprocess.h",
-        "tracing_impl.h",
-    ]) + [
-        "platform/cpu_feature_guard.cc",
-        "platform/setround.cc",
-        "platform/tracing.cc",
-        "platform/denormal.cc",
-        "platform/profile_utils/android_armv7a_cpu_utils_helper.cc",
-        "platform/profile_utils/clock_cycle_profiler.cc",
-        "platform/profile_utils/cpu_utils.cc",
+    srcs = [
+        "//tensorflow/core/platform:cpu_feature_guard.cc",
+        "//tensorflow/core/platform:denormal.cc",
+        "//tensorflow/core/platform:legacy_platform_other_srcs",
+        "//tensorflow/core/platform:profile_utils/android_armv7a_cpu_utils_helper.cc",
+        "//tensorflow/core/platform:profile_utils/clock_cycle_profiler.cc",
+        "//tensorflow/core/platform:profile_utils/cpu_utils.cc",
+        "//tensorflow/core/platform:setround.cc",
+        "//tensorflow/core/platform:tracing.cc",
     ],
     hdrs = [
         ":platform_other_hdrs",
@@ -670,11 +637,13 @@
     visibility = [":__subpackages__"],
     deps = [
         ":lib",
-        ":lib_platform",
         ":platform_base",
         ":platform_env",
         ":platform_port",
         ":platform_protobuf",
+        "//tensorflow/core/platform",
+        "//tensorflow/core/platform:abi",
+        "//tensorflow/core/platform:stacktrace",
         "//tensorflow/core/platform/default/build_config:other",
         "//tensorflow/core/platform/default/build_config:platformlib",
         "//tensorflow/core/platform/default/build_config:port",
@@ -686,34 +655,45 @@
 # don't have to depend on lib/platformlib.
 cc_library(
     name = "lib_proto_parsing",
-    srcs = glob(tf_additional_proto_srcs()),
+    srcs = [
+        "//tensorflow/core/platform:protobuf.cc",
+    ],
     hdrs = [
         "lib/core/errors.h",
         "lib/core/status.h",
         "lib/core/stringpiece.h",
         "lib/strings/numbers.h",
         "lib/strings/strcat.h",
-        "platform/init_main.h",
-        "platform/logging.h",
-        "platform/macros.h",
-        "platform/platform.h",
-        "platform/protobuf.h",
-        "platform/types.h",
-        "platform/windows/cpu_info.h",
-        "lib/bfloat16/bfloat16.h",
-    ] + tf_additional_proto_hdrs(),
+        "//tensorflow/core/lib/bfloat16:bfloat16.h",
+        "//tensorflow/core/platform:init_main.h",
+        "//tensorflow/core/platform:legacy_proto_hdrs",
+        "//tensorflow/core/platform:logging.h",
+        "//tensorflow/core/platform:macros.h",
+        "//tensorflow/core/platform:platform.h",
+        "//tensorflow/core/platform:protobuf.h",
+        "//tensorflow/core/platform:stringpiece.h",
+        "//tensorflow/core/platform:tstring.h",
+        "//tensorflow/core/platform:types.h",
+    ],
     copts = tf_copts(),
     deps = tf_lib_proto_parsing_deps() + [
         ":platform_base",
         "@com_google_absl//absl/strings",
         "@double_conversion//:double-conversion",
+        "//tensorflow/core/lib/bfloat16",
+        "//tensorflow/core/platform:cpu_info",
+        "//tensorflow/core/platform:logging",
+        "//tensorflow/core/platform:macros",
+        "//tensorflow/core/platform:platform",
+        "//tensorflow/core/platform:stringpiece",
+        "//tensorflow/core/platform:types",
     ],
 )
 
 cc_library(
     name = "lib_proto_compiler",
     hdrs = [
-        "platform/protobuf_compiler.h",
+        "//tensorflow/core/platform:protobuf_compiler.h",
     ],
     copts = tf_copts(),
     deps = tf_lib_proto_compiler_deps() + [
@@ -728,7 +708,6 @@
 cc_library(
     name = "lib",
     hdrs = [
-        "lib/bfloat16/bfloat16.h",
         "lib/core/arena.h",
         "lib/core/bitmap.h",
         "lib/core/bits.h",
@@ -740,14 +719,6 @@
         "lib/core/stringpiece.h",
         "lib/core/threadpool.h",
         "lib/core/threadpool_interface.h",
-        "lib/gtl/array_slice.h",
-        "lib/gtl/cleanup.h",
-        "lib/gtl/compactptrset.h",
-        "lib/gtl/flatmap.h",
-        "lib/gtl/flatset.h",
-        "lib/gtl/inlined_vector.h",
-        "lib/gtl/optional.h",
-        "lib/gtl/priority_queue_util.h",
         "lib/hash/crc32c.h",
         "lib/hash/hash.h",
         "lib/histogram/histogram.h",
@@ -784,10 +755,14 @@
         ":platform_other_hdrs",
         ":platform_port_hdrs",
         ":platform_protobuf_hdrs",
+        "//tensorflow/core/lib/bfloat16:bfloat16.h",
+        "//tensorflow/core/lib/gtl:legacy_lib_gtl_headers",
     ],
     visibility = ["//visibility:public"],
     deps = [
         ":lib_internal",
+        "//tensorflow/core/platform:stringpiece",
+        "//tensorflow/core/platform:stringprintf",
         "@com_google_absl//absl/container:inlined_vector",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:optional",
@@ -823,44 +798,23 @@
 )
 
 cc_library(
-    name = "abi",
-    srcs = ["platform/abi.cc"],
-    hdrs = ["platform/abi.h"],
-    deps = [":platform_base"],
-)
-
-cc_library(
-    name = "stacktrace",
-    srcs = glob(["platform/*/stacktrace.h"]),
-    hdrs = ["platform/stacktrace.h"],
-    deps = [
-        ":abi",
-        ":lib_platform",
-        "//tensorflow/core/platform/default/build_config:stacktrace",
-    ],
-)
-
-cc_library(
     name = "stacktrace_handler",
-    srcs = ["platform/stacktrace_handler.cc"],
-    hdrs = ["platform/stacktrace_handler.h"],
+    srcs = ["//tensorflow/core/platform:stacktrace_handler.cc"],
+    hdrs = ["//tensorflow/core/platform:stacktrace_handler.h"],
     deps = [
-        ":abi",
-        ":lib_platform",
-        ":stacktrace",
+        "//tensorflow/core/platform",
+        "//tensorflow/core/platform:abi",
+        "//tensorflow/core/platform:stacktrace",
     ],
 )
 
-# Libraries that will eventually be moved into lib/core
-# Note that stringpiece_test can't be place here yet, because we are
-# required to use tf_cc_test, and that rule will change / into _
+# DEPRECATED: use platform:stringpiece instead.
 cc_library(
     name = "core_stringpiece",
     hdrs = ["lib/core/stringpiece.h"],
     copts = tf_copts(),
     deps = [
-        ":platform_base",
-        "@com_google_absl//absl/strings",
+        "//tensorflow/core/platform:stringpiece",
     ],
 )
 
@@ -871,14 +825,15 @@
     name = "test",
     testonly = 1,
     srcs = [
-        "platform/test.cc",
         "util/reporter.cc",
-    ] + tf_additional_test_srcs(),
+        "//tensorflow/core/platform:legacy_test_srcs",
+        "//tensorflow/core/platform:test.cc",
+    ],
     hdrs = [
         "lib/core/status_test_util.h",
-        "platform/test.h",
-        "platform/test_benchmark.h",
         "util/reporter.h",
+        "//tensorflow/core/platform:test.h",
+        "//tensorflow/core/platform:test_benchmark.h",
     ],
     copts = tf_copts(),
     linkopts = select({
@@ -904,16 +859,16 @@
     name = "test_lite",
     testonly = 1,
     srcs = [
-        "platform/test.cc",
+        "//tensorflow/core/platform:test.cc",
     ],
     hdrs = [
-        "platform/test.h",
-        "platform/test_benchmark.h",
+        "//tensorflow/core/platform:test.h",
+        "//tensorflow/core/platform:test_benchmark.h",
     ],
     copts = tf_copts(),
     deps = [
-        ":lib_platform",
         ":platform_base",
+        "//tensorflow/core/platform",
         "//tensorflow/core/platform/default/build_config:gtest",
     ],
 )
@@ -1164,36 +1119,40 @@
 
 cc_library(
     name = "framework_lite",
-    srcs = tf_additional_minimal_lib_srcs(),
+    srcs = [
+        "//tensorflow/core/platform:legacy_minimal_lib_srcs",
+    ],
     hdrs = [
         "framework/numeric_types.h",
         "framework/tensor_types.h",
         "framework/type_traits.h",
-        "lib/bfloat16/bfloat16.h",
-        "platform/byte_order.h",
-        "platform/default/dynamic_annotations.h",
-        "platform/default/integral_types.h",
-        "platform/default/logging.h",
-        "platform/default/mutex.h",
-        "platform/default/thread_annotations.h",
-        "platform/dynamic_annotations.h",
-        "platform/macros.h",
-        "platform/mutex.h",
-        "platform/platform.h",
-        "platform/prefetch.h",
-        "platform/protobuf.h",
-        "platform/thread_annotations.h",
-        "platform/types.h",
-        "platform/cpu_info.h",
-    ] + if_windows(["platform/windows/integral_types.h"]),
+        "//tensorflow/core/lib/bfloat16:bfloat16.h",
+        "//tensorflow/core/platform:byte_order.h",
+        "//tensorflow/core/platform:default/dynamic_annotations.h",
+        "//tensorflow/core/platform:default/integral_types.h",
+        "//tensorflow/core/platform:default/logging.h",
+        "//tensorflow/core/platform:default/mutex.h",
+        "//tensorflow/core/platform:default/thread_annotations.h",
+        "//tensorflow/core/platform:dynamic_annotations.h",
+        "//tensorflow/core/platform:macros.h",
+        "//tensorflow/core/platform:mutex.h",
+        "//tensorflow/core/platform:platform.h",
+        "//tensorflow/core/platform:prefetch.h",
+        "//tensorflow/core/platform:protobuf.h",
+        "//tensorflow/core/platform:thread_annotations.h",
+        "//tensorflow/core/platform:tstring.h",
+        "//tensorflow/core/platform:types.h",
+        "//tensorflow/core/platform:cpu_info.h",
+    ] + if_windows(["//tensorflow/core/platform:windows/integral_types.h"]),
     visibility = ["//visibility:public"],
     deps =
         [
             "@nsync//:nsync_cpp",
         ] + [
-            "@com_google_absl//absl/base:log_severity",
             "//third_party/eigen3",
+            "//tensorflow/core/lib/bfloat16",
             "//tensorflow/core/platform/default/build_config:minimal",
+            "//tensorflow/core/platform:types",
         ],
 )
 
@@ -1387,6 +1346,36 @@
         "ragged_conversion_ops",
         "ragged_math_ops",
     ],
+    deps = [":ragged_to_dense_util"],
+)
+
+cc_library(
+    name = "ragged_to_dense_util",
+    srcs = [
+        "ops/ragged_to_dense_util.cc",
+    ],
+    hdrs = [
+        "ops/ragged_to_dense_util.h",
+    ],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:protos_all_cc",
+    ],
+)
+
+tf_cc_test(
+    name = "ragged_to_dense_util_test",
+    srcs = [
+        "ops/ragged_to_dense_util_test.cc",
+    ],
+    deps = [
+        ":ragged_to_dense_util",
+        ":test",
+        ":testlib",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:protos_all_cc",
+        "@com_google_googletest//:gtest_main",
+    ],
 )
 
 cc_library(
@@ -1812,7 +1801,11 @@
         ":error_codes_proto_text_srcs",
         "//tensorflow/core/platform/default/build_config:android_srcs",
         "//tensorflow/core/util/ctc:android_srcs",
+        "//tensorflow/core/platform:legacy_srcs_no_runtime",
         "//tensorflow/core/profiler:mobile_srcs",
+        "//tensorflow/core/lib/bfloat16:bfloat16.h",
+        "//tensorflow/core/lib/bfloat16:bfloat16.cc",
+        "//tensorflow/core/lib/gtl:legacy_lib_gtl_all_headers",
     ] + glob(
         [
             "client/**/*.cc",
@@ -1820,8 +1813,6 @@
             "framework/**/*.cc",
             "lib/**/*.h",
             "lib/**/*.cc",
-            "platform/**/*.h",
-            "platform/**/*.cc",
             "public/**/*.h",
             "util/**/*.h",
             "util/**/*.cc",
@@ -1842,22 +1833,6 @@
             "util/events_writer.*",
             "util/stats_calculator.*",
             "util/reporter.*",
-            "platform/**/cuda_libdevice_path.*",
-            "platform/**/logger.cc",
-            # Exclude env_time and logging to avoid collisions with
-            # :platform_base, a common dependency for downstream targets.
-            "platform/**/env_time.cc",
-            "platform/**/logging.cc",
-            "platform/default/test_benchmark.*",
-            "platform/cuda.h",
-            "platform/rocm.h",
-            "platform/google/**/*",
-            "platform/hadoop/**/*",
-            "platform/gif.h",
-            "platform/jpeg.h",
-            "platform/png.h",
-            "platform/stream_executor.*",
-            "platform/windows/**/*",
             "user_ops/**/*.cu.cc",
             "util/ctc/*.h",
             "util/ctc/*.cc",
@@ -2112,7 +2087,7 @@
 filegroup(
     name = "android_test_srcs",
     # TODO(andrewharp/nhua):
-    # make more test-related sources portable e.g. "platform/test.cc",
+    # make more test-related sources portable e.g. "//tensorflow/core/platform:test.cc",
     srcs = [
         ":framework/fake_input.cc",
         ":framework/fake_input.h",
@@ -2120,10 +2095,10 @@
         ":framework/shape_inference_testutil.h",
         ":framework/tensor_testutil.cc",
         ":framework/tensor_testutil.h",
-        ":platform/test.cc",
-        ":platform/test.h",
         ":util/reporter.cc",
         ":util/reporter.h",
+        "//tensorflow/core/platform:test.cc",
+        "//tensorflow/core/platform:test.h",
     ],
     visibility = ["//visibility:public"],
 )
@@ -2136,9 +2111,9 @@
         ":framework/shape_inference_testutil.h",
         ":framework/tensor_testutil.cc",
         ":framework/tensor_testutil.h",
-        ":platform/test.h",
         ":util/reporter.cc",
         ":util/reporter.h",
+        "//tensorflow/core/platform:test.h",
     ],
     visibility = ["//visibility:public"],
 )
@@ -2403,36 +2378,27 @@
     ],
 )
 
-LIB_INTERNAL_PRIVATE_HEADERS = ["framework/resource_handle.h"] + glob(
+LIB_INTERNAL_PRIVATE_HEADERS = [
+    "framework/resource_handle.h",
+    "//tensorflow/core/platform:legacy_lib_internal_headers",
+    "//tensorflow/core/lib/bfloat16:bfloat16.h",
+    "//tensorflow/core/lib/gtl:legacy_lib_gtl_all_headers",
+] + glob(
     [
         "lib/**/*.h",
-        "platform/*.h",
-        "platform/profile_utils/**/*.h",
     ],
     exclude = [
         "**/*test*",
         "lib/gif/**/*",
         "lib/jpeg/**/*",
         "lib/png/**/*",
-        "platform/gif.h",
-        "platform/jpeg.h",
-        "platform/png.h",
-        "platform/**/cuda.h",
-        "platform/**/rocm.h",
-        "platform/**/stream_executor.h",
     ],
 )
 
-LIB_INTERNAL_PUBLIC_HEADERS = tf_additional_lib_hdrs() + [
+LIB_INTERNAL_PUBLIC_HEADERS = [
     "lib/core/blocking_counter.h",
     "lib/core/refcount.h",
-    "lib/gtl/edit_distance.h",
-    "lib/gtl/int_type.h",
-    "lib/gtl/iterator_range.h",
-    "lib/gtl/manual_constructor.h",
-    "lib/gtl/map_util.h",
-    "lib/gtl/stl_util.h",
-    "lib/gtl/top_n.h",
+    "//tensorflow/core/lib/gtl:legacy_lib_internal_public_gtl_headers",
     "lib/hash/hash.h",
     "lib/io/inputbuffer.h",
     "lib/io/iterator.h",
@@ -2454,17 +2420,19 @@
     "lib/strings/proto_serialization.h",
     "lib/strings/scanner.h",
     "lib/wav/wav_io.h",
-    "platform/annotation.h",
-    "platform/demangle.h",
-    "platform/denormal.h",
-    "platform/host_info.h",
-    "platform/platform.h",
-    "platform/monitoring.h",
-    "platform/protobuf_internal.h",
-    "platform/setround.h",
-    "platform/snappy.h",
-    "platform/tensor_coding.h",
-    "platform/tracing.h",
+    "//tensorflow/core/platform:annotation.h",
+    "//tensorflow/core/platform:demangle.h",
+    "//tensorflow/core/platform:denormal.h",
+    "//tensorflow/core/platform:host_info.h",
+    "//tensorflow/core/platform:platform.h",
+    "//tensorflow/core/platform:monitoring.h",
+    "//tensorflow/core/platform:protobuf_internal.h",
+    "//tensorflow/core/platform:setround.h",
+    "//tensorflow/core/platform:snappy.h",
+    "//tensorflow/core/platform:tensor_coding.h",
+    "//tensorflow/core/platform:tracing.h",
+    "//tensorflow/core/platform:unbounded_work_queue.h",
+    "//tensorflow/core/platform:legacy_platform_lib_hdrs",
     "util/env_var.h",
 ]
 
@@ -2472,11 +2440,12 @@
     name = "annotation",
     srcs = [],
     hdrs = [
-        "platform/annotation.h",
+        "//tensorflow/core/platform:annotation.h",
     ],
     copts = tf_copts(),
     visibility = ["//visibility:public"],
     deps = [
+        "//tensorflow/core/platform:macros",
         "@com_google_absl//absl/strings",
     ],
 )
@@ -2520,8 +2489,6 @@
     srcs = LIB_INTERNAL_PRIVATE_HEADERS + glob(
         [
             "lib/**/*.cc",
-            "platform/*.cc",
-            "platform/profile_utils/**/*.cc",
             "util/env_var.cc",
         ],
         exclude = [
@@ -2531,46 +2498,27 @@
             "lib/gif/**/*",
             "lib/jpeg/**/*",
             "lib/png/**/*",
-            "platform/**/env_time.cc",
-            "platform/**/monitoring.cc",
-            "platform/**/cuda_libdevice_path.cc",
-            "platform/**/device_tracer.cc",
-            "platform/**/logger.cc",
-            "platform/**/logging.cc",
-            "platform/**/human_readable_json.cc",
-            "platform/abi.cc",
-            "platform/protobuf.cc",
         ],
-    ) + tf_additional_lib_srcs(
-        exclude = [
-            "**/*test*",
-            "platform/**/cuda.h",
-            "platform/**/cuda_libdevice_path.cc",
-            "platform/**/rocm.h",
-            "platform/**/monitoring.cc",
-            "platform/**/stream_executor.h",
-            "platform/**/env_time.cc",
-            "platform/**/device_tracer.cc",
-            "platform/**/logger.cc",
-            "platform/**/logging.cc",
-            "platform/**/human_readable_json.cc",
-            "platform/abi.cc",
-        ] +
-        # Protobuf deps already included through the ":lib_proto_parsing"
-        # dependency.
-        tf_additional_proto_srcs(),
-    ) + tf_additional_monitoring_srcs(),
+    ) + [
+        "//tensorflow/core/platform:legacy_monitoring_srcs",
+        "//tensorflow/core/platform:legacy_platform_lib_srcs",
+        "//tensorflow/core/platform:legacy_lib_internal_srcs",
+    ],
     hdrs = LIB_INTERNAL_PUBLIC_HEADERS,
     copts = tf_copts(),
     defines = LIB_INTERNAL_DEFINES,
     deps = tf_additional_lib_deps() + [
+               ":core_stringpiece",
                ":lib_hash_crc32c_accelerate_internal",
                ":lib_proto_parsing",
-               ":abi",
-               ":core_stringpiece",
+               ":platform_strings",
                "@com_google_absl//absl/memory",
                "@com_google_absl//absl/strings",
                "//third_party/eigen3",
+               "//tensorflow/core/lib/bfloat16",
+               "//tensorflow/core/platform:abi",
+               "//tensorflow/core/platform:cpu_info",
+               "//tensorflow/core/platform:stringprintf",
                "//tensorflow/core/platform/default/build_config:platformlib",
                "@snappy",
                "@zlib_archive//:zlib",
@@ -2592,7 +2540,7 @@
     name = "gif_internal",
     srcs = [
         "lib/gif/gif_io.cc",
-        "platform/gif.h",
+        "//tensorflow/core/platform:gif.h",
     ],
     hdrs = ["lib/gif/gif_io.h"],
     copts = tf_copts(),
@@ -2613,7 +2561,7 @@
     srcs = [
         "lib/jpeg/jpeg_handle.cc",
         "lib/jpeg/jpeg_mem.cc",
-        "platform/jpeg.h",
+        "//tensorflow/core/platform:jpeg.h",
     ],
     hdrs = [
         "lib/jpeg/jpeg_handle.h",
@@ -2636,18 +2584,19 @@
     name = "png_internal",
     srcs = ["lib/png/png_io.cc"],
     hdrs = [
-        "lib/bfloat16/bfloat16.h",
         "lib/core/stringpiece.h",
         "lib/png/png_io.h",
-        "platform/byte_order.h",
-        "platform/cpu_info.h",
-        "platform/default/integral_types.h",
-        "platform/default/logging.h",
-        "platform/logging.h",
-        "platform/macros.h",
-        "platform/platform.h",
-        "platform/png.h",
-        "platform/types.h",
+        "//tensorflow/core/lib/bfloat16:bfloat16.h",
+        "//tensorflow/core/platform:byte_order.h",
+        "//tensorflow/core/platform:cpu_info.h",
+        "//tensorflow/core/platform:default/integral_types.h",
+        "//tensorflow/core/platform:default/logging.h",
+        "//tensorflow/core/platform:logging.h",
+        "//tensorflow/core/platform:macros.h",
+        "//tensorflow/core/platform:platform.h",
+        "//tensorflow/core/platform:png.h",
+        "//tensorflow/core/platform:tstring.h",
+        "//tensorflow/core/platform:types.h",
     ],
     copts = tf_copts(),
     linkopts = select({
@@ -2660,7 +2609,6 @@
         ":lib_internal",
         "//tensorflow/core/platform/default/build_config:png",
         "@com_google_absl//absl/base",
-        "@com_google_absl//absl/base:log_severity",
         "@com_google_absl//absl/strings",
         "@zlib_archive//:zlib",
     ],
@@ -2669,20 +2617,20 @@
 cc_library(
     name = "tflite_portable_logging",
     hdrs = [
-        "lib/bfloat16/bfloat16.h",
-        "platform/default/integral_types.h",
-        "platform/default/logging.h",
-        "platform/logging.h",
-        "platform/macros.h",
-        "platform/platform.h",
-        "platform/types.h",
+        "//tensorflow/core/lib/bfloat16:bfloat16.h",
+        "//tensorflow/core/platform:default/integral_types.h",
+        "//tensorflow/core/platform:default/logging.h",
+        "//tensorflow/core/platform:logging.h",
+        "//tensorflow/core/platform:macros.h",
+        "//tensorflow/core/platform:platform.h",
+        "//tensorflow/core/platform:tstring.h",
+        "//tensorflow/core/platform:types.h",
     ],
     copts = tf_copts(),
     linkopts = ["-ldl"],
     deps = [
         ":platform_base",
         "//tensorflow/core/platform/default/build_config:logging",
-        "@com_google_absl//absl/base:log_severity",
     ],
 )
 
@@ -2691,30 +2639,33 @@
     srcs = if_android([
         "lib/jpeg/jpeg_handle.cc",
         "lib/jpeg/jpeg_mem.cc",
-        "platform/jpeg.h",
+        "//tensorflow/core/platform:jpeg.h",
     ]),
     hdrs = [
-        "lib/bfloat16/bfloat16.h",
         "lib/core/stringpiece.h",
         "lib/jpeg/jpeg_handle.h",
         "lib/jpeg/jpeg_mem.h",
-        "platform/default/dynamic_annotations.h",
-        "platform/default/integral_types.h",
-        "platform/default/logging.h",
-        "platform/dynamic_annotations.h",
-        "platform/logging.h",
-        "platform/macros.h",
-        "platform/mem.h",
-        "platform/platform.h",
-        "platform/types.h",
+        "//tensorflow/core/lib/bfloat16:bfloat16.h",
+        "//tensorflow/core/platform:default/dynamic_annotations.h",
+        "//tensorflow/core/platform:default/integral_types.h",
+        "//tensorflow/core/platform:default/logging.h",
+        "//tensorflow/core/platform:dynamic_annotations.h",
+        "//tensorflow/core/platform:logging.h",
+        "//tensorflow/core/platform:macros.h",
+        "//tensorflow/core/platform:mem.h",
+        "//tensorflow/core/platform:platform.h",
+        "//tensorflow/core/platform:stringpiece.h",
+        "//tensorflow/core/platform:tstring.h",
+        "//tensorflow/core/platform:types.h",
     ],
     copts = tf_copts(),
     linkopts = ["-ldl"],
     deps = [
+        ":core_stringpiece",
+        "//tensorflow/core/platform:stringpiece",
         "//tensorflow/core/platform/default/build_config:jpeg",
         "//tensorflow/core/platform/default/build_config:logging",
         "@com_google_absl//absl/base:core_headers",
-        "@com_google_absl//absl/base:log_severity",
         "@com_google_absl//absl/strings",
     ],
 )
@@ -2723,24 +2674,25 @@
     name = "android_gif_internal",
     srcs = if_android([
         "lib/gif/gif_io.cc",
-        "platform/gif.h",
+        "//tensorflow/core/platform:gif.h",
         "lib/strings/strcat.h",
         "lib/strings/numbers.h",
     ]),
     hdrs = [
-        "lib/bfloat16/bfloat16.h",
         "lib/core/stringpiece.h",
         "lib/gif/gif_io.h",
-        "lib/gtl/cleanup.h",
-        "platform/default/dynamic_annotations.h",
-        "platform/default/integral_types.h",
-        "platform/default/logging.h",
-        "platform/dynamic_annotations.h",
-        "platform/logging.h",
-        "platform/macros.h",
-        "platform/mem.h",
-        "platform/platform.h",
-        "platform/types.h",
+        "//tensorflow/core/lib/bfloat16:bfloat16.h",
+        "//tensorflow/core/lib/gtl:legacy_android_gif_internal_headers",
+        "//tensorflow/core/platform:default/dynamic_annotations.h",
+        "//tensorflow/core/platform:default/integral_types.h",
+        "//tensorflow/core/platform:default/logging.h",
+        "//tensorflow/core/platform:dynamic_annotations.h",
+        "//tensorflow/core/platform:logging.h",
+        "//tensorflow/core/platform:macros.h",
+        "//tensorflow/core/platform:mem.h",
+        "//tensorflow/core/platform:platform.h",
+        "//tensorflow/core/platform:tstring.h",
+        "//tensorflow/core/platform:types.h",
     ],
     copts = tf_copts(),
     linkopts = ["-ldl"],
@@ -2748,7 +2700,6 @@
         "//tensorflow/core/platform/default/build_config:gif",
         "//tensorflow/core/platform/default/build_config:logging",
         "@com_google_absl//absl/base:core_headers",
-        "@com_google_absl//absl/base:log_severity",
         "@com_google_absl//absl/strings",
     ],
 )
@@ -2757,26 +2708,26 @@
     name = "android_png_internal",
     srcs = if_android([
         "lib/png/png_io.cc",
-        "platform/png.h",
+        "//tensorflow/core/platform:png.h",
     ]),
     hdrs = [
-        "lib/bfloat16/bfloat16.h",
         "lib/core/stringpiece.h",
         "lib/png/png_io.h",
-        "platform/byte_order.h",
-        "platform/cpu_info.h",
-        "platform/default/integral_types.h",
-        "platform/default/logging.h",
-        "platform/logging.h",
-        "platform/macros.h",
-        "platform/platform.h",
-        "platform/types.h",
+        "//tensorflow/core/lib/bfloat16:bfloat16.h",
+        "//tensorflow/core/platform:byte_order.h",
+        "//tensorflow/core/platform:cpu_info.h",
+        "//tensorflow/core/platform:default/integral_types.h",
+        "//tensorflow/core/platform:default/logging.h",
+        "//tensorflow/core/platform:logging.h",
+        "//tensorflow/core/platform:macros.h",
+        "//tensorflow/core/platform:platform.h",
+        "//tensorflow/core/platform:tstring.h",
+        "//tensorflow/core/platform:types.h",
     ],
     copts = tf_copts(),
     linkopts = ["-ldl"],
     deps = [
         "//tensorflow/core/platform/default/build_config:logging",
-        "@com_google_absl//absl/base:log_severity",
         "@com_google_absl//absl/strings",
         "@png_archive//:png",
     ],
@@ -3052,11 +3003,11 @@
 
 tf_cuda_library(
     name = "stream_executor",
-    srcs = ["platform/stream_executor.h"],
+    srcs = ["//tensorflow/core/platform:stream_executor.h"],
     hdrs = [
-        "platform/cuda.h",
-        "platform/rocm.h",
-        "platform/stream_executor.h",
+        "//tensorflow/core/platform:cuda.h",
+        "//tensorflow/core/platform:rocm.h",
+        "//tensorflow/core/platform:stream_executor.h",
     ],
     deps = [
         "//tensorflow/core/platform/default/build_config:stream_executor",
@@ -3067,9 +3018,9 @@
 # and does not include any cuda dependencies.
 cc_library(
     name = "stream_executor_no_cuda",
-    srcs = ["platform/stream_executor.h"],
+    srcs = ["//tensorflow/core/platform:stream_executor.h"],
     hdrs = [
-        "platform/stream_executor_no_cuda.h",
+        "//tensorflow/core/platform:stream_executor_no_cuda.h",
     ],
     visibility = ["//visibility:public"],
     deps = [
@@ -3450,7 +3401,7 @@
 cc_library(
     name = "regexp_internal",
     hdrs = [
-        "platform/regexp.h",
+        "//tensorflow/core/platform:regexp.h",
     ],
     visibility = [
         "//tensorflow/compiler:__subpackages__",
@@ -3510,7 +3461,9 @@
 
 tf_cuda_library(
     name = "device_tracer",
-    srcs = tf_additional_device_tracer_srcs(),
+    srcs = [
+        "//tensorflow/core/platform:legacy_device_tracer_srcs",
+    ],
     copts = tf_copts(),
     cuda_deps = tf_additional_cupti_wrapper_deps() + tf_additional_device_tracer_cuda_deps(),
     visibility = [
@@ -3742,11 +3695,11 @@
     name = "lib_test_internal",
     testonly = 1,
     hdrs = [
-        "lib/gtl/manual_constructor.h",
         "lib/io/block.h",
         "lib/io/block_builder.h",
         "lib/io/format.h",
         "lib/random/philox_random_test_utils.h",
+        "//tensorflow/core/lib/gtl:legacy_lib_test_internal_headers",
     ],
     deps = [
         ":lib",
@@ -3785,7 +3738,7 @@
 cc_library(
     name = "test_main",
     testonly = 1,
-    srcs = ["platform/test_main.cc"],
+    srcs = ["//tensorflow/core/platform:test_main.cc"],
     copts = tf_copts(),
     linkopts = select({
         "//tensorflow:windows": [],
@@ -3806,16 +3759,16 @@
 cc_library(
     name = "test_lite_main",
     testonly = 1,
-    srcs = ["platform/test_main.cc"],
+    srcs = ["//tensorflow/core/platform:test_main.cc"],
     copts = tf_copts(),
     deps = [
         # TODO(ahentz): we don't want to depend on "lib" here. It used to be
         # that "core_stringpiece" was enough but that recently changed and
         # we now need at least "str_util".
         ":lib",
-        ":lib_platform",
         ":stacktrace_handler",
         ":test_lite",
+        "//tensorflow/core/platform",
         "//tensorflow/core/platform/default/build_config:test_lite_main",
     ],
     alwayslink = 1,
@@ -3832,18 +3785,7 @@
         "lib/core/notification_test.cc",
         "lib/core/refcount_test.cc",
         "lib/core/status_test.cc",
-        "lib/core/stringpiece_test.cc",
         "lib/core/threadpool_test.cc",
-        "lib/gtl/cleanup_test.cc",
-        "lib/gtl/compactptrset_test.cc",
-        "lib/gtl/edit_distance_test.cc",
-        "lib/gtl/flatmap_test.cc",
-        "lib/gtl/flatset_test.cc",
-        "lib/gtl/int_type_test.cc",
-        "lib/gtl/iterator_range_test.cc",
-        "lib/gtl/manual_constructor_test.cc",
-        "lib/gtl/map_util_test.cc",
-        "lib/gtl/top_n_test.cc",
         "lib/hash/crc32c_test.cc",
         "lib/hash/hash_test.cc",
         "lib/histogram/histogram_test.cc",
@@ -3872,18 +3814,20 @@
         "lib/strings/scanner_test.cc",
         "lib/strings/str_util_test.cc",
         "lib/strings/strcat_test.cc",
-        "lib/strings/stringprintf_test.cc",
         "lib/wav/wav_io_test.cc",
-        "platform/fingerprint_test.cc",
-        "platform/integral_types_test.cc",
-        "platform/logging_test.cc",
-        "platform/mutex_test.cc",
-        "platform/net_test.cc",
-        "platform/port_test.cc",
-        "platform/profile_utils/cpu_utils_test.cc",
-        "platform/stacktrace_handler_test.cc",
-        "platform/subprocess_test.cc",
-        "platform/vmodule_benchmark_test.cc",
+        "//tensorflow/core/lib/gtl:legacy_lib_gtl_tests",
+        "//tensorflow/core/platform:fingerprint_test.cc",
+        "//tensorflow/core/platform:integral_types_test.cc",
+        "//tensorflow/core/platform:logging_test.cc",
+        "//tensorflow/core/platform:mutex_test.cc",
+        "//tensorflow/core/platform:net_test.cc",
+        "//tensorflow/core/platform:port_test.cc",
+        "//tensorflow/core/platform:profile_utils/cpu_utils_test.cc",
+        "//tensorflow/core/platform:stacktrace_handler_test.cc",
+        "//tensorflow/core/platform:stringpiece_test.cc",
+        "//tensorflow/core/platform:stringprintf_test.cc",
+        "//tensorflow/core/platform:subprocess_test.cc",
+        "//tensorflow/core/platform:vmodule_benchmark_test.cc",
     ],
     deps = [
         ":core_cpu_internal",
@@ -3893,6 +3837,8 @@
         ":protos_all_cc",
         ":test",
         ":test_main",
+        "//tensorflow/core/platform:stringpiece",
+        "//tensorflow/core/platform:stringprintf",
         "//third_party/eigen3",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/synchronization",
@@ -3902,7 +3848,7 @@
 
 tf_cc_test(
     name = "vmodule_test",
-    srcs = ["platform/vmodule_test.cc"],
+    srcs = ["//tensorflow/core/platform:vmodule_test.cc"],
     tags = ["optonly"],
     deps = [
         ":lib",
@@ -3933,7 +3879,7 @@
 tf_cc_test(
     name = "platform_strings_test",
     size = "small",
-    srcs = ["platform/platform_strings_test.cc"],
+    srcs = ["//tensorflow/core/platform:platform_strings_test.cc"],
     features = ["-dynamic_link_test_srcs"],  # see go/dynamic_link_test_srcs
     deps = [
         ":lib",
@@ -3944,7 +3890,7 @@
 tf_cc_test(
     name = "platform_env_test",
     size = "small",
-    srcs = ["platform/env_test.cc"],
+    srcs = ["//tensorflow/core/platform:env_test.cc"],
     deps = [
         ":lib",
         ":lib_internal",
@@ -3959,7 +3905,7 @@
 tf_cc_test(
     name = "platform_fake_python_env_test",
     size = "small",
-    srcs = ["platform/fake_python_env_test.cc"],
+    srcs = ["//tensorflow/core/platform:fake_python_env_test.cc"],
     args = [
         "/some/path/to/pythontest.runfiles/org_tensorflow/stuff/to/run.py",
     ],
@@ -3982,7 +3928,7 @@
 tf_cc_test(
     name = "platform_abi_test",
     size = "small",
-    srcs = ["platform/abi_test.cc"],
+    srcs = ["//tensorflow/core/platform:abi_test.cc"],
     deps = [
         ":framework",
         ":lib",
@@ -3998,7 +3944,7 @@
 tf_cc_test(
     name = "platform_numa_test",
     size = "small",
-    srcs = ["platform/numa_test.cc"],
+    srcs = ["//tensorflow/core/platform:numa_test.cc"],
     tags = [
         # This test will not pass unless it has access to all NUMA nodes
         # on the executing machine.
@@ -4020,7 +3966,7 @@
 tf_cc_test(
     name = "platform_setround_test",
     size = "small",
-    srcs = ["platform/setround_test.cc"],
+    srcs = ["//tensorflow/core/platform:setround_test.cc"],
     tags = [
         "noasan",
         "noclang",
@@ -4039,7 +3985,7 @@
 tf_cc_test(
     name = "platform_file_system_test",
     size = "small",
-    srcs = ["platform/file_system_test.cc"],
+    srcs = ["//tensorflow/core/platform:file_system_test.cc"],
     deps = [
         ":lib",
         ":lib_internal",
@@ -4609,6 +4555,20 @@
     ],
 )
 
+tf_cc_test_gpu(
+    name = "rocm_rocdl_path_test",
+    size = "small",
+    srcs = ["//tensorflow/core/platform:rocm_rocdl_path_test.cc"],
+    linkstatic = tf_kernel_tests_linkstatic(),
+    tags = tf_gpu_tests_tags(),
+    deps = [
+        ":lib",
+        ":test",
+        ":test_main",
+        "//tensorflow/core/platform:rocm_rocdl_path",
+    ],
+)
+
 tf_cuda_only_cc_test(
     name = "util_gpu_kernel_helper_test",
     srcs = [
@@ -5351,7 +5311,7 @@
 tf_cc_test_gpu(
     name = "device_tracer_test",
     size = "small",
-    srcs = ["platform/device_tracer_test.cc"],
+    srcs = ["//tensorflow/core/platform:device_tracer_test.cc"],
     args =
         ["--heap_check=local"] + tf_additional_device_tracer_test_flags(),
     linkstatic = tf_kernel_tests_linkstatic(),
@@ -5575,10 +5535,12 @@
 
 cc_library(
     name = "cuda_libdevice_path",
-    srcs = tf_additional_libdevice_srcs(),
-    hdrs = ["platform/cuda_libdevice_path.h"],
+    srcs = [
+        "//tensorflow/core/platform:legacy_libdevice_srcs",
+    ],
     copts = tf_copts(),
     data = tf_additional_libdevice_data(),
+    textual_hdrs = ["//tensorflow/core/platform:cuda_libdevice_path.h"],
     visibility = ["//visibility:public"],
     deps = [
         ":lib",
diff --git a/tensorflow/core/api_def/base_api/api_def_AnonymousMemoryCache.pbtxt b/tensorflow/core/api_def/base_api/api_def_AnonymousMemoryCache.pbtxt
new file mode 100644
index 0000000..7c6d161
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_AnonymousMemoryCache.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "AnonymousMemoryCache"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_AnonymousRandomSeedGenerator.pbtxt b/tensorflow/core/api_def/base_api/api_def_AnonymousRandomSeedGenerator.pbtxt
new file mode 100644
index 0000000..327a068
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_AnonymousRandomSeedGenerator.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "AnonymousRandomSeedGenerator"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ApplyAdagradV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ApplyAdagradV2.pbtxt
new file mode 100644
index 0000000..07366bf
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ApplyAdagradV2.pbtxt
@@ -0,0 +1,53 @@
+op {
+  graph_op_name: "ApplyAdagradV2"
+  visibility: HIDDEN
+  in_arg {
+    name: "var"
+    description: <<END
+Should be from a Variable().
+END
+  }
+  in_arg {
+    name: "accum"
+    description: <<END
+Should be from a Variable().
+END
+  }
+  in_arg {
+    name: "lr"
+    description: <<END
+Scaling factor. Must be a scalar.
+END
+  }
+  in_arg {
+    name: "epsilon"
+    description: <<END
+Constant factor. Must be a scalar.
+END
+  }
+  in_arg {
+    name: "grad"
+    description: <<END
+The gradient.
+END
+  }
+  out_arg {
+    name: "out"
+    description: <<END
+Same as "var".
+END
+  }
+  attr {
+    name: "use_locking"
+    description: <<END
+If `True`, updating of the var and accum tensors will be protected
+by a lock; otherwise the behavior is undefined, but may exhibit less
+contention.
+END
+  }
+  summary: "Update \'*var\' according to the adagrad scheme."
+  description: <<END
+accum += grad * grad
+var -= lr * grad * (1 / sqrt(accum))
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ArgMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_ArgMax.pbtxt
index 8a213aa..47956bd 100644
--- a/tensorflow/core/api_def/base_api/api_def_ArgMax.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ArgMax.pbtxt
@@ -17,7 +17,7 @@
   import tensorflow as tf
   a = [1, 10, 26.9, 2.8, 166.32, 62.3]
   b = tf.math.argmax(input = a)
-  c = tf.keras.backend.eval(b)  
+  c = tf.keras.backend.eval(b)
   # c = 4
   # here a[4] = 166.32 which is the largest element of a across axis 0
   ```
diff --git a/tensorflow/core/api_def/base_api/api_def_ArgMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_ArgMin.pbtxt
index 6a5f2fa..5ebee5c 100644
--- a/tensorflow/core/api_def/base_api/api_def_ArgMin.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ArgMin.pbtxt
@@ -17,7 +17,7 @@
   import tensorflow as tf
   a = [1, 10, 26.9, 2.8, 166.32, 62.3]
   b = tf.math.argmin(input = a)
-  c = tf.keras.backend.eval(b)  
+  c = tf.keras.backend.eval(b)
   # c = 0
   # here a[0] = 1 which is the smallest element of a across axis 0
   ```
diff --git a/tensorflow/core/api_def/base_api/api_def_CacheDatasetV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_CacheDatasetV2.pbtxt
new file mode 100644
index 0000000..665d7ce
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_CacheDatasetV2.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "CacheDatasetV2"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ConfigureTPUEmbedding.pbtxt b/tensorflow/core/api_def/base_api/api_def_ConfigureTPUEmbedding.pbtxt
new file mode 100644
index 0000000..4198734
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ConfigureTPUEmbedding.pbtxt
@@ -0,0 +1,12 @@
+op {
+  graph_op_name: "ConfigureTPUEmbedding"
+  visibility: HIDDEN
+  attr {
+    name: "config"
+    description: <<END
+Serialized tensorflow.tpu.TPUEmbeddingConfiguration that
+describes the embedding lookups of the program.
+END
+  }
+  summary: "Sets up TPUEmbedding in a distributed TPU system."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_DeleteMemoryCache.pbtxt b/tensorflow/core/api_def/base_api/api_def_DeleteMemoryCache.pbtxt
new file mode 100644
index 0000000..791e6080
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_DeleteMemoryCache.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "DeleteMemoryCache"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_DeleteRandomSeedGenerator.pbtxt b/tensorflow/core/api_def/base_api/api_def_DeleteRandomSeedGenerator.pbtxt
new file mode 100644
index 0000000..3197405
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_DeleteRandomSeedGenerator.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "DeleteRandomSeedGenerator"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalRebatchDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalRebatchDataset.pbtxt
index b845530..d45abf5 100644
--- a/tensorflow/core/api_def/base_api/api_def_ExperimentalRebatchDataset.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalRebatchDataset.pbtxt
@@ -8,9 +8,9 @@
 END
   }
   in_arg {
-  name: "num_workers"
+  name: "num_replicas"
   description: <<END
-A scalar representing the number of workers to distribute this batch across. As
+A scalar representing the number of replicas to distribute this batch across. As
 a result of this transformation the current batch size would end up being
 divided  by this parameter.
 END
@@ -18,6 +18,6 @@
   summary: "Creates a dataset that changes the batch size."
   description: <<END
 Creates a dataset that changes the batch size of the dataset to current batch
-size // num_workers.
+size // num_replicas.
 END
 }
diff --git a/tensorflow/core/api_def/base_api/api_def_GatherNd.pbtxt b/tensorflow/core/api_def/base_api/api_def_GatherNd.pbtxt
index 9f3f9b2..68b78be 100644
--- a/tensorflow/core/api_def/base_api/api_def_GatherNd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_GatherNd.pbtxt
@@ -21,13 +21,13 @@
   }
   summary: "Gather slices from `params` into a Tensor with shape specified by `indices`."
   description: <<END
-`indices` is an K-dimensional integer tensor, best thought of as a
+`indices` is a K-dimensional integer tensor, best thought of as a
 (K-1)-dimensional tensor of indices into `params`, where each element defines a
 slice of `params`:
 
     output[\\(i_0, ..., i_{K-2}\\)] = params[indices[\\(i_0, ..., i_{K-2}\\)]]
 
-Whereas in `tf.gather` `indices` defines slices into the first
+Whereas in `tf.gather` `indices` defines slices into the `axis`
 dimension of `params`, in `tf.gather_nd`, `indices` defines slices into the
 first `N` dimensions of `params`, where `N = indices.shape[-1]`.
 
diff --git a/tensorflow/core/api_def/base_api/api_def_Invert.pbtxt b/tensorflow/core/api_def/base_api/api_def_Invert.pbtxt
index c6cb1c1..44c67e6 100644
--- a/tensorflow/core/api_def/base_api/api_def_Invert.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Invert.pbtxt
@@ -28,10 +28,10 @@
                                       input_tensor, bitwise_ops.invert(input_tensor)),
                                     bitwise_ops.invert(
                                       tf.constant(0, dtype=dtype))]
-  
+
   expected = tf.constant([0, 0, 0, 0], dtype=tf.float32)
   tf.assert_equal(tf.cast(not_a_and_a, tf.float32), expected)
-  
+
   expected = tf.cast([not_0] * 4, tf.float32)
   tf.assert_equal(tf.cast(not_a_or_a, tf.float32), expected)
 
diff --git a/tensorflow/core/api_def/base_api/api_def_RaggedTensorToTensor.pbtxt b/tensorflow/core/api_def/base_api/api_def_RaggedTensorToTensor.pbtxt
new file mode 100644
index 0000000..1746221
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RaggedTensorToTensor.pbtxt
@@ -0,0 +1,81 @@
+op {
+  graph_op_name: "RaggedTensorToTensor"
+  visibility: HIDDEN
+  attr {
+    name: "row_partition_types"
+    description: <<END
+The types of the row partition tensors. At present, these can be:
+* "ROW_SPLITS": the row_splits tensor from the ragged tensor.
+* "VALUE_ROWIDS": the value_rowids tensor from the ragged tensor.
+* "FIRST_DIM_SIZE": if value_rowids is used for the first dimension, then it
+  is preceeded by "FIRST_DIM_SIZE".
+The tensors are in the order of the dimensions.
+END
+  }
+  in_arg {
+    name: "shape"
+    description: <<END
+The desired shape of the the output tensor. If left unspecified (empty),
+the minimal shape required to contain all the elements in the ragged tensor
+(the natural shape) will be used. If some dimensions are left unspecified, then
+the size of the natural shape is used in that dimension.
+
+Note that dense dimensions cannot be modified by the shape argument. Trying to
+change the size of a dense dimension will cause the op to fail.
+Examples:
+natural shape: [4, 5, 6]
+shape: -1
+output shape: [4, 5, 6]
+
+natural shape: [4, 5, 6]
+shape: [3, -1, 2]
+output shape: [3, 5, 2]
+
+natural shape: [4, 5, 6]
+shape: [3, 7, 2]
+output shape: [3, 7, 2]
+
+END
+  }
+in_arg {
+    name: "values"
+    description: <<END
+A 1D tensor representing the values of the ragged tensor.
+END
+  }
+  in_arg {
+    name: "default_value"
+    description: <<END
+The default_value when the shape is larger than the ragged tensor. The
+default_value is broadcast until it is the shape of the output tensor, and
+then overwritten by values in the ragged tensor. The default value must be
+compatible with this broadcast operation, and must have fewer dimensions than
+the value tensor.
+END
+  }
+  out_arg {
+    name: "result"
+    description: "The resulting dense tensor."
+  }
+  summary: <<END
+Create a dense tensor from a ragged tensor, possibly altering its shape.
+END
+  description: <<END
+The `ragged_to_dense` op creates a dense tensor from a list of row partition
+tensors, a value vector, and default values. If the shape is unspecified, the
+minimal shape required to contain all the elements in the ragged tensor (the
+natural shape) will be used. If some dimensions are left unspecified, then the
+size of the natural shape is used in that dimension.
+
+The default_value will be broadcast to the output shape. After that, the values
+from the ragged tensor overwrite the default values. Note that the default_value
+must have less dimensions than the value.
+
+The row partition tensors are in the order of the dimensions.
+At present, the types can be:
+* "ROW_SPLITS": the row_splits tensor from the ragged tensor.
+* "VALUE_ROWIDS": the value_rowids tensor from the ragged tensor.
+* "FIRST_DIM_SIZE": if value_rowids is used for the first dimension, then it
+  is preceded by "FIRST_DIM_SIZE".
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RebatchDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_RebatchDataset.pbtxt
index 7017e37..9375f3e 100644
--- a/tensorflow/core/api_def/base_api/api_def_RebatchDataset.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_RebatchDataset.pbtxt
@@ -8,9 +8,9 @@
 END
   }
   in_arg {
-  name: "num_workers"
+  name: "num_replicas"
   description: <<END
-A scalar representing the number of workers to distribute this batch across. As
+A scalar representing the number of replicas to distribute this batch across. As
 a result of this transformation the current batch size would end up being
 divided  by this parameter.
 END
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceApplyAdagradV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceApplyAdagradV2.pbtxt
new file mode 100644
index 0000000..d99d418
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ResourceApplyAdagradV2.pbtxt
@@ -0,0 +1,47 @@
+op {
+  graph_op_name: "ResourceApplyAdagradV2"
+  visibility: HIDDEN
+  in_arg {
+    name: "var"
+    description: <<END
+Should be from a Variable().
+END
+  }
+  in_arg {
+    name: "accum"
+    description: <<END
+Should be from a Variable().
+END
+  }
+  in_arg {
+    name: "lr"
+    description: <<END
+Scaling factor. Must be a scalar.
+END
+  }
+  in_arg {
+    name: "epsilon"
+    description: <<END
+Constant factor. Must be a scalar.
+END
+  }
+  in_arg {
+    name: "grad"
+    description: <<END
+The gradient.
+END
+  }
+  attr {
+    name: "use_locking"
+    description: <<END
+If `True`, updating of the var and accum tensors will be protected
+by a lock; otherwise the behavior is undefined, but may exhibit less
+contention.
+END
+  }
+  summary: "Update \'*var\' according to the adagrad scheme."
+  description: <<END
+accum += grad * grad
+var -= lr * grad * (1 / sqrt(accum))
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceSparseApplyAdagradV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceSparseApplyAdagradV2.pbtxt
new file mode 100644
index 0000000..5c98df6
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ResourceSparseApplyAdagradV2.pbtxt
@@ -0,0 +1,54 @@
+op {
+  graph_op_name: "ResourceSparseApplyAdagradV2"
+  visibility: HIDDEN
+  in_arg {
+    name: "var"
+    description: <<END
+Should be from a Variable().
+END
+  }
+  in_arg {
+    name: "accum"
+    description: <<END
+Should be from a Variable().
+END
+  }
+  in_arg {
+    name: "lr"
+    description: <<END
+Learning rate. Must be a scalar.
+END
+  }
+  in_arg {
+    name: "epsilon"
+    description: <<END
+Constant factor. Must be a scalar.
+END
+  }
+  in_arg {
+    name: "grad"
+    description: <<END
+The gradient.
+END
+  }
+  in_arg {
+    name: "indices"
+    description: <<END
+A vector of indices into the first dimension of var and accum.
+END
+  }
+  attr {
+    name: "use_locking"
+    description: <<END
+If `True`, updating of the var and accum tensors will be protected
+by a lock; otherwise the behavior is undefined, but may exhibit less
+contention.
+END
+  }
+  summary: "Update relevant entries in \'*var\' and \'*accum\' according to the adagrad scheme."
+  description: <<END
+That is for rows we have grad for, we update var and accum as follows:
+accum += grad * grad
+var -= lr * grad * (1 / sqrt(accum))
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_SamplingDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_SamplingDataset.pbtxt
index 0c6fccd..48c01e9 100644
--- a/tensorflow/core/api_def/base_api/api_def_SamplingDataset.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SamplingDataset.pbtxt
@@ -4,8 +4,8 @@
   in_arg {
     name: "rate"
     description: <<END
-A scalar representing the sample rate of elements from the `input_dataset`
-that should be taken.
+A scalar representing the sample rate. Each element of `input_dataset` is 
+retained with this probability, independent of all other elements.
 END
   }
   in_arg {
@@ -20,5 +20,12 @@
 A scalar representing seed2 of random number generator.
 END
   }
-  summary: "Creates a dataset that contains `rate` elements from the `input_dataset`."
+  summary: "Creates a dataset that takes a Bernoulli sample of the contents of another dataset."
+  description: <<END
+There is no transformation in the `tf.data` Python API for creating this dataset.
+Instead, it is created as a result of the `filter_with_random_uniform_fusion`
+static optimization. Whether this optimization is performed is determined by the
+`experimental_optimization.filter_with_random_uniform_fusion` option of
+`tf.data.Options`.
+END
 }
diff --git a/tensorflow/core/api_def/base_api/api_def_ShuffleDatasetV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ShuffleDatasetV2.pbtxt
new file mode 100644
index 0000000..b5a6e2d
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ShuffleDatasetV2.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "ShuffleDatasetV2"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseApplyAdagradV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseApplyAdagradV2.pbtxt
new file mode 100644
index 0000000..e44d329
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_SparseApplyAdagradV2.pbtxt
@@ -0,0 +1,60 @@
+op {
+  graph_op_name: "SparseApplyAdagradV2"
+  visibility: HIDDEN
+  in_arg {
+    name: "var"
+    description: <<END
+Should be from a Variable().
+END
+  }
+  in_arg {
+    name: "accum"
+    description: <<END
+Should be from a Variable().
+END
+  }
+  in_arg {
+    name: "lr"
+    description: <<END
+Learning rate. Must be a scalar.
+END
+  }
+  in_arg {
+    name: "epsilon"
+    description: <<END
+Constant factor. Must be a scalar.
+END
+  }
+  in_arg {
+    name: "grad"
+    description: <<END
+The gradient.
+END
+  }
+  in_arg {
+    name: "indices"
+    description: <<END
+A vector of indices into the first dimension of var and accum.
+END
+  }
+  out_arg {
+    name: "out"
+    description: <<END
+Same as "var".
+END
+  }
+  attr {
+    name: "use_locking"
+    description: <<END
+If `True`, updating of the var and accum tensors will be protected
+by a lock; otherwise the behavior is undefined, but may exhibit less
+contention.
+END
+  }
+  summary: "Update relevant entries in \'*var\' and \'*accum\' according to the adagrad scheme."
+  description: <<END
+That is for rows we have grad for, we update var and accum as follows:
+$$accum += grad * grad$$
+$$var -= lr * grad * (1 / sqrt(accum))$$
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_StringNGrams.pbtxt b/tensorflow/core/api_def/base_api/api_def_StringNGrams.pbtxt
new file mode 100644
index 0000000..d3d1a01
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StringNGrams.pbtxt
@@ -0,0 +1,69 @@
+op {
+  graph_op_name: "StringNGrams"
+  in_arg {
+    name: "data"
+    description: <<END
+The values tensor of the ragged string tensor to make ngrams out of. Must be a
+1D string tensor.
+END
+  }
+  in_arg {
+    name: "data_splits"
+    description: <<END
+The splits tensor of the ragged string tensor to make ngrams out of.
+END
+  }
+  out_arg {
+    name: "ngrams"
+    description: <<END
+The values tensor of the output ngrams ragged tensor.
+END
+  }
+  out_arg {
+    name: "ngrams_splits"
+    description: <<END
+The splits tensor of the output ngrams ragged tensor.
+END
+  }
+  attr {
+    name: "separator"
+    description: <<END
+The string to append between elements of the token. Use "" for no separator.
+END
+  }
+  attr {
+    name: "ngram_widths"
+    description: <<END
+The sizes of the ngrams to create.
+END
+  }
+  attr {
+    name: "left_pad"
+    description: <<END
+The string to use to pad the left side of the ngram sequence. Only used if
+pad_width != 0.
+END
+  }
+  attr {
+    name: "right_pad"
+    description: <<END
+The string to use to pad the right side of the ngram sequence. Only used if
+pad_width != 0.
+END
+}
+  attr {
+    name: "pad_width"
+    description: <<END
+The number of padding elements to add to each side of each
+sequence. Note that padding will never be greater than 'ngram_widths'-1
+regardless of this value. If `pad_width=-1`, then add `max(ngram_widths)-1`
+elements.
+END
+  }
+  summary: "Creates ngrams from ragged string data."
+  description: <<END
+This op accepts a ragged tensor with 1 ragged dimension containing only
+strings and outputs a ragged tensor with 1 ragged dimension containing ngrams
+of that string, joined along the innermost axis.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_TensorListPushBack.pbtxt b/tensorflow/core/api_def/base_api/api_def_TensorListPushBack.pbtxt
index 73297c0..988d8b6 100644
--- a/tensorflow/core/api_def/base_api/api_def_TensorListPushBack.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_TensorListPushBack.pbtxt
@@ -1,6 +1,6 @@
 op {
   graph_op_name: "TensorListPushBack"
-  summary: "Returns a list list which has the passed-in `Tensor` as last element and the other elements of the given list in `input_handle`."
+  summary: "Returns a list which has the passed-in `Tensor` as last element and the other elements of the given list in `input_handle`."
   description: <<END
 tensor: The tensor to put on the list.
 input_handle: The old list.
diff --git a/tensorflow/core/api_def/excluded_ops.cc b/tensorflow/core/api_def/excluded_ops.cc
index ddac98d..2b3a8f6 100644
--- a/tensorflow/core/api_def/excluded_ops.cc
+++ b/tensorflow/core/api_def/excluded_ops.cc
@@ -42,9 +42,9 @@
           "QuantizedMatMulWithBiasAndReluAndRequantize",
 #endif  // INTEL_MKL
 #ifdef GOOGLE_TENSORRT
-          "CreateTRTEngineCacheHandle",
-          "PopulateTRTEngineCache",
-          "DumpTRTEngineCache",
+          "CreateTRTResourceHandle",
+          "InitializeTRTResource",
+          "SerializeTRTResource",
           "GetCalibrationDataOp",
           "TRTEngineOp",
 #endif  // GOOGLE_TENSORRT
diff --git a/tensorflow/core/api_def/python_api/api_def_Fill.pbtxt b/tensorflow/core/api_def/python_api/api_def_Fill.pbtxt
new file mode 100644
index 0000000..ba9b5dc
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Fill.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "Fill"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt b/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt
index ee202490..d3694dc 100644
--- a/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt
@@ -1,10 +1,4 @@
 op {
   graph_op_name: "Reshape"
-  endpoint {
-    name: "reshape"
-  }
-  endpoint {
-    name: "manip.reshape"
-    deprecation_version: 2
-  }
+  visibility: HIDDEN
 }
diff --git a/tensorflow/core/api_def/python_api/api_def_StringNGrams.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringNGrams.pbtxt
new file mode 100644
index 0000000..acefd9b
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StringNGrams.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "StringNGrams"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/common_runtime/accumulate_n_optimizer.cc b/tensorflow/core/common_runtime/accumulate_n_optimizer.cc
index 62424eb..15e58ce 100644
--- a/tensorflow/core/common_runtime/accumulate_n_optimizer.cc
+++ b/tensorflow/core/common_runtime/accumulate_n_optimizer.cc
@@ -85,9 +85,9 @@
       // With `parallel_iterations == 1` it's safe to use TemporaryVariable.
       if (is_in_while_loop) {
         int parallel_iterations;
-        Status s = GetNodeAttr(frame->attrs(), kParallelIterationsAttrName,
-                               &parallel_iterations);
-        if (s.ok() && parallel_iterations == 1) {
+        bool found = TryGetNodeAttr(frame->attrs(), kParallelIterationsAttrName,
+                                    &parallel_iterations);
+        if (found && parallel_iterations == 1) {
           is_in_while_loop = false;
         }
       }
@@ -112,8 +112,8 @@
 
       // The pieces of AccumulateNV2 should all be on the same node.
       node_builder.Device(n->requested_device());
-      string colo;
-      if (GetNodeAttr(n_attrs, kColocationAttrName, &colo).ok()) {
+      const string& colo = GetNodeAttrString(n_attrs, kColocationAttrName);
+      if (!colo.empty()) {
         node_builder.Attr(kColocationAttrName, colo);
       }
       return node_builder;
@@ -261,8 +261,8 @@
             .Attr("T", dtype)
             .Input(data_inputs)
             .ControlInputs(control_inputs);
-    string colo;
-    if (GetNodeAttr(n_attrs, kColocationAttrName, &colo).ok()) {
+    const string& colo = GetNodeAttrString(n_attrs, kColocationAttrName);
+    if (!colo.empty()) {
       builder.Attr(kColocationAttrName, colo);
     }
     TF_RETURN_IF_ERROR(builder.Finalize(g, &add_n_node));
diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc
index 0734e53..fa3dde9 100644
--- a/tensorflow/core/common_runtime/base_collective_executor.cc
+++ b/tensorflow/core/common_runtime/base_collective_executor.cc
@@ -262,11 +262,9 @@
     delete col_impl;
     return;
   }
-  // Run in an I/O thread, so as not to starve the executor threads.
-  // TODO(b/80529858): Instead of forking every per-device Collective
-  // Op off into its own thread, consider queuing them on a
-  // fixed-size thread-pool dedicated to running CollectiveOps.
-  SchedClosure([col_impl, col_ctx, done_safe, ctx]() {
+  // Run on an unbounded work queue that can handle blocking work so as to not
+  // starve executor threads.
+  remote_access_->RunClosure([col_impl, col_ctx, done_safe, ctx]() {
     profiler::TraceMe activity(
         [&] {
           return strings::StrCat(ctx->op_kernel().name(), ":",
diff --git a/tensorflow/core/common_runtime/base_collective_executor.h b/tensorflow/core/common_runtime/base_collective_executor.h
index 6ecfca2..1f1c809 100644
--- a/tensorflow/core/common_runtime/base_collective_executor.h
+++ b/tensorflow/core/common_runtime/base_collective_executor.h
@@ -142,6 +142,10 @@
                                client_locality, done);
   }
 
+  void RunClosure(std::function<void()> closure) override {
+    remote_access_->RunClosure(std::move(closure));
+  }
+
   // If we need to enforce an ordering on any portion of collective
   // implementation, and the ordering is encoded via attribute on the collective
   // op, this function will block until all dependencies for this collective
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.cc b/tensorflow/core/common_runtime/collective_executor_mgr.cc
index 7bbc7ca..e9e0082 100644
--- a/tensorflow/core/common_runtime/collective_executor_mgr.cc
+++ b/tensorflow/core/common_runtime/collective_executor_mgr.cc
@@ -31,7 +31,9 @@
       dev_resolver_(std::move(dev_resolver)),
       param_resolver_(std::move(param_resolver)),
       gpu_ring_order_(
-          config.gpu_options().experimental().collective_ring_order()) {}
+          config.gpu_options().experimental().collective_ring_order()),
+      work_queue_(std::make_shared<UnboundedWorkQueue>(Env::Default(),
+                                                       "collective_ops")) {}
 
 CollectiveExecutorMgr::~CollectiveExecutorMgr() {
   for (auto iter : executor_table_) {
@@ -56,8 +58,8 @@
 }
 
 CollectiveExecutor* CollectiveExecutorMgr::Create(int64 step_id) {
-  CollectiveRemoteAccessLocal* rma =
-      new CollectiveRemoteAccessLocal(dev_mgr_, dev_resolver_.get(), step_id);
+  CollectiveRemoteAccessLocal* rma = new CollectiveRemoteAccessLocal(
+      dev_mgr_, dev_resolver_.get(), work_queue_, step_id);
   return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_,
                                     &gpu_ring_order_);
 }
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.h b/tensorflow/core/common_runtime/collective_executor_mgr.h
index 4db121a..d4cef14 100644
--- a/tensorflow/core/common_runtime/collective_executor_mgr.h
+++ b/tensorflow/core/common_runtime/collective_executor_mgr.h
@@ -17,6 +17,7 @@
 
 #include "tensorflow/core/framework/collective.h"
 #include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/platform/unbounded_work_queue.h"
 
 namespace tensorflow {
 class ConfigProto;
@@ -63,6 +64,10 @@
   std::unique_ptr<DeviceResolverInterface> dev_resolver_;
   std::unique_ptr<ParamResolverInterface> param_resolver_;
   string gpu_ring_order_;
+  // Unbounded work queue for scheduling potentially-blocking work during
+  // collective op execution.  Ownership is shared between `this` and
+  // `CollectiveRemoteAccessLocal`.
+  std::shared_ptr<UnboundedWorkQueue> work_queue_;
 
  private:
   mutex exec_mu_;
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
index 2be3f62..97523e3 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
@@ -56,18 +56,13 @@
 }
 
 namespace {
-string GetCollectiveName(const CollectiveParams* cp, bool nccl) {
+const char* GetCollectiveName(const CollectiveParams* cp, bool nccl) {
   switch (cp->instance.type) {
     case BROADCAST_COLLECTIVE:
       return "HierarchicalTreeBroadcast";
 
-    case REDUCTION_COLLECTIVE: {
-      if (nccl) {
-        return "NcclReduce";
-      } else {
-        return "RingReduce";
-      }
-    }
+    case REDUCTION_COLLECTIVE:
+      return nccl ? "NcclReduce" : "RingReduce";
 
     case GATHER_COLLECTIVE:
       return "RingGather";
@@ -96,15 +91,22 @@
 
       // Initialize group runtime details.
       CollectiveImplementationInterface* col_impl;
-      // TODO(b/128853131,b/132707282): Remove NCCL special case when we have
-      // NCCL implementations for all collectives.
-      status = CollectiveRegistry::LookupParamResolverInstance(
-          nccl_ ? "NcclReduce" : GetCollectiveName(cp, /*nccl=*/false),
-          &col_impl);
+      // Try to lookup a NCCL collective kernel.  This will return error status
+      // if `NcclReduce` kernel is not present in the registry, e.g. on an
+      // environment that does not support NCCL.
+      status = CollectiveRegistry::LookupParamResolverInstance("NcclReduce",
+                                                               &col_impl);
+      if (!status.ok()) {
+        // Fallback to non-NCCL collective.
+        status = CollectiveRegistry::LookupParamResolverInstance(
+            GetCollectiveName(cp, /*nccl=*/false), &col_impl);
+      }
       if (status.ok()) {
         status = col_impl->InitializeCollectiveGroupRuntimeDetails(
             &gr->group.runtime_details);
-      } else {
+      }
+
+      if (!status.ok()) {
         done(status, gr);
         return;
       }
@@ -702,6 +704,7 @@
 void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
     const string& device, const GroupRec* gr, CollectiveParams* cp,
     InstanceRec* ir, bool is_source, const StatusCallback& done) {
+  auto expected_shape = cp->instance.shape;
   // Populate the fields common across instance.
   {
     mutex_lock l(ir->out_mu);
@@ -709,6 +712,16 @@
     // custom operator= does a deep copy.
     cp->instance = ir->shared.instance;
   }
+  if (expected_shape != cp->instance.shape) {
+    done(errors::InvalidArgument(
+        "Shape mismatch in the collective instance ", cp->instance.instance_key,
+        ". Op at device ", device, " expected shape ",
+        expected_shape.DebugString(), " but another member in the group ",
+        "expected shape ", cp->instance.shape.DebugString(), ". This is likely",
+        " due to different input shapes at different members of the collective",
+        " op."));
+    return;
+  }
   // Populate the fields common across task.
   AssignCollectiveType(cp);
   SetDefaultRank(device, cp);
diff --git a/tensorflow/core/common_runtime/collective_rma_local.h b/tensorflow/core/common_runtime/collective_rma_local.h
index 160161f..b5d02f4 100644
--- a/tensorflow/core/common_runtime/collective_rma_local.h
+++ b/tensorflow/core/common_runtime/collective_rma_local.h
@@ -14,10 +14,12 @@
 ==============================================================================*/
 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_H_
 #define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_H_
+
 #include "tensorflow/core/common_runtime/buf_rendezvous.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
 #include "tensorflow/core/framework/collective.h"
 #include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/platform/unbounded_work_queue.h"
 
 namespace tensorflow {
 
@@ -26,13 +28,15 @@
  public:
   CollectiveRemoteAccessLocal(const DeviceMgr* dev_mgr,
                               DeviceResolverInterface* dev_resolver,
+                              std::shared_ptr<UnboundedWorkQueue> work_queue,
                               int64 step_id)
       : dev_mgr_(dev_mgr),
         dev_resolver_(dev_resolver),
+        work_queue_(std::move(work_queue)),
         buf_rendezvous_(step_id, dev_mgr),
         step_id_(step_id) {}
 
-  virtual ~CollectiveRemoteAccessLocal() {}
+  ~CollectiveRemoteAccessLocal() override = default;
 
   void StartAbort(const Status& s) override;
 
@@ -52,6 +56,10 @@
                   const DeviceLocality& client_locality,
                   const StatusCallback& done) override;
 
+  void RunClosure(std::function<void()> closure) override {
+    work_queue_->Schedule(std::move(closure));
+  }
+
   void GetAllDeviceAttributesAsync(const std::vector<string>& devices,
                                    const std::vector<string>& tasks,
                                    std::vector<DeviceAttributes>* attributes,
@@ -88,6 +96,9 @@
  protected:
   const DeviceMgr* dev_mgr_;               // not owned
   DeviceResolverInterface* dev_resolver_;  // not owned
+  // Ownership of `work_queue_` is shared between `this` and
+  // `CollectiveExecutorMgr`.
+  std::shared_ptr<UnboundedWorkQueue> work_queue_;
   BufRendezvous buf_rendezvous_;
   int64 step_id_;
 };
diff --git a/tensorflow/core/common_runtime/collective_rma_local_test.cc b/tensorflow/core/common_runtime/collective_rma_local_test.cc
index 2e9d8cd..6024359 100644
--- a/tensorflow/core/common_runtime/collective_rma_local_test.cc
+++ b/tensorflow/core/common_runtime/collective_rma_local_test.cc
@@ -25,6 +25,7 @@
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/unbounded_work_queue.h"
 #include "tensorflow/core/public/session_options.h"
 
 namespace tensorflow {
@@ -38,20 +39,24 @@
   const string kTaskName = "/job:localhost/replica:0/task:0";
 
   CollectiveRemoteAccessLocalTest() {
+    work_queue_ = std::make_shared<UnboundedWorkQueue>(Env::Default(), "test");
     ConfigProto cp;
     SessionOptions options;
     auto* device_count = options.config.mutable_device_count();
     device_count->insert({"CPU", NUM_DEVS});
     std::vector<std::unique_ptr<Device>> devices;
     TF_CHECK_OK(DeviceFactory::AddDevices(options, kTaskName, &devices));
-    device_mgr_.reset(new DeviceMgr(std::move(devices)));
-    drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
-    prl_.reset(new CollectiveParamResolverLocal(cp, device_mgr_.get(),
-                                                drl_.get(), kTaskName));
-    rma_.reset(new CollectiveRemoteAccessLocal(device_mgr_.get(), drl_.get(),
-                                               kStepId));
+    device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
+    drl_ = absl::make_unique<DeviceResolverLocal>(device_mgr_.get());
+    prl_ = absl::make_unique<CollectiveParamResolverLocal>(
+        cp, device_mgr_.get(), drl_.get(), kTaskName);
+    rma_ = absl::make_unique<CollectiveRemoteAccessLocal>(
+        device_mgr_.get(), drl_.get(), work_queue_, kStepId);
   }
 
+  ~CollectiveRemoteAccessLocalTest() override = default;
+
+  std::shared_ptr<UnboundedWorkQueue> work_queue_;
   std::unique_ptr<DeviceMgr> device_mgr_;
   std::unique_ptr<DeviceResolverLocal> drl_;
   std::unique_ptr<CollectiveParamResolverLocal> prl_;
diff --git a/tensorflow/core/common_runtime/colocation_graph.cc b/tensorflow/core/common_runtime/colocation_graph.cc
index 4fd40a1..4f70680 100644
--- a/tensorflow/core/common_runtime/colocation_graph.cc
+++ b/tensorflow/core/common_runtime/colocation_graph.cc
@@ -438,7 +438,7 @@
   return true;
 }
 
-Status Member::AssignDevice(const Node& node, bool allow_soft_placement) {
+Status Member::AssignDevice(const Node& node) {
   if (node.assigned_device_name_index() == assigned_device_name_index_) {
     return Status::OK();
   }
@@ -539,7 +539,7 @@
     : graph_(*graph),
       stack_(stack),
       flib_def_(*flib_def),
-      inspecting_placer_(graph, stack, flib_def, device_set, default_device,
+      inspecting_placer_(stack, flib_def, device_set, default_device,
                          allow_soft_placement, log_device_placement),
       inspection_required_checker_(graph, flib_def),
       device_set_(*device_set),
@@ -914,7 +914,7 @@
   }
   int root = FindAndUpdateRoot(node.id());
   Member& root_member = members_[root];
-  return root_member.AssignDevice(node, allow_soft_placement_);
+  return root_member.AssignDevice(node);
 }
 
 void ColocationGraph::GetSoftDeviceCandidates(
diff --git a/tensorflow/core/common_runtime/colocation_graph.h b/tensorflow/core/common_runtime/colocation_graph.h
index 410b943..1d71a90 100644
--- a/tensorflow/core/common_runtime/colocation_graph.h
+++ b/tensorflow/core/common_runtime/colocation_graph.h
@@ -80,7 +80,7 @@
   // not update this. Else returns true and updates this.
   bool MergeSupportedDevices(const Member& other);
 
-  Status AssignDevice(const Node& node, bool allow_soft_placement);
+  Status AssignDevice(const Node& node);
 
   // Limit the possible devices of this (should be a root) to the device
   // specifications in `devices`.
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc
index beeca57..5c7d3ef 100644
--- a/tensorflow/core/common_runtime/constant_folding.cc
+++ b/tensorflow/core/common_runtime/constant_folding.cc
@@ -419,7 +419,7 @@
     const Graph* orig_graph, const std::vector<Node*>& nodes,
     const std::unordered_map<const Node*, std::vector<Tensor>>&
         shape_replacement_map,
-    std::map<NodeAndOutput, Node*>* tensors_to_fetch,
+    std::map<NodeAndOutput, NodeAndOutput>* tensors_to_fetch,
     const ConstantFoldNameGenerator& generate_new_name) {
   Graph* constant_graph = new Graph(orig_graph->op_registry());
   std::unordered_map<Node*, std::vector<Node*>> node_map;
@@ -441,7 +441,7 @@
         if (added_nodes.second.size() == 1) {
           tensors_to_fetch->insert(
               {{added_nodes.second[0], out_edge->src_output()},
-               added_nodes.first});
+               {added_nodes.first, out_edge->src_output()}});
         } else {
           // The node had multiple outputs and was replaced by a
           // vector of constants, so the NodeAndOutput is the 0th
@@ -449,7 +449,7 @@
           // output of the added node as in the standard case above.
           tensors_to_fetch->insert(
               {{added_nodes.second[out_edge->src_output()], 0},
-               added_nodes.first});
+               {added_nodes.first, out_edge->src_output()}});
         }
       }
     }
@@ -590,7 +590,7 @@
     return Status::OK();
   }
 
-  std::map<NodeAndOutput, Node*> tensors_to_fetch;
+  std::map<NodeAndOutput, NodeAndOutput> tensors_to_fetch;
   std::unique_ptr<Graph> constant_graph(
       GetConstantGraph(graph, constant_foldable_nodes, shape_replacement_map,
                        &tensors_to_fetch, generate_new_name));
@@ -609,17 +609,18 @@
   std::vector<NodeAndOutput> tensors_to_replace;
   // Sorting the nodes based on the name gives us a stable ordering between runs
   // for the same graph.
-  std::vector<std::pair<NodeAndOutput, Node*>> tensors_to_fetch_sorted(
+  std::vector<std::pair<NodeAndOutput, NodeAndOutput>> tensors_to_fetch_sorted(
       tensors_to_fetch.begin(), tensors_to_fetch.end());
   std::sort(tensors_to_fetch_sorted.begin(), tensors_to_fetch_sorted.end(),
-            [](const std::pair<NodeAndOutput, Node*>& n1,
-               const std::pair<NodeAndOutput, Node*>& n2) {
-              return n1.first.first->name() < n2.first.first->name();
+            [](const std::pair<NodeAndOutput, NodeAndOutput>& n1,
+               const std::pair<NodeAndOutput, NodeAndOutput>& n2) {
+              return std::tie(n1.first.first->name(), n1.first.second) <
+                     std::tie(n2.first.first->name(), n2.first.second);
             });
   for (auto n : tensors_to_fetch_sorted) {
     tensors_to_fetch_names.push_back(
         strings::StrCat(n.first.first->name(), ":", n.first.second));
-    tensors_to_replace.push_back({n.second, n.first.second});
+    tensors_to_replace.push_back(n.second);
   }
 
   auto graph_runner = std::unique_ptr<GraphRunner>(new GraphRunner(env));
diff --git a/tensorflow/core/common_runtime/data/BUILD b/tensorflow/core/common_runtime/data/BUILD
index 1909018..e5102d0 100644
--- a/tensorflow/core/common_runtime/data/BUILD
+++ b/tensorflow/core/common_runtime/data/BUILD
@@ -3,7 +3,7 @@
 )
 
 load("//tensorflow:tensorflow.bzl", "tf_cc_test")
-load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
+load("//tensorflow/core/platform:default/build_config.bzl", "tf_protos_all")
 
 cc_library(
     name = "standalone",
diff --git a/tensorflow/core/common_runtime/device_mgr.cc b/tensorflow/core/common_runtime/device_mgr.cc
index 1cd8931..56ac71d 100644
--- a/tensorflow/core/common_runtime/device_mgr.cc
+++ b/tensorflow/core/common_runtime/device_mgr.cc
@@ -49,6 +49,14 @@
         return vector;
       }()) {}
 
+DeviceMgr::~DeviceMgr() {
+  // Release resources ahead of destroying the device manager as the resource
+  // destructors (e.g. ~IteratorResource) assume devices still exist.
+  for (auto& device : devices_) {
+    device->ClearResourceMgr();
+  }
+}
+
 StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
   size_t n = s.size();
   char* space = name_backing_store_.Alloc(n);
diff --git a/tensorflow/core/common_runtime/device_mgr.h b/tensorflow/core/common_runtime/device_mgr.h
index bf86946..3cef631 100644
--- a/tensorflow/core/common_runtime/device_mgr.h
+++ b/tensorflow/core/common_runtime/device_mgr.h
@@ -42,6 +42,8 @@
   // Constructs a DeviceMgr managing a single device.
   explicit DeviceMgr(std::unique_ptr<Device> device);
 
+  ~DeviceMgr();
+
   // Returns attributes of all devices.
   void ListDeviceAttributes(std::vector<DeviceAttributes>* devices) const;
 
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 3661367..f6bd957 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -42,6 +42,7 @@
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/graph_def_util.h"
 #include "tensorflow/core/framework/log_memory.h"
+#include "tensorflow/core/framework/logging.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/run_handler.h"
 #include "tensorflow/core/framework/tensor.h"
@@ -353,7 +354,10 @@
     } else {
       printf("Device mapping:\n%s", mapping_str.c_str());
     }
-    LOG(INFO) << "Device mapping:\n" << mapping_str;
+    string msg = strings::StrCat("Device mapping:\n", mapping_str);
+    if (!logging::LogToListeners(msg)) {
+      LOG(INFO) << msg;
+    }
   }
   for (auto d : device_mgr_->ListDevices()) {
     devices_.push_back(d);
@@ -381,9 +385,6 @@
   for (auto d : device_mgr_->ListDevices()) {
     d->op_segment()->RemoveHold(session_handle_);
   }
-  for (auto d : device_mgr_->ListDevices()) {
-    d->ClearResourceMgr();
-  }
   functions_.clear();
   delete cancellation_manager_;
   for (const auto& p_and_owned : thread_pools_) {
@@ -496,7 +497,17 @@
   RunState run_state(step_id, &devices_);
 
   profiler::TraceMe activity(
-      [&] { return strings::StrCat("SessionRun #id=", step_id, "#"); },
+      [&] {
+        if (options_.config.experimental().has_session_metadata()) {
+          const auto& model_metadata =
+              options_.config.experimental().session_metadata();
+          return strings::StrCat("SessionRun #id=", step_id,
+                                 ",model_id=", model_metadata.name(), ":",
+                                 model_metadata.version(), "#");
+        } else {
+          return strings::StrCat("SessionRun #id=", step_id, "#");
+        }
+      },
       profiler::TraceMeLevel::kInfo);
 
   std::unique_ptr<DebuggerStateInterface> debugger_state;
@@ -590,7 +601,7 @@
 
   std::unique_ptr<ProfilerSession> profiler_session;
   if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
-    profiler_session = ProfilerSession::Create(/*ProfilerContext*/ nullptr);
+    profiler_session = ProfilerSession::Create();
   }
 
   if (run_options.inter_op_thread_pool() < -1 ||
@@ -1614,15 +1625,15 @@
     }
   }
 
-  for (const auto& partition : partitions) {
+  for (auto& partition : partitions) {
     std::unique_ptr<Graph> device_graph(
         new Graph(client_graph->flib_def.get()));
     GraphConstructorOptions device_opts;
     // There are internal operations (e.g., send/recv) that we now allow.
     device_opts.allow_internal_ops = true;
     device_opts.expect_device_spec = true;
-    TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second,
-                                              device_graph.get()));
+    TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
+        device_opts, std::move(partition.second), device_graph.get()));
     outputs->emplace(partition.first, std::move(device_graph));
   }
 
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 8da13aa..b073d1a 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -1055,9 +1055,9 @@
     OP_REQUIRES_OK(ctx,
                    ctx->allocate_output("y", TensorShape({}), &out_tensor));
     if (ctx->session_metadata() != nullptr) {
-      out_tensor->scalar<string>()() = ctx->session_metadata()->DebugString();
+      out_tensor->scalar<tstring>()() = ctx->session_metadata()->DebugString();
     } else {
-      out_tensor->scalar<string>()() = "";
+      out_tensor->scalar<tstring>()() = "";
     }
   }
 };
@@ -1079,7 +1079,7 @@
   run_opts.set_inter_op_thread_pool(-1);
   auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr);
 
-  EXPECT_EQ("", outputs[0].scalar<string>()());
+  EXPECT_EQ("", outputs[0].scalar<tstring>()());
 }
 
 TEST(DirectSessionTest, SessionMetadataPresent) {
@@ -1104,7 +1104,7 @@
 
   SessionMetadata read_metadata;
   ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
-      outputs[0].scalar<string>()(), &read_metadata));
+      outputs[0].scalar<tstring>()(), &read_metadata));
   EXPECT_EQ("name", read_metadata.name());
   EXPECT_EQ(1, read_metadata.version());
 }
@@ -1468,7 +1468,7 @@
 
   const ResourceHandle& resource_handle = outputs[0].scalar<ResourceHandle>()();
   Tensor string_handle(DT_STRING, {});
-  string_handle.flat<string>().setConstant(resource_handle.name());
+  string_handle.flat<tstring>().setConstant(resource_handle.name());
 
   // Second run call: Use a handle.
   std::vector<Tensor> outputs1;
@@ -1521,7 +1521,7 @@
 
   const ResourceHandle& resource_handle = outputs[0].scalar<ResourceHandle>()();
   Tensor string_handle(DT_STRING, {});
-  string_handle.flat<string>().setConstant(resource_handle.name());
+  string_handle.flat<tstring>().setConstant(resource_handle.name());
 
   // Second run call: Use a handle.
   std::vector<Tensor> outputs1;
diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD
index 5d77170..abc93b5 100644
--- a/tensorflow/core/common_runtime/eager/BUILD
+++ b/tensorflow/core/common_runtime/eager/BUILD
@@ -3,6 +3,10 @@
     "tf_cc_test",
     "tf_cuda_library",
 )
+load(
+    "//third_party/mkl:build_defs.bzl",
+    "if_mkl",
+)
 
 package(
     default_visibility = [
@@ -264,8 +268,21 @@
             "//tensorflow/core:protos_all_cc",
             "//tensorflow/core/distributed_runtime/eager:eager_client",
             "//tensorflow/core/distributed_runtime/eager:remote_execute_node",
+            "//tensorflow/core/distributed_runtime/eager:remote_copy_node",
         ],
-    }),
+    }) + if_mkl([":mkl_eager_op_rewrite"]),
+)
+
+cc_library(
+    name = "mkl_eager_op_rewrite",
+    srcs = ["mkl_eager_op_rewrite.cc"],
+    deps = [
+        ":eager_op_rewrite_registry",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:graph",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:mkl_graph_util",
+    ],
 )
 
 cc_library(
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 9c7ad99..eb8e912 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -80,7 +80,7 @@
       log_device_placement_(opts.config.log_device_placement()),
       allow_soft_placement_(opts.config.allow_soft_placement()),
       num_active_steps_(0),
-      async_default_(async),
+      default_executor_(async),
       log_memory_(LogMemory::IsEnabled()),
       env_(opts.env),
       use_send_tensor_rpc_(false),
@@ -102,6 +102,10 @@
     this->thread_pool_->Schedule(std::move(closure));
   };
 
+#if !defined(IS_MOBILE_PLATFORM)
+  context_id_ = kInvalidContextId;
+#endif  // IS_MOBILE_PLATFORM
+
   std::unique_ptr<DeviceResolverInterface> drl(
       new DeviceResolverLocal(local_device_mgr()));
   std::unique_ptr<ParamResolverInterface> cprl(new CollectiveParamResolverLocal(
@@ -112,10 +116,6 @@
 }
 
 void EagerContext::InitDeviceMapAndAsync() {
-  if (async_default_) {
-    executor_.EnableAsync();
-  }
-
   for (auto* device : devices_) {
     devices_map_[device->name()] = device;
   }
@@ -136,38 +136,38 @@
   prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList();
 }
 
-bool EagerContext::Async() const {
-  mutex_lock l(async_map_mu_);
-  return gtl::FindWithDefault(thread_local_async_, std::this_thread::get_id(),
-                              async_default_);
+EagerExecutor* EagerContext::Executor() {
+  tf_shared_lock l(executor_map_mu_);
+  return gtl::FindWithDefault(thread_local_executor_,
+                              std::this_thread::get_id(), &default_executor_);
 }
 
-Status EagerContext::SetAsyncForThread(bool async) {
-  {
-    tensorflow::mutex_lock l(async_map_mu_);
-    thread_local_async_[std::this_thread::get_id()] = async;
-  }
-  if (async) {
-    executor_.EnableAsync();
+void EagerContext::SetExecutorForThread(EagerExecutor* executor) {
+  tensorflow::mutex_lock l(executor_map_mu_);
+  if (executor == &default_executor_) {
+    thread_local_executor_.erase(std::this_thread::get_id());
   } else {
-    // TODO(agarwal): Currently we add a wait here to handle cases where a
-    // sync op has a control dependency on an async op, and the latter has not
-    // executed yet. This wait can be removed by storing all the control
-    // inputs and waiting for them when executing ops.
-    return executor_.WaitForAllPendingNodes();
+    thread_local_executor_[std::this_thread::get_id()] = executor;
   }
-  return Status::OK();
 }
 
 void EagerContext::ClearCaches() {
-  // The executor stores pointers to kernels, so we need to make sure that no
-  // async eager ops are still executing. We lock the cache during this time as
-  // well.
-  mutex_lock ml(cache_mu_);
-  executor_.WaitForAllPendingNodes().IgnoreError();
-  kernel_cache_.clear();
-  for (auto& entry : registered_functions_) {
-    entry.second->cached_kernel_keys->clear();
+  {
+    mutex_lock ml(executor_map_mu_);
+    for (auto& entry : thread_local_executor_) {
+      entry.second->WaitForAllPendingNodes().IgnoreError();
+    }
+  }
+  {
+    // The executor stores pointers to kernels, so we need to make sure that no
+    // async eager ops are still executing. We lock the cache during this time
+    // as well.
+    mutex_lock ml(cache_mu_);
+    default_executor_.WaitForAllPendingNodes().IgnoreError();
+    kernel_cache_.clear();
+    for (auto& entry : registered_functions_) {
+      entry.second->cached_kernel_keys->clear();
+    }
   }
 }
 
@@ -211,6 +211,9 @@
   // Close all remote contexts.
   eager::CloseContextRequest request;
   request.set_context_id(context_id_);
+  // Setting context_id to a new value can avoid us issuing DestroyTensorHandle
+  // request to closed remote workers.
+  context_id_ = kInvalidContextId;
   std::vector<eager::CloseContextResponse> responses(remote_contexts_.size());
   BlockingCounter counter(static_cast<int>(remote_contexts_.size()));
 
@@ -232,9 +235,49 @@
   }
 
   counter.Wait();
+
+  remote_contexts_.clear();
 }
+
 #endif  // !IS_MOBILE_PLATFORM
 
+void EagerContext::WaitForAndCloseRemoteContexts() {
+  ClearCaches();
+
+#if !defined(IS_MOBILE_PLATFORM)
+  {
+    mutex_lock l(keep_alive_thread_shutdown_mu_);
+    shutting_down_ = true;
+    keep_alive_thread_cv_.notify_all();
+  }
+  keep_alive_thread_.reset();
+
+  mutex_lock l(remote_state_mu_);
+  if (!remote_contexts_.empty() && is_master_) {
+    CloseRemoteContexts();
+  }
+
+  default_executor_.ShutDown().IgnoreError();
+  std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
+  {
+    mutex_lock l(executor_map_mu_);
+    executors_copy = thread_local_executor_;
+  }
+  for (const auto& it : executors_copy) {
+    it.second->ShutDown().IgnoreError();
+  }
+
+  // This shuts down the completion queue and joins the thread polling it.
+  // The thread exits only after the completion queue has been drained of all
+  // the events. These events' completion should invoke all remaining RPC
+  // callbacks.
+  // This also deletes all EagerClient instances. There should not be any
+  // references to EagerClients left after all RPCs and async ops have been
+  // finished.
+  remote_eager_workers_ = nullptr;
+#endif  // !IS_MOBILE_PLATFORM
+}
+
 EagerContext::~EagerContext() {
   ClearCaches();
   for (auto& entry : registered_functions_) {
@@ -263,14 +306,7 @@
   }
 #endif  // !IS_MOBILE_PLATFORM
 
-  executor_.WaitForAllPendingNodes().IgnoreError();
   rendezvous_->Unref();
-
-  // Release resources ahead of destroying the device manager as the resource
-  // destructors (e.g. ~IteratorResource) assume devices still exist.
-  for (auto device : local_device_mgr()->ListDevices()) {
-    device->ClearResourceMgr();
-  }
 }
 
 bool EagerContext::FindFunctionByName(const string& name) {
@@ -598,6 +634,13 @@
 
   InitDeviceMapAndAsync();
   ClearCaches();
+  default_executor_.ClearError();
+  {
+    tensorflow::mutex_lock l(executor_map_mu_);
+    for (auto& entry : thread_local_executor_) {
+      entry.second->ClearError();
+    }
+  }
 
   pflr_.reset(new ProcessFunctionLibraryRuntime(
       local_unowned_device_manager_, env_, TF_GRAPH_DEF_VERSION, &func_lib_def_,
@@ -624,6 +667,11 @@
     DistributedFunctionLibraryRuntime* cluster_flr,
     std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
         remote_mgr) {
+  if (context_id == kInvalidContextId) {
+    return errors::InvalidArgument(
+        "Failed to initialize remote for master context due to invalid ",
+        "context id");
+  }
   mutex_lock l(remote_state_mu_);
   is_master_ = true;
 
@@ -666,7 +714,13 @@
   InitDeviceMapAndAsync();
 
   ClearCaches();
-  executor_.ClearError();
+  default_executor_.ClearError();
+  {
+    tensorflow::mutex_lock l(executor_map_mu_);
+    for (auto& entry : thread_local_executor_) {
+      entry.second->ClearError();
+    }
+  }
 
   keep_alive_secs_ = keep_alive_secs;
   sleep_for_secs_ = std::max(1, keep_alive_secs_ / 2);
@@ -736,6 +790,11 @@
     std::function<Rendezvous*(const int64)> rendezvous_creator,
     std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
         remote_mgr) {
+  if (context_id == kInvalidContextId) {
+    return errors::InvalidArgument(
+        "Failed to initialize remote for worker context due to invalid ",
+        "context id");
+  }
   mutex_lock l(remote_state_mu_);
 
   if (remote_device_manager_ != nullptr || server_ != nullptr ||
@@ -757,7 +816,13 @@
   InitDeviceMapAndAsync();
 
   ClearCaches();
-  executor_.ClearError();
+  default_executor_.ClearError();
+  {
+    tensorflow::mutex_lock l(executor_map_mu_);
+    for (auto& entry : thread_local_executor_) {
+      entry.second->ClearError();
+    }
+  }
 
   return Status::OK();
 }
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 5b9a08b..445f77e 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -107,6 +107,16 @@
 
 class EagerContext : public core::RefCounted {
  public:
+  static const uint64 kInvalidContextId = 0;
+
+  static uint64 NewContextId() {
+    uint64 context_id = random::New64();
+    while (context_id == kInvalidContextId) {
+      context_id = random::New64();
+    }
+    return context_id;
+  }
+
   EagerContext(const SessionOptions& opts,
                ContextDevicePlacementPolicy default_device_placement_policy,
                ContextMirroringPolicy default_mirroring_policy, bool async,
@@ -124,15 +134,10 @@
 
   ProcessFunctionLibraryRuntime* pflr() const { return pflr_.get(); }
 
-  // True if running in asynchronous mode.
-  bool Async() const;
-
-  EagerExecutor* Executor() { return &executor_; }
-
   std::function<void(std::function<void()>)>* runner() { return &runner_; }
 
-  // Sets whether this thread should run in synchronous or asynchronous mode.
-  Status SetAsyncForThread(bool async);
+  // Specify a executor for this thread.
+  void SetExecutorForThread(EagerExecutor* executor);
 
   // TODO(apassos) make this return a constant reference
   gtl::FlatMap<string, Device*, StringPieceHasher>* device_map() {
@@ -162,12 +167,6 @@
 
   bool MirrorTensors() const;
 
-  Status AsyncWait() { return executor_.WaitForAllPendingNodes(); }
-
-  Status GetStatus() { return executor_.status(); }
-
-  void ClearAsyncError() { executor_.ClearError(); }
-
   bool FindFunctionByName(const string& name);
 
   Status FindFunctionOpData(const string& name,
@@ -184,9 +183,7 @@
 
   GraphCollector* GetGraphCollector() { return &graph_collector_; }
 
-  Status ExecutorAdd(std::unique_ptr<EagerNode> node) {
-    return executor_.Add(std::move(node));
-  }
+  EagerExecutor* Executor();
 
   Status AddFunctionDef(const FunctionDef& fdef);
 
@@ -320,8 +317,20 @@
   // EagerService.SendTensor RPC. If false, _Send/_Recv ops should be used
   // instead (which in-turn use WorkerService.RecvTensor RPCs).
   bool UseSendTensorRPC() { return use_send_tensor_rpc_; }
+
 #endif  // IS_MOBILE_PLATFORM
 
+  // Closes remote eager contexts, waits for all RPCs to finish, and
+  // destroys the EagerClientCache. No RPCs can be made through this context
+  // after this method has been called.
+  // This method exists to aid a clean shutdown. It causes all RPCs to finish
+  // and remote TensorHandles to release their references to this context.
+  // To avoid deadlocks, this method must not be called on the thread
+  // processing RPCs because it makes RPCs and waits for their completion.
+  //
+  // On mobile, it just cleans the caches.
+  void WaitForAndCloseRemoteContexts();
+
   bool PinSmallOpsToCPU() { return pin_small_ops_to_cpu_; }
 
   tensorflow::Env* TFEnv() const { return env_; }
@@ -399,19 +408,16 @@
   // TODO(fishx): Allow update following two bool after context creation.
   const bool log_device_placement_;
   const bool allow_soft_placement_;
-  // EagerExecutor for async execution.
-  EagerExecutor executor_;
 
   // Information related to step containers.
   std::atomic<int> num_active_steps_;
   std::unique_ptr<ScopedStepContainer> step_container_ GUARDED_BY(metadata_mu_);
 
-  // True if the default value for execution mode is async. Note that this value
-  // can be overridden per thread based on `thread_local_async` overrides.
-  const bool async_default_;
-  mutable mutex async_map_mu_;
-  std::unordered_map<std::thread::id, bool> thread_local_async_
-      GUARDED_BY(async_map_mu_);
+  EagerExecutor default_executor_;
+  mutable mutex executor_map_mu_;
+  // Not owned.
+  std::unordered_map<std::thread::id, EagerExecutor*> thread_local_executor_
+      GUARDED_BY(executor_map_mu_);
 
   const bool log_memory_;
 
diff --git a/tensorflow/core/common_runtime/eager/eager_executor.cc b/tensorflow/core/common_runtime/eager/eager_executor.cc
index ae3369d..7f91251 100644
--- a/tensorflow/core/common_runtime/eager/eager_executor.cc
+++ b/tensorflow/core/common_runtime/eager/eager_executor.cc
@@ -15,20 +15,81 @@
 
 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
 
+#include "tensorflow/core/lib/gtl/cleanup.h"
+
 namespace tensorflow {
 
+EagerExecutor::EagerExecutor(bool async)
+    : thread_(async ? tensorflow::Env::Default()->StartThread(
+                          tensorflow::ThreadOptions(), "eager_async_executor",
+                          std::bind(&EagerExecutor::Run, this))
+                    : nullptr) {}
+
 EagerExecutor::~EagerExecutor() {
   tensorflow::mutex_lock l(node_queue_mutex_);
-  thread_done_ = true;
+  state_ = ExecutorState::kShutDown;
   nodes_pending_.notify_all();
 }
 
-void EagerExecutor::EnableAsync() {
+Status EagerExecutor::ShutDown() {
+  {
+    tensorflow::mutex_lock l(node_queue_mutex_);
+    if (state_ != ExecutorState::kShutDown) {
+      // if the state is kShutDown, we don't return here because we want to
+      // make sure the executor thread has ended (if there is one).
+      // So, we fall through to
+      // thread_exited_notification_.WaitForNotification() below.
+      state_ = ExecutorState::kShuttingDown;
+    }
+    WaitForOrDestroyAllPendingNodes(&l);
+    state_ = ExecutorState::kShutDown;
+    if (thread_ == nullptr) {
+      return status_;
+    }
+    nodes_pending_.notify_all();
+  }
+
+  thread_exited_notification_.WaitForNotification();
   tensorflow::mutex_lock l(node_queue_mutex_);
+  return status_;
+}
+
+void EagerExecutor::WaitForOrDestroyAllPendingNodes(mutex_lock* lock) {
+  if (state_ == ExecutorState::kShutDown) {
+    return;
+  }
   if (thread_ == nullptr) {
-    thread_.reset(tensorflow::Env::Default()->StartThread(
-        tensorflow::ThreadOptions(), "eager_async_executor",
-        std::bind(&EagerExecutor::Run, this)));
+    Status status = status_;
+    if (status.ok()) {
+      status = errors::FailedPrecondition(
+          "Aborting eager nodes because EagerExecutor is being shut down "
+          "before it got a thread to run the nodes");
+      status_ = status;
+    }
+    while (!node_queue_.empty()) {
+      node_queue_.front()->Abort(status);
+      node_queue_.pop();
+    }
+    return;
+  }
+
+  // It is OK to ignore the returned status here because it will be saved
+  // as the final status_.
+  WaitForAllPendingNodesLocked(lock).IgnoreError();
+}
+
+bool EagerExecutor::Async() const {
+  return thread_ != nullptr;
+}
+
+const char* EagerExecutor::StateStringLocked() {
+  switch (state_) {
+    case ExecutorState::kActive:
+      return "Active";
+    case ExecutorState::kShuttingDown:
+      return "ShuttingDown";
+    case ExecutorState::kShutDown:
+      return "ShutDown";
   }
 }
 
@@ -40,18 +101,25 @@
   // try to call EagerExecutor::Add()
   {
     tensorflow::mutex_lock l(node_queue_mutex_);
-    DCHECK(thread_) << "EnableAsync should have been called before Add";
-    status = status_;
-    if (status.ok()) {
-      node_queue_.push(std::move(node));
+    if (state_ != ExecutorState::kActive) {
+      status = errors::FailedPrecondition(
+          "EagerExecutor accepts new EagerNodes to run only in Active state. "
+          "Current state is '",
+          StateStringLocked(), "'");
+    } else {
+      DCHECK(thread_) << "EnableAsync should have been called before Add";
+      status = status_;
+      if (status.ok()) {
+        node_queue_.push(std::move(node));
 
-      // If there were no previous nodes pending, wake the run thread to start
-      // processing requests again.
-      if (node_queue_.size() == 1) {
-        nodes_pending_.notify_all();
+        // If there were no previous nodes pending, wake the run thread to start
+        // processing requests again.
+        if (node_queue_.size() == 1) {
+          nodes_pending_.notify_all();
+        }
+
+        return Status::OK();
       }
-
-      return Status::OK();
     }
   }
 
@@ -61,14 +129,19 @@
 }
 
 tensorflow::Status EagerExecutor::WaitForAllPendingNodes() {
-  tensorflow::condition_variable cond;
   tensorflow::mutex_lock l(node_queue_mutex_);
+  return WaitForAllPendingNodesLocked(&l);
+}
+
+tensorflow::Status EagerExecutor::WaitForAllPendingNodesLocked(
+    mutex_lock* lock) {
+  tensorflow::condition_variable cond;
   // Don't wait if an error is already set.
   if (!status_.ok()) return status_;
   if (node_queue_.empty()) return tensorflow::Status::OK();
   EagerNode* last_node = node_queue_.back().get();
   node_done_notifications_.insert(std::make_pair(last_node, &cond));
-  cond.wait(l);
+  cond.wait(*lock);
   // Note that we could be woken up if an error occurs, even though the node has
   // not actually executed.
   return status_;
@@ -76,6 +149,7 @@
 
 void EagerExecutor::ClearError() {
   tensorflow::mutex_lock l(node_queue_mutex_);
+  // TODO(iga): Check state_ and return an error if it is not kActive.
   if (status_.ok()) return;
   // If an error was set, node_done_notifications_ and node_queue_ should have
   // been cleared, and no new entries should have been added since.
@@ -91,48 +165,67 @@
 }
 
 void EagerExecutor::Run() {
+  auto thread_exited_notifier =
+      gtl::MakeCleanup([this] { thread_exited_notification_.Notify(); });
   while (true) {
-    EagerNode* curr_node;
+    EagerNode* curr_node_raw;
     {
       tensorflow::mutex_lock l(node_queue_mutex_);
       while (node_queue_.empty() || !status_.ok()) {
-        if (thread_done_) return;
+        if (state_ == ExecutorState::kShutDown) return;
         nodes_pending_.wait(l);
       }
       // Obtain raw pointer since we don't want to remove from the queue until
-      // the node has been run.
-      curr_node = node_queue_.front().get();
+      // the node has been run. Otherwise, WaitForAllPendingNodes can return
+      // too early.
+      // Note, we don't std::move from the here because the front of the queue
+      // will then contain a nullptr. This can be a problem in
+      // WaitForAllPendingNodes where we get the top EagerNode pointer
+      // and register a notification for its completion.
+      curr_node_raw = node_queue_.front().get();
     }
-    tensorflow::Status status = curr_node->Run();
+    tensorflow::Status status = curr_node_raw->Run();
     const bool ok = status.ok();
-    tensorflow::mutex_lock l(node_queue_mutex_);
-    node_queue_.pop();
-    if (!ok) {
-      status_ = status;
-      // We remove any pending ops so that we don't try to execute them if
-      // ClearError is called.
-      errors::AppendToMessage(&status,
-                              ". Encountered when executing an operation using "
-                              "EagerExecutor. This error cancels all future "
-                              "operations and poisons their output tensors.");
-      for (int i = 0; i < node_queue_.size(); ++i) {
-        node_queue_.front()->Abort(status);
-        // Dequeue and delete nodes
-        node_queue_.pop();
+
+    std::unique_ptr<EagerNode> curr_node;
+    std::vector<std::unique_ptr<EagerNode>> nodes_to_destroy;
+    {
+      tensorflow::mutex_lock l(node_queue_mutex_);
+      curr_node = std::move(node_queue_.front());
+      node_queue_.pop();
+      if (!ok) {
+        status_ = status;
+        // We remove any pending ops so that we don't try to execute them if
+        // ClearError is called.
+        errors::AppendToMessage(
+            &status,
+            ". Encountered when executing an operation using "
+            "EagerExecutor. This error cancels all future "
+            "operations and poisons their output tensors.");
+        while (!node_queue_.empty()) {
+          node_queue_.front()->Abort(status);
+          nodes_to_destroy.push_back(std::move(node_queue_.front()));
+          node_queue_.pop();
+        }
+      }
+      if (!node_done_notifications_.empty()) {
+        // Note that we notify all waiting threads in case an error has
+        // occurred. These calling threads are responsible for checking status_
+        // before proceeding.
+        const auto range =
+            ok ? node_done_notifications_.equal_range(curr_node_raw)
+               : make_pair(node_done_notifications_.begin(),
+                           node_done_notifications_.end());
+        for (auto it = range.first; it != range.second; ++it) {
+          it->second->notify_all();
+        }
+        node_done_notifications_.erase(range.first, range.second);
       }
     }
-    if (!node_done_notifications_.empty()) {
-      // Note that we notify all waiting threads in case an error has occurred.
-      // These calling threads are responsible for checking status_ before
-      // proceeding.
-      const auto range = ok ? node_done_notifications_.equal_range(curr_node)
-                            : make_pair(node_done_notifications_.begin(),
-                                        node_done_notifications_.end());
-      for (auto it = range.first; it != range.second; ++it) {
-        it->second->notify_all();
-      }
-      node_done_notifications_.erase(range.first, range.second);
-    }
+    // curr_node and nodes_to_destroy will be destructed here, while not holding
+    // node_queue_mutex_. This is important because, unfortunately, some nodes'
+    // destructors can enqueue more operations onto this executor and cause
+    // a deadlock.
   }
 }
 
diff --git a/tensorflow/core/common_runtime/eager/eager_executor.h b/tensorflow/core/common_runtime/eager/eager_executor.h
index 9a5aee3..9e3d092 100644
--- a/tensorflow/core/common_runtime/eager/eager_executor.h
+++ b/tensorflow/core/common_runtime/eager/eager_executor.h
@@ -70,17 +70,29 @@
 // TODO(agarwal): Implement optimizations over EagerNode traces.
 class EagerExecutor {
  public:
+  explicit EagerExecutor(bool async);
+
   ~EagerExecutor();
 
-  // This is called whenever async mode is enabled. Note that it may be called
-  // multiple times as different calling threads may switch async mode on or off
-  // independently.
-  void EnableAsync();
+  // Puts this in a shutdown state. In this state, Add() will return an error
+  // and not add new EagerNodes. After putting this in the shutdown state,
+  // blocks until all pendings nodes have finished running.
+  // Returns the status of executing pending nodes.
+  // If async was not enabled, aborts and destroys all pending nodes.
+  Status ShutDown();
 
-  // Schedules `node` for execution.
+  bool Async() const;
+
+  // Schedules `node` for execution. If an error occurs (e.g. EagerExecutor
+  // has already been shut down), the `node` is not added to this executor
+  // and its Abort() method is called.
   Status Add(std::unique_ptr<EagerNode> node);
 
   // Blocks till all currently pending ops are done.
+  // In particular, if EnableAsync() has not beed called, it will not return
+  // until that happens (and pendings, at the time of call, nodes finish
+  // running). If this executor has already been shut down, its final status is
+  // returned.
   Status WaitForAllPendingNodes();
 
   // Clears all currently set errors which re-enables async execution.
@@ -90,12 +102,43 @@
   Status status() const;
 
  private:
+  // Possible states for this executor.
+  // Executor starts in kActive state. When Shutdown() is called, Executor
+  // is put in the kShuttingDown state. In this state, the executor thread
+  // continues to run, but no new nodes are accepted. Finally, when all nodes
+  // are drained, the executor is put in the kShutDown state, which causes the
+  // thread to exit.
+  // If this executor is destroyed without calling shutdown first, it
+  // transitions to kShutDown state immediately which causes the thread to exit
+  // without running pending nodes.
+  enum class ExecutorState {
+    kActive,
+    kShuttingDown,
+    kShutDown,
+  };
+
+  const char* StateStringLocked() EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
+
   // Starts execution of pending EagerNodes. This function loops till
   // thread_done_ is set to true. If any errors are encontered, these are set
   // inside `status_`. The loop blocks anytime there are no pending nodes, or if
   // `status_` is not ok.
   void Run();
 
+  // The impl of WaitForAllPendingNodes
+  // `lock` is the lock that holds node_queue_mutex_.
+  Status WaitForAllPendingNodesLocked(mutex_lock* lock)
+      EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
+
+  // If async has been enabled on this executor, just calls
+  // WaitForAllPendingNodes. Else:
+  //  - Aborts and destroys all pending nodes
+  //  - sets the status_ to an error if it does not already contain one
+  // `lock` is the lock that holds node_queue_mutex_.
+  // Precondition: state_ != kActive.
+  void WaitForOrDestroyAllPendingNodes(mutex_lock* lock)
+      EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
+
   Status WaitImpl(bool wait_all, uint64 node_id);
 
   mutable mutex node_queue_mutex_;
@@ -117,13 +160,17 @@
   std::multimap<EagerNode*, condition_variable*> node_done_notifications_
       GUARDED_BY(node_queue_mutex_);
 
-  // Thread object that calls the `Run` method. Currently we use only one thread
-  // for executing the EagerNodes one-by-one.
-  std::unique_ptr<Thread> thread_ GUARDED_BY(node_queue_mutex_);
+  // Thread object that calls the `Run` method in async mode.This thread runs
+  // till thread_done_ is set to true. It is `nullptr` in sync mode.
+  const std::unique_ptr<Thread> thread_;
+
+  // thread_exited_notification_ is notified by the `thread_` right before it
+  // exits.
+  Notification thread_exited_notification_;
 
   // Indicates that `thread_` should stop as soon as it is done executing the
   // current EagerNode.
-  bool thread_done_ GUARDED_BY(node_queue_mutex_) = false;
+  ExecutorState state_ GUARDED_BY(node_queue_mutex_) = ExecutorState::kActive;
 };
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h
index 7191cbb..fb7eb41 100644
--- a/tensorflow/core/common_runtime/eager/eager_operation.h
+++ b/tensorflow/core/common_runtime/eager/eager_operation.h
@@ -31,7 +31,8 @@
         attrs_(op),
         attr_types_(t),
         device_(nullptr),
-        is_function_(is_function) {}
+        is_function_(is_function),
+        executor_(ctx ? ctx->Executor() : nullptr) {}
 
   ~EagerOperation() {
     for (tensorflow::TensorHandle* h : inputs_) {
@@ -81,6 +82,8 @@
     cancellation_manager_ = cancellation_manager;
   }
 
+  EagerExecutor* Executor() { return executor_; }
+
   string DebugString() const;
 
  private:
@@ -94,6 +97,7 @@
   bool use_xla_ = false;
   const bool is_function_;
   CancellationManager* cancellation_manager_ = nullptr;  // Not owned.
+  EagerExecutor* const executor_;                        // Not owned.
 };
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 19e79a9..3dae534 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -44,8 +44,9 @@
 #include "tensorflow/core/profiler/lib/traceme.h"
 #include "tensorflow/core/util/device_name_utils.h"
 #if !defined(IS_MOBILE_PLATFORM)
-#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
+#include "tensorflow/core/distributed_runtime/eager/remote_copy_node.h"
+#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
 #include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h"
 #endif  // IS_MOBILE_PLATFORM
 #include "tensorflow/core/framework/step_stats.pb.h"
@@ -176,8 +177,9 @@
   // trigger a copy.
   auto pre_time_nanos = Env::Default()->NowNanos();
   TensorHandle* result_handle = nullptr;
-  Status status = EagerCopyToDevice(handle, ctx, expected_input_device,
-                                    ctx->MirrorTensors(), &result_handle);
+  Status status =
+      EagerCopyToDevice(handle, ctx, op->Executor(), expected_input_device,
+                        ctx->MirrorTensors(), &result_handle);
   if (run_metadata != nullptr) {
     auto* step_stats = run_metadata->mutable_step_stats();
     MaybeInitializeStepStats(step_stats, ctx);
@@ -408,8 +410,7 @@
 
 Status ShouldCompileWithXLA(const EagerOperation* op, const EagerContext* ctx,
                             bool* compile_with_xla) {
-  if (!op->is_function() ||
-      !DeviceNameUtils::HasSomeDetails(op->GetDeviceName())) {
+  if (!op->is_function()) {
     *compile_with_xla = false;
     return Status::OK();
   }
@@ -473,7 +474,8 @@
       [&] { return absl::StrCat("EagerLocalExecute: ", op->Name()); },
       profiler::TraceMeLevel::kInfo);
   EagerContext* ctx = op->EagerContext();
-  TF_RETURN_IF_ERROR(ctx->GetStatus());
+  auto* executor = op->Executor();
+  TF_RETURN_IF_ERROR(executor->status());
   Device* device = op->Device();
 
   Fprint128 cache_key = op->MutableAttrs()->CacheKey(
@@ -506,7 +508,7 @@
       if (input->IsRemote()) {
         TensorHandle* handle = nullptr;
         TF_RETURN_IF_ERROR(EagerCopyToDevice(
-            input, ctx, device == nullptr ? ctx->HostCPU() : device,
+            input, ctx, executor, device == nullptr ? ctx->HostCPU() : device,
             ctx->MirrorTensors(), &handle));
         op->UpdateInput(i, handle);
         // Unref handle since it has a ref as an input now
@@ -680,7 +682,7 @@
   // input handles are ready before executing them.
   // TODO(b/137118203): Consider executing "cheap" kernels inline for
   // performance.
-  Status s = ctx->Async() ? ctx->ExecutorAdd(std::move(node)) : node->Run();
+  Status s = executor->Async() ? executor->Add(std::move(node)) : node->Run();
   // Since the operation failed, we need to Unref any outputs that were
   // allocated.
   if (!s.ok()) {
@@ -789,6 +791,7 @@
     for (int i = 0; i < op->Inputs().size(); i++) {
       tensorflow::TensorHandle* input = op->Inputs()[i];
       tensorflow::Device* input_device = input->device();
+      const string* input_device_name = &input->DeviceOrHostCPU(ctx)->name();
       if (op->Device() != input_device &&
           // If the expected and actual devices are on the same task, don't
           // explicitly copy, and instead depend on the copy to happen locally
@@ -809,12 +812,13 @@
         op->UpdateInput(i, handle);
         input = handle;
         input_device = remote_cpu_device;
+        input_device_name = &remote_cpu_device->name();
         // Unref handle since it has a ref as an input now
         handle->Unref();
       }
 
       TF_RETURN_IF_ERROR(ctx->RemoteMgr()->SerializeRemoteTensorHandle(
-          input, remote_op->add_inputs(), input_device));
+          input, remote_op->add_inputs(), input_device, *input_device_name));
     }
   }
 
@@ -831,11 +835,6 @@
   *num_retvals = num_outputs;
 
   tensorflow::Device* op_device = op->Device();
-
-  bool is_async = ctx->Async();
-  VLOG(4) << "Execute remote eager op: " << op->Name()
-          << " (is async?: " << is_async << ").";
-
   const tensorflow::uint64 id = remote_op->id();
   for (int i = 0; i < num_outputs; ++i) {
     // TODO(nareshmodi): Change the callback to instead add the decref to a
@@ -849,15 +848,28 @@
     // remote device here. We just need to know that it is remote. If we need
     // to copy this tensor to this process, the remote end will know the
     // correct device of this handle.
-    TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle(
+    Status status = TensorHandle::CreateUnshapedRemoteHandle(
         id, i, eager_client, context_id, output_dtypes[i], op_device, ctx,
-        &retvals[i]));
+        &retvals[i]);
+    if (!status.ok()) {
+      for (int j = 0; j < i; ++j) {
+        retvals[j]->Poison(errors::Internal(
+            "Failed to construct unshaped remote tensor handle at index ", i,
+            " for op ", op->Name()));
+      }
+      return status;
+    }
   }
 
+  auto* executor = op->Executor();
+  bool is_async = executor->Async();
+  VLOG(4) << "Execute remote eager op: " << op->Name()
+          << " (is async?: " << is_async << ").";
+
   std::unique_ptr<EagerNode> node(
       new eager::RemoteExecuteNode(std::move(request), op_device, eager_client,
                                    op->Inputs(), {retvals, num_outputs}));
-  Status s = is_async ? ctx->ExecutorAdd(std::move(node)) : node->Run();
+  Status s = is_async ? executor->Add(std::move(node)) : node->Run();
   // Since the operation failed, we need to Unref any outputs that were
   // allocated.
   if (!s.ok()) {
@@ -898,7 +910,7 @@
 // (int32/int64). This can be disabled by setting the environment variable
 // "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING" to "0" or "false".
 Status MaybeUpdateOpDevice(EagerOperation* op) {
-  auto exempt_ops = InputColocationExemptionRegistry::Global()->Get();
+  const auto& exempt_ops = InputColocationExemptionRegistry::Global()->Get();
   if (op->is_function() || exempt_ops.find(op->Name()) != exempt_ops.end()) {
     // Don't update the device of direct function calls.
     // Particularly, if the user did not explicitly request any device for this
@@ -1146,9 +1158,10 @@
 
 namespace {
 
-Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* dstd,
+Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
+                              EagerExecutor* executor, Device* dstd,
                               TensorHandle** result) {
-  TF_RETURN_IF_ERROR(ctx->GetStatus());
+  TF_RETURN_IF_ERROR(executor->status());
   Device* resource_device = (h->dtype == DT_RESOURCE) ? dstd : nullptr;
   TF_RETURN_IF_ERROR(TensorHandle::CreateAsyncLocalHandle(
       ctx->CanonicalDevice(dstd), dstd, resource_device, h->dtype, ctx,
@@ -1157,7 +1170,7 @@
   // Note that `h` may not be currently ready. However execution order will
   // make sure that `h` is ready before the copy is actually done.
   std::unique_ptr<EagerNode> node(new CopyToDeviceNode(h, *result, dstd, ctx));
-  Status s = ctx->Async() ? ctx->ExecutorAdd(std::move(node)) : node->Run();
+  Status s = executor->Async() ? executor->Add(std::move(node)) : node->Run();
   // Since the operation failed, we need to Unref any outputs that were
   // allocated.
   if (!s.ok()) {
@@ -1167,185 +1180,11 @@
   return s;
 }
 
-#if !defined(IS_MOBILE_PLATFORM)
-Status CreateUncachedKernelAndDeviceOp(
-    EagerOperation* op, core::RefCountPtr<KernelAndDevice>* kernel) {
-  EagerContext* ctx = op->EagerContext();
-  Device* device = op->Device();
-
-  FunctionLibraryRuntime* flr = ctx->func_lib(device);
-  if (flr == nullptr) {
-    return errors::Unavailable(
-        "Unable to find a FunctionLibraryRuntime corresponding to device ",
-        device->name());
-  }
-
-  auto runner = (flr->runner() != nullptr) ? flr->runner() : ctx->runner();
-  kernel->reset(new KernelAndDeviceOp(
-      ctx->GetRendezvous(), ctx->LogMemory(), flr, runner,
-      ctx->GetCollectiveExecutorHandle(), ctx->HostCPU()));
-
-  const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
-  return kernel->get()->Init(ndef, nullptr);
-}
-
-Status ExecuteSend(EagerContext* ctx, Device* device, TensorHandle* h,
-                   StringPiece wire_id, Device* recv_device) {
-  // TODO(gjn): We should consider just using the low-level SendOp::Compute()
-  // functionality here instead of constructing an Op.
-  const AttrTypeMap* types;
-  bool is_function = false;
-  TF_RETURN_IF_ERROR(AttrTypeMapForOp("_Send", &types, &is_function));
-  DCHECK(!is_function);
-  EagerOperation op(ctx, "_Send", /*is_function=*/false, types);
-
-  op.SetDevice(device);
-
-  op.MutableAttrs()->Set("tensor_name", wire_id);
-  op.MutableAttrs()->Set("send_device", device->name());
-  op.MutableAttrs()->Set(
-      "send_device_incarnation",
-      static_cast<int64>(device->attributes().incarnation()));
-  op.MutableAttrs()->Set("recv_device", recv_device->name());
-  op.MutableAttrs()->Set("client_terminated", false);
-
-  op.MutableAttrs()->Set("T", h->dtype);
-
-  DCHECK(device != nullptr);
-
-  if (device->IsLocal()) {
-    TF_RETURN_IF_ERROR(ctx->GetStatus());
-
-    op.AddInput(h);
-
-    core::RefCountPtr<KernelAndDevice> kernel;
-    TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(&op, &kernel));
-
-    gtl::InlinedVector<TensorValue, 4> input_vector(1);
-    TF_RETURN_IF_ERROR(h->TensorValue(&input_vector[0]));
-
-    TF_RETURN_IF_ERROR(
-        kernel->Run(input_vector, nullptr, nullptr, nullptr, nullptr, nullptr));
-  } else {
-    eager::EagerClient* eager_client;
-    uint64 context_id = ctx->GetContextId();
-    TF_RETURN_IF_ERROR(ctx->GetClient(device, &eager_client));
-
-    std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
-    request->set_context_id(context_id);
-
-    auto* remote_op = request->add_queue()->mutable_operation();
-    TF_RETURN_IF_ERROR(ctx->RemoteMgr()->SerializeRemoteTensorHandle(
-        h, remote_op->add_inputs(), h->device()));
-
-    PrepareRemoteOp(remote_op, &op);
-
-    std::unique_ptr<EagerNode> node(new eager::RemoteExecuteNode(
-        std::move(request), nullptr, eager_client, op.Inputs(), {nullptr, 0}));
-    if (ctx->Async()) {
-      TF_RETURN_IF_ERROR(ctx->ExecutorAdd(std::move(node)));
-    } else {
-      TF_RETURN_IF_ERROR(node->Run());
-    }
-  }
-
-  return Status::OK();
-}
-
-// Execute a Recv to transfer a tensor handle to a specific device. The received
-// tensor handle will be returned in result. If mirror_dst is provided, the
-// tensor handle will be added as a mirror.
-Status ExecuteRecv(EagerContext* ctx, Device* device, DataType dtype,
-                   StringPiece wire_id, Device* send_device,
-                   TensorHandle* mirror_dst, TensorHandle** result) {
-  // TODO(gjn): We should consider just using the low-level RecvOp::Compute()
-  // functionality here instead of constructing an Op.
-  const AttrTypeMap* types;
-  bool is_function = false;
-  TF_RETURN_IF_ERROR(AttrTypeMapForOp("_Recv", &types, &is_function));
-  DCHECK(!is_function);
-  EagerOperation op(ctx, "_Recv", /*is_function=*/false, types);
-
-  op.SetDevice(device);
-
-  op.MutableAttrs()->Set("tensor_name", wire_id);
-  op.MutableAttrs()->Set("send_device", send_device->name());
-  op.MutableAttrs()->Set(
-      "send_device_incarnation",
-      static_cast<int64>(send_device->attributes().incarnation()));
-  op.MutableAttrs()->Set("recv_device", device->name());
-  op.MutableAttrs()->Set("client_terminated", false);
-
-  op.MutableAttrs()->Set("tensor_type", dtype);
-
-  if (device->IsLocal()) {
-    TF_RETURN_IF_ERROR(ctx->GetStatus());
-
-    core::RefCountPtr<KernelAndDevice> kernel;
-    TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(&op, &kernel));
-
-    std::vector<Tensor> outputs;
-    gtl::InlinedVector<TensorValue, 4> input_vector;
-    TF_RETURN_IF_ERROR(kernel->Run(input_vector, &outputs, nullptr, nullptr,
-                                   nullptr, nullptr));
-
-    // TODO(gjn): Add support for async mode
-    TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
-        outputs[0], /* d= */ kernel->OutputDevice(0),
-        /* op_device= */ kernel->device(), ctx, result));
-  } else {
-    eager::EagerClient* eager_client;
-    uint64 context_id = ctx->GetContextId();
-    TF_RETURN_IF_ERROR(ctx->GetClient(device, &eager_client));
-
-    std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
-    eager::EnqueueResponse response;
-
-    request->set_context_id(context_id);
-
-    auto* remote_op = request->add_queue()->mutable_operation();
-    PrepareRemoteOp(remote_op, &op);
-
-    const uint64 id = remote_op->id();
-    auto tensor_handle_data = absl::make_unique<UnshapedRemoteTensorHandleData>(
-        id, 0, eager_client, context_id, ctx);
-    if (mirror_dst != nullptr) {
-      TF_RETURN_IF_ERROR(mirror_dst->AddUnshapedRemoteMirror(
-          std::move(tensor_handle_data), device));
-      mirror_dst->Ref();
-      *result = mirror_dst;
-    } else {
-      TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle(
-          std::move(tensor_handle_data), dtype, device, ctx, result));
-    }
-
-    std::unique_ptr<EagerNode> node(new eager::RemoteExecuteNode(
-        std::move(request), device, eager_client, op.Inputs(), {result, 1}));
-    if (ctx->Async()) {
-      TF_RETURN_IF_ERROR(ctx->ExecutorAdd(std::move(node)));
-    } else {
-      TF_RETURN_IF_ERROR(node->Run());
-    }
-  }
-
-  return Status::OK();
-}
-
-// This gets a unique wire ID. We add a random identifier so that if the
-// worker has other clients that it is servicing, we don't have any collision.
-string GetUniqueWireID() {
-  static tensorflow::uint64 random_seed = random::New64();
-  static tensorflow::mutex wireid_mutex(tensorflow::LINKER_INITIALIZED);
-  static tensorflow::int64 wireid GUARDED_BY(wireid_mutex) = 0;
-  tensorflow::mutex_lock l(wireid_mutex);
-  return strings::StrCat(random_seed, "_", wireid++);
-}
-#endif  // !IS_MOBILE_PLATFORM
-
 }  // namespace
 
-Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* device,
-                         bool mirror, TensorHandle** result) {
+Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
+                         EagerExecutor* executor, Device* device, bool mirror,
+                         TensorHandle** result) {
   Device* send_device = h->DeviceOrHostCPU(ctx);
 
   bool sender_is_local = send_device->IsLocal();
@@ -1353,7 +1192,7 @@
   bool recver_is_local = device->IsLocal();
 
   if (sender_is_local && recver_is_local) {
-    return LocalEagerCopyToDevice(h, ctx, device, result);
+    return LocalEagerCopyToDevice(h, ctx, executor, device, result);
   } else {
 #if defined(IS_MOBILE_PLATFORM)
     return errors::Unimplemented(
@@ -1370,11 +1209,38 @@
     if (ctx->UseSendTensorRPC() && sender_is_local && !recver_is_local) {
       return EagerRemoteSendTensor(ctx, h, device, mirror, result);
     } else {
-      string wire_id = GetUniqueWireID();
-      TF_RETURN_IF_ERROR(ExecuteSend(ctx, send_device, h, wire_id, device));
-
-      return ExecuteRecv(ctx, device, h->dtype, wire_id, send_device,
-                         mirror ? h : nullptr, result);
+      uint64 recv_op_id = 0;
+      if (recver_is_local) {
+        TF_RETURN_IF_ERROR(TensorHandle::CreateAsyncLocalHandle(
+            /* d= */ device,
+            /* op_device= */ device, /*resource_device=*/nullptr, h->dtype, ctx,
+            result));
+      } else {
+        eager::EagerClient* eager_client;
+        uint64 context_id = ctx->GetContextId();
+        TF_RETURN_IF_ERROR(ctx->GetClient(device, &eager_client));
+        recv_op_id = ctx->RemoteMgr()->NextOpId();
+        auto tensor_handle_data =
+            absl::make_unique<UnshapedRemoteTensorHandleData>(
+                recv_op_id, 0, eager_client, context_id, ctx);
+        if (mirror) {
+          TF_RETURN_IF_ERROR(h->AddUnshapedRemoteMirror(
+              std::move(tensor_handle_data), device));
+          h->Ref();
+          *result = h;
+        } else {
+          TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle(
+              std::move(tensor_handle_data), h->dtype, device, ctx, result));
+        }
+      }
+      auto node = absl::make_unique<eager::RemoteCopyNode>(
+          ctx, executor, h, result[0], device, recv_op_id);
+      Status s =
+          executor->Async() ? executor->Add(std::move(node)) : node->Run();
+      if (!s.ok()) {
+        result[0]->Unref();
+      }
+      return s;
     }
 #endif  // !IS_MOBILE_PLATFORM
   }
diff --git a/tensorflow/core/common_runtime/eager/execute.h b/tensorflow/core/common_runtime/eager/execute.h
index a7b0a2c..f2dc579 100644
--- a/tensorflow/core/common_runtime/eager/execute.h
+++ b/tensorflow/core/common_runtime/eager/execute.h
@@ -58,8 +58,9 @@
 // the mirror flag, EagerCopyToDevice will attempt to add a mirror to the
 // original handle and update *result to point to h. Since this is not
 // guaranteed, callers should always use the value in *result.
-Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* device,
-                         bool mirror, TensorHandle** result);
+Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
+                         EagerExecutor* executor, Device* device, bool mirror,
+                         TensorHandle** result);
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
index 07c7ef2..59c5875 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
@@ -207,7 +207,7 @@
 
     absl::optional<AllocatorStats> allocator_stats =
         allocator_pair.first->GetStats();
-    if (stats) {
+    if (allocator_stats) {
       memory->set_allocator_bytes_in_use(allocator_stats->bytes_in_use);
     }
     allocator_pair.second->GetRecordsAndUnRef();
@@ -259,6 +259,7 @@
   }
 
   OpKernelContext::Params params;
+  params.is_eager = true;
   params.device = device_;
   params.frame_iter = FrameAndIter(0, 0);
   params.inputs = &inputs;
diff --git a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
new file mode 100644
index 0000000..c487aa9
--- /dev/null
+++ b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
@@ -0,0 +1,196 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifdef INTEL_MKL
+#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
+#include "tensorflow/core/graph/mkl_graph_util.h"
+#include "tensorflow/core/graph/mkl_layout_pass.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/mkl_util.h"
+#include "tensorflow/core/util/util.h"
+
+namespace tensorflow {
+
+class MklEagerOpRewrite : public EagerOpRewrite {
+ public:
+  MklEagerOpRewrite(string name, string file, string line);
+  typedef struct {
+    string op_name;
+    std::function<bool(EagerOperation*)> RewriteRule;
+    std::function<Status(EagerOperation*, std::unique_ptr<EagerOperation>*)>
+        CreateMklOp;
+  } MklEagerOp;
+
+ private:
+  // TODO(intel-tf): refactor with unordered_map;
+  // especially when adding more ops/rewrite rules in future.
+  std::vector<MklEagerOp> mkl_eager_ops_;
+
+  // The entry point to execute the op rewrite.
+  Status Run(EagerOperation* orig_op,
+             std::unique_ptr<tensorflow::EagerOperation>* out_op);
+
+  // Initializes the new op and sets up its inputs and attributes
+  static Status SetupNewOp(EagerOperation* orig_op, const string mkl_op_name,
+                           std::unique_ptr<EagerOperation>* new_mkl_op);
+
+  // Creates new MKL op for MatMul
+  static Status CreateMklMatMul(EagerOperation* orig_op,
+                                std::unique_ptr<EagerOperation>* mkl_matmul_op);
+
+  // Creates new MKL op for Conv2D, Conv2DBackpropInput and
+  // Conv2DBackpropFilter.
+  static Status CreateMklConv2DOp(
+      EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_conv2d_op);
+
+  // Rewrite rule for Conv2D, Conv2DBackpropInput and Conv2DBackpropFilter.
+  static bool RewriteConv2D(EagerOperation* op);
+
+  // Calls op-specific rewrite function to create new MKL op.
+  Status RewriteToMklOp(EagerOperation* orig_op,
+                        std::unique_ptr<EagerOperation>* mkl_op,
+                        const int op_idx);
+
+  // Checks whether we can rewrite the op to MKL one or not.
+  bool ShouldRewriteOp(EagerOperation* op, int* op_idx);
+
+  // Default rewrite rule to be used when rewrite should happen without any
+  // restriction.
+  static bool AlwaysRewrite(EagerOperation* op) { return true; }
+};
+
+REGISTER_REWRITE(EagerOpRewriteRegistry::PRE_EXECUTION, MklEagerOpRewrite);
+
+// Constructor
+MklEagerOpRewrite::MklEagerOpRewrite(string name, string file, string line)
+    : EagerOpRewrite(name, file, line) {
+  mkl_eager_ops_.push_back({"Conv2D", RewriteConv2D, CreateMklConv2DOp});
+  mkl_eager_ops_.push_back(
+      {"Conv2DBackpropInput", RewriteConv2D, CreateMklConv2DOp});
+  mkl_eager_ops_.push_back(
+      {"Conv2DBackpropFilter", RewriteConv2D, CreateMklConv2DOp});
+  mkl_eager_ops_.push_back({"MatMul", AlwaysRewrite, CreateMklMatMul});
+}
+
+Status MklEagerOpRewrite::Run(
+    EagerOperation* orig_op,
+    std::unique_ptr<tensorflow::EagerOperation>* out_op) {
+  int found_op_idx = -1;
+  if (ShouldRewriteOp(orig_op, &found_op_idx)) {
+    TF_CHECK_OK(RewriteToMklOp(orig_op, out_op, found_op_idx));
+  }
+  return Status::OK();
+}
+
+Status MklEagerOpRewrite::SetupNewOp(
+    EagerOperation* orig_op, const string mkl_op_name,
+    std::unique_ptr<EagerOperation>* new_mkl_op) {
+  const tensorflow::AttrTypeMap* types;
+  bool is_function = false;
+  TF_RETURN_IF_ERROR(
+      tensorflow::AttrTypeMapForOp(mkl_op_name.c_str(), &types, &is_function));
+  EagerContext* ctx = orig_op->EagerContext();
+  new_mkl_op->reset(new tensorflow::EagerOperation(ctx, mkl_op_name.c_str(),
+                                                   is_function, types));
+
+  int num_inputs = orig_op->Inputs().size();
+  // Add all inputs to the new op.
+  for (int i = 0; i < num_inputs; ++i) {
+    (*new_mkl_op)->AddInput(orig_op->Inputs()[i]);
+  }
+
+  // Copy all attributes to the new op.
+  string name;
+  const NodeDef& orig_ndef = orig_op->MutableAttrs()->BuildNodeDef();
+
+  AttrSlice attr_list(orig_ndef);
+  for (const auto& attr : attr_list) {
+    (*new_mkl_op)->MutableAttrs()->Set(attr.first, attr.second);
+  }
+
+  (*new_mkl_op)
+      ->MutableAttrs()
+      ->Set("_kernel", mkl_op_registry::kMklNameChangeOpLabel);
+
+  if (orig_op->Device() != nullptr) {
+    (*new_mkl_op)->SetDevice(orig_op->Device());
+  } else {
+    string device_name =
+        DeviceNameUtils::ParsedNameToString(orig_op->GetDeviceName());
+    (*new_mkl_op)->SetDeviceName(device_name.c_str());
+  }
+  return Status::OK();
+}
+
+Status MklEagerOpRewrite::CreateMklMatMul(
+    EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_matmul_op) {
+  const string mkl_op_name = mkl_op_registry::GetMklOpName(orig_op->Name());
+  TF_CHECK_OK(SetupNewOp(orig_op, mkl_op_name, mkl_matmul_op));
+  return Status::OK();
+}
+
+Status MklEagerOpRewrite::CreateMklConv2DOp(
+    EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_conv2d_op) {
+  const string mkl_op_name =
+      mkl_op_registry::GetMklEagerOpName(orig_op->Name());
+  TF_CHECK_OK(SetupNewOp(orig_op, mkl_op_name, mkl_conv2d_op));
+  return Status::OK();
+}
+
+bool MklEagerOpRewrite::ShouldRewriteOp(EagerOperation* op, int* op_idx) {
+  // Don't rewrite the op if MKL use is disabled at runtime.
+  if (DisableMKL()) {
+    return false;
+  }
+  DataType data_type;
+  if (op->Attrs().Get("T", &data_type) != Status::OK()) {
+    return false;
+  }
+  // Check if we have registered MKL kernel for this op.
+  if (!mkl_op_registry::IsMklNameChangeOp(
+          mkl_op_registry::GetMklEagerOpName(op->Name()), data_type) &&
+      !mkl_op_registry::IsMklNameChangeOp(
+          mkl_op_registry::GetMklOpName(op->Name()), data_type)) {
+    return false;
+  }
+
+  *op_idx = -1;
+  // Find and call the op's rewrite rule that determines whether we need to
+  // rewrite this op or not.
+  for (auto it = mkl_eager_ops_.begin(); it != mkl_eager_ops_.end(); ++it) {
+    if (it->op_name.compare(op->Name()) == 0 && it->RewriteRule(op)) {
+      *op_idx = it - mkl_eager_ops_.begin();
+      return true;
+    }
+  }
+  return false;
+}
+
+Status MklEagerOpRewrite::RewriteToMklOp(
+    EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_op,
+    const int op_idx) {
+  mkl_eager_ops_[op_idx].CreateMklOp(orig_op, mkl_op);
+  return Status::OK();
+}
+
+bool MklEagerOpRewrite::RewriteConv2D(EagerOperation* op) {
+  const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
+  string padding;
+  TF_CHECK_OK(GetNodeAttr(ndef, "padding", &padding));
+  // Right now MKL Conv2D does not support explicit padding.
+  return (padding != "EXPLICIT");
+}
+
+}  // namespace tensorflow
+#endif  // INTEL_MKL
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc
index 8f68ee4..f451cbb 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc
@@ -53,6 +53,13 @@
 
 namespace tensorflow {
 
+namespace {
+#if !defined(IS_MOBILE_PLATFORM)
+const int64 kInvalidOpId = -1;
+const int32 kInvalidOutputNum = -1;
+#endif
+}  // namespace
+
 Status TensorHandle::GetResourceHandleDtypesAndShapes(
     std::vector<DtypeAndPartialTensorShape>* result) {
   if (IsRemote()) {
@@ -109,8 +116,8 @@
       op_device_(op_device),
       resource_device_(nullptr),
 #if !defined(IS_MOBILE_PLATFORM)
-      remote_op_id_(-1),
-      remote_output_num_(-1),
+      remote_op_id_(kInvalidOpId),
+      remote_output_num_(kInvalidOutputNum),
 #endif
       ctx_(ctx),
       is_remote_(false),
@@ -128,8 +135,8 @@
       op_device_(op_device),
       resource_device_(GetResourceDevice(resource_handle, ctx)),
 #if !defined(IS_MOBILE_PLATFORM)
-      remote_op_id_(-1),
-      remote_output_num_(-1),
+      remote_op_id_(kInvalidOpId),
+      remote_output_num_(kInvalidOutputNum),
 #endif
       ctx_(ctx),
       is_remote_(false),
@@ -159,8 +166,8 @@
       op_device_(op_device),
       resource_device_(resource_device),
 #if !defined(IS_MOBILE_PLATFORM)
-      remote_op_id_(-1),
-      remote_output_num_(-1),
+      remote_op_id_(kInvalidOpId),
+      remote_output_num_(kInvalidOutputNum),
 #endif
       ctx_(ctx),
       is_remote_(false),
@@ -250,8 +257,8 @@
       op_device_(nullptr),
       resource_device_(nullptr),
 #if !defined(IS_MOBILE_PLATFORM)
-      remote_op_id_(-1),
-      remote_output_num_(-1),
+      remote_op_id_(kInvalidOpId),
+      remote_output_num_(kInvalidOutputNum),
 #endif
       ctx_(nullptr),
       is_remote_(false),
@@ -326,11 +333,24 @@
         "Could not find remote mirror for specified device");
   }
 
+  if (remote_op_id_ == kInvalidOpId ||
+      remote_output_num_ == kInvalidOutputNum) {
+    return errors::InvalidArgument("Remote handle (op_id:", remote_op_id_,
+                                   ", output_num:", remote_output_num_,
+                                   ") is not set.");
+  }
   *op_id = remote_op_id_;
   *output_num = remote_output_num_;
   return Status::OK();
 }
 
+void TensorHandle::SetRemoteOpIdAndOutputNumToLocalTensorHandle(
+    const int64 op_id, const int32 output_num) {
+  DCHECK(!is_remote_);
+  remote_op_id_ = op_id;
+  remote_output_num_ = output_num;
+}
+
 bool TensorHandle::HasRemoteMirror(Device* d) {
   tf_shared_lock l(remote_mirrors_mutex_);
   auto mirror = remote_mirrors_.find(d);
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h
index 1ecf5bf..95003a9 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.h
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.h
@@ -149,6 +149,11 @@
   // Return the op_id and output num if the handle refers to a remote tensor.
   Status RemoteAddress(Device* d, int64* op_id, int32* output_num) const;
 
+  // Set remote_op_id_ and remote_output_num_ if the handle refers to a local
+  // tensor that needs to be copied to remote workers.
+  void SetRemoteOpIdAndOutputNumToLocalTensorHandle(const int64 op_id,
+                                                    const int32 output_num);
+
   // Called on an async remote tensor once it's shape has been determined. This
   // transitions the tensor handle from a non-ready to a ready state by
   // replacing the backing data abstraction to allow for the shape to be
@@ -238,8 +243,8 @@
       remote_mirrors_ GUARDED_BY(remote_mirrors_mutex_);
 
   // IDs required when this class is representing a remote tensor handle.
-  const int64 remote_op_id_;
-  const int32 remote_output_num_;
+  int64 remote_op_id_;
+  int32 remote_output_num_;
   eager::EagerClient* remote_eager_client_;
   uint64 remote_context_id_;
 #endif
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 0be4394..da2954a 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -2548,9 +2548,9 @@
                                            const Node* node,
                                            FrameState** child) {
   // Get the child frame name.
-  string enter_name;
-  Status s = GetNodeAttr(node->attrs(), "frame_name", &enter_name);
-  DCHECK(s.ok()) << s;
+  const string& enter_name = GetNodeAttrString(node->attrs(), "frame_name");
+  DCHECK(!enter_name.empty())
+      << "Could not find \"frame_name\" attr in node " << node->name();
   const string child_name = MakeFrameName(frame, iter, enter_name);
 
   {
@@ -2567,8 +2567,10 @@
   if (vlog_) VLOG(2) << "Create frame: " << child_name;
 
   int parallel_iters;
-  s = GetNodeAttr(node->attrs(), "parallel_iterations", &parallel_iters);
-  DCHECK(s.ok()) << s;
+  bool found_parallel_iters =
+      TryGetNodeAttr(node->attrs(), "parallel_iterations", &parallel_iters);
+  DCHECK(found_parallel_iters)
+      << "Could not find \"parallel_iterations\" attr in node " << node->name();
   FrameState* temp = new FrameState(impl_, parallel_iters);
   temp->frame_name = child_name;
   temp->frame_id = Hash64(child_name);
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 9ca758b..c60ec68 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -844,9 +844,9 @@
   // TODO(zhifengc): Change Graph to record #nodes.
   VLOG(2) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges "
           << g->num_edges();
-  if (VLOG_IS_ON(4)) {
+  if (VLOG_IS_ON(5)) {
     for (const auto& line : str_util::Split(DebugString(g), '\n')) {
-      VLOG(4) << "|| " << line;
+      VLOG(5) << "|| " << line;
     }
   }
 }
@@ -1246,7 +1246,7 @@
       env_, graph_def_version_, optimizer_.options(), custom_kernel_creator_,
       out_lib_def, out_pflr, skip_flib_def));
   *out_flr = (*out_pflr)->GetFLR(device_->name());
-  if (out_flr != nullptr) {
+  if (*out_flr != nullptr) {
     return Status::OK();
   } else {
     return errors::Internal("Cloning FunctionLibraryRuntime failed.");
@@ -1495,13 +1495,28 @@
 
 std::vector<string> InputDevices(const Node& caller) {
   std::vector<string> input_devices(caller.in_edges().size());
+  std::vector<string> input_tensors(caller.in_edges().size());
+
   for (const Edge* edge : caller.in_edges()) {
     if (edge->IsControlEdge()) continue;
     const string& input_device = edge->src()->has_assigned_device_name()
                                      ? edge->src()->assigned_device_name()
                                      : edge->src()->requested_device();
     input_devices[edge->dst_input()] = input_device;
+    input_tensors[edge->dst_input()] =
+        absl::StrCat(edge->src()->name(), ":", edge->src_output());
   }
+
+  if (VLOG_IS_ON(4)) {
+    VLOG(4) << "Function instantiation input devices:";
+    for (int i = 0; i < input_devices.size(); ++i) {
+      if (input_tensors[i].empty()) continue;  // skip control edges
+      VLOG(4) << "    [index " << i << "]"
+              << " device: " << input_devices[i]
+              << " (input: " << input_tensors[i] << ")";
+    }
+  }
+
   return input_devices;
 }
 
@@ -1616,24 +1631,21 @@
 std::unique_ptr<InlinedFunctionBodyPlacer>
 InlinedFunctionBodyPlacer::DefaultPlacer(const Graph& graph,
                                          const Node& caller) {
-  VLOG(3) << "Create default placer for inlined function body: "
-          << SummarizeNode(caller);
+  VLOG(3) << "Create default placer for inlined function body.";
   return absl::make_unique<DefaultFunctionBodyPlacer>(caller);
 }
 
 std::unique_ptr<InlinedFunctionBodyPlacer>
 InlinedFunctionBodyPlacer::SingleDevicePlacer(const Graph& graph,
                                               const Node& caller) {
-  VLOG(3) << "Create single device placer for inlined function body: "
-          << SummarizeNode(caller);
+  VLOG(3) << "Create single device placer for inlined function body.";
   return absl::make_unique<SingleDeviceFunctionBodyPlacer>(caller);
 }
 
 std::unique_ptr<InlinedFunctionBodyPlacer>
 InlinedFunctionBodyPlacer::MultiDevicePlacer(const Graph& graph,
                                              const Node& caller) {
-  VLOG(3) << "Create multi device placer for inlined function body: "
-          << SummarizeNode(caller);
+  VLOG(3) << "Create multi device placer for inlined function body.";
   return absl::make_unique<MultiDeviceFunctionBodyPlacer>(caller);
 }
 
@@ -1642,7 +1654,7 @@
 Status ValidateNoInline(const FunctionBody* fbody) {
   const auto attr = AttrSlice(&fbody->fdef.attr());
   bool noinline = false;
-  if (GetNodeAttr(attr, kNoInlineAttr, &noinline).ok() && noinline) {
+  if (TryGetNodeAttr(attr, kNoInlineAttr, &noinline) && noinline) {
     return errors::InvalidArgument(
         "Can't inline function marked with '_noinline'");
   }
@@ -1848,7 +1860,7 @@
                           const InlineFunctionBodyOptions& options) {
   VLOG(3) << "Inline function call: " << SummarizeNode(*caller) << " ["
           << options.DebugString() << "]";
-  VLOG(4) << "Inlined function definition: " << DebugString(fbody->fdef);
+  VLOG(5) << "Inlined function definition: " << DebugString(fbody->fdef);
 
   Status validation = ValidateInlining(caller, fbody, options);
   if (!validation.ok()) {
@@ -1985,9 +1997,13 @@
   // identity node.
   //
   // The added identity nodes depend on "input_control_node".
+  VLOG(4) << "Add input Identity nodes for each function argument:";
   for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) {
     Node* arg = node_map[fbody->arg_nodes[i]->id()];
     Node* n = input_identity("input", inputs[i], i);
+    VLOG(4) << "    [index " << i << "] " << n->name()
+            << " (input: " << inputs[i].name() << ")";
+
     if (input_control_node) {
       g->AddControlEdge(input_control_node, n, kDoNotCheckDuplicates);
     }
diff --git a/tensorflow/core/common_runtime/function_testlib.cc b/tensorflow/core/common_runtime/function_testlib.cc
index 1720ee6..bbaa94d 100644
--- a/tensorflow/core/common_runtime/function_testlib.cc
+++ b/tensorflow/core/common_runtime/function_testlib.cc
@@ -33,7 +33,7 @@
     Tensor* device_tensor = nullptr;
     OP_REQUIRES_OK(ctx, ctx->allocate_output("device_name", TensorShape{},
                                              &device_tensor));
-    device_tensor->scalar<string>()() =
+    device_tensor->scalar<tstring>()() =
         ctx->function_library()->device()->name();
   }
 };
diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
index ea12a66..491ef2a 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
@@ -61,6 +61,10 @@
 #endif  // GOOGLE_CUDA
 }
 
+absl::optional<AllocatorStats> GPUcudaMallocAllocator::GetStats() {
+  return base_allocator_->GetStats();
+}
+
 bool GPUcudaMallocAllocator::TracksAllocationSizes() const { return false; }
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
index 5025eed..b45d505 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
@@ -38,6 +38,7 @@
   void* AllocateRaw(size_t alignment, size_t num_bytes) override;
   void DeallocateRaw(void* ptr) override;
   bool TracksAllocationSizes() const override;
+  absl::optional<AllocatorStats> GetStats() override;
 
  private:
   Allocator* base_allocator_ = nullptr;  // owned
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc b/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc
index 2b40730..84eb841 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc
@@ -59,7 +59,9 @@
 
   bool Find(TfGpuId tf_gpu_id, PlatformGpuId* platform_gpu_id) const
       LOCKS_EXCLUDED(mu_) {
-    mutex_lock lock(mu_);
+    // TODO(mrry): Consider replacing this with an atomic `is_initialized` bit,
+    // to avoid writing to a shared cache line in the tf_shared_lock.
+    tf_shared_lock lock(mu_);
     auto result = id_map_.find(tf_gpu_id.value());
     if (result == id_map_.end()) return false;
     *platform_gpu_id = result->second;
diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc
index 4907183..95a3e70 100644
--- a/tensorflow/core/common_runtime/graph_execution_state.cc
+++ b/tensorflow/core/common_runtime/graph_execution_state.cc
@@ -466,8 +466,8 @@
 
   // All the node types handled here have their output datatype set in
   // either attribute 'dtype' or 'T'.
-  if (!GetNodeAttr(node, "dtype", type).ok() &&
-      !GetNodeAttr(node, "T", type).ok()) {
+  if (!TryGetNodeAttr(node, "dtype", type) &&
+      !TryGetNodeAttr(node, "T", type)) {
     return errors::InvalidArgument(
         "Could not determine output type for feed node: ", node.name(),
         " of type ", node.op());
@@ -757,8 +757,8 @@
 
     GraphConstructorOptions opts;
     opts.allow_internal_ops = true;
-    TF_RETURN_IF_ERROR(
-        ConvertGraphDefToGraph(opts, new_graph, optimized_graph->get()));
+    TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, std::move(new_graph),
+                                              optimized_graph->get()));
     // The graph conversion sets the requested device names but not the
     // assigned device names. However, since at this point the graph is placed
     // TF expects an assigned device name for every node. Therefore we copy
@@ -848,7 +848,8 @@
           for (const NodeDef& ndef : fdef->node_def()) {
             if (ndef.op() == "CollectiveReduce" ||
                 ndef.op() == "CollectiveBcastSend" ||
-                ndef.op() == "CollectiveBcastRecv") {
+                ndef.op() == "CollectiveBcastRecv" ||
+                ndef.op() == "CollectiveGather") {
               int32 instance_key;
               TF_RETURN_IF_ERROR(
                   GetNodeAttr(ndef, "instance_key", &instance_key));
diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
index a300ae0..c00645a 100644
--- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
+++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
@@ -15,6 +15,7 @@
 #include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h"
 
 #include <algorithm>
+
 #include "absl/memory/memory.h"
 #include "tensorflow/core/common_runtime/base_collective_executor.h"
 #include "tensorflow/core/common_runtime/collective_rma_local.h"
@@ -32,6 +33,7 @@
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/lib/core/notification.h"
 #include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/unbounded_work_queue.h"
 #include "tensorflow/core/public/session_options.h"
 #include "tensorflow/core/public/version.h"
 
@@ -136,8 +138,9 @@
 class FailTestRMA : public CollectiveRemoteAccessLocal {
  public:
   FailTestRMA(const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
-              int64 step_id, int fail_after)
-      : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id),
+              std::shared_ptr<UnboundedWorkQueue> work_queue, int64 step_id,
+              int fail_after)
+      : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, work_queue, step_id),
         fail_after_(fail_after) {}
 
   bool MaybeFail(const StatusCallback& done) {
@@ -244,12 +247,15 @@
       }
     }
     if (!dev_mgr_ || device_type == DEVICE_CPU) {
-      dev_mgr_.reset(new DeviceMgr(std::move(local_devices)));
+      dev_mgr_ = absl::make_unique<DeviceMgr>(std::move(local_devices));
     }
-    if (!gpu_ring_order_) gpu_ring_order_.reset(new string());
-    dev_resolver_.reset(new DeviceResolverLocal(dev_mgr_.get()));
-    rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(), kStepId,
-                           fail_after);
+    if (!gpu_ring_order_) {
+      gpu_ring_order_ = absl::make_unique<string>();
+    }
+    dev_resolver_ = absl::make_unique<DeviceResolverLocal>(dev_mgr_.get());
+    work_queue_ = std::make_shared<UnboundedWorkQueue>(Env::Default(), "test");
+    rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(), work_queue_,
+                           kStepId, fail_after);
     col_exec_ = new BaseCollectiveExecutor(
         &col_exec_mgr_, rma_, kStepId, dev_mgr_.get(), gpu_ring_order_.get());
     col_params_.name = "test_collective";
@@ -714,6 +720,7 @@
   CollectiveExecutor* col_exec_ = nullptr;
   CollectiveRemoteAccessLocal* rma_;
   std::unique_ptr<DeviceResolverLocal> dev_resolver_;
+  std::shared_ptr<UnboundedWorkQueue> work_queue_;
   std::vector<DeviceInstance*> instances_;
   CollectiveParams col_params_;
   std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
diff --git a/tensorflow/core/common_runtime/inspecting_placer.cc b/tensorflow/core/common_runtime/inspecting_placer.cc
index 19cc784..88317bf 100644
--- a/tensorflow/core/common_runtime/inspecting_placer.cc
+++ b/tensorflow/core/common_runtime/inspecting_placer.cc
@@ -108,15 +108,13 @@
   int next_group_id_;
 };
 
-InspectingPlacer::InspectingPlacer(const Graph* graph,
-                                   const FunctionStack& stack,
+InspectingPlacer::InspectingPlacer(const FunctionStack& stack,
                                    const FunctionLibraryDefinition* flib_def,
                                    const DeviceSet* device_set,
                                    const Device* default_device,
                                    bool allow_soft_placement,
                                    bool log_device_placement)
-    : graph_(*graph),
-      stack_(stack),
+    : stack_(stack),
       flib_def_(*flib_def),
       device_set_(*device_set),
       default_device_(default_device),
diff --git a/tensorflow/core/common_runtime/inspecting_placer.h b/tensorflow/core/common_runtime/inspecting_placer.h
index 6cba364..3fe6a1a 100644
--- a/tensorflow/core/common_runtime/inspecting_placer.h
+++ b/tensorflow/core/common_runtime/inspecting_placer.h
@@ -69,7 +69,7 @@
   // TODO(iga): Add a "stack trace" to detect recursion and improve log
   // messages. Currently, we will enter an infinite loop for recursive
   // functions.
-  InspectingPlacer(const Graph* graph, const FunctionStack& stack,
+  InspectingPlacer(const FunctionStack& stack,
                    const FunctionLibraryDefinition* flib_def,
                    const DeviceSet* device_set, const Device* default_device,
                    bool allow_soft_placement, bool log_device_placement);
@@ -80,7 +80,6 @@
                                    IOColocationGroups* groups);
 
  private:
-  const Graph& graph_;
   const FunctionStack stack_;
   const FunctionLibraryDefinition& flib_def_;
   const DeviceSet& device_set_;
diff --git a/tensorflow/core/common_runtime/lower_case_op.cc b/tensorflow/core/common_runtime/lower_case_op.cc
index f85dc14..24ca8a9 100644
--- a/tensorflow/core/common_runtime/lower_case_op.cc
+++ b/tensorflow/core/common_runtime/lower_case_op.cc
@@ -38,11 +38,9 @@
 class CaseBuilder {
  public:
   // Create a CaseBuilder to create the lowered form of `case` with branch
-  // functions identified by `branch_fn_names` in the `graph`. The functions
-  // should be available in `flib`.
+  // functions identified by `branch_fn_names` in the `graph`.
   CaseBuilder(Node* case_op, const std::vector<string>& branch_fn_names,
-              const FunctionLibraryDefinition& flib, bool keep_node_fetchable,
-              Graph* graph);
+              bool keep_node_fetchable, Graph* graph);
 
   // Constructs the basic conditional control flow using switch and merge nodes.
   Status CreatePivotNodes();
@@ -91,7 +89,6 @@
   // for the side effects.
   Node* branch_executed_node_;
   Graph* graph_;
-  const FunctionLibraryDefinition& flib_;
   string name_;
   bool keep_node_fetchable_;
 
@@ -101,12 +98,10 @@
 
 CaseBuilder::CaseBuilder(Node* case_op,
                          const std::vector<string>& branch_fn_names,
-                         const FunctionLibraryDefinition& flib,
                          bool keep_node_fetchable, Graph* graph)
     : case_op_(case_op),
       num_branches_(branch_fn_names.size()),
       graph_(graph),
-      flib_(flib),
       name_(case_op->name()),
       keep_node_fetchable_(keep_node_fetchable),
       debug_info_(*case_op_) {
@@ -273,8 +268,7 @@
 
 }  // namespace
 
-Status RewriteCaseNode(Node* n, Graph* g, const FunctionLibraryDefinition& flib,
-                       bool keep_node_fetchable) {
+Status RewriteCaseNode(Node* n, Graph* g, bool keep_node_fetchable) {
   VLOG(2) << "Lower Case node (keep_node_fetchable=" << keep_node_fetchable
           << "): " << SummarizeNode(*n);
   const AttrValue* branches_attr = n->attrs().Find("branches");
@@ -288,7 +282,7 @@
   for (int b = 0; b < num_branches; b++) {
     branch_fn_names.emplace_back(branches_attr->list().func(b).name());
   }
-  CaseBuilder cb(n, branch_fn_names, flib, keep_node_fetchable, g);
+  CaseBuilder cb(n, branch_fn_names, keep_node_fetchable, g);
   TF_RETURN_IF_ERROR(cb.CreatePivotNodes());
   TF_RETURN_IF_ERROR(cb.AddInputs());
   TF_RETURN_IF_ERROR(cb.AddOutputs());
diff --git a/tensorflow/core/common_runtime/lower_case_op.h b/tensorflow/core/common_runtime/lower_case_op.h
index fc46a1f..9148f43 100644
--- a/tensorflow/core/common_runtime/lower_case_op.h
+++ b/tensorflow/core/common_runtime/lower_case_op.h
@@ -22,8 +22,7 @@
 namespace tensorflow {
 
 // Replaces Case node `n` with a lowered form that uses _SwitchN/Merge nodes.
-Status RewriteCaseNode(Node* n, Graph* g, const FunctionLibraryDefinition& flib,
-                       bool keep_node_fetchable);
+Status RewriteCaseNode(Node* n, Graph* g, bool keep_node_fetchable);
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/common_runtime/lower_function_call_op.cc b/tensorflow/core/common_runtime/lower_function_call_op.cc
index 87b0246..1152619 100644
--- a/tensorflow/core/common_runtime/lower_function_call_op.cc
+++ b/tensorflow/core/common_runtime/lower_function_call_op.cc
@@ -33,8 +33,9 @@
   if (n->IsPartitionedCall()) return true;
 
   bool match;
-  Status s = GetNodeAttr(n->attrs(), kLowerAsMultiDeviceFunctionAttr, &match);
-  return s.ok() && match;
+  bool found =
+      TryGetNodeAttr(n->attrs(), kLowerAsMultiDeviceFunctionAttr, &match);
+  return found && match;
 }
 
 }  // namespace
diff --git a/tensorflow/core/common_runtime/lower_functional_ops.cc b/tensorflow/core/common_runtime/lower_functional_ops.cc
index 2b8d941..4254fd1 100644
--- a/tensorflow/core/common_runtime/lower_functional_ops.cc
+++ b/tensorflow/core/common_runtime/lower_functional_ops.cc
@@ -40,15 +40,15 @@
 // Checks if boolean attribute is defined and it's value is 'true'.
 bool CheckBoolAttr(const Node* n, absl::string_view attr_name) {
   bool match;
-  Status s = GetNodeAttr(n->attrs(), attr_name, &match);
-  return s.ok() && match;
+  bool found = TryGetNodeAttr(n->attrs(), attr_name, &match);
+  return found && match;
 }
 
 // Checks if string attribute is defined and it's not empty.
 bool CheckStringAttr(const Node* n, absl::string_view attr_name) {
   string match;
-  Status s = GetNodeAttr(n->attrs(), attr_name, &match);
-  return s.ok() && !match.empty();
+  bool found = TryGetNodeAttr(n->attrs(), attr_name, &match);
+  return found && !match.empty();
 }
 
 bool LowerUsingSwitchMergeIsOn(const Node* n) {
@@ -138,14 +138,12 @@
 
     if (LowerUsingSwitchMergeIsOn(n)) {
       if (n->IsIfNode()) {
-        TF_RETURN_IF_ERROR(
-            RewriteIfNode(n, g, *flib_def, keep_lowered_nodes_fetchable));
+        TF_RETURN_IF_ERROR(RewriteIfNode(n, g, keep_lowered_nodes_fetchable));
       } else if (n->type_string() == "Case") {
+        TF_RETURN_IF_ERROR(RewriteCaseNode(n, g, keep_lowered_nodes_fetchable));
+      } else if (n->IsWhileNode()) {
         TF_RETURN_IF_ERROR(
-            RewriteCaseNode(n, g, *flib_def, keep_lowered_nodes_fetchable));
-      } else if (n->type_string() == "While") {
-        TF_RETURN_IF_ERROR(
-            RewriteWhileNode(n, g, *flib_def, keep_lowered_nodes_fetchable));
+            RewriteWhileNode(n, g, keep_lowered_nodes_fetchable));
       } else {
         return errors::Internal(
             "Node ", FormatNodeForError(*n), " of type ", n->type_string(),
diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc
index 2cd89ea..9b1d2b8 100644
--- a/tensorflow/core/common_runtime/lower_if_op.cc
+++ b/tensorflow/core/common_runtime/lower_if_op.cc
@@ -41,8 +41,7 @@
   // else functions `then_fn` and `else_fn` respectively in the `graph`. The
   // functions should be available in `flib`.
   CondBuilder(Node* if_op, const NameAttrList& then_fn,
-              const NameAttrList& else_fn,
-              const FunctionLibraryDefinition& flib, bool keep_node_fetchable,
+              const NameAttrList& else_fn, bool keep_node_fetchable,
               Graph* graph);
 
   // Constructs the basic conditional control flow using switch and merge nodes.
@@ -95,7 +94,6 @@
   // executed for the side effects.
   Node* branch_executed_node_;
   Graph* graph_;
-  const FunctionLibraryDefinition& flib_;
   string name_;
   bool keep_node_fetchable_;
 
@@ -106,11 +104,9 @@
 
 CondBuilder::CondBuilder(Node* if_op, const NameAttrList& then_fn,
                          const NameAttrList& else_fn,
-                         const FunctionLibraryDefinition& flib,
                          bool keep_node_fetchable, Graph* graph)
     : if_op_(if_op),
       graph_(graph),
-      flib_(flib),
       name_(if_op->name()),
       keep_node_fetchable_(keep_node_fetchable),
       debug_info_(*if_op_),
@@ -272,8 +268,7 @@
 
 }  // namespace
 
-Status RewriteIfNode(Node* n, Graph* g, const FunctionLibraryDefinition& flib,
-                     bool keep_node_fetchable) {
+Status RewriteIfNode(Node* n, Graph* g, bool keep_node_fetchable) {
   VLOG(2) << "Lower If node (keep_node_fetchable=" << keep_node_fetchable
           << "): " << SummarizeNode(*n);
 
@@ -286,8 +281,8 @@
     return errors::InvalidArgument("Else branch function missing");
   }
 
-  CondBuilder cb(n, then_attr->func(), else_attr->func(), flib,
-                 keep_node_fetchable, g);
+  CondBuilder cb(n, then_attr->func(), else_attr->func(), keep_node_fetchable,
+                 g);
   TF_RETURN_IF_ERROR(cb.CreatePivotNodes());
   TF_RETURN_IF_ERROR(cb.AddInputs());
   TF_RETURN_IF_ERROR(cb.AddOutputs());
diff --git a/tensorflow/core/common_runtime/lower_if_op.h b/tensorflow/core/common_runtime/lower_if_op.h
index 00e9302..cfaf15e 100644
--- a/tensorflow/core/common_runtime/lower_if_op.h
+++ b/tensorflow/core/common_runtime/lower_if_op.h
@@ -22,8 +22,7 @@
 namespace tensorflow {
 
 // Replaces If node `n` with its lowered form that uses Switch and Merge nodes.
-Status RewriteIfNode(Node* n, Graph* g, const FunctionLibraryDefinition& flib,
-                     bool keep_node_fetchable);
+Status RewriteIfNode(Node* n, Graph* g, bool keep_node_fetchable);
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/common_runtime/lower_while_op.cc b/tensorflow/core/common_runtime/lower_while_op.cc
index c1c5e51..c28918a 100644
--- a/tensorflow/core/common_runtime/lower_while_op.cc
+++ b/tensorflow/core/common_runtime/lower_while_op.cc
@@ -58,10 +58,9 @@
  public:
   static Status Run(Node* while_op, const NameAttrList& cond_fn,
                     const NameAttrList& body_fn, int parallel_iterations,
-                    Graph* graph, const FunctionLibraryDefinition& flib,
-                    bool keep_node_fetchable) {
+                    Graph* graph, bool keep_node_fetchable) {
     LowerWhileHelper helper(while_op, cond_fn, body_fn, parallel_iterations,
-                            graph, flib, keep_node_fetchable);
+                            graph, keep_node_fetchable);
     return helper.RunInternal();
   }
 
@@ -71,8 +70,7 @@
   // the given graph.
   LowerWhileHelper(Node* while_op, const NameAttrList& cond_fn,
                    const NameAttrList& body_fn, int parallel_iterations,
-                   Graph* graph, const FunctionLibraryDefinition& flib,
-                   bool keep_node_fetchable);
+                   Graph* graph, bool keep_node_fetchable);
 
   Status RunInternal();
 
@@ -136,7 +134,6 @@
   // used as a source of outgoing control edges from lowered While node.
   Node* lowered_while_executed_;
   Graph* graph_;
-  const FunctionLibraryDefinition& flib_;
   // Name of the `while_op_`.
   string name_;
   // Max number of parallel_iterations for the while loop.
@@ -159,11 +156,9 @@
 LowerWhileHelper::LowerWhileHelper(Node* while_op, const NameAttrList& cond_fn,
                                    const NameAttrList& body_fn,
                                    int parallel_iterations, Graph* graph,
-                                   const FunctionLibraryDefinition& flib,
                                    bool keep_node_fetchable)
     : while_op_(while_op),
       graph_(graph),
-      flib_(flib),
       name_(while_op->name()),
       parallel_iterations_(parallel_iterations),
       keep_node_fetchable_(keep_node_fetchable),
@@ -417,7 +412,6 @@
 }  // namespace
 
 Status RewriteWhileNode(Node* n, Graph* g,
-                        const FunctionLibraryDefinition& flib,
                         bool keep_node_fetchable) {
   VLOG(2) << "Lower While node (keep_node_fetchable=" << keep_node_fetchable
           << "): " << SummarizeNode(*n);
@@ -438,7 +432,7 @@
 
   TF_RETURN_IF_ERROR(LowerWhileHelper::Run(
       n, cond_attr->func(), body_attr->func(), parallel_iterations_attr->i(), g,
-      flib, keep_node_fetchable));
+      keep_node_fetchable));
   g->RemoveNode(n);
 
   return Status::OK();
diff --git a/tensorflow/core/common_runtime/lower_while_op.h b/tensorflow/core/common_runtime/lower_while_op.h
index 2090a24..9f016c4 100644
--- a/tensorflow/core/common_runtime/lower_while_op.h
+++ b/tensorflow/core/common_runtime/lower_while_op.h
@@ -23,9 +23,7 @@
 
 // Replaces While node `n` with its lowered form that uses Enter, Exit, Switch,
 // Merge, NextIteration and LoopCond nodes.
-Status RewriteWhileNode(Node* n, Graph* g,
-                        const FunctionLibraryDefinition& flib,
-                        bool keep_node_fetchable);
+Status RewriteWhileNode(Node* n, Graph* g, bool keep_node_fetchable);
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/common_runtime/metrics.cc b/tensorflow/core/common_runtime/metrics.cc
index 5a6a5e9..bb3ea39 100644
--- a/tensorflow/core/common_runtime/metrics.cc
+++ b/tensorflow/core/common_runtime/metrics.cc
@@ -78,6 +78,15 @@
     "spent optimizing the graph with Grappler, and time spent pruning the "
     "sub-graph.");
 
+auto* xla_compilations = monitoring::Counter<0>::New(
+    "/tensorflow/core/xla_compilations",
+    "The number of XLA compilations used to collect "
+    "/tensorflow/core/xla_compilation_time_usecs");
+
+auto* xla_compilation_time_usecs = monitoring::Counter<0>::New(
+    "/tensorflow/core/xla_compilation_time_usecs",
+    "The total time spent on compiling XLA graphs in microseconds.");
+
 }  // namespace
 
 void RecordTFDataAutotune(const string& name) {
@@ -119,5 +128,12 @@
   }
 }
 
+void UpdateXlaCompilationTime(const uint64 compilation_time_usecs) {
+  if (compilation_time_usecs > 0) {
+    xla_compilations->GetCell()->IncrementBy(1);
+    xla_compilation_time_usecs->GetCell()->IncrementBy(compilation_time_usecs);
+  }
+}
+
 }  // namespace metrics
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/metrics.h b/tensorflow/core/common_runtime/metrics.h
index b756382..1c0f795 100644
--- a/tensorflow/core/common_runtime/metrics.h
+++ b/tensorflow/core/common_runtime/metrics.h
@@ -64,6 +64,9 @@
 // TODO(jtkeeling): Should we record building/optimizing tf.functions?
 void UpdateGraphBuildTime(const uint64 running_time_usecs);
 
+// Updates the metrics stored about time XLA spents compiling graphs.
+void UpdateXlaCompilationTime(const uint64 compilation_time_usecs);
+
 }  // namespace metrics
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc
index 5b61e66..486b21e 100644
--- a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc
+++ b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc
@@ -55,8 +55,8 @@
         NodeDebugInfo debug_info(*n);
         NodeBuilder node_builder(name, op, OpRegistry::Global(), &debug_info);
         node_builder.Device(n->requested_device());
-        string colo;
-        if (GetNodeAttr(n_attrs, "_class", &colo).ok()) {
+        const string& colo = GetNodeAttrString(n_attrs, "_class");
+        if (!colo.empty()) {
           node_builder.Attr("_class", colo);
         }
         return node_builder;
diff --git a/tensorflow/core/common_runtime/partitioning_utils.cc b/tensorflow/core/common_runtime/partitioning_utils.cc
index d27e9da..8f9583c 100644
--- a/tensorflow/core/common_runtime/partitioning_utils.cc
+++ b/tensorflow/core/common_runtime/partitioning_utils.cc
@@ -102,8 +102,10 @@
     TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
     AllocatorAttributes alloc_attr;
     DataType type = attr_value->type();
-    MemoryType mtype = (device_type == "TPU") ? MTypeFromDTypeIntsOnDevice(type)
-                                              : MTypeFromDType(type);
+    MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
+                        device_type == "XLA_GPU")
+                           ? MTypeFromDTypeIntsOnDevice(type)
+                           : MTypeFromDType(type);
     if (mtype == HOST_MEMORY) {
       alloc_attr.set_on_host(true);
     }
@@ -115,8 +117,10 @@
     TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
     AllocatorAttributes alloc_attr;
     DataType type = attr_value->type();
-    MemoryType mtype = (device_type == "TPU") ? MTypeFromDTypeIntsOnDevice(type)
-                                              : MTypeFromDType(type);
+    MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
+                        device_type == "XLA_GPU")
+                           ? MTypeFromDTypeIntsOnDevice(type)
+                           : MTypeFromDType(type);
     if (mtype == HOST_MEMORY) {
       alloc_attr.set_on_host(true);
     }
diff --git a/tensorflow/core/common_runtime/rendezvous_util_test.cc b/tensorflow/core/common_runtime/rendezvous_util_test.cc
index 093fa79..cb3fc45 100644
--- a/tensorflow/core/common_runtime/rendezvous_util_test.cc
+++ b/tensorflow/core/common_runtime/rendezvous_util_test.cc
@@ -33,7 +33,7 @@
 // string -> Tensor<string>
 Tensor V(const string& content) {
   Tensor tensor(DT_STRING, TensorShape({}));
-  tensor.scalar<string>()() = content;
+  tensor.scalar<tstring>()() = content;
   return tensor;
 }
 
@@ -41,7 +41,7 @@
 string V(const Tensor& tensor) {
   CHECK_EQ(tensor.dtype(), DT_STRING);
   CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
-  return tensor.scalar<string>()();
+  return tensor.scalar<tstring>()();
 }
 
 string MakeStringKey(const string& name) {
diff --git a/tensorflow/core/common_runtime/ring_gatherer_test.cc b/tensorflow/core/common_runtime/ring_gatherer_test.cc
index f0f2998..a564868 100644
--- a/tensorflow/core/common_runtime/ring_gatherer_test.cc
+++ b/tensorflow/core/common_runtime/ring_gatherer_test.cc
@@ -15,6 +15,7 @@
 #include "tensorflow/core/common_runtime/ring_gatherer.h"
 
 #include <algorithm>
+
 #include "absl/memory/memory.h"
 #include "tensorflow/core/common_runtime/base_collective_executor.h"
 #include "tensorflow/core/common_runtime/collective_rma_local.h"
@@ -34,6 +35,7 @@
 #include "tensorflow/core/lib/core/notification.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/unbounded_work_queue.h"
 #include "tensorflow/core/public/session_options.h"
 #include "tensorflow/core/public/version.h"
 
@@ -44,8 +46,9 @@
 class FailTestRMA : public CollectiveRemoteAccessLocal {
  public:
   FailTestRMA(const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
-              int64 step_id, int fail_after)
-      : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id),
+              std::shared_ptr<UnboundedWorkQueue> work_queue, int64 step_id,
+              int fail_after)
+      : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, work_queue, step_id),
         fail_after_(fail_after) {}
 
   bool MaybeFail(const StatusCallback& done) {
@@ -164,12 +167,15 @@
     if (!dev_mgr_ || device_type == DEVICE_CPU) {
       LOG(ERROR) << "resetting dev_mgr for " << local_devices.size()
                  << " devices: ";
-      dev_mgr_.reset(new DeviceMgr(std::move(local_devices)));
+      dev_mgr_ = absl::make_unique<DeviceMgr>(std::move(local_devices));
     }
-    if (!gpu_ring_order_) gpu_ring_order_.reset(new string());
-    dev_resolver_.reset(new DeviceResolverLocal(dev_mgr_.get()));
-    rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(), kStepId,
-                           fail_after);
+    if (!gpu_ring_order_) {
+      gpu_ring_order_ = absl::make_unique<string>();
+    }
+    dev_resolver_ = absl::make_unique<DeviceResolverLocal>(dev_mgr_.get());
+    work_queue_ = std::make_shared<UnboundedWorkQueue>(Env::Default(), "test");
+    rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(), work_queue_,
+                           kStepId, fail_after);
     col_exec_ = new BaseCollectiveExecutor(
         &col_exec_mgr_, rma_, kStepId, dev_mgr_.get(), gpu_ring_order_.get());
     col_params_.name = "test_collective";
@@ -518,6 +524,7 @@
   CollectiveExecutor* col_exec_;
   CollectiveRemoteAccessLocal* rma_;
   std::unique_ptr<DeviceResolverLocal> dev_resolver_;
+  std::shared_ptr<UnboundedWorkQueue> work_queue_;
   std::vector<DeviceInstance*> instances_;
   CollectiveParams col_params_;
   std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc
index 16dbabd..6141d33 100644
--- a/tensorflow/core/common_runtime/ring_reducer_test.cc
+++ b/tensorflow/core/common_runtime/ring_reducer_test.cc
@@ -15,6 +15,7 @@
 #include "tensorflow/core/common_runtime/ring_reducer.h"
 
 #include <algorithm>
+
 #include "absl/memory/memory.h"
 #include "tensorflow/core/common_runtime/base_collective_executor.h"
 #include "tensorflow/core/common_runtime/collective_rma_local.h"
@@ -34,6 +35,7 @@
 #include "tensorflow/core/lib/core/notification.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/unbounded_work_queue.h"
 #include "tensorflow/core/public/session_options.h"
 #include "tensorflow/core/public/version.h"
 
@@ -44,8 +46,9 @@
 class FailTestRMA : public CollectiveRemoteAccessLocal {
  public:
   FailTestRMA(const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
-              int64 step_id, int fail_after)
-      : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id),
+              std::shared_ptr<UnboundedWorkQueue> work_queue, int64 step_id,
+              int fail_after)
+      : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, work_queue, step_id),
         fail_after_(fail_after) {}
 
   bool MaybeFail(const StatusCallback& done) {
@@ -184,14 +187,17 @@
       }
     }
     if (!dev_mgr_ || device_type == DEVICE_CPU) {
-      LOG(ERROR) << "resetting dev_mgr for " << local_devices.size()
-                 << " devices: ";
-      dev_mgr_.reset(new DeviceMgr(std::move(local_devices)));
+      LOG(INFO) << "resetting dev_mgr for " << local_devices.size()
+                << " devices: ";
+      dev_mgr_ = absl::make_unique<DeviceMgr>(std::move(local_devices));
     }
-    if (!gpu_ring_order_) gpu_ring_order_.reset(new string());
-    dev_resolver_.reset(new DeviceResolverLocal(dev_mgr_.get()));
-    rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(), kStepId,
-                           fail_after);
+    if (!gpu_ring_order_) {
+      gpu_ring_order_ = absl::make_unique<string>();
+    }
+    dev_resolver_ = absl::make_unique<DeviceResolverLocal>(dev_mgr_.get());
+    work_queue_ = std::make_shared<UnboundedWorkQueue>(Env::Default(), "test");
+    rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(), work_queue_,
+                           kStepId, fail_after);
     col_exec_ = new BaseCollectiveExecutor(
         &col_exec_mgr_, rma_, kStepId, dev_mgr_.get(), gpu_ring_order_.get());
     col_params_.name = "test_collective";
@@ -545,6 +551,7 @@
   CollectiveExecutor* col_exec_;
   CollectiveRemoteAccessLocal* rma_;
   std::unique_ptr<DeviceResolverLocal> dev_resolver_;
+  std::shared_ptr<UnboundedWorkQueue> work_queue_;
   std::vector<DeviceInstance*> instances_;
   CollectiveParams col_params_;
   std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
diff --git a/tensorflow/core/common_runtime/session.cc b/tensorflow/core/common_runtime/session.cc
index eabcb7c..575fafd 100644
--- a/tensorflow/core/common_runtime/session.cc
+++ b/tensorflow/core/common_runtime/session.cc
@@ -92,6 +92,7 @@
   // Starts exporting metrics through a platform-specific monitoring API (if
   // provided). For builds using "tensorflow/core/platform/default", this is
   // currently a no-op.
+  session_created->GetCell()->Set(true);
   monitoring::StartExporter();
   s = factory->NewSession(options, out_session);
   if (!s.ok()) {
diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc
index e8ac66e..2333e55 100644
--- a/tensorflow/core/common_runtime/shape_refiner.cc
+++ b/tensorflow/core/common_runtime/shape_refiner.cc
@@ -83,7 +83,15 @@
           " not in [0, ", outer_context->num_inputs(), ").");
     }
 
-    node_context->set_output(0, outer_context->input(index));
+    // TODO(b/134547156): TEMPORARY WORKAROUND. If input shape handle is not set
+    // in outer context, set _Arg node output shape to unknown.
+    if (outer_context->input(index).SameHandle(ShapeHandle())) {
+      LOG(WARNING) << "Function instantiation has undefined input shape at "
+                   << "index: " << index << " in the outer inference context.";
+      node_context->set_output(0, node_context->UnknownShape());
+    } else {
+      node_context->set_output(0, outer_context->input(index));
+    }
 
     auto* resource = outer_context->input_handle_shapes_and_types(index);
     if (resource) {
diff --git a/tensorflow/core/common_runtime/test_collective_executor_mgr.h b/tensorflow/core/common_runtime/test_collective_executor_mgr.h
index 8020583..6436dea 100644
--- a/tensorflow/core/common_runtime/test_collective_executor_mgr.h
+++ b/tensorflow/core/common_runtime/test_collective_executor_mgr.h
@@ -47,6 +47,10 @@
                   const StatusCallback& done) override {
     done(errors::Internal("Unimplemented"));
   }
+
+  void RunClosure(std::function<void()>) override {
+    LOG(FATAL) << "Unimplemented";
+  }
 };
 
 class TestCollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD
index 135f73d..462b447 100644
--- a/tensorflow/core/debug/BUILD
+++ b/tensorflow/core/debug/BUILD
@@ -22,7 +22,7 @@
 
 # For platform specific build config
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_additional_all_protos",
     "tf_kernel_tests_linkstatic",
     "tf_proto_library",
diff --git a/tensorflow/core/debug/debug_graph_utils.cc b/tensorflow/core/debug/debug_graph_utils.cc
index d5498ed..038418a 100644
--- a/tensorflow/core/debug/debug_graph_utils.cc
+++ b/tensorflow/core/debug/debug_graph_utils.cc
@@ -56,6 +56,10 @@
     return Status::OK();
   }
 
+  // Debug ops and URLs for wildcard node names (if any).
+  std::vector<string> default_debug_ops;
+  std::vector<string> default_debug_urls;
+
   // A map from tensor name (e.g., "node_a:0") to list of debug op names
   // (e.g., {"DebugIdentity", "DebugNanCount"})
   std::unordered_map<string, std::vector<string>> tensor_watches;
@@ -65,16 +69,39 @@
 
   // Cache the proto content for fast lookup later
   for (const DebugTensorWatch& watch : watches) {
-    if (watch.output_slot() < 0) {
-      // The semantics of output_slot == -1 is that the node is watched only
-      // for completion, but not for output tensor values (see
-      // NodeCompletionCallback in debug_gateway.h).
-      continue;
-    }
     if (watch.debug_ops().empty()) {
       continue;
     }
 
+    if (watch.debug_urls().empty()) {
+      continue;
+    }
+
+    if (watch.node_name() == "*") {
+      if (watch.output_slot() == -1) {
+        default_debug_ops.insert(default_debug_ops.end(),
+                                 watch.debug_ops().begin(),
+                                 watch.debug_ops().end());
+        default_debug_urls.insert(default_debug_urls.end(),
+                                  watch.debug_urls().begin(),
+                                  watch.debug_urls().end());
+      } else {
+        return Status(error::FAILED_PRECONDITION,
+                      strings::StrCat(
+                          "output_slot is expected to be -1 for wildcard ",
+                          "node name (\"*\"), but got ", watch.output_slot()));
+      }
+      continue;
+    } else {
+      if (watch.output_slot() < 0) {
+        return Status(
+            error::FAILED_PRECONDITION,
+            strings::StrCat("A negative output_slot in DebugTensorWatch is ",
+                            "valid only for the wildcard node name (\"*\"), ",
+                            "but got node name ", watch.node_name()));
+      }
+    }
+
     string tensor_name =
         strings::StrCat(watch.node_name(), ":", watch.output_slot());
 
@@ -120,9 +147,9 @@
          ++src_output_slot) {
       const string tensor_name =
           strings::StrCat(src_node->name(), ":", src_output_slot);
-      if (tensor_watches.find(tensor_name) == tensor_watches.end()) {
-        // Add debug nodes only for edges with matching source node and source
-        // output slot.
+      const bool explicit_tensor_match =
+          tensor_watches.find(tensor_name) != tensor_watches.end();
+      if (!explicit_tensor_match && default_debug_ops.empty()) {
         continue;
       }
 
@@ -146,11 +173,17 @@
                                              src_output_slot, &memory_type));
 
       // Create the copy node for the watched tensor.
+      const std::vector<string> debug_ops = explicit_tensor_match
+                                                ? tensor_watches[tensor_name]
+                                                : default_debug_ops;
+      const std::vector<string> debug_urls =
+          explicit_tensor_match ? tensor_watch_urls[tensor_name]
+                                : default_debug_urls;
       Node* copy_node;
-      Status copy_s = CreateCopyNode(
-          graph, device_type, memory_type == HOST_MEMORY, src_node->name(),
-          src_output_slot, src_dt, tensor_name, tensor_watches[tensor_name],
-          tensor_watch_urls[tensor_name], &copy_node);
+      Status copy_s =
+          CreateCopyNode(graph, device_type, memory_type == HOST_MEMORY,
+                         src_node->name(), src_output_slot, src_dt, tensor_name,
+                         debug_ops, debug_urls, &copy_node);
       if (!copy_s.ok()) {
         return Status(
             error::FAILED_PRECONDITION,
@@ -163,13 +196,13 @@
 
       // Create all requested debug nodes and their edges to the Copy node.
       std::vector<Node*> debug_nodes;
-      for (size_t i = 0; i < tensor_watches[tensor_name].size(); ++i) {
-        const string& debug_op_name = tensor_watches[tensor_name][i];
+      for (size_t i = 0; i < debug_ops.size(); ++i) {
+        const string& debug_op_name = debug_ops[i];
 
         Node* debug_node;
-        Status debug_s = CreateDebugNode(
-            graph, *device, copy_node->name(), src_dt, tensor_name,
-            tensor_watch_urls[tensor_name], i, debug_op_name, &debug_node);
+        Status debug_s = CreateDebugNode(graph, *device, copy_node->name(),
+                                         src_dt, tensor_name, debug_urls, i,
+                                         debug_op_name, &debug_node);
         if (debug_s.ok()) {
           graph->AddEdge(copy_node, 0, debug_node, 0);
           debug_nodes.push_back(debug_node);
diff --git a/tensorflow/core/debug/debug_grpc_io_utils_test.cc b/tensorflow/core/debug/debug_grpc_io_utils_test.cc
index c857f12..26fd376c 100644
--- a/tensorflow/core/debug/debug_grpc_io_utils_test.cc
+++ b/tensorflow/core/debug/debug_grpc_io_utils_test.cc
@@ -147,7 +147,7 @@
 
 TEST_F(GrpcDebugTest, SendDebugTensorWithLargeStringAtIndex0ViaGrpcTest) {
   Tensor tensor(DT_STRING, TensorShape({1, 1}));
-  tensor.flat<string>()(0) = string(5000 * 1024, 'A');
+  tensor.flat<tstring>()(0) = string(5000 * 1024, 'A');
   const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
                                    "foo_tensor", 0, "DebugIdentity");
   const Status status = DebugIO::PublishDebugTensor(
@@ -162,8 +162,8 @@
 
 TEST_F(GrpcDebugTest, SendDebugTensorWithLargeStringAtIndex1ViaGrpcTest) {
   Tensor tensor(DT_STRING, TensorShape({1, 2}));
-  tensor.flat<string>()(0) = "A";
-  tensor.flat<string>()(1) = string(5000 * 1024, 'A');
+  tensor.flat<tstring>()(0) = "A";
+  tensor.flat<tstring>()(1) = string(5000 * 1024, 'A');
   const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
                                    "foo_tensor", 0, "DebugIdentity");
   const Status status = DebugIO::PublishDebugTensor(
diff --git a/tensorflow/core/debug/debug_io_utils_test.cc b/tensorflow/core/debug/debug_io_utils_test.cc
index 928a82b..3eebcb3 100644
--- a/tensorflow/core/debug/debug_io_utils_test.cc
+++ b/tensorflow/core/debug/debug_io_utils_test.cc
@@ -47,8 +47,8 @@
     tensor_a_->flat<float>()(3) = 0.0;
 
     tensor_b_.reset(new Tensor(DT_STRING, TensorShape{2}));
-    tensor_b_->flat<string>()(0) = "corge";
-    tensor_b_->flat<string>()(1) = "garply";
+    tensor_b_->flat<tstring>()(0) = "corge";
+    tensor_b_->flat<tstring>()(1) = "garply";
   }
 
   Env* env_;
@@ -182,8 +182,8 @@
 
   // Verify tensor shape and value.
   ASSERT_EQ(tensor_b_->shape(), b_prime.shape());
-  for (int i = 0; i < b_prime.flat<string>().size(); ++i) {
-    ASSERT_EQ(tensor_b_->flat<string>()(i), b_prime.flat<string>()(i));
+  for (int i = 0; i < b_prime.flat<tstring>().size(); ++i) {
+    ASSERT_EQ(tensor_b_->flat<tstring>()(i), b_prime.flat<tstring>()(i));
   }
 
   // Tear down temporary file and directories.
diff --git a/tensorflow/core/debug/grpc_session_debug_test.cc b/tensorflow/core/debug/grpc_session_debug_test.cc
index 642a2a4..65ec1ef 100644
--- a/tensorflow/core/debug/grpc_session_debug_test.cc
+++ b/tensorflow/core/debug/grpc_session_debug_test.cc
@@ -231,7 +231,7 @@
   Graph graph(OpRegistry::Global());
   Tensor a_tensor(DT_STRING, TensorShape({2, 2}));
   for (size_t i = 0; i < 4; ++i) {
-    a_tensor.flat<string>()(i) = "hello, world";
+    a_tensor.flat<tstring>()(i) = "hello, world";
   }
   Node* a = test::graph::Constant(&graph, a_tensor);
   Node* b = test::graph::Identity(&graph, a);
@@ -266,7 +266,7 @@
         ASSERT_EQ(outputs[0].dtype(), DT_STRING);
         ASSERT_EQ(outputs[0].NumElements(), 4);
         for (size_t i = 0; i < outputs[0].NumElements(); ++i) {
-          EXPECT_EQ(outputs[0].flat<string>()(i), "hello, world");
+          EXPECT_EQ(outputs[0].flat<tstring>()(i), "hello, world");
         }
         TF_CHECK_OK(session->Close());
 
@@ -278,7 +278,7 @@
         ASSERT_EQ(1, dumped_tensors.size());
         ASSERT_EQ(TensorShape({2, 2}), dumped_tensors[0].shape());
         for (size_t i = 0; i < 4; ++i) {
-          ASSERT_EQ("hello, world", dumped_tensors[0].flat<string>()(i));
+          ASSERT_EQ("hello, world", dumped_tensors[0].flat<tstring>()(i));
         }
 
         DeleteDumpDir();
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index b33b785..d2c48aa 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -8,11 +8,11 @@
 
 # For platform specific build config
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_kernel_tests_linkstatic",
 )
 load(
-    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "//tensorflow/core/platform:default/build_config_root.bzl",
     "tf_cuda_tests_tags",
 )
 
diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
index 21fcd05..2751deb 100644
--- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
+++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
@@ -25,6 +25,7 @@
 #include "tensorflow/core/common_runtime/process_util.h"
 #include "tensorflow/core/distributed_runtime/worker_cache.h"
 #include "tensorflow/core/distributed_runtime/worker_interface.h"
+#include "tensorflow/core/framework/cancellation.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status.h"
@@ -389,26 +390,53 @@
     mutex_lock l(mu_);
     if (status_.ok()) {
       status_ = derived_status;
-      for (BaseRecvTensorCall* call : active_) {
-        call->StartAbort(derived_status);
+      for (auto& entry : active_) {
+        entry.first->StartAbort(derived_status);
+        entry.second();
       }
       active_.clear();
     }
   }
 }
 
-void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call) {
-  mutex_lock l(mu_);
-  if (!status_.ok()) {
-    call->StartAbort(status_);
-  } else {
-    CHECK(active_.insert(call).second);
+void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call,
+                                        const Rendezvous::Args& args) {
+  CancellationManager* cm = args.cancellation_manager;
+  {
+    mutex_lock l(mu_);
+    if (!status_.ok()) {
+      call->StartAbort(status_);
+      return;
+    }
+    bool already_cancelled = false;
+    InactiveCallback callback = [] {};
+    if (cm != nullptr) {
+      auto token = cm->get_cancellation_token();
+      already_cancelled = !cm->RegisterCallback(token, [this, call] {
+        {
+          mutex_lock l(mu_);
+          if (active_.find(call) == active_.end()) return;
+          call->StartAbort(
+              errors::Cancelled("RecvFromRemoteAsync is cancelled."));
+        }
+      });
+      callback = [cm, token] { cm->TryDeregisterCallback(token); };
+    }
+    if (already_cancelled) {
+      call->StartAbort(errors::Cancelled("RecvFromRemoteAsync is cancelled."));
+    } else {
+      CHECK(active_.emplace(call, callback).second);
+    }
   }
 }
 
 void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) {
   mutex_lock l(mu_);
-  active_.erase(call);
+  auto it = active_.find(call);
+  if (it != active_.end()) {
+    it->second();
+    active_.erase(it);
+  }
 }
 
 BaseRemoteRendezvous::DeferredCall::DeferredCall(const ParsedKey& parsed,
diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h
index 6751fb8..fde589b 100644
--- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h
+++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h
@@ -160,7 +160,7 @@
                             DeviceNameUtils::ParsedName dst);
 
   // If aborted, aborts "call". Otherwise, adds "call" into active_.
-  void RegisterCall(BaseRecvTensorCall* call);
+  void RegisterCall(BaseRecvTensorCall* call, const Rendezvous::Args& args);
 
   // Removes "call" from active_ if "call" is in active_.
   void DeregisterCall(BaseRecvTensorCall* call);
@@ -192,8 +192,11 @@
   };
   std::vector<DeferredCall> deferred_calls_ GUARDED_BY(mu_);
 
+  typedef std::function<void()> InactiveCallback;
+
   // Active outstanding RecvTensor calls.
-  gtl::FlatSet<BaseRecvTensorCall*> active_ GUARDED_BY(mu_);
+  std::unordered_map<BaseRecvTensorCall*, InactiveCallback> active_
+      GUARDED_BY(mu_);
 
   bool is_initialized_locked() SHARED_LOCKS_REQUIRED(mu_) {
     return session_ != nullptr;
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
index 92b2e4e..b2af3c2 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
@@ -137,11 +137,11 @@
                            nullptr /*send_dev_ctx*/, to_device_ctx, cpu_dev,
                            to_device, cpu_attr, to_alloc_attr, cpu_tensor,
                            to_tensor, dev_to_dev_stream_index,
-                           [cpu_tensor, done](const Status& s) {
+                           [this, cpu_tensor, done](const Status& s) {
                              delete cpu_tensor;
                              // This callback must not block, so execute
                              // done in another thread.
-                             SchedClosure([s, done] { done(s); });
+                             RunClosure([s, done] { done(s); });
                            });
         delete state;
         return;
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.h b/tensorflow/core/distributed_runtime/collective_rma_distributed.h
index 9434cac..7d8fcc61 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed.h
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.h
@@ -14,8 +14,10 @@
 ==============================================================================*/
 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_RMA_DISTRIBUTED_H_
 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_RMA_DISTRIBUTED_H_
+
 #include "tensorflow/core/common_runtime/collective_rma_local.h"
 #include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/unbounded_work_queue.h"
 
 namespace tensorflow {
 class WorkerCacheInterface;
@@ -23,11 +25,11 @@
 // Extend CollectiveRemoteAccessLocal with access to remote peers.
 class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
  public:
-  CollectiveRemoteAccessDistributed(const DeviceMgr* dev_mgr,
-                                    DeviceResolverInterface* dev_resolver,
-                                    WorkerCacheInterface* worker_cache,
-                                    int64 step_id)
-      : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id),
+  CollectiveRemoteAccessDistributed(
+      const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
+      std::shared_ptr<UnboundedWorkQueue> work_queue,
+      WorkerCacheInterface* worker_cache, int64 step_id)
+      : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, work_queue, step_id),
         worker_cache_(worker_cache) {}
 
   ~CollectiveRemoteAccessDistributed() override {}
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
index 4ed8b31..d554650 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
@@ -170,7 +170,9 @@
 
 class CollRMADistTest : public ::testing::Test {
  protected:
-  CollRMADistTest() {}
+  CollRMADistTest()
+      : work_queue_(
+            std::make_shared<UnboundedWorkQueue>(Env::Default(), "test")) {}
 
   ~CollRMADistTest() override {
     for (DeviceMgr* dm : device_mgrs_) {
@@ -198,7 +200,8 @@
     }
     // All tests simulate requests from worker 0 to worker 1.
     rma_.reset(new CollectiveRemoteAccessDistributed(
-        device_mgrs_[0], dev_resolvers_[dev0_worker_name], &wc_, kStepId));
+        device_mgrs_[0], dev_resolvers_[dev0_worker_name], work_queue_, &wc_,
+        kStepId));
 
     const int kNumElts = 8;
     expected_value_ = Tensor(DT_FLOAT, {kNumElts});
@@ -257,6 +260,7 @@
   std::vector<DeviceMgr*> device_mgrs_;
   std::unordered_map<string, DeviceResolverDistributed*> dev_resolvers_;
   std::unordered_map<string, std::vector<string>> dev_by_task_;
+  std::shared_ptr<UnboundedWorkQueue> work_queue_;
   std::vector<FakeWorker*> workers_;
   std::unique_ptr<CollectiveRemoteAccessDistributed> rma_;
   mutex mu_;
diff --git a/tensorflow/core/distributed_runtime/eager/BUILD b/tensorflow/core/distributed_runtime/eager/BUILD
index bffe4dc..922d7c2 100644
--- a/tensorflow/core/distributed_runtime/eager/BUILD
+++ b/tensorflow/core/distributed_runtime/eager/BUILD
@@ -44,6 +44,7 @@
 
 cc_library(
     name = "remote_execute_node",
+    srcs = ["remote_execute_node.cc"],
     hdrs = ["remote_execute_node.h"],
     deps = [
         ":eager_client",
@@ -130,6 +131,20 @@
     ],
 )
 
+tf_cc_test(
+    name = "remote_mgr_test",
+    size = "small",
+    srcs = ["remote_mgr_test.cc"],
+    deps = [
+        ":remote_mgr",
+        "//tensorflow/core:eager_service_proto_cc",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core/common_runtime/eager:tensor_handle",
+    ],
+)
+
 cc_library(
     name = "remote_tensor_handle_data",
     srcs = ["remote_tensor_handle_data.cc"],
@@ -141,3 +156,24 @@
         "//tensorflow/core/common_runtime/eager:tensor_handle_data",
     ],
 )
+
+cc_library(
+    name = "remote_copy_node",
+    srcs = [
+        "remote_copy_node.cc",
+    ],
+    hdrs = [
+        "remote_copy_node.h",
+    ],
+    visibility = ["//tensorflow:internal"],
+    deps = [
+        ":remote_mgr",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core/common_runtime/eager:attr_builder",
+        "//tensorflow/core/common_runtime/eager:eager_executor",
+        "//tensorflow/core/common_runtime/eager:eager_operation",
+        "//tensorflow/core/common_runtime/eager:tensor_handle",
+    ],
+)
diff --git a/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h b/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h
index 88847a2..84da304 100644
--- a/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h
+++ b/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h
@@ -34,18 +34,17 @@
         eager_client_(eager_client) {}
 
   Status Run() override {
-    EnqueueResponse response;
-    Status status;
-    // TODO(b/136025146): Remove wait for notification
-    Notification n;
-    eager_client_->EnqueueAsync(request_.get(), &response,
-                                [&n, &status](const tensorflow::Status& s) {
-                                  status.Update(s);
-                                  n.Notify();
-                                });
-    n.WaitForNotification();
-
-    return status;
+    EnqueueResponse* response = new EnqueueResponse;
+    eager_client_->StreamingEnqueueAsync(
+        request_.get(), response, [response](const tensorflow::Status& s) {
+          if (!s.ok()) {
+            LOG(WARNING) << "Ignoring an error encountered when deleting "
+                            "remote tensors handles: "
+                         << s.ToString();
+          }
+          delete response;
+        });
+    return Status::OK();
   }
 
   void Abort(Status status) override {}
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
index ae2fd93..deee5c8 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
@@ -25,6 +25,7 @@
 #include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/common_runtime/process_util.h"
 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
+#include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h"
 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
 #include "tensorflow/core/distributed_runtime/server_lib.h"
 #include "tensorflow/core/distributed_runtime/session_mgr.h"
@@ -122,14 +123,19 @@
         return r;
       };
 
+  LOG(INFO) << "Creating " << (request->async() ? "async" : "sync")
+            << " eager service context with rendezvous_id on host "
+            << port::Hostname();
   tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
       SessionOptions(),
       tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
       tensorflow::ContextMirroringPolicy::MIRRORING_NONE, request->async(),
       device_mgr, false, r, GetDefaultCustomKernelCreator(),
       worker_session->cluster_flr.get());
+  // Ownership will be transferred to the ServerContext, or else in an error
+  // case ctx will be deleted by this unref.
+  core::ScopedUnref unref_ctx(ctx);
 
-  Status s;
   std::vector<string> remote_workers;
   worker_session->worker_cache->ListWorkers(&remote_workers);
   remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(),
@@ -137,20 +143,18 @@
                        remote_workers.end());
 
   std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
-  s = worker_session->worker_cache->GetEagerClientCache(&remote_eager_workers);
-  if (!s.ok()) {
-    delete ctx;
-    return s;
-  }
+  TF_RETURN_IF_ERROR(
+      worker_session->worker_cache->GetEagerClientCache(&remote_eager_workers));
 
   auto remote_mgr =
-      absl::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/false);
-  s = ctx->InitializeRemoteWorker(
+      absl::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/false, ctx);
+  Status s = ctx->InitializeRemoteWorker(
       std::move(remote_eager_workers), worker_session->remote_device_mgr(),
       remote_workers, request->context_id(), std::move(rendezvous_creator),
       std::move(remote_mgr));
   if (!s.ok()) {
-    delete ctx;
+    VLOG(1) << "EagerContext::InitializeRemoteWorker failed with "
+            << s.ToString();
     return s;
   }
 
@@ -163,7 +167,6 @@
   {
     mutex_lock l(contexts_mu_);
     if (contexts_.find(request->context_id()) != contexts_.end()) {
-      delete ctx;
       return errors::InvalidArgument("EagerService:CreateContext failed. ",
                                      "Context id: <", request->context_id(),
                                      "> already exists.");
@@ -175,6 +178,24 @@
   return Status::OK();
 }
 
+Status EagerServiceImpl::CreateMasterContext(
+    const tensorflow::uint64 context_id, EagerContext* context) {
+  {
+    mutex_lock l(contexts_mu_);
+    auto iter = contexts_.find(context_id);
+    if (iter != contexts_.end()) {
+      return errors::InvalidArgument(
+          "EagerService:CreateMasterContext failed. ", "Context id: <",
+          context_id, "> already exists.");
+    }
+  }
+  ServerContext* server_context =
+      ServerContext::CreateMasterContext(context, env_);
+  mutex_lock l(contexts_mu_);
+  contexts_.emplace(context_id, server_context);
+  return Status::OK();
+}
+
 Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
   const tensorflow::Tensor* t = nullptr;
 
@@ -187,14 +208,14 @@
 }
 
 Status EagerServiceImpl::ExecuteOp(const Operation& operation,
-                                   ServerContext* server_context,
+                                   EagerContext* eager_context,
                                    QueueResponse* queue_response) {
   std::unique_ptr<tensorflow::EagerOperation> op;
   const char* name = operation.name().c_str();  // Shorthand
   const tensorflow::AttrTypeMap* types;
   bool is_function = false;
   TF_RETURN_IF_ERROR(tensorflow::AttrTypeMapForOp(name, &types, &is_function));
-  if (is_function && !server_context->Context()->FindFunctionByName(name)) {
+  if (is_function && !eager_context->FindFunctionByName(name)) {
     return errors::NotFound(
         "'", name,
         "' is neither a type of a primitive operation nor a name "
@@ -203,8 +224,8 @@
         ". Make sure the operation or function is "
         "registered in the binary running in this process.");
   }
-  op.reset(new tensorflow::EagerOperation(server_context->Context(), name,
-                                          is_function, types));
+  op.reset(
+      new tensorflow::EagerOperation(eager_context, name, is_function, types));
 
   TF_RETURN_IF_ERROR(op->SetDeviceName(operation.device().c_str()));
 
@@ -214,9 +235,11 @@
     for (const auto& remote_handle : operation.inputs()) {
       tensorflow::TensorHandle* handle;
       TF_RETURN_IF_ERROR(
-          server_context->Context()->RemoteMgr()->DeserializeRemoteTensorHandle(
+          eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
               remote_handle, &handle));
       op->AddInput(handle);
+      // Unref handle since it has a ref as an input now.
+      handle->Unref();
     }
   }
 
@@ -226,16 +249,16 @@
 
   int num_retvals = 0;
   // TODO(nareshmodi): Consider caching this.
-  TF_RETURN_IF_ERROR(GetNumRetvals(server_context->Context(), operation.name(),
+  TF_RETURN_IF_ERROR(GetNumRetvals(eager_context, operation.name(),
                                    operation.attrs(), &num_retvals));
 
   tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> retvals(
       num_retvals);
+  VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id();
   TF_RETURN_IF_ERROR(EagerExecute(op.get(), &retvals, &num_retvals));
   retvals.resize(num_retvals);
 
-  server_context->Context()->RemoteMgr()->AddOperationOutputs(retvals,
-                                                              operation.id());
+  eager_context->RemoteMgr()->AddOperationOutputs(retvals, operation.id());
 
   for (auto* handle : retvals) {
     TF_RETURN_IF_ERROR(TensorHandleShape(handle, queue_response->add_shape()));
@@ -255,13 +278,21 @@
   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
   core::ScopedUnref context_unref(context);
 
+  auto executor = context->Context()->Executor();
   for (const auto& item : request->queue()) {
     auto* queue_response = response->add_queue_response();
     if (item.has_operation()) {
-      TF_RETURN_IF_ERROR(ExecuteOp(item.operation(), context, queue_response));
+      TF_RETURN_IF_ERROR(
+          ExecuteOp(item.operation(), context->Context(), queue_response));
     } else {
-      TF_RETURN_IF_ERROR(context->Context()->RemoteMgr()->DeleteTensorHandle(
-          RemoteTensorHandleInternal(item.handle_to_decref())));
+      auto handle_to_decref = absl::make_unique<RemoteTensorHandleInternal>(
+          item.handle_to_decref());
+      auto node = absl::make_unique<ClientTensorHandleDeleteNode>(
+          context, std::move(handle_to_decref));
+      TF_RETURN_IF_ERROR(
+          executor->Async()
+              ? context->Context()->Executor()->Add(std::move(node))
+              : node->Run());
     }
   }
 
@@ -279,7 +310,7 @@
         "EagerServiceImpl::WaitQueueDone is not "
         "implemented for particular op IDs.");
   }
-  return context->Context()->AsyncWait();
+  return context->Context()->Executor()->WaitForAllPendingNodes();
 }
 
 Status EagerServiceImpl::KeepAlive(const KeepAliveRequest* request,
@@ -293,6 +324,8 @@
 
 Status EagerServiceImpl::CloseContext(const CloseContextRequest* request,
                                       CloseContextResponse* response) {
+  VLOG(1) << "Executing EagerService::CloseContext for context "
+          << request->context_id();
   ServerContext* context = nullptr;
   if (!GetServerContext(request->context_id(), &context).ok()) {
     // Swallow the error here.
@@ -342,8 +375,8 @@
     Device* device;
     TF_RETURN_IF_ERROR(
         ctx->FindDeviceFromName(request->device_name().c_str(), &device));
-    TF_RETURN_IF_ERROR(
-        EagerCopyToDevice(tensor_handle, ctx, device, false, &copied_handle));
+    TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, ctx, ctx->Executor(),
+                                         device, false, &copied_handle));
     tensors.push_back(copied_handle);
     tensor_handle->Unref();
   }
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h
index b64c0ff..5e75c4b 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h
@@ -16,9 +16,9 @@
 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_
 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_
 
-
 #include "tensorflow/core/common_runtime/eager/context.h"
 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h"
 #include "tensorflow/core/distributed_runtime/worker_env.h"
 #include "tensorflow/core/lib/core/refcount.h"
@@ -80,6 +80,11 @@
   Status CreateContext(const CreateContextRequest* request,
                        CreateContextResponse* response);
 
+  // Create a ServerContext for master eager context.
+  Status CreateMasterContext(const tensorflow::uint64 context_id,
+                             EagerContext* context);
+
+  // Used by both Enqueue and StreamingEnqueue RPCs.
   Status Enqueue(const EnqueueRequest* request, EnqueueResponse* response);
 
   Status WaitQueueDone(const WaitQueueDoneRequest* request,
@@ -103,15 +108,29 @@
   // and the EagerContext).
   class ServerContext : public core::RefCounted {
    public:
+    // Create a ServerContext for local master.
+    static ServerContext* CreateMasterContext(tensorflow::EagerContext* ctx,
+                                              const WorkerEnv* env) {
+      return new ServerContext(ctx, -1, env, /* is_master= */ true);
+    }
+
     explicit ServerContext(tensorflow::EagerContext* ctx,
-                           int64 destroy_after_secs, const WorkerEnv* env)
-        : ctx_(ctx), env_(env) {
+                           int64 destroy_after_secs, const WorkerEnv* env,
+                           const bool is_master = false)
+        : ctx_(ctx), env_(env), is_master_(is_master) {
+      ctx->Ref();
       destroy_after_micros_ =
           destroy_after_secs * tensorflow::EnvTime::kSecondsToMicros;
       RecordAccess();
     }
-    ~ServerContext() {
 
+    ~ServerContext() {
+      // TFE_Context is responsible for shutting down master eager context.
+      if (!is_master_) {
+        ctx_->WaitForAndCloseRemoteContexts();
+      }
+      // ctx_->RefCountIsOne() should be true here when is_master_ = false.
+      // TODO(iga): Remove EagerContext refcounting.
       ctx_->Unref();
     }
 
@@ -138,12 +157,43 @@
     mutex last_accessed_mu_;
     int64 last_accessed_micros_ GUARDED_BY(last_accessed_mu_);
     int64 destroy_after_micros_;
+
+    const bool is_master_;
   };
   // The returned ServerContext will need to be Unrefed.
   tensorflow::Status GetServerContext(uint64, ServerContext**);
 
+  class ClientTensorHandleDeleteNode : public EagerNode {
+   public:
+    ClientTensorHandleDeleteNode(
+        ServerContext* context,
+        std::unique_ptr<RemoteTensorHandleInternal> handle_to_delete)
+        : tensorflow::EagerNode(),
+          context_(context),
+          handle_to_delete_(std::move(handle_to_delete)) {
+      context_->Ref();
+    }
+
+    ~ClientTensorHandleDeleteNode() override { context_->Unref(); }
+
+    Status Run() override {
+      VLOG(3) << "ServerContext: Deleting tensor handle "
+              << handle_to_delete_->op_id << ":"
+              << handle_to_delete_->output_num;
+      return context_->Context()->RemoteMgr()->DeleteTensorHandle(
+          *handle_to_delete_);
+    }
+
+    void Abort(Status status) override {}
+
+   private:
+    // Owns one reference.
+    ServerContext* const context_;
+    const std::unique_ptr<RemoteTensorHandleInternal> handle_to_delete_;
+  };
+
  private:
-  Status ExecuteOp(const Operation& operation, ServerContext* server_context,
+  Status ExecuteOp(const Operation& operation, EagerContext* eager_context,
                    QueueResponse* queue_response);
   const WorkerEnv* const env_;  // Not owned.
 
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
index ff07b9f..1c81b76 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
@@ -125,6 +125,8 @@
     auto* input = operation->add_inputs();
     input->set_op_id(tensor_handle_pair.first);
     input->set_output_num(tensor_handle_pair.second);
+    input->set_op_device(device);
+    input->set_device(device);
   }
 
   for (const auto& attr_entry : attrs) {
@@ -379,6 +381,49 @@
                                                &close_context_response));
 }
 
+// Test requests sent to the eager service on master.
+TEST_F(EagerServiceImplTest, RequestsToMasterTest) {
+  tensorflow::Rendezvous* rendezvous =
+      new tensorflow::IntraProcessRendezvous(device_mgr_.get());
+  // Create a master eager context.
+  tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
+      SessionOptions(),
+      tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
+      tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false,
+      device_mgr_.get(), false, rendezvous, GetDefaultCustomKernelCreator(),
+      nullptr);
+  const uint64 context_id = random::New64();
+
+  // Set RemoteMgr to ctx.
+  auto remote_mgr =
+      absl::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/true, ctx);
+  TF_ASSERT_OK(ctx->InitializeRemoteWorker(nullptr, nullptr, {}, context_id,
+                                           nullptr, std::move(remote_mgr)));
+
+  TestEagerServiceImpl eager_service_impl(&worker_env_);
+
+  SendTensorRequest send_tensor_request;
+  send_tensor_request.set_context_id(context_id);
+  send_tensor_request.set_op_id(1);
+  SetTensorProto(send_tensor_request.add_tensors());
+  SendTensorResponse send_tensor_response;
+
+  // Unable to handle the request since there is no eager context.
+  Status status = eager_service_impl.SendTensor(&send_tensor_request,
+                                                &send_tensor_response);
+  EXPECT_EQ(error::INVALID_ARGUMENT, status.code());
+  EXPECT_TRUE(absl::StrContains(
+      status.error_message(),
+      "Unable to find a context_id matching the specified one"));
+
+  // The request can be handled after adding the master eager context to
+  // service.
+  TF_ASSERT_OK(eager_service_impl.CreateMasterContext(context_id, ctx));
+  TF_ASSERT_OK(eager_service_impl.SendTensor(&send_tensor_request,
+                                             &send_tensor_response));
+  ctx->Unref();
+}
+
 TEST_F(EagerServiceImplTest, KeepAliveTest) {
   TestEagerServiceImpl eager_service_impl(&worker_env_);
 
diff --git a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc
new file mode 100644
index 0000000..6d239ea
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc
@@ -0,0 +1,305 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/eager/remote_copy_node.h"
+
+#include "tensorflow/core/common_runtime/eager/attr_builder.h"
+#include "tensorflow/core/common_runtime/eager/eager_operation.h"
+#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
+#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace eager {
+
+namespace {
+
+void PrepareRemoteOp(eager::Operation* remote_op, EagerOperation* op) {
+  remote_op->set_name(op->Name());
+
+  op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
+  remote_op->set_device(op->Device()->name());
+}
+
+Status CreateUncachedKernelAndDeviceOp(
+    EagerOperation* op, core::RefCountPtr<KernelAndDevice>* kernel) {
+  EagerContext* ctx = op->EagerContext();
+  Device* device = op->Device();
+
+  FunctionLibraryRuntime* flr = ctx->func_lib(device);
+  if (flr == nullptr) {
+    return errors::Unavailable(
+        "Unable to find a FunctionLibraryRuntime corresponding to device ",
+        device->name());
+  }
+
+  auto runner = (flr->runner() != nullptr) ? flr->runner() : ctx->runner();
+  kernel->reset(new KernelAndDeviceOp(
+      ctx->GetRendezvous(), ctx->LogMemory(), flr, runner,
+      ctx->GetCollectiveExecutorHandle(), ctx->HostCPU()));
+
+  const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
+  return kernel->get()->Init(ndef, nullptr);
+}
+
+// This gets a unique wire ID. We add a random identifier so that if the
+// worker has other clients that it is servicing, we don't have any collision.
+string GetUniqueWireID() {
+  static tensorflow::uint64 random_seed = random::New64();
+  static tensorflow::mutex wireid_mutex(tensorflow::LINKER_INITIALIZED);
+  static tensorflow::int64 wireid GUARDED_BY(wireid_mutex) = 0;
+  tensorflow::mutex_lock l(wireid_mutex);
+  return strings::StrCat(random_seed, "_", wireid++);
+}
+
+}  // namespace
+
+RemoteCopyNode::RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor,
+                               TensorHandle* src, TensorHandle* dst,
+                               Device* recv_device, uint64 recv_op_id)
+    : EagerNode(),
+      src_(src),
+      ctx_(ctx),
+      executor_(executor),
+      send_device_(src->DeviceOrHostCPU(ctx)),
+      recv_device_(recv_device),
+      wire_id_(GetUniqueWireID()),
+      recv_op_id_(recv_op_id),
+      captured_state_(std::make_shared<CapturedSharedState>(dst)) {
+  DCHECK(!send_device_->IsLocal() || !recv_device_->IsLocal());
+  src_->Ref();
+  ctx_->Ref();
+}
+
+Status RemoteCopyNode::RunLocalSend(EagerOperation* op) {
+  TF_RETURN_IF_ERROR(executor_->status());
+
+  op->AddInput(src_);
+
+  core::RefCountPtr<KernelAndDevice> kernel;
+  TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
+
+  gtl::InlinedVector<TensorValue, 4> input_vector(1);
+  TF_RETURN_IF_ERROR(src_->TensorValue(&input_vector[0]));
+
+  return kernel->Run(input_vector, nullptr, nullptr, nullptr, nullptr, nullptr);
+}
+
+Status RemoteCopyNode::StartSend() {
+  // TODO(gjn): We should consider just using the low-level SendOp::Compute()
+  // functionality here instead of constructing an Op.
+  const AttrTypeMap* types;
+  bool is_function = false;
+  Status status = AttrTypeMapForOp("_Send", &types, &is_function);
+  if (!status.ok()) {
+    captured_state_->SetSendStatus(status);
+    return status;
+  }
+  DCHECK(!is_function);
+  EagerOperation op(ctx_, "_Send", /*is_function=*/false, types);
+
+  op.SetDevice(send_device_);
+
+  op.MutableAttrs()->Set("tensor_name", wire_id_);
+  op.MutableAttrs()->Set("send_device", send_device_->name());
+  op.MutableAttrs()->Set(
+      "send_device_incarnation",
+      static_cast<int64>(send_device_->attributes().incarnation()));
+  op.MutableAttrs()->Set("recv_device", recv_device_->name());
+  op.MutableAttrs()->Set("client_terminated", false);
+
+  op.MutableAttrs()->Set("T", src_->dtype);
+
+  DCHECK(send_device_ != nullptr);
+
+  if (send_device_->IsLocal()) {
+    status = RunLocalSend(&op);
+    captured_state_->SetSendStatus(status);
+    return status;
+  } else {
+    // Prepare the request
+    EnqueueRequest request;
+    request.set_context_id(ctx_->GetContextId());
+    auto* remote_op = request.add_queue()->mutable_operation();
+    status = ctx_->RemoteMgr()->SerializeRemoteTensorHandle(
+        src_, remote_op->add_inputs(), src_->device(),
+        src_->DeviceOrHostCPU(ctx_)->name());
+    if (!status.ok()) {
+      captured_state_->SetSendStatus(status);
+      return status;
+    }
+
+    PrepareRemoteOp(remote_op, &op);
+    remote_op->set_id(ctx_->RemoteMgr()->NextOpId());
+
+    // Issue the RPC
+    eager::EagerClient* eager_client;
+    status = ctx_->GetClient(send_device_, &eager_client);
+    if (!status.ok()) {
+      captured_state_->SetSendStatus(status);
+      return status;
+    }
+
+    const std::shared_ptr<CapturedSharedState>& captured_state =
+        captured_state_;
+    EnqueueResponse* response = new EnqueueResponse;
+    // If StartRecv fails very quickly, `this` can be destroyed before the
+    // callback below is executed. So, we can't capture `this`.
+    eager_client->StreamingEnqueueAsync(
+        &request, response, [response, captured_state](const Status& s) {
+          captured_state->SetSendStatus(s);
+          if (!s.ok()) {
+            captured_state->recv_cancellation()->StartCancel();
+          }
+          delete response;
+        });
+    return Status::OK();
+  }
+}
+
+Status RemoteCopyNode::RunLocalRecv(EagerOperation* op,
+                                    std::vector<Tensor>* outputs) {
+  TF_RETURN_IF_ERROR(executor_->status());
+
+  core::RefCountPtr<KernelAndDevice> kernel;
+  TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
+
+  gtl::InlinedVector<TensorValue, 4> input_vector;
+  return kernel->Run(input_vector, outputs, nullptr, nullptr, nullptr,
+                     captured_state_->recv_cancellation());
+}
+
+Status RemoteCopyNode::RunRemoteRecv(EagerOperation* op) {
+  EnqueueRequest request;
+  uint64 context_id = ctx_->GetContextId();
+  request.set_context_id(context_id);
+  auto* remote_op = request.add_queue()->mutable_operation();
+  PrepareRemoteOp(remote_op, op);
+  remote_op->set_id(recv_op_id_);
+
+  eager::EagerClient* eager_client;
+  Status status = ctx_->GetClient(recv_device_, &eager_client);
+  if (!status.ok()) {
+    captured_state_->dst()->Poison(status);
+    return status;
+  }
+
+  // Don't issue the recv until send has completed.
+  //  - local send will complete very quickly.
+  //  - remote send will take some time, but remote->remote copy is
+  //    probably rare enough that we don't care much.
+  // Blocks until send has completed.
+  Status send_status = captured_state_->GetSendStatus();
+  if (!send_status.ok()) {
+    captured_state_->dst()->Poison(send_status);
+    return send_status;
+  }
+
+  EnqueueResponse* response = new EnqueueResponse;
+  const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
+  Device* recv_device = recv_device_;
+  eager_client->StreamingEnqueueAsync(
+      &request, response,
+      [captured_state, response, recv_device](const Status& s) {
+        if (s.ok()) {
+          Status status = captured_state->dst()->SetRemoteShape(
+              response->queue_response(0).shape(0), recv_device);
+          if (!status.ok()) {
+            LOG(ERROR) << "Ignoring an error encountered when setting remote "
+                          "shape of tensor received by remote Recv op: "
+                       << status.ToString()
+                       << "\nThis should never happen. "
+                          "Please file an issue with the TensorFlow Team.";
+          }
+        } else {
+          captured_state->dst()->Poison(s);
+        }
+        delete response;
+      });
+
+  return Status::OK();
+}
+
+Status RemoteCopyNode::StartRecv() {
+  // TODO(gjn): We should consider just using the low-level RecvOp::Compute()
+  // functionality here instead of constructing an Op.
+  const AttrTypeMap* types;
+  bool is_function = false;
+  Status status = AttrTypeMapForOp("_Recv", &types, &is_function);
+  if (!status.ok()) {
+    captured_state_->dst()->Poison(status);
+    return status;
+  }
+  DCHECK(!is_function);
+  EagerOperation op(ctx_, "_Recv", /*is_function=*/false, types);
+
+  op.SetDevice(recv_device_);
+
+  op.MutableAttrs()->Set("tensor_name", wire_id_);
+  op.MutableAttrs()->Set("send_device", send_device_->name());
+  op.MutableAttrs()->Set(
+      "send_device_incarnation",
+      static_cast<int64>(send_device_->attributes().incarnation()));
+  op.MutableAttrs()->Set("recv_device", recv_device_->name());
+  op.MutableAttrs()->Set("client_terminated", false);
+
+  op.MutableAttrs()->Set("tensor_type", src_->dtype);
+
+  if (recv_device_->IsLocal()) {
+    std::vector<Tensor> outputs(1);
+    status = RunLocalRecv(&op, &outputs);
+    if (!status.ok()) {
+      captured_state_->dst()->Poison(status);
+      return status;
+    }
+    return captured_state_->dst()->SetTensor(outputs[0]);
+  } else {
+    // Handles captured_state_->dst_ internally.
+    return RunRemoteRecv(&op);
+  }
+}
+
+Status RemoteCopyNode::Run() {
+  Status s = StartSend();
+  if (!s.ok()) {
+    Abort(s);
+    return s;
+  }
+
+  // StartRecv() takes care of doing the right thing to dst handle.
+  // No need to poison it after this point.
+  s = StartRecv();
+  if (!s.ok() && errors::IsCancelled(s)) {
+    Status send_status = captured_state_->GetSendStatus();
+    if (!send_status.ok()) {
+      // In this case, Recv is cancelled because the Send op failed. Return the
+      // status of the Send op instead.
+      s = send_status;
+    }
+  }
+
+  src_->Unref();
+  ctx_->Unref();
+  return s;
+}
+
+void RemoteCopyNode::Abort(Status status) {
+  captured_state_->dst()->Poison(status);
+  src_->Unref();
+  ctx_->Unref();
+}
+
+}  // namespace eager
+}  // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/eager/remote_copy_node.h b/tensorflow/core/distributed_runtime/eager/remote_copy_node.h
new file mode 100644
index 0000000..f642901
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/eager/remote_copy_node.h
@@ -0,0 +1,147 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_COPY_NODE_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_COPY_NODE_H_
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/eager/eager_executor.h"
+#include "tensorflow/core/common_runtime/eager/eager_operation.h"
+#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace eager {
+
+// This node supports copying a tensor in the following way:
+// - Remote -> Local:
+//   We don't block on the remote _Send op and start executing the local
+//   _Recv immediately after issuing the remote _Send. The local _Recv
+//   kernel (or rather the special _Recv handling in KernelAndDeviceOp::Run)
+//   blocks until the tensor is received. If the remote _Send (or some op
+//   before it) fails, the local callback we give to EnqueueAsync will run
+//   and call CancellationManager.StartCancel(). The blocked local _Recv will
+//   get this notification and return with a cancelled error.
+//
+// - Local -> Remote:
+//   The local _Send op is synchronous and non-blocking, thus it should complete
+//   quickly. We issue remote _Recv RPC only after local _Send completes
+//   successfully. At this point, the tensor to be sent is in the local
+//   Rendezvous, hence, remote _Recv op will not deadlock waiting for the tensor
+//   to appear.
+//
+// - Remote -> Remote:
+//   We could issue both remote ops asynchronously, but if remote _Send (or some
+//   op before it) fails, we don't have a good way of cancelling the remote
+//   _Recv. The remote _Recv will deadlock in this case. The current approach
+//   to deal with this issue is to wait for remote _Send to complete before
+//   issuing remote _Recv RPC. Another option is to close the whole streaming
+//   RPC that contains the deadlocked remote _Recv. This would not unblock the
+//   deadlocked RPC on the remote machine without some extra code. Luckily, the
+//   remote -> remote case seems to be fairly rare at this point. So, the
+//   current partially synchronous approach seems fine.
+//
+// To copy a tensor within a host, please use copy_to_device_node instead.
+class RemoteCopyNode : public EagerNode {
+ public:
+  RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor, TensorHandle* src,
+                 TensorHandle* dst, Device* recv_device, uint64 recv_op_id);
+
+  ~RemoteCopyNode() override {}
+
+  Status Run() override;
+
+  void Abort(Status status) override;
+
+ private:
+  // Runs the _Send operation locally or remotely.
+  // An error return value indicates that _Send did not run successfully.
+  // An OK return value does NOT necessarily indicate that _Send has completed
+  // successfully. It might still fail after this method returns.
+  // StartSend() makes sure that captured_state_->send_status_ is set to the
+  // final _Send status after captured_state->send_done_.WaitForNotification()
+  // returns.
+  Status StartSend();
+
+  // Synchronously runs local send `op` and returns its status.
+  Status RunLocalSend(EagerOperation* op);
+
+  // Runs the _Recv operation locally or remotely.
+  // An error return value indicates that _Recv did not run successfully. It
+  // does not indicate that _Send op has completed since StartRecv could have
+  // encountered an error before waiting for _Send's completion.
+  // An OK return value does NOT necessarily indicate that _Recv has completed
+  // successfully (it does now, but won't when streaming RPCs are turned on).
+  // StartRecv() makes sure that dst_ tensor handle is handled correctly
+  // (potentially after this methods returns); a tensor is set in the local
+  // case, a remote shape is set in the remote case, the dst_ handle is
+  // poisoned in either case if there is an error.
+  Status StartRecv();
+
+  // Synchronously runs local receive `op` and returns its status.
+  // Does not wait for the send to complete before running receive.
+  Status RunLocalRecv(EagerOperation* op, std::vector<Tensor>* outputs);
+
+  // Waits for send to complete, then issues remote receive `op` and
+  // returns its status.
+  Status RunRemoteRecv(EagerOperation* op);
+
+  // State that is captured by Send and/or Recv callbacks (depending on which
+  // one(s) is remote) and outlives this node in the case of remote->remote
+  // copy.
+  class CapturedSharedState {
+   public:
+    explicit CapturedSharedState(TensorHandle* d) : dst_(d) { dst_->Ref(); }
+    ~CapturedSharedState() { dst_->Unref(); }
+
+    void SetSendStatus(Status status) {
+      send_status_.Update(status);
+      send_done_.Notify();
+    }
+
+    Status GetSendStatus() {
+      send_done_.WaitForNotification();
+      return send_status_;
+    }
+
+    TensorHandle* dst() { return dst_; }
+    CancellationManager* recv_cancellation() { return &recv_cancellation_; }
+
+   private:
+    TensorHandle* const dst_;
+    CancellationManager recv_cancellation_;
+    // send_status_ is safe to read only after send_done_.WaitForNotification()
+    // has returned.
+    Status send_status_;
+    Notification send_done_;
+  };
+
+  TensorHandle* const src_;
+  EagerContext* const ctx_;
+  EagerExecutor* const executor_;
+  Device* const send_device_;
+  Device* const recv_device_;
+  const string wire_id_;
+  const uint64 recv_op_id_;
+
+  std::shared_ptr<CapturedSharedState> captured_state_;
+};
+
+}  // namespace eager
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_COPY_NODE_H_
diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc b/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc
new file mode 100644
index 0000000..51a95b0
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc
@@ -0,0 +1,82 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h"
+
+namespace tensorflow {
+namespace eager {
+
+Status RemoteExecuteNode::Run() {
+  EnqueueResponse* response = new EnqueueResponse;
+
+  const gtl::InlinedVector<TensorHandle*, 4>& inputs = inputs_;
+  const gtl::InlinedVector<TensorHandle*, 2>& retvals = retvals_;
+  Device* device = device_;
+
+  // Filled and used only when VLOG(3) is on.
+  string rpc_description;
+  if (VLOG_IS_ON(3)) {
+    std::vector<string> ops;
+    ops.reserve(request_->queue_size());
+    for (const QueueItem& item : request_->queue()) {
+      if (item.has_operation()) {
+        ops.push_back(item.operation().name());
+      } else {
+        ops.push_back(absl::StrCat("DeleteHandle(",
+                                   item.handle_to_decref().op_id(), ":",
+                                   item.handle_to_decref().output_num(), ")"));
+      }
+    }
+    rpc_description =
+        absl::StrCat("RemoteOperation(", absl::StrJoin(ops, ", "), ")");
+  }
+  VLOG(3) << "Issuing: " << rpc_description;
+
+  eager_client_->StreamingEnqueueAsync(
+      request_.get(), response,
+      [inputs, retvals, response, device,
+       rpc_description](const Status& status) {
+        for (auto handle : inputs) {
+          handle->Unref();
+        }
+        if (status.ok()) {
+          VLOG(3) << "Completed successfully: " << rpc_description;
+        } else {
+          VLOG(3) << "Failed: " << rpc_description << " with status "
+                  << status.ToString();
+        }
+        for (size_t i = 0; i < retvals.size(); ++i) {
+          if (status.ok()) {
+            Status s = retvals[i]->SetRemoteShape(
+                response->queue_response(0).shape(i), device);
+            if (!s.ok()) {
+              LOG(ERROR) << "Ignoring an error encountered when setting "
+                            "remote shape of tensor handle: "
+                         << retvals[i] << " with status: " << status.ToString()
+                         << "\nThis should never happen. "
+                            "Please file an issue with the TensorFlow Team.";
+            }
+          } else {
+            retvals[i]->Poison(status);
+          }
+          retvals[i]->Unref();
+        }
+        delete response;
+      });
+  return Status::OK();
+}
+
+}  // namespace eager
+}  // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
index 761efff..9dab1f7 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
+++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
@@ -16,6 +16,8 @@
 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_EXECUTE_NODE_H_
 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_EXECUTE_NODE_H_
 
+#include <cstddef>
+
 #include "absl/types/span.h"
 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
@@ -52,37 +54,7 @@
     }
   }
 
-  Status Run() override {
-    EnqueueResponse response;
-    Status status;
-    Notification n;
-    eager_client_->EnqueueAsync(request_.get(), &response,
-                                [&n, &status](const Status& s) {
-                                  status.Update(s);
-                                  n.Notify();
-                                });
-    n.WaitForNotification();
-
-    if (!status.ok()) {
-      Abort(status);
-      return status;
-    }
-
-    for (int i = 0; i < retvals_.size(); i++) {
-      Status s = retvals_[i]->SetRemoteShape(
-          response.queue_response(0).shape(i), device_);
-      if (!s.ok()) {
-        retvals_[i]->Poison(s);
-      }
-      retvals_[i]->Unref();
-    }
-
-    for (auto handle : inputs_) {
-      handle->Unref();
-    }
-
-    return status;
-  }
+  Status Run() override;
 
   void Abort(Status status) override {
     for (auto handle : retvals_) {
diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc
index a7e0027..943c160 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc
+++ b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc
@@ -17,6 +17,7 @@
 
 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h"
 #include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
 
 namespace tensorflow {
 namespace eager {
@@ -32,10 +33,9 @@
   }
 }
 
-Status RemoteMgr::GetTensorHandle(
+Status RemoteMgr::GetTensorHandleImpl(
     const RemoteTensorHandleInternal& remote_handle,
     tensorflow::TensorHandle** handle) {
-  tf_shared_lock l(remote_tensor_handle_mu_);
   auto iter = remote_tensor_handle_map_.find(remote_handle);
   if (iter == remote_tensor_handle_map_.end()) {
     return errors::InvalidArgument(
@@ -48,6 +48,28 @@
   return Status::OK();
 }
 
+Status RemoteMgr::GetTensorHandle(
+    const RemoteTensorHandleInternal& remote_handle,
+    tensorflow::TensorHandle** handle) {
+  tf_shared_lock l(remote_tensor_handle_mu_);
+  return GetTensorHandleImpl(remote_handle, handle);
+}
+
+Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
+                                        int64* op_id, int32* output_num) {
+  TF_RETURN_IF_ERROR(
+      handle->RemoteAddress(handle->device(), op_id, output_num));
+  tensorflow::TensorHandle* h;
+  TF_RETURN_IF_ERROR(
+      GetTensorHandleImpl(RemoteTensorHandleInternal(*op_id, *output_num), &h));
+  if (handle != h) {
+    return errors::Internal(
+        "Found two different tensor handles with the same op_id:", *op_id,
+        " and output_num:", *output_num);
+  }
+  return Status::OK();
+}
+
 Status RemoteMgr::DeleteTensorHandle(
     const RemoteTensorHandleInternal& remote_handle) {
   mutex_lock l(remote_tensor_handle_mu_);
@@ -66,22 +88,54 @@
 
 Status RemoteMgr::SerializeRemoteTensorHandle(TensorHandle* in,
                                               RemoteTensorHandle* out,
-                                              Device* device) {
-  // TODO(fishx): support serializing local tensor handle.
+                                              Device* device,
+                                              const string& device_name) {
   int64 op_id;
   int32 output_num;
-  TF_RETURN_IF_ERROR(in->RemoteAddress(device, &op_id, &output_num));
+  if (!in->RemoteAddress(device, &op_id, &output_num).ok()) {
+    mutex_lock l(remote_tensor_handle_mu_);
+    if (!GetRemoteTensorHandle(in, &op_id, &output_num).ok()) {
+      op_id = NextOpId();
+      output_num = 0;
+      in->SetRemoteOpIdAndOutputNumToLocalTensorHandle(op_id, output_num);
+      in->Ref();
+      remote_tensor_handle_map_.emplace(
+          RemoteTensorHandleInternal(op_id, output_num), in);
+    }
+  }
   out->Clear();
   out->set_op_id(op_id);
   out->set_output_num(output_num);
+  out->set_op_device(in->op_device() ? in->op_device()->name() : "");
+  out->set_device(device_name);
+  out->set_dtype(in->dtype);
   return Status::OK();
 }
 
 Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in,
                                                 TensorHandle** out) {
-  // TODO(fishx): support the case when the remote tensor handle does not exist
-  // in the map.
-  TF_RETURN_IF_ERROR(GetTensorHandle(RemoteTensorHandleInternal(in), out));
+  Device* device;
+  if (parent_->local_device_mgr()->LookupDevice(in.op_device(), &device).ok() ||
+      parent_->local_device_mgr()->LookupDevice(in.device(), &device).ok()) {
+    TF_RETURN_IF_ERROR(GetTensorHandle(RemoteTensorHandleInternal(in), out));
+    (*out)->Ref();
+  } else {
+    // Create a remote TensorHandle for remote tensors which have not been
+    // copied to the local worker yet.
+    const string& device_name =
+        in.op_device().empty() ? in.device() : in.op_device();
+    TF_RETURN_IF_ERROR(
+        parent_->FindDeviceFromName(device_name.c_str(), &device));
+    EagerClient* eager_client;
+    TF_RETURN_IF_ERROR(parent_->GetClient(device, &eager_client));
+    auto remote_handle_data = absl::make_unique<UnshapedRemoteTensorHandleData>(
+        in.op_id(), in.output_num(), eager_client, parent_->GetContextId(),
+        parent_);
+    remote_handle_data->ReleaseRemoteTensorHandle();
+    TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle(
+        std::move(remote_handle_data), in.dtype(), device, parent_, out));
+  }
+
   return Status::OK();
 }
 
diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr.h b/tensorflow/core/distributed_runtime/eager/remote_mgr.h
index 7b4a9bf..44be2d4 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_mgr.h
+++ b/tensorflow/core/distributed_runtime/eager/remote_mgr.h
@@ -29,7 +29,8 @@
 // TODO(fishx): Move remote state from context to this class.
 class RemoteMgr {
  public:
-  explicit RemoteMgr(bool is_master) : is_master_(is_master) {}
+  RemoteMgr(bool is_master, EagerContext* ctx)
+      : is_master_(is_master), parent_(ctx) {}
 
   ~RemoteMgr() {
     for (const auto& entry : remote_tensor_handle_map_) {
@@ -56,13 +57,30 @@
     return next_op_id_++;
   }
 
+  // Serialize a TensorHandle(local/remote) to a RemoteTensorHandle.
   Status SerializeRemoteTensorHandle(TensorHandle* in, RemoteTensorHandle* out,
-                                     Device* device);
+                                     Device* device, const string& device_name);
 
+  // Deserialize a RemoteTensorHandle to a TensorHandle(local/remote).
+  // The output holds a reference to the TensorHandle.
   Status DeserializeRemoteTensorHandle(const RemoteTensorHandle& in,
                                        TensorHandle** out);
 
+ protected:
+  mutex next_id_mutex_;
+  uint64 next_op_id_ GUARDED_BY(next_id_mutex_) = 1;
+
  private:
+  // Returns the op_id and output_num if the given local TensorHandle exists in
+  // remote_tensor_handle_map_.
+  Status GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
+                               int64* op_id, int32* output_num)
+      SHARED_LOCKS_REQUIRED(remote_tensor_handle_mu_);
+
+  Status GetTensorHandleImpl(const RemoteTensorHandleInternal& remote_handle,
+                             tensorflow::TensorHandle** handle)
+      SHARED_LOCKS_REQUIRED(remote_tensor_handle_mu_);
+
   bool is_master_;
 
   using RemoteTensorHandleMap =
@@ -70,13 +88,13 @@
                    RemoteTensorHandleInternalHash,
                    RemoteTensorHandleInternalEquals>;
   mutex remote_tensor_handle_mu_;
-  // This map maintains the TensorHandles that is required by remote worker
-  // in the cluster.
+  // This map maintains the TensorHandles that are required by remote workers
+  // in the cluster. Each map key is generated by the master, so it should be
+  // globally unique. This map owns references on the handles it contains.
   RemoteTensorHandleMap remote_tensor_handle_map_
       GUARDED_BY(remote_tensor_handle_mu_);
 
-  mutex next_id_mutex_;
-  uint64 next_op_id_ GUARDED_BY(next_id_mutex_) = 1;
+  EagerContext* parent_;  // not owned.
 };
 
 }  // namespace eager
diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc b/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc
new file mode 100644
index 0000000..f5f0106
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc
@@ -0,0 +1,151 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
+
+#include <memory>
+
+#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/protobuf/eager_service.pb.h"
+
+namespace tensorflow {
+namespace eager {
+namespace {
+
+class TestRemoteMgr : public RemoteMgr {
+ public:
+  TestRemoteMgr(bool is_master, EagerContext* ctx)
+      : RemoteMgr(is_master, ctx) {}
+
+  uint64 OpId() {
+    tf_shared_lock l(next_id_mutex_);
+    return next_op_id_;
+  }
+};
+
+class RemoteMgrTest : public ::testing::Test {
+ public:
+  RemoteMgrTest() {
+    std::vector<std::unique_ptr<Device>> devices;
+    devices.push_back(
+        DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0"));
+    local_device_ = devices.back().get();
+    devices.push_back(
+        DeviceFactory::NewDevice("CPU", {}, "/job:worker/replica:0/task:0"));
+    remote_device_ = devices.back().get();
+    auto device_mgr = absl::make_unique<DeviceMgr>(std::move(devices));
+    context_id_ = random::New64();
+    tensorflow::Rendezvous* rendezvous =
+        new tensorflow::IntraProcessRendezvous(device_mgr.get());
+    ctx_ = new tensorflow::EagerContext(
+        SessionOptions(),
+        tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
+        tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false,
+        device_mgr.release(), true, rendezvous, GetDefaultCustomKernelCreator(),
+        nullptr);
+  }
+
+  ~RemoteMgrTest() override { ctx_->Unref(); }
+
+  Device* local_device_;
+  Device* remote_device_;
+  uint64 context_id_;
+  EagerContext* ctx_;
+};
+
+TEST_F(RemoteMgrTest, LocalTensorHandle) {
+  TestRemoteMgr remote_mgr(true, ctx_);
+  Tensor t(DT_FLOAT, TensorShape({0}));
+
+  TensorHandle* handle;
+  TF_ASSERT_OK(TensorHandle::CreateLocalHandle(t, &handle));
+  EXPECT_EQ(nullptr, handle->device());
+  EXPECT_EQ(local_device_, handle->DeviceOrHostCPU(ctx_));
+  const uint64 op_id = remote_mgr.OpId();
+  EXPECT_EQ(1, op_id);
+  RemoteTensorHandle remote_handle;
+  TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
+      handle, &remote_handle, handle->device(),
+      handle->DeviceOrHostCPU(ctx_)->name()));
+  EXPECT_EQ(2, remote_mgr.OpId());
+  EXPECT_EQ(op_id, remote_handle.op_id());
+  EXPECT_EQ(0, remote_handle.output_num());
+  EXPECT_EQ(local_device_->name(), remote_handle.device());
+
+  TensorHandle* deserialized_handle;
+  TF_ASSERT_OK(remote_mgr.DeserializeRemoteTensorHandle(remote_handle,
+                                                        &deserialized_handle));
+  tensorflow::TensorHandle* h;
+  TF_EXPECT_OK(remote_mgr.GetTensorHandle(
+      RemoteTensorHandleInternal(remote_handle), &h));
+  TF_ASSERT_OK(
+      remote_mgr.DeleteTensorHandle(RemoteTensorHandleInternal(remote_handle)));
+  EXPECT_FALSE(
+      remote_mgr.GetTensorHandle(RemoteTensorHandleInternal(remote_handle), &h)
+          .ok());
+
+  deserialized_handle->Unref();
+  handle->Unref();
+}
+
+TEST_F(RemoteMgrTest, SerializeLocalTensorHandleWithRemoteMirror) {
+  RemoteMgr remote_mgr(false, ctx_);
+  Tensor t(DT_FLOAT, TensorShape({0}));
+
+  TensorHandle* handle;
+  TF_ASSERT_OK(
+      TensorHandle::CreateLocalHandle(t, local_device_, ctx_, &handle));
+  const uint64 op_id = 2;
+  const int output_num = 3;
+  auto tensor_handle_data = absl::make_unique<RemoteTensorHandleData>(
+      op_id, output_num, t.shape(), /*eager_client=*/nullptr, context_id_,
+      ctx_);
+  TF_ASSERT_OK(
+      handle->AddRemoteMirror(std::move(tensor_handle_data), remote_device_));
+  RemoteTensorHandle remote_handle;
+  TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
+      handle, &remote_handle, remote_device_, remote_device_->name()));
+  EXPECT_EQ(op_id, remote_handle.op_id());
+  EXPECT_EQ(output_num, remote_handle.output_num());
+  EXPECT_EQ(remote_device_->name(), remote_handle.device());
+  handle->Unref();
+}
+
+TEST_F(RemoteMgrTest, SerializeRemoteTensorHandle) {
+  RemoteMgr remote_mgr(false, ctx_);
+  Tensor t(DT_FLOAT, TensorShape({0}));
+
+  const uint64 op_id = 3;
+  const int output_num = 1;
+  TensorHandle* handle;
+  TF_ASSERT_OK(TensorHandle::CreateRemoteHandle(
+      op_id, output_num, t.shape(), /*eager_client=*/nullptr, context_id_,
+      DT_FLOAT, remote_device_,
+      /*resource_device=*/nullptr, ctx_, &handle));
+  RemoteTensorHandle remote_handle;
+  TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
+      handle, &remote_handle, remote_device_, remote_device_->name()));
+  EXPECT_EQ(op_id, remote_handle.op_id());
+  EXPECT_EQ(output_num, remote_handle.output_num());
+  EXPECT_EQ(remote_device_->name(), remote_handle.device());
+  handle->Unref();
+}
+
+}  // namespace
+}  // namespace eager
+}  // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc
index d3a7c60..85ad20e 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc
+++ b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc
@@ -41,16 +41,31 @@
   handle_to_decref->set_op_id(op_id);
   handle_to_decref->set_output_num(output_num);
 
+  VLOG(3) << "Sending request to delete " << request->DebugString();
   std::unique_ptr<EagerNode> node(
       absl::make_unique<eager::DestroyTensorHandleNode>(std::move(request),
                                                         eager_client));
-  Status s = ctx->Async() ? ctx->ExecutorAdd(std::move(node)) : node->Run();
-  if (!s.ok()) {
-    LOG(ERROR) << "Unable to destroy remote tensor handles: "
-               << s.error_message();
+  auto* executor = ctx->Executor();
+  if (executor->Async()) {
+    Status status = executor->Add(std::move(node));
+    if (!status.ok()) {
+      LOG(ERROR) << "Unable to destroy remote tensor handles: "
+                 << status.error_message();
+    }
+  } else {
+    // This thread may still hold tensorflow::StreamingRPCState::mu_. We need
+    // to send out the destroy request in a new thread to avoid deadlock.
+    auto* released_node = node.release();
+    (*ctx->runner())([released_node] {
+      Status status = released_node->Run();
+      if (!status.ok()) {
+        LOG(ERROR) << "Unable to destroy remote tensor handles: "
+                   << status.error_message();
+      }
+      delete released_node;
+    });
   }
 }
-
 }  // namespace
 
 RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num,
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index 81d6412..5d06bf9 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -179,14 +179,14 @@
   }
 
   std::unordered_map<string, std::unique_ptr<Graph>> partition_graphs;
-  for (const auto& partition : partitions) {
+  for (auto& partition : partitions) {
     std::unique_ptr<Graph> device_graph(new Graph(OpRegistry::Global()));
     GraphConstructorOptions device_opts;
     // There are internal operations (e.g., send/recv) that we now allow.
     device_opts.allow_internal_ops = true;
     device_opts.expect_device_spec = true;
-    TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second,
-                                              device_graph.get()));
+    TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
+        device_opts, std::move(partition.second), device_graph.get()));
     partition_graphs.emplace(partition.first, std::move(device_graph));
   }
 
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD
index bd300d6..50b381b 100644
--- a/tensorflow/core/distributed_runtime/rpc/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/BUILD
@@ -3,20 +3,20 @@
 
 load(
     "//tensorflow:tensorflow.bzl",
+    "tf_cc_binary",
     "tf_cc_test",
     "tf_cuda_library",
-    "tf_cc_binary",
 )
 load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
 load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_tests")
 
 # For platform specific build config
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_kernel_tests_linkstatic",
 )
 load(
-    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "//tensorflow/core/platform:default/build_config_root.bzl",
     "tf_cuda_tests_tags",
 )
 
@@ -295,6 +295,7 @@
         "//tensorflow/core:framework",
         "//tensorflow/core:framework_internal",
         "//tensorflow/core:lib",
+        "//tensorflow/core/common_runtime/eager:context",
         "//tensorflow/core/distributed_runtime:collective_param_resolver_distributed",
         "//tensorflow/core/distributed_runtime:device_resolver_distributed",
         "//tensorflow/core/distributed_runtime:graph_mgr",
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/BUILD b/tensorflow/core/distributed_runtime/rpc/eager/BUILD
index 0b18136..1ac8e68 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/eager/BUILD
@@ -27,6 +27,7 @@
         "//tensorflow:grpc++",
         "//tensorflow/core:eager_service_proto_cc",
         "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
         "//tensorflow/core/distributed_runtime/eager:eager_client",
         "//tensorflow/core/distributed_runtime/rpc:grpc_channel",
         "//tensorflow/core/distributed_runtime/rpc:grpc_client_cq_tag",
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc
index da5d43a..c3764a9 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc
+++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc
@@ -23,10 +23,19 @@
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/protobuf/eager_service.pb.h"
+#include "tensorflow/core/util/env_var.h"
 
 namespace tensorflow {
 namespace eager {
 namespace {
+bool EnableStreaming() {
+  bool result;
+  // TODO(b/139210648): Turn on this flag by default.
+  TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE",
+                                 false, &result));
+  return result;
+}
+
 class GrpcEagerClient : public EagerClient {
  public:
   GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr& channel,
@@ -40,7 +49,8 @@
       override {                                                          \
     new RPCState<protobuf::Message>(                                      \
         &stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request, \
-        response, std::move(done), nullptr, nullptr);                     \
+        response, std::move(done), nullptr, nullptr, /*max_retries=*/10,  \
+        /*fail_fast=*/true);                                              \
   }
 
   CLIENT_METHOD(CreateContext);
@@ -59,8 +69,13 @@
         &stub_, cq_, "/tensorflow.eager.EagerService/CloseContext", *request,
         response, std::move(done), nullptr, nullptr);
 
-    if (enqueue_dispatchers_.find(request->context_id()) !=
-        enqueue_dispatchers_.end()) {
+    VLOG(1) << "Sending RPC to close remote eager context "
+            << request->DebugString();
+
+    mutex_lock l(mu_);
+    const auto& it = enqueue_dispatchers_.find(request->context_id());
+    if (it != enqueue_dispatchers_.end()) {
+      it->second.CancelCall();
       enqueue_dispatchers_.erase(request->context_id());
     } else {
       LOG(ERROR) << "Remote EagerContext with id " << request->context_id()
@@ -71,25 +86,40 @@
   void StreamingEnqueueAsync(const EnqueueRequest* request,
                              EnqueueResponse* response,
                              StatusCallback done) override {
-    auto it = enqueue_dispatchers_.find(request->context_id());
-    if (enqueue_dispatchers_.find(request->context_id()) ==
-        enqueue_dispatchers_.end()) {
-      auto it_and_bool = enqueue_dispatchers_.emplace(
-          std::piecewise_construct,
-          std::forward_as_tuple(request->context_id()),
-          std::forward_as_tuple(
-              &stub_, cq_, "/tensorflow.eager.EagerService/StreamingEnqueue"));
-      it = it_and_bool.first;
+    if (EnableStreaming()) {
+      tf_shared_lock l(mu_);
+      auto it = enqueue_dispatchers_.find(request->context_id());
+      if (enqueue_dispatchers_.find(request->context_id()) ==
+          enqueue_dispatchers_.end()) {
+        auto it_and_bool = enqueue_dispatchers_.emplace(
+            std::piecewise_construct,
+            std::forward_as_tuple(request->context_id()),
+            std::forward_as_tuple(
+                &stub_, cq_,
+                "/tensorflow.eager.EagerService/StreamingEnqueue"));
+        it = it_and_bool.first;
+      }
+      it->second.SendNextRequest(*request, response, std::move(done));
+    } else {
+      Notification n;
+      Status status;
+      EnqueueAsync(request, response, [&n, &status](const Status& s) {
+        status.Update(s);
+        n.Notify();
+      });
+      n.WaitForNotification();
+      done(status);
     }
-
-    it->second.SendNextRequest(*request, response, std::move(done));
   }
 
  private:
   ::grpc::GenericStub stub_;
   ::grpc::CompletionQueue* cq_;
+
+  mutable mutex mu_;
+
   std::unordered_map<uint64, StreamingRPCDispatcher<EnqueueResponse>>
-      enqueue_dispatchers_;
+      enqueue_dispatchers_ GUARDED_BY(mu_);
 };
 
 class GrpcEagerClientCache : public EagerClientCache {
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc
index b3c2001..869fe14 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc
@@ -32,6 +32,11 @@
   cq_ = server_builder->AddCompletionQueue();
 }
 
+Status GrpcEagerServiceImpl::CreateMasterContext(
+    const tensorflow::uint64 context_id, EagerContext* context) {
+  return local_impl_.CreateMasterContext(context_id, context);
+}
+
 void GrpcEagerServiceImpl::HandleRPCsLoop() {
 #define ENQUEUE_REQUEST(method)                                            \
   do {                                                                     \
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h
index cea7b69..5ee8f33 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h
+++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h
@@ -44,6 +44,10 @@
                        ::grpc::ServerBuilder* server_builder);
   virtual ~GrpcEagerServiceImpl() {}
 
+  // Create a master context in eager service.
+  Status CreateMasterContext(const tensorflow::uint64 context_id,
+                             EagerContext* context);
+
   void HandleRPCsLoop() override;
   void Shutdown() override;
 
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
index a313588..f70d608 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
@@ -86,6 +86,10 @@
       LOG(ERROR) << "Invalid compression algorithm: "
                  << rpc_options->compression_algorithm();
     }
+    if (rpc_options->disable_session_connection_sharing()) {
+      VLOG(5) << "Disabling TCP connection sharing";
+      args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, true);
+    }
   }
   return args;
 }
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
index 3635caf..8be6f1d 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
@@ -187,8 +187,8 @@
 
 void GrpcRPCFactory::StartCall(const Tensor& address_t, const Tensor& method_t,
                                GrpcCall* call) {
-  auto address = address_t.flat<string>();
-  auto method = method_t.flat<string>();
+  auto address = address_t.flat<tstring>();
+  auto method = method_t.flat<tstring>();
   // Stubs are maintained by the GrpcRPCFactory class and will be
   // deleted when the class is destroyed.
   ::grpc::GenericStub* singleton_stub = nullptr;
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
index 78751ff..c8eeaa9 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -375,6 +375,13 @@
   }
 }
 
+Status GrpcServer::AddMasterEagerContextToEagerService(
+    const tensorflow::uint64 context_id, tensorflow::EagerContext* context) {
+  auto* eager_service =
+      static_cast<eager::GrpcEagerServiceImpl*>(eager_service_);
+  return eager_service->CreateMasterContext(context_id, context);
+}
+
 Status GrpcServer::Stop() {
   mutex_lock l(mu_);
   switch (state_) {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
index 6f3bdd2..521c8f2 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
@@ -22,7 +22,7 @@
 
 #include "grpcpp/grpcpp.h"
 #include "grpcpp/security/credentials.h"
-
+#include "tensorflow/core/common_runtime/eager/context.h"
 #include "tensorflow/core/common_runtime/process_util.h"
 #include "tensorflow/core/common_runtime/stats_publisher_interface.h"
 #include "tensorflow/core/distributed_runtime/master_env.h"
@@ -95,6 +95,11 @@
   WorkerEnv* worker_env() { return &worker_env_; }
   MasterEnv* master_env() { return &master_env_; }
 
+  // Add master eager context to local eager service in order to handle enqueue
+  // requests from remote workers.
+  Status AddMasterEagerContextToEagerService(
+      const tensorflow::uint64 context_id, tensorflow::EagerContext* context);
+
  protected:
   virtual Status GetPort(int* port) const;
   Status Init(const GrpcServerOptions& opts = GrpcServerOptions());
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
index c38b89b..7f2906e 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
@@ -501,7 +501,7 @@
   Graph graph(OpRegistry::Global());
   Tensor a_tensor(DT_STRING, TensorShape({2, 2}));
   for (int i = 0; i < 4; ++i) {
-    a_tensor.flat<string>()(i) = "hello, world";
+    a_tensor.flat<tstring>()(i) = "hello, world";
   }
   Node* a = test::graph::Constant(&graph, a_tensor);
   Node* b = test::graph::Identity(&graph, a);
@@ -525,7 +525,7 @@
         ASSERT_EQ(outputs[0].dtype(), DT_STRING);
         ASSERT_EQ(outputs[0].NumElements(), 4);
         for (int i = 0; i < outputs[0].NumElements(); ++i) {
-          EXPECT_EQ(outputs[0].flat<string>()(i), "hello, world");
+          EXPECT_EQ(outputs[0].flat<tstring>()(i), "hello, world");
         }
         TF_CHECK_OK(session->Close());
       } else {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_state.cc b/tensorflow/core/distributed_runtime/rpc/grpc_state.cc
index 75e4153..b05a54c 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_state.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_state.cc
@@ -26,6 +26,8 @@
       return "kRequestWriteCompleted";
     case UntypedStreamingRPCState::Tag::TagType::kResponseReadCommpleted:
       return "kResponseReadCommpleted";
+    case UntypedStreamingRPCState::Tag::TagType::kCallFinished:
+      return "kCallFinished";
   }
 }
 
@@ -44,6 +46,9 @@
     case TagType::kResponseReadCommpleted:
       streaming_state_->ResponseReadCompleted(ok);
       break;
+    case TagType::kCallFinished:
+      streaming_state_->CallFinished(ok);
+      break;
   }
   streaming_state_->Unref();  // Ref acquired when tag was handed to grpc.
 }
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_state.h b/tensorflow/core/distributed_runtime/rpc/grpc_state.h
index 10c9af3..1567d89 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_state.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_state.h
@@ -45,9 +45,10 @@
   RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq,
            const ::grpc::string& method, const protobuf::Message& request,
            Response* response, StatusCallback done, CallOptions* call_opts,
-           thread::ThreadPool* threadpool, int32 max_retries = 0)
+           thread::ThreadPool* threadpool, int32 max_retries = 0,
+           bool fail_fast = false)
       : RPCState(stub, cq, method, request, response, std::move(done),
-                 call_opts, threadpool, /*fail_fast=*/false,
+                 call_opts, threadpool, fail_fast,
                  /*timeout_in_ms=*/0, max_retries) {}
 
   template <typename Request>
@@ -80,7 +81,7 @@
 
   void StartCall() {
     context_.reset(new ::grpc::ClientContext());
-    context_->set_fail_fast(fail_fast_);
+    context_->set_wait_for_ready(!fail_fast_);
 
     if (timeout_in_ms_ > 0) {
       context_->set_deadline(
@@ -194,6 +195,7 @@
   virtual void CallStarted(bool ok) = 0;
   virtual void RequestWriteCompleted(bool ok) = 0;
   virtual void ResponseReadCompleted(bool ok) = 0;
+  virtual void CallFinished(bool ok) = 0;
 
   virtual string DebugString() const = 0;
 
@@ -204,6 +206,7 @@
       kCallStarted,
       kRequestWriteCompleted,
       kResponseReadCommpleted,
+      kCallFinished,
     };
 
     Tag(UntypedStreamingRPCState* streaming_state, Tag::TagType type);
@@ -364,7 +367,7 @@
   // manually.
   StreamingRPCState(std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call,
                     const std::shared_ptr<::grpc::ClientContext>& context)
-      : context_(context), call_(std::move(call)), call_done_(false) {
+      : context_(context), call_(std::move(call)), call_state_(State::kActive) {
     Ref();
     VLOG(3) << "Created new StreamingRPCState " << this;
     VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::StartCall";
@@ -396,7 +399,7 @@
     }
 
     mutex_lock l(mu_);
-    if (call_done_) {
+    if (call_state_ != State::kActive) {
       // `done` is not invoked intentionally.
       return false;
     }
@@ -417,7 +420,7 @@
             << ")";
     mutex_lock l(mu_);
     if (!ok) {
-      call_done_ = true;
+      call_state_ = State::kDone;
       return;
     }
     exchanges_.CallStarted();
@@ -429,13 +432,17 @@
     VLOG(3) << "StreamingRPCState(" << this
             << ")::RequestWriteCompleted(ok=" << ok << ")";
     mu_.lock();
-    if (call_done_) {
+    if (call_state_ != State::kActive) {
       mu_.unlock();
       return;
     }
     if (!ok) {
       // unlocks mu_
-      MarkDoneAndCompleteExchanges();
+      MarkDoneAndCompleteExchanges(errors::Internal(
+          "Unexpected ok value at streaming rpc writing. ",
+          "Probably because the completion queue has been shut ",
+          "down or the connection went down. ",
+          context_->debug_error_string()));
       return;
     }
 
@@ -449,13 +456,13 @@
     VLOG(3) << "StreamingRPCState(" << this
             << ")::ResponseReadCompleted(ok=" << ok << ")";
     mu_.lock();
-    if (call_done_) {
+    if (call_state_ != State::kActive) {
       mu_.unlock();
       return;
     }
     if (!ok) {
-      // unlocks mu_
-      MarkDoneAndCompleteExchanges();
+      IssueCallFinishLocked();
+      mu_.unlock();
       return;
     }
 
@@ -477,17 +484,41 @@
     }
   }
 
+  void CallFinished(bool ok) override {
+    VLOG(3) << "StreamingRPCState(" << this << ")::CallFinished(ok=" << ok
+            << ")";
+    mu_.lock();
+    DCHECK(call_state_ != State::kActive);
+    if (call_state_ != State::kFinishing) {
+      mu_.unlock();
+      return;
+    }
+
+    Status s = FromGrpcStatus(call_status_);
+    if (s.ok() && !ok) {
+      s.Update(
+          errors::Internal("unexpected ok value at streaming rpc completion. ",
+                           context_->debug_error_string()));
+    }
+    // unlocks mu_
+    MarkDoneAndCompleteExchanges(s);
+  }
+
   string DebugString() const override {
     mutex_lock l(mu_);
     return exchanges_.DebugString();
   }
 
  private:
-  void MarkDoneAndCompleteExchanges() EXCLUSIVE_LOCKS_REQUIRED(mu_)
+  enum class State {
+    kActive,
+    kFinishing,
+    kDone,
+  };
+
+  void MarkDoneAndCompleteExchanges(Status status) EXCLUSIVE_LOCKS_REQUIRED(mu_)
       UNLOCK_FUNCTION(mu_) {
-    call_done_ = true;
-    Status status = errors::Unknown("gRPC streaming call has ended: ",
-                                    context_->debug_error_string());
+    call_state_ = State::kDone;
     VLOG(2) << "Ending gRPC stremaing call on the client side due to "
             << status.ToString();
     // Swap the exchanges_ into a temporary ExchangeQueue so that we can
@@ -524,6 +555,17 @@
     call_->Read(exchange->response_buf(), &response_read_completed_tag_);
   }
 
+  void IssueCallFinishLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+    call_state_ = State::kFinishing;
+    Ref();
+    VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Finish";
+    // We call finish in response to completed (with error) response reading tag
+    // on some exchange. We let this exchange hang in ResponseReadIssued state.
+    // ExchangeQueue makes sure that there is at most one exchange in this
+    // state. So, no new reads will be issued.
+    call_->Finish(&call_status_, &finished_tag_);
+  }
+
   // Holds state for a single request/response exchange between the client
   // and the server.
   typedef typename UntypedStreamingRPCState::Tag Tag;
@@ -535,7 +577,8 @@
 
   mutable mutex mu_;
   ExchangeQueue exchanges_ GUARDED_BY(mu_);
-  bool call_done_ GUARDED_BY(mu_);
+  State call_state_ GUARDED_BY(mu_);
+  ::grpc::Status call_status_ GUARDED_BY(mu_);
 
   // We can get away with having single instances of these tags per
   // StreamingRPCState because we make sure (as gRPC requires) that
@@ -545,6 +588,7 @@
   Tag call_started_tag_{this, Tag::TagType::kCallStarted};
   Tag request_write_completed_tag_{this, Tag::TagType::kRequestWriteCompleted};
   Tag response_read_completed_tag_{this, Tag::TagType::kResponseReadCommpleted};
+  Tag finished_tag_{this, Tag::TagType::kCallFinished};
 };
 
 // Creates streaming calls and dispatches requests to them.
@@ -597,6 +641,16 @@
     }
   }
 
+  // Request to cancel the current streaming call. Non-blocking.
+  void CancelCall() {
+    mutex_lock l(mu_);
+    if (state_ == nullptr) {
+      return;
+    }
+    context_->TryCancel();
+    state_ = nullptr;
+  }
+
  private:
   void CreateStreamingState() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
     // ClientContext cannot be reused across calls.
diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
index 0818a05..a267371 100644
--- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
+++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
@@ -255,7 +255,7 @@
              recv_args, std::move(done));
 
   // Record "call" in active_ so that it can be aborted cleanly.
-  RegisterCall(call);
+  RegisterCall(call, recv_args);
 
   // RendezvousMgr already aborted, shouldn't send RPC call any more
   if (!call->status().ok()) {
diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
index f54eace..5021853 100644
--- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
 
 #include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/framework/cancellation.h"
 #include "tensorflow/core/framework/control_flow.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/notification.h"
@@ -29,7 +30,7 @@
 // string -> Tensor<string>
 Tensor V(const string& content) {
   Tensor tensor(DT_STRING, TensorShape({}));
-  tensor.scalar<string>()() = content;
+  tensor.scalar<tstring>()() = content;
   return tensor;
 }
 
@@ -37,7 +38,7 @@
 string V(const Tensor& tensor) {
   CHECK_EQ(tensor.dtype(), DT_STRING);
   CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
-  return tensor.scalar<string>()();
+  return tensor.scalar<tstring>()();
 }
 
 Rendezvous::ParsedKey MakeKey(const string& s) {
@@ -142,6 +143,56 @@
   }
 }
 
+TEST_F(RpcRendezvousMgrTest, LocalCancel) {
+  const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
+      "/job:mnist/replica:1/task:2/cpu:0", 7890,
+      "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
+  auto* cm = new CancellationManager();
+  const int64 step_id = 123;
+  RemoteRendezvous* rendez = rmgr_.Find(step_id);
+  core::ScopedUnref unref(rendez);
+  Notification n;
+  SchedClosure([this, cm, &n]() {
+    env.env->SleepForMicroseconds(100 * 1000);
+    cm->StartCancel();
+    n.Notify();
+  });
+  Tensor val(DT_STRING);
+  bool val_dead = false;
+  Rendezvous::Args args;
+  args.cancellation_manager = cm;
+  TF_ASSERT_OK(rendez->Initialize(&worker_session_));
+  EXPECT_TRUE(errors::IsCancelled(rendez->Recv(key, args, &val, &val_dead)));
+  n.WaitForNotification();
+  delete cm;
+}
+
+TEST_F(RpcRendezvousMgrTest, CancelAfterReceived) {
+  const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
+      "/job:mnist/replica:1/task:2/cpu:0", 7890,
+      "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
+  auto* cm = new CancellationManager();
+  const int64 step_id = 123;
+  RemoteRendezvous* rendez = rmgr_.Find(step_id);
+  core::ScopedUnref unref(rendez);
+  Notification n;
+  SchedClosure([this, rendez, key, cm, &n]() {
+    env.env->SleepForMicroseconds(100 * 1000);
+    TF_ASSERT_OK(rendez->Send(key, Rendezvous::Args(), V("peach"), false));
+    cm->StartCancel();
+    n.Notify();
+  });
+  Tensor val(DT_STRING);
+  bool val_dead = false;
+  Rendezvous::Args args;
+  args.cancellation_manager = cm;
+  TF_ASSERT_OK(rendez->Initialize(&worker_session_));
+  TF_ASSERT_OK(rendez->Recv(key, args, &val, &val_dead));
+  EXPECT_EQ(V(val), "peach");
+  n.WaitForNotification();
+  delete cm;
+}
+
 TEST_F(RpcRendezvousMgrTest, CleanupAll) {
   const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
       "/job:mnist/replica:1/task:2/cpu:0", 7890,
diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
index 9157dbe..0c3ef6a 100644
--- a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
+++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
@@ -47,8 +47,8 @@
 
 CollectiveExecutor* RpcCollectiveExecutorMgr::Create(int64 step_id) {
   CollectiveRemoteAccessDistributed* rma =
-      new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(),
-                                            worker_cache_, step_id);
+      new CollectiveRemoteAccessDistributed(
+          dev_mgr_, dev_resolver_.get(), work_queue_, worker_cache_, step_id);
   return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_,
                                     &gpu_ring_order_);
 }
diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc
index cfa6191..96488d4 100644
--- a/tensorflow/core/distributed_runtime/worker.cc
+++ b/tensorflow/core/distributed_runtime/worker.cc
@@ -196,8 +196,7 @@
   ProfilerSession* profiler_session = nullptr;
   if (collector && request->exec_opts().record_timeline()) {
     // If timeline was requested, assume we want hardware level tracing.
-    profiler_session =
-        ProfilerSession::Create(/*ProfilerContext*/ nullptr).release();
+    profiler_session = ProfilerSession::Create().release();
   }
   CancellationManager* cm = new CancellationManager;
   opts->SetCancelCallback([this, cm, step_id]() {
diff --git a/tensorflow/core/distributed_runtime/worker_cache_wrapper.h b/tensorflow/core/distributed_runtime/worker_cache_wrapper.h
index 22f9c2a..9d495ea 100644
--- a/tensorflow/core/distributed_runtime/worker_cache_wrapper.h
+++ b/tensorflow/core/distributed_runtime/worker_cache_wrapper.h
@@ -29,11 +29,11 @@
 
   // Updates *workers with strings naming the remote worker tasks to
   // which open channels have been established.
-  virtual void ListWorkers(std::vector<string>* workers) const {
+  void ListWorkers(std::vector<string>* workers) const override {
     return wrapped_->ListWorkers(workers);
   }
-  virtual void ListWorkersInJob(const string& job_name,
-                                std::vector<string>* workers) const {
+  void ListWorkersInJob(const string& job_name,
+                        std::vector<string>* workers) const override {
     return wrapped_->ListWorkersInJob(job_name, workers);
   }
 
@@ -41,7 +41,7 @@
   // or can be constructed, returns a pointer to a WorkerInterface object
   // wrapping that channel. The returned value must be destroyed by
   // calling `this->ReleaseWorker(target, ret)`
-  virtual WorkerInterface* GetOrCreateWorker(const string& target) {
+  WorkerInterface* GetOrCreateWorker(const string& target) override {
     return wrapped_->GetOrCreateWorker(target);
   }
 
@@ -50,7 +50,7 @@
   // TODO(jeff,sanjay): Consider moving target into WorkerInterface.
   // TODO(jeff,sanjay): Unify all worker-cache impls and factor out a
   //                    per-rpc-subsystem WorkerInterface creator.
-  virtual void ReleaseWorker(const string& target, WorkerInterface* worker) {
+  void ReleaseWorker(const string& target, WorkerInterface* worker) override {
     return wrapped_->ReleaseWorker(target, worker);
   }
 
@@ -63,29 +63,28 @@
   // within its local environment.  Returns true if *locality
   // was set, using only locally cached data.  Returns false
   // if status data for that device was not available.  Never blocks.
-  virtual bool GetDeviceLocalityNonBlocking(const string& device,
-                                            DeviceLocality* locality) {
+  bool GetDeviceLocalityNonBlocking(const string& device,
+                                    DeviceLocality* locality) override {
     return wrapped_->GetDeviceLocalityNonBlocking(device, locality);
   }
 
   // Set *locality with the DeviceLocality of the specified remote device
   // within its local environment.  Callback gets Status::OK if *locality
   // was set.
-  virtual void GetDeviceLocalityAsync(const string& device,
-                                      DeviceLocality* locality,
-                                      StatusCallback done) {
+  void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
+                              StatusCallback done) override {
     return wrapped_->GetDeviceLocalityAsync(device, locality, std::move(done));
   }
 
   // Start/stop logging activity.
-  virtual void SetLogging(bool active) { wrapped_->SetLogging(active); }
+  void SetLogging(bool active) override { wrapped_->SetLogging(active); }
 
   // Discard any saved log data.
-  virtual void ClearLogs() { wrapped_->ClearLogs(); }
+  void ClearLogs() override { wrapped_->ClearLogs(); }
 
   // Return logs for the identified step in *ss.  Any returned data will no
   // longer be stored.
-  virtual bool RetrieveLogs(int64 step_id, StepStats* ss) {
+  bool RetrieveLogs(int64 step_id, StepStats* ss) override {
     return wrapped_->RetrieveLogs(step_id, ss);
   }
 
diff --git a/tensorflow/core/example/example_parser_configuration.cc b/tensorflow/core/example/example_parser_configuration.cc
index 5660465..af06c07 100644
--- a/tensorflow/core/example/example_parser_configuration.cc
+++ b/tensorflow/core/example/example_parser_configuration.cc
@@ -114,13 +114,14 @@
 
   for (int i = 0; i < num_sparse; ++i) {
     int input_idx = sparse_keys_start + i;
-    (*var_len_features)[i].key = op_input_tensors[input_idx].scalar<string>()();
+    (*var_len_features)[i].key =
+        op_input_tensors[input_idx].scalar<tstring>()();
   }
 
   for (int i = 0; i < num_dense; ++i) {
     FixedLenFeature& config = (*fixed_len_features)[i];
     int dense_keys_offset = dense_keys_start + i;
-    config.key = op_input_tensors[dense_keys_offset].scalar<string>()();
+    config.key = op_input_tensors[dense_keys_offset].scalar<tstring>()();
 
     int defaults_offset = dense_defaults_start + i;
     config.default_value = op_input_tensors[defaults_offset];
diff --git a/tensorflow/core/example/feature_util.h b/tensorflow/core/example/feature_util.h
index 2cb895c..595e040 100644
--- a/tensorflow/core/example/feature_util.h
+++ b/tensorflow/core/example/feature_util.h
@@ -20,11 +20,11 @@
 // So accessing feature values is not very convenient.
 //
 // For example, to read a first value of integer feature "tag":
-//   int id = example.features().feature().at("tag").int64_list().value(0)
+//   int id = example.features().feature().at("tag").int64_list().value(0);
 //
 // to add a value:
 //   auto features = example->mutable_features();
-//   (*features->mutable_feature())["tag"].mutable_int64_list()->add_value(id)
+//   (*features->mutable_feature())["tag"].mutable_int64_list()->add_value(id);
 //
 // For float features you have to use float_list, for string - bytes_list.
 //
@@ -67,7 +67,8 @@
 //         feature { float_list { value: [4.0] } }
 //         feature { float_list { value: [5.0, 3.0] } }
 //       }
-//     } }
+//     }
+//   }
 //
 // Functions exposed by this library:
 //   HasFeature<[FeatureType]>(key, proto) -> bool
diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc
index 1eafd29..1480739 100644
--- a/tensorflow/core/framework/attr_value_util.cc
+++ b/tensorflow/core/framework/attr_value_util.cc
@@ -129,8 +129,6 @@
 }
 
 using TensorProtoHasher = std::function<uint64(const TensorProto&)>;
-using TensorProtosEquality =
-    std::function<bool(const TensorProto&, const TensorProto&)>;
 
 uint64 AttrValueHash(const AttrValue& a, const TensorProtoHasher& tensor_hash) {
   if (a.has_tensor()) return tensor_hash(a.tensor());
@@ -150,8 +148,9 @@
   return DeterministicProtoHash64(a);
 }
 
+template <typename TensorProtosEquality>
 bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b,
-                        const TensorProtosEquality& tensor_equality) {
+                        TensorProtosEquality tensor_equality) {
   if (a.type() != b.type()) {
     return false;
   } else if (a.type() != DT_INVALID && b.type() != DT_INVALID) {
@@ -493,6 +492,13 @@
   }
 }
 
+void MoveAttrValue(std::vector<string>&& value, AttrValue* out) {
+  out->mutable_list()->Clear();  // Create list() even if value empty.
+  for (auto& v : value) {
+    out->mutable_list()->add_s(std::move(v));
+  }
+}
+
 void SetAttrValue(const TensorShape& value, AttrValue* out) {
   value.AsProto(out->mutable_shape());
 }
diff --git a/tensorflow/core/framework/attr_value_util.h b/tensorflow/core/framework/attr_value_util.h
index 9fce488..e302e65 100644
--- a/tensorflow/core/framework/attr_value_util.h
+++ b/tensorflow/core/framework/attr_value_util.h
@@ -87,6 +87,8 @@
 
 void SetAttrValue(const AttrValue& value, AttrValue* out);
 
+void MoveAttrValue(std::vector<string>&& value, AttrValue* out);
+
 // Returns true if a and b have the same value.
 bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b);
 
diff --git a/tensorflow/core/framework/bfloat16.h b/tensorflow/core/framework/bfloat16.h
index e9e9402..ba5637d 100644
--- a/tensorflow/core/framework/bfloat16.h
+++ b/tensorflow/core/framework/bfloat16.h
@@ -20,10 +20,6 @@
 #include "tensorflow/core/platform/byte_order.h"
 #include "tensorflow/core/platform/types.h"
 
-#if defined(PLATFORM_WINDOWS)
-#include "tensorflow/core/platform/windows/cpu_info.h"
-#endif
-
 // Compact 16-bit encoding of floating point numbers. This representation uses
 // 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa.  It
 // is assumed that floats are in IEEE 754 format so the representation is just
diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h
index f0511f0..9ec192c 100644
--- a/tensorflow/core/framework/collective.h
+++ b/tensorflow/core/framework/collective.h
@@ -259,6 +259,9 @@
                           const Tensor* from_tensor,
                           const DeviceLocality& client_locality,
                           const StatusCallback& done) = 0;
+
+  // Runs the potentially-blocking closure/expensive callback.
+  virtual void RunClosure(std::function<void()> closure) = 0;
 };
 
 class PerStepCollectiveRemoteAccess;
diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc
index 0171fe9..507b7aa 100644
--- a/tensorflow/core/framework/dataset.cc
+++ b/tensorflow/core/framework/dataset.cc
@@ -264,9 +264,6 @@
             << " the graph. It will not be added again.";
     return Status::OK();
   }
-  if (!ctx->optimization_only()) {
-    TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(function_name, lib_def));
-  }
   const FunctionDef* f_def = lib_def.Find(function_name);
   if (f_def == nullptr) {
     return errors::InvalidArgument("Unable to find FunctionDef for ",
@@ -369,29 +366,10 @@
   return Status::OK();
 }
 
-Status DatasetBase::Save(SerializationContext* ctx,
-                         IteratorStateWriter* writer) const {
-  string serialized_graph_def;
-  string output_node;
-  GraphDefBuilder b;
-  DatasetGraphDefBuilder db(&b);
-  Node* node = nullptr;
-  TF_RETURN_IF_ERROR(AsGraphDefInternal(ctx, &db, &node));
-  output_node = node->name();
-  GraphDef graph_def;
-  TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
-  graph_def.SerializeToString(&serialized_graph_def);
-  TF_RETURN_IF_ERROR(
-      writer->WriteScalar(kDatasetGraphKey, serialized_graph_def));
-  TF_RETURN_IF_ERROR(
-      writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node));
-  return Status::OK();
-}
-
 Status DatasetBase::DatasetGraphDefBuilder::AddInputDataset(
     SerializationContext* ctx, const DatasetBase* dataset, Node** output) {
   Status status = dataset->AsGraphDefInternal(ctx, this, output);
-  if (ctx->optimization_only() && errors::IsUnimplemented(status)) {
+  if (errors::IsUnimplemented(status) && !ctx->fail_if_unimplemented()) {
     Tensor t(DT_VARIANT, TensorShape({}));
     // `StoreDatasetInVariantTensor` will transfer ownership of `dataset`. We
     // increment the refcount of `dataset` here to retain ownership.
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index abca353..251108c 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -34,10 +34,13 @@
 #include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/framework/variant_encode_decode.h"
 #include "tensorflow/core/framework/variant_tensor_data.h"
+#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/core/threadpool_interface.h"
 #include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/tracing.h"
 
 // Polymorphic datasets should support all primitive TensorFlow
@@ -198,48 +201,6 @@
  private:
   void AddPlaceholderInternal(const Tensor& val, Node** output);
   void AddTensorInternal(const Tensor& val, Node** output);
-
-  Status EnsureFunctionIsStateless(
-      const string& function_name,
-      const FunctionLibraryDefinition& lib_def) const {
-    const FunctionDef* function_def = lib_def.Find(function_name);
-    if (!function_def) {
-      return errors::InvalidArgument("Unable to find FunctionDef for ",
-                                     function_name, " in registry.");
-    }
-    for (const NodeDef& node_def : function_def->node_def()) {
-      const OpDef* op_def;
-      TF_RETURN_IF_ERROR(lib_def.LookUpOpDef(node_def.op(), &op_def));
-      // TODO(b/65524810): Hack to allow functions to capture Dataset op
-      // nodes needed for FlatMap. Currently, source datasets nodes have been
-      // marked stateful to avoid constant folding since we do not have a
-      // good way of serializing them.
-      if (IsOpWhitelisted(op_def)) {
-        continue;
-      }
-      if (op_def->is_stateful()) {
-        return errors::InvalidArgument(
-            "Op[name: ", node_def.name(), ", type: ", node_def.op(), "] ",
-            "in function ", function_name, " is stateful. ",
-            "Saving stateful functions is not supported yet.");
-      }
-    }
-    return Status::OK();
-  }
-
-  // Returns whether an op has been whitelisted for use inside map_fns.
-  // Uses a heuristic to whitelist source dataset ops which have been
-  // marked stateful due to b/65524810.
-  // Also looks up the `op_def->name` in the global
-  // `WhitelistedStatefulOpRegistry`.
-  bool IsOpWhitelisted(const OpDef* op_def) const {
-    return ((absl::EndsWith(op_def->name(), "Dataset") ||
-             absl::EndsWith(op_def->name(), "DatasetV2")) &&
-            op_def->output_arg_size() == 1 &&
-            op_def->output_arg(0).type() == DT_VARIANT) ||
-           WhitelistedStatefulOpRegistry::Global()->Contains(op_def->name());
-  }
-
   bool HasAttr(const string& op_type_name, const string& attr_name) const;
 
   bool HasAttr(const OpDef* op_def, const string& attr_name) const {
@@ -308,7 +269,8 @@
           runner(*(ctx->runner())),
           runner_threadpool_size(ctx->runner_threadpool_size()),
           stats_aggregator(ctx->stats_aggregator()),
-          thread_factory(ctx->thread_factory()) {}
+          thread_factory(ctx->thread_factory()),
+          thread_pool(ctx->thread_pool()) {}
 
     explicit Params(OpKernelContext* ctx)
         : env(ctx->env()), flr(ctx->function_library()) {
@@ -373,9 +335,11 @@
     // The `StatsAggregator` object to record statistics about the iterator.
     std::shared_ptr<StatsAggregator> stats_aggregator = nullptr;
 
-    // A `ThreadFactory` for creating threads used by iterators to perform
-    // blocking work.
+    // A factory for creating threads to perform blocking work.
     std::shared_ptr<ThreadFactory> thread_factory = nullptr;
+
+    // A shared thread pool to schedule computation into.
+    thread::ThreadPoolInterface* thread_pool = nullptr;
   };
 
   explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {}
@@ -412,10 +376,35 @@
     return &params_.runner;
   }
 
+  int32 runner_threadpool_size() { return params_.runner_threadpool_size; }
+
+  std::shared_ptr<StatsAggregator> stats_aggregator() {
+    return params_.stats_aggregator;
+  }
+
   const std::shared_ptr<ThreadFactory>& thread_factory() {
     return params_.thread_factory;
   }
 
+  thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; }
+
+  Params params() { return params_; }
+
+  std::unique_ptr<thread::ThreadPool> CreateThreadPool(const string& name,
+                                                       int num_threads) {
+    if (params_.thread_pool) {
+      // Create a `ThreadPool` instance by wrapping `params_.thread_pool` (which
+      // is an instance of `thread::ThreadPoolInterface`). Notably, the
+      // ownership of `params_.thread_pool` is *not* transferred onto the newly
+      // created `ThreadPool` instance.
+      return absl::make_unique<thread::ThreadPool>(params_.thread_pool);
+    } else {
+      return absl::make_unique<thread::ThreadPool>(params_.env, ThreadOptions(),
+                                                   name, num_threads,
+                                                   /*low_latency_hint=*/false);
+    }
+  }
+
   std::unique_ptr<Thread> StartThread(const string& name,
                                       std::function<void()> fn) {
     if (params_.thread_factory) {
@@ -426,14 +415,6 @@
     }
   }
 
-  int32 runner_threadpool_size() { return params_.runner_threadpool_size; }
-
-  std::shared_ptr<StatsAggregator> stats_aggregator() {
-    return params_.stats_aggregator;
-  }
-
-  Params params() { return params_; }
-
  private:
   Params params_;
 };
@@ -443,7 +424,23 @@
  public:
   struct Params {
     std::vector<std::pair<string, Tensor>>* input_list = nullptr;  // Not owned.
-    bool optimization_only = false;
+
+    // Indicates whether serialization should check if the dataset depends on
+    // external state. If the check is enabled and external state is
+    // encountered, then the serialization will fail.
+    bool check_external_state = true;
+
+    // Indicates whether an attempt to serialize a dataset that does not
+    // implement serialization should result in an error. If set to `false`, the
+    // serialized graph will replace the dataset with a placeholder returned in
+    // `input_list`.
+    bool fail_if_unimplemented = true;
+
+    // Indicates whether (potentionally large) data tensors should be
+    // serialized, or replaced with a placeholder returned in `input_list`. The
+    // latter makes sense to do when performing data agnostic graph rewrites to
+    // reduce the memory usage.
+    bool serialize_data_tensors = true;
   };
 
   explicit SerializationContext(Params params) : params_(std::move(params)) {}
@@ -452,7 +449,11 @@
     return params_.input_list;
   }
 
-  bool optimization_only() { return params_.optimization_only; }
+  bool check_external_state() const { return params_.check_external_state; }
+
+  bool fail_if_unimplemented() const { return params_.fail_if_unimplemented; }
+
+  bool serialize_data_tensors() const { return params_.serialize_data_tensors; }
 
  private:
   Params params_;
@@ -546,12 +547,22 @@
     return input->RestoreInternal(ctx, reader);
   }
 
-  // Saves the state of this iterator recursively.
+  // Saves the state of this iterator.
+  //
+  // This method is used to store the state of the iterator in a checkpoint.
+  //
+  // TODO(jsimsa): Make this method pure virtual once all `IteratorBase`
+  // implementations have an override.
   virtual Status SaveInternal(IteratorStateWriter* writer) {
     return errors::Unimplemented("SaveInternal");
   }
 
-  // Restores the state of this iterator recursively.
+  // Restores the state of this iterator.
+  //
+  // This method is used to restore the state of the iterator from a checkpoint.
+  //
+  // TODO(jsimsa): Make this method pure virtual once all `IteratorBase`
+  // implementations have an override.
   virtual Status RestoreInternal(IteratorContext* ctx,
                                  IteratorStateReader* reader) {
     return errors::Unimplemented("RestoreInternal");
@@ -691,9 +702,25 @@
   // A human-readable debug string for this dataset.
   virtual string DebugString() const = 0;
 
-  // Serializes the dataset and writes it to the `writer`.
-  virtual Status Save(SerializationContext* ctx,
-                      IteratorStateWriter* writer) const;
+  // If the dataset is stateful it will not be possible to save its graph or
+  // checkpoint the state of its iterators.
+  //
+  // TODO(jsimsa): Remove this method once all `DatasetBase` implementations are
+  // migrated over to `CheckExternalState`.
+  virtual bool IsStateful() const { return false; }
+
+  // Indicates whether the dataset depends on any external state. If so, the
+  // method returns `errors::FailedPrecondition` with a message that identifies
+  // the external state. Otherwise, the method returns `Status::OK()`.
+  //
+  // TODO(jsimsa): Make this method pure virtual once all `DatasetBase`
+  // implementations have an override.
+  virtual Status CheckExternalState() const {
+    if (IsStateful()) {
+      return errors::FailedPrecondition("Dataset cannot be serialized.");
+    }
+    return Status::OK();
+  }
 
  protected:
   friend Status AsGraphDef(
@@ -704,11 +731,22 @@
 
   class DatasetGraphDefBuilder : public GraphDefBuilderWrapper {
    public:
-    DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {}
+    explicit DatasetGraphDefBuilder(GraphDefBuilder* b)
+        : GraphDefBuilderWrapper(b) {}
     Status AddInputDataset(SerializationContext* ctx,
                            const DatasetBase* dataset, Node** output);
   };
 
+  // Serializes the dataset into a `GraphDef`, which has two uses:
+  //
+  // 1) To perform static input pipeline optimizations, tf.data serializes the
+  // dataset graph, applies graph rewrites, and then deserializes the graph.
+  // If a subclass of `DatasetBase` does not implement this method, then it will
+  // be excluded from static optimizations (and so will any upstream datasets).
+  //
+  // 2) To save the dataset so that it can restore at a later point (possibly in
+  // different environment). If a subclass of `DatasetBase` does not implement
+  // this method, then this migration will not be possible.
   virtual Status AsGraphDefInternal(SerializationContext* ctx,
                                     DatasetGraphDefBuilder* b,
                                     Node** node) const = 0;
@@ -767,7 +805,7 @@
                  bool* end_of_sequence) final;
 
   Status Save(SerializationContext* ctx, IteratorStateWriter* writer) final {
-    TF_RETURN_IF_ERROR(params_.dataset->Save(ctx, writer));
+    TF_RETURN_IF_ERROR(params_.dataset->CheckExternalState());
     return IteratorBase::Save(ctx, writer);
   }
 
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 5f9cc9a..fb9c6d3 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -621,7 +621,7 @@
   strings::StrAppend(&out, "\n(");
   auto get_type_and_device = [](const NodeDef& n) {
     DataType dt;
-    if (!GetNodeAttr(n, "T", &dt).ok()) {
+    if (!TryGetNodeAttr(n, "T", &dt)) {
       dt = DT_INVALID;
     }
     if (!n.device().empty()) {
@@ -1226,6 +1226,7 @@
   // the duration of the function could lead to deadlock).
   FunctionLibraryDefinition clone(other);
   mutex_lock l(mu_);
+  mutex_lock l2(clone.mu_);
   // Remember the funcs and grads that we added successfully so that
   // we can roll them back on error.
   std::vector<string> funcs;
@@ -1388,7 +1389,7 @@
   // If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or
   // Foo's attributes.
   const NameAttrList* forward_func_attrs;
-  if (!GetNodeAttr(ndef, kFuncAttr, &forward_func_attrs).ok()) {
+  if (!TryGetNodeAttr(ndef, kFuncAttr, &forward_func_attrs)) {
     return nullptr;
   }
   const string& func_name = forward_func_attrs->name();
@@ -1433,7 +1434,7 @@
 Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
                                           const string& attr, T* value) const {
   const FunctionDef* fdef = GetAttrImpl(ndef);
-  if (fdef && GetNodeAttr(AttrSlice(&fdef->attr()), attr, value).ok()) {
+  if (fdef && TryGetNodeAttr(AttrSlice(&fdef->attr()), attr, value)) {
     return Status::OK();
   }
   return errors::InvalidArgument("Attr ", attr, " is not defined.");
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index a106c74..f9c7b8b 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -634,7 +634,8 @@
                              Handle* handle) = 0;
   Status Instantiate(const string& function_name, AttrSlice attrs,
                      Handle* handle) {
-    return Instantiate(function_name, attrs, {}, handle);
+    auto opts = absl::make_unique<InstantiateOptions>();
+    return Instantiate(function_name, attrs, *opts, handle);
   }
 
   // Releases state associated with the handle.
diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc
index 1d48218..246f50a 100644
--- a/tensorflow/core/framework/memory_types.cc
+++ b/tensorflow/core/framework/memory_types.cc
@@ -156,14 +156,14 @@
   }
 
   std::vector<int32> hostmem_attr;
-  if (GetNodeAttr(ndef, "_input_hostmem", &hostmem_attr).ok()) {
+  if (TryGetNodeAttr(ndef, "_input_hostmem", &hostmem_attr)) {
     for (int32 i : hostmem_attr) {
       if (0 <= i && i < inp_mtypes->size()) {
         (*inp_mtypes)[i] = HOST_MEMORY;
       }
     }
   }
-  if (GetNodeAttr(ndef, "_output_hostmem", &hostmem_attr).ok()) {
+  if (TryGetNodeAttr(ndef, "_output_hostmem", &hostmem_attr)) {
     for (int32 i : hostmem_attr) {
       if (0 <= i && i < out_mtypes->size()) {
         (*out_mtypes)[i] = HOST_MEMORY;
diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc
index a6b6114..c8b3c71 100644
--- a/tensorflow/core/framework/model.cc
+++ b/tensorflow/core/framework/model.cc
@@ -49,6 +49,8 @@
                        double buffer_size, double* output_time_derivative,
                        double* input_time_derivative,
                        double* buffer_size_derivative) {
+  // Case 0: either the producer or the consumer are infinitely fast. Wait time
+  // is the time to produce an output.
   if (output_time == 0 || input_time == 0) {
     if (output_time_derivative) {
       *output_time_derivative = 1.0L;
@@ -61,6 +63,22 @@
     }
     return output_time;
   }
+  // Case 1: the consumer is slower than the producer. Wait time is 0 since the
+  // buffer will be full in the long run.
+  if (input_time > output_time) {
+    if (output_time_derivative) {
+      *output_time_derivative = 0.0L;
+    }
+    if (input_time_derivative) {
+      *input_time_derivative = 0.0L;
+    }
+    if (buffer_size_derivative) {
+      *buffer_size_derivative = 0.0L;
+    }
+    return 0;
+  }
+  // Case 2: the consumer and the producer are equally fast. Expected wait time
+  // decreases linearly with the size of the buffer.
   if (input_time == output_time) {
     const double p_buffer_empty = 1.0L / (buffer_size + 1.0L);
     if (output_time_derivative) {
@@ -75,6 +93,8 @@
     }
     return p_buffer_empty * output_time;
   }
+  // Case 3: the producer is slower than the consumer and neither is infinitely
+  // fast.
   const double alpha = 1.0L / input_time;
   const double beta = 1.0L / output_time;
   const double ratio_pow = std::pow((beta / alpha), (buffer_size + 1.0L));
@@ -167,14 +187,20 @@
 
   // The processing time is the sum of the self processing time and the average
   // processing time of inputs comprising the interleave "cycle".
-  double TotalProcessingTimeLocked() override SHARED_LOCKS_REQUIRED(mu_) {
-    if (num_inputs() <= 1) {
-      return SelfProcessingTimeLocked();
+  double TotalProcessingTimeLocked(std::map<string, double>* processing_times)
+      override SHARED_LOCKS_REQUIRED(mu_) {
+    double self_processing_time = SelfProcessingTimeLocked();
+    if (processing_times) {
+      (*processing_times)[long_name()] = self_processing_time;
     }
-    double processing_time = (TotalProcessingTimeForInputs() -
-                              inputs_.front()->TotalProcessingTime()) /
-                             static_cast<double>(num_inputs() - 1);
-    return SelfProcessingTimeLocked() + processing_time;
+    if (num_inputs() <= 1) {
+      return self_processing_time;
+    }
+    double processing_time =
+        (TotalProcessingTimeForInputs(processing_times) -
+         inputs_.front()->TotalProcessingTime(/*processing_times=*/nullptr)) /
+        static_cast<double>(num_inputs() - 1);
+    return self_processing_time + processing_time;
   }
 };
 
@@ -282,13 +308,19 @@
 
   // The processing time is the sum of the self processing time and the average
   // processing time of inputs comprising the interleave "cycle".
-  double TotalProcessingTimeLocked() override SHARED_LOCKS_REQUIRED(mu_) {
+  double TotalProcessingTimeLocked(std::map<string, double>* processing_times)
+      override SHARED_LOCKS_REQUIRED(mu_) {
+    double self_processing_time = SelfProcessingTimeLocked();
+    if (processing_times) {
+      (*processing_times)[long_name()] = self_processing_time;
+    }
     if (num_inputs() <= 1) {
-      return SelfProcessingTimeLocked();
+      return self_processing_time;
     }
     double processing_time =
-        TotalProcessingTimeForInputs() - inputs_.front()->TotalProcessingTime();
-    return SelfProcessingTimeLocked() +
+        TotalProcessingTimeForInputs(processing_times) -
+        inputs_.front()->TotalProcessingTime(/*processing_times=*/nullptr);
+    return self_processing_time +
            processing_time / static_cast<double>(num_inputs() - 1);
   }
 };
@@ -345,8 +377,14 @@
 
   // The processing time is the sum of the self processing time and the product
   // of `ratio_` and the sum of processing times of inputs.
-  double TotalProcessingTimeLocked() override SHARED_LOCKS_REQUIRED(mu_) {
-    return SelfProcessingTimeLocked() + ratio_ * TotalProcessingTimeForInputs();
+  double TotalProcessingTimeLocked(std::map<string, double>* processing_times)
+      override SHARED_LOCKS_REQUIRED(mu_) {
+    double self_processing_time = SelfProcessingTimeLocked();
+    if (processing_times) {
+      (*processing_times)[long_name()] = self_processing_time;
+    }
+    return self_processing_time +
+           ratio_ * TotalProcessingTimeForInputs(processing_times);
   }
 
  private:
@@ -462,8 +500,14 @@
 
   // The processing time is the sum of the self processing time and the product
   // of `ratio_` and the sum of processing times of inputs.
-  double TotalProcessingTimeLocked() override SHARED_LOCKS_REQUIRED(mu_) {
-    return SelfProcessingTimeLocked() + ratio_ * TotalProcessingTimeForInputs();
+  double TotalProcessingTimeLocked(std::map<string, double>* processing_times)
+      override SHARED_LOCKS_REQUIRED(mu_) {
+    double self_processing_time = SelfProcessingTimeLocked();
+    if (processing_times) {
+      (*processing_times)[long_name()] = self_processing_time;
+    }
+    return self_processing_time +
+           ratio_ * TotalProcessingTimeForInputs(processing_times);
   }
 
  private:
@@ -524,16 +568,22 @@
 
   // The processing time is the sum of the self processing time and the product
   // of the ratio estimate and the sum of processing times of inputs.
-  double TotalProcessingTimeLocked() override SHARED_LOCKS_REQUIRED(mu_) {
+  double TotalProcessingTimeLocked(std::map<string, double>* processing_times)
+      override SHARED_LOCKS_REQUIRED(mu_) {
+    double self_processing_time = SelfProcessingTimeLocked();
+    if (processing_times) {
+      (*processing_times)[long_name()] = self_processing_time;
+    }
     if (inputs_.empty() || num_elements_ == 0) {
-      return SelfProcessingTimeLocked();
+      return self_processing_time;
     }
     // TODO(jsimsa): The current implementation assumes that the number of input
     // elements consumed per output is the same across all inputs.
     std::shared_ptr<Node> input = inputs_.front();
     double ratio = static_cast<double>(input->num_elements()) /
                    static_cast<double>(num_elements_);
-    return SelfProcessingTimeLocked() + ratio * TotalProcessingTimeForInputs();
+    return self_processing_time +
+           ratio * TotalProcessingTimeForInputs(processing_times);
   }
 };
 
@@ -557,8 +607,9 @@
   }
 
   // The processing time is the sum of processing times of inputs.
-  double TotalProcessingTimeLocked() override SHARED_LOCKS_REQUIRED(mu_) {
-    return TotalProcessingTimeForInputs();
+  double TotalProcessingTimeLocked(std::map<string, double>* processing_times)
+      override SHARED_LOCKS_REQUIRED(mu_) {
+    return TotalProcessingTimeForInputs(processing_times);
   }
 };
 
@@ -719,6 +770,28 @@
   return parameters;
 }
 
+std::map<string, std::shared_ptr<Parameter>> Model::CollectEssentialParallelism(
+    std::shared_ptr<Node> node) {
+  // Parallelism parameter is considered to be essential if the coressponding
+  // transformations's processing time is greater than essential rate times the
+  // average transformation self processing time.
+  constexpr double kEssentialRate = 0.3L;
+
+  std::map<string, std::shared_ptr<Parameter>> parameters;
+  node->CollectTunableParameters(&parameters);
+  std::map<string, double> processing_times;
+  double processing_time = node->TotalProcessingTime(&processing_times);
+  double uniform_share =
+      processing_time / static_cast<double>(processing_times.size());
+  std::map<string, std::shared_ptr<Parameter>> essential_parameters;
+  for (auto& pair : parameters) {
+    if (processing_times[pair.first] > kEssentialRate * uniform_share) {
+      essential_parameters.insert(pair);
+    }
+  }
+  return essential_parameters;
+}
+
 void Model::OptimizeGradientDescent(int64 cpu_budget) {
   std::shared_ptr<Node> snapshot;
   {
@@ -726,24 +799,20 @@
     snapshot = output_->Snapshot(nullptr);
   }
   VLOG(2) << "Starting optimization of tunable parameters with GradientDescent";
-  const double processing_time = TotalProcessingTime(snapshot);
   auto parameters = CollectTunableParameters(snapshot);
+  auto essential_parameters = CollectEssentialParallelism(snapshot);
   for (auto& pair : parameters) {
     pair.second->value = pair.second->min;
   }
   // Gradient descent step size.
-  constexpr double kDescentStep = 0.7L;
+  constexpr double kDescentStep = 0.1L;
 
   // Optimization is stopped once the `OutputTime` improvement is smaller than
   // this value.
   constexpr double kOptimizationPrecision = 100.0L;
 
-  // Penalizing step for the parameters after we overoptimize (output time <
-  // processing time / cpu budget) the objective.
-  constexpr double kParametersPenalty = 0.05L;
-
   // Maximum number of iterations for optimization.
-  constexpr int64 kMaxIterations = 100;
+  constexpr int64 kMaxIterations = 1000;
 
   double output_time = 0;
   double new_output_time;
@@ -751,8 +820,14 @@
   for (int i = 0; i < kMaxIterations; ++i) {
     std::map<string, double> gradient;
     new_output_time = OutputTime(snapshot, &gradient);
+    int64 model_parallelism = 0;
+    for (auto& pair : essential_parameters) {
+      model_parallelism += std::round(pair.second->value);
+    }
+    // We terminate once the improvement of the output latency is too small or
+    // the essential transformations' parallelism reaches the CPU budget.
     if (std::abs(output_time - new_output_time) < kOptimizationPrecision ||
-        new_output_time < processing_time / cpu_budget) {
+        model_parallelism > cpu_budget) {
       break;
     }
     double max_abs_derivative = 1.0;
@@ -762,13 +837,6 @@
             std::max(max_abs_derivative, std::abs(gradient[pair.first]));
       }
     }
-    // Maximizes parameters on early stages of the model.
-    if (max_abs_derivative < kOptimizationPrecision) {
-      for (auto& pair : parameters) {
-        pair.second->value = pair.second->max;
-      }
-      break;
-    }
     for (auto& pair : parameters) {
       new_value = pair.second->value -
                   kDescentStep * gradient[pair.first] / max_abs_derivative;
@@ -783,23 +851,6 @@
     }
     output_time = new_output_time;
   }
-  // Penalize parameters if we overoptimized the objective.
-  for (int i = 0;
-       i < kMaxIterations && new_output_time < processing_time / cpu_budget;
-       ++i) {
-    for (auto& pair : parameters) {
-      new_value = pair.second->value - kParametersPenalty;
-      // Projection on a feasible interval.
-      if (new_value > pair.second->max) {
-        pair.second->value = pair.second->max;
-      } else if (new_value < pair.second->min) {
-        pair.second->value = pair.second->min;
-      } else {
-        pair.second->value = new_value;
-      }
-    }
-    new_output_time = OutputTime(snapshot, /*gradient=*/nullptr);
-  }
   VLOG(2) << "Number of tunable parameters: " << parameters.size();
   for (auto& pair : parameters) {
     pair.second->value = std::round(pair.second->value);
@@ -884,7 +935,7 @@
 }
 
 double Model::TotalProcessingTime(std::shared_ptr<Node> node) {
-  return node->TotalProcessingTime();
+  return node->TotalProcessingTime(/*processing_times=*/nullptr);
 }
 
 }  // namespace model
diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h
index 2687cc6..bab816f 100644
--- a/tensorflow/core/framework/model.h
+++ b/tensorflow/core/framework/model.h
@@ -311,9 +311,12 @@
   }
 
   // Returns the per-element CPU time spent in the subtree rooted in this node.
-  double TotalProcessingTime() LOCKS_EXCLUDED(mu_) {
+  // If `processing_times` is not `nullptr`, collects the per-element CPU time
+  // spent in each node of the subtree.
+  double TotalProcessingTime(std::map<string, double>* processing_times)
+      LOCKS_EXCLUDED(mu_) {
     tf_shared_lock l(mu_);
-    return TotalProcessingTimeLocked();
+    return TotalProcessingTimeLocked(processing_times);
   }
 
  protected:
@@ -360,10 +363,13 @@
   // Processing time for a given input is a weighted combination of a statistic
   // based on history of input processing time and the actual time. This is done
   // to improve accuracy of processing time estimation for newly created inputs.
+  // If `processing_times` is not `nullptr`, collects the per-element CPU time
+  // spent in each input node.
   //
   // Uniform distribution of per-element processing times across different
   // inputs is assumed.
-  double TotalProcessingTimeForInputs() SHARED_LOCKS_REQUIRED(mu_) {
+  double TotalProcessingTimeForInputs(
+      std::map<string, double>* processing_times) SHARED_LOCKS_REQUIRED(mu_) {
     // If the number of elements produced by an input is smaller than this
     // constant, then its processing time is estimated using a weighted average
     // of the empirical processing time and processing time history.
@@ -377,7 +383,8 @@
     for (auto& input : inputs_) {
       // Inputs for which autotuning is disabled are excluded.
       if (input->autotune()) {
-        double input_processing_time = input->TotalProcessingTime();
+        double input_processing_time =
+            input->TotalProcessingTime(processing_times);
         int64 num_elements = input->num_elements();
         if (num_elements < kNumElementsThreshold) {
           if (input_processing_time_count_ < kCountThreshold) {
@@ -411,7 +418,11 @@
   }
 
   // Returns the per-element CPU time spent in the subtree rooted in this node.
-  virtual double TotalProcessingTimeLocked() SHARED_LOCKS_REQUIRED(mu_) = 0;
+  // If `processing_times` is not `nullptr`, collects the per-element CPU time
+  // spent in each node of the subtree.
+  virtual double TotalProcessingTimeLocked(
+      std::map<string, double>* processing_times)
+      SHARED_LOCKS_REQUIRED(mu_) = 0;
 
   mutable mutex mu_;
   const int64 id_;
@@ -536,6 +547,14 @@
   std::map<string, std::shared_ptr<Parameter>> CollectTunableParameters(
       std::shared_ptr<Node> node);
 
+  // Collects "essential" parallelism parameters of transformations in the tree
+  // rooted in the given node. Which parameters are essential is determined by
+  // comparison the processing time spent in the corresponding transformation
+  // relative to other transformations. The collected parameters are returned
+  // as a mapping from a (unique) node name to a parallelism parameter.
+  std::map<string, std::shared_ptr<Parameter>> CollectEssentialParallelism(
+      std::shared_ptr<Node> node);
+
   // This optimization algorithm starts by setting all tunable parallelism
   // parameters to 1. It then repeatedly identifies the parameter whose increase
   // in parallelism decreases the output time the most. This process is repeated
diff --git a/tensorflow/core/framework/model_test.cc b/tensorflow/core/framework/model_test.cc
index 1a96ebd..4263617 100644
--- a/tensorflow/core/framework/model_test.cc
+++ b/tensorflow/core/framework/model_test.cc
@@ -58,26 +58,36 @@
   std::vector<double> input_times(1, input_time);
   async_interleave_many->add_processing_time(100);
   EXPECT_EQ(async_interleave_many->processing_time(), 100);
-  EXPECT_EQ(async_interleave_many->TotalProcessingTime(), 0);
+  EXPECT_EQ(
+      async_interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
+      0);
   EXPECT_EQ(async_interleave_many->OutputTime(&input_times, nullptr), 0);
   async_interleave_many->record_element();
   EXPECT_EQ(async_interleave_many->num_elements(), 1);
-  EXPECT_EQ(async_interleave_many->TotalProcessingTime(), 100);
+  EXPECT_EQ(
+      async_interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
+      100);
   EXPECT_LE(async_interleave_many->OutputTime(&input_times, nullptr), 100);
   EXPECT_GE(async_interleave_many->OutputTime(&input_times, nullptr), 0);
   source1->add_processing_time(200);
   source2->add_processing_time(300);
-  EXPECT_EQ(async_interleave_many->TotalProcessingTime(), 100);
+  EXPECT_EQ(
+      async_interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
+      100);
   EXPECT_LE(async_interleave_many->OutputTime(&input_times, nullptr), 100);
   EXPECT_GE(async_interleave_many->OutputTime(&input_times, nullptr), 0);
   source1->record_element();
   source2->record_element();
-  EXPECT_EQ(async_interleave_many->TotalProcessingTime(), 100 + 250);
+  EXPECT_EQ(
+      async_interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
+      100 + 250);
   EXPECT_LE(async_interleave_many->OutputTime(&input_times, nullptr),
             100 + 250 / parallelism);
   EXPECT_GE(async_interleave_many->OutputTime(&input_times, nullptr), 0);
   async_interleave_many->record_element();
-  EXPECT_EQ(async_interleave_many->TotalProcessingTime(), 50 + 250);
+  EXPECT_EQ(
+      async_interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
+      50 + 250);
   EXPECT_LE(async_interleave_many->OutputTime(&input_times, nullptr),
             50 + 250 / parallelism);
   EXPECT_GE(async_interleave_many->OutputTime(&input_times, nullptr), 0);
@@ -109,49 +119,51 @@
   async_known_many->add_input(source2);
   std::vector<double> input_times(1, input_time);
   source1->add_processing_time(100);
-  EXPECT_EQ(async_known_many->TotalProcessingTime(), 0);
+  EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
+            0);
   EXPECT_EQ(async_known_many->OutputTime(&input_times, nullptr), 0);
   source2->add_processing_time(200);
-  EXPECT_EQ(async_known_many->TotalProcessingTime(), 0);
+  EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
+            0);
   EXPECT_EQ(async_known_many->OutputTime(&input_times, nullptr), 0);
   source1->record_element();
-  EXPECT_EQ(async_known_many->TotalProcessingTime(),
+  EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
             num_inputs_per_output * 100);
   EXPECT_LE(async_known_many->OutputTime(&input_times, nullptr),
             num_inputs_per_output * 100);
   EXPECT_GE(async_known_many->OutputTime(&input_times, nullptr), 0);
   source2->record_element();
-  EXPECT_EQ(async_known_many->TotalProcessingTime(),
+  EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
             num_inputs_per_output * (100 + 200));
   EXPECT_LE(async_known_many->OutputTime(&input_times, nullptr),
             num_inputs_per_output * (100 + 200));
   EXPECT_GE(async_known_many->OutputTime(&input_times, nullptr), 0);
   source1->record_element();
-  EXPECT_EQ(async_known_many->TotalProcessingTime(),
+  EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
             num_inputs_per_output * (50 + 200));
   EXPECT_LE(async_known_many->OutputTime(&input_times, nullptr),
             num_inputs_per_output * (50 + 200));
   EXPECT_GE(async_known_many->OutputTime(&input_times, nullptr), 0);
   source2->record_element();
-  EXPECT_EQ(async_known_many->TotalProcessingTime(),
+  EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
             num_inputs_per_output * (50 + 100));
   EXPECT_LE(async_known_many->OutputTime(&input_times, nullptr),
             num_inputs_per_output * (50 + 100));
   EXPECT_GE(async_known_many->OutputTime(&input_times, nullptr), 0);
   async_known_many->add_processing_time(128);
-  EXPECT_EQ(async_known_many->TotalProcessingTime(),
+  EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
             num_inputs_per_output * (50 + 100));
   EXPECT_LE(async_known_many->OutputTime(&input_times, nullptr),
             num_inputs_per_output * (50 + 100));
   EXPECT_GE(async_known_many->OutputTime(&input_times, nullptr), 0);
   async_known_many->record_element();
-  EXPECT_EQ(async_known_many->TotalProcessingTime(),
+  EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
             num_inputs_per_output * (50 + 100) + 128);
   EXPECT_LE(async_known_many->OutputTime(&input_times, nullptr),
             num_inputs_per_output * (50 + 100) + 128 / parallelism);
   EXPECT_GE(async_known_many->OutputTime(&input_times, nullptr), 0);
   async_known_many->record_element();
-  EXPECT_EQ(async_known_many->TotalProcessingTime(),
+  EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
             num_inputs_per_output * (50 + 100) + 64);
   EXPECT_LE(async_known_many->OutputTime(&input_times, nullptr),
             num_inputs_per_output * (50 + 100) + 64 / parallelism);
@@ -178,22 +190,27 @@
   std::vector<double> input_times(1, 0);
   interleave_many->add_processing_time(100);
   EXPECT_EQ(interleave_many->processing_time(), 100);
-  EXPECT_EQ(interleave_many->TotalProcessingTime(), 0);
+  EXPECT_EQ(interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
+            0);
   EXPECT_EQ(interleave_many->OutputTime(&input_times, nullptr), 0);
   interleave_many->record_element();
   EXPECT_EQ(interleave_many->num_elements(), 1);
-  EXPECT_EQ(interleave_many->TotalProcessingTime(), 100);
+  EXPECT_EQ(interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
+            100);
   EXPECT_EQ(interleave_many->OutputTime(&input_times, nullptr), 100);
   source1->add_processing_time(200);
   source2->add_processing_time(300);
-  EXPECT_EQ(interleave_many->TotalProcessingTime(), 100);
+  EXPECT_EQ(interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
+            100);
   EXPECT_EQ(interleave_many->OutputTime(&input_times, nullptr), 100);
   source1->record_element();
   source2->record_element();
-  EXPECT_EQ(interleave_many->TotalProcessingTime(), 350);
+  EXPECT_EQ(interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
+            350);
   EXPECT_EQ(interleave_many->OutputTime(&input_times, nullptr), 350);
   interleave_many->record_element();
-  EXPECT_EQ(interleave_many->TotalProcessingTime(), 300);
+  EXPECT_EQ(interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
+            300);
   EXPECT_EQ(interleave_many->OutputTime(&input_times, nullptr), 300);
 }
 
@@ -211,42 +228,43 @@
   known_many->add_input(source2);
   std::vector<double> input_times(1, 0);
   source1->add_processing_time(100);
-  EXPECT_EQ(known_many->TotalProcessingTime(), 0);
+  EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr), 0);
   EXPECT_EQ(known_many->OutputTime(&input_times, nullptr), 0);
   source2->add_processing_time(200);
-  EXPECT_EQ(known_many->TotalProcessingTime(), 0);
+  EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr), 0);
   EXPECT_EQ(known_many->OutputTime(&input_times, nullptr), 0);
   source1->record_element();
-  EXPECT_EQ(known_many->TotalProcessingTime(), num_inputs_per_output * 100);
+  EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr),
+            num_inputs_per_output * 100);
   EXPECT_EQ(known_many->OutputTime(&input_times, nullptr),
             num_inputs_per_output * 100);
   source2->record_element();
-  EXPECT_EQ(known_many->TotalProcessingTime(),
+  EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr),
             num_inputs_per_output * (100 + 200));
   EXPECT_EQ(known_many->OutputTime(&input_times, nullptr),
             num_inputs_per_output * (100 + 200));
   source1->record_element();
-  EXPECT_EQ(known_many->TotalProcessingTime(),
+  EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr),
             num_inputs_per_output * (50 + 200));
   EXPECT_EQ(known_many->OutputTime(&input_times, nullptr),
             num_inputs_per_output * (50 + 200));
   source2->record_element();
-  EXPECT_EQ(known_many->TotalProcessingTime(),
+  EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr),
             num_inputs_per_output * (50 + 100));
   EXPECT_EQ(known_many->OutputTime(&input_times, nullptr),
             num_inputs_per_output * (50 + 100));
   known_many->add_processing_time(128);
-  EXPECT_EQ(known_many->TotalProcessingTime(),
+  EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr),
             num_inputs_per_output * (50 + 100));
   EXPECT_EQ(known_many->OutputTime(&input_times, nullptr),
             num_inputs_per_output * (50 + 100));
   known_many->record_element();
-  EXPECT_EQ(known_many->TotalProcessingTime(),
+  EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr),
             num_inputs_per_output * (50 + 100) + 128);
   EXPECT_EQ(known_many->OutputTime(&input_times, nullptr),
             num_inputs_per_output * (50 + 100) + 128);
   known_many->record_element();
-  EXPECT_EQ(known_many->TotalProcessingTime(),
+  EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr),
             num_inputs_per_output * (50 + 100) + 64);
   EXPECT_EQ(known_many->OutputTime(&input_times, nullptr),
             num_inputs_per_output * (50 + 100) + 64);
@@ -259,15 +277,15 @@
   std::vector<double> input_times(1, 0);
   source->add_processing_time(100);
   EXPECT_EQ(source->processing_time(), 100);
-  EXPECT_EQ(source->TotalProcessingTime(), 0);
+  EXPECT_EQ(source->TotalProcessingTime(/*processing_times=*/nullptr), 0);
   EXPECT_EQ(source->OutputTime(&input_times, nullptr), 0);
   source->record_element();
   EXPECT_EQ(source->num_elements(), 1);
-  EXPECT_EQ(source->TotalProcessingTime(), 100);
+  EXPECT_EQ(source->TotalProcessingTime(/*processing_times=*/nullptr), 100);
   EXPECT_EQ(source->OutputTime(&input_times, nullptr), 100);
   source->record_element();
   EXPECT_EQ(source->num_elements(), 2);
-  EXPECT_EQ(source->TotalProcessingTime(), 50);
+  EXPECT_EQ(source->TotalProcessingTime(/*processing_times=*/nullptr), 50);
   EXPECT_EQ(source->OutputTime(&input_times, nullptr), 50);
 }
 
@@ -283,22 +301,26 @@
   std::vector<double> input_times(1, 0);
   unknown_many->add_processing_time(100);
   EXPECT_EQ(unknown_many->processing_time(), 100);
-  EXPECT_EQ(unknown_many->TotalProcessingTime(), 0);
+  EXPECT_EQ(unknown_many->TotalProcessingTime(/*processing_times=*/nullptr), 0);
   EXPECT_EQ(unknown_many->OutputTime(&input_times, nullptr), 0);
   unknown_many->record_element();
   EXPECT_EQ(unknown_many->num_elements(), 1);
-  EXPECT_EQ(unknown_many->TotalProcessingTime(), 100);
+  EXPECT_EQ(unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
+            100);
   EXPECT_EQ(unknown_many->OutputTime(&input_times, nullptr), 100);
   source1->add_processing_time(100);
   source2->add_processing_time(200);
-  EXPECT_EQ(unknown_many->TotalProcessingTime(), 100);
+  EXPECT_EQ(unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
+            100);
   EXPECT_EQ(unknown_many->OutputTime(&input_times, nullptr), 100);
   source1->record_element();
   source2->record_element();
-  EXPECT_EQ(unknown_many->TotalProcessingTime(), 400);
+  EXPECT_EQ(unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
+            400);
   EXPECT_EQ(unknown_many->OutputTime(&input_times, nullptr), 400);
   unknown_many->record_element();
-  EXPECT_EQ(unknown_many->TotalProcessingTime(), 200);
+  EXPECT_EQ(unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
+            200);
   EXPECT_EQ(unknown_many->OutputTime(&input_times, nullptr), 200);
 }
 
@@ -313,34 +335,34 @@
   unknown->add_input(source2);
   std::vector<double> input_times(1, 0);
   source1->add_processing_time(100);
-  EXPECT_EQ(unknown->TotalProcessingTime(), 0);
+  EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 0);
   EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 0);
   source2->add_processing_time(100);
-  EXPECT_EQ(unknown->TotalProcessingTime(), 0);
+  EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 0);
   EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 0);
   source1->record_element();
-  EXPECT_EQ(unknown->TotalProcessingTime(), 100);
+  EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 100);
   EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 100);
   source2->record_element();
-  EXPECT_EQ(unknown->TotalProcessingTime(), 200);
+  EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 200);
   EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 200);
   source1->record_element();
-  EXPECT_EQ(unknown->TotalProcessingTime(), 150);
+  EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 150);
   EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 150);
   source2->record_element();
-  EXPECT_EQ(unknown->TotalProcessingTime(), 100);
+  EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 100);
   EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 100);
   // Unknown node processing time should not affect its TotalProcessingTime() or
   // OutputTime().
   unknown->add_processing_time(100);
   EXPECT_EQ(unknown->processing_time(), 100);
-  EXPECT_EQ(unknown->TotalProcessingTime(), 100);
+  EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 100);
   EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 100);
   // Unknown node number of elements should not affect its TotalProcessingTime()
   // or OutputTime().
   unknown->record_element();
   EXPECT_EQ(unknown->num_elements(), 1);
-  EXPECT_EQ(unknown->TotalProcessingTime(), 100);
+  EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 100);
   EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 100);
 }
 
@@ -362,7 +384,8 @@
     return 0;
   }
 
-  double TotalProcessingTimeLocked() override SHARED_LOCKS_REQUIRED(mu_) {
+  double TotalProcessingTimeLocked(std::map<string, double>* processing_times)
+      override SHARED_LOCKS_REQUIRED(mu_) {
     return 0;
   }
 };
@@ -424,9 +447,10 @@
   for (int i = 0; i < 100; i++) {
     source1->record_element();
   }
-  EXPECT_LE(interleave_many->TotalProcessingTime(),
+  EXPECT_LE(interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
             (weighted_processing_time(100, 2, 0)) + 100);
-  EXPECT_GE(interleave_many->TotalProcessingTime(), 0);
+  EXPECT_GE(interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
+            0);
 }
 
 // Precision for comparison of the gradient and a relative output time change.
diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc
index 58f79bd..9011b61 100644
--- a/tensorflow/core/framework/node_def_builder.cc
+++ b/tensorflow/core/framework/node_def_builder.cc
@@ -261,19 +261,33 @@
   }
 }
 
-NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, const AttrValue& value) {
+bool NodeDefBuilder::AttrValueAlreadyPresent(StringPiece name,
+                                             const AttrValue& value) {
   if (const AttrValue* found = AttrSlice(node_def_).Find(name)) {
     if (!AreAttrValuesEqual(*found, value)) {
       errors_.push_back(strings::StrCat("Inconsistent values for attr '", name,
                                         "' ", SummarizeAttrValue(*found),
                                         " vs. ", SummarizeAttrValue(value)));
     }
-  } else {
+    return true;
+  }
+  return false;
+}
+
+NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, const AttrValue& value) {
+  if (!AttrValueAlreadyPresent(name, value)) {
     AddNodeAttr(name, value, &node_def_);
   }
   return *this;
 }
 
+NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, AttrValue&& value) {
+  if (!AttrValueAlreadyPresent(name, value)) {
+    AddNodeAttr(name, std::move(value), &node_def_);
+  }
+  return *this;
+}
+
 #define ATTR(T)                                                     \
   NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, T value) { \
     AttrValue attr_value;                                           \
diff --git a/tensorflow/core/framework/node_def_builder.h b/tensorflow/core/framework/node_def_builder.h
index 92d6399..b450966 100644
--- a/tensorflow/core/framework/node_def_builder.h
+++ b/tensorflow/core/framework/node_def_builder.h
@@ -93,6 +93,7 @@
   // Sets the attr, if not already set.  If already set with a different
   // value, an error will be returned from Finalize().
   NodeDefBuilder& Attr(StringPiece name, const AttrValue& value);
+  NodeDefBuilder& Attr(StringPiece name, AttrValue&& value);
   NodeDefBuilder& Attr(StringPiece name, StringPiece value);
   NodeDefBuilder& Attr(StringPiece name, const char* value);
   NodeDefBuilder& Attr(StringPiece name, int32 value);
@@ -172,6 +173,11 @@
     return input_arg->is_ref() ? MakeRefType(dt) : dt;
   }
 
+  // Returns true if an attr named `name` is already present in the node_def_.
+  // If such an attr is already present and `value` is not equal to the present
+  // value, an error is generated.
+  bool AttrValueAlreadyPresent(StringPiece name, const AttrValue& value);
+
   const OpDef* op_def_;
   NodeDef node_def_;
   int inputs_specified_;
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index a130d26..9484f10 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -21,6 +21,7 @@
 
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_join.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/attr_value_util.h"
 #include "tensorflow/core/framework/graph.pb_text.h"
 #include "tensorflow/core/framework/op.h"
@@ -243,6 +244,7 @@
     const AttrValue* attr_value;                                              \
     TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));                   \
     TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")")); \
+    value->reserve(attr_value->list().FIELD().size());                        \
     for (const auto& v : attr_value->list().FIELD()) {                        \
       __VA_ARGS__;                                                            \
       value->APPEND_OP(CAST);                                                 \
@@ -250,58 +252,87 @@
     return Status::OK();                                                      \
   }
 
-#define DEFINE_GET_ATTR_SIMPLE(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
-  bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,      \
-                         TYPE* value) {                                      \
-    const AttrValue* attr_value = attrs.Find(attr_name);                     \
-    if (attr_value == nullptr) {                                             \
-      return false;                                                          \
-    }                                                                        \
-    Status s = AttrValueHasType(*attr_value, ATTR_TYPE);                     \
-    if (!s.ok()) {                                                           \
-      return false;                                                          \
-    }                                                                        \
-    const auto& v = attr_value->FIELD();                                     \
-    __VA_ARGS__;                                                             \
-    *value = CAST;                                                           \
-    return true;                                                             \
-  }                                                                          \
-  bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,      \
-                         std::vector<TYPE>* value) {                         \
-    const AttrValue* attr_value = attrs.Find(attr_name);                     \
-    if (attr_value == nullptr) {                                             \
-      return false;                                                          \
-    }                                                                        \
-    Status s = AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")");         \
-    if (!s.ok()) {                                                           \
-      return false;                                                          \
-    }                                                                        \
-    for (const auto& v : attr_value->list().FIELD()) {                       \
-      __VA_ARGS__;                                                           \
-      value->APPEND_OP(CAST);                                                \
-    }                                                                        \
-    return true;                                                             \
+#define DEFINE_TRY_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
+  bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,      \
+                      TYPE* value) {                                      \
+    const AttrValue* attr_value = attrs.Find(attr_name);                  \
+    if (attr_value == nullptr) {                                          \
+      return false;                                                       \
+    }                                                                     \
+    Status s = AttrValueHasType(*attr_value, ATTR_TYPE);                  \
+    if (!s.ok()) {                                                        \
+      return false;                                                       \
+    }                                                                     \
+    const auto& v = attr_value->FIELD();                                  \
+    __VA_ARGS__;                                                          \
+    *value = CAST;                                                        \
+    return true;                                                          \
+  }                                                                       \
+  bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,      \
+                      std::vector<TYPE>* value) {                         \
+    const AttrValue* attr_value = attrs.Find(attr_name);                  \
+    if (attr_value == nullptr) {                                          \
+      return false;                                                       \
+    }                                                                     \
+    Status s = AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")");      \
+    if (!s.ok()) {                                                        \
+      return false;                                                       \
+    }                                                                     \
+    value->reserve(attr_value->list().FIELD().size());                    \
+    for (const auto& v : attr_value->list().FIELD()) {                    \
+      __VA_ARGS__;                                                        \
+      value->APPEND_OP(CAST);                                             \
+    }                                                                     \
+    return true;                                                          \
   }
 
 DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;)
-DEFINE_GET_ATTR_SIMPLE(string, s, "string", emplace_back, v, ;)
+DEFINE_TRY_GET_ATTR(string, s, "string", emplace_back, v, ;)
 DEFINE_GET_ATTR(int64, i, "int", emplace_back, v, ;)
+DEFINE_TRY_GET_ATTR(int64, i, "int", emplace_back, v, ;)
 DEFINE_GET_ATTR(
     int32, i, "int", emplace_back, static_cast<int32>(v),
     if (static_cast<int64>(static_cast<int32>(v)) != v) {
       return errors::InvalidArgument("Attr ", attr_name, " has value ", v,
                                      " out of range for an int32");
     })
+DEFINE_TRY_GET_ATTR(
+    int32, i, "int", emplace_back, static_cast<int32>(v),
+    if (static_cast<int64>(static_cast<int32>(v)) != v) {
+      static int log_counter = 0;
+      if (log_counter < 10) {
+        log_counter++;
+        LOG(WARNING) << "Attr " << attr_name << " has value " << v
+                     << " out of range for an int32";
+      }
+      return false;
+    })
 DEFINE_GET_ATTR(float, f, "float", emplace_back, v, ;)
+DEFINE_TRY_GET_ATTR(float, f, "float", emplace_back, v, ;)
 // std::vector<bool> specialization does not have emplace_back until
 // c++14, so we have to use push_back (see
 // http://en.cppreference.com/w/cpp/container/vector/emplace_back)
 DEFINE_GET_ATTR(bool, b, "bool", push_back, v, ;)
+DEFINE_TRY_GET_ATTR(bool, b, "bool", push_back, v, ;)
 DEFINE_GET_ATTR(DataType, type, "type", emplace_back, static_cast<DataType>(v),
                 ;)
+DEFINE_TRY_GET_ATTR(DataType, type, "type", emplace_back,
+                    static_cast<DataType>(v),
+                    ;)
 DEFINE_GET_ATTR(TensorShapeProto, shape, "shape", emplace_back, v, ;)
 DEFINE_GET_ATTR(TensorShape, shape, "shape", emplace_back, TensorShape(v),
                 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(v));)
+DEFINE_TRY_GET_ATTR(
+    TensorShape, shape, "shape", emplace_back, TensorShape(v),
+    if (!TensorShape::IsValidShape(v).ok()) {
+      static int log_counter = 0;
+      if (log_counter < 10) {
+        log_counter++;
+        LOG(WARNING) << "Attr " << attr_name << " has invalid shape value "
+                     << v.DebugString();
+      }
+      return false;
+    })
 DEFINE_GET_ATTR(PartialTensorShape, shape, "shape", emplace_back,
                 PartialTensorShape(v),
                 TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(v));)
@@ -332,6 +363,40 @@
   return attr_value->s();
 }
 
+bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
+                    std::vector<const string*>* value) {
+  const AttrValue* attr_value = attrs.Find(attr_name);
+  if (attr_value == nullptr) {
+    return false;
+  }
+  Status s = AttrValueHasType(*attr_value, "list(string)");
+  if (!s.ok()) {
+    return false;
+  }
+  value->reserve(attr_value->list().s().size());
+  for (const auto& v : attr_value->list().s()) {
+    value->push_back(&v);
+  }
+  return true;
+}
+
+bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
+                    std::vector<const TensorShapeProto*>* value) {
+  const AttrValue* attr_value = attrs.Find(attr_name);
+  if (attr_value == nullptr) {
+    return false;
+  }
+  Status s = AttrValueHasType(*attr_value, "list(shape)");
+  if (!s.ok()) {
+    return false;
+  }
+  value->reserve(attr_value->list().shape().size());
+  for (const auto& v : attr_value->list().shape()) {
+    value->push_back(&v);
+  }
+  return true;
+}
+
 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
                    DataTypeVector* value) {
   const AttrValue* attr_value;
@@ -352,6 +417,20 @@
   return Status::OK();
 }
 
+bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
+                    const TensorProto** value) {
+  const AttrValue* attr_value = attrs.Find(attr_name);
+  if (attr_value == nullptr) {
+    return false;
+  }
+  Status s = AttrValueHasType(*attr_value, "tensor");
+  if (!s.ok()) {
+    return false;
+  }
+  *value = &attr_value->tensor();
+  return true;
+}
+
 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
                    const NameAttrList** value) {
   const AttrValue* attr_value;
@@ -361,6 +440,20 @@
   return Status::OK();
 }
 
+bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
+                    const NameAttrList** value) {
+  const AttrValue* attr_value = attrs.Find(attr_name);
+  if (attr_value == nullptr) {
+    return false;
+  }
+  Status s = AttrValueHasType(*attr_value, "func");
+  if (!s.ok()) {
+    return false;
+  }
+  *value = &attr_value->func();
+  return true;
+}
+
 namespace {  // Helper for InOutTypesForNode().
 
 template <class NodeDefOrAttrSlice>
@@ -753,6 +846,10 @@
       AttrValueMap::value_type(string(name), value));
 }
 
+void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def) {
+  (*node_def->mutable_attr())[string(name)] = std::move(value);
+}
+
 #define ADD_NODE_ATTR(T)                                           \
   void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \
     AttrValue attr_value;                                          \
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index 1a089b5..8f58607 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -74,6 +74,7 @@
 // Adds an attr with name <name> and value <value> to *node_def.
 // The type of the attr is based on the type of value.
 void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def);
+void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def);
 void AddNodeAttr(StringPiece name, StringPiece value, NodeDef* node_def);
 void AddNodeAttr(StringPiece name, const char* value, NodeDef* node_def);
 void AddNodeAttr(StringPiece name, int32 value, NodeDef* node_def);
@@ -234,11 +235,15 @@
 // REQUIRES: Must not use *value beyond the lifetime of node_def.
 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
                    const TensorProto** value);  // type: "tensor"
+bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
+                    const TensorProto** value);  // type: "tensor"
 
 // This version avoids copying the NameAttrList.
 // REQUIRES: Must not use *value beyond the lifetime of node_def.
 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
                    const NameAttrList** value);  // type: "func"
+bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
+                    const NameAttrList** value);  // type: "func"
 
 // These versions copies the NameAttrList(s).
 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
@@ -249,10 +254,43 @@
 // Look up the attr with name attr_name and set *value to its value.  If no
 // attr with attr_name is found in node_def, or the attr does not have
 // a matching type, false is returned.
-bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,
-                       string* value);  // type: "string"
-bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,
-                       std::vector<string>* value);  // type: "string"
+bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
+                    string* value);  // type: "string"
+bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
+                    int64* value);  // type: "int"
+bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
+                    std::vector<int64>* value);  // type: "int"
+bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
+                    int32* value);  // type: "int"
+bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
+                    float* value);  // type: "float"
+bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
+                    bool* value);  // type: "bool"
+bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
+                    DataType* value);  // type: "type"
+bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
+                    TensorShape* value);  // type: "shape"
+
+bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
+                    std::vector<string>* value);  // type: "list(string)"
+bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
+                    std::vector<int32>* value);  // type: "list(int)"
+bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
+                    std::vector<float>* value);  // type: "list(float)"
+bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
+                    std::vector<bool>* value);  // type: "list(bool)"
+bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
+                    std::vector<DataType>* value);  // type: "list(type)"
+bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
+                    std::vector<TensorShape> value);  // type: "shape"
+
+// Overloads of TryGetNodeAttr() that avoid copying the non-POD attribute
+// values.
+bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
+                    std::vector<const string*>* value);  // type: "list(string)"
+bool TryGetNodeAttr(
+    const AttrSlice& attrs, StringPiece attr_name,
+    std::vector<const TensorShapeProto*>* value);  // type: "list(shape)"
 
 // Look up the attr with name attr_name and return a reference to its value.
 // If no attr with attr_name is found in node_def, or the attr does not have
diff --git a/tensorflow/core/framework/op_compatibility_test.cc b/tensorflow/core/framework/op_compatibility_test.cc
index dc931c3..4edb607 100644
--- a/tensorflow/core/framework/op_compatibility_test.cc
+++ b/tensorflow/core/framework/op_compatibility_test.cc
@@ -35,7 +35,7 @@
     Tensor* out_tensor = nullptr;
     OP_REQUIRES_OK(context, context->allocate_output("ndef", TensorShape({}),
                                                      &out_tensor));
-    out_tensor->scalar<string>()() = SummarizeNodeDef(def());
+    out_tensor->scalar<tstring>()() = SummarizeNodeDef(def());
   }
 };
 
@@ -87,7 +87,7 @@
     TF_ASSERT_OK(RunOpKernel());
   }
 
-  string Result() { return GetOutput(0)->scalar<string>()(); }
+  string Result() { return GetOutput(0)->scalar<tstring>()(); }
 
   void ExpectIncompatible(const OpDef& old_op_def, const OpDef& new_op_def,
                           const string& error) {
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index 020b3b2..6fe1f4d 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -1477,8 +1477,8 @@
   }
   if (registration == nullptr) {
     s.Update(errors::NotFound("No registered '", node_def.op(),
-                              "' OpKernel for ", DeviceTypeString(device_type),
-                              " devices compatible with node ",
+                              "' OpKernel for '", DeviceTypeString(device_type),
+                              "' devices compatible with node ",
                               FormatNodeDefForError(node_def)));
     if (was_attr_mismatch) {
       errors::AppendToMessage(
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index f05bb90..61f7f9e 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -621,6 +621,9 @@
     // The step being executed.
     int64 step_id = 0;
 
+    // True if the op is created by eager runtime.
+    bool is_eager = false;
+
     // The op kernel being computed.
     OpKernel* op_kernel = nullptr;
 
@@ -738,6 +741,8 @@
 
   int64 step_id() const { return params_->step_id; }
 
+  bool is_eager() const { return params_->is_eager; }
+
   const OpKernel& op_kernel() const { return *params_->op_kernel; }
 
   // Input/output signature.
@@ -1282,8 +1287,9 @@
     return params_->dec_num_deferred_ops_function;
   }
 
- private:
   Allocator* get_allocator(AllocatorAttributes attr);
+
+ private:
   bool record_memory_consumption_ = false;
 
   // Internal method to add a tensor's buffer to the list of buffers
diff --git a/tensorflow/core/framework/reader_base.cc b/tensorflow/core/framework/reader_base.cc
index 39d83d9..ec27b8b 100644
--- a/tensorflow/core/framework/reader_base.cc
+++ b/tensorflow/core/framework/reader_base.cc
@@ -214,7 +214,7 @@
             context->SetStatus(errors::InvalidArgument(
                 "Expected to dequeue a one-element string tensor"));
           } else {
-            work = tuple[0].flat<string>()(0);
+            work = tuple[0].flat<tstring>()(0);
           }
         }
         n.Notify();
diff --git a/tensorflow/core/framework/rendezvous.cc b/tensorflow/core/framework/rendezvous.cc
index 1281b12..90e432a 100644
--- a/tensorflow/core/framework/rendezvous.cc
+++ b/tensorflow/core/framework/rendezvous.cc
@@ -187,7 +187,7 @@
 
     // Delete the queue when the last element has been consumed.
     if (queue->size() == 1) {
-      VLOG(2) << "Clean up Send/Recv queu (key:" << key.FullKey() << "). ";
+      VLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). ";
       table_.erase(key_hash);
     } else {
       queue->pop_front();
@@ -220,10 +220,53 @@
     if (queue->empty() || !queue->front()->IsSendValue()) {
       // There is no message to pick up.
       // Only recv-related fields need to be filled.
+      CancellationManager* cm = recv_args.cancellation_manager;
+      CancellationToken token = CancellationManager::kInvalidToken;
+      bool already_cancelled = false;
+      if (cm != nullptr) {
+        token = cm->get_cancellation_token();
+        already_cancelled = !cm->RegisterCallback(token, [this, token,
+                                                          key_hash] {
+          Item* item = nullptr;
+          {
+            mutex_lock l(mu_);
+            ItemQueue* queue = &table_[key_hash];
+            if (!queue->empty() && !queue->front()->IsSendValue()) {
+              for (auto it = queue->begin(); it != queue->end(); it++) {
+                if ((*it)->cancellation_token == token) {
+                  item = *it;
+                  if (queue->size() == 1) {
+                    table_.erase(key_hash);
+                  } else {
+                    queue->erase(it);
+                  }
+                  break;
+                }
+              }
+            }
+          }
+
+          if (item != nullptr) {
+            item->waiter(StatusGroup::MakeDerived(
+                             errors::Cancelled("RecvAsync is cancelled.")),
+                         Args(), item->recv_args, Tensor(), /*is_dead=*/false);
+            delete item;
+          }
+        });
+      }
+      if (already_cancelled) {
+        mu_.unlock();
+        done(StatusGroup::MakeDerived(
+                 errors::Cancelled("RecvAsync is cancelled.")),
+             Args(), recv_args, Tensor(), /*is_dead=*/false);
+        return;
+      }
+
       VLOG(2) << "Enqueue Recv Item (key:" << key.FullKey() << "). ";
       Item* item = new Item;
       item->waiter = std::move(done);
       item->recv_args = recv_args;
+      item->cancellation_token = token;
       if (item->recv_args.device_context) {
         item->recv_args.device_context->Ref();
       }
@@ -239,7 +282,7 @@
 
     // Delete the queue when the last element has been consumed.
     if (queue->size() == 1) {
-      VLOG(2) << "Clean up Send/Recv queu (key:" << key.FullKey() << "). ";
+      VLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). ";
       table_.erase(key_hash);
     } else {
       queue->pop_front();
@@ -280,6 +323,7 @@
     bool is_dead = false;
     Args send_args;
     Args recv_args;
+    CancellationToken cancellation_token;
 
     ~Item() {
       if (send_args.device_context) {
@@ -288,6 +332,11 @@
       if (recv_args.device_context) {
         recv_args.device_context->Unref();
       }
+      auto* cm = recv_args.cancellation_manager;
+      if (cancellation_token != CancellationManager::kInvalidToken &&
+          cm != nullptr) {
+        cm->TryDeregisterCallback(cancellation_token);
+      }
     }
 
     // Returns true iff this item represents a value being sent.
diff --git a/tensorflow/core/framework/rendezvous.h b/tensorflow/core/framework/rendezvous.h
index 01e43e4..84e2f6a 100644
--- a/tensorflow/core/framework/rendezvous.h
+++ b/tensorflow/core/framework/rendezvous.h
@@ -18,6 +18,7 @@
 
 #include <string>
 
+#include "tensorflow/core/framework/cancellation.h"
 #include "tensorflow/core/framework/control_flow.h"
 #include "tensorflow/core/framework/device_base.h"
 #include "tensorflow/core/framework/tensor.h"
@@ -48,6 +49,7 @@
   struct Args {
     DeviceContext* device_context = nullptr;
     AllocatorAttributes alloc_attrs;
+    CancellationManager* cancellation_manager = nullptr;  // not owned.
   };
 
   // Constructs a rendezvous key for the tensor of "name" sent from
diff --git a/tensorflow/core/framework/rendezvous_test.cc b/tensorflow/core/framework/rendezvous_test.cc
index 8f16c6f..da9a1fb 100644
--- a/tensorflow/core/framework/rendezvous_test.cc
+++ b/tensorflow/core/framework/rendezvous_test.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/core/framework/rendezvous.h"
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/cancellation.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/tensor_types.h"
@@ -29,6 +30,7 @@
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/notification.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/test_benchmark.h"
 #include "tensorflow/core/platform/types.h"
@@ -84,7 +86,7 @@
 // string -> Tensor<string>
 Tensor V(const string& content) {
   Tensor tensor(DT_STRING, TensorShape({}));
-  tensor.scalar<string>()() = content;
+  tensor.scalar<tstring>()() = content;
   return tensor;
 }
 
@@ -92,7 +94,7 @@
 string V(const Tensor& tensor) {
   CHECK_EQ(tensor.dtype(), DT_STRING);
   CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
-  return tensor.scalar<string>()();
+  return tensor.scalar<tstring>()();
 }
 
 Rendezvous::ParsedKey MakeKey(const string& name) {
@@ -153,6 +155,126 @@
   EXPECT_EQ("secret msg", V(val));
 }
 
+TEST_F(LocalRendezvousTest, CancelBeforeRecv) {
+  auto* cm = new CancellationManager();
+  Tensor val(DT_STRING);
+  bool is_dead = false;
+  Rendezvous::Args args;
+  args.cancellation_manager = cm;
+  cm->StartCancel();
+  auto s = rendez_->Recv(KeyFoo(), args, &val, &is_dead);
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(errors::IsCancelled(s));
+  EXPECT_EQ("[_Derived_]RecvAsync is cancelled.", s.error_message());
+  delete cm;
+}
+
+TEST_F(LocalRendezvousTest, CancelAfterRecv) {
+  auto* cm = new CancellationManager();
+  Notification n;
+  SchedClosure([cm, &n]() {
+    Env::Default()->SleepForMicroseconds(10000);
+    cm->StartCancel();
+    n.Notify();
+  });
+  Tensor val(DT_STRING);
+  bool is_dead = false;
+  Rendezvous::Args args;
+  args.cancellation_manager = cm;
+  auto s = rendez_->Recv(KeyFoo(), args, &val, &is_dead);
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(errors::IsCancelled(s));
+  EXPECT_EQ("[_Derived_]RecvAsync is cancelled.", s.error_message());
+  n.WaitForNotification();
+  delete cm;
+}
+
+TEST_F(LocalRendezvousTest, CancelEmptyQueue) {
+  auto* cm = new CancellationManager();
+  Notification n;
+  SchedClosure([this, cm, &n]() {
+    Env::Default()->SleepForMicroseconds(10000);
+    Rendezvous::Args args;
+    TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
+    cm->StartCancel();
+    n.Notify();
+  });
+  Tensor val(DT_STRING);
+  bool is_dead = false;
+  Rendezvous::Args args;
+  args.cancellation_manager = cm;
+  TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &val, &is_dead));
+  EXPECT_EQ("hello", V(val));
+  n.WaitForNotification();
+  delete cm;
+}
+
+TEST_F(LocalRendezvousTest, CancelMultiple) {
+  auto* cm = new CancellationManager();
+  SchedClosure([this, cm]() {
+    Env::Default()->SleepForMicroseconds(10000);
+    Rendezvous::Args args;
+    cm->StartCancel();
+    TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
+    TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
+  });
+  Tensor val(DT_STRING);
+  Rendezvous::Args args;
+  Rendezvous::Args args_with_cancellation;
+  args_with_cancellation.cancellation_manager = cm;
+  Notification n0;
+  Notification n1;
+  Notification n2;
+  Notification n3;
+  Status s0;
+  Status s1;
+  Status s2;
+  Status s3;
+
+  rendez_->RecvAsync(
+      KeyFoo(), args,
+      [&n0, &s0](const Status& s, const Rendezvous::Args& send_args,
+                 const Rendezvous::Args& recv_args, const Tensor& v,
+                 const bool dead) {
+        s0.Update(s);
+        n0.Notify();
+      });
+  rendez_->RecvAsync(
+      KeyFoo(), args_with_cancellation,
+      [&n1, &s1](const Status& s, const Rendezvous::Args& send_args,
+                 const Rendezvous::Args& recv_args, const Tensor& v,
+                 const bool dead) {
+        s1.Update(s);
+        n1.Notify();
+      });
+  rendez_->RecvAsync(
+      KeyFoo(), args,
+      [&n2, &s2](const Status& s, const Rendezvous::Args& send_args,
+                 const Rendezvous::Args& recv_args, const Tensor& v,
+                 const bool dead) {
+        s2.Update(s);
+        n2.Notify();
+      });
+  rendez_->RecvAsync(
+      KeyFoo(), args_with_cancellation,
+      [&n3, &s3](const Status& s, const Rendezvous::Args& send_args,
+                 const Rendezvous::Args& recv_args, const Tensor& v,
+                 const bool dead) {
+        s3.Update(s);
+        n3.Notify();
+      });
+  n0.WaitForNotification();
+  n1.WaitForNotification();
+  n2.WaitForNotification();
+  n3.WaitForNotification();
+  TF_ASSERT_OK(s0);
+  TF_ASSERT_OK(s2);
+  EXPECT_FALSE(s1.ok());
+  EXPECT_FALSE(s3.ok());
+
+  delete cm;
+}
+
 // A simple structure that behaves a bit like a blocking counter.  The
 // user that decrements counter to 0 does done.Notify(), and the main
 // thread waits for done to be notified.
@@ -331,6 +453,7 @@
 
 void BM_PingPong(int iters) {
   CHECK_GT(iters, 0);
+  auto* cm = new CancellationManager();
   thread::ThreadPool* pool = new thread::ThreadPool(Env::Default(), "test", 1);
 
   // The main thread sends "foo" for iters times and receives "bar"
@@ -352,12 +475,14 @@
   Tensor bar(DT_STRING, TensorShape({}));
   bool is_dead = false;
   Rendezvous::Args args;
+  args.cancellation_manager = cm;
   for (int i = 0; i < iters; ++i) {
     TF_CHECK_OK(rendez->Send(KeyFoo(), args, foo, is_dead));
     TF_CHECK_OK(rendez->Recv(KeyBar(), args, &bar, &is_dead));
   }
   CHECK_EQ("bar", V(bar));
   delete pool;
+  delete cm;
 }
 BENCHMARK(BM_PingPong);
 
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
index 301fe68..67ea803 100644
--- a/tensorflow/core/framework/resource_mgr.h
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -639,8 +639,8 @@
           "Resource handle must have 2 elements, but had shape: ",
           tensor.shape().DebugString());
     }
-    container = tensor.flat<string>()(0);
-    shared_name = tensor.flat<string>()(1);
+    container = tensor.flat<tstring>()(0);
+    shared_name = tensor.flat<tstring>()(1);
   }
   return ctx->resource_manager()->Lookup(container, shared_name, resource);
 }
diff --git a/tensorflow/core/framework/resource_op_kernel.h b/tensorflow/core/framework/resource_op_kernel.h
index fbcd439..60e9703 100644
--- a/tensorflow/core/framework/resource_op_kernel.h
+++ b/tensorflow/core/framework/resource_op_kernel.h
@@ -96,7 +96,7 @@
       }
 
       if (!has_resource_type_) {
-        auto h = handle_.AccessTensor(context)->template flat<string>();
+        auto h = handle_.AccessTensor(context)->template flat<tstring>();
         h(0) = cinfo_.container();
         h(1) = cinfo_.name();
       }
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index c2b3b7d..5d3cc57 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -515,7 +515,12 @@
   if (in_n <= 0) {
     std::fill_n(data, n, Variant());
   } else {
-    for (int64 i = 0; i < in_n; ++i) {
+    // If tensor shape says we have n < in_n elements in the output tensor
+    // then make sure to only decode the first n out of the in_n elements in the
+    // in tensors. In all other cases, we decode all in_n elements of in and set
+    // the remaining elements up to n to be the default Variant() value.
+    const int64 real_n = n < in_n ? n : in_n;
+    for (int64 i = 0; i < real_n; ++i) {
       data[i] = in.variant_val(i);
       if (!DecodeUnaryVariant(&data[i])) {
         LOG(ERROR) << "Could not decode variant with type_name: \""
@@ -986,7 +991,7 @@
     for (int64 i = 0; i < element_count; i++) {
       if (*data_index >= limit) {
         // If not enough elements has been printed, append "...".
-        if (dim_index != 0 && i < element_count) {
+        if (dim_index != 0) {
           strings::StrAppend(result, "...");
         }
         return;
diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc
index d4aed38..dd4ca70 100644
--- a/tensorflow/core/framework/tensor_test.cc
+++ b/tensorflow/core/framework/tensor_test.cc
@@ -480,7 +480,7 @@
 
   Tensor string_tensor{DT_STRING, {10}};
   // Note that the error message compare # of elements, not # of bytes.
-  EXPECT_DEATH((string_tensor.bit_casted_shaped<string, 1>({9})), "9 vs. 10");
+  EXPECT_DEATH((string_tensor.bit_casted_shaped<tstring, 1>({9})), "9 vs. 10");
 }
 
 TEST_F(TensorReshapeTest, Flat) {
@@ -795,27 +795,27 @@
   {
     Tensor t(DT_STRING, TensorShape({}));
     EXPECT_EQ(1, t.NumElements());
-    auto Tt = t.scalar<string>();
+    auto Tt = t.scalar<tstring>();
     EXPECT_EQ(1, Tt.size());
     EXPECT_EQ(0, Tt.rank());
-    t.scalar<string>()() = "foo";
+    t.scalar<tstring>()() = "foo";
     EXPECT_EQ("foo", Tt());
   }
   {
     Tensor t(DT_STRING, TensorShape({1}));
     EXPECT_EQ(1, t.NumElements());
-    auto Tt = t.vec<string>();
+    auto Tt = t.vec<tstring>();
     EXPECT_EQ(1, Tt.size());
-    t.flat<string>()(0) = "foo";
+    t.flat<tstring>()(0) = "foo";
     EXPECT_EQ("foo", Tt(0));
   }
   {
     Tensor t(DT_STRING, TensorShape({1, 1, 1}));
     EXPECT_EQ(1, t.NumElements());
-    auto Tt = t.scalar<string>();
+    auto Tt = t.scalar<tstring>();
     EXPECT_EQ(1, Tt.size());
     EXPECT_EQ(0, Tt.rank());
-    t.flat<string>()(0) = "bar";
+    t.flat<tstring>()(0) = "bar";
     EXPECT_EQ("bar", Tt());
   }
   {
@@ -860,7 +860,7 @@
     Tensor t("fooooooooooooooooooooooooooooooooooooo");
     EXPECT_EQ(DT_STRING, t.dtype());
     EXPECT_EQ(1, t.NumElements());
-    auto Tt = t.scalar<string>();
+    auto Tt = t.scalar<tstring>();
     EXPECT_EQ(1, Tt.size());
     EXPECT_EQ(0, Tt.rank());
     EXPECT_EQ("fooooooooooooooooooooooooooooooooooooo", Tt());
@@ -980,7 +980,7 @@
   Tensor t2(DT_STRING, {2, 3});
   for (int i = 0; i < 2; ++i) {
     for (int j = 0; j < 3; ++j) {
-      t2.matrix<string>()(i, j) = strings::StrCat(i * 3 + j);
+      t2.matrix<tstring>()(i, j) = strings::StrCat(i * 3 + j);
     }
   }
 
@@ -1163,7 +1163,7 @@
   // String
   {
     Tensor t(DT_STRING, TensorShape({1}));
-    t.vec<string>()(0) = "foo";
+    t.vec<tstring>()(0) = "foo";
     TensorProto proto;
     t.AsProtoField(&proto);
 
diff --git a/tensorflow/core/framework/tensor_util.cc b/tensorflow/core/framework/tensor_util.cc
index c87cc95..2e99626 100644
--- a/tensorflow/core/framework/tensor_util.cc
+++ b/tensorflow/core/framework/tensor_util.cc
@@ -48,7 +48,7 @@
              input_data.size());
     }
   } else if (input.dtype() == DT_STRING) {
-    output->unaligned_flat<string>() = input.unaligned_flat<string>();
+    output->unaligned_flat<tstring>() = input.unaligned_flat<tstring>();
   } else {
     CHECK_EQ(DT_VARIANT, input.dtype());
     output->unaligned_flat<Variant>() = input.unaligned_flat<Variant>();
@@ -103,7 +103,7 @@
 
     int64 offset = 0;
     for (const Tensor& tensor : tensors) {
-      auto from_strings = tensor.flat<string>();
+      auto from_strings = tensor.flat<tstring>();
       CHECK_LE(offset + tensor.NumElements(), result->NumElements());
       for (int i = 0; i < tensor.NumElements(); ++i) {
         to_strings[offset + i] = from_strings(i);
@@ -155,7 +155,7 @@
     if (tensor.dtype() != DT_STRING) {
       return errors::Internal("Unexpected data type");
     }
-    auto from_strings = tensor.flat<string>();
+    auto from_strings = tensor.flat<tstring>();
 
     int64 offset = 0;
     for (int64 size : sizes) {
diff --git a/tensorflow/core/framework/tensor_util_test.cc b/tensorflow/core/framework/tensor_util_test.cc
index 4470876..fe98801 100644
--- a/tensorflow/core/framework/tensor_util_test.cc
+++ b/tensorflow/core/framework/tensor_util_test.cc
@@ -111,12 +111,12 @@
 
   // Test string deep copy
   Tensor str1(DT_STRING, TensorShape({2}));
-  str1.flat<string>()(0) = "foo1";
-  str1.flat<string>()(1) = "foo2";
+  str1.flat<tstring>()(0) = "foo1";
+  str1.flat<tstring>()(1) = "foo2";
   Tensor str2 = tensor::DeepCopy(str1);
-  str2.flat<string>()(0) = "bar1";
-  str2.flat<string>()(1) = "bar2";
-  EXPECT_NE(str2.flat<string>()(0), str1.flat<string>()(0));
+  str2.flat<tstring>()(0) = "bar1";
+  str2.flat<tstring>()(1) = "bar2";
+  EXPECT_NE(str2.flat<tstring>()(0), str1.flat<tstring>()(0));
 }
 
 TEST(TensorUtil, DeepCopySlice) {
@@ -151,7 +151,7 @@
 
 TEST(TensorUtil, DeepCopySliceString) {
   Tensor x(DT_STRING, TensorShape({10}));
-  x.flat<string>().setConstant("hello");
+  x.flat<tstring>().setConstant("hello");
 
   // Slice 'x' -- y still refers to the same buffer.
   Tensor y = x.Slice(3, 7);
@@ -160,7 +160,7 @@
   Tensor z = tensor::DeepCopy(y);
 
   // Set x to be different.
-  x.flat<string>().setConstant("goodbye");
+  x.flat<tstring>().setConstant("goodbye");
 
   EXPECT_EQ(TensorShape({10}), x.shape());
   EXPECT_EQ(TensorShape({4}), y.shape());
@@ -171,11 +171,11 @@
 
   // x and y should now all be 'goodbye', but z should be 'hello'.
   for (int i = 0; i < 10; ++i) {
-    EXPECT_EQ("goodbye", x.flat<string>()(i));
+    EXPECT_EQ("goodbye", x.flat<tstring>()(i));
   }
   for (int i = 0; i < 4; ++i) {
-    EXPECT_EQ("goodbye", y.unaligned_flat<string>()(i));
-    EXPECT_EQ("hello", z.flat<string>()(i));
+    EXPECT_EQ("goodbye", y.unaligned_flat<tstring>()(i));
+    EXPECT_EQ("hello", z.flat<tstring>()(i));
   }
 }
 
@@ -202,11 +202,12 @@
   // Each element of x and y should now be a DT_STRING Tensor containing "foo",
   // but each element of z should be a DT_FLOAT tensor containing 42.0.
   for (int i = 0; i < 10; ++i) {
-    EXPECT_EQ("foo", x.flat<Variant>()(i).get<Tensor>()->scalar<string>()());
+    EXPECT_EQ("foo", x.flat<Variant>()(i).get<Tensor>()->scalar<tstring>()());
   }
   for (int i = 0; i < 4; ++i) {
-    EXPECT_EQ("foo",
-              y.unaligned_flat<Variant>()(i).get<Tensor>()->scalar<string>()());
+    EXPECT_EQ(
+        "foo",
+        y.unaligned_flat<Variant>()(i).get<Tensor>()->scalar<tstring>()());
     EXPECT_EQ(42.0, z.flat<Variant>()(i).get<Tensor>()->scalar<float>()());
   }
 }
@@ -271,7 +272,7 @@
 TEST(TensorUtil, ConcatSplitStrings) {
   Tensor x(DT_STRING, TensorShape({4, 3}));
   for (int i = 0; i < 4 * 3; ++i) {
-    x.flat<string>()(i) = strings::StrCat("foo_", i);
+    x.flat<tstring>()(i) = strings::StrCat("foo_", i);
   }
 
   std::vector<Tensor> split;
@@ -280,15 +281,15 @@
   TF_ASSERT_OK(tensor::Concat(split, &x_round_tripped));
   ASSERT_EQ(x.shape(), x_round_tripped.shape());
   for (int i = 0; i < 4 * 3; ++i) {
-    EXPECT_EQ(x.flat<string>()(i), x_round_tripped.flat<string>()(i));
+    EXPECT_EQ(x.flat<tstring>()(i), x_round_tripped.flat<tstring>()(i));
   }
 
   // Ensure that no memory is being shared between 'x' and 'x_round_tripped'.
   for (int i = 0; i < 4 * 3; ++i) {
-    x_round_tripped.flat<string>()(i) = strings::StrCat("bar_", i);
+    x_round_tripped.flat<tstring>()(i) = strings::StrCat("bar_", i);
   }
   for (int i = 0; i < 4 * 3; ++i) {
-    EXPECT_NE(x.flat<string>()(i), x_round_tripped.flat<string>()(i));
+    EXPECT_NE(x.flat<tstring>()(i), x_round_tripped.flat<tstring>()(i));
   }
 }
 
diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h
index 7a58c10..e09ea26 100644
--- a/tensorflow/core/framework/types.h
+++ b/tensorflow/core/framework/types.h
@@ -391,7 +391,7 @@
 MATCH_TYPE_AND_ENUM(uint8, DT_UINT8);
 MATCH_TYPE_AND_ENUM(int16, DT_INT16);
 MATCH_TYPE_AND_ENUM(int8, DT_INT8);
-MATCH_TYPE_AND_ENUM(string, DT_STRING);
+MATCH_TYPE_AND_ENUM(tstring, DT_STRING);
 MATCH_TYPE_AND_ENUM(complex64, DT_COMPLEX64);
 MATCH_TYPE_AND_ENUM(complex128, DT_COMPLEX128);
 MATCH_TYPE_AND_ENUM(int64, DT_INT64);
diff --git a/tensorflow/core/framework/variant_op_copy_test.cc b/tensorflow/core/framework/variant_op_copy_test.cc
index 25cddc0..19226d2 100644
--- a/tensorflow/core/framework/variant_op_copy_test.cc
+++ b/tensorflow/core/framework/variant_op_copy_test.cc
@@ -244,7 +244,7 @@
   // Create the input StoredTensorValue and serialize it.
   StoredTensorValue from;
   from.stored = Tensor(DT_STRING, TensorShape({}));
-  from.stored.scalar<string>()() = "hi";
+  from.stored.scalar<tstring>()() = "hi";
   VariantTensorData data;
   data.set_type_name(from.TypeName());
   from.Encode(&data);
@@ -292,7 +292,7 @@
 TEST(VariantOpCopyTest, CreateCopyCPUToCPUString) {
   Scope root = Scope::NewRootScope().WithDevice("/cpu:0");
   Tensor t_str(DT_STRING, TensorShape({}));
-  t_str.scalar<string>()() = "hi";
+  t_str.scalar<tstring>()() = "hi";
   Output create_op = CreateTestVariant(root, t_str);
   Output identity = ops::Identity(root, create_op);
 
@@ -309,7 +309,7 @@
     EXPECT_EQ("StoredTensorValue", r1.TypeName());
     const StoredTensorValue* v1 = r1.get<StoredTensorValue>();
     EXPECT_NE(v1, nullptr);
-    EXPECT_EQ("hi", v1->stored.scalar<string>()());
+    EXPECT_EQ("hi", v1->stored.scalar<tstring>()());
   }
 }
 
@@ -356,7 +356,7 @@
   Scope root = Scope::NewRootScope().WithDevice("/cpu:0");
   Scope with_gpu = root.WithDevice("/gpu:0");
   Tensor t_str(DT_STRING, TensorShape({}));
-  t_str.scalar<string>()() = "hi";
+  t_str.scalar<tstring>()() = "hi";
   Output create_op = CreateTestVariant(root, t_str);
   Output identity = ops::Identity(with_gpu, create_op);
 
diff --git a/tensorflow/core/framework/variant_op_registry.cc b/tensorflow/core/framework/variant_op_registry.cc
index b5107a0..608f368 100644
--- a/tensorflow/core/framework/variant_op_registry.cc
+++ b/tensorflow/core/framework/variant_op_registry.cc
@@ -13,13 +13,15 @@
 limitations under the License.
 ==============================================================================*/
 
+#include "tensorflow/core/framework/variant_op_registry.h"
+
 #include <string>
 
 #include "tensorflow/core/framework/register_types.h"
 #include "tensorflow/core/framework/type_index.h"
 #include "tensorflow/core/framework/variant.h"
-#include "tensorflow/core/framework/variant_op_registry.h"
 #include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/public/version.h"
 
 namespace tensorflow {
@@ -56,6 +58,18 @@
 }
 
 bool DecodeUnaryVariant(Variant* variant) {
+  CHECK_NOTNULL(variant);
+  if (variant->TypeName().empty()) {
+    VariantTensorDataProto* t = variant->get<VariantTensorDataProto>();
+    if (t == nullptr || !t->metadata().empty() || !t->tensors().empty()) {
+      // Malformed variant.
+      return false;
+    } else {
+      // Serialization of an empty Variant.
+      variant->clear();
+      return true;
+    }
+  }
   UnaryVariantOpRegistry::VariantDecodeFn* decode_fn =
       UnaryVariantOpRegistry::Global()->GetDecodeFn(variant->TypeName());
   if (decode_fn == nullptr) {
diff --git a/tensorflow/core/framework/variant_op_registry_test.cc b/tensorflow/core/framework/variant_op_registry_test.cc
index 0a4874a..6f40cd1 100644
--- a/tensorflow/core/framework/variant_op_registry_test.cc
+++ b/tensorflow/core/framework/variant_op_registry_test.cc
@@ -118,13 +118,30 @@
   v.Encode(&data);
   VariantTensorDataProto proto;
   data.ToProto(&proto);
-  Variant encoded = proto;
+  Variant encoded = std::move(proto);
   EXPECT_TRUE((*decode_fn)(&encoded));
   VariantValue* decoded = encoded.get<VariantValue>();
   EXPECT_NE(decoded, nullptr);
   EXPECT_EQ(decoded->early_exit, true);
 }
 
+TEST(VariantOpDecodeRegistryTest, TestEmpty) {
+  VariantTensorDataProto empty_proto;
+  Variant empty_encoded = std::move(empty_proto);
+  EXPECT_TRUE(DecodeUnaryVariant(&empty_encoded));
+  EXPECT_TRUE(empty_encoded.is_empty());
+
+  VariantTensorData data;
+  Variant number = 3.0f;
+  number.Encode(&data);
+  VariantTensorDataProto proto;
+  data.ToProto(&proto);
+  proto.set_type_name("");
+  Variant encoded = std::move(proto);
+  // Failure when type name is empty but there's data in the proto.
+  EXPECT_FALSE(DecodeUnaryVariant(&encoded));
+}
+
 TEST(VariantOpDecodeRegistryTest, TestDuplicate) {
   UnaryVariantOpRegistry registry;
   UnaryVariantOpRegistry::VariantDecodeFn f;
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index cc8e18a..c24cac5 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/core/graph/graph.h"
 
 #include <vector>
+
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/node_def_util.h"
@@ -85,11 +86,14 @@
         {"CollectiveReduce", NC_COLLECTIVE},
         {"CollectiveBcastSend", NC_COLLECTIVE},
         {"CollectiveBcastRecv", NC_COLLECTIVE},
+        {"CollectiveGather", NC_COLLECTIVE},
         {"FakeParam", NC_FAKE_PARAM},
         {"PartitionedCall", NC_PARTITIONED_CALL},
         {"StatefulPartitionedCall", NC_PARTITIONED_CALL},
         {"If", NC_IF},
         {"StatelessIf", NC_IF},
+        {"While", NC_WHILE},
+        {"StatelessWhile", NC_WHILE},
         // Not using the constants defined in FunctionLibraryDefinition for the
         // 4 ops below because android inference library does not link
         // tf.function related files.
@@ -592,7 +596,7 @@
 }
 
 Status Graph::AddWhileInputHack(Node* new_src, int new_src_index, Node* dst) {
-  if (dst->type_string() != "While") {
+  if (!dst->IsWhileNode()) {
     return errors::Internal(
         "dst argument to AddWhileEdgeHack should be a While op, got: ",
         dst->DebugString());
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index 1d9a45b..0fe7f86 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -177,6 +177,7 @@
   bool IsFakeParam() const { return class_ == NC_FAKE_PARAM; }
   bool IsPartitionedCall() const { return class_ == NC_PARTITIONED_CALL; }
   bool IsIfNode() const { return class_ == NC_IF; }
+  bool IsWhileNode() const { return class_ == NC_WHILE; }
   // Is this node a function input
   bool IsArg() const { return class_ == NC_ARG; }
   // Is this node a function output
@@ -188,6 +189,11 @@
     UpdateProperties();
   }
 
+  void AddAttr(const string& name, std::vector<string>&& val) {
+    MoveAttrValue(std::move(val), AddAttrHelper(name));
+    UpdateProperties();
+  }
+
   void ClearAttr(const string& name);
 
   // Returns into '*e' the edge connecting to the 'idx' input of this Node.
@@ -264,6 +270,7 @@
     NC_FAKE_PARAM,
     NC_PARTITIONED_CALL,
     NC_IF,
+    NC_WHILE,
     NC_ARG,
     NC_RETVAL,
     NC_OTHER  // Not a special kind of node
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 48c6639..b462ab3 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -728,9 +728,9 @@
   if (!opts_.importing || !opts_.validate_shape) return Status::OK();
   TF_RETURN_IF_ERROR(refiner_->AddNode(node));
   // For nodes with the _output_shapes attribute, override the shape.
-  std::vector<TensorShapeProto> shape_attrs;
+  std::vector<const TensorShapeProto*> shape_attrs;
   const char* kAttrName = "_output_shapes";
-  if (!GetNodeAttr(node->attrs(), kAttrName, &shape_attrs).ok()) {
+  if (!TryGetNodeAttr(node->attrs(), kAttrName, &shape_attrs)) {
     // No _output_shapes attribute, the AddNode call above was sufficient.
     return Status::OK();
   }
@@ -753,7 +753,7 @@
                  << " outputs. Output shapes may be inaccurate.";
   }
   for (int i = 0; i < node->num_outputs(); ++i) {
-    const TensorShapeProto& p = shape_attrs[i];
+    const TensorShapeProto& p = *shape_attrs[i];
     shape_inference::ShapeHandle h;
     Status s = ic->MakeShapeFromShapeProto(p, &h);
     if (!s.ok()) {
@@ -772,7 +772,6 @@
       // This is an escape hatch that allows us to correct shape
       // functions that are not critical to correct execution but
       // would cause graphs to fail if imported after correcting.
-      //
       const string& op = node->type_string();
       const std::vector<string> whitelist = {
           // To be removed after 2017/03/08.
@@ -991,11 +990,10 @@
     Node* node = pair.second.node;
     if (node == nullptr) continue;
     std::vector<string> coloc_values;
-    Status status =
-        GetNodeAttr(node->attrs(), kColocationAttrName, &coloc_values);
-    if (!status.ok()) continue;
+    if (!TryGetNodeAttr(node->attrs(), kColocationAttrName, &coloc_values))
+      continue;
     bool updated = false;
-    for (int i = 0; i < coloc_values.size(); ++i) {
+    for (size_t i = 0; i < coloc_values.size(); ++i) {
       StringPiece val(coloc_values[i]);
       if (absl::ConsumePrefix(&val, kColocationGroupPrefix)) {
         auto name_pair = uniquified_names_.find(string(val));
@@ -1006,7 +1004,7 @@
       }
     }
     if (updated) {
-      node->AddAttr(kColocationAttrName, coloc_values);
+      node->AddAttr(kColocationAttrName, std::move(coloc_values));
     }
   }
 }
@@ -1182,10 +1180,19 @@
       }
 
       if (src_node != nullptr && src_index >= src_node->num_outputs()) {
-        return errors::InvalidArgument(
-            "Node '", node_def.name(), "': Connecting to invalid output ",
-            tensor_id.index(), " of source node ", tensor_id.node(),
-            " which has ", src_node->num_outputs(), " outputs");
+        std::ostringstream out;
+        out << "Node '" << node_def.name() << "': Connecting to invalid output "
+            << tensor_id.index() << " of source node " << tensor_id.node()
+            << " which has " << src_node->num_outputs() << " outputs.";
+
+        if (src_node->type_string() == "If" ||
+            src_node->type_string() == "StatelessIf" ||
+            src_node->type_string() == "While" ||
+            src_node->type_string() == "StatelessWhile") {
+          out << " Try using "
+              << "tf.compat.v1.experimental.output_all_intermediates(True).";
+        }
+        return errors::InvalidArgument(out.str());
       }
 
       inputs.emplace_back(string(tensor_id.node()), src_node, src_index);
diff --git a/tensorflow/core/graph/graph_def_builder_util.cc b/tensorflow/core/graph/graph_def_builder_util.cc
index 102c721..3ca9f8a 100644
--- a/tensorflow/core/graph/graph_def_builder_util.cc
+++ b/tensorflow/core/graph/graph_def_builder_util.cc
@@ -22,7 +22,7 @@
   GraphDef graph_def;
   TF_RETURN_IF_ERROR(builder.ToGraphDef(&graph_def));
   GraphConstructorOptions opts;
-  return ConvertGraphDefToGraph(opts, graph_def, graph);
+  return ConvertGraphDefToGraph(opts, std::move(graph_def), graph);
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc
index 1c906a3..b295085 100644
--- a/tensorflow/core/graph/graph_partition.cc
+++ b/tensorflow/core/graph/graph_partition.cc
@@ -947,13 +947,13 @@
     // Not related to send/recv.
     return;
   }
-  string send_device;
-  if (!GetNodeAttr(*ndef, "send_device", &send_device).ok()) {
+  const string& send_device = GetNodeAttrString(*ndef, "send_device");
+  if (send_device.empty()) {
     // No known send_device. The runtime will detect it later.
     return;
   }
   int64 incarnation = PartitionOptions::kIllegalIncarnation;
-  if (!GetNodeAttr(*ndef, "send_device_incarnation", &incarnation).ok() ||
+  if (!TryGetNodeAttr(*ndef, "send_device_incarnation", &incarnation) ||
       (incarnation == PartitionOptions::kIllegalIncarnation)) {
     incarnation = opts.get_incarnation(send_device);
     SetAttrValue(incarnation,
diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h
index c204dd0..cb4afab 100644
--- a/tensorflow/core/graph/mkl_graph_util.h
+++ b/tensorflow/core/graph/mkl_graph_util.h
@@ -104,12 +104,24 @@
 
 // Prefix that we add to Tensorflow op name to construct Mkl op name.
 static const char* const kMklOpPrefix = "_Mkl";
+// TODO(intel-tf): PR review feedback (penpornk)
+// Can we add eager_mode (or is_eager) as an op attribute instead?
+// This way we don't need to rename the op just to pass eager_mode
+// through template parameter.
+static const char* const kMklEagerOpPrefix = "_MklEager";
 
 // Get the name of Mkl op from original TensorFlow op
 // We prefix 'Mkl' to the original op to get Mkl op.
 inline string GetMklOpName(const string& name) {
   return string(kMklOpPrefix) + name;
 }
+
+// Get the name of Mkl Eager op from original TensorFlow op
+// We prefix 'MklEager' to the original op to get Mkl Eager op.
+inline string GetMklEagerOpName(const string& name) {
+  return string(kMklEagerOpPrefix) + name;
+}
+
 // Check whether opname with type T is registered as MKL operator
 // that can accept input tensors in MKL layout.
 //
@@ -177,6 +189,11 @@
   return IsMklLayoutDependentOp(op_name, T) || IsMklNameChangeOp(op_name, T);
 }
 
+static inline bool IsMklOp(const Node* n) {
+  DataType T;
+  return GetNodeAttr(n->def(), "T", &T).ok() && IsMklOp(n->type_string(), T);
+}
+
 // Check whether opname with type T is registered as MKL-compliant and
 // is element-wise.
 //
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 8812ec9..c97cbd8 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -353,47 +353,49 @@
     csinfo_.mul = "Mul";
     csinfo_.squared_difference = "SquaredDifference";
     csinfo_.sub = "Sub";
-    // End - element-wise ops. See note above.
+// End - element-wise ops. See note above.
 
-    // NOTE: names are alphabetically sorted.
+// NOTE: names are alphabetically sorted.
+#ifndef ENABLE_MKLDNN_V1
     rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
-                      CopyAttrsAddN, AlwaysRewrite,
+                      CopyAttrsAll, AlwaysRewrite,
                       kRewriteForLayoutPropagation});
     rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
-                      CopyAttrsDataType, AlwaysRewrite,
+                      CopyAttrsAll, RewriteIfAtleastOneMklInput,
                       kRewriteForLayoutPropagation});
     rinfo_.push_back(
         {csinfo_.avg_pool, mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
-         CopyAttrsPooling, AlwaysRewrite, kRewriteForLayoutPropagation});
+         CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
     rinfo_.push_back({csinfo_.avg_pool_grad,
                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad),
-                      CopyAttrsPooling, AlwaysRewrite,
+                      CopyAttrsAll, AlwaysRewrite,
                       kRewriteForLayoutPropagation});
     rinfo_.push_back(
         {csinfo_.avg_pool3d, mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d),
-         CopyAttrsPooling, AlwaysRewrite, kRewriteForLayoutPropagation});
+         CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
     rinfo_.push_back({csinfo_.avg_pool3d_grad,
                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d_grad),
-                      CopyAttrsPooling, AlwaysRewrite,
+                      CopyAttrsAll, AlwaysRewrite,
                       kRewriteForLayoutPropagation});
     rinfo_.push_back({csinfo_.batch_matmul,
                       mkl_op_registry::GetMklOpName(csinfo_.batch_matmul),
-                      CopyAttrsBatchMatMul, AlwaysRewrite,
-                      kRewriteForOpNameChange});
+                      CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
     rinfo_.push_back(
         {csinfo_.concat, mkl_op_registry::GetMklOpName(csinfo_.concat),
-         CopyAttrsConcat, AlwaysRewrite, kRewriteForLayoutPropagation});
+         CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
     rinfo_.push_back(
         {csinfo_.concatv2, mkl_op_registry::GetMklOpName(csinfo_.concatv2),
-         CopyAttrsConcatV2, AlwaysRewrite, kRewriteForLayoutPropagation});
+         CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
     rinfo_.push_back(
         {csinfo_.conjugate_transpose,
          mkl_op_registry::GetMklOpName(csinfo_.conjugate_transpose),
-         CopyAttrsTranspose, AlwaysRewrite, kRewriteForOpNameChange});
+         CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
+#endif  // !ENABLE_MKLDNN_V1
     rinfo_.push_back({csinfo_.conv2d,
                       mkl_op_registry::GetMklOpName(csinfo_.conv2d),
                       CopyAttrsConvCheckConstFilter, AlwaysRewrite,
                       kRewriteForLayoutPropagation});
+#ifndef ENABLE_MKLDNN_V1
     rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias,
                       CopyAttrsConvCheckConstFilter, AlwaysRewrite,
                       kRewriteForLayoutPropagation});
@@ -427,90 +429,87 @@
     rinfo_.push_back(
         {csinfo_.depthwise_conv2d_grad_input,
          mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_input),
-         CopyAttrsConv2DDepthwise, AlwaysRewrite,
-         kRewriteForLayoutPropagation});
+         CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
     rinfo_.push_back(
         {csinfo_.depthwise_conv2d_grad_filter,
          mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_filter),
-         CopyAttrsConv2DDepthwise, AlwaysRewrite,
-         kRewriteForLayoutPropagation});
+         CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
     rinfo_.push_back(
         {csinfo_.dequantize, mkl_op_registry::GetMklOpName(csinfo_.dequantize),
-         CopyAttrsDequantize, DequantizeRewrite, kRewriteForLayoutPropagation});
+         CopyAttrsAll, DequantizeRewrite, kRewriteForLayoutPropagation});
     rinfo_.push_back({csinfo_.fused_batch_norm,
                       mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
-                      CopyAttrsFusedBatchNorm, AlwaysRewrite,
+                      CopyAttrsAll, AlwaysRewrite,
                       kRewriteForLayoutPropagation});
     rinfo_.push_back(
         {csinfo_.fused_batch_norm_grad,
          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
-         CopyAttrsFusedBatchNorm, AlwaysRewrite, kRewriteForLayoutPropagation});
+         CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
     rinfo_.push_back(
         {csinfo_.fused_batch_norm_v2,
          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v2),
-         CopyAttrsFusedBatchNormV2, AlwaysRewrite,
-         kRewriteForLayoutPropagation});
+         CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
     rinfo_.push_back(
         {csinfo_.fused_batch_norm_grad_v2,
          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v2),
-         CopyAttrsFusedBatchNormV2, AlwaysRewrite,
-         kRewriteForLayoutPropagation});
+         CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
 
-    // Using CopyAttrsFusedBatchNormV2 for V3 on CPU, as there are no additional
+    // Using CopyAttrsAll for V3 on CPU, as there are no additional
     // attributes.
     rinfo_.push_back(
         {csinfo_.fused_batch_norm_v3,
          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v3),
-         CopyAttrsFusedBatchNormV2, AlwaysRewrite,
-         kRewriteForLayoutPropagation});
+         CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
     rinfo_.push_back(
         {csinfo_.fused_batch_norm_grad_v3,
          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3),
-         CopyAttrsFusedBatchNormV2, AlwaysRewrite,
-         kRewriteForLayoutPropagation});
+         CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
 
     rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d,
                       CopyAttrsFusedConv2D, FusedConv2DRewrite,
                       kRewriteForLayoutPropagation});
-    rinfo_.push_back(
-        {csinfo_.identity, mkl_op_registry::GetMklOpName(csinfo_.identity),
-         CopyAttrsDataType, AlwaysRewrite, kRewriteForLayoutPropagation});
+    rinfo_.push_back({csinfo_.identity,
+                      mkl_op_registry::GetMklOpName(csinfo_.identity),
+                      CopyAttrsAll, RewriteIfAtleastOneMklInput,
+                      kRewriteForLayoutPropagation});
+
     rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
-                      CopyAttrsLRN, LrnRewrite, kRewriteForLayoutPropagation});
+                      CopyAttrsAll, LrnRewrite, kRewriteForLayoutPropagation});
     rinfo_.push_back(
         {csinfo_.lrn_grad, mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
-         CopyAttrsLRN, LrnGradRewrite, kRewriteForLayoutPropagation});
+         CopyAttrsAll, LrnGradRewrite, kRewriteForLayoutPropagation});
     rinfo_.push_back({csinfo_.matmul,
                       mkl_op_registry::GetMklOpName(csinfo_.matmul),
-                      CopyAttrsMatMul, AlwaysRewrite, kRewriteForOpNameChange});
+                      CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
     rinfo_.push_back(
         {csinfo_.leakyrelu, mkl_op_registry::GetMklOpName(csinfo_.leakyrelu),
-         CopyAttrsLeakyRelu, LeakyReluRewrite, kRewriteForLayoutPropagation});
+         CopyAttrsAll, LeakyReluRewrite, kRewriteForLayoutPropagation});
     rinfo_.push_back({csinfo_.leakyrelu_grad,
                       mkl_op_registry::GetMklOpName(csinfo_.leakyrelu_grad),
-                      CopyAttrsLeakyRelu, LeakyReluRewrite,
+                      CopyAttrsAll, LeakyReluRewrite,
                       kRewriteForLayoutPropagation});
     rinfo_.push_back({csinfo_.max_pool,
                       mkl_op_registry::GetMklOpName(csinfo_.max_pool),
-                      CopyAttrsPooling, NonDepthBatchWisePoolRewrite,
+                      CopyAttrsAll, NonDepthBatchWisePoolRewrite,
                       kRewriteForLayoutPropagation});
     rinfo_.push_back({csinfo_.max_pool_grad,
                       mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
-                      CopyAttrsPooling, MaxpoolGradRewrite,
+                      CopyAttrsAll, MaxpoolGradRewrite,
                       kRewriteForLayoutPropagation});
     rinfo_.push_back({csinfo_.max_pool3d,
                       mkl_op_registry::GetMklOpName(csinfo_.max_pool3d),
-                      CopyAttrsPooling, NonDepthBatchWisePoolRewrite,
+                      CopyAttrsAll, NonDepthBatchWisePoolRewrite,
                       kRewriteForLayoutPropagation});
     rinfo_.push_back({csinfo_.max_pool3d_grad,
                       mkl_op_registry::GetMklOpName(csinfo_.max_pool3d_grad),
-                      CopyAttrsPooling, AlwaysRewrite,
+                      CopyAttrsAll, AlwaysRewrite,
                       kRewriteForLayoutPropagation});
-    rinfo_.push_back(
-        {csinfo_.maximum, mkl_op_registry::GetMklOpName(csinfo_.maximum),
-         CopyAttrsDataType, AlwaysRewrite, kRewriteForLayoutPropagation});
+    rinfo_.push_back({csinfo_.maximum,
+                      mkl_op_registry::GetMklOpName(csinfo_.maximum),
+                      CopyAttrsAll, RewriteIfAtleastOneMklInput,
+                      kRewriteForLayoutPropagation});
     rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul),
-                      CopyAttrsDataType, AlwaysRewrite,
+                      CopyAttrsAll, RewriteIfAtleastOneMklInput,
                       kRewriteForLayoutPropagation});
     rinfo_.push_back({csinfo_.pad_with_conv2d, csinfo_.mkl_pad_with_conv2d,
                       CopyAttrsPadWithConv2D, AlwaysRewrite,
@@ -521,11 +520,11 @@
                       kRewriteForLayoutPropagation});
     rinfo_.push_back({csinfo_.quantized_avg_pool,
                       mkl_op_registry::GetMklOpName(csinfo_.quantized_avg_pool),
-                      CopyAttrsQuantizedPooling, AlwaysRewrite,
+                      CopyAttrsAll, AlwaysRewrite,
                       kRewriteForLayoutPropagation});
     rinfo_.push_back({csinfo_.quantized_concatv2,
                       mkl_op_registry::GetMklOpName(csinfo_.quantized_concatv2),
-                      CopyAttrsConcatV2, AlwaysRewrite,
+                      CopyAttrsAll, AlwaysRewrite,
                       kRewriteForLayoutPropagation});
     rinfo_.push_back({csinfo_.quantized_conv2d,
                       mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d),
@@ -574,7 +573,7 @@
          kRewriteForLayoutPropagation});
     rinfo_.push_back({csinfo_.quantized_max_pool,
                       mkl_op_registry::GetMklOpName(csinfo_.quantized_max_pool),
-                      CopyAttrsQuantizedPooling, AlwaysRewrite,
+                      CopyAttrsAll, AlwaysRewrite,
                       kRewriteForLayoutPropagation});
     rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_sum_and_relu,
                       mkl_op_registry::GetMklOpName(
@@ -631,55 +630,58 @@
          kRewriteForLayoutPropagation});
     rinfo_.push_back({csinfo_.quantize_v2,
                       mkl_op_registry::GetMklOpName(csinfo_.quantize_v2),
-                      CopyAttrsQuantizeV2, QuantizeOpRewrite,
+                      CopyAttrsAll, QuantizeOpRewrite,
                       kRewriteForLayoutPropagation});
     rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
-                      CopyAttrsDataType, AlwaysRewrite,
+                      CopyAttrsAll, AlwaysRewrite,
                       kRewriteForLayoutPropagation});
     rinfo_.push_back(
         {csinfo_.relu_grad, mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
-         CopyAttrsDataType, AlwaysRewrite, kRewriteForLayoutPropagation});
+         CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
     rinfo_.push_back(
         {csinfo_.relu6, mkl_op_registry::GetMklOpName(csinfo_.relu6),
-         CopyAttrsDataType, AlwaysRewrite, kRewriteForLayoutPropagation});
+         CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
     rinfo_.push_back(
         {csinfo_.relu6_grad, mkl_op_registry::GetMklOpName(csinfo_.relu6_grad),
-         CopyAttrsDataType, AlwaysRewrite, kRewriteForLayoutPropagation});
+         CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
     rinfo_.push_back(
         {csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize),
-         CopyAttrsRequantize, AlwaysRewrite, kRewriteForLayoutPropagation});
-    // Disable these two MKL operators for now due to some test failures caused
-    // by these two ops
-    /*
-    rinfo_.push_back({csinfo_.tanh,
-                      mkl_op_registry::GetMklOpName(csinfo_.tanh),
-                      CopyAttrsDataType, AlwaysRewrite,
-                      kRewriteForLayoutPropagation});
-    rinfo_.push_back({csinfo_.tanh_grad,
-                      mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
-                      CopyAttrsDataType, AlwaysRewrite,
-                      kRewriteForLayoutPropagation});
-    */
+         CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
+#endif  // !ENABLE_MKLDNN_V1
+// Disable these two MKL operators for now due to some test failures caused
+// by these two ops
+/*
+rinfo_.push_back({csinfo_.tanh,
+                  mkl_op_registry::GetMklOpName(csinfo_.tanh),
+                  CopyAttrsAll, AlwaysRewrite,
+                  kRewriteForLayoutPropagation});
+rinfo_.push_back({csinfo_.tanh_grad,
+                  mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
+                  CopyAttrsAll, AlwaysRewrite,
+                  kRewriteForLayoutPropagation});
+*/
+#ifndef ENABLE_MKLDNN_V1
     rinfo_.push_back(
         {csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape),
-         CopyAttrsReshape, AlwaysRewrite, kRewriteForLayoutPropagation});
-    rinfo_.push_back(
-        {csinfo_.slice, mkl_op_registry::GetMklOpName(csinfo_.slice),
-         CopyAttrsSlice, AlwaysRewrite, kRewriteForLayoutPropagation});
+         CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
+    rinfo_.push_back({csinfo_.slice,
+                      mkl_op_registry::GetMklOpName(csinfo_.slice),
+                      CopyAttrsAll, RewriteIfAtleastOneMklInput,
+                      kRewriteForLayoutPropagation});
     rinfo_.push_back(
         {csinfo_.softmax, mkl_op_registry::GetMklOpName(csinfo_.softmax),
-         CopyAttrsDataType, AlwaysRewrite, kRewriteForLayoutPropagation});
+         CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
 
     rinfo_.push_back({csinfo_.squared_difference,
                       mkl_op_registry::GetMklOpName(csinfo_.squared_difference),
-                      CopyAttrsDataType, AlwaysRewrite,
+                      CopyAttrsAll, RewriteIfAtleastOneMklInput,
                       kRewriteForLayoutPropagation});
     rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub),
-                      CopyAttrsDataType, AlwaysRewrite,
+                      CopyAttrsAll, RewriteIfAtleastOneMklInput,
                       kRewriteForLayoutPropagation});
-    rinfo_.push_back(
-        {csinfo_.transpose, mkl_op_registry::GetMklOpName(csinfo_.transpose),
-         CopyAttrsTranspose, AlwaysRewrite, kRewriteForOpNameChange});
+    rinfo_.push_back({csinfo_.transpose,
+                      mkl_op_registry::GetMklOpName(csinfo_.transpose),
+                      CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
 
     // Add info about which ops to add workspace edge to and the slots.
     wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
@@ -760,6 +762,7 @@
          // CheckForMklOp
          FuseConv3D,
          CopyAttrsConv});
+#endif  // !ENABLE_MKLDNN_V1
   }
 
   // Standard interface to run pass
@@ -1381,6 +1384,38 @@
   // @return - true (since we want to always rewrite)
   static bool AlwaysRewrite(const Node* n) { return true; }
 
+  // Rewrite rule which considers "context" of the current node to decide if we
+  // should rewrite. By "context" we currently mean all the inputs of current
+  // node. The idea is if none of the inputs of current node are not MKL nodes,
+  // then rewriting current node to MKL node _may not_ offer any performance
+  // improvement.
+  //
+  // One such case is element-wise ops. For such ops, we reuse the Eigen
+  // implementation and pass the MKL metadata tensor through so we can avoid
+  // conversions. However, if all incoming edges are in TF format, we don't
+  // need all this overhead, so replace the elementwise node only if at least
+  // one of its parents is a MKL node.
+  //
+  // More generally, all memory- or IO-bound ops (such as Identity) may fall
+  // under this category.
+  //
+  // @input - Input graph node to be rewritten
+  // @return - true if node is to be rewritten as MKL node; false otherwise.
+  static bool RewriteIfAtleastOneMklInput(const Node* n) {
+    DataType T;
+    if (GetNodeAttr(n->def(), "T", &T).ok() &&
+        mkl_op_registry::IsMklOp(
+            mkl_op_registry::GetMklOpName(n->type_string()), T)) {
+      for (auto e : n->in_edges()) {
+        if (e->IsControlEdge()) continue;
+        if (mkl_op_registry::IsMklOp(e->src())) {
+          return true;
+        }
+      }
+    }
+    return false;
+  }
+
   static bool DequantizeRewrite(const Node* n) {
     DCHECK(n);
     Node* input = nullptr;
@@ -1480,7 +1515,7 @@
     DCHECK(n);
 
     float alpha;
-    bool has_attr = GetNodeAttr(n->def(), "alpha", &alpha).ok();
+    bool has_attr = TryGetNodeAttr(n->def(), "alpha", &alpha);
     DCHECK(has_attr);
 
     // If the alpha of LeakyRelu is less than 1, rewrite the node.
@@ -1543,7 +1578,7 @@
     // together with Conv2D (ex. batchnorm). We rewrite _FusedConv2D only if
     // it includes those we support.
     DataType T;
-    if (!GetNodeAttr(n->def(), "T", &T).ok() ||
+    if (!TryGetNodeAttr(n->def(), "T", &T) ||
         !mkl_op_registry::IsMklLayoutDependentOp(csinfo_.mkl_fused_conv2d, T)) {
       return false;
     }
@@ -1757,41 +1792,17 @@
   // We need operator-specific function to copy attributes because the framework
   // does not provide any generic function for it.
   // NOTE: names are alphabetically sorted.
-  static void CopyAttrsAddN(const Node* orig_node, NodeBuilder* nb,
-                            bool change_format = false);
-  static void CopyAttrsBatchMatMul(const Node* orig_node, NodeBuilder* nb,
-                                   bool change_format = false);
-  static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb,
-                                   bool change_format = false);
-  static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb,
-                              bool change_format = false);
-  static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb,
-                                bool change_format = false);
+  static void CopyAttrsAll(const Node* orig_node, NodeBuilder* nb,
+                           bool change_format = false);
   static void CopyAttrsConv(const Node* orig_node, NodeBuilder* nb,
                             bool change_format = false);
-  static void CopyAttrsConv2DDepthwise(const Node* orig_node, NodeBuilder* nb,
-                                       bool change_format = false);
   static void CopyAttrsConv2DDepthwiseCheckConstFilter(
       const Node* orig_node, NodeBuilder* nb, bool change_format = false);
   static void CopyAttrsConvCheckConstFilter(const Node* orig_node,
                                             NodeBuilder* nb,
                                             bool change_format = false);
-  static void CopyAttrsDataType(const Node* orig_node, NodeBuilder* nb,
-                                bool change_format = false);
-  static void CopyAttrsDequantize(const Node* orig_node, NodeBuilder* nb,
-                                  bool change_format = false);
-  static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb,
-                                      bool change_format = false);
-  static void CopyAttrsFusedBatchNormV2(const Node* orig_node, NodeBuilder* nb,
-                                        bool change_format = false);
-  static void CopyAttrsLeakyRelu(const Node* orig_node, NodeBuilder* nb,
-                                 bool change_format = false);
   static void CopyAttrsFusedConv2D(const Node* orig_node, NodeBuilder* nb,
                                    bool change_format = false);
-  static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb,
-                           bool change_format = false);
-  static void CopyAttrsMatMul(const Node* orig_node, NodeBuilder* nb,
-                              bool change_format = false);
   static void CopyAttrsPadWithConv2D(const Node* orig_node, NodeBuilder* nb,
                                      bool change_format = false);
   static void CopyAttrsPadWithFusedConv2D(const Node* orig_node,
@@ -1804,26 +1815,8 @@
                                              const Node* orig_node2,
                                              NodeBuilder* nb,
                                              bool change_format = false);
-  static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb,
-                               bool change_format = false);
-  static void CopyAttrsQuantizedPooling(const Node* orig_node, NodeBuilder* nb,
-                                        bool change_format = false);
   static void CopyAttrsQuantizedConv2D(const Node* orig_node, NodeBuilder* nb,
                                        bool change_format = false);
-  static void CopyAttrsQuantizedConcat(const Node* orig_node, NodeBuilder* nb,
-                                       bool change_format = false);
-  static void CopyAttrsQuantizeV2(const Node* orig_node, NodeBuilder* nb,
-                                  bool change_format = false);
-  static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb,
-                               bool change_format = false);
-  static void CopyAttrsRequantize(const Node* orig_node, NodeBuilder* nb,
-                                  bool change_format = false);
-  static void CopyAttrsSlice(const Node* orig_node, NodeBuilder* nb,
-                             bool change_format = false);
-  static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb,
-                             bool change_format = false);
-  static void CopyAttrsTranspose(const Node* orig_node, NodeBuilder* nb,
-                                 bool change_format = false);
   static void CopyFormatAttrsConv(const Node* orig_node, NodeBuilder* nb,
                                   const std::vector<int32>& strides,
                                   const std::vector<int32>& dilations,
@@ -1975,7 +1968,7 @@
 
   // If this is an MKL op, then it will create extra output for MKL layout.
   DataType T;
-  if (GetNodeAttr(n->def(), "T", &T).ok() &&
+  if (TryGetNodeAttr(n->def(), "T", &T) &&
       mkl_op_registry::IsMklLayoutDependentOp(n->type_string(), T)) {
     // If this is an MKL op, then it will generate an edge that will receive
     // Mkl tensor from a node.
@@ -2373,6 +2366,21 @@
 // Op-specific functions to copy attributes from old node to new node
 //////////////////////////////////////////////////////////////////////////
 
+// Generic function to copy all attributes from original node to target.
+void MklLayoutRewritePass::CopyAttrsAll(const Node* orig_node, NodeBuilder* nb,
+                                        bool change_format) {
+  string name;
+  AttrSlice attr_list(orig_node->def());
+
+  auto iter = attr_list.begin();
+  while (iter != attr_list.end()) {
+    name = iter->first;
+    auto attr = iter->second;
+    nb->Attr(name, attr);
+    ++iter;
+  }
+}
+
 void MklLayoutRewritePass::CopyAttrsConvCheckConstFilter(const Node* orig_node,
                                                          NodeBuilder* nb,
                                                          bool change_format) {
@@ -2399,23 +2407,6 @@
   CopyFormatAttrsConv(orig_node, nb, strides, dilations, change_format);
 }
 
-void MklLayoutRewritePass::CopyAttrsQuantizeV2(const Node* orig_node,
-                                               NodeBuilder* nb,
-                                               bool change_format) {
-  DataType T;
-  string mode;
-  string round_mode;
-
-  // Get all attributes from old node.
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "mode", &mode));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "round_mode", &round_mode));
-
-  // Add attributes to new node.
-  nb->Attr("T", T);
-  nb->Attr("mode", mode);
-  nb->Attr("round_mode", round_mode);
-}
 void MklLayoutRewritePass::CopyAttrsConv(const Node* orig_node, NodeBuilder* nb,
                                          bool change_format) {
   DataType T;
@@ -2437,21 +2428,6 @@
   CopyFormatAttrsConv(orig_node, nb, strides, dilations, change_format);
 }
 
-void MklLayoutRewritePass::CopyAttrsDequantize(const Node* orig_node,
-                                               NodeBuilder* nb,
-                                               bool change_format) {
-  DataType T;
-  string mode;
-
-  // Get all attributes from old node.
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "mode", &mode));
-
-  // Add attributes to new node.
-  nb->Attr("T", T);
-  nb->Attr("mode", mode);
-}
-
 // Used in rinfo when replacing __MklDummyPadWithConv2D by _MklPadWithConv2D
 void MklLayoutRewritePass::CopyAttrsPadWithConv2D(const Node* orig_node,
                                                   NodeBuilder* nb,
@@ -2576,30 +2552,6 @@
   nb->Attr("fused_ops", fused_ops);
 }
 
-void MklLayoutRewritePass::CopyAttrsConv2DDepthwise(const Node* orig_node,
-                                                    NodeBuilder* nb,
-                                                    bool change_format) {
-  DataType T;
-  string data_format;
-  string padding;
-  std::vector<int32> strides;
-  std::vector<int32> dilations;
-
-  // Get all attributes from old node.
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
-
-  // Add attributes to new node.
-  nb->Attr("T", T);
-  nb->Attr("strides", strides);
-  nb->Attr("dilations", dilations);
-  nb->Attr("padding", padding);
-  nb->Attr("data_format", data_format);
-}
-
 void MklLayoutRewritePass::CopyAttrsConv2DDepthwiseCheckConstFilter(
     const Node* orig_node, NodeBuilder* nb, bool change_format) {
   DataType T;
@@ -2627,131 +2579,6 @@
   nb->Attr("data_format", data_format);
 }
 
-void MklLayoutRewritePass::CopyAttrsAddN(const Node* orig_node, NodeBuilder* nb,
-                                         bool change_format) {
-  DataType T;
-  int N;
-
-  // Get all attributes from old node.
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
-
-  // Add attributes to new node.
-  nb->Attr("T", T);
-  nb->Attr("N", N);
-}
-
-void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orig_node,
-                                                NodeBuilder* nb,
-                                                bool change_format) {
-  DataType T;
-  string data_format;
-  std::vector<int32> strides;
-
-  // Get all attributes from old node.
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
-
-  // Add attributes to new node.
-  nb->Attr("T", T);
-  nb->Attr("strides", strides);
-  nb->Attr("data_format", data_format);
-}
-
-void MklLayoutRewritePass::CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb,
-                                        bool change_format) {
-  DataType T;
-  int depth_radius;
-  float bias;
-  float alpha;
-  float beta;
-
-  // Get all attributes from old node.
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "depth_radius", &depth_radius));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "bias", &bias));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "alpha", &alpha));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "beta", &beta));
-
-  // Add attributes to new node.
-  nb->Attr("T", T);
-  nb->Attr("depth_radius", depth_radius);
-  nb->Attr("bias", bias);
-  nb->Attr("alpha", alpha);
-  nb->Attr("beta", beta);
-}
-
-void MklLayoutRewritePass::CopyAttrsLeakyRelu(const Node* orig_node,
-                                              NodeBuilder* nb,
-                                              bool change_format) {
-  DataType T;
-  float alpha;
-
-  // Get all attributes from old node.
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "alpha", &alpha));
-
-  // Add attributes to new node.
-  nb->Attr("T", T);
-  nb->Attr("alpha", alpha);
-}
-
-void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node,
-                                            NodeBuilder* nb,
-                                            bool change_format) {
-  DataType T;
-  string data_format;
-  string padding;
-  std::vector<int32> ksize, strides;
-
-  // Get all attributes from old node.
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
-
-  // Add attributes to new node.
-  nb->Attr("T", T);
-  nb->Attr("ksize", ksize);
-  nb->Attr("strides", strides);
-  nb->Attr("padding", padding);
-  nb->Attr("data_format", data_format);
-}
-
-void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node,
-                                             NodeBuilder* nb,
-                                             bool change_format) {
-  DataType T;
-
-  // Get all attributes from old node.
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
-
-  // Add attributes to new node.
-  nb->Attr("T", T);
-}
-
-void MklLayoutRewritePass::CopyAttrsQuantizedPooling(const Node* orig_node,
-                                                     NodeBuilder* nb,
-                                                     bool change_format) {
-  DataType T;
-  string padding;
-  std::vector<int32> ksize, strides;
-
-  // Get all attributes from old node.
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
-
-  // Add attributes to new node.
-  nb->Attr("T", T);
-  nb->Attr("ksize", ksize);
-  nb->Attr("strides", strides);
-  nb->Attr("padding", padding);
-}
-
 void MklLayoutRewritePass::CopyAttrsQuantizedConv2D(const Node* orig_node,
                                                     NodeBuilder* nb,
                                                     bool change_format) {
@@ -2816,66 +2643,6 @@
   if (bias_status.ToString() == "OK") nb->Attr("Tbias", Tbias);
 }
 
-void MklLayoutRewritePass::CopyAttrsRequantize(const Node* orig_node,
-                                               NodeBuilder* nb,
-                                               bool change_format) {
-  DataType Tinput, out_type;
-
-  // Get all attributes from old node.
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tinput", &Tinput));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "out_type", &out_type));
-
-  // Add attributes to new node.
-  nb->Attr("Tinput", Tinput);
-  nb->Attr("out_type", out_type);
-}
-
-void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
-                                            NodeBuilder* nb,
-                                            bool change_format) {
-  DataType T;
-  DataType Tshape;
-
-  // Get all attributes from old node.
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape));
-
-  // Add attributes to new node.
-  nb->Attr("T", T);
-  nb->Attr("Tshape", Tshape);
-}
-
-void MklLayoutRewritePass::CopyAttrsSlice(const Node* orig_node,
-                                          NodeBuilder* nb, bool change_format) {
-  DataType T;
-  DataType Index;
-
-  // Get all attributes from old node.
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Index", &Index));
-
-  // Add attributes to new node.
-  nb->Attr("T", T);
-  nb->Attr("Index", Index);
-}
-
-void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node,
-                                          NodeBuilder* nb, bool change_format) {
-  DataType T;
-  string data_format;
-  int num_split;
-
-  // Get all attributes from old node.
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "num_split", &num_split));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
-
-  // Add attributes to new node.
-  nb->Attr("T", T);
-  nb->Attr("num_split", num_split);
-  nb->Attr("data_format", data_format);
-}
-
 void MklLayoutRewritePass::CopyFormatAttrsConv(
     const Node* orig_node, NodeBuilder* nb, const std::vector<int32>& strides,
     const std::vector<int32>& dilations, bool change_format) {
@@ -2915,70 +2682,6 @@
   }
 }
 
-void MklLayoutRewritePass::CopyAttrsConcat(const Node* orig_node,
-                                           NodeBuilder* nb,
-                                           bool change_format) {
-  DataType T;
-  int N;
-
-  // Get all attributes from old node.
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
-
-  // Add attributes to new node.
-  nb->Attr("T", T);
-  nb->Attr("N", N);
-}
-
-void MklLayoutRewritePass::CopyAttrsConcatV2(const Node* orig_node,
-                                             NodeBuilder* nb,
-                                             bool change_format) {
-  DataType T;
-  int N;
-  DataType tidx;
-
-  // Get all attributes from old node.
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tidx", &tidx));
-
-  // Add attributes to new node.
-  nb->Attr("T", T);
-  nb->Attr("N", N);
-  nb->Attr("Tidx", tidx);
-}
-
-void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node,
-                                                   NodeBuilder* nb,
-                                                   bool change_format) {
-  DataType T;
-  float epsilon;
-  string data_format;
-  bool is_training;
-
-  // Get all attributes from old node.
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "epsilon", &epsilon));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "is_training", &is_training));
-
-  // Add attributes to new node.
-  nb->Attr("T", T);
-  nb->Attr("epsilon", epsilon);
-  nb->Attr("data_format", data_format);
-  nb->Attr("is_training", is_training);
-}
-
-void MklLayoutRewritePass::CopyAttrsFusedBatchNormV2(const Node* orig_node,
-                                                     NodeBuilder* nb,
-                                                     bool change_format) {
-  CopyAttrsFusedBatchNorm(orig_node, nb, change_format);
-
-  DataType U;
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "U", &U));
-  nb->Attr("U", U);
-}
-
 void MklLayoutRewritePass::CopyAttrsFusedConv2D(const Node* orig_node,
                                                 NodeBuilder* nb,
                                                 bool change_format) {
@@ -3016,54 +2719,6 @@
   nb->Attr("epsilon", epsilon);
 }
 
-void MklLayoutRewritePass::CopyAttrsMatMul(const Node* orig_node,
-                                           NodeBuilder* nb,
-                                           bool change_format) {
-  DataType T;
-  bool transpose_a, transpose_b;
-
-  // Get all attributes from old node.
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "transpose_a", &transpose_a));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "transpose_b", &transpose_b));
-
-  // Add attributes to new node.
-  nb->Attr("T", T);
-  nb->Attr("transpose_a", transpose_a);
-  nb->Attr("transpose_b", transpose_b);
-}
-
-void MklLayoutRewritePass::CopyAttrsTranspose(const Node* orig_node,
-                                              NodeBuilder* nb,
-                                              bool change_format) {
-  DataType T, Tperm;
-
-  // Get all attributes from old node.
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tperm", &Tperm));
-
-  // Add attributes to new node.
-  nb->Attr("T", T);
-  nb->Attr("Tperm", Tperm);
-}
-
-void MklLayoutRewritePass::CopyAttrsBatchMatMul(const Node* orig_node,
-                                                NodeBuilder* nb,
-                                                bool change_format) {
-  DataType T;
-  bool adj_x, adj_y;
-
-  // Get all attributes from old node.
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "adj_x", &adj_x));
-  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "adj_y", &adj_y));
-
-  // Add attributes to new node.
-  nb->Attr("T", T);
-  nb->Attr("adj_x", adj_x);
-  nb->Attr("adj_y", adj_y);
-}
-
 //////////////////////////////////////////////////////////////////////////
 //           Helper functions related to node merge pass
 //////////////////////////////////////////////////////////////////////////
@@ -3809,13 +3464,13 @@
   DataType Tinput, Tfilter;
   bool type_attrs_present = false;
 
-  if (GetNodeAttr(n->def(), "Tinput", &Tinput).ok() &&
-      GetNodeAttr(n->def(), "Tfilter", &Tfilter).ok() &&
+  if (TryGetNodeAttr(n->def(), "Tinput", &Tinput) &&
+      TryGetNodeAttr(n->def(), "Tfilter", &Tfilter) &&
       mkl_op_registry::IsMklLayoutDependentOp(
           mkl_op_registry::GetMklOpName(n->type_string()), Tinput, Tfilter)) {
     type_attrs_present = true;
-  } else if (GetNodeAttr(n->def(), "T1", &T1).ok() &&
-             GetNodeAttr(n->def(), "T2", &T2).ok() &&
+  } else if (TryGetNodeAttr(n->def(), "T1", &T1) &&
+             TryGetNodeAttr(n->def(), "T2", &T2) &&
              mkl_op_registry::IsMklLayoutDependentOp(
                  mkl_op_registry::GetMklOpName(n->type_string()), T1, T2)) {
     type_attrs_present = true;
@@ -3846,7 +3501,7 @@
   // E.g., MklRelu does not support INT32. So we cannot rewrite Relu to
   // MklRelu if type is INT32.
   DataType T;
-  if (!GetNodeAttr(n->def(), "T", &T).ok()) {
+  if (!TryGetNodeAttr(n->def(), "T", &T)) {
     return nullptr;
   }
 
@@ -3873,47 +3528,6 @@
     return nullptr;
   }
 
-  // For elementwise node, we reuse the Eigen implementation and pass the MKL
-  // metadata tensor through so we can avoid conversions. However, if all
-  // incoming edges are in TF format, we don't need all this overhead, so
-  // replace the elementwise node only if at least one of its parents is a MKL
-  // node.
-  //
-  // Identity nodes can also skip replacement if they are not being served by
-  // any MKL nodes.
-  //
-  // TODO(vrane): Add implementation for element-wise ops that doesn't reuse
-  // eigen code to reduce cross-library dependency.
-  VLOG(1) << "ELEMENTWISE: checking op: " << n->type_string();
-  if (mkl_op_registry::IsMklElementWiseOp(
-          mkl_op_registry::GetMklOpName(n->type_string()), T) ||
-      n->type_string().find("Identity") != string::npos) {
-    VLOG(1) << "ELEMENTWISE: op is elementwise: " << n->type_string();
-    bool incoming_mkl_edge = false;
-    int num_parent = 0;
-    for (auto parent : n->in_edges()) {
-      if (mkl_op_registry::IsMklLayoutDependentOp(parent->src()->type_string(),
-                                                  T)) {
-        VLOG(1) << "ELEMENTWISE: parent " << num_parent++
-                << " is MKL op: " << parent->src()->type_string();
-        incoming_mkl_edge = true;
-        break;
-      } else {
-        VLOG(1) << "ELEMENTWISE: parent " << num_parent++
-                << " is NON-MKL op: " << parent->src()->type_string();
-      }
-    }
-    if (incoming_mkl_edge == false) {
-      VLOG(1) << "ELEMENTWISE: Skipping replacement of elementwise node which "
-                 "has no MKL "
-                 "parents.";
-      return nullptr;
-    } else {
-      VLOG(1) << "ELEMENTWISE: Replacing elementwise node " << n->type_string()
-              << " which has MKL parents";
-    }
-  }
-
   // We now check if rewrite rule applies for this op. If rewrite rule passes
   // for this op, then we rewrite it to Mkl op.
   // Find matching RewriteInfo and then check that rewrite rule applies.
@@ -4102,7 +3716,7 @@
 
   // If graph node is not Mkl node, then return.
   DataType T = DT_INVALID;
-  if (!GetNodeAttr(n->def(), "T", &T).ok() ||
+  if (!TryGetNodeAttr(n->def(), "T", &T) ||
       !mkl_op_registry::IsMklLayoutDependentOp(n->type_string(), T)) {
     return result;
   }
@@ -4127,7 +3741,7 @@
     // Check that the source node for edge 'e' is Mkl node. If it is not an Mkl
     // node, then we don't need to do anything.
     Node* e_src = e->src();
-    if (GetNodeAttr(e_src->def(), "T", &T).ok() &&
+    if (TryGetNodeAttr(e_src->def(), "T", &T) &&
         mkl_op_registry::IsMklLayoutDependentOp(e_src->type_string(), T)) {
       // Source node for edge 'e' is Mkl node.
       // Destination node and destination input slot of e is node 'n' and 'idx'
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index 0f1053a..df54c9f 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -16,7 +16,6 @@
 #if defined(INTEL_MKL) && defined(ENABLE_MKL)
 
 #include "tensorflow/core/graph/mkl_layout_pass.h"
-#include "tensorflow/core/graph/mkl_graph_util.h"
 
 #include <algorithm>
 #include <vector>
@@ -25,6 +24,7 @@
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/mkl_graph_util.h"
 #include "tensorflow/core/graph/testlib.h"
 #include "tensorflow/core/kernels/ops_util.h"
 #include "tensorflow/core/lib/random/simple_philox.h"
@@ -2206,7 +2206,6 @@
       "node { name: 'D' op: 'DepthwiseConv2dNativeBackpropFilter'"
       " attr { key: 'T'                value { type: DT_FLOAT } }"
       " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
       " attr { key: 'padding'          value { s: 'SAME' } }"
       " attr { key: 'dilations'        value { list: {i: 1, i:1, i:1, i:1} } }"
@@ -2230,7 +2229,6 @@
       "node { name: 'D' op: 'DepthwiseConv2dNativeBackpropInput'"
       " attr { key: 'T'                value { type: DT_FLOAT } }"
       " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
       " attr { key: 'padding'          value { s: 'SAME' } }"
       " attr { key: 'dilations'        value { list: {i: 1, i:1, i:1, i:1} } }"
@@ -3005,6 +3003,60 @@
 }
 
 /////////////////////////////////////////////////////////////////////
+//  Unit tests related to context-based node rewrite
+/////////////////////////////////////////////////////////////////////
+
+// If any of the inputs is an MKL op, then rewrite Slice to Mkl op.
+TEST_F(MklLayoutPassTest, NodeRewrite_Ctxbased_Slice_Positive) {
+  InitGraph(
+      "node { name: 'A' op: 'Input'}"
+      "node { name: 'B' op: 'Input'}"
+      "node { name: 'M' op: '_MklInput'}"
+      "node { name: 'N' op: '_MklInput'}"
+      "node { name: 'C' op: '_MklConv2D'"
+      " attr { key: 'T'                value { type: DT_FLOAT } }"
+      " attr { key: 'data_format'      value { s: 'NCHW' } }"
+      " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+      " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
+      " attr { key: 'padding'          value { s: 'SAME' } }"
+      " input: ['A', 'B', 'M', 'N']}"
+      "node { name: 'D' op: 'Int32Input'}"
+      "node { name: 'E' op: 'Int32Input'}"
+      "node { name: 'F' op: 'Slice'"
+      " attr { key: 'T'            value { type: DT_FLOAT } }"
+      " attr { key: 'Index'        value { type: DT_INT32 } }"
+      " input: ['C', 'D', 'E'] }"
+      "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
+      " input: ['A', 'C'] }");
+  EXPECT_EQ(DoMklLayoutOptimizationPass(),
+            "A(Input);B(Input);C(_MklConv2D);D(Int32Input);"
+            "DMT/_0(Const);DMT/_1(Const);"
+            "E(Int32Input);F(_MklSlice);G(Zeta);M(_MklInput);N(_MklInput)|"
+            "A->C;A->G;B->C:1;C->F;C->G:1;C:2->F:3;"
+            "C:control->DMT/_0:control;C:control->DMT/"
+            "_1:control;"
+            "D->F:1;DMT/_0->F:4;DMT/_1->F:5;"
+            "E->F:2;M->C:2;N->C:3");
+}
+
+// If none of the inputs is an MKL op, then Slice should not be rewritten.
+TEST_F(MklLayoutPassTest, NodeRewrite_Ctxbased_Slice_Negative) {
+  InitGraph(
+      "node { name: 'A' op: 'Input'}"
+      "node { name: 'B' op: 'Int32Input'}"
+      "node { name: 'C' op: 'Int32Input'}"
+      "node { name: 'D' op: 'Slice'"
+      " attr { key: 'T'            value { type: DT_FLOAT } }"
+      " attr { key: 'Index'        value { type: DT_INT32 } }"
+      " input: ['A', 'B', 'C'] }"
+      "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
+      " input: ['A', 'D'] }");
+  EXPECT_EQ(DoMklLayoutOptimizationPass(),
+            "A(Input);B(Int32Input);C(Int32Input);"
+            "D(Slice);E(Zeta)|A->D;A->E;B->D:1;C->D:2;D->E:1");
+}
+
+/////////////////////////////////////////////////////////////////////
 //  Unit tests related to rewriting node for workspace edges
 /////////////////////////////////////////////////////////////////////
 
@@ -3017,7 +3069,6 @@
       " attr { key: 'alpha'        value { f: 0.001 } }"
       " attr { key: 'beta'         value { f: 0.75 } }"
       " attr { key: 'bias'         value { f: 1.0 } }"
-      " attr { key: 'data_format'  value { s: 'NCHW' } }"
       " attr { key: 'depth_radius' value { i: 2 } }"
       " input: ['A'] }"
       "node { name: 'C' op: 'MaxPool'"
@@ -3041,7 +3092,6 @@
       " attr { key: 'alpha'        value { f: 0.001 } }"
       " attr { key: 'beta'         value { f: 0.75 } }"
       " attr { key: 'bias'         value { f: 1.0 } }"
-      " attr { key: 'data_format'  value { s: 'NCHW' } }"
       " attr { key: 'depth_radius' value { i: 2 } }"
       " input: ['E', 'F', 'B'] }"
       "node { name: 'H' op: 'Input'}"
@@ -3066,7 +3116,6 @@
       " attr { key: 'alpha'        value { f: 0.001 } }"
       " attr { key: 'beta'         value { f: 0.75 } }"
       " attr { key: 'bias'         value { f: 1.0 } }"
-      " attr { key: 'data_format'  value { s: 'NCHW' } }"
       " attr { key: 'depth_radius' value { i: 2 } }"
       " input: ['A'] }"
       "node { name: 'C' op: 'Input'}"
@@ -3076,7 +3125,6 @@
       " attr { key: 'alpha'        value { f: 0.001 } }"
       " attr { key: 'beta'         value { f: 0.75 } }"
       " attr { key: 'bias'         value { f: 1.0 } }"
-      " attr { key: 'data_format'  value { s: 'NCHW' } }"
       " attr { key: 'depth_radius' value { i: 2 } }"
       " input: ['C', 'D', 'B'] }"
       "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
@@ -3098,7 +3146,6 @@
       " attr { key: 'alpha'        value { f: 0.001 } }"
       " attr { key: 'beta'         value { f: 0.75 } }"
       " attr { key: 'bias'         value { f: 1.0 } }"
-      " attr { key: 'data_format'  value { s: 'NCHW' } }"
       " attr { key: 'depth_radius' value { i: 2 } }"
       " input: ['A'] }"
       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
@@ -3119,7 +3166,6 @@
       " attr { key: 'alpha'        value { f: 0.001 } }"
       " attr { key: 'beta'         value { f: 0.75 } }"
       " attr { key: 'bias'         value { f: 1.0 } }"
-      " attr { key: 'data_format'  value { s: 'NCHW' } }"
       " attr { key: 'depth_radius' value { i: 2 } }"
       " input: ['A', 'B', 'C'] }"
       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
@@ -3139,7 +3185,6 @@
       " attr { key: 'alpha'        value { f: 0.001 } }"
       " attr { key: 'beta'         value { f: 0.75 } }"
       " attr { key: 'bias'         value { f: 1.0 } }"
-      " attr { key: 'data_format'  value { s: 'NCHW' } }"
       " attr { key: 'depth_radius' value { i: 2 } }"
       " input: ['A'] }"
       "node { name: 'C' op: 'Input'}"
@@ -3149,7 +3194,6 @@
       " attr { key: 'alpha'        value { f: 0.001 } }"
       " attr { key: 'beta'         value { f: 0.75 } }"
       " attr { key: 'bias'         value { f: 1.0 } }"
-      " attr { key: 'data_format'  value { s: 'NCHW' } }"
       " attr { key: 'depth_radius' value { i: 2 } }"
       " input: ['C', 'D', 'B'] }"
       "node { name: 'F' op: 'LRNGrad'"
@@ -3157,7 +3201,6 @@
       " attr { key: 'alpha'        value { f: 0.001 } }"
       " attr { key: 'beta'         value { f: 0.75 } }"
       " attr { key: 'bias'         value { f: 1.0 } }"
-      " attr { key: 'data_format'  value { s: 'NCHW' } }"
       " attr { key: 'depth_radius' value { i: 2 } }"
       " input: ['C', 'B', 'D'] }"
       "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
@@ -3471,7 +3514,6 @@
       "node { name: 'D' op: 'DepthwiseConv2dNativeBackpropFilter'"
       " attr { key: 'T'                value { type: DT_FLOAT } }"
       " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
       " attr { key: 'padding'          value { s: 'SAME' } }"
       " attr { key: 'dilations'        value { list: {i: 1, i:1, i:1, i:1} } }"
@@ -3727,14 +3769,11 @@
       " attr { key: 'Index'        value { type: DT_INT32 } }"
       " input: ['A', 'B', 'C'] }"
       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
-      " input: ['A', 'D'] }");
+      " input: ['A', 'D'] }",
+      kGPUDevice);
   EXPECT_EQ(DoMklLayoutOptimizationPass(),
-            "A(Input);B(Int32Input);C(Int32Input);"
-            "D(_MklSlice);DMT/_0(Const);DMT/_1(Const);DMT/"
-            "_2(Const);E(Zeta)|A->D;A->E;"
-            "A:control->DMT/_0:control;A:control->DMT/"
-            "_1:control;A:control->DMT/_2:control;"
-            "B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
+            "A(Input);B(Int32Input);C(Int32Input);D(Slice);E(Zeta)|A->D;A->E;"
+            "B->D:1;C->D:2;D->E:1");
 }
 
 /////////////////////////////////////////////////////////////////////
diff --git a/tensorflow/core/graph/quantize_training.cc b/tensorflow/core/graph/quantize_training.cc
index 26bb654..4670e7a 100644
--- a/tensorflow/core/graph/quantize_training.cc
+++ b/tensorflow/core/graph/quantize_training.cc
@@ -172,8 +172,8 @@
 }
 
 void FillStringTensor(Tensor* dst, const Tensor& src) {
-  auto dst_flat = dst->flat<string>();
-  auto src_flat = src.flat<string>();
+  auto dst_flat = dst->flat<tstring>();
+  auto src_flat = src.flat<tstring>();
   for (int i = 0; i < src.NumElements(); i++) {
     dst_flat(i) = src_flat(i);
   }
@@ -220,8 +220,8 @@
   FillStringTensor(&new_shape_and_slices, shape_and_slices);
   for (int i = 0; i < var_size; i++) {
     Node* var = added_variables[i];
-    new_tensor_names.flat<string>()(tn_size + i) = var->name();
-    new_shape_and_slices.flat<string>()(tn_size + i) = "";
+    new_tensor_names.flat<tstring>()(tn_size + i) = var->name();
+    new_shape_and_slices.flat<tstring>()(tn_size + i) = "";
     var_nodeouts.emplace_back(var);
   }
   save_op_builder = save_op_builder.Input(var_nodeouts);
@@ -275,7 +275,7 @@
     // Construct the tensor_names input with the variable name.
     Node* tensor_names;
     Tensor tensor_names_val(DT_STRING, TensorShape({1}));
-    tensor_names_val.flat<string>()(0) = var->name();
+    tensor_names_val.flat<tstring>()(0) = var->name();
     TF_RETURN_IF_ERROR(NodeBuilder(tensor_names_op_name, "Const")
                            .Attr("dtype", DT_STRING)
                            .Attr("value", tensor_names_val)
@@ -284,7 +284,7 @@
     // Construct the shape_and_slices input with empty string.
     Node* shape_and_slices;
     Tensor shape_and_slices_val(DT_STRING, TensorShape({1}));
-    shape_and_slices_val.flat<string>()(0) = "";
+    shape_and_slices_val.flat<tstring>()(0) = "";
     TF_RETURN_IF_ERROR(NodeBuilder(shape_and_slices_op_name, "Const")
                            .Attr("dtype", DT_STRING)
                            .Attr("value", shape_and_slices_val)
diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD
index 5d16e4e..20bed36 100644
--- a/tensorflow/core/grappler/clusters/BUILD
+++ b/tensorflow/core/grappler/clusters/BUILD
@@ -2,7 +2,7 @@
 load("//tensorflow:tensorflow.bzl", "tf_cc_test")
 load("//tensorflow:tensorflow.bzl", "tf_cuda_library")
 load(
-    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "//tensorflow/core/platform:default/build_config_root.bzl",
     "tf_cuda_tests_tags",
 )
 
diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.cc b/tensorflow/core/grappler/clusters/virtual_cluster.cc
index e764722..3ef6c2a 100644
--- a/tensorflow/core/grappler/clusters/virtual_cluster.cc
+++ b/tensorflow/core/grappler/clusters/virtual_cluster.cc
@@ -85,9 +85,8 @@
   }
 
   TF_RETURN_IF_ERROR(estimator_->Initialize(item));
-  Costs ignored_costs;
   TF_RETURN_IF_ERROR(
-      estimator_->PredictCosts(item.graph, metadata, &ignored_costs));
+      estimator_->PredictCosts(item.graph, metadata, /*cost=*/nullptr));
 
   const std::unordered_map<string, DeviceProperties>& device = GetDevices();
   std::unordered_map<string, int64> peak_mem_usage =
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD
index f1746a2..af79d09 100644
--- a/tensorflow/core/grappler/costs/BUILD
+++ b/tensorflow/core/grappler/costs/BUILD
@@ -1,6 +1,6 @@
 load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_library")
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_additional_all_protos",
     "tf_proto_library",
     "tf_protos_grappler",
diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
index a7e8184..a85e293 100644
--- a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
@@ -149,12 +149,24 @@
 Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph,
                                              RunMetadata* run_metadata,
                                              Costs* costs) const {
-  GraphDef graph_copy = optimized_graph;
-  GrapplerItem item = item_->WithGraph(std::move(graph_copy));
+  std::unique_ptr<GrapplerItem> item_storage;
+  const GrapplerItem* item;
+  // Many callers to PredictCosts() pass the same optimized_graph as was used
+  // to initialize the estimator.
+  if (&optimized_graph == &item_->graph) {
+    item = item_;
+  } else {
+    GraphDef graph_copy = optimized_graph;
+    item_storage = absl::make_unique<GrapplerItem>(
+        item_->WithGraph(std::move(graph_copy)));
+    item = item_storage.get();
+  }
 
-  auto status = scheduler_->Init(&item);
+  auto status = scheduler_->Init(item);
   if (!status.ok()) {
-    costs->execution_time = Costs::Duration::max();
+    if (costs) {
+      costs->execution_time = Costs::Duration::max();
+    }
     return status;
   }
 
@@ -203,7 +215,11 @@
   }
 
   // run_metadata gets step_stats and partition_graphs from Summary.
-  *costs = scheduler_->Summary(run_metadata);
+  if (costs) {
+    *costs = scheduler_->Summary(run_metadata);
+  } else if (run_metadata) {
+    scheduler_->GenerateRunMetadata(run_metadata);
+  }
 
   if (VLOG_IS_ON(1)) {
     bool verbose = VLOG_IS_ON(2);
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index 15e6647..e72e613 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -811,16 +811,22 @@
           ctx->input_tensor_protos[dst_input] = tensor_proto;
 
           if (!ic->FullyDefined(input_tensors_as_shapes[dst_input])) {
-            // Shape from a Const is not fully defined when the Const has
-            // value -1 (e.g., Reshape(x, Const(-1)) to reshape an arbitrary
-            // tensor x to a vector).
+            // Tensorflow uses '-1' to encode unknown shape or dimension:
+            //
+            //      -1  : unknown shape
+            //     [-1] : vector of unknown size
+            // [-1, -1] : matrix of unknown size
+            //
+            // For example `tf.reshape(x, [-1])` will reshape an arbitrary
+            // tensor x to a vector.
+            //
             // It's possible that the same Const with -1 is used in many
             // places, but that doesn't mean the resultant shapes are
             // identical. e.g., x1 = Reshape(x, c) and y1 = Reshape(y, c),
-            // where c is -1. In this case, shape inference yields both x1 and
+            // where c is [-1]. In this case, shape inference yields both x1 and
             // y1 as rank 1, size unknown, but still the shapes of x1 and y1
-            // can be different. (even if we use different Const(-1) for x1
-            // and x2, graph optimzier may merge them to single Const through
+            // can be different. (even if we use different Const([-1]) for x1
+            // and x2, graph optimizer may merge them to single Const through
             // duplicate removal.)
             // If we reuse output_tensors_as_shapes to input_tensors_as_shapes
             // by copying ShapeHandle, they share the same Shape object, and
@@ -1755,9 +1761,14 @@
       // Scalar constant.
       int64 value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(0)
                                                : tensor.flat<int64>()(0);
-      // Ideally, values can be < -1, but MakeDim() fails with a value < -1.
-      // It's a limitation as we use ShapeHandle as a means to pass values.
-      if (value >= -1) {
+      if (value == -1) {
+        // Scalar value -1 represents an unknown shape. If we would try to
+        // MakeShape(MakeDim) with it, we would get vector of unknown size.
+        *tensors_as_shapes = ic->UnknownShape();
+        return true;
+      } else if (value >= 0) {
+        // Ideally, values can be < -1, but MakeDim() fails with a value < -1.
+        // It's a limitation as we use ShapeHandle as a means to pass values.
         *tensors_as_shapes = ic->MakeShape({ic->MakeDim(value)});
         return true;
       }
diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc
index d45bb14..2f3d171 100644
--- a/tensorflow/core/grappler/costs/utils.cc
+++ b/tensorflow/core/grappler/costs/utils.cc
@@ -130,7 +130,7 @@
         if (tensor.NumElements() != 1) {
           continue;
         }
-        const string filename = tensor.scalar<string>()();
+        const string& filename = tensor.scalar<tstring>()();
 
         Env* env = Env::Default();
         FileStatistics stat;
diff --git a/tensorflow/core/grappler/graph_view_test.cc b/tensorflow/core/grappler/graph_view_test.cc
index 5b3e140..7be98dc 100644
--- a/tensorflow/core/grappler/graph_view_test.cc
+++ b/tensorflow/core/grappler/graph_view_test.cc
@@ -98,7 +98,7 @@
 
 TEST_F(GraphViewTest, ParseSingleExample) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  Output a = ops::Const<string>(s.WithOpName("a"), "", {});
+  Output a = ops::Const<tstring>(s.WithOpName("a"), "", {});
   Output b = ops::Const<int64>(s.WithOpName("b"), 1, {1, 1});
   ops::ParseSingleExample c(s.WithOpName("c"), a, {b, b}, 2, {"w", "x"},
                             {"y", "z"}, {DT_INT64, DT_INT64}, {{1}, {1}});
diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc
index 6916bc8..80d0134 100644
--- a/tensorflow/core/grappler/grappler_item.cc
+++ b/tensorflow/core/grappler/grappler_item.cc
@@ -120,6 +120,8 @@
     fn_library.emplace(OpRegistry::Global(), graph.library());
   }
   for (const NodeDef& node : graph.node()) {
+    const auto attrs = AttrSlice(&node.attr());
+
     // Tensorflow functions do not prune stateful or dataset-output ops from
     // the function body (see PruneFunctionBody in common_runtime/function.cc).
     if (!optimization_options_.allow_pruning_stateful_and_dataset_ops &&
@@ -129,8 +131,9 @@
 
     // Do not remove ops with attribute _grappler_do_not_remove. This is useful
     // for debugging.
-    auto iter = node.attr().find("_grappler_do_not_remove");
-    if (iter != node.attr().end() && iter->second.b()) {
+    bool do_not_remove;
+    if (TryGetNodeAttr(attrs, "_grappler_do_not_remove", &do_not_remove) &&
+        do_not_remove) {
       result.insert(node.name());
     }
   }
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc
index 9790915..6d49b2f 100644
--- a/tensorflow/core/grappler/grappler_item_builder.cc
+++ b/tensorflow/core/grappler/grappler_item_builder.cc
@@ -267,8 +267,8 @@
   graph_ctor_opts.expect_device_spec = false;
   std::unique_ptr<Graph> graphptr(new Graph(function_library));
 
-  TF_RETURN_IF_ERROR(
-      ConvertGraphDefToGraph(graph_ctor_opts, graph_def, graphptr.get()));
+  TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
+      graph_ctor_opts, std::move(graph_def), graphptr.get()));
 
   // Optimize the graph.
   ::tensorflow::GraphOptimizer optimizer(*optimizer_opts);
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index c4de79e..0c94801 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -317,6 +317,8 @@
 
 bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; }
 
+bool IsLoopCond(const NodeDef& node) { return node.op() == "LoopCond"; }
+
 bool IsMatMul(const NodeDef& node) { return node.op() == "MatMul"; }
 
 bool IsMax(const NodeDef& node) { return node.op() == "Max"; }
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 2b2ea56..4dc8b31 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -99,6 +99,7 @@
 bool IsLogicalAnd(const NodeDef& node);
 bool IsLogicalNot(const NodeDef& node);
 bool IsLogicalOr(const NodeDef& node);
+bool IsLoopCond(const NodeDef& node);
 bool IsMatMul(const NodeDef& node);
 bool IsMax(const NodeDef& node);
 bool IsMaxPoolGrad(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 42e7bef..8944062 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -3,7 +3,7 @@
 
 # Platform specific build config
 load(
-    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "//tensorflow/core/platform:default/build_config_root.bzl",
     "if_static",
 )
 
@@ -625,6 +625,9 @@
 tf_cuda_cc_test(
     name = "meta_optimizer_test",
     srcs = ["meta_optimizer_test.cc"],
+    tags = [
+        "no_gpu",
+    ],
     deps = [
         ":custom_graph_optimizer",
         ":custom_graph_optimizer_registry",
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index a8b57ee..3bbd988 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -866,7 +866,7 @@
     *shapes_match = true;
     unique_factors->reserve(node->input_size());
 
-    for (int i = 0; i < node->input_size() && shapes_match; ++i) {
+    for (int i = 0; i < node->input_size() && *shapes_match; ++i) {
       const string& input = node->input(i);
       if (IsControlInput(input)) {
         break;
@@ -2248,6 +2248,8 @@
       FlipBooleanAttr(attr_a, new_op);
       new_op->set_input(0, a->input(0));
       ctx().node_map->UpdateInput(new_op->name(), a->name(), a->input(0));
+    } else {
+      ctx().node_map->UpdateOutput(a->name(), node->name(), new_op->name());
     }
 
     if (b_is_foldable) {
@@ -2256,6 +2258,8 @@
       FlipBooleanAttr(attr_b, new_op);
       new_op->set_input(1, b->input(0));
       ctx().node_map->UpdateInput(new_op->name(), b->name(), b->input(0));
+    } else {
+      ctx().node_map->UpdateOutput(b->name(), node->name(), new_op->name());
     }
 
     std::vector<const NodeDef*> deps_to_forward = {node};
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index ae3da03..82c0016 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -160,12 +160,12 @@
   OptimizeTwice(&optimizer, &item, &output);
   NodeMap node_map(&output);
 
-  EXPECT_EQ(output.node_size(), 5);
+  EXPECT_EQ(output.node_size(), 6);
   const NodeDef* new_div = node_map.GetNode("div");
   ASSERT_NE(new_div, nullptr);
   ASSERT_EQ(new_div->input_size(), 3);
   EXPECT_EQ(new_div->input(0), "check1");
-  EXPECT_EQ(new_div->input(1), "check1");
+  EXPECT_EQ(new_div->input(1), "check2");
   EXPECT_EQ(new_div->input(2), "^assert1");
 
   auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", bool_t}});
diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h b/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h
index f864183..4e25e91 100644
--- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h
+++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h
@@ -120,6 +120,7 @@
         "FusedBatchNormGradV2",
         "FusedBatchNormV3",
         "FusedBatchNormGradV3",
+        "_FusedBatchNormEx",
         "Inv",
         "LeakyRelu",
         "LeakyReluGrad",
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 487e3bc..2ce6377 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -989,11 +989,10 @@
     }
   }
 
-  // No need to (and don't) fold nodes that have no outgoing edges except
-  // whitelisted nodes. Such nodes could be introduced by an earlier constant
-  // folding pass and are preserved in case users want to fetch their values;
-  // re-processing them would lead to an error of adding a duplicated node
-  // to graph.
+  // Don't fold nodes that have no outgoing edges except whitelisted nodes.
+  // Such nodes could be introduced by an earlier constant folding pass and are
+  // preserved in case users want to fetch their values; re-processing them
+  // would lead to an error of adding a duplicated node to graph.
   const auto& outputs = node_map_->GetOutputs(node.name());
   if (outputs.empty() &&
       nodes_whitelist_.find(node.name()) == nodes_whitelist_.end()) {
@@ -1029,6 +1028,7 @@
       return false;
     }
   }
+  if (is_merge && !merge_has_constant_input) return false;
 
   // If we know the output shapes, make sure that the outputs are small enough
   // to materialize.
@@ -1050,7 +1050,7 @@
     }
   }
 
-  return !is_merge || merge_has_constant_input;
+  return true;
 }
 
 namespace {
@@ -1205,11 +1205,11 @@
       case DT_INT64:
         POPULATE_TENSOR_PROTO(tensor, t, int64, int64);
       case DT_UINT64:
-        POPULATE_TENSOR_PROTO(tensor, t, uint64, int64);
+        POPULATE_TENSOR_PROTO(tensor, t, uint64, uint64);
       case DT_INT32:
         POPULATE_TENSOR_PROTO(tensor, t, int32, int);
       case DT_UINT32:
-        POPULATE_TENSOR_PROTO(tensor, t, uint32, int);
+        POPULATE_TENSOR_PROTO(tensor, t, uint32, uint32);
       case DT_INT16:
         POPULATE_TENSOR_PROTO(tensor, t, int16, int);
       case DT_UINT16:
@@ -2831,6 +2831,9 @@
 
 bool ConstantFolding::ConstantPushDown(GraphDef* optimized_graph,
                                        NodeDef* node) {
+  // TODO(rmlarsen): Consider enabling for subtractions if we are comfortable
+  // with the potential loss of numerical accuracy due to re-association.
+  //
   // Consider the transformation
   //
   //                      +                +       = parent
@@ -2839,83 +2842,170 @@
   //                       / \              / \
   //                      X   Y            C   Y   = leaves
   //
-  // where C is constant and X is non-constant, and '+' denotes an
-  // associative and commutative operator like addition or multiplication.
-  // This optimization pushes constants down in the tree to canonicalize it.
-  // Moreoever, in cases where the child node has a second constant input Y
-  // we will create a leaf node that can be folded, e.g.
+  // where C is constant, X is non-constant, Y may be constant or non-constant,
+  // and '+' denotes an associative and commutative operator like addition or
+  // multiplication. This optimization pushes constants down in the tree to
+  // canonicalize it. Moreoever, in cases where the child node has a second
+  // constant input Y we will create a leaf node that can be folded, e.g.
   //
   //    Add(C1, Add(C2, X)) -> Add(X, Add(C1, C2)) -> Add(X, C1 + C2)
   //
-  // TODO(rmlarsen): Handle non-associative/non-commutative operators like
-  // subtraction and division, as well as mixed subtraction/addition,
-  // division/multiplication.
-  // Don't touch BiasAdd since they can't handle vectors as their first
+  // We also handle the non-commutative cases of subtraction and division
+  // by rotating the tree locally, e.g.
+  //    Sub(C, Add(X, Y)) -> Sub(Sub(C, Y), X)
+  //    Mul(C, Div(X, Y)) -> Mul(X, Div(C, Y)).
+  //
+  // Note: Don't touch BiasAdd since they can't handle vectors as their first
   // inputs.
-  if (has_fetch_ && (IsAdd(*node) || IsMul(*node)) &&
-      NumNonControlInputs(*node) == 2) {
-    NodeDef* left_child = node_map_->GetNode(node->input(0));
-    NodeDef* right_child = node_map_->GetNode(node->input(1));
-    // One child must be constant, and the other the same op as the parent.
-    if (node->op() != left_child->op() && node->op() != right_child->op()) {
-      return false;
-    }
-    const bool left_child_is_constant = IsReallyConstant(*left_child);
-    const bool right_child_is_constant = IsReallyConstant(*right_child);
-    if (!left_child_is_constant && !right_child_is_constant) {
-      return false;
-    }
-    if (node->device() != left_child->device() ||
-        node->device() != right_child->device()) {
-      return false;
-    }
-    NodeDef* op_child_node = left_child_is_constant ? right_child : left_child;
-    NodeDef* const_child_node =
-        left_child_is_constant ? left_child : right_child;
-    // Make sure that it is safe to change the value of the child node->
-    if (op_child_node->input_size() < 2 ||
-        nodes_to_preserve_.find(op_child_node->name()) !=
-            nodes_to_preserve_.end() ||
-        NumNonControlOutputs(*op_child_node, *node_map_) > 1) {
-      return false;
-    }
 
-    // Identify the nodes to swap.
-    NodeDef* left_leaf = node_map_->GetNode(op_child_node->input(0));
-    NodeDef* right_leaf = node_map_->GetNode(op_child_node->input(1));
-    const bool left_leaf_is_constant = IsReallyConstant(*left_leaf);
-    const bool right_leaf_is_constant = IsReallyConstant(*right_leaf);
-    if (left_leaf_is_constant && right_leaf_is_constant) {
-      // Child is already foldable, leave it alone.
-      return false;
-    }
-    const int non_const_leaf_input = left_leaf_is_constant ? 1 : 0;
-    const int parent_const_input = left_child_is_constant ? 0 : 1;
-    const auto& child_output = node_map_->GetOutputs(op_child_node->name());
-    if (child_output.find(const_child_node) != child_output.end()) {
-      // If there is a control edge from the child op to C, the transformation
-      // would create a cycle in the graph. We know that it must be a control
-      // edge. We can replace such a control edge with a control edge from A
-      // to C.
-      CHECK(MaybeRemoveControlInput(op_child_node->name(), const_child_node,
-                                    optimized_graph, node_map_.get()));
-      string other_leaf_input = left_leaf_is_constant ? op_child_node->input(0)
-                                                      : op_child_node->input(1);
-      MaybeAddControlInput(other_leaf_input, const_child_node, optimized_graph,
-                           node_map_.get());
-    }
-
-    // Swap the constant child with a non-constant leaf node.
-    node_map_->UpdateInput(node->name(), node->input(parent_const_input),
-                           op_child_node->input(non_const_leaf_input));
-    node_map_->UpdateInput(op_child_node->name(),
-                           op_child_node->input(non_const_leaf_input),
-                           node->input(parent_const_input));
-    std::swap(*node->mutable_input(parent_const_input),
-              *op_child_node->mutable_input(non_const_leaf_input));
-    return true;
+  // Get parent op type.
+  const bool is_add = IsAdd(*node);
+  const bool is_mul = IsMul(*node);
+  const bool is_sub = IsSub(*node);
+  const bool is_div = IsDiv(*node);
+  const bool is_symmetric = is_add || is_mul;
+  if (!has_fetch_ || !(is_add || is_sub || is_mul || is_div) ||
+      NumNonControlInputs(*node) != 2) {
+    return false;
   }
-  return false;
+
+  NodeDef* left_child = node_map_->GetNode(node->input(0));
+  NodeDef* right_child = node_map_->GetNode(node->input(1));
+
+  const bool left_child_is_constant = IsReallyConstant(*left_child);
+  const bool right_child_is_constant = IsReallyConstant(*right_child);
+  if (!left_child_is_constant && !right_child_is_constant) {
+    return false;
+  }
+  // Don't move nodes across devices.
+  if (node->device() != left_child->device() ||
+      node->device() != right_child->device()) {
+    return false;
+  }
+  NodeDef* op_child = left_child_is_constant ? right_child : left_child;
+  NodeDef* const_child = left_child_is_constant ? left_child : right_child;
+  // Don't rewrite the tree if it might create cycles.
+  // TODO(rmlarsen): Add back handling of control dependency from op to C.
+  const auto& child_output = node_map_->GetOutputs(op_child->name());
+  if (child_output.find(const_child) != child_output.end()) {
+    return false;
+  }
+  // Get child op type.
+  const bool is_child_add = IsAdd(*op_child);
+  const bool is_child_mul = IsMul(*op_child);
+  const bool is_child_sub = IsSub(*op_child);
+  const bool is_child_div = IsDiv(*op_child);
+  const bool is_add_sub = (is_add || is_sub) && (is_child_add || is_child_sub);
+  const bool is_mul_div = (is_mul || is_div) && (is_child_mul || is_child_div);
+  if (!is_add_sub && !is_mul_div) {
+    return false;
+  }
+
+  // TODO(rmlarsen): Consider enabling for subtractions if we are comfortable
+  // with the potential loss of numerical accuracy due to re-association.
+  // Notice that subtraction is not really different from addition in this
+  // regard.
+  if (is_sub || is_child_sub) {
+    return false;
+  }
+  const bool is_child_symmetric = is_child_add || is_child_mul;
+  // Make sure that it is safe to change the value of the child node result.
+  if (op_child->input_size() < 2 ||
+      nodes_to_preserve_.find(op_child->name()) != nodes_to_preserve_.end() ||
+      NumNonControlOutputs(*op_child, *node_map_) > 1) {
+    return false;
+  }
+  // Do not rewrite integer expressions with subtraction or division.
+  if (!CheckAttrExists(*node, "T").ok()) return false;
+  DataType dtype = node->attr().at("T").type();
+  if (!(is_symmetric && is_child_symmetric) &&
+      !(DataTypeIsFloating(dtype) || DataTypeIsComplex(dtype))) {
+    return false;
+  }
+
+  // Identify the nodes to swap.
+  NodeDef* left_leaf = node_map_->GetNode(op_child->input(0));
+  NodeDef* right_leaf = node_map_->GetNode(op_child->input(1));
+  const bool left_leaf_is_constant = IsReallyConstant(*left_leaf);
+  const bool right_leaf_is_constant = IsReallyConstant(*right_leaf);
+  if (left_leaf_is_constant && right_leaf_is_constant) {
+    // Child is already foldable, leave it alone.
+    return false;
+  }
+  // Don't move nodes across devices.
+  if (node->device() != left_leaf->device() ||
+      node->device() != right_leaf->device()) {
+    return false;
+  }
+  // Get the node names corresponding to X, Y, and C.
+  const string input_x =
+      left_leaf_is_constant ? op_child->input(1) : op_child->input(0);
+  const string input_y =
+      input_x == op_child->input(0) ? op_child->input(1) : op_child->input(0);
+  const string input_c =
+      left_child_is_constant ? node->input(0) : node->input(1);
+  const string input_op =
+      left_child_is_constant ? node->input(1) : node->input(0);
+
+  // Now we have identified the nodes to swap (non_const_leaf_input and
+  // const_child).
+  node_map_->UpdateInput(node->name(), input_c, input_x);
+  node_map_->AddOutput(input_c, op_child->name());
+  if (input_x != input_y) {
+    node_map_->RemoveOutput(input_x, op_child->name());
+  }
+
+  if (is_symmetric && is_child_symmetric) {
+    // Easy case (only commutative ops). We always write this as one of
+    //   +
+    //  / \
+    // X   +
+    //    / \
+    //   C   Y
+    node->set_input(0, input_x);
+    node->set_input(1, input_op);
+    op_child->set_input(0, input_c);
+    op_child->set_input(1, input_y);
+  } else {
+    // More complicated case: When there are non-commutative operations like
+    // subtractions or divisions involved, we may have to rotate the tree
+    // and/or change op types. There are 6 non-trivial cases depending on
+    // the effective generalized "sign" of each of the three terms C, Y, and X.
+    // Here are the final trees we want to generate for those 6 cases:
+    //
+    // (CYX signs):   ++-      +--      -+-    --+     +-+      -++
+    //                 -        -        -      -       +        +
+    //                / \      / \      / \    / \     / \      / \
+    //               +   X    -   X    -   X  X   +   X   -    X   -
+    //              / \      / \      / \        / \     / \      / \
+    //             C   Y    C   Y    Y   C      Y   C   C   Y    Y   C
+    //
+    NodeDef* non_const_leaf = left_leaf_is_constant ? right_leaf : left_leaf;
+    NodeDef* maybe_const_leaf =
+        non_const_leaf == right_leaf ? left_leaf : right_leaf;
+
+    // First, let's determine the effective sign of each term in the original
+    // expression
+    auto is_leaf_negated = [&](const NodeDef* node) -> bool {
+      bool leaf_negated = !is_child_symmetric && (node == right_leaf);
+      bool child_negated = !is_symmetric && (op_child == right_child);
+      return leaf_negated != child_negated;
+    };
+    const string symmetric_op = (is_add || is_sub) ? "Add" : "Mul";
+    const string nonsymmetric_op = (is_add || is_sub) ? "Sub" : "Div";
+    bool neg_c = !is_symmetric && (const_child == right_child);
+    bool neg_x = is_leaf_negated(non_const_leaf);
+    bool neg_y = is_leaf_negated(maybe_const_leaf);
+    // Rewrite the parent node.
+    node->set_op((neg_x || (neg_c && neg_y)) ? nonsymmetric_op : symmetric_op);
+    node->set_input(0, neg_x ? input_op : input_x);
+    node->set_input(1, neg_x ? input_x : input_op);
+    // Rewrite the child node.
+    op_child->set_op(neg_c != neg_y ? nonsymmetric_op : symmetric_op);
+    op_child->set_input(0, neg_c ? input_y : input_c);
+    op_child->set_input(1, neg_c ? input_c : input_y);
+  }
+  return true;
 }
 
 bool ConstantFolding::MulConvPushDown(GraphDef* optimized_graph, NodeDef* node,
@@ -3221,7 +3311,7 @@
     // child node.
     node->set_input(interval.first, added_node->name());
   }
-  if (!constant_input_runs.empty() && !inputs_to_delete.empty()) {
+  if (!inputs_to_delete.empty()) {
     // Fix up the inputs to the original node.
     protobuf::RepeatedPtrField<string> tmp;
     tmp.Swap(node->mutable_input());
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 3928fdf..87b9462 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -14,6 +14,7 @@
 ==============================================================================*/
 
 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
+
 #include "tensorflow/cc/ops/array_ops.h"
 #include "tensorflow/cc/ops/array_ops_internal.h"
 #include "tensorflow/cc/ops/standard_ops.h"
@@ -255,20 +256,19 @@
 TEST_F(ConstantFoldingTest, AddTree) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
 
+  Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {1});
   Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2});
   Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2});
   Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
                               ops::Placeholder::Shape(TensorShape({2, 2})));
   Output add_child = ops::Add(s.WithOpName("add_child"), c2, x);
-  Output c1 = ops::Const(s.WithOpName("c1").WithControlDependencies(add_child),
-                         1.0f, {1});
   Output add_parent = ops::Add(s.WithOpName("add_parent"), c1, add_child);
 
-  Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
-                              ops::Placeholder::Shape(TensorShape({2, 2})));
   Output c4 = ops::Const(s.WithOpName("c4"), 4.0f, {2});
   Output c5 = ops::Const(s.WithOpName("c5"), 5.0f, {2});
   Output c20 = ops::Const(s.WithOpName("c20"), 20.0f, {2});
+  Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
+                              ops::Placeholder::Shape(TensorShape({2, 2})));
   Output mul_child = ops::Mul(s.WithOpName("mul_child"), c4, y);
   Output mul_parent = ops::Mul(s.WithOpName("mul_parent"), c5, mul_child);
   Output addmul_child = ops::Add(s.WithOpName("addmul_child"), c4, x);
@@ -298,16 +298,16 @@
   //     / \              / \
   //   5.0  y           4.0 5.0
 
-  EXPECT_EQ(11, output.node_size());
+  EXPECT_EQ(10, output.node_size());
   for (const auto& node : output.node()) {
     if (node.name() == "add_child") {
       EXPECT_EQ("Const", node.op());
       TensorProto t = node.attr().at("value").tensor();
-      EXPECT_EQ(1, t.tensor_shape().dim_size());
+      ASSERT_EQ(1, t.tensor_shape().dim_size());
       EXPECT_EQ(2, t.tensor_shape().dim(0).size());
     } else if (node.name() == "add_parent") {
       EXPECT_EQ("Add", node.op());
-      EXPECT_EQ(2, node.input_size());
+      ASSERT_EQ(2, node.input_size());
       EXPECT_EQ("x", node.input(0));
       EXPECT_EQ("add_child", node.input(1));
     } else if (node.name() == "mul_child") {
@@ -317,30 +317,112 @@
       EXPECT_EQ(2, t.tensor_shape().dim(0).size());
     } else if (node.name() == "mul_parent") {
       EXPECT_EQ("Mul", node.op());
-      EXPECT_EQ(2, node.input_size());
+      ASSERT_EQ(2, node.input_size());
       EXPECT_EQ("y", node.input(0));
       EXPECT_EQ("mul_child", node.input(1));
     } else if (node.name() == "addmul_child") {
       // Unchanged.
       EXPECT_EQ("Add", node.op());
-      EXPECT_EQ(2, node.input_size());
+      ASSERT_EQ(2, node.input_size());
       EXPECT_EQ("c4", node.input(0));
       EXPECT_EQ("x", node.input(1));
     }
   }
 
   // Check that the result nodes have the expected value.
-  std::vector<string> fetch = {"c3", "c20"};
-  auto tensor_expected = EvaluateNodes(item.graph, fetch);
-  EXPECT_EQ(fetch.size(), tensor_expected.size());
-  fetch = {"add_child", "mul_child"};
-  auto tensors = EvaluateNodes(output, fetch);
-  EXPECT_EQ(fetch.size(), tensors.size());
+  auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+  auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+
+  std::vector<string> fetch = {"add_parent", "mul_parent"};
+  auto tensor_expected =
+      EvaluateNodes(item.graph, fetch, {{"x", x_t}, {"y", y_t}});
+  ASSERT_EQ(fetch.size(), tensor_expected.size());
+  fetch = {"add_parent", "mul_parent"};
+  auto tensors = EvaluateNodes(output, fetch, {{"x", x_t}, {"y", y_t}});
+  ASSERT_EQ(fetch.size(), tensors.size());
   for (int i = 0; i < fetch.size(); i++) {
     test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
   }
 }
 
+TEST_F(ConstantFoldingTest, TreeCanonicalization) {
+  for (int is_add : {true, false}) {
+    for (int is_parent_commutative : {true, false}) {
+      for (int is_child_commutative : {true, false}) {
+        // TODO(rmlarsen): Consider enabling for subtractions if we are
+        // comfortable with the potential loss of numerical accuracy due to
+        // re-association. Notice that subtraction is not really different from
+        // addition in this regard.
+        if (is_add && (!is_parent_commutative || !is_child_commutative))
+          continue;
+        for (int is_left_child_const : {true, false}) {
+          for (int is_left_leaf_const : {true, false}) {
+            tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+            Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2});
+            Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2});
+            Output x =
+                ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
+                                 ops::Placeholder::Shape(TensorShape({2, 2})));
+
+            auto get_op = [&](bool is_commutative, bool is_left_arg_cont,
+                              const string& name, const Output& const_arg,
+                              const Output non_const_arg) -> Output {
+              if (is_add) {
+                if (is_commutative) {
+                  return ops::Add(s.WithOpName(name),
+                                  is_left_arg_cont ? const_arg : non_const_arg,
+                                  is_left_arg_cont ? non_const_arg : const_arg);
+                } else {
+                  return ops::Sub(s.WithOpName(name),
+                                  is_left_arg_cont ? const_arg : non_const_arg,
+                                  is_left_arg_cont ? non_const_arg : const_arg);
+                }
+              } else {
+                if (is_commutative) {
+                  return ops::Mul(s.WithOpName(name),
+                                  is_left_arg_cont ? const_arg : non_const_arg,
+                                  is_left_arg_cont ? non_const_arg : const_arg);
+                } else {
+                  return ops::Div(s.WithOpName(name),
+                                  is_left_arg_cont ? const_arg : non_const_arg,
+                                  is_left_arg_cont ? non_const_arg : const_arg);
+                }
+              }
+            };
+
+            Output child = get_op(is_child_commutative, is_left_leaf_const,
+                                  "child", c2, x);
+            Output parent = get_op(is_parent_commutative, is_left_child_const,
+                                   "parent", c3, child);
+            GrapplerItem item;
+            item.fetch = {"parent"};
+            TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+            ConstantFolding optimizer(/*cpu_device=*/nullptr);
+            GraphDef output;
+            Status status =
+                optimizer.Optimize(/*cluster=*/nullptr, item, &output);
+            TF_EXPECT_OK(status);
+
+            // Check that the result nodes have the expected value.
+            auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+            std::vector<string> fetch = {"parent"};
+            auto tensor_expected =
+                EvaluateNodes(item.graph, fetch, {{"x", x_t}});
+            ASSERT_EQ(fetch.size(), tensor_expected.size());
+            fetch = {"parent"};
+            auto tensors = EvaluateNodes(output, fetch, {{"x", x_t}});
+            ASSERT_EQ(fetch.size(), tensors.size());
+            for (int i = 0; i < fetch.size(); i++) {
+              test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
+            }
+          }
+        }
+      }
+    }
+  }
+}
+
 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_ScalarConst) {
   for (string data_format : {
          "NHWC",
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index 6db3c5a..288d2e9 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -1,5 +1,5 @@
 load("//tensorflow:tensorflow.bzl", "tf_cc_test")
-load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
+load("//tensorflow/core/platform:default/build_config.bzl", "tf_protos_all")
 
 package(
     default_visibility = [
@@ -20,6 +20,7 @@
         ":inject_prefetch",
         ":latency_all_edges",
         ":make_sloppy",
+        ":make_stateless",
         ":map_and_batch_fusion",
         ":map_and_filter_fusion",
         ":map_fusion",
@@ -309,6 +310,21 @@
     alwayslink = 1,
 )
 
+tf_cc_test(
+    name = "inject_prefetch_test",
+    srcs = ["inject_prefetch_test.cc"],
+    deps = [
+        ":graph_test_utils",
+        ":graph_utils",
+        ":inject_prefetch",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+        "//tensorflow/core/grappler:grappler_item",
+    ],
+)
+
 cc_library(
     name = "latency_all_edges",
     srcs = ["latency_all_edges.cc"],
@@ -375,6 +391,37 @@
 )
 
 cc_library(
+    name = "make_stateless",
+    srcs = ["make_stateless.cc"],
+    hdrs = ["make_stateless.h"],
+    deps = [
+        ":graph_utils",
+        ":optimizer_base",
+        "//tensorflow/core/grappler:grappler_item",
+        "//tensorflow/core/grappler:mutable_graph_view",
+        "//tensorflow/core/grappler:op_types",
+        "//tensorflow/core/grappler/clusters:cluster",
+        "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+    ] + tf_protos_all(),
+    alwayslink = 1,
+)
+
+tf_cc_test(
+    name = "make_stateless_test",
+    srcs = ["make_stateless_test.cc"],
+    deps = [
+        ":graph_test_utils",
+        ":graph_utils",
+        ":make_stateless",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+        "//tensorflow/core/grappler:grappler_item",
+    ],
+)
+
+cc_library(
     name = "map_and_batch_fusion",
     srcs = ["map_and_batch_fusion.cc"],
     hdrs = [
diff --git a/tensorflow/core/grappler/optimizers/data/auto_shard.cc b/tensorflow/core/grappler/optimizers/data/auto_shard.cc
index a82f04e..0b7996a 100644
--- a/tensorflow/core/grappler/optimizers/data/auto_shard.cc
+++ b/tensorflow/core/grappler/optimizers/data/auto_shard.cc
@@ -37,6 +37,7 @@
 // clang-format off
 constexpr char kShardDatasetOpName[] = "ShardDataset";
 constexpr char kShuffleDatasetOpName[] = "ShuffleDataset";
+constexpr char kShuffleDatasetV2OpName[] = "ShuffleDatasetV2";
 
 constexpr std::array<const char*, 4> kReaderDatasetOps = {
     "FixedLengthRecordDataset",
@@ -50,7 +51,7 @@
     "ZipDataset"
 };
 
-constexpr std::array<const char*, 23> kPassThroughOps = {
+constexpr std::array<const char*, 25> kPassThroughOps = {
     "_Retval",
     "BatchDataset",
     "BatchDatasetV2",
@@ -58,6 +59,7 @@
     "PaddedBatchDataset",
     "PaddedBatchDatasetV2",
     "CacheDataset",
+    "CacheDatasetV2",
     "FilterDataset",
     "Identity",
     "MapAndBatchDataset",
@@ -71,9 +73,10 @@
     "ShardDataset",
     "ShuffleAndRepeatDataset",
     "ShuffleDataset",
+    "ShuffleDatasetV2",
     "SkipDataset",
     "TakeDataset",
-    "WindowDataset"
+    "WindowDataset",
 };
 
 // TODO(frankchn): Process functions within kFuncDatasetOps as well.
@@ -129,8 +132,8 @@
   // Add shapes and other attributes
   NodeDef* add_after = graph->GetNode(add_before.input(0));
 
-  if (str_util::EndsWith(add_after->op(), "Dataset") ||
-      str_util::EndsWith(add_after->op(), "DatasetV2")) {
+  if (absl::EndsWith(add_after->op(), "Dataset") ||
+      absl::EndsWith(add_after->op(), "DatasetV2")) {
     // We still may or may not have the right attributes because Datasets like
     // TFRecordDataset doesn't have a output type or shape, and by default we
     // set them to DT_STRING and an unknown shape.
@@ -174,27 +177,48 @@
 }
 
 Status AddShuffleNode(MutableGraphView* graph, const NodeDef& add_before,
-                      const string& buffer_node) {
+                      const string& buffer_size_node, const string& seed_node,
+                      const string& seed2_node, bool reshuffle_each_iteration) {
   NodeDef* add_after = graph->GetNode(add_before.input(0));
-
   NodeDef new_node;
   new_node.set_op(kShuffleDatasetOpName);
   graph_utils::SetUniqueGraphNodeName(kShuffleDatasetOpName, graph->graph(),
                                       &new_node);
 
-  NodeDef* seed = graph_utils::AddScalarConstNode<int64>(1, graph);
-  NodeDef* seed2 = graph_utils::AddScalarConstNode<int64>(2, graph);
-  AttrValue reshuffle;
-  reshuffle.set_b(false);
-
   new_node.add_input(add_before.input(0));
-  new_node.add_input(buffer_node);
-  new_node.add_input(seed->name());
-  new_node.add_input(seed2->name());
+  new_node.add_input(buffer_size_node);
+  new_node.add_input(seed_node);
+  new_node.add_input(seed2_node);
 
   graph_utils::CopyAttribute("output_shapes", *add_after, &new_node);
   graph_utils::CopyAttribute("output_types", *add_after, &new_node);
-  (*new_node.mutable_attr())["reshuffle_each_iteration"] = reshuffle;
+
+  AttrValue reshuffle_attr;
+  reshuffle_attr.set_b(reshuffle_each_iteration);
+  (*new_node.mutable_attr())["reshuffle_each_iteration"] = reshuffle_attr;
+
+  NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
+
+  TF_RETURN_IF_ERROR(
+      graph->UpdateFanouts(add_after->name(), new_node_graph->name()));
+  return Status::OK();
+}
+
+Status AddShuffleV2Node(MutableGraphView* graph, const NodeDef& add_before,
+                        const string& buffer_size_node,
+                        const string& seed_generator_node) {
+  NodeDef* add_after = graph->GetNode(add_before.input(0));
+  NodeDef new_node;
+  new_node.set_op(kShuffleDatasetV2OpName);
+  graph_utils::SetUniqueGraphNodeName(kShuffleDatasetV2OpName, graph->graph(),
+                                      &new_node);
+
+  new_node.add_input(add_before.input(0));
+  new_node.add_input(buffer_size_node);
+  new_node.add_input(seed_generator_node);
+
+  graph_utils::CopyAttribute("output_shapes", *add_after, &new_node);
+  graph_utils::CopyAttribute("output_types", *add_after, &new_node);
 
   NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
 
@@ -223,19 +247,45 @@
 
 Status RemoveShuffleDataset(MutableGraphView* graph, const NodeDef& node,
                             absl::flat_hash_set<string>* nodes_to_delete,
-                            bool* shuffle_removed,
-                            string* buffer_size_node_name) {
+                            string* op_name, string* buffer_size_node,
+                            string* seed_node, string* seed2_node,
+                            bool* reshuffle_each_iteration) {
   if (node.op() == kShuffleDatasetOpName) {
-    *shuffle_removed = true;
-    *buffer_size_node_name = node.input(1);
+    *op_name = node.op();
+    *buffer_size_node = node.input(1);
+    *seed_node = node.input(2);
+    *seed2_node = node.input(3);
+    *reshuffle_each_iteration = node.attr().at("reshuffle_each_iteration").b();
     TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
     nodes_to_delete->insert(node.name());
   }
 
   for (const auto& fanin : graph->GetFanins(node, true)) {
-    TF_RETURN_IF_ERROR(RemoveShuffleDataset(graph, *fanin.node, nodes_to_delete,
-                                            shuffle_removed,
-                                            buffer_size_node_name));
+    TF_RETURN_IF_ERROR(RemoveShuffleDataset(
+        graph, *fanin.node, nodes_to_delete, op_name, buffer_size_node,
+        seed_node, seed2_node, reshuffle_each_iteration));
+  }
+
+  // TODO(frankchn): Traverse functions too.
+  return Status::OK();
+}
+
+Status RemoveShuffleDatasetV2(MutableGraphView* graph, const NodeDef& node,
+                              absl::flat_hash_set<string>* nodes_to_delete,
+                              string* op_name, string* buffer_size_node,
+                              string* seed_generator_node) {
+  if (node.op() == kShuffleDatasetV2OpName) {
+    *op_name = node.op();
+    *buffer_size_node = node.input(1);
+    *seed_generator_node = node.input(2);
+    TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
+    nodes_to_delete->insert(node.name());
+  }
+
+  for (const auto& fanin : graph->GetFanins(node, true)) {
+    TF_RETURN_IF_ERROR(
+        RemoveShuffleDatasetV2(graph, *fanin.node, nodes_to_delete, op_name,
+                               buffer_size_node, seed_generator_node));
   }
 
   // TODO(frankchn): Traverse functions too.
@@ -245,15 +295,29 @@
 Status ProcessDatasetSourceNode(MutableGraphView* graph, const NodeDef& node,
                                 absl::flat_hash_set<string>* nodes_to_delete,
                                 int64 num_workers, int64 index) {
-  bool shuffle_removed = false;
-  string buffer_size_node_name = "";
+  string shuffle_op_name = "";
+  string buffer_size_node = "";
+  string seed_node = "";
+  string seed2_node = "";
+  string seed_generator_node = "";
+  bool reshuffle_each_iteration;
 
   TF_RETURN_IF_ERROR(AddShardNode(graph, node, num_workers, index));
   TF_RETURN_IF_ERROR(RemoveShuffleDataset(
-      graph, node, nodes_to_delete, &shuffle_removed, &buffer_size_node_name));
+      graph, node, nodes_to_delete, &shuffle_op_name, &buffer_size_node,
+      &seed_node, &seed2_node, &reshuffle_each_iteration));
+  if (shuffle_op_name.empty()) {
+    TF_RETURN_IF_ERROR(
+        RemoveShuffleDatasetV2(graph, node, nodes_to_delete, &shuffle_op_name,
+                               &buffer_size_node, &seed_generator_node));
+  }
 
-  if (shuffle_removed) {
-    TF_RETURN_IF_ERROR(AddShuffleNode(graph, node, buffer_size_node_name));
+  if (shuffle_op_name == kShuffleDatasetOpName) {
+    TF_RETURN_IF_ERROR(AddShuffleNode(graph, node, buffer_size_node, seed_node,
+                                      seed2_node, reshuffle_each_iteration));
+  } else if (shuffle_op_name == kShuffleDatasetV2OpName) {
+    TF_RETURN_IF_ERROR(
+        AddShuffleV2Node(graph, node, buffer_size_node, seed_generator_node));
   }
 
   return Status::OK();
@@ -383,7 +447,6 @@
                                           GraphDef* output,
                                           OptimizationStats* stats) {
   *output = item.graph;
-
   TF_RETURN_IF_ERROR(OptimizeGraph(item, num_workers_, index_, output));
   stats->num_changes++;
   return Status::OK();
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.cc b/tensorflow/core/grappler/optimizers/data/function_utils.cc
index 2053691..40f4f24 100644
--- a/tensorflow/core/grappler/optimizers/data/function_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/function_utils.cc
@@ -88,18 +88,27 @@
 
 void AddFunctionOutputWithUniqueName(StringPiece prefix,
                                      StringPiece output_tensor_name,
-                                     FunctionDef* function, DataType dt) {
+                                     FunctionDef* fdef, DataType dtype) {
   string name = string(prefix);
-  int id = function->signature().output_arg_size();
-  while (ContainsFunctionOutputWithName(name, *function)) {
+  int id = fdef->signature().output_arg_size();
+  while (ContainsFunctionOutputWithName(name, *fdef)) {
     name = strings::StrCat(prefix, "/_", id);
     ++id;
   }
-  auto* output = function->mutable_signature()->mutable_output_arg()->Add();
+  auto* output = fdef->mutable_signature()->mutable_output_arg()->Add();
   output->set_name(name);
-  output->set_type(dt);
+  output->set_type(dtype);
 
-  (*function->mutable_ret())[name] = string(output_tensor_name);
+  (*fdef->mutable_ret())[name] = string(output_tensor_name);
+}
+
+OpDef_ArgDef* AddFunctionInput(const string& name, FunctionDef* fdef,
+                               DataType dtype) {
+  auto* input_arg = fdef->mutable_signature()->mutable_input_arg()->Add();
+  input_arg->set_type(dtype);
+  input_arg->set_name(name);
+
+  return input_arg;
 }
 
 NodeDef* AddNode(StringPiece name, StringPiece op,
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.h b/tensorflow/core/grappler/optimizers/data/function_utils.h
index 79271e8..8941e58 100644
--- a/tensorflow/core/grappler/optimizers/data/function_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/function_utils.h
@@ -61,7 +61,11 @@
 // is unique, and maps to output_tensor_name in the ret dict.
 void AddFunctionOutputWithUniqueName(StringPiece prefix,
                                      StringPiece output_tensor_name,
-                                     FunctionDef* function, DataType dt);
+                                     FunctionDef* fdef, DataType dtype);
+
+// Adds an input to a FunctionDef.
+OpDef_ArgDef* AddFunctionInput(const string& name, FunctionDef* fdef,
+                               DataType dtype);
 
 // Adds a node to a FunctionDef.
 NodeDef* AddNode(StringPiece name, StringPiece op,
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils_test.cc b/tensorflow/core/grappler/optimizers/data/function_utils_test.cc
index 8ae0cde..9a53b00 100644
--- a/tensorflow/core/grappler/optimizers/data/function_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/function_utils_test.cc
@@ -60,6 +60,18 @@
   EXPECT_EQ(function.ret().at("y/_1"), "two");
 }
 
+TEST(FunctionUtilsTest, AddFunctionInput) {
+  FunctionDef fdef;
+  auto arg0 = AddFunctionInput("arg0", &fdef, DT_INT32);
+  auto arg1 = AddFunctionInput("arg1", &fdef, DT_BOOL);
+  EXPECT_EQ(fdef.signature().input_arg().data()[0], arg0);
+  EXPECT_EQ(arg0->name(), "arg0");
+  EXPECT_EQ(arg0->type(), DT_INT32);
+  EXPECT_EQ(fdef.signature().input_arg().data()[1], arg1);
+  EXPECT_EQ(arg1->name(), "arg1");
+  EXPECT_EQ(arg1->type(), DT_BOOL);
+}
+
 TEST(FunctionUtilsTest, ContainsFunctionNodeWithName) {
   FunctionDef function = test::function::XTimesTwo();
   EXPECT_FALSE(ContainsFunctionNodeWithName(
diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
index 9340267..323e3c2 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
@@ -25,6 +25,22 @@
 namespace grappler {
 namespace graph_tests_utils {
 
+NodeDef MakeCacheV2Node(StringPiece name, StringPiece input_node_name,
+                        StringPiece filename_node_name,
+                        StringPiece cache_node_name) {
+  return test::function::NDef(
+      name, "CacheDatasetV2",
+      {
+          string(input_node_name),
+          string(filename_node_name),
+          string(cache_node_name),
+      },
+      {
+          {"output_shapes", gtl::ArraySlice<TensorShape>{}},
+          {"output_types", gtl::ArraySlice<DataType>{}},
+      });
+}
+
 NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name,
                        StringPiece function_name) {
   return test::function::NDef(
@@ -60,12 +76,12 @@
        {"output_types", gtl::ArraySlice<DataType>{}}});
 }
 
-NodeDef MakeParallelInterleaveNode(StringPiece name,
-                                   StringPiece input_node_name,
-                                   StringPiece cycle_length_node_name,
-                                   StringPiece block_length_node_name,
-                                   StringPiece num_parallel_calls_node_name,
-                                   StringPiece function_name, bool sloppy) {
+NodeDef MakeParallelInterleaveV2Node(StringPiece name,
+                                     StringPiece input_node_name,
+                                     StringPiece cycle_length_node_name,
+                                     StringPiece block_length_node_name,
+                                     StringPiece num_parallel_calls_node_name,
+                                     StringPiece function_name, bool sloppy) {
   return test::function::NDef(
       name, "ParallelInterleaveDatasetV2",
       {string(input_node_name), string(cycle_length_node_name),
@@ -107,6 +123,22 @@
       });
 }
 
+NodeDef MakeShuffleV2Node(StringPiece name, StringPiece input_node_name,
+                          StringPiece buffer_size_node_name,
+                          StringPiece seed_generator_node_name) {
+  return test::function::NDef(
+      name, "ShuffleDatasetV2",
+      {
+          string(input_node_name),
+          string(buffer_size_node_name),
+          string(seed_generator_node_name),
+      },
+      {
+          {"output_shapes", gtl::ArraySlice<TensorShape>{}},
+          {"output_types", gtl::ArraySlice<DataType>{}},
+      });
+}
+
 }  // namespace graph_tests_utils
 }  // namespace grappler
 }  // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
index 3750e2d..0dcfe65 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
@@ -23,6 +23,11 @@
 namespace grappler {
 namespace graph_tests_utils {
 
+// Creates a test NodeDef for ShuffleDatasetV2.
+NodeDef MakeCacheV2Node(StringPiece name, StringPiece input_node_name,
+                        StringPiece filename_node_name,
+                        StringPiece cache_node_name);
+
 // Creates a test NodeDef for FilterDataset.
 NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name,
                        StringPiece function_name = "IsZero");
@@ -38,13 +43,13 @@
                             StringPiece drop_remainder_node_name,
                             StringPiece function_name = "XTimesTwo");
 
-// Creates a test NodeDef for ParallelInterleaveDataset.
-NodeDef MakeParallelInterleaveNode(StringPiece name,
-                                   StringPiece input_node_name,
-                                   StringPiece cycle_length_node_name,
-                                   StringPiece block_length_node_name,
-                                   StringPiece num_parallel_calls_node_name,
-                                   StringPiece function_name, bool sloppy);
+// Creates a test NodeDef for ParallelInterleaveDatasetV2.
+NodeDef MakeParallelInterleaveV2Node(StringPiece name,
+                                     StringPiece input_node_name,
+                                     StringPiece cycle_length_node_name,
+                                     StringPiece block_length_node_name,
+                                     StringPiece num_parallel_calls_node_name,
+                                     StringPiece function_name, bool sloppy);
 
 // Creates a test NodeDef for ParallelMapDataset.
 NodeDef MakeParallelMapNode(StringPiece name, StringPiece input_node_name,
@@ -56,6 +61,11 @@
                              StringPiece num_parallel_calls_node_name,
                              bool sloppy);
 
+// Creates a test NodeDef for ShuffleDatasetV2.
+NodeDef MakeShuffleV2Node(StringPiece name, StringPiece input_node_name,
+                          StringPiece buffer_size_node_name,
+                          StringPiece seed_generator_node_name);
+
 }  // namespace graph_tests_utils
 }  // namespace grappler
 }  // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index a11717e..ce56b7c 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -158,6 +158,46 @@
       graph);
 }
 
+Status GetScalarConstNodeValueHelper(
+    const NodeDef& node, DataType dtype,
+    const std::function<void(const Tensor&)>& get_value) {
+  if (node.op() != kConstOpName)
+    return errors::InvalidArgument("Node ", node.name(),
+                                   " is not a Const node. Op: ", node.op());
+
+  Tensor tensor;
+  TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &tensor));
+  if (!TensorShapeUtils::IsScalar(tensor.shape())) {
+    return errors::InvalidArgument(
+        "Node ", node.name(),
+        " should be a scalar but has shape: ", tensor.shape());
+  }
+
+  if (tensor.dtype() != dtype) {
+    return errors::InvalidArgument(
+        "Node ", node.name(), " should have type ", DataTypeString(dtype),
+        " but has type: ", DataTypeString(tensor.dtype()));
+  }
+
+  get_value(tensor);
+
+  return Status::OK();
+}
+
+template <>
+Status GetScalarConstNodeValue(const NodeDef& node, int64* value) {
+  return GetScalarConstNodeValueHelper(
+      node, DT_INT64,
+      [value](const Tensor& tensor) { *value = tensor.scalar<int64>()(); });
+}
+
+template <>
+Status GetScalarConstNodeValue(const NodeDef& node, bool* value) {
+  return GetScalarConstNodeValueHelper(
+      node, DT_BOOL,
+      [value](const Tensor& tensor) { *value = tensor.scalar<bool>()(); });
+}
+
 bool Compare(const GraphDef& g1, const GraphDef& g2) {
   if (g1.node_size() != g2.node_size()) {
     return false;
@@ -240,12 +280,12 @@
   return graph.GetRegularFanin(input_port).node;
 }
 
-Status GetDatasetOutputTypesAttr(const NodeDef& node, AttrValue* output_types) {
+Status GetDatasetOutputTypesAttr(const NodeDef& node,
+                                 DataTypeVector* output_types) {
   // We don't name the output_types attr consistently, so should check for both.
   for (const string& attr_name : {"output_types", "Toutput_types"}) {
     if (node.attr().contains(attr_name)) {
-      *output_types = node.attr().at(attr_name);
-      return Status::OK();
+      return GetNodeAttr(node, attr_name, output_types);
     }
   }
   return errors::InvalidArgument("Could not find output_types attr for node: ",
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index 341eec4..87c9831 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -80,6 +80,21 @@
 template <>
 NodeDef* AddScalarConstNode(StringPiece v, MutableGraphView* graph);
 
+// Retrieves the value of a const node. Returns an error
+// if the node is not const, or its value is of a different type.
+template <typename T>
+Status GetScalarConstNodeValue(const NodeDef& node, T* value) {
+  // is_same is an idiomatic hack for making it compile if not instantiated.
+  // Replacing with false will result in a compile-time error.
+  static_assert(!std::is_same<T, T>::value,
+                "Invalid specialization of this method fo rtype T.");
+}
+
+template <>
+Status GetScalarConstNodeValue(const NodeDef& node, int64* value);
+template <>
+Status GetScalarConstNodeValue(const NodeDef& node, bool* value);
+
 // Checks whether the two graphs are the same.
 bool Compare(const GraphDef& g1, const GraphDef& g2);
 
@@ -114,7 +129,8 @@
                       int64 i);
 
 // Gets the attr corresponding to a dataset node's output types, if it exists.
-Status GetDatasetOutputTypesAttr(const NodeDef& node, AttrValue* output_types);
+Status GetDatasetOutputTypesAttr(const NodeDef& node,
+                                 DataTypeVector* output_types);
 
 // Returns the list of indices of all nodes with the given op or empty list if
 // no such node exists.
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
index 93df72a..125f2e3 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
@@ -85,6 +85,64 @@
   EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello");
 }
 
+TEST(GraphUtilsTest, GetScalarConstNodeInt64) {
+  GraphDef graph_def;
+  MutableGraphView graph(&graph_def);
+  NodeDef* int64_node = AddScalarConstNode<int64>(128, &graph);
+  int64 result;
+  EXPECT_TRUE(GetScalarConstNodeValue<int64>(*int64_node, &result).ok());
+  EXPECT_EQ(result, 128);
+}
+
+TEST(GraphUtilsTest, GetScalarConstNodeBool) {
+  GraphDef graph_def;
+  MutableGraphView graph(&graph_def);
+  NodeDef* bool_node = AddScalarConstNode<bool>(true, &graph);
+  bool result;
+  EXPECT_TRUE(GetScalarConstNodeValue<bool>(*bool_node, &result).ok());
+  EXPECT_EQ(result, true);
+}
+
+TEST(GraphUtilsTest, GetScalarConstNodeErrorWithNonConst) {
+  GraphDef graph_def;
+  MutableGraphView graph(&graph_def);
+  NodeDef* non_const = AddScalarPlaceholder(DT_INT64, &graph);
+  int64 result;
+  Status s = GetScalarConstNodeValue<int64>(*non_const, &result);
+  EXPECT_FALSE(s.ok());
+  EXPECT_EQ(s.error_message(),
+            "Node Placeholder is not a Const node. Op: Placeholder");
+}
+
+TEST(GraphUtilsTest, GetScalarConstNodeErrorWithType) {
+  GraphDef graph_def;
+  MutableGraphView graph(&graph_def);
+  NodeDef* int64_node = AddScalarConstNode<int64>(128, &graph);
+  bool result;
+  Status s = GetScalarConstNodeValue<bool>(*int64_node, &result);
+  EXPECT_FALSE(s.ok());
+  EXPECT_EQ(s.error_message(),
+            "Node Const should have type bool but has type: int64");
+}
+
+TEST(GraphUtilsTest, GetScalarConstNodeErrorWithVector) {
+  NodeDef node;
+  node.set_name("Const");
+  node.set_op("Const");
+
+  (*node.mutable_attr())["dtype"].set_type(DT_INT64);
+  auto tensor = (*node.mutable_attr())["value"].mutable_tensor();
+  tensor->set_dtype(DT_INT64);
+  tensor->mutable_tensor_shape()->mutable_dim()->Add()->set_size(1);
+  tensor->add_int64_val(128);
+
+  int64 result;
+  Status s = GetScalarConstNodeValue<int64>(node, &result);
+  EXPECT_FALSE(s.ok());
+  EXPECT_EQ(s.error_message(),
+            "Node Const should be a scalar but has shape: [1]");
+}
+
 TEST(GraphUtilsTest, Compare) {
   GraphDef graph_def_a;
   MutableGraphView graph_a(&graph_def_a);
diff --git a/tensorflow/core/grappler/optimizers/data/inject_prefetch.cc b/tensorflow/core/grappler/optimizers/data/inject_prefetch.cc
index 479ce1e..bd37bec 100644
--- a/tensorflow/core/grappler/optimizers/data/inject_prefetch.cc
+++ b/tensorflow/core/grappler/optimizers/data/inject_prefetch.cc
@@ -30,6 +30,9 @@
 namespace grappler {
 namespace {
 
+constexpr char kLegacyAutotune[] = "legacy_autotune";
+constexpr char kPrefetchDataset[] = "PrefetchDataset";
+
 constexpr std::array<const char*, 4> kAsyncDatasetOps = {
     "ExperimentalMapAndBatchDataset",
     "ParallelMapDataset",
@@ -65,7 +68,7 @@
   for (const NodeDef* async_dataset_node : async_datasets) {
     NodeDef prefetch_node;
     graph_utils::SetUniqueGraphNodeName(
-        strings::StrCat("autotune/prefetch_", async_dataset_node->name()),
+        strings::StrCat("inject/prefetch_", async_dataset_node->name()),
         graph.graph(), &prefetch_node);
     prefetch_node.set_op("PrefetchDataset");
     // `input_dataset` input
@@ -82,6 +85,14 @@
     TF_RETURN_IF_ERROR(
         graph.UpdateFanouts(async_dataset_node->name(), added_node->name()));
   }
+
+  for (NodeDef& node : *output->mutable_node()) {
+    if (node.op() == kPrefetchDataset) {
+      (*node.mutable_attr())[kLegacyAutotune].set_b(false);
+      stats->num_changes++;
+    }
+  }
+
   return Status::OK();
 }
 
diff --git a/tensorflow/core/grappler/optimizers/data/inject_prefetch.h b/tensorflow/core/grappler/optimizers/data/inject_prefetch.h
index 8f51dab..3b3a712 100644
--- a/tensorflow/core/grappler/optimizers/data/inject_prefetch.h
+++ b/tensorflow/core/grappler/optimizers/data/inject_prefetch.h
@@ -30,7 +30,7 @@
   InjectPrefetch() = default;
   ~InjectPrefetch() override = default;
 
-  string name() const override { return "autotune_buffers"; };
+  string name() const override { return "inject_prefetch"; };
 
   Status Init(
       const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
diff --git a/tensorflow/core/grappler/optimizers/data/inject_prefetch_test.cc b/tensorflow/core/grappler/optimizers/data/inject_prefetch_test.cc
new file mode 100644
index 0000000..9c75867
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/inject_prefetch_test.cc
@@ -0,0 +1,116 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/inject_prefetch.h"
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+TEST(MakeStateless, ParallelMap) {
+  using test::function::NDef;
+  GrapplerItem item;
+  item.graph = test::function::GDef(
+      {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+       NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+       NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+       NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+       NDef("num_parallel_calls", "Const", {},
+            {{"value", 1}, {"dtype", DT_INT32}}),
+       graph_tests_utils::MakeParallelMapNode("map", "range",
+                                              "num_parallel_calls", "XTimesTwo",
+                                              /*sloppy=*/false)},
+      // FunctionLib
+      {
+          test::function::XTimesTwo(),
+      });
+
+  InjectPrefetch optimizer;
+  GraphDef output;
+  TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+  EXPECT_TRUE(graph_utils::ContainsNodeWithOp("PrefetchDataset", output));
+  int index = graph_utils::FindGraphNodeWithOp("PrefetchDataset", output);
+  EXPECT_FALSE(output.node(index).attr().at("legacy_autotune").b());
+}
+
+TEST(MakeStateless, ParallelInterleave) {
+  using test::function::NDef;
+  GrapplerItem item;
+  item.graph = test::function::GDef(
+      {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+       NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+       NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+       NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+       NDef("cycle_length", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+       NDef("block_length", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+       NDef("num_parallel_calls", "Const", {},
+            {{"value", 1}, {"dtype", DT_INT32}}),
+       graph_tests_utils::MakeParallelInterleaveV2Node(
+           "interleave", "range", "cycle_length", "block_length",
+           "num_parallel_calls", "XTimesTwo", /*sloppy=*/false)},
+      // FunctionLib
+      {
+          test::function::XTimesTwo(),
+      });
+
+  InjectPrefetch optimizer;
+  GraphDef output;
+  TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+  EXPECT_TRUE(graph_utils::ContainsNodeWithOp("PrefetchDataset", output));
+  int index = graph_utils::FindGraphNodeWithOp("PrefetchDataset", output);
+  EXPECT_FALSE(output.node(index).attr().at("legacy_autotune").b());
+}
+
+TEST(MakeStateless, MapAndBatch) {
+  using test::function::NDef;
+  GrapplerItem item;
+  item.graph = test::function::GDef(
+      {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+       NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+       NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+       NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+       NDef("batch_size", "Const", {}, {{"value", 32}, {"dtype", DT_INT64}}),
+       NDef("num_parallel_calls", "Const", {},
+            {{"value", 1}, {"dtype", DT_INT64}}),
+       NDef("drop_remainder", "Const", {},
+            {{"value", false}, {"dtype", DT_BOOL}}),
+       graph_tests_utils::MakeMapAndBatchNode(
+           "map_and_batch", "range", "batch_size", "num_parallel_calls",
+           "drop_remainder", "XTimesTwo")},
+      // FunctionLib
+      {
+          test::function::XTimesTwo(),
+      });
+
+  InjectPrefetch optimizer;
+  GraphDef output;
+  TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+  EXPECT_TRUE(graph_utils::ContainsNodeWithOp("PrefetchDataset", output));
+  int index = graph_utils::FindGraphNodeWithOp("PrefetchDataset", output);
+  EXPECT_FALSE(output.node(index).attr().at("legacy_autotune").b());
+}
+
+}  // namespace
+}  // namespace grappler
+}  // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/make_sloppy_test.cc b/tensorflow/core/grappler/optimizers/data/make_sloppy_test.cc
index 24431f4..89bb3f3 100644
--- a/tensorflow/core/grappler/optimizers/data/make_sloppy_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/make_sloppy_test.cc
@@ -29,10 +29,6 @@
 namespace grappler {
 namespace {
 
-using graph_tests_utils::MakeParallelInterleaveNode;
-using graph_tests_utils::MakeParallelMapNode;
-using graph_tests_utils::MakeParseExampleNode;
-
 TEST(MakeSloppy, ParallelInterleave) {
   using test::function::NDef;
   GrapplerItem item;
@@ -45,9 +41,9 @@
        NDef("block_length", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
        NDef("num_parallel_calls", "Const", {},
             {{"value", 1}, {"dtype", DT_INT32}}),
-       MakeParallelInterleaveNode("interleave", "range", "cycle_length",
-                                  "block_length", "num_parallel_calls",
-                                  "XTimesTwo", /*sloppy=*/false)},
+       graph_tests_utils::MakeParallelInterleaveV2Node(
+           "interleave", "range", "cycle_length", "block_length",
+           "num_parallel_calls", "XTimesTwo", /*sloppy=*/false)},
       // FunctionLib
       {
           test::function::XTimesTwo(),
@@ -71,8 +67,9 @@
        NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
        NDef("num_parallel_calls", "Const", {},
             {{"value", 1}, {"dtype", DT_INT32}}),
-       MakeParallelMapNode("map", "range", "num_parallel_calls", "XTimesTwo",
-                           /*sloppy=*/false)},
+       graph_tests_utils::MakeParallelMapNode("map", "range",
+                                              "num_parallel_calls", "XTimesTwo",
+                                              /*sloppy=*/false)},
       // FunctionLib
       {
           test::function::XTimesTwo(),
@@ -96,8 +93,9 @@
        NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
        NDef("num_parallel_calls", "Const", {},
             {{"value", 1}, {"dtype", DT_INT32}}),
-       MakeParseExampleNode("parse_example", "range", "num_parallel_calls",
-                            /*sloppy=*/false)},
+       graph_tests_utils::MakeParseExampleNode("parse_example", "range",
+                                               "num_parallel_calls",
+                                               /*sloppy=*/false)},
       // FunctionLib
       {});
 
diff --git a/tensorflow/core/grappler/optimizers/data/make_stateless.cc b/tensorflow/core/grappler/optimizers/data/make_stateless.cc
new file mode 100644
index 0000000..a18ca58
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/make_stateless.cc
@@ -0,0 +1,71 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/make_stateless.h"
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+constexpr char kCacheDataset[] = "CacheDataset";
+constexpr char kCacheDatasetV2[] = "CacheDatasetV2";
+constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
+constexpr char kShuffleDataset[] = "ShuffleDataset";
+constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2";
+
+}  // namespace
+
+Status MakeStateless::OptimizeAndCollectStats(Cluster* cluster,
+                                              const GrapplerItem& item,
+                                              GraphDef* output,
+                                              OptimizationStats* stats) {
+  *output = item.graph;
+  MutableGraphView graph(output);
+
+  NodeDef* zero_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
+
+  for (NodeDef& node : *output->mutable_node()) {
+    if (node.op() == kShuffleDatasetV2) {
+      *node.mutable_op() = kShuffleDataset;
+      // remove `seed_generator` input
+      node.mutable_input()->RemoveLast();
+      // add `seed` input
+      node.add_input(zero_node->name());
+      // add `seed2` input
+      node.add_input(zero_node->name());
+      // set `reshuffle_each_iteration` attr
+      (*node.mutable_attr())[kReshuffleEachIteration].set_b(true);
+    } else if (node.op() == kCacheDatasetV2) {
+      *node.mutable_op() = kCacheDataset;
+      // remove `cache` input
+      node.mutable_input()->RemoveLast();
+    }
+  }
+
+  return Status::OK();
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(MakeStateless, "make_stateless");
+
+}  // namespace grappler
+}  // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/make_stateless.h b/tensorflow/core/grappler/optimizers/data/make_stateless.h
new file mode 100644
index 0000000..702eb4c
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/make_stateless.h
@@ -0,0 +1,54 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAKE_STATELESS_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAKE_STATELESS_H_
+
+#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// This rewrite replaces transformations that depend on external state (such as
+// `ShuffleDatasetV2`) with a stateless alternative so that the input pipeline
+// graph can be cloned.
+//
+// Note that this rewrites may change observable behavior of the input pipeline
+// (e.g. `reshuffle_each_iteration` will not work) and is a stop gap solution
+// to enable cloning until a better mechanism exists.
+class MakeStateless : public TFDataOptimizerBase {
+ public:
+  MakeStateless() = default;
+  ~MakeStateless() override = default;
+
+  string name() const override { return "make_stateless"; }
+
+  Status Init(
+      const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+    return Status::OK();
+  }
+
+  Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item,
+                                 GraphDef* output,
+                                 OptimizationStats* stats) override;
+
+  void Feedback(Cluster* cluster, const GrapplerItem& item,
+                const GraphDef& optimize_output, double result) override {}
+};
+
+}  // namespace grappler
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAKE_STATELESS_H_
diff --git a/tensorflow/core/grappler/optimizers/data/make_stateless_test.cc b/tensorflow/core/grappler/optimizers/data/make_stateless_test.cc
new file mode 100644
index 0000000..a30b7c6
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/make_stateless_test.cc
@@ -0,0 +1,79 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/make_stateless.h"
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+TEST(MakeStateless, Cache) {
+  using test::function::NDef;
+  GrapplerItem item;
+  item.graph = test::function::GDef(
+      {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+       NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+       NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+       NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+       NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_INT64}}),
+       NDef("handle", "Const", {}, {{"value", 1}, {"dtype", DT_RESOURCE}}),
+       graph_tests_utils::MakeCacheV2Node("cache", "range", "filename",
+                                          "handle")},
+      {});
+
+  MakeStateless optimizer;
+  GraphDef output;
+  TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+  EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("cache", output));
+  int index = graph_utils::FindGraphNodeWithName("cache", output);
+  EXPECT_EQ(output.node(index).op(), "CacheDataset");
+  EXPECT_EQ(output.node(index).input_size(), 2);
+}
+
+TEST(MakeStateless, Shuffle) {
+  using test::function::NDef;
+  GrapplerItem item;
+  item.graph = test::function::GDef(
+      {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+       NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+       NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+       NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+       NDef("buffer_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT64}}),
+       NDef("handle", "Const", {}, {{"value", 1}, {"dtype", DT_RESOURCE}}),
+       graph_tests_utils::MakeShuffleV2Node("shuffle", "range", "buffer_size",
+                                            "handle")},
+      {});
+
+  MakeStateless optimizer;
+  GraphDef output;
+  TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+  EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("shuffle", output));
+  int index = graph_utils::FindGraphNodeWithName("shuffle", output);
+  EXPECT_EQ(output.node(index).op(), "ShuffleDataset");
+  EXPECT_EQ(output.node(index).input_size(), 4);
+}
+
+}  // namespace
+}  // namespace grappler
+}  // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc
index b364296..7a42fab 100644
--- a/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc
@@ -36,7 +36,8 @@
     std::map<string, tensorflow::RewriterConfig_CustomGraphOptimizer>;
 
 // tf.data optimizations, in the order we want to perform them.
-constexpr std::array<const char*, 15> kTFDataOptimizations = {
+constexpr std::array<const char*, 16> kTFDataOptimizations = {
+    "make_stateless",
     "noop_elimination",
     "shuffle_and_repeat_fusion",
     "map_fusion",
diff --git a/tensorflow/core/grappler/optimizers/data/rebatch.cc b/tensorflow/core/grappler/optimizers/data/rebatch.cc
index bcea9ee..821b486 100644
--- a/tensorflow/core/grappler/optimizers/data/rebatch.cc
+++ b/tensorflow/core/grappler/optimizers/data/rebatch.cc
@@ -28,6 +28,7 @@
 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
 #include "tensorflow/core/grappler/utils/functions.h"
 #include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/util/padding.h"
 
 namespace tensorflow {
 namespace grappler {
@@ -38,7 +39,7 @@
     return errors::InvalidArgument(
         "Cannot initialize RebatchOptimizer without config.");
 
-  num_workers_ = config->parameter_map().at("num_workers").i();
+  num_replicas_ = config->parameter_map().at("num_replicas").i();
   use_fallback_ = config->parameter_map().at("use_fallback").b();
   return Status::OK();
 }
@@ -50,14 +51,19 @@
 constexpr char kIdentityOp[] = "Identity";
 constexpr char kSubOp[] = "Sub";
 constexpr char kTruncateDivOp[] = "TruncateDiv";
+constexpr char kOutputShapesAttr[] = "output_shapes";
+constexpr char kOutputTypesAttr[] = "output_types";
+constexpr char kTOutputTypesAttr[] = "Toutput_types";
+constexpr char kBatchOp[] = "BatchDataset";
+constexpr char kBatchV2Op[] = "BatchDatasetV2";
+constexpr char kPaddedBatchOp[] = "PaddedBatchDataset";
+constexpr char kPaddedBatchV2Op[] = "PaddedBatchDatasetV2";
+constexpr char kMapAndBatchOp[] = "MapAndBatchDataset";
+constexpr char kExperimentalMapAndBatchOp[] = "ExperimentalMapAndBatchDataset";
 
 constexpr std::array<const char*, 6> kBatchDatasetOps = {
-    "BatchDataset",
-    "BatchDatasetV2",
-    "ExperimentalMapAndBatchDataset",
-    "MapAndBatchDataset",
-    "PaddedBatchDataset",
-    "PaddedBatchDatasetV2"};
+    kBatchOp,       kBatchV2Op,      kMapAndBatchOp, kExperimentalMapAndBatchOp,
+    kPaddedBatchOp, kPaddedBatchV2Op};
 
 constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
     "ConcatenateDataset",
@@ -69,8 +75,9 @@
 // batch dimension. Furthermore, transformations like "Skip" may change
 // the semantics of the dataset (since we'd be skipping N minibatches instead
 // of N batches).
-constexpr std::array<const char*, 20> kPassThroughOps = {
+constexpr std::array<const char*, 22> kPassThroughOps = {
     "CacheDataset",
+    "CacheDatasetV2",
     "ExperimentalScanDataset",
     "ExperimentalParseExampleDataset",
     "FilterDataset",
@@ -87,9 +94,11 @@
     "ShardDataset",
     "ShuffleAndRepeatDataset",
     "ShuffleDataset",
+    "ShuffleDatasetV2",
     "SkipDataset",
     "TakeDataset",
-    "WindowDataset"};
+    "WindowDataset",
+};
 
 constexpr std::array<const char*, 5> kFuncDatasetOps = {
     "ExperimentalGroupByWindowDataset",
@@ -116,17 +125,24 @@
     "TFRecordDataset",
 };
 
-NodeDef* AddBinaryNode(const string& input_x, const string& input_y,
-                       const string& op, DataType type,
-                       MutableGraphView* graph) {
+NodeDef MakeBinaryNode(const string& input_x, const string& input_y,
+                       const string& op, DataType dtype) {
   NodeDef node;
   node.set_op(op);
   node.add_input(input_x);
   node.add_input(input_y);
-  graph_utils::SetUniqueGraphNodeName(op, graph->graph(), &node);
-  AddNodeAttr("T", type, &node);
+  AddNodeAttr("T", dtype, &node);
 
-  return graph->AddNode(std::move(node));
+  return node;
+}
+
+NodeDef* AddBinaryNode(const string& input_x, const string& input_y,
+                       const string& op, DataType type, FunctionDef* fdef) {
+  NodeDef* node = fdef->add_node_def();
+  *node = MakeBinaryNode(input_x, input_y, op, type);
+  function_utils::SetUniqueFunctionNodeName(op, fdef, node);
+
+  return node;
 }
 
 // Adds a Const node to the FunctionDef.
@@ -160,6 +176,30 @@
   return Status::OK();
 }
 
+Status AddConstInt64Node(int64 value, FunctionDef* fdef, NodeDef** result) {
+  *result = fdef->add_node_def();
+  Tensor t(value);
+  TF_RETURN_IF_ERROR(NodeDefBuilder("", "Const")
+                         .Attr("dtype", DT_INT64)
+                         .Attr("value", t)
+                         .Finalize(*result));
+  function_utils::SetUniqueFunctionNodeName("rebatch/const", fdef, *result);
+
+  return Status::OK();
+}
+
+Status AddConstBoolNode(bool value, FunctionDef* fdef, NodeDef** result) {
+  *result = fdef->add_node_def();
+  Tensor t(value);
+  TF_RETURN_IF_ERROR(NodeDefBuilder("", "Const")
+                         .Attr("dtype", DT_BOOL)
+                         .Attr("value", t)
+                         .Finalize(*result));
+  function_utils::SetUniqueFunctionNodeName("rebatch/const", fdef, *result);
+
+  return Status::OK();
+}
+
 Status AddShapeNode(const NodeDefBuilder::NodeOut& input, FunctionDef* fdef,
                     NodeDef** result) {
   *result = fdef->add_node_def();
@@ -267,61 +307,72 @@
   return Status::OK();
 }
 
-Status UpdateOutputShapes(const string& node_name, int64 num_workers,
+Status UpdateOutputShapes(const string& node_name, int64 num_replicas,
                           MutableGraphView* graph) {
   NodeDef* node = graph->GetNode(node_name);
-  if (node->attr().contains("output_shapes")) {
-    AttrValue output_shapes = node->attr().at("output_shapes");
+  if (node->attr().contains(kOutputShapesAttr)) {
+    AttrValue output_shapes = node->attr().at(kOutputShapesAttr);
     for (auto& shape : *output_shapes.mutable_list()->mutable_shape()) {
       if (!shape.unknown_rank() && shape.dim(0).size() != -1) {
-        shape.mutable_dim(0)->set_size(shape.dim(0).size() / num_workers);
+        shape.mutable_dim(0)->set_size(shape.dim(0).size() / num_replicas);
       }
     }
-    (*node->mutable_attr())["output_shapes"] = output_shapes;
+    (*node->mutable_attr())[kOutputShapesAttr] = output_shapes;
   }
   return Status::OK();
 }
 
+// Helper function to get the batch_size input node for a give batch node.
+int64 GetBatchSizeArgIndex(const NodeDef& batch_node) {
+  if (batch_node.op() == kExperimentalMapAndBatchOp ||
+      batch_node.op() == kMapAndBatchOp) {
+    // For MapAndBatch we take the 3rd last input.
+    return batch_node.input_size() - 3;
+  }
+  // For all the batching datasets the batch_size is input number 1 except for
+  // MapAndBatchDataset.
+  return 1;
+}
+
+Status MakeNewBatchSizeNode(const string& global_batch_size_name,
+                            int64 num_replicas, FunctionDef* fdef,
+                            NodeDef** result) {
+  NodeDef* one_node;
+  TF_RETURN_IF_ERROR(AddConstInt64Node(1, fdef, &one_node));
+  NodeDef* num_replicas_node;
+  TF_RETURN_IF_ERROR(AddConstInt64Node(num_replicas, fdef, &num_replicas_node));
+
+  NodeDef* numerator_node =
+      AddBinaryNode(global_batch_size_name,
+                    strings::StrCat(num_replicas_node->name(), ":output:0"),
+                    kAddOp, DT_INT64, fdef);
+  numerator_node = AddBinaryNode(
+      strings::StrCat(numerator_node->name(), ":z:0"),
+      strings::StrCat(one_node->name(), ":output:0"), kSubOp, DT_INT64, fdef);
+
+  *result =
+      AddBinaryNode(strings::StrCat(numerator_node->name(), ":z:0"),
+                    strings::StrCat(num_replicas_node->name(), ":output:0"),
+                    kTruncateDivOp, DT_INT64, fdef);
+  return Status::OK();
+}
+
 // Given a "batch" dataset node, we replace the `batch_size` input with a new
-// input that corresponds to the original input divided by `num_workers`. If
-// `num_workers` does not divide `batch_size` evenly, the value is rounded up.
-Status MutateBatchSize(const NodeDef& node, int64 num_workers,
+// input that corresponds to the original input divided by `num_replicas`.
+Status MutateBatchSize(const NodeDef& node, int64 num_replicas,
                        MutableGraphView* graph) {
   // For all the batching datasets the batch_size is input number 1 except for
   // MapAndBatchDataset.
-  int64 batch_size_arg_index = 1;
-  if (node.op() == "ExperimentalMapAndBatchDataset" ||
-      node.op() == "MapAndBatchDataset") {
-    // For MapAndBatch we take the 3rd last input.
-    batch_size_arg_index = node.input_size() - 3;
-  }
+  int64 batch_size_arg_index = GetBatchSizeArgIndex(node);
   NodeDef* batch_size_node =
       graph_utils::GetInputNode(node, *graph, batch_size_arg_index);
-  NodeDef* new_batch_size_node;
-  if (batch_size_node->op() == kConstOp) {
-    Tensor batch_size_tensor;
-    TF_RETURN_IF_ERROR(
-        GetNodeAttr(*batch_size_node, "value", &batch_size_tensor));
-    if (!TensorShapeUtils::IsScalar(batch_size_tensor.shape())) {
-      return errors::Internal("Batch size node shape should be scalar");
-    }
-    int64 batch_size = batch_size_tensor.scalar<int64>()();
-    batch_size = (batch_size + num_workers - 1) / num_workers;
-    new_batch_size_node =
-        graph_utils::AddScalarConstNode<int64>(batch_size, graph);
-  } else {
-    NodeDef* one_node = graph_utils::AddScalarConstNode<int64>(1, graph);
-    NodeDef* num_workers_node =
-        graph_utils::AddScalarConstNode<int64>(num_workers, graph);
-    NodeDef* numerator_node =
-        AddBinaryNode(batch_size_node->name(), num_workers_node->name(), kAddOp,
-                      DT_INT64, graph);
-    numerator_node = AddBinaryNode(numerator_node->name(), one_node->name(),
-                                   kSubOp, DT_INT64, graph);
-    new_batch_size_node =
-        AddBinaryNode(numerator_node->name(), num_workers_node->name(),
-                      kTruncateDivOp, DT_INT64, graph);
-  }
+  int64 batch_size;
+  TF_RETURN_IF_ERROR(
+      graph_utils::GetScalarConstNodeValue(*batch_size_node, &batch_size));
+  DCHECK_EQ(batch_size % num_replicas, 0);
+  batch_size = batch_size / num_replicas;
+  NodeDef* new_batch_size_node =
+      graph_utils::AddScalarConstNode<int64>(batch_size, graph);
   // We don't call UpdateFanouts here because CSE elimination might lead to
   // multiple nodes sharing the same batch size constant node. This is also
   // why we don't delete batch_size_node as well.
@@ -330,7 +381,183 @@
   return Status::OK();
 }
 
-Status OptimizeGraph(const GrapplerItem& item, int64 num_workers,
+Status AddFlatMapNode(const string& input_dataset,
+                      gtl::ArraySlice<string> other_arguments,
+                      gtl::ArraySlice<DataType> t_arguments,
+                      const FunctionDef& flat_map_fn,
+                      const AttrValue& output_shapes,
+                      const DataTypeVector& output_types,
+                      FunctionLibraryDefinition* flib, MutableGraphView* graph,
+                      NodeDef** result) {
+  TF_RETURN_IF_ERROR(flib->AddFunctionDef(flat_map_fn));
+  AttrValue f;
+  f.mutable_func()->set_name(flat_map_fn.signature().name());
+
+  NodeDef flat_map_node;
+  flat_map_node.set_op("FlatMapDataset");
+  flat_map_node.add_input(input_dataset);
+  for (const auto& arg : other_arguments) {
+    flat_map_node.add_input(arg);
+  }
+  AddNodeAttr("f", f, &flat_map_node);
+  AddNodeAttr("Targuments", t_arguments, &flat_map_node);
+  AddNodeAttr(kOutputShapesAttr, output_shapes, &flat_map_node);
+  AddNodeAttr(kOutputTypesAttr, output_types, &flat_map_node);
+
+  graph_utils::SetUniqueGraphNodeName("rebatch/flat_map", graph->graph(),
+                                      &flat_map_node);
+  *result = graph->AddNode(std::move(flat_map_node));
+  return Status::OK();
+}
+
+// def flat_map_fn(*batched_components):
+//   ds = tf.data.Dataset.from_tensor_slices(batched_components)
+//   return ds.batch(minibatch_size, drop_remainder=False)
+Status CreateFlatMapFnWithBatch(const DataTypeVector& dtypes,
+                                int64 num_replicas, FunctionDef* result) {
+  NodeDef* tensor_slice_node = result->add_node_def();
+  tensor_slice_node->set_op("TensorSliceDataset");
+  for (int i = 0; i < dtypes.size(); ++i) {
+    auto* input_arg = function_utils::AddFunctionInput(
+        strings::StrCat("args_", i), result, dtypes.at(i));
+    tensor_slice_node->add_input(input_arg->name());
+  }
+  AddNodeAttr(kTOutputTypesAttr, dtypes, tensor_slice_node);
+
+  // The output_shapes attr here doesn't make a difference, since we
+  // set the output_shapes of the external FlatMap node.
+  AttrValue shapes;
+  SetUnknownShapes(dtypes.size(), &shapes);
+  AddNodeAttr(kOutputShapesAttr, shapes, tensor_slice_node);
+  function_utils::SetUniqueFunctionNodeName("rebatch/from_tensor_slices",
+                                            result, tensor_slice_node);
+
+  NodeDef* false_node;
+  TF_RETURN_IF_ERROR(AddConstBoolNode(false, result, &false_node));
+  NodeDef* batch_node = result->add_node_def();
+  batch_node->set_op(kBatchV2Op);
+  batch_node->add_input(
+      strings::StrCat(tensor_slice_node->name(), ":handle:0"));
+
+  // `batch_size` input
+  // Here, we capture the original batch size from outside the flat map fn.
+  auto* original_batch_size =
+      function_utils::AddFunctionInput("captured_batch_size", result, DT_INT64);
+  NodeDef* new_batch_size;
+  TF_RETURN_IF_ERROR(MakeNewBatchSizeNode(
+      original_batch_size->name(), num_replicas, result, &new_batch_size));
+  batch_node->add_input(strings::StrCat(new_batch_size->name(), ":z:0"));
+
+  // `drop_remainder` input
+  batch_node->add_input(strings::StrCat(false_node->name(), ":output:0"));
+  AddNodeAttr(kOutputTypesAttr, dtypes, batch_node);
+  AddNodeAttr(kOutputShapesAttr, shapes, batch_node);
+  function_utils::SetUniqueFunctionNodeName("rebatch/batch", result,
+                                            batch_node);
+  function_utils::AddFunctionOutputWithUniqueName(
+      "output", strings::StrCat(batch_node->name(), ":handle:0"), result,
+      DT_VARIANT);
+  // Because TensorSliceDataset is stateful, we set the function to stateful.
+  result->mutable_signature()->set_is_stateful(true);
+
+  return Status::OK();
+}
+
+// Rewrite graph to add
+// `.flat_map(lambda x: tf.data.Dataset.from_tensor_slices(x).
+//     batch(minibatch_size, drop_remainder=False))`
+// after the batch node. This ensures that the sum of the minibatch sizes
+// in a step adds up to the global batch size. However, since this adds
+// additional data copies (both from_tensor_slices and batch), we only use
+// this approach when necessary, i.e. when we need to drop remainder on the
+// global batch, or when the global batch size does not divide num_replicas
+// evenly.
+Status AppendFlatMap(const NodeDef& batch_node, int64 num_replicas,
+                     FunctionLibraryDefinition* flib, MutableGraphView* graph) {
+  // `.flat_map(lambda x: tf.data.Dataset.from_tensor_slices(x).
+  //     batch(minibatch_size, drop_remainder=False))`
+  FunctionDef flat_map_fn;
+  FunctionDefLibrary lib = flib->ToProto();
+  graph_utils::SetUniqueGraphFunctionName("rebatch/flat_map_fn", &lib,
+                                          &flat_map_fn);
+  DataTypeVector dtypes;
+  TF_RETURN_IF_ERROR(
+      graph_utils::GetDatasetOutputTypesAttr(batch_node, &dtypes));
+  TF_RETURN_IF_ERROR(
+      CreateFlatMapFnWithBatch(dtypes, num_replicas, &flat_map_fn));
+
+  int64 batch_size_index = GetBatchSizeArgIndex(batch_node);
+
+  NodeDef* flat_map_node;
+
+  AttrValue output_shapes = batch_node.attr().at(kOutputShapesAttr);
+  for (auto& shape : *output_shapes.mutable_list()->mutable_shape()) {
+    if (!shape.unknown_rank() && shape.dim(0).size() != -1) {
+      // Because the flat map function uses drop_remainder = False,
+      // the shape might be unknown
+      auto old_dim = shape.dim(0).size();
+      auto new_dim = old_dim % num_replicas == 0 ? old_dim / num_replicas : -1;
+      shape.mutable_dim(0)->set_size(new_dim);
+    }
+  }
+
+  TF_RETURN_IF_ERROR(AddFlatMapNode(strings::StrCat(batch_node.name(), ":0"),
+                                    {batch_node.input(batch_size_index)},
+                                    {DT_INT64}, flat_map_fn, output_shapes,
+                                    dtypes, flib, graph, &flat_map_node));
+
+  TF_RETURN_IF_ERROR(
+      graph->UpdateFanouts(batch_node.name(), flat_map_node->name()));
+
+  return Status::OK();
+}
+
+// There are several things we do here, depending on the values of
+// batch_size and drop_remainder.
+// (1) If batch size is known and divisible by num_replicas, and drop_remainder
+// is known to be False, we mutate the batch size directly.
+//   .batch(global_batch_size) -> .batch(global_batch_size // num_replicas)
+// (2) Otherwise, we add a flat_map transformation to preserve the global batch
+// size across the replicas and to preserve the drop remainder behavior.
+bool ShouldMutateBatchSizeDirectly(const NodeDef& batch_node,
+                                   int64 num_replicas,
+                                   MutableGraphView* graph) {
+  int64 batch_size_arg_index = GetBatchSizeArgIndex(batch_node);
+  NodeDef* batch_size_node =
+      graph_utils::GetInputNode(batch_node, *graph, batch_size_arg_index);
+
+  int64 batch_size;
+  Status s =
+      graph_utils::GetScalarConstNodeValue(*batch_size_node, &batch_size);
+  // If batch size is unknown or indivisible by num replicas, we don't
+  // mutate it directly
+  if (!s.ok() || batch_size % num_replicas != 0) return false;
+
+  if (batch_node.op() == kBatchOp || batch_node.op() == kPaddedBatchOp) {
+    // These ops don't have a `drop_remainder` input, and behave like
+    // drop_remainder is False.
+    return true;
+  }
+
+  // drop_remainder is the final input on the other batch nodes.
+  NodeDef* drop_remainder_node = graph_utils::GetInputNode(
+      batch_node, *graph, batch_node.input_size() - 1);
+  bool drop_remainder;
+  s = graph_utils::GetScalarConstNodeValue(*drop_remainder_node,
+                                           &drop_remainder);
+  return s.ok() && !drop_remainder;
+}
+
+Status RewriteBatchNode(const NodeDef& batch_node, int64 num_replicas,
+                        FunctionLibraryDefinition* flib,
+                        MutableGraphView* graph) {
+  if (ShouldMutateBatchSizeDirectly(batch_node, num_replicas, graph)) {
+    return MutateBatchSize(batch_node, num_replicas, graph);
+  }
+  return AppendFlatMap(batch_node, num_replicas, flib, graph);
+}
+
+Status OptimizeGraph(const GrapplerItem& item, int64 num_replicas,
                      bool use_fallback, GraphDef* output);
 
 // Helper function that starts from a node in the graph and recurses into its
@@ -341,16 +568,16 @@
 //      as they are datasets themselves.
 // 3. Core dataset ops + Identity op: Recurses into first input parameter.
 // 4. FlatMap type mapping dataset ops: Recurses into the function definition.
-Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers,
+Status RecursivelyHandleOp(const NodeDef& node, int64 num_replicas,
                            bool use_fallback, FunctionLibraryDefinition* flib,
                            MutableGraphView* graph) {
   if (IsDatasetNodeOfType(node, kBatchDatasetOps)) {
-    TF_RETURN_IF_ERROR(MutateBatchSize(node, num_workers, graph));
+    TF_RETURN_IF_ERROR(RewriteBatchNode(node, num_replicas, flib, graph));
   } else if (IsDatasetNodeOfType(node, kMultipleInputsDatasetOps)) {
     // For all multiple input datasets, all inputs are datasets themselves.
     for (int i = 0; i < node.input_size(); ++i) {
       NodeDef* input_node = graph_utils::GetInputNode(node, *graph, i);
-      TF_RETURN_IF_ERROR(RecursivelyHandleOp(*input_node, num_workers,
+      TF_RETURN_IF_ERROR(RecursivelyHandleOp(*input_node, num_replicas,
                                              use_fallback, flib, graph));
     }
   } else if (IsDatasetNodeOfType(node, kPassThroughOps) || IsRetval(node)) {
@@ -358,7 +585,7 @@
     // function body graph in place of function outputs, the input dataset is
     // input 0.
     NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
-    TF_RETURN_IF_ERROR(RecursivelyHandleOp(*input_node, num_workers,
+    TF_RETURN_IF_ERROR(RecursivelyHandleOp(*input_node, num_replicas,
                                            use_fallback, flib, graph));
   } else if (IsDatasetNodeOfType(node, kFuncDatasetOps)) {
     const string func_name =
@@ -368,7 +595,7 @@
     TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
         *fdef, *flib, graph->graph()->versions().producer(), &f_item));
     GraphDef optimized_func_graph;
-    TF_RETURN_IF_ERROR(OptimizeGraph(f_item, num_workers, use_fallback,
+    TF_RETURN_IF_ERROR(OptimizeGraph(f_item, num_replicas, use_fallback,
                                      &optimized_func_graph));
 
     // Function body optimization might have created new specialized
@@ -397,12 +624,12 @@
   }
   // If we've successfully updated the batch size of this node or any nodes
   // in the dataset tree rooted in this node, we update the output_shapes attr.
-  TF_RETURN_IF_ERROR(UpdateOutputShapes(node.name(), num_workers, graph));
+  TF_RETURN_IF_ERROR(UpdateOutputShapes(node.name(), num_replicas, graph));
   return Status::OK();
 }
 
 // Add nodes to the function to reshape arg to shape (-1, new_batch_dim, ...)
-Status ReshapeComponent(int new_batch_dim, StringPiece arg, DataType dtype,
+Status ReshapeComponent(int new_batch_dim, const string& arg, DataType dtype,
                         FunctionDef* fdef, string* result) {
   // Const with value [0]
   NodeDef* const_vec_0;
@@ -452,47 +679,50 @@
   return Status::OK();
 }
 
-Status CreateFlatMapFn(int new_batch_dim, const AttrValue& types,
-                       FunctionDef* result) {
+// def flat_map_fn(*batched_components):
+//   return tf.data.Dataset.from_tensor_slices(
+//     [tf.reshape(c, (-1, new_batch_size, ...))
+//      for c in batched_components])
+Status CreateFlatMapFnWithReshape(int new_batch_dim,
+                                  const DataTypeVector& types,
+                                  FunctionDef* result) {
   std::vector<NodeDefBuilder::NodeOut> tensor_slice_dataset_inputs;
 
   // For each component of the dataset, we reshape it from shape
   // (old_batch_size, ...) to (-1, new_batch_size, ...)
-  // where new_batch_size = (old_batch_size + num_workers - 1) // num_workers
-  for (int i = 0; i < types.list().type_size(); ++i) {
-    string arg = strings::StrCat("args_", i);
-    auto* input_arg = result->mutable_signature()->mutable_input_arg()->Add();
-    input_arg->set_type(types.list().type(i));
-    input_arg->set_name(arg);
+  // where new_batch_size = (old_batch_size + num_replicas - 1) // num_replicas
+  for (int i = 0; i < types.size(); ++i) {
+    auto* input_arg = function_utils::AddFunctionInput(
+        strings::StrCat("args_", i), result, types.at(i));
 
     string reshape_node_name;
-    TF_RETURN_IF_ERROR(ReshapeComponent(
-        new_batch_dim, arg, types.list().type(i), result, &reshape_node_name));
+    TF_RETURN_IF_ERROR(ReshapeComponent(new_batch_dim, input_arg->name(),
+                                        types.at(i), result,
+                                        &reshape_node_name));
 
     tensor_slice_dataset_inputs.emplace_back(
-        strings::StrCat(reshape_node_name, ":output"), 0, types.list().type(i));
+        strings::StrCat(reshape_node_name, ":output"), 0, types.at(i));
   }
 
   // The output_shapes attr here doesn't make a difference, since we
   // set the output_shapes of the external FlatMap node.
   AttrValue shapes;
-  SetUnknownShapes(types.list().type_size(), &shapes);
+  SetUnknownShapes(types.size(), &shapes);
 
   NodeDef* tensor_slice_dataset = result->add_node_def();
   TF_RETURN_IF_ERROR(NodeDefBuilder("", "TensorSliceDataset")
                          .Input(tensor_slice_dataset_inputs)
                          .Attr("Toutput_types", types)
-                         .Attr("output_shapes", shapes)
+                         .Attr(kOutputShapesAttr, shapes)
                          .Finalize(tensor_slice_dataset));
   function_utils::SetUniqueFunctionNodeName("rebatch/tensor_slice_dataset",
                                             result, tensor_slice_dataset);
 
-  auto* output_arg = result->mutable_signature()->mutable_output_arg()->Add();
-  output_arg->set_name("output");
-  output_arg->set_type(DT_VARIANT);
+  function_utils::AddFunctionOutputWithUniqueName(
+      "output", strings::StrCat(tensor_slice_dataset->name(), ":handle:0"),
+      result, DT_VARIANT);
+  // Because TensorSliceDataset is stateful, we set the function to stateful.
   result->mutable_signature()->set_is_stateful(true);
-  (*result->mutable_ret())["output"] =
-      strings::StrCat(tensor_slice_dataset->name(), ":handle:0");
 
   return Status::OK();
 }
@@ -504,13 +734,13 @@
 //     return tf.data.Dataset.from_tensor_slices(
 //       tf.reshape(
 //         x,
-//         tf.concat([[-1, old_batch_dim / num_workers], tf.shape(x)[1:]], 0)
+//         tf.concat([[-1, old_batch_dim / num_replicas], tf.shape(x)[1:]], 0)
 //       )
 //     )
 //
 //   dataset = dataset.flat_map(fn)
 // ```
-Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_workers,
+Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_replicas,
                            FunctionLibraryDefinition* flib,
                            MutableGraphView* graph) {
   if (IsRetval(*fetch_node) || fetch_node->op() == kIdentityOp) {
@@ -524,53 +754,43 @@
   // because of the use of the "Reshape" op. This ensures that the error is
   // surfaced correctly.
   AttrValue output_shapes;
-  if (!fetch_node->attr().contains("output_shapes")) {
+  if (!fetch_node->attr().contains(kOutputShapesAttr)) {
     return errors::InvalidArgument(
         "Cannot use rebatching fallback without output_shapes attr. Node: ",
         fetch_node->name(), " Op: ", fetch_node->op());
   } else {
-    output_shapes = fetch_node->attr().at("output_shapes");
+    output_shapes = fetch_node->attr().at(kOutputShapesAttr);
   }
   int batch_dim;
   TF_RETURN_IF_ERROR(GetBatchDim(output_shapes, &batch_dim));
-  if (batch_dim % num_workers != 0) {
+  if (batch_dim % num_replicas != 0) {
     return errors::InvalidArgument(
         "Cannot use rebatching fallback when batch dimension doesn't divide "
-        "num_workers evenly.");
+        "num_replicas evenly.");
   }
 
   // Create the flat map fn
   FunctionDef flat_map_fn;
   FunctionDefLibrary lib = flib->ToProto();
-  graph_utils::SetUniqueGraphFunctionName("flat_map_fn", &lib, &flat_map_fn);
+  graph_utils::SetUniqueGraphFunctionName("rebatch/flat_map_fn", &lib,
+                                          &flat_map_fn);
 
   // Get types of input arguments from the output types of the final dataset.
-  AttrValue output_types;
+  DataTypeVector output_types;
   TF_RETURN_IF_ERROR(
       graph_utils::GetDatasetOutputTypesAttr(*fetch_node, &output_types));
+  TF_RETURN_IF_ERROR(CreateFlatMapFnWithReshape(batch_dim / num_replicas,
+                                                output_types, &flat_map_fn));
+
+  NodeDef* flat_map_node;
+  TF_RETURN_IF_ERROR(AddFlatMapNode(strings::StrCat(fetch_node->name(), ":0"),
+                                    {}, {}, flat_map_fn, output_shapes,
+                                    output_types, flib, graph, &flat_map_node));
   TF_RETURN_IF_ERROR(
-      CreateFlatMapFn(batch_dim / num_workers, output_types, &flat_map_fn));
+      UpdateOutputShapes(flat_map_node->name(), num_replicas, graph));
 
-  TF_RETURN_IF_ERROR(flib->AddFunctionDef(flat_map_fn));
-  AttrValue fn;
-  fn.mutable_func()->set_name(flat_map_fn.signature().name());
-
-  NodeDef flat_map_node;
   TF_RETURN_IF_ERROR(
-      NodeDefBuilder("", "FlatMapDataset")
-          .Input(fetch_node->name(), 0, DT_VARIANT)
-          .Input(std::vector<NodeDefBuilder::NodeOut>())  // other_arguments
-          .Attr("f", fn)
-          .Attr("Targuments", std::vector<DataType>())
-          .Attr("output_types", output_types)
-          .Attr("output_shapes", output_shapes)
-          .Finalize(&flat_map_node));
-  graph_utils::SetUniqueGraphNodeName("rebatch/flat_map", graph->graph(),
-                                      &flat_map_node);
-  NodeDef* added = graph->AddNode(std::move(flat_map_node));
-  TF_RETURN_IF_ERROR(UpdateOutputShapes(added->name(), num_workers, graph));
-
-  TF_RETURN_IF_ERROR(graph->UpdateFanouts(fetch_node->name(), added->name()));
+      graph->UpdateFanouts(fetch_node->name(), flat_map_node->name()));
 
   return Status::OK();
 }
@@ -578,7 +798,7 @@
 // Helper function that given a GrapplerItem generates a mutated graph def
 // with the batch size changed. The GrapplerItem could be generated from the
 // main graph or could be a function graph.
-Status OptimizeGraph(const GrapplerItem& item, int64 num_workers,
+Status OptimizeGraph(const GrapplerItem& item, int64 num_replicas,
                      bool use_fallback, GraphDef* output) {
   *output = item.graph;
   MutableGraphView graph(output);
@@ -588,18 +808,18 @@
   NodeDef* sink_node;
   TF_RETURN_IF_ERROR(graph_utils::GetFetchNode(graph, item, &sink_node));
 
-  Status s =
-      RecursivelyHandleOp(*sink_node, num_workers, use_fallback, &flib, &graph);
+  Status s = RecursivelyHandleOp(*sink_node, num_replicas, use_fallback, &flib,
+                                 &graph);
   if (!s.ok()) {
     if (use_fallback) {
-      VLOG(1) << "Couldn't find a batch transformation. Using a fallback method"
-                 " to rebatch dataset.";
+      VLOG(1) << "Failed to rebatch by rewriting the batch transformation ("
+              << s << "). Using a fallback method instead.";
       // If RecursivelyHandleOp fails, we reset `graph` to use the original,
       // graph, since that function may have mutated `graph`.
       *output = item.graph;
       graph = MutableGraphView(output);
       TF_RETURN_IF_ERROR(
-          RebatchWithFallback(sink_node, num_workers, &flib, &graph));
+          RebatchWithFallback(sink_node, num_replicas, &flib, &graph));
     } else {
       // Return the error
       return s;
@@ -618,7 +838,7 @@
   *output = item.graph;
   MutableGraphView graph(output);
 
-  TF_RETURN_IF_ERROR(OptimizeGraph(item, num_workers_, use_fallback_, output));
+  TF_RETURN_IF_ERROR(OptimizeGraph(item, num_replicas_, use_fallback_, output));
   stats->num_changes++;
   return Status::OK();
 }
diff --git a/tensorflow/core/grappler/optimizers/data/rebatch.h b/tensorflow/core/grappler/optimizers/data/rebatch.h
index 75c9658..028e690 100644
--- a/tensorflow/core/grappler/optimizers/data/rebatch.h
+++ b/tensorflow/core/grappler/optimizers/data/rebatch.h
@@ -23,7 +23,7 @@
 namespace grappler {
 
 // This optimizer changes the batch size of the output dataset by dividing the
-// current batch size by parameter `num_workers`. Currently, this works only
+// current batch size by parameter `num_replicas`. Currently, this works only
 // for very simple pipelines with a single BatchDatasetV2 transformation.
 class RebatchOptimizer : public TFDataOptimizerBase {
  public:
@@ -43,7 +43,7 @@
                 const GraphDef& optimize_output, double result) override;
 
  private:
-  int64 num_workers_;
+  int64 num_replicas_;
   bool use_fallback_;
 };
 
diff --git a/tensorflow/core/grappler/optimizers/data/slack.cc b/tensorflow/core/grappler/optimizers/data/slack.cc
index 1ccc00e..6d1aab0 100644
--- a/tensorflow/core/grappler/optimizers/data/slack.cc
+++ b/tensorflow/core/grappler/optimizers/data/slack.cc
@@ -51,8 +51,9 @@
 constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
     "ZipDataset", "ConcatenateDataset"};
 
-constexpr std::array<const char*, 19> kPassThroughOps = {
+constexpr std::array<const char*, 21> kPassThroughOps = {
     "CacheDataset",
+    "CacheDatasetV2",
     "ExperimentalMaxIntraOpParallelismDataset",
     "ExperimentalPrivateThreadPoolDataset",
     "FilterDataset",
@@ -68,9 +69,11 @@
     "ShardDataset",
     "ShuffleAndRepeatDataset",
     "ShuffleDataset",
+    "ShuffleDatasetV2",
     "SkipDataset",
     "TakeDataset",
-    "WindowDataset"};
+    "WindowDataset",
+};
 
 }  // namespace
 
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
index 2247f81..d041797 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
@@ -1,5 +1,5 @@
 load("//tensorflow:tensorflow.bzl", "tf_cc_test")
-load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
+load("//tensorflow/core/platform:default/build_config.bzl", "tf_protos_all")
 
 package(
     default_visibility = ["//visibility:private"],
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc
index b4f5c36..09fc57b 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc
@@ -146,8 +146,7 @@
 bool MarkedNoSpecialize(const FunctionDef& fdef) {
   const auto attr = AttrSlice(&fdef.attr());
   bool nospecialize = false;
-  return GetNodeAttr(attr, kNoSpecializeAttr, &nospecialize).ok() &&
-         nospecialize;
+  return TryGetNodeAttr(attr, kNoSpecializeAttr, &nospecialize) && nospecialize;
 }
 
 // Specialized function instantiation type parameters, body parameters, and
@@ -784,18 +783,17 @@
 using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
 using OutputControlSource = InlineFunctionBodyOptions::OutputControlSource;
 
-// Checks if boolean attribute is defined and it's value is 'true'.
+// Checks if boolean attribute is defined and its value is 'true'.
 bool CheckBoolAttr(const Node* n, absl::string_view attr_name) {
   bool match;
-  Status s = GetNodeAttr(n->attrs(), attr_name, &match);
-  return s.ok() && match;
+  bool found = TryGetNodeAttr(n->attrs(), attr_name, &match);
+  return found && match;
 }
 
 // Checks if string attribute is defined and it's not empty.
 bool CheckStringAttr(const Node* n, absl::string_view attr_name) {
-  string match;
-  Status s = GetNodeAttr(n->attrs(), attr_name, &match);
-  return s.ok() && !match.empty();
+  const string& value = GetNodeAttrString(n->attrs(), attr_name);
+  return !value.empty();
 }
 
 bool LowerUsingSwitchMergeIsOn(const Node* n) {
@@ -1216,11 +1214,11 @@
       AddFrameForwardingControlEdge(control_flow_info, n, graph.get());
 
       if (n->IsIfNode()) {
-        TF_RETURN_IF_ERROR(RewriteIfNode(n, graph.get(), flib_def, false));
+        TF_RETURN_IF_ERROR(RewriteIfNode(n, graph.get(), false));
       } else if (n->type_string() == "Case") {
-        TF_RETURN_IF_ERROR(RewriteCaseNode(n, graph.get(), flib_def, false));
-      } else if (n->type_string() == "While") {
-        TF_RETURN_IF_ERROR(RewriteWhileNode(n, graph.get(), flib_def, false));
+        TF_RETURN_IF_ERROR(RewriteCaseNode(n, graph.get(), false));
+      } else if (n->IsWhileNode()) {
+        TF_RETURN_IF_ERROR(RewriteWhileNode(n, graph.get(), false));
       }
       continue;
     }
diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer.cc
index 38393e1..a254052 100644
--- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer.cc
@@ -40,6 +40,12 @@
 constexpr float kVoltaGPURatioThreshold = 0.5;
 constexpr float kConv2DGPUFP16Threshold = 0.5;
 
+struct MutableNodeViewFormatter {
+  void operator()(std::string* out, utils::MutableNodeView* node_view) const {
+    absl::StrAppend(out, node_view->node()->name());
+  }
+};
+
 inline std::pair<int, int> GetNumGPUs(const Cluster& cluster) {
   auto devices = cluster.GetDevices();
   int num_gpus = 0;
@@ -93,6 +99,8 @@
     }
   }
 
+  if (num_conv2d_gpu == 0) return false;
+
   return (static_cast<float>(num_conv2d_gpu_fp16) /
           static_cast<float>(num_conv2d_gpu)) >= kConv2DGPUFP16Threshold;
 }
@@ -267,12 +275,17 @@
   utils::MutableGraphView* graph_view = context->graph_view.get();
   utils::Mutation* mutation = graph_view->GetMutationBuilder();
 
+  absl::flat_hash_set<utils::MutableNodeView*> cancelled_transposes;
+
   const int num_nodes = graph_view->NumNodes();
   for (int i = 0; i < num_nodes; ++i) {
     // Transpose node after Pad.
     auto* transpose_after = graph_view->GetNode(i);
     if (!IsTranspose(*transpose_after->node())) continue;
 
+    // This transpose was already cancelled in previous loop iteration.
+    if (cancelled_transposes.contains(transpose_after)) continue;
+
     // Pad node.
     const auto& transpose_after_fanin = transpose_after->GetRegularFanin(0);
     auto* pad = transpose_after_fanin.node_view();
@@ -306,10 +319,34 @@
                                         &permute_t))
       continue;
 
-    VLOG(0) << "Cancel transpose node pair around pad node:"
+    // Pad output might be used multiple times by different Transpose nodes. If
+    // they all have identical permutation, we can cancel all of them.
+    std::vector<utils::MutableNodeView*> pad_fanout_transposes;
+    pad_fanout_transposes.emplace_back(transpose_after);
+
+    bool pad_has_unsupported_fanout = false;
+    for (auto& fanout : pad->GetRegularFanout(0)) {
+      auto* extra_transpose = fanout.node_view();
+      if (extra_transpose == transpose_after) continue;
+
+      // Check that fanout is a Transpose identical to the transpose_after.
+      Tensor extra_permute_t;
+      if (!GetValueAttrFromConstInputNode(*extra_transpose, IsTranspose, 1,
+                                          &extra_permute_t) ||
+          extra_permute_t.tensor_data() != permute_t.tensor_data()) {
+        pad_has_unsupported_fanout = true;
+        break;
+      }
+
+      pad_fanout_transposes.emplace_back(extra_transpose);
+    }
+    if (pad_has_unsupported_fanout) continue;
+
+    VLOG(0) << "Cancel Transpose nodes around Pad:"
             << " transpose_before=" << transpose_before->node()->name()
-            << " pad=" << pad->node()->name()
-            << " transpose_after=" << transpose_after->node()->name();
+            << " pad=" << pad->node()->name() << " transpose_after="
+            << absl::StrJoin(pad_fanout_transposes, ",",
+                             MutableNodeViewFormatter());
 
     // Permute paddings in place according to permutation in second transpose.
     auto permutation_s = absl::Span<int32>(permute_t.flat<int32>().data(),
@@ -325,14 +362,16 @@
 
     // Transform Transpose nodes into Identity nodes.
     const auto transpose_to_identity =
-        [&mutation](utils::MutableNodeView* transpose) -> void {
+        [&cancelled_transposes,
+         &mutation](utils::MutableNodeView* transpose) -> void {
       mutation->UpdateNodeOp(transpose, "Identity");
       mutation->RemoveNodeAttr(transpose, "Tperm");
       mutation->RemoveRegularFanin(transpose, 1);
+      cancelled_transposes.insert(transpose);
     };
 
     transpose_to_identity(transpose_before);
-    transpose_to_identity(transpose_after);
+    absl::c_for_each(pad_fanout_transposes, transpose_to_identity);
   }
 
   return mutation->Apply();
diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc
index 3a6316e..fd5ae22 100644
--- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc
@@ -552,6 +552,8 @@
            {{"T", DT_FLOAT}, {"Tpaddings", DT_INT32}}),
       NDef("transpose_1", "Transpose", {"pad", "perm_nchw_to_nhwc"},
            {{"T", DT_FLOAT}, {"Tperm", DT_INT32}}),
+      NDef("transpose_2", "Transpose", {"pad", "perm_nchw_to_nhwc"},
+           {{"T", DT_FLOAT}, {"Tperm", DT_INT32}}),
   });
 
   GraphDef output;
@@ -575,17 +577,21 @@
       NDef("pad", "Pad", {"transpose_0", "paddings"},
            {{"T", DT_FLOAT}, {"Tpaddings", DT_INT32}}),
       NDef("transpose_1", "Identity", {"pad"}, {{"T", DT_FLOAT}}),
+      NDef("transpose_2", "Identity", {"pad"}, {{"T", DT_FLOAT}}),
   });
 
   CompareGraphs(expected, output);
 
   Tensor x = GenerateRandomTensor<DT_FLOAT>({2, 6, 6, 8});
-  item.fetch = {"transpose_1"};
+  item.fetch = {"transpose_1", "transpose_2"};
   item.feed.emplace_back("x", x);
   auto tensors_expected = EvaluateFetchNodes(item);
   GrapplerItem optimized = item.WithGraph(std::move(output));
   auto tensors = EvaluateFetchNodes(optimized);
+  ASSERT_EQ(tensors.size(), 2);
+  ASSERT_EQ(tensors_expected.size(), 2);
   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
+  test::ExpectTensorEqual<float>(tensors_expected[1], tensors[1]);
 }
 
 // TODO(yanzha): Add more complex Graph for test.
diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc
index 3ffc6ad..41e6c9f 100644
--- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc
@@ -558,6 +558,9 @@
                                          DeviceBase* cpu_device,
                                          ResourceMgr* resource_mgr,
                                          bool* value) {
+  VLOG(4) << "Evaluate bool op: op_node=" << op_node.name()
+          << " input0=" << constant_operand_0.name()
+          << " input1=" << constant_operand_1.name();
   TensorVector inputs;
 
   const TensorProto& raw_val_0 = constant_operand_0.attr().at("value").tensor();
@@ -604,10 +607,14 @@
 
   // CASE 1: Control is a constant.
   if (IsReallyConstant(*switch_predicate, feed_nodes)) {
+    VLOG(3) << "Found switch node with constant predicate:"
+            << " switch_node=" << switch_node.name()
+            << " switch_predicate=" << switch_predicate->name();
     Tensor selector;
     CHECK(selector.FromProto(switch_predicate->attr().at("value").tensor()));
     *has_dead_fanout = true;
     *dead_fanout = selector.scalar<bool>()() ? 0 : 1;
+    return Status::OK();
   }
 
   GraphView::InputPort switch_input_port(&switch_node, 0);
@@ -617,28 +624,29 @@
   // We check if its a while loop such that the condition is a simple binary
   // operator which returns false for the initialization value.
   // TODO(srjoglekar): Improve to work with arbitrary predicate subgraphs.
-  if (!IsMerge(*switch_input)) {
+  if (!IsMerge(*switch_input) || !IsLoopCond(*switch_predicate)) {
     return Status::OK();
   }
 
-  // Find the boolean Op from predicate node.
-  NodeDef* switch_ctrl_node = nullptr;
-  for (int i = 0; i < switch_predicate->input().size(); ++i) {
-    NodeDef* node = node_map.GetNode(switch_predicate->input(i));
-    if (IsSimpleBinaryOperator(*node)) {
-      switch_ctrl_node = node;
-    }
-  }
-  if (switch_ctrl_node == nullptr) {
+  VLOG(3) << "Try to find a zero iteration while loop:"
+          << " switch_node=" << switch_node.name();
+
+  // Find the boolean predicate from a LoopCond node (e.g. Greater).
+  NodeDef* switch_ctrl_node = view.GetRegularFanin({switch_predicate, 0}).node;
+  if (!switch_ctrl_node || !IsSimpleBinaryOperator(*switch_ctrl_node)) {
     return Status::OK();
   }
+
   // Find the Merge node & the Constant Operand to the condition node, if
   // available.
   NodeDef* merge_node = nullptr;
   NodeDef* constant_ctrl_input = nullptr;
   int constant_index = 0;
   for (int i = 0; i < switch_ctrl_node->input().size(); ++i) {
-    NodeDef* node = node_map.GetNode(switch_ctrl_node->input(i));
+    const string& input = switch_ctrl_node->input(i);
+    if (IsControlInput(input)) continue;
+
+    NodeDef* node = view.GetNode(switch_ctrl_node->input(i));
     if (IsMerge(*node)) {
       merge_node = node;
     }
@@ -650,6 +658,7 @@
   if (merge_node == nullptr || constant_ctrl_input == nullptr) {
     return Status::OK();
   }
+
   // Find the initialization constant (via Enter, if one exists).
   NodeDef* enter_node = nullptr;
   NodeDef* constant_init_node = nullptr;
@@ -675,6 +684,15 @@
     return Status::OK();
   }
 
+  VLOG(4) << "Check if loop will be 0 iterations:"
+          << "\n|  switch_node        : " << switch_node.name()
+          << "\n|  switch_ctrl_node   : " << switch_ctrl_node->name()
+          << "\n|  merge_node         : " << merge_node->name()
+          << "\n|  constant_ctrl_input: " << constant_ctrl_input->name()
+          << "\n|  enter_node         : "
+          << (enter_node ? enter_node->name() : "<n/a>")
+          << "\n|  constant_init_node : " << constant_init_node->name();
+
   // Check if there will be 0 iterations. This will only happen if the condition
   // evaluates to false with respect to the initialization value.
   NodeDef* operand_0 =
@@ -685,9 +703,14 @@
   TF_RETURN_IF_ERROR(EvaluateBoolOpForConstantOperands(
       *switch_ctrl_node, *operand_0, *operand_1, cpu_device, resource_mgr,
       &constant_switch_value));
+
   if (constant_switch_value == false) {
+    VLOG(4) << "Remove 0 iteration while loop:"
+            << " switch_node=" << switch_node.name();
     *has_dead_fanout = true;
     *dead_fanout = 1;
+  } else {
+    VLOG(4) << "Was not able to prove that loop has 0 iterations.";
   }
   return Status::OK();
 }
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
index 4420f8e..90b4716 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
@@ -17,6 +17,7 @@
 
 #include <algorithm>
 #include <queue>
+#include <set>
 #include <unordered_map>
 #include <unordered_set>
 #include <vector>
@@ -442,12 +443,6 @@
 void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
                                 const string& recomputation_targets_name_scope,
                                 GraphDef* graph, const GrapplerItem& item) {
-  if (optimization_level != RewriterConfig::RECOMPUTATION_HEURISTICS &&
-      optimization_level != RewriterConfig::HEURISTICS &&
-      optimization_level != RewriterConfig::MANUAL) {
-    // Nothing to do
-    return;
-  }
   // The topological numberings and NodeMap will be stale as soon as we start
   // modifying the graph in RecomputeSubgraph. However, RecomputeSubgraph only
   // looks up nodes which were in the original graph, and preserves the graph
@@ -1274,13 +1269,24 @@
           absl::StrContains(device2, DEVICE_CPU));
 }
 
+void RelaxAssignNodes(const std::set<int>& nodes_to_relax,
+                      GraphDef* optimized_graph) {
+  for (int idx : nodes_to_relax) {
+    // Set an attribute telling AssignOp to ignore allocator constraints.
+    NodeDef* assign_node = optimized_graph->mutable_node(idx);
+    (*assign_node->mutable_attr())["_grappler_relax_allocator_constraints"]
+        .set_b(true);
+  }
+}
+
 // TODO(rmlarsen): Add distributed TF test.
-Status RelaxAllocatorConstraints(GraphDef* optimized_graph) {
+Status FindAssignNodesToRelax(const GraphDef& graph,
+                              std::set<int>* nodes_to_relax) {
   std::unordered_set<string> devices;
   std::vector<int> assign_nodes;
   bool found_send = false;
-  for (int i = 0; i < optimized_graph->node_size(); ++i) {
-    const NodeDef& node = optimized_graph->node(i);
+  for (int i = 0; i < graph.node_size(); ++i) {
+    const NodeDef& node = graph.node(i);
     devices.insert(node.device());
     if (IsAssign(node)) {
       assign_nodes.push_back(i);
@@ -1291,22 +1297,17 @@
     }
   }
   if (!found_send && devices.size() == 1) {
-    for (int assign_idx : assign_nodes) {
-      // Set an attribute telling AssignOp to ignore allocator constraints.
-      NodeDef* assign_node = optimized_graph->mutable_node(assign_idx);
-      (*assign_node->mutable_attr())["_grappler_relax_allocator_constraints"]
-          .set_b(true);
-    }
+    nodes_to_relax->insert(assign_nodes.begin(), assign_nodes.end());
     return Status::OK();
   }
 
   GraphTopologyView graph_view;
-  TF_RETURN_IF_ERROR(graph_view.InitializeFromGraph(
-      *optimized_graph, /*ignore_control_edges=*/true));
+  TF_RETURN_IF_ERROR(
+      graph_view.InitializeFromGraph(graph, /*ignore_control_edges=*/true));
   std::unordered_set<const NodeDef*> optimized_nodes;
 
   for (int i : assign_nodes) {
-    const NodeDef& assign_node = optimized_graph->node(i);
+    const NodeDef& assign_node = graph.node(i);
 
     if (optimized_nodes.find(&assign_node) == optimized_nodes.end()) {
       std::vector<const NodeDef*> assign_nodes_in_fanout;
@@ -1352,11 +1353,7 @@
           // Set an attribute telling AssignOp to ignore allocator constraints.
           const absl::optional<int> assign_node_idx =
               graph_view.GetNodeIndex(*assign_node_in_fanout);
-          NodeDef* assign_node_to_relax =
-              optimized_graph->mutable_node(assign_node_idx.value());
-          (*assign_node_to_relax
-                ->mutable_attr())["_grappler_relax_allocator_constraints"]
-              .set_b(true);
+          nodes_to_relax->insert(assign_node_idx.value());
         }
       }
     }
@@ -1368,39 +1365,55 @@
 
 Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
                                  GraphDef* optimized_graph) {
-  GrapplerItem optimized_item(item);
+  std::set<int> nodes_to_relax;
+  TF_RETURN_IF_ERROR(FindAssignNodesToRelax(item.graph, &nodes_to_relax));
 
-  RecomputationRewritingPass(optimization_level_,
-                             recomputation_targets_name_scope_,
-                             &optimized_item.graph, item);
+  bool run_recomputation_pass =
+      (optimization_level_ == RewriterConfig::RECOMPUTATION_HEURISTICS ||
+       optimization_level_ == RewriterConfig::HEURISTICS ||
+       optimization_level_ == RewriterConfig::MANUAL);
+  if (!run_recomputation_pass && nodes_to_relax.empty() && item.fetch.empty()) {
+    return errors::Aborted("Nothing to do.");
+  }
+
+  GrapplerItem optimized_item(item);
+  RelaxAssignNodes(nodes_to_relax, &optimized_item.graph);
+
+  if (run_recomputation_pass) {
+    RecomputationRewritingPass(optimization_level_,
+                               recomputation_targets_name_scope_,
+                               &optimized_item.graph, item);
+  }
 
   std::unordered_set<string> skip_list;
   // Bound the number of rewrite passes to avoid long processing times on graphs
   // that simply won't fit in memory.
-  bool updated_graph = true;
-  for (int i = 0; i < 25 && updated_graph; ++i) {
-    GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
-    updated_graph = false;
-    if ((optimization_level_ == RewriterConfig::DEFAULT_MEM_OPT ||
-         optimization_level_ == RewriterConfig::SCHEDULING_HEURISTICS ||
-         optimization_level_ == RewriterConfig::HEURISTICS) &&
-        cluster != nullptr) {
-      updated_graph |= SchedulingPass(cluster, &optimized_item);
-    }
+  // SchedulingPass() and SwappingPass() rely on defined fetches in order to
+  // infer the memory usage, so skip optimization if there are no fetches.
+  if (!item.fetch.empty()) {
+    bool updated_graph = true;
+    for (int i = 0; i < 25 && updated_graph; ++i) {
+      GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
+      updated_graph = false;
+      if ((optimization_level_ == RewriterConfig::DEFAULT_MEM_OPT ||
+           optimization_level_ == RewriterConfig::SCHEDULING_HEURISTICS ||
+           optimization_level_ == RewriterConfig::HEURISTICS) &&
+          cluster != nullptr) {
+        updated_graph |= SchedulingPass(cluster, &optimized_item);
+      }
 
-    GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
-    if ((optimization_level_ == RewriterConfig::DEFAULT_MEM_OPT ||
-         optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS ||
-         optimization_level_ == RewriterConfig::HEURISTICS ||
-         optimization_level_ == RewriterConfig::MANUAL) &&
-        cluster != nullptr) {
-      updated_graph |= SwappingPass(optimization_level_, cluster,
-                                    &optimized_item, &skip_list);
+      GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
+      if ((optimization_level_ == RewriterConfig::DEFAULT_MEM_OPT ||
+           optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS ||
+           optimization_level_ == RewriterConfig::HEURISTICS ||
+           optimization_level_ == RewriterConfig::MANUAL) &&
+          cluster != nullptr) {
+        updated_graph |= SwappingPass(optimization_level_, cluster,
+                                      &optimized_item, &skip_list);
+      }
     }
   }
 
-  TF_RETURN_IF_ERROR(RelaxAllocatorConstraints(&optimized_item.graph));
-
   optimized_graph->Swap(&optimized_item.graph);
   return Status::OK();
 }
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
index e7aea5f..9f2e0b3 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
@@ -240,6 +240,7 @@
 
   GrapplerItem item;
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  item.fetch = {"e"};
 
   EXPECT_EQ(7, item.graph.node_size());
   EXPECT_EQ(NodeName(e.name()), item.graph.node(4).name());
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 7f1302d..6e93c4e 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -192,13 +192,13 @@
     optimizers->push_back(
         MakeUnique<DependencyOptimizer>(cfg_.dependency_optimization()));
   }
-  if (cfg_.layout_optimizer() != RewriterConfig::OFF) {
-    optimizers->push_back(MakeUnique<GenericLayoutOptimizer>());
-  }
   if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision())) {
     optimizers->push_back(
         MakeUnique<AutoMixedPrecision>(cfg_.auto_mixed_precision()));
   }
+  if (cfg_.layout_optimizer() != RewriterConfig::OFF) {
+    optimizers->push_back(MakeUnique<GenericLayoutOptimizer>());
+  }
   if (cfg_.memory_optimization() != RewriterConfig::NO_MEM_OPT) {
     if (cfg_.memory_optimizer_target_node_name_scope().empty()) {
       optimizers->push_back(
@@ -802,8 +802,6 @@
 
   std::unique_ptr<tensorflow::Graph> optimized_graph(
       new tensorflow::Graph(OpRegistry::Global()));
-  TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
-                                            out_graph, optimized_graph.get()));
 
   // Copy optimized functions back to the overlay lib.
   if (flib) {
@@ -817,25 +815,28 @@
     }
   }
 
-  *g = std::move(optimized_graph);
+  TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
+      GraphConstructorOptions(), std::move(out_graph), optimized_graph.get()));
 
   // The graph conversion sets the requested device names but not the
   // assigned device names. However, since at this point the graph is
   // placed TF expects an assigned device name for every node. Therefore
   // we copy the requested device into the assigned device field.
-  for (Node* node : (*g)->nodes()) {
+  for (Node* node : optimized_graph->nodes()) {
     if (node->IsOp() && node->assigned_device_name().empty()) {
       if (node->requested_device().empty()) {
         return errors::Internal(
             "Either placer did not place the node or Grappler did not "
             "copy the assigned device. Contact Grappler team since latter "
             "is more likely. Node=",
-            node->name(), " Graph: ", (*g)->ToGraphDefDebug().DebugString());
+            node->name(),
+            " Graph: ", optimized_graph->ToGraphDefDebug().DebugString());
       }
       node->set_assigned_device_name(node->requested_device());
     }
   }
 
+  *g = std::move(optimized_graph);
   return Status::OK();
 }
 
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
index 3cfebda..766e8a1 100644
--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++ b/tensorflow/core/grappler/optimizers/remapper.cc
@@ -456,7 +456,7 @@
 
   // Squeeze must not squeeze output channel dimension.
   std::vector<int32> dims;
-  if (!GetNodeAttr(*squeeze_node_def, "squeeze_dims", &dims).ok()) return false;
+  if (!TryGetNodeAttr(*squeeze_node_def, "squeeze_dims", &dims)) return false;
   for (auto dim : dims) {
     if (dim == 3) return false;
   }
@@ -531,7 +531,7 @@
   // We successfully found a Conv2D+FusedBatchNorm pattern.
   matched->contraction = conv2d_node_view->node_index();
   matched->fused_batch_norm = node_index;
-  if (!GetNodeAttr(*node_def, "epsilon", &matched->epsilon).ok()) return false;
+  if (!TryGetNodeAttr(*node_def, "epsilon", &matched->epsilon)) return false;
 
   return true;
 }
@@ -684,7 +684,7 @@
 
   // Check that the node is in inference mode.
   bool is_training = true;
-  if (!GetNodeAttr(*node_def, kIsTraining, &is_training).ok()) return false;
+  if (!TryGetNodeAttr(*node_def, kIsTraining, &is_training)) return false;
   if (is_training) return false;
 
   const auto& props = ctx.graph_properties.GetInputProperties(node_def->name());
@@ -1477,7 +1477,7 @@
     if (GetDataTypeFromAttr(*node_def, "T") != DT_FLOAT) return false;
 
     bool is_training = true;
-    if (!GetNodeAttr(*node_def, kIsTraining, &is_training).ok()) return false;
+    if (!TryGetNodeAttr(*node_def, kIsTraining, &is_training)) return false;
     if (is_training) return false;
 
     return true;
@@ -1634,7 +1634,6 @@
     // Infer properties lazily in case they are not needed.
     if (!ctx.inferred_graph_properties && RequiresInferredShapes(ctx, i)) {
       const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
-      // TODO(rmlarsen): Get rid of tensor value copies.
       TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(
           assume_valid_feeds,
           /*aggressive_shape_inference=*/false,
diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
index dfdbc8c..3a1cfb6 100644
--- a/tensorflow/core/grappler/optimizers/shape_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
@@ -95,7 +95,7 @@
       }
       const auto& prop =
           properties.GetOutputProperties(reduce_indices.node->name());
-      if (prop.size() < reduce_indices.port_id) {
+      if (prop.size() <= reduce_indices.port_id) {
         continue;
       }
       const TensorShapeProto& reduction_indices_shape =
diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD
index d193907..fef002b 100644
--- a/tensorflow/core/grappler/utils/BUILD
+++ b/tensorflow/core/grappler/utils/BUILD
@@ -1,6 +1,6 @@
 load("//tensorflow:tensorflow.bzl", "tf_cc_test")
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_protos_grappler",
 )
 
diff --git a/tensorflow/core/grappler/utils/graph_view.cc b/tensorflow/core/grappler/utils/graph_view.cc
index 0dccee5..5b3d8e7 100644
--- a/tensorflow/core/grappler/utils/graph_view.cc
+++ b/tensorflow/core/grappler/utils/graph_view.cc
@@ -27,6 +27,7 @@
 #include "tensorflow/core/grappler/utils/graph_view_internal.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/util/device_name_utils.h"
 
 namespace tensorflow {
 namespace grappler {
@@ -815,7 +816,9 @@
                           attr_to_add.second);
     }
     const string& device = diff.update_device ? diff.device : node->device();
-    if (device.empty()) {
+    DeviceNameUtils::ParsedName name;
+    if (device.empty() || !DeviceNameUtils::ParseFullName(device, &name) ||
+        !name.has_type) {
       continue;
     }
     s = IsKernelRegisteredForNode(diff.update_name ? diff.name : node->name(),
@@ -824,19 +827,20 @@
                                   diff.update_op ? diff.op : node->op(), device,
                                   AttrSlice(&diff.processed_attrs));
     if (!s.ok()) {
-      return errors::InvalidArgument(kMutableGraphViewApplyError,
-                                     s.error_message());
+      LOG(WARNING) << s.error_message();
     }
   }
   for (const auto& new_node_holder : mutation_.new_nodes_) {
     const auto& new_node_def = new_node_holder.node;
-    if (new_node_def.device().empty()) {
+    DeviceNameUtils::ParsedName name;
+    if (new_node_def.device().empty() ||
+        !DeviceNameUtils::ParseFullName(new_node_def.device(), &name) ||
+        !name.has_type) {
       continue;
     }
     s = IsKernelRegisteredForNode(new_node_def);
     if (!s.ok()) {
-      return errors::InvalidArgument(kMutableGraphViewApplyError,
-                                     s.error_message());
+      LOG(WARNING) << s.error_message();
     }
   }
   return Status::OK();
diff --git a/tensorflow/core/grappler/utils/graph_view_test.cc b/tensorflow/core/grappler/utils/graph_view_test.cc
index 3170de0..a8f4b65 100644
--- a/tensorflow/core/grappler/utils/graph_view_test.cc
+++ b/tensorflow/core/grappler/utils/graph_view_test.cc
@@ -1894,6 +1894,7 @@
 constexpr char kMatchingFiles[] = "MatchingFiles";
 
 TEST_F(MutationTest, OpWithUnsupportedDevice) {
+  GTEST_SKIP() << "Reenable once offline optimization tests enable CUDA.";
   auto test_graph = []() {
     return GDef({NDef("a", kMatchingFiles, {}, {}, kDeviceCPU0)},
                 /*funcs=*/{});
@@ -1930,6 +1931,7 @@
 }
 
 TEST_F(MutationTest, OpMissingAttribute) {
+  GTEST_SKIP() << "Reenable once offline optimization tests enable CUDA.";
   auto test_graph = []() {
     return GDef({NDef("a", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU0)},
                 /*funcs=*/{});
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 572afde..d326e3e 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -20,16 +20,15 @@
 load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
 load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_tests")
 load(
-    "//tensorflow/core:platform/default/cuda_build_defs.bzl",
+    "//tensorflow/core/platform:default/cuda_build_defs.bzl",
     "if_cuda_is_configured",
 )
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_kernel_tests_linkstatic",
 )
 load(
-    "//tensorflow/core:platform/default/build_config_root.bzl",
-    "if_static",
+    "//tensorflow/core/platform:default/build_config_root.bzl",
     "tf_cuda_tests_tags",
 )
 load(
@@ -198,6 +197,12 @@
 tf_kernel_library(
     name = "collective_ops",
     srcs = if_nccl([
+        "collective_nccl.h",
+        "collective_nccl.cc",
+        "collective_nccl_broadcaster.h",
+        "collective_nccl_broadcaster.cc",
+        "collective_nccl_gatherer.h",
+        "collective_nccl_gatherer.cc",
         "collective_nccl_reducer.h",
         "collective_nccl_reducer.cc",
     ]),
@@ -214,9 +219,9 @@
 )
 
 tf_cuda_cc_test(
-    name = "collective_nccl_reducer_test",
+    name = "collective_nccl_test",
     size = "small",
-    srcs = ["collective_nccl_reducer_test.cc"],
+    srcs = ["collective_nccl_test.cc"],
     tags = tf_cuda_tests_tags() + ["no_cuda_on_cpu_tap"],
     deps = [
         "//tensorflow/core:all_kernels",
@@ -497,15 +502,18 @@
     hdrs = ["gpu_utils.h"],
     deps = [
         ":gpu_util_hdrs",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/types:span",
         "//tensorflow/core:autotuning_proto_cc",
         "//tensorflow/core:conv_autotuning_proto_cc",
         "//tensorflow/core:lib",
         "//tensorflow/core:logger",
         "//tensorflow/core:stream_executor",
         "//tensorflow/core/util/proto:proto_utils",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/types:span",
-    ],
+    ] + if_cuda([
+        "//tensorflow/stream_executor/cuda:redzone_allocator",
+        "//tensorflow/stream_executor/cuda:ptxas_utils",
+    ]),
 )
 
 tf_cc_test(
@@ -813,7 +821,6 @@
         "eigen_backward_spatial_convolutions.h",
         "eigen_cuboid_convolution.h",
         "eigen_pooling.h",
-        "eigen_softmax.h",
         "eigen_spatial_convolutions.h",
         "eigen_volume_patch.h",
     ],
@@ -834,7 +841,6 @@
         "eigen_backward_spatial_convolutions.h",
         "eigen_cuboid_convolution.h",
         "eigen_pooling.h",
-        "eigen_softmax.h",
         "eigen_spatial_convolutions.h",
         "eigen_volume_patch.h",
     ],
@@ -1332,6 +1338,7 @@
         ":ragged_range_op",
         ":ragged_tensor_from_variant_op",
         ":ragged_tensor_to_sparse_kernel",
+        ":ragged_tensor_to_tensor_op",
         ":ragged_tensor_to_variant_op",
     ],
 )
@@ -1388,6 +1395,35 @@
 )
 
 tf_cc_test(
+    name = "ragged_tensor_to_tensor_op_test",
+    size = "small",
+    srcs = ["ragged_tensor_to_tensor_op_test.cc"],
+    deps = [
+        ":ops_testutil",
+        ":ragged_tensor_to_tensor_op",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+    ],
+)
+
+tf_kernel_library(
+    name = "ragged_tensor_to_tensor_op",
+    srcs = ["ragged_tensor_to_tensor_op.cc"],
+    deps = [
+        ":broadcast_to_op",
+        ":list_kernels",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:framework_lite",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:ragged_to_dense_util",
+    ],
+)
+
+tf_cc_test(
     name = "ragged_tensor_to_sparse_kernel_test",
     size = "small",
     srcs = ["ragged_tensor_to_sparse_kernel_test.cc"],
@@ -1884,7 +1920,10 @@
 # Unlike gather_functor library, this does not include the CUDA code and deps.
 cc_library(
     name = "gather_functor_hdr",
-    hdrs = ["gather_functor.h"],
+    hdrs = [
+        "gather_functor.h",
+        "gather_functor_batched.h",
+    ],
 )
 
 tf_kernel_library(
@@ -2927,7 +2966,6 @@
         "eigen_attention_test.cc",
         "eigen_backward_spatial_convolutions_test.cc",
         "eigen_pooling_test.cc",
-        "eigen_softmax_test.cc",
         "eigen_spatial_convolutions_test.cc",
     ],
     deps = [
@@ -3294,16 +3332,9 @@
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core/platform/default/build_config:cublas_plugin",
-    ] + if_static(
-        [
-            "@local_config_cuda//cuda:cusolver",
-            "@local_config_cuda//cuda:cublas",
-        ],
-        [
-            "//tensorflow/stream_executor/cuda:cusolver_stub",
-            "//tensorflow/stream_executor/cuda:cublas_stub",
-        ],
-    ),
+        "//tensorflow/stream_executor/cuda:cublas_lib",
+        "//tensorflow/stream_executor/cuda:cusolver_lib",
+    ],
 )
 
 tf_kernel_library(
@@ -3313,10 +3344,8 @@
     deps = [
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
-    ] + if_static(
-        ["@local_config_cuda//cuda:cusparse"],
-        ["//tensorflow/stream_executor/cuda:cusparse_stub"],
-    ),
+        "//tensorflow/stream_executor/cuda:cusparse_lib",
+    ],
 )
 
 LINALG_DEPS = [
@@ -4364,7 +4393,7 @@
 tf_kernel_library(
     name = "lrn_op",
     prefix = "lrn_op",
-    deps = NN_DEPS,
+    deps = NN_DEPS + if_rocm([":conv_ops_gpu_hdrs"]),
 )
 
 tf_kernel_library(
@@ -5317,6 +5346,7 @@
         ":string_join_op",
         ":string_length_op",
         ":string_lower_op",
+        ":string_ngrams_op",
         ":string_split_op",
         ":string_strip_op",
         ":string_to_hash_bucket_op",
@@ -5458,6 +5488,30 @@
 )
 
 tf_kernel_library(
+    name = "string_ngrams_op",
+    srcs = ["string_ngrams_op.cc"],
+    deps = STRING_DEPS + [
+        "@com_google_absl//absl/strings",
+    ],
+)
+
+tf_cc_test(
+    name = "string_ngrams_op_test",
+    srcs = ["string_ngrams_op_test.cc"],
+    deps = [
+        ":ops_testutil",
+        ":ops_util",
+        ":string_ngrams_op",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+    ],
+)
+
+tf_kernel_library(
     name = "string_strip_op",
     prefix = "string_strip_op",
     deps = STRING_DEPS,
@@ -5945,7 +5999,6 @@
         "eigen_convolution_helpers.h",
         "eigen_cuboid_convolution.h",
         "eigen_pooling.h",
-        "eigen_softmax.h",
         "eigen_spatial_convolutions.h",
         "eigen_spatial_convolutions-inl.h",
         "eigen_volume_patch.h",
@@ -6014,6 +6067,7 @@
         "function_ops.cc",
         "function_ops.h",
         "gather_functor.h",
+        "gather_functor_batched.h",
         "gather_nd_op.cc",
         "gather_nd_op.h",
         "gather_nd_op_cpu_impl.h",
diff --git a/tensorflow/core/kernels/as_string_op.cc b/tensorflow/core/kernels/as_string_op.cc
index e6d6c40..8341909 100644
--- a/tensorflow/core/kernels/as_string_op.cc
+++ b/tensorflow/core/kernels/as_string_op.cc
@@ -116,7 +116,7 @@
     OP_REQUIRES_OK(context,
                    context->allocate_output("output", input_tensor->shape(),
                                             &output_tensor));
-    auto output_flat = output_tensor->flat<string>();
+    auto output_flat = output_tensor->flat<tstring>();
 
 #define ENCODE_TYPE(type, T, enc_str)                                     \
   case (type): {                                                          \
diff --git a/tensorflow/core/kernels/avgpooling_op.cc b/tensorflow/core/kernels/avgpooling_op.cc
index 1cc5a2d..ead0efb 100644
--- a/tensorflow/core/kernels/avgpooling_op.cc
+++ b/tensorflow/core/kernels/avgpooling_op.cc
@@ -36,6 +36,10 @@
 #include "tensorflow/core/util/padding.h"
 #include "tensorflow/core/util/tensor_format.h"
 
+#if GOOGLE_CUDA
+#include "third_party/gpus/cudnn/cudnn.h"
+#endif  // GOOGLE_CUDA
+
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #include "tensorflow/core/kernels/maxpooling_op_gpu.h"
 #include "tensorflow/core/kernels/pooling_ops_common_gpu.h"
@@ -155,6 +159,12 @@
 
     TensorShape output_shape = params.forward_output_shape();
 
+#if CUDNN_VERSION >= 7300
+    DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kAverage, ksize_,
+                             stride_, padding_, data_format_, tensor_in,
+                             output_shape,
+                             /*propagate_nans=*/false);
+#else
     if (data_format_ == FORMAT_NCHW) {
       DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kAverage, ksize_,
                                stride_, padding_, data_format_, tensor_in,
@@ -170,6 +180,7 @@
           tensor_in.tensor<T, 4>(), params.window_rows, params.window_cols,
           params.row_stride, params.col_stride, pt);
     }
+#endif  // CUDNN_VERSION >= 7300
   }
 
  private:
@@ -496,6 +507,12 @@
       output_shape.AddDim(shape_vec(i));
     }
 
+#if CUDNN_VERSION >= 7300
+    DnnPoolingGradOp<T>::Compute(context, se::dnn::PoolingMode::kAverage,
+                                 ksize_, stride_, padding_, data_format_,
+                                 nullptr, nullptr, out_backprop, output_shape,
+                                 /*propagate_nans=*/false);
+#else
     if (data_format_ == FORMAT_NHWC) {
       const int64 out_backprop_batch = out_backprop.dim_size(0);
       const int64 out_backprop_rows = out_backprop.dim_size(1);
@@ -552,6 +569,7 @@
                                    nullptr, nullptr, out_backprop, output_shape,
                                    /*propagate_nans=*/false);
     }
+#endif  // CUDNN_VERSION >= 7300
   }
 
  private:
diff --git a/tensorflow/core/kernels/barrier_ops.cc b/tensorflow/core/kernels/barrier_ops.cc
index 89d742c..adbe370 100644
--- a/tensorflow/core/kernels/barrier_ops.cc
+++ b/tensorflow/core/kernels/barrier_ops.cc
@@ -308,7 +308,7 @@
                          int component_index, int i,
                          std::vector<Tuple>* ready_tuples, bool* new_elements)
       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
-    auto keys_vec = keys.flat<string>();
+    auto keys_vec = keys.flat<tstring>();
     auto values_matrix = values.flat_outer_dims<T>();
 
     PersistentTuple* element_ptr;
@@ -392,7 +392,7 @@
                                                   &key, &allocated_key));
       ready_tuple.push_back(*element[0].AccessTensor(ctx));  // index
       ready_tuple.push_back(*allocated_key);                 // key
-      ready_tuple[1].scalar<string>()() = keys_vec(i);       // set the key
+      ready_tuple[1].scalar<tstring>()() = keys_vec(i);      // set the key
       for (int j = 1; j < num_components() + 1; ++j) {
         ready_tuple.push_back(*element[j].AccessTensor(ctx));
       }
diff --git a/tensorflow/core/kernels/base64_ops.cc b/tensorflow/core/kernels/base64_ops.cc
index 74e6b39..cb235f5 100644
--- a/tensorflow/core/kernels/base64_ops.cc
+++ b/tensorflow/core/kernels/base64_ops.cc
@@ -36,8 +36,8 @@
     OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                      &output_tensor));
 
-    auto input = input_tensor.flat<string>();
-    auto output = output_tensor->flat<string>();
+    auto input = input_tensor.flat<tstring>();
+    auto output = output_tensor->flat<tstring>();
 
     for (int64 i = 0; i < input.dimension(0); ++i) {
       OP_REQUIRES_OK(context, Base64Encode(input(i), pad_, &output(i)));
@@ -61,8 +61,8 @@
     OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                      &output_tensor));
 
-    auto input = input_tensor.flat<string>();
-    auto output = output_tensor->flat<string>();
+    auto input = input_tensor.flat<tstring>();
+    auto output = output_tensor->flat<tstring>();
 
     for (int64 i = 0; i < input.dimension(0); ++i) {
       OP_REQUIRES_OK(context, Base64Decode(input(i), &output(i)));
diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h
index 84f7571..1e85dbc 100644
--- a/tensorflow/core/kernels/batch_matmul_op_impl.h
+++ b/tensorflow/core/kernels/batch_matmul_op_impl.h
@@ -265,10 +265,10 @@
 
   BlasScratchAllocator(OpKernelContext* context) : context_(context) {}
 
-  int64 GetMemoryLimitInBytes(Stream* stream) override { return -1; }
+  int64 GetMemoryLimitInBytes() override { return -1; }
 
   se::port::StatusOr<DeviceMemoryBytes> AllocateBytes(
-      Stream* stream, int64 byte_size) override {
+      int64 byte_size) override {
     Tensor temporary_memory;
 
     Status allocation_status(context_->allocate_temp(
diff --git a/tensorflow/core/kernels/boosted_trees/BUILD b/tensorflow/core/kernels/boosted_trees/BUILD
index f6414c8..3c2bc92 100644
--- a/tensorflow/core/kernels/boosted_trees/BUILD
+++ b/tensorflow/core/kernels/boosted_trees/BUILD
@@ -7,7 +7,7 @@
     "tf_kernel_library",
 )
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_proto_library",
 )
 
diff --git a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
index 718cf8e..7cd62af 100644
--- a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
@@ -324,7 +324,7 @@
         context, context->allocate_output("examples_debug_outputs_serialized",
                                           {batch_size}, &output_debug_info_t));
     // Will contain serialized protos, per example.
-    auto output_debug_info = output_debug_info_t->flat<string>();
+    auto output_debug_info = output_debug_info_t->flat<tstring>();
     const int32 last_tree = resource->num_trees() - 1;
 
     // For each given example, traverse through all trees keeping track of the
diff --git a/tensorflow/core/kernels/boosted_trees/quantile_ops.cc b/tensorflow/core/kernels/boosted_trees/quantile_ops.cc
index 36f52ab..b4d300b 100644
--- a/tensorflow/core/kernels/boosted_trees/quantile_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/quantile_ops.cc
@@ -264,6 +264,7 @@
         *context->device()->tensorflow_cpu_worker_threads();
     Shard(worker_threads.num_threads, worker_threads.workers, num_features_,
           kCostPerUnit, do_quantile_summary_gen);
+    stream_resource->ResetStreams();
   }
 
  private:
@@ -424,6 +425,7 @@
     Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
           kCostPerUnit, do_quantile_flush);
 
+    stream_resource->ResetStreams();
     stream_resource->set_buckets_ready(true);
   }
 
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h b/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h
index 965bf2c..10afc9e 100644
--- a/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h
@@ -67,6 +67,14 @@
     are_buckets_ready_ = are_buckets_ready;
   }
 
+  void ResetStreams() {
+    streams_.clear();
+    streams_.reserve(num_streams_);
+    for (int64 idx = 0; idx < num_streams_; ++idx) {
+      streams_.push_back(QuantileStream(epsilon_, max_elements_));
+    }
+  }
+
  private:
   ~BoostedTreesQuantileStreamResource() override {}
 
diff --git a/tensorflow/core/kernels/boosted_trees/resource_ops.cc b/tensorflow/core/kernels/boosted_trees/resource_ops.cc
index 5a9c354..ac1fb56 100644
--- a/tensorflow/core/kernels/boosted_trees/resource_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/resource_ops.cc
@@ -51,7 +51,7 @@
     std::unique_ptr<BoostedTreesEnsembleResource> result(
         new BoostedTreesEnsembleResource());
     if (!result->InitFromSerialized(
-            tree_ensemble_serialized_t->scalar<string>()(), stamp_token)) {
+            tree_ensemble_serialized_t->scalar<tstring>()(), stamp_token)) {
       result->Unref();
       OP_REQUIRES(
           context, false,
@@ -152,7 +152,7 @@
     Tensor* output_proto_t = nullptr;
     OP_REQUIRES_OK(context,
                    context->allocate_output(1, TensorShape(), &output_proto_t));
-    output_proto_t->scalar<string>()() =
+    output_proto_t->scalar<tstring>()() =
         tree_ensemble_resource->SerializeAsString();
   }
 };
@@ -187,7 +187,7 @@
     OP_REQUIRES(
         context,
         tree_ensemble_resource->InitFromSerialized(
-            tree_ensemble_serialized_t->scalar<string>()(), stamp_token),
+            tree_ensemble_serialized_t->scalar<tstring>()(), stamp_token),
         errors::InvalidArgument("Unable to parse tree ensemble proto."));
   }
 };
diff --git a/tensorflow/core/kernels/boosted_trees/stats_ops.cc b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
index fac5967..c421bff 100644
--- a/tensorflow/core/kernels/boosted_trees/stats_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
@@ -25,8 +25,18 @@
 
 namespace tensorflow {
 
-const char INEQUALITY_DEFAULT_LEFT[] = "inequality_default_left";
-const char INEQUALITY_DEFAULT_RIGHT[] = "inequality_default_right";
+// TODO(tanzheny): Make these const as proto enum.
+const char kInequalityDefaultLeft[] = "inequality_default_left";
+const char kInequalityDefaultRight[] = "inequality_default_right";
+const char kEqualityDefaultLeft[] = "equality_default_left";
+
+using Matrix =
+    Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
+using ConstMatrixMap = Eigen::Map<const Matrix>;
+using MatrixMap = Eigen::Map<Matrix>;
+
+using ConstVectorMap = Eigen::Map<const Eigen::VectorXf>;
+using VectorMap = Eigen::Map<Eigen::VectorXf>;
 
 // V1 Op. Deprecated. BoostedTreesCalculateBestFeatureSplitOp is V2.
 class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
@@ -228,6 +238,7 @@
       OpKernelConstruction* const context)
       : OpKernel(context) {
     OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_));
+    OP_REQUIRES_OK(context, context->GetAttr("split_type", &split_type_));
   }
 
   void Compute(OpKernelContext* const context) override {
@@ -244,9 +255,10 @@
         stats_summary_t->tensor<float, 4>();
     const int64 feature_dims = stats_summary_t->dim_size(1);
     const int64 num_buckets = stats_summary_t->dim_size(2);
-    const int64 hessian_dim = stats_summary_t->dim_size(3) - logits_dim_;
+    const int64 logits_dim = logits_dim_;
+    const int64 hessian_dim = stats_summary_t->dim_size(3) - logits_dim;
     DCHECK_GT(hessian_dim, 0);
-    DCHECK_LE(hessian_dim, logits_dim_ * logits_dim_);
+    DCHECK_LE(hessian_dim, logits_dim * logits_dim);
 
     const Tensor* l1_t;
     OP_REQUIRES_OK(context, context->input("l1", &l1_t));
@@ -280,76 +292,44 @@
     std::vector<Eigen::VectorXf> output_right_node_contribs;
     std::vector<string> output_split_types;
 
+    // Iterate each node and find the best gain per node.
     for (int node_id = node_id_first; node_id < node_id_last; ++node_id) {
-      std::vector<Eigen::VectorXf> cum_grad;
-      std::vector<Eigen::VectorXf> cum_hess;
-      cum_grad.reserve(num_buckets);
-      cum_hess.reserve(num_buckets);
-
       float best_gain = std::numeric_limits<float>::lowest();
-      float best_bucket = 0;
-      float best_f_dim = 0;
-      string best_split_type = INEQUALITY_DEFAULT_LEFT;
-      Eigen::VectorXf best_contrib_for_left(logits_dim_);
-      Eigen::VectorXf best_contrib_for_right(logits_dim_);
+      int32 best_bucket = 0;
+      int32 best_f_dim = 0;
+      string best_split_type;
+      Eigen::VectorXf best_contrib_for_left(logits_dim);
+      Eigen::VectorXf best_contrib_for_right(logits_dim);
       float parent_gain;
-      Eigen::VectorXf unused(logits_dim_);
-      for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
-        cum_grad.clear();
-        cum_hess.clear();
-        Eigen::VectorXf total_grad = Eigen::VectorXf::Zero(logits_dim_);
-        Eigen::VectorXf total_hess = Eigen::VectorXf::Zero(hessian_dim);
-        for (int bucket = 0; bucket < num_buckets; ++bucket) {
-          for (int i = 0; i < logits_dim_; ++i) {
-            total_grad[i] += stats_summary(node_id, f_dim, bucket, i);
-            total_hess[i] +=
-                stats_summary(node_id, f_dim, bucket, logits_dim_ + i);
-          }
-          for (int i = logits_dim_; i < hessian_dim; ++i) {
-            // Full hessian.
-            total_hess[i] +=
-                stats_summary(node_id, f_dim, bucket, logits_dim_ + i);
-          }
-          cum_grad.push_back(total_grad);
-          cum_hess.push_back(total_hess);
-        }
 
-        // Only need to check once as total_grad/total_hess will be the same for
-        // all features.
-        if (f_dim == 0) {
-          if (total_hess.norm() < min_node_weight) {
-            break;
-          }
-          CalculateWeightsAndGains(total_grad, total_hess, l1, l2, &unused,
-                                   &parent_gain);
-        }
+      ConstMatrixMap stats_mat(&stats_summary(node_id, 0, 0, 0), num_buckets,
+                               logits_dim + hessian_dim);
+      const Eigen::VectorXf total_grad =
+          stats_mat.leftCols(logits_dim).colwise().sum();
+      const Eigen::VectorXf total_hess =
+          stats_mat.rightCols(hessian_dim).colwise().sum();
+      if (total_hess.norm() < min_node_weight) {
+        continue;
+      }
+      Eigen::VectorXf parent_weight(logits_dim);
+      CalculateWeightsAndGains(total_grad, total_hess, l1, l2, &parent_weight,
+                               &parent_gain);
 
-        for (int bucket = 0; bucket < num_buckets; ++bucket) {
-          const Eigen::VectorXf cum_grad_bucket = cum_grad[bucket];
-          const Eigen::VectorXf cum_hess_bucket = cum_hess[bucket];
-          // Left child.
-          Eigen::VectorXf contrib_for_left(logits_dim_);
-          float gain_for_left;
-          CalculateWeightsAndGains(cum_grad_bucket, cum_hess_bucket, l1, l2,
-                                   &contrib_for_left, &gain_for_left);
-          // Right child.
-          // TODO(crawles): consider accumulating right grad/hessians when doing
-          // cum_grad/hessian (if this becomes a bottleneck).
-          const Eigen::VectorXf grad_for_right = total_grad - cum_grad_bucket;
-          const Eigen::VectorXf hess_for_right = total_hess - cum_hess_bucket;
-          Eigen::VectorXf contrib_for_right(logits_dim_);
-          float gain_for_right;
-          CalculateWeightsAndGains(grad_for_right, hess_for_right, l1, l2,
-                                   &contrib_for_right, &gain_for_right);
-          if (GainIsLarger(gain_for_left + gain_for_right, best_gain)) {
-            best_gain = gain_for_left + gain_for_right;
-            best_bucket = bucket;
-            best_f_dim = f_dim;
-            best_contrib_for_left = contrib_for_left;
-            best_contrib_for_right = contrib_for_right;
-          }
-        }  // for bucket
-      }    // for f_dim
+      if (split_type_ == "inequality") {
+        best_split_type = kInequalityDefaultLeft;
+        CalculateBestInequalitySplit(
+            stats_summary, node_id, feature_dims, logits_dim, hessian_dim,
+            num_buckets, min_node_weight, l1, l2, &best_gain, &best_bucket,
+            &best_f_dim, &best_contrib_for_left, &best_contrib_for_right);
+      } else {
+        best_split_type = kEqualityDefaultLeft;
+        CalculateBestEqualitySplit(
+            stats_summary, total_grad, total_hess, node_id, feature_dims,
+            logits_dim, hessian_dim, num_buckets, l1, l2, &best_gain,
+            &best_bucket, &best_f_dim, &best_contrib_for_left,
+            &best_contrib_for_right);
+      }
+
       if (best_gain == std::numeric_limits<float>::lowest()) {
         // Do not add the node if not split if found.
         continue;
@@ -395,7 +375,7 @@
     // output_left_node_contribs
     Tensor* output_left_node_contribs_t;
     OP_REQUIRES_OK(context, context->allocate_output(
-                                "left_node_contribs", {num_nodes, logits_dim_},
+                                "left_node_contribs", {num_nodes, logits_dim},
                                 &output_left_node_contribs_t));
     auto output_left_node_contribs_matrix =
         output_left_node_contribs_t->matrix<float>();
@@ -403,7 +383,7 @@
     // output_right_node_contribs
     Tensor* output_right_node_contribs_t;
     OP_REQUIRES_OK(context, context->allocate_output(
-                                "right_node_contribs", {num_nodes, logits_dim_},
+                                "right_node_contribs", {num_nodes, logits_dim},
                                 &output_right_node_contribs_t));
     auto output_right_node_contribs_matrix =
         output_right_node_contribs_t->matrix<float>();
@@ -413,7 +393,7 @@
     OP_REQUIRES_OK(
         context, context->allocate_output("split_with_default_directions",
                                           {num_nodes}, &output_split_types_t));
-    auto output_split_types_vec = output_split_types_t->vec<string>();
+    auto output_split_types_vec = output_split_types_t->vec<tstring>();
 
     // Sets output tensors from vectors.
     for (int i = 0; i < num_nodes; ++i) {
@@ -422,7 +402,7 @@
       output_gains_vec(i) = output_gains[i] - tree_complexity;
       output_feature_dimensions_vec(i) = output_feature_dimensions[i];
       output_thresholds_vec(i) = output_thresholds[i];
-      for (int j = 0; j < logits_dim_; ++j) {
+      for (int j = 0; j < logits_dim; ++j) {
         output_left_node_contribs_matrix(i, j) =
             output_left_node_contribs[i][j];
         output_right_node_contribs_matrix(i, j) =
@@ -433,7 +413,110 @@
   }
 
  private:
+  // TODO(crawles): Simplify inequality path just like equality b/138329196
+  // Currently this is not simplify-able due to numerical instability in math
+  // i.e. gain = -g.transpose() * hessian_and_reg.colPivHouseholderQr().solve(g)
+  // It caused gain to be Inf when g is approaching 0 but not exactly 0 while
+  // there is no regularization.
+  // Calculate the best inequality split per node.
+  void CalculateBestInequalitySplit(TTypes<float, 4>::ConstTensor stats_summary,
+                                    const int node_id, const int feature_dims,
+                                    const int logits_dim, const int hessian_dim,
+                                    const int num_buckets,
+                                    const float min_node_weight, const float l1,
+                                    const float l2, float* best_gain,
+                                    int* best_bucket, int* best_f_dim,
+                                    Eigen::VectorXf* best_contrib_for_left,
+                                    Eigen::VectorXf* best_contrib_for_right) {
+    std::vector<Eigen::VectorXf> cum_grad;
+    std::vector<Eigen::VectorXf> cum_hess;
+    cum_grad.reserve(num_buckets);
+    cum_hess.reserve(num_buckets);
+
+    for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
+      cum_grad.clear();
+      cum_hess.clear();
+      Eigen::VectorXf total_grad = Eigen::VectorXf::Zero(logits_dim);
+      Eigen::VectorXf total_hess = Eigen::VectorXf::Zero(hessian_dim);
+      for (int bucket = 0; bucket < num_buckets; ++bucket) {
+        for (int i = 0; i < logits_dim; ++i) {
+          total_grad[i] += stats_summary(node_id, f_dim, bucket, i);
+        }
+        for (int i = 0; i < hessian_dim; ++i) {
+          // Full hessian.
+          total_hess[i] +=
+              stats_summary(node_id, f_dim, bucket, logits_dim + i);
+        }
+        cum_grad.push_back(total_grad);
+        cum_hess.push_back(total_hess);
+      }
+
+      for (int bucket = 0; bucket < num_buckets; ++bucket) {
+        MaybeUpdateBestSplit(cum_grad[bucket], total_grad, cum_hess[bucket],
+                             total_hess, logits_dim, bucket, f_dim, l1, l2,
+                             best_gain, best_bucket, best_f_dim,
+                             best_contrib_for_left, best_contrib_for_right);
+      }  // for bucket
+    }
+  }
+
+  // Calculate the best equality split per node.
+  void CalculateBestEqualitySplit(TTypes<float, 4>::ConstTensor stats_summary,
+                                  const Eigen::VectorXf& total_grad,
+                                  const Eigen::VectorXf& total_hess,
+                                  const int node_id, const int feature_dims,
+                                  const int logits_dim, const int hessian_dim,
+                                  const int num_buckets, const float l1,
+                                  const float l2, float* best_gain,
+                                  int* best_bucket, int* best_f_dim,
+                                  Eigen::VectorXf* best_contrib_for_left,
+                                  Eigen::VectorXf* best_contrib_for_right) {
+    for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
+      for (int bucket = 0; bucket < num_buckets; ++bucket) {
+        ConstVectorMap stats_vec(&stats_summary(node_id, f_dim, bucket, 0),
+                                 logits_dim + hessian_dim);
+        Eigen::VectorXf curr_grad = stats_vec.head(logits_dim);
+        Eigen::VectorXf curr_hess = stats_vec.tail(hessian_dim);
+        MaybeUpdateBestSplit(curr_grad, total_grad, curr_hess, total_hess,
+                             logits_dim, bucket, f_dim, l1, l2, best_gain,
+                             best_bucket, best_f_dim, best_contrib_for_left,
+                             best_contrib_for_right);
+      }
+    }
+  }
+
+  void MaybeUpdateBestSplit(const Eigen::VectorXf& grad_for_left,
+                            const Eigen::VectorXf& total_grad,
+                            const Eigen::VectorXf& hess_for_left,
+                            const Eigen::VectorXf& total_hess,
+                            const int logits_dim, const int bucket,
+                            const int f_dim, const float l1, const float l2,
+                            float* best_gain, int* best_bucket, int* best_f_dim,
+                            Eigen::VectorXf* best_contrib_for_left,
+                            Eigen::VectorXf* best_contrib_for_right) {
+    // Left child.
+    Eigen::VectorXf contrib_for_left(logits_dim);
+    float gain_for_left;
+    CalculateWeightsAndGains(grad_for_left, hess_for_left, l1, l2,
+                             &contrib_for_left, &gain_for_left);
+    // Right child.
+    const auto grad_for_right = total_grad - grad_for_left;
+    const auto hess_for_right = total_hess - hess_for_left;
+    Eigen::VectorXf contrib_for_right(logits_dim);
+    float gain_for_right;
+    CalculateWeightsAndGains(grad_for_right, hess_for_right, l1, l2,
+                             &contrib_for_right, &gain_for_right);
+    if (GainIsLarger(gain_for_left + gain_for_right, *best_gain)) {
+      *best_gain = gain_for_left + gain_for_right;
+      *best_bucket = bucket;
+      *best_f_dim = f_dim;
+      *best_contrib_for_left = contrib_for_left;
+      *best_contrib_for_right = contrib_for_right;
+    }
+  }
+
   int logits_dim_;
+  string split_type_;
 };
 
 // v2 op that supports multi-class.
@@ -594,7 +677,7 @@
     OP_REQUIRES_OK(
         context, context->allocate_output("split_with_default_directions",
                                           {num_nodes}, &output_split_types_t));
-    auto output_split_types_vec = output_split_types_t->vec<string>();
+    auto output_split_types_vec = output_split_types_t->vec<tstring>();
 
     // Sets output tensors from vectors.
     for (int i = 0; i < num_nodes; ++i) {
@@ -630,7 +713,7 @@
     float best_gain = std::numeric_limits<float>::lowest();
     float best_bucket = 0;
     float best_f_dim = 0;
-    string best_split_type = INEQUALITY_DEFAULT_LEFT;
+    string best_split_type = kInequalityDefaultLeft;
     float best_contrib_for_left = 0.0;
     float best_contrib_for_right = 0.0;
     // the sum of gradients including default bucket.
@@ -697,7 +780,7 @@
           best_gain = gain_for_left + gain_for_right;
           best_bucket = bucket_id;
           best_f_dim = feature_dim;
-          best_split_type = INEQUALITY_DEFAULT_RIGHT;
+          best_split_type = kInequalityDefaultRight;
           best_contrib_for_left = contrib_for_left[0];
           best_contrib_for_right = contrib_for_right[0];
         }
@@ -714,7 +797,7 @@
           best_gain = gain_for_left + gain_for_right;
           best_bucket = bucket_id;
           best_f_dim = feature_dim;
-          best_split_type = INEQUALITY_DEFAULT_LEFT;
+          best_split_type = kInequalityDefaultLeft;
           best_contrib_for_left = contrib_for_left[0];
           best_contrib_for_right = contrib_for_right[0];
         }
diff --git a/tensorflow/core/kernels/cholesky_op.cc b/tensorflow/core/kernels/cholesky_op.cc
index 744436c..8dfdd8d 100644
--- a/tensorflow/core/kernels/cholesky_op.cc
+++ b/tensorflow/core/kernels/cholesky_op.cc
@@ -132,7 +132,7 @@
     // Copy the lower triangular part of the input matrices to the output and
     // set the strictly upper triangular part to zero. We use a pre-existing
     // kernel MatrixBandPart to do this for all matrices in the batch at once,
-    // before we launch each of the Cholesky factorization kernels in paralle.
+    // before we launch each of the Cholesky factorization kernels.
     auto input_reshaped = input.template flat_inner_dims<Scalar, 3>();
     auto output_reshaped = output->template flat_inner_dims<Scalar, 3>();
     functor::MatrixBandPartFunctor<GPUDevice, Scalar> band_part;
@@ -143,16 +143,47 @@
     // Launch a Cholesky kernel for each matrix in the batch.
     const int64 batch_size = input_reshaped.dimension(0);
     std::vector<DeviceLapackInfo> dev_info;
-    dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "potrf"));
-    // TODO(rmlarsen): Use PotrfBatched for factoring many small matrices in
-    // parallel.
-    for (int batch = 0; batch < batch_size; ++batch) {
+
+#if CUDA_VERSION >= 9020
+    // Decide whether to use the batched API.
+    // TODO(rmlarsen): The value 128 was found to be optimal for the equivalent
+    // split in matrix_solve_op. Tune this heuristic.
+    constexpr int kMaxMatrixSizeToBatchSizeRatio = 128;
+    const bool use_batched_solver =
+        n <= kMaxMatrixSizeToBatchSizeRatio * batch_size;
+    if (use_batched_solver) {
+      // For small matrices or large batch sizes, we use the batched interface
+      // from cuSolver.
+      auto output_reshaped_ptrs = solver->GetScratchSpace<uint8>(
+          sizeof(Scalar*) * batch_size, "input_copt_ptrs",
+          /* on_host */ true);
+      const Scalar** output_reshaped_ptrs_base =
+          reinterpret_cast<const Scalar**>(output_reshaped_ptrs.mutable_data());
+      for (int batch = 0; batch < batch_size; ++batch) {
+        output_reshaped_ptrs_base[batch] = &output_reshaped(batch, 0, 0);
+      }
+      dev_info.push_back(
+          solver->GetDeviceLapackInfo(batch_size, "potrfBatched"));
       OP_REQUIRES_OK_ASYNC(context,
-                           solver->Potrf(CUBLAS_FILL_MODE_UPPER, n,
-                                         &output_reshaped(batch, 0, 0), n,
-                                         &dev_info.back()(batch)),
+                           solver->PotrfBatched(CUBLAS_FILL_MODE_UPPER, n,
+                                                output_reshaped_ptrs_base, n,
+                                                &dev_info.back(), batch_size),
                            done);
+    } else {
+#endif
+
+      dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "potrf"));
+      for (int batch = 0; batch < batch_size; ++batch) {
+        OP_REQUIRES_OK_ASYNC(context,
+                             solver->Potrf(CUBLAS_FILL_MODE_UPPER, n,
+                                           &output_reshaped(batch, 0, 0), n,
+                                           &dev_info.back()(batch)),
+                             done);
+      }
+
+#if CUDA_VERSION >= 9020
     }
+#endif
 
     // Register callback to check info after kernels finish.
     auto info_checker = [context, done](
diff --git a/tensorflow/core/kernels/collective_nccl.cc b/tensorflow/core/kernels/collective_nccl.cc
new file mode 100644
index 0000000..db07959
--- /dev/null
+++ b/tensorflow/core/kernels/collective_nccl.cc
@@ -0,0 +1,82 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/kernels/collective_nccl.h"
+
+#ifdef GOOGLE_CUDA
+
+#include "tensorflow/core/common_runtime/collective_util.h"
+#include "tensorflow/core/nccl/nccl_manager.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/profiler/lib/traceme.h"
+
+namespace tensorflow {
+
+NcclBase::NcclBase(CollectiveType type, const string& name)
+    : type_(type), name_(name), col_ctx_(nullptr), col_params_(nullptr) {}
+
+Status NcclBase::InitializeCollectiveParams(CollectiveParams* col_params) {
+  if (type_ != col_params->instance.type) {
+    return errors::Internal("Expected initialized type ", type_,
+                            " to match type in CollectiveParams ",
+                            col_params->instance.type);
+  }
+
+  const char* expected_name;
+  switch (type_) {
+    case REDUCTION_COLLECTIVE:
+      expected_name = "NcclReduce";
+      break;
+    case BROADCAST_COLLECTIVE:
+      expected_name = "NcclBroadcast";
+      break;
+    case GATHER_COLLECTIVE:
+      expected_name = "NcclGather";
+      break;
+    default:
+      return errors::Internal("Unexpected CollectiveType ", type_);
+  }
+
+  if (expected_name != col_params->instance.impl_details.collective_name) {
+    return errors::Internal("Unexpected combination of collective type ",
+                            col_params->instance.type, " and collective name ",
+                            col_params->instance.impl_details.collective_name,
+                            ", expected name ", expected_name);
+  }
+
+  return Status::OK();
+}
+
+Status NcclBase::InitializeCollectiveContext(CollectiveContext* col_ctx) {
+  col_ctx_ = col_ctx;
+  col_params_ = &col_ctx->col_params;
+  return collective_util::InitializeDeviceAndLocality(
+      col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
+      &col_ctx->device_locality);
+}
+
+Status NcclBase::InitializeCollectiveGroupRuntimeDetails(
+    CollGroupRuntimeDetails* col_group_runtime_details) {
+  col_group_runtime_details->communicator_key =
+      NcclManager::instance()->GenerateCommunicatorKey();
+  return Status::OK();
+}
+
+const string NcclBase::NcclCollectiveKey(const string& exec_key, int step_id) {
+  return strings::StrCat(exec_key, ":", step_id);
+}
+
+}  // namespace tensorflow
+
+#endif  // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/collective_nccl.h b/tensorflow/core/kernels/collective_nccl.h
new file mode 100644
index 0000000..024d569
--- /dev/null
+++ b/tensorflow/core/kernels/collective_nccl.h
@@ -0,0 +1,50 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_H_
+#define TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_H_
+
+#include "tensorflow/core/framework/collective.h"
+
+namespace tensorflow {
+#ifdef GOOGLE_CUDA
+
+class NcclBase : public CollectiveImplementationInterface {
+ public:
+  explicit NcclBase(CollectiveType type, const string& name);
+  ~NcclBase() override = default;
+
+  // No-op for this collective implementation.
+  Status InitializeCollectiveParams(CollectiveParams* col_params) override;
+
+  // Initializes the device objects and device localities.
+  Status InitializeCollectiveContext(CollectiveContext* col_ctx) override;
+
+  // Initialize nccl communicator key.
+  Status InitializeCollectiveGroupRuntimeDetails(
+      CollGroupRuntimeDetails* col_group_runtime_details) override;
+
+ protected:
+  const string NcclCollectiveKey(const string& exec_key, int step_id);
+
+  const CollectiveType type_;
+  const string name_;
+  CollectiveContext* col_ctx_;          // Not owned
+  const CollectiveParams* col_params_;  // Not owned
+};
+
+#endif  // GOOGLE_CUDA
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_H_
diff --git a/tensorflow/core/kernels/collective_nccl_broadcaster.cc b/tensorflow/core/kernels/collective_nccl_broadcaster.cc
new file mode 100644
index 0000000..27d691e
--- /dev/null
+++ b/tensorflow/core/kernels/collective_nccl_broadcaster.cc
@@ -0,0 +1,83 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/kernels/collective_nccl_broadcaster.h"
+
+#ifdef GOOGLE_CUDA
+
+#include "tensorflow/core/common_runtime/collective_util.h"
+#include "tensorflow/core/nccl/nccl_manager.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/profiler/lib/traceme.h"
+
+namespace tensorflow {
+
+void NcclBroadcaster::Run(StatusCallback done) {
+  auto* compute_stream = col_ctx_->op_ctx->op_device_context()->stream();
+  auto* gpu_info = col_ctx_->op_ctx->device()->tensorflow_gpu_device_info();
+  const int num_global_devices = col_params_->group.group_size;
+  const int num_local_devices = col_params_->instance.num_devices_per_task.at(
+      col_params_->instance.task_names[col_params_->default_rank]);
+  string nccl_collective_key =
+      NcclCollectiveKey(col_ctx_->exec_key, col_ctx_->step_id);
+  auto participant = absl::make_unique<NcclManager::Participant>(
+      compute_stream->parent(), compute_stream, gpu_info->event_mgr,
+      gpu_info->gpu_id, col_ctx_->input, col_ctx_->output,
+      col_params_->default_rank, std::move(done));
+  VLOG(1)
+      << "NcclBroadcast calling NcclManager::AddBroadcastSend/Recv num_tasks "
+      << col_params_->group.num_tasks << " current task "
+      << col_params_->instance.task_names[col_params_->default_rank]
+      << " num local devices " << num_local_devices << " num global devices "
+      << num_global_devices << " rank " << col_params_->default_rank
+      << " device " << col_ctx_->device_name << " instance "
+      << col_params_->instance.instance_key << " source "
+      << col_params_->is_source;
+  if (col_params_->is_source) {
+    NcclManager::instance()->AddBroadcastSend(
+        std::move(participant),
+        {std::move(nccl_collective_key), num_local_devices, num_global_devices,
+         col_params_->group.runtime_details.communicator_key,
+         col_params_->source_rank});
+  } else {
+    NcclManager::instance()->AddBroadcastRecv(
+        std::move(participant),
+        {std::move(nccl_collective_key), num_local_devices, num_global_devices,
+         col_params_->group.runtime_details.communicator_key,
+         col_params_->source_rank});
+  }
+  {
+    // `WaitForDependencies` may block if the collective instances on which this
+    // op depends have not yet launched.  When this function returns, this op is
+    // ready to go.
+    profiler::TraceMe activity("WaitForDependencies",
+                               profiler::TraceMeLevel::kInfo);
+    col_ctx_->col_exec->WaitForDependencies(*col_params_);
+    NcclManager::instance()->SignalMultiNodeReady(nccl_collective_key);
+  }
+  {
+    // When all devices at this worker have called `SignalMultiNodeReady`, the
+    // `NcclManager` will enqueue the NCCL kernel on the NCCL stream.  Thus the
+    // implementation of `Launched` keeps track of the number of devices that
+    // have launched.
+    profiler::TraceMe activity("Schedule", profiler::TraceMeLevel::kInfo);
+    col_ctx_->col_exec->Launched(*col_params_);
+  }
+}
+
+REGISTER_COLLECTIVE(NcclBroadcast, NcclBroadcaster);
+
+}  // namespace tensorflow
+
+#endif  // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/collective_nccl_broadcaster.h b/tensorflow/core/kernels/collective_nccl_broadcaster.h
new file mode 100644
index 0000000..630d0bf
--- /dev/null
+++ b/tensorflow/core/kernels/collective_nccl_broadcaster.h
@@ -0,0 +1,35 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_BROADCASTER_H_
+#define TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_BROADCASTER_H_
+
+#include "tensorflow/core/kernels/collective_nccl.h"
+
+namespace tensorflow {
+#ifdef GOOGLE_CUDA
+
+class NcclBroadcaster : public NcclBase {
+ public:
+  NcclBroadcaster() : NcclBase(BROADCAST_COLLECTIVE, "NcclBroadcast") {}
+  ~NcclBroadcaster() override = default;
+
+  // Hands off broadcast to NcclManager.
+  void Run(StatusCallback done) override;
+};
+
+#endif  // GOOGLE_CUDA
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_BROADCASTER_H_
diff --git a/tensorflow/core/kernels/collective_nccl_gatherer.cc b/tensorflow/core/kernels/collective_nccl_gatherer.cc
new file mode 100644
index 0000000..627fea6
--- /dev/null
+++ b/tensorflow/core/kernels/collective_nccl_gatherer.cc
@@ -0,0 +1,73 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/kernels/collective_nccl_gatherer.h"
+
+#ifdef GOOGLE_CUDA
+
+#include "tensorflow/core/common_runtime/collective_util.h"
+#include "tensorflow/core/nccl/nccl_manager.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/profiler/lib/traceme.h"
+
+namespace tensorflow {
+
+void NcclGatherer::Run(StatusCallback done) {
+  auto* compute_stream = col_ctx_->op_ctx->op_device_context()->stream();
+  auto* gpu_info = col_ctx_->op_ctx->device()->tensorflow_gpu_device_info();
+  const int num_global_devices = col_params_->group.group_size;
+  const int num_local_devices = col_params_->instance.num_devices_per_task.at(
+      col_params_->instance.task_names[col_params_->default_rank]);
+  string nccl_collective_key =
+      NcclCollectiveKey(col_ctx_->exec_key, col_ctx_->step_id);
+  auto participant = absl::make_unique<NcclManager::Participant>(
+      compute_stream->parent(), compute_stream, gpu_info->event_mgr,
+      gpu_info->gpu_id, col_ctx_->input, col_ctx_->output,
+      col_params_->default_rank, std::move(done));
+  VLOG(1) << "NcclGatherer calling NcclManager::AddToAllGather num_tasks "
+          << col_params_->group.num_tasks << " current task "
+          << col_params_->instance.task_names[col_params_->default_rank]
+          << " num local devices " << num_local_devices
+          << " num global devices " << num_global_devices << " rank "
+          << col_params_->default_rank << " device " << col_ctx_->device_name
+          << " instance " << col_params_->instance.instance_key;
+  NcclManager::instance()->AddToAllGather(
+      std::move(participant),
+      {std::move(nccl_collective_key), num_local_devices, num_global_devices,
+       col_params_->group.runtime_details.communicator_key,
+       /*source_rank=*/-1});
+  {
+    // `WaitForDependencies` may block if the collective instances on which this
+    // op depends have not yet launched.  When this function returns, this op is
+    // ready to go.
+    profiler::TraceMe activity("WaitForDependencies",
+                               profiler::TraceMeLevel::kInfo);
+    col_ctx_->col_exec->WaitForDependencies(*col_params_);
+    NcclManager::instance()->SignalMultiNodeReady(nccl_collective_key);
+  }
+  {
+    // When all devices at this worker have called `SignalMultiNodeReady`, the
+    // `NcclManager` will enqueue the NCCL kernel on the NCCL stream.  Thus the
+    // implementation of `Launched` keeps track of the number of devices that
+    // have launched.
+    profiler::TraceMe activity("Schedule", profiler::TraceMeLevel::kInfo);
+    col_ctx_->col_exec->Launched(*col_params_);
+  }
+}
+
+REGISTER_COLLECTIVE(NcclGather, NcclGatherer);
+
+}  // namespace tensorflow
+
+#endif  // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/collective_nccl_gatherer.h b/tensorflow/core/kernels/collective_nccl_gatherer.h
new file mode 100644
index 0000000..9113d92
--- /dev/null
+++ b/tensorflow/core/kernels/collective_nccl_gatherer.h
@@ -0,0 +1,35 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_GATHERER_H_
+#define TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_GATHERER_H_
+
+#include "tensorflow/core/kernels/collective_nccl.h"
+
+namespace tensorflow {
+#ifdef GOOGLE_CUDA
+
+class NcclGatherer : public NcclBase {
+ public:
+  NcclGatherer() : NcclBase(GATHER_COLLECTIVE, "NcclGather") {}
+  ~NcclGatherer() override = default;
+
+  // Hands off all-gather to NcclManager.
+  void Run(StatusCallback done) override;
+};
+
+#endif  // GOOGLE_CUDA
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_GATHERER_H_
diff --git a/tensorflow/core/kernels/collective_nccl_reducer.cc b/tensorflow/core/kernels/collective_nccl_reducer.cc
index 8fd6b15..b6c140b 100644
--- a/tensorflow/core/kernels/collective_nccl_reducer.cc
+++ b/tensorflow/core/kernels/collective_nccl_reducer.cc
@@ -22,42 +22,8 @@
 #include "tensorflow/core/profiler/lib/traceme.h"
 
 namespace tensorflow {
+
 namespace {
-string NcclCollectiveKey(const string& exec_key, int step_id) {
-  return strings::StrCat(exec_key, ":", step_id);
-}
-}  // namespace
-
-NcclReducer::NcclReducer() : col_ctx_(nullptr), col_params_(nullptr) {}
-
-Status NcclReducer::InitializeCollectiveParams(CollectiveParams* col_params) {
-  if (col_params->instance.type != REDUCTION_COLLECTIVE ||
-      col_params->instance.impl_details.collective_name != "NcclReduce") {
-    return errors::Internal("Unexpected collective type ",
-                            col_params->instance.type, " expected ",
-                            REDUCTION_COLLECTIVE, "; or collective name ",
-                            col_params->instance.impl_details.collective_name,
-                            " expected NcclReduce");
-  } else {
-    return Status::OK();
-  }
-}
-
-Status NcclReducer::InitializeCollectiveContext(CollectiveContext* col_ctx) {
-  col_ctx_ = col_ctx;
-  col_params_ = &col_ctx->col_params;
-  return collective_util::InitializeDeviceAndLocality(
-      col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
-      &col_ctx->device_locality);
-}
-
-Status NcclReducer::InitializeCollectiveGroupRuntimeDetails(
-    CollGroupRuntimeDetails* col_group_runtime_details) {
-  col_group_runtime_details->communicator_key =
-      NcclManager::instance()->GenerateCommunicatorKey();
-  return Status::OK();
-}
-
 Status ReductionOp(const string& merge_op, ncclRedOp_t* reduction_op) {
   if (merge_op == "Add") {
     *reduction_op = ncclSum;
@@ -70,6 +36,7 @@
                             merge_op);
   }
 }
+}  // namespace
 
 void NcclReducer::Run(StatusCallback done) {
   ncclRedOp_t reduction_op;
@@ -155,7 +122,7 @@
   NcclManager::instance()->AddToAllReduce(
       std::move(participant),
       {nccl_collective_key, num_local_devices, num_global_devices,
-       col_params_->group.runtime_details.communicator_key},
+       col_params_->group.runtime_details.communicator_key, /*source_rank=*/-1},
       reduction_op);
 
   // NOTE(ayushd): We need to synchronize NCCL launches across nodes to prevent
diff --git a/tensorflow/core/kernels/collective_nccl_reducer.h b/tensorflow/core/kernels/collective_nccl_reducer.h
index f04a5b5..00919cb 100644
--- a/tensorflow/core/kernels/collective_nccl_reducer.h
+++ b/tensorflow/core/kernels/collective_nccl_reducer.h
@@ -15,32 +15,18 @@
 #ifndef TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_REDUCER_H_
 #define TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_REDUCER_H_
 
-#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/kernels/collective_nccl.h"
 
 namespace tensorflow {
 #ifdef GOOGLE_CUDA
 
-class NcclReducer : public CollectiveImplementationInterface {
+class NcclReducer : public NcclBase {
  public:
-  NcclReducer();
+  NcclReducer() : NcclBase(REDUCTION_COLLECTIVE, "NcclReduce") {}
   ~NcclReducer() override = default;
 
-  // No-op for this collective implementation.
-  Status InitializeCollectiveParams(CollectiveParams* col_params) override;
-
-  // Initializes the device objects and device localities.
-  Status InitializeCollectiveContext(CollectiveContext* col_ctx) override;
-
-  // Initialize nccl communicator key.
-  Status InitializeCollectiveGroupRuntimeDetails(
-      CollGroupRuntimeDetails* col_group_runtime_details) override;
-
   // Hands off all reduce to NcclManager.
   void Run(StatusCallback done) override;
-
- private:
-  CollectiveContext* col_ctx_;          // Not owned
-  const CollectiveParams* col_params_;  // Not owned
 };
 
 #endif  // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/collective_nccl_reducer_test.cc b/tensorflow/core/kernels/collective_nccl_reducer_test.cc
deleted file mode 100644
index 00dfa72..0000000
--- a/tensorflow/core/kernels/collective_nccl_reducer_test.cc
+++ /dev/null
@@ -1,333 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifdef GOOGLE_CUDA
-
-#include "tensorflow/core/kernels/collective_nccl_reducer.h"
-
-#include <algorithm>
-
-#include "absl/memory/memory.h"
-#include "tensorflow/core/common_runtime/base_collective_executor.h"
-#include "tensorflow/core/common_runtime/device.h"
-#include "tensorflow/core/common_runtime/device_factory.h"
-#include "tensorflow/core/common_runtime/device_mgr.h"
-#include "tensorflow/core/common_runtime/device_resolver_local.h"
-#include "tensorflow/core/common_runtime/process_util.h"
-#include "tensorflow/core/common_runtime/test_collective_executor_mgr.h"
-#include "tensorflow/core/framework/collective.h"
-#include "tensorflow/core/framework/fake_input.h"
-#include "tensorflow/core/framework/node_def_builder.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/lib/core/notification.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/public/session_options.h"
-#include "tensorflow/core/public/version.h"
-
-namespace tensorflow {
-static constexpr int kStepId = 10;
-
-std::unique_ptr<OpKernel> GetKernel(const NodeDef& node, DeviceBase* device) {
-  Status status;
-  std::unique_ptr<OpKernel> k = CreateOpKernel(
-      DEVICE_GPU, device, device->GetAllocator(AllocatorAttributes()), node,
-      TF_GRAPH_DEF_VERSION, &status);
-  if (!status.ok()) LOG(FATAL) << status;
-  return k;
-}
-
-std::unique_ptr<OpKernel> GetAdd(DeviceBase* device) {
-  NodeDef node_def;
-  NodeDefBuilder builder("add_node", "Add");
-  TF_CHECK_OK(builder.Attr("T", DT_FLOAT)
-                  .Input(FakeInput(DT_FLOAT))
-                  .Input(FakeInput(DT_FLOAT))
-                  .Finalize(&node_def));
-  return GetKernel(node_def, device);
-}
-
-std::unique_ptr<OpKernel> GetDiv(DeviceBase* device) {
-  NodeDef node_def;
-  NodeDefBuilder builder("add_node", "Div");
-  TF_CHECK_OK(builder.Attr("T", DT_FLOAT)
-                  .Input(FakeInput(DT_FLOAT))
-                  .Input(FakeInput(DT_FLOAT))
-                  .Finalize(&node_def));
-  return GetKernel(node_def, device);
-}
-
-class NcclReducerTest : public ::testing::Test {
- protected:
-  ~NcclReducerTest() override {
-    if (col_exec_) col_exec_->Unref();
-  }
-
-  void InitGPUDevices() {
-    std::vector<std::unique_ptr<Device>> all_devices;
-    SessionOptions session_options;
-    session_options.config.mutable_gpu_options()
-        ->set_per_process_gpu_memory_fraction(0.1);
-    session_options.env = Env::Default();
-    Status s = DeviceFactory::GetFactory(DEVICE_GPU)
-                   ->AddDevices(session_options, "", &all_devices);
-    TF_CHECK_OK(s);
-    for (std::unique_ptr<Device>& d : all_devices) {
-      if (d->device_type() == "GPU") {
-        gpus_.emplace_back(std::move(d));
-      }
-    }
-  }
-
-  void Init(int num_ranks) {
-    setenv("NCCL_DEBUG", "INFO", 1 /* replace */);
-    setenv("NCCL_LAUNCH_MODE", "PARALLEL", 1 /* replace */);
-    InitGPUDevices();
-    std::vector<std::unique_ptr<Device>> local_devices;
-    std::vector<string> device_names;
-    for (int rank = 0; rank < num_ranks; ++rank) {
-      if (rank < gpus_.size()) {
-        local_devices.emplace_back(std::move(gpus_[rank]));
-      }
-    }
-    int num_gpus = local_devices.size();
-    for (const auto& device : local_devices) {
-      device_names.push_back(device->name());
-      VLOG(2) << device->name();
-    }
-    if (!dev_mgr_) dev_mgr_.reset(new DeviceMgr(std::move(local_devices)));
-    col_exec_ = new BaseCollectiveExecutor(
-        &col_exec_mgr_, /*remote_access=*/nullptr, kStepId, dev_mgr_.get(),
-        /*gpu_ring_order=*/nullptr);
-
-    // Initialize collective params.
-    col_params_.name = "test_nccl_collective_op";
-    const int group_key = 5;
-    col_params_.group.group_key = group_key;
-    col_params_.group.device_type = DEVICE_GPU;
-    col_params_.group.group_size = num_ranks;
-    const int instance_key = 23;
-    col_params_.instance.instance_key = instance_key;
-    col_params_.instance.type = REDUCTION_COLLECTIVE;
-    col_params_.instance.data_type = DT_FLOAT;
-    col_params_.instance.impl_details.collective_name = "NcclReduce";
-    const string task_name = "/job:worker/replica:0/task:0";
-    col_params_.instance.num_devices_per_task[task_name] = num_ranks;
-    for (int rank = 0; rank < num_ranks; ++rank) {
-      col_params_.instance.device_names.push_back(
-          device_names[rank % num_gpus]);
-      col_params_.instance.task_names.push_back(task_name);
-    }
-    for (int rank = 0; rank < num_ranks; ++rank) {
-      instances_.push_back(absl::make_unique<DeviceInstance>(
-          rank, col_params_.instance.device_names[rank], this));
-    }
-  }
-
-  void Reduce() {
-    int done = 0;
-    mutex done_mu;
-    condition_variable done_cv;
-    for (const auto& instance : instances_) {
-      DeviceInstance* di = instance.get();
-      SchedClosure([di, &done, &done_mu, &done_cv] {
-        di->DoReduce();
-        mutex_lock l(done_mu);
-        ++done;
-        done_cv.notify_all();
-      });
-    }
-
-    mutex_lock l(done_mu);
-    while (done < instances_.size()) done_cv.wait(l);
-  }
-
-  void RunTest(int num_ranks, int tensor_length) {
-    Init(num_ranks);
-    std::vector<float> expected(tensor_length, 0.0);
-    for (int rank = 0; rank < num_ranks; ++rank) {
-      DeviceInstance* instance = instances_[rank].get();
-      instance->InitTensor(DT_FLOAT, TensorShape({tensor_length}),
-                           [&expected, rank](Tensor* t) {
-                             for (size_t i = 0; i < t->NumElements(); ++i) {
-                               float value = pow(10, rank) * i;
-                               t->flat<float>()(i) = value;
-                               expected[i] += value;
-                             }
-                           });
-    }
-    Reduce();
-    // Confirm that every rank computed the same correct value.
-    for (int i = 0; i < tensor_length; ++i) {
-      expected[i] /= num_ranks;
-    }
-    for (int rank = 0; rank < instances_.size(); ++rank) {
-      TF_ASSERT_OK(instances_[rank]->status_);
-      Tensor* dev_tensor = &instances_[rank]->tensor_;
-      Tensor actual(DT_FLOAT, TensorShape({tensor_length}));
-      Notification note;
-      Device* dev = instances_[rank]->device_;
-      auto* dev_info = dev->tensorflow_gpu_device_info();
-      dev_info->default_context->CopyDeviceTensorToCPU(
-          dev_tensor, /*tensor_name=*/"", dev, &actual,
-          [&note](const Status&) { note.Notify(); });
-      note.WaitForNotification();
-      for (int i = 0; i < tensor_length; ++i) {
-        EXPECT_FLOAT_EQ(expected[i], actual.template flat<float>()(i))
-            << "Mismatch at rank " << rank << " index " << i;
-      }
-    }
-  }
-
-  std::unique_ptr<OpKernel> GetCollectiveReduce(const CollectiveParams& params,
-                                                Tensor* input,
-                                                DeviceBase* device) {
-    mutex_lock l(mu_);
-    NodeDef node_def;
-    NodeDefBuilder builder(
-        strings::StrCat("collective_reduce_", reduce_counter_++),
-        "CollectiveReduce");
-    TF_CHECK_OK(
-        builder.Attr("T", params.instance.data_type)
-            .Attr("merge_op", "Add")
-            .Attr("final_op", "Div")
-            .Attr("group_size", params.group.group_size)
-            .Attr("group_key", params.group.group_key)
-            .Attr("instance_key", params.instance.instance_key)
-            .Attr("subdiv_offsets", params.instance.impl_details.subdiv_offsets)
-            .Input(FakeInput(params.instance.data_type))
-            .Finalize(&node_def));
-    return GetKernel(node_def, device);
-  }
-
-  class DeviceInstance {
-   public:
-    DeviceInstance(int rank, const string& device_name, NcclReducerTest* parent)
-        : parent_(parent), device_name_(device_name), rank_(rank) {
-      TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(device_name_, &device_))
-          << "Could not find device " << device_name_ << " existing devices "
-          << parent_->dev_mgr_->DebugString();
-      col_params_.name = parent_->col_params_.name;
-      col_params_.default_rank = rank;
-      col_params_.group.group_key = parent_->col_params_.group.group_key;
-      col_params_.group.device_type = parent_->col_params_.group.device_type;
-      col_params_.group.group_size = parent_->col_params_.group.group_size;
-      col_params_.instance = parent->col_params_.instance;
-    }
-
-    void InitTensor(DataType dtype, const TensorShape& shape,
-                    const std::function<void(Tensor*)>& init_f) {
-      tensor_ =
-          Tensor(device_->GetAllocator(AllocatorAttributes()), dtype, shape);
-      Tensor cpu_tensor(dtype, shape);
-      init_f(&cpu_tensor);
-      VLOG(2) << "cpu_tensor " << cpu_tensor.DebugString();
-      auto* dev_info = device_->tensorflow_gpu_device_info();
-      Notification note;
-      dev_info->default_context->CopyCPUTensorToDevice(
-          &cpu_tensor, device_, &tensor_,
-          [&note](const Status&) { note.Notify(); });
-      note.WaitForNotification();
-    }
-
-    void DoReduce() {
-      col_params_.merge_op = GetAdd(device_);
-      col_params_.final_op = GetDiv(device_);
-
-      // Prepare an OpKernelContext.
-      OpKernelContext::Params op_params;
-      op_params.step_id = kStepId;
-      op_params.device = device_;
-      gtl::InlinedVector<TensorValue, 4> inputs;
-      inputs.push_back(TensorValue(&tensor_));
-      op_params.inputs = &inputs;
-      gtl::InlinedVector<AllocatorAttributes, 4> input_aa(
-          {AllocatorAttributes()});
-      op_params.input_alloc_attrs = &input_aa;
-      gtl::InlinedVector<DeviceContext*, 4> input_dc;
-      DeviceContext* dev_ctx = nullptr;
-      auto* dev_info = device_->tensorflow_gpu_device_info();
-      if (dev_info) {
-        dev_ctx = dev_info->default_context;
-        dev_ctx->Ref();
-      } else {
-        dev_ctx = new DeviceContext;
-      }
-      input_dc.push_back(dev_ctx);
-      op_params.input_device_contexts = &input_dc;
-      op_params.op_device_context = dev_ctx;
-      int forward_from = 0;
-      op_params.forward_from_array = &forward_from;
-      AllocatorAttributes generic_alloc_attr;
-      op_params.output_attr_array = &generic_alloc_attr;
-      std::unique_ptr<OpKernel> op =
-          parent_->GetCollectiveReduce(col_params_, &tensor_, device_);
-      op_params.op_kernel = op.get();
-      OpKernelContext ctx(&op_params, 1);
-
-      // We never actually execute the kernel, so we need to do the output
-      // allocation it would do, ourselves.
-      Tensor* output_tensor_ptr = nullptr;
-      TF_CHECK_OK(ctx.forward_input_or_allocate_output({0}, 0, tensor_.shape(),
-                                                       &output_tensor_ptr));
-      CHECK_EQ(output_tensor_ptr, ctx.mutable_output(0));
-
-      // Prepare a NcclReducer instance.
-      string exec_key =
-          strings::StrCat(col_params_.instance.instance_key, ":0:0");
-      NcclReducer reducer;
-      CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
-                                &ctx, &op_params, col_params_, exec_key,
-                                kStepId, &tensor_, &tensor_);
-      TF_CHECK_OK(reducer.InitializeCollectiveContext(&col_ctx));
-
-      // Run the all-reduce.
-      reducer.Run([this](Status s) { status_ = s; });
-      if (status_.ok()) {
-        CHECK(tensor_.CopyFrom(*ctx.mutable_output(0), tensor_.shape()));
-      }
-
-      dev_ctx->Unref();
-    }
-
-    NcclReducerTest* parent_;
-    string device_name_;
-    int rank_;
-    Tensor tensor_;
-    Device* device_;
-    CollectiveParams col_params_;
-    Status status_;
-  };
-
-  std::vector<std::unique_ptr<tensorflow::Device>> gpus_;
-  TestCollectiveExecutorMgr col_exec_mgr_;
-  CollectiveExecutor* col_exec_;
-  std::unique_ptr<DeviceMgr> dev_mgr_;
-  std::vector<std::unique_ptr<DeviceInstance>> instances_;
-  CollectiveParams col_params_;
-  mutex mu_;
-  int32 reduce_counter_ GUARDED_BY(mu_) = 0;
-};
-
-TEST_F(NcclReducerTest, Test2Dev16Len) { RunTest(2, 16); }
-TEST_F(NcclReducerTest, Test4Dev16Len) { RunTest(4, 16); }
-TEST_F(NcclReducerTest, Test8Dev16Len) { RunTest(8, 16); }
-TEST_F(NcclReducerTest, Test8Dev128Len) { RunTest(8, 128); }
-TEST_F(NcclReducerTest, Test8Dev1045991Len) { RunTest(8, 1048576); }
-
-}  // namespace tensorflow
-
-#endif
diff --git a/tensorflow/core/kernels/collective_nccl_test.cc b/tensorflow/core/kernels/collective_nccl_test.cc
new file mode 100644
index 0000000..b77fe2b
--- /dev/null
+++ b/tensorflow/core/kernels/collective_nccl_test.cc
@@ -0,0 +1,586 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/collective_nccl.h"
+
+#include <algorithm>
+
+#include "absl/memory/memory.h"
+#include "tensorflow/core/common_runtime/base_collective_executor.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/device_resolver_local.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/common_runtime/test_collective_executor_mgr.h"
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/collective_nccl_broadcaster.h"
+#include "tensorflow/core/kernels/collective_nccl_gatherer.h"
+#include "tensorflow/core/kernels/collective_nccl_reducer.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+static constexpr int kStepId = 10;
+
+std::unique_ptr<OpKernel> GetKernel(const NodeDef& node, DeviceBase* device) {
+  Status status;
+  std::unique_ptr<OpKernel> k = CreateOpKernel(
+      DEVICE_GPU, device, device->GetAllocator(AllocatorAttributes()), node,
+      TF_GRAPH_DEF_VERSION, &status);
+  if (!status.ok()) LOG(FATAL) << status;
+  return k;
+}
+
+std::unique_ptr<OpKernel> GetAdd(DeviceBase* device) {
+  NodeDef node_def;
+  NodeDefBuilder builder("add_node", "Add");
+  TF_CHECK_OK(builder.Attr("T", DT_FLOAT)
+                  .Input(FakeInput(DT_FLOAT))
+                  .Input(FakeInput(DT_FLOAT))
+                  .Finalize(&node_def));
+  return GetKernel(node_def, device);
+}
+
+std::unique_ptr<OpKernel> GetDiv(DeviceBase* device) {
+  NodeDef node_def;
+  NodeDefBuilder builder("add_node", "Div");
+  TF_CHECK_OK(builder.Attr("T", DT_FLOAT)
+                  .Input(FakeInput(DT_FLOAT))
+                  .Input(FakeInput(DT_FLOAT))
+                  .Finalize(&node_def));
+  return GetKernel(node_def, device);
+}
+
+class NcclTestBase : public ::testing::Test {
+ protected:
+  class DeviceInstance;
+
+  NcclTestBase(CollectiveType collective_type, const string& collective_name)
+      : collective_type_(collective_type), collective_name_(collective_name) {}
+  ~NcclTestBase() override {
+    if (col_exec_) col_exec_->Unref();
+  }
+
+  void InitGPUDevices() {
+    std::vector<std::unique_ptr<Device>> all_devices;
+    SessionOptions session_options;
+    session_options.config.mutable_gpu_options()
+        ->set_per_process_gpu_memory_fraction(0.1);
+    session_options.env = Env::Default();
+    Status s = DeviceFactory::GetFactory(DEVICE_GPU)
+                   ->AddDevices(session_options, "", &all_devices);
+    TF_CHECK_OK(s);
+    for (std::unique_ptr<Device>& d : all_devices) {
+      if (d->device_type() == "GPU") {
+        gpus_.emplace_back(std::move(d));
+      }
+    }
+  }
+
+  void Init(const int num_ranks, const int instance_key) {
+    setenv("NCCL_DEBUG", "INFO", 1 /* replace */);
+    setenv("NCCL_LAUNCH_MODE", "PARALLEL", 1 /* replace */);
+    InitGPUDevices();
+    std::vector<std::unique_ptr<Device>> local_devices;
+    std::vector<string> device_names;
+    for (int rank = 0; rank < num_ranks; ++rank) {
+      if (rank < gpus_.size()) {
+        local_devices.emplace_back(std::move(gpus_[rank]));
+      }
+    }
+    int num_gpus = local_devices.size();
+    for (const auto& device : local_devices) {
+      device_names.push_back(device->name());
+      VLOG(2) << device->name();
+    }
+    if (!dev_mgr_) dev_mgr_.reset(new DeviceMgr(std::move(local_devices)));
+    col_exec_ = new BaseCollectiveExecutor(
+        &col_exec_mgr_, /*remote_access=*/nullptr, kStepId, dev_mgr_.get(),
+        /*gpu_ring_order=*/nullptr);
+
+    // Initialize collective params.
+    col_params_.name = "test_nccl_collective_op";
+    const int group_key = num_ranks;
+    col_params_.group.group_key = group_key;
+    col_params_.group.device_type = DEVICE_GPU;
+    col_params_.group.group_size = num_ranks;
+    col_params_.instance.instance_key = instance_key;
+    col_params_.instance.type = collective_type_;
+    col_params_.instance.data_type = DT_FLOAT;
+    col_params_.instance.impl_details.collective_name = collective_name_;
+    const string task_name = "/job:worker/replica:0/task:0";
+    col_params_.instance.num_devices_per_task[task_name] = num_ranks;
+    for (int rank = 0; rank < num_ranks; ++rank) {
+      col_params_.instance.device_names.push_back(
+          device_names[rank % num_gpus]);
+      col_params_.instance.task_names.push_back(task_name);
+    }
+    for (int rank = 0; rank < num_ranks; ++rank) {
+      instances_.push_back(absl::make_unique<DeviceInstance>(
+          rank, col_params_.instance.device_names[rank], this));
+    }
+  }
+
+  // Initialize `input` tensor at rank `rank`.
+  virtual void InitInput(Tensor* input, const int rank) = 0;
+
+  // Initialize `expected` output at all `num_ranks` ranks.
+  virtual void InitExpected(std::vector<float>* expected,
+                            const int tensor_length, const int num_ranks) = 0;
+
+  // Initialize device `di` specific to the collective op.
+  virtual void InitDevice(DeviceInstance* di) = 0;
+
+  // Run collective op on device `di`.
+  virtual void RunCollectiveOnDevice(DeviceInstance* di) = 0;
+
+  void RunCollective() {
+    int done = 0;
+    mutex done_mu;
+    condition_variable done_cv;
+    for (const auto& instance : instances_) {
+      DeviceInstance* di = instance.get();
+      InitDevice(di);
+      SchedClosure([this, di, &done, &done_mu, &done_cv] {
+        RunCollectiveOnDevice(di);
+        mutex_lock l(done_mu);
+        ++done;
+        done_cv.notify_all();
+      });
+    }
+
+    mutex_lock l(done_mu);
+    while (done < instances_.size()) done_cv.wait(l);
+  }
+
+  void RunTest(int num_ranks, int input_length, int instance_key) {
+    Init(num_ranks, instance_key);
+    std::vector<float> expected;
+    InitExpected(&expected, input_length, num_ranks);
+    if (VLOG_IS_ON(3)) {
+      string str_buf;
+      for (const auto& x : expected) {
+        strings::StrAppend(&str_buf, " ", x);
+      }
+      VLOG(3) << "Expected output " << str_buf;
+    }
+    for (int rank = 0; rank < num_ranks; ++rank) {
+      DeviceInstance* instance = instances_[rank].get();
+      instance->InitTensor(DT_FLOAT, TensorShape({input_length}),
+                           [this, rank](Tensor* t) { InitInput(t, rank); });
+    }
+    RunCollective();
+    // Confirm that every rank computed the same correct value.
+    for (int rank = 0; rank < instances_.size(); ++rank) {
+      TF_ASSERT_OK(instances_[rank]->status_);
+      Tensor* output = &instances_[rank]->output_;
+      const int output_length = output->NumElements();
+      VLOG(2) << "rank " << rank << " output " << output << " buf "
+              << DMAHelper::base(output);
+      Tensor actual(DT_FLOAT, TensorShape({output_length}));
+      Notification note;
+      Device* dev = instances_[rank]->device_;
+      auto* dev_info = dev->tensorflow_gpu_device_info();
+      dev_info->default_context->CopyDeviceTensorToCPU(
+          output, /*tensor_name=*/"", dev, &actual, [&note](const Status& s) {
+            TF_CHECK_OK(s);
+            note.Notify();
+          });
+      note.WaitForNotification();
+      VLOG(3) << "rank " << rank << " got output tensor "
+              << actual.DebugString(output_length);
+      for (int i = 0; i < output_length; ++i) {
+        EXPECT_FLOAT_EQ(expected[i], actual.template flat<float>()(i))
+            << "Mismatch at rank " << rank << " index " << i;
+      }
+    }
+  }
+
+  std::unique_ptr<OpKernel> GetCollectiveReduceOpKernel(
+      const CollectiveParams& params, Tensor* input, DeviceBase* device) {
+    mutex_lock l(mu_);
+    NodeDef node_def;
+    NodeDefBuilder builder(strings::StrCat("collective_reduce_", op_counter_++),
+                           "CollectiveReduce");
+    TF_CHECK_OK(
+        builder.Attr("T", params.instance.data_type)
+            .Attr("merge_op", "Add")
+            .Attr("final_op", "Div")
+            .Attr("group_size", params.group.group_size)
+            .Attr("group_key", params.group.group_key)
+            .Attr("instance_key", params.instance.instance_key)
+            .Attr("subdiv_offsets", params.instance.impl_details.subdiv_offsets)
+            .Input(FakeInput(params.instance.data_type))
+            .Finalize(&node_def));
+    return GetKernel(node_def, device);
+  }
+
+  class DeviceInstance {
+   public:
+    DeviceInstance(int rank, const string& device_name, NcclTestBase* parent)
+        : parent_(parent), device_name_(device_name), rank_(rank) {
+      TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(device_name_, &device_))
+          << "Could not find device " << device_name_ << " existing devices "
+          << parent_->dev_mgr_->DebugString();
+      col_params_.name = parent_->col_params_.name;
+      col_params_.default_rank = rank;
+      col_params_.group.group_key = parent_->col_params_.group.group_key;
+      col_params_.group.device_type = parent_->col_params_.group.device_type;
+      col_params_.group.group_size = parent_->col_params_.group.group_size;
+      col_params_.instance = parent->col_params_.instance;
+    }
+
+    void InitTensor(DataType dtype, const TensorShape& shape,
+                    const std::function<void(Tensor*)>& init_f) {
+      input_ =
+          Tensor(device_->GetAllocator(AllocatorAttributes()), dtype, shape);
+      Tensor cpu_tensor(dtype, shape);
+      init_f(&cpu_tensor);
+      if (VLOG_IS_ON(3)) {
+        VLOG(3) << "input tensor "
+                << cpu_tensor.DebugString(shape.num_elements());
+      } else {
+        VLOG(2) << "input tensor " << cpu_tensor.DebugString();
+      }
+      auto* dev_info = device_->tensorflow_gpu_device_info();
+      Notification note;
+      dev_info->default_context->CopyCPUTensorToDevice(
+          &cpu_tensor, device_, &input_, [&note](const Status& s) {
+            TF_CHECK_OK(s);
+            note.Notify();
+          });
+      note.WaitForNotification();
+    }
+
+    void PrepareDeviceContext(OpKernelContext::Params* params) {
+      params->step_id = kStepId;
+      params->device = device_;
+      DeviceContext* dev_ctx = nullptr;
+      auto* dev_info = device_->tensorflow_gpu_device_info();
+      if (dev_info) {
+        dev_ctx = dev_info->default_context;
+        dev_ctx->Ref();
+      } else {
+        dev_ctx = new DeviceContext;
+      }
+      params->op_device_context = dev_ctx;
+    }
+
+    void RunReduce() {
+      // Prepare an OpKernelContext.
+      OpKernelContext::Params op_params;
+      PrepareDeviceContext(&op_params);
+
+      // Prepare inputs and outputs to OpKernel.
+      gtl::InlinedVector<TensorValue, 4> inputs;
+      inputs.push_back(TensorValue(&input_));
+      op_params.inputs = &inputs;
+      gtl::InlinedVector<AllocatorAttributes, 4> input_aa(
+          {AllocatorAttributes()});
+      op_params.input_alloc_attrs = &input_aa;
+      gtl::InlinedVector<DeviceContext*, 4> input_dc;
+      input_dc.push_back(op_params.op_device_context);
+      op_params.input_device_contexts = &input_dc;
+      int forward_from = 0;
+      op_params.forward_from_array = &forward_from;
+      AllocatorAttributes generic_alloc_attr;
+      op_params.output_attr_array = &generic_alloc_attr;
+      std::unique_ptr<OpKernel> op =
+          parent_->GetCollectiveReduceOpKernel(col_params_, &input_, device_);
+      op_params.op_kernel = op.get();
+      OpKernelContext ctx(&op_params, 1);
+      // We never actually execute the kernel, so we need to do the output
+      // allocation it would do, ourselves.
+      Tensor* output_tensor_ptr = nullptr;
+      TF_CHECK_OK(ctx.forward_input_or_allocate_output({0}, 0, input_.shape(),
+                                                       &output_tensor_ptr));
+      CHECK_EQ(output_tensor_ptr, ctx.mutable_output(0));
+
+      // Run the all-reduce.
+      string exec_key =
+          strings::StrCat(col_params_.instance.instance_key, ":0:0");
+      NcclReducer reducer;
+      CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
+                                /*OpKernelContext=*/&ctx, &op_params,
+                                col_params_, exec_key, kStepId,
+                                /*input=*/&input_, /*output=*/&input_);
+      TF_CHECK_OK(reducer.InitializeCollectiveContext(&col_ctx));
+      Notification note;
+      reducer.Run([this, &note](Status s) {
+        status_ = s;
+        note.Notify();
+      });
+      note.WaitForNotification();
+      if (status_.ok()) {
+        CHECK(output_.CopyFrom(*ctx.mutable_output(0), input_.shape()));
+      }
+
+      op_params.op_device_context->Unref();
+    }
+
+    void RunBroadcast() {
+      VLOG(2) << "RunBroadcast name " << parent_->collective_name_ << " rank "
+              << col_params_.default_rank;
+      // Prepare an OpKernelContext.
+      OpKernelContext::Params op_params;
+      PrepareDeviceContext(&op_params);
+      OpKernelContext ctx(&op_params, 1);
+
+      // Run broadcast.
+      string exec_key =
+          strings::StrCat(col_params_.instance.instance_key, ":0:0");
+      NcclBroadcaster broadcaster;
+      CollectiveContext col_ctx(
+          parent_->col_exec_, parent_->dev_mgr_.get(),
+          /*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
+          /*input=*/col_params_.is_source ? &input_ : nullptr,
+          /*output=*/&input_);
+      TF_CHECK_OK(broadcaster.InitializeCollectiveContext(&col_ctx));
+      Notification note;
+      broadcaster.Run([this, &note](Status s) {
+        status_ = s;
+        note.Notify();
+      });
+      note.WaitForNotification();
+      if (status_.ok()) {
+        CHECK(output_.CopyFrom(input_, input_.shape()));
+      }
+
+      op_params.op_device_context->Unref();
+    }
+
+    void RunGather() {
+      VLOG(2) << "RunGather name " << parent_->collective_name_ << " rank "
+              << col_params_.default_rank;
+      // Prepare an OpKernelContext.
+      OpKernelContext::Params op_params;
+      PrepareDeviceContext(&op_params);
+      OpKernelContext ctx(&op_params, 1);
+
+      // Allocate output.  We can't reuse the input because output has a
+      // different shape.
+      auto output_shape = input_.shape();
+      output_shape.set_dim(
+          0, output_shape.dim_size(0) * col_params_.group.group_size);
+      output_ = Tensor(device_->GetAllocator(AllocatorAttributes()), DT_FLOAT,
+                       output_shape);
+
+      // Run gather.
+      string exec_key =
+          strings::StrCat(col_params_.instance.instance_key, ":0:0");
+      NcclGatherer gatherer;
+      CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
+                                /*OpKernelContext=*/&ctx, &op_params,
+                                col_params_, exec_key, kStepId,
+                                /*input=*/&input_,
+                                /*output=*/&output_);
+      TF_CHECK_OK(gatherer.InitializeCollectiveContext(&col_ctx));
+      Notification note;
+      gatherer.Run([this, &note](Status s) {
+        status_ = s;
+        note.Notify();
+      });
+      note.WaitForNotification();
+
+      op_params.op_device_context->Unref();
+    }
+
+    NcclTestBase* parent_;
+    string device_name_;
+    int rank_;
+    Tensor input_;
+    Tensor output_;
+    Device* device_;
+    CollectiveParams col_params_;
+    Status status_;
+  };
+
+  CollectiveType collective_type_;
+  const string collective_name_;
+  std::vector<std::unique_ptr<tensorflow::Device>> gpus_;
+  TestCollectiveExecutorMgr col_exec_mgr_;
+  CollectiveExecutor* col_exec_;
+  std::unique_ptr<DeviceMgr> dev_mgr_;
+  std::vector<std::unique_ptr<DeviceInstance>> instances_;
+  CollectiveParams col_params_;
+  mutex mu_;
+  int32 op_counter_ GUARDED_BY(mu_) = 0;
+};
+
+class NcclReducerTest : public NcclTestBase {
+ protected:
+  NcclReducerTest()
+      : NcclTestBase(/*collective_type=*/REDUCTION_COLLECTIVE,
+                     /*collective_name=*/"NcclReduce") {}
+  ~NcclReducerTest() override = default;
+
+  void InitInput(Tensor* input, const int rank) override {
+    for (size_t i = 0; i < input->NumElements(); ++i) {
+      float value = pow(10, rank) * i;
+      input->flat<float>()(i) = value;
+    }
+  }
+
+  void InitExpected(std::vector<float>* expected, const int tensor_length,
+                    const int num_ranks) override {
+    expected->resize(tensor_length);
+    for (int i = 0; i < tensor_length; ++i) {
+      float expected_sum = 0.0;
+      for (int rank = 0; rank < num_ranks; ++rank) {
+        float value = pow(10, rank) * i;
+        expected_sum += value;
+      }
+      (*expected)[i] = expected_sum / num_ranks;
+    }
+  }
+
+  void InitDevice(DeviceInstance* di) override {
+    di->col_params_.merge_op = GetAdd(di->device_);
+    di->col_params_.final_op = GetDiv(di->device_);
+  }
+
+  void RunCollectiveOnDevice(DeviceInstance* di) override { di->RunReduce(); }
+};
+
+class NcclBroadcasterTest : public NcclTestBase {
+ protected:
+  NcclBroadcasterTest()
+      : NcclTestBase(/*collective_type=*/BROADCAST_COLLECTIVE,
+                     /*collective_name=*/"NcclBroadcast") {}
+  ~NcclBroadcasterTest() override = default;
+
+  void InitInput(Tensor* input, const int rank) override {
+    bool source = rank == source_rank_;
+    for (size_t i = 0; i < input->NumElements(); ++i) {
+      input->flat<float>()(i) = source ? static_cast<float>(i) : -1.0;
+    }
+  }
+
+  void InitExpected(std::vector<float>* expected, const int tensor_length,
+                    const int num_ranks) override {
+    expected->resize(tensor_length);
+    for (int i = 0; i < tensor_length; ++i) {
+      (*expected)[i] = i;
+    }
+  }
+
+  void InitDevice(DeviceInstance* di) override {
+    di->col_params_.source_rank = source_rank_;
+    di->col_params_.is_source = di->col_params_.default_rank == source_rank_;
+  }
+
+  void RunCollectiveOnDevice(DeviceInstance* di) override {
+    di->RunBroadcast();
+  }
+
+  int source_rank_ = 0;
+};
+
+class NcclGathererTest : public NcclTestBase {
+ protected:
+  NcclGathererTest()
+      : NcclTestBase(/*collective_type=*/GATHER_COLLECTIVE,
+                     /*collective_name=*/"NcclGather") {}
+  ~NcclGathererTest() override = default;
+
+  void InitInput(Tensor* input, const int rank) override {
+    for (size_t i = 0; i < input->NumElements(); ++i) {
+      float value = pow(10, rank) * i;
+      input->flat<float>()(i) = value;
+    }
+  }
+
+  void InitExpected(std::vector<float>* expected, const int tensor_length,
+                    const int num_ranks) override {
+    expected->resize(tensor_length * num_ranks, -1);
+    for (int rank = 0, i = 0; rank < num_ranks; ++rank) {
+      for (int j = 0; j < tensor_length; ++j, ++i) {
+        (*expected)[i] = pow(10, rank) * j;
+      }
+    }
+  }
+
+  void InitDevice(DeviceInstance* di) override {}
+
+  void RunCollectiveOnDevice(DeviceInstance* di) override { di->RunGather(); }
+
+  int source_rank_ = 0;
+};
+
+TEST_F(NcclReducerTest, Test2Dev16Len) {
+  RunTest(/*num_ranks=*/2, /*tensor_length=*/16, /*instance_key=*/23);
+}
+TEST_F(NcclReducerTest, Test4Dev16Len) {
+  RunTest(/*num_ranks=*/4, /*tensor_length=*/16, /*instance_key=*/23);
+}
+TEST_F(NcclReducerTest, Test8Dev16Len) {
+  RunTest(/*num_ranks=*/8, /*tensor_length=*/16, /*instance_key=*/23);
+}
+TEST_F(NcclReducerTest, Test8Dev128Len) {
+  RunTest(/*num_ranks=*/8, /*tensor_length=*/128, /*instance_key=*/23);
+}
+TEST_F(NcclReducerTest, Test8Dev1045991Len) {
+  RunTest(/*num_ranks=*/8, /*tensor_length=*/1048576, /*instance_key=*/23);
+}
+
+TEST_F(NcclBroadcasterTest, Test2Dev16LenSrc0) {
+  RunTest(/*num_ranks=*/2, /*tensor_length=*/16, /*instance_key=*/23);
+}
+TEST_F(NcclBroadcasterTest, Test4Dev16LenSrc1) {
+  source_rank_ = 1;
+  RunTest(/*num_ranks=*/4, /*tensor_length=*/16, /*instance_key=*/23);
+}
+TEST_F(NcclBroadcasterTest, Test8Dev16LenSrc7) {
+  source_rank_ = 7;
+  RunTest(/*num_ranks=*/8, /*tensor_length=*/16, /*instance_key=*/23);
+}
+TEST_F(NcclBroadcasterTest, Test8Dev128LenSrc0) {
+  RunTest(/*num_ranks=*/8, /*tensor_length=*/128, /*instance_key=*/24);
+}
+TEST_F(NcclBroadcasterTest, Test8Dev1045991LenSrc0) {
+  RunTest(/*num_ranks=*/8, /*tensor_length=*/1048576, /*instance_key=*/23);
+}
+
+TEST_F(NcclGathererTest, Test2Dev16Len) {
+  RunTest(/*num_ranks=*/2, /*tensor_length=*/16, /*instance_key=*/23);
+}
+TEST_F(NcclGathererTest, Test4Dev16Len) {
+  RunTest(/*num_ranks=*/4, /*tensor_length=*/16, /*instance_key=*/23);
+}
+TEST_F(NcclGathererTest, Test8Dev16Len) {
+  RunTest(/*num_ranks=*/8, /*tensor_length=*/16, /*instance_key=*/23);
+}
+TEST_F(NcclGathererTest, Test8Dev128Len) {
+  RunTest(/*num_ranks=*/8, /*tensor_length=*/128, /*instance_key=*/24);
+}
+TEST_F(NcclGathererTest, Test8Dev1045991Len) {
+  RunTest(/*num_ranks=*/8, /*tensor_length=*/1048576, /*instance_key=*/23);
+}
+
+}  // namespace tensorflow
+
+#endif
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
index b262dc5..ca26bf4 100644
--- a/tensorflow/core/kernels/collective_ops.cc
+++ b/tensorflow/core/kernels/collective_ops.cc
@@ -40,9 +40,9 @@
     if (col_params_.group.group_size >
         col_params_.instance.device_names.size()) {
       // This is the first invocation: Finish initializing col_params_.
-      // Call in a blockable thread because it's not guaranteed that
-      // this call cannot block.
-      c->env()->SchedClosure([this, c, done, col_exec]() {
+      // Schedule the `CompleteParamsAsync` call on a work queue that can handle
+      // blocking work because it's not guaranteed that this call cannot block.
+      c->collective_executor()->RunClosure([this, c, done, col_exec]() {
         VLOG(1) << "CollectiveOpKernel CompleteParams for collective "
                 << col_params_.name << " device " << c->device()->name()
                 << " group " << col_params_.group.group_key << " instance "
diff --git a/tensorflow/core/kernels/concat_lib_cpu.cc b/tensorflow/core/kernels/concat_lib_cpu.cc
index 547a7b4..199bb2a 100644
--- a/tensorflow/core/kernels/concat_lib_cpu.cc
+++ b/tensorflow/core/kernels/concat_lib_cpu.cc
@@ -73,6 +73,8 @@
 REGISTER(quint16)
 REGISTER(qint16)
 REGISTER(qint32)
+REGISTER(uint32)
+REGISTER(uint64)
 
 #if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) && \
     !defined(__ANDROID_TYPES_FULL__)
diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc
index ea0c486..350f5e7 100644
--- a/tensorflow/core/kernels/concat_op.cc
+++ b/tensorflow/core/kernels/concat_op.cc
@@ -194,6 +194,8 @@
 REGISTER_CONCAT(quint16);
 REGISTER_CONCAT(qint16);
 REGISTER_CONCAT(qint32);
+REGISTER_CONCAT(uint32);
+REGISTER_CONCAT(uint64);
 
 #undef REGISTER_CONCAT
 
diff --git a/tensorflow/core/kernels/conditional_accumulator_base_op.h b/tensorflow/core/kernels/conditional_accumulator_base_op.h
index ab54fc1..a2bfa2c 100644
--- a/tensorflow/core/kernels/conditional_accumulator_base_op.h
+++ b/tensorflow/core/kernels/conditional_accumulator_base_op.h
@@ -113,7 +113,7 @@
     // Verify that the shared accumulator is compatible
     // with the requested arguments.
     TF_RETURN_IF_ERROR(accumulator->MatchesNodeDef(def()));
-    auto h = accumulator_handle_.AccessTensor(ctx)->template flat<string>();
+    auto h = accumulator_handle_.AccessTensor(ctx)->template flat<tstring>();
     h(0) = cinfo_.container();
     h(1) = cinfo_.name();
     accumulator_handle_set_ = true;
diff --git a/tensorflow/core/kernels/conditional_accumulator_op.cc b/tensorflow/core/kernels/conditional_accumulator_op.cc
index 2bbd0ec..3c7fbe0 100644
--- a/tensorflow/core/kernels/conditional_accumulator_op.cc
+++ b/tensorflow/core/kernels/conditional_accumulator_op.cc
@@ -85,7 +85,7 @@
 
   void SetHandleToOutput(OpKernelContext* ctx)
       SHARED_LOCKS_REQUIRED(mu_) override {
-    auto h = accumulator_handle_.AccessTensor(ctx)->template flat<string>();
+    auto h = accumulator_handle_.AccessTensor(ctx)->template flat<tstring>();
     h(0) = cinfo_.container();
     h(1) = cinfo_.name();
     OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc
index ea934b8..3292160 100644
--- a/tensorflow/core/kernels/conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc
@@ -53,6 +53,11 @@
 #include "tensorflow/core/protobuf/autotuning.pb.h"
 #include "tensorflow/core/util/proto/proto_utils.h"
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if GOOGLE_CUDA
+#include "tensorflow/stream_executor/cuda/ptxas_utils.h"
+#include "tensorflow/stream_executor/cuda/redzone_allocator.h"
+#include "tensorflow/stream_executor/tf_allocator_adapter.h"
+#endif  // GOOGLE_CUDA
 
 namespace {
 
@@ -408,11 +413,15 @@
                 errors::InvalidArgument(
                     "Current implementation does not yet support "
                     "dilations in the batch and depth dimensions."));
-    // TODO(yangzihao): Add a CPU implementation for dilated convolution.
-    OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1),
-                errors::InvalidArgument(
-                    "Current libxsmm and customized CPU implementations do "
-                    "not yet support dilation rates larger than 1."));
+    if (std::is_same<Device, CPUDevice>::value ||
+        std::is_same<Device, GPUDevice>::value) {
+      // TODO(yangzihao): Add a CPU implementation for dilated convolution.
+      OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1),
+                  errors::InvalidArgument(
+                      "Current libxsmm and customized CPU implementations do "
+                      "not yet support dilation rates larger than 1."));
+      dilations_ = {1, 1, 1, 1};
+    }
   }
 
   void Compute(OpKernelContext* context) override {
@@ -434,8 +443,8 @@
         context,
         ConvBackpropComputeDimensionsV2(
             "Conv2DCustomBackpropFilter", /*num_spatial_dims=*/2, input.shape(),
-            filter_shape, out_backprop.shape(), /*dilations=*/{1, 1, 1, 1},
-            strides_, padding_, explicit_paddings_, data_format_, &dims));
+            filter_shape, out_backprop.shape(), dilations_, strides_, padding_,
+            explicit_paddings_, data_format_, &dims));
 
     Tensor* filter_backprop;
     OP_REQUIRES_OK(context,
@@ -929,10 +938,10 @@
     transformed_input = compatible_input;
   }
 
-  auto out_backprop_ptr =
+  se::DeviceMemory<T> out_backprop_ptr =
       AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
                      transformed_out_backprop.template flat<T>().size());
-  auto filter_backprop_ptr =
+  se::DeviceMemory<T> filter_backprop_ptr =
       AsDeviceMemory(pre_transformed_filter_backprop.template flat<T>().data(),
                      pre_transformed_filter_backprop.template flat<T>().size());
   auto input_ptr = AsDeviceMemory(transformed_input.template flat<T>().data(),
@@ -966,6 +975,15 @@
   if (cudnn_use_autotune && !AutoTuneConvBwdFilter::GetInstance()->Find(
                                 conv_parameters, &algorithm_config)) {
 #if GOOGLE_CUDA
+
+    se::TfAllocatorAdapter tf_allocator_adapter(
+        stream->parent()->platform(), ctx->device()->GetAllocator({}));
+    se::cuda::RedzoneAllocator rz_allocator(stream, &tf_allocator_adapter,
+                                            se::cuda::PtxCompilationOptions());
+
+    se::DeviceMemory<T> filter_backprop_ptr_rz(
+        WrapRedzoneBestEffort(&rz_allocator, filter_backprop_ptr));
+
     std::vector<AlgorithmDesc> algorithms;
     CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
         conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(stream->parent()),
@@ -976,13 +994,21 @@
       // accuracy.
       DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
                                             ctx);
+      se::cuda::RedzoneAllocator rz_scratch_allocator(
+          stream, &tf_allocator_adapter, se::cuda::PtxCompilationOptions(),
+          /*memory_limit=*/ConvolveBackwardFilterScratchSize);
+      se::ScratchAllocator* allocator_used =
+          !RedzoneCheckDisabled()
+              ? static_cast<se::ScratchAllocator*>(&rz_scratch_allocator)
+              : static_cast<se::ScratchAllocator*>(&scratch_allocator);
+
       ProfileResult profile_result;
       bool cudnn_launch_status =
           stream
               ->ThenConvolveBackwardFilterWithAlgorithm(
                   input_desc, input_ptr, output_desc, out_backprop_ptr,
-                  conv_desc, filter_desc, &filter_backprop_ptr,
-                  &scratch_allocator, AlgorithmConfig(profile_algorithm),
+                  conv_desc, filter_desc, &filter_backprop_ptr_rz,
+                  allocator_used, AlgorithmConfig(profile_algorithm),
                   &profile_result)
               .ok();
       if (cudnn_launch_status && profile_result.is_valid()) {
@@ -991,14 +1017,21 @@
         result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
         result.mutable_conv()->set_tensor_ops_enabled(
             profile_algorithm.tensor_ops_enabled());
-        result.set_scratch_bytes(scratch_allocator.TotalByteSize());
+
+        result.set_scratch_bytes(
+            !RedzoneCheckDisabled()
+                ? rz_scratch_allocator.TotalAllocatedBytesExcludingRedzones()
+                : scratch_allocator.TotalByteSize());
         *result.mutable_run_time() = proto_utils::ToDurationProto(
             absl::Milliseconds(profile_result.elapsed_time_in_ms()));
+
+        CheckRedzones(rz_scratch_allocator, &result);
+        CheckRedzones(rz_allocator, &result);
       }
     }
     LogConvAutotuneResults(se::dnn::ConvolutionKind::BACKWARD_FILTER,
                            se::dnn::ToDataType<T>::value, input_ptr,
-                           filter_backprop_ptr, out_backprop_ptr, input_desc,
+                           filter_backprop_ptr_rz, out_backprop_ptr, input_desc,
                            filter_desc, output_desc, conv_desc,
                            stream->parent(), results);
     OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc
index 8974aa1..97e7079 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops.cc
@@ -53,6 +53,11 @@
 #include "tensorflow/core/protobuf/autotuning.pb.h"
 #include "tensorflow/core/util/proto/proto_utils.h"
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if GOOGLE_CUDA
+#include "tensorflow/stream_executor/cuda/ptxas_utils.h"
+#include "tensorflow/stream_executor/cuda/redzone_allocator.h"
+#include "tensorflow/stream_executor/tf_allocator_adapter.h"
+#endif  // GOOGLE_CUDA
 
 namespace {
 
@@ -1096,6 +1101,16 @@
   if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find(
                                 conv_parameters, &algorithm_config)) {
 #if GOOGLE_CUDA
+
+    se::TfAllocatorAdapter tf_allocator_adapter(
+        stream->parent()->platform(), ctx->device()->GetAllocator({}));
+
+    se::cuda::RedzoneAllocator rz_allocator(stream, &tf_allocator_adapter,
+                                            se::cuda::PtxCompilationOptions());
+
+    se::DeviceMemory<T> in_backprop_ptr_rz(
+        WrapRedzoneBestEffort(&rz_allocator, in_backprop_ptr));
+
     std::vector<AlgorithmDesc> algorithms;
     CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
         conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(stream->parent()),
@@ -1106,12 +1121,19 @@
       // accuracy.
       DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
                                             ctx);
+      se::cuda::RedzoneAllocator rz_scratch_allocator(
+          stream, &tf_allocator_adapter, se::cuda::PtxCompilationOptions(),
+          /*memory_limit=*/ConvolveBackwardDataScratchSize);
+      se::ScratchAllocator* allocator_used =
+          !RedzoneCheckDisabled()
+              ? static_cast<se::ScratchAllocator*>(&rz_scratch_allocator)
+              : static_cast<se::ScratchAllocator*>(&scratch_allocator);
       ProfileResult profile_result;
       bool cudnn_launch_status =
           stream
               ->ThenConvolveBackwardDataWithAlgorithm(
                   filter_desc, filter_ptr, output_desc, out_backprop_ptr,
-                  conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
+                  conv_desc, input_desc, &in_backprop_ptr_rz, allocator_used,
                   AlgorithmConfig(profile_algorithm), &profile_result)
               .ok();
       if (cudnn_launch_status && profile_result.is_valid()) {
@@ -1120,9 +1142,15 @@
         result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
         result.mutable_conv()->set_tensor_ops_enabled(
             profile_algorithm.tensor_ops_enabled());
-        result.set_scratch_bytes(scratch_allocator.TotalByteSize());
+        result.set_scratch_bytes(
+            !RedzoneCheckDisabled()
+                ? rz_scratch_allocator.TotalAllocatedBytesExcludingRedzones()
+                : scratch_allocator.TotalByteSize());
         *result.mutable_run_time() = proto_utils::ToDurationProto(
             absl::Milliseconds(profile_result.elapsed_time_in_ms()));
+
+        CheckRedzones(rz_scratch_allocator, &result);
+        CheckRedzones(rz_allocator, &result);
       }
     }
     LogConvAutotuneResults(
diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc
index 6ab5178..0c237d0 100644
--- a/tensorflow/core/kernels/conv_grad_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc
@@ -41,7 +41,14 @@
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #include "tensorflow/core/platform/stream_executor.h"
 using stream_executor::dnn::DimIndex;
-#endif
+#include "tensorflow/core/protobuf/autotuning.pb.h"
+#include "tensorflow/core/util/proto/proto_utils.h"
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if GOOGLE_CUDA
+#include "tensorflow/stream_executor/cuda/ptxas_utils.h"
+#include "tensorflow/stream_executor/cuda/redzone_allocator.h"
+#include "tensorflow/stream_executor/tf_allocator_adapter.h"
+#endif  // GOOGLE_CUDA
 
 namespace {
 
@@ -1358,6 +1365,12 @@
     if (cudnn_use_autotune_ && !AutoTuneConv3dBwdData::GetInstance()->Find(
                                    conv_parameters, &algorithm_config)) {
 #if GOOGLE_CUDA
+      se::TfAllocatorAdapter tf_allocator_adapter(
+          stream->parent()->platform(), context->device()->GetAllocator({}));
+      se::cuda::RedzoneAllocator rz_allocator(
+          stream, &tf_allocator_adapter, se::cuda::PtxCompilationOptions());
+      se::DeviceMemory<T> in_backprop_ptr_rz(
+          WrapRedzoneBestEffort(&rz_allocator, in_backprop_ptr));
       std::vector<AlgorithmDesc> algorithms;
       CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
           conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
@@ -1365,21 +1378,42 @@
           &algorithms));
       ProfileResult best_result;
       ProfileResult best_result_no_scratch;
+      std::vector<tensorflow::AutotuneResult> results;
       for (auto profile_algorithm : algorithms) {
         // TODO(zhengxq): profile each algorithm multiple times to better
         // accuracy.
         DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
                                               context);
+        se::cuda::RedzoneAllocator rz_scratch_allocator(
+            stream, &tf_allocator_adapter, se::cuda::PtxCompilationOptions(),
+            /*memory_limit=*/ConvolveBackwardDataScratchSize);
+        se::ScratchAllocator* allocator_used =
+            !RedzoneCheckDisabled()
+                ? static_cast<se::ScratchAllocator*>(&rz_scratch_allocator)
+                : static_cast<se::ScratchAllocator*>(&scratch_allocator);
         ProfileResult profile_result;
         bool cudnn_launch_status =
             stream
                 ->ThenConvolveBackwardDataWithAlgorithm(
                     filter_desc, filter_ptr, output_desc, out_backprop_ptr,
-                    conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
+                    conv_desc, input_desc, &in_backprop_ptr_rz, allocator_used,
                     AlgorithmConfig(profile_algorithm), &profile_result)
                 .ok();
         if (cudnn_launch_status) {
           if (profile_result.is_valid()) {
+            results.emplace_back();
+            auto& result = results.back();
+            result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
+            result.mutable_conv()->set_tensor_ops_enabled(
+                profile_algorithm.tensor_ops_enabled());
+            result.set_scratch_bytes(
+                !RedzoneCheckDisabled()
+                    ? rz_scratch_allocator
+                          .TotalAllocatedBytesExcludingRedzones()
+                    : scratch_allocator.TotalByteSize());
+            *result.mutable_run_time() = proto_utils::ToDurationProto(
+                absl::Milliseconds(profile_result.elapsed_time_in_ms()));
+
             if (profile_result.elapsed_time_in_ms() <
                 best_result.elapsed_time_in_ms()) {
               best_result = profile_result;
@@ -1389,9 +1423,17 @@
                     best_result_no_scratch.elapsed_time_in_ms()) {
               best_result_no_scratch = profile_result;
             }
+            // TODO(george): they don't do results at all??
+            CheckRedzones(rz_scratch_allocator, &result);
+            CheckRedzones(rz_allocator, &result);
           }
         }
       }
+      LogConvAutotuneResults(se::dnn::ConvolutionKind::BACKWARD_DATA,
+                             se::dnn::ToDataType<T>::value, in_backprop_ptr,
+                             filter_ptr, out_backprop_ptr, input_desc,
+                             filter_desc, output_desc, conv_desc,
+                             stream->parent(), results);
       OP_REQUIRES(context,
                   best_result.is_valid() || best_result_no_scratch.is_valid(),
                   errors::NotFound("No algorithm worked!"));
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index 4ea3186..55cf29b 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -75,7 +75,6 @@
 typedef Eigen::GpuDevice GPUDevice;
 
 namespace {
-
 template <typename Device, typename T>
 struct LaunchGeneric {
   void operator()(OpKernelContext* ctx, const Tensor& input,
@@ -578,10 +577,6 @@
 template struct LaunchConv2DOp<CPUDevice, double>;
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-static bool RedzoneCheckDisabled() {
-  const char* disable_rz_str = std::getenv("TF_DISABLE_RZ_CHECK");
-  return disable_rz_str != nullptr && std::strcmp(disable_rz_str, "1") == 0;
-}
 
 int64 GetDnnWorkspaceLimit(const string& envvar_in_mb,
                            int64 default_value_in_bytes) {
@@ -608,47 +603,6 @@
                           se::dnn::AlgorithmConfig>
     AutoTuneConv;
 
-#if GOOGLE_CUDA
-// Check the passed allocator for redzone violations.
-// If violations have occurred, mark the corresponding autotune result
-// as a failure.
-static void CheckRedzones(const se::cuda::RedzoneAllocator& rz_allocator,
-                          se::Stream* stream,
-                          tensorflow::AutotuneResult* autotune_result) {
-  se::port::StatusOr<se::cuda::RedzoneAllocator::RedzoneCheckStatus> rz_status =
-      rz_allocator.CheckRedzones(stream);
-  if (!rz_status.ok()) {
-    static std::once_flag failure_logged;
-    std::call_once(failure_logged, [&]() {
-      LOG(WARNING) << "Failed to check cudnn convolutions for out-of-bounds "
-                   << "reads and writes with an error message: '"
-                   << rz_status.status().error_message()
-                   << "'; skipping this check. This only means that we won't "
-                   << "check cudnn for out-of-bounds reads and writes. This "
-                   << "message will only be printed once.";
-    });
-    return;
-  }
-  auto rz_check_status = rz_status.ValueOrDie();
-  if (!rz_check_status.ok()) {
-    auto* fail = autotune_result->mutable_failure();
-    fail->set_msg(rz_check_status.RedzoneFailureMsg());
-    fail->set_kind(AutotuneResult::REDZONE_MODIFIED);
-    fail->set_buffer_address(
-        reinterpret_cast<uint64>(rz_check_status.user_buffer_address));
-    LOG(ERROR)
-        << "Detected cudnn out-of-bounds write in convolution buffer! This is "
-           "likely a cudnn bug. We will skip this algorithm in the future, but "
-           "your GPU state may already be corrupted, leading to incorrect "
-           "results. Within Google, no action is needed on your part. Outside "
-           "of Google, please ensure you're running the latest version of "
-           "cudnn. If that doesn't fix the problem, please file a bug with "
-           "this full error message and we'll contact nvidia.";
-    LOG(ERROR) << rz_check_status.RedzoneFailureMsg();
-  }
-}
-#endif  // GOOGLE_CUDA
-
 template <typename T>
 void LaunchConv2DOp<GPUDevice, T>::operator()(
     OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
@@ -1002,39 +956,18 @@
 
     se::TfAllocatorAdapter tf_allocator_adapter(
         stream->parent()->platform(), ctx->device()->GetAllocator({}));
-
-    se::cuda::RedzoneAllocator rz_allocator(stream->parent()->device_ordinal(),
-                                            &tf_allocator_adapter,
+    se::cuda::RedzoneAllocator rz_allocator(stream, &tf_allocator_adapter,
                                             se::cuda::PtxCompilationOptions());
-
-    se::DeviceMemory<T> output_tensor;
-
-    if (!RedzoneCheckDisabled()) {
-      auto output_rz_or = rz_allocator.AllocateBytes(stream, output_ptr.size());
-      if (!output_rz_or.ok()) {
-        static std::once_flag rz_allocation_failure_logged;
-        std::call_once(rz_allocation_failure_logged, []() {
-          LOG(WARNING)
-              << "Failed to allocate memory for convolution redzone "
-              << "checking; skipping this check. This is benign and only "
-              << "means that we won't check cudnn for out-of-bounds reads "
-              << "and writes. This message will only be printed once.";
-        });
-        output_tensor = output_ptr;
-      } else {
-        output_tensor = se::DeviceMemory<T>(output_rz_or.ValueOrDie());
-      }
-    } else {
-      output_tensor = output_ptr;
-    }
+    se::DeviceMemory<T> output_tensor(
+        WrapRedzoneBestEffort(&rz_allocator, output_ptr));
 
     std::vector<tensorflow::AutotuneResult> results;
     for (auto profile_algorithm : algorithms) {
       // TODO(zhengxq): profile each algorithm multiple times to better
       // accuracy.
       se::cuda::RedzoneAllocator rz_scratch_allocator(
-          stream->parent()->device_ordinal(), &tf_allocator_adapter,
-          se::cuda::PtxCompilationOptions());
+          stream, &tf_allocator_adapter, se::cuda::PtxCompilationOptions(),
+          /*memory_limit=*/ConvolveScratchSize);
       DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
       se::ScratchAllocator* allocator_used =
           !RedzoneCheckDisabled()
@@ -1057,12 +990,14 @@
             profile_algorithm.tensor_ops_enabled());
 
         result.set_scratch_bytes(
-            rz_scratch_allocator.TotalAllocatedBytesExcludingRedzones());
+            !RedzoneCheckDisabled()
+                ? rz_scratch_allocator.TotalAllocatedBytesExcludingRedzones()
+                : scratch_allocator.TotalByteSize());
         *result.mutable_run_time() = proto_utils::ToDurationProto(
             absl::Milliseconds(profile_result.elapsed_time_in_ms()));
 
-        CheckRedzones(rz_scratch_allocator, stream, &result);
-        CheckRedzones(rz_allocator, stream, &result);
+        CheckRedzones(rz_scratch_allocator, &result);
+        CheckRedzones(rz_allocator, &result);
       }
     }
     LogConvAutotuneResults(se::dnn::ConvolutionKind::FORWARD,
diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc
index 076db5c..52c5e80 100644
--- a/tensorflow/core/kernels/conv_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_ops_3d.cc
@@ -36,7 +36,12 @@
 #include "tensorflow/core/protobuf/autotuning.pb.h"
 #include "tensorflow/core/util/proto/proto_utils.h"
 using stream_executor::dnn::DimIndex;
-#endif
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if GOOGLE_CUDA
+#include "tensorflow/stream_executor/cuda/ptxas_utils.h"
+#include "tensorflow/stream_executor/cuda/redzone_allocator.h"
+#include "tensorflow/stream_executor/tf_allocator_adapter.h"
+#endif  // GOOGLE_CUDA
 
 namespace tensorflow {
 
@@ -436,6 +441,12 @@
     if (cudnn_use_autotune && !AutoTuneConv3d::GetInstance()->Find(
                                   conv_parameters, &algorithm_config)) {
 #if GOOGLE_CUDA
+      se::TfAllocatorAdapter tf_allocator_adapter(
+          stream->parent()->platform(), ctx->device()->GetAllocator({}));
+      se::cuda::RedzoneAllocator rz_allocator(
+          stream, &tf_allocator_adapter, se::cuda::PtxCompilationOptions());
+      se::DeviceMemory<T> output_ptr_rz(
+          WrapRedzoneBestEffort(&rz_allocator, output_ptr));
       std::vector<AlgorithmDesc> algorithms;
       OP_REQUIRES(ctx,
                   stream->parent()->GetConvolveAlgorithms(
@@ -452,12 +463,19 @@
         // TODO(zhengxq): profile each algorithm multiple times to better
         // accuracy.
         DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
+        se::cuda::RedzoneAllocator rz_scratch_allocator(
+            stream, &tf_allocator_adapter, se::cuda::PtxCompilationOptions(),
+            /*memory_limit=*/ConvolveScratchSize);
+        se::ScratchAllocator* allocator_used =
+            !RedzoneCheckDisabled()
+                ? static_cast<se::ScratchAllocator*>(&rz_scratch_allocator)
+                : static_cast<se::ScratchAllocator*>(&scratch_allocator);
         ProfileResult profile_result;
         bool cudnn_launch_status =
             stream
                 ->ThenConvolveWithAlgorithm(
                     input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
-                    output_desc, &output_ptr, &scratch_allocator,
+                    output_desc, &output_ptr_rz, allocator_used,
                     AlgorithmConfig(profile_algorithm), &profile_result)
                 .ok();
         if (cudnn_launch_status) {
@@ -467,9 +485,15 @@
             result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
             result.mutable_conv()->set_tensor_ops_enabled(
                 profile_algorithm.tensor_ops_enabled());
-            result.set_scratch_bytes(scratch_allocator.TotalByteSize());
+            result.set_scratch_bytes(
+                !RedzoneCheckDisabled()
+                    ? rz_scratch_allocator
+                          .TotalAllocatedBytesExcludingRedzones()
+                    : scratch_allocator.TotalByteSize());
             *result.mutable_run_time() = proto_utils::ToDurationProto(
                 absl::Milliseconds(profile_result.elapsed_time_in_ms()));
+            CheckRedzones(rz_scratch_allocator, &result);
+            CheckRedzones(rz_allocator, &result);
           }
         }
       }
diff --git a/tensorflow/core/kernels/conv_ops_fused_impl.h b/tensorflow/core/kernels/conv_ops_fused_impl.h
index 8fba8ce..e65d6f6 100644
--- a/tensorflow/core/kernels/conv_ops_fused_impl.h
+++ b/tensorflow/core/kernels/conv_ops_fused_impl.h
@@ -61,6 +61,9 @@
 #include "tensorflow/core/kernels/conv_ops_gpu.h"
 #include "tensorflow/core/platform/stream_executor.h"
 #include "tensorflow/core/util/proto/proto_utils.h"
+#include "tensorflow/stream_executor/cuda/ptxas_utils.h"
+#include "tensorflow/stream_executor/cuda/redzone_allocator.h"
+#include "tensorflow/stream_executor/tf_allocator_adapter.h"
 #endif  // GOOGLE_CUDA
 
 namespace tensorflow {
@@ -304,6 +307,7 @@
 Status FindBestConvolveAlgorithm(const FusedConvParameters& params,
                                  const ConvLaunch launch,
                                  OpKernelContext* context, se::Stream* stream,
+                                 se::DeviceMemory<T> output_ptr,
                                  const LogFunc& log,
                                  se::dnn::AlgorithmConfig* algorithm_config) {
   // Check if we already have an algorithm selected for the given parameters.
@@ -322,14 +326,28 @@
         "see if a warning log message was printed above.");
   }
 
+  se::TfAllocatorAdapter tf_allocator_adapter(
+      stream->parent()->platform(), context->device()->GetAllocator({}));
+  se::cuda::RedzoneAllocator rz_allocator(stream, &tf_allocator_adapter,
+                                          se::cuda::PtxCompilationOptions());
+  se::DeviceMemory<T> output_ptr_rz(
+      WrapRedzoneBestEffort(&rz_allocator, output_ptr));
+
   std::vector<tensorflow::AutotuneResult> results;
   for (auto profile_algorithm : algorithms) {
     DnnScratchAllocator scratch_allocator(ConvolveScratchSize(), context);
+    se::cuda::RedzoneAllocator rz_scratch_allocator(
+        stream, &tf_allocator_adapter, se::cuda::PtxCompilationOptions(),
+        /*memory_limit=*/ConvolveScratchSize());
+    se::ScratchAllocator* allocator_used =
+        !RedzoneCheckDisabled()
+            ? static_cast<se::ScratchAllocator*>(&rz_scratch_allocator)
+            : static_cast<se::ScratchAllocator*>(&scratch_allocator);
     se::dnn::ProfileResult profile_result;
 
     bool cudnn_launch_status =
-        launch(se::dnn::AlgorithmConfig(profile_algorithm), &scratch_allocator,
-               &profile_result);
+        launch(se::dnn::AlgorithmConfig(profile_algorithm), allocator_used,
+               output_ptr_rz, &profile_result);
 
     if (cudnn_launch_status && profile_result.is_valid()) {
       results.emplace_back();
@@ -337,9 +355,14 @@
       result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
       result.mutable_conv()->set_tensor_ops_enabled(
           profile_algorithm.tensor_ops_enabled());
-      result.set_scratch_bytes(scratch_allocator.TotalByteSize());
+      result.set_scratch_bytes(
+          !RedzoneCheckDisabled()
+              ? rz_scratch_allocator.TotalAllocatedBytesExcludingRedzones()
+              : scratch_allocator.TotalByteSize());
       *result.mutable_run_time() = proto_utils::ToDurationProto(
           absl::Milliseconds(profile_result.elapsed_time_in_ms()));
+      CheckRedzones(rz_scratch_allocator, &result);
+      CheckRedzones(rz_allocator, &result);
     }
   }
   // Only log on an AutoTuneFusedConv cache miss.
@@ -588,7 +611,8 @@
     // Launch fused convolution with given parameters and scratch allocator.
     // Record profile result into `profile_result` if it's not nullptr.
     const auto launch = [&](se::dnn::AlgorithmConfig algorithm_config,
-                            DnnScratchAllocator* scratch_allocator,
+                            se::ScratchAllocator* scratch_allocator,
+                            se::DeviceMemory<T> output_ptr_to_use,
                             se::dnn::ProfileResult* profile_result) -> bool {
       return stream
           ->ThenFusedConvolveWithAlgorithm(
@@ -599,7 +623,7 @@
               side_input_ptr, /*side_input_scale=*/0.0,  // side_input
               bias_desc, bias_ptr,                       // bias
               dnn_activation_mode,                       // activation
-              output_desc, &output_ptr,                  // output
+              output_desc, &output_ptr_to_use,           // output
               scratch_allocator, algorithm_config, profile_result)
           .ok();
     };
@@ -607,7 +631,7 @@
     se::dnn::AlgorithmConfig algorithm_config;
     if (cudnn_use_autotune) {
       auto status = FindBestConvolveAlgorithm<T>(
-          conv_parameters, launch, context, stream,
+          conv_parameters, launch, context, stream, output_ptr,
           [&](absl::Span<const tensorflow::AutotuneResult> results) {
             LogFusedConvForwardAutotuneResults(
                 se::dnn::ToDataType<T>::value, input_ptr, filter_ptr,
@@ -621,7 +645,7 @@
 
     DnnScratchAllocator scratch_allocator(ConvolveScratchSize(), context);
     bool cudnn_launch_status = launch(algorithm_config, &scratch_allocator,
-                                      /*profile_result=*/nullptr);
+                                      output_ptr, /*profile_result=*/nullptr);
     OP_REQUIRES(
         context, cudnn_launch_status,
         errors::Internal(absl::Substitute(
diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h
index 7906f74..2ccc0f3 100644
--- a/tensorflow/core/kernels/conv_ops_gpu.h
+++ b/tensorflow/core/kernels/conv_ops_gpu.h
@@ -20,6 +20,7 @@
 
 #include <tuple>
 #include <unordered_map>
+
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/kernels/gpu_utils.h"
 #include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -50,11 +51,9 @@
   virtual ~DnnScratchAllocator() {}
   DnnScratchAllocator(int64 memory_limit, OpKernelContext* context)
       : memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
-  int64 GetMemoryLimitInBytes(se::Stream* stream) override {
-    return memory_limit_;
-  }
+  int64 GetMemoryLimitInBytes() override { return memory_limit_; }
   se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
-      se::Stream* stream, int64 byte_size) override {
+      int64 byte_size) override {
     Tensor temporary_memory;
     if (byte_size < 0) {
       return se::port::Status{se::port::error::INVALID_ARGUMENT,
diff --git a/tensorflow/core/kernels/cuda_solvers.h b/tensorflow/core/kernels/cuda_solvers.h
index 9679fad..104ee09 100644
--- a/tensorflow/core/kernels/cuda_solvers.h
+++ b/tensorflow/core/kernels/cuda_solvers.h
@@ -28,6 +28,7 @@
 
 #if GOOGLE_CUDA
 #include "third_party/gpus/cuda/include/cublas_v2.h"
+#include "third_party/gpus/cuda/include/cuda.h"
 #include "third_party/gpus/cuda/include/cusolverDn.h"
 #endif
 #include "tensorflow/core/framework/op_kernel.h"
diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc
index 09826f5..4a27394 100644
--- a/tensorflow/core/kernels/cudnn_rnn_ops.cc
+++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc
@@ -363,12 +363,11 @@
 
   explicit CudnnRnnAllocatorInTemp(OpKernelContext* context)
       : context_(context) {}
-  int64 GetMemoryLimitInBytes(Stream* stream) override {
+  int64 GetMemoryLimitInBytes() override {
     return std::numeric_limits<int64>::max();
   }
 
-  StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream,
-                                              int64 byte_size) override {
+  StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override {
     Tensor temporary_memory;
     const DataType tf_data_type = ToTFDataType<T>::value;
     int64 allocate_count =
@@ -409,11 +408,10 @@
   ~CudnnRnnAllocatorInOutput() override {}
   CudnnRnnAllocatorInOutput(OpKernelContext* context, int output_index)
       : context_(context), output_index_(output_index) {}
-  int64 GetMemoryLimitInBytes(Stream* stream) override {
+  int64 GetMemoryLimitInBytes() override {
     return std::numeric_limits<int64>::max();
   }
-  StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream,
-                                              int64 byte_size) override {
+  StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override {
     CHECK(total_byte_size_ == 0)
         << "Reserve space allocator can only be called once";
     int64 allocate_count =
@@ -449,12 +447,11 @@
 
   ~CudnnRNNPersistentSpaceAllocator() override {}
 
-  int64 GetMemoryLimitInBytes(Stream* stream) override {
+  int64 GetMemoryLimitInBytes() override {
     return std::numeric_limits<int64>::max();
   }
 
-  StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream,
-                                              int64 byte_size) override {
+  StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override {
     if (total_byte_size_ != 0) {
       return Status(error::FAILED_PRECONDITION,
                     "Persistent space allocator can only be called once");
@@ -944,6 +941,20 @@
   }
 }
 
+bool ShouldUsePaddedIO(const Tensor* sequence_lengths,
+                       const CudnnRnnModelShapes& model_shapes,
+                       bool time_major) {
+  auto seq_array = sequence_lengths->template flat<int>().data();
+  bool all_max_seq_length = true;
+  for (int i = 0; i < model_shapes.batch_size; i++) {
+    if (seq_array[i] != model_shapes.max_seq_length) {
+      all_max_seq_length = false;
+      break;
+    }
+  }
+  return !(time_major && all_max_seq_length);
+}
+
 }  // namespace
 
 // Note: all following kernels depend on a RnnDescriptor instance, which
@@ -1027,7 +1038,7 @@
         num_layers, h_num_units, input_size, /*cell_size=*/c_num_units,
         /*batch_size=*/0, input_mode, rnn_direction_mode(), rnn_mode(),
         ToDataType<T>::value, algo_config, dropout(), seed(),
-        /* state_allocator=*/nullptr);
+        /* state_allocator=*/nullptr, /*use_padded_io=*/false);
     if (!rnn_desc_s.ok()) {
       return FromExecutorStatus(rnn_desc_s);
     }
@@ -1041,14 +1052,16 @@
                              const RnnInputMode& input_mode,
                              const AlgorithmConfig& algo_config,
                              ScratchAllocator* dropout_state_allocator,
-                             std::unique_ptr<RnnDescriptor>* rnn_desc) {
+                             std::unique_ptr<RnnDescriptor>* rnn_desc,
+                             bool use_padded_io) {
     StreamExecutor* executor = context->op_device_context()->stream()->parent();
     se::dnn::DataType data_type = ToDataType<T>::value;
     auto rnn_desc_s = executor->createRnnDescriptor(
         model_shapes.num_layers, model_shapes.num_units,
         model_shapes.input_size, model_shapes.cell_num_units,
         model_shapes.batch_size, input_mode, rnn_direction_mode(), rnn_mode(),
-        data_type, algo_config, dropout(), seed(), dropout_state_allocator);
+        data_type, algo_config, dropout(), seed(), dropout_state_allocator,
+        use_padded_io);
     TF_RETURN_IF_ERROR(rnn_desc_s.status());
 
     *rnn_desc = rnn_desc_s.ConsumeValueOrDie();
@@ -1065,17 +1078,17 @@
                                 const CudnnRnnModelShapes& model_shapes,
                                 const RnnInputMode& input_mode,
                                 const AlgorithmConfig& algo_config,
-                                RnnStateCache* cache,
-                                RnnDescriptor** rnn_desc) {
+                                RnnStateCache* cache, RnnDescriptor** rnn_desc,
+                                bool use_padded_io) {
     auto key = std::make_pair(model_shapes, algo_config.algorithm());
     RnnScratchSpace& rnn_state = (*cache)[key];
     if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) {
       CudnnRNNPersistentSpaceAllocator* dropout_state_allocator =
           new CudnnRNNPersistentSpaceAllocator(context);
       rnn_state.dropout_state_allocator.reset(dropout_state_allocator);
-      Status status =
-          CreateRnnDescriptor<T>(context, model_shapes, input_mode, algo_config,
-                                 dropout_state_allocator, &rnn_state.rnn_desc);
+      Status status = CreateRnnDescriptor<T>(
+          context, model_shapes, input_mode, algo_config,
+          dropout_state_allocator, &rnn_state.rnn_desc, use_padded_io);
       TF_RETURN_IF_ERROR(status);
     }
     *rnn_desc = rnn_state.rnn_desc.get();
@@ -1444,11 +1457,14 @@
     const Tensor* params = nullptr;
     const Tensor* sequence_lengths = nullptr;
     CudnnRnnModelShapes model_shapes;
+    bool use_padded_io = false;
     if (var_seq_lengths) {
       OP_REQUIRES_OK(context, ExtractForwardInput(
                                   context, model_types(), time_major, &input,
                                   &input_h, &input_c, &params,
                                   &sequence_lengths, num_proj, &model_shapes));
+      use_padded_io =
+          ShouldUsePaddedIO(sequence_lengths, model_shapes, time_major);
     } else {
       OP_REQUIRES_OK(context,
                      ExtractForwardInput(context, model_types(), time_major,
@@ -1488,10 +1504,10 @@
     {
       mutex_lock l(mu_);
       RnnDescriptor* rnn_desc_ptr = nullptr;
-      OP_REQUIRES_OK(
-          context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode,
-                                             *output_algo_config,
-                                             &rnn_state_cache_, &rnn_desc_ptr));
+      OP_REQUIRES_OK(context,
+                     GetCachedRnnDescriptor<T>(
+                         context, model_shapes, input_mode, *output_algo_config,
+                         &rnn_state_cache_, &rnn_desc_ptr, use_padded_io));
       launch_status = DoForward<T>(
           context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
           input_c, params, is_training_, output, output_h, output_c,
@@ -1690,7 +1706,8 @@
       CudnnRnnAllocatorInTemp<uint8> dropout_state_allocator(context);
       if (!this->template CreateRnnDescriptor<T>(
                    context, model_shapes, input_mode, AlgorithmConfig(algo),
-                   &dropout_state_allocator, &rnn_desc)
+                   &dropout_state_allocator, &rnn_desc,
+                   /*use_padded_io=*/false)
                .ok()) {
         continue;
       }
@@ -1840,11 +1857,14 @@
     const Tensor* params = nullptr;
     const Tensor* sequence_lengths = nullptr;
     CudnnRnnModelShapes model_shapes;
+    bool use_padded_io = false;
     if (var_seq_lengths) {
       OP_REQUIRES_OK(context, ExtractForwardInput(
                                   context, model_types(), time_major, &input,
                                   &input_h, &input_c, &params,
                                   &sequence_lengths, num_proj, &model_shapes));
+      use_padded_io =
+          ShouldUsePaddedIO(sequence_lengths, model_shapes, time_major);
     } else {
       OP_REQUIRES_OK(context,
                      ExtractForwardInput(context, model_types(), time_major,
@@ -1890,7 +1910,7 @@
       OP_REQUIRES_OK(
           context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode,
                                              algo_config, &rnn_state_cache_,
-                                             &rnn_desc_ptr));
+                                             &rnn_desc_ptr, use_padded_io));
       launch_status = DoBackward<T>(
           context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
           input_c, params, output, output_h, output_c, output_backprop,
diff --git a/tensorflow/core/kernels/cwise_op_less.cc b/tensorflow/core/kernels/cwise_op_less.cc
index 563bb7d..062a029 100644
--- a/tensorflow/core/kernels/cwise_op_less.cc
+++ b/tensorflow/core/kernels/cwise_op_less.cc
@@ -18,8 +18,7 @@
 namespace tensorflow {
 REGISTER5(BinaryOp, CPU, "Less", functor::less, float, Eigen::half, double,
           bfloat16, int32);
-REGISTER5(BinaryOp, CPU, "Less", functor::less, int64, uint8, int8, int16,
-          bfloat16);
+REGISTER4(BinaryOp, CPU, "Less", functor::less, int64, uint8, int8, int16);
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 REGISTER7(BinaryOp, GPU, "Less", functor::less, float, Eigen::half, double,
diff --git a/tensorflow/core/kernels/cwise_op_less_equal.cc b/tensorflow/core/kernels/cwise_op_less_equal.cc
index 1998fc0..43af038 100644
--- a/tensorflow/core/kernels/cwise_op_less_equal.cc
+++ b/tensorflow/core/kernels/cwise_op_less_equal.cc
@@ -18,8 +18,8 @@
 namespace tensorflow {
 REGISTER5(BinaryOp, CPU, "LessEqual", functor::less_equal, float, Eigen::half,
           bfloat16, double, int32);
-REGISTER5(BinaryOp, CPU, "LessEqual", functor::less_equal, int64, uint8, int8,
-          int16, bfloat16);
+REGISTER4(BinaryOp, CPU, "LessEqual", functor::less_equal, int64, uint8, int8,
+          int16);
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 REGISTER7(BinaryOp, GPU, "LessEqual", functor::less_equal, float, Eigen::half,
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index a5f41b6..90ad751 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -180,6 +180,7 @@
         "//tensorflow/core:core_cpu_internal",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
         "@com_google_absl//absl/memory",
     ],
 )
@@ -794,6 +795,7 @@
     hdrs = ["shuffle_dataset_op.h"],
     deps = [
         ":name_utils",
+        ":random_seed_ops",
         "//tensorflow/core:dataset_ops_op_lib",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
@@ -1154,6 +1156,7 @@
     srcs = ["cache_dataset_ops.cc"],
     hdrs = ["cache_dataset_ops.h"],
     deps = [
+        ":cache_ops",
         ":name_utils",
         "//tensorflow/core:dataset_ops_op_lib",
         "//tensorflow/core:framework",
@@ -1318,3 +1321,31 @@
         "//tensorflow/core/kernels:function_ops",
     ],
 )
+
+tf_kernel_library(
+    name = "random_seed_ops",
+    srcs = ["random_seed_ops.cc"],
+    hdrs = ["random_seed_ops.h"],
+    deps = [
+        ":dataset_utils",
+        "//tensorflow/core:core_cpu_internal",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:functional_ops_op_lib",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
+    name = "cache_ops",
+    srcs = ["cache_ops.cc"],
+    hdrs = ["cache_ops.h"],
+    deps = [
+        ":dataset_utils",
+        "//tensorflow/core:core_cpu_internal",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:functional_ops_op_lib",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc
index e527674..3e6f8b8 100644
--- a/tensorflow/core/kernels/data/batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/batch_dataset_op.cc
@@ -101,6 +101,10 @@
     return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
   }
 
+  Status CheckExternalState() const override {
+    return input_->CheckExternalState();
+  }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
diff --git a/tensorflow/core/kernels/data/batch_dataset_op_test.cc b/tensorflow/core/kernels/data/batch_dataset_op_test.cc
index 0addd6e..6baa5d7 100644
--- a/tensorflow/core/kernels/data/batch_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/batch_dataset_op_test.cc
@@ -72,43 +72,39 @@
 // Test Case 1: test BatchDatasetV2 with `drop_remainder` = false and a batch
 // size that can evenly split the input dataset.
 TestCase TestCase1() {
-  return {
-      /*range_data_param*/ {0, 12, 1},
-      /*batch_size*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4}),
-      /*drop_remainder*/
-      DatasetOpsTestBase::CreateTensor<bool>(TensorShape({}), {false}),
-      /*parallel_copy*/ true,
-      /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({4}), {0, 1, 2, 3}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({4}), {4, 5, 6, 7}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({4}),
-                                               {8, 9, 10, 11})},
-      /*expected_output_dtypes*/ {DT_INT64},
-      /*expected_output_shapes*/ {PartialTensorShape({4})},
-      /*expected_cardinality*/ 3,
-      /*breakpoints*/ {0, 1, 5}};
+  return {/*range_data_param*/ {0, 12, 1},
+          /*batch_size*/
+          CreateTensor<int64>(TensorShape({}), {4}),
+          /*drop_remainder*/
+          CreateTensor<bool>(TensorShape({}), {false}),
+          /*parallel_copy*/ true,
+          /*expected_outputs*/
+          {CreateTensor<int64>(TensorShape({4}), {0, 1, 2, 3}),
+           CreateTensor<int64>(TensorShape({4}), {4, 5, 6, 7}),
+           CreateTensor<int64>(TensorShape({4}), {8, 9, 10, 11})},
+          /*expected_output_dtypes*/ {DT_INT64},
+          /*expected_output_shapes*/ {PartialTensorShape({4})},
+          /*expected_cardinality*/ 3,
+          /*breakpoints*/ {0, 1, 5}};
 }
 
 // Test Case 2: test BatchDatasetV2 with `drop_remainder` = true and a batch
 // size that can evenly split the input dataset.
 TestCase TestCase2() {
-  return {
-      /*range_data_param*/ {0, 12, 1},
-      /*batch_size*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4}),
-      /*drop_remainder*/
-      DatasetOpsTestBase::CreateTensor<bool>(TensorShape({}), {true}),
-      /*parallel_copy*/ false,
-      /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({4}), {0, 1, 2, 3}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({4}), {4, 5, 6, 7}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({4}),
-                                               {8, 9, 10, 11})},
-      /*expected_output_dtypes*/ {DT_INT64},
-      /*expected_output_shapes*/ {PartialTensorShape({4})},
-      /*expected_cardinality*/ 3,
-      /*breakpoints*/ {0, 1, 5}};
+  return {/*range_data_param*/ {0, 12, 1},
+          /*batch_size*/
+          CreateTensor<int64>(TensorShape({}), {4}),
+          /*drop_remainder*/
+          CreateTensor<bool>(TensorShape({}), {true}),
+          /*parallel_copy*/ false,
+          /*expected_outputs*/
+          {CreateTensor<int64>(TensorShape({4}), {0, 1, 2, 3}),
+           CreateTensor<int64>(TensorShape({4}), {4, 5, 6, 7}),
+           CreateTensor<int64>(TensorShape({4}), {8, 9, 10, 11})},
+          /*expected_output_dtypes*/ {DT_INT64},
+          /*expected_output_shapes*/ {PartialTensorShape({4})},
+          /*expected_cardinality*/ 3,
+          /*breakpoints*/ {0, 1, 5}};
 }
 
 // Test Case 3: test BatchDatasetV2 with `drop_remainder` = false and a batch
@@ -116,15 +112,15 @@
 TestCase TestCase3() {
   return {/*range_data_param*/ {0, 10, 1},
           /*batch_size*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3}),
+          CreateTensor<int64>(TensorShape({}), {3}),
           /*drop_remainder*/
-          DatasetOpsTestBase::CreateTensor<bool>(TensorShape({}), {false}),
+          CreateTensor<bool>(TensorShape({}), {false}),
           /*parallel_copy*/ false,
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3}), {0, 1, 2}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3}), {3, 4, 5}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3}), {6, 7, 8}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({1}), {9})},
+          {CreateTensor<int64>(TensorShape({3}), {0, 1, 2}),
+           CreateTensor<int64>(TensorShape({3}), {3, 4, 5}),
+           CreateTensor<int64>(TensorShape({3}), {6, 7, 8}),
+           CreateTensor<int64>(TensorShape({1}), {9})},
           /*expected_output_dtypes*/ {DT_INT64},
           /*expected_output_shapes*/ {PartialTensorShape({-1})},
           /*expected_cardinality*/ 4,
@@ -134,21 +130,20 @@
 // Test Case 4: test BatchDatasetV2 with `drop_remainder` = true and a batch
 // size that can not evenly split the input dataset.
 TestCase TestCase4() {
-  return {
-      /*range_data_param*/ {0, 10, 1},
-      /*batch_size*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3}),
-      /*drop_remainder*/
-      DatasetOpsTestBase::CreateTensor<bool>(TensorShape({}), {true}),
-      /*parallel_copy*/ true,
-      /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3}), {0, 1, 2}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3}), {3, 4, 5}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3}), {6, 7, 8})},
-      /*expected_output_dtypes*/ {DT_INT64},
-      /*expected_output_shapes*/ {PartialTensorShape({3})},
-      /*expected_cardinality*/ 3,
-      /*breakpoints*/ {0, 1, 5}};
+  return {/*range_data_param*/ {0, 10, 1},
+          /*batch_size*/
+          CreateTensor<int64>(TensorShape({}), {3}),
+          /*drop_remainder*/
+          CreateTensor<bool>(TensorShape({}), {true}),
+          /*parallel_copy*/ true,
+          /*expected_outputs*/
+          {CreateTensor<int64>(TensorShape({3}), {0, 1, 2}),
+           CreateTensor<int64>(TensorShape({3}), {3, 4, 5}),
+           CreateTensor<int64>(TensorShape({3}), {6, 7, 8})},
+          /*expected_output_dtypes*/ {DT_INT64},
+          /*expected_output_shapes*/ {PartialTensorShape({3})},
+          /*expected_cardinality*/ 3,
+          /*breakpoints*/ {0, 1, 5}};
 }
 
 // Test Case 5: test BatchDatasetV2 with `drop_remainder` = true and
@@ -156,9 +151,9 @@
 TestCase TestCase5() {
   return {/*range_data_param*/ {0, 10, 1},
           /*batch_size*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {12}),
+          CreateTensor<int64>(TensorShape({}), {12}),
           /*drop_remainder*/
-          DatasetOpsTestBase::CreateTensor<bool>(TensorShape({}), {true}),
+          CreateTensor<bool>(TensorShape({}), {true}),
           /*parallel_copy*/ true,
           /*expected_outputs*/ {},
           /*expected_output_dtypes*/ {DT_INT64},
@@ -170,19 +165,19 @@
 // Test Case 6: test BatchDatasetV2 with `drop_remainder` = false and
 // `batch_size` > the cardinality of the input dataset.
 TestCase TestCase6() {
-  return {/*range_data_param*/ {0, 10, 1},
-          /*batch_size*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {12}),
-          /*drop_remainder*/
-          DatasetOpsTestBase::CreateTensor<bool>(TensorShape({}), {false}),
-          /*parallel_copy*/ true,
-          /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape({10}), {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({-1})},
-          /*expected_cardinality*/ 1,
-          /*breakpoints*/ {0, 1, 5}};
+  return {
+      /*range_data_param*/ {0, 10, 1},
+      /*batch_size*/
+      CreateTensor<int64>(TensorShape({}), {12}),
+      /*drop_remainder*/
+      CreateTensor<bool>(TensorShape({}), {false}),
+      /*parallel_copy*/ true,
+      /*expected_outputs*/
+      {CreateTensor<int64>(TensorShape({10}), {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({-1})},
+      /*expected_cardinality*/ 1,
+      /*breakpoints*/ {0, 1, 5}};
 }
 
 // Test Case 7: test BatchDatasetV2 with `drop_remainder` = false and
@@ -190,9 +185,9 @@
 TestCase TestCase7() {
   return {/*range_data_param*/ {0, 0, 1},
           /*batch_size*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4}),
+          CreateTensor<int64>(TensorShape({}), {4}),
           /*drop_remainder*/
-          DatasetOpsTestBase::CreateTensor<bool>(TensorShape({}), {false}),
+          CreateTensor<bool>(TensorShape({}), {false}),
           /*parallel_copy*/ false,
           /*expected_outputs*/ {},
           /*expected_output_dtypes*/ {DT_INT64},
@@ -205,9 +200,9 @@
 TestCase InvalidBatchSizeTestCase() {
   return {/*range_data_param*/ {0, 10, 1},
           /*batch_size*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {-1}),
+          CreateTensor<int64>(TensorShape({}), {-1}),
           /*drop_remainder*/
-          DatasetOpsTestBase::CreateTensor<bool>(TensorShape({}), {false}),
+          CreateTensor<bool>(TensorShape({}), {false}),
           /*parallel_copy*/ false,
           /*expected_outputs*/ {},
           /*expected_output_dtypes*/ {DT_INT64},
@@ -454,46 +449,6 @@
   EXPECT_EQ(batch_dataset->Cardinality(), test_case.expected_cardinality);
 }
 
-TEST_P(ParameterizedBatchDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  TestCase test_case = GetParam();
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  std::unique_ptr<OpKernel> batch_dataset_kernel;
-  TF_ASSERT_OK(CreateBatchDatasetOpKernel(
-      test_case.parallel_copy, test_case.expected_output_dtypes,
-      test_case.expected_output_shapes, &batch_dataset_kernel));
-
-  DatasetBase* range_dataset;
-  TF_ASSERT_OK(CreateRangeDataset<int64>(
-      test_case.range_dataset_param.start, test_case.range_dataset_param.end,
-      test_case.range_dataset_param.step, "range", &range_dataset));
-  Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
-  TF_ASSERT_OK(
-      StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
-
-  Tensor batch_size = test_case.batch_size;
-  Tensor drop_remainder = test_case.drop_remainder;
-  gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&range_dataset_tensor),
-                                            TensorValue(&batch_size),
-                                            TensorValue(&drop_remainder)};
-  std::unique_ptr<OpKernelContext> batch_dataset_context;
-  TF_ASSERT_OK(CreateBatchDatasetContext(batch_dataset_kernel.get(), &inputs,
-                                         &batch_dataset_context));
-  DatasetBase* batch_dataset;
-  TF_ASSERT_OK(CreateDataset(batch_dataset_kernel.get(),
-                             batch_dataset_context.get(), &batch_dataset));
-  core::ScopedUnref scoped_unref_batch_dataset(batch_dataset);
-
-  std::unique_ptr<SerializationContext> serialization_context;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(batch_dataset->Save(serialization_context.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_P(ParameterizedBatchDatasetOpTest, IteratorOutputDtypes) {
   int thread_num = 2, cpu_num = 2;
   TestCase test_case = GetParam();
diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc
index 9b1fed9..2fb1c6f 100644
--- a/tensorflow/core/kernels/data/cache_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc
@@ -17,6 +17,7 @@
 #include "tensorflow/core/framework/partial_tensor_shape.h"
 #include "tensorflow/core/framework/resource_mgr.h"
 #include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/data/cache_ops.h"
 #include "tensorflow/core/kernels/data/name_utils.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/strings/stringprintf.h"
@@ -54,6 +55,7 @@
 constexpr char kCacheCompleted[] = "cache_completed";
 constexpr char kIndex[] = "index";
 constexpr char kImpl[] = "Impl";
+constexpr char kCacheDataset[] = "CacheDataset";
 
 class CacheDatasetOp::FileDataset : public DatasetBase {
  public:
@@ -99,6 +101,10 @@
 
   int64 Cardinality() const override { return input_->Cardinality(); }
 
+  Status CheckExternalState() const override {
+    return input_->CheckExternalState();
+  }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
@@ -111,6 +117,9 @@
     return Status::OK();
   }
 
+  const DatasetBase* const input_;
+  const string filename_;
+
  private:
   static size_t StringPaddingSize(size_t num_tensors) {
     return strings::Printf(kPaddingSizeStrFormat, num_tensors - 1).size();
@@ -215,6 +224,25 @@
             lockfile_created_(false),
             iteration_completed_(false) {}
 
+      ~FileWriterIterator() {
+        if (!dataset()->env_->FileExists(MetaFilename(filename_)).ok()) {
+          std::vector<string> cache_files;
+          Status s = dataset()->env_->GetMatchingPaths(
+              strings::StrCat(filename_, "*"), &cache_files);
+          if (!s.ok()) {
+            LOG(WARNING) << "Failed to get matching files on " << filename_
+                         << "* : " << s.ToString();
+          }
+          for (const string& path : cache_files) {
+            s = dataset()->env_->DeleteFile(path);
+            if (!s.ok()) {
+              LOG(WARNING) << "Failed to delete " << path << " : "
+                           << s.ToString();
+            }
+          }
+        }
+      }
+
       Status Initialize(IteratorContext* ctx) override {
         return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
       }
@@ -275,6 +303,9 @@
 
       Status SaveInternal(IteratorStateWriter* writer) override {
         mutex_lock l(mu_);
+        TF_RETURN_IF_ERROR(
+            writer->WriteScalar(full_name(kCurIndex), cur_index_));
+
         if (iteration_completed_) {
           TF_RETURN_IF_ERROR(
               writer->WriteScalar(full_name(kIterationCompleted), ""));
@@ -301,8 +332,6 @@
           lockfile_created_ = false;
         }
         TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
-        TF_RETURN_IF_ERROR(
-            writer->WriteScalar(full_name(kCurIndex), cur_index_));
         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kShardId), shard_id_));
         return Status::OK();
       }
@@ -310,12 +339,6 @@
       Status RestoreInternal(IteratorContext* ctx,
                              IteratorStateReader* reader) override {
         mutex_lock l(mu_);
-        if (reader->Contains(full_name(kIterationCompleted))) {
-          iteration_completed_ = true;
-          return Status::OK();
-        }
-
-        TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
         int64 temp;
         // TODO(b/78048575): Update this when saving size_t tensors directly
         // is supported.
@@ -326,6 +349,14 @@
             return errors::Internal("Invalid value for cur_index ", temp);
           }
         }
+
+        if (reader->Contains(full_name(kIterationCompleted))) {
+          iteration_completed_ = true;
+          return Status::OK();
+        }
+
+        TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+
         // TODO(b/78048575): Update this when saving size_t tensors directly
         // is supported.
         {
@@ -348,7 +379,9 @@
           *end_of_sequence = true;
           return Status::OK();
         }
-        if (lockfile_created_ && !iteration_completed_) return Status::OK();
+        if (lockfile_created_) {
+          return Status::OK();
+        }
 
         // Perform rudimentary locking to help catch concurrent writes to the
         // same cache files.
@@ -409,7 +442,7 @@
         // Merge all the bundles.
         // Currently there are `shard_id_ + 1` bundles, one for each
         // checkpoint. Each bundle has prefix <filename>_<id> where `id` is an
-        // integer starting at 0 an incremented by 1 for each new checkpoint.
+        // integer starting at 0 and incremented by 1 for each new checkpoint.
         // We merge all these bundles into a bundle with prefix <filename> so
         // that the next call to `MakeIterator` can build a
         // `FileReaderIterator`.
@@ -562,8 +595,6 @@
     std::unique_ptr<IteratorBase> iterator_ GUARDED_BY(mu_);
   };  // FileIterator
 
-  const DatasetBase* const input_;
-  const string filename_;
   Env* const env_;
   const size_t num_tensors_;
   const size_t tensor_index_padding_size_;
@@ -572,21 +603,56 @@
   const string tensor_format_string_;
 };  // FileDataset
 
-class CacheDatasetOp::MemoryDataset : public DatasetBase {
+class CacheDatasetOp::FileDatasetV2 : public CacheDatasetOp::FileDataset {
  public:
-  explicit MemoryDataset(OpKernelContext* ctx, const DatasetBase* input)
-      : DatasetBase(DatasetContext(ctx)), input_(input) {
-    input->Ref();
+  explicit FileDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
+                         string filename, Env* env,
+                         const Tensor& resource_handle)
+      : FileDataset(ctx, input, filename, env),
+        resource_handle_(resource_handle) {}
+
+ protected:
+  Status AsGraphDefInternal(SerializationContext* ctx,
+                            DatasetGraphDefBuilder* b,
+                            Node** output) const override {
+    Node* input_node = nullptr;
+    TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
+    Node* filename_node = nullptr;
+    TF_RETURN_IF_ERROR(b->AddScalar(filename_, &filename_node));
+    Node* resource_handle_node = nullptr;
+    TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node));
+    TF_RETURN_IF_ERROR(b->AddDataset(
+        this, {input_node, filename_node, resource_handle_node}, output));
+    return Status::OK();
   }
 
-  ~MemoryDataset() override { input_->Unref(); }
+ private:
+  const Tensor resource_handle_;
+};
+
+class CacheDatasetOp::MemoryDataset : public DatasetBase {
+ public:
+  explicit MemoryDataset(OpKernelContext* ctx, const DatasetBase* input,
+                         MemoryCache* cache)
+      : DatasetBase(DatasetContext(ctx)), input_(input), cache_(cache) {
+    input_->Ref();
+  }
+
+  ~MemoryDataset() override {
+    input_->Unref();
+    if (cache_) {
+      cache_->Unref();
+    }
+  }
 
   std::unique_ptr<IteratorBase> MakeIteratorInternal(
       const string& prefix) const override {
     name_utils::IteratorPrefixParams params;
     params.dataset_prefix = kMemoryDatasetPrefix;
-    return absl::make_unique<MemoryIterator>(MemoryIterator::Params{
-        this, name_utils::IteratorPrefix(kDatasetType, prefix, params)});
+    return absl::make_unique<MemoryIterator>(
+        MemoryIterator::Params{
+            this, name_utils::IteratorPrefix(kDatasetType, prefix, params)},
+        cache_);
   }
 
   const DataTypeVector& output_dtypes() const override {
@@ -605,6 +671,10 @@
 
   int64 Cardinality() const override { return input_->Cardinality(); }
 
+  Status CheckExternalState() const override {
+    return input_->CheckExternalState();
+  }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
@@ -618,102 +688,32 @@
     return Status::OK();
   }
 
- private:
-  // A thread-safe data structure for caching dataset elements.
-  //
-  // The expected use is that a single `MemoryWriterIterator` populates the
-  // cache with dataset elements. Once all elements are cached, the cache can
-  // be used by one or more `MemoryReaderIterator`s.
-  class MemoryCache : public ResourceBase {
-   public:
-    MemoryCache() = default;
-
-    string DebugString() const override { return "CacheDataset::MemoryCache"; }
-
-    // Marks the cache as completed.
-    void Complete() {
-      mutex_lock l(mu_);
-      completed_ = true;
-    }
-
-    // Returns whether the cache is claimed.
-    bool IsClaimed() {
-      tf_shared_lock l(mu_);
-      return claimed_;
-    }
-
-    // Returns whether the cache is completed.
-    bool IsCompleted() {
-      tf_shared_lock l(mu_);
-      return completed_;
-    }
-
-    // Attempts to claim the cache, returning whether the cache was claimed.
-    bool MaybeClaim() {
-      mutex_lock l(mu_);
-      if (!claimed_) {
-        claimed_ = true;
-        return true;
-      }
-      return false;
-    }
-
-    // Resets the cache.
-    void Reset() {
-      mutex_lock l(mu_);
-      claimed_ = false;
-      completed_ = false;
-      cache_.clear();
-    }
-
-    // Returns the element at the given index.
-    const std::vector<Tensor>& at(int64 index) {
-      tf_shared_lock l(mu_);
-      DCHECK(index < cache_.size());
-      return cache_[index];
-    }
-
-    // Adds the element to the cache.
-    void emplace_back(std::vector<Tensor> element) {
-      mutex_lock l(mu_);
-      cache_.emplace_back(std::move(element));
-    }
-
-    // Returns the size of the cache.
-    size_t size() {
-      tf_shared_lock l(mu_);
-      return cache_.size();
-    }
-
-   private:
-    mutex mu_;
-    // Determines whether a writer has claimed the cache.
-    bool claimed_ GUARDED_BY(mu_) = false;
-    // Determines whether all elements of the dataset have been cached.
-    bool completed_ GUARDED_BY(mu_) = false;
-    std::vector<std::vector<Tensor>> cache_ GUARDED_BY(mu_);
-  };
-
   class MemoryIterator : public DatasetIterator<MemoryDataset> {
    public:
-    explicit MemoryIterator(const Params& params)
-        : DatasetIterator<MemoryDataset>(params) {}
+    explicit MemoryIterator(const Params& params, MemoryCache* cache)
+        : DatasetIterator<MemoryDataset>(params), cache_(cache) {}
 
-    ~MemoryIterator() override { cache_->Unref(); }
+    ~MemoryIterator() override {
+      if (dataset()->cache_ == nullptr) {
+        cache_->Unref();
+      }
+    }
 
     Status Initialize(IteratorContext* ctx) override {
       mutex_lock l(mu_);
-      // Use the resource manager in the iterator context to get / create
-      // a cache.
-      ResourceMgr* mgr = ctx->resource_mgr();
-      const string name = strings::StrCat(prefix(), name_utils::kDelimiter,
-                                          dataset()->node_name(),
-                                          name_utils::kDelimiter, kMemoryCache);
-      TF_RETURN_IF_ERROR(mgr->LookupOrCreate<MemoryCache>(
-          kTFData, name, &cache_, [](MemoryCache** cache) {
-            *cache = new MemoryCache();
-            return Status::OK();
-          }));
+      if (cache_ == nullptr) {
+        // Use the resource manager in the iterator context to get / create
+        // a cache.
+        ResourceMgr* mgr = ctx->resource_mgr();
+        const string name = strings::StrCat(
+            prefix(), name_utils::kDelimiter, dataset()->node_name(),
+            name_utils::kDelimiter, kMemoryCache);
+        TF_RETURN_IF_ERROR(mgr->LookupOrCreate<MemoryCache>(
+            kTFData, name, &cache_, [](MemoryCache** cache) {
+              *cache = new MemoryCache();
+              return Status::OK();
+            }));
+      }
       mode_ = cache_->MaybeClaim() ? Mode::write : Mode::read;
       InitializeIterator();
       if (mode_ == Mode::read && !cache_->IsCompleted()) {
@@ -966,10 +966,42 @@
   };  // MemoryIterator
 
   const DatasetBase* const input_;
+  MemoryCache* cache_ = nullptr;
 };  // MemoryDataset
 
+class CacheDatasetOp::MemoryDatasetV2 : public CacheDatasetOp::MemoryDataset {
+ public:
+  explicit MemoryDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
+                           MemoryCache* cache, const Tensor& resource_handle)
+      : MemoryDataset(ctx, input, cache), resource_handle_(resource_handle) {}
+
+  Status CheckExternalState() const override {
+    return errors::FailedPrecondition(DebugString(),
+                                      " depends on memory cache resource.");
+  }
+
+ protected:
+  Status AsGraphDefInternal(SerializationContext* ctx,
+                            DatasetGraphDefBuilder* b,
+                            Node** output) const override {
+    Node* input_node = nullptr;
+    TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
+    Node* filename_node = nullptr;
+    TF_RETURN_IF_ERROR(b->AddScalar(string(""), &filename_node));
+    Node* resource_handle_node = nullptr;
+    TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node));
+    TF_RETURN_IF_ERROR(b->AddDataset(
+        this, {input_node, filename_node, resource_handle_node}, output));
+    return Status::OK();
+  }
+
+ private:
+  const Tensor resource_handle_;
+};
+
 CacheDatasetOp::CacheDatasetOp(OpKernelConstruction* ctx)
-    : UnaryDatasetOpKernel(ctx) {}
+    : UnaryDatasetOpKernel(ctx),
+      op_version_(ctx->def().op() == kCacheDataset ? 1 : 2) {}
 
 void CacheDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
                                  DatasetBase** output) {
@@ -978,15 +1010,29 @@
   OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, kFileName, &filename));
 
   if (filename.empty()) {
-    *output = new MemoryDataset(ctx, input);
+    if (op_version_ == 2) {
+      MemoryCache* cache = nullptr;
+      OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 2), &cache));
+      // Transferring cache reference ownership onto `MemoryDatasetV2`.
+      *output = new MemoryDatasetV2(ctx, input, cache, ctx->input(2));
+    } else {
+      *output = new MemoryDataset(ctx, input, /*cache=*/nullptr);
+    }
   } else {
-    *output = new FileDataset(ctx, input, filename, ctx->env());
+    if (op_version_ == 2) {
+      *output =
+          new FileDatasetV2(ctx, input, filename, ctx->env(), ctx->input(2));
+    } else {
+      *output = new FileDataset(ctx, input, filename, ctx->env());
+    }
   }
 }
 
 namespace {
 REGISTER_KERNEL_BUILDER(Name("CacheDataset").Device(DEVICE_CPU),
                         CacheDatasetOp);
+REGISTER_KERNEL_BUILDER(Name("CacheDatasetV2").Device(DEVICE_CPU),
+                        CacheDatasetOp);
 }  // namespace
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.h b/tensorflow/core/kernels/data/cache_dataset_ops.h
index af023a6..484d048 100644
--- a/tensorflow/core/kernels/data/cache_dataset_ops.h
+++ b/tensorflow/core/kernels/data/cache_dataset_ops.h
@@ -12,8 +12,8 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#ifndef TENSORFLOW_CORE_KERNELS_DATA_CACHE_DATASET_OP_H_
-#define TENSORFLOW_CORE_KERNELS_DATA_CACHE_DATASET_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_CACHE_DATASET_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_CACHE_DATASET_OPS_H_
 
 #include "tensorflow/core/framework/dataset.h"
 
@@ -22,6 +22,9 @@
 
 class CacheDatasetOp : public UnaryDatasetOpKernel {
  public:
+  class FileDataset;
+  class MemoryDataset;
+
   static constexpr const char* const kDatasetType = "Cache";
   static constexpr const char* const kInputDataset = "input_dataset";
   static constexpr const char* const kFileName = "filename";
@@ -35,11 +38,13 @@
                    DatasetBase** output) override;
 
  private:
-  class FileDataset;
-  class MemoryDataset;
+  class FileDatasetV2;
+  class MemoryDatasetV2;
+
+  int op_version_;
 };
 
 }  // namespace data
 }  // namespace tensorflow
 
-#endif  // TENSORFLOW_CORE_KERNELS_DATA_CACHE_DATASET_OP_H_
+#endif  // TENSORFLOW_CORE_KERNELS_DATA_CACHE_DATASET_OPS_H_
diff --git a/tensorflow/core/kernels/data/cache_dataset_ops_test.cc b/tensorflow/core/kernels/data/cache_dataset_ops_test.cc
index 812d719..53d455b 100644
--- a/tensorflow/core/kernels/data/cache_dataset_ops_test.cc
+++ b/tensorflow/core/kernels/data/cache_dataset_ops_test.cc
@@ -23,6 +23,25 @@
 constexpr char kMemoryDatasetPrefix[] = "Memory";
 
 class CacheDatasetOpTest : public DatasetOpsTestBase {
+ public:
+  ~CacheDatasetOpTest() {
+    if (!filename_.empty()) {
+      std::vector<string> cache_files;
+      Status s = device_->env()->GetMatchingPaths(
+          strings::StrCat(filename_, "*"), &cache_files);
+      if (!s.ok()) {
+        LOG(WARNING) << "Failed to get matching files on " << filename_
+                     << "* : " << s.ToString();
+      }
+      for (const string& path : cache_files) {
+        s = device_->env()->DeleteFile(path);
+        if (!s.ok()) {
+          LOG(WARNING) << "Failed to delete " << path << " : " << s.ToString();
+        }
+      }
+    }
+  }
+
  protected:
   // Creates `TensorSliceDataset` variant tensor from the input vector of
   // tensors.
@@ -57,8 +76,13 @@
       std::unique_ptr<OpKernelContext>* context) {
     TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs));
     TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
+    TF_RETURN_IF_ERROR(ParseScalarArgument<string>(
+        context->get(), CacheDatasetOp::kFileName, &filename_));
     return Status::OK();
   }
+
+ private:
+  string filename_ = "";
 };
 
 struct TestCase {
@@ -73,58 +97,54 @@
 
 // Test case 1: cache data in file.
 TestCase TestCase1() {
-  return {
-      /*input_tensors*/ {DatasetOpsTestBase::CreateTensor<int64>(
-          TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
-      /*file_name*/ absl::StrCat(testing::TmpDir(), "/cache_data"),
-      /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 1}, {0, 1, 2}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 1}, {3, 4, 5}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 1}, {6, 7, 8})},
-      /*expected_output_dtypes*/ {DT_INT64},
-      /*expected_output_shapes*/ {PartialTensorShape({3, 1})},
-      /*expected_cardinality*/ 3,
-      /*breakpoints*/ {0, 4, 11}};
+  return {/*input_tensors*/ {CreateTensor<int64>(TensorShape{3, 3, 1},
+                                                 {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+          /*file_name*/ absl::StrCat(testing::TmpDir(), "/cache_data"),
+          /*expected_outputs*/
+          {CreateTensor<int64>(TensorShape{3, 1}, {0, 1, 2}),
+           CreateTensor<int64>(TensorShape{3, 1}, {3, 4, 5}),
+           CreateTensor<int64>(TensorShape{3, 1}, {6, 7, 8})},
+          /*expected_output_dtypes*/ {DT_INT64},
+          /*expected_output_shapes*/ {PartialTensorShape({3, 1})},
+          /*expected_cardinality*/ 3,
+          /*breakpoints*/ {0, 2, 4, 11}};
 }
 
 // Test case 2: cache empty data in file.
 TestCase TestCase2() {
-  return {/*input_tensors*/ {
-              DatasetOpsTestBase::CreateTensor<int64>(TensorShape{0}, {})},
+  return {/*input_tensors*/ {CreateTensor<int64>(TensorShape{0}, {})},
           /*file_name*/ absl::StrCat(testing::TmpDir(), "/empty_cache_data"),
           /*expected_outputs*/ {},
           /*expected_output_dtypes*/ {DT_INT64},
           /*expected_output_shapes*/ {PartialTensorShape({})},
           /*expected_cardinality*/ 0,
-          /*breakpoints*/ {0, 4, 11}};
+          /*breakpoints*/ {0, 2, 4, 11}};
 }
 
 // Test case 3: cache data in memory.
 TestCase TestCase3() {
-  return {
-      /*input_tensors*/ {DatasetOpsTestBase::CreateTensor<int64>(
-          TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
-      /*file_name*/ "",
-      /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 1}, {0, 1, 2}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 1}, {3, 4, 5}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 1}, {6, 7, 8})},
-      /*expected_output_dtypes*/ {DT_INT64},
-      /*expected_output_shapes*/ {PartialTensorShape({3, 1})},
-      /*expected_cardinality*/ 3,
-      /*breakpoints*/ {0, 4, 11}};
+  return {/*input_tensors*/ {CreateTensor<int64>(TensorShape{3, 3, 1},
+                                                 {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+          /*file_name*/ "",
+          /*expected_outputs*/
+          {CreateTensor<int64>(TensorShape{3, 1}, {0, 1, 2}),
+           CreateTensor<int64>(TensorShape{3, 1}, {3, 4, 5}),
+           CreateTensor<int64>(TensorShape{3, 1}, {6, 7, 8})},
+          /*expected_output_dtypes*/ {DT_INT64},
+          /*expected_output_shapes*/ {PartialTensorShape({3, 1})},
+          /*expected_cardinality*/ 3,
+          /*breakpoints*/ {0, 2, 4, 11}};
 }
 
 // Test case 4: cache empty data in memory.
 TestCase TestCase4() {
-  return {/*input_tensors*/ {
-              DatasetOpsTestBase::CreateTensor<int64>(TensorShape{0}, {})},
+  return {/*input_tensors*/ {CreateTensor<int64>(TensorShape{0}, {})},
           /*file_name*/ "",
           /*expected_outputs*/ {},
           /*expected_output_dtypes*/ {DT_INT64},
           /*expected_output_shapes*/ {PartialTensorShape({})},
           /*expected_cardinality*/ 0,
-          /*breakpoints*/ {0, 4, 11}};
+          /*breakpoints*/ {0, 2, 4, 11}};
 }
 
 class ParameterizedCacheDatasetOpTest
@@ -333,39 +353,6 @@
   EXPECT_EQ(cache_dataset->Cardinality(), test_case.expected_cardinality);
 }
 
-TEST_P(ParameterizedCacheDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  TestCase test_case = GetParam();
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  std::unique_ptr<OpKernel> cache_dataset_kernel;
-  TF_ASSERT_OK(CreateCacheDatasetOpKernel(test_case.expected_output_dtypes,
-                                          test_case.expected_output_shapes,
-                                          &cache_dataset_kernel));
-  Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
-  std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
-  TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
-                                              &tensor_slice_dataset_tensor));
-  Tensor file_name = CreateTensor<string>(TensorShape{}, {test_case.file_name});
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&tensor_slice_dataset_tensor), TensorValue(&file_name)});
-  std::unique_ptr<OpKernelContext> cache_dataset_context;
-  TF_ASSERT_OK(CreateCacheDatasetContext(cache_dataset_kernel.get(), &inputs,
-                                         &cache_dataset_context));
-  DatasetBase* cache_dataset;
-  TF_ASSERT_OK(CreateDataset(cache_dataset_kernel.get(),
-                             cache_dataset_context.get(), &cache_dataset));
-  core::ScopedUnref scoped_unref(cache_dataset);
-
-  std::unique_ptr<SerializationContext> serialization_context;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(cache_dataset->Save(serialization_context.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_P(ParameterizedCacheDatasetOpTest, IteratorOutputShapes) {
   int thread_num = 2, cpu_num = 2;
   TestCase test_case = GetParam();
diff --git a/tensorflow/core/kernels/data/cache_ops.cc b/tensorflow/core/kernels/data/cache_ops.cc
new file mode 100644
index 0000000..2d77e03
--- /dev/null
+++ b/tensorflow/core/kernels/data/cache_ops.cc
@@ -0,0 +1,125 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/kernels/data/cache_ops.h"
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+const char kMemoryCache[] = "MemoryCache";
+
+}  // namespace
+
+string MemoryCache::DebugString() const { return kMemoryCache; }
+
+void MemoryCache::Complete() {
+  mutex_lock l(mu_);
+  completed_ = true;
+}
+
+bool MemoryCache::IsClaimed() {
+  tf_shared_lock l(mu_);
+  return claimed_;
+}
+
+bool MemoryCache::IsCompleted() {
+  tf_shared_lock l(mu_);
+  return completed_;
+}
+
+bool MemoryCache::MaybeClaim() {
+  mutex_lock l(mu_);
+  if (!claimed_) {
+    claimed_ = true;
+    return true;
+  }
+  return false;
+}
+
+void MemoryCache::Reset() {
+  mutex_lock l(mu_);
+  claimed_ = false;
+  completed_ = false;
+  cache_.clear();
+}
+
+const std::vector<Tensor>& MemoryCache::at(int64 index) {
+  tf_shared_lock l(mu_);
+  DCHECK(index < cache_.size());
+  return cache_[index];
+}
+
+void MemoryCache::emplace_back(std::vector<Tensor> element) {
+  mutex_lock l(mu_);
+  cache_.emplace_back(std::move(element));
+}
+
+size_t MemoryCache::size() {
+  tf_shared_lock l(mu_);
+  return cache_.size();
+}
+
+AnonymousMemoryCacheHandleOp::AnonymousMemoryCacheHandleOp(
+    OpKernelConstruction* ctx)
+    : AnonymousResourceOp<MemoryCache>(ctx) {}
+
+void AnonymousMemoryCacheHandleOp::Compute(OpKernelContext* ctx) {
+  AnonymousResourceOp<MemoryCache>::Compute(ctx);
+}
+
+string AnonymousMemoryCacheHandleOp::name() { return kMemoryCache; }
+
+Status AnonymousMemoryCacheHandleOp::CreateResource(
+    OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
+    std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
+    FunctionLibraryRuntime* lib, MemoryCache** resource) {
+  *resource = new MemoryCache();
+  return Status::OK();
+}
+
+void DeleteMemoryCacheOp::Compute(OpKernelContext* ctx) {
+  const ResourceHandle& handle = ctx->input(0).flat<ResourceHandle>()(0);
+  // The resource is guaranteed to exist because the variant tensor wrapping the
+  // deleter is provided as an unused input to this op, which guarantees that it
+  // has not run yet.
+  Status s = ctx->resource_manager()->Delete(handle);
+  if (errors::IsNotFound(s)) {
+    // TODO(b/135948230): Investigate why is the above statement not true and
+    // then get rid of the special case.
+    ctx->SetStatus(Status::OK());
+    return;
+  }
+  ctx->SetStatus(s);
+}
+
+namespace {
+
+REGISTER_KERNEL_BUILDER(Name("AnonymousMemoryCache").Device(DEVICE_CPU),
+                        AnonymousMemoryCacheHandleOp);
+
+REGISTER_KERNEL_BUILDER(Name("DeleteMemoryCache").Device(DEVICE_CPU),
+                        DeleteMemoryCacheOp);
+
+}  // namespace
+}  // namespace data
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/cache_ops.h b/tensorflow/core/kernels/data/cache_ops.h
new file mode 100644
index 0000000..c022c06
--- /dev/null
+++ b/tensorflow/core/kernels/data/cache_ops.h
@@ -0,0 +1,95 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_CACHE_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_CACHE_OPS_H_
+
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
+
+namespace tensorflow {
+namespace data {
+
+// A thread-safe data structure for caching dataset elements.
+//
+// The expected use is that a single `MemoryWriterIterator` populates the
+// cache with dataset elements. Once all elements are cached, the cache can
+// be used by one or more `MemoryReaderIterator`s.
+class MemoryCache : public ResourceBase {
+ public:
+  MemoryCache() = default;
+
+  string DebugString() const override;
+
+  // Marks the cache as completed.
+  void Complete();
+
+  // Returns whether the cache is claimed.
+  bool IsClaimed();
+
+  // Returns whether the cache is completed.
+  bool IsCompleted();
+
+  // Attempts to claim the cache, returning whether the cache was claimed.
+  bool MaybeClaim();
+
+  // Resets the cache.
+  void Reset();
+
+  // Returns the element at the given index.
+  const std::vector<Tensor>& at(int64 index);
+
+  // Adds the element to the cache.
+  void emplace_back(std::vector<Tensor> element);
+
+  // Returns the size of the cache.
+  size_t size();
+
+ private:
+  mutex mu_;
+  // Determines whether a writer has claimed the cache.
+  bool claimed_ GUARDED_BY(mu_) = false;
+  // Determines whether all elements of the dataset have been cached.
+  bool completed_ GUARDED_BY(mu_) = false;
+  std::vector<std::vector<Tensor>> cache_ GUARDED_BY(mu_);
+};
+
+// Creates an instance of cache resource and transfers ownership to the caller.
+class AnonymousMemoryCacheHandleOp : public AnonymousResourceOp<MemoryCache> {
+ public:
+  explicit AnonymousMemoryCacheHandleOp(OpKernelConstruction* ctx);
+  void Compute(OpKernelContext* ctx) override;
+
+ private:
+  string name() override;
+  Status CreateResource(OpKernelContext* ctx,
+                        std::unique_ptr<FunctionLibraryDefinition> flib_def,
+                        std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
+                        FunctionLibraryRuntime* lib,
+                        MemoryCache** resource) override;
+};
+
+// Deletes an instance of cache resource.
+class DeleteMemoryCacheOp : public OpKernel {
+ public:
+  explicit DeleteMemoryCacheOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override;
+};
+
+}  // namespace data
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_KERNELS_DATA_CACHE_OPS_H_
diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc
index 89656b9..b3757fa 100644
--- a/tensorflow/core/kernels/data/captured_function.cc
+++ b/tensorflow/core/kernels/data/captured_function.cc
@@ -210,6 +210,76 @@
   return (*result)->CopyFunctionDefFrom(func_name, *lib_def);
 }
 
+Status IsNodeStateful(const FunctionLibraryDefinition& library,
+                      const NodeDef& node);
+
+Status IsFunctionStateful(const FunctionLibraryDefinition& library,
+                          const FunctionDef& function_def) {
+  if (!function_def.signature().is_stateful()) {
+    return Status::OK();
+  }
+
+  for (const NodeDef& node_def : function_def.node_def()) {
+    TF_RETURN_IF_ERROR(IsNodeStateful(library, node_def));
+  }
+  return Status::OK();
+}
+
+// Returns whether an op has been whitelisted as stateless. Uses a heuristic to
+// whitelist source dataset ops which have been marked stateful due to
+// b/65524810. Also looks up the `op_def->name` in the global
+// `WhitelistedStatefulOpRegistry`.
+bool IsOpWhitelisted(const OpDef* op_def) {
+  return (op_def->output_arg_size() == 1 &&
+          op_def->output_arg(0).type() == DT_VARIANT &&
+          (absl::EndsWith(op_def->name(), "Dataset") ||
+           absl::EndsWith(op_def->name(), "DatasetV2"))) ||
+         WhitelistedStatefulOpRegistry::Global()->Contains(op_def->name());
+}
+
+Status IsNodeStateful(const FunctionLibraryDefinition& library,
+                      const NodeDef& node) {
+  const OpDef* op_def;
+
+  // TODO(jsimsa): Fix C++ unit tests so that we do not have to ignore
+  // `LookUpOpDef` errors here.
+  if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok() ||
+      IsOpWhitelisted(op_def) || !op_def->is_stateful() ||
+      op_def->name() == "Assert") {
+    return Status::OK();
+  }
+
+  if (op_def->name() == "If") {
+    const FunctionDef* then_func =
+        library.Find(node.attr().at("then_branch").func().name());
+    const FunctionDef* else_func =
+        library.Find(node.attr().at("else_branch").func().name());
+    if (then_func != nullptr) {
+      TF_RETURN_IF_ERROR(IsFunctionStateful(library, *then_func));
+    }
+    if (else_func != nullptr) {
+      TF_RETURN_IF_ERROR(IsFunctionStateful(library, *else_func));
+    }
+    return Status::OK();
+  }
+
+  if (op_def->name() == "While") {
+    const FunctionDef* cond_func =
+        library.Find(node.attr().at("cond").func().name());
+    const FunctionDef* body_func =
+        library.Find(node.attr().at("body").func().name());
+    if (cond_func != nullptr) {
+      TF_RETURN_IF_ERROR(IsFunctionStateful(library, *cond_func));
+    }
+    if (body_func != nullptr) {
+      TF_RETURN_IF_ERROR(IsFunctionStateful(library, *body_func));
+    }
+    return Status::OK();
+  }
+
+  return errors::FailedPrecondition(op_def->name(), " is stateful.");
+}
+
 }  // namespace
 
 Status MakeIteratorFromInputElement(
@@ -372,11 +442,22 @@
     // TODO(jsimsa): Correctly handle tensors on devices other than CPU:0.
     Device* cpu_device;
     TF_RETURN_IF_ERROR(lib->device_mgr()->LookupDevice("CPU:0", &cpu_device));
-    for (auto& input : captured_inputs_) {
+    std::unordered_map<int, DtypeAndPartialTensorShape>&
+        input_resource_variable_dtypes_and_shapes =
+            inst_opts.input_resource_dtypes_and_shapes;
+    for (size_t i = 0; i < captured_inputs_.size(); ++i) {
+      const auto& input = captured_inputs_[i];
       DataType dtype = input.dtype();
       if (dtype == DT_RESOURCE) {
         const ResourceHandle& handle = input.flat<ResourceHandle>()(0);
         inst_opts.input_devices.push_back(handle.device());
+        const auto& dtypes_and_shapes = handle.dtypes_and_shapes();
+        // Set dtypes and shapes for resource variable inputs.
+        if (!dtypes_and_shapes.empty()) {
+          input_resource_variable_dtypes_and_shapes[num_non_captured_inputs +
+                                                    i] =
+              dtypes_and_shapes.at(0);
+        }
       } else if (MTypeFromDType(dtype) == HOST_MEMORY) {
         inst_opts.input_devices.push_back(cpu_device->name());
       } else {
@@ -406,6 +487,16 @@
   return Status::OK();
 }
 
+bool CapturedFunction::IsStateful() const { return !CheckExternalState().ok(); }
+
+Status CapturedFunction::CheckExternalState() const {
+  for (const auto& name : lib_def()->ListFunctionNames()) {
+    TF_RETURN_IF_ERROR(
+        IsFunctionStateful(*lib_def(), *(lib_def()->Find(name))));
+  }
+  return Status::OK();
+}
+
 namespace {
 class CallFrameBase : public CallFrameInterface {
  public:
diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h
index b020f53..5fd4633 100644
--- a/tensorflow/core/kernels/data/captured_function.h
+++ b/tensorflow/core/kernels/data/captured_function.h
@@ -204,6 +204,15 @@
                      std::unique_ptr<InstantiatedCapturedFunction>*
                          instantiated_captured_function);
 
+  // Determines whether the captured function is stateful.
+  //
+  // TODO(jsimsa): Remove this method once all users of `CapturedFunction`
+  // migrate to `CheckExternalState`.
+  bool IsStateful() const;
+
+  // Determines whether the captured function is stateful.
+  Status CheckExternalState() const;
+
   // Returns the additional captured inputs that will be passed to the function.
   const std::vector<Tensor>& captured_inputs() const {
     return captured_inputs_;
diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
index f1368ae..4d9bce1 100644
--- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc
+++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
@@ -85,6 +85,11 @@
     return n1 + n2;
   }
 
+  Status CheckExternalState() const override {
+    TF_RETURN_IF_ERROR(input_->CheckExternalState());
+    return to_concatenate_->CheckExternalState();
+  }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op_test.cc b/tensorflow/core/kernels/data/concatenate_dataset_op_test.cc
index b5399ad..7f4c55f 100644
--- a/tensorflow/core/kernels/data/concatenate_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/concatenate_dataset_op_test.cc
@@ -81,23 +81,19 @@
 // Test case 1: same shape.
 TestCase SameShapeTestCase() {
   return {/*input_tensors*/
-          {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 2},
-                                                    {1, 2, 3, 4}),
-            DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 2},
-                                                    {5, 6, 7, 8})},
-           {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 2},
-                                                    {11, 12, 13, 14}),
-            DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 2},
-                                                    {15, 16, 17, 18})}},
+          {{CreateTensor<int64>(TensorShape{2, 2}, {1, 2, 3, 4}),
+            CreateTensor<int64>(TensorShape{2, 2}, {5, 6, 7, 8})},
+           {CreateTensor<int64>(TensorShape{2, 2}, {11, 12, 13, 14}),
+            CreateTensor<int64>(TensorShape{2, 2}, {15, 16, 17, 18})}},
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2}, {1, 2}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2}, {5, 6}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2}, {3, 4}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2}, {7, 8}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2}, {11, 12}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2}, {15, 16}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2}, {13, 14}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2}, {17, 18})},
+          {CreateTensor<int64>(TensorShape{2}, {1, 2}),
+           CreateTensor<int64>(TensorShape{2}, {5, 6}),
+           CreateTensor<int64>(TensorShape{2}, {3, 4}),
+           CreateTensor<int64>(TensorShape{2}, {7, 8}),
+           CreateTensor<int64>(TensorShape{2}, {11, 12}),
+           CreateTensor<int64>(TensorShape{2}, {15, 16}),
+           CreateTensor<int64>(TensorShape{2}, {13, 14}),
+           CreateTensor<int64>(TensorShape{2}, {17, 18})},
           /*expected_output_dtypes*/ {DT_INT64, DT_INT64},
           /*expected_output_shapes*/
           {PartialTensorShape({2}), PartialTensorShape({2})},
@@ -107,42 +103,38 @@
 
 // Test case 2: different shape.
 TestCase DifferentShapeTestCase() {
-  return {
-      /*input_tensors*/
-      {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 3},
-                                                {1, 2, 3, 4, 5, 6}),
-        DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 2},
-                                                {7, 8, 9, 10})},
-       {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 2},
-                                                {11, 12, 13, 14}),
-        DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 1}, {15, 16})}},
-      /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3}, {1, 2, 3}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2}, {7, 8}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3}, {4, 5, 6}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2}, {9, 10}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2}, {11, 12}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {15}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2}, {13, 14}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {16})},
-      /*expected_output_dtypes*/ {DT_INT64, DT_INT64},
-      /*expected_output_shapes*/
-      {PartialTensorShape({-1}), PartialTensorShape({-1})},
-      /*expected_cardinality*/ 4,
-      /*breakpoints*/ {0, 2, 5}};
+  return {/*input_tensors*/
+          {{CreateTensor<int64>(TensorShape{2, 3}, {1, 2, 3, 4, 5, 6}),
+            CreateTensor<int64>(TensorShape{2, 2}, {7, 8, 9, 10})},
+           {CreateTensor<int64>(TensorShape{2, 2}, {11, 12, 13, 14}),
+            CreateTensor<int64>(TensorShape{2, 1}, {15, 16})}},
+          /*expected_outputs*/
+          {CreateTensor<int64>(TensorShape{3}, {1, 2, 3}),
+           CreateTensor<int64>(TensorShape{2}, {7, 8}),
+           CreateTensor<int64>(TensorShape{3}, {4, 5, 6}),
+           CreateTensor<int64>(TensorShape{2}, {9, 10}),
+           CreateTensor<int64>(TensorShape{2}, {11, 12}),
+           CreateTensor<int64>(TensorShape{1}, {15}),
+           CreateTensor<int64>(TensorShape{2}, {13, 14}),
+           CreateTensor<int64>(TensorShape{1}, {16})},
+          /*expected_output_dtypes*/ {DT_INT64, DT_INT64},
+          /*expected_output_shapes*/
+          {PartialTensorShape({-1}), PartialTensorShape({-1})},
+          /*expected_cardinality*/ 4,
+          /*breakpoints*/ {0, 2, 5}};
 }
 
 // Test case 3: different dtypes
 TestCase DifferentDtypeTestCase() {
-  return {/*input_tensors*/ {{DatasetOpsTestBase::CreateTensor<int64>(
-                                 TensorShape({2, 2}), {1, 2, 3, 4})},
-                             {DatasetOpsTestBase::CreateTensor<double>(
-                                 TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}},
-          /*expected_outputs*/ {},
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({2})},
-          /*expected_cardinality*/ 0,
-          /*breakpoints*/ {}};
+  return {
+      /*input_tensors*/ {
+          {CreateTensor<int64>(TensorShape({2, 2}), {1, 2, 3, 4})},
+          {CreateTensor<double>(TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}},
+      /*expected_outputs*/ {},
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({2})},
+      /*expected_cardinality*/ 0,
+      /*breakpoints*/ {}};
 }
 
 class ParameterizedConcatenateDatasetOpTest
@@ -365,39 +357,6 @@
   EXPECT_EQ(concatenate_dataset->Cardinality(), test_case.expected_cardinality);
 }
 
-TEST_F(ConcatenateDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  const TestCase &test_case = SameShapeTestCase();
-  std::vector<Tensor> tensor_slice_dataset_tensors;
-  TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors,
-                                               &tensor_slice_dataset_tensors));
-  gtl::InlinedVector<TensorValue, 4> inputs;
-  for (auto &tensor : tensor_slice_dataset_tensors) {
-    inputs.emplace_back(&tensor);
-  }
-  std::unique_ptr<OpKernel> dataset_kernel;
-  TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes,
-                                              test_case.expected_output_shapes,
-                                              &dataset_kernel));
-  std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
-  TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs,
-                                               &dataset_kernel_ctx));
-  DatasetBase *concatenate_dataset;
-  TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
-                             &concatenate_dataset));
-  core::ScopedUnref scoped_unref(concatenate_dataset);
-
-  std::unique_ptr<SerializationContext> serialization_ctx;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(concatenate_dataset->Save(serialization_ctx.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_P(ParameterizedConcatenateDatasetOpTest, IteratorOutputDtypes) {
   int thread_num = 2, cpu_num = 2;
   TF_ASSERT_OK(InitThreadPool(thread_num));
diff --git a/tensorflow/core/kernels/data/dataset_ops.cc b/tensorflow/core/kernels/data/dataset_ops.cc
index 58cd174..e931755 100644
--- a/tensorflow/core/kernels/data/dataset_ops.cc
+++ b/tensorflow/core/kernels/data/dataset_ops.cc
@@ -40,7 +40,7 @@
       ctx, AsGraphDef(ctx, dataset, SerializationContext({}), &graph_def));
   Tensor* result;
   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result));
-  result->scalar<string>()() = graph_def.SerializeAsString();
+  result->scalar<tstring>()() = graph_def.SerializeAsString();
 }
 
 void DatasetCardinalityOp::Compute(OpKernelContext* ctx) {
diff --git a/tensorflow/core/kernels/data/dataset_test_base.cc b/tensorflow/core/kernels/data/dataset_test_base.cc
index 2854bfd..f41c28b 100644
--- a/tensorflow/core/kernels/data/dataset_test_base.cc
+++ b/tensorflow/core/kernels/data/dataset_test_base.cc
@@ -560,5 +560,126 @@
   return Status::OK();
 }
 
+Status DatasetOpsTestBase::CheckDatasetNodeName(
+    const DatasetBase& dataset, const string& expected_dataset_node_name) {
+  EXPECT_EQ(dataset.node_name(), expected_dataset_node_name);
+  return Status::OK();
+}
+
+Status DatasetOpsTestBase::CheckDatasetTypeString(
+    const DatasetBase& dataset, const string& expected_dataset_type_string) {
+  EXPECT_EQ(dataset.type_string(), expected_dataset_type_string);
+  return Status::OK();
+}
+
+Status DatasetOpsTestBase::CheckDatasetOutputDtypes(
+    const DatasetBase& dataset, const DataTypeVector& expected_output_dtypes) {
+  TF_EXPECT_OK(
+      VerifyTypesMatch(dataset.output_dtypes(), expected_output_dtypes));
+  return Status::OK();
+}
+
+Status DatasetOpsTestBase::CheckDatasetOutputShapes(
+    const DatasetBase& dataset,
+    const std::vector<PartialTensorShape>& expected_output_shapes) {
+  TF_EXPECT_OK(
+      VerifyShapesCompatible(dataset.output_shapes(), expected_output_shapes));
+  return Status::OK();
+}
+
+Status DatasetOpsTestBase::CheckDatasetCardinality(const DatasetBase& dataset,
+                                                   int64 expected_cardinality) {
+  EXPECT_EQ(dataset.Cardinality(), expected_cardinality);
+  return Status::OK();
+}
+
+Status DatasetOpsTestBase::CheckDatasetIsStateful(const DatasetBase& dataset,
+                                                  bool expected_stateful) {
+  EXPECT_EQ(dataset.IsStateful(), expected_stateful);
+  return Status::OK();
+}
+
+Status DatasetOpsTestBase::CheckIteratorOutputDtypes(
+    const IteratorBase& iterator,
+    const DataTypeVector& expected_output_dtypes) {
+  TF_EXPECT_OK(
+      VerifyTypesMatch(iterator.output_dtypes(), expected_output_dtypes));
+  return Status::OK();
+}
+
+Status DatasetOpsTestBase::CheckIteratorOutputShapes(
+    const IteratorBase& iterator,
+    const std::vector<PartialTensorShape>& expected_output_shapes) {
+  TF_EXPECT_OK(
+      VerifyShapesCompatible(iterator.output_shapes(), expected_output_shapes));
+  return Status::OK();
+}
+
+Status DatasetOpsTestBase::CheckIteratorPrefix(
+    const IteratorBase& iterator, const string& expected_iterator_prefix) {
+  EXPECT_EQ(iterator.prefix(), expected_iterator_prefix);
+  return Status::OK();
+}
+
+Status DatasetOpsTestBase::CheckIteratorGetNext(
+    IteratorBase* iterator, IteratorContext* iterator_context,
+    const std::vector<Tensor>& expected_outputs, bool compare_order) {
+  bool end_of_sequence = false;
+  std::vector<Tensor> out_tensors;
+  while (!end_of_sequence) {
+    std::vector<Tensor> next;
+    TF_RETURN_IF_ERROR(
+        iterator->GetNext(iterator_context, &next, &end_of_sequence));
+    out_tensors.insert(out_tensors.end(), next.begin(), next.end());
+  }
+
+  TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs,
+                           /*compare_order=*/compare_order));
+  return Status::OK();
+}
+
+Status DatasetOpsTestBase::CheckIteratorSaveAndRestore(
+    const DatasetBase& dataset, IteratorContext* iterator_context,
+    const string& iterator_prefix, const std::vector<Tensor>& expected_outputs,
+    const std::vector<int>& breakpoints) {
+  std::unique_ptr<IteratorBase> iterator;
+  TF_RETURN_IF_ERROR(
+      dataset.MakeIterator(iterator_context, iterator_prefix, &iterator));
+  std::unique_ptr<SerializationContext> serialization_ctx;
+  TF_RETURN_IF_ERROR(CreateSerializationContext(&serialization_ctx));
+  bool end_of_sequence = false;
+  std::vector<Tensor> out_tensors;
+  int cur_iteration = 0;
+  auto expected_outputs_it = expected_outputs.begin();
+  for (int breakpoint : breakpoints) {
+    VariantTensorData data;
+    VariantTensorDataWriter writer(&data);
+    TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
+    TF_RETURN_IF_ERROR(writer.Flush());
+    VariantTensorDataReader reader(&data);
+    TF_EXPECT_OK(RestoreIterator(iterator_context, &reader, iterator_prefix,
+                                 dataset, &iterator));
+
+    while (cur_iteration <= breakpoint) {
+      TF_RETURN_IF_ERROR(
+          iterator->GetNext(iterator_context, &out_tensors, &end_of_sequence));
+      if (!end_of_sequence) {
+        EXPECT_NE(expected_outputs_it, expected_outputs.end());
+        TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it));
+        expected_outputs_it++;
+      }
+      cur_iteration++;
+    }
+
+    if (breakpoint >= expected_outputs.size()) {
+      EXPECT_TRUE(end_of_sequence);
+      EXPECT_EQ(expected_outputs_it, expected_outputs.end());
+    } else {
+      EXPECT_FALSE(end_of_sequence);
+    }
+  }
+  return Status::OK();
+}
+
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/dataset_test_base.h b/tensorflow/core/kernels/data/dataset_test_base.h
index 427ccca..e149680 100644
--- a/tensorflow/core/kernels/data/dataset_test_base.h
+++ b/tensorflow/core/kernels/data/dataset_test_base.h
@@ -77,6 +77,116 @@
                                const std::vector<absl::string_view>& records,
                                const CompressionParams& params);
 
+class DatasetParams {
+ public:
+  DatasetParams(DataTypeVector output_dtypes,
+                std::vector<PartialTensorShape> output_shapes, string node_name)
+      : output_dtypes(std::move(output_dtypes)),
+        output_shapes(std::move(output_shapes)),
+        node_name(std::move(node_name)) {}
+
+  virtual Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) = 0;
+
+  virtual ~DatasetParams() {}
+
+  DataTypeVector output_dtypes;
+  std::vector<PartialTensorShape> output_shapes;
+  string node_name;
+};
+
+template <typename T>
+struct GetNextTestCase {
+  T dataset_params;
+  std::vector<Tensor> expected_outputs;
+};
+
+template <typename T>
+struct DatasetNodeNameTestCase {
+  T dataset_params;
+  string expected_node_name;
+};
+
+template <typename T>
+struct DatasetTypeStringTestCase {
+  T dataset_params;
+  string expected_dataset_type_string;
+};
+
+template <typename T>
+struct DatasetOutputDtypesTestCase {
+  T dataset_params;
+  DataTypeVector expected_output_dtypes;
+};
+
+template <typename T>
+struct DatasetOutputShapesTestCase {
+  T dataset_params;
+  std::vector<PartialTensorShape> expected_output_shapes;
+};
+
+template <typename T>
+struct CardinalityTestCase {
+  T dataset_params;
+  int64 expected_cardinality;
+};
+
+template <typename T>
+struct DatasetSaveTestCase {
+  T dataset_params;
+};
+
+template <typename T>
+struct IsStatefulTestCase {
+  T dataset_params;
+  bool expected_stateful;
+};
+
+template <typename T>
+struct IteratorOutputDtypesTestCase {
+  T dataset_params;
+  DataTypeVector expected_output_dtypes;
+};
+
+template <typename T>
+struct IteratorOutputShapesTestCase {
+  T dataset_params;
+  std::vector<PartialTensorShape> expected_output_shapes;
+};
+
+template <typename T>
+struct IteratorOutputPrefixTestCase {
+  T dataset_params;
+  string expected_iterator_prefix;
+};
+
+template <typename T>
+struct IteratorSaveAndRestoreTestCase {
+  T dataset_params;
+  std::vector<int> breakpoints;
+  std::vector<Tensor> expected_outputs;
+};
+
+// Creates a tensor with the specified dtype, shape, and value.
+template <typename T>
+static Tensor CreateTensor(const TensorShape& input_shape,
+                           const gtl::ArraySlice<T>& input_data) {
+  Tensor tensor(DataTypeToEnum<T>::value, input_shape);
+  test::FillValues<T>(&tensor, input_data);
+  return tensor;
+}
+
+// Creates a vector of tensors with the specified dtype, shape, and values.
+template <typename T>
+std::vector<Tensor> CreateTensors(
+    const TensorShape& shape, const std::vector<gtl::ArraySlice<T>>& values) {
+  std::vector<Tensor> result;
+  result.reserve(values.size());
+  for (auto& value : values) {
+    result.emplace_back(CreateTensor<T>(shape, value));
+  }
+  return result;
+}
+
 // Helpful functions to test Dataset op kernels.
 class DatasetOpsTestBase : public ::testing::Test {
  public:
@@ -99,15 +209,6 @@
                             std::vector<Tensor> expected_tensors,
                             bool compare_order);
 
-  // Creates a tensor with the specified dtype, shape, and value.
-  template <typename T>
-  static Tensor CreateTensor(TensorShape input_shape,
-                             const gtl::ArraySlice<T>& input_data) {
-    Tensor tensor(DataTypeToEnum<T>::value, input_shape);
-    test::FillValues<T>(&tensor, input_data);
-    return tensor;
-  }
-
   // Creates a new op kernel based on the node definition.
   Status CreateOpKernel(const NodeDef& node_def,
                         std::unique_ptr<OpKernel>* op_kernel);
@@ -195,6 +296,58 @@
   Status GetDatasetFromContext(OpKernelContext* context, int output_index,
                                DatasetBase** const dataset);
 
+  // Checks `DatasetBase::node_name()`.
+  Status CheckDatasetNodeName(const DatasetBase& dataset,
+                              const string& expected_dataset_node_name);
+
+  // Checks `DatasetBase::type_string()`.
+  Status CheckDatasetTypeString(const DatasetBase& dataset,
+                                const string& expected_dataset_type_string);
+
+  // Checks `DatasetBase::output_dtypes()`.
+  Status CheckDatasetOutputDtypes(const DatasetBase& dataset,
+                                  const DataTypeVector& expected_output_dtypes);
+
+  // Checks `DatasetBase::output_shapes()`.
+  Status CheckDatasetOutputShapes(
+      const DatasetBase& dataset,
+      const std::vector<PartialTensorShape>& expected_output_shapes);
+
+  // Checks `DatasetBase::Cardinality()`.
+  Status CheckDatasetCardinality(const DatasetBase& dataset,
+                                 int64 expected_cardinality);
+
+  // Checks `DatasetBase::IsStateful()`.
+  Status CheckDatasetIsStateful(const DatasetBase& dataset,
+                                bool expected_stateful);
+
+  // Checks `IteratorBase::output_dtypes()`.
+  Status CheckIteratorOutputDtypes(
+      const IteratorBase& iterator,
+      const DataTypeVector& expected_output_dtypes);
+
+  // Checks `IteratorBase::output_shapes()`.
+  Status CheckIteratorOutputShapes(
+      const IteratorBase& iterator,
+      const std::vector<PartialTensorShape>& expected_output_shapes);
+
+  // Checks `IteratorBase::prefix()`.
+  Status CheckIteratorPrefix(const IteratorBase& iterator,
+                             const string& expected_iterator_prefix);
+
+  // Checks `IteratorBase::GetNext()`.
+  Status CheckIteratorGetNext(IteratorBase* iterator,
+                              IteratorContext* iterator_context,
+                              const std::vector<Tensor>& expected_outputs,
+                              bool compare_order);
+
+  // Checks `IteratorBase::Save()` and `IteratorBase::Restore()`.
+  Status CheckIteratorSaveAndRestore(
+      const DatasetBase& dataset, IteratorContext* iterator_context,
+      const string& iterator_prefix,
+      const std::vector<Tensor>& expected_outputs,
+      const std::vector<int>& breakpoints);
+
  protected:
   // Creates a thread pool for parallel tasks.
   Status InitThreadPool(int thread_num);
diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc
index 5c81cb6..e46f524 100644
--- a/tensorflow/core/kernels/data/dataset_utils.cc
+++ b/tensorflow/core/kernels/data/dataset_utils.cc
@@ -41,6 +41,8 @@
 namespace data {
 namespace {
 
+constexpr char kDelimiter[] = "@@";
+
 void AddFakeSinks(FunctionDef* function_def) {
   int counter = 0;
   for (const auto& output : function_def->signature().output_arg()) {
@@ -136,134 +138,6 @@
   return Status::OK();
 }
 
-}  // anonymous namespace
-
-Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
-                  SerializationContext&& serialization_ctx,
-                  GraphDef* graph_def) {
-  GraphDefBuilder b;
-  DatasetBase::DatasetGraphDefBuilder db(&b);
-  Node* output_node = nullptr;
-  TF_RETURN_IF_ERROR(
-      db.AddInputDataset(&serialization_ctx, dataset, &output_node));
-  // Insert a purely symbolic _Retval node to indicate to consumers which Tensor
-  // represents this Dataset.
-  ops::UnaryOp("_Retval", output_node,
-               b.opts()
-                   .WithName("dataset")
-                   .WithAttr("T", DT_VARIANT)
-                   .WithAttr("index", 0));
-  TF_RETURN_IF_ERROR(b.ToGraphDef(graph_def));
-  return Status::OK();
-}
-
-Status ConnectCancellationManagers(CancellationManager* parent,
-                                   CancellationManager* child,
-                                   std::function<void()>* deregister_fn) {
-  if (parent) {
-    CancellationToken token = parent->get_cancellation_token();
-    if (!parent->RegisterCallback(token, [child]() { child->StartCancel(); })) {
-      return errors::Cancelled("Operation was cancelled");
-    }
-    *deregister_fn = [parent, token]() { parent->DeregisterCallback(token); };
-  } else {
-    VLOG(1) << "Parent cancellation manager is not set. Cancellation will "
-               "not be propagated to the child cancellation manager.";
-    *deregister_fn = []() {};
-  }
-  return Status::OK();
-}
-
-Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
-                      std::function<RewriterConfig(void)> config_factory,
-                      bool optimize_function_library,
-                      DatasetBase** rewritten_input) {
-  SerializationContext::Params params;
-  std::vector<std::pair<string, Tensor>> input_list;
-  params.input_list = &input_list;
-  params.optimization_only = true;
-  SerializationContext serialization_ctx(params);
-  GraphDef graph_def;
-  TF_RETURN_IF_ERROR(
-      AsGraphDef(ctx, input, std::move(serialization_ctx), &graph_def));
-
-  string output_node;
-  for (const auto& node : graph_def.node()) {
-    if (node.op() == "_Retval") {
-      output_node = node.input(0);
-    }
-  }
-
-  VLOG(3) << "Before graph rewrites: " << graph_def.DebugString();
-  TF_RETURN_IF_ERROR(ApplyRewrites(ctx, config_factory,
-                                   optimize_function_library, &graph_def,
-                                   &output_node));
-  VLOG(3) << "After graph rewrites: " << graph_def.DebugString();
-
-  // Instantiate the optimized input pipeline by running the optimized graph
-  // using the optimized function library.
-  FunctionLibraryRuntime* flr = nullptr;
-  std::unique_ptr<ProcessFunctionLibraryRuntime> pflr = nullptr;
-  std::unique_ptr<FunctionLibraryDefinition> lib_def = nullptr;
-  TF_RETURN_IF_ERROR(
-      ctx->function_library()->Clone(&lib_def, &pflr, &flr, true));
-
-  // Some functions may have been modified without having their names
-  // changed (for example, nested dataset graphs from FlatMap or
-  // Interleave).
-  TF_RETURN_IF_ERROR(AddToFunctionLibrary(lib_def.get(), graph_def.library()));
-
-  Graph graph(OpRegistry::Global());
-  TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
-  std::vector<Tensor> outputs;
-  GraphRunner graph_runner(flr->device());
-
-  TF_RETURN_IF_ERROR(
-      graph_runner.Run(&graph, flr, input_list, {output_node}, &outputs));
-  TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], rewritten_input));
-  (*rewritten_input)->Ref();
-  return Status::OK();
-}
-
-Status VerifyTypesMatch(const DataTypeVector& expected,
-                        const DataTypeVector& received) {
-  if (expected.size() != received.size()) {
-    return errors::InvalidArgument(
-        "Number of components does not match: expected ", expected.size(),
-        " types but got ", received.size(), ".");
-  }
-  for (size_t i = 0; i < expected.size(); ++i) {
-    if (expected[i] != received[i]) {
-      return errors::InvalidArgument("Data type mismatch at component ", i,
-                                     ": expected ", DataTypeString(expected[i]),
-                                     " but got ", DataTypeString(received[i]),
-                                     ".");
-    }
-  }
-  return Status::OK();
-}
-
-Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
-                              const std::vector<PartialTensorShape>& received) {
-  if (expected.size() != received.size()) {
-    return errors::InvalidArgument(
-        "Number of components does not match: expected ", expected.size(),
-        " shapes but got ", received.size(), ".");
-  }
-  for (size_t i = 0; i < expected.size(); ++i) {
-    if (!expected[i].IsCompatibleWith(received[i])) {
-      return errors::InvalidArgument("Incompatible shapes at component ", i,
-                                     ": expected ", expected[i].DebugString(),
-                                     " but got ", received[i].DebugString(),
-                                     ".");
-    }
-  }
-
-  return Status::OK();
-}
-
-namespace {
-
 uint64 DefaultDependencyLoopNodeHash() {
   static const uint64 hash = Hash64("DependencyLoopNode");
   return hash;
@@ -496,7 +370,136 @@
   return final_hash;
 }
 
-}  // namespace
+}  // anonymous namespace
+
+Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
+                  SerializationContext&& serialization_ctx,
+                  GraphDef* graph_def) {
+  if (serialization_ctx.check_external_state()) {
+    TF_RETURN_IF_ERROR(dataset->CheckExternalState());
+  }
+  GraphDefBuilder b;
+  DatasetBase::DatasetGraphDefBuilder db(&b);
+  Node* output_node = nullptr;
+  TF_RETURN_IF_ERROR(
+      db.AddInputDataset(&serialization_ctx, dataset, &output_node));
+  // Insert a purely symbolic _Retval node to indicate to consumers which node
+  // represents `dataset`.
+  ops::UnaryOp("_Retval", output_node,
+               b.opts()
+                   .WithName("dataset")
+                   .WithAttr("T", DT_VARIANT)
+                   .WithAttr("index", 0));
+  TF_RETURN_IF_ERROR(b.ToGraphDef(graph_def));
+  return Status::OK();
+}
+
+Status ConnectCancellationManagers(CancellationManager* parent,
+                                   CancellationManager* child,
+                                   std::function<void()>* deregister_fn) {
+  if (parent) {
+    CancellationToken token = parent->get_cancellation_token();
+    if (!parent->RegisterCallback(token, [child]() { child->StartCancel(); })) {
+      return errors::Cancelled("Operation was cancelled");
+    }
+    *deregister_fn = [parent, token]() { parent->DeregisterCallback(token); };
+  } else {
+    VLOG(1) << "Parent cancellation manager is not set. Cancellation will "
+               "not be propagated to the child cancellation manager.";
+    *deregister_fn = []() {};
+  }
+  return Status::OK();
+}
+
+Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
+                      std::function<RewriterConfig(void)> config_factory,
+                      bool optimize_function_library,
+                      DatasetBase** rewritten_input) {
+  SerializationContext::Params params;
+  std::vector<std::pair<string, Tensor>> input_list;
+  params.input_list = &input_list;
+  params.check_external_state = false;
+  params.fail_if_unimplemented = false;
+  params.serialize_data_tensors = false;
+  SerializationContext serialization_ctx(params);
+  GraphDef graph_def;
+  TF_RETURN_IF_ERROR(
+      AsGraphDef(ctx, input, std::move(serialization_ctx), &graph_def));
+
+  string output_node;
+  for (const auto& node : graph_def.node()) {
+    if (node.op() == "_Retval") {
+      output_node = node.input(0);
+    }
+  }
+
+  VLOG(3) << "Before graph rewrites: " << graph_def.DebugString();
+  TF_RETURN_IF_ERROR(ApplyRewrites(ctx, config_factory,
+                                   optimize_function_library, &graph_def,
+                                   &output_node));
+  VLOG(3) << "After graph rewrites: " << graph_def.DebugString();
+
+  // Instantiate the optimized input pipeline by running the optimized graph
+  // using the optimized function library.
+  FunctionLibraryRuntime* flr = nullptr;
+  std::unique_ptr<ProcessFunctionLibraryRuntime> pflr = nullptr;
+  std::unique_ptr<FunctionLibraryDefinition> lib_def = nullptr;
+  TF_RETURN_IF_ERROR(
+      ctx->function_library()->Clone(&lib_def, &pflr, &flr, true));
+
+  // Some functions may have been modified without having their names
+  // changed (for example, nested dataset graphs from FlatMap or
+  // Interleave).
+  TF_RETURN_IF_ERROR(AddToFunctionLibrary(lib_def.get(), graph_def.library()));
+
+  Graph graph(OpRegistry::Global());
+  TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
+  std::vector<Tensor> outputs;
+  GraphRunner graph_runner(flr->device());
+
+  TF_RETURN_IF_ERROR(
+      graph_runner.Run(&graph, flr, input_list, {output_node}, &outputs));
+  TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], rewritten_input));
+  (*rewritten_input)->Ref();
+  return Status::OK();
+}
+
+Status VerifyTypesMatch(const DataTypeVector& expected,
+                        const DataTypeVector& received) {
+  if (expected.size() != received.size()) {
+    return errors::InvalidArgument(
+        "Number of components does not match: expected ", expected.size(),
+        " types but got ", received.size(), ".");
+  }
+  for (size_t i = 0; i < expected.size(); ++i) {
+    if (expected[i] != received[i]) {
+      return errors::InvalidArgument("Data type mismatch at component ", i,
+                                     ": expected ", DataTypeString(expected[i]),
+                                     " but got ", DataTypeString(received[i]),
+                                     ".");
+    }
+  }
+  return Status::OK();
+}
+
+Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
+                              const std::vector<PartialTensorShape>& received) {
+  if (expected.size() != received.size()) {
+    return errors::InvalidArgument(
+        "Number of components does not match: expected ", expected.size(),
+        " shapes but got ", received.size(), ".");
+  }
+  for (size_t i = 0; i < expected.size(); ++i) {
+    if (!expected[i].IsCompatibleWith(received[i])) {
+      return errors::InvalidArgument("Incompatible shapes at component ", i,
+                                     ": expected ", expected[i].DebugString(),
+                                     " but got ", received[i].DebugString(),
+                                     ".");
+    }
+  }
+
+  return Status::OK();
+}
 
 uint64 HashSubgraphFunction(const FunctionDefLibrary& library,
                             const FunctionDef* f) {
@@ -511,11 +514,6 @@
   return HashSubgraphImpl(grappler::GraphView(&g), node, &visited, &cache);
 }
 
-namespace {
-
-constexpr char kDelimiter[] = "@@";
-
-}  // namespace
 
 VariantTensorDataReader::VariantTensorDataReader(
     const tensorflow::VariantTensorData* data)
diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h
index fbf7f8e..beb9d0c 100644
--- a/tensorflow/core/kernels/data/dataset_utils.h
+++ b/tensorflow/core/kernels/data/dataset_utils.h
@@ -15,13 +15,67 @@
 #ifndef TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_
 #define TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_
 
+#include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/framework/dataset.h"
 #include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/resource_mgr.h"
 #include "tensorflow/core/framework/tensor.h"
 
 namespace tensorflow {
 namespace data {
 
+template <typename T>
+class AnonymousResourceOp : public OpKernel {
+ public:
+  static std::atomic<int64> resource_id_counter_;
+
+  explicit AnonymousResourceOp(OpKernelConstruction* context)
+      : OpKernel(context) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    FunctionLibraryRuntime* lib;
+    std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
+    std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
+    OP_REQUIRES_OK(
+        ctx, ctx->function_library()->Clone(&flib_def, &pflr, &lib, true));
+    T* resource;
+    OP_REQUIRES_OK(ctx, CreateResource(ctx, std::move(flib_def),
+                                       std::move(pflr), lib, &resource));
+
+    string container_name = name();
+    string unique_name =
+        strings::StrCat(container_name, resource_id_counter_.fetch_add(1));
+    ResourceMgr* mgr = ctx->resource_manager();
+    OP_REQUIRES_OK(ctx, mgr->Create<T>(container_name, unique_name, resource));
+
+    Tensor* handle_t;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle_t));
+    ResourceHandle handle = MakeResourceHandle(ctx, container_name, unique_name,
+                                               MakeTypeIndex<T>());
+    handle_t->scalar<ResourceHandle>()() = handle;
+
+    if (create_deleter_) {
+      Tensor* deleter_t;
+      OP_REQUIRES_OK(ctx, ctx->allocate_output(1, TensorShape({}), &deleter_t));
+      deleter_t->scalar<Variant>()() =
+          ResourceDeleter(handle, ctx->resource_manager());
+    }
+  }
+
+ protected:
+  virtual string name() = 0;
+
+  virtual Status CreateResource(
+      OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
+      std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
+      FunctionLibraryRuntime* lib, T** resource) = 0;
+
+  bool create_deleter_ = true;
+};
+
+template <typename T>
+std::atomic<int64> AnonymousResourceOp<T>::resource_id_counter_;
+
 // Returns a GraphDef representation of the given dataset.
 Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
                   SerializationContext&& serialization_ctx,
diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD
index e209cdc..65d8a1d 100644
--- a/tensorflow/core/kernels/data/experimental/BUILD
+++ b/tensorflow/core/kernels/data/experimental/BUILD
@@ -229,6 +229,7 @@
 tf_kernel_library(
     name = "parallel_interleave_dataset_op",
     srcs = ["parallel_interleave_dataset_op.cc"],
+    hdrs = ["parallel_interleave_dataset_op.h"],
     deps = [
         "//tensorflow/core:core_cpu_internal",
         "//tensorflow/core:dataset_ops_op_lib",
@@ -237,6 +238,23 @@
         "//tensorflow/core:lib_internal",
         "//tensorflow/core/kernels/data:captured_function",
         "//tensorflow/core/kernels/data:dataset_utils",
+        "//tensorflow/core/kernels/data:name_utils",
+    ],
+)
+
+tf_cc_test(
+    name = "parallel_interleave_dataset_op_test",
+    size = "small",
+    srcs = ["parallel_interleave_dataset_op_test.cc"],
+    deps = [
+        ":parallel_interleave_dataset_op",
+        "//tensorflow/core:experimental_dataset_ops_op_lib",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+        "//tensorflow/core/kernels/data:dataset_test_base",
+        "//tensorflow/core/kernels/data:tensor_slice_dataset_op",
     ],
 )
 
diff --git a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc
index 8171bb6..abe7b7e 100644
--- a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc
@@ -22,6 +22,7 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 
 /* static */ constexpr const char* const AssertNextDatasetOp::kInputDataset;
 /* static */ constexpr const char* const AssertNextDatasetOp::kDatasetType;
@@ -62,6 +63,10 @@
 
   int64 Cardinality() const override { return input_->Cardinality(); }
 
+  Status CheckExternalState() const override {
+    return input_->CheckExternalState();
+  }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
@@ -158,5 +163,6 @@
     AssertNextDatasetOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.h b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.h
index aae2e80..6e86b5d 100644
--- a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.h
+++ b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.h
@@ -19,9 +19,7 @@
 
 namespace tensorflow {
 namespace data {
-
-// See documentation in ../../ops/experimental_dataset_ops.cc for a high-level
-// description of the following op.
+namespace experimental {
 
 class AssertNextDatasetOp : public UnaryDatasetOpKernel {
  public:
@@ -43,6 +41,7 @@
   std::vector<PartialTensorShape> output_shapes_;
 };
 
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op_test.cc
index e256d5b..52b2e62 100644
--- a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op_test.cc
@@ -15,6 +15,7 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
 constexpr char kNodeName[] = "assert_next_dataset";
@@ -95,12 +96,11 @@
   return {/*range_dataset_params*/ {/*start*/ 0, /*stop*/ 10, /*step*/ 1},
           /*take_dataset_params*/ {/*count*/ 3},
           /*transformations*/
-          DatasetOpsTestBase::CreateTensor<string>(
-              TensorShape({1}), {TakeDatasetOp::kDatasetType}),
+          CreateTensor<string>(TensorShape({1}), {TakeDatasetOp::kDatasetType}),
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2})},
+          {CreateTensor<int64>(TensorShape({}), {0}),
+           CreateTensor<int64>(TensorShape({}), {1}),
+           CreateTensor<int64>(TensorShape({}), {2})},
           /*expected_output_dtypes*/ {DT_INT64},
           /*expected_output_shapes*/ {PartialTensorShape({})},
           /*expected_cardinality*/ 3,
@@ -109,49 +109,48 @@
 
 // Test case 2 : assert two transformations.
 TestCase TestCase2() {
-  return {/*range_dataset_params*/ {/*start*/ 0, /*stop*/ 10, /*step*/ 1},
-          /*take_dataset_params*/ {/*count*/ 3},
-          /*transformations*/
-          DatasetOpsTestBase::CreateTensor<string>(
-              TensorShape({2}),
-              {TakeDatasetOp::kDatasetType, RangeDatasetOp::kDatasetType}),
-          /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2})},
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({})},
-          /*expected_cardinality*/ 3,
-          /*breakpoints*/ {0, 2, 5}};
-}
-
-TestCase AssertNextInvalid() {
   return {
       /*range_dataset_params*/ {/*start*/ 0, /*stop*/ 10, /*step*/ 1},
       /*take_dataset_params*/ {/*count*/ 3},
       /*transformations*/
-      DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), {"Whoops"}),
+      CreateTensor<string>(TensorShape({2}), {TakeDatasetOp::kDatasetType,
+                                              RangeDatasetOp::kDatasetType}),
       /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2})},
+      {CreateTensor<int64>(TensorShape({}), {0}),
+       CreateTensor<int64>(TensorShape({}), {1}),
+       CreateTensor<int64>(TensorShape({}), {2})},
       /*expected_output_dtypes*/ {DT_INT64},
       /*expected_output_shapes*/ {PartialTensorShape({})},
       /*expected_cardinality*/ 3,
       /*breakpoints*/ {0, 2, 5}};
 }
 
+TestCase AssertNextInvalid() {
+  return {/*range_dataset_params*/ {/*start*/ 0, /*stop*/ 10, /*step*/ 1},
+          /*take_dataset_params*/ {/*count*/ 3},
+          /*transformations*/
+          CreateTensor<string>(TensorShape({1}), {"Whoops"}),
+          /*expected_outputs*/
+          {CreateTensor<int64>(TensorShape({}), {0}),
+           CreateTensor<int64>(TensorShape({}), {1}),
+           CreateTensor<int64>(TensorShape({}), {2})},
+          /*expected_output_dtypes*/ {DT_INT64},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ 3,
+          /*breakpoints*/ {0, 2, 5}};
+}
+
 TestCase AssertNextShort() {
   return {/*range_dataset_params*/ {/*start*/ 0, /*stop*/ 10, /*step*/ 1},
           /*take_dataset_params*/ {/*count*/ 3},
           /*transformations*/
-          DatasetOpsTestBase::CreateTensor<string>(
-              TensorShape({3}), {TakeDatasetOp::kDatasetType,
-                                 RangeDatasetOp::kDatasetType, "Whoops"}),
+          CreateTensor<string>(TensorShape({3}),
+                               {TakeDatasetOp::kDatasetType,
+                                RangeDatasetOp::kDatasetType, "Whoops"}),
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2})},
+          {CreateTensor<int64>(TensorShape({}), {0}),
+           CreateTensor<int64>(TensorShape({}), {1}),
+           CreateTensor<int64>(TensorShape({}), {2})},
           /*expected_output_dtypes*/ {DT_INT64},
           /*expected_output_shapes*/ {PartialTensorShape({})},
           /*expected_cardinality*/ 3,
@@ -377,43 +376,6 @@
   EXPECT_EQ(assert_next_dataset->Cardinality(), test_case.expected_cardinality);
 }
 
-TEST_P(ParameterizedAssertNextDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  TestCase test_case = GetParam();
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  Tensor range_and_take_dataset_tensor;
-  TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
-                                             test_case.take_dataset_params,
-                                             &range_and_take_dataset_tensor));
-
-  std::unique_ptr<OpKernel> assert_next_dataset_kernel;
-  TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
-                                               test_case.expected_output_shapes,
-                                               &assert_next_dataset_kernel));
-  Tensor transformations = test_case.transformations;
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&range_and_take_dataset_tensor),
-       TensorValue(&transformations)});
-  std::unique_ptr<OpKernelContext> assert_next_dataset_context;
-  TF_ASSERT_OK(CreateAssertNextDatasetContext(
-      assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
-
-  DatasetBase* assert_next_dataset;
-  TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
-                             assert_next_dataset_context.get(),
-                             &assert_next_dataset));
-  core::ScopedUnref scoped_unref(assert_next_dataset);
-
-  std::unique_ptr<SerializationContext> serialization_context;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(assert_next_dataset->Save(serialization_context.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_P(ParameterizedAssertNextDatasetOpTest, IteratorOutputDtypes) {
   int thread_num = 2, cpu_num = 2;
   TestCase test_case = GetParam();
@@ -663,5 +625,6 @@
 }
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.cc b/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.cc
index 79a830a..fb4a8f2 100644
--- a/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.cc
@@ -19,6 +19,7 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 
 /* static */ constexpr const char* const AutoShardDatasetOp::kDatasetType;
 /* static */ constexpr const char* const AutoShardDatasetOp::kInputDataset;
@@ -62,14 +63,14 @@
                                                 int64 index) {
   RewriterConfig rewriter_config;
   rewriter_config.set_fail_on_optimizer_errors(true);
-  rewriter_config.add_optimizers(kOptimizerName);
   rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
+
+  rewriter_config.add_optimizers(kOptimizerName);
   auto custom_optimizer = rewriter_config.add_custom_optimizers();
   custom_optimizer->set_name(kOptimizerName);
   AttrValue num_workers_attr;
   num_workers_attr.set_i(num_workers);
   (*custom_optimizer->mutable_parameter_map())[kNumWorkers] = num_workers_attr;
-
   AttrValue index_attr;
   index_attr.set_i(index);
   (*custom_optimizer->mutable_parameter_map())[kIndex] = index_attr;
@@ -83,5 +84,6 @@
 REGISTER_KERNEL_BUILDER(Name("ExperimentalAutoShardDataset").Device(DEVICE_CPU),
                         AutoShardDatasetOp);
 }  // anonymous namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.h b/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.h
index 73ab7ad..087337c 100644
--- a/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.h
+++ b/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.h
@@ -19,9 +19,7 @@
 
 namespace tensorflow {
 namespace data {
-
-// See documentation in ../../ops/experimental_dataset_ops.cc for a high-level
-// description of the following op.
+namespace experimental {
 
 class AutoShardDatasetOp : public UnaryDatasetOpKernel {
  public:
@@ -42,6 +40,7 @@
   static RewriterConfig CreateConfig(int64 num_workers, int64 index);
 };
 
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op_test.cc
index 828561a..c509be0 100644
--- a/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op_test.cc
@@ -16,6 +16,7 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
 constexpr char kNodeName[] = "auto_shard_dataset";
@@ -55,14 +56,11 @@
            DataTypeVector expected_output_dtypes,
            std::vector<PartialTensorShape> expected_output_shapes,
            int64 expected_cardinality, std::vector<int> breakpoints)
-      : start(
-            DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {start})),
-        stop(DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {stop})),
-        step(DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {step})),
-        num_workers(DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}),
-                                                            {num_workers})),
-        index(
-            DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {index})),
+      : start(CreateTensor<int64>(TensorShape({}), {start})),
+        stop(CreateTensor<int64>(TensorShape({}), {stop})),
+        step(CreateTensor<int64>(TensorShape({}), {step})),
+        num_workers(CreateTensor<int64>(TensorShape({}), {num_workers})),
+        index(CreateTensor<int64>(TensorShape({}), {index})),
         expected_outputs(std::move(expected_outputs)),
         expected_output_dtypes(std::move(expected_output_dtypes)),
         expected_output_shapes(std::move(expected_output_shapes)),
@@ -89,8 +87,8 @@
           /*num_workers=*/5,
           /*index=*/2,
           /*expected_outputs=*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {7})},
+          {CreateTensor<int64>(TensorShape({}), {2}),
+           CreateTensor<int64>(TensorShape({}), {7})},
           /*expected_output_dtypes=*/{DT_INT64},
           /*expected_output_shapes=*/{PartialTensorShape({})},
           /*expected_cardinality=*/2,
@@ -120,8 +118,8 @@
           /*num_workers=*/4,
           /*index=*/3,
           /*expected_outputs=*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {7})},
+          {CreateTensor<int64>(TensorShape({}), {3}),
+           CreateTensor<int64>(TensorShape({}), {7})},
           /*expected_output_dtypes=*/{DT_INT64},
           /*expected_output_shapes=*/{PartialTensorShape({})},
           /*expected_cardinality=*/2,
@@ -279,5 +277,6 @@
 }
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc
index 8b4bafe..cc7577d 100644
--- a/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc
@@ -26,6 +26,7 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
 static const double kPercentile = 90.0;
@@ -241,6 +242,13 @@
       return static_cast<double>(n) * ratio_numerator_ / ratio_denominator_;
     }
 
+    Status CheckExternalState() const override {
+      for (const auto& captured_func : captured_funcs_) {
+        TF_RETURN_IF_ERROR(captured_func->CheckExternalState());
+      }
+      return input_->CheckExternalState();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -553,5 +561,6 @@
                         ChooseFastestBranchDatasetOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc b/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc
index 2c934eb..1db93a1 100644
--- a/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc
@@ -22,6 +22,7 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
 static const double kPercentile = 90.0;
@@ -157,6 +158,13 @@
 
     int64 Cardinality() const override { return cardinality_; }
 
+    Status CheckExternalState() const override {
+      for (const auto& input : inputs_) {
+        TF_RETURN_IF_ERROR(input->CheckExternalState());
+      }
+      return Status::OK();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -364,5 +372,6 @@
     ChooseFastestDatasetOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc
index d721e69..a814206 100644
--- a/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc
@@ -23,6 +23,7 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
 class CSVDatasetOp : public DatasetOpKernel {
@@ -92,7 +93,7 @@
     std::vector<string> filenames;
     filenames.reserve(filenames_tensor->NumElements());
     for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
-      filenames.push_back(filenames_tensor->flat<string>()(i));
+      filenames.push_back(filenames_tensor->flat<tstring>()(i));
     }
 
     io::ZlibCompressionOptions zlib_compression_options =
@@ -169,6 +170,8 @@
 
     string DebugString() const override { return "CSVDatasetOp::Dataset"; }
 
+    Status CheckExternalState() const override { return Status::OK(); }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -718,10 +721,10 @@
           }
           case DT_STRING: {
             if (field.empty() || field == dataset()->na_value_) {
-              component.scalar<string>()() =
-                  dataset()->record_defaults_[output_idx].flat<string>()(0);
+              component.scalar<tstring>()() =
+                  dataset()->record_defaults_[output_idx].flat<tstring>()(0);
             } else {
-              component.scalar<string>()() = string(field);
+              component.scalar<tstring>()() = string(field);
             }
             break;
           }
@@ -859,5 +862,6 @@
                         CSVDatasetOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc
index b3003aa..545a966 100644
--- a/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc
@@ -19,11 +19,9 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
-// See documentation in ../../ops/dataset_ops.cc for a high-level
-// description of the following op.
-
 class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
  public:
   explicit DenseToSparseBatchDatasetOp(OpKernelConstruction* ctx)
@@ -122,6 +120,10 @@
       return n / batch_size_ + (n % batch_size_ == 0 ? 0 : 1);
     }
 
+    Status CheckExternalState() const override {
+      return input_->CheckExternalState();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -319,5 +321,6 @@
     DenseToSparseBatchDatasetOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc
index 5d94ec0..e9ac761 100644
--- a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc
@@ -19,11 +19,9 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
-// See documentation in ../ops/dataset_ops.cc for a high-level
-// description of the following op.
-
 class DirectedInterleaveDatasetOp : public DatasetOpKernel {
  public:
   explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx)
@@ -109,6 +107,13 @@
       return strings::StrCat("DirectedInterleaveDatasetOp::Dataset");
     }
 
+    Status CheckExternalState() const override {
+      for (const auto& input : data_inputs_) {
+        TF_RETURN_IF_ERROR(input->CheckExternalState());
+      }
+      return selector_input_->CheckExternalState();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -196,8 +201,8 @@
             }
           }
 
-          LOG(WARNING) << "DirectedInterleave selected an exhausted input: "
-                       << selected_input;
+          VLOG(2) << "DirectedInterleave selected an exhausted input: "
+                  << selected_input;
         }
       }
 
@@ -284,5 +289,6 @@
     DirectedInterleaveDatasetOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc
index 6aa2ae7..fa02e8c 100644
--- a/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc
@@ -25,10 +25,9 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
-// See documentation in ../../ops/dataset_ops.cc for a high-level
-// description of the following op.
 class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
  public:
   explicit GroupByReducerDatasetOp(OpKernelConstruction* ctx)
@@ -113,6 +112,14 @@
       return "GroupByReducerDatasetOp::Dataset";
     }
 
+    Status CheckExternalState() const override {
+      TF_RETURN_IF_ERROR(captured_key_func_->CheckExternalState());
+      TF_RETURN_IF_ERROR(captured_init_func_->CheckExternalState());
+      TF_RETURN_IF_ERROR(captured_reduce_func_->CheckExternalState());
+      TF_RETURN_IF_ERROR(captured_finalize_func_->CheckExternalState());
+      return input_->CheckExternalState();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -422,5 +429,6 @@
 REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalGroupByReducerDataset");
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc
index 38a5201..2ccb463 100644
--- a/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc
@@ -26,10 +26,9 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
-// See documentation in ../../ops/dataset_ops.cc for a high-level
-// description of the following op.
 class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
  public:
   explicit GroupByWindowDatasetOp(OpKernelConstruction* ctx)
@@ -109,6 +108,13 @@
       return "GroupByWindowDatasetOp::Dataset";
     }
 
+    Status CheckExternalState() const override {
+      TF_RETURN_IF_ERROR(captured_key_func_->CheckExternalState());
+      TF_RETURN_IF_ERROR(captured_reduce_func_->CheckExternalState());
+      TF_RETURN_IF_ERROR(captured_window_size_func_->CheckExternalState());
+      return input_->CheckExternalState();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -517,5 +523,6 @@
 REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalGroupByWindowDataset");
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc
index 410861d..b9fb85c 100644
--- a/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc
@@ -18,11 +18,9 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
-// See documentation in ../ops/dataset_ops.cc for a high-level
-// description of the following op.
-
 class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
  public:
   explicit IgnoreErrorsDatasetOp(OpKernelConstruction* ctx)
@@ -62,6 +60,10 @@
 
     int64 Cardinality() const override { return input_->Cardinality(); }
 
+    Status CheckExternalState() const override {
+      return input_->CheckExternalState();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -147,5 +149,6 @@
     IgnoreErrorsDatasetOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc
index 0c75995..bae373a 100644
--- a/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc
@@ -22,6 +22,7 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
 class LMDBDatasetOp : public DatasetOpKernel {
@@ -37,7 +38,7 @@
     std::vector<string> filenames;
     filenames.reserve(filenames_tensor->NumElements());
     for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
-      filenames.push_back(filenames_tensor->flat<string>()(i));
+      filenames.push_back(filenames_tensor->flat<tstring>()(i));
     }
 
     *output = new Dataset(ctx, filenames);
@@ -69,6 +70,8 @@
 
     string DebugString() const override { return "LMDBDatasetOp::Dataset"; }
 
+    Status CheckExternalState() const override { return Status::OK(); }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -94,13 +97,13 @@
             out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
                                       TensorShape({}));
             Tensor& key_tensor = out_tensors->back();
-            key_tensor.scalar<string>()() = string(
+            key_tensor.scalar<tstring>()() = string(
                 static_cast<const char*>(mdb_key_.mv_data), mdb_key_.mv_size);
 
             out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
                                       TensorShape({}));
             Tensor& value_tensor = out_tensors->back();
-            value_tensor.scalar<string>()() =
+            value_tensor.scalar<tstring>()() =
                 string(static_cast<const char*>(mdb_value_.mv_data),
                        mdb_value_.mv_size);
 
@@ -221,5 +224,6 @@
                         LMDBDatasetOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc
index 2eddf4a..40ed96e 100644
--- a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc
@@ -12,8 +12,6 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#define EIGEN_USE_THREADS
-
 #include <atomic>
 #include <utility>
 
@@ -38,6 +36,7 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
 constexpr char kDatasetName[] = "MapAndBatch";
@@ -45,8 +44,6 @@
 // Maximum number of batch results to buffer.
 constexpr int64 kMaxBatchResults = 16;
 
-// See documentation in ../../ops/dataset_ops.cc for a high-level
-// description of the following op.
 class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
  public:
   explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx)
@@ -146,6 +143,11 @@
              (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
     }
 
+    Status CheckExternalState() const override {
+      TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
+      return input_->CheckExternalState();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -774,5 +776,6 @@
 REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalMapAndBatchDataset");
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc b/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc
index 0c3d06a..663537e 100644
--- a/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc
@@ -32,6 +32,7 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
 class MatchingFilesDatasetOp : public DatasetOpKernel {
@@ -41,7 +42,7 @@
   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
     const Tensor* patterns_t;
     OP_REQUIRES_OK(ctx, ctx->input("patterns", &patterns_t));
-    const auto patterns = patterns_t->flat<string>();
+    const auto patterns = patterns_t->flat<tstring>();
     size_t num_patterns = static_cast<size_t>(patterns.size());
     std::vector<string> pattern_strs;
     pattern_strs.reserve(num_patterns);
@@ -80,6 +81,8 @@
       return "MatchingFilesDatasetOp::Dataset";
     }
 
+    Status CheckExternalState() const override { return Status::OK(); }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -125,7 +128,7 @@
                              current_path.first.end(), '/', '\\');
               }
 
-              filepath_tensor.scalar<string>()() =
+              filepath_tensor.scalar<tstring>()() =
                   std::move(current_path.first);
               out_tensors->emplace_back(std::move(filepath_tensor));
               *end_of_sequence = false;
@@ -373,5 +376,6 @@
     MatchingFilesDatasetOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/non_serializable_dataset_op.cc b/tensorflow/core/kernels/data/experimental/non_serializable_dataset_op.cc
index 9086e13..c08e8fe 100644
--- a/tensorflow/core/kernels/data/experimental/non_serializable_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/non_serializable_dataset_op.cc
@@ -20,6 +20,7 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
 class NonSerializableDatasetOp : public UnaryDatasetOpKernel {
@@ -68,11 +69,16 @@
       return "NonSerializableDatasetOp::Dataset";
     }
 
+    Status CheckExternalState() const override {
+      return input_->CheckExternalState();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
                               Node** output) const override {
-      return errors::Unimplemented(DebugString(), "::AsGraphDefInternal");
+      return errors::Unimplemented(DebugString(),
+                                   " does not support serialization.");
     }
 
     int64 Cardinality() const override { return input_->Cardinality(); }
@@ -130,5 +136,6 @@
     NonSerializableDatasetOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc
index 2ce26f6..8c616aa 100644
--- a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc
@@ -12,1063 +12,1095 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+#include "tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.h"
+
 #include <atomic>
 #include <deque>
 #include <utility>
 
 #include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
-#include "tensorflow/core/framework/dataset.h"
 #include "tensorflow/core/framework/partial_tensor_shape.h"
 #include "tensorflow/core/framework/stats_aggregator.h"
 #include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/kernels/data/captured_function.h"
 #include "tensorflow/core/kernels/data/dataset_utils.h"
+#include "tensorflow/core/kernels/data/name_utils.h"
 #include "tensorflow/core/lib/core/threadpool.h"
 #include "tensorflow/core/lib/gtl/cleanup.h"
 #include "tensorflow/core/lib/random/random.h"
 
 namespace tensorflow {
 namespace data {
-namespace {
+namespace experimental {
 
-// See documentation in ../../ops/dataset_ops.cc for a high-level
-// description of the following op.
+/* static */ constexpr const char* const
+    ParallelInterleaveDatasetOp::kDatasetType;
+/* static */ constexpr const char* const
+    ParallelInterleaveDatasetOp::kInputDataset;
+/* static */ constexpr const char* const
+    ParallelInterleaveDatasetOp::kOtherArguments;
+/* static */ constexpr const char* const
+    ParallelInterleaveDatasetOp::kCycleLength;
+/* static */ constexpr const char* const
+    ParallelInterleaveDatasetOp::kBlockLength;
+/* static */ constexpr const char* const ParallelInterleaveDatasetOp::kSloppy;
+/* static */ constexpr const char* const
+    ParallelInterleaveDatasetOp::kBufferOutputElements;
+/* static */ constexpr const char* const
+    ParallelInterleaveDatasetOp::kPrefetchInputElements;
+/* static */ constexpr const char* const ParallelInterleaveDatasetOp::kFunc;
+/* static */ constexpr const char* const
+    ParallelInterleaveDatasetOp::kTarguments;
+/* static */ constexpr const char* const
+    ParallelInterleaveDatasetOp::kOutputTypes;
+/* static */ constexpr const char* const
+    ParallelInterleaveDatasetOp::kOutputShapes;
 
-class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
+constexpr char kInputExhausted[] = "input_exhausted";
+constexpr char kNextIndex[] = "next_index";
+constexpr char kBlockCount[] = "block_count";
+constexpr char kWorkersSize[] = "workers_size";
+constexpr char kInterleaveSize[] = "interleave_size";
+constexpr char kInterleaveIndices[] = "interleave_indices";
+constexpr char kStagingSize[] = "staging_size";
+constexpr char kStagingIndices[] = "staging_indices";
+constexpr char kWorkerThreadsRunning[] = "worker_threads_running";
+constexpr char kTFDataParallelInterleaveWorker[] =
+    "tf_data_parallel_interleave_worker";
+constexpr char kWorker[] = "worker";
+constexpr char kInputSize[] = "input_size";
+constexpr char kInput[] = "input";
+constexpr char kOutputsSize[] = "outputs_size";
+constexpr char kOutputs[] = "outputs";
+constexpr char kIsProducing[] = "is_producing";
+constexpr char kWorkerThread[] = "worker_thread";
+constexpr char kIteratorExhausted[] = "iterator_exhausted";
+constexpr char kIteratorCreationStatus[] = "iterator_creation_status";
+constexpr char kOutput[] = "output";
+constexpr char kEndOfSequence[] = "end_of_sequence";
+constexpr char kStatus[] = "status";
+constexpr char kOutputSize[] = "output_size";
+constexpr char kCode[] = "code";
+constexpr char KMessage[] = "msg";
+
+class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
  public:
-  explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx)
-      : UnaryDatasetOpKernel(ctx) {
-    FunctionMetadata::Params params;
-    params.is_multi_device_function = true;
-    OP_REQUIRES_OK(ctx,
-                   FunctionMetadata::Create(ctx, "f", params, &func_metadata_));
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+  Dataset(OpKernelContext* ctx, const DatasetBase* input,
+          std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
+          int64 block_length, bool sloppy, int64 buffer_output_elements,
+          int64 prefetch_input_elements, const DataTypeVector& output_types,
+          const std::vector<PartialTensorShape>& output_shapes)
+      : DatasetBase(DatasetContext(ctx)),
+        input_(input),
+        captured_func_(std::move(captured_func)),
+        cycle_length_(cycle_length),
+        block_length_(block_length),
+        sloppy_(sloppy),
+        buffer_output_elements_(buffer_output_elements),
+        prefetch_input_elements_(prefetch_input_elements),
+        output_types_(output_types),
+        output_shapes_(output_shapes) {
+    input_->Ref();
   }
 
-  void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
-                   DatasetBase** output) override {
-    int64 cycle_length = 0;
-    OP_REQUIRES_OK(ctx,
-                   ParseScalarArgument(ctx, "cycle_length", &cycle_length));
-    OP_REQUIRES(ctx, cycle_length > 0,
-                errors::InvalidArgument("`cycle_length` must be > 0"));
+  ~Dataset() override { input_->Unref(); }
 
-    int64 block_length = 0;
-    OP_REQUIRES_OK(ctx,
-                   ParseScalarArgument(ctx, "block_length", &block_length));
-    OP_REQUIRES(ctx, block_length > 0,
-                errors::InvalidArgument("`block_length` must be > 0"));
+  std::unique_ptr<IteratorBase> MakeIteratorInternal(
+      const string& prefix) const override {
+    return absl::make_unique<Iterator>(Iterator::Params{
+        this, name_utils::IteratorPrefix(kDatasetType, prefix)});
+  }
 
-    bool sloppy = false;
-    OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "sloppy", &sloppy));
+  const DataTypeVector& output_dtypes() const override { return output_types_; }
 
-    int64 buffer_output_elements = 0;
-    OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "buffer_output_elements",
-                                            &buffer_output_elements));
-    OP_REQUIRES(
-        ctx, buffer_output_elements > 0,
-        errors::InvalidArgument("`buffer_output_elements` must be > 0"));
+  const std::vector<PartialTensorShape>& output_shapes() const override {
+    return output_shapes_;
+  }
 
-    int64 prefetch_input_elements = 0;
-    OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefetch_input_elements",
-                                            &prefetch_input_elements));
-    OP_REQUIRES(
-        ctx, prefetch_input_elements >= 0,
-        errors::InvalidArgument("`prefetch_input_elements` must be >= 0"));
+  string DebugString() const override {
+    return name_utils::DatasetDebugString(kDatasetType);
+  }
 
-    std::unique_ptr<CapturedFunction> captured_func;
-    OP_REQUIRES_OK(
-        ctx, CapturedFunction::Create(ctx, func_metadata_, "other_arguments",
-                                      &captured_func));
+  Status CheckExternalState() const override {
+    TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
+    return input_->CheckExternalState();
+  }
 
-    *output =
-        new Dataset(ctx, input, std::move(captured_func), cycle_length,
-                    block_length, sloppy, buffer_output_elements,
-                    prefetch_input_elements, output_types_, output_shapes_);
+ protected:
+  Status AsGraphDefInternal(SerializationContext* ctx,
+                            DatasetGraphDefBuilder* b,
+                            Node** output) const override {
+    Node* input_node;
+    TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
+    Node* cycle_length_node;
+    TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
+    Node* block_length_node;
+    TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
+    Node* sloppy_node;
+    TF_RETURN_IF_ERROR(b->AddScalar(sloppy_, &sloppy_node));
+    Node* buffer_output_elements_node;
+    TF_RETURN_IF_ERROR(
+        b->AddScalar(buffer_output_elements_, &buffer_output_elements_node));
+    Node* prefetch_input_elements_node;
+    TF_RETURN_IF_ERROR(
+        b->AddScalar(prefetch_input_elements_, &prefetch_input_elements_node));
+    std::vector<Node*> other_arguments;
+    DataTypeVector other_arguments_types;
+    TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
+                                                  &other_arguments_types));
+    AttrValue f;
+    b->BuildAttrValue(captured_func_->func(), &f);
+    AttrValue other_arguments_types_attr;
+    b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
+
+    TF_RETURN_IF_ERROR(b->AddDataset(
+        this,
+        {{0, input_node},
+         {2, cycle_length_node},
+         {3, block_length_node},
+         {4, sloppy_node},
+         {5, buffer_output_elements_node},
+         {6, prefetch_input_elements_node}},
+        {{1, other_arguments}},
+        {{kFunc, f}, {kTarguments, other_arguments_types_attr}}, output));
+    return Status::OK();
   }
 
  private:
-  class Dataset : public DatasetBase {
+  int64 num_threads() const { return cycle_length_ + prefetch_input_elements_; }
+
+  // Parallel interleave's implementation is designed around a few principles:
+  //  1. Thread creation is relatively expensive. (Not reusing
+  //     threads causes a number of indirect costs such as poorer tcmalloc
+  //     performance due to thread-local caches, etc.) We allocate a fixed
+  //     number of threads at the start and never change. This is why we've
+  //     fused functionality that is theoretically orthogonal (i.e.
+  //     .prefetch()) into the implementation.
+  //  2. Drop-in replacement for standard interleave. The goal will be to
+  //     auto-opt people into an optimized implementation without any work
+  //     on the customer's part. We thus go through great pains to maintain
+  //     identical iteration orders, full determinism (disabled only via a
+  //     flag, etc.)
+  //  3. Performance across a variety of environments and I/O envelopes.
+  //
+  // The actual implementation centers around a collection of worker threads
+  // and their corresponding worker state (tracked in the `workers_` vector).
+  // Worker threads repeatedly receive a vector of Tensors that are used as
+  // input to the flat-map function (`captured_func_`). The output of this
+  // function must be a dataset. The worker thread then repeatedly calls
+  // `GetNext()`, maintaining a buffer of elements to minimize the likelihood
+  // that a caller will block waiting for an element to be produced.
+  //
+  // Pointers to these worker states are kept in 2 disjoint data structures:
+  //  1. `interleave_indices_` is a vector containing indices of WorkerStates
+  //     in `workers_` that we are interleaving. Worker threads backing these
+  //     WorkerStates should be regularly producing values.
+  //  2. `staging_indices_` is a deque containing indices of WorkerStates in
+  //     `workers_` that we will move to `interleave_indices_` when an
+  //     iterator in `interleave_indices_` is exhausted.
+  //
+  // The client calls `GetNext[Internal]()` to retrieve an output element. The
+  // internal implementation updates the state of `interleave_indices_` and
+  // `staging_indices_` as output iterators (run by the worker threads) are
+  // exhausted.
+  //
+  // `input_impl_` is the input iterator that generates arguments for the
+  // flat-map function (`captured_func_`). It is set to an iterator at
+  // Iterator construction, and is fixed until we consume all input elements.
+  // Once it is exhausted, we reset the unique_ptr to eagerly deallocate
+  // memory.
+  //
+  // A few invariants are maintained:
+  //  1. No element in interleave_indices_ should be a -1 unless
+  //     `staging_indices_` is empty and `input_impl_` is empty.
+  //  2. Every `worker_` element is pointed to by at most one element of the
+  //     union of `interleave_indices_` and `staging_indices_`.
+  //  3. Unless `input_impl_` is empty, every `worker_` must be pointed to by
+  //     an element in `interleave_indices_` or `staging_indices_`.
+  class Iterator : public DatasetIterator<Dataset> {
    public:
-    Dataset(OpKernelContext* ctx, const DatasetBase* input,
-            std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
-            int64 block_length, bool sloppy, int64 buffer_output_elements,
-            int64 prefetch_input_elements, const DataTypeVector& output_types,
-            const std::vector<PartialTensorShape>& output_shapes)
-        : DatasetBase(DatasetContext(ctx)),
-          input_(input),
-          captured_func_(std::move(captured_func)),
-          cycle_length_(cycle_length),
-          block_length_(block_length),
-          sloppy_(sloppy),
-          buffer_output_elements_(buffer_output_elements),
-          prefetch_input_elements_(prefetch_input_elements),
-          output_types_(output_types),
-          output_shapes_(output_shapes) {
-      input_->Ref();
+    explicit Iterator(const Params& params)
+        : DatasetIterator<Dataset>(params),
+          workers_(dataset()->num_threads()),
+          worker_thread_states_(dataset()->num_threads()) {}
+
+    ~Iterator() override {
+      mutex_lock l(mu_);
+      cancelled_ = true;
+      // Notify all workers in case they are blocked.
+      for (auto& worker : workers_) {
+        worker.cond_var.notify_all();
+      }
     }
 
-    ~Dataset() override { input_->Unref(); }
-
-    std::unique_ptr<IteratorBase> MakeIteratorInternal(
-        const string& prefix) const override {
-      return absl::make_unique<Iterator>(Iterator::Params{
-          this, strings::StrCat(prefix, "::ParallelInterleave")});
+    Status Initialize(IteratorContext* ctx) override {
+      TF_RETURN_IF_ERROR(
+          dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+      return dataset()->captured_func_->Instantiate(
+          ctx, &instantiated_captured_func_);
     }
 
-    const DataTypeVector& output_dtypes() const override {
-      return output_types_;
-    }
+    // It is implemented so that it matches the deterministic interleave
+    // unless getting the next element would block and we are allowed to be
+    // sloppy.
+    Status GetNextInternal(IteratorContext* ctx,
+                           std::vector<Tensor>* out_tensors,
+                           bool* end_of_sequence) override {
+      mutex_lock l(mu_);
+      TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx));
+      while (!cancelled_) {
+        // Wait for an item to become available, blocking if necessary. If we
+        // are allowed to be sloppy, we can skip over input datasets that do
+        // not have an item readily available.
+        bool can_produce_elements = false;
+        bool must_wait_for_input = true;
+        for (int64 i = 0; i < interleave_indices_.size(); ++i) {
+          int64 index = (next_index_ + i) % interleave_indices_.size();
+          int64 current_worker_index = interleave_indices_[index];
+          if (current_worker_index < 0) {
+            continue;  // Empty interleave elements.
+          }
+          WorkerState* current_worker = &workers_[current_worker_index];
+          can_produce_elements |= current_worker->MayHaveElements();
+          if (!current_worker->outputs.empty()) {
+            // We have an element!
+            next_index_ = index;
+            const bool element_acquired_sloppily = dataset()->sloppy_ && i > 1;
+            if (!element_acquired_sloppily) {
+              // If the element was acquired in the regular (non-sloppy)
+              // order, then advance the current block and cycle pointers to
+              // the next element in the regular order.
+              block_count_++;
+              if (block_count_ == dataset()->block_length_) {
+                next_index_ = (index + 1) % interleave_indices_.size();
+                block_count_ = 0;
+              }
+            } else {
+              block_count_ = 0;
+            }
+            *end_of_sequence = false;
+            Status s = current_worker->outputs.front().status;
+            current_worker->outputs.front().output.swap(*out_tensors);
+            current_worker->outputs.pop_front();
+            current_worker->cond_var.notify_one();
+            return s;
+          } else if (current_worker->is_producing && !dataset()->sloppy_) {
+            // current_worker.outputs.empty(), and we must wait for this
+            // iterator.
+            if (next_index_ != index) {
+              // We have advanced to a new iterator; reset block counts.
+              next_index_ = index;
+              block_count_ = 0;
+            }
+            break;
+          } else if (!current_worker->is_producing) {
+            // This iterator has reached end of input.
+            interleave_indices_[index] = -1;
+            if (input_impl_) {
+              // Start prefetching a new iterator.
+              std::vector<Tensor> args;
+              bool end_of_input = false;
+              Status s = input_impl_->GetNext(ctx, &args, &end_of_input);
+              if (end_of_input) {
+                input_impl_.reset();
+              } else {
+                current_worker->SetInputs(s, std::move(args));
+                staging_indices_.emplace_back(current_worker_index);
+              }
+            }
 
-    const std::vector<PartialTensorShape>& output_shapes() const override {
-      return output_shapes_;
-    }
+            if (!staging_indices_.empty()) {
+              // Move a worker from `staging_indices_` to
+              // `interleave_indices_`.
+              interleave_indices_[index] = staging_indices_.front();
+              staging_indices_.pop_front();
 
-    string DebugString() const override {
-      return "ParallelInterleaveDatasetOp::Dataset";
+              next_index_ = (index + 1) % interleave_indices_.size();
+              block_count_ = 0;
+              // Restart the inner [for] loop
+              can_produce_elements = true;
+              must_wait_for_input = false;
+              break;
+            }
+          }
+        }
+
+        if (!can_produce_elements && !input_impl_) {
+          // No potential for future values.
+          *end_of_sequence = true;
+          return Status::OK();
+        }
+
+        if (must_wait_for_input) {
+          // Wait for elements to become available.
+          RecordStop(ctx);
+          if (dataset()->sloppy_) {
+            sloppy_cond_var_.wait(l);
+          } else {
+            workers_[interleave_indices_[next_index_]].cond_var.wait(l);
+          }
+          RecordStart(ctx);
+        }
+      }
+      return errors::Cancelled(
+          "ParallelInterleaveDatasetOp::Dataset::Iterator::GetNext");
     }
 
    protected:
-    Status AsGraphDefInternal(SerializationContext* ctx,
-                              DatasetGraphDefBuilder* b,
-                              Node** output) const override {
-      Node* input_node;
-      TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
-      Node* cycle_length_node;
-      TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
-      Node* block_length_node;
-      TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
-      Node* sloppy_node;
-      TF_RETURN_IF_ERROR(b->AddScalar(sloppy_, &sloppy_node));
-      Node* buffer_output_elements_node;
-      TF_RETURN_IF_ERROR(
-          b->AddScalar(buffer_output_elements_, &buffer_output_elements_node));
-      Node* prefetch_input_elements_node;
-      TF_RETURN_IF_ERROR(b->AddScalar(prefetch_input_elements_,
-                                      &prefetch_input_elements_node));
-      std::vector<Node*> other_arguments;
-      DataTypeVector other_arguments_types;
-      TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
-                                                    &other_arguments_types));
-      AttrValue f;
-      b->BuildAttrValue(captured_func_->func(), &f);
-      AttrValue other_arguments_types_attr;
-      b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
+    std::shared_ptr<model::Node> CreateNode(
+        IteratorContext* ctx, model::Node::Args args) const override {
+      return model::MakeAsyncInterleaveManyNode(std::move(args),
+                                                /*parameters=*/{});
+    }
 
-      TF_RETURN_IF_ERROR(b->AddDataset(
-          this,
-          {{0, input_node},
-           {2, cycle_length_node},
-           {3, block_length_node},
-           {4, sloppy_node},
-           {5, buffer_output_elements_node},
-           {6, prefetch_input_elements_node}},
-          {{1, other_arguments}},
-          {{"f", f}, {"Targuments", other_arguments_types_attr}}, output));
+    Status SaveInternal(IteratorStateWriter* writer) override {
+      // The order of locking is important here to avoid deadlock.
+      mutex_lock l(mu_);
+      mutex_lock ckpt_l(ckpt_mu_);
+      if (input_impl_) {
+        TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+      } else {
+        TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputExhausted), ""));
+      }
+      TF_RETURN_IF_ERROR(
+          writer->WriteScalar(full_name(kNextIndex), next_index_));
+      TF_RETURN_IF_ERROR(
+          writer->WriteScalar(full_name(kBlockCount), block_count_));
+      TF_RETURN_IF_ERROR(
+          writer->WriteScalar(full_name(kWorkersSize), workers_.size()));
+      for (int i = 0; i < workers_.size(); ++i) {
+        TF_RETURN_IF_ERROR(WriteWorkerStateLocked(writer, i));
+      }
+      for (int i = 0; i < worker_thread_states_.size(); ++i) {
+        TF_RETURN_IF_ERROR(WriteWorkerThreadStateLocked(writer, i));
+      }
+      TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInterleaveSize),
+                                             interleave_indices_.size()));
+      for (int i = 0; i < interleave_indices_.size(); ++i) {
+        TF_RETURN_IF_ERROR(writer->WriteScalar(
+            full_name(strings::StrCat(kInterleaveIndices, "_", i)),
+            interleave_indices_[i]));
+      }
+      TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kStagingSize),
+                                             staging_indices_.size()));
+      for (int i = 0; i < staging_indices_.size(); ++i) {
+        TF_RETURN_IF_ERROR(writer->WriteScalar(
+            full_name(strings::StrCat(kStagingIndices, "_", i)),
+            staging_indices_[i]));
+      }
+      if (!worker_threads_.empty()) {
+        TF_RETURN_IF_ERROR(
+            writer->WriteScalar(full_name(kWorkerThreadsRunning), ""));
+      }
+      return Status::OK();
+    }
+
+    Status RestoreInternal(IteratorContext* ctx,
+                           IteratorStateReader* reader) override {
+      // The order of locking is important here to avoid deadlock.
+      mutex_lock l(mu_);
+      mutex_lock ckpt_l(ckpt_mu_);
+      if (!reader->Contains(full_name(kInputExhausted))) {
+        TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+      } else {
+        input_impl_.reset();
+      }
+      int64 temp;
+      TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNextIndex), &temp));
+      next_index_ = size_t(temp);
+      TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kBlockCount), &temp));
+      block_count_ = size_t(temp);
+
+      // Restore WorkerStates.
+      TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kWorkersSize), &temp));
+      if (temp != dataset()->num_threads()) {
+        return errors::Internal("Expected ", dataset()->num_threads(),
+                                " worker states but found ", temp, ".");
+      }
+      for (size_t i = 0; i < dataset()->num_threads(); ++i) {
+        TF_RETURN_IF_ERROR(ReadWorkerStateLocked(reader, i, ctx));
+      }
+      for (size_t i = 0; i < dataset()->num_threads(); ++i) {
+        TF_RETURN_IF_ERROR(ReadWorkerThreadStateLocked(reader, i, ctx));
+      }
+
+      // Restore `interleave_indices_`.
+      std::set<int64> all_indices;
+      {
+        int64 interleave_size;
+        TF_RETURN_IF_ERROR(
+            reader->ReadScalar(full_name(kInterleaveSize), &interleave_size));
+        interleave_indices_.reserve(interleave_size);
+        for (int64 i = 0; i < interleave_size; ++i) {
+          int64 temp;
+          TF_RETURN_IF_ERROR(reader->ReadScalar(
+              full_name(strings::StrCat(kInterleaveIndices, "_", i)), &temp));
+          if (temp >= 0 && all_indices.find(temp) != all_indices.end()) {
+            return errors::Internal(
+                "Duplicate entry for ", temp,
+                " found when reading interleave and staging indices.");
+          }
+          if (temp >= 0) {
+            all_indices.insert(temp);
+          }
+          interleave_indices_.emplace_back(temp);
+        }
+      }
+
+      // Restore `staging_indices_`.
+      {
+        int64 staging_size;
+        TF_RETURN_IF_ERROR(
+            reader->ReadScalar(full_name(kStagingSize), &staging_size));
+        for (int i = 0; i < staging_size; ++i) {
+          int64 temp;
+          TF_RETURN_IF_ERROR(reader->ReadScalar(
+              full_name(strings::StrCat(kStagingIndices, "_", i)), &temp));
+          if (all_indices.find(temp) != all_indices.end()) {
+            return errors::Internal(
+                "Duplicate entry for ", temp,
+                " found when reading interleave and staging indices.");
+          }
+          if (temp >= 0) {
+            all_indices.insert(temp);
+          }
+          staging_indices_.emplace_back(temp);
+        }
+      }
+
+      // Start Worker threads.
+      if (reader->Contains(full_name(kWorkerThreadsRunning))) {
+        worker_threads_.reserve(dataset()->num_threads());
+        for (size_t i = 0; i < dataset()->num_threads(); ++i) {
+          std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
+          worker_threads_.emplace_back(ctx->StartThread(
+              strings::StrCat(kTFDataParallelInterleaveWorker, "_", i),
+              [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
+        }
+      }
       return Status::OK();
     }
 
    private:
-    int64 num_threads() const {
-      return cycle_length_ + prefetch_input_elements_;
-    }
+    // OutputElem contains the information from a call to GetNext by an output
+    // iterator.
+    struct OutputElem {
+      // The output iterator sets `status` if getting the output element
+      // fails.
+      Status status;
+      // The buffered data element.
+      std::vector<Tensor> output;
 
-    // Parallel interleave's implementation is designed around a few principles:
-    //  1. Thread creation is relatively expensive. (Not reusing
-    //     threads causes a number of indirect costs such as poorer tcmalloc
-    //     performance due to thread-local caches, etc.) We allocate a fixed
-    //     number of threads at the start and never change. This is why we've
-    //     fused functionality that is theoretically orthogonal (i.e.
-    //     .prefetch()) into the implementation.
-    //  2. Drop-in replacement for standard interleave. The goal will be to
-    //     auto-opt people into an optimized implementation without any work
-    //     on the customer's part. We thus go through great pains to maintain
-    //     identical iteration orders, full determinism (disabled only via a
-    //     flag, etc.)
-    //  3. Performance across a variety of environments and I/O envelopes.
-    //
-    // The actual implementation centers around a collection of worker threads
-    // and their corresponding worker state (tracked in the `workers_` vector).
-    // Worker threads repeatedly receive a vector of Tensors that are used as
-    // input to the flat-map function (`captured_func_`). The output of this
-    // function must be a dataset. The worker thread then repeatedly calls
-    // `GetNext()`, maintaining a buffer of elements to minimize the likelihood
-    // that a caller will block waiting for an element to be produced.
-    //
-    // Pointers to these worker states are kept in 2 disjoint data structures:
-    //  1. `interleave_indices_` is a vector containing indices of WorkerStates
-    //     in `workers_` that we are interleaving. Worker threads backing these
-    //     WorkerStates should be regularly producing values.
-    //  2. `staging_indices_` is a deque containing indices of WorkerStates in
-    //     `workers_` that we will move to `interleave_indices_` when an
-    //     iterator in `interleave_indices_` is exhausted.
-    //
-    // The client calls `GetNext[Internal]()` to retrieve an output element. The
-    // internal implementation updates the state of `interleave_indices_` and
-    // `staging_indices_` as output iterators (run by the worker threads) are
-    // exhausted.
-    //
-    // `input_impl_` is the input iterator that generates arguments for the
-    // flat-map function (`captured_func_`). It is set to an iterator at
-    // Iterator construction, and is fixed until we consume all input elements.
-    // Once it is exhausted, we reset the unique_ptr to eagerly deallocate
-    // memory.
-    //
-    // A few invariants are maintained:
-    //  1. No element in interleave_indices_ should be a -1 unless
-    //     `staging_indices_` is empty and `input_impl_` is empty.
-    //  2. Every `worker_` element is pointed to by at most one element of the
-    //     union of `interleave_indices_` and `staging_indices_`.
-    //  3. Unless `input_impl_` is empty, every `worker_` must be pointed to by
-    //     an element in `interleave_indices_` or `staging_indices_`.
-    class Iterator : public DatasetIterator<Dataset> {
-     public:
-      explicit Iterator(const Params& params)
-          : DatasetIterator<Dataset>(params),
-            workers_(dataset()->num_threads()),
-            worker_thread_states_(dataset()->num_threads()) {}
+      explicit OutputElem(const Status& s) : status(s) {}
+    };
 
-      ~Iterator() override {
-        mutex_lock l(mu_);
-        cancelled_ = true;
-        // Notify all workers in case they are blocked.
-        for (auto& worker : workers_) {
-          worker.cond_var.notify_all();
+    // Worker threads operate on their relevant WorkerState structs.
+    //
+    // WorkerState's fields are all protected by mu_;
+    struct WorkerState {
+      // The arguments to be used to construct an output iterator.
+      std::vector<Tensor> input;
+      // The buffered output elements.
+      std::deque<OutputElem> outputs;
+      // Set to true iff the worker thread expects to append more elements to
+      // outputs. is_producing can be false despite !outputs.empty().
+      // Concretely, all output elements will have been consumed only when:
+      // is_producing == false && outputs.empty();
+      bool is_producing = false;
+      // Condition variable used to coordinate between threads. The worker
+      // thread waits on this condition variable when it is either (1) waiting
+      // for the main thread to add arguments to `input`, or (2) waiting for
+      // the main thread to consume an element of `outputs`. The main thread
+      // waits on cond_var if it is waiting for the worker thread to produce
+      // an element into `outputs` (this implies sloppy_==false).
+      condition_variable cond_var;
+
+      inline bool MayHaveElements() const {
+        return is_producing || !outputs.empty();
+      }
+
+      // Sets inputs for a worker thread and notifies it to start processing.
+      void SetInputs(const Status& s, std::vector<Tensor> input_arguments) {
+        if (s.ok()) {
+          DCHECK(!MayHaveElements())
+              << "Tried to start inputs, despite already producing!";
+          input = std::move(input_arguments);
+          is_producing = true;
+          cond_var.notify_one();
+        } else {
+          outputs.emplace_back(s);
         }
       }
+    };
 
-      Status Initialize(IteratorContext* ctx) override {
-        TF_RETURN_IF_ERROR(
-            dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
-        return dataset()->captured_func_->Instantiate(
-            ctx, &instantiated_captured_func_);
-      }
+    // The internal state of a worker thread that is not already captured
+    // in its `WorkerState`.
+    //
+    // This is needed only for checkpointing purposes. We keep this
+    // separate from `WorkerState` and guard its fields using a separate
+    // lock `ckpt_mu_` so as to not affect the performance of main pipeline.
+    struct WorkerThreadState {
+      // The output element that has been produced from the input iterator
+      // and is waiting to be added to `WorkerState.outputs`.
+      OutputElem output_elem;
 
-      // It is implemented so that it matches the deterministic interleave
-      // unless getting the next element would block and we are allowed to be
-      // sloppy.
-      Status GetNextInternal(IteratorContext* ctx,
-                             std::vector<Tensor>* out_tensors,
-                             bool* end_of_sequence) override {
-        mutex_lock l(mu_);
-        TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx));
-        while (!cancelled_) {
-          // Wait for an item to become available, blocking if necessary. If we
-          // are allowed to be sloppy, we can skip over input datasets that do
-          // not have an item readily available.
-          bool can_produce_elements = false;
-          bool must_wait_for_input = true;
-          for (int64 i = 0; i < interleave_indices_.size(); ++i) {
-            int64 index = (next_index_ + i) % interleave_indices_.size();
-            int64 current_worker_index = interleave_indices_[index];
-            if (current_worker_index < 0) {
-              continue;  // Empty interleave elements.
-            }
-            WorkerState* current_worker = &workers_[current_worker_index];
-            can_produce_elements |= current_worker->MayHaveElements();
-            if (!current_worker->outputs.empty()) {
-              // We have an element!
-              next_index_ = index;
-              const bool element_acquired_sloppily =
-                  dataset()->sloppy_ && i > 1;
-              if (!element_acquired_sloppily) {
-                // If the element was acquired in the regular (non-sloppy)
-                // order, then advance the current block and cycle pointers to
-                // the next element in the regular order.
-                block_count_++;
-                if (block_count_ == dataset()->block_length_) {
-                  next_index_ = (index + 1) % interleave_indices_.size();
-                  block_count_ = 0;
-                }
-              } else {
-                block_count_ = 0;
-              }
-              *end_of_sequence = false;
-              Status s = current_worker->outputs.front().status;
-              current_worker->outputs.front().output.swap(*out_tensors);
-              current_worker->outputs.pop_front();
-              current_worker->cond_var.notify_one();
-              return s;
-            } else if (current_worker->is_producing && !dataset()->sloppy_) {
-              // current_worker.outputs.empty(), and we must wait for this
-              // iterator.
-              if (next_index_ != index) {
-                // We have advanced to a new iterator; reset block counts.
-                next_index_ = index;
-                block_count_ = 0;
-              }
-              break;
-            } else if (!current_worker->is_producing) {
-              // This iterator has reached end of input.
-              interleave_indices_[index] = -1;
-              if (input_impl_) {
-                // Start prefetching a new iterator.
-                std::vector<Tensor> args;
-                bool end_of_input = false;
-                Status s = input_impl_->GetNext(ctx, &args, &end_of_input);
-                if (end_of_input) {
-                  input_impl_.reset();
-                } else {
-                  current_worker->SetInputs(s, std::move(args));
-                  staging_indices_.emplace_back(current_worker_index);
-                }
-              }
+      // Whether the input iterator returned an `end_of_sequence`.
+      bool end_of_sequence = false;
 
-              if (!staging_indices_.empty()) {
-                // Move a worker from `staging_indices_` to
-                // `interleave_indices_`.
-                interleave_indices_[index] = staging_indices_.front();
-                staging_indices_.pop_front();
+      // Status returned from `MakeIteratorFromInputElement`.
+      Status iterator_creation_status;
 
-                next_index_ = (index + 1) % interleave_indices_.size();
-                block_count_ = 0;
-                // Restart the inner [for] loop
-                can_produce_elements = true;
-                must_wait_for_input = false;
-                break;
-              }
-            }
-          }
+      // The arguments to be used to construct `iterator`.
+      std::vector<Tensor> input;
 
-          if (!can_produce_elements && !input_impl_) {
-            // No potential for future values.
-            *end_of_sequence = true;
+      std::unique_ptr<IteratorBase> iterator;
+
+      WorkerThreadState() : output_elem(Status::OK()) {}
+    };
+
+    Status EnsureWorkerThreadsStarted(IteratorContext* ctx)
+        EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+      if (worker_threads_.empty()) {
+        worker_threads_.reserve(dataset()->num_threads());
+        for (int64 i = 0; i < dataset()->num_threads(); ++i) {
+          std::vector<Tensor> args;
+          bool end_of_input = false;
+          Status s = input_impl_->GetNext(ctx, &args, &end_of_input);
+          if (end_of_input) {
+            input_impl_.reset();
             return Status::OK();
           }
-
-          if (must_wait_for_input) {
-            // Wait for elements to become available.
-            RecordStop(ctx);
-            if (dataset()->sloppy_) {
-              sloppy_cond_var_.wait(l);
-            } else {
-              workers_[interleave_indices_[next_index_]].cond_var.wait(l);
-            }
-            RecordStart(ctx);
-          }
-        }
-        return errors::Cancelled(
-            "ParallelInterleaveDatasetOp::Dataset::Iterator::GetNext");
-      }
-
-     protected:
-      std::shared_ptr<model::Node> CreateNode(
-          IteratorContext* ctx, model::Node::Args args) const override {
-        return model::MakeAsyncInterleaveManyNode(std::move(args),
-                                                  /*parameters=*/{});
-      }
-
-      Status SaveInternal(IteratorStateWriter* writer) override {
-        // The order of locking is important here to avoid deadlock.
-        mutex_lock l(mu_);
-        mutex_lock ckpt_l(ckpt_mu_);
-        if (input_impl_) {
-          TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
-        } else {
-          TF_RETURN_IF_ERROR(
-              writer->WriteScalar(full_name("input_exhausted"), ""));
-        }
-        TF_RETURN_IF_ERROR(
-            writer->WriteScalar(full_name("next_index"), next_index_));
-        TF_RETURN_IF_ERROR(
-            writer->WriteScalar(full_name("block_count"), block_count_));
-        TF_RETURN_IF_ERROR(
-            writer->WriteScalar(full_name("workers_size"), workers_.size()));
-        for (int i = 0; i < workers_.size(); ++i) {
-          TF_RETURN_IF_ERROR(WriteWorkerStateLocked(writer, i));
-        }
-        for (int i = 0; i < worker_thread_states_.size(); ++i) {
-          TF_RETURN_IF_ERROR(WriteWorkerThreadStateLocked(writer, i));
-        }
-        TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("interleave_size"),
-                                               interleave_indices_.size()));
-        for (int i = 0; i < interleave_indices_.size(); ++i) {
-          TF_RETURN_IF_ERROR(writer->WriteScalar(
-              full_name(strings::StrCat("interleave_indices_", i)),
-              interleave_indices_[i]));
-        }
-        TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("staging_size"),
-                                               staging_indices_.size()));
-        for (int i = 0; i < staging_indices_.size(); ++i) {
-          TF_RETURN_IF_ERROR(writer->WriteScalar(
-              full_name(strings::StrCat("staging_indices_", i)),
-              staging_indices_[i]));
-        }
-        if (!worker_threads_.empty()) {
-          TF_RETURN_IF_ERROR(
-              writer->WriteScalar(full_name("worker_threads_running"), ""));
-        }
-        return Status::OK();
-      }
-
-      Status RestoreInternal(IteratorContext* ctx,
-                             IteratorStateReader* reader) override {
-        // The order of locking is important here to avoid deadlock.
-        mutex_lock l(mu_);
-        mutex_lock ckpt_l(ckpt_mu_);
-        if (!reader->Contains(full_name("input_exhausted"))) {
-          TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
-        } else {
-          input_impl_.reset();
-        }
-        int64 temp;
-        TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("next_index"), &temp));
-        next_index_ = size_t(temp);
-        TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("block_count"), &temp));
-        block_count_ = size_t(temp);
-
-        // Restore WorkerStates.
-        TF_RETURN_IF_ERROR(
-            reader->ReadScalar(full_name("workers_size"), &temp));
-        if (temp != dataset()->num_threads()) {
-          return errors::Internal("Expected ", dataset()->num_threads(),
-                                  " worker states but found ", temp, ".");
-        }
-        for (size_t i = 0; i < dataset()->num_threads(); ++i) {
-          TF_RETURN_IF_ERROR(ReadWorkerStateLocked(reader, i, ctx));
-        }
-        for (size_t i = 0; i < dataset()->num_threads(); ++i) {
-          TF_RETURN_IF_ERROR(ReadWorkerThreadStateLocked(reader, i, ctx));
-        }
-
-        // Restore `interleave_indices_`.
-        std::set<int64> all_indices;
-        {
-          int64 interleave_size;
-          TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("interleave_size"),
-                                                &interleave_size));
-          interleave_indices_.reserve(interleave_size);
-          for (int64 i = 0; i < interleave_size; ++i) {
-            int64 temp;
-            TF_RETURN_IF_ERROR(reader->ReadScalar(
-                full_name(strings::StrCat("interleave_indices_", i)), &temp));
-            if (temp >= 0 && all_indices.find(temp) != all_indices.end()) {
-              return errors::Internal(
-                  "Duplicate entry for ", temp,
-                  " found when reading interleave and staging indices.");
-            }
-            if (temp >= 0) {
-              all_indices.insert(temp);
-            }
-            interleave_indices_.emplace_back(temp);
-          }
-        }
-
-        // Restore `staging_indices_`.
-        {
-          int64 staging_size;
-          TF_RETURN_IF_ERROR(
-              reader->ReadScalar(full_name("staging_size"), &staging_size));
-          for (int i = 0; i < staging_size; ++i) {
-            int64 temp;
-            TF_RETURN_IF_ERROR(reader->ReadScalar(
-                full_name(strings::StrCat("staging_indices_", i)), &temp));
-            if (all_indices.find(temp) != all_indices.end()) {
-              return errors::Internal(
-                  "Duplicate entry for ", temp,
-                  " found when reading interleave and staging indices.");
-            }
-            if (temp >= 0) {
-              all_indices.insert(temp);
-            }
-            staging_indices_.emplace_back(temp);
-          }
-        }
-
-        // Start Worker threads.
-        if (reader->Contains(full_name("worker_threads_running"))) {
-          worker_threads_.reserve(dataset()->num_threads());
-          for (size_t i = 0; i < dataset()->num_threads(); ++i) {
-            std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
-            worker_threads_.emplace_back(ctx->StartThread(
-                strings::StrCat("tf_data_parallel_interleave_worker_", i),
-                [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
-          }
-        }
-        return Status::OK();
-      }
-
-     private:
-      // OutputElem contains the information from a call to GetNext by an output
-      // iterator.
-      struct OutputElem {
-        // The output iterator sets `status` if getting the output element
-        // fails.
-        Status status;
-        // The buffered data element.
-        std::vector<Tensor> output;
-
-        explicit OutputElem(const Status& s) : status(s) {}
-      };
-
-      // Worker threads operate on their relevant WorkerState structs.
-      //
-      // WorkerState's fields are all protected by mu_;
-      struct WorkerState {
-        // The arguments to be used to construct an output iterator.
-        std::vector<Tensor> input;
-        // The buffered output elements.
-        std::deque<OutputElem> outputs;
-        // Set to true iff the worker thread expects to append more elements to
-        // outputs. is_producing can be false despite !outputs.empty().
-        // Concretely, all output elements will have been consumed only when:
-        // is_producing == false && outputs.empty();
-        bool is_producing = false;
-        // Condition variable used to coordinate between threads. The worker
-        // thread waits on this condition variable when it is either (1) waiting
-        // for the main thread to add arguments to `input`, or (2) waiting for
-        // the main thread to consume an element of `outputs`. The main thread
-        // waits on cond_var if it is waiting for the worker thread to produce
-        // an element into `outputs` (this implies sloppy_==false).
-        condition_variable cond_var;
-
-        inline bool MayHaveElements() const {
-          return is_producing || !outputs.empty();
-        }
-
-        // Sets inputs for a worker thread and notifies it to start processing.
-        void SetInputs(const Status& s, std::vector<Tensor> input_arguments) {
-          if (s.ok()) {
-            DCHECK(!MayHaveElements())
-                << "Tried to start inputs, despite already producing!";
-            input = std::move(input_arguments);
-            is_producing = true;
-            cond_var.notify_one();
+          workers_[i].SetInputs(s, std::move(args));
+          std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
+          worker_threads_.push_back(ctx->StartThread(
+              strings::StrCat(kTFDataParallelInterleaveWorker, "_", i),
+              [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
+          if (i < dataset()->cycle_length_) {
+            interleave_indices_.push_back(i);
           } else {
-            outputs.emplace_back(s);
+            staging_indices_.push_back(i);
           }
         }
-      };
+        DCHECK(interleave_indices_.size() == dataset()->cycle_length_);
+        DCHECK(staging_indices_.size() == dataset()->prefetch_input_elements_);
+      }
+      return Status::OK();
+    }
 
-      // The internal state of a worker thread that is not already captured
-      // in its `WorkerState`.
+    // Produces elements into the worker's output buffers.
+    void WorkerThread(const std::shared_ptr<IteratorContext>& ctx,
+                      const int64 thread_index) {
+      // Notes on checkpointing thread local state, i.e., `WorkerThreadState`:
       //
-      // This is needed only for checkpointing purposes. We keep this
-      // separate from `WorkerState` and guard its fields using a separate
-      // lock `ckpt_mu_` so as to not affect the performance of main pipeline.
-      struct WorkerThreadState {
-        // The output element that has been produced from the input iterator
-        // and is waiting to be added to `WorkerState.outputs`.
-        OutputElem output_elem;
+      // 1. Any local state that may need to be checkpointed should be kept
+      //    in `worker_thread_states_[thread_index]`.
+      // 2. `WorkerThreadState` should contain state that is needed only for
+      //    checkpointing, i.e., if we were to remove checkpointing support,
+      //    we could keep that state as local variables in this thread.
+      // 3. This thread should only read/write state at `thread_index`
+      //    and should not access other thread states.
+      // 4. When restoring from checkpoint, threads are started only after
+      //    the restore is complete.
+      // 5. Once restored from a checkpoint, the local state is edited only
+      //    by this thread. 3 & 4 allow making assumptions like temporarily
+      //    caching local state in this thread and using it outside a lock
+      //    e.g. `make_new_iterator`.
+      // 6. `ckpt_mu_` should be wisely used to create *consistent*
+      //    checkpoint markers.
 
-        // Whether the input iterator returned an `end_of_sequence`.
-        bool end_of_sequence = false;
-
-        // Status returned from `MakeIteratorFromInputElement`.
+      // std::function arguments are copy-constructable, so we pass raw
+      // pointers, and then immediately wrap them to ensure correct ownership.
+      RecordStart(ctx.get());
+      auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] {
+        mutex_lock l(mu_);
+        workers_[thread_index].cond_var.notify_all();
+        RecordStop(ctx.get());
+      });
+      bool make_new_iterator;
+      {
+        tf_shared_lock l(ckpt_mu_);
+        // Decide whether a new iterator should be built.
+        // 1. If there is an existing iterator, we use it.
+        // 2. If there was an error in iterator creation that could not be
+        //    notified to the client we attempt to send that to the client
+        //    first.
+        make_new_iterator =
+            worker_thread_states_[thread_index].iterator == nullptr &&
+            worker_thread_states_[thread_index].iterator_creation_status.ok();
+      }
+      // Even though `make_new_iterator` has cached values from
+      // `worker_thread_states_[thread_index]` which is guarded by ckpt_mu_,
+      // it is safe to *read* `make_new_iterator`outside of a lock without
+      // worrying about concurrent changes to values in
+      // `worker_thread_states_[thread_index]`. See comment at the start of
+      // this function for details.
+      while (true) {
+        // Whether creation of the iterator succeeded.
         Status iterator_creation_status;
-
-        // The arguments to be used to construct `iterator`.
-        std::vector<Tensor> input;
-
-        std::unique_ptr<IteratorBase> iterator;
-
-        WorkerThreadState() : output_elem(Status::OK()) {}
-      };
-
-      Status EnsureWorkerThreadsStarted(IteratorContext* ctx)
-          EXCLUSIVE_LOCKS_REQUIRED(mu_) {
-        if (worker_threads_.empty()) {
-          worker_threads_.reserve(dataset()->num_threads());
-          for (int64 i = 0; i < dataset()->num_threads(); ++i) {
-            std::vector<Tensor> args;
-            bool end_of_input = false;
-            Status s = input_impl_->GetNext(ctx, &args, &end_of_input);
-            if (end_of_input) {
-              input_impl_.reset();
-              return Status::OK();
-            }
-            workers_[i].SetInputs(s, std::move(args));
-            std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
-            worker_threads_.push_back(ctx->StartThread(
-                strings::StrCat("tf_data_parallel_interleave_worker_", i),
-                [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
-            if (i < dataset()->cycle_length_) {
-              interleave_indices_.push_back(i);
-            } else {
-              staging_indices_.push_back(i);
-            }
-          }
-          DCHECK(interleave_indices_.size() == dataset()->cycle_length_);
-          DCHECK(staging_indices_.size() ==
-                 dataset()->prefetch_input_elements_);
-        }
-        return Status::OK();
-      }
-
-      // Produces elements into the worker's output buffers.
-      void WorkerThread(const std::shared_ptr<IteratorContext>& ctx,
-                        const int64 thread_index) {
-        // Notes on checkpointing thread local state, i.e., `WorkerThreadState`:
-        //
-        // 1. Any local state that may need to be checkpointed should be kept
-        //    in `worker_thread_states_[thread_index]`.
-        // 2. `WorkerThreadState` should contain state that is needed only for
-        //    checkpointing, i.e., if we were to remove checkpointing support,
-        //    we could keep that state as local variables in this thread.
-        // 3. This thread should only read/write state at `thread_index`
-        //    and should not access other thread states.
-        // 4. When restoring from checkpoint, threads are started only after
-        //    the restore is complete.
-        // 5. Once restored from a checkpoint, the local state is edited only
-        //    by this thread. 3 & 4 allow making assumptions like temporarily
-        //    caching local state in this thread and using it outside a lock
-        //    e.g. `make_new_iterator`.
-        // 6. `ckpt_mu_` should be wisely used to create *consistent*
-        //    checkpoint markers.
-
-        // std::function arguments are copy-constructable, so we pass raw
-        // pointers, and then immediately wrap them to ensure correct ownership.
-        RecordStart(ctx.get());
-        auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] {
-          mutex_lock l(mu_);
-          workers_[thread_index].cond_var.notify_all();
-          RecordStop(ctx.get());
-        });
-        bool make_new_iterator;
-        {
-          tf_shared_lock l(ckpt_mu_);
-          // Decide whether a new iterator should be built.
-          // 1. If there is an existing iterator, we use it.
-          // 2. If there was an error in iterator creation that could not be
-          //    notified to the client we attempt to send that to the client
-          //    first.
-          make_new_iterator =
-              worker_thread_states_[thread_index].iterator == nullptr &&
-              worker_thread_states_[thread_index].iterator_creation_status.ok();
-        }
-        // Even though `make_new_iterator` has cached values from
-        // `worker_thread_states_[thread_index]` which is guarded by ckpt_mu_,
-        // it is safe to *read* `make_new_iterator`outside of a lock without
-        // worrying about concurrent changes to values in
-        // `worker_thread_states_[thread_index]`. See comment at the start of
-        // this function for details.
-        while (true) {
-          // Whether creation of the iterator succeeded.
-          Status iterator_creation_status;
-          // 1. Build a new iterator or use the existing one.
-          if (make_new_iterator) {
-            // 1a. Get new input tensors or use the exiting ones.
-            bool read_new_input;
-            {
-              tf_shared_lock l(ckpt_mu_);
-              // worker_thread_states_[thread_index].input will be non-empty
-              // if checkpointing happened at CHECKPOINT_MARKER_A.
-              read_new_input =
-                  worker_thread_states_[thread_index].input.empty();
-            }
-
-            if (read_new_input) {
-              mutex_lock l(mu_);
-              while (!cancelled_ && !workers_[thread_index].is_producing) {
-                RecordStop(ctx.get());
-                workers_[thread_index].cond_var.wait(l);
-                RecordStart(ctx.get());
-              }
-              if (cancelled_) return;
-              // Copy the input tensors so that we do not need to block on `mu_`
-              // when building the iterator.
-              // We keep a copy of the input tensors in
-              // `WorkerThreadState.input` till the iterator is in use. This is
-              // used in `RestoreInternal` to re-build the iterator.
-              // TODO(b/78046638): Explore ways to avoid tracking the input
-              // tensors.
-              tf_shared_lock ckpt_l(ckpt_mu_);
-              worker_thread_states_[thread_index].input.swap(
-                  workers_[thread_index].input);
-              // CHECKPOINT_MARKER_A
-              // We have the input tensors but have not built the iterator yet.
-            }
-
-            // 1b. Run the user defined function to produce a new iterator.
-            {
-              tf_shared_lock l(ckpt_mu_);
-              worker_thread_states_[thread_index].iterator_creation_status =
-                  MakeIteratorFromInputElement(
-                      ctx.get(), worker_thread_states_[thread_index].input,
-                      thread_index, *instantiated_captured_func_, prefix(),
-                      &worker_thread_states_[thread_index].iterator);
-              iterator_creation_status =
-                  worker_thread_states_[thread_index].iterator_creation_status;
-              if (!iterator_creation_status.ok()) {
-                worker_thread_states_[thread_index].input.clear();
-              }
-              // CHECKPOINT_MARKER_B
-              // Either an iterator has been successfully built and placed in
-              // `worker_thread_states_[thread_index].iterator` or it failed and
-              // a non-OK status has been put in
-              // `worker_thread_states_[thread_index].iterator_creation_status`.
-            }
-          } else {
+        // 1. Build a new iterator or use the existing one.
+        if (make_new_iterator) {
+          // 1a. Get new input tensors or use the exiting ones.
+          bool read_new_input;
+          {
             tf_shared_lock l(ckpt_mu_);
-            iterator_creation_status =
-                worker_thread_states_[thread_index].iterator_creation_status;
-            // Mark that we have used up the restored iterator.
-            make_new_iterator = true;
+            // worker_thread_states_[thread_index].input will be non-empty
+            // if checkpointing happened at CHECKPOINT_MARKER_A.
+            read_new_input = worker_thread_states_[thread_index].input.empty();
           }
-          // 2. Start producing elements or send error state to client if
-          //    iterator creation failed.
-          if (!iterator_creation_status.ok()) {
+
+          if (read_new_input) {
             mutex_lock l(mu_);
-            // Wait for space in the prefetch queue.
-            while (!cancelled_ && workers_[thread_index].outputs.size() ==
-                                      dataset()->buffer_output_elements_) {
+            while (!cancelled_ && !workers_[thread_index].is_producing) {
               RecordStop(ctx.get());
               workers_[thread_index].cond_var.wait(l);
               RecordStart(ctx.get());
             }
             if (cancelled_) return;
+            // Copy the input tensors so that we do not need to block on `mu_`
+            // when building the iterator.
+            // We keep a copy of the input tensors in
+            // `WorkerThreadState.input` till the iterator is in use. This is
+            // used in `RestoreInternal` to re-build the iterator.
+            // TODO(b/78046638): Explore ways to avoid tracking the input
+            // tensors.
             tf_shared_lock ckpt_l(ckpt_mu_);
-            workers_[thread_index].outputs.emplace_back(
-                iterator_creation_status);
-            workers_[thread_index].is_producing = false;
+            worker_thread_states_[thread_index].input.swap(
+                workers_[thread_index].input);
+            // CHECKPOINT_MARKER_A
+            // We have the input tensors but have not built the iterator yet.
+          }
+
+          // 1b. Run the user defined function to produce a new iterator.
+          {
+            tf_shared_lock l(ckpt_mu_);
             worker_thread_states_[thread_index].iterator_creation_status =
-                Status::OK();
-            // CHECKPOINT_MARKER_C
-            // Non-OK iterator creation status has been notified to the
-            // client.
-            workers_[thread_index].cond_var.notify_one();
-          } else {
-            bool end_of_sequence = false;
-            while (!end_of_sequence) {
-              // 3.a Produce an element!
-              {
-                tf_shared_lock ckpt_l(ckpt_mu_);
-                if (worker_thread_states_[thread_index]
-                        .output_elem.status.ok() &&
-                    worker_thread_states_[thread_index]
-                        .output_elem.output.empty() &&
-                    !worker_thread_states_[thread_index].end_of_sequence) {
-                  worker_thread_states_[thread_index].output_elem.status =
-                      worker_thread_states_[thread_index].iterator->GetNext(
-                          ctx.get(),
-                          &worker_thread_states_[thread_index]
-                               .output_elem.output,
-                          &worker_thread_states_[thread_index].end_of_sequence);
-                  end_of_sequence =
-                      worker_thread_states_[thread_index].end_of_sequence;
-                } else {
-                  end_of_sequence =
-                      worker_thread_states_[thread_index].end_of_sequence;
-                }
-                // CHECKPOINT_MARKER_D
-                // An element has been read or an error or end_of_sequence has
-                // been received from the input iterator and is waiting to be
-                // sent to client.
-              }
-
-              // 3.b Make it available to the client.
-              {
-                mutex_lock l(mu_);
-
-                // Wait for space in the prefetch queue.
-                while (!cancelled_ && workers_[thread_index].outputs.size() ==
-                                          dataset()->buffer_output_elements_) {
-                  RecordStop(ctx.get());
-                  workers_[thread_index].cond_var.wait(l);
-                  RecordStart(ctx.get());
-                }
-                if (cancelled_) return;
-
-                tf_shared_lock ckpt_l(ckpt_mu_);
-                workers_[thread_index].is_producing = !end_of_sequence;
-
-                // Output the element.
-
-                // Move the temporary state in WorkerThreadState to WorkerState
-                // and mark it as used.
-                if (end_of_sequence) {
-                  worker_thread_states_[thread_index].iterator.reset();
-                  worker_thread_states_[thread_index].input.clear();
-                  worker_thread_states_[thread_index].end_of_sequence = false;
-                } else {
-                  workers_[thread_index].outputs.emplace_back(
-                      worker_thread_states_[thread_index].output_elem.status);
-                  workers_[thread_index].outputs.back().output.swap(
-                      worker_thread_states_[thread_index].output_elem.output);
-                }
+                MakeIteratorFromInputElement(
+                    ctx.get(), worker_thread_states_[thread_index].input,
+                    thread_index, *instantiated_captured_func_, prefix(),
+                    &worker_thread_states_[thread_index].iterator);
+            iterator_creation_status =
+                worker_thread_states_[thread_index].iterator_creation_status;
+            if (!iterator_creation_status.ok()) {
+              worker_thread_states_[thread_index].input.clear();
+            }
+            // CHECKPOINT_MARKER_B
+            // Either an iterator has been successfully built and placed in
+            // `worker_thread_states_[thread_index].iterator` or it failed and
+            // a non-OK status has been put in
+            // `worker_thread_states_[thread_index].iterator_creation_status`.
+          }
+        } else {
+          tf_shared_lock l(ckpt_mu_);
+          iterator_creation_status =
+              worker_thread_states_[thread_index].iterator_creation_status;
+          // Mark that we have used up the restored iterator.
+          make_new_iterator = true;
+        }
+        // 2. Start producing elements or send error state to client if
+        //    iterator creation failed.
+        if (!iterator_creation_status.ok()) {
+          mutex_lock l(mu_);
+          // Wait for space in the prefetch queue.
+          while (!cancelled_ && workers_[thread_index].outputs.size() ==
+                                    dataset()->buffer_output_elements_) {
+            RecordStop(ctx.get());
+            workers_[thread_index].cond_var.wait(l);
+            RecordStart(ctx.get());
+          }
+          if (cancelled_) return;
+          tf_shared_lock ckpt_l(ckpt_mu_);
+          workers_[thread_index].outputs.emplace_back(iterator_creation_status);
+          workers_[thread_index].is_producing = false;
+          worker_thread_states_[thread_index].iterator_creation_status =
+              Status::OK();
+          // CHECKPOINT_MARKER_C
+          // Non-OK iterator creation status has been notified to the
+          // client.
+          workers_[thread_index].cond_var.notify_one();
+        } else {
+          bool end_of_sequence = false;
+          while (!end_of_sequence) {
+            // 3.a Produce an element!
+            {
+              tf_shared_lock ckpt_l(ckpt_mu_);
+              if (worker_thread_states_[thread_index].output_elem.status.ok() &&
+                  worker_thread_states_[thread_index]
+                      .output_elem.output.empty() &&
+                  !worker_thread_states_[thread_index].end_of_sequence) {
                 worker_thread_states_[thread_index].output_elem.status =
-                    Status::OK();
-                if (dataset()->sloppy_) {
-                  sloppy_cond_var_.notify_one();
-                } else {
-                  workers_[thread_index].cond_var.notify_one();
-                }
-                // CHECKPOINT_MARKER_E
-                // Output element or iterator status has been sent to the
-                // client.
+                    worker_thread_states_[thread_index].iterator->GetNext(
+                        ctx.get(),
+                        &worker_thread_states_[thread_index].output_elem.output,
+                        &worker_thread_states_[thread_index].end_of_sequence);
+                end_of_sequence =
+                    worker_thread_states_[thread_index].end_of_sequence;
+              } else {
+                end_of_sequence =
+                    worker_thread_states_[thread_index].end_of_sequence;
               }
+              // CHECKPOINT_MARKER_D
+              // An element has been read or an error or end_of_sequence has
+              // been received from the input iterator and is waiting to be
+              // sent to client.
+            }
+
+            // 3.b Make it available to the client.
+            {
+              mutex_lock l(mu_);
+
+              // Wait for space in the prefetch queue.
+              while (!cancelled_ && workers_[thread_index].outputs.size() ==
+                                        dataset()->buffer_output_elements_) {
+                RecordStop(ctx.get());
+                workers_[thread_index].cond_var.wait(l);
+                RecordStart(ctx.get());
+              }
+              if (cancelled_) return;
+
+              tf_shared_lock ckpt_l(ckpt_mu_);
+              workers_[thread_index].is_producing = !end_of_sequence;
+
+              // Output the element.
+
+              // Move the temporary state in WorkerThreadState to WorkerState
+              // and mark it as used.
+              if (end_of_sequence) {
+                worker_thread_states_[thread_index].iterator.reset();
+                worker_thread_states_[thread_index].input.clear();
+                worker_thread_states_[thread_index].end_of_sequence = false;
+              } else {
+                workers_[thread_index].outputs.emplace_back(
+                    worker_thread_states_[thread_index].output_elem.status);
+                workers_[thread_index].outputs.back().output.swap(
+                    worker_thread_states_[thread_index].output_elem.output);
+              }
+              worker_thread_states_[thread_index].output_elem.status =
+                  Status::OK();
+              if (dataset()->sloppy_) {
+                sloppy_cond_var_.notify_one();
+              } else {
+                workers_[thread_index].cond_var.notify_one();
+              }
+              // CHECKPOINT_MARKER_E
+              // Output element or iterator status has been sent to the
+              // client.
             }
           }
         }
       }
+    }
 
-      Status WriteWorkerStateLocked(IteratorStateWriter* writer, int index)
-          EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
-        string prefix = strings::StrCat("worker_", index);
-        TF_RETURN_IF_ERROR(writer->WriteScalar(
-            full_name(strings::StrCat(prefix, "_input_size")),
-            workers_[index].input.size()));
-        for (int i = 0; i < workers_[index].input.size(); ++i) {
-          TF_RETURN_IF_ERROR(writer->WriteTensor(
-              full_name(strings::StrCat(prefix, "_input_", i)),
-              workers_[index].input[i]));
-        }
-        TF_RETURN_IF_ERROR(writer->WriteScalar(
-            full_name(strings::StrCat(prefix, "_outputs_size")),
-            workers_[index].outputs.size()));
-        for (int i = 0; i < workers_[index].outputs.size(); ++i) {
-          TF_RETURN_IF_ERROR(WriteOutputElemLocked(
-              writer, workers_[index].outputs[i],
-              full_name(strings::StrCat(prefix, "_outputs_", i))));
-        }
-        if (workers_[index].is_producing) {
-          TF_RETURN_IF_ERROR(writer->WriteScalar(
-              full_name(strings::StrCat(prefix, "_is_producing")), ""));
-        }
-        return Status::OK();
+    Status WriteWorkerStateLocked(IteratorStateWriter* writer, int index)
+        EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
+      string prefix = strings::StrCat(kWorker, "_", index);
+      TF_RETURN_IF_ERROR(writer->WriteScalar(
+          full_name(strings::StrCat(prefix, "_", kInputSize)),
+          workers_[index].input.size()));
+      for (int i = 0; i < workers_[index].input.size(); ++i) {
+        TF_RETURN_IF_ERROR(writer->WriteTensor(
+            full_name(strings::StrCat(prefix, "_", kInput, "_", i)),
+            workers_[index].input[i]));
       }
-
-      Status ReadWorkerStateLocked(IteratorStateReader* reader, int index,
-                                   IteratorContext* ctx)
-          EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
-        string worker_prefix = strings::StrCat("worker_", index);
-        // Restore inputs.
-        int64 input_size;
-        TF_RETURN_IF_ERROR(reader->ReadScalar(
-            full_name(strings::StrCat(worker_prefix, "_input_size")),
-            &input_size));
-        workers_[index].input.reserve(input_size);
-        for (int i = 0; i < input_size; ++i) {
-          workers_[index].input.emplace_back();
-          TF_RETURN_IF_ERROR(reader->ReadTensor(
-              full_name(strings::StrCat(worker_prefix, "_input_", i)),
-              &workers_[index].input.back()));
-        }
-        int64 outputs_size;
-        TF_RETURN_IF_ERROR(reader->ReadScalar(
-            full_name(strings::StrCat(worker_prefix, "_outputs_size")),
-            &outputs_size));
-        for (int i = 0; i < outputs_size; ++i) {
-          workers_[index].outputs.emplace_back(Status::OK());
-          TF_RETURN_IF_ERROR(ReadOutputElemLocked(
-              reader, &workers_[index].outputs.back(),
-              full_name(strings::StrCat(worker_prefix, "_outputs_", i))));
-        }
-        if (reader->Contains(
-                full_name(strings::StrCat(worker_prefix, "_is_producing")))) {
-          workers_[index].is_producing = true;
-        } else {
-          workers_[index].is_producing = false;
-        }
-        return Status::OK();
-      }
-
-      Status WriteWorkerThreadStateLocked(IteratorStateWriter* writer,
-                                          int index)
-          EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
-        string prefix = strings::StrCat("worker_thread_", index);
-        if (worker_thread_states_[index].iterator != nullptr) {
-          TF_RETURN_IF_ERROR(
-              SaveInput(writer, worker_thread_states_[index].iterator));
-        } else {
-          TF_RETURN_IF_ERROR(writer->WriteScalar(
-              full_name(strings::StrCat(prefix, "_iterator_exhausted")), ""));
-        }
-        TF_RETURN_IF_ERROR(writer->WriteScalar(
-            full_name(strings::StrCat(prefix, "_input_size")),
-            worker_thread_states_[index].input.size()));
-        for (int i = 0; i < worker_thread_states_[index].input.size(); ++i) {
-          TF_RETURN_IF_ERROR(writer->WriteTensor(
-              full_name(strings::StrCat(prefix, "_input_", i)),
-              worker_thread_states_[index].input[i]));
-        }
-        TF_RETURN_IF_ERROR(WriteStatusLocked(
-            writer, strings::StrCat(prefix, "_iterator_creation_status"),
-            worker_thread_states_[index].iterator_creation_status));
+      TF_RETURN_IF_ERROR(writer->WriteScalar(
+          full_name(strings::StrCat(prefix, "_", kOutputsSize)),
+          workers_[index].outputs.size()));
+      for (int i = 0; i < workers_[index].outputs.size(); ++i) {
         TF_RETURN_IF_ERROR(WriteOutputElemLocked(
-            writer, worker_thread_states_[index].output_elem,
-            full_name(strings::StrCat(prefix, "_output"))));
-        if (worker_thread_states_[index].end_of_sequence) {
-          TF_RETURN_IF_ERROR(writer->WriteScalar(
-              full_name(strings::StrCat(prefix, "_end_of_sequence")), ""));
-        }
-        return Status::OK();
+            writer, workers_[index].outputs[i],
+            full_name(strings::StrCat(prefix, "_", kOutputs, "_", i))));
       }
+      if (workers_[index].is_producing) {
+        TF_RETURN_IF_ERROR(writer->WriteScalar(
+            full_name(strings::StrCat(prefix, "_", kIsProducing)), ""));
+      }
+      return Status::OK();
+    }
 
-      Status ReadWorkerThreadStateLocked(IteratorStateReader* reader, int index,
-                                         IteratorContext* ctx)
-          EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
-        string worker_prefix = strings::StrCat("worker_thread_", index);
-        // Restore inputs.
-        int64 input_size;
-        TF_RETURN_IF_ERROR(reader->ReadScalar(
-            full_name(strings::StrCat(worker_prefix, "_input_size")),
-            &input_size));
-        worker_thread_states_[index].input.reserve(input_size);
-        for (int i = 0; i < input_size; ++i) {
-          worker_thread_states_[index].input.emplace_back();
-          TF_RETURN_IF_ERROR(reader->ReadTensor(
-              full_name(strings::StrCat(worker_prefix, "_input_", i)),
-              &worker_thread_states_[index].input.back()));
-        }
-        // Restore iterator.
-        if (reader->Contains(full_name(
-                strings::StrCat(worker_prefix, "_iterator_exhausted")))) {
-          worker_thread_states_[index].iterator.reset();
-        } else {
-          std::unique_ptr<IteratorBase> iterator;
-          Status s = MakeIteratorFromInputElement(
-              ctx, worker_thread_states_[index].input, index,
-              *instantiated_captured_func_, prefix(), &iterator);
-          TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator));
-          worker_thread_states_[index].iterator.swap(iterator);
-        }
-        TF_RETURN_IF_ERROR(ReadStatusLocked(
-            reader, strings::StrCat(worker_prefix, "_iterator_creation_status"),
-            &worker_thread_states_[index].iterator_creation_status));
+    Status ReadWorkerStateLocked(IteratorStateReader* reader, int index,
+                                 IteratorContext* ctx)
+        EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
+      string worker_prefix = strings::StrCat(kWorker, "_", index);
+      // Restore inputs.
+      int64 input_size;
+      TF_RETURN_IF_ERROR(reader->ReadScalar(
+          full_name(strings::StrCat(worker_prefix, "_", kInputSize)),
+          &input_size));
+      workers_[index].input.reserve(input_size);
+      for (int i = 0; i < input_size; ++i) {
+        workers_[index].input.emplace_back();
+        TF_RETURN_IF_ERROR(reader->ReadTensor(
+            full_name(strings::StrCat(worker_prefix, "_", kInput, "_", i)),
+            &workers_[index].input.back()));
+      }
+      int64 outputs_size;
+      TF_RETURN_IF_ERROR(reader->ReadScalar(
+          full_name(strings::StrCat(worker_prefix, "_", kOutputsSize)),
+          &outputs_size));
+      for (int i = 0; i < outputs_size; ++i) {
+        workers_[index].outputs.emplace_back(Status::OK());
         TF_RETURN_IF_ERROR(ReadOutputElemLocked(
-            reader, &worker_thread_states_[index].output_elem,
-            full_name(strings::StrCat(worker_prefix, "_output"))));
-        if (reader->Contains(full_name(
-                strings::StrCat(worker_prefix, "_end_of_sequence")))) {
-          worker_thread_states_[index].end_of_sequence = true;
-        } else {
-          worker_thread_states_[index].end_of_sequence = false;
-        }
-        return Status::OK();
+            reader, &workers_[index].outputs.back(),
+            full_name(strings::StrCat(worker_prefix, "_", kOutputs, "_", i))));
       }
+      if (reader->Contains(
+              full_name(strings::StrCat(worker_prefix, "_", kIsProducing)))) {
+        workers_[index].is_producing = true;
+      } else {
+        workers_[index].is_producing = false;
+      }
+      return Status::OK();
+    }
 
-      Status WriteOutputElemLocked(IteratorStateWriter* writer,
-                                   const OutputElem& output_elem,
-                                   const string& prefix)
-          EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
-        TF_RETURN_IF_ERROR(WriteStatusLocked(
-            writer, strings::StrCat(prefix, "_status"), output_elem.status));
+    Status WriteWorkerThreadStateLocked(IteratorStateWriter* writer, int index)
+        EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
+      string prefix = strings::StrCat(kWorkerThread, "_", index);
+      if (worker_thread_states_[index].iterator != nullptr) {
         TF_RETURN_IF_ERROR(
-            writer->WriteScalar(strings::StrCat(prefix, "_output_size"),
-                                output_elem.output.size()));
-        for (int i = 0; i < output_elem.output.size(); ++i) {
-          TF_RETURN_IF_ERROR(writer->WriteTensor(
-              strings::StrCat(prefix, "_output_", i), output_elem.output[i]));
-        }
-        return Status::OK();
+            SaveInput(writer, worker_thread_states_[index].iterator));
+      } else {
+        TF_RETURN_IF_ERROR(writer->WriteScalar(
+            full_name(strings::StrCat(prefix, "_", kIteratorExhausted)), ""));
       }
-
-      Status ReadOutputElemLocked(IteratorStateReader* reader,
-                                  OutputElem* output_elem, const string& prefix)
-          EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
-        TF_RETURN_IF_ERROR(ReadStatusLocked(
-            reader, strings::StrCat(prefix, "_status"), &output_elem->status));
-        int64 output_size;
-        TF_RETURN_IF_ERROR(reader->ReadScalar(
-            strings::StrCat(prefix, "_output_size"), &output_size));
-        output_elem->output.reserve(output_size);
-        for (int i = 0; i < output_size; ++i) {
-          output_elem->output.emplace_back();
-          TF_RETURN_IF_ERROR(
-              reader->ReadTensor(strings::StrCat(prefix, "_output_", i),
-                                 &output_elem->output.back()));
-        }
-        return Status::OK();
+      TF_RETURN_IF_ERROR(writer->WriteScalar(
+          full_name(strings::StrCat(prefix, "_", kInputSize)),
+          worker_thread_states_[index].input.size()));
+      for (int i = 0; i < worker_thread_states_[index].input.size(); ++i) {
+        TF_RETURN_IF_ERROR(writer->WriteTensor(
+            full_name(strings::StrCat(prefix, "_", kInput, "_", i)),
+            worker_thread_states_[index].input[i]));
       }
+      TF_RETURN_IF_ERROR(WriteStatusLocked(
+          writer, strings::StrCat(prefix, "_", kIteratorCreationStatus),
+          worker_thread_states_[index].iterator_creation_status));
+      TF_RETURN_IF_ERROR(WriteOutputElemLocked(
+          writer, worker_thread_states_[index].output_elem,
+          full_name(strings::StrCat(prefix, "_", kOutput))));
+      if (worker_thread_states_[index].end_of_sequence) {
+        TF_RETURN_IF_ERROR(writer->WriteScalar(
+            full_name(strings::StrCat(prefix, "_", kEndOfSequence)), ""));
+      }
+      return Status::OK();
+    }
 
-      Status WriteStatusLocked(IteratorStateWriter* writer,
-                               const string& prefix, const Status& status)
-          EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
+    Status ReadWorkerThreadStateLocked(IteratorStateReader* reader, int index,
+                                       IteratorContext* ctx)
+        EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
+      string worker_prefix = strings::StrCat(kWorkerThread, "_", index);
+      // Restore inputs.
+      int64 input_size;
+      TF_RETURN_IF_ERROR(reader->ReadScalar(
+          full_name(strings::StrCat(worker_prefix, "_", kInputSize)),
+          &input_size));
+      worker_thread_states_[index].input.reserve(input_size);
+      for (int i = 0; i < input_size; ++i) {
+        worker_thread_states_[index].input.emplace_back();
+        TF_RETURN_IF_ERROR(reader->ReadTensor(
+            full_name(strings::StrCat(worker_prefix, "_", kInput, "_", i)),
+            &worker_thread_states_[index].input.back()));
+      }
+      // Restore iterator.
+      if (reader->Contains(full_name(
+              strings::StrCat(worker_prefix, "_", kIteratorExhausted)))) {
+        worker_thread_states_[index].iterator.reset();
+      } else {
+        std::unique_ptr<IteratorBase> iterator;
+        Status s = MakeIteratorFromInputElement(
+            ctx, worker_thread_states_[index].input, index,
+            *instantiated_captured_func_, prefix(), &iterator);
+        TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator));
+        worker_thread_states_[index].iterator.swap(iterator);
+      }
+      TF_RETURN_IF_ERROR(ReadStatusLocked(
+          reader, strings::StrCat(worker_prefix, "_", kIteratorCreationStatus),
+          &worker_thread_states_[index].iterator_creation_status));
+      TF_RETURN_IF_ERROR(ReadOutputElemLocked(
+          reader, &worker_thread_states_[index].output_elem,
+          full_name(strings::StrCat(worker_prefix, "_", kOutput))));
+      if (reader->Contains(
+              full_name(strings::StrCat(worker_prefix, "_", kEndOfSequence)))) {
+        worker_thread_states_[index].end_of_sequence = true;
+      } else {
+        worker_thread_states_[index].end_of_sequence = false;
+      }
+      return Status::OK();
+    }
+
+    Status WriteOutputElemLocked(IteratorStateWriter* writer,
+                                 const OutputElem& output_elem,
+                                 const string& prefix)
+        EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
+      TF_RETURN_IF_ERROR(WriteStatusLocked(
+          writer, strings::StrCat(prefix, "_", kStatus), output_elem.status));
+      TF_RETURN_IF_ERROR(
+          writer->WriteScalar(strings::StrCat(prefix, "_", kOutputSize),
+                              output_elem.output.size()));
+      for (int i = 0; i < output_elem.output.size(); ++i) {
         TF_RETURN_IF_ERROR(
-            writer->WriteScalar(full_name(strings::StrCat(prefix, "_code")),
-                                static_cast<int64>(status.code())));
-        if (!status.ok()) {
-          TF_RETURN_IF_ERROR(
-              writer->WriteScalar(full_name(strings::StrCat(prefix, "_msg")),
-                                  status.error_message()));
-        }
-        return Status::OK();
+            writer->WriteTensor(strings::StrCat(prefix, "_", kOutput, "_", i),
+                                output_elem.output[i]));
       }
+      return Status::OK();
+    }
 
-      Status ReadStatusLocked(IteratorStateReader* reader, const string& prefix,
-                              Status* status)
-          EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
-        int64 code_int;
+    Status ReadOutputElemLocked(IteratorStateReader* reader,
+                                OutputElem* output_elem, const string& prefix)
+        EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
+      TF_RETURN_IF_ERROR(ReadStatusLocked(
+          reader, strings::StrCat(prefix, "_", kStatus), &output_elem->status));
+      int64 output_size;
+      TF_RETURN_IF_ERROR(reader->ReadScalar(
+          strings::StrCat(prefix, "_", kOutputSize), &output_size));
+      output_elem->output.reserve(output_size);
+      for (int i = 0; i < output_size; ++i) {
+        output_elem->output.emplace_back();
+        TF_RETURN_IF_ERROR(
+            reader->ReadTensor(strings::StrCat(prefix, "_", kOutput, "_", i),
+                               &output_elem->output.back()));
+      }
+      return Status::OK();
+    }
+
+    Status WriteStatusLocked(IteratorStateWriter* writer, const string& prefix,
+                             const Status& status)
+        EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
+      TF_RETURN_IF_ERROR(
+          writer->WriteScalar(full_name(strings::StrCat(prefix, "_", kCode)),
+                              static_cast<int64>(status.code())));
+      if (!status.ok()) {
+        TF_RETURN_IF_ERROR(writer->WriteScalar(
+            full_name(strings::StrCat(prefix, "_", KMessage)),
+            status.error_message()));
+      }
+      return Status::OK();
+    }
+
+    Status ReadStatusLocked(IteratorStateReader* reader, const string& prefix,
+                            Status* status)
+        EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
+      int64 code_int;
+      TF_RETURN_IF_ERROR(reader->ReadScalar(
+          full_name(strings::StrCat(prefix, "_", kCode)), &code_int));
+      error::Code code = static_cast<error::Code>(code_int);
+
+      if (code != error::Code::OK) {
+        string error_message;
         TF_RETURN_IF_ERROR(reader->ReadScalar(
-            full_name(strings::StrCat(prefix, "_code")), &code_int));
-        error::Code code = static_cast<error::Code>(code_int);
-
-        if (code != error::Code::OK) {
-          string error_message;
-          TF_RETURN_IF_ERROR(reader->ReadScalar(
-              full_name(strings::StrCat(prefix, "_msg")), &error_message));
-          *status = Status(code, error_message);
-        } else {
-          *status = Status::OK();
-        }
-        return Status::OK();
+            full_name(strings::StrCat(prefix, "_", KMessage)), &error_message));
+        *status = Status(code, error_message);
+      } else {
+        *status = Status::OK();
       }
+      return Status::OK();
+    }
 
-      // Mutex & condition variable to guard mutable iterator internals and
-      // coordinate among worker threads and client thread[s].
-      mutex mu_ ACQUIRED_BEFORE(ckpt_mu_);
-      // The main thread waits on this condition variable if running in sloppy
-      // mode and no values are available.
-      condition_variable sloppy_cond_var_;
-      // Mutex used to wait for a consistent state while checkpointing.
-      // Only Save and Restore require an exclusive lock on this mutex. In
-      // other scenarios we just acquire a shared lock so the pipeline's
-      // performance should not be affected in the absence of checkpointing.
-      // A thread must not wait on any condition variable while holding
-      // `ckpt_mu_` in either shared or exclusive modes.
-      mutex ckpt_mu_;
+    // Mutex & condition variable to guard mutable iterator internals and
+    // coordinate among worker threads and client thread[s].
+    mutex mu_ ACQUIRED_BEFORE(ckpt_mu_);
+    // The main thread waits on this condition variable if running in sloppy
+    // mode and no values are available.
+    condition_variable sloppy_cond_var_;
+    // Mutex used to wait for a consistent state while checkpointing.
+    // Only Save and Restore require an exclusive lock on this mutex. In
+    // other scenarios we just acquire a shared lock so the pipeline's
+    // performance should not be affected in the absence of checkpointing.
+    // A thread must not wait on any condition variable while holding
+    // `ckpt_mu_` in either shared or exclusive modes.
+    mutex ckpt_mu_;
 
-      // The iterator producing elements which are converted to datasets by
-      // the dataset()->captured_func_ then interleaved together.
-      // input_impl_ is reset when we have exhausted its input.
-      std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+    // The iterator producing elements which are converted to datasets by
+    // the dataset()->captured_func_ then interleaved together.
+    // input_impl_ is reset when we have exhausted its input.
+    std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
 
-      std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
+    std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
 
-      // The WorkerState structs the worker threads operate on.
-      // workers_ elements are in at most one of interleave_ and staging_.
-      std::vector<WorkerState> workers_ GUARDED_BY(mu_);
+    // The WorkerState structs the worker threads operate on.
+    // workers_ elements are in at most one of interleave_ and staging_.
+    std::vector<WorkerState> workers_ GUARDED_BY(mu_);
 
-      // Stores the temporary state of WorkerThreads which is not stored in
-      // WorkerState. This is used for checkpointing purposes only.
-      std::vector<WorkerThreadState> worker_thread_states_ GUARDED_BY(ckpt_mu_);
+    // Stores the temporary state of WorkerThreads which is not stored in
+    // WorkerState. This is used for checkpointing purposes only.
+    std::vector<WorkerThreadState> worker_thread_states_ GUARDED_BY(ckpt_mu_);
 
-      // Indices in `workers_` of iterators to interleave.
-      std::vector<int64> interleave_indices_ GUARDED_BY(mu_);
-      // Indices in `workers_` of prefetched iterators.
-      std::deque<int64> staging_indices_ GUARDED_BY(mu_);
+    // Indices in `workers_` of iterators to interleave.
+    std::vector<int64> interleave_indices_ GUARDED_BY(mu_);
+    // Indices in `workers_` of prefetched iterators.
+    std::deque<int64> staging_indices_ GUARDED_BY(mu_);
 
-      // The index into output_elements_ for next element to produce.
-      size_t next_index_ GUARDED_BY(mu_) = 0;
-      // The number of items produced so far within the block
-      size_t block_count_ GUARDED_BY(mu_) = 0;
-      // Flag to instruct the worker threads to exit.
-      bool cancelled_ GUARDED_BY(mu_) = false;
-      // The worker threads. This must be last to ensure the
-      // threads have exited before any other members are deallocated.
-      // TODO(b/65178177): Avoid allocating additional threads.
-      std::vector<std::unique_ptr<Thread>> worker_threads_ GUARDED_BY(mu_);
-    };
-
-    const DatasetBase* const input_;
-    const std::unique_ptr<CapturedFunction> captured_func_;
-    const int64 cycle_length_;
-    const int64 block_length_;
-    const bool sloppy_;
-    const int64 buffer_output_elements_;
-    const int64 prefetch_input_elements_;
-    const DataTypeVector output_types_;
-    const std::vector<PartialTensorShape> output_shapes_;
+    // The index into output_elements_ for next element to produce.
+    size_t next_index_ GUARDED_BY(mu_) = 0;
+    // The number of items produced so far within the block
+    size_t block_count_ GUARDED_BY(mu_) = 0;
+    // Flag to instruct the worker threads to exit.
+    bool cancelled_ GUARDED_BY(mu_) = false;
+    // The worker threads. This must be last to ensure the
+    // threads have exited before any other members are deallocated.
+    // TODO(b/65178177): Avoid allocating additional threads.
+    std::vector<std::unique_ptr<Thread>> worker_threads_ GUARDED_BY(mu_);
   };
 
-  std::shared_ptr<FunctionMetadata> func_metadata_ = nullptr;
-  DataTypeVector output_types_;
-  std::vector<PartialTensorShape> output_shapes_;
+  const DatasetBase* const input_;
+  const std::unique_ptr<CapturedFunction> captured_func_;
+  const int64 cycle_length_;
+  const int64 block_length_;
+  const bool sloppy_;
+  const int64 buffer_output_elements_;
+  const int64 prefetch_input_elements_;
+  const DataTypeVector output_types_;
+  const std::vector<PartialTensorShape> output_shapes_;
 };
 
+ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
+    OpKernelConstruction* ctx)
+    : UnaryDatasetOpKernel(ctx) {
+  FunctionMetadata::Params params;
+  params.is_multi_device_function = true;
+  OP_REQUIRES_OK(ctx,
+                 FunctionMetadata::Create(ctx, kFunc, params, &func_metadata_));
+  OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
+  OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
+}
+
+void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
+                                              DatasetBase* input,
+                                              DatasetBase** output) {
+  int64 cycle_length = 0;
+  OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kCycleLength, &cycle_length));
+  OP_REQUIRES(ctx, cycle_length > 0,
+              errors::InvalidArgument("`cycle_length` must be > 0"));
+
+  int64 block_length = 0;
+  OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBlockLength, &block_length));
+  OP_REQUIRES(ctx, block_length > 0,
+              errors::InvalidArgument("`block_length` must be > 0"));
+
+  bool sloppy = false;
+  OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kSloppy, &sloppy));
+
+  int64 buffer_output_elements = 0;
+  OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBufferOutputElements,
+                                          &buffer_output_elements));
+  OP_REQUIRES(ctx, buffer_output_elements > 0,
+              errors::InvalidArgument("`buffer_output_elements` must be > 0"));
+
+  int64 prefetch_input_elements = 0;
+  OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kPrefetchInputElements,
+                                          &prefetch_input_elements));
+  OP_REQUIRES(
+      ctx, prefetch_input_elements >= 0,
+      errors::InvalidArgument("`prefetch_input_elements` must be >= 0"));
+
+  std::unique_ptr<CapturedFunction> captured_func;
+  OP_REQUIRES_OK(ctx,
+                 CapturedFunction::Create(ctx, func_metadata_, kOtherArguments,
+                                          &captured_func));
+
+  *output = new Dataset(ctx, input, std::move(captured_func), cycle_length,
+                        block_length, sloppy, buffer_output_elements,
+                        prefetch_input_elements, output_types_, output_shapes_);
+}
+
+namespace {
 REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
                         ParallelInterleaveDatasetOp);
 REGISTER_KERNEL_BUILDER(
@@ -1079,5 +1111,6 @@
 REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalParallelInterleaveDataset");
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.h b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.h
new file mode 100644
index 0000000..6e49679
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.h
@@ -0,0 +1,63 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_PARALLEL_INTERLEAVE_DATASET_OP_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_PARALLEL_INTERLEAVE_DATASET_OP_H_
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/kernels/data/captured_function.h"
+
+namespace tensorflow {
+namespace data {
+namespace experimental {
+
+// See documentation in ../../ops/experimental_dataset_ops.cc for a high-level
+// description of the following op.
+
+class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
+ public:
+  static constexpr const char* const kDatasetType = "ParallelInterleave";
+  static constexpr const char* const kInputDataset = "input_dataset";
+  static constexpr const char* const kOtherArguments = "other_arguments";
+  static constexpr const char* const kCycleLength = "cycle_length";
+  static constexpr const char* const kBlockLength = "block_length";
+  static constexpr const char* const kSloppy = "sloppy";
+  static constexpr const char* const kBufferOutputElements =
+      "buffer_output_elements";
+  static constexpr const char* const kPrefetchInputElements =
+      "prefetch_input_elements";
+  static constexpr const char* const kFunc = "f";
+  static constexpr const char* const kTarguments = "Targuments";
+  static constexpr const char* const kOutputTypes = "output_types";
+  static constexpr const char* const kOutputShapes = "output_shapes";
+
+  explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx);
+
+ protected:
+  void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+                   DatasetBase** output) override;
+
+ private:
+  class Dataset;
+
+  std::shared_ptr<FunctionMetadata> func_metadata_ = nullptr;
+  DataTypeVector output_types_;
+  std::vector<PartialTensorShape> output_shapes_;
+};
+
+}  // namespace experimental
+}  // namespace data
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_PARALLEL_INTERLEAVE_DATASET_OP_H_
diff --git a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op_test.cc
new file mode 100644
index 0000000..e7ecab2
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op_test.cc
@@ -0,0 +1,813 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.h"
+
+#include "tensorflow/core/kernels/data/dataset_test_base.h"
+#include "tensorflow/core/kernels/data/tensor_slice_dataset_op.h"
+
+namespace tensorflow {
+namespace data {
+namespace experimental {
+namespace {
+
+constexpr char kNodeName[] = "parallel_interleave_dataset";
+constexpr char kIteratorPrefix[] = "Iterator";
+
+class ParallelInterleaveDatasetOpTest : public DatasetOpsTestBase {
+ protected:
+  // Creates `TensorSliceDataset` variant tensor from the input vector of
+  // tensors.
+  Status CreateTensorSliceDatasetTensor(
+      std::vector<Tensor>* const tensor_vector, Tensor* dataset_tensor) {
+    DatasetBase* tensor_slice_dataset;
+    TF_RETURN_IF_ERROR(CreateTensorSliceDataset(
+        "tensor_slice_node", tensor_vector, &tensor_slice_dataset));
+    TF_RETURN_IF_ERROR(
+        StoreDatasetInVariantTensor(tensor_slice_dataset, dataset_tensor));
+    return Status::OK();
+  }
+
+  // Creates a new `ParallelInterleaveDataset` op kernel
+  Status CreateParallelInterleaveDatasetKernel(
+      const FunctionDefHelper::AttrValueWrapper& func,
+      const DataTypeVector& output_types,
+      const std::vector<PartialTensorShape>& output_shapes,
+      std::unique_ptr<OpKernel>* op_kernel) {
+    NodeDef node_def = test::function::NDef(
+        kNodeName,
+        name_utils::OpName(ParallelInterleaveDatasetOp::kDatasetType),
+        {ParallelInterleaveDatasetOp::kInputDataset,
+         ParallelInterleaveDatasetOp::kCycleLength,
+         ParallelInterleaveDatasetOp::kBlockLength,
+         ParallelInterleaveDatasetOp::kSloppy,
+         ParallelInterleaveDatasetOp::kBufferOutputElements,
+         ParallelInterleaveDatasetOp::kPrefetchInputElements},
+        {{ParallelInterleaveDatasetOp::kFunc, func},
+         {ParallelInterleaveDatasetOp::kTarguments, {}},
+         {ParallelInterleaveDatasetOp::kOutputTypes, output_types},
+         {ParallelInterleaveDatasetOp::kOutputShapes, output_shapes}});
+    TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel));
+    return Status::OK();
+  }
+
+  // Creates a new `ParallelInterleaveDataset` op kernel context.
+  Status CreateParallelInterleaveDatasetContext(
+      OpKernel* const op_kernel,
+      gtl::InlinedVector<TensorValue, 4>* const inputs,
+      std::unique_ptr<OpKernelContext>* context) {
+    TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs));
+    TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
+    return Status::OK();
+  }
+};
+
+struct TestCase {
+  TestCase(std::vector<Tensor> input_tensors, int64 cycle_length,
+           int64 block_length, bool sloppy, int64 buffer_output_elements,
+           int64 prefetch_input_elements,
+           FunctionDefHelper::AttrValueWrapper func,
+           std::vector<FunctionDef> func_lib,
+           std::vector<Tensor> expected_outputs,
+           DataTypeVector expected_output_dtypes,
+           std::vector<PartialTensorShape> expected_output_shapes,
+           int64 expected_cardinality, std::vector<int> breakpoints)
+      : input_tensors(std::move(input_tensors)),
+        cycle_length(CreateTensor<int64>(TensorShape({}), {cycle_length})),
+        block_length(CreateTensor<int64>(TensorShape({}), {block_length})),
+        sloppy(CreateTensor<bool>(TensorShape({}), {sloppy})),
+        buffer_output_elements(
+            CreateTensor<int64>(TensorShape({}), {buffer_output_elements})),
+        prefetch_input_elements(
+            CreateTensor<int64>(TensorShape({}), {prefetch_input_elements})),
+        func(std::move(func)),
+        func_lib(std::move(func_lib)),
+        expected_outputs(std::move(expected_outputs)),
+        expected_output_dtypes(std::move(expected_output_dtypes)),
+        expected_output_shapes(std::move(expected_output_shapes)),
+        expected_cardinality(expected_cardinality),
+        breakpoints(std::move(breakpoints)) {}
+
+  std::vector<Tensor> input_tensors;
+  Tensor cycle_length;
+  Tensor block_length;
+  Tensor sloppy;
+  Tensor buffer_output_elements;
+  Tensor prefetch_input_elements;
+  FunctionDefHelper::AttrValueWrapper func;
+  std::vector<FunctionDef> func_lib;
+  std::vector<Tensor> expected_outputs;
+  DataTypeVector expected_output_dtypes;
+  std::vector<PartialTensorShape> expected_output_shapes;
+  int64 expected_cardinality;
+  std::vector<int> breakpoints;
+};
+
+template <typename T>
+std::vector<Tensor> ConvertToTensorVec(std::vector<T> values) {
+  std::vector<Tensor> tensors;
+  tensors.reserve(values.size());
+  for (auto& value : values) {
+    tensors.emplace_back(CreateTensor<T>(TensorShape({1}), {value}));
+  }
+  return tensors;
+}
+
+FunctionDefHelper::AttrValueWrapper MakeTensorSliceDatasetFunc(
+    const DataTypeVector& output_types,
+    const std::vector<PartialTensorShape>& output_shapes) {
+  return FunctionDefHelper::FunctionRef(
+      /*name*/ "MakeTensorSliceDataset",
+      /*attrs*/ {{TensorSliceDatasetOp::kToutputTypes, output_types},
+                 {TensorSliceDatasetOp::kOutputShapes, output_shapes}});
+}
+
+// Test case 1: cycle_length = 1, block_length = 1, sloppy = false,
+// buffer_output_elements = 1, prefetch_input_elements = 1
+TestCase TestCase1() {
+  return {
+      /*input_tensors=*/
+      {CreateTensor<int64>(TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+      /*cycle_length=*/1,
+      /*block_length=*/1,
+      /*sloppy=*/false,
+      /*buffer_output_elements=*/1,
+      /*prefetch_input_elements=*/1,
+      /*func=*/
+      MakeTensorSliceDatasetFunc(
+          DataTypeVector({DT_INT64}),
+          std::vector<PartialTensorShape>({PartialTensorShape({1})})),
+      /*func_lib=*/{test::function::MakeTensorSliceDataset()},
+      /*expected_outputs*/
+      ConvertToTensorVec<int64>({0, 1, 2, 3, 4, 5, 6, 7, 8}),
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
+      /*breakpoints*/ {0, 4, 11}};
+}
+
+// Test case 2: cycle_length = 2, block_length = 1, sloppy = false,
+// buffer_output_elements = 1, prefetch_input_elements = 0
+TestCase TestCase2() {
+  return {
+      /*input_tensors=*/
+      {CreateTensor<int64>(TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+      /*cycle_length=*/2,
+      /*block_length=*/1,
+      /*sloppy=*/false,
+      /*buffer_output_elements=*/1,
+      /*prefetch_input_elements=*/0,
+      /*func=*/
+      MakeTensorSliceDatasetFunc(
+          DataTypeVector({DT_INT64}),
+          std::vector<PartialTensorShape>({PartialTensorShape({1})})),
+      /*func_lib=*/{test::function::MakeTensorSliceDataset()},
+      /*expected_outputs*/
+      ConvertToTensorVec<int64>({0, 3, 1, 4, 2, 5, 6, 7, 8}),
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
+      /*breakpoints*/ {0, 4, 11}};
+}
+
+// Test case 3: cycle_length = 3, block_length = 1, sloppy = true,
+// buffer_output_elements = 3, prefetch_input_elements = 2
+TestCase TestCase3() {
+  return {
+      /*input_tensors=*/
+      {CreateTensor<int64>(TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+      /*cycle_length=*/3,
+      /*block_length=*/1,
+      /*sloppy=*/true,
+      /*buffer_output_elements=*/3,
+      /*prefetch_input_elements=*/2,
+      /*func=*/
+      MakeTensorSliceDatasetFunc(
+          DataTypeVector({DT_INT64}),
+          std::vector<PartialTensorShape>({PartialTensorShape({1})})),
+      /*func_lib=*/{test::function::MakeTensorSliceDataset()},
+      /*expected_outputs*/
+      ConvertToTensorVec<int64>({0, 3, 6, 1, 4, 7, 2, 5, 8}),
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
+      /*breakpoints*/ {0, 4, 11}};
+}
+
+// Test case 4: cycle_length = 5, block_length = 1, sloppy = true
+// buffer_output_elements = 1, prefetch_input_elements = 2
+TestCase TestCase4() {
+  return {
+      /*input_tensors=*/
+      {CreateTensor<int64>(TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+      /*cycle_length=*/5,
+      /*block_length=*/1,
+      /*sloppy=*/true,
+      /*buffer_output_elements=*/1,
+      /*prefetch_input_elements=*/2,
+      /*func=*/
+      MakeTensorSliceDatasetFunc(
+          DataTypeVector({DT_INT64}),
+          std::vector<PartialTensorShape>({PartialTensorShape({1})})),
+      /*func_lib=*/{test::function::MakeTensorSliceDataset()},
+      /*expected_outputs*/
+      ConvertToTensorVec<int64>({0, 3, 6, 1, 4, 7, 2, 5, 8}),
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
+      /*breakpoints*/ {0, 4, 11}};
+}
+
+// Test case 5: cycle_length = 2, block_length = 2, sloppy = false
+// buffer_output_elements = 2, prefetch_input_elements = 2
+TestCase TestCase5() {
+  return {
+      /*input_tensors=*/
+      {CreateTensor<string>(TensorShape{3, 3, 1},
+                            {"a", "b", "c", "d", "e", "f", "g", "h", "i"})},
+      /*cycle_length=*/2,
+      /*block_length=*/2,
+      /*sloppy=*/false,
+      /*buffer_output_elements=*/2,
+      /*prefetch_input_elements=*/2,
+      /*func=*/
+      MakeTensorSliceDatasetFunc(
+          DataTypeVector({DT_STRING}),
+          std::vector<PartialTensorShape>({PartialTensorShape({1})})),
+      /*func_lib=*/{test::function::MakeTensorSliceDataset()},
+      /*expected_outputs*/
+      ConvertToTensorVec<string>({"a", "b", "d", "e", "c", "f", "g", "h", "i"}),
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
+      /*breakpoints*/ {0, 4, 11}};
+}
+
+TestCase InvalidCycleLengthTestCase() {
+  return {
+      /*input_tensors=*/
+      {CreateTensor<int64>(TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+      /*cycle_length=*/0,
+      /*block_length=*/1,
+      /*sloppy=*/false,
+      /*buffer_output_elements=*/1,
+      /*prefetch_input_elements=*/1,
+      /*func=*/
+      MakeTensorSliceDatasetFunc(
+          DataTypeVector({DT_INT64}),
+          std::vector<PartialTensorShape>({PartialTensorShape({1})})),
+      /*func_lib=*/{test::function::MakeTensorSliceDataset()},
+      /*expected_outputs*/
+      ConvertToTensorVec<int64>({}),
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
+      /*breakpoints*/ {0, 4, 11}};
+}
+
+TestCase InvalidBlockLengthTestCase() {
+  return {
+      /*input_tensors=*/
+      {CreateTensor<int64>(TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+      /*cycle_length=*/1,
+      /*block_length=*/-1,
+      /*sloppy=*/false,
+      /*buffer_output_elements=*/1,
+      /*prefetch_input_elements=*/1,
+      /*func=*/
+      MakeTensorSliceDatasetFunc(
+          DataTypeVector({DT_INT64}),
+          std::vector<PartialTensorShape>({PartialTensorShape({1})})),
+      /*func_lib=*/{test::function::MakeTensorSliceDataset()},
+      /*expected_outputs*/
+      ConvertToTensorVec<int64>({}),
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
+      /*breakpoints*/ {0, 4, 11}};
+}
+
+TestCase InvalidBufferOutputElementsTestCase() {
+  return {
+      /*input_tensors=*/
+      {CreateTensor<int64>(TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+      /*cycle_length=*/1,
+      /*block_length=*/1,
+      /*sloppy=*/false,
+      /*buffer_output_elements=*/0,
+      /*prefetch_input_elements=*/1,
+      /*func=*/
+      MakeTensorSliceDatasetFunc(
+          DataTypeVector({DT_INT64}),
+          std::vector<PartialTensorShape>({PartialTensorShape({1})})),
+      /*func_lib=*/{test::function::MakeTensorSliceDataset()},
+      /*expected_outputs*/
+      ConvertToTensorVec<int64>({}),
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
+      /*breakpoints*/ {0, 4, 11}};
+}
+
+TestCase InvalidPrefetchInputElementsTestCase() {
+  return {
+      /*input_tensors=*/
+      {CreateTensor<int64>(TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+      /*cycle_length=*/1,
+      /*block_length=*/1,
+      /*sloppy=*/false,
+      /*buffer_output_elements=*/1,
+      /*prefetch_input_elements=*/-1,
+      /*func=*/
+      MakeTensorSliceDatasetFunc(
+          DataTypeVector({DT_INT64}),
+          std::vector<PartialTensorShape>({PartialTensorShape({1})})),
+      /*func_lib=*/{test::function::MakeTensorSliceDataset()},
+      /*expected_outputs*/
+      ConvertToTensorVec<int64>({}),
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
+      /*breakpoints*/ {0, 4, 11}};
+}
+
+class ParameterizedParallelInterleaveDatasetOpTest
+    : public ParallelInterleaveDatasetOpTest,
+      public ::testing::WithParamInterface<TestCase> {};
+
+TEST_P(ParameterizedParallelInterleaveDatasetOpTest, GetNext) {
+  int thread_num = 2, cpu_num = 2;
+  TestCase test_case = GetParam();
+  TF_ASSERT_OK(InitThreadPool(thread_num));
+  TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
+
+  std::unique_ptr<OpKernel> parallel_interleave_dataset_kernel;
+  TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel(
+      test_case.func, test_case.expected_output_dtypes,
+      test_case.expected_output_shapes, &parallel_interleave_dataset_kernel));
+
+  Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
+  std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
+  TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
+                                              &tensor_slice_dataset_tensor));
+  gtl::InlinedVector<TensorValue, 4> inputs(
+      {TensorValue(&tensor_slice_dataset_tensor),
+       TensorValue(&test_case.cycle_length),
+       TensorValue(&test_case.block_length), TensorValue(&test_case.sloppy),
+       TensorValue(&test_case.buffer_output_elements),
+       TensorValue(&test_case.prefetch_input_elements)});
+  std::unique_ptr<OpKernelContext> parallel_interleave_dataset_context;
+  TF_ASSERT_OK(CreateParallelInterleaveDatasetContext(
+      parallel_interleave_dataset_kernel.get(), &inputs,
+      &parallel_interleave_dataset_context));
+  DatasetBase* parallel_interleave_dataset;
+  TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(),
+                             parallel_interleave_dataset_context.get(),
+                             &parallel_interleave_dataset));
+  core::ScopedUnref scoped_unref_dataset(parallel_interleave_dataset);
+
+  std::unique_ptr<IteratorContext> iterator_ctx;
+  TF_ASSERT_OK(CreateIteratorContext(parallel_interleave_dataset_context.get(),
+                                     &iterator_ctx));
+  std::unique_ptr<IteratorBase> iterator;
+  TF_ASSERT_OK(parallel_interleave_dataset->MakeIterator(
+      iterator_ctx.get(), kIteratorPrefix, &iterator));
+  bool end_of_sequence = false;
+  std::vector<Tensor> out_tensors;
+  while (!end_of_sequence) {
+    std::vector<Tensor> next;
+    TF_EXPECT_OK(
+        iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence));
+    out_tensors.insert(out_tensors.end(), next.begin(), next.end());
+  }
+
+  TF_EXPECT_OK(
+      ExpectEqual(out_tensors, test_case.expected_outputs,
+                  /*compare_order=*/!test_case.sloppy.scalar<bool>()()));
+}
+
+TEST_F(ParallelInterleaveDatasetOpTest, DatasetNodeName) {
+  int thread_num = 2, cpu_num = 2;
+  TestCase test_case = TestCase1();
+  TF_ASSERT_OK(InitThreadPool(thread_num));
+  TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
+
+  std::unique_ptr<OpKernel> parallel_interleave_dataset_kernel;
+  TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel(
+      test_case.func, test_case.expected_output_dtypes,
+      test_case.expected_output_shapes, &parallel_interleave_dataset_kernel));
+
+  Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
+  std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
+  TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
+                                              &tensor_slice_dataset_tensor));
+  gtl::InlinedVector<TensorValue, 4> inputs(
+      {TensorValue(&tensor_slice_dataset_tensor),
+       TensorValue(&test_case.cycle_length),
+       TensorValue(&test_case.block_length), TensorValue(&test_case.sloppy),
+       TensorValue(&test_case.buffer_output_elements),
+       TensorValue(&test_case.prefetch_input_elements)});
+  std::unique_ptr<OpKernelContext> parallel_interleave_dataset_context;
+  TF_ASSERT_OK(CreateParallelInterleaveDatasetContext(
+      parallel_interleave_dataset_kernel.get(), &inputs,
+      &parallel_interleave_dataset_context));
+  DatasetBase* parallel_interleave_dataset;
+  TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(),
+                             parallel_interleave_dataset_context.get(),
+                             &parallel_interleave_dataset));
+  core::ScopedUnref scoped_unref_dataset(parallel_interleave_dataset);
+
+  EXPECT_EQ(parallel_interleave_dataset->node_name(), kNodeName);
+}
+
+TEST_F(ParallelInterleaveDatasetOpTest, DatasetTypeString) {
+  int thread_num = 2, cpu_num = 2;
+  TestCase test_case = TestCase1();
+  TF_ASSERT_OK(InitThreadPool(thread_num));
+  TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
+
+  std::unique_ptr<OpKernel> parallel_interleave_dataset_kernel;
+  TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel(
+      test_case.func, test_case.expected_output_dtypes,
+      test_case.expected_output_shapes, &parallel_interleave_dataset_kernel));
+
+  Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
+  std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
+  TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
+                                              &tensor_slice_dataset_tensor));
+  gtl::InlinedVector<TensorValue, 4> inputs(
+      {TensorValue(&tensor_slice_dataset_tensor),
+       TensorValue(&test_case.cycle_length),
+       TensorValue(&test_case.block_length), TensorValue(&test_case.sloppy),
+       TensorValue(&test_case.buffer_output_elements),
+       TensorValue(&test_case.prefetch_input_elements)});
+  std::unique_ptr<OpKernelContext> parallel_interleave_dataset_context;
+  TF_ASSERT_OK(CreateParallelInterleaveDatasetContext(
+      parallel_interleave_dataset_kernel.get(), &inputs,
+      &parallel_interleave_dataset_context));
+  DatasetBase* parallel_interleave_dataset;
+  TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(),
+                             parallel_interleave_dataset_context.get(),
+                             &parallel_interleave_dataset));
+  core::ScopedUnref scoped_unref_dataset(parallel_interleave_dataset);
+
+  EXPECT_EQ(parallel_interleave_dataset->type_string(),
+            name_utils::OpName(ParallelInterleaveDatasetOp::kDatasetType));
+}
+
+TEST_P(ParameterizedParallelInterleaveDatasetOpTest, DatasetOutputDtypes) {
+  int thread_num = 2, cpu_num = 2;
+  TestCase test_case = GetParam();
+  TF_ASSERT_OK(InitThreadPool(thread_num));
+  TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
+
+  std::unique_ptr<OpKernel> parallel_interleave_dataset_kernel;
+  TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel(
+      test_case.func, test_case.expected_output_dtypes,
+      test_case.expected_output_shapes, &parallel_interleave_dataset_kernel));
+
+  Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
+  std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
+  TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
+                                              &tensor_slice_dataset_tensor));
+  gtl::InlinedVector<TensorValue, 4> inputs(
+      {TensorValue(&tensor_slice_dataset_tensor),
+       TensorValue(&test_case.cycle_length),
+       TensorValue(&test_case.block_length), TensorValue(&test_case.sloppy),
+       TensorValue(&test_case.buffer_output_elements),
+       TensorValue(&test_case.prefetch_input_elements)});
+  std::unique_ptr<OpKernelContext> parallel_interleave_dataset_context;
+  TF_ASSERT_OK(CreateParallelInterleaveDatasetContext(
+      parallel_interleave_dataset_kernel.get(), &inputs,
+      &parallel_interleave_dataset_context));
+  DatasetBase* parallel_interleave_dataset;
+  TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(),
+                             parallel_interleave_dataset_context.get(),
+                             &parallel_interleave_dataset));
+  core::ScopedUnref scoped_unref_dataset(parallel_interleave_dataset);
+
+  TF_EXPECT_OK(VerifyTypesMatch(parallel_interleave_dataset->output_dtypes(),
+                                test_case.expected_output_dtypes));
+}
+
+TEST_P(ParameterizedParallelInterleaveDatasetOpTest, DatasetOutputShapes) {
+  int thread_num = 2, cpu_num = 2;
+  TestCase test_case = GetParam();
+  TF_ASSERT_OK(InitThreadPool(thread_num));
+  TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
+
+  std::unique_ptr<OpKernel> parallel_interleave_dataset_kernel;
+  TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel(
+      test_case.func, test_case.expected_output_dtypes,
+      test_case.expected_output_shapes, &parallel_interleave_dataset_kernel));
+
+  Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
+  std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
+  TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
+                                              &tensor_slice_dataset_tensor));
+  gtl::InlinedVector<TensorValue, 4> inputs(
+      {TensorValue(&tensor_slice_dataset_tensor),
+       TensorValue(&test_case.cycle_length),
+       TensorValue(&test_case.block_length), TensorValue(&test_case.sloppy),
+       TensorValue(&test_case.buffer_output_elements),
+       TensorValue(&test_case.prefetch_input_elements)});
+  std::unique_ptr<OpKernelContext> parallel_interleave_dataset_context;
+  TF_ASSERT_OK(CreateParallelInterleaveDatasetContext(
+      parallel_interleave_dataset_kernel.get(), &inputs,
+      &parallel_interleave_dataset_context));
+  DatasetBase* parallel_interleave_dataset;
+  TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(),
+                             parallel_interleave_dataset_context.get(),
+                             &parallel_interleave_dataset));
+  core::ScopedUnref scoped_unref_dataset(parallel_interleave_dataset);
+
+  TF_EXPECT_OK(
+      VerifyShapesCompatible(parallel_interleave_dataset->output_shapes(),
+                             test_case.expected_output_shapes));
+}
+
+TEST_P(ParameterizedParallelInterleaveDatasetOpTest, Cardinality) {
+  int thread_num = 2, cpu_num = 2;
+  TestCase test_case = GetParam();
+  TF_ASSERT_OK(InitThreadPool(thread_num));
+  TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
+
+  std::unique_ptr<OpKernel> parallel_interleave_dataset_kernel;
+  TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel(
+      test_case.func, test_case.expected_output_dtypes,
+      test_case.expected_output_shapes, &parallel_interleave_dataset_kernel));
+
+  Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
+  std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
+  TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
+                                              &tensor_slice_dataset_tensor));
+  gtl::InlinedVector<TensorValue, 4> inputs(
+      {TensorValue(&tensor_slice_dataset_tensor),
+       TensorValue(&test_case.cycle_length),
+       TensorValue(&test_case.block_length), TensorValue(&test_case.sloppy),
+       TensorValue(&test_case.buffer_output_elements),
+       TensorValue(&test_case.prefetch_input_elements)});
+  std::unique_ptr<OpKernelContext> parallel_interleave_dataset_context;
+  TF_ASSERT_OK(CreateParallelInterleaveDatasetContext(
+      parallel_interleave_dataset_kernel.get(), &inputs,
+      &parallel_interleave_dataset_context));
+  DatasetBase* parallel_interleave_dataset;
+  TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(),
+                             parallel_interleave_dataset_context.get(),
+                             &parallel_interleave_dataset));
+  core::ScopedUnref scoped_unref_dataset(parallel_interleave_dataset);
+
+  EXPECT_EQ(parallel_interleave_dataset->Cardinality(),
+            test_case.expected_cardinality);
+}
+
+TEST_P(ParameterizedParallelInterleaveDatasetOpTest, IteratorOutputDtypes) {
+  int thread_num = 2, cpu_num = 2;
+  TestCase test_case = GetParam();
+  TF_ASSERT_OK(InitThreadPool(thread_num));
+  TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
+
+  std::unique_ptr<OpKernel> parallel_interleave_dataset_kernel;
+  TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel(
+      test_case.func, test_case.expected_output_dtypes,
+      test_case.expected_output_shapes, &parallel_interleave_dataset_kernel));
+
+  Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
+  std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
+  TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
+                                              &tensor_slice_dataset_tensor));
+  gtl::InlinedVector<TensorValue, 4> inputs(
+      {TensorValue(&tensor_slice_dataset_tensor),
+       TensorValue(&test_case.cycle_length),
+       TensorValue(&test_case.block_length), TensorValue(&test_case.sloppy),
+       TensorValue(&test_case.buffer_output_elements),
+       TensorValue(&test_case.prefetch_input_elements)});
+  std::unique_ptr<OpKernelContext> parallel_interleave_dataset_context;
+  TF_ASSERT_OK(CreateParallelInterleaveDatasetContext(
+      parallel_interleave_dataset_kernel.get(), &inputs,
+      &parallel_interleave_dataset_context));
+  DatasetBase* parallel_interleave_dataset;
+  TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(),
+                             parallel_interleave_dataset_context.get(),
+                             &parallel_interleave_dataset));
+  core::ScopedUnref scoped_unref_dataset(parallel_interleave_dataset);
+
+  std::unique_ptr<IteratorContext> iterator_ctx;
+  TF_ASSERT_OK(CreateIteratorContext(parallel_interleave_dataset_context.get(),
+                                     &iterator_ctx));
+  std::unique_ptr<IteratorBase> iterator;
+  TF_ASSERT_OK(parallel_interleave_dataset->MakeIterator(
+      iterator_ctx.get(), kIteratorPrefix, &iterator));
+
+  TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(),
+                                test_case.expected_output_dtypes));
+}
+
+TEST_P(ParameterizedParallelInterleaveDatasetOpTest, IteratorOutputShapes) {
+  int thread_num = 2, cpu_num = 2;
+  TestCase test_case = GetParam();
+  TF_ASSERT_OK(InitThreadPool(thread_num));
+  TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
+
+  std::unique_ptr<OpKernel> parallel_interleave_dataset_kernel;
+  TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel(
+      test_case.func, test_case.expected_output_dtypes,
+      test_case.expected_output_shapes, &parallel_interleave_dataset_kernel));
+
+  Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
+  std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
+  TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
+                                              &tensor_slice_dataset_tensor));
+  gtl::InlinedVector<TensorValue, 4> inputs(
+      {TensorValue(&tensor_slice_dataset_tensor),
+       TensorValue(&test_case.cycle_length),
+       TensorValue(&test_case.block_length), TensorValue(&test_case.sloppy),
+       TensorValue(&test_case.buffer_output_elements),
+       TensorValue(&test_case.prefetch_input_elements)});
+  std::unique_ptr<OpKernelContext> parallel_interleave_dataset_context;
+  TF_ASSERT_OK(CreateParallelInterleaveDatasetContext(
+      parallel_interleave_dataset_kernel.get(), &inputs,
+      &parallel_interleave_dataset_context));
+  DatasetBase* parallel_interleave_dataset;
+  TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(),
+                             parallel_interleave_dataset_context.get(),
+                             &parallel_interleave_dataset));
+  core::ScopedUnref scoped_unref_dataset(parallel_interleave_dataset);
+
+  std::unique_ptr<IteratorContext> iterator_ctx;
+  TF_ASSERT_OK(CreateIteratorContext(parallel_interleave_dataset_context.get(),
+                                     &iterator_ctx));
+  std::unique_ptr<IteratorBase> iterator;
+  TF_ASSERT_OK(parallel_interleave_dataset->MakeIterator(
+      iterator_ctx.get(), kIteratorPrefix, &iterator));
+
+  TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(),
+                                      test_case.expected_output_shapes));
+}
+
+TEST_F(ParallelInterleaveDatasetOpTest, IteratorOutputPrefix) {
+  int thread_num = 2, cpu_num = 2;
+  TestCase test_case = TestCase1();
+  TF_ASSERT_OK(InitThreadPool(thread_num));
+  TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
+
+  std::unique_ptr<OpKernel> parallel_interleave_dataset_kernel;
+  TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel(
+      test_case.func, test_case.expected_output_dtypes,
+      test_case.expected_output_shapes, &parallel_interleave_dataset_kernel));
+
+  Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
+  std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
+  TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
+                                              &tensor_slice_dataset_tensor));
+  gtl::InlinedVector<TensorValue, 4> inputs(
+      {TensorValue(&tensor_slice_dataset_tensor),
+       TensorValue(&test_case.cycle_length),
+       TensorValue(&test_case.block_length), TensorValue(&test_case.sloppy),
+       TensorValue(&test_case.buffer_output_elements),
+       TensorValue(&test_case.prefetch_input_elements)});
+  std::unique_ptr<OpKernelContext> parallel_interleave_dataset_context;
+  TF_ASSERT_OK(CreateParallelInterleaveDatasetContext(
+      parallel_interleave_dataset_kernel.get(), &inputs,
+      &parallel_interleave_dataset_context));
+  DatasetBase* parallel_interleave_dataset;
+  TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(),
+                             parallel_interleave_dataset_context.get(),
+                             &parallel_interleave_dataset));
+  core::ScopedUnref scoped_unref_dataset(parallel_interleave_dataset);
+
+  std::unique_ptr<IteratorContext> iterator_ctx;
+  TF_ASSERT_OK(CreateIteratorContext(parallel_interleave_dataset_context.get(),
+                                     &iterator_ctx));
+  std::unique_ptr<IteratorBase> iterator;
+  TF_ASSERT_OK(parallel_interleave_dataset->MakeIterator(
+      iterator_ctx.get(), kIteratorPrefix, &iterator));
+  EXPECT_EQ(iterator->prefix(),
+            name_utils::IteratorPrefix(
+                ParallelInterleaveDatasetOp::kDatasetType, kIteratorPrefix));
+}
+
+TEST_P(ParameterizedParallelInterleaveDatasetOpTest, Roundtrip) {
+  int thread_num = 2, cpu_num = 2;
+  TestCase test_case = GetParam();
+  TF_ASSERT_OK(InitThreadPool(thread_num));
+  TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
+
+  std::unique_ptr<OpKernel> parallel_interleave_dataset_kernel;
+  TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel(
+      test_case.func, test_case.expected_output_dtypes,
+      test_case.expected_output_shapes, &parallel_interleave_dataset_kernel));
+
+  Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
+  std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
+  TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
+                                              &tensor_slice_dataset_tensor));
+  gtl::InlinedVector<TensorValue, 4> inputs(
+      {TensorValue(&tensor_slice_dataset_tensor),
+       TensorValue(&test_case.cycle_length),
+       TensorValue(&test_case.block_length), TensorValue(&test_case.sloppy),
+       TensorValue(&test_case.buffer_output_elements),
+       TensorValue(&test_case.prefetch_input_elements)});
+  std::unique_ptr<OpKernelContext> parallel_interleave_dataset_context;
+  TF_ASSERT_OK(CreateParallelInterleaveDatasetContext(
+      parallel_interleave_dataset_kernel.get(), &inputs,
+      &parallel_interleave_dataset_context));
+  DatasetBase* parallel_interleave_dataset;
+  TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(),
+                             parallel_interleave_dataset_context.get(),
+                             &parallel_interleave_dataset));
+  core::ScopedUnref scoped_unref_dataset(parallel_interleave_dataset);
+
+  std::unique_ptr<IteratorContext> iterator_ctx;
+  TF_ASSERT_OK(CreateIteratorContext(parallel_interleave_dataset_context.get(),
+                                     &iterator_ctx));
+  std::unique_ptr<IteratorBase> iterator;
+  TF_ASSERT_OK(parallel_interleave_dataset->MakeIterator(
+      iterator_ctx.get(), kIteratorPrefix, &iterator));
+
+  std::unique_ptr<SerializationContext> serialization_ctx;
+  TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
+
+  bool end_of_sequence = false;
+  std::vector<Tensor> out_tensors;
+  int cur_iteration = 0;
+  const std::vector<int>& breakpoints = test_case.breakpoints;
+  for (int breakpoint : breakpoints) {
+    VariantTensorData data;
+    VariantTensorDataWriter writer(&data);
+    TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
+    TF_EXPECT_OK(writer.Flush());
+    VariantTensorDataReader reader(&data);
+    TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, kIteratorPrefix,
+                                 *parallel_interleave_dataset, &iterator));
+    while (cur_iteration <= breakpoint) {
+      std::vector<Tensor> next;
+      TF_EXPECT_OK(
+          iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence));
+      out_tensors.insert(out_tensors.end(), next.begin(), next.end());
+      cur_iteration++;
+    }
+  }
+
+  TF_EXPECT_OK(
+      ExpectEqual(out_tensors, test_case.expected_outputs,
+                  /*compare_order*/ !test_case.sloppy.scalar<bool>()()));
+}
+
+INSTANTIATE_TEST_SUITE_P(ParallelInterleaveDatasetOpTest,
+                         ParameterizedParallelInterleaveDatasetOpTest,
+                         ::testing::ValuesIn(std::vector<TestCase>(
+                             {TestCase1(), TestCase2(), TestCase3(),
+                              TestCase4(), TestCase5()})));
+
+TEST_F(ParallelInterleaveDatasetOpTest, InvalidArguments) {
+  int thread_num = 2, cpu_num = 2;
+  TF_ASSERT_OK(InitThreadPool(thread_num));
+
+  std::vector<TestCase> test_cases({InvalidCycleLengthTestCase(),
+                                    InvalidBlockLengthTestCase(),
+                                    InvalidBufferOutputElementsTestCase(),
+                                    InvalidPrefetchInputElementsTestCase()});
+  for (auto test_case : test_cases) {
+    TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
+    std::unique_ptr<OpKernel> parallel_interleave_dataset_kernel;
+    TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel(
+        test_case.func, test_case.expected_output_dtypes,
+        test_case.expected_output_shapes, &parallel_interleave_dataset_kernel));
+
+    Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
+    std::vector<Tensor> inputs_for_tensor_slice_dataset =
+        test_case.input_tensors;
+    TF_ASSERT_OK(CreateTensorSliceDatasetTensor(
+        &inputs_for_tensor_slice_dataset, &tensor_slice_dataset_tensor));
+    gtl::InlinedVector<TensorValue, 4> inputs(
+        {TensorValue(&tensor_slice_dataset_tensor),
+         TensorValue(&test_case.cycle_length),
+         TensorValue(&test_case.block_length), TensorValue(&test_case.sloppy),
+         TensorValue(&test_case.buffer_output_elements),
+         TensorValue(&test_case.prefetch_input_elements)});
+    std::unique_ptr<OpKernelContext> parallel_interleave_dataset_context;
+    TF_ASSERT_OK(CreateParallelInterleaveDatasetContext(
+        parallel_interleave_dataset_kernel.get(), &inputs,
+        &parallel_interleave_dataset_context));
+    DatasetBase* parallel_interleave_dataset;
+    EXPECT_EQ(CreateDataset(parallel_interleave_dataset_kernel.get(),
+                            parallel_interleave_dataset_context.get(),
+                            &parallel_interleave_dataset)
+                  .code(),
+              tensorflow::error::INVALID_ARGUMENT);
+  }
+}
+
+}  // namespace
+}  // namespace experimental
+}  // namespace data
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc
index 97b91d8..b73e226 100644
--- a/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc
@@ -22,10 +22,9 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
-// See documentation in ../../ops/dataset_ops.cc for a high-level
-// description of the following op.
 class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
  public:
   explicit ParseExampleDatasetOp(OpKernelConstruction* ctx)
@@ -207,6 +206,10 @@
 
     int64 Cardinality() const override { return input_->Cardinality(); }
 
+    Status CheckExternalState() const override {
+      return input_->CheckExternalState();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -404,5 +407,6 @@
     ParseExampleDatasetOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc b/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc
index 0a640c5..3fbb9bd 100644
--- a/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc
+++ b/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc
@@ -25,6 +25,7 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
 class IteratorGetDeviceOp : public OpKernel {
@@ -41,7 +42,7 @@
     // NOTE(mrry): Since the operation's input is a resource, we must be
     // colocated with it, and so we can simply return the current device's
     // name without looking at the input.
-    device_name_t->scalar<string>()() = ctx->device()->name();
+    device_name_t->scalar<tstring>()() = ctx->device()->name();
   }
 };
 
@@ -52,5 +53,6 @@
     IteratorGetDeviceOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/random_dataset_op.cc b/tensorflow/core/kernels/data/experimental/random_dataset_op.cc
index 80a3776..a5cc433 100644
--- a/tensorflow/core/kernels/data/experimental/random_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/random_dataset_op.cc
@@ -22,11 +22,9 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
-// See documentation in ../../ops/dataset_ops.cc for a high-level
-// description of the following op.
-
 class RandomDatasetOp : public DatasetOpKernel {
  public:
   explicit RandomDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {}
@@ -78,6 +76,8 @@
 
     int64 Cardinality() const override { return kInfiniteCardinality; }
 
+    Status CheckExternalState() const override { return Status::OK(); }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -160,5 +160,6 @@
                         RandomDatasetOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc
index ac351eb..13d0125 100644
--- a/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc
@@ -18,6 +18,7 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
 constexpr char kOptimizerName[] = "tf_data_rebatcher";
@@ -35,14 +36,15 @@
  protected:
   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
                    DatasetBase** output) override {
-    int64 num_workers;
-    OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_workers", &num_workers));
+    int64 num_replicas;
+    OP_REQUIRES_OK(ctx,
+                   ParseScalarArgument(ctx, "num_replicas", &num_replicas));
     OP_REQUIRES(
-        ctx, num_workers > 0,
-        errors::InvalidArgument("num_workers must be greater than zero."));
+        ctx, num_replicas > 0,
+        errors::InvalidArgument("num_replicas must be greater than zero."));
 
-    auto config_factory = [num_workers, this]() {
-      return CreateConfig(num_workers, this->use_fallback_);
+    auto config_factory = [num_replicas, this]() {
+      return CreateConfig(num_replicas, this->use_fallback_);
     };
 
     // We only want to optimize functions for some particular datasets like
@@ -55,17 +57,17 @@
   }
 
  private:
-  static RewriterConfig CreateConfig(int64 num_workers, bool use_fallback) {
+  static RewriterConfig CreateConfig(int64 num_replicas, bool use_fallback) {
     RewriterConfig rewriter_config;
     rewriter_config.set_fail_on_optimizer_errors(true);
     rewriter_config.add_optimizers(kOptimizerName);
     rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
     auto custom_optimizer = rewriter_config.add_custom_optimizers();
     custom_optimizer->set_name(kOptimizerName);
-    AttrValue num_workers_attr;
-    num_workers_attr.set_i(num_workers);
-    (*custom_optimizer->mutable_parameter_map())["num_workers"] =
-        num_workers_attr;
+    AttrValue num_replicas_attr;
+    num_replicas_attr.set_i(num_replicas);
+    (*custom_optimizer->mutable_parameter_map())["num_replicas"] =
+        num_replicas_attr;
     AttrValue use_fallback_attr;
     use_fallback_attr.set_b(use_fallback);
     (*custom_optimizer->mutable_parameter_map())["use_fallback"] =
@@ -82,5 +84,6 @@
                         RebatchDatasetOp);
 
 }  // anonymous namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/sampling_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sampling_dataset_op.cc
index a118fd8..65e3a0d 100644
--- a/tensorflow/core/kernels/data/experimental/sampling_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/sampling_dataset_op.cc
@@ -22,11 +22,9 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
-// See documentation in ../../ops/dataset_ops.cc for a high-level
-// description of the following op.
-
 class SamplingDatasetOp : public UnaryDatasetOpKernel {
  public:
   explicit SamplingDatasetOp(OpKernelConstruction* ctx)
@@ -81,6 +79,10 @@
 
     string DebugString() const override { return "SamplingDatasetOp::Dataset"; }
 
+    Status CheckExternalState() const override {
+      return input_->CheckExternalState();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -220,5 +222,6 @@
                         SamplingDatasetOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc b/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc
index 31ec086..e7fd1dd 100644
--- a/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc
@@ -26,11 +26,9 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
-// See documentation in ../../ops/dataset_ops.cc for a high-level
-// description of the following op.
-
 class ScanDatasetOp : public UnaryDatasetOpKernel {
  public:
   explicit ScanDatasetOp(OpKernelConstruction* ctx)
@@ -104,6 +102,11 @@
 
     int64 Cardinality() const override { return input_->Cardinality(); }
 
+    Status CheckExternalState() const override {
+      TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
+      return input_->CheckExternalState();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -298,5 +301,6 @@
 REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalScanDataset");
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc
index 64390e72..14fb48f 100644
--- a/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc
@@ -25,6 +25,7 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
 class StatsAggregatorWithTagAndPrefix : public StatsAggregator {
@@ -137,6 +138,10 @@
 
     int64 Cardinality() const override { return input_->Cardinality(); }
 
+    Status CheckExternalState() const override {
+      return input_->CheckExternalState();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -218,5 +223,6 @@
     SetStatsAggregatorDatasetOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc
index afe44dc..cbb53e1 100644
--- a/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc
@@ -18,11 +18,9 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
-// See documentation in ../ops/dataset_ops.cc for a high-level
-// description of the following op.
-
 class SleepDatasetOp : public UnaryDatasetOpKernel {
  public:
   using UnaryDatasetOpKernel::UnaryDatasetOpKernel;
@@ -69,6 +67,10 @@
 
     int64 Cardinality() const override { return input_->Cardinality(); }
 
+    Status CheckExternalState() const override {
+      return input_->CheckExternalState();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -152,5 +154,6 @@
                         SleepDatasetOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc
index 154ce7d..82baa01 100644
--- a/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc
@@ -23,11 +23,9 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
-// See documentation in ../../ops/dataset_ops.cc for a high-level
-// description of the following op.
-
 class SlidingWindowDatasetOp : public UnaryDatasetOpKernel {
  public:
   explicit SlidingWindowDatasetOp(OpKernelConstruction* ctx)
@@ -110,6 +108,10 @@
       return n / window_shift_;
     }
 
+    Status CheckExternalState() const override {
+      return input_->CheckExternalState();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -309,5 +311,6 @@
     SlidingWindowDatasetOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc
index 4e1b3e3..1f76ca2 100644
--- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc
@@ -23,6 +23,7 @@
 #include "tensorflow/core/lib/core/coding.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/raw_coding.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
 #include "tensorflow/core/lib/io/buffered_inputstream.h"
 #include "tensorflow/core/lib/io/compression.h"
 #include "tensorflow/core/lib/io/random_inputstream.h"
@@ -43,6 +44,7 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
 enum SnapshotMode { READER = 0, WRITER = 1, PASSTHROUGH = 2 };
@@ -54,6 +56,7 @@
 
 const char kSnapshotFilename[] = "snapshot.metadata";
 constexpr char kSnapshotReaderWorkerPool[] = "snapshot_reader_worker_pool";
+constexpr char kSnapshotWriterWorkerPool[] = "snapshot_writer_worker_pool";
 
 class SnapshotWriter {
  public:
@@ -311,7 +314,7 @@
     SerializationContext::Params params;
     std::vector<std::pair<string, Tensor>> input_list;
     params.input_list = &input_list;
-    params.optimization_only = true;
+    params.check_external_state = false;
 
     GraphDef graph_def;
     OP_REQUIRES_OK(
@@ -373,6 +376,10 @@
 
     int64 Cardinality() const override { return input_->Cardinality(); }
 
+    Status CheckExternalState() const override {
+      return input_->CheckExternalState();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -499,21 +506,21 @@
             const experimental::SnapshotMetadataRecord& metadata)
             : DatasetIterator<Dataset>(params),
               hash_dir_(hash_dir),
-              metadata_(metadata) {
-          thread_pool_ = absl::make_unique<thread::ThreadPool>(
-              Env::Default(), ThreadOptions(), kSnapshotReaderWorkerPool,
-              params.dataset->num_reader_threads_, /*low_latency_hint=*/false);
-        }
+              metadata_(metadata) {}
 
         ~SnapshotReaderIterator() override {
           mutex_lock l(mu_);
           cancelled_ = true;
           cond_var_.notify_all();
+          while (num_active_threads_ > 0) {
+            cond_var_.wait(l);
+          }
         }
 
         Status Initialize(IteratorContext* ctx) override {
           mutex_lock l(mu_);
-
+          thread_pool_ = ctx->CreateThreadPool(kSnapshotReaderWorkerPool,
+                                               dataset()->num_reader_threads_);
           run_id_ = metadata_.run_id();
           run_dir_ = absl::StrCat(hash_dir_, "/", run_id_);
           // Get all the files in the run_dir.
@@ -534,6 +541,7 @@
           mutex_lock l(mu_);
           if (!background_threads_started_) {
             for (int i = 0; i < dataset()->num_reader_threads_; ++i) {
+              ++num_active_threads_;
               thread_pool_->Schedule([this]() { ReadingFilesLoop(); });
             }
             background_threads_started_ = true;
@@ -648,6 +656,11 @@
         // Pulls one file off the filenames_ list and reads it through. When
         // all files are read, terminates.
         void ReadingFilesLoop() {
+          auto cleanup = gtl::MakeCleanup([this]() {
+            mutex_lock l(mu_);
+            --num_active_threads_;
+            cond_var_.notify_all();
+          });
           while (true) {
             string filename = "";
             {
@@ -691,6 +704,9 @@
           std::vector<Tensor> value;
         };
 
+        mutex mu_;
+        condition_variable cond_var_;
+
         const string hash_dir_;
         const experimental::SnapshotMetadataRecord metadata_;
         string run_id_ GUARDED_BY(mu_);
@@ -704,39 +720,36 @@
         int64 num_files_done_ GUARDED_BY(mu_) = 0;
 
         std::unique_ptr<thread::ThreadPool> thread_pool_;
-        condition_variable cond_var_;
+        int64 num_active_threads_ GUARDED_BY(mu_) = 0;
         std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
         bool cancelled_ GUARDED_BY(mu_) = false;
         bool background_threads_started_ GUARDED_BY(mu_) = false;
         bool background_threads_finished_ GUARDED_BY(mu_) = false;
-
-        mutex mu_;
       };
 
       class SnapshotWriterIterator : public DatasetIterator<Dataset> {
        public:
         explicit SnapshotWriterIterator(const Params& params,
                                         const string& hash_dir)
-            : DatasetIterator<Dataset>(params), hash_dir_(hash_dir) {
-          thread_pool_ = absl::make_unique<thread::ThreadPool>(
-              Env::Default(), ThreadOptions(), "snapshot_writer_pool",
-              params.dataset->num_writer_threads_, /*low_latency_hint=*/false);
-        }
+            : DatasetIterator<Dataset>(params), hash_dir_(hash_dir) {}
 
         ~SnapshotWriterIterator() override {
           mutex_lock l(mu_);
           cancelled_ = true;
           cond_var_.notify_all();
+          while (num_active_threads_ > 0) {
+            cond_var_.wait(l);
+          }
         }
 
         Status Initialize(IteratorContext* ctx) override {
           mutex_lock l(mu_);
-
+          thread_pool_ = ctx->CreateThreadPool(kSnapshotWriterWorkerPool,
+                                               dataset()->num_writer_threads_);
           run_id_ = strings::StrCat(
               strings::Hex(random::New64(), strings::kZeroPad4));
           run_dir_ = absl::StrCat(dataset()->writer_path_prefix_, hash_dir_,
                                   "/", run_id_);
-
           TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(run_dir_));
 
           experimental::SnapshotMetadataRecord metadata;
@@ -744,7 +757,6 @@
           metadata.set_graph_hash(dataset()->graph_hash_);
           metadata.set_run_id(run_id_);
           metadata.set_finalized(false);
-
           TF_RETURN_IF_ERROR(WriteMetadataFile(hash_dir_, metadata));
 
           return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
@@ -761,6 +773,7 @@
             first_call = first_call_;
             if (first_call_) {
               for (int i = 0; i < dataset()->num_writer_threads_; ++i) {
+                ++num_active_threads_;
                 thread_pool_->Schedule([this]() { WriterThread(); });
               }
               first_call_ = false;
@@ -836,7 +849,7 @@
             end_of_sequence_ = true;
             cond_var_.notify_all();
             // Now we wait till all background threads finish.
-            while (num_threads_finished_ < dataset()->num_writer_threads_) {
+            while (num_active_threads_ > 0) {
               cond_var_.wait(l);
             }
             return Status::OK();
@@ -957,6 +970,12 @@
 
         // Just pulls off elements from the buffer and writes them.
         void WriterThread() {
+          auto cleanup = gtl::MakeCleanup([this]() {
+            mutex_lock l(mu_);
+            --num_active_threads_;
+            cond_var_.notify_all();
+          });
+
           int64 bytes_written = 0;
           string snapshot_data_filename = GetSnapshotFilename();
           std::unique_ptr<WritableFile> file;
@@ -987,12 +1006,19 @@
               return;
             }
           }
-          mutex_lock l(mu_);
-          num_threads_finished_++;
-          cond_var_.notify_all();
         }
 
         mutex mu_;
+        // This condition variable is notified
+        // 1. By the background writer threads when an element from the buffer
+        //    is consumed.
+        // 2. By the main thread when it puts something into the buffer.
+        // 3. By the main thread when the destructor is called to cancel.
+        // 4. By the background writer threads when any error is encountered
+        //    while writing.
+        // 5. By the background threads when they finish.
+        condition_variable cond_var_;
+
         BufferElement next_elem_ GUARDED_BY(mu_);
         std::unique_ptr<IteratorBase> input_impl_;
 
@@ -1004,15 +1030,6 @@
         int64 time_spent_micros_ GUARDED_BY(mu_) = 0;
         int64 bytes_produced_ GUARDED_BY(mu_) = 0;
 
-        // This condition variable is notified
-        // 1. By the background writer threads when an element from the buffer
-        //    is consumed.
-        // 2. By the main thread when it puts something into the buffer.
-        // 3. By the main thread when the destructor is called to cancel.
-        // 4. By the background writer threads when any error is encountered
-        //    while writing.
-        // 5. By the background threads when they finish.
-        condition_variable cond_var_;
         std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
         bool snapshot_failed_ GUARDED_BY(mu_) = false;
         bool cancelled_ GUARDED_BY(mu_) = false;
@@ -1020,8 +1037,8 @@
         bool end_of_sequence_ GUARDED_BY(mu_) = false;
         bool written_final_metadata_file_ GUARDED_BY(mu_) = false;
         uint64 next_file_index_ GUARDED_BY(mu_) = 0;
-        int64 num_threads_finished_ GUARDED_BY(mu_) = 0;
         std::unique_ptr<thread::ThreadPool> thread_pool_;
+        int64 num_active_threads_ GUARDED_BY(mu_) = 0;
       };
 
       class SnapshotPassthroughIterator : public DatasetIterator<Dataset> {
@@ -1086,5 +1103,6 @@
                         SnapshotDatasetOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/sql/driver_manager.cc b/tensorflow/core/kernels/data/experimental/sql/driver_manager.cc
index 58174f6..5bb511f 100644
--- a/tensorflow/core/kernels/data/experimental/sql/driver_manager.cc
+++ b/tensorflow/core/kernels/data/experimental/sql/driver_manager.cc
@@ -17,6 +17,7 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace sql {
 
 std::unique_ptr<QueryConnection> DriverManager::CreateQueryConnection(
@@ -30,5 +31,6 @@
 }
 
 }  // namespace sql
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/sql/driver_manager.h b/tensorflow/core/kernels/data/experimental/sql/driver_manager.h
index 6afadf9..7aa307e 100644
--- a/tensorflow/core/kernels/data/experimental/sql/driver_manager.h
+++ b/tensorflow/core/kernels/data/experimental/sql/driver_manager.h
@@ -19,6 +19,7 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace sql {
 
 // A factory class for creating `QueryConnection` instances.
@@ -35,6 +36,7 @@
 };
 
 }  // namespace sql
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/kernels/data/experimental/sql/query_connection.h b/tensorflow/core/kernels/data/experimental/sql/query_connection.h
index 10c6643..40f13d5 100644
--- a/tensorflow/core/kernels/data/experimental/sql/query_connection.h
+++ b/tensorflow/core/kernels/data/experimental/sql/query_connection.h
@@ -22,6 +22,8 @@
 
 class IteratorContext;
 
+namespace experimental {
+
 namespace sql {
 // This interface allows a user to connect to a database, execute a query, and
 // iterate over the result set, putting the results into an output tensor.
@@ -64,6 +66,7 @@
 };
 
 }  // namespace sql
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/kernels/data/experimental/sql/sqlite_query_connection.cc b/tensorflow/core/kernels/data/experimental/sql/sqlite_query_connection.cc
index cadceee..37dc6b49 100644
--- a/tensorflow/core/kernels/data/experimental/sql/sqlite_query_connection.cc
+++ b/tensorflow/core/kernels/data/experimental/sql/sqlite_query_connection.cc
@@ -20,6 +20,7 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace sql {
 
 SqliteQueryConnection::SqliteQueryConnection() {}
@@ -114,5 +115,6 @@
 }
 
 }  // namespace sql
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/sql/sqlite_query_connection.h b/tensorflow/core/kernels/data/experimental/sql/sqlite_query_connection.h
index 61df290..42526c7 100644
--- a/tensorflow/core/kernels/data/experimental/sql/sqlite_query_connection.h
+++ b/tensorflow/core/kernels/data/experimental/sql/sqlite_query_connection.h
@@ -23,6 +23,7 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace sql {
 
 class SqliteQueryConnection : public QueryConnection {
@@ -50,6 +51,7 @@
 };
 
 }  // namespace sql
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/kernels/data/experimental/sql_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sql_dataset_op.cc
index 8a095d9c..2b7283c 100644
--- a/tensorflow/core/kernels/data/experimental/sql_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/sql_dataset_op.cc
@@ -25,11 +25,9 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
-// See documentation in ../../ops/dataset_ops.cc for a high-level
-// description of the following ops.
-
 class SqlDatasetOp : public DatasetOpKernel {
  public:
   explicit SqlDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
@@ -105,6 +103,8 @@
 
     string DebugString() const override { return "SqlDatasetOp::Dataset"; }
 
+    Status CheckExternalState() const override { return Status::OK(); }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -219,5 +219,6 @@
                         SqlDatasetOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc b/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc
index 9b5a483..05dadf0 100644
--- a/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc
+++ b/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc
@@ -30,6 +30,7 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
 static mutex* get_counters_map_lock() {
@@ -266,7 +267,7 @@
     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &summary_t));
     Summary summary;
     resource->stats_aggregator()->EncodeToProto(&summary);
-    summary_t->scalar<string>()() = summary.SerializeAsString();
+    summary_t->scalar<tstring>()() = summary.SerializeAsString();
   }
 };
 
@@ -316,5 +317,6 @@
     StatsAggregatorSetSummaryWriterOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc b/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc
index 70a95fa..8525fa5 100644
--- a/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc
@@ -20,6 +20,7 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
 // This op defines a `Dataset` that passes through its input elements and
@@ -77,6 +78,10 @@
 
     int64 Cardinality() const override { return input_->Cardinality(); }
 
+    Status CheckExternalState() const override {
+      return input_->CheckExternalState();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -188,6 +193,10 @@
 
     int64 Cardinality() const override { return input_->Cardinality(); }
 
+    Status CheckExternalState() const override {
+      return input_->CheckExternalState();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -271,5 +280,6 @@
     LatencyStatsDatasetOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/take_while_dataset_op.cc b/tensorflow/core/kernels/data/experimental/take_while_dataset_op.cc
index af7f778..378fa80 100644
--- a/tensorflow/core/kernels/data/experimental/take_while_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/take_while_dataset_op.cc
@@ -26,11 +26,9 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
-// See documentation in ../../ops/dataset_ops.cc for a high-level
-// description of the following op.
-
 class TakeWhileDatasetOp : public UnaryDatasetOpKernel {
  public:
   explicit TakeWhileDatasetOp(OpKernelConstruction* ctx)
@@ -86,6 +84,11 @@
 
     int64 Cardinality() const override { return kUnknownCardinality; }
 
+    Status CheckExternalState() const override {
+      TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
+      return input_->CheckExternalState();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -207,5 +210,6 @@
 REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalTakeWhileDataset");
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
index 8aa26ea..0ece761 100644
--- a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
@@ -24,6 +24,7 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
 class ThreadPoolResource : public ResourceBase {
@@ -171,6 +172,10 @@
 
     int64 Cardinality() const override { return input_->Cardinality(); }
 
+    Status CheckExternalState() const override {
+      return input_->CheckExternalState();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -278,6 +283,10 @@
 
     int64 Cardinality() const override { return input_->Cardinality(); }
 
+    Status CheckExternalState() const override {
+      return input_->CheckExternalState();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -378,6 +387,10 @@
 
     int64 Cardinality() const override { return input_->Cardinality(); }
 
+    Status CheckExternalState() const override {
+      return input_->CheckExternalState();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -455,5 +468,6 @@
     ThreadPoolDatasetOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc b/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc
index 1cc3bc0..f45b493 100644
--- a/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc
+++ b/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc
@@ -24,6 +24,7 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
 class ToTFRecordOp : public AsyncOpKernel {
@@ -121,7 +122,7 @@
 
             if (!end_of_sequence) {
               OP_REQUIRES_OK_ASYNC(
-                  ctx, writer->WriteRecord(components[0].scalar<string>()()),
+                  ctx, writer->WriteRecord(components[0].scalar<tstring>()()),
                   done);
             }
             components.clear();
@@ -141,5 +142,6 @@
     Name("ExperimentalDatasetToTFRecord").Device(DEVICE_CPU), ToTFRecordOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc
index 3252196..c206982 100644
--- a/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc
@@ -19,11 +19,9 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
-// See documentation in ../../ops/dataset_ops.cc for a high-level
-// description of the following op.
-
 class UnbatchDatasetOp : public UnaryDatasetOpKernel {
  public:
   explicit UnbatchDatasetOp(OpKernelConstruction* ctx)
@@ -71,6 +69,10 @@
 
     string DebugString() const override { return "UnbatchDatasetOp::Dataset"; }
 
+    Status CheckExternalState() const override {
+      return input_->CheckExternalState();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -227,5 +229,6 @@
                         UnbatchDatasetOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc
index 613b2fd..9092687 100644
--- a/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc
@@ -19,11 +19,9 @@
 
 namespace tensorflow {
 namespace data {
+namespace experimental {
 namespace {
 
-// See documentation in ../ops/dataset_ops.cc for a high-level
-// description of the following op.
-
 class UniqueDatasetOp : public UnaryDatasetOpKernel {
  public:
   explicit UniqueDatasetOp(OpKernelConstruction* ctx)
@@ -74,6 +72,10 @@
       return strings::StrCat("UniqueDatasetOp::Dataset");
     }
 
+    Status CheckExternalState() const override {
+      return input_->CheckExternalState();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
@@ -171,7 +173,7 @@
             return Hash64(t.tensor_data().data(), t.tensor_data().size());
           } else {
             DCHECK_EQ(DT_STRING, t.dtype());
-            auto flat_t = t.flat<string>();
+            auto flat_t = t.flat<tstring>();
             uint64 hash = 0;
             for (int64 i = 0; i < t.NumElements(); ++i) {
               hash = Hash64Combine(hash, Hash64(flat_t(i)));
@@ -227,5 +229,6 @@
                         UniqueDatasetOp);
 
 }  // namespace
+}  // namespace experimental
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc
index ad92840..9325856 100644
--- a/tensorflow/core/kernels/data/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_dataset_op.cc
@@ -75,6 +75,11 @@
     return name_utils::DatasetDebugString(kDatasetType);
   }
 
+  Status CheckExternalState() const override {
+    TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
+    return input_->CheckExternalState();
+  }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
diff --git a/tensorflow/core/kernels/data/filter_dataset_op_test.cc b/tensorflow/core/kernels/data/filter_dataset_op_test.cc
index bb4e17d..8634207 100644
--- a/tensorflow/core/kernels/data/filter_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/filter_dataset_op_test.cc
@@ -77,8 +77,7 @@
   std::vector<Tensor> tensors;
   tensors.reserve(values.size());
   for (auto &value : values) {
-    tensors.emplace_back(
-        DatasetOpsTestBase::CreateTensor<T>(TensorShape({1}), {value}));
+    tensors.emplace_back(CreateTensor<T>(TensorShape({1}), {value}));
   }
   return tensors;
 }
@@ -86,8 +85,7 @@
 // Test case 1: norm case.
 TestCase TestCase1() {
   return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{9, 1}, {0, 0, 0, 3, 4, 5, 6, 7, 8})},
+          {CreateTensor<int64>(TensorShape{9, 1}, {0, 0, 0, 3, 4, 5, 6, 7, 8})},
           /*func*/ FunctionDefHelper::FunctionRef("IsZero", {{"T", DT_INT64}}),
           /*func_lib*/ {test::function::IsZero()},
           /*expected_outputs*/
@@ -101,7 +99,7 @@
 // Test case 2: the input dataset has no outputs.
 TestCase TestCase2() {
   return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{0}, {})},
+          {CreateTensor<int64>(TensorShape{0}, {})},
           /*func*/ FunctionDefHelper::FunctionRef("IsZero", {{"T", DT_INT64}}),
           /*func_lib*/ {test::function::IsZero()},
           /*expected_outputs*/
@@ -115,8 +113,7 @@
 // Test case 3: the filter function returns two outputs.
 TestCase InvalidFuncTestCase1() {
   return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{3, 3}, {0, 0, 0, 3, 4, 5, 6, 7, 8})},
+          {CreateTensor<int64>(TensorShape{3, 3}, {0, 0, 0, 3, 4, 5, 6, 7, 8})},
           /*func*/
           FunctionDefHelper::FunctionRef(
               "GetUnique", {{"T", DT_INT64}, {"out_idx", DT_INT32}}),
@@ -131,24 +128,23 @@
 
 // Test case 4: the filter function returns a 1-D bool tensor.
 TestCase InvalidFuncTestCase2() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{3, 3, 1}, {0, 0, 0, 3, 4, 5, 6, 7, 8})},
-          /*func*/ FunctionDefHelper::FunctionRef("IsZero", {{"T", DT_INT64}}),
-          /*func_lib*/ {test::function::IsZero()},
-          /*expected_outputs*/
-          ConvertToTensorVec<int64>({}),
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({3, 1})},
-          /*expected_cardinality*/ kUnknownCardinality,
-          /*breakpoints*/ {}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{3, 3, 1}, {0, 0, 0, 3, 4, 5, 6, 7, 8})},
+      /*func*/ FunctionDefHelper::FunctionRef("IsZero", {{"T", DT_INT64}}),
+      /*func_lib*/ {test::function::IsZero()},
+      /*expected_outputs*/
+      ConvertToTensorVec<int64>({}),
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({3, 1})},
+      /*expected_cardinality*/ kUnknownCardinality,
+      /*breakpoints*/ {}};
 }
 
 // Test case 5: the filter function returns a scalar int64 tensor.
 TestCase InvalidFuncTestCase3() {
   return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{9}, {0, 0, 0, 3, 4, 5, 6, 7, 8})},
+          {CreateTensor<int64>(TensorShape{9}, {0, 0, 0, 3, 4, 5, 6, 7, 8})},
           /*func*/ FunctionDefHelper::FunctionRef("NonZero", {{"T", DT_INT64}}),
           /*func_lib*/ {test::function::NonZero()},
           /*expected_outputs*/
@@ -350,39 +346,6 @@
   EXPECT_EQ(filter_dataset->Cardinality(), test_case.expected_cardinality);
 }
 
-TEST_P(ParameterizedFilterDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  const TestCase &test_case = GetParam();
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
-
-  std::unique_ptr<OpKernel> filter_dataset_kernel;
-  TF_ASSERT_OK(CreateFilterDatasetKernel(
-      test_case.func, test_case.expected_output_dtypes,
-      test_case.expected_output_shapes, &filter_dataset_kernel));
-
-  Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
-  std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
-  TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
-                                              &tensor_slice_dataset_tensor));
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&tensor_slice_dataset_tensor)});
-  std::unique_ptr<OpKernelContext> filter_dataset_context;
-  TF_ASSERT_OK(CreateFilterDatasetContext(filter_dataset_kernel.get(), &inputs,
-                                          &filter_dataset_context));
-  DatasetBase *filter_dataset;
-  TF_ASSERT_OK(CreateDataset(filter_dataset_kernel.get(),
-                             filter_dataset_context.get(), &filter_dataset));
-  core::ScopedUnref scoped_unref(filter_dataset);
-
-  std::unique_ptr<SerializationContext> serialization_ctx;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(filter_dataset->Save(serialization_ctx.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_P(ParameterizedFilterDatasetOpTest, IteratorOutputDtypes) {
   int thread_num = 2, cpu_num = 2;
   const TestCase &test_case = GetParam();
diff --git a/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc b/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc
index dd147c6..fdfe756 100644
--- a/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc
+++ b/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc
@@ -93,6 +93,8 @@
     return name_utils::DatasetDebugString(kDatasetType, params);
   }
 
+  Status CheckExternalState() const override { return Status::OK(); }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
@@ -141,7 +143,7 @@
 
             // Produce the record as output.
             Tensor record_tensor(ctx->allocator({}), DT_STRING, {});
-            record_tensor.scalar<string>()() = record;
+            record_tensor.scalar<tstring>()() = record;
             out_tensors->emplace_back(std::move(record_tensor));
             *end_of_sequence = false;
             return Status::OK();
@@ -264,7 +266,7 @@
 
               // Produce the record as output.
               Tensor record_tensor(ctx->allocator({}), DT_STRING, {});
-              record_tensor.scalar<string>()() = std::move(record);
+              record_tensor.scalar<tstring>()() = std::move(record);
               out_tensors->emplace_back(std::move(record_tensor));
               *end_of_sequence = false;
               return Status::OK();
@@ -282,7 +284,7 @@
                   lookahead_cache_.substr(dataset()->record_bytes_);
               // Produce the record as output.
               Tensor record_tensor(ctx->allocator({}), DT_STRING, {});
-              record_tensor.scalar<string>()() = std::move(record);
+              record_tensor.scalar<tstring>()() = std::move(record);
               out_tensors->emplace_back(std::move(record_tensor));
               *end_of_sequence = false;
               return Status::OK();
@@ -459,7 +461,7 @@
   std::vector<string> filenames;
   filenames.reserve(filenames_tensor->NumElements());
   for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
-    filenames.push_back(filenames_tensor->flat<string>()(i));
+    filenames.push_back(filenames_tensor->flat<tstring>()(i));
   }
 
   int64 header_bytes = -1;
diff --git a/tensorflow/core/kernels/data/fixed_length_record_dataset_op_test.cc b/tensorflow/core/kernels/data/fixed_length_record_dataset_op_test.cc
index 0a0f7a2..6636172 100644
--- a/tensorflow/core/kernels/data/fixed_length_record_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/fixed_length_record_dataset_op_test.cc
@@ -105,11 +105,11 @@
           /*buffer_size*/ 10,
           /*compression_type*/ CompressionType::ZLIB,
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"111"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"222"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"333"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"aaa"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"bbb"})},
+          {CreateTensor<string>(TensorShape({}), {"111"}),
+           CreateTensor<string>(TensorShape({}), {"222"}),
+           CreateTensor<string>(TensorShape({}), {"333"}),
+           CreateTensor<string>(TensorShape({}), {"aaa"}),
+           CreateTensor<string>(TensorShape({}), {"bbb"})},
           /*expected_output_dtypes*/ {DT_STRING},
           /*expected_output_shapes*/ {PartialTensorShape({})},
           /*expected_cardinality*/ kUnknownCardinality,
@@ -129,11 +129,11 @@
           /*buffer_size*/ 10,
           /*compression_type*/ CompressionType::GZIP,
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"111"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"222"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"333"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"aaa"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"bbb"})},
+          {CreateTensor<string>(TensorShape({}), {"111"}),
+           CreateTensor<string>(TensorShape({}), {"222"}),
+           CreateTensor<string>(TensorShape({}), {"333"}),
+           CreateTensor<string>(TensorShape({}), {"aaa"}),
+           CreateTensor<string>(TensorShape({}), {"bbb"})},
           /*expected_output_dtypes*/ {DT_STRING},
           /*expected_output_shapes*/ {PartialTensorShape({})},
           /*expected_cardinality*/ kUnknownCardinality,
@@ -154,11 +154,11 @@
           /*buffer_size*/ 10,
           /*compression_type*/ CompressionType::UNCOMPRESSED,
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"111"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"222"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"333"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"aaa"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"bbb"})},
+          {CreateTensor<string>(TensorShape({}), {"111"}),
+           CreateTensor<string>(TensorShape({}), {"222"}),
+           CreateTensor<string>(TensorShape({}), {"333"}),
+           CreateTensor<string>(TensorShape({}), {"aaa"}),
+           CreateTensor<string>(TensorShape({}), {"bbb"})},
           /*expected_output_dtypes*/ {DT_STRING},
           /*expected_output_shapes*/ {PartialTensorShape({})},
           /*expected_cardinality*/ kUnknownCardinality,
@@ -452,56 +452,6 @@
             test_case.expected_cardinality);
 }
 
-TEST_P(ParameterizedFixedLengthRecordDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  TestCase test_case = GetParam();
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  TF_ASSERT_OK(CreateTestFiles(test_case));
-
-  std::unique_ptr<OpKernel> fixed_length_record_dataset_kernel;
-  TF_ASSERT_OK(CreateFixedLengthRecordDatasetOpKernel(
-      &fixed_length_record_dataset_kernel));
-
-  int64 num_files = test_case.filenames.size();
-  Tensor filenames =
-      CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
-  Tensor header_bytes =
-      CreateTensor<int64>(TensorShape({}), {test_case.header_bytes});
-  Tensor record_bytes =
-      CreateTensor<int64>(TensorShape({}), {test_case.record_bytes});
-  Tensor footer_bytes =
-      CreateTensor<int64>(TensorShape({}), {test_case.footer_bytes});
-  Tensor buffer_size =
-      CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
-  Tensor compression_type = CreateTensor<string>(
-      TensorShape({}), {ToString(test_case.compression_type)});
-  gtl::InlinedVector<TensorValue, 4> inputs{
-      TensorValue(&filenames),    TensorValue(&header_bytes),
-      TensorValue(&record_bytes), TensorValue(&footer_bytes),
-      TensorValue(&buffer_size),  TensorValue(&compression_type),
-  };
-  std::unique_ptr<OpKernelContext> fixed_length_record_dataset_context;
-  TF_ASSERT_OK(CreateFixedLengthRecordDatasetContext(
-      fixed_length_record_dataset_kernel.get(), &inputs,
-      &fixed_length_record_dataset_context));
-
-  DatasetBase* fixed_length_record_dataset;
-  TF_ASSERT_OK(CreateDataset(fixed_length_record_dataset_kernel.get(),
-                             fixed_length_record_dataset_context.get(),
-                             &fixed_length_record_dataset));
-  core::ScopedUnref scoped_unref(fixed_length_record_dataset);
-
-  std::unique_ptr<SerializationContext> serialization_context;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(
-      fixed_length_record_dataset->Save(serialization_context.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_P(ParameterizedFixedLengthRecordDatasetOpTest, IteratorOutputDtypes) {
   int thread_num = 2, cpu_num = 2;
   TestCase test_case = GetParam();
diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
index 8a0d14b..184d9da 100644
--- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
@@ -75,6 +75,11 @@
     return name_utils::DatasetDebugString(kDatasetType);
   }
 
+  Status CheckExternalState() const override {
+    TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
+    return input_->CheckExternalState();
+  }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op_test.cc b/tensorflow/core/kernels/data/flat_map_dataset_op_test.cc
index 8beb943..8927a4f 100644
--- a/tensorflow/core/kernels/data/flat_map_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/flat_map_dataset_op_test.cc
@@ -73,47 +73,47 @@
 };
 
 TestCase MakeTensorSliceDatasetFuncTestCase() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
-          /*func*/
-          FunctionDefHelper::FunctionRef(
-              /*name*/ "MakeTensorSliceDataset",
-              /*attrs*/ {{"Toutput_types", DataTypeVector({DT_INT64})},
-                         {"output_shapes", std::vector<PartialTensorShape>(
-                                               {PartialTensorShape({1})})}}),
-          /*func_lib*/ {test::function::MakeTensorSliceDataset()},
-          /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {0}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {1}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {2}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {4}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {5}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {6}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {7}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {8})},
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
-          /*breakpoints*/ {0, 4, 11}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+      /*func*/
+      FunctionDefHelper::FunctionRef(
+          /*name*/ "MakeTensorSliceDataset",
+          /*attrs*/ {{"Toutput_types", DataTypeVector({DT_INT64})},
+                     {"output_shapes", std::vector<PartialTensorShape>(
+                                           {PartialTensorShape({1})})}}),
+      /*func_lib*/ {test::function::MakeTensorSliceDataset()},
+      /*expected_outputs*/
+      {CreateTensor<int64>(TensorShape{1}, {0}),
+       CreateTensor<int64>(TensorShape{1}, {1}),
+       CreateTensor<int64>(TensorShape{1}, {2}),
+       CreateTensor<int64>(TensorShape{1}, {3}),
+       CreateTensor<int64>(TensorShape{1}, {4}),
+       CreateTensor<int64>(TensorShape{1}, {5}),
+       CreateTensor<int64>(TensorShape{1}, {6}),
+       CreateTensor<int64>(TensorShape{1}, {7}),
+       CreateTensor<int64>(TensorShape{1}, {8})},
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
+      /*breakpoints*/ {0, 4, 11}};
 }
 
 // Test case 2: test the case if the function does not return a single scalar
 // of dtype DT_VARIANT.
 TestCase InvalidFuncTestCase() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
-          /*func*/
-          FunctionDefHelper::FunctionRef(/*name*/ "NonZero",
-                                         /*attrs*/ {{"T", DT_INT64}}),
-          /*func_lib*/ {test::function::NonZero()},
-          /*expected_outputs*/ {},
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ 0,
-          /*breakpoints*/ {}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+      /*func*/
+      FunctionDefHelper::FunctionRef(/*name*/ "NonZero",
+                                     /*attrs*/ {{"T", DT_INT64}}),
+      /*func_lib*/ {test::function::NonZero()},
+      /*expected_outputs*/ {},
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ 0,
+      /*breakpoints*/ {}};
 }
 
 class ParameterizedFlatMapDatasetOpTest
@@ -364,41 +364,6 @@
   EXPECT_EQ(flat_map_dataset->Cardinality(), test_case.expected_cardinality);
 }
 
-TEST_F(FlatMapDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  const TestCase &test_case = MakeTensorSliceDatasetFuncTestCase();
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
-
-  std::unique_ptr<OpKernel> flat_map_dataset_kernel;
-  TF_ASSERT_OK(CreateFlatMapDatasetKernel(
-      test_case.func, test_case.expected_output_dtypes,
-      test_case.expected_output_shapes, &flat_map_dataset_kernel));
-
-  Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
-  std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
-  TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
-                                              &tensor_slice_dataset_tensor));
-
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&tensor_slice_dataset_tensor)});
-  std::unique_ptr<OpKernelContext> flat_map_dataset_context;
-  TF_ASSERT_OK(CreateFlatMapDatasetContext(flat_map_dataset_kernel.get(),
-                                           &inputs, &flat_map_dataset_context));
-  DatasetBase *flat_map_dataset;
-  TF_ASSERT_OK(CreateDataset(flat_map_dataset_kernel.get(),
-                             flat_map_dataset_context.get(),
-                             &flat_map_dataset));
-  core::ScopedUnref scoped_unref(flat_map_dataset);
-
-  std::unique_ptr<SerializationContext> serialization_ctx;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(flat_map_dataset->Save(serialization_ctx.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_P(ParameterizedFlatMapDatasetOpTest, IteratorOutputDtypes) {
   int thread_num = 2, cpu_num = 2;
   const TestCase &test_case = GetParam();
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc
index 49ee3ed..e57a185 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/generator_dataset_op.cc
@@ -74,12 +74,18 @@
     return name_utils::DatasetDebugString(kDatasetType);
   }
 
+  Status CheckExternalState() const override {
+    TF_RETURN_IF_ERROR(init_func_->CheckExternalState());
+    TF_RETURN_IF_ERROR(next_func_->CheckExternalState());
+    return finalize_func_->CheckExternalState();
+  }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
                             Node** output) const override {
-    return errors::Unimplemented("%s does not support serialization",
-                                 DebugString());
+    return errors::Unimplemented(DebugString(),
+                                 " does not support serialization");
   }
 
  private:
diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc
index 5a6c65c..642092a 100644
--- a/tensorflow/core/kernels/data/interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc
@@ -81,6 +81,11 @@
     return name_utils::DatasetDebugString(kDatasetType);
   }
 
+  Status CheckExternalState() const override {
+    TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
+    return input_->CheckExternalState();
+  }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
diff --git a/tensorflow/core/kernels/data/interleave_dataset_op_test.cc b/tensorflow/core/kernels/data/interleave_dataset_op_test.cc
index 20b55f0..39ed82c 100644
--- a/tensorflow/core/kernels/data/interleave_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/interleave_dataset_op_test.cc
@@ -80,8 +80,7 @@
   std::vector<Tensor> tensors;
   tensors.reserve(values.size());
   for (auto &value : values) {
-    tensors.emplace_back(
-        DatasetOpsTestBase::CreateTensor<T>(TensorShape({1}), {value}));
+    tensors.emplace_back(CreateTensor<T>(TensorShape({1}), {value}));
   }
   return tensors;
 }
@@ -97,107 +96,107 @@
 
 // test case 1: cycle_length = 1, block_length = 1.
 TestCase TestCase1() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
-          /*func*/
-          MakeTensorSliceDatasetFunc(
-              DataTypeVector({DT_INT64}),
-              std::vector<PartialTensorShape>({PartialTensorShape({1})})),
-          /*func_lib*/ {test::function::MakeTensorSliceDataset()},
-          /*cycle_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-          /*block_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-          /*expected_outputs*/
-          ConvertToTensorVec<int64>({0, 1, 2, 3, 4, 5, 6, 7, 8}),
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
-          /*breakpoints*/ {0, 4, 11}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+      /*func*/
+      MakeTensorSliceDatasetFunc(
+          DataTypeVector({DT_INT64}),
+          std::vector<PartialTensorShape>({PartialTensorShape({1})})),
+      /*func_lib*/ {test::function::MakeTensorSliceDataset()},
+      /*cycle_length*/
+      CreateTensor<int64>(TensorShape({}), {1}),
+      /*block_length*/
+      CreateTensor<int64>(TensorShape({}), {1}),
+      /*expected_outputs*/
+      ConvertToTensorVec<int64>({0, 1, 2, 3, 4, 5, 6, 7, 8}),
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
+      /*breakpoints*/ {0, 4, 11}};
 }
 
 // test case 2: cycle_length = 2, block_length = 1.
 TestCase TestCase2() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
-          /*func*/
-          MakeTensorSliceDatasetFunc(
-              DataTypeVector({DT_INT64}),
-              std::vector<PartialTensorShape>({PartialTensorShape({1})})),
-          /*func_lib*/ {test::function::MakeTensorSliceDataset()},
-          /*cycle_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-          /*block_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-          /*expected_outputs*/
-          ConvertToTensorVec<int64>({0, 3, 1, 4, 2, 5, 6, 7, 8}),
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
-          /*breakpoints*/ {0, 4, 11}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+      /*func*/
+      MakeTensorSliceDatasetFunc(
+          DataTypeVector({DT_INT64}),
+          std::vector<PartialTensorShape>({PartialTensorShape({1})})),
+      /*func_lib*/ {test::function::MakeTensorSliceDataset()},
+      /*cycle_length*/
+      CreateTensor<int64>(TensorShape({}), {2}),
+      /*block_length*/
+      CreateTensor<int64>(TensorShape({}), {1}),
+      /*expected_outputs*/
+      ConvertToTensorVec<int64>({0, 3, 1, 4, 2, 5, 6, 7, 8}),
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
+      /*breakpoints*/ {0, 4, 11}};
 }
 
 // test case 3: cycle_length = 3, block_length = 1.
 TestCase TestCase3() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
-          /*func*/
-          MakeTensorSliceDatasetFunc(
-              DataTypeVector({DT_INT64}),
-              std::vector<PartialTensorShape>({PartialTensorShape({1})})),
-          /*func_lib*/ {test::function::MakeTensorSliceDataset()},
-          /*cycle_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3}),
-          /*block_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-          /*expected_outputs*/
-          ConvertToTensorVec<int64>({0, 3, 6, 1, 4, 7, 2, 5, 8}),
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
-          /*breakpoints*/ {0, 4, 11}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+      /*func*/
+      MakeTensorSliceDatasetFunc(
+          DataTypeVector({DT_INT64}),
+          std::vector<PartialTensorShape>({PartialTensorShape({1})})),
+      /*func_lib*/ {test::function::MakeTensorSliceDataset()},
+      /*cycle_length*/
+      CreateTensor<int64>(TensorShape({}), {3}),
+      /*block_length*/
+      CreateTensor<int64>(TensorShape({}), {1}),
+      /*expected_outputs*/
+      ConvertToTensorVec<int64>({0, 3, 6, 1, 4, 7, 2, 5, 8}),
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
+      /*breakpoints*/ {0, 4, 11}};
 }
 
 // test case 4: cycle_length = 5, block_length = 1.
 TestCase TestCase4() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
-          /*func*/
-          MakeTensorSliceDatasetFunc(
-              DataTypeVector({DT_INT64}),
-              std::vector<PartialTensorShape>({PartialTensorShape({1})})),
-          /*func_lib*/ {test::function::MakeTensorSliceDataset()},
-          /*cycle_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3}),
-          /*block_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-          /*expected_outputs*/
-          ConvertToTensorVec<int64>({0, 3, 6, 1, 4, 7, 2, 5, 8}),
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
-          /*breakpoints*/ {0, 4, 11}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+      /*func*/
+      MakeTensorSliceDatasetFunc(
+          DataTypeVector({DT_INT64}),
+          std::vector<PartialTensorShape>({PartialTensorShape({1})})),
+      /*func_lib*/ {test::function::MakeTensorSliceDataset()},
+      /*cycle_length*/
+      CreateTensor<int64>(TensorShape({}), {3}),
+      /*block_length*/
+      CreateTensor<int64>(TensorShape({}), {1}),
+      /*expected_outputs*/
+      ConvertToTensorVec<int64>({0, 3, 6, 1, 4, 7, 2, 5, 8}),
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
+      /*breakpoints*/ {0, 4, 11}};
 }
 
 // test case 5: cycle_length = 2, block_length = 2.
 TestCase TestCase5() {
   return {
       /*input_tensors*/
-      {DatasetOpsTestBase::CreateTensor<string>(
-          TensorShape{3, 3, 1}, {"a", "b", "c", "d", "e", "f", "g", "h", "i"})},
+      {CreateTensor<string>(TensorShape{3, 3, 1},
+                            {"a", "b", "c", "d", "e", "f", "g", "h", "i"})},
       /*func*/
       MakeTensorSliceDatasetFunc(
           DataTypeVector({DT_STRING}),
           std::vector<PartialTensorShape>({PartialTensorShape({1})})),
       /*func_lib*/ {test::function::MakeTensorSliceDataset()},
       /*cycle_length*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
+      CreateTensor<int64>(TensorShape({}), {2}),
       /*block_length*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
+      CreateTensor<int64>(TensorShape({}), {2}),
       /*expected_outputs*/
       ConvertToTensorVec<string>({"a", "b", "d", "e", "c", "f", "g", "h", "i"}),
       /*expected_output_dtypes*/ {DT_STRING},
@@ -210,17 +209,17 @@
 TestCase TestCase6() {
   return {
       /*input_tensors*/
-      {DatasetOpsTestBase::CreateTensor<string>(
-          TensorShape{3, 3, 1}, {"a", "b", "c", "d", "e", "f", "g", "h", "i"})},
+      {CreateTensor<string>(TensorShape{3, 3, 1},
+                            {"a", "b", "c", "d", "e", "f", "g", "h", "i"})},
       /*func*/
       MakeTensorSliceDatasetFunc(
           DataTypeVector({DT_STRING}),
           std::vector<PartialTensorShape>({PartialTensorShape({1})})),
       /*func_lib*/ {test::function::MakeTensorSliceDataset()},
       /*cycle_length*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
+      CreateTensor<int64>(TensorShape({}), {2}),
       /*block_length*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3}),
+      CreateTensor<int64>(TensorShape({}), {3}),
       /*expected_outputs*/
       ConvertToTensorVec<string>({"a", "b", "c", "d", "e", "f", "g", "h", "i"}),
       /*expected_output_dtypes*/ {DT_STRING},
@@ -233,17 +232,17 @@
 TestCase TestCase7() {
   return {
       /*input_tensors*/
-      {DatasetOpsTestBase::CreateTensor<string>(
-          TensorShape{3, 3, 1}, {"a", "b", "c", "d", "e", "f", "g", "h", "i"})},
+      {CreateTensor<string>(TensorShape{3, 3, 1},
+                            {"a", "b", "c", "d", "e", "f", "g", "h", "i"})},
       /*func*/
       MakeTensorSliceDatasetFunc(
           DataTypeVector({DT_STRING}),
           std::vector<PartialTensorShape>({PartialTensorShape({1})})),
       /*func_lib*/ {test::function::MakeTensorSliceDataset()},
       /*cycle_length*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
+      CreateTensor<int64>(TensorShape({}), {2}),
       /*block_length*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {5}),
+      CreateTensor<int64>(TensorShape({}), {5}),
       /*expected_outputs*/
       ConvertToTensorVec<string>({"a", "b", "c", "d", "e", "f", "g", "h", "i"}),
       /*expected_output_dtypes*/ {DT_STRING},
@@ -254,44 +253,42 @@
 
 // test case 8: cycle_length = 0, block_length = 5.
 TestCase InvalidCycleLengthTestCase() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
-          /*func*/
-          MakeTensorSliceDatasetFunc(
-              DataTypeVector({DT_INT64}),
-              std::vector<PartialTensorShape>({PartialTensorShape({1})})),
-          /*func_lib*/ {test::function::MakeTensorSliceDataset()},
-          /*cycle_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-          /*block_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {5}),
-          /*expected_outputs*/ ConvertToTensorVec<int64>({}),
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
-          /*breakpoints*/ {}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+      /*func*/
+      MakeTensorSliceDatasetFunc(
+          DataTypeVector({DT_INT64}),
+          std::vector<PartialTensorShape>({PartialTensorShape({1})})),
+      /*func_lib*/ {test::function::MakeTensorSliceDataset()},
+      /*cycle_length*/
+      CreateTensor<int64>(TensorShape({}), {0}),
+      /*block_length*/
+      CreateTensor<int64>(TensorShape({}), {5}),
+      /*expected_outputs*/ ConvertToTensorVec<int64>({}),
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
+      /*breakpoints*/ {}};
 }
 
 // test case 9: cycle_length = 1, block_length = -1.
 TestCase InvalidBlockLengthTestCase() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
-          /*func*/
-          MakeTensorSliceDatasetFunc(
-              DataTypeVector({DT_INT64}),
-              std::vector<PartialTensorShape>({PartialTensorShape({1})})),
-          /*func_lib*/ {test::function::MakeTensorSliceDataset()},
-          /*cycle_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-          /*block_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {-1}),
-          /*expected_outputs*/ ConvertToTensorVec<int64>({}),
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
-          /*breakpoints*/ {}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+      /*func*/
+      MakeTensorSliceDatasetFunc(
+          DataTypeVector({DT_INT64}),
+          std::vector<PartialTensorShape>({PartialTensorShape({1})})),
+      /*func_lib*/ {test::function::MakeTensorSliceDataset()},
+      /*cycle_length*/ CreateTensor<int64>(TensorShape({}), {1}),
+      /*block_length*/ CreateTensor<int64>(TensorShape({}), {-1}),
+      /*expected_outputs*/ ConvertToTensorVec<int64>({}),
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
+      /*breakpoints*/ {}};
 }
 
 class ParameterizedInterleaveDatasetOpTest
@@ -573,43 +570,6 @@
   EXPECT_EQ(interleave_dataset->Cardinality(), test_case.expected_cardinality);
 }
 
-TEST_P(ParameterizedInterleaveDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  const TestCase &test_case = GetParam();
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
-
-  std::unique_ptr<OpKernel> interleave_dataset_kernel;
-  TF_ASSERT_OK(CreateInterleaveDatasetKernel(
-      test_case.func, test_case.expected_output_dtypes,
-      test_case.expected_output_shapes, &interleave_dataset_kernel));
-
-  Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
-  std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
-  TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
-                                              &tensor_slice_dataset_tensor));
-  Tensor cycle_length = test_case.cycle_length;
-  Tensor block_length = test_case.block_length;
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&tensor_slice_dataset_tensor), TensorValue(&cycle_length),
-       TensorValue(&block_length)});
-  std::unique_ptr<OpKernelContext> interleave_dataset_context;
-  TF_ASSERT_OK(CreateInterleaveDatasetContext(
-      interleave_dataset_kernel.get(), &inputs, &interleave_dataset_context));
-  DatasetBase *interleave_dataset;
-  TF_ASSERT_OK(CreateDataset(interleave_dataset_kernel.get(),
-                             interleave_dataset_context.get(),
-                             &interleave_dataset));
-  core::ScopedUnref scoped_unref(interleave_dataset);
-
-  std::unique_ptr<SerializationContext> serialization_ctx;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(interleave_dataset->Save(serialization_ctx.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_P(ParameterizedInterleaveDatasetOpTest, IteratorOutputDtypes) {
   int thread_num = 2, cpu_num = 2;
   const TestCase &test_case = GetParam();
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 64b7f7c..08d9d93 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -52,7 +52,11 @@
 // See documentation in ../../ops/dataset_ops.cc for a high-level
 // description of the following ops.
 
+const char kAnonymousIterator[] = "AnonymousIterator";
+const char kAnonymousIteratorV2[] = "AnonymousIteratorV2";
 const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
+const char kOutputShapes[] = "output_shapes";
+const char kOutputTypes[] = "output_types";
 
 }  // namespace
 
@@ -70,6 +74,7 @@
     params.function_handle_cache = captured_state->function_handle_cache.get();
     params.resource_mgr = &captured_state->resource_mgr;
     params.thread_factory = unbounded_thread_pool_.get_thread_factory();
+    params.thread_pool = &unbounded_thread_pool_;
     params.cancellation_manager = &captured_state->cancellation_manager;
     std::function<void()> deregister_fn;
     TF_RETURN_IF_ERROR(ConnectCancellationManagers(ctx->cancellation_manager(),
@@ -78,12 +83,11 @@
     auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
     return captured_state->iterator->GetNext(IteratorContext(std::move(params)),
                                              out_tensors, end_of_sequence);
-  } else {
-    return errors::FailedPrecondition(
-        "GetNext() failed because the iterator has not been initialized. "
-        "Ensure that you have run the initializer operation for this "
-        "iterator before getting the next element.");
   }
+  return errors::FailedPrecondition(
+      "GetNext() failed because the iterator has not been initialized. Ensure "
+      "that you have run the initializer operation for this iterator before "
+      "getting the next element.");
 }
 
 Status IteratorResource::Save(SerializationContext* ctx,
@@ -95,84 +99,40 @@
   }
   if (captured_state->iterator) {
     return captured_state->iterator->Save(ctx, writer);
-  } else {
-    return errors::FailedPrecondition(
-        "Save() failed because the iterator has not been initialized. "
-        "Ensure that you have run the initializer operation for this "
-        "iterator before saving it.");
   }
+  return errors::FailedPrecondition(
+      "Save() failed because the iterator has not been initialized. Ensure "
+      "that you have run the initializer operation for this iterator before "
+      "saving it.");
 }
 
 Status IteratorResource::Restore(OpKernelContext* ctx,
                                  IteratorStateReader* reader) {
-  string serialized_graph_def;
-  TF_RETURN_IF_ERROR(
-      reader->ReadScalar(DatasetBase::kDatasetGraphKey, &serialized_graph_def));
-  GraphDef graph_def;
-  if (!graph_def.ParseFromString(serialized_graph_def)) {
-    return errors::Internal("Error parsing dataset GraphDef.");
+  std::shared_ptr<State> captured_state;
+  {
+    tf_shared_lock l(mu_);
+    captured_state = iterator_state_;
   }
-  string output_node;
-  TF_RETURN_IF_ERROR(reader->ReadScalar(DatasetBase::kDatasetGraphOutputNodeKey,
-                                        &output_node));
-  DatasetBase* dataset = nullptr;
-  Graph graph(OpRegistry::Global());
-  TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
-  std::vector<Tensor> outputs;
-  GraphRunner graph_runner(ctx->env());
-
-  // Build a new FLR that knows about the functions in the graph, and use
-  // it for all operations on the restored iterator.
-  // NOTE(mrry): We clone the existing FLR and use it in the GraphRunner
-  // because some of the OpKernels in the graph might call functions that are
-  // only defined in the loaded GraphDef.
-  FunctionLibraryRuntime* flr;
-  std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
-  std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
-  TF_RETURN_IF_ERROR(
-      ctx->function_library()->Clone(&flib_def, &pflr, &flr, true));
-
-  // Some function names may be duplicated (for example, if the serialized
-  // graph has an optimized function that retains its original name). We
-  // override functions in flib_def in the event of conflict. It is
-  // safe to assume that any node in the serialized graph is referring to the
-  // serialized function when there is a conflict.
-  TF_RETURN_IF_ERROR(AddToFunctionLibrary(flib_def.get(), graph_def.library()));
-  auto new_state = absl::make_unique<State>(
-      std::move(flib_def), std::move(pflr), flr, /*iterator=*/nullptr);
-
-  TF_RETURN_IF_ERROR(
-      graph_runner.Run(&graph, new_state->flr, {}, {output_node}, &outputs));
-  TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
-
-  IteratorContext::Params params(ctx);
-  params.flr = new_state->flr;
-  params.function_handle_cache = new_state->function_handle_cache.get();
-  params.resource_mgr = &new_state->resource_mgr;
-  DeviceBase* device = new_state->flr->device();
-  params.allocator_getter = [device](AllocatorAttributes attrs) {
-    return device->GetAllocator(attrs);
-  };
-  params.thread_factory = unbounded_thread_pool_.get_thread_factory();
-  params.cancellation_manager = &new_state->cancellation_manager;
-  std::function<void()> deregister_fn;
-  TF_RETURN_IF_ERROR(ConnectCancellationManagers(ctx->cancellation_manager(),
-                                                 params.cancellation_manager,
-                                                 &deregister_fn));
-  auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
-  IteratorContext iter_ctx(std::move(params));
-
-  TF_RETURN_IF_ERROR(
-      dataset->MakeIterator(&iter_ctx, "Iterator", &new_state->iterator));
-  TF_RETURN_IF_ERROR(
-      VerifyTypesMatch(output_dtypes_, new_state->iterator->output_dtypes()));
-  TF_RETURN_IF_ERROR(VerifyShapesCompatible(
-      output_shapes_, new_state->iterator->output_shapes()));
-  TF_RETURN_IF_ERROR(new_state->iterator->Restore(&iter_ctx, reader));
-
-  mutex_lock l(mu_);
-  iterator_state_ = std::move(new_state);
-  return Status::OK();
+  if (captured_state->iterator) {
+    IteratorContext::Params params(ctx);
+    params.flr = captured_state->flr;
+    params.function_handle_cache = captured_state->function_handle_cache.get();
+    params.resource_mgr = &captured_state->resource_mgr;
+    params.thread_factory = unbounded_thread_pool_.get_thread_factory();
+    params.thread_pool = &unbounded_thread_pool_;
+    params.cancellation_manager = &captured_state->cancellation_manager;
+    std::function<void()> deregister_fn;
+    TF_RETURN_IF_ERROR(ConnectCancellationManagers(ctx->cancellation_manager(),
+                                                   params.cancellation_manager,
+                                                   &deregister_fn));
+    auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
+    IteratorContext iter_ctx(std::move(params));
+    return captured_state->iterator->Restore(&iter_ctx, reader);
+  }
+  return errors::FailedPrecondition(
+      "Restore() failed because the iterator has not been initialized. Ensure "
+      "that you have run the initializer operation for this iterator before "
+      "restoring it.");
 }
 
 Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
@@ -191,6 +151,7 @@
   params.function_handle_cache = new_state->function_handle_cache.get();
   params.resource_mgr = &new_state->resource_mgr;
   params.thread_factory = unbounded_thread_pool_.get_thread_factory();
+  params.thread_pool = &unbounded_thread_pool_;
   params.cancellation_manager = &new_state->cancellation_manager;
   std::function<void()> deregister_fn;
   TF_RETURN_IF_ERROR(ConnectCancellationManagers(ctx->cancellation_manager(),
@@ -305,8 +266,8 @@
 // resource containers with AnonymousIteratorHandleOp instead.
 IteratorHandleOp::IteratorHandleOp(OpKernelConstruction* ctx)
     : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
-  OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
-  OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+  OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
+  OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
   OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
 }
 
@@ -413,19 +374,14 @@
 // running them.
 AnonymousIteratorHandleOp::AnonymousIteratorHandleOp(
     OpKernelConstruction* context)
-    : AnonymousIteratorResourceOp<IteratorResource>(context),
+    : AnonymousResourceOp<IteratorResource>(context),
       graph_def_version_(context->graph_def_version()) {
-  create_deleter_ = context->def().op() == "AnonymousIteratorV2";
+  OP_REQUIRES_OK(context, context->GetAttr(kOutputTypes, &output_dtypes_));
+  OP_REQUIRES_OK(context, context->GetAttr(kOutputShapes, &output_shapes_));
+  create_deleter_ = context->def().op() == kAnonymousIteratorV2;
 }
 
-static std::atomic<int64> current_iterator_id_;
-
-void AnonymousIteratorHandleOp::GenerateContainerNames(string* unique_name,
-                                                       string* container_name) {
-  *unique_name =
-      strings::StrCat("AnonymousIterator", current_iterator_id_.fetch_add(1));
-  *container_name = "AnonymousIterator";
-}
+string AnonymousIteratorHandleOp::name() { return kAnonymousIterator; }
 
 Status AnonymousIteratorHandleOp::CreateResource(
     OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
@@ -585,8 +541,8 @@
     params.is_multi_device_function = true;
     OP_REQUIRES_OK(ctx,
                    FunctionMetadata::Create(ctx, "f", params, &func_metadata_));
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
   }
 
   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
@@ -760,8 +716,8 @@
                                         "support the 'shared_name' attr."));
     OP_REQUIRES_OK(ctx,
                    ctx->GetAttr("dataset_factory", &dataset_factory_func_));
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
   }
 
   ~OneShotIteratorOp() override {
@@ -1046,15 +1002,15 @@
   Tensor* string_handle_t;
   OP_REQUIRES_OK(ctx,
                  ctx->allocate_output(0, TensorShape({}), &string_handle_t));
-  string_handle_t->scalar<string>()() =
+  string_handle_t->scalar<tstring>()() =
       resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
 }
 
 IteratorFromStringHandleOp::IteratorFromStringHandleOp(
     OpKernelConstruction* ctx)
     : OpKernel(ctx) {
-  OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
-  OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+  OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
+  OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
   OP_REQUIRES(
       ctx,
       output_dtypes_.empty() || output_shapes_.empty() ||
@@ -1070,7 +1026,7 @@
 
   ResourceHandle resource_handle;
   OP_REQUIRES(
-      ctx, resource_handle.ParseFromString(string_handle_t.scalar<string>()()),
+      ctx, resource_handle.ParseFromString(string_handle_t.scalar<tstring>()()),
       errors::InvalidArgument(
           "Could not parse string_handle as a valid ResourceHandle"));
 
diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h
index 09c951f..07b88d4 100644
--- a/tensorflow/core/kernels/data/iterator_ops.h
+++ b/tensorflow/core/kernels/data/iterator_ops.h
@@ -22,6 +22,7 @@
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
 #include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
 #include "tensorflow/core/kernels/ops_util.h"
 
@@ -136,70 +137,16 @@
   string name_;
 };
 
-template <typename T>
-class AnonymousIteratorResourceOp : public OpKernel {
- public:
-  explicit AnonymousIteratorResourceOp(OpKernelConstruction* context)
-      : OpKernel(context) {
-    OP_REQUIRES_OK(context, context->GetAttr("output_types", &output_dtypes_));
-    OP_REQUIRES_OK(context, context->GetAttr("output_shapes", &output_shapes_));
-  }
-
-  void Compute(OpKernelContext* ctx) override {
-    FunctionLibraryRuntime* lib;
-    std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
-    std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
-    OP_REQUIRES_OK(
-        ctx, ctx->function_library()->Clone(&flib_def, &pflr, &lib, true));
-    T* resource;
-    OP_REQUIRES_OK(ctx, CreateResource(ctx, std::move(flib_def),
-                                       std::move(pflr), lib, &resource));
-
-    string unique_name, container_name;
-    GenerateContainerNames(&unique_name, &container_name);
-    ResourceMgr* mgr = ctx->resource_manager();
-    OP_REQUIRES_OK(ctx, mgr->Create<T>(container_name, unique_name, resource));
-
-    Tensor* handle_t;
-    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle_t));
-    ResourceHandle handle = MakeResourceHandle(ctx, container_name, unique_name,
-                                               MakeTypeIndex<T>());
-    handle_t->scalar<ResourceHandle>()() = handle;
-
-    if (create_deleter_) {
-      Tensor* deleter_t;
-      OP_REQUIRES_OK(ctx, ctx->allocate_output(1, TensorShape({}), &deleter_t));
-      deleter_t->scalar<Variant>()() =
-          ResourceDeleter(handle, ctx->resource_manager());
-    }
-  }
-
- protected:
-  virtual void GenerateContainerNames(string* unique_name,
-                                      string* container_name) = 0;
-
-  virtual Status CreateResource(
-      OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
-      std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
-      FunctionLibraryRuntime* lib, T** resource) = 0;
-
-  DataTypeVector output_dtypes_;
-  std::vector<PartialTensorShape> output_shapes_;
-  bool create_deleter_ = true;
-};
-
 // Like IteratorHandleOp, but creates handles which are never shared, and does
 // not hold a reference to these handles. The latter is important for eager
 // execution, since OpKernel instances generally live as long as the program
 // running them.
-class AnonymousIteratorHandleOp
-    : public AnonymousIteratorResourceOp<IteratorResource> {
+class AnonymousIteratorHandleOp : public AnonymousResourceOp<IteratorResource> {
  public:
   explicit AnonymousIteratorHandleOp(OpKernelConstruction* context);
 
  private:
-  void GenerateContainerNames(string* unique_name,
-                              string* container_name) override;
+  string name() override;
 
   Status CreateResource(OpKernelContext* ctx,
                         std::unique_ptr<FunctionLibraryDefinition> flib_def,
@@ -207,6 +154,8 @@
                         FunctionLibraryRuntime* lib,
                         IteratorResource** resource) override;
 
+  DataTypeVector output_dtypes_;
+  std::vector<PartialTensorShape> output_shapes_;
   const int graph_def_version_;
 };
 
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index 26a56b0..0f36c6e 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -74,6 +74,11 @@
 
   int64 Cardinality() const override { return input_->Cardinality(); }
 
+  Status CheckExternalState() const override {
+    TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
+    return input_->CheckExternalState();
+  }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
diff --git a/tensorflow/core/kernels/data/map_dataset_op_test.cc b/tensorflow/core/kernels/data/map_dataset_op_test.cc
index 7dbe345..84b45a9 100644
--- a/tensorflow/core/kernels/data/map_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op_test.cc
@@ -74,10 +74,10 @@
           FunctionDefHelper::FunctionRef("XTimesTwo", {{"T", DT_INT64}}),
           /*func_lib*/ {test::function::XTimesTwo()},
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {6}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {12}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {18})},
+          {CreateTensor<int64>(TensorShape({}), {0}),
+           CreateTensor<int64>(TensorShape({}), {6}),
+           CreateTensor<int64>(TensorShape({}), {12}),
+           CreateTensor<int64>(TensorShape({}), {18})},
           /*expected_output_dtypes*/ {DT_INT64},
           /*expected_output_shapes*/ {PartialTensorShape({})},
           /*expected_cardinality*/ 4,
@@ -92,10 +92,10 @@
           FunctionDefHelper::FunctionRef("XAddX", {{"T", DT_INT64}}),
           /*func_lib*/ {test::function::XAddX()},
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {20}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {14}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {8}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2})},
+          {CreateTensor<int64>(TensorShape({}), {20}),
+           CreateTensor<int64>(TensorShape({}), {14}),
+           CreateTensor<int64>(TensorShape({}), {8}),
+           CreateTensor<int64>(TensorShape({}), {2})},
           /*expected_output_dtypes*/ {DT_INT64},
           /*expected_output_shapes*/ {PartialTensorShape({})},
           /*expected_cardinality*/ 4,
@@ -113,10 +113,10 @@
       FunctionDefHelper::FunctionRef("XTimesFour", {{"T", DT_INT64}}),
       /*func_lib*/ {test::function::XTimesTwo(), test::function::XTimesFour()},
       /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {12}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {24}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {36})},
+      {CreateTensor<int64>(TensorShape({}), {0}),
+       CreateTensor<int64>(TensorShape({}), {12}),
+       CreateTensor<int64>(TensorShape({}), {24}),
+       CreateTensor<int64>(TensorShape({}), {36})},
       /*expected_output_dtypes*/ {DT_INT64},
       /*expected_output_shapes*/ {PartialTensorShape({})},
       /*expected_cardinality*/ 4,
@@ -341,43 +341,6 @@
   EXPECT_EQ(map_dataset->Cardinality(), test_case.expected_cardinality);
 }
 
-TEST_P(ParameterizedMapDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  TestCase test_case = GetParam();
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
-
-  DatasetBase* range_dataset;
-  TF_ASSERT_OK(CreateRangeDataset<int64>(
-      test_case.start, test_case.end, test_case.step, "range", &range_dataset));
-  Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
-  // The ownership of range_dataset is transferred to DatasetVariantWrapper,
-  // which will handle the release of memory.
-  TF_ASSERT_OK(
-      StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
-  gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
-  map_dataset_inputs.emplace_back(&range_dataset_tensor);
-
-  std::unique_ptr<OpKernel> map_dataset_kernel;
-  TF_ASSERT_OK(CreateMapDatasetOpKernel(
-      test_case.func, test_case.expected_output_dtypes,
-      test_case.expected_output_shapes, &map_dataset_kernel));
-  std::unique_ptr<OpKernelContext> map_dataset_context;
-  TF_ASSERT_OK(CreateMapDatasetContext(
-      map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
-  DatasetBase* map_dataset;
-  TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
-                             map_dataset_context.get(), &map_dataset));
-  core::ScopedUnref scoped_unref_map_dataset(map_dataset);
-
-  std::unique_ptr<SerializationContext> serialization_context;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(map_dataset->Save(serialization_context.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_F(MapDatasetOpTest, IteratorOutputDtypes) {
   int thread_num = 2, cpu_num = 2;
   TestCase test_case = TestCase1();
diff --git a/tensorflow/core/kernels/data/map_defun_op_test.cc b/tensorflow/core/kernels/data/map_defun_op_test.cc
index 6db3376..39561ac 100644
--- a/tensorflow/core/kernels/data/map_defun_op_test.cc
+++ b/tensorflow/core/kernels/data/map_defun_op_test.cc
@@ -79,8 +79,8 @@
 // Test case 1: one input for the map function with no captured inputs.
 TestCase TestCase1() {
   return {
-      /*arguments*/ {DatasetOpsTestBase::CreateTensor<int64>(
-          TensorShape({3, 2}), {0, 1, 2, 3, 4, 5})},
+      /*arguments*/ {
+          CreateTensor<int64>(TensorShape({3, 2}), {0, 1, 2, 3, 4, 5})},
       /*captured_inputs*/ {},
       /*t_arguments*/ {DT_INT64},
       /*t_captured*/ {},
@@ -90,36 +90,34 @@
       /*output_dtypes*/ {DT_INT64},
       /*output_shapes*/ {PartialTensorShape({2})},
       /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3, 2}),
-                                               {0, 2, 4, 6, 8, 10})}};
+      {CreateTensor<int64>(TensorShape({3, 2}), {0, 2, 4, 6, 8, 10})}};
 }
 
 // Test case 2: two inputs for the map function with no captured inputs.
 TestCase TestCase2() {
-  return {/*arguments*/ {DatasetOpsTestBase::CreateTensor<int64>(
-                             TensorShape({3, 2}), {0, 1, 2, 3, 4, 5}),
-                         DatasetOpsTestBase::CreateTensor<int64>(
-                             TensorShape({3, 2}), {0, 10, 20, 30, 40, 50})},
-          /*captured_inputs*/ {},
-          /*t_arguments*/ {DT_INT64, DT_INT64},
-          /*t_captured*/ {},
-          /*func*/ {FunctionDefHelper::FunctionRef("XAddY", {{"T", DT_INT64}})},
-          /*func_lib*/ {test::function::XAddY()},
-          /*max_intra_op_parallelism*/ 2,
-          /*output_dtypes*/ {DT_INT64},
-          /*output_shapes*/ {PartialTensorShape({2})},
-          /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3, 2}),
-                                                   {0, 11, 22, 33, 44, 55})}};
+  return {
+      /*arguments*/ {
+          CreateTensor<int64>(TensorShape({3, 2}), {0, 1, 2, 3, 4, 5}),
+          CreateTensor<int64>(TensorShape({3, 2}), {0, 10, 20, 30, 40, 50})},
+      /*captured_inputs*/ {},
+      /*t_arguments*/ {DT_INT64, DT_INT64},
+      /*t_captured*/ {},
+      /*func*/ {FunctionDefHelper::FunctionRef("XAddY", {{"T", DT_INT64}})},
+      /*func_lib*/ {test::function::XAddY()},
+      /*max_intra_op_parallelism*/ 2,
+      /*output_dtypes*/ {DT_INT64},
+      /*output_shapes*/ {PartialTensorShape({2})},
+      /*expected_outputs*/
+      {CreateTensor<int64>(TensorShape({3, 2}), {0, 11, 22, 33, 44, 55})}};
 }
 
 // Test case 3: two inputs for the map function with one captured input.
 TestCase TestCase3() {
   return {
-      /*arguments*/ {DatasetOpsTestBase::CreateTensor<int64>(
-          TensorShape({3, 2}), {0, 1, 2, 3, 4, 5})},
+      /*arguments*/ {
+          CreateTensor<int64>(TensorShape({3, 2}), {0, 1, 2, 3, 4, 5})},
       /*captured_inputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {10, 100})},
+      {CreateTensor<int64>(TensorShape({2}), {10, 100})},
       /*t_arguments*/ {DT_INT64},
       /*t_captured*/ {DT_INT64},
       /*func*/ {FunctionDefHelper::FunctionRef("XAddY", {{"T", DT_INT64}})},
@@ -128,16 +126,15 @@
       /*output_dtypes*/ {DT_INT64},
       /*output_shapes*/ {PartialTensorShape({2})},
       /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3, 2}),
-                                               {10, 101, 12, 103, 14, 105})}};
+      {CreateTensor<int64>(TensorShape({3, 2}), {10, 101, 12, 103, 14, 105})}};
 }
 
 TestCase InvalidOutputTypes() {
   return {
-      /*arguments*/ {DatasetOpsTestBase::CreateTensor<int64>(
-          TensorShape({3, 2}), {0, 1, 2, 3, 4, 5})},
+      /*arguments*/ {
+          CreateTensor<int64>(TensorShape({3, 2}), {0, 1, 2, 3, 4, 5})},
       /*captured_inputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {10, 100})},
+      {CreateTensor<int64>(TensorShape({2}), {10, 100})},
       /*t_arguments*/ {DT_INT64},
       /*t_captured*/ {DT_INT64},
       /*func*/ {FunctionDefHelper::FunctionRef("XAddY", {{"T", DT_INT64}})},
@@ -146,16 +143,15 @@
       /*output_dtypes*/ {DT_FLOAT},
       /*output_shapes*/ {PartialTensorShape({2})},
       /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3, 2}),
-                                               {10, 101, 12, 103, 14, 105})}};
+      {CreateTensor<int64>(TensorShape({3, 2}), {10, 101, 12, 103, 14, 105})}};
 }
 
 TestCase InvalidOutputShapes() {
   return {
-      /*arguments*/ {DatasetOpsTestBase::CreateTensor<int64>(
-          TensorShape({3, 2}), {0, 1, 2, 3, 4, 5})},
+      /*arguments*/ {
+          CreateTensor<int64>(TensorShape({3, 2}), {0, 1, 2, 3, 4, 5})},
       /*captured_inputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {10, 100})},
+      {CreateTensor<int64>(TensorShape({2}), {10, 100})},
       /*t_arguments*/ {DT_INT64},
       /*t_captured*/ {DT_INT64},
       /*func*/ {FunctionDefHelper::FunctionRef("XAddY", {{"T", DT_INT64}})},
@@ -164,18 +160,16 @@
       /*output_dtypes*/ {DT_INT64},
       /*output_shapes*/ {PartialTensorShape({2, 2})},
       /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3, 2}),
-                                               {10, 101, 12, 103, 14, 105})}};
+      {CreateTensor<int64>(TensorShape({3, 2}), {10, 101, 12, 103, 14, 105})}};
 }
 
 TestCase InvalidInputs() {
   return {
-      /*arguments*/ {DatasetOpsTestBase::CreateTensor<int64>(
-                         TensorShape({3, 2}), {0, 1, 2, 3, 4, 5}),
-                     DatasetOpsTestBase::CreateTensor<int64>(
-                         TensorShape({2, 2}), {0, 1, 2, 3})},
+      /*arguments*/ {
+          CreateTensor<int64>(TensorShape({3, 2}), {0, 1, 2, 3, 4, 5}),
+          CreateTensor<int64>(TensorShape({2, 2}), {0, 1, 2, 3})},
       /*captured_inputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {10, 100})},
+      {CreateTensor<int64>(TensorShape({2}), {10, 100})},
       /*t_arguments*/ {DT_INT64, DT_INT64},
       /*t_captured*/ {DT_INT64},
       /*func*/ {FunctionDefHelper::FunctionRef("XAddY", {{"T", DT_INT64}})},
@@ -184,8 +178,7 @@
       /*output_dtypes*/ {DT_INT64},
       /*output_shapes*/ {PartialTensorShape({2})},
       /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3, 2}),
-                                               {10, 101, 12, 103, 14, 105})}};
+      {CreateTensor<int64>(TensorShape({3, 2}), {10, 101, 12, 103, 14, 105})}};
 }
 
 class ParameterizedMapDefunOpTest
diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc
index 6f347ef..3e25803 100644
--- a/tensorflow/core/kernels/data/model_dataset_op.cc
+++ b/tensorflow/core/kernels/data/model_dataset_op.cc
@@ -85,6 +85,10 @@
 
     int64 Cardinality() const override { return input_->Cardinality(); }
 
+    Status CheckExternalState() const override {
+      return input_->CheckExternalState();
+    }
+
    protected:
     Status AsGraphDefInternal(SerializationContext* ctx,
                               DatasetGraphDefBuilder* b,
diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
index 99d6304..7a538d7 100644
--- a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
+++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
@@ -27,7 +27,6 @@
 #include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
 #include "tensorflow/core/kernels/ops_util.h"
 #include "tensorflow/core/lib/core/refcount.h"
-#include "tensorflow/core/lib/core/threadpool.h"
 #include "tensorflow/core/lib/gtl/cleanup.h"
 #include "tensorflow/core/lib/random/random.h"
 #include "tensorflow/core/util/device_name_utils.h"
@@ -36,6 +35,11 @@
 namespace data {
 namespace {
 
+const char kAnonymousMultiDeviceIterator[] = "AnonymousMultiDeviceIterator";
+const char kDevices[] = "devices";
+const char kOutputShapes[] = "output_shapes";
+const char kOutputTypes[] = "output_types";
+
 struct HostBufferElement {
   Status status;
   bool end_of_sequence;
@@ -102,6 +106,7 @@
     params.function_handle_cache = function_handle_cache_.get();
     params.resource_mgr = &resource_mgr_;
     params.thread_factory = unbounded_thread_pool_.get_thread_factory();
+    params.thread_pool = &unbounded_thread_pool_;
     params.cancellation_manager = &cancellation_manager_;
     std::function<void()> deregister_fn;
     OP_REQUIRES_OK_ASYNC(ctx,
@@ -199,7 +204,9 @@
                           MultiDeviceIteratorCallback callback) {
       HostBufferElement elem;
       if (incarnation_id_ != incarnation_id) {
-        elem.status = errors::InvalidArgument("Invalid incarnation id");
+        elem.status = errors::InvalidArgument(
+            "Invalid incarnation id. Provided: ", incarnation_id,
+            "; Expected: ", incarnation_id_);
         callback(elem);
         return;
       }
@@ -385,7 +392,6 @@
   const std::unique_ptr<FunctionHandleCache> function_handle_cache_;
   ResourceMgr resource_mgr_;
   CancellationManager cancellation_manager_;
-  std::shared_ptr<const FunctionLibraryDefinition> lib_def_ GUARDED_BY(mu_);
 
   int64 incarnation_id_ GUARDED_BY(mu_) = 0;
   std::unique_ptr<MultiDeviceBuffer> multi_device_buffer_ GUARDED_BY(mu_);
@@ -399,11 +405,11 @@
  public:
   explicit MultiDeviceIteratorHandleOp(OpKernelConstruction* ctx)
       : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
     OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
     OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("devices", &devices_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr(kDevices, &devices_));
   }
 
   // The resource is deleted from the resource manager only when it is private
@@ -443,7 +449,7 @@
         if (name_ == ResourceHandle::ANONYMOUS_NAME) {
           unique_name = strings::StrCat("_AnonymousMultiDeviceIterator",
                                         current_id_.fetch_add(1));
-          container_name = "AnonymousMultiDeviceIterator";
+          container_name = kAnonymousMultiDeviceIterator;
           resource = new MultiDeviceIterator(
               context->env(), output_types_, output_shapes_, devices_,
               std::move(flib_def), std::move(pflr), flr,
@@ -511,26 +517,18 @@
 REGISTER_KERNEL_BUILDER(Name("MultiDeviceIterator").Device(DEVICE_CPU),
                         MultiDeviceIteratorHandleOp);
 
-// This atomic is used to ensure that each new AnonymousMultiDeviceIterator
-// handle is unique.
-static std::atomic<int64> current_multi_device_iterator_id_;
-
 class AnonymousMultiDeviceIteratorOp
-    : public AnonymousIteratorResourceOp<MultiDeviceIterator> {
+    : public AnonymousResourceOp<MultiDeviceIterator> {
  public:
   explicit AnonymousMultiDeviceIteratorOp(OpKernelConstruction* ctx)
-      : AnonymousIteratorResourceOp<MultiDeviceIterator>(ctx) {
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("devices", &devices_));
+      : AnonymousResourceOp<MultiDeviceIterator>(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr(kDevices, &devices_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
   }
 
  private:
-  void GenerateContainerNames(string* unique_name,
-                              string* container_name) override {
-    *unique_name =
-        strings::StrCat("_AnonymousMultiDeviceIterator",
-                        current_multi_device_iterator_id_.fetch_add(1));
-    *container_name = "AnonymousMultiDeviceIterator";
-  }
+  string name() override { return kAnonymousMultiDeviceIterator; }
 
   Status CreateResource(OpKernelContext* ctx,
                         std::unique_ptr<FunctionLibraryDefinition> flib_def,
@@ -546,9 +544,11 @@
   }
 
   std::vector<string> devices_;
+  DataTypeVector output_dtypes_;
+  std::vector<PartialTensorShape> output_shapes_;
 };
 
-REGISTER_KERNEL_BUILDER(Name("AnonymousMultiDeviceIterator").Device(DEVICE_CPU),
+REGISTER_KERNEL_BUILDER(Name(kAnonymousMultiDeviceIterator).Device(DEVICE_CPU),
                         AnonymousMultiDeviceIteratorOp);
 
 // Calls init on the MultiDeviceIterator.
@@ -644,7 +644,7 @@
     Tensor* string_handle_t;
     OP_REQUIRES_OK(ctx,
                    ctx->allocate_output(0, TensorShape({}), &string_handle_t));
-    string_handle_t->scalar<string>()() =
+    string_handle_t->scalar<tstring>()() =
         resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
   }
 };
@@ -657,8 +657,8 @@
  public:
   explicit MultiDeviceIteratorFromStringHandleOp(OpKernelConstruction* ctx)
       : OpKernel(ctx) {
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
     OP_REQUIRES(
         ctx,
         output_types_.empty() || output_shapes_.empty() ||
@@ -675,7 +675,7 @@
     ResourceHandle resource_handle;
     OP_REQUIRES(
         ctx,
-        resource_handle.ParseFromString(string_handle_t.scalar<string>()()),
+        resource_handle.ParseFromString(string_handle_t.scalar<tstring>()()),
         errors::InvalidArgument(
             "Could not parse string_handle as a valid ResourceHandle"));
 
diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
index bc7234c..5c23273 100644
--- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
@@ -114,6 +114,10 @@
     return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
   }
 
+  Status CheckExternalState() const override {
+    return input_->CheckExternalState();
+  }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op_test.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op_test.cc
index 6e1b06c..c906bfc 100644
--- a/tensorflow/core/kernels/data/padded_batch_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/padded_batch_dataset_op_test.cc
@@ -147,70 +147,60 @@
   std::vector<Tensor> tensors;
   tensors.reserve(values.size());
   for (auto &value : values) {
-    tensors.emplace_back(
-        DatasetOpsTestBase::CreateTensor<T>(TensorShape({1}), {value}));
+    tensors.emplace_back(CreateTensor<T>(TensorShape({1}), {value}));
   }
   return tensors;
 }
 
 // Test case 1: input elements with same shapes.
 TestCase TestCase1() {
-  return {/*input_tensors*/
-          {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 2},
-                                                    {0, 1, 2, 3, 4, 5})},
-           {DatasetOpsTestBase::CreateTensor<int64>(
-               TensorShape{4, 2}, {6, 7, 8, 9, 10, 11, 12, 13})}},
-          /*concatenate_output_dtypes*/ {DT_INT64},
-          /*concatenate_output_shapes*/ {PartialTensorShape({2})},
-          /*batch_size*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {2}),
-          /*padded_shapes*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3})},
-          /*padding_values*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {1})},
-          /*drop_remainder*/
-          DatasetOpsTestBase::CreateTensor<bool>(TensorShape{}, {true}),
-          /*parallel_copy*/ true,
-          /*n*/ 1,
-          /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 3},
-                                                   {0, 1, 1, 2, 3, 1}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 3},
-                                                   {4, 5, 1, 6, 7, 1}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 3},
-                                                   {8, 9, 1, 10, 11, 1})},
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({2, 3})},
-          /*expected_cardinality*/ 3,
-          /*breakpoints*/ {0, 2, 5}};
+  return {
+      /*input_tensors*/
+      {{CreateTensor<int64>(TensorShape{3, 2}, {0, 1, 2, 3, 4, 5})},
+       {CreateTensor<int64>(TensorShape{4, 2}, {6, 7, 8, 9, 10, 11, 12, 13})}},
+      /*concatenate_output_dtypes*/ {DT_INT64},
+      /*concatenate_output_shapes*/ {PartialTensorShape({2})},
+      /*batch_size*/
+      CreateTensor<int64>(TensorShape{}, {2}),
+      /*padded_shapes*/
+      {CreateTensor<int64>(TensorShape{1}, {3})},
+      /*padding_values*/
+      {CreateTensor<int64>(TensorShape{}, {1})},
+      /*drop_remainder*/
+      CreateTensor<bool>(TensorShape{}, {true}),
+      /*parallel_copy*/ true,
+      /*n*/ 1,
+      /*expected_outputs*/
+      {CreateTensor<int64>(TensorShape{2, 3}, {0, 1, 1, 2, 3, 1}),
+       CreateTensor<int64>(TensorShape{2, 3}, {4, 5, 1, 6, 7, 1}),
+       CreateTensor<int64>(TensorShape{2, 3}, {8, 9, 1, 10, 11, 1})},
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({2, 3})},
+      /*expected_cardinality*/ 3,
+      /*breakpoints*/ {0, 2, 5}};
 }
 
 // Test case 2: input elements with different shapes.
 TestCase TestCase2() {
   return {/*input_tensors*/
-          {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 2},
-                                                    {0, 1, 2, 3, 4, 5})},
-           {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{4, 1},
-                                                    {6, 7, 8, 9})}},
+          {{CreateTensor<int64>(TensorShape{3, 2}, {0, 1, 2, 3, 4, 5})},
+           {CreateTensor<int64>(TensorShape{4, 1}, {6, 7, 8, 9})}},
           /*concatenate_output_dtypes*/ {DT_INT64},
           /*concatenate_output_shapes*/ {PartialTensorShape({-1})},
           /*batch_size*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {2}),
+          CreateTensor<int64>(TensorShape{}, {2}),
           /*padded_shapes*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3})},
+          {CreateTensor<int64>(TensorShape{1}, {3})},
           /*padding_values*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {1})},
+          {CreateTensor<int64>(TensorShape{}, {1})},
           /*drop_remainder*/
-          DatasetOpsTestBase::CreateTensor<bool>(TensorShape{}, {true}),
+          CreateTensor<bool>(TensorShape{}, {true}),
           /*parallel_copy*/ true,
           /*n*/ 1,
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 3},
-                                                   {0, 1, 1, 2, 3, 1}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 3},
-                                                   {4, 5, 1, 6, 1, 1}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 3},
-                                                   {7, 1, 1, 8, 1, 1})},
+          {CreateTensor<int64>(TensorShape{2, 3}, {0, 1, 1, 2, 3, 1}),
+           CreateTensor<int64>(TensorShape{2, 3}, {4, 5, 1, 6, 1, 1}),
+           CreateTensor<int64>(TensorShape{2, 3}, {7, 1, 1, 8, 1, 1})},
           /*expected_output_dtypes*/ {DT_INT64},
           /*expected_output_shapes*/ {PartialTensorShape({2, 3})},
           /*expected_cardinality*/ 3,
@@ -219,149 +209,132 @@
 
 // Test case 3: similar with the test case 2 but drop_remainder = false.
 TestCase TestCase3() {
-  return {
-      /*input_tensors*/
-      {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 2},
-                                                {0, 1, 2, 3, 4, 5})},
-       {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{4, 1},
-                                                {6, 7, 8, 9})}},
-      /*concatenate_output_dtypes*/ {DT_INT64},
-      /*concatenate_output_shapes*/ {PartialTensorShape({-1})},
-      /*batch_size*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {2}),
-      /*padded_shapes*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3})},
-      /*padding_values*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {1})},
-      /*drop_remainder*/
-      DatasetOpsTestBase::CreateTensor<bool>(TensorShape{}, {false}),
-      /*parallel_copy*/ false,
-      /*n*/ 1,
-      /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 3},
-                                               {0, 1, 1, 2, 3, 1}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 3},
-                                               {4, 5, 1, 6, 1, 1}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 3},
-                                               {7, 1, 1, 8, 1, 1}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1, 3}, {9, 1, 1})},
-      /*expected_output_dtypes*/ {DT_INT64},
-      /*expected_output_shapes*/ {PartialTensorShape({-1, 3})},
-      /*expected_cardinality*/ 4,
-      /*breakpoints*/ {0, 2, 5}};
+  return {/*input_tensors*/
+          {{CreateTensor<int64>(TensorShape{3, 2}, {0, 1, 2, 3, 4, 5})},
+           {CreateTensor<int64>(TensorShape{4, 1}, {6, 7, 8, 9})}},
+          /*concatenate_output_dtypes*/ {DT_INT64},
+          /*concatenate_output_shapes*/ {PartialTensorShape({-1})},
+          /*batch_size*/
+          CreateTensor<int64>(TensorShape{}, {2}),
+          /*padded_shapes*/
+          {CreateTensor<int64>(TensorShape{1}, {3})},
+          /*padding_values*/
+          {CreateTensor<int64>(TensorShape{}, {1})},
+          /*drop_remainder*/
+          CreateTensor<bool>(TensorShape{}, {false}),
+          /*parallel_copy*/ false,
+          /*n*/ 1,
+          /*expected_outputs*/
+          {CreateTensor<int64>(TensorShape{2, 3}, {0, 1, 1, 2, 3, 1}),
+           CreateTensor<int64>(TensorShape{2, 3}, {4, 5, 1, 6, 1, 1}),
+           CreateTensor<int64>(TensorShape{2, 3}, {7, 1, 1, 8, 1, 1}),
+           CreateTensor<int64>(TensorShape{1, 3}, {9, 1, 1})},
+          /*expected_output_dtypes*/ {DT_INT64},
+          /*expected_output_shapes*/ {PartialTensorShape({-1, 3})},
+          /*expected_cardinality*/ 4,
+          /*breakpoints*/ {0, 2, 5}};
 }
 
 // Test case 4: similar with the test case 3 but the input elements can be
 // divided by the batch size evenly. As drop_remainder = false, the output
 // shape is still {-1, 3} instead of {2, 3}.
 TestCase TestCase4() {
-  return {
-      /*input_tensors*/
-      {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 2},
-                                                {0, 1, 2, 3, 4, 5})},
-       {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 1}, {6, 7, 8})}},
-      /*concatenate_output_dtypes*/ {DT_INT64},
-      /*concatenate_output_shapes*/ {PartialTensorShape({-1})},
-      /*batch_size*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {2}),
-      /*padded_shapes*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3})},
-      /*padding_values*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {1})},
-      /*drop_remainder*/
-      DatasetOpsTestBase::CreateTensor<bool>(TensorShape{}, {false}),
-      /*parallel_copy*/ false,
-      /*n*/ 1,
-      /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 3},
-                                               {0, 1, 1, 2, 3, 1}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 3},
-                                               {4, 5, 1, 6, 1, 1}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 3},
-                                               {7, 1, 1, 8, 1, 1})},
-      /*expected_output_dtypes*/ {DT_INT64},
-      /*expected_output_shapes*/ {PartialTensorShape({-1, 3})},
-      /*expected_cardinality*/ 3,
-      /*breakpoints*/ {0, 2, 5}};
+  return {/*input_tensors*/
+          {{CreateTensor<int64>(TensorShape{3, 2}, {0, 1, 2, 3, 4, 5})},
+           {CreateTensor<int64>(TensorShape{3, 1}, {6, 7, 8})}},
+          /*concatenate_output_dtypes*/ {DT_INT64},
+          /*concatenate_output_shapes*/ {PartialTensorShape({-1})},
+          /*batch_size*/
+          CreateTensor<int64>(TensorShape{}, {2}),
+          /*padded_shapes*/
+          {CreateTensor<int64>(TensorShape{1}, {3})},
+          /*padding_values*/
+          {CreateTensor<int64>(TensorShape{}, {1})},
+          /*drop_remainder*/
+          CreateTensor<bool>(TensorShape{}, {false}),
+          /*parallel_copy*/ false,
+          /*n*/ 1,
+          /*expected_outputs*/
+          {CreateTensor<int64>(TensorShape{2, 3}, {0, 1, 1, 2, 3, 1}),
+           CreateTensor<int64>(TensorShape{2, 3}, {4, 5, 1, 6, 1, 1}),
+           CreateTensor<int64>(TensorShape{2, 3}, {7, 1, 1, 8, 1, 1})},
+          /*expected_output_dtypes*/ {DT_INT64},
+          /*expected_output_shapes*/ {PartialTensorShape({-1, 3})},
+          /*expected_cardinality*/ 3,
+          /*breakpoints*/ {0, 2, 5}};
 }
 
 // Test case 5: similar with the test case 3 but padded_shapes = {-1}.
 TestCase TestCase5() {
-  return {
-      /*input_tensors*/
-      {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 2},
-                                                {0, 1, 2, 3, 4, 5})},
-       {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{4, 1},
-                                                {6, 7, 8, 9})}},
-      /*concatenate_output_dtypes*/ {DT_INT64},
-      /*concatenate_output_shapes*/ {PartialTensorShape({-1})},
-      /*batch_size*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {2}),
-      /*padded_shapes*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {-1})},
-      /*padding_values*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {1})},
-      /*drop_remainder*/
-      DatasetOpsTestBase::CreateTensor<bool>(TensorShape{}, {false}),
-      /*parallel_copy*/ false,
-      /*n*/ 1,
-      /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 2}, {0, 1, 2, 3}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 2}, {4, 5, 6, 1}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 1}, {7, 8}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1, 1}, {9})},
-      /*expected_output_dtypes*/ {DT_INT64},
-      /*expected_output_shapes*/ {PartialTensorShape({-1, -1})},
-      /*expected_cardinality*/ 4,
-      /*breakpoints*/ {0, 2, 5}};
+  return {/*input_tensors*/
+          {{CreateTensor<int64>(TensorShape{3, 2}, {0, 1, 2, 3, 4, 5})},
+           {CreateTensor<int64>(TensorShape{4, 1}, {6, 7, 8, 9})}},
+          /*concatenate_output_dtypes*/ {DT_INT64},
+          /*concatenate_output_shapes*/ {PartialTensorShape({-1})},
+          /*batch_size*/
+          CreateTensor<int64>(TensorShape{}, {2}),
+          /*padded_shapes*/
+          {CreateTensor<int64>(TensorShape{1}, {-1})},
+          /*padding_values*/
+          {CreateTensor<int64>(TensorShape{}, {1})},
+          /*drop_remainder*/
+          CreateTensor<bool>(TensorShape{}, {false}),
+          /*parallel_copy*/ false,
+          /*n*/ 1,
+          /*expected_outputs*/
+          {CreateTensor<int64>(TensorShape{2, 2}, {0, 1, 2, 3}),
+           CreateTensor<int64>(TensorShape{2, 2}, {4, 5, 6, 1}),
+           CreateTensor<int64>(TensorShape{2, 1}, {7, 8}),
+           CreateTensor<int64>(TensorShape{1, 1}, {9})},
+          /*expected_output_dtypes*/ {DT_INT64},
+          /*expected_output_shapes*/ {PartialTensorShape({-1, -1})},
+          /*expected_cardinality*/ 4,
+          /*breakpoints*/ {0, 2, 5}};
 }
 
 // Test case 6: similar with the test case 5 but parallel_copy = true.
 TestCase TestCase6() {
-  return {
-      /*input_tensors*/
-      {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 2},
-                                                {0, 1, 2, 3, 4, 5})},
-       {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{4, 1},
-                                                {6, 7, 8, 9})}},
-      /*concatenate_output_dtypes*/ {DT_INT64},
-      /*concatenate_output_shapes*/ {PartialTensorShape({-1})},
-      /*batch_size*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {2}),
-      /*padded_shapes*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {-1})},
-      /*padding_values*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {1})},
-      /*drop_remainder*/
-      DatasetOpsTestBase::CreateTensor<bool>(TensorShape{}, {false}),
-      /*parallel_copy*/ true,
-      /*n*/ 1,
-      /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 2}, {0, 1, 2, 3}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 2}, {4, 5, 6, 1}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 1}, {7, 8}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1, 1}, {9})},
-      /*expected_output_dtypes*/ {DT_INT64},
-      /*expected_output_shapes*/ {PartialTensorShape({-1, -1})},
-      /*expected_cardinality*/ 4,
-      /*breakpoints*/ {0, 2, 5}};
+  return {/*input_tensors*/
+          {{CreateTensor<int64>(TensorShape{3, 2}, {0, 1, 2, 3, 4, 5})},
+           {CreateTensor<int64>(TensorShape{4, 1}, {6, 7, 8, 9})}},
+          /*concatenate_output_dtypes*/ {DT_INT64},
+          /*concatenate_output_shapes*/ {PartialTensorShape({-1})},
+          /*batch_size*/
+          CreateTensor<int64>(TensorShape{}, {2}),
+          /*padded_shapes*/
+          {CreateTensor<int64>(TensorShape{1}, {-1})},
+          /*padding_values*/
+          {CreateTensor<int64>(TensorShape{}, {1})},
+          /*drop_remainder*/
+          CreateTensor<bool>(TensorShape{}, {false}),
+          /*parallel_copy*/ true,
+          /*n*/ 1,
+          /*expected_outputs*/
+          {CreateTensor<int64>(TensorShape{2, 2}, {0, 1, 2, 3}),
+           CreateTensor<int64>(TensorShape{2, 2}, {4, 5, 6, 1}),
+           CreateTensor<int64>(TensorShape{2, 1}, {7, 8}),
+           CreateTensor<int64>(TensorShape{1, 1}, {9})},
+          /*expected_output_dtypes*/ {DT_INT64},
+          /*expected_output_shapes*/ {PartialTensorShape({-1, -1})},
+          /*expected_cardinality*/ 4,
+          /*breakpoints*/ {0, 2, 5}};
 }
 
 // Test case 7: empty input elements.
 TestCase TestCase7() {
   return {/*input_tensors*/
-          {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape{0}, {})},
-           {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{0}, {})}},
+          {{CreateTensor<int64>(TensorShape{0}, {})},
+           {CreateTensor<int64>(TensorShape{0}, {})}},
           /*concatenate_output_dtypes*/ {DT_INT64},
           /*concatenate_output_shapes*/ {PartialTensorShape({-1})},
           /*batch_size*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {2}),
+          CreateTensor<int64>(TensorShape{}, {2}),
           /*padded_shapes*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {-1})},
+          {CreateTensor<int64>(TensorShape{1}, {-1})},
           /*padding_values*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {1})},
+          {CreateTensor<int64>(TensorShape{}, {1})},
           /*drop_remainder*/
-          DatasetOpsTestBase::CreateTensor<bool>(TensorShape{}, {false}),
+          CreateTensor<bool>(TensorShape{}, {false}),
           /*parallel_copy*/ true,
           /*n*/ 1,
           /*expected_outputs*/ {},
@@ -373,20 +346,18 @@
 
 TestCase ShortPaddingTestCase() {
   return {/*input_tensors*/
-          {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 2},
-                                                    {0, 1, 2, 3, 4, 5})},
-           {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 2},
-                                                    {6, 7, 8, 9, 10, 11})}},
+          {{CreateTensor<int64>(TensorShape{3, 2}, {0, 1, 2, 3, 4, 5})},
+           {CreateTensor<int64>(TensorShape{3, 2}, {6, 7, 8, 9, 10, 11})}},
           /*concatenate_output_dtypes*/ {DT_INT64},
           /*concatenate_output_shapes*/ {PartialTensorShape({2})},
           /*batch_size*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {2}),
+          CreateTensor<int64>(TensorShape{}, {2}),
           /*padded_shapes*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {1})},
+          {CreateTensor<int64>(TensorShape{1}, {1})},
           /*padding_values*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {1})},
+          {CreateTensor<int64>(TensorShape{}, {1})},
           /*drop_remainder*/
-          DatasetOpsTestBase::CreateTensor<bool>(TensorShape{}, {false}),
+          CreateTensor<bool>(TensorShape{}, {false}),
           /*parallel_copy*/ true,
           /*n*/ 1,
           /*expected_outputs*/ {},
@@ -398,20 +369,18 @@
 
 TestCase InvalidPaddingShapesTestCase() {
   return {/*input_tensors*/
-          {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 2},
-                                                    {0, 1, 2, 3, 4, 5})},
-           {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 2},
-                                                    {6, 7, 8, 9, 10, 11})}},
+          {{CreateTensor<int64>(TensorShape{3, 2}, {0, 1, 2, 3, 4, 5})},
+           {CreateTensor<int64>(TensorShape{3, 2}, {6, 7, 8, 9, 10, 11})}},
           /*concatenate_output_dtypes*/ {DT_INT64},
           /*concatenate_output_shapes*/ {PartialTensorShape({2})},
           /*batch_size*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {2}),
+          CreateTensor<int64>(TensorShape{}, {2}),
           /*padded_shapes*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2}, {1, 2})},
+          {CreateTensor<int64>(TensorShape{2}, {1, 2})},
           /*padding_values*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {1})},
+          {CreateTensor<int64>(TensorShape{}, {1})},
           /*drop_remainder*/
-          DatasetOpsTestBase::CreateTensor<bool>(TensorShape{}, {false}),
+          CreateTensor<bool>(TensorShape{}, {false}),
           /*parallel_copy*/ true,
           /*n*/ 1,
           /*expected_outputs*/ {},
@@ -423,20 +392,18 @@
 
 TestCase InvalidBatchSizeTestCase() {
   return {/*input_tensors*/
-          {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 2},
-                                                    {0, 1, 2, 3, 4, 5})},
-           {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 2},
-                                                    {6, 7, 8, 9, 10, 11})}},
+          {{CreateTensor<int64>(TensorShape{3, 2}, {0, 1, 2, 3, 4, 5})},
+           {CreateTensor<int64>(TensorShape{3, 2}, {6, 7, 8, 9, 10, 11})}},
           /*concatenate_output_dtypes*/ {DT_INT64},
           /*concatenate_output_shapes*/ {PartialTensorShape({2})},
           /*batch_size*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {-1}),
+          CreateTensor<int64>(TensorShape{}, {-1}),
           /*padded_shapes*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3})},
+          {CreateTensor<int64>(TensorShape{1}, {3})},
           /*padding_values*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {1})},
+          {CreateTensor<int64>(TensorShape{}, {1})},
           /*drop_remainder*/
-          DatasetOpsTestBase::CreateTensor<bool>(TensorShape{}, {false}),
+          CreateTensor<bool>(TensorShape{}, {false}),
           /*parallel_copy*/ true,
           /*n*/ 1,
           /*expected_outputs*/ {},
@@ -448,21 +415,19 @@
 
 TestCase InvalidPaddedShapesSizeTestCase() {
   return {/*input_tensors*/
-          {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 2},
-                                                    {0, 1, 2, 3, 4, 5})},
-           {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 2},
-                                                    {6, 7, 8, 9, 10, 11})}},
+          {{CreateTensor<int64>(TensorShape{3, 2}, {0, 1, 2, 3, 4, 5})},
+           {CreateTensor<int64>(TensorShape{3, 2}, {6, 7, 8, 9, 10, 11})}},
           /*concatenate_output_dtypes*/ {DT_INT64},
           /*concatenate_output_shapes*/ {PartialTensorShape({2})},
           /*batch_size*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {2}),
+          CreateTensor<int64>(TensorShape{}, {2}),
           /*padded_shapes*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3})},
+          {CreateTensor<int64>(TensorShape{1}, {3}),
+           CreateTensor<int64>(TensorShape{1}, {3})},
           /*padding_values*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {1})},
+          {CreateTensor<int64>(TensorShape{}, {1})},
           /*drop_remainder*/
-          DatasetOpsTestBase::CreateTensor<bool>(TensorShape{}, {false}),
+          CreateTensor<bool>(TensorShape{}, {false}),
           /*parallel_copy*/ true,
           /*n*/ 2,
           /*expected_outputs*/ {},
@@ -474,21 +439,19 @@
 
 TestCase InvalidPaddedValuesSizeTestCase() {
   return {/*input_tensors*/
-          {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 2},
-                                                    {0, 1, 2, 3, 4, 5})},
-           {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 2},
-                                                    {6, 7, 8, 9, 10, 11})}},
+          {{CreateTensor<int64>(TensorShape{3, 2}, {0, 1, 2, 3, 4, 5})},
+           {CreateTensor<int64>(TensorShape{3, 2}, {6, 7, 8, 9, 10, 11})}},
           /*concatenate_output_dtypes*/ {DT_INT64},
           /*concatenate_output_shapes*/ {PartialTensorShape({2})},
           /*batch_size*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {2}),
+          CreateTensor<int64>(TensorShape{}, {2}),
           /*padded_shapes*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3})},
+          {CreateTensor<int64>(TensorShape{1}, {3})},
           /*padding_values*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {1}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {1})},
+          {CreateTensor<int64>(TensorShape{}, {1}),
+           CreateTensor<int64>(TensorShape{}, {1})},
           /*drop_remainder*/
-          DatasetOpsTestBase::CreateTensor<bool>(TensorShape{}, {false}),
+          CreateTensor<bool>(TensorShape{}, {false}),
           /*parallel_copy*/ true,
           /*n*/ 1,
           /*expected_outputs*/ {},
@@ -500,20 +463,18 @@
 
 TestCase InvalidPaddedValuesDTypeTestCase() {
   return {/*input_tensors*/
-          {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 2},
-                                                    {0, 1, 2, 3, 4, 5})},
-           {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 2},
-                                                    {6, 7, 8, 9, 10, 11})}},
+          {{CreateTensor<int64>(TensorShape{3, 2}, {0, 1, 2, 3, 4, 5})},
+           {CreateTensor<int64>(TensorShape{3, 2}, {6, 7, 8, 9, 10, 11})}},
           /*concatenate_output_dtypes*/ {DT_INT64},
           /*concatenate_output_shapes*/ {PartialTensorShape({2})},
           /*batch_size*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {2}),
+          CreateTensor<int64>(TensorShape{}, {2}),
           /*padded_shapes*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3})},
+          {CreateTensor<int64>(TensorShape{1}, {3})},
           /*padding_values*/
-          {DatasetOpsTestBase::CreateTensor<string>(TensorShape{}, {"a"})},
+          {CreateTensor<string>(TensorShape{}, {"a"})},
           /*drop_remainder*/
-          DatasetOpsTestBase::CreateTensor<bool>(TensorShape{}, {false}),
+          CreateTensor<bool>(TensorShape{}, {false}),
           /*parallel_copy*/ true,
           /*n*/ 1,
           /*expected_outputs*/ {},
@@ -525,20 +486,18 @@
 
 TestCase InvalidPaddedValuesShapeTestCase() {
   return {/*input_tensors*/
-          {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 2},
-                                                    {0, 1, 2, 3, 4, 5})},
-           {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{3, 2},
-                                                    {6, 7, 8, 9, 10, 11})}},
+          {{CreateTensor<int64>(TensorShape{3, 2}, {0, 1, 2, 3, 4, 5})},
+           {CreateTensor<int64>(TensorShape{3, 2}, {6, 7, 8, 9, 10, 11})}},
           /*concatenate_output_dtypes*/ {DT_INT64},
           /*concatenate_output_shapes*/ {PartialTensorShape({2})},
           /*batch_size*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {2}),
+          CreateTensor<int64>(TensorShape{}, {2}),
           /*padded_shapes*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3})},
+          {CreateTensor<int64>(TensorShape{1}, {3})},
           /*padding_values*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {1})},
+          {CreateTensor<int64>(TensorShape{1}, {1})},
           /*drop_remainder*/
-          DatasetOpsTestBase::CreateTensor<bool>(TensorShape{}, {false}),
+          CreateTensor<bool>(TensorShape{}, {false}),
           /*parallel_copy*/ true,
           /*n*/ 1,
           /*expected_outputs*/ {},
@@ -826,53 +785,6 @@
             test_case.expected_cardinality);
 }
 
-TEST_P(ParameterizedPaddedBatchDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  const TestCase &test_case = GetParam();
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  std::unique_ptr<OpKernel> padded_batch_dataset_kernel;
-  TF_ASSERT_OK(CreatePaddedBatchDatasetKernel(
-      test_case.parallel_copy, test_case.n, test_case.expected_output_dtypes,
-      test_case.expected_output_shapes, &padded_batch_dataset_kernel));
-
-  Tensor concatenate_dataset_tensor(DT_VARIANT, TensorShape({}));
-  TF_ASSERT_OK(CreateConcatenateDatasetTensor(
-      test_case.input_tensors, test_case.concatenate_output_dtypes,
-      test_case.concatenate_output_shapes, &concatenate_dataset_tensor));
-  Tensor batch_size = test_case.batch_size;
-  std::vector<Tensor> padded_shapes = test_case.padded_shapes;
-  std::vector<Tensor> padding_values = test_case.padding_values;
-  Tensor drop_remainder = test_case.drop_remainder;
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&concatenate_dataset_tensor), TensorValue(&batch_size)});
-  for (auto &padded_shape : padded_shapes) {
-    inputs.emplace_back(&padded_shape);
-  }
-  for (auto &padding_value : padding_values) {
-    inputs.emplace_back(&padding_value);
-  }
-  inputs.emplace_back(&drop_remainder);
-
-  std::unique_ptr<OpKernelContext> padded_batch_dataset_context;
-  TF_ASSERT_OK(
-      CreatePaddedBatchDatasetContext(padded_batch_dataset_kernel.get(),
-                                      &inputs, &padded_batch_dataset_context));
-  DatasetBase *padded_batch_dataset;
-  TF_ASSERT_OK(CreateDataset(padded_batch_dataset_kernel.get(),
-                             padded_batch_dataset_context.get(),
-                             &padded_batch_dataset));
-  core::ScopedUnref scoped_unref(padded_batch_dataset);
-
-  std::unique_ptr<SerializationContext> serialization_ctx;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(padded_batch_dataset->Save(serialization_ctx.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_P(ParameterizedPaddedBatchDatasetOpTest, IteratorOutputDtypes) {
   int thread_num = 2, cpu_num = 2;
   const TestCase &test_case = GetParam();
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index 3e21399..82840d5 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -153,6 +153,11 @@
         ParallelInterleaveDatasetOp::kDatasetType, params);
   }
 
+  Status CheckExternalState() const override {
+    TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
+    return input_->CheckExternalState();
+  }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
@@ -200,23 +205,7 @@
           num_parallel_calls_(std::make_shared<model::SharedState>(
               params.dataset->num_parallel_calls_, mu_, cond_var_)),
           sloppy_(sloppy),
-          current_elements_(params.dataset->cycle_length_) {
-      // The size of the threadpool is the smaller of:
-      //
-      // 1) The number of schedulable CPUs multiplied by a constant factor
-      //    factor to account for the fact that some threads may perform I/O.
-      //
-      // 2) The maximum number of iterators instantiated at any given point
-      //    in time (`cycle_length` for the current cycle elements and
-      //    `kPrefetchFactor * cycle_length` for future cycle elements).
-      const int num_threads =
-          std::min(static_cast<int>(kCPUFactor * port::NumSchedulableCPUs()),
-                   static_cast<int>((kPrefetchFactor + 1) *
-                                    params.dataset->cycle_length_));
-      thread_pool_ = absl::make_unique<thread::ThreadPool>(
-          Env::Default(), ThreadOptions(), kDataParallelInterleaveWorkerPool,
-          num_threads, /*low_latency_hint=*/false);
-    }
+          current_elements_(params.dataset->cycle_length_) {}
 
     ~ParallelInterleaveIterator() override {
       mutex_lock l(*mu_);
@@ -238,6 +227,24 @@
 
     Status Initialize(IteratorContext* ctx) override {
       mutex_lock l(*mu_);
+      // The size of the threadpool `num_threads` is the smaller of:
+      //
+      // 1) The number of schedulable CPUs multiplied by a constant factor
+      //    factor to account for the fact that some threads may perform I/O.
+      //
+      // 2) The maximum number of iterators instantiated at any given point
+      //    in time (`cycle_length` for the current cycle elements and
+      //    `kPrefetchFactor * cycle_length` for future cycle elements).
+      //
+      // Note that if `ctx->thread_pool()` is non-null, then instead of creating
+      // a dedicated thread pool of size `num_threads`, computation will be
+      // scheduled into the shared threadpool whose size is independent of
+      // `num_threads`.
+      const int num_threads = std::min(
+          static_cast<int>(kCPUFactor * port::NumSchedulableCPUs()),
+          static_cast<int>((kPrefetchFactor + 1) * dataset()->cycle_length_));
+      thread_pool_ =
+          ctx->CreateThreadPool(kDataParallelInterleaveWorkerPool, num_threads);
       if (num_parallel_calls_->value == model::kAutotune) {
         num_parallel_calls_->value = dataset()->cycle_length_;
       }
@@ -560,7 +567,6 @@
                       int64 num_results, std::function<void()> done)
         LOCKS_EXCLUDED(*mu_) {
       RecordStart(ctx.get());
-      auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
       bool end_of_input = false;
       for (int64 i = 0; i < num_results; ++i) {
         auto result = std::make_shared<Result>();
@@ -588,6 +594,7 @@
       }
       done();
       cond_var_->notify_all();
+      RecordStop(ctx.get());
     }
 
     // Manages futures cycle elements, creating new iterators as needed and
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op_test.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op_test.cc
index c2e66ec..aeadba0 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op_test.cc
@@ -89,8 +89,7 @@
   std::vector<Tensor> tensors;
   tensors.reserve(values.size());
   for (auto &value : values) {
-    tensors.emplace_back(
-        DatasetOpsTestBase::CreateTensor<T>(TensorShape({1}), {value}));
+    tensors.emplace_back(CreateTensor<T>(TensorShape({1}), {value}));
   }
   return tensors;
 }
@@ -107,105 +106,105 @@
 // test case 1: cycle_length = 1, block_length = 1, num_parallel_calls = 1,
 // sloppy = false
 TestCase TestCase1() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
-          /*func*/
-          MakeTensorSliceDatasetFunc(
-              DataTypeVector({DT_INT64}),
-              std::vector<PartialTensorShape>({PartialTensorShape({1})})),
-          /*func_lib*/ {test::function::MakeTensorSliceDataset()},
-          /*cycle_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-          /*block_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-          /*num_parallel_calls*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-          /*sloppy*/ false,
-          /*expected_outputs*/
-          ConvertToTensorVec<int64>({0, 1, 2, 3, 4, 5, 6, 7, 8}),
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
-          /*breakpoints*/ {0, 4, 11}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+      /*func*/
+      MakeTensorSliceDatasetFunc(
+          DataTypeVector({DT_INT64}),
+          std::vector<PartialTensorShape>({PartialTensorShape({1})})),
+      /*func_lib*/ {test::function::MakeTensorSliceDataset()},
+      /*cycle_length*/
+      CreateTensor<int64>(TensorShape({}), {1}),
+      /*block_length*/
+      CreateTensor<int64>(TensorShape({}), {1}),
+      /*num_parallel_calls*/
+      CreateTensor<int64>(TensorShape({}), {1}),
+      /*sloppy*/ false,
+      /*expected_outputs*/
+      ConvertToTensorVec<int64>({0, 1, 2, 3, 4, 5, 6, 7, 8}),
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
+      /*breakpoints*/ {0, 4, 11}};
 }
 
 // test case 2: cycle_length = 2, block_length = 1, num_parallel_calls = 2,
 // sloppy = false
 TestCase TestCase2() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
-          /*func*/
-          MakeTensorSliceDatasetFunc(
-              DataTypeVector({DT_INT64}),
-              std::vector<PartialTensorShape>({PartialTensorShape({1})})),
-          /*func_lib*/ {test::function::MakeTensorSliceDataset()},
-          /*cycle_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-          /*block_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-          /*num_parallel_calls*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-          /*sloppy*/ false,
-          /*expected_outputs*/
-          ConvertToTensorVec<int64>({0, 3, 1, 4, 2, 5, 6, 7, 8}),
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
-          /*breakpoints*/ {0, 4, 11}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+      /*func*/
+      MakeTensorSliceDatasetFunc(
+          DataTypeVector({DT_INT64}),
+          std::vector<PartialTensorShape>({PartialTensorShape({1})})),
+      /*func_lib*/ {test::function::MakeTensorSliceDataset()},
+      /*cycle_length*/
+      CreateTensor<int64>(TensorShape({}), {2}),
+      /*block_length*/
+      CreateTensor<int64>(TensorShape({}), {1}),
+      /*num_parallel_calls*/
+      CreateTensor<int64>(TensorShape({}), {2}),
+      /*sloppy*/ false,
+      /*expected_outputs*/
+      ConvertToTensorVec<int64>({0, 3, 1, 4, 2, 5, 6, 7, 8}),
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
+      /*breakpoints*/ {0, 4, 11}};
 }
 
 // test case 3: cycle_length = 3, block_length = 1, num_parallel_calls = 2,
 // sloppy = true
 TestCase TestCase3() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
-          /*func*/
-          MakeTensorSliceDatasetFunc(
-              DataTypeVector({DT_INT64}),
-              std::vector<PartialTensorShape>({PartialTensorShape({1})})),
-          /*func_lib*/ {test::function::MakeTensorSliceDataset()},
-          /*cycle_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3}),
-          /*block_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-          /*num_parallel_calls*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-          /*sloppy*/ true,
-          /*expected_outputs*/
-          ConvertToTensorVec<int64>({0, 3, 6, 1, 4, 7, 2, 5, 8}),
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
-          /*breakpoints*/ {0, 4, 11}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+      /*func*/
+      MakeTensorSliceDatasetFunc(
+          DataTypeVector({DT_INT64}),
+          std::vector<PartialTensorShape>({PartialTensorShape({1})})),
+      /*func_lib*/ {test::function::MakeTensorSliceDataset()},
+      /*cycle_length*/
+      CreateTensor<int64>(TensorShape({}), {3}),
+      /*block_length*/
+      CreateTensor<int64>(TensorShape({}), {1}),
+      /*num_parallel_calls*/
+      CreateTensor<int64>(TensorShape({}), {2}),
+      /*sloppy*/ true,
+      /*expected_outputs*/
+      ConvertToTensorVec<int64>({0, 3, 6, 1, 4, 7, 2, 5, 8}),
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
+      /*breakpoints*/ {0, 4, 11}};
 }
 
 // test case 4: cycle_length = 5, block_length = 1, num_parallel_calls = 4,
 // sloppy = true
 TestCase TestCase4() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
-          /*func*/
-          MakeTensorSliceDatasetFunc(
-              DataTypeVector({DT_INT64}),
-              std::vector<PartialTensorShape>({PartialTensorShape({1})})),
-          /*func_lib*/ {test::function::MakeTensorSliceDataset()},
-          /*cycle_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {5}),
-          /*block_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-          /*num_parallel_calls*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4}),
-          /*sloppy*/ true,
-          /*expected_outputs*/
-          ConvertToTensorVec<int64>({0, 3, 6, 1, 4, 7, 2, 5, 8}),
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
-          /*breakpoints*/ {0, 4, 11}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+      /*func*/
+      MakeTensorSliceDatasetFunc(
+          DataTypeVector({DT_INT64}),
+          std::vector<PartialTensorShape>({PartialTensorShape({1})})),
+      /*func_lib*/ {test::function::MakeTensorSliceDataset()},
+      /*cycle_length*/
+      CreateTensor<int64>(TensorShape({}), {5}),
+      /*block_length*/
+      CreateTensor<int64>(TensorShape({}), {1}),
+      /*num_parallel_calls*/
+      CreateTensor<int64>(TensorShape({}), {4}),
+      /*sloppy*/ true,
+      /*expected_outputs*/
+      ConvertToTensorVec<int64>({0, 3, 6, 1, 4, 7, 2, 5, 8}),
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
+      /*breakpoints*/ {0, 4, 11}};
 }
 
 // test case 5: cycle_length = 2, block_length = 2, num_parallel_calls = 1,
@@ -213,19 +212,19 @@
 TestCase TestCase5() {
   return {
       /*input_tensors*/
-      {DatasetOpsTestBase::CreateTensor<string>(
-          TensorShape{3, 3, 1}, {"a", "b", "c", "d", "e", "f", "g", "h", "i"})},
+      {CreateTensor<string>(TensorShape{3, 3, 1},
+                            {"a", "b", "c", "d", "e", "f", "g", "h", "i"})},
       /*func*/
       MakeTensorSliceDatasetFunc(
           DataTypeVector({DT_STRING}),
           std::vector<PartialTensorShape>({PartialTensorShape({1})})),
       /*func_lib*/ {test::function::MakeTensorSliceDataset()},
       /*cycle_length*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
+      CreateTensor<int64>(TensorShape({}), {2}),
       /*block_length*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
+      CreateTensor<int64>(TensorShape({}), {2}),
       /*num_parallel_calls*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
+      CreateTensor<int64>(TensorShape({}), {1}),
       /*sloppy*/ false,
       /*expected_outputs*/
       ConvertToTensorVec<string>({"a", "b", "d", "e", "c", "f", "g", "h", "i"}),
@@ -240,19 +239,19 @@
 TestCase TestCase6() {
   return {
       /*input_tensors*/
-      {DatasetOpsTestBase::CreateTensor<string>(
-          TensorShape{3, 3, 1}, {"a", "b", "c", "d", "e", "f", "g", "h", "i"})},
+      {CreateTensor<string>(TensorShape{3, 3, 1},
+                            {"a", "b", "c", "d", "e", "f", "g", "h", "i"})},
       /*func*/
       MakeTensorSliceDatasetFunc(
           DataTypeVector({DT_STRING}),
           std::vector<PartialTensorShape>({PartialTensorShape({1})})),
       /*func_lib*/ {test::function::MakeTensorSliceDataset()},
       /*cycle_length*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
+      CreateTensor<int64>(TensorShape({}), {2}),
       /*block_length*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3}),
+      CreateTensor<int64>(TensorShape({}), {3}),
       /*num_parallel_calls*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
+      CreateTensor<int64>(TensorShape({}), {2}),
       /*sloppy*/ true,
       /*expected_outputs*/
       ConvertToTensorVec<string>({"a", "b", "c", "d", "e", "f", "g", "h", "i"}),
@@ -267,19 +266,19 @@
 TestCase TestCase7() {
   return {
       /*input_tensors*/
-      {DatasetOpsTestBase::CreateTensor<string>(
-          TensorShape{3, 3, 1}, {"a", "b", "c", "d", "e", "f", "g", "h", "i"})},
+      {CreateTensor<string>(TensorShape{3, 3, 1},
+                            {"a", "b", "c", "d", "e", "f", "g", "h", "i"})},
       /*func*/
       MakeTensorSliceDatasetFunc(
           DataTypeVector({DT_STRING}),
           std::vector<PartialTensorShape>({PartialTensorShape({1})})),
       /*func_lib*/ {test::function::MakeTensorSliceDataset()},
       /*cycle_length*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3}),
+      CreateTensor<int64>(TensorShape({}), {3}),
       /*block_length*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
+      CreateTensor<int64>(TensorShape({}), {2}),
       /*num_parallel_calls*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
+      CreateTensor<int64>(TensorShape({}), {2}),
       /*sloppy*/ false,
       /*expected_outputs*/
       ConvertToTensorVec<string>({"a", "b", "d", "e", "g", "h", "c", "f", "i"}),
@@ -294,19 +293,19 @@
 TestCase TestCase8() {
   return {
       /*input_tensors*/
-      {DatasetOpsTestBase::CreateTensor<string>(
-          TensorShape{3, 3, 1}, {"a", "b", "c", "d", "e", "f", "g", "h", "i"})},
+      {CreateTensor<string>(TensorShape{3, 3, 1},
+                            {"a", "b", "c", "d", "e", "f", "g", "h", "i"})},
       /*func*/
       MakeTensorSliceDatasetFunc(
           DataTypeVector({DT_STRING}),
           std::vector<PartialTensorShape>({PartialTensorShape({1})})),
       /*func_lib*/ {test::function::MakeTensorSliceDataset()},
       /*cycle_length*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3}),
+      CreateTensor<int64>(TensorShape({}), {3}),
       /*block_length*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3}),
+      CreateTensor<int64>(TensorShape({}), {3}),
       /*num_parallel_calls*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3}),
+      CreateTensor<int64>(TensorShape({}), {3}),
       /*sloppy*/ true,
       /*expected_outputs*/
       ConvertToTensorVec<string>({"a", "b", "c", "d", "e", "f", "g", "h", "i"}),
@@ -321,19 +320,19 @@
 TestCase TestCase9() {
   return {
       /*input_tensors*/
-      {DatasetOpsTestBase::CreateTensor<string>(
-          TensorShape{3, 3, 1}, {"a", "b", "c", "d", "e", "f", "g", "h", "i"})},
+      {CreateTensor<string>(TensorShape{3, 3, 1},
+                            {"a", "b", "c", "d", "e", "f", "g", "h", "i"})},
       /*func*/
       MakeTensorSliceDatasetFunc(
           DataTypeVector({DT_STRING}),
           std::vector<PartialTensorShape>({PartialTensorShape({1})})),
       /*func_lib*/ {test::function::MakeTensorSliceDataset()},
       /*cycle_length*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4}),
+      CreateTensor<int64>(TensorShape({}), {4}),
       /*block_length*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4}),
+      CreateTensor<int64>(TensorShape({}), {4}),
       /*num_parallel_calls*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4}),
+      CreateTensor<int64>(TensorShape({}), {4}),
       /*sloppy*/ true,
       /*expected_outputs*/
       ConvertToTensorVec<string>({"a", "b", "c", "d", "e", "f", "g", "h", "i"}),
@@ -348,20 +347,19 @@
 TestCase TestCase10() {
   return {
       /*input_tensors*/
-      {DatasetOpsTestBase::CreateTensor<string>(
-          TensorShape{3, 3, 1}, {"a", "b", "c", "d", "e", "f", "g", "h", "i"})},
+      {CreateTensor<string>(TensorShape{3, 3, 1},
+                            {"a", "b", "c", "d", "e", "f", "g", "h", "i"})},
       /*func*/
       MakeTensorSliceDatasetFunc(
           DataTypeVector({DT_STRING}),
           std::vector<PartialTensorShape>({PartialTensorShape({1})})),
       /*func_lib*/ {test::function::MakeTensorSliceDataset()},
       /*cycle_length*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4}),
+      CreateTensor<int64>(TensorShape({}), {4}),
       /*block_length*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4}),
+      CreateTensor<int64>(TensorShape({}), {4}),
       /*num_parallel_calls*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}),
-                                              {model::kAutotune}),
+      CreateTensor<int64>(TensorShape({}), {model::kAutotune}),
       /*sloppy*/ true,
       /*expected_outputs*/
       ConvertToTensorVec<string>({"a", "b", "c", "d", "e", "f", "g", "h", "i"}),
@@ -374,79 +372,79 @@
 // test case 11: cycle_length = 0, block_length = 1, num_parallel_calls = 2,
 // sloppy = true
 TestCase InvalidCycleLengthTestCase() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
-          /*func*/
-          MakeTensorSliceDatasetFunc(
-              DataTypeVector({DT_INT64}),
-              std::vector<PartialTensorShape>({PartialTensorShape({1})})),
-          /*func_lib*/ {test::function::MakeTensorSliceDataset()},
-          /*cycle_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-          /*block_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-          /*num_parallel_calls*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-          /*sloppy*/ true,
-          /*expected_outputs*/
-          ConvertToTensorVec<int64>({}),
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
-          /*breakpoints*/ {}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+      /*func*/
+      MakeTensorSliceDatasetFunc(
+          DataTypeVector({DT_INT64}),
+          std::vector<PartialTensorShape>({PartialTensorShape({1})})),
+      /*func_lib*/ {test::function::MakeTensorSliceDataset()},
+      /*cycle_length*/
+      CreateTensor<int64>(TensorShape({}), {0}),
+      /*block_length*/
+      CreateTensor<int64>(TensorShape({}), {1}),
+      /*num_parallel_calls*/
+      CreateTensor<int64>(TensorShape({}), {2}),
+      /*sloppy*/ true,
+      /*expected_outputs*/
+      ConvertToTensorVec<int64>({}),
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
+      /*breakpoints*/ {}};
 }
 
 // test case 12: cycle_length = 1, block_length = -1, num_parallel_calls = 2,
 // sloppy = true
 TestCase InvalidBlockLengthTestCase() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
-          /*func*/
-          MakeTensorSliceDatasetFunc(
-              DataTypeVector({DT_INT64}),
-              std::vector<PartialTensorShape>({PartialTensorShape({1})})),
-          /*func_lib*/ {test::function::MakeTensorSliceDataset()},
-          /*cycle_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-          /*block_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {-1}),
-          /*num_parallel_calls*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-          /*sloppy*/ true,
-          /*expected_outputs*/
-          ConvertToTensorVec<int64>({}),
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
-          /*breakpoints*/ {}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+      /*func*/
+      MakeTensorSliceDatasetFunc(
+          DataTypeVector({DT_INT64}),
+          std::vector<PartialTensorShape>({PartialTensorShape({1})})),
+      /*func_lib*/ {test::function::MakeTensorSliceDataset()},
+      /*cycle_length*/
+      CreateTensor<int64>(TensorShape({}), {1}),
+      /*block_length*/
+      CreateTensor<int64>(TensorShape({}), {-1}),
+      /*num_parallel_calls*/
+      CreateTensor<int64>(TensorShape({}), {2}),
+      /*sloppy*/ true,
+      /*expected_outputs*/
+      ConvertToTensorVec<int64>({}),
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
+      /*breakpoints*/ {}};
 }
 
 // test case 13: cycle_length = 1, block_length = 1, num_parallel_calls = -5,
 // sloppy = true
 TestCase InvalidNumParallelCallsTestCase() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
-          /*func*/
-          MakeTensorSliceDatasetFunc(
-              DataTypeVector({DT_INT64}),
-              std::vector<PartialTensorShape>({PartialTensorShape({1})})),
-          /*func_lib*/ {test::function::MakeTensorSliceDataset()},
-          /*cycle_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-          /*block_length*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-          /*num_parallel_calls*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {-5}),
-          /*sloppy*/ true,
-          /*expected_outputs*/
-          ConvertToTensorVec<int64>({}),
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
-          /*breakpoints*/ {}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})},
+      /*func*/
+      MakeTensorSliceDatasetFunc(
+          DataTypeVector({DT_INT64}),
+          std::vector<PartialTensorShape>({PartialTensorShape({1})})),
+      /*func_lib*/ {test::function::MakeTensorSliceDataset()},
+      /*cycle_length*/
+      CreateTensor<int64>(TensorShape({}), {1}),
+      /*block_length*/
+      CreateTensor<int64>(TensorShape({}), {1}),
+      /*num_parallel_calls*/
+      CreateTensor<int64>(TensorShape({}), {-5}),
+      /*sloppy*/ true,
+      /*expected_outputs*/
+      ConvertToTensorVec<int64>({}),
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ tensorflow::data::kUnknownCardinality,
+      /*breakpoints*/ {}};
 }
 
 class ParameterizedParallelInterleaveDatasetOpTest
@@ -726,47 +724,6 @@
             test_case.expected_cardinality);
 }
 
-TEST_P(ParameterizedParallelInterleaveDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  const TestCase &test_case = GetParam();
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
-
-  std::unique_ptr<OpKernel> parallel_interleave_dataset_kernel;
-  TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel(
-      test_case.func, test_case.expected_output_dtypes,
-      test_case.expected_output_shapes, test_case.sloppy,
-      &parallel_interleave_dataset_kernel));
-
-  Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
-  std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
-  TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
-                                              &tensor_slice_dataset_tensor));
-  Tensor cycle_length = test_case.cycle_length;
-  Tensor block_length = test_case.block_length;
-  Tensor num_parallel_calls = test_case.num_parallel_calls;
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&tensor_slice_dataset_tensor), TensorValue(&cycle_length),
-       TensorValue(&block_length), TensorValue(&num_parallel_calls)});
-  std::unique_ptr<OpKernelContext> parallel_interleave_dataset_context;
-  TF_ASSERT_OK(CreateInterleaveDatasetContext(
-      parallel_interleave_dataset_kernel.get(), &inputs,
-      &parallel_interleave_dataset_context));
-  DatasetBase *parallel_interleave_dataset;
-  TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(),
-                             parallel_interleave_dataset_context.get(),
-                             &parallel_interleave_dataset));
-  core::ScopedUnref scoped_unref(parallel_interleave_dataset);
-
-  std::unique_ptr<SerializationContext> serialization_ctx;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(
-      parallel_interleave_dataset->Save(serialization_ctx.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_P(ParameterizedParallelInterleaveDatasetOpTest, IteratorOutputDtypes) {
   int thread_num = 2, cpu_num = 2;
   const TestCase &test_case = GetParam();
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index 4ec8711..625d672 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -89,6 +89,11 @@
 
   int64 Cardinality() const override { return input_->Cardinality(); }
 
+  Status CheckExternalState() const override {
+    TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
+    return input_->CheckExternalState();
+  }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc
index 34f19e0..4870b7c 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc
@@ -86,17 +86,17 @@
 TestCase TestCase1() {
   return {/*range_data_param*/ {0, 10, 3},
           /*num_parallel_calls*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
+          CreateTensor<int64>(TensorShape({}), {1}),
           /*func*/ MapFunc("XTimesTwo", DT_INT64),
           /*func_lib*/ {test::function::XTimesTwo()},
           /*use_inter_op_parallelism*/ false,
           /*sloppy*/ false,
           /*preserve_cardinality*/ false,
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {6}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {12}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {18})},
+          {CreateTensor<int64>(TensorShape({}), {0}),
+           CreateTensor<int64>(TensorShape({}), {6}),
+           CreateTensor<int64>(TensorShape({}), {12}),
+           CreateTensor<int64>(TensorShape({}), {18})},
           /*expected_output_dtypes*/ {DT_INT64},
           /*expected_output_shapes*/ {PartialTensorShape({})},
           /*expected_cardinality*/ 4,
@@ -108,17 +108,17 @@
 TestCase TestCase2() {
   return {/*range_data_param*/ {0, 10, 3},
           /*num_parallel_calls*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
+          CreateTensor<int64>(TensorShape({}), {2}),
           /*func*/ MapFunc("XTimesTwo", DT_INT64),
           /*func_lib*/ {test::function::XTimesTwo()},
           /*use_inter_op_parallelism*/ true,
           /*sloppy*/ true,
           /*preserve_cardinality*/ true,
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {6}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {12}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {18})},
+          {CreateTensor<int64>(TensorShape({}), {0}),
+           CreateTensor<int64>(TensorShape({}), {6}),
+           CreateTensor<int64>(TensorShape({}), {12}),
+           CreateTensor<int64>(TensorShape({}), {18})},
           /*expected_output_dtypes*/ {DT_INT64},
           /*expected_output_shapes*/ {PartialTensorShape({})},
           /*expected_cardinality*/ 4,
@@ -131,17 +131,17 @@
   return {
       /*range_data_param*/ {0, 10, 3},
       /*num_parallel_calls*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3}),
+      CreateTensor<int64>(TensorShape({}), {3}),
       /*func*/ MapFunc("XTimesFour", DT_INT64),
       /*func_lib*/ {test::function::XTimesTwo(), test::function::XTimesFour()},
       /*use_inter_op_parallelism*/ true,
       /*sloppy*/ false,
       /*preserve_cardinality*/ false,
       /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {12}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {24}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {36})},
+      {CreateTensor<int64>(TensorShape({}), {0}),
+       CreateTensor<int64>(TensorShape({}), {12}),
+       CreateTensor<int64>(TensorShape({}), {24}),
+       CreateTensor<int64>(TensorShape({}), {36})},
       /*expected_output_dtypes*/ {DT_INT64},
       /*expected_output_shapes*/ {PartialTensorShape({})},
       /*expected_cardinality*/ 4,
@@ -153,17 +153,17 @@
 TestCase TestCase4() {
   return {/*range_data_param*/ {0, 10, 3},
           /*num_parallel_calls*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4}),
+          CreateTensor<int64>(TensorShape({}), {4}),
           /*func*/ MapFunc("XTimesTwo", DT_INT64),
           /*func_lib*/ {test::function::XTimesTwo()},
           /*use_inter_op_parallelism*/ false,
           /*sloppy*/ false,
           /*preserve_cardinality*/ false,
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {6}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {12}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {18})},
+          {CreateTensor<int64>(TensorShape({}), {0}),
+           CreateTensor<int64>(TensorShape({}), {6}),
+           CreateTensor<int64>(TensorShape({}), {12}),
+           CreateTensor<int64>(TensorShape({}), {18})},
           /*expected_output_dtypes*/ {DT_INT64},
           /*expected_output_shapes*/ {PartialTensorShape({})},
           /*expected_cardinality*/ 4,
@@ -176,18 +176,17 @@
   return {
       /*range_data_param*/ {0, 10, 3},
       /*num_parallel_calls*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}),
-                                              {model::kAutotune}),
+      CreateTensor<int64>(TensorShape({}), {model::kAutotune}),
       /*func*/ MapFunc("XTimesFour", DT_INT64),
       /*func_lib*/ {test::function::XTimesTwo(), test::function::XTimesFour()},
       /*use_inter_op_parallelism*/ true,
       /*sloppy*/ true,
       /*preserve_cardinality*/ true,
       /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {12}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {24}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {36})},
+      {CreateTensor<int64>(TensorShape({}), {0}),
+       CreateTensor<int64>(TensorShape({}), {12}),
+       CreateTensor<int64>(TensorShape({}), {24}),
+       CreateTensor<int64>(TensorShape({}), {36})},
       /*expected_output_dtypes*/ {DT_INT64},
       /*expected_output_shapes*/ {PartialTensorShape({})},
       /*expected_cardinality*/ 4,
@@ -200,17 +199,17 @@
   return {
       /*range_data_param*/ {0, 10, 3},
       /*num_parallel_calls*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4}),
+      CreateTensor<int64>(TensorShape({}), {4}),
       /*func*/ MapFunc("XTimesFour", DT_INT64),
       /*func_lib*/ {test::function::XTimesTwo(), test::function::XTimesFour()},
       /*use_inter_op_parallelism*/ true,
       /*sloppy*/ false,
       /*preserve_cardinality*/ false,
       /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {12}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {24}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {36})},
+      {CreateTensor<int64>(TensorShape({}), {0}),
+       CreateTensor<int64>(TensorShape({}), {12}),
+       CreateTensor<int64>(TensorShape({}), {24}),
+       CreateTensor<int64>(TensorShape({}), {36})},
       /*expected_output_dtypes*/ {DT_INT64},
       /*expected_output_shapes*/ {PartialTensorShape({})},
       /*expected_cardinality*/ 4,
@@ -224,17 +223,17 @@
   return {
       /*range_data_param*/ {0, 10, 3},
       /*num_parallel_calls*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
+      CreateTensor<int64>(TensorShape({}), {2}),
       /*func*/ MapFunc("XTimesFour", DT_INT64),
       /*func_lib*/ {test::function::XTimesTwo(), test::function::XTimesFour()},
       /*use_inter_op_parallelism*/ false,
       /*sloppy*/ false,
       /*preserve_cardinality*/ false,
       /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {12}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {24}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {36})},
+      {CreateTensor<int64>(TensorShape({}), {0}),
+       CreateTensor<int64>(TensorShape({}), {12}),
+       CreateTensor<int64>(TensorShape({}), {24}),
+       CreateTensor<int64>(TensorShape({}), {36})},
       /*expected_output_dtypes*/ {DT_INT64},
       /*expected_output_shapes*/ {PartialTensorShape({})},
       /*expected_cardinality*/ 4,
@@ -248,18 +247,17 @@
   return {
       /*range_data_param*/ {0, 10, 3},
       /*num_parallel_calls*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}),
-                                              {model::kAutotune}),
+      CreateTensor<int64>(TensorShape({}), {model::kAutotune}),
       /*func*/ MapFunc("XTimesFour", DT_INT64),
       /*func_lib*/ {test::function::XTimesTwo(), test::function::XTimesFour()},
       /*use_inter_op_parallelism*/ false,
       /*sloppy*/ true,
       /*preserve_cardinality*/ true,
       /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {12}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {24}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {36})},
+      {CreateTensor<int64>(TensorShape({}), {0}),
+       CreateTensor<int64>(TensorShape({}), {12}),
+       CreateTensor<int64>(TensorShape({}), {24}),
+       CreateTensor<int64>(TensorShape({}), {36})},
       /*expected_output_dtypes*/ {DT_INT64},
       /*expected_output_shapes*/ {PartialTensorShape({})},
       /*expected_cardinality*/ 4,
@@ -269,7 +267,7 @@
 TestCase InvalidNumParallelCallsTestCase() {
   return {/*range_data_param*/ {0, 10, 3},
           /*num_parallel_calls*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {-4}),
+          CreateTensor<int64>(TensorShape({}), {-4}),
           /*func*/ MapFunc("XTimesTwo", DT_INT64),
           /*func_lib*/ {test::function::XTimesTwo()},
           /*use_inter_op_parallelism*/ true,
@@ -529,49 +527,6 @@
             test_case.expected_cardinality);
 }
 
-TEST_P(ParameterizedParallelMapDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  TestCase test_case = GetParam();
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
-
-  std::unique_ptr<OpKernel> parallel_map_dataset_kernel;
-  TF_ASSERT_OK(CreateParallelMapDatasetOpKernel(
-      test_case.func, test_case.expected_output_dtypes,
-      test_case.expected_output_shapes, test_case.use_inter_op_parallelism,
-      test_case.sloppy, test_case.preserve_cardinality,
-      &parallel_map_dataset_kernel));
-
-  DatasetBase* range_dataset;
-  TF_ASSERT_OK(CreateRangeDataset<int64>(
-      test_case.range_data_param.start, test_case.range_data_param.end,
-      test_case.range_data_param.step, "range", &range_dataset));
-  Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
-  TF_ASSERT_OK(
-      StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
-  Tensor num_parallel_calls = test_case.num_parallel_calls;
-  gtl::InlinedVector<TensorValue, 4> parallel_map_dataset_inputs(
-      {TensorValue(&range_dataset_tensor), TensorValue(&num_parallel_calls)});
-
-  std::unique_ptr<OpKernelContext> parallel_map_dataset_context;
-  TF_ASSERT_OK(CreateParallelMapDatasetContext(
-      parallel_map_dataset_kernel.get(), &parallel_map_dataset_inputs,
-      &parallel_map_dataset_context));
-  DatasetBase* parallel_map_dataset;
-  TF_ASSERT_OK(CreateDataset(parallel_map_dataset_kernel.get(),
-                             parallel_map_dataset_context.get(),
-                             &parallel_map_dataset));
-  core::ScopedUnref scoped_unref_map_dataset(parallel_map_dataset);
-
-  std::unique_ptr<SerializationContext> serialization_context;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(
-      parallel_map_dataset->Save(serialization_context.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_P(ParameterizedParallelMapDatasetOpTest, IteratorOutputDtypes) {
   int thread_num = 2, cpu_num = 2;
   TestCase test_case = GetParam();
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index ec6cec0..490c123 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -39,6 +39,7 @@
 /* static */ constexpr const char* const PrefetchDatasetOp::kOutputTypes;
 /* static */ constexpr const char* const PrefetchDatasetOp::kOutputShapes;
 /* static */ constexpr const char* const PrefetchDatasetOp::kSlackPeriod;
+/* static */ constexpr const char* const PrefetchDatasetOp::kLegacyAutotune;
 
 // Determines the fraction of slack time by which to delay prefetching of data.
 constexpr double kSleepFactor = 0.2;
@@ -51,11 +52,12 @@
 class PrefetchDatasetOp::Dataset : public DatasetBase {
  public:
   Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
-          int64 slack_period)
+          int64 slack_period, bool legacy_autotune)
       : DatasetBase(DatasetContext(ctx)),
         input_(input),
         buffer_size_(buffer_size),
-        slack_period_(slack_period) {
+        slack_period_(slack_period),
+        legacy_autotune_(legacy_autotune) {
     input_->Ref();
   }
 
@@ -81,6 +83,10 @@
 
   int64 Cardinality() const override { return input_->Cardinality(); }
 
+  Status CheckExternalState() const override {
+    return input_->CheckExternalState();
+  }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
@@ -451,6 +457,9 @@
   // If non-zero, determines the period between injecting "slack" into the
   // execution.
   const int64 slack_period_;
+
+  // Determines whether legacy autotuning should be used.
+  const bool legacy_autotune_ = true;
 };
 
 PrefetchDatasetOp::PrefetchDatasetOp(OpKernelConstruction* ctx)
@@ -458,6 +467,9 @@
   if (ctx->HasAttr(kSlackPeriod)) {
     OP_REQUIRES_OK(ctx, ctx->GetAttr(kSlackPeriod, &slack_period_));
   }
+  if (ctx->HasAttr(kLegacyAutotune)) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr(kLegacyAutotune, &legacy_autotune_));
+  }
 }
 
 void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
@@ -474,7 +486,8 @@
     metrics::RecordTFDataAutotune(kDatasetType);
   }
 
-  *output = new Dataset(ctx, input, buffer_size, slack_period_);
+  *output =
+      new Dataset(ctx, input, buffer_size, slack_period_, legacy_autotune_);
 }
 
 namespace {
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.h b/tensorflow/core/kernels/data/prefetch_dataset_op.h
index 17df807..999f002 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.h
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.h
@@ -30,6 +30,7 @@
   static constexpr const char* const kOutputTypes = "output_types";
   static constexpr const char* const kOutputShapes = "output_shapes";
   static constexpr const char* const kSlackPeriod = "slack_period";
+  static constexpr const char* const kLegacyAutotune = "legacy_autotune";
 
   explicit PrefetchDatasetOp(OpKernelConstruction* ctx);
 
@@ -40,6 +41,7 @@
  private:
   class Dataset;
   int64 slack_period_ = 0;
+  bool legacy_autotune_ = true;
 };
 
 }  // namespace data
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op_test.cc b/tensorflow/core/kernels/data/prefetch_dataset_op_test.cc
index 3cd70c8..03c193f 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op_test.cc
@@ -44,7 +44,8 @@
         {PrefetchDatasetOp::kInputDataset, PrefetchDatasetOp::kBufferSize},
         {{PrefetchDatasetOp::kOutputTypes, output_types},
          {PrefetchDatasetOp::kOutputShapes, output_shapes},
-         {PrefetchDatasetOp::kSlackPeriod, 0}});
+         {PrefetchDatasetOp::kSlackPeriod, 0},
+         {PrefetchDatasetOp::kLegacyAutotune, true}});
     TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel));
     return Status::OK();
   }
@@ -70,81 +71,81 @@
 };
 
 TestCase PositiveBufferSizeTestCase() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
-          /*buffer_size*/ 5,
-          /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {0}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {1}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {2}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {4}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {5}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {6}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {7}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {8}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {9})},
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ 10,
-          /*breakpoints*/ {0, 4, 11}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
+      /*buffer_size*/ 5,
+      /*expected_outputs*/
+      {CreateTensor<int64>(TensorShape{1}, {0}),
+       CreateTensor<int64>(TensorShape{1}, {1}),
+       CreateTensor<int64>(TensorShape{1}, {2}),
+       CreateTensor<int64>(TensorShape{1}, {3}),
+       CreateTensor<int64>(TensorShape{1}, {4}),
+       CreateTensor<int64>(TensorShape{1}, {5}),
+       CreateTensor<int64>(TensorShape{1}, {6}),
+       CreateTensor<int64>(TensorShape{1}, {7}),
+       CreateTensor<int64>(TensorShape{1}, {8}),
+       CreateTensor<int64>(TensorShape{1}, {9})},
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ 10,
+      /*breakpoints*/ {0, 4, 11}};
 }
 
 TestCase ZeroBufferSizeTestCase() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
-          /*buffer_size*/ 0,
-          /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {0}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {1}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {2}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {4}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {5}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {6}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {7}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {8}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {9})},
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ 10,
-          /*breakpoints*/ {0, 4, 11}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
+      /*buffer_size*/ 0,
+      /*expected_outputs*/
+      {CreateTensor<int64>(TensorShape{1}, {0}),
+       CreateTensor<int64>(TensorShape{1}, {1}),
+       CreateTensor<int64>(TensorShape{1}, {2}),
+       CreateTensor<int64>(TensorShape{1}, {3}),
+       CreateTensor<int64>(TensorShape{1}, {4}),
+       CreateTensor<int64>(TensorShape{1}, {5}),
+       CreateTensor<int64>(TensorShape{1}, {6}),
+       CreateTensor<int64>(TensorShape{1}, {7}),
+       CreateTensor<int64>(TensorShape{1}, {8}),
+       CreateTensor<int64>(TensorShape{1}, {9})},
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ 10,
+      /*breakpoints*/ {0, 4, 11}};
 }
 
 TestCase AutoTuneTestCase() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
-          /*buffer_size*/ -1,
-          /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {0}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {1}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {2}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {4}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {5}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {6}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {7}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {8}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {9})},
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ 10,
-          /*breakpoints*/ {0, 4, 11}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
+      /*buffer_size*/ -1,
+      /*expected_outputs*/
+      {CreateTensor<int64>(TensorShape{1}, {0}),
+       CreateTensor<int64>(TensorShape{1}, {1}),
+       CreateTensor<int64>(TensorShape{1}, {2}),
+       CreateTensor<int64>(TensorShape{1}, {3}),
+       CreateTensor<int64>(TensorShape{1}, {4}),
+       CreateTensor<int64>(TensorShape{1}, {5}),
+       CreateTensor<int64>(TensorShape{1}, {6}),
+       CreateTensor<int64>(TensorShape{1}, {7}),
+       CreateTensor<int64>(TensorShape{1}, {8}),
+       CreateTensor<int64>(TensorShape{1}, {9})},
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ 10,
+      /*breakpoints*/ {0, 4, 11}};
 }
 
 TestCase InvalidBufferSizeTestCase() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
-          /*buffer_size*/ -2,
-          /*expected_outputs*/ {},
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ 0,
-          /*breakpoints*/ {0, 4, 11}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
+      /*buffer_size*/ -2,
+      /*expected_outputs*/ {},
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ 0,
+      /*breakpoints*/ {0, 4, 11}};
 }
 
 class ParameterizedPrefetchDatasetOpTest
@@ -397,43 +398,6 @@
   EXPECT_EQ(prefetch_dataset->Cardinality(), test_case.expected_cardinality);
 }
 
-TEST_F(PrefetchDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  const TestCase &test_case = PositiveBufferSizeTestCase();
-  Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
-  std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
-  TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
-                                              &tensor_slice_dataset_tensor));
-  Tensor buffer_size =
-      CreateTensor<int64>(TensorShape{}, {test_case.buffer_size});
-  gtl::InlinedVector<TensorValue, 4> inputs_for_prefetch_dataset(
-      {TensorValue(&tensor_slice_dataset_tensor), TensorValue(&buffer_size)});
-
-  std::unique_ptr<OpKernel> prefetch_dataset_kernel;
-  TF_ASSERT_OK(CreatePrefetchDatasetKernel(test_case.expected_output_dtypes,
-                                           test_case.expected_output_shapes,
-                                           &prefetch_dataset_kernel));
-  std::unique_ptr<OpKernelContext> prefetch_dataset_context;
-  TF_ASSERT_OK(CreatePrefetchDatasetContext(prefetch_dataset_kernel.get(),
-                                            &inputs_for_prefetch_dataset,
-                                            &prefetch_dataset_context));
-  DatasetBase *prefetch_dataset;
-  TF_ASSERT_OK(CreateDataset(prefetch_dataset_kernel.get(),
-                             prefetch_dataset_context.get(),
-                             &prefetch_dataset));
-  core::ScopedUnref scoped_unref(prefetch_dataset);
-
-  std::unique_ptr<SerializationContext> serialization_ctx;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(prefetch_dataset->Save(serialization_ctx.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_F(PrefetchDatasetOpTest, IteratorOutputDtypes) {
   int thread_num = 2, cpu_num = 2;
   TF_ASSERT_OK(InitThreadPool(thread_num));
diff --git a/tensorflow/core/kernels/data/random_seed_ops.cc b/tensorflow/core/kernels/data/random_seed_ops.cc
new file mode 100644
index 0000000..f9fc975
--- /dev/null
+++ b/tensorflow/core/kernels/data/random_seed_ops.cc
@@ -0,0 +1,126 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/kernels/data/random_seed_ops.h"
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+const char kNumRandomSamples[] = "num_random_samples";
+const char kRandomSeedGenerator[] = "RandomSeedGenerator";
+const char kSeed[] = "seed";
+const char kSeed2[] = "seed2";
+
+}  // namespace
+
+string RandomSeedGenerator::DebugString() const { return kRandomSeedGenerator; }
+
+void RandomSeedGenerator::GenerateRandomSeeds(int64* seed1, int64* seed2) {
+  mutex_lock l(mu_);
+  num_random_samples_++;
+  *seed1 = generator_();
+  num_random_samples_++;
+  *seed2 = generator_();
+}
+
+int64 RandomSeedGenerator::num_random_samples() {
+  tf_shared_lock l(mu_);
+  return num_random_samples_;
+}
+
+void RandomSeedGenerator::set_num_random_samples(int64 num_random_samples) {
+  mutex_lock l(mu_);
+  num_random_samples_ = num_random_samples;
+}
+
+void RandomSeedGenerator::Reset() {
+  mutex_lock l(mu_);
+  // Reset the generators based on the current seeds.
+  parent_generator_ = random::PhiloxRandom(seed_, seed2_);
+  generator_ =
+      random::SingleSampleAdapter<random::PhiloxRandom>(&parent_generator_);
+  generator_.Skip(num_random_samples_);
+}
+
+void RandomSeedGenerator::Serialize(OpKernelContext* ctx) {
+  mutex_lock l(mu_);
+  Tensor* num_random_samples;
+  OP_REQUIRES_OK(ctx, ctx->allocate_output(kNumRandomSamples, TensorShape({}),
+                                           &num_random_samples));
+  num_random_samples->scalar<int64>()() = num_random_samples_;
+  Tensor* seed;
+  OP_REQUIRES_OK(ctx, ctx->allocate_output(kSeed, TensorShape({}), &seed));
+  seed->scalar<int64>()() = seed_;
+  Tensor* seed2;
+  OP_REQUIRES_OK(ctx, ctx->allocate_output(kSeed2, TensorShape({}), &seed2));
+  seed2->scalar<int64>()() = seed2_;
+}
+
+AnonymousRandomSeedGeneratorHandleOp::AnonymousRandomSeedGeneratorHandleOp(
+    OpKernelConstruction* ctx)
+    : AnonymousResourceOp<RandomSeedGenerator>(ctx) {}
+
+void AnonymousRandomSeedGeneratorHandleOp::Compute(OpKernelContext* ctx) {
+  OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed, &seed_));
+  OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2_));
+  AnonymousResourceOp<RandomSeedGenerator>::Compute(ctx);
+}
+
+string AnonymousRandomSeedGeneratorHandleOp::name() {
+  return kRandomSeedGenerator;
+}
+
+Status AnonymousRandomSeedGeneratorHandleOp::CreateResource(
+    OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
+    std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
+    FunctionLibraryRuntime* lib, RandomSeedGenerator** resource) {
+  *resource = new RandomSeedGenerator(seed_, seed2_);
+  return Status::OK();
+}
+
+void DeleteRandomSeedGeneratorOp::Compute(OpKernelContext* ctx) {
+  ResourceHandle handle = ctx->input(0).flat<ResourceHandle>()(0);
+  // The resource is guaranteed to exist because the variant tensor wrapping the
+  // deleter is provided as an unused input to this op, which guarantees that it
+  // has not run yet.
+  Status s = ctx->resource_manager()->Delete(handle);
+  if (errors::IsNotFound(s)) {
+    // TODO(b/135948230): Investigate why is the above statement not true and
+    // then get rid of the special case.
+    ctx->SetStatus(Status::OK());
+    return;
+  }
+  ctx->SetStatus(s);
+}
+
+namespace {
+
+REGISTER_KERNEL_BUILDER(Name("AnonymousRandomSeedGenerator").Device(DEVICE_CPU),
+                        AnonymousRandomSeedGeneratorHandleOp);
+
+REGISTER_KERNEL_BUILDER(Name("DeleteRandomSeedGenerator").Device(DEVICE_CPU),
+                        DeleteRandomSeedGeneratorOp);
+
+}  // namespace
+}  // namespace data
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/random_seed_ops.h b/tensorflow/core/kernels/data/random_seed_ops.h
new file mode 100644
index 0000000..750e6fd
--- /dev/null
+++ b/tensorflow/core/kernels/data/random_seed_ops.h
@@ -0,0 +1,86 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_RANDOM_SEED_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_RANDOM_SEED_OPS_H_
+
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+
+namespace tensorflow {
+namespace data {
+
+// A random seed generator resource.
+class RandomSeedGenerator : public ResourceBase {
+ public:
+  RandomSeedGenerator(int64 seed, int64 seed2)
+      : seed_(seed),
+        seed2_(seed2),
+        parent_generator_(seed, seed2),
+        generator_(&parent_generator_) {}
+
+  int64 num_random_samples();
+  void set_num_random_samples(int64 num_random_samples);
+
+  string DebugString() const override;
+  void GenerateRandomSeeds(int64* seed1, int64* seed2);
+  void Reset();
+  void Serialize(OpKernelContext* ctx);
+
+ private:
+  const int64 seed_;
+  const int64 seed2_;
+  mutex mu_;
+  random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
+  random::SingleSampleAdapter<random::PhiloxRandom> generator_ GUARDED_BY(mu_);
+  int64 num_random_samples_ GUARDED_BY(mu_) = 0;
+};
+
+// Creates an instance of random seed generator resource and transfers ownership
+// to the caller.
+class AnonymousRandomSeedGeneratorHandleOp
+    : public AnonymousResourceOp<RandomSeedGenerator> {
+ public:
+  explicit AnonymousRandomSeedGeneratorHandleOp(OpKernelConstruction* ctx);
+  void Compute(OpKernelContext* ctx) override;
+
+ private:
+  string name() override;
+  Status CreateResource(OpKernelContext* ctx,
+                        std::unique_ptr<FunctionLibraryDefinition> flib_def,
+                        std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
+                        FunctionLibraryRuntime* lib,
+                        RandomSeedGenerator** resource) override;
+
+  int64 seed_;
+  int64 seed2_;
+};
+
+// Deletes an instance of random seed generator resource.
+class DeleteRandomSeedGeneratorOp : public OpKernel {
+ public:
+  explicit DeleteRandomSeedGeneratorOp(OpKernelConstruction* ctx)
+      : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override;
+};
+
+}  // namespace data
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_KERNELS_DATA_RANDOM_SEED_OPS_H_
diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc
index 8b37b2c..8e87085 100644
--- a/tensorflow/core/kernels/data/range_dataset_op.cc
+++ b/tensorflow/core/kernels/data/range_dataset_op.cc
@@ -73,6 +73,8 @@
     }
   }
 
+  Status CheckExternalState() const override { return Status::OK(); }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
diff --git a/tensorflow/core/kernels/data/range_dataset_op_test.cc b/tensorflow/core/kernels/data/range_dataset_op_test.cc
index 3165ad5..dfa6959 100644
--- a/tensorflow/core/kernels/data/range_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/range_dataset_op_test.cc
@@ -12,7 +12,6 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-
 #include "tensorflow/core/kernels/data/range_dataset_op.h"
 
 #include "tensorflow/core/kernels/data/dataset_test_base.h"
@@ -22,6 +21,7 @@
 namespace {
 
 constexpr char kNodeName[] = "range_dataset";
+constexpr char kIteratorPrefix[] = "Iterator";
 
 class RangeDatasetOpTest : public DatasetOpsTestBase {
  protected:
@@ -37,77 +37,83 @@
   }
 };
 
-struct TestCase {
-  int64 start;
-  int64 stop;
-  int64 step;
-  std::vector<Tensor> expected_outputs;
-  DataTypeVector expected_output_dtypes;
-  std::vector<PartialTensorShape> expected_output_shapes;
-  int64 expected_cardinality;
-  std::vector<int> breakpoints;
+class RangeDatasetParams : public DatasetParams {
+ public:
+  RangeDatasetParams(int64 start, int64 stop, int64 step,
+                     DataTypeVector output_dtypes,
+                     std::vector<PartialTensorShape> output_shapes,
+                     string node_name)
+      : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
+                      std::move(node_name)),
+        start(CreateTensor<int64>(TensorShape({}), {start})),
+        stop(CreateTensor<int64>(TensorShape({}), {stop})),
+        step(CreateTensor<int64>(TensorShape({}), {step})) {}
+
+  Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override {
+    *inputs = {TensorValue(&start), TensorValue(&stop), TensorValue(&step)};
+    return Status::OK();
+  }
+
+  Tensor start;
+  Tensor stop;
+  Tensor step;
 };
 
-TestCase PositiveStepTestCase() {
-  return {/*start*/ 0,
-          /*stop*/ 10,
-          /*step*/ 3,
-          /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {6}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {9})},
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({})},
-          /*expected_cardinality*/ 4,
-          /*breakpoints*/ {0, 1, 4}};
+RangeDatasetParams PositiveStepRangeDataset() {
+  return {/*start=*/0,
+          /*stop=*/10,
+          /*step=*/3,
+          /*output_dtypes=*/{DT_INT64},
+          /*output_shapes=*/{PartialTensorShape({})},
+          /*node_name=*/kNodeName};
 }
 
-TestCase NegativeStepTestCase() {
-  return {/*start*/ 10,
-          /*stop*/ 0,
-          /*step*/ -3,
-          /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {10}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {7}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1})},
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({})},
-          /*expected_cardinality*/ 4,
-          /*breakpoints*/ {0, 1, 4}};
+RangeDatasetParams NegativeStepRangeDataset() {
+  return {/*start=*/10,
+          /*stop=*/0,
+          /*step=*/-3,
+          /*output_dtypes=*/{DT_INT64},
+          /*output_shapes=*/{PartialTensorShape({})},
+          /*node_name=*/kNodeName};
 }
 
-TestCase ZeroStepTestCase() {
-  return {/*start*/ 0,
-          /*stop*/ 10,
-          /*step*/ 0,
-          /*expected_outputs*/ {},
-          /*expected_output_dtypes*/ {},
-          /*expected_output_shapes*/ {},
-          /*expected_cardinality*/ 0,
-          /*breakpoints*/ {}};
+RangeDatasetParams ZeroStepRangeDataset() {
+  return {/*start=*/10,
+          /*stop=*/0,
+          /*step=*/0,
+          /*output_dtypes=*/{DT_INT64},
+          /*output_shapes=*/{PartialTensorShape({})},
+          /*node_name=*/kNodeName};
 }
 
-class ParameterizedRangeDatasetOpTest
+class ParameterizedGetNextRangeDatasetOpTest
     : public RangeDatasetOpTest,
-      public ::testing::WithParamInterface<TestCase> {};
+      public ::testing::WithParamInterface<
+          GetNextTestCase<RangeDatasetParams>> {};
 
-TEST_P(ParameterizedRangeDatasetOpTest, GetNext) {
+GetNextTestCase<RangeDatasetParams> GetNextTestCase1() {
+  return {/*dataset_params=*/PositiveStepRangeDataset(),
+          /*expected_outputs=*/
+          CreateTensors<int64>(TensorShape({}), {{0}, {3}, {6}, {9}})};
+}
+
+GetNextTestCase<RangeDatasetParams> GetNextTestCase2() {
+  return {/*dataset_params=*/NegativeStepRangeDataset(),
+          /*expected_outputs=*/
+          CreateTensors<int64>(TensorShape({}), {{10}, {7}, {4}, {1}})};
+}
+
+TEST_P(ParameterizedGetNextRangeDatasetOpTest, GetNext) {
   int thread_num = 2, cpu_num = 2;
   TF_ASSERT_OK(InitThreadPool(thread_num));
   TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
 
-  TestCase test_case = GetParam();
-  Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
-  Tensor stop = CreateTensor<int64>(TensorShape({}), {test_case.stop});
-  Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&start), TensorValue(&stop), TensorValue(&step)});
-
+  GetNextTestCase<RangeDatasetParams> test_case = GetParam();
+  gtl::InlinedVector<TensorValue, 4> inputs;
+  TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
   std::unique_ptr<OpKernel> range_dataset_kernel;
-  TF_ASSERT_OK(
-      CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
+  TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
+      test_case.dataset_params.node_name, &range_dataset_kernel));
   std::unique_ptr<OpKernelContext> range_dataset_context;
   TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
                                          &range_dataset_context));
@@ -120,22 +126,375 @@
   TF_ASSERT_OK(
       CreateIteratorContext(range_dataset_context.get(), &iterator_context));
   std::unique_ptr<IteratorBase> iterator;
-  TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
-                                           &iterator));
+  TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(),
+                                           kIteratorPrefix, &iterator));
 
-  bool end_of_sequence = false;
-  auto expected_outputs_it = test_case.expected_outputs.begin();
-  std::vector<Tensor> out_tensors;
-  while (!end_of_sequence) {
-    TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
-                                   &end_of_sequence));
-    if (!end_of_sequence) {
-      EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end());
-      TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it));
-      expected_outputs_it++;
-    }
-  }
-  EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
+  TF_ASSERT_OK(CheckIteratorGetNext(iterator.get(), iterator_context.get(),
+                                    test_case.expected_outputs,
+                                    /*compare_order=*/true));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+    RangeDatasetOpTest, ParameterizedGetNextRangeDatasetOpTest,
+    ::testing::ValuesIn(std::vector<GetNextTestCase<RangeDatasetParams>>(
+        {GetNextTestCase1(), GetNextTestCase2()})));
+
+DatasetNodeNameTestCase<RangeDatasetParams> DatasetNodeNameTestCase1() {
+  return {/*dataset_params=*/PositiveStepRangeDataset(),
+          /*expected_node_name=*/kNodeName};
+}
+
+TEST_F(RangeDatasetOpTest, DatasetNodeName) {
+  int thread_num = 2, cpu_num = 2;
+  TF_ASSERT_OK(InitThreadPool(thread_num));
+  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
+
+  auto test_case = DatasetNodeNameTestCase1();
+  gtl::InlinedVector<TensorValue, 4> inputs;
+  TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
+  std::unique_ptr<OpKernel> range_dataset_kernel;
+  TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
+      test_case.dataset_params.node_name, &range_dataset_kernel));
+  std::unique_ptr<OpKernelContext> range_dataset_context;
+  TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
+                                         &range_dataset_context));
+  DatasetBase* range_dataset;
+  TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
+                             range_dataset_context.get(), &range_dataset));
+  core::ScopedUnref scoped_unref(range_dataset);
+
+  TF_ASSERT_OK(
+      CheckDatasetNodeName(*range_dataset, test_case.expected_node_name));
+}
+
+DatasetTypeStringTestCase<RangeDatasetParams> DatasetTypeStringTestCase1() {
+  return {/*dataset_params=*/PositiveStepRangeDataset(),
+          /*expected_dataset_type_string=*/
+          name_utils::OpName(RangeDatasetOp::kDatasetType)};
+}
+
+TEST_F(RangeDatasetOpTest, DatasetTypeString) {
+  int thread_num = 2, cpu_num = 2;
+  TF_ASSERT_OK(InitThreadPool(thread_num));
+  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
+
+  auto test_case = DatasetTypeStringTestCase1();
+  gtl::InlinedVector<TensorValue, 4> inputs;
+  TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
+  std::unique_ptr<OpKernel> range_dataset_kernel;
+  TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
+      test_case.dataset_params.node_name, &range_dataset_kernel));
+  std::unique_ptr<OpKernelContext> range_dataset_context;
+  TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
+                                         &range_dataset_context));
+  DatasetBase* range_dataset;
+  TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
+                             range_dataset_context.get(), &range_dataset));
+  core::ScopedUnref scoped_unref(range_dataset);
+
+  TF_ASSERT_OK(CheckDatasetTypeString(*range_dataset,
+                                      test_case.expected_dataset_type_string));
+}
+
+DatasetOutputDtypesTestCase<RangeDatasetParams> DatasetOutputDtypesTestCase1() {
+  return {/*dataset_params=*/PositiveStepRangeDataset(),
+          /*expected_output_dtypes=*/{DT_INT64}};
+}
+
+TEST_F(RangeDatasetOpTest, DatasetOutputDtypes) {
+  int thread_num = 2, cpu_num = 2;
+  TF_ASSERT_OK(InitThreadPool(thread_num));
+  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
+
+  auto test_case = DatasetOutputDtypesTestCase1();
+  gtl::InlinedVector<TensorValue, 4> inputs;
+  TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
+  std::unique_ptr<OpKernel> range_dataset_kernel;
+  TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
+      test_case.dataset_params.node_name, &range_dataset_kernel));
+  std::unique_ptr<OpKernelContext> range_dataset_context;
+  TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
+                                         &range_dataset_context));
+  DatasetBase* range_dataset;
+  TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
+                             range_dataset_context.get(), &range_dataset));
+  core::ScopedUnref scoped_unref(range_dataset);
+
+  TF_ASSERT_OK(CheckDatasetOutputDtypes(*range_dataset,
+                                        test_case.expected_output_dtypes));
+}
+
+DatasetOutputShapesTestCase<RangeDatasetParams> DatasetOutputShapesTestCase1() {
+  return {/*dataset_params=*/PositiveStepRangeDataset(),
+          /*expected_output_shapes=*/{PartialTensorShape({})}};
+}
+
+TEST_F(RangeDatasetOpTest, DatasetOutputShapes) {
+  int thread_num = 2, cpu_num = 2;
+  TF_ASSERT_OK(InitThreadPool(thread_num));
+  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
+
+  auto test_case = DatasetOutputShapesTestCase1();
+  gtl::InlinedVector<TensorValue, 4> inputs;
+  TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
+  std::unique_ptr<OpKernel> range_dataset_kernel;
+  TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
+      test_case.dataset_params.node_name, &range_dataset_kernel));
+  std::unique_ptr<OpKernelContext> range_dataset_context;
+  TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
+                                         &range_dataset_context));
+  DatasetBase* range_dataset;
+  TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
+                             range_dataset_context.get(), &range_dataset));
+  core::ScopedUnref scoped_unref(range_dataset);
+
+  TF_ASSERT_OK(CheckDatasetOutputShapes(*range_dataset,
+                                        test_case.expected_output_shapes));
+}
+
+class ParameterizedCardinalityRangeDatasetOpTest
+    : public RangeDatasetOpTest,
+      public ::testing::WithParamInterface<
+          CardinalityTestCase<RangeDatasetParams>> {};
+
+CardinalityTestCase<RangeDatasetParams> CardinalityTestCase1() {
+  return {/*dataset_params=*/PositiveStepRangeDataset(),
+          /*expected_cardinality=*/4};
+}
+
+CardinalityTestCase<RangeDatasetParams> CardinalityTestCase2() {
+  return {/*dataset_params=*/NegativeStepRangeDataset(),
+          /*expected_cardinality=*/4};
+}
+
+TEST_P(ParameterizedCardinalityRangeDatasetOpTest, Cardinality) {
+  int thread_num = 2, cpu_num = 2;
+  TF_ASSERT_OK(InitThreadPool(thread_num));
+  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
+
+  auto test_case = GetParam();
+  gtl::InlinedVector<TensorValue, 4> inputs;
+  TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
+  std::unique_ptr<OpKernel> range_dataset_kernel;
+  TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
+      test_case.dataset_params.node_name, &range_dataset_kernel));
+  std::unique_ptr<OpKernelContext> range_dataset_context;
+  TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
+                                         &range_dataset_context));
+  DatasetBase* range_dataset;
+  TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
+                             range_dataset_context.get(), &range_dataset));
+  core::ScopedUnref scoped_unref(range_dataset);
+
+  TF_ASSERT_OK(
+      CheckDatasetCardinality(*range_dataset, test_case.expected_cardinality));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+    RangeDatasetOpTest, ParameterizedCardinalityRangeDatasetOpTest,
+    ::testing::ValuesIn(std::vector<CardinalityTestCase<RangeDatasetParams>>(
+        {CardinalityTestCase1(), CardinalityTestCase2()})));
+
+DatasetSaveTestCase<RangeDatasetParams> DatasetSaveTestCase1() {
+  return {/*dataset_params=*/PositiveStepRangeDataset()};
+}
+
+IsStatefulTestCase<RangeDatasetParams> IsStatefulTestCase1() {
+  return {/*dataset_params=*/PositiveStepRangeDataset(),
+          /*expected_stateful=*/false};
+}
+
+TEST_F(RangeDatasetOpTest, IsStateful) {
+  int64 thread_num = 2, cpu_num = 2;
+  TF_ASSERT_OK(InitThreadPool(thread_num));
+  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
+
+  auto test_case = IsStatefulTestCase1();
+  gtl::InlinedVector<TensorValue, 4> inputs;
+  TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
+  std::unique_ptr<OpKernel> range_dataset_kernel;
+  TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
+      test_case.dataset_params.node_name, &range_dataset_kernel));
+  std::unique_ptr<OpKernelContext> range_dataset_context;
+  TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
+                                         &range_dataset_context));
+  DatasetBase* range_dataset;
+  TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
+                             range_dataset_context.get(), &range_dataset));
+  core::ScopedUnref scoped_unref(range_dataset);
+
+  TF_ASSERT_OK(
+      CheckDatasetIsStateful(*range_dataset, test_case.expected_stateful));
+}
+
+IteratorOutputDtypesTestCase<RangeDatasetParams>
+IteratorOutputDtypesTestCase1() {
+  return {/*dataset_params=*/PositiveStepRangeDataset(),
+          /*expected_output_dtypes=*/{DT_INT64}};
+}
+
+TEST_F(RangeDatasetOpTest, IteratorOutputDtypes) {
+  int thread_num = 2, cpu_num = 2;
+  TF_ASSERT_OK(InitThreadPool(thread_num));
+  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
+
+  auto test_case = IteratorOutputDtypesTestCase1();
+  gtl::InlinedVector<TensorValue, 4> inputs;
+  TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
+  std::unique_ptr<OpKernel> range_dataset_kernel;
+  TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
+      test_case.dataset_params.node_name, &range_dataset_kernel));
+  std::unique_ptr<OpKernelContext> range_dataset_context;
+  TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
+                                         &range_dataset_context));
+  DatasetBase* range_dataset;
+  TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
+                             range_dataset_context.get(), &range_dataset));
+  core::ScopedUnref scoped_unref(range_dataset);
+
+  std::unique_ptr<IteratorContext> iterator_context;
+  TF_ASSERT_OK(
+      CreateIteratorContext(range_dataset_context.get(), &iterator_context));
+  std::unique_ptr<IteratorBase> iterator;
+  TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(),
+                                           kIteratorPrefix, &iterator));
+  TF_ASSERT_OK(
+      CheckIteratorOutputDtypes(*iterator, test_case.expected_output_dtypes));
+}
+
+IteratorOutputShapesTestCase<RangeDatasetParams>
+IteratorOutputShapesTestCase1() {
+  return {/*dataset_params=*/PositiveStepRangeDataset(),
+          /*expected_output_shapes=*/{PartialTensorShape({})}};
+}
+
+TEST_F(RangeDatasetOpTest, IteratorOutputShapes) {
+  int thread_num = 2, cpu_num = 2;
+  TF_ASSERT_OK(InitThreadPool(thread_num));
+  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
+
+  auto test_case = IteratorOutputShapesTestCase1();
+  gtl::InlinedVector<TensorValue, 4> inputs;
+  TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
+  std::unique_ptr<OpKernel> range_dataset_kernel;
+  TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
+      test_case.dataset_params.node_name, &range_dataset_kernel));
+  std::unique_ptr<OpKernelContext> range_dataset_context;
+  TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
+                                         &range_dataset_context));
+  DatasetBase* range_dataset;
+  TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
+                             range_dataset_context.get(), &range_dataset));
+  core::ScopedUnref scoped_unref(range_dataset);
+
+  std::unique_ptr<IteratorContext> iterator_context;
+  TF_ASSERT_OK(
+      CreateIteratorContext(range_dataset_context.get(), &iterator_context));
+  std::unique_ptr<IteratorBase> iterator;
+  TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(),
+                                           kIteratorPrefix, &iterator));
+  TF_ASSERT_OK(
+      CheckIteratorOutputShapes(*iterator, test_case.expected_output_shapes));
+}
+
+IteratorOutputPrefixTestCase<RangeDatasetParams>
+IteratorOutputPrefixTestCase1() {
+  return {/*dataset_params=*/PositiveStepRangeDataset(),
+          /*expected_iterator_prefix=*/name_utils::IteratorPrefix(
+              RangeDatasetOp::kDatasetType, kIteratorPrefix)};
+}
+
+TEST_F(RangeDatasetOpTest, IteratorOutputPrefix) {
+  int thread_num = 2, cpu_num = 2;
+  TF_ASSERT_OK(InitThreadPool(thread_num));
+  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
+
+  auto test_case = IteratorOutputPrefixTestCase1();
+  gtl::InlinedVector<TensorValue, 4> inputs;
+  TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
+  std::unique_ptr<OpKernel> range_dataset_kernel;
+  TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
+      test_case.dataset_params.node_name, &range_dataset_kernel));
+  std::unique_ptr<OpKernelContext> range_dataset_context;
+  TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
+                                         &range_dataset_context));
+  DatasetBase* range_dataset;
+  TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
+                             range_dataset_context.get(), &range_dataset));
+  core::ScopedUnref scoped_unref(range_dataset);
+
+  std::unique_ptr<IteratorContext> iterator_context;
+  TF_ASSERT_OK(
+      CreateIteratorContext(range_dataset_context.get(), &iterator_context));
+  std::unique_ptr<IteratorBase> iterator;
+  TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(),
+                                           kIteratorPrefix, &iterator));
+  TF_ASSERT_OK(
+      CheckIteratorPrefix(*iterator, test_case.expected_iterator_prefix));
+}
+
+class ParameterizedIteratorSaveAndRestoreRangeDatasetOpTest
+    : public RangeDatasetOpTest,
+      public ::testing::WithParamInterface<
+          IteratorSaveAndRestoreTestCase<RangeDatasetParams>> {};
+
+IteratorSaveAndRestoreTestCase<RangeDatasetParams>
+IteratorSaveAndRestoreTestCase1() {
+  return {/*dataset_params=*/PositiveStepRangeDataset(),
+          /*breakpoints=*/{0, 1, 4},
+          /*expected_outputs=*/
+          CreateTensors<int64>(TensorShape({}), {{0}, {3}, {6}, {9}})};
+}
+
+IteratorSaveAndRestoreTestCase<RangeDatasetParams>
+IteratorSaveAndRestoreTestCase2() {
+  return {/*dataset_params=*/NegativeStepRangeDataset(),
+          /*breakpoints=*/{0, 1, 4},
+          /*expected_outputs=*/
+          CreateTensors<int64>(TensorShape({}), {{10}, {7}, {4}, {1}})};
+}
+
+TEST_P(ParameterizedIteratorSaveAndRestoreRangeDatasetOpTest,
+       IteratorSaveAndRestore) {
+  int thread_num = 2, cpu_num = 2;
+  TF_ASSERT_OK(InitThreadPool(thread_num));
+  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
+
+  auto test_case = GetParam();
+  gtl::InlinedVector<TensorValue, 4> inputs;
+  TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
+  std::unique_ptr<OpKernel> range_dataset_kernel;
+  TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
+      test_case.dataset_params.node_name, &range_dataset_kernel));
+  std::unique_ptr<OpKernelContext> range_dataset_context;
+  TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
+                                         &range_dataset_context));
+  DatasetBase* range_dataset;
+  TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
+                             range_dataset_context.get(), &range_dataset));
+  core::ScopedUnref scoped_unref(range_dataset);
+
+  std::unique_ptr<IteratorContext> iterator_context;
+  TF_ASSERT_OK(
+      CreateIteratorContext(range_dataset_context.get(), &iterator_context));
+  std::unique_ptr<IteratorBase> iterator;
+  TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(),
+                                           kIteratorPrefix, &iterator));
+  TF_ASSERT_OK(CheckIteratorSaveAndRestore(
+      *range_dataset, iterator_context.get(), kIteratorPrefix,
+      test_case.expected_outputs, test_case.breakpoints));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+    RangeDatasetOpTest, ParameterizedIteratorSaveAndRestoreRangeDatasetOpTest,
+    ::testing::ValuesIn(
+        std::vector<IteratorSaveAndRestoreTestCase<RangeDatasetParams>>(
+            {IteratorSaveAndRestoreTestCase1(),
+             IteratorSaveAndRestoreTestCase2()})));
+
+GetNextTestCase<RangeDatasetParams> ZeroStepTestCase1() {
+  return {/*dataset_params=*/ZeroStepRangeDataset(),
+          /*expected_outputs=*/{}};
 }
 
 TEST_F(RangeDatasetOpTest, ZeroStep) {
@@ -143,16 +502,12 @@
   TF_ASSERT_OK(InitThreadPool(thread_num));
   TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
 
-  TestCase test_case = ZeroStepTestCase();
-  Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
-  Tensor stop = CreateTensor<int64>(TensorShape({}), {test_case.stop});
-  Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&start), TensorValue(&stop), TensorValue(&step)});
-
+  auto test_case = ZeroStepTestCase1();
+  gtl::InlinedVector<TensorValue, 4> inputs;
+  TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
   std::unique_ptr<OpKernel> range_dataset_kernel;
-  TF_ASSERT_OK(
-      CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
+  TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
+      test_case.dataset_params.node_name, &range_dataset_kernel));
   std::unique_ptr<OpKernelContext> range_dataset_context;
   TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
                                          &range_dataset_context));
@@ -163,344 +518,6 @@
             tensorflow::error::INVALID_ARGUMENT);
 }
 
-TEST_F(RangeDatasetOpTest, DatasetNodeName) {
-  int thread_num = 2, cpu_num = 2;
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  TestCase test_case = PositiveStepTestCase();
-  Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
-  Tensor stop = CreateTensor<int64>(TensorShape({}), {test_case.stop});
-  Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&start), TensorValue(&stop), TensorValue(&step)});
-
-  std::unique_ptr<OpKernel> range_dataset_kernel;
-  TF_ASSERT_OK(
-      CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
-  std::unique_ptr<OpKernelContext> range_dataset_context;
-  TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
-                                         &range_dataset_context));
-  DatasetBase* range_dataset;
-  TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
-                             range_dataset_context.get(), &range_dataset));
-  core::ScopedUnref scoped_unref(range_dataset);
-
-  EXPECT_EQ(range_dataset->node_name(), kNodeName);
-}
-
-TEST_F(RangeDatasetOpTest, DatasetTypeString) {
-  int thread_num = 2, cpu_num = 2;
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  TestCase test_case = PositiveStepTestCase();
-  Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
-  Tensor stop = CreateTensor<int64>(TensorShape({}), {test_case.stop});
-  Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&start), TensorValue(&stop), TensorValue(&step)});
-
-  std::unique_ptr<OpKernel> range_dataset_kernel;
-  TF_ASSERT_OK(
-      CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
-  std::unique_ptr<OpKernelContext> range_dataset_context;
-  TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
-                                         &range_dataset_context));
-  DatasetBase* range_dataset;
-  TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
-                             range_dataset_context.get(), &range_dataset));
-  core::ScopedUnref scoped_unref(range_dataset);
-
-  EXPECT_EQ(range_dataset->type_string(),
-            name_utils::OpName(RangeDatasetOp::kDatasetType));
-}
-
-TEST_F(RangeDatasetOpTest, DatasetOutputDtypes) {
-  int thread_num = 2, cpu_num = 2;
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  TestCase test_case = PositiveStepTestCase();
-  Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
-  Tensor stop = CreateTensor<int64>(TensorShape({}), {test_case.stop});
-  Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&start), TensorValue(&stop), TensorValue(&step)});
-
-  std::unique_ptr<OpKernel> range_dataset_kernel;
-  TF_ASSERT_OK(
-      CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
-  std::unique_ptr<OpKernelContext> range_dataset_context;
-  TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
-                                         &range_dataset_context));
-  DatasetBase* range_dataset;
-  TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
-                             range_dataset_context.get(), &range_dataset));
-  core::ScopedUnref scoped_unref(range_dataset);
-
-  TF_EXPECT_OK(VerifyTypesMatch(range_dataset->output_dtypes(),
-                                test_case.expected_output_dtypes));
-}
-
-TEST_F(RangeDatasetOpTest, DatasetOutputShapes) {
-  int thread_num = 2, cpu_num = 2;
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  TestCase test_case = PositiveStepTestCase();
-  Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
-  Tensor stop = CreateTensor<int64>(TensorShape({}), {test_case.stop});
-  Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&start), TensorValue(&stop), TensorValue(&step)});
-
-  std::unique_ptr<OpKernel> range_dataset_kernel;
-  TF_ASSERT_OK(
-      CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
-  std::unique_ptr<OpKernelContext> range_dataset_context;
-  TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
-                                         &range_dataset_context));
-  DatasetBase* range_dataset;
-  TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
-                             range_dataset_context.get(), &range_dataset));
-  core::ScopedUnref scoped_unref(range_dataset);
-
-  TF_EXPECT_OK(VerifyShapesCompatible(range_dataset->output_shapes(),
-                                      test_case.expected_output_shapes));
-}
-
-TEST_P(ParameterizedRangeDatasetOpTest, Cardinality) {
-  int thread_num = 2, cpu_num = 2;
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  TestCase test_case = GetParam();
-  Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
-  Tensor stop = CreateTensor<int64>(TensorShape({}), {test_case.stop});
-  Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&start), TensorValue(&stop), TensorValue(&step)});
-
-  std::unique_ptr<OpKernel> range_dataset_kernel;
-  TF_ASSERT_OK(
-      CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
-  std::unique_ptr<OpKernelContext> range_dataset_context;
-  TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
-                                         &range_dataset_context));
-  DatasetBase* range_dataset;
-  TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
-                             range_dataset_context.get(), &range_dataset));
-  core::ScopedUnref scoped_unref(range_dataset);
-
-  EXPECT_EQ(range_dataset->Cardinality(), test_case.expected_cardinality);
-}
-
-TEST_F(RangeDatasetOpTest, DatasetSave) {
-  int64 thread_num = 2, cpu_num = 2;
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  TestCase test_case = PositiveStepTestCase();
-  Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
-  Tensor stop = CreateTensor<int64>(TensorShape({}), {test_case.stop});
-  Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&start), TensorValue(&stop), TensorValue(&step)});
-
-  std::unique_ptr<OpKernel> range_dataset_kernel;
-  TF_ASSERT_OK(
-      CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
-  std::unique_ptr<OpKernelContext> range_dataset_context;
-  TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
-                                         &range_dataset_context));
-  DatasetBase* range_dataset;
-  TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
-                             range_dataset_context.get(), &range_dataset));
-  core::ScopedUnref scoped_unref(range_dataset);
-
-  std::unique_ptr<SerializationContext> serialization_context;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
-
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(range_dataset->Save(serialization_context.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
-TEST_F(RangeDatasetOpTest, IteratorOutputDtypes) {
-  int thread_num = 2, cpu_num = 2;
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  TestCase test_case = PositiveStepTestCase();
-  Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
-  Tensor stop = CreateTensor<int64>(TensorShape({}), {test_case.stop});
-  Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&start), TensorValue(&stop), TensorValue(&step)});
-
-  std::unique_ptr<OpKernel> range_dataset_kernel;
-  TF_ASSERT_OK(
-      CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
-  std::unique_ptr<OpKernelContext> range_dataset_context;
-  TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
-                                         &range_dataset_context));
-  DatasetBase* range_dataset;
-  TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
-                             range_dataset_context.get(), &range_dataset));
-  core::ScopedUnref scoped_unref(range_dataset);
-
-  std::unique_ptr<IteratorContext> iterator_context;
-  TF_ASSERT_OK(
-      CreateIteratorContext(range_dataset_context.get(), &iterator_context));
-  std::unique_ptr<IteratorBase> iterator;
-  TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
-                                           &iterator));
-
-  TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(),
-                                test_case.expected_output_dtypes));
-}
-
-TEST_F(RangeDatasetOpTest, IteratorOutputShapes) {
-  int thread_num = 2, cpu_num = 2;
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  TestCase test_case = PositiveStepTestCase();
-  Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
-  Tensor stop = CreateTensor<int64>(TensorShape({}), {test_case.stop});
-  Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&start), TensorValue(&stop), TensorValue(&step)});
-
-  std::unique_ptr<OpKernel> range_dataset_kernel;
-  TF_ASSERT_OK(
-      CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
-  std::unique_ptr<OpKernelContext> range_dataset_context;
-  TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
-                                         &range_dataset_context));
-  DatasetBase* range_dataset;
-  TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
-                             range_dataset_context.get(), &range_dataset));
-  core::ScopedUnref scoped_unref(range_dataset);
-
-  std::unique_ptr<IteratorContext> iterator_context;
-  TF_ASSERT_OK(
-      CreateIteratorContext(range_dataset_context.get(), &iterator_context));
-  std::unique_ptr<IteratorBase> iterator;
-  TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
-                                           &iterator));
-
-  TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(),
-                                      test_case.expected_output_shapes));
-}
-
-TEST_F(RangeDatasetOpTest, IteratorOutputPrefix) {
-  int thread_num = 2, cpu_num = 2;
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  TestCase test_case = PositiveStepTestCase();
-  Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
-  Tensor stop = CreateTensor<int64>(TensorShape({}), {test_case.stop});
-  Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&start), TensorValue(&stop), TensorValue(&step)});
-
-  std::unique_ptr<OpKernel> range_dataset_kernel;
-  TF_ASSERT_OK(
-      CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
-  std::unique_ptr<OpKernelContext> range_dataset_context;
-  TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
-                                         &range_dataset_context));
-  DatasetBase* range_dataset;
-  TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
-                             range_dataset_context.get(), &range_dataset));
-  core::ScopedUnref scoped_unref(range_dataset);
-
-  std::unique_ptr<IteratorContext> iterator_context;
-  TF_ASSERT_OK(
-      CreateIteratorContext(range_dataset_context.get(), &iterator_context));
-  std::unique_ptr<IteratorBase> iterator;
-  TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
-                                           &iterator));
-
-  EXPECT_EQ(iterator->prefix(), name_utils::IteratorPrefix(
-                                    RangeDatasetOp::kDatasetType, "Iterator"));
-}
-
-TEST_P(ParameterizedRangeDatasetOpTest, Roundtrip) {
-  int thread_num = 2, cpu_num = 2;
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  TestCase test_case = GetParam();
-  Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
-  Tensor stop = CreateTensor<int64>(TensorShape({}), {test_case.stop});
-  Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&start), TensorValue(&stop), TensorValue(&step)});
-
-  std::unique_ptr<OpKernel> range_dataset_kernel;
-  TF_ASSERT_OK(
-      CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
-  std::unique_ptr<OpKernelContext> range_dataset_context;
-  TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
-                                         &range_dataset_context));
-  DatasetBase* range_dataset;
-  TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
-                             range_dataset_context.get(), &range_dataset));
-  core::ScopedUnref scoped_unref(range_dataset);
-
-  std::unique_ptr<IteratorContext> iterator_context;
-  TF_ASSERT_OK(
-      CreateIteratorContext(range_dataset_context.get(), &iterator_context));
-  std::unique_ptr<IteratorBase> iterator;
-  TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
-                                           &iterator));
-
-  std::unique_ptr<SerializationContext> serialization_ctx;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
-  bool end_of_sequence = false;
-  std::vector<Tensor> out_tensors;
-  int cur_iteration = 0;
-  auto expected_outputs_it = test_case.expected_outputs.begin();
-  const std::vector<int>& breakpoints = test_case.breakpoints;
-  for (int breakpoint : breakpoints) {
-    VariantTensorData data;
-    VariantTensorDataWriter writer(&data);
-    TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
-    TF_EXPECT_OK(writer.Flush());
-    VariantTensorDataReader reader(&data);
-    TF_EXPECT_OK(RestoreIterator(iterator_context.get(), &reader, "Iterator",
-                                 *range_dataset, &iterator));
-
-    while (cur_iteration <= breakpoint) {
-      TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
-                                     &end_of_sequence));
-      if (!end_of_sequence) {
-        EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end());
-        TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it));
-        expected_outputs_it++;
-      }
-      cur_iteration++;
-    }
-
-    if (breakpoint >= test_case.expected_cardinality) {
-      EXPECT_TRUE(end_of_sequence);
-      EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
-    } else {
-      EXPECT_FALSE(end_of_sequence);
-    }
-  }
-}
-
-INSTANTIATE_TEST_SUITE_P(
-    RangeDatasetOpTest, ParameterizedRangeDatasetOpTest,
-    ::testing::ValuesIn(std::vector<TestCase>({PositiveStepTestCase(),
-                                               NegativeStepTestCase()})));
-
 }  // namespace
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/reduce_dataset_op_test.cc b/tensorflow/core/kernels/data/reduce_dataset_op_test.cc
index 825168f..2f90e2d 100644
--- a/tensorflow/core/kernels/data/reduce_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/reduce_dataset_op_test.cc
@@ -82,14 +82,14 @@
 TestCase TestCase1() {
   return {/*range_data_param*/ {0, 10, 1},
           /*initial_state*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0})},
+          {CreateTensor<int64>(TensorShape({}), {0})},
           /*func*/
           FunctionDefHelper::FunctionRef("XAddY", {{"T", DT_INT64}}),
           /*func_lib*/ {test::function::XAddY()},
           /*t_state*/ {DT_INT64},
           /*use_inter_op_parallelism*/ true,
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {45})},
+          {CreateTensor<int64>(TensorShape({}), {45})},
           /*output_dtypes*/ {DT_INT64},
           /*output_shapes*/ {PartialTensorShape({})}};
 }
@@ -103,17 +103,17 @@
 TestCase TestCase2() {
   return {/*range_data_param*/ {1, 10, 1},
           /*initial_state*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1})},
+          {CreateTensor<int64>(TensorShape({}), {1}),
+           CreateTensor<int64>(TensorShape({}), {1})},
           /*func*/
           FunctionDefHelper::FunctionRef("XPlusOneXTimesY", {{"T", DT_INT64}}),
           /*func_lib*/ {test::function::XPlusOneXTimesY()},
           /*t_state*/ {DT_INT64, DT_INT64},
           /*use_inter_op_parallelism*/ true,
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {10}),
-           DatasetOpsTestBase::CreateTensor<int64>(
-               TensorShape({}), {1 * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9})},
+          {CreateTensor<int64>(TensorShape({}), {10}),
+           CreateTensor<int64>(TensorShape({}),
+                               {1 * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9})},
           /*output_dtypes*/ {DT_INT64, DT_INT64},
           /*output_shapes*/ {PartialTensorShape({}), PartialTensorShape({})}};
 }
@@ -123,16 +123,16 @@
 TestCase TestCase3() {
   return {/*range_data_param*/ {0, 0, 1},
           /*initial_state*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3})},
+          {CreateTensor<int64>(TensorShape({}), {1}),
+           CreateTensor<int64>(TensorShape({}), {3})},
           /*func*/
           FunctionDefHelper::FunctionRef("XAddY", {{"T", DT_INT64}}),
           /*func_lib*/ {test::function::XAddY()},
           /*t_state*/ {DT_INT64, DT_INT64},
           /*use_inter_op_parallelism*/ true,
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3})},
+          {CreateTensor<int64>(TensorShape({}), {1}),
+           CreateTensor<int64>(TensorShape({}), {3})},
           /*output_dtypes*/ {DT_INT64, DT_INT64},
           /*output_shapes*/ {PartialTensorShape({}), PartialTensorShape({})}};
 }
diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc
index 8b918e9..6ec0b01 100644
--- a/tensorflow/core/kernels/data/repeat_dataset_op.cc
+++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc
@@ -89,6 +89,10 @@
     return count_ * n;
   }
 
+  Status CheckExternalState() const override {
+    return input_->CheckExternalState();
+  }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
diff --git a/tensorflow/core/kernels/data/repeat_dataset_op_test.cc b/tensorflow/core/kernels/data/repeat_dataset_op_test.cc
index 6f39969..0b55a05 100644
--- a/tensorflow/core/kernels/data/repeat_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/repeat_dataset_op_test.cc
@@ -71,51 +71,49 @@
 };
 
 TestCase FiniteRepeatTestCase() {
-  return {
-      /*input_tensors*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 2}, {1, 2, 3, 4}),
-       DatasetOpsTestBase::CreateTensor<string>(TensorShape{2, 1}, {"a", "b"})},
-      /*count*/ 2,
-      /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2}, {1, 2}),
-       DatasetOpsTestBase::CreateTensor<string>(TensorShape{1}, {"a"}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2}, {3, 4}),
-       DatasetOpsTestBase::CreateTensor<string>(TensorShape{1}, {"b"}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2}, {1, 2}),
-       DatasetOpsTestBase::CreateTensor<string>(TensorShape{1}, {"a"}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2}, {3, 4}),
-       DatasetOpsTestBase::CreateTensor<string>(TensorShape{1}, {"b"})},
-      /*expected_output_dtypes*/ {DT_INT64, DT_STRING},
-      /*expected_output_shapes*/
-      {PartialTensorShape({2}), PartialTensorShape({1})},
-      /*expected_cardinality*/ 4,
-      /*breakpoints*/ {0, 1, 3}};
+  return {/*input_tensors*/
+          {CreateTensor<int64>(TensorShape{2, 2}, {1, 2, 3, 4}),
+           CreateTensor<string>(TensorShape{2, 1}, {"a", "b"})},
+          /*count*/ 2,
+          /*expected_outputs*/
+          {CreateTensor<int64>(TensorShape{2}, {1, 2}),
+           CreateTensor<string>(TensorShape{1}, {"a"}),
+           CreateTensor<int64>(TensorShape{2}, {3, 4}),
+           CreateTensor<string>(TensorShape{1}, {"b"}),
+           CreateTensor<int64>(TensorShape{2}, {1, 2}),
+           CreateTensor<string>(TensorShape{1}, {"a"}),
+           CreateTensor<int64>(TensorShape{2}, {3, 4}),
+           CreateTensor<string>(TensorShape{1}, {"b"})},
+          /*expected_output_dtypes*/ {DT_INT64, DT_STRING},
+          /*expected_output_shapes*/
+          {PartialTensorShape({2}), PartialTensorShape({1})},
+          /*expected_cardinality*/ 4,
+          /*breakpoints*/ {0, 1, 3}};
 }
 
 TestCase EmptyRepeatTestCase() {
-  return {
-      /*input_tensors*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 2}, {1, 2, 3, 4}),
-       DatasetOpsTestBase::CreateTensor<string>(TensorShape{2, 1}, {"a", "b"})},
-      /*count*/ 0,
-      /*expected_outputs*/
-      {},
-      /*expected_output_dtypes*/ {DT_INT64, DT_STRING},
-      /*expected_output_shapes*/
-      {PartialTensorShape({2}), PartialTensorShape({1})},
-      /*expected_cardinality*/ 0,
-      /*breakpoints*/ {0, 1, 3}};
+  return {/*input_tensors*/
+          {CreateTensor<int64>(TensorShape{2, 2}, {1, 2, 3, 4}),
+           CreateTensor<string>(TensorShape{2, 1}, {"a", "b"})},
+          /*count*/ 0,
+          /*expected_outputs*/
+          {},
+          /*expected_output_dtypes*/ {DT_INT64, DT_STRING},
+          /*expected_output_shapes*/
+          {PartialTensorShape({2}), PartialTensorShape({1})},
+          /*expected_cardinality*/ 0,
+          /*breakpoints*/ {0, 1, 3}};
 }
 
 TestCase ForeverRepeatTestCase() {
   return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 1}, {1, 2})},
+          {CreateTensor<int64>(TensorShape{2, 1}, {1, 2})},
           /*count*/ -1,
           /*expected_outputs*/
           // Use the first group of the repeated tensors to represent the
           // infinite outputs.
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {1}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {2})},
+          {CreateTensor<int64>(TensorShape{1}, {1}),
+           CreateTensor<int64>(TensorShape{1}, {2})},
           /*expected_output_dtypes*/ {DT_INT64},
           /*expected_output_shapes*/ {PartialTensorShape({1})},
           /*expected_cardinality*/ -1,
@@ -351,41 +349,6 @@
   EXPECT_EQ(repeat_dataset->Cardinality(), GetParam().expected_cardinality);
 }
 
-TEST_F(RepeatDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-  const TestCase &test_case = FiniteRepeatTestCase();
-  Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
-  std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
-  TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
-                                              &tensor_slice_dataset_tensor));
-  Tensor count = CreateTensor<int64>(TensorShape{}, {test_case.count});
-  gtl::InlinedVector<TensorValue, 4> inputs_for_repeat_dataset;
-  inputs_for_repeat_dataset.emplace_back(&tensor_slice_dataset_tensor);
-  inputs_for_repeat_dataset.emplace_back(&count);
-
-  std::unique_ptr<OpKernel> repeat_dataset_kernel;
-  TF_ASSERT_OK(CreateRepeatDatasetKernel(test_case.expected_output_dtypes,
-                                         test_case.expected_output_shapes,
-                                         &repeat_dataset_kernel));
-  std::unique_ptr<OpKernelContext> repeat_dataset_context;
-  TF_ASSERT_OK(CreateRepeatDatasetContext(repeat_dataset_kernel.get(),
-                                          &inputs_for_repeat_dataset,
-                                          &repeat_dataset_context));
-  DatasetBase *repeat_dataset;
-  TF_ASSERT_OK(CreateDataset(repeat_dataset_kernel.get(),
-                             repeat_dataset_context.get(), &repeat_dataset));
-  core::ScopedUnref scoped_unref(repeat_dataset);
-
-  std::unique_ptr<SerializationContext> serialization_ctx;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(repeat_dataset->Save(serialization_ctx.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_P(ParameterizedDatasetOpTest, IteratorOutputDtypes) {
   int thread_num = 2, cpu_num = 2;
   TF_ASSERT_OK(InitThreadPool(thread_num));
diff --git a/tensorflow/core/kernels/data/shard_dataset_op.cc b/tensorflow/core/kernels/data/shard_dataset_op.cc
index d88654f..e79d343 100644
--- a/tensorflow/core/kernels/data/shard_dataset_op.cc
+++ b/tensorflow/core/kernels/data/shard_dataset_op.cc
@@ -79,6 +79,10 @@
     return n / num_shards_ + (index_ < n % num_shards_ ? 1 : 0);
   }
 
+  Status CheckExternalState() const override {
+    return input_->CheckExternalState();
+  }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
diff --git a/tensorflow/core/kernels/data/shard_dataset_op_test.cc b/tensorflow/core/kernels/data/shard_dataset_op_test.cc
index b51e296..b101327 100644
--- a/tensorflow/core/kernels/data/shard_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/shard_dataset_op_test.cc
@@ -70,13 +70,13 @@
 TestCase TestCase1() {
   return {/*range_data_param*/ {0, 10, 1},
           /*num_shards*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {5}),
+          CreateTensor<int64>(TensorShape({}), {5}),
           /*index*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
+          CreateTensor<int64>(TensorShape({}), {2}),
           /*require_non_empty*/ true,
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {7})},
+          {CreateTensor<int64>(TensorShape({}), {2}),
+           CreateTensor<int64>(TensorShape({}), {7})},
           /*expected_output_dtypes*/ {DT_INT64},
           /*expected_output_shapes*/ {PartialTensorShape({})},
           /*expected_cardinality*/ 2,
@@ -87,13 +87,13 @@
 TestCase TestCase2() {
   return {/*range_data_param*/ {0, 10, 1},
           /*num_shards*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {5}),
+          CreateTensor<int64>(TensorShape({}), {5}),
           /*index*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
+          CreateTensor<int64>(TensorShape({}), {0}),
           /*require_non_empty*/ true,
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {5})},
+          {CreateTensor<int64>(TensorShape({}), {0}),
+           CreateTensor<int64>(TensorShape({}), {5})},
           /*expected_output_dtypes*/ {DT_INT64},
           /*expected_output_shapes*/ {PartialTensorShape({})},
           /*expected_cardinality*/ 2,
@@ -104,9 +104,9 @@
 TestCase TestCase3() {
   return {/*range_data_param*/ {0, 1, 1},
           /*num_shards*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {5}),
+          CreateTensor<int64>(TensorShape({}), {5}),
           /*index*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
+          CreateTensor<int64>(TensorShape({}), {2}),
           /*require_non_empty*/ true,
           /*expected_outputs*/ {},
           /*expected_output_dtypes*/ {DT_INT64},
@@ -119,12 +119,12 @@
 TestCase TestCase4() {
   return {/*range_data_param*/ {0, 10, 1},
           /*num_shards*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {7}),
+          CreateTensor<int64>(TensorShape({}), {7}),
           /*index*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {5}),
+          CreateTensor<int64>(TensorShape({}), {5}),
           /*require_non_empty*/ true,
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {5})},
+          {CreateTensor<int64>(TensorShape({}), {5})},
           /*expected_output_dtypes*/ {DT_INT64},
           /*expected_output_shapes*/ {PartialTensorShape({})},
           /*expected_cardinality*/ 1,
@@ -135,13 +135,13 @@
 TestCase TestCase5() {
   return {/*range_data_param*/ {0, 10, 1},
           /*num_shards*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {5}),
+          CreateTensor<int64>(TensorShape({}), {5}),
           /*index*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4}),
+          CreateTensor<int64>(TensorShape({}), {4}),
           /*require_non_empty*/ true,
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {9})},
+          {CreateTensor<int64>(TensorShape({}), {4}),
+           CreateTensor<int64>(TensorShape({}), {9})},
           /*expected_output_dtypes*/ {DT_INT64},
           /*expected_output_shapes*/ {PartialTensorShape({})},
           /*expected_cardinality*/ 2,
@@ -153,13 +153,13 @@
 TestCase TestCase6() {
   return {/*range_data_param*/ {0, 10, 1},
           /*num_shards*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4}),
+          CreateTensor<int64>(TensorShape({}), {4}),
           /*index*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3}),
+          CreateTensor<int64>(TensorShape({}), {3}),
           /*require_non_empty*/ true,
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {7})},
+          {CreateTensor<int64>(TensorShape({}), {3}),
+           CreateTensor<int64>(TensorShape({}), {7})},
           /*expected_output_dtypes*/ {DT_INT64},
           /*expected_output_shapes*/ {PartialTensorShape({})},
           /*expected_cardinality*/ 2,
@@ -171,12 +171,12 @@
 TestCase TestCase7() {
   return {/*range_data_param*/ {0, 10, 1},
           /*num_shards*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {20}),
+          CreateTensor<int64>(TensorShape({}), {20}),
           /*index*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {5}),
+          CreateTensor<int64>(TensorShape({}), {5}),
           /*require_non_empty*/ false,
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {5})},
+          {CreateTensor<int64>(TensorShape({}), {5})},
           /*expected_output_dtypes*/ {DT_INT64},
           /*expected_output_shapes*/ {PartialTensorShape({})},
           /*expected_cardinality*/ 1,
@@ -187,12 +187,12 @@
 TestCase NoElemForEachShardTestCase() {
   return {/*range_data_param*/ {0, 10, 1},
           /*num_shards*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {20}),
+          CreateTensor<int64>(TensorShape({}), {20}),
           /*index*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {5}),
+          CreateTensor<int64>(TensorShape({}), {5}),
           /*require_non_empty*/ true,
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {5})},
+          {CreateTensor<int64>(TensorShape({}), {5})},
           /*expected_output_dtypes*/ {DT_INT64},
           /*expected_output_shapes*/ {PartialTensorShape({})},
           /*expected_cardinality*/ 1,
@@ -202,9 +202,9 @@
 TestCase IndexGreaterNumShardsCase() {
   return {/*range_data_param*/ {0, 10, 1},
           /*num_shards*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {5}),
+          CreateTensor<int64>(TensorShape({}), {5}),
           /*index*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {7}),
+          CreateTensor<int64>(TensorShape({}), {7}),
           /*require_non_empty*/ true,
           /*expected_outputs*/ {},
           /*expected_output_dtypes*/ {DT_INT64},
@@ -216,9 +216,9 @@
 TestCase NegativeIndexTestCase() {
   return {/*range_data_param*/ {0, 10, 1},
           /*num_shards*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {5}),
+          CreateTensor<int64>(TensorShape({}), {5}),
           /*index*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {-3}),
+          CreateTensor<int64>(TensorShape({}), {-3}),
           /*require_non_empty*/ true,
           /*expected_outputs*/ {},
           /*expected_output_dtypes*/ {DT_INT64},
@@ -230,9 +230,9 @@
 TestCase NegativeNumShardsTestCase() {
   return {/*range_data_param*/ {0, 10, 1},
           /*num_shards*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {-3}),
+          CreateTensor<int64>(TensorShape({}), {-3}),
           /*index*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
+          CreateTensor<int64>(TensorShape({}), {1}),
           /*require_non_empty*/ true,
           /*expected_outputs*/ {},
           /*expected_output_dtypes*/ {DT_INT64},
@@ -244,9 +244,9 @@
 TestCase ZeroNumShardsTestCase() {
   return {/*range_data_param*/ {0, 10, 1},
           /*num_shards*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
+          CreateTensor<int64>(TensorShape({}), {0}),
           /*index*/
-          DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
+          CreateTensor<int64>(TensorShape({}), {1}),
           /*require_non_empty*/ true,
           /*expected_outputs*/ {},
           /*expected_output_dtypes*/ {DT_INT64},
@@ -497,47 +497,6 @@
   EXPECT_EQ(shard_dataset->Cardinality(), test_case.expected_cardinality);
 }
 
-TEST_P(ParameterizedShardDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  TestCase test_case = GetParam();
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  std::unique_ptr<OpKernel> shard_dataset_kernel;
-  TF_ASSERT_OK(CreateShardDatasetOpKernel(
-      test_case.require_non_empty, test_case.expected_output_dtypes,
-      test_case.expected_output_shapes, &shard_dataset_kernel));
-
-  DatasetBase* range_dataset;
-  TF_ASSERT_OK(CreateRangeDataset<int64>(
-      test_case.range_dataset_param.start, test_case.range_dataset_param.end,
-      test_case.range_dataset_param.step, "range", &range_dataset));
-  Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
-  TF_ASSERT_OK(
-      StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
-
-  Tensor num_shards = test_case.num_shards;
-  Tensor index = test_case.index;
-  gtl::InlinedVector<TensorValue, 4> inputs({TensorValue(&range_dataset_tensor),
-                                             TensorValue(&num_shards),
-                                             TensorValue(&index)});
-  std::unique_ptr<OpKernelContext> shard_dataset_context;
-  TF_ASSERT_OK(CreateShardDatasetContext(shard_dataset_kernel.get(), &inputs,
-                                         &shard_dataset_context));
-
-  DatasetBase* shard_dataset;
-  TF_ASSERT_OK(CreateDataset(shard_dataset_kernel.get(),
-                             shard_dataset_context.get(), &shard_dataset));
-  core::ScopedUnref scoped_unref_batch_dataset(shard_dataset);
-
-  std::unique_ptr<SerializationContext> serialization_context;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(shard_dataset->Save(serialization_context.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_P(ParameterizedShardDatasetOpTest, IteratorOutputDtypes) {
   int thread_num = 2, cpu_num = 2;
   TestCase test_case = GetParam();
diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
index 0be76f3..d5d7f0a 100644
--- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc
+++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
@@ -22,6 +22,8 @@
 #include "tensorflow/core/framework/resource_mgr.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/kernels/data/name_utils.h"
+#include "tensorflow/core/kernels/data/random_seed_ops.h"
+#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/random/philox_random.h"
 #include "tensorflow/core/lib/random/random.h"
 #include "tensorflow/core/lib/random/random_distributions.h"
@@ -64,6 +66,7 @@
 constexpr char kDSNumRandomSamples[] = "ds_num_random_samples";
 constexpr char kFixedSeedDatasetPrefix[] = "FixedSeed";
 constexpr char kReshufflingDatasetPrefix[] = "Reshuffling";
+constexpr char kShuffleDataset[] = "ShuffleDataset";
 
 ShuffleDatasetOpBase::ShuffleDatasetOpBase(OpKernelConstruction* ctx)
     : UnaryDatasetOpKernel(ctx) {}
@@ -100,6 +103,10 @@
     }
   }
 
+  Status CheckExternalState() const override {
+    return input_->CheckExternalState();
+  }
+
  protected:
   template <class T>
   class Iterator : public DatasetIterator<T> {
@@ -383,12 +390,6 @@
   const int64 count_;
 };
 
-ShuffleDatasetOp::ShuffleDatasetOp(OpKernelConstruction* ctx)
-    : ShuffleDatasetOpBase(ctx) {
-  OP_REQUIRES_OK(
-      ctx, ctx->GetAttr(kReshuffleEachIteration, &reshuffle_each_iteration_));
-}
-
 // A dataset that uses a pseudorandom sequence of seeds for the iterators
 // created from it. Used when `reshuffle_each_iteration` is true.
 class ShuffleDatasetOp::ReshufflingDataset : public ShuffleDatasetBase {
@@ -415,59 +416,9 @@
   }
 
  protected:
-  class RandomSeedGenerator : public ResourceBase {
-   public:
-    RandomSeedGenerator(int64 seed, int64 seed2)
-        : seed_(seed),
-          seed2_(seed2),
-          parent_generator_(seed, seed2),
-          generator_(&parent_generator_) {}
-
-    string DebugString() const override {
-      return strings::StrCat(kReshufflingDatasetPrefix, name_utils::kDelimiter,
-                             kRandomSeedGenerator);
-    }
-
-    void GenerateRandomSeeds(int64* seed1, int64* seed2) {
-      mutex_lock l(mu_);
-      num_random_samples_++;
-      *seed1 = generator_();
-      num_random_samples_++;
-      *seed2 = generator_();
-    }
-
-    int64 num_random_samples() {
-      tf_shared_lock l(mu_);
-      return num_random_samples_;
-    }
-
-    void set_num_random_samples(int64 num_random_samples) {
-      mutex_lock l(mu_);
-      num_random_samples_ = num_random_samples;
-    }
-
-    void Reset() {
-      mutex_lock l(mu_);
-      // Reset the generators based on the current seeds.
-      parent_generator_ = random::PhiloxRandom(seed_, seed2_);
-      generator_ =
-          random::SingleSampleAdapter<random::PhiloxRandom>(&parent_generator_);
-      generator_.Skip(num_random_samples_);
-    }
-
-   private:
-    const int64 seed_;
-    const int64 seed2_;
-    mutex mu_;
-    random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
-    random::SingleSampleAdapter<random::PhiloxRandom> generator_
-        GUARDED_BY(mu_);
-    int64 num_random_samples_ GUARDED_BY(mu_) = 0;
-  };
-
   class Iterator : public ShuffleDatasetBase::Iterator<ReshufflingDataset> {
    public:
-    explicit Iterator(const Params& params, int64 seed, int64 seed2)
+    Iterator(const Params& params, int64 seed, int64 seed2)
         : ShuffleDatasetBase::Iterator<ReshufflingDataset>(params, seed,
                                                            seed2) {}
 
@@ -500,13 +451,10 @@
                 new RandomSeedGenerator(dataset_seed, dataset_seed2);
             return Status::OK();
           }));
-      // Now use the seed generator to update the base class Iterator seeds
-      // and random number generator with generated seeds for the current
-      // repetition.
-      mutex_lock l(mu_);
-      seed_generator->GenerateRandomSeeds(&seed_, &seed2_);
-      ResetRngs();
       seed_generator_ = seed_generator;
+      seed_generator_->GenerateRandomSeeds(&seed_, &seed2_);
+      mutex_lock l(mu_);
+      ResetRngs();
       return Status::OK();
     }
 
@@ -573,6 +521,111 @@
   const int64 seed2_;
 };
 
+// A dataset that uses a pseudorandom sequence of seeds for the iterators
+// created from it. Used in TF 2.0 when `reshuffle_each_iteration` is true.
+class ShuffleDatasetOp::ReshufflingDatasetV2 : public ShuffleDatasetBase {
+ public:
+  ReshufflingDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
+                       int64 buffer_size, int64 count,
+                       const Tensor& resource_handle,
+                       RandomSeedGenerator* seed_generator)
+      : ShuffleDatasetBase(ctx, input, buffer_size, count),
+        resource_handle_(resource_handle),
+        seed_generator_(seed_generator) {}
+
+  ~ReshufflingDatasetV2() override { seed_generator_->Unref(); }
+
+  string DebugString() const override {
+    name_utils::DatasetDebugStringParams params;
+    params.dataset_prefix = kReshufflingDatasetPrefix;
+    params.set_args(buffer_size_);
+    return name_utils::DatasetDebugString(kDatasetType, params);
+  }
+
+  Status CheckExternalState() const override {
+    return errors::FailedPrecondition(
+        DebugString(), " depends on random seed generator resource.");
+  }
+
+  std::unique_ptr<IteratorBase> MakeIteratorInternal(
+      const string& prefix) const override {
+    return absl::make_unique<Iterator>(
+        Iterator::Params{this,
+                         name_utils::IteratorPrefix(kDatasetType, prefix)},
+        seed_generator_);
+  }
+
+ protected:
+  class Iterator : public ShuffleDatasetBase::Iterator<ReshufflingDatasetV2> {
+   public:
+    Iterator(const Params& params, RandomSeedGenerator* seed_generator)
+        : ShuffleDatasetBase::Iterator<ReshufflingDatasetV2>(params, 0, 0),
+          seed_generator_(seed_generator) {}
+
+    Status Initialize(IteratorContext* ctx) override {
+      mutex_lock l(mu_);
+      seed_generator_->GenerateRandomSeeds(&seed_, &seed2_);
+      ResetRngs();
+      return Status::OK();
+    }
+
+   protected:
+    std::shared_ptr<model::Node> CreateNode(
+        IteratorContext* ctx, model::Node::Args args) const override {
+      return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1);
+    }
+
+    Status SaveInternal(IteratorStateWriter* writer) override {
+      // Save state of the seed generator.
+      TF_RETURN_IF_ERROR(
+          writer->WriteScalar(full_name(kDSNumRandomSamples),
+                              seed_generator_->num_random_samples()));
+
+      // Save the tterator state.
+      return ShuffleDatasetBase::Iterator<ReshufflingDatasetV2>::SaveInternal(
+          writer);
+    }
+
+    Status RestoreInternal(IteratorContext* ctx,
+                           IteratorStateReader* reader) override {
+      // Restore state of the seed generator.
+      int64 num_random_samples;
+      TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kDSNumRandomSamples),
+                                            &num_random_samples));
+      seed_generator_->set_num_random_samples(num_random_samples);
+      seed_generator_->Reset();
+
+      // Restore the iterator state.
+      return ShuffleDatasetBase::Iterator<
+          ReshufflingDatasetV2>::RestoreInternal(ctx, reader);
+    }
+
+   private:
+    RandomSeedGenerator* seed_generator_;
+  };
+
+  Status AsGraphDefInternal(SerializationContext* ctx,
+                            DatasetGraphDefBuilder* b,
+                            Node** output) const override {
+    Node* input_graph_node = nullptr;
+    TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+    Node* buffer_size_node = nullptr;
+    TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
+    Node* resource_handle_node = nullptr;
+    TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node));
+    TF_RETURN_IF_ERROR(b->AddDataset(
+        this,
+        {input_graph_node, buffer_size_node, resource_handle_node},  // Inputs
+        {},                                                          // Attrs
+        output));
+    return Status::OK();
+  }
+
+ private:
+  const Tensor resource_handle_;
+  RandomSeedGenerator* seed_generator_ = nullptr;
+};
+
 // A dataset that uses the same fixed seed for all iterators created from it.
 // Used when `reshuffle_each_iteration` is false.
 class ShuffleDatasetOp::FixedSeedDataset : public ShuffleDatasetBase {
@@ -626,6 +679,15 @@
   const int64 seed2_;
 };
 
+ShuffleDatasetOp::ShuffleDatasetOp(OpKernelConstruction* ctx)
+    : ShuffleDatasetOpBase(ctx),
+      op_version_(ctx->def().op() == kShuffleDataset ? 1 : 2) {
+  if (ctx->HasAttr(kReshuffleEachIteration)) {
+    OP_REQUIRES_OK(
+        ctx, ctx->GetAttr(kReshuffleEachIteration, &reshuffle_each_iteration_));
+  }
+}
+
 void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
                                    DatasetBase** output) {
   int64 buffer_size = 0;
@@ -635,6 +697,18 @@
       ctx, buffer_size > 0,
       errors::InvalidArgument("buffer_size must be greater than zero."));
 
+  int64 count = 1;
+  if (op_version_ == 2) {
+    RandomSeedGenerator* seed_generator = nullptr;
+    OP_REQUIRES_OK(
+        ctx, LookupResource(ctx, HandleFromInput(ctx, 2), &seed_generator));
+    // Transferring ownership of seed generator reference onto
+    // `ReshufflingDatasetV2`.
+    *output = new ReshufflingDatasetV2(ctx, input, buffer_size, count,
+                                       ctx->input(2), seed_generator);
+    return;
+  }
+
   int64 seed;
   OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed, &seed));
 
@@ -648,7 +722,6 @@
     seed2 = random::New64();
   }
 
-  int64 count = 1;
   if (reshuffle_each_iteration_) {
     *output =
         new ReshufflingDataset(ctx, input, buffer_size, seed, seed2, count);
@@ -746,6 +819,9 @@
 REGISTER_KERNEL_BUILDER(Name("ShuffleDataset").Device(DEVICE_CPU),
                         ShuffleDatasetOp);
 
+REGISTER_KERNEL_BUILDER(Name("ShuffleDatasetV2").Device(DEVICE_CPU),
+                        ShuffleDatasetOp);
+
 REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU),
                         ShuffleAndRepeatDatasetOp);
 }  // namespace
diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.h b/tensorflow/core/kernels/data/shuffle_dataset_op.h
index 280221b..33b33f8 100644
--- a/tensorflow/core/kernels/data/shuffle_dataset_op.h
+++ b/tensorflow/core/kernels/data/shuffle_dataset_op.h
@@ -49,7 +49,9 @@
 
  private:
   class ReshufflingDataset;
+  class ReshufflingDatasetV2;
   class FixedSeedDataset;
+  int op_version_;
   bool reshuffle_each_iteration_;
 };
 
diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op_test.cc b/tensorflow/core/kernels/data/shuffle_dataset_op_test.cc
index b03f7d7..a017c53 100644
--- a/tensorflow/core/kernels/data/shuffle_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/shuffle_dataset_op_test.cc
@@ -89,231 +89,219 @@
   std::vector<Tensor> tensors;
   tensors.reserve(values.size());
   for (auto& value : values) {
-    tensors.emplace_back(
-        DatasetOpsTestBase::CreateTensor<T>(TensorShape({}), {value}));
+    tensors.emplace_back(CreateTensor<T>(TensorShape({}), {value}));
   }
   return tensors;
 }
 
 // Test case 1: test shuffle_dataset with reshuffle_each_iteration = false.
 TestCase TestCase1() {
-  return {
-      /*range_data_param*/ {0, 10, 1},
-      /*buffer_size*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3}),
-      /*seed*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-      /*seed2*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*count*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-      /*reshuffle_each_iteration*/ false,
-      /*expected_shuffle_outputs*/
-      ConvertToTensorVec<int64>({2, 3, 0, 5, 6, 4, 7, 8, 9, 1}),
-      /*expected_reshuffle_outputs*/
-      ConvertToTensorVec<int64>({2, 3, 0, 5, 6, 4, 7, 8, 9, 1}),
-      /*expected_output_dtypes*/ {DT_INT64},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ 10,
-      /*breakpoints*/ {0, 1, 9}};
+  return {/*range_data_param*/ {0, 10, 1},
+          /*buffer_size*/
+          CreateTensor<int64>(TensorShape({}), {3}),
+          /*seed*/ CreateTensor<int64>(TensorShape({}), {1}),
+          /*seed2*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*count*/ CreateTensor<int64>(TensorShape({}), {1}),
+          /*reshuffle_each_iteration*/ false,
+          /*expected_shuffle_outputs*/
+          ConvertToTensorVec<int64>({2, 3, 0, 5, 6, 4, 7, 8, 9, 1}),
+          /*expected_reshuffle_outputs*/
+          ConvertToTensorVec<int64>({2, 3, 0, 5, 6, 4, 7, 8, 9, 1}),
+          /*expected_output_dtypes*/ {DT_INT64},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ 10,
+          /*breakpoints*/ {0, 1, 9}};
 }
 
 // Test case 2: test shuffle_dataset with reshuffle_each_iteration = true.
 TestCase TestCase2() {
-  return {
-      /*range_data_param*/ {0, 10, 1},
-      /*buffer_size*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {10}),
-      /*seed*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-      /*seed2*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*count*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-      /*reshuffle_each_iteration*/ true,
-      /*expected_shuffle_outputs*/
-      ConvertToTensorVec<int64>({2, 6, 1, 3, 9, 5, 0, 8, 7, 4}),
-      /*expected_reshuffle_outputs*/
-      ConvertToTensorVec<int64>({1, 6, 0, 5, 2, 7, 4, 3, 9, 8}),
-      /*expected_output_dtypes*/ {DT_INT64},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ 10,
-      /*breakpoints*/ {0, 1, 9}};
+  return {/*range_data_param*/ {0, 10, 1},
+          /*buffer_size*/
+          CreateTensor<int64>(TensorShape({}), {10}),
+          /*seed*/ CreateTensor<int64>(TensorShape({}), {1}),
+          /*seed2*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*count*/ CreateTensor<int64>(TensorShape({}), {1}),
+          /*reshuffle_each_iteration*/ true,
+          /*expected_shuffle_outputs*/
+          ConvertToTensorVec<int64>({2, 6, 1, 3, 9, 5, 0, 8, 7, 4}),
+          /*expected_reshuffle_outputs*/
+          ConvertToTensorVec<int64>({1, 6, 0, 5, 2, 7, 4, 3, 9, 8}),
+          /*expected_output_dtypes*/ {DT_INT64},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ 10,
+          /*breakpoints*/ {0, 1, 9}};
 }
 
 // Test case 3: similar with the test case 2 but a smaller buffer size than
 // the input dataset.
 TestCase TestCase3() {
-  return {
-      /*range_data_param*/ {0, 10, 1},
-      /*buffer_size*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*seed*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-      /*seed2*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*count*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-      /*reshuffle_each_iteration*/ true,
-      /*expected_shuffle_outputs*/
-      ConvertToTensorVec<int64>({0, 2, 1, 3, 5, 6, 4, 7, 8, 9}),
-      /*expected_reshuffle_outputs*/
-      ConvertToTensorVec<int64>({1, 0, 2, 3, 4, 5, 6, 7, 9, 8}),
-      /*expected_output_dtypes*/ {DT_INT64},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ 10,
-      /*breakpoints*/ {0, 1, 9}};
+  return {/*range_data_param*/ {0, 10, 1},
+          /*buffer_size*/
+          CreateTensor<int64>(TensorShape({}), {2}),
+          /*seed*/ CreateTensor<int64>(TensorShape({}), {1}),
+          /*seed2*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*count*/ CreateTensor<int64>(TensorShape({}), {1}),
+          /*reshuffle_each_iteration*/ true,
+          /*expected_shuffle_outputs*/
+          ConvertToTensorVec<int64>({0, 2, 1, 3, 5, 6, 4, 7, 8, 9}),
+          /*expected_reshuffle_outputs*/
+          ConvertToTensorVec<int64>({1, 0, 2, 3, 4, 5, 6, 7, 9, 8}),
+          /*expected_output_dtypes*/ {DT_INT64},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ 10,
+          /*breakpoints*/ {0, 1, 9}};
 }
 
 // Test case 4: similar with the test case 2 but has different seeds.
 TestCase TestCase4() {
-  return {
-      /*range_data_param*/ {0, 10, 1},
-      /*buffer_size*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {10}),
-      /*seed*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*seed2*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*count*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-      /*reshuffle_each_iteration*/ true,
-      /*expected_shuffle_outputs*/
-      ConvertToTensorVec<int64>({3, 0, 8, 1, 5, 4, 7, 2, 6, 9}),
-      /*expected_reshuffle_outputs*/
-      ConvertToTensorVec<int64>({4, 6, 9, 0, 1, 8, 2, 7, 3, 5}),
-      /*expected_output_dtypes*/ {DT_INT64},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ 10,
-      /*breakpoints*/ {0, 1, 9}};
+  return {/*range_data_param*/ {0, 10, 1},
+          /*buffer_size*/
+          CreateTensor<int64>(TensorShape({}), {10}),
+          /*seed*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*seed2*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*count*/ CreateTensor<int64>(TensorShape({}), {1}),
+          /*reshuffle_each_iteration*/ true,
+          /*expected_shuffle_outputs*/
+          ConvertToTensorVec<int64>({3, 0, 8, 1, 5, 4, 7, 2, 6, 9}),
+          /*expected_reshuffle_outputs*/
+          ConvertToTensorVec<int64>({4, 6, 9, 0, 1, 8, 2, 7, 3, 5}),
+          /*expected_output_dtypes*/ {DT_INT64},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ 10,
+          /*breakpoints*/ {0, 1, 9}};
 }
 
 // Test case 5: test shuffle_dataset with buffer_size = 1 &
 // reshuffle_each_iteration = true.
 TestCase TestCase5() {
-  return {
-      /*range_data_param*/ {0, 10, 1},
-      /*buffer_size*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-      /*seed*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-      /*seed2*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*count*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-      /*reshuffle_each_iteration*/ true,
-      /*expected_shuffle_outputs*/
-      ConvertToTensorVec<int64>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}),
-      /*expected_reshuffle_outputs*/
-      ConvertToTensorVec<int64>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}),
-      /*expected_output_dtypes*/ {DT_INT64},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ 10,
-      /*breakpoints*/ {0, 1, 9}};
+  return {/*range_data_param*/ {0, 10, 1},
+          /*buffer_size*/
+          CreateTensor<int64>(TensorShape({}), {1}),
+          /*seed*/ CreateTensor<int64>(TensorShape({}), {1}),
+          /*seed2*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*count*/ CreateTensor<int64>(TensorShape({}), {1}),
+          /*reshuffle_each_iteration*/ true,
+          /*expected_shuffle_outputs*/
+          ConvertToTensorVec<int64>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}),
+          /*expected_reshuffle_outputs*/
+          ConvertToTensorVec<int64>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}),
+          /*expected_output_dtypes*/ {DT_INT64},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ 10,
+          /*breakpoints*/ {0, 1, 9}};
 }
 
 // Test case 6: test shuffle_dataset with an empty input dataset.
 TestCase TestCase6() {
-  return {
-      /*range_data_param*/ {0, 0, 1},
-      /*buffer_size*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {10}),
-      /*seed*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-      /*seed2*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*count*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-      /*reshuffle_each_iteration*/ true,
-      /*expected_shuffle_outputs*/
-      ConvertToTensorVec<int64>({}),
-      /*expected_reshuffle_outputs*/
-      ConvertToTensorVec<int64>({}),
-      /*expected_output_dtypes*/ {DT_INT64},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ 0,
-      /*breakpoints*/ {0, 1, 9}};
+  return {/*range_data_param*/ {0, 0, 1},
+          /*buffer_size*/
+          CreateTensor<int64>(TensorShape({}), {10}),
+          /*seed*/ CreateTensor<int64>(TensorShape({}), {1}),
+          /*seed2*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*count*/ CreateTensor<int64>(TensorShape({}), {1}),
+          /*reshuffle_each_iteration*/ true,
+          /*expected_shuffle_outputs*/
+          ConvertToTensorVec<int64>({}),
+          /*expected_reshuffle_outputs*/
+          ConvertToTensorVec<int64>({}),
+          /*expected_output_dtypes*/ {DT_INT64},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ 0,
+          /*breakpoints*/ {0, 1, 9}};
 }
 
 // Test case 7: test shuffle_and_repeat_dataset with buffer_size = 10 &
 // count = 2.
 TestCase TestCase7() {
-  return {
-      /*range_data_param*/ {0, 10, 1},
-      /*buffer_size*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {10}),
-      /*seed*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-      /*seed2*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*count*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*reshuffle_each_iteration*/ false,
-      /*expected_shuffle_outputs*/
-      ConvertToTensorVec<int64>(
-          {9, 0, 8, 6, 1, 3, 7, 2, 4, 5, 4, 3, 0, 5, 8, 2, 6, 9, 7, 1}),
-      /*expected_reshuffle_outputs*/
-      ConvertToTensorVec<int64>(
-          {9, 0, 8, 6, 1, 3, 7, 2, 4, 5, 4, 3, 0, 5, 8, 2, 6, 9, 7, 1}),
-      /*expected_output_dtypes*/ {DT_INT64},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ 20,
-      /*breakpoints*/ {0, 5, 22}};
+  return {/*range_data_param*/ {0, 10, 1},
+          /*buffer_size*/
+          CreateTensor<int64>(TensorShape({}), {10}),
+          /*seed*/ CreateTensor<int64>(TensorShape({}), {1}),
+          /*seed2*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*count*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*reshuffle_each_iteration*/ false,
+          /*expected_shuffle_outputs*/
+          ConvertToTensorVec<int64>(
+              {9, 0, 8, 6, 1, 3, 7, 2, 4, 5, 4, 3, 0, 5, 8, 2, 6, 9, 7, 1}),
+          /*expected_reshuffle_outputs*/
+          ConvertToTensorVec<int64>(
+              {9, 0, 8, 6, 1, 3, 7, 2, 4, 5, 4, 3, 0, 5, 8, 2, 6, 9, 7, 1}),
+          /*expected_output_dtypes*/ {DT_INT64},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ 20,
+          /*breakpoints*/ {0, 5, 22}};
 }
 
 // Test case 8: test shuffle_and_repeat_dataset with buffer_size = 10 &
 // count = -1
 TestCase TestCase8() {
-  return {
-      /*range_data_param*/ {0, 3, 1},
-      /*buffer_size*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {10}),
-      /*seed*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-      /*seed2*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*count*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {-1}),
-      /*reshuffle_each_iteration*/ false,
-      /*expected_shuffle_outputs*/
-      ConvertToTensorVec<int64>(
-          {2, 0, 1, 2, 0, 1, 1, 2, 0, 1, 0, 2, 2, 0, 1, 1, 0, 2, 2, 1, 0}),
-      /*expected_reshuffle_outputs*/
-      ConvertToTensorVec<int64>(
-          {2, 0, 1, 2, 0, 1, 1, 2, 0, 1, 0, 2, 2, 0, 1, 1, 0, 2, 2, 1, 0}),
-      /*expected_output_dtypes*/ {DT_INT64},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ kInfiniteCardinality,
-      /*breakpoints*/ {0, 5, 20}};
+  return {/*range_data_param*/ {0, 3, 1},
+          /*buffer_size*/
+          CreateTensor<int64>(TensorShape({}), {10}),
+          /*seed*/ CreateTensor<int64>(TensorShape({}), {1}),
+          /*seed2*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*count*/ CreateTensor<int64>(TensorShape({}), {-1}),
+          /*reshuffle_each_iteration*/ false,
+          /*expected_shuffle_outputs*/
+          ConvertToTensorVec<int64>(
+              {2, 0, 1, 2, 0, 1, 1, 2, 0, 1, 0, 2, 2, 0, 1, 1, 0, 2, 2, 1, 0}),
+          /*expected_reshuffle_outputs*/
+          ConvertToTensorVec<int64>(
+              {2, 0, 1, 2, 0, 1, 1, 2, 0, 1, 0, 2, 2, 0, 1, 1, 0, 2, 2, 1, 0}),
+          /*expected_output_dtypes*/ {DT_INT64},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ kInfiniteCardinality,
+          /*breakpoints*/ {0, 5, 20}};
 }
 
 TestCase InvalidBufferSizeTestCaseForShuffleDataset() {
-  return {
-      /*range_data_param*/ {0, 10, 1},
-      /*buffer_size*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {-1}),
-      /*seed*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-      /*seed2*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*count*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-      /*reshuffle_each_iteration*/ true,
-      /*expected_shuffle_outputs*/ ConvertToTensorVec<int64>({}),
-      /*expected_reshuffle_outputs*/ ConvertToTensorVec<int64>({}),
-      /*expected_output_dtypes*/ {DT_INT64},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ 0,
-      /*breakpoints*/ {0, 1, 9}};
+  return {/*range_data_param*/ {0, 10, 1},
+          /*buffer_size*/
+          CreateTensor<int64>(TensorShape({}), {-1}),
+          /*seed*/ CreateTensor<int64>(TensorShape({}), {1}),
+          /*seed2*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*count*/ CreateTensor<int64>(TensorShape({}), {1}),
+          /*reshuffle_each_iteration*/ true,
+          /*expected_shuffle_outputs*/ ConvertToTensorVec<int64>({}),
+          /*expected_reshuffle_outputs*/ ConvertToTensorVec<int64>({}),
+          /*expected_output_dtypes*/ {DT_INT64},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ 0,
+          /*breakpoints*/ {0, 1, 9}};
 }
 
 TestCase InvalidBufferSizeTestCaseForShuffleAndRepeatDataset() {
-  return {
-      /*range_data_param*/ {0, 10, 1},
-      /*buffer_size*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {-1}),
-      /*seed*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-      /*seed2*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*count*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*reshuffle_each_iteration*/ true,
-      /*expected_shuffle_outputs*/ ConvertToTensorVec<int64>({}),
-      /*expected_reshuffle_outputs*/ ConvertToTensorVec<int64>({}),
-      /*expected_output_dtypes*/ {DT_INT64},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ 0,
-      /*breakpoints*/ {0, 1, 9}};
+  return {/*range_data_param*/ {0, 10, 1},
+          /*buffer_size*/
+          CreateTensor<int64>(TensorShape({}), {-1}),
+          /*seed*/ CreateTensor<int64>(TensorShape({}), {1}),
+          /*seed2*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*count*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*reshuffle_each_iteration*/ true,
+          /*expected_shuffle_outputs*/ ConvertToTensorVec<int64>({}),
+          /*expected_reshuffle_outputs*/ ConvertToTensorVec<int64>({}),
+          /*expected_output_dtypes*/ {DT_INT64},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ 0,
+          /*breakpoints*/ {0, 1, 9}};
 }
 
 TestCase InvalidCountTestCaseForShuffleAndRepeatDataset() {
-  return {
-      /*range_data_param*/ {0, 3, 1},
-      /*buffer_size*/
-      DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {10}),
-      /*seed*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-      /*seed2*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*count*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-      /*reshuffle_each_iteration*/ false,
-      /*expected_shuffle_outputs*/
-      ConvertToTensorVec<int64>({}),
-      /*expected_reshuffle_outputs*/
-      ConvertToTensorVec<int64>({}),
-      /*expected_output_dtypes*/ {DT_INT64},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ 0,
-      /*breakpoints*/ {0, 5, 20}};
+  return {/*range_data_param*/ {0, 3, 1},
+          /*buffer_size*/
+          CreateTensor<int64>(TensorShape({}), {10}),
+          /*seed*/ CreateTensor<int64>(TensorShape({}), {1}),
+          /*seed2*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*count*/ CreateTensor<int64>(TensorShape({}), {0}),
+          /*reshuffle_each_iteration*/ false,
+          /*expected_shuffle_outputs*/
+          ConvertToTensorVec<int64>({}),
+          /*expected_reshuffle_outputs*/
+          ConvertToTensorVec<int64>({}),
+          /*expected_output_dtypes*/ {DT_INT64},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ 0,
+          /*breakpoints*/ {0, 5, 20}};
 }
 
 class ParameterizedShuffleDatasetOpTest
@@ -618,51 +606,6 @@
   EXPECT_EQ(dataset->Cardinality(), test_case.expected_cardinality);
 }
 
-TEST_P(ParameterizedShuffleDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  TestCase test_case = GetParam();
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  Tensor count = test_case.count;
-  int64 count_value = count.flat<int64>()(0);
-  std::unique_ptr<OpKernel> dataset_kernel;
-  TF_ASSERT_OK(
-      CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration,
-                            test_case.expected_output_dtypes,
-                            test_case.expected_output_shapes, &dataset_kernel));
-
-  DatasetBase* range_dataset;
-  TF_ASSERT_OK(CreateRangeDataset<int64>(
-      test_case.range_data_param.start, test_case.range_data_param.end,
-      test_case.range_data_param.step, "range", &range_dataset));
-  Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
-  TF_ASSERT_OK(
-      StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
-  Tensor buffer_size = test_case.buffer_size;
-  Tensor seed = test_case.seed;
-  Tensor seed2 = test_case.seed2;
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&range_dataset_tensor), TensorValue(&buffer_size),
-       TensorValue(&seed), TensorValue(&seed2)});
-  if (count_value != 1) inputs.push_back(TensorValue(&count));
-
-  std::unique_ptr<OpKernelContext> dataset_context;
-  TF_ASSERT_OK(
-      CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context));
-  DatasetBase* dataset;
-  TF_ASSERT_OK(
-      CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset));
-  core::ScopedUnref scoped_unref_dataset(dataset);
-
-  std::unique_ptr<SerializationContext> serialization_context;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(dataset->Save(serialization_context.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_P(ParameterizedShuffleDatasetOpTest, IteratorOutputDtypes) {
   int thread_num = 2, cpu_num = 2;
   TestCase test_case = GetParam();
diff --git a/tensorflow/core/kernels/data/skip_dataset_op.cc b/tensorflow/core/kernels/data/skip_dataset_op.cc
index 4d378b2..5858c07 100644
--- a/tensorflow/core/kernels/data/skip_dataset_op.cc
+++ b/tensorflow/core/kernels/data/skip_dataset_op.cc
@@ -75,6 +75,10 @@
     return count_ < 0 ? 0 : std::max(0LL, n - count_);
   }
 
+  Status CheckExternalState() const override {
+    return input_->CheckExternalState();
+  }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
diff --git a/tensorflow/core/kernels/data/skip_dataset_op_test.cc b/tensorflow/core/kernels/data/skip_dataset_op_test.cc
index bc95bf7..8079b60 100644
--- a/tensorflow/core/kernels/data/skip_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/skip_dataset_op_test.cc
@@ -69,83 +69,83 @@
 
 // Test case 1: skip fewer than input size.
 TestCase SkipLessTestCase() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
-          /*count*/ 4,
-          /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {4}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {5}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {6}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {7}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {8}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {9})},
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ 6,
-          /*breakpoints*/ {0, 2, 7}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
+      /*count*/ 4,
+      /*expected_outputs*/
+      {CreateTensor<int64>(TensorShape{1}, {4}),
+       CreateTensor<int64>(TensorShape{1}, {5}),
+       CreateTensor<int64>(TensorShape{1}, {6}),
+       CreateTensor<int64>(TensorShape{1}, {7}),
+       CreateTensor<int64>(TensorShape{1}, {8}),
+       CreateTensor<int64>(TensorShape{1}, {9})},
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ 6,
+      /*breakpoints*/ {0, 2, 7}};
 }
 
 // Test case 2: skip more than input size.
 TestCase SkipMoreTestCase() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
-          /*count*/ 25,
-          /*expected_outputs*/ {},
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ 0,
-          /*breakpoints*/ {0, 2, 5}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
+      /*count*/ 25,
+      /*expected_outputs*/ {},
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ 0,
+      /*breakpoints*/ {0, 2, 5}};
 }
 
 // Test case 3: skip exactly the input size.
 TestCase SkipAllTestCase() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
-          /*count*/ 10,
-          /*expected_outputs*/ {},
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ 0,
-          /*breakpoints*/ {0, 2, 5}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
+      /*count*/ 10,
+      /*expected_outputs*/ {},
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ 0,
+      /*breakpoints*/ {0, 2, 5}};
 }
 
 // Test case 4: skip nothing.
 TestCase SkipNothingTestCase() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
-          /*count*/ 0,
-          /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {0}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {1}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {2}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {4}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {5}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {6}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {7}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {8}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {9})},
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ 10,
-          /*breakpoints*/ {0, 2, 5, 11}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
+      /*count*/ 0,
+      /*expected_outputs*/
+      {CreateTensor<int64>(TensorShape{1}, {0}),
+       CreateTensor<int64>(TensorShape{1}, {1}),
+       CreateTensor<int64>(TensorShape{1}, {2}),
+       CreateTensor<int64>(TensorShape{1}, {3}),
+       CreateTensor<int64>(TensorShape{1}, {4}),
+       CreateTensor<int64>(TensorShape{1}, {5}),
+       CreateTensor<int64>(TensorShape{1}, {6}),
+       CreateTensor<int64>(TensorShape{1}, {7}),
+       CreateTensor<int64>(TensorShape{1}, {8}),
+       CreateTensor<int64>(TensorShape{1}, {9})},
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ 10,
+      /*breakpoints*/ {0, 2, 5, 11}};
 }
 
 // Test case 5: set -1 for `count` to skip the entire dataset.
 TestCase SkipEntireDatasetTestCase() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
-          /*count*/ -1,
-          /*expected_outputs*/ {},
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ 0,
-          /*breakpoints*/ {0, 2, 5}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
+      /*count*/ -1,
+      /*expected_outputs*/ {},
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ 0,
+      /*breakpoints*/ {0, 2, 5}};
 }
 
 class ParameterizedSkipDatasetOpTest
@@ -356,41 +356,6 @@
   EXPECT_EQ(skip_dataset->Cardinality(), test_case.expected_cardinality);
 }
 
-TEST_F(SkipDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  const TestCase &test_case = SkipLessTestCase();
-  Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
-  std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
-  TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
-                                              &tensor_slice_dataset_tensor));
-  Tensor count = CreateTensor<int64>(TensorShape{}, {test_case.count});
-  gtl::InlinedVector<TensorValue, 4> inputs_for_skip_dataset(
-      {TensorValue(&tensor_slice_dataset_tensor), TensorValue(&count)});
-
-  std::unique_ptr<OpKernel> skip_dataset_kernel;
-  TF_ASSERT_OK(CreateSkipDatasetKernel(test_case.expected_output_dtypes,
-                                       test_case.expected_output_shapes,
-                                       &skip_dataset_kernel));
-  std::unique_ptr<OpKernelContext> skip_dataset_context;
-  TF_ASSERT_OK(CreateSkipDatasetContext(skip_dataset_kernel.get(),
-                                        &inputs_for_skip_dataset,
-                                        &skip_dataset_context));
-  DatasetBase *skip_dataset;
-  TF_ASSERT_OK(CreateDataset(skip_dataset_kernel.get(),
-                             skip_dataset_context.get(), &skip_dataset));
-  core::ScopedUnref scoped_unref(skip_dataset);
-
-  std::unique_ptr<SerializationContext> serialization_ctx;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(skip_dataset->Save(serialization_ctx.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_P(ParameterizedSkipDatasetOpTest, IteratorOutputDtypes) {
   int thread_num = 2, cpu_num = 2;
   TF_ASSERT_OK(InitThreadPool(thread_num));
diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
index d8d7cd2..ffc74fc 100644
--- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
@@ -56,6 +56,8 @@
 
   int64 Cardinality() const override { return sparse_tensor_.shape()[0]; }
 
+  Status CheckExternalState() const override { return Status::OK(); }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op_test.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op_test.cc
index c8586d9..e38d167 100644
--- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op_test.cc
@@ -57,87 +57,78 @@
 };
 
 TestCase TwoDimsTestCase() {
-  return {
-      /*input_sparse_tensor*/
-      {/*indices*/ DatasetOpsTestBase::CreateTensor<int64>({2, 2},
-                                                           {0, 0, 1, 1}),
-       /*values*/ DatasetOpsTestBase::CreateTensor<int32>({2}, {888, 999}),
-       /*dense_shape*/ DatasetOpsTestBase::CreateTensor<int64>({2}, {2, 2})},
-      /*expected_outputs*/
-      {{/*indices*/ DatasetOpsTestBase::CreateTensor<int64>({1, 1}, {0}),
-        /*values*/ DatasetOpsTestBase::CreateTensor<int32>({1}, {888}),
-        /*dense_shape*/ DatasetOpsTestBase::CreateTensor<int64>({1}, {2})},
-       {/*indices*/ DatasetOpsTestBase::CreateTensor<int64>({1, 1}, {1}),
-        /*values*/ DatasetOpsTestBase::CreateTensor<int32>({1}, {999}),
-        /*dense_shape*/ DatasetOpsTestBase::CreateTensor<int64>({1}, {2})}},
-      /*breakpoints*/ {0, 1, 2}};
+  return {/*input_sparse_tensor*/
+          {/*indices*/ CreateTensor<int64>({2, 2}, {0, 0, 1, 1}),
+           /*values*/ CreateTensor<int32>({2}, {888, 999}),
+           /*dense_shape*/ CreateTensor<int64>({2}, {2, 2})},
+          /*expected_outputs*/
+          {{/*indices*/ CreateTensor<int64>({1, 1}, {0}),
+            /*values*/ CreateTensor<int32>({1}, {888}),
+            /*dense_shape*/ CreateTensor<int64>({1}, {2})},
+           {/*indices*/ CreateTensor<int64>({1, 1}, {1}),
+            /*values*/ CreateTensor<int32>({1}, {999}),
+            /*dense_shape*/ CreateTensor<int64>({1}, {2})}},
+          /*breakpoints*/ {0, 1, 2}};
 }
 
 TestCase ThreeDimsTestCase() {
-  return {
-      /*input_sparse_tensor*/
-      {/*indices*/ DatasetOpsTestBase::CreateTensor<int64>({2, 3},
-                                                           {0, 0, 0, 1, 1, 1}),
-       /*values*/ DatasetOpsTestBase::CreateTensor<double>({2}, {888.0, 999.0}),
-       /*dense_shape*/ DatasetOpsTestBase::CreateTensor<int64>({3}, {2, 2, 2})},
-      /*expected_outputs*/
-      {{/*indices*/ DatasetOpsTestBase::CreateTensor<int64>({1, 2}, {0, 0}),
-        /*values*/ DatasetOpsTestBase::CreateTensor<double>({1}, {888.0}),
-        /*dense_shape*/ DatasetOpsTestBase::CreateTensor<int64>({2}, {2, 2})},
-       {{/*indices*/ DatasetOpsTestBase::CreateTensor<int64>({1, 2}, {1, 1})},
-        {/*values*/ DatasetOpsTestBase::CreateTensor<double>({1}, {999.0})},
-        {/*dense_shape*/ DatasetOpsTestBase::CreateTensor<int64>({2},
-                                                                 {2, 2})}}},
-      /*breakpoints*/ {0, 1, 2}};
+  return {/*input_sparse_tensor*/
+          {/*indices*/ CreateTensor<int64>({2, 3}, {0, 0, 0, 1, 1, 1}),
+           /*values*/ CreateTensor<double>({2}, {888.0, 999.0}),
+           /*dense_shape*/ CreateTensor<int64>({3}, {2, 2, 2})},
+          /*expected_outputs*/
+          {{/*indices*/ CreateTensor<int64>({1, 2}, {0, 0}),
+            /*values*/ CreateTensor<double>({1}, {888.0}),
+            /*dense_shape*/ CreateTensor<int64>({2}, {2, 2})},
+           {{/*indices*/ CreateTensor<int64>({1, 2}, {1, 1})},
+            {/*values*/ CreateTensor<double>({1}, {999.0})},
+            {/*dense_shape*/ CreateTensor<int64>({2}, {2, 2})}}},
+          /*breakpoints*/ {0, 1, 2}};
 }
 
 TestCase FourDimsTestCase() {
-  return {
-      /*input_sparse_tensor*/
-      {/*indices*/ DatasetOpsTestBase::CreateTensor<int64>(
-           {2, 4}, {0, 0, 0, 0, 1, 1, 1, 1}),
-       /*values*/ DatasetOpsTestBase::CreateTensor<string>({2}, {"a", "b"}),
-       /*dense_shape*/
-       DatasetOpsTestBase::CreateTensor<int64>({4}, {3, 2, 2, 2})},
-      /*expected_outputs*/
-      {{/*indices*/ DatasetOpsTestBase::CreateTensor<int64>({1, 3}, {0, 0, 0}),
-        /*values*/ DatasetOpsTestBase::CreateTensor<string>({1}, {"a"}),
-        /*dense_shape*/
-        DatasetOpsTestBase::CreateTensor<int64>({3}, {2, 2, 2})},
-       {/*indices*/ DatasetOpsTestBase::CreateTensor<int64>({1, 3}, {1, 1, 1}),
-        /*values*/ DatasetOpsTestBase::CreateTensor<string>({1}, {"b"}),
-        /*dense_shape*/
-        DatasetOpsTestBase::CreateTensor<int64>({3}, {2, 2, 2})},
-       {/*indices*/ DatasetOpsTestBase::CreateTensor<int64>({0, 3}, {}),
-        /*values*/ DatasetOpsTestBase::CreateTensor<string>({0}, {}),
-        /*dense_shape*/
-        DatasetOpsTestBase::CreateTensor<int64>({3}, {2, 2, 2})}},
-      /*breakpoints*/ {0, 1, 3}};
+  return {/*input_sparse_tensor*/
+          {/*indices*/ CreateTensor<int64>({2, 4}, {0, 0, 0, 0, 1, 1, 1, 1}),
+           /*values*/ CreateTensor<string>({2}, {"a", "b"}),
+           /*dense_shape*/
+           CreateTensor<int64>({4}, {3, 2, 2, 2})},
+          /*expected_outputs*/
+          {{/*indices*/ CreateTensor<int64>({1, 3}, {0, 0, 0}),
+            /*values*/ CreateTensor<string>({1}, {"a"}),
+            /*dense_shape*/
+            CreateTensor<int64>({3}, {2, 2, 2})},
+           {/*indices*/ CreateTensor<int64>({1, 3}, {1, 1, 1}),
+            /*values*/ CreateTensor<string>({1}, {"b"}),
+            /*dense_shape*/
+            CreateTensor<int64>({3}, {2, 2, 2})},
+           {/*indices*/ CreateTensor<int64>({0, 3}, {}),
+            /*values*/ CreateTensor<string>({0}, {}),
+            /*dense_shape*/
+            CreateTensor<int64>({3}, {2, 2, 2})}},
+          /*breakpoints*/ {0, 1, 3}};
 }
 
 TestCase FiveDimsTestCase() {
-  return {/*input_sparse_tensor*/
-          {/*indices*/ DatasetOpsTestBase::CreateTensor<int64>(
-               {2, 5}, {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}),
-           /*values*/ DatasetOpsTestBase::CreateTensor<int32>({2}, {888, 999}),
-           /*dense_shape*/
-           DatasetOpsTestBase::CreateTensor<int64>({5}, {3, 2, 2, 2, 2})},
-          /*expected_outputs*/
-          {{/*indices*/ DatasetOpsTestBase::CreateTensor<int64>({1, 4},
-                                                                {0, 0, 0, 0}),
-            /*values*/ DatasetOpsTestBase::CreateTensor<int32>({1}, {888}),
-            /*dense_shape*/
-            DatasetOpsTestBase::CreateTensor<int64>({4}, {2, 2, 2, 2})},
-           {/*indices*/ DatasetOpsTestBase::CreateTensor<int64>({1, 4},
-                                                                {1, 1, 1, 1}),
-            /*values*/ DatasetOpsTestBase::CreateTensor<int32>({1}, {999}),
-            /*dense_shape*/
-            DatasetOpsTestBase::CreateTensor<int64>({4}, {2, 2, 2, 2})},
-           {/*indices*/ DatasetOpsTestBase::CreateTensor<int64>({0, 4}, {}),
-            /*values*/ DatasetOpsTestBase::CreateTensor<int32>({0}, {}),
-            /*dense_shape*/
-            DatasetOpsTestBase::CreateTensor<int64>({4}, {2, 2, 2, 2})}},
-          /*breakpoints*/ {0, 1, 3}};
+  return {
+      /*input_sparse_tensor*/
+      {/*indices*/ CreateTensor<int64>({2, 5}, {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}),
+       /*values*/ CreateTensor<int32>({2}, {888, 999}),
+       /*dense_shape*/
+       CreateTensor<int64>({5}, {3, 2, 2, 2, 2})},
+      /*expected_outputs*/
+      {{/*indices*/ CreateTensor<int64>({1, 4}, {0, 0, 0, 0}),
+        /*values*/ CreateTensor<int32>({1}, {888}),
+        /*dense_shape*/
+        CreateTensor<int64>({4}, {2, 2, 2, 2})},
+       {/*indices*/ CreateTensor<int64>({1, 4}, {1, 1, 1, 1}),
+        /*values*/ CreateTensor<int32>({1}, {999}),
+        /*dense_shape*/
+        CreateTensor<int64>({4}, {2, 2, 2, 2})},
+       {/*indices*/ CreateTensor<int64>({0, 4}, {}),
+        /*values*/ CreateTensor<int32>({0}, {}),
+        /*dense_shape*/
+        CreateTensor<int64>({4}, {2, 2, 2, 2})}},
+      /*breakpoints*/ {0, 1, 3}};
 }
 
 class ParameterizedSparseTensorSliceDatasetOpTest
@@ -333,38 +324,6 @@
   EXPECT_EQ(dataset->Cardinality(), expected_outputs.size());
 }
 
-TEST_F(SparseTensorSliceDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  const TestCase &test_case = TwoDimsTestCase();
-  SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor;
-  std::vector<SparseTensorParam> expected_outputs = test_case.expected_outputs;
-  DataType tvalues = input_sparse_tensor.values.dtype();
-  gtl::InlinedVector<TensorValue, 4> inputs = {
-      TensorValue(&input_sparse_tensor.indices),
-      TensorValue(&input_sparse_tensor.values),
-      TensorValue(&input_sparse_tensor.dense_shape)};
-
-  std::unique_ptr<OpKernel> dataset_kernel;
-  TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel));
-  std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
-  TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext(
-      dataset_kernel.get(), &inputs, &dataset_kernel_ctx));
-  DatasetBase *dataset;
-  TF_ASSERT_OK(
-      CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), &dataset));
-  core::ScopedUnref scoped_unref(dataset);
-
-  std::unique_ptr<SerializationContext> serialization_ctx;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(dataset->Save(serialization_ctx.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, IteratorOutputDtypes) {
   int thread_num = 2, cpu_num = 2;
   TF_ASSERT_OK(InitThreadPool(thread_num));
diff --git a/tensorflow/core/kernels/data/take_dataset_op.cc b/tensorflow/core/kernels/data/take_dataset_op.cc
index 9cee97d..8fc9cde 100644
--- a/tensorflow/core/kernels/data/take_dataset_op.cc
+++ b/tensorflow/core/kernels/data/take_dataset_op.cc
@@ -71,6 +71,10 @@
   return std::min(n, count_);
 }
 
+Status TakeDataset::CheckExternalState() const {
+  return input_->CheckExternalState();
+}
+
 class TakeDataset::EmptyIterator : public DatasetIterator<TakeDataset> {
  public:
   explicit EmptyIterator(const Params& params)
diff --git a/tensorflow/core/kernels/data/take_dataset_op.h b/tensorflow/core/kernels/data/take_dataset_op.h
index 5d76f6d..03f8ff6 100644
--- a/tensorflow/core/kernels/data/take_dataset_op.h
+++ b/tensorflow/core/kernels/data/take_dataset_op.h
@@ -32,13 +32,15 @@
   std::unique_ptr<IteratorBase> MakeIteratorInternal(
       const string& prefix) const override;
 
-  const DataTypeVector& output_dtypes() const;
+  const DataTypeVector& output_dtypes() const override;
 
-  const std::vector<PartialTensorShape>& output_shapes() const;
+  const std::vector<PartialTensorShape>& output_shapes() const override;
 
-  string DebugString() const;
+  string DebugString() const override;
 
-  int64 Cardinality() const;
+  int64 Cardinality() const override;
+
+  Status CheckExternalState() const override;
 
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
diff --git a/tensorflow/core/kernels/data/take_dataset_op_test.cc b/tensorflow/core/kernels/data/take_dataset_op_test.cc
index b482a52..0a75c06 100644
--- a/tensorflow/core/kernels/data/take_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/take_dataset_op_test.cc
@@ -69,78 +69,78 @@
 
 // Test case 1: take fewer than input size.
 TestCase TakeLessTestCase() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
-          /*count*/ 4,
-          /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {0}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {1}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {2}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3})},
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ 4,
-          /*breakpoints*/ {0, 2, 5}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
+      /*count*/ 4,
+      /*expected_outputs*/
+      {CreateTensor<int64>(TensorShape{1}, {0}),
+       CreateTensor<int64>(TensorShape{1}, {1}),
+       CreateTensor<int64>(TensorShape{1}, {2}),
+       CreateTensor<int64>(TensorShape{1}, {3})},
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ 4,
+      /*breakpoints*/ {0, 2, 5}};
 }
 
 // Test case 2: take more than input size.
 TestCase TakeMoreTestCase() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
-          /*count*/ 25,
-          /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {0}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {1}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {2}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {4}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {5}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {6}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {7}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {8}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {9})},
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ 10,
-          /*breakpoints*/ {0, 2, 5, 11}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
+      /*count*/ 25,
+      /*expected_outputs*/
+      {CreateTensor<int64>(TensorShape{1}, {0}),
+       CreateTensor<int64>(TensorShape{1}, {1}),
+       CreateTensor<int64>(TensorShape{1}, {2}),
+       CreateTensor<int64>(TensorShape{1}, {3}),
+       CreateTensor<int64>(TensorShape{1}, {4}),
+       CreateTensor<int64>(TensorShape{1}, {5}),
+       CreateTensor<int64>(TensorShape{1}, {6}),
+       CreateTensor<int64>(TensorShape{1}, {7}),
+       CreateTensor<int64>(TensorShape{1}, {8}),
+       CreateTensor<int64>(TensorShape{1}, {9})},
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ 10,
+      /*breakpoints*/ {0, 2, 5, 11}};
 }
 
 // Test case 3: take all of input.
 TestCase TakeAllTestCase() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
-          /*count*/ -1,
-          /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {0}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {1}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {2}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {4}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {5}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {6}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {7}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {8}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {9})},
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ -1,
-          /*breakpoints*/ {0, 2, 5, 11}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
+      /*count*/ -1,
+      /*expected_outputs*/
+      {CreateTensor<int64>(TensorShape{1}, {0}),
+       CreateTensor<int64>(TensorShape{1}, {1}),
+       CreateTensor<int64>(TensorShape{1}, {2}),
+       CreateTensor<int64>(TensorShape{1}, {3}),
+       CreateTensor<int64>(TensorShape{1}, {4}),
+       CreateTensor<int64>(TensorShape{1}, {5}),
+       CreateTensor<int64>(TensorShape{1}, {6}),
+       CreateTensor<int64>(TensorShape{1}, {7}),
+       CreateTensor<int64>(TensorShape{1}, {8}),
+       CreateTensor<int64>(TensorShape{1}, {9})},
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ -1,
+      /*breakpoints*/ {0, 2, 5, 11}};
 }
 
 // Test case 4: take nothing.
 TestCase TakeNothingTestCase() {
-  return {/*input_tensors*/
-          {DatasetOpsTestBase::CreateTensor<int64>(
-              TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
-          /*count*/ 0,
-          /*expected_outputs*/ {},
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({1})},
-          /*expected_cardinality*/ 0,
-          /*breakpoints*/ {0, 2, 5, 11}};
+  return {
+      /*input_tensors*/
+      {CreateTensor<int64>(TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
+      /*count*/ 0,
+      /*expected_outputs*/ {},
+      /*expected_output_dtypes*/ {DT_INT64},
+      /*expected_output_shapes*/ {PartialTensorShape({1})},
+      /*expected_cardinality*/ 0,
+      /*breakpoints*/ {0, 2, 5, 11}};
 }
 
 class ParameterizedTakeDatasetOpTest
@@ -351,41 +351,6 @@
   EXPECT_EQ(take_dataset->Cardinality(), test_case.expected_cardinality);
 }
 
-TEST_F(TakeDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-  const TestCase &test_case = TakeLessTestCase();
-  Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
-  std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
-  TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
-                                              &tensor_slice_dataset_tensor));
-  Tensor count = CreateTensor<int64>(TensorShape{}, {test_case.count});
-  gtl::InlinedVector<TensorValue, 4> inputs_for_take_dataset;
-  inputs_for_take_dataset.emplace_back(&tensor_slice_dataset_tensor);
-  inputs_for_take_dataset.emplace_back(&count);
-
-  std::unique_ptr<OpKernel> take_dataset_kernel;
-  TF_ASSERT_OK(CreateTakeDatasetKernel(test_case.expected_output_dtypes,
-                                       test_case.expected_output_shapes,
-                                       &take_dataset_kernel));
-  std::unique_ptr<OpKernelContext> take_dataset_context;
-  TF_ASSERT_OK(CreateTakeDatasetContext(take_dataset_kernel.get(),
-                                        &inputs_for_take_dataset,
-                                        &take_dataset_context));
-  DatasetBase *take_dataset;
-  TF_ASSERT_OK(CreateDataset(take_dataset_kernel.get(),
-                             take_dataset_context.get(), &take_dataset));
-  core::ScopedUnref scoped_unref(take_dataset);
-
-  std::unique_ptr<SerializationContext> serialization_ctx;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(take_dataset->Save(serialization_ctx.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputDtypes) {
   int thread_num = 2, cpu_num = 2;
   TF_ASSERT_OK(InitThreadPool(thread_num));
diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc
index 38acc0a..3a12690 100644
--- a/tensorflow/core/kernels/data/tensor_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc
@@ -62,6 +62,8 @@
 
   int64 Cardinality() const override { return 1LL; }
 
+  Status CheckExternalState() const override { return Status::OK(); }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
@@ -70,12 +72,12 @@
     components.reserve(tensors_.size());
     for (const Tensor& t : tensors_) {
       Node* node;
-      if (ctx->optimization_only()) {
+      if (ctx->serialize_data_tensors()) {
+        TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+      } else {
         TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node));
         DCHECK_NE(ctx->input_list(), nullptr);
         ctx->input_list()->emplace_back(node->name(), t);
-      } else {
-        TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
       }
       components.emplace_back(node);
     }
diff --git a/tensorflow/core/kernels/data/tensor_dataset_op_test.cc b/tensorflow/core/kernels/data/tensor_dataset_op_test.cc
index 48961a1..d60f9c6 100644
--- a/tensorflow/core/kernels/data/tensor_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/tensor_dataset_op_test.cc
@@ -68,47 +68,44 @@
 
 // Test case 1: test a dataset that represents a single tuple of plain tensors.
 TestCase PlainTensorsTestCase() {
-  return {
-      /*components*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({1, 3}), {1, 2, 3}),
-       DatasetOpsTestBase::CreateTensor<double>(TensorShape({}), {37.0}),
-       DatasetOpsTestBase::CreateTensor<string>(TensorShape({1, 2}),
-                                                {"a", "b"})},
-      /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({1, 3}), {1, 2, 3}),
-       DatasetOpsTestBase::CreateTensor<double>(TensorShape({}), {37.0}),
-       DatasetOpsTestBase::CreateTensor<string>(TensorShape({1, 2}),
-                                                {"a", "b"})},
-      /*expected_output_dtypes*/
-      {DT_INT64, DT_INT64, DT_DOUBLE, DT_STRING},
-      /*expected_output_shapes*/
-      {PartialTensorShape({}), PartialTensorShape({1, 3}),
-       PartialTensorShape({}), PartialTensorShape({1, 2})},
-      /*expected_cardinality*/ 1,
-      /*breakpoints*/ {0, 1, 2}};
+  return {/*components*/
+          {CreateTensor<int64>(TensorShape({}), {1}),
+           CreateTensor<int64>(TensorShape({1, 3}), {1, 2, 3}),
+           CreateTensor<double>(TensorShape({}), {37.0}),
+           CreateTensor<string>(TensorShape({1, 2}), {"a", "b"})},
+          /*expected_outputs*/
+          {CreateTensor<int64>(TensorShape({}), {1}),
+           CreateTensor<int64>(TensorShape({1, 3}), {1, 2, 3}),
+           CreateTensor<double>(TensorShape({}), {37.0}),
+           CreateTensor<string>(TensorShape({1, 2}), {"a", "b"})},
+          /*expected_output_dtypes*/
+          {DT_INT64, DT_INT64, DT_DOUBLE, DT_STRING},
+          /*expected_output_shapes*/
+          {PartialTensorShape({}), PartialTensorShape({1, 3}),
+           PartialTensorShape({}), PartialTensorShape({1, 2})},
+          /*expected_cardinality*/ 1,
+          /*breakpoints*/ {0, 1, 2}};
 }
 
 // Test case 2: test a dataset that represents a tuple of nested tensors.
 TestCase NestedTensorsTestCase() {
   return {
       /*components*/
-      {DatasetOpsTestBase::CreateTensor<Variant>(
-           TensorShape({}), {DatasetOpsTestBase::CreateTensor<double>(
-                                TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}),
-       DatasetOpsTestBase::CreateTensor<Variant>(
-           TensorShape({}), {DatasetOpsTestBase::CreateTensor<string>(
-                                TensorShape({1, 2}), {"a", "b"})}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({1, 3}), {1, 2, 3})},
+      {CreateTensor<Variant>(
+           TensorShape({}),
+           {CreateTensor<double>(TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}),
+       CreateTensor<Variant>(
+           TensorShape({}),
+           {CreateTensor<string>(TensorShape({1, 2}), {"a", "b"})}),
+       CreateTensor<int64>(TensorShape({1, 3}), {1, 2, 3})},
       /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<Variant>(
-           TensorShape({}), {DatasetOpsTestBase::CreateTensor<double>(
-                                TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}),
-       DatasetOpsTestBase::CreateTensor<Variant>(
-           TensorShape({}), {DatasetOpsTestBase::CreateTensor<string>(
-                                TensorShape({1, 2}), {"a", "b"})}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({1, 3}), {1, 2, 3})},
+      {CreateTensor<Variant>(
+           TensorShape({}),
+           {CreateTensor<double>(TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}),
+       CreateTensor<Variant>(
+           TensorShape({}),
+           {CreateTensor<string>(TensorShape({1, 2}), {"a", "b"})}),
+       CreateTensor<int64>(TensorShape({1, 3}), {1, 2, 3})},
       /*expected_output_dtypes*/
       {DT_VARIANT, DT_VARIANT, DT_INT64},
       /*expected_output_shapes*/
@@ -308,37 +305,6 @@
   EXPECT_EQ(tensor_dataset->Cardinality(), test_case.expected_cardinality);
 }
 
-TEST_P(ParametrizedTensorDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  const TestCase &test_case = GetParam();
-  std::vector<Tensor> components = test_case.components;
-  gtl::InlinedVector<TensorValue, 4> inputs;
-  for (auto &component : components) {
-    inputs.push_back(TensorValue(&component));
-  }
-  std::unique_ptr<OpKernel> tensor_dataset_kernel;
-  TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
-                                         test_case.expected_output_shapes,
-                                         &tensor_dataset_kernel));
-  std::unique_ptr<OpKernelContext> tensor_dataset_context;
-  TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
-                                          &tensor_dataset_context));
-  DatasetBase *tensor_dataset;
-  TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
-                             tensor_dataset_context.get(), &tensor_dataset));
-  core::ScopedUnref scoped_unref(tensor_dataset);
-
-  std::unique_ptr<SerializationContext> serialization_context;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(tensor_dataset->Save(serialization_context.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_P(ParametrizedTensorDatasetOpTest, IteratorOutputDtypes) {
   int thread_num = 2, cpu_num = 2;
   TF_ASSERT_OK(InitThreadPool(thread_num));
diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
index 308efeb..16f5b36 100644
--- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
@@ -68,6 +68,8 @@
 
   int64 Cardinality() const override { return tensors_[0].dim_size(0); }
 
+  Status CheckExternalState() const override { return Status::OK(); }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
@@ -76,12 +78,12 @@
     components.reserve(tensors_.size());
     for (const Tensor& t : tensors_) {
       Node* node;
-      if (ctx->optimization_only()) {
+      if (ctx->serialize_data_tensors()) {
+        TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+      } else {
         TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node));
         DCHECK_NE(ctx->input_list(), nullptr);
         ctx->input_list()->emplace_back(node->name(), t);
-      } else {
-        TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
       }
       components.emplace_back(node);
     }
diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op_test.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op_test.cc
index 2ef2807..e04d998 100644
--- a/tensorflow/core/kernels/data/tensor_slice_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op_test.cc
@@ -64,70 +64,61 @@
 
 TestCase PlainTensorTestCase() {
   return {/*components*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {1, 2}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2, 2}),
-                                                   {1, 2, 3, 4}),
-           DatasetOpsTestBase::CreateTensor<uint32>(TensorShape({2}), {2, 3}),
-           DatasetOpsTestBase::CreateTensor<uint32>(TensorShape({2, 2}),
-                                                    {2, 3, 4, 5}),
-           DatasetOpsTestBase::CreateTensor<uint64>(TensorShape({2}), {3, 4}),
-           DatasetOpsTestBase::CreateTensor<uint64>(TensorShape({2, 2}),
-                                                    {3, 4, 5, 6}),
-           DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 1}),
-                                                    {37.0, 38.0}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({2, 1}),
-                                                    {"a", "b"})},
+          {CreateTensor<int64>(TensorShape({2}), {1, 2}),
+           CreateTensor<int64>(TensorShape({2, 2}), {1, 2, 3, 4}),
+           CreateTensor<uint32>(TensorShape({2}), {2, 3}),
+           CreateTensor<uint32>(TensorShape({2, 2}), {2, 3, 4, 5}),
+           CreateTensor<uint64>(TensorShape({2}), {3, 4}),
+           CreateTensor<uint64>(TensorShape({2, 2}), {3, 4, 5, 6}),
+           CreateTensor<double>(TensorShape({2, 1}), {37.0, 38.0}),
+           CreateTensor<string>(TensorShape({2, 1}), {"a", "b"})},
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {1, 2}),
-           DatasetOpsTestBase::CreateTensor<uint32>(TensorShape({}), {2}),
-           DatasetOpsTestBase::CreateTensor<uint32>(TensorShape({2}), {2, 3}),
-           DatasetOpsTestBase::CreateTensor<uint64>(TensorShape({}), {3}),
-           DatasetOpsTestBase::CreateTensor<uint64>(TensorShape({2}), {3, 4}),
-           DatasetOpsTestBase::CreateTensor<double>(TensorShape({1}), {37.0}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), {"a"}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {3, 4}),
-           DatasetOpsTestBase::CreateTensor<uint32>(TensorShape({}), {3}),
-           DatasetOpsTestBase::CreateTensor<uint32>(TensorShape({2}), {4, 5}),
-           DatasetOpsTestBase::CreateTensor<uint64>(TensorShape({}), {4}),
-           DatasetOpsTestBase::CreateTensor<uint64>(TensorShape({2}), {5, 6}),
-           DatasetOpsTestBase::CreateTensor<double>(TensorShape({1}), {38.0}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), {"b"})},
+          {CreateTensor<int64>(TensorShape({}), {1}),
+           CreateTensor<int64>(TensorShape({2}), {1, 2}),
+           CreateTensor<uint32>(TensorShape({}), {2}),
+           CreateTensor<uint32>(TensorShape({2}), {2, 3}),
+           CreateTensor<uint64>(TensorShape({}), {3}),
+           CreateTensor<uint64>(TensorShape({2}), {3, 4}),
+           CreateTensor<double>(TensorShape({1}), {37.0}),
+           CreateTensor<string>(TensorShape({1}), {"a"}),
+           CreateTensor<int64>(TensorShape({}), {2}),
+           CreateTensor<int64>(TensorShape({2}), {3, 4}),
+           CreateTensor<uint32>(TensorShape({}), {3}),
+           CreateTensor<uint32>(TensorShape({2}), {4, 5}),
+           CreateTensor<uint64>(TensorShape({}), {4}),
+           CreateTensor<uint64>(TensorShape({2}), {5, 6}),
+           CreateTensor<double>(TensorShape({1}), {38.0}),
+           CreateTensor<string>(TensorShape({1}), {"b"})},
           /*breakpoints*/ {0, 1, 3}};
 }
 
 TestCase NestedTensorTestCase() {
   return {
       /*components*/
-      {DatasetOpsTestBase::CreateTensor<Variant>(
+      {CreateTensor<Variant>(
            TensorShape({2, 1}),
-           {DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 2}),
-                                                     {1.0, 2.0, 3.0, 4.0}),
-            DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 2}),
-                                                     {5.0, 6.0, 7.0, 8.0})}),
-       DatasetOpsTestBase::CreateTensor<Variant>(
-           TensorShape({2, 1}), {DatasetOpsTestBase::CreateTensor<string>(
-                                     TensorShape({1, 2}), {"a", "b"}),
-                                 DatasetOpsTestBase::CreateTensor<string>(
-                                     TensorShape({1, 2}), {"c", "d"})}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2, 3}),
-                                               {1, 2, 3, 4, 5, 6})},
+           {CreateTensor<double>(TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0}),
+            CreateTensor<double>(TensorShape({2, 2}), {5.0, 6.0, 7.0, 8.0})}),
+       CreateTensor<Variant>(
+           TensorShape({2, 1}),
+           {CreateTensor<string>(TensorShape({1, 2}), {"a", "b"}),
+            CreateTensor<string>(TensorShape({1, 2}), {"c", "d"})}),
+       CreateTensor<int64>(TensorShape({2, 3}), {1, 2, 3, 4, 5, 6})},
       /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<Variant>(
-           TensorShape({1}), {DatasetOpsTestBase::CreateTensor<double>(
-                                 TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}),
-       DatasetOpsTestBase::CreateTensor<Variant>(
-           TensorShape({1}), {DatasetOpsTestBase::CreateTensor<string>(
-                                 TensorShape({1, 2}), {"a", "b"})}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3}), {1, 2, 3}),
-       DatasetOpsTestBase::CreateTensor<Variant>(
-           TensorShape({1}), {DatasetOpsTestBase::CreateTensor<double>(
-                                 TensorShape({2, 2}), {5.0, 6.0, 7.0, 8.0})}),
-       DatasetOpsTestBase::CreateTensor<Variant>(
-           TensorShape({1}), {DatasetOpsTestBase::CreateTensor<string>(
-                                 TensorShape({1, 2}), {"c", "d"})}),
-       DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3}), {4, 5, 6})},
+      {CreateTensor<Variant>(
+           TensorShape({1}),
+           {CreateTensor<double>(TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}),
+       CreateTensor<Variant>(
+           TensorShape({1}),
+           {CreateTensor<string>(TensorShape({1, 2}), {"a", "b"})}),
+       CreateTensor<int64>(TensorShape({3}), {1, 2, 3}),
+       CreateTensor<Variant>(
+           TensorShape({1}),
+           {CreateTensor<double>(TensorShape({2, 2}), {5.0, 6.0, 7.0, 8.0})}),
+       CreateTensor<Variant>(
+           TensorShape({1}),
+           {CreateTensor<string>(TensorShape({1, 2}), {"c", "d"})}),
+       CreateTensor<int64>(TensorShape({3}), {4, 5, 6})},
       /*breakpoints*/ {0, 1, 2}};
 }
 
@@ -396,48 +387,6 @@
   EXPECT_EQ(tensor_slice_dataset->Cardinality(), inputs[0].tensor->dim_size(0));
 }
 
-TEST_F(TensorSliceDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  const TestCase &test_case = PlainTensorTestCase();
-  const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
-  std::vector<Tensor> components = test_case.components;
-  DataTypeVector dtypes;
-  gtl::InlinedVector<TensorValue, 4> inputs;
-  for (auto &component : components) {
-    inputs.emplace_back(&component);
-    dtypes.emplace_back(component.dtype());
-  }
-  size_t num_tensors_per_slice = components.size();
-  std::vector<PartialTensorShape> shapes;
-  shapes.reserve(num_tensors_per_slice);
-  for (int i = 0; i < num_tensors_per_slice; ++i) {
-    shapes.emplace_back(expected_outputs[i].shape());
-  }
-  std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
-  TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
-                                              &tensor_slice_dataset_kernel));
-  std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
-  TF_ASSERT_OK(
-      CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
-                                      &inputs, &tensor_slice_dataset_context));
-  DatasetBase *tensor_slice_dataset;
-  TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
-                             tensor_slice_dataset_context.get(),
-                             &tensor_slice_dataset));
-  core::ScopedUnref scoped_unref(tensor_slice_dataset);
-
-  std::unique_ptr<SerializationContext> serialization_context;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(
-      tensor_slice_dataset->Save(serialization_context.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_P(ParameterizedTensorSliceDatasetOpTest, IteratorOutputDtypes) {
   int thread_num = 2, cpu_num = 2;
   TF_ASSERT_OK(InitThreadPool(thread_num));
diff --git a/tensorflow/core/kernels/data/text_line_dataset_op.cc b/tensorflow/core/kernels/data/text_line_dataset_op.cc
index b8302b8..e747ad3 100644
--- a/tensorflow/core/kernels/data/text_line_dataset_op.cc
+++ b/tensorflow/core/kernels/data/text_line_dataset_op.cc
@@ -70,6 +70,8 @@
     return name_utils::DatasetDebugString(kDatasetType);
   }
 
+  Status CheckExternalState() const override { return Status::OK(); }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
@@ -108,7 +110,7 @@
                 line_contents.size());
             out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
                                       TensorShape({}));
-            out_tensors->back().scalar<string>()() = std::move(line_contents);
+            out_tensors->back().scalar<tstring>()() = std::move(line_contents);
             *end_of_sequence = false;
             return Status::OK();
           } else if (!errors::IsOutOfRange(s)) {
@@ -266,7 +268,7 @@
   std::vector<string> filenames;
   filenames.reserve(filenames_tensor->NumElements());
   for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
-    filenames.push_back(filenames_tensor->flat<string>()(i));
+    filenames.push_back(filenames_tensor->flat<tstring>()(i));
   }
 
   *output = new Dataset(ctx, std::move(filenames), compression_type,
diff --git a/tensorflow/core/kernels/data/text_line_dataset_op_test.cc b/tensorflow/core/kernels/data/text_line_dataset_op_test.cc
index d5909c8..76c65ff 100644
--- a/tensorflow/core/kernels/data/text_line_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/text_line_dataset_op_test.cc
@@ -82,81 +82,66 @@
 
 // Test case 1: multiple text files with ZLIB compression.
 TestCase TestCase1() {
-  return {
-      /*filenames*/ {absl::StrCat(testing::TmpDir(), "/text_line_ZLIB_1"),
-                     absl::StrCat(testing::TmpDir(), "/text_line_ZLIB_2")},
-      /*texts*/
-      {absl::StrCat("hello world\n", "11223334455\n"),
-       absl::StrCat("abcd, EFgH\n", "           \n", "$%^&*()\n")},
-      /*compression_type*/ CompressionType::ZLIB,
-      /*buffer_size*/ 10,
-      /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<string>(TensorShape({}),
-                                                {"hello world"}),
-       DatasetOpsTestBase::CreateTensor<string>(TensorShape({}),
-                                                {"11223334455"}),
-       DatasetOpsTestBase::CreateTensor<string>(TensorShape({}),
-                                                {"abcd, EFgH"}),
-       DatasetOpsTestBase::CreateTensor<string>(TensorShape({}),
-                                                {"           "}),
-       DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"$%^&*()"})},
-      /*expected_output_dtypes*/ {DT_STRING},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ kUnknownCardinality,
-      /*breakpoints*/ {0, 2, 6}};
+  return {/*filenames*/ {absl::StrCat(testing::TmpDir(), "/text_line_ZLIB_1"),
+                         absl::StrCat(testing::TmpDir(), "/text_line_ZLIB_2")},
+          /*texts*/
+          {absl::StrCat("hello world\n", "11223334455\n"),
+           absl::StrCat("abcd, EFgH\n", "           \n", "$%^&*()\n")},
+          /*compression_type*/ CompressionType::ZLIB,
+          /*buffer_size*/ 10,
+          /*expected_outputs*/
+          {CreateTensor<string>(TensorShape({}), {"hello world"}),
+           CreateTensor<string>(TensorShape({}), {"11223334455"}),
+           CreateTensor<string>(TensorShape({}), {"abcd, EFgH"}),
+           CreateTensor<string>(TensorShape({}), {"           "}),
+           CreateTensor<string>(TensorShape({}), {"$%^&*()"})},
+          /*expected_output_dtypes*/ {DT_STRING},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ kUnknownCardinality,
+          /*breakpoints*/ {0, 2, 6}};
 }
 
 // Test case 2: multiple text files with GZIP compression.
 TestCase TestCase2() {
-  return {
-      /*filenames*/ {absl::StrCat(testing::TmpDir(), "/text_line_GZIP_1"),
-                     absl::StrCat(testing::TmpDir(), "/text_line_GZIP_2")},
-      /*texts*/
-      {absl::StrCat("hello world\n", "11223334455\n"),
-       absl::StrCat("abcd, EFgH\n", "           \n", "$%^&*()\n")},
-      /*compression_type*/ CompressionType::GZIP,
-      /*buffer_size*/ 10,
-      /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<string>(TensorShape({}),
-                                                {"hello world"}),
-       DatasetOpsTestBase::CreateTensor<string>(TensorShape({}),
-                                                {"11223334455"}),
-       DatasetOpsTestBase::CreateTensor<string>(TensorShape({}),
-                                                {"abcd, EFgH"}),
-       DatasetOpsTestBase::CreateTensor<string>(TensorShape({}),
-                                                {"           "}),
-       DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"$%^&*()"})},
-      /*expected_output_dtypes*/ {DT_STRING},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ kUnknownCardinality,
-      /*breakpoints*/ {0, 2, 6}};
+  return {/*filenames*/ {absl::StrCat(testing::TmpDir(), "/text_line_GZIP_1"),
+                         absl::StrCat(testing::TmpDir(), "/text_line_GZIP_2")},
+          /*texts*/
+          {absl::StrCat("hello world\n", "11223334455\n"),
+           absl::StrCat("abcd, EFgH\n", "           \n", "$%^&*()\n")},
+          /*compression_type*/ CompressionType::GZIP,
+          /*buffer_size*/ 10,
+          /*expected_outputs*/
+          {CreateTensor<string>(TensorShape({}), {"hello world"}),
+           CreateTensor<string>(TensorShape({}), {"11223334455"}),
+           CreateTensor<string>(TensorShape({}), {"abcd, EFgH"}),
+           CreateTensor<string>(TensorShape({}), {"           "}),
+           CreateTensor<string>(TensorShape({}), {"$%^&*()"})},
+          /*expected_output_dtypes*/ {DT_STRING},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ kUnknownCardinality,
+          /*breakpoints*/ {0, 2, 6}};
 }
 
 // Test case 3: multiple text files without compression.
 TestCase TestCase3() {
-  return {
-      /*filenames*/ {
-          absl::StrCat(testing::TmpDir(), "/text_line_UNCOMPRESSED_1"),
-          absl::StrCat(testing::TmpDir(), "/text_line_UNCOMPRESSED_2")},
-      /*texts*/
-      {absl::StrCat("hello world\n", "11223334455\n"),
-       absl::StrCat("abcd, EFgH\n", "           \n", "$%^&*()\n")},
-      /*compression_type*/ CompressionType::UNCOMPRESSED,
-      /*buffer_size*/ 10,
-      /*expected_outputs*/
-      {DatasetOpsTestBase::CreateTensor<string>(TensorShape({}),
-                                                {"hello world"}),
-       DatasetOpsTestBase::CreateTensor<string>(TensorShape({}),
-                                                {"11223334455"}),
-       DatasetOpsTestBase::CreateTensor<string>(TensorShape({}),
-                                                {"abcd, EFgH"}),
-       DatasetOpsTestBase::CreateTensor<string>(TensorShape({}),
-                                                {"           "}),
-       DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"$%^&*()"})},
-      /*expected_output_dtypes*/ {DT_STRING},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ kUnknownCardinality,
-      /*breakpoints*/ {0, 2, 6}};
+  return {/*filenames*/ {
+              absl::StrCat(testing::TmpDir(), "/text_line_UNCOMPRESSED_1"),
+              absl::StrCat(testing::TmpDir(), "/text_line_UNCOMPRESSED_2")},
+          /*texts*/
+          {absl::StrCat("hello world\n", "11223334455\n"),
+           absl::StrCat("abcd, EFgH\n", "           \n", "$%^&*()\n")},
+          /*compression_type*/ CompressionType::UNCOMPRESSED,
+          /*buffer_size*/ 10,
+          /*expected_outputs*/
+          {CreateTensor<string>(TensorShape({}), {"hello world"}),
+           CreateTensor<string>(TensorShape({}), {"11223334455"}),
+           CreateTensor<string>(TensorShape({}), {"abcd, EFgH"}),
+           CreateTensor<string>(TensorShape({}), {"           "}),
+           CreateTensor<string>(TensorShape({}), {"$%^&*()"})},
+          /*expected_output_dtypes*/ {DT_STRING},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ kUnknownCardinality,
+          /*breakpoints*/ {0, 2, 6}};
 }
 
 class ParameterizedTextLineDatasetOpTest
@@ -381,45 +366,6 @@
   EXPECT_EQ(text_line_dataset->Cardinality(), test_case.expected_cardinality);
 }
 
-TEST_P(ParameterizedTextLineDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  TestCase test_case = GetParam();
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  TF_ASSERT_OK(CreateTestFiles(test_case));
-
-  std::unique_ptr<OpKernel> text_line_dataset_kernel;
-  TF_ASSERT_OK(CreateTextLineDatasetOpKernel(&text_line_dataset_kernel));
-
-  int64 num_files = test_case.filenames.size();
-  Tensor filenames =
-      CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
-  Tensor compression_type = CreateTensor<string>(
-      TensorShape({}), {ToString(test_case.compression_type)});
-  Tensor buffer_size =
-      CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
-  gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
-                                            TensorValue(&compression_type),
-                                            TensorValue(&buffer_size)};
-  std::unique_ptr<OpKernelContext> text_line_dataset_context;
-  TF_ASSERT_OK(CreateTextLineDatasetContext(
-      text_line_dataset_kernel.get(), &inputs, &text_line_dataset_context));
-
-  DatasetBase* text_line_dataset;
-  TF_ASSERT_OK(CreateDataset(text_line_dataset_kernel.get(),
-                             text_line_dataset_context.get(),
-                             &text_line_dataset));
-  core::ScopedUnref scoped_unref(text_line_dataset);
-
-  std::unique_ptr<SerializationContext> serialization_context;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(text_line_dataset->Save(serialization_context.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_P(ParameterizedTextLineDatasetOpTest, IteratorOutputDtypes) {
   int thread_num = 2, cpu_num = 2;
   TestCase test_case = GetParam();
diff --git a/tensorflow/core/kernels/data/tf_record_dataset_op.cc b/tensorflow/core/kernels/data/tf_record_dataset_op.cc
index e35743d..861639b 100644
--- a/tensorflow/core/kernels/data/tf_record_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tf_record_dataset_op.cc
@@ -74,6 +74,8 @@
     return name_utils::DatasetDebugString(kDatasetType);
   }
 
+  Status CheckExternalState() const override { return Status::OK(); }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
@@ -108,7 +110,7 @@
               reader_->ReadRecord(&out_tensors->back().scalar<string>()());
           if (s.ok()) {
             metrics::RecordTFDataBytesRead(
-                kDatasetType, out_tensors->back().scalar<string>()().size());
+                kDatasetType, out_tensors->back().scalar<tstring>()().size());
             *end_of_sequence = false;
             return Status::OK();
           }
@@ -224,8 +226,8 @@
   std::vector<string> filenames;
   filenames.reserve(filenames_tensor->NumElements());
   for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
-    VLOG(2) << "Reading file: " << filenames_tensor->flat<string>()(i);
-    filenames.push_back(filenames_tensor->flat<string>()(i));
+    VLOG(2) << "Reading file: " << filenames_tensor->flat<tstring>()(i);
+    filenames.push_back(filenames_tensor->flat<tstring>()(i));
   }
 
   string compression_type;
diff --git a/tensorflow/core/kernels/data/tf_record_dataset_op_test.cc b/tensorflow/core/kernels/data/tf_record_dataset_op_test.cc
index 742b458..936d7e1 100644
--- a/tensorflow/core/kernels/data/tf_record_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/tf_record_dataset_op_test.cc
@@ -84,12 +84,12 @@
           /*compression_type*/ CompressionType::ZLIB,
           /*buffer_size*/ 10,
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"1"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"22"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"333"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"a"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"bb"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"ccc"})},
+          {CreateTensor<string>(TensorShape({}), {"1"}),
+           CreateTensor<string>(TensorShape({}), {"22"}),
+           CreateTensor<string>(TensorShape({}), {"333"}),
+           CreateTensor<string>(TensorShape({}), {"a"}),
+           CreateTensor<string>(TensorShape({}), {"bb"}),
+           CreateTensor<string>(TensorShape({}), {"ccc"})},
           /*expected_output_dtypes*/ {DT_STRING},
           /*expected_output_shapes*/ {PartialTensorShape({})},
           /*expected_cardinality*/ kUnknownCardinality,
@@ -105,12 +105,12 @@
           /*compression_type*/ CompressionType::GZIP,
           /*buffer_size*/ 10,
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"1"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"22"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"333"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"a"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"bb"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"ccc"})},
+          {CreateTensor<string>(TensorShape({}), {"1"}),
+           CreateTensor<string>(TensorShape({}), {"22"}),
+           CreateTensor<string>(TensorShape({}), {"333"}),
+           CreateTensor<string>(TensorShape({}), {"a"}),
+           CreateTensor<string>(TensorShape({}), {"bb"}),
+           CreateTensor<string>(TensorShape({}), {"ccc"})},
           /*expected_output_dtypes*/ {DT_STRING},
           /*expected_output_shapes*/ {PartialTensorShape({})},
           /*expected_cardinality*/ kUnknownCardinality,
@@ -127,12 +127,12 @@
           /*compression_type*/ CompressionType::UNCOMPRESSED,
           /*buffer_size*/ 10,
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"1"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"22"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"333"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"a"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"bb"}),
-           DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"ccc"})},
+          {CreateTensor<string>(TensorShape({}), {"1"}),
+           CreateTensor<string>(TensorShape({}), {"22"}),
+           CreateTensor<string>(TensorShape({}), {"333"}),
+           CreateTensor<string>(TensorShape({}), {"a"}),
+           CreateTensor<string>(TensorShape({}), {"bb"}),
+           CreateTensor<string>(TensorShape({}), {"ccc"})},
           /*expected_output_dtypes*/ {DT_STRING},
           /*expected_output_shapes*/ {PartialTensorShape({})},
           /*expected_cardinality*/ kUnknownCardinality,
@@ -361,45 +361,6 @@
   EXPECT_EQ(tf_record_dataset->Cardinality(), test_case.expected_cardinality);
 }
 
-TEST_P(ParameterizedTFRecordDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  TestCase test_case = GetParam();
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  TF_ASSERT_OK(CreateTestFiles(test_case));
-
-  std::unique_ptr<OpKernel> tf_record_dataset_kernel;
-  TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
-
-  int64 num_files = test_case.filenames.size();
-  Tensor filenames =
-      CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
-  Tensor compression_type = CreateTensor<string>(
-      TensorShape({}), {ToString(test_case.compression_type)});
-  Tensor buffer_size =
-      CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
-  gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
-                                            TensorValue(&compression_type),
-                                            TensorValue(&buffer_size)};
-  std::unique_ptr<OpKernelContext> tf_record_dataset_context;
-  TF_ASSERT_OK(CreateTFRecordDatasetContext(
-      tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
-
-  DatasetBase* tf_record_dataset;
-  TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
-                             tf_record_dataset_context.get(),
-                             &tf_record_dataset));
-  core::ScopedUnref scoped_unref(tf_record_dataset);
-
-  std::unique_ptr<SerializationContext> serialization_context;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(tf_record_dataset->Save(serialization_context.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_P(ParameterizedTFRecordDatasetOpTest, IteratorOutputDtypes) {
   int thread_num = 2, cpu_num = 2;
   TestCase test_case = GetParam();
diff --git a/tensorflow/core/kernels/data/unbounded_thread_pool.cc b/tensorflow/core/kernels/data/unbounded_thread_pool.cc
index ac12197..9cb4563 100644
--- a/tensorflow/core/kernels/data/unbounded_thread_pool.cc
+++ b/tensorflow/core/kernels/data/unbounded_thread_pool.cc
@@ -16,27 +16,13 @@
 #include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
 
 #include "absl/memory/memory.h"
+#include "tensorflow/core/lib/core/notification.h"
 #include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/unbounded_work_queue.h"
 
 namespace tensorflow {
 namespace data {
 
-// A lightweight wrapper for creating logical threads in a `UnboundedThreadPool`
-// that can be shared (e.g.) in an `IteratorContext`.
-class UnboundedThreadPool::LogicalThreadFactory : public ThreadFactory {
- public:
-  explicit LogicalThreadFactory(UnboundedThreadPool* pool) : pool_(pool) {}
-
-  std::unique_ptr<Thread> StartThread(const string& name,
-                                      std::function<void()> fn) override {
-    return pool_->RunOnPooledThread(std::move(fn));
-  }
-
- private:
-  UnboundedThreadPool* const pool_;  // Not owned.
-};
-
 // A logical implementation of the `tensorflow::Thread` interface that uses
 // physical threads in an `UnboundedThreadPool` to perform the work.
 //
@@ -45,111 +31,64 @@
 // same `UnboundedThreadPool`.
 class UnboundedThreadPool::LogicalThreadWrapper : public Thread {
  public:
-  explicit LogicalThreadWrapper(std::shared_ptr<Notification> join_notification)
-      : join_notification_(std::move(join_notification)) {}
+  explicit LogicalThreadWrapper(std::shared_ptr<Notification> done)
+      : done_(std::move(done)) {}
 
   ~LogicalThreadWrapper() override {
     // NOTE: The `Thread` destructor is expected to "join" the created thread,
     // but the physical thread may continue to execute after the work for this
     // thread is complete. We simulate this by waiting on a notification that
-    // the `CachedThreadFunc` will notify when the thread's work function is
-    // complete.
-    join_notification_->WaitForNotification();
+    // the thread's work function will notify when it is complete.
+    done_->WaitForNotification();
   }
 
  private:
-  std::shared_ptr<Notification> join_notification_;
+  std::shared_ptr<Notification> done_;
 };
 
-UnboundedThreadPool::~UnboundedThreadPool() {
-  {
-    mutex_lock l(work_queue_mu_);
-    // Wake up all `CachedThreadFunc` threads and cause them to terminate before
-    // joining them when `threads_` is cleared.
-    cancelled_ = true;
-    work_queue_cv_.notify_all();
-    if (!work_queue_.empty()) {
-      LOG(ERROR) << "UnboundedThreadPool named \"" << thread_name_ << "\" was "
-                 << "deleted with pending work in its queue. This may indicate "
-                 << "a potential use-after-free bug.";
-    }
+// A lightweight wrapper for creating logical threads in a `UnboundedThreadPool`
+// that can be shared (e.g.) in an `IteratorContext`.
+class UnboundedThreadPool::LogicalThreadFactory : public ThreadFactory {
+ public:
+  explicit LogicalThreadFactory(UnboundedThreadPool* pool) : pool_(pool) {}
+
+  std::unique_ptr<Thread> StartThread(const string& name,
+                                      std::function<void()> fn) override {
+    auto done = std::make_shared<Notification>();
+    pool_->ScheduleOnWorkQueue(std::move(fn), done);
+    return absl::make_unique<LogicalThreadWrapper>(std::move(done));
   }
 
-  {
-    mutex_lock l(thread_pool_mu_);
-    // Clear the list of pooled threads, which will eventually terminate due to
-    // the previous notification.
-    //
-    // NOTE: It is safe to do this while holding `pooled_threads_mu_`, because
-    // no subsequent calls to `this->StartThread()` should be issued after the
-    // destructor starts.
-    thread_pool_.clear();
-  }
-}
+ private:
+  UnboundedThreadPool* const pool_;  // Not owned.
+};
 
 std::shared_ptr<ThreadFactory> UnboundedThreadPool::get_thread_factory() {
   return std::make_shared<LogicalThreadFactory>(this);
 }
 
-size_t UnboundedThreadPool::size() {
-  tf_shared_lock l(thread_pool_mu_);
-  return thread_pool_.size();
+void UnboundedThreadPool::Schedule(std::function<void()> fn) {
+  ScheduleOnWorkQueue(std::move(fn), /*done=*/nullptr);
 }
 
-std::unique_ptr<Thread> UnboundedThreadPool::RunOnPooledThread(
-    std::function<void()> fn) {
-  auto join_notification = std::make_shared<Notification>();
-  bool all_threads_busy;
-  {
-    // Enqueue a work item for the new thread's function, and wake up a
-    // cached thread to process it.
-    mutex_lock l(work_queue_mu_);
-    work_queue_.push_back({std::move(fn), join_notification});
-    work_queue_cv_.notify_one();
-    // NOTE: The queue may be non-empty, so we must account for queued work when
-    // considering how many threads are free.
-    all_threads_busy = work_queue_.size() > num_idle_threads_;
+int UnboundedThreadPool::NumThreads() const { return -1; }
+
+int UnboundedThreadPool::CurrentThreadId() const { return -1; }
+
+namespace {
+void WorkQueueFunc(const std::function<void()>& fn,
+                   std::shared_ptr<Notification> done) {
+  fn();
+  if (done) {
+    done->Notify();
   }
-
-  if (all_threads_busy) {
-    // Spawn a new physical thread to process the given function.
-    // NOTE: `PooledThreadFunc` will eventually increment `num_idle_threads_`
-    // at the beginning of its work loop.
-    Thread* new_thread = env_->StartThread(
-        {}, thread_name_,
-        std::bind(&UnboundedThreadPool::PooledThreadFunc, this));
-
-    mutex_lock l(thread_pool_mu_);
-    thread_pool_.emplace_back(new_thread);
-  }
-
-  return absl::make_unique<LogicalThreadWrapper>(std::move(join_notification));
 }
+}  // namespace
 
-void UnboundedThreadPool::PooledThreadFunc() {
-  while (true) {
-    WorkItem work_item;
-    {
-      mutex_lock l(work_queue_mu_);
-      ++num_idle_threads_;
-      while (!cancelled_ && work_queue_.empty()) {
-        // Wait for a new work function to be submitted, or the cache to be
-        // destroyed.
-        work_queue_cv_.wait(l);
-      }
-      if (cancelled_) {
-        return;
-      }
-      work_item = std::move(work_queue_.front());
-      work_queue_.pop_front();
-      --num_idle_threads_;
-    }
-
-    work_item.work_function();
-
-    // Notify any thread that has "joined" the cached thread for this work item.
-    work_item.done_notification->Notify();
-  }
+void UnboundedThreadPool::ScheduleOnWorkQueue(
+    std::function<void()> fn, std::shared_ptr<Notification> done) {
+  unbounded_work_queue_.Schedule(
+      std::bind(&WorkQueueFunc, std::move(fn), std::move(done)));
 }
 
 }  // namespace data
diff --git a/tensorflow/core/kernels/data/unbounded_thread_pool.h b/tensorflow/core/kernels/data/unbounded_thread_pool.h
index c84d495..82335d7 100644
--- a/tensorflow/core/kernels/data/unbounded_thread_pool.h
+++ b/tensorflow/core/kernels/data/unbounded_thread_pool.h
@@ -21,54 +21,39 @@
 
 #include "tensorflow/core/framework/thread_factory.h"
 #include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/threadpool_interface.h"
 #include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/unbounded_work_queue.h"
 
 namespace tensorflow {
 namespace data {
 
 // An `UnboundedThreadPool` provides a mechanism for temporally multiplexing a
 // potentially large number of "logical" threads onto a smaller number of
-// "physical" threads. The multiplexing is achieved by maintaining an internal
-// pool of long-running "physical" threads that are used to execute the
-// "logical" threads.  Like a regular thread, a "logical" thread may block on
-// other threads, and the size of the pool will increase to ensure that progress
-// is made. This mechanism is recommended in situations where short-lived
-// threads are created repeatedly, to avoid the overhead and memory
-// fragmentation that can result from excessive thread creation.
-class UnboundedThreadPool {
+// "physical" threads. The multiplexing is achieved by using an
+// `UnboundedWorkQueue`.
+class UnboundedThreadPool : public thread::ThreadPoolInterface {
  public:
   UnboundedThreadPool(Env* env, const string& thread_name)
-      : env_(env), thread_name_(thread_name) {}
-  ~UnboundedThreadPool();
+      : unbounded_work_queue_(env, thread_name) {}
+  ~UnboundedThreadPool() = default;
 
   // Returns an implementation of `ThreadFactory` that can be used to create
   // logical threads in this pool.
   std::shared_ptr<ThreadFactory> get_thread_factory();
 
-  // Returns the current number of threads in this pool.
-  size_t size();
+  void Schedule(std::function<void()> fn) override;
+  int NumThreads() const override;
+  int CurrentThreadId() const override;
 
  private:
   class LogicalThreadFactory;
   class LogicalThreadWrapper;
-  struct WorkItem {
-    std::function<void()> work_function;
-    std::shared_ptr<Notification> done_notification;
-  };
 
-  std::unique_ptr<Thread> RunOnPooledThread(std::function<void()> fn);
-  void PooledThreadFunc();
+  void ScheduleOnWorkQueue(std::function<void()> fn,
+                           std::shared_ptr<Notification> done);
 
-  Env* const env_;  // Not owned.
-  const string thread_name_;
-  mutex work_queue_mu_;
-  condition_variable work_queue_cv_ GUARDED_BY(work_queue_mu_);
-  size_t num_idle_threads_ GUARDED_BY(work_queue_mu_) = 0;
-  bool cancelled_ GUARDED_BY(work_queue_mu_) = false;
-  std::deque<WorkItem> work_queue_ GUARDED_BY(work_queue_mu_);
-  mutex thread_pool_mu_;
-  std::vector<std::unique_ptr<Thread>> thread_pool_ GUARDED_BY(thread_pool_mu_);
+  UnboundedWorkQueue unbounded_work_queue_;
 };
 
 }  // namespace data
diff --git a/tensorflow/core/kernels/data/unbounded_thread_pool_test.cc b/tensorflow/core/kernels/data/unbounded_thread_pool_test.cc
index f996b4f..3604be8 100644
--- a/tensorflow/core/kernels/data/unbounded_thread_pool_test.cc
+++ b/tensorflow/core/kernels/data/unbounded_thread_pool_test.cc
@@ -23,59 +23,6 @@
 namespace data {
 namespace {
 
-TEST(UnboundedThreadPool, SingleThread) {
-  UnboundedThreadPool pool(Env::Default(), "test");
-  auto thread_factory = pool.get_thread_factory();
-
-  // Create a thread that updates a variable, and ensure that it runs to
-  // completion.
-  std::atomic<int> i(0);
-  auto thread = thread_factory->StartThread("", [&i]() { ++i; });
-  thread.reset();
-
-  EXPECT_GE(pool.size(), 1);
-  EXPECT_EQ(1, i);
-}
-
-TEST(UnboundedThreadPool, MultipleThreads) {
-  UnboundedThreadPool pool(Env::Default(), "test");
-  auto thread_factory = pool.get_thread_factory();
-
-  // Create ten threads that update a variable, and ensure that they all run
-  // to completion.
-  std::vector<std::unique_ptr<Thread>> threads;
-  const int kNumThreadsToCreate = 10;
-  std::atomic<int> i(0);
-  for (int j = 0; j < kNumThreadsToCreate; ++j) {
-    threads.push_back(thread_factory->StartThread("", [&i]() { ++i; }));
-  }
-  threads.clear();
-
-  EXPECT_GE(pool.size(), 1);
-  EXPECT_EQ(i, kNumThreadsToCreate);
-}
-
-TEST(UnboundedThreadPool, MultipleThreadsSleepingRandomly) {
-  UnboundedThreadPool pool(Env::Default(), "test");
-  auto thread_factory = pool.get_thread_factory();
-
-  // Create 1000 threads that sleep for a random period of time then update a
-  // variable, and ensure that they all run to completion.
-  std::vector<std::unique_ptr<Thread>> threads;
-  const int kNumThreadsToCreate = 1000;
-  std::atomic<int> i(0);
-  for (int j = 0; j < kNumThreadsToCreate; ++j) {
-    threads.push_back(thread_factory->StartThread("", [&i]() {
-      Env::Default()->SleepForMicroseconds(random::New64() % 10);
-      ++i;
-    }));
-  }
-  threads.clear();
-
-  EXPECT_GE(pool.size(), 1);
-  EXPECT_EQ(i, kNumThreadsToCreate);
-}
-
 TEST(UnboundedThreadPool, ConcurrentThreadCreation) {
   UnboundedThreadPool pool(Env::Default(), "test");
   auto thread_factory = pool.get_thread_factory();
@@ -97,7 +44,6 @@
   }
   threads.clear();
 
-  EXPECT_GE(pool.size(), 1);
   EXPECT_EQ(i, kNumThreadsToCreate * kNumThreadsToCreate);
 }
 
@@ -108,9 +54,7 @@
   std::vector<std::unique_ptr<Thread>> threads;
 
   // Create multiple waves (with increasing sizes) of threads that all block
-  // before returning, and
-  // ensure that we create the appropriate number of threads and terminate
-  // correctly.
+  // before returning, and ensure that we terminate correctly.
   std::vector<int> round_sizes = {5, 10, 15, 20};
 
   for (const int round_size : round_sizes) {
@@ -129,10 +73,6 @@
     // wave is increasing, we should have at least that number of threads in the
     // pool.
     bc.Wait();
-    // NOTE: There is a benign race between a new round starting and the
-    // physical threads from the previous round returning to the pool, so we may
-    // create more threads than the round_size.
-    EXPECT_GE(pool.size(), round_size);
     n.Notify();
     threads.clear();
   }
diff --git a/tensorflow/core/kernels/data/window_dataset.cc b/tensorflow/core/kernels/data/window_dataset.cc
index 3b1e886..b8a7a8a 100644
--- a/tensorflow/core/kernels/data/window_dataset.cc
+++ b/tensorflow/core/kernels/data/window_dataset.cc
@@ -59,6 +59,8 @@
 
   string DebugString() const override { return kWindowDataset; }
 
+  Status CheckExternalState() const override { return Status::OK(); }
+
  protected:
   // TODO(b/110981596): Support checkpointing.
   Status AsGraphDefInternal(SerializationContext* ctx,
diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc
index 3f2d18d..a767cc5 100644
--- a/tensorflow/core/kernels/data/window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/window_dataset_op.cc
@@ -99,6 +99,10 @@
     return cardinality;
   }
 
+  Status CheckExternalState() const override {
+    return input_->CheckExternalState();
+  }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
diff --git a/tensorflow/core/kernels/data/window_dataset_op_test.cc b/tensorflow/core/kernels/data/window_dataset_op_test.cc
index e02d5e8..4e01fb3 100644
--- a/tensorflow/core/kernels/data/window_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/window_dataset_op_test.cc
@@ -69,247 +69,234 @@
 
 // Test case 1: size=2, shift=2, stride=1, drop_remainder=false.
 TestCase TestCase1() {
-  return {
-      /*range_data_param*/ {0, 7, 1},
-      /*size*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*shift*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*stride*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-      /*drop_remainder*/
-      DatasetOpsTestBase::CreateTensor<bool>(TensorShape({}), {false}),
-      /*expected_outputs*/
-      {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-        DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1})},
-       {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-        DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3})},
-       {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4}),
-        DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {5})},
-       {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {6})}},
-      /*expected_output_dtypes*/ {DT_VARIANT},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ 4,
-      /*breakpoints*/ {0, 1, 9}};
+  return {/*range_data_param*/ {0, 7, 1},
+          /*size*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*shift*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*stride*/ CreateTensor<int64>(TensorShape({}), {1}),
+          /*drop_remainder*/
+          CreateTensor<bool>(TensorShape({}), {false}),
+          /*expected_outputs*/
+          {{CreateTensor<int64>(TensorShape({}), {0}),
+            CreateTensor<int64>(TensorShape({}), {1})},
+           {CreateTensor<int64>(TensorShape({}), {2}),
+            CreateTensor<int64>(TensorShape({}), {3})},
+           {CreateTensor<int64>(TensorShape({}), {4}),
+            CreateTensor<int64>(TensorShape({}), {5})},
+           {CreateTensor<int64>(TensorShape({}), {6})}},
+          /*expected_output_dtypes*/ {DT_VARIANT},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ 4,
+          /*breakpoints*/ {0, 1, 9}};
 }
 
 // Test case 2: size=2, shift=2, stride=2, drop_remainder=true.
 TestCase TestCase2() {
-  return {
-      /*range_data_param*/ {0, 7, 1},
-      /*size*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*shift*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*stride*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*drop_remainder*/
-      DatasetOpsTestBase::CreateTensor<bool>(TensorShape({}), {true}),
-      /*expected_outputs*/
-      {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-        DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2})},
-       {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-        DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4})},
-       {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4}),
-        DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {6})}},
-      /*expected_output_dtypes*/ {DT_VARIANT},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ 3,
-      /*breakpoints*/ {0, 1, 9}};
+  return {/*range_data_param*/ {0, 7, 1},
+          /*size*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*shift*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*stride*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*drop_remainder*/
+          CreateTensor<bool>(TensorShape({}), {true}),
+          /*expected_outputs*/
+          {{CreateTensor<int64>(TensorShape({}), {0}),
+            CreateTensor<int64>(TensorShape({}), {2})},
+           {CreateTensor<int64>(TensorShape({}), {2}),
+            CreateTensor<int64>(TensorShape({}), {4})},
+           {CreateTensor<int64>(TensorShape({}), {4}),
+            CreateTensor<int64>(TensorShape({}), {6})}},
+          /*expected_output_dtypes*/ {DT_VARIANT},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ 3,
+          /*breakpoints*/ {0, 1, 9}};
 }
 
 // Test case 3: size=8, shift=3, stride=1, drop_remainder=false.
 TestCase TestCase3() {
-  return {
-      /*range_data_param*/ {0, 7, 1},
-      /*size*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {8}),
-      /*shift*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3}),
-      /*stride*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-      /*drop_remainder*/
-      DatasetOpsTestBase::CreateTensor<bool>(TensorShape({}), {false}),
-      /*expected_outputs*/
-      {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-        DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-        DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-        DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3}),
-        DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4}),
-        DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {5}),
-        DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {6})},
-       {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3}),
-        DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4}),
-        DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {5}),
-        DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {6})},
-       {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {6})}},
-      /*expected_output_dtypes*/ {DT_VARIANT},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ 3,
-      /*breakpoints*/ {0, 1, 9}};
+  return {/*range_data_param*/ {0, 7, 1},
+          /*size*/ CreateTensor<int64>(TensorShape({}), {8}),
+          /*shift*/ CreateTensor<int64>(TensorShape({}), {3}),
+          /*stride*/ CreateTensor<int64>(TensorShape({}), {1}),
+          /*drop_remainder*/
+          CreateTensor<bool>(TensorShape({}), {false}),
+          /*expected_outputs*/
+          {{CreateTensor<int64>(TensorShape({}), {0}),
+            CreateTensor<int64>(TensorShape({}), {1}),
+            CreateTensor<int64>(TensorShape({}), {2}),
+            CreateTensor<int64>(TensorShape({}), {3}),
+            CreateTensor<int64>(TensorShape({}), {4}),
+            CreateTensor<int64>(TensorShape({}), {5}),
+            CreateTensor<int64>(TensorShape({}), {6})},
+           {CreateTensor<int64>(TensorShape({}), {3}),
+            CreateTensor<int64>(TensorShape({}), {4}),
+            CreateTensor<int64>(TensorShape({}), {5}),
+            CreateTensor<int64>(TensorShape({}), {6})},
+           {CreateTensor<int64>(TensorShape({}), {6})}},
+          /*expected_output_dtypes*/ {DT_VARIANT},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ 3,
+          /*breakpoints*/ {0, 1, 9}};
 }
 
 // Test case 4: size=8, shift=3, stride=1, drop_remainder=true.
 TestCase TestCase4() {
-  return {
-      /*range_data_param*/ {0, 7, 1},
-      /*size*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {8}),
-      /*shift*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3}),
-      /*stride*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-      /*drop_remainder*/
-      DatasetOpsTestBase::CreateTensor<bool>(TensorShape({}), {true}),
-      /*expected_outputs*/ {},
-      /*expected_output_dtypes*/ {DT_VARIANT},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ 0,
-      /*breakpoints*/ {0, 1, 9}};
+  return {/*range_data_param*/ {0, 7, 1},
+          /*size*/ CreateTensor<int64>(TensorShape({}), {8}),
+          /*shift*/ CreateTensor<int64>(TensorShape({}), {3}),
+          /*stride*/ CreateTensor<int64>(TensorShape({}), {1}),
+          /*drop_remainder*/
+          CreateTensor<bool>(TensorShape({}), {true}),
+          /*expected_outputs*/ {},
+          /*expected_output_dtypes*/ {DT_VARIANT},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ 0,
+          /*breakpoints*/ {0, 1, 9}};
 }
 
 // Test case 5: size=2, shift=8, stride=1, drop_remainder=false.
 TestCase TestCase5() {
-  return {
-      /*range_data_param*/ {0, 7, 1},
-      /*size*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*shift*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {8}),
-      /*stride*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-      /*drop_remainder*/
-      DatasetOpsTestBase::CreateTensor<bool>(TensorShape({}), {false}),
-      /*expected_outputs*/
-      {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-        DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1})}},
-      /*expected_output_dtypes*/ {DT_VARIANT},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ 1,
-      /*breakpoints*/ {0, 1, 9}};
+  return {/*range_data_param*/ {0, 7, 1},
+          /*size*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*shift*/ CreateTensor<int64>(TensorShape({}), {8}),
+          /*stride*/ CreateTensor<int64>(TensorShape({}), {1}),
+          /*drop_remainder*/
+          CreateTensor<bool>(TensorShape({}), {false}),
+          /*expected_outputs*/
+          {{CreateTensor<int64>(TensorShape({}), {0}),
+            CreateTensor<int64>(TensorShape({}), {1})}},
+          /*expected_output_dtypes*/ {DT_VARIANT},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ 1,
+          /*breakpoints*/ {0, 1, 9}};
 }
 
 // Test case 6: size=2, shift=8, stride=1, drop_remainder=true.
 TestCase TestCase6() {
-  return {
-      /*range_data_param*/ {0, 7, 1},
-      /*size*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*shift*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {8}),
-      /*stride*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-      /*drop_remainder*/
-      DatasetOpsTestBase::CreateTensor<bool>(TensorShape({}), {true}),
-      /*expected_outputs*/
-      {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-        DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1})}},
-      /*expected_output_dtypes*/ {DT_VARIANT},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ 1,
-      /*breakpoints*/ {0, 1, 9}};
+  return {/*range_data_param*/ {0, 7, 1},
+          /*size*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*shift*/ CreateTensor<int64>(TensorShape({}), {8}),
+          /*stride*/ CreateTensor<int64>(TensorShape({}), {1}),
+          /*drop_remainder*/
+          CreateTensor<bool>(TensorShape({}), {true}),
+          /*expected_outputs*/
+          {{CreateTensor<int64>(TensorShape({}), {0}),
+            CreateTensor<int64>(TensorShape({}), {1})}},
+          /*expected_output_dtypes*/ {DT_VARIANT},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ 1,
+          /*breakpoints*/ {0, 1, 9}};
 }
 
 // Test case 7: size=2, shift=2, stride=8, drop_remainder=false.
 TestCase TestCase7() {
-  return {
-      /*range_data_param*/ {0, 7, 1},
-      /*size*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*shift*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*stride*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {8}),
-      /*drop_remainder*/
-      DatasetOpsTestBase::CreateTensor<bool>(TensorShape({}), {false}),
-      /*expected_outputs*/
-      {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0})},
-       {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2})},
-       {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4})},
-       {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {6})}},
-      /*expected_output_dtypes*/ {DT_VARIANT},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ 4,
-      /*breakpoints*/ {0, 1, 9}};
+  return {/*range_data_param*/ {0, 7, 1},
+          /*size*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*shift*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*stride*/ CreateTensor<int64>(TensorShape({}), {8}),
+          /*drop_remainder*/
+          CreateTensor<bool>(TensorShape({}), {false}),
+          /*expected_outputs*/
+          {{CreateTensor<int64>(TensorShape({}), {0})},
+           {CreateTensor<int64>(TensorShape({}), {2})},
+           {CreateTensor<int64>(TensorShape({}), {4})},
+           {CreateTensor<int64>(TensorShape({}), {6})}},
+          /*expected_output_dtypes*/ {DT_VARIANT},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ 4,
+          /*breakpoints*/ {0, 1, 9}};
 }
 
 // Test case 8: size=2, shift=2, stride=8, drop_remainder=true.
 TestCase TestCase8() {
-  return {
-      /*range_data_param*/ {0, 7, 1},
-      /*size*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*shift*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*stride*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {8}),
-      /*drop_remainder*/
-      DatasetOpsTestBase::CreateTensor<bool>(TensorShape({}), {true}),
-      /*expected_outputs*/ {},
-      /*expected_output_dtypes*/ {DT_VARIANT},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ 0,
-      /*breakpoints*/ {0, 1, 9}};
+  return {/*range_data_param*/ {0, 7, 1},
+          /*size*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*shift*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*stride*/ CreateTensor<int64>(TensorShape({}), {8}),
+          /*drop_remainder*/
+          CreateTensor<bool>(TensorShape({}), {true}),
+          /*expected_outputs*/ {},
+          /*expected_output_dtypes*/ {DT_VARIANT},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ 0,
+          /*breakpoints*/ {0, 1, 9}};
 }
 
 // Test case 9: size=4, shift=2, stride=2, drop_remainder=true.
 TestCase TestCase9() {
-  return {
-      /*range_data_param*/ {0, 7, 1},
-      /*size*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4}),
-      /*shift*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*stride*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*drop_remainder*/
-      DatasetOpsTestBase::CreateTensor<bool>(TensorShape({}), {true}),
-      /*expected_outputs*/
-      {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-        DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-        DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4}),
-        DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {6})}},
-      /*expected_output_dtypes*/ {DT_VARIANT},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ 1,
-      /*breakpoints*/ {0, 1, 9}};
+  return {/*range_data_param*/ {0, 7, 1},
+          /*size*/ CreateTensor<int64>(TensorShape({}), {4}),
+          /*shift*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*stride*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*drop_remainder*/
+          CreateTensor<bool>(TensorShape({}), {true}),
+          /*expected_outputs*/
+          {{CreateTensor<int64>(TensorShape({}), {0}),
+            CreateTensor<int64>(TensorShape({}), {2}),
+            CreateTensor<int64>(TensorShape({}), {4}),
+            CreateTensor<int64>(TensorShape({}), {6})}},
+          /*expected_output_dtypes*/ {DT_VARIANT},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ 1,
+          /*breakpoints*/ {0, 1, 9}};
 }
 
 // Test case 10: size=5, shift=2, stride=2, drop_remainder=true.
 TestCase TestCase10() {
-  return {
-      /*range_data_param*/ {0, 7, 1},
-      /*size*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {5}),
-      /*shift*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*stride*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*drop_remainder*/
-      DatasetOpsTestBase::CreateTensor<bool>(TensorShape({}), {true}),
-      /*expected_outputs*/ {},
-      /*expected_output_dtypes*/ {DT_VARIANT},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ 0,
-      /*breakpoints*/ {0, 1, 9}};
+  return {/*range_data_param*/ {0, 7, 1},
+          /*size*/ CreateTensor<int64>(TensorShape({}), {5}),
+          /*shift*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*stride*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*drop_remainder*/
+          CreateTensor<bool>(TensorShape({}), {true}),
+          /*expected_outputs*/ {},
+          /*expected_output_dtypes*/ {DT_VARIANT},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ 0,
+          /*breakpoints*/ {0, 1, 9}};
 }
 
 // Test case 11: size=0, shift=2, stride=2, drop_remainder=true.
 TestCase InvalidWindowSizeTestCase() {
-  return {
-      /*range_data_param*/ {0, 7, 1},
-      /*size*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-      /*shift*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*stride*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*drop_remainder*/
-      DatasetOpsTestBase::CreateTensor<bool>(TensorShape({}), {true}),
-      /*expected_outputs*/ {},
-      /*expected_output_dtypes*/ {DT_VARIANT},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ 0,
-      /*breakpoints*/ {0, 1, 9}};
+  return {/*range_data_param*/ {0, 7, 1},
+          /*size*/ CreateTensor<int64>(TensorShape({}), {0}),
+          /*shift*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*stride*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*drop_remainder*/
+          CreateTensor<bool>(TensorShape({}), {true}),
+          /*expected_outputs*/ {},
+          /*expected_output_dtypes*/ {DT_VARIANT},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ 0,
+          /*breakpoints*/ {0, 1, 9}};
 }
 
 // Test case 12: size=2, shift=0, stride=2, drop_remainder=true.
 TestCase InvalidWindowShiftTestCase() {
-  return {
-      /*range_data_param*/ {0, 7, 1},
-      /*size*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*shift*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-      /*stride*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*drop_remainder*/
-      DatasetOpsTestBase::CreateTensor<bool>(TensorShape({}), {true}),
-      /*expected_outputs*/ {},
-      /*expected_output_dtypes*/ {DT_VARIANT},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ 0,
-      /*breakpoints*/ {0, 1, 9}};
+  return {/*range_data_param*/ {0, 7, 1},
+          /*size*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*shift*/ CreateTensor<int64>(TensorShape({}), {0}),
+          /*stride*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*drop_remainder*/
+          CreateTensor<bool>(TensorShape({}), {true}),
+          /*expected_outputs*/ {},
+          /*expected_output_dtypes*/ {DT_VARIANT},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ 0,
+          /*breakpoints*/ {0, 1, 9}};
 }
 
 // Test case 13: size=2, shift=2, stride=0, drop_remainder=true.
 TestCase InvalidWindowStrideTestCase() {
-  return {
-      /*range_data_param*/ {0, 7, 1},
-      /*size*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*shift*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
-      /*stride*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-      /*drop_remainder*/
-      DatasetOpsTestBase::CreateTensor<bool>(TensorShape({}), {true}),
-      /*expected_outputs*/ {},
-      /*expected_output_dtypes*/ {DT_VARIANT},
-      /*expected_output_shapes*/ {PartialTensorShape({})},
-      /*expected_cardinality*/ 0,
-      /*breakpoints*/ {0, 1, 9}};
+  return {/*range_data_param*/ {0, 7, 1},
+          /*size*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*shift*/ CreateTensor<int64>(TensorShape({}), {2}),
+          /*stride*/ CreateTensor<int64>(TensorShape({}), {0}),
+          /*drop_remainder*/
+          CreateTensor<bool>(TensorShape({}), {true}),
+          /*expected_outputs*/ {},
+          /*expected_output_dtypes*/ {DT_VARIANT},
+          /*expected_output_shapes*/ {PartialTensorShape({})},
+          /*expected_cardinality*/ 0,
+          /*breakpoints*/ {0, 1, 9}};
 }
 
 class ParameterizedWindowDatasetOpTest
@@ -587,49 +574,6 @@
   EXPECT_EQ(dataset->Cardinality(), test_case.expected_cardinality);
 }
 
-TEST_P(ParameterizedWindowDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  TestCase test_case = GetParam();
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  std::unique_ptr<OpKernel> window_dataset_kernel;
-  TF_ASSERT_OK(CreateWindowDatasetKernel(test_case.expected_output_dtypes,
-                                         test_case.expected_output_shapes,
-                                         &window_dataset_kernel));
-
-  DatasetBase* range_dataset;
-  TF_ASSERT_OK(CreateRangeDataset<int64>(
-      test_case.range_data_param.start, test_case.range_data_param.end,
-      test_case.range_data_param.step, "range", &range_dataset));
-  Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
-  TF_ASSERT_OK(
-      StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
-  Tensor size = test_case.size;
-  Tensor shift = test_case.shift;
-  Tensor stride = test_case.stride;
-  Tensor drop_remainder = test_case.drop_remainder;
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&range_dataset_tensor), TensorValue(&size),
-       TensorValue(&shift), TensorValue(&stride),
-       TensorValue(&drop_remainder)});
-
-  std::unique_ptr<OpKernelContext> window_dataset_op_ctx;
-  TF_ASSERT_OK(CreateWindowDatasetContext(window_dataset_kernel.get(), &inputs,
-                                          &window_dataset_op_ctx));
-  DatasetBase* dataset;
-  TF_ASSERT_OK(CreateDataset(window_dataset_kernel.get(),
-                             window_dataset_op_ctx.get(), &dataset));
-  core::ScopedUnref scoped_unref_dataset(dataset);
-
-  std::unique_ptr<SerializationContext> serialization_context;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(dataset->Save(serialization_context.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_P(ParameterizedWindowDatasetOpTest, IteratorOutputDtypes) {
   int thread_num = 2, cpu_num = 2;
   TestCase test_case = GetParam();
diff --git a/tensorflow/core/kernels/data/zip_dataset_op.cc b/tensorflow/core/kernels/data/zip_dataset_op.cc
index ecc3072..c401b65 100644
--- a/tensorflow/core/kernels/data/zip_dataset_op.cc
+++ b/tensorflow/core/kernels/data/zip_dataset_op.cc
@@ -87,6 +87,13 @@
     return result;
   }
 
+  Status CheckExternalState() const override {
+    for (const auto& input : inputs_) {
+      TF_RETURN_IF_ERROR(input->CheckExternalState());
+    }
+    return Status::OK();
+  }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
diff --git a/tensorflow/core/kernels/data/zip_dataset_op_test.cc b/tensorflow/core/kernels/data/zip_dataset_op_test.cc
index 301f182..9dddb05 100644
--- a/tensorflow/core/kernels/data/zip_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/zip_dataset_op_test.cc
@@ -91,12 +91,12 @@
   return {/*input_range_dataset_params*/
           {RangeDatasetParam{0, 3, 1}, RangeDatasetParam{10, 13, 1}},
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {0}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {10}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {1}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {11}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {2}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {12})},
+          {CreateTensor<int64>(TensorShape{}, {0}),
+           CreateTensor<int64>(TensorShape{}, {10}),
+           CreateTensor<int64>(TensorShape{}, {1}),
+           CreateTensor<int64>(TensorShape{}, {11}),
+           CreateTensor<int64>(TensorShape{}, {2}),
+           CreateTensor<int64>(TensorShape{}, {12})},
           /*breakpoints*/ {0, 1, 4}};
 }
 
@@ -105,12 +105,12 @@
   return {/*input_range_dataset_params*/
           {RangeDatasetParam{0, 3, 1}, RangeDatasetParam{10, 15, 1}},
           /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {0}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {10}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {1}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {11}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {2}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {12})},
+          {CreateTensor<int64>(TensorShape{}, {0}),
+           CreateTensor<int64>(TensorShape{}, {10}),
+           CreateTensor<int64>(TensorShape{}, {1}),
+           CreateTensor<int64>(TensorShape{}, {11}),
+           CreateTensor<int64>(TensorShape{}, {2}),
+           CreateTensor<int64>(TensorShape{}, {12})},
           /*breakpoints*/ {0, 1, 4}};
 }
 
@@ -333,41 +333,6 @@
             test_case.expected_outputs.size() / num_tensors_per_slice);
 }
 
-TEST_F(ZipDatasetOpTest, DatasetSave) {
-  int thread_num = 2, cpu_num = 2;
-  TF_ASSERT_OK(InitThreadPool(thread_num));
-  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
-
-  const TestParam &test_case = TestCase1();
-  std::vector<Tensor> range_dataset_tensors;
-  range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
-  TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
-                                         &range_dataset_tensors));
-  gtl::InlinedVector<TensorValue, 4> inputs;
-  inputs.reserve(range_dataset_tensors.size());
-  for (auto &tensor : range_dataset_tensors) {
-    inputs.emplace_back(&tensor);
-  }
-  std::unique_ptr<OpKernel> dataset_kernel;
-  int num_tensors_per_slice = test_case.input_range_dataset_params.size();
-  TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
-                                      inputs.size(), &dataset_kernel));
-  std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
-  TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
-                                       &dataset_kernel_ctx));
-  DatasetBase *zip_dataset;
-  TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
-                             &zip_dataset));
-  core::ScopedUnref scoped_unref(zip_dataset);
-
-  std::unique_ptr<SerializationContext> serialization_ctx;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(zip_dataset->Save(serialization_ctx.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
-}
-
 TEST_P(ParameterizedZipDatasetOpTest, IteratorOutputDtypes) {
   int thread_num = 2, cpu_num = 2;
   TF_ASSERT_OK(InitThreadPool(thread_num));
diff --git a/tensorflow/core/kernels/decode_bmp_op.cc b/tensorflow/core/kernels/decode_bmp_op.cc
index 8a9f7b1..122b7ec 100644
--- a/tensorflow/core/kernels/decode_bmp_op.cc
+++ b/tensorflow/core/kernels/decode_bmp_op.cc
@@ -54,7 +54,7 @@
                                         contents.shape().DebugString()));
 
     // Start decoding image to get shape details
-    const StringPiece input = contents.scalar<string>()();
+    const StringPiece input = contents.scalar<tstring>()();
 
     OP_REQUIRES(context, (32 <= input.size()),
                 errors::InvalidArgument("Incomplete bmp content, requires at "
diff --git a/tensorflow/core/kernels/decode_compressed_op.cc b/tensorflow/core/kernels/decode_compressed_op.cc
index 3c3d49e..dd44f04 100644
--- a/tensorflow/core/kernels/decode_compressed_op.cc
+++ b/tensorflow/core/kernels/decode_compressed_op.cc
@@ -84,13 +84,13 @@
   void Compute(OpKernelContext* context) override {
     const Tensor* bytes_tensor;
     OP_REQUIRES_OK(context, context->input("bytes", &bytes_tensor));
-    const auto& bytes_flat = bytes_tensor->flat<string>();
+    const auto& bytes_flat = bytes_tensor->flat<tstring>();
 
     Tensor* output_tensor = nullptr;
     OP_REQUIRES_OK(context,
                    context->allocate_output("output", bytes_tensor->shape(),
                                             &output_tensor));
-    auto output_flat = output_tensor->flat<string>();
+    auto output_flat = output_tensor->flat<tstring>();
     if (compression_type_.empty()) {
       for (int64 i = 0; i < bytes_flat.size(); i++) {
         output_flat(i) = bytes_flat(i);
@@ -109,7 +109,7 @@
         string output_string;
         Status s = zlib_stream->ReadNBytes(INT_MAX, &output_string);
         OP_REQUIRES(context, (s.ok() || errors::IsOutOfRange(s)), s);
-        output_flat(i) = output_string;
+        output_flat(i) = std::move(output_string);
       }
     }
   }
diff --git a/tensorflow/core/kernels/decode_csv_op.cc b/tensorflow/core/kernels/decode_csv_op.cc
index ba63695..9d959a5 100644
--- a/tensorflow/core/kernels/decode_csv_op.cc
+++ b/tensorflow/core/kernels/decode_csv_op.cc
@@ -70,7 +70,7 @@
                       " has ", record_defaults[i].NumElements()));
     }
 
-    auto records_t = records->flat<string>();
+    auto records_t = records->flat<tstring>();
     int64 records_size = records_t.size();
 
     OpOutputList output;
@@ -181,10 +181,10 @@
                           errors::InvalidArgument(
                               "Field ", f,
                               " is required but missing in record ", i, "!"));
-              output[f]->flat<string>()(i) =
-                  record_defaults[f].flat<string>()(0);
+              output[f]->flat<tstring>()(i) =
+                  record_defaults[f].flat<tstring>()(0);
             } else {
-              output[f]->flat<string>()(i) = fields[f];
+              output[f]->flat<tstring>()(i) = std::move(fields[f]);
             }
             break;
           }
diff --git a/tensorflow/core/kernels/decode_image_op.cc b/tensorflow/core/kernels/decode_image_op.cc
index 052c9f2..f89533d 100644
--- a/tensorflow/core/kernels/decode_image_op.cc
+++ b/tensorflow/core/kernels/decode_image_op.cc
@@ -154,7 +154,7 @@
                                         contents.shape().DebugString()));
 
     // Determine format
-    const StringPiece input = contents.scalar<string>()();
+    const StringPiece input = contents.scalar<tstring>()();
     const auto magic = ClassifyFileFormat(input);
     OP_REQUIRES(
         context,
diff --git a/tensorflow/core/kernels/decode_padded_raw_op.cc b/tensorflow/core/kernels/decode_padded_raw_op.cc
index 1e6a0cb..12e8ec6 100644
--- a/tensorflow/core/kernels/decode_padded_raw_op.cc
+++ b/tensorflow/core/kernels/decode_padded_raw_op.cc
@@ -39,7 +39,7 @@
 
   void Compute(OpKernelContext* context) override {
     const auto& input = context->input(0);
-    auto flat_in = input.flat<string>();
+    auto flat_in = input.flat<tstring>();
 
     int fixed_length;
     const auto& length_input = context->input(1);
diff --git a/tensorflow/core/kernels/decode_proto_op.cc b/tensorflow/core/kernels/decode_proto_op.cc
index 06dc766..5717fa5 100644
--- a/tensorflow/core/kernels/decode_proto_op.cc
+++ b/tensorflow/core/kernels/decode_proto_op.cc
@@ -748,14 +748,14 @@
     if (is_binary_ && !sanitize_) {
       // Fast path.
       for (int mi = 0; mi < message_count; ++mi) {
-        const string* buf = &buf_tensor.flat<string>()(mi);
+        const tstring* buf = &buf_tensor.flat<tstring>()(mi);
         bufs.push_back(buf);
       }
     } else {
       // We will have to allocate a copy, either to convert from text to binary
       // or to sanitize a binary proto.
       for (int mi = 0; mi < message_count; ++mi) {
-        ReserializeMessage(ctx, buf_tensor.flat<string>()(mi),
+        ReserializeMessage(ctx, buf_tensor.flat<tstring>()(mi),
                            &tmp_binary_bufs[mi]);
         if (!ctx->status().ok()) {
           return;
@@ -895,8 +895,8 @@
           data = tensor->bit_casted_shaped<uint8, 1>(flatshape).data();
         } else {
           // DataTypeSize() returns 0 for string types.
-          stride = last_dim_size * sizeof(string);
-          data = reinterpret_cast<uint8*>(tensor->flat<string>().data());
+          stride = last_dim_size * sizeof(tstring);
+          data = reinterpret_cast<uint8*>(tensor->flat<tstring>().data());
         }
       }
 
diff --git a/tensorflow/core/kernels/decode_raw_op.cc b/tensorflow/core/kernels/decode_raw_op.cc
index e68fa40..9425896 100644
--- a/tensorflow/core/kernels/decode_raw_op.cc
+++ b/tensorflow/core/kernels/decode_raw_op.cc
@@ -41,7 +41,7 @@
   void Compute(OpKernelContext* context) override {
     const auto& input = context->input(0);
     int64 str_size = -1;
-    auto flat_in = input.flat<string>();
+    auto flat_in = input.flat<tstring>();
     for (int64 i = 0; i < flat_in.size(); ++i) {
       const string& in_str = flat_in(i);
       if (str_size == -1) {
diff --git a/tensorflow/core/kernels/decode_wav_op.cc b/tensorflow/core/kernels/decode_wav_op.cc
index 4bd5d7a..c7edcac 100644
--- a/tensorflow/core/kernels/decode_wav_op.cc
+++ b/tensorflow/core/kernels/decode_wav_op.cc
@@ -40,7 +40,7 @@
     OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents.shape()),
                 errors::InvalidArgument("contents must be scalar, got shape ",
                                         contents.shape().DebugString()));
-    const string wav_string = contents.scalar<string>()();
+    const string& wav_string = contents.scalar<tstring>()();
     OP_REQUIRES(context, wav_string.size() <= std::numeric_limits<int>::max(),
                 errors::InvalidArgument("WAV contents are too large for int: ",
                                         wav_string.size()));
diff --git a/tensorflow/core/kernels/deserialize_sparse_string_op.cc b/tensorflow/core/kernels/deserialize_sparse_string_op.cc
index d26d818..398df42 100644
--- a/tensorflow/core/kernels/deserialize_sparse_string_op.cc
+++ b/tensorflow/core/kernels/deserialize_sparse_string_op.cc
@@ -75,7 +75,7 @@
     if (num_sparse_tensors == 1 && ndims == 1) {
       // Special case with a single sparse tensor. We can avoid data
       // motion in the Concat and Reshape.
-      const auto& serialized_sparse_t = serialized_sparse.vec<string>();
+      const auto& serialized_sparse_t = serialized_sparse.vec<tstring>();
 
       Tensor output_indices;
       Tensor output_values;
@@ -98,7 +98,7 @@
     values.reserve(num_sparse_tensors);
 
     const auto& serialized_sparse_t =
-        serialized_sparse.flat_inner_dims<string, 2>();
+        serialized_sparse.flat_inner_dims<tstring, 2>();
     for (int i = 0; i < num_sparse_tensors; ++i) {
       Tensor output_indices;
       Tensor output_values;
diff --git a/tensorflow/core/kernels/eigen_mkldnn_contraction_kernel_test.cc b/tensorflow/core/kernels/eigen_mkldnn_contraction_kernel_test.cc
index cbf5252..1db98d1 100644
--- a/tensorflow/core/kernels/eigen_mkldnn_contraction_kernel_test.cc
+++ b/tensorflow/core/kernels/eigen_mkldnn_contraction_kernel_test.cc
@@ -13,6 +13,12 @@
 limitations under the License.
 ==============================================================================*/
 
+// Need to #include Eigen's Tensor class first because Eigen/CXX11/FixedPoint
+// depends on the file but doesn't include it. This breaks compilation on
+// clang.
+// clang-format off
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+// clang-format on
 #include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
 #include "tensorflow/core/platform/test.h"
diff --git a/tensorflow/core/kernels/eigen_softmax.h b/tensorflow/core/kernels/eigen_softmax.h
deleted file mode 100644
index 12148c5..0000000
--- a/tensorflow/core/kernels/eigen_softmax.h
+++ /dev/null
@@ -1,99 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_SOFTMAX_H_
-#define TENSORFLOW_CORE_KERNELS_EIGEN_SOFTMAX_H_
-
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-
-namespace Eigen {
-
-/** SoftMax
- * \ingroup CXX11_NeuralNetworks_Module
- *
- * \brief Applies a softmax
- *
- * The input parameter is expected to be a col-major tensor with a rank of 2
- * (depth and other).
- *
- * The result can be assigned to a tensor of rank and dimensions equal to that
- * of the input. The result will be laid out in col-major order.
- *
- */
-
-namespace {
-struct SoftmaxOp {
-  SoftmaxOp(const float beta) : beta_(beta) {}
-
-  template <typename Input>
-  typename Input::Dimensions dimensions(const Input& input) const {
-    return input.dimensions();
-  }
-
-  template <typename Input, typename Output, typename Device>
-  void eval(const Input& input, Output& output, const Device& device) const {
-#if !defined(EIGEN_HAS_INDEX_LIST)
-    // nvcc doesn't support cxx11
-    Eigen::array<typename internal::traits<Input>::Index, 1> depth_dim;
-    depth_dim[0] = 0;
-    Eigen::array<typename internal::traits<Input>::Index, 2> bcast;
-    bcast[0] = dimensions(input)[0];
-    bcast[1] = 1;
-    DSizes<typename internal::traits<Input>::Index, 2> dims2d;
-    dims2d[0] = 1;
-    dims2d[1] = dimensions(input)[1];
-#else
-    // Take advantage of cxx11 to give the compiler information it can use to
-    // optimize the code.
-    Eigen::IndexList<Eigen::type2index<0> > depth_dim;
-    Eigen::IndexList<int, Eigen::type2index<1> > bcast;
-    bcast.set(0, dimensions(input)[0]);
-    Eigen::IndexList<Eigen::type2index<1>,
-                     typename internal::traits<Input>::Index>
-        dims2d;
-    dims2d.set(1, dimensions(input)[1]);
-#endif
-
-    output.device(device) =
-        ((input -
-          input.maximum(depth_dim).eval().reshape(dims2d).broadcast(bcast)) *
-         beta_)
-            .exp();
-    output.device(device) =
-        output /
-        (output.sum(depth_dim).eval().reshape(dims2d).broadcast(bcast));
-  }
-
- private:
-  const float beta_;
-};
-}  // namespace
-
-template <typename Input>
-EIGEN_ALWAYS_INLINE static const TensorCustomUnaryOp<const SoftmaxOp,
-                                                     const Input>
-SoftMax(const Input& input, const float beta) {
-  EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == ColMajor,
-                      YOU_MADE_A_PROGRAMMING_MISTAKE);
-  EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 2,
-                      YOU_MADE_A_PROGRAMMING_MISTAKE);
-
-  const SoftmaxOp op(beta);
-  return input.customOp(op);
-}
-
-}  // end namespace Eigen
-
-#endif  // TENSORFLOW_CORE_KERNELS_EIGEN_SOFTMAX_H_
diff --git a/tensorflow/core/kernels/eigen_softmax_test.cc b/tensorflow/core/kernels/eigen_softmax_test.cc
deleted file mode 100644
index 30a1ccc..0000000
--- a/tensorflow/core/kernels/eigen_softmax_test.cc
+++ /dev/null
@@ -1,64 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/kernels/eigen_softmax.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace Eigen {
-
-namespace {
-void EigenApprox(float a, float b) {
-  ASSERT_TRUE(std::abs(a - b) <= std::min(std::abs(a), std::abs(b)) * 1e-3);
-}
-}  // namespace
-
-TEST(EigenSoftmaxTest, Simple) {
-  const int depth = 1024;
-  const int batch = 32;
-  const float beta = 1.2f;
-
-  Tensor<float, 2> input(depth, batch);
-  input = input.constant(11.0f) + input.random();
-
-  Tensor<float, 2> reference(depth, batch);
-  reference.setRandom();
-
-  Eigen::array<int, 1> depth_dim;
-  depth_dim[0] = 0;
-  Eigen::array<int, 2> bcast;
-  bcast[0] = depth;
-  bcast[1] = 1;
-  Tensor<float, 2>::Dimensions dims2d;
-  dims2d[0] = 1;
-  dims2d[1] = batch;
-  reference =
-      ((input -
-        input.maximum(depth_dim).eval().reshape(dims2d).broadcast(bcast)) *
-       beta)
-          .exp();
-  reference =
-      reference /
-      (reference.sum(depth_dim).eval().reshape(dims2d).broadcast(bcast));
-
-  Tensor<float, 2> result = SoftMax(input, beta);
-
-  for (int i = 0; i < depth; ++i) {
-    for (int j = 0; j < batch; ++j) {
-      EigenApprox(result(i, j), reference(i, j));
-    }
-  }
-}
-
-}  // namespace Eigen
diff --git a/tensorflow/core/kernels/encode_proto_op.cc b/tensorflow/core/kernels/encode_proto_op.cc
index b023f1c..12bbd34 100644
--- a/tensorflow/core/kernels/encode_proto_op.cc
+++ b/tensorflow/core/kernels/encode_proto_op.cc
@@ -303,7 +303,7 @@
 // code it ourselves.
 Status WriteGroup(const FieldDescriptor& field_desc, const Tensor& input,
                   int message_index, int size, CodedOutputStream* output) {
-  auto input_t = input.flat_inner_dims<string>();
+  auto input_t = input.flat_inner_dims<tstring>();
   for (int64 i = 0; i < size; i++) {
     const string& value = input_t(static_cast<int64>(message_index), i);
     WireFormatLite::WriteTag(field_desc.number(),
@@ -587,7 +587,7 @@
     Tensor* output_tensor;
     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, common_prefix, &output_tensor));
 
-    auto bufs = output_tensor->flat<string>();
+    auto bufs = output_tensor->flat<tstring>();
     for (int message_index = 0; message_index < message_count;
          message_index++) {
       // TODO(nix): possibly optimize allocation here by calling
diff --git a/tensorflow/core/kernels/example_parsing_ops.cc b/tensorflow/core/kernels/example_parsing_ops.cc
index 708b52a..783190b 100644
--- a/tensorflow/core/kernels/example_parsing_ops.cc
+++ b/tensorflow/core/kernels/example_parsing_ops.cc
@@ -63,10 +63,10 @@
 
     // Copy from OpInputList to std::vector<string>.
     for (int di = 0; di < attrs_.num_dense; ++di) {
-      dense_keys_t[di] = dense_keys[di].scalar<string>()();
+      dense_keys_t[di] = dense_keys[di].scalar<tstring>()();
     }
     for (int di = 0; di < attrs_.num_sparse; ++di) {
-      sparse_keys_t[di] = sparse_keys[di].scalar<string>()();
+      sparse_keys_t[di] = sparse_keys[di].scalar<tstring>()();
     }
 
     if (names->NumElements() > 0) {
@@ -234,7 +234,7 @@
       config.sparse.push_back({attrs_.sparse_keys[d], attrs_.sparse_types[d]});
     }
 
-    const string& serialized_proto = serialized->scalar<string>()();
+    const string& serialized_proto = serialized->scalar<tstring>()();
 
     OP_REQUIRES_OK(ctx,
                    FastParseSingleExample(config, serialized_proto, &result));
@@ -473,7 +473,7 @@
                       "Expected context_dense_keys[", di,
                       "] to be a scalar, got shape: ",
                       context_dense_keys[di].shape().DebugString()));
-      context_dense_keys_t[di] = context_dense_keys[di].scalar<string>()();
+      context_dense_keys_t[di] = context_dense_keys[di].scalar<tstring>()();
     }
     for (int di = 0; di < attrs_.num_context_sparse; ++di) {
       OP_REQUIRES(ctx,
@@ -482,7 +482,7 @@
                       "Expected context_sparse_keys[", di,
                       "] to be a scalar, got shape: ",
                       context_sparse_keys[di].shape().DebugString()));
-      context_sparse_keys_t[di] = context_sparse_keys[di].scalar<string>()();
+      context_sparse_keys_t[di] = context_sparse_keys[di].scalar<tstring>()();
     }
     for (int di = 0; di < attrs_.num_feature_list_dense; ++di) {
       OP_REQUIRES(
@@ -492,7 +492,7 @@
               "] to be a scalar, got shape: ",
               feature_list_dense_keys[di].shape().DebugString()));
       feature_list_dense_keys_t[di] =
-          feature_list_dense_keys[di].scalar<string>()();
+          feature_list_dense_keys[di].scalar<tstring>()();
     }
     for (int di = 0; di < attrs_.num_feature_list_sparse; ++di) {
       OP_REQUIRES(
@@ -502,7 +502,7 @@
               "] to be a scalar, got shape: ",
               feature_list_sparse_keys[di].shape().DebugString()));
       feature_list_sparse_keys_t[di] =
-          feature_list_sparse_keys[di].scalar<string>()();
+          feature_list_sparse_keys[di].scalar<tstring>()();
     }
     OP_REQUIRES(
         ctx,
@@ -513,7 +513,7 @@
             "to be a vector, got shape: ",
             feature_list_dense_missing_assumed_empty->shape().DebugString()));
     auto feature_list_dense_missing_assumped_empty_t =
-        feature_list_dense_missing_assumed_empty->vec<string>();
+        feature_list_dense_missing_assumed_empty->vec<tstring>();
     for (int de = 0;
          de < feature_list_dense_missing_assumed_empty->NumElements(); ++de) {
       feature_list_dense_missing_assumed_empty_set.insert(
@@ -527,7 +527,7 @@
                       "Expected debug_name to be a scalar, got shape: ",
                       debug_name->shape().DebugString()));
     }
-    auto debug_name_t = debug_name->scalar<string>();
+    auto debug_name_t = debug_name->scalar<tstring>();
 
     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(serialized->shape()),
                 errors::InvalidArgument(
@@ -561,7 +561,7 @@
       }
     }
 
-    auto serialized_t = serialized->scalar<string>();
+    auto serialized_t = serialized->scalar<tstring>();
 
     OpOutputList context_sparse_indices;
     OpOutputList context_sparse_values;
diff --git a/tensorflow/core/kernels/example_parsing_ops_test.cc b/tensorflow/core/kernels/example_parsing_ops_test.cc
index 4d843ab..db1672e 100644
--- a/tensorflow/core/kernels/example_parsing_ops_test.cc
+++ b/tensorflow/core/kernels/example_parsing_ops_test.cc
@@ -114,7 +114,7 @@
     Example example;
     Filler fill;
     Tensor record_string(DT_STRING, TensorShape({batch_size}));
-    auto string_t = record_string.vec<string>();
+    auto string_t = record_string.vec<tstring>();
     example.Clear();
     for (int b = 0; b < batch_size; ++b) {
       for (int k = 0; k < num_keys; ++k) {
@@ -163,7 +163,7 @@
   Options opt;
   for (int i = 0; i < num_keys; ++i) {
     Tensor key(DT_STRING, TensorShape());
-    key.scalar<string>()() = strings::Printf("feature_%d", i);
+    key.scalar<tstring>()() = strings::Printf("feature_%d", i);
     switch (opt.benchmark_type) {
       case kDense:
         dense_keys.emplace_back(test::graph::Constant(g, key));
@@ -205,7 +205,7 @@
       Options::Store::GetSerializedExample()[std::make_tuple(1, num_keys,
                                                              feature_size)];
   Tensor serialized(DT_STRING, TensorShape());
-  serialized.scalar<string>()() = serialized_batch_1.vec<string>()(0);
+  serialized.scalar<tstring>()() = serialized_batch_1.vec<tstring>()(0);
 
   std::vector<string> sparse_keys;
   std::vector<string> dense_keys;
diff --git a/tensorflow/core/kernels/extract_jpeg_shape_op.cc b/tensorflow/core/kernels/extract_jpeg_shape_op.cc
index ab42459..c74245d 100644
--- a/tensorflow/core/kernels/extract_jpeg_shape_op.cc
+++ b/tensorflow/core/kernels/extract_jpeg_shape_op.cc
@@ -41,7 +41,7 @@
     OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents.shape()),
                 errors::InvalidArgument("contents must be scalar, got shape ",
                                         contents.shape().DebugString()));
-    const StringPiece input = contents.scalar<string>()();
+    const StringPiece input = contents.scalar<tstring>()();
     OP_REQUIRES(context, input.size() <= std::numeric_limits<int>::max(),
                 errors::InvalidArgument("JPEG contents are too large for int: ",
                                         input.size()));
diff --git a/tensorflow/core/kernels/fact_op.cc b/tensorflow/core/kernels/fact_op.cc
index 4a1aa43..6c11ab7 100644
--- a/tensorflow/core/kernels/fact_op.cc
+++ b/tensorflow/core/kernels/fact_op.cc
@@ -85,7 +85,7 @@
     Tensor* output_tensor = nullptr;
     OP_REQUIRES_OK(
         context, context->allocate_output(0, TensorShape({}), &output_tensor));
-    auto output = output_tensor->template scalar<string>();
+    auto output = output_tensor->template scalar<tstring>();
 
     string coded = facts[context->env()->NowMicros() % count];
     E(&coded);
diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc
index e0f326d..fabd8e9 100644
--- a/tensorflow/core/kernels/fft_ops.cc
+++ b/tensorflow/core/kernels/fft_ops.cc
@@ -315,11 +315,9 @@
   ~CufftScratchAllocator() override {}
   CufftScratchAllocator(int64 memory_limit, OpKernelContext* context)
       : memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
-  int64 GetMemoryLimitInBytes(se::Stream* stream) override {
-    return memory_limit_;
-  }
+  int64 GetMemoryLimitInBytes() override { return memory_limit_; }
   se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
-      se::Stream* stream, int64 byte_size) override {
+      int64 byte_size) override {
     Tensor temporary_memory;
     if (byte_size > memory_limit_) {
       return se::port::StatusOr<se::DeviceMemory<uint8>>();
diff --git a/tensorflow/core/kernels/fingerprint_op.cc b/tensorflow/core/kernels/fingerprint_op.cc
index 2052932..660f900 100644
--- a/tensorflow/core/kernels/fingerprint_op.cc
+++ b/tensorflow/core/kernels/fingerprint_op.cc
@@ -110,14 +110,14 @@
         // and each row contains the fingerprint value of corresponding string.
         // To compute fingerprints of multiple strings, this op fingerprints the
         // buffer containing the string fingerprints.
-        FarmhashFingerprint64(input.flat<string>(), temp.tensor<uint8, 2>());
+        FarmhashFingerprint64(input.flat<tstring>(), temp.tensor<uint8, 2>());
         FarmhashFingerprint64(static_cast<const Tensor&>(temp).shaped<uint8, 2>(
                                   {dim0, dim1 * kFingerprintSize}),
                               output->matrix<uint8>());
       } else {
         // In case dim1 == 1, each string computes into its own fingerprint
         // value. There is no need to fingerprint twice.
-        FarmhashFingerprint64(input.flat<string>(), output->matrix<uint8>());
+        FarmhashFingerprint64(input.flat<tstring>(), output->matrix<uint8>());
       }
     } else {
       auto data = input.bit_casted_shaped<uint8, 2>(
diff --git a/tensorflow/core/kernels/fingerprint_op_test.cc b/tensorflow/core/kernels/fingerprint_op_test.cc
index 14376cb..d9a9a97 100644
--- a/tensorflow/core/kernels/fingerprint_op_test.cc
+++ b/tensorflow/core/kernels/fingerprint_op_test.cc
@@ -51,7 +51,7 @@
     inputs_.push_back(TensorValue(data));
 
     method_ = Tensor(DT_STRING, TensorShape{});
-    method_.scalar<string>()() = method;
+    method_.scalar<tstring>()() = method;
     inputs_.push_back(TensorValue(&method_));
     return Status::OK();
   }
@@ -77,7 +77,7 @@
 // special-case handling.
 TEST_F(FingerprintOpTest, StringGoldenValue) {
   Tensor data(DT_STRING, {1, 2, 2});
-  auto buffer = data.flat<string>();
+  auto buffer = data.flat<tstring>();
   buffer(0).resize(10);
   buffer(1).resize(7);
   buffer(2).resize(0);
@@ -134,7 +134,7 @@
   constexpr int64 size = 256;
 
   Tensor tensor(DT_STRING, {1});
-  auto& input = tensor.vec<string>()(0);
+  auto& input = tensor.vec<tstring>()(0);
   input.resize(size);
 
   TTypes<uint8>::UnalignedFlat buffer(reinterpret_cast<uint8*>(&*input.begin()),
@@ -163,7 +163,7 @@
   auto pods = pods_tensor.matrix<float>();
   pods.setRandom();
 
-  auto strings = strings_tensor.vec<string>();
+  auto strings = strings_tensor.vec<tstring>();
   for (int64 i = 0; i < strings.size(); ++i) {
     strings(i).assign(reinterpret_cast<const char*>(&pods(i, 0)),
                       pods.dimension(1) * sizeof(pods(i, 0)));
@@ -199,7 +199,7 @@
   ShapeInferenceTestOp op("Fingerprint");
 
   Tensor method(DT_STRING, TensorShape{});
-  method.scalar<string>()() = "farmhash64";
+  method.scalar<tstring>()() = "farmhash64";
   op.input_tensors.assign({nullptr, &method});
 
   TF_ASSERT_OK(MakeNodeDef(DT_UINT8, &op.node_def));
@@ -229,12 +229,12 @@
 
   // When `method` shape is unknown statically.
   Tensor method(DT_STRING, TensorShape{1});
-  method.vec<string>()(0) = "farmhash64";
+  method.vec<tstring>()(0) = "farmhash64";
   op.input_tensors.assign({nullptr, &method});
   INFER_ERROR("must be rank 0", op, "?;?");
 
   method = Tensor(DT_STRING, TensorShape{});
-  method.scalar<string>()() = "unsupported_method";
+  method.scalar<tstring>()() = "unsupported_method";
   op.input_tensors.assign({nullptr, &method});
   INFER_ERROR("unsupported_method", op, "?;?");
 }
diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc
index 33bed21..087ff2e 100644
--- a/tensorflow/core/kernels/function_ops.cc
+++ b/tensorflow/core/kernels/function_ops.cc
@@ -318,7 +318,7 @@
   string target_device;
   OP_REQUIRES_OK_ASYNC(
       ctx,
-      DeviceNameUtils::CanonicalizeDeviceName(target->scalar<string>()(),
+      DeviceNameUtils::CanonicalizeDeviceName(target->scalar<tstring>()(),
                                               source_device, &target_device),
       done);
 
diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc
index 920c14b..d7d15d5 100644
--- a/tensorflow/core/kernels/functional_ops.cc
+++ b/tensorflow/core/kernels/functional_ops.cc
@@ -82,7 +82,7 @@
         *v = t[0].scalar<bool>()();
         break;
       case DT_STRING:
-        *v = !t[0].scalar<string>()().empty();
+        *v = !t[0].scalar<tstring>()().empty();
         break;
       default:
         return errors::InvalidArgument(DataTypeString(t[0].dtype()),
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc
index 70bd659..dd75b37 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/fused_batch_norm_op.cc
@@ -101,12 +101,11 @@
   explicit CudnnBatchNormAllocatorInTemp(OpKernelContext* context)
       : context_(context) {}
 
-  int64 GetMemoryLimitInBytes(Stream* stream) override {
+  int64 GetMemoryLimitInBytes() override {
     return std::numeric_limits<int64>::max();
   }
 
-  StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream,
-                                              int64 byte_size) override {
+  StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override {
     Tensor temporary_memory;
     const DataType tf_data_type = DataTypeToEnum<T>::v();
     int64 allocate_count =
@@ -155,12 +154,11 @@
   CudnnBatchNormAllocatorInOutput(OpKernelContext* context, int output_index)
       : context_(context), output_index_(output_index) {}
 
-  int64 GetMemoryLimitInBytes(Stream* stream) override {
+  int64 GetMemoryLimitInBytes() override {
     return std::numeric_limits<int64>::max();
   }
 
-  StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream,
-                                              int64 byte_size) override {
+  StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override {
     output_allocated = true;
     DCHECK(total_byte_size_ == 0)
         << "Reserve space allocator can only be called once";
diff --git a/tensorflow/core/kernels/fuzzing/example_proto_fast_parsing_fuzz.cc b/tensorflow/core/kernels/fuzzing/example_proto_fast_parsing_fuzz.cc
index f72dfb3..35cd0fb 100644
--- a/tensorflow/core/kernels/fuzzing/example_proto_fast_parsing_fuzz.cc
+++ b/tensorflow/core/kernels/fuzzing/example_proto_fast_parsing_fuzz.cc
@@ -51,7 +51,7 @@
   void FuzzImpl(const uint8_t* data, size_t size) final {
     // TODO(dga):  Test the batch case also.
     Tensor input_tensor(tensorflow::DT_STRING, TensorShape({}));
-    input_tensor.scalar<string>()() =
+    input_tensor.scalar<tstring>()() =
         string(reinterpret_cast<const char*>(data), size);
     RunInputs({{"input", input_tensor}});
   }
diff --git a/tensorflow/core/kernels/fuzzing/fuzz_session.h b/tensorflow/core/kernels/fuzzing/fuzz_session.h
index 4b036b1..dc0435c 100644
--- a/tensorflow/core/kernels/fuzzing/fuzz_session.h
+++ b/tensorflow/core/kernels/fuzzing/fuzz_session.h
@@ -145,7 +145,7 @@
 class FuzzStringInputOp : public FuzzSession {
   void FuzzImpl(const uint8_t* data, size_t size) final {
     Tensor input_tensor(tensorflow::DT_STRING, TensorShape({}));
-    input_tensor.scalar<string>()() =
+    input_tensor.scalar<tstring>()() =
         string(reinterpret_cast<const char*>(data), size);
     RunInputs({{"input", input_tensor}});
   }
diff --git a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc
index 0ce4206..a71f290 100644
--- a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc
+++ b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc
@@ -61,7 +61,7 @@
 
     // Now we can do the actual fuzz implementation
     Tensor input_tensor(tensorflow::DT_STRING, TensorShape({}));
-    input_tensor.scalar<string>()() = as_string;
+    input_tensor.scalar<tstring>()() = as_string;
     RunInputs({{"input", input_tensor}});
   }
 };
diff --git a/tensorflow/core/kernels/fuzzing/string_split_fuzz.cc b/tensorflow/core/kernels/fuzzing/string_split_fuzz.cc
index b3b637b..d4e6418 100644
--- a/tensorflow/core/kernels/fuzzing/string_split_fuzz.cc
+++ b/tensorflow/core/kernels/fuzzing/string_split_fuzz.cc
@@ -42,9 +42,9 @@
       if (delim_len > size) {
         delim_len = size - 1;
       }
-      delimiter_tensor.scalar<string>()() =
+      delimiter_tensor.scalar<tstring>()() =
           string(reinterpret_cast<const char*>(data), delim_len);
-      input_tensor.scalar<string>()() = string(
+      input_tensor.scalar<tstring>()() = string(
           reinterpret_cast<const char*>(data + delim_len), size - delim_len);
 
       RunInputs({{"input", input_tensor}, {"delimiter", delimiter_tensor}});
diff --git a/tensorflow/core/kernels/fuzzing/string_split_v2_fuzz.cc b/tensorflow/core/kernels/fuzzing/string_split_v2_fuzz.cc
index f7e3da8..367759d 100644
--- a/tensorflow/core/kernels/fuzzing/string_split_v2_fuzz.cc
+++ b/tensorflow/core/kernels/fuzzing/string_split_v2_fuzz.cc
@@ -46,10 +46,10 @@
       if (sep_len > size) {
         sep_len = size - 1;
       }
-      separator_tensor.scalar<string>()() =
+      separator_tensor.scalar<tstring>()() =
           string(reinterpret_cast<const char*>(data), sep_len);
-      input_tensor.scalar<string>()() = string(
-          reinterpret_cast<const char*>(data + sep_len), size - sep_len);
+      input_tensor.scalar<tstring>()() =
+          string(reinterpret_cast<const char*>(data + sep_len), size - sep_len);
 
       RunInputs({{"input", input_tensor}, {"separator", separator_tensor}});
     }
diff --git a/tensorflow/core/kernels/gather_functor_batched.cc b/tensorflow/core/kernels/gather_functor_batched.cc
new file mode 100644
index 0000000..0960b3a
--- /dev/null
+++ b/tensorflow/core/kernels/gather_functor_batched.cc
@@ -0,0 +1,55 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+
+#include "tensorflow/core/kernels/gather_functor_batched.h"
+#include "tensorflow/core/framework/register_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+
+// Forward declarations of the functor specializations for GPU.
+#define DECLARE_GPU_SPECS_INDEX(T, Index)                               \
+  template <>                                                           \
+  int64 GatherFunctorBatched<GPUDevice, T, Index>::operator()(          \
+      OpKernelContext* ctx, typename TTypes<T, 4>::ConstTensor Tparams, \
+      typename TTypes<Index>::ConstFlat Tindices,                       \
+      typename TTypes<T, 4>::Tensor Tout);                              \
+  extern template struct GatherFunctorBatched<GPUDevice, T, Index>;
+
+#define DECLARE_GPU_SPECS(T)         \
+  DECLARE_GPU_SPECS_INDEX(T, int32); \
+  DECLARE_GPU_SPECS_INDEX(T, int64)
+
+TF_CALL_int64(DECLARE_GPU_SPECS);
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
+TF_CALL_complex64(DECLARE_GPU_SPECS);
+TF_CALL_complex128(DECLARE_GPU_SPECS);
+
+#undef DECLARE_GPU_SPECS
+#undef DECLARE_GPU_SPECS_INDEX
+
+}  // namespace functor
+}  // namespace tensorflow
+
+#else
+
+#include "tensorflow/core/kernels/gather_functor_batched.h"
+
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
diff --git a/tensorflow/core/kernels/gather_functor_batched.h b/tensorflow/core/kernels/gather_functor_batched.h
new file mode 100644
index 0000000..fa9ac72
--- /dev/null
+++ b/tensorflow/core/kernels/gather_functor_batched.h
@@ -0,0 +1,197 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_
+#define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+#include "tensorflow/core/framework/bounds_check.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/type_traits.h"
+#include "tensorflow/core/framework/variant.h"
+#include "tensorflow/core/platform/prefetch.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+
+// Helper method to copy using memcpy.
+template <typename T, typename Index, typename SliceIndex,
+          SliceIndex static_slice_elems>
+SliceIndex HandleCopiesBatched(OpKernelContext* ctx,
+                               typename TTypes<T, 4>::ConstTensor params,
+                               typename TTypes<Index>::ConstFlat indices,
+                               SliceIndex slice_elems,
+                               typename TTypes<T, 4>::Tensor out) {
+  const SliceIndex batch_size = static_cast<SliceIndex>(params.dimension(0));
+  const SliceIndex outer_size = static_cast<SliceIndex>(params.dimension(1));
+  const SliceIndex indices_size =
+      static_cast<SliceIndex>(indices.dimension(0)) / batch_size;
+
+  const Index limit = static_cast<Index>(params.dimension(2));
+  if (static_slice_elems >= 0) {
+    // Give compiler static knowledge of the number of elements/bytes
+    slice_elems = static_slice_elems;
+  }
+  // Compute slice_bytes here so that static knowledge is available
+  const size_t slice_bytes = slice_elems * sizeof(T);
+  auto* worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
+  mutex mu;
+  // Store the value of invalidate index for printing error information, it's a
+  // shared variable.
+  SliceIndex result = -1;
+  auto work = [&](int64 start, int64 end) {
+    const int64 r_start = start % (outer_size * indices_size);
+    SliceIndex batch_idx = static_cast<SliceIndex>(
+        start / (outer_size * indices_size));
+    SliceIndex outer_idx = static_cast<SliceIndex>(r_start / indices_size);
+    SliceIndex indices_idx = static_cast<SliceIndex>(r_start % indices_size);
+
+    SliceIndex batch_offset = batch_idx * indices_size;
+    for (; start < end; ++start) {
+      SliceIndex i_next = indices_idx + 1;
+      SliceIndex o_next = outer_idx;
+      SliceIndex b_next = batch_idx;
+      SliceIndex b_offset_next = batch_offset;
+
+      if (i_next >= indices_size) {
+        i_next = 0;
+        if (++o_next >= outer_size) {
+          o_next = 0;
+          ++b_next;
+          b_offset_next += indices_size;
+        }
+      }
+      if (start + 1 < end) {
+        port::prefetch<port::PREFETCH_HINT_T0>(
+            &params(b_next, o_next, indices(b_offset_next + i_next), 0));
+        port::prefetch<port::PREFETCH_HINT_T0>(&out(b_next, o_next, i_next, 0));
+      }
+      const Index index = internal::SubtleMustCopy(
+          indices(batch_offset + indices_idx));
+      if (!FastBoundsCheck(index, limit)) {
+        mutex_lock l(mu);
+        result = batch_offset + indices_idx;
+        return;
+      }
+
+      // Copy using memcpy if possible, otherwise an Eigen loop
+      // TODO(cwhipkey): avoid linking to framework to get Allocator (to improve
+      // ahead-of-time compilation binary size).
+      if (is_simple_type<T>::value) {
+        // Avoid auto-promotion to Index from SliceIndex by casting.
+        memcpy(
+            &out(batch_idx, outer_idx, indices_idx, 0),
+            &params(batch_idx, outer_idx, static_cast<SliceIndex>(index), 0),
+            slice_bytes);
+      } else {
+        // For non-"simple" types (e.g. strings).
+        out.template chip<2>(indices_idx) = params.template chip<2>(index);
+      }
+
+      indices_idx = i_next;
+      outer_idx = o_next;
+      batch_idx = b_next;
+      batch_offset = b_offset_next;
+    }
+  };
+
+  Shard(worker_threads->num_threads, worker_threads->workers,
+        batch_size * outer_size * indices_size, slice_elems * sizeof(T), work);
+  return result;
+}
+
+template <typename T, typename Index>
+struct GatherFunctorBatchedCPU {
+  int64 operator()(OpKernelContext* ctx,
+                   typename TTypes<T, 4>::ConstTensor params,
+                   typename TTypes<Index>::ConstFlat indices,
+                   typename TTypes<T, 4>::Tensor out) {
+    const int64 indices_size = indices.size();  // Includes the batch_size.
+    const int64 slice_size = out.dimension(3);
+    int64 bad_i;
+
+    const int64 batch_size = params.dimension(0);
+    const int64 outer_size = params.dimension(1);
+
+    bool use_large = (slice_size > std::numeric_limits<int32>::max() ||
+                      params.size() > std::numeric_limits<int32>::max() ||
+                      indices_size > std::numeric_limits<int32>::max() ||
+                      batch_size * outer_size * indices_size * slice_size >
+                          std::numeric_limits<int32>::max());
+#define CALL(elems)                                                      \
+  do {                                                                   \
+    if (use_large) {                                                     \
+      bad_i = HandleCopiesBatched<T, Index, int64, elems>(               \
+          ctx, params, indices, slice_size, out);                        \
+    } else {                                                             \
+      const int32 small_slice = static_cast<int32>(slice_size);          \
+      bad_i = HandleCopiesBatched<T, Index, int32, elems>(               \
+          ctx, params, indices, small_slice, out);                       \
+    }                                                                    \
+  } while (0)
+
+    // TODO(rmlarsen): Investigate whether these specializations are still
+    // needed and, if yes, whether the slice sizes are apropriate.
+    if (slice_size == 10)
+      CALL(10);
+    else if (slice_size == 20)
+      CALL(20);
+    else
+      CALL(-1);
+#undef CALL
+
+    return bad_i;
+  }
+};
+
+template <typename Device, typename T, typename Index>
+struct GatherFunctorBatched {
+  int64 operator()(OpKernelContext* ctx,
+                   typename TTypes<T, 4>::ConstTensor params,
+                   typename TTypes<Index>::ConstFlat indices,
+                   typename TTypes<T, 4>::Tensor out);
+};
+
+template <typename T, typename Index>
+struct GatherFunctorBatched<CPUDevice, T, Index> {
+  int64 operator()(OpKernelContext* ctx,
+                   typename TTypes<T, 4>::ConstTensor params,
+                   typename TTypes<Index>::ConstFlat indices,
+                   typename TTypes<T, 4>::Tensor out) {
+    return GatherFunctorBatchedCPU<T, Index>()(ctx, params, indices, out);
+  }
+};
+
+template <typename Index>
+struct GatherFunctorBatched<GPUDevice, Variant, Index> {
+  int64 operator()(OpKernelContext* ctx,
+                   typename TTypes<Variant, 4>::ConstTensor params,
+                   typename TTypes<Index>::ConstFlat indices,
+                   typename TTypes<Variant, 4>::Tensor out) {
+    return GatherFunctorBatchedCPU<Variant, Index>()(ctx, params, indices, out);
+  }
+};
+
+}  // namespace functor
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_
diff --git a/tensorflow/core/kernels/gather_functor_batched_gpu.cu.cc b/tensorflow/core/kernels/gather_functor_batched_gpu.cu.cc
new file mode 100644
index 0000000..f118d8d
--- /dev/null
+++ b/tensorflow/core/kernels/gather_functor_batched_gpu.cu.cc
@@ -0,0 +1,46 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/gather_functor_batched_gpu.cu.h"
+#include "tensorflow/core/framework/register_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+#define DEFINE_GPU_SPECS_INDEX(T, Index) \
+  template struct functor::GatherFunctorBatched<GPUDevice, T, Index>
+
+#define DEFINE_GPU_SPECS(T)         \
+  DEFINE_GPU_SPECS_INDEX(T, int32); \
+  DEFINE_GPU_SPECS_INDEX(T, int64);
+
+TF_CALL_bool(DEFINE_GPU_SPECS);
+TF_CALL_int32(DEFINE_GPU_SPECS);
+TF_CALL_int64(DEFINE_GPU_SPECS);
+TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
+TF_CALL_complex64(DEFINE_GPU_SPECS);
+TF_CALL_complex128(DEFINE_GPU_SPECS);
+
+#undef DEFINE_GPU_SPECS
+#undef DEFINE_GPU_SPECS_INDEX
+
+}  // namespace tensorflow
+
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
diff --git a/tensorflow/core/kernels/gather_functor_batched_gpu.cu.h b/tensorflow/core/kernels/gather_functor_batched_gpu.cu.h
new file mode 100644
index 0000000..24c23f1
--- /dev/null
+++ b/tensorflow/core/kernels/gather_functor_batched_gpu.cu.h
@@ -0,0 +1,132 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_GPU_CU_H_
+#define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_GPU_CU_H_
+
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/gather_functor_batched.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/gpu_kernel_helper.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename T, typename Index,
+          bool is_axis_zero, bool is_batch_dims_zero>
+__global__ void GatherOpKernel(const T* params, const Index* indices, T* out,
+                               int64 outer_size,
+                               int64 gather_dim_size, int64 indices_size,
+                               int64 slice_size, int64 out_size) {
+  // params is a tensor of shape
+  // [batch_size, outer_size, gather_dim_size, slice_size].
+  GPU_1D_KERNEL_LOOP(i, out_size) {
+    Index batch_i = 0;  // The batch index into params to use for i.
+    Index outer_i = 0;  // The outer index into params to use for i.
+    Index indices_i = 0;  // The index into indices to use for i.
+    Index slice_i = 0;  // Index into the current slice in params to use for i.
+
+    const Index slices_count = i / slice_size;
+    if (is_batch_dims_zero) {
+      if (is_axis_zero) {
+        indices_i = slices_count;
+      } else {
+        outer_i = slices_count / indices_size;
+        indices_i = slices_count - outer_i * indices_size;
+      }
+    } else {
+      const Index entries_count = slices_count / indices_size;
+      if (is_axis_zero) {
+        batch_i = entries_count;
+      } else {
+        batch_i = entries_count / outer_size;
+        outer_i = entries_count - batch_i * outer_size;
+      }
+      indices_i = slices_count - entries_count * indices_size;
+    }
+    slice_i = i - slices_count * slice_size;
+
+    // Index into the gather axis to use for i.
+    Index gather_i = ldg(indices + batch_i * indices_size + indices_i);
+
+    // Check gather_i is in [0, gather_dim_size).
+    if (!FastBoundsCheck(gather_i, gather_dim_size)) {
+      // Set indices out of range to zero
+      // TODO(fpmc): Log an error for transfer back to host.
+      out[i] = T(0);
+    } else {
+      // Read params[batch_i, outer_i, gather_i, slice_i] and write it to the
+      // i'th position in out.
+      Index params_i = (
+          (batch_i * outer_size + outer_i) * gather_dim_size + gather_i
+      ) * slice_size + slice_i;
+      out[i] = ldg(params + params_i);
+    }
+  }
+}
+
+namespace functor {
+template <typename T, typename Index>
+struct GatherFunctorBatched<GPUDevice, T, Index> {
+  int64 operator()(OpKernelContext* ctx,
+                   typename TTypes<T, 4>::ConstTensor params,
+                   typename TTypes<Index>::ConstFlat indices,
+                   typename TTypes<T, 4>::Tensor out) {
+    const GPUDevice& d = ctx->eigen_gpu_device();
+    const int64 out_size = out.size();
+    if (out_size == 0) {
+      // We need a check here since the CPU version does useful error checking
+      // work if there are nonempty indices but empty slices, so the kernel is
+      // executed in that case.  In the GPU case we don't know how to do error
+      // checking, so we skip the loop entirely.
+      return -1;
+    }
+    const bool is_batch_dims_zero = params.dimension(0) == 1;
+    const bool is_axis_zero = params.dimension(1) == 1;
+    const int64 outer_size = params.dimension(1);
+    const int64 gather_dim_size = params.dimension(2);
+    const int64 indices_size = indices.size() / params.dimension(0);
+    const int64 slice_size = params.dimension(3);
+
+    GpuLaunchConfig config = GetGpuLaunchConfig(out_size, d);
+    const auto function = is_axis_zero ?
+          (is_batch_dims_zero ?
+            GatherOpKernel<T, Index, true, true>:
+            GatherOpKernel<T, Index, true, false>) :
+          (is_batch_dims_zero ?
+             GatherOpKernel<T, Index, false, true>:
+             GatherOpKernel<T, Index, false, false>);
+    TF_CHECK_OK(GpuLaunchKernel(
+        function, config.block_count, config.thread_per_block, 0, d.stream(),
+        params.data(), indices.data(), out.data(),
+        outer_size, gather_dim_size, indices_size, slice_size, out_size));
+    // TODO(fpmc): enable indices validation on GPU.
+    // Right now checking for indicies out of bound in the kernel would
+    // require copying code between GPU/CPU, and thus slow.
+    return -1;
+  }
+};
+
+}  // namespace functor
+}  // namespace tensorflow
+
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+
+#endif  // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_GPU_CU_H_
diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc
index 68c258d..38e0bab 100644
--- a/tensorflow/core/kernels/gather_op.cc
+++ b/tensorflow/core/kernels/gather_op.cc
@@ -23,6 +23,7 @@
 #include "tensorflow/core/framework/variant.h"
 #include "tensorflow/core/framework/variant_encode_decode.h"
 #include "tensorflow/core/kernels/gather_functor.h"
+#include "tensorflow/core/kernels/gather_functor_batched.h"
 #include "tensorflow/core/platform/mem.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/util/util.h"
@@ -123,16 +124,22 @@
     // The result shape is params.shape[:axis] + indices.shape[batch_dims:] +
     // params.shape[axis + 1:].
     TensorShape result_shape;
+    int64 batch_size = 1;
     int64 outer_size = 1;
     int64 inner_size = 1;
-    for (int i = 0; i < axis; i++) {
+
+    for (int i = 0; i < batch_dims_; ++i) {
+      result_shape.AddDim(params.dim_size(i));
+      batch_size *= params.dim_size(i);
+    }
+    for (int i = batch_dims_; i < axis; ++i) {
       result_shape.AddDim(params.dim_size(i));
       outer_size *= params.dim_size(i);
     }
     for (int i = batch_dims_; i < indices.dims(); ++i) {
       result_shape.AddDim(indices.dim_size(i));
     }
-    for (int i = axis + 1; i < params.dims(); i++) {
+    for (int i = axis + 1; i < params.dims(); ++i) {
       result_shape.AddDim(params.dim_size(i));
       inner_size *= params.dim_size(i);
     }
@@ -141,60 +148,29 @@
     OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
     if (N == 0) return;
 
+    int64 bad_i = -1;
+    auto indices_flat = indices.flat<Index>();
     if (batch_dims_ > 0) {
-      // TODO(virimia): Switch to transpose / gather with axis=0 / transpose
-      // on GPU, to avoid launching a lot of small kernels.
+      auto params_flat = params.shaped<T, 4>(
+          {batch_size, outer_size, gather_dim_size, inner_size});
+      auto out_flat = out->shaped<T, 4>(
+          {batch_size, outer_size, N / batch_size, inner_size});
 
-      // To avoid copying params (by transposing), run gather for each batch.
-      int64 batch_size = 1;
-      for (int i = 0; i < batch_dims_; ++i) {
-        batch_size *= params.dim_size(i);
-      }
-      outer_size /= batch_size;
-      auto batched_params =
-          params.shaped<T, 2>({batch_size, params.NumElements() / batch_size});
-      auto batched_indices =
-          indices.shaped<Index, 2>({batch_size, N / batch_size});
-      auto batched_out =
-          out->shaped<T, 2>({batch_size, out->NumElements() / batch_size});
-
-      // TODO(virimia): Investigate the best performance, when the number of
-      // batches is large, between parallel vs sequential runs.
-      for (int64 batch = 0; batch < batch_size; ++batch) {
-        auto params_flat = typename TTypes<T, 3>::ConstTensor(
-            &batched_params(batch, 0), static_cast<IndexType>(outer_size),
-            static_cast<IndexType>(gather_dim_size),
-            static_cast<IndexType>(inner_size));
-        auto indices_flat = typename TTypes<Index>::ConstFlat(
-            &batched_indices(batch, 0), batched_indices.dimension(1));
-        auto out_flat = typename TTypes<T, 3>::Tensor(
-            &batched_out(batch, 0), static_cast<IndexType>(outer_size),
-            static_cast<IndexType>(N), static_cast<IndexType>(inner_size));
-
-        functor::GatherFunctor<Device, T, Index> functor;
-        const int64 bad_i = functor(c, params_flat, indices_flat, out_flat);
-
-        OP_REQUIRES(
-            c, bad_i < 0,
-            errors::InvalidArgument(
-                "indices", SliceDebugString(indices.shape(), bad_i), " = ",
-                indices_flat(bad_i), " is not in [0, ", gather_dim_size, ")"));
-      }
+      functor::GatherFunctorBatched<Device, T, Index> functor;
+      bad_i = functor(c, params_flat, indices_flat, out_flat);
     } else {
       auto params_flat =
           params.shaped<T, 3>({outer_size, gather_dim_size, inner_size});
-      auto indices_flat = indices.flat<Index>();
       auto out_flat = out->shaped<T, 3>({outer_size, N, inner_size});
 
       functor::GatherFunctor<Device, T, Index> functor;
-      const int64 bad_i = functor(c, params_flat, indices_flat, out_flat);
-
-      OP_REQUIRES(
-          c, bad_i < 0,
-          errors::InvalidArgument(
-              "indices", SliceDebugString(indices.shape(), bad_i), " = ",
-              indices_flat(bad_i), " is not in [0, ", gather_dim_size, ")"));
+      bad_i = functor(c, params_flat, indices_flat, out_flat);
     }
+    OP_REQUIRES(
+        c, bad_i < 0,
+        errors::InvalidArgument(
+            "indices", SliceDebugString(indices.shape(), bad_i), " = ",
+            indices_flat(bad_i), " is not in [0, ", gather_dim_size, ")"));
   }
 
  private:
diff --git a/tensorflow/core/kernels/generate_vocab_remapping_op.cc b/tensorflow/core/kernels/generate_vocab_remapping_op.cc
index 2b97677..03d9191 100644
--- a/tensorflow/core/kernels/generate_vocab_remapping_op.cc
+++ b/tensorflow/core/kernels/generate_vocab_remapping_op.cc
@@ -57,7 +57,7 @@
 
     // Build a new ID->token lookup table.
     const string& new_vocab_filename =
-        new_vocab_file_tensor->scalar<string>()();
+        new_vocab_file_tensor->scalar<tstring>()();
     OP_REQUIRES(context, !new_vocab_filename.empty(),
                 errors::InvalidArgument("new vocab filename cannot be empty."));
     lookup::HashTable<int64, string>* new_vocab_table =
@@ -88,7 +88,7 @@
                     old_vocab_file_tensor->shape().DebugString()));
     // Build a token->old ID lookup table.
     const string& old_vocab_filename =
-        old_vocab_file_tensor->scalar<string>()();
+        old_vocab_file_tensor->scalar<tstring>()();
     OP_REQUIRES(context, !old_vocab_filename.empty(),
                 errors::InvalidArgument("new vocab filename cannot be empty."));
     lookup::HashTable<string, int64>* old_vocab_table =
@@ -118,7 +118,7 @@
     OP_REQUIRES_OK(
         context, context->allocate_temp(
                      DT_STRING, TensorShape({num_new_vocab_}), &default_token));
-    auto default_token_vec = default_token.vec<string>();
+    auto default_token_vec = default_token.vec<tstring>();
     default_token_vec.setConstant("" /* NOT_FOUND_TOKEN */);
 
     Tensor default_id;
diff --git a/tensorflow/core/kernels/gpu_utils.cc b/tensorflow/core/kernels/gpu_utils.cc
index 318a917..68b069a 100644
--- a/tensorflow/core/kernels/gpu_utils.cc
+++ b/tensorflow/core/kernels/gpu_utils.cc
@@ -25,8 +25,71 @@
 #include "tensorflow/core/protobuf/autotuning.pb.h"
 #include "tensorflow/core/protobuf/conv_autotuning.pb.h"
 #include "tensorflow/core/util/proto/proto_utils.h"
+#include "tensorflow/stream_executor/cuda/ptxas_utils.h"
+#include "tensorflow/stream_executor/cuda/redzone_allocator.h"
 
 namespace tensorflow {
+
+bool RedzoneCheckDisabled() {
+  const char* disable_rz_str = std::getenv("TF_DISABLE_RZ_CHECK");
+  return disable_rz_str != nullptr && std::strcmp(disable_rz_str, "1") == 0;
+}
+
+se::DeviceMemoryBase WrapRedzoneBestEffort(
+    se::cuda::RedzoneAllocator* rz_allocator, se::DeviceMemoryBase buffer) {
+  if (RedzoneCheckDisabled()) {
+    return buffer;
+  }
+  se::DeviceMemoryBase output_tensor;
+  auto output_rz_or = rz_allocator->AllocateBytes(buffer.size());
+  if (!output_rz_or.ok()) {
+    static std::once_flag rz_allocation_failure_logged;
+    std::call_once(rz_allocation_failure_logged, []() {
+      LOG(WARNING) << "Failed to allocate memory for convolution redzone "
+                   << "checking; skipping this check. This is benign and only "
+                   << "means that we won't check cudnn for out-of-bounds reads "
+                   << "and writes. This message will only be printed once.";
+    });
+    return buffer;
+  }
+  return se::DeviceMemoryBase(output_rz_or.ValueOrDie());
+}
+
+void CheckRedzones(const se::cuda::RedzoneAllocator& rz_allocator,
+                   tensorflow::AutotuneResult* autotune_result) {
+  se::port::StatusOr<se::cuda::RedzoneAllocator::RedzoneCheckStatus> rz_status =
+      rz_allocator.CheckRedzones();
+  if (!rz_status.ok()) {
+    static std::once_flag failure_logged;
+    std::call_once(failure_logged, [&]() {
+      LOG(WARNING) << "Failed to check cudnn convolutions for out-of-bounds "
+                   << "reads and writes with an error message: '"
+                   << rz_status.status().error_message()
+                   << "'; skipping this check. This only means that we won't "
+                   << "check cudnn for out-of-bounds reads and writes. This "
+                   << "message will only be printed once.";
+    });
+    return;
+  }
+  auto rz_check_status = rz_status.ValueOrDie();
+  if (!rz_check_status.ok()) {
+    auto* fail = autotune_result->mutable_failure();
+    fail->set_msg(rz_check_status.RedzoneFailureMsg());
+    fail->set_kind(AutotuneResult::REDZONE_MODIFIED);
+    fail->set_buffer_address(
+        reinterpret_cast<uint64>(rz_check_status.user_buffer_address));
+    LOG(ERROR)
+        << "Detected cudnn out-of-bounds write in convolution buffer! This is "
+           "likely a cudnn bug. We will skip this algorithm in the future, but "
+           "your GPU state may already be corrupted, leading to incorrect "
+           "results. Within Google, no action is needed on your part. Outside "
+           "of Google, please ensure you're running the latest version of "
+           "cudnn. If that doesn't fix the problem, please file a bug with "
+           "this full error message and we'll contact nvidia.";
+    LOG(ERROR) << rz_check_status.RedzoneFailureMsg();
+  }
+}
+
 namespace {
 
 tensorflow::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) {
@@ -85,6 +148,14 @@
   *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec);
   *log.mutable_compute_capability() = GetComputeCapability(stream_exec);
   log.set_device_pci_bus_id(stream_exec->GetDeviceDescription().pci_bus_id());
+  {
+    string blas_version;
+    if (auto* blas = stream_exec->AsBlas()) {
+      if (blas->GetVersion(&blas_version).ok()) {
+        log.set_blas_version(blas_version);
+      }
+    }
+  }
   for (const auto& result : results) {
     *log.add_results() = result;
   }
@@ -123,6 +194,14 @@
   *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec);
   *log.mutable_compute_capability() = GetComputeCapability(stream_exec);
   log.set_device_pci_bus_id(stream_exec->GetDeviceDescription().pci_bus_id());
+  {
+    string blas_version;
+    if (auto* blas = stream_exec->AsBlas()) {
+      if (blas->GetVersion(&blas_version).ok()) {
+        log.set_blas_version(blas_version);
+      }
+    }
+  }
   for (const auto& result : results) {
     *log.add_results() = result;
   }
diff --git a/tensorflow/core/kernels/gpu_utils.h b/tensorflow/core/kernels/gpu_utils.h
index 67e8963..b3ac953 100644
--- a/tensorflow/core/kernels/gpu_utils.h
+++ b/tensorflow/core/kernels/gpu_utils.h
@@ -29,11 +29,38 @@
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/stream_executor.h"
 
+namespace stream_executor {
+namespace cuda {
+class RedzoneAllocator;
+}
+}  // namespace stream_executor
+
 namespace tensorflow {
 
 class NodeDef;
 class AutotuneResult;
 
+// Return whether the redzone check is disabled.
+//
+// Controlled by the TF_DISABLE_RZ_CHECK environment variable.
+bool RedzoneCheckDisabled();
+
+// Return an allocated buffer with redzones the size of `buffer`. Does
+// *not* copy the contents of the `buffer` into the newly allocated buffer:
+// assumes that buffer is a pure out-parameter.
+//
+// Returns `buffer` if RedzoneCheckDisabled() is true.
+//
+// On error, return `buffer`, and log an error message (once).
+se::DeviceMemoryBase WrapRedzoneBestEffort(
+    se::cuda::RedzoneAllocator* rz_allocator, se::DeviceMemoryBase buffer);
+
+// Check the passed allocator for redzone violations.
+// If violations have occurred, mark the corresponding autotune result
+// as a failure.
+void CheckRedzones(const se::cuda::RedzoneAllocator& rz_allocator,
+                   tensorflow::AutotuneResult* autotune_result);
+
 template <typename T>
 inline se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory, uint64 size) {
   se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory), size * sizeof(T));
diff --git a/tensorflow/core/kernels/inplace_ops.cc b/tensorflow/core/kernels/inplace_ops.cc
index c0d39d9..a6f0261 100644
--- a/tensorflow/core/kernels/inplace_ops.cc
+++ b/tensorflow/core/kernels/inplace_ops.cc
@@ -319,8 +319,8 @@
 void DoInplaceStringUpdateOp(const CPUDevice& d, const Tensor& i,
                              const Tensor& v, Tensor* y) {
   auto Ti = i.flat<int32>();
-  auto Tv = v.flat_outer_dims<string>();
-  auto Ty = y->flat_outer_dims<string>();
+  auto Tv = v.flat_outer_dims<tstring>();
+  auto Ty = y->flat_outer_dims<tstring>();
   auto nrows = Ty.dimension(0);
   for (int64 j = 0; j < Ti.size(); ++j) {
     auto r = (Ti(j) % nrows + nrows) % nrows;  // Guard index range.
diff --git a/tensorflow/core/kernels/load_and_remap_matrix_op.cc b/tensorflow/core/kernels/load_and_remap_matrix_op.cc
index 9d5a4b2..cb0245a 100644
--- a/tensorflow/core/kernels/load_and_remap_matrix_op.cc
+++ b/tensorflow/core/kernels/load_and_remap_matrix_op.cc
@@ -123,12 +123,11 @@
     // Processes the checkpoint source and the provided Tensor name.
     const Tensor* ckpt_path_t;
     OP_REQUIRES_OK(context, context->input("ckpt_path", &ckpt_path_t));
-    const string ckpt_path = *(ckpt_path_t->scalar<string>().data());
+    const string& ckpt_path = ckpt_path_t->scalar<tstring>()();
     const Tensor* old_tensor_name_t;
     OP_REQUIRES_OK(context,
                    context->input("old_tensor_name", &old_tensor_name_t));
-    const string old_tensor_name =
-        *(old_tensor_name_t->scalar<string>().data());
+    const string& old_tensor_name = old_tensor_name_t->scalar<tstring>()();
 
     LOG(INFO) << "Processing checkpoint : " << ckpt_path;
     BundleReader reader(context->env(), ckpt_path);
diff --git a/tensorflow/core/kernels/logging_ops.cc b/tensorflow/core/kernels/logging_ops.cc
index f93d324..e4d04c4 100644
--- a/tensorflow/core/kernels/logging_ops.cc
+++ b/tensorflow/core/kernels/logging_ops.cc
@@ -143,7 +143,7 @@
   void Compute(OpKernelContext* ctx) override {
     const Tensor* input_;
     OP_REQUIRES_OK(ctx, ctx->input("input", &input_));
-    const string& msg = input_->scalar<string>()();
+    const string& msg = input_->scalar<tstring>()();
 
     string ended_msg = strings::StrCat(msg, end_);
 
diff --git a/tensorflow/core/kernels/lookup_table_init_op.cc b/tensorflow/core/kernels/lookup_table_init_op.cc
index 6e77e1e..459a7b3 100644
--- a/tensorflow/core/kernels/lookup_table_init_op.cc
+++ b/tensorflow/core/kernels/lookup_table_init_op.cc
@@ -130,7 +130,7 @@
         errors::InvalidArgument("filename should be a single string, but got ",
                                 vocab_filename_tensor.shape().DebugString()));
 
-    string vocab_filename = vocab_filename_tensor.scalar<string>()();
+    const string& vocab_filename = vocab_filename_tensor.scalar<tstring>()();
     OP_REQUIRES(ctx, !vocab_filename.empty(),
                 errors::InvalidArgument("filename cannot be empty."));
 
diff --git a/tensorflow/core/kernels/lookup_table_op.h b/tensorflow/core/kernels/lookup_table_op.h
index 28a3d94..28d63cb 100644
--- a/tensorflow/core/kernels/lookup_table_op.h
+++ b/tensorflow/core/kernels/lookup_table_op.h
@@ -92,7 +92,7 @@
                                                       cinfo_.name());
     } else {
       if (!table_handle_set_) {
-        auto h = table_handle_.AccessTensor(ctx)->template flat<string>();
+        auto h = table_handle_.AccessTensor(ctx)->template flat<tstring>();
         h(0) = cinfo_.container();
         h(1) = cinfo_.name();
       }
diff --git a/tensorflow/core/kernels/lookup_util.cc b/tensorflow/core/kernels/lookup_util.cc
index c3b80f0..1fe7988 100644
--- a/tensorflow/core/kernels/lookup_util.cc
+++ b/tensorflow/core/kernels/lookup_util.cc
@@ -238,7 +238,7 @@
         tensor->flat<double>()(0) = value;
       } break;
       case DT_STRING:
-        tensor->flat<string>()(0) = token;
+        tensor->flat<tstring>()(0) = token;
         break;
       default:
         valid_ = false;
@@ -264,7 +264,7 @@
           "Lookup table handle must be scalar, but had shape: ",
           tensor.shape().DebugString());
     }
-    auto h = tensor.flat<string>();
+    auto h = tensor.flat<tstring>();
     *container = h(0);
     *table_handle = h(1);
   }
diff --git a/tensorflow/core/kernels/lrn_op.cc b/tensorflow/core/kernels/lrn_op.cc
index a199757..afe94ed 100644
--- a/tensorflow/core/kernels/lrn_op.cc
+++ b/tensorflow/core/kernels/lrn_op.cc
@@ -36,9 +36,17 @@
 
 #if GOOGLE_CUDA
 #include "third_party/gpus/cuda/include/cuda.h"
+#endif
+
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/gpu_utils.h"
+#if TENSORFLOW_USE_ROCM
+#include "tensorflow/core/kernels/conv_ops_gpu.h"
+#endif
 #include "tensorflow/core/platform/stream_executor.h"
 #include "tensorflow/core/util/stream_executor_util.h"
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 namespace tensorflow {
 
@@ -164,7 +172,7 @@
   T beta_;
 };
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 template <typename T>
 struct LaunchLRN<GPUDevice, T> {
@@ -173,6 +181,7 @@
 
   void launch(OpKernelContext* context, OpKernel* kernel, const Tensor& in,
               Tensor* output) {
+#if GOOGLE_CUDA
     OP_REQUIRES(
         context, beta_ >= 0.01,
         errors::InvalidArgument("cuDNN requires beta >= 0.01, got: ", beta_));
@@ -217,6 +226,71 @@
             .ok();
     OP_REQUIRES(context, status,
                 errors::Internal("NormalizeWithDimensions launch failed"));
+#elif TENSORFLOW_USE_ROCM
+    // For NHWC input/output tensors, convert to NCHW because it's the only
+    // supported format in MIOpen for now.
+
+    // Cast to platform-specific int to avoid conversion warnings.
+    const int batch = static_cast<int>(in.dim_size(0));
+    const int rows = static_cast<int>(in.dim_size(1));
+    const int cols = static_cast<int>(in.dim_size(2));
+    const int depth = static_cast<int>(in.dim_size(3));
+
+    Tensor transformed_input;
+    OP_REQUIRES_OK(context,
+                   context->allocate_temp(
+                       DataTypeToEnum<T>::value,
+                       ShapeFromFormat(FORMAT_NCHW, in.shape(), FORMAT_NHWC),
+                       &transformed_input));
+    functor::NHWCToNCHW<GPUDevice, T, 4>()(context->eigen_device<GPUDevice>(),
+                                           in.tensor<T, 4>(),
+                                           transformed_input.tensor<T, 4>());
+
+    Tensor transformed_output;
+    OP_REQUIRES_OK(
+        context, context->allocate_temp(
+                     DataTypeToEnum<T>::value,
+                     ShapeFromFormat(FORMAT_NCHW, output->shape(), FORMAT_NHWC),
+                     &transformed_output));
+
+    perftools::gputools::dnn::BatchDescriptor dimensions_desc;
+    dimensions_desc.set_count(batch)
+        .set_height(rows)
+        .set_width(cols)
+        .set_feature_map_count(depth)
+        .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+
+    perftools::gputools::dnn::NormalizeDescriptor normalize_desc;
+    normalize_desc.set_bias(bias_)
+        .set_range(depth_radius_)
+        .set_alpha(alpha_)
+        .set_beta(beta_);
+
+    auto input_data =
+        AsDeviceMemory(transformed_input.template flat<T>().data(),
+                       transformed_input.template flat<T>().size());
+    auto output_data =
+        AsDeviceMemory(transformed_output.template flat<T>().data(),
+                       transformed_output.template flat<T>().size());
+
+    auto* stream = context->op_device_context()->stream();
+    OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
+
+    bool status =
+        stream
+            ->ThenNormalizeWithDimensions(normalize_desc, dimensions_desc,
+                                          input_data, &output_data)
+            .ok();
+    OP_REQUIRES(context, status,
+                errors::Internal("NormalizeWithDimensions launch failed"));
+
+    // Need to convert it back to NHWC once MIOpen kernels finishes.
+    auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
+    functor::NCHWToNHWC<GPUDevice, T, 4>()(
+        context->eigen_device<GPUDevice>(),
+        toConstTensor(transformed_output).template tensor<T, 4>(),
+        output->tensor<T, 4>());
+#endif
   }
 
   int depth_radius_;
@@ -225,7 +299,7 @@
   T beta_;
 };
 
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 template <typename Device, typename T>
 class LRNOp : public OpKernel {
@@ -292,7 +366,7 @@
 
 #undef REGISTER_CPU
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #define REGISTER_GPU(T)                                      \
   REGISTER_KERNEL_BUILDER(                                   \
@@ -302,7 +376,7 @@
 
 #undef REGISTER_GPU
 
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #if !defined(IS_MOBILE_PLATFORM)
 
@@ -390,7 +464,7 @@
   T alpha_beta_2_;
 };
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 template <typename T>
 struct LaunchLRNGrad<GPUDevice, T> {
@@ -400,6 +474,7 @@
   void launch(OpKernelContext* context, OpKernel* kernel,
               const Tensor& in_grads, const Tensor& in_image,
               const Tensor& out_image, Tensor* output) {
+#if GOOGLE_CUDA
     OP_REQUIRES(
         context, beta_ >= 0.01,
         errors::InvalidArgument("cuDNN requires beta >= 0.01, got: ", beta_));
@@ -447,6 +522,105 @@
     OP_REQUIRES(
         context, status,
         errors::Internal("NormalizeBackwardWithDimensions launch failed"));
+#elif TENSORFLOW_USE_ROCM
+    // For NHWC input/output tensors, convert to NCHW because it's the only
+    // supported format in MIOpen for now.
+    const int64 batch = in_grads.dim_size(0);
+    const int64 rows = in_grads.dim_size(1);
+    const int64 cols = in_grads.dim_size(2);
+    const int64 depth = in_grads.dim_size(3);
+
+    Tensor transformed_in_grads;
+    OP_REQUIRES_OK(context, context->allocate_temp(
+                                DataTypeToEnum<T>::value,
+                                ShapeFromFormat(FORMAT_NCHW, in_grads.shape(),
+                                                FORMAT_NHWC),
+                                &transformed_in_grads));
+    functor::NHWCToNCHW<GPUDevice, T, 4>()(context->eigen_device<GPUDevice>(),
+                                           in_grads.tensor<T, 4>(),
+                                           transformed_in_grads.tensor<T, 4>());
+
+    Tensor transformed_in_image;
+    OP_REQUIRES_OK(context, context->allocate_temp(
+                                DataTypeToEnum<T>::value,
+                                ShapeFromFormat(FORMAT_NCHW, in_image.shape(),
+                                                FORMAT_NHWC),
+                                &transformed_in_image));
+    functor::NHWCToNCHW<GPUDevice, T, 4>()(context->eigen_device<GPUDevice>(),
+                                           in_image.tensor<T, 4>(),
+                                           transformed_in_image.tensor<T, 4>());
+
+    Tensor transformed_out_image;
+    OP_REQUIRES_OK(context, context->allocate_temp(
+                                DataTypeToEnum<T>::value,
+                                ShapeFromFormat(FORMAT_NCHW, out_image.shape(),
+                                                FORMAT_NHWC),
+                                &transformed_out_image));
+    functor::NHWCToNCHW<GPUDevice, T, 4>()(
+        context->eigen_device<GPUDevice>(), out_image.tensor<T, 4>(),
+        transformed_out_image.tensor<T, 4>());
+
+    Tensor transformed_output;
+    OP_REQUIRES_OK(
+        context, context->allocate_temp(
+                     DataTypeToEnum<T>::value,
+                     ShapeFromFormat(FORMAT_NCHW, output->shape(), FORMAT_NHWC),
+                     &transformed_output));
+
+    perftools::gputools::dnn::BatchDescriptor dimensions_desc;
+    dimensions_desc.set_count(batch)
+        .set_height(rows)
+        .set_width(cols)
+        .set_feature_map_count(depth)
+        .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+
+    perftools::gputools::dnn::NormalizeDescriptor normalize_desc;
+    normalize_desc.set_bias(bias_)
+        .set_range(depth_radius_)
+        .set_alpha(alpha_)
+        .set_beta(beta_);
+
+    auto input_grads_data =
+        AsDeviceMemory(transformed_in_grads.template flat<T>().data(),
+                       transformed_in_grads.template flat<T>().size());
+    auto input_image_data =
+        AsDeviceMemory(transformed_in_image.template flat<T>().data(),
+                       transformed_in_image.template flat<T>().size());
+    auto output_image_data =
+        AsDeviceMemory(transformed_out_image.template flat<T>().data(),
+                       transformed_out_image.template flat<T>().size());
+    auto output_grads_data =
+        AsDeviceMemory(transformed_output.template flat<T>().data(),
+                       transformed_output.template flat<T>().size());
+
+    auto* stream = context->op_device_context()->stream();
+    OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
+
+    static int64 NormalizeBackwardScratchSize = GetDnnWorkspaceLimit(
+        // default value is in bytes despite the name of the environment
+        // variable
+        "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32  // 4GB
+    );
+
+    DnnScratchAllocator scratch_allocator(NormalizeBackwardScratchSize,
+                                          context);
+    bool status = stream
+                      ->ThenNormalizeBackwardWithDimensions(
+                          normalize_desc, dimensions_desc, input_image_data,
+                          output_image_data, input_grads_data,
+                          &output_grads_data, &scratch_allocator)
+                      .ok();
+    OP_REQUIRES(
+        context, status,
+        errors::Internal("NormalizeBackwardWithDimensions launch failed"));
+
+    // Need to convert it back to NHWC once MIOpen kernels finishes.
+    auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
+    functor::NCHWToNHWC<GPUDevice, T, 4>()(
+        context->eigen_device<GPUDevice>(),
+        toConstTensor(transformed_output).template tensor<T, 4>(),
+        output->tensor<T, 4>());
+#endif
   }
 
   int depth_radius_;
@@ -455,7 +629,7 @@
   T beta_;
 };
 
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 template <typename Device, typename T>
 class LRNGradOp : public OpKernel {
@@ -524,7 +698,7 @@
 
 #undef REGISTER_CPU
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #define REGISTER_GPU(T)                                          \
   REGISTER_KERNEL_BUILDER(                                       \
@@ -534,7 +708,7 @@
 
 #undef REGISTER_GPU
 
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #endif  // !defined(IS_MOBILE_PLATFORM)
 
diff --git a/tensorflow/core/kernels/matching_files_op.cc b/tensorflow/core/kernels/matching_files_op.cc
index 7912ca1..0ba718c 100644
--- a/tensorflow/core/kernels/matching_files_op.cc
+++ b/tensorflow/core/kernels/matching_files_op.cc
@@ -40,7 +40,7 @@
         errors::InvalidArgument(
             "Input patterns tensor must be scalar or vector, but had shape: ",
             patterns_t->shape().DebugString()));
-    const auto patterns = patterns_t->flat<string>();
+    const auto patterns = patterns_t->flat<tstring>();
     int num_patterns = patterns.size();
     int num_files = 0;
     std::vector<std::vector<string>> all_fnames(num_patterns);
@@ -53,7 +53,7 @@
     OP_REQUIRES_OK(
         context, context->allocate_output("filenames", TensorShape({num_files}),
                                           &output_t));
-    auto output = output_t->vec<string>();
+    auto output = output_t->vec<tstring>();
     int index = 0;
     for (int i = 0; i < num_patterns; ++i) {
       for (int j = 0; j < all_fnames[i].size(); j++) {
diff --git a/tensorflow/core/kernels/matrix_triangular_solve_op.cc b/tensorflow/core/kernels/matrix_triangular_solve_op.cc
index bc7eb49..16fb29f 100644
--- a/tensorflow/core/kernels/matrix_triangular_solve_op.cc
+++ b/tensorflow/core/kernels/matrix_triangular_solve_op.cc
@@ -83,7 +83,7 @@
     const ConstMatrixMap& rhs = inputs[1];
     MatrixMap& output = outputs->at(0);
 
-    if (matrix.rows() == 0 || rhs.cols() == 0) {
+    if (matrix.rows() == 0 || rhs.rows() == 0 || rhs.cols() == 0) {
       // To be consistent with the MatrixInverse op, we define the solution for
       // an empty set of equation as the empty matrix.
       return;
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index aa4254d..fa3264d 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -357,7 +357,8 @@
   }
 };
 
-template <typename Device, class T, bool bias_enabled, bool is_depthwise>
+template <typename Device, class T, bool bias_enabled, bool is_depthwise,
+          bool eager_mode>
 class MklConvCustomBackpropFilterOp
     : public MklConvBackpropCommonOp<Device, T, is_depthwise> {
  public:
@@ -382,9 +383,9 @@
       const Tensor& diff_dst_tensor = MklGetInput(context, kOutbpropIdx);
 
       MklDnnShape src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape;
-      GetMklShape(context, kInputIdx, &src_mkl_shape);
-      GetMklShape(context, kFilterIdx, &filter_mkl_shape);
-      GetMklShape(context, kOutbpropIdx, &diff_dst_mkl_shape);
+      GetMklShape(context, kInputIdx, &src_mkl_shape, eager_mode);
+      GetMklShape(context, kFilterIdx, &filter_mkl_shape, eager_mode);
+      GetMklShape(context, kOutbpropIdx, &diff_dst_mkl_shape, eager_mode);
       // Allow operator-specific sanity checking of shapes.
       ValidateMklShapes(src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape);
 
@@ -395,7 +396,8 @@
       // allow this class to handle this case.
       TensorShape src_tf_shape = MakeInputTfShape(context, src_tensor);
       TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor);
-      TensorShape diff_dst_tf_shape = GetTfShape(context, kOutbpropIdx);
+      TensorShape diff_dst_tf_shape =
+          GetTfShape(context, kOutbpropIdx, eager_mode);
 
       // Corner cases: output with 0 elements and 0 batch size.
       Tensor* diff_filter_tensor = nullptr;
@@ -408,7 +410,8 @@
             GetOutputTfShape(src_tf_shape, filter_tf_shape, diff_dst_tf_shape);
         const int kOutputIdx = 0;
         AllocateOutputSetMklShape(context, kOutputIdx, &diff_filter_tensor,
-                                  diff_filter_tf_shape, diff_filter_mkl_shape);
+                                  diff_filter_tf_shape, diff_filter_mkl_shape,
+                                  eager_mode);
         CHECK_NOTNULL(diff_filter_tensor);
 
         // if output tensor has more than 0 elements, we need to 0 them out.
@@ -493,8 +496,8 @@
                bwd_output_dims[MklDnnDims::Dim_I],
                bwd_output_dims[MklDnnDims::Dim_O]});
           AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
-                                    diff_filter_tf_shape,
-                                    diff_filter_mkl_shape);
+                                    diff_filter_tf_shape, diff_filter_mkl_shape,
+                                    eager_mode);
         } else {
           // Depthwise Conv2d: bwd_output_dims is GOIHW format
           //                  | TensorFlow       | MKLDNN
@@ -620,7 +623,7 @@
   TensorShape MakeInputTfShape(OpKernelContext* context,
                                const Tensor& input_tensor) {
     size_t input_idx = 0;
-    return GetTfShape(context, input_idx);
+    return GetTfShape(context, input_idx, eager_mode);
   }
 
   // Get TensorFlow shape of filter tensor.
@@ -699,37 +702,43 @@
   }
 };
 
-#define REGISTER_MKL_FILTER_KERNELS(T)                            \
-  REGISTER_KERNEL_BUILDER(                                        \
-      Name("_MklConv2DBackpropFilter")                            \
-          .Device(DEVICE_CPU)                                     \
-          .TypeConstraint<T>("T")                                 \
-          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),    \
-      MklConvCustomBackpropFilterOp<CPUDevice, T, false, false>); \
-  REGISTER_KERNEL_BUILDER(                                        \
-      Name("_MklConv2DBackpropFilterWithBias")                    \
-          .Device(DEVICE_CPU)                                     \
-          .TypeConstraint<T>("T")                                 \
-          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),    \
-      MklConvCustomBackpropFilterOp<CPUDevice, T, true, false>);  \
-  REGISTER_KERNEL_BUILDER(                                        \
-      Name("_MklDepthwiseConv2dNativeBackpropFilter")             \
-          .Device(DEVICE_CPU)                                     \
-          .TypeConstraint<T>("T")                                 \
-          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),    \
-      MklConvCustomBackpropFilterOp<CPUDevice, T, false, true>);  \
-  REGISTER_KERNEL_BUILDER(                                        \
-      Name("__MklDummyConv2DBackpropFilterWithBias")              \
-          .Device(DEVICE_CPU)                                     \
-          .TypeConstraint<T>("T")                                 \
-          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),    \
-      MklDummyOp<CPUDevice, T>);                                  \
-  REGISTER_KERNEL_BUILDER(                                        \
-      Name("_MklConv3DBackpropFilterV2")                          \
-          .Device(DEVICE_CPU)                                     \
-          .TypeConstraint<T>("T")                                 \
-          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),    \
-      MklConvCustomBackpropFilterOp<CPUDevice, T, false, false>);
+#define REGISTER_MKL_FILTER_KERNELS(T)                                   \
+  REGISTER_KERNEL_BUILDER(                                               \
+      Name("_MklConv2DBackpropFilter")                                   \
+          .Device(DEVICE_CPU)                                            \
+          .TypeConstraint<T>("T")                                        \
+          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),           \
+      MklConvCustomBackpropFilterOp<CPUDevice, T, false, false, false>); \
+  REGISTER_KERNEL_BUILDER(                                               \
+      Name("_MklEagerConv2DBackpropFilter")                              \
+          .Device(DEVICE_CPU)                                            \
+          .TypeConstraint<T>("T")                                        \
+          .Label(mkl_op_registry::kMklNameChangeOpLabel),                \
+      MklConvCustomBackpropFilterOp<CPUDevice, T, false, false, true>);  \
+  REGISTER_KERNEL_BUILDER(                                               \
+      Name("_MklConv2DBackpropFilterWithBias")                           \
+          .Device(DEVICE_CPU)                                            \
+          .TypeConstraint<T>("T")                                        \
+          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),           \
+      MklConvCustomBackpropFilterOp<CPUDevice, T, true, false, false>);  \
+  REGISTER_KERNEL_BUILDER(                                               \
+      Name("_MklDepthwiseConv2dNativeBackpropFilter")                    \
+          .Device(DEVICE_CPU)                                            \
+          .TypeConstraint<T>("T")                                        \
+          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),           \
+      MklConvCustomBackpropFilterOp<CPUDevice, T, false, true, false>);  \
+  REGISTER_KERNEL_BUILDER(                                               \
+      Name("__MklDummyConv2DBackpropFilterWithBias")                     \
+          .Device(DEVICE_CPU)                                            \
+          .TypeConstraint<T>("T")                                        \
+          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),           \
+      MklDummyOp<CPUDevice, T>);                                         \
+  REGISTER_KERNEL_BUILDER(                                               \
+      Name("_MklConv3DBackpropFilterV2")                                 \
+          .Device(DEVICE_CPU)                                            \
+          .TypeConstraint<T>("T")                                        \
+          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),           \
+      MklConvCustomBackpropFilterOp<CPUDevice, T, false, false, false>);
 
 TF_CALL_float(REGISTER_MKL_FILTER_KERNELS);
 TF_CALL_bfloat16(REGISTER_MKL_FILTER_KERNELS);
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index e23e099..943f498 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -295,7 +295,7 @@
   }
 };
 
-template <typename Device, class T, bool is_depthwise>
+template <typename Device, class T, bool is_depthwise, bool eager_mode>
 class MklConvCustomBackpropInputOp
     : public MklConvBackpropCommonOp<Device, T, is_depthwise> {
  public:
@@ -319,9 +319,9 @@
       const Tensor& diff_dst_tensor = MklGetInput(context, kOutbpropIdx);
 
       MklDnnShape src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape;
-      GetMklShape(context, kInputIdx, &src_mkl_shape);
-      GetMklShape(context, kFilterIdx, &filter_mkl_shape);
-      GetMklShape(context, kOutbpropIdx, &diff_dst_mkl_shape);
+      GetMklShape(context, kInputIdx, &src_mkl_shape, eager_mode);
+      GetMklShape(context, kFilterIdx, &filter_mkl_shape, eager_mode);
+      GetMklShape(context, kOutbpropIdx, &diff_dst_mkl_shape, eager_mode);
       // Allow operator-specific sanity checking of shapes.
       ValidateMklShapes(src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape);
 
@@ -332,7 +332,8 @@
       // allow this class to handle this case.
       TensorShape src_tf_shape = MakeInputTfShape(context, src_tensor);
       TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor);
-      TensorShape diff_dst_tf_shape = GetTfShape(context, kOutbpropIdx);
+      TensorShape diff_dst_tf_shape =
+          GetTfShape(context, kOutbpropIdx, eager_mode);
 
       // Corner cases: output with 0 elements and 0 batch size.
       Tensor* diff_src_tensor = nullptr;
@@ -345,7 +346,8 @@
             GetOutputTfShape(src_tf_shape, filter_tf_shape, diff_dst_tf_shape);
         const int kOutputIdx = 0;
         AllocateOutputSetMklShape(context, kOutputIdx, &diff_src_tensor,
-                                  diff_src_tf_shape, diff_src_mkl_shape);
+                                  diff_src_tf_shape, diff_src_mkl_shape,
+                                  eager_mode);
         CHECK_NOTNULL(diff_src_tensor);
 
         // if output tensor has more than 0 elements, we need to 0 them out.
@@ -429,9 +431,13 @@
                                      bwd_diff_src_dims, bwd_diff_src_format);
       TensorShape diff_src_tf_shape;
       diff_src_tf_shape.AddDim(diff_src_pd.get_size() / sizeof(T));
+      Tensor tmp_tensor;
+      if (eager_mode) {
+        AllocTmpBuffer<T>(context, &tmp_tensor, diff_src_tf_shape);
+        diff_src_tf_shape = diff_src_mkl_shape.GetTfShape();
+      }
       AllocateOutputSetMklShape(context, 0, &diff_src_tensor, diff_src_tf_shape,
-                                diff_src_mkl_shape);
-
+                                diff_src_mkl_shape, eager_mode);
       T* diff_src_data =
           static_cast<T*>(const_cast<T*>(diff_src_tensor->flat<T>().data()));
 
@@ -458,7 +464,25 @@
       }
 
       // execute convolution input bwd
-      conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);
+      if (!eager_mode) {
+        conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);
+      } else {
+        // In eager mode we first write the output to temporary
+        // buffer in MKL format. Then we convert the data to TF format.
+        T* tmp_data =
+            static_cast<T*>(const_cast<T*>(tmp_tensor.flat<T>().data()));
+        conv_bwd_input->Execute(tmp_data, filter_data, diff_dst_data);
+        auto output_tf_md = diff_src_mkl_shape.GetTfLayout();
+        auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine);
+        mkldnn::reorder::primitive_desc reorder_pd =
+            mkldnn::reorder::primitive_desc(diff_src_pd, output_tf_pd);
+        std::vector<mkldnn::primitive> net;
+        memory* tmp_data_mem = new memory(diff_src_pd, tmp_data);
+        memory* dst_data_mem = new memory(output_tf_pd, diff_src_data);
+        net.push_back(
+            mkldnn::reorder(reorder_pd, *tmp_data_mem, *dst_data_mem));
+        stream(stream::kind::eager).submit(net).wait();
+      }
 
       // delete primitive since it is not cached.
       if (do_not_cache) {
@@ -506,7 +530,7 @@
   // Get TensorFlow shape of filter tensor.
   TensorShape MakeFilterTfShape(OpKernelContext* context,
                                 const Tensor& filter_tensor) {
-    return GetTfShape(context, kInputIndex_Filter);
+    return GetTfShape(context, kInputIndex_Filter, eager_mode);
   }
 
   // Get the Tensorflow shape of Output (diff_src),
@@ -557,26 +581,31 @@
   }
 };
 
-#define REGISTER_MKL_CPU_KERNELS(T)                            \
-  REGISTER_KERNEL_BUILDER(                                     \
-      Name("_MklConv2DBackpropInput")                          \
-          .Device(DEVICE_CPU)                                  \
-          .TypeConstraint<T>("T")                              \
-          .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
-      MklConvCustomBackpropInputOp<CPUDevice, T, false>);      \
-  REGISTER_KERNEL_BUILDER(                                     \
-      Name("_MklConv3DBackpropInputV2")                        \
-          .Device(DEVICE_CPU)                                  \
-          .TypeConstraint<T>("T")                              \
-          .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
-      MklConvCustomBackpropInputOp<CPUDevice, T, false>);      \
-  REGISTER_KERNEL_BUILDER(                                     \
-      Name("_MklDepthwiseConv2dNativeBackpropInput")           \
-          .Device(DEVICE_CPU)                                  \
-          .TypeConstraint<T>("T")                              \
-          .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
-      MklConvCustomBackpropInputOp<CPUDevice, T, true>);
-
+#define REGISTER_MKL_CPU_KERNELS(T)                              \
+  REGISTER_KERNEL_BUILDER(                                       \
+      Name("_MklConv2DBackpropInput")                            \
+          .Device(DEVICE_CPU)                                    \
+          .TypeConstraint<T>("T")                                \
+          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),   \
+      MklConvCustomBackpropInputOp<CPUDevice, T, false, false>); \
+  REGISTER_KERNEL_BUILDER(                                       \
+      Name("_MklEagerConv2DBackpropInput")                       \
+          .Device(DEVICE_CPU)                                    \
+          .TypeConstraint<T>("T")                                \
+          .Label(mkl_op_registry::kMklNameChangeOpLabel),        \
+      MklConvCustomBackpropInputOp<CPUDevice, T, false, true>);  \
+  REGISTER_KERNEL_BUILDER(                                       \
+      Name("_MklConv3DBackpropInputV2")                          \
+          .Device(DEVICE_CPU)                                    \
+          .TypeConstraint<T>("T")                                \
+          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),   \
+      MklConvCustomBackpropInputOp<CPUDevice, T, false, false>); \
+  REGISTER_KERNEL_BUILDER(                                       \
+      Name("_MklDepthwiseConv2dNativeBackpropInput")             \
+          .Device(DEVICE_CPU)                                    \
+          .TypeConstraint<T>("T")                                \
+          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),   \
+      MklConvCustomBackpropInputOp<CPUDevice, T, true, false>);
 TF_CALL_float(REGISTER_MKL_CPU_KERNELS);
 TF_CALL_bfloat16(REGISTER_MKL_CPU_KERNELS);
 #undef REGISTER_MKL_CPU_KERNELS
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index 14344da..433659a 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -47,13 +47,96 @@
 #include "tensorflow/core/util/padding.h"
 #include "tensorflow/core/util/tensor_format.h"
 
+using mkldnn::convolution_forward;
 using mkldnn::prop_kind;
 using mkldnn::stream;
-using mkldnn::convolution_forward;
-using mkldnn::convolution_direct;
 
 namespace tensorflow {
 
+#ifdef ENABLE_MKLDNN_V1
+#define ADD_MD add_md
+#define ALGORITHM mkldnn::algorithm
+#define ALGORITHM_UNDEF ALGORITHM::undef
+#define CPU_STREAM(engine) stream(engine)
+#define DATA_WITH_ENGINE(data, engine) data, engine
+#define DST_MD dst_md
+#define ENGINE_CPU engine::kind::cpu
+#define GET_DESC get_desc()
+#define GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm) \
+  { {dims}, MklDnnType<type>(), memory::format_tag::fm }
+#define GET_SRC_DESC_FROM_OP_PD(op_pd) op_pd->src_desc()
+#define GET_WEIGHTS_DESC_FROM_OP_PD(op_pd) op_pd->weights_desc()
+#define GET_WEIGHTS_FORMAT_FROM_OP_PD(op_pd, op_primitive) \
+  GET_WEIGHTS_DESC_FROM_OP_PD(op_pd)
+#define IS_FILTER_REORDER_NEEDED(filter_md, op_pd, op_primitive) \
+  filter_md != op_pd->weights_desc()
+#define IS_SRC_REORDER_NEEDED(src_md, op_pd, op_primitive) \
+  src_md != op_pd->src_desc()
+#define MEMORY_CONSTRUCTOR(mem_desc, engine, data) \
+  memory(mem_desc, engine, data)
+#define MEMORY_CONSTRUCTOR_USING_MEM_PD(dims, type, fm, engine, data) \
+  memory(GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine, data)
+#define MEMORY_CONSTRUCTOR_WITHOUT_DATA(mem_desc, engine) \
+  memory(mem_desc, engine)
+#define MEMORY_DESC memory::desc
+#define MEMORY_FORMAT mkldnn::memory::format_tag
+#define MEMORY_PD_CONSTRUCTOR(dims, type, fm, engine) \
+  memory::desc({dims}, MklDnnType<type>(), memory::format_tag::fm)
+#define MEMORY_PD_WITHOUT_DATA(md, engine) md, engine
+#define MKL_TENSOR_FORMAT MklTensorFormat
+#define MKL_TENSOR_FORMAT_BLOCKED MklTensorFormat::FORMAT_BLOCKED
+#define MKL_TENSOR_FORMAT_IN_C MKL_TENSOR_FORMAT
+#define PRIMITIVE_DESC_BIAS bias_desc()
+#define PRIMITIVE_DESC_DST dst_desc()
+#define PRIMITIVE_DESC_SRC src_desc()
+#define PRIMITIVE_DESC_WEIGHTS weights_desc()
+#define REORDER_PD_CONSTRUCTOR(src_md, dst_md, engine) \
+  mkldnn::reorder::primitive_desc(engine, src_md, engine, dst_md)
+#define REORDER_PD_CONSTRUCTOR_WITH_ATTR(src_md, dst_md, engine, prim_attr) \
+  mkldnn::reorder::primitive_desc(engine, src_md, engine, dst_md, prim_attr)
+#define SUMMAND_MD summand_md
+#else
+#define ADD_MD add_pd
+#define ALGORITHM mkldnn
+#define ALGORITHM_UNDEF ALGORITHM::algorithm_undef
+#define CPU_STREAM(engine) stream(stream::kind::eager)
+#define DATA_WITH_ENGINE(data, engine) data
+#define DST_MD dst_pd
+#define ENGINE_CPU engine::cpu
+#define GET_DESC get_primitive_desc()
+#define GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm) \
+  { {dims}, MklDnnType<type>(), memory::format::fm }
+#define GET_SRC_DESC_FROM_OP_PD(op_pd) op_pd.get()->src_primitive_desc()
+#define GET_WEIGHTS_DESC_FROM_OP_PD(op_pd) op_pd.get()->weights_primitive_desc()
+#define GET_WEIGHTS_FORMAT_FROM_OP_PD(op_pd, op_primitive) \
+  op_primitive->GetFilterMemoryFormat()
+#define IS_FILTER_REORDER_NEEDED(filter_md, op_pd, op_primitive) \
+  filter_md.data.format != op_primitive->GetFilterMemoryFormat()
+#define IS_SRC_REORDER_NEEDED(src_md, op_pd, op_primitive) \
+  src_md.data.format != op_primitive->GetSrcMemoryFormat()
+#define MEMORY_CONSTRUCTOR(mem_pd, engine, data) memory(mem_pd, data)
+#define MEMORY_CONSTRUCTOR_USING_MEM_PD(dims, type, fm, engine, data) \
+  memory({GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine}, data)
+#define MEMORY_CONSTRUCTOR_WITHOUT_DATA(mem_pd, engine) memory(mem_pd)
+#define MEMORY_DESC memory::format
+#define MEMORY_FORMAT mkldnn::memory::format
+#define MEMORY_PD_CONSTRUCTOR(dims, type, fm, engine) \
+  memory::primitive_desc(GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine)
+#define MEMORY_PD_WITHOUT_DATA(pd, engine) pd
+#define MKL_TENSOR_FORMAT memory::format
+#define MKL_TENSOR_FORMAT_BLOCKED memory::format::blocked
+#define MKL_TENSOR_FORMAT_IN_C mkldnn_memory_format_t
+#define PRIMITIVE_DESC_BIAS bias_primitive_desc()
+#define PRIMITIVE_DESC_DST dst_primitive_desc()
+#define PRIMITIVE_DESC_SRC src_primitive_desc()
+#define PRIMITIVE_DESC_WEIGHTS weights_primitive_desc()
+#define REORDER_PD_CONSTRUCTOR(src_pd, dst_pd, engine) \
+  mkldnn::reorder::primitive_desc(src_pd, dst_pd)
+#define REORDER_PD_CONSTRUCTOR_WITH_ATTR(src_pd, dst_pd, engine, prim_attr) \
+  mkldnn::reorder::primitive_desc(src_pd, dst_pd, prim_attr)
+#define SUMMAND_MD summand_pd
+#endif  // ENABLE_MKLDNN_V1
+
 // This structure aggregates multiple inputs to Conv2DFwd* methods.
 struct MklConvFwdParams {
   memory::dims src_dims;
@@ -94,9 +177,9 @@
 class MklConvFwdPrimitive : public MklPrimitive {
  public:
   explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims)
-      : cpu_engine_(engine::cpu, 0) {
-    context_.fwd_stream.reset(new stream(stream::kind::eager));
-    // Create conv primitive
+      : cpu_engine_(ENGINE_CPU, 0) {
+    context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
+    // Create convolution primitive
     if (context_.conv_fwd == nullptr) {
       Setup(convFwdDims);
     }
@@ -115,19 +198,30 @@
         static_cast<void*>(const_cast<Tinput*>(src_data)));
     context_.filter_mem->set_data_handle(
         static_cast<void*>(const_cast<Tfilter*>(filter_data)));
-    context_.bias_mem->set_data_handle(
-        static_cast<void*>(const_cast<Tbias*>(bias_data)));
+    if (bias_data != nullptr) {
+      context_.bias_mem->set_data_handle(
+          static_cast<void*>(const_cast<Tbias*>(bias_data)));
+    }
     context_.dst_mem->set_data_handle(
         static_cast<void*>(const_cast<Toutput*>(dst_data)));
+#ifdef ENABLE_MKLDNN_V1
+    DCHECK_EQ(context_.fwd_primitives.size(),
+              context_.fwd_primitives_args.size());
+    for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
+      context_.fwd_primitives.at(i).execute(*context_.fwd_stream,
+                                            context_.fwd_primitives_args.at(i));
+    }
+#else
     context_.fwd_stream->submit(context_.fwd_primitives);
+#endif  // ENABLE_MKLDNN_V1
 
-    // After exec, set data handle back
+    // After execution, set data handle back
     context_.src_mem->set_data_handle(DummyData);
     context_.filter_mem->set_data_handle(DummyData);
-    context_.bias_mem->set_data_handle(DummyData);
+    if (bias_data != nullptr) {
+      context_.bias_mem->set_data_handle(DummyData);
+    }
     context_.dst_mem->set_data_handle(DummyData);
-
-    return;
   }
 
   // Convolution forward execute without bias
@@ -136,23 +230,15 @@
   //   dst_data:    output data buffer of dst
   void Execute(const Tinput* src_data, const Tfilter* filter_data,
                const Toutput* dst_data) {
-    context_.src_mem->set_data_handle(
-        static_cast<void*>(const_cast<Tinput*>(src_data)));
-    context_.filter_mem->set_data_handle(
-        static_cast<void*>(const_cast<Tfilter*>(filter_data)));
-    context_.dst_mem->set_data_handle(
-        static_cast<void*>(const_cast<Toutput*>(dst_data)));
-    context_.fwd_stream->submit(context_.fwd_primitives);
-
-    // After execution, set data handle back
-    context_.src_mem->set_data_handle(DummyData);
-    context_.filter_mem->set_data_handle(DummyData);
-    context_.dst_mem->set_data_handle(DummyData);
+    Execute(src_data, filter_data, nullptr, dst_data);
   }
 
+#ifndef ENABLE_MKLDNN_V1
+  // In MKL-DNN v1.x, memory format tags only provide a partial description
+  // of the memory layout. Hence, these functions are disabled for v1.x.
   memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
-
   memory::format GetFilterMemoryFormat() const { return context_.filter_fmt; }
+#endif  // !ENABLE_MKLDNN_V1
 
   std::shared_ptr<ConvFwdPd> GetPrimitiveDesc() const {
     return context_.fwd_pd;
@@ -161,17 +247,19 @@
  private:
   // Primitive reuse context for Conv2D Fwd op
   struct ConvFwdContext {
+#ifndef ENABLE_MKLDNN_V1
     // Expected memory format for this primitive instance
     memory::format src_fmt;
     memory::format filter_fmt;
+#endif  // !ENABLE_MKLDNN_V1
 
-    // MKLDNN memory
+    // MKL-DNN memory
     std::shared_ptr<mkldnn::memory> src_mem;
     std::shared_ptr<mkldnn::memory> filter_mem;
     std::shared_ptr<mkldnn::memory> bias_mem;
     std::shared_ptr<mkldnn::memory> dst_mem;
 
-    // Desc & prmitive desc
+    // Desc & primitive desc
     std::shared_ptr<mkldnn::convolution_forward::desc> fwd_desc;
 
     // Memory desc
@@ -187,9 +275,16 @@
     std::shared_ptr<mkldnn::stream> fwd_stream;
     std::vector<mkldnn::primitive> fwd_primitives;
 
+#ifdef ENABLE_MKLDNN_V1
+    std::vector<std::unordered_map<int, memory>> fwd_primitives_args;
+#endif  // ENABLE_MKLDNN_V1
+
     ConvFwdContext()
-        : src_fmt(memory::format::any),
+        :
+#ifndef ENABLE_MKLDNN_V1
+          src_fmt(memory::format::any),
           filter_fmt(memory::format::any),
+#endif  // !ENABLE_MKLDNN_V1
           src_mem(nullptr),
           filter_mem(nullptr),
           bias_mem(nullptr),
@@ -200,34 +295,35 @@
           bias_md(nullptr),
           fwd_pd(nullptr),
           conv_fwd(nullptr),
-          fwd_stream(nullptr) {}
+          fwd_stream(nullptr) {
+    }
   };
 
   void Setup(const MklConvFwdParams& convFwdDims) {
     // Create memory descriptors for convolution data w/ no specified format
     context_.src_md.reset(new memory::desc(
-        {convFwdDims.src_dims}, MklDnnType<Tinput>(), memory::format::any));
+        {convFwdDims.src_dims}, MklDnnType<Tinput>(), MEMORY_FORMAT::any));
 
     context_.filter_md.reset(new memory::desc(
-        {convFwdDims.filter_dims}, MklDnnType<Tfilter>(), memory::format::any));
+        {convFwdDims.filter_dims}, MklDnnType<Tfilter>(), MEMORY_FORMAT::any));
 
     context_.dst_md.reset(new memory::desc(
-        {convFwdDims.dst_dims}, MklDnnType<Toutput>(), memory::format::any));
+        {convFwdDims.dst_dims}, MklDnnType<Toutput>(), MEMORY_FORMAT::any));
 
     if (!convFwdDims.bias_dims.empty())
       context_.bias_md.reset(new memory::desc(
-          {convFwdDims.bias_dims}, MklDnnType<Tbias>(), memory::format::any));
+          {convFwdDims.bias_dims}, MklDnnType<Tbias>(), MEMORY_FORMAT::any));
 
-    // Create a convolution
+    // Create a convolution descriptor
     if (!convFwdDims.bias_dims.empty()) {
       context_.fwd_desc.reset(new convolution_forward::desc(
-          prop_kind::forward, convolution_direct, *context_.src_md,
+          prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
           *context_.filter_md, *context_.bias_md, *context_.dst_md,
           convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
           convFwdDims.padding_right, padding_kind::zero));
     } else {
       context_.fwd_desc.reset(new convolution_forward::desc(
-          prop_kind::forward, convolution_direct, *context_.src_md,
+          prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
           *context_.filter_md, *context_.dst_md, convFwdDims.strides,
           convFwdDims.dilations, convFwdDims.padding_left,
           convFwdDims.padding_right, padding_kind::zero));
@@ -246,7 +342,12 @@
           float op_scale = post_op_param.param[0];
           float op_alpha = post_op_param.param[1];
           float op_beta = post_op_param.param[2];
+#ifdef ENABLE_MKLDNN_V1
+          post_ops.append_eltwise(op_scale, mkldnn::algorithm::eltwise_relu,
+                                  op_alpha,
+#else
           post_ops.append_eltwise(op_scale, post_op_param.alg, op_alpha,
+#endif  // ENABLE_MKLDNN_V1
                                   op_beta);
         } else if (post_op_param.name == "sum") {
           DCHECK_EQ(post_op_param.param.size(), 1);
@@ -271,27 +372,44 @@
       context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
     }
 
+#ifndef ENABLE_MKLDNN_V1
     // Store the expected memory format
     context_.src_fmt = static_cast<mkldnn::memory::format>(
         context_.fwd_pd.get()->src_primitive_desc().desc().data.format);
 
     context_.filter_fmt = static_cast<mkldnn::memory::format>(
         context_.fwd_pd.get()->weights_primitive_desc().desc().data.format);
+#endif  // !ENABLE_MKLDNN_V1
 
     // Create memory primitive based on dummy data
-    context_.src_mem.reset(
-        new memory(context_.fwd_pd.get()->src_primitive_desc(), DummyData));
-    context_.filter_mem.reset(
-        new memory(context_.fwd_pd.get()->weights_primitive_desc(), DummyData));
-    context_.dst_mem.reset(
-        new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
+    context_.src_mem.reset(new MEMORY_CONSTRUCTOR(
+        context_.fwd_pd.get()->PRIMITIVE_DESC_SRC, cpu_engine_, DummyData));
+    context_.filter_mem.reset(new MEMORY_CONSTRUCTOR(
+        context_.fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_, DummyData));
+    context_.dst_mem.reset(new MEMORY_CONSTRUCTOR(
+        context_.fwd_pd.get()->PRIMITIVE_DESC_DST, cpu_engine_, DummyData));
 
     // Create convolution primitive and add it to net
     if (!convFwdDims.bias_dims.empty()) {
-      context_.bias_mem.reset(new memory(
-          {{{convFwdDims.bias_dims}, MklDnnType<Tbias>(), memory::format::x},
-           cpu_engine_},
-          DummyData));
+      context_.bias_mem.reset(new MEMORY_CONSTRUCTOR_USING_MEM_PD(
+          convFwdDims.bias_dims, Tbias, x, cpu_engine_, DummyData));
+#ifdef ENABLE_MKLDNN_V1
+      context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd));
+      context_.fwd_primitives_args.push_back(
+          {{MKLDNN_ARG_SRC, *context_.src_mem},
+           {MKLDNN_ARG_WEIGHTS, *context_.filter_mem},
+           {MKLDNN_ARG_BIAS, *context_.bias_mem},
+           { MKLDNN_ARG_DST,
+             *context_.dst_mem }});
+    } else {
+      context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd));
+      context_.fwd_primitives_args.push_back(
+          {{MKLDNN_ARG_SRC, *context_.src_mem},
+           {MKLDNN_ARG_WEIGHTS, *context_.filter_mem},
+           { MKLDNN_ARG_DST,
+             *context_.dst_mem }});
+    }
+#else
       context_.conv_fwd.reset(new convolution_forward(
           *context_.fwd_pd, *context_.src_mem, *context_.filter_mem,
           *context_.bias_mem, *context_.dst_mem));
@@ -300,9 +418,8 @@
           new convolution_forward(*context_.fwd_pd, *context_.src_mem,
                                   *context_.filter_mem, *context_.dst_mem));
     }
-
+#endif  // ENABLE_MKLDNN_V1
     context_.fwd_primitives.push_back(*context_.conv_fwd);
-    return;
   }
 
   struct ConvFwdContext context_;
@@ -401,7 +518,8 @@
 // Base class for convolution forward operations
 template <typename Device, typename Tinput, typename Tfilter, typename Tbias,
           typename Toutput, typename Ttemp_output, typename Tpadding,
-          bool bias_enabled, bool pad_enabled, bool is_depthwise>
+          bool bias_enabled, bool pad_enabled, bool is_depthwise,
+          bool eager_mode>
 class MklConvOp : public OpKernel {
  public:
   ~MklConvOp() {}
@@ -428,8 +546,10 @@
                                 "strides in the batch and depth dimensions."));
     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
     is_filter_const_ = false;
-    OP_REQUIRES_OK(context,
-                   context->GetAttr("is_filter_const", &is_filter_const_));
+    if (context->HasAttr("is_filter_const")) {
+      OP_REQUIRES_OK(context,
+                     context->GetAttr("is_filter_const", &is_filter_const_));
+    }
 
     if (strides_.size() == 4) {
       OP_REQUIRES(context, dilations_.size() == 4,
@@ -472,8 +592,9 @@
       const Tensor& filter_tensor = MklGetInput(context, kInputIndex_Filter);
 
       MklDnnShape src_mkl_shape, filter_mkl_shape;
-      GetMklShape(context, kInputIndex_Src, &src_mkl_shape);
-      GetMklShape(context, kInputIndex_Filter, &filter_mkl_shape);
+      GetMklShape(context, kInputIndex_Src, &src_mkl_shape, eager_mode);
+      GetMklShape(context, kInputIndex_Filter, &filter_mkl_shape, eager_mode);
+
       OP_REQUIRES(context, filter_mkl_shape.IsMklTensor() == false,
                   errors::InvalidArgument("Filter should not be in "
                                           "Mkl Layout"));
@@ -503,8 +624,9 @@
       // Get shapes of input tensors in MKL-DNN order
       MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_,
                               dilations_);
-      auto src_tf_shape = GetTfShape(context, kInputIndex_Src);
-      auto filter_tf_shape = GetTfShape(context, kInputIndex_Filter);
+      auto src_tf_shape = GetTfShape(context, kInputIndex_Src, eager_mode);
+      auto filter_tf_shape =
+          GetTfShape(context, kInputIndex_Filter, eager_mode);
       conv_utl.GetConvFwdSizesInMklOrder(
           src_tf_shape, filter_tf_shape, &src_dims, &filter_dims, &strides,
           &dilations, &dst_dims_tf_order, &dst_dims_mkl_order, &padding_left,
@@ -517,15 +639,17 @@
 
       // Corner cases: output with 0 elements and 0 batch size.
       Tensor* dst_tensor = nullptr;
+      Tensor tmp_tensor;
       bool emit_filter_output = (typeid(Tinput) == typeid(Tfilter) &&
                                  typeid(Tinput) == typeid(Toutput) &&
                                  (typeid(Tinput) == typeid(float) ||
-                                  typeid(Tinput) == typeid(bfloat16)));
+                                  typeid(Tinput) == typeid(bfloat16))) &&
+                                !eager_mode;
       if (dst_tf_shape.num_elements() == 0 || dst_dims_tf_order[0] == 0) {
         MklDnnShape dst_mkl_shape;
         dst_mkl_shape.SetMklTensor(false);
         AllocateOutputSetMklShape(context, kOutputIndex_Dst, &dst_tensor,
-                                  src_tf_shape, dst_mkl_shape);
+                                  src_tf_shape, dst_mkl_shape, eager_mode);
 
         // MklConv2D/3D also outputs converted filter as 2nd output.
         filter_mkl_shape.SetMklTensor(false);
@@ -566,6 +690,12 @@
       auto tf_fmt = is_conv2d ? TFDataFormatToMklDnnDataFormat(data_format_)
                               : TFDataFormatToMklDnn3DDataFormat(data_format_);
 
+#ifdef ENABLE_MKLDNN_V1
+      auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt);
+      // NOTE: `mkl_fmt_tag` will be `format_tag::undef` for ReLU
+      DCHECK_NE(mkl_fmt_tag, memory::format_tag::undef);
+#endif  // ENABLE_MKLDNN_V1
+
       // If input is in MKL layout, then simply grab the layout; otherwise,
       // construct TF layout for input.
       // For constructing TF layout for input, although input shape (src_dims)
@@ -573,18 +703,22 @@
       // TF layout depending on the data format:
       //     Conv2D: NHWC or NCHW
       //     Conv3D: NDHWC or NCDHW
-      auto src_md = src_mkl_shape.IsMklTensor()
-                        ? src_mkl_shape.GetMklLayout()
-                        : memory::desc(src_dims, MklDnnType<Tinput>(), tf_fmt);
+      auto src_md =
+          src_mkl_shape.IsMklTensor()
+              ? src_mkl_shape.GetMklLayout()
+#ifdef ENABLE_MKLDNN_V1
+              : memory::desc(src_dims, MklDnnType<Tinput>(), mkl_fmt_tag);
+#else
+              : memory::desc(src_dims, MklDnnType<Tinput>(), tf_fmt);
+#endif  // ENABLE_MKLDNN_V1
       src.SetUsrMem(src_md, &src_tensor);
 
       // Although filter shape (filter_dims) required is in MKL-DNN order,
       // the layout is Tensorflow's layout (HWIO) and (HWIGO) for
       // depthwise/group convolutions.
-
-      auto filter_format = is_conv2d ? (is_depthwise ? memory::format::hwigo
-                                                     : memory::format::hwio)
-                                     : memory::format::dhwio;
+      auto filter_format = is_conv2d ? (is_depthwise ? MEMORY_FORMAT::hwigo
+                                                     : MEMORY_FORMAT::hwio)
+                                     : MEMORY_FORMAT::dhwio;
 
       DCHECK(!filter_mkl_shape.IsMklTensor());
       auto filter_md =
@@ -593,7 +727,7 @@
               : memory::desc(filter_dims, MklDnnType<Tfilter>(), filter_format);
       filter.SetUsrMem(filter_md, &filter_tensor);
 
-      // MKLDNN dilations start from 0.
+      // MKL-DNN dilations start from 0.
       for (int i = 0; i < dilations.size(); ++i) --dilations[i];
 
       // In some cases, primitive descriptor could potentially contain
@@ -627,9 +761,10 @@
               convFwdDims, do_not_cache);
 
       // Allocate output tensors `output_tensor` and `filter_out_tensor`
+      MklDnnShape output_mkl_shape;
       std::shared_ptr<ConvFwdPd> conv_fwd_pd = conv_fwd->GetPrimitiveDesc();
       AllocateOutputTensor(context, *conv_fwd_pd, dst_dims_mkl_order, tf_fmt,
-                           &dst_tensor);
+                           &output_mkl_shape, &dst_tensor, &tmp_tensor);
 
       Tensor* filter_out_tensor = nullptr;
       if (emit_filter_output) {
@@ -643,10 +778,11 @@
 
       // Check whether src and filter need to be reordered
       Tinput* src_data = nullptr;
-      if (src_md.data.format != conv_fwd->GetSrcMemoryFormat()) {
+      if (IS_SRC_REORDER_NEEDED(src_md, conv_fwd_pd, conv_fwd)) {
         // Reorder src
         src.SetUsrMem(src_md, &src_tensor);
-        src.CheckReorderToOpMem(conv_fwd_pd.get()->src_primitive_desc());
+        src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
+            GET_SRC_DESC_FROM_OP_PD(conv_fwd_pd), cpu_engine_));
         src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle());
       } else {
         src_data = static_cast<Tinput*>(
@@ -654,7 +790,7 @@
       }
 
       Tfilter* filter_data = nullptr;
-      if (filter_md.data.format != conv_fwd->GetFilterMemoryFormat()) {
+      if (IS_FILTER_REORDER_NEEDED(filter_md, conv_fwd_pd, conv_fwd)) {
         bool is_filter_cached = false;
         // If filter is a constant, we can avoid the conversion of filter from
         // Tensorflow format to MKL format by caching the filter when it is
@@ -664,21 +800,26 @@
           if (IsFilterCacheEmpty(context)) {
             // Cache filter if it is not already cached.
             CacheFilter(context, conv_fwd_pd, filter_data, filter_tensor,
+#ifdef ENABLE_MKLDNN_V1
+                        filter, filter_md, filter_mkl_shape);
+#else
                         filter, filter_md);
+#endif  // ENABLE_MKLDNN_V1
           }
-          filter_data =
-              GetCachedFilter(context, conv_fwd->GetFilterMemoryFormat());
+          filter_data = GetCachedFilter(
+              context, GET_WEIGHTS_FORMAT_FROM_OP_PD(conv_fwd_pd, conv_fwd));
           is_filter_cached = (filter_data != nullptr);
         }
         if (!is_filter_cached) {
           filter.SetUsrMem(filter_md, &filter_tensor);
           if (filter_out_tensor == nullptr) {
-            filter.CheckReorderToOpMem(
-                conv_fwd_pd.get()->weights_primitive_desc());
+            filter.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
+                GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd), cpu_engine_));
           } else {
             filter.CheckReorderToOpMem(
-                conv_fwd_pd.get()->weights_primitive_desc(),
-                filter.GetTensorBuffer(filter_out_tensor));
+                GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd),
+                DATA_WITH_ENGINE(filter.GetTensorBuffer(filter_out_tensor),
+                                 cpu_engine_));
           }
           filter_data =
               static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
@@ -695,7 +836,28 @@
             this->GetBiasHandle(context, conv_fwd_pd, bias_tensor);
         conv_fwd->Execute(src_data, filter_data, bias_data, dst_data);
       } else {
-        conv_fwd->Execute(src_data, filter_data, dst_data);
+        if (!eager_mode) {
+          conv_fwd->Execute(src_data, filter_data, dst_data);
+        } else {
+          // In eager mode we first write the output to temporary
+          // buffer in MKL format. Then we convert the data to TF format.
+          Ttemp_output* tmp_data = reinterpret_cast<Ttemp_output*>(
+              tmp_tensor.flat<Toutput>().data());
+          conv_fwd->Execute(src_data, filter_data, tmp_data);
+
+          // Now we need to convert the output to TF format.
+          auto output_tf_md = output_mkl_shape.GetTfLayout();
+          auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine_);
+          auto dst_pd = (*conv_fwd_pd).dst_primitive_desc();
+          mkldnn::reorder::primitive_desc reorder_pd =
+              mkldnn::reorder::primitive_desc(dst_pd, output_tf_pd);
+          std::vector<mkldnn::primitive> net;
+          memory* tmp_data_mem = new memory(dst_pd, tmp_data);
+          memory* dst_data_mem = new memory(output_tf_pd, dst_data);
+          net.push_back(
+              mkldnn::reorder(reorder_pd, *tmp_data_mem, *dst_data_mem));
+          stream(stream::kind::eager).submit(net).wait();
+        }
       }
 
       // Delete primitive since it is not cached.
@@ -787,7 +949,7 @@
     // NOTE: Fusion of BiasAdd is handled directly inside MklConvOp by
     // checking `fuse_biasadd_` flag.
     if (fuse_add_) {
-      params.post_op_params.push_back({"sum", mkldnn::algorithm_undef, {1.0}});
+      params.post_op_params.push_back({"sum", ALGORITHM_UNDEF, {1.0}});
     }
     if (fuse_activation_) {
       params.post_op_params.push_back(
@@ -808,45 +970,60 @@
   virtual void AllocateOutputTensor(OpKernelContext* context,
                                     const ConvFwdPd& conv_prim_desc,
                                     const memory::dims& output_dims_mkl_order,
-                                    memory::format output_tf_format,
-                                    Tensor** output_tensor) {
-    CHECK_NOTNULL(output_tensor);
+                                    MKL_TENSOR_FORMAT output_tf_format,
+                                    MklDnnShape* output_mkl_shape,
+                                    Tensor** output_tensor,
+                                    Tensor* tmp_tensor) {
+    DCHECK(output_tensor);
+#ifdef ENABLE_MKLDNN_V1
+    auto dst_md = conv_prim_desc.dst_desc();
+#else
     auto dst_pd = conv_prim_desc.dst_primitive_desc();
-
     auto dst_md = dst_pd.desc();
+#endif  // ENABLE_MKLDNN_V1
+
     if (!std::is_same<Ttemp_output, Toutput>::value) {
       dst_md.data.data_type =
           static_cast<mkldnn_data_type_t>(MklDnnType<Toutput>());
+#ifndef ENABLE_MKLDNN_V1
       dst_pd = memory::primitive_desc(dst_md, cpu_engine_);
+#endif  // !ENABLE_MKLDNN_V1
     }
-    // Allocate shape of Mkl tensor.
-    MklDnnShape output_mkl_shape;
-    output_mkl_shape.SetMklTensor(true);
-    output_mkl_shape.SetMklLayout(&dst_pd);
-    output_mkl_shape.SetElemType(MklDnnType<Toutput>());
-    output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
-                                 output_dims_mkl_order, output_tf_format);
 
-    // Allocate shape of TF tensor.
+    // Allocate shape of MKL tensor
+    output_mkl_shape->SetMklTensor(true);
+    output_mkl_shape->SetMklLayout(&DST_MD);
+    output_mkl_shape->SetElemType(MklDnnType<Toutput>());
+    output_mkl_shape->SetTfLayout(output_dims_mkl_order.size(),
+                                  output_dims_mkl_order, output_tf_format);
+
+    // Allocate shape of TF tensor
     TensorShape output_tf_shape;
-    output_tf_shape.AddDim((dst_pd.get_size() / sizeof(Toutput)));
+    output_tf_shape.AddDim((DST_MD.get_size() / sizeof(Toutput)));
+    if (eager_mode) {
+      AllocTmpBuffer<Toutput>(context, tmp_tensor, output_tf_shape);
+      output_tf_shape = output_mkl_shape->GetTfShape();
+    }
 
     AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
-                              output_tf_shape, output_mkl_shape);
+                              output_tf_shape, *output_mkl_shape, eager_mode);
+    // TODO(bhavanis): Need to integrate the following Add fusion code with
+    // MKL-DNN v1.x
     if (fuse_add_) {
       const Tensor& add_tensor = MklGetInput(context, kInputIndex_Add);
       MklDnnShape add_mkl_shape;
       GetMklShape(context, kInputIndex_Add, &add_mkl_shape);
 
       // Check if need reorder
-      if (add_mkl_shape == output_mkl_shape) {
-        CHECK((*output_tensor)->CopyFrom(add_tensor, output_tf_shape));
+      if (add_mkl_shape == *output_mkl_shape) {
+        auto result = (*output_tensor)->CopyFrom(add_tensor, output_tf_shape);
+        DCHECK(result);
       } else {
         auto add_md =
             add_mkl_shape.IsMklTensor()
                 ? add_mkl_shape.GetMklLayout()
                 : memory::desc(output_dims_mkl_order, MklDnnType<Toutput>(),
-                               output_mkl_shape.GetTfDataFormat());
+                               output_mkl_shape->GetTfDataFormat());
         auto add_pd = memory::primitive_desc(add_md, this->cpu_engine_);
         void* add_buf = static_cast<void*>(
             const_cast<Toutput*>(add_tensor.flat<Toutput>().data()));
@@ -863,7 +1040,7 @@
     }
   }
 
-  engine cpu_engine_ = engine(engine::cpu, 0);
+  engine cpu_engine_ = engine(ENGINE_CPU, 0);
 
  private:
   std::vector<int32> strides_;
@@ -883,7 +1060,7 @@
   bool fuse_add_ = false;
 
   float relu_up_bound_ = 0.0;
-  mkldnn::algorithm activation_alg_ = mkldnn::algorithm_undef;
+  mkldnn::algorithm activation_alg_ = ALGORITHM_UNDEF;
 
   int input_index_pad_ = 2;
 
@@ -892,15 +1069,27 @@
   const int kOutputIndex_Dst = 0, kOutputIndex_Filter = 1;
   const int kDilationH = 0, kDilationW = 1;
 
+  MKL_TENSOR_FORMAT_IN_C GetFilterTfDataFormat(
+      const MklDnnShape* filter_mkl_shape,
+      const ConvFwdPd& conv_prim_desc) const {
+#ifdef ENABLE_MKLDNN_V1
+    DCHECK(filter_mkl_shape);
+    return filter_mkl_shape->GetTfDataFormat();
+#else
+    return conv_prim_desc.weights_primitive_desc().desc().data.format;
+#endif  // ENABLE_MKLDNN_V1
+  }
+
   // Allocate persistent tensors for cached filter data and
   // cached filter memory descriptor (data format)
   void AllocatePersistentTensor(OpKernelContext* context,
                                 const ConvFwdPd& conv_prim_desc,
-                                Tensor** filter_tensor) {
+                                Tensor** filter_tensor,
+                                const MklDnnShape* filter_mkl_shape) {
     DCHECK(filter_tensor);
     TensorShape filter_tf_shape;
     filter_tf_shape.AddDim(
-        (conv_prim_desc.weights_primitive_desc().get_size() / sizeof(Tfilter)));
+        (conv_prim_desc.PRIMITIVE_DESC_WEIGHTS.get_size() / sizeof(Tfilter)));
     OP_REQUIRES_OK(context, context->allocate_persistent(
                                 DataTypeToEnum<Tfilter>::value, filter_tf_shape,
                                 &cached_filter_data_ptensor_, filter_tensor));
@@ -908,37 +1097,44 @@
     Tensor* second_tensor = nullptr;
     TensorShape filter_mkl_format;
     filter_mkl_format.AddDim(
-        sizeof(conv_prim_desc.weights_primitive_desc().desc().data.format) /
+        sizeof(GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc)) /
         sizeof(DT_INT32));
     OP_REQUIRES_OK(context, context->allocate_persistent(
                                 DT_INT32, filter_mkl_format,
                                 &cached_filter_md_ptensor_, &second_tensor));
-    second_tensor->scalar<int32>()() =
-        conv_prim_desc.weights_primitive_desc().desc().data.format;
+    second_tensor->scalar<int32>()() = static_cast<int32>(
+        GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc));
+  }
+
+  void AllocatePersistentTensor(OpKernelContext* context,
+                                const ConvFwdPd& conv_prim_desc,
+                                Tensor** filter_tensor) {
+    AllocatePersistentTensor(context, conv_prim_desc, filter_tensor, nullptr);
   }
 
   void AllocateFilterOutputTensor(OpKernelContext* context,
                                   const ConvFwdPd& conv_prim_desc,
                                   const memory::dims& filter_dims_tf_order,
                                   Tensor** filter_tensor) {
-    CHECK_NOTNULL(filter_tensor);
-    auto filter_pd = conv_prim_desc.weights_primitive_desc();
+    DCHECK(filter_tensor);
+    auto filter_md = conv_prim_desc.PRIMITIVE_DESC_WEIGHTS;
 
-    // Allocate shape of Mkl tensor.
+    // Allocate shape of MKL tensor
     MklDnnShape filter_mkl_shape;
     filter_mkl_shape.SetMklTensor(true);
-    filter_mkl_shape.SetMklLayout(&filter_pd);
+    filter_mkl_shape.SetMklLayout(&filter_md);
     filter_mkl_shape.SetElemType(MklDnnType<Tfilter>());
 
     // The format of the filter is actually OIhw8i8o, but TF doesn't support
     // this format. Just use format::blocked for now because the layout
     // is stored in the MKL data.
     filter_mkl_shape.SetTfLayout(filter_dims_tf_order.size(),
-                                 filter_dims_tf_order, memory::format::blocked);
+                                 filter_dims_tf_order,
+                                 MKL_TENSOR_FORMAT_BLOCKED);
 
     // Allocate the data space for the filter to propagate as TF tensor.
     TensorShape filter_tf_shape;
-    filter_tf_shape.AddDim((filter_pd.get_size() / sizeof(Tfilter)));
+    filter_tf_shape.AddDim((filter_md.get_size() / sizeof(Tfilter)));
 
     AllocateOutputSetMklShape(context, kOutputIndex_Filter, filter_tensor,
                               filter_tf_shape, filter_mkl_shape);
@@ -951,20 +1147,46 @@
                             MklDnnData<Tbias>* bias,
                             MklDnnData<Toutput>* output,
                             Tensor* filter_out_tensor) {
-    CHECK_NOTNULL(filter_out_tensor);
+    DCHECK(filter_out_tensor);
 
     // Create reorders between user layout and MKL layout if it is needed and
     // add it to the net before convolution. No need to check for output
     // reorder as we propagate output layout to the next layer.
-    src->CheckReorderToOpMem(conv_prim_desc.src_primitive_desc());
+    src->CheckReorderToOpMem(
+        MEMORY_PD_WITHOUT_DATA(conv_prim_desc.PRIMITIVE_DESC_SRC, cpu_engine_));
 
-    // rather than re-order to a temp buffer, reorder directly to the
+    // Rather than re-ordering to a temp buffer, reorder directly to the
     // filter output tensor
-    filter->CheckReorderToOpMem(conv_prim_desc.weights_primitive_desc(),
+    filter->CheckReorderToOpMem(conv_prim_desc.PRIMITIVE_DESC_WEIGHTS,
                                 filter->GetTensorBuffer(filter_out_tensor));
 
     // Create convolution primitive and add it to net.
     std::vector<primitive> net;
+#ifdef ENABLE_MKLDNN_V1
+    std::vector<std::unordered_map<int, memory>> net_args;
+    if (bias) {
+      DCHECK(fuse_biasadd_);
+      net.push_back(convolution_forward(conv_prim_desc));
+      net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()},
+                          {MKLDNN_ARG_WEIGHTS, filter->GetOpMem()},
+                          {MKLDNN_ARG_BIAS, bias->GetOpMem()},
+                          { MKLDNN_ARG_DST,
+                            output->GetOpMem() }});
+    } else {
+      DCHECK(!fuse_biasadd_);
+      net.push_back(convolution_forward(conv_prim_desc));
+      net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()},
+                          {MKLDNN_ARG_WEIGHTS, filter->GetOpMem()},
+                          { MKLDNN_ARG_DST,
+                            output->GetOpMem() }});
+    }
+    stream cpu_stream(cpu_engine_);
+    DCHECK_EQ(net.size(), net_args.size());
+    for (size_t i = 0; i < net.size(); ++i) {
+      net.at(i).execute(cpu_stream, net_args.at(i));
+    }
+    cpu_stream.wait();
+#else
     if (bias) {
       DCHECK(fuse_biasadd_);
       net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
@@ -976,8 +1198,8 @@
                                         filter->GetOpMem(),
                                         output->GetOpMem()));
     }
-
     stream(stream::kind::eager).submit(net).wait();
+#endif  // ENABLE_MKLDNN_V1
   }
 
   // LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot
@@ -990,8 +1212,55 @@
     return (cached_filter_data_tensor.NumElements() == 0);
   }
 
-  // Cache the converted filter in a persistent tensor.
-  // Only one thread can execute this method at any given time.
+// Cache the converted filter in a persistent tensor.
+// Only one thread can execute this method at any given time.
+#ifdef ENABLE_MKLDNN_V1
+  void CacheFilter(OpKernelContext* context,
+                   const std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
+                   Tfilter* filter_data, const Tensor& filter_tensor,
+                   MklDnnData<Tfilter>& filter, const memory::desc& filter_md,
+                   const MklDnnShape& filter_mkl_shape) LOCKS_EXCLUDED(mu_) {
+    mutex_lock lock(mu_);
+    const Tensor& cached_filter_data_tensor =
+        *cached_filter_data_ptensor_.AccessTensor(context);
+
+    // If filter is already cached, there's nothing to do.
+    if (cached_filter_data_tensor.NumElements() > 0) {
+      return;
+    }
+
+    // Otherwise, cache filter
+    filter.SetUsrMem(filter_md, &filter_tensor);
+    filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_desc(),
+                               this->cpu_engine_);
+    filter_data = static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
+
+    Tensor* filter_tensor_ptr = nullptr;
+    AllocatePersistentTensor(context, *conv_fwd_pd, &filter_tensor_ptr,
+                             &filter_mkl_shape);
+    void* cached_filter_data = filter.GetTensorBuffer(filter_tensor_ptr);
+    size_t cached_filter_data_size = filter.GetOpMem().get_desc().get_size();
+    memcpy(cached_filter_data, filter_data, cached_filter_data_size);
+  }
+
+  bool AreMemoryDescriptorsEqual(const memory::desc& filter_md,
+                                 const Tensor& cached_filter_md) {
+    auto filter_md_data = filter_md.data;
+    const char* filter_data = reinterpret_cast<const char*>(&filter_md_data);
+
+    auto cached_filter_md_data = cached_filter_md.scalar<int64>()();
+    const char* cached_filter_data =
+        reinterpret_cast<const char*>(&cached_filter_md_data);
+
+    for (size_t i = 0; i < sizeof(filter_md_data); ++i) {
+      if (*filter_data++ != *cached_filter_data++) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+#else
   void CacheFilter(OpKernelContext* context,
                    const std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
                    Tfilter* filter_data, const Tensor& filter_tensor,
@@ -1018,22 +1287,26 @@
         filter.GetOpMem().get_primitive_desc().get_size();
     memcpy(cached_filter_data, filter_data, cached_filter_data_size);
   }
+#endif  // ENABLE_MKLDNN_V1
 
   Tfilter* GetCachedFilter(OpKernelContext* context,
-                           const memory::format& filter_mf)
-      LOCKS_EXCLUDED(mu_) {
+                           const MEMORY_DESC& filter_md) LOCKS_EXCLUDED(mu_) {
     tf_shared_lock lock(mu_);
     const Tensor& cached_filter_data =
         *cached_filter_data_ptensor_.AccessTensor(context);
     const Tensor& cached_filter_md =
         *cached_filter_md_ptensor_.AccessTensor(context);
 
-    // Check if the memory descriptor of the cached weights is same as
-    // filter_mf. If so, we can used the cached weights; otherwise
-    // return NULL.
-    // TODO (bhavanis): Do we need to cast filter_mf before the check?
+// Check if the memory descriptor of the cached weights is same as
+// filter_md. If so, we can used the cached weights; otherwise
+// return NULL.
+#ifdef ENABLE_MKLDNN_V1
+    if (cached_filter_md.scalar<int64>().size() &&
+        AreMemoryDescriptorsEqual(filter_md, cached_filter_md)) {
+#else
     if (cached_filter_md.scalar<int32>().size() &&
-        cached_filter_md.scalar<int32>()() == filter_mf) {
+        cached_filter_md.scalar<int32>()() == filter_md) {
+#endif  // ENABLE_MKLDNN_V1
       return static_cast<Tfilter*>(
           const_cast<Tfilter*>(cached_filter_data.flat<Tfilter>().data()));
     }
@@ -1047,11 +1320,11 @@
           bool pad_enabled>
 class MklFusedConvOp
     : public MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
-                       Tpadding, false, false, false> {
+                       Tpadding, false, false, false, false> {
  public:
   explicit MklFusedConvOp(OpKernelConstruction* context)
       : MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
-                  Tpadding, false, false, false>(context) {
+                  Tpadding, false, false, false, false>(context) {
     // Since we came here through the registration of _MklFusedConv2D, get
     // all information from 'fused_ops' and 'num_args'
     std::vector<string> fused_ops;
@@ -1069,26 +1342,26 @@
                   errors::InvalidArgument(
                       "Fused Conv2D must have one extra argument: bias."));
     } else if (fused_ops == std::vector<string>{"Relu"}) {
-      this->set_fuse_activation(true, mkldnn::eltwise_relu);
+      this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
     } else if (fused_ops == std::vector<string>{"Relu6"}) {
-      this->set_fuse_activation(true, mkldnn::eltwise_bounded_relu, 6.0);
+      this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
     } else if (fused_ops == std::vector<string>{"Elu"}) {
-      this->set_fuse_activation(true, mkldnn::eltwise_elu);
+      this->set_fuse_activation(true, ALGORITHM::eltwise_elu);
     } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
       this->set_fuse_biasadd(true);
-      this->set_fuse_activation(true, mkldnn::eltwise_relu);
+      this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
       OP_REQUIRES(context, num_args == 1,
                   errors::InvalidArgument(
                       "Fused Conv2D must have one extra argument: bias."));
     } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
       this->set_fuse_biasadd(true);
-      this->set_fuse_activation(true, mkldnn::eltwise_bounded_relu, 6.0);
+      this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
       OP_REQUIRES(context, num_args == 1,
                   errors::InvalidArgument(
                       "Fused Conv2D must have one extra argument: bias."));
     } else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
       this->set_fuse_biasadd(true);
-      this->set_fuse_activation(true, mkldnn::eltwise_elu);
+      this->set_fuse_activation(true, ALGORITHM::eltwise_elu);
       OP_REQUIRES(context, num_args == 1,
                   errors::InvalidArgument(
                       "Fused Conv2D must have one extra argument: bias."));
@@ -1102,7 +1375,7 @@
     } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"}) {
       this->set_fuse_biasadd(true);
       this->set_fuse_add(true);
-      this->set_fuse_activation(true, mkldnn::eltwise_relu);
+      this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
       OP_REQUIRES(
           context, num_args == 2,
           errors::InvalidArgument(
@@ -1110,7 +1383,7 @@
     } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"}) {
       this->set_fuse_biasadd(true);
       this->set_fuse_add(true);
-      this->set_fuse_activation(true, mkldnn::eltwise_bounded_relu, 6.0);
+      this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
       OP_REQUIRES(
           context, num_args == 2,
           errors::InvalidArgument(
@@ -1118,7 +1391,7 @@
     } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"}) {
       this->set_fuse_biasadd(true);
       this->set_fuse_add(true);
-      this->set_fuse_activation(true, mkldnn::eltwise_elu);
+      this->set_fuse_activation(true, ALGORITHM::eltwise_elu);
       OP_REQUIRES(
           context, num_args == 2,
           errors::InvalidArgument(
@@ -1143,7 +1416,7 @@
           typename Ttemp_output, bool bias_enabled, bool is_depthwise>
 class MklQuantizedConv2DOp
     : public MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output,
-                       int32, bias_enabled, false, is_depthwise> {
+                       int32, bias_enabled, false, is_depthwise, false> {
  public:
   virtual ~MklQuantizedConv2DOp() {
     if (this->input_bias_ != nullptr) {
@@ -1159,7 +1432,7 @@
 
   explicit MklQuantizedConv2DOp(OpKernelConstruction* context)
       : MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32,
-                  bias_enabled, false, is_depthwise>(context) {
+                  bias_enabled, false, is_depthwise, false>(context) {
     bool is_filter_const;
     OP_REQUIRES_OK(context,
                    context->GetAttr("is_filter_const", &is_filter_const));
@@ -1170,7 +1443,7 @@
   void Compute(OpKernelContext* context) override {
     // Compute int32 output tensor
     MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32,
-              bias_enabled, false, is_depthwise>::Compute(context);
+              bias_enabled, false, is_depthwise, false>::Compute(context);
 
     // Compute additional outputs: min/max scalars.
     int bias_index_offset;
@@ -1232,8 +1505,8 @@
   void ExtendConvFwdParams(OpKernelContext* context,
                            MklConvFwdParams& params) override {
     MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32,
-              bias_enabled, false, is_depthwise>::ExtendConvFwdParams(context,
-                                                                      params);
+              bias_enabled, false, is_depthwise,
+              false>::ExtendConvFwdParams(context, params);
 
     // When the output type is quint8, the output data id requantized
     // into quint8. A post_op "output_scale" is added to do the conversion.
@@ -1274,7 +1547,7 @@
                     (255.0f * 127.0f * output_range);
       }
       params.post_op_params.push_back(
-          {"output_scale", mkldnn::algorithm_undef, scales});
+          {"output_scale", ALGORITHM_UNDEF, scales});
     }
   }
 
@@ -1293,7 +1566,6 @@
     const float* min_filter = min_filter_vector.flat<float>().data();
     const float* max_filter = max_filter_vector.flat<float>().data();
 
-    std::vector<mkldnn::primitive> net;
     if (bias_enabled) {
       if (std::is_same<Tbias, qint32>::value) {
         return static_cast<Tbias*>(
@@ -1315,21 +1587,21 @@
       } else {
         bias_attr.set_output_scales(1, scales);
       }
-      auto bias_pd =
-          memory::primitive_desc({{static_cast<int>(bias_tensor.NumElements())},
-                                  MklDnnType<Tbias>(),
-                                  memory::format::x},
-                                 this->cpu_engine_);
 
+      auto bias_md =
+          MEMORY_PD_CONSTRUCTOR(static_cast<int>(bias_tensor.NumElements()),
+                                Tbias, x, this->cpu_engine_);
       void* bias_buf = static_cast<void*>(
           const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
-      input_bias_ = new memory(bias_pd, bias_buf);
-      scaled_bias_ = new memory(conv_fwd_pd->bias_primitive_desc());
-      auto reorder_desc = mkldnn::reorder::primitive_desc(
-          input_bias_->get_primitive_desc(), scaled_bias_->get_primitive_desc(),
+      input_bias_ =
+          new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_, bias_buf);
+      scaled_bias_ = new MEMORY_CONSTRUCTOR_WITHOUT_DATA(
+          conv_fwd_pd->PRIMITIVE_DESC_BIAS, this->cpu_engine_);
+      auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR(
+          input_bias_->GET_DESC, scaled_bias_->GET_DESC, this->cpu_engine_,
           bias_attr);
-      net.push_back(mkldnn::reorder(reorder_desc, *input_bias_, *scaled_bias_));
-      stream(stream::kind::eager).submit(net).wait();
+      CreateAndExecuteReorder(reorder_desc, *input_bias_, *scaled_bias_,
+                              this->cpu_engine_);
       return reinterpret_cast<Tbias*>(scaled_bias_->get_data_handle());
     } else {
       return nullptr;
@@ -1358,7 +1630,7 @@
     MklQuantizedConv2DOp<Device, Tbias, Toutput, Ttemp_output, bias_enabled,
                          is_depthwise>::ExtendConvFwdParams(context, params);
     params.post_op_params.push_back(
-        {"activation", mkldnn::eltwise_relu, {1.0, 0.0, 0.0}});
+        {"activation", ALGORITHM::eltwise_relu, {1.0, 0.0, 0.0}});
   }
 };
 
@@ -1415,24 +1687,26 @@
       // If it is not then  it is DT_INT8 and is scaled appropriately.
       if (summand_type == DT_QUINT8)
         params.post_op_params.push_back(
-            {"sum", mkldnn::algorithm_undef, {scale_summand / scale_output}});
+            {"sum", ALGORITHM_UNDEF, {scale_summand / scale_output}});
       else
         params.post_op_params.push_back(
             {"sum",
-             mkldnn::algorithm_undef,
+             ALGORITHM_UNDEF,
              {255.0f * scale_summand / (scale_output * 127.0f)}});
     } else {
-      params.post_op_params.push_back({"sum", mkldnn::algorithm_undef, {1.0}});
+      params.post_op_params.push_back({"sum", ALGORITHM_UNDEF, {1.0}});
     }
     params.post_op_params.push_back(
-        {"activation", mkldnn::eltwise_relu, {1.0, 0.0, 0.0}});
+        {"activation", ALGORITHM::eltwise_relu, {1.0, 0.0, 0.0}});
   }
 
   void AllocateOutputTensor(OpKernelContext* context,
                             const ConvFwdPd& conv_prim_desc,
                             const memory::dims& output_dims_mkl_order,
-                            memory::format output_tf_format,
-                            Tensor** output_tensor) override {
+                            MKL_TENSOR_FORMAT output_tf_format,
+                            MklDnnShape* output_mkl_shape,
+                            Tensor** output_tensor,
+                            Tensor* tmp_tensor) override {
     int summand_idx = context->num_inputs() / 2 - 1;
     if (std::is_same<Toutput, quint8>::value) {
       summand_idx -= 2;
@@ -1459,12 +1733,12 @@
       *output_tensor = const_cast<Tensor*>(&summand);
       return;
     }
-
     MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32,
-              bias_enabled, false,
+              bias_enabled, false, false,
               false>::AllocateOutputTensor(context, conv_prim_desc,
                                            output_dims_mkl_order,
-                                           output_tf_format, output_tensor);
+                                           output_tf_format, output_mkl_shape,
+                                           output_tensor, tmp_tensor);
     const Tensor& summand = MklGetInput(context, summand_idx);
     if (summand.dtype() != DT_FLOAT)
       TF_CHECK_OK(Status(error::Code::FAILED_PRECONDITION,
@@ -1503,20 +1777,22 @@
         summand_mkl_shape.IsMklTensor()
             ? summand_mkl_shape.GetMklLayout()
             : memory::desc(output_dims_mkl_order, MklDnnType<Tbias>(),
-                           memory::format::nhwc);
+                           MEMORY_FORMAT::nhwc);
+#ifndef ENABLE_MKLDNN_V1
     auto summand_pd = memory::primitive_desc(summand_md, this->cpu_engine_);
+#endif  // !ENABLE_MKLDNN_V1
     void* summand_buf =
         static_cast<void*>(const_cast<Tbias*>(summand.flat<Tbias>().data()));
     void* dst_buf =
         static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data());
-    summand_ = new memory(summand_pd, summand_buf);
-    dst_ = new memory(conv_prim_desc.dst_primitive_desc(), dst_buf);
-    auto reorder_desc = mkldnn::reorder::primitive_desc(
-        summand_pd, conv_prim_desc.dst_primitive_desc(), reorder_attr);
-
-    std::vector<mkldnn::primitive> net;
-    net.push_back(mkldnn::reorder(reorder_desc, *summand_, *dst_));
-    stream(stream::kind::eager).submit(net).wait();
+    summand_ =
+        new MEMORY_CONSTRUCTOR(SUMMAND_MD, this->cpu_engine_, summand_buf);
+    dst_ = new MEMORY_CONSTRUCTOR(conv_prim_desc.PRIMITIVE_DESC_DST,
+                                  this->cpu_engine_, dst_buf);
+    auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR(
+        SUMMAND_MD, conv_prim_desc.PRIMITIVE_DESC_DST, this->cpu_engine_,
+        reorder_attr);
+    CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_);
   }
 
   memory* summand_ = nullptr;
@@ -1870,46 +2146,52 @@
     MklQuantizedConv2DReluOp<CPUDevice, qint32, quint8, quint8, true, true>);
 
 // Register 2D operations
-#define REGISTER_MKL_CPU_2D(T)                                          \
-  REGISTER_KERNEL_BUILDER(                                              \
-      Name("_MklConv2D")                                                \
-          .Device(DEVICE_CPU)                                           \
-          .TypeConstraint<T>("T")                                       \
-          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),          \
-      MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false>); \
-  REGISTER_KERNEL_BUILDER(                                              \
-      Name("_MklConv2DWithBias")                                        \
-          .Device(DEVICE_CPU)                                           \
-          .TypeConstraint<T>("T")                                       \
-          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),          \
-      MklConvOp<CPUDevice, T, T, T, T, T, int32, true, false, false>);  \
-  REGISTER_KERNEL_BUILDER(                                              \
-      Name("__MklDummyConv2DWithBias")                                  \
-          .Device(DEVICE_CPU)                                           \
-          .TypeConstraint<T>("T")                                       \
-          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),          \
-      MklDummyOp<CPUDevice, T>);                                        \
-  REGISTER_KERNEL_BUILDER(                                              \
-      Name("_MklPadWithConv2D")                                         \
-          .Device(DEVICE_CPU)                                           \
-          .TypeConstraint<T>("T")                                       \
-          .TypeConstraint<int32>("Tpaddings")                           \
-          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),          \
-      MklConvOp<CPUDevice, T, T, T, T, T, int32, false, true, false>);  \
-  REGISTER_KERNEL_BUILDER(                                              \
-      Name("_MklPadWithConv2D")                                         \
-          .Device(DEVICE_CPU)                                           \
-          .TypeConstraint<T>("T")                                       \
-          .TypeConstraint<int64>("Tpaddings")                           \
-          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),          \
-      MklConvOp<CPUDevice, T, T, T, T, T, int64, false, true, false>);  \
-  REGISTER_KERNEL_BUILDER(                                              \
-      Name("__MklDummyPadWithConv2D")                                   \
-          .Device(DEVICE_CPU)                                           \
-          .TypeConstraint<T>("T")                                       \
-          .TypeConstraint<int32>("Tpaddings")                           \
-          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),          \
-      MklDummyOp<CPUDevice, T>);
+#define REGISTER_MKL_CPU_2D(T)                                                 \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("_MklConv2D")                                                       \
+          .Device(DEVICE_CPU)                                                  \
+          .TypeConstraint<T>("T")                                              \
+          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                 \
+      MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, false>); \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("_MklConv2DWithBias")                                               \
+          .Device(DEVICE_CPU)                                                  \
+          .TypeConstraint<T>("T")                                              \
+          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                 \
+      MklConvOp<CPUDevice, T, T, T, T, T, int32, true, false, false, false>);  \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("__MklDummyConv2DWithBias")                                         \
+          .Device(DEVICE_CPU)                                                  \
+          .TypeConstraint<T>("T")                                              \
+          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                 \
+      MklDummyOp<CPUDevice, T>);                                               \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("_MklPadWithConv2D")                                                \
+          .Device(DEVICE_CPU)                                                  \
+          .TypeConstraint<T>("T")                                              \
+          .TypeConstraint<int32>("Tpaddings")                                  \
+          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                 \
+      MklConvOp<CPUDevice, T, T, T, T, T, int32, false, true, false, false>);  \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("_MklPadWithConv2D")                                                \
+          .Device(DEVICE_CPU)                                                  \
+          .TypeConstraint<T>("T")                                              \
+          .TypeConstraint<int64>("Tpaddings")                                  \
+          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                 \
+      MklConvOp<CPUDevice, T, T, T, T, T, int64, false, true, false, false>);  \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("__MklDummyPadWithConv2D")                                          \
+          .Device(DEVICE_CPU)                                                  \
+          .TypeConstraint<T>("T")                                              \
+          .TypeConstraint<int32>("Tpaddings")                                  \
+          .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                 \
+      MklDummyOp<CPUDevice, T>);                                               \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("_MklEagerConv2D")                                                  \
+          .Device(DEVICE_CPU)                                                  \
+          .TypeConstraint<T>("T")                                              \
+          .Label(mkl_op_registry::kMklNameChangeOpLabel),                      \
+      MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, true>);
 
 TF_CALL_float(REGISTER_MKL_CPU_2D);
 TF_CALL_bfloat16(REGISTER_MKL_CPU_2D);
@@ -1920,7 +2202,7 @@
           .Device(DEVICE_CPU)                                  \
           .TypeConstraint<T>("T")                              \
           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
-      MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, true>);
+      MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, true, false>);
 
 TF_CALL_float(REGISTER_MKL_CPU_2D_DEPTHWISE);
 TF_CALL_bfloat16(REGISTER_MKL_CPU_2D_DEPTHWISE);
@@ -1966,9 +2248,40 @@
           .Device(DEVICE_CPU)                                  \
           .TypeConstraint<T>("T")                              \
           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
-      MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false>);
+      MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, false>);
 TF_CALL_float(REGISTER_MKL_CPU_3D);
 TF_CALL_bfloat16(REGISTER_MKL_CPU_3D);
 
+#undef ADD_MD
+#undef ALGORITHM
+#undef ALGORITHM_UNDEF
+#undef CPU_STREAM
+#undef DATA_WITH_ENGINE
+#undef DST_MD
+#undef ENGINE_CPU
+#undef GET_DESC
+#undef GET_MEMORY_DESC_CONSTRUCTOR
+#undef GET_SRC_DESC_FROM_OP_PD
+#undef GET_WEIGHTS_DESC_FROM_OP_PD
+#undef GET_WEIGHTS_FORMAT_FROM_OP_PD
+#undef IS_FILTER_REORDER_NEEDED
+#undef IS_SRC_REORDER_NEEDED
+#undef MEMORY_CONSTRUCTOR
+#undef MEMORY_CONSTRUCTOR_USING_MEM_PD
+#undef MEMORY_CONSTRUCTOR_WITHOUT_DATA
+#undef MEMORY_DESC
+#undef MEMORY_FORMAT
+#undef MEMORY_PD_CONSTRUCTOR
+#undef MEMORY_PD_WITHOUT_DATA
+#undef MKL_TENSOR_FORMAT
+#undef MKL_TENSOR_FORMAT_BLOCKED
+#undef MKL_TENSOR_FORMAT_IN_C
+#undef PRIMITIVE_DESC_BIAS
+#undef PRIMITIVE_DESC_DST
+#undef PRIMITIVE_DESC_SRC
+#undef PRIMITIVE_DESC_WEIGHTS
+#undef REORDER_PD_CONSTRUCTOR
+#undef REORDER_PD_CONSTRUCTOR_WITH_ATTR
+#undef SUMMAND_MD
 }  // namespace tensorflow
 #endif  // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h
index e9be11a..99e9c9f 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.h
+++ b/tensorflow/core/kernels/mkl_conv_ops.h
@@ -40,13 +40,21 @@
 #include "tensorflow/core/util/padding.h"
 #include "tensorflow/core/util/tensor_format.h"
 
+#ifndef ENABLE_MKLDNN_V1
 using mkldnn::convolution_direct;
+#endif  // !ENABLE_MKLDNN_V1
 using mkldnn::convolution_forward;
 using mkldnn::prop_kind;
 using mkldnn::stream;
 
 namespace tensorflow {
 
+#ifdef ENABLE_MKLDNN_V1
+#define MKLDNN_SIZE_DTYPE long int
+#else
+#define MKLDNN_SIZE_DTYPE int
+#endif  // ENABLE_MKLDNN_V1
+
 class MklDnnConvUtil {
  protected:
   OpKernelContext* context_;  // We don't own this.
@@ -137,7 +145,7 @@
       int input_cols = static_cast<int>(input_cols_raw);
 
       // MKL-DNN always requires input in NCHW format Conv2D.
-      std::vector<int> mkldnn_sizes(4, -1);
+      std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(4, -1);
       mkldnn_sizes[MklDnnDims::Dim_N] = input_batch;
       mkldnn_sizes[MklDnnDims::Dim_C] = input_depth;
       mkldnn_sizes[MklDnnDims::Dim_H] = input_rows;
@@ -161,7 +169,7 @@
       int input_cols = static_cast<int>(input_cols_raw);
 
       // MKL-DNN always requires input in NCDHW format for Conv3D.
-      std::vector<int> mkldnn_sizes(5, -1);
+      std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(5, -1);
       mkldnn_sizes[MklDnnDims3D::Dim3d_N] = input_batch;
       mkldnn_sizes[MklDnnDims3D::Dim3d_C] = input_depth;
       mkldnn_sizes[MklDnnDims3D::Dim3d_D] = input_planes;
@@ -225,7 +233,7 @@
       // GOIHW = (group, out_depth, in_depth, rows, cols)
       // Specifically for depthwise G=filter_indepth, O=filter_outdepth, I=1
       if (is_depthwise) {
-        std::vector<int> mkldnn_sizes(5, -1);
+        std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(5, -1);
         mkldnn_sizes[MKL_GROUP_FILTER_DIM_G] = filter_in_depth;
         mkldnn_sizes[MKL_GROUP_FILTER_DIM_O] = filter_out_depth;
         mkldnn_sizes[MKL_GROUP_FILTER_DIM_I] = 1;
@@ -234,7 +242,7 @@
 
         *filter_dims = mkldnn_sizes;
       } else {
-        std::vector<int> mkldnn_sizes(4, -1);
+        std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(4, -1);
         mkldnn_sizes[MklDnnDims::Dim_O] = filter_out_depth;
         mkldnn_sizes[MklDnnDims::Dim_I] = filter_in_depth;
         mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows;
@@ -262,7 +270,7 @@
 
       // MKL-DNN always needs filter in OIDHW format.
       // OIDHW = (out_depth, in_depth, planes, rows, cols)
-      std::vector<int> mkldnn_sizes(5, -1);
+      std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(5, -1);
       mkldnn_sizes[MklDnnDims3D::Dim3d_O] = filter_out_depth;
       mkldnn_sizes[MklDnnDims3D::Dim3d_I] = filter_in_depth;
       mkldnn_sizes[MklDnnDims3D::Dim3d_D] = filter_planes;
@@ -455,14 +463,14 @@
 
     if (is_conv2d) {
       // For Conv2D, MKL-DNN always needs output in NCHW format.
-      std::vector<int> mkldnn_sizes(4, -1);
+      std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(4, -1);
       mkldnn_sizes[MklDnnDims::Dim_N] = out_batch;
       mkldnn_sizes[MklDnnDims::Dim_C] = out_depth;
       mkldnn_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows);
       mkldnn_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols);
       *output_dims_mkl_order = mkldnn_sizes;
     } else {
-      std::vector<int> mkldnn_sizes(5, -1);
+      std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(5, -1);
       mkldnn_sizes[MklDnnDims3D::Dim3d_N] = out_batch;
       mkldnn_sizes[MklDnnDims3D::Dim3d_C] = out_depth;
       mkldnn_sizes[MklDnnDims3D::Dim3d_D] = static_cast<int>(out_planes);
@@ -624,6 +632,8 @@
   }
 };
 
+#undef MKLDNN_SIZE_DTYPE
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_
diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
index 7c89199..4fbacc1 100644
--- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
@@ -518,6 +518,8 @@
                 errors::InvalidArgument("Invalid data format"));
     OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
     depth_ = 0;
+    mean_values_ = nullptr;
+    variance_values_ = nullptr;
   }
 
   void Compute(OpKernelContext* context) override {
diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc
index 53de24a..0d203c1 100644
--- a/tensorflow/core/kernels/mkl_maxpooling_op.cc
+++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc
@@ -75,6 +75,8 @@
 
       // Declare output tensor
       Tensor* output_tensor = nullptr;
+      // Declare output workspace tensor
+      Tensor* output_ws_tensor = nullptr;
       memory::dims output_dims_mkl_order;
       this->GetOutputDims(pool_params, &output_dims_mkl_order);
 
@@ -83,6 +85,19 @@
         const int kOutputIndex = 0;
         this->AllocateEmptyOutputTensor(context, kOutputIndex, &pool_params,
                                         output_dims_mkl_order, &output_tensor);
+        bool int8_forward_inference =
+            std::is_same<T, qint8>::value || std::is_same<T, quint8>::value;
+
+        // Allocate an empty workspace tensor if not Quantized MaxPooling
+        // Because Quantized MaxPooling does not have backward pass
+        // Therefore no workspace, which is used to help backward pass in MKL
+        if (!int8_forward_inference) {
+          const int kOutputWorkspaceIndex = 1;
+          // output_ws_tensor is not really used, so using output_dims_mkl_order
+          this->AllocateEmptyOutputTensor(context, kOutputWorkspaceIndex,
+                                          &pool_params, output_dims_mkl_order,
+                                          &output_ws_tensor);
+        }
         return;
       }
 
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl_pooling_ops_common.h
index ec440a0..c2c33d9 100644
--- a/tensorflow/core/kernels/mkl_pooling_ops_common.h
+++ b/tensorflow/core/kernels/mkl_pooling_ops_common.h
@@ -548,12 +548,21 @@
     if (pool_params->data_format == TensorFormat::FORMAT_NCHW) {
       output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order);
     } else {
-      memory::dims output_dims_NHWC_order;
-      output_dims_NHWC_order = {pool_params->tensor_in_batch,
-                                static_cast<int>(pool_params->out_height),
-                                static_cast<int>(pool_params->out_width),
-                                pool_params->out_depth};
-      output_tf_shape = MklDnnDimsToTFShape(output_dims_NHWC_order);
+      memory::dims output_dims_order;
+      // determine Pooling2D (NHWC) or Pooling3D (NDHWC)
+      if (this->ksize_.size() == 4) {
+        output_dims_order = {pool_params->tensor_in_batch,
+                             static_cast<int>(pool_params->out_height),
+                             static_cast<int>(pool_params->out_width),
+                             pool_params->out_depth};
+      } else {
+        output_dims_order = {pool_params->tensor_in_batch,
+                             static_cast<int>(pool_params->out_planes),
+                             static_cast<int>(pool_params->out_height),
+                             static_cast<int>(pool_params->out_width),
+                             pool_params->out_depth};
+      }
+      output_tf_shape = MklDnnDimsToTFShape(output_dims_order);
     }
     AllocateOutputSetMklShape(context, kOutputIndex, output_tensor,
                               output_tf_shape, output_mkl_shape);
diff --git a/tensorflow/core/kernels/nccl_ops.cc b/tensorflow/core/kernels/nccl_ops.cc
index 666a144..6d34b68 100644
--- a/tensorflow/core/kernels/nccl_ops.cc
+++ b/tensorflow/core/kernels/nccl_ops.cc
@@ -93,9 +93,9 @@
   void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
     const Tensor* input = &c->input(0);
     Tensor* output;
-    OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, input->shape(), &output),
-                         done);
-
+    OP_REQUIRES_OK_ASYNC(
+        c, c->forward_input_or_allocate_output({0}, 0, input->shape(), &output),
+        done);
     auto actual_done = [c, done](Status s) {
       OP_REQUIRES_OK_ASYNC(c, s, done);
       done();
@@ -112,7 +112,7 @@
         {GetCollectiveKey(c),
          /*num_local_devices=*/num_devices(),
          /*num_global_devices=*/num_devices(),
-         /*communicator_key=*/""},
+         /*communicator_key=*/"", /*source_rank=*/-1},
         reduction_op());
   }
 };
@@ -144,7 +144,7 @@
         {GetCollectiveKey(c),
          /*num_local_devices=*/num_devices(),
          /*num_global_devices=*/num_devices(),
-         /*communicator_key=*/""},
+         /*communicator_key=*/"", /*source_rank=*/-1},
         reduction_op());
   }
 };
@@ -181,7 +181,7 @@
         {GetCollectiveKey(c),
          /*num_local_devices=*/num_devices(),
          /*num_global_devices=*/num_devices(),
-         /*communicator_key=*/""},
+         /*communicator_key=*/"", /*source_rank=*/-1},
         reduction_op());
   }
 
@@ -215,7 +215,7 @@
         std::move(participant), {GetCollectiveKey(c),
                                  /*num_local_devices=*/num_devices(),
                                  /*num_global_devices=*/num_devices(),
-                                 /*communicator_key=*/""});
+                                 /*communicator_key=*/"", /*source_rank=*/-1});
   }
 };
 REGISTER_KERNEL_BUILDER(Name("_NcclBroadcastSend").Device(DEVICE_GPU),
@@ -252,7 +252,7 @@
         std::move(participant), {GetCollectiveKey(c),
                                  /*num_local_devices=*/num_devices(),
                                  /*num_global_devices=*/num_devices(),
-                                 /*communicator_key=*/""});
+                                 /*communicator_key=*/"", /*source_rank=*/-1});
   }
 };
 REGISTER_KERNEL_BUILDER(
diff --git a/tensorflow/core/kernels/nn_ops_test.cc b/tensorflow/core/kernels/nn_ops_test.cc
index 8b4d3d9..de21a3c 100644
--- a/tensorflow/core/kernels/nn_ops_test.cc
+++ b/tensorflow/core/kernels/nn_ops_test.cc
@@ -108,6 +108,7 @@
                          CONV_OP op, int num_threads, int stride,
                          Padding padding, bool use_gpu, DataType data_type,
                          const string& label) {
+  testing::StopTiming();
   if (!IsGoogleCudaEnabled() && use_gpu) {
     testing::SetLabel(
         strings::StrCat("Skipping GPU test (no --config=cuda): ", label));
@@ -221,6 +222,7 @@
 
   string device = use_gpu ? "gpu" : "cpu";
   testing::UseRealTime();
+  testing::StartTiming();
   test::Benchmark(device, g, &options).Run(iters);
   testing::ItemsProcessed(num_ops * iters);
 }
@@ -502,6 +504,7 @@
                                   int filter_cols, DEPTHWISE_CONV_OP op,
                                   int num_threads, int stride, Padding padding,
                                   bool use_gpu, const string& label) {
+  testing::StopTiming();
   if (!IsGoogleCudaEnabled() && use_gpu) {
     testing::SetLabel(
         strings::StrCat("Skipping GPU test (no --config=cuda): ", label));
@@ -601,6 +604,7 @@
 
   string device = use_gpu ? "gpu" : "cpu";
   testing::UseRealTime();
+  testing::StartTiming();
   test::Benchmark(device, g, &options).Run(iters);
   testing::ItemsProcessed(num_ops * iters);
 }
diff --git a/tensorflow/core/kernels/parse_tensor_op.cc b/tensorflow/core/kernels/parse_tensor_op.cc
index 8e175fe..d273f67 100644
--- a/tensorflow/core/kernels/parse_tensor_op.cc
+++ b/tensorflow/core/kernels/parse_tensor_op.cc
@@ -39,7 +39,7 @@
                     "Expected `serialized` to be a scalar, got shape: ",
                     serialized.shape().DebugString()));
 
-    auto serialized_t = serialized.scalar<string>();
+    auto serialized_t = serialized.scalar<tstring>();
 
     TensorProto proto;
     OP_REQUIRES(ctx, ParseProtoUnlimited(&proto, serialized_t()),
diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc
index cd37a25..997c2ab 100644
--- a/tensorflow/core/kernels/pooling_ops_common.cc
+++ b/tensorflow/core/kernels/pooling_ops_common.cc
@@ -317,6 +317,7 @@
     return;
   }
 
+#if CUDNN_VERSION < 7300
   /// For now, cudnn does not support NHWC format, so we need to convert it
   /// to NCHW before calling cudnn. We need to get rid of this once it is done
   Tensor transformed_input;
@@ -382,6 +383,40 @@
         context->eigen_device<Device>(), out_backprop.tensor<T, 4>(),
         transformed_output_backprop.tensor<T, 4>());
   }
+  se::dnn::DataLayout data_layout = se::dnn::DataLayout::kBatchDepthYX;
+#else
+  Tensor transformed_input;
+  if (!tensor_in) {
+    OP_REQUIRES_OK(context,
+                   context->allocate_temp(DataTypeToEnum<T>::value,
+                                          tensor_in_shape, &transformed_input));
+  } else {
+    transformed_input = *tensor_in;
+  }
+  Tensor transformed_output;
+  if (!tensor_out) {
+    OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
+                                                   out_backprop.shape(),
+                                                   &transformed_output));
+  } else {
+    transformed_output = *tensor_out;
+  }
+  Tensor transformed_input_backprop = *input_backprop;
+  Tensor transformed_output_backprop = out_backprop;
+  se::dnn::DataLayout data_layout;
+  switch (data_format) {
+    case FORMAT_NHWC:
+      data_layout = se::dnn::DataLayout::kBatchYXDepth;
+      break;
+    case FORMAT_NCHW:
+      data_layout = se::dnn::DataLayout::kBatchDepthYX;
+      break;
+    default:
+      OP_REQUIRES(context, false,
+                  errors::InvalidArgument("Unsupported format: ",
+                                          ToString(data_format)));
+  }
+#endif  // CUDNN_VERSION < 7300
 
   /// Get ready to call cudnn
   se::dnn::PoolingDescriptor pooling_desc;
@@ -399,14 +434,14 @@
       .set_height(params.out_height)
       .set_width(params.out_width)
       .set_feature_map_count(params.depth)
-      .set_layout(se::dnn::DataLayout::kBatchDepthYX);
+      .set_layout(data_layout);
 
   se::dnn::BatchDescriptor orig_input_desc;
   orig_input_desc.set_count(params.tensor_in_batch)
       .set_height(params.tensor_in_rows)
       .set_width(params.tensor_in_cols)
       .set_feature_map_count(params.depth)
-      .set_layout(se::dnn::DataLayout::kBatchDepthYX);
+      .set_layout(data_layout);
 
   auto orig_output_data =
       AsDeviceMemory(transformed_output.template flat<T>().data(),
@@ -449,6 +484,7 @@
   OP_REQUIRES(context, status,
               errors::Internal("dnn PoolBackward launch failed"));
 
+#if CUDNN_VERSION < 7300
   if (data_format == FORMAT_NHWC) {
     /// Transform the output data from NCHW back to NHWC.
     auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
@@ -457,6 +493,7 @@
         toConstTensor(transformed_input_backprop).template tensor<T, 4>(),
         input_backprop->tensor<T, 4>());
   }
+#endif  // CUDNN_VERSION < 7300
 }
 
 #define DEFINE_DNN_OPS(T)         \
diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc
index 6ed5bb0..67e8c94 100644
--- a/tensorflow/core/kernels/queue_ops.cc
+++ b/tensorflow/core/kernels/queue_ops.cc
@@ -84,8 +84,8 @@
 
   void Compute(OpKernelContext* context) override {
     const ResourceHandle& ref = context->input(0).flat<ResourceHandle>()(0);
-    handle_.AccessTensor(context)->flat<string>()(0) = ref.container();
-    handle_.AccessTensor(context)->flat<string>()(1) = ref.name();
+    handle_.AccessTensor(context)->flat<tstring>()(0) = ref.container();
+    handle_.AccessTensor(context)->flat<tstring>()(1) = ref.name();
     context->set_output_ref(0, &mu_, handle_.AccessTensor(context));
   }
 
diff --git a/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc b/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc
new file mode 100644
index 0000000..21b929f
--- /dev/null
+++ b/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc
@@ -0,0 +1,538 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <stddef.h>
+
+#include <algorithm>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/numeric_types.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/broadcast_to_op.h"
+#include "tensorflow/core/kernels/list_kernels.h"
+#include "tensorflow/core/lib/bfloat16/bfloat16.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/ops/ragged_to_dense_util.h"
+#include "tensorflow/core/platform/default/integral_types.h"
+#include "tensorflow/core/util/bcast.h"
+
+namespace tensorflow {
+
+namespace {
+typedef Eigen::ThreadPoolDevice CPUDevice;
+using ::std::vector;
+using ::tensorflow::errors::Internal;
+
+const int kShapeInputIndex = 0;
+const int kValueInputIndex = 1;
+const int kDefaultValueInputIndex = 2;
+const int kFirstPartitionInputIndex = 3;
+
+template <typename INDEX_TYPE>
+class RaggedTensorToTensorBaseOp : public OpKernel {
+ public:
+  typedef
+      typename ::tensorflow::TTypes<const INDEX_TYPE>::Flat RowPartitionTensor;
+
+  explicit RaggedTensorToTensorBaseOp(OpKernelConstruction* context)
+      : OpKernel(context) {
+    OP_REQUIRES_OK(context, GetRowPartitionTypes<OpKernelConstruction>(
+                                context, &row_partition_types_));
+    ragged_rank_ = GetRaggedRank(row_partition_types_);
+  }
+
+  // Returns the relationship between dimension and dimension + 1.
+  RowPartitionType GetRowPartitionTypeByDimension(int dimension) {
+    if (row_partition_types_[0] == RowPartitionType::FIRST_DIM_SIZE) {
+      return row_partition_types_[dimension + 1];
+    } else {
+      return row_partition_types_[dimension];
+    }
+  }
+
+  // Returns the relationship between dimension and dimension + 1.
+  RowPartitionTensor GetRowPartitionTensor(OpKernelContext* c, int dimension) {
+    if (row_partition_types_[0] == RowPartitionType::FIRST_DIM_SIZE) {
+      return c->input(dimension + 1 + kFirstPartitionInputIndex)
+          .flat<INDEX_TYPE>();
+    } else {
+      return c->input(dimension + kFirstPartitionInputIndex).flat<INDEX_TYPE>();
+    }
+  }
+
+  Status GetMaxWidth(OpKernelContext* c, int dimension, INDEX_TYPE* result) {
+    const RowPartitionTensor row_partition_tensor =
+        GetRowPartitionTensor(c, dimension - 1);
+    switch (GetRowPartitionTypeByDimension(dimension - 1)) {
+      case RowPartitionType::VALUE_ROWIDS:
+        *result = GetMaxWidthValueRowID(row_partition_tensor);
+        return Status::OK();
+      case RowPartitionType::ROW_SPLITS:
+        *result = GetMaxWidthRowSplit(row_partition_tensor);
+        return Status::OK();
+      default:
+        return errors::InvalidArgument(
+            "Cannot handle partition type ",
+            RowPartitionTypeToString(
+                GetRowPartitionTypeByDimension(dimension - 1)));
+    }
+  }
+
+  static INDEX_TYPE GetMaxWidthRowSplit(const RowPartitionTensor& row_split) {
+    const INDEX_TYPE tensor_length = row_split.size();
+    if (tensor_length == 0 || tensor_length == 1) {
+      return 0;
+    }
+    INDEX_TYPE max_width = 0;
+    for (INDEX_TYPE i = 0; i < tensor_length - 1; ++i) {
+      const INDEX_TYPE current_width = row_split(i + 1) - row_split(i);
+      if (current_width > max_width) {
+        max_width = current_width;
+      }
+    }
+    return max_width;
+  }
+
+  static INDEX_TYPE GetMaxWidthValueRowID(
+      const RowPartitionTensor& value_rowids) {
+    const INDEX_TYPE index_length = value_rowids.size();
+    if (index_length == 0) {
+      return 0;
+    }
+    INDEX_TYPE first_equal_index = 0;
+    INDEX_TYPE first_equal_index_value = value_rowids(0);
+    INDEX_TYPE max_width = 0;
+    for (INDEX_TYPE i = 1; i < index_length; ++i) {
+      const INDEX_TYPE value = value_rowids(i);
+      if (value != first_equal_index_value) {
+        first_equal_index_value = value;
+        max_width = std::max(i - first_equal_index, max_width);
+        first_equal_index = i;
+      }
+    }
+    return std::max(index_length - first_equal_index, max_width);
+  }
+
+  Status CalculateOutputSize(INDEX_TYPE first_dim, OpKernelContext* c,
+                             vector<INDEX_TYPE>* result) {
+    TensorShapeProto value_shape_proto;
+    c->input(kValueInputIndex).shape().AsProto(&value_shape_proto);
+
+    TensorShapeProto default_value_shape_proto;
+    c->input(kDefaultValueInputIndex)
+        .shape()
+        .AsProto(&default_value_shape_proto);
+
+    TensorShapeProto output_shape_proto;
+    TF_RETURN_IF_ERROR(ValidateDefaultValueShape(default_value_shape_proto,
+                                                 value_shape_proto));
+
+    TensorShapeProto shape_proto;
+    {
+      PartialTensorShape partial_tensor_shape;
+      TF_RETURN_IF_ERROR(TensorShapeFromTensor(c->input(kShapeInputIndex),
+                                               &partial_tensor_shape));
+      partial_tensor_shape.AsProto(&shape_proto);
+    }
+
+    TF_RETURN_IF_ERROR(CombineRaggedTensorToTensorShapes(
+        ragged_rank_, shape_proto, value_shape_proto, &output_shape_proto));
+
+    result->reserve(output_shape_proto.dim_size());
+    for (const TensorShapeProto::Dim& dim : output_shape_proto.dim()) {
+      // Note that this may be -1 (if dimension size is unknown).
+      result->push_back(dim.size());
+    }
+
+    if ((*result)[0] < 0) {
+      (*result)[0] = first_dim;
+    }
+    for (int i = 1; i <= ragged_rank_; ++i) {
+      if ((*result)[i] < 0) {
+        TF_RETURN_IF_ERROR(GetMaxWidth(c, i, &(*result)[i]));
+      }
+    }
+    return Status::OK();
+  }
+
+  /**
+   * The output_index represents the index in the output tensor
+   * where the first element of a particular dimension would be written.
+   * If it is -1, it indicates that the index is out of scope.
+   * Example, given first_dimension = 10, first_dimension_output = 6,
+   * and output_index_multiplier = 100:
+   * result = [0 100 200 300 400 500 -1 -1 -1 -1]
+   * If first_dimension_output = 11 instead, then:
+   * result = [0 100 200 300 400 500 600 700 800 900]
+   */
+  vector<INDEX_TYPE> CalculateFirstParentOutputIndex(
+      INDEX_TYPE first_dimension, INDEX_TYPE output_index_multiplier,
+      INDEX_TYPE first_dimension_output) {
+    const INDEX_TYPE min_dimension =
+        std::min(first_dimension, first_dimension_output);
+    vector<INDEX_TYPE> result;
+    result.reserve(first_dimension);
+    int current_output_index = 0;
+    for (INDEX_TYPE i = 0; i < min_dimension;
+         ++i, current_output_index += output_index_multiplier) {
+      result.push_back(current_output_index);
+    }
+    for (INDEX_TYPE i = min_dimension; i < first_dimension; ++i) {
+      result.push_back(-1);
+    }
+    DCHECK_EQ(result.size(), first_dimension);
+    return result;
+  }
+
+  void CalculateOutputIndexRowSplit(
+      const RowPartitionTensor& row_split,
+      const vector<INDEX_TYPE>& parent_output_index,
+      INDEX_TYPE output_index_multiplier, INDEX_TYPE output_size,
+      vector<INDEX_TYPE>* result) {
+    INDEX_TYPE row_split_size = row_split.size();
+    if (row_split_size > 0) {
+      result->reserve(row_split(row_split_size - 1));
+    }
+    for (INDEX_TYPE i = 0; i < row_split_size - 1; ++i) {
+      INDEX_TYPE row_length = row_split(i + 1) - row_split(i);
+      INDEX_TYPE real_length = std::min(output_size, row_length);
+      INDEX_TYPE parent_output_index_current = parent_output_index[i];
+
+      if (parent_output_index_current == -1) {
+        real_length = 0;
+      }
+      for (INDEX_TYPE j = 0; j < real_length; ++j) {
+        result->push_back(parent_output_index_current);
+        parent_output_index_current += output_index_multiplier;
+      }
+      for (INDEX_TYPE j = 0; j < row_length - real_length; ++j) {
+        result->push_back(-1);
+      }
+    }
+    if (row_split_size > 0) {
+      DCHECK_EQ(result->size(), row_split(row_split_size - 1));
+    }
+  }
+
+  // Calculate the output index of the first element of a list.
+  // The parent_output_index is the same computation for the previous list.
+  // -1 indicates an element or list that is out of range.
+  // The output_index_multiplier is the number of output indices one moves
+  // forward for each column.
+  // E.g., given:
+  // value_rowids:[0 1 2 2 2 3 5 5 6]
+  // parent_output_index:[1000 1100 2000 2100 -1 3000 4000]
+  // output_index_multiplier: 10
+  // output_size: 2
+  // You get:
+  // result = [1000 1100 2000 2010 -1 2100 -1 -1 3000]
+  // result[0] = parent_output_index[value_rowids[0]]
+  // result[1] = parent_output_index[value_rowids[1]]
+  // result[2] = parent_output_index[value_rowids[2]]
+  // result[3] = parent_output_index[value_rowids[2] + 10]
+  // result[4] = -1 because it is the third element the size is 2.
+  // result[5] = parent_output_index[value_rowids[3]]
+  // result[6] = -1 because parent_output_index[value_rowids[6]] == -1
+  // result[7] = -1 because parent_output_index[value_rowids[6]] == -1
+  // result[8] = parent_output_index[value_rowids[7]]
+  void CalculateOutputIndexValueRowID(
+      const RowPartitionTensor& value_rowids,
+      const vector<INDEX_TYPE>& parent_output_index,
+      INDEX_TYPE output_index_multiplier, INDEX_TYPE output_size,
+      vector<INDEX_TYPE>* result) {
+    const INDEX_TYPE index_size = value_rowids.size();
+    result->reserve(index_size);
+    if (index_size == 0) {
+      return;
+    }
+
+    INDEX_TYPE current_output_column = 0;
+    INDEX_TYPE current_value_rowid = value_rowids(0);
+    DCHECK_LT(current_value_rowid, parent_output_index.size());
+    INDEX_TYPE current_output_index = parent_output_index[current_value_rowid];
+    result->push_back(current_output_index);
+    for (INDEX_TYPE i = 1; i < index_size; ++i) {
+      INDEX_TYPE next_value_rowid = value_rowids(i);
+      if (next_value_rowid == current_value_rowid) {
+        if (current_output_index >= 0) {
+          ++current_output_column;
+          if (current_output_column < output_size) {
+            current_output_index += output_index_multiplier;
+          } else {
+            current_output_index = -1;
+          }
+        }
+      } else {
+        current_output_column = 0;
+        current_value_rowid = next_value_rowid;
+        DCHECK_LT(next_value_rowid, parent_output_index.size());
+        current_output_index = parent_output_index[next_value_rowid];
+      }
+      result->push_back(current_output_index);
+    }
+    DCHECK_EQ(result->size(), value_rowids.size());
+  }
+
+  Status CalculateOutputIndex(OpKernelContext* context, int dimension,
+                              const vector<INDEX_TYPE>& parent_output_index,
+                              INDEX_TYPE output_index_multiplier,
+                              INDEX_TYPE output_size,
+                              vector<INDEX_TYPE>* result) {
+    const RowPartitionTensor row_partition_tensor =
+        GetRowPartitionTensor(context, dimension);
+    auto partition_type = GetRowPartitionTypeByDimension(dimension);
+    switch (partition_type) {
+      case RowPartitionType::VALUE_ROWIDS:
+        CalculateOutputIndexValueRowID(
+            row_partition_tensor, parent_output_index, output_index_multiplier,
+            output_size, result);
+        return tensorflow::Status::OK();
+      case RowPartitionType::ROW_SPLITS:
+        CalculateOutputIndexRowSplit(row_partition_tensor, parent_output_index,
+                                     output_index_multiplier, output_size,
+                                     result);
+        return tensorflow::Status::OK();
+      default:
+        return errors::InvalidArgument(
+            "Unsupported partition type:",
+            RowPartitionTypeToString(partition_type));
+    }
+  }
+
+  Status GetFirstDimensionSize(OpKernelContext* context, INDEX_TYPE* result) {
+    const Tensor first_partition_tensor =
+        context->input(kFirstPartitionInputIndex);
+    const RowPartitionType first_partition_type = row_partition_types_[0];
+    switch (first_partition_type) {
+      case RowPartitionType::FIRST_DIM_SIZE:
+        *result = first_partition_tensor.scalar<INDEX_TYPE>()();
+        return Status::OK();
+      case RowPartitionType::VALUE_ROWIDS:
+        return errors::InvalidArgument(
+            "Cannot handle VALUE_ROWIDS in first dimension.");
+      case RowPartitionType::ROW_SPLITS:
+        *result = first_partition_tensor.shape().dim_size(0) - 1;
+        return Status::OK();
+      default:
+        return errors::InvalidArgument(
+            "Cannot handle type ",
+            RowPartitionTypeToString(first_partition_type));
+    }
+  }
+
+  void Compute(OpKernelContext* context) override {
+    INDEX_TYPE first_dimension;
+    OP_REQUIRES_OK(context, GetFirstDimensionSize(context, &first_dimension));
+    vector<INDEX_TYPE> output_size;
+    OP_REQUIRES_OK(context,
+                   CalculateOutputSize(first_dimension, context, &output_size));
+    vector<INDEX_TYPE> multiplier;
+    multiplier.resize(output_size.size());
+
+    multiplier[multiplier.size() - 1] = 1;
+    for (int i = output_size.size() - 2; i >= 0; --i) {
+      multiplier[i] = multiplier[i + 1] * output_size[i + 1];
+    }
+    // Full size of the tensor.
+    TensorShape output_shape;
+    OP_REQUIRES_OK(context,
+                   TensorShapeUtils::MakeShape(output_size, &output_shape));
+    Tensor* output_tensor = nullptr;
+
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(0, output_shape, &output_tensor));
+    const INDEX_TYPE full_size = multiplier[0] * output_size[0];
+    if (full_size > 0) {
+      vector<INDEX_TYPE> output_index = CalculateFirstParentOutputIndex(
+          first_dimension, multiplier[0], output_size[0]);
+
+      for (int i = 1; i <= ragged_rank_; ++i) {
+        vector<INDEX_TYPE> new_output_index;
+        OP_REQUIRES_OK(context, CalculateOutputIndex(
+                                    context, i - 1, output_index, multiplier[i],
+                                    output_size[i], &new_output_index));
+        output_index = new_output_index;
+      }
+
+      SetOutput(context, output_index, output_tensor);
+    }
+  }
+  virtual void SetOutput(OpKernelContext* context,
+                         const vector<INDEX_TYPE>& output_index,
+                         Tensor* output_tensor) = 0;
+
+ private:
+  vector<RowPartitionType> row_partition_types_;
+  int ragged_rank_;
+};
+
+template <typename VALUE_TYPE, typename INDEX_TYPE>
+void slow_copy_array(VALUE_TYPE* dst, const VALUE_TYPE* src, INDEX_TYPE size) {
+  for (INDEX_TYPE index = 0; index < size; ++index) {
+    dst[index] = src[index];
+  }
+}
+
+template <typename VALUE_TYPE, typename INDEX_TYPE>
+void copy_array(VALUE_TYPE* dst, const VALUE_TYPE* src, INDEX_TYPE size,
+                size_t bytes) {
+  memcpy(dst, src, bytes);
+}
+
+template <>
+void copy_array<string, int64>(string* dst, const string* src, int64 size,
+                               size_t bytes) {
+  slow_copy_array(dst, src, size);
+}
+
+template <>
+void copy_array<string, int32>(string* dst, const string* src, int32 size,
+                               size_t bytes) {
+  slow_copy_array(dst, src, size);
+}
+
+// If we don't specialize for Eigen::half, we get:
+// undefined behavior, destination object type 'Eigen::half'
+// is not TriviallyCopyable
+template <>
+void copy_array<Eigen::half, int64>(Eigen::half* dst, const Eigen::half* src,
+                                    int64 size, size_t bytes) {
+  slow_copy_array(dst, src, size);
+}
+
+template <>
+void copy_array<Eigen::half, int32>(Eigen::half* dst, const Eigen::half* src,
+                                    int32 size, size_t bytes) {
+  slow_copy_array(dst, src, size);
+}
+
+template <typename VALUE_TYPE, typename INDEX_TYPE>
+class RaggedTensorToTensorOp : public RaggedTensorToTensorBaseOp<INDEX_TYPE> {
+ public:
+  explicit RaggedTensorToTensorOp(OpKernelConstruction* context)
+      : RaggedTensorToTensorBaseOp<INDEX_TYPE>(context) {}
+
+  void SetOutput(OpKernelContext* context,
+                 const vector<INDEX_TYPE>& output_index,
+                 Tensor* output_tensor) override {
+    typename tensorflow::TTypes<VALUE_TYPE>::Flat output_flat =
+        output_tensor->flat<VALUE_TYPE>();
+    const auto& value_tensor = context->input(kValueInputIndex);
+    const auto& default_value_tensor = context->input(kDefaultValueInputIndex);
+    if (value_tensor.shape().dims() == 1) {
+      // Initialize tensor to default_value.
+      VALUE_TYPE* base_output = output_flat.data();
+      VALUE_TYPE default_value = default_value_tensor.scalar<VALUE_TYPE>()();
+
+      std::fill(base_output, base_output + output_flat.size(), default_value);
+      auto values = context->input(kValueInputIndex).flat<VALUE_TYPE>();
+      int values_size = values.size();
+      OP_REQUIRES(context, values_size == output_index.size(),
+                  Internal("Values and indices must be equal"));
+      for (int i = 0; i < values_size; ++i) {
+        if (output_index[i] >= 0) {
+          output_flat(output_index[i]) = values(i);
+        }
+      }
+    } else {
+      const auto& output_shape = output_tensor->shape();
+      const auto& default_value_shape = default_value_tensor.shape();
+
+      // Initialize tensor to default_value.
+
+      BCast bcast(BCast::FromShape(default_value_shape),
+                  BCast::FromShape(output_shape),
+                  /*fewer_dims_optimization=*/true);
+      OP_REQUIRES(
+          context, bcast.IsValid(),
+          errors::InvalidArgument(
+              "Incompatible shapes: ", default_value_shape.DebugString(),
+              " vs. ", default_value_shape.DebugString()));
+      OP_REQUIRES(
+          context, BCast::ToShape(bcast.output_shape()) == output_shape,
+          errors::InvalidArgument("Unable to broadcast default_value of shape ",
+                                  default_value_shape, " to tensor of shape ",
+                                  output_shape));
+      const CPUDevice& device = context->eigen_device<CPUDevice>();
+      functor::BroadcastTo<CPUDevice, VALUE_TYPE>()(
+          device, context, *output_tensor, output_shape, default_value_tensor,
+          default_value_shape, bcast);
+
+      VALUE_TYPE* base_output = output_flat.data();
+      auto values = context->input(kValueInputIndex).flat<VALUE_TYPE>();
+      size_t values_size = values.size();
+      size_t output_index_size = output_index.size();
+      //  A value "element" is a group of values that are arranged together.
+      // For example, if the value shape is [3,4,5], then 20 values are in a
+      // value element.
+      int value_element_size = values_size / output_index_size;
+      int value_element_bytesize = value_element_size * sizeof(VALUE_TYPE);
+      const VALUE_TYPE* values_base = values.data();
+
+      OP_REQUIRES(context,
+                  value_tensor.shape().dim_size(0) == output_index_size,
+                  Internal("Values and indices must be equal"));
+
+      OP_REQUIRES(context,
+                  values_size == output_index_size * value_element_size,
+                  Internal("Values and indices must be equal"));
+      INDEX_TYPE value_index = 0;
+      for (int i = 0; i < output_index_size;
+           ++i, value_index += value_element_size) {
+        if (output_index[i] >= 0) {
+          VALUE_TYPE* dst = base_output + output_index[i];
+          const VALUE_TYPE* src = values_base + value_index;
+          copy_array<VALUE_TYPE, INDEX_TYPE>(dst, src, value_element_size,
+                                             value_element_bytesize);
+        }
+      }
+    }
+  }
+};
+
+#define REGISTER_CPU_KERNEL_INDEX_TYPE(value_type, index_type)       \
+  REGISTER_KERNEL_BUILDER(Name("RaggedTensorToTensor")               \
+                              .Device(DEVICE_CPU)                    \
+                              .TypeConstraint<value_type>("T")       \
+                              .TypeConstraint<index_type>("Tindex"), \
+                          RaggedTensorToTensorOp<value_type, index_type>);
+
+#define REGISTER_CPU_KERNEL(value_type)                          \
+  REGISTER_CPU_KERNEL_INDEX_TYPE(value_type, tensorflow::int64); \
+  REGISTER_CPU_KERNEL_INDEX_TYPE(value_type, tensorflow::int32);
+
+TF_CALL_POD_TYPES(REGISTER_CPU_KERNEL);
+TF_CALL_string(REGISTER_CPU_KERNEL);
+TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL);
+TF_CALL_quint16(REGISTER_CPU_KERNEL);
+TF_CALL_qint16(REGISTER_CPU_KERNEL);
+TF_CALL_uint32(REGISTER_CPU_KERNEL);
+TF_CALL_uint64(REGISTER_CPU_KERNEL);
+
+#undef REGISTER_CPU_KERNEL
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/ragged_tensor_to_tensor_op_test.cc b/tensorflow/core/kernels/ragged_tensor_to_tensor_op_test.cc
new file mode 100644
index 0000000..7337ebe
--- /dev/null
+++ b/tensorflow/core/kernels/ragged_tensor_to_tensor_op_test.cc
@@ -0,0 +1,553 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/shape_inference_testutil.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+template <typename VALUE_TYPE>
+struct ShapeAndValues {
+  TensorShape shape;
+  std::vector<VALUE_TYPE> values;
+};
+
+template <typename VALUE_TYPE>
+ShapeAndValues<VALUE_TYPE> createVector(const std::vector<VALUE_TYPE>& values) {
+  TensorShape shape({static_cast<int64>(values.size())});
+  return {shape, values};
+}
+
+template <typename VALUE_TYPE>
+ShapeAndValues<VALUE_TYPE> createScalar(const VALUE_TYPE& values) {
+  TensorShape shape({});
+  return {shape, {values}};
+}
+
+class RaggedTensorToTensorOpTest : public ::tensorflow::OpsTestBase {
+ protected:
+  // Builds the tensorflow test graph for RaggedTensorToTensor.
+  template <typename VALUE_TYPE, typename INDEX_TYPE>
+  void BuildRaggedTensorToTensorGraph(
+      const TensorShape& shape, const std::vector<string>& row_partition_types,
+      const ShapeAndValues<VALUE_TYPE>& values,
+      const ShapeAndValues<VALUE_TYPE>& default_value,
+      const std::vector<ShapeAndValues<INDEX_TYPE>>& row_partition_tensors) {
+    const auto& value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
+    const auto& index_dtype = DataTypeToEnum<INDEX_TYPE>::v();
+    int num_row_partition_tensors = row_partition_tensors.size();
+    TF_ASSERT_OK(
+        NodeDefBuilder("tested_op", "RaggedTensorToTensor")
+            .Attr("T", value_dtype)
+            .Attr("Tindex", index_dtype)
+            .Attr("num_row_partition_tensors", num_row_partition_tensors)
+            .Attr("row_partition_types", row_partition_types)
+            .Input(FakeInput(index_dtype))
+            .Input(FakeInput(value_dtype))  // values
+            .Input(FakeInput(value_dtype))  // default_value
+            .Input(FakeInput(num_row_partition_tensors,
+                             index_dtype))  // row_partition_tensors
+            .Finalize(node_def()));
+    TF_ASSERT_OK(InitOp());
+    {
+      std::vector<INDEX_TYPE> shape_as_vector;
+      for (const auto& dim : shape.dim_sizes()) {
+        shape_as_vector.push_back(dim);
+      }
+      ShapeAndValues<INDEX_TYPE> shape_as_tensor =
+          createVector(shape_as_vector);
+      AddInputFromArray<INDEX_TYPE>(shape_as_tensor.shape,
+                                    shape_as_tensor.values);
+    }
+    AddInputFromArray<VALUE_TYPE>(values.shape, values.values);
+    AddInputFromArray<VALUE_TYPE>(default_value.shape, default_value.values);
+
+    for (const auto& row_partition_tensor : row_partition_tensors) {
+      AddInputFromArray<INDEX_TYPE>(row_partition_tensor.shape,
+                                    row_partition_tensor.values);
+    }
+  }
+};
+
+TEST_F(RaggedTensorToTensorOpTest, RaggedTensorToTensor) {
+  // indices = [2, 1, 0, 3]
+  // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]]
+  // params.shape = [4, None]
+  BuildRaggedTensorToTensorGraph<float, int32>(
+      TensorShape({4, 4}),                 // shape
+      {"FIRST_DIM_SIZE", "VALUE_ROWIDS"},  // row_partition_types
+      createVector<float>({.1, .2, .3, .4, .5, .6, .7, .8, .9}),  // values
+      createScalar<float>(1.5),  // default_value
+      {createScalar<int32>(4), createVector<int32>({0, 0, 0, 2, 2, 2, 2, 3, 3})}
+      // row_partition_tensors
+  );
+
+  TF_ASSERT_OK(RunOpKernel());
+
+  test::ExpectTensorNear<float>(
+      *GetOutput(0),
+      test::AsTensor<float>({.1, .2, .3, 1.5, 1.5, 1.5, 1.5, 1.5, .4, .5, .6,
+                             .7, .8, .9, 1.5, 1.5},
+                            TensorShape({4, 4})),
+      0.01);
+}
+
+TEST_F(RaggedTensorToTensorOpTest, RaggedTensorToTensorRowSplits) {
+  // indices = [2, 1, 0, 3]
+  // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]]
+  BuildRaggedTensorToTensorGraph<float, int32>(
+      TensorShape({4, 4}),  // shape
+      {"ROW_SPLITS"},       // row_partition_types
+      createVector<float>({.1, .2, .3, .4, .5, .6, .7, .8, .9}),  // values
+      createScalar<float>(1.5),               // default_value
+      {createVector<int32>({0, 3, 3, 7, 9})}  // row_partition_tensors
+  );
+
+  TF_ASSERT_OK(RunOpKernel());
+
+  test::ExpectTensorNear<float>(
+      *GetOutput(0),
+      test::AsTensor<float>({.1, .2, .3, 1.5, 1.5, 1.5, 1.5, 1.5, .4, .5, .6,
+                             .7, .8, .9, 1.5, 1.5},
+                            TensorShape({4, 4})),
+      0.01);
+}
+
+TEST_F(RaggedTensorToTensorOpTest, RaggedTensorToTensor_3DParams) {
+  // params = [
+  //           [[]],
+  //           [[.1, .2], [.3]],
+  //           [],
+  //           [[.4, .5], [.6, .7, .8]],
+  //           [[.9]]
+  //          ]
+  BuildRaggedTensorToTensorGraph<float, int32>(
+      TensorShape({5, 2, 3}),  // shape
+      {"FIRST_DIM_SIZE", "VALUE_ROWIDS",
+       "VALUE_ROWIDS"},  // row_partition_types
+      createVector<float>({.1, .2, .3, .4, .5, .6, .7, .8, .9}),  // values
+      createScalar<float>(1.5),  // default_value
+      {
+          createScalar<int32>(5),
+          createVector<int32>({0, 1, 1, 3, 3, 4}),
+          createVector<int32>({1, 1, 2, 3, 3, 4, 4, 4, 5}),
+      }  // row_partition_tensors
+  );
+  TF_ASSERT_OK(RunOpKernel());
+
+  // Expected = [
+  //              [[1.5, 1.5, 1.5], [1.5, 1.5, 1.5]],
+  //              [[.1, .2, 1.5], [.3, 1.5, 1.5]],
+  //              [[1.5, 1.5, 1.5], [1.5, 1.5, 1.5]],
+  //              [[.4, .5, 1.5], [.6, .7, .8]],
+  //              [[.9, 1.5, 1.5], [1.5, 1.5, 1.5]]
+  //            ]
+  test::ExpectTensorNear<float>(
+      *GetOutput(0),
+      test::AsTensor<float>({1.5, 1.5, 1.5, 1.5, 1.5, 1.5, .1,  .2,  1.5, .3,
+                             1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, .4,  .5,
+                             1.5, .6,  .7,  .8,  .9,  1.5, 1.5, 1.5, 1.5, 1.5},
+                            TensorShape({5, 2, 3})),
+      0.1);
+}
+
+TEST_F(RaggedTensorToTensorOpTest, RaggedTensorToTensor_3DParamsRowSplits) {
+  // params = [
+  //           [[]],
+  //           [[.1, .2], [.3]],
+  //           [],
+  //           [[.4, .5], [.6, .7, .8]],
+  //           [[.9]]
+  //          ]
+  BuildRaggedTensorToTensorGraph<float, int32>(
+      TensorShape({5, 2, 3}),        // shape
+      {"ROW_SPLITS", "ROW_SPLITS"},  // row_partition_types
+      createVector<float>({.1, .2, .3, .4, .5, .6, .7, .8, .9}),  // values
+      createScalar<float>(1.5),  // default_value
+      {
+          createVector<int32>({0, 1, 3, 3, 5, 6}),
+          createVector<int32>({0, 0, 2, 3, 5, 8, 9}),
+      }  // row_partition_tensors
+  );
+  TF_ASSERT_OK(RunOpKernel());
+
+  // Expected = [
+  //              [[1.5, 1.5, 1.5], [1.5, 1.5, 1.5]],
+  //              [[.1, .2, 1.5], [.3, 1.5, 1.5]],
+  //              [[1.5, 1.5, 1.5], [1.5, 1.5, 1.5]],
+  //              [[.4, .5, 1.5], [.6, .7, .8]],
+  //              [[.9, 1.5, 1.5], [1.5, 1.5, 1.5]]
+  //            ]
+  test::ExpectTensorNear<float>(
+      *GetOutput(0),
+      test::AsTensor<float>({1.5, 1.5, 1.5, 1.5, 1.5, 1.5, .1,  .2,  1.5, .3,
+                             1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, .4,  .5,
+                             1.5, .6,  .7,  .8,  .9,  1.5, 1.5, 1.5, 1.5, 1.5},
+                            TensorShape({5, 2, 3})),
+      0.1);
+}
+
+// test_three_dimensional_ragged fails, want to try it at a lower level.
+TEST_F(RaggedTensorToTensorOpTest, RaggedTensorToTensor_3DParamsRowSplits2) {
+  // params = [
+  //           [[0, 1, 2], []],
+  //           [],
+  //           [[3]]
+  //          ]
+  BuildRaggedTensorToTensorGraph<int64, int64>(
+      TensorShape({3, 2, 3}),             // shape
+      {"ROW_SPLITS", "ROW_SPLITS"},       // row_partition_types
+      createVector<int64>({0, 1, 2, 3}),  // values
+      createScalar<int64>(5),             // default_value
+      {
+          createVector<int64>({0, 2, 2, 3}),
+          createVector<int64>({0, 3, 3, 4}),
+      }  // row_partition_tensors
+  );
+  TF_ASSERT_OK(RunOpKernel());
+
+  // Expected = [
+  //              [[0, 1, 2], [5, 5, 5]],
+  //              [[5, 5, 5], [5, 5, 5]],
+  //              [[3, 5, 5], [5, 5, 5]]
+  //            ]
+  test::ExpectTensorEqual<int64>(
+      *GetOutput(0), test::AsTensor<int64>(
+                         {0, 1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 3, 5, 5, 5, 5, 5},
+                         TensorShape({3, 2, 3})));
+}
+
+TEST_F(RaggedTensorToTensorOpTest, RaggedTensorToTensor_4DParams) {
+  // Input:    [[],
+  //            [
+  //             [[1, 2], [3, 4], [5, 6]],
+  //             [[7, 8]]
+  //            ],
+  //            [[]],
+  //            []
+  // ]
+  // params.shape = [3, 2, 3, 2]
+  BuildRaggedTensorToTensorGraph<int32, int32>(
+      TensorShape({4, 2, 3, 2}),  // shape
+      {"FIRST_DIM_SIZE", "VALUE_ROWIDS", "VALUE_ROWIDS",
+       "VALUE_ROWIDS"},                               // row_partition_types
+      createVector<int32>({1, 2, 3, 4, 5, 6, 7, 8}),  // values
+      createScalar<int32>(15),                        // default_value
+      {createScalar<int32>(5), createVector<int32>({0, 1, 1}),
+       createVector<int32>({1, 1, 1, 2}),
+       createVector<int32>({0, 0, 1, 1, 2, 2, 3, 3})}  // row_partition_tensors
+  );
+
+  TF_ASSERT_OK(RunOpKernel());
+  // params = [
+  //           [
+  //             [[15,15],[15,15],[15,15]],
+  //             [[15,15],[15,15],[15,15]],
+  //           ],
+  //           [
+  //             [[1, 2], [3, 4], [5, 6]],
+  //             [[7, 8], [15, 15], [15,15]],
+  //           ],
+  //             [[15,15],[15,15],[15,15]],
+  //             [[15,15],[15,15],[15,15]],
+  //           ],
+  //             [[15,15],[15,15],[15,15]],
+  //             [[15,15],[15,15],[15,15]],
+  //           ]
+  // params.shape = [3, 2, 3, 2]
+  test::ExpectTensorEqual<int32>(
+      *GetOutput(0),
+      test::AsTensor<int32>(
+          {15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 1,  2,  3,  4,
+           5,  6,  7,  8,  15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+           15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15},
+          TensorShape({4, 2, 3, 2})));
+}
+
+TEST_F(RaggedTensorToTensorOpTest, RaggedTensorToTensor_4DParamsRowSplit) {
+  // Input:    [[],
+  //            [
+  //             [[1, 2], [3, 4], [5, 6]],
+  //             [[7, 8]]
+  //            ],
+  //            [[]],
+  //            []
+  // ]
+  // params.shape = [3, 2, 3, 2]
+  BuildRaggedTensorToTensorGraph<int32, int32>(
+      TensorShape({4, 2, 3, 2}),  // shape
+      {"ROW_SPLITS", "ROW_SPLITS", "ROW_SPLITS"},
+      // row_partition_types
+      createVector<int32>({1, 2, 3, 4, 5, 6, 7, 8}),  // values
+      createScalar<int32>(15),                        // default_value
+      {createVector<int32>({0, 1, 3}), createVector<int32>({0, 0, 3, 4}),
+       createVector<int32>({0, 2, 4, 6, 8})}  // row_partition_tensors
+  );
+
+  TF_ASSERT_OK(RunOpKernel());
+  // params = [
+  //           [
+  //             [[15,15],[15,15],[15,15]],
+  //             [[15,15],[15,15],[15,15]],
+  //           ],
+  //           [
+  //             [[1, 2], [3, 4], [5, 6]],
+  //             [[7, 8], [15, 15], [15,15]],
+  //           ],
+  //             [[15,15],[15,15],[15,15]],
+  //             [[15,15],[15,15],[15,15]],
+  //           ],
+  //             [[15,15],[15,15],[15,15]],
+  //             [[15,15],[15,15],[15,15]],
+  //           ]
+  // params.shape = [3, 2, 3, 2]
+  test::ExpectTensorEqual<int32>(
+      *GetOutput(0),
+      test::AsTensor<int32>(
+          {15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 1,  2,  3,  4,
+           5,  6,  7,  8,  15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+           15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15},
+          TensorShape({4, 2, 3, 2})));
+}
+
+TEST_F(RaggedTensorToTensorOpTest, RaggedTensorToTensorContractExpanded) {
+  // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]]
+  BuildRaggedTensorToTensorGraph<float, int32>(
+      TensorShape({3, 5}),                 // shape
+      {"FIRST_DIM_SIZE", "VALUE_ROWIDS"},  // row_partition_types
+      createVector<float>({.1, .2, .3, .4, .5, .6, .7, .8, .9}),  // values
+      createScalar<float>(1.5),  // default_value
+      {createScalar<int32>(4), createVector<int32>({0, 0, 0, 2, 2, 2, 2, 3, 3})}
+      // row_partition_tensors
+  );
+
+  TF_ASSERT_OK(RunOpKernel());
+
+  test::ExpectTensorNear<float>(
+      *GetOutput(0),
+      test::AsTensor<float>({.1, .2, .3, 1.5, 1.5,     //
+                             1.5, 1.5, 1.5, 1.5, 1.5,  //
+                             .4, .5, .6, .7, 1.5},     //
+                            TensorShape({3, 5})),
+      0.01);
+}
+
+// Adds a dense dimension.
+TEST_F(RaggedTensorToTensorOpTest, RaggedTensorToTensorContractExpandedDense) {
+  // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]]
+  BuildRaggedTensorToTensorGraph<float, int32>(
+      TensorShape({3, 5, 2}),              // shape
+      {"FIRST_DIM_SIZE", "VALUE_ROWIDS"},  // row_partition_types
+      ShapeAndValues<float>{TensorShape({9, 2}),
+                            {.1, 1.1, .2, 1.2, .3, 1.3, .4, 1.4, .5, 1.5, .6,
+                             1.6, .7, 1.7, .8, 1.8, .9, 1.9}},  // values
+      createScalar<float>(1.5),                                 // default_value
+      {createScalar<int32>(4), createVector<int32>({0, 0, 0, 2, 2, 2, 2, 3, 3})}
+      // row_partition_tensors
+  );
+
+  TF_ASSERT_OK(RunOpKernel());
+
+  test::ExpectTensorNear<float>(
+      *GetOutput(0),
+      test::AsTensor<float>(
+          {.1,  1.1, .2,  1.2, .3,  1.3, 1.5, 1.5, 1.5, 1.5,   //
+           1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5,   //
+           .4,  1.4, .5,  1.5, .6,  1.6, .7,  1.7, 1.5, 1.5},  //
+          TensorShape({3, 5, 2})),
+      0.01);
+}
+
+TEST_F(RaggedTensorToTensorOpTest, RaggedTensorToTensorConstrained) {
+  // params = [[.1, .2, .3],
+  //           [],
+  //           [.4, .5, .6, .7],
+  //           [.8, .9]]
+  // constrained to (3, 3)
+  BuildRaggedTensorToTensorGraph<float, int32>(
+      TensorShape({3, 3}),                 // shape
+      {"FIRST_DIM_SIZE", "VALUE_ROWIDS"},  // row_partition_types
+      createVector<float>({.1, .2, .3, .4, .5, .6, .7, .8, .9}),  // values
+      createScalar<float>(1.5),  // default_value
+      {createScalar<int32>(4), createVector<int32>({0, 0, 0, 2, 2, 2, 2, 3, 3})}
+      // row_partition_tensors
+  );
+
+  TF_ASSERT_OK(RunOpKernel());
+
+  test::ExpectTensorNear<float>(*GetOutput(0),
+                                test::AsTensor<float>(
+                                    {
+                                        //
+                                        .1, .2, .3,     //
+                                        1.5, 1.5, 1.5,  //
+                                        .4, .5, .6      //
+                                    },
+                                    TensorShape({3, 3})),
+                                0.01);
+}
+
+TEST_F(RaggedTensorToTensorOpTest, RaggedTensorToTensor_3DParamsConstrained) {
+  // params = [
+  //           [[]],
+  //           [[.1, .2], [.3]],
+  //           [],
+  //           [[.4, .5], [.6, .7, .8]],
+  //           [[.9]]
+  //          ]
+  // params.shape = [5, None, None]
+  BuildRaggedTensorToTensorGraph<float, int32>(
+      TensorShape({4, 1, 2}),  // shape
+      {"FIRST_DIM_SIZE", "VALUE_ROWIDS",
+       "VALUE_ROWIDS"},  // row_partition_types
+      createVector<float>({.1, .2, .3, .4, .5, .6, .7, .8, .9}),  // values
+      createScalar<float>(1.5),  // default_value
+      {
+          createScalar<int32>(5),
+          createVector<int32>({0, 1, 1, 3, 3, 4}),
+          createVector<int32>({1, 1, 2, 3, 3, 4, 4, 4, 5}),
+      }  // row_partition_tensors
+  );
+  TF_ASSERT_OK(RunOpKernel());
+
+  // Expected = [
+  //              [[1.5, 1.5]],
+  //              [[.1, .2]],
+  //              [[1.5, 1.5]],
+  //              [[.4, .5]],
+  //            ]
+  test::ExpectTensorNear<float>(
+      *GetOutput(0),
+      test::AsTensor<float>({1.5, 1.5, .1, .2, 1.5, 1.5, .4, .5},
+                            TensorShape({4, 1, 2})),
+      0.01);
+}
+
+// Seg fault but removing this does not make the problem go away.
+// This tests is labeled as flaky. Removing it to find out.
+TEST_F(RaggedTensorToTensorOpTest, RaggedTensorToTensor_4DParamsConstrained) {
+  // Input:    [[],
+  //            [
+  //             [[1, 2], [3, 4], [5, 6]],
+  //             [[7, 8]]
+  //            ],
+  //            [[]],
+  //            []
+  // ]
+  // params.shape = [3, 2, 3, 2]
+  BuildRaggedTensorToTensorGraph<int32, int32>(
+      TensorShape({2, 2, 2, 2}),  // shape
+      {"FIRST_DIM_SIZE", "VALUE_ROWIDS", "VALUE_ROWIDS",
+       "VALUE_ROWIDS"},                               // row_partition_types
+      createVector<int32>({1, 2, 3, 4, 5, 6, 7, 8}),  // values
+      createScalar<int32>(15),                        // default_value
+      {createScalar<int32>(5), createVector<int32>({0, 1, 1}),
+       createVector<int32>({1, 1, 1, 2}),
+       createVector<int32>({0, 0, 1, 1, 2, 2, 3, 3})}  // row_partition_tensors
+  );
+
+  TF_ASSERT_OK(RunOpKernel());
+  // params = [
+  //           [
+  //             [[15,15],[15,15]],
+  //             [[15,15],[15,15]],
+  //           ],
+  //           [
+  //             [[1, 2], [3, 4]],
+  //             [[7, 8], [15, 15]],
+  //           ],
+  //          ]
+  // params.shape = [3, 2, 3, 2]
+  test::ExpectTensorEqual<int32>(*GetOutput(0), test::AsTensor<int32>(
+                                                    {
+                                                        15, 15, 15, 15,  //
+                                                        15, 15, 15, 15,  //
+                                                        1, 2, 3, 4,      //
+                                                        7, 8, 15, 15,    //
+                                                    },
+                                                    TensorShape({2, 2, 2, 2})));
+}
+
+TEST_F(RaggedTensorToTensorOpTest, ShapeWrongDimensions) {
+  BuildRaggedTensorToTensorGraph<int32, int32>(
+      TensorShape({10, 7, 10, 20}),  // shape
+      {"FIRST_DIM_SIZE", "VALUE_ROWIDS",
+       "VALUE_ROWIDS"},                   // row_partition_types
+      createVector<int32>({1, 2, 3, 4}),  // values
+      createScalar<int32>(15),            // default_value
+      {createScalar<int32>(5), createVector<int32>({0, 1, 1}),
+       createVector<int32>({1, 1, 1, 2})}  // row_partition_tensors
+  );
+  // Fails with an invalid argument.
+  EXPECT_EQ(RunOpKernel().code(), errors::Code::INVALID_ARGUMENT);
+}
+
+class RaggedTensorToTensorOpUnknownShapeTest
+    : public ::tensorflow::OpsTestBase {
+ protected:
+  std::unique_ptr<ShapeInferenceTestOp> op_;
+  void SetAttributes(const gtl::ArraySlice<string> row_partition_types,
+                     int num_row_partition_tensors) {
+    op_ = absl::make_unique<ShapeInferenceTestOp>("RaggedTensorToTensor");
+    SetAttrValue(row_partition_types,
+                 &((*op_->node_def.mutable_attr())["row_partition_types"]));
+    (*op_->node_def.mutable_attr())["num_row_partition_tensors"].set_i(
+        num_row_partition_tensors);
+  }
+};
+
+TEST_F(RaggedTensorToTensorOpUnknownShapeTest, ValueRowIDs) {
+  SetAttributes(gtl::ArraySlice<string>{"FIRST_DIM_SIZE", "VALUE_ROWIDS"}, 2);
+
+  INFER_OK(*op_, "?;?;?;?;?", "?");
+  INFER_OK(*op_, "?;[6];[];[];[6]", "[?,?]");
+  INFER_OK(*op_, "?;[6];?;[];[6]", "[?,?]");
+  INFER_OK(*op_, "?;?;[];[];[6]", "?");
+  INFER_OK(*op_, "?;[6];?;[];[6]", "[?,?]");
+  INFER_OK(*op_, "?;[6,2];?;[];[6]", "[?,?,2]");
+  INFER_OK(*op_, "?;[6,2];[2];[];[6]", "[?,?,2]");
+  INFER_OK(*op_, "?;[6,2,7];[2,7];[];[6]", "[?,?,2,7]");
+  INFER_ERROR("default_value_shape and value_shape do not match", *op_,
+              "?;[6,2];[3];[];[6]");
+  INFER_ERROR("default_value_shape and value_shape do not match", *op_,
+              "?;[6,2,1,2];[2,2];[];[6]");
+  INFER_ERROR("must be a vector", *op_, "?;[6];[];[];[3,6]");
+  INFER_ERROR("must be a scalar", *op_, "?;[6];[];[7];[3]");
+}
+
+TEST_F(RaggedTensorToTensorOpUnknownShapeTest, RowSplits) {
+  // RaggedTensorToTensor(param_splits+, param_values, indices) -> [splits+,
+  // values]
+  SetAttributes(gtl::ArraySlice<string>{"ROW_SPLITS"}, 1);
+
+  // value, default_value, ROW_SPLITS
+  INFER_OK(*op_, "?;?;?;?", "?");
+  INFER_OK(*op_, "?;[3];[];[6]", "[?,?]");
+  INFER_OK(*op_, "?;?;?;?", "?");
+  INFER_OK(*op_, "?;[3,2];[2];[6]", "[?,?,2]");
+  INFER_OK(*op_, "?;[3,2,7];[2,7];[6]", "[?,?,2,7]");
+  INFER_OK(*op_, "?;[3,2,7];[2,7];[6]", "[?,?,2,7]");
+}
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/random_binomial_op.cc b/tensorflow/core/kernels/random_binomial_op.cc
index a002549..df27541 100644
--- a/tensorflow/core/kernels/random_binomial_op.cc
+++ b/tensorflow/core/kernels/random_binomial_op.cc
@@ -96,7 +96,7 @@
     return kTailValues[static_cast<int>(k)];
   }
   double kp1sq = (k + 1) * (k + 1);
-  return (1 / 12 - (1 / 360 + 1 / 1260 / kp1sq) / kp1sq) / (k + 1);
+  return (1.0 / 12 - (1.0 / 360 + 1.0 / 1260 / kp1sq) / kp1sq) / (k + 1);
 }
 
 // We use a transformation-rejection algorithm from
diff --git a/tensorflow/core/kernels/reader_ops.cc b/tensorflow/core/kernels/reader_ops.cc
index abd16de..d93197c 100644
--- a/tensorflow/core/kernels/reader_ops.cc
+++ b/tensorflow/core/kernels/reader_ops.cc
@@ -139,8 +139,8 @@
                    context->allocate_output(
                        "values", TensorShape({num_actually_read}), &values));
 
-    auto keys_t = keys->vec<string>();
-    auto values_t = values->vec<string>();
+    auto keys_t = keys->vec<tstring>();
+    auto values_t = values->vec<tstring>();
     for (int i = 0; i < num_actually_read; ++i) {
       keys_t(i) = std::move(keys_vec[i]);
       values_t(i) = std::move(values_vec[i]);
@@ -221,7 +221,7 @@
         context, TensorShapeUtils::IsScalar(tensor->shape()),
         errors::InvalidArgument("Reader state must be scalar, but had shape: ",
                                 tensor->shape().DebugString()));
-    OP_REQUIRES_OK(context, reader->RestoreState(tensor->scalar<string>()()));
+    OP_REQUIRES_OK(context, reader->RestoreState(tensor->scalar<tstring>()()));
   }
 };
 
diff --git a/tensorflow/core/kernels/reduce_join_op.cc b/tensorflow/core/kernels/reduce_join_op.cc
index 7a81dfd..562281e 100644
--- a/tensorflow/core/kernels/reduce_join_op.cc
+++ b/tensorflow/core/kernels/reduce_join_op.cc
@@ -122,7 +122,7 @@
 
   void Compute(OpKernelContext* context) override {
     const Tensor& input = context->input(0);
-    const auto input_flat = input.flat<string>();
+    const auto input_flat = input.flat<tstring>();
     const TensorShape& input_shape = input.shape();
     const int32 input_dims = input_shape.dims();
 
@@ -156,7 +156,7 @@
         GetOutputShape(index_is_reduced, input_shape, keep_dims_);
     OP_REQUIRES_OK(context, context->allocate_output("output", output_shape,
                                                      &output_tensor));
-    auto output_flat = output_tensor->flat<string>();
+    auto output_flat = output_tensor->flat<tstring>();
 
     const int64 reduction_iter_size =
         GetReductionIterSize(reduced_indices, input_shape);
diff --git a/tensorflow/core/kernels/reduction_ops_all.cc b/tensorflow/core/kernels/reduction_ops_all.cc
index 4a34c4ef..70ea87a 100644
--- a/tensorflow/core/kernels/reduction_ops_all.cc
+++ b/tensorflow/core/kernels/reduction_ops_all.cc
@@ -30,7 +30,7 @@
         .HostMemory("reduction_indices"),
     ReductionOp<CPUDevice, bool, int64, Eigen::internal::AndReducer>);
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 REGISTER_KERNEL_BUILDER(
     Name("All")
         .TypeConstraint<int32>("Tidx")
diff --git a/tensorflow/core/kernels/reduction_ops_any.cc b/tensorflow/core/kernels/reduction_ops_any.cc
index 6c0519d..cd0ce28 100644
--- a/tensorflow/core/kernels/reduction_ops_any.cc
+++ b/tensorflow/core/kernels/reduction_ops_any.cc
@@ -30,7 +30,7 @@
         .HostMemory("reduction_indices"),
     ReductionOp<CPUDevice, bool, int64, Eigen::internal::OrReducer>);
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 REGISTER_KERNEL_BUILDER(
     Name("Any")
         .TypeConstraint<int32>("Tidx")
diff --git a/tensorflow/core/kernels/reduction_ops_common_gpu.h b/tensorflow/core/kernels/reduction_ops_common_gpu.h
index 9af43f8..2415f1d 100644
--- a/tensorflow/core/kernels/reduction_ops_common_gpu.h
+++ b/tensorflow/core/kernels/reduction_ops_common_gpu.h
@@ -15,8 +15,8 @@
 #ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_GPU_H_
 #define TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_GPU_H_
 
-#if !GOOGLE_CUDA
-#error This file must only be included when building with Cuda support
+#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM
+#error This file must only be included when building with GPU support
 #endif
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
diff --git a/tensorflow/core/kernels/reduction_ops_euclidean.cc b/tensorflow/core/kernels/reduction_ops_euclidean.cc
index 9f4bf50..cf719e7 100644
--- a/tensorflow/core/kernels/reduction_ops_euclidean.cc
+++ b/tensorflow/core/kernels/reduction_ops_euclidean.cc
@@ -33,7 +33,7 @@
 TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
 #undef REGISTER_CPU_KERNELS
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #define REGISTER_GPU_KERNELS(type)                                           \
   REGISTER_KERNEL_BUILDER(Name("EuclideanNorm")                              \
@@ -51,8 +51,10 @@
                           ReductionOp<GPUDevice, type, int64,                \
                                       functor::EuclideanNormReducer<type>>);
 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
+#if GOOGLE_CUDA
 TF_CALL_complex64(REGISTER_GPU_KERNELS);
 TF_CALL_complex128(REGISTER_GPU_KERNELS);
+#endif
 #undef REGISTER_GPU_KERNELS
 
 #endif
diff --git a/tensorflow/core/kernels/reduction_ops_gpu_bool.cu.cc b/tensorflow/core/kernels/reduction_ops_gpu_bool.cu.cc
index 79ec1d5..89bcf1d 100644
--- a/tensorflow/core/kernels/reduction_ops_gpu_bool.cu.cc
+++ b/tensorflow/core/kernels/reduction_ops_gpu_bool.cu.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #define EIGEN_USE_GPU
 
@@ -59,4 +59,4 @@
 }  // end namespace functor
 }  // end namespace tensorflow
 
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
diff --git a/tensorflow/core/kernels/reduction_ops_gpu_double.cu.cc b/tensorflow/core/kernels/reduction_ops_gpu_double.cu.cc
index c492308..c952c4c 100644
--- a/tensorflow/core/kernels/reduction_ops_gpu_double.cu.cc
+++ b/tensorflow/core/kernels/reduction_ops_gpu_double.cu.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #define EIGEN_USE_GPU
 
@@ -67,4 +67,4 @@
 }  // end namespace functor
 }  // end namespace tensorflow
 
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
diff --git a/tensorflow/core/kernels/reduction_ops_gpu_float.cu.cc b/tensorflow/core/kernels/reduction_ops_gpu_float.cu.cc
index b006311..92f4b9d 100644
--- a/tensorflow/core/kernels/reduction_ops_gpu_float.cu.cc
+++ b/tensorflow/core/kernels/reduction_ops_gpu_float.cu.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #define EIGEN_USE_GPU
 
@@ -67,4 +67,4 @@
 }  // end namespace functor
 }  // end namespace tensorflow
 
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
diff --git a/tensorflow/core/kernels/reduction_ops_gpu_int.cu.cc b/tensorflow/core/kernels/reduction_ops_gpu_int.cu.cc
index 91a33b9..c35d8c2 100644
--- a/tensorflow/core/kernels/reduction_ops_gpu_int.cu.cc
+++ b/tensorflow/core/kernels/reduction_ops_gpu_int.cu.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #define EIGEN_USE_GPU
 
@@ -68,4 +68,4 @@
 }  // end namespace functor
 }  // end namespace tensorflow
 
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
diff --git a/tensorflow/core/kernels/reduction_ops_half_mean_sum.cu.cc b/tensorflow/core/kernels/reduction_ops_half_mean_sum.cu.cc
index f33d504..bbb34c9 100644
--- a/tensorflow/core/kernels/reduction_ops_half_mean_sum.cu.cc
+++ b/tensorflow/core/kernels/reduction_ops_half_mean_sum.cu.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #define EIGEN_USE_GPU
 
@@ -64,4 +64,4 @@
 }  // end namespace functor
 }  // end namespace tensorflow
 
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
diff --git a/tensorflow/core/kernels/reduction_ops_half_prod_max_min.cu.cc b/tensorflow/core/kernels/reduction_ops_half_prod_max_min.cu.cc
index 84fd389..d2a180b 100644
--- a/tensorflow/core/kernels/reduction_ops_half_prod_max_min.cu.cc
+++ b/tensorflow/core/kernels/reduction_ops_half_prod_max_min.cu.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #define EIGEN_USE_GPU
 
@@ -64,4 +64,4 @@
 }  // end namespace functor
 }  // end namespace tensorflow
 
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
diff --git a/tensorflow/core/kernels/reduction_ops_max.cc b/tensorflow/core/kernels/reduction_ops_max.cc
index 8bfa44b..fe9775f 100644
--- a/tensorflow/core/kernels/reduction_ops_max.cc
+++ b/tensorflow/core/kernels/reduction_ops_max.cc
@@ -33,7 +33,7 @@
 TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
 #undef REGISTER_CPU_KERNELS
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #define REGISTER_GPU_KERNELS(type)                                             \
   REGISTER_KERNEL_BUILDER(                                                     \
diff --git a/tensorflow/core/kernels/reduction_ops_mean.cc b/tensorflow/core/kernels/reduction_ops_mean.cc
index 67c974e..d314f19 100644
--- a/tensorflow/core/kernels/reduction_ops_mean.cc
+++ b/tensorflow/core/kernels/reduction_ops_mean.cc
@@ -33,7 +33,7 @@
 TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
 #undef REGISTER_CPU_KERNELS
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #define REGISTER_GPU_KERNELS(type)                                      \
   REGISTER_KERNEL_BUILDER(                                              \
@@ -51,8 +51,10 @@
           .HostMemory("reduction_indices"),                             \
       ReductionOp<GPUDevice, type, int64, functor::MeanReducer<type>>);
 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
+#if GOOGLE_CUDA
 TF_CALL_complex64(REGISTER_GPU_KERNELS);
 TF_CALL_complex128(REGISTER_GPU_KERNELS);
+#endif
 #undef REGISTER_GPU_KERNELS
 
 #endif
diff --git a/tensorflow/core/kernels/reduction_ops_min.cc b/tensorflow/core/kernels/reduction_ops_min.cc
index 5c537c5..9f1feae 100644
--- a/tensorflow/core/kernels/reduction_ops_min.cc
+++ b/tensorflow/core/kernels/reduction_ops_min.cc
@@ -33,7 +33,7 @@
 TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
 #undef REGISTER_CPU_KERNELS
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #define REGISTER_GPU_KERNELS(type)                                             \
   REGISTER_KERNEL_BUILDER(                                                     \
diff --git a/tensorflow/core/kernels/reduction_ops_prod.cc b/tensorflow/core/kernels/reduction_ops_prod.cc
index e9b23df..0642bad 100644
--- a/tensorflow/core/kernels/reduction_ops_prod.cc
+++ b/tensorflow/core/kernels/reduction_ops_prod.cc
@@ -33,7 +33,7 @@
 TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
 #undef REGISTER_CPU_KERNELS
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #define REGISTER_GPU_KERNELS(type)                                          \
   REGISTER_KERNEL_BUILDER(Name("Prod")                                      \
@@ -52,8 +52,10 @@
                                       Eigen::internal::ProdReducer<type>>);
 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
 TF_CALL_int32(REGISTER_GPU_KERNELS);
+#if GOOGLE_CUDA
 TF_CALL_complex64(REGISTER_GPU_KERNELS);
 TF_CALL_complex128(REGISTER_GPU_KERNELS);
+#endif
 #undef REGISTER_GPU_KERNELS
 
 #endif
diff --git a/tensorflow/core/kernels/reduction_ops_sum.cc b/tensorflow/core/kernels/reduction_ops_sum.cc
index cf0d0f5..d79684d 100644
--- a/tensorflow/core/kernels/reduction_ops_sum.cc
+++ b/tensorflow/core/kernels/reduction_ops_sum.cc
@@ -33,7 +33,7 @@
 TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
 #undef REGISTER_CPU_KERNELS
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #define REGISTER_GPU_KERNELS(type)                                             \
   REGISTER_KERNEL_BUILDER(                                                     \
@@ -52,8 +52,10 @@
       ReductionOp<GPUDevice, type, int64, Eigen::internal::SumReducer<type>>);
 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
 TF_CALL_int64(REGISTER_GPU_KERNELS);
+#if GOOGLE_CUDA
 TF_CALL_complex64(REGISTER_GPU_KERNELS);
 TF_CALL_complex128(REGISTER_GPU_KERNELS);
+#endif
 #undef REGISTER_GPU_KERNELS
 
 // A special GPU kernel for int32.
diff --git a/tensorflow/core/kernels/regex_full_match_op.cc b/tensorflow/core/kernels/regex_full_match_op.cc
index 7edaaad..04da969 100644
--- a/tensorflow/core/kernels/regex_full_match_op.cc
+++ b/tensorflow/core/kernels/regex_full_match_op.cc
@@ -31,14 +31,14 @@
   void Compute(OpKernelContext* ctx) override {
     const Tensor* input_tensor;
     OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
-    const auto& input_flat = input_tensor->flat<string>();
+    const auto& input_flat = input_tensor->flat<tstring>();
 
     const Tensor* pattern_tensor;
     OP_REQUIRES_OK(ctx, ctx->input("pattern", &pattern_tensor));
     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(pattern_tensor->shape()),
                 errors::InvalidArgument("Pattern must be scalar, but received ",
                                         pattern_tensor->shape().DebugString()));
-    const string pattern = pattern_tensor->flat<string>()(0);
+    const string pattern = pattern_tensor->flat<tstring>()(0);
     const RE2 match(pattern);
     OP_REQUIRES(ctx, match.ok(),
                 errors::InvalidArgument("Invalid pattern: ", pattern,
@@ -71,7 +71,7 @@
   void Compute(OpKernelContext* ctx) override {
     const Tensor* input_tensor;
     OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
-    const auto& input_flat = input_tensor->flat<string>();
+    const auto& input_flat = input_tensor->flat<tstring>();
 
     Tensor* output_tensor = nullptr;
     OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
diff --git a/tensorflow/core/kernels/regex_replace_op.cc b/tensorflow/core/kernels/regex_replace_op.cc
index a1b9488..187a4f9 100644
--- a/tensorflow/core/kernels/regex_replace_op.cc
+++ b/tensorflow/core/kernels/regex_replace_op.cc
@@ -44,9 +44,9 @@
   } else {
     TF_RETURN_IF_ERROR(
         ctx->allocate_output("output", input_tensor->shape(), &output_tensor));
-    output_tensor->flat<string>() = input_tensor->flat<string>();
+    output_tensor->flat<tstring>() = input_tensor->flat<tstring>();
   }
-  auto output_flat = output_tensor->flat<string>();
+  auto output_flat = output_tensor->flat<tstring>();
   for (size_t i = 0; i < output_flat.size(); ++i) {
     if (replace_global) {
       RE2::GlobalReplace(&output_flat(i), match, rewrite);
@@ -70,7 +70,7 @@
     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(pattern_tensor->shape()),
                 errors::InvalidArgument("Pattern must be scalar, but received ",
                                         pattern_tensor->shape().DebugString()));
-    const string pattern = pattern_tensor->flat<string>()(0);
+    const string& pattern = pattern_tensor->scalar<tstring>()();
     const RE2 match(pattern);
     OP_REQUIRES(ctx, match.ok(),
                 errors::InvalidArgument("Invalid pattern: ", pattern,
@@ -81,7 +81,7 @@
     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rewrite_tensor->shape()),
                 errors::InvalidArgument("Rewrite must be scalar, but received ",
                                         rewrite_tensor->shape().DebugString()));
-    const string rewrite = rewrite_tensor->flat<string>()(0);
+    const string& rewrite = rewrite_tensor->scalar<tstring>()();
     OP_REQUIRES_OK(ctx, InternalCompute(match, rewrite, replace_global_, ctx));
   }
 
diff --git a/tensorflow/core/kernels/regex_replace_op_test.cc b/tensorflow/core/kernels/regex_replace_op_test.cc
index 9691d4a..bfc45e8 100644
--- a/tensorflow/core/kernels/regex_replace_op_test.cc
+++ b/tensorflow/core/kernels/regex_replace_op_test.cc
@@ -60,7 +60,7 @@
 Tensor GetTestTensor(int batch) {
   const int sz = TF_ARRAYSIZE(lines);
   Tensor t(DT_STRING, {batch});
-  auto s = t.flat<string>();
+  auto s = t.flat<tstring>();
   for (int i = 0; i < batch; ++i) {
     s(i) = lines[i % sz];
   }
diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc
index e67695d..83ef50a 100644
--- a/tensorflow/core/kernels/relu_op.cc
+++ b/tensorflow/core/kernels/relu_op.cc
@@ -74,7 +74,7 @@
 TF_CALL_GPU_NUMBER_TYPES(REGISTER_ELU_KERNELS);
 #undef REGISTER_ELU_KERNELS
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 // Forward declarations of the functor specializations for GPU.
 namespace functor {
 #define DECLARE_GPU_SPEC(T)                                                    \
@@ -143,11 +143,14 @@
       typename TTypes<T>::Tensor backprops);                                   \
   extern template struct SeluGrad<GPUDevice, T>;
 
+#if GOOGLE_CUDA
+// TODO(rocm) : qint8 datatype currently not supported on the ROCm platform
 template <>
 void Relu<GPUDevice, qint8>::operator()(
     const GPUDevice& d, typename TTypes<qint8>::ConstTensor features,
     typename TTypes<qint8>::Tensor activations);
 extern template struct Relu<GPUDevice, qint8>;
+#endif  // GOOGLE_CUDA
 
 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
 }  // namespace functor
@@ -188,6 +191,7 @@
 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
 #undef REGISTER_GPU_KERNELS
 
+#if GOOGLE_CUDA
 template <typename Device>
 class ReluOp<Device, qint8>
     : public UnaryElementWiseOp<qint8, ReluOp<Device, qint8>> {
@@ -210,6 +214,7 @@
     ReluOp<GPUDevice, qint8>);
 
 #endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #ifdef TENSORFLOW_USE_SYCL
 // Registration of the GPU implementations.
diff --git a/tensorflow/core/kernels/relu_op_gpu.cu.cc b/tensorflow/core/kernels/relu_op_gpu.cu.cc
index 385565b..0dc14c5 100644
--- a/tensorflow/core/kernels/relu_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/relu_op_gpu.cu.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #define EIGEN_USE_GPU
 
@@ -31,6 +31,11 @@
 typedef Eigen::GpuDevice GPUDevice;
 
 namespace functor {
+
+#if GOOGLE_CUDA
+// TODO(rocm): disabling this code on the ROCm platform since the references
+// to `half2` are leading to compile errors.
+
 // This kernel computes ReluGrad by processing one half2, two fp16, at a time.
 // It effectively does: backdrops = (feature > 0) ? gradient : 0
 // It also tries to use native half2 primitives as much as possible.
@@ -111,10 +116,12 @@
         d.stream(), gradient.data(), feature.data(), backprop.data(), count));
   }
 };
+#endif  // GOOGLE_CUDA
 
+#if GOOGLE_CUDA
 __global__ void Relu_int8x4_kernel(int vect_count, const int32* input,
                                    int32* output) {
-  CUDA_1D_KERNEL_LOOP(index, vect_count) {
+  GPU_1D_KERNEL_LOOP(index, vect_count) {
     output[index] = __vmaxs4(input[index], 0);
   }
 }
@@ -141,6 +148,7 @@
         reinterpret_cast<int32*>(output.data())));
   }
 };
+#endif  // GOOGLE_CUDA
 
 }  // namespace functor
 
@@ -158,9 +166,10 @@
   template struct functor::SeluGrad<GPUDevice, T>;
 
 TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
-
+#if GOOGLE_CUDA
 template struct functor::Relu<GPUDevice, qint8>;
+#endif  // GOOGLE_CUDA
 
 }  // end namespace tensorflow
 
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
index 26f107f..5e01f4d 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
@@ -1356,7 +1356,7 @@
       dst_ptr = tensor->flat<int8>().data();
       break;
     case DT_STRING:
-      dst_ptr = tensor->flat<string>().data();
+      dst_ptr = tensor->flat<tstring>().data();
       break;
     case DT_INT64:
       dst_ptr = tensor->flat<int64>().data();
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index 967d4a4..21d4b2a 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -188,11 +188,11 @@
     }
   }
 
-  OP_REQUIRES(
-      ctx, uninitialized_vars.empty(),
-      errors::InvalidArgument("In ReadVariableOp the following variables were "
-                              "found uninitialized: ",
-                              absl::StrJoin(uninitialized_vars, ", ")));
+  OP_REQUIRES(ctx, uninitialized_vars.empty(),
+              errors::FailedPrecondition(
+                  "In ReadVariablesOp the following variables were "
+                  "found uninitialized: ",
+                  absl::StrJoin(uninitialized_vars, ", ")));
 
   for (size_t i = 0; i < dtypes_.size(); ++i) {
     // We're acquiring a reference to the underlying buffer while
diff --git a/tensorflow/core/kernels/restore_op_test.cc b/tensorflow/core/kernels/restore_op_test.cc
index b6f15a9..1e6ca10 100644
--- a/tensorflow/core/kernels/restore_op_test.cc
+++ b/tensorflow/core/kernels/restore_op_test.cc
@@ -94,7 +94,7 @@
 
     // Input #0 is the file name
     Tensor input_0(DT_STRING, TensorShape({}));
-    input_0.scalar<string>()() = filename;
+    input_0.scalar<tstring>()() = filename;
     inputs.push_back({nullptr, &input_0});
 
     // Input #1 is the tensor names
@@ -203,7 +203,7 @@
   // The 1-d integer tensor
   {
     MakeRestoreOp(DT_INT32);
-    (*mutable_input(1).tensor).scalar<string>()() = tensor_names[1];
+    (*mutable_input(1).tensor).scalar<tstring>()() = tensor_names[1];
     TF_ASSERT_OK(RunOpKernel());
     Tensor* output = GetOutput(0);
     TensorShape expected({10});
@@ -215,7 +215,7 @@
   // The 2-d float tensor
   {
     MakeRestoreOp(DT_FLOAT);
-    (*mutable_input(1).tensor).scalar<string>()() = tensor_names[2];
+    (*mutable_input(1).tensor).scalar<tstring>()() = tensor_names[2];
     TF_ASSERT_OK(RunOpKernel());
     Tensor* output = GetOutput(0);
     TensorShape expected({2, 4});
@@ -227,7 +227,7 @@
   // The 2-d double tensor
   {
     MakeRestoreOp(DT_DOUBLE);
-    (*mutable_input(1).tensor).scalar<string>()() = tensor_names[3];
+    (*mutable_input(1).tensor).scalar<tstring>()() = tensor_names[3];
     TF_ASSERT_OK(RunOpKernel());
     Tensor* output = GetOutput(0);
     TensorShape expected({2, 4});
@@ -239,7 +239,7 @@
   // The 2-d qint8 tensor
   {
     MakeRestoreOp(DT_QINT8);
-    (*mutable_input(1).tensor).scalar<string>()() = tensor_names[4];
+    (*mutable_input(1).tensor).scalar<tstring>()() = tensor_names[4];
     TF_ASSERT_OK(RunOpKernel());
     Tensor* output = GetOutput(0);
     TensorShape expected({3, 2});
@@ -251,7 +251,7 @@
   // The 2-d qint32 tensor
   {
     MakeRestoreOp(DT_QINT32);
-    (*mutable_input(1).tensor).scalar<string>()() = tensor_names[5];
+    (*mutable_input(1).tensor).scalar<tstring>()() = tensor_names[5];
     TF_ASSERT_OK(RunOpKernel());
     Tensor* output = GetOutput(0);
     TensorShape expected({2, 3});
@@ -264,7 +264,7 @@
   // The 1-d uint8 tensor
   {
     MakeRestoreOp(DT_UINT8);
-    (*mutable_input(1).tensor).scalar<string>()() = tensor_names[6];
+    (*mutable_input(1).tensor).scalar<tstring>()() = tensor_names[6];
     TF_ASSERT_OK(RunOpKernel());
     Tensor* output = GetOutput(0);
     TensorShape expected({11});
@@ -276,7 +276,7 @@
   // The 1-d int8 tensor
   {
     MakeRestoreOp(DT_INT8);
-    (*mutable_input(1).tensor).scalar<string>()() = tensor_names[7];
+    (*mutable_input(1).tensor).scalar<tstring>()() = tensor_names[7];
     TF_ASSERT_OK(RunOpKernel());
     Tensor* output = GetOutput(0);
     TensorShape expected({7});
@@ -288,7 +288,7 @@
   // The 1-d int16 tensor
   {
     MakeRestoreOp(DT_INT16);
-    (*mutable_input(1).tensor).scalar<string>()() = tensor_names[8];
+    (*mutable_input(1).tensor).scalar<tstring>()() = tensor_names[8];
     TF_ASSERT_OK(RunOpKernel());
     Tensor* output = GetOutput(0);
     TensorShape expected({7});
@@ -300,7 +300,7 @@
   // The 1-d int64 tensor
   {
     MakeRestoreOp(DT_INT64);
-    (*mutable_input(1).tensor).scalar<string>()() = tensor_names[9];
+    (*mutable_input(1).tensor).scalar<tstring>()() = tensor_names[9];
     TF_ASSERT_OK(RunOpKernel());
     Tensor* output = GetOutput(0);
     TensorShape expected({9});
@@ -312,18 +312,18 @@
   // The 1-d string tensor
   {
     MakeRestoreOp(DT_STRING);
-    (*mutable_input(1).tensor).scalar<string>()() = tensor_names[10];
+    (*mutable_input(1).tensor).scalar<tstring>()() = tensor_names[10];
     TF_ASSERT_OK(RunOpKernel());
     Tensor* output = GetOutput(0);
     TensorShape expected({2});
     EXPECT_TRUE(output->shape().IsSameSize(expected));
-    EXPECT_EQ("no", output->flat<string>()(0));
-    EXPECT_EQ("yes", output->flat<string>()(1));
+    EXPECT_EQ("no", output->flat<tstring>()(0));
+    EXPECT_EQ("yes", output->flat<tstring>()(1));
   }
   // The 2-d complex64 tensor
   {
     MakeRestoreOp(DT_COMPLEX64);
-    (*mutable_input(1).tensor).scalar<string>()() = tensor_names[11];
+    (*mutable_input(1).tensor).scalar<tstring>()() = tensor_names[11];
     TF_ASSERT_OK(RunOpKernel());
     Tensor* output = GetOutput(0);
     TensorShape expected({2, 3});
@@ -335,7 +335,7 @@
   // The 2-d half tensor
   {
     MakeRestoreOp(DT_HALF);
-    (*mutable_input(1).tensor).scalar<string>()() = tensor_names[12];
+    (*mutable_input(1).tensor).scalar<tstring>()() = tensor_names[12];
     TF_ASSERT_OK(RunOpKernel());
     Tensor* output = GetOutput(0);
     TensorShape expected({2, 4});
@@ -348,7 +348,7 @@
   // The 2-d empty float tensor
   {
     MakeRestoreOp(DT_FLOAT);
-    (*mutable_input(1).tensor).scalar<string>()() = tensor_names[13];
+    (*mutable_input(1).tensor).scalar<tstring>()() = tensor_names[13];
     TF_ASSERT_OK(RunOpKernel());
     Tensor* output = GetOutput(0);
     TensorShape expected({2, 0});
@@ -398,12 +398,12 @@
 
     // Input #0 is the file name
     Tensor input_0(DT_STRING, TensorShape({}));
-    input_0.scalar<string>()() = filename;
+    input_0.scalar<tstring>()() = filename;
     inputs.push_back({nullptr, &input_0});
 
     // Input #1 is the tensor name
     Tensor input_1(DT_STRING, TensorShape({}));
-    input_1.scalar<string>()() = tensor_name;
+    input_1.scalar<tstring>()() = tensor_name;
     inputs.push_back({nullptr, &input_1});
 
     // Input #2 is a 4x16 integer tensor.
diff --git a/tensorflow/core/kernels/restore_v2_op_test.cc b/tensorflow/core/kernels/restore_v2_op_test.cc
index 3663157..22eb99d 100644
--- a/tensorflow/core/kernels/restore_v2_op_test.cc
+++ b/tensorflow/core/kernels/restore_v2_op_test.cc
@@ -105,7 +105,7 @@
 
       // Input #0 is the file name
       Tensor input_0(DT_STRING, TensorShape({}));
-      input_0.scalar<string>()() = filename;
+      input_0.scalar<tstring>()() = filename;
       inputs.push_back({nullptr, &input_0});
 
       // Input #1 is the tensor names
@@ -213,7 +213,7 @@
     // The 1-d integer tensor
     {
       MakeRestoreOp(DT_INT32);
-      (*mutable_input(1).tensor).flat<string>()(0) = tensor_names[1];
+      (*mutable_input(1).tensor).flat<tstring>()(0) = tensor_names[1];
       TF_ASSERT_OK(RunOpKernel());
       Tensor* output = GetOutput(0);
       TensorShape expected({10});
@@ -225,7 +225,7 @@
     // The 2-d float tensor
     {
       MakeRestoreOp(DT_FLOAT);
-      (*mutable_input(1).tensor).flat<string>()(0) = tensor_names[2];
+      (*mutable_input(1).tensor).flat<tstring>()(0) = tensor_names[2];
       TF_ASSERT_OK(RunOpKernel());
       Tensor* output = GetOutput(0);
       TensorShape expected({2, 4});
@@ -237,7 +237,7 @@
     // The 2-d double tensor
     {
       MakeRestoreOp(DT_DOUBLE);
-      (*mutable_input(1).tensor).flat<string>()(0) = tensor_names[3];
+      (*mutable_input(1).tensor).flat<tstring>()(0) = tensor_names[3];
       TF_ASSERT_OK(RunOpKernel());
       Tensor* output = GetOutput(0);
       TensorShape expected({2, 4});
@@ -249,7 +249,7 @@
     // The 2-d qint8 tensor
     {
       MakeRestoreOp(DT_QINT8);
-      (*mutable_input(1).tensor).flat<string>()(0) = tensor_names[4];
+      (*mutable_input(1).tensor).flat<tstring>()(0) = tensor_names[4];
       TF_ASSERT_OK(RunOpKernel());
       Tensor* output = GetOutput(0);
       TensorShape expected({3, 2});
@@ -261,7 +261,7 @@
     // The 2-d qint32 tensor
     {
       MakeRestoreOp(DT_QINT32);
-      (*mutable_input(1).tensor).flat<string>()(0) = tensor_names[5];
+      (*mutable_input(1).tensor).flat<tstring>()(0) = tensor_names[5];
       TF_ASSERT_OK(RunOpKernel());
       Tensor* output = GetOutput(0);
       TensorShape expected({2, 3});
@@ -274,7 +274,7 @@
     // The 1-d uint8 tensor
     {
       MakeRestoreOp(DT_UINT8);
-      (*mutable_input(1).tensor).flat<string>()(0) = tensor_names[6];
+      (*mutable_input(1).tensor).flat<tstring>()(0) = tensor_names[6];
       TF_ASSERT_OK(RunOpKernel());
       Tensor* output = GetOutput(0);
       TensorShape expected({11});
@@ -286,7 +286,7 @@
     // The 1-d int8 tensor
     {
       MakeRestoreOp(DT_INT8);
-      (*mutable_input(1).tensor).flat<string>()(0) = tensor_names[7];
+      (*mutable_input(1).tensor).flat<tstring>()(0) = tensor_names[7];
       TF_ASSERT_OK(RunOpKernel());
       Tensor* output = GetOutput(0);
       TensorShape expected({7});
@@ -298,7 +298,7 @@
     // The 1-d int16 tensor
     {
       MakeRestoreOp(DT_INT16);
-      (*mutable_input(1).tensor).flat<string>()(0) = tensor_names[8];
+      (*mutable_input(1).tensor).flat<tstring>()(0) = tensor_names[8];
       TF_ASSERT_OK(RunOpKernel());
       Tensor* output = GetOutput(0);
       TensorShape expected({7});
@@ -310,7 +310,7 @@
     // The 1-d int64 tensor
     {
       MakeRestoreOp(DT_INT64);
-      (*mutable_input(1).tensor).flat<string>()(0) = tensor_names[9];
+      (*mutable_input(1).tensor).flat<tstring>()(0) = tensor_names[9];
       TF_ASSERT_OK(RunOpKernel());
       Tensor* output = GetOutput(0);
       TensorShape expected({9});
@@ -322,7 +322,7 @@
     // The 2-d complex64 tensor
     {
       MakeRestoreOp(DT_COMPLEX64);
-      (*mutable_input(1).tensor).flat<string>()(0) = tensor_names[10];
+      (*mutable_input(1).tensor).flat<tstring>()(0) = tensor_names[10];
       TF_ASSERT_OK(RunOpKernel());
       Tensor* output = GetOutput(0);
       TensorShape expected({2, 3});
@@ -334,7 +334,7 @@
     // The 2-d half tensor
     {
       MakeRestoreOp(DT_HALF);
-      (*mutable_input(1).tensor).flat<string>()(0) = tensor_names[11];
+      (*mutable_input(1).tensor).flat<tstring>()(0) = tensor_names[11];
       TF_ASSERT_OK(RunOpKernel());
       Tensor* output = GetOutput(0);
       TensorShape expected({2, 4});
diff --git a/tensorflow/core/kernels/rnn/BUILD b/tensorflow/core/kernels/rnn/BUILD
index b7fafe9..8096e21 100644
--- a/tensorflow/core/kernels/rnn/BUILD
+++ b/tensorflow/core/kernels/rnn/BUILD
@@ -7,7 +7,7 @@
     "tf_kernel_library",
 )
 load(
-    "//tensorflow/core:platform/default/cuda_build_defs.bzl",
+    "//tensorflow/core/platform:default/cuda_build_defs.bzl",
     "if_cuda_is_configured",
 )
 load(
diff --git a/tensorflow/core/kernels/rnn/lstm_ops.cc b/tensorflow/core/kernels/rnn/lstm_ops.cc
index 7e067b3..57d3e9b 100644
--- a/tensorflow/core/kernels/rnn/lstm_ops.cc
+++ b/tensorflow/core/kernels/rnn/lstm_ops.cc
@@ -41,7 +41,7 @@
 
 namespace functor {
 
-template <typename T>
+template <typename T, GateLayout gate_layout>
 void LSTMBlockCellFpropWithEigen(
     const LSTMBlockCell& cell, OpKernelContext* ctx, const CPUDevice& d,
     const float forget_bias, const float cell_clip, bool use_peephole,
@@ -52,7 +52,7 @@
     typename TTypes<T>::Matrix xh, typename TTypes<T>::Matrix i,
     typename TTypes<T>::Matrix cs, typename TTypes<T>::Matrix f,
     typename TTypes<T>::Matrix o, typename TTypes<T>::Matrix ci,
-    typename TTypes<T>::Matrix co, typename TTypes<T>::Matrix icfo,
+    typename TTypes<T>::Matrix co, typename TTypes<T>::Matrix gates,
     typename TTypes<T>::Matrix h) {
   // Concat xh = [x, h].
   xh.slice(cell.xh_x_offsets(), cell.xh_x_extents()).device(d) = x;
@@ -62,10 +62,10 @@
   typename TTypes<T>::ConstMatrix const_xh(xh.data(), xh.dimensions());
   TensorBlasGemm<CPUDevice, T, false /* USE_CUBLAS */>::compute(
       ctx, d, false, false, typename gemm_compute_type<T>::type(1.f), const_xh,
-      w, typename gemm_compute_type<T>::type(0.f), icfo);
+      w, typename gemm_compute_type<T>::type(0.f), gates);
   Eigen::array<Eigen::DenseIndex, 2> b_shape({1, b.dimensions()[0]});
   Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({cell.batch_size(), 1});
-  icfo.device(d) += b.reshape(b_shape).broadcast(broadcast_shape);
+  gates.device(d) += b.reshape(b_shape).broadcast(broadcast_shape);
 
   Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell.cell_size()});
   Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({cell.batch_size(), 1});
@@ -74,26 +74,30 @@
   if (use_peephole) {
     auto i_peep = cs_prev * wci.reshape(p_shape).broadcast(p_broadcast_shape);
     i.device(d) =
-        (icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()) + i_peep)
+        (gates.slice(cell.gates_i_offsets(), cell.cell_extents()) + i_peep)
             .sigmoid();
   } else {
     i.device(d) =
-        icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).sigmoid();
+        gates.slice(cell.gates_i_offsets(), cell.cell_extents()).sigmoid();
   }
 
   // Cell input.
-  ci.device(d) = icfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).tanh();
+  ci.device(d) =
+      gates.slice(cell.gates_c_offsets(gate_layout), cell.cell_extents())
+          .tanh();
 
   // Forget gate (w/ bias).
   if (use_peephole) {
     auto f_peep = cs_prev * wcf.reshape(p_shape).broadcast(p_broadcast_shape);
-    f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) +
-                   f.constant(T(forget_bias)) + f_peep)
-                      .sigmoid();
+    f.device(d) =
+        (gates.slice(cell.gates_f_offsets(gate_layout), cell.cell_extents()) +
+         f.constant(T(forget_bias)) + f_peep)
+            .sigmoid();
   } else {
-    f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) +
-                   f.constant(T(forget_bias)))
-                      .sigmoid();
+    f.device(d) =
+        (gates.slice(cell.gates_f_offsets(gate_layout), cell.cell_extents()) +
+         f.constant(T(forget_bias)))
+            .sigmoid();
   }
 
   // cs = ci .* i + f .* cs_prev
@@ -111,18 +115,18 @@
   if (use_peephole) {
     auto o_peep = cs * wco.reshape(p_shape).broadcast(p_broadcast_shape);
     o.device(d) =
-        (icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()) + o_peep)
+        (gates.slice(cell.gates_o_offsets(), cell.cell_extents()) + o_peep)
             .sigmoid();
   } else {
     o.device(d) =
-        icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).sigmoid();
+        gates.slice(cell.gates_o_offsets(), cell.cell_extents()).sigmoid();
   }
 
   // h = o .* co
   h.device(d) = o * co;
 }
 
-template <typename Device, typename T, bool USE_CUBLAS>
+template <typename Device, typename T, GateLayout gate_layout>
 void LSTMBlockCellBpropWithEigen(
     const LSTMBlockCell& cell, OpKernelContext* ctx, const Device& d,
     bool use_peephole, typename TTypes<T>::ConstMatrix x,
@@ -137,7 +141,7 @@
     typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,
     typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,
     typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,
-    typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad,
+    typename TTypes<T>::Matrix dgates, typename TTypes<T>::Matrix cs_prev_grad,
     typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,
     typename TTypes<T>::Vec wco_grad) {
   // do[t] = sigm'(o[t]) .* dh[t] .* co[t]
@@ -162,10 +166,12 @@
   // di[t] = sigm'(i[t]) dcs[t] ci[t]
   di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci;
 
-  dicfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).device(d) = di;
-  dicfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).device(d) = dci;
-  dicfo.slice(cell.icfo_f_offsets(), cell.cell_extents()).device(d) = df;
-  dicfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).device(d) = do_;
+  dgates.slice(cell.gates_i_offsets(), cell.cell_extents()).device(d) = di;
+  dgates.slice(cell.gates_c_offsets(gate_layout), cell.cell_extents())
+      .device(d) = dci;
+  dgates.slice(cell.gates_f_offsets(gate_layout), cell.cell_extents())
+      .device(d) = df;
+  dgates.slice(cell.gates_o_offsets(), cell.cell_extents()).device(d) = do_;
 
   cs_prev_grad.device(d) = dcs * f;
   if (use_peephole) {
@@ -178,10 +184,69 @@
   }
 }
 
-#define DEFINE_CPU_SPECS(T)                                                   \
+#define DECLARE_CPU_FBPROP(T, GATE_LAYOUT)                                     \
+  template <>                                                                  \
+  void LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */, GATE_LAYOUT>:: \
+  operator()(                                                                  \
+      OpKernelContext* ctx, const CPUDevice& d, const float forget_bias,       \
+      const float cell_clip, bool use_peephole,                                \
+      typename TTypes<T>::ConstMatrix x,                                       \
+      typename TTypes<T>::ConstMatrix cs_prev,                                 \
+      typename TTypes<T>::ConstMatrix h_prev,                                  \
+      typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,     \
+      typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,      \
+      typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,           \
+      typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs,             \
+      typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o,              \
+      typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co,            \
+      typename TTypes<T>::Matrix gates, typename TTypes<T>::Matrix h) {        \
+    LSTMBlockCellFpropWithEigen<T, GATE_LAYOUT>(                               \
+        *this, ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev,       \
+        h_prev, w, wci, wcf, wco, b, xh, i, cs, f, o, ci, co, gates, h);       \
+  }                                                                            \
+  template <>                                                                  \
+  void LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */, GATE_LAYOUT>:: \
+  operator()(                                                                  \
+      OpKernelContext* ctx, const CPUDevice& d, bool use_peephole,             \
+      typename TTypes<T>::ConstMatrix x,                                       \
+      typename TTypes<T>::ConstMatrix cs_prev,                                 \
+      typename TTypes<T>::ConstMatrix h_prev,                                  \
+      typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,     \
+      typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,      \
+      typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i,       \
+      typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f,   \
+      typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci,   \
+      typename TTypes<T>::ConstMatrix co,                                      \
+      typename TTypes<T>::ConstMatrix cs_grad,                                 \
+      typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,  \
+      typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,          \
+      typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,            \
+      typename TTypes<T>::Matrix dgates,                                       \
+      typename TTypes<T>::Matrix cs_prev_grad,                                 \
+      typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,      \
+      typename TTypes<T>::Vec wco_grad) {                                      \
+    LSTMBlockCellBpropWithEigen<CPUDevice, T, GATE_LAYOUT>(                    \
+        *this, ctx, d, use_peephole, x, cs_prev, h_prev, w, wci, wcf, wco, b,  \
+        i, cs, f, o, ci, co, cs_grad, h_grad, do_, dcs, dci, df, di, dgates,   \
+        cs_prev_grad, wci_grad, wcf_grad, wco_grad);                           \
+  }                                                                            \
+  template struct LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */,     \
+                                     GATE_LAYOUT>;                             \
+  template struct LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */,     \
+                                     GATE_LAYOUT>;
+
+#define DECLARE_CPU_SPECS(T) DECLARE_CPU_FBPROP(T, ICFO);
+
+DECLARE_CPU_SPECS(Eigen::half);
+DECLARE_CPU_SPECS(float);
+#undef DECLARE_CPU_SPECS
+#undef DECLARE_CPU_FBPROP
+
+#if GOOGLE_CUDA
+#define DECLARE_GPU_FBPROP(T, GATE_LAYOUT)                                    \
   template <>                                                                 \
-  void LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */>::operator()(  \
-      OpKernelContext* ctx, const CPUDevice& d, const float forget_bias,      \
+  void LSTMBlockCellFprop<GPUDevice, T, true, GATE_LAYOUT>::operator()(       \
+      OpKernelContext* ctx, const GPUDevice& d, const float forget_bias,      \
       const float cell_clip, bool use_peephole,                               \
       typename TTypes<T>::ConstMatrix x,                                      \
       typename TTypes<T>::ConstMatrix cs_prev,                                \
@@ -192,14 +257,10 @@
       typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs,            \
       typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o,             \
       typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co,           \
-      typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h) {        \
-    LSTMBlockCellFpropWithEigen<T>(                                           \
-        *this, ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev,      \
-        h_prev, w, wci, wcf, wco, b, xh, i, cs, f, o, ci, co, icfo, h);       \
-  }                                                                           \
+      typename TTypes<T>::Matrix gates, typename TTypes<T>::Matrix h);        \
   template <>                                                                 \
-  void LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */>::operator()(  \
-      OpKernelContext* ctx, const CPUDevice& d, bool use_peephole,            \
+  void LSTMBlockCellBprop<GPUDevice, T, true, GATE_LAYOUT>::operator()(       \
+      OpKernelContext* ctx, const GPUDevice& d, bool use_peephole,            \
       typename TTypes<T>::ConstMatrix x,                                      \
       typename TTypes<T>::ConstMatrix cs_prev,                                \
       typename TTypes<T>::ConstMatrix h_prev,                                 \
@@ -213,25 +274,25 @@
       typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
       typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,         \
       typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,           \
-      typename TTypes<T>::Matrix dicfo,                                       \
+      typename TTypes<T>::Matrix dgates,                                      \
       typename TTypes<T>::Matrix cs_prev_grad,                                \
       typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,     \
-      typename TTypes<T>::Vec wco_grad) {                                     \
-    LSTMBlockCellBpropWithEigen<CPUDevice, T, false /* USE_CUBLAS */>(        \
-        *this, ctx, d, use_peephole, x, cs_prev, h_prev, w, wci, wcf, wco, b, \
-        i, cs, f, o, ci, co, cs_grad, h_grad, do_, dcs, dci, df, di, dicfo,   \
-        cs_prev_grad, wci_grad, wcf_grad, wco_grad);                          \
-  }                                                                           \
-  template struct LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */>;   \
-  template struct LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */>;
+      typename TTypes<T>::Vec wco_grad);                                      \
+                                                                              \
+  extern template struct LSTMBlockCellBprop<                                  \
+      GPUDevice, T, true /* USE_CUBLAS */, GATE_LAYOUT>;                      \
+  extern template struct LSTMBlockCellFprop<GPUDevice, T, true, GATE_LAYOUT>;
 
-DEFINE_CPU_SPECS(float);
-DEFINE_CPU_SPECS(Eigen::half);
-#undef DEFINE_CPU_SPECS
+#define DECLARE_GPU_SPECS(T) DECLARE_GPU_FBPROP(T, ICFO);
 
+DECLARE_GPU_SPECS(float);
+DECLARE_GPU_SPECS(Eigen::half);
+#undef DECLARE_GPU_SPECS
+#undef DECLARE_GPU_FBROP
+#endif  // GOOGLE_CUDA
 }  // namespace functor
 
-template <typename Device, typename T, bool USE_CUBLAS>
+template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
 class LSTMBlockCellOp : public OpKernel {
  public:
   explicit LSTMBlockCellOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@@ -345,23 +406,24 @@
                             TensorShape({batch_size, input_size + cell_size}),
                             &xh_tensor));
 
-    Tensor icfo_tensor;
+    Tensor gates_tensor;
     OP_REQUIRES_OK(ctx,
                    ctx->allocate_temp(DataTypeToEnum<T>::v(),
                                       TensorShape({batch_size, cell_size * 4}),
-                                      &icfo_tensor));
+                                      &gates_tensor));
 
     const Device& device = ctx->eigen_device<Device>();
 
-    functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
-                                                       cell_size)(
+    functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS, gate_layout>(
+        batch_size, input_size, cell_size)(
         ctx, device, forget_bias_, cell_clip_, use_peephole_,
         x_tensor->matrix<T>(), cs_prev_tensor->matrix<T>(),
         h_prev_tensor->matrix<T>(), w_tensor->matrix<T>(), wci_tensor->vec<T>(),
         wcf_tensor->vec<T>(), wco_tensor->vec<T>(), b_tensor->vec<T>(),
         xh_tensor.matrix<T>(), i_tensor->matrix<T>(), cs_tensor->matrix<T>(),
         f_tensor->matrix<T>(), o_tensor->matrix<T>(), ci_tensor->matrix<T>(),
-        co_tensor->matrix<T>(), icfo_tensor.matrix<T>(), h_tensor->matrix<T>());
+        co_tensor->matrix<T>(), gates_tensor.matrix<T>(),
+        h_tensor->matrix<T>());
   }
 
  private:
@@ -373,48 +435,24 @@
 #define REGISTER_KERNEL(T)                                             \
   REGISTER_KERNEL_BUILDER(                                             \
       Name("LSTMBlockCell").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
-      LSTMBlockCellOp<CPUDevice, T, false>);
-REGISTER_KERNEL(float);
+      LSTMBlockCellOp<CPUDevice, T, false, ICFO>);
+
 REGISTER_KERNEL(Eigen::half);
+REGISTER_KERNEL(float);
 #undef REGISTER_KERNEL
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-namespace functor {
-#define DECLARE_GPU_SPEC(T)                                                \
-  template <>                                                              \
-  void LSTMBlockCellFprop<GPUDevice, T, true>::operator()(                 \
-      OpKernelContext* ctx, const GPUDevice& d, const float forget_bias,   \
-      const float cell_clip, bool use_peephole,                            \
-      typename TTypes<T>::ConstMatrix x,                                   \
-      typename TTypes<T>::ConstMatrix cs_prev,                             \
-      typename TTypes<T>::ConstMatrix h_prev,                              \
-      typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
-      typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,  \
-      typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,       \
-      typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs,         \
-      typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o,          \
-      typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co,        \
-      typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h);      \
-                                                                           \
-  extern template struct LSTMBlockCellFprop<GPUDevice, T, true>;
-
-DECLARE_GPU_SPEC(float);
-DECLARE_GPU_SPEC(Eigen::half);
-#undef DECLARE_GPU_SPEC
-}  // end namespace functor
-
 #define REGISTER_GPU_KERNEL(T)                                         \
   REGISTER_KERNEL_BUILDER(                                             \
       Name("LSTMBlockCell").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
-      LSTMBlockCellOp<GPUDevice, T, true>);
+      LSTMBlockCellOp<GPUDevice, T, true, ICFO>);
 
-REGISTER_GPU_KERNEL(float);
 REGISTER_GPU_KERNEL(Eigen::half);
-// REGISTER_GPU_KERNEL(double);
+REGISTER_GPU_KERNEL(float);
 #undef REGISTER_GPU_KERNEL
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
-template <typename Device, typename T, bool USE_CUBLAS>
+template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
 class LSTMBlockCellGradOp : public OpKernel {
  public:
   explicit LSTMBlockCellGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@@ -586,10 +624,10 @@
                  {"cs_grad"}, "cs_prev_grad",
                  TensorShape({batch_size, cell_size}), &cs_prev_grad_tensor));
 
-    Tensor* dicfo_tensor = nullptr;
+    Tensor* dgates_tensor = nullptr;
     OP_REQUIRES_OK(ctx, ctx->allocate_output(
                             "dicfo", TensorShape({batch_size, cell_size * 4}),
-                            &dicfo_tensor));
+                            &dgates_tensor));
 
     Tensor* wci_grad_tensor = nullptr;
     OP_REQUIRES_OK(
@@ -638,8 +676,8 @@
     functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<T>());
     functor::TensorZero<Device, T>()(device, wco_grad_tensor->flat<T>());
 
-    functor::LSTMBlockCellBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
-                                                       cell_size)(
+    functor::LSTMBlockCellBprop<Device, T, USE_CUBLAS, gate_layout>(
+        batch_size, input_size, cell_size)(
         ctx, device, use_peephole_, x_tensor->matrix<T>(),
         cs_prev_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
         w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
@@ -648,9 +686,10 @@
         ci_tensor->matrix<T>(), co_tensor->matrix<T>(),
         cs_grad_tensor->matrix<T>(), h_grad_tensor->matrix<T>(),
         do_tensor.matrix<T>(), dcs_tensor.matrix<T>(), dci_tensor.matrix<T>(),
-        df_tensor.matrix<T>(), di_tensor.matrix<T>(), dicfo_tensor->matrix<T>(),
-        cs_prev_grad_tensor->matrix<T>(), wci_grad_tensor->vec<T>(),
-        wcf_grad_tensor->vec<T>(), wco_grad_tensor->vec<T>());
+        df_tensor.matrix<T>(), di_tensor.matrix<T>(),
+        dgates_tensor->matrix<T>(), cs_prev_grad_tensor->matrix<T>(),
+        wci_grad_tensor->vec<T>(), wcf_grad_tensor->vec<T>(),
+        wco_grad_tensor->vec<T>());
   }
 
  protected:
@@ -660,52 +699,19 @@
 #define REGISTER_KERNEL(T)                                                 \
   REGISTER_KERNEL_BUILDER(                                                 \
       Name("LSTMBlockCellGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
-      LSTMBlockCellGradOp<CPUDevice, T, false>);
+      LSTMBlockCellGradOp<CPUDevice, T, false, ICFO>);
 REGISTER_KERNEL(float);
 REGISTER_KERNEL(Eigen::half);
 #undef REGISTER_KERNEL
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-namespace functor {
-#define DECLARE_GPU_SPEC(T)                                                   \
-  template <>                                                                 \
-  void LSTMBlockCellBprop<GPUDevice, T, true>::operator()(                    \
-      OpKernelContext* ctx, const GPUDevice& d, bool use_peephole,            \
-      typename TTypes<T>::ConstMatrix x,                                      \
-      typename TTypes<T>::ConstMatrix cs_prev,                                \
-      typename TTypes<T>::ConstMatrix h_prev,                                 \
-      typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,    \
-      typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,     \
-      typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i,      \
-      typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f,  \
-      typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci,  \
-      typename TTypes<T>::ConstMatrix co,                                     \
-      typename TTypes<T>::ConstMatrix cs_grad,                                \
-      typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
-      typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,         \
-      typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,           \
-      typename TTypes<T>::Matrix dicfo,                                       \
-      typename TTypes<T>::Matrix cs_prev_grad,                                \
-      typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,     \
-      typename TTypes<T>::Vec wco_grad);                                      \
-                                                                              \
-  extern template struct LSTMBlockCellBprop<GPUDevice, T,                     \
-                                            true /* USE_CUBLAS */>;
-
-DECLARE_GPU_SPEC(float);
-DECLARE_GPU_SPEC(Eigen::half);
-// DECLARE_GPU_SPEC(double);
-#undef DECLARE_GPU_SPEC
-}  // namespace functor
-
 #define REGISTER_GPU_KERNEL(T)                                             \
   REGISTER_KERNEL_BUILDER(                                                 \
       Name("LSTMBlockCellGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
-      LSTMBlockCellGradOp<GPUDevice, T, true>);
+      LSTMBlockCellGradOp<GPUDevice, T, true, ICFO>);
 
-REGISTER_GPU_KERNEL(float);
 REGISTER_GPU_KERNEL(Eigen::half);
-// REGISTER_GPU_KERNEL(double);
+REGISTER_GPU_KERNEL(float);
 #undef REGISTER_GPU_KERNEL
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
@@ -817,7 +823,7 @@
 
 }  // namespace
 
-template <typename Device, typename T, bool USE_CUBLAS>
+template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
 class BlockLSTMOp : public OpKernel {
  public:
   explicit BlockLSTMOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@@ -948,11 +954,11 @@
                             TensorShape({batch_size, input_size + cell_size}),
                             &xh_tensor));
 
-    Tensor icfo_tensor;
+    Tensor gates_tensor;
     OP_REQUIRES_OK(ctx,
                    ctx->allocate_temp(DataTypeToEnum<T>::v(),
                                       TensorShape({batch_size, cell_size * 4}),
-                                      &icfo_tensor));
+                                      &gates_tensor));
 
     const Device& device = ctx->eigen_device<Device>();
 
@@ -974,16 +980,16 @@
       Tensor co_tensor = slicer.OutputSlice(co_out, t, "co_out");
       Tensor h_tensor = slicer.OutputSlice(h_out, t, "h_out");
 
-      functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
-                                                         cell_size)(
+      functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS, gate_layout>(
+          batch_size, input_size, cell_size)(
           ctx, device, forget_bias_, cell_clip_, use_peephole_,
           x_tensor.matrix<T>(), cs_prev_tensor2.matrix<T>(),
           h_prev_tensor2.matrix<T>(), w_tensor->matrix<T>(),
           wci_tensor->vec<T>(), wcf_tensor->vec<T>(), wco_tensor->vec<T>(),
           b_tensor->vec<T>(), xh_tensor.matrix<T>(), i_tensor.matrix<T>(),
           cs_tensor.matrix<T>(), f_tensor.matrix<T>(), o_tensor.matrix<T>(),
-          ci_tensor.matrix<T>(), co_tensor.matrix<T>(), icfo_tensor.matrix<T>(),
-          h_tensor.matrix<T>());
+          ci_tensor.matrix<T>(), co_tensor.matrix<T>(),
+          gates_tensor.matrix<T>(), h_tensor.matrix<T>());
       slicer.FinishTimeStep();
     }
 
@@ -1007,14 +1013,15 @@
 #define REGISTER_KERNEL(T)                                         \
   REGISTER_KERNEL_BUILDER(                                         \
       Name("BlockLSTM").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
-      BlockLSTMOp<CPUDevice, T, false>);
-REGISTER_KERNEL(float);
+      BlockLSTMOp<CPUDevice, T, false, ICFO>);
+
 REGISTER_KERNEL(Eigen::half);
+REGISTER_KERNEL(float);
 #undef REGISTER_KERNEL
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 namespace functor {
-#define DECLARE_GPU_SPEC(T)                                              \
+#define DECLARE_GPU_SPECS(T)                                             \
   template <>                                                            \
   void TensorZero<GPUDevice, T>::operator()(const GPUDevice& d,          \
                                             typename TTypes<T>::Flat t); \
@@ -1027,10 +1034,9 @@
                                                                          \
   extern template struct TensorUnalignedZero<GPUDevice, T>;
 
-DECLARE_GPU_SPEC(float);
-DECLARE_GPU_SPEC(Eigen::half);
-// DECLARE_GPU_SPEC(double);
-#undef DECLARE_GPU_SPEC
+DECLARE_GPU_SPECS(Eigen::half);
+DECLARE_GPU_SPECS(float);
+#undef DECLARE_GPU_SPECS
 }  // end namespace functor
 
 #define REGISTER_GPU_KERNEL(T)                           \
@@ -1038,15 +1044,14 @@
                               .Device(DEVICE_GPU)        \
                               .HostMemory("seq_len_max") \
                               .TypeConstraint<T>("T"),   \
-                          BlockLSTMOp<GPUDevice, T, true>);
+                          BlockLSTMOp<GPUDevice, T, true, ICFO>);
 
-REGISTER_GPU_KERNEL(float);
 REGISTER_GPU_KERNEL(Eigen::half);
-// REGISTER_GPU_KERNEL(double);
+REGISTER_GPU_KERNEL(float);
 #undef REGISTER_GPU_KERNEL
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
-template <typename Device, typename T, bool USE_CUBLAS>
+template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
 class BlockLSTMGradOp : public OpKernel {
  public:
   explicit BlockLSTMGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@@ -1188,11 +1193,11 @@
     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
                                            batch_cell_shape, &di_tensor));
 
-    Tensor dicfo_tensor;
+    Tensor dgates_tensor;
     OP_REQUIRES_OK(ctx,
                    ctx->allocate_temp(DataTypeToEnum<T>::v(),
                                       TensorShape({batch_size, cell_size * 4}),
-                                      &dicfo_tensor));
+                                      &dgates_tensor));
 
     Tensor cs_grad_tensor;
     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
@@ -1249,8 +1254,8 @@
       const Tensor& const_h_grad_tensor = h_grad_tensor;
 
       Tensor x_grad_tensor = slicer.OutputSlice(x_grad, t, "x_grad");
-      functor::BlockLSTMBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
-                                                     cell_size)(
+      functor::BlockLSTMBprop<Device, T, USE_CUBLAS, gate_layout>(
+          batch_size, input_size, cell_size)(
           ctx, device, use_peephole_, x_tensor.matrix<T>(),
           cs_prev_tensor2.matrix<T>(), h_prev_tensor2.matrix<T>(),
           w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
@@ -1260,7 +1265,7 @@
           const_cs_grad_tensor.matrix<T>(), const_h_grad_tensor.matrix<T>(),
           do_tensor.matrix<T>(), dcs_tensor.matrix<T>(), dci_tensor.matrix<T>(),
           df_tensor.matrix<T>(), di_tensor.matrix<T>(),
-          dicfo_tensor.matrix<T>(), cs_prev_grad_tensor->matrix<T>(),
+          dgates_tensor.matrix<T>(), cs_prev_grad_tensor->matrix<T>(),
           h_prev_grad_tensor->matrix<T>(), xh_grad_tensor.matrix<T>(),
           x_grad_tensor.matrix<T>(), w_grad_tensor->matrix<T>(),
           wci_grad_tensor->vec<T>(), wcf_grad_tensor->vec<T>(),
@@ -1282,14 +1287,41 @@
 #define REGISTER_KERNEL(T)                                             \
   REGISTER_KERNEL_BUILDER(                                             \
       Name("BlockLSTMGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
-      BlockLSTMGradOp<CPUDevice, T, false>);
-REGISTER_KERNEL(float);
+      BlockLSTMGradOp<CPUDevice, T, false, ICFO>);
+
 REGISTER_KERNEL(Eigen::half);
+REGISTER_KERNEL(float);
 #undef REGISTER_KERNEL
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 namespace functor {
-#define DECLARE_GPU_SPEC(T)                                                    \
+#define DECLARE_GPU_BPROP(T, GATE_LAYOUT)                                     \
+  template <>                                                                 \
+  void BlockLSTMBprop<GPUDevice, T, true, GATE_LAYOUT>::operator()(           \
+      OpKernelContext* ctx, const GPUDevice& d, bool use_peephole,            \
+      typename TTypes<T>::ConstMatrix x,                                      \
+      typename TTypes<T>::ConstMatrix cs_prev,                                \
+      typename TTypes<T>::ConstMatrix h_prev,                                 \
+      typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,    \
+      typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,     \
+      typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,          \
+      typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs,  \
+      typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o,   \
+      typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co, \
+      typename TTypes<T>::ConstMatrix cs_grad,                                \
+      typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
+      typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,         \
+      typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,           \
+      typename TTypes<T>::Matrix dgates,                                      \
+      typename TTypes<T>::Matrix cs_prev_grad,                                \
+      typename TTypes<T>::Matrix h_prev_grad,                                 \
+      typename TTypes<T>::Matrix xh_grad, typename TTypes<T>::Matrix x_grad,  \
+      typename TTypes<T>::Matrix w_grad, typename TTypes<T>::Vec wci_grad,    \
+      typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad,     \
+      typename TTypes<T>::Vec b_grad);                                        \
+  extern template struct BlockLSTMBprop<GPUDevice, T, true, GATE_LAYOUT>;
+
+#define DECLARE_GPU_SPECS(T)                                                   \
   template <>                                                                  \
   void TensorCopy<GPUDevice, T>::operator()(const GPUDevice& d,                \
                                             typename TTypes<T>::ConstFlat src, \
@@ -1310,38 +1342,15 @@
       const GPUDevice& d, typename TTypes<T>::ConstFlat a,                     \
       typename TTypes<T>::ConstFlat b, typename TTypes<T>::Flat c);            \
                                                                                \
-  template <>                                                                  \
-  void BlockLSTMBprop<GPUDevice, T, true>::operator()(                         \
-      OpKernelContext* ctx, const GPUDevice& d, bool use_peephole,             \
-      typename TTypes<T>::ConstMatrix x,                                       \
-      typename TTypes<T>::ConstMatrix cs_prev,                                 \
-      typename TTypes<T>::ConstMatrix h_prev,                                  \
-      typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,     \
-      typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,      \
-      typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,           \
-      typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs,   \
-      typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o,    \
-      typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co,  \
-      typename TTypes<T>::ConstMatrix cs_grad,                                 \
-      typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,  \
-      typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,          \
-      typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,            \
-      typename TTypes<T>::Matrix dicfo,                                        \
-      typename TTypes<T>::Matrix cs_prev_grad,                                 \
-      typename TTypes<T>::Matrix h_prev_grad,                                  \
-      typename TTypes<T>::Matrix xh_grad, typename TTypes<T>::Matrix x_grad,   \
-      typename TTypes<T>::Matrix w_grad, typename TTypes<T>::Vec wci_grad,     \
-      typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad,      \
-      typename TTypes<T>::Vec b_grad);                                         \
-                                                                               \
   extern template struct TensorCopy<GPUDevice, T>;                             \
   extern template struct TensorAdd<GPUDevice, T>;                              \
-  extern template struct BlockLSTMBprop<GPUDevice, T, true>;
+                                                                               \
+  DECLARE_GPU_BPROP(T, ICFO);
 
-DECLARE_GPU_SPEC(float);
-DECLARE_GPU_SPEC(Eigen::half);
-// DECLARE_GPU_SPEC(double);
-#undef DECLARE_GPU_SPEC
+DECLARE_GPU_SPECS(Eigen::half);
+DECLARE_GPU_SPECS(float);
+#undef DECLARE_GPU_SPECS
+#undef DECLARE_GPU_BPROP
 }  // end namespace functor
 
 #define REGISTER_GPU_KERNEL(T)                           \
@@ -1349,11 +1358,10 @@
                               .Device(DEVICE_GPU)        \
                               .HostMemory("seq_len_max") \
                               .TypeConstraint<T>("T"),   \
-                          BlockLSTMGradOp<GPUDevice, T, true>);
+                          BlockLSTMGradOp<GPUDevice, T, true, ICFO>);
 
-REGISTER_GPU_KERNEL(float);
 REGISTER_GPU_KERNEL(Eigen::half);
-// REGISTER_GPU_KERNEL(double);
+REGISTER_GPU_KERNEL(float);
 #undef REGISTER_GPU_KERNEL
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
diff --git a/tensorflow/core/kernels/rnn/lstm_ops.h b/tensorflow/core/kernels/rnn/lstm_ops.h
index 8885d7c..834a923 100644
--- a/tensorflow/core/kernels/rnn/lstm_ops.h
+++ b/tensorflow/core/kernels/rnn/lstm_ops.h
@@ -25,6 +25,16 @@
 namespace tensorflow {
 class OpKernelContext;
 
+enum GateLayout { ICFO, IFCO };
+
+constexpr int gate_c_offset(GateLayout gate_layout, int cell_size) {
+  return (gate_layout == ICFO) ? cell_size : cell_size * 2;
+}
+
+constexpr int gate_f_offset(GateLayout gate_layout, int cell_size) {
+  return (gate_layout == ICFO) ? cell_size * 2 : cell_size;
+}
+
 namespace functor {
 
 template <typename Device, typename T>
@@ -103,19 +113,21 @@
 
   int cell_size() const { return cell_size_; }
 
-  inline Eigen::array<Eigen::DenseIndex, 2> icfo_i_offsets() const {
+  inline Eigen::array<Eigen::DenseIndex, 2> gates_i_offsets() const {
     return {0, 0};
   }
 
-  inline Eigen::array<Eigen::DenseIndex, 2> icfo_c_offsets() const {
-    return {0, cell_size_};
+  inline Eigen::array<Eigen::DenseIndex, 2> gates_c_offsets(
+      const GateLayout gate_layout) const {
+    return {0, gate_c_offset(gate_layout, cell_size_)};
   }
 
-  inline Eigen::array<Eigen::DenseIndex, 2> icfo_f_offsets() const {
-    return {0, cell_size_ * 2};
+  inline Eigen::array<Eigen::DenseIndex, 2> gates_f_offsets(
+      const GateLayout gate_layout) const {
+    return {0, gate_f_offset(gate_layout, cell_size_)};
   }
 
-  inline Eigen::array<Eigen::DenseIndex, 2> icfo_o_offsets() const {
+  inline Eigen::array<Eigen::DenseIndex, 2> gates_o_offsets() const {
     return {0, cell_size_ * 3};
   }
 
@@ -147,7 +159,7 @@
 
 // See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for
 // GPUDevice implementation.
-template <typename Device, typename T, bool USE_CUBLAS>
+template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
 struct LSTMBlockCellFprop : public LSTMBlockCell {
   LSTMBlockCellFprop(const int batch_size, const int input_size,
                      const int cell_size)
@@ -166,13 +178,13 @@
                   typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs,
                   typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o,
                   typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co,
-                  typename TTypes<T>::Matrix icfo,
+                  typename TTypes<T>::Matrix gates,
                   typename TTypes<T>::Matrix h);
 };
 
 // See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for
 // GPUDevice implementation.
-template <typename Device, typename T, bool USE_CUBLAS>
+template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
 struct LSTMBlockCellBprop : public LSTMBlockCell {
   LSTMBlockCellBprop(const int batch_size, const int input_size,
                      const int cell_size)
@@ -192,12 +204,12 @@
       typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,
       typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,
       typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,
-      typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad,
-      typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,
-      typename TTypes<T>::Vec wco_grad);
+      typename TTypes<T>::Matrix dgates,
+      typename TTypes<T>::Matrix cs_prev_grad, typename TTypes<T>::Vec wci_grad,
+      typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad);
 };
 
-template <typename Device, typename T, bool USE_CUBLAS>
+template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
 struct BlockLSTMBprop : public LSTMBlockCell {
   BlockLSTMBprop(const int batch_size, const int input_size,
                  const int cell_size)
@@ -218,7 +230,8 @@
       typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,
       typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,
       typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,
-      typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad,
+      typename TTypes<T>::Matrix dgates,
+      typename TTypes<T>::Matrix cs_prev_grad,
       typename TTypes<T>::Matrix h_prev_grad,
       typename TTypes<T>::Matrix xh_grad, typename TTypes<T>::Matrix x_grad,
       typename TTypes<T>::Matrix w_grad, typename TTypes<T>::Vec wci_grad,
@@ -246,10 +259,10 @@
     // di[t] = sigm'(i[t]) dcs[t] ci[t]
     di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci;
 
-    dicfo.slice(icfo_i_offsets(), cell_extents()).device(d) = di;
-    dicfo.slice(icfo_c_offsets(), cell_extents()).device(d) = dci;
-    dicfo.slice(icfo_f_offsets(), cell_extents()).device(d) = df;
-    dicfo.slice(icfo_o_offsets(), cell_extents()).device(d) = do_;
+    dgates.slice(gates_i_offsets(), cell_extents()).device(d) = di;
+    dgates.slice(gates_c_offsets(gate_layout), cell_extents()).device(d) = dci;
+    dgates.slice(gates_f_offsets(gate_layout), cell_extents()).device(d) = df;
+    dgates.slice(gates_o_offsets(), cell_extents()).device(d) = do_;
 
     cs_prev_grad.device(d) = dcs * f;
     if (use_peephole) {
@@ -260,10 +273,10 @@
     }
 
     // xh_grad.
-    typename TTypes<T>::ConstMatrix const_dicfo(dicfo.data(),
-                                                dicfo.dimensions());
+    typename TTypes<T>::ConstMatrix const_dgates(dgates.data(),
+                                                 dgates.dimensions());
     TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
-        ctx, d, false, true, 1.f, const_dicfo, w, 0.f, xh_grad);
+        ctx, d, false, true, 1.f, const_dgates, w, 0.f, xh_grad);
 
     // xh.
     xh.slice(xh_x_offsets(), xh_x_extents()).device(d) = x;
@@ -276,10 +289,10 @@
 
     // w_grad.
     TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
-        ctx, d, true, false, 1.f, const_xh, const_dicfo, 1.f, w_grad);
+        ctx, d, true, false, 1.f, const_xh, const_dgates, 1.f, w_grad);
 
     // b_grad.
-    b_grad.device(d) += dicfo.sum(Eigen::array<int, 1>({0}));
+    b_grad.device(d) += dgates.sum(Eigen::array<int, 1>({0}));
 
     if (use_peephole) {
       wci_grad.device(d) += (di * cs_prev).sum(Eigen::array<int, 1>({0}));
diff --git a/tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc b/tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc
index 256591a..f3f3785 100644
--- a/tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc
@@ -81,8 +81,8 @@
 // Launch with blocks of (batch x 32)
 //
 // TODO(b/67600500): Try making 'use_peephole' a template parameter.
-template <typename T, bool use_peephole>
-__global__ void lstm_gates(const T* icfo, const T* b, const T* cs_prev,
+template <typename T, bool use_peephole, GateLayout gate_layout>
+__global__ void lstm_gates(const T* gates, const T* b, const T* cs_prev,
                            const T* wci, const T* wcf, const T* wco, T* o, T* h,
                            T* ci, T* cs, T* co, T* i, T* f,
                            const float forget_bias, const float cell_clip,
@@ -98,7 +98,7 @@
   // The following code assumes the input arrays are of the following
   // shapes and interpretations.
   //
-  // 1) 'icfo' is a matrix such that,
+  // 1) 'gates' is a matrix such that,
   //
   //   cell_size  cell_size  cell_size  cell_size
   //  +----------+----------+----------+----------+
@@ -107,7 +107,8 @@
   //  |          |          |          |          |
   //  +----------+----------+----------+----------+
   //
-  // 'gid' is the index assigned to this thread for 'icfo' in the 'i' submatrix.
+  // 'gid' is the index assigned to this thread for 'gates' in the 'i'
+  // submatrix.
   //
   // 2) 'b' is a vector such that,
   //
@@ -146,23 +147,27 @@
 
   T i_local;
   if (use_peephole) {
-    i_local = sigmoid_op(icfo[0 * cell_size + gid] + b[0 * cell_size + act_id] +
-                         cs_prev[cid] * wci[act_id]);
+    i_local =
+        sigmoid_op(gates[0 * cell_size + gid] + b[0 * cell_size + act_id] +
+                   cs_prev[cid] * wci[act_id]);
   } else {
-    i_local = sigmoid_op(icfo[0 * cell_size + gid] + b[0 * cell_size + act_id]);
+    i_local =
+        sigmoid_op(gates[0 * cell_size + gid] + b[0 * cell_size + act_id]);
   }
   i[cid] = i_local;
 
-  const T ci_local =
-      tanh_op(icfo[1 * cell_size + gid] + b[1 * cell_size + act_id]);
+  const int c_offset = gate_c_offset(gate_layout, cell_size);
+  const int f_offset = gate_f_offset(gate_layout, cell_size);
+
+  const T ci_local = tanh_op(gates[c_offset + gid] + b[c_offset + act_id]);
   ci[cid] = ci_local;
 
   T f_local;
   if (use_peephole) {
-    f_local = sigmoid_op(icfo[2 * cell_size + gid] + b[2 * cell_size + act_id] +
+    f_local = sigmoid_op(gates[f_offset + gid] + b[f_offset + act_id] +
                          forget_bias_t + cs_prev[cid] * wcf[act_id]);
   } else {
-    f_local = sigmoid_op(icfo[2 * cell_size + gid] + b[2 * cell_size + act_id] +
+    f_local = sigmoid_op(gates[f_offset + gid] + b[f_offset + act_id] +
                          forget_bias_t);
   }
   f[cid] = f_local;
@@ -178,10 +183,11 @@
 
   T o_local;
   if (use_peephole) {
-    o_local = sigmoid_op(icfo[3 * cell_size + gid] + b[3 * cell_size + act_id] +
-                         cs_local * wco[act_id]);
+    o_local = sigmoid_op(gates[3 * cell_size + gid] +
+                         b[3 * cell_size + act_id] + cs_local * wco[act_id]);
   } else {
-    o_local = sigmoid_op(icfo[3 * cell_size + gid] + b[3 * cell_size + act_id]);
+    o_local =
+        sigmoid_op(gates[3 * cell_size + gid] + b[3 * cell_size + act_id]);
   }
   o[cid] = o_local;
 
@@ -217,7 +223,7 @@
   }
 }
 
-template <typename T>
+template <typename T, GateLayout gate_layout>
 void LSTMBlockCellFpropWithCUDA(
     OpKernelContext* ctx, const GPUDevice& d, const float forget_bias,
     const float cell_clip, bool use_peephole, typename TTypes<T>::ConstMatrix x,
@@ -228,7 +234,7 @@
     typename TTypes<T>::Matrix xh, typename TTypes<T>::Matrix i,
     typename TTypes<T>::Matrix cs, typename TTypes<T>::Matrix f,
     typename TTypes<T>::Matrix o, typename TTypes<T>::Matrix ci,
-    typename TTypes<T>::Matrix co, typename TTypes<T>::Matrix icfo,
+    typename TTypes<T>::Matrix co, typename TTypes<T>::Matrix gates,
     typename TTypes<T>::Matrix h, int batch_size, int cell_size,
     int input_size) {
   const auto& cu_stream = GetGpuStream(ctx);
@@ -249,7 +255,7 @@
   typename TTypes<T>::ConstMatrix const_xh(xh.data(), xh.dimensions());
   TensorBlasGemm<GPUDevice, T, true /* USE_CUBLAS */>::compute(
       ctx, d, false, false, typename gemm_compute_type<T>::type(1.f), const_xh,
-      w, typename gemm_compute_type<T>::type(0.f), icfo);
+      w, typename gemm_compute_type<T>::type(0.f), gates);
 
   // Add bias, apply non-linearities and gating.
   //
@@ -262,20 +268,22 @@
 
   if (use_peephole) {
     TF_CHECK_OK(GpuLaunchKernel(
-        lstm_gates<T, true>, grid_dim_2d, block_dim_2d, 0, cu_stream,
-        icfo.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(),
-        wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(),
-        i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size));
+        lstm_gates<T, true, gate_layout>, grid_dim_2d, block_dim_2d, 0,
+        cu_stream, gates.data(), b.data(), cs_prev.data(), wci.data(),
+        wcf.data(), wco.data(), o.data(), h.data(), ci.data(), cs.data(),
+        co.data(), i.data(), f.data(), forget_bias, cell_clip, batch_size,
+        cell_size));
   } else {
     TF_CHECK_OK(GpuLaunchKernel(
-        lstm_gates<T, false>, grid_dim_2d, block_dim_2d, 0, cu_stream,
-        icfo.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(),
-        wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(),
-        i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size));
+        lstm_gates<T, false, gate_layout>, grid_dim_2d, block_dim_2d, 0,
+        cu_stream, gates.data(), b.data(), cs_prev.data(), wci.data(),
+        wcf.data(), wco.data(), o.data(), h.data(), ci.data(), cs.data(),
+        co.data(), i.data(), f.data(), forget_bias, cell_clip, batch_size,
+        cell_size));
   }
 }
 
-template <typename T>
+template <typename T, GateLayout gate_layout>
 __global__ void lstm_gates_bprop(
     const T* cs_prev,  // [batch_size, cell_size]
     const T* h_prev,   // [batch_size, cell_size]
@@ -297,7 +305,7 @@
     T* dci,            // [batch_size, cell_size]
     T* df,             // [batch_size, cell_size]
     T* di,             // [batch_size, cell_size]
-    T* dicfo,          // [input_size + cell_size, 4 * cell_size]
+    T* dgates,         // [input_size + cell_size, 4 * cell_size]
     T* cs_prev_grad,   // [batch_size, cell_size]
     const int batch_size, const int cell_size, const bool use_peephole) {
   const int batch_id = blockIdx.x * blockDim.x + threadIdx.x;
@@ -341,10 +349,10 @@
   const T di_local = i_local * (one - i_local) * dcs_local * ci_local;
   di[cid] = di_local;
 
-  dicfo[gid + 0 * cell_size] = di_local;
-  dicfo[gid + 1 * cell_size] = dci_local;
-  dicfo[gid + 2 * cell_size] = df_local;
-  dicfo[gid + 3 * cell_size] = do_local;
+  dgates[gid + 0 * cell_size] = di_local;
+  dgates[gate_c_offset(gate_layout, cell_size)] = dci_local;
+  dgates[gate_f_offset(gate_layout, cell_size)] = df_local;
+  dgates[gid + 3 * cell_size] = do_local;
 
   cs_prev_grad[cid] = dcs_local * f_local;
   if (use_peephole) {
@@ -352,7 +360,7 @@
   }
 }
 
-template <typename T>
+template <typename T, GateLayout gate_layout>
 void LSTMBlockCellBpropWithCUDA(
     OpKernelContext* ctx, const GPUDevice& d, typename TTypes<T>::ConstMatrix x,
     typename TTypes<T>::ConstMatrix cs_prev,
@@ -366,7 +374,7 @@
     typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,
     typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,
     typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,
-    typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad,
+    typename TTypes<T>::Matrix dgates, typename TTypes<T>::Matrix cs_prev_grad,
     typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,
     typename TTypes<T>::Vec wco_grad, const int batch_size, const int cell_size,
     const bool use_peephole) {
@@ -377,11 +385,11 @@
                    Eigen::divup(cell_size, static_cast<int>(block_dim_2d.y)));
 
   TF_CHECK_OK(GpuLaunchKernel(
-      lstm_gates_bprop<T>, grid_dim_2d, block_dim_2d, 0, cu_stream,
+      lstm_gates_bprop<T, gate_layout>, grid_dim_2d, block_dim_2d, 0, cu_stream,
       cs_prev.data(), h_prev.data(), w.data(), wci.data(), wcf.data(),
       wco.data(), b.data(), i.data(), cs.data(), f.data(), o.data(), ci.data(),
       co.data(), cs_grad.data(), h_grad.data(), do_.data(), dcs.data(),
-      dci.data(), df.data(), di.data(), dicfo.data(), cs_prev_grad.data(),
+      dci.data(), df.data(), di.data(), dgates.data(), cs_prev_grad.data(),
       batch_size, cell_size, use_peephole));
 
   if (use_peephole) {
@@ -398,66 +406,74 @@
 
 }  // namespace
 
-#define DEFINE_GPU_SPECS(T)                                                    \
-  template struct TensorZero<GPUDevice, T>;                                    \
-  template struct TensorUnalignedZero<GPUDevice, T>;                           \
-  template struct TensorCopy<GPUDevice, T>;                                    \
-  template struct TensorCopyUnaligned<GPUDevice, T>;                           \
-  template struct TensorCopyToUnaligned<GPUDevice, T>;                         \
-  template struct TensorAdd<GPUDevice, T>;                                     \
-  template <>                                                                  \
-  void LSTMBlockCellFprop<GPUDevice, T, true /* USE_CUBLAS */>::operator()(    \
-      OpKernelContext* ctx, const GPUDevice& d, const float forget_bias,       \
-      const float cell_clip, bool use_peephole,                                \
-      typename TTypes<T>::ConstMatrix x,                                       \
-      typename TTypes<T>::ConstMatrix cs_prev,                                 \
-      typename TTypes<T>::ConstMatrix h_prev,                                  \
-      typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,     \
-      typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,      \
-      typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,           \
-      typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs,             \
-      typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o,              \
-      typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co,            \
-      typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h) {         \
-    LSTMBlockCellFpropWithCUDA<T>(ctx, d, forget_bias, cell_clip,              \
-                                  use_peephole, x, cs_prev, h_prev, w, wci,    \
-                                  wcf, wco, b, xh, i, cs, f, o, ci, co, icfo,  \
-                                  h, batch_size_, cell_size_, input_size_);    \
-  }                                                                            \
-  template <>                                                                  \
-  void LSTMBlockCellBprop<GPUDevice, T, true /* USE_CUBLAS */>::operator()(    \
-      OpKernelContext* ctx, const GPUDevice& d, bool use_peephole,             \
-      typename TTypes<T>::ConstMatrix x,                                       \
-      typename TTypes<T>::ConstMatrix cs_prev,                                 \
-      typename TTypes<T>::ConstMatrix h_prev,                                  \
-      typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,     \
-      typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,      \
-      typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i,       \
-      typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f,   \
-      typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci,   \
-      typename TTypes<T>::ConstMatrix co,                                      \
-      typename TTypes<T>::ConstMatrix cs_grad,                                 \
-      typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,  \
-      typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,          \
-      typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,            \
-      typename TTypes<T>::Matrix dicfo,                                        \
-      typename TTypes<T>::Matrix cs_prev_grad,                                 \
-      typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,      \
-      typename TTypes<T>::Vec wco_grad) {                                      \
-    LSTMBlockCellBpropWithCUDA<T>(                                             \
-        ctx, d, x, cs_prev, h_prev, w, wci, wcf, wco, b, i, cs, f, o, ci, co,  \
-        cs_grad, h_grad, do_, dcs, dci, df, di, dicfo, cs_prev_grad, wci_grad, \
-        wcf_grad, wco_grad, batch_size_, cell_size_, use_peephole);            \
-  }                                                                            \
-  template struct LSTMBlockCellFprop<GPUDevice, T, true /* USE_CUBLAS */>;     \
-  template struct LSTMBlockCellBprop<GPUDevice, T, true /* USE_CUBLAS */>;     \
-  template struct BlockLSTMBprop<GPUDevice, T, true /* USE_CUBLAS */>;
+#define DECLARE_GPU_FBPROP(T, GATE_LAYOUT)                                    \
+  template <>                                                                 \
+  void LSTMBlockCellFprop<GPUDevice, T, true /* USE_CUBLAS */, GATE_LAYOUT>:: \
+  operator()(                                                                 \
+      OpKernelContext* ctx, const GPUDevice& d, const float forget_bias,      \
+      const float cell_clip, bool use_peephole,                               \
+      typename TTypes<T>::ConstMatrix x,                                      \
+      typename TTypes<T>::ConstMatrix cs_prev,                                \
+      typename TTypes<T>::ConstMatrix h_prev,                                 \
+      typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,    \
+      typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,     \
+      typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,          \
+      typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs,            \
+      typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o,             \
+      typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co,           \
+      typename TTypes<T>::Matrix gates, typename TTypes<T>::Matrix h) {       \
+    LSTMBlockCellFpropWithCUDA<T, GATE_LAYOUT>(                               \
+        ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev, h_prev, w,  \
+        wci, wcf, wco, b, xh, i, cs, f, o, ci, co, gates, h, batch_size_,     \
+        cell_size_, input_size_);                                             \
+  }                                                                           \
+  template <>                                                                 \
+  void LSTMBlockCellBprop<GPUDevice, T, true /* USE_CUBLAS */, GATE_LAYOUT>:: \
+  operator()(                                                                 \
+      OpKernelContext* ctx, const GPUDevice& d, bool use_peephole,            \
+      typename TTypes<T>::ConstMatrix x,                                      \
+      typename TTypes<T>::ConstMatrix cs_prev,                                \
+      typename TTypes<T>::ConstMatrix h_prev,                                 \
+      typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,    \
+      typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,     \
+      typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i,      \
+      typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f,  \
+      typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci,  \
+      typename TTypes<T>::ConstMatrix co,                                     \
+      typename TTypes<T>::ConstMatrix cs_grad,                                \
+      typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
+      typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,         \
+      typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,           \
+      typename TTypes<T>::Matrix dgates,                                      \
+      typename TTypes<T>::Matrix cs_prev_grad,                                \
+      typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,     \
+      typename TTypes<T>::Vec wco_grad) {                                     \
+    LSTMBlockCellBpropWithCUDA<T, GATE_LAYOUT>(                               \
+        ctx, d, x, cs_prev, h_prev, w, wci, wcf, wco, b, i, cs, f, o, ci, co, \
+        cs_grad, h_grad, do_, dcs, dci, df, di, dgates, cs_prev_grad,         \
+        wci_grad, wcf_grad, wco_grad, batch_size_, cell_size_, use_peephole); \
+  }                                                                           \
+  template struct LSTMBlockCellFprop<GPUDevice, T, true /* USE_CUBLAS */,     \
+                                     GATE_LAYOUT>;                            \
+  template struct LSTMBlockCellBprop<GPUDevice, T, true /* USE_CUBLAS */,     \
+                                     GATE_LAYOUT>;                            \
+  template struct BlockLSTMBprop<GPUDevice, T, true /* USE_CUBLAS */,         \
+                                 GATE_LAYOUT>;
 
-DEFINE_GPU_SPECS(float);
-DEFINE_GPU_SPECS(Eigen::half);
-// DEFINE_GPU_SPECS(double);
-#undef DEFINE_GPU_SPECS
+#define DECLARE_GPU_SPECS(T)                           \
+  template struct TensorZero<GPUDevice, T>;            \
+  template struct TensorUnalignedZero<GPUDevice, T>;   \
+  template struct TensorCopy<GPUDevice, T>;            \
+  template struct TensorCopyUnaligned<GPUDevice, T>;   \
+  template struct TensorCopyToUnaligned<GPUDevice, T>; \
+  template struct TensorAdd<GPUDevice, T>;             \
+                                                       \
+  DECLARE_GPU_FBPROP(T, ICFO);
 
+DECLARE_GPU_SPECS(Eigen::half);
+DECLARE_GPU_SPECS(float);
+#undef DECLARE_GPU_SPECS
+#undef DECLARE_GPU_FBPROP
 }  // end namespace functor
 }  // end namespace tensorflow
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
diff --git a/tensorflow/core/kernels/save_op.cc b/tensorflow/core/kernels/save_op.cc
index f87e0fa..f53976c 100644
--- a/tensorflow/core/kernels/save_op.cc
+++ b/tensorflow/core/kernels/save_op.cc
@@ -62,8 +62,8 @@
     }
     Tensor* out = nullptr;
     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
-    out->scalar<string>()() = strings::Printf(
-        "%s-%05d-of-%05d", ctx->input(0).scalar<string>()().c_str(),
+    out->scalar<tstring>()() = strings::Printf(
+        "%s-%05d-of-%05d", ctx->input(0).scalar<tstring>()().c_str(),
         ctx->input(1).scalar<int32>()(), ctx->input(2).scalar<int32>()());
   }
 };
@@ -85,8 +85,8 @@
     }
     Tensor* out = nullptr;
     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
-    out->scalar<string>()() = strings::Printf(
-        "%s-\?\?\?\?\?-of-%05d", ctx->input(0).scalar<string>()().c_str(),
+    out->scalar<tstring>()() = strings::Printf(
+        "%s-\?\?\?\?\?-of-%05d", ctx->input(0).scalar<tstring>()().c_str(),
         ctx->input(1).scalar<int32>()());
   }
 };
diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc
index faafed3..f0a2867 100644
--- a/tensorflow/core/kernels/save_restore_tensor.cc
+++ b/tensorflow/core/kernels/save_restore_tensor.cc
@@ -70,7 +70,7 @@
                                 "shapes and slices but got ",
                                 tensor_shapes_and_slices_t.NumElements()));
     tensor_shapes_and_slices_ptr =
-        tensor_shapes_and_slices_t.flat<string>().data();
+        tensor_shapes_and_slices_t.flat<tstring>().data();
   }
   OP_REQUIRES(context, context->num_inputs() == N + kFixedInputs,
               errors::InvalidArgument("Expected totally ", N + kFixedInputs,
@@ -79,13 +79,13 @@
                                       N, " names, but received ",
                                       context->num_inputs(), " inputs"));
 
-  VLOG(1) << "About to save tensors to file " << filename_t.flat<string>()(0)
+  VLOG(1) << "About to save tensors to file " << filename_t.flat<tstring>()(0)
           << "...";
-  checkpoint::TensorSliceWriter writer(filename_t.flat<string>()(0),
+  checkpoint::TensorSliceWriter writer(filename_t.flat<tstring>()(0),
                                        std::move(builder_func));
 
   Status s;
-  auto tensor_names_flat = tensor_names_t.flat<string>();
+  auto tensor_names_flat = tensor_names_t.flat<tstring>();
 
   // Process tensors in sorted name order.  This allows us to avoid seeking
   // during restoration in the common case where we are restoring a full
@@ -153,10 +153,10 @@
             "Input 0 (file_pattern) must be a string scalar; got a tensor of ",
             size, "elements"));
   }
-  const string& file_pattern = file_pattern_t.flat<string>()(0);
+  const string& file_pattern = file_pattern_t.flat<tstring>()(0);
 
   const Tensor& tensor_name_t = context->input(1);
-  const string& tensor_name = tensor_name_t.flat<string>()(restore_index);
+  const string& tensor_name = tensor_name_t.flat<tstring>()(restore_index);
 
   // If we cannot find a cached reader we will allocate our own.
   std::unique_ptr<checkpoint::TensorSliceReader> allocated_reader;
@@ -192,7 +192,7 @@
   TensorShape output_shape(saved_shape);
   TensorSlice slice_to_load(saved_shape.dims());
   if (restore_slice) {
-    const string& shape_spec = context->input(2).flat<string>()(restore_index);
+    const string& shape_spec = context->input(2).flat<tstring>()(restore_index);
     if (!shape_spec.empty()) {
       TensorShape parsed_shape;
       OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
@@ -318,10 +318,10 @@
                         const Tensor& tensor_names,
                         const Tensor& shape_and_slices,
                         gtl::ArraySlice<DataType> dtypes) {
-  const string& prefix_string = prefix.scalar<string>()();
+  const string& prefix_string = prefix.scalar<tstring>()();
 
-  const auto& tensor_names_flat = tensor_names.flat<string>();
-  const auto& shape_and_slices_flat = shape_and_slices.flat<string>();
+  const auto& tensor_names_flat = tensor_names.flat<tstring>();
+  const auto& shape_and_slices_flat = shape_and_slices.flat<tstring>();
 
   // Sort lookup keys to improve locality when reading multiple tensors.
   std::vector<size_t> sorted_name_idx(tensor_names_flat.size());
diff --git a/tensorflow/core/kernels/save_restore_v2_ops.cc b/tensorflow/core/kernels/save_restore_v2_ops.cc
index ed1195c..512fd9b 100644
--- a/tensorflow/core/kernels/save_restore_v2_ops.cc
+++ b/tensorflow/core/kernels/save_restore_v2_ops.cc
@@ -101,9 +101,9 @@
 
     const int kFixedInputs = 3;  // Prefix, tensor names, shape_and_slices.
     const int num_tensors = static_cast<int>(tensor_names.NumElements());
-    const string& prefix_string = prefix.scalar<string>()();
-    const auto& tensor_names_flat = tensor_names.flat<string>();
-    const auto& shape_and_slices_flat = shape_and_slices.flat<string>();
+    const string& prefix_string = prefix.scalar<tstring>()();
+    const auto& tensor_names_flat = tensor_names.flat<tstring>();
+    const auto& shape_and_slices_flat = shape_and_slices.flat<tstring>();
 
     BundleWriter writer(Env::Default(), prefix_string);
     OP_REQUIRES_OK(context, writer.status());
@@ -157,7 +157,7 @@
     ValidateInputs(false /* not save op */, context, prefix, tensor_names,
                    shape_and_slices);
 
-    const string& prefix_string = prefix.scalar<string>()();
+    const string& prefix_string = prefix.scalar<tstring>()();
 
     // Intention: we plan to use the RestoreV2 op as a backward-compatible
     // reader as we upgrade to the V2 format.  This allows transparent upgrade.
@@ -215,7 +215,7 @@
     const gtl::ArraySlice<string> input_prefixes =
         gtl::ArraySlice<string>(checkpoint_prefixes.flat<string>());
     Env* env = Env::Default();
-    const string& merged_prefix = destination_prefix.scalar<string>()();
+    const string& merged_prefix = destination_prefix.scalar<tstring>()();
     OP_REQUIRES_OK(
         context, tensorflow::MergeBundles(env, input_prefixes, merged_prefix));
 
diff --git a/tensorflow/core/kernels/sdca_ops.cc b/tensorflow/core/kernels/sdca_ops.cc
index d0e0b15..4fdb7d1 100644
--- a/tensorflow/core/kernels/sdca_ops.cc
+++ b/tensorflow/core/kernels/sdca_ops.cc
@@ -312,7 +312,7 @@
     OP_REQUIRES_OK(context, context->allocate_output(
                                 0, TensorShape({num_elements, 2}), &out));
 
-    const auto in_values = input.flat<string>();
+    const auto in_values = input.flat<tstring>();
     auto out_values = out->matrix<int64>();
 
     for (int64 i = 0; i < num_elements; ++i) {
diff --git a/tensorflow/core/kernels/sendrecv_ops.cc b/tensorflow/core/kernels/sendrecv_ops.cc
index 90bd3ea..5e09e5f 100644
--- a/tensorflow/core/kernels/sendrecv_ops.cc
+++ b/tensorflow/core/kernels/sendrecv_ops.cc
@@ -169,6 +169,12 @@
   Rendezvous::Args args;
   args.device_context = ctx->op_device_context();
   args.alloc_attrs = ctx->output_alloc_attr(0);
+  if (ctx->is_eager()) {
+    // NOTE(fishx): Only set cancellation_manager in eager mode. Because in
+    // Tensorflow 1.x, session (or graph_mgr) will abort the underlying
+    // rendezvous if it encounters any error.
+    args.cancellation_manager = ctx->cancellation_manager();
+  }
 
   FrameAndIter frame_iter = GetFrameAndIter(ctx, hostmem_sendrecv_);
   if (frame_iter == FrameAndIter(0, 0)) {
diff --git a/tensorflow/core/kernels/session_ops.cc b/tensorflow/core/kernels/session_ops.cc
index f2dd281..d83a714 100644
--- a/tensorflow/core/kernels/session_ops.cc
+++ b/tensorflow/core/kernels/session_ops.cc
@@ -57,7 +57,7 @@
       handle->scalar<ResourceHandle>()() = resource_handle;
     } else {
       // Legacy behavior in V1.
-      handle->flat<string>().setConstant(tk.GetHandle(name()));
+      handle->flat<tstring>().setConstant(tk.GetHandle(name()));
     }
   }
 
@@ -110,7 +110,7 @@
 
   void Compute(OpKernelContext* ctx) override {
     const Tensor& handle = ctx->input(0);
-    const string& name = handle.scalar<string>()();
+    const string& name = handle.scalar<tstring>()();
     Tensor val;
     OP_REQUIRES_OK(ctx, ctx->session_state()->GetTensor(name, &val));
     ctx->set_output(0, val);
@@ -153,7 +153,7 @@
 
   void Compute(OpKernelContext* ctx) override {
     const Tensor& handle = ctx->input(0);
-    const string& name = handle.scalar<string>()();
+    const string& name = handle.scalar<tstring>()();
     OP_REQUIRES_OK(ctx, ctx->session_state()->DeleteTensor(name));
   }
 
diff --git a/tensorflow/core/kernels/softmax_op_gpu.cu.cc b/tensorflow/core/kernels/softmax_op_gpu.cu.cc
index b90381d..df84414 100644
--- a/tensorflow/core/kernels/softmax_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/softmax_op_gpu.cu.cc
@@ -165,8 +165,8 @@
           context, const_cast<T*>(max_logits.flat<T>().data()),
           reinterpret_cast<const T*>(logits_in_.flat<T>().data()), rows, cols);
 
-      const int numThreads = 128;
-      const int numBlocks = Eigen::divup(rows * cols, numThreads);
+      const int numThreadsPerBlock = 128;
+      const int numBlocks = Eigen::divup(rows * cols, numThreadsPerBlock);
 
       gpuprim::CountingInputIterator<int> counting_iterator(0);
       using InputIterType =
@@ -185,7 +185,7 @@
           input_itr, rows, cols);
 
       TF_CHECK_OK(GpuLaunchKernel(
-          GenerateNormalizedProb<T, acc_type>, numBlocks, numThreads, 0,
+          GenerateNormalizedProb<T, acc_type>, numBlocks, numThreadsPerBlock, 0,
           cu_stream, reinterpret_cast<const T*>(logits_in_.flat<T>().data()),
           reinterpret_cast<const acc_type*>(sum_probs.flat<acc_type>().data()),
           reinterpret_cast<const T*>(max_logits.flat<T>().data()),
diff --git a/tensorflow/core/kernels/sparse_cross_op.cc b/tensorflow/core/kernels/sparse_cross_op.cc
index 8e92c9e..a6a4060 100644
--- a/tensorflow/core/kernels/sparse_cross_op.cc
+++ b/tensorflow/core/kernels/sparse_cross_op.cc
@@ -78,7 +78,7 @@
 int64 SparseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
   const int64 start = feature_start_indices_[batch];
   if (DT_STRING == values_.dtype())
-    return Fingerprint64(values_.vec<string>().data()[start + n]);
+    return Fingerprint64(values_.vec<tstring>().data()[start + n]);
   return values_.vec<int64>().data()[start + n];
 }
 
@@ -87,7 +87,7 @@
 string SparseTensorColumn<string>::Feature(int64 batch, int64 n) const {
   const int64 start = feature_start_indices_[batch];
   if (DT_STRING == values_.dtype())
-    return values_.vec<string>().data()[start + n];
+    return values_.vec<tstring>().data()[start + n];
   return std::to_string(values_.vec<int64>().data()[start + n]);
 }
 
@@ -95,7 +95,7 @@
 StringPiece SparseTensorColumn<StringPiece>::Feature(int64 batch,
                                                      int64 n) const {
   const int64 start = feature_start_indices_[batch];
-  return values_.vec<string>().data()[start + n];
+  return values_.vec<tstring>().data()[start + n];
 }
 
 // A column that is backed by a dense tensor.
@@ -118,21 +118,21 @@
 template <>
 int64 DenseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
   if (DT_STRING == tensor_.dtype())
-    return Fingerprint64(tensor_.matrix<string>()(batch, n));
+    return Fingerprint64(tensor_.matrix<tstring>()(batch, n));
   return tensor_.matrix<int64>()(batch, n);
 }
 
 // Internal type is string or StringPiece when using StringCrosser.
 template <>
 string DenseTensorColumn<string>::Feature(int64 batch, int64 n) const {
-  if (DT_STRING == tensor_.dtype()) return tensor_.matrix<string>()(batch, n);
+  if (DT_STRING == tensor_.dtype()) return tensor_.matrix<tstring>()(batch, n);
   return std::to_string(tensor_.matrix<int64>()(batch, n));
 }
 
 template <>
 StringPiece DenseTensorColumn<StringPiece>::Feature(int64 batch,
                                                     int64 n) const {
-  return tensor_.matrix<string>()(batch, n);
+  return tensor_.matrix<tstring>()(batch, n);
 }
 
 // Updates Output tensors with sparse crosses.
diff --git a/tensorflow/core/kernels/sparse_matmul_op.cc b/tensorflow/core/kernels/sparse_matmul_op.cc
index 9c9e737..1ad86b6 100644
--- a/tensorflow/core/kernels/sparse_matmul_op.cc
+++ b/tensorflow/core/kernels/sparse_matmul_op.cc
@@ -1550,7 +1550,7 @@
   // Note buffer needs enough space to hold at most a KR * NR matrix since that
   // is the block size per iteration.
   const int buffer_num_rows =
-      std::min(KR, right_dim0) * (std::min(NR, right_dim1) + N - 1) / N;
+      std::min(KR, right_dim0) * ((std::min(NR, right_dim1) + N - 1) / N);
   MatrixR buffer(buffer_num_rows, N);
   std::vector<ConstMatrixMapR*> right_slices;
 
diff --git a/tensorflow/core/kernels/sparse_matmul_op.h b/tensorflow/core/kernels/sparse_matmul_op.h
index 6b9db8f..6e84e22 100644
--- a/tensorflow/core/kernels/sparse_matmul_op.h
+++ b/tensorflow/core/kernels/sparse_matmul_op.h
@@ -21,7 +21,6 @@
 #include "tensorflow/core/platform/types.h"
 
 #if defined(PLATFORM_WINDOWS)
-#include "tensorflow/core/platform/windows/cpu_info.h"
 #include "tensorflow/core/platform/windows/intrinsics_port.h"
 #endif
 
diff --git a/tensorflow/core/kernels/stack.cc b/tensorflow/core/kernels/stack.cc
index 033b9f3..af8f760 100644
--- a/tensorflow/core/kernels/stack.cc
+++ b/tensorflow/core/kernels/stack.cc
@@ -134,8 +134,8 @@
           "Stack handle must have two elements, but had shape: ",
           Tstack_handle.shape().DebugString());
     }
-    const string& container = Tstack_handle.flat<string>()(0);
-    const string& stack_name = Tstack_handle.flat<string>()(1);
+    const string& container = Tstack_handle.flat<tstring>()(0);
+    const string& stack_name = Tstack_handle.flat<tstring>()(1);
     string key = strings::StrCat(container, stack_name);
     ResourceMgr* rm = ctx->resource_manager();
     if (rm == nullptr) {
@@ -184,10 +184,10 @@
   ResourceMgr* rm = ctx->resource_manager();
   OP_REQUIRES(ctx, rm != nullptr, errors::Internal("No resource manager."));
   string key = strings::StrCat(kContainer, stack_name);
-  Stack* stack = new Stack(elem_type_, stack_name, size);
   auto* step_container = ctx->step_container();
   OP_REQUIRES(ctx, step_container != nullptr,
               errors::Internal("No step container."));
+  Stack* stack = new Stack(elem_type_, stack_name, size);
   OP_REQUIRES_OK(ctx, rm->Create(step_container->name(), key, stack));
   if (IsRefType(ctx->expected_output_dtype(0))) {
     // Create the stack handle.
@@ -196,7 +196,7 @@
     OP_REQUIRES_OK(ctx, ctx->allocate_temp(tensorflow::DT_STRING,
                                            tensorflow::TensorShape({2}),
                                            &stack->handle_, alloc_attr));
-    auto handle = stack->handle_.flat<string>();
+    auto handle = stack->handle_.flat<tstring>();
     handle(0) = kContainer;
     handle(1) = std::move(stack_name);
     ctx->set_output_ref(0, stack->mu(), &stack->handle_);
diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc
index 5d4ee52..2e6a264 100644
--- a/tensorflow/core/kernels/strided_slice_op.cc
+++ b/tensorflow/core/kernels/strided_slice_op.cc
@@ -437,6 +437,8 @@
                           StridedSliceAssignOp<CPUDevice, type, true>)
 
 TF_CALL_ALL_TYPES(REGISTER_STRIDED_SLICE);
+TF_CALL_uint32(REGISTER_STRIDED_SLICE);
+TF_CALL_uint64(REGISTER_STRIDED_SLICE);
 
 #undef REGISTER_STRIDED_SLICE
 
diff --git a/tensorflow/core/kernels/strided_slice_op_impl.h b/tensorflow/core/kernels/strided_slice_op_impl.h
index e7d9a5e..bf69a19 100644
--- a/tensorflow/core/kernels/strided_slice_op_impl.h
+++ b/tensorflow/core/kernels/strided_slice_op_impl.h
@@ -291,6 +291,8 @@
 #endif  // END GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU);
+TF_CALL_uint32(DECLARE_FOR_N_CPU);
+TF_CALL_uint64(DECLARE_FOR_N_CPU);
 
 #ifdef TENSORFLOW_USE_SYCL
 #define PREVENT_FOR_N_SYCL(T) \
diff --git a/tensorflow/core/kernels/string_format_op.cc b/tensorflow/core/kernels/string_format_op.cc
index e4a1887..0caec3e 100644
--- a/tensorflow/core/kernels/string_format_op.cc
+++ b/tensorflow/core/kernels/string_format_op.cc
@@ -50,7 +50,7 @@
       strings::StrAppend(&msg, split_template_[i + 1].c_str());
     }
 
-    formatted_string->scalar<string>()() = msg;
+    formatted_string->scalar<tstring>()() = std::move(msg);
   }
 
  private:
diff --git a/tensorflow/core/kernels/string_join_op.cc b/tensorflow/core/kernels/string_join_op.cc
index 4b9c19d..5532f6d 100644
--- a/tensorflow/core/kernels/string_join_op.cc
+++ b/tensorflow/core/kernels/string_join_op.cc
@@ -42,7 +42,7 @@
     std::vector<TTypes<string>::ConstFlat> inputs;
 
     for (const auto& input : input_list) {
-      inputs.push_back(input.flat<string>());
+      inputs.push_back(input.flat<tstring>());
       is_scalar.push_back(TensorShapeUtils::IsScalar(input.shape()));
       if (!TensorShapeUtils::IsScalar(input.shape())) {
         if (TensorShapeUtils::IsScalar(input_shape)) {
@@ -60,7 +60,7 @@
     Tensor* output_tensor = nullptr;
     OP_REQUIRES_OK(context, context->allocate_output("output", input_shape,
                                                      &output_tensor));
-    auto output_flat = output_tensor->flat<string>();
+    auto output_flat = output_tensor->flat<tstring>();
 
     std::vector<StringPiece> strings(input_list.size());
     for (size_t i = 0; i < input_shape.num_elements(); ++i) {
diff --git a/tensorflow/core/kernels/string_length_op.cc b/tensorflow/core/kernels/string_length_op.cc
index 435a7ab..53a1613 100644
--- a/tensorflow/core/kernels/string_length_op.cc
+++ b/tensorflow/core/kernels/string_length_op.cc
@@ -34,7 +34,7 @@
     OP_REQUIRES_OK(context,
                    context->allocate_output(0, input.shape(), &output));
 
-    auto src = input.flat<string>();
+    auto src = input.flat<tstring>();
     auto dst = output->flat<int32>();
 
     switch (unit_) {
diff --git a/tensorflow/core/kernels/string_lower_op.cc b/tensorflow/core/kernels/string_lower_op.cc
index e24eedc..07065d2 100644
--- a/tensorflow/core/kernels/string_lower_op.cc
+++ b/tensorflow/core/kernels/string_lower_op.cc
@@ -45,8 +45,8 @@
     OP_REQUIRES_OK(
         ctx, ctx->allocate_output(0, input_tensor->shape(), &output_tensor));
 
-    const auto input = input_tensor->flat<string>();
-    auto output = output_tensor->flat<string>();
+    const auto input = input_tensor->flat<tstring>();
+    auto output = output_tensor->flat<tstring>();
 
     if (encoding_.empty()) {
       for (int64 i = 0; i < input.size(); ++i) {
diff --git a/tensorflow/core/kernels/string_ngrams_op.cc b/tensorflow/core/kernels/string_ngrams_op.cc
new file mode 100644
index 0000000..430d91b
--- /dev/null
+++ b/tensorflow/core/kernels/string_ngrams_op.cc
@@ -0,0 +1,201 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <locale>
+#include <string>
+
+#include "absl/strings/ascii.h"
+#include "absl/strings/str_cat.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace text {
+
+namespace {
+template <typename SPLITS_TYPE>
+class StringNGramsOp : public tensorflow::OpKernel {
+ public:
+  explicit StringNGramsOp(tensorflow::OpKernelConstruction* context)
+      : tensorflow::OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr("separator", &separator_));
+    OP_REQUIRES_OK(context, context->GetAttr("ngram_widths", &ngram_widths_));
+    OP_REQUIRES_OK(context, context->GetAttr("left_pad", &left_pad_));
+    OP_REQUIRES_OK(context, context->GetAttr("right_pad", &right_pad_));
+    OP_REQUIRES_OK(context, context->GetAttr("pad_width", &pad_width_));
+    OP_REQUIRES_OK(context, context->GetAttr("preserve_short_sequences",
+                                             &preserve_short_));
+  }
+
+  int get_pad_width(const int ngram_width) const {
+    // Ngrams can be padded with either a fixed pad width or a dynamic pad
+    // width depending on the 'pad_width' arg, but in no case should the padding
+    // ever be wider than 'ngram_width' - 1.
+    return std::min(pad_width_ < 0 ? ngram_width - 1 : pad_width_,
+                    ngram_width - 1);
+  }
+
+  int get_num_ngrams(const int length, const int ngram_width) const {
+    int pad_width = get_pad_width(ngram_width);
+    return std::max(0, ((length + 2 * pad_width) - ngram_width) + 1);
+  }
+
+  void Compute(tensorflow::OpKernelContext* context) override {
+    const tensorflow::Tensor* data;
+    OP_REQUIRES_OK(context, context->input("data", &data));
+    const auto& input_data = data->flat<tstring>().data();
+
+    const tensorflow::Tensor* splits;
+    OP_REQUIRES_OK(context, context->input("data_splits", &splits));
+    const auto& splits_vec = splits->flat<SPLITS_TYPE>();
+
+    // If there is no data or size, return an empty RT.
+    if (data->flat<tstring>().size() == 0 || splits_vec.size() == 0) {
+      tensorflow::Tensor* empty;
+      OP_REQUIRES_OK(context,
+                     context->allocate_output(0, data->shape(), &empty));
+      OP_REQUIRES_OK(context,
+                     context->allocate_output(1, splits->shape(), &empty));
+      return;
+    }
+
+    int num_batch_items = splits_vec.size() - 1;
+    tensorflow::Tensor* ngrams_splits;
+    OP_REQUIRES_OK(
+        context, context->allocate_output(1, splits->shape(), &ngrams_splits));
+    auto ngrams_splits_data = ngrams_splits->flat<SPLITS_TYPE>().data();
+
+    ngrams_splits_data[0] = 0;
+    for (int i = 1; i <= num_batch_items; ++i) {
+      int length = splits_vec(i) - splits_vec(i - 1);
+      int num_ngrams = 0;
+      for (int ngram_width : ngram_widths_)
+        num_ngrams += get_num_ngrams(length, ngram_width);
+      if (preserve_short_ && length > 0 && num_ngrams == 0) {
+        num_ngrams = 1;
+      }
+      ngrams_splits_data[i] = ngrams_splits_data[i - 1] + num_ngrams;
+    }
+
+    tensorflow::Tensor* ngrams;
+    OP_REQUIRES_OK(
+        context,
+        context->allocate_output(
+            0, TensorShape({ngrams_splits_data[num_batch_items]}), &ngrams));
+    auto ngrams_data = ngrams->flat<tstring>().data();
+
+    for (int i = 0; i < num_batch_items; ++i) {
+      auto data_start = &input_data[splits_vec(i)];
+      int output_start_idx = ngrams_splits_data[i];
+      for (int ngram_width : ngram_widths_) {
+        auto output_start = &ngrams_data[output_start_idx];
+        int length = splits_vec(i + 1) - splits_vec(i);
+        int num_ngrams = get_num_ngrams(length, ngram_width);
+        CreateNgrams(data_start, output_start, num_ngrams, ngram_width);
+        output_start_idx += num_ngrams;
+      }
+      // If we're preserving short sequences, check to see if no sequence was
+      // generated by comparing the current output start idx to the original
+      // one (ngram_splits_data). If no ngrams were generated, then they will
+      // be equal (since we increment output_start_idx by num_ngrams every
+      // time we create a set of ngrams.)
+      if (preserve_short_ && output_start_idx == ngrams_splits_data[i]) {
+        int data_length = splits_vec(i + 1) - splits_vec(i);
+        // One legitimate reason to not have any ngrams when preserve_short_
+        // is true is if the sequence itself is empty. In that case, move on.
+        if (data_length == 0) {
+          continue;
+        }
+        // We don't have to worry about dynamic padding sizes here: if padding
+        // was dynamic, every sequence would have had sufficient padding to
+        // generate at least one ngram.
+        int ngram_width = data_length + 2 * pad_width_;
+        auto output_start = &ngrams_data[output_start_idx];
+        int num_ngrams = 1;
+        CreateNgrams(data_start, output_start, num_ngrams, ngram_width);
+      }
+    }
+  }
+
+  void CreateNgrams(const string* data, string* output, int num_ngrams,
+                    int ngram_width) const {
+    for (int ngram_index = 0; ngram_index < num_ngrams; ++ngram_index) {
+      int pad_width = get_pad_width(ngram_width);
+      int left_padding = std::max(0, pad_width - ngram_index);
+      int right_padding =
+          std::max(0, pad_width - (num_ngrams - (ngram_index + 1)));
+      int num_tokens = ngram_width - (left_padding + right_padding);
+      int data_start_index = left_padding > 0 ? 0 : ngram_index - pad_width;
+
+      // Calculate the total expected size of the ngram so we can reserve the
+      // correct amount of space in the string.
+      int ngram_size = 0;
+      // Size of the left padding.
+      ngram_size += left_padding * left_pad_.length();
+      // Size of the tokens.
+      for (int n = 0; n < num_tokens; ++n) {
+        ngram_size += data[data_start_index + n].length();
+      }
+      // Size of the right padding.
+      ngram_size += right_padding * right_pad_.length();
+      // Size of the separators.
+      int num_separators = left_padding + right_padding + num_tokens - 1;
+      ngram_size += num_separators * separator_.length();
+
+      // Build the ngram.
+      string* ngram = &output[ngram_index];
+      ngram->reserve(ngram_size);
+      for (int n = 0; n < left_padding; ++n) {
+        *ngram += left_pad_;
+        *ngram += separator_;
+      }
+      for (int n = 0; n < num_tokens - 1; ++n) {
+        *ngram += data[data_start_index + n];
+        *ngram += separator_;
+      }
+      *ngram += data[data_start_index + num_tokens - 1];
+      for (int n = 0; n < right_padding; ++n) {
+        *ngram += separator_;
+        *ngram += right_pad_;
+      }
+
+      // In debug mode only: validate that we've reserved enough space for the
+      // ngram.
+      DCHECK_EQ(ngram_size, ngram->size());
+    }
+  }
+
+  string separator_;
+  string left_pad_;
+  string right_pad_;
+  bool use_pad_;
+  bool extend_pad_;
+  bool preserve_short_;
+
+  std::vector<int> ngram_widths_;
+  int pad_width_;
+};
+
+}  // namespace
+REGISTER_KERNEL_BUILDER(Name("StringNGrams")
+                            .Device(tensorflow::DEVICE_CPU)
+                            .TypeConstraint<int32>("Tsplits"),
+                        StringNGramsOp<int32>);
+REGISTER_KERNEL_BUILDER(Name("StringNGrams")
+                            .Device(tensorflow::DEVICE_CPU)
+                            .TypeConstraint<int64>("Tsplits"),
+                        StringNGramsOp<int64>);
+
+}  // namespace text
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/string_ngrams_op_test.cc b/tensorflow/core/kernels/string_ngrams_op_test.cc
new file mode 100644
index 0000000..afd1700
--- /dev/null
+++ b/tensorflow/core/kernels/string_ngrams_op_test.cc
@@ -0,0 +1,554 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <vector>
+
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/shape_inference_testutil.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace tensorflow {
+namespace text {
+
+using tensorflow::FakeInput;
+using tensorflow::NodeDefBuilder;
+using tensorflow::Status;
+using tensorflow::TensorShape;
+
+class NgramKernelTest : public tensorflow::OpsTestBase {
+ public:
+  void MakeOp(string separator, std::vector<int> ngram_width, string left_pad,
+              string right_pad, int pad_width, bool preserve) {
+    TF_ASSERT_OK(NodeDefBuilder("tested_op", "StringNGrams")
+                     .Attr("separator", separator)
+                     .Attr("ngram_widths", ngram_width)
+                     .Attr("left_pad", left_pad)
+                     .Attr("right_pad", right_pad)
+                     .Attr("pad_width", pad_width)
+                     .Attr("preserve_short_sequences", preserve)
+                     .Input(FakeInput())
+                     .Input(FakeInput())
+                     .Finalize(node_def()));
+    TF_ASSERT_OK(InitOp());
+  }
+
+  void assert_string_equal(const std::vector<string> &expected,
+                           const Tensor &value) {
+    Tensor expected_tensor(allocator(), DT_STRING,
+                           TensorShape({static_cast<int64>(expected.size())}));
+    test::FillValues<string>(&expected_tensor, expected);
+    test::ExpectTensorEqual<string>(expected_tensor, value);
+  }
+  void assert_int64_equal(const std::vector<int64> &expected,
+                          const Tensor &value) {
+    Tensor expected_tensor(allocator(), DT_INT64,
+                           TensorShape({static_cast<int64>(expected.size())}));
+    test::FillValues<int64>(&expected_tensor, expected);
+    test::ExpectTensorEqual<int64>(expected_tensor, value);
+  }
+};
+
+TEST_F(NgramKernelTest, TestPaddedTrigrams) {
+  MakeOp("|", {3}, "LP", "RP", -1, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values(                              //
+      {"LP|LP|a", "LP|a|b", "a|b|c", "b|c|d", "c|d|RP", "d|RP|RP",  // 0
+       "LP|LP|e", "LP|e|f", "e|f|RP", "f|RP|RP"});                  // 1
+  std::vector<int64> expected_splits({0, 6, 10});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestPaddedBigramsAndTrigrams) {
+  MakeOp("|", {2, 3}, "LP", "RP", -1, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values(
+      {"LP|a", "a|b", "b|c", "c|d", "d|RP", "LP|LP|a", "LP|a|b", "a|b|c",
+       "b|c|d", "c|d|RP", "d|RP|RP",                                       // 0
+       "LP|e", "e|f", "f|RP", "LP|LP|e", "LP|e|f", "e|f|RP", "f|RP|RP"});  // 1
+  std::vector<int64> expected_splits({0, 11, 18});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestPaddedBigrams) {
+  MakeOp("|", {2}, "LP", "RP", -1, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values(       //
+      {"LP|a", "a|b", "b|c", "c|d", "d|RP",  // 0
+       "LP|e", "e|f", "f|RP"});              // 1
+  std::vector<int64> expected_splits({0, 5, 8});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestPaddingIsAtMostNGramSizeMinus1) {
+  MakeOp("|", {2}, "LP", "RP", 4, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values(       //
+      {"LP|a", "a|b", "b|c", "c|d", "d|RP",  // 0
+       "LP|e", "e|f", "f|RP"});              // 1
+  std::vector<int64> expected_splits({0, 5, 8});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestPaddedUnigramAndBigrams) {
+  MakeOp("|", {1, 2}, "LP", "RP", -1, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values(                           //
+      {"a", "b", "c", "d", "LP|a", "a|b", "b|c", "c|d", "d|RP",  // 0
+       "e", "f", "LP|e", "e|f", "f|RP"});                        // 1
+  std::vector<int64> expected_splits({0, 9, 14});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestOverlappingPaddedNGrams) {
+  // This test validates that n-grams with both left and right padding in a
+  // single ngram token are created correctly.
+  MakeOp("|", {3}, "LP", "RP", -1, false);
+  // Batch items are:
+  // 0: "a"
+  // 1: "b", "c", "d"
+  // 2: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({4}), {0, 1, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values(                     //
+      {"LP|LP|a", "LP|a|RP", "a|RP|RP",                    // ngrams for elem. 0
+       "LP|LP|b", "LP|b|c", "b|c|d", "c|d|RP", "d|RP|RP",  // ngrams for elem. 1
+       "LP|LP|e", "LP|e|f", "e|f|RP", "f|RP|RP"});         // ngrams for elem. 2
+  std::vector<int64> expected_splits({0, 3, 8, 12});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestOverlappingPaddedMultiCharNGrams) {
+  MakeOp("|", {3}, "LP", "RP", -1, false);
+  // Batch items are:
+  // 0: "a"
+  // 1: "b", "c", "d"
+  // 2: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}),
+                            {"aa", "bb", "cc", "dd", "ee", "ff"});
+  AddInputFromArray<int64>(TensorShape({4}), {0, 1, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values(                              //
+      {"LP|LP|aa", "LP|aa|RP", "aa|RP|RP",                          //
+       "LP|LP|bb", "LP|bb|cc", "bb|cc|dd", "cc|dd|RP", "dd|RP|RP",  //
+       "LP|LP|ee", "LP|ee|ff", "ee|ff|RP", "ff|RP|RP"});            //
+  std::vector<int64> expected_splits({0, 3, 8, 12});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestMultiOverlappingPaddedNGrams) {
+  // This test validates that n-grams with more than 1 padding value on each
+  // side are created correctly.
+  MakeOp("|", {5}, "LP", "RP", -1, false);
+  // Batch items are:
+  // 0: "a"
+  AddInputFromArray<string>(TensorShape({1}), {"a"});
+  AddInputFromArray<int64>(TensorShape({2}), {0, 1});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({"LP|LP|LP|LP|a", "LP|LP|LP|a|RP",
+                                       "LP|LP|a|RP|RP", "LP|a|RP|RP|RP",
+                                       "a|RP|RP|RP|RP"});
+  std::vector<int64> expected_splits({0, 5});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestUnpaddedTrigrams) {
+  MakeOp("|", {3}, "", "", 0, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({"a|b|c", "b|c|d"});
+  std::vector<int64> expected_splits({0, 2, 2});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestUnpaddedTrigramsWithEmptySequence) {
+  MakeOp("|", {3}, "", "", 0, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({4}), {0, 4, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({"a|b|c", "b|c|d"});
+  std::vector<int64> expected_splits({0, 2, 2, 2});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestUnpaddedTrigramsWithPreserveShort) {
+  MakeOp("|", {3}, "", "", 0, true);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({"a|b|c", "b|c|d", "e|f"});
+  std::vector<int64> expected_splits({0, 2, 3});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestUnpaddedTrigramsWithPreserveShortAndEmptySequence) {
+  MakeOp("|", {3}, "", "", 0, true);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({4}), {0, 4, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({"a|b|c", "b|c|d", "e|f"});
+  std::vector<int64> expected_splits({0, 2, 2, 3});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestUnpaddedTrigramsAndQuadgramsWithPreserveShort) {
+  MakeOp("|", {4, 3}, "", "", 0, true);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({"a|b|c|d", "a|b|c", "b|c|d", "e|f"});
+  std::vector<int64> expected_splits({0, 3, 4});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestUnpaddedBigramsAndTrigrams) {
+  MakeOp("|", {2, 3}, "", "", 0, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values(
+      {"a|b", "b|c", "c|d", "a|b|c", "b|c|d", "e|f"});
+  std::vector<int64> expected_splits({0, 5, 6});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestUnpaddedBigramsAndTrigramsWithPreserveShort) {
+  MakeOp("|", {2, 3}, "", "", 0, true);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  // Note that in this case, because the bigram 'e|f' was already generated,
+  // the op will not generate a special preserve_short bigram.
+  std::vector<string> expected_values(
+      {"a|b", "b|c", "c|d", "a|b|c", "b|c|d", "e|f"});
+  std::vector<int64> expected_splits({0, 5, 6});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestUnpaddedTrigramsAndBigramsWithPreserveShort) {
+  MakeOp("|", {3, 2}, "", "", 0, true);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  // Note that in this case, because the bigram 'e|f' was already generated,
+  // the op will not generate a special preserve_short bigram.
+  std::vector<string> expected_values(
+      {"a|b|c", "b|c|d", "a|b", "b|c", "c|d", "e|f"});
+  std::vector<int64> expected_splits({0, 5, 6});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestUnpaddedBigrams) {
+  MakeOp("|", {2}, "", "", 0, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({"a|b", "b|c", "c|d", "e|f"});
+  std::vector<int64> expected_splits({0, 3, 4});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestOverlappingUnpaddedNGrams) {
+  MakeOp("|", {3}, "", "", 0, false);
+  // Batch items are:
+  // 0: "a"
+  // 1: "b", "c", "d"
+  // 2: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({4}), {0, 1, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({"b|c|d"});
+  std::vector<int64> expected_splits({0, 0, 1, 1});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestOverlappingUnpaddedNGramsNoOutput) {
+  MakeOp("|", {5}, "", "", 0, false);
+  // Batch items are:
+  // 0: "a"
+  // 1: "b", "c", "d"
+  // 2: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({4}), {0, 1, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({});
+  std::vector<int64> expected_splits({0, 0, 0, 0});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestSinglyPaddedTrigrams) {
+  MakeOp("|", {3}, "LP", "RP", 1, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({"LP|a|b", "a|b|c", "b|c|d", "c|d|RP",  //
+                                       "LP|e|f", "e|f|RP"});
+  std::vector<int64> expected_splits({0, 4, 6});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestSinglyPaddedBigrams) {
+  MakeOp("|", {2}, "LP", "RP", 1, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({"LP|a", "a|b", "b|c", "c|d", "d|RP",  //
+                                       "LP|e", "e|f", "f|RP"});
+  std::vector<int64> expected_splits({0, 5, 8});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestSinglyPaddedBigramsAnd5grams) {
+  MakeOp("|", {2, 5}, "LP", "RP", 1, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values(                                   //
+      {"LP|a", "a|b", "b|c", "c|d", "d|RP", "LP|a|b|c|d", "a|b|c|d|RP",  //
+       "LP|e", "e|f", "f|RP"});
+  std::vector<int64> expected_splits({0, 7, 10});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestSinglyPadded5gramsWithPreserveShort) {
+  MakeOp("|", {5}, "LP", "RP", 1, true);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values(  //
+      {"LP|a|b|c|d", "a|b|c|d|RP",      //
+       "LP|e|f|RP"});
+  std::vector<int64> expected_splits({0, 2, 3});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestOverlappingSinglyPaddedNGrams) {
+  MakeOp("|", {3}, "LP", "RP", 1, false);
+  // Batch items are:
+  // 0: "a"
+  // 1: "b", "c", "d"
+  // 2: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({4}), {0, 1, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values(
+      {"LP|a|RP",                    // ngrams for elem. 0
+       "LP|b|c", "b|c|d", "c|d|RP",  // ngrams for elem. 1
+       "LP|e|f", "e|f|RP"});         // ngrams for elem. 2
+  std::vector<int64> expected_splits({0, 1, 4, 6});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestOverlappingSinglyPaddedNGramsNoOutput) {
+  MakeOp("|", {5}, "LP", "RP", 1, false);
+  // Batch items are:
+  // 0: "a"
+  // 1: "b", "c", "d"
+  // 2: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({4}), {0, 1, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({"LP|b|c|d|RP"});
+  std::vector<int64> expected_splits({0, 0, 1, 1});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestSinglyPaddedUnigrams) {
+  MakeOp("|", {1}, "LP", "RP", 1, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({"a", "b", "c", "d", "e", "f"});
+  std::vector<int64> expected_splits({0, 4, 6});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestEmptyInput) {
+  MakeOp("|", {1}, "LP", "RP", 3, false);
+  AddInputFromArray<string>(TensorShape({0}), {});
+  AddInputFromArray<int64>(TensorShape({0}), {});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({});
+  std::vector<int64> expected_splits({});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, ShapeFn) {
+  ShapeInferenceTestOp op("StringNGrams");
+  INFER_OK(op, "?;?", "[?];[?]");
+  INFER_OK(op, "[1];?", "[?];[?]");
+  INFER_OK(op, "[1];[2]", "[?];in1");
+  INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[];?");
+  INFER_ERROR("Shape must be rank 1 but is rank 0", op, "?;[]");
+}
+
+}  // namespace text
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/string_split_op.cc b/tensorflow/core/kernels/string_split_op.cc
index 3884370..d6d27de 100644
--- a/tensorflow/core/kernels/string_split_op.cc
+++ b/tensorflow/core/kernels/string_split_op.cc
@@ -178,7 +178,7 @@
                 errors::InvalidArgument("input must be a vector, got shape: ",
                                         input_tensor->shape().DebugString()));
 
-    const auto input_vec = input_tensor->vec<string>();
+    const auto input_vec = input_tensor->vec<tstring>();
     const int64 batch_size = input_vec.dimension(0);
 
     const Tensor* delimiter_tensor;
@@ -220,7 +220,7 @@
     OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({2}), &sp_shape_t));
 
     auto sp_indices = sp_indices_t->matrix<int64>();
-    auto sp_tokens = sp_tokens_t->vec<string>();
+    auto sp_tokens = sp_tokens_t->vec<tstring>();
     auto sp_shape = sp_shape_t->vec<int64>();
     sp_shape(0) = batch_size;
     sp_shape(1) = max_num_entries;
@@ -253,7 +253,7 @@
                 errors::InvalidArgument("input must be a vector, got shape: ",
                                         input_tensor->shape().DebugString()));
 
-    const auto input_vec = input_tensor->vec<string>();
+    const auto input_vec = input_tensor->vec<tstring>();
     const int64 batch_size = input_vec.dimension(0);
 
     const Tensor* sep_tensor;
@@ -261,7 +261,7 @@
     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(sep_tensor->shape()),
                 errors::InvalidArgument("sep must be a scalar, got shape: ",
                                         sep_tensor->shape().DebugString()));
-    const auto sep_vec = sep_tensor->flat<string>();
+    const auto sep_vec = sep_tensor->flat<tstring>();
     StringPiece sep(sep_vec(0));
     std::vector<StringPiece> tokens;
     // Guess that we'll be unpacking a handful of tokens per example.
@@ -290,7 +290,7 @@
     OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({2}), &sp_shape_t));
 
     auto sp_indices = sp_indices_t->matrix<int64>();
-    auto sp_tokens = sp_tokens_t->vec<string>();
+    auto sp_tokens = sp_tokens_t->vec<tstring>();
     auto sp_shape = sp_shape_t->vec<int64>();
     sp_shape(0) = batch_size;
     sp_shape(1) = max_num_entries;
diff --git a/tensorflow/core/kernels/string_split_op_test.cc b/tensorflow/core/kernels/string_split_op_test.cc
index 58ad61a..4494cf9 100644
--- a/tensorflow/core/kernels/string_split_op_test.cc
+++ b/tensorflow/core/kernels/string_split_op_test.cc
@@ -57,7 +57,7 @@
 Tensor GetTestTensor(int batch) {
   const int sz = TF_ARRAYSIZE(lines);
   Tensor t(DT_STRING, {batch});
-  auto s = t.flat<string>();
+  auto s = t.flat<tstring>();
   for (int i = 0; i < batch; ++i) {
     s(i) = lines[i % sz];
   }
@@ -67,7 +67,7 @@
 Graph* SetupStringSplitGraph(const Tensor& input) {
   Graph* g = new Graph(OpRegistry::Global());
   Tensor delim(DT_STRING, TensorShape({}));
-  delim.flat<string>().setConstant(" ");
+  delim.flat<tstring>().setConstant(" ");
 
   TF_CHECK_OK(NodeBuilder("string_split_op", "StringSplit")
                   .Input(test::graph::Constant(g, input))
@@ -98,7 +98,7 @@
 Graph* SetupStringSplitV2Graph(const Tensor& input) {
   Graph* g = new Graph(OpRegistry::Global());
   Tensor sep(DT_STRING, TensorShape({}));
-  sep.flat<string>().setConstant(" ");
+  sep.flat<tstring>().setConstant(" ");
 
   TF_CHECK_OK(NodeBuilder("string_split_op", "StringSplitV2")
                   .Input(test::graph::Constant(g, input))
diff --git a/tensorflow/core/kernels/string_strip_op.cc b/tensorflow/core/kernels/string_strip_op.cc
index 544dca9..715ec27 100644
--- a/tensorflow/core/kernels/string_strip_op.cc
+++ b/tensorflow/core/kernels/string_strip_op.cc
@@ -37,8 +37,8 @@
     OP_REQUIRES_OK(
         ctx, ctx->allocate_output(0, input_tensor->shape(), &output_tensor));
 
-    const auto input = input_tensor->flat<string>();
-    auto output = output_tensor->flat<string>();
+    const auto input = input_tensor->flat<tstring>();
+    auto output = output_tensor->flat<tstring>();
 
     for (int64 i = 0; i < input.size(); ++i) {
       StringPiece entry(input(i));
diff --git a/tensorflow/core/kernels/string_to_hash_bucket_op.cc b/tensorflow/core/kernels/string_to_hash_bucket_op.cc
index 10fc6ee..1505ddb 100644
--- a/tensorflow/core/kernels/string_to_hash_bucket_op.cc
+++ b/tensorflow/core/kernels/string_to_hash_bucket_op.cc
@@ -33,7 +33,7 @@
   void Compute(OpKernelContext* context) override {
     const Tensor* input_tensor;
     OP_REQUIRES_OK(context, context->input("string_tensor", &input_tensor));
-    const auto& input_flat = input_tensor->flat<string>();
+    const auto& input_flat = input_tensor->flat<tstring>();
 
     Tensor* output_tensor = nullptr;
     OP_REQUIRES_OK(context,
diff --git a/tensorflow/core/kernels/string_to_hash_bucket_op.h b/tensorflow/core/kernels/string_to_hash_bucket_op.h
index 62ef35b..8647695 100644
--- a/tensorflow/core/kernels/string_to_hash_bucket_op.h
+++ b/tensorflow/core/kernels/string_to_hash_bucket_op.h
@@ -36,7 +36,7 @@
   void Compute(OpKernelContext* context) override {
     const Tensor* input_tensor;
     OP_REQUIRES_OK(context, context->input("input", &input_tensor));
-    const auto& input_flat = input_tensor->flat<string>();
+    const auto& input_flat = input_tensor->flat<tstring>();
 
     Tensor* output_tensor = nullptr;
     OP_REQUIRES_OK(context,
@@ -78,7 +78,7 @@
   void Compute(OpKernelContext* context) override {
     const Tensor* input_tensor;
     OP_REQUIRES_OK(context, context->input("input", &input_tensor));
-    const auto& input_flat = input_tensor->flat<string>();
+    const auto& input_flat = input_tensor->flat<tstring>();
 
     Tensor* output_tensor = nullptr;
     OP_REQUIRES_OK(context,
diff --git a/tensorflow/core/kernels/string_to_number_op.cc b/tensorflow/core/kernels/string_to_number_op.cc
index 22742dd..8340f35 100644
--- a/tensorflow/core/kernels/string_to_number_op.cc
+++ b/tensorflow/core/kernels/string_to_number_op.cc
@@ -40,7 +40,7 @@
     // underlying storage.
     const Tensor* input_tensor;
     OP_REQUIRES_OK(context, context->input("string_tensor", &input_tensor));
-    const auto& input_flat = input_tensor->flat<string>();
+    const auto& input_flat = input_tensor->flat<tstring>();
 
     Tensor* output_tensor = nullptr;
     OP_REQUIRES_OK(context,
diff --git a/tensorflow/core/kernels/string_upper_op.cc b/tensorflow/core/kernels/string_upper_op.cc
index f2a1d33..d9f088a 100644
--- a/tensorflow/core/kernels/string_upper_op.cc
+++ b/tensorflow/core/kernels/string_upper_op.cc
@@ -45,8 +45,8 @@
     OP_REQUIRES_OK(
         ctx, ctx->allocate_output(0, input_tensor->shape(), &output_tensor));
 
-    const auto input = input_tensor->flat<string>();
-    auto output = output_tensor->flat<string>();
+    const auto input = input_tensor->flat<tstring>();
+    auto output = output_tensor->flat<tstring>();
     if (encoding_.empty()) {
       for (int64 i = 0; i < input.size(); ++i) {
         StringPiece entry(input(i));
diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc
index 77b16b9..458d67c 100644
--- a/tensorflow/core/kernels/substr_op.cc
+++ b/tensorflow/core/kernels/substr_op.cc
@@ -59,13 +59,13 @@
       // Do not need to do broadcasting
 
       // Reshape input
-      auto input = input_tensor.flat<string>();
+      auto input = input_tensor.flat<tstring>();
       // Allocate output
       Tensor* output_tensor = nullptr;
       OP_REQUIRES_OK(context,
                      context->allocate_output("output", input_tensor.shape(),
                                               &output_tensor));
-      auto output = output_tensor->flat<string>();
+      auto output = output_tensor->flat<tstring>();
       if (is_scalar) {
         // Perform Op with scalar pos/len
         const T pos =
@@ -141,8 +141,8 @@
       switch (ndims) {
         case 1: {
           // Reshape tensors according to BCast results
-          auto input = input_tensor.shaped<string, 1>(bcast.x_reshape());
-          auto output = output_tensor->shaped<string, 1>(bcast.result_shape());
+          auto input = input_tensor.shaped<tstring, 1>(bcast.x_reshape());
+          auto output = output_tensor->shaped<tstring, 1>(bcast.result_shape());
           auto pos_shaped = pos_tensor.shaped<T, 1>(bcast.y_reshape());
           auto len_shaped = len_tensor.shaped<T, 1>(bcast.y_reshape());
 
@@ -204,8 +204,8 @@
         }
         case 2: {
           // Reshape tensors according to BCast results
-          auto input = input_tensor.shaped<string, 2>(bcast.x_reshape());
-          auto output = output_tensor->shaped<string, 2>(bcast.result_shape());
+          auto input = input_tensor.shaped<tstring, 2>(bcast.x_reshape());
+          auto output = output_tensor->shaped<tstring, 2>(bcast.result_shape());
           auto pos_shaped = pos_tensor.shaped<T, 2>(bcast.y_reshape());
           auto len_shaped = len_tensor.shaped<T, 2>(bcast.y_reshape());
 
diff --git a/tensorflow/core/kernels/substr_op_test.cc b/tensorflow/core/kernels/substr_op_test.cc
index ea6b1ed..3aebfe3 100644
--- a/tensorflow/core/kernels/substr_op_test.cc
+++ b/tensorflow/core/kernels/substr_op_test.cc
@@ -115,7 +115,7 @@
 Tensor GetTestTensor(int batch) {
   const int sz = TF_ARRAYSIZE(ascii_lines);
   Tensor t(DT_STRING, {batch});
-  auto s = t.flat<string>();
+  auto s = t.flat<tstring>();
   for (int i = 0; i < batch; ++i) {
     s(i) = ascii_lines[i % sz];
   }
@@ -125,7 +125,7 @@
 Tensor GetTestUTF8Tensor(int batch) {
   const int sz = TF_ARRAYSIZE(unicode_lines);
   Tensor t(DT_STRING, {batch});
-  auto s = t.flat<string>();
+  auto s = t.flat<tstring>();
   for (int i = 0; i < batch; ++i) {
     s(i) = unicode_lines[i % sz];
   }
diff --git a/tensorflow/core/kernels/summary_audio_op.cc b/tensorflow/core/kernels/summary_audio_op.cc
index f5ddb90..fbb1c2c 100644
--- a/tensorflow/core/kernels/summary_audio_op.cc
+++ b/tensorflow/core/kernels/summary_audio_op.cc
@@ -44,7 +44,7 @@
     OP_REQUIRES(c, tensor.dims() >= 2 && tensor.dims() <= 3,
                 errors::InvalidArgument("Tensor must be 3-D or 2-D, got: ",
                                         tensor.shape().DebugString()));
-    const string& base_tag = tag.scalar<string>()();
+    const string& base_tag = tag.scalar<tstring>()();
 
     float sample_rate = sample_rate_attr_;
     if (!has_sample_rate_attr_) {
diff --git a/tensorflow/core/kernels/summary_audio_op_test.cc b/tensorflow/core/kernels/summary_audio_op_test.cc
index 1b957c5..7c6ec04 100644
--- a/tensorflow/core/kernels/summary_audio_op_test.cc
+++ b/tensorflow/core/kernels/summary_audio_op_test.cc
@@ -93,7 +93,7 @@
   Tensor* out_tensor = GetOutput(0);
   ASSERT_EQ(0, out_tensor->dims());
   Summary summary;
-  ParseProtoUnlimited(&summary, out_tensor->scalar<string>()());
+  ParseProtoUnlimited(&summary, out_tensor->scalar<tstring>()());
 
   CheckAndRemoveEncodedAudio(&summary);
   EXPECT_SummaryMatches(summary, R"(
@@ -127,7 +127,7 @@
   Tensor* out_tensor = GetOutput(0);
   ASSERT_EQ(0, out_tensor->dims());
   Summary summary;
-  ParseProtoUnlimited(&summary, out_tensor->scalar<string>()());
+  ParseProtoUnlimited(&summary, out_tensor->scalar<tstring>()());
 
   CheckAndRemoveEncodedAudio(&summary);
   EXPECT_SummaryMatches(summary, R"(
diff --git a/tensorflow/core/kernels/summary_image_op.cc b/tensorflow/core/kernels/summary_image_op.cc
index 68f17c2..bfba449 100644
--- a/tensorflow/core/kernels/summary_image_op.cc
+++ b/tensorflow/core/kernels/summary_image_op.cc
@@ -61,7 +61,7 @@
                 errors::InvalidArgument(
                     "Tensor must be 4-D with last dim 1, 3, or 4, not ",
                     tensor.shape().DebugString()));
-    const string& base_tag = tags.scalar<string>()();
+    const string& base_tag = tags.scalar<tstring>()();
 
     OP_REQUIRES(c,
                 tensor.dim_size(0) < (1LL << 31) &&
diff --git a/tensorflow/core/kernels/summary_image_op_test.cc b/tensorflow/core/kernels/summary_image_op_test.cc
index 74e0d09..be8e44d 100644
--- a/tensorflow/core/kernels/summary_image_op_test.cc
+++ b/tensorflow/core/kernels/summary_image_op_test.cc
@@ -87,7 +87,7 @@
   Tensor* out_tensor = GetOutput(0);
   ASSERT_EQ(0, out_tensor->dims());
   Summary summary;
-  ParseProtoUnlimited(&summary, out_tensor->scalar<string>()());
+  ParseProtoUnlimited(&summary, out_tensor->scalar<tstring>()());
 
   CheckAndRemoveEncodedImages(&summary);
   EXPECT_SummaryMatches(summary, R"(
@@ -110,7 +110,7 @@
   Tensor* out_tensor = GetOutput(0);
   ASSERT_EQ(0, out_tensor->dims());
   Summary summary;
-  ParseProtoUnlimited(&summary, out_tensor->scalar<string>()());
+  ParseProtoUnlimited(&summary, out_tensor->scalar<tstring>()());
 
   CheckAndRemoveEncodedImages(&summary);
   EXPECT_SummaryMatches(summary, R"(
@@ -142,7 +142,7 @@
   Tensor* out_tensor = GetOutput(0);
   ASSERT_EQ(0, out_tensor->dims());
   Summary summary;
-  ParseProtoUnlimited(&summary, out_tensor->scalar<string>()());
+  ParseProtoUnlimited(&summary, out_tensor->scalar<tstring>()());
 
   CheckAndRemoveEncodedImages(&summary);
   EXPECT_SummaryMatches(summary, R"(
diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc
index e17e28e..7f888da 100644
--- a/tensorflow/core/kernels/summary_kernels.cc
+++ b/tensorflow/core/kernels/summary_kernels.cc
@@ -38,13 +38,13 @@
   void Compute(OpKernelContext* ctx) override {
     const Tensor* tmp;
     OP_REQUIRES_OK(ctx, ctx->input("logdir", &tmp));
-    const string logdir = tmp->scalar<string>()();
+    const string logdir = tmp->scalar<tstring>()();
     OP_REQUIRES_OK(ctx, ctx->input("max_queue", &tmp));
     const int32 max_queue = tmp->scalar<int32>()();
     OP_REQUIRES_OK(ctx, ctx->input("flush_millis", &tmp));
     const int32 flush_millis = tmp->scalar<int32>()();
     OP_REQUIRES_OK(ctx, ctx->input("filename_suffix", &tmp));
-    const string filename_suffix = tmp->scalar<string>()();
+    const string filename_suffix = tmp->scalar<tstring>()();
 
     core::RefCountPtr<SummaryWriterInterface> s;
     OP_REQUIRES_OK(ctx, LookupOrCreateResource<SummaryWriterInterface>(
@@ -67,13 +67,13 @@
   void Compute(OpKernelContext* ctx) override {
     const Tensor* tmp;
     OP_REQUIRES_OK(ctx, ctx->input("db_uri", &tmp));
-    const string db_uri = tmp->scalar<string>()();
+    const string db_uri = tmp->scalar<tstring>()();
     OP_REQUIRES_OK(ctx, ctx->input("experiment_name", &tmp));
-    const string experiment_name = tmp->scalar<string>()();
+    const string experiment_name = tmp->scalar<tstring>()();
     OP_REQUIRES_OK(ctx, ctx->input("run_name", &tmp));
-    const string run_name = tmp->scalar<string>()();
+    const string run_name = tmp->scalar<tstring>()();
     OP_REQUIRES_OK(ctx, ctx->input("user_name", &tmp));
-    const string user_name = tmp->scalar<string>()();
+    const string user_name = tmp->scalar<tstring>()();
 
     core::RefCountPtr<SummaryWriterInterface> s;
     OP_REQUIRES_OK(
@@ -132,9 +132,9 @@
     OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
     const int64 step = tmp->scalar<int64>()();
     OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
-    const string& tag = tmp->scalar<string>()();
+    const string& tag = tmp->scalar<tstring>()();
     OP_REQUIRES_OK(ctx, ctx->input("summary_metadata", &tmp));
-    const string& serialized_metadata = tmp->scalar<string>()();
+    const string& serialized_metadata = tmp->scalar<tstring>()();
 
     const Tensor* t;
     OP_REQUIRES_OK(ctx, ctx->input("tensor", &t));
@@ -166,7 +166,7 @@
     // Each Summary proto contains just one repeated field "value" of Value
     // messages with the actual data, so repeated Merge() is equivalent to
     // concatenating all the Value entries together into a single Event.
-    const auto summary_pbs = t->flat<string>();
+    const auto summary_pbs = t->flat<tstring>();
     for (int i = 0; i < summary_pbs.size(); ++i) {
       if (!event->mutable_summary()->MergeFromString(summary_pbs(i))) {
         ctx->CtxFailureWithWarning(errors::DataLoss(
@@ -191,7 +191,7 @@
     const Tensor* t;
     OP_REQUIRES_OK(ctx, ctx->input("event", &t));
     std::unique_ptr<Event> event{new Event};
-    if (!ParseProtoUnlimited(event.get(), t->scalar<string>()())) {
+    if (!ParseProtoUnlimited(event.get(), t->scalar<tstring>()())) {
       ctx->CtxFailureWithWarning(
           errors::DataLoss("Bad tf.Event binary proto tensor string"));
       return;
@@ -212,7 +212,7 @@
     OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
     const int64 step = tmp->scalar<int64>()();
     OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
-    const string& tag = tmp->scalar<string>()();
+    const string& tag = tmp->scalar<tstring>()();
 
     const Tensor* t;
     OP_REQUIRES_OK(ctx, ctx->input("value", &t));
@@ -234,7 +234,7 @@
     OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
     const int64 step = tmp->scalar<int64>()();
     OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
-    const string& tag = tmp->scalar<string>()();
+    const string& tag = tmp->scalar<tstring>()();
 
     const Tensor* t;
     OP_REQUIRES_OK(ctx, ctx->input("values", &t));
@@ -262,7 +262,7 @@
     OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
     const int64 step = tmp->scalar<int64>()();
     OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
-    const string& tag = tmp->scalar<string>()();
+    const string& tag = tmp->scalar<tstring>()();
     const Tensor* bad_color;
     OP_REQUIRES_OK(ctx, ctx->input("bad_color", &bad_color));
     OP_REQUIRES(
@@ -297,7 +297,7 @@
     OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
     const int64 step = tmp->scalar<int64>()();
     OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
-    const string& tag = tmp->scalar<string>()();
+    const string& tag = tmp->scalar<tstring>()();
     OP_REQUIRES_OK(ctx, ctx->input("sample_rate", &tmp));
     const float sample_rate = tmp->scalar<float>()();
 
@@ -326,7 +326,7 @@
     const int64 step = t->scalar<int64>()();
     OP_REQUIRES_OK(ctx, ctx->input("tensor", &t));
     std::unique_ptr<GraphDef> graph{new GraphDef};
-    if (!ParseProtoUnlimited(graph.get(), t->scalar<string>()())) {
+    if (!ParseProtoUnlimited(graph.get(), t->scalar<tstring>()())) {
       ctx->CtxFailureWithWarning(
           errors::DataLoss("Bad tf.GraphDef binary proto tensor string"));
       return;
diff --git a/tensorflow/core/kernels/summary_op.cc b/tensorflow/core/kernels/summary_op.cc
index 1053aa7..a765825 100644
--- a/tensorflow/core/kernels/summary_op.cc
+++ b/tensorflow/core/kernels/summary_op.cc
@@ -47,7 +47,7 @@
         errors::InvalidArgument(
             "tags and values not the same shape: ", tags.shape().DebugString(),
             " != ", values.shape().DebugString(), SingleTag(tags)));
-    auto Ttags = tags.flat<string>();
+    auto Ttags = tags.flat<tstring>();
     auto Tvalues = values.flat<T>();
     Summary s;
     for (int i = 0; i < Ttags.size(); i++) {
@@ -64,7 +64,7 @@
   // If there's only one tag, include it in the error message
   static string SingleTag(const Tensor& tags) {
     if (tags.NumElements() == 1) {
-      return strings::StrCat(" (tag '", tags.flat<string>()(0), "')");
+      return strings::StrCat(" (tag '", tags.flat<tstring>()(0), "')");
     } else {
       return "";
     }
@@ -138,7 +138,7 @@
     std::unordered_set<string> tags;
     for (int input_num = 0; input_num < c->num_inputs(); input_num++) {
       const Tensor& in = c->input(input_num);
-      auto in_vec = in.flat<string>();
+      auto in_vec = in.flat<tstring>();
       for (int i = 0; i < in_vec.dimension(0); i++) {
         const string& s_in = in_vec(i);
         Summary summary_in;
diff --git a/tensorflow/core/kernels/summary_op_test.cc b/tensorflow/core/kernels/summary_op_test.cc
index 697c03a..9dcc98e 100644
--- a/tensorflow/core/kernels/summary_op_test.cc
+++ b/tensorflow/core/kernels/summary_op_test.cc
@@ -88,7 +88,7 @@
   Tensor* out_tensor = GetOutput(0);
   ASSERT_EQ(0, out_tensor->dims());
   Summary summary;
-  ParseProtoUnlimited(&summary, out_tensor->scalar<string>()());
+  ParseProtoUnlimited(&summary, out_tensor->scalar<tstring>()());
   EXPECT_SummaryMatches(summary, R"(
       value { tag: 'tag1' simple_value: 1.0 }
       value { tag: 'tag2' simple_value: -0.73 }
@@ -100,7 +100,7 @@
   MakeOp(DT_HALF);
 
   // Feed and run
-  AddInputFromList<string>(TensorShape({3}), {"tag1", "tag2", "tag3"});
+  AddInputFromList<tstring>(TensorShape({3}), {"tag1", "tag2", "tag3"});
   AddInputFromList<Eigen::half>(TensorShape({3}), {1.0, -2.0, 10000.0});
   TF_ASSERT_OK(RunOpKernel());
 
@@ -108,7 +108,7 @@
   Tensor* out_tensor = GetOutput(0);
   ASSERT_EQ(0, out_tensor->dims());
   Summary summary;
-  ParseProtoUnlimited(&summary, out_tensor->scalar<string>()());
+  ParseProtoUnlimited(&summary, out_tensor->scalar<tstring>()());
   EXPECT_SummaryMatches(summary, R"(
       value { tag: 'tag1' simple_value: 1.0 }
       value { tag: 'tag2' simple_value: -2.0 }
@@ -177,7 +177,7 @@
   Tensor* out_tensor = GetOutput(0);
   ASSERT_EQ(0, out_tensor->dims());
   Summary summary;
-  ParseProtoUnlimited(&summary, out_tensor->scalar<string>()());
+  ParseProtoUnlimited(&summary, out_tensor->scalar<tstring>()());
   ASSERT_EQ(summary.value_size(), 1);
   EXPECT_EQ(summary.value(0).tag(), "taghisto");
   histogram::Histogram histo;
@@ -205,7 +205,7 @@
   Tensor* out_tensor = GetOutput(0);
   ASSERT_EQ(0, out_tensor->dims());
   Summary summary;
-  ParseProtoUnlimited(&summary, out_tensor->scalar<string>()());
+  ParseProtoUnlimited(&summary, out_tensor->scalar<tstring>()());
   ASSERT_EQ(summary.value_size(), 1);
   EXPECT_EQ(summary.value(0).tag(), "taghisto");
   histogram::Histogram histo;
@@ -234,7 +234,7 @@
   Tensor* out_tensor = GetOutput(0);
   ASSERT_EQ(0, out_tensor->dims());
   Summary summary;
-  ParseProtoUnlimited(&summary, out_tensor->scalar<string>()());
+  ParseProtoUnlimited(&summary, out_tensor->scalar<tstring>()());
   ASSERT_EQ(summary.value_size(), 1);
   EXPECT_EQ(summary.value(0).tag(), "taghisto");
   histogram::Histogram histo;
@@ -308,7 +308,7 @@
   Tensor* out_tensor = GetOutput(0);
   ASSERT_EQ(0, out_tensor->dims());
   Summary summary;
-  ParseProtoUnlimited(&summary, out_tensor->scalar<string>()());
+  ParseProtoUnlimited(&summary, out_tensor->scalar<tstring>()());
 
   EXPECT_SummaryMatches(summary,
                         "value { tag: \"tag1\" simple_value: 1.0 } "
@@ -342,7 +342,7 @@
   Tensor* out_tensor = GetOutput(0);
   ASSERT_EQ(0, out_tensor->dims());
   Summary summary;
-  ParseProtoUnlimited(&summary, out_tensor->scalar<string>()());
+  ParseProtoUnlimited(&summary, out_tensor->scalar<tstring>()());
 
   EXPECT_SummaryMatches(summary,
                         "value { tag: \"tag1\" simple_value: 1.0 } "
diff --git a/tensorflow/core/kernels/summary_tensor_op_test.cc b/tensorflow/core/kernels/summary_tensor_op_test.cc
index 55a0cb3..6bc4d15 100644
--- a/tensorflow/core/kernels/summary_tensor_op_test.cc
+++ b/tensorflow/core/kernels/summary_tensor_op_test.cc
@@ -80,14 +80,14 @@
   Tensor* out_tensor = GetOutput(0);
   ASSERT_EQ(0, out_tensor->dims());
   Summary summary;
-  ParseProtoUnlimited(&summary, out_tensor->scalar<string>()());
+  ParseProtoUnlimited(&summary, out_tensor->scalar<tstring>()());
   ASSERT_EQ(1, summary.value_size());
 
   // Check the content of the tensor stored in the summary.
   Tensor string_content_tensor;
   CHECK(string_content_tensor.FromProto(summary.value(0).tensor()));
   ASSERT_EQ("some string tensor content",
-            string_content_tensor.scalar<string>()());
+            string_content_tensor.scalar<tstring>()());
 
   // Check plugin-related data.
   ASSERT_EQ("tag_foo", summary.value(0).tag());
diff --git a/tensorflow/core/kernels/tensor_array.cc b/tensorflow/core/kernels/tensor_array.cc
index 8e8faf8..2bd6ac0 100644
--- a/tensorflow/core/kernels/tensor_array.cc
+++ b/tensorflow/core/kernels/tensor_array.cc
@@ -91,8 +91,8 @@
   if (tensors_.size() != rhs->tensors_.size()) {
     return errors::InvalidArgument(
         "TensorArray sizes do not match during CopyShapesFrom: ",
-        handle_.vec<string>()(1), " has size ", tensors_.size(), " but rhs ",
-        rhs->handle_.vec<string>()(1), " has size ", rhs->tensors_.size());
+        handle_.vec<tstring>()(1), " has size ", tensors_.size(), " but rhs ",
+        rhs->handle_.vec<tstring>()(1), " has size ", rhs->tensors_.size());
   }
   for (std::size_t i = 0; i < tensors_.size(); ++i) {
     // Skip "soft copy" of indices which have not been written.
diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h
index 964b463..bea97d1 100644
--- a/tensorflow/core/kernels/tensor_array.h
+++ b/tensorflow/core/kernels/tensor_array.h
@@ -365,7 +365,7 @@
 
   Status LockedReturnIfClosed() const EXCLUSIVE_LOCKS_REQUIRED(mu_) {
     if (closed_) {
-      return errors::InvalidArgument("TensorArray ", handle_.vec<string>()(1),
+      return errors::InvalidArgument("TensorArray ", handle_.vec<tstring>()(1),
                                      " has already been closed.");
     }
     return Status::OK();
@@ -447,7 +447,7 @@
   size_t index_size = static_cast<size_t>(index);
   if (index < 0 || (!dynamic_size_ && index_size >= tensors_.size())) {
     return errors::InvalidArgument(
-        "TensorArray ", handle_.vec<string>()(1), ": Tried to write to index ",
+        "TensorArray ", handle_.vec<tstring>()(1), ": Tried to write to index ",
         index, " but array is not resizeable and size is: ", tensors_.size());
   }
   if (dynamic_size_) {
@@ -464,14 +464,14 @@
   Tensor* value_t = value->AccessTensor(ctx);
   if (value_t->dtype() != dtype_) {
     return errors::InvalidArgument(
-        "TensorArray ", handle_.vec<string>()(1),
+        "TensorArray ", handle_.vec<tstring>()(1),
         ": Could not write to TensorArray index ", index,
         " because the value dtype is ", DataTypeString(value_t->dtype()),
         " but TensorArray dtype is ", DataTypeString(dtype_), ".");
   }
   if (!element_shape_.IsCompatibleWith(value_t->shape())) {
     return errors::InvalidArgument(
-        "TensorArray ", handle_.vec<string>()(1),
+        "TensorArray ", handle_.vec<tstring>()(1),
         ": Could not write to TensorArray index ", index,
         " because the value shape is ", value_t->shape().DebugString(),
         " which is incompatible with the TensorArray's inferred element "
@@ -482,13 +482,13 @@
   }
 
   if (t.read) {
-    return errors::InvalidArgument("TensorArray ", handle_.vec<string>()(1),
+    return errors::InvalidArgument("TensorArray ", handle_.vec<tstring>()(1),
                                    ": Could not write to TensorArray index ",
                                    index, " because it has already been read.");
   }
 
   if (!multiple_writes_aggregate_ && t.written) {
-    return errors::InvalidArgument("TensorArray ", handle_.vec<string>()(1),
+    return errors::InvalidArgument("TensorArray ", handle_.vec<tstring>()(1),
                                    ": Could not write to TensorArray index ",
                                    index,
                                    " because it has already been written to.");
@@ -500,7 +500,7 @@
     // Check that value_t shape matches t.shape
     if (value_t->shape() != t.shape) {
       return errors::InvalidArgument(
-          "TensorArray ", handle_.vec<string>()(1),
+          "TensorArray ", handle_.vec<tstring>()(1),
           ": Could not aggregate to TensorArray index ", index,
           " because the existing shape is ", t.shape.DebugString(),
           " but the new input shape is ", value_t->shape().DebugString(), ".");
@@ -568,7 +568,7 @@
       element_shape = tensors_[index].shape;
     } else if (!element_shape_.IsFullyDefined()) {
       return errors::InvalidArgument(
-          "TensorArray ", handle_.vec<string>()(1),
+          "TensorArray ", handle_.vec<tstring>()(1),
           ": Could not read from TensorArray index ", index,
           ".  Furthermore, the element shape is not fully defined: ",
           element_shape_.DebugString(),
@@ -598,7 +598,7 @@
   TensorAndState& t = tensors_[index];
 
   if (t.cleared) {
-    return errors::InvalidArgument("TensorArray ", handle_.vec<string>()(1),
+    return errors::InvalidArgument("TensorArray ", handle_.vec<tstring>()(1),
                                    ": Could not read index ", index,
                                    " twice because it was cleared after a "
                                    "previous read (perhaps try setting "
diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc
index d5c9470..52162e9 100644
--- a/tensorflow/core/kernels/tensor_array_ops.cc
+++ b/tensorflow/core/kernels/tensor_array_ops.cc
@@ -65,7 +65,7 @@
           "Tensor array handle must be 2-element vector, but had shape: ",
           tensor.shape().DebugString());
     }
-    auto h = tensor.flat<string>();
+    auto h = tensor.flat<tstring>();
     *container = h(0);
     *ta_handle = h(1);
   }
@@ -194,7 +194,7 @@
       return errors::InvalidArgument("Size should be >= 0.");
     }
 
-    auto handle = tensor_array_output_handle->flat<string>();
+    auto handle = tensor_array_output_handle->flat<tstring>();
     string unique_tensor_array_name =
         strings::StrCat(tensor_array_name_, "_",
                         TensorArray::tensor_array_counter.fetch_add(1));
@@ -301,7 +301,7 @@
           string(StringPiece(resource.name()).substr(container.size()));
     }
 
-    auto output_handle = tensor_array_output_handle->flat<string>();
+    auto output_handle = tensor_array_output_handle->flat<tstring>();
     output_handle(0) = "_tensor_array_grads";
     output_handle(1) = strings::StrCat(tensor_array_name, "@", source_);
 
diff --git a/tensorflow/core/kernels/tensor_forest/resource_ops.cc b/tensorflow/core/kernels/tensor_forest/resource_ops.cc
index c225d83..0c7b9e9 100644
--- a/tensorflow/core/kernels/tensor_forest/resource_ops.cc
+++ b/tensorflow/core/kernels/tensor_forest/resource_ops.cc
@@ -34,7 +34,7 @@
 
     auto* const result = new TensorForestTreeResource();
 
-    if (!result->InitFromSerialized(tree_config_t->scalar<string>()())) {
+    if (!result->InitFromSerialized(tree_config_t->scalar<tstring>()())) {
       result->Unref();
       OP_REQUIRES(context, false,
                   errors::InvalidArgument("Unable to parse tree config."));
@@ -63,7 +63,7 @@
     Tensor* output_config_t = nullptr;
     OP_REQUIRES_OK(
         context, context->allocate_output(0, TensorShape(), &output_config_t));
-    output_config_t->scalar<string>()() =
+    output_config_t->scalar<tstring>()() =
         decision_tree_resource->decision_tree().SerializeAsString();
   }
 };
@@ -86,7 +86,7 @@
     decision_tree_resource->Reset();
 
     if (!decision_tree_resource->InitFromSerialized(
-            tree_config_t->scalar<string>()())) {
+            tree_config_t->scalar<tstring>()())) {
       OP_REQUIRES(context, false,
                   errors::InvalidArgument("Unable to parse tree config."));
     }
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index 7451004..330e02c 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -162,6 +162,20 @@
 };
 
 template <typename T>
+struct ApplyAdagradV2<CPUDevice, T> {
+  void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
+                  typename TTypes<T>::Flat accum,
+                  typename TTypes<T>::ConstScalar lr,
+                  typename TTypes<T>::ConstScalar epsilon,
+                  typename TTypes<T>::ConstFlat grad, bool update_slots) {
+    if (update_slots) {
+      accum.device(d) += grad.square();
+    }
+    var.device(d) -= grad * lr() / (accum.sqrt() + epsilon());
+  }
+};
+
+template <typename T>
 struct ApplyProximalAdagrad<CPUDevice, T> {
   void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
                   typename TTypes<T>::Flat accum,
@@ -1265,6 +1279,106 @@
 #undef REGISTER_KERNELS
 
 template <typename Device, typename T>
+class ApplyAdagradV2Op : public OpKernel {
+ public:
+  explicit ApplyAdagradV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_));
+  }
+
+  void Compute(OpKernelContext* ctx) override {
+    const bool sparse = false;
+    auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
+        ctx, use_exclusive_lock_, sparse, {0, 1});
+    Tensor var;
+    OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
+                            ctx, 0, use_exclusive_lock_, sparse, &var));
+    Tensor accum;
+    OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
+                            ctx, 1, use_exclusive_lock_, sparse, &accum));
+    OP_REQUIRES(
+        ctx, var.IsInitialized(),
+        errors::FailedPrecondition(
+            "Attempting to use uninitialized variables: ", requested_input(0)));
+    OP_REQUIRES(
+        ctx, accum.IsInitialized(),
+        errors::FailedPrecondition(
+            "Attempting to use uninitialized variables: ", requested_input(1)));
+    const Tensor& lr = ctx->input(2);
+    OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
+                errors::InvalidArgument("lr is not a scalar: ",
+                                        lr.shape().DebugString()));
+    const Tensor& epsilon = ctx->input(3);
+    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
+                errors::InvalidArgument("epsilon is not a scalar: ",
+                                        epsilon.shape().DebugString()));
+    const Tensor& grad = ctx->input(4);
+    OP_REQUIRES(
+        ctx, var.shape().IsSameSize(accum.shape()),
+        errors::InvalidArgument("var and accum do not have the same shape",
+                                var.shape().DebugString(), " ",
+                                accum.shape().DebugString()));
+    OP_REQUIRES(
+        ctx, var.shape().IsSameSize(grad.shape()),
+        errors::InvalidArgument("var and grad do not have the same shape",
+                                var.shape().DebugString(), " ",
+                                grad.shape().DebugString()));
+
+    const Device& device = ctx->template eigen_device<Device>();
+    functor::ApplyAdagradV2<Device, T>()(device, var.flat<T>(), accum.flat<T>(),
+                                         lr.scalar<T>(), epsilon.scalar<T>(),
+                                         grad.flat<T>(), update_slots_);
+
+    MaybeForwardRefInputToRefOutput(ctx, 0, 0);
+  }
+
+ private:
+  bool use_exclusive_lock_;
+  bool update_slots_;
+};
+
+#define REGISTER_KERNELS(D, T)                                          \
+  REGISTER_KERNEL_BUILDER(                                              \
+      Name("ApplyAdagradV2").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+      ApplyAdagradV2Op<D##Device, T>);                                  \
+  REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdagradV2")                \
+                              .HostMemory("var")                        \
+                              .HostMemory("accum")                      \
+                              .Device(DEVICE_##D)                       \
+                              .TypeConstraint<T>("T"),                  \
+                          ApplyAdagradV2Op<D##Device, T>);
+#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
+
+TF_CALL_half(REGISTER_CPU_KERNELS);
+TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
+TF_CALL_float(REGISTER_CPU_KERNELS);
+TF_CALL_double(REGISTER_CPU_KERNELS);
+
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T)                                               \
+  template <>                                                             \
+  void ApplyAdagradV2<GPUDevice, T>::operator()(                          \
+      const GPUDevice& d, typename TTypes<T>::Flat var,                   \
+      typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, \
+      typename TTypes<T>::ConstScalar epsilon,                            \
+      typename TTypes<T>::ConstFlat grad, bool update_slots);             \
+  extern template struct ApplyAdagradV2<GPUDevice, T>;
+DECLARE_GPU_SPEC(Eigen::half);
+DECLARE_GPU_SPEC(float);
+DECLARE_GPU_SPEC(double);
+#undef DECLARE_GPU_SPEC
+}  // namespace functor
+
+REGISTER_KERNELS(GPU, Eigen::half);
+REGISTER_KERNELS(GPU, float);
+REGISTER_KERNELS(GPU, double);
+#endif
+#undef REGISTER_CPU_KERNELS
+#undef REGISTER_KERNELS
+
+template <typename Device, typename T>
 class ApplyProximalAdagradOp : public OpKernel {
  public:
   explicit ApplyProximalAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@@ -1532,6 +1646,179 @@
 
 // Note, this op works on cpu only.
 template <typename T, typename Tindex>
+class SparseApplyAdagradV2Op : public OpKernel {
+ public:
+  explicit SparseApplyAdagradV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_));
+  }
+
+  void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
+    const bool sparse = true;
+    auto locks = MaybeLockVariableInputMutexesInOrder<CPUDevice, T>(
+        ctx, use_exclusive_lock_, sparse, {0, 1});
+    Tensor var;
+    OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
+                            ctx, 0, use_exclusive_lock_, sparse, &var));
+    Tensor accum;
+    OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
+                            ctx, 1, use_exclusive_lock_, sparse, &accum));
+    OP_REQUIRES(
+        ctx, var.IsInitialized(),
+        errors::FailedPrecondition(
+            "Attempting to use uninitialized variables: ", requested_input(0)));
+    OP_REQUIRES(
+        ctx, accum.IsInitialized(),
+        errors::FailedPrecondition(
+            "Attempting to use uninitialized variables: ", requested_input(1)));
+    OP_REQUIRES(
+        ctx, var.shape().IsSameSize(accum.shape()),
+        errors::InvalidArgument("var and accum do not have the same shape",
+                                var.shape().DebugString(), " ",
+                                accum.shape().DebugString()));
+    OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
+                errors::InvalidArgument("var must be at least 1 dimensional"));
+
+    const Tensor& lr = ctx->input(2);
+    OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
+                errors::InvalidArgument("lr is not a scalar: ",
+                                        lr.shape().DebugString()));
+    const Tensor& epsilon = ctx->input(3);
+    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
+                errors::InvalidArgument("epsilon is not a scalar: ",
+                                        epsilon.shape().DebugString()));
+    const Tensor& grad = ctx->input(4);
+    const Tensor& indices = ctx->input(5);
+    OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
+                errors::InvalidArgument("indices must be one-dimensional"));
+
+    int64 inner_dim = 1;
+    for (int d = 1; d < var.dims(); d++) {
+      OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d),
+                  errors::InvalidArgument(strings::StrCat(
+                      "var and grad must match in dimension ", d)));
+      inner_dim *= grad.dim_size(d);
+    }
+    const Tindex N = indices.dim_size(0);
+    OP_REQUIRES(
+        ctx, grad.dim_size(0) == N,
+        errors::InvalidArgument(
+            "grad must be the same size as indices in the first dimension."));
+
+    OP_REQUIRES(ctx, inner_dim > 0,
+                errors::InvalidArgument(
+                    "Inner dimension should be greater than zero."));
+
+    // This op is implemented only for CPU device.
+    const auto& d = ctx->eigen_cpu_device();
+
+    if (N > 0) {
+      const int in_bytes = inner_dim * sizeof(T) * 3;
+      const int out_bytes = inner_dim * sizeof(T) * 2;
+      const int cycles = inner_dim * (Eigen::TensorOpCost::AddCost<T>() * 2 +
+                                      Eigen::TensorOpCost::MulCost<T>() * 2);
+      const Eigen::TensorOpCost cost(in_bytes, out_bytes, cycles);
+
+      if (inner_dim > 1) {
+        const Tindex first_dim_size = var.dim_size(0);
+        auto indices_vec = indices.vec<Tindex>();
+        auto var_flat = var.flat_outer_dims<T>();
+        auto accum_flat = accum.flat_outer_dims<T>();
+        auto grad_flat = grad.flat_outer_dims<T>();
+        const T lr_scalar = lr.scalar<T>()();
+        const T epsilon_scalar = epsilon.scalar<T>()();
+
+        for (Tindex i = 0; i < N; ++i) {
+          const Tindex index = internal::SubtleMustCopy(indices_vec(i));
+          OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
+                      errors::InvalidArgument(
+                          strings::StrCat("Index ", index, " at offset ", i,
+                                          " in indices is out of range")));
+        }
+
+        const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void {
+          for (Tindex i = start_idx; i < end_idx; ++i) {
+            const Tindex index = internal::SubtleMustCopy(indices_vec(i));
+            auto a = accum_flat.template chip<0>(index);
+            auto g = grad_flat.template chip<0>(i);
+            auto v = var_flat.template chip<0>(index);
+            if (update_slots_) {
+              a += g.square();
+            }
+            v -= g.constant(lr_scalar) * g /
+                 (a.sqrt() + a.constant(epsilon_scalar));
+          }
+        };
+
+        d.parallelFor(N, cost, shard);
+
+      } else {
+        auto indices_vec = indices.vec<Tindex>();
+        auto var_flat = var.flat<T>();
+        auto accum_flat = accum.flat<T>();
+        auto grad_flat = grad.flat<T>();
+        T lr_scalar = lr.scalar<T>()();
+        const T epsilon_scalar = epsilon.scalar<T>()();
+        const Tindex first_dim_size = accum_flat.size();
+
+        for (Tindex i = 0; i < N; ++i) {
+          const Tindex index = internal::SubtleMustCopy(indices_vec(i));
+          OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
+                      errors::InvalidArgument(
+                          strings::StrCat("Index ", index, " at offset ", i,
+                                          " in indices is out of range")));
+        }
+
+        const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void {
+          for (Tindex i = start_idx; i < end_idx; ++i) {
+            const Tindex index = internal::SubtleMustCopy(indices_vec(i));
+            T& a = accum_flat(index);
+            const T& g = grad_flat(i);
+            if (update_slots_) {
+              a += g * g;
+            }
+            var_flat(index) -=
+                lr_scalar * g / (Eigen::numext::sqrt(a) + epsilon_scalar);
+          }
+        };
+
+        d.parallelFor(N, cost, shard);
+      }
+    }
+
+    MaybeForwardRefInputToRefOutput(ctx, 0, 0);
+  }
+
+ private:
+  bool use_exclusive_lock_;
+  bool update_slots_;
+};
+
+#define REGISTER_KERNELS(T, Tindices)                                \
+  REGISTER_KERNEL_BUILDER(Name("SparseApplyAdagradV2")               \
+                              .Device(DEVICE_CPU)                    \
+                              .TypeConstraint<T>("T")                \
+                              .TypeConstraint<Tindices>("Tindices"), \
+                          SparseApplyAdagradV2Op<T, Tindices>);      \
+  REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdagradV2")       \
+                              .Device(DEVICE_CPU)                    \
+                              .TypeConstraint<T>("T")                \
+                              .TypeConstraint<Tindices>("Tindices"), \
+                          SparseApplyAdagradV2Op<T, Tindices>);
+#define REGISTER_CPU_KERNELS(T) \
+  REGISTER_KERNELS(T, int32);   \
+  REGISTER_KERNELS(T, int64);
+
+TF_CALL_half(REGISTER_CPU_KERNELS);
+TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
+TF_CALL_float(REGISTER_CPU_KERNELS);
+TF_CALL_double(REGISTER_CPU_KERNELS);
+
+#undef REGISTER_CPU_KERNELS
+#undef REGISTER_KERNELS
+
+// Note, this op works on cpu only.
+template <typename T, typename Tindex>
 class SparseApplyProximalAdagradOp : public OpKernel {
  public:
   explicit SparseApplyProximalAdagradOp(OpKernelConstruction* ctx)
diff --git a/tensorflow/core/kernels/training_ops.h b/tensorflow/core/kernels/training_ops.h
index 054f073..e1776dd 100644
--- a/tensorflow/core/kernels/training_ops.h
+++ b/tensorflow/core/kernels/training_ops.h
@@ -72,6 +72,15 @@
 };
 
 template <typename Device, typename T>
+struct ApplyAdagradV2 {
+  void operator()(const Device& d, typename TTypes<T>::Flat var,
+                  typename TTypes<T>::Flat accum,
+                  typename TTypes<T>::ConstScalar lr,
+                  typename TTypes<T>::ConstScalar epsilon,
+                  typename TTypes<T>::ConstFlat grad, bool update_slots);
+};
+
+template <typename Device, typename T>
 struct ApplyAdagradDA {
   void operator()(const Device& d, typename TTypes<T>::Flat var,
                   typename TTypes<T>::Flat gradient_accum,
diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc
index e67ac07..b9240cc 100644
--- a/tensorflow/core/kernels/training_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc
@@ -54,6 +54,25 @@
 };
 
 template <typename T>
+struct ApplyAdagradV2<GPUDevice, T> {
+  void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
+                  typename TTypes<T>::Flat accum,
+                  typename TTypes<T>::ConstScalar lr,
+                  typename TTypes<T>::ConstScalar epsilon,
+                  typename TTypes<T>::ConstFlat grad, bool update_slots) {
+    Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
+    bcast[0] = grad.dimension(0);
+    Eigen::Sizes<1> single;
+    if (update_slots) {
+      accum.device(d) += grad.square();
+    }
+    const auto update =
+        grad / (accum.sqrt() + epsilon.reshape(single).broadcast(bcast));
+    var.device(d) -= lr.reshape(single).broadcast(bcast) * update;
+  }
+};
+
+template <typename T>
 struct ApplyAdadelta<GPUDevice, T> {
   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
                   typename TTypes<T>::Flat accum,
@@ -348,6 +367,10 @@
 template struct functor::ApplyAdagrad<GPUDevice, float>;
 template struct functor::ApplyAdagrad<GPUDevice, double>;
 
+template struct functor::ApplyAdagradV2<GPUDevice, Eigen::half>;
+template struct functor::ApplyAdagradV2<GPUDevice, float>;
+template struct functor::ApplyAdagradV2<GPUDevice, double>;
+
 template struct functor::ApplyAdadelta<GPUDevice, Eigen::half>;
 template struct functor::ApplyAdadelta<GPUDevice, float>;
 template struct functor::ApplyAdadelta<GPUDevice, double>;
diff --git a/tensorflow/core/kernels/unicode_ops.cc b/tensorflow/core/kernels/unicode_ops.cc
index 59ebbed..0bb5f0f 100644
--- a/tensorflow/core/kernels/unicode_ops.cc
+++ b/tensorflow/core/kernels/unicode_ops.cc
@@ -295,10 +295,10 @@
     } else {
       OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
                                                &output_tensor));
-      output_tensor->flat<string>() = input_tensor->flat<string>();
+      output_tensor->flat<tstring>() = input_tensor->flat<tstring>();
     }
 
-    auto output_flat = output_tensor->flat<string>();
+    auto output_flat = output_tensor->flat<tstring>();
     bool found_any_format_error = false;
     for (size_t i = 0; i < output_flat.size(); ++i) {
       Transcode(&(output_flat(i)), input_encoder->converter_,
@@ -404,7 +404,7 @@
     OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
 
     // Go through all the strings in `input`.
-    const auto& input_vec = input_tensor->flat<string>();
+    const auto& input_vec = input_tensor->flat<tstring>();
 
     std::unique_ptr<WrappedConverter> input_encoder =
         absl::make_unique<WrappedConverter>();
@@ -538,7 +538,7 @@
     Tensor* output_tensor;
     OP_REQUIRES_OK(context, context->allocate_output("output", output_shape,
                                                      &output_tensor));
-    auto output_tensor_flat = output_tensor->flat<string>();
+    auto output_tensor_flat = output_tensor->flat<tstring>();
 
     // Use a single index over the flattened input values tensor.
     int idx = 0;
diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc
index adf84ba..4968284 100644
--- a/tensorflow/core/kernels/unique_op.cc
+++ b/tensorflow/core/kernels/unique_op.cc
@@ -237,6 +237,7 @@
                           UniqueOp<type, int64>)
 TF_CALL_REAL_NUMBER_TYPES(REGISTER_UNIQUE);
 REGISTER_UNIQUE(string)
+REGISTER_UNIQUE(bool)
 #undef REGISTER_UNIQUE
 
 // Fake integer GPU kernels so that the use of Unique in optimizers (to
diff --git a/tensorflow/core/kernels/unsorted_segment_join_op.cc b/tensorflow/core/kernels/unsorted_segment_join_op.cc
index 4ab890c..f0b9388 100644
--- a/tensorflow/core/kernels/unsorted_segment_join_op.cc
+++ b/tensorflow/core/kernels/unsorted_segment_join_op.cc
@@ -115,9 +115,9 @@
                                                      &output_tensor));
 
     // Preprating flat tensors.
-    auto output_flat = output_tensor->flat<string>();
+    auto output_flat = output_tensor->flat<tstring>();
     auto flat_segment_id = segment_id.flat<INDICES_TYPE>();
-    auto flat_input = input.flat<string>();
+    auto flat_input = input.flat<tstring>();
 
     for (int i = 0; i < flat_segment_id.size(); i++) {
       OP_REQUIRES(
diff --git a/tensorflow/core/kernels/whole_file_read_ops.cc b/tensorflow/core/kernels/whole_file_read_ops.cc
index b617b76..1e3b7fd 100644
--- a/tensorflow/core/kernels/whole_file_read_ops.cc
+++ b/tensorflow/core/kernels/whole_file_read_ops.cc
@@ -135,14 +135,14 @@
                 errors::InvalidArgument(
                     "Contents tensor must be scalar, but had shape: ",
                     contents_input->shape().DebugString()));
-    const string& filename = filename_input->scalar<string>()();
+    const string& filename = filename_input->scalar<tstring>()();
     const string dir(io::Dirname(filename));
     if (!context->env()->FileExists(dir).ok()) {
       OP_REQUIRES_OK(context, context->env()->RecursivelyCreateDir(dir));
     }
     OP_REQUIRES_OK(context,
                    WriteStringToFile(context->env(), filename,
-                                     contents_input->scalar<string>()()));
+                                     contents_input->scalar<tstring>()()));
   }
 };
 
diff --git a/tensorflow/core/kernels/word2vec_kernels.cc b/tensorflow/core/kernels/word2vec_kernels.cc
index 3477445..42b70e9 100644
--- a/tensorflow/core/kernels/word2vec_kernels.cc
+++ b/tensorflow/core/kernels/word2vec_kernels.cc
@@ -209,14 +209,14 @@
     vocab_size_ = static_cast<int32>(1 + ordered.size());
     Tensor word(DT_STRING, TensorShape({vocab_size_}));
     Tensor freq(DT_INT32, TensorShape({vocab_size_}));
-    word.flat<string>()(0) = "UNK";
+    word.flat<tstring>()(0) = "UNK";
     static const int32 kUnkId = 0;
     std::unordered_map<string, int32> word_id;
     int64 total_counted = 0;
     for (std::size_t i = 0; i < ordered.size(); ++i) {
       const auto& w = ordered[i].first;
       auto id = i + 1;
-      word.flat<string>()(id) = w;
+      word.flat<tstring>()(id) = w;
       auto word_count = ordered[i].second;
       freq.flat<int32>()(id) = word_count;
       total_counted += word_count;
diff --git a/tensorflow/core/lib/bfloat16/BUILD b/tensorflow/core/lib/bfloat16/BUILD
new file mode 100644
index 0000000..4f955c3
--- /dev/null
+++ b/tensorflow/core/lib/bfloat16/BUILD
@@ -0,0 +1,21 @@
+package(
+    default_visibility = [
+        "//tensorflow:__subpackages__",
+    ],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+cc_library(
+    name = "bfloat16",
+    srcs = ["bfloat16.cc"],
+    hdrs = ["bfloat16.h"],
+    deps = [
+        "//tensorflow/core/platform:byte_order",
+        "//third_party/eigen3",
+    ],
+)
+
+# TODO(bmzhao): Remove the following once references in core/BUILD is removed.
+exports_files(
+    glob(["*"]),
+)
diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h
index 6edff13..ba389e5 100644
--- a/tensorflow/core/lib/core/stringpiece.h
+++ b/tensorflow/core/lib/core/stringpiece.h
@@ -26,13 +26,6 @@
 #ifndef TENSORFLOW_CORE_LIB_CORE_STRINGPIECE_H_
 #define TENSORFLOW_CORE_LIB_CORE_STRINGPIECE_H_
 
-#include "absl/strings/string_view.h"
-
-namespace tensorflow {
-
-// Deprecated: please use absl::string_view directly.
-using StringPiece = absl::string_view;
-
-}  // namespace tensorflow
+#include "tensorflow/core/platform/stringpiece.h"
 
 #endif  // TENSORFLOW_CORE_LIB_CORE_STRINGPIECE_H_
diff --git a/tensorflow/core/lib/core/stringpiece_test.cc b/tensorflow/core/lib/core/stringpiece_test.cc
deleted file mode 100644
index e4b489f..0000000
--- a/tensorflow/core/lib/core/stringpiece_test.cc
+++ /dev/null
@@ -1,63 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/lib/core/stringpiece.h"
-
-#include <unordered_map>
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-
-TEST(StringPiece, Ctor) {
-  {
-    // const char* without size.
-    const char* hello = "hello";
-    StringPiece s20(hello);
-    EXPECT_TRUE(s20.data() == hello);
-    EXPECT_EQ(5, s20.size());
-
-    // const char* with size.
-    StringPiece s21(hello, 4);
-    EXPECT_TRUE(s21.data() == hello);
-    EXPECT_EQ(4, s21.size());
-
-    // Not recommended, but valid C++
-    StringPiece s22(hello, 6);
-    EXPECT_TRUE(s22.data() == hello);
-    EXPECT_EQ(6, s22.size());
-  }
-
-  {
-    string hola = "hola";
-    StringPiece s30(hola);
-    EXPECT_TRUE(s30.data() == hola.data());
-    EXPECT_EQ(4, s30.size());
-
-    // std::string with embedded '\0'.
-    hola.push_back('\0');
-    hola.append("h2");
-    hola.push_back('\0');
-    StringPiece s31(hola);
-    EXPECT_TRUE(s31.data() == hola.data());
-    EXPECT_EQ(8, s31.size());
-  }
-}
-
-TEST(StringPiece, ConversionToString) {
-  EXPECT_EQ("", string(StringPiece("")));
-  EXPECT_EQ("foo", string(StringPiece("foo")));
-}
-
-}  // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/BUILD b/tensorflow/core/lib/gtl/BUILD
new file mode 100644
index 0000000..fca412f
--- /dev/null
+++ b/tensorflow/core/lib/gtl/BUILD
@@ -0,0 +1,187 @@
+package(
+    default_visibility = [
+        "//tensorflow:__subpackages__",
+    ],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+# Todo(bmzhao): Remaining targets to add to this BUILD file are:
+# compactptrset, flatmap, flatset, manual_constructor, + all tests.
+
+cc_library(
+    name = "array_slice",
+    hdrs = ["array_slice.h"],
+    deps = [
+        "//tensorflow/core/lib/gtl:inlined_vector",
+        "@com_google_absl//absl/types:span",
+    ],
+)
+
+cc_library(
+    name = "cleanup",
+    hdrs = ["cleanup.h"],
+    deps = ["//tensorflow/core/platform:macros"],
+)
+
+cc_library(
+    name = "edit_distance",
+    hdrs = ["edit_distance.h"],
+    deps = [
+        "//tensorflow/core/lib/gtl:array_slice",
+        "//tensorflow/core/lib/gtl:inlined_vector",
+    ],
+)
+
+cc_library(
+    name = "flatrep",
+    hdrs = ["flatrep.h"],
+    deps = [
+        "//tensorflow/core/platform:prefetch",
+        "//tensorflow/core/platform:types",
+    ],
+)
+
+cc_library(
+    name = "inlined_vector",
+    hdrs = ["inlined_vector.h"],
+    deps = [
+        "//tensorflow/core/platform:macros",
+        "//tensorflow/core/platform:types",
+        "@com_google_absl//absl/container:inlined_vector",
+    ],
+)
+
+cc_library(
+    name = "int_type",
+    hdrs = ["int_type.h"],
+    deps = [
+        "//tensorflow/core/platform:macros",
+        "//tensorflow/core/platform:types",
+    ],
+)
+
+cc_library(
+    name = "iterator_range",
+    hdrs = ["iterator_range.h"],
+    deps = [],
+)
+
+cc_library(
+    name = "map_util",
+    srcs = [
+        "map_util.h",
+        "subtle/map_traits.h",
+    ],
+    hdrs = ["map_util.h"],
+)
+
+cc_library(
+    name = "optional",
+    hdrs = ["optional.h"],
+    deps = ["@com_google_absl//absl/types:optional"],
+)
+
+cc_library(
+    name = "priority_queue_util",
+    hdrs = ["priority_queue_util.h"],
+    deps = [],
+)
+
+cc_library(
+    name = "stl_util",
+    hdrs = ["stl_util.h"],
+    deps = ["@com_google_absl//absl/meta:type_traits"],
+)
+
+cc_library(
+    name = "top_n",
+    hdrs = ["top_n.h"],
+    deps = ["//tensorflow/core/platform:logging"],
+)
+
+filegroup(
+    name = "legacy_lib_gtl_headers",
+    srcs = [
+        "array_slice.h",
+        "cleanup.h",
+        "compactptrset.h",
+        "edit_distance.h",
+        "flatmap.h",
+        "flatset.h",
+        "inlined_vector.h",
+        "optional.h",
+        "priority_queue_util.h",
+    ],
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+filegroup(
+    name = "legacy_lib_internal_public_gtl_headers",
+    srcs = [
+        "edit_distance.h",
+        "int_type.h",
+        "iterator_range.h",
+        "manual_constructor.h",
+        "map_util.h",
+        "stl_util.h",
+        "top_n.h",
+    ],
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+filegroup(
+    name = "legacy_lib_test_internal_headers",
+    srcs = [
+        "manual_constructor.h",
+    ],
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+filegroup(
+    name = "legacy_android_gif_internal_headers",
+    srcs = [
+        "cleanup.h",
+    ],
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+filegroup(
+    name = "legacy_lib_gtl_all_headers",
+    srcs = [
+        "array_slice.h",
+        "cleanup.h",
+        "compactptrset.h",
+        "edit_distance.h",
+        "flatmap.h",
+        "flatrep.h",
+        "flatset.h",
+        "inlined_vector.h",
+        "int_type.h",
+        "iterator_range.h",
+        "manual_constructor.h",
+        "map_util.h",
+        "optional.h",
+        "priority_queue_util.h",
+        "stl_util.h",
+        "subtle/map_traits.h",
+        "top_n.h",
+    ],
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+filegroup(
+    name = "legacy_lib_gtl_tests",
+    srcs = [
+        "cleanup_test.cc",
+        "compactptrset_test.cc",
+        "edit_distance_test.cc",
+        "flatmap_test.cc",
+        "flatset_test.cc",
+        "int_type_test.cc",
+        "iterator_range_test.cc",
+        "manual_constructor_test.cc",
+        "map_util_test.cc",
+        "top_n_test.cc",
+    ],
+    visibility = ["//tensorflow/core:__pkg__"],
+)
diff --git a/tensorflow/core/lib/strings/stringprintf.cc b/tensorflow/core/lib/strings/stringprintf.cc
deleted file mode 100644
index bbffa06..0000000
--- a/tensorflow/core/lib/strings/stringprintf.cc
+++ /dev/null
@@ -1,93 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/lib/strings/stringprintf.h"
-
-#include <errno.h>
-#include <stdarg.h>  // For va_list and related operations
-#include <stdio.h>   // MSVC requires this for _vsnprintf
-
-namespace tensorflow {
-namespace strings {
-
-void Appendv(string* dst, const char* format, va_list ap) {
-  // First try with a small fixed size buffer
-  static const int kSpaceLength = 1024;
-  char space[kSpaceLength];
-
-  // It's possible for methods that use a va_list to invalidate
-  // the data in it upon use.  The fix is to make a copy
-  // of the structure before using it and use that copy instead.
-  va_list backup_ap;
-  va_copy(backup_ap, ap);
-  int result = vsnprintf(space, kSpaceLength, format, backup_ap);
-  va_end(backup_ap);
-
-  if (result < kSpaceLength) {
-    if (result >= 0) {
-      // Normal case -- everything fit.
-      dst->append(space, result);
-      return;
-    }
-
-#ifdef _MSC_VER
-      // Error or MSVC running out of space.  MSVC 8.0 and higher
-      // can be asked about space needed with the special idiom below:
-      va_copy(backup_ap, ap);
-      result = vsnprintf(nullptr, 0, format, backup_ap);
-      va_end(backup_ap);
-#endif
-
-    if (result < 0) {
-      // Just an error.
-      return;
-    }
-  }
-
-  // Increase the buffer size to the size requested by vsnprintf,
-  // plus one for the closing \0.
-  int length = result + 1;
-  char* buf = new char[length];
-
-  // Restore the va_list before we use it again
-  va_copy(backup_ap, ap);
-  result = vsnprintf(buf, length, format, backup_ap);
-  va_end(backup_ap);
-
-  if (result >= 0 && result < length) {
-    // It fit
-    dst->append(buf, result);
-  }
-  delete[] buf;
-}
-
-string Printf(const char* format, ...) {
-  va_list ap;
-  va_start(ap, format);
-  string result;
-  Appendv(&result, format, ap);
-  va_end(ap);
-  return result;
-}
-
-void Appendf(string* dst, const char* format, ...) {
-  va_list ap;
-  va_start(ap, format);
-  Appendv(dst, format, ap);
-  va_end(ap);
-}
-
-}  // namespace strings
-}  // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/stringprintf.h b/tensorflow/core/lib/strings/stringprintf.h
index 52af410..836632d 100644
--- a/tensorflow/core/lib/strings/stringprintf.h
+++ b/tensorflow/core/lib/strings/stringprintf.h
@@ -23,30 +23,6 @@
 #ifndef TENSORFLOW_CORE_LIB_STRINGS_STRINGPRINTF_H_
 #define TENSORFLOW_CORE_LIB_STRINGS_STRINGPRINTF_H_
 
-#include <stdarg.h>
-#include <string>
-
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-namespace strings {
-
-// Return a C++ string
-extern string Printf(const char* format, ...)
-    // Tell the compiler to do printf format string checking.
-    TF_PRINTF_ATTRIBUTE(1, 2);
-
-// Append result to a supplied string
-extern void Appendf(string* dst, const char* format, ...)
-    // Tell the compiler to do printf format string checking.
-    TF_PRINTF_ATTRIBUTE(2, 3);
-
-// Lower-level routine that takes a va_list and appends to a specified
-// string.  All other routines are just convenience wrappers around it.
-extern void Appendv(string* dst, const char* format, va_list ap);
-
-}  // namespace strings
-}  // namespace tensorflow
+#include "tensorflow/core/platform/stringprintf.h"
 
 #endif  // TENSORFLOW_CORE_LIB_STRINGS_STRINGPRINTF_H_
diff --git a/tensorflow/core/lib/strings/stringprintf_test.cc b/tensorflow/core/lib/strings/stringprintf_test.cc
deleted file mode 100644
index 02cf4cb..0000000
--- a/tensorflow/core/lib/strings/stringprintf_test.cc
+++ /dev/null
@@ -1,128 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/lib/strings/stringprintf.h"
-
-#include <string>
-
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-namespace strings {
-namespace {
-
-TEST(PrintfTest, Empty) {
-  EXPECT_EQ("", Printf("%s", string().c_str()));
-  EXPECT_EQ("", Printf("%s", ""));
-}
-
-TEST(PrintfTest, Misc) {
-// MSVC does not support $ format specifier.
-#if !defined(_MSC_VER)
-  EXPECT_EQ("123hello w", Printf("%3$d%2$s %1$c", 'w', "hello", 123));
-#endif  // !_MSC_VER
-}
-
-TEST(AppendfTest, Empty) {
-  string value("Hello");
-  const char* empty = "";
-  Appendf(&value, "%s", empty);
-  EXPECT_EQ("Hello", value);
-}
-
-TEST(AppendfTest, EmptyString) {
-  string value("Hello");
-  Appendf(&value, "%s", "");
-  EXPECT_EQ("Hello", value);
-}
-
-TEST(AppendfTest, String) {
-  string value("Hello");
-  Appendf(&value, " %s", "World");
-  EXPECT_EQ("Hello World", value);
-}
-
-TEST(AppendfTest, Int) {
-  string value("Hello");
-  Appendf(&value, " %d", 123);
-  EXPECT_EQ("Hello 123", value);
-}
-
-TEST(PrintfTest, Multibyte) {
-  // If we are in multibyte mode and feed invalid multibyte sequence,
-  // Printf should return an empty string instead of running
-  // out of memory while trying to determine destination buffer size.
-  // see b/4194543.
-
-  char* old_locale = setlocale(LC_CTYPE, nullptr);
-  // Push locale with multibyte mode
-  setlocale(LC_CTYPE, "en_US.utf8");
-
-  const char kInvalidCodePoint[] = "\375\067s";
-  string value = Printf("%.*s", 3, kInvalidCodePoint);
-
-  // In some versions of glibc (e.g. eglibc-2.11.1, aka GRTEv2), snprintf
-  // returns error given an invalid codepoint. Other versions
-  // (e.g. eglibc-2.15, aka pre-GRTEv3) emit the codepoint verbatim.
-  // We test that the output is one of the above.
-  EXPECT_TRUE(value.empty() || value == kInvalidCodePoint);
-
-  // Repeat with longer string, to make sure that the dynamically
-  // allocated path in StringAppendV is handled correctly.
-  int n = 2048;
-  char* buf = new char[n + 1];
-  memset(buf, ' ', n - 3);
-  memcpy(buf + n - 3, kInvalidCodePoint, 4);
-  value = Printf("%.*s", n, buf);
-  // See GRTEv2 vs. GRTEv3 comment above.
-  EXPECT_TRUE(value.empty() || value == buf);
-  delete[] buf;
-
-  setlocale(LC_CTYPE, old_locale);
-}
-
-TEST(PrintfTest, NoMultibyte) {
-  // No multibyte handling, but the string contains funny chars.
-  char* old_locale = setlocale(LC_CTYPE, nullptr);
-  setlocale(LC_CTYPE, "POSIX");
-  string value = Printf("%.*s", 3, "\375\067s");
-  setlocale(LC_CTYPE, old_locale);
-  EXPECT_EQ("\375\067s", value);
-}
-
-TEST(PrintfTest, DontOverwriteErrno) {
-  // Check that errno isn't overwritten unless we're printing
-  // something significantly larger than what people are normally
-  // printing in their badly written PLOG() statements.
-  errno = ECHILD;
-  string value = Printf("Hello, %s!", "World");
-  EXPECT_EQ(ECHILD, errno);
-}
-
-TEST(PrintfTest, LargeBuf) {
-  // Check that the large buffer is handled correctly.
-  int n = 2048;
-  char* buf = new char[n + 1];
-  memset(buf, ' ', n);
-  buf[n] = 0;
-  string value = Printf("%s", buf);
-  EXPECT_EQ(buf, value);
-  delete[] buf;
-}
-
-}  // namespace
-
-}  // namespace strings
-}  // namespace tensorflow
diff --git a/tensorflow/core/nccl/BUILD b/tensorflow/core/nccl/BUILD
index b1f7bcf..f48061c 100644
--- a/tensorflow/core/nccl/BUILD
+++ b/tensorflow/core/nccl/BUILD
@@ -6,7 +6,7 @@
 load("//tensorflow:tensorflow.bzl", "tf_copts")
 load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
 load(
-    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "//tensorflow/core/platform:default/build_config_root.bzl",
     "tf_cuda_tests_tags",
 )
 
diff --git a/tensorflow/core/nccl/nccl_manager.cc b/tensorflow/core/nccl/nccl_manager.cc
index 20ba3ca..0746250 100644
--- a/tensorflow/core/nccl/nccl_manager.cc
+++ b/tensorflow/core/nccl/nccl_manager.cc
@@ -209,6 +209,9 @@
   std::sort(collective->participants.begin(), collective->participants.end(),
             [](const std::unique_ptr<Participant>& a,
                const std::unique_ptr<Participant>& b) {
+              if (a->executor == b->executor) {
+                return a->global_rank < b->global_rank;
+              }
               return a->executor < b->executor;
             });
 
@@ -402,6 +405,8 @@
       if (CheckReady(collective_key, collective)) {
         to_run = collective;
       }
+      VLOG(2) << "SignalMultiNodeReady collective " << collective_key
+              << " to_run " << to_run;
     }
   }
 
@@ -480,7 +485,18 @@
           collective->participants.size(),
           " with one more participant being added");
     }
+    if (collective->status.ok() && collective->root_rank >= 0 &&
+        context.source_rank >= 0 &&
+        collective->root_rank != context.source_rank) {
+      collective->status = errors::Internal(
+          "Collective ", collective->collective_key, " already has root_rank ",
+          collective->root_rank, " but new participant has root_rank ",
+          context.source_rank);
+    }
 
+    if (context.source_rank >= 0) {
+      collective->root_rank = context.source_rank;
+    }
     collective->participants.emplace_back(std::move(participant));
     ++collective->available_participants;
 
@@ -508,19 +524,12 @@
 void NcclManager::RunCollective(Collective* collective) {
   static mutex collective_mu(LINKER_INITIALIZED);
 
-  Status s = collective->status;
-  if (s.ok()) {
-    s = GetCommunicator(collective, &collective->communicator);
-  }
-  if (!s.ok()) {
-    for (int i = 0; i < collective->num_local_devices; ++i) {
-      collective->participants[i]->done_callback(s);
-    }
-    collective->Unref();
-    return;
+  Status status = collective->status;
+  if (status.ok()) {
+    status = GetCommunicator(collective, &collective->communicator);
   }
 
-  for (int i = 0; i < collective->num_local_devices; ++i) {
+  for (int i = 0; status.ok() && i < collective->num_local_devices; ++i) {
     Participant* p = collective->participants[i].get();
     NcclStream* nccl_stream = collective->communicator->members[i].nccl_stream;
     CHECK(nccl_stream != nullptr);
@@ -533,13 +542,30 @@
       nccl_stream->stream->ThenWaitFor(p->tensor_stream);
     }
     if (p->root) {
-      CHECK_EQ(collective->root_rank, -1);
-      collective->root_rank = rank;
+      if (collective->root_rank == -1) {
+        collective->root_rank = rank;
+      } else if (collective->root_rank != rank) {
+        status = errors::Internal(
+            "Inconsistent root rank ", collective->root_rank, " and GPU id ",
+            p->gpu_device_id, " rank ", rank, " also marked as root.");
+      }
     }
+    VLOG(2) << "RunCollective rank " << rank << " global_rank "
+            << p->global_rank << " root_rank " << collective->root_rank;
   }
 
-  if (collective->type == kBroadcast) {
-    CHECK_NE(collective->root_rank, -1);
+  if (status.ok() && collective->type == kBroadcast &&
+      collective->root_rank < 0) {
+    status = errors::Internal("Root rank not indicated for collective ",
+                              collective->collective_key);
+  }
+
+  if (!status.ok()) {
+    for (int i = 0; i < collective->num_local_devices; ++i) {
+      collective->participants[i]->done_callback(status);
+    }
+    collective->Unref();
+    return;
   }
 
   {
diff --git a/tensorflow/core/nccl/nccl_manager.h b/tensorflow/core/nccl/nccl_manager.h
index ebb2aab..a4d5d13 100644
--- a/tensorflow/core/nccl/nccl_manager.h
+++ b/tensorflow/core/nccl/nccl_manager.h
@@ -115,11 +115,13 @@
   // operation key, number of participants, and communicator key.
   struct Context {
     Context(const string& collective_key, int num_local_devices,
-            int num_global_devices, const string& communicator_key)
+            int num_global_devices, const string& communicator_key,
+            int source_rank)
         : collective_key(collective_key),
           num_local_devices(num_local_devices),
           num_global_devices(num_global_devices),
-          communicator_key(communicator_key) {}
+          communicator_key(communicator_key),
+          source_rank(source_rank) {}
 
     // Unique key for this collective instance
     const string& collective_key;
@@ -137,6 +139,9 @@
     // `communicator_key` is not required for single-node collectives and can be
     // empty.
     const string& communicator_key;
+
+    // Rank of broadcast source.
+    int source_rank;
   };
 
   // Adds one participant to an all-reduce.
diff --git a/tensorflow/core/nccl/nccl_manager_test.cc b/tensorflow/core/nccl/nccl_manager_test.cc
index 161a889..44ae34a 100644
--- a/tensorflow/core/nccl/nccl_manager_test.cc
+++ b/tensorflow/core/nccl/nccl_manager_test.cc
@@ -26,6 +26,7 @@
 #include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/unbounded_work_queue.h"
 
 namespace tensorflow {
 
@@ -60,7 +61,8 @@
 
     mutex mu;
     Status final_status;
-    int num_completed = 0;
+    int num_completed GUARDED_BY(mu) = 0;
+    condition_variable done_cv;
   };
 
   static void SetUpTestSuite() {
@@ -68,13 +70,20 @@
     setenv("NCCL_LAUNCH_MODE", "PARALLEL", 1 /* replace */);
     devices_ = new std::vector<std::unique_ptr<BaseGPUDevice>>(GetGPUDevices());
     LOG(INFO) << "Running test with " << devices_->size() << " gpus";
+    work_queue_ = new UnboundedWorkQueue(Env::Default(), "nccl_manager_test");
   }
 
-  void SetUp() override { ASSERT_GT(devices_->size(), 0) << "No GPUs found"; }
+  void SetUp() override {
+    ASSERT_GT(devices_->size(), 0) << "No GPUs found";
+    ASSERT_NE(work_queue_, nullptr);
+  }
 
   static int32 NumGPUs() { return static_cast<int32>(devices_->size()); }
 
-  static void TearDownTestSuite() { delete devices_; }
+  static void TearDownTestSuite() {
+    delete devices_;
+    delete work_queue_;
+  }
 
   TestCase* MakeReductionTestCase(int num_nodes, int num_ranks_per_node,
                                   ncclRedOp_t reduction_op, TensorShape shape,
@@ -221,13 +230,10 @@
 
   // Waits for the done callback to be called for each participant.
   void WaitForTestCompletion(TestCase* test_case) {
-    test_case->mu.lock();
+    mutex_lock l(test_case->mu);
     while (test_case->num_completed != test_case->outs.size()) {
-      test_case->mu.unlock();
-      Env::Default()->SleepForMicroseconds(10);
-      test_case->mu.lock();
+      test_case->done_cv.wait(l);
     }
-    test_case->mu.unlock();
   }
 
   void VerifyResults(TestCase* test_case) {
@@ -259,23 +265,31 @@
   NcclManager::DoneCallback CreateDoneCallback(TestCase* test_case) {
     return [this, test_case](Status s) {
       mutex_lock l(test_case->mu);
-      ++test_case->num_completed;
       test_case->final_status.Update(s);
+      if (++test_case->num_completed == test_case->outs.size()) {
+        test_case->done_cv.notify_one();
+      }
     };
   }
 
-  void RunMultiNodeTest(const int num_nodes, const int num_ranks_per_node) {
+  struct NodeState {
+    NcclManager nccl_manager;
+    std::atomic<int> launched{0};
+  };
+
+  void RunMultiNodeAllReduceTest(const int num_nodes,
+                                 const int num_ranks_per_node) {
     const int num_global_ranks = num_nodes * num_ranks_per_node;
-    std::vector<NcclManager> nccl_managers(num_nodes);
+    std::vector<NodeState> node_states(num_nodes);
     const string collective_key = "allreduce";
     // The NcclManagers in this test synchronize in real-time, so we need to run
     // each node's code in a separate thread.
     // Specifically, the call to ncclGroupEnd() after calling ncclCommInitRank
     // waits for all communicators before returning.
-    thread::ThreadPool pool(Env::Default(), "test_multi_node_nccl", num_nodes);
 
     // First, initialize the communicator_key used for this collective.
-    const string communicator_key = nccl_managers[0].GenerateCommunicatorKey();
+    const string communicator_key =
+        node_states[0].nccl_manager.GenerateCommunicatorKey();
 
     for (int op = 0; op < 4; ++op) {
       ncclRedOp_t reduction_op = static_cast<ncclRedOp_t>(op);
@@ -284,7 +298,7 @@
                                       reduction_op, TensorShape({2, 3}), 0.0f));
       for (int node = 0; node < num_nodes; ++node) {
         auto node_fn = [this, node, num_ranks_per_node, num_global_ranks,
-                        &nccl_managers, &communicator_key, &collective_key,
+                        &node_states, &communicator_key, &collective_key,
                         reduction_op, &test_case] {
           for (int local_rank = 0; local_rank < num_ranks_per_node;
                ++local_rank) {
@@ -296,19 +310,19 @@
                 device->executor(), stream, event_mgr, device->gpu_id(),
                 &test_case->ins[global_rank], &test_case->outs[global_rank],
                 global_rank, this->CreateDoneCallback(test_case.get()));
-            nccl_managers[node].AddToAllReduce(
+            node_states[node].nccl_manager.AddToAllReduce(
                 std::move(participant),
                 {collective_key, num_ranks_per_node, num_global_ranks,
-                 communicator_key},
+                 communicator_key, /*source_rank=*/-1},
                 reduction_op);
             VLOG(1) << "AddToAllReduce node " << node << " global_rank "
                     << global_rank;
           }
 
           // Signal collective ready to launch at this node.
-          nccl_managers[node].SignalMultiNodeReady(collective_key);
+          node_states[node].nccl_manager.SignalMultiNodeReady(collective_key);
         };
-        pool.Schedule(node_fn);
+        this->work_queue_->Schedule(node_fn);
       }
 
       VLOG(2) << "Verifying results";
@@ -316,10 +330,74 @@
     }
   }
 
+  void RunMultiNodeBroadcastTest(const int num_nodes,
+                                 const int num_ranks_per_node,
+                                 const int src_node, const int src_local_rank,
+                                 const bool in_place) {
+    const int num_global_ranks = num_nodes * num_ranks_per_node;
+    const int src_global_rank = src_node * num_ranks_per_node + src_local_rank;
+    const string collective_key = "broadcast";
+    std::vector<NodeState> node_states(num_nodes);
+    const string communicator_key =
+        node_states[0].nccl_manager.GenerateCommunicatorKey();
+    std::unique_ptr<TestCase> test_case(this->MakeBroadcastTestCase(
+        num_nodes, num_ranks_per_node, TensorShape({5, 6}), src_node,
+        src_local_rank, in_place));
+    for (int node = 0; node < num_nodes; ++node) {
+      for (int local_rank = 0; local_rank < num_ranks_per_node; ++local_rank) {
+        // Launch each rank in a separate thread to test concurrent,
+        // randomly-ordered calls into NcclManager.
+        auto rank_fn = [this, node, num_ranks_per_node, num_global_ranks,
+                        src_global_rank, local_rank, &node_states,
+                        &collective_key, &communicator_key, &test_case]() {
+          auto* device = this->GetDevice(local_rank);
+          auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr;
+          auto* stream = device->tensorflow_gpu_device_info()->stream;
+          const int global_rank = node * num_ranks_per_node + local_rank;
+          auto* input = global_rank == src_global_rank
+                            ? &test_case->ins[global_rank]
+                            : nullptr;
+          auto* output = test_case->outs[global_rank].NumElements() == 0
+                             ? nullptr
+                             : &test_case->outs[global_rank];
+          auto participant = absl::make_unique<NcclManager::Participant>(
+              device->executor(), stream, event_mgr, device->gpu_id(), input,
+              output, global_rank, this->CreateDoneCallback(test_case.get()));
+          if (global_rank == src_global_rank) {
+            VLOG(1) << "AddBroadcastSend node " << node << " global_rank "
+                    << global_rank;
+            node_states[node].nccl_manager.AddBroadcastSend(
+                std::move(participant),
+                {collective_key, num_ranks_per_node, num_global_ranks,
+                 communicator_key, src_global_rank});
+          } else {
+            VLOG(1) << "AddBroadcastRecv node " << node << " global_rank "
+                    << global_rank;
+            node_states[node].nccl_manager.AddBroadcastRecv(
+                std::move(participant),
+                {collective_key, num_ranks_per_node, num_global_ranks,
+                 communicator_key, src_global_rank});
+          }
+
+          if (++node_states[node].launched == num_ranks_per_node) {
+            // Signal collective ready to launch at this node.
+            node_states[node].nccl_manager.SignalMultiNodeReady(collective_key);
+          }
+        };
+        this->work_queue_->Schedule(std::move(rank_fn));
+      }
+    }
+
+    VLOG(2) << "Verifying results";
+    this->VerifyResults(test_case.get());
+  }
+
   static BaseGPUDevice* GetDevice(size_t rank) {
     return devices_->at(rank % devices_->size()).get();
   }
 
+  static UnboundedWorkQueue* work_queue_;
+
  private:
   static Allocator* GpuAllocator(BaseGPUDevice* device) {
     return device->GetAllocator(AllocatorAttributes());
@@ -331,7 +409,6 @@
     return typed;
   }
 
- private:
   static std::vector<std::unique_ptr<BaseGPUDevice>>* devices_;
   static const DataType data_type_;
   static const Scalar max_;
@@ -346,6 +423,8 @@
 template <typename Scalar>
 const Scalar NcclManagerTest<Scalar>::max_ =
     Eigen::NumTraits<Scalar>::highest();
+template <typename Scalar>
+UnboundedWorkQueue* NcclManagerTest<Scalar>::work_queue_ = nullptr;
 
 // Instantiate tests for float and double.
 using TypeList = ::testing::Types<float, double>;
@@ -372,7 +451,8 @@
       NcclManager::instance()->AddToAllReduce(
           std::move(participant),
           {"allreduce", /*num_local_devices=*/num_ranks,
-           /*num_global_devices=*/num_ranks, /*communicator_key=*/""},
+           /*num_global_devices=*/num_ranks, /*communicator_key=*/"",
+           /*source_rank=*/-1},
           reduction_op);
     }
 
@@ -389,7 +469,6 @@
 TYPED_TEST(NcclManagerTest, MultipleCallers) {
   const int num_ranks = 4;
   const int num_collectives_per_iteration = 10;
-  const int num_threads = num_ranks * 2;
   const int time_limit_micros = 1 * 1000 * 1000;  // 1 second
 
   int64 start = Env::Default()->NowMicros();
@@ -417,8 +496,6 @@
                  std::mt19937(std::random_device()()));
 
     mutex mu;  // guards case_and_rank.
-    std::unique_ptr<thread::ThreadPool> pool(
-        new thread::ThreadPool(Env::Default(), "test", num_threads));
     const int to_schedule = case_and_rank.size();
     for (int i = 0; i < to_schedule; ++i) {
       auto fn = [&]() {
@@ -443,12 +520,11 @@
             {strings::StrCat("allreduce", test_num),
              /*num_local_devices=*/num_ranks,
              /*num_global_devices=*/num_ranks,
-             /*communicator_key=*/""},
+             /*communicator_key=*/"", /*source_rank=*/-1},
             ncclSum);
       };
-      pool->Schedule(fn);
+      this->work_queue_->Schedule(fn);
     }
-    pool.reset();  // wait for all work to be scheduled.
 
     VLOG(2) << "Verifying results for " << num_collectives_per_iteration
             << " collectives";
@@ -484,7 +560,8 @@
       NcclManager::instance()->AddToAllGather(
           std::move(participant),
           {"allgather", /*num_local_devices=*/num_ranks,
-           /*num_global_devices=*/num_ranks, /*communicator_key=*/""});
+           /*num_global_devices=*/num_ranks, /*communicator_key=*/"",
+           /*source_rank=*/-1});
     }
 
     LOG(INFO) << "Verifying results";
@@ -494,41 +571,27 @@
 
 // Test basic broadcast.
 TYPED_TEST(NcclManagerTest, BasicBroadcast) {
-  const int num_ranks = 4;
-  const int src_rank = 2;
-  for (int in_place_idx = 0; in_place_idx <= 1; ++in_place_idx) {
-    bool in_place = in_place_idx == 1;
-    std::unique_ptr<typename TestFixture::TestCase> test_case(
-        this->MakeBroadcastTestCase(/*num_nodes=*/1, num_ranks,
-                                    TensorShape({5, 6}), /*src_node=*/0,
-                                    src_rank, in_place));
-    for (int rank = 0; rank < num_ranks; ++rank) {
-      auto* device = this->GetDevice(rank);
-      auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr;
-      auto* stream = device->tensorflow_gpu_device_info()->stream;
-      auto* input = rank == src_rank ? &test_case->ins[rank] : nullptr;
-      auto* output = test_case->outs[rank].NumElements() == 0
-                         ? nullptr
-                         : &test_case->outs[rank];
-      auto participant = absl::make_unique<NcclManager::Participant>(
-          device->executor(), stream, event_mgr, device->gpu_id(), input,
-          output, rank, this->CreateDoneCallback(test_case.get()));
-      if (rank == src_rank) {
-        NcclManager::instance()->AddBroadcastSend(
-            std::move(participant),
-            {"broadcast", /*num_local_devices=*/num_ranks,
-             /*num_global_devices=*/num_ranks,
-             /*communicator_key=*/""});
-      } else {
-        NcclManager::instance()->AddBroadcastRecv(
-            std::move(participant),
-            {"broadcast", /*num_local_devices=*/num_ranks,
-             /*num_global_devices=*/num_ranks,
-             /*communicator_key=*/""});
-      }
-    }
+  this->RunMultiNodeBroadcastTest(/*num_nodes=*/1, /*num_ranks_per_node=*/4,
+                                  /*src_node=*/0, /*src_local_rank=*/2,
+                                  /*in_place=*/false);
+}
 
-    this->VerifyResults(test_case.get());
+// Test in-place broadcast.
+TYPED_TEST(NcclManagerTest, InPlaceBroadcast) {
+  this->RunMultiNodeBroadcastTest(/*num_nodes=*/1, /*num_ranks_per_node=*/4,
+                                  /*src_node=*/0, /*src_local_rank=*/1,
+                                  /*in_place=*/true);
+}
+
+// Test broadcast with increasing ranks.
+TYPED_TEST(NcclManagerTest, BroadcastWithDifferentRanks) {
+  for (int num_ranks = 4; num_ranks <= 8; ++num_ranks) {
+    const int src_rank = static_cast<int>(random::New64() % num_ranks);
+    for (int in_place_idx = 0; in_place_idx <= 1; ++in_place_idx) {
+      const bool in_place = in_place_idx == 0;
+      this->RunMultiNodeBroadcastTest(/*num_nodes=*/1, num_ranks,
+                                      /*src_node=*/0, src_rank, in_place);
+    }
   }
 }
 
@@ -544,13 +607,20 @@
 // environment.  It works on a single node and reuses GPUs.  It enqueues NCCL
 // kernels on separate stream per rank.
 TYPED_TEST(NcclManagerTest, MultiNode) {
-  this->RunMultiNodeTest(/*num_nodes=*/2, /*num_ranks_per_node=*/4);
+  this->RunMultiNodeAllReduceTest(/*num_nodes=*/2, /*num_ranks_per_node=*/4);
 }
 
 // Tests that specifying `communicator_key` with a single node NCCL collective
 // works well.
 TYPED_TEST(NcclManagerTest, MultiNodeSingle) {
-  this->RunMultiNodeTest(/*num_nodes=*/1, /*num_ranks_per_node=*/4);
+  this->RunMultiNodeAllReduceTest(/*num_nodes=*/1, /*num_ranks_per_node=*/4);
+}
+
+// Multi-node broadcast.
+TYPED_TEST(NcclManagerTest, MultiNodeBroadcast) {
+  this->RunMultiNodeBroadcastTest(/*num_nodes=*/4, /*num_ranks_per_node=*/8,
+                                  /*src_node=*/2, /*src_local_rank=*/3,
+                                  /*in_place=*/true);
 }
 
 // Checks that we return error status if a collective_key is used for different
@@ -574,14 +644,16 @@
                                               {"bad_coll_type",
                                                /*num_local_devices=*/num_ranks,
                                                /*num_global_devices=*/num_ranks,
-                                               /*communicator_key=*/""},
+                                               /*communicator_key=*/"",
+                                               /*source_rank=*/-1},
                                               ncclSum);
     } else {
       NcclManager::instance()->AddBroadcastSend(
-          std::move(participant), {"bad_coll_type",
-                                   /*num_local_devices=*/num_ranks,
-                                   /*num_global_devices=*/num_ranks,
-                                   /*communicator_key=*/""});
+          std::move(participant),
+          {"bad_coll_type",
+           /*num_local_devices=*/num_ranks,
+           /*num_global_devices=*/num_ranks,
+           /*communicator_key=*/"", /*source_rank=*/-1});
     }
   }
 
@@ -609,7 +681,8 @@
         {"bad_coll_type",
          /*num_local_devices=*/num_ranks,
          /*num_global_devices=*/num_ranks,
-         rank == 0 ? "" : NcclManager::instance()->GenerateCommunicatorKey()},
+         rank == 0 ? "" : NcclManager::instance()->GenerateCommunicatorKey(),
+         /*source_rank=*/-1},
         ncclSum);
   }
 
@@ -637,12 +710,95 @@
                                             {"bad_coll_type",
                                              /*num_local_devices=*/num_devices,
                                              /*num_global_devices=*/num_devices,
-                                             /*communicator_key=*/""},
+                                             /*communicator_key=*/"",
+                                             /*source_rank=*/-1},
                                             ncclSum);
   }
 
   this->VerifyError(test_case.get());
-}  // namespace tensorflow
+}
+
+// Checks that we return error status if a broadcast does not have source.
+TYPED_TEST(NcclManagerTest, BroadcastNoSource) {
+  const int num_ranks = 2;
+
+  std::unique_ptr<typename TestFixture::TestCase> test_case(
+      this->MakeBroadcastTestCase(/*num_nodes=*/1, num_ranks,
+                                  TensorShape({2, 3}), /*src_node=*/-1,
+                                  /*src_rank=*/-1, false));
+  for (int rank = 0; rank < num_ranks; ++rank) {
+    auto* device = this->GetDevice(rank);
+    auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr;
+    auto* stream = device->tensorflow_gpu_device_info()->stream;
+    auto participant = absl::make_unique<NcclManager::Participant>(
+        device->executor(), stream, event_mgr, device->gpu_id(), nullptr,
+        &test_case->outs[rank], rank,
+        this->CreateDoneCallback(test_case.get()));
+    NcclManager::instance()->AddBroadcastRecv(std::move(participant),
+                                              {"bcast_no_send",
+                                               /*num_local_devices=*/num_ranks,
+                                               /*num_global_devices=*/num_ranks,
+                                               /*communicator_key=*/"",
+                                               /*source_rank=*/-1});
+  }
+
+  this->VerifyError(test_case.get());
+}
+
+// Checks that we return error status if a broadcast has multiple sends.
+TYPED_TEST(NcclManagerTest, BroadcastMultipleSends) {
+  const int num_ranks = 2;
+
+  std::unique_ptr<typename TestFixture::TestCase> test_case(
+      this->MakeBroadcastTestCase(/*num_nodes=*/1, num_ranks,
+                                  TensorShape({2, 3}), /*src_node=*/-1,
+                                  /*src_rank=*/-1, false));
+  for (int rank = 0; rank < num_ranks; ++rank) {
+    auto* device = this->GetDevice(rank);
+    auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr;
+    auto* stream = device->tensorflow_gpu_device_info()->stream;
+    auto participant = absl::make_unique<NcclManager::Participant>(
+        device->executor(), stream, event_mgr, device->gpu_id(),
+        &test_case->outs[rank], &test_case->outs[rank], rank,
+        this->CreateDoneCallback(test_case.get()));
+    NcclManager::instance()->AddBroadcastSend(std::move(participant),
+                                              {"bcast_multiple_send",
+                                               /*num_local_devices=*/num_ranks,
+                                               /*num_global_devices=*/num_ranks,
+                                               /*communicator_key=*/"",
+                                               /*source_rank=*/-1});
+  }
+
+  this->VerifyError(test_case.get());
+}
+
+// Checks that we return error status if a broadcast has inconsistent source
+// ranks.
+TYPED_TEST(NcclManagerTest, BroadcastInconsistentSource) {
+  const int num_ranks = 2;
+
+  std::unique_ptr<typename TestFixture::TestCase> test_case(
+      this->MakeBroadcastTestCase(/*num_nodes=*/1, num_ranks,
+                                  TensorShape({2, 3}), /*src_node=*/-1,
+                                  /*src_rank=*/-1, false));
+  for (int rank = 0; rank < num_ranks; ++rank) {
+    auto* device = this->GetDevice(rank);
+    auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr;
+    auto* stream = device->tensorflow_gpu_device_info()->stream;
+    auto participant = absl::make_unique<NcclManager::Participant>(
+        device->executor(), stream, event_mgr, device->gpu_id(),
+        &test_case->outs[rank], &test_case->outs[rank], rank,
+        this->CreateDoneCallback(test_case.get()));
+    NcclManager::instance()->AddBroadcastRecv(std::move(participant),
+                                              {"bcast_inconsistent_source",
+                                               /*num_local_devices=*/num_ranks,
+                                               /*num_global_devices=*/num_ranks,
+                                               /*communicator_key=*/"",
+                                               /*source_rank=*/rank});
+  }
+
+  this->VerifyError(test_case.get());
+}
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 8d9759c..3f0f0c2 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -1555,6 +1555,7 @@
     .Output("output: T")
     .Attr("T: {bfloat16, half, float, double}")
     .Attr("message: string")
+    .SetIsStateful()
     .SetShapeFn(shape_inference::UnchangedShape);
 
 // --------------------------------------------------------------------------
@@ -3397,7 +3398,7 @@
           return errors::InvalidArgument("`method` must be rank 0: ",
                                          method->shape());
         }
-        const string& method_string = method->scalar<string>()();
+        const string& method_string = method->scalar<tstring>()();
         if (method_string != "farmhash64") {
           return errors::InvalidArgument("Unsupported method: ", method_string);
         }
diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc
index 4e33bcd..db19357 100644
--- a/tensorflow/core/ops/boosted_trees_ops.cc
+++ b/tensorflow/core/ops/boosted_trees_ops.cc
@@ -105,7 +105,7 @@
     .Input("tree_complexity: float")
     .Input("min_node_weight: float")
     .Attr("logits_dimension: int >= 1")
-    .Attr("split_type: {'inequality'} = 'inequality'")
+    .Attr("split_type: {'inequality', 'equality'} = 'inequality'")
     .Output("node_ids: int32")
     .Output("gains: float32")
     .Output("feature_dimensions: int32")
diff --git a/tensorflow/core/ops/compat/BUILD b/tensorflow/core/ops/compat/BUILD
index 566fa89..299076d 100644
--- a/tensorflow/core/ops/compat/BUILD
+++ b/tensorflow/core/ops/compat/BUILD
@@ -34,11 +34,11 @@
     size = "small",
     srcs = ["backwards_compatibility_test.cc"],
     data = [
-        ":ops_history.v0.pbtxt",
-        ":ops_history.v1.pbtxt",
-        ":ops_history.v2.pbtxt",
         "//tensorflow/core:ops/ops.pbtxt",
-    ],
+    ] + glob([
+        "ops_history_v*/*.pbtxt",
+        "ops_history.v*.pbtxt",
+    ]),
     deps = [
         ":op_compatibility_lib",
         "//tensorflow/core:framework",
diff --git a/tensorflow/core/ops/compat/op_compatibility_lib.cc b/tensorflow/core/ops/compat/op_compatibility_lib.cc
index a44fead..9005e74 100644
--- a/tensorflow/core/ops/compat/op_compatibility_lib.cc
+++ b/tensorflow/core/ops/compat/op_compatibility_lib.cc
@@ -27,17 +27,83 @@
 
 namespace tensorflow {
 
+static string OpsHistoryDirectory(const string& ops_prefix,
+                                  const string& history_version) {
+  return io::JoinPath(ops_prefix,
+                      strings::StrCat("compat/ops_history_", history_version));
+}
+
 static string OpsHistoryFile(const string& ops_prefix,
                              const string& history_version) {
   return io::JoinPath(ops_prefix, strings::StrCat("compat/ops_history.",
                                                   history_version, ".pbtxt"));
 }
 
+static string FileNameFromOpName(const string& op_name) {
+  return strings::StrCat(op_name, ".pbtxt");
+}
+
+static void AddNewOpToHistory(const OpDef& op,
+                              OpCompatibilityLib::OpHistory* out_op_history) {
+  if (out_op_history != nullptr) {
+    out_op_history->emplace_back(FileNameFromOpName(op.name()), OpList());
+    *out_op_history->back().second.add_op() = op;
+  }
+}
+
+static Status ReadOpHistory(Env* env, const string& file,
+                            const string& directory,
+                            OpCompatibilityLib::OpHistory* out) {
+  // Read op history form `directory` if it exists there.
+  std::vector<string> matching_files;
+  Status status = env->GetMatchingPaths(io::JoinPath(directory, "*.pbtxt"),
+                                        &matching_files);
+  if (status.ok() && !matching_files.empty()) {
+    printf("Reading op history from %s/*.pbtxt...\n", directory.c_str());
+    std::sort(matching_files.begin(), matching_files.end());
+    for (const string& full_file : matching_files) {
+      string op_history_str;
+      TF_RETURN_IF_ERROR(ReadFileToString(env, full_file, &op_history_str));
+      OpList in_op_history;
+      protobuf::TextFormat::ParseFromString(op_history_str, &in_op_history);
+      const string file_tail = FileNameFromOpName(in_op_history.op(0).name());
+      const string expected = io::JoinPath(directory, file_tail);
+      if (full_file != expected) {
+        return errors::Internal("Expected file paths to match but '", full_file,
+                                "' != '", expected, "'");
+      }
+      out->emplace_back(file_tail, in_op_history);
+    }
+  } else {  // Otherwise, fall back to reading op history from `file`.
+    printf("Reading op history from %s...\n", file.c_str());
+    string op_history_str;
+    TF_RETURN_IF_ERROR(ReadFileToString(env, file, &op_history_str));
+    OpList in_op_history;
+    protobuf::TextFormat::ParseFromString(op_history_str, &in_op_history);
+    // Convert from a linear OpList to OpHistory format with one OpList per
+    // unique op name.
+    int start = 0;
+    while (start < in_op_history.op_size()) {
+      int end = start + 1;
+      while (end < in_op_history.op_size() &&
+             in_op_history.op(start).name() == in_op_history.op(end).name()) {
+        ++end;
+      }
+      AddNewOpToHistory(in_op_history.op(start), out);
+      for (++start; start < end; ++start) {
+        *out->back().second.add_op() = in_op_history.op(start);
+      }
+    }
+  }
+  return Status::OK();
+}
+
 OpCompatibilityLib::OpCompatibilityLib(const string& ops_prefix,
                                        const string& history_version,
                                        const std::set<string>* stable_ops)
     : ops_file_(io::JoinPath(ops_prefix, "ops.pbtxt")),
       op_history_file_(OpsHistoryFile(ops_prefix, history_version)),
+      op_history_directory_(OpsHistoryDirectory(ops_prefix, history_version)),
       stable_ops_(stable_ops) {
   // Get the sorted list of all registered OpDefs.
   printf("Getting all registered ops...\n");
@@ -46,7 +112,7 @@
 
 Status OpCompatibilityLib::ValidateCompatible(Env* env, int* changed_ops,
                                               int* added_ops,
-                                              OpList* out_op_history) {
+                                              OpHistory* out_op_history) {
   *changed_ops = 0;
   *added_ops = 0;
 
@@ -78,104 +144,90 @@
     }
   }
 
-  OpList in_op_history;
-  {  // Read op history.
-    printf("Reading op history from %s...\n", op_history_file_.c_str());
-    string op_history_str;
-    TF_RETURN_IF_ERROR(
-        ReadFileToString(env, op_history_file_, &op_history_str));
-    protobuf::TextFormat::ParseFromString(op_history_str, &in_op_history);
-  }
+  OpHistory in_op_history;
+  TF_RETURN_IF_ERROR(ReadOpHistory(env, op_history_file_, op_history_directory_,
+                                   &in_op_history));
 
   int cur = 0;
-  int start = 0;
+  int hist = 0;
 
   printf("Verifying updates are compatible...\n");
-  // Note: Op history is in (alphabetical, oldest-first) order.
-  while (cur < op_list_.op_size() && start < in_op_history.op_size()) {
-    const string& op_name = op_list_.op(cur).name();
-    if (stable_ops_ != nullptr && stable_ops_->count(op_name) == 0) {
+  // Note: Op history is one OpList per unique op name in alphabetical order.
+  // Within the OplList it has versions in oldest-first order.
+  while (cur < op_list_.op_size() && hist < in_op_history.size()) {
+    const OpDef& cur_op = op_list_.op(cur);
+    const string& cur_op_name = cur_op.name();
+    const OpList& history_op_list = in_op_history[hist].second;
+    const string& history_op_name = history_op_list.op(0).name();
+    if (stable_ops_ != nullptr && stable_ops_->count(cur_op_name) == 0) {
       // Ignore unstable op.
       for (++cur; cur < op_list_.op_size(); ++cur) {
-        if (op_list_.op(cur).name() != op_name) break;
+        if (op_list_.op(cur).name() != cur_op_name) break;
       }
-    } else if (op_name < in_op_history.op(start).name()) {
+    } else if (cur_op_name < history_op_name) {
       // New op: add it.
-      if (out_op_history != nullptr) {
-        *out_op_history->add_op() = op_list_.op(cur);
-      }
+      AddNewOpToHistory(cur_op, out_op_history);
       ++*added_ops;
       ++cur;
-    } else if (op_name > in_op_history.op(start).name()) {
+    } else if (cur_op_name > history_op_name) {
       if (stable_ops_ != nullptr) {
         // Okay to remove ops from the history that have been made unstable.
-        for (++start; start < in_op_history.op_size(); ++start) {
-          if (op_name <= in_op_history.op(start).name()) break;
-        }
+        ++hist;
       } else {
         // Op removed: error.
         return errors::InvalidArgument("Error, removed op: ",
-                                       SummarizeOpDef(in_op_history.op(start)));
+                                       SummarizeOpDef(history_op_list.op(0)));
       }
     } else {
       // Op match.
-
-      // Find all historical version of this op.
-      int end = start + 1;
-      for (; end < in_op_history.op_size(); ++end) {
-        if (in_op_history.op(end).name() != op_name) break;
-      }
-
       if (out_op_history != nullptr) {
         // Copy from in_op_history to *out_op_history.
-        for (int i = start; i < end; ++i) {
-          *out_op_history->add_op() = in_op_history.op(i);
-        }
+        out_op_history->push_back(in_op_history[hist]);
       }
 
+      const int end = history_op_list.op_size();
       // Is the last op in the history the same as the current op?
       // Compare using their serialized representations.
       string history_str, cur_str;
-      in_op_history.op(end - 1).SerializeToString(&history_str);
-      op_list_.op(cur).SerializeToString(&cur_str);
+      history_op_list.op(end - 1).SerializeToString(&history_str);
+      cur_op.SerializeToString(&cur_str);
 
       if (history_str != cur_str) {
         // Op changed, verify the change is compatible.
-        for (int i = start; i < end; ++i) {
-          TF_RETURN_IF_ERROR(
-              OpDefCompatible(in_op_history.op(i), op_list_.op(cur)));
+        for (int i = 0; i < end; ++i) {
+          TF_RETURN_IF_ERROR(OpDefCompatible(history_op_list.op(i), cur_op));
         }
 
         // Verify default value of attrs has not been added/removed/modified
         // as compared to only the last historical version.
-        TF_RETURN_IF_ERROR(OpDefAttrDefaultsUnchanged(in_op_history.op(end - 1),
-                                                      op_list_.op(cur)));
+        TF_RETURN_IF_ERROR(
+            OpDefAttrDefaultsUnchanged(history_op_list.op(end - 1), cur_op));
 
-        // Check that attrs missing from in_op_history.op(start) don't
-        // change their defaults.
-        if (start < end - 1) {
+        // Check that attrs missing from history_op_list.op(0) don't change
+        // their defaults.
+        if (end > 1) {
           TF_RETURN_IF_ERROR(OpDefAddedDefaultsUnchanged(
-              in_op_history.op(start), in_op_history.op(end - 1),
-              op_list_.op(cur)));
+              history_op_list.op(0), history_op_list.op(end - 1), cur_op));
         }
 
         // Compatible! Add changed op to the end of the history.
         if (out_op_history != nullptr) {
-          *out_op_history->add_op() = op_list_.op(cur);
+          *out_op_history->back().second.add_op() = cur_op;
         }
         ++*changed_ops;
       }
 
       // Advance past this op.
-      start = end;
+      ++hist;
       ++cur;
     }
   }
 
   // Error if missing ops.
-  if (stable_ops_ == nullptr && start < in_op_history.op_size()) {
-    return errors::InvalidArgument("Error, removed op: ",
-                                   SummarizeOpDef(in_op_history.op(start)));
+  if (stable_ops_ == nullptr && hist < in_op_history.size()) {
+    return errors::InvalidArgument(
+        "Error, removed op: ",
+        SummarizeOpDef(in_op_history[hist].second.op(0)));
   }
 
   // Add remaining new ops.
@@ -184,9 +236,7 @@
     if (stable_ops_ != nullptr && stable_ops_->count(op_name) == 0) {
       // Ignore unstable op.
     } else {
-      if (out_op_history) {
-        *out_op_history->add_op() = op_list_.op(cur);
-      }
+      AddNewOpToHistory(op_list_.op(cur), out_op_history);
       ++*added_ops;
     }
   }
diff --git a/tensorflow/core/ops/compat/op_compatibility_lib.h b/tensorflow/core/ops/compat/op_compatibility_lib.h
index 054f903..2f26fd6 100644
--- a/tensorflow/core/ops/compat/op_compatibility_lib.h
+++ b/tensorflow/core/ops/compat/op_compatibility_lib.h
@@ -45,6 +45,11 @@
   // order.
   const string& op_history_file() const { return op_history_file_; }
 
+  // Name of the directory that contains all versions of *stable* ops,
+  // without docs.  Op history is one file per op, in oldest-first
+  // order within the file.
+  const string& op_history_directory() const { return op_history_directory_; }
+
   // Should match the contents of ops_file().  Run before calling
   // ValidateCompatible().
   string OpsString() const { return op_list_.DebugString(); }
@@ -53,17 +58,21 @@
   // just stable ops.
   int num_all_ops() const { return op_list_.op_size(); }
 
+  // <file name, file contents> pairs representing op history.
+  typedef std::vector<std::pair<string, OpList>> OpHistory;
+
   // Make sure the current version of the *stable* ops are compatible
   // with the historical versions, and if out_op_history != nullptr,
   // generate a new history adding all changed ops.  Sets
   // *changed_ops/*added_ops to the number of changed/added ops
   // (ignoring doc changes).
   Status ValidateCompatible(Env* env, int* changed_ops, int* added_ops,
-                            OpList* out_op_history);
+                            OpHistory* out_op_history);
 
  private:
   const string ops_file_;
   const string op_history_file_;
+  const string op_history_directory_;
   const std::set<string>* stable_ops_;
   OpList op_list_;
 };
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
deleted file mode 100644
index d163bf5..0000000
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ /dev/null
@@ -1,94911 +0,0 @@
-op {
-  name: "Abort"
-  attr {
-    name: "error_msg"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "exit_without_error"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "Abs"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Abs"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Abs"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "AccumulateNV2"
-  input_arg {
-    name: "inputs"
-    type_attr: "T"
-    number_attr: "N"
-  }
-  output_arg {
-    name: "sum"
-    type_attr: "T"
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  is_aggregate: true
-  is_commutative: true
-}
-op {
-  name: "AccumulateNV2"
-  input_arg {
-    name: "inputs"
-    type_attr: "T"
-    number_attr: "N"
-  }
-  output_arg {
-    name: "sum"
-    type_attr: "T"
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  is_aggregate: true
-  is_commutative: true
-}
-op {
-  name: "AccumulateNV2"
-  input_arg {
-    name: "inputs"
-    type_attr: "T"
-    number_attr: "N"
-  }
-  output_arg {
-    name: "sum"
-    type_attr: "T"
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  is_aggregate: true
-  is_commutative: true
-}
-op {
-  name: "AccumulatorApplyGradient"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "local_step"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "gradient"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "AccumulatorApplyGradient"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "local_step"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "gradient"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "AccumulatorApplyGradient"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "local_step"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "gradient"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "AccumulatorApplyGradient"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "local_step"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "gradient"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "AccumulatorNumAccumulated"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  output_arg {
-    name: "num_accumulated"
-    type: DT_INT32
-  }
-}
-op {
-  name: "AccumulatorSetGlobalStep"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "new_global_step"
-    type: DT_INT64
-  }
-}
-op {
-  name: "AccumulatorTakeGradient"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "num_required"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "average"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "AccumulatorTakeGradient"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "num_required"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "average"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "AccumulatorTakeGradient"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "num_required"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "average"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "AccumulatorTakeGradient"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "num_required"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "average"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "Acos"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Acos"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Acos"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Acosh"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Acosh"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Acosh"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Add"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_STRING
-      }
-    }
-  }
-}
-op {
-  name: "Add"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_STRING
-      }
-    }
-  }
-}
-op {
-  name: "Add"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_STRING
-      }
-    }
-  }
-}
-op {
-  name: "AddManySparseToTensorsMap"
-  input_arg {
-    name: "sparse_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sparse_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sparse_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sparse_handles"
-    type: DT_INT64
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "AddN"
-  input_arg {
-    name: "inputs"
-    type_attr: "T"
-    number_attr: "N"
-  }
-  output_arg {
-    name: "sum"
-    type_attr: "T"
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  is_aggregate: true
-  is_commutative: true
-}
-op {
-  name: "AddN"
-  input_arg {
-    name: "inputs"
-    type_attr: "T"
-    number_attr: "N"
-  }
-  output_arg {
-    name: "sum"
-    type_attr: "T"
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_VARIANT
-      }
-    }
-  }
-  is_aggregate: true
-  is_commutative: true
-}
-op {
-  name: "AddN"
-  input_arg {
-    name: "inputs"
-    type_attr: "T"
-    number_attr: "N"
-  }
-  output_arg {
-    name: "sum"
-    type_attr: "T"
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_VARIANT
-      }
-    }
-  }
-  is_aggregate: true
-  is_commutative: true
-}
-op {
-  name: "AddN"
-  input_arg {
-    name: "inputs"
-    type_attr: "T"
-    number_attr: "N"
-  }
-  output_arg {
-    name: "sum"
-    type_attr: "T"
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-        type: DT_VARIANT
-      }
-    }
-  }
-  is_aggregate: true
-  is_commutative: true
-}
-op {
-  name: "AddN"
-  input_arg {
-    name: "inputs"
-    type_attr: "T"
-    number_attr: "N"
-  }
-  output_arg {
-    name: "sum"
-    type_attr: "T"
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_VARIANT
-      }
-    }
-  }
-  is_aggregate: true
-  is_commutative: true
-}
-op {
-  name: "AddSparseToTensorsMap"
-  input_arg {
-    name: "sparse_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sparse_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sparse_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sparse_handle"
-    type: DT_INT64
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "AddV2"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  is_aggregate: true
-  is_commutative: true
-}
-op {
-  name: "AddV2"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  is_aggregate: true
-  is_commutative: true
-}
-op {
-  name: "AddV2"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  is_aggregate: true
-  is_commutative: true
-}
-op {
-  name: "AdjustContrast"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "contrast_factor"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_value"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_value"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  deprecation {
-    version: 2
-  }
-}
-op {
-  name: "AdjustContrastv2"
-  input_arg {
-    name: "images"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "contrast_factor"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type: DT_FLOAT
-  }
-}
-op {
-  name: "AdjustContrastv2"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "contrast_factor"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "AdjustHue"
-  input_arg {
-    name: "images"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "delta"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type: DT_FLOAT
-  }
-}
-op {
-  name: "AdjustHue"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "delta"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "AdjustSaturation"
-  input_arg {
-    name: "images"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "scale"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type: DT_FLOAT
-  }
-}
-op {
-  name: "AdjustSaturation"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "scale"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "All"
-  input_arg {
-    name: "input"
-    type: DT_BOOL
-  }
-  input_arg {
-    name: "reduction_indices"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type: DT_BOOL
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "AllCandidateSampler"
-  input_arg {
-    name: "true_classes"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sampled_candidates"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "true_expected_count"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "sampled_expected_count"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "num_true"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "num_sampled"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "unique"
-    type: "bool"
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-}
-op {
-  name: "AllCandidateSampler"
-  input_arg {
-    name: "true_classes"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sampled_candidates"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "true_expected_count"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "sampled_expected_count"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "num_true"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "num_sampled"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "unique"
-    type: "bool"
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "AllToAll"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "group_assignment"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "concat_dimension"
-    type: "int"
-  }
-  attr {
-    name: "split_dimension"
-    type: "int"
-  }
-  attr {
-    name: "split_count"
-    type: "int"
-  }
-}
-op {
-  name: "AllToAll"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "group_assignment"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BOOL
-      }
-    }
-  }
-  attr {
-    name: "concat_dimension"
-    type: "int"
-  }
-  attr {
-    name: "split_dimension"
-    type: "int"
-  }
-  attr {
-    name: "split_count"
-    type: "int"
-  }
-}
-op {
-  name: "Angle"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "Tout"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_COMPLEX64
-    }
-    allowed_values {
-      list {
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  attr {
-    name: "Tout"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "AnonymousIterator"
-  output_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "AnonymousIteratorV2"
-  output_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "deleter"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "AnonymousMultiDeviceIterator"
-  output_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "deleter"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "devices"
-    type: "list(string)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "Any"
-  input_arg {
-    name: "input"
-    type: DT_BOOL
-  }
-  input_arg {
-    name: "reduction_indices"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type: DT_BOOL
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "ApplyAdaMax"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "m"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "v"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "beta1_power"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyAdadelta"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum_update"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyAdadelta"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum_update"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyAdadelta"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum_update"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyAdadelta"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum_update"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyAdagrad"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyAdagrad"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyAdagrad"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyAdagrad"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyAdagrad"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "update_slots"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "ApplyAdagradDA"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "gradient_accumulator"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "gradient_squared_accumulator"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "global_step"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyAdagradDA"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "gradient_accumulator"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "gradient_squared_accumulator"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "global_step"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyAdagradDA"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "gradient_accumulator"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "gradient_squared_accumulator"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "global_step"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyAdagradDA"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "gradient_accumulator"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "gradient_squared_accumulator"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "global_step"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyAdam"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "m"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "v"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "beta1_power"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta2_power"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyAdam"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "m"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "v"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "beta1_power"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta2_power"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyAdam"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "m"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "v"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "beta1_power"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta2_power"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyAdam"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "m"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "v"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "beta1_power"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta2_power"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyAdam"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "m"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "v"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "beta1_power"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta2_power"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyAddSign"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "m"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sign_decay"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyAddSign"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "m"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sign_decay"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyAddSign"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "m"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sign_decay"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyCenteredRMSProp"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "mg"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "ms"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "mom"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyCenteredRMSProp"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "mg"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "ms"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "mom"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyCenteredRMSProp"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "mg"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "ms"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "mom"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyCenteredRMSProp"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "mg"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "ms"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "mom"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyFtrl"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "linear"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyFtrl"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "linear"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyFtrl"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "linear"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyFtrl"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "linear"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyFtrlV2"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "linear"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2_shrinkage"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyFtrlV2"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "linear"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2_shrinkage"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyFtrlV2"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "linear"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2_shrinkage"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyFtrlV2"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "linear"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2_shrinkage"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyGradientDescent"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "delta"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyGradientDescent"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "delta"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyGradientDescent"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "delta"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyGradientDescent"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "delta"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyMomentum"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyMomentum"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyMomentum"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyMomentum"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyPowerSign"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "m"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "logbase"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sign_decay"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyPowerSign"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "m"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "logbase"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sign_decay"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyPowerSign"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "m"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "logbase"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sign_decay"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyProximalAdagrad"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyProximalAdagrad"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyProximalAdagrad"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyProximalAdagrad"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyProximalGradientDescent"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "delta"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyProximalGradientDescent"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "delta"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyProximalGradientDescent"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "delta"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyProximalGradientDescent"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "delta"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyRMSProp"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "ms"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "mom"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyRMSProp"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "ms"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "mom"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyRMSProp"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "ms"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "mom"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApplyRMSProp"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "ms"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "mom"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ApproximateEqual"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "tolerance"
-    type: "float"
-    default_value {
-      f: 1e-05
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "ApproximateEqual"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "tolerance"
-    type: "float"
-    default_value {
-      f: 1e-05
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "ApproximateEqual"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "tolerance"
-    type: "float"
-    default_value {
-      f: 1e-05
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "ApproximateEqual"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "tolerance"
-    type: "float"
-    default_value {
-      f: 1e-05
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "ArgMax"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dimension"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type: DT_INT64
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "ArgMax"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dimension"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "output_type"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "output_type"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "ArgMax"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dimension"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "output_type"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "output_type"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "ArgMax"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dimension"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "output_type"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "output_type"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "ArgMax"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dimension"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "output_type"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "output_type"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "ArgMin"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dimension"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type: DT_INT64
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "ArgMin"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dimension"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "output_type"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "output_type"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "ArgMin"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dimension"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "output_type"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "output_type"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "ArgMin"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dimension"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "output_type"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "output_type"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "ArgMin"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dimension"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "output_type"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "output_type"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "AsString"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type: DT_STRING
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_BOOL
-        type: DT_INT8
-      }
-    }
-  }
-  attr {
-    name: "precision"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "scientific"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "shortest"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "width"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "fill"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-}
-op {
-  name: "AsString"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type: DT_STRING
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_BOOL
-      }
-    }
-  }
-  attr {
-    name: "precision"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "scientific"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "shortest"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "width"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "fill"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-}
-op {
-  name: "AsString"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type: DT_STRING
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_BOOL
-      }
-    }
-  }
-  attr {
-    name: "precision"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "scientific"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "shortest"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "width"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "fill"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-}
-op {
-  name: "Asin"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Asin"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Asin"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Asinh"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Asinh"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Asinh"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Assert"
-  input_arg {
-    name: "condition"
-    type: DT_BOOL
-  }
-  input_arg {
-    name: "data"
-    type_list_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "summarize"
-    type: "int"
-    default_value {
-      i: 3
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "AssertNextDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "transformations"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "Assign"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "validate_shape"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  allows_uninitialized_input: true
-}
-op {
-  name: "AssignAdd"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "AssignAdd"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "AssignAdd"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "AssignAdd"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "AssignAddVariableOp"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "value"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "AssignSub"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "AssignSub"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "AssignSub"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "AssignSub"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "AssignSubVariableOp"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "value"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "AssignVariableOp"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "value"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "Atan"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Atan"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Atan"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Atan2"
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Atan2"
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Atan2"
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Atanh"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Atanh"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Atanh"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "AudioSpectrogram"
-  input_arg {
-    name: "input"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "spectrogram"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "window_size"
-    type: "int"
-  }
-  attr {
-    name: "stride"
-    type: "int"
-  }
-  attr {
-    name: "magnitude_squared"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "AudioSummary"
-  input_arg {
-    name: "tag"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "tensor"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "summary"
-    type: DT_STRING
-  }
-  attr {
-    name: "sample_rate"
-    type: "float"
-  }
-  attr {
-    name: "max_outputs"
-    type: "int"
-    default_value {
-      i: 3
-    }
-    has_minimum: true
-    minimum: 1
-  }
-  deprecation {
-    version: 15
-  }
-}
-op {
-  name: "AudioSummaryV2"
-  input_arg {
-    name: "tag"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "tensor"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "sample_rate"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "summary"
-    type: DT_STRING
-  }
-  attr {
-    name: "max_outputs"
-    type: "int"
-    default_value {
-      i: 3
-    }
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "AutoShardDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "num_workers"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "index"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "AvgPool"
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_HALF
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "AvgPool"
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "AvgPool"
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "AvgPool"
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "AvgPool3D"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "AvgPool3D"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NDHWC"
-    }
-    allowed_values {
-      list {
-        s: "NDHWC"
-        s: "NCDHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "AvgPool3D"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NDHWC"
-    }
-    allowed_values {
-      list {
-        s: "NDHWC"
-        s: "NCDHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "AvgPool3D"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NDHWC"
-    }
-    allowed_values {
-      list {
-        s: "NDHWC"
-        s: "NCDHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "AvgPool3DGrad"
-  input_arg {
-    name: "orig_input_shape"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "AvgPool3DGrad"
-  input_arg {
-    name: "orig_input_shape"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NDHWC"
-    }
-    allowed_values {
-      list {
-        s: "NDHWC"
-        s: "NCDHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "AvgPool3DGrad"
-  input_arg {
-    name: "orig_input_shape"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NDHWC"
-    }
-    allowed_values {
-      list {
-        s: "NDHWC"
-        s: "NCDHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "AvgPool3DGrad"
-  input_arg {
-    name: "orig_input_shape"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NDHWC"
-    }
-    allowed_values {
-      list {
-        s: "NDHWC"
-        s: "NCDHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "AvgPoolGrad"
-  input_arg {
-    name: "orig_input_shape"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_HALF
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "AvgPoolGrad"
-  input_arg {
-    name: "orig_input_shape"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "AvgPoolGrad"
-  input_arg {
-    name: "orig_input_shape"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "AvgPoolGrad"
-  input_arg {
-    name: "orig_input_shape"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Barrier"
-  output_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "component_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "shapes"
-    type: "list(shape)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "BarrierClose"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "cancel_pending_enqueues"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "BarrierIncompleteSize"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  output_arg {
-    name: "size"
-    type: DT_INT32
-  }
-}
-op {
-  name: "BarrierInsertMany"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "keys"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "component_index"
-    type: "int"
-  }
-}
-op {
-  name: "BarrierReadySize"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  output_arg {
-    name: "size"
-    type: DT_INT32
-  }
-}
-op {
-  name: "BarrierTakeMany"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "num_elements"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "keys"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "values"
-    type_list_attr: "component_types"
-  }
-  attr {
-    name: "component_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "allow_small_batch"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "wait_for_incomplete"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "timeout_ms"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-}
-op {
-  name: "Batch"
-  input_arg {
-    name: "in_tensors"
-    type_list_attr: "T"
-  }
-  output_arg {
-    name: "batched_tensors"
-    type_list_attr: "T"
-  }
-  output_arg {
-    name: "batch_index"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "id"
-    type: DT_INT64
-  }
-  attr {
-    name: "num_batch_threads"
-    type: "int"
-  }
-  attr {
-    name: "max_batch_size"
-    type: "int"
-  }
-  attr {
-    name: "batch_timeout_micros"
-    type: "int"
-  }
-  attr {
-    name: "allowed_batch_sizes"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "grad_timeout_micros"
-    type: "int"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "batching_queue"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "T"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "Batch"
-  input_arg {
-    name: "in_tensors"
-    type_list_attr: "T"
-  }
-  output_arg {
-    name: "batched_tensors"
-    type_list_attr: "T"
-  }
-  output_arg {
-    name: "batch_index"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "id"
-    type: DT_INT64
-  }
-  attr {
-    name: "num_batch_threads"
-    type: "int"
-  }
-  attr {
-    name: "max_batch_size"
-    type: "int"
-  }
-  attr {
-    name: "max_enqueued_batches"
-    type: "int"
-    default_value {
-      i: 10
-    }
-  }
-  attr {
-    name: "batch_timeout_micros"
-    type: "int"
-  }
-  attr {
-    name: "allowed_batch_sizes"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "grad_timeout_micros"
-    type: "int"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "batching_queue"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "T"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "BatchCholesky"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-      }
-    }
-  }
-  deprecation {
-    version: 13
-  }
-}
-op {
-  name: "BatchCholeskyGrad"
-  input_arg {
-    name: "l"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  deprecation {
-    version: 13
-  }
-}
-op {
-  name: "BatchDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "batch_size"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "BatchDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "batch_size"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "BatchDatasetV2"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "batch_size"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "drop_remainder"
-    type: DT_BOOL
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "BatchDatasetV2"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "batch_size"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "drop_remainder"
-    type: DT_BOOL
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "parallel_copy"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "BatchFFT"
-  input_arg {
-    name: "input"
-    type: DT_COMPLEX64
-  }
-  output_arg {
-    name: "output"
-    type: DT_COMPLEX64
-  }
-  deprecation {
-    version: 15
-  }
-}
-op {
-  name: "BatchFFT2D"
-  input_arg {
-    name: "input"
-    type: DT_COMPLEX64
-  }
-  output_arg {
-    name: "output"
-    type: DT_COMPLEX64
-  }
-  deprecation {
-    version: 15
-  }
-}
-op {
-  name: "BatchFFT3D"
-  input_arg {
-    name: "input"
-    type: DT_COMPLEX64
-  }
-  output_arg {
-    name: "output"
-    type: DT_COMPLEX64
-  }
-  deprecation {
-    version: 15
-  }
-}
-op {
-  name: "BatchFunction"
-  input_arg {
-    name: "in_tensors"
-    type_list_attr: "Tin"
-  }
-  input_arg {
-    name: "captured_tensors"
-    type_list_attr: "Tcaptured"
-  }
-  output_arg {
-    name: "out_tensors"
-    type_list_attr: "Tout"
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "num_batch_threads"
-    type: "int"
-  }
-  attr {
-    name: "max_batch_size"
-    type: "int"
-  }
-  attr {
-    name: "batch_timeout_micros"
-    type: "int"
-  }
-  attr {
-    name: "max_enqueued_batches"
-    type: "int"
-    default_value {
-      i: 10
-    }
-  }
-  attr {
-    name: "allowed_batch_sizes"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "batching_queue"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "Tin"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "Tcaptured"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tout"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "BatchIFFT"
-  input_arg {
-    name: "input"
-    type: DT_COMPLEX64
-  }
-  output_arg {
-    name: "output"
-    type: DT_COMPLEX64
-  }
-  deprecation {
-    version: 15
-  }
-}
-op {
-  name: "BatchIFFT2D"
-  input_arg {
-    name: "input"
-    type: DT_COMPLEX64
-  }
-  output_arg {
-    name: "output"
-    type: DT_COMPLEX64
-  }
-  deprecation {
-    version: 15
-  }
-}
-op {
-  name: "BatchIFFT3D"
-  input_arg {
-    name: "input"
-    type: DT_COMPLEX64
-  }
-  output_arg {
-    name: "output"
-    type: DT_COMPLEX64
-  }
-  deprecation {
-    version: 15
-  }
-}
-op {
-  name: "BatchMatMul"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  attr {
-    name: "adj_x"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "adj_y"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "BatchMatMul"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  attr {
-    name: "adj_x"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "adj_y"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "BatchMatMul"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  attr {
-    name: "adj_x"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "adj_y"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "BatchMatMul"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  attr {
-    name: "adj_x"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "adj_y"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "BatchMatMulV2"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  attr {
-    name: "adj_x"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "adj_y"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "BatchMatrixBandPart"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "num_lower"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "num_upper"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "band"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  deprecation {
-    version: 14
-  }
-}
-op {
-  name: "BatchMatrixDeterminant"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  deprecation {
-    version: 13
-  }
-}
-op {
-  name: "BatchMatrixDeterminant"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  deprecation {
-    version: 13
-  }
-}
-op {
-  name: "BatchMatrixDiag"
-  input_arg {
-    name: "diagonal"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  deprecation {
-    version: 14
-  }
-}
-op {
-  name: "BatchMatrixDiagPart"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "diagonal"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  deprecation {
-    version: 14
-  }
-}
-op {
-  name: "BatchMatrixInverse"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "adjoint"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-      }
-    }
-  }
-  deprecation {
-    version: 13
-  }
-}
-op {
-  name: "BatchMatrixSetDiag"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "diagonal"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  deprecation {
-    version: 14
-  }
-}
-op {
-  name: "BatchMatrixSolve"
-  input_arg {
-    name: "matrix"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rhs"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "adjoint"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-      }
-    }
-  }
-  deprecation {
-    version: 13
-  }
-}
-op {
-  name: "BatchMatrixSolveLs"
-  input_arg {
-    name: "matrix"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rhs"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2_regularizer"
-    type: DT_DOUBLE
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "fast"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  deprecation {
-    version: 13
-  }
-}
-op {
-  name: "BatchMatrixTriangularSolve"
-  input_arg {
-    name: "matrix"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rhs"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "lower"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "adjoint"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-      }
-    }
-  }
-  deprecation {
-    version: 13
-  }
-}
-op {
-  name: "BatchNormWithGlobalNormalization"
-  input_arg {
-    name: "t"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "m"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "v"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "gamma"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "result"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "variance_epsilon"
-    type: "float"
-  }
-  attr {
-    name: "scale_after_normalization"
-    type: "bool"
-  }
-  deprecation {
-    version: 9
-  }
-}
-op {
-  name: "BatchNormWithGlobalNormalization"
-  input_arg {
-    name: "t"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "m"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "v"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "gamma"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "result"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "variance_epsilon"
-    type: "float"
-  }
-  attr {
-    name: "scale_after_normalization"
-    type: "bool"
-  }
-  deprecation {
-    version: 9
-  }
-}
-op {
-  name: "BatchNormWithGlobalNormalization"
-  input_arg {
-    name: "t"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "m"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "v"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "gamma"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "result"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "variance_epsilon"
-    type: "float"
-  }
-  attr {
-    name: "scale_after_normalization"
-    type: "bool"
-  }
-  deprecation {
-    version: 9
-  }
-}
-op {
-  name: "BatchNormWithGlobalNormalization"
-  input_arg {
-    name: "t"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "m"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "v"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "gamma"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "result"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "variance_epsilon"
-    type: "float"
-  }
-  attr {
-    name: "scale_after_normalization"
-    type: "bool"
-  }
-  deprecation {
-    version: 9
-  }
-}
-op {
-  name: "BatchNormWithGlobalNormalizationGrad"
-  input_arg {
-    name: "t"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "m"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "v"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "gamma"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "dx"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "dm"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "dv"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "db"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "dg"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "variance_epsilon"
-    type: "float"
-  }
-  attr {
-    name: "scale_after_normalization"
-    type: "bool"
-  }
-  deprecation {
-    version: 9
-  }
-}
-op {
-  name: "BatchNormWithGlobalNormalizationGrad"
-  input_arg {
-    name: "t"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "m"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "v"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "gamma"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "dx"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "dm"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "dv"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "db"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "dg"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "variance_epsilon"
-    type: "float"
-  }
-  attr {
-    name: "scale_after_normalization"
-    type: "bool"
-  }
-  deprecation {
-    version: 9
-  }
-}
-op {
-  name: "BatchNormWithGlobalNormalizationGrad"
-  input_arg {
-    name: "t"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "m"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "v"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "gamma"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "dx"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "dm"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "dv"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "db"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "dg"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "variance_epsilon"
-    type: "float"
-  }
-  attr {
-    name: "scale_after_normalization"
-    type: "bool"
-  }
-  deprecation {
-    version: 9
-  }
-}
-op {
-  name: "BatchNormWithGlobalNormalizationGrad"
-  input_arg {
-    name: "t"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "m"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "v"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "gamma"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "dx"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "dm"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "dv"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "db"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "dg"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "variance_epsilon"
-    type: "float"
-  }
-  attr {
-    name: "scale_after_normalization"
-    type: "bool"
-  }
-  deprecation {
-    version: 9
-  }
-}
-op {
-  name: "BatchSelfAdjointEig"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-      }
-    }
-  }
-  deprecation {
-    version: 11
-  }
-}
-op {
-  name: "BatchSelfAdjointEigV2"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "e"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "v"
-    type_attr: "T"
-  }
-  attr {
-    name: "compute_v"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-      }
-    }
-  }
-  deprecation {
-    version: 13
-  }
-}
-op {
-  name: "BatchSvd"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "s"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "u"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "v"
-    type_attr: "T"
-  }
-  attr {
-    name: "compute_uv"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "full_matrices"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  deprecation {
-    version: 13
-  }
-}
-op {
-  name: "BatchToSpace"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "crops"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "block_size"
-    type: "int"
-    has_minimum: true
-    minimum: 2
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "BatchToSpaceND"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "block_shape"
-    type_attr: "Tblock_shape"
-  }
-  input_arg {
-    name: "crops"
-    type_attr: "Tcrops"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tblock_shape"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tcrops"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "BesselI0e"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "BesselI1e"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Betainc"
-  input_arg {
-    name: "a"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "BiasAdd"
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "bias"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-}
-op {
-  name: "BiasAdd"
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "bias"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-}
-op {
-  name: "BiasAdd"
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "bias"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-}
-op {
-  name: "BiasAdd"
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "bias"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-}
-op {
-  name: "BiasAddGrad"
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-}
-op {
-  name: "BiasAddGrad"
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-}
-op {
-  name: "BiasAddGrad"
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-}
-op {
-  name: "BiasAddGrad"
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-}
-op {
-  name: "BiasAddV1"
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "bias"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "BiasAddV1"
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "bias"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "BiasAddV1"
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "bias"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "BiasAddV1"
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "bias"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "Bincount"
-  input_arg {
-    name: "arr"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "weights"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "bins"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Bitcast"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "type"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "Bitcast"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "type"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "Bitcast"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "type"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "Bitcast"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "type"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-}
-op {
-  name: "Bitcast"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "type"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-}
-op {
-  name: "BitwiseAnd"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_UINT16
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "BitwiseAnd"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "BitwiseOr"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_UINT16
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "BitwiseOr"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "BitwiseXor"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_UINT16
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "BitwiseXor"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "BlockLSTM"
-  input_arg {
-    name: "seq_len_max"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "cs_prev"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "h_prev"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "w"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "wci"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "wcf"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "wco"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "i"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "cs"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "f"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "o"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "ci"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "co"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "h"
-    type_attr: "T"
-  }
-  attr {
-    name: "forget_bias"
-    type: "float"
-    default_value {
-      f: 1
-    }
-  }
-  attr {
-    name: "cell_clip"
-    type: "float"
-    default_value {
-      f: 3
-    }
-  }
-  attr {
-    name: "use_peephole"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "BlockLSTMGrad"
-  input_arg {
-    name: "seq_len_max"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "cs_prev"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "h_prev"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "w"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "wci"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "wcf"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "wco"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "i"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "cs"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "f"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "o"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "ci"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "co"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "h"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "cs_grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "h_grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "x_grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "cs_prev_grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "h_prev_grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "w_grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "wci_grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "wcf_grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "wco_grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "b_grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "use_peephole"
-    type: "bool"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "BoostedTreesAggregateStats"
-  input_arg {
-    name: "node_ids"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "gradients"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "hessians"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "feature"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "stats_summary"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "max_splits"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "num_buckets"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "BoostedTreesBucketize"
-  input_arg {
-    name: "float_values"
-    type: DT_FLOAT
-    number_attr: "num_features"
-  }
-  input_arg {
-    name: "bucket_boundaries"
-    type: DT_FLOAT
-    number_attr: "num_features"
-  }
-  output_arg {
-    name: "buckets"
-    type: DT_INT32
-    number_attr: "num_features"
-  }
-  attr {
-    name: "num_features"
-    type: "int"
-    has_minimum: true
-  }
-}
-op {
-  name: "BoostedTreesCalculateBestFeatureSplit"
-  input_arg {
-    name: "node_id_range"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "stats_summary"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "l1"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "l2"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "tree_complexity"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_node_weight"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "node_ids"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "gains"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "feature_dimensions"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "thresholds"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "left_node_contribs"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "right_node_contribs"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "split_with_default_directions"
-    type: DT_STRING
-  }
-  attr {
-    name: "logits_dimension"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "split_type"
-    type: "string"
-    default_value {
-      s: "inequality"
-    }
-    allowed_values {
-      list {
-        s: "inequality"
-      }
-    }
-  }
-}
-op {
-  name: "BoostedTreesCalculateBestGainsPerFeature"
-  input_arg {
-    name: "node_id_range"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "stats_summary_list"
-    type: DT_FLOAT
-    number_attr: "num_features"
-  }
-  input_arg {
-    name: "l1"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "l2"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "tree_complexity"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_node_weight"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "node_ids_list"
-    type: DT_INT32
-    number_attr: "num_features"
-  }
-  output_arg {
-    name: "gains_list"
-    type: DT_FLOAT
-    number_attr: "num_features"
-  }
-  output_arg {
-    name: "thresholds_list"
-    type: DT_INT32
-    number_attr: "num_features"
-  }
-  output_arg {
-    name: "left_node_contribs_list"
-    type: DT_FLOAT
-    number_attr: "num_features"
-  }
-  output_arg {
-    name: "right_node_contribs_list"
-    type: DT_FLOAT
-    number_attr: "num_features"
-  }
-  attr {
-    name: "max_splits"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "num_features"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "BoostedTreesCenterBias"
-  input_arg {
-    name: "tree_ensemble_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mean_gradients"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "mean_hessians"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "l1"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "l2"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "continue_centering"
-    type: DT_BOOL
-  }
-  is_stateful: true
-}
-op {
-  name: "BoostedTreesCreateEnsemble"
-  input_arg {
-    name: "tree_ensemble_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "stamp_token"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "tree_ensemble_serialized"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "BoostedTreesCreateQuantileStreamResource"
-  input_arg {
-    name: "quantile_stream_resource_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "epsilon"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "num_streams"
-    type: DT_INT64
-  }
-  attr {
-    name: "max_elements"
-    type: "int"
-    default_value {
-      i: 1099511627776
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "BoostedTreesDeserializeEnsemble"
-  input_arg {
-    name: "tree_ensemble_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "stamp_token"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "tree_ensemble_serialized"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "BoostedTreesEnsembleResourceHandleOp"
-  output_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "BoostedTreesExampleDebugOutputs"
-  input_arg {
-    name: "tree_ensemble_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "bucketized_features"
-    type: DT_INT32
-    number_attr: "num_bucketized_features"
-  }
-  output_arg {
-    name: "examples_debug_outputs_serialized"
-    type: DT_STRING
-  }
-  attr {
-    name: "num_bucketized_features"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "logits_dimension"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "BoostedTreesFlushQuantileSummaries"
-  input_arg {
-    name: "quantile_stream_resource_handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "summaries"
-    type: DT_FLOAT
-    number_attr: "num_features"
-  }
-  attr {
-    name: "num_features"
-    type: "int"
-    has_minimum: true
-  }
-  is_stateful: true
-}
-op {
-  name: "BoostedTreesGetEnsembleStates"
-  input_arg {
-    name: "tree_ensemble_handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "stamp_token"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "num_trees"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "num_finalized_trees"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "num_attempted_layers"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "last_layer_nodes_range"
-    type: DT_INT32
-  }
-  is_stateful: true
-}
-op {
-  name: "BoostedTreesMakeQuantileSummaries"
-  input_arg {
-    name: "float_values"
-    type: DT_FLOAT
-    number_attr: "num_features"
-  }
-  input_arg {
-    name: "example_weights"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "epsilon"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "summaries"
-    type: DT_FLOAT
-    number_attr: "num_features"
-  }
-  attr {
-    name: "num_features"
-    type: "int"
-    has_minimum: true
-  }
-}
-op {
-  name: "BoostedTreesMakeStatsSummary"
-  input_arg {
-    name: "node_ids"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "gradients"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "hessians"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "bucketized_features_list"
-    type: DT_INT32
-    number_attr: "num_features"
-  }
-  output_arg {
-    name: "stats_summary"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "max_splits"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "num_buckets"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "num_features"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "BoostedTreesPredict"
-  input_arg {
-    name: "tree_ensemble_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "bucketized_features"
-    type: DT_INT32
-    number_attr: "num_bucketized_features"
-  }
-  output_arg {
-    name: "logits"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "num_bucketized_features"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "logits_dimension"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "BoostedTreesQuantileStreamResourceAddSummaries"
-  input_arg {
-    name: "quantile_stream_resource_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "summaries"
-    type: DT_FLOAT
-    number_attr: "num_features"
-  }
-  attr {
-    name: "num_features"
-    type: "int"
-    has_minimum: true
-  }
-  is_stateful: true
-}
-op {
-  name: "BoostedTreesQuantileStreamResourceDeserialize"
-  input_arg {
-    name: "quantile_stream_resource_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "bucket_boundaries"
-    type: DT_FLOAT
-    number_attr: "num_streams"
-  }
-  attr {
-    name: "num_streams"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "BoostedTreesQuantileStreamResourceFlush"
-  input_arg {
-    name: "quantile_stream_resource_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "num_buckets"
-    type: DT_INT64
-  }
-  attr {
-    name: "generate_quantiles"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "BoostedTreesQuantileStreamResourceGetBucketBoundaries"
-  input_arg {
-    name: "quantile_stream_resource_handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "bucket_boundaries"
-    type: DT_FLOAT
-    number_attr: "num_features"
-  }
-  attr {
-    name: "num_features"
-    type: "int"
-    has_minimum: true
-  }
-  is_stateful: true
-}
-op {
-  name: "BoostedTreesQuantileStreamResourceHandleOp"
-  output_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "BoostedTreesSerializeEnsemble"
-  input_arg {
-    name: "tree_ensemble_handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "stamp_token"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "tree_ensemble_serialized"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "BoostedTreesSparseAggregateStats"
-  input_arg {
-    name: "node_ids"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "gradients"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "hessians"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "feature_indices"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "feature_values"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "feature_shape"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "stats_summary_indices"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "stats_summary_values"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "stats_summary_shape"
-    type: DT_INT32
-  }
-  attr {
-    name: "max_splits"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "num_buckets"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "BoostedTreesSparseCalculateBestFeatureSplit"
-  input_arg {
-    name: "node_id_range"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "stats_summary_indices"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "stats_summary_values"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "stats_summary_shape"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "l1"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "l2"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "tree_complexity"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_node_weight"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "node_ids"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "gains"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "feature_dimensions"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "thresholds"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "left_node_contribs"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "right_node_contribs"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "split_with_default_directions"
-    type: DT_STRING
-  }
-  attr {
-    name: "logits_dimension"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "split_type"
-    type: "string"
-    default_value {
-      s: "inequality"
-    }
-    allowed_values {
-      list {
-        s: "inequality"
-      }
-    }
-  }
-}
-op {
-  name: "BoostedTreesTrainingPredict"
-  input_arg {
-    name: "tree_ensemble_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "cached_tree_ids"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "cached_node_ids"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "bucketized_features"
-    type: DT_INT32
-    number_attr: "num_bucketized_features"
-  }
-  output_arg {
-    name: "partial_logits"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "tree_ids"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "node_ids"
-    type: DT_INT32
-  }
-  attr {
-    name: "num_bucketized_features"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "logits_dimension"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "BoostedTreesUpdateEnsemble"
-  input_arg {
-    name: "tree_ensemble_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "feature_ids"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "node_ids"
-    type: DT_INT32
-    number_attr: "num_features"
-  }
-  input_arg {
-    name: "gains"
-    type: DT_FLOAT
-    number_attr: "num_features"
-  }
-  input_arg {
-    name: "thresholds"
-    type: DT_INT32
-    number_attr: "num_features"
-  }
-  input_arg {
-    name: "left_node_contribs"
-    type: DT_FLOAT
-    number_attr: "num_features"
-  }
-  input_arg {
-    name: "right_node_contribs"
-    type: DT_FLOAT
-    number_attr: "num_features"
-  }
-  input_arg {
-    name: "max_depth"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "learning_rate"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "pruning_mode"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "num_features"
-    type: "int"
-    has_minimum: true
-  }
-  is_stateful: true
-}
-op {
-  name: "BroadcastArgs"
-  input_arg {
-    name: "s0"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "s1"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "r0"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "BroadcastGradientArgs"
-  input_arg {
-    name: "s0"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "s1"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "r0"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "r1"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "BroadcastTo"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "shape"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Bucketize"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type: DT_INT32
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "boundaries"
-    type: "list(float)"
-  }
-}
-op {
-  name: "BytesProducedStatsDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "tag"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "CSVDataset"
-  input_arg {
-    name: "filenames"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "compression_type"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "buffer_size"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "header"
-    type: DT_BOOL
-  }
-  input_arg {
-    name: "field_delim"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "use_quote_delim"
-    type: DT_BOOL
-  }
-  input_arg {
-    name: "na_value"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "select_cols"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "record_defaults"
-    type_list_attr: "output_types"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "CTCBeamSearchDecoder"
-  input_arg {
-    name: "inputs"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "sequence_length"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "decoded_indices"
-    type: DT_INT64
-    number_attr: "top_paths"
-  }
-  output_arg {
-    name: "decoded_values"
-    type: DT_INT64
-    number_attr: "top_paths"
-  }
-  output_arg {
-    name: "decoded_shape"
-    type: DT_INT64
-    number_attr: "top_paths"
-  }
-  output_arg {
-    name: "log_probability"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "beam_width"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "top_paths"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "merge_repeated"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "CTCGreedyDecoder"
-  input_arg {
-    name: "inputs"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "sequence_length"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "decoded_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "decoded_values"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "decoded_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "log_probability"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "merge_repeated"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "CTCLoss"
-  input_arg {
-    name: "inputs"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "labels_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "labels_values"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "sequence_length"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "loss"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "gradient"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "preprocess_collapse_repeated"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "ctc_merge_repeated"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "CTCLoss"
-  input_arg {
-    name: "inputs"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "labels_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "labels_values"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "sequence_length"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "loss"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "gradient"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "preprocess_collapse_repeated"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "ctc_merge_repeated"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "ignore_longer_outputs_than_inputs"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "CacheDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "filename"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "CacheDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "filename"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "Case"
-  input_arg {
-    name: "branch_index"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "input"
-    type_list_attr: "Tin"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "Tout"
-  }
-  attr {
-    name: "Tin"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tout"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "branches"
-    type: "list(func)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    default_value {
-      list {
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "Cast"
-  input_arg {
-    name: "x"
-    type_attr: "SrcT"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "DstT"
-  }
-  attr {
-    name: "SrcT"
-    type: "type"
-  }
-  attr {
-    name: "DstT"
-    type: "type"
-  }
-}
-op {
-  name: "Cast"
-  input_arg {
-    name: "x"
-    type_attr: "SrcT"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "DstT"
-  }
-  attr {
-    name: "SrcT"
-    type: "type"
-  }
-  attr {
-    name: "DstT"
-    type: "type"
-  }
-  attr {
-    name: "Truncate"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "Ceil"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Ceil"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Ceil"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "CheckNumerics"
-  input_arg {
-    name: "tensor"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "message"
-    type: "string"
-  }
-}
-op {
-  name: "CheckNumerics"
-  input_arg {
-    name: "tensor"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "message"
-    type: "string"
-  }
-}
-op {
-  name: "CheckNumerics"
-  input_arg {
-    name: "tensor"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "message"
-    type: "string"
-  }
-}
-op {
-  name: "Cholesky"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "Cholesky"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Cholesky"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_HALF
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "CholeskyGrad"
-  input_arg {
-    name: "l"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "CholeskyGrad"
-  input_arg {
-    name: "l"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "ChooseFastestBranchDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "ratio_numerator"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "ratio_denominator"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "num_elements_per_branch"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "branches"
-    type: "list(func)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "other_arguments_lengths"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ChooseFastestDataset"
-  input_arg {
-    name: "input_datasets"
-    type: DT_VARIANT
-    number_attr: "N"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 2
-  }
-  attr {
-    name: "num_experiments"
-    type: "int"
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ClipByValue"
-  input_arg {
-    name: "t"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "clip_value_min"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "clip_value_max"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "CloseSummaryWriter"
-  input_arg {
-    name: "writer"
-    type: DT_RESOURCE
-  }
-  is_stateful: true
-}
-op {
-  name: "CollectiveBcastRecv"
-  output_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_HALF
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "group_size"
-    type: "int"
-  }
-  attr {
-    name: "group_key"
-    type: "int"
-  }
-  attr {
-    name: "instance_key"
-    type: "int"
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  is_stateful: true
-}
-op {
-  name: "CollectiveBcastSend"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_HALF
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "group_size"
-    type: "int"
-  }
-  attr {
-    name: "group_key"
-    type: "int"
-  }
-  attr {
-    name: "instance_key"
-    type: "int"
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  is_stateful: true
-}
-op {
-  name: "CollectiveGather"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_HALF
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "group_size"
-    type: "int"
-  }
-  attr {
-    name: "group_key"
-    type: "int"
-  }
-  attr {
-    name: "instance_key"
-    type: "int"
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  is_stateful: true
-}
-op {
-  name: "CollectivePermute"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "source_target_pairs"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "CollectiveReduce"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_HALF
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "group_size"
-    type: "int"
-  }
-  attr {
-    name: "group_key"
-    type: "int"
-  }
-  attr {
-    name: "instance_key"
-    type: "int"
-  }
-  attr {
-    name: "merge_op"
-    type: "string"
-    allowed_values {
-      list {
-        s: "Min"
-        s: "Max"
-        s: "Mul"
-        s: "Add"
-      }
-    }
-  }
-  attr {
-    name: "final_op"
-    type: "string"
-    allowed_values {
-      list {
-        s: "Id"
-        s: "Div"
-      }
-    }
-  }
-  attr {
-    name: "subdiv_offsets"
-    type: "list(int)"
-  }
-  is_stateful: true
-}
-op {
-  name: "CollectiveReduce"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_HALF
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "group_size"
-    type: "int"
-  }
-  attr {
-    name: "group_key"
-    type: "int"
-  }
-  attr {
-    name: "instance_key"
-    type: "int"
-  }
-  attr {
-    name: "merge_op"
-    type: "string"
-    allowed_values {
-      list {
-        s: "Min"
-        s: "Max"
-        s: "Mul"
-        s: "Add"
-      }
-    }
-  }
-  attr {
-    name: "final_op"
-    type: "string"
-    allowed_values {
-      list {
-        s: "Id"
-        s: "Div"
-      }
-    }
-  }
-  attr {
-    name: "subdiv_offsets"
-    type: "list(int)"
-  }
-  attr {
-    name: "wait_for"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "CombinedNonMaxSuppression"
-  input_arg {
-    name: "boxes"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "scores"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_output_size_per_class"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "max_total_size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "iou_threshold"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "score_threshold"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "nmsed_boxes"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "nmsed_scores"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "nmsed_classes"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "valid_detections"
-    type: DT_INT32
-  }
-  attr {
-    name: "pad_per_class"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "CombinedNonMaxSuppression"
-  input_arg {
-    name: "boxes"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "scores"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_output_size_per_class"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "max_total_size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "iou_threshold"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "score_threshold"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "nmsed_boxes"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "nmsed_scores"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "nmsed_classes"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "valid_detections"
-    type: DT_INT32
-  }
-  attr {
-    name: "pad_per_class"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "clip_boxes"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "CompareAndBitpack"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "threshold"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type: DT_UINT8
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BOOL
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Complex"
-  input_arg {
-    name: "real"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "imag"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "Tout"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "Tout"
-    type: "type"
-    default_value {
-      type: DT_COMPLEX64
-    }
-    allowed_values {
-      list {
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "ComplexAbs"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "Tout"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_COMPLEX64
-    }
-    allowed_values {
-      list {
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  attr {
-    name: "Tout"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "ComputeAccidentalHits"
-  input_arg {
-    name: "true_classes"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sampled_candidates"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "ids"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "weights"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "num_true"
-    type: "int"
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-}
-op {
-  name: "Concat"
-  input_arg {
-    name: "concat_dim"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "values"
-    type_attr: "T"
-    number_attr: "N"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 2
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "ConcatOffset"
-  input_arg {
-    name: "concat_dim"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "shape"
-    type: DT_INT32
-    number_attr: "N"
-  }
-  output_arg {
-    name: "offset"
-    type: DT_INT32
-    number_attr: "N"
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 2
-  }
-}
-op {
-  name: "ConcatV2"
-  input_arg {
-    name: "values"
-    type_attr: "T"
-    number_attr: "N"
-  }
-  input_arg {
-    name: "axis"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 2
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "ConcatenateDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "another_dataset"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "ConcatenateDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "another_dataset"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ConditionalAccumulator"
-  output_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ConditionalAccumulator"
-  output_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ConditionalAccumulator"
-  output_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ConditionalAccumulator"
-  output_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ConditionalAccumulator"
-  output_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "reduction_type"
-    type: "string"
-    default_value {
-      s: "MEAN"
-    }
-    allowed_values {
-      list {
-        s: "MEAN"
-        s: "SUM"
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ConfigureDistributedTPU"
-  output_arg {
-    name: "topology"
-    type: DT_STRING
-  }
-  attr {
-    name: "embedding_config"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "tpu_embedding_config"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "is_global_init"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "Conj"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_COMPLEX64
-    }
-    allowed_values {
-      list {
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Conj"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_COMPLEX64
-    }
-    allowed_values {
-      list {
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_VARIANT
-      }
-    }
-  }
-}
-op {
-  name: "ConjugateTranspose"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "perm"
-    type_attr: "Tperm"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tperm"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Const"
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "value"
-    type: "tensor"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-}
-op {
-  name: "ConsumeMutexLock"
-  input_arg {
-    name: "mutex_lock"
-    type: DT_VARIANT
-  }
-  is_stateful: true
-}
-op {
-  name: "ControlTrigger"
-}
-op {
-  name: "Conv2D"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "use_cudnn_on_gpu"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-}
-op {
-  name: "Conv2D"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "use_cudnn_on_gpu"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "Conv2D"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "use_cudnn_on_gpu"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "Conv2D"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "use_cudnn_on_gpu"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-        s: "EXPLICIT"
-      }
-    }
-  }
-  attr {
-    name: "explicit_paddings"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "Conv2DBackpropFilter"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter_sizes"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "use_cudnn_on_gpu"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-}
-op {
-  name: "Conv2DBackpropFilter"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter_sizes"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "use_cudnn_on_gpu"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "Conv2DBackpropFilter"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter_sizes"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "use_cudnn_on_gpu"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "Conv2DBackpropFilter"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter_sizes"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "use_cudnn_on_gpu"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-        s: "EXPLICIT"
-      }
-    }
-  }
-  attr {
-    name: "explicit_paddings"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "Conv2DBackpropInput"
-  input_arg {
-    name: "input_sizes"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "use_cudnn_on_gpu"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-}
-op {
-  name: "Conv2DBackpropInput"
-  input_arg {
-    name: "input_sizes"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "use_cudnn_on_gpu"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "Conv2DBackpropInput"
-  input_arg {
-    name: "input_sizes"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "use_cudnn_on_gpu"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "Conv2DBackpropInput"
-  input_arg {
-    name: "input_sizes"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "use_cudnn_on_gpu"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-        s: "EXPLICIT"
-      }
-    }
-  }
-  attr {
-    name: "explicit_paddings"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "Conv3D"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "Conv3D"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NDHWC"
-    }
-    allowed_values {
-      list {
-        s: "NDHWC"
-        s: "NCDHW"
-      }
-    }
-  }
-}
-op {
-  name: "Conv3D"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NDHWC"
-    }
-    allowed_values {
-      list {
-        s: "NDHWC"
-        s: "NCDHW"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "Conv3DBackpropFilter"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  deprecation {
-    version: 10
-  }
-}
-op {
-  name: "Conv3DBackpropFilter"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  deprecation {
-    version: 10
-  }
-}
-op {
-  name: "Conv3DBackpropFilter"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-  deprecation {
-    version: 10
-  }
-}
-op {
-  name: "Conv3DBackpropFilterV2"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter_sizes"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "Conv3DBackpropFilterV2"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter_sizes"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NDHWC"
-    }
-    allowed_values {
-      list {
-        s: "NDHWC"
-        s: "NCDHW"
-      }
-    }
-  }
-}
-op {
-  name: "Conv3DBackpropFilterV2"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter_sizes"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NDHWC"
-    }
-    allowed_values {
-      list {
-        s: "NDHWC"
-        s: "NCDHW"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "Conv3DBackpropInput"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  deprecation {
-    version: 10
-  }
-}
-op {
-  name: "Conv3DBackpropInput"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  deprecation {
-    version: 10
-  }
-}
-op {
-  name: "Conv3DBackpropInput"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-  deprecation {
-    version: 10
-  }
-}
-op {
-  name: "Conv3DBackpropInputV2"
-  input_arg {
-    name: "input_sizes"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "Conv3DBackpropInputV2"
-  input_arg {
-    name: "input_sizes"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NDHWC"
-    }
-    allowed_values {
-      list {
-        s: "NDHWC"
-        s: "NCDHW"
-      }
-    }
-  }
-}
-op {
-  name: "Conv3DBackpropInputV2"
-  input_arg {
-    name: "input_sizes"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NDHWC"
-    }
-    allowed_values {
-      list {
-        s: "NDHWC"
-        s: "NCDHW"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "Conv3DBackpropInputV2"
-  input_arg {
-    name: "input_sizes"
-    type_attr: "Tshape"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NDHWC"
-    }
-    allowed_values {
-      list {
-        s: "NDHWC"
-        s: "NCDHW"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-  attr {
-    name: "Tshape"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Copy"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "tensor_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  allows_uninitialized_input: true
-}
-op {
-  name: "Copy"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "tensor_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "debug_ops_spec"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-  allows_uninitialized_input: true
-}
-op {
-  name: "CopyHost"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "tensor_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  allows_uninitialized_input: true
-}
-op {
-  name: "CopyHost"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "tensor_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "debug_ops_spec"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-  allows_uninitialized_input: true
-}
-op {
-  name: "Cos"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Cos"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Cos"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Cosh"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Cosh"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Cosh"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "CountUpTo"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "limit"
-    type: "int"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "CreateSummaryDbWriter"
-  input_arg {
-    name: "writer"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "db_uri"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "experiment_name"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "run_name"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "user_name"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "CreateSummaryFileWriter"
-  input_arg {
-    name: "writer"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "logdir"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "max_queue"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "flush_millis"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "filename_suffix"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "CropAndResize"
-  input_arg {
-    name: "image"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "boxes"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "box_ind"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "crop_size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "crops"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "method"
-    type: "string"
-    default_value {
-      s: "bilinear"
-    }
-    allowed_values {
-      list {
-        s: "bilinear"
-      }
-    }
-  }
-  attr {
-    name: "extrapolation_value"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-}
-op {
-  name: "CropAndResize"
-  input_arg {
-    name: "image"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "boxes"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "box_ind"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "crop_size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "crops"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "method"
-    type: "string"
-    default_value {
-      s: "bilinear"
-    }
-    allowed_values {
-      list {
-        s: "bilinear"
-      }
-    }
-  }
-  attr {
-    name: "extrapolation_value"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-}
-op {
-  name: "CropAndResize"
-  input_arg {
-    name: "image"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "boxes"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "box_ind"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "crop_size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "crops"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "method"
-    type: "string"
-    default_value {
-      s: "bilinear"
-    }
-    allowed_values {
-      list {
-        s: "bilinear"
-        s: "nearest"
-      }
-    }
-  }
-  attr {
-    name: "extrapolation_value"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-}
-op {
-  name: "CropAndResizeGradBoxes"
-  input_arg {
-    name: "grads"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "image"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "boxes"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "box_ind"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "method"
-    type: "string"
-    default_value {
-      s: "bilinear"
-    }
-    allowed_values {
-      list {
-        s: "bilinear"
-      }
-    }
-  }
-}
-op {
-  name: "CropAndResizeGradBoxes"
-  input_arg {
-    name: "grads"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "image"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "boxes"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "box_ind"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "method"
-    type: "string"
-    default_value {
-      s: "bilinear"
-    }
-    allowed_values {
-      list {
-        s: "bilinear"
-      }
-    }
-  }
-}
-op {
-  name: "CropAndResizeGradImage"
-  input_arg {
-    name: "grads"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "boxes"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "box_ind"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "image_size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_HALF
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "method"
-    type: "string"
-    default_value {
-      s: "bilinear"
-    }
-    allowed_values {
-      list {
-        s: "bilinear"
-      }
-    }
-  }
-}
-op {
-  name: "CropAndResizeGradImage"
-  input_arg {
-    name: "grads"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "boxes"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "box_ind"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "image_size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_HALF
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "method"
-    type: "string"
-    default_value {
-      s: "bilinear"
-    }
-    allowed_values {
-      list {
-        s: "bilinear"
-        s: "nearest"
-      }
-    }
-  }
-}
-op {
-  name: "Cross"
-  input_arg {
-    name: "a"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "product"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "Cross"
-  input_arg {
-    name: "a"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "product"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "Cross"
-  input_arg {
-    name: "a"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "product"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "Cross"
-  input_arg {
-    name: "a"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "product"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "CrossReplicaSum"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "group_assignment"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "CrossReplicaSum"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "group_assignment"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_INT32
-        type: DT_UINT32
-      }
-    }
-  }
-}
-op {
-  name: "CudnnRNN"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_h"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_c"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "params"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_h"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_c"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "reserve_space"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "rnn_mode"
-    type: "string"
-    default_value {
-      s: "lstm"
-    }
-    allowed_values {
-      list {
-        s: "rnn_relu"
-        s: "rnn_tanh"
-        s: "lstm"
-        s: "gru"
-      }
-    }
-  }
-  attr {
-    name: "input_mode"
-    type: "string"
-    default_value {
-      s: "linear_input"
-    }
-    allowed_values {
-      list {
-        s: "linear_input"
-        s: "skip_input"
-        s: "auto_select"
-      }
-    }
-  }
-  attr {
-    name: "direction"
-    type: "string"
-    default_value {
-      s: "unidirectional"
-    }
-    allowed_values {
-      list {
-        s: "unidirectional"
-        s: "bidirectional"
-      }
-    }
-  }
-  attr {
-    name: "dropout"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "is_training"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "CudnnRNNBackprop"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_h"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_c"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "params"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_h"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_c"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_backprop"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_h_backprop"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_c_backprop"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reserve_space"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "input_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "input_h_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "input_c_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "params_backprop"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "rnn_mode"
-    type: "string"
-    default_value {
-      s: "lstm"
-    }
-    allowed_values {
-      list {
-        s: "rnn_relu"
-        s: "rnn_tanh"
-        s: "lstm"
-        s: "gru"
-      }
-    }
-  }
-  attr {
-    name: "input_mode"
-    type: "string"
-    default_value {
-      s: "linear_input"
-    }
-    allowed_values {
-      list {
-        s: "linear_input"
-        s: "skip_input"
-        s: "auto_select"
-      }
-    }
-  }
-  attr {
-    name: "direction"
-    type: "string"
-    default_value {
-      s: "unidirectional"
-    }
-    allowed_values {
-      list {
-        s: "unidirectional"
-        s: "bidirectional"
-      }
-    }
-  }
-  attr {
-    name: "dropout"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "CudnnRNNBackpropV2"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_h"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_c"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "params"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_h"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_c"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_backprop"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_h_backprop"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_c_backprop"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reserve_space"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "host_reserved"
-    type: DT_INT8
-  }
-  output_arg {
-    name: "input_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "input_h_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "input_c_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "params_backprop"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "rnn_mode"
-    type: "string"
-    default_value {
-      s: "lstm"
-    }
-    allowed_values {
-      list {
-        s: "rnn_relu"
-        s: "rnn_tanh"
-        s: "lstm"
-        s: "gru"
-      }
-    }
-  }
-  attr {
-    name: "input_mode"
-    type: "string"
-    default_value {
-      s: "linear_input"
-    }
-    allowed_values {
-      list {
-        s: "linear_input"
-        s: "skip_input"
-        s: "auto_select"
-      }
-    }
-  }
-  attr {
-    name: "direction"
-    type: "string"
-    default_value {
-      s: "unidirectional"
-    }
-    allowed_values {
-      list {
-        s: "unidirectional"
-        s: "bidirectional"
-      }
-    }
-  }
-  attr {
-    name: "dropout"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "CudnnRNNBackpropV3"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_h"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_c"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "params"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sequence_lengths"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_h"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_c"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_backprop"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_h_backprop"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_c_backprop"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reserve_space"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "host_reserved"
-    type: DT_INT8
-  }
-  output_arg {
-    name: "input_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "input_h_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "input_c_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "params_backprop"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "rnn_mode"
-    type: "string"
-    default_value {
-      s: "lstm"
-    }
-    allowed_values {
-      list {
-        s: "rnn_relu"
-        s: "rnn_tanh"
-        s: "lstm"
-        s: "gru"
-      }
-    }
-  }
-  attr {
-    name: "input_mode"
-    type: "string"
-    default_value {
-      s: "linear_input"
-    }
-    allowed_values {
-      list {
-        s: "linear_input"
-        s: "skip_input"
-        s: "auto_select"
-      }
-    }
-  }
-  attr {
-    name: "direction"
-    type: "string"
-    default_value {
-      s: "unidirectional"
-    }
-    allowed_values {
-      list {
-        s: "unidirectional"
-        s: "bidirectional"
-      }
-    }
-  }
-  attr {
-    name: "dropout"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "CudnnRNNBackpropV3"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_h"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_c"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "params"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sequence_lengths"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_h"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_c"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_backprop"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_h_backprop"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_c_backprop"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reserve_space"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "host_reserved"
-    type: DT_INT8
-  }
-  output_arg {
-    name: "input_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "input_h_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "input_c_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "params_backprop"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "rnn_mode"
-    type: "string"
-    default_value {
-      s: "lstm"
-    }
-    allowed_values {
-      list {
-        s: "rnn_relu"
-        s: "rnn_tanh"
-        s: "lstm"
-        s: "gru"
-      }
-    }
-  }
-  attr {
-    name: "input_mode"
-    type: "string"
-    default_value {
-      s: "linear_input"
-    }
-    allowed_values {
-      list {
-        s: "linear_input"
-        s: "skip_input"
-        s: "auto_select"
-      }
-    }
-  }
-  attr {
-    name: "direction"
-    type: "string"
-    default_value {
-      s: "unidirectional"
-    }
-    allowed_values {
-      list {
-        s: "unidirectional"
-        s: "bidirectional"
-      }
-    }
-  }
-  attr {
-    name: "dropout"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "time_major"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "CudnnRNNBackpropV3"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_h"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_c"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "params"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sequence_lengths"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_h"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_c"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_backprop"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_h_backprop"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_c_backprop"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reserve_space"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "host_reserved"
-    type: DT_INT8
-  }
-  output_arg {
-    name: "input_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "input_h_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "input_c_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "params_backprop"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "rnn_mode"
-    type: "string"
-    default_value {
-      s: "lstm"
-    }
-    allowed_values {
-      list {
-        s: "rnn_relu"
-        s: "rnn_tanh"
-        s: "lstm"
-        s: "gru"
-      }
-    }
-  }
-  attr {
-    name: "input_mode"
-    type: "string"
-    default_value {
-      s: "linear_input"
-    }
-    allowed_values {
-      list {
-        s: "linear_input"
-        s: "skip_input"
-        s: "auto_select"
-      }
-    }
-  }
-  attr {
-    name: "direction"
-    type: "string"
-    default_value {
-      s: "unidirectional"
-    }
-    allowed_values {
-      list {
-        s: "unidirectional"
-        s: "bidirectional"
-      }
-    }
-  }
-  attr {
-    name: "dropout"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "num_proj"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "time_major"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "CudnnRNNCanonicalToParams"
-  input_arg {
-    name: "num_layers"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "num_units"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "input_size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "weights"
-    type_attr: "T"
-    number_attr: "num_params"
-  }
-  input_arg {
-    name: "biases"
-    type_attr: "T"
-    number_attr: "num_params"
-  }
-  output_arg {
-    name: "params"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "num_params"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "rnn_mode"
-    type: "string"
-    default_value {
-      s: "lstm"
-    }
-    allowed_values {
-      list {
-        s: "rnn_relu"
-        s: "rnn_tanh"
-        s: "lstm"
-        s: "gru"
-      }
-    }
-  }
-  attr {
-    name: "input_mode"
-    type: "string"
-    default_value {
-      s: "linear_input"
-    }
-    allowed_values {
-      list {
-        s: "linear_input"
-        s: "skip_input"
-        s: "auto_select"
-      }
-    }
-  }
-  attr {
-    name: "direction"
-    type: "string"
-    default_value {
-      s: "unidirectional"
-    }
-    allowed_values {
-      list {
-        s: "unidirectional"
-        s: "bidirectional"
-      }
-    }
-  }
-  attr {
-    name: "dropout"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-}
-op {
-  name: "CudnnRNNCanonicalToParamsV2"
-  input_arg {
-    name: "num_layers"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "num_units"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "input_size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "weights"
-    type_attr: "T"
-    number_attr: "num_params_weights"
-  }
-  input_arg {
-    name: "biases"
-    type_attr: "T"
-    number_attr: "num_params_biases"
-  }
-  output_arg {
-    name: "params"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "num_params_weights"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "num_params_biases"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "rnn_mode"
-    type: "string"
-    default_value {
-      s: "lstm"
-    }
-    allowed_values {
-      list {
-        s: "rnn_relu"
-        s: "rnn_tanh"
-        s: "lstm"
-        s: "gru"
-      }
-    }
-  }
-  attr {
-    name: "input_mode"
-    type: "string"
-    default_value {
-      s: "linear_input"
-    }
-    allowed_values {
-      list {
-        s: "linear_input"
-        s: "skip_input"
-        s: "auto_select"
-      }
-    }
-  }
-  attr {
-    name: "direction"
-    type: "string"
-    default_value {
-      s: "unidirectional"
-    }
-    allowed_values {
-      list {
-        s: "unidirectional"
-        s: "bidirectional"
-      }
-    }
-  }
-  attr {
-    name: "dropout"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "num_proj"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-}
-op {
-  name: "CudnnRNNParamsSize"
-  input_arg {
-    name: "num_layers"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "num_units"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "input_size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "params_size"
-    type_attr: "S"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "S"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "rnn_mode"
-    type: "string"
-    default_value {
-      s: "lstm"
-    }
-    allowed_values {
-      list {
-        s: "rnn_relu"
-        s: "rnn_tanh"
-        s: "lstm"
-        s: "gru"
-      }
-    }
-  }
-  attr {
-    name: "input_mode"
-    type: "string"
-    default_value {
-      s: "linear_input"
-    }
-    allowed_values {
-      list {
-        s: "linear_input"
-        s: "skip_input"
-        s: "auto_select"
-      }
-    }
-  }
-  attr {
-    name: "direction"
-    type: "string"
-    default_value {
-      s: "unidirectional"
-    }
-    allowed_values {
-      list {
-        s: "unidirectional"
-        s: "bidirectional"
-      }
-    }
-  }
-  attr {
-    name: "dropout"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-}
-op {
-  name: "CudnnRNNParamsSize"
-  input_arg {
-    name: "num_layers"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "num_units"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "input_size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "params_size"
-    type_attr: "S"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "S"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "rnn_mode"
-    type: "string"
-    default_value {
-      s: "lstm"
-    }
-    allowed_values {
-      list {
-        s: "rnn_relu"
-        s: "rnn_tanh"
-        s: "lstm"
-        s: "gru"
-      }
-    }
-  }
-  attr {
-    name: "input_mode"
-    type: "string"
-    default_value {
-      s: "linear_input"
-    }
-    allowed_values {
-      list {
-        s: "linear_input"
-        s: "skip_input"
-        s: "auto_select"
-      }
-    }
-  }
-  attr {
-    name: "direction"
-    type: "string"
-    default_value {
-      s: "unidirectional"
-    }
-    allowed_values {
-      list {
-        s: "unidirectional"
-        s: "bidirectional"
-      }
-    }
-  }
-  attr {
-    name: "dropout"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "num_proj"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-}
-op {
-  name: "CudnnRNNParamsToCanonical"
-  input_arg {
-    name: "num_layers"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "num_units"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "input_size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "params"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "weights"
-    type_attr: "T"
-    number_attr: "num_params"
-  }
-  output_arg {
-    name: "biases"
-    type_attr: "T"
-    number_attr: "num_params"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "num_params"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "rnn_mode"
-    type: "string"
-    default_value {
-      s: "lstm"
-    }
-    allowed_values {
-      list {
-        s: "rnn_relu"
-        s: "rnn_tanh"
-        s: "lstm"
-        s: "gru"
-      }
-    }
-  }
-  attr {
-    name: "input_mode"
-    type: "string"
-    default_value {
-      s: "linear_input"
-    }
-    allowed_values {
-      list {
-        s: "linear_input"
-        s: "skip_input"
-        s: "auto_select"
-      }
-    }
-  }
-  attr {
-    name: "direction"
-    type: "string"
-    default_value {
-      s: "unidirectional"
-    }
-    allowed_values {
-      list {
-        s: "unidirectional"
-        s: "bidirectional"
-      }
-    }
-  }
-  attr {
-    name: "dropout"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-}
-op {
-  name: "CudnnRNNParamsToCanonicalV2"
-  input_arg {
-    name: "num_layers"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "num_units"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "input_size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "params"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "weights"
-    type_attr: "T"
-    number_attr: "num_params_weights"
-  }
-  output_arg {
-    name: "biases"
-    type_attr: "T"
-    number_attr: "num_params_biases"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "num_params_weights"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "num_params_biases"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "rnn_mode"
-    type: "string"
-    default_value {
-      s: "lstm"
-    }
-    allowed_values {
-      list {
-        s: "rnn_relu"
-        s: "rnn_tanh"
-        s: "lstm"
-        s: "gru"
-      }
-    }
-  }
-  attr {
-    name: "input_mode"
-    type: "string"
-    default_value {
-      s: "linear_input"
-    }
-    allowed_values {
-      list {
-        s: "linear_input"
-        s: "skip_input"
-        s: "auto_select"
-      }
-    }
-  }
-  attr {
-    name: "direction"
-    type: "string"
-    default_value {
-      s: "unidirectional"
-    }
-    allowed_values {
-      list {
-        s: "unidirectional"
-        s: "bidirectional"
-      }
-    }
-  }
-  attr {
-    name: "dropout"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "num_proj"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-}
-op {
-  name: "CudnnRNNV2"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_h"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_c"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "params"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_h"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_c"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "reserve_space"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "host_reserved"
-    type: DT_INT8
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "rnn_mode"
-    type: "string"
-    default_value {
-      s: "lstm"
-    }
-    allowed_values {
-      list {
-        s: "rnn_relu"
-        s: "rnn_tanh"
-        s: "lstm"
-        s: "gru"
-      }
-    }
-  }
-  attr {
-    name: "input_mode"
-    type: "string"
-    default_value {
-      s: "linear_input"
-    }
-    allowed_values {
-      list {
-        s: "linear_input"
-        s: "skip_input"
-        s: "auto_select"
-      }
-    }
-  }
-  attr {
-    name: "direction"
-    type: "string"
-    default_value {
-      s: "unidirectional"
-    }
-    allowed_values {
-      list {
-        s: "unidirectional"
-        s: "bidirectional"
-      }
-    }
-  }
-  attr {
-    name: "dropout"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "is_training"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "CudnnRNNV3"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_h"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_c"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "params"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sequence_lengths"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_h"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_c"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "reserve_space"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "host_reserved"
-    type: DT_INT8
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "rnn_mode"
-    type: "string"
-    default_value {
-      s: "lstm"
-    }
-    allowed_values {
-      list {
-        s: "rnn_relu"
-        s: "rnn_tanh"
-        s: "lstm"
-        s: "gru"
-      }
-    }
-  }
-  attr {
-    name: "input_mode"
-    type: "string"
-    default_value {
-      s: "linear_input"
-    }
-    allowed_values {
-      list {
-        s: "linear_input"
-        s: "skip_input"
-        s: "auto_select"
-      }
-    }
-  }
-  attr {
-    name: "direction"
-    type: "string"
-    default_value {
-      s: "unidirectional"
-    }
-    allowed_values {
-      list {
-        s: "unidirectional"
-        s: "bidirectional"
-      }
-    }
-  }
-  attr {
-    name: "dropout"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "is_training"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "CudnnRNNV3"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_h"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_c"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "params"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sequence_lengths"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_h"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_c"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "reserve_space"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "host_reserved"
-    type: DT_INT8
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "rnn_mode"
-    type: "string"
-    default_value {
-      s: "lstm"
-    }
-    allowed_values {
-      list {
-        s: "rnn_relu"
-        s: "rnn_tanh"
-        s: "lstm"
-        s: "gru"
-      }
-    }
-  }
-  attr {
-    name: "input_mode"
-    type: "string"
-    default_value {
-      s: "linear_input"
-    }
-    allowed_values {
-      list {
-        s: "linear_input"
-        s: "skip_input"
-        s: "auto_select"
-      }
-    }
-  }
-  attr {
-    name: "direction"
-    type: "string"
-    default_value {
-      s: "unidirectional"
-    }
-    allowed_values {
-      list {
-        s: "unidirectional"
-        s: "bidirectional"
-      }
-    }
-  }
-  attr {
-    name: "dropout"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "is_training"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "time_major"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "CudnnRNNV3"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_h"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_c"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "params"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sequence_lengths"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_h"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_c"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "reserve_space"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "host_reserved"
-    type: DT_INT8
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "rnn_mode"
-    type: "string"
-    default_value {
-      s: "lstm"
-    }
-    allowed_values {
-      list {
-        s: "rnn_relu"
-        s: "rnn_tanh"
-        s: "lstm"
-        s: "gru"
-      }
-    }
-  }
-  attr {
-    name: "input_mode"
-    type: "string"
-    default_value {
-      s: "linear_input"
-    }
-    allowed_values {
-      list {
-        s: "linear_input"
-        s: "skip_input"
-        s: "auto_select"
-      }
-    }
-  }
-  attr {
-    name: "direction"
-    type: "string"
-    default_value {
-      s: "unidirectional"
-    }
-    allowed_values {
-      list {
-        s: "unidirectional"
-        s: "bidirectional"
-      }
-    }
-  }
-  attr {
-    name: "dropout"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "num_proj"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "is_training"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "time_major"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "Cumprod"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "axis"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-  }
-  attr {
-    name: "exclusive"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "reverse"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Cumprod"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "axis"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-  }
-  attr {
-    name: "exclusive"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "reverse"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Cumprod"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "axis"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-  }
-  attr {
-    name: "exclusive"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "reverse"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Cumprod"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "axis"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-  }
-  attr {
-    name: "exclusive"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "reverse"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Cumsum"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "axis"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-  }
-  attr {
-    name: "exclusive"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "reverse"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Cumsum"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "axis"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-  }
-  attr {
-    name: "exclusive"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "reverse"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Cumsum"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "axis"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-  }
-  attr {
-    name: "exclusive"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "reverse"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Cumsum"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "axis"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-  }
-  attr {
-    name: "exclusive"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "reverse"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "CumulativeLogsumexp"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "axis"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-  }
-  attr {
-    name: "exclusive"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "reverse"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "DataFormatDimMap"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "src_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-  }
-  attr {
-    name: "dst_format"
-    type: "string"
-    default_value {
-      s: "NCHW"
-    }
-  }
-}
-op {
-  name: "DataFormatVecPermute"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "src_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-  }
-  attr {
-    name: "dst_format"
-    type: "string"
-    default_value {
-      s: "NCHW"
-    }
-  }
-}
-op {
-  name: "DatasetCardinality"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "cardinality"
-    type: DT_INT64
-  }
-}
-op {
-  name: "DatasetFromGraph"
-  input_arg {
-    name: "graph_def"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-}
-op {
-  name: "DatasetToGraph"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "graph"
-    type: DT_STRING
-  }
-}
-op {
-  name: "DatasetToSingleElement"
-  input_arg {
-    name: "dataset"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "components"
-    type_list_attr: "output_types"
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "DatasetToSingleElement"
-  input_arg {
-    name: "dataset"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "components"
-    type_list_attr: "output_types"
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "DatasetToTFRecord"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "filename"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "compression_type"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "DebugGradientIdentity"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  allows_uninitialized_input: true
-}
-op {
-  name: "DebugGradientRefIdentity"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-    is_ref: true
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  allows_uninitialized_input: true
-}
-op {
-  name: "DebugIdentity"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "tensor_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "debug_urls"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-  allows_uninitialized_input: true
-}
-op {
-  name: "DebugIdentity"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "tensor_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "debug_urls"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "gated_grpc"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  allows_uninitialized_input: true
-}
-op {
-  name: "DebugIdentity"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "device_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "tensor_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "debug_urls"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "gated_grpc"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  allows_uninitialized_input: true
-}
-op {
-  name: "DebugNanCount"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type: DT_INT64
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "tensor_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "debug_urls"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-  allows_uninitialized_input: true
-}
-op {
-  name: "DebugNanCount"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type: DT_INT64
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "tensor_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "debug_urls"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "gated_grpc"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  allows_uninitialized_input: true
-}
-op {
-  name: "DebugNanCount"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type: DT_INT64
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "device_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "tensor_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "debug_urls"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "gated_grpc"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  allows_uninitialized_input: true
-}
-op {
-  name: "DebugNumericSummary"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type: DT_DOUBLE
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "tensor_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "debug_urls"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-  allows_uninitialized_input: true
-}
-op {
-  name: "DebugNumericSummary"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type: DT_DOUBLE
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "tensor_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "debug_urls"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "lower_bound"
-    type: "float"
-    default_value {
-      f: -inf
-    }
-  }
-  attr {
-    name: "upper_bound"
-    type: "float"
-    default_value {
-      f: inf
-    }
-  }
-  attr {
-    name: "mute_if_healthy"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  allows_uninitialized_input: true
-}
-op {
-  name: "DebugNumericSummary"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type: DT_DOUBLE
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "tensor_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "debug_urls"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "lower_bound"
-    type: "float"
-    default_value {
-      f: -inf
-    }
-  }
-  attr {
-    name: "upper_bound"
-    type: "float"
-    default_value {
-      f: inf
-    }
-  }
-  attr {
-    name: "mute_if_healthy"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "gated_grpc"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  allows_uninitialized_input: true
-}
-op {
-  name: "DebugNumericSummary"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type: DT_DOUBLE
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "device_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "tensor_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "debug_urls"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "lower_bound"
-    type: "float"
-    default_value {
-      f: -inf
-    }
-  }
-  attr {
-    name: "upper_bound"
-    type: "float"
-    default_value {
-      f: inf
-    }
-  }
-  attr {
-    name: "mute_if_healthy"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "gated_grpc"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  allows_uninitialized_input: true
-}
-op {
-  name: "DecodeAndCropJpeg"
-  input_arg {
-    name: "contents"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "crop_window"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "image"
-    type: DT_UINT8
-  }
-  attr {
-    name: "channels"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "ratio"
-    type: "int"
-    default_value {
-      i: 1
-    }
-  }
-  attr {
-    name: "fancy_upscaling"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "try_recover_truncated"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "acceptable_fraction"
-    type: "float"
-    default_value {
-      f: 1
-    }
-  }
-  attr {
-    name: "dct_method"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-}
-op {
-  name: "DecodeBase64"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "output"
-    type: DT_STRING
-  }
-}
-op {
-  name: "DecodeBmp"
-  input_arg {
-    name: "contents"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "image"
-    type: DT_UINT8
-  }
-  attr {
-    name: "channels"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-}
-op {
-  name: "DecodeCSV"
-  input_arg {
-    name: "records"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "record_defaults"
-    type_list_attr: "OUT_TYPE"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "OUT_TYPE"
-  }
-  attr {
-    name: "OUT_TYPE"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "field_delim"
-    type: "string"
-    default_value {
-      s: ","
-    }
-  }
-}
-op {
-  name: "DecodeCSV"
-  input_arg {
-    name: "records"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "record_defaults"
-    type_list_attr: "OUT_TYPE"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "OUT_TYPE"
-  }
-  attr {
-    name: "OUT_TYPE"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "field_delim"
-    type: "string"
-    default_value {
-      s: ","
-    }
-  }
-  attr {
-    name: "use_quote_delim"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "DecodeCSV"
-  input_arg {
-    name: "records"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "record_defaults"
-    type_list_attr: "OUT_TYPE"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "OUT_TYPE"
-  }
-  attr {
-    name: "OUT_TYPE"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "field_delim"
-    type: "string"
-    default_value {
-      s: ","
-    }
-  }
-  attr {
-    name: "use_quote_delim"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "na_value"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-}
-op {
-  name: "DecodeCSV"
-  input_arg {
-    name: "records"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "record_defaults"
-    type_list_attr: "OUT_TYPE"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "OUT_TYPE"
-  }
-  attr {
-    name: "OUT_TYPE"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "field_delim"
-    type: "string"
-    default_value {
-      s: ","
-    }
-  }
-  attr {
-    name: "use_quote_delim"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "na_value"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-}
-op {
-  name: "DecodeCSV"
-  input_arg {
-    name: "records"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "record_defaults"
-    type_list_attr: "OUT_TYPE"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "OUT_TYPE"
-  }
-  attr {
-    name: "OUT_TYPE"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "field_delim"
-    type: "string"
-    default_value {
-      s: ","
-    }
-  }
-  attr {
-    name: "use_quote_delim"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "na_value"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "select_cols"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-}
-op {
-  name: "DecodeCompressed"
-  input_arg {
-    name: "bytes"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "output"
-    type: DT_STRING
-  }
-  attr {
-    name: "compression_type"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-}
-op {
-  name: "DecodeGif"
-  input_arg {
-    name: "contents"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "image"
-    type: DT_UINT8
-  }
-}
-op {
-  name: "DecodeJSONExample"
-  input_arg {
-    name: "json_examples"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "binary_examples"
-    type: DT_STRING
-  }
-}
-op {
-  name: "DecodeJpeg"
-  input_arg {
-    name: "contents"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "image"
-    type: DT_UINT8
-  }
-  attr {
-    name: "channels"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "ratio"
-    type: "int"
-    default_value {
-      i: 1
-    }
-  }
-  attr {
-    name: "fancy_upscaling"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "try_recover_truncated"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "acceptable_fraction"
-    type: "float"
-    default_value {
-      f: 1
-    }
-  }
-  attr {
-    name: "dct_method"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-}
-op {
-  name: "DecodePaddedRaw"
-  input_arg {
-    name: "input_bytes"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "fixed_length"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT16
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "little_endian"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "DecodePng"
-  input_arg {
-    name: "contents"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "image"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "channels"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    default_value {
-      type: DT_UINT8
-    }
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_UINT16
-      }
-    }
-  }
-}
-op {
-  name: "DecodeProtoV2"
-  input_arg {
-    name: "bytes"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "sizes"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "values"
-    type_list_attr: "output_types"
-  }
-  attr {
-    name: "message_type"
-    type: "string"
-  }
-  attr {
-    name: "field_names"
-    type: "list(string)"
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "descriptor_source"
-    type: "string"
-    default_value {
-      s: "local://"
-    }
-  }
-  attr {
-    name: "message_format"
-    type: "string"
-    default_value {
-      s: "binary"
-    }
-  }
-  attr {
-    name: "sanitize"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "DecodeRaw"
-  input_arg {
-    name: "bytes"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "little_endian"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "DecodeRaw"
-  input_arg {
-    name: "bytes"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT16
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "little_endian"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "DecodeRaw"
-  input_arg {
-    name: "bytes"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT16
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  attr {
-    name: "little_endian"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "DecodeRaw"
-  input_arg {
-    name: "bytes"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT16
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_BOOL
-      }
-    }
-  }
-  attr {
-    name: "little_endian"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "DecodeWav"
-  input_arg {
-    name: "contents"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "audio"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "sample_rate"
-    type: DT_INT32
-  }
-  attr {
-    name: "desired_channels"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "desired_samples"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-}
-op {
-  name: "DeepCopy"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "DeleteIterator"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "deleter"
-    type: DT_VARIANT
-  }
-  is_stateful: true
-}
-op {
-  name: "DeleteMultiDeviceIterator"
-  input_arg {
-    name: "multi_device_iterator"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "iterators"
-    type: DT_RESOURCE
-    number_attr: "N"
-  }
-  input_arg {
-    name: "deleter"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-  }
-  is_stateful: true
-}
-op {
-  name: "DeleteSessionTensor"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-}
-op {
-  name: "DeleteSessionTensor"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "DenseToDenseSetOperation"
-  input_arg {
-    name: "set1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "set2"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "result_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "result_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "result_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "set_operation"
-    type: "string"
-  }
-  attr {
-    name: "validate_indices"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_STRING
-      }
-    }
-  }
-}
-op {
-  name: "DenseToSparseBatchDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "batch_size"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "row_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "DenseToSparseSetOperation"
-  input_arg {
-    name: "set1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "set2_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "set2_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "set2_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "result_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "result_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "result_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "set_operation"
-    type: "string"
-  }
-  attr {
-    name: "validate_indices"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_STRING
-      }
-    }
-  }
-}
-op {
-  name: "DepthToSpace"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "block_size"
-    type: "int"
-    has_minimum: true
-    minimum: 2
-  }
-}
-op {
-  name: "DepthToSpace"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "block_size"
-    type: "int"
-    has_minimum: true
-    minimum: 2
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-        s: "NCHW_VECT_C"
-      }
-    }
-  }
-}
-op {
-  name: "DepthwiseConv2dNative"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "DepthwiseConv2dNative"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-}
-op {
-  name: "DepthwiseConv2dNative"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "DepthwiseConv2dNativeBackpropFilter"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter_sizes"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "DepthwiseConv2dNativeBackpropFilter"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter_sizes"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-}
-op {
-  name: "DepthwiseConv2dNativeBackpropFilter"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter_sizes"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "DepthwiseConv2dNativeBackpropFilter"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter_sizes"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "DepthwiseConv2dNativeBackpropInput"
-  input_arg {
-    name: "input_sizes"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "DepthwiseConv2dNativeBackpropInput"
-  input_arg {
-    name: "input_sizes"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-}
-op {
-  name: "DepthwiseConv2dNativeBackpropInput"
-  input_arg {
-    name: "input_sizes"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "DepthwiseConv2dNativeBackpropInput"
-  input_arg {
-    name: "input_sizes"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "Dequantize"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "min_range"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_range"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "mode"
-    type: "string"
-    default_value {
-      s: "MIN_COMBINED"
-    }
-    allowed_values {
-      list {
-        s: "MIN_COMBINED"
-        s: "MIN_FIRST"
-      }
-    }
-  }
-}
-op {
-  name: "Dequantize"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "min_range"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_range"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "mode"
-    type: "string"
-    default_value {
-      s: "MIN_COMBINED"
-    }
-    allowed_values {
-      list {
-        s: "MIN_COMBINED"
-        s: "MIN_FIRST"
-        s: "SCALED"
-      }
-    }
-  }
-}
-op {
-  name: "Dequantize"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "min_range"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_range"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "mode"
-    type: "string"
-    default_value {
-      s: "MIN_COMBINED"
-    }
-    allowed_values {
-      list {
-        s: "MIN_COMBINED"
-        s: "MIN_FIRST"
-        s: "SCALED"
-      }
-    }
-  }
-}
-op {
-  name: "DeserializeIterator"
-  input_arg {
-    name: "resource_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "serialized"
-    type: DT_VARIANT
-  }
-  is_stateful: true
-}
-op {
-  name: "DeserializeManySparse"
-  input_arg {
-    name: "serialized_sparse"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "sparse_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sparse_values"
-    type_attr: "dtype"
-  }
-  output_arg {
-    name: "sparse_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-}
-op {
-  name: "DeserializeSparse"
-  input_arg {
-    name: "serialized_sparse"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "sparse_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sparse_values"
-    type_attr: "dtype"
-  }
-  output_arg {
-    name: "sparse_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-}
-op {
-  name: "DeserializeSparse"
-  input_arg {
-    name: "serialized_sparse"
-    type_attr: "Tserialized"
-  }
-  output_arg {
-    name: "sparse_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sparse_values"
-    type_attr: "dtype"
-  }
-  output_arg {
-    name: "sparse_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "Tserialized"
-    type: "type"
-    default_value {
-      type: DT_STRING
-    }
-    allowed_values {
-      list {
-        type: DT_STRING
-        type: DT_VARIANT
-      }
-    }
-  }
-}
-op {
-  name: "DestroyResourceOp"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "ignore_lookup_error"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "DestroyTemporaryVariable"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  output_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "var_name"
-    type: "string"
-  }
-}
-op {
-  name: "Diag"
-  input_arg {
-    name: "diagonal"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Diag"
-  input_arg {
-    name: "diagonal"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Diag"
-  input_arg {
-    name: "diagonal"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "DiagPart"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "diagonal"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "DiagPart"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "diagonal"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "DiagPart"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "diagonal"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Digamma"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Digamma"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Digamma"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Dilation2D"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "rates"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "Dilation2D"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "rates"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "Dilation2D"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "rates"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "Dilation2D"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "rates"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "Dilation2DBackpropFilter"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "filter_backprop"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "rates"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "Dilation2DBackpropFilter"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "filter_backprop"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "rates"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "Dilation2DBackpropFilter"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "filter_backprop"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "rates"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "Dilation2DBackpropFilter"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "filter_backprop"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "rates"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "Dilation2DBackpropInput"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "in_backprop"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "rates"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "Dilation2DBackpropInput"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "in_backprop"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "rates"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "Dilation2DBackpropInput"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "in_backprop"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "rates"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "Dilation2DBackpropInput"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "in_backprop"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "rates"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "DirectedInterleaveDataset"
-  input_arg {
-    name: "selector_input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "data_input_datasets"
-    type: DT_VARIANT
-    number_attr: "N"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "Div"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Div"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Div"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "DivNoNan"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "DivNoNan"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "DrawBoundingBoxes"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "boxes"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "DrawBoundingBoxesV2"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "boxes"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "colors"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "DynamicPartition"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "partitions"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "outputs"
-    type_attr: "T"
-    number_attr: "num_partitions"
-  }
-  attr {
-    name: "num_partitions"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "DynamicStitch"
-  input_arg {
-    name: "indices"
-    type: DT_INT32
-    number_attr: "N"
-  }
-  input_arg {
-    name: "data"
-    type_attr: "T"
-    number_attr: "N"
-  }
-  output_arg {
-    name: "merged"
-    type_attr: "T"
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "EagerPyFunc"
-  input_arg {
-    name: "input"
-    type_list_attr: "Tin"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "Tout"
-  }
-  attr {
-    name: "token"
-    type: "string"
-  }
-  attr {
-    name: "Tin"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tout"
-    type: "list(type)"
-    has_minimum: true
-  }
-  is_stateful: true
-}
-op {
-  name: "EditDistance"
-  input_arg {
-    name: "hypothesis_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "hypothesis_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "hypothesis_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "truth_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "truth_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "truth_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "normalize"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "Einsum"
-  input_arg {
-    name: "inputs"
-    type_attr: "T"
-    number_attr: "N"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "equation"
-    type: "string"
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "Elu"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "Elu"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Elu"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "EluGrad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "outputs"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "EluGrad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "outputs"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "EluGrad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "outputs"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Empty"
-  input_arg {
-    name: "shape"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "init"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "EmptyTensorList"
-  input_arg {
-    name: "element_shape"
-    type_attr: "shape_type"
-  }
-  input_arg {
-    name: "max_num_elements"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "element_dtype"
-    type: "type"
-  }
-  attr {
-    name: "shape_type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "EncodeBase64"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "output"
-    type: DT_STRING
-  }
-  attr {
-    name: "pad"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "EncodeJpeg"
-  input_arg {
-    name: "image"
-    type: DT_UINT8
-  }
-  output_arg {
-    name: "contents"
-    type: DT_STRING
-  }
-  attr {
-    name: "format"
-    type: "string"
-    default_value {
-      s: ""
-    }
-    allowed_values {
-      list {
-        s: ""
-        s: "grayscale"
-        s: "rgb"
-      }
-    }
-  }
-  attr {
-    name: "quality"
-    type: "int"
-    default_value {
-      i: 95
-    }
-  }
-  attr {
-    name: "progressive"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "optimize_size"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "chroma_downsampling"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "density_unit"
-    type: "string"
-    default_value {
-      s: "in"
-    }
-    allowed_values {
-      list {
-        s: "in"
-        s: "cm"
-      }
-    }
-  }
-  attr {
-    name: "x_density"
-    type: "int"
-    default_value {
-      i: 300
-    }
-  }
-  attr {
-    name: "y_density"
-    type: "int"
-    default_value {
-      i: 300
-    }
-  }
-  attr {
-    name: "xmp_metadata"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-}
-op {
-  name: "EncodeJpegVariableQuality"
-  input_arg {
-    name: "images"
-    type: DT_UINT8
-  }
-  input_arg {
-    name: "quality"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "contents"
-    type: DT_STRING
-  }
-}
-op {
-  name: "EncodePng"
-  input_arg {
-    name: "image"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "contents"
-    type: DT_STRING
-  }
-  attr {
-    name: "compression"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_UINT8
-    }
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_UINT16
-      }
-    }
-  }
-}
-op {
-  name: "EncodeProto"
-  input_arg {
-    name: "sizes"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "values"
-    type_list_attr: "Tinput_types"
-  }
-  output_arg {
-    name: "bytes"
-    type: DT_STRING
-  }
-  attr {
-    name: "field_names"
-    type: "list(string)"
-  }
-  attr {
-    name: "message_type"
-    type: "string"
-  }
-  attr {
-    name: "descriptor_source"
-    type: "string"
-    default_value {
-      s: "local://"
-    }
-  }
-  attr {
-    name: "Tinput_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "EncodeWav"
-  input_arg {
-    name: "audio"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "sample_rate"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "contents"
-    type: DT_STRING
-  }
-}
-op {
-  name: "EnqueueTPUEmbeddingIntegerBatch"
-  input_arg {
-    name: "batch"
-    type: DT_INT32
-    number_attr: "N"
-  }
-  input_arg {
-    name: "mode_override"
-    type: DT_STRING
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "device_ordinal"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "EnqueueTPUEmbeddingSparseBatch"
-  input_arg {
-    name: "sample_indices"
-    type: DT_INT32
-    number_attr: "N"
-  }
-  input_arg {
-    name: "embedding_indices"
-    type: DT_INT32
-    number_attr: "N"
-  }
-  input_arg {
-    name: "aggregation_weights"
-    type: DT_FLOAT
-    number_attr: "N"
-  }
-  input_arg {
-    name: "mode_override"
-    type: DT_STRING
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "device_ordinal"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "combiners"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "EnqueueTPUEmbeddingSparseBatch"
-  input_arg {
-    name: "sample_indices"
-    type_attr: "T1"
-    number_attr: "N"
-  }
-  input_arg {
-    name: "embedding_indices"
-    type_attr: "T2"
-    number_attr: "N"
-  }
-  input_arg {
-    name: "aggregation_weights"
-    type_attr: "T3"
-    number_attr: "N"
-  }
-  input_arg {
-    name: "mode_override"
-    type: DT_STRING
-  }
-  attr {
-    name: "T1"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T2"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T3"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "device_ordinal"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "combiners"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "EnqueueTPUEmbeddingSparseTensorBatch"
-  input_arg {
-    name: "sample_indices"
-    type: DT_INT32
-    number_attr: "N"
-  }
-  input_arg {
-    name: "embedding_indices"
-    type: DT_INT32
-    number_attr: "N"
-  }
-  input_arg {
-    name: "aggregation_weights"
-    type: DT_FLOAT
-    number_attr: "N"
-  }
-  input_arg {
-    name: "mode_override"
-    type: DT_STRING
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "device_ordinal"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "combiners"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "table_ids"
-    type: "list(int)"
-  }
-  is_stateful: true
-}
-op {
-  name: "EnqueueTPUEmbeddingSparseTensorBatch"
-  input_arg {
-    name: "sample_indices"
-    type_attr: "T1"
-    number_attr: "N"
-  }
-  input_arg {
-    name: "embedding_indices"
-    type_attr: "T2"
-    number_attr: "N"
-  }
-  input_arg {
-    name: "aggregation_weights"
-    type_attr: "T3"
-    number_attr: "N"
-  }
-  input_arg {
-    name: "mode_override"
-    type: DT_STRING
-  }
-  attr {
-    name: "T1"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T2"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T3"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "device_ordinal"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "combiners"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "table_ids"
-    type: "list(int)"
-  }
-  is_stateful: true
-}
-op {
-  name: "EnqueueTPUEmbeddingSparseTensorBatch"
-  input_arg {
-    name: "sample_indices"
-    type_attr: "T1"
-    number_attr: "N"
-  }
-  input_arg {
-    name: "embedding_indices"
-    type_attr: "T2"
-    number_attr: "N"
-  }
-  input_arg {
-    name: "aggregation_weights"
-    type_attr: "T3"
-    number_attr: "N"
-  }
-  input_arg {
-    name: "mode_override"
-    type: DT_STRING
-  }
-  attr {
-    name: "T1"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T2"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T3"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "device_ordinal"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "combiners"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "table_ids"
-    type: "list(int)"
-  }
-  attr {
-    name: "max_sequence_lengths"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "EnsureShape"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "Enter"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "frame_name"
-    type: "string"
-  }
-  attr {
-    name: "is_constant"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "parallel_iterations"
-    type: "int"
-    default_value {
-      i: 10
-    }
-  }
-}
-op {
-  name: "Equal"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_QUINT8
-        type: DT_QINT8
-        type: DT_QINT32
-        type: DT_STRING
-        type: DT_BOOL
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "Equal"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_QUINT8
-        type: DT_QINT8
-        type: DT_QINT32
-        type: DT_STRING
-        type: DT_BOOL
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "Equal"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_QUINT8
-        type: DT_QINT8
-        type: DT_QINT32
-        type: DT_STRING
-        type: DT_BOOL
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "Erf"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Erf"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Erf"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Erfc"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Erfc"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Erfc"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "EuclideanNorm"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reduction_indices"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Exit"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "Exp"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Exp"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Exp"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "ExpandDims"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dim"
-    type_attr: "Tdim"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tdim"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "ExperimentalAssertNextDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "transformations"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ExperimentalAutoShardDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "num_workers"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "index"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ExperimentalBytesProducedStatsDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "tag"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ExperimentalCSVDataset"
-  input_arg {
-    name: "filenames"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "compression_type"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "buffer_size"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "header"
-    type: DT_BOOL
-  }
-  input_arg {
-    name: "field_delim"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "use_quote_delim"
-    type: DT_BOOL
-  }
-  input_arg {
-    name: "na_value"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "select_cols"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "record_defaults"
-    type_list_attr: "output_types"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "ExperimentalChooseFastestDataset"
-  input_arg {
-    name: "input_datasets"
-    type: DT_VARIANT
-    number_attr: "N"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 2
-  }
-  attr {
-    name: "num_experiments"
-    type: "int"
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ExperimentalChooseFastestDataset"
-  input_arg {
-    name: "input_datasets"
-    type: DT_VARIANT
-    number_attr: "N"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 2
-  }
-  attr {
-    name: "num_experiments"
-    type: "int"
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ExperimentalDatasetCardinality"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "cardinality"
-    type: DT_INT64
-  }
-}
-op {
-  name: "ExperimentalDatasetToTFRecord"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "filename"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "compression_type"
-    type: DT_STRING
-  }
-}
-op {
-  name: "ExperimentalDatasetToTFRecord"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "filename"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "compression_type"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "ExperimentalDenseToSparseBatchDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "batch_size"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "row_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "ExperimentalDenseToSparseBatchDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "batch_size"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "row_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ExperimentalDirectedInterleaveDataset"
-  input_arg {
-    name: "selector_input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "data_input_datasets"
-    type: DT_VARIANT
-    number_attr: "N"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ExperimentalGroupByReducerDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "key_func_other_arguments"
-    type_list_attr: "Tkey_func_other_arguments"
-  }
-  input_arg {
-    name: "init_func_other_arguments"
-    type_list_attr: "Tinit_func_other_arguments"
-  }
-  input_arg {
-    name: "reduce_func_other_arguments"
-    type_list_attr: "Treduce_func_other_arguments"
-  }
-  input_arg {
-    name: "finalize_func_other_arguments"
-    type_list_attr: "Tfinalize_func_other_arguments"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "key_func"
-    type: "func"
-  }
-  attr {
-    name: "init_func"
-    type: "func"
-  }
-  attr {
-    name: "reduce_func"
-    type: "func"
-  }
-  attr {
-    name: "finalize_func"
-    type: "func"
-  }
-  attr {
-    name: "Tkey_func_other_arguments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tinit_func_other_arguments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Treduce_func_other_arguments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tfinalize_func_other_arguments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "ExperimentalGroupByWindowDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "key_func_other_arguments"
-    type_list_attr: "Tkey_func_other_arguments"
-  }
-  input_arg {
-    name: "reduce_func_other_arguments"
-    type_list_attr: "Treduce_func_other_arguments"
-  }
-  input_arg {
-    name: "window_size_func_other_arguments"
-    type_list_attr: "Twindow_size_func_other_arguments"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "key_func"
-    type: "func"
-  }
-  attr {
-    name: "reduce_func"
-    type: "func"
-  }
-  attr {
-    name: "window_size_func"
-    type: "func"
-  }
-  attr {
-    name: "Tkey_func_other_arguments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Treduce_func_other_arguments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Twindow_size_func_other_arguments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "ExperimentalGroupByWindowDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "key_func_other_arguments"
-    type_list_attr: "Tkey_func_other_arguments"
-  }
-  input_arg {
-    name: "reduce_func_other_arguments"
-    type_list_attr: "Treduce_func_other_arguments"
-  }
-  input_arg {
-    name: "window_size_func_other_arguments"
-    type_list_attr: "Twindow_size_func_other_arguments"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "key_func"
-    type: "func"
-  }
-  attr {
-    name: "reduce_func"
-    type: "func"
-  }
-  attr {
-    name: "window_size_func"
-    type: "func"
-  }
-  attr {
-    name: "Tkey_func_other_arguments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Treduce_func_other_arguments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Twindow_size_func_other_arguments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ExperimentalIgnoreErrorsDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ExperimentalIteratorGetDevice"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "device"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "ExperimentalLMDBDataset"
-  input_arg {
-    name: "filenames"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "ExperimentalLatencyStatsDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "tag"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ExperimentalMapAndBatchDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  input_arg {
-    name: "batch_size"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "num_parallel_calls"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "drop_remainder"
-    type: DT_BOOL
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ExperimentalMapAndBatchDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  input_arg {
-    name: "batch_size"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "num_parallel_calls"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "drop_remainder"
-    type: DT_BOOL
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "preserve_cardinality"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ExperimentalMapDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "use_inter_op_parallelism"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "ExperimentalMapDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "use_inter_op_parallelism"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "preserve_cardinality"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ExperimentalMatchingFilesDataset"
-  input_arg {
-    name: "patterns"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  is_stateful: true
-}
-op {
-  name: "ExperimentalMaxIntraOpParallelismDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "max_intra_op_parallelism"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ExperimentalNonSerializableDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ExperimentalParallelInterleaveDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  input_arg {
-    name: "cycle_length"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "block_length"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sloppy"
-    type: DT_BOOL
-  }
-  input_arg {
-    name: "buffer_output_elements"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "prefetch_input_elements"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ExperimentalParseExampleDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "num_parallel_calls"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "dense_defaults"
-    type_list_attr: "Tdense"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "sparse_keys"
-    type: "list(string)"
-    has_minimum: true
-  }
-  attr {
-    name: "dense_keys"
-    type: "list(string)"
-    has_minimum: true
-  }
-  attr {
-    name: "sparse_types"
-    type: "list(type)"
-    has_minimum: true
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "Tdense"
-    type: "list(type)"
-    has_minimum: true
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "dense_shapes"
-    type: "list(shape)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ExperimentalParseExampleDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "num_parallel_calls"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "dense_defaults"
-    type_list_attr: "Tdense"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "sparse_keys"
-    type: "list(string)"
-    has_minimum: true
-  }
-  attr {
-    name: "dense_keys"
-    type: "list(string)"
-    has_minimum: true
-  }
-  attr {
-    name: "sparse_types"
-    type: "list(type)"
-    has_minimum: true
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "Tdense"
-    type: "list(type)"
-    has_minimum: true
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "dense_shapes"
-    type: "list(shape)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "sloppy"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ExperimentalPrivateThreadPoolDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "num_threads"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ExperimentalRandomDataset"
-  input_arg {
-    name: "seed"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "seed2"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "ExperimentalRebatchDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "num_workers"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ExperimentalRebatchDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "num_workers"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "use_fallback"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "ExperimentalScanDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "initial_state"
-    type_list_attr: "Tstate"
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Tstate"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ExperimentalScanDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "initial_state"
-    type_list_attr: "Tstate"
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Tstate"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "preserve_cardinality"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ExperimentalSetStatsAggregatorDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "stats_aggregator"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "tag"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "counter_prefix"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "ExperimentalSleepDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "sleep_microseconds"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ExperimentalSlidingWindowDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "window_size"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "window_shift"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "window_stride"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ExperimentalSqlDataset"
-  input_arg {
-    name: "driver_name"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "data_source_name"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "query"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "ExperimentalStatsAggregatorHandle"
-  output_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ExperimentalStatsAggregatorSummary"
-  input_arg {
-    name: "iterator"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "summary"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "ExperimentalTakeWhileDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "predicate"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ExperimentalThreadPoolDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "thread_pool"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "ExperimentalThreadPoolHandle"
-  output_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "num_threads"
-    type: "int"
-  }
-  attr {
-    name: "max_intra_op_parallelism"
-    type: "int"
-    default_value {
-      i: 1
-    }
-  }
-  attr {
-    name: "display_name"
-    type: "string"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ExperimentalUnbatchDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ExperimentalUniqueDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "Expm1"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Expm1"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Expm1"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "ExtractGlimpse"
-  input_arg {
-    name: "input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "offsets"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "glimpse"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "centered"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "normalized"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "uniform_noise"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "ExtractGlimpse"
-  input_arg {
-    name: "input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "offsets"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "glimpse"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "centered"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "normalized"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "uniform_noise"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "noise"
-    type: "string"
-    default_value {
-      s: "uniform"
-    }
-  }
-}
-op {
-  name: "ExtractImagePatches"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "patches"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksizes"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "rates"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "ExtractImagePatches"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "patches"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksizes"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "rates"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "ExtractImagePatches"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "patches"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksizes"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "rates"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "ExtractImagePatches"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "patches"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksizes"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "rates"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "ExtractJpegShape"
-  input_arg {
-    name: "contents"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "image_shape"
-    type_attr: "output_type"
-  }
-  attr {
-    name: "output_type"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "ExtractVolumePatches"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "patches"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksizes"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "FFT"
-  input_arg {
-    name: "input"
-    type: DT_COMPLEX64
-  }
-  output_arg {
-    name: "output"
-    type: DT_COMPLEX64
-  }
-}
-op {
-  name: "FFT"
-  input_arg {
-    name: "input"
-    type_attr: "Tcomplex"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "Tcomplex"
-  }
-  attr {
-    name: "Tcomplex"
-    type: "type"
-    default_value {
-      type: DT_COMPLEX64
-    }
-    allowed_values {
-      list {
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "FFT2D"
-  input_arg {
-    name: "input"
-    type: DT_COMPLEX64
-  }
-  output_arg {
-    name: "output"
-    type: DT_COMPLEX64
-  }
-}
-op {
-  name: "FFT2D"
-  input_arg {
-    name: "input"
-    type_attr: "Tcomplex"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "Tcomplex"
-  }
-  attr {
-    name: "Tcomplex"
-    type: "type"
-    default_value {
-      type: DT_COMPLEX64
-    }
-    allowed_values {
-      list {
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "FFT3D"
-  input_arg {
-    name: "input"
-    type: DT_COMPLEX64
-  }
-  output_arg {
-    name: "output"
-    type: DT_COMPLEX64
-  }
-}
-op {
-  name: "FFT3D"
-  input_arg {
-    name: "input"
-    type_attr: "Tcomplex"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "Tcomplex"
-  }
-  attr {
-    name: "Tcomplex"
-    type: "type"
-    default_value {
-      type: DT_COMPLEX64
-    }
-    allowed_values {
-      list {
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "FIFOQueue"
-  output_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "component_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "shapes"
-    type: "list(shape)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "FIFOQueueV2"
-  output_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "component_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "shapes"
-    type: "list(shape)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "Fact"
-  output_arg {
-    name: "fact"
-    type: DT_STRING
-  }
-}
-op {
-  name: "FakeParam"
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-}
-op {
-  name: "FakeQuantWithMinMaxArgs"
-  input_arg {
-    name: "inputs"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "outputs"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "min"
-    type: "float"
-    default_value {
-      f: -6
-    }
-  }
-  attr {
-    name: "max"
-    type: "float"
-    default_value {
-      f: 6
-    }
-  }
-}
-op {
-  name: "FakeQuantWithMinMaxArgs"
-  input_arg {
-    name: "inputs"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "outputs"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "min"
-    type: "float"
-    default_value {
-      f: -6
-    }
-  }
-  attr {
-    name: "max"
-    type: "float"
-    default_value {
-      f: 6
-    }
-  }
-  attr {
-    name: "num_bits"
-    type: "int"
-    default_value {
-      i: 8
-    }
-  }
-}
-op {
-  name: "FakeQuantWithMinMaxArgs"
-  input_arg {
-    name: "inputs"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "outputs"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "min"
-    type: "float"
-    default_value {
-      f: -6
-    }
-  }
-  attr {
-    name: "max"
-    type: "float"
-    default_value {
-      f: 6
-    }
-  }
-  attr {
-    name: "num_bits"
-    type: "int"
-    default_value {
-      i: 8
-    }
-  }
-  attr {
-    name: "narrow_range"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "FakeQuantWithMinMaxArgsGradient"
-  input_arg {
-    name: "gradients"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "inputs"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "backprops"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "min"
-    type: "float"
-    default_value {
-      f: -6
-    }
-  }
-  attr {
-    name: "max"
-    type: "float"
-    default_value {
-      f: 6
-    }
-  }
-}
-op {
-  name: "FakeQuantWithMinMaxArgsGradient"
-  input_arg {
-    name: "gradients"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "inputs"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "backprops"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "min"
-    type: "float"
-    default_value {
-      f: -6
-    }
-  }
-  attr {
-    name: "max"
-    type: "float"
-    default_value {
-      f: 6
-    }
-  }
-  attr {
-    name: "num_bits"
-    type: "int"
-    default_value {
-      i: 8
-    }
-  }
-}
-op {
-  name: "FakeQuantWithMinMaxArgsGradient"
-  input_arg {
-    name: "gradients"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "inputs"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "backprops"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "min"
-    type: "float"
-    default_value {
-      f: -6
-    }
-  }
-  attr {
-    name: "max"
-    type: "float"
-    default_value {
-      f: 6
-    }
-  }
-  attr {
-    name: "num_bits"
-    type: "int"
-    default_value {
-      i: 8
-    }
-  }
-  attr {
-    name: "narrow_range"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "FakeQuantWithMinMaxVars"
-  input_arg {
-    name: "inputs"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "outputs"
-    type: DT_FLOAT
-  }
-}
-op {
-  name: "FakeQuantWithMinMaxVars"
-  input_arg {
-    name: "inputs"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "outputs"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "num_bits"
-    type: "int"
-    default_value {
-      i: 8
-    }
-  }
-}
-op {
-  name: "FakeQuantWithMinMaxVars"
-  input_arg {
-    name: "inputs"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "outputs"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "num_bits"
-    type: "int"
-    default_value {
-      i: 8
-    }
-  }
-  attr {
-    name: "narrow_range"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "FakeQuantWithMinMaxVarsGradient"
-  input_arg {
-    name: "gradients"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "inputs"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "backprops_wrt_input"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "backprop_wrt_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "backprop_wrt_max"
-    type: DT_FLOAT
-  }
-}
-op {
-  name: "FakeQuantWithMinMaxVarsGradient"
-  input_arg {
-    name: "gradients"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "inputs"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "backprops_wrt_input"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "backprop_wrt_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "backprop_wrt_max"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "num_bits"
-    type: "int"
-    default_value {
-      i: 8
-    }
-  }
-}
-op {
-  name: "FakeQuantWithMinMaxVarsGradient"
-  input_arg {
-    name: "gradients"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "inputs"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "backprops_wrt_input"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "backprop_wrt_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "backprop_wrt_max"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "num_bits"
-    type: "int"
-    default_value {
-      i: 8
-    }
-  }
-  attr {
-    name: "narrow_range"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "FakeQuantWithMinMaxVarsPerChannel"
-  input_arg {
-    name: "inputs"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "outputs"
-    type: DT_FLOAT
-  }
-}
-op {
-  name: "FakeQuantWithMinMaxVarsPerChannel"
-  input_arg {
-    name: "inputs"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "outputs"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "num_bits"
-    type: "int"
-    default_value {
-      i: 8
-    }
-  }
-}
-op {
-  name: "FakeQuantWithMinMaxVarsPerChannel"
-  input_arg {
-    name: "inputs"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "outputs"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "num_bits"
-    type: "int"
-    default_value {
-      i: 8
-    }
-  }
-  attr {
-    name: "narrow_range"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "FakeQuantWithMinMaxVarsPerChannelGradient"
-  input_arg {
-    name: "gradients"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "inputs"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "backprops_wrt_input"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "backprop_wrt_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "backprop_wrt_max"
-    type: DT_FLOAT
-  }
-}
-op {
-  name: "FakeQuantWithMinMaxVarsPerChannelGradient"
-  input_arg {
-    name: "gradients"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "inputs"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "backprops_wrt_input"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "backprop_wrt_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "backprop_wrt_max"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "num_bits"
-    type: "int"
-    default_value {
-      i: 8
-    }
-  }
-}
-op {
-  name: "FakeQuantWithMinMaxVarsPerChannelGradient"
-  input_arg {
-    name: "gradients"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "inputs"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "backprops_wrt_input"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "backprop_wrt_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "backprop_wrt_max"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "num_bits"
-    type: "int"
-    default_value {
-      i: 8
-    }
-  }
-  attr {
-    name: "narrow_range"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "FakeQueue"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  is_stateful: true
-}
-op {
-  name: "Fill"
-  input_arg {
-    name: "dims"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "Fill"
-  input_arg {
-    name: "dims"
-    type_attr: "index_type"
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "index_type"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "FilterByLastComponentDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "output"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "FilterDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "predicate"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "FilterDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "predicate"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "Fingerprint"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "method"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "fingerprint"
-    type: DT_UINT8
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "FixedLengthRecordDataset"
-  input_arg {
-    name: "filenames"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "header_bytes"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "record_bytes"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "footer_bytes"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "buffer_size"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  is_stateful: true
-}
-op {
-  name: "FixedLengthRecordDatasetV2"
-  input_arg {
-    name: "filenames"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "header_bytes"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "record_bytes"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "footer_bytes"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "buffer_size"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "compression_type"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  is_stateful: true
-}
-op {
-  name: "FixedLengthRecordReader"
-  output_arg {
-    name: "reader_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "header_bytes"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "record_bytes"
-    type: "int"
-  }
-  attr {
-    name: "footer_bytes"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "FixedLengthRecordReader"
-  output_arg {
-    name: "reader_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "header_bytes"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "record_bytes"
-    type: "int"
-  }
-  attr {
-    name: "footer_bytes"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "hop_bytes"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "FixedLengthRecordReader"
-  output_arg {
-    name: "reader_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "header_bytes"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "record_bytes"
-    type: "int"
-  }
-  attr {
-    name: "footer_bytes"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "hop_bytes"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  deprecation {
-    version: 26
-  }
-  is_stateful: true
-}
-op {
-  name: "FixedLengthRecordReaderV2"
-  output_arg {
-    name: "reader_handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "header_bytes"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "record_bytes"
-    type: "int"
-  }
-  attr {
-    name: "footer_bytes"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "FixedLengthRecordReaderV2"
-  output_arg {
-    name: "reader_handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "header_bytes"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "record_bytes"
-    type: "int"
-  }
-  attr {
-    name: "footer_bytes"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "hop_bytes"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "FixedLengthRecordReaderV2"
-  output_arg {
-    name: "reader_handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "header_bytes"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "record_bytes"
-    type: "int"
-  }
-  attr {
-    name: "footer_bytes"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "hop_bytes"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "encoding"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "FixedUnigramCandidateSampler"
-  input_arg {
-    name: "true_classes"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sampled_candidates"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "true_expected_count"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "sampled_expected_count"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "num_true"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "num_sampled"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "unique"
-    type: "bool"
-  }
-  attr {
-    name: "range_max"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "vocab_file"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "distortion"
-    type: "float"
-    default_value {
-      f: 1
-    }
-  }
-  attr {
-    name: "num_reserved_ids"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-    default_value {
-      i: 1
-    }
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "shard"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "unigrams"
-    type: "list(float)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-}
-op {
-  name: "FixedUnigramCandidateSampler"
-  input_arg {
-    name: "true_classes"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sampled_candidates"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "true_expected_count"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "sampled_expected_count"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "num_true"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "num_sampled"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "unique"
-    type: "bool"
-  }
-  attr {
-    name: "range_max"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "vocab_file"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "distortion"
-    type: "float"
-    default_value {
-      f: 1
-    }
-  }
-  attr {
-    name: "num_reserved_ids"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-    default_value {
-      i: 1
-    }
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "shard"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "unigrams"
-    type: "list(float)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "FlatMapDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "FlatMapDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "Floor"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Floor"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Floor"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "FloorDiv"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "FloorDiv"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "FloorDiv"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "FloorMod"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "FloorMod"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "FloorMod"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "FlushSummaryWriter"
-  input_arg {
-    name: "writer"
-    type: DT_RESOURCE
-  }
-  is_stateful: true
-}
-op {
-  name: "For"
-  input_arg {
-    name: "start"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "limit"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "delta"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "input"
-    type_list_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "body"
-    type: "func"
-  }
-}
-op {
-  name: "FractionalAvgPool"
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "row_pooling_sequence"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "col_pooling_sequence"
-    type: DT_INT64
-  }
-  attr {
-    name: "pooling_ratio"
-    type: "list(float)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "pseudo_random"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "overlapping"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "deterministic"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "FractionalAvgPoolGrad"
-  input_arg {
-    name: "orig_input_tensor_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "row_pooling_sequence"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "col_pooling_sequence"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "overlapping"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "FractionalMaxPool"
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "row_pooling_sequence"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "col_pooling_sequence"
-    type: DT_INT64
-  }
-  attr {
-    name: "pooling_ratio"
-    type: "list(float)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "pseudo_random"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "overlapping"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "deterministic"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "FractionalMaxPoolGrad"
-  input_arg {
-    name: "orig_input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "orig_output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "out_backprop"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "row_pooling_sequence"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "col_pooling_sequence"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "overlapping"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "FusedBatchNorm"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "scale"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "offset"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "mean"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "variance"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "batch_mean"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "batch_variance"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "reserve_space_1"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "reserve_space_2"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "epsilon"
-    type: "float"
-    default_value {
-      f: 0.0001
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "is_training"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "FusedBatchNormGrad"
-  input_arg {
-    name: "y_backprop"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "scale"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reserve_space_1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reserve_space_2"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "x_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "scale_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "offset_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "reserve_space_3"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "reserve_space_4"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "epsilon"
-    type: "float"
-    default_value {
-      f: 0.0001
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "is_training"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "FusedBatchNormGradV2"
-  input_arg {
-    name: "y_backprop"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "scale"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "reserve_space_1"
-    type_attr: "U"
-  }
-  input_arg {
-    name: "reserve_space_2"
-    type_attr: "U"
-  }
-  output_arg {
-    name: "x_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "scale_backprop"
-    type_attr: "U"
-  }
-  output_arg {
-    name: "offset_backprop"
-    type_attr: "U"
-  }
-  output_arg {
-    name: "reserve_space_3"
-    type_attr: "U"
-  }
-  output_arg {
-    name: "reserve_space_4"
-    type_attr: "U"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "U"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "epsilon"
-    type: "float"
-    default_value {
-      f: 0.0001
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "is_training"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "FusedBatchNormGradV3"
-  input_arg {
-    name: "y_backprop"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "scale"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "reserve_space_1"
-    type_attr: "U"
-  }
-  input_arg {
-    name: "reserve_space_2"
-    type_attr: "U"
-  }
-  input_arg {
-    name: "reserve_space_3"
-    type_attr: "U"
-  }
-  output_arg {
-    name: "x_backprop"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "scale_backprop"
-    type_attr: "U"
-  }
-  output_arg {
-    name: "offset_backprop"
-    type_attr: "U"
-  }
-  output_arg {
-    name: "reserve_space_4"
-    type_attr: "U"
-  }
-  output_arg {
-    name: "reserve_space_5"
-    type_attr: "U"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "U"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "epsilon"
-    type: "float"
-    default_value {
-      f: 0.0001
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "is_training"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "FusedBatchNormV2"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "scale"
-    type_attr: "U"
-  }
-  input_arg {
-    name: "offset"
-    type_attr: "U"
-  }
-  input_arg {
-    name: "mean"
-    type_attr: "U"
-  }
-  input_arg {
-    name: "variance"
-    type_attr: "U"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "batch_mean"
-    type_attr: "U"
-  }
-  output_arg {
-    name: "batch_variance"
-    type_attr: "U"
-  }
-  output_arg {
-    name: "reserve_space_1"
-    type_attr: "U"
-  }
-  output_arg {
-    name: "reserve_space_2"
-    type_attr: "U"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "U"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "epsilon"
-    type: "float"
-    default_value {
-      f: 0.0001
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "is_training"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "FusedBatchNormV3"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "scale"
-    type_attr: "U"
-  }
-  input_arg {
-    name: "offset"
-    type_attr: "U"
-  }
-  input_arg {
-    name: "mean"
-    type_attr: "U"
-  }
-  input_arg {
-    name: "variance"
-    type_attr: "U"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "batch_mean"
-    type_attr: "U"
-  }
-  output_arg {
-    name: "batch_variance"
-    type_attr: "U"
-  }
-  output_arg {
-    name: "reserve_space_1"
-    type_attr: "U"
-  }
-  output_arg {
-    name: "reserve_space_2"
-    type_attr: "U"
-  }
-  output_arg {
-    name: "reserve_space_3"
-    type_attr: "U"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "U"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "epsilon"
-    type: "float"
-    default_value {
-      f: 0.0001
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "is_training"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "FusedPadConv2D"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "paddings"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "mode"
-    type: "string"
-    allowed_values {
-      list {
-        s: "REFLECT"
-        s: "SYMMETRIC"
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "FusedPadConv2D"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "paddings"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "mode"
-    type: "string"
-    allowed_values {
-      list {
-        s: "REFLECT"
-        s: "SYMMETRIC"
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "FusedResizeAndPadConv2D"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "paddings"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "resize_align_corners"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "mode"
-    type: "string"
-    allowed_values {
-      list {
-        s: "REFLECT"
-        s: "SYMMETRIC"
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "FusedResizeAndPadConv2D"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "paddings"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "resize_align_corners"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "mode"
-    type: "string"
-    allowed_values {
-      list {
-        s: "REFLECT"
-        s: "SYMMETRIC"
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "GRUBlockCell"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "h_prev"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "w_ru"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "w_c"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b_ru"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b_c"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "r"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "u"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "c"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "h"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "GRUBlockCellGrad"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "h_prev"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "w_ru"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "w_c"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b_ru"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b_c"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "r"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "u"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "c"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "d_h"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "d_x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "d_h_prev"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "d_c_bar"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "d_r_bar_u_bar"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "Gather"
-  input_arg {
-    name: "params"
-    type_attr: "Tparams"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "Tparams"
-  }
-  attr {
-    name: "validate_indices"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "Tparams"
-    type: "type"
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "GatherNd"
-  input_arg {
-    name: "params"
-    type_attr: "Tparams"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "Tparams"
-  }
-  attr {
-    name: "Tparams"
-    type: "type"
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "GatherV2"
-  input_arg {
-    name: "params"
-    type_attr: "Tparams"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "axis"
-    type_attr: "Taxis"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "Tparams"
-  }
-  attr {
-    name: "Tparams"
-    type: "type"
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Taxis"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "GatherV2"
-  input_arg {
-    name: "params"
-    type_attr: "Tparams"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "axis"
-    type_attr: "Taxis"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "Tparams"
-  }
-  attr {
-    name: "batch_dims"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "Tparams"
-    type: "type"
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Taxis"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "GenerateVocabRemapping"
-  input_arg {
-    name: "new_vocab_file"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "old_vocab_file"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "remapping"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "num_present"
-    type: DT_INT32
-  }
-  attr {
-    name: "new_vocab_offset"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "num_new_vocab"
-    type: "int"
-    has_minimum: true
-  }
-}
-op {
-  name: "GenerateVocabRemapping"
-  input_arg {
-    name: "new_vocab_file"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "old_vocab_file"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "remapping"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "num_present"
-    type: DT_INT32
-  }
-  attr {
-    name: "new_vocab_offset"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "num_new_vocab"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "old_vocab_size"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-}
-op {
-  name: "GeneratorDataset"
-  input_arg {
-    name: "init_func_other_args"
-    type_list_attr: "Tinit_func_args"
-  }
-  input_arg {
-    name: "next_func_other_args"
-    type_list_attr: "Tnext_func_args"
-  }
-  input_arg {
-    name: "finalize_func_other_args"
-    type_list_attr: "Tfinalize_func_args"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "init_func"
-    type: "func"
-  }
-  attr {
-    name: "next_func"
-    type: "func"
-  }
-  attr {
-    name: "finalize_func"
-    type: "func"
-  }
-  attr {
-    name: "Tinit_func_args"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tnext_func_args"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tfinalize_func_args"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "GetSessionHandle"
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "GetSessionHandle"
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  deprecation {
-    version: 23
-  }
-}
-op {
-  name: "GetSessionHandle"
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "GetSessionHandle"
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "GetSessionHandleV2"
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "GetSessionTensor"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "value"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-}
-op {
-  name: "GetSessionTensor"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "value"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "Greater"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "Greater"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "Greater"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "Greater"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "GreaterEqual"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "GreaterEqual"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "GreaterEqual"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "GreaterEqual"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "GroupByReducerDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "key_func_other_arguments"
-    type_list_attr: "Tkey_func_other_arguments"
-  }
-  input_arg {
-    name: "init_func_other_arguments"
-    type_list_attr: "Tinit_func_other_arguments"
-  }
-  input_arg {
-    name: "reduce_func_other_arguments"
-    type_list_attr: "Treduce_func_other_arguments"
-  }
-  input_arg {
-    name: "finalize_func_other_arguments"
-    type_list_attr: "Tfinalize_func_other_arguments"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "key_func"
-    type: "func"
-  }
-  attr {
-    name: "init_func"
-    type: "func"
-  }
-  attr {
-    name: "reduce_func"
-    type: "func"
-  }
-  attr {
-    name: "finalize_func"
-    type: "func"
-  }
-  attr {
-    name: "Tkey_func_other_arguments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tinit_func_other_arguments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Treduce_func_other_arguments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tfinalize_func_other_arguments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "GroupByWindowDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "key_func_other_arguments"
-    type_list_attr: "Tkey_func_other_arguments"
-  }
-  input_arg {
-    name: "reduce_func_other_arguments"
-    type_list_attr: "Treduce_func_other_arguments"
-  }
-  input_arg {
-    name: "window_size_func_other_arguments"
-    type_list_attr: "Twindow_size_func_other_arguments"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "key_func"
-    type: "func"
-  }
-  attr {
-    name: "reduce_func"
-    type: "func"
-  }
-  attr {
-    name: "window_size_func"
-    type: "func"
-  }
-  attr {
-    name: "Tkey_func_other_arguments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Treduce_func_other_arguments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Twindow_size_func_other_arguments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "GuaranteeConst"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "HSVToRGB"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "HSVToRGB"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "HashTable"
-  output_arg {
-    name: "table_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "use_node_name_sharing"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "key_dtype"
-    type: "type"
-  }
-  attr {
-    name: "value_dtype"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "HashTableV2"
-  output_arg {
-    name: "table_handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "use_node_name_sharing"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "key_dtype"
-    type: "type"
-  }
-  attr {
-    name: "value_dtype"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "HistogramFixedWidth"
-  input_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "value_range"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "nbins"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "out"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "HistogramSummary"
-  input_arg {
-    name: "tag"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "summary"
-    type: DT_STRING
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "HistogramSummary"
-  input_arg {
-    name: "tag"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "summary"
-    type: DT_STRING
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "HistogramSummary"
-  input_arg {
-    name: "tag"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "summary"
-    type: DT_STRING
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "HistogramSummary"
-  input_arg {
-    name: "tag"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "summary"
-    type: DT_STRING
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "HostConst"
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "value"
-    type: "tensor"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-}
-op {
-  name: "IFFT"
-  input_arg {
-    name: "input"
-    type: DT_COMPLEX64
-  }
-  output_arg {
-    name: "output"
-    type: DT_COMPLEX64
-  }
-}
-op {
-  name: "IFFT"
-  input_arg {
-    name: "input"
-    type_attr: "Tcomplex"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "Tcomplex"
-  }
-  attr {
-    name: "Tcomplex"
-    type: "type"
-    default_value {
-      type: DT_COMPLEX64
-    }
-    allowed_values {
-      list {
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "IFFT2D"
-  input_arg {
-    name: "input"
-    type: DT_COMPLEX64
-  }
-  output_arg {
-    name: "output"
-    type: DT_COMPLEX64
-  }
-}
-op {
-  name: "IFFT2D"
-  input_arg {
-    name: "input"
-    type_attr: "Tcomplex"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "Tcomplex"
-  }
-  attr {
-    name: "Tcomplex"
-    type: "type"
-    default_value {
-      type: DT_COMPLEX64
-    }
-    allowed_values {
-      list {
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "IFFT3D"
-  input_arg {
-    name: "input"
-    type: DT_COMPLEX64
-  }
-  output_arg {
-    name: "output"
-    type: DT_COMPLEX64
-  }
-}
-op {
-  name: "IFFT3D"
-  input_arg {
-    name: "input"
-    type_attr: "Tcomplex"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "Tcomplex"
-  }
-  attr {
-    name: "Tcomplex"
-    type: "type"
-    default_value {
-      type: DT_COMPLEX64
-    }
-    allowed_values {
-      list {
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "IRFFT"
-  input_arg {
-    name: "input"
-    type: DT_COMPLEX64
-  }
-  input_arg {
-    name: "fft_length"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type: DT_FLOAT
-  }
-}
-op {
-  name: "IRFFT2D"
-  input_arg {
-    name: "input"
-    type: DT_COMPLEX64
-  }
-  input_arg {
-    name: "fft_length"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type: DT_FLOAT
-  }
-}
-op {
-  name: "IRFFT3D"
-  input_arg {
-    name: "input"
-    type: DT_COMPLEX64
-  }
-  input_arg {
-    name: "fft_length"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type: DT_FLOAT
-  }
-}
-op {
-  name: "Identity"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "IdentityN"
-  input_arg {
-    name: "input"
-    type_list_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "IdentityReader"
-  output_arg {
-    name: "reader_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "IdentityReader"
-  output_arg {
-    name: "reader_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  deprecation {
-    version: 26
-  }
-  is_stateful: true
-}
-op {
-  name: "IdentityReaderV2"
-  output_arg {
-    name: "reader_handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "If"
-  input_arg {
-    name: "cond"
-    type_attr: "Tcond"
-  }
-  input_arg {
-    name: "input"
-    type_list_attr: "Tin"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "Tout"
-  }
-  attr {
-    name: "Tcond"
-    type: "type"
-  }
-  attr {
-    name: "Tin"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "Tout"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "then_branch"
-    type: "func"
-  }
-  attr {
-    name: "else_branch"
-    type: "func"
-  }
-}
-op {
-  name: "If"
-  input_arg {
-    name: "cond"
-    type_attr: "Tcond"
-  }
-  input_arg {
-    name: "input"
-    type_list_attr: "Tin"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "Tout"
-  }
-  attr {
-    name: "Tcond"
-    type: "type"
-  }
-  attr {
-    name: "Tin"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tout"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "then_branch"
-    type: "func"
-  }
-  attr {
-    name: "else_branch"
-    type: "func"
-  }
-}
-op {
-  name: "If"
-  input_arg {
-    name: "cond"
-    type_attr: "Tcond"
-  }
-  input_arg {
-    name: "input"
-    type_list_attr: "Tin"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "Tout"
-  }
-  attr {
-    name: "Tcond"
-    type: "type"
-  }
-  attr {
-    name: "Tin"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tout"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "then_branch"
-    type: "func"
-  }
-  attr {
-    name: "else_branch"
-    type: "func"
-  }
-}
-op {
-  name: "If"
-  input_arg {
-    name: "cond"
-    type_attr: "Tcond"
-  }
-  input_arg {
-    name: "input"
-    type_list_attr: "Tin"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "Tout"
-  }
-  attr {
-    name: "Tcond"
-    type: "type"
-  }
-  attr {
-    name: "Tin"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tout"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "then_branch"
-    type: "func"
-  }
-  attr {
-    name: "else_branch"
-    type: "func"
-  }
-  is_stateful: true
-}
-op {
-  name: "If"
-  input_arg {
-    name: "cond"
-    type_attr: "Tcond"
-  }
-  input_arg {
-    name: "input"
-    type_list_attr: "Tin"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "Tout"
-  }
-  attr {
-    name: "Tcond"
-    type: "type"
-  }
-  attr {
-    name: "Tin"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tout"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "then_branch"
-    type: "func"
-  }
-  attr {
-    name: "else_branch"
-    type: "func"
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    default_value {
-      list {
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "Igamma"
-  input_arg {
-    name: "a"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "IgammaGradA"
-  input_arg {
-    name: "a"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Igammac"
-  input_arg {
-    name: "a"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "IgnoreErrorsDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "Imag"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "Tout"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_COMPLEX64
-    }
-    allowed_values {
-      list {
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  attr {
-    name: "Tout"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "ImageSummary"
-  input_arg {
-    name: "tag"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "tensor"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "summary"
-    type: DT_STRING
-  }
-  attr {
-    name: "max_images"
-    type: "int"
-    default_value {
-      i: 3
-    }
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_FLOAT
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "bad_color"
-    type: "tensor"
-    default_value {
-      tensor {
-        dtype: DT_UINT8
-        tensor_shape {
-          dim {
-            size: 4
-          }
-        }
-        int_val: 255
-        int_val: 0
-        int_val: 0
-        int_val: 255
-      }
-    }
-  }
-}
-op {
-  name: "ImageSummary"
-  input_arg {
-    name: "tag"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "tensor"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "summary"
-    type: DT_STRING
-  }
-  attr {
-    name: "max_images"
-    type: "int"
-    default_value {
-      i: 3
-    }
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_FLOAT
-        type: DT_HALF
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "bad_color"
-    type: "tensor"
-    default_value {
-      tensor {
-        dtype: DT_UINT8
-        tensor_shape {
-          dim {
-            size: 4
-          }
-        }
-        int_val: 255
-        int_val: 0
-        int_val: 0
-        int_val: 255
-      }
-    }
-  }
-}
-op {
-  name: "ImmutableConst"
-  output_arg {
-    name: "tensor"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  attr {
-    name: "memory_region_name"
-    type: "string"
-  }
-}
-op {
-  name: "ImportEvent"
-  input_arg {
-    name: "writer"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "event"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "InTopK"
-  input_arg {
-    name: "predictions"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "targets"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "precision"
-    type: DT_BOOL
-  }
-  attr {
-    name: "k"
-    type: "int"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "InTopKV2"
-  input_arg {
-    name: "predictions"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "targets"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "k"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "precision"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "InfeedDequeue"
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  is_stateful: true
-}
-op {
-  name: "InfeedDequeueTuple"
-  output_arg {
-    name: "outputs"
-    type_list_attr: "dtypes"
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "shapes"
-    type: "list(shape)"
-  }
-  is_stateful: true
-}
-op {
-  name: "InfeedEnqueue"
-  input_arg {
-    name: "input"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-    default_value {
-      shape {
-      }
-    }
-  }
-  attr {
-    name: "layout"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "device_ordinal"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "InfeedEnqueuePrelinearizedBuffer"
-  input_arg {
-    name: "input"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "device_ordinal"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-}
-op {
-  name: "InfeedEnqueueTuple"
-  input_arg {
-    name: "inputs"
-    type_list_attr: "dtypes"
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "shapes"
-    type: "list(shape)"
-  }
-  attr {
-    name: "layouts"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "device_ordinal"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "InitializeTable"
-  input_arg {
-    name: "table_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "keys"
-    type_attr: "Tkey"
-  }
-  input_arg {
-    name: "values"
-    type_attr: "Tval"
-  }
-  attr {
-    name: "Tkey"
-    type: "type"
-  }
-  attr {
-    name: "Tval"
-    type: "type"
-  }
-}
-op {
-  name: "InitializeTableFromTextFile"
-  input_arg {
-    name: "table_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "filename"
-    type: DT_STRING
-  }
-  attr {
-    name: "key_index"
-    type: "int"
-    has_minimum: true
-    minimum: -2
-  }
-  attr {
-    name: "value_index"
-    type: "int"
-    has_minimum: true
-    minimum: -2
-  }
-  attr {
-    name: "vocab_size"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "delimiter"
-    type: "string"
-    default_value {
-      s: "\t"
-    }
-  }
-}
-op {
-  name: "InitializeTableFromTextFileV2"
-  input_arg {
-    name: "table_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "filename"
-    type: DT_STRING
-  }
-  attr {
-    name: "key_index"
-    type: "int"
-    has_minimum: true
-    minimum: -2
-  }
-  attr {
-    name: "value_index"
-    type: "int"
-    has_minimum: true
-    minimum: -2
-  }
-  attr {
-    name: "vocab_size"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "delimiter"
-    type: "string"
-    default_value {
-      s: "\t"
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "InitializeTableV2"
-  input_arg {
-    name: "table_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "keys"
-    type_attr: "Tkey"
-  }
-  input_arg {
-    name: "values"
-    type_attr: "Tval"
-  }
-  attr {
-    name: "Tkey"
-    type: "type"
-  }
-  attr {
-    name: "Tval"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "InplaceAdd"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "i"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "v"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "InplaceSub"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "i"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "v"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "InplaceUpdate"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "i"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "v"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "InterleaveDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  input_arg {
-    name: "cycle_length"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "block_length"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "InterleaveDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  input_arg {
-    name: "cycle_length"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "block_length"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "Inv"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  deprecation {
-    version: 17
-  }
-}
-op {
-  name: "Inv"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  deprecation {
-    version: 17
-  }
-}
-op {
-  name: "Inv"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Inv"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  deprecation {
-    version: 17
-  }
-}
-op {
-  name: "Inv"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Inv"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "InvGrad"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  deprecation {
-    version: 17
-  }
-}
-op {
-  name: "InvGrad"
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dy"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  deprecation {
-    version: 17
-  }
-}
-op {
-  name: "InvGrad"
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dy"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  deprecation {
-    version: 17
-  }
-}
-op {
-  name: "InvGrad"
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dy"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "InvGrad"
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dy"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  deprecation {
-    version: 17
-  }
-}
-op {
-  name: "InvGrad"
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dy"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "InvGrad"
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dy"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Invert"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_UINT16
-      }
-    }
-  }
-}
-op {
-  name: "Invert"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "InvertPermutation"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "IsBoostedTreesEnsembleInitialized"
-  input_arg {
-    name: "tree_ensemble_handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "is_initialized"
-    type: DT_BOOL
-  }
-  is_stateful: true
-}
-op {
-  name: "IsBoostedTreesQuantileStreamResourceInitialized"
-  input_arg {
-    name: "quantile_stream_resource_handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "is_initialized"
-    type: DT_BOOL
-  }
-  is_stateful: true
-}
-op {
-  name: "IsFinite"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "IsFinite"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "IsFinite"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "IsInf"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "IsInf"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "IsInf"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "IsNan"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "IsNan"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "IsNan"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "IsVariableInitialized"
-  input_arg {
-    name: "ref"
-    type_attr: "dtype"
-    is_ref: true
-  }
-  output_arg {
-    name: "is_initialized"
-    type: DT_BOOL
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  allows_uninitialized_input: true
-}
-op {
-  name: "Iterator"
-  output_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-  }
-  attr {
-    name: "container"
-    type: "string"
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "IteratorFromStringHandle"
-  input_arg {
-    name: "string_handle"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "resource_handle"
-    type: DT_RESOURCE
-  }
-  is_stateful: true
-}
-op {
-  name: "IteratorFromStringHandle"
-  input_arg {
-    name: "string_handle"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "resource_handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-  }
-  is_stateful: true
-}
-op {
-  name: "IteratorFromStringHandleV2"
-  input_arg {
-    name: "string_handle"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "resource_handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-  }
-  is_stateful: true
-}
-op {
-  name: "IteratorGetDevice"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "device"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "IteratorGetNext"
-  input_arg {
-    name: "iterator"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "components"
-    type_list_attr: "output_types"
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "IteratorGetNextAsOptional"
-  input_arg {
-    name: "iterator"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "optional"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "IteratorGetNextSync"
-  input_arg {
-    name: "iterator"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "components"
-    type_list_attr: "output_types"
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "IteratorToStringHandle"
-  input_arg {
-    name: "resource_handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "string_handle"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "IteratorV2"
-  output_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-  }
-  attr {
-    name: "container"
-    type: "string"
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "KMC2ChainInitialization"
-  input_arg {
-    name: "distances"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "seed"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "index"
-    type: DT_INT64
-  }
-}
-op {
-  name: "KmeansPlusPlusInitialization"
-  input_arg {
-    name: "points"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "num_to_sample"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "seed"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "num_retries_per_sample"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "samples"
-    type: DT_FLOAT
-  }
-}
-op {
-  name: "L2Loss"
-  input_arg {
-    name: "t"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "L2Loss"
-  input_arg {
-    name: "t"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "L2Loss"
-  input_arg {
-    name: "t"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "LMDBDataset"
-  input_arg {
-    name: "filenames"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "LMDBReader"
-  output_arg {
-    name: "reader_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "LRN"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "depth_radius"
-    type: "int"
-    default_value {
-      i: 5
-    }
-  }
-  attr {
-    name: "bias"
-    type: "float"
-    default_value {
-      f: 1
-    }
-  }
-  attr {
-    name: "alpha"
-    type: "float"
-    default_value {
-      f: 1
-    }
-  }
-  attr {
-    name: "beta"
-    type: "float"
-    default_value {
-      f: 0.5
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "LRN"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "depth_radius"
-    type: "int"
-    default_value {
-      i: 5
-    }
-  }
-  attr {
-    name: "bias"
-    type: "float"
-    default_value {
-      f: 1
-    }
-  }
-  attr {
-    name: "alpha"
-    type: "float"
-    default_value {
-      f: 1
-    }
-  }
-  attr {
-    name: "beta"
-    type: "float"
-    default_value {
-      f: 0.5
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "LRNGrad"
-  input_arg {
-    name: "input_grads"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_image"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_image"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "depth_radius"
-    type: "int"
-    default_value {
-      i: 5
-    }
-  }
-  attr {
-    name: "bias"
-    type: "float"
-    default_value {
-      f: 1
-    }
-  }
-  attr {
-    name: "alpha"
-    type: "float"
-    default_value {
-      f: 1
-    }
-  }
-  attr {
-    name: "beta"
-    type: "float"
-    default_value {
-      f: 0.5
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "LRNGrad"
-  input_arg {
-    name: "input_grads"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_image"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "output_image"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "depth_radius"
-    type: "int"
-    default_value {
-      i: 5
-    }
-  }
-  attr {
-    name: "bias"
-    type: "float"
-    default_value {
-      f: 1
-    }
-  }
-  attr {
-    name: "alpha"
-    type: "float"
-    default_value {
-      f: 1
-    }
-  }
-  attr {
-    name: "beta"
-    type: "float"
-    default_value {
-      f: 0.5
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "LSTMBlockCell"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "cs_prev"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "h_prev"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "w"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "wci"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "wcf"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "wco"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "i"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "cs"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "f"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "o"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "ci"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "co"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "h"
-    type_attr: "T"
-  }
-  attr {
-    name: "forget_bias"
-    type: "float"
-    default_value {
-      f: 1
-    }
-  }
-  attr {
-    name: "cell_clip"
-    type: "float"
-    default_value {
-      f: 3
-    }
-  }
-  attr {
-    name: "use_peephole"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "LSTMBlockCellGrad"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "cs_prev"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "h_prev"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "w"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "wci"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "wcf"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "wco"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "i"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "cs"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "f"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "o"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "ci"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "co"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "cs_grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "h_grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "cs_prev_grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "dicfo"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "wci_grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "wcf_grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "wco_grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "use_peephole"
-    type: "bool"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "LatencyStatsDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "tag"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "LeakyRelu"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "alpha"
-    type: "float"
-    default_value {
-      f: 0.2
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "LeakyRelu"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "alpha"
-    type: "float"
-    default_value {
-      f: 0.2
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "LeakyReluGrad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "alpha"
-    type: "float"
-    default_value {
-      f: 0.2
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "LeakyReluGrad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "alpha"
-    type: "float"
-    default_value {
-      f: 0.2
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "LearnedUnigramCandidateSampler"
-  input_arg {
-    name: "true_classes"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sampled_candidates"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "true_expected_count"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "sampled_expected_count"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "num_true"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "num_sampled"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "unique"
-    type: "bool"
-  }
-  attr {
-    name: "range_max"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-}
-op {
-  name: "LearnedUnigramCandidateSampler"
-  input_arg {
-    name: "true_classes"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sampled_candidates"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "true_expected_count"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "sampled_expected_count"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "num_true"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "num_sampled"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "unique"
-    type: "bool"
-  }
-  attr {
-    name: "range_max"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "LeftShift"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "LeftShift"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "Less"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "Less"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "Less"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "Less"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "LessEqual"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "LessEqual"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "LessEqual"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "LessEqual"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "Lgamma"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Lgamma"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Lgamma"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "LinSpace"
-  input_arg {
-    name: "start"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "stop"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "num"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "LinSpace"
-  input_arg {
-    name: "start"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "stop"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "num"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "ListDiff"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "idx"
-    type_attr: "out_idx"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "out_idx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "LoadAndRemapMatrix"
-  input_arg {
-    name: "ckpt_path"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "old_tensor_name"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "row_remapping"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "col_remapping"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "initializing_values"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output_matrix"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "num_rows"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "num_cols"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "max_rows_in_memory"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "LoadTPUEmbeddingADAMParameters"
-  input_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "momenta"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "velocities"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "LoadTPUEmbeddingADAMParametersGradAccumDebug"
-  input_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "momenta"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "velocities"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "gradient_accumulators"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "LoadTPUEmbeddingAdadeltaParameters"
-  input_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "accumulators"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "updates"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "LoadTPUEmbeddingAdadeltaParametersGradAccumDebug"
-  input_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "accumulators"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "updates"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "gradient_accumulators"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "LoadTPUEmbeddingAdagradParameters"
-  input_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "accumulators"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "LoadTPUEmbeddingAdagradParametersGradAccumDebug"
-  input_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "accumulators"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "gradient_accumulators"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "LoadTPUEmbeddingCenteredRMSPropParameters"
-  input_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "ms"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "mom"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "mg"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "LoadTPUEmbeddingFTRLParameters"
-  input_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "accumulators"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "linears"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "LoadTPUEmbeddingFTRLParametersGradAccumDebug"
-  input_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "accumulators"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "linears"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "gradient_accumulators"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "LoadTPUEmbeddingMDLAdagradLightParameters"
-  input_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "accumulators"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "weights"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "benefits"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "LoadTPUEmbeddingMomentumParameters"
-  input_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "momenta"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "LoadTPUEmbeddingMomentumParametersGradAccumDebug"
-  input_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "momenta"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "gradient_accumulators"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "LoadTPUEmbeddingProximalAdagradParameters"
-  input_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "accumulators"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug"
-  input_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "accumulators"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "gradient_accumulators"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "LoadTPUEmbeddingRMSPropParameters"
-  input_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "ms"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "mom"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "LoadTPUEmbeddingRMSPropParametersGradAccumDebug"
-  input_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "ms"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "mom"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "gradient_accumulators"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "LoadTPUEmbeddingStochasticGradientDescentParameters"
-  input_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "Log"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Log"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Log"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Log1p"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Log1p"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Log1p"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "LogMatrixDeterminant"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "sign"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "log_abs_determinant"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "LogMatrixDeterminant"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "sign"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "log_abs_determinant"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "LogSoftmax"
-  input_arg {
-    name: "logits"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "logsoftmax"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "LogSoftmax"
-  input_arg {
-    name: "logits"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "logsoftmax"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "LogUniformCandidateSampler"
-  input_arg {
-    name: "true_classes"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sampled_candidates"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "true_expected_count"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "sampled_expected_count"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "num_true"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "num_sampled"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "unique"
-    type: "bool"
-  }
-  attr {
-    name: "range_max"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-}
-op {
-  name: "LogUniformCandidateSampler"
-  input_arg {
-    name: "true_classes"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sampled_candidates"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "true_expected_count"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "sampled_expected_count"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "num_true"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "num_sampled"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "unique"
-    type: "bool"
-  }
-  attr {
-    name: "range_max"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "LogicalAnd"
-  input_arg {
-    name: "x"
-    type: DT_BOOL
-  }
-  input_arg {
-    name: "y"
-    type: DT_BOOL
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  is_commutative: true
-}
-op {
-  name: "LogicalNot"
-  input_arg {
-    name: "x"
-    type: DT_BOOL
-  }
-  output_arg {
-    name: "y"
-    type: DT_BOOL
-  }
-}
-op {
-  name: "LogicalOr"
-  input_arg {
-    name: "x"
-    type: DT_BOOL
-  }
-  input_arg {
-    name: "y"
-    type: DT_BOOL
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  is_commutative: true
-}
-op {
-  name: "LookupTableExport"
-  input_arg {
-    name: "table_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  output_arg {
-    name: "keys"
-    type_attr: "Tkeys"
-  }
-  output_arg {
-    name: "values"
-    type_attr: "Tvalues"
-  }
-  attr {
-    name: "Tkeys"
-    type: "type"
-  }
-  attr {
-    name: "Tvalues"
-    type: "type"
-  }
-}
-op {
-  name: "LookupTableExportV2"
-  input_arg {
-    name: "table_handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "keys"
-    type_attr: "Tkeys"
-  }
-  output_arg {
-    name: "values"
-    type_attr: "Tvalues"
-  }
-  attr {
-    name: "Tkeys"
-    type: "type"
-  }
-  attr {
-    name: "Tvalues"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "LookupTableFind"
-  input_arg {
-    name: "table_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "keys"
-    type_attr: "Tin"
-  }
-  input_arg {
-    name: "default_value"
-    type_attr: "Tout"
-  }
-  output_arg {
-    name: "values"
-    type_attr: "Tout"
-  }
-  attr {
-    name: "Tin"
-    type: "type"
-  }
-  attr {
-    name: "Tout"
-    type: "type"
-  }
-}
-op {
-  name: "LookupTableFindV2"
-  input_arg {
-    name: "table_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "keys"
-    type_attr: "Tin"
-  }
-  input_arg {
-    name: "default_value"
-    type_attr: "Tout"
-  }
-  output_arg {
-    name: "values"
-    type_attr: "Tout"
-  }
-  attr {
-    name: "Tin"
-    type: "type"
-  }
-  attr {
-    name: "Tout"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "LookupTableImport"
-  input_arg {
-    name: "table_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "keys"
-    type_attr: "Tin"
-  }
-  input_arg {
-    name: "values"
-    type_attr: "Tout"
-  }
-  attr {
-    name: "Tin"
-    type: "type"
-  }
-  attr {
-    name: "Tout"
-    type: "type"
-  }
-}
-op {
-  name: "LookupTableImportV2"
-  input_arg {
-    name: "table_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "keys"
-    type_attr: "Tin"
-  }
-  input_arg {
-    name: "values"
-    type_attr: "Tout"
-  }
-  attr {
-    name: "Tin"
-    type: "type"
-  }
-  attr {
-    name: "Tout"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "LookupTableInsert"
-  input_arg {
-    name: "table_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "keys"
-    type_attr: "Tin"
-  }
-  input_arg {
-    name: "values"
-    type_attr: "Tout"
-  }
-  attr {
-    name: "Tin"
-    type: "type"
-  }
-  attr {
-    name: "Tout"
-    type: "type"
-  }
-}
-op {
-  name: "LookupTableInsertV2"
-  input_arg {
-    name: "table_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "keys"
-    type_attr: "Tin"
-  }
-  input_arg {
-    name: "values"
-    type_attr: "Tout"
-  }
-  attr {
-    name: "Tin"
-    type: "type"
-  }
-  attr {
-    name: "Tout"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "LookupTableRemoveV2"
-  input_arg {
-    name: "table_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "keys"
-    type_attr: "Tin"
-  }
-  attr {
-    name: "Tin"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "LookupTableSize"
-  input_arg {
-    name: "table_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  output_arg {
-    name: "size"
-    type: DT_INT64
-  }
-}
-op {
-  name: "LookupTableSizeV2"
-  input_arg {
-    name: "table_handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "size"
-    type: DT_INT64
-  }
-  is_stateful: true
-}
-op {
-  name: "LoopCond"
-  input_arg {
-    name: "input"
-    type: DT_BOOL
-  }
-  output_arg {
-    name: "output"
-    type: DT_BOOL
-  }
-}
-op {
-  name: "LowerBound"
-  input_arg {
-    name: "sorted_inputs"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Lu"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "lu"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "p"
-    type_attr: "output_idx_type"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  attr {
-    name: "output_idx_type"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Lu"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "lu"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "p"
-    type_attr: "output_idx_type"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_HALF
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  attr {
-    name: "output_idx_type"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "MakeIterator"
-  input_arg {
-    name: "dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "iterator"
-    type: DT_RESOURCE
-  }
-  is_stateful: true
-}
-op {
-  name: "MapAndBatchDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  input_arg {
-    name: "batch_size"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "num_parallel_calls"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "drop_remainder"
-    type: DT_BOOL
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "preserve_cardinality"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "MapClear"
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "memory_limit"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "MapDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "MapDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "MapDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "use_inter_op_parallelism"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "MapDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "use_inter_op_parallelism"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "preserve_cardinality"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "MapDefun"
-  input_arg {
-    name: "arguments"
-    type_list_attr: "Targuments"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "output_types"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-}
-op {
-  name: "MapDefun"
-  input_arg {
-    name: "arguments"
-    type_list_attr: "Targuments"
-  }
-  input_arg {
-    name: "captured_inputs"
-    type_list_attr: "Tcaptured"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "output_types"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "Tcaptured"
-    type: "list(type)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-}
-op {
-  name: "MapDefun"
-  input_arg {
-    name: "arguments"
-    type_list_attr: "Targuments"
-  }
-  input_arg {
-    name: "captured_inputs"
-    type_list_attr: "Tcaptured"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "output_types"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "Tcaptured"
-    type: "list(type)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "max_intra_op_parallelism"
-    type: "int"
-    default_value {
-      i: 1
-    }
-  }
-}
-op {
-  name: "MapIncompleteSize"
-  output_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "memory_limit"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "MapPeek"
-  input_arg {
-    name: "key"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "values"
-    type_list_attr: "dtypes"
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "memory_limit"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "MapSize"
-  output_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "memory_limit"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "MapStage"
-  input_arg {
-    name: "key"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "values"
-    type_list_attr: "fake_dtypes"
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "memory_limit"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-  }
-  attr {
-    name: "fake_dtypes"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "MapUnstage"
-  input_arg {
-    name: "key"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "values"
-    type_list_attr: "dtypes"
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "memory_limit"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "MapUnstageNoKey"
-  input_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "key"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "values"
-    type_list_attr: "dtypes"
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "memory_limit"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "MatMul"
-  input_arg {
-    name: "a"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "product"
-    type_attr: "T"
-  }
-  attr {
-    name: "transpose_a"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "transpose_b"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "MatMul"
-  input_arg {
-    name: "a"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "product"
-    type_attr: "T"
-  }
-  attr {
-    name: "transpose_a"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "transpose_b"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "MatMul"
-  input_arg {
-    name: "a"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "product"
-    type_attr: "T"
-  }
-  attr {
-    name: "transpose_a"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "transpose_b"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "MatMul"
-  input_arg {
-    name: "a"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "product"
-    type_attr: "T"
-  }
-  attr {
-    name: "transpose_a"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "transpose_b"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "MatchingFiles"
-  input_arg {
-    name: "pattern"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "filenames"
-    type: DT_STRING
-  }
-}
-op {
-  name: "MatchingFilesDataset"
-  input_arg {
-    name: "patterns"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  is_stateful: true
-}
-op {
-  name: "MatrixBandPart"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "num_lower"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "num_upper"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "band"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "MatrixBandPart"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "num_lower"
-    type_attr: "Tindex"
-  }
-  input_arg {
-    name: "num_upper"
-    type_attr: "Tindex"
-  }
-  output_arg {
-    name: "band"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tindex"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "MatrixDeterminant"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "MatrixDeterminant"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "MatrixDeterminant"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "MatrixDiag"
-  input_arg {
-    name: "diagonal"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "MatrixDiagPart"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "diagonal"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "MatrixDiagPartV2"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "k"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "padding_value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "diagonal"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "MatrixDiagV2"
-  input_arg {
-    name: "diagonal"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "k"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "num_rows"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "num_cols"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "padding_value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "MatrixExponential"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "MatrixExponential"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  deprecation {
-    version: 27
-  }
-}
-op {
-  name: "MatrixExponential"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_HALF
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  deprecation {
-    version: 27
-  }
-}
-op {
-  name: "MatrixInverse"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "adjoint"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "MatrixInverse"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "adjoint"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "MatrixInverse"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "adjoint"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_HALF
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "MatrixLogarithm"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "MatrixSetDiag"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "diagonal"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "MatrixSetDiagV2"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "diagonal"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "k"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "MatrixSolve"
-  input_arg {
-    name: "matrix"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rhs"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "adjoint"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "MatrixSolve"
-  input_arg {
-    name: "matrix"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rhs"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "adjoint"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_HALF
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "MatrixSolveLs"
-  input_arg {
-    name: "matrix"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rhs"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2_regularizer"
-    type: DT_DOUBLE
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "fast"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "MatrixSolveLs"
-  input_arg {
-    name: "matrix"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rhs"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2_regularizer"
-    type: DT_DOUBLE
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  attr {
-    name: "fast"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "MatrixSolveLs"
-  input_arg {
-    name: "matrix"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rhs"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2_regularizer"
-    type: DT_DOUBLE
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_HALF
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  attr {
-    name: "fast"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "MatrixSquareRoot"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "MatrixSquareRoot"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_HALF
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "MatrixTriangularSolve"
-  input_arg {
-    name: "matrix"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rhs"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "lower"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "adjoint"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "MatrixTriangularSolve"
-  input_arg {
-    name: "matrix"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rhs"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "lower"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "adjoint"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "MatrixTriangularSolve"
-  input_arg {
-    name: "matrix"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rhs"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "lower"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "adjoint"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_HALF
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Max"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reduction_indices"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Max"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reduction_indices"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Max"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reduction_indices"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Max"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reduction_indices"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "MaxIntraOpParallelismDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "max_intra_op_parallelism"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "MaxPool"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-}
-op {
-  name: "MaxPool"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-}
-op {
-  name: "MaxPool"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_QINT8
-      }
-    }
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-        s: "NCHW_VECT_C"
-      }
-    }
-  }
-}
-op {
-  name: "MaxPool"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_QINT8
-      }
-    }
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-        s: "NCHW_VECT_C"
-      }
-    }
-  }
-}
-op {
-  name: "MaxPool3D"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "MaxPool3D"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NDHWC"
-    }
-    allowed_values {
-      list {
-        s: "NDHWC"
-        s: "NCDHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "MaxPool3D"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NDHWC"
-    }
-    allowed_values {
-      list {
-        s: "NDHWC"
-        s: "NCDHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "MaxPool3D"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NDHWC"
-    }
-    allowed_values {
-      list {
-        s: "NDHWC"
-        s: "NCDHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "MaxPool3DGrad"
-  input_arg {
-    name: "orig_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "orig_output"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "MaxPool3DGrad"
-  input_arg {
-    name: "orig_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "orig_output"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NDHWC"
-    }
-    allowed_values {
-      list {
-        s: "NDHWC"
-        s: "NCDHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "MaxPool3DGrad"
-  input_arg {
-    name: "orig_input"
-    type_attr: "TInput"
-  }
-  input_arg {
-    name: "orig_output"
-    type_attr: "TInput"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NDHWC"
-    }
-    allowed_values {
-      list {
-        s: "NDHWC"
-        s: "NCDHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "TInput"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "MaxPool3DGrad"
-  input_arg {
-    name: "orig_input"
-    type_attr: "TInput"
-  }
-  input_arg {
-    name: "orig_output"
-    type_attr: "TInput"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NDHWC"
-    }
-    allowed_values {
-      list {
-        s: "NDHWC"
-        s: "NCDHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "TInput"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "MaxPool3DGrad"
-  input_arg {
-    name: "orig_input"
-    type_attr: "TInput"
-  }
-  input_arg {
-    name: "orig_output"
-    type_attr: "TInput"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NDHWC"
-    }
-    allowed_values {
-      list {
-        s: "NDHWC"
-        s: "NCDHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "TInput"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "MaxPool3DGradGrad"
-  input_arg {
-    name: "orig_input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "orig_output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NDHWC"
-    }
-    allowed_values {
-      list {
-        s: "NDHWC"
-        s: "NCDHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "MaxPool3DGradGrad"
-  input_arg {
-    name: "orig_input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "orig_output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 5
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NDHWC"
-    }
-    allowed_values {
-      list {
-        s: "NDHWC"
-        s: "NCDHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGrad"
-  input_arg {
-    name: "orig_input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "orig_output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGrad"
-  input_arg {
-    name: "orig_input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "orig_output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGrad"
-  input_arg {
-    name: "orig_input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "orig_output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGrad"
-  input_arg {
-    name: "orig_input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "orig_output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGrad"
-  input_arg {
-    name: "orig_input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "orig_output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGradGrad"
-  input_arg {
-    name: "orig_input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "orig_output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGradGrad"
-  input_arg {
-    name: "orig_input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "orig_output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGradGrad"
-  input_arg {
-    name: "orig_input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "orig_output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGradGrad"
-  input_arg {
-    name: "orig_input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "orig_output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGradGradV2"
-  input_arg {
-    name: "orig_input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "orig_output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "ksize"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "strides"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGradGradV2"
-  input_arg {
-    name: "orig_input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "orig_output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "ksize"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "strides"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGradGradV2"
-  input_arg {
-    name: "orig_input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "orig_output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "ksize"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "strides"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGradGradV2"
-  input_arg {
-    name: "orig_input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "orig_output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "ksize"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "strides"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGradGradWithArgmax"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "argmax"
-    type_attr: "Targmax"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "Targmax"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGradGradWithArgmax"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "argmax"
-    type_attr: "Targmax"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "Targmax"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGradGradWithArgmax"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "argmax"
-    type_attr: "Targmax"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "Targmax"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGradGradWithArgmax"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "argmax"
-    type_attr: "Targmax"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "Targmax"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGradGradWithArgmax"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "argmax"
-    type_attr: "Targmax"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "include_batch_in_index"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "Targmax"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGradV2"
-  input_arg {
-    name: "orig_input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "orig_output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "ksize"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "strides"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGradV2"
-  input_arg {
-    name: "orig_input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "orig_output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "ksize"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "strides"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGradV2"
-  input_arg {
-    name: "orig_input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "orig_output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "ksize"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "strides"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGradV2"
-  input_arg {
-    name: "orig_input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "orig_output"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "ksize"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "strides"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGradWithArgmax"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "argmax"
-    type_attr: "Targmax"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "Targmax"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGradWithArgmax"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "argmax"
-    type_attr: "Targmax"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "Targmax"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGradWithArgmax"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "argmax"
-    type_attr: "Targmax"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "Targmax"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGradWithArgmax"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "argmax"
-    type_attr: "Targmax"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "Targmax"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGradWithArgmax"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "argmax"
-    type_attr: "Targmax"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "Targmax"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolGradWithArgmax"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "argmax"
-    type_attr: "Targmax"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "include_batch_in_index"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "Targmax"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolV2"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "ksize"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "strides"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolV2"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "ksize"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "strides"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_QINT8
-      }
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-        s: "NCHW_VECT_C"
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolV2"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "ksize"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "strides"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_QINT8
-      }
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-        s: "NCHW_VECT_C"
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolWithArgmax"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "argmax"
-    type_attr: "Targmax"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "Targmax"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolWithArgmax"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "argmax"
-    type_attr: "Targmax"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "Targmax"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolWithArgmax"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "argmax"
-    type_attr: "Targmax"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "Targmax"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolWithArgmax"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "argmax"
-    type_attr: "Targmax"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "Targmax"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolWithArgmax"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "argmax"
-    type_attr: "Targmax"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "Targmax"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "MaxPoolWithArgmax"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "argmax"
-    type_attr: "Targmax"
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-    has_minimum: true
-    minimum: 4
-  }
-  attr {
-    name: "Targmax"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "include_batch_in_index"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "Maximum"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "Maximum"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "Maximum"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "Maximum"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Mean"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reduction_indices"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Mean"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reduction_indices"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Mean"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reduction_indices"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Mean"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reduction_indices"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Merge"
-  input_arg {
-    name: "inputs"
-    type_attr: "T"
-    number_attr: "N"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "value_index"
-    type: DT_INT32
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "MergeSummary"
-  input_arg {
-    name: "inputs"
-    type: DT_STRING
-    number_attr: "N"
-  }
-  output_arg {
-    name: "summary"
-    type: DT_STRING
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "MergeV2Checkpoints"
-  input_arg {
-    name: "checkpoint_prefixes"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "destination_prefix"
-    type: DT_STRING
-  }
-  attr {
-    name: "delete_old_dirs"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "MergeV2Checkpoints"
-  input_arg {
-    name: "checkpoint_prefixes"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "destination_prefix"
-    type: DT_STRING
-  }
-  attr {
-    name: "delete_old_dirs"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "Mfcc"
-  input_arg {
-    name: "spectrogram"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "sample_rate"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "upper_frequency_limit"
-    type: "float"
-    default_value {
-      f: 4000
-    }
-  }
-  attr {
-    name: "lower_frequency_limit"
-    type: "float"
-    default_value {
-      f: 20
-    }
-  }
-  attr {
-    name: "filterbank_channel_count"
-    type: "int"
-    default_value {
-      i: 40
-    }
-  }
-  attr {
-    name: "dct_coefficient_count"
-    type: "int"
-    default_value {
-      i: 13
-    }
-  }
-}
-op {
-  name: "Min"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reduction_indices"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Min"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reduction_indices"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Min"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reduction_indices"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Min"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reduction_indices"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Minimum"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "Minimum"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "Minimum"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "Minimum"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "MirrorPad"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "paddings"
-    type_attr: "Tpaddings"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tpaddings"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "mode"
-    type: "string"
-    allowed_values {
-      list {
-        s: "REFLECT"
-        s: "SYMMETRIC"
-      }
-    }
-  }
-}
-op {
-  name: "MirrorPadGrad"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "paddings"
-    type_attr: "Tpaddings"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tpaddings"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "mode"
-    type: "string"
-    allowed_values {
-      list {
-        s: "REFLECT"
-        s: "SYMMETRIC"
-      }
-    }
-  }
-}
-op {
-  name: "Mod"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Mod"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Mod"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_HALF
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "ModelDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ModelDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "cpu_budget"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ModelDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "algorithm"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "cpu_budget"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "Mul"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "Mul"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "Mul"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "MulNoNan"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "MulNoNan"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "MulNoNan"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "MultiDeviceIterator"
-  output_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "devices"
-    type: "list(string)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-  }
-  attr {
-    name: "container"
-    type: "string"
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "MultiDeviceIteratorFromStringHandle"
-  input_arg {
-    name: "string_handle"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "multi_device_iterator"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-  }
-  is_stateful: true
-}
-op {
-  name: "MultiDeviceIteratorGetNextFromShard"
-  input_arg {
-    name: "multi_device_iterator"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "shard_num"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "incarnation_id"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "components"
-    type_list_attr: "output_types"
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "MultiDeviceIteratorInit"
-  input_arg {
-    name: "dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "multi_device_iterator"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "max_buffer_size"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "incarnation_id"
-    type: DT_INT64
-  }
-  is_stateful: true
-}
-op {
-  name: "MultiDeviceIteratorToStringHandle"
-  input_arg {
-    name: "multi_device_iterator"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "string_handle"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "Multinomial"
-  input_arg {
-    name: "logits"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "num_samples"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type: DT_INT64
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "Multinomial"
-  input_arg {
-    name: "logits"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "num_samples"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type: DT_INT64
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "Multinomial"
-  input_arg {
-    name: "logits"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "num_samples"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "output_dtype"
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "output_dtype"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "Multinomial"
-  input_arg {
-    name: "logits"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "num_samples"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "output_dtype"
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "output_dtype"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "MutableDenseHashTable"
-  input_arg {
-    name: "empty_key"
-    type_attr: "key_dtype"
-  }
-  output_arg {
-    name: "table_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "use_node_name_sharing"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "key_dtype"
-    type: "type"
-  }
-  attr {
-    name: "value_dtype"
-    type: "type"
-  }
-  attr {
-    name: "value_shape"
-    type: "shape"
-    default_value {
-      shape {
-      }
-    }
-  }
-  attr {
-    name: "initial_num_buckets"
-    type: "int"
-    default_value {
-      i: 131072
-    }
-  }
-  attr {
-    name: "max_load_factor"
-    type: "float"
-    default_value {
-      f: 0.8
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "MutableDenseHashTableV2"
-  input_arg {
-    name: "empty_key"
-    type_attr: "key_dtype"
-  }
-  input_arg {
-    name: "deleted_key"
-    type_attr: "key_dtype"
-  }
-  output_arg {
-    name: "table_handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "use_node_name_sharing"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "key_dtype"
-    type: "type"
-  }
-  attr {
-    name: "value_dtype"
-    type: "type"
-  }
-  attr {
-    name: "value_shape"
-    type: "shape"
-    default_value {
-      shape {
-      }
-    }
-  }
-  attr {
-    name: "initial_num_buckets"
-    type: "int"
-    default_value {
-      i: 131072
-    }
-  }
-  attr {
-    name: "max_load_factor"
-    type: "float"
-    default_value {
-      f: 0.8
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "MutableHashTable"
-  output_arg {
-    name: "table_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "use_node_name_sharing"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "key_dtype"
-    type: "type"
-  }
-  attr {
-    name: "value_dtype"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "MutableHashTableOfTensors"
-  output_arg {
-    name: "table_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "use_node_name_sharing"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "key_dtype"
-    type: "type"
-  }
-  attr {
-    name: "value_dtype"
-    type: "type"
-  }
-  attr {
-    name: "value_shape"
-    type: "shape"
-    default_value {
-      shape {
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "MutableHashTableOfTensorsV2"
-  output_arg {
-    name: "table_handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "use_node_name_sharing"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "key_dtype"
-    type: "type"
-  }
-  attr {
-    name: "value_dtype"
-    type: "type"
-  }
-  attr {
-    name: "value_shape"
-    type: "shape"
-    default_value {
-      shape {
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "MutableHashTableV2"
-  output_arg {
-    name: "table_handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "use_node_name_sharing"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "key_dtype"
-    type: "type"
-  }
-  attr {
-    name: "value_dtype"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "MutexLock"
-  input_arg {
-    name: "mutex"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "mutex_lock"
-    type: DT_VARIANT
-  }
-  is_stateful: true
-}
-op {
-  name: "MutexV2"
-  output_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "NcclAllReduce"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  attr {
-    name: "reduction"
-    type: "string"
-    allowed_values {
-      list {
-        s: "min"
-        s: "max"
-        s: "prod"
-        s: "sum"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "num_devices"
-    type: "int"
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-  }
-  is_stateful: true
-}
-op {
-  name: "NcclBroadcast"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  is_stateful: true
-}
-op {
-  name: "NcclReduce"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-    number_attr: "num_devices"
-  }
-  output_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  attr {
-    name: "reduction"
-    type: "string"
-    allowed_values {
-      list {
-        s: "min"
-        s: "max"
-        s: "prod"
-        s: "sum"
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "num_devices"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "NearestNeighbors"
-  input_arg {
-    name: "points"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "centers"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "k"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "nearest_center_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "nearest_center_distances"
-    type: DT_FLOAT
-  }
-}
-op {
-  name: "Neg"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Neg"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Neg"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "NegTrain"
-  input_arg {
-    name: "w_in"
-    type: DT_FLOAT
-    is_ref: true
-  }
-  input_arg {
-    name: "w_out"
-    type: DT_FLOAT
-    is_ref: true
-  }
-  input_arg {
-    name: "examples"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "labels"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "lr"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "vocab_count"
-    type: "list(int)"
-  }
-  attr {
-    name: "num_negative_samples"
-    type: "int"
-  }
-  deprecation {
-    version: 19
-  }
-  is_stateful: true
-}
-op {
-  name: "NextAfter"
-  input_arg {
-    name: "x1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "x2"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "NextIteration"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "NoOp"
-}
-op {
-  name: "NonDeterministicInts"
-  input_arg {
-    name: "shape"
-    type_attr: "shape_dtype"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-  }
-  attr {
-    name: "shape_dtype"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "NonMaxSuppression"
-  input_arg {
-    name: "boxes"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "scores"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_output_size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "selected_indices"
-    type: DT_INT32
-  }
-  attr {
-    name: "iou_threshold"
-    type: "float"
-    default_value {
-      f: 0.5
-    }
-  }
-}
-op {
-  name: "NonMaxSuppressionV2"
-  input_arg {
-    name: "boxes"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "scores"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_output_size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "iou_threshold"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "selected_indices"
-    type: DT_INT32
-  }
-}
-op {
-  name: "NonMaxSuppressionV2"
-  input_arg {
-    name: "boxes"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "scores"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "max_output_size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "iou_threshold"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "selected_indices"
-    type: DT_INT32
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "NonMaxSuppressionV2"
-  input_arg {
-    name: "boxes"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "scores"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "max_output_size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "iou_threshold"
-    type_attr: "T_threshold"
-  }
-  output_arg {
-    name: "selected_indices"
-    type: DT_INT32
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "T_threshold"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "NonMaxSuppressionV3"
-  input_arg {
-    name: "boxes"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "scores"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_output_size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "iou_threshold"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "score_threshold"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "selected_indices"
-    type: DT_INT32
-  }
-}
-op {
-  name: "NonMaxSuppressionV3"
-  input_arg {
-    name: "boxes"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "scores"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "max_output_size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "iou_threshold"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "score_threshold"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "selected_indices"
-    type: DT_INT32
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "NonMaxSuppressionV3"
-  input_arg {
-    name: "boxes"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "scores"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "max_output_size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "iou_threshold"
-    type_attr: "T_threshold"
-  }
-  input_arg {
-    name: "score_threshold"
-    type_attr: "T_threshold"
-  }
-  output_arg {
-    name: "selected_indices"
-    type: DT_INT32
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "T_threshold"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "NonMaxSuppressionV4"
-  input_arg {
-    name: "boxes"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "scores"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_output_size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "iou_threshold"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "score_threshold"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "selected_indices"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "valid_outputs"
-    type: DT_INT32
-  }
-  attr {
-    name: "pad_to_max_output_size"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "NonMaxSuppressionV4"
-  input_arg {
-    name: "boxes"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "scores"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "max_output_size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "iou_threshold"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "score_threshold"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "selected_indices"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "valid_outputs"
-    type: DT_INT32
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "pad_to_max_output_size"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "NonMaxSuppressionV4"
-  input_arg {
-    name: "boxes"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "scores"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "max_output_size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "iou_threshold"
-    type_attr: "T_threshold"
-  }
-  input_arg {
-    name: "score_threshold"
-    type_attr: "T_threshold"
-  }
-  output_arg {
-    name: "selected_indices"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "valid_outputs"
-    type: DT_INT32
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "T_threshold"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "pad_to_max_output_size"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "NonMaxSuppressionV5"
-  input_arg {
-    name: "boxes"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "scores"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "max_output_size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "iou_threshold"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "score_threshold"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "soft_nms_sigma"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "selected_indices"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "selected_scores"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "valid_outputs"
-    type: DT_INT32
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "pad_to_max_output_size"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "NonMaxSuppressionWithOverlaps"
-  input_arg {
-    name: "overlaps"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "scores"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_output_size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "overlap_threshold"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "score_threshold"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "selected_indices"
-    type: DT_INT32
-  }
-}
-op {
-  name: "NonSerializableDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "NotEqual"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_QUINT8
-        type: DT_QINT8
-        type: DT_QINT32
-        type: DT_STRING
-        type: DT_BOOL
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "NotEqual"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_QUINT8
-        type: DT_QINT8
-        type: DT_QINT32
-        type: DT_STRING
-        type: DT_BOOL
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "NotEqual"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type: DT_BOOL
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_QUINT8
-        type: DT_QINT8
-        type: DT_QINT32
-        type: DT_STRING
-        type: DT_BOOL
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "NthElement"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "n"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  attr {
-    name: "reverse"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "NthElement"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "n"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  attr {
-    name: "reverse"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "NthElement"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "n"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  attr {
-    name: "reverse"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "OneHot"
-  input_arg {
-    name: "indices"
-    type_attr: "TI"
-  }
-  input_arg {
-    name: "depth"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "on_value"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "off_value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "axis"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "TI"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "OneShotIterator"
-  output_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "dataset_factory"
-    type: "func"
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "OnesLike"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "OnesLike"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT8
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_UINT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_BOOL
-      }
-    }
-  }
-}
-op {
-  name: "OnesLike"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT8
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_UINT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_BOOL
-      }
-    }
-  }
-}
-op {
-  name: "OptimizeDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "optimizations"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "OptimizeDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "optimizations"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "optimization_configs"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-}
-op {
-  name: "OptionalFromValue"
-  input_arg {
-    name: "components"
-    type_list_attr: "Toutput_types"
-  }
-  output_arg {
-    name: "optional"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "Toutput_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "OptionalGetValue"
-  input_arg {
-    name: "optional"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "components"
-    type_list_attr: "output_types"
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "OptionalHasValue"
-  input_arg {
-    name: "optional"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "has_value"
-    type: DT_BOOL
-  }
-}
-op {
-  name: "OptionalNone"
-  output_arg {
-    name: "optional"
-    type: DT_VARIANT
-  }
-}
-op {
-  name: "OrderedMapClear"
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "memory_limit"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "OrderedMapIncompleteSize"
-  output_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "memory_limit"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "OrderedMapPeek"
-  input_arg {
-    name: "key"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "values"
-    type_list_attr: "dtypes"
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "memory_limit"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "OrderedMapSize"
-  output_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "memory_limit"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "OrderedMapStage"
-  input_arg {
-    name: "key"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "values"
-    type_list_attr: "fake_dtypes"
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "memory_limit"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-  }
-  attr {
-    name: "fake_dtypes"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "OrderedMapUnstage"
-  input_arg {
-    name: "key"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "values"
-    type_list_attr: "dtypes"
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "memory_limit"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "OrderedMapUnstageNoKey"
-  input_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "key"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "values"
-    type_list_attr: "dtypes"
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "memory_limit"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "OutfeedDequeue"
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  attr {
-    name: "device_ordinal"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "OutfeedDequeueTuple"
-  output_arg {
-    name: "outputs"
-    type_list_attr: "dtypes"
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "shapes"
-    type: "list(shape)"
-  }
-  attr {
-    name: "device_ordinal"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "OutfeedEnqueue"
-  input_arg {
-    name: "input"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "OutfeedEnqueueTuple"
-  input_arg {
-    name: "inputs"
-    type_list_attr: "dtypes"
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "Pack"
-  input_arg {
-    name: "values"
-    type_attr: "T"
-    number_attr: "N"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "axis"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-}
-op {
-  name: "Pad"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "paddings"
-    type_attr: "Tpaddings"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tpaddings"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "PadV2"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "paddings"
-    type_attr: "Tpaddings"
-  }
-  input_arg {
-    name: "constant_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tpaddings"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "PaddedBatchDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "batch_size"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "padded_shapes"
-    type: DT_INT64
-    number_attr: "N"
-  }
-  input_arg {
-    name: "padding_values"
-    type_list_attr: "Toutput_types"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "Toutput_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "PaddedBatchDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "batch_size"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "padded_shapes"
-    type: DT_INT64
-    number_attr: "N"
-  }
-  input_arg {
-    name: "padding_values"
-    type_list_attr: "Toutput_types"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "Toutput_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "PaddedBatchDatasetV2"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "batch_size"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "padded_shapes"
-    type: DT_INT64
-    number_attr: "N"
-  }
-  input_arg {
-    name: "padding_values"
-    type_list_attr: "Toutput_types"
-  }
-  input_arg {
-    name: "drop_remainder"
-    type: DT_BOOL
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "Toutput_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "PaddedBatchDatasetV2"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "batch_size"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "padded_shapes"
-    type: DT_INT64
-    number_attr: "N"
-  }
-  input_arg {
-    name: "padding_values"
-    type_list_attr: "Toutput_types"
-  }
-  input_arg {
-    name: "drop_remainder"
-    type: DT_BOOL
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "parallel_copy"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "Toutput_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "PaddingFIFOQueue"
-  output_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "component_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "shapes"
-    type: "list(shape)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "PaddingFIFOQueueV2"
-  output_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "component_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "shapes"
-    type: "list(shape)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ParallelConcat"
-  input_arg {
-    name: "values"
-    type_attr: "T"
-    number_attr: "N"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-}
-op {
-  name: "ParallelDynamicStitch"
-  input_arg {
-    name: "indices"
-    type: DT_INT32
-    number_attr: "N"
-  }
-  input_arg {
-    name: "data"
-    type_attr: "T"
-    number_attr: "N"
-  }
-  output_arg {
-    name: "merged"
-    type_attr: "T"
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "ParallelInterleaveDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  input_arg {
-    name: "cycle_length"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "block_length"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sloppy"
-    type: DT_BOOL
-  }
-  input_arg {
-    name: "buffer_output_elements"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "prefetch_input_elements"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ParallelInterleaveDatasetV2"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  input_arg {
-    name: "cycle_length"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "block_length"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "num_parallel_calls"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ParallelInterleaveDatasetV2"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  input_arg {
-    name: "cycle_length"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "block_length"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "num_parallel_calls"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "sloppy"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ParallelMapDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  input_arg {
-    name: "num_parallel_calls"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "ParallelMapDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  input_arg {
-    name: "num_parallel_calls"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ParallelMapDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  input_arg {
-    name: "num_parallel_calls"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "use_inter_op_parallelism"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "ParallelMapDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  input_arg {
-    name: "num_parallel_calls"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "use_inter_op_parallelism"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "sloppy"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ParallelMapDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  input_arg {
-    name: "num_parallel_calls"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "use_inter_op_parallelism"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "sloppy"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "preserve_cardinality"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ParameterizedTruncatedNormal"
-  input_arg {
-    name: "shape"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "means"
-    type_attr: "dtype"
-  }
-  input_arg {
-    name: "stdevs"
-    type_attr: "dtype"
-  }
-  input_arg {
-    name: "minvals"
-    type_attr: "dtype"
-  }
-  input_arg {
-    name: "maxvals"
-    type_attr: "dtype"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ParameterizedTruncatedNormal"
-  input_arg {
-    name: "shape"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "means"
-    type_attr: "dtype"
-  }
-  input_arg {
-    name: "stdevs"
-    type_attr: "dtype"
-  }
-  input_arg {
-    name: "minvals"
-    type_attr: "dtype"
-  }
-  input_arg {
-    name: "maxvals"
-    type_attr: "dtype"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ParseExample"
-  input_arg {
-    name: "serialized"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "names"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "sparse_keys"
-    type: DT_STRING
-    number_attr: "Nsparse"
-  }
-  input_arg {
-    name: "dense_keys"
-    type: DT_STRING
-    number_attr: "Ndense"
-  }
-  input_arg {
-    name: "dense_defaults"
-    type_list_attr: "Tdense"
-  }
-  output_arg {
-    name: "sparse_indices"
-    type: DT_INT64
-    number_attr: "Nsparse"
-  }
-  output_arg {
-    name: "sparse_values"
-    type_list_attr: "sparse_types"
-  }
-  output_arg {
-    name: "sparse_shapes"
-    type: DT_INT64
-    number_attr: "Nsparse"
-  }
-  output_arg {
-    name: "dense_values"
-    type_list_attr: "Tdense"
-  }
-  attr {
-    name: "Nsparse"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "Ndense"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "sparse_types"
-    type: "list(type)"
-    has_minimum: true
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "Tdense"
-    type: "list(type)"
-    has_minimum: true
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "dense_shapes"
-    type: "list(shape)"
-    has_minimum: true
-  }
-}
-op {
-  name: "ParseExampleDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "num_parallel_calls"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "dense_defaults"
-    type_list_attr: "Tdense"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "sparse_keys"
-    type: "list(string)"
-    has_minimum: true
-  }
-  attr {
-    name: "dense_keys"
-    type: "list(string)"
-    has_minimum: true
-  }
-  attr {
-    name: "sparse_types"
-    type: "list(type)"
-    has_minimum: true
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "Tdense"
-    type: "list(type)"
-    has_minimum: true
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "dense_shapes"
-    type: "list(shape)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "sloppy"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ParseSequenceExample"
-  input_arg {
-    name: "serialized"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "debug_name"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "context_dense_defaults"
-    type_list_attr: "Tcontext_dense"
-  }
-  output_arg {
-    name: "context_sparse_indices"
-    type: DT_INT64
-    number_attr: "Ncontext_sparse"
-  }
-  output_arg {
-    name: "context_sparse_values"
-    type_list_attr: "context_sparse_types"
-  }
-  output_arg {
-    name: "context_sparse_shapes"
-    type: DT_INT64
-    number_attr: "Ncontext_sparse"
-  }
-  output_arg {
-    name: "context_dense_values"
-    type_list_attr: "Tcontext_dense"
-  }
-  output_arg {
-    name: "feature_list_sparse_indices"
-    type: DT_INT64
-    number_attr: "Nfeature_list_sparse"
-  }
-  output_arg {
-    name: "feature_list_sparse_values"
-    type_list_attr: "feature_list_sparse_types"
-  }
-  output_arg {
-    name: "feature_list_sparse_shapes"
-    type: DT_INT64
-    number_attr: "Nfeature_list_sparse"
-  }
-  output_arg {
-    name: "feature_list_dense_values"
-    type_list_attr: "feature_list_dense_types"
-  }
-  output_arg {
-    name: "feature_list_dense_lengths"
-    type: DT_INT64
-    number_attr: "Nfeature_list_dense"
-  }
-  attr {
-    name: "feature_list_dense_missing_assumed_empty"
-    type: "list(string)"
-    has_minimum: true
-  }
-  attr {
-    name: "context_sparse_keys"
-    type: "list(string)"
-    has_minimum: true
-  }
-  attr {
-    name: "context_dense_keys"
-    type: "list(string)"
-    has_minimum: true
-  }
-  attr {
-    name: "feature_list_sparse_keys"
-    type: "list(string)"
-    has_minimum: true
-  }
-  attr {
-    name: "feature_list_dense_keys"
-    type: "list(string)"
-    has_minimum: true
-  }
-  attr {
-    name: "Ncontext_sparse"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "Ncontext_dense"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "Nfeature_list_sparse"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "Nfeature_list_dense"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "context_sparse_types"
-    type: "list(type)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "Tcontext_dense"
-    type: "list(type)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "feature_list_dense_types"
-    type: "list(type)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "context_dense_shapes"
-    type: "list(shape)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "feature_list_sparse_types"
-    type: "list(type)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "feature_list_dense_shapes"
-    type: "list(shape)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-  }
-}
-op {
-  name: "ParseSingleExample"
-  input_arg {
-    name: "serialized"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "dense_defaults"
-    type_list_attr: "Tdense"
-  }
-  output_arg {
-    name: "sparse_indices"
-    type: DT_INT64
-    number_attr: "num_sparse"
-  }
-  output_arg {
-    name: "sparse_values"
-    type_list_attr: "sparse_types"
-  }
-  output_arg {
-    name: "sparse_shapes"
-    type: DT_INT64
-    number_attr: "num_sparse"
-  }
-  output_arg {
-    name: "dense_values"
-    type_list_attr: "Tdense"
-  }
-  attr {
-    name: "num_sparse"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "sparse_keys"
-    type: "list(string)"
-    has_minimum: true
-  }
-  attr {
-    name: "dense_keys"
-    type: "list(string)"
-    has_minimum: true
-  }
-  attr {
-    name: "sparse_types"
-    type: "list(type)"
-    has_minimum: true
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "Tdense"
-    type: "list(type)"
-    has_minimum: true
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "dense_shapes"
-    type: "list(shape)"
-    has_minimum: true
-  }
-}
-op {
-  name: "ParseSingleSequenceExample"
-  input_arg {
-    name: "serialized"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "feature_list_dense_missing_assumed_empty"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "context_sparse_keys"
-    type: DT_STRING
-    number_attr: "Ncontext_sparse"
-  }
-  input_arg {
-    name: "context_dense_keys"
-    type: DT_STRING
-    number_attr: "Ncontext_dense"
-  }
-  input_arg {
-    name: "feature_list_sparse_keys"
-    type: DT_STRING
-    number_attr: "Nfeature_list_sparse"
-  }
-  input_arg {
-    name: "feature_list_dense_keys"
-    type: DT_STRING
-    number_attr: "Nfeature_list_dense"
-  }
-  input_arg {
-    name: "context_dense_defaults"
-    type_list_attr: "Tcontext_dense"
-  }
-  input_arg {
-    name: "debug_name"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "context_sparse_indices"
-    type: DT_INT64
-    number_attr: "Ncontext_sparse"
-  }
-  output_arg {
-    name: "context_sparse_values"
-    type_list_attr: "context_sparse_types"
-  }
-  output_arg {
-    name: "context_sparse_shapes"
-    type: DT_INT64
-    number_attr: "Ncontext_sparse"
-  }
-  output_arg {
-    name: "context_dense_values"
-    type_list_attr: "Tcontext_dense"
-  }
-  output_arg {
-    name: "feature_list_sparse_indices"
-    type: DT_INT64
-    number_attr: "Nfeature_list_sparse"
-  }
-  output_arg {
-    name: "feature_list_sparse_values"
-    type_list_attr: "feature_list_sparse_types"
-  }
-  output_arg {
-    name: "feature_list_sparse_shapes"
-    type: DT_INT64
-    number_attr: "Nfeature_list_sparse"
-  }
-  output_arg {
-    name: "feature_list_dense_values"
-    type_list_attr: "feature_list_dense_types"
-  }
-  attr {
-    name: "Ncontext_sparse"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "Ncontext_dense"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "Nfeature_list_sparse"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "Nfeature_list_dense"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "context_sparse_types"
-    type: "list(type)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "Tcontext_dense"
-    type: "list(type)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "feature_list_dense_types"
-    type: "list(type)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "context_dense_shapes"
-    type: "list(shape)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "feature_list_sparse_types"
-    type: "list(type)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "feature_list_dense_shapes"
-    type: "list(shape)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-  }
-}
-op {
-  name: "ParseTensor"
-  input_arg {
-    name: "serialized"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-  }
-}
-op {
-  name: "PartitionedCall"
-  input_arg {
-    name: "args"
-    type_list_attr: "Tin"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "Tout"
-  }
-  attr {
-    name: "Tin"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tout"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-}
-op {
-  name: "PartitionedCall"
-  input_arg {
-    name: "args"
-    type_list_attr: "Tin"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "Tout"
-  }
-  attr {
-    name: "Tin"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tout"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "config"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-}
-op {
-  name: "PartitionedCall"
-  input_arg {
-    name: "args"
-    type_list_attr: "Tin"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "Tout"
-  }
-  attr {
-    name: "Tin"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tout"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "config"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "executor_type"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-}
-op {
-  name: "PartitionedCall"
-  input_arg {
-    name: "args"
-    type_list_attr: "Tin"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "Tout"
-  }
-  attr {
-    name: "Tin"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tout"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "config"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "config_proto"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "executor_type"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-}
-op {
-  name: "Placeholder"
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-    default_value {
-      shape {
-      }
-    }
-  }
-}
-op {
-  name: "Placeholder"
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-    default_value {
-      shape {
-        unknown_rank: true
-      }
-    }
-  }
-}
-op {
-  name: "PlaceholderV2"
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-}
-op {
-  name: "PlaceholderV2"
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  deprecation {
-    version: 23
-  }
-}
-op {
-  name: "PlaceholderWithDefault"
-  input_arg {
-    name: "input"
-    type_attr: "dtype"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-}
-op {
-  name: "Polygamma"
-  input_arg {
-    name: "a"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "PopulationCount"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type: DT_UINT8
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_UINT16
-      }
-    }
-  }
-}
-op {
-  name: "PopulationCount"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type: DT_UINT8
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "Pow"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Pow"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Pow"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_HALF
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "PrefetchDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "buffer_size"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "PrefetchDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "buffer_size"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "PrefetchDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "buffer_size"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "slack_period"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-}
-op {
-  name: "Prelinearize"
-  input_arg {
-    name: "input"
-    type_attr: "dtype"
-  }
-  output_arg {
-    name: "output"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-    default_value {
-      shape {
-      }
-    }
-  }
-  attr {
-    name: "layout"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-}
-op {
-  name: "PrelinearizeTuple"
-  input_arg {
-    name: "inputs"
-    type_list_attr: "dtypes"
-  }
-  output_arg {
-    name: "output"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "shapes"
-    type: "list(shape)"
-  }
-  attr {
-    name: "layouts"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-}
-op {
-  name: "PreventGradient"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "message"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-}
-op {
-  name: "Print"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "data"
-    type_list_attr: "U"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "U"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "message"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "first_n"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "summarize"
-    type: "int"
-    default_value {
-      i: 3
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "Print"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "data"
-    type_list_attr: "U"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "U"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "message"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "first_n"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "summarize"
-    type: "int"
-    default_value {
-      i: 3
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "PrintV2"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  attr {
-    name: "output_stream"
-    type: "string"
-    default_value {
-      s: "stderr"
-    }
-    allowed_values {
-      list {
-        s: "stdout"
-        s: "stderr"
-        s: "log(info)"
-        s: "log(warning)"
-        s: "log(error)"
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "PrintV2"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  attr {
-    name: "output_stream"
-    type: "string"
-    default_value {
-      s: "stderr"
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "PrintV2"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  attr {
-    name: "output_stream"
-    type: "string"
-    default_value {
-      s: "stderr"
-    }
-  }
-  attr {
-    name: "end"
-    type: "string"
-    default_value {
-      s: "\n"
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "PriorityQueue"
-  output_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "component_types"
-    type: "list(type)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "shapes"
-    type: "list(shape)"
-    has_minimum: true
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "PriorityQueueV2"
-  output_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "component_types"
-    type: "list(type)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "shapes"
-    type: "list(shape)"
-    has_minimum: true
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "PrivateThreadPoolDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "num_threads"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "Prod"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reduction_indices"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Prod"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reduction_indices"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Prod"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reduction_indices"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Prod"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reduction_indices"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "PyFunc"
-  input_arg {
-    name: "input"
-    type_list_attr: "Tin"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "Tout"
-  }
-  attr {
-    name: "token"
-    type: "string"
-  }
-  attr {
-    name: "Tin"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tout"
-    type: "list(type)"
-    has_minimum: true
-  }
-  is_stateful: true
-}
-op {
-  name: "PyFuncStateless"
-  input_arg {
-    name: "input"
-    type_list_attr: "Tin"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "Tout"
-  }
-  attr {
-    name: "token"
-    type: "string"
-  }
-  attr {
-    name: "Tin"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tout"
-    type: "list(type)"
-    has_minimum: true
-  }
-}
-op {
-  name: "Qr"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "q"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "r"
-    type_attr: "T"
-  }
-  attr {
-    name: "full_matrices"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Qr"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "q"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "r"
-    type_attr: "T"
-  }
-  attr {
-    name: "full_matrices"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_HALF
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "QuantizeAndDequantize"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "signed_input"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "num_bits"
-    type: "int"
-    default_value {
-      i: 8
-    }
-  }
-  attr {
-    name: "range_given"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "input_min"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "input_max"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "QuantizeAndDequantize"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "signed_input"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "num_bits"
-    type: "int"
-    default_value {
-      i: 8
-    }
-  }
-  attr {
-    name: "range_given"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "input_min"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "input_max"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  deprecation {
-    version: 21
-  }
-}
-op {
-  name: "QuantizeAndDequantize"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "signed_input"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "num_bits"
-    type: "int"
-    default_value {
-      i: 8
-    }
-  }
-  attr {
-    name: "range_given"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "input_min"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "input_max"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  deprecation {
-    version: 22
-  }
-}
-op {
-  name: "QuantizeAndDequantize"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "signed_input"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "num_bits"
-    type: "int"
-    default_value {
-      i: 8
-    }
-  }
-  attr {
-    name: "range_given"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "input_min"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "input_max"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  deprecation {
-    version: 22
-  }
-}
-op {
-  name: "QuantizeAndDequantize"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "signed_input"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "num_bits"
-    type: "int"
-    default_value {
-      i: 8
-    }
-  }
-  attr {
-    name: "range_given"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "input_min"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "input_max"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  deprecation {
-    version: 22
-  }
-}
-op {
-  name: "QuantizeAndDequantizeV2"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_min"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_max"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "signed_input"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "num_bits"
-    type: "int"
-    default_value {
-      i: 8
-    }
-  }
-  attr {
-    name: "range_given"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "QuantizeAndDequantizeV2"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_min"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_max"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "signed_input"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "num_bits"
-    type: "int"
-    default_value {
-      i: 8
-    }
-  }
-  attr {
-    name: "range_given"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "QuantizeAndDequantizeV2"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_min"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_max"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "signed_input"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "num_bits"
-    type: "int"
-    default_value {
-      i: 8
-    }
-  }
-  attr {
-    name: "range_given"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "QuantizeAndDequantizeV2"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_min"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_max"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "signed_input"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "num_bits"
-    type: "int"
-    default_value {
-      i: 8
-    }
-  }
-  attr {
-    name: "range_given"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "round_mode"
-    type: "string"
-    default_value {
-      s: "HALF_TO_EVEN"
-    }
-    allowed_values {
-      list {
-        s: "HALF_TO_EVEN"
-        s: "HALF_UP"
-      }
-    }
-  }
-}
-op {
-  name: "QuantizeAndDequantizeV2"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_min"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_max"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "signed_input"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "num_bits"
-    type: "int"
-    default_value {
-      i: 8
-    }
-  }
-  attr {
-    name: "range_given"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "round_mode"
-    type: "string"
-    default_value {
-      s: "HALF_TO_EVEN"
-    }
-    allowed_values {
-      list {
-        s: "HALF_TO_EVEN"
-        s: "HALF_UP"
-      }
-    }
-  }
-  attr {
-    name: "narrow_range"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "QuantizeAndDequantizeV3"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_min"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_max"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "num_bits"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "signed_input"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "range_given"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "QuantizeAndDequantizeV3"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_min"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_max"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "num_bits"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "signed_input"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "range_given"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "QuantizeAndDequantizeV3"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_min"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_max"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "num_bits"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "signed_input"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "range_given"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "QuantizeAndDequantizeV3"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_min"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_max"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "num_bits"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "signed_input"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "range_given"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "narrow_range"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "QuantizeDownAndShrinkRange"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "input_min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "input_max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "output_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output_max"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-}
-op {
-  name: "QuantizeDownAndShrinkRange"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "input_min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "input_max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "output_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output_max"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-}
-op {
-  name: "QuantizeV2"
-  input_arg {
-    name: "input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_range"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_range"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output_max"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "mode"
-    type: "string"
-    default_value {
-      s: "MIN_COMBINED"
-    }
-    allowed_values {
-      list {
-        s: "MIN_COMBINED"
-        s: "MIN_FIRST"
-      }
-    }
-  }
-}
-op {
-  name: "QuantizeV2"
-  input_arg {
-    name: "input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_range"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_range"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output_max"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "mode"
-    type: "string"
-    default_value {
-      s: "MIN_COMBINED"
-    }
-    allowed_values {
-      list {
-        s: "MIN_COMBINED"
-        s: "MIN_FIRST"
-        s: "SCALED"
-      }
-    }
-  }
-}
-op {
-  name: "QuantizeV2"
-  input_arg {
-    name: "input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_range"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_range"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output_max"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "mode"
-    type: "string"
-    default_value {
-      s: "MIN_COMBINED"
-    }
-    allowed_values {
-      list {
-        s: "MIN_COMBINED"
-        s: "MIN_FIRST"
-        s: "SCALED"
-      }
-    }
-  }
-  attr {
-    name: "round_mode"
-    type: "string"
-    default_value {
-      s: "HALF_AWAY_FROM_ZERO"
-    }
-    allowed_values {
-      list {
-        s: "HALF_AWAY_FROM_ZERO"
-        s: "HALF_TO_EVEN"
-      }
-    }
-  }
-}
-op {
-  name: "QuantizeV2"
-  input_arg {
-    name: "input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_range"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_range"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output_max"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "mode"
-    type: "string"
-    default_value {
-      s: "MIN_COMBINED"
-    }
-    allowed_values {
-      list {
-        s: "MIN_COMBINED"
-        s: "MIN_FIRST"
-        s: "SCALED"
-      }
-    }
-  }
-  attr {
-    name: "round_mode"
-    type: "string"
-    default_value {
-      s: "HALF_AWAY_FROM_ZERO"
-    }
-    allowed_values {
-      list {
-        s: "HALF_AWAY_FROM_ZERO"
-        s: "HALF_TO_EVEN"
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedAdd"
-  input_arg {
-    name: "x"
-    type_attr: "T1"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T2"
-  }
-  input_arg {
-    name: "min_x"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_x"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_y"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_y"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "z"
-    type_attr: "Toutput"
-  }
-  output_arg {
-    name: "min_z"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_z"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T1"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "T2"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "Toutput"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "QuantizedAdd"
-  input_arg {
-    name: "x"
-    type_attr: "T1"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T2"
-  }
-  input_arg {
-    name: "min_x"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_x"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_y"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_y"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "z"
-    type_attr: "Toutput"
-  }
-  output_arg {
-    name: "min_z"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_z"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T1"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "T2"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Toutput"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "QuantizedAdd"
-  input_arg {
-    name: "x"
-    type_attr: "T1"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T2"
-  }
-  input_arg {
-    name: "min_x"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_x"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_y"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_y"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "z"
-    type_attr: "Toutput"
-  }
-  output_arg {
-    name: "min_z"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_z"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T1"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "T2"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Toutput"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedAvgPool"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedAvgPool"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedBatchNormWithGlobalNormalization"
-  input_arg {
-    name: "t"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "t_min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "t_max"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "m"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "m_min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "m_max"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "v"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "v_min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "v_max"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "beta"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "beta_min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "beta_max"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "gamma"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "gamma_min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "gamma_max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "result"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "result_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "result_max"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "variance_epsilon"
-    type: "float"
-  }
-  attr {
-    name: "scale_after_normalization"
-    type: "bool"
-  }
-}
-op {
-  name: "QuantizedBatchNormWithGlobalNormalization"
-  input_arg {
-    name: "t"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "t_min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "t_max"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "m"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "m_min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "m_max"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "v"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "v_min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "v_max"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "beta"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "beta_min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "beta_max"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "gamma"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "gamma_min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "gamma_max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "result"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "result_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "result_max"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "variance_epsilon"
-    type: "float"
-  }
-  attr {
-    name: "scale_after_normalization"
-    type: "bool"
-  }
-}
-op {
-  name: "QuantizedBiasAdd"
-  input_arg {
-    name: "input"
-    type_attr: "T1"
-  }
-  input_arg {
-    name: "bias"
-    type_attr: "T2"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_bias"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_bias"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_out"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_out"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T1"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "T2"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedBiasAdd"
-  input_arg {
-    name: "input"
-    type_attr: "T1"
-  }
-  input_arg {
-    name: "bias"
-    type_attr: "T2"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_bias"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_bias"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_out"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_out"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T1"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "T2"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedConcat"
-  input_arg {
-    name: "concat_dim"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "values"
-    type_attr: "T"
-    number_attr: "N"
-  }
-  input_arg {
-    name: "input_mins"
-    type: DT_FLOAT
-    number_attr: "N"
-  }
-  input_arg {
-    name: "input_maxes"
-    type: DT_FLOAT
-    number_attr: "N"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output_max"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 2
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "QuantizedConv2D"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedConv2D"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedConv2D"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedConv2DAndRelu"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedConv2DAndRelu"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-  attr {
-    name: "padding_list"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedConv2DAndReluAndRequantize"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_freezed_output"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_freezed_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QUINT8
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedConv2DAndReluAndRequantize"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_freezed_output"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_freezed_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QUINT8
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-  attr {
-    name: "padding_list"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedConv2DAndRequantize"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_freezed_output"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_freezed_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QINT8
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedConv2DAndRequantize"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_freezed_output"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_freezed_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QINT8
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-  attr {
-    name: "padding_list"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedConv2DPerChannel"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedConv2DWithBias"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "bias"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedConv2DWithBias"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "bias"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-  attr {
-    name: "padding_list"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedConv2DWithBiasAndRelu"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "bias"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedConv2DWithBiasAndRelu"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "bias"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-  attr {
-    name: "padding_list"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedConv2DWithBiasAndReluAndRequantize"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "bias"
-    type_attr: "Tbias"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_freezed_output"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_freezed_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tbias"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QUINT8
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedConv2DWithBiasAndReluAndRequantize"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "bias"
-    type_attr: "Tbias"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_freezed_output"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_freezed_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tbias"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QUINT8
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-  attr {
-    name: "padding_list"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedConv2DWithBiasAndRequantize"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "bias"
-    type_attr: "Tbias"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_freezed_output"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_freezed_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tbias"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QINT8
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedConv2DWithBiasAndRequantize"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "bias"
-    type_attr: "Tbias"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_freezed_output"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_freezed_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tbias"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QINT8
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-  attr {
-    name: "padding_list"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "bias"
-    type_attr: "Tbias"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_freezed_output"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_freezed_output"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "summand"
-    type_attr: "Tsummand"
-  }
-  input_arg {
-    name: "min_summand"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_summand"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tbias"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "Tsummand"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QUINT8
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "bias"
-    type_attr: "Tbias"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_freezed_output"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_freezed_output"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "summand"
-    type_attr: "Tsummand"
-  }
-  input_arg {
-    name: "min_summand"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_summand"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tbias"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "Tsummand"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QUINT8
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-  attr {
-    name: "padding_list"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedConv2DWithBiasSumAndRelu"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "bias"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "summand"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedConv2DWithBiasSumAndRelu"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "bias"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "summand"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-  attr {
-    name: "padding_list"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedConv2DWithBiasSumAndReluAndRequantize"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "bias"
-    type_attr: "Tbias"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_freezed_output"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_freezed_output"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "summand"
-    type_attr: "Tsummand"
-  }
-  input_arg {
-    name: "min_summand"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_summand"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tbias"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "Tsummand"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QUINT8
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedConv2DWithBiasSumAndReluAndRequantize"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "bias"
-    type_attr: "Tbias"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_freezed_output"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_freezed_output"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "summand"
-    type_attr: "Tsummand"
-  }
-  input_arg {
-    name: "min_summand"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_summand"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tbias"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "Tsummand"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QUINT8
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-  attr {
-    name: "padding_list"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedDepthwiseConv2D"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedDepthwiseConv2DWithBias"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "bias"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedDepthwiseConv2DWithBiasAndRelu"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "bias"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "filter"
-    type_attr: "Tfilter"
-  }
-  input_arg {
-    name: "bias"
-    type_attr: "Tbias"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_filter"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_freezed_output"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_freezed_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tfilter"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tbias"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QUINT8
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-  attr {
-    name: "dilations"
-    type: "list(int)"
-    default_value {
-      list {
-        i: 1
-        i: 1
-        i: 1
-        i: 1
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedInstanceNorm"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "x_min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "x_max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "y_max"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "output_range_given"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "given_y_min"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "given_y_max"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "variance_epsilon"
-    type: "float"
-    default_value {
-      f: 1e-05
-    }
-  }
-  attr {
-    name: "min_separation"
-    type: "float"
-    default_value {
-      f: 0.001
-    }
-  }
-}
-op {
-  name: "QuantizedInstanceNorm"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "x_min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "x_max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "y_max"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "output_range_given"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "given_y_min"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "given_y_max"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "variance_epsilon"
-    type: "float"
-    default_value {
-      f: 1e-05
-    }
-  }
-  attr {
-    name: "min_separation"
-    type: "float"
-    default_value {
-      f: 0.001
-    }
-  }
-}
-op {
-  name: "QuantizedMatMul"
-  input_arg {
-    name: "a"
-    type_attr: "T1"
-  }
-  input_arg {
-    name: "b"
-    type_attr: "T2"
-  }
-  input_arg {
-    name: "min_a"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_a"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_b"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_b"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "out"
-    type_attr: "Toutput"
-  }
-  output_arg {
-    name: "min_out"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_out"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T1"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "T2"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "Toutput"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "transpose_a"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "transpose_b"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "Tactivation"
-    type: "type"
-    default_value {
-      type: DT_QUINT8
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedMatMul"
-  input_arg {
-    name: "a"
-    type_attr: "T1"
-  }
-  input_arg {
-    name: "b"
-    type_attr: "T2"
-  }
-  input_arg {
-    name: "min_a"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_a"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_b"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_b"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "out"
-    type_attr: "Toutput"
-  }
-  output_arg {
-    name: "min_out"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_out"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T1"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "T2"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Toutput"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "transpose_a"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "transpose_b"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "Tactivation"
-    type: "type"
-    default_value {
-      type: DT_QUINT8
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedMatMulWithBias"
-  input_arg {
-    name: "a"
-    type_attr: "T1"
-  }
-  input_arg {
-    name: "b"
-    type_attr: "T2"
-  }
-  input_arg {
-    name: "bias"
-    type_attr: "Tbias"
-  }
-  input_arg {
-    name: "min_a"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_a"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_b"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_b"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "out"
-    type_attr: "Toutput"
-  }
-  output_arg {
-    name: "min_out"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_out"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T1"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "T2"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tbias"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "Toutput"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "transpose_a"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "transpose_b"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "input_quant_mode"
-    type: "string"
-    default_value {
-      s: "MIN_FIRST"
-    }
-    allowed_values {
-      list {
-        s: "MIN_FIRST"
-        s: "SCALED"
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedMatMulWithBiasAndRelu"
-  input_arg {
-    name: "a"
-    type_attr: "T1"
-  }
-  input_arg {
-    name: "b"
-    type_attr: "T2"
-  }
-  input_arg {
-    name: "bias"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_a"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_a"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_b"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_b"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "out"
-    type_attr: "Toutput"
-  }
-  output_arg {
-    name: "min_out"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_out"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T1"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "T2"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Toutput"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "transpose_a"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "transpose_b"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "input_quant_mode"
-    type: "string"
-    default_value {
-      s: "MIN_FIRST"
-    }
-    allowed_values {
-      list {
-        s: "MIN_FIRST"
-        s: "SCALED"
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedMatMulWithBiasAndReluAndRequantize"
-  input_arg {
-    name: "a"
-    type_attr: "T1"
-  }
-  input_arg {
-    name: "b"
-    type_attr: "T2"
-  }
-  input_arg {
-    name: "bias"
-    type_attr: "Tbias"
-  }
-  input_arg {
-    name: "min_a"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_a"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_b"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_b"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_freezed_output"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_freezed_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "out"
-    type_attr: "Toutput"
-  }
-  output_arg {
-    name: "min_out"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_out"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T1"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "T2"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Tbias"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "Toutput"
-    type: "type"
-    default_value {
-      type: DT_QUINT8
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "transpose_a"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "transpose_b"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "input_quant_mode"
-    type: "string"
-    default_value {
-      s: "MIN_FIRST"
-    }
-    allowed_values {
-      list {
-        s: "MIN_FIRST"
-        s: "SCALED"
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedMaxPool"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedMaxPool"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "min_input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_input"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "min_output"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "ksize"
-    type: "list(int)"
-  }
-  attr {
-    name: "strides"
-    type: "list(int)"
-  }
-  attr {
-    name: "padding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "SAME"
-        s: "VALID"
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedMul"
-  input_arg {
-    name: "x"
-    type_attr: "T1"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T2"
-  }
-  input_arg {
-    name: "min_x"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_x"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_y"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_y"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "z"
-    type_attr: "Toutput"
-  }
-  output_arg {
-    name: "min_z"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_z"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T1"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "T2"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "Toutput"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "QuantizedMul"
-  input_arg {
-    name: "x"
-    type_attr: "T1"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T2"
-  }
-  input_arg {
-    name: "min_x"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_x"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_y"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_y"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "z"
-    type_attr: "Toutput"
-  }
-  output_arg {
-    name: "min_z"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_z"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T1"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "T2"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Toutput"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "QuantizedMul"
-  input_arg {
-    name: "x"
-    type_attr: "T1"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T2"
-  }
-  input_arg {
-    name: "min_x"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_x"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_y"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_y"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "z"
-    type_attr: "Toutput"
-  }
-  output_arg {
-    name: "min_z"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_z"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T1"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "T2"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "Toutput"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedRelu"
-  input_arg {
-    name: "features"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "min_features"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_features"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_activations"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_activations"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QUINT8
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedRelu"
-  input_arg {
-    name: "features"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "min_features"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_features"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_activations"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_activations"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QUINT8
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedRelu6"
-  input_arg {
-    name: "features"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "min_features"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_features"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_activations"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_activations"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QUINT8
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedRelu6"
-  input_arg {
-    name: "features"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "min_features"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_features"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_activations"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_activations"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QUINT8
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedReluX"
-  input_arg {
-    name: "features"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "max_value"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_features"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_features"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_activations"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_activations"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QUINT8
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedReluX"
-  input_arg {
-    name: "features"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "max_value"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_features"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max_features"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "min_activations"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "max_activations"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QUINT8
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedReshape"
-  input_arg {
-    name: "tensor"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "shape"
-    type_attr: "Tshape"
-  }
-  input_arg {
-    name: "input_min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "input_max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output_max"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tshape"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "QuantizedResizeBilinear"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "resized_images"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "out_max"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "align_corners"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "QuantizedResizeBilinear"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "resized_images"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "out_max"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "align_corners"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "half_pixel_centers"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "QueueClose"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "cancel_pending_enqueues"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "QueueCloseV2"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "cancel_pending_enqueues"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "QueueDequeue"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  output_arg {
-    name: "components"
-    type_list_attr: "component_types"
-  }
-  attr {
-    name: "component_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "timeout_ms"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-}
-op {
-  name: "QueueDequeueMany"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "n"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "components"
-    type_list_attr: "component_types"
-  }
-  attr {
-    name: "component_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "timeout_ms"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-}
-op {
-  name: "QueueDequeueManyV2"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "n"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "components"
-    type_list_attr: "component_types"
-  }
-  attr {
-    name: "component_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "timeout_ms"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "QueueDequeueUpTo"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "n"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "components"
-    type_list_attr: "component_types"
-  }
-  attr {
-    name: "component_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "timeout_ms"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-}
-op {
-  name: "QueueDequeueUpToV2"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "n"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "components"
-    type_list_attr: "component_types"
-  }
-  attr {
-    name: "component_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "timeout_ms"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "QueueDequeueV2"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "components"
-    type_list_attr: "component_types"
-  }
-  attr {
-    name: "component_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "timeout_ms"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "QueueEnqueue"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "components"
-    type_list_attr: "Tcomponents"
-  }
-  attr {
-    name: "Tcomponents"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "timeout_ms"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-}
-op {
-  name: "QueueEnqueueMany"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "components"
-    type_list_attr: "Tcomponents"
-  }
-  attr {
-    name: "Tcomponents"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "timeout_ms"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-}
-op {
-  name: "QueueEnqueueManyV2"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "components"
-    type_list_attr: "Tcomponents"
-  }
-  attr {
-    name: "Tcomponents"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "timeout_ms"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "QueueEnqueueV2"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "components"
-    type_list_attr: "Tcomponents"
-  }
-  attr {
-    name: "Tcomponents"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "timeout_ms"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "QueueIsClosed"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  output_arg {
-    name: "is_closed"
-    type: DT_BOOL
-  }
-}
-op {
-  name: "QueueIsClosedV2"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "is_closed"
-    type: DT_BOOL
-  }
-  is_stateful: true
-}
-op {
-  name: "QueueSize"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  output_arg {
-    name: "size"
-    type: DT_INT32
-  }
-}
-op {
-  name: "QueueSizeV2"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  is_stateful: true
-}
-op {
-  name: "RFFT"
-  input_arg {
-    name: "input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "fft_length"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type: DT_COMPLEX64
-  }
-}
-op {
-  name: "RFFT2D"
-  input_arg {
-    name: "input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "fft_length"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type: DT_COMPLEX64
-  }
-}
-op {
-  name: "RFFT3D"
-  input_arg {
-    name: "input"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "fft_length"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type: DT_COMPLEX64
-  }
-}
-op {
-  name: "RGBToHSV"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "RGBToHSV"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "RaggedGather"
-  input_arg {
-    name: "params_nested_splits"
-    type: DT_INT64
-    number_attr: "PARAMS_RAGGED_RANK"
-  }
-  input_arg {
-    name: "params_dense_values"
-    type_attr: "Tvalues"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output_nested_splits"
-    type: DT_INT64
-    number_attr: "OUTPUT_RAGGED_RANK"
-  }
-  output_arg {
-    name: "output_dense_values"
-    type_attr: "Tvalues"
-  }
-  attr {
-    name: "Tvalues"
-    type: "type"
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "PARAMS_RAGGED_RANK"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "OUTPUT_RAGGED_RANK"
-    type: "int"
-    has_minimum: true
-  }
-}
-op {
-  name: "RaggedGather"
-  input_arg {
-    name: "params_nested_splits"
-    type_attr: "Tsplits"
-    number_attr: "PARAMS_RAGGED_RANK"
-  }
-  input_arg {
-    name: "params_dense_values"
-    type_attr: "Tvalues"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output_nested_splits"
-    type_attr: "Tsplits"
-    number_attr: "OUTPUT_RAGGED_RANK"
-  }
-  output_arg {
-    name: "output_dense_values"
-    type_attr: "Tvalues"
-  }
-  attr {
-    name: "Tvalues"
-    type: "type"
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tsplits"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "PARAMS_RAGGED_RANK"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "OUTPUT_RAGGED_RANK"
-    type: "int"
-    has_minimum: true
-  }
-}
-op {
-  name: "RaggedRange"
-  input_arg {
-    name: "starts"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "limits"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "deltas"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "rt_nested_splits"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "rt_dense_values"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "RaggedRange"
-  input_arg {
-    name: "starts"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "limits"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "deltas"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "rt_nested_splits"
-    type_attr: "Tsplits"
-  }
-  output_arg {
-    name: "rt_dense_values"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tsplits"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "RaggedTensorFromVariant"
-  input_arg {
-    name: "encoded_ragged"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "output_nested_splits"
-    type_attr: "Tsplits"
-    number_attr: "output_ragged_rank"
-  }
-  output_arg {
-    name: "output_dense_values"
-    type_attr: "Tvalues"
-  }
-  attr {
-    name: "input_ragged_rank"
-    type: "int"
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "output_ragged_rank"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "Tvalues"
-    type: "type"
-  }
-  attr {
-    name: "Tsplits"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "RaggedTensorToSparse"
-  input_arg {
-    name: "rt_nested_splits"
-    type: DT_INT64
-    number_attr: "RAGGED_RANK"
-  }
-  input_arg {
-    name: "rt_dense_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "sparse_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sparse_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "sparse_dense_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "RAGGED_RANK"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "RaggedTensorToSparse"
-  input_arg {
-    name: "rt_nested_splits"
-    type_attr: "Tsplits"
-    number_attr: "RAGGED_RANK"
-  }
-  input_arg {
-    name: "rt_dense_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "sparse_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sparse_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "sparse_dense_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "RAGGED_RANK"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tsplits"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "RaggedTensorToVariant"
-  input_arg {
-    name: "rt_nested_splits"
-    type_attr: "Tsplits"
-    number_attr: "RAGGED_RANK"
-  }
-  input_arg {
-    name: "rt_dense_values"
-    type_attr: "Tvalues"
-  }
-  output_arg {
-    name: "encoded_ragged"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "RAGGED_RANK"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "Tvalues"
-    type: "type"
-  }
-  attr {
-    name: "Tsplits"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "batched_input"
-    type: "bool"
-  }
-}
-op {
-  name: "RandomCrop"
-  input_arg {
-    name: "image"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  deprecation {
-    version: 8
-  }
-  is_stateful: true
-}
-op {
-  name: "RandomDataset"
-  input_arg {
-    name: "seed"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "seed2"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "RandomGamma"
-  input_arg {
-    name: "shape"
-    type_attr: "S"
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "S"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "RandomGammaGrad"
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sample"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "RandomPoisson"
-  input_arg {
-    name: "shape"
-    type_attr: "S"
-  }
-  input_arg {
-    name: "rate"
-    type_attr: "dtype"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "S"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "RandomPoisson"
-  input_arg {
-    name: "shape"
-    type_attr: "S"
-  }
-  input_arg {
-    name: "rate"
-    type_attr: "dtype"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "S"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  deprecation {
-    version: 25
-  }
-  is_stateful: true
-}
-op {
-  name: "RandomPoissonV2"
-  input_arg {
-    name: "shape"
-    type_attr: "S"
-  }
-  input_arg {
-    name: "rate"
-    type_attr: "R"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "S"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "R"
-    type: "type"
-    default_value {
-      type: DT_DOUBLE
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "RandomShuffle"
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "RandomShuffleQueue"
-  output_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "component_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "shapes"
-    type: "list(shape)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "min_after_dequeue"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "RandomShuffleQueueV2"
-  output_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "component_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "shapes"
-    type: "list(shape)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "min_after_dequeue"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "RandomStandardNormal"
-  input_arg {
-    name: "shape"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "RandomStandardNormal"
-  input_arg {
-    name: "shape"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "RandomUniform"
-  input_arg {
-    name: "shape"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "RandomUniform"
-  input_arg {
-    name: "shape"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "RandomUniformInt"
-  input_arg {
-    name: "shape"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "minval"
-    type_attr: "Tout"
-  }
-  input_arg {
-    name: "maxval"
-    type_attr: "Tout"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "Tout"
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "Tout"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "Range"
-  input_arg {
-    name: "start"
-    type_attr: "Tidx"
-  }
-  input_arg {
-    name: "limit"
-    type_attr: "Tidx"
-  }
-  input_arg {
-    name: "delta"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "Tidx"
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Range"
-  input_arg {
-    name: "start"
-    type_attr: "Tidx"
-  }
-  input_arg {
-    name: "limit"
-    type_attr: "Tidx"
-  }
-  input_arg {
-    name: "delta"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "Tidx"
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "RangeDataset"
-  input_arg {
-    name: "start"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "stop"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "step"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "Rank"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type: DT_INT32
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "ReadFile"
-  input_arg {
-    name: "filename"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "contents"
-    type: DT_STRING
-  }
-}
-op {
-  name: "ReadVariableOp"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "value"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "ReaderNumRecordsProduced"
-  input_arg {
-    name: "reader_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  output_arg {
-    name: "records_produced"
-    type: DT_INT64
-  }
-}
-op {
-  name: "ReaderNumRecordsProducedV2"
-  input_arg {
-    name: "reader_handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "records_produced"
-    type: DT_INT64
-  }
-  is_stateful: true
-}
-op {
-  name: "ReaderNumWorkUnitsCompleted"
-  input_arg {
-    name: "reader_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  output_arg {
-    name: "units_completed"
-    type: DT_INT64
-  }
-}
-op {
-  name: "ReaderNumWorkUnitsCompletedV2"
-  input_arg {
-    name: "reader_handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "units_completed"
-    type: DT_INT64
-  }
-  is_stateful: true
-}
-op {
-  name: "ReaderRead"
-  input_arg {
-    name: "reader_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "queue_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  output_arg {
-    name: "key"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "value"
-    type: DT_STRING
-  }
-}
-op {
-  name: "ReaderReadUpTo"
-  input_arg {
-    name: "reader_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "queue_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "num_records"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "keys"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "values"
-    type: DT_STRING
-  }
-}
-op {
-  name: "ReaderReadUpToV2"
-  input_arg {
-    name: "reader_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "queue_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "num_records"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "keys"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "values"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "ReaderReadV2"
-  input_arg {
-    name: "reader_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "queue_handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "key"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "value"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "ReaderReset"
-  input_arg {
-    name: "reader_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-}
-op {
-  name: "ReaderResetV2"
-  input_arg {
-    name: "reader_handle"
-    type: DT_RESOURCE
-  }
-  is_stateful: true
-}
-op {
-  name: "ReaderRestoreState"
-  input_arg {
-    name: "reader_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "state"
-    type: DT_STRING
-  }
-}
-op {
-  name: "ReaderRestoreStateV2"
-  input_arg {
-    name: "reader_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "state"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "ReaderSerializeState"
-  input_arg {
-    name: "reader_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  output_arg {
-    name: "state"
-    type: DT_STRING
-  }
-}
-op {
-  name: "ReaderSerializeStateV2"
-  input_arg {
-    name: "reader_handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "state"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "Real"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "Tout"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_COMPLEX64
-    }
-    allowed_values {
-      list {
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  attr {
-    name: "Tout"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "RealDiv"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "RealDiv"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "RealDiv"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "RebatchDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "num_workers"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "RebatchDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "num_workers"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "use_fallback"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "Reciprocal"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Reciprocal"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Reciprocal"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "ReciprocalGrad"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "ReciprocalGrad"
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dy"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "ReciprocalGrad"
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dy"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "ReciprocalGrad"
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dy"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "RecordInput"
-  output_arg {
-    name: "records"
-    type: DT_STRING
-  }
-  attr {
-    name: "file_pattern"
-    type: "string"
-  }
-  attr {
-    name: "file_random_seed"
-    type: "int"
-    default_value {
-      i: 301
-    }
-  }
-  attr {
-    name: "file_shuffle_shift_ratio"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "file_buffer_size"
-    type: "int"
-    default_value {
-      i: 10000
-    }
-  }
-  attr {
-    name: "file_parallelism"
-    type: "int"
-    default_value {
-      i: 16
-    }
-  }
-  attr {
-    name: "batch_size"
-    type: "int"
-    default_value {
-      i: 32
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "RecordInput"
-  output_arg {
-    name: "records"
-    type: DT_STRING
-  }
-  attr {
-    name: "file_pattern"
-    type: "string"
-  }
-  attr {
-    name: "file_random_seed"
-    type: "int"
-    default_value {
-      i: 301
-    }
-  }
-  attr {
-    name: "file_shuffle_shift_ratio"
-    type: "float"
-    default_value {
-      f: 0
-    }
-  }
-  attr {
-    name: "file_buffer_size"
-    type: "int"
-    default_value {
-      i: 10000
-    }
-  }
-  attr {
-    name: "file_parallelism"
-    type: "int"
-    default_value {
-      i: 16
-    }
-  }
-  attr {
-    name: "batch_size"
-    type: "int"
-    default_value {
-      i: 32
-    }
-  }
-  attr {
-    name: "compression_type"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "RecvTPUEmbeddingActivations"
-  output_arg {
-    name: "outputs"
-    type: DT_FLOAT
-    number_attr: "num_outputs"
-  }
-  attr {
-    name: "num_outputs"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "config"
-    type: "string"
-  }
-  is_stateful: true
-}
-op {
-  name: "ReduceDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "initial_state"
-    type_list_attr: "Tstate"
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  output_arg {
-    name: "components"
-    type_list_attr: "output_types"
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Tstate"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "use_inter_op_parallelism"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "ReduceDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "initial_state"
-    type_list_attr: "Tstate"
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  output_arg {
-    name: "components"
-    type_list_attr: "output_types"
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Tstate"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "use_inter_op_parallelism"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ReduceJoin"
-  input_arg {
-    name: "inputs"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "reduction_indices"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type: DT_STRING
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "separator"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-}
-op {
-  name: "RefEnter"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-    is_ref: true
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "frame_name"
-    type: "string"
-  }
-  attr {
-    name: "is_constant"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "parallel_iterations"
-    type: "int"
-    default_value {
-      i: 10
-    }
-  }
-}
-op {
-  name: "RefExit"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-    is_ref: true
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "RefIdentity"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-    is_ref: true
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  allows_uninitialized_input: true
-}
-op {
-  name: "RefMerge"
-  input_arg {
-    name: "inputs"
-    type_attr: "T"
-    number_attr: "N"
-    is_ref: true
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-    is_ref: true
-  }
-  output_arg {
-    name: "value_index"
-    type: DT_INT32
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "RefNextIteration"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-    is_ref: true
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "RefSelect"
-  input_arg {
-    name: "index"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "inputs"
-    type_attr: "T"
-    number_attr: "N"
-    is_ref: true
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "RefSwitch"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "pred"
-    type: DT_BOOL
-  }
-  output_arg {
-    name: "output_false"
-    type_attr: "T"
-    is_ref: true
-  }
-  output_arg {
-    name: "output_true"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  allows_uninitialized_input: true
-}
-op {
-  name: "RegexFullMatch"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "pattern"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "output"
-    type: DT_BOOL
-  }
-}
-op {
-  name: "RegexReplace"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "pattern"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "rewrite"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "output"
-    type: DT_STRING
-  }
-  attr {
-    name: "replace_global"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "Relu"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "Relu"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "Relu"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "Relu"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "Relu"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_QINT8
-      }
-    }
-  }
-}
-op {
-  name: "Relu6"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "Relu6"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "Relu6"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "Relu6"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "Relu6Grad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "Relu6Grad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "Relu6Grad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "Relu6Grad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "ReluGrad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "ReluGrad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "ReluGrad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "ReluGrad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "RemoteCall"
-  input_arg {
-    name: "target"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "args"
-    type_list_attr: "Tin"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "Tout"
-  }
-  attr {
-    name: "Tin"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "Tout"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-}
-op {
-  name: "RemoteCall"
-  input_arg {
-    name: "target"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "args"
-    type_list_attr: "Tin"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "Tout"
-  }
-  attr {
-    name: "Tin"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "Tout"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  is_stateful: true
-}
-op {
-  name: "RemoteFusedGraphExecute"
-  input_arg {
-    name: "inputs"
-    type_list_attr: "Tinputs"
-  }
-  output_arg {
-    name: "outputs"
-    type_list_attr: "Toutputs"
-  }
-  attr {
-    name: "Tinputs"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Toutputs"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "serialized_remote_fused_graph_execute_info"
-    type: "string"
-  }
-}
-op {
-  name: "RepeatDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "count"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "RepeatDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "count"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "RequantizationRange"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "input_min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "input_max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output_max"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-}
-op {
-  name: "RequantizationRange"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "input_min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "input_max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output_max"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-}
-op {
-  name: "RequantizationRangePerChannel"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "input_max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output_max"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "clip_value_max"
-    type: "float"
-  }
-}
-op {
-  name: "Requantize"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "input_min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "input_max"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "requested_output_min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "requested_output_max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "output_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output_max"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT16
-        type: DT_QUINT16
-        type: DT_QINT32
-      }
-    }
-  }
-}
-op {
-  name: "Requantize"
-  input_arg {
-    name: "input"
-    type_attr: "Tinput"
-  }
-  input_arg {
-    name: "input_min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "input_max"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "requested_output_min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "requested_output_max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "output_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output_max"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "Tinput"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-}
-op {
-  name: "RequantizePerChannel"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "input_max"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "requested_output_min"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "requested_output_max"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "output_min"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output_max"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_QINT32
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_QUINT8
-    }
-    allowed_values {
-      list {
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_QINT16
-        type: DT_QUINT16
-      }
-    }
-  }
-}
-op {
-  name: "Reshape"
-  input_arg {
-    name: "tensor"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "shape"
-    type_attr: "Tshape"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tshape"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "ResizeArea"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "resized_images"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "align_corners"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ResizeArea"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "resized_images"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_UINT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "align_corners"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ResizeBicubic"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "resized_images"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "align_corners"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ResizeBicubic"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "resized_images"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_UINT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "align_corners"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ResizeBicubic"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "resized_images"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_UINT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "align_corners"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "half_pixel_centers"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ResizeBicubicGrad"
-  input_arg {
-    name: "grads"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "original_image"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "align_corners"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ResizeBicubicGrad"
-  input_arg {
-    name: "grads"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "original_image"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "align_corners"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "half_pixel_centers"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ResizeBilinear"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "resized_images"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "align_corners"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ResizeBilinear"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "resized_images"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_UINT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "align_corners"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ResizeBilinear"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "resized_images"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_UINT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "align_corners"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ResizeBilinear"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "resized_images"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_UINT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "align_corners"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "half_pixel_centers"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ResizeBilinearGrad"
-  input_arg {
-    name: "grads"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "original_image"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_HALF
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "align_corners"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ResizeBilinearGrad"
-  input_arg {
-    name: "grads"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "original_image"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "align_corners"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ResizeBilinearGrad"
-  input_arg {
-    name: "grads"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "original_image"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "align_corners"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "half_pixel_centers"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ResizeNearestNeighbor"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "resized_images"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "align_corners"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ResizeNearestNeighbor"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "resized_images"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_UINT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "align_corners"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ResizeNearestNeighbor"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "resized_images"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_UINT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "align_corners"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "half_pixel_centers"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ResizeNearestNeighborGrad"
-  input_arg {
-    name: "grads"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT32
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "align_corners"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ResizeNearestNeighborGrad"
-  input_arg {
-    name: "grads"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT32
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "align_corners"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "half_pixel_centers"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ResourceAccumulatorApplyGradient"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "local_step"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "gradient"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceAccumulatorNumAccumulated"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "num_accumulated"
-    type: DT_INT32
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceAccumulatorSetGlobalStep"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "new_global_step"
-    type: DT_INT64
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceAccumulatorTakeGradient"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "num_required"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "average"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyAdaMax"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "m"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "v"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "beta1_power"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyAdadelta"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum_update"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyAdadelta"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum_update"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyAdadelta"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum_update"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyAdadelta"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum_update"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyAdagrad"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyAdagrad"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyAdagrad"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyAdagrad"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyAdagrad"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "update_slots"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyAdagradDA"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "gradient_accumulator"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "gradient_squared_accumulator"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "global_step"
-    type: DT_INT64
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyAdagradDA"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "gradient_accumulator"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "gradient_squared_accumulator"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "global_step"
-    type: DT_INT64
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyAdagradDA"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "gradient_accumulator"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "gradient_squared_accumulator"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "global_step"
-    type: DT_INT64
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyAdagradDA"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "gradient_accumulator"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "gradient_squared_accumulator"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "global_step"
-    type: DT_INT64
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyAdam"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "m"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "v"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "beta1_power"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta2_power"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyAdam"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "m"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "v"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "beta1_power"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta2_power"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyAdam"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "m"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "v"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "beta1_power"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta2_power"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyAdam"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "m"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "v"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "beta1_power"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta2_power"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyAdam"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "m"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "v"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "beta1_power"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta2_power"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyAdamWithAmsgrad"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "m"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "v"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "vhat"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "beta1_power"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta2_power"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyAddSign"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "m"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sign_decay"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyAddSign"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "m"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sign_decay"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyAddSign"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "m"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sign_decay"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyCenteredRMSProp"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mg"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "ms"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mom"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyCenteredRMSProp"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mg"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "ms"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mom"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyCenteredRMSProp"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mg"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "ms"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mom"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyCenteredRMSProp"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mg"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "ms"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mom"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyFtrl"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "linear"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyFtrl"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "linear"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyFtrl"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "linear"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyFtrl"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "linear"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyFtrlV2"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "linear"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2_shrinkage"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyFtrlV2"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "linear"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2_shrinkage"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyFtrlV2"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "linear"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2_shrinkage"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyFtrlV2"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "linear"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2_shrinkage"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyGradientDescent"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "delta"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyGradientDescent"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "delta"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyGradientDescent"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "delta"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyGradientDescent"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "delta"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyKerasMomentum"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyMomentum"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyMomentum"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyMomentum"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyMomentum"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyPowerSign"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "m"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "logbase"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sign_decay"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyPowerSign"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "m"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "logbase"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sign_decay"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyPowerSign"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "m"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "logbase"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sign_decay"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "beta"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyProximalAdagrad"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyProximalAdagrad"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyProximalAdagrad"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyProximalAdagrad"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyProximalGradientDescent"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "delta"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyProximalGradientDescent"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "delta"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyProximalGradientDescent"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "delta"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyProximalGradientDescent"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "delta"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyRMSProp"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "ms"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mom"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyRMSProp"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "ms"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mom"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyRMSProp"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "ms"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mom"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceApplyRMSProp"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "ms"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mom"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceConditionalAccumulator"
-  output_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "reduction_type"
-    type: "string"
-    default_value {
-      s: "MEAN"
-    }
-    allowed_values {
-      list {
-        s: "MEAN"
-        s: "SUM"
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceCountUpTo"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "limit"
-    type: "int"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceGather"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "validate_indices"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceGather"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "batch_dims"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "validate_indices"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceGatherNd"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceScatterAdd"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceScatterAdd"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceScatterAdd"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceScatterAdd"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceScatterDiv"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceScatterMax"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceScatterMin"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceScatterMul"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceScatterNdAdd"
-  input_arg {
-    name: "ref"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceScatterNdSub"
-  input_arg {
-    name: "ref"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceScatterNdUpdate"
-  input_arg {
-    name: "ref"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceScatterSub"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceScatterUpdate"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceScatterUpdate"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceScatterUpdate"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceScatterUpdate"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyAdadelta"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum_update"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyAdadelta"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum_update"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyAdadelta"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum_update"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyAdadelta"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum_update"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyAdagrad"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyAdagrad"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyAdagrad"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyAdagrad"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyAdagrad"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "update_slots"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyAdagradDA"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "gradient_accumulator"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "gradient_squared_accumulator"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "global_step"
-    type: DT_INT64
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyAdagradDA"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "gradient_accumulator"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "gradient_squared_accumulator"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "global_step"
-    type: DT_INT64
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyAdagradDA"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "gradient_accumulator"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "gradient_squared_accumulator"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "global_step"
-    type: DT_INT64
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyAdagradDA"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "gradient_accumulator"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "gradient_squared_accumulator"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "global_step"
-    type: DT_INT64
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyCenteredRMSProp"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mg"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "ms"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mom"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyCenteredRMSProp"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mg"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "ms"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mom"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyCenteredRMSProp"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mg"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "ms"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mom"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyCenteredRMSProp"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mg"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "ms"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mom"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyFtrl"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "linear"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyFtrl"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "linear"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyFtrl"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "linear"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyFtrl"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "linear"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyFtrlV2"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "linear"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2_shrinkage"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyFtrlV2"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "linear"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2_shrinkage"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyFtrlV2"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "linear"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2_shrinkage"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyFtrlV2"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "linear"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2_shrinkage"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyKerasMomentum"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyMomentum"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyMomentum"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyMomentum"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyMomentum"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyProximalAdagrad"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyProximalAdagrad"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyProximalAdagrad"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyProximalAdagrad"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "accum"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyProximalGradientDescent"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyProximalGradientDescent"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyProximalGradientDescent"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyProximalGradientDescent"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyRMSProp"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "ms"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mom"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyRMSProp"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "ms"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mom"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyRMSProp"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "ms"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mom"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceSparseApplyRMSProp"
-  input_arg {
-    name: "var"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "ms"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "mom"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ResourceStridedSliceAssign"
-  input_arg {
-    name: "ref"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "begin"
-    type_attr: "Index"
-  }
-  input_arg {
-    name: "end"
-    type_attr: "Index"
-  }
-  input_arg {
-    name: "strides"
-    type_attr: "Index"
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Index"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "begin_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "end_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "ellipsis_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "new_axis_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "shrink_axis_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "Restore"
-  input_arg {
-    name: "file_pattern"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "tensor_name"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "tensor"
-    type_attr: "dt"
-  }
-  attr {
-    name: "dt"
-    type: "type"
-  }
-  attr {
-    name: "preferred_shard"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-}
-op {
-  name: "Restore"
-  input_arg {
-    name: "file_pattern"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "tensor_name"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "tensor"
-    type_attr: "dt"
-  }
-  attr {
-    name: "dt"
-    type: "type"
-  }
-  attr {
-    name: "preferred_shard"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "RestoreSlice"
-  input_arg {
-    name: "file_pattern"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "tensor_name"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "shape_and_slice"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "tensor"
-    type_attr: "dt"
-  }
-  attr {
-    name: "dt"
-    type: "type"
-  }
-  attr {
-    name: "preferred_shard"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-}
-op {
-  name: "RestoreSlice"
-  input_arg {
-    name: "file_pattern"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "tensor_name"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "shape_and_slice"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "tensor"
-    type_attr: "dt"
-  }
-  attr {
-    name: "dt"
-    type: "type"
-  }
-  attr {
-    name: "preferred_shard"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "RestoreV2"
-  input_arg {
-    name: "prefix"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "tensor_names"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "shape_and_slices"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "tensors"
-    type_list_attr: "dtypes"
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "RestoreV2"
-  input_arg {
-    name: "prefix"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "tensor_names"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "shape_and_slices"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "tensors"
-    type_list_attr: "dtypes"
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "RetrieveTPUEmbeddingADAMParameters"
-  output_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "momenta"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "velocities"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "RetrieveTPUEmbeddingADAMParametersGradAccumDebug"
-  output_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "momenta"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "velocities"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "gradient_accumulators"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "RetrieveTPUEmbeddingAdadeltaParameters"
-  output_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "accumulators"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "updates"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug"
-  output_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "accumulators"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "updates"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "gradient_accumulators"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "RetrieveTPUEmbeddingAdagradParameters"
-  output_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "accumulators"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "RetrieveTPUEmbeddingAdagradParametersGradAccumDebug"
-  output_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "accumulators"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "gradient_accumulators"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "RetrieveTPUEmbeddingCenteredRMSPropParameters"
-  output_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "ms"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "mom"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "mg"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "RetrieveTPUEmbeddingFTRLParameters"
-  output_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "accumulators"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "linears"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "RetrieveTPUEmbeddingFTRLParametersGradAccumDebug"
-  output_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "accumulators"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "linears"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "gradient_accumulators"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "RetrieveTPUEmbeddingMDLAdagradLightParameters"
-  output_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "accumulators"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "weights"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "benefits"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "RetrieveTPUEmbeddingMomentumParameters"
-  output_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "momenta"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "RetrieveTPUEmbeddingMomentumParametersGradAccumDebug"
-  output_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "momenta"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "gradient_accumulators"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "RetrieveTPUEmbeddingProximalAdagradParameters"
-  output_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "accumulators"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug"
-  output_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "accumulators"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "gradient_accumulators"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "RetrieveTPUEmbeddingRMSPropParameters"
-  output_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "ms"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "mom"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug"
-  output_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "ms"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "mom"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "gradient_accumulators"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "RetrieveTPUEmbeddingStochasticGradientDescentParameters"
-  output_arg {
-    name: "parameters"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    default_value {
-      i: -1
-    }
-    has_minimum: true
-    minimum: -1
-  }
-  attr {
-    name: "table_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "num_shards"
-    type: "int"
-  }
-  attr {
-    name: "shard_id"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "Reverse"
-  input_arg {
-    name: "tensor"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dims"
-    type: DT_BOOL
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_BOOL
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Reverse"
-  input_arg {
-    name: "tensor"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dims"
-    type: DT_BOOL
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_BOOL
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_STRING
-      }
-    }
-  }
-}
-op {
-  name: "Reverse"
-  input_arg {
-    name: "tensor"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dims"
-    type: DT_BOOL
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_BOOL
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_STRING
-      }
-    }
-  }
-}
-op {
-  name: "ReverseSequence"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "seq_lengths"
-    type_attr: "Tlen"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "seq_dim"
-    type: "int"
-  }
-  attr {
-    name: "batch_dim"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tlen"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "ReverseV2"
-  input_arg {
-    name: "tensor"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "axis"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_BOOL
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "ReverseV2"
-  input_arg {
-    name: "tensor"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "axis"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_BOOL
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_STRING
-      }
-    }
-  }
-}
-op {
-  name: "ReverseV2"
-  input_arg {
-    name: "tensor"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "axis"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_BOOL
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_STRING
-      }
-    }
-  }
-}
-op {
-  name: "ReverseV2"
-  input_arg {
-    name: "tensor"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "axis"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_BOOL
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_STRING
-      }
-    }
-  }
-}
-op {
-  name: "ReverseV2"
-  input_arg {
-    name: "tensor"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "axis"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_BOOL
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_STRING
-      }
-    }
-  }
-}
-op {
-  name: "RightShift"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "RightShift"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "Rint"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Rint"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Rint"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "RngSkip"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "algorithm"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "delta"
-    type: DT_INT64
-  }
-  is_stateful: true
-}
-op {
-  name: "Roll"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "shift"
-    type_attr: "Tshift"
-  }
-  input_arg {
-    name: "axis"
-    type_attr: "Taxis"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tshift"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Taxis"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Round"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Round"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Round"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Rpc"
-  input_arg {
-    name: "address"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "method"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "request"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "response"
-    type: DT_STRING
-  }
-  attr {
-    name: "protocol"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "fail_fast"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "timeout_in_ms"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "Rsqrt"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Rsqrt"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Rsqrt"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "RsqrtGrad"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "RsqrtGrad"
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dy"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "RsqrtGrad"
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dy"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "RsqrtGrad"
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dy"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "SampleDistortedBoundingBox"
-  input_arg {
-    name: "image_size"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "bounding_boxes"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "begin"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "size"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "bboxes"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "min_object_covered"
-    type: "float"
-    default_value {
-      f: 0.1
-    }
-  }
-  attr {
-    name: "aspect_ratio_range"
-    type: "list(float)"
-    default_value {
-      list {
-        f: 0.75
-        f: 1.33
-      }
-    }
-  }
-  attr {
-    name: "area_range"
-    type: "list(float)"
-    default_value {
-      list {
-        f: 0.05
-        f: 1
-      }
-    }
-  }
-  attr {
-    name: "max_attempts"
-    type: "int"
-    default_value {
-      i: 100
-    }
-  }
-  attr {
-    name: "use_image_if_no_bounding_boxes"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "SampleDistortedBoundingBoxV2"
-  input_arg {
-    name: "image_size"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "bounding_boxes"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "min_object_covered"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "begin"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "size"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "bboxes"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "aspect_ratio_range"
-    type: "list(float)"
-    default_value {
-      list {
-        f: 0.75
-        f: 1.33
-      }
-    }
-  }
-  attr {
-    name: "area_range"
-    type: "list(float)"
-    default_value {
-      list {
-        f: 0.05
-        f: 1
-      }
-    }
-  }
-  attr {
-    name: "max_attempts"
-    type: "int"
-    default_value {
-      i: 100
-    }
-  }
-  attr {
-    name: "use_image_if_no_bounding_boxes"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "SamplingDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "rate"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "seed"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "seed2"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "Save"
-  input_arg {
-    name: "filename"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "tensor_names"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "data"
-    type_list_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "Save"
-  input_arg {
-    name: "filename"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "tensor_names"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "data"
-    type_list_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "SaveSlices"
-  input_arg {
-    name: "filename"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "tensor_names"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "shapes_and_slices"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "data"
-    type_list_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "SaveSlices"
-  input_arg {
-    name: "filename"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "tensor_names"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "shapes_and_slices"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "data"
-    type_list_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "SaveV2"
-  input_arg {
-    name: "prefix"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "tensor_names"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "shape_and_slices"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "tensors"
-    type_list_attr: "dtypes"
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "SaveV2"
-  input_arg {
-    name: "prefix"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "tensor_names"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "shape_and_slices"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "tensors"
-    type_list_attr: "dtypes"
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "ScalarSummary"
-  input_arg {
-    name: "tags"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "summary"
-    type: DT_STRING
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "ScalarSummary"
-  input_arg {
-    name: "tags"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "summary"
-    type: DT_STRING
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "ScalarSummary"
-  input_arg {
-    name: "tags"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "summary"
-    type: DT_STRING
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "ScalarSummary"
-  input_arg {
-    name: "tags"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "summary"
-    type: DT_STRING
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "ScaleAndTranslate"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "scale"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "translation"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "resized_images"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_UINT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "kernel_type"
-    type: "string"
-    default_value {
-      s: "lanczos3"
-    }
-  }
-}
-op {
-  name: "ScaleAndTranslate"
-  input_arg {
-    name: "images"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "scale"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "translation"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "resized_images"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_UINT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "kernel_type"
-    type: "string"
-    default_value {
-      s: "lanczos3"
-    }
-  }
-  attr {
-    name: "antialias"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "ScaleAndTranslateGrad"
-  input_arg {
-    name: "grads"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "original_image"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "scale"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "translation"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "kernel_type"
-    type: "string"
-    default_value {
-      s: "lanczos3"
-    }
-  }
-}
-op {
-  name: "ScaleAndTranslateGrad"
-  input_arg {
-    name: "grads"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "original_image"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "scale"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "translation"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-      }
-    }
-  }
-  attr {
-    name: "kernel_type"
-    type: "string"
-    default_value {
-      s: "lanczos3"
-    }
-  }
-  attr {
-    name: "antialias"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "ScanDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "initial_state"
-    type_list_attr: "Tstate"
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Tstate"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "preserve_cardinality"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterAdd"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterAdd"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterAdd"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterAdd"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterDiv"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterDiv"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterDiv"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterDiv"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterMax"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterMin"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterMul"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterMul"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterMul"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterMul"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterNd"
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "shape"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "ScatterNdAdd"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterNdAdd"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterNdAdd"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterNdAdd"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterNdNonAliasingAdd"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "ScatterNdNonAliasingAdd"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "ScatterNdNonAliasingAdd"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "ScatterNdNonAliasingAdd"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "ScatterNdNonAliasingAdd"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BOOL
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "ScatterNdSub"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterNdSub"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterNdSub"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterNdSub"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterNdUpdate"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "ScatterSub"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterSub"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterSub"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterSub"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "ScatterUpdate"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "SdcaFprint"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "output"
-    type: DT_INT64
-  }
-}
-op {
-  name: "SdcaOptimizer"
-  input_arg {
-    name: "sparse_example_indices"
-    type: DT_INT64
-    number_attr: "num_sparse_features"
-  }
-  input_arg {
-    name: "sparse_feature_indices"
-    type: DT_INT64
-    number_attr: "num_sparse_features"
-  }
-  input_arg {
-    name: "sparse_feature_values"
-    type: DT_FLOAT
-    number_attr: "num_sparse_features_with_values"
-  }
-  input_arg {
-    name: "dense_features"
-    type: DT_FLOAT
-    number_attr: "num_dense_features"
-  }
-  input_arg {
-    name: "example_weights"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "example_labels"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "sparse_indices"
-    type: DT_INT64
-    number_attr: "num_sparse_features"
-  }
-  input_arg {
-    name: "sparse_weights"
-    type: DT_FLOAT
-    number_attr: "num_sparse_features"
-  }
-  input_arg {
-    name: "dense_weights"
-    type: DT_FLOAT
-    number_attr: "num_dense_features"
-  }
-  input_arg {
-    name: "example_state_data"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "out_example_state_data"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "out_delta_sparse_weights"
-    type: DT_FLOAT
-    number_attr: "num_sparse_features"
-  }
-  output_arg {
-    name: "out_delta_dense_weights"
-    type: DT_FLOAT
-    number_attr: "num_dense_features"
-  }
-  attr {
-    name: "loss_type"
-    type: "string"
-    allowed_values {
-      list {
-        s: "logistic_loss"
-        s: "squared_loss"
-        s: "hinge_loss"
-        s: "smooth_hinge_loss"
-      }
-    }
-  }
-  attr {
-    name: "adaptative"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "num_sparse_features"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "num_sparse_features_with_values"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "num_dense_features"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "l1"
-    type: "float"
-  }
-  attr {
-    name: "l2"
-    type: "float"
-  }
-  attr {
-    name: "num_loss_partitions"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "num_inner_iterations"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "SdcaOptimizer"
-  input_arg {
-    name: "sparse_example_indices"
-    type: DT_INT64
-    number_attr: "num_sparse_features"
-  }
-  input_arg {
-    name: "sparse_feature_indices"
-    type: DT_INT64
-    number_attr: "num_sparse_features"
-  }
-  input_arg {
-    name: "sparse_feature_values"
-    type: DT_FLOAT
-    number_attr: "num_sparse_features_with_values"
-  }
-  input_arg {
-    name: "dense_features"
-    type: DT_FLOAT
-    number_attr: "num_dense_features"
-  }
-  input_arg {
-    name: "example_weights"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "example_labels"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "sparse_indices"
-    type: DT_INT64
-    number_attr: "num_sparse_features"
-  }
-  input_arg {
-    name: "sparse_weights"
-    type: DT_FLOAT
-    number_attr: "num_sparse_features"
-  }
-  input_arg {
-    name: "dense_weights"
-    type: DT_FLOAT
-    number_attr: "num_dense_features"
-  }
-  input_arg {
-    name: "example_state_data"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "out_example_state_data"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "out_delta_sparse_weights"
-    type: DT_FLOAT
-    number_attr: "num_sparse_features"
-  }
-  output_arg {
-    name: "out_delta_dense_weights"
-    type: DT_FLOAT
-    number_attr: "num_dense_features"
-  }
-  attr {
-    name: "loss_type"
-    type: "string"
-    allowed_values {
-      list {
-        s: "logistic_loss"
-        s: "squared_loss"
-        s: "hinge_loss"
-        s: "smooth_hinge_loss"
-        s: "poisson_loss"
-      }
-    }
-  }
-  attr {
-    name: "adaptative"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "num_sparse_features"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "num_sparse_features_with_values"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "num_dense_features"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "l1"
-    type: "float"
-  }
-  attr {
-    name: "l2"
-    type: "float"
-  }
-  attr {
-    name: "num_loss_partitions"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "num_inner_iterations"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "SdcaOptimizerV2"
-  input_arg {
-    name: "sparse_example_indices"
-    type: DT_INT64
-    number_attr: "num_sparse_features"
-  }
-  input_arg {
-    name: "sparse_feature_indices"
-    type: DT_INT64
-    number_attr: "num_sparse_features"
-  }
-  input_arg {
-    name: "sparse_feature_values"
-    type: DT_FLOAT
-    number_attr: "num_sparse_features_with_values"
-  }
-  input_arg {
-    name: "dense_features"
-    type: DT_FLOAT
-    number_attr: "num_dense_features"
-  }
-  input_arg {
-    name: "example_weights"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "example_labels"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "sparse_indices"
-    type: DT_INT64
-    number_attr: "num_sparse_features"
-  }
-  input_arg {
-    name: "sparse_weights"
-    type: DT_FLOAT
-    number_attr: "num_sparse_features"
-  }
-  input_arg {
-    name: "dense_weights"
-    type: DT_FLOAT
-    number_attr: "num_dense_features"
-  }
-  input_arg {
-    name: "example_state_data"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "out_example_state_data"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "out_delta_sparse_weights"
-    type: DT_FLOAT
-    number_attr: "num_sparse_features"
-  }
-  output_arg {
-    name: "out_delta_dense_weights"
-    type: DT_FLOAT
-    number_attr: "num_dense_features"
-  }
-  attr {
-    name: "loss_type"
-    type: "string"
-    allowed_values {
-      list {
-        s: "logistic_loss"
-        s: "squared_loss"
-        s: "hinge_loss"
-        s: "smooth_hinge_loss"
-        s: "poisson_loss"
-      }
-    }
-  }
-  attr {
-    name: "adaptive"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "num_sparse_features"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "num_sparse_features_with_values"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "num_dense_features"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "l1"
-    type: "float"
-  }
-  attr {
-    name: "l2"
-    type: "float"
-  }
-  attr {
-    name: "num_loss_partitions"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "num_inner_iterations"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "SdcaShrinkL1"
-  input_arg {
-    name: "weights"
-    type: DT_FLOAT
-    number_attr: "num_features"
-    is_ref: true
-  }
-  attr {
-    name: "num_features"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "l1"
-    type: "float"
-  }
-  attr {
-    name: "l2"
-    type: "float"
-  }
-}
-op {
-  name: "SegmentMax"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SegmentMax"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SegmentMax"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SegmentMax"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SegmentMean"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SegmentMean"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SegmentMean"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SegmentMean"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SegmentMean"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SegmentMin"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SegmentMin"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SegmentMin"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SegmentMin"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SegmentProd"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SegmentProd"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SegmentProd"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SegmentProd"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SegmentSum"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SegmentSum"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SegmentSum"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SegmentSum"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Select"
-  input_arg {
-    name: "condition"
-    type: DT_BOOL
-  }
-  input_arg {
-    name: "t"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "e"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "SelectV2"
-  input_arg {
-    name: "condition"
-    type: DT_BOOL
-  }
-  input_arg {
-    name: "t"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "e"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "SelfAdjointEig"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-      }
-    }
-  }
-  deprecation {
-    version: 11
-  }
-}
-op {
-  name: "SelfAdjointEig"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_HALF
-      }
-    }
-  }
-  deprecation {
-    version: 11
-  }
-}
-op {
-  name: "SelfAdjointEigV2"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "e"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "v"
-    type_attr: "T"
-  }
-  attr {
-    name: "compute_v"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-      }
-    }
-  }
-}
-op {
-  name: "SelfAdjointEigV2"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "e"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "v"
-    type_attr: "T"
-  }
-  attr {
-    name: "compute_v"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "SelfAdjointEigV2"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "e"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "v"
-    type_attr: "T"
-  }
-  attr {
-    name: "compute_v"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_HALF
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Selu"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Selu"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "SeluGrad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "outputs"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "SeluGrad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "outputs"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "SendTPUEmbeddingGradients"
-  input_arg {
-    name: "inputs"
-    type: DT_FLOAT
-    number_attr: "N"
-  }
-  input_arg {
-    name: "learning_rates"
-    type: DT_FLOAT
-    number_attr: "NN"
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "NN"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "config"
-    type: "string"
-  }
-  is_stateful: true
-}
-op {
-  name: "SerializeIterator"
-  input_arg {
-    name: "resource_handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "serialized"
-    type: DT_VARIANT
-  }
-  is_stateful: true
-}
-op {
-  name: "SerializeManySparse"
-  input_arg {
-    name: "sparse_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sparse_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sparse_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "serialized_sparse"
-    type: DT_STRING
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "SerializeManySparse"
-  input_arg {
-    name: "sparse_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sparse_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sparse_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "serialized_sparse"
-    type_attr: "out_type"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_STRING
-    }
-    allowed_values {
-      list {
-        type: DT_STRING
-        type: DT_VARIANT
-      }
-    }
-  }
-}
-op {
-  name: "SerializeSparse"
-  input_arg {
-    name: "sparse_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sparse_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sparse_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "serialized_sparse"
-    type: DT_STRING
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "SerializeSparse"
-  input_arg {
-    name: "sparse_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sparse_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sparse_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "serialized_sparse"
-    type_attr: "out_type"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_STRING
-    }
-    allowed_values {
-      list {
-        type: DT_STRING
-        type: DT_VARIANT
-      }
-    }
-  }
-}
-op {
-  name: "SerializeTensor"
-  input_arg {
-    name: "tensor"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "serialized"
-    type: DT_STRING
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "SetSize"
-  input_arg {
-    name: "set_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "set_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "set_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  attr {
-    name: "validate_indices"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_STRING
-      }
-    }
-  }
-}
-op {
-  name: "SetStatsAggregatorDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "stats_aggregator"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "tag"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "counter_prefix"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "Shape"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "ShapeN"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-    number_attr: "N"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-    number_attr: "N"
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "ShardDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "num_shards"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "index"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ShardDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "num_shards"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "index"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "require_non_empty"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ShardedFilename"
-  input_arg {
-    name: "basename"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "shard"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "num_shards"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "filename"
-    type: DT_STRING
-  }
-}
-op {
-  name: "ShardedFilespec"
-  input_arg {
-    name: "basename"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "num_shards"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "filename"
-    type: DT_STRING
-  }
-}
-op {
-  name: "ShuffleAndRepeatDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "buffer_size"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "seed"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "seed2"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "count"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ShuffleDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "buffer_size"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "seed"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "seed2"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "ShuffleDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "buffer_size"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "seed"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "seed2"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ShuffleDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "buffer_size"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "seed"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "seed2"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "reshuffle_each_iteration"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "ShutdownDistributedTPU"
-  is_stateful: true
-}
-op {
-  name: "Sigmoid"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Sigmoid"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Sigmoid"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "SigmoidGrad"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "SigmoidGrad"
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dy"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "SigmoidGrad"
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dy"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "SigmoidGrad"
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dy"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Sign"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Sign"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Sign"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Sin"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Sin"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Sin"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Sinh"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Sinh"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Sinh"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Size"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SkipDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "count"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "SkipDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "count"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "Skipgram"
-  output_arg {
-    name: "vocab_word"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "vocab_freq"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "words_per_epoch"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "current_epoch"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "total_words_processed"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "examples"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "labels"
-    type: DT_INT32
-  }
-  attr {
-    name: "filename"
-    type: "string"
-  }
-  attr {
-    name: "batch_size"
-    type: "int"
-  }
-  attr {
-    name: "window_size"
-    type: "int"
-    default_value {
-      i: 5
-    }
-  }
-  attr {
-    name: "min_count"
-    type: "int"
-    default_value {
-      i: 5
-    }
-  }
-  attr {
-    name: "subsample"
-    type: "float"
-    default_value {
-      f: 0.001
-    }
-  }
-  deprecation {
-    version: 19
-  }
-  is_stateful: true
-}
-op {
-  name: "SleepDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "sleep_microseconds"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "Slice"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "begin"
-    type_attr: "Index"
-  }
-  input_arg {
-    name: "size"
-    type_attr: "Index"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Index"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SlidingWindowDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "window_size"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "window_shift"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "window_stride"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "Snapshot"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "SnapshotDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "path"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "compression"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "reader_path_prefix"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "writer_path_prefix"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-}
-op {
-  name: "SnapshotDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "path"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "compression"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "reader_path_prefix"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "writer_path_prefix"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shard_size_bytes"
-    type: "int"
-    default_value {
-      i: 10737418240
-    }
-  }
-  attr {
-    name: "pending_snapshot_expiry_seconds"
-    type: "int"
-    default_value {
-      i: 86400
-    }
-  }
-}
-op {
-  name: "SnapshotDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "path"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "compression"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "reader_path_prefix"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "writer_path_prefix"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shard_size_bytes"
-    type: "int"
-    default_value {
-      i: 10737418240
-    }
-  }
-  attr {
-    name: "pending_snapshot_expiry_seconds"
-    type: "int"
-    default_value {
-      i: 86400
-    }
-  }
-  attr {
-    name: "num_reader_threads"
-    type: "int"
-    default_value {
-      i: 1
-    }
-  }
-  attr {
-    name: "reader_buffer_size"
-    type: "int"
-    default_value {
-      i: 1
-    }
-  }
-}
-op {
-  name: "SnapshotDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "path"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "compression"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "reader_path_prefix"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "writer_path_prefix"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shard_size_bytes"
-    type: "int"
-    default_value {
-      i: 10737418240
-    }
-  }
-  attr {
-    name: "pending_snapshot_expiry_seconds"
-    type: "int"
-    default_value {
-      i: 86400
-    }
-  }
-  attr {
-    name: "num_reader_threads"
-    type: "int"
-    default_value {
-      i: 1
-    }
-  }
-  attr {
-    name: "reader_buffer_size"
-    type: "int"
-    default_value {
-      i: 1
-    }
-  }
-  attr {
-    name: "num_writer_threads"
-    type: "int"
-    default_value {
-      i: 1
-    }
-  }
-  attr {
-    name: "writer_buffer_size"
-    type: "int"
-    default_value {
-      i: 1
-    }
-  }
-}
-op {
-  name: "Softmax"
-  input_arg {
-    name: "logits"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "softmax"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Softmax"
-  input_arg {
-    name: "logits"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "softmax"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "SoftmaxCrossEntropyWithLogits"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "labels"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "loss"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprop"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "SoftmaxCrossEntropyWithLogits"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "labels"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "loss"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprop"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Softplus"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "Softplus"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "Softplus"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "Softplus"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "Softplus"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "SoftplusGrad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "SoftplusGrad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SoftplusGrad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "SoftplusGrad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SoftplusGrad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "Softsign"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "Softsign"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "Softsign"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "Softsign"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "Softsign"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "activations"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "SoftsignGrad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "SoftsignGrad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SoftsignGrad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "SoftsignGrad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SoftsignGrad"
-  input_arg {
-    name: "gradients"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprops"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "SpaceToBatch"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "paddings"
-    type_attr: "Tpaddings"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tpaddings"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "block_size"
-    type: "int"
-    has_minimum: true
-    minimum: 2
-  }
-}
-op {
-  name: "SpaceToBatchND"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "block_shape"
-    type_attr: "Tblock_shape"
-  }
-  input_arg {
-    name: "paddings"
-    type_attr: "Tpaddings"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tblock_shape"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tpaddings"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SpaceToDepth"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "block_size"
-    type: "int"
-    has_minimum: true
-    minimum: 2
-  }
-}
-op {
-  name: "SpaceToDepth"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "block_size"
-    type: "int"
-    has_minimum: true
-    minimum: 2
-  }
-  attr {
-    name: "data_format"
-    type: "string"
-    default_value {
-      s: "NHWC"
-    }
-    allowed_values {
-      list {
-        s: "NHWC"
-        s: "NCHW"
-        s: "NCHW_VECT_C"
-      }
-    }
-  }
-}
-op {
-  name: "SparseAccumulatorApplyGradient"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "local_step"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "gradient_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "gradient_values"
-    type_attr: "dtype"
-  }
-  input_arg {
-    name: "gradient_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "has_known_shape"
-    type: "bool"
-  }
-}
-op {
-  name: "SparseAccumulatorApplyGradient"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "local_step"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "gradient_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "gradient_values"
-    type_attr: "dtype"
-  }
-  input_arg {
-    name: "gradient_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "has_known_shape"
-    type: "bool"
-  }
-}
-op {
-  name: "SparseAccumulatorApplyGradient"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "local_step"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "gradient_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "gradient_values"
-    type_attr: "dtype"
-  }
-  input_arg {
-    name: "gradient_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "has_known_shape"
-    type: "bool"
-  }
-}
-op {
-  name: "SparseAccumulatorApplyGradient"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "local_step"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "gradient_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "gradient_values"
-    type_attr: "dtype"
-  }
-  input_arg {
-    name: "gradient_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "has_known_shape"
-    type: "bool"
-  }
-}
-op {
-  name: "SparseAccumulatorTakeGradient"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "num_required"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "values"
-    type_attr: "dtype"
-  }
-  output_arg {
-    name: "shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "SparseAccumulatorTakeGradient"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "num_required"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "values"
-    type_attr: "dtype"
-  }
-  output_arg {
-    name: "shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseAccumulatorTakeGradient"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "num_required"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "values"
-    type_attr: "dtype"
-  }
-  output_arg {
-    name: "shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "SparseAccumulatorTakeGradient"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "num_required"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "values"
-    type_attr: "dtype"
-  }
-  output_arg {
-    name: "shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseAdd"
-  input_arg {
-    name: "a_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "a_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "a_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "thresh"
-    type_attr: "Treal"
-  }
-  output_arg {
-    name: "sum_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sum_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "sum_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Treal"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "SparseAdd"
-  input_arg {
-    name: "a_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "a_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "a_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "thresh"
-    type_attr: "Treal"
-  }
-  output_arg {
-    name: "sum_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sum_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "sum_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Treal"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseAdd"
-  input_arg {
-    name: "a_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "a_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "a_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "thresh"
-    type_attr: "Treal"
-  }
-  output_arg {
-    name: "sum_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sum_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "sum_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Treal"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "SparseAdd"
-  input_arg {
-    name: "a_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "a_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "a_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "thresh"
-    type_attr: "Treal"
-  }
-  output_arg {
-    name: "sum_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sum_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "sum_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Treal"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseAddGrad"
-  input_arg {
-    name: "backprop_val_grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "a_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sum_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "a_val_grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "b_val_grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "SparseAddGrad"
-  input_arg {
-    name: "backprop_val_grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "a_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sum_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "a_val_grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "b_val_grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseAddGrad"
-  input_arg {
-    name: "backprop_val_grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "a_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sum_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "a_val_grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "b_val_grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "SparseAddGrad"
-  input_arg {
-    name: "backprop_val_grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "a_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sum_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "a_val_grad"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "b_val_grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseApplyAdadelta"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum_update"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyAdadelta"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum_update"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyAdadelta"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum_update"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyAdadelta"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum_update"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyAdagrad"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyAdagrad"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyAdagrad"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyAdagrad"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyAdagrad"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "update_slots"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "SparseApplyAdagradDA"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "gradient_accumulator"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "gradient_squared_accumulator"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "global_step"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyAdagradDA"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "gradient_accumulator"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "gradient_squared_accumulator"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "global_step"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyAdagradDA"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "gradient_accumulator"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "gradient_squared_accumulator"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "global_step"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyAdagradDA"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "gradient_accumulator"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "gradient_squared_accumulator"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "global_step"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyCenteredRMSProp"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "mg"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "ms"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "mom"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyCenteredRMSProp"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "mg"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "ms"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "mom"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyCenteredRMSProp"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "mg"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "ms"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "mom"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyCenteredRMSProp"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "mg"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "ms"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "mom"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyFtrl"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "linear"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyFtrl"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "linear"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyFtrl"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "linear"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyFtrl"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "linear"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyFtrlV2"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "linear"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2_shrinkage"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyFtrlV2"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "linear"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2_shrinkage"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyFtrlV2"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "linear"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2_shrinkage"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyFtrlV2"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "linear"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2_shrinkage"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lr_power"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyMomentum"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyMomentum"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyMomentum"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyMomentum"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "use_nesterov"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyProximalAdagrad"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyProximalAdagrad"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyProximalAdagrad"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyProximalAdagrad"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "accum"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyProximalGradientDescent"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyProximalGradientDescent"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyProximalGradientDescent"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyProximalGradientDescent"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "alpha"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l1"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "l2"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyRMSProp"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "ms"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "mom"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyRMSProp"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "ms"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "mom"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyRMSProp"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "ms"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "mom"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseApplyRMSProp"
-  input_arg {
-    name: "var"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "ms"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "mom"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "lr"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rho"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "momentum"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "epsilon"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  output_arg {
-    name: "out"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "use_locking"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseConcat"
-  input_arg {
-    name: "indices"
-    type: DT_INT64
-    number_attr: "N"
-  }
-  input_arg {
-    name: "values"
-    type_attr: "T"
-    number_attr: "N"
-  }
-  input_arg {
-    name: "shapes"
-    type: DT_INT64
-    number_attr: "N"
-  }
-  output_arg {
-    name: "output_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "concat_dim"
-    type: "int"
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 2
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "SparseConditionalAccumulator"
-  output_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "SparseConditionalAccumulator"
-  output_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "SparseConditionalAccumulator"
-  output_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "SparseConditionalAccumulator"
-  output_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "SparseConditionalAccumulator"
-  output_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "reduction_type"
-    type: "string"
-    default_value {
-      s: "MEAN"
-    }
-    allowed_values {
-      list {
-        s: "MEAN"
-        s: "SUM"
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "SparseCross"
-  input_arg {
-    name: "indices"
-    type: DT_INT64
-    number_attr: "N"
-  }
-  input_arg {
-    name: "values"
-    type_list_attr: "sparse_types"
-  }
-  input_arg {
-    name: "shapes"
-    type: DT_INT64
-    number_attr: "N"
-  }
-  input_arg {
-    name: "dense_inputs"
-    type_list_attr: "dense_types"
-  }
-  output_arg {
-    name: "output_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_values"
-    type_attr: "out_type"
-  }
-  output_arg {
-    name: "output_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "hashed_output"
-    type: "bool"
-  }
-  attr {
-    name: "num_buckets"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "hash_key"
-    type: "int"
-  }
-  attr {
-    name: "sparse_types"
-    type: "list(type)"
-    has_minimum: true
-    allowed_values {
-      list {
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "dense_types"
-    type: "list(type)"
-    has_minimum: true
-    allowed_values {
-      list {
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-  attr {
-    name: "internal_type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT64
-        type: DT_STRING
-      }
-    }
-  }
-}
-op {
-  name: "SparseDenseCwiseAdd"
-  input_arg {
-    name: "sp_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sp_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sp_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "dense"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "SparseDenseCwiseAdd"
-  input_arg {
-    name: "sp_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sp_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sp_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "dense"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseDenseCwiseAdd"
-  input_arg {
-    name: "sp_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sp_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sp_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "dense"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "SparseDenseCwiseAdd"
-  input_arg {
-    name: "sp_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sp_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sp_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "dense"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseDenseCwiseDiv"
-  input_arg {
-    name: "sp_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sp_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sp_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "dense"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "SparseDenseCwiseDiv"
-  input_arg {
-    name: "sp_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sp_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sp_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "dense"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseDenseCwiseDiv"
-  input_arg {
-    name: "sp_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sp_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sp_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "dense"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "SparseDenseCwiseDiv"
-  input_arg {
-    name: "sp_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sp_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sp_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "dense"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseDenseCwiseMul"
-  input_arg {
-    name: "sp_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sp_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sp_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "dense"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "SparseDenseCwiseMul"
-  input_arg {
-    name: "sp_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sp_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sp_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "dense"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseDenseCwiseMul"
-  input_arg {
-    name: "sp_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sp_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sp_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "dense"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "SparseDenseCwiseMul"
-  input_arg {
-    name: "sp_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sp_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sp_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "dense"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseFillEmptyRows"
-  input_arg {
-    name: "indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dense_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "default_value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "empty_row_indicator"
-    type: DT_BOOL
-  }
-  output_arg {
-    name: "reverse_index_map"
-    type: DT_INT64
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "SparseFillEmptyRowsGrad"
-  input_arg {
-    name: "reverse_index_map"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "grad_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "d_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "d_default_value"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "SparseMatMul"
-  input_arg {
-    name: "a"
-    type_attr: "Ta"
-  }
-  input_arg {
-    name: "b"
-    type_attr: "Tb"
-  }
-  output_arg {
-    name: "product"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "transpose_a"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "transpose_b"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "a_is_sparse"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "b_is_sparse"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "Ta"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tb"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "SparseReduceMax"
-  input_arg {
-    name: "input_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "input_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "reduction_axes"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "SparseReduceMax"
-  input_arg {
-    name: "input_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "input_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "reduction_axes"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseReduceMax"
-  input_arg {
-    name: "input_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "input_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "reduction_axes"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "SparseReduceMax"
-  input_arg {
-    name: "input_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "input_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "reduction_axes"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseReduceMaxSparse"
-  input_arg {
-    name: "input_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "input_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "reduction_axes"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "SparseReduceMaxSparse"
-  input_arg {
-    name: "input_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "input_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "reduction_axes"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseReduceMaxSparse"
-  input_arg {
-    name: "input_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "input_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "reduction_axes"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "SparseReduceMaxSparse"
-  input_arg {
-    name: "input_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "input_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "reduction_axes"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseReduceSum"
-  input_arg {
-    name: "input_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "input_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "reduction_axes"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "SparseReduceSum"
-  input_arg {
-    name: "input_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "input_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "reduction_axes"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseReduceSum"
-  input_arg {
-    name: "input_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "input_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "reduction_axes"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "SparseReduceSum"
-  input_arg {
-    name: "input_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "input_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "reduction_axes"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseReduceSumSparse"
-  input_arg {
-    name: "input_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "input_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "reduction_axes"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "SparseReduceSumSparse"
-  input_arg {
-    name: "input_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "input_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "reduction_axes"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseReduceSumSparse"
-  input_arg {
-    name: "input_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "input_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "reduction_axes"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "SparseReduceSumSparse"
-  input_arg {
-    name: "input_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "input_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "reduction_axes"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseReorder"
-  input_arg {
-    name: "input_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "input_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_values"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "SparseReshape"
-  input_arg {
-    name: "input_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "input_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "new_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_shape"
-    type: DT_INT64
-  }
-}
-op {
-  name: "SparseSegmentMean"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tidx"
-  }
-  input_arg {
-    name: "segment_ids"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseSegmentMeanGrad"
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tidx"
-  }
-  input_arg {
-    name: "segment_ids"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "output_dim0"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseSegmentMeanWithNumSegments"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tidx"
-  }
-  input_arg {
-    name: "segment_ids"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "num_segments"
-    type_attr: "Tnumsegments"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tnumsegments"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseSegmentSqrtN"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tidx"
-  }
-  input_arg {
-    name: "segment_ids"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseSegmentSqrtNGrad"
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tidx"
-  }
-  input_arg {
-    name: "segment_ids"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "output_dim0"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseSegmentSqrtNWithNumSegments"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tidx"
-  }
-  input_arg {
-    name: "segment_ids"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "num_segments"
-    type_attr: "Tnumsegments"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tnumsegments"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseSegmentSum"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tidx"
-  }
-  input_arg {
-    name: "segment_ids"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseSegmentSum"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tidx"
-  }
-  input_arg {
-    name: "segment_ids"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseSegmentSum"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tidx"
-  }
-  input_arg {
-    name: "segment_ids"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseSegmentSum"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tidx"
-  }
-  input_arg {
-    name: "segment_ids"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseSegmentSumWithNumSegments"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tidx"
-  }
-  input_arg {
-    name: "segment_ids"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "num_segments"
-    type_attr: "Tnumsegments"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tnumsegments"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseSegmentSumWithNumSegments"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tidx"
-  }
-  input_arg {
-    name: "segment_ids"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "num_segments"
-    type_attr: "Tnumsegments"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tnumsegments"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseSlice"
-  input_arg {
-    name: "indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "start"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "SparseSliceGrad"
-  input_arg {
-    name: "backprop_val_grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "input_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "input_start"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "output_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "val_grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseSoftmax"
-  input_arg {
-    name: "sp_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "sp_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "sp_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "SparseSoftmaxCrossEntropyWithLogits"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "labels"
-    type_attr: "Tlabels"
-  }
-  output_arg {
-    name: "loss"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprop"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "Tlabels"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseSoftmaxCrossEntropyWithLogits"
-  input_arg {
-    name: "features"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "labels"
-    type_attr: "Tlabels"
-  }
-  output_arg {
-    name: "loss"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "backprop"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "Tlabels"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseSparseMaximum"
-  input_arg {
-    name: "a_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "a_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "a_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_values"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "SparseSparseMaximum"
-  input_arg {
-    name: "a_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "a_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "a_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_values"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseSparseMaximum"
-  input_arg {
-    name: "a_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "a_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "a_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_values"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "SparseSparseMaximum"
-  input_arg {
-    name: "a_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "a_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "a_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_values"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseSparseMinimum"
-  input_arg {
-    name: "a_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "a_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "a_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_values"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "SparseSparseMinimum"
-  input_arg {
-    name: "a_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "a_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "a_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_values"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseSparseMinimum"
-  input_arg {
-    name: "a_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "a_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "a_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_values"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "SparseSparseMinimum"
-  input_arg {
-    name: "a_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "a_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "a_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "b_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_values"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseSplit"
-  input_arg {
-    name: "split_dim"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_indices"
-    type: DT_INT64
-    number_attr: "num_split"
-  }
-  output_arg {
-    name: "output_values"
-    type_attr: "T"
-    number_attr: "num_split"
-  }
-  output_arg {
-    name: "output_shape"
-    type: DT_INT64
-    number_attr: "num_split"
-  }
-  attr {
-    name: "num_split"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "SparseTensorDenseAdd"
-  input_arg {
-    name: "a_indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "a_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "a_shape"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "b"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseTensorDenseAdd"
-  input_arg {
-    name: "a_indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "a_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "a_shape"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "b"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseTensorDenseAdd"
-  input_arg {
-    name: "a_indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "a_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "a_shape"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "b"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseTensorDenseAdd"
-  input_arg {
-    name: "a_indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "a_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "a_shape"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "b"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseTensorDenseMatMul"
-  input_arg {
-    name: "a_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "a_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "a_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "product"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "adjoint_a"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "adjoint_b"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseTensorDenseMatMul"
-  input_arg {
-    name: "a_indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "a_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "a_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "b"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "product"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "adjoint_a"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "adjoint_b"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "SparseTensorSliceDataset"
-  input_arg {
-    name: "indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "values"
-    type_attr: "Tvalues"
-  }
-  input_arg {
-    name: "dense_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "Tvalues"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "SparseToDense"
-  input_arg {
-    name: "sparse_indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "output_shape"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "sparse_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "default_value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "dense"
-    type_attr: "T"
-  }
-  attr {
-    name: "validate_indices"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SparseToSparseSetOperation"
-  input_arg {
-    name: "set1_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "set1_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "set1_shape"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "set2_indices"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "set2_values"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "set2_shape"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "result_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "result_values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "result_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "set_operation"
-    type: "string"
-  }
-  attr {
-    name: "validate_indices"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT8
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_STRING
-      }
-    }
-  }
-}
-op {
-  name: "Split"
-  input_arg {
-    name: "split_dim"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-    number_attr: "num_split"
-  }
-  attr {
-    name: "num_split"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "SplitV"
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "size_splits"
-    type_attr: "Tlen"
-  }
-  input_arg {
-    name: "split_dim"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-    number_attr: "num_split"
-  }
-  attr {
-    name: "num_split"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tlen"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SqlDataset"
-  input_arg {
-    name: "driver_name"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "data_source_name"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "query"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "Sqrt"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Sqrt"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Sqrt"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "SqrtGrad"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "SqrtGrad"
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dy"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "SqrtGrad"
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dy"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "SqrtGrad"
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dy"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Square"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Square"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Square"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "SquaredDifference"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "SquaredDifference"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "SquaredDifference"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-  is_commutative: true
-}
-op {
-  name: "Squeeze"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "squeeze_dims"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-    has_minimum: true
-  }
-}
-op {
-  name: "Stack"
-  output_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "elem_type"
-    type: "type"
-  }
-  attr {
-    name: "stack_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "StackClose"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-}
-op {
-  name: "StackCloseV2"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  is_stateful: true
-}
-op {
-  name: "StackPop"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  output_arg {
-    name: "elem"
-    type_attr: "elem_type"
-  }
-  attr {
-    name: "elem_type"
-    type: "type"
-  }
-}
-op {
-  name: "StackPopV2"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "elem"
-    type_attr: "elem_type"
-  }
-  attr {
-    name: "elem_type"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "StackPush"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "elem"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "swap_memory"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "StackPushV2"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "elem"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "swap_memory"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "StackV2"
-  input_arg {
-    name: "max_size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "elem_type"
-    type: "type"
-  }
-  attr {
-    name: "stack_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "Stage"
-  input_arg {
-    name: "values"
-    type_list_attr: "dtypes"
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "Stage"
-  input_arg {
-    name: "values"
-    type_list_attr: "dtypes"
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "memory_limit"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "StageClear"
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "memory_limit"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "StagePeek"
-  input_arg {
-    name: "index"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "values"
-    type_list_attr: "dtypes"
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "memory_limit"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "StageSize"
-  output_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "memory_limit"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "StatefulPartitionedCall"
-  input_arg {
-    name: "args"
-    type_list_attr: "Tin"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "Tout"
-  }
-  attr {
-    name: "Tin"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tout"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  is_stateful: true
-}
-op {
-  name: "StatefulPartitionedCall"
-  input_arg {
-    name: "args"
-    type_list_attr: "Tin"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "Tout"
-  }
-  attr {
-    name: "Tin"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tout"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "config"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "StatefulPartitionedCall"
-  input_arg {
-    name: "args"
-    type_list_attr: "Tin"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "Tout"
-  }
-  attr {
-    name: "Tin"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tout"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "config"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "executor_type"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "StatefulPartitionedCall"
-  input_arg {
-    name: "args"
-    type_list_attr: "Tin"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "Tout"
-  }
-  attr {
-    name: "Tin"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tout"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "config"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "config_proto"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "executor_type"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "StatefulRandomBinomial"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "algorithm"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "shape"
-    type_attr: "S"
-  }
-  input_arg {
-    name: "counts"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "probs"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "S"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_DOUBLE
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "StatefulStandardNormal"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "shape"
-    type_attr: "shape_dtype"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "shape_dtype"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "StatefulStandardNormal"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "shape"
-    type_attr: "shape_dtype"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-  }
-  attr {
-    name: "shape_dtype"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "StatefulStandardNormal"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "shape"
-    type_attr: "shape_dtype"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-  }
-  attr {
-    name: "shape_dtype"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-  }
-  deprecation {
-    version: 29
-  }
-  is_stateful: true
-}
-op {
-  name: "StatefulStandardNormalV2"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "algorithm"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "shape"
-    type_attr: "shape_dtype"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-  }
-  attr {
-    name: "shape_dtype"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "StatefulTruncatedNormal"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "algorithm"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "shape"
-    type_attr: "shape_dtype"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-  }
-  attr {
-    name: "shape_dtype"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "StatefulUniform"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "algorithm"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "shape"
-    type_attr: "shape_dtype"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-  }
-  attr {
-    name: "shape_dtype"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "StatefulUniformFullInt"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "algorithm"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "shape"
-    type_attr: "shape_dtype"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    default_value {
-      type: DT_UINT64
-    }
-  }
-  attr {
-    name: "shape_dtype"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "StatefulUniformInt"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "algorithm"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "shape"
-    type_attr: "shape_dtype"
-  }
-  input_arg {
-    name: "minval"
-    type_attr: "dtype"
-  }
-  input_arg {
-    name: "maxval"
-    type_attr: "dtype"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-  }
-  attr {
-    name: "shape_dtype"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "StatelessIf"
-  input_arg {
-    name: "cond"
-    type_attr: "Tcond"
-  }
-  input_arg {
-    name: "input"
-    type_list_attr: "Tin"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "Tout"
-  }
-  attr {
-    name: "Tcond"
-    type: "type"
-  }
-  attr {
-    name: "Tin"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tout"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "then_branch"
-    type: "func"
-  }
-  attr {
-    name: "else_branch"
-    type: "func"
-  }
-}
-op {
-  name: "StatelessIf"
-  input_arg {
-    name: "cond"
-    type_attr: "Tcond"
-  }
-  input_arg {
-    name: "input"
-    type_list_attr: "Tin"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "Tout"
-  }
-  attr {
-    name: "Tcond"
-    type: "type"
-  }
-  attr {
-    name: "Tin"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tout"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "then_branch"
-    type: "func"
-  }
-  attr {
-    name: "else_branch"
-    type: "func"
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    default_value {
-      list {
-      }
-    }
-  }
-}
-op {
-  name: "StatelessMultinomial"
-  input_arg {
-    name: "logits"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "num_samples"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "seed"
-    type_attr: "Tseed"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "output_dtype"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tseed"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "output_dtype"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "StatelessRandomNormal"
-  input_arg {
-    name: "shape"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "seed"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "StatelessRandomNormal"
-  input_arg {
-    name: "shape"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "seed"
-    type_attr: "Tseed"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tseed"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "StatelessRandomNormal"
-  input_arg {
-    name: "shape"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "seed"
-    type_attr: "Tseed"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tseed"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "StatelessRandomUniform"
-  input_arg {
-    name: "shape"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "seed"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "StatelessRandomUniform"
-  input_arg {
-    name: "shape"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "seed"
-    type_attr: "Tseed"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tseed"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "StatelessRandomUniform"
-  input_arg {
-    name: "shape"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "seed"
-    type_attr: "Tseed"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tseed"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "StatelessRandomUniformInt"
-  input_arg {
-    name: "shape"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "seed"
-    type_attr: "Tseed"
-  }
-  input_arg {
-    name: "minval"
-    type_attr: "dtype"
-  }
-  input_arg {
-    name: "maxval"
-    type_attr: "dtype"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tseed"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "StatelessTruncatedNormal"
-  input_arg {
-    name: "shape"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "seed"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "StatelessTruncatedNormal"
-  input_arg {
-    name: "shape"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "seed"
-    type_attr: "Tseed"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tseed"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "StatelessTruncatedNormal"
-  input_arg {
-    name: "shape"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "seed"
-    type_attr: "Tseed"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tseed"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "StatelessWhile"
-  input_arg {
-    name: "input"
-    type_list_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "cond"
-    type: "func"
-  }
-  attr {
-    name: "body"
-    type: "func"
-  }
-}
-op {
-  name: "StaticRegexFullMatch"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "output"
-    type: DT_BOOL
-  }
-  attr {
-    name: "pattern"
-    type: "string"
-  }
-}
-op {
-  name: "StaticRegexReplace"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "output"
-    type: DT_STRING
-  }
-  attr {
-    name: "pattern"
-    type: "string"
-  }
-  attr {
-    name: "rewrite"
-    type: "string"
-  }
-  attr {
-    name: "replace_global"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "StatsAggregatorHandle"
-  output_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "StatsAggregatorHandleV2"
-  output_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "StatsAggregatorSetSummaryWriter"
-  input_arg {
-    name: "stats_aggregator"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "summary"
-    type: DT_RESOURCE
-  }
-  is_stateful: true
-}
-op {
-  name: "StatsAggregatorSummary"
-  input_arg {
-    name: "iterator"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "summary"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "StopGradient"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "StridedSlice"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "begin"
-    type_attr: "Index"
-  }
-  input_arg {
-    name: "end"
-    type_attr: "Index"
-  }
-  input_arg {
-    name: "strides"
-    type_attr: "Index"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Index"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "begin_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "end_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "ellipsis_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "new_axis_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "shrink_axis_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-}
-op {
-  name: "StridedSliceAssign"
-  input_arg {
-    name: "ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  input_arg {
-    name: "begin"
-    type_attr: "Index"
-  }
-  input_arg {
-    name: "end"
-    type_attr: "Index"
-  }
-  input_arg {
-    name: "strides"
-    type_attr: "Index"
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_ref"
-    type_attr: "T"
-    is_ref: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Index"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "begin_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "end_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "ellipsis_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "new_axis_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "shrink_axis_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-}
-op {
-  name: "StridedSliceGrad"
-  input_arg {
-    name: "shape"
-    type_attr: "Index"
-  }
-  input_arg {
-    name: "begin"
-    type_attr: "Index"
-  }
-  input_arg {
-    name: "end"
-    type_attr: "Index"
-  }
-  input_arg {
-    name: "strides"
-    type_attr: "Index"
-  }
-  input_arg {
-    name: "dy"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Index"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "begin_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "end_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "ellipsis_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "new_axis_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "shrink_axis_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-}
-op {
-  name: "StringFormat"
-  input_arg {
-    name: "inputs"
-    type_list_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type: DT_STRING
-  }
-  attr {
-    name: "T"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "template"
-    type: "string"
-    default_value {
-      s: "%s"
-    }
-  }
-  attr {
-    name: "placeholder"
-    type: "string"
-    default_value {
-      s: "%s"
-    }
-  }
-  attr {
-    name: "summarize"
-    type: "int"
-    default_value {
-      i: 3
-    }
-  }
-}
-op {
-  name: "StringJoin"
-  input_arg {
-    name: "inputs"
-    type: DT_STRING
-    number_attr: "N"
-  }
-  output_arg {
-    name: "output"
-    type: DT_STRING
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "separator"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-}
-op {
-  name: "StringLength"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "output"
-    type: DT_INT32
-  }
-}
-op {
-  name: "StringLength"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "output"
-    type: DT_INT32
-  }
-  attr {
-    name: "unit"
-    type: "string"
-    default_value {
-      s: "BYTE"
-    }
-    allowed_values {
-      list {
-        s: "BYTE"
-        s: "UTF8_CHAR"
-      }
-    }
-  }
-}
-op {
-  name: "StringLower"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "output"
-    type: DT_STRING
-  }
-  attr {
-    name: "encoding"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-}
-op {
-  name: "StringSplit"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "delimiter"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "values"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "shape"
-    type: DT_INT64
-  }
-}
-op {
-  name: "StringSplit"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "delimiter"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "values"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "skip_empty"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-}
-op {
-  name: "StringSplitV2"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "sep"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "values"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "maxsplit"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-}
-op {
-  name: "StringStrip"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "output"
-    type: DT_STRING
-  }
-}
-op {
-  name: "StringToHashBucket"
-  input_arg {
-    name: "string_tensor"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "output"
-    type: DT_INT64
-  }
-  attr {
-    name: "num_buckets"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "StringToHashBucketFast"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "output"
-    type: DT_INT64
-  }
-  attr {
-    name: "num_buckets"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "StringToHashBucketStrong"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "output"
-    type: DT_INT64
-  }
-  attr {
-    name: "num_buckets"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "key"
-    type: "list(int)"
-  }
-}
-op {
-  name: "StringToNumber"
-  input_arg {
-    name: "string_tensor"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_INT32
-      }
-    }
-  }
-}
-op {
-  name: "StringToNumber"
-  input_arg {
-    name: "string_tensor"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "StringUpper"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "output"
-    type: DT_STRING
-  }
-  attr {
-    name: "encoding"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-}
-op {
-  name: "Sub"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Sub"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Sub"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Sub"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Substr"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "pos"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "len"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type: DT_STRING
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Substr"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "pos"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "len"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type: DT_STRING
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "unit"
-    type: "string"
-    default_value {
-      s: "BYTE"
-    }
-    allowed_values {
-      list {
-        s: "BYTE"
-        s: "UTF8_CHAR"
-      }
-    }
-  }
-}
-op {
-  name: "Sum"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reduction_indices"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Sum"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reduction_indices"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Sum"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reduction_indices"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Sum"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "reduction_indices"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "keep_dims"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "SummaryWriter"
-  output_arg {
-    name: "writer"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "Svd"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "s"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "u"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "v"
-    type_attr: "T"
-  }
-  attr {
-    name: "compute_uv"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "full_matrices"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Svd"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "s"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "u"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "v"
-    type_attr: "T"
-  }
-  attr {
-    name: "compute_uv"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "full_matrices"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_HALF
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Switch"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "pred"
-    type: DT_BOOL
-  }
-  output_arg {
-    name: "output_false"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output_true"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "SymbolicGradient"
-  input_arg {
-    name: "input"
-    type_list_attr: "Tin"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "Tout"
-  }
-  attr {
-    name: "Tin"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "Tout"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-}
-op {
-  name: "TFRecordDataset"
-  input_arg {
-    name: "filenames"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "compression_type"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "buffer_size"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  is_stateful: true
-}
-op {
-  name: "TFRecordReader"
-  output_arg {
-    name: "reader_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "compression_type"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "TFRecordReader"
-  output_arg {
-    name: "reader_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "compression_type"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  deprecation {
-    version: 26
-  }
-  is_stateful: true
-}
-op {
-  name: "TFRecordReaderV2"
-  output_arg {
-    name: "reader_handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "compression_type"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "TPUCompilationResult"
-  output_arg {
-    name: "output"
-    type: DT_STRING
-  }
-}
-op {
-  name: "TPUEmbeddingActivations"
-  input_arg {
-    name: "embedding_variable"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "sliced_activations"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "output"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "table_id"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "lookup_id"
-    type: "int"
-    has_minimum: true
-  }
-}
-op {
-  name: "TPUOrdinalSelector"
-  output_arg {
-    name: "device_ordinals"
-    type: DT_INT32
-  }
-  is_stateful: true
-}
-op {
-  name: "TPUPartitionedCall"
-  input_arg {
-    name: "args"
-    type_list_attr: "Tin"
-  }
-  input_arg {
-    name: "device_ordinal"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "Tout"
-  }
-  attr {
-    name: "Tin"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "Tout"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-}
-op {
-  name: "TPUReplicateMetadata"
-  attr {
-    name: "num_replicas"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "num_cores_per_replica"
-    type: "int"
-    default_value {
-      i: 1
-    }
-  }
-  attr {
-    name: "topology"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "use_tpu"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "device_assignment"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "computation_shape"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "host_compute_core"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "padding_map"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-}
-op {
-  name: "TPUReplicateMetadata"
-  attr {
-    name: "num_replicas"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "num_cores_per_replica"
-    type: "int"
-    default_value {
-      i: 1
-    }
-  }
-  attr {
-    name: "topology"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "use_tpu"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "device_assignment"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "computation_shape"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "host_compute_core"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "padding_map"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "step_marker_location"
-    type: "string"
-    default_value {
-      s: "STEP_MARK_AT_ENTRY"
-    }
-  }
-}
-op {
-  name: "TPUReplicateMetadata"
-  attr {
-    name: "num_replicas"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "num_cores_per_replica"
-    type: "int"
-    default_value {
-      i: 1
-    }
-  }
-  attr {
-    name: "topology"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "use_tpu"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "device_assignment"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "computation_shape"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "host_compute_core"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "padding_map"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "step_marker_location"
-    type: "string"
-    default_value {
-      s: "STEP_MARK_AT_ENTRY"
-    }
-  }
-  attr {
-    name: "allow_soft_placement"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "TPUReplicatedInput"
-  input_arg {
-    name: "inputs"
-    type_attr: "T"
-    number_attr: "N"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "TPUReplicatedOutput"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "outputs"
-    type_attr: "T"
-    number_attr: "num_replicas"
-  }
-  attr {
-    name: "num_replicas"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "TakeDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "count"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "TakeDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "count"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "TakeManySparseFromTensorsMap"
-  input_arg {
-    name: "sparse_handles"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sparse_indices"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sparse_values"
-    type_attr: "dtype"
-  }
-  output_arg {
-    name: "sparse_shape"
-    type: DT_INT64
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "TakeWhileDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "other_arguments"
-    type_list_attr: "Targuments"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "predicate"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "Tan"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Tan"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Tan"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Tanh"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Tanh"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Tanh"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "TanhGrad"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "TanhGrad"
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dy"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "TanhGrad"
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dy"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "TanhGrad"
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "dy"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "TemporaryVariable"
-  output_arg {
-    name: "ref"
-    type_attr: "dtype"
-    is_ref: true
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "var_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorArray"
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "dynamic_size"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "clear_after_read"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "tensor_array_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "element_shape"
-    type: "shape"
-    default_value {
-      shape {
-        unknown_rank: true
-      }
-    }
-  }
-  deprecation {
-    version: 16
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorArrayClose"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  deprecation {
-    version: 16
-  }
-}
-op {
-  name: "TensorArrayCloseV2"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-}
-op {
-  name: "TensorArrayCloseV2"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  deprecation {
-    version: 26
-  }
-}
-op {
-  name: "TensorArrayCloseV3"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorArrayConcat"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "value"
-    type_attr: "dtype"
-  }
-  output_arg {
-    name: "lengths"
-    type: DT_INT64
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "element_shape_except0"
-    type: "shape"
-    default_value {
-      shape {
-        unknown_rank: true
-      }
-    }
-  }
-  deprecation {
-    version: 16
-  }
-}
-op {
-  name: "TensorArrayConcatV2"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "value"
-    type_attr: "dtype"
-  }
-  output_arg {
-    name: "lengths"
-    type: DT_INT64
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "element_shape_except0"
-    type: "shape"
-    default_value {
-      shape {
-        unknown_rank: true
-      }
-    }
-  }
-}
-op {
-  name: "TensorArrayConcatV3"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "value"
-    type_attr: "dtype"
-  }
-  output_arg {
-    name: "lengths"
-    type: DT_INT64
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "element_shape_except0"
-    type: "shape"
-    default_value {
-      shape {
-        unknown_rank: true
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorArrayGather"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "value"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "element_shape"
-    type: "shape"
-    default_value {
-      shape {
-        unknown_rank: true
-      }
-    }
-  }
-  deprecation {
-    version: 16
-  }
-}
-op {
-  name: "TensorArrayGatherV2"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "value"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "element_shape"
-    type: "shape"
-    default_value {
-      shape {
-        unknown_rank: true
-      }
-    }
-  }
-}
-op {
-  name: "TensorArrayGatherV2"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "value"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "element_shape"
-    type: "shape"
-    default_value {
-      shape {
-        unknown_rank: true
-      }
-    }
-  }
-  deprecation {
-    version: 26
-  }
-}
-op {
-  name: "TensorArrayGatherV3"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "value"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "element_shape"
-    type: "shape"
-    default_value {
-      shape {
-        unknown_rank: true
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorArrayGrad"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "grad_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "source"
-    type: "string"
-  }
-  deprecation {
-    version: 16
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorArrayGradV2"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "grad_handle"
-    type: DT_STRING
-  }
-  attr {
-    name: "source"
-    type: "string"
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorArrayGradV2"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "grad_handle"
-    type: DT_STRING
-  }
-  attr {
-    name: "source"
-    type: "string"
-  }
-  deprecation {
-    version: 26
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorArrayGradV3"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "grad_handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "flow_out"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "source"
-    type: "string"
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorArrayGradWithShape"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "shape_to_prepend"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "grad_handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "flow_out"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "source"
-    type: "string"
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorArrayPack"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "value"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "element_shape"
-    type: "shape"
-    default_value {
-      shape {
-        unknown_rank: true
-      }
-    }
-  }
-  deprecation {
-    version: 16
-  }
-}
-op {
-  name: "TensorArrayRead"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "index"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "value"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  deprecation {
-    version: 16
-  }
-}
-op {
-  name: "TensorArrayReadV2"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "index"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "value"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-}
-op {
-  name: "TensorArrayReadV2"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "index"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "value"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  deprecation {
-    version: 26
-  }
-}
-op {
-  name: "TensorArrayReadV3"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "index"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "value"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorArrayScatter"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "flow_out"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  deprecation {
-    version: 19
-  }
-}
-op {
-  name: "TensorArrayScatterV2"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "flow_out"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "TensorArrayScatterV2"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "flow_out"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  deprecation {
-    version: 26
-  }
-}
-op {
-  name: "TensorArrayScatterV3"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "flow_out"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorArraySize"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  deprecation {
-    version: 16
-  }
-}
-op {
-  name: "TensorArraySizeV2"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "size"
-    type: DT_INT32
-  }
-}
-op {
-  name: "TensorArraySizeV2"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  deprecation {
-    version: 26
-  }
-}
-op {
-  name: "TensorArraySizeV3"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorArraySplit"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lengths"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "flow_out"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  deprecation {
-    version: 16
-  }
-}
-op {
-  name: "TensorArraySplitV2"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lengths"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "flow_out"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "TensorArraySplitV2"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lengths"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "flow_out"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  deprecation {
-    version: 26
-  }
-}
-op {
-  name: "TensorArraySplitV3"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "lengths"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "flow_out"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorArrayUnpack"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "flow_out"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  deprecation {
-    version: 20
-  }
-}
-op {
-  name: "TensorArrayV2"
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "element_shape"
-    type: "shape"
-    default_value {
-      shape {
-        unknown_rank: true
-      }
-    }
-  }
-  attr {
-    name: "dynamic_size"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "clear_after_read"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "tensor_array_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorArrayV2"
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "element_shape"
-    type: "shape"
-    default_value {
-      shape {
-        unknown_rank: true
-      }
-    }
-  }
-  attr {
-    name: "dynamic_size"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "clear_after_read"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "tensor_array_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  deprecation {
-    version: 26
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorArrayV3"
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "flow"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "element_shape"
-    type: "shape"
-    default_value {
-      shape {
-        unknown_rank: true
-      }
-    }
-  }
-  attr {
-    name: "dynamic_size"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "clear_after_read"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "tensor_array_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorArrayV3"
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "flow"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "element_shape"
-    type: "shape"
-    default_value {
-      shape {
-        unknown_rank: true
-      }
-    }
-  }
-  attr {
-    name: "dynamic_size"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "clear_after_read"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "identical_element_shapes"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "tensor_array_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorArrayWrite"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  input_arg {
-    name: "index"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "flow_out"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  deprecation {
-    version: 16
-  }
-}
-op {
-  name: "TensorArrayWriteV2"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "index"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "flow_out"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "TensorArrayWriteV2"
-  input_arg {
-    name: "handle"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "index"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "flow_out"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  deprecation {
-    version: 26
-  }
-}
-op {
-  name: "TensorArrayWriteV3"
-  input_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "index"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "flow_in"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "flow_out"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorDataset"
-  input_arg {
-    name: "components"
-    type_list_attr: "Toutput_types"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "Toutput_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorForestCreateTreeVariable"
-  input_arg {
-    name: "tree_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "tree_config"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorForestTreeDeserialize"
-  input_arg {
-    name: "tree_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "tree_config"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorForestTreeIsInitializedOp"
-  input_arg {
-    name: "tree_handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "is_initialized"
-    type: DT_BOOL
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorForestTreePredict"
-  input_arg {
-    name: "tree_handle"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "dense_features"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "logits"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "logits_dimension"
-    type: "int"
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorForestTreeResourceHandleOp"
-  output_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorForestTreeSerialize"
-  input_arg {
-    name: "tree_handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "tree_config"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorForestTreeSize"
-  input_arg {
-    name: "tree_handle"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "tree_size"
-    type: DT_INT32
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorListConcat"
-  input_arg {
-    name: "input_handle"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "tensor"
-    type_attr: "element_dtype"
-  }
-  output_arg {
-    name: "lengths"
-    type: DT_INT64
-  }
-  attr {
-    name: "element_dtype"
-    type: "type"
-  }
-}
-op {
-  name: "TensorListConcat"
-  input_arg {
-    name: "input_handle"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "tensor"
-    type_attr: "element_dtype"
-  }
-  output_arg {
-    name: "lengths"
-    type: DT_INT64
-  }
-  attr {
-    name: "element_dtype"
-    type: "type"
-  }
-  attr {
-    name: "element_shape"
-    type: "shape"
-    default_value {
-      shape {
-        unknown_rank: true
-      }
-    }
-  }
-}
-op {
-  name: "TensorListConcatLists"
-  input_arg {
-    name: "input_a"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "input_b"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "output"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "element_dtype"
-    type: "type"
-  }
-}
-op {
-  name: "TensorListConcatV2"
-  input_arg {
-    name: "input_handle"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "element_shape"
-    type_attr: "shape_type"
-  }
-  input_arg {
-    name: "leading_dims"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "tensor"
-    type_attr: "element_dtype"
-  }
-  output_arg {
-    name: "lengths"
-    type: DT_INT64
-  }
-  attr {
-    name: "element_dtype"
-    type: "type"
-  }
-  attr {
-    name: "shape_type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "TensorListElementShape"
-  input_arg {
-    name: "input_handle"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "element_shape"
-    type_attr: "shape_type"
-  }
-  attr {
-    name: "shape_type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "TensorListFromTensor"
-  input_arg {
-    name: "tensor"
-    type_attr: "element_dtype"
-  }
-  input_arg {
-    name: "element_shape"
-    type_attr: "shape_type"
-  }
-  output_arg {
-    name: "output_handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "element_dtype"
-    type: "type"
-  }
-  attr {
-    name: "shape_type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "TensorListGather"
-  input_arg {
-    name: "input_handle"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "element_shape"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "values"
-    type_attr: "element_dtype"
-  }
-  attr {
-    name: "element_dtype"
-    type: "type"
-  }
-}
-op {
-  name: "TensorListGetItem"
-  input_arg {
-    name: "input_handle"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "index"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "element_shape"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "item"
-    type_attr: "element_dtype"
-  }
-  attr {
-    name: "element_dtype"
-    type: "type"
-  }
-}
-op {
-  name: "TensorListLength"
-  input_arg {
-    name: "input_handle"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "length"
-    type: DT_INT32
-  }
-}
-op {
-  name: "TensorListPopBack"
-  input_arg {
-    name: "input_handle"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "element_shape"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output_handle"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "tensor"
-    type_attr: "element_dtype"
-  }
-  attr {
-    name: "element_dtype"
-    type: "type"
-  }
-}
-op {
-  name: "TensorListPushBack"
-  input_arg {
-    name: "input_handle"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "tensor"
-    type_attr: "element_dtype"
-  }
-  output_arg {
-    name: "output_handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "element_dtype"
-    type: "type"
-  }
-}
-op {
-  name: "TensorListPushBackBatch"
-  input_arg {
-    name: "input_handles"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "tensor"
-    type_attr: "element_dtype"
-  }
-  output_arg {
-    name: "output_handles"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "element_dtype"
-    type: "type"
-  }
-}
-op {
-  name: "TensorListReserve"
-  input_arg {
-    name: "element_shape"
-    type_attr: "shape_type"
-  }
-  input_arg {
-    name: "num_elements"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "element_dtype"
-    type: "type"
-  }
-  attr {
-    name: "shape_type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "TensorListResize"
-  input_arg {
-    name: "input_handle"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output_handle"
-    type: DT_VARIANT
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorListScatter"
-  input_arg {
-    name: "tensor"
-    type_attr: "element_dtype"
-  }
-  input_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "element_shape"
-    type_attr: "shape_type"
-  }
-  output_arg {
-    name: "output_handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "element_dtype"
-    type: "type"
-  }
-  attr {
-    name: "shape_type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "TensorListScatterIntoExistingList"
-  input_arg {
-    name: "input_handle"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "tensor"
-    type_attr: "element_dtype"
-  }
-  input_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output_handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "element_dtype"
-    type: "type"
-  }
-}
-op {
-  name: "TensorListScatterV2"
-  input_arg {
-    name: "tensor"
-    type_attr: "element_dtype"
-  }
-  input_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "element_shape"
-    type_attr: "shape_type"
-  }
-  input_arg {
-    name: "num_elements"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output_handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "element_dtype"
-    type: "type"
-  }
-  attr {
-    name: "shape_type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "TensorListSetItem"
-  input_arg {
-    name: "input_handle"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "index"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "item"
-    type_attr: "element_dtype"
-  }
-  output_arg {
-    name: "output_handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "element_dtype"
-    type: "type"
-  }
-}
-op {
-  name: "TensorListSplit"
-  input_arg {
-    name: "tensor"
-    type_attr: "element_dtype"
-  }
-  input_arg {
-    name: "element_shape"
-    type_attr: "shape_type"
-  }
-  input_arg {
-    name: "lengths"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output_handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "element_dtype"
-    type: "type"
-  }
-  attr {
-    name: "shape_type"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "TensorListStack"
-  input_arg {
-    name: "input_handle"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "element_shape"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "tensor"
-    type_attr: "element_dtype"
-  }
-  attr {
-    name: "element_dtype"
-    type: "type"
-  }
-  attr {
-    name: "num_elements"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-}
-op {
-  name: "TensorScatterAdd"
-  input_arg {
-    name: "tensor"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "TensorScatterSub"
-  input_arg {
-    name: "tensor"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "TensorScatterUpdate"
-  input_arg {
-    name: "tensor"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "indices"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "updates"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "TensorSliceDataset"
-  input_arg {
-    name: "components"
-    type_list_attr: "Toutput_types"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "Toutput_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "TensorStridedSliceUpdate"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "begin"
-    type_attr: "Index"
-  }
-  input_arg {
-    name: "end"
-    type_attr: "Index"
-  }
-  input_arg {
-    name: "strides"
-    type_attr: "Index"
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Index"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "begin_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "end_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "ellipsis_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "new_axis_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "shrink_axis_mask"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-}
-op {
-  name: "TensorSummary"
-  input_arg {
-    name: "tensor"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "summary"
-    type: DT_STRING
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "description"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "labels"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "display_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-}
-op {
-  name: "TensorSummaryV2"
-  input_arg {
-    name: "tag"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "tensor"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "serialized_summary_metadata"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "summary"
-    type: DT_STRING
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "TextLineDataset"
-  input_arg {
-    name: "filenames"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "compression_type"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "buffer_size"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  is_stateful: true
-}
-op {
-  name: "TextLineReader"
-  output_arg {
-    name: "reader_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "skip_header_lines"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "TextLineReader"
-  output_arg {
-    name: "reader_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "skip_header_lines"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  deprecation {
-    version: 26
-  }
-  is_stateful: true
-}
-op {
-  name: "TextLineReaderV2"
-  output_arg {
-    name: "reader_handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "skip_header_lines"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ThreadPoolDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "thread_pool"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "ThreadPoolHandle"
-  output_arg {
-    name: "handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "num_threads"
-    type: "int"
-  }
-  attr {
-    name: "max_intra_op_parallelism"
-    type: "int"
-    default_value {
-      i: 1
-    }
-  }
-  attr {
-    name: "display_name"
-    type: "string"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "ThreadUnsafeUnigramCandidateSampler"
-  input_arg {
-    name: "true_classes"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sampled_candidates"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "true_expected_count"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "sampled_expected_count"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "num_true"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "num_sampled"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "unique"
-    type: "bool"
-  }
-  attr {
-    name: "range_max"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-}
-op {
-  name: "ThreadUnsafeUnigramCandidateSampler"
-  input_arg {
-    name: "true_classes"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sampled_candidates"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "true_expected_count"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "sampled_expected_count"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "num_true"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "num_sampled"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "unique"
-    type: "bool"
-  }
-  attr {
-    name: "range_max"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "Tile"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "multiples"
-    type_attr: "Tmultiples"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tmultiples"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "TileGrad"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "multiples"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  deprecation {
-    version: 3
-  }
-}
-op {
-  name: "Timestamp"
-  output_arg {
-    name: "ts"
-    type: DT_DOUBLE
-  }
-  is_stateful: true
-}
-op {
-  name: "TopK"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  attr {
-    name: "k"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "sorted"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-  deprecation {
-    version: 7
-  }
-}
-op {
-  name: "TopK"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  attr {
-    name: "k"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "sorted"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  deprecation {
-    version: 7
-  }
-}
-op {
-  name: "TopK"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  attr {
-    name: "k"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "sorted"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  deprecation {
-    version: 7
-  }
-}
-op {
-  name: "TopK"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  attr {
-    name: "k"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "sorted"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  deprecation {
-    version: 7
-  }
-}
-op {
-  name: "TopKV2"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "k"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  attr {
-    name: "sorted"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-}
-op {
-  name: "TopKV2"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "k"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  attr {
-    name: "sorted"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "TopKV2"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "k"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  attr {
-    name: "sorted"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-}
-op {
-  name: "TopKV2"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "k"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "indices"
-    type: DT_INT32
-  }
-  attr {
-    name: "sorted"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-}
-op {
-  name: "Transpose"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "perm"
-    type_attr: "Tperm"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Tperm"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "TridiagonalMatMul"
-  input_arg {
-    name: "superdiag"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "maindiag"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "subdiag"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rhs"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "TridiagonalSolve"
-  input_arg {
-    name: "diagonals"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rhs"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "TridiagonalSolve"
-  input_arg {
-    name: "diagonals"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "rhs"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "partial_pivoting"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_DOUBLE
-        type: DT_FLOAT
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "TruncateDiv"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "TruncateDiv"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "TruncateDiv"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_UINT8
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "TruncateMod"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "TruncateMod"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "TruncateMod"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "TruncatedNormal"
-  input_arg {
-    name: "shape"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "TruncatedNormal"
-  input_arg {
-    name: "shape"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "dtype"
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_BFLOAT16
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "TryRpc"
-  input_arg {
-    name: "address"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "method"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "request"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "response"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "status_code"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "status_message"
-    type: DT_STRING
-  }
-  attr {
-    name: "protocol"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "fail_fast"
-    type: "bool"
-    default_value {
-      b: true
-    }
-  }
-  attr {
-    name: "timeout_in_ms"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "Unbatch"
-  input_arg {
-    name: "batched_tensor"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "batch_index"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "id"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "unbatched_tensor"
-    type_attr: "T"
-  }
-  attr {
-    name: "timeout_micros"
-    type: "int"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "UnbatchDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "UnbatchGrad"
-  input_arg {
-    name: "original_input"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "batch_index"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "grad"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "id"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "batched_grad"
-    type_attr: "T"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "UnicodeDecode"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "row_splits"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "char_values"
-    type: DT_INT32
-  }
-  attr {
-    name: "input_encoding"
-    type: "string"
-  }
-  attr {
-    name: "errors"
-    type: "string"
-    default_value {
-      s: "replace"
-    }
-    allowed_values {
-      list {
-        s: "strict"
-        s: "replace"
-        s: "ignore"
-      }
-    }
-  }
-  attr {
-    name: "replacement_char"
-    type: "int"
-    default_value {
-      i: 65533
-    }
-  }
-  attr {
-    name: "replace_control_characters"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "UnicodeDecode"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "row_splits"
-    type_attr: "Tsplits"
-  }
-  output_arg {
-    name: "char_values"
-    type: DT_INT32
-  }
-  attr {
-    name: "input_encoding"
-    type: "string"
-  }
-  attr {
-    name: "errors"
-    type: "string"
-    default_value {
-      s: "replace"
-    }
-    allowed_values {
-      list {
-        s: "strict"
-        s: "replace"
-        s: "ignore"
-      }
-    }
-  }
-  attr {
-    name: "replacement_char"
-    type: "int"
-    default_value {
-      i: 65533
-    }
-  }
-  attr {
-    name: "replace_control_characters"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "Tsplits"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "UnicodeDecodeWithOffsets"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "row_splits"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "char_values"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "char_to_byte_starts"
-    type: DT_INT64
-  }
-  attr {
-    name: "input_encoding"
-    type: "string"
-  }
-  attr {
-    name: "errors"
-    type: "string"
-    default_value {
-      s: "replace"
-    }
-    allowed_values {
-      list {
-        s: "strict"
-        s: "replace"
-        s: "ignore"
-      }
-    }
-  }
-  attr {
-    name: "replacement_char"
-    type: "int"
-    default_value {
-      i: 65533
-    }
-  }
-  attr {
-    name: "replace_control_characters"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "UnicodeDecodeWithOffsets"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "row_splits"
-    type_attr: "Tsplits"
-  }
-  output_arg {
-    name: "char_values"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "char_to_byte_starts"
-    type: DT_INT64
-  }
-  attr {
-    name: "input_encoding"
-    type: "string"
-  }
-  attr {
-    name: "errors"
-    type: "string"
-    default_value {
-      s: "replace"
-    }
-    allowed_values {
-      list {
-        s: "strict"
-        s: "replace"
-        s: "ignore"
-      }
-    }
-  }
-  attr {
-    name: "replacement_char"
-    type: "int"
-    default_value {
-      i: 65533
-    }
-  }
-  attr {
-    name: "replace_control_characters"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-  attr {
-    name: "Tsplits"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "UnicodeEncode"
-  input_arg {
-    name: "input_values"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "input_splits"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "output"
-    type: DT_STRING
-  }
-  attr {
-    name: "errors"
-    type: "string"
-    default_value {
-      s: "replace"
-    }
-    allowed_values {
-      list {
-        s: "ignore"
-        s: "replace"
-        s: "strict"
-      }
-    }
-  }
-  attr {
-    name: "output_encoding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "UTF-8"
-        s: "UTF-16-BE"
-        s: "UTF-32-BE"
-      }
-    }
-  }
-  attr {
-    name: "replacement_char"
-    type: "int"
-    default_value {
-      i: 65533
-    }
-  }
-}
-op {
-  name: "UnicodeEncode"
-  input_arg {
-    name: "input_values"
-    type: DT_INT32
-  }
-  input_arg {
-    name: "input_splits"
-    type_attr: "Tsplits"
-  }
-  output_arg {
-    name: "output"
-    type: DT_STRING
-  }
-  attr {
-    name: "errors"
-    type: "string"
-    default_value {
-      s: "replace"
-    }
-    allowed_values {
-      list {
-        s: "ignore"
-        s: "replace"
-        s: "strict"
-      }
-    }
-  }
-  attr {
-    name: "output_encoding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "UTF-8"
-        s: "UTF-16-BE"
-        s: "UTF-32-BE"
-      }
-    }
-  }
-  attr {
-    name: "replacement_char"
-    type: "int"
-    default_value {
-      i: 65533
-    }
-  }
-  attr {
-    name: "Tsplits"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "UnicodeScript"
-  input_arg {
-    name: "input"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type: DT_INT32
-  }
-}
-op {
-  name: "UnicodeTranscode"
-  input_arg {
-    name: "input"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "output"
-    type: DT_STRING
-  }
-  attr {
-    name: "input_encoding"
-    type: "string"
-  }
-  attr {
-    name: "output_encoding"
-    type: "string"
-    allowed_values {
-      list {
-        s: "UTF-8"
-        s: "UTF-16-BE"
-        s: "UTF-32-BE"
-      }
-    }
-  }
-  attr {
-    name: "errors"
-    type: "string"
-    default_value {
-      s: "replace"
-    }
-    allowed_values {
-      list {
-        s: "strict"
-        s: "replace"
-        s: "ignore"
-      }
-    }
-  }
-  attr {
-    name: "replacement_char"
-    type: "int"
-    default_value {
-      i: 65533
-    }
-  }
-  attr {
-    name: "replace_control_characters"
-    type: "bool"
-    default_value {
-      b: false
-    }
-  }
-}
-op {
-  name: "UniformCandidateSampler"
-  input_arg {
-    name: "true_classes"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sampled_candidates"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "true_expected_count"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "sampled_expected_count"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "num_true"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "num_sampled"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "unique"
-    type: "bool"
-  }
-  attr {
-    name: "range_max"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-}
-op {
-  name: "UniformCandidateSampler"
-  input_arg {
-    name: "true_classes"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "sampled_candidates"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "true_expected_count"
-    type: DT_FLOAT
-  }
-  output_arg {
-    name: "sampled_expected_count"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "num_true"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "num_sampled"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "unique"
-    type: "bool"
-  }
-  attr {
-    name: "range_max"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "seed"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  attr {
-    name: "seed2"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "Unique"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "idx"
-    type_attr: "out_idx"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "out_idx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "UniqueDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "UniqueV2"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "axis"
-    type: DT_INT64
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "idx"
-    type_attr: "out_idx"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "out_idx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "UniqueV2"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "axis"
-    type_attr: "Taxis"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "idx"
-    type_attr: "out_idx"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Taxis"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "out_idx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "UniqueWithCounts"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "idx"
-    type_attr: "out_idx"
-  }
-  output_arg {
-    name: "count"
-    type_attr: "out_idx"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "out_idx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "UniqueWithCountsV2"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "axis"
-    type_attr: "Taxis"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "idx"
-    type_attr: "out_idx"
-  }
-  output_arg {
-    name: "count"
-    type_attr: "out_idx"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "Taxis"
-    type: "type"
-    default_value {
-      type: DT_INT64
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "out_idx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Unpack"
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-    number_attr: "num"
-  }
-  attr {
-    name: "num"
-    type: "int"
-    has_minimum: true
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "axis"
-    type: "int"
-    default_value {
-      i: 0
-    }
-  }
-}
-op {
-  name: "UnravelIndex"
-  input_arg {
-    name: "indices"
-    type_attr: "Tidx"
-  }
-  input_arg {
-    name: "dims"
-    type_attr: "Tidx"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "Tidx"
-  }
-  attr {
-    name: "Tidx"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "UnsortedSegmentJoin"
-  input_arg {
-    name: "inputs"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "num_segments"
-    type_attr: "Tnumsegments"
-  }
-  output_arg {
-    name: "output"
-    type: DT_STRING
-  }
-  attr {
-    name: "separator"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tnumsegments"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "UnsortedSegmentMax"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "num_segments"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "UnsortedSegmentMax"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "num_segments"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "UnsortedSegmentMax"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "num_segments"
-    type_attr: "Tnumsegments"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_INT64
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tnumsegments"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "UnsortedSegmentMax"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "num_segments"
-    type_attr: "Tnumsegments"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tnumsegments"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "UnsortedSegmentMin"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "num_segments"
-    type_attr: "Tnumsegments"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tnumsegments"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "UnsortedSegmentProd"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "num_segments"
-    type_attr: "Tnumsegments"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tnumsegments"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "UnsortedSegmentProd"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "num_segments"
-    type_attr: "Tnumsegments"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tnumsegments"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "UnsortedSegmentSum"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "num_segments"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "UnsortedSegmentSum"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "num_segments"
-    type: DT_INT32
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "UnsortedSegmentSum"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "num_segments"
-    type_attr: "Tnumsegments"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tnumsegments"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "UnsortedSegmentSum"
-  input_arg {
-    name: "data"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "segment_ids"
-    type_attr: "Tindices"
-  }
-  input_arg {
-    name: "num_segments"
-    type_attr: "Tnumsegments"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  attr {
-    name: "Tindices"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "Tnumsegments"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "Unstage"
-  output_arg {
-    name: "values"
-    type_list_attr: "dtypes"
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "Unstage"
-  output_arg {
-    name: "values"
-    type_list_attr: "dtypes"
-  }
-  attr {
-    name: "capacity"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "memory_limit"
-    type: "int"
-    default_value {
-      i: 0
-    }
-    has_minimum: true
-  }
-  attr {
-    name: "dtypes"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "UnwrapDatasetVariant"
-  input_arg {
-    name: "input_handle"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "output_handle"
-    type: DT_VARIANT
-  }
-}
-op {
-  name: "UpperBound"
-  input_arg {
-    name: "sorted_inputs"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-}
-op {
-  name: "VarHandleOp"
-  output_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  is_stateful: true
-}
-op {
-  name: "VarIsInitializedOp"
-  input_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "is_initialized"
-    type: DT_BOOL
-  }
-  is_stateful: true
-}
-op {
-  name: "Variable"
-  output_arg {
-    name: "ref"
-    type_attr: "dtype"
-    is_ref: true
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "VariableShape"
-  input_arg {
-    name: "input"
-    type: DT_RESOURCE
-  }
-  output_arg {
-    name: "output"
-    type_attr: "out_type"
-  }
-  attr {
-    name: "out_type"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "VariableV2"
-  output_arg {
-    name: "ref"
-    type_attr: "dtype"
-    is_ref: true
-  }
-  attr {
-    name: "shape"
-    type: "shape"
-  }
-  attr {
-    name: "dtype"
-    type: "type"
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "Where"
-  input_arg {
-    name: "input"
-    type: DT_BOOL
-  }
-  output_arg {
-    name: "index"
-    type: DT_INT64
-  }
-}
-op {
-  name: "Where"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "index"
-    type: DT_INT64
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_BOOL
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BOOL
-      }
-    }
-  }
-}
-op {
-  name: "Where"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "index"
-    type: DT_INT64
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_BOOL
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT64
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_UINT16
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BFLOAT16
-        type: DT_BOOL
-      }
-    }
-  }
-}
-op {
-  name: "Where"
-  input_arg {
-    name: "input"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "index"
-    type: DT_INT64
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_BOOL
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_COMPLEX128
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-        type: DT_BOOL
-      }
-    }
-  }
-}
-op {
-  name: "While"
-  input_arg {
-    name: "input"
-    type_list_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "cond"
-    type: "func"
-  }
-  attr {
-    name: "body"
-    type: "func"
-  }
-  is_stateful: true
-}
-op {
-  name: "While"
-  input_arg {
-    name: "input"
-    type_list_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "cond"
-    type: "func"
-  }
-  attr {
-    name: "body"
-    type: "func"
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    default_value {
-      list {
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "While"
-  input_arg {
-    name: "input"
-    type_list_attr: "T"
-  }
-  output_arg {
-    name: "output"
-    type_list_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "cond"
-    type: "func"
-  }
-  attr {
-    name: "body"
-    type: "func"
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "parallel_iterations"
-    type: "int"
-    default_value {
-      i: 10
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "WholeFileReader"
-  output_arg {
-    name: "reader_handle"
-    type: DT_STRING
-    is_ref: true
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "WholeFileReaderV2"
-  output_arg {
-    name: "reader_handle"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "WindowDataset"
-  input_arg {
-    name: "input_dataset"
-    type: DT_VARIANT
-  }
-  input_arg {
-    name: "size"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "shift"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "stride"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "drop_remainder"
-    type: DT_BOOL
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-}
-op {
-  name: "WorkerHeartbeat"
-  input_arg {
-    name: "request"
-    type: DT_STRING
-  }
-  output_arg {
-    name: "response"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "WrapDatasetVariant"
-  input_arg {
-    name: "input_handle"
-    type: DT_VARIANT
-  }
-  output_arg {
-    name: "output_handle"
-    type: DT_VARIANT
-  }
-}
-op {
-  name: "WriteAudioSummary"
-  input_arg {
-    name: "writer"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "step"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "tag"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "tensor"
-    type: DT_FLOAT
-  }
-  input_arg {
-    name: "sample_rate"
-    type: DT_FLOAT
-  }
-  attr {
-    name: "max_outputs"
-    type: "int"
-    default_value {
-      i: 3
-    }
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "WriteFile"
-  input_arg {
-    name: "filename"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "contents"
-    type: DT_STRING
-  }
-}
-op {
-  name: "WriteFile"
-  input_arg {
-    name: "filename"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "contents"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "WriteGraphSummary"
-  input_arg {
-    name: "writer"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "step"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "tensor"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "WriteHistogramSummary"
-  input_arg {
-    name: "writer"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "step"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "tag"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "values"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "WriteImageSummary"
-  input_arg {
-    name: "writer"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "step"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "tag"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "tensor"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "bad_color"
-    type: DT_UINT8
-  }
-  attr {
-    name: "max_images"
-    type: "int"
-    default_value {
-      i: 3
-    }
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "T"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_UINT8
-        type: DT_FLOAT
-        type: DT_HALF
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "WriteRawProtoSummary"
-  input_arg {
-    name: "writer"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "step"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "tensor"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
-op {
-  name: "WriteScalarSummary"
-  input_arg {
-    name: "writer"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "step"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "tag"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "value"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_INT32
-        type: DT_UINT8
-        type: DT_INT16
-        type: DT_INT8
-        type: DT_INT64
-        type: DT_BFLOAT16
-        type: DT_UINT16
-        type: DT_HALF
-        type: DT_UINT32
-        type: DT_UINT64
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
-  name: "WriteSummary"
-  input_arg {
-    name: "writer"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "step"
-    type: DT_INT64
-  }
-  input_arg {
-    name: "tensor"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "tag"
-    type: DT_STRING
-  }
-  input_arg {
-    name: "summary_metadata"
-    type: DT_STRING
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-  is_stateful: true
-}
-op {
-  name: "Xdivy"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "Xlogy"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_HALF
-        type: DT_FLOAT
-        type: DT_DOUBLE
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
-      }
-    }
-  }
-}
-op {
-  name: "ZerosLike"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "y"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-  }
-}
-op {
-  name: "Zeta"
-  input_arg {
-    name: "x"
-    type_attr: "T"
-  }
-  input_arg {
-    name: "q"
-    type_attr: "T"
-  }
-  output_arg {
-    name: "z"
-    type_attr: "T"
-  }
-  attr {
-    name: "T"
-    type: "type"
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-}
-op {
-  name: "ZipDataset"
-  input_arg {
-    name: "input_datasets"
-    type: DT_VARIANT
-    number_attr: "N"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "ZipDataset"
-  input_arg {
-    name: "input_datasets"
-    type: DT_VARIANT
-    number_attr: "N"
-  }
-  output_arg {
-    name: "handle"
-    type: DT_VARIANT
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Abort.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Abort.pbtxt
new file mode 100644
index 0000000..4752385
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Abort.pbtxt
@@ -0,0 +1,17 @@
+op {
+  name: "Abort"
+  attr {
+    name: "error_msg"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "exit_without_error"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Abs.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Abs.pbtxt
new file mode 100644
index 0000000..80e7c7f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Abs.pbtxt
@@ -0,0 +1,74 @@
+op {
+  name: "Abs"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Abs"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Abs"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AccumulateNV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AccumulateNV2.pbtxt
new file mode 100644
index 0000000..5c9523d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AccumulateNV2.pbtxt
@@ -0,0 +1,146 @@
+op {
+  name: "AccumulateNV2"
+  input_arg {
+    name: "inputs"
+    type_attr: "T"
+    number_attr: "N"
+  }
+  output_arg {
+    name: "sum"
+    type_attr: "T"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  is_aggregate: true
+  is_commutative: true
+}
+op {
+  name: "AccumulateNV2"
+  input_arg {
+    name: "inputs"
+    type_attr: "T"
+    number_attr: "N"
+  }
+  output_arg {
+    name: "sum"
+    type_attr: "T"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  is_aggregate: true
+  is_commutative: true
+}
+op {
+  name: "AccumulateNV2"
+  input_arg {
+    name: "inputs"
+    type_attr: "T"
+    number_attr: "N"
+  }
+  output_arg {
+    name: "sum"
+    type_attr: "T"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  is_aggregate: true
+  is_commutative: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AccumulatorApplyGradient.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AccumulatorApplyGradient.pbtxt
new file mode 100644
index 0000000..b29f4a0a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AccumulatorApplyGradient.pbtxt
@@ -0,0 +1,160 @@
+op {
+  name: "AccumulatorApplyGradient"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "local_step"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "gradient"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "AccumulatorApplyGradient"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "local_step"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "gradient"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "AccumulatorApplyGradient"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "local_step"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "gradient"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "AccumulatorApplyGradient"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "local_step"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "gradient"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AccumulatorNumAccumulated.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AccumulatorNumAccumulated.pbtxt
new file mode 100644
index 0000000..f378509
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AccumulatorNumAccumulated.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "AccumulatorNumAccumulated"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  output_arg {
+    name: "num_accumulated"
+    type: DT_INT32
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AccumulatorSetGlobalStep.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AccumulatorSetGlobalStep.pbtxt
new file mode 100644
index 0000000..9b4170d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AccumulatorSetGlobalStep.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "AccumulatorSetGlobalStep"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "new_global_step"
+    type: DT_INT64
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AccumulatorTakeGradient.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AccumulatorTakeGradient.pbtxt
new file mode 100644
index 0000000..22b521a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AccumulatorTakeGradient.pbtxt
@@ -0,0 +1,160 @@
+op {
+  name: "AccumulatorTakeGradient"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "num_required"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "average"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "AccumulatorTakeGradient"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "num_required"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "average"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "AccumulatorTakeGradient"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "num_required"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "average"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "AccumulatorTakeGradient"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "num_required"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "average"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Acos.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Acos.pbtxt
new file mode 100644
index 0000000..3ed4518
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Acos.pbtxt
@@ -0,0 +1,80 @@
+op {
+  name: "Acos"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Acos"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Acos"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Acosh.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Acosh.pbtxt
new file mode 100644
index 0000000..e53c817
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Acosh.pbtxt
@@ -0,0 +1,74 @@
+op {
+  name: "Acosh"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Acosh"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Acosh"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Add.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Add.pbtxt
new file mode 100644
index 0000000..ce30e6d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Add.pbtxt
@@ -0,0 +1,104 @@
+op {
+  name: "Add"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_STRING
+      }
+    }
+  }
+}
+op {
+  name: "Add"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_STRING
+      }
+    }
+  }
+}
+op {
+  name: "Add"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_STRING
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AddManySparseToTensorsMap.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AddManySparseToTensorsMap.pbtxt
new file mode 100644
index 0000000..c1433cc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AddManySparseToTensorsMap.pbtxt
@@ -0,0 +1,38 @@
+op {
+  name: "AddManySparseToTensorsMap"
+  input_arg {
+    name: "sparse_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sparse_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sparse_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sparse_handles"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AddN.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AddN.pbtxt
new file mode 100644
index 0000000..076f3e0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AddN.pbtxt
@@ -0,0 +1,222 @@
+op {
+  name: "AddN"
+  input_arg {
+    name: "inputs"
+    type_attr: "T"
+    number_attr: "N"
+  }
+  output_arg {
+    name: "sum"
+    type_attr: "T"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  is_aggregate: true
+  is_commutative: true
+}
+op {
+  name: "AddN"
+  input_arg {
+    name: "inputs"
+    type_attr: "T"
+    number_attr: "N"
+  }
+  output_arg {
+    name: "sum"
+    type_attr: "T"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_VARIANT
+      }
+    }
+  }
+  is_aggregate: true
+  is_commutative: true
+}
+op {
+  name: "AddN"
+  input_arg {
+    name: "inputs"
+    type_attr: "T"
+    number_attr: "N"
+  }
+  output_arg {
+    name: "sum"
+    type_attr: "T"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_VARIANT
+      }
+    }
+  }
+  is_aggregate: true
+  is_commutative: true
+}
+op {
+  name: "AddN"
+  input_arg {
+    name: "inputs"
+    type_attr: "T"
+    number_attr: "N"
+  }
+  output_arg {
+    name: "sum"
+    type_attr: "T"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+        type: DT_VARIANT
+      }
+    }
+  }
+  is_aggregate: true
+  is_commutative: true
+}
+op {
+  name: "AddN"
+  input_arg {
+    name: "inputs"
+    type_attr: "T"
+    number_attr: "N"
+  }
+  output_arg {
+    name: "sum"
+    type_attr: "T"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_VARIANT
+      }
+    }
+  }
+  is_aggregate: true
+  is_commutative: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AddSparseToTensorsMap.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AddSparseToTensorsMap.pbtxt
new file mode 100644
index 0000000..8a4c020
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AddSparseToTensorsMap.pbtxt
@@ -0,0 +1,38 @@
+op {
+  name: "AddSparseToTensorsMap"
+  input_arg {
+    name: "sparse_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sparse_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sparse_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sparse_handle"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AddV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AddV2.pbtxt
new file mode 100644
index 0000000..8781485
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AddV2.pbtxt
@@ -0,0 +1,107 @@
+op {
+  name: "AddV2"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  is_aggregate: true
+  is_commutative: true
+}
+op {
+  name: "AddV2"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  is_aggregate: true
+  is_commutative: true
+}
+op {
+  name: "AddV2"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  is_aggregate: true
+  is_commutative: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AdjustContrast.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AdjustContrast.pbtxt
new file mode 100644
index 0000000..e51900d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AdjustContrast.pbtxt
@@ -0,0 +1,41 @@
+op {
+  name: "AdjustContrast"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "contrast_factor"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_value"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_value"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  deprecation {
+    version: 2
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AdjustContrastv2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AdjustContrastv2.pbtxt
new file mode 100644
index 0000000..6869f26
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AdjustContrastv2.pbtxt
@@ -0,0 +1,43 @@
+op {
+  name: "AdjustContrastv2"
+  input_arg {
+    name: "images"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "contrast_factor"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type: DT_FLOAT
+  }
+}
+op {
+  name: "AdjustContrastv2"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "contrast_factor"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AdjustHue.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AdjustHue.pbtxt
new file mode 100644
index 0000000..9a6c72d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AdjustHue.pbtxt
@@ -0,0 +1,43 @@
+op {
+  name: "AdjustHue"
+  input_arg {
+    name: "images"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "delta"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type: DT_FLOAT
+  }
+}
+op {
+  name: "AdjustHue"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "delta"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AdjustSaturation.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AdjustSaturation.pbtxt
new file mode 100644
index 0000000..918ea18
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AdjustSaturation.pbtxt
@@ -0,0 +1,43 @@
+op {
+  name: "AdjustSaturation"
+  input_arg {
+    name: "images"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "scale"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type: DT_FLOAT
+  }
+}
+op {
+  name: "AdjustSaturation"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "scale"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/All.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/All.pbtxt
new file mode 100644
index 0000000..c0bc8f4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/All.pbtxt
@@ -0,0 +1,35 @@
+op {
+  name: "All"
+  input_arg {
+    name: "input"
+    type: DT_BOOL
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type: DT_BOOL
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AllCandidateSampler.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AllCandidateSampler.pbtxt
new file mode 100644
index 0000000..e452850
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AllCandidateSampler.pbtxt
@@ -0,0 +1,99 @@
+op {
+  name: "AllCandidateSampler"
+  input_arg {
+    name: "true_classes"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sampled_candidates"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "true_expected_count"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "sampled_expected_count"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "num_true"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_sampled"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "unique"
+    type: "bool"
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+}
+op {
+  name: "AllCandidateSampler"
+  input_arg {
+    name: "true_classes"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sampled_candidates"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "true_expected_count"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "sampled_expected_count"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "num_true"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_sampled"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "unique"
+    type: "bool"
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AllToAll.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AllToAll.pbtxt
new file mode 100644
index 0000000..9f6bafd
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AllToAll.pbtxt
@@ -0,0 +1,90 @@
+op {
+  name: "AllToAll"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "group_assignment"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "concat_dimension"
+    type: "int"
+  }
+  attr {
+    name: "split_dimension"
+    type: "int"
+  }
+  attr {
+    name: "split_count"
+    type: "int"
+  }
+}
+op {
+  name: "AllToAll"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "group_assignment"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BOOL
+      }
+    }
+  }
+  attr {
+    name: "concat_dimension"
+    type: "int"
+  }
+  attr {
+    name: "split_dimension"
+    type: "int"
+  }
+  attr {
+    name: "split_count"
+    type: "int"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Angle.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Angle.pbtxt
new file mode 100644
index 0000000..ce28927
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Angle.pbtxt
@@ -0,0 +1,37 @@
+op {
+  name: "Angle"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "Tout"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_COMPLEX64
+    }
+    allowed_values {
+      list {
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  attr {
+    name: "Tout"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AnonymousIterator.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AnonymousIterator.pbtxt
new file mode 100644
index 0000000..bf8f8fc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AnonymousIterator.pbtxt
@@ -0,0 +1,20 @@
+op {
+  name: "AnonymousIterator"
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AnonymousIteratorV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AnonymousIteratorV2.pbtxt
new file mode 100644
index 0000000..e7dca69
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AnonymousIteratorV2.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "AnonymousIteratorV2"
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "deleter"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AnonymousMemoryCache.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AnonymousMemoryCache.pbtxt
new file mode 100644
index 0000000..7f15df3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AnonymousMemoryCache.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "AnonymousMemoryCache"
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "deleter"
+    type: DT_VARIANT
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AnonymousMultiDeviceIterator.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AnonymousMultiDeviceIterator.pbtxt
new file mode 100644
index 0000000..b8afaa3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AnonymousMultiDeviceIterator.pbtxt
@@ -0,0 +1,30 @@
+op {
+  name: "AnonymousMultiDeviceIterator"
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "deleter"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "devices"
+    type: "list(string)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AnonymousRandomSeedGenerator.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AnonymousRandomSeedGenerator.pbtxt
new file mode 100644
index 0000000..da2558b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AnonymousRandomSeedGenerator.pbtxt
@@ -0,0 +1,20 @@
+op {
+  name: "AnonymousRandomSeedGenerator"
+  input_arg {
+    name: "seed"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "seed2"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "deleter"
+    type: DT_VARIANT
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Any.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Any.pbtxt
new file mode 100644
index 0000000..da02090
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Any.pbtxt
@@ -0,0 +1,35 @@
+op {
+  name: "Any"
+  input_arg {
+    name: "input"
+    type: DT_BOOL
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type: DT_BOOL
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ApplyAdaMax.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ApplyAdaMax.pbtxt
new file mode 100644
index 0000000..5d9e414
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ApplyAdaMax.pbtxt
@@ -0,0 +1,79 @@
+op {
+  name: "ApplyAdaMax"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "m"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "v"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "beta1_power"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ApplyAdadelta.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ApplyAdadelta.pbtxt
new file mode 100644
index 0000000..5d98b49
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ApplyAdadelta.pbtxt
@@ -0,0 +1,280 @@
+op {
+  name: "ApplyAdadelta"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum_update"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyAdadelta"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum_update"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyAdadelta"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum_update"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyAdadelta"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum_update"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ApplyAdagrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ApplyAdagrad.pbtxt
new file mode 100644
index 0000000..25cbf4b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ApplyAdagrad.pbtxt
@@ -0,0 +1,293 @@
+op {
+  name: "ApplyAdagrad"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyAdagrad"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyAdagrad"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyAdagrad"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyAdagrad"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "update_slots"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ApplyAdagradDA.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ApplyAdagradDA.pbtxt
new file mode 100644
index 0000000..8a8053f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ApplyAdagradDA.pbtxt
@@ -0,0 +1,296 @@
+op {
+  name: "ApplyAdagradDA"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "gradient_accumulator"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "gradient_squared_accumulator"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "global_step"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyAdagradDA"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "gradient_accumulator"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "gradient_squared_accumulator"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "global_step"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyAdagradDA"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "gradient_accumulator"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "gradient_squared_accumulator"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "global_step"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyAdagradDA"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "gradient_accumulator"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "gradient_squared_accumulator"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "global_step"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ApplyAdagradV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ApplyAdagradV2.pbtxt
new file mode 100644
index 0000000..d2d70ee
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ApplyAdagradV2.pbtxt
@@ -0,0 +1,69 @@
+op {
+  name: "ApplyAdagradV2"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "update_slots"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ApplyAdam.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ApplyAdam.pbtxt
new file mode 100644
index 0000000..44f1560
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ApplyAdam.pbtxt
@@ -0,0 +1,436 @@
+op {
+  name: "ApplyAdam"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "m"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "v"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "beta1_power"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta2_power"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyAdam"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "m"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "v"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "beta1_power"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta2_power"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyAdam"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "m"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "v"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "beta1_power"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta2_power"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyAdam"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "m"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "v"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "beta1_power"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta2_power"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyAdam"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "m"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "v"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "beta1_power"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta2_power"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ApplyAddSign.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ApplyAddSign.pbtxt
new file mode 100644
index 0000000..83d621b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ApplyAddSign.pbtxt
@@ -0,0 +1,209 @@
+op {
+  name: "ApplyAddSign"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "m"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sign_decay"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyAddSign"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "m"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sign_decay"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyAddSign"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "m"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sign_decay"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ApplyCenteredRMSProp.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ApplyCenteredRMSProp.pbtxt
new file mode 100644
index 0000000..300fbea
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ApplyCenteredRMSProp.pbtxt
@@ -0,0 +1,316 @@
+op {
+  name: "ApplyCenteredRMSProp"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "mg"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "ms"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "mom"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyCenteredRMSProp"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "mg"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "ms"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "mom"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyCenteredRMSProp"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "mg"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "ms"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "mom"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyCenteredRMSProp"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "mg"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "ms"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "mom"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ApplyFtrl.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ApplyFtrl.pbtxt
new file mode 100644
index 0000000..23c85a1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ApplyFtrl.pbtxt
@@ -0,0 +1,296 @@
+op {
+  name: "ApplyFtrl"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "linear"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyFtrl"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "linear"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyFtrl"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "linear"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyFtrl"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "linear"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ApplyFtrlV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ApplyFtrlV2.pbtxt
new file mode 100644
index 0000000..3a8e207
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ApplyFtrlV2.pbtxt
@@ -0,0 +1,312 @@
+op {
+  name: "ApplyFtrlV2"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "linear"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2_shrinkage"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyFtrlV2"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "linear"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2_shrinkage"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyFtrlV2"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "linear"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2_shrinkage"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyFtrlV2"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "linear"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2_shrinkage"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ApplyGradientDescent.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ApplyGradientDescent.pbtxt
new file mode 100644
index 0000000..c4df817
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ApplyGradientDescent.pbtxt
@@ -0,0 +1,208 @@
+op {
+  name: "ApplyGradientDescent"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "delta"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyGradientDescent"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "delta"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyGradientDescent"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "delta"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyGradientDescent"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "delta"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ApplyMomentum.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ApplyMomentum.pbtxt
new file mode 100644
index 0000000..c723244
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ApplyMomentum.pbtxt
@@ -0,0 +1,272 @@
+op {
+  name: "ApplyMomentum"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyMomentum"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyMomentum"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyMomentum"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ApplyPowerSign.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ApplyPowerSign.pbtxt
new file mode 100644
index 0000000..7013010
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ApplyPowerSign.pbtxt
@@ -0,0 +1,209 @@
+op {
+  name: "ApplyPowerSign"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "m"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "logbase"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sign_decay"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyPowerSign"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "m"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "logbase"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sign_decay"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyPowerSign"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "m"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "logbase"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sign_decay"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ApplyProximalAdagrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ApplyProximalAdagrad.pbtxt
new file mode 100644
index 0000000..bf76e30
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ApplyProximalAdagrad.pbtxt
@@ -0,0 +1,260 @@
+op {
+  name: "ApplyProximalAdagrad"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyProximalAdagrad"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyProximalAdagrad"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyProximalAdagrad"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ApplyProximalGradientDescent.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ApplyProximalGradientDescent.pbtxt
new file mode 100644
index 0000000..89ab89c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ApplyProximalGradientDescent.pbtxt
@@ -0,0 +1,240 @@
+op {
+  name: "ApplyProximalGradientDescent"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "delta"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyProximalGradientDescent"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "delta"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyProximalGradientDescent"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "delta"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyProximalGradientDescent"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "delta"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ApplyRMSProp.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ApplyRMSProp.pbtxt
new file mode 100644
index 0000000..b30e9b0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ApplyRMSProp.pbtxt
@@ -0,0 +1,296 @@
+op {
+  name: "ApplyRMSProp"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "ms"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "mom"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyRMSProp"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "ms"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "mom"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyRMSProp"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "ms"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "mom"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ApplyRMSProp"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "ms"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "mom"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ApproximateEqual.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ApproximateEqual.pbtxt
new file mode 100644
index 0000000..40c9025
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ApproximateEqual.pbtxt
@@ -0,0 +1,188 @@
+op {
+  name: "ApproximateEqual"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "tolerance"
+    type: "float"
+    default_value {
+      f: 1e-05
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "ApproximateEqual"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "tolerance"
+    type: "float"
+    default_value {
+      f: 1e-05
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "ApproximateEqual"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "tolerance"
+    type: "float"
+    default_value {
+      f: 1e-05
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "ApproximateEqual"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "tolerance"
+    type: "float"
+    default_value {
+      f: 1e-05
+    }
+  }
+  is_commutative: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ArgMax.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ArgMax.pbtxt
new file mode 100644
index 0000000..6fd71eb
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ArgMax.pbtxt
@@ -0,0 +1,310 @@
+op {
+  name: "ArgMax"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dimension"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "ArgMax"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dimension"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "output_type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "output_type"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "ArgMax"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dimension"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "output_type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "output_type"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "ArgMax"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dimension"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "output_type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "output_type"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "ArgMax"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dimension"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "output_type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "output_type"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ArgMin.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ArgMin.pbtxt
new file mode 100644
index 0000000..b6fa24e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ArgMin.pbtxt
@@ -0,0 +1,310 @@
+op {
+  name: "ArgMin"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dimension"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "ArgMin"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dimension"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "output_type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "output_type"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "ArgMin"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dimension"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "output_type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "output_type"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "ArgMin"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dimension"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "output_type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "output_type"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "ArgMin"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dimension"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "output_type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "output_type"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AsString.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AsString.pbtxt
new file mode 100644
index 0000000..2bbf48d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AsString.pbtxt
@@ -0,0 +1,186 @@
+op {
+  name: "AsString"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_BOOL
+        type: DT_INT8
+      }
+    }
+  }
+  attr {
+    name: "precision"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "scientific"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "shortest"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "width"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "fill"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+}
+op {
+  name: "AsString"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_BOOL
+      }
+    }
+  }
+  attr {
+    name: "precision"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "scientific"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "shortest"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "width"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "fill"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+}
+op {
+  name: "AsString"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_BOOL
+      }
+    }
+  }
+  attr {
+    name: "precision"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "scientific"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "shortest"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "width"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "fill"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Asin.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Asin.pbtxt
new file mode 100644
index 0000000..7df768f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Asin.pbtxt
@@ -0,0 +1,80 @@
+op {
+  name: "Asin"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Asin"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Asin"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Asinh.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Asinh.pbtxt
new file mode 100644
index 0000000..7f31ec1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Asinh.pbtxt
@@ -0,0 +1,74 @@
+op {
+  name: "Asinh"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Asinh"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Asinh"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Assert.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Assert.pbtxt
new file mode 100644
index 0000000..a891ca8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Assert.pbtxt
@@ -0,0 +1,25 @@
+op {
+  name: "Assert"
+  input_arg {
+    name: "condition"
+    type: DT_BOOL
+  }
+  input_arg {
+    name: "data"
+    type_list_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "summarize"
+    type: "int"
+    default_value {
+      i: 3
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AssertNextDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AssertNextDataset.pbtxt
new file mode 100644
index 0000000..7ca9d56
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AssertNextDataset.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "AssertNextDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "transformations"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Assign.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Assign.pbtxt
new file mode 100644
index 0000000..9255e12
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Assign.pbtxt
@@ -0,0 +1,36 @@
+op {
+  name: "Assign"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "validate_shape"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  allows_uninitialized_input: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AssignAdd.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AssignAdd.pbtxt
new file mode 100644
index 0000000..d36449e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AssignAdd.pbtxt
@@ -0,0 +1,192 @@
+op {
+  name: "AssignAdd"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "AssignAdd"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "AssignAdd"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "AssignAdd"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AssignAddVariableOp.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AssignAddVariableOp.pbtxt
new file mode 100644
index 0000000..c3a8b74
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AssignAddVariableOp.pbtxt
@@ -0,0 +1,16 @@
+op {
+  name: "AssignAddVariableOp"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "value"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AssignSub.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AssignSub.pbtxt
new file mode 100644
index 0000000..55ed790
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AssignSub.pbtxt
@@ -0,0 +1,192 @@
+op {
+  name: "AssignSub"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "AssignSub"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "AssignSub"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "AssignSub"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AssignSubVariableOp.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AssignSubVariableOp.pbtxt
new file mode 100644
index 0000000..a5c9a56
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AssignSubVariableOp.pbtxt
@@ -0,0 +1,16 @@
+op {
+  name: "AssignSubVariableOp"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "value"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AssignVariableOp.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AssignVariableOp.pbtxt
new file mode 100644
index 0000000..5fb0396
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AssignVariableOp.pbtxt
@@ -0,0 +1,16 @@
+op {
+  name: "AssignVariableOp"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "value"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Atan.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Atan.pbtxt
new file mode 100644
index 0000000..86f0628
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Atan.pbtxt
@@ -0,0 +1,80 @@
+op {
+  name: "Atan"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Atan"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Atan"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Atan2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Atan2.pbtxt
new file mode 100644
index 0000000..e58675d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Atan2.pbtxt
@@ -0,0 +1,78 @@
+op {
+  name: "Atan2"
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "Atan2"
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "Atan2"
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Atanh.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Atanh.pbtxt
new file mode 100644
index 0000000..28d417a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Atanh.pbtxt
@@ -0,0 +1,74 @@
+op {
+  name: "Atanh"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Atanh"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Atanh"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AudioSpectrogram.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AudioSpectrogram.pbtxt
new file mode 100644
index 0000000..dbc2a22
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AudioSpectrogram.pbtxt
@@ -0,0 +1,26 @@
+op {
+  name: "AudioSpectrogram"
+  input_arg {
+    name: "input"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "spectrogram"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "window_size"
+    type: "int"
+  }
+  attr {
+    name: "stride"
+    type: "int"
+  }
+  attr {
+    name: "magnitude_squared"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AudioSummary.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AudioSummary.pbtxt
new file mode 100644
index 0000000..4b18305
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AudioSummary.pbtxt
@@ -0,0 +1,31 @@
+op {
+  name: "AudioSummary"
+  input_arg {
+    name: "tag"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "tensor"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "summary"
+    type: DT_STRING
+  }
+  attr {
+    name: "sample_rate"
+    type: "float"
+  }
+  attr {
+    name: "max_outputs"
+    type: "int"
+    default_value {
+      i: 3
+    }
+    has_minimum: true
+    minimum: 1
+  }
+  deprecation {
+    version: 15
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AudioSummaryV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AudioSummaryV2.pbtxt
new file mode 100644
index 0000000..313c044
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AudioSummaryV2.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "AudioSummaryV2"
+  input_arg {
+    name: "tag"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "tensor"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "sample_rate"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "summary"
+    type: DT_STRING
+  }
+  attr {
+    name: "max_outputs"
+    type: "int"
+    default_value {
+      i: 3
+    }
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AutoShardDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AutoShardDataset.pbtxt
new file mode 100644
index 0000000..2b7dcfa
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AutoShardDataset.pbtxt
@@ -0,0 +1,31 @@
+op {
+  name: "AutoShardDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "num_workers"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "index"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AvgPool.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AvgPool.pbtxt
new file mode 100644
index 0000000..8e7db13
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AvgPool.pbtxt
@@ -0,0 +1,229 @@
+op {
+  name: "AvgPool"
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "AvgPool"
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "AvgPool"
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "AvgPool"
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AvgPool3D.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AvgPool3D.pbtxt
new file mode 100644
index 0000000..f3f60cb
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AvgPool3D.pbtxt
@@ -0,0 +1,214 @@
+op {
+  name: "AvgPool3D"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "AvgPool3D"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NDHWC"
+    }
+    allowed_values {
+      list {
+        s: "NDHWC"
+        s: "NCDHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "AvgPool3D"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NDHWC"
+    }
+    allowed_values {
+      list {
+        s: "NDHWC"
+        s: "NCDHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "AvgPool3D"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NDHWC"
+    }
+    allowed_values {
+      list {
+        s: "NDHWC"
+        s: "NCDHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AvgPool3DGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AvgPool3DGrad.pbtxt
new file mode 100644
index 0000000..67fef95
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AvgPool3DGrad.pbtxt
@@ -0,0 +1,230 @@
+op {
+  name: "AvgPool3DGrad"
+  input_arg {
+    name: "orig_input_shape"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "AvgPool3DGrad"
+  input_arg {
+    name: "orig_input_shape"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NDHWC"
+    }
+    allowed_values {
+      list {
+        s: "NDHWC"
+        s: "NCDHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "AvgPool3DGrad"
+  input_arg {
+    name: "orig_input_shape"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NDHWC"
+    }
+    allowed_values {
+      list {
+        s: "NDHWC"
+        s: "NCDHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "AvgPool3DGrad"
+  input_arg {
+    name: "orig_input_shape"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NDHWC"
+    }
+    allowed_values {
+      list {
+        s: "NDHWC"
+        s: "NCDHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/AvgPoolGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/AvgPoolGrad.pbtxt
new file mode 100644
index 0000000..6c72eff
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/AvgPoolGrad.pbtxt
@@ -0,0 +1,245 @@
+op {
+  name: "AvgPoolGrad"
+  input_arg {
+    name: "orig_input_shape"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "AvgPoolGrad"
+  input_arg {
+    name: "orig_input_shape"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "AvgPoolGrad"
+  input_arg {
+    name: "orig_input_shape"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "AvgPoolGrad"
+  input_arg {
+    name: "orig_input_shape"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Barrier.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Barrier.pbtxt
new file mode 100644
index 0000000..9391157
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Barrier.pbtxt
@@ -0,0 +1,45 @@
+op {
+  name: "Barrier"
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "component_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "shapes"
+    type: "list(shape)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BarrierClose.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BarrierClose.pbtxt
new file mode 100644
index 0000000..6923048
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BarrierClose.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "BarrierClose"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "cancel_pending_enqueues"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BarrierIncompleteSize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BarrierIncompleteSize.pbtxt
new file mode 100644
index 0000000..0d17c18
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BarrierIncompleteSize.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "BarrierIncompleteSize"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  output_arg {
+    name: "size"
+    type: DT_INT32
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BarrierInsertMany.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BarrierInsertMany.pbtxt
new file mode 100644
index 0000000..86b64f6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BarrierInsertMany.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "BarrierInsertMany"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "keys"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "component_index"
+    type: "int"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BarrierReadySize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BarrierReadySize.pbtxt
new file mode 100644
index 0000000..e7b0630
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BarrierReadySize.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "BarrierReadySize"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  output_arg {
+    name: "size"
+    type: DT_INT32
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BarrierTakeMany.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BarrierTakeMany.pbtxt
new file mode 100644
index 0000000..e324042
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BarrierTakeMany.pbtxt
@@ -0,0 +1,51 @@
+op {
+  name: "BarrierTakeMany"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "num_elements"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "keys"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "values"
+    type_list_attr: "component_types"
+  }
+  attr {
+    name: "component_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "allow_small_batch"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "wait_for_incomplete"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "timeout_ms"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Batch.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Batch.pbtxt
new file mode 100644
index 0000000..5e2ea7a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Batch.pbtxt
@@ -0,0 +1,147 @@
+op {
+  name: "Batch"
+  input_arg {
+    name: "in_tensors"
+    type_list_attr: "T"
+  }
+  output_arg {
+    name: "batched_tensors"
+    type_list_attr: "T"
+  }
+  output_arg {
+    name: "batch_index"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "id"
+    type: DT_INT64
+  }
+  attr {
+    name: "num_batch_threads"
+    type: "int"
+  }
+  attr {
+    name: "max_batch_size"
+    type: "int"
+  }
+  attr {
+    name: "batch_timeout_micros"
+    type: "int"
+  }
+  attr {
+    name: "allowed_batch_sizes"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "grad_timeout_micros"
+    type: "int"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "batching_queue"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "T"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
+  name: "Batch"
+  input_arg {
+    name: "in_tensors"
+    type_list_attr: "T"
+  }
+  output_arg {
+    name: "batched_tensors"
+    type_list_attr: "T"
+  }
+  output_arg {
+    name: "batch_index"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "id"
+    type: DT_INT64
+  }
+  attr {
+    name: "num_batch_threads"
+    type: "int"
+  }
+  attr {
+    name: "max_batch_size"
+    type: "int"
+  }
+  attr {
+    name: "max_enqueued_batches"
+    type: "int"
+    default_value {
+      i: 10
+    }
+  }
+  attr {
+    name: "batch_timeout_micros"
+    type: "int"
+  }
+  attr {
+    name: "allowed_batch_sizes"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "grad_timeout_micros"
+    type: "int"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "batching_queue"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "T"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchCholesky.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchCholesky.pbtxt
new file mode 100644
index 0000000..5d38acc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchCholesky.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "BatchCholesky"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+      }
+    }
+  }
+  deprecation {
+    version: 13
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchCholeskyGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchCholeskyGrad.pbtxt
new file mode 100644
index 0000000..286ae3a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchCholeskyGrad.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "BatchCholeskyGrad"
+  input_arg {
+    name: "l"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  deprecation {
+    version: 13
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchDataset.pbtxt
new file mode 100644
index 0000000..6770f2f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchDataset.pbtxt
@@ -0,0 +1,55 @@
+op {
+  name: "BatchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "batch_size"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
+op {
+  name: "BatchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "batch_size"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchDatasetV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchDatasetV2.pbtxt
new file mode 100644
index 0000000..f553418
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchDatasetV2.pbtxt
@@ -0,0 +1,69 @@
+op {
+  name: "BatchDatasetV2"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "batch_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "drop_remainder"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
+  name: "BatchDatasetV2"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "batch_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "drop_remainder"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "parallel_copy"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchFFT.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchFFT.pbtxt
new file mode 100644
index 0000000..4fe86a3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchFFT.pbtxt
@@ -0,0 +1,14 @@
+op {
+  name: "BatchFFT"
+  input_arg {
+    name: "input"
+    type: DT_COMPLEX64
+  }
+  output_arg {
+    name: "output"
+    type: DT_COMPLEX64
+  }
+  deprecation {
+    version: 15
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchFFT2D.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchFFT2D.pbtxt
new file mode 100644
index 0000000..b52a6bd
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchFFT2D.pbtxt
@@ -0,0 +1,14 @@
+op {
+  name: "BatchFFT2D"
+  input_arg {
+    name: "input"
+    type: DT_COMPLEX64
+  }
+  output_arg {
+    name: "output"
+    type: DT_COMPLEX64
+  }
+  deprecation {
+    version: 15
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchFFT3D.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchFFT3D.pbtxt
new file mode 100644
index 0000000..7f19cf1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchFFT3D.pbtxt
@@ -0,0 +1,14 @@
+op {
+  name: "BatchFFT3D"
+  input_arg {
+    name: "input"
+    type: DT_COMPLEX64
+  }
+  output_arg {
+    name: "output"
+    type: DT_COMPLEX64
+  }
+  deprecation {
+    version: 15
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchFunction.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchFunction.pbtxt
new file mode 100644
index 0000000..daf3c46
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchFunction.pbtxt
@@ -0,0 +1,84 @@
+op {
+  name: "BatchFunction"
+  input_arg {
+    name: "in_tensors"
+    type_list_attr: "Tin"
+  }
+  input_arg {
+    name: "captured_tensors"
+    type_list_attr: "Tcaptured"
+  }
+  output_arg {
+    name: "out_tensors"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "num_batch_threads"
+    type: "int"
+  }
+  attr {
+    name: "max_batch_size"
+    type: "int"
+  }
+  attr {
+    name: "batch_timeout_micros"
+    type: "int"
+  }
+  attr {
+    name: "max_enqueued_batches"
+    type: "int"
+    default_value {
+      i: 10
+    }
+  }
+  attr {
+    name: "allowed_batch_sizes"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "batching_queue"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "Tcaptured"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchIFFT.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchIFFT.pbtxt
new file mode 100644
index 0000000..09d7b4a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchIFFT.pbtxt
@@ -0,0 +1,14 @@
+op {
+  name: "BatchIFFT"
+  input_arg {
+    name: "input"
+    type: DT_COMPLEX64
+  }
+  output_arg {
+    name: "output"
+    type: DT_COMPLEX64
+  }
+  deprecation {
+    version: 15
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchIFFT2D.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchIFFT2D.pbtxt
new file mode 100644
index 0000000..23cc9cc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchIFFT2D.pbtxt
@@ -0,0 +1,14 @@
+op {
+  name: "BatchIFFT2D"
+  input_arg {
+    name: "input"
+    type: DT_COMPLEX64
+  }
+  output_arg {
+    name: "output"
+    type: DT_COMPLEX64
+  }
+  deprecation {
+    version: 15
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchIFFT3D.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchIFFT3D.pbtxt
new file mode 100644
index 0000000..10a78fa
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchIFFT3D.pbtxt
@@ -0,0 +1,14 @@
+op {
+  name: "BatchIFFT3D"
+  input_arg {
+    name: "input"
+    type: DT_COMPLEX64
+  }
+  output_arg {
+    name: "output"
+    type: DT_COMPLEX64
+  }
+  deprecation {
+    version: 15
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchMatMul.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchMatMul.pbtxt
new file mode 100644
index 0000000..29c5a6c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchMatMul.pbtxt
@@ -0,0 +1,176 @@
+op {
+  name: "BatchMatMul"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  attr {
+    name: "adj_x"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "adj_y"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "BatchMatMul"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  attr {
+    name: "adj_x"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "adj_y"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "BatchMatMul"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  attr {
+    name: "adj_x"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "adj_y"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "BatchMatMul"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  attr {
+    name: "adj_x"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "adj_y"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchMatMulV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchMatMulV2.pbtxt
new file mode 100644
index 0000000..77224c1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchMatMulV2.pbtxt
@@ -0,0 +1,45 @@
+op {
+  name: "BatchMatMulV2"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  attr {
+    name: "adj_x"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "adj_y"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixBandPart.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixBandPart.pbtxt
new file mode 100644
index 0000000..413681e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixBandPart.pbtxt
@@ -0,0 +1,26 @@
+op {
+  name: "BatchMatrixBandPart"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "num_lower"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "num_upper"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "band"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  deprecation {
+    version: 14
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixDeterminant.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixDeterminant.pbtxt
new file mode 100644
index 0000000..4bc6081
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixDeterminant.pbtxt
@@ -0,0 +1,50 @@
+op {
+  name: "BatchMatrixDeterminant"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  deprecation {
+    version: 13
+  }
+}
+op {
+  name: "BatchMatrixDeterminant"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  deprecation {
+    version: 13
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixDiag.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixDiag.pbtxt
new file mode 100644
index 0000000..6104bef
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixDiag.pbtxt
@@ -0,0 +1,18 @@
+op {
+  name: "BatchMatrixDiag"
+  input_arg {
+    name: "diagonal"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  deprecation {
+    version: 14
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixDiagPart.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixDiagPart.pbtxt
new file mode 100644
index 0000000..9bd200f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixDiagPart.pbtxt
@@ -0,0 +1,18 @@
+op {
+  name: "BatchMatrixDiagPart"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "diagonal"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  deprecation {
+    version: 14
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixInverse.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixInverse.pbtxt
new file mode 100644
index 0000000..03a694d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixInverse.pbtxt
@@ -0,0 +1,31 @@
+op {
+  name: "BatchMatrixInverse"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "adjoint"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+      }
+    }
+  }
+  deprecation {
+    version: 13
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixSetDiag.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixSetDiag.pbtxt
new file mode 100644
index 0000000..f459184
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixSetDiag.pbtxt
@@ -0,0 +1,22 @@
+op {
+  name: "BatchMatrixSetDiag"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "diagonal"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  deprecation {
+    version: 14
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixSolve.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixSolve.pbtxt
new file mode 100644
index 0000000..909502e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixSolve.pbtxt
@@ -0,0 +1,35 @@
+op {
+  name: "BatchMatrixSolve"
+  input_arg {
+    name: "matrix"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rhs"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "adjoint"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+      }
+    }
+  }
+  deprecation {
+    version: 13
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixSolveLs.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixSolveLs.pbtxt
new file mode 100644
index 0000000..8c9d24e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixSolveLs.pbtxt
@@ -0,0 +1,39 @@
+op {
+  name: "BatchMatrixSolveLs"
+  input_arg {
+    name: "matrix"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rhs"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2_regularizer"
+    type: DT_DOUBLE
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "fast"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  deprecation {
+    version: 13
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixTriangularSolve.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixTriangularSolve.pbtxt
new file mode 100644
index 0000000..406fa62
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchMatrixTriangularSolve.pbtxt
@@ -0,0 +1,42 @@
+op {
+  name: "BatchMatrixTriangularSolve"
+  input_arg {
+    name: "matrix"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rhs"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "lower"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "adjoint"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+      }
+    }
+  }
+  deprecation {
+    version: 13
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchNormWithGlobalNormalization.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchNormWithGlobalNormalization.pbtxt
new file mode 100644
index 0000000..a15b037
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchNormWithGlobalNormalization.pbtxt
@@ -0,0 +1,248 @@
+op {
+  name: "BatchNormWithGlobalNormalization"
+  input_arg {
+    name: "t"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "m"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "v"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "gamma"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "result"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "variance_epsilon"
+    type: "float"
+  }
+  attr {
+    name: "scale_after_normalization"
+    type: "bool"
+  }
+  deprecation {
+    version: 9
+  }
+}
+op {
+  name: "BatchNormWithGlobalNormalization"
+  input_arg {
+    name: "t"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "m"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "v"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "gamma"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "result"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "variance_epsilon"
+    type: "float"
+  }
+  attr {
+    name: "scale_after_normalization"
+    type: "bool"
+  }
+  deprecation {
+    version: 9
+  }
+}
+op {
+  name: "BatchNormWithGlobalNormalization"
+  input_arg {
+    name: "t"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "m"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "v"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "gamma"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "result"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "variance_epsilon"
+    type: "float"
+  }
+  attr {
+    name: "scale_after_normalization"
+    type: "bool"
+  }
+  deprecation {
+    version: 9
+  }
+}
+op {
+  name: "BatchNormWithGlobalNormalization"
+  input_arg {
+    name: "t"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "m"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "v"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "gamma"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "result"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "variance_epsilon"
+    type: "float"
+  }
+  attr {
+    name: "scale_after_normalization"
+    type: "bool"
+  }
+  deprecation {
+    version: 9
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchNormWithGlobalNormalizationGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchNormWithGlobalNormalizationGrad.pbtxt
new file mode 100644
index 0000000..ef973cc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchNormWithGlobalNormalizationGrad.pbtxt
@@ -0,0 +1,312 @@
+op {
+  name: "BatchNormWithGlobalNormalizationGrad"
+  input_arg {
+    name: "t"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "m"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "v"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "gamma"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "dx"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "dm"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "dv"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "db"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "dg"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "variance_epsilon"
+    type: "float"
+  }
+  attr {
+    name: "scale_after_normalization"
+    type: "bool"
+  }
+  deprecation {
+    version: 9
+  }
+}
+op {
+  name: "BatchNormWithGlobalNormalizationGrad"
+  input_arg {
+    name: "t"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "m"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "v"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "gamma"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "dx"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "dm"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "dv"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "db"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "dg"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "variance_epsilon"
+    type: "float"
+  }
+  attr {
+    name: "scale_after_normalization"
+    type: "bool"
+  }
+  deprecation {
+    version: 9
+  }
+}
+op {
+  name: "BatchNormWithGlobalNormalizationGrad"
+  input_arg {
+    name: "t"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "m"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "v"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "gamma"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "dx"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "dm"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "dv"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "db"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "dg"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "variance_epsilon"
+    type: "float"
+  }
+  attr {
+    name: "scale_after_normalization"
+    type: "bool"
+  }
+  deprecation {
+    version: 9
+  }
+}
+op {
+  name: "BatchNormWithGlobalNormalizationGrad"
+  input_arg {
+    name: "t"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "m"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "v"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "gamma"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "dx"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "dm"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "dv"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "db"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "dg"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "variance_epsilon"
+    type: "float"
+  }
+  attr {
+    name: "scale_after_normalization"
+    type: "bool"
+  }
+  deprecation {
+    version: 9
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchSelfAdjointEig.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchSelfAdjointEig.pbtxt
new file mode 100644
index 0000000..42ba041
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchSelfAdjointEig.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "BatchSelfAdjointEig"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+      }
+    }
+  }
+  deprecation {
+    version: 11
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchSelfAdjointEigV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchSelfAdjointEigV2.pbtxt
new file mode 100644
index 0000000..df3996e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchSelfAdjointEigV2.pbtxt
@@ -0,0 +1,35 @@
+op {
+  name: "BatchSelfAdjointEigV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "e"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "v"
+    type_attr: "T"
+  }
+  attr {
+    name: "compute_v"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+      }
+    }
+  }
+  deprecation {
+    version: 13
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchSvd.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchSvd.pbtxt
new file mode 100644
index 0000000..0595ffc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchSvd.pbtxt
@@ -0,0 +1,48 @@
+op {
+  name: "BatchSvd"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "s"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "u"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "v"
+    type_attr: "T"
+  }
+  attr {
+    name: "compute_uv"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "full_matrices"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  deprecation {
+    version: 13
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchToSpace.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchToSpace.pbtxt
new file mode 100644
index 0000000..ac089e5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchToSpace.pbtxt
@@ -0,0 +1,38 @@
+op {
+  name: "BatchToSpace"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "crops"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "block_size"
+    type: "int"
+    has_minimum: true
+    minimum: 2
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BatchToSpaceND.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BatchToSpaceND.pbtxt
new file mode 100644
index 0000000..464beb3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BatchToSpaceND.pbtxt
@@ -0,0 +1,49 @@
+op {
+  name: "BatchToSpaceND"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "block_shape"
+    type_attr: "Tblock_shape"
+  }
+  input_arg {
+    name: "crops"
+    type_attr: "Tcrops"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tblock_shape"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tcrops"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BesselI0e.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BesselI0e.pbtxt
new file mode 100644
index 0000000..299cf82
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BesselI0e.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "BesselI0e"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BesselI1e.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BesselI1e.pbtxt
new file mode 100644
index 0000000..a9c8d0e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BesselI1e.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "BesselI1e"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Betainc.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Betainc.pbtxt
new file mode 100644
index 0000000..b1523bf
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Betainc.pbtxt
@@ -0,0 +1,29 @@
+op {
+  name: "Betainc"
+  input_arg {
+    name: "a"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BiasAdd.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BiasAdd.pbtxt
new file mode 100644
index 0000000..0d64c07
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BiasAdd.pbtxt
@@ -0,0 +1,208 @@
+op {
+  name: "BiasAdd"
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "bias"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+}
+op {
+  name: "BiasAdd"
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "bias"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+}
+op {
+  name: "BiasAdd"
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "bias"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+}
+op {
+  name: "BiasAdd"
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "bias"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BiasAddGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BiasAddGrad.pbtxt
new file mode 100644
index 0000000..bb86721
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BiasAddGrad.pbtxt
@@ -0,0 +1,192 @@
+op {
+  name: "BiasAddGrad"
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+}
+op {
+  name: "BiasAddGrad"
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+}
+op {
+  name: "BiasAddGrad"
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+}
+op {
+  name: "BiasAddGrad"
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BiasAddV1.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BiasAddV1.pbtxt
new file mode 100644
index 0000000..ecd1245
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BiasAddV1.pbtxt
@@ -0,0 +1,156 @@
+op {
+  name: "BiasAddV1"
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "bias"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "BiasAddV1"
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "bias"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "BiasAddV1"
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "bias"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "BiasAddV1"
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "bias"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Bincount.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Bincount.pbtxt
new file mode 100644
index 0000000..12135bb
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Bincount.pbtxt
@@ -0,0 +1,31 @@
+op {
+  name: "Bincount"
+  input_arg {
+    name: "arr"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "weights"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "bins"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Bitcast.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Bitcast.pbtxt
new file mode 100644
index 0000000..993a0c6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Bitcast.pbtxt
@@ -0,0 +1,301 @@
+op {
+  name: "Bitcast"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "Bitcast"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "Bitcast"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "Bitcast"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+}
+op {
+  name: "Bitcast"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BitwiseAnd.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BitwiseAnd.pbtxt
new file mode 100644
index 0000000..4b90e0e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BitwiseAnd.pbtxt
@@ -0,0 +1,62 @@
+op {
+  name: "BitwiseAnd"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_UINT16
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "BitwiseAnd"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  is_commutative: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BitwiseOr.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BitwiseOr.pbtxt
new file mode 100644
index 0000000..393a506
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BitwiseOr.pbtxt
@@ -0,0 +1,62 @@
+op {
+  name: "BitwiseOr"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_UINT16
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "BitwiseOr"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  is_commutative: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BitwiseXor.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BitwiseXor.pbtxt
new file mode 100644
index 0000000..c72b23fc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BitwiseXor.pbtxt
@@ -0,0 +1,62 @@
+op {
+  name: "BitwiseXor"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_UINT16
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "BitwiseXor"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  is_commutative: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BlockLSTM.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BlockLSTM.pbtxt
new file mode 100644
index 0000000..63180f5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BlockLSTM.pbtxt
@@ -0,0 +1,98 @@
+op {
+  name: "BlockLSTM"
+  input_arg {
+    name: "seq_len_max"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "cs_prev"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "h_prev"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "w"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "wci"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "wcf"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "wco"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "i"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "cs"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "f"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "o"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "ci"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "co"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "h"
+    type_attr: "T"
+  }
+  attr {
+    name: "forget_bias"
+    type: "float"
+    default_value {
+      f: 1
+    }
+  }
+  attr {
+    name: "cell_clip"
+    type: "float"
+    default_value {
+      f: 3
+    }
+  }
+  attr {
+    name: "use_peephole"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BlockLSTMGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BlockLSTMGrad.pbtxt
new file mode 100644
index 0000000..e7b6458
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BlockLSTMGrad.pbtxt
@@ -0,0 +1,121 @@
+op {
+  name: "BlockLSTMGrad"
+  input_arg {
+    name: "seq_len_max"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "cs_prev"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "h_prev"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "w"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "wci"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "wcf"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "wco"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "i"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "cs"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "f"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "o"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "ci"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "co"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "h"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "cs_grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "h_grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "x_grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "cs_prev_grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "h_prev_grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "w_grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "wci_grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "wcf_grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "wco_grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "b_grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "use_peephole"
+    type: "bool"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesAggregateStats.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesAggregateStats.pbtxt
new file mode 100644
index 0000000..7299409
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesAggregateStats.pbtxt
@@ -0,0 +1,35 @@
+op {
+  name: "BoostedTreesAggregateStats"
+  input_arg {
+    name: "node_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "hessians"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "feature"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "stats_summary"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "max_splits"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_buckets"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesBucketize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesBucketize.pbtxt
new file mode 100644
index 0000000..5f277d3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesBucketize.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "BoostedTreesBucketize"
+  input_arg {
+    name: "float_values"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  input_arg {
+    name: "bucket_boundaries"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  output_arg {
+    name: "buckets"
+    type: DT_INT32
+    number_attr: "num_features"
+  }
+  attr {
+    name: "num_features"
+    type: "int"
+    has_minimum: true
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesCalculateBestFeatureSplit.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesCalculateBestFeatureSplit.pbtxt
new file mode 100644
index 0000000..929d54d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesCalculateBestFeatureSplit.pbtxt
@@ -0,0 +1,147 @@
+op {
+  name: "BoostedTreesCalculateBestFeatureSplit"
+  input_arg {
+    name: "node_id_range"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "stats_summary"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "l1"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "l2"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "tree_complexity"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_node_weight"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "node_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "gains"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "feature_dimensions"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "thresholds"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "left_node_contribs"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "right_node_contribs"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "split_with_default_directions"
+    type: DT_STRING
+  }
+  attr {
+    name: "logits_dimension"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "split_type"
+    type: "string"
+    default_value {
+      s: "inequality"
+    }
+    allowed_values {
+      list {
+        s: "inequality"
+      }
+    }
+  }
+}
+op {
+  name: "BoostedTreesCalculateBestFeatureSplit"
+  input_arg {
+    name: "node_id_range"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "stats_summary"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "l1"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "l2"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "tree_complexity"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_node_weight"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "node_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "gains"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "feature_dimensions"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "thresholds"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "left_node_contribs"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "right_node_contribs"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "split_with_default_directions"
+    type: DT_STRING
+  }
+  attr {
+    name: "logits_dimension"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "split_type"
+    type: "string"
+    default_value {
+      s: "inequality"
+    }
+    allowed_values {
+      list {
+        s: "inequality"
+        s: "equality"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesCalculateBestGainsPerFeature.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesCalculateBestGainsPerFeature.pbtxt
new file mode 100644
index 0000000..f100db7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesCalculateBestGainsPerFeature.pbtxt
@@ -0,0 +1,65 @@
+op {
+  name: "BoostedTreesCalculateBestGainsPerFeature"
+  input_arg {
+    name: "node_id_range"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "stats_summary_list"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  input_arg {
+    name: "l1"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "l2"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "tree_complexity"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_node_weight"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "node_ids_list"
+    type: DT_INT32
+    number_attr: "num_features"
+  }
+  output_arg {
+    name: "gains_list"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  output_arg {
+    name: "thresholds_list"
+    type: DT_INT32
+    number_attr: "num_features"
+  }
+  output_arg {
+    name: "left_node_contribs_list"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  output_arg {
+    name: "right_node_contribs_list"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  attr {
+    name: "max_splits"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_features"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesCenterBias.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesCenterBias.pbtxt
new file mode 100644
index 0000000..5c2fb9b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesCenterBias.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "BoostedTreesCenterBias"
+  input_arg {
+    name: "tree_ensemble_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mean_gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "mean_hessians"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "l1"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "l2"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "continue_centering"
+    type: DT_BOOL
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesCreateEnsemble.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesCreateEnsemble.pbtxt
new file mode 100644
index 0000000..cea6d23
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesCreateEnsemble.pbtxt
@@ -0,0 +1,16 @@
+op {
+  name: "BoostedTreesCreateEnsemble"
+  input_arg {
+    name: "tree_ensemble_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "stamp_token"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "tree_ensemble_serialized"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesCreateQuantileStreamResource.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesCreateQuantileStreamResource.pbtxt
new file mode 100644
index 0000000..3d0d64a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesCreateQuantileStreamResource.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "BoostedTreesCreateQuantileStreamResource"
+  input_arg {
+    name: "quantile_stream_resource_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "epsilon"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "num_streams"
+    type: DT_INT64
+  }
+  attr {
+    name: "max_elements"
+    type: "int"
+    default_value {
+      i: 1099511627776
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesDeserializeEnsemble.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesDeserializeEnsemble.pbtxt
new file mode 100644
index 0000000..b6d55ea
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesDeserializeEnsemble.pbtxt
@@ -0,0 +1,16 @@
+op {
+  name: "BoostedTreesDeserializeEnsemble"
+  input_arg {
+    name: "tree_ensemble_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "stamp_token"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "tree_ensemble_serialized"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesEnsembleResourceHandleOp.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesEnsembleResourceHandleOp.pbtxt
new file mode 100644
index 0000000..00573c1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesEnsembleResourceHandleOp.pbtxt
@@ -0,0 +1,22 @@
+op {
+  name: "BoostedTreesEnsembleResourceHandleOp"
+  output_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesExampleDebugOutputs.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesExampleDebugOutputs.pbtxt
new file mode 100644
index 0000000..066be04
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesExampleDebugOutputs.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "BoostedTreesExampleDebugOutputs"
+  input_arg {
+    name: "tree_ensemble_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "bucketized_features"
+    type: DT_INT32
+    number_attr: "num_bucketized_features"
+  }
+  output_arg {
+    name: "examples_debug_outputs_serialized"
+    type: DT_STRING
+  }
+  attr {
+    name: "num_bucketized_features"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "logits_dimension"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesFlushQuantileSummaries.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesFlushQuantileSummaries.pbtxt
new file mode 100644
index 0000000..ae35e10
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesFlushQuantileSummaries.pbtxt
@@ -0,0 +1,18 @@
+op {
+  name: "BoostedTreesFlushQuantileSummaries"
+  input_arg {
+    name: "quantile_stream_resource_handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "summaries"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  attr {
+    name: "num_features"
+    type: "int"
+    has_minimum: true
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesGetEnsembleStates.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesGetEnsembleStates.pbtxt
new file mode 100644
index 0000000..1959384
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesGetEnsembleStates.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "BoostedTreesGetEnsembleStates"
+  input_arg {
+    name: "tree_ensemble_handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "stamp_token"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "num_trees"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "num_finalized_trees"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "num_attempted_layers"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "last_layer_nodes_range"
+    type: DT_INT32
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesMakeQuantileSummaries.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesMakeQuantileSummaries.pbtxt
new file mode 100644
index 0000000..bbefa8b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesMakeQuantileSummaries.pbtxt
@@ -0,0 +1,26 @@
+op {
+  name: "BoostedTreesMakeQuantileSummaries"
+  input_arg {
+    name: "float_values"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  input_arg {
+    name: "example_weights"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "epsilon"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "summaries"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  attr {
+    name: "num_features"
+    type: "int"
+    has_minimum: true
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesMakeStatsSummary.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesMakeStatsSummary.pbtxt
new file mode 100644
index 0000000..49a82d2
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesMakeStatsSummary.pbtxt
@@ -0,0 +1,42 @@
+op {
+  name: "BoostedTreesMakeStatsSummary"
+  input_arg {
+    name: "node_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "hessians"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "bucketized_features_list"
+    type: DT_INT32
+    number_attr: "num_features"
+  }
+  output_arg {
+    name: "stats_summary"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "max_splits"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_buckets"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_features"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesPredict.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesPredict.pbtxt
new file mode 100644
index 0000000..7f176cd
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesPredict.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "BoostedTreesPredict"
+  input_arg {
+    name: "tree_ensemble_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "bucketized_features"
+    type: DT_INT32
+    number_attr: "num_bucketized_features"
+  }
+  output_arg {
+    name: "logits"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "num_bucketized_features"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "logits_dimension"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesQuantileStreamResourceAddSummaries.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesQuantileStreamResourceAddSummaries.pbtxt
new file mode 100644
index 0000000..97e875f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesQuantileStreamResourceAddSummaries.pbtxt
@@ -0,0 +1,18 @@
+op {
+  name: "BoostedTreesQuantileStreamResourceAddSummaries"
+  input_arg {
+    name: "quantile_stream_resource_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "summaries"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  attr {
+    name: "num_features"
+    type: "int"
+    has_minimum: true
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesQuantileStreamResourceDeserialize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesQuantileStreamResourceDeserialize.pbtxt
new file mode 100644
index 0000000..c3f01fe
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesQuantileStreamResourceDeserialize.pbtxt
@@ -0,0 +1,19 @@
+op {
+  name: "BoostedTreesQuantileStreamResourceDeserialize"
+  input_arg {
+    name: "quantile_stream_resource_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "bucket_boundaries"
+    type: DT_FLOAT
+    number_attr: "num_streams"
+  }
+  attr {
+    name: "num_streams"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesQuantileStreamResourceFlush.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesQuantileStreamResourceFlush.pbtxt
new file mode 100644
index 0000000..fc2613a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesQuantileStreamResourceFlush.pbtxt
@@ -0,0 +1,19 @@
+op {
+  name: "BoostedTreesQuantileStreamResourceFlush"
+  input_arg {
+    name: "quantile_stream_resource_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "num_buckets"
+    type: DT_INT64
+  }
+  attr {
+    name: "generate_quantiles"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt
new file mode 100644
index 0000000..b2aa8dd
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt
@@ -0,0 +1,18 @@
+op {
+  name: "BoostedTreesQuantileStreamResourceGetBucketBoundaries"
+  input_arg {
+    name: "quantile_stream_resource_handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "bucket_boundaries"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  attr {
+    name: "num_features"
+    type: "int"
+    has_minimum: true
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesQuantileStreamResourceHandleOp.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesQuantileStreamResourceHandleOp.pbtxt
new file mode 100644
index 0000000..ca40a0a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesQuantileStreamResourceHandleOp.pbtxt
@@ -0,0 +1,22 @@
+op {
+  name: "BoostedTreesQuantileStreamResourceHandleOp"
+  output_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesSerializeEnsemble.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesSerializeEnsemble.pbtxt
new file mode 100644
index 0000000..29d19f0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesSerializeEnsemble.pbtxt
@@ -0,0 +1,16 @@
+op {
+  name: "BoostedTreesSerializeEnsemble"
+  input_arg {
+    name: "tree_ensemble_handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "stamp_token"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "tree_ensemble_serialized"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesSparseAggregateStats.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesSparseAggregateStats.pbtxt
new file mode 100644
index 0000000..9260634
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesSparseAggregateStats.pbtxt
@@ -0,0 +1,51 @@
+op {
+  name: "BoostedTreesSparseAggregateStats"
+  input_arg {
+    name: "node_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "hessians"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "feature_indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "feature_values"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "feature_shape"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "stats_summary_indices"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "stats_summary_values"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "stats_summary_shape"
+    type: DT_INT32
+  }
+  attr {
+    name: "max_splits"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_buckets"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesSparseCalculateBestFeatureSplit.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesSparseCalculateBestFeatureSplit.pbtxt
new file mode 100644
index 0000000..86f7a5f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesSparseCalculateBestFeatureSplit.pbtxt
@@ -0,0 +1,81 @@
+op {
+  name: "BoostedTreesSparseCalculateBestFeatureSplit"
+  input_arg {
+    name: "node_id_range"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "stats_summary_indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "stats_summary_values"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "stats_summary_shape"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "l1"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "l2"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "tree_complexity"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_node_weight"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "node_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "gains"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "feature_dimensions"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "thresholds"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "left_node_contribs"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "right_node_contribs"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "split_with_default_directions"
+    type: DT_STRING
+  }
+  attr {
+    name: "logits_dimension"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "split_type"
+    type: "string"
+    default_value {
+      s: "inequality"
+    }
+    allowed_values {
+      list {
+        s: "inequality"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesTrainingPredict.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesTrainingPredict.pbtxt
new file mode 100644
index 0000000..615f52c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesTrainingPredict.pbtxt
@@ -0,0 +1,43 @@
+op {
+  name: "BoostedTreesTrainingPredict"
+  input_arg {
+    name: "tree_ensemble_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "cached_tree_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "cached_node_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "bucketized_features"
+    type: DT_INT32
+    number_attr: "num_bucketized_features"
+  }
+  output_arg {
+    name: "partial_logits"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "tree_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "node_ids"
+    type: DT_INT32
+  }
+  attr {
+    name: "num_bucketized_features"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "logits_dimension"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesUpdateEnsemble.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesUpdateEnsemble.pbtxt
new file mode 100644
index 0000000..9cd779e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesUpdateEnsemble.pbtxt
@@ -0,0 +1,55 @@
+op {
+  name: "BoostedTreesUpdateEnsemble"
+  input_arg {
+    name: "tree_ensemble_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "feature_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "node_ids"
+    type: DT_INT32
+    number_attr: "num_features"
+  }
+  input_arg {
+    name: "gains"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  input_arg {
+    name: "thresholds"
+    type: DT_INT32
+    number_attr: "num_features"
+  }
+  input_arg {
+    name: "left_node_contribs"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  input_arg {
+    name: "right_node_contribs"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  input_arg {
+    name: "max_depth"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "pruning_mode"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "num_features"
+    type: "int"
+    has_minimum: true
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BroadcastArgs.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BroadcastArgs.pbtxt
new file mode 100644
index 0000000..e6dc399
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BroadcastArgs.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "BroadcastArgs"
+  input_arg {
+    name: "s0"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "s1"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "r0"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BroadcastGradientArgs.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BroadcastGradientArgs.pbtxt
new file mode 100644
index 0000000..2e1d739
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BroadcastGradientArgs.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "BroadcastGradientArgs"
+  input_arg {
+    name: "s0"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "s1"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "r0"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "r1"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BroadcastTo.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BroadcastTo.pbtxt
new file mode 100644
index 0000000..4d29f9e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BroadcastTo.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "BroadcastTo"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "shape"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Bucketize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Bucketize.pbtxt
new file mode 100644
index 0000000..abe818e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Bucketize.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "Bucketize"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type: DT_INT32
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "boundaries"
+    type: "list(float)"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BytesProducedStatsDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BytesProducedStatsDataset.pbtxt
new file mode 100644
index 0000000..fabe50c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BytesProducedStatsDataset.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "BytesProducedStatsDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "tag"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CSVDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CSVDataset.pbtxt
new file mode 100644
index 0000000..3cac405
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CSVDataset.pbtxt
@@ -0,0 +1,65 @@
+op {
+  name: "CSVDataset"
+  input_arg {
+    name: "filenames"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "compression_type"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "buffer_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "header"
+    type: DT_BOOL
+  }
+  input_arg {
+    name: "field_delim"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "use_quote_delim"
+    type: DT_BOOL
+  }
+  input_arg {
+    name: "na_value"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "select_cols"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "record_defaults"
+    type_list_attr: "output_types"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CTCBeamSearchDecoder.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CTCBeamSearchDecoder.pbtxt
new file mode 100644
index 0000000..aa489be
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CTCBeamSearchDecoder.pbtxt
@@ -0,0 +1,49 @@
+op {
+  name: "CTCBeamSearchDecoder"
+  input_arg {
+    name: "inputs"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "sequence_length"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "decoded_indices"
+    type: DT_INT64
+    number_attr: "top_paths"
+  }
+  output_arg {
+    name: "decoded_values"
+    type: DT_INT64
+    number_attr: "top_paths"
+  }
+  output_arg {
+    name: "decoded_shape"
+    type: DT_INT64
+    number_attr: "top_paths"
+  }
+  output_arg {
+    name: "log_probability"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "beam_width"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "top_paths"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "merge_repeated"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CTCGreedyDecoder.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CTCGreedyDecoder.pbtxt
new file mode 100644
index 0000000..c13070b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CTCGreedyDecoder.pbtxt
@@ -0,0 +1,34 @@
+op {
+  name: "CTCGreedyDecoder"
+  input_arg {
+    name: "inputs"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "sequence_length"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "decoded_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "decoded_values"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "decoded_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "log_probability"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "merge_repeated"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CTCLoss.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CTCLoss.pbtxt
new file mode 100644
index 0000000..6947879
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CTCLoss.pbtxt
@@ -0,0 +1,89 @@
+op {
+  name: "CTCLoss"
+  input_arg {
+    name: "inputs"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "labels_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "labels_values"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sequence_length"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "loss"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "gradient"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "preprocess_collapse_repeated"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "ctc_merge_repeated"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
+op {
+  name: "CTCLoss"
+  input_arg {
+    name: "inputs"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "labels_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "labels_values"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sequence_length"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "loss"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "gradient"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "preprocess_collapse_repeated"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "ctc_merge_repeated"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "ignore_longer_outputs_than_inputs"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CacheDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CacheDataset.pbtxt
new file mode 100644
index 0000000..c53eb7a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CacheDataset.pbtxt
@@ -0,0 +1,55 @@
+op {
+  name: "CacheDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "filename"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
+op {
+  name: "CacheDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "filename"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CacheDatasetV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CacheDatasetV2.pbtxt
new file mode 100644
index 0000000..c65c4c1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CacheDatasetV2.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "CacheDatasetV2"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "filename"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "cache"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Case.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Case.pbtxt
new file mode 100644
index 0000000..39cfc3f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Case.pbtxt
@@ -0,0 +1,40 @@
+op {
+  name: "Case"
+  input_arg {
+    name: "branch_index"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "input"
+    type_list_attr: "Tin"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "branches"
+    type: "list(func)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    default_value {
+      list {
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Cast.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Cast.pbtxt
new file mode 100644
index 0000000..581a8e4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Cast.pbtxt
@@ -0,0 +1,45 @@
+op {
+  name: "Cast"
+  input_arg {
+    name: "x"
+    type_attr: "SrcT"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "DstT"
+  }
+  attr {
+    name: "SrcT"
+    type: "type"
+  }
+  attr {
+    name: "DstT"
+    type: "type"
+  }
+}
+op {
+  name: "Cast"
+  input_arg {
+    name: "x"
+    type_attr: "SrcT"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "DstT"
+  }
+  attr {
+    name: "SrcT"
+    type: "type"
+  }
+  attr {
+    name: "DstT"
+    type: "type"
+  }
+  attr {
+    name: "Truncate"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Ceil.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Ceil.pbtxt
new file mode 100644
index 0000000..cdec085
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Ceil.pbtxt
@@ -0,0 +1,68 @@
+op {
+  name: "Ceil"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "Ceil"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "Ceil"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CheckNumerics.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CheckNumerics.pbtxt
new file mode 100644
index 0000000..9e63b17
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CheckNumerics.pbtxt
@@ -0,0 +1,108 @@
+op {
+  name: "CheckNumerics"
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "message"
+    type: "string"
+  }
+}
+op {
+  name: "CheckNumerics"
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "message"
+    type: "string"
+  }
+}
+op {
+  name: "CheckNumerics"
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "message"
+    type: "string"
+  }
+}
+op {
+  name: "CheckNumerics"
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "message"
+    type: "string"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Cholesky.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Cholesky.pbtxt
new file mode 100644
index 0000000..e3cee5f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Cholesky.pbtxt
@@ -0,0 +1,68 @@
+op {
+  name: "Cholesky"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+      }
+    }
+  }
+}
+op {
+  name: "Cholesky"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Cholesky"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CholeskyGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CholeskyGrad.pbtxt
new file mode 100644
index 0000000..0f7c7ef
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CholeskyGrad.pbtxt
@@ -0,0 +1,51 @@
+op {
+  name: "CholeskyGrad"
+  input_arg {
+    name: "l"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "CholeskyGrad"
+  input_arg {
+    name: "l"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ChooseFastestBranchDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ChooseFastestBranchDataset.pbtxt
new file mode 100644
index 0000000..8638a91
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ChooseFastestBranchDataset.pbtxt
@@ -0,0 +1,58 @@
+op {
+  name: "ChooseFastestBranchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "ratio_numerator"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "ratio_denominator"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "num_elements_per_branch"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "branches"
+    type: "list(func)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "other_arguments_lengths"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ChooseFastestDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ChooseFastestDataset.pbtxt
new file mode 100644
index 0000000..d4ead6d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ChooseFastestDataset.pbtxt
@@ -0,0 +1,34 @@
+op {
+  name: "ChooseFastestDataset"
+  input_arg {
+    name: "input_datasets"
+    type: DT_VARIANT
+    number_attr: "N"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 2
+  }
+  attr {
+    name: "num_experiments"
+    type: "int"
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ClipByValue.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ClipByValue.pbtxt
new file mode 100644
index 0000000..3c3919b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ClipByValue.pbtxt
@@ -0,0 +1,44 @@
+op {
+  name: "ClipByValue"
+  input_arg {
+    name: "t"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "clip_value_min"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "clip_value_max"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CloseSummaryWriter.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CloseSummaryWriter.pbtxt
new file mode 100644
index 0000000..f67e1aa
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CloseSummaryWriter.pbtxt
@@ -0,0 +1,8 @@
+op {
+  name: "CloseSummaryWriter"
+  input_arg {
+    name: "writer"
+    type: DT_RESOURCE
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CollectiveBcastRecv.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CollectiveBcastRecv.pbtxt
new file mode 100644
index 0000000..0b77d8b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CollectiveBcastRecv.pbtxt
@@ -0,0 +1,37 @@
+op {
+  name: "CollectiveBcastRecv"
+  output_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "group_size"
+    type: "int"
+  }
+  attr {
+    name: "group_key"
+    type: "int"
+  }
+  attr {
+    name: "instance_key"
+    type: "int"
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CollectiveBcastSend.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CollectiveBcastSend.pbtxt
new file mode 100644
index 0000000..137f044
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CollectiveBcastSend.pbtxt
@@ -0,0 +1,41 @@
+op {
+  name: "CollectiveBcastSend"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "group_size"
+    type: "int"
+  }
+  attr {
+    name: "group_key"
+    type: "int"
+  }
+  attr {
+    name: "instance_key"
+    type: "int"
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CollectiveGather.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CollectiveGather.pbtxt
new file mode 100644
index 0000000..69cd90e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CollectiveGather.pbtxt
@@ -0,0 +1,41 @@
+op {
+  name: "CollectiveGather"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "group_size"
+    type: "int"
+  }
+  attr {
+    name: "group_key"
+    type: "int"
+  }
+  attr {
+    name: "instance_key"
+    type: "int"
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CollectivePermute.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CollectivePermute.pbtxt
new file mode 100644
index 0000000..fd224d3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CollectivePermute.pbtxt
@@ -0,0 +1,40 @@
+op {
+  name: "CollectivePermute"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "source_target_pairs"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CollectiveReduce.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CollectiveReduce.pbtxt
new file mode 100644
index 0000000..e23edde
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CollectiveReduce.pbtxt
@@ -0,0 +1,134 @@
+op {
+  name: "CollectiveReduce"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "group_size"
+    type: "int"
+  }
+  attr {
+    name: "group_key"
+    type: "int"
+  }
+  attr {
+    name: "instance_key"
+    type: "int"
+  }
+  attr {
+    name: "merge_op"
+    type: "string"
+    allowed_values {
+      list {
+        s: "Min"
+        s: "Max"
+        s: "Mul"
+        s: "Add"
+      }
+    }
+  }
+  attr {
+    name: "final_op"
+    type: "string"
+    allowed_values {
+      list {
+        s: "Id"
+        s: "Div"
+      }
+    }
+  }
+  attr {
+    name: "subdiv_offsets"
+    type: "list(int)"
+  }
+  is_stateful: true
+}
+op {
+  name: "CollectiveReduce"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "group_size"
+    type: "int"
+  }
+  attr {
+    name: "group_key"
+    type: "int"
+  }
+  attr {
+    name: "instance_key"
+    type: "int"
+  }
+  attr {
+    name: "merge_op"
+    type: "string"
+    allowed_values {
+      list {
+        s: "Min"
+        s: "Max"
+        s: "Mul"
+        s: "Add"
+      }
+    }
+  }
+  attr {
+    name: "final_op"
+    type: "string"
+    allowed_values {
+      list {
+        s: "Id"
+        s: "Div"
+      }
+    }
+  }
+  attr {
+    name: "subdiv_offsets"
+    type: "list(int)"
+  }
+  attr {
+    name: "wait_for"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CombinedNonMaxSuppression.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CombinedNonMaxSuppression.pbtxt
new file mode 100644
index 0000000..55e2712
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CombinedNonMaxSuppression.pbtxt
@@ -0,0 +1,107 @@
+op {
+  name: "CombinedNonMaxSuppression"
+  input_arg {
+    name: "boxes"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "scores"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_output_size_per_class"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "max_total_size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "iou_threshold"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "score_threshold"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "nmsed_boxes"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "nmsed_scores"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "nmsed_classes"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "valid_detections"
+    type: DT_INT32
+  }
+  attr {
+    name: "pad_per_class"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "CombinedNonMaxSuppression"
+  input_arg {
+    name: "boxes"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "scores"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_output_size_per_class"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "max_total_size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "iou_threshold"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "score_threshold"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "nmsed_boxes"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "nmsed_scores"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "nmsed_classes"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "valid_detections"
+    type: DT_INT32
+  }
+  attr {
+    name: "pad_per_class"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "clip_boxes"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CompareAndBitpack.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CompareAndBitpack.pbtxt
new file mode 100644
index 0000000..b2df17b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CompareAndBitpack.pbtxt
@@ -0,0 +1,31 @@
+op {
+  name: "CompareAndBitpack"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "threshold"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type: DT_UINT8
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BOOL
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Complex.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Complex.pbtxt
new file mode 100644
index 0000000..5d17643
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Complex.pbtxt
@@ -0,0 +1,41 @@
+op {
+  name: "Complex"
+  input_arg {
+    name: "real"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "imag"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "Tout"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "Tout"
+    type: "type"
+    default_value {
+      type: DT_COMPLEX64
+    }
+    allowed_values {
+      list {
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ComplexAbs.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ComplexAbs.pbtxt
new file mode 100644
index 0000000..6e7cfc1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ComplexAbs.pbtxt
@@ -0,0 +1,37 @@
+op {
+  name: "ComplexAbs"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "Tout"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_COMPLEX64
+    }
+    allowed_values {
+      list {
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  attr {
+    name: "Tout"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ComputeAccidentalHits.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ComputeAccidentalHits.pbtxt
new file mode 100644
index 0000000..0bac269
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ComputeAccidentalHits.pbtxt
@@ -0,0 +1,41 @@
+op {
+  name: "ComputeAccidentalHits"
+  input_arg {
+    name: "true_classes"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sampled_candidates"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "ids"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "weights"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "num_true"
+    type: "int"
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Concat.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Concat.pbtxt
new file mode 100644
index 0000000..21ff0fd
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Concat.pbtxt
@@ -0,0 +1,26 @@
+op {
+  name: "Concat"
+  input_arg {
+    name: "concat_dim"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "values"
+    type_attr: "T"
+    number_attr: "N"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 2
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ConcatOffset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ConcatOffset.pbtxt
new file mode 100644
index 0000000..8057672
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ConcatOffset.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "ConcatOffset"
+  input_arg {
+    name: "concat_dim"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "shape"
+    type: DT_INT32
+    number_attr: "N"
+  }
+  output_arg {
+    name: "offset"
+    type: DT_INT32
+    number_attr: "N"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 2
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ConcatV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ConcatV2.pbtxt
new file mode 100644
index 0000000..d11dc14
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ConcatV2.pbtxt
@@ -0,0 +1,39 @@
+op {
+  name: "ConcatV2"
+  input_arg {
+    name: "values"
+    type_attr: "T"
+    number_attr: "N"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 2
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ConcatenateDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ConcatenateDataset.pbtxt
new file mode 100644
index 0000000..058d0ea
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ConcatenateDataset.pbtxt
@@ -0,0 +1,55 @@
+op {
+  name: "ConcatenateDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "another_dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
+op {
+  name: "ConcatenateDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "another_dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ConditionalAccumulator.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ConditionalAccumulator.pbtxt
new file mode 100644
index 0000000..d6da788
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ConditionalAccumulator.pbtxt
@@ -0,0 +1,269 @@
+op {
+  name: "ConditionalAccumulator"
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ConditionalAccumulator"
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ConditionalAccumulator"
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ConditionalAccumulator"
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ConditionalAccumulator"
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "reduction_type"
+    type: "string"
+    default_value {
+      s: "MEAN"
+    }
+    allowed_values {
+      list {
+        s: "MEAN"
+        s: "SUM"
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ConfigureDistributedTPU.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ConfigureDistributedTPU.pbtxt
new file mode 100644
index 0000000..8d19a02
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ConfigureDistributedTPU.pbtxt
@@ -0,0 +1,29 @@
+op {
+  name: "ConfigureDistributedTPU"
+  output_arg {
+    name: "topology"
+    type: DT_STRING
+  }
+  attr {
+    name: "embedding_config"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "tpu_embedding_config"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "is_global_init"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ConfigureTPUEmbedding.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ConfigureTPUEmbedding.pbtxt
new file mode 100644
index 0000000..6e61f88
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ConfigureTPUEmbedding.pbtxt
@@ -0,0 +1,8 @@
+op {
+  name: "ConfigureTPUEmbedding"
+  attr {
+    name: "config"
+    type: "string"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Conj.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Conj.pbtxt
new file mode 100644
index 0000000..6e98e16
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Conj.pbtxt
@@ -0,0 +1,49 @@
+op {
+  name: "Conj"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_COMPLEX64
+    }
+    allowed_values {
+      list {
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Conj"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_COMPLEX64
+    }
+    allowed_values {
+      list {
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_VARIANT
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ConjugateTranspose.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ConjugateTranspose.pbtxt
new file mode 100644
index 0000000..417a2a5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ConjugateTranspose.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "ConjugateTranspose"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "perm"
+    type_attr: "Tperm"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tperm"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Const.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Const.pbtxt
new file mode 100644
index 0000000..6512d22
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Const.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "Const"
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "value"
+    type: "tensor"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ConsumeMutexLock.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ConsumeMutexLock.pbtxt
new file mode 100644
index 0000000..4340267
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ConsumeMutexLock.pbtxt
@@ -0,0 +1,8 @@
+op {
+  name: "ConsumeMutexLock"
+  input_arg {
+    name: "mutex_lock"
+    type: DT_VARIANT
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ControlTrigger.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ControlTrigger.pbtxt
new file mode 100644
index 0000000..2fe84fe
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ControlTrigger.pbtxt
@@ -0,0 +1,3 @@
+op {
+  name: "ControlTrigger"
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Conv2D.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Conv2D.pbtxt
new file mode 100644
index 0000000..10f6c1d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Conv2D.pbtxt
@@ -0,0 +1,286 @@
+op {
+  name: "Conv2D"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "use_cudnn_on_gpu"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+}
+op {
+  name: "Conv2D"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "use_cudnn_on_gpu"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
+op {
+  name: "Conv2D"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "use_cudnn_on_gpu"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
+op {
+  name: "Conv2D"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "use_cudnn_on_gpu"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+        s: "EXPLICIT"
+      }
+    }
+  }
+  attr {
+    name: "explicit_paddings"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Conv2DBackpropFilter.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Conv2DBackpropFilter.pbtxt
new file mode 100644
index 0000000..1c65612
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Conv2DBackpropFilter.pbtxt
@@ -0,0 +1,302 @@
+op {
+  name: "Conv2DBackpropFilter"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter_sizes"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "use_cudnn_on_gpu"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+}
+op {
+  name: "Conv2DBackpropFilter"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter_sizes"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "use_cudnn_on_gpu"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
+op {
+  name: "Conv2DBackpropFilter"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter_sizes"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "use_cudnn_on_gpu"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
+op {
+  name: "Conv2DBackpropFilter"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter_sizes"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "use_cudnn_on_gpu"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+        s: "EXPLICIT"
+      }
+    }
+  }
+  attr {
+    name: "explicit_paddings"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Conv2DBackpropInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Conv2DBackpropInput.pbtxt
new file mode 100644
index 0000000..7d46813
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Conv2DBackpropInput.pbtxt
@@ -0,0 +1,302 @@
+op {
+  name: "Conv2DBackpropInput"
+  input_arg {
+    name: "input_sizes"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "use_cudnn_on_gpu"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+}
+op {
+  name: "Conv2DBackpropInput"
+  input_arg {
+    name: "input_sizes"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "use_cudnn_on_gpu"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
+op {
+  name: "Conv2DBackpropInput"
+  input_arg {
+    name: "input_sizes"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "use_cudnn_on_gpu"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
+op {
+  name: "Conv2DBackpropInput"
+  input_arg {
+    name: "input_sizes"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "use_cudnn_on_gpu"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+        s: "EXPLICIT"
+      }
+    }
+  }
+  attr {
+    name: "explicit_paddings"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Conv3D.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Conv3D.pbtxt
new file mode 100644
index 0000000..a04a660
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Conv3D.pbtxt
@@ -0,0 +1,164 @@
+op {
+  name: "Conv3D"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
+op {
+  name: "Conv3D"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NDHWC"
+    }
+    allowed_values {
+      list {
+        s: "NDHWC"
+        s: "NCDHW"
+      }
+    }
+  }
+}
+op {
+  name: "Conv3D"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NDHWC"
+    }
+    allowed_values {
+      list {
+        s: "NDHWC"
+        s: "NCDHW"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Conv3DBackpropFilter.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Conv3DBackpropFilter.pbtxt
new file mode 100644
index 0000000..3091cfd
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Conv3DBackpropFilter.pbtxt
@@ -0,0 +1,159 @@
+op {
+  name: "Conv3DBackpropFilter"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  deprecation {
+    version: 10
+  }
+}
+op {
+  name: "Conv3DBackpropFilter"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  deprecation {
+    version: 10
+  }
+}
+op {
+  name: "Conv3DBackpropFilter"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+  deprecation {
+    version: 10
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Conv3DBackpropFilterV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Conv3DBackpropFilterV2.pbtxt
new file mode 100644
index 0000000..2494eba
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Conv3DBackpropFilterV2.pbtxt
@@ -0,0 +1,176 @@
+op {
+  name: "Conv3DBackpropFilterV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter_sizes"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
+op {
+  name: "Conv3DBackpropFilterV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter_sizes"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NDHWC"
+    }
+    allowed_values {
+      list {
+        s: "NDHWC"
+        s: "NCDHW"
+      }
+    }
+  }
+}
+op {
+  name: "Conv3DBackpropFilterV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter_sizes"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NDHWC"
+    }
+    allowed_values {
+      list {
+        s: "NDHWC"
+        s: "NCDHW"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Conv3DBackpropInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Conv3DBackpropInput.pbtxt
new file mode 100644
index 0000000..7fa3a55
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Conv3DBackpropInput.pbtxt
@@ -0,0 +1,159 @@
+op {
+  name: "Conv3DBackpropInput"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  deprecation {
+    version: 10
+  }
+}
+op {
+  name: "Conv3DBackpropInput"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  deprecation {
+    version: 10
+  }
+}
+op {
+  name: "Conv3DBackpropInput"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+  deprecation {
+    version: 10
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Conv3DBackpropInputV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Conv3DBackpropInputV2.pbtxt
new file mode 100644
index 0000000..e01b33d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Conv3DBackpropInputV2.pbtxt
@@ -0,0 +1,262 @@
+op {
+  name: "Conv3DBackpropInputV2"
+  input_arg {
+    name: "input_sizes"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
+op {
+  name: "Conv3DBackpropInputV2"
+  input_arg {
+    name: "input_sizes"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NDHWC"
+    }
+    allowed_values {
+      list {
+        s: "NDHWC"
+        s: "NCDHW"
+      }
+    }
+  }
+}
+op {
+  name: "Conv3DBackpropInputV2"
+  input_arg {
+    name: "input_sizes"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NDHWC"
+    }
+    allowed_values {
+      list {
+        s: "NDHWC"
+        s: "NCDHW"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
+op {
+  name: "Conv3DBackpropInputV2"
+  input_arg {
+    name: "input_sizes"
+    type_attr: "Tshape"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NDHWC"
+    }
+    allowed_values {
+      list {
+        s: "NDHWC"
+        s: "NCDHW"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+  attr {
+    name: "Tshape"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Copy.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Copy.pbtxt
new file mode 100644
index 0000000..258aecc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Copy.pbtxt
@@ -0,0 +1,54 @@
+op {
+  name: "Copy"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "tensor_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  allows_uninitialized_input: true
+}
+op {
+  name: "Copy"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "tensor_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "debug_ops_spec"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+  allows_uninitialized_input: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CopyHost.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CopyHost.pbtxt
new file mode 100644
index 0000000..07eb864
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CopyHost.pbtxt
@@ -0,0 +1,54 @@
+op {
+  name: "CopyHost"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "tensor_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  allows_uninitialized_input: true
+}
+op {
+  name: "CopyHost"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "tensor_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "debug_ops_spec"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+  allows_uninitialized_input: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Cos.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Cos.pbtxt
new file mode 100644
index 0000000..52b7c1e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Cos.pbtxt
@@ -0,0 +1,74 @@
+op {
+  name: "Cos"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Cos"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Cos"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Cosh.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Cosh.pbtxt
new file mode 100644
index 0000000..7a29316
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Cosh.pbtxt
@@ -0,0 +1,74 @@
+op {
+  name: "Cosh"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Cosh"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Cosh"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CountUpTo.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CountUpTo.pbtxt
new file mode 100644
index 0000000..05726df
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CountUpTo.pbtxt
@@ -0,0 +1,26 @@
+op {
+  name: "CountUpTo"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "limit"
+    type: "int"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CreateSummaryDbWriter.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CreateSummaryDbWriter.pbtxt
new file mode 100644
index 0000000..7a5f844
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CreateSummaryDbWriter.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "CreateSummaryDbWriter"
+  input_arg {
+    name: "writer"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "db_uri"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "experiment_name"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "run_name"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "user_name"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CreateSummaryFileWriter.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CreateSummaryFileWriter.pbtxt
new file mode 100644
index 0000000..61106e9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CreateSummaryFileWriter.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "CreateSummaryFileWriter"
+  input_arg {
+    name: "writer"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "logdir"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "max_queue"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "flush_millis"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "filename_suffix"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CropAndResize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CropAndResize.pbtxt
new file mode 100644
index 0000000..57b02c6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CropAndResize.pbtxt
@@ -0,0 +1,177 @@
+op {
+  name: "CropAndResize"
+  input_arg {
+    name: "image"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "boxes"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "box_ind"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "crop_size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "crops"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "method"
+    type: "string"
+    default_value {
+      s: "bilinear"
+    }
+    allowed_values {
+      list {
+        s: "bilinear"
+      }
+    }
+  }
+  attr {
+    name: "extrapolation_value"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+}
+op {
+  name: "CropAndResize"
+  input_arg {
+    name: "image"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "boxes"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "box_ind"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "crop_size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "crops"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "method"
+    type: "string"
+    default_value {
+      s: "bilinear"
+    }
+    allowed_values {
+      list {
+        s: "bilinear"
+      }
+    }
+  }
+  attr {
+    name: "extrapolation_value"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+}
+op {
+  name: "CropAndResize"
+  input_arg {
+    name: "image"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "boxes"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "box_ind"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "crop_size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "crops"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "method"
+    type: "string"
+    default_value {
+      s: "bilinear"
+    }
+    allowed_values {
+      list {
+        s: "bilinear"
+        s: "nearest"
+      }
+    }
+  }
+  attr {
+    name: "extrapolation_value"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CropAndResizeGradBoxes.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CropAndResizeGradBoxes.pbtxt
new file mode 100644
index 0000000..d3f62e3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CropAndResizeGradBoxes.pbtxt
@@ -0,0 +1,103 @@
+op {
+  name: "CropAndResizeGradBoxes"
+  input_arg {
+    name: "grads"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "image"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "boxes"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "box_ind"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "method"
+    type: "string"
+    default_value {
+      s: "bilinear"
+    }
+    allowed_values {
+      list {
+        s: "bilinear"
+      }
+    }
+  }
+}
+op {
+  name: "CropAndResizeGradBoxes"
+  input_arg {
+    name: "grads"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "image"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "boxes"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "box_ind"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "method"
+    type: "string"
+    default_value {
+      s: "bilinear"
+    }
+    allowed_values {
+      list {
+        s: "bilinear"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CropAndResizeGradImage.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CropAndResizeGradImage.pbtxt
new file mode 100644
index 0000000..6ae744f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CropAndResizeGradImage.pbtxt
@@ -0,0 +1,93 @@
+op {
+  name: "CropAndResizeGradImage"
+  input_arg {
+    name: "grads"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "boxes"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "box_ind"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "image_size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "method"
+    type: "string"
+    default_value {
+      s: "bilinear"
+    }
+    allowed_values {
+      list {
+        s: "bilinear"
+      }
+    }
+  }
+}
+op {
+  name: "CropAndResizeGradImage"
+  input_arg {
+    name: "grads"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "boxes"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "box_ind"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "image_size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "method"
+    type: "string"
+    default_value {
+      s: "bilinear"
+    }
+    allowed_values {
+      list {
+        s: "bilinear"
+        s: "nearest"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Cross.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Cross.pbtxt
new file mode 100644
index 0000000..b80215f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Cross.pbtxt
@@ -0,0 +1,136 @@
+op {
+  name: "Cross"
+  input_arg {
+    name: "a"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "product"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "Cross"
+  input_arg {
+    name: "a"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "product"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "Cross"
+  input_arg {
+    name: "a"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "product"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "Cross"
+  input_arg {
+    name: "a"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "product"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CrossReplicaSum.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CrossReplicaSum.pbtxt
new file mode 100644
index 0000000..09c2402
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CrossReplicaSum.pbtxt
@@ -0,0 +1,52 @@
+op {
+  name: "CrossReplicaSum"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "group_assignment"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+      }
+    }
+  }
+}
+op {
+  name: "CrossReplicaSum"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "group_assignment"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_INT32
+        type: DT_UINT32
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CudnnRNN.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CudnnRNN.pbtxt
new file mode 100644
index 0000000..deab608
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CudnnRNN.pbtxt
@@ -0,0 +1,117 @@
+op {
+  name: "CudnnRNN"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_h"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_c"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "params"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_h"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_c"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "reserve_space"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "rnn_mode"
+    type: "string"
+    default_value {
+      s: "lstm"
+    }
+    allowed_values {
+      list {
+        s: "rnn_relu"
+        s: "rnn_tanh"
+        s: "lstm"
+        s: "gru"
+      }
+    }
+  }
+  attr {
+    name: "input_mode"
+    type: "string"
+    default_value {
+      s: "linear_input"
+    }
+    allowed_values {
+      list {
+        s: "linear_input"
+        s: "skip_input"
+        s: "auto_select"
+      }
+    }
+  }
+  attr {
+    name: "direction"
+    type: "string"
+    default_value {
+      s: "unidirectional"
+    }
+    allowed_values {
+      list {
+        s: "unidirectional"
+        s: "bidirectional"
+      }
+    }
+  }
+  attr {
+    name: "dropout"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "is_training"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNBackprop.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNBackprop.pbtxt
new file mode 100644
index 0000000..2100ed2
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNBackprop.pbtxt
@@ -0,0 +1,138 @@
+op {
+  name: "CudnnRNNBackprop"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_h"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_c"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "params"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_h"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_c"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_backprop"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_h_backprop"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_c_backprop"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reserve_space"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "input_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "input_h_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "input_c_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "params_backprop"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "rnn_mode"
+    type: "string"
+    default_value {
+      s: "lstm"
+    }
+    allowed_values {
+      list {
+        s: "rnn_relu"
+        s: "rnn_tanh"
+        s: "lstm"
+        s: "gru"
+      }
+    }
+  }
+  attr {
+    name: "input_mode"
+    type: "string"
+    default_value {
+      s: "linear_input"
+    }
+    allowed_values {
+      list {
+        s: "linear_input"
+        s: "skip_input"
+        s: "auto_select"
+      }
+    }
+  }
+  attr {
+    name: "direction"
+    type: "string"
+    default_value {
+      s: "unidirectional"
+    }
+    allowed_values {
+      list {
+        s: "unidirectional"
+        s: "bidirectional"
+      }
+    }
+  }
+  attr {
+    name: "dropout"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNBackpropV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNBackpropV2.pbtxt
new file mode 100644
index 0000000..eac269a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNBackpropV2.pbtxt
@@ -0,0 +1,142 @@
+op {
+  name: "CudnnRNNBackpropV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_h"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_c"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "params"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_h"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_c"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_backprop"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_h_backprop"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_c_backprop"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reserve_space"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "host_reserved"
+    type: DT_INT8
+  }
+  output_arg {
+    name: "input_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "input_h_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "input_c_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "params_backprop"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "rnn_mode"
+    type: "string"
+    default_value {
+      s: "lstm"
+    }
+    allowed_values {
+      list {
+        s: "rnn_relu"
+        s: "rnn_tanh"
+        s: "lstm"
+        s: "gru"
+      }
+    }
+  }
+  attr {
+    name: "input_mode"
+    type: "string"
+    default_value {
+      s: "linear_input"
+    }
+    allowed_values {
+      list {
+        s: "linear_input"
+        s: "skip_input"
+        s: "auto_select"
+      }
+    }
+  }
+  attr {
+    name: "direction"
+    type: "string"
+    default_value {
+      s: "unidirectional"
+    }
+    allowed_values {
+      list {
+        s: "unidirectional"
+        s: "bidirectional"
+      }
+    }
+  }
+  attr {
+    name: "dropout"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNBackpropV3.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNBackpropV3.pbtxt
new file mode 100644
index 0000000..c5342f4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNBackpropV3.pbtxt
@@ -0,0 +1,459 @@
+op {
+  name: "CudnnRNNBackpropV3"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_h"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_c"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "params"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sequence_lengths"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_h"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_c"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_backprop"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_h_backprop"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_c_backprop"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reserve_space"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "host_reserved"
+    type: DT_INT8
+  }
+  output_arg {
+    name: "input_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "input_h_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "input_c_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "params_backprop"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "rnn_mode"
+    type: "string"
+    default_value {
+      s: "lstm"
+    }
+    allowed_values {
+      list {
+        s: "rnn_relu"
+        s: "rnn_tanh"
+        s: "lstm"
+        s: "gru"
+      }
+    }
+  }
+  attr {
+    name: "input_mode"
+    type: "string"
+    default_value {
+      s: "linear_input"
+    }
+    allowed_values {
+      list {
+        s: "linear_input"
+        s: "skip_input"
+        s: "auto_select"
+      }
+    }
+  }
+  attr {
+    name: "direction"
+    type: "string"
+    default_value {
+      s: "unidirectional"
+    }
+    allowed_values {
+      list {
+        s: "unidirectional"
+        s: "bidirectional"
+      }
+    }
+  }
+  attr {
+    name: "dropout"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "CudnnRNNBackpropV3"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_h"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_c"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "params"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sequence_lengths"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_h"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_c"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_backprop"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_h_backprop"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_c_backprop"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reserve_space"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "host_reserved"
+    type: DT_INT8
+  }
+  output_arg {
+    name: "input_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "input_h_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "input_c_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "params_backprop"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "rnn_mode"
+    type: "string"
+    default_value {
+      s: "lstm"
+    }
+    allowed_values {
+      list {
+        s: "rnn_relu"
+        s: "rnn_tanh"
+        s: "lstm"
+        s: "gru"
+      }
+    }
+  }
+  attr {
+    name: "input_mode"
+    type: "string"
+    default_value {
+      s: "linear_input"
+    }
+    allowed_values {
+      list {
+        s: "linear_input"
+        s: "skip_input"
+        s: "auto_select"
+      }
+    }
+  }
+  attr {
+    name: "direction"
+    type: "string"
+    default_value {
+      s: "unidirectional"
+    }
+    allowed_values {
+      list {
+        s: "unidirectional"
+        s: "bidirectional"
+      }
+    }
+  }
+  attr {
+    name: "dropout"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "time_major"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "CudnnRNNBackpropV3"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_h"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_c"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "params"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sequence_lengths"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_h"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_c"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_backprop"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_h_backprop"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_c_backprop"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reserve_space"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "host_reserved"
+    type: DT_INT8
+  }
+  output_arg {
+    name: "input_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "input_h_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "input_c_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "params_backprop"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "rnn_mode"
+    type: "string"
+    default_value {
+      s: "lstm"
+    }
+    allowed_values {
+      list {
+        s: "rnn_relu"
+        s: "rnn_tanh"
+        s: "lstm"
+        s: "gru"
+      }
+    }
+  }
+  attr {
+    name: "input_mode"
+    type: "string"
+    default_value {
+      s: "linear_input"
+    }
+    allowed_values {
+      list {
+        s: "linear_input"
+        s: "skip_input"
+        s: "auto_select"
+      }
+    }
+  }
+  attr {
+    name: "direction"
+    type: "string"
+    default_value {
+      s: "unidirectional"
+    }
+    allowed_values {
+      list {
+        s: "unidirectional"
+        s: "bidirectional"
+      }
+    }
+  }
+  attr {
+    name: "dropout"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "num_proj"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "time_major"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNCanonicalToParams.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNCanonicalToParams.pbtxt
new file mode 100644
index 0000000..63a7022
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNCanonicalToParams.pbtxt
@@ -0,0 +1,109 @@
+op {
+  name: "CudnnRNNCanonicalToParams"
+  input_arg {
+    name: "num_layers"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "num_units"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "input_size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "weights"
+    type_attr: "T"
+    number_attr: "num_params"
+  }
+  input_arg {
+    name: "biases"
+    type_attr: "T"
+    number_attr: "num_params"
+  }
+  output_arg {
+    name: "params"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "num_params"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "rnn_mode"
+    type: "string"
+    default_value {
+      s: "lstm"
+    }
+    allowed_values {
+      list {
+        s: "rnn_relu"
+        s: "rnn_tanh"
+        s: "lstm"
+        s: "gru"
+      }
+    }
+  }
+  attr {
+    name: "input_mode"
+    type: "string"
+    default_value {
+      s: "linear_input"
+    }
+    allowed_values {
+      list {
+        s: "linear_input"
+        s: "skip_input"
+        s: "auto_select"
+      }
+    }
+  }
+  attr {
+    name: "direction"
+    type: "string"
+    default_value {
+      s: "unidirectional"
+    }
+    allowed_values {
+      list {
+        s: "unidirectional"
+        s: "bidirectional"
+      }
+    }
+  }
+  attr {
+    name: "dropout"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNCanonicalToParamsV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNCanonicalToParamsV2.pbtxt
new file mode 100644
index 0000000..ce8a7f8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNCanonicalToParamsV2.pbtxt
@@ -0,0 +1,122 @@
+op {
+  name: "CudnnRNNCanonicalToParamsV2"
+  input_arg {
+    name: "num_layers"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "num_units"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "input_size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "weights"
+    type_attr: "T"
+    number_attr: "num_params_weights"
+  }
+  input_arg {
+    name: "biases"
+    type_attr: "T"
+    number_attr: "num_params_biases"
+  }
+  output_arg {
+    name: "params"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "num_params_weights"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_params_biases"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "rnn_mode"
+    type: "string"
+    default_value {
+      s: "lstm"
+    }
+    allowed_values {
+      list {
+        s: "rnn_relu"
+        s: "rnn_tanh"
+        s: "lstm"
+        s: "gru"
+      }
+    }
+  }
+  attr {
+    name: "input_mode"
+    type: "string"
+    default_value {
+      s: "linear_input"
+    }
+    allowed_values {
+      list {
+        s: "linear_input"
+        s: "skip_input"
+        s: "auto_select"
+      }
+    }
+  }
+  attr {
+    name: "direction"
+    type: "string"
+    default_value {
+      s: "unidirectional"
+    }
+    allowed_values {
+      list {
+        s: "unidirectional"
+        s: "bidirectional"
+      }
+    }
+  }
+  attr {
+    name: "dropout"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "num_proj"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNParamsSize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNParamsSize.pbtxt
new file mode 100644
index 0000000..50c8a28
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNParamsSize.pbtxt
@@ -0,0 +1,213 @@
+op {
+  name: "CudnnRNNParamsSize"
+  input_arg {
+    name: "num_layers"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "num_units"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "input_size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "params_size"
+    type_attr: "S"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "S"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "rnn_mode"
+    type: "string"
+    default_value {
+      s: "lstm"
+    }
+    allowed_values {
+      list {
+        s: "rnn_relu"
+        s: "rnn_tanh"
+        s: "lstm"
+        s: "gru"
+      }
+    }
+  }
+  attr {
+    name: "input_mode"
+    type: "string"
+    default_value {
+      s: "linear_input"
+    }
+    allowed_values {
+      list {
+        s: "linear_input"
+        s: "skip_input"
+        s: "auto_select"
+      }
+    }
+  }
+  attr {
+    name: "direction"
+    type: "string"
+    default_value {
+      s: "unidirectional"
+    }
+    allowed_values {
+      list {
+        s: "unidirectional"
+        s: "bidirectional"
+      }
+    }
+  }
+  attr {
+    name: "dropout"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+}
+op {
+  name: "CudnnRNNParamsSize"
+  input_arg {
+    name: "num_layers"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "num_units"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "input_size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "params_size"
+    type_attr: "S"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "S"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "rnn_mode"
+    type: "string"
+    default_value {
+      s: "lstm"
+    }
+    allowed_values {
+      list {
+        s: "rnn_relu"
+        s: "rnn_tanh"
+        s: "lstm"
+        s: "gru"
+      }
+    }
+  }
+  attr {
+    name: "input_mode"
+    type: "string"
+    default_value {
+      s: "linear_input"
+    }
+    allowed_values {
+      list {
+        s: "linear_input"
+        s: "skip_input"
+        s: "auto_select"
+      }
+    }
+  }
+  attr {
+    name: "direction"
+    type: "string"
+    default_value {
+      s: "unidirectional"
+    }
+    allowed_values {
+      list {
+        s: "unidirectional"
+        s: "bidirectional"
+      }
+    }
+  }
+  attr {
+    name: "dropout"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "num_proj"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNParamsToCanonical.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNParamsToCanonical.pbtxt
new file mode 100644
index 0000000..568e338
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNParamsToCanonical.pbtxt
@@ -0,0 +1,109 @@
+op {
+  name: "CudnnRNNParamsToCanonical"
+  input_arg {
+    name: "num_layers"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "num_units"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "input_size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "params"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "weights"
+    type_attr: "T"
+    number_attr: "num_params"
+  }
+  output_arg {
+    name: "biases"
+    type_attr: "T"
+    number_attr: "num_params"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "num_params"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "rnn_mode"
+    type: "string"
+    default_value {
+      s: "lstm"
+    }
+    allowed_values {
+      list {
+        s: "rnn_relu"
+        s: "rnn_tanh"
+        s: "lstm"
+        s: "gru"
+      }
+    }
+  }
+  attr {
+    name: "input_mode"
+    type: "string"
+    default_value {
+      s: "linear_input"
+    }
+    allowed_values {
+      list {
+        s: "linear_input"
+        s: "skip_input"
+        s: "auto_select"
+      }
+    }
+  }
+  attr {
+    name: "direction"
+    type: "string"
+    default_value {
+      s: "unidirectional"
+    }
+    allowed_values {
+      list {
+        s: "unidirectional"
+        s: "bidirectional"
+      }
+    }
+  }
+  attr {
+    name: "dropout"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNParamsToCanonicalV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNParamsToCanonicalV2.pbtxt
new file mode 100644
index 0000000..e9c02a2
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNParamsToCanonicalV2.pbtxt
@@ -0,0 +1,122 @@
+op {
+  name: "CudnnRNNParamsToCanonicalV2"
+  input_arg {
+    name: "num_layers"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "num_units"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "input_size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "params"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "weights"
+    type_attr: "T"
+    number_attr: "num_params_weights"
+  }
+  output_arg {
+    name: "biases"
+    type_attr: "T"
+    number_attr: "num_params_biases"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "num_params_weights"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_params_biases"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "rnn_mode"
+    type: "string"
+    default_value {
+      s: "lstm"
+    }
+    allowed_values {
+      list {
+        s: "rnn_relu"
+        s: "rnn_tanh"
+        s: "lstm"
+        s: "gru"
+      }
+    }
+  }
+  attr {
+    name: "input_mode"
+    type: "string"
+    default_value {
+      s: "linear_input"
+    }
+    allowed_values {
+      list {
+        s: "linear_input"
+        s: "skip_input"
+        s: "auto_select"
+      }
+    }
+  }
+  attr {
+    name: "direction"
+    type: "string"
+    default_value {
+      s: "unidirectional"
+    }
+    allowed_values {
+      list {
+        s: "unidirectional"
+        s: "bidirectional"
+      }
+    }
+  }
+  attr {
+    name: "dropout"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "num_proj"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNV2.pbtxt
new file mode 100644
index 0000000..6ad7a0e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNV2.pbtxt
@@ -0,0 +1,121 @@
+op {
+  name: "CudnnRNNV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_h"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_c"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "params"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_h"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_c"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "reserve_space"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "host_reserved"
+    type: DT_INT8
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "rnn_mode"
+    type: "string"
+    default_value {
+      s: "lstm"
+    }
+    allowed_values {
+      list {
+        s: "rnn_relu"
+        s: "rnn_tanh"
+        s: "lstm"
+        s: "gru"
+      }
+    }
+  }
+  attr {
+    name: "input_mode"
+    type: "string"
+    default_value {
+      s: "linear_input"
+    }
+    allowed_values {
+      list {
+        s: "linear_input"
+        s: "skip_input"
+        s: "auto_select"
+      }
+    }
+  }
+  attr {
+    name: "direction"
+    type: "string"
+    default_value {
+      s: "unidirectional"
+    }
+    allowed_values {
+      list {
+        s: "unidirectional"
+        s: "bidirectional"
+      }
+    }
+  }
+  attr {
+    name: "dropout"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "is_training"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNV3.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNV3.pbtxt
new file mode 100644
index 0000000..a655183
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CudnnRNNV3.pbtxt
@@ -0,0 +1,396 @@
+op {
+  name: "CudnnRNNV3"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_h"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_c"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "params"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sequence_lengths"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_h"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_c"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "reserve_space"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "host_reserved"
+    type: DT_INT8
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "rnn_mode"
+    type: "string"
+    default_value {
+      s: "lstm"
+    }
+    allowed_values {
+      list {
+        s: "rnn_relu"
+        s: "rnn_tanh"
+        s: "lstm"
+        s: "gru"
+      }
+    }
+  }
+  attr {
+    name: "input_mode"
+    type: "string"
+    default_value {
+      s: "linear_input"
+    }
+    allowed_values {
+      list {
+        s: "linear_input"
+        s: "skip_input"
+        s: "auto_select"
+      }
+    }
+  }
+  attr {
+    name: "direction"
+    type: "string"
+    default_value {
+      s: "unidirectional"
+    }
+    allowed_values {
+      list {
+        s: "unidirectional"
+        s: "bidirectional"
+      }
+    }
+  }
+  attr {
+    name: "dropout"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "is_training"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "CudnnRNNV3"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_h"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_c"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "params"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sequence_lengths"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_h"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_c"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "reserve_space"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "host_reserved"
+    type: DT_INT8
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "rnn_mode"
+    type: "string"
+    default_value {
+      s: "lstm"
+    }
+    allowed_values {
+      list {
+        s: "rnn_relu"
+        s: "rnn_tanh"
+        s: "lstm"
+        s: "gru"
+      }
+    }
+  }
+  attr {
+    name: "input_mode"
+    type: "string"
+    default_value {
+      s: "linear_input"
+    }
+    allowed_values {
+      list {
+        s: "linear_input"
+        s: "skip_input"
+        s: "auto_select"
+      }
+    }
+  }
+  attr {
+    name: "direction"
+    type: "string"
+    default_value {
+      s: "unidirectional"
+    }
+    allowed_values {
+      list {
+        s: "unidirectional"
+        s: "bidirectional"
+      }
+    }
+  }
+  attr {
+    name: "dropout"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "is_training"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "time_major"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "CudnnRNNV3"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_h"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_c"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "params"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sequence_lengths"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_h"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_c"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "reserve_space"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "host_reserved"
+    type: DT_INT8
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "rnn_mode"
+    type: "string"
+    default_value {
+      s: "lstm"
+    }
+    allowed_values {
+      list {
+        s: "rnn_relu"
+        s: "rnn_tanh"
+        s: "lstm"
+        s: "gru"
+      }
+    }
+  }
+  attr {
+    name: "input_mode"
+    type: "string"
+    default_value {
+      s: "linear_input"
+    }
+    allowed_values {
+      list {
+        s: "linear_input"
+        s: "skip_input"
+        s: "auto_select"
+      }
+    }
+  }
+  attr {
+    name: "direction"
+    type: "string"
+    default_value {
+      s: "unidirectional"
+    }
+    allowed_values {
+      list {
+        s: "unidirectional"
+        s: "bidirectional"
+      }
+    }
+  }
+  attr {
+    name: "dropout"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "num_proj"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "is_training"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "time_major"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Cumprod.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Cumprod.pbtxt
new file mode 100644
index 0000000..72e2c03
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Cumprod.pbtxt
@@ -0,0 +1,264 @@
+op {
+  name: "Cumprod"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+  }
+  attr {
+    name: "exclusive"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "reverse"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Cumprod"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+  }
+  attr {
+    name: "exclusive"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "reverse"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Cumprod"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+  }
+  attr {
+    name: "exclusive"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "reverse"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Cumprod"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+  }
+  attr {
+    name: "exclusive"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "reverse"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Cumsum.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Cumsum.pbtxt
new file mode 100644
index 0000000..8249a8f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Cumsum.pbtxt
@@ -0,0 +1,264 @@
+op {
+  name: "Cumsum"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+  }
+  attr {
+    name: "exclusive"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "reverse"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Cumsum"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+  }
+  attr {
+    name: "exclusive"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "reverse"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Cumsum"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+  }
+  attr {
+    name: "exclusive"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "reverse"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Cumsum"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+  }
+  attr {
+    name: "exclusive"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "reverse"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/CumulativeLogsumexp.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CumulativeLogsumexp.pbtxt
new file mode 100644
index 0000000..3b70a7b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/CumulativeLogsumexp.pbtxt
@@ -0,0 +1,53 @@
+op {
+  name: "CumulativeLogsumexp"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+  }
+  attr {
+    name: "exclusive"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "reverse"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DataFormatDimMap.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DataFormatDimMap.pbtxt
new file mode 100644
index 0000000..a01806b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DataFormatDimMap.pbtxt
@@ -0,0 +1,38 @@
+op {
+  name: "DataFormatDimMap"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "src_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+  }
+  attr {
+    name: "dst_format"
+    type: "string"
+    default_value {
+      s: "NCHW"
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DataFormatVecPermute.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DataFormatVecPermute.pbtxt
new file mode 100644
index 0000000..e439414
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DataFormatVecPermute.pbtxt
@@ -0,0 +1,38 @@
+op {
+  name: "DataFormatVecPermute"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "src_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+  }
+  attr {
+    name: "dst_format"
+    type: "string"
+    default_value {
+      s: "NCHW"
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DatasetCardinality.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DatasetCardinality.pbtxt
new file mode 100644
index 0000000..61638a6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DatasetCardinality.pbtxt
@@ -0,0 +1,11 @@
+op {
+  name: "DatasetCardinality"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "cardinality"
+    type: DT_INT64
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DatasetFromGraph.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DatasetFromGraph.pbtxt
new file mode 100644
index 0000000..30f5384
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DatasetFromGraph.pbtxt
@@ -0,0 +1,11 @@
+op {
+  name: "DatasetFromGraph"
+  input_arg {
+    name: "graph_def"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DatasetToGraph.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DatasetToGraph.pbtxt
new file mode 100644
index 0000000..1f1751c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DatasetToGraph.pbtxt
@@ -0,0 +1,11 @@
+op {
+  name: "DatasetToGraph"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "graph"
+    type: DT_STRING
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DatasetToSingleElement.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DatasetToSingleElement.pbtxt
new file mode 100644
index 0000000..d9080d7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DatasetToSingleElement.pbtxt
@@ -0,0 +1,47 @@
+op {
+  name: "DatasetToSingleElement"
+  input_arg {
+    name: "dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "components"
+    type_list_attr: "output_types"
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
+  name: "DatasetToSingleElement"
+  input_arg {
+    name: "dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "components"
+    type_list_attr: "output_types"
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DatasetToTFRecord.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DatasetToTFRecord.pbtxt
new file mode 100644
index 0000000..e13f88e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DatasetToTFRecord.pbtxt
@@ -0,0 +1,16 @@
+op {
+  name: "DatasetToTFRecord"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "filename"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "compression_type"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DebugGradientIdentity.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DebugGradientIdentity.pbtxt
new file mode 100644
index 0000000..e1b4257
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DebugGradientIdentity.pbtxt
@@ -0,0 +1,16 @@
+op {
+  name: "DebugGradientIdentity"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  allows_uninitialized_input: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DebugGradientRefIdentity.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DebugGradientRefIdentity.pbtxt
new file mode 100644
index 0000000..f75b778
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DebugGradientRefIdentity.pbtxt
@@ -0,0 +1,18 @@
+op {
+  name: "DebugGradientRefIdentity"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+    is_ref: true
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  allows_uninitialized_input: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DebugIdentity.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DebugIdentity.pbtxt
new file mode 100644
index 0000000..50f97d8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DebugIdentity.pbtxt
@@ -0,0 +1,114 @@
+op {
+  name: "DebugIdentity"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "tensor_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "debug_urls"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+  allows_uninitialized_input: true
+}
+op {
+  name: "DebugIdentity"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "tensor_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "debug_urls"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "gated_grpc"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  allows_uninitialized_input: true
+}
+op {
+  name: "DebugIdentity"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "device_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "tensor_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "debug_urls"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "gated_grpc"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  allows_uninitialized_input: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DebugNanCount.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DebugNanCount.pbtxt
new file mode 100644
index 0000000..82ae073
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DebugNanCount.pbtxt
@@ -0,0 +1,114 @@
+op {
+  name: "DebugNanCount"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "tensor_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "debug_urls"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+  allows_uninitialized_input: true
+}
+op {
+  name: "DebugNanCount"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "tensor_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "debug_urls"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "gated_grpc"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  allows_uninitialized_input: true
+}
+op {
+  name: "DebugNanCount"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "device_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "tensor_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "debug_urls"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "gated_grpc"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  allows_uninitialized_input: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DebugNumericSummary.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DebugNumericSummary.pbtxt
new file mode 100644
index 0000000..d108b54
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DebugNumericSummary.pbtxt
@@ -0,0 +1,208 @@
+op {
+  name: "DebugNumericSummary"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type: DT_DOUBLE
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "tensor_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "debug_urls"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+  allows_uninitialized_input: true
+}
+op {
+  name: "DebugNumericSummary"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type: DT_DOUBLE
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "tensor_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "debug_urls"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "lower_bound"
+    type: "float"
+    default_value {
+      f: -inf
+    }
+  }
+  attr {
+    name: "upper_bound"
+    type: "float"
+    default_value {
+      f: inf
+    }
+  }
+  attr {
+    name: "mute_if_healthy"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  allows_uninitialized_input: true
+}
+op {
+  name: "DebugNumericSummary"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type: DT_DOUBLE
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "tensor_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "debug_urls"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "lower_bound"
+    type: "float"
+    default_value {
+      f: -inf
+    }
+  }
+  attr {
+    name: "upper_bound"
+    type: "float"
+    default_value {
+      f: inf
+    }
+  }
+  attr {
+    name: "mute_if_healthy"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "gated_grpc"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  allows_uninitialized_input: true
+}
+op {
+  name: "DebugNumericSummary"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type: DT_DOUBLE
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "device_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "tensor_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "debug_urls"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "lower_bound"
+    type: "float"
+    default_value {
+      f: -inf
+    }
+  }
+  attr {
+    name: "upper_bound"
+    type: "float"
+    default_value {
+      f: inf
+    }
+  }
+  attr {
+    name: "mute_if_healthy"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "gated_grpc"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  allows_uninitialized_input: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DecodeAndCropJpeg.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DecodeAndCropJpeg.pbtxt
new file mode 100644
index 0000000..d1d767f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DecodeAndCropJpeg.pbtxt
@@ -0,0 +1,57 @@
+op {
+  name: "DecodeAndCropJpeg"
+  input_arg {
+    name: "contents"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "crop_window"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "image"
+    type: DT_UINT8
+  }
+  attr {
+    name: "channels"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "ratio"
+    type: "int"
+    default_value {
+      i: 1
+    }
+  }
+  attr {
+    name: "fancy_upscaling"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "try_recover_truncated"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "acceptable_fraction"
+    type: "float"
+    default_value {
+      f: 1
+    }
+  }
+  attr {
+    name: "dct_method"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DecodeBase64.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DecodeBase64.pbtxt
new file mode 100644
index 0000000..cc7d61c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DecodeBase64.pbtxt
@@ -0,0 +1,11 @@
+op {
+  name: "DecodeBase64"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type: DT_STRING
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DecodeBmp.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DecodeBmp.pbtxt
new file mode 100644
index 0000000..40ac5f0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DecodeBmp.pbtxt
@@ -0,0 +1,18 @@
+op {
+  name: "DecodeBmp"
+  input_arg {
+    name: "contents"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "image"
+    type: DT_UINT8
+  }
+  attr {
+    name: "channels"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DecodeCSV.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DecodeCSV.pbtxt
new file mode 100644
index 0000000..f4fee2a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DecodeCSV.pbtxt
@@ -0,0 +1,239 @@
+op {
+  name: "DecodeCSV"
+  input_arg {
+    name: "records"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "record_defaults"
+    type_list_attr: "OUT_TYPE"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "OUT_TYPE"
+  }
+  attr {
+    name: "OUT_TYPE"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "field_delim"
+    type: "string"
+    default_value {
+      s: ","
+    }
+  }
+}
+op {
+  name: "DecodeCSV"
+  input_arg {
+    name: "records"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "record_defaults"
+    type_list_attr: "OUT_TYPE"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "OUT_TYPE"
+  }
+  attr {
+    name: "OUT_TYPE"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "field_delim"
+    type: "string"
+    default_value {
+      s: ","
+    }
+  }
+  attr {
+    name: "use_quote_delim"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
+op {
+  name: "DecodeCSV"
+  input_arg {
+    name: "records"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "record_defaults"
+    type_list_attr: "OUT_TYPE"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "OUT_TYPE"
+  }
+  attr {
+    name: "OUT_TYPE"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "field_delim"
+    type: "string"
+    default_value {
+      s: ","
+    }
+  }
+  attr {
+    name: "use_quote_delim"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "na_value"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+}
+op {
+  name: "DecodeCSV"
+  input_arg {
+    name: "records"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "record_defaults"
+    type_list_attr: "OUT_TYPE"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "OUT_TYPE"
+  }
+  attr {
+    name: "OUT_TYPE"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "field_delim"
+    type: "string"
+    default_value {
+      s: ","
+    }
+  }
+  attr {
+    name: "use_quote_delim"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "na_value"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+}
+op {
+  name: "DecodeCSV"
+  input_arg {
+    name: "records"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "record_defaults"
+    type_list_attr: "OUT_TYPE"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "OUT_TYPE"
+  }
+  attr {
+    name: "OUT_TYPE"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "field_delim"
+    type: "string"
+    default_value {
+      s: ","
+    }
+  }
+  attr {
+    name: "use_quote_delim"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "na_value"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "select_cols"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DecodeCompressed.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DecodeCompressed.pbtxt
new file mode 100644
index 0000000..8a345ff
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DecodeCompressed.pbtxt
@@ -0,0 +1,18 @@
+op {
+  name: "DecodeCompressed"
+  input_arg {
+    name: "bytes"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type: DT_STRING
+  }
+  attr {
+    name: "compression_type"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DecodeGif.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DecodeGif.pbtxt
new file mode 100644
index 0000000..89b21b3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DecodeGif.pbtxt
@@ -0,0 +1,11 @@
+op {
+  name: "DecodeGif"
+  input_arg {
+    name: "contents"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "image"
+    type: DT_UINT8
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DecodeJSONExample.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DecodeJSONExample.pbtxt
new file mode 100644
index 0000000..ec37ae5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DecodeJSONExample.pbtxt
@@ -0,0 +1,11 @@
+op {
+  name: "DecodeJSONExample"
+  input_arg {
+    name: "json_examples"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "binary_examples"
+    type: DT_STRING
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DecodeJpeg.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DecodeJpeg.pbtxt
new file mode 100644
index 0000000..9a4b4e4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DecodeJpeg.pbtxt
@@ -0,0 +1,53 @@
+op {
+  name: "DecodeJpeg"
+  input_arg {
+    name: "contents"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "image"
+    type: DT_UINT8
+  }
+  attr {
+    name: "channels"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "ratio"
+    type: "int"
+    default_value {
+      i: 1
+    }
+  }
+  attr {
+    name: "fancy_upscaling"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "try_recover_truncated"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "acceptable_fraction"
+    type: "float"
+    default_value {
+      f: 1
+    }
+  }
+  attr {
+    name: "dct_method"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DecodePaddedRaw.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DecodePaddedRaw.pbtxt
new file mode 100644
index 0000000..dac2d95
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DecodePaddedRaw.pbtxt
@@ -0,0 +1,39 @@
+op {
+  name: "DecodePaddedRaw"
+  input_arg {
+    name: "input_bytes"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "fixed_length"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT16
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "little_endian"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DecodePng.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DecodePng.pbtxt
new file mode 100644
index 0000000..dd7bd02
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DecodePng.pbtxt
@@ -0,0 +1,31 @@
+op {
+  name: "DecodePng"
+  input_arg {
+    name: "contents"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "image"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "channels"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    default_value {
+      type: DT_UINT8
+    }
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_UINT16
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DecodeProtoV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DecodeProtoV2.pbtxt
new file mode 100644
index 0000000..ae72d29
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DecodeProtoV2.pbtxt
@@ -0,0 +1,49 @@
+op {
+  name: "DecodeProtoV2"
+  input_arg {
+    name: "bytes"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "sizes"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "values"
+    type_list_attr: "output_types"
+  }
+  attr {
+    name: "message_type"
+    type: "string"
+  }
+  attr {
+    name: "field_names"
+    type: "list(string)"
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "descriptor_source"
+    type: "string"
+    default_value {
+      s: "local://"
+    }
+  }
+  attr {
+    name: "message_format"
+    type: "string"
+    default_value {
+      s: "binary"
+    }
+  }
+  attr {
+    name: "sanitize"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DecodeRaw.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DecodeRaw.pbtxt
new file mode 100644
index 0000000..77f27f9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DecodeRaw.pbtxt
@@ -0,0 +1,144 @@
+op {
+  name: "DecodeRaw"
+  input_arg {
+    name: "bytes"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "little_endian"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
+op {
+  name: "DecodeRaw"
+  input_arg {
+    name: "bytes"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT16
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "little_endian"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
+op {
+  name: "DecodeRaw"
+  input_arg {
+    name: "bytes"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT16
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  attr {
+    name: "little_endian"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
+op {
+  name: "DecodeRaw"
+  input_arg {
+    name: "bytes"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT16
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_BOOL
+      }
+    }
+  }
+  attr {
+    name: "little_endian"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DecodeWav.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DecodeWav.pbtxt
new file mode 100644
index 0000000..8eba7b9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DecodeWav.pbtxt
@@ -0,0 +1,29 @@
+op {
+  name: "DecodeWav"
+  input_arg {
+    name: "contents"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "audio"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "sample_rate"
+    type: DT_INT32
+  }
+  attr {
+    name: "desired_channels"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "desired_samples"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DeepCopy.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DeepCopy.pbtxt
new file mode 100644
index 0000000..e673960
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DeepCopy.pbtxt
@@ -0,0 +1,16 @@
+op {
+  name: "DeepCopy"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DeleteIterator.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DeleteIterator.pbtxt
new file mode 100644
index 0000000..3050ea9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DeleteIterator.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "DeleteIterator"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "deleter"
+    type: DT_VARIANT
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DeleteMemoryCache.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DeleteMemoryCache.pbtxt
new file mode 100644
index 0000000..821293b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DeleteMemoryCache.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "DeleteMemoryCache"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "deleter"
+    type: DT_VARIANT
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DeleteMultiDeviceIterator.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DeleteMultiDeviceIterator.pbtxt
new file mode 100644
index 0000000..b4ae640
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DeleteMultiDeviceIterator.pbtxt
@@ -0,0 +1,22 @@
+op {
+  name: "DeleteMultiDeviceIterator"
+  input_arg {
+    name: "multi_device_iterator"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "iterators"
+    type: DT_RESOURCE
+    number_attr: "N"
+  }
+  input_arg {
+    name: "deleter"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DeleteRandomSeedGenerator.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DeleteRandomSeedGenerator.pbtxt
new file mode 100644
index 0000000..0c0d2d1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DeleteRandomSeedGenerator.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "DeleteRandomSeedGenerator"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "deleter"
+    type: DT_VARIANT
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DeleteSessionTensor.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DeleteSessionTensor.pbtxt
new file mode 100644
index 0000000..def4c10
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DeleteSessionTensor.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "DeleteSessionTensor"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+}
+op {
+  name: "DeleteSessionTensor"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DenseToDenseSetOperation.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DenseToDenseSetOperation.pbtxt
new file mode 100644
index 0000000..5188a82
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DenseToDenseSetOperation.pbtxt
@@ -0,0 +1,49 @@
+op {
+  name: "DenseToDenseSetOperation"
+  input_arg {
+    name: "set1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "set2"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "result_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "result_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "result_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "set_operation"
+    type: "string"
+  }
+  attr {
+    name: "validate_indices"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_STRING
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DenseToSparseBatchDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DenseToSparseBatchDataset.pbtxt
new file mode 100644
index 0000000..051589d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DenseToSparseBatchDataset.pbtxt
@@ -0,0 +1,31 @@
+op {
+  name: "DenseToSparseBatchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "batch_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "row_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DenseToSparseSetOperation.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DenseToSparseSetOperation.pbtxt
new file mode 100644
index 0000000..71c9c37
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DenseToSparseSetOperation.pbtxt
@@ -0,0 +1,57 @@
+op {
+  name: "DenseToSparseSetOperation"
+  input_arg {
+    name: "set1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "set2_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "set2_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "set2_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "result_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "result_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "result_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "set_operation"
+    type: "string"
+  }
+  attr {
+    name: "validate_indices"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_STRING
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DepthToSpace.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DepthToSpace.pbtxt
new file mode 100644
index 0000000..422fe7f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DepthToSpace.pbtxt
@@ -0,0 +1,56 @@
+op {
+  name: "DepthToSpace"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "block_size"
+    type: "int"
+    has_minimum: true
+    minimum: 2
+  }
+}
+op {
+  name: "DepthToSpace"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "block_size"
+    type: "int"
+    has_minimum: true
+    minimum: 2
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+        s: "NCHW_VECT_C"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DepthwiseConv2dNative.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DepthwiseConv2dNative.pbtxt
new file mode 100644
index 0000000..14dc12e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DepthwiseConv2dNative.pbtxt
@@ -0,0 +1,157 @@
+op {
+  name: "DepthwiseConv2dNative"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
+op {
+  name: "DepthwiseConv2dNative"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+}
+op {
+  name: "DepthwiseConv2dNative"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DepthwiseConv2dNativeBackpropFilter.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DepthwiseConv2dNativeBackpropFilter.pbtxt
new file mode 100644
index 0000000..9ae9df1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DepthwiseConv2dNativeBackpropFilter.pbtxt
@@ -0,0 +1,238 @@
+op {
+  name: "DepthwiseConv2dNativeBackpropFilter"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter_sizes"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
+op {
+  name: "DepthwiseConv2dNativeBackpropFilter"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter_sizes"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+}
+op {
+  name: "DepthwiseConv2dNativeBackpropFilter"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter_sizes"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
+op {
+  name: "DepthwiseConv2dNativeBackpropFilter"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter_sizes"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DepthwiseConv2dNativeBackpropInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DepthwiseConv2dNativeBackpropInput.pbtxt
new file mode 100644
index 0000000..d3329f2
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DepthwiseConv2dNativeBackpropInput.pbtxt
@@ -0,0 +1,238 @@
+op {
+  name: "DepthwiseConv2dNativeBackpropInput"
+  input_arg {
+    name: "input_sizes"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
+op {
+  name: "DepthwiseConv2dNativeBackpropInput"
+  input_arg {
+    name: "input_sizes"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+}
+op {
+  name: "DepthwiseConv2dNativeBackpropInput"
+  input_arg {
+    name: "input_sizes"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
+op {
+  name: "DepthwiseConv2dNativeBackpropInput"
+  input_arg {
+    name: "input_sizes"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Dequantize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Dequantize.pbtxt
new file mode 100644
index 0000000..6471f9d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Dequantize.pbtxt
@@ -0,0 +1,137 @@
+op {
+  name: "Dequantize"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "min_range"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_range"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "mode"
+    type: "string"
+    default_value {
+      s: "MIN_COMBINED"
+    }
+    allowed_values {
+      list {
+        s: "MIN_COMBINED"
+        s: "MIN_FIRST"
+      }
+    }
+  }
+}
+op {
+  name: "Dequantize"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "min_range"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_range"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "mode"
+    type: "string"
+    default_value {
+      s: "MIN_COMBINED"
+    }
+    allowed_values {
+      list {
+        s: "MIN_COMBINED"
+        s: "MIN_FIRST"
+        s: "SCALED"
+      }
+    }
+  }
+}
+op {
+  name: "Dequantize"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "min_range"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_range"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "mode"
+    type: "string"
+    default_value {
+      s: "MIN_COMBINED"
+    }
+    allowed_values {
+      list {
+        s: "MIN_COMBINED"
+        s: "MIN_FIRST"
+        s: "SCALED"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DeserializeIterator.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DeserializeIterator.pbtxt
new file mode 100644
index 0000000..1ae290e9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DeserializeIterator.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "DeserializeIterator"
+  input_arg {
+    name: "resource_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "serialized"
+    type: DT_VARIANT
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DeserializeManySparse.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DeserializeManySparse.pbtxt
new file mode 100644
index 0000000..f0e75d9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DeserializeManySparse.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "DeserializeManySparse"
+  input_arg {
+    name: "serialized_sparse"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "sparse_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sparse_values"
+    type_attr: "dtype"
+  }
+  output_arg {
+    name: "sparse_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DeserializeSparse.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DeserializeSparse.pbtxt
new file mode 100644
index 0000000..c23a9b5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DeserializeSparse.pbtxt
@@ -0,0 +1,59 @@
+op {
+  name: "DeserializeSparse"
+  input_arg {
+    name: "serialized_sparse"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "sparse_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sparse_values"
+    type_attr: "dtype"
+  }
+  output_arg {
+    name: "sparse_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+}
+op {
+  name: "DeserializeSparse"
+  input_arg {
+    name: "serialized_sparse"
+    type_attr: "Tserialized"
+  }
+  output_arg {
+    name: "sparse_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sparse_values"
+    type_attr: "dtype"
+  }
+  output_arg {
+    name: "sparse_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "Tserialized"
+    type: "type"
+    default_value {
+      type: DT_STRING
+    }
+    allowed_values {
+      list {
+        type: DT_STRING
+        type: DT_VARIANT
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DestroyResourceOp.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DestroyResourceOp.pbtxt
new file mode 100644
index 0000000..aa16c5a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DestroyResourceOp.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "DestroyResourceOp"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "ignore_lookup_error"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DestroyTemporaryVariable.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DestroyTemporaryVariable.pbtxt
new file mode 100644
index 0000000..7e073b2
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DestroyTemporaryVariable.pbtxt
@@ -0,0 +1,20 @@
+op {
+  name: "DestroyTemporaryVariable"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  output_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "var_name"
+    type: "string"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Diag.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Diag.pbtxt
new file mode 100644
index 0000000..92cb207
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Diag.pbtxt
@@ -0,0 +1,78 @@
+op {
+  name: "Diag"
+  input_arg {
+    name: "diagonal"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Diag"
+  input_arg {
+    name: "diagonal"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Diag"
+  input_arg {
+    name: "diagonal"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DiagPart.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DiagPart.pbtxt
new file mode 100644
index 0000000..aec8c87
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DiagPart.pbtxt
@@ -0,0 +1,78 @@
+op {
+  name: "DiagPart"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "diagonal"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "DiagPart"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "diagonal"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "DiagPart"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "diagonal"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Digamma.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Digamma.pbtxt
new file mode 100644
index 0000000..0c294e5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Digamma.pbtxt
@@ -0,0 +1,68 @@
+op {
+  name: "Digamma"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "Digamma"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "Digamma"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Dilation2D.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Dilation2D.pbtxt
new file mode 100644
index 0000000..1db8503
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Dilation2D.pbtxt
@@ -0,0 +1,224 @@
+op {
+  name: "Dilation2D"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "rates"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
+op {
+  name: "Dilation2D"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "rates"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
+op {
+  name: "Dilation2D"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "rates"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
+op {
+  name: "Dilation2D"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "rates"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Dilation2DBackpropFilter.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Dilation2DBackpropFilter.pbtxt
new file mode 100644
index 0000000..5a5a9f1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Dilation2DBackpropFilter.pbtxt
@@ -0,0 +1,240 @@
+op {
+  name: "Dilation2DBackpropFilter"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "filter_backprop"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "rates"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
+op {
+  name: "Dilation2DBackpropFilter"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "filter_backprop"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "rates"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
+op {
+  name: "Dilation2DBackpropFilter"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "filter_backprop"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "rates"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
+op {
+  name: "Dilation2DBackpropFilter"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "filter_backprop"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "rates"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Dilation2DBackpropInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Dilation2DBackpropInput.pbtxt
new file mode 100644
index 0000000..8944211
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Dilation2DBackpropInput.pbtxt
@@ -0,0 +1,240 @@
+op {
+  name: "Dilation2DBackpropInput"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "in_backprop"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "rates"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
+op {
+  name: "Dilation2DBackpropInput"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "in_backprop"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "rates"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
+op {
+  name: "Dilation2DBackpropInput"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "in_backprop"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "rates"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
+op {
+  name: "Dilation2DBackpropInput"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "in_backprop"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "rates"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DirectedInterleaveDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DirectedInterleaveDataset.pbtxt
new file mode 100644
index 0000000..dccdf1e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DirectedInterleaveDataset.pbtxt
@@ -0,0 +1,34 @@
+op {
+  name: "DirectedInterleaveDataset"
+  input_arg {
+    name: "selector_input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "data_input_datasets"
+    type: DT_VARIANT
+    number_attr: "N"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Div.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Div.pbtxt
new file mode 100644
index 0000000..6ccb981
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Div.pbtxt
@@ -0,0 +1,104 @@
+op {
+  name: "Div"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Div"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Div"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DivNoNan.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DivNoNan.pbtxt
new file mode 100644
index 0000000..17ec867
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DivNoNan.pbtxt
@@ -0,0 +1,53 @@
+op {
+  name: "DivNoNan"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "DivNoNan"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DrawBoundingBoxes.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DrawBoundingBoxes.pbtxt
new file mode 100644
index 0000000..7298173
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DrawBoundingBoxes.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "DrawBoundingBoxes"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "boxes"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_HALF
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DrawBoundingBoxesV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DrawBoundingBoxesV2.pbtxt
new file mode 100644
index 0000000..0a56179
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DrawBoundingBoxesV2.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "DrawBoundingBoxesV2"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "boxes"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "colors"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_HALF
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DynamicPartition.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DynamicPartition.pbtxt
new file mode 100644
index 0000000..3565bd6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DynamicPartition.pbtxt
@@ -0,0 +1,26 @@
+op {
+  name: "DynamicPartition"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "partitions"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "outputs"
+    type_attr: "T"
+    number_attr: "num_partitions"
+  }
+  attr {
+    name: "num_partitions"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/DynamicStitch.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/DynamicStitch.pbtxt
new file mode 100644
index 0000000..aba8346
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/DynamicStitch.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "DynamicStitch"
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+    number_attr: "N"
+  }
+  input_arg {
+    name: "data"
+    type_attr: "T"
+    number_attr: "N"
+  }
+  output_arg {
+    name: "merged"
+    type_attr: "T"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/EagerPyFunc.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/EagerPyFunc.pbtxt
new file mode 100644
index 0000000..84f3510
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/EagerPyFunc.pbtxt
@@ -0,0 +1,26 @@
+op {
+  name: "EagerPyFunc"
+  input_arg {
+    name: "input"
+    type_list_attr: "Tin"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "token"
+    type: "string"
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/EditDistance.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/EditDistance.pbtxt
new file mode 100644
index 0000000..aba098b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/EditDistance.pbtxt
@@ -0,0 +1,42 @@
+op {
+  name: "EditDistance"
+  input_arg {
+    name: "hypothesis_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "hypothesis_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "hypothesis_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "truth_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "truth_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "truth_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "normalize"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Einsum.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Einsum.pbtxt
new file mode 100644
index 0000000..3855daa
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Einsum.pbtxt
@@ -0,0 +1,26 @@
+op {
+  name: "Einsum"
+  input_arg {
+    name: "inputs"
+    type_attr: "T"
+    number_attr: "N"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "equation"
+    type: "string"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Elu.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Elu.pbtxt
new file mode 100644
index 0000000..4b8a815
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Elu.pbtxt
@@ -0,0 +1,67 @@
+op {
+  name: "Elu"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "Elu"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "Elu"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/EluGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/EluGrad.pbtxt
new file mode 100644
index 0000000..cfbc9f9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/EluGrad.pbtxt
@@ -0,0 +1,79 @@
+op {
+  name: "EluGrad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "outputs"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "EluGrad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "outputs"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "EluGrad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "outputs"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Empty.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Empty.pbtxt
new file mode 100644
index 0000000..147854b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Empty.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "Empty"
+  input_arg {
+    name: "shape"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "init"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/EmptyTensorList.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/EmptyTensorList.pbtxt
new file mode 100644
index 0000000..829a6d4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/EmptyTensorList.pbtxt
@@ -0,0 +1,29 @@
+op {
+  name: "EmptyTensorList"
+  input_arg {
+    name: "element_shape"
+    type_attr: "shape_type"
+  }
+  input_arg {
+    name: "max_num_elements"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "element_dtype"
+    type: "type"
+  }
+  attr {
+    name: "shape_type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/EncodeBase64.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/EncodeBase64.pbtxt
new file mode 100644
index 0000000..6e5241d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/EncodeBase64.pbtxt
@@ -0,0 +1,18 @@
+op {
+  name: "EncodeBase64"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type: DT_STRING
+  }
+  attr {
+    name: "pad"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/EncodeJpeg.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/EncodeJpeg.pbtxt
new file mode 100644
index 0000000..9f3c345
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/EncodeJpeg.pbtxt
@@ -0,0 +1,87 @@
+op {
+  name: "EncodeJpeg"
+  input_arg {
+    name: "image"
+    type: DT_UINT8
+  }
+  output_arg {
+    name: "contents"
+    type: DT_STRING
+  }
+  attr {
+    name: "format"
+    type: "string"
+    default_value {
+      s: ""
+    }
+    allowed_values {
+      list {
+        s: ""
+        s: "grayscale"
+        s: "rgb"
+      }
+    }
+  }
+  attr {
+    name: "quality"
+    type: "int"
+    default_value {
+      i: 95
+    }
+  }
+  attr {
+    name: "progressive"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "optimize_size"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "chroma_downsampling"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "density_unit"
+    type: "string"
+    default_value {
+      s: "in"
+    }
+    allowed_values {
+      list {
+        s: "in"
+        s: "cm"
+      }
+    }
+  }
+  attr {
+    name: "x_density"
+    type: "int"
+    default_value {
+      i: 300
+    }
+  }
+  attr {
+    name: "y_density"
+    type: "int"
+    default_value {
+      i: 300
+    }
+  }
+  attr {
+    name: "xmp_metadata"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/EncodeJpegVariableQuality.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/EncodeJpegVariableQuality.pbtxt
new file mode 100644
index 0000000..94c41ea
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/EncodeJpegVariableQuality.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "EncodeJpegVariableQuality"
+  input_arg {
+    name: "images"
+    type: DT_UINT8
+  }
+  input_arg {
+    name: "quality"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "contents"
+    type: DT_STRING
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/EncodePng.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/EncodePng.pbtxt
new file mode 100644
index 0000000..7d2cbd8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/EncodePng.pbtxt
@@ -0,0 +1,31 @@
+op {
+  name: "EncodePng"
+  input_arg {
+    name: "image"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "contents"
+    type: DT_STRING
+  }
+  attr {
+    name: "compression"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_UINT8
+    }
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_UINT16
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/EncodeProto.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/EncodeProto.pbtxt
new file mode 100644
index 0000000..e619618
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/EncodeProto.pbtxt
@@ -0,0 +1,36 @@
+op {
+  name: "EncodeProto"
+  input_arg {
+    name: "sizes"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "values"
+    type_list_attr: "Tinput_types"
+  }
+  output_arg {
+    name: "bytes"
+    type: DT_STRING
+  }
+  attr {
+    name: "field_names"
+    type: "list(string)"
+  }
+  attr {
+    name: "message_type"
+    type: "string"
+  }
+  attr {
+    name: "descriptor_source"
+    type: "string"
+    default_value {
+      s: "local://"
+    }
+  }
+  attr {
+    name: "Tinput_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/EncodeWav.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/EncodeWav.pbtxt
new file mode 100644
index 0000000..b013362
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/EncodeWav.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "EncodeWav"
+  input_arg {
+    name: "audio"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "sample_rate"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "contents"
+    type: DT_STRING
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/EnqueueTPUEmbeddingIntegerBatch.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/EnqueueTPUEmbeddingIntegerBatch.pbtxt
new file mode 100644
index 0000000..26d63b6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/EnqueueTPUEmbeddingIntegerBatch.pbtxt
@@ -0,0 +1,26 @@
+op {
+  name: "EnqueueTPUEmbeddingIntegerBatch"
+  input_arg {
+    name: "batch"
+    type: DT_INT32
+    number_attr: "N"
+  }
+  input_arg {
+    name: "mode_override"
+    type: DT_STRING
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "device_ordinal"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/EnqueueTPUEmbeddingSparseBatch.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/EnqueueTPUEmbeddingSparseBatch.pbtxt
new file mode 100644
index 0000000..64b8cb5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/EnqueueTPUEmbeddingSparseBatch.pbtxt
@@ -0,0 +1,127 @@
+op {
+  name: "EnqueueTPUEmbeddingSparseBatch"
+  input_arg {
+    name: "sample_indices"
+    type: DT_INT32
+    number_attr: "N"
+  }
+  input_arg {
+    name: "embedding_indices"
+    type: DT_INT32
+    number_attr: "N"
+  }
+  input_arg {
+    name: "aggregation_weights"
+    type: DT_FLOAT
+    number_attr: "N"
+  }
+  input_arg {
+    name: "mode_override"
+    type: DT_STRING
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "device_ordinal"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "combiners"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "EnqueueTPUEmbeddingSparseBatch"
+  input_arg {
+    name: "sample_indices"
+    type_attr: "T1"
+    number_attr: "N"
+  }
+  input_arg {
+    name: "embedding_indices"
+    type_attr: "T2"
+    number_attr: "N"
+  }
+  input_arg {
+    name: "aggregation_weights"
+    type_attr: "T3"
+    number_attr: "N"
+  }
+  input_arg {
+    name: "mode_override"
+    type: DT_STRING
+  }
+  attr {
+    name: "T1"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T2"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T3"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "device_ordinal"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "combiners"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/EnqueueTPUEmbeddingSparseTensorBatch.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/EnqueueTPUEmbeddingSparseTensorBatch.pbtxt
new file mode 100644
index 0000000..14849fc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/EnqueueTPUEmbeddingSparseTensorBatch.pbtxt
@@ -0,0 +1,230 @@
+op {
+  name: "EnqueueTPUEmbeddingSparseTensorBatch"
+  input_arg {
+    name: "sample_indices"
+    type: DT_INT32
+    number_attr: "N"
+  }
+  input_arg {
+    name: "embedding_indices"
+    type: DT_INT32
+    number_attr: "N"
+  }
+  input_arg {
+    name: "aggregation_weights"
+    type: DT_FLOAT
+    number_attr: "N"
+  }
+  input_arg {
+    name: "mode_override"
+    type: DT_STRING
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "device_ordinal"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "combiners"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "table_ids"
+    type: "list(int)"
+  }
+  is_stateful: true
+}
+op {
+  name: "EnqueueTPUEmbeddingSparseTensorBatch"
+  input_arg {
+    name: "sample_indices"
+    type_attr: "T1"
+    number_attr: "N"
+  }
+  input_arg {
+    name: "embedding_indices"
+    type_attr: "T2"
+    number_attr: "N"
+  }
+  input_arg {
+    name: "aggregation_weights"
+    type_attr: "T3"
+    number_attr: "N"
+  }
+  input_arg {
+    name: "mode_override"
+    type: DT_STRING
+  }
+  attr {
+    name: "T1"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T2"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T3"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "device_ordinal"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "combiners"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "table_ids"
+    type: "list(int)"
+  }
+  is_stateful: true
+}
+op {
+  name: "EnqueueTPUEmbeddingSparseTensorBatch"
+  input_arg {
+    name: "sample_indices"
+    type_attr: "T1"
+    number_attr: "N"
+  }
+  input_arg {
+    name: "embedding_indices"
+    type_attr: "T2"
+    number_attr: "N"
+  }
+  input_arg {
+    name: "aggregation_weights"
+    type_attr: "T3"
+    number_attr: "N"
+  }
+  input_arg {
+    name: "mode_override"
+    type: DT_STRING
+  }
+  attr {
+    name: "T1"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T2"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T3"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "device_ordinal"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "combiners"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "table_ids"
+    type: "list(int)"
+  }
+  attr {
+    name: "max_sequence_lengths"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/EnsureShape.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/EnsureShape.pbtxt
new file mode 100644
index 0000000..24fa558
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/EnsureShape.pbtxt
@@ -0,0 +1,19 @@
+op {
+  name: "EnsureShape"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Enter.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Enter.pbtxt
new file mode 100644
index 0000000..d39d15f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Enter.pbtxt
@@ -0,0 +1,33 @@
+op {
+  name: "Enter"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "frame_name"
+    type: "string"
+  }
+  attr {
+    name: "is_constant"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "parallel_iterations"
+    type: "int"
+    default_value {
+      i: 10
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Equal.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Equal.pbtxt
new file mode 100644
index 0000000..590849b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Equal.pbtxt
@@ -0,0 +1,119 @@
+op {
+  name: "Equal"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_QUINT8
+        type: DT_QINT8
+        type: DT_QINT32
+        type: DT_STRING
+        type: DT_BOOL
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "Equal"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_QUINT8
+        type: DT_QINT8
+        type: DT_QINT32
+        type: DT_STRING
+        type: DT_BOOL
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "Equal"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_QUINT8
+        type: DT_QINT8
+        type: DT_QINT32
+        type: DT_STRING
+        type: DT_BOOL
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  is_commutative: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Erf.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Erf.pbtxt
new file mode 100644
index 0000000..680b736
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Erf.pbtxt
@@ -0,0 +1,68 @@
+op {
+  name: "Erf"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "Erf"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "Erf"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Erfc.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Erfc.pbtxt
new file mode 100644
index 0000000..2fcfc68
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Erfc.pbtxt
@@ -0,0 +1,68 @@
+op {
+  name: "Erfc"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "Erfc"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "Erfc"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/EuclideanNorm.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/EuclideanNorm.pbtxt
new file mode 100644
index 0000000..1117fce
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/EuclideanNorm.pbtxt
@@ -0,0 +1,60 @@
+op {
+  name: "EuclideanNorm"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Exit.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Exit.pbtxt
new file mode 100644
index 0000000..56a1371
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Exit.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "Exit"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Exp.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Exp.pbtxt
new file mode 100644
index 0000000..7afeb67
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Exp.pbtxt
@@ -0,0 +1,74 @@
+op {
+  name: "Exp"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Exp"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Exp"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExpandDims.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExpandDims.pbtxt
new file mode 100644
index 0000000..c7bb353
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExpandDims.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "ExpandDims"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dim"
+    type_attr: "Tdim"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tdim"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalAssertNextDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalAssertNextDataset.pbtxt
new file mode 100644
index 0000000..8f3d58c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalAssertNextDataset.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "ExperimentalAssertNextDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "transformations"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalAutoShardDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalAutoShardDataset.pbtxt
new file mode 100644
index 0000000..92ef9f6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalAutoShardDataset.pbtxt
@@ -0,0 +1,31 @@
+op {
+  name: "ExperimentalAutoShardDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "num_workers"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "index"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalBytesProducedStatsDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalBytesProducedStatsDataset.pbtxt
new file mode 100644
index 0000000..a06fc97
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalBytesProducedStatsDataset.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "ExperimentalBytesProducedStatsDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "tag"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalCSVDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalCSVDataset.pbtxt
new file mode 100644
index 0000000..54706da
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalCSVDataset.pbtxt
@@ -0,0 +1,65 @@
+op {
+  name: "ExperimentalCSVDataset"
+  input_arg {
+    name: "filenames"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "compression_type"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "buffer_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "header"
+    type: DT_BOOL
+  }
+  input_arg {
+    name: "field_delim"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "use_quote_delim"
+    type: DT_BOOL
+  }
+  input_arg {
+    name: "na_value"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "select_cols"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "record_defaults"
+    type_list_attr: "output_types"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalChooseFastestDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalChooseFastestDataset.pbtxt
new file mode 100644
index 0000000..4500a8a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalChooseFastestDataset.pbtxt
@@ -0,0 +1,77 @@
+op {
+  name: "ExperimentalChooseFastestDataset"
+  input_arg {
+    name: "input_datasets"
+    type: DT_VARIANT
+    number_attr: "N"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 2
+  }
+  attr {
+    name: "num_experiments"
+    type: "int"
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
+  name: "ExperimentalChooseFastestDataset"
+  input_arg {
+    name: "input_datasets"
+    type: DT_VARIANT
+    number_attr: "N"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 2
+  }
+  attr {
+    name: "num_experiments"
+    type: "int"
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalDatasetCardinality.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalDatasetCardinality.pbtxt
new file mode 100644
index 0000000..f6ba365
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalDatasetCardinality.pbtxt
@@ -0,0 +1,11 @@
+op {
+  name: "ExperimentalDatasetCardinality"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "cardinality"
+    type: DT_INT64
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalDatasetToTFRecord.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalDatasetToTFRecord.pbtxt
new file mode 100644
index 0000000..0d0e46c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalDatasetToTFRecord.pbtxt
@@ -0,0 +1,31 @@
+op {
+  name: "ExperimentalDatasetToTFRecord"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "filename"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "compression_type"
+    type: DT_STRING
+  }
+}
+op {
+  name: "ExperimentalDatasetToTFRecord"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "filename"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "compression_type"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalDenseToSparseBatchDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalDenseToSparseBatchDataset.pbtxt
new file mode 100644
index 0000000..886168c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalDenseToSparseBatchDataset.pbtxt
@@ -0,0 +1,63 @@
+op {
+  name: "ExperimentalDenseToSparseBatchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "batch_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "row_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
+op {
+  name: "ExperimentalDenseToSparseBatchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "batch_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "row_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalDirectedInterleaveDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalDirectedInterleaveDataset.pbtxt
new file mode 100644
index 0000000..e0d0dc6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalDirectedInterleaveDataset.pbtxt
@@ -0,0 +1,34 @@
+op {
+  name: "ExperimentalDirectedInterleaveDataset"
+  input_arg {
+    name: "selector_input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "data_input_datasets"
+    type: DT_VARIANT
+    number_attr: "N"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalGroupByReducerDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalGroupByReducerDataset.pbtxt
new file mode 100644
index 0000000..87977da
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalGroupByReducerDataset.pbtxt
@@ -0,0 +1,76 @@
+op {
+  name: "ExperimentalGroupByReducerDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "key_func_other_arguments"
+    type_list_attr: "Tkey_func_other_arguments"
+  }
+  input_arg {
+    name: "init_func_other_arguments"
+    type_list_attr: "Tinit_func_other_arguments"
+  }
+  input_arg {
+    name: "reduce_func_other_arguments"
+    type_list_attr: "Treduce_func_other_arguments"
+  }
+  input_arg {
+    name: "finalize_func_other_arguments"
+    type_list_attr: "Tfinalize_func_other_arguments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "key_func"
+    type: "func"
+  }
+  attr {
+    name: "init_func"
+    type: "func"
+  }
+  attr {
+    name: "reduce_func"
+    type: "func"
+  }
+  attr {
+    name: "finalize_func"
+    type: "func"
+  }
+  attr {
+    name: "Tkey_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tinit_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Treduce_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tfinalize_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalGroupByWindowDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalGroupByWindowDataset.pbtxt
new file mode 100644
index 0000000..500e8eb
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalGroupByWindowDataset.pbtxt
@@ -0,0 +1,125 @@
+op {
+  name: "ExperimentalGroupByWindowDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "key_func_other_arguments"
+    type_list_attr: "Tkey_func_other_arguments"
+  }
+  input_arg {
+    name: "reduce_func_other_arguments"
+    type_list_attr: "Treduce_func_other_arguments"
+  }
+  input_arg {
+    name: "window_size_func_other_arguments"
+    type_list_attr: "Twindow_size_func_other_arguments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "key_func"
+    type: "func"
+  }
+  attr {
+    name: "reduce_func"
+    type: "func"
+  }
+  attr {
+    name: "window_size_func"
+    type: "func"
+  }
+  attr {
+    name: "Tkey_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Treduce_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Twindow_size_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
+op {
+  name: "ExperimentalGroupByWindowDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "key_func_other_arguments"
+    type_list_attr: "Tkey_func_other_arguments"
+  }
+  input_arg {
+    name: "reduce_func_other_arguments"
+    type_list_attr: "Treduce_func_other_arguments"
+  }
+  input_arg {
+    name: "window_size_func_other_arguments"
+    type_list_attr: "Twindow_size_func_other_arguments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "key_func"
+    type: "func"
+  }
+  attr {
+    name: "reduce_func"
+    type: "func"
+  }
+  attr {
+    name: "window_size_func"
+    type: "func"
+  }
+  attr {
+    name: "Tkey_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Treduce_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Twindow_size_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalIgnoreErrorsDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalIgnoreErrorsDataset.pbtxt
new file mode 100644
index 0000000..e334e93
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalIgnoreErrorsDataset.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "ExperimentalIgnoreErrorsDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalIteratorGetDevice.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalIteratorGetDevice.pbtxt
new file mode 100644
index 0000000..8e1e102
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalIteratorGetDevice.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "ExperimentalIteratorGetDevice"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "device"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalLMDBDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalLMDBDataset.pbtxt
new file mode 100644
index 0000000..7f06e8d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalLMDBDataset.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "ExperimentalLMDBDataset"
+  input_arg {
+    name: "filenames"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalLatencyStatsDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalLatencyStatsDataset.pbtxt
new file mode 100644
index 0000000..601867b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalLatencyStatsDataset.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "ExperimentalLatencyStatsDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "tag"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalMapAndBatchDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalMapAndBatchDataset.pbtxt
new file mode 100644
index 0000000..c586c33
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalMapAndBatchDataset.pbtxt
@@ -0,0 +1,103 @@
+op {
+  name: "ExperimentalMapAndBatchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  input_arg {
+    name: "batch_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "num_parallel_calls"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "drop_remainder"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
+  name: "ExperimentalMapAndBatchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  input_arg {
+    name: "batch_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "num_parallel_calls"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "drop_remainder"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "preserve_cardinality"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalMapDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalMapDataset.pbtxt
new file mode 100644
index 0000000..f3e13c9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalMapDataset.pbtxt
@@ -0,0 +1,93 @@
+op {
+  name: "ExperimentalMapDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "use_inter_op_parallelism"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
+op {
+  name: "ExperimentalMapDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "use_inter_op_parallelism"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "preserve_cardinality"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalMatchingFilesDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalMatchingFilesDataset.pbtxt
new file mode 100644
index 0000000..b67c8af
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalMatchingFilesDataset.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "ExperimentalMatchingFilesDataset"
+  input_arg {
+    name: "patterns"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalMaxIntraOpParallelismDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalMaxIntraOpParallelismDataset.pbtxt
new file mode 100644
index 0000000..f6510b7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalMaxIntraOpParallelismDataset.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "ExperimentalMaxIntraOpParallelismDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "max_intra_op_parallelism"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalNonSerializableDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalNonSerializableDataset.pbtxt
new file mode 100644
index 0000000..546dd09
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalNonSerializableDataset.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "ExperimentalNonSerializableDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalParallelInterleaveDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalParallelInterleaveDataset.pbtxt
new file mode 100644
index 0000000..8827543
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalParallelInterleaveDataset.pbtxt
@@ -0,0 +1,56 @@
+op {
+  name: "ExperimentalParallelInterleaveDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  input_arg {
+    name: "cycle_length"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "block_length"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sloppy"
+    type: DT_BOOL
+  }
+  input_arg {
+    name: "buffer_output_elements"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "prefetch_input_elements"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalParseExampleDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalParseExampleDataset.pbtxt
new file mode 100644
index 0000000..6ed7e88
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalParseExampleDataset.pbtxt
@@ -0,0 +1,147 @@
+op {
+  name: "ExperimentalParseExampleDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "num_parallel_calls"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "dense_defaults"
+    type_list_attr: "Tdense"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "sparse_keys"
+    type: "list(string)"
+    has_minimum: true
+  }
+  attr {
+    name: "dense_keys"
+    type: "list(string)"
+    has_minimum: true
+  }
+  attr {
+    name: "sparse_types"
+    type: "list(type)"
+    has_minimum: true
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "Tdense"
+    type: "list(type)"
+    has_minimum: true
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "dense_shapes"
+    type: "list(shape)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
+  name: "ExperimentalParseExampleDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "num_parallel_calls"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "dense_defaults"
+    type_list_attr: "Tdense"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "sparse_keys"
+    type: "list(string)"
+    has_minimum: true
+  }
+  attr {
+    name: "dense_keys"
+    type: "list(string)"
+    has_minimum: true
+  }
+  attr {
+    name: "sparse_types"
+    type: "list(type)"
+    has_minimum: true
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "Tdense"
+    type: "list(type)"
+    has_minimum: true
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "dense_shapes"
+    type: "list(shape)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "sloppy"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalPrivateThreadPoolDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalPrivateThreadPoolDataset.pbtxt
new file mode 100644
index 0000000..799dab7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalPrivateThreadPoolDataset.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "ExperimentalPrivateThreadPoolDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "num_threads"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalRandomDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalRandomDataset.pbtxt
new file mode 100644
index 0000000..e3dc22c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalRandomDataset.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "ExperimentalRandomDataset"
+  input_arg {
+    name: "seed"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "seed2"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalRebatchDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalRebatchDataset.pbtxt
new file mode 100644
index 0000000..8f56e85
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalRebatchDataset.pbtxt
@@ -0,0 +1,95 @@
+op {
+  name: "ExperimentalRebatchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "num_workers"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
+  name: "ExperimentalRebatchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "num_workers"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "use_fallback"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
+op {
+  name: "ExperimentalRebatchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "num_replicas"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "use_fallback"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalScanDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalScanDataset.pbtxt
new file mode 100644
index 0000000..d708354
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalScanDataset.pbtxt
@@ -0,0 +1,99 @@
+op {
+  name: "ExperimentalScanDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "initial_state"
+    type_list_attr: "Tstate"
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Tstate"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
+  name: "ExperimentalScanDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "initial_state"
+    type_list_attr: "Tstate"
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Tstate"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "preserve_cardinality"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalSetStatsAggregatorDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalSetStatsAggregatorDataset.pbtxt
new file mode 100644
index 0000000..f360120
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalSetStatsAggregatorDataset.pbtxt
@@ -0,0 +1,36 @@
+op {
+  name: "ExperimentalSetStatsAggregatorDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "stats_aggregator"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "tag"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "counter_prefix"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalSleepDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalSleepDataset.pbtxt
new file mode 100644
index 0000000..19d20c7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalSleepDataset.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "ExperimentalSleepDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "sleep_microseconds"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalSlidingWindowDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalSlidingWindowDataset.pbtxt
new file mode 100644
index 0000000..344dc28
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalSlidingWindowDataset.pbtxt
@@ -0,0 +1,35 @@
+op {
+  name: "ExperimentalSlidingWindowDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "window_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "window_shift"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "window_stride"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalSqlDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalSqlDataset.pbtxt
new file mode 100644
index 0000000..c4663c5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalSqlDataset.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "ExperimentalSqlDataset"
+  input_arg {
+    name: "driver_name"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "data_source_name"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "query"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalStatsAggregatorHandle.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalStatsAggregatorHandle.pbtxt
new file mode 100644
index 0000000..b00cadb
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalStatsAggregatorHandle.pbtxt
@@ -0,0 +1,22 @@
+op {
+  name: "ExperimentalStatsAggregatorHandle"
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalStatsAggregatorSummary.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalStatsAggregatorSummary.pbtxt
new file mode 100644
index 0000000..7886f7a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalStatsAggregatorSummary.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "ExperimentalStatsAggregatorSummary"
+  input_arg {
+    name: "iterator"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "summary"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalTakeWhileDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalTakeWhileDataset.pbtxt
new file mode 100644
index 0000000..5e8a62a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalTakeWhileDataset.pbtxt
@@ -0,0 +1,36 @@
+op {
+  name: "ExperimentalTakeWhileDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "predicate"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalThreadPoolDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalThreadPoolDataset.pbtxt
new file mode 100644
index 0000000..1be5fd2
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalThreadPoolDataset.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "ExperimentalThreadPoolDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "thread_pool"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalThreadPoolHandle.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalThreadPoolHandle.pbtxt
new file mode 100644
index 0000000..8b230f9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalThreadPoolHandle.pbtxt
@@ -0,0 +1,37 @@
+op {
+  name: "ExperimentalThreadPoolHandle"
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "num_threads"
+    type: "int"
+  }
+  attr {
+    name: "max_intra_op_parallelism"
+    type: "int"
+    default_value {
+      i: 1
+    }
+  }
+  attr {
+    name: "display_name"
+    type: "string"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalUnbatchDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalUnbatchDataset.pbtxt
new file mode 100644
index 0000000..ab48c84
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalUnbatchDataset.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "ExperimentalUnbatchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExperimentalUniqueDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalUniqueDataset.pbtxt
new file mode 100644
index 0000000..aacdfba
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExperimentalUniqueDataset.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "ExperimentalUniqueDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Expm1.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Expm1.pbtxt
new file mode 100644
index 0000000..b09aac4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Expm1.pbtxt
@@ -0,0 +1,74 @@
+op {
+  name: "Expm1"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Expm1"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Expm1"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExtractGlimpse.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExtractGlimpse.pbtxt
new file mode 100644
index 0000000..597a77a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExtractGlimpse.pbtxt
@@ -0,0 +1,87 @@
+op {
+  name: "ExtractGlimpse"
+  input_arg {
+    name: "input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "offsets"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "glimpse"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "centered"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "normalized"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "uniform_noise"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
+op {
+  name: "ExtractGlimpse"
+  input_arg {
+    name: "input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "offsets"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "glimpse"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "centered"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "normalized"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "uniform_noise"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "noise"
+    type: "string"
+    default_value {
+      s: "uniform"
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExtractImagePatches.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExtractImagePatches.pbtxt
new file mode 100644
index 0000000..ebbbd75
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExtractImagePatches.pbtxt
@@ -0,0 +1,232 @@
+op {
+  name: "ExtractImagePatches"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "patches"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksizes"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "rates"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
+op {
+  name: "ExtractImagePatches"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "patches"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksizes"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "rates"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
+op {
+  name: "ExtractImagePatches"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "patches"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksizes"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "rates"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
+op {
+  name: "ExtractImagePatches"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "patches"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksizes"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "rates"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExtractJpegShape.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExtractJpegShape.pbtxt
new file mode 100644
index 0000000..ac3d34c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExtractJpegShape.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "ExtractJpegShape"
+  input_arg {
+    name: "contents"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "image_shape"
+    type_attr: "output_type"
+  }
+  attr {
+    name: "output_type"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ExtractVolumePatches.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ExtractVolumePatches.pbtxt
new file mode 100644
index 0000000..09cc21a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ExtractVolumePatches.pbtxt
@@ -0,0 +1,53 @@
+op {
+  name: "ExtractVolumePatches"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "patches"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksizes"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FFT.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FFT.pbtxt
new file mode 100644
index 0000000..e986f32
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FFT.pbtxt
@@ -0,0 +1,35 @@
+op {
+  name: "FFT"
+  input_arg {
+    name: "input"
+    type: DT_COMPLEX64
+  }
+  output_arg {
+    name: "output"
+    type: DT_COMPLEX64
+  }
+}
+op {
+  name: "FFT"
+  input_arg {
+    name: "input"
+    type_attr: "Tcomplex"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "Tcomplex"
+  }
+  attr {
+    name: "Tcomplex"
+    type: "type"
+    default_value {
+      type: DT_COMPLEX64
+    }
+    allowed_values {
+      list {
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FFT2D.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FFT2D.pbtxt
new file mode 100644
index 0000000..adb1c25
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FFT2D.pbtxt
@@ -0,0 +1,35 @@
+op {
+  name: "FFT2D"
+  input_arg {
+    name: "input"
+    type: DT_COMPLEX64
+  }
+  output_arg {
+    name: "output"
+    type: DT_COMPLEX64
+  }
+}
+op {
+  name: "FFT2D"
+  input_arg {
+    name: "input"
+    type_attr: "Tcomplex"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "Tcomplex"
+  }
+  attr {
+    name: "Tcomplex"
+    type: "type"
+    default_value {
+      type: DT_COMPLEX64
+    }
+    allowed_values {
+      list {
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FFT3D.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FFT3D.pbtxt
new file mode 100644
index 0000000..9266d6d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FFT3D.pbtxt
@@ -0,0 +1,35 @@
+op {
+  name: "FFT3D"
+  input_arg {
+    name: "input"
+    type: DT_COMPLEX64
+  }
+  output_arg {
+    name: "output"
+    type: DT_COMPLEX64
+  }
+}
+op {
+  name: "FFT3D"
+  input_arg {
+    name: "input"
+    type_attr: "Tcomplex"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "Tcomplex"
+  }
+  attr {
+    name: "Tcomplex"
+    type: "type"
+    default_value {
+      type: DT_COMPLEX64
+    }
+    allowed_values {
+      list {
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FIFOQueue.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FIFOQueue.pbtxt
new file mode 100644
index 0000000..c3321a8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FIFOQueue.pbtxt
@@ -0,0 +1,45 @@
+op {
+  name: "FIFOQueue"
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "component_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "shapes"
+    type: "list(shape)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FIFOQueueV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FIFOQueueV2.pbtxt
new file mode 100644
index 0000000..9b1c840
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FIFOQueueV2.pbtxt
@@ -0,0 +1,44 @@
+op {
+  name: "FIFOQueueV2"
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "component_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "shapes"
+    type: "list(shape)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Fact.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Fact.pbtxt
new file mode 100644
index 0000000..90a0ad8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Fact.pbtxt
@@ -0,0 +1,7 @@
+op {
+  name: "Fact"
+  output_arg {
+    name: "fact"
+    type: DT_STRING
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FakeParam.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FakeParam.pbtxt
new file mode 100644
index 0000000..dc2a7c5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FakeParam.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "FakeParam"
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FakeQuantWithMinMaxArgs.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FakeQuantWithMinMaxArgs.pbtxt
new file mode 100644
index 0000000..2d8eac8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FakeQuantWithMinMaxArgs.pbtxt
@@ -0,0 +1,96 @@
+op {
+  name: "FakeQuantWithMinMaxArgs"
+  input_arg {
+    name: "inputs"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "outputs"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "min"
+    type: "float"
+    default_value {
+      f: -6
+    }
+  }
+  attr {
+    name: "max"
+    type: "float"
+    default_value {
+      f: 6
+    }
+  }
+}
+op {
+  name: "FakeQuantWithMinMaxArgs"
+  input_arg {
+    name: "inputs"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "outputs"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "min"
+    type: "float"
+    default_value {
+      f: -6
+    }
+  }
+  attr {
+    name: "max"
+    type: "float"
+    default_value {
+      f: 6
+    }
+  }
+  attr {
+    name: "num_bits"
+    type: "int"
+    default_value {
+      i: 8
+    }
+  }
+}
+op {
+  name: "FakeQuantWithMinMaxArgs"
+  input_arg {
+    name: "inputs"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "outputs"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "min"
+    type: "float"
+    default_value {
+      f: -6
+    }
+  }
+  attr {
+    name: "max"
+    type: "float"
+    default_value {
+      f: 6
+    }
+  }
+  attr {
+    name: "num_bits"
+    type: "int"
+    default_value {
+      i: 8
+    }
+  }
+  attr {
+    name: "narrow_range"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FakeQuantWithMinMaxArgsGradient.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FakeQuantWithMinMaxArgsGradient.pbtxt
new file mode 100644
index 0000000..5d02f59
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FakeQuantWithMinMaxArgsGradient.pbtxt
@@ -0,0 +1,108 @@
+op {
+  name: "FakeQuantWithMinMaxArgsGradient"
+  input_arg {
+    name: "gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "inputs"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "backprops"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "min"
+    type: "float"
+    default_value {
+      f: -6
+    }
+  }
+  attr {
+    name: "max"
+    type: "float"
+    default_value {
+      f: 6
+    }
+  }
+}
+op {
+  name: "FakeQuantWithMinMaxArgsGradient"
+  input_arg {
+    name: "gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "inputs"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "backprops"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "min"
+    type: "float"
+    default_value {
+      f: -6
+    }
+  }
+  attr {
+    name: "max"
+    type: "float"
+    default_value {
+      f: 6
+    }
+  }
+  attr {
+    name: "num_bits"
+    type: "int"
+    default_value {
+      i: 8
+    }
+  }
+}
+op {
+  name: "FakeQuantWithMinMaxArgsGradient"
+  input_arg {
+    name: "gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "inputs"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "backprops"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "min"
+    type: "float"
+    default_value {
+      f: -6
+    }
+  }
+  attr {
+    name: "max"
+    type: "float"
+    default_value {
+      f: 6
+    }
+  }
+  attr {
+    name: "num_bits"
+    type: "int"
+    default_value {
+      i: 8
+    }
+  }
+  attr {
+    name: "narrow_range"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FakeQuantWithMinMaxVars.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FakeQuantWithMinMaxVars.pbtxt
new file mode 100644
index 0000000..233f5cc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FakeQuantWithMinMaxVars.pbtxt
@@ -0,0 +1,78 @@
+op {
+  name: "FakeQuantWithMinMaxVars"
+  input_arg {
+    name: "inputs"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "outputs"
+    type: DT_FLOAT
+  }
+}
+op {
+  name: "FakeQuantWithMinMaxVars"
+  input_arg {
+    name: "inputs"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "outputs"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "num_bits"
+    type: "int"
+    default_value {
+      i: 8
+    }
+  }
+}
+op {
+  name: "FakeQuantWithMinMaxVars"
+  input_arg {
+    name: "inputs"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "outputs"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "num_bits"
+    type: "int"
+    default_value {
+      i: 8
+    }
+  }
+  attr {
+    name: "narrow_range"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FakeQuantWithMinMaxVarsGradient.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FakeQuantWithMinMaxVarsGradient.pbtxt
new file mode 100644
index 0000000..cf8ed6f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FakeQuantWithMinMaxVarsGradient.pbtxt
@@ -0,0 +1,114 @@
+op {
+  name: "FakeQuantWithMinMaxVarsGradient"
+  input_arg {
+    name: "gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "inputs"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "backprops_wrt_input"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "backprop_wrt_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "backprop_wrt_max"
+    type: DT_FLOAT
+  }
+}
+op {
+  name: "FakeQuantWithMinMaxVarsGradient"
+  input_arg {
+    name: "gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "inputs"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "backprops_wrt_input"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "backprop_wrt_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "backprop_wrt_max"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "num_bits"
+    type: "int"
+    default_value {
+      i: 8
+    }
+  }
+}
+op {
+  name: "FakeQuantWithMinMaxVarsGradient"
+  input_arg {
+    name: "gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "inputs"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "backprops_wrt_input"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "backprop_wrt_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "backprop_wrt_max"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "num_bits"
+    type: "int"
+    default_value {
+      i: 8
+    }
+  }
+  attr {
+    name: "narrow_range"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FakeQuantWithMinMaxVarsPerChannel.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FakeQuantWithMinMaxVarsPerChannel.pbtxt
new file mode 100644
index 0000000..551ae79
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FakeQuantWithMinMaxVarsPerChannel.pbtxt
@@ -0,0 +1,78 @@
+op {
+  name: "FakeQuantWithMinMaxVarsPerChannel"
+  input_arg {
+    name: "inputs"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "outputs"
+    type: DT_FLOAT
+  }
+}
+op {
+  name: "FakeQuantWithMinMaxVarsPerChannel"
+  input_arg {
+    name: "inputs"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "outputs"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "num_bits"
+    type: "int"
+    default_value {
+      i: 8
+    }
+  }
+}
+op {
+  name: "FakeQuantWithMinMaxVarsPerChannel"
+  input_arg {
+    name: "inputs"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "outputs"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "num_bits"
+    type: "int"
+    default_value {
+      i: 8
+    }
+  }
+  attr {
+    name: "narrow_range"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FakeQuantWithMinMaxVarsPerChannelGradient.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FakeQuantWithMinMaxVarsPerChannelGradient.pbtxt
new file mode 100644
index 0000000..a787e25
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FakeQuantWithMinMaxVarsPerChannelGradient.pbtxt
@@ -0,0 +1,114 @@
+op {
+  name: "FakeQuantWithMinMaxVarsPerChannelGradient"
+  input_arg {
+    name: "gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "inputs"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "backprops_wrt_input"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "backprop_wrt_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "backprop_wrt_max"
+    type: DT_FLOAT
+  }
+}
+op {
+  name: "FakeQuantWithMinMaxVarsPerChannelGradient"
+  input_arg {
+    name: "gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "inputs"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "backprops_wrt_input"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "backprop_wrt_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "backprop_wrt_max"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "num_bits"
+    type: "int"
+    default_value {
+      i: 8
+    }
+  }
+}
+op {
+  name: "FakeQuantWithMinMaxVarsPerChannelGradient"
+  input_arg {
+    name: "gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "inputs"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "backprops_wrt_input"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "backprop_wrt_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "backprop_wrt_max"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "num_bits"
+    type: "int"
+    default_value {
+      i: 8
+    }
+  }
+  attr {
+    name: "narrow_range"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FakeQueue.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FakeQueue.pbtxt
new file mode 100644
index 0000000..5e4cb62
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FakeQueue.pbtxt
@@ -0,0 +1,13 @@
+op {
+  name: "FakeQueue"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Fill.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Fill.pbtxt
new file mode 100644
index 0000000..543ae42
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Fill.pbtxt
@@ -0,0 +1,51 @@
+op {
+  name: "Fill"
+  input_arg {
+    name: "dims"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
+op {
+  name: "Fill"
+  input_arg {
+    name: "dims"
+    type_attr: "index_type"
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "index_type"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FilterByLastComponentDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FilterByLastComponentDataset.pbtxt
new file mode 100644
index 0000000..d1e814c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FilterByLastComponentDataset.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "FilterByLastComponentDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "output"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FilterDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FilterDataset.pbtxt
new file mode 100644
index 0000000..217e420
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FilterDataset.pbtxt
@@ -0,0 +1,73 @@
+op {
+  name: "FilterDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "predicate"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
+op {
+  name: "FilterDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "predicate"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Fingerprint.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Fingerprint.pbtxt
new file mode 100644
index 0000000..3a55857
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Fingerprint.pbtxt
@@ -0,0 +1,19 @@
+op {
+  name: "Fingerprint"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "method"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "fingerprint"
+    type: DT_UINT8
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FixedLengthRecordDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FixedLengthRecordDataset.pbtxt
new file mode 100644
index 0000000..c743db8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FixedLengthRecordDataset.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "FixedLengthRecordDataset"
+  input_arg {
+    name: "filenames"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "header_bytes"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "record_bytes"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "footer_bytes"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "buffer_size"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FixedLengthRecordDatasetV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FixedLengthRecordDatasetV2.pbtxt
new file mode 100644
index 0000000..cb9b65a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FixedLengthRecordDatasetV2.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "FixedLengthRecordDatasetV2"
+  input_arg {
+    name: "filenames"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "header_bytes"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "record_bytes"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "footer_bytes"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "buffer_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "compression_type"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FixedLengthRecordReader.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FixedLengthRecordReader.pbtxt
new file mode 100644
index 0000000..75b6018
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FixedLengthRecordReader.pbtxt
@@ -0,0 +1,140 @@
+op {
+  name: "FixedLengthRecordReader"
+  output_arg {
+    name: "reader_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "header_bytes"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "record_bytes"
+    type: "int"
+  }
+  attr {
+    name: "footer_bytes"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "FixedLengthRecordReader"
+  output_arg {
+    name: "reader_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "header_bytes"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "record_bytes"
+    type: "int"
+  }
+  attr {
+    name: "footer_bytes"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "hop_bytes"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "FixedLengthRecordReader"
+  output_arg {
+    name: "reader_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "header_bytes"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "record_bytes"
+    type: "int"
+  }
+  attr {
+    name: "footer_bytes"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "hop_bytes"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  deprecation {
+    version: 26
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FixedLengthRecordReaderV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FixedLengthRecordReaderV2.pbtxt
new file mode 100644
index 0000000..b16e522
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FixedLengthRecordReaderV2.pbtxt
@@ -0,0 +1,141 @@
+op {
+  name: "FixedLengthRecordReaderV2"
+  output_arg {
+    name: "reader_handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "header_bytes"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "record_bytes"
+    type: "int"
+  }
+  attr {
+    name: "footer_bytes"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "FixedLengthRecordReaderV2"
+  output_arg {
+    name: "reader_handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "header_bytes"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "record_bytes"
+    type: "int"
+  }
+  attr {
+    name: "footer_bytes"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "hop_bytes"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "FixedLengthRecordReaderV2"
+  output_arg {
+    name: "reader_handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "header_bytes"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "record_bytes"
+    type: "int"
+  }
+  attr {
+    name: "footer_bytes"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "hop_bytes"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "encoding"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FixedUnigramCandidateSampler.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FixedUnigramCandidateSampler.pbtxt
new file mode 100644
index 0000000..a791134
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FixedUnigramCandidateSampler.pbtxt
@@ -0,0 +1,203 @@
+op {
+  name: "FixedUnigramCandidateSampler"
+  input_arg {
+    name: "true_classes"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sampled_candidates"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "true_expected_count"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "sampled_expected_count"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "num_true"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_sampled"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "unique"
+    type: "bool"
+  }
+  attr {
+    name: "range_max"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "vocab_file"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "distortion"
+    type: "float"
+    default_value {
+      f: 1
+    }
+  }
+  attr {
+    name: "num_reserved_ids"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+    default_value {
+      i: 1
+    }
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "shard"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "unigrams"
+    type: "list(float)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+}
+op {
+  name: "FixedUnigramCandidateSampler"
+  input_arg {
+    name: "true_classes"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sampled_candidates"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "true_expected_count"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "sampled_expected_count"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "num_true"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_sampled"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "unique"
+    type: "bool"
+  }
+  attr {
+    name: "range_max"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "vocab_file"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "distortion"
+    type: "float"
+    default_value {
+      f: 1
+    }
+  }
+  attr {
+    name: "num_reserved_ids"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+    default_value {
+      i: 1
+    }
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "shard"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "unigrams"
+    type: "list(float)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FlatMapDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FlatMapDataset.pbtxt
new file mode 100644
index 0000000..7dd76ee
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FlatMapDataset.pbtxt
@@ -0,0 +1,73 @@
+op {
+  name: "FlatMapDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
+op {
+  name: "FlatMapDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Floor.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Floor.pbtxt
new file mode 100644
index 0000000..27e405e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Floor.pbtxt
@@ -0,0 +1,68 @@
+op {
+  name: "Floor"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "Floor"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "Floor"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FloorDiv.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FloorDiv.pbtxt
new file mode 100644
index 0000000..08f232b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FloorDiv.pbtxt
@@ -0,0 +1,104 @@
+op {
+  name: "FloorDiv"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "FloorDiv"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "FloorDiv"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FloorMod.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FloorMod.pbtxt
new file mode 100644
index 0000000..c53a6c8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FloorMod.pbtxt
@@ -0,0 +1,84 @@
+op {
+  name: "FloorMod"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "FloorMod"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "FloorMod"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FlushSummaryWriter.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FlushSummaryWriter.pbtxt
new file mode 100644
index 0000000..f928d4a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FlushSummaryWriter.pbtxt
@@ -0,0 +1,8 @@
+op {
+  name: "FlushSummaryWriter"
+  input_arg {
+    name: "writer"
+    type: DT_RESOURCE
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/For.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/For.pbtxt
new file mode 100644
index 0000000..139990f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/For.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "For"
+  input_arg {
+    name: "start"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "limit"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "delta"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "input"
+    type_list_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "body"
+    type: "func"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FractionalAvgPool.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FractionalAvgPool.pbtxt
new file mode 100644
index 0000000..5fc527b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FractionalAvgPool.pbtxt
@@ -0,0 +1,72 @@
+op {
+  name: "FractionalAvgPool"
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "row_pooling_sequence"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "col_pooling_sequence"
+    type: DT_INT64
+  }
+  attr {
+    name: "pooling_ratio"
+    type: "list(float)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "pseudo_random"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "overlapping"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "deterministic"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FractionalAvgPoolGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FractionalAvgPoolGrad.pbtxt
new file mode 100644
index 0000000..cceb2fe
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FractionalAvgPoolGrad.pbtxt
@@ -0,0 +1,42 @@
+op {
+  name: "FractionalAvgPoolGrad"
+  input_arg {
+    name: "orig_input_tensor_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "row_pooling_sequence"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "col_pooling_sequence"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "overlapping"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FractionalMaxPool.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FractionalMaxPool.pbtxt
new file mode 100644
index 0000000..a11b4ef
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FractionalMaxPool.pbtxt
@@ -0,0 +1,72 @@
+op {
+  name: "FractionalMaxPool"
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "row_pooling_sequence"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "col_pooling_sequence"
+    type: DT_INT64
+  }
+  attr {
+    name: "pooling_ratio"
+    type: "list(float)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "pseudo_random"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "overlapping"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "deterministic"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FractionalMaxPoolGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FractionalMaxPoolGrad.pbtxt
new file mode 100644
index 0000000..711e98a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FractionalMaxPoolGrad.pbtxt
@@ -0,0 +1,46 @@
+op {
+  name: "FractionalMaxPoolGrad"
+  input_arg {
+    name: "orig_input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "orig_output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "out_backprop"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "row_pooling_sequence"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "col_pooling_sequence"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "overlapping"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FusedBatchNorm.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FusedBatchNorm.pbtxt
new file mode 100644
index 0000000..9f30c2a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FusedBatchNorm.pbtxt
@@ -0,0 +1,79 @@
+op {
+  name: "FusedBatchNorm"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "scale"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "offset"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "mean"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "variance"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "batch_mean"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "batch_variance"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "reserve_space_1"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "reserve_space_2"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "epsilon"
+    type: "float"
+    default_value {
+      f: 0.0001
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "is_training"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FusedBatchNormGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FusedBatchNormGrad.pbtxt
new file mode 100644
index 0000000..bff7eec
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FusedBatchNormGrad.pbtxt
@@ -0,0 +1,79 @@
+op {
+  name: "FusedBatchNormGrad"
+  input_arg {
+    name: "y_backprop"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "scale"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reserve_space_1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reserve_space_2"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "x_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "scale_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "offset_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "reserve_space_3"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "reserve_space_4"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "epsilon"
+    type: "float"
+    default_value {
+      f: 0.0001
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "is_training"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FusedBatchNormGradV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FusedBatchNormGradV2.pbtxt
new file mode 100644
index 0000000..dea20af
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FusedBatchNormGradV2.pbtxt
@@ -0,0 +1,90 @@
+op {
+  name: "FusedBatchNormGradV2"
+  input_arg {
+    name: "y_backprop"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "scale"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "reserve_space_1"
+    type_attr: "U"
+  }
+  input_arg {
+    name: "reserve_space_2"
+    type_attr: "U"
+  }
+  output_arg {
+    name: "x_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "scale_backprop"
+    type_attr: "U"
+  }
+  output_arg {
+    name: "offset_backprop"
+    type_attr: "U"
+  }
+  output_arg {
+    name: "reserve_space_3"
+    type_attr: "U"
+  }
+  output_arg {
+    name: "reserve_space_4"
+    type_attr: "U"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "U"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "epsilon"
+    type: "float"
+    default_value {
+      f: 0.0001
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "is_training"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FusedBatchNormGradV3.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FusedBatchNormGradV3.pbtxt
new file mode 100644
index 0000000..b1576ff
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FusedBatchNormGradV3.pbtxt
@@ -0,0 +1,94 @@
+op {
+  name: "FusedBatchNormGradV3"
+  input_arg {
+    name: "y_backprop"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "scale"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "reserve_space_1"
+    type_attr: "U"
+  }
+  input_arg {
+    name: "reserve_space_2"
+    type_attr: "U"
+  }
+  input_arg {
+    name: "reserve_space_3"
+    type_attr: "U"
+  }
+  output_arg {
+    name: "x_backprop"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "scale_backprop"
+    type_attr: "U"
+  }
+  output_arg {
+    name: "offset_backprop"
+    type_attr: "U"
+  }
+  output_arg {
+    name: "reserve_space_4"
+    type_attr: "U"
+  }
+  output_arg {
+    name: "reserve_space_5"
+    type_attr: "U"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "U"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "epsilon"
+    type: "float"
+    default_value {
+      f: 0.0001
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "is_training"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FusedBatchNormV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FusedBatchNormV2.pbtxt
new file mode 100644
index 0000000..170a90a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FusedBatchNormV2.pbtxt
@@ -0,0 +1,90 @@
+op {
+  name: "FusedBatchNormV2"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "scale"
+    type_attr: "U"
+  }
+  input_arg {
+    name: "offset"
+    type_attr: "U"
+  }
+  input_arg {
+    name: "mean"
+    type_attr: "U"
+  }
+  input_arg {
+    name: "variance"
+    type_attr: "U"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "batch_mean"
+    type_attr: "U"
+  }
+  output_arg {
+    name: "batch_variance"
+    type_attr: "U"
+  }
+  output_arg {
+    name: "reserve_space_1"
+    type_attr: "U"
+  }
+  output_arg {
+    name: "reserve_space_2"
+    type_attr: "U"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "U"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "epsilon"
+    type: "float"
+    default_value {
+      f: 0.0001
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "is_training"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FusedBatchNormV3.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FusedBatchNormV3.pbtxt
new file mode 100644
index 0000000..f79e493
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FusedBatchNormV3.pbtxt
@@ -0,0 +1,94 @@
+op {
+  name: "FusedBatchNormV3"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "scale"
+    type_attr: "U"
+  }
+  input_arg {
+    name: "offset"
+    type_attr: "U"
+  }
+  input_arg {
+    name: "mean"
+    type_attr: "U"
+  }
+  input_arg {
+    name: "variance"
+    type_attr: "U"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "batch_mean"
+    type_attr: "U"
+  }
+  output_arg {
+    name: "batch_variance"
+    type_attr: "U"
+  }
+  output_arg {
+    name: "reserve_space_1"
+    type_attr: "U"
+  }
+  output_arg {
+    name: "reserve_space_2"
+    type_attr: "U"
+  }
+  output_arg {
+    name: "reserve_space_3"
+    type_attr: "U"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "U"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "epsilon"
+    type: "float"
+    default_value {
+      f: 0.0001
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "is_training"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FusedPadConv2D.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FusedPadConv2D.pbtxt
new file mode 100644
index 0000000..7dc3eec
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FusedPadConv2D.pbtxt
@@ -0,0 +1,106 @@
+op {
+  name: "FusedPadConv2D"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "paddings"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "mode"
+    type: "string"
+    allowed_values {
+      list {
+        s: "REFLECT"
+        s: "SYMMETRIC"
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
+op {
+  name: "FusedPadConv2D"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "paddings"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "mode"
+    type: "string"
+    allowed_values {
+      list {
+        s: "REFLECT"
+        s: "SYMMETRIC"
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/FusedResizeAndPadConv2D.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/FusedResizeAndPadConv2D.pbtxt
new file mode 100644
index 0000000..cfc716f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/FusedResizeAndPadConv2D.pbtxt
@@ -0,0 +1,128 @@
+op {
+  name: "FusedResizeAndPadConv2D"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "paddings"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "resize_align_corners"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "mode"
+    type: "string"
+    allowed_values {
+      list {
+        s: "REFLECT"
+        s: "SYMMETRIC"
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
+op {
+  name: "FusedResizeAndPadConv2D"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "paddings"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "resize_align_corners"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "mode"
+    type: "string"
+    allowed_values {
+      list {
+        s: "REFLECT"
+        s: "SYMMETRIC"
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/GRUBlockCell.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/GRUBlockCell.pbtxt
new file mode 100644
index 0000000..7c0dd9d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/GRUBlockCell.pbtxt
@@ -0,0 +1,52 @@
+op {
+  name: "GRUBlockCell"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "h_prev"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "w_ru"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "w_c"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b_ru"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b_c"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "r"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "u"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "c"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "h"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/GRUBlockCellGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/GRUBlockCellGrad.pbtxt
new file mode 100644
index 0000000..723bcbd
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/GRUBlockCellGrad.pbtxt
@@ -0,0 +1,68 @@
+op {
+  name: "GRUBlockCellGrad"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "h_prev"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "w_ru"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "w_c"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b_ru"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b_c"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "r"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "u"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "c"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "d_h"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "d_x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "d_h_prev"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "d_c_bar"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "d_r_bar_u_bar"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Gather.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Gather.pbtxt
new file mode 100644
index 0000000..264a836
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Gather.pbtxt
@@ -0,0 +1,36 @@
+op {
+  name: "Gather"
+  input_arg {
+    name: "params"
+    type_attr: "Tparams"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "Tparams"
+  }
+  attr {
+    name: "validate_indices"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "Tparams"
+    type: "type"
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/GatherNd.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/GatherNd.pbtxt
new file mode 100644
index 0000000..43b7d3e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/GatherNd.pbtxt
@@ -0,0 +1,29 @@
+op {
+  name: "GatherNd"
+  input_arg {
+    name: "params"
+    type_attr: "Tparams"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "Tparams"
+  }
+  attr {
+    name: "Tparams"
+    type: "type"
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/GatherV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/GatherV2.pbtxt
new file mode 100644
index 0000000..bec3fa9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/GatherV2.pbtxt
@@ -0,0 +1,93 @@
+op {
+  name: "GatherV2"
+  input_arg {
+    name: "params"
+    type_attr: "Tparams"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Taxis"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "Tparams"
+  }
+  attr {
+    name: "Tparams"
+    type: "type"
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Taxis"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "GatherV2"
+  input_arg {
+    name: "params"
+    type_attr: "Tparams"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Taxis"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "Tparams"
+  }
+  attr {
+    name: "batch_dims"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "Tparams"
+    type: "type"
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Taxis"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/GenerateVocabRemapping.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/GenerateVocabRemapping.pbtxt
new file mode 100644
index 0000000..a095253
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/GenerateVocabRemapping.pbtxt
@@ -0,0 +1,67 @@
+op {
+  name: "GenerateVocabRemapping"
+  input_arg {
+    name: "new_vocab_file"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "old_vocab_file"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "remapping"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "num_present"
+    type: DT_INT32
+  }
+  attr {
+    name: "new_vocab_offset"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "num_new_vocab"
+    type: "int"
+    has_minimum: true
+  }
+}
+op {
+  name: "GenerateVocabRemapping"
+  input_arg {
+    name: "new_vocab_file"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "old_vocab_file"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "remapping"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "num_present"
+    type: DT_INT32
+  }
+  attr {
+    name: "new_vocab_offset"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "num_new_vocab"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "old_vocab_size"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/GeneratorDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/GeneratorDataset.pbtxt
new file mode 100644
index 0000000..86d75b2
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/GeneratorDataset.pbtxt
@@ -0,0 +1,59 @@
+op {
+  name: "GeneratorDataset"
+  input_arg {
+    name: "init_func_other_args"
+    type_list_attr: "Tinit_func_args"
+  }
+  input_arg {
+    name: "next_func_other_args"
+    type_list_attr: "Tnext_func_args"
+  }
+  input_arg {
+    name: "finalize_func_other_args"
+    type_list_attr: "Tfinalize_func_args"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "init_func"
+    type: "func"
+  }
+  attr {
+    name: "next_func"
+    type: "func"
+  }
+  attr {
+    name: "finalize_func"
+    type: "func"
+  }
+  attr {
+    name: "Tinit_func_args"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tnext_func_args"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tfinalize_func_args"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/GetSessionHandle.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/GetSessionHandle.pbtxt
new file mode 100644
index 0000000..e5345ec
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/GetSessionHandle.pbtxt
@@ -0,0 +1,64 @@
+op {
+  name: "GetSessionHandle"
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
+op {
+  name: "GetSessionHandle"
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  deprecation {
+    version: 23
+  }
+}
+op {
+  name: "GetSessionHandle"
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
+op {
+  name: "GetSessionHandle"
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/GetSessionHandleV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/GetSessionHandleV2.pbtxt
new file mode 100644
index 0000000..6040523
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/GetSessionHandleV2.pbtxt
@@ -0,0 +1,16 @@
+op {
+  name: "GetSessionHandleV2"
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/GetSessionTensor.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/GetSessionTensor.pbtxt
new file mode 100644
index 0000000..5c4cf8a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/GetSessionTensor.pbtxt
@@ -0,0 +1,31 @@
+op {
+  name: "GetSessionTensor"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "value"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+}
+op {
+  name: "GetSessionTensor"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "value"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Greater.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Greater.pbtxt
new file mode 100644
index 0000000..8860e3c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Greater.pbtxt
@@ -0,0 +1,136 @@
+op {
+  name: "Greater"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "Greater"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "Greater"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "Greater"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/GreaterEqual.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/GreaterEqual.pbtxt
new file mode 100644
index 0000000..5bcdd37
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/GreaterEqual.pbtxt
@@ -0,0 +1,136 @@
+op {
+  name: "GreaterEqual"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "GreaterEqual"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "GreaterEqual"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "GreaterEqual"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/GroupByReducerDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/GroupByReducerDataset.pbtxt
new file mode 100644
index 0000000..412cdab
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/GroupByReducerDataset.pbtxt
@@ -0,0 +1,76 @@
+op {
+  name: "GroupByReducerDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "key_func_other_arguments"
+    type_list_attr: "Tkey_func_other_arguments"
+  }
+  input_arg {
+    name: "init_func_other_arguments"
+    type_list_attr: "Tinit_func_other_arguments"
+  }
+  input_arg {
+    name: "reduce_func_other_arguments"
+    type_list_attr: "Treduce_func_other_arguments"
+  }
+  input_arg {
+    name: "finalize_func_other_arguments"
+    type_list_attr: "Tfinalize_func_other_arguments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "key_func"
+    type: "func"
+  }
+  attr {
+    name: "init_func"
+    type: "func"
+  }
+  attr {
+    name: "reduce_func"
+    type: "func"
+  }
+  attr {
+    name: "finalize_func"
+    type: "func"
+  }
+  attr {
+    name: "Tkey_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tinit_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Treduce_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tfinalize_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/GroupByWindowDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/GroupByWindowDataset.pbtxt
new file mode 100644
index 0000000..5c07855
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/GroupByWindowDataset.pbtxt
@@ -0,0 +1,62 @@
+op {
+  name: "GroupByWindowDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "key_func_other_arguments"
+    type_list_attr: "Tkey_func_other_arguments"
+  }
+  input_arg {
+    name: "reduce_func_other_arguments"
+    type_list_attr: "Treduce_func_other_arguments"
+  }
+  input_arg {
+    name: "window_size_func_other_arguments"
+    type_list_attr: "Twindow_size_func_other_arguments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "key_func"
+    type: "func"
+  }
+  attr {
+    name: "reduce_func"
+    type: "func"
+  }
+  attr {
+    name: "window_size_func"
+    type: "func"
+  }
+  attr {
+    name: "Tkey_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Treduce_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Twindow_size_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/GuaranteeConst.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/GuaranteeConst.pbtxt
new file mode 100644
index 0000000..71d47e3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/GuaranteeConst.pbtxt
@@ -0,0 +1,16 @@
+op {
+  name: "GuaranteeConst"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/HSVToRGB.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/HSVToRGB.pbtxt
new file mode 100644
index 0000000..2b209cc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/HSVToRGB.pbtxt
@@ -0,0 +1,50 @@
+op {
+  name: "HSVToRGB"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "HSVToRGB"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/HashTable.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/HashTable.pbtxt
new file mode 100644
index 0000000..83afe2b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/HashTable.pbtxt
@@ -0,0 +1,38 @@
+op {
+  name: "HashTable"
+  output_arg {
+    name: "table_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "use_node_name_sharing"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "key_dtype"
+    type: "type"
+  }
+  attr {
+    name: "value_dtype"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/HashTableV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/HashTableV2.pbtxt
new file mode 100644
index 0000000..24a9bc7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/HashTableV2.pbtxt
@@ -0,0 +1,37 @@
+op {
+  name: "HashTableV2"
+  output_arg {
+    name: "table_handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "use_node_name_sharing"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "key_dtype"
+    type: "type"
+  }
+  attr {
+    name: "value_dtype"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/HistogramFixedWidth.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/HistogramFixedWidth.pbtxt
new file mode 100644
index 0000000..f39eabe
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/HistogramFixedWidth.pbtxt
@@ -0,0 +1,44 @@
+op {
+  name: "HistogramFixedWidth"
+  input_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "value_range"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "nbins"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "out"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/HistogramSummary.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/HistogramSummary.pbtxt
new file mode 100644
index 0000000..0c46f39
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/HistogramSummary.pbtxt
@@ -0,0 +1,148 @@
+op {
+  name: "HistogramSummary"
+  input_arg {
+    name: "tag"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "summary"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "HistogramSummary"
+  input_arg {
+    name: "tag"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "summary"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "HistogramSummary"
+  input_arg {
+    name: "tag"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "summary"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "HistogramSummary"
+  input_arg {
+    name: "tag"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "summary"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/HostConst.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/HostConst.pbtxt
new file mode 100644
index 0000000..6dd4c17
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/HostConst.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "HostConst"
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "value"
+    type: "tensor"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IFFT.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IFFT.pbtxt
new file mode 100644
index 0000000..8571a13
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IFFT.pbtxt
@@ -0,0 +1,35 @@
+op {
+  name: "IFFT"
+  input_arg {
+    name: "input"
+    type: DT_COMPLEX64
+  }
+  output_arg {
+    name: "output"
+    type: DT_COMPLEX64
+  }
+}
+op {
+  name: "IFFT"
+  input_arg {
+    name: "input"
+    type_attr: "Tcomplex"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "Tcomplex"
+  }
+  attr {
+    name: "Tcomplex"
+    type: "type"
+    default_value {
+      type: DT_COMPLEX64
+    }
+    allowed_values {
+      list {
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IFFT2D.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IFFT2D.pbtxt
new file mode 100644
index 0000000..0b208d4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IFFT2D.pbtxt
@@ -0,0 +1,35 @@
+op {
+  name: "IFFT2D"
+  input_arg {
+    name: "input"
+    type: DT_COMPLEX64
+  }
+  output_arg {
+    name: "output"
+    type: DT_COMPLEX64
+  }
+}
+op {
+  name: "IFFT2D"
+  input_arg {
+    name: "input"
+    type_attr: "Tcomplex"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "Tcomplex"
+  }
+  attr {
+    name: "Tcomplex"
+    type: "type"
+    default_value {
+      type: DT_COMPLEX64
+    }
+    allowed_values {
+      list {
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IFFT3D.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IFFT3D.pbtxt
new file mode 100644
index 0000000..8b9667f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IFFT3D.pbtxt
@@ -0,0 +1,35 @@
+op {
+  name: "IFFT3D"
+  input_arg {
+    name: "input"
+    type: DT_COMPLEX64
+  }
+  output_arg {
+    name: "output"
+    type: DT_COMPLEX64
+  }
+}
+op {
+  name: "IFFT3D"
+  input_arg {
+    name: "input"
+    type_attr: "Tcomplex"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "Tcomplex"
+  }
+  attr {
+    name: "Tcomplex"
+    type: "type"
+    default_value {
+      type: DT_COMPLEX64
+    }
+    allowed_values {
+      list {
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IRFFT.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IRFFT.pbtxt
new file mode 100644
index 0000000..0975c35
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IRFFT.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "IRFFT"
+  input_arg {
+    name: "input"
+    type: DT_COMPLEX64
+  }
+  input_arg {
+    name: "fft_length"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type: DT_FLOAT
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IRFFT2D.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IRFFT2D.pbtxt
new file mode 100644
index 0000000..b850a6a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IRFFT2D.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "IRFFT2D"
+  input_arg {
+    name: "input"
+    type: DT_COMPLEX64
+  }
+  input_arg {
+    name: "fft_length"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type: DT_FLOAT
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IRFFT3D.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IRFFT3D.pbtxt
new file mode 100644
index 0000000..1cc8666
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IRFFT3D.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "IRFFT3D"
+  input_arg {
+    name: "input"
+    type: DT_COMPLEX64
+  }
+  input_arg {
+    name: "fft_length"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type: DT_FLOAT
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Identity.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Identity.pbtxt
new file mode 100644
index 0000000..f3ca3db
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Identity.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "Identity"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IdentityN.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IdentityN.pbtxt
new file mode 100644
index 0000000..61c3b63
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IdentityN.pbtxt
@@ -0,0 +1,17 @@
+op {
+  name: "IdentityN"
+  input_arg {
+    name: "input"
+    type_list_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IdentityReader.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IdentityReader.pbtxt
new file mode 100644
index 0000000..3330154
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IdentityReader.pbtxt
@@ -0,0 +1,49 @@
+op {
+  name: "IdentityReader"
+  output_arg {
+    name: "reader_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "IdentityReader"
+  output_arg {
+    name: "reader_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  deprecation {
+    version: 26
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IdentityReaderV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IdentityReaderV2.pbtxt
new file mode 100644
index 0000000..f37e9ce
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IdentityReaderV2.pbtxt
@@ -0,0 +1,22 @@
+op {
+  name: "IdentityReaderV2"
+  output_arg {
+    name: "reader_handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/If.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/If.pbtxt
new file mode 100644
index 0000000..7ccb12a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/If.pbtxt
@@ -0,0 +1,198 @@
+op {
+  name: "If"
+  input_arg {
+    name: "cond"
+    type_attr: "Tcond"
+  }
+  input_arg {
+    name: "input"
+    type_list_attr: "Tin"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "Tcond"
+    type: "type"
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "then_branch"
+    type: "func"
+  }
+  attr {
+    name: "else_branch"
+    type: "func"
+  }
+}
+op {
+  name: "If"
+  input_arg {
+    name: "cond"
+    type_attr: "Tcond"
+  }
+  input_arg {
+    name: "input"
+    type_list_attr: "Tin"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "Tcond"
+    type: "type"
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "then_branch"
+    type: "func"
+  }
+  attr {
+    name: "else_branch"
+    type: "func"
+  }
+}
+op {
+  name: "If"
+  input_arg {
+    name: "cond"
+    type_attr: "Tcond"
+  }
+  input_arg {
+    name: "input"
+    type_list_attr: "Tin"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "Tcond"
+    type: "type"
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "then_branch"
+    type: "func"
+  }
+  attr {
+    name: "else_branch"
+    type: "func"
+  }
+}
+op {
+  name: "If"
+  input_arg {
+    name: "cond"
+    type_attr: "Tcond"
+  }
+  input_arg {
+    name: "input"
+    type_list_attr: "Tin"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "Tcond"
+    type: "type"
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "then_branch"
+    type: "func"
+  }
+  attr {
+    name: "else_branch"
+    type: "func"
+  }
+  is_stateful: true
+}
+op {
+  name: "If"
+  input_arg {
+    name: "cond"
+    type_attr: "Tcond"
+  }
+  input_arg {
+    name: "input"
+    type_list_attr: "Tin"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "Tcond"
+    type: "type"
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "then_branch"
+    type: "func"
+  }
+  attr {
+    name: "else_branch"
+    type: "func"
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    default_value {
+      list {
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Igamma.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Igamma.pbtxt
new file mode 100644
index 0000000..822871d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Igamma.pbtxt
@@ -0,0 +1,25 @@
+op {
+  name: "Igamma"
+  input_arg {
+    name: "a"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IgammaGradA.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IgammaGradA.pbtxt
new file mode 100644
index 0000000..964067d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IgammaGradA.pbtxt
@@ -0,0 +1,25 @@
+op {
+  name: "IgammaGradA"
+  input_arg {
+    name: "a"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Igammac.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Igammac.pbtxt
new file mode 100644
index 0000000..46254f4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Igammac.pbtxt
@@ -0,0 +1,25 @@
+op {
+  name: "Igammac"
+  input_arg {
+    name: "a"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IgnoreErrorsDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IgnoreErrorsDataset.pbtxt
new file mode 100644
index 0000000..0670fd6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IgnoreErrorsDataset.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "IgnoreErrorsDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Imag.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Imag.pbtxt
new file mode 100644
index 0000000..1444b0c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Imag.pbtxt
@@ -0,0 +1,37 @@
+op {
+  name: "Imag"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "Tout"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_COMPLEX64
+    }
+    allowed_values {
+      list {
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  attr {
+    name: "Tout"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ImageSummary.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ImageSummary.pbtxt
new file mode 100644
index 0000000..fafd717
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ImageSummary.pbtxt
@@ -0,0 +1,113 @@
+op {
+  name: "ImageSummary"
+  input_arg {
+    name: "tag"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "summary"
+    type: DT_STRING
+  }
+  attr {
+    name: "max_images"
+    type: "int"
+    default_value {
+      i: 3
+    }
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_FLOAT
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "bad_color"
+    type: "tensor"
+    default_value {
+      tensor {
+        dtype: DT_UINT8
+        tensor_shape {
+          dim {
+            size: 4
+          }
+        }
+        int_val: 255
+        int_val: 0
+        int_val: 0
+        int_val: 255
+      }
+    }
+  }
+}
+op {
+  name: "ImageSummary"
+  input_arg {
+    name: "tag"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "summary"
+    type: DT_STRING
+  }
+  attr {
+    name: "max_images"
+    type: "int"
+    default_value {
+      i: 3
+    }
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "bad_color"
+    type: "tensor"
+    default_value {
+      tensor {
+        dtype: DT_UINT8
+        tensor_shape {
+          dim {
+            size: 4
+          }
+        }
+        int_val: 255
+        int_val: 0
+        int_val: 0
+        int_val: 255
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ImmutableConst.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ImmutableConst.pbtxt
new file mode 100644
index 0000000..ba11809
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ImmutableConst.pbtxt
@@ -0,0 +1,19 @@
+op {
+  name: "ImmutableConst"
+  output_arg {
+    name: "tensor"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  attr {
+    name: "memory_region_name"
+    type: "string"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ImportEvent.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ImportEvent.pbtxt
new file mode 100644
index 0000000..7be31dd
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ImportEvent.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "ImportEvent"
+  input_arg {
+    name: "writer"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "event"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/InTopK.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/InTopK.pbtxt
new file mode 100644
index 0000000..6acd3b6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/InTopK.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "InTopK"
+  input_arg {
+    name: "predictions"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "targets"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "precision"
+    type: DT_BOOL
+  }
+  attr {
+    name: "k"
+    type: "int"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/InTopKV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/InTopKV2.pbtxt
new file mode 100644
index 0000000..a6ca2b8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/InTopKV2.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "InTopKV2"
+  input_arg {
+    name: "predictions"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "targets"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "k"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "precision"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/InfeedDequeue.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/InfeedDequeue.pbtxt
new file mode 100644
index 0000000..a48d840
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/InfeedDequeue.pbtxt
@@ -0,0 +1,16 @@
+op {
+  name: "InfeedDequeue"
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/InfeedDequeueTuple.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/InfeedDequeueTuple.pbtxt
new file mode 100644
index 0000000..dc6ab2b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/InfeedDequeueTuple.pbtxt
@@ -0,0 +1,18 @@
+op {
+  name: "InfeedDequeueTuple"
+  output_arg {
+    name: "outputs"
+    type_list_attr: "dtypes"
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "shapes"
+    type: "list(shape)"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/InfeedEnqueue.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/InfeedEnqueue.pbtxt
new file mode 100644
index 0000000..759b914
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/InfeedEnqueue.pbtxt
@@ -0,0 +1,35 @@
+op {
+  name: "InfeedEnqueue"
+  input_arg {
+    name: "input"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+    default_value {
+      shape {
+      }
+    }
+  }
+  attr {
+    name: "layout"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "device_ordinal"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/InfeedEnqueuePrelinearizedBuffer.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/InfeedEnqueuePrelinearizedBuffer.pbtxt
new file mode 100644
index 0000000..d281b70
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/InfeedEnqueuePrelinearizedBuffer.pbtxt
@@ -0,0 +1,14 @@
+op {
+  name: "InfeedEnqueuePrelinearizedBuffer"
+  input_arg {
+    name: "input"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "device_ordinal"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/InfeedEnqueueTuple.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/InfeedEnqueueTuple.pbtxt
new file mode 100644
index 0000000..459c5d9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/InfeedEnqueueTuple.pbtxt
@@ -0,0 +1,33 @@
+op {
+  name: "InfeedEnqueueTuple"
+  input_arg {
+    name: "inputs"
+    type_list_attr: "dtypes"
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "shapes"
+    type: "list(shape)"
+  }
+  attr {
+    name: "layouts"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "device_ordinal"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/InitializeTable.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/InitializeTable.pbtxt
new file mode 100644
index 0000000..35a46a9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/InitializeTable.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "InitializeTable"
+  input_arg {
+    name: "table_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "keys"
+    type_attr: "Tkey"
+  }
+  input_arg {
+    name: "values"
+    type_attr: "Tval"
+  }
+  attr {
+    name: "Tkey"
+    type: "type"
+  }
+  attr {
+    name: "Tval"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/InitializeTableFromTextFile.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/InitializeTableFromTextFile.pbtxt
new file mode 100644
index 0000000..c4de3da
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/InitializeTableFromTextFile.pbtxt
@@ -0,0 +1,40 @@
+op {
+  name: "InitializeTableFromTextFile"
+  input_arg {
+    name: "table_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "filename"
+    type: DT_STRING
+  }
+  attr {
+    name: "key_index"
+    type: "int"
+    has_minimum: true
+    minimum: -2
+  }
+  attr {
+    name: "value_index"
+    type: "int"
+    has_minimum: true
+    minimum: -2
+  }
+  attr {
+    name: "vocab_size"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "delimiter"
+    type: "string"
+    default_value {
+      s: "\t"
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/InitializeTableFromTextFileV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/InitializeTableFromTextFileV2.pbtxt
new file mode 100644
index 0000000..0096e94
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/InitializeTableFromTextFileV2.pbtxt
@@ -0,0 +1,40 @@
+op {
+  name: "InitializeTableFromTextFileV2"
+  input_arg {
+    name: "table_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "filename"
+    type: DT_STRING
+  }
+  attr {
+    name: "key_index"
+    type: "int"
+    has_minimum: true
+    minimum: -2
+  }
+  attr {
+    name: "value_index"
+    type: "int"
+    has_minimum: true
+    minimum: -2
+  }
+  attr {
+    name: "vocab_size"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "delimiter"
+    type: "string"
+    default_value {
+      s: "\t"
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/InitializeTableV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/InitializeTableV2.pbtxt
new file mode 100644
index 0000000..62c5659
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/InitializeTableV2.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "InitializeTableV2"
+  input_arg {
+    name: "table_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "keys"
+    type_attr: "Tkey"
+  }
+  input_arg {
+    name: "values"
+    type_attr: "Tval"
+  }
+  attr {
+    name: "Tkey"
+    type: "type"
+  }
+  attr {
+    name: "Tval"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/InplaceAdd.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/InplaceAdd.pbtxt
new file mode 100644
index 0000000..7c66857
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/InplaceAdd.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "InplaceAdd"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "i"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "v"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/InplaceSub.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/InplaceSub.pbtxt
new file mode 100644
index 0000000..42d6c14
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/InplaceSub.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "InplaceSub"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "i"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "v"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/InplaceUpdate.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/InplaceUpdate.pbtxt
new file mode 100644
index 0000000..94b7f24
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/InplaceUpdate.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "InplaceUpdate"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "i"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "v"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/InterleaveDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/InterleaveDataset.pbtxt
new file mode 100644
index 0000000..ac36177
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/InterleaveDataset.pbtxt
@@ -0,0 +1,89 @@
+op {
+  name: "InterleaveDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  input_arg {
+    name: "cycle_length"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "block_length"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
+op {
+  name: "InterleaveDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  input_arg {
+    name: "cycle_length"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "block_length"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Inv.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Inv.pbtxt
new file mode 100644
index 0000000..ca20866
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Inv.pbtxt
@@ -0,0 +1,170 @@
+op {
+  name: "Inv"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  deprecation {
+    version: 17
+  }
+}
+op {
+  name: "Inv"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  deprecation {
+    version: 17
+  }
+}
+op {
+  name: "Inv"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Inv"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  deprecation {
+    version: 17
+  }
+}
+op {
+  name: "Inv"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Inv"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/InvGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/InvGrad.pbtxt
new file mode 100644
index 0000000..af882a9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/InvGrad.pbtxt
@@ -0,0 +1,213 @@
+op {
+  name: "InvGrad"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  deprecation {
+    version: 17
+  }
+}
+op {
+  name: "InvGrad"
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dy"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  deprecation {
+    version: 17
+  }
+}
+op {
+  name: "InvGrad"
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dy"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  deprecation {
+    version: 17
+  }
+}
+op {
+  name: "InvGrad"
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dy"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "InvGrad"
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dy"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  deprecation {
+    version: 17
+  }
+}
+op {
+  name: "InvGrad"
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dy"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "InvGrad"
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dy"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Invert.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Invert.pbtxt
new file mode 100644
index 0000000..cd9c812
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Invert.pbtxt
@@ -0,0 +1,52 @@
+op {
+  name: "Invert"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_UINT16
+      }
+    }
+  }
+}
+op {
+  name: "Invert"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/InvertPermutation.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/InvertPermutation.pbtxt
new file mode 100644
index 0000000..fa02896
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/InvertPermutation.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "InvertPermutation"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IsBoostedTreesEnsembleInitialized.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IsBoostedTreesEnsembleInitialized.pbtxt
new file mode 100644
index 0000000..1b19fef
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IsBoostedTreesEnsembleInitialized.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "IsBoostedTreesEnsembleInitialized"
+  input_arg {
+    name: "tree_ensemble_handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "is_initialized"
+    type: DT_BOOL
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IsBoostedTreesQuantileStreamResourceInitialized.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IsBoostedTreesQuantileStreamResourceInitialized.pbtxt
new file mode 100644
index 0000000..359e0e9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IsBoostedTreesQuantileStreamResourceInitialized.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "IsBoostedTreesQuantileStreamResourceInitialized"
+  input_arg {
+    name: "quantile_stream_resource_handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "is_initialized"
+    type: DT_BOOL
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IsFinite.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IsFinite.pbtxt
new file mode 100644
index 0000000..8410dce
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IsFinite.pbtxt
@@ -0,0 +1,68 @@
+op {
+  name: "IsFinite"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "IsFinite"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "IsFinite"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IsInf.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IsInf.pbtxt
new file mode 100644
index 0000000..1ce6c74
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IsInf.pbtxt
@@ -0,0 +1,68 @@
+op {
+  name: "IsInf"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "IsInf"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "IsInf"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IsNan.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IsNan.pbtxt
new file mode 100644
index 0000000..826f2ff
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IsNan.pbtxt
@@ -0,0 +1,68 @@
+op {
+  name: "IsNan"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "IsNan"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "IsNan"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IsVariableInitialized.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IsVariableInitialized.pbtxt
new file mode 100644
index 0000000..03496db
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IsVariableInitialized.pbtxt
@@ -0,0 +1,17 @@
+op {
+  name: "IsVariableInitialized"
+  input_arg {
+    name: "ref"
+    type_attr: "dtype"
+    is_ref: true
+  }
+  output_arg {
+    name: "is_initialized"
+    type: DT_BOOL
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  allows_uninitialized_input: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Iterator.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Iterator.pbtxt
new file mode 100644
index 0000000..76b9fde
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Iterator.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "Iterator"
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+  }
+  attr {
+    name: "container"
+    type: "string"
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IteratorFromStringHandle.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IteratorFromStringHandle.pbtxt
new file mode 100644
index 0000000..ebd3437
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IteratorFromStringHandle.pbtxt
@@ -0,0 +1,42 @@
+op {
+  name: "IteratorFromStringHandle"
+  input_arg {
+    name: "string_handle"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "resource_handle"
+    type: DT_RESOURCE
+  }
+  is_stateful: true
+}
+op {
+  name: "IteratorFromStringHandle"
+  input_arg {
+    name: "string_handle"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "resource_handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IteratorFromStringHandleV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IteratorFromStringHandleV2.pbtxt
new file mode 100644
index 0000000..624c473
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IteratorFromStringHandleV2.pbtxt
@@ -0,0 +1,30 @@
+op {
+  name: "IteratorFromStringHandleV2"
+  input_arg {
+    name: "string_handle"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "resource_handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IteratorGetDevice.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IteratorGetDevice.pbtxt
new file mode 100644
index 0000000..8d379c1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IteratorGetDevice.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "IteratorGetDevice"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "device"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IteratorGetNext.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IteratorGetNext.pbtxt
new file mode 100644
index 0000000..f204011
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IteratorGetNext.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "IteratorGetNext"
+  input_arg {
+    name: "iterator"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "components"
+    type_list_attr: "output_types"
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IteratorGetNextAsOptional.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IteratorGetNextAsOptional.pbtxt
new file mode 100644
index 0000000..4c13586
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IteratorGetNextAsOptional.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "IteratorGetNextAsOptional"
+  input_arg {
+    name: "iterator"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "optional"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IteratorGetNextSync.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IteratorGetNextSync.pbtxt
new file mode 100644
index 0000000..e1a7351
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IteratorGetNextSync.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "IteratorGetNextSync"
+  input_arg {
+    name: "iterator"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "components"
+    type_list_attr: "output_types"
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IteratorToStringHandle.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IteratorToStringHandle.pbtxt
new file mode 100644
index 0000000..87f2dff
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IteratorToStringHandle.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "IteratorToStringHandle"
+  input_arg {
+    name: "resource_handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "string_handle"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/IteratorV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/IteratorV2.pbtxt
new file mode 100644
index 0000000..6f7ab70
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/IteratorV2.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "IteratorV2"
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+  }
+  attr {
+    name: "container"
+    type: "string"
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/KMC2ChainInitialization.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/KMC2ChainInitialization.pbtxt
new file mode 100644
index 0000000..e964097
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/KMC2ChainInitialization.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "KMC2ChainInitialization"
+  input_arg {
+    name: "distances"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "seed"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "index"
+    type: DT_INT64
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/KmeansPlusPlusInitialization.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/KmeansPlusPlusInitialization.pbtxt
new file mode 100644
index 0000000..27ab4b3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/KmeansPlusPlusInitialization.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "KmeansPlusPlusInitialization"
+  input_arg {
+    name: "points"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "num_to_sample"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "seed"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "num_retries_per_sample"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "samples"
+    type: DT_FLOAT
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/L2Loss.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/L2Loss.pbtxt
new file mode 100644
index 0000000..90e8619
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/L2Loss.pbtxt
@@ -0,0 +1,67 @@
+op {
+  name: "L2Loss"
+  input_arg {
+    name: "t"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "L2Loss"
+  input_arg {
+    name: "t"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "L2Loss"
+  input_arg {
+    name: "t"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LMDBDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LMDBDataset.pbtxt
new file mode 100644
index 0000000..ff42e6f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LMDBDataset.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "LMDBDataset"
+  input_arg {
+    name: "filenames"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LMDBReader.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LMDBReader.pbtxt
new file mode 100644
index 0000000..967c74b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LMDBReader.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "LMDBReader"
+  output_arg {
+    name: "reader_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LRN.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LRN.pbtxt
new file mode 100644
index 0000000..7588068
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LRN.pbtxt
@@ -0,0 +1,105 @@
+op {
+  name: "LRN"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "depth_radius"
+    type: "int"
+    default_value {
+      i: 5
+    }
+  }
+  attr {
+    name: "bias"
+    type: "float"
+    default_value {
+      f: 1
+    }
+  }
+  attr {
+    name: "alpha"
+    type: "float"
+    default_value {
+      f: 1
+    }
+  }
+  attr {
+    name: "beta"
+    type: "float"
+    default_value {
+      f: 0.5
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "LRN"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "depth_radius"
+    type: "int"
+    default_value {
+      i: 5
+    }
+  }
+  attr {
+    name: "bias"
+    type: "float"
+    default_value {
+      f: 1
+    }
+  }
+  attr {
+    name: "alpha"
+    type: "float"
+    default_value {
+      f: 1
+    }
+  }
+  attr {
+    name: "beta"
+    type: "float"
+    default_value {
+      f: 0.5
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LRNGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LRNGrad.pbtxt
new file mode 100644
index 0000000..37db775
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LRNGrad.pbtxt
@@ -0,0 +1,121 @@
+op {
+  name: "LRNGrad"
+  input_arg {
+    name: "input_grads"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_image"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_image"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "depth_radius"
+    type: "int"
+    default_value {
+      i: 5
+    }
+  }
+  attr {
+    name: "bias"
+    type: "float"
+    default_value {
+      f: 1
+    }
+  }
+  attr {
+    name: "alpha"
+    type: "float"
+    default_value {
+      f: 1
+    }
+  }
+  attr {
+    name: "beta"
+    type: "float"
+    default_value {
+      f: 0.5
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "LRNGrad"
+  input_arg {
+    name: "input_grads"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_image"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "output_image"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "depth_radius"
+    type: "int"
+    default_value {
+      i: 5
+    }
+  }
+  attr {
+    name: "bias"
+    type: "float"
+    default_value {
+      f: 1
+    }
+  }
+  attr {
+    name: "alpha"
+    type: "float"
+    default_value {
+      f: 1
+    }
+  }
+  attr {
+    name: "beta"
+    type: "float"
+    default_value {
+      f: 0.5
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LSTMBlockCell.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LSTMBlockCell.pbtxt
new file mode 100644
index 0000000..f1071f7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LSTMBlockCell.pbtxt
@@ -0,0 +1,94 @@
+op {
+  name: "LSTMBlockCell"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "cs_prev"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "h_prev"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "w"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "wci"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "wcf"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "wco"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "i"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "cs"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "f"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "o"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "ci"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "co"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "h"
+    type_attr: "T"
+  }
+  attr {
+    name: "forget_bias"
+    type: "float"
+    default_value {
+      f: 1
+    }
+  }
+  attr {
+    name: "cell_clip"
+    type: "float"
+    default_value {
+      f: 3
+    }
+  }
+  attr {
+    name: "use_peephole"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LSTMBlockCellGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LSTMBlockCellGrad.pbtxt
new file mode 100644
index 0000000..b20d47c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LSTMBlockCellGrad.pbtxt
@@ -0,0 +1,101 @@
+op {
+  name: "LSTMBlockCellGrad"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "cs_prev"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "h_prev"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "w"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "wci"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "wcf"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "wco"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "i"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "cs"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "f"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "o"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "ci"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "co"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "cs_grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "h_grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "cs_prev_grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "dicfo"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "wci_grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "wcf_grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "wco_grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "use_peephole"
+    type: "bool"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LatencyStatsDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LatencyStatsDataset.pbtxt
new file mode 100644
index 0000000..2459e86
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LatencyStatsDataset.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "LatencyStatsDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "tag"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LeakyRelu.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LeakyRelu.pbtxt
new file mode 100644
index 0000000..c0358f9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LeakyRelu.pbtxt
@@ -0,0 +1,65 @@
+op {
+  name: "LeakyRelu"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "alpha"
+    type: "float"
+    default_value {
+      f: 0.2
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "LeakyRelu"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "alpha"
+    type: "float"
+    default_value {
+      f: 0.2
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LeakyReluGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LeakyReluGrad.pbtxt
new file mode 100644
index 0000000..7868722
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LeakyReluGrad.pbtxt
@@ -0,0 +1,73 @@
+op {
+  name: "LeakyReluGrad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "alpha"
+    type: "float"
+    default_value {
+      f: 0.2
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "LeakyReluGrad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "alpha"
+    type: "float"
+    default_value {
+      f: 0.2
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LearnedUnigramCandidateSampler.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LearnedUnigramCandidateSampler.pbtxt
new file mode 100644
index 0000000..71466c5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LearnedUnigramCandidateSampler.pbtxt
@@ -0,0 +1,111 @@
+op {
+  name: "LearnedUnigramCandidateSampler"
+  input_arg {
+    name: "true_classes"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sampled_candidates"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "true_expected_count"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "sampled_expected_count"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "num_true"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_sampled"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "unique"
+    type: "bool"
+  }
+  attr {
+    name: "range_max"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+}
+op {
+  name: "LearnedUnigramCandidateSampler"
+  input_arg {
+    name: "true_classes"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sampled_candidates"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "true_expected_count"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "sampled_expected_count"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "num_true"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_sampled"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "unique"
+    type: "bool"
+  }
+  attr {
+    name: "range_max"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LeftShift.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LeftShift.pbtxt
new file mode 100644
index 0000000..c3f56be
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LeftShift.pbtxt
@@ -0,0 +1,63 @@
+op {
+  name: "LeftShift"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "LeftShift"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Less.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Less.pbtxt
new file mode 100644
index 0000000..e4f1245
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Less.pbtxt
@@ -0,0 +1,136 @@
+op {
+  name: "Less"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "Less"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "Less"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "Less"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LessEqual.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LessEqual.pbtxt
new file mode 100644
index 0000000..9162a68
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LessEqual.pbtxt
@@ -0,0 +1,136 @@
+op {
+  name: "LessEqual"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "LessEqual"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "LessEqual"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "LessEqual"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Lgamma.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Lgamma.pbtxt
new file mode 100644
index 0000000..fcb0241
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Lgamma.pbtxt
@@ -0,0 +1,68 @@
+op {
+  name: "Lgamma"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "Lgamma"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "Lgamma"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LinSpace.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LinSpace.pbtxt
new file mode 100644
index 0000000..931c751
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LinSpace.pbtxt
@@ -0,0 +1,85 @@
+op {
+  name: "LinSpace"
+  input_arg {
+    name: "start"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "stop"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "num"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "LinSpace"
+  input_arg {
+    name: "start"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "stop"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "num"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ListDiff.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ListDiff.pbtxt
new file mode 100644
index 0000000..39c3ee8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ListDiff.pbtxt
@@ -0,0 +1,36 @@
+op {
+  name: "ListDiff"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "idx"
+    type_attr: "out_idx"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "out_idx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LoadAndRemapMatrix.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LoadAndRemapMatrix.pbtxt
new file mode 100644
index 0000000..54b4a68
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LoadAndRemapMatrix.pbtxt
@@ -0,0 +1,46 @@
+op {
+  name: "LoadAndRemapMatrix"
+  input_arg {
+    name: "ckpt_path"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "old_tensor_name"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "row_remapping"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "col_remapping"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "initializing_values"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output_matrix"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "num_rows"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "num_cols"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "max_rows_in_memory"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingADAMParameters.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingADAMParameters.pbtxt
new file mode 100644
index 0000000..bdf33d6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingADAMParameters.pbtxt
@@ -0,0 +1,40 @@
+op {
+  name: "LoadTPUEmbeddingADAMParameters"
+  input_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "momenta"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "velocities"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingADAMParametersGradAccumDebug.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingADAMParametersGradAccumDebug.pbtxt
new file mode 100644
index 0000000..a033105
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingADAMParametersGradAccumDebug.pbtxt
@@ -0,0 +1,44 @@
+op {
+  name: "LoadTPUEmbeddingADAMParametersGradAccumDebug"
+  input_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "momenta"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "velocities"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "gradient_accumulators"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingAdadeltaParameters.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingAdadeltaParameters.pbtxt
new file mode 100644
index 0000000..02136e5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingAdadeltaParameters.pbtxt
@@ -0,0 +1,40 @@
+op {
+  name: "LoadTPUEmbeddingAdadeltaParameters"
+  input_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "accumulators"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "updates"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingAdadeltaParametersGradAccumDebug.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingAdadeltaParametersGradAccumDebug.pbtxt
new file mode 100644
index 0000000..32485bf
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingAdadeltaParametersGradAccumDebug.pbtxt
@@ -0,0 +1,44 @@
+op {
+  name: "LoadTPUEmbeddingAdadeltaParametersGradAccumDebug"
+  input_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "accumulators"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "updates"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "gradient_accumulators"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingAdagradParameters.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingAdagradParameters.pbtxt
new file mode 100644
index 0000000..e40d457
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingAdagradParameters.pbtxt
@@ -0,0 +1,36 @@
+op {
+  name: "LoadTPUEmbeddingAdagradParameters"
+  input_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "accumulators"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingAdagradParametersGradAccumDebug.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingAdagradParametersGradAccumDebug.pbtxt
new file mode 100644
index 0000000..ba403c1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingAdagradParametersGradAccumDebug.pbtxt
@@ -0,0 +1,40 @@
+op {
+  name: "LoadTPUEmbeddingAdagradParametersGradAccumDebug"
+  input_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "accumulators"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "gradient_accumulators"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingCenteredRMSPropParameters.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingCenteredRMSPropParameters.pbtxt
new file mode 100644
index 0000000..36280a5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingCenteredRMSPropParameters.pbtxt
@@ -0,0 +1,44 @@
+op {
+  name: "LoadTPUEmbeddingCenteredRMSPropParameters"
+  input_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "ms"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "mom"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "mg"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingFTRLParameters.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingFTRLParameters.pbtxt
new file mode 100644
index 0000000..8785f4e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingFTRLParameters.pbtxt
@@ -0,0 +1,40 @@
+op {
+  name: "LoadTPUEmbeddingFTRLParameters"
+  input_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "accumulators"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "linears"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingFTRLParametersGradAccumDebug.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingFTRLParametersGradAccumDebug.pbtxt
new file mode 100644
index 0000000..640801b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingFTRLParametersGradAccumDebug.pbtxt
@@ -0,0 +1,44 @@
+op {
+  name: "LoadTPUEmbeddingFTRLParametersGradAccumDebug"
+  input_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "accumulators"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "linears"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "gradient_accumulators"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingMDLAdagradLightParameters.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingMDLAdagradLightParameters.pbtxt
new file mode 100644
index 0000000..2b86a8e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingMDLAdagradLightParameters.pbtxt
@@ -0,0 +1,44 @@
+op {
+  name: "LoadTPUEmbeddingMDLAdagradLightParameters"
+  input_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "accumulators"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "weights"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "benefits"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingMomentumParameters.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingMomentumParameters.pbtxt
new file mode 100644
index 0000000..1622c9a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingMomentumParameters.pbtxt
@@ -0,0 +1,36 @@
+op {
+  name: "LoadTPUEmbeddingMomentumParameters"
+  input_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "momenta"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingMomentumParametersGradAccumDebug.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingMomentumParametersGradAccumDebug.pbtxt
new file mode 100644
index 0000000..fe66f27
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingMomentumParametersGradAccumDebug.pbtxt
@@ -0,0 +1,40 @@
+op {
+  name: "LoadTPUEmbeddingMomentumParametersGradAccumDebug"
+  input_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "momenta"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "gradient_accumulators"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingProximalAdagradParameters.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingProximalAdagradParameters.pbtxt
new file mode 100644
index 0000000..75a3ca5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingProximalAdagradParameters.pbtxt
@@ -0,0 +1,36 @@
+op {
+  name: "LoadTPUEmbeddingProximalAdagradParameters"
+  input_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "accumulators"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug.pbtxt
new file mode 100644
index 0000000..58ea405
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug.pbtxt
@@ -0,0 +1,40 @@
+op {
+  name: "LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug"
+  input_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "accumulators"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "gradient_accumulators"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingRMSPropParameters.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingRMSPropParameters.pbtxt
new file mode 100644
index 0000000..2867fda
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingRMSPropParameters.pbtxt
@@ -0,0 +1,40 @@
+op {
+  name: "LoadTPUEmbeddingRMSPropParameters"
+  input_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "ms"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "mom"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingRMSPropParametersGradAccumDebug.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingRMSPropParametersGradAccumDebug.pbtxt
new file mode 100644
index 0000000..506e17e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingRMSPropParametersGradAccumDebug.pbtxt
@@ -0,0 +1,44 @@
+op {
+  name: "LoadTPUEmbeddingRMSPropParametersGradAccumDebug"
+  input_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "ms"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "mom"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "gradient_accumulators"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingStochasticGradientDescentParameters.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingStochasticGradientDescentParameters.pbtxt
new file mode 100644
index 0000000..2c69b16
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LoadTPUEmbeddingStochasticGradientDescentParameters.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "LoadTPUEmbeddingStochasticGradientDescentParameters"
+  input_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Log.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Log.pbtxt
new file mode 100644
index 0000000..a16862c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Log.pbtxt
@@ -0,0 +1,74 @@
+op {
+  name: "Log"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Log"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Log"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Log1p.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Log1p.pbtxt
new file mode 100644
index 0000000..1f8ba12
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Log1p.pbtxt
@@ -0,0 +1,74 @@
+op {
+  name: "Log1p"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Log1p"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Log1p"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LogMatrixDeterminant.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LogMatrixDeterminant.pbtxt
new file mode 100644
index 0000000..3807cdd
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LogMatrixDeterminant.pbtxt
@@ -0,0 +1,55 @@
+op {
+  name: "LogMatrixDeterminant"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "sign"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "log_abs_determinant"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "LogMatrixDeterminant"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "sign"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "log_abs_determinant"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LogSoftmax.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LogSoftmax.pbtxt
new file mode 100644
index 0000000..92d2727
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LogSoftmax.pbtxt
@@ -0,0 +1,45 @@
+op {
+  name: "LogSoftmax"
+  input_arg {
+    name: "logits"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "logsoftmax"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "LogSoftmax"
+  input_arg {
+    name: "logits"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "logsoftmax"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LogUniformCandidateSampler.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LogUniformCandidateSampler.pbtxt
new file mode 100644
index 0000000..9ec4557
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LogUniformCandidateSampler.pbtxt
@@ -0,0 +1,111 @@
+op {
+  name: "LogUniformCandidateSampler"
+  input_arg {
+    name: "true_classes"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sampled_candidates"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "true_expected_count"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "sampled_expected_count"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "num_true"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_sampled"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "unique"
+    type: "bool"
+  }
+  attr {
+    name: "range_max"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+}
+op {
+  name: "LogUniformCandidateSampler"
+  input_arg {
+    name: "true_classes"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sampled_candidates"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "true_expected_count"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "sampled_expected_count"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "num_true"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_sampled"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "unique"
+    type: "bool"
+  }
+  attr {
+    name: "range_max"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LogicalAnd.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LogicalAnd.pbtxt
new file mode 100644
index 0000000..b10b115
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LogicalAnd.pbtxt
@@ -0,0 +1,16 @@
+op {
+  name: "LogicalAnd"
+  input_arg {
+    name: "x"
+    type: DT_BOOL
+  }
+  input_arg {
+    name: "y"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  is_commutative: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LogicalNot.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LogicalNot.pbtxt
new file mode 100644
index 0000000..5cf13ad
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LogicalNot.pbtxt
@@ -0,0 +1,11 @@
+op {
+  name: "LogicalNot"
+  input_arg {
+    name: "x"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "y"
+    type: DT_BOOL
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LogicalOr.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LogicalOr.pbtxt
new file mode 100644
index 0000000..635a66d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LogicalOr.pbtxt
@@ -0,0 +1,16 @@
+op {
+  name: "LogicalOr"
+  input_arg {
+    name: "x"
+    type: DT_BOOL
+  }
+  input_arg {
+    name: "y"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  is_commutative: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LookupTableExport.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LookupTableExport.pbtxt
new file mode 100644
index 0000000..6c56cde
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LookupTableExport.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "LookupTableExport"
+  input_arg {
+    name: "table_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  output_arg {
+    name: "keys"
+    type_attr: "Tkeys"
+  }
+  output_arg {
+    name: "values"
+    type_attr: "Tvalues"
+  }
+  attr {
+    name: "Tkeys"
+    type: "type"
+  }
+  attr {
+    name: "Tvalues"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LookupTableExportV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LookupTableExportV2.pbtxt
new file mode 100644
index 0000000..b86fd3a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LookupTableExportV2.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "LookupTableExportV2"
+  input_arg {
+    name: "table_handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "keys"
+    type_attr: "Tkeys"
+  }
+  output_arg {
+    name: "values"
+    type_attr: "Tvalues"
+  }
+  attr {
+    name: "Tkeys"
+    type: "type"
+  }
+  attr {
+    name: "Tvalues"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LookupTableFind.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LookupTableFind.pbtxt
new file mode 100644
index 0000000..5923b50
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LookupTableFind.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "LookupTableFind"
+  input_arg {
+    name: "table_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "keys"
+    type_attr: "Tin"
+  }
+  input_arg {
+    name: "default_value"
+    type_attr: "Tout"
+  }
+  output_arg {
+    name: "values"
+    type_attr: "Tout"
+  }
+  attr {
+    name: "Tin"
+    type: "type"
+  }
+  attr {
+    name: "Tout"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LookupTableFindV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LookupTableFindV2.pbtxt
new file mode 100644
index 0000000..53cbafb
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LookupTableFindV2.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "LookupTableFindV2"
+  input_arg {
+    name: "table_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "keys"
+    type_attr: "Tin"
+  }
+  input_arg {
+    name: "default_value"
+    type_attr: "Tout"
+  }
+  output_arg {
+    name: "values"
+    type_attr: "Tout"
+  }
+  attr {
+    name: "Tin"
+    type: "type"
+  }
+  attr {
+    name: "Tout"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LookupTableImport.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LookupTableImport.pbtxt
new file mode 100644
index 0000000..73b53a5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LookupTableImport.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "LookupTableImport"
+  input_arg {
+    name: "table_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "keys"
+    type_attr: "Tin"
+  }
+  input_arg {
+    name: "values"
+    type_attr: "Tout"
+  }
+  attr {
+    name: "Tin"
+    type: "type"
+  }
+  attr {
+    name: "Tout"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LookupTableImportV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LookupTableImportV2.pbtxt
new file mode 100644
index 0000000..41c03b8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LookupTableImportV2.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "LookupTableImportV2"
+  input_arg {
+    name: "table_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "keys"
+    type_attr: "Tin"
+  }
+  input_arg {
+    name: "values"
+    type_attr: "Tout"
+  }
+  attr {
+    name: "Tin"
+    type: "type"
+  }
+  attr {
+    name: "Tout"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LookupTableInsert.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LookupTableInsert.pbtxt
new file mode 100644
index 0000000..b96cb47
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LookupTableInsert.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "LookupTableInsert"
+  input_arg {
+    name: "table_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "keys"
+    type_attr: "Tin"
+  }
+  input_arg {
+    name: "values"
+    type_attr: "Tout"
+  }
+  attr {
+    name: "Tin"
+    type: "type"
+  }
+  attr {
+    name: "Tout"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LookupTableInsertV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LookupTableInsertV2.pbtxt
new file mode 100644
index 0000000..19d7d49
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LookupTableInsertV2.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "LookupTableInsertV2"
+  input_arg {
+    name: "table_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "keys"
+    type_attr: "Tin"
+  }
+  input_arg {
+    name: "values"
+    type_attr: "Tout"
+  }
+  attr {
+    name: "Tin"
+    type: "type"
+  }
+  attr {
+    name: "Tout"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LookupTableRemoveV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LookupTableRemoveV2.pbtxt
new file mode 100644
index 0000000..d7fe0bb
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LookupTableRemoveV2.pbtxt
@@ -0,0 +1,16 @@
+op {
+  name: "LookupTableRemoveV2"
+  input_arg {
+    name: "table_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "keys"
+    type_attr: "Tin"
+  }
+  attr {
+    name: "Tin"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LookupTableSize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LookupTableSize.pbtxt
new file mode 100644
index 0000000..0d4bf61
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LookupTableSize.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "LookupTableSize"
+  input_arg {
+    name: "table_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  output_arg {
+    name: "size"
+    type: DT_INT64
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LookupTableSizeV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LookupTableSizeV2.pbtxt
new file mode 100644
index 0000000..511beed
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LookupTableSizeV2.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "LookupTableSizeV2"
+  input_arg {
+    name: "table_handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "size"
+    type: DT_INT64
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LoopCond.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LoopCond.pbtxt
new file mode 100644
index 0000000..7111fff
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LoopCond.pbtxt
@@ -0,0 +1,11 @@
+op {
+  name: "LoopCond"
+  input_arg {
+    name: "input"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "output"
+    type: DT_BOOL
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/LowerBound.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/LowerBound.pbtxt
new file mode 100644
index 0000000..b7d1dee
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/LowerBound.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "LowerBound"
+  input_arg {
+    name: "sorted_inputs"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Lu.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Lu.pbtxt
new file mode 100644
index 0000000..59c28e0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Lu.pbtxt
@@ -0,0 +1,81 @@
+op {
+  name: "Lu"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "lu"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "p"
+    type_attr: "output_idx_type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  attr {
+    name: "output_idx_type"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Lu"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "lu"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "p"
+    type_attr: "output_idx_type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  attr {
+    name: "output_idx_type"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MakeIterator.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MakeIterator.pbtxt
new file mode 100644
index 0000000..b11c2b9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MakeIterator.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "MakeIterator"
+  input_arg {
+    name: "dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "iterator"
+    type: DT_RESOURCE
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MapAndBatchDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MapAndBatchDataset.pbtxt
new file mode 100644
index 0000000..2f6b79a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MapAndBatchDataset.pbtxt
@@ -0,0 +1,55 @@
+op {
+  name: "MapAndBatchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  input_arg {
+    name: "batch_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "num_parallel_calls"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "drop_remainder"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "preserve_cardinality"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MapClear.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MapClear.pbtxt
new file mode 100644
index 0000000..22c5e5f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MapClear.pbtxt
@@ -0,0 +1,38 @@
+op {
+  name: "MapClear"
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "memory_limit"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MapDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MapDataset.pbtxt
new file mode 100644
index 0000000..59354f3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MapDataset.pbtxt
@@ -0,0 +1,166 @@
+op {
+  name: "MapDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
+op {
+  name: "MapDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
+  name: "MapDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "use_inter_op_parallelism"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
+op {
+  name: "MapDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "use_inter_op_parallelism"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "preserve_cardinality"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MapDefun.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MapDefun.pbtxt
new file mode 100644
index 0000000..7cb9d19
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MapDefun.pbtxt
@@ -0,0 +1,132 @@
+op {
+  name: "MapDefun"
+  input_arg {
+    name: "arguments"
+    type_list_attr: "Targuments"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "output_types"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+}
+op {
+  name: "MapDefun"
+  input_arg {
+    name: "arguments"
+    type_list_attr: "Targuments"
+  }
+  input_arg {
+    name: "captured_inputs"
+    type_list_attr: "Tcaptured"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "output_types"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "Tcaptured"
+    type: "list(type)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+}
+op {
+  name: "MapDefun"
+  input_arg {
+    name: "arguments"
+    type_list_attr: "Targuments"
+  }
+  input_arg {
+    name: "captured_inputs"
+    type_list_attr: "Tcaptured"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "output_types"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "Tcaptured"
+    type: "list(type)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "max_intra_op_parallelism"
+    type: "int"
+    default_value {
+      i: 1
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MapIncompleteSize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MapIncompleteSize.pbtxt
new file mode 100644
index 0000000..ca9c629
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MapIncompleteSize.pbtxt
@@ -0,0 +1,42 @@
+op {
+  name: "MapIncompleteSize"
+  output_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "memory_limit"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MapPeek.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MapPeek.pbtxt
new file mode 100644
index 0000000..4a61cb9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MapPeek.pbtxt
@@ -0,0 +1,52 @@
+op {
+  name: "MapPeek"
+  input_arg {
+    name: "key"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "values"
+    type_list_attr: "dtypes"
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "memory_limit"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MapSize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MapSize.pbtxt
new file mode 100644
index 0000000..6828f8f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MapSize.pbtxt
@@ -0,0 +1,42 @@
+op {
+  name: "MapSize"
+  output_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "memory_limit"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MapStage.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MapStage.pbtxt
new file mode 100644
index 0000000..4ad2131
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MapStage.pbtxt
@@ -0,0 +1,56 @@
+op {
+  name: "MapStage"
+  input_arg {
+    name: "key"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "values"
+    type_list_attr: "fake_dtypes"
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "memory_limit"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+  }
+  attr {
+    name: "fake_dtypes"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MapUnstage.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MapUnstage.pbtxt
new file mode 100644
index 0000000..9901130
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MapUnstage.pbtxt
@@ -0,0 +1,52 @@
+op {
+  name: "MapUnstage"
+  input_arg {
+    name: "key"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "values"
+    type_list_attr: "dtypes"
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "memory_limit"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MapUnstageNoKey.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MapUnstageNoKey.pbtxt
new file mode 100644
index 0000000..ee4cca5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MapUnstageNoKey.pbtxt
@@ -0,0 +1,52 @@
+op {
+  name: "MapUnstageNoKey"
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "key"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "values"
+    type_list_attr: "dtypes"
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "memory_limit"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MatMul.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MatMul.pbtxt
new file mode 100644
index 0000000..83e157c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MatMul.pbtxt
@@ -0,0 +1,176 @@
+op {
+  name: "MatMul"
+  input_arg {
+    name: "a"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "product"
+    type_attr: "T"
+  }
+  attr {
+    name: "transpose_a"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "transpose_b"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "MatMul"
+  input_arg {
+    name: "a"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "product"
+    type_attr: "T"
+  }
+  attr {
+    name: "transpose_a"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "transpose_b"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "MatMul"
+  input_arg {
+    name: "a"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "product"
+    type_attr: "T"
+  }
+  attr {
+    name: "transpose_a"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "transpose_b"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "MatMul"
+  input_arg {
+    name: "a"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "product"
+    type_attr: "T"
+  }
+  attr {
+    name: "transpose_a"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "transpose_b"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MatchingFiles.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MatchingFiles.pbtxt
new file mode 100644
index 0000000..3f8af5f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MatchingFiles.pbtxt
@@ -0,0 +1,11 @@
+op {
+  name: "MatchingFiles"
+  input_arg {
+    name: "pattern"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "filenames"
+    type: DT_STRING
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MatchingFilesDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MatchingFilesDataset.pbtxt
new file mode 100644
index 0000000..8c46cf6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MatchingFilesDataset.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "MatchingFilesDataset"
+  input_arg {
+    name: "patterns"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MatrixBandPart.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MatrixBandPart.pbtxt
new file mode 100644
index 0000000..c25aa96
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MatrixBandPart.pbtxt
@@ -0,0 +1,59 @@
+op {
+  name: "MatrixBandPart"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "num_lower"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "num_upper"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "band"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
+op {
+  name: "MatrixBandPart"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "num_lower"
+    type_attr: "Tindex"
+  }
+  input_arg {
+    name: "num_upper"
+    type_attr: "Tindex"
+  }
+  output_arg {
+    name: "band"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tindex"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MatrixDeterminant.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MatrixDeterminant.pbtxt
new file mode 100644
index 0000000..4dd524d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MatrixDeterminant.pbtxt
@@ -0,0 +1,68 @@
+op {
+  name: "MatrixDeterminant"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "MatrixDeterminant"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "MatrixDeterminant"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MatrixDiag.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MatrixDiag.pbtxt
new file mode 100644
index 0000000..9b0ddb0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MatrixDiag.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "MatrixDiag"
+  input_arg {
+    name: "diagonal"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MatrixDiagPart.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MatrixDiagPart.pbtxt
new file mode 100644
index 0000000..efb1e18
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MatrixDiagPart.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "MatrixDiagPart"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "diagonal"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MatrixDiagPartV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MatrixDiagPartV2.pbtxt
new file mode 100644
index 0000000..f709c6d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MatrixDiagPartV2.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "MatrixDiagPartV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "k"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "padding_value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "diagonal"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MatrixDiagV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MatrixDiagV2.pbtxt
new file mode 100644
index 0000000..3f6aa1e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MatrixDiagV2.pbtxt
@@ -0,0 +1,31 @@
+op {
+  name: "MatrixDiagV2"
+  input_arg {
+    name: "diagonal"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "k"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "num_rows"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "num_cols"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "padding_value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MatrixExponential.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MatrixExponential.pbtxt
new file mode 100644
index 0000000..008291a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MatrixExponential.pbtxt
@@ -0,0 +1,76 @@
+op {
+  name: "MatrixExponential"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "MatrixExponential"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  deprecation {
+    version: 27
+  }
+}
+op {
+  name: "MatrixExponential"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  deprecation {
+    version: 27
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MatrixInverse.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MatrixInverse.pbtxt
new file mode 100644
index 0000000..81d35ad
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MatrixInverse.pbtxt
@@ -0,0 +1,89 @@
+op {
+  name: "MatrixInverse"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "adjoint"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+      }
+    }
+  }
+}
+op {
+  name: "MatrixInverse"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "adjoint"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "MatrixInverse"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "adjoint"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MatrixLogarithm.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MatrixLogarithm.pbtxt
new file mode 100644
index 0000000..0a87e59
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MatrixLogarithm.pbtxt
@@ -0,0 +1,21 @@
+op {
+  name: "MatrixLogarithm"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MatrixSetDiag.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MatrixSetDiag.pbtxt
new file mode 100644
index 0000000..e8c08f8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MatrixSetDiag.pbtxt
@@ -0,0 +1,19 @@
+op {
+  name: "MatrixSetDiag"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "diagonal"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MatrixSetDiagV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MatrixSetDiagV2.pbtxt
new file mode 100644
index 0000000..1147220
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MatrixSetDiagV2.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "MatrixSetDiagV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "diagonal"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "k"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MatrixSolve.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MatrixSolve.pbtxt
new file mode 100644
index 0000000..2a28fa0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MatrixSolve.pbtxt
@@ -0,0 +1,69 @@
+op {
+  name: "MatrixSolve"
+  input_arg {
+    name: "matrix"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rhs"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "adjoint"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "MatrixSolve"
+  input_arg {
+    name: "matrix"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rhs"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "adjoint"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MatrixSolveLs.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MatrixSolveLs.pbtxt
new file mode 100644
index 0000000..5df48fc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MatrixSolveLs.pbtxt
@@ -0,0 +1,113 @@
+op {
+  name: "MatrixSolveLs"
+  input_arg {
+    name: "matrix"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rhs"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2_regularizer"
+    type: DT_DOUBLE
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "fast"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
+op {
+  name: "MatrixSolveLs"
+  input_arg {
+    name: "matrix"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rhs"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2_regularizer"
+    type: DT_DOUBLE
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  attr {
+    name: "fast"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
+op {
+  name: "MatrixSolveLs"
+  input_arg {
+    name: "matrix"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rhs"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2_regularizer"
+    type: DT_DOUBLE
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  attr {
+    name: "fast"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MatrixSquareRoot.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MatrixSquareRoot.pbtxt
new file mode 100644
index 0000000..32ff859
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MatrixSquareRoot.pbtxt
@@ -0,0 +1,47 @@
+op {
+  name: "MatrixSquareRoot"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "MatrixSquareRoot"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MatrixTriangularSolve.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MatrixTriangularSolve.pbtxt
new file mode 100644
index 0000000..1755e7d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MatrixTriangularSolve.pbtxt
@@ -0,0 +1,122 @@
+op {
+  name: "MatrixTriangularSolve"
+  input_arg {
+    name: "matrix"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rhs"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "lower"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "adjoint"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+      }
+    }
+  }
+}
+op {
+  name: "MatrixTriangularSolve"
+  input_arg {
+    name: "matrix"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rhs"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "lower"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "adjoint"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "MatrixTriangularSolve"
+  input_arg {
+    name: "matrix"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rhs"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "lower"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "adjoint"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Max.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Max.pbtxt
new file mode 100644
index 0000000..4c931cc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Max.pbtxt
@@ -0,0 +1,236 @@
+op {
+  name: "Max"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Max"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Max"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Max"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MaxIntraOpParallelismDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MaxIntraOpParallelismDataset.pbtxt
new file mode 100644
index 0000000..dd209ee
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MaxIntraOpParallelismDataset.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "MaxIntraOpParallelismDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "max_intra_op_parallelism"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MaxPool.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MaxPool.pbtxt
new file mode 100644
index 0000000..ff78964
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MaxPool.pbtxt
@@ -0,0 +1,262 @@
+op {
+  name: "MaxPool"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+}
+op {
+  name: "MaxPool"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+}
+op {
+  name: "MaxPool"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_QINT8
+      }
+    }
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+        s: "NCHW_VECT_C"
+      }
+    }
+  }
+}
+op {
+  name: "MaxPool"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_QINT8
+      }
+    }
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+        s: "NCHW_VECT_C"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MaxPool3D.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MaxPool3D.pbtxt
new file mode 100644
index 0000000..7af4fca
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MaxPool3D.pbtxt
@@ -0,0 +1,210 @@
+op {
+  name: "MaxPool3D"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+      }
+    }
+  }
+}
+op {
+  name: "MaxPool3D"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NDHWC"
+    }
+    allowed_values {
+      list {
+        s: "NDHWC"
+        s: "NCDHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+      }
+    }
+  }
+}
+op {
+  name: "MaxPool3D"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NDHWC"
+    }
+    allowed_values {
+      list {
+        s: "NDHWC"
+        s: "NCDHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+      }
+    }
+  }
+}
+op {
+  name: "MaxPool3D"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NDHWC"
+    }
+    allowed_values {
+      list {
+        s: "NDHWC"
+        s: "NCDHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MaxPool3DGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MaxPool3DGrad.pbtxt
new file mode 100644
index 0000000..77edcb4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MaxPool3DGrad.pbtxt
@@ -0,0 +1,353 @@
+op {
+  name: "MaxPool3DGrad"
+  input_arg {
+    name: "orig_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "orig_output"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+      }
+    }
+  }
+}
+op {
+  name: "MaxPool3DGrad"
+  input_arg {
+    name: "orig_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "orig_output"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NDHWC"
+    }
+    allowed_values {
+      list {
+        s: "NDHWC"
+        s: "NCDHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+      }
+    }
+  }
+}
+op {
+  name: "MaxPool3DGrad"
+  input_arg {
+    name: "orig_input"
+    type_attr: "TInput"
+  }
+  input_arg {
+    name: "orig_output"
+    type_attr: "TInput"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NDHWC"
+    }
+    allowed_values {
+      list {
+        s: "NDHWC"
+        s: "NCDHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "TInput"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+      }
+    }
+  }
+}
+op {
+  name: "MaxPool3DGrad"
+  input_arg {
+    name: "orig_input"
+    type_attr: "TInput"
+  }
+  input_arg {
+    name: "orig_output"
+    type_attr: "TInput"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NDHWC"
+    }
+    allowed_values {
+      list {
+        s: "NDHWC"
+        s: "NCDHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "TInput"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+      }
+    }
+  }
+}
+op {
+  name: "MaxPool3DGrad"
+  input_arg {
+    name: "orig_input"
+    type_attr: "TInput"
+  }
+  input_arg {
+    name: "orig_output"
+    type_attr: "TInput"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NDHWC"
+    }
+    allowed_values {
+      list {
+        s: "NDHWC"
+        s: "NCDHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "TInput"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MaxPool3DGradGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MaxPool3DGradGrad.pbtxt
new file mode 100644
index 0000000..55d26c1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MaxPool3DGradGrad.pbtxt
@@ -0,0 +1,137 @@
+op {
+  name: "MaxPool3DGradGrad"
+  input_arg {
+    name: "orig_input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "orig_output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NDHWC"
+    }
+    allowed_values {
+      list {
+        s: "NDHWC"
+        s: "NCDHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+      }
+    }
+  }
+}
+op {
+  name: "MaxPool3DGradGrad"
+  input_arg {
+    name: "orig_input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "orig_output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 5
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NDHWC"
+    }
+    allowed_values {
+      list {
+        s: "NDHWC"
+        s: "NCDHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MaxPoolGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MaxPoolGrad.pbtxt
new file mode 100644
index 0000000..b54e555
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MaxPoolGrad.pbtxt
@@ -0,0 +1,371 @@
+op {
+  name: "MaxPoolGrad"
+  input_arg {
+    name: "orig_input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "orig_output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolGrad"
+  input_arg {
+    name: "orig_input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "orig_output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolGrad"
+  input_arg {
+    name: "orig_input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "orig_output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolGrad"
+  input_arg {
+    name: "orig_input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "orig_output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolGrad"
+  input_arg {
+    name: "orig_input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "orig_output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MaxPoolGradGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MaxPoolGradGrad.pbtxt
new file mode 100644
index 0000000..9b1f4de
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MaxPoolGradGrad.pbtxt
@@ -0,0 +1,292 @@
+op {
+  name: "MaxPoolGradGrad"
+  input_arg {
+    name: "orig_input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "orig_output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolGradGrad"
+  input_arg {
+    name: "orig_input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "orig_output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolGradGrad"
+  input_arg {
+    name: "orig_input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "orig_output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolGradGrad"
+  input_arg {
+    name: "orig_input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "orig_output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MaxPoolGradGradV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MaxPoolGradGradV2.pbtxt
new file mode 100644
index 0000000..fba1ab5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MaxPoolGradGradV2.pbtxt
@@ -0,0 +1,276 @@
+op {
+  name: "MaxPoolGradGradV2"
+  input_arg {
+    name: "orig_input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "orig_output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "ksize"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "strides"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolGradGradV2"
+  input_arg {
+    name: "orig_input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "orig_output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "ksize"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "strides"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolGradGradV2"
+  input_arg {
+    name: "orig_input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "orig_output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "ksize"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "strides"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolGradGradV2"
+  input_arg {
+    name: "orig_input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "orig_output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "ksize"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "strides"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MaxPoolGradGradWithArgmax.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MaxPoolGradGradWithArgmax.pbtxt
new file mode 100644
index 0000000..3c3cdbb
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MaxPoolGradGradWithArgmax.pbtxt
@@ -0,0 +1,358 @@
+op {
+  name: "MaxPoolGradGradWithArgmax"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "argmax"
+    type_attr: "Targmax"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "Targmax"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolGradGradWithArgmax"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "argmax"
+    type_attr: "Targmax"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "Targmax"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolGradGradWithArgmax"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "argmax"
+    type_attr: "Targmax"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "Targmax"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolGradGradWithArgmax"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "argmax"
+    type_attr: "Targmax"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "Targmax"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolGradGradWithArgmax"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "argmax"
+    type_attr: "Targmax"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "include_batch_in_index"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "Targmax"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MaxPoolGradV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MaxPoolGradV2.pbtxt
new file mode 100644
index 0000000..7e38cf8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MaxPoolGradV2.pbtxt
@@ -0,0 +1,288 @@
+op {
+  name: "MaxPoolGradV2"
+  input_arg {
+    name: "orig_input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "orig_output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "ksize"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "strides"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolGradV2"
+  input_arg {
+    name: "orig_input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "orig_output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "ksize"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "strides"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolGradV2"
+  input_arg {
+    name: "orig_input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "orig_output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "ksize"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "strides"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolGradV2"
+  input_arg {
+    name: "orig_input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "orig_output"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "ksize"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "strides"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MaxPoolGradWithArgmax.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MaxPoolGradWithArgmax.pbtxt
new file mode 100644
index 0000000..7c3ab4a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MaxPoolGradWithArgmax.pbtxt
@@ -0,0 +1,422 @@
+op {
+  name: "MaxPoolGradWithArgmax"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "argmax"
+    type_attr: "Targmax"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "Targmax"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolGradWithArgmax"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "argmax"
+    type_attr: "Targmax"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "Targmax"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolGradWithArgmax"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "argmax"
+    type_attr: "Targmax"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "Targmax"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolGradWithArgmax"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "argmax"
+    type_attr: "Targmax"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "Targmax"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolGradWithArgmax"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "argmax"
+    type_attr: "Targmax"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "Targmax"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolGradWithArgmax"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "argmax"
+    type_attr: "Targmax"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "include_batch_in_index"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "Targmax"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MaxPoolV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MaxPoolV2.pbtxt
new file mode 100644
index 0000000..3ef7da8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MaxPoolV2.pbtxt
@@ -0,0 +1,191 @@
+op {
+  name: "MaxPoolV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "ksize"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "strides"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "ksize"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "strides"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_QINT8
+      }
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+        s: "NCHW_VECT_C"
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "ksize"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "strides"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_QINT8
+      }
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+        s: "NCHW_VECT_C"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MaxPoolWithArgmax.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MaxPoolWithArgmax.pbtxt
new file mode 100644
index 0000000..d33bbd2
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MaxPoolWithArgmax.pbtxt
@@ -0,0 +1,416 @@
+op {
+  name: "MaxPoolWithArgmax"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "argmax"
+    type_attr: "Targmax"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "Targmax"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolWithArgmax"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "argmax"
+    type_attr: "Targmax"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "Targmax"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolWithArgmax"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "argmax"
+    type_attr: "Targmax"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "Targmax"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolWithArgmax"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "argmax"
+    type_attr: "Targmax"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "Targmax"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolWithArgmax"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "argmax"
+    type_attr: "Targmax"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "Targmax"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "MaxPoolWithArgmax"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "argmax"
+    type_attr: "Targmax"
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+    has_minimum: true
+    minimum: 4
+  }
+  attr {
+    name: "Targmax"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "include_batch_in_index"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Maximum.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Maximum.pbtxt
new file mode 100644
index 0000000..6ca1504
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Maximum.pbtxt
@@ -0,0 +1,118 @@
+op {
+  name: "Maximum"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "Maximum"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "Maximum"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "Maximum"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Mean.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Mean.pbtxt
new file mode 100644
index 0000000..ae37662
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Mean.pbtxt
@@ -0,0 +1,236 @@
+op {
+  name: "Mean"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Mean"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Mean"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Mean"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Merge.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Merge.pbtxt
new file mode 100644
index 0000000..d08f9cc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Merge.pbtxt
@@ -0,0 +1,26 @@
+op {
+  name: "Merge"
+  input_arg {
+    name: "inputs"
+    type_attr: "T"
+    number_attr: "N"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "value_index"
+    type: DT_INT32
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MergeSummary.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MergeSummary.pbtxt
new file mode 100644
index 0000000..d9b14d4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MergeSummary.pbtxt
@@ -0,0 +1,18 @@
+op {
+  name: "MergeSummary"
+  input_arg {
+    name: "inputs"
+    type: DT_STRING
+    number_attr: "N"
+  }
+  output_arg {
+    name: "summary"
+    type: DT_STRING
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MergeV2Checkpoints.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MergeV2Checkpoints.pbtxt
new file mode 100644
index 0000000..44158b9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MergeV2Checkpoints.pbtxt
@@ -0,0 +1,37 @@
+op {
+  name: "MergeV2Checkpoints"
+  input_arg {
+    name: "checkpoint_prefixes"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "destination_prefix"
+    type: DT_STRING
+  }
+  attr {
+    name: "delete_old_dirs"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
+op {
+  name: "MergeV2Checkpoints"
+  input_arg {
+    name: "checkpoint_prefixes"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "destination_prefix"
+    type: DT_STRING
+  }
+  attr {
+    name: "delete_old_dirs"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Mfcc.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Mfcc.pbtxt
new file mode 100644
index 0000000..4c22eb8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Mfcc.pbtxt
@@ -0,0 +1,43 @@
+op {
+  name: "Mfcc"
+  input_arg {
+    name: "spectrogram"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "sample_rate"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "upper_frequency_limit"
+    type: "float"
+    default_value {
+      f: 4000
+    }
+  }
+  attr {
+    name: "lower_frequency_limit"
+    type: "float"
+    default_value {
+      f: 20
+    }
+  }
+  attr {
+    name: "filterbank_channel_count"
+    type: "int"
+    default_value {
+      i: 40
+    }
+  }
+  attr {
+    name: "dct_coefficient_count"
+    type: "int"
+    default_value {
+      i: 13
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Min.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Min.pbtxt
new file mode 100644
index 0000000..f0ebdb0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Min.pbtxt
@@ -0,0 +1,236 @@
+op {
+  name: "Min"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Min"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Min"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Min"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Minimum.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Minimum.pbtxt
new file mode 100644
index 0000000..9cebfc5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Minimum.pbtxt
@@ -0,0 +1,118 @@
+op {
+  name: "Minimum"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "Minimum"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "Minimum"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "Minimum"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MirrorPad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MirrorPad.pbtxt
new file mode 100644
index 0000000..bf64a6c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MirrorPad.pbtxt
@@ -0,0 +1,42 @@
+op {
+  name: "MirrorPad"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "paddings"
+    type_attr: "Tpaddings"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tpaddings"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "mode"
+    type: "string"
+    allowed_values {
+      list {
+        s: "REFLECT"
+        s: "SYMMETRIC"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MirrorPadGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MirrorPadGrad.pbtxt
new file mode 100644
index 0000000..b544cfb
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MirrorPadGrad.pbtxt
@@ -0,0 +1,42 @@
+op {
+  name: "MirrorPadGrad"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "paddings"
+    type_attr: "Tpaddings"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tpaddings"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "mode"
+    type: "string"
+    allowed_values {
+      list {
+        s: "REFLECT"
+        s: "SYMMETRIC"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Mod.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Mod.pbtxt
new file mode 100644
index 0000000..6c39ed6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Mod.pbtxt
@@ -0,0 +1,85 @@
+op {
+  name: "Mod"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "Mod"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "Mod"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_HALF
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ModelDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ModelDataset.pbtxt
new file mode 100644
index 0000000..81973fd
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ModelDataset.pbtxt
@@ -0,0 +1,90 @@
+op {
+  name: "ModelDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
+  name: "ModelDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "cpu_budget"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
+  name: "ModelDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "algorithm"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "cpu_budget"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Mul.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Mul.pbtxt
new file mode 100644
index 0000000..a52bc14
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Mul.pbtxt
@@ -0,0 +1,107 @@
+op {
+  name: "Mul"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "Mul"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "Mul"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  is_commutative: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MulNoNan.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MulNoNan.pbtxt
new file mode 100644
index 0000000..1ce9f1d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MulNoNan.pbtxt
@@ -0,0 +1,83 @@
+op {
+  name: "MulNoNan"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "MulNoNan"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "MulNoNan"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MultiDeviceIterator.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MultiDeviceIterator.pbtxt
new file mode 100644
index 0000000..d85c553
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MultiDeviceIterator.pbtxt
@@ -0,0 +1,34 @@
+op {
+  name: "MultiDeviceIterator"
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "devices"
+    type: "list(string)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+  }
+  attr {
+    name: "container"
+    type: "string"
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MultiDeviceIteratorFromStringHandle.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MultiDeviceIteratorFromStringHandle.pbtxt
new file mode 100644
index 0000000..384b147
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MultiDeviceIteratorFromStringHandle.pbtxt
@@ -0,0 +1,30 @@
+op {
+  name: "MultiDeviceIteratorFromStringHandle"
+  input_arg {
+    name: "string_handle"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "multi_device_iterator"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MultiDeviceIteratorGetNextFromShard.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MultiDeviceIteratorGetNextFromShard.pbtxt
new file mode 100644
index 0000000..2e007c2
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MultiDeviceIteratorGetNextFromShard.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "MultiDeviceIteratorGetNextFromShard"
+  input_arg {
+    name: "multi_device_iterator"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "shard_num"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "incarnation_id"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "components"
+    type_list_attr: "output_types"
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MultiDeviceIteratorInit.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MultiDeviceIteratorInit.pbtxt
new file mode 100644
index 0000000..a011997
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MultiDeviceIteratorInit.pbtxt
@@ -0,0 +1,20 @@
+op {
+  name: "MultiDeviceIteratorInit"
+  input_arg {
+    name: "dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "multi_device_iterator"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "max_buffer_size"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "incarnation_id"
+    type: DT_INT64
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MultiDeviceIteratorToStringHandle.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MultiDeviceIteratorToStringHandle.pbtxt
new file mode 100644
index 0000000..d7780d7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MultiDeviceIteratorToStringHandle.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "MultiDeviceIteratorToStringHandle"
+  input_arg {
+    name: "multi_device_iterator"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "string_handle"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Multinomial.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Multinomial.pbtxt
new file mode 100644
index 0000000..c258fa6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Multinomial.pbtxt
@@ -0,0 +1,222 @@
+op {
+  name: "Multinomial"
+  input_arg {
+    name: "logits"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "num_samples"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type: DT_INT64
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "Multinomial"
+  input_arg {
+    name: "logits"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "num_samples"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type: DT_INT64
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "Multinomial"
+  input_arg {
+    name: "logits"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "num_samples"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "output_dtype"
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "output_dtype"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "Multinomial"
+  input_arg {
+    name: "logits"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "num_samples"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "output_dtype"
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "output_dtype"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MutableDenseHashTable.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MutableDenseHashTable.pbtxt
new file mode 100644
index 0000000..eecaeb2
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MutableDenseHashTable.pbtxt
@@ -0,0 +1,64 @@
+op {
+  name: "MutableDenseHashTable"
+  input_arg {
+    name: "empty_key"
+    type_attr: "key_dtype"
+  }
+  output_arg {
+    name: "table_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "use_node_name_sharing"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "key_dtype"
+    type: "type"
+  }
+  attr {
+    name: "value_dtype"
+    type: "type"
+  }
+  attr {
+    name: "value_shape"
+    type: "shape"
+    default_value {
+      shape {
+      }
+    }
+  }
+  attr {
+    name: "initial_num_buckets"
+    type: "int"
+    default_value {
+      i: 131072
+    }
+  }
+  attr {
+    name: "max_load_factor"
+    type: "float"
+    default_value {
+      f: 0.8
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MutableDenseHashTableV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MutableDenseHashTableV2.pbtxt
new file mode 100644
index 0000000..739079c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MutableDenseHashTableV2.pbtxt
@@ -0,0 +1,67 @@
+op {
+  name: "MutableDenseHashTableV2"
+  input_arg {
+    name: "empty_key"
+    type_attr: "key_dtype"
+  }
+  input_arg {
+    name: "deleted_key"
+    type_attr: "key_dtype"
+  }
+  output_arg {
+    name: "table_handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "use_node_name_sharing"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "key_dtype"
+    type: "type"
+  }
+  attr {
+    name: "value_dtype"
+    type: "type"
+  }
+  attr {
+    name: "value_shape"
+    type: "shape"
+    default_value {
+      shape {
+      }
+    }
+  }
+  attr {
+    name: "initial_num_buckets"
+    type: "int"
+    default_value {
+      i: 131072
+    }
+  }
+  attr {
+    name: "max_load_factor"
+    type: "float"
+    default_value {
+      f: 0.8
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MutableHashTable.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MutableHashTable.pbtxt
new file mode 100644
index 0000000..a8ecc34
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MutableHashTable.pbtxt
@@ -0,0 +1,38 @@
+op {
+  name: "MutableHashTable"
+  output_arg {
+    name: "table_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "use_node_name_sharing"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "key_dtype"
+    type: "type"
+  }
+  attr {
+    name: "value_dtype"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MutableHashTableOfTensors.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MutableHashTableOfTensors.pbtxt
new file mode 100644
index 0000000..bdec2ff
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MutableHashTableOfTensors.pbtxt
@@ -0,0 +1,46 @@
+op {
+  name: "MutableHashTableOfTensors"
+  output_arg {
+    name: "table_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "use_node_name_sharing"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "key_dtype"
+    type: "type"
+  }
+  attr {
+    name: "value_dtype"
+    type: "type"
+  }
+  attr {
+    name: "value_shape"
+    type: "shape"
+    default_value {
+      shape {
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MutableHashTableOfTensorsV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MutableHashTableOfTensorsV2.pbtxt
new file mode 100644
index 0000000..dc46d07
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MutableHashTableOfTensorsV2.pbtxt
@@ -0,0 +1,45 @@
+op {
+  name: "MutableHashTableOfTensorsV2"
+  output_arg {
+    name: "table_handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "use_node_name_sharing"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "key_dtype"
+    type: "type"
+  }
+  attr {
+    name: "value_dtype"
+    type: "type"
+  }
+  attr {
+    name: "value_shape"
+    type: "shape"
+    default_value {
+      shape {
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MutableHashTableV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MutableHashTableV2.pbtxt
new file mode 100644
index 0000000..610214d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MutableHashTableV2.pbtxt
@@ -0,0 +1,37 @@
+op {
+  name: "MutableHashTableV2"
+  output_arg {
+    name: "table_handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "use_node_name_sharing"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "key_dtype"
+    type: "type"
+  }
+  attr {
+    name: "value_dtype"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MutexLock.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MutexLock.pbtxt
new file mode 100644
index 0000000..243770b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MutexLock.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "MutexLock"
+  input_arg {
+    name: "mutex"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "mutex_lock"
+    type: DT_VARIANT
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/MutexV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MutexV2.pbtxt
new file mode 100644
index 0000000..b20f9b1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/MutexV2.pbtxt
@@ -0,0 +1,22 @@
+op {
+  name: "MutexV2"
+  output_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/NcclAllReduce.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/NcclAllReduce.pbtxt
new file mode 100644
index 0000000..80f91ed
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/NcclAllReduce.pbtxt
@@ -0,0 +1,45 @@
+op {
+  name: "NcclAllReduce"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  attr {
+    name: "reduction"
+    type: "string"
+    allowed_values {
+      list {
+        s: "min"
+        s: "max"
+        s: "prod"
+        s: "sum"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "num_devices"
+    type: "int"
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/NcclBroadcast.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/NcclBroadcast.pbtxt
new file mode 100644
index 0000000..02a5487
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/NcclBroadcast.pbtxt
@@ -0,0 +1,29 @@
+op {
+  name: "NcclBroadcast"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/NcclReduce.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/NcclReduce.pbtxt
new file mode 100644
index 0000000..507f92c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/NcclReduce.pbtxt
@@ -0,0 +1,44 @@
+op {
+  name: "NcclReduce"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+    number_attr: "num_devices"
+  }
+  output_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  attr {
+    name: "reduction"
+    type: "string"
+    allowed_values {
+      list {
+        s: "min"
+        s: "max"
+        s: "prod"
+        s: "sum"
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "num_devices"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/NearestNeighbors.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/NearestNeighbors.pbtxt
new file mode 100644
index 0000000..5d1e5ed
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/NearestNeighbors.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "NearestNeighbors"
+  input_arg {
+    name: "points"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "centers"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "k"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "nearest_center_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "nearest_center_distances"
+    type: DT_FLOAT
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Neg.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Neg.pbtxt
new file mode 100644
index 0000000..77bb4a5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Neg.pbtxt
@@ -0,0 +1,80 @@
+op {
+  name: "Neg"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Neg"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Neg"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/NegTrain.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/NegTrain.pbtxt
new file mode 100644
index 0000000..f12529f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/NegTrain.pbtxt
@@ -0,0 +1,37 @@
+op {
+  name: "NegTrain"
+  input_arg {
+    name: "w_in"
+    type: DT_FLOAT
+    is_ref: true
+  }
+  input_arg {
+    name: "w_out"
+    type: DT_FLOAT
+    is_ref: true
+  }
+  input_arg {
+    name: "examples"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "labels"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "lr"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "vocab_count"
+    type: "list(int)"
+  }
+  attr {
+    name: "num_negative_samples"
+    type: "int"
+  }
+  deprecation {
+    version: 19
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/NextAfter.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/NextAfter.pbtxt
new file mode 100644
index 0000000..70e4afe
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/NextAfter.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "NextAfter"
+  input_arg {
+    name: "x1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "x2"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/NextIteration.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/NextIteration.pbtxt
new file mode 100644
index 0000000..7186fc0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/NextIteration.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "NextIteration"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/NoOp.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/NoOp.pbtxt
new file mode 100644
index 0000000..8f03706
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/NoOp.pbtxt
@@ -0,0 +1,3 @@
+op {
+  name: "NoOp"
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/NonDeterministicInts.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/NonDeterministicInts.pbtxt
new file mode 100644
index 0000000..3fa5aa4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/NonDeterministicInts.pbtxt
@@ -0,0 +1,26 @@
+op {
+  name: "NonDeterministicInts"
+  input_arg {
+    name: "shape"
+    type_attr: "shape_dtype"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+  }
+  attr {
+    name: "shape_dtype"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/NonMaxSuppression.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/NonMaxSuppression.pbtxt
new file mode 100644
index 0000000..ded8b37
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/NonMaxSuppression.pbtxt
@@ -0,0 +1,26 @@
+op {
+  name: "NonMaxSuppression"
+  input_arg {
+    name: "boxes"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "scores"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_output_size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "selected_indices"
+    type: DT_INT32
+  }
+  attr {
+    name: "iou_threshold"
+    type: "float"
+    default_value {
+      f: 0.5
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/NonMaxSuppressionV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/NonMaxSuppressionV2.pbtxt
new file mode 100644
index 0000000..90c23bc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/NonMaxSuppressionV2.pbtxt
@@ -0,0 +1,108 @@
+op {
+  name: "NonMaxSuppressionV2"
+  input_arg {
+    name: "boxes"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "scores"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_output_size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "iou_threshold"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "selected_indices"
+    type: DT_INT32
+  }
+}
+op {
+  name: "NonMaxSuppressionV2"
+  input_arg {
+    name: "boxes"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "scores"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "max_output_size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "iou_threshold"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "selected_indices"
+    type: DT_INT32
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
+}
+op {
+  name: "NonMaxSuppressionV2"
+  input_arg {
+    name: "boxes"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "scores"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "max_output_size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "iou_threshold"
+    type_attr: "T_threshold"
+  }
+  output_arg {
+    name: "selected_indices"
+    type: DT_INT32
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "T_threshold"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/NonMaxSuppressionV3.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/NonMaxSuppressionV3.pbtxt
new file mode 100644
index 0000000..daeffd8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/NonMaxSuppressionV3.pbtxt
@@ -0,0 +1,120 @@
+op {
+  name: "NonMaxSuppressionV3"
+  input_arg {
+    name: "boxes"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "scores"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_output_size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "iou_threshold"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "score_threshold"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "selected_indices"
+    type: DT_INT32
+  }
+}
+op {
+  name: "NonMaxSuppressionV3"
+  input_arg {
+    name: "boxes"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "scores"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "max_output_size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "iou_threshold"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "score_threshold"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "selected_indices"
+    type: DT_INT32
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
+}
+op {
+  name: "NonMaxSuppressionV3"
+  input_arg {
+    name: "boxes"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "scores"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "max_output_size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "iou_threshold"
+    type_attr: "T_threshold"
+  }
+  input_arg {
+    name: "score_threshold"
+    type_attr: "T_threshold"
+  }
+  output_arg {
+    name: "selected_indices"
+    type: DT_INT32
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "T_threshold"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/NonMaxSuppressionV4.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/NonMaxSuppressionV4.pbtxt
new file mode 100644
index 0000000..07ca92f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/NonMaxSuppressionV4.pbtxt
@@ -0,0 +1,153 @@
+op {
+  name: "NonMaxSuppressionV4"
+  input_arg {
+    name: "boxes"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "scores"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_output_size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "iou_threshold"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "score_threshold"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "selected_indices"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "valid_outputs"
+    type: DT_INT32
+  }
+  attr {
+    name: "pad_to_max_output_size"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "NonMaxSuppressionV4"
+  input_arg {
+    name: "boxes"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "scores"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "max_output_size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "iou_threshold"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "score_threshold"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "selected_indices"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "valid_outputs"
+    type: DT_INT32
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "pad_to_max_output_size"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "NonMaxSuppressionV4"
+  input_arg {
+    name: "boxes"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "scores"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "max_output_size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "iou_threshold"
+    type_attr: "T_threshold"
+  }
+  input_arg {
+    name: "score_threshold"
+    type_attr: "T_threshold"
+  }
+  output_arg {
+    name: "selected_indices"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "valid_outputs"
+    type: DT_INT32
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "T_threshold"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "pad_to_max_output_size"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/NonMaxSuppressionV5.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/NonMaxSuppressionV5.pbtxt
new file mode 100644
index 0000000..cabec76
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/NonMaxSuppressionV5.pbtxt
@@ -0,0 +1,59 @@
+op {
+  name: "NonMaxSuppressionV5"
+  input_arg {
+    name: "boxes"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "scores"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "max_output_size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "iou_threshold"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "score_threshold"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "soft_nms_sigma"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "selected_indices"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "selected_scores"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "valid_outputs"
+    type: DT_INT32
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "pad_to_max_output_size"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/NonMaxSuppressionWithOverlaps.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/NonMaxSuppressionWithOverlaps.pbtxt
new file mode 100644
index 0000000..d89eeee
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/NonMaxSuppressionWithOverlaps.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "NonMaxSuppressionWithOverlaps"
+  input_arg {
+    name: "overlaps"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "scores"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_output_size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "overlap_threshold"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "score_threshold"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "selected_indices"
+    type: DT_INT32
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/NonSerializableDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/NonSerializableDataset.pbtxt
new file mode 100644
index 0000000..a1290dc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/NonSerializableDataset.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "NonSerializableDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/NotEqual.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/NotEqual.pbtxt
new file mode 100644
index 0000000..b08a3cc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/NotEqual.pbtxt
@@ -0,0 +1,119 @@
+op {
+  name: "NotEqual"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_QUINT8
+        type: DT_QINT8
+        type: DT_QINT32
+        type: DT_STRING
+        type: DT_BOOL
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "NotEqual"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_QUINT8
+        type: DT_QINT8
+        type: DT_QINT32
+        type: DT_STRING
+        type: DT_BOOL
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "NotEqual"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type: DT_BOOL
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_QUINT8
+        type: DT_QINT8
+        type: DT_QINT32
+        type: DT_STRING
+        type: DT_BOOL
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  is_commutative: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/NthElement.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/NthElement.pbtxt
new file mode 100644
index 0000000..c9e7972
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/NthElement.pbtxt
@@ -0,0 +1,125 @@
+op {
+  name: "NthElement"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "n"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  attr {
+    name: "reverse"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "NthElement"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "n"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  attr {
+    name: "reverse"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "NthElement"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "n"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  attr {
+    name: "reverse"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/OneHot.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/OneHot.pbtxt
new file mode 100644
index 0000000..7c2d6b8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/OneHot.pbtxt
@@ -0,0 +1,48 @@
+op {
+  name: "OneHot"
+  input_arg {
+    name: "indices"
+    type_attr: "TI"
+  }
+  input_arg {
+    name: "depth"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "on_value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "off_value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "axis"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "TI"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/OneShotIterator.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/OneShotIterator.pbtxt
new file mode 100644
index 0000000..a2969bc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/OneShotIterator.pbtxt
@@ -0,0 +1,38 @@
+op {
+  name: "OneShotIterator"
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "dataset_factory"
+    type: "func"
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/OnesLike.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/OnesLike.pbtxt
new file mode 100644
index 0000000..270d01a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/OnesLike.pbtxt
@@ -0,0 +1,88 @@
+op {
+  name: "OnesLike"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "OnesLike"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT8
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_UINT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_BOOL
+      }
+    }
+  }
+}
+op {
+  name: "OnesLike"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT8
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_UINT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_BOOL
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/OptimizeDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/OptimizeDataset.pbtxt
new file mode 100644
index 0000000..cd611bf
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/OptimizeDataset.pbtxt
@@ -0,0 +1,62 @@
+op {
+  name: "OptimizeDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "optimizations"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
+  name: "OptimizeDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "optimizations"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "optimization_configs"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/OptionalFromValue.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/OptionalFromValue.pbtxt
new file mode 100644
index 0000000..b079f56
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/OptionalFromValue.pbtxt
@@ -0,0 +1,17 @@
+op {
+  name: "OptionalFromValue"
+  input_arg {
+    name: "components"
+    type_list_attr: "Toutput_types"
+  }
+  output_arg {
+    name: "optional"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "Toutput_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/OptionalGetValue.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/OptionalGetValue.pbtxt
new file mode 100644
index 0000000..e7364a1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/OptionalGetValue.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "OptionalGetValue"
+  input_arg {
+    name: "optional"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "components"
+    type_list_attr: "output_types"
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/OptionalHasValue.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/OptionalHasValue.pbtxt
new file mode 100644
index 0000000..da76333
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/OptionalHasValue.pbtxt
@@ -0,0 +1,11 @@
+op {
+  name: "OptionalHasValue"
+  input_arg {
+    name: "optional"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "has_value"
+    type: DT_BOOL
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/OptionalNone.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/OptionalNone.pbtxt
new file mode 100644
index 0000000..c47d6a7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/OptionalNone.pbtxt
@@ -0,0 +1,7 @@
+op {
+  name: "OptionalNone"
+  output_arg {
+    name: "optional"
+    type: DT_VARIANT
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/OrderedMapClear.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/OrderedMapClear.pbtxt
new file mode 100644
index 0000000..726e26e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/OrderedMapClear.pbtxt
@@ -0,0 +1,38 @@
+op {
+  name: "OrderedMapClear"
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "memory_limit"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/OrderedMapIncompleteSize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/OrderedMapIncompleteSize.pbtxt
new file mode 100644
index 0000000..9a9572a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/OrderedMapIncompleteSize.pbtxt
@@ -0,0 +1,42 @@
+op {
+  name: "OrderedMapIncompleteSize"
+  output_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "memory_limit"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/OrderedMapPeek.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/OrderedMapPeek.pbtxt
new file mode 100644
index 0000000..0d9fd20
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/OrderedMapPeek.pbtxt
@@ -0,0 +1,52 @@
+op {
+  name: "OrderedMapPeek"
+  input_arg {
+    name: "key"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "values"
+    type_list_attr: "dtypes"
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "memory_limit"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/OrderedMapSize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/OrderedMapSize.pbtxt
new file mode 100644
index 0000000..ea07d7e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/OrderedMapSize.pbtxt
@@ -0,0 +1,42 @@
+op {
+  name: "OrderedMapSize"
+  output_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "memory_limit"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/OrderedMapStage.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/OrderedMapStage.pbtxt
new file mode 100644
index 0000000..76af456
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/OrderedMapStage.pbtxt
@@ -0,0 +1,56 @@
+op {
+  name: "OrderedMapStage"
+  input_arg {
+    name: "key"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "values"
+    type_list_attr: "fake_dtypes"
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "memory_limit"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+  }
+  attr {
+    name: "fake_dtypes"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/OrderedMapUnstage.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/OrderedMapUnstage.pbtxt
new file mode 100644
index 0000000..c09b4be
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/OrderedMapUnstage.pbtxt
@@ -0,0 +1,52 @@
+op {
+  name: "OrderedMapUnstage"
+  input_arg {
+    name: "key"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "values"
+    type_list_attr: "dtypes"
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "memory_limit"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/OrderedMapUnstageNoKey.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/OrderedMapUnstageNoKey.pbtxt
new file mode 100644
index 0000000..bc3e8c7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/OrderedMapUnstageNoKey.pbtxt
@@ -0,0 +1,52 @@
+op {
+  name: "OrderedMapUnstageNoKey"
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "key"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "values"
+    type_list_attr: "dtypes"
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "memory_limit"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/OutfeedDequeue.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/OutfeedDequeue.pbtxt
new file mode 100644
index 0000000..29dc8b5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/OutfeedDequeue.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "OutfeedDequeue"
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  attr {
+    name: "device_ordinal"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/OutfeedDequeueTuple.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/OutfeedDequeueTuple.pbtxt
new file mode 100644
index 0000000..3e0d310
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/OutfeedDequeueTuple.pbtxt
@@ -0,0 +1,25 @@
+op {
+  name: "OutfeedDequeueTuple"
+  output_arg {
+    name: "outputs"
+    type_list_attr: "dtypes"
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "shapes"
+    type: "list(shape)"
+  }
+  attr {
+    name: "device_ordinal"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/OutfeedEnqueue.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/OutfeedEnqueue.pbtxt
new file mode 100644
index 0000000..d8c16f4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/OutfeedEnqueue.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "OutfeedEnqueue"
+  input_arg {
+    name: "input"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/OutfeedEnqueueTuple.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/OutfeedEnqueueTuple.pbtxt
new file mode 100644
index 0000000..0bf1a5b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/OutfeedEnqueueTuple.pbtxt
@@ -0,0 +1,14 @@
+op {
+  name: "OutfeedEnqueueTuple"
+  input_arg {
+    name: "inputs"
+    type_list_attr: "dtypes"
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Pack.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Pack.pbtxt
new file mode 100644
index 0000000..65eb675
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Pack.pbtxt
@@ -0,0 +1,29 @@
+op {
+  name: "Pack"
+  input_arg {
+    name: "values"
+    type_attr: "T"
+    number_attr: "N"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "axis"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Pad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Pad.pbtxt
new file mode 100644
index 0000000..1c7b9c7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Pad.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "Pad"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "paddings"
+    type_attr: "Tpaddings"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tpaddings"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/PadV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/PadV2.pbtxt
new file mode 100644
index 0000000..463cb71
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/PadV2.pbtxt
@@ -0,0 +1,36 @@
+op {
+  name: "PadV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "paddings"
+    type_attr: "Tpaddings"
+  }
+  input_arg {
+    name: "constant_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tpaddings"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/PaddedBatchDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/PaddedBatchDataset.pbtxt
new file mode 100644
index 0000000..69834b6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/PaddedBatchDataset.pbtxt
@@ -0,0 +1,85 @@
+op {
+  name: "PaddedBatchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "batch_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "padded_shapes"
+    type: DT_INT64
+    number_attr: "N"
+  }
+  input_arg {
+    name: "padding_values"
+    type_list_attr: "Toutput_types"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "Toutput_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
+op {
+  name: "PaddedBatchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "batch_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "padded_shapes"
+    type: DT_INT64
+    number_attr: "N"
+  }
+  input_arg {
+    name: "padding_values"
+    type_list_attr: "Toutput_types"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "Toutput_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/PaddedBatchDatasetV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/PaddedBatchDatasetV2.pbtxt
new file mode 100644
index 0000000..52b5acc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/PaddedBatchDatasetV2.pbtxt
@@ -0,0 +1,99 @@
+op {
+  name: "PaddedBatchDatasetV2"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "batch_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "padded_shapes"
+    type: DT_INT64
+    number_attr: "N"
+  }
+  input_arg {
+    name: "padding_values"
+    type_list_attr: "Toutput_types"
+  }
+  input_arg {
+    name: "drop_remainder"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "Toutput_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
+  name: "PaddedBatchDatasetV2"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "batch_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "padded_shapes"
+    type: DT_INT64
+    number_attr: "N"
+  }
+  input_arg {
+    name: "padding_values"
+    type_list_attr: "Toutput_types"
+  }
+  input_arg {
+    name: "drop_remainder"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "parallel_copy"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "Toutput_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/PaddingFIFOQueue.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/PaddingFIFOQueue.pbtxt
new file mode 100644
index 0000000..f5eca52
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/PaddingFIFOQueue.pbtxt
@@ -0,0 +1,45 @@
+op {
+  name: "PaddingFIFOQueue"
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "component_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "shapes"
+    type: "list(shape)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/PaddingFIFOQueueV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/PaddingFIFOQueueV2.pbtxt
new file mode 100644
index 0000000..c398f9e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/PaddingFIFOQueueV2.pbtxt
@@ -0,0 +1,44 @@
+op {
+  name: "PaddingFIFOQueueV2"
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "component_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "shapes"
+    type: "list(shape)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ParallelConcat.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ParallelConcat.pbtxt
new file mode 100644
index 0000000..b0d1cc3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ParallelConcat.pbtxt
@@ -0,0 +1,26 @@
+op {
+  name: "ParallelConcat"
+  input_arg {
+    name: "values"
+    type_attr: "T"
+    number_attr: "N"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ParallelDynamicStitch.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ParallelDynamicStitch.pbtxt
new file mode 100644
index 0000000..9ab18a1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ParallelDynamicStitch.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "ParallelDynamicStitch"
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+    number_attr: "N"
+  }
+  input_arg {
+    name: "data"
+    type_attr: "T"
+    number_attr: "N"
+  }
+  output_arg {
+    name: "merged"
+    type_attr: "T"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ParallelInterleaveDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ParallelInterleaveDataset.pbtxt
new file mode 100644
index 0000000..6b9d2a7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ParallelInterleaveDataset.pbtxt
@@ -0,0 +1,56 @@
+op {
+  name: "ParallelInterleaveDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  input_arg {
+    name: "cycle_length"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "block_length"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sloppy"
+    type: DT_BOOL
+  }
+  input_arg {
+    name: "buffer_output_elements"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "prefetch_input_elements"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ParallelInterleaveDatasetV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ParallelInterleaveDatasetV2.pbtxt
new file mode 100644
index 0000000..de73483
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ParallelInterleaveDatasetV2.pbtxt
@@ -0,0 +1,103 @@
+op {
+  name: "ParallelInterleaveDatasetV2"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  input_arg {
+    name: "cycle_length"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "block_length"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "num_parallel_calls"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
+  name: "ParallelInterleaveDatasetV2"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  input_arg {
+    name: "cycle_length"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "block_length"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "num_parallel_calls"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "sloppy"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ParallelMapDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ParallelMapDataset.pbtxt
new file mode 100644
index 0000000..31aed52
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ParallelMapDataset.pbtxt
@@ -0,0 +1,243 @@
+op {
+  name: "ParallelMapDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  input_arg {
+    name: "num_parallel_calls"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
+op {
+  name: "ParallelMapDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  input_arg {
+    name: "num_parallel_calls"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
+  name: "ParallelMapDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  input_arg {
+    name: "num_parallel_calls"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "use_inter_op_parallelism"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
+op {
+  name: "ParallelMapDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  input_arg {
+    name: "num_parallel_calls"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "use_inter_op_parallelism"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "sloppy"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ParallelMapDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  input_arg {
+    name: "num_parallel_calls"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "use_inter_op_parallelism"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "sloppy"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "preserve_cardinality"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ParameterizedTruncatedNormal.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ParameterizedTruncatedNormal.pbtxt
new file mode 100644
index 0000000..1f96da6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ParameterizedTruncatedNormal.pbtxt
@@ -0,0 +1,127 @@
+op {
+  name: "ParameterizedTruncatedNormal"
+  input_arg {
+    name: "shape"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "means"
+    type_attr: "dtype"
+  }
+  input_arg {
+    name: "stdevs"
+    type_attr: "dtype"
+  }
+  input_arg {
+    name: "minvals"
+    type_attr: "dtype"
+  }
+  input_arg {
+    name: "maxvals"
+    type_attr: "dtype"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ParameterizedTruncatedNormal"
+  input_arg {
+    name: "shape"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "means"
+    type_attr: "dtype"
+  }
+  input_arg {
+    name: "stdevs"
+    type_attr: "dtype"
+  }
+  input_arg {
+    name: "minvals"
+    type_attr: "dtype"
+  }
+  input_arg {
+    name: "maxvals"
+    type_attr: "dtype"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ParseExample.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ParseExample.pbtxt
new file mode 100644
index 0000000..a1e35bd
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ParseExample.pbtxt
@@ -0,0 +1,82 @@
+op {
+  name: "ParseExample"
+  input_arg {
+    name: "serialized"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "names"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "sparse_keys"
+    type: DT_STRING
+    number_attr: "Nsparse"
+  }
+  input_arg {
+    name: "dense_keys"
+    type: DT_STRING
+    number_attr: "Ndense"
+  }
+  input_arg {
+    name: "dense_defaults"
+    type_list_attr: "Tdense"
+  }
+  output_arg {
+    name: "sparse_indices"
+    type: DT_INT64
+    number_attr: "Nsparse"
+  }
+  output_arg {
+    name: "sparse_values"
+    type_list_attr: "sparse_types"
+  }
+  output_arg {
+    name: "sparse_shapes"
+    type: DT_INT64
+    number_attr: "Nsparse"
+  }
+  output_arg {
+    name: "dense_values"
+    type_list_attr: "Tdense"
+  }
+  attr {
+    name: "Nsparse"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "Ndense"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "sparse_types"
+    type: "list(type)"
+    has_minimum: true
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "Tdense"
+    type: "list(type)"
+    has_minimum: true
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "dense_shapes"
+    type: "list(shape)"
+    has_minimum: true
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ParseExampleDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ParseExampleDataset.pbtxt
new file mode 100644
index 0000000..6a51864
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ParseExampleDataset.pbtxt
@@ -0,0 +1,77 @@
+op {
+  name: "ParseExampleDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "num_parallel_calls"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "dense_defaults"
+    type_list_attr: "Tdense"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "sparse_keys"
+    type: "list(string)"
+    has_minimum: true
+  }
+  attr {
+    name: "dense_keys"
+    type: "list(string)"
+    has_minimum: true
+  }
+  attr {
+    name: "sparse_types"
+    type: "list(type)"
+    has_minimum: true
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "Tdense"
+    type: "list(type)"
+    has_minimum: true
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "dense_shapes"
+    type: "list(shape)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "sloppy"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ParseSequenceExample.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ParseSequenceExample.pbtxt
new file mode 100644
index 0000000..03ac5be
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ParseSequenceExample.pbtxt
@@ -0,0 +1,195 @@
+op {
+  name: "ParseSequenceExample"
+  input_arg {
+    name: "serialized"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "debug_name"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "context_dense_defaults"
+    type_list_attr: "Tcontext_dense"
+  }
+  output_arg {
+    name: "context_sparse_indices"
+    type: DT_INT64
+    number_attr: "Ncontext_sparse"
+  }
+  output_arg {
+    name: "context_sparse_values"
+    type_list_attr: "context_sparse_types"
+  }
+  output_arg {
+    name: "context_sparse_shapes"
+    type: DT_INT64
+    number_attr: "Ncontext_sparse"
+  }
+  output_arg {
+    name: "context_dense_values"
+    type_list_attr: "Tcontext_dense"
+  }
+  output_arg {
+    name: "feature_list_sparse_indices"
+    type: DT_INT64
+    number_attr: "Nfeature_list_sparse"
+  }
+  output_arg {
+    name: "feature_list_sparse_values"
+    type_list_attr: "feature_list_sparse_types"
+  }
+  output_arg {
+    name: "feature_list_sparse_shapes"
+    type: DT_INT64
+    number_attr: "Nfeature_list_sparse"
+  }
+  output_arg {
+    name: "feature_list_dense_values"
+    type_list_attr: "feature_list_dense_types"
+  }
+  output_arg {
+    name: "feature_list_dense_lengths"
+    type: DT_INT64
+    number_attr: "Nfeature_list_dense"
+  }
+  attr {
+    name: "feature_list_dense_missing_assumed_empty"
+    type: "list(string)"
+    has_minimum: true
+  }
+  attr {
+    name: "context_sparse_keys"
+    type: "list(string)"
+    has_minimum: true
+  }
+  attr {
+    name: "context_dense_keys"
+    type: "list(string)"
+    has_minimum: true
+  }
+  attr {
+    name: "feature_list_sparse_keys"
+    type: "list(string)"
+    has_minimum: true
+  }
+  attr {
+    name: "feature_list_dense_keys"
+    type: "list(string)"
+    has_minimum: true
+  }
+  attr {
+    name: "Ncontext_sparse"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "Ncontext_dense"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "Nfeature_list_sparse"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "Nfeature_list_dense"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "context_sparse_types"
+    type: "list(type)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "Tcontext_dense"
+    type: "list(type)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "feature_list_dense_types"
+    type: "list(type)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "context_dense_shapes"
+    type: "list(shape)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "feature_list_sparse_types"
+    type: "list(type)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "feature_list_dense_shapes"
+    type: "list(shape)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ParseSingleExample.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ParseSingleExample.pbtxt
new file mode 100644
index 0000000..aaa69af
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ParseSingleExample.pbtxt
@@ -0,0 +1,73 @@
+op {
+  name: "ParseSingleExample"
+  input_arg {
+    name: "serialized"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "dense_defaults"
+    type_list_attr: "Tdense"
+  }
+  output_arg {
+    name: "sparse_indices"
+    type: DT_INT64
+    number_attr: "num_sparse"
+  }
+  output_arg {
+    name: "sparse_values"
+    type_list_attr: "sparse_types"
+  }
+  output_arg {
+    name: "sparse_shapes"
+    type: DT_INT64
+    number_attr: "num_sparse"
+  }
+  output_arg {
+    name: "dense_values"
+    type_list_attr: "Tdense"
+  }
+  attr {
+    name: "num_sparse"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "sparse_keys"
+    type: "list(string)"
+    has_minimum: true
+  }
+  attr {
+    name: "dense_keys"
+    type: "list(string)"
+    has_minimum: true
+  }
+  attr {
+    name: "sparse_types"
+    type: "list(type)"
+    has_minimum: true
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "Tdense"
+    type: "list(type)"
+    has_minimum: true
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "dense_shapes"
+    type: "list(shape)"
+    has_minimum: true
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ParseSingleSequenceExample.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ParseSingleSequenceExample.pbtxt
new file mode 100644
index 0000000..a0f52db
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ParseSingleSequenceExample.pbtxt
@@ -0,0 +1,189 @@
+op {
+  name: "ParseSingleSequenceExample"
+  input_arg {
+    name: "serialized"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "feature_list_dense_missing_assumed_empty"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "context_sparse_keys"
+    type: DT_STRING
+    number_attr: "Ncontext_sparse"
+  }
+  input_arg {
+    name: "context_dense_keys"
+    type: DT_STRING
+    number_attr: "Ncontext_dense"
+  }
+  input_arg {
+    name: "feature_list_sparse_keys"
+    type: DT_STRING
+    number_attr: "Nfeature_list_sparse"
+  }
+  input_arg {
+    name: "feature_list_dense_keys"
+    type: DT_STRING
+    number_attr: "Nfeature_list_dense"
+  }
+  input_arg {
+    name: "context_dense_defaults"
+    type_list_attr: "Tcontext_dense"
+  }
+  input_arg {
+    name: "debug_name"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "context_sparse_indices"
+    type: DT_INT64
+    number_attr: "Ncontext_sparse"
+  }
+  output_arg {
+    name: "context_sparse_values"
+    type_list_attr: "context_sparse_types"
+  }
+  output_arg {
+    name: "context_sparse_shapes"
+    type: DT_INT64
+    number_attr: "Ncontext_sparse"
+  }
+  output_arg {
+    name: "context_dense_values"
+    type_list_attr: "Tcontext_dense"
+  }
+  output_arg {
+    name: "feature_list_sparse_indices"
+    type: DT_INT64
+    number_attr: "Nfeature_list_sparse"
+  }
+  output_arg {
+    name: "feature_list_sparse_values"
+    type_list_attr: "feature_list_sparse_types"
+  }
+  output_arg {
+    name: "feature_list_sparse_shapes"
+    type: DT_INT64
+    number_attr: "Nfeature_list_sparse"
+  }
+  output_arg {
+    name: "feature_list_dense_values"
+    type_list_attr: "feature_list_dense_types"
+  }
+  attr {
+    name: "Ncontext_sparse"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "Ncontext_dense"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "Nfeature_list_sparse"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "Nfeature_list_dense"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "context_sparse_types"
+    type: "list(type)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "Tcontext_dense"
+    type: "list(type)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "feature_list_dense_types"
+    type: "list(type)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "context_dense_shapes"
+    type: "list(shape)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "feature_list_sparse_types"
+    type: "list(type)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "feature_list_dense_shapes"
+    type: "list(shape)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ParseTensor.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ParseTensor.pbtxt
new file mode 100644
index 0000000..63d1f12
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ParseTensor.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "ParseTensor"
+  input_arg {
+    name: "serialized"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/PartitionedCall.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/PartitionedCall.pbtxt
new file mode 100644
index 0000000..b51bd1d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/PartitionedCall.pbtxt
@@ -0,0 +1,142 @@
+op {
+  name: "PartitionedCall"
+  input_arg {
+    name: "args"
+    type_list_attr: "Tin"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+}
+op {
+  name: "PartitionedCall"
+  input_arg {
+    name: "args"
+    type_list_attr: "Tin"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "config"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+}
+op {
+  name: "PartitionedCall"
+  input_arg {
+    name: "args"
+    type_list_attr: "Tin"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "config"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "executor_type"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+}
+op {
+  name: "PartitionedCall"
+  input_arg {
+    name: "args"
+    type_list_attr: "Tin"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "config"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "config_proto"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "executor_type"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Placeholder.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Placeholder.pbtxt
new file mode 100644
index 0000000..7c0f57a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Placeholder.pbtxt
@@ -0,0 +1,39 @@
+op {
+  name: "Placeholder"
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+    default_value {
+      shape {
+      }
+    }
+  }
+}
+op {
+  name: "Placeholder"
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+    default_value {
+      shape {
+        unknown_rank: true
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/PlaceholderV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/PlaceholderV2.pbtxt
new file mode 100644
index 0000000..b2cd20b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/PlaceholderV2.pbtxt
@@ -0,0 +1,33 @@
+op {
+  name: "PlaceholderV2"
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+}
+op {
+  name: "PlaceholderV2"
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  deprecation {
+    version: 23
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/PlaceholderWithDefault.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/PlaceholderWithDefault.pbtxt
new file mode 100644
index 0000000..79a2ffb
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/PlaceholderWithDefault.pbtxt
@@ -0,0 +1,19 @@
+op {
+  name: "PlaceholderWithDefault"
+  input_arg {
+    name: "input"
+    type_attr: "dtype"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Polygamma.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Polygamma.pbtxt
new file mode 100644
index 0000000..6bf0d9b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Polygamma.pbtxt
@@ -0,0 +1,25 @@
+op {
+  name: "Polygamma"
+  input_arg {
+    name: "a"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/PopulationCount.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/PopulationCount.pbtxt
new file mode 100644
index 0000000..d66c1ac
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/PopulationCount.pbtxt
@@ -0,0 +1,52 @@
+op {
+  name: "PopulationCount"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type: DT_UINT8
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_UINT16
+      }
+    }
+  }
+}
+op {
+  name: "PopulationCount"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type: DT_UINT8
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Pow.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Pow.pbtxt
new file mode 100644
index 0000000..9908655
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Pow.pbtxt
@@ -0,0 +1,92 @@
+op {
+  name: "Pow"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Pow"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Pow"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/PrefetchDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/PrefetchDataset.pbtxt
new file mode 100644
index 0000000..396c691
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/PrefetchDataset.pbtxt
@@ -0,0 +1,130 @@
+op {
+  name: "PrefetchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "buffer_size"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
+op {
+  name: "PrefetchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "buffer_size"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
+  name: "PrefetchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "buffer_size"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "slack_period"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+}
+op {
+  name: "PrefetchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "buffer_size"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "slack_period"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "legacy_autotune"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Prelinearize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Prelinearize.pbtxt
new file mode 100644
index 0000000..b5ed810
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Prelinearize.pbtxt
@@ -0,0 +1,31 @@
+op {
+  name: "Prelinearize"
+  input_arg {
+    name: "input"
+    type_attr: "dtype"
+  }
+  output_arg {
+    name: "output"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+    default_value {
+      shape {
+      }
+    }
+  }
+  attr {
+    name: "layout"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/PrelinearizeTuple.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/PrelinearizeTuple.pbtxt
new file mode 100644
index 0000000..bb1ae7d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/PrelinearizeTuple.pbtxt
@@ -0,0 +1,29 @@
+op {
+  name: "PrelinearizeTuple"
+  input_arg {
+    name: "inputs"
+    type_list_attr: "dtypes"
+  }
+  output_arg {
+    name: "output"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "shapes"
+    type: "list(shape)"
+  }
+  attr {
+    name: "layouts"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/PreventGradient.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/PreventGradient.pbtxt
new file mode 100644
index 0000000..1649fc8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/PreventGradient.pbtxt
@@ -0,0 +1,22 @@
+op {
+  name: "PreventGradient"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "message"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Print.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Print.pbtxt
new file mode 100644
index 0000000..fbbb514
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Print.pbtxt
@@ -0,0 +1,93 @@
+op {
+  name: "Print"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "data"
+    type_list_attr: "U"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "U"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "message"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "first_n"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "summarize"
+    type: "int"
+    default_value {
+      i: 3
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "Print"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "data"
+    type_list_attr: "U"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "U"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "message"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "first_n"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "summarize"
+    type: "int"
+    default_value {
+      i: 3
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/PrintV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/PrintV2.pbtxt
new file mode 100644
index 0000000..c5942f0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/PrintV2.pbtxt
@@ -0,0 +1,61 @@
+op {
+  name: "PrintV2"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  attr {
+    name: "output_stream"
+    type: "string"
+    default_value {
+      s: "stderr"
+    }
+    allowed_values {
+      list {
+        s: "stdout"
+        s: "stderr"
+        s: "log(info)"
+        s: "log(warning)"
+        s: "log(error)"
+      }
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "PrintV2"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  attr {
+    name: "output_stream"
+    type: "string"
+    default_value {
+      s: "stderr"
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "PrintV2"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  attr {
+    name: "output_stream"
+    type: "string"
+    default_value {
+      s: "stderr"
+    }
+  }
+  attr {
+    name: "end"
+    type: "string"
+    default_value {
+      s: "\n"
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/PriorityQueue.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/PriorityQueue.pbtxt
new file mode 100644
index 0000000..b44d83d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/PriorityQueue.pbtxt
@@ -0,0 +1,44 @@
+op {
+  name: "PriorityQueue"
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "component_types"
+    type: "list(type)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "shapes"
+    type: "list(shape)"
+    has_minimum: true
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/PriorityQueueV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/PriorityQueueV2.pbtxt
new file mode 100644
index 0000000..a4e7c75
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/PriorityQueueV2.pbtxt
@@ -0,0 +1,43 @@
+op {
+  name: "PriorityQueueV2"
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "component_types"
+    type: "list(type)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "shapes"
+    type: "list(shape)"
+    has_minimum: true
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/PrivateThreadPoolDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/PrivateThreadPoolDataset.pbtxt
new file mode 100644
index 0000000..91a1017
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/PrivateThreadPoolDataset.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "PrivateThreadPoolDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "num_threads"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Prod.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Prod.pbtxt
new file mode 100644
index 0000000..0583fc1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Prod.pbtxt
@@ -0,0 +1,236 @@
+op {
+  name: "Prod"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Prod"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Prod"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Prod"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/PyFunc.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/PyFunc.pbtxt
new file mode 100644
index 0000000..987f0280
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/PyFunc.pbtxt
@@ -0,0 +1,26 @@
+op {
+  name: "PyFunc"
+  input_arg {
+    name: "input"
+    type_list_attr: "Tin"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "token"
+    type: "string"
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/PyFuncStateless.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/PyFuncStateless.pbtxt
new file mode 100644
index 0000000..2a587d5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/PyFuncStateless.pbtxt
@@ -0,0 +1,25 @@
+op {
+  name: "PyFuncStateless"
+  input_arg {
+    name: "input"
+    type_list_attr: "Tin"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "token"
+    type: "string"
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Qr.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Qr.pbtxt
new file mode 100644
index 0000000..8319528
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Qr.pbtxt
@@ -0,0 +1,69 @@
+op {
+  name: "Qr"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "q"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "r"
+    type_attr: "T"
+  }
+  attr {
+    name: "full_matrices"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Qr"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "q"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "r"
+    type_attr: "T"
+  }
+  attr {
+    name: "full_matrices"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizeAndDequantize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizeAndDequantize.pbtxt
new file mode 100644
index 0000000..fb662f7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizeAndDequantize.pbtxt
@@ -0,0 +1,295 @@
+op {
+  name: "QuantizeAndDequantize"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "signed_input"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "num_bits"
+    type: "int"
+    default_value {
+      i: 8
+    }
+  }
+  attr {
+    name: "range_given"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "input_min"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "input_max"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "QuantizeAndDequantize"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "signed_input"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "num_bits"
+    type: "int"
+    default_value {
+      i: 8
+    }
+  }
+  attr {
+    name: "range_given"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "input_min"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "input_max"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  deprecation {
+    version: 21
+  }
+}
+op {
+  name: "QuantizeAndDequantize"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "signed_input"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "num_bits"
+    type: "int"
+    default_value {
+      i: 8
+    }
+  }
+  attr {
+    name: "range_given"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "input_min"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "input_max"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  deprecation {
+    version: 22
+  }
+}
+op {
+  name: "QuantizeAndDequantize"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "signed_input"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "num_bits"
+    type: "int"
+    default_value {
+      i: 8
+    }
+  }
+  attr {
+    name: "range_given"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "input_min"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "input_max"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  deprecation {
+    version: 22
+  }
+}
+op {
+  name: "QuantizeAndDequantize"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "signed_input"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "num_bits"
+    type: "int"
+    default_value {
+      i: 8
+    }
+  }
+  attr {
+    name: "range_given"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "input_min"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "input_max"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  deprecation {
+    version: 22
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizeAndDequantizeV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizeAndDequantizeV2.pbtxt
new file mode 100644
index 0000000..46375bd
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizeAndDequantizeV2.pbtxt
@@ -0,0 +1,290 @@
+op {
+  name: "QuantizeAndDequantizeV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_min"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_max"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "signed_input"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "num_bits"
+    type: "int"
+    default_value {
+      i: 8
+    }
+  }
+  attr {
+    name: "range_given"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "QuantizeAndDequantizeV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_min"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_max"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "signed_input"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "num_bits"
+    type: "int"
+    default_value {
+      i: 8
+    }
+  }
+  attr {
+    name: "range_given"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "QuantizeAndDequantizeV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_min"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_max"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "signed_input"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "num_bits"
+    type: "int"
+    default_value {
+      i: 8
+    }
+  }
+  attr {
+    name: "range_given"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "QuantizeAndDequantizeV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_min"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_max"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "signed_input"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "num_bits"
+    type: "int"
+    default_value {
+      i: 8
+    }
+  }
+  attr {
+    name: "range_given"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "round_mode"
+    type: "string"
+    default_value {
+      s: "HALF_TO_EVEN"
+    }
+    allowed_values {
+      list {
+        s: "HALF_TO_EVEN"
+        s: "HALF_UP"
+      }
+    }
+  }
+}
+op {
+  name: "QuantizeAndDequantizeV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_min"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_max"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "signed_input"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "num_bits"
+    type: "int"
+    default_value {
+      i: 8
+    }
+  }
+  attr {
+    name: "range_given"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "round_mode"
+    type: "string"
+    default_value {
+      s: "HALF_TO_EVEN"
+    }
+    allowed_values {
+      list {
+        s: "HALF_TO_EVEN"
+        s: "HALF_UP"
+      }
+    }
+  }
+  attr {
+    name: "narrow_range"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizeAndDequantizeV3.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizeAndDequantizeV3.pbtxt
new file mode 100644
index 0000000..3ece936
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizeAndDequantizeV3.pbtxt
@@ -0,0 +1,200 @@
+op {
+  name: "QuantizeAndDequantizeV3"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_min"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_max"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "num_bits"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "signed_input"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "range_given"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "QuantizeAndDequantizeV3"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_min"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_max"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "num_bits"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "signed_input"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "range_given"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "QuantizeAndDequantizeV3"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_min"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_max"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "num_bits"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "signed_input"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "range_given"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "QuantizeAndDequantizeV3"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_min"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_max"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "num_bits"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "signed_input"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "range_given"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "narrow_range"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizeDownAndShrinkRange.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizeDownAndShrinkRange.pbtxt
new file mode 100644
index 0000000..42783d3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizeDownAndShrinkRange.pbtxt
@@ -0,0 +1,106 @@
+op {
+  name: "QuantizeDownAndShrinkRange"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "input_min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "input_max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "output_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output_max"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+}
+op {
+  name: "QuantizeDownAndShrinkRange"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "input_min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "input_max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "output_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output_max"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizeV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizeV2.pbtxt
new file mode 100644
index 0000000..9386906
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizeV2.pbtxt
@@ -0,0 +1,241 @@
+op {
+  name: "QuantizeV2"
+  input_arg {
+    name: "input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_range"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_range"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output_max"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "mode"
+    type: "string"
+    default_value {
+      s: "MIN_COMBINED"
+    }
+    allowed_values {
+      list {
+        s: "MIN_COMBINED"
+        s: "MIN_FIRST"
+      }
+    }
+  }
+}
+op {
+  name: "QuantizeV2"
+  input_arg {
+    name: "input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_range"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_range"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output_max"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "mode"
+    type: "string"
+    default_value {
+      s: "MIN_COMBINED"
+    }
+    allowed_values {
+      list {
+        s: "MIN_COMBINED"
+        s: "MIN_FIRST"
+        s: "SCALED"
+      }
+    }
+  }
+}
+op {
+  name: "QuantizeV2"
+  input_arg {
+    name: "input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_range"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_range"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output_max"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "mode"
+    type: "string"
+    default_value {
+      s: "MIN_COMBINED"
+    }
+    allowed_values {
+      list {
+        s: "MIN_COMBINED"
+        s: "MIN_FIRST"
+        s: "SCALED"
+      }
+    }
+  }
+  attr {
+    name: "round_mode"
+    type: "string"
+    default_value {
+      s: "HALF_AWAY_FROM_ZERO"
+    }
+    allowed_values {
+      list {
+        s: "HALF_AWAY_FROM_ZERO"
+        s: "HALF_TO_EVEN"
+      }
+    }
+  }
+}
+op {
+  name: "QuantizeV2"
+  input_arg {
+    name: "input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_range"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_range"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output_max"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "mode"
+    type: "string"
+    default_value {
+      s: "MIN_COMBINED"
+    }
+    allowed_values {
+      list {
+        s: "MIN_COMBINED"
+        s: "MIN_FIRST"
+        s: "SCALED"
+      }
+    }
+  }
+  attr {
+    name: "round_mode"
+    type: "string"
+    default_value {
+      s: "HALF_AWAY_FROM_ZERO"
+    }
+    allowed_values {
+      list {
+        s: "HALF_AWAY_FROM_ZERO"
+        s: "HALF_TO_EVEN"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedAdd.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedAdd.pbtxt
new file mode 100644
index 0000000..4532bc2
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedAdd.pbtxt
@@ -0,0 +1,245 @@
+op {
+  name: "QuantizedAdd"
+  input_arg {
+    name: "x"
+    type_attr: "T1"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T2"
+  }
+  input_arg {
+    name: "min_x"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_x"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_y"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_y"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "z"
+    type_attr: "Toutput"
+  }
+  output_arg {
+    name: "min_z"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_z"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T1"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "T2"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "Toutput"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "QuantizedAdd"
+  input_arg {
+    name: "x"
+    type_attr: "T1"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T2"
+  }
+  input_arg {
+    name: "min_x"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_x"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_y"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_y"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "z"
+    type_attr: "Toutput"
+  }
+  output_arg {
+    name: "min_z"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_z"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T1"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "T2"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Toutput"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "QuantizedAdd"
+  input_arg {
+    name: "x"
+    type_attr: "T1"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T2"
+  }
+  input_arg {
+    name: "min_x"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_x"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_y"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_y"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "z"
+    type_attr: "Toutput"
+  }
+  output_arg {
+    name: "min_z"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_z"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T1"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "T2"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Toutput"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedAvgPool.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedAvgPool.pbtxt
new file mode 100644
index 0000000..0ae3390
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedAvgPool.pbtxt
@@ -0,0 +1,116 @@
+op {
+  name: "QuantizedAvgPool"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
+op {
+  name: "QuantizedAvgPool"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedBatchNormWithGlobalNormalization.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedBatchNormWithGlobalNormalization.pbtxt
new file mode 100644
index 0000000..832b8ba
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedBatchNormWithGlobalNormalization.pbtxt
@@ -0,0 +1,218 @@
+op {
+  name: "QuantizedBatchNormWithGlobalNormalization"
+  input_arg {
+    name: "t"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "t_min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "t_max"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "m"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "m_min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "m_max"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "v"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "v_min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "v_max"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "beta"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "beta_min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "beta_max"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "gamma"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "gamma_min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "gamma_max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "result"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "result_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "result_max"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "variance_epsilon"
+    type: "float"
+  }
+  attr {
+    name: "scale_after_normalization"
+    type: "bool"
+  }
+}
+op {
+  name: "QuantizedBatchNormWithGlobalNormalization"
+  input_arg {
+    name: "t"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "t_min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "t_max"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "m"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "m_min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "m_max"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "v"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "v_min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "v_max"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "beta"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "beta_min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "beta_max"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "gamma"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "gamma_min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "gamma_max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "result"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "result_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "result_max"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "variance_epsilon"
+    type: "float"
+  }
+  attr {
+    name: "scale_after_normalization"
+    type: "bool"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedBiasAdd.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedBiasAdd.pbtxt
new file mode 100644
index 0000000..b479c2c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedBiasAdd.pbtxt
@@ -0,0 +1,156 @@
+op {
+  name: "QuantizedBiasAdd"
+  input_arg {
+    name: "input"
+    type_attr: "T1"
+  }
+  input_arg {
+    name: "bias"
+    type_attr: "T2"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_bias"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_bias"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_out"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_out"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T1"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "T2"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+}
+op {
+  name: "QuantizedBiasAdd"
+  input_arg {
+    name: "input"
+    type_attr: "T1"
+  }
+  input_arg {
+    name: "bias"
+    type_attr: "T2"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_bias"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_bias"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_out"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_out"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T1"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "T2"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedConcat.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConcat.pbtxt
new file mode 100644
index 0000000..449f588
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConcat.pbtxt
@@ -0,0 +1,44 @@
+op {
+  name: "QuantizedConcat"
+  input_arg {
+    name: "concat_dim"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "values"
+    type_attr: "T"
+    number_attr: "N"
+  }
+  input_arg {
+    name: "input_mins"
+    type: DT_FLOAT
+    number_attr: "N"
+  }
+  input_arg {
+    name: "input_maxes"
+    type: DT_FLOAT
+    number_attr: "N"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output_max"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 2
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2D.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2D.pbtxt
new file mode 100644
index 0000000..b1cf1c8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2D.pbtxt
@@ -0,0 +1,309 @@
+op {
+  name: "QuantizedConv2D"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
+op {
+  name: "QuantizedConv2D"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
+op {
+  name: "QuantizedConv2D"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DAndRelu.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DAndRelu.pbtxt
new file mode 100644
index 0000000..229e4c4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DAndRelu.pbtxt
@@ -0,0 +1,222 @@
+op {
+  name: "QuantizedConv2DAndRelu"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
+op {
+  name: "QuantizedConv2DAndRelu"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+  attr {
+    name: "padding_list"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DAndReluAndRequantize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DAndReluAndRequantize.pbtxt
new file mode 100644
index 0000000..bc56689
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DAndReluAndRequantize.pbtxt
@@ -0,0 +1,238 @@
+op {
+  name: "QuantizedConv2DAndReluAndRequantize"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_freezed_output"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_freezed_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QUINT8
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
+op {
+  name: "QuantizedConv2DAndReluAndRequantize"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_freezed_output"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_freezed_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QUINT8
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+  attr {
+    name: "padding_list"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DAndRequantize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DAndRequantize.pbtxt
new file mode 100644
index 0000000..5d26709
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DAndRequantize.pbtxt
@@ -0,0 +1,238 @@
+op {
+  name: "QuantizedConv2DAndRequantize"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_freezed_output"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_freezed_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QINT8
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
+op {
+  name: "QuantizedConv2DAndRequantize"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_freezed_output"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_freezed_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QINT8
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+  attr {
+    name: "padding_list"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DPerChannel.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DPerChannel.pbtxt
new file mode 100644
index 0000000..9364094
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DPerChannel.pbtxt
@@ -0,0 +1,107 @@
+op {
+  name: "QuantizedConv2DPerChannel"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DWithBias.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DWithBias.pbtxt
new file mode 100644
index 0000000..8372a88
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DWithBias.pbtxt
@@ -0,0 +1,230 @@
+op {
+  name: "QuantizedConv2DWithBias"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "bias"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
+op {
+  name: "QuantizedConv2DWithBias"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "bias"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+  attr {
+    name: "padding_list"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DWithBiasAndRelu.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DWithBiasAndRelu.pbtxt
new file mode 100644
index 0000000..af0ce39
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DWithBiasAndRelu.pbtxt
@@ -0,0 +1,230 @@
+op {
+  name: "QuantizedConv2DWithBiasAndRelu"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "bias"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
+op {
+  name: "QuantizedConv2DWithBiasAndRelu"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "bias"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+  attr {
+    name: "padding_list"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DWithBiasAndReluAndRequantize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DWithBiasAndReluAndRequantize.pbtxt
new file mode 100644
index 0000000..599f19e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DWithBiasAndReluAndRequantize.pbtxt
@@ -0,0 +1,266 @@
+op {
+  name: "QuantizedConv2DWithBiasAndReluAndRequantize"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "bias"
+    type_attr: "Tbias"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_freezed_output"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_freezed_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tbias"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QUINT8
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
+op {
+  name: "QuantizedConv2DWithBiasAndReluAndRequantize"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "bias"
+    type_attr: "Tbias"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_freezed_output"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_freezed_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tbias"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QUINT8
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+  attr {
+    name: "padding_list"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DWithBiasAndRequantize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DWithBiasAndRequantize.pbtxt
new file mode 100644
index 0000000..8cf8fbb
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DWithBiasAndRequantize.pbtxt
@@ -0,0 +1,266 @@
+op {
+  name: "QuantizedConv2DWithBiasAndRequantize"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "bias"
+    type_attr: "Tbias"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_freezed_output"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_freezed_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tbias"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QINT8
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
+op {
+  name: "QuantizedConv2DWithBiasAndRequantize"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "bias"
+    type_attr: "Tbias"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_freezed_output"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_freezed_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tbias"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QINT8
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+  attr {
+    name: "padding_list"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DWithBiasSignedSumAndReluAndRequantize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DWithBiasSignedSumAndReluAndRequantize.pbtxt
new file mode 100644
index 0000000..e46786a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DWithBiasSignedSumAndReluAndRequantize.pbtxt
@@ -0,0 +1,316 @@
+op {
+  name: "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "bias"
+    type_attr: "Tbias"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_freezed_output"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_freezed_output"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "summand"
+    type_attr: "Tsummand"
+  }
+  input_arg {
+    name: "min_summand"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_summand"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tbias"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "Tsummand"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QUINT8
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
+op {
+  name: "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "bias"
+    type_attr: "Tbias"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_freezed_output"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_freezed_output"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "summand"
+    type_attr: "Tsummand"
+  }
+  input_arg {
+    name: "min_summand"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_summand"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tbias"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "Tsummand"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QUINT8
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+  attr {
+    name: "padding_list"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DWithBiasSumAndRelu.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DWithBiasSumAndRelu.pbtxt
new file mode 100644
index 0000000..d74439b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DWithBiasSumAndRelu.pbtxt
@@ -0,0 +1,238 @@
+op {
+  name: "QuantizedConv2DWithBiasSumAndRelu"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "bias"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "summand"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
+op {
+  name: "QuantizedConv2DWithBiasSumAndRelu"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "bias"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "summand"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+  attr {
+    name: "padding_list"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DWithBiasSumAndReluAndRequantize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DWithBiasSumAndReluAndRequantize.pbtxt
new file mode 100644
index 0000000..70c2366
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedConv2DWithBiasSumAndReluAndRequantize.pbtxt
@@ -0,0 +1,316 @@
+op {
+  name: "QuantizedConv2DWithBiasSumAndReluAndRequantize"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "bias"
+    type_attr: "Tbias"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_freezed_output"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_freezed_output"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "summand"
+    type_attr: "Tsummand"
+  }
+  input_arg {
+    name: "min_summand"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_summand"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tbias"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "Tsummand"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QUINT8
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
+op {
+  name: "QuantizedConv2DWithBiasSumAndReluAndRequantize"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "bias"
+    type_attr: "Tbias"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_freezed_output"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_freezed_output"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "summand"
+    type_attr: "Tsummand"
+  }
+  input_arg {
+    name: "min_summand"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_summand"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tbias"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "Tsummand"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QUINT8
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+  attr {
+    name: "padding_list"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedDepthwiseConv2D.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedDepthwiseConv2D.pbtxt
new file mode 100644
index 0000000..f88bba2
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedDepthwiseConv2D.pbtxt
@@ -0,0 +1,107 @@
+op {
+  name: "QuantizedDepthwiseConv2D"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedDepthwiseConv2DWithBias.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedDepthwiseConv2DWithBias.pbtxt
new file mode 100644
index 0000000..4faf839
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedDepthwiseConv2DWithBias.pbtxt
@@ -0,0 +1,111 @@
+op {
+  name: "QuantizedDepthwiseConv2DWithBias"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "bias"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedDepthwiseConv2DWithBiasAndRelu.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedDepthwiseConv2DWithBiasAndRelu.pbtxt
new file mode 100644
index 0000000..cc6d923
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedDepthwiseConv2DWithBiasAndRelu.pbtxt
@@ -0,0 +1,111 @@
+op {
+  name: "QuantizedDepthwiseConv2DWithBiasAndRelu"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "bias"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize.pbtxt
new file mode 100644
index 0000000..5413d15
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize.pbtxt
@@ -0,0 +1,129 @@
+op {
+  name: "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "filter"
+    type_attr: "Tfilter"
+  }
+  input_arg {
+    name: "bias"
+    type_attr: "Tbias"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_filter"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_freezed_output"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_freezed_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tfilter"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tbias"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QUINT8
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+  attr {
+    name: "dilations"
+    type: "list(int)"
+    default_value {
+      list {
+        i: 1
+        i: 1
+        i: 1
+        i: 1
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedInstanceNorm.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedInstanceNorm.pbtxt
new file mode 100644
index 0000000..98136d8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedInstanceNorm.pbtxt
@@ -0,0 +1,150 @@
+op {
+  name: "QuantizedInstanceNorm"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "x_min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "x_max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "y_max"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "output_range_given"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "given_y_min"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "given_y_max"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "variance_epsilon"
+    type: "float"
+    default_value {
+      f: 1e-05
+    }
+  }
+  attr {
+    name: "min_separation"
+    type: "float"
+    default_value {
+      f: 0.001
+    }
+  }
+}
+op {
+  name: "QuantizedInstanceNorm"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "x_min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "x_max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "y_max"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "output_range_given"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "given_y_min"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "given_y_max"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "variance_epsilon"
+    type: "float"
+    default_value {
+      f: 1e-05
+    }
+  }
+  attr {
+    name: "min_separation"
+    type: "float"
+    default_value {
+      f: 0.001
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedMatMul.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedMatMul.pbtxt
new file mode 100644
index 0000000..7e4707a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedMatMul.pbtxt
@@ -0,0 +1,222 @@
+op {
+  name: "QuantizedMatMul"
+  input_arg {
+    name: "a"
+    type_attr: "T1"
+  }
+  input_arg {
+    name: "b"
+    type_attr: "T2"
+  }
+  input_arg {
+    name: "min_a"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_a"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_b"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_b"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "out"
+    type_attr: "Toutput"
+  }
+  output_arg {
+    name: "min_out"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_out"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T1"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "T2"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "Toutput"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "transpose_a"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "transpose_b"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "Tactivation"
+    type: "type"
+    default_value {
+      type: DT_QUINT8
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+}
+op {
+  name: "QuantizedMatMul"
+  input_arg {
+    name: "a"
+    type_attr: "T1"
+  }
+  input_arg {
+    name: "b"
+    type_attr: "T2"
+  }
+  input_arg {
+    name: "min_a"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_a"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_b"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_b"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "out"
+    type_attr: "Toutput"
+  }
+  output_arg {
+    name: "min_out"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_out"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T1"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "T2"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Toutput"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "transpose_a"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "transpose_b"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "Tactivation"
+    type: "type"
+    default_value {
+      type: DT_QUINT8
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedMatMulWithBias.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedMatMulWithBias.pbtxt
new file mode 100644
index 0000000..a59adb7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedMatMulWithBias.pbtxt
@@ -0,0 +1,122 @@
+op {
+  name: "QuantizedMatMulWithBias"
+  input_arg {
+    name: "a"
+    type_attr: "T1"
+  }
+  input_arg {
+    name: "b"
+    type_attr: "T2"
+  }
+  input_arg {
+    name: "bias"
+    type_attr: "Tbias"
+  }
+  input_arg {
+    name: "min_a"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_a"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_b"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_b"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "out"
+    type_attr: "Toutput"
+  }
+  output_arg {
+    name: "min_out"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_out"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T1"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "T2"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tbias"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "Toutput"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "transpose_a"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "transpose_b"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "input_quant_mode"
+    type: "string"
+    default_value {
+      s: "MIN_FIRST"
+    }
+    allowed_values {
+      list {
+        s: "MIN_FIRST"
+        s: "SCALED"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedMatMulWithBiasAndRelu.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedMatMulWithBiasAndRelu.pbtxt
new file mode 100644
index 0000000..cd0acb9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedMatMulWithBiasAndRelu.pbtxt
@@ -0,0 +1,112 @@
+op {
+  name: "QuantizedMatMulWithBiasAndRelu"
+  input_arg {
+    name: "a"
+    type_attr: "T1"
+  }
+  input_arg {
+    name: "b"
+    type_attr: "T2"
+  }
+  input_arg {
+    name: "bias"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_a"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_a"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_b"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_b"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "out"
+    type_attr: "Toutput"
+  }
+  output_arg {
+    name: "min_out"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_out"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T1"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "T2"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Toutput"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "transpose_a"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "transpose_b"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "input_quant_mode"
+    type: "string"
+    default_value {
+      s: "MIN_FIRST"
+    }
+    allowed_values {
+      list {
+        s: "MIN_FIRST"
+        s: "SCALED"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedMatMulWithBiasAndReluAndRequantize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedMatMulWithBiasAndReluAndRequantize.pbtxt
new file mode 100644
index 0000000..b591d3f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedMatMulWithBiasAndReluAndRequantize.pbtxt
@@ -0,0 +1,130 @@
+op {
+  name: "QuantizedMatMulWithBiasAndReluAndRequantize"
+  input_arg {
+    name: "a"
+    type_attr: "T1"
+  }
+  input_arg {
+    name: "b"
+    type_attr: "T2"
+  }
+  input_arg {
+    name: "bias"
+    type_attr: "Tbias"
+  }
+  input_arg {
+    name: "min_a"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_a"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_b"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_b"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_freezed_output"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_freezed_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "out"
+    type_attr: "Toutput"
+  }
+  output_arg {
+    name: "min_out"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_out"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T1"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "T2"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tbias"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "Toutput"
+    type: "type"
+    default_value {
+      type: DT_QUINT8
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "transpose_a"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "transpose_b"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "input_quant_mode"
+    type: "string"
+    default_value {
+      s: "MIN_FIRST"
+    }
+    allowed_values {
+      list {
+        s: "MIN_FIRST"
+        s: "SCALED"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedMaxPool.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedMaxPool.pbtxt
new file mode 100644
index 0000000..47d6ac8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedMaxPool.pbtxt
@@ -0,0 +1,116 @@
+op {
+  name: "QuantizedMaxPool"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
+op {
+  name: "QuantizedMaxPool"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "min_input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_input"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "min_output"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "ksize"
+    type: "list(int)"
+  }
+  attr {
+    name: "strides"
+    type: "list(int)"
+  }
+  attr {
+    name: "padding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "SAME"
+        s: "VALID"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedMul.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedMul.pbtxt
new file mode 100644
index 0000000..795ab13
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedMul.pbtxt
@@ -0,0 +1,245 @@
+op {
+  name: "QuantizedMul"
+  input_arg {
+    name: "x"
+    type_attr: "T1"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T2"
+  }
+  input_arg {
+    name: "min_x"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_x"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_y"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_y"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "z"
+    type_attr: "Toutput"
+  }
+  output_arg {
+    name: "min_z"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_z"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T1"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "T2"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "Toutput"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "QuantizedMul"
+  input_arg {
+    name: "x"
+    type_attr: "T1"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T2"
+  }
+  input_arg {
+    name: "min_x"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_x"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_y"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_y"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "z"
+    type_attr: "Toutput"
+  }
+  output_arg {
+    name: "min_z"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_z"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T1"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "T2"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Toutput"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "QuantizedMul"
+  input_arg {
+    name: "x"
+    type_attr: "T1"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T2"
+  }
+  input_arg {
+    name: "min_x"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_x"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_y"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_y"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "z"
+    type_attr: "Toutput"
+  }
+  output_arg {
+    name: "min_z"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_z"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T1"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "T2"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Toutput"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedRelu.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedRelu.pbtxt
new file mode 100644
index 0000000..724d8b3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedRelu.pbtxt
@@ -0,0 +1,112 @@
+op {
+  name: "QuantizedRelu"
+  input_arg {
+    name: "features"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "min_features"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_features"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_activations"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_activations"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QUINT8
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+}
+op {
+  name: "QuantizedRelu"
+  input_arg {
+    name: "features"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "min_features"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_features"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_activations"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_activations"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QUINT8
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedRelu6.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedRelu6.pbtxt
new file mode 100644
index 0000000..0f389d5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedRelu6.pbtxt
@@ -0,0 +1,112 @@
+op {
+  name: "QuantizedRelu6"
+  input_arg {
+    name: "features"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "min_features"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_features"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_activations"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_activations"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QUINT8
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+}
+op {
+  name: "QuantizedRelu6"
+  input_arg {
+    name: "features"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "min_features"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_features"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_activations"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_activations"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QUINT8
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedReluX.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedReluX.pbtxt
new file mode 100644
index 0000000..9ee6f0d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedReluX.pbtxt
@@ -0,0 +1,120 @@
+op {
+  name: "QuantizedReluX"
+  input_arg {
+    name: "features"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "max_value"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_features"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_features"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_activations"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_activations"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QUINT8
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+}
+op {
+  name: "QuantizedReluX"
+  input_arg {
+    name: "features"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "max_value"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_features"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_features"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "min_activations"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "max_activations"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QUINT8
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedReshape.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedReshape.pbtxt
new file mode 100644
index 0000000..f54db98
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedReshape.pbtxt
@@ -0,0 +1,48 @@
+op {
+  name: "QuantizedReshape"
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "shape"
+    type_attr: "Tshape"
+  }
+  input_arg {
+    name: "input_min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "input_max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output_max"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tshape"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QuantizedResizeBilinear.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QuantizedResizeBilinear.pbtxt
new file mode 100644
index 0000000..bee577e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QuantizedResizeBilinear.pbtxt
@@ -0,0 +1,105 @@
+op {
+  name: "QuantizedResizeBilinear"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "resized_images"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "out_max"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "align_corners"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "QuantizedResizeBilinear"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "resized_images"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "out_max"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "align_corners"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "half_pixel_centers"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QueueClose.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QueueClose.pbtxt
new file mode 100644
index 0000000..582eecc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QueueClose.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "QueueClose"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "cancel_pending_enqueues"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QueueCloseV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QueueCloseV2.pbtxt
new file mode 100644
index 0000000..e0544c1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QueueCloseV2.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "QueueCloseV2"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "cancel_pending_enqueues"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QueueDequeue.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QueueDequeue.pbtxt
new file mode 100644
index 0000000..f06745f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QueueDequeue.pbtxt
@@ -0,0 +1,25 @@
+op {
+  name: "QueueDequeue"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  output_arg {
+    name: "components"
+    type_list_attr: "component_types"
+  }
+  attr {
+    name: "component_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "timeout_ms"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QueueDequeueMany.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QueueDequeueMany.pbtxt
new file mode 100644
index 0000000..374ecfb
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QueueDequeueMany.pbtxt
@@ -0,0 +1,29 @@
+op {
+  name: "QueueDequeueMany"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "n"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "components"
+    type_list_attr: "component_types"
+  }
+  attr {
+    name: "component_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "timeout_ms"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QueueDequeueManyV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QueueDequeueManyV2.pbtxt
new file mode 100644
index 0000000..f3ebc6c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QueueDequeueManyV2.pbtxt
@@ -0,0 +1,29 @@
+op {
+  name: "QueueDequeueManyV2"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "n"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "components"
+    type_list_attr: "component_types"
+  }
+  attr {
+    name: "component_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "timeout_ms"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QueueDequeueUpTo.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QueueDequeueUpTo.pbtxt
new file mode 100644
index 0000000..6fa30ac
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QueueDequeueUpTo.pbtxt
@@ -0,0 +1,29 @@
+op {
+  name: "QueueDequeueUpTo"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "n"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "components"
+    type_list_attr: "component_types"
+  }
+  attr {
+    name: "component_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "timeout_ms"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QueueDequeueUpToV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QueueDequeueUpToV2.pbtxt
new file mode 100644
index 0000000..2016cc7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QueueDequeueUpToV2.pbtxt
@@ -0,0 +1,29 @@
+op {
+  name: "QueueDequeueUpToV2"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "n"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "components"
+    type_list_attr: "component_types"
+  }
+  attr {
+    name: "component_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "timeout_ms"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QueueDequeueV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QueueDequeueV2.pbtxt
new file mode 100644
index 0000000..e338ccb
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QueueDequeueV2.pbtxt
@@ -0,0 +1,25 @@
+op {
+  name: "QueueDequeueV2"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "components"
+    type_list_attr: "component_types"
+  }
+  attr {
+    name: "component_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "timeout_ms"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QueueEnqueue.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QueueEnqueue.pbtxt
new file mode 100644
index 0000000..fb94d28
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QueueEnqueue.pbtxt
@@ -0,0 +1,25 @@
+op {
+  name: "QueueEnqueue"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "components"
+    type_list_attr: "Tcomponents"
+  }
+  attr {
+    name: "Tcomponents"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "timeout_ms"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QueueEnqueueMany.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QueueEnqueueMany.pbtxt
new file mode 100644
index 0000000..2d95824
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QueueEnqueueMany.pbtxt
@@ -0,0 +1,25 @@
+op {
+  name: "QueueEnqueueMany"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "components"
+    type_list_attr: "Tcomponents"
+  }
+  attr {
+    name: "Tcomponents"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "timeout_ms"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QueueEnqueueManyV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QueueEnqueueManyV2.pbtxt
new file mode 100644
index 0000000..c327d27
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QueueEnqueueManyV2.pbtxt
@@ -0,0 +1,25 @@
+op {
+  name: "QueueEnqueueManyV2"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "components"
+    type_list_attr: "Tcomponents"
+  }
+  attr {
+    name: "Tcomponents"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "timeout_ms"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QueueEnqueueV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QueueEnqueueV2.pbtxt
new file mode 100644
index 0000000..da8cdd3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QueueEnqueueV2.pbtxt
@@ -0,0 +1,25 @@
+op {
+  name: "QueueEnqueueV2"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "components"
+    type_list_attr: "Tcomponents"
+  }
+  attr {
+    name: "Tcomponents"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "timeout_ms"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QueueIsClosed.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QueueIsClosed.pbtxt
new file mode 100644
index 0000000..11a421b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QueueIsClosed.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "QueueIsClosed"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  output_arg {
+    name: "is_closed"
+    type: DT_BOOL
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QueueIsClosedV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QueueIsClosedV2.pbtxt
new file mode 100644
index 0000000..7cf1fde
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QueueIsClosedV2.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "QueueIsClosedV2"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "is_closed"
+    type: DT_BOOL
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QueueSize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QueueSize.pbtxt
new file mode 100644
index 0000000..d2a4962
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QueueSize.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "QueueSize"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  output_arg {
+    name: "size"
+    type: DT_INT32
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/QueueSizeV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/QueueSizeV2.pbtxt
new file mode 100644
index 0000000..46eb229
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/QueueSizeV2.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "QueueSizeV2"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RFFT.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RFFT.pbtxt
new file mode 100644
index 0000000..0d65e7c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RFFT.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "RFFT"
+  input_arg {
+    name: "input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "fft_length"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type: DT_COMPLEX64
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RFFT2D.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RFFT2D.pbtxt
new file mode 100644
index 0000000..4e4ef53
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RFFT2D.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "RFFT2D"
+  input_arg {
+    name: "input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "fft_length"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type: DT_COMPLEX64
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RFFT3D.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RFFT3D.pbtxt
new file mode 100644
index 0000000..2f044b3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RFFT3D.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "RFFT3D"
+  input_arg {
+    name: "input"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "fft_length"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type: DT_COMPLEX64
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RGBToHSV.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RGBToHSV.pbtxt
new file mode 100644
index 0000000..9ed50d3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RGBToHSV.pbtxt
@@ -0,0 +1,50 @@
+op {
+  name: "RGBToHSV"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "RGBToHSV"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RaggedGather.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RaggedGather.pbtxt
new file mode 100644
index 0000000..afa14e8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RaggedGather.pbtxt
@@ -0,0 +1,113 @@
+op {
+  name: "RaggedGather"
+  input_arg {
+    name: "params_nested_splits"
+    type: DT_INT64
+    number_attr: "PARAMS_RAGGED_RANK"
+  }
+  input_arg {
+    name: "params_dense_values"
+    type_attr: "Tvalues"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output_nested_splits"
+    type: DT_INT64
+    number_attr: "OUTPUT_RAGGED_RANK"
+  }
+  output_arg {
+    name: "output_dense_values"
+    type_attr: "Tvalues"
+  }
+  attr {
+    name: "Tvalues"
+    type: "type"
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "PARAMS_RAGGED_RANK"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "OUTPUT_RAGGED_RANK"
+    type: "int"
+    has_minimum: true
+  }
+}
+op {
+  name: "RaggedGather"
+  input_arg {
+    name: "params_nested_splits"
+    type_attr: "Tsplits"
+    number_attr: "PARAMS_RAGGED_RANK"
+  }
+  input_arg {
+    name: "params_dense_values"
+    type_attr: "Tvalues"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output_nested_splits"
+    type_attr: "Tsplits"
+    number_attr: "OUTPUT_RAGGED_RANK"
+  }
+  output_arg {
+    name: "output_dense_values"
+    type_attr: "Tvalues"
+  }
+  attr {
+    name: "Tvalues"
+    type: "type"
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tsplits"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "PARAMS_RAGGED_RANK"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "OUTPUT_RAGGED_RANK"
+    type: "int"
+    has_minimum: true
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RaggedRange.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RaggedRange.pbtxt
new file mode 100644
index 0000000..866c9b4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RaggedRange.pbtxt
@@ -0,0 +1,91 @@
+op {
+  name: "RaggedRange"
+  input_arg {
+    name: "starts"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "limits"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "deltas"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "rt_nested_splits"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "rt_dense_values"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "RaggedRange"
+  input_arg {
+    name: "starts"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "limits"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "deltas"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "rt_nested_splits"
+    type_attr: "Tsplits"
+  }
+  output_arg {
+    name: "rt_dense_values"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tsplits"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RaggedTensorFromVariant.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RaggedTensorFromVariant.pbtxt
new file mode 100644
index 0000000..1d2201c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RaggedTensorFromVariant.pbtxt
@@ -0,0 +1,42 @@
+op {
+  name: "RaggedTensorFromVariant"
+  input_arg {
+    name: "encoded_ragged"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "output_nested_splits"
+    type_attr: "Tsplits"
+    number_attr: "output_ragged_rank"
+  }
+  output_arg {
+    name: "output_dense_values"
+    type_attr: "Tvalues"
+  }
+  attr {
+    name: "input_ragged_rank"
+    type: "int"
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "output_ragged_rank"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "Tvalues"
+    type: "type"
+  }
+  attr {
+    name: "Tsplits"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RaggedTensorToSparse.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RaggedTensorToSparse.pbtxt
new file mode 100644
index 0000000..f9172b4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RaggedTensorToSparse.pbtxt
@@ -0,0 +1,81 @@
+op {
+  name: "RaggedTensorToSparse"
+  input_arg {
+    name: "rt_nested_splits"
+    type: DT_INT64
+    number_attr: "RAGGED_RANK"
+  }
+  input_arg {
+    name: "rt_dense_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "sparse_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sparse_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "sparse_dense_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "RAGGED_RANK"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
+op {
+  name: "RaggedTensorToSparse"
+  input_arg {
+    name: "rt_nested_splits"
+    type_attr: "Tsplits"
+    number_attr: "RAGGED_RANK"
+  }
+  input_arg {
+    name: "rt_dense_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "sparse_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sparse_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "sparse_dense_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "RAGGED_RANK"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tsplits"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RaggedTensorToTensor.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RaggedTensorToTensor.pbtxt
new file mode 100644
index 0000000..60fceb5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RaggedTensorToTensor.pbtxt
@@ -0,0 +1,58 @@
+op {
+  name: "RaggedTensorToTensor"
+  input_arg {
+    name: "shape"
+    type_attr: "Tshape"
+  }
+  input_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "default_value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "row_partition_tensors"
+    type_attr: "Tindex"
+    number_attr: "num_row_partition_tensors"
+  }
+  output_arg {
+    name: "result"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tindex"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT64
+        type: DT_INT32
+      }
+    }
+  }
+  attr {
+    name: "Tshape"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT64
+        type: DT_INT32
+      }
+    }
+  }
+  attr {
+    name: "num_row_partition_tensors"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "row_partition_types"
+    type: "list(string)"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RaggedTensorToVariant.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RaggedTensorToVariant.pbtxt
new file mode 100644
index 0000000..6121fbd
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RaggedTensorToVariant.pbtxt
@@ -0,0 +1,40 @@
+op {
+  name: "RaggedTensorToVariant"
+  input_arg {
+    name: "rt_nested_splits"
+    type_attr: "Tsplits"
+    number_attr: "RAGGED_RANK"
+  }
+  input_arg {
+    name: "rt_dense_values"
+    type_attr: "Tvalues"
+  }
+  output_arg {
+    name: "encoded_ragged"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "RAGGED_RANK"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "Tvalues"
+    type: "type"
+  }
+  attr {
+    name: "Tsplits"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "batched_input"
+    type: "bool"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RandomCrop.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RandomCrop.pbtxt
new file mode 100644
index 0000000..a5353cf
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RandomCrop.pbtxt
@@ -0,0 +1,48 @@
+op {
+  name: "RandomCrop"
+  input_arg {
+    name: "image"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  deprecation {
+    version: 8
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RandomDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RandomDataset.pbtxt
new file mode 100644
index 0000000..777c509
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RandomDataset.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "RandomDataset"
+  input_arg {
+    name: "seed"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "seed2"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RandomGamma.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RandomGamma.pbtxt
new file mode 100644
index 0000000..2f38a20
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RandomGamma.pbtxt
@@ -0,0 +1,51 @@
+op {
+  name: "RandomGamma"
+  input_arg {
+    name: "shape"
+    type_attr: "S"
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "S"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RandomGammaGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RandomGammaGrad.pbtxt
new file mode 100644
index 0000000..1e1c072
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RandomGammaGrad.pbtxt
@@ -0,0 +1,25 @@
+op {
+  name: "RandomGammaGrad"
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sample"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RandomPoisson.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RandomPoisson.pbtxt
new file mode 100644
index 0000000..5499e8d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RandomPoisson.pbtxt
@@ -0,0 +1,105 @@
+op {
+  name: "RandomPoisson"
+  input_arg {
+    name: "shape"
+    type_attr: "S"
+  }
+  input_arg {
+    name: "rate"
+    type_attr: "dtype"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "S"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "RandomPoisson"
+  input_arg {
+    name: "shape"
+    type_attr: "S"
+  }
+  input_arg {
+    name: "rate"
+    type_attr: "dtype"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "S"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  deprecation {
+    version: 25
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RandomPoissonV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RandomPoissonV2.pbtxt
new file mode 100644
index 0000000..6c3d982
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RandomPoissonV2.pbtxt
@@ -0,0 +1,72 @@
+op {
+  name: "RandomPoissonV2"
+  input_arg {
+    name: "shape"
+    type_attr: "S"
+  }
+  input_arg {
+    name: "rate"
+    type_attr: "R"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "S"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "R"
+    type: "type"
+    default_value {
+      type: DT_DOUBLE
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RandomShuffle.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RandomShuffle.pbtxt
new file mode 100644
index 0000000..ddd1a8d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RandomShuffle.pbtxt
@@ -0,0 +1,30 @@
+op {
+  name: "RandomShuffle"
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RandomShuffleQueue.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RandomShuffleQueue.pbtxt
new file mode 100644
index 0000000..550acae
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RandomShuffleQueue.pbtxt
@@ -0,0 +1,66 @@
+op {
+  name: "RandomShuffleQueue"
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "component_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "shapes"
+    type: "list(shape)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "min_after_dequeue"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RandomShuffleQueueV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RandomShuffleQueueV2.pbtxt
new file mode 100644
index 0000000..7d9807c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RandomShuffleQueueV2.pbtxt
@@ -0,0 +1,65 @@
+op {
+  name: "RandomShuffleQueueV2"
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "component_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "shapes"
+    type: "list(shape)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "min_after_dequeue"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RandomStandardNormal.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RandomStandardNormal.pbtxt
new file mode 100644
index 0000000..71fe5e5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RandomStandardNormal.pbtxt
@@ -0,0 +1,95 @@
+op {
+  name: "RandomStandardNormal"
+  input_arg {
+    name: "shape"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "RandomStandardNormal"
+  input_arg {
+    name: "shape"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RandomUniform.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RandomUniform.pbtxt
new file mode 100644
index 0000000..449a9ef
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RandomUniform.pbtxt
@@ -0,0 +1,95 @@
+op {
+  name: "RandomUniform"
+  input_arg {
+    name: "shape"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "RandomUniform"
+  input_arg {
+    name: "shape"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RandomUniformInt.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RandomUniformInt.pbtxt
new file mode 100644
index 0000000..3b89715
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RandomUniformInt.pbtxt
@@ -0,0 +1,54 @@
+op {
+  name: "RandomUniformInt"
+  input_arg {
+    name: "shape"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "minval"
+    type_attr: "Tout"
+  }
+  input_arg {
+    name: "maxval"
+    type_attr: "Tout"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "Tout"
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "Tout"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Range.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Range.pbtxt
new file mode 100644
index 0000000..fc13408
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Range.pbtxt
@@ -0,0 +1,69 @@
+op {
+  name: "Range"
+  input_arg {
+    name: "start"
+    type_attr: "Tidx"
+  }
+  input_arg {
+    name: "limit"
+    type_attr: "Tidx"
+  }
+  input_arg {
+    name: "delta"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "Tidx"
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Range"
+  input_arg {
+    name: "start"
+    type_attr: "Tidx"
+  }
+  input_arg {
+    name: "limit"
+    type_attr: "Tidx"
+  }
+  input_arg {
+    name: "delta"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "Tidx"
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RangeDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RangeDataset.pbtxt
new file mode 100644
index 0000000..782d23a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RangeDataset.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "RangeDataset"
+  input_arg {
+    name: "start"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "stop"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "step"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Rank.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Rank.pbtxt
new file mode 100644
index 0000000..c12fd9a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Rank.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "Rank"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type: DT_INT32
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ReadFile.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ReadFile.pbtxt
new file mode 100644
index 0000000..ce1985e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ReadFile.pbtxt
@@ -0,0 +1,11 @@
+op {
+  name: "ReadFile"
+  input_arg {
+    name: "filename"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "contents"
+    type: DT_STRING
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ReadVariableOp.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ReadVariableOp.pbtxt
new file mode 100644
index 0000000..5459632
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ReadVariableOp.pbtxt
@@ -0,0 +1,16 @@
+op {
+  name: "ReadVariableOp"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "value"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ReaderNumRecordsProduced.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ReaderNumRecordsProduced.pbtxt
new file mode 100644
index 0000000..50b1ea0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ReaderNumRecordsProduced.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "ReaderNumRecordsProduced"
+  input_arg {
+    name: "reader_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  output_arg {
+    name: "records_produced"
+    type: DT_INT64
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ReaderNumRecordsProducedV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ReaderNumRecordsProducedV2.pbtxt
new file mode 100644
index 0000000..f560f01
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ReaderNumRecordsProducedV2.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "ReaderNumRecordsProducedV2"
+  input_arg {
+    name: "reader_handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "records_produced"
+    type: DT_INT64
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ReaderNumWorkUnitsCompleted.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ReaderNumWorkUnitsCompleted.pbtxt
new file mode 100644
index 0000000..b1e361e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ReaderNumWorkUnitsCompleted.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "ReaderNumWorkUnitsCompleted"
+  input_arg {
+    name: "reader_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  output_arg {
+    name: "units_completed"
+    type: DT_INT64
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ReaderNumWorkUnitsCompletedV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ReaderNumWorkUnitsCompletedV2.pbtxt
new file mode 100644
index 0000000..ee4c93e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ReaderNumWorkUnitsCompletedV2.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "ReaderNumWorkUnitsCompletedV2"
+  input_arg {
+    name: "reader_handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "units_completed"
+    type: DT_INT64
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ReaderRead.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ReaderRead.pbtxt
new file mode 100644
index 0000000..b2a9338
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ReaderRead.pbtxt
@@ -0,0 +1,21 @@
+op {
+  name: "ReaderRead"
+  input_arg {
+    name: "reader_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "queue_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  output_arg {
+    name: "key"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "value"
+    type: DT_STRING
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ReaderReadUpTo.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ReaderReadUpTo.pbtxt
new file mode 100644
index 0000000..e3bb64e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ReaderReadUpTo.pbtxt
@@ -0,0 +1,25 @@
+op {
+  name: "ReaderReadUpTo"
+  input_arg {
+    name: "reader_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "queue_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "num_records"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "keys"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "values"
+    type: DT_STRING
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ReaderReadUpToV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ReaderReadUpToV2.pbtxt
new file mode 100644
index 0000000..2ad62b1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ReaderReadUpToV2.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "ReaderReadUpToV2"
+  input_arg {
+    name: "reader_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "queue_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "num_records"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "keys"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "values"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ReaderReadV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ReaderReadV2.pbtxt
new file mode 100644
index 0000000..3a15731
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ReaderReadV2.pbtxt
@@ -0,0 +1,20 @@
+op {
+  name: "ReaderReadV2"
+  input_arg {
+    name: "reader_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "queue_handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "key"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "value"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ReaderReset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ReaderReset.pbtxt
new file mode 100644
index 0000000..9607f83
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ReaderReset.pbtxt
@@ -0,0 +1,8 @@
+op {
+  name: "ReaderReset"
+  input_arg {
+    name: "reader_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ReaderResetV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ReaderResetV2.pbtxt
new file mode 100644
index 0000000..56f862a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ReaderResetV2.pbtxt
@@ -0,0 +1,8 @@
+op {
+  name: "ReaderResetV2"
+  input_arg {
+    name: "reader_handle"
+    type: DT_RESOURCE
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ReaderRestoreState.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ReaderRestoreState.pbtxt
new file mode 100644
index 0000000..717a5c3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ReaderRestoreState.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "ReaderRestoreState"
+  input_arg {
+    name: "reader_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "state"
+    type: DT_STRING
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ReaderRestoreStateV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ReaderRestoreStateV2.pbtxt
new file mode 100644
index 0000000..f75b04f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ReaderRestoreStateV2.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "ReaderRestoreStateV2"
+  input_arg {
+    name: "reader_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "state"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ReaderSerializeState.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ReaderSerializeState.pbtxt
new file mode 100644
index 0000000..2f708cb
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ReaderSerializeState.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "ReaderSerializeState"
+  input_arg {
+    name: "reader_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  output_arg {
+    name: "state"
+    type: DT_STRING
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ReaderSerializeStateV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ReaderSerializeStateV2.pbtxt
new file mode 100644
index 0000000..c4ade14
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ReaderSerializeStateV2.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "ReaderSerializeStateV2"
+  input_arg {
+    name: "reader_handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "state"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Real.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Real.pbtxt
new file mode 100644
index 0000000..d7e783e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Real.pbtxt
@@ -0,0 +1,37 @@
+op {
+  name: "Real"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "Tout"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_COMPLEX64
+    }
+    allowed_values {
+      list {
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  attr {
+    name: "Tout"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RealDiv.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RealDiv.pbtxt
new file mode 100644
index 0000000..43d814a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RealDiv.pbtxt
@@ -0,0 +1,104 @@
+op {
+  name: "RealDiv"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "RealDiv"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "RealDiv"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RebatchDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RebatchDataset.pbtxt
new file mode 100644
index 0000000..b0f222c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RebatchDataset.pbtxt
@@ -0,0 +1,95 @@
+op {
+  name: "RebatchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "num_workers"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
+  name: "RebatchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "num_workers"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "use_fallback"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
+op {
+  name: "RebatchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "num_replicas"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "use_fallback"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Reciprocal.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Reciprocal.pbtxt
new file mode 100644
index 0000000..5ea1abe
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Reciprocal.pbtxt
@@ -0,0 +1,80 @@
+op {
+  name: "Reciprocal"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Reciprocal"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Reciprocal"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ReciprocalGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ReciprocalGrad.pbtxt
new file mode 100644
index 0000000..8884c79
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ReciprocalGrad.pbtxt
@@ -0,0 +1,114 @@
+op {
+  name: "ReciprocalGrad"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "ReciprocalGrad"
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dy"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "ReciprocalGrad"
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dy"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "ReciprocalGrad"
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dy"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RecordInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RecordInput.pbtxt
new file mode 100644
index 0000000..a723744
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RecordInput.pbtxt
@@ -0,0 +1,101 @@
+op {
+  name: "RecordInput"
+  output_arg {
+    name: "records"
+    type: DT_STRING
+  }
+  attr {
+    name: "file_pattern"
+    type: "string"
+  }
+  attr {
+    name: "file_random_seed"
+    type: "int"
+    default_value {
+      i: 301
+    }
+  }
+  attr {
+    name: "file_shuffle_shift_ratio"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "file_buffer_size"
+    type: "int"
+    default_value {
+      i: 10000
+    }
+  }
+  attr {
+    name: "file_parallelism"
+    type: "int"
+    default_value {
+      i: 16
+    }
+  }
+  attr {
+    name: "batch_size"
+    type: "int"
+    default_value {
+      i: 32
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "RecordInput"
+  output_arg {
+    name: "records"
+    type: DT_STRING
+  }
+  attr {
+    name: "file_pattern"
+    type: "string"
+  }
+  attr {
+    name: "file_random_seed"
+    type: "int"
+    default_value {
+      i: 301
+    }
+  }
+  attr {
+    name: "file_shuffle_shift_ratio"
+    type: "float"
+    default_value {
+      f: 0
+    }
+  }
+  attr {
+    name: "file_buffer_size"
+    type: "int"
+    default_value {
+      i: 10000
+    }
+  }
+  attr {
+    name: "file_parallelism"
+    type: "int"
+    default_value {
+      i: 16
+    }
+  }
+  attr {
+    name: "batch_size"
+    type: "int"
+    default_value {
+      i: 32
+    }
+  }
+  attr {
+    name: "compression_type"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RecvTPUEmbeddingActivations.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RecvTPUEmbeddingActivations.pbtxt
new file mode 100644
index 0000000..0fec828
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RecvTPUEmbeddingActivations.pbtxt
@@ -0,0 +1,19 @@
+op {
+  name: "RecvTPUEmbeddingActivations"
+  output_arg {
+    name: "outputs"
+    type: DT_FLOAT
+    number_attr: "num_outputs"
+  }
+  attr {
+    name: "num_outputs"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "config"
+    type: "string"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ReduceDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ReduceDataset.pbtxt
new file mode 100644
index 0000000..4ec1948
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ReduceDataset.pbtxt
@@ -0,0 +1,107 @@
+op {
+  name: "ReduceDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "initial_state"
+    type_list_attr: "Tstate"
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  output_arg {
+    name: "components"
+    type_list_attr: "output_types"
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Tstate"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "use_inter_op_parallelism"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
+op {
+  name: "ReduceDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "initial_state"
+    type_list_attr: "Tstate"
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  output_arg {
+    name: "components"
+    type_list_attr: "output_types"
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Tstate"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "use_inter_op_parallelism"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ReduceJoin.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ReduceJoin.pbtxt
new file mode 100644
index 0000000..28880fc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ReduceJoin.pbtxt
@@ -0,0 +1,29 @@
+op {
+  name: "ReduceJoin"
+  input_arg {
+    name: "inputs"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "reduction_indices"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type: DT_STRING
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "separator"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RefEnter.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RefEnter.pbtxt
new file mode 100644
index 0000000..9af599d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RefEnter.pbtxt
@@ -0,0 +1,35 @@
+op {
+  name: "RefEnter"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+    is_ref: true
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "frame_name"
+    type: "string"
+  }
+  attr {
+    name: "is_constant"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "parallel_iterations"
+    type: "int"
+    default_value {
+      i: 10
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RefExit.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RefExit.pbtxt
new file mode 100644
index 0000000..1f9e84e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RefExit.pbtxt
@@ -0,0 +1,17 @@
+op {
+  name: "RefExit"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+    is_ref: true
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RefIdentity.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RefIdentity.pbtxt
new file mode 100644
index 0000000..d2293fd
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RefIdentity.pbtxt
@@ -0,0 +1,18 @@
+op {
+  name: "RefIdentity"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+    is_ref: true
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  allows_uninitialized_input: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RefMerge.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RefMerge.pbtxt
new file mode 100644
index 0000000..fc4794d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RefMerge.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "RefMerge"
+  input_arg {
+    name: "inputs"
+    type_attr: "T"
+    number_attr: "N"
+    is_ref: true
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+    is_ref: true
+  }
+  output_arg {
+    name: "value_index"
+    type: DT_INT32
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RefNextIteration.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RefNextIteration.pbtxt
new file mode 100644
index 0000000..d447a3a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RefNextIteration.pbtxt
@@ -0,0 +1,17 @@
+op {
+  name: "RefNextIteration"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+    is_ref: true
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RefSelect.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RefSelect.pbtxt
new file mode 100644
index 0000000..aa2645f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RefSelect.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "RefSelect"
+  input_arg {
+    name: "index"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "inputs"
+    type_attr: "T"
+    number_attr: "N"
+    is_ref: true
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RefSwitch.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RefSwitch.pbtxt
new file mode 100644
index 0000000..6d12be2
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RefSwitch.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "RefSwitch"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "pred"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "output_false"
+    type_attr: "T"
+    is_ref: true
+  }
+  output_arg {
+    name: "output_true"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  allows_uninitialized_input: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RegexFullMatch.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RegexFullMatch.pbtxt
new file mode 100644
index 0000000..f2c0b7b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RegexFullMatch.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "RegexFullMatch"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "pattern"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type: DT_BOOL
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RegexReplace.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RegexReplace.pbtxt
new file mode 100644
index 0000000..591773c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RegexReplace.pbtxt
@@ -0,0 +1,26 @@
+op {
+  name: "RegexReplace"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "pattern"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "rewrite"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type: DT_STRING
+  }
+  attr {
+    name: "replace_global"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Relu.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Relu.pbtxt
new file mode 100644
index 0000000..703fbbe
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Relu.pbtxt
@@ -0,0 +1,152 @@
+op {
+  name: "Relu"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "Relu"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "Relu"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "Relu"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "Relu"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_QINT8
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Relu6.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Relu6.pbtxt
new file mode 100644
index 0000000..311c329
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Relu6.pbtxt
@@ -0,0 +1,120 @@
+op {
+  name: "Relu6"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "Relu6"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "Relu6"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "Relu6"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Relu6Grad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Relu6Grad.pbtxt
new file mode 100644
index 0000000..618e13a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Relu6Grad.pbtxt
@@ -0,0 +1,136 @@
+op {
+  name: "Relu6Grad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "Relu6Grad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "Relu6Grad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "Relu6Grad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ReluGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ReluGrad.pbtxt
new file mode 100644
index 0000000..b14f23b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ReluGrad.pbtxt
@@ -0,0 +1,136 @@
+op {
+  name: "ReluGrad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "ReluGrad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "ReluGrad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "ReluGrad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RemoteCall.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RemoteCall.pbtxt
new file mode 100644
index 0000000..c6bc594
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RemoteCall.pbtxt
@@ -0,0 +1,63 @@
+op {
+  name: "RemoteCall"
+  input_arg {
+    name: "target"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "args"
+    type_list_attr: "Tin"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+}
+op {
+  name: "RemoteCall"
+  input_arg {
+    name: "target"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "args"
+    type_list_attr: "Tin"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RemoteFusedGraphExecute.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RemoteFusedGraphExecute.pbtxt
new file mode 100644
index 0000000..c47a45f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RemoteFusedGraphExecute.pbtxt
@@ -0,0 +1,25 @@
+op {
+  name: "RemoteFusedGraphExecute"
+  input_arg {
+    name: "inputs"
+    type_list_attr: "Tinputs"
+  }
+  output_arg {
+    name: "outputs"
+    type_list_attr: "Toutputs"
+  }
+  attr {
+    name: "Tinputs"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Toutputs"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "serialized_remote_fused_graph_execute_info"
+    type: "string"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RepeatDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RepeatDataset.pbtxt
new file mode 100644
index 0000000..de78c67
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RepeatDataset.pbtxt
@@ -0,0 +1,55 @@
+op {
+  name: "RepeatDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "count"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
+op {
+  name: "RepeatDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "count"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RequantizationRange.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RequantizationRange.pbtxt
new file mode 100644
index 0000000..6a48908
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RequantizationRange.pbtxt
@@ -0,0 +1,72 @@
+op {
+  name: "RequantizationRange"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "input_min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "input_max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output_max"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+}
+op {
+  name: "RequantizationRange"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "input_min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "input_max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output_max"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RequantizationRangePerChannel.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RequantizationRangePerChannel.pbtxt
new file mode 100644
index 0000000..b621afb
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RequantizationRangePerChannel.pbtxt
@@ -0,0 +1,43 @@
+op {
+  name: "RequantizationRangePerChannel"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "input_max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output_max"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "clip_value_max"
+    type: "float"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Requantize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Requantize.pbtxt
new file mode 100644
index 0000000..c04d32f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Requantize.pbtxt
@@ -0,0 +1,122 @@
+op {
+  name: "Requantize"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "input_min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "input_max"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "requested_output_min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "requested_output_max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "output_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output_max"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT16
+        type: DT_QUINT16
+        type: DT_QINT32
+      }
+    }
+  }
+}
+op {
+  name: "Requantize"
+  input_arg {
+    name: "input"
+    type_attr: "Tinput"
+  }
+  input_arg {
+    name: "input_min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "input_max"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "requested_output_min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "requested_output_max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "output_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output_max"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "Tinput"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RequantizePerChannel.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RequantizePerChannel.pbtxt
new file mode 100644
index 0000000..3ed03fe
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RequantizePerChannel.pbtxt
@@ -0,0 +1,67 @@
+op {
+  name: "RequantizePerChannel"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "input_max"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "requested_output_min"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "requested_output_max"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "output_min"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output_max"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_QINT32
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_QUINT8
+    }
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Reshape.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Reshape.pbtxt
new file mode 100644
index 0000000..e422ffa
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Reshape.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "Reshape"
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "shape"
+    type_attr: "Tshape"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tshape"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResizeArea.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResizeArea.pbtxt
new file mode 100644
index 0000000..6887280
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResizeArea.pbtxt
@@ -0,0 +1,77 @@
+op {
+  name: "ResizeArea"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "resized_images"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "align_corners"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ResizeArea"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "resized_images"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_UINT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "align_corners"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResizeBicubic.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResizeBicubic.pbtxt
new file mode 100644
index 0000000..9abf628
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResizeBicubic.pbtxt
@@ -0,0 +1,123 @@
+op {
+  name: "ResizeBicubic"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "resized_images"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "align_corners"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ResizeBicubic"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "resized_images"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_UINT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "align_corners"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ResizeBicubic"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "resized_images"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_UINT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "align_corners"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "half_pixel_centers"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResizeBicubicGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResizeBicubicGrad.pbtxt
new file mode 100644
index 0000000..6de227d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResizeBicubicGrad.pbtxt
@@ -0,0 +1,71 @@
+op {
+  name: "ResizeBicubicGrad"
+  input_arg {
+    name: "grads"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "original_image"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "align_corners"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ResizeBicubicGrad"
+  input_arg {
+    name: "grads"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "original_image"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "align_corners"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "half_pixel_centers"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResizeBilinear.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResizeBilinear.pbtxt
new file mode 100644
index 0000000..4e7c772
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResizeBilinear.pbtxt
@@ -0,0 +1,164 @@
+op {
+  name: "ResizeBilinear"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "resized_images"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "align_corners"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ResizeBilinear"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "resized_images"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_UINT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "align_corners"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ResizeBilinear"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "resized_images"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_UINT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "align_corners"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ResizeBilinear"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "resized_images"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_UINT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "align_corners"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "half_pixel_centers"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResizeBilinearGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResizeBilinearGrad.pbtxt
new file mode 100644
index 0000000..79d1605
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResizeBilinearGrad.pbtxt
@@ -0,0 +1,108 @@
+op {
+  name: "ResizeBilinearGrad"
+  input_arg {
+    name: "grads"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "original_image"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "align_corners"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ResizeBilinearGrad"
+  input_arg {
+    name: "grads"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "original_image"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "align_corners"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ResizeBilinearGrad"
+  input_arg {
+    name: "grads"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "original_image"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "align_corners"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "half_pixel_centers"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResizeNearestNeighbor.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResizeNearestNeighbor.pbtxt
new file mode 100644
index 0000000..627bae8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResizeNearestNeighbor.pbtxt
@@ -0,0 +1,123 @@
+op {
+  name: "ResizeNearestNeighbor"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "resized_images"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "align_corners"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ResizeNearestNeighbor"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "resized_images"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_UINT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "align_corners"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ResizeNearestNeighbor"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "resized_images"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_UINT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "align_corners"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "half_pixel_centers"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResizeNearestNeighborGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResizeNearestNeighborGrad.pbtxt
new file mode 100644
index 0000000..b16307e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResizeNearestNeighborGrad.pbtxt
@@ -0,0 +1,79 @@
+op {
+  name: "ResizeNearestNeighborGrad"
+  input_arg {
+    name: "grads"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT32
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "align_corners"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ResizeNearestNeighborGrad"
+  input_arg {
+    name: "grads"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT32
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "align_corners"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "half_pixel_centers"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceAccumulatorApplyGradient.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceAccumulatorApplyGradient.pbtxt
new file mode 100644
index 0000000..ba21fe2
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceAccumulatorApplyGradient.pbtxt
@@ -0,0 +1,41 @@
+op {
+  name: "ResourceAccumulatorApplyGradient"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "local_step"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "gradient"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceAccumulatorNumAccumulated.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceAccumulatorNumAccumulated.pbtxt
new file mode 100644
index 0000000..398171d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceAccumulatorNumAccumulated.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "ResourceAccumulatorNumAccumulated"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "num_accumulated"
+    type: DT_INT32
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceAccumulatorSetGlobalStep.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceAccumulatorSetGlobalStep.pbtxt
new file mode 100644
index 0000000..3e9c5a2
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceAccumulatorSetGlobalStep.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "ResourceAccumulatorSetGlobalStep"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "new_global_step"
+    type: DT_INT64
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceAccumulatorTakeGradient.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceAccumulatorTakeGradient.pbtxt
new file mode 100644
index 0000000..56e941c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceAccumulatorTakeGradient.pbtxt
@@ -0,0 +1,41 @@
+op {
+  name: "ResourceAccumulatorTakeGradient"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "num_required"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "average"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyAdaMax.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyAdaMax.pbtxt
new file mode 100644
index 0000000..0eb3c7c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyAdaMax.pbtxt
@@ -0,0 +1,72 @@
+op {
+  name: "ResourceApplyAdaMax"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "m"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "v"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "beta1_power"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyAdadelta.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyAdadelta.pbtxt
new file mode 100644
index 0000000..b2267d6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyAdadelta.pbtxt
@@ -0,0 +1,252 @@
+op {
+  name: "ResourceApplyAdadelta"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum_update"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyAdadelta"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum_update"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyAdadelta"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum_update"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyAdadelta"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum_update"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyAdagrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyAdagrad.pbtxt
new file mode 100644
index 0000000..9c3ee6a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyAdagrad.pbtxt
@@ -0,0 +1,263 @@
+op {
+  name: "ResourceApplyAdagrad"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyAdagrad"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyAdagrad"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyAdagrad"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyAdagrad"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "update_slots"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyAdagradDA.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyAdagradDA.pbtxt
new file mode 100644
index 0000000..acdf74c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyAdagradDA.pbtxt
@@ -0,0 +1,268 @@
+op {
+  name: "ResourceApplyAdagradDA"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "gradient_accumulator"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "gradient_squared_accumulator"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "global_step"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyAdagradDA"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "gradient_accumulator"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "gradient_squared_accumulator"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "global_step"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyAdagradDA"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "gradient_accumulator"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "gradient_squared_accumulator"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "global_step"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyAdagradDA"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "gradient_accumulator"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "gradient_squared_accumulator"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "global_step"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyAdagradV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyAdagradV2.pbtxt
new file mode 100644
index 0000000..e9071d4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyAdagradV2.pbtxt
@@ -0,0 +1,63 @@
+op {
+  name: "ResourceApplyAdagradV2"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "update_slots"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyAdam.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyAdam.pbtxt
new file mode 100644
index 0000000..0344526
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyAdam.pbtxt
@@ -0,0 +1,401 @@
+op {
+  name: "ResourceApplyAdam"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "m"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "v"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "beta1_power"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta2_power"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyAdam"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "m"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "v"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "beta1_power"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta2_power"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyAdam"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "m"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "v"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "beta1_power"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta2_power"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyAdam"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "m"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "v"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "beta1_power"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta2_power"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyAdam"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "m"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "v"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "beta1_power"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta2_power"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyAdamWithAmsgrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyAdamWithAmsgrad.pbtxt
new file mode 100644
index 0000000..7e0c6a0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyAdamWithAmsgrad.pbtxt
@@ -0,0 +1,80 @@
+op {
+  name: "ResourceApplyAdamWithAmsgrad"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "m"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "v"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "vhat"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "beta1_power"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta2_power"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyAddSign.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyAddSign.pbtxt
new file mode 100644
index 0000000..ca9ee73
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyAddSign.pbtxt
@@ -0,0 +1,191 @@
+op {
+  name: "ResourceApplyAddSign"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "m"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sign_decay"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyAddSign"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "m"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sign_decay"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyAddSign"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "m"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sign_decay"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyCenteredRMSProp.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyCenteredRMSProp.pbtxt
new file mode 100644
index 0000000..65248ab
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyCenteredRMSProp.pbtxt
@@ -0,0 +1,284 @@
+op {
+  name: "ResourceApplyCenteredRMSProp"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mg"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "ms"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mom"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyCenteredRMSProp"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mg"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "ms"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mom"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyCenteredRMSProp"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mg"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "ms"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mom"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyCenteredRMSProp"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mg"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "ms"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mom"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyFtrl.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyFtrl.pbtxt
new file mode 100644
index 0000000..a879b76
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyFtrl.pbtxt
@@ -0,0 +1,268 @@
+op {
+  name: "ResourceApplyFtrl"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "linear"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyFtrl"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "linear"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyFtrl"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "linear"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyFtrl"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "linear"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyFtrlV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyFtrlV2.pbtxt
new file mode 100644
index 0000000..6b11b0c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyFtrlV2.pbtxt
@@ -0,0 +1,284 @@
+op {
+  name: "ResourceApplyFtrlV2"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "linear"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2_shrinkage"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyFtrlV2"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "linear"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2_shrinkage"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyFtrlV2"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "linear"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2_shrinkage"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyFtrlV2"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "linear"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2_shrinkage"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyGradientDescent.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyGradientDescent.pbtxt
new file mode 100644
index 0000000..7badaf3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyGradientDescent.pbtxt
@@ -0,0 +1,188 @@
+op {
+  name: "ResourceApplyGradientDescent"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "delta"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyGradientDescent"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "delta"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyGradientDescent"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "delta"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyGradientDescent"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "delta"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyKerasMomentum.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyKerasMomentum.pbtxt
new file mode 100644
index 0000000..6837960
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyKerasMomentum.pbtxt
@@ -0,0 +1,63 @@
+op {
+  name: "ResourceApplyKerasMomentum"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyMomentum.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyMomentum.pbtxt
new file mode 100644
index 0000000..a72c0ad
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyMomentum.pbtxt
@@ -0,0 +1,248 @@
+op {
+  name: "ResourceApplyMomentum"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyMomentum"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyMomentum"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyMomentum"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyPowerSign.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyPowerSign.pbtxt
new file mode 100644
index 0000000..1c4bae7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyPowerSign.pbtxt
@@ -0,0 +1,191 @@
+op {
+  name: "ResourceApplyPowerSign"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "m"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "logbase"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sign_decay"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyPowerSign"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "m"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "logbase"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sign_decay"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyPowerSign"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "m"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "logbase"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sign_decay"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "beta"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyProximalAdagrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyProximalAdagrad.pbtxt
new file mode 100644
index 0000000..ab95b2d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyProximalAdagrad.pbtxt
@@ -0,0 +1,236 @@
+op {
+  name: "ResourceApplyProximalAdagrad"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyProximalAdagrad"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyProximalAdagrad"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyProximalAdagrad"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyProximalGradientDescent.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyProximalGradientDescent.pbtxt
new file mode 100644
index 0000000..e9abbd7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyProximalGradientDescent.pbtxt
@@ -0,0 +1,220 @@
+op {
+  name: "ResourceApplyProximalGradientDescent"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "delta"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyProximalGradientDescent"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "delta"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyProximalGradientDescent"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "delta"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyProximalGradientDescent"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "delta"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyRMSProp.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyRMSProp.pbtxt
new file mode 100644
index 0000000..6fcefea
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceApplyRMSProp.pbtxt
@@ -0,0 +1,268 @@
+op {
+  name: "ResourceApplyRMSProp"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "ms"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mom"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyRMSProp"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "ms"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mom"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyRMSProp"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "ms"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mom"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceApplyRMSProp"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "ms"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mom"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceConditionalAccumulator.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceConditionalAccumulator.pbtxt
new file mode 100644
index 0000000..36ffdbb
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceConditionalAccumulator.pbtxt
@@ -0,0 +1,64 @@
+op {
+  name: "ResourceConditionalAccumulator"
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "reduction_type"
+    type: "string"
+    default_value {
+      s: "MEAN"
+    }
+    allowed_values {
+      list {
+        s: "MEAN"
+        s: "SUM"
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceCountUpTo.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceCountUpTo.pbtxt
new file mode 100644
index 0000000..352935c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceCountUpTo.pbtxt
@@ -0,0 +1,26 @@
+op {
+  name: "ResourceCountUpTo"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "limit"
+    type: "int"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceGather.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceGather.pbtxt
new file mode 100644
index 0000000..9aa33d9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceGather.pbtxt
@@ -0,0 +1,81 @@
+op {
+  name: "ResourceGather"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "validate_indices"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceGather"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "batch_dims"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "validate_indices"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceGatherNd.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceGatherNd.pbtxt
new file mode 100644
index 0000000..04794f4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceGatherNd.pbtxt
@@ -0,0 +1,30 @@
+op {
+  name: "ResourceGatherNd"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterAdd.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterAdd.pbtxt
new file mode 100644
index 0000000..9524367
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterAdd.pbtxt
@@ -0,0 +1,200 @@
+op {
+  name: "ResourceScatterAdd"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceScatterAdd"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceScatterAdd"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceScatterAdd"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterDiv.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterDiv.pbtxt
new file mode 100644
index 0000000..d428855
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterDiv.pbtxt
@@ -0,0 +1,51 @@
+op {
+  name: "ResourceScatterDiv"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterMax.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterMax.pbtxt
new file mode 100644
index 0000000..41ef2a3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterMax.pbtxt
@@ -0,0 +1,51 @@
+op {
+  name: "ResourceScatterMax"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterMin.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterMin.pbtxt
new file mode 100644
index 0000000..d6a50b0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterMin.pbtxt
@@ -0,0 +1,51 @@
+op {
+  name: "ResourceScatterMin"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterMul.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterMul.pbtxt
new file mode 100644
index 0000000..9d124a0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterMul.pbtxt
@@ -0,0 +1,51 @@
+op {
+  name: "ResourceScatterMul"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterNdAdd.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterNdAdd.pbtxt
new file mode 100644
index 0000000..507b30e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterNdAdd.pbtxt
@@ -0,0 +1,37 @@
+op {
+  name: "ResourceScatterNdAdd"
+  input_arg {
+    name: "ref"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterNdSub.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterNdSub.pbtxt
new file mode 100644
index 0000000..9d1a74d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterNdSub.pbtxt
@@ -0,0 +1,37 @@
+op {
+  name: "ResourceScatterNdSub"
+  input_arg {
+    name: "ref"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterNdUpdate.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterNdUpdate.pbtxt
new file mode 100644
index 0000000..4305163
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterNdUpdate.pbtxt
@@ -0,0 +1,37 @@
+op {
+  name: "ResourceScatterNdUpdate"
+  input_arg {
+    name: "ref"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterSub.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterSub.pbtxt
new file mode 100644
index 0000000..af78b06
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterSub.pbtxt
@@ -0,0 +1,51 @@
+op {
+  name: "ResourceScatterSub"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterUpdate.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterUpdate.pbtxt
new file mode 100644
index 0000000..55101b8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceScatterUpdate.pbtxt
@@ -0,0 +1,182 @@
+op {
+  name: "ResourceScatterUpdate"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceScatterUpdate"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceScatterUpdate"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceScatterUpdate"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyAdadelta.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyAdadelta.pbtxt
new file mode 100644
index 0000000..24f1a40
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyAdadelta.pbtxt
@@ -0,0 +1,308 @@
+op {
+  name: "ResourceSparseApplyAdadelta"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum_update"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyAdadelta"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum_update"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyAdadelta"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum_update"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyAdadelta"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum_update"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyAdagrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyAdagrad.pbtxt
new file mode 100644
index 0000000..1bac35a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyAdagrad.pbtxt
@@ -0,0 +1,333 @@
+op {
+  name: "ResourceSparseApplyAdagrad"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyAdagrad"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyAdagrad"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyAdagrad"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyAdagrad"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "update_slots"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyAdagradDA.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyAdagradDA.pbtxt
new file mode 100644
index 0000000..f37acfc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyAdagradDA.pbtxt
@@ -0,0 +1,324 @@
+op {
+  name: "ResourceSparseApplyAdagradDA"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "gradient_accumulator"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "gradient_squared_accumulator"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "global_step"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyAdagradDA"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "gradient_accumulator"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "gradient_squared_accumulator"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "global_step"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyAdagradDA"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "gradient_accumulator"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "gradient_squared_accumulator"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "global_step"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyAdagradDA"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "gradient_accumulator"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "gradient_squared_accumulator"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "global_step"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyAdagradV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyAdagradV2.pbtxt
new file mode 100644
index 0000000..2c88969
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyAdagradV2.pbtxt
@@ -0,0 +1,77 @@
+op {
+  name: "ResourceSparseApplyAdagradV2"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "update_slots"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyCenteredRMSProp.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyCenteredRMSProp.pbtxt
new file mode 100644
index 0000000..feedcd5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyCenteredRMSProp.pbtxt
@@ -0,0 +1,340 @@
+op {
+  name: "ResourceSparseApplyCenteredRMSProp"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mg"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "ms"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mom"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyCenteredRMSProp"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mg"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "ms"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mom"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyCenteredRMSProp"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mg"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "ms"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mom"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyCenteredRMSProp"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mg"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "ms"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mom"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyFtrl.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyFtrl.pbtxt
new file mode 100644
index 0000000..9f45b6b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyFtrl.pbtxt
@@ -0,0 +1,324 @@
+op {
+  name: "ResourceSparseApplyFtrl"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "linear"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyFtrl"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "linear"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyFtrl"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "linear"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyFtrl"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "linear"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyFtrlV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyFtrlV2.pbtxt
new file mode 100644
index 0000000..e4a3aa2
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyFtrlV2.pbtxt
@@ -0,0 +1,340 @@
+op {
+  name: "ResourceSparseApplyFtrlV2"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "linear"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2_shrinkage"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyFtrlV2"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "linear"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2_shrinkage"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyFtrlV2"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "linear"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2_shrinkage"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyFtrlV2"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "linear"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2_shrinkage"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyKerasMomentum.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyKerasMomentum.pbtxt
new file mode 100644
index 0000000..84e146e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyKerasMomentum.pbtxt
@@ -0,0 +1,77 @@
+op {
+  name: "ResourceSparseApplyKerasMomentum"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyMomentum.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyMomentum.pbtxt
new file mode 100644
index 0000000..4248207
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyMomentum.pbtxt
@@ -0,0 +1,304 @@
+op {
+  name: "ResourceSparseApplyMomentum"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyMomentum"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyMomentum"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyMomentum"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyProximalAdagrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyProximalAdagrad.pbtxt
new file mode 100644
index 0000000..35f0409
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyProximalAdagrad.pbtxt
@@ -0,0 +1,292 @@
+op {
+  name: "ResourceSparseApplyProximalAdagrad"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyProximalAdagrad"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyProximalAdagrad"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyProximalAdagrad"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyProximalGradientDescent.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyProximalGradientDescent.pbtxt
new file mode 100644
index 0000000..d63e4e9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyProximalGradientDescent.pbtxt
@@ -0,0 +1,276 @@
+op {
+  name: "ResourceSparseApplyProximalGradientDescent"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyProximalGradientDescent"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyProximalGradientDescent"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyProximalGradientDescent"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyRMSProp.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyRMSProp.pbtxt
new file mode 100644
index 0000000..4bf71a9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceSparseApplyRMSProp.pbtxt
@@ -0,0 +1,324 @@
+op {
+  name: "ResourceSparseApplyRMSProp"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "ms"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mom"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyRMSProp"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "ms"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mom"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyRMSProp"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "ms"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mom"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "ResourceSparseApplyRMSProp"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "ms"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "mom"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ResourceStridedSliceAssign.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ResourceStridedSliceAssign.pbtxt
new file mode 100644
index 0000000..867f205
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ResourceStridedSliceAssign.pbtxt
@@ -0,0 +1,73 @@
+op {
+  name: "ResourceStridedSliceAssign"
+  input_arg {
+    name: "ref"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "begin"
+    type_attr: "Index"
+  }
+  input_arg {
+    name: "end"
+    type_attr: "Index"
+  }
+  input_arg {
+    name: "strides"
+    type_attr: "Index"
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Index"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "begin_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "end_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "ellipsis_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "new_axis_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "shrink_axis_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Restore.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Restore.pbtxt
new file mode 100644
index 0000000..1db0290
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Restore.pbtxt
@@ -0,0 +1,53 @@
+op {
+  name: "Restore"
+  input_arg {
+    name: "file_pattern"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "tensor_name"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "tensor"
+    type_attr: "dt"
+  }
+  attr {
+    name: "dt"
+    type: "type"
+  }
+  attr {
+    name: "preferred_shard"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+}
+op {
+  name: "Restore"
+  input_arg {
+    name: "file_pattern"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "tensor_name"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "tensor"
+    type_attr: "dt"
+  }
+  attr {
+    name: "dt"
+    type: "type"
+  }
+  attr {
+    name: "preferred_shard"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RestoreSlice.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RestoreSlice.pbtxt
new file mode 100644
index 0000000..03d2aa3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RestoreSlice.pbtxt
@@ -0,0 +1,61 @@
+op {
+  name: "RestoreSlice"
+  input_arg {
+    name: "file_pattern"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "tensor_name"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "shape_and_slice"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "tensor"
+    type_attr: "dt"
+  }
+  attr {
+    name: "dt"
+    type: "type"
+  }
+  attr {
+    name: "preferred_shard"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+}
+op {
+  name: "RestoreSlice"
+  input_arg {
+    name: "file_pattern"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "tensor_name"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "shape_and_slice"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "tensor"
+    type_attr: "dt"
+  }
+  attr {
+    name: "dt"
+    type: "type"
+  }
+  attr {
+    name: "preferred_shard"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RestoreV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RestoreV2.pbtxt
new file mode 100644
index 0000000..a88db31
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RestoreV2.pbtxt
@@ -0,0 +1,51 @@
+op {
+  name: "RestoreV2"
+  input_arg {
+    name: "prefix"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "tensor_names"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "shape_and_slices"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "tensors"
+    type_list_attr: "dtypes"
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
+  name: "RestoreV2"
+  input_arg {
+    name: "prefix"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "tensor_names"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "shape_and_slices"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "tensors"
+    type_list_attr: "dtypes"
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingADAMParameters.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingADAMParameters.pbtxt
new file mode 100644
index 0000000..4a31692
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingADAMParameters.pbtxt
@@ -0,0 +1,40 @@
+op {
+  name: "RetrieveTPUEmbeddingADAMParameters"
+  output_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "momenta"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "velocities"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingADAMParametersGradAccumDebug.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingADAMParametersGradAccumDebug.pbtxt
new file mode 100644
index 0000000..dd1651c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingADAMParametersGradAccumDebug.pbtxt
@@ -0,0 +1,44 @@
+op {
+  name: "RetrieveTPUEmbeddingADAMParametersGradAccumDebug"
+  output_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "momenta"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "velocities"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "gradient_accumulators"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingAdadeltaParameters.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingAdadeltaParameters.pbtxt
new file mode 100644
index 0000000..145e322
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingAdadeltaParameters.pbtxt
@@ -0,0 +1,40 @@
+op {
+  name: "RetrieveTPUEmbeddingAdadeltaParameters"
+  output_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "accumulators"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updates"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug.pbtxt
new file mode 100644
index 0000000..64bb295
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug.pbtxt
@@ -0,0 +1,44 @@
+op {
+  name: "RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug"
+  output_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "accumulators"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updates"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "gradient_accumulators"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingAdagradParameters.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingAdagradParameters.pbtxt
new file mode 100644
index 0000000..ceb4b68
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingAdagradParameters.pbtxt
@@ -0,0 +1,36 @@
+op {
+  name: "RetrieveTPUEmbeddingAdagradParameters"
+  output_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "accumulators"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingAdagradParametersGradAccumDebug.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingAdagradParametersGradAccumDebug.pbtxt
new file mode 100644
index 0000000..9959a8e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingAdagradParametersGradAccumDebug.pbtxt
@@ -0,0 +1,40 @@
+op {
+  name: "RetrieveTPUEmbeddingAdagradParametersGradAccumDebug"
+  output_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "accumulators"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "gradient_accumulators"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingCenteredRMSPropParameters.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingCenteredRMSPropParameters.pbtxt
new file mode 100644
index 0000000..27e66ba
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingCenteredRMSPropParameters.pbtxt
@@ -0,0 +1,44 @@
+op {
+  name: "RetrieveTPUEmbeddingCenteredRMSPropParameters"
+  output_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "ms"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "mom"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "mg"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingFTRLParameters.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingFTRLParameters.pbtxt
new file mode 100644
index 0000000..28b74a1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingFTRLParameters.pbtxt
@@ -0,0 +1,40 @@
+op {
+  name: "RetrieveTPUEmbeddingFTRLParameters"
+  output_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "accumulators"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "linears"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingFTRLParametersGradAccumDebug.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingFTRLParametersGradAccumDebug.pbtxt
new file mode 100644
index 0000000..917d4a1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingFTRLParametersGradAccumDebug.pbtxt
@@ -0,0 +1,44 @@
+op {
+  name: "RetrieveTPUEmbeddingFTRLParametersGradAccumDebug"
+  output_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "accumulators"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "linears"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "gradient_accumulators"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingMDLAdagradLightParameters.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingMDLAdagradLightParameters.pbtxt
new file mode 100644
index 0000000..2510f7e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingMDLAdagradLightParameters.pbtxt
@@ -0,0 +1,44 @@
+op {
+  name: "RetrieveTPUEmbeddingMDLAdagradLightParameters"
+  output_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "accumulators"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "weights"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "benefits"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingMomentumParameters.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingMomentumParameters.pbtxt
new file mode 100644
index 0000000..555a8c1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingMomentumParameters.pbtxt
@@ -0,0 +1,36 @@
+op {
+  name: "RetrieveTPUEmbeddingMomentumParameters"
+  output_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "momenta"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingMomentumParametersGradAccumDebug.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingMomentumParametersGradAccumDebug.pbtxt
new file mode 100644
index 0000000..fba454a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingMomentumParametersGradAccumDebug.pbtxt
@@ -0,0 +1,40 @@
+op {
+  name: "RetrieveTPUEmbeddingMomentumParametersGradAccumDebug"
+  output_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "momenta"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "gradient_accumulators"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingProximalAdagradParameters.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingProximalAdagradParameters.pbtxt
new file mode 100644
index 0000000..fdbcf9d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingProximalAdagradParameters.pbtxt
@@ -0,0 +1,36 @@
+op {
+  name: "RetrieveTPUEmbeddingProximalAdagradParameters"
+  output_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "accumulators"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug.pbtxt
new file mode 100644
index 0000000..1fbf9a2
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug.pbtxt
@@ -0,0 +1,40 @@
+op {
+  name: "RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug"
+  output_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "accumulators"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "gradient_accumulators"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingRMSPropParameters.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingRMSPropParameters.pbtxt
new file mode 100644
index 0000000..73ae099
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingRMSPropParameters.pbtxt
@@ -0,0 +1,40 @@
+op {
+  name: "RetrieveTPUEmbeddingRMSPropParameters"
+  output_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "ms"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "mom"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug.pbtxt
new file mode 100644
index 0000000..193af7d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug.pbtxt
@@ -0,0 +1,44 @@
+op {
+  name: "RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug"
+  output_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "ms"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "mom"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "gradient_accumulators"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingStochasticGradientDescentParameters.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingStochasticGradientDescentParameters.pbtxt
new file mode 100644
index 0000000..7c70f9f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RetrieveTPUEmbeddingStochasticGradientDescentParameters.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "RetrieveTPUEmbeddingStochasticGradientDescentParameters"
+  output_arg {
+    name: "parameters"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    default_value {
+      i: -1
+    }
+    has_minimum: true
+    minimum: -1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "num_shards"
+    type: "int"
+  }
+  attr {
+    name: "shard_id"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Reverse.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Reverse.pbtxt
new file mode 100644
index 0000000..99b3f2e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Reverse.pbtxt
@@ -0,0 +1,103 @@
+op {
+  name: "Reverse"
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dims"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_BOOL
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Reverse"
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dims"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_BOOL
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_STRING
+      }
+    }
+  }
+}
+op {
+  name: "Reverse"
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dims"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_BOOL
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_STRING
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ReverseSequence.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ReverseSequence.pbtxt
new file mode 100644
index 0000000..74d3601
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ReverseSequence.pbtxt
@@ -0,0 +1,43 @@
+op {
+  name: "ReverseSequence"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "seq_lengths"
+    type_attr: "Tlen"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "seq_dim"
+    type: "int"
+  }
+  attr {
+    name: "batch_dim"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tlen"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ReverseV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ReverseV2.pbtxt
new file mode 100644
index 0000000..39ee8b2
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ReverseV2.pbtxt
@@ -0,0 +1,242 @@
+op {
+  name: "ReverseV2"
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_BOOL
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "ReverseV2"
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_BOOL
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_STRING
+      }
+    }
+  }
+}
+op {
+  name: "ReverseV2"
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_BOOL
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_STRING
+      }
+    }
+  }
+}
+op {
+  name: "ReverseV2"
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_BOOL
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_STRING
+      }
+    }
+  }
+}
+op {
+  name: "ReverseV2"
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_BOOL
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_STRING
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RightShift.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RightShift.pbtxt
new file mode 100644
index 0000000..97257a0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RightShift.pbtxt
@@ -0,0 +1,63 @@
+op {
+  name: "RightShift"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "RightShift"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Rint.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Rint.pbtxt
new file mode 100644
index 0000000..feed3bc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Rint.pbtxt
@@ -0,0 +1,66 @@
+op {
+  name: "Rint"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "Rint"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "Rint"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RngSkip.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RngSkip.pbtxt
new file mode 100644
index 0000000..dc3e9b9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RngSkip.pbtxt
@@ -0,0 +1,16 @@
+op {
+  name: "RngSkip"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "algorithm"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "delta"
+    type: DT_INT64
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Roll.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Roll.pbtxt
new file mode 100644
index 0000000..ac81404
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Roll.pbtxt
@@ -0,0 +1,43 @@
+op {
+  name: "Roll"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "shift"
+    type_attr: "Tshift"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Taxis"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tshift"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Taxis"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Round.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Round.pbtxt
new file mode 100644
index 0000000..4f59b21
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Round.pbtxt
@@ -0,0 +1,80 @@
+op {
+  name: "Round"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Round"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Round"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Rpc.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Rpc.pbtxt
new file mode 100644
index 0000000..224e52e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Rpc.pbtxt
@@ -0,0 +1,41 @@
+op {
+  name: "Rpc"
+  input_arg {
+    name: "address"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "method"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "request"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "response"
+    type: DT_STRING
+  }
+  attr {
+    name: "protocol"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "fail_fast"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "timeout_in_ms"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Rsqrt.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Rsqrt.pbtxt
new file mode 100644
index 0000000..6d066c9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Rsqrt.pbtxt
@@ -0,0 +1,74 @@
+op {
+  name: "Rsqrt"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Rsqrt"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Rsqrt"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/RsqrtGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/RsqrtGrad.pbtxt
new file mode 100644
index 0000000..4509b1a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/RsqrtGrad.pbtxt
@@ -0,0 +1,114 @@
+op {
+  name: "RsqrtGrad"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "RsqrtGrad"
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dy"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "RsqrtGrad"
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dy"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "RsqrtGrad"
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dy"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SampleDistortedBoundingBox.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SampleDistortedBoundingBox.pbtxt
new file mode 100644
index 0000000..95b4a2d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SampleDistortedBoundingBox.pbtxt
@@ -0,0 +1,92 @@
+op {
+  name: "SampleDistortedBoundingBox"
+  input_arg {
+    name: "image_size"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "bounding_boxes"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "begin"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "size"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "bboxes"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "min_object_covered"
+    type: "float"
+    default_value {
+      f: 0.1
+    }
+  }
+  attr {
+    name: "aspect_ratio_range"
+    type: "list(float)"
+    default_value {
+      list {
+        f: 0.75
+        f: 1.33
+      }
+    }
+  }
+  attr {
+    name: "area_range"
+    type: "list(float)"
+    default_value {
+      list {
+        f: 0.05
+        f: 1
+      }
+    }
+  }
+  attr {
+    name: "max_attempts"
+    type: "int"
+    default_value {
+      i: 100
+    }
+  }
+  attr {
+    name: "use_image_if_no_bounding_boxes"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SampleDistortedBoundingBoxV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SampleDistortedBoundingBoxV2.pbtxt
new file mode 100644
index 0000000..d857ee0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SampleDistortedBoundingBoxV2.pbtxt
@@ -0,0 +1,89 @@
+op {
+  name: "SampleDistortedBoundingBoxV2"
+  input_arg {
+    name: "image_size"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "bounding_boxes"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "min_object_covered"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "begin"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "size"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "bboxes"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "aspect_ratio_range"
+    type: "list(float)"
+    default_value {
+      list {
+        f: 0.75
+        f: 1.33
+      }
+    }
+  }
+  attr {
+    name: "area_range"
+    type: "list(float)"
+    default_value {
+      list {
+        f: 0.05
+        f: 1
+      }
+    }
+  }
+  attr {
+    name: "max_attempts"
+    type: "int"
+    default_value {
+      i: 100
+    }
+  }
+  attr {
+    name: "use_image_if_no_bounding_boxes"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SamplingDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SamplingDataset.pbtxt
new file mode 100644
index 0000000..fd183f8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SamplingDataset.pbtxt
@@ -0,0 +1,35 @@
+op {
+  name: "SamplingDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "seed"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "seed2"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Save.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Save.pbtxt
new file mode 100644
index 0000000..c632730
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Save.pbtxt
@@ -0,0 +1,43 @@
+op {
+  name: "Save"
+  input_arg {
+    name: "filename"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "tensor_names"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "data"
+    type_list_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
+  name: "Save"
+  input_arg {
+    name: "filename"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "tensor_names"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "data"
+    type_list_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SaveSlices.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SaveSlices.pbtxt
new file mode 100644
index 0000000..306d67b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SaveSlices.pbtxt
@@ -0,0 +1,51 @@
+op {
+  name: "SaveSlices"
+  input_arg {
+    name: "filename"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "tensor_names"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "shapes_and_slices"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "data"
+    type_list_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
+  name: "SaveSlices"
+  input_arg {
+    name: "filename"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "tensor_names"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "shapes_and_slices"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "data"
+    type_list_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SaveV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SaveV2.pbtxt
new file mode 100644
index 0000000..d9bae4c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SaveV2.pbtxt
@@ -0,0 +1,51 @@
+op {
+  name: "SaveV2"
+  input_arg {
+    name: "prefix"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "tensor_names"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "shape_and_slices"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "tensors"
+    type_list_attr: "dtypes"
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
+  name: "SaveV2"
+  input_arg {
+    name: "prefix"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "tensor_names"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "shape_and_slices"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "tensors"
+    type_list_attr: "dtypes"
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ScalarSummary.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ScalarSummary.pbtxt
new file mode 100644
index 0000000..bf49480
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ScalarSummary.pbtxt
@@ -0,0 +1,136 @@
+op {
+  name: "ScalarSummary"
+  input_arg {
+    name: "tags"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "summary"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "ScalarSummary"
+  input_arg {
+    name: "tags"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "summary"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "ScalarSummary"
+  input_arg {
+    name: "tags"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "summary"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "ScalarSummary"
+  input_arg {
+    name: "tags"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "summary"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ScaleAndTranslate.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ScaleAndTranslate.pbtxt
new file mode 100644
index 0000000..516cca3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ScaleAndTranslate.pbtxt
@@ -0,0 +1,103 @@
+op {
+  name: "ScaleAndTranslate"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "scale"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "translation"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "resized_images"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_UINT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "kernel_type"
+    type: "string"
+    default_value {
+      s: "lanczos3"
+    }
+  }
+}
+op {
+  name: "ScaleAndTranslate"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "scale"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "translation"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "resized_images"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_UINT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "kernel_type"
+    type: "string"
+    default_value {
+      s: "lanczos3"
+    }
+  }
+  attr {
+    name: "antialias"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ScaleAndTranslateGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ScaleAndTranslateGrad.pbtxt
new file mode 100644
index 0000000..8eaa03c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ScaleAndTranslateGrad.pbtxt
@@ -0,0 +1,85 @@
+op {
+  name: "ScaleAndTranslateGrad"
+  input_arg {
+    name: "grads"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "original_image"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "scale"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "translation"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "kernel_type"
+    type: "string"
+    default_value {
+      s: "lanczos3"
+    }
+  }
+}
+op {
+  name: "ScaleAndTranslateGrad"
+  input_arg {
+    name: "grads"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "original_image"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "scale"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "translation"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "kernel_type"
+    type: "string"
+    default_value {
+      s: "lanczos3"
+    }
+  }
+  attr {
+    name: "antialias"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ScanDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ScanDataset.pbtxt
new file mode 100644
index 0000000..8ac7432
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ScanDataset.pbtxt
@@ -0,0 +1,53 @@
+op {
+  name: "ScanDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "initial_state"
+    type_list_attr: "Tstate"
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Tstate"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "preserve_cardinality"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ScatterAdd.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ScatterAdd.pbtxt
new file mode 100644
index 0000000..71b48857
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ScatterAdd.pbtxt
@@ -0,0 +1,248 @@
+op {
+  name: "ScatterAdd"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ScatterAdd"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ScatterAdd"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ScatterAdd"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ScatterDiv.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ScatterDiv.pbtxt
new file mode 100644
index 0000000..256da22
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ScatterDiv.pbtxt
@@ -0,0 +1,248 @@
+op {
+  name: "ScatterDiv"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ScatterDiv"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ScatterDiv"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ScatterDiv"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ScatterMax.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ScatterMax.pbtxt
new file mode 100644
index 0000000..fe176e1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ScatterMax.pbtxt
@@ -0,0 +1,52 @@
+op {
+  name: "ScatterMax"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ScatterMin.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ScatterMin.pbtxt
new file mode 100644
index 0000000..7099d89
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ScatterMin.pbtxt
@@ -0,0 +1,52 @@
+op {
+  name: "ScatterMin"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ScatterMul.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ScatterMul.pbtxt
new file mode 100644
index 0000000..ae16baf
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ScatterMul.pbtxt
@@ -0,0 +1,248 @@
+op {
+  name: "ScatterMul"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ScatterMul"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ScatterMul"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ScatterMul"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ScatterNd.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ScatterNd.pbtxt
new file mode 100644
index 0000000..62cfd05
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ScatterNd.pbtxt
@@ -0,0 +1,33 @@
+op {
+  name: "ScatterNd"
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "shape"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ScatterNdAdd.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ScatterNdAdd.pbtxt
new file mode 100644
index 0000000..5eb6287
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ScatterNdAdd.pbtxt
@@ -0,0 +1,248 @@
+op {
+  name: "ScatterNdAdd"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ScatterNdAdd"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ScatterNdAdd"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ScatterNdAdd"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ScatterNdNonAliasingAdd.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ScatterNdNonAliasingAdd.pbtxt
new file mode 100644
index 0000000..d65f471
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ScatterNdNonAliasingAdd.pbtxt
@@ -0,0 +1,267 @@
+op {
+  name: "ScatterNdNonAliasingAdd"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "ScatterNdNonAliasingAdd"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "ScatterNdNonAliasingAdd"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "ScatterNdNonAliasingAdd"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "ScatterNdNonAliasingAdd"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BOOL
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ScatterNdSub.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ScatterNdSub.pbtxt
new file mode 100644
index 0000000..d13fe16
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ScatterNdSub.pbtxt
@@ -0,0 +1,248 @@
+op {
+  name: "ScatterNdSub"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ScatterNdSub"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ScatterNdSub"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ScatterNdSub"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ScatterNdUpdate.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ScatterNdUpdate.pbtxt
new file mode 100644
index 0000000..73def71
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ScatterNdUpdate.pbtxt
@@ -0,0 +1,42 @@
+op {
+  name: "ScatterNdUpdate"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ScatterSub.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ScatterSub.pbtxt
new file mode 100644
index 0000000..dbfb97f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ScatterSub.pbtxt
@@ -0,0 +1,248 @@
+op {
+  name: "ScatterSub"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ScatterSub"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ScatterSub"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "ScatterSub"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ScatterUpdate.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ScatterUpdate.pbtxt
new file mode 100644
index 0000000..2f29273
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ScatterUpdate.pbtxt
@@ -0,0 +1,42 @@
+op {
+  name: "ScatterUpdate"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SdcaFprint.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SdcaFprint.pbtxt
new file mode 100644
index 0000000..979c001
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SdcaFprint.pbtxt
@@ -0,0 +1,11 @@
+op {
+  name: "SdcaFprint"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type: DT_INT64
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SdcaOptimizer.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SdcaOptimizer.pbtxt
new file mode 100644
index 0000000..3746f95
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SdcaOptimizer.pbtxt
@@ -0,0 +1,237 @@
+op {
+  name: "SdcaOptimizer"
+  input_arg {
+    name: "sparse_example_indices"
+    type: DT_INT64
+    number_attr: "num_sparse_features"
+  }
+  input_arg {
+    name: "sparse_feature_indices"
+    type: DT_INT64
+    number_attr: "num_sparse_features"
+  }
+  input_arg {
+    name: "sparse_feature_values"
+    type: DT_FLOAT
+    number_attr: "num_sparse_features_with_values"
+  }
+  input_arg {
+    name: "dense_features"
+    type: DT_FLOAT
+    number_attr: "num_dense_features"
+  }
+  input_arg {
+    name: "example_weights"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "example_labels"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "sparse_indices"
+    type: DT_INT64
+    number_attr: "num_sparse_features"
+  }
+  input_arg {
+    name: "sparse_weights"
+    type: DT_FLOAT
+    number_attr: "num_sparse_features"
+  }
+  input_arg {
+    name: "dense_weights"
+    type: DT_FLOAT
+    number_attr: "num_dense_features"
+  }
+  input_arg {
+    name: "example_state_data"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "out_example_state_data"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "out_delta_sparse_weights"
+    type: DT_FLOAT
+    number_attr: "num_sparse_features"
+  }
+  output_arg {
+    name: "out_delta_dense_weights"
+    type: DT_FLOAT
+    number_attr: "num_dense_features"
+  }
+  attr {
+    name: "loss_type"
+    type: "string"
+    allowed_values {
+      list {
+        s: "logistic_loss"
+        s: "squared_loss"
+        s: "hinge_loss"
+        s: "smooth_hinge_loss"
+      }
+    }
+  }
+  attr {
+    name: "adaptative"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "num_sparse_features"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "num_sparse_features_with_values"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "num_dense_features"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "l1"
+    type: "float"
+  }
+  attr {
+    name: "l2"
+    type: "float"
+  }
+  attr {
+    name: "num_loss_partitions"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_inner_iterations"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
+  name: "SdcaOptimizer"
+  input_arg {
+    name: "sparse_example_indices"
+    type: DT_INT64
+    number_attr: "num_sparse_features"
+  }
+  input_arg {
+    name: "sparse_feature_indices"
+    type: DT_INT64
+    number_attr: "num_sparse_features"
+  }
+  input_arg {
+    name: "sparse_feature_values"
+    type: DT_FLOAT
+    number_attr: "num_sparse_features_with_values"
+  }
+  input_arg {
+    name: "dense_features"
+    type: DT_FLOAT
+    number_attr: "num_dense_features"
+  }
+  input_arg {
+    name: "example_weights"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "example_labels"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "sparse_indices"
+    type: DT_INT64
+    number_attr: "num_sparse_features"
+  }
+  input_arg {
+    name: "sparse_weights"
+    type: DT_FLOAT
+    number_attr: "num_sparse_features"
+  }
+  input_arg {
+    name: "dense_weights"
+    type: DT_FLOAT
+    number_attr: "num_dense_features"
+  }
+  input_arg {
+    name: "example_state_data"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "out_example_state_data"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "out_delta_sparse_weights"
+    type: DT_FLOAT
+    number_attr: "num_sparse_features"
+  }
+  output_arg {
+    name: "out_delta_dense_weights"
+    type: DT_FLOAT
+    number_attr: "num_dense_features"
+  }
+  attr {
+    name: "loss_type"
+    type: "string"
+    allowed_values {
+      list {
+        s: "logistic_loss"
+        s: "squared_loss"
+        s: "hinge_loss"
+        s: "smooth_hinge_loss"
+        s: "poisson_loss"
+      }
+    }
+  }
+  attr {
+    name: "adaptative"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "num_sparse_features"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "num_sparse_features_with_values"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "num_dense_features"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "l1"
+    type: "float"
+  }
+  attr {
+    name: "l2"
+    type: "float"
+  }
+  attr {
+    name: "num_loss_partitions"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_inner_iterations"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SdcaOptimizerV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SdcaOptimizerV2.pbtxt
new file mode 100644
index 0000000..cb16c8f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SdcaOptimizerV2.pbtxt
@@ -0,0 +1,119 @@
+op {
+  name: "SdcaOptimizerV2"
+  input_arg {
+    name: "sparse_example_indices"
+    type: DT_INT64
+    number_attr: "num_sparse_features"
+  }
+  input_arg {
+    name: "sparse_feature_indices"
+    type: DT_INT64
+    number_attr: "num_sparse_features"
+  }
+  input_arg {
+    name: "sparse_feature_values"
+    type: DT_FLOAT
+    number_attr: "num_sparse_features_with_values"
+  }
+  input_arg {
+    name: "dense_features"
+    type: DT_FLOAT
+    number_attr: "num_dense_features"
+  }
+  input_arg {
+    name: "example_weights"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "example_labels"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "sparse_indices"
+    type: DT_INT64
+    number_attr: "num_sparse_features"
+  }
+  input_arg {
+    name: "sparse_weights"
+    type: DT_FLOAT
+    number_attr: "num_sparse_features"
+  }
+  input_arg {
+    name: "dense_weights"
+    type: DT_FLOAT
+    number_attr: "num_dense_features"
+  }
+  input_arg {
+    name: "example_state_data"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "out_example_state_data"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "out_delta_sparse_weights"
+    type: DT_FLOAT
+    number_attr: "num_sparse_features"
+  }
+  output_arg {
+    name: "out_delta_dense_weights"
+    type: DT_FLOAT
+    number_attr: "num_dense_features"
+  }
+  attr {
+    name: "loss_type"
+    type: "string"
+    allowed_values {
+      list {
+        s: "logistic_loss"
+        s: "squared_loss"
+        s: "hinge_loss"
+        s: "smooth_hinge_loss"
+        s: "poisson_loss"
+      }
+    }
+  }
+  attr {
+    name: "adaptive"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "num_sparse_features"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "num_sparse_features_with_values"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "num_dense_features"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "l1"
+    type: "float"
+  }
+  attr {
+    name: "l2"
+    type: "float"
+  }
+  attr {
+    name: "num_loss_partitions"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_inner_iterations"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SdcaShrinkL1.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SdcaShrinkL1.pbtxt
new file mode 100644
index 0000000..23d9fdd
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SdcaShrinkL1.pbtxt
@@ -0,0 +1,22 @@
+op {
+  name: "SdcaShrinkL1"
+  input_arg {
+    name: "weights"
+    type: DT_FLOAT
+    number_attr: "num_features"
+    is_ref: true
+  }
+  attr {
+    name: "num_features"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "l1"
+    type: "float"
+  }
+  attr {
+    name: "l2"
+    type: "float"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SegmentMax.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SegmentMax.pbtxt
new file mode 100644
index 0000000..a1d5968
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SegmentMax.pbtxt
@@ -0,0 +1,176 @@
+op {
+  name: "SegmentMax"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "SegmentMax"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "SegmentMax"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "SegmentMax"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SegmentMean.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SegmentMean.pbtxt
new file mode 100644
index 0000000..e359dbe
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SegmentMean.pbtxt
@@ -0,0 +1,226 @@
+op {
+  name: "SegmentMean"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "SegmentMean"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "SegmentMean"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "SegmentMean"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "SegmentMean"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SegmentMin.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SegmentMin.pbtxt
new file mode 100644
index 0000000..bf87e82
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SegmentMin.pbtxt
@@ -0,0 +1,176 @@
+op {
+  name: "SegmentMin"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "SegmentMin"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "SegmentMin"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "SegmentMin"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SegmentProd.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SegmentProd.pbtxt
new file mode 100644
index 0000000..1666c49
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SegmentProd.pbtxt
@@ -0,0 +1,196 @@
+op {
+  name: "SegmentProd"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "SegmentProd"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "SegmentProd"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "SegmentProd"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SegmentSum.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SegmentSum.pbtxt
new file mode 100644
index 0000000..c7957a9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SegmentSum.pbtxt
@@ -0,0 +1,196 @@
+op {
+  name: "SegmentSum"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "SegmentSum"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "SegmentSum"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "SegmentSum"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Select.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Select.pbtxt
new file mode 100644
index 0000000..38d00af
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Select.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "Select"
+  input_arg {
+    name: "condition"
+    type: DT_BOOL
+  }
+  input_arg {
+    name: "t"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "e"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SelectV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SelectV2.pbtxt
new file mode 100644
index 0000000..a7c59f0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SelectV2.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "SelectV2"
+  input_arg {
+    name: "condition"
+    type: DT_BOOL
+  }
+  input_arg {
+    name: "t"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "e"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SelfAdjointEig.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SelfAdjointEig.pbtxt
new file mode 100644
index 0000000..3657cc1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SelfAdjointEig.pbtxt
@@ -0,0 +1,49 @@
+op {
+  name: "SelfAdjointEig"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+      }
+    }
+  }
+  deprecation {
+    version: 11
+  }
+}
+op {
+  name: "SelfAdjointEig"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_HALF
+      }
+    }
+  }
+  deprecation {
+    version: 11
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SelfAdjointEigV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SelfAdjointEigV2.pbtxt
new file mode 100644
index 0000000..8fbbfc9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SelfAdjointEigV2.pbtxt
@@ -0,0 +1,101 @@
+op {
+  name: "SelfAdjointEigV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "e"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "v"
+    type_attr: "T"
+  }
+  attr {
+    name: "compute_v"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+      }
+    }
+  }
+}
+op {
+  name: "SelfAdjointEigV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "e"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "v"
+    type_attr: "T"
+  }
+  attr {
+    name: "compute_v"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "SelfAdjointEigV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "e"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "v"
+    type_attr: "T"
+  }
+  attr {
+    name: "compute_v"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Selu.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Selu.pbtxt
new file mode 100644
index 0000000..2acf579
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Selu.pbtxt
@@ -0,0 +1,45 @@
+op {
+  name: "Selu"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "Selu"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SeluGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SeluGrad.pbtxt
new file mode 100644
index 0000000..f96c7cb
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SeluGrad.pbtxt
@@ -0,0 +1,53 @@
+op {
+  name: "SeluGrad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "outputs"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "SeluGrad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "outputs"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SendTPUEmbeddingGradients.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SendTPUEmbeddingGradients.pbtxt
new file mode 100644
index 0000000..f6c486f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SendTPUEmbeddingGradients.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "SendTPUEmbeddingGradients"
+  input_arg {
+    name: "inputs"
+    type: DT_FLOAT
+    number_attr: "N"
+  }
+  input_arg {
+    name: "learning_rates"
+    type: DT_FLOAT
+    number_attr: "NN"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "NN"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "config"
+    type: "string"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SerializeIterator.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SerializeIterator.pbtxt
new file mode 100644
index 0000000..618ff27
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SerializeIterator.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "SerializeIterator"
+  input_arg {
+    name: "resource_handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "serialized"
+    type: DT_VARIANT
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SerializeManySparse.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SerializeManySparse.pbtxt
new file mode 100644
index 0000000..9e74163
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SerializeManySparse.pbtxt
@@ -0,0 +1,59 @@
+op {
+  name: "SerializeManySparse"
+  input_arg {
+    name: "sparse_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sparse_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sparse_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "serialized_sparse"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
+op {
+  name: "SerializeManySparse"
+  input_arg {
+    name: "sparse_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sparse_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sparse_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "serialized_sparse"
+    type_attr: "out_type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_STRING
+    }
+    allowed_values {
+      list {
+        type: DT_STRING
+        type: DT_VARIANT
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SerializeSparse.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SerializeSparse.pbtxt
new file mode 100644
index 0000000..5040d77
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SerializeSparse.pbtxt
@@ -0,0 +1,59 @@
+op {
+  name: "SerializeSparse"
+  input_arg {
+    name: "sparse_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sparse_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sparse_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "serialized_sparse"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
+op {
+  name: "SerializeSparse"
+  input_arg {
+    name: "sparse_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sparse_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sparse_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "serialized_sparse"
+    type_attr: "out_type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_STRING
+    }
+    allowed_values {
+      list {
+        type: DT_STRING
+        type: DT_VARIANT
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SerializeTensor.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SerializeTensor.pbtxt
new file mode 100644
index 0000000..4d7b5cf5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SerializeTensor.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "SerializeTensor"
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "serialized"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SetSize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SetSize.pbtxt
new file mode 100644
index 0000000..185e5e0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SetSize.pbtxt
@@ -0,0 +1,41 @@
+op {
+  name: "SetSize"
+  input_arg {
+    name: "set_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "set_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "set_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  attr {
+    name: "validate_indices"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_STRING
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SetStatsAggregatorDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SetStatsAggregatorDataset.pbtxt
new file mode 100644
index 0000000..35613fa
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SetStatsAggregatorDataset.pbtxt
@@ -0,0 +1,36 @@
+op {
+  name: "SetStatsAggregatorDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "stats_aggregator"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "tag"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "counter_prefix"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Shape.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Shape.pbtxt
new file mode 100644
index 0000000..3618b52
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Shape.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "Shape"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ShapeN.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ShapeN.pbtxt
new file mode 100644
index 0000000..15e9f11
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ShapeN.pbtxt
@@ -0,0 +1,36 @@
+op {
+  name: "ShapeN"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+    number_attr: "N"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+    number_attr: "N"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ShardDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ShardDataset.pbtxt
new file mode 100644
index 0000000..e21b5ba
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ShardDataset.pbtxt
@@ -0,0 +1,69 @@
+op {
+  name: "ShardDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "num_shards"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "index"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
+  name: "ShardDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "num_shards"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "index"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "require_non_empty"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ShardedFilename.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ShardedFilename.pbtxt
new file mode 100644
index 0000000..cf46ffd
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ShardedFilename.pbtxt
@@ -0,0 +1,19 @@
+op {
+  name: "ShardedFilename"
+  input_arg {
+    name: "basename"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "shard"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "num_shards"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "filename"
+    type: DT_STRING
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ShardedFilespec.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ShardedFilespec.pbtxt
new file mode 100644
index 0000000..7d1badc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ShardedFilespec.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "ShardedFilespec"
+  input_arg {
+    name: "basename"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "num_shards"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "filename"
+    type: DT_STRING
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ShuffleAndRepeatDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ShuffleAndRepeatDataset.pbtxt
new file mode 100644
index 0000000..5af8dd5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ShuffleAndRepeatDataset.pbtxt
@@ -0,0 +1,39 @@
+op {
+  name: "ShuffleAndRepeatDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "buffer_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "seed"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "seed2"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "count"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ShuffleDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ShuffleDataset.pbtxt
new file mode 100644
index 0000000..70d1e1d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ShuffleDataset.pbtxt
@@ -0,0 +1,113 @@
+op {
+  name: "ShuffleDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "buffer_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "seed"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "seed2"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
+op {
+  name: "ShuffleDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "buffer_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "seed"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "seed2"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
+  name: "ShuffleDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "buffer_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "seed"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "seed2"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "reshuffle_each_iteration"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ShuffleDatasetV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ShuffleDatasetV2.pbtxt
new file mode 100644
index 0000000..e2dd11d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ShuffleDatasetV2.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "ShuffleDatasetV2"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "buffer_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "seed_generator"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ShutdownDistributedTPU.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ShutdownDistributedTPU.pbtxt
new file mode 100644
index 0000000..9e60b7f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ShutdownDistributedTPU.pbtxt
@@ -0,0 +1,4 @@
+op {
+  name: "ShutdownDistributedTPU"
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Sigmoid.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Sigmoid.pbtxt
new file mode 100644
index 0000000..dee59f6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Sigmoid.pbtxt
@@ -0,0 +1,74 @@
+op {
+  name: "Sigmoid"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Sigmoid"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Sigmoid"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SigmoidGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SigmoidGrad.pbtxt
new file mode 100644
index 0000000..788c338
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SigmoidGrad.pbtxt
@@ -0,0 +1,114 @@
+op {
+  name: "SigmoidGrad"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "SigmoidGrad"
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dy"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "SigmoidGrad"
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dy"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "SigmoidGrad"
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dy"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Sign.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Sign.pbtxt
new file mode 100644
index 0000000..9afaa14
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Sign.pbtxt
@@ -0,0 +1,80 @@
+op {
+  name: "Sign"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Sign"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Sign"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Sin.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Sin.pbtxt
new file mode 100644
index 0000000..f6122e6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Sin.pbtxt
@@ -0,0 +1,74 @@
+op {
+  name: "Sin"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Sin"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Sin"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Sinh.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Sinh.pbtxt
new file mode 100644
index 0000000..7225234
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Sinh.pbtxt
@@ -0,0 +1,74 @@
+op {
+  name: "Sinh"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Sinh"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Sinh"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Size.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Size.pbtxt
new file mode 100644
index 0000000..db039e4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Size.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "Size"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SkipDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SkipDataset.pbtxt
new file mode 100644
index 0000000..6f5d1f6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SkipDataset.pbtxt
@@ -0,0 +1,55 @@
+op {
+  name: "SkipDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "count"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
+op {
+  name: "SkipDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "count"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Skipgram.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Skipgram.pbtxt
new file mode 100644
index 0000000..d31bc82
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Skipgram.pbtxt
@@ -0,0 +1,64 @@
+op {
+  name: "Skipgram"
+  output_arg {
+    name: "vocab_word"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "vocab_freq"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "words_per_epoch"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "current_epoch"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "total_words_processed"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "examples"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "labels"
+    type: DT_INT32
+  }
+  attr {
+    name: "filename"
+    type: "string"
+  }
+  attr {
+    name: "batch_size"
+    type: "int"
+  }
+  attr {
+    name: "window_size"
+    type: "int"
+    default_value {
+      i: 5
+    }
+  }
+  attr {
+    name: "min_count"
+    type: "int"
+    default_value {
+      i: 5
+    }
+  }
+  attr {
+    name: "subsample"
+    type: "float"
+    default_value {
+      f: 0.001
+    }
+  }
+  deprecation {
+    version: 19
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SleepDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SleepDataset.pbtxt
new file mode 100644
index 0000000..ed669a7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SleepDataset.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "SleepDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "sleep_microseconds"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Slice.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Slice.pbtxt
new file mode 100644
index 0000000..ced3fb6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Slice.pbtxt
@@ -0,0 +1,33 @@
+op {
+  name: "Slice"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "begin"
+    type_attr: "Index"
+  }
+  input_arg {
+    name: "size"
+    type_attr: "Index"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Index"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SlidingWindowDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SlidingWindowDataset.pbtxt
new file mode 100644
index 0000000..87298ca
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SlidingWindowDataset.pbtxt
@@ -0,0 +1,35 @@
+op {
+  name: "SlidingWindowDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "window_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "window_shift"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "window_stride"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Snapshot.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Snapshot.pbtxt
new file mode 100644
index 0000000..aea213f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Snapshot.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "Snapshot"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SnapshotDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SnapshotDataset.pbtxt
new file mode 100644
index 0000000..8a104bc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SnapshotDataset.pbtxt
@@ -0,0 +1,276 @@
+op {
+  name: "SnapshotDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "path"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "compression"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "reader_path_prefix"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "writer_path_prefix"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+}
+op {
+  name: "SnapshotDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "path"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "compression"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "reader_path_prefix"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "writer_path_prefix"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shard_size_bytes"
+    type: "int"
+    default_value {
+      i: 10737418240
+    }
+  }
+  attr {
+    name: "pending_snapshot_expiry_seconds"
+    type: "int"
+    default_value {
+      i: 86400
+    }
+  }
+}
+op {
+  name: "SnapshotDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "path"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "compression"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "reader_path_prefix"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "writer_path_prefix"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shard_size_bytes"
+    type: "int"
+    default_value {
+      i: 10737418240
+    }
+  }
+  attr {
+    name: "pending_snapshot_expiry_seconds"
+    type: "int"
+    default_value {
+      i: 86400
+    }
+  }
+  attr {
+    name: "num_reader_threads"
+    type: "int"
+    default_value {
+      i: 1
+    }
+  }
+  attr {
+    name: "reader_buffer_size"
+    type: "int"
+    default_value {
+      i: 1
+    }
+  }
+}
+op {
+  name: "SnapshotDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "path"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "compression"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "reader_path_prefix"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "writer_path_prefix"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shard_size_bytes"
+    type: "int"
+    default_value {
+      i: 10737418240
+    }
+  }
+  attr {
+    name: "pending_snapshot_expiry_seconds"
+    type: "int"
+    default_value {
+      i: 86400
+    }
+  }
+  attr {
+    name: "num_reader_threads"
+    type: "int"
+    default_value {
+      i: 1
+    }
+  }
+  attr {
+    name: "reader_buffer_size"
+    type: "int"
+    default_value {
+      i: 1
+    }
+  }
+  attr {
+    name: "num_writer_threads"
+    type: "int"
+    default_value {
+      i: 1
+    }
+  }
+  attr {
+    name: "writer_buffer_size"
+    type: "int"
+    default_value {
+      i: 1
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Softmax.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Softmax.pbtxt
new file mode 100644
index 0000000..03f4997
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Softmax.pbtxt
@@ -0,0 +1,45 @@
+op {
+  name: "Softmax"
+  input_arg {
+    name: "logits"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "softmax"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "Softmax"
+  input_arg {
+    name: "logits"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "softmax"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SoftmaxCrossEntropyWithLogits.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SoftmaxCrossEntropyWithLogits.pbtxt
new file mode 100644
index 0000000..8ac8052
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SoftmaxCrossEntropyWithLogits.pbtxt
@@ -0,0 +1,61 @@
+op {
+  name: "SoftmaxCrossEntropyWithLogits"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "labels"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "loss"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprop"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "SoftmaxCrossEntropyWithLogits"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "labels"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "loss"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprop"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Softplus.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Softplus.pbtxt
new file mode 100644
index 0000000..3757e8d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Softplus.pbtxt
@@ -0,0 +1,143 @@
+op {
+  name: "Softplus"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "Softplus"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "Softplus"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "Softplus"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "Softplus"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SoftplusGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SoftplusGrad.pbtxt
new file mode 100644
index 0000000..331b1ab
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SoftplusGrad.pbtxt
@@ -0,0 +1,163 @@
+op {
+  name: "SoftplusGrad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "SoftplusGrad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "SoftplusGrad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "SoftplusGrad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "SoftplusGrad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Softsign.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Softsign.pbtxt
new file mode 100644
index 0000000..c83bc99
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Softsign.pbtxt
@@ -0,0 +1,143 @@
+op {
+  name: "Softsign"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "Softsign"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "Softsign"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "Softsign"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "Softsign"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "activations"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SoftsignGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SoftsignGrad.pbtxt
new file mode 100644
index 0000000..5411f9b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SoftsignGrad.pbtxt
@@ -0,0 +1,163 @@
+op {
+  name: "SoftsignGrad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "SoftsignGrad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "SoftsignGrad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "SoftsignGrad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "SoftsignGrad"
+  input_arg {
+    name: "gradients"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprops"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SpaceToBatch.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SpaceToBatch.pbtxt
new file mode 100644
index 0000000..155e1b3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SpaceToBatch.pbtxt
@@ -0,0 +1,38 @@
+op {
+  name: "SpaceToBatch"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "paddings"
+    type_attr: "Tpaddings"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tpaddings"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "block_size"
+    type: "int"
+    has_minimum: true
+    minimum: 2
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SpaceToBatchND.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SpaceToBatchND.pbtxt
new file mode 100644
index 0000000..c38026e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SpaceToBatchND.pbtxt
@@ -0,0 +1,49 @@
+op {
+  name: "SpaceToBatchND"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "block_shape"
+    type_attr: "Tblock_shape"
+  }
+  input_arg {
+    name: "paddings"
+    type_attr: "Tpaddings"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tblock_shape"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tpaddings"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SpaceToDepth.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SpaceToDepth.pbtxt
new file mode 100644
index 0000000..c7dd03e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SpaceToDepth.pbtxt
@@ -0,0 +1,56 @@
+op {
+  name: "SpaceToDepth"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "block_size"
+    type: "int"
+    has_minimum: true
+    minimum: 2
+  }
+}
+op {
+  name: "SpaceToDepth"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "block_size"
+    type: "int"
+    has_minimum: true
+    minimum: 2
+  }
+  attr {
+    name: "data_format"
+    type: "string"
+    default_value {
+      s: "NHWC"
+    }
+    allowed_values {
+      list {
+        s: "NHWC"
+        s: "NCHW"
+        s: "NCHW_VECT_C"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseAccumulatorApplyGradient.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseAccumulatorApplyGradient.pbtxt
new file mode 100644
index 0000000..45b7e1c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseAccumulatorApplyGradient.pbtxt
@@ -0,0 +1,208 @@
+op {
+  name: "SparseAccumulatorApplyGradient"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "local_step"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "gradient_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "gradient_values"
+    type_attr: "dtype"
+  }
+  input_arg {
+    name: "gradient_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "has_known_shape"
+    type: "bool"
+  }
+}
+op {
+  name: "SparseAccumulatorApplyGradient"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "local_step"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "gradient_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "gradient_values"
+    type_attr: "dtype"
+  }
+  input_arg {
+    name: "gradient_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "has_known_shape"
+    type: "bool"
+  }
+}
+op {
+  name: "SparseAccumulatorApplyGradient"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "local_step"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "gradient_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "gradient_values"
+    type_attr: "dtype"
+  }
+  input_arg {
+    name: "gradient_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "has_known_shape"
+    type: "bool"
+  }
+}
+op {
+  name: "SparseAccumulatorApplyGradient"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "local_step"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "gradient_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "gradient_values"
+    type_attr: "dtype"
+  }
+  input_arg {
+    name: "gradient_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "has_known_shape"
+    type: "bool"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseAccumulatorTakeGradient.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseAccumulatorTakeGradient.pbtxt
new file mode 100644
index 0000000..e12c8c2
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseAccumulatorTakeGradient.pbtxt
@@ -0,0 +1,192 @@
+op {
+  name: "SparseAccumulatorTakeGradient"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "num_required"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "values"
+    type_attr: "dtype"
+  }
+  output_arg {
+    name: "shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "SparseAccumulatorTakeGradient"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "num_required"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "values"
+    type_attr: "dtype"
+  }
+  output_arg {
+    name: "shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "SparseAccumulatorTakeGradient"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "num_required"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "values"
+    type_attr: "dtype"
+  }
+  output_arg {
+    name: "shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "SparseAccumulatorTakeGradient"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "num_required"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "values"
+    type_attr: "dtype"
+  }
+  output_arg {
+    name: "shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseAdd.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseAdd.pbtxt
new file mode 100644
index 0000000..a7f3970
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseAdd.pbtxt
@@ -0,0 +1,344 @@
+op {
+  name: "SparseAdd"
+  input_arg {
+    name: "a_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "a_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "a_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "thresh"
+    type_attr: "Treal"
+  }
+  output_arg {
+    name: "sum_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sum_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "sum_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Treal"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "SparseAdd"
+  input_arg {
+    name: "a_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "a_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "a_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "thresh"
+    type_attr: "Treal"
+  }
+  output_arg {
+    name: "sum_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sum_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "sum_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Treal"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "SparseAdd"
+  input_arg {
+    name: "a_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "a_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "a_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "thresh"
+    type_attr: "Treal"
+  }
+  output_arg {
+    name: "sum_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sum_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "sum_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Treal"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "SparseAdd"
+  input_arg {
+    name: "a_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "a_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "a_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "thresh"
+    type_attr: "Treal"
+  }
+  output_arg {
+    name: "sum_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sum_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "sum_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Treal"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseAddGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseAddGrad.pbtxt
new file mode 100644
index 0000000..87edd91
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseAddGrad.pbtxt
@@ -0,0 +1,204 @@
+op {
+  name: "SparseAddGrad"
+  input_arg {
+    name: "backprop_val_grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "a_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sum_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "a_val_grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "b_val_grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "SparseAddGrad"
+  input_arg {
+    name: "backprop_val_grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "a_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sum_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "a_val_grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "b_val_grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "SparseAddGrad"
+  input_arg {
+    name: "backprop_val_grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "a_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sum_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "a_val_grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "b_val_grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "SparseAddGrad"
+  input_arg {
+    name: "backprop_val_grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "a_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sum_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "a_val_grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "b_val_grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseApplyAdadelta.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseApplyAdadelta.pbtxt
new file mode 100644
index 0000000..655b73d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseApplyAdadelta.pbtxt
@@ -0,0 +1,336 @@
+op {
+  name: "SparseApplyAdadelta"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum_update"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyAdadelta"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum_update"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyAdadelta"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum_update"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyAdadelta"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum_update"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseApplyAdagrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseApplyAdagrad.pbtxt
new file mode 100644
index 0000000..70d42a4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseApplyAdagrad.pbtxt
@@ -0,0 +1,363 @@
+op {
+  name: "SparseApplyAdagrad"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyAdagrad"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyAdagrad"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyAdagrad"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyAdagrad"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "update_slots"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseApplyAdagradDA.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseApplyAdagradDA.pbtxt
new file mode 100644
index 0000000..dedda34
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseApplyAdagradDA.pbtxt
@@ -0,0 +1,352 @@
+op {
+  name: "SparseApplyAdagradDA"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "gradient_accumulator"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "gradient_squared_accumulator"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "global_step"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyAdagradDA"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "gradient_accumulator"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "gradient_squared_accumulator"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "global_step"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyAdagradDA"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "gradient_accumulator"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "gradient_squared_accumulator"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "global_step"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyAdagradDA"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "gradient_accumulator"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "gradient_squared_accumulator"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "global_step"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseApplyAdagradV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseApplyAdagradV2.pbtxt
new file mode 100644
index 0000000..8a6a92b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseApplyAdagradV2.pbtxt
@@ -0,0 +1,83 @@
+op {
+  name: "SparseApplyAdagradV2"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "update_slots"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseApplyCenteredRMSProp.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseApplyCenteredRMSProp.pbtxt
new file mode 100644
index 0000000..4ae5fb1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseApplyCenteredRMSProp.pbtxt
@@ -0,0 +1,372 @@
+op {
+  name: "SparseApplyCenteredRMSProp"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "mg"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "ms"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "mom"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyCenteredRMSProp"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "mg"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "ms"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "mom"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyCenteredRMSProp"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "mg"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "ms"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "mom"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyCenteredRMSProp"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "mg"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "ms"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "mom"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseApplyFtrl.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseApplyFtrl.pbtxt
new file mode 100644
index 0000000..f20f106
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseApplyFtrl.pbtxt
@@ -0,0 +1,352 @@
+op {
+  name: "SparseApplyFtrl"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "linear"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyFtrl"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "linear"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyFtrl"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "linear"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyFtrl"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "linear"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseApplyFtrlV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseApplyFtrlV2.pbtxt
new file mode 100644
index 0000000..93a7eff
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseApplyFtrlV2.pbtxt
@@ -0,0 +1,368 @@
+op {
+  name: "SparseApplyFtrlV2"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "linear"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2_shrinkage"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyFtrlV2"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "linear"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2_shrinkage"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyFtrlV2"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "linear"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2_shrinkage"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyFtrlV2"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "linear"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2_shrinkage"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lr_power"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseApplyMomentum.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseApplyMomentum.pbtxt
new file mode 100644
index 0000000..7ea3d81
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseApplyMomentum.pbtxt
@@ -0,0 +1,328 @@
+op {
+  name: "SparseApplyMomentum"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyMomentum"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyMomentum"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyMomentum"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseApplyProximalAdagrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseApplyProximalAdagrad.pbtxt
new file mode 100644
index 0000000..165a892
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseApplyProximalAdagrad.pbtxt
@@ -0,0 +1,316 @@
+op {
+  name: "SparseApplyProximalAdagrad"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyProximalAdagrad"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyProximalAdagrad"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyProximalAdagrad"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseApplyProximalGradientDescent.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseApplyProximalGradientDescent.pbtxt
new file mode 100644
index 0000000..f661a63
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseApplyProximalGradientDescent.pbtxt
@@ -0,0 +1,296 @@
+op {
+  name: "SparseApplyProximalGradientDescent"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyProximalGradientDescent"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyProximalGradientDescent"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyProximalGradientDescent"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "alpha"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l1"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "l2"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseApplyRMSProp.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseApplyRMSProp.pbtxt
new file mode 100644
index 0000000..254fb7f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseApplyRMSProp.pbtxt
@@ -0,0 +1,352 @@
+op {
+  name: "SparseApplyRMSProp"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "ms"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "mom"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyRMSProp"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "ms"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "mom"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyRMSProp"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "ms"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "mom"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseApplyRMSProp"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "ms"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "mom"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rho"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "momentum"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseConcat.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseConcat.pbtxt
new file mode 100644
index 0000000..ac291f4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseConcat.pbtxt
@@ -0,0 +1,44 @@
+op {
+  name: "SparseConcat"
+  input_arg {
+    name: "indices"
+    type: DT_INT64
+    number_attr: "N"
+  }
+  input_arg {
+    name: "values"
+    type_attr: "T"
+    number_attr: "N"
+  }
+  input_arg {
+    name: "shapes"
+    type: DT_INT64
+    number_attr: "N"
+  }
+  output_arg {
+    name: "output_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "concat_dim"
+    type: "int"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 2
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseConditionalAccumulator.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseConditionalAccumulator.pbtxt
new file mode 100644
index 0000000..8ede8ba
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseConditionalAccumulator.pbtxt
@@ -0,0 +1,269 @@
+op {
+  name: "SparseConditionalAccumulator"
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "SparseConditionalAccumulator"
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "SparseConditionalAccumulator"
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "SparseConditionalAccumulator"
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "SparseConditionalAccumulator"
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "reduction_type"
+    type: "string"
+    default_value {
+      s: "MEAN"
+    }
+    allowed_values {
+      list {
+        s: "MEAN"
+        s: "SUM"
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseCross.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseCross.pbtxt
new file mode 100644
index 0000000..f25372f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseCross.pbtxt
@@ -0,0 +1,93 @@
+op {
+  name: "SparseCross"
+  input_arg {
+    name: "indices"
+    type: DT_INT64
+    number_attr: "N"
+  }
+  input_arg {
+    name: "values"
+    type_list_attr: "sparse_types"
+  }
+  input_arg {
+    name: "shapes"
+    type: DT_INT64
+    number_attr: "N"
+  }
+  input_arg {
+    name: "dense_inputs"
+    type_list_attr: "dense_types"
+  }
+  output_arg {
+    name: "output_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_values"
+    type_attr: "out_type"
+  }
+  output_arg {
+    name: "output_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "hashed_output"
+    type: "bool"
+  }
+  attr {
+    name: "num_buckets"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "hash_key"
+    type: "int"
+  }
+  attr {
+    name: "sparse_types"
+    type: "list(type)"
+    has_minimum: true
+    allowed_values {
+      list {
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "dense_types"
+    type: "list(type)"
+    has_minimum: true
+    allowed_values {
+      list {
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+  attr {
+    name: "internal_type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT64
+        type: DT_STRING
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseDenseCwiseAdd.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseDenseCwiseAdd.pbtxt
new file mode 100644
index 0000000..87eb5e4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseDenseCwiseAdd.pbtxt
@@ -0,0 +1,188 @@
+op {
+  name: "SparseDenseCwiseAdd"
+  input_arg {
+    name: "sp_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sp_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sp_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "dense"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "SparseDenseCwiseAdd"
+  input_arg {
+    name: "sp_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sp_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sp_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "dense"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "SparseDenseCwiseAdd"
+  input_arg {
+    name: "sp_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sp_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sp_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "dense"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "SparseDenseCwiseAdd"
+  input_arg {
+    name: "sp_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sp_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sp_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "dense"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseDenseCwiseDiv.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseDenseCwiseDiv.pbtxt
new file mode 100644
index 0000000..e3b0f58
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseDenseCwiseDiv.pbtxt
@@ -0,0 +1,188 @@
+op {
+  name: "SparseDenseCwiseDiv"
+  input_arg {
+    name: "sp_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sp_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sp_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "dense"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "SparseDenseCwiseDiv"
+  input_arg {
+    name: "sp_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sp_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sp_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "dense"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "SparseDenseCwiseDiv"
+  input_arg {
+    name: "sp_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sp_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sp_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "dense"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "SparseDenseCwiseDiv"
+  input_arg {
+    name: "sp_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sp_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sp_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "dense"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseDenseCwiseMul.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseDenseCwiseMul.pbtxt
new file mode 100644
index 0000000..494ce78
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseDenseCwiseMul.pbtxt
@@ -0,0 +1,188 @@
+op {
+  name: "SparseDenseCwiseMul"
+  input_arg {
+    name: "sp_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sp_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sp_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "dense"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "SparseDenseCwiseMul"
+  input_arg {
+    name: "sp_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sp_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sp_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "dense"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "SparseDenseCwiseMul"
+  input_arg {
+    name: "sp_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sp_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sp_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "dense"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "SparseDenseCwiseMul"
+  input_arg {
+    name: "sp_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sp_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sp_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "dense"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseFillEmptyRows.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseFillEmptyRows.pbtxt
new file mode 100644
index 0000000..d99257a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseFillEmptyRows.pbtxt
@@ -0,0 +1,39 @@
+op {
+  name: "SparseFillEmptyRows"
+  input_arg {
+    name: "indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dense_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "default_value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "empty_row_indicator"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "reverse_index_map"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseFillEmptyRowsGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseFillEmptyRowsGrad.pbtxt
new file mode 100644
index 0000000..87f1c5c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseFillEmptyRowsGrad.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "SparseFillEmptyRowsGrad"
+  input_arg {
+    name: "reverse_index_map"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "grad_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "d_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "d_default_value"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseMatMul.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseMatMul.pbtxt
new file mode 100644
index 0000000..d1eaa6a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseMatMul.pbtxt
@@ -0,0 +1,69 @@
+op {
+  name: "SparseMatMul"
+  input_arg {
+    name: "a"
+    type_attr: "Ta"
+  }
+  input_arg {
+    name: "b"
+    type_attr: "Tb"
+  }
+  output_arg {
+    name: "product"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "transpose_a"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "transpose_b"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "a_is_sparse"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "b_is_sparse"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "Ta"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tb"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseReduceMax.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseReduceMax.pbtxt
new file mode 100644
index 0000000..4df1254
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseReduceMax.pbtxt
@@ -0,0 +1,196 @@
+op {
+  name: "SparseReduceMax"
+  input_arg {
+    name: "input_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "input_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "reduction_axes"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "SparseReduceMax"
+  input_arg {
+    name: "input_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "input_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "reduction_axes"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "SparseReduceMax"
+  input_arg {
+    name: "input_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "input_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "reduction_axes"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "SparseReduceMax"
+  input_arg {
+    name: "input_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "input_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "reduction_axes"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseReduceMaxSparse.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseReduceMaxSparse.pbtxt
new file mode 100644
index 0000000..8189644
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseReduceMaxSparse.pbtxt
@@ -0,0 +1,228 @@
+op {
+  name: "SparseReduceMaxSparse"
+  input_arg {
+    name: "input_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "input_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "reduction_axes"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "SparseReduceMaxSparse"
+  input_arg {
+    name: "input_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "input_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "reduction_axes"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "SparseReduceMaxSparse"
+  input_arg {
+    name: "input_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "input_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "reduction_axes"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "SparseReduceMaxSparse"
+  input_arg {
+    name: "input_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "input_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "reduction_axes"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseReduceSum.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseReduceSum.pbtxt
new file mode 100644
index 0000000..48a5627
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseReduceSum.pbtxt
@@ -0,0 +1,216 @@
+op {
+  name: "SparseReduceSum"
+  input_arg {
+    name: "input_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "input_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "reduction_axes"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "SparseReduceSum"
+  input_arg {
+    name: "input_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "input_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "reduction_axes"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "SparseReduceSum"
+  input_arg {
+    name: "input_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "input_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "reduction_axes"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "SparseReduceSum"
+  input_arg {
+    name: "input_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "input_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "reduction_axes"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseReduceSumSparse.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseReduceSumSparse.pbtxt
new file mode 100644
index 0000000..1c13464
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseReduceSumSparse.pbtxt
@@ -0,0 +1,248 @@
+op {
+  name: "SparseReduceSumSparse"
+  input_arg {
+    name: "input_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "input_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "reduction_axes"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "SparseReduceSumSparse"
+  input_arg {
+    name: "input_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "input_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "reduction_axes"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "SparseReduceSumSparse"
+  input_arg {
+    name: "input_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "input_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "reduction_axes"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "SparseReduceSumSparse"
+  input_arg {
+    name: "input_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "input_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "reduction_axes"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseReorder.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseReorder.pbtxt
new file mode 100644
index 0000000..5c5ad90
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseReorder.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "SparseReorder"
+  input_arg {
+    name: "input_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "input_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_values"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseReshape.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseReshape.pbtxt
new file mode 100644
index 0000000..934b501
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseReshape.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "SparseReshape"
+  input_arg {
+    name: "input_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "input_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "new_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_shape"
+    type: DT_INT64
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseSegmentMean.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseSegmentMean.pbtxt
new file mode 100644
index 0000000..0447e6f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseSegmentMean.pbtxt
@@ -0,0 +1,42 @@
+op {
+  name: "SparseSegmentMean"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tidx"
+  }
+  input_arg {
+    name: "segment_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseSegmentMeanGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseSegmentMeanGrad.pbtxt
new file mode 100644
index 0000000..c31439f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseSegmentMeanGrad.pbtxt
@@ -0,0 +1,46 @@
+op {
+  name: "SparseSegmentMeanGrad"
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tidx"
+  }
+  input_arg {
+    name: "segment_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "output_dim0"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseSegmentMeanWithNumSegments.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseSegmentMeanWithNumSegments.pbtxt
new file mode 100644
index 0000000..ed3693a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseSegmentMeanWithNumSegments.pbtxt
@@ -0,0 +1,59 @@
+op {
+  name: "SparseSegmentMeanWithNumSegments"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tidx"
+  }
+  input_arg {
+    name: "segment_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "num_segments"
+    type_attr: "Tnumsegments"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tnumsegments"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseSegmentSqrtN.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseSegmentSqrtN.pbtxt
new file mode 100644
index 0000000..f856480
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseSegmentSqrtN.pbtxt
@@ -0,0 +1,42 @@
+op {
+  name: "SparseSegmentSqrtN"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tidx"
+  }
+  input_arg {
+    name: "segment_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseSegmentSqrtNGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseSegmentSqrtNGrad.pbtxt
new file mode 100644
index 0000000..569b5b8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseSegmentSqrtNGrad.pbtxt
@@ -0,0 +1,46 @@
+op {
+  name: "SparseSegmentSqrtNGrad"
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tidx"
+  }
+  input_arg {
+    name: "segment_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "output_dim0"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseSegmentSqrtNWithNumSegments.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseSegmentSqrtNWithNumSegments.pbtxt
new file mode 100644
index 0000000..753cfe4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseSegmentSqrtNWithNumSegments.pbtxt
@@ -0,0 +1,59 @@
+op {
+  name: "SparseSegmentSqrtNWithNumSegments"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tidx"
+  }
+  input_arg {
+    name: "segment_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "num_segments"
+    type_attr: "Tnumsegments"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tnumsegments"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseSegmentSum.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseSegmentSum.pbtxt
new file mode 100644
index 0000000..9ecc207
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseSegmentSum.pbtxt
@@ -0,0 +1,204 @@
+op {
+  name: "SparseSegmentSum"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tidx"
+  }
+  input_arg {
+    name: "segment_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "SparseSegmentSum"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tidx"
+  }
+  input_arg {
+    name: "segment_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "SparseSegmentSum"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tidx"
+  }
+  input_arg {
+    name: "segment_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "SparseSegmentSum"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tidx"
+  }
+  input_arg {
+    name: "segment_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseSegmentSumWithNumSegments.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseSegmentSumWithNumSegments.pbtxt
new file mode 100644
index 0000000..0608745
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseSegmentSumWithNumSegments.pbtxt
@@ -0,0 +1,138 @@
+op {
+  name: "SparseSegmentSumWithNumSegments"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tidx"
+  }
+  input_arg {
+    name: "segment_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "num_segments"
+    type_attr: "Tnumsegments"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tnumsegments"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "SparseSegmentSumWithNumSegments"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tidx"
+  }
+  input_arg {
+    name: "segment_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "num_segments"
+    type_attr: "Tnumsegments"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tnumsegments"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseSlice.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseSlice.pbtxt
new file mode 100644
index 0000000..a6434cb
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseSlice.pbtxt
@@ -0,0 +1,39 @@
+op {
+  name: "SparseSlice"
+  input_arg {
+    name: "indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "start"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseSliceGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseSliceGrad.pbtxt
new file mode 100644
index 0000000..ce82f94
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseSliceGrad.pbtxt
@@ -0,0 +1,48 @@
+op {
+  name: "SparseSliceGrad"
+  input_arg {
+    name: "backprop_val_grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "input_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "input_start"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "output_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "val_grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseSoftmax.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseSoftmax.pbtxt
new file mode 100644
index 0000000..efa3df8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseSoftmax.pbtxt
@@ -0,0 +1,29 @@
+op {
+  name: "SparseSoftmax"
+  input_arg {
+    name: "sp_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sp_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sp_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseSoftmaxCrossEntropyWithLogits.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseSoftmaxCrossEntropyWithLogits.pbtxt
new file mode 100644
index 0000000..57d8f4c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseSoftmaxCrossEntropyWithLogits.pbtxt
@@ -0,0 +1,87 @@
+op {
+  name: "SparseSoftmaxCrossEntropyWithLogits"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "labels"
+    type_attr: "Tlabels"
+  }
+  output_arg {
+    name: "loss"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprop"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "Tlabels"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "SparseSoftmaxCrossEntropyWithLogits"
+  input_arg {
+    name: "features"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "labels"
+    type_attr: "Tlabels"
+  }
+  output_arg {
+    name: "loss"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "backprop"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "Tlabels"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseSparseMaximum.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseSparseMaximum.pbtxt
new file mode 100644
index 0000000..bdd017c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseSparseMaximum.pbtxt
@@ -0,0 +1,216 @@
+op {
+  name: "SparseSparseMaximum"
+  input_arg {
+    name: "a_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "a_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "a_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_values"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "SparseSparseMaximum"
+  input_arg {
+    name: "a_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "a_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "a_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_values"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "SparseSparseMaximum"
+  input_arg {
+    name: "a_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "a_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "a_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_values"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "SparseSparseMaximum"
+  input_arg {
+    name: "a_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "a_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "a_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_values"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseSparseMinimum.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseSparseMinimum.pbtxt
new file mode 100644
index 0000000..52fc6be
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseSparseMinimum.pbtxt
@@ -0,0 +1,236 @@
+op {
+  name: "SparseSparseMinimum"
+  input_arg {
+    name: "a_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "a_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "a_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_values"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "SparseSparseMinimum"
+  input_arg {
+    name: "a_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "a_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "a_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_values"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "SparseSparseMinimum"
+  input_arg {
+    name: "a_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "a_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "a_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_values"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "SparseSparseMinimum"
+  input_arg {
+    name: "a_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "a_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "a_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "b_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_values"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseSplit.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseSplit.pbtxt
new file mode 100644
index 0000000..997b2b2
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseSplit.pbtxt
@@ -0,0 +1,44 @@
+op {
+  name: "SparseSplit"
+  input_arg {
+    name: "split_dim"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_indices"
+    type: DT_INT64
+    number_attr: "num_split"
+  }
+  output_arg {
+    name: "output_values"
+    type_attr: "T"
+    number_attr: "num_split"
+  }
+  output_arg {
+    name: "output_shape"
+    type: DT_INT64
+    number_attr: "num_split"
+  }
+  attr {
+    name: "num_split"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseTensorDenseAdd.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseTensorDenseAdd.pbtxt
new file mode 100644
index 0000000..397d3fd
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseTensorDenseAdd.pbtxt
@@ -0,0 +1,228 @@
+op {
+  name: "SparseTensorDenseAdd"
+  input_arg {
+    name: "a_indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "a_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "a_shape"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "b"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "SparseTensorDenseAdd"
+  input_arg {
+    name: "a_indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "a_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "a_shape"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "b"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "SparseTensorDenseAdd"
+  input_arg {
+    name: "a_indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "a_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "a_shape"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "b"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "SparseTensorDenseAdd"
+  input_arg {
+    name: "a_indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "a_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "a_shape"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "b"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseTensorDenseMatMul.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseTensorDenseMatMul.pbtxt
new file mode 100644
index 0000000..ce66c53
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseTensorDenseMatMul.pbtxt
@@ -0,0 +1,95 @@
+op {
+  name: "SparseTensorDenseMatMul"
+  input_arg {
+    name: "a_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "a_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "a_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "product"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "adjoint_a"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "adjoint_b"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "SparseTensorDenseMatMul"
+  input_arg {
+    name: "a_indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "a_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "a_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "b"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "product"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "adjoint_a"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "adjoint_b"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseTensorSliceDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseTensorSliceDataset.pbtxt
new file mode 100644
index 0000000..0009238
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseTensorSliceDataset.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "SparseTensorSliceDataset"
+  input_arg {
+    name: "indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "values"
+    type_attr: "Tvalues"
+  }
+  input_arg {
+    name: "dense_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "Tvalues"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseToDense.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseToDense.pbtxt
new file mode 100644
index 0000000..3516034
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseToDense.pbtxt
@@ -0,0 +1,44 @@
+op {
+  name: "SparseToDense"
+  input_arg {
+    name: "sparse_indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "output_shape"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "sparse_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "default_value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "dense"
+    type_attr: "T"
+  }
+  attr {
+    name: "validate_indices"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SparseToSparseSetOperation.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SparseToSparseSetOperation.pbtxt
new file mode 100644
index 0000000..a7775a2
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SparseToSparseSetOperation.pbtxt
@@ -0,0 +1,65 @@
+op {
+  name: "SparseToSparseSetOperation"
+  input_arg {
+    name: "set1_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "set1_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "set1_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "set2_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "set2_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "set2_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "result_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "result_values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "result_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "set_operation"
+    type: "string"
+  }
+  attr {
+    name: "validate_indices"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT8
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_STRING
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Split.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Split.pbtxt
new file mode 100644
index 0000000..49428f7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Split.pbtxt
@@ -0,0 +1,26 @@
+op {
+  name: "Split"
+  input_arg {
+    name: "split_dim"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+    number_attr: "num_split"
+  }
+  attr {
+    name: "num_split"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SplitV.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SplitV.pbtxt
new file mode 100644
index 0000000..a7a9839
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SplitV.pbtxt
@@ -0,0 +1,43 @@
+op {
+  name: "SplitV"
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "size_splits"
+    type_attr: "Tlen"
+  }
+  input_arg {
+    name: "split_dim"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+    number_attr: "num_split"
+  }
+  attr {
+    name: "num_split"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tlen"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SqlDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SqlDataset.pbtxt
new file mode 100644
index 0000000..337b379
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SqlDataset.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "SqlDataset"
+  input_arg {
+    name: "driver_name"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "data_source_name"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "query"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Sqrt.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Sqrt.pbtxt
new file mode 100644
index 0000000..3c566b9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Sqrt.pbtxt
@@ -0,0 +1,74 @@
+op {
+  name: "Sqrt"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Sqrt"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Sqrt"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SqrtGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SqrtGrad.pbtxt
new file mode 100644
index 0000000..d738e20
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SqrtGrad.pbtxt
@@ -0,0 +1,114 @@
+op {
+  name: "SqrtGrad"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "SqrtGrad"
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dy"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "SqrtGrad"
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dy"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "SqrtGrad"
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dy"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Square.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Square.pbtxt
new file mode 100644
index 0000000..4d07faf
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Square.pbtxt
@@ -0,0 +1,80 @@
+op {
+  name: "Square"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Square"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Square"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SquaredDifference.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SquaredDifference.pbtxt
new file mode 100644
index 0000000..29ea33c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SquaredDifference.pbtxt
@@ -0,0 +1,95 @@
+op {
+  name: "SquaredDifference"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "SquaredDifference"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  is_commutative: true
+}
+op {
+  name: "SquaredDifference"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  is_commutative: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Squeeze.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Squeeze.pbtxt
new file mode 100644
index 0000000..5433554
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Squeeze.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "Squeeze"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "squeeze_dims"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+    has_minimum: true
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Stack.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Stack.pbtxt
new file mode 100644
index 0000000..e8e459c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Stack.pbtxt
@@ -0,0 +1,20 @@
+op {
+  name: "Stack"
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "elem_type"
+    type: "type"
+  }
+  attr {
+    name: "stack_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StackClose.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StackClose.pbtxt
new file mode 100644
index 0000000..8c916ab
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StackClose.pbtxt
@@ -0,0 +1,8 @@
+op {
+  name: "StackClose"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StackCloseV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StackCloseV2.pbtxt
new file mode 100644
index 0000000..18c5934
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StackCloseV2.pbtxt
@@ -0,0 +1,8 @@
+op {
+  name: "StackCloseV2"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StackPop.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StackPop.pbtxt
new file mode 100644
index 0000000..80e3ef7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StackPop.pbtxt
@@ -0,0 +1,16 @@
+op {
+  name: "StackPop"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  output_arg {
+    name: "elem"
+    type_attr: "elem_type"
+  }
+  attr {
+    name: "elem_type"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StackPopV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StackPopV2.pbtxt
new file mode 100644
index 0000000..438d52b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StackPopV2.pbtxt
@@ -0,0 +1,16 @@
+op {
+  name: "StackPopV2"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "elem"
+    type_attr: "elem_type"
+  }
+  attr {
+    name: "elem_type"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StackPush.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StackPush.pbtxt
new file mode 100644
index 0000000..44fae0c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StackPush.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "StackPush"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "elem"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "swap_memory"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StackPushV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StackPushV2.pbtxt
new file mode 100644
index 0000000..7149b4f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StackPushV2.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "StackPushV2"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "elem"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "swap_memory"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StackV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StackV2.pbtxt
new file mode 100644
index 0000000..606361d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StackV2.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "StackV2"
+  input_arg {
+    name: "max_size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "elem_type"
+    type: "type"
+  }
+  attr {
+    name: "stack_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Stage.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Stage.pbtxt
new file mode 100644
index 0000000..8a64d69
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Stage.pbtxt
@@ -0,0 +1,72 @@
+op {
+  name: "Stage"
+  input_arg {
+    name: "values"
+    type_list_attr: "dtypes"
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "Stage"
+  input_arg {
+    name: "values"
+    type_list_attr: "dtypes"
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "memory_limit"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StageClear.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StageClear.pbtxt
new file mode 100644
index 0000000..1f43cdb
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StageClear.pbtxt
@@ -0,0 +1,38 @@
+op {
+  name: "StageClear"
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "memory_limit"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StagePeek.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StagePeek.pbtxt
new file mode 100644
index 0000000..a7397c4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StagePeek.pbtxt
@@ -0,0 +1,48 @@
+op {
+  name: "StagePeek"
+  input_arg {
+    name: "index"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "values"
+    type_list_attr: "dtypes"
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "memory_limit"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StageSize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StageSize.pbtxt
new file mode 100644
index 0000000..6f22fd3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StageSize.pbtxt
@@ -0,0 +1,42 @@
+op {
+  name: "StageSize"
+  output_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "memory_limit"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StatefulPartitionedCall.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StatefulPartitionedCall.pbtxt
new file mode 100644
index 0000000..b6f9b97
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StatefulPartitionedCall.pbtxt
@@ -0,0 +1,146 @@
+op {
+  name: "StatefulPartitionedCall"
+  input_arg {
+    name: "args"
+    type_list_attr: "Tin"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  is_stateful: true
+}
+op {
+  name: "StatefulPartitionedCall"
+  input_arg {
+    name: "args"
+    type_list_attr: "Tin"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "config"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "StatefulPartitionedCall"
+  input_arg {
+    name: "args"
+    type_list_attr: "Tin"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "config"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "executor_type"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "StatefulPartitionedCall"
+  input_arg {
+    name: "args"
+    type_list_attr: "Tin"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "config"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "config_proto"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "executor_type"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StatefulRandomBinomial.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StatefulRandomBinomial.pbtxt
new file mode 100644
index 0000000..97eb7d4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StatefulRandomBinomial.pbtxt
@@ -0,0 +1,70 @@
+op {
+  name: "StatefulRandomBinomial"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "algorithm"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "shape"
+    type_attr: "S"
+  }
+  input_arg {
+    name: "counts"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "probs"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "S"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_DOUBLE
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StatefulStandardNormal.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StatefulStandardNormal.pbtxt
new file mode 100644
index 0000000..44ef92c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StatefulStandardNormal.pbtxt
@@ -0,0 +1,107 @@
+op {
+  name: "StatefulStandardNormal"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "shape"
+    type_attr: "shape_dtype"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "shape_dtype"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "StatefulStandardNormal"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "shape"
+    type_attr: "shape_dtype"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    name: "shape_dtype"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "StatefulStandardNormal"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "shape"
+    type_attr: "shape_dtype"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    name: "shape_dtype"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+  }
+  deprecation {
+    version: 29
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StatefulStandardNormalV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StatefulStandardNormalV2.pbtxt
new file mode 100644
index 0000000..1b99b23
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StatefulStandardNormalV2.pbtxt
@@ -0,0 +1,34 @@
+op {
+  name: "StatefulStandardNormalV2"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "algorithm"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "shape"
+    type_attr: "shape_dtype"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    name: "shape_dtype"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StatefulTruncatedNormal.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StatefulTruncatedNormal.pbtxt
new file mode 100644
index 0000000..e74de4f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StatefulTruncatedNormal.pbtxt
@@ -0,0 +1,34 @@
+op {
+  name: "StatefulTruncatedNormal"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "algorithm"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "shape"
+    type_attr: "shape_dtype"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    name: "shape_dtype"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StatefulUniform.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StatefulUniform.pbtxt
new file mode 100644
index 0000000..fd2b87c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StatefulUniform.pbtxt
@@ -0,0 +1,34 @@
+op {
+  name: "StatefulUniform"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "algorithm"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "shape"
+    type_attr: "shape_dtype"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    name: "shape_dtype"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StatefulUniformFullInt.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StatefulUniformFullInt.pbtxt
new file mode 100644
index 0000000..35ab70e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StatefulUniformFullInt.pbtxt
@@ -0,0 +1,34 @@
+op {
+  name: "StatefulUniformFullInt"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "algorithm"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "shape"
+    type_attr: "shape_dtype"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    default_value {
+      type: DT_UINT64
+    }
+  }
+  attr {
+    name: "shape_dtype"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StatefulUniformInt.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StatefulUniformInt.pbtxt
new file mode 100644
index 0000000..06f62fa
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StatefulUniformInt.pbtxt
@@ -0,0 +1,42 @@
+op {
+  name: "StatefulUniformInt"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "algorithm"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "shape"
+    type_attr: "shape_dtype"
+  }
+  input_arg {
+    name: "minval"
+    type_attr: "dtype"
+  }
+  input_arg {
+    name: "maxval"
+    type_attr: "dtype"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+  }
+  attr {
+    name: "shape_dtype"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StatelessIf.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StatelessIf.pbtxt
new file mode 100644
index 0000000..6eda6df
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StatelessIf.pbtxt
@@ -0,0 +1,82 @@
+op {
+  name: "StatelessIf"
+  input_arg {
+    name: "cond"
+    type_attr: "Tcond"
+  }
+  input_arg {
+    name: "input"
+    type_list_attr: "Tin"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "Tcond"
+    type: "type"
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "then_branch"
+    type: "func"
+  }
+  attr {
+    name: "else_branch"
+    type: "func"
+  }
+}
+op {
+  name: "StatelessIf"
+  input_arg {
+    name: "cond"
+    type_attr: "Tcond"
+  }
+  input_arg {
+    name: "input"
+    type_list_attr: "Tin"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "Tcond"
+    type: "type"
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "then_branch"
+    type: "func"
+  }
+  attr {
+    name: "else_branch"
+    type: "func"
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    default_value {
+      list {
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StatelessMultinomial.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StatelessMultinomial.pbtxt
new file mode 100644
index 0000000..16dac7d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StatelessMultinomial.pbtxt
@@ -0,0 +1,65 @@
+op {
+  name: "StatelessMultinomial"
+  input_arg {
+    name: "logits"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "num_samples"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "seed"
+    type_attr: "Tseed"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "output_dtype"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tseed"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "output_dtype"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StatelessRandomNormal.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StatelessRandomNormal.pbtxt
new file mode 100644
index 0000000..804d904
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StatelessRandomNormal.pbtxt
@@ -0,0 +1,153 @@
+op {
+  name: "StatelessRandomNormal"
+  input_arg {
+    name: "shape"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "seed"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "StatelessRandomNormal"
+  input_arg {
+    name: "shape"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "seed"
+    type_attr: "Tseed"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tseed"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "StatelessRandomNormal"
+  input_arg {
+    name: "shape"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "seed"
+    type_attr: "Tseed"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tseed"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StatelessRandomUniform.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StatelessRandomUniform.pbtxt
new file mode 100644
index 0000000..22a5b25
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StatelessRandomUniform.pbtxt
@@ -0,0 +1,153 @@
+op {
+  name: "StatelessRandomUniform"
+  input_arg {
+    name: "shape"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "seed"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "StatelessRandomUniform"
+  input_arg {
+    name: "shape"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "seed"
+    type_attr: "Tseed"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tseed"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "StatelessRandomUniform"
+  input_arg {
+    name: "shape"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "seed"
+    type_attr: "Tseed"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tseed"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StatelessRandomUniformInt.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StatelessRandomUniformInt.pbtxt
new file mode 100644
index 0000000..834a6fd
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StatelessRandomUniformInt.pbtxt
@@ -0,0 +1,56 @@
+op {
+  name: "StatelessRandomUniformInt"
+  input_arg {
+    name: "shape"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "seed"
+    type_attr: "Tseed"
+  }
+  input_arg {
+    name: "minval"
+    type_attr: "dtype"
+  }
+  input_arg {
+    name: "maxval"
+    type_attr: "dtype"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tseed"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StatelessTruncatedNormal.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StatelessTruncatedNormal.pbtxt
new file mode 100644
index 0000000..c8c8d85
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StatelessTruncatedNormal.pbtxt
@@ -0,0 +1,153 @@
+op {
+  name: "StatelessTruncatedNormal"
+  input_arg {
+    name: "shape"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "seed"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "StatelessTruncatedNormal"
+  input_arg {
+    name: "shape"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "seed"
+    type_attr: "Tseed"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tseed"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "StatelessTruncatedNormal"
+  input_arg {
+    name: "shape"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "seed"
+    type_attr: "Tseed"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tseed"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StatelessWhile.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StatelessWhile.pbtxt
new file mode 100644
index 0000000..28579ed
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StatelessWhile.pbtxt
@@ -0,0 +1,63 @@
+op {
+  name: "StatelessWhile"
+  input_arg {
+    name: "input"
+    type_list_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "cond"
+    type: "func"
+  }
+  attr {
+    name: "body"
+    type: "func"
+  }
+}
+op {
+  name: "StatelessWhile"
+  input_arg {
+    name: "input"
+    type_list_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "cond"
+    type: "func"
+  }
+  attr {
+    name: "body"
+    type: "func"
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "parallel_iterations"
+    type: "int"
+    default_value {
+      i: 10
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StaticRegexFullMatch.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StaticRegexFullMatch.pbtxt
new file mode 100644
index 0000000..be6078c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StaticRegexFullMatch.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "StaticRegexFullMatch"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type: DT_BOOL
+  }
+  attr {
+    name: "pattern"
+    type: "string"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StaticRegexReplace.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StaticRegexReplace.pbtxt
new file mode 100644
index 0000000..fe3eb69
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StaticRegexReplace.pbtxt
@@ -0,0 +1,26 @@
+op {
+  name: "StaticRegexReplace"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type: DT_STRING
+  }
+  attr {
+    name: "pattern"
+    type: "string"
+  }
+  attr {
+    name: "rewrite"
+    type: "string"
+  }
+  attr {
+    name: "replace_global"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StatsAggregatorHandle.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StatsAggregatorHandle.pbtxt
new file mode 100644
index 0000000..2d55e00
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StatsAggregatorHandle.pbtxt
@@ -0,0 +1,22 @@
+op {
+  name: "StatsAggregatorHandle"
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StatsAggregatorHandleV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StatsAggregatorHandleV2.pbtxt
new file mode 100644
index 0000000..7dc361e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StatsAggregatorHandleV2.pbtxt
@@ -0,0 +1,22 @@
+op {
+  name: "StatsAggregatorHandleV2"
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StatsAggregatorSetSummaryWriter.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StatsAggregatorSetSummaryWriter.pbtxt
new file mode 100644
index 0000000..24730ad
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StatsAggregatorSetSummaryWriter.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "StatsAggregatorSetSummaryWriter"
+  input_arg {
+    name: "stats_aggregator"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "summary"
+    type: DT_RESOURCE
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StatsAggregatorSummary.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StatsAggregatorSummary.pbtxt
new file mode 100644
index 0000000..a0702a1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StatsAggregatorSummary.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "StatsAggregatorSummary"
+  input_arg {
+    name: "iterator"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "summary"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StopGradient.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StopGradient.pbtxt
new file mode 100644
index 0000000..26f7c67
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StopGradient.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "StopGradient"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StridedSlice.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StridedSlice.pbtxt
new file mode 100644
index 0000000..7c5fd7b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StridedSlice.pbtxt
@@ -0,0 +1,72 @@
+op {
+  name: "StridedSlice"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "begin"
+    type_attr: "Index"
+  }
+  input_arg {
+    name: "end"
+    type_attr: "Index"
+  }
+  input_arg {
+    name: "strides"
+    type_attr: "Index"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Index"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "begin_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "end_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "ellipsis_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "new_axis_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "shrink_axis_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StridedSliceAssign.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StridedSliceAssign.pbtxt
new file mode 100644
index 0000000..8393dc7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StridedSliceAssign.pbtxt
@@ -0,0 +1,78 @@
+op {
+  name: "StridedSliceAssign"
+  input_arg {
+    name: "ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "begin"
+    type_attr: "Index"
+  }
+  input_arg {
+    name: "end"
+    type_attr: "Index"
+  }
+  input_arg {
+    name: "strides"
+    type_attr: "Index"
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_ref"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Index"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "begin_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "end_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "ellipsis_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "new_axis_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "shrink_axis_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StridedSliceGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StridedSliceGrad.pbtxt
new file mode 100644
index 0000000..14f6a46
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StridedSliceGrad.pbtxt
@@ -0,0 +1,76 @@
+op {
+  name: "StridedSliceGrad"
+  input_arg {
+    name: "shape"
+    type_attr: "Index"
+  }
+  input_arg {
+    name: "begin"
+    type_attr: "Index"
+  }
+  input_arg {
+    name: "end"
+    type_attr: "Index"
+  }
+  input_arg {
+    name: "strides"
+    type_attr: "Index"
+  }
+  input_arg {
+    name: "dy"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Index"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "begin_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "end_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "ellipsis_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "new_axis_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "shrink_axis_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StringFormat.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StringFormat.pbtxt
new file mode 100644
index 0000000..bea3290
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StringFormat.pbtxt
@@ -0,0 +1,37 @@
+op {
+  name: "StringFormat"
+  input_arg {
+    name: "inputs"
+    type_list_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "template"
+    type: "string"
+    default_value {
+      s: "%s"
+    }
+  }
+  attr {
+    name: "placeholder"
+    type: "string"
+    default_value {
+      s: "%s"
+    }
+  }
+  attr {
+    name: "summarize"
+    type: "int"
+    default_value {
+      i: 3
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StringJoin.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StringJoin.pbtxt
new file mode 100644
index 0000000..5854eb6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StringJoin.pbtxt
@@ -0,0 +1,25 @@
+op {
+  name: "StringJoin"
+  input_arg {
+    name: "inputs"
+    type: DT_STRING
+    number_attr: "N"
+  }
+  output_arg {
+    name: "output"
+    type: DT_STRING
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "separator"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StringLength.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StringLength.pbtxt
new file mode 100644
index 0000000..5bdf993
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StringLength.pbtxt
@@ -0,0 +1,35 @@
+op {
+  name: "StringLength"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type: DT_INT32
+  }
+}
+op {
+  name: "StringLength"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type: DT_INT32
+  }
+  attr {
+    name: "unit"
+    type: "string"
+    default_value {
+      s: "BYTE"
+    }
+    allowed_values {
+      list {
+        s: "BYTE"
+        s: "UTF8_CHAR"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StringLower.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StringLower.pbtxt
new file mode 100644
index 0000000..1c88614
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StringLower.pbtxt
@@ -0,0 +1,18 @@
+op {
+  name: "StringLower"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type: DT_STRING
+  }
+  attr {
+    name: "encoding"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StringNGrams.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StringNGrams.pbtxt
new file mode 100644
index 0000000..025fc05
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StringNGrams.pbtxt
@@ -0,0 +1,57 @@
+op {
+  name: "StringNGrams"
+  input_arg {
+    name: "data"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "data_splits"
+    type_attr: "Tsplits"
+  }
+  output_arg {
+    name: "ngrams"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "ngrams_splits"
+    type_attr: "Tsplits"
+  }
+  attr {
+    name: "separator"
+    type: "string"
+  }
+  attr {
+    name: "ngram_widths"
+    type: "list(int)"
+    has_minimum: true
+  }
+  attr {
+    name: "left_pad"
+    type: "string"
+  }
+  attr {
+    name: "right_pad"
+    type: "string"
+  }
+  attr {
+    name: "pad_width"
+    type: "int"
+  }
+  attr {
+    name: "preserve_short_sequences"
+    type: "bool"
+  }
+  attr {
+    name: "Tsplits"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StringSplit.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StringSplit.pbtxt
new file mode 100644
index 0000000..35e8594
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StringSplit.pbtxt
@@ -0,0 +1,53 @@
+op {
+  name: "StringSplit"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "delimiter"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "values"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "shape"
+    type: DT_INT64
+  }
+}
+op {
+  name: "StringSplit"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "delimiter"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "values"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "skip_empty"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StringSplitV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StringSplitV2.pbtxt
new file mode 100644
index 0000000..fbdf8e0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StringSplitV2.pbtxt
@@ -0,0 +1,30 @@
+op {
+  name: "StringSplitV2"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "sep"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "values"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "maxsplit"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StringStrip.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StringStrip.pbtxt
new file mode 100644
index 0000000..3fff999
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StringStrip.pbtxt
@@ -0,0 +1,11 @@
+op {
+  name: "StringStrip"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type: DT_STRING
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StringToHashBucket.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StringToHashBucket.pbtxt
new file mode 100644
index 0000000..7147a40
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StringToHashBucket.pbtxt
@@ -0,0 +1,17 @@
+op {
+  name: "StringToHashBucket"
+  input_arg {
+    name: "string_tensor"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type: DT_INT64
+  }
+  attr {
+    name: "num_buckets"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StringToHashBucketFast.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StringToHashBucketFast.pbtxt
new file mode 100644
index 0000000..8ef1227
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StringToHashBucketFast.pbtxt
@@ -0,0 +1,17 @@
+op {
+  name: "StringToHashBucketFast"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type: DT_INT64
+  }
+  attr {
+    name: "num_buckets"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StringToHashBucketStrong.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StringToHashBucketStrong.pbtxt
new file mode 100644
index 0000000..2dbd992
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StringToHashBucketStrong.pbtxt
@@ -0,0 +1,21 @@
+op {
+  name: "StringToHashBucketStrong"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type: DT_INT64
+  }
+  attr {
+    name: "num_buckets"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "key"
+    type: "list(int)"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StringToNumber.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StringToNumber.pbtxt
new file mode 100644
index 0000000..6809380
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StringToNumber.pbtxt
@@ -0,0 +1,50 @@
+op {
+  name: "StringToNumber"
+  input_arg {
+    name: "string_tensor"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_INT32
+      }
+    }
+  }
+}
+op {
+  name: "StringToNumber"
+  input_arg {
+    name: "string_tensor"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/StringUpper.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/StringUpper.pbtxt
new file mode 100644
index 0000000..8df4881
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/StringUpper.pbtxt
@@ -0,0 +1,18 @@
+op {
+  name: "StringUpper"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type: DT_STRING
+  }
+  attr {
+    name: "encoding"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Sub.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Sub.pbtxt
new file mode 100644
index 0000000..b9b17ec
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Sub.pbtxt
@@ -0,0 +1,134 @@
+op {
+  name: "Sub"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Sub"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Sub"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Sub"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Substr.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Substr.pbtxt
new file mode 100644
index 0000000..a5c1d2c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Substr.pbtxt
@@ -0,0 +1,71 @@
+op {
+  name: "Substr"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "pos"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "len"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Substr"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "pos"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "len"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "unit"
+    type: "string"
+    default_value {
+      s: "BYTE"
+    }
+    allowed_values {
+      list {
+        s: "BYTE"
+        s: "UTF8_CHAR"
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Sum.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Sum.pbtxt
new file mode 100644
index 0000000..d21e444
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Sum.pbtxt
@@ -0,0 +1,236 @@
+op {
+  name: "Sum"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Sum"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Sum"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "Sum"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SummaryWriter.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SummaryWriter.pbtxt
new file mode 100644
index 0000000..a6fd9170
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SummaryWriter.pbtxt
@@ -0,0 +1,22 @@
+op {
+  name: "SummaryWriter"
+  output_arg {
+    name: "writer"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Svd.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Svd.pbtxt
new file mode 100644
index 0000000..4800390
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Svd.pbtxt
@@ -0,0 +1,91 @@
+op {
+  name: "Svd"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "s"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "u"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "v"
+    type_attr: "T"
+  }
+  attr {
+    name: "compute_uv"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "full_matrices"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Svd"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "s"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "u"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "v"
+    type_attr: "T"
+  }
+  attr {
+    name: "compute_uv"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "full_matrices"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_HALF
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Switch.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Switch.pbtxt
new file mode 100644
index 0000000..0856f34
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Switch.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "Switch"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "pred"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "output_false"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output_true"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/SymbolicGradient.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SymbolicGradient.pbtxt
new file mode 100644
index 0000000..aae5457
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/SymbolicGradient.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "SymbolicGradient"
+  input_arg {
+    name: "input"
+    type_list_attr: "Tin"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TFRecordDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TFRecordDataset.pbtxt
new file mode 100644
index 0000000..dd8ac37
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TFRecordDataset.pbtxt
@@ -0,0 +1,20 @@
+op {
+  name: "TFRecordDataset"
+  input_arg {
+    name: "filenames"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "compression_type"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "buffer_size"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TFRecordReader.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TFRecordReader.pbtxt
new file mode 100644
index 0000000..684c21e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TFRecordReader.pbtxt
@@ -0,0 +1,63 @@
+op {
+  name: "TFRecordReader"
+  output_arg {
+    name: "reader_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "compression_type"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "TFRecordReader"
+  output_arg {
+    name: "reader_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "compression_type"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  deprecation {
+    version: 26
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TFRecordReaderV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TFRecordReaderV2.pbtxt
new file mode 100644
index 0000000..bcdb476
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TFRecordReaderV2.pbtxt
@@ -0,0 +1,29 @@
+op {
+  name: "TFRecordReaderV2"
+  output_arg {
+    name: "reader_handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "compression_type"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TPUCompilationResult.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TPUCompilationResult.pbtxt
new file mode 100644
index 0000000..04a95cc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TPUCompilationResult.pbtxt
@@ -0,0 +1,7 @@
+op {
+  name: "TPUCompilationResult"
+  output_arg {
+    name: "output"
+    type: DT_STRING
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TPUEmbeddingActivations.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TPUEmbeddingActivations.pbtxt
new file mode 100644
index 0000000..3975077
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TPUEmbeddingActivations.pbtxt
@@ -0,0 +1,25 @@
+op {
+  name: "TPUEmbeddingActivations"
+  input_arg {
+    name: "embedding_variable"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "sliced_activations"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_id"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "lookup_id"
+    type: "int"
+    has_minimum: true
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TPUOrdinalSelector.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TPUOrdinalSelector.pbtxt
new file mode 100644
index 0000000..3fb2725
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TPUOrdinalSelector.pbtxt
@@ -0,0 +1,8 @@
+op {
+  name: "TPUOrdinalSelector"
+  output_arg {
+    name: "device_ordinals"
+    type: DT_INT32
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TPUPartitionedCall.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TPUPartitionedCall.pbtxt
new file mode 100644
index 0000000..534f78f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TPUPartitionedCall.pbtxt
@@ -0,0 +1,29 @@
+op {
+  name: "TPUPartitionedCall"
+  input_arg {
+    name: "args"
+    type_list_attr: "Tin"
+  }
+  input_arg {
+    name: "device_ordinal"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TPUReplicateMetadata.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TPUReplicateMetadata.pbtxt
new file mode 100644
index 0000000..1d1b1ba
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TPUReplicateMetadata.pbtxt
@@ -0,0 +1,204 @@
+op {
+  name: "TPUReplicateMetadata"
+  attr {
+    name: "num_replicas"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "num_cores_per_replica"
+    type: "int"
+    default_value {
+      i: 1
+    }
+  }
+  attr {
+    name: "topology"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "use_tpu"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "device_assignment"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "computation_shape"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "host_compute_core"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "padding_map"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+}
+op {
+  name: "TPUReplicateMetadata"
+  attr {
+    name: "num_replicas"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "num_cores_per_replica"
+    type: "int"
+    default_value {
+      i: 1
+    }
+  }
+  attr {
+    name: "topology"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "use_tpu"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "device_assignment"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "computation_shape"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "host_compute_core"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "padding_map"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "step_marker_location"
+    type: "string"
+    default_value {
+      s: "STEP_MARK_AT_ENTRY"
+    }
+  }
+}
+op {
+  name: "TPUReplicateMetadata"
+  attr {
+    name: "num_replicas"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "num_cores_per_replica"
+    type: "int"
+    default_value {
+      i: 1
+    }
+  }
+  attr {
+    name: "topology"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "use_tpu"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "device_assignment"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "computation_shape"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "host_compute_core"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "padding_map"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "step_marker_location"
+    type: "string"
+    default_value {
+      s: "STEP_MARK_AT_ENTRY"
+    }
+  }
+  attr {
+    name: "allow_soft_placement"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TPUReplicatedInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TPUReplicatedInput.pbtxt
new file mode 100644
index 0000000..431df69
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TPUReplicatedInput.pbtxt
@@ -0,0 +1,22 @@
+op {
+  name: "TPUReplicatedInput"
+  input_arg {
+    name: "inputs"
+    type_attr: "T"
+    number_attr: "N"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TPUReplicatedOutput.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TPUReplicatedOutput.pbtxt
new file mode 100644
index 0000000..70b7d0a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TPUReplicatedOutput.pbtxt
@@ -0,0 +1,22 @@
+op {
+  name: "TPUReplicatedOutput"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "outputs"
+    type_attr: "T"
+    number_attr: "num_replicas"
+  }
+  attr {
+    name: "num_replicas"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TakeDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TakeDataset.pbtxt
new file mode 100644
index 0000000..9993d40
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TakeDataset.pbtxt
@@ -0,0 +1,55 @@
+op {
+  name: "TakeDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "count"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
+op {
+  name: "TakeDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "count"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TakeManySparseFromTensorsMap.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TakeManySparseFromTensorsMap.pbtxt
new file mode 100644
index 0000000..0e3ca63
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TakeManySparseFromTensorsMap.pbtxt
@@ -0,0 +1,38 @@
+op {
+  name: "TakeManySparseFromTensorsMap"
+  input_arg {
+    name: "sparse_handles"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sparse_indices"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sparse_values"
+    type_attr: "dtype"
+  }
+  output_arg {
+    name: "sparse_shape"
+    type: DT_INT64
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TakeWhileDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TakeWhileDataset.pbtxt
new file mode 100644
index 0000000..87841fa
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TakeWhileDataset.pbtxt
@@ -0,0 +1,36 @@
+op {
+  name: "TakeWhileDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "predicate"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Tan.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Tan.pbtxt
new file mode 100644
index 0000000..7dc7f84
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Tan.pbtxt
@@ -0,0 +1,80 @@
+op {
+  name: "Tan"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Tan"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Tan"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Tanh.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Tanh.pbtxt
new file mode 100644
index 0000000..1672b0d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Tanh.pbtxt
@@ -0,0 +1,74 @@
+op {
+  name: "Tanh"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Tanh"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "Tanh"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TanhGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TanhGrad.pbtxt
new file mode 100644
index 0000000..67d28f8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TanhGrad.pbtxt
@@ -0,0 +1,114 @@
+op {
+  name: "TanhGrad"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "TanhGrad"
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dy"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "TanhGrad"
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dy"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "TanhGrad"
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "dy"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TemporaryVariable.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TemporaryVariable.pbtxt
new file mode 100644
index 0000000..191354e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TemporaryVariable.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "TemporaryVariable"
+  output_arg {
+    name: "ref"
+    type_attr: "dtype"
+    is_ref: true
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "var_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArray.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArray.pbtxt
new file mode 100644
index 0000000..74b1a54
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArray.pbtxt
@@ -0,0 +1,50 @@
+op {
+  name: "TensorArray"
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "dynamic_size"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "clear_after_read"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "tensor_array_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "element_shape"
+    type: "shape"
+    default_value {
+      shape {
+        unknown_rank: true
+      }
+    }
+  }
+  deprecation {
+    version: 16
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayClose.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayClose.pbtxt
new file mode 100644
index 0000000..63c0100
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayClose.pbtxt
@@ -0,0 +1,11 @@
+op {
+  name: "TensorArrayClose"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  deprecation {
+    version: 16
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayCloseV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayCloseV2.pbtxt
new file mode 100644
index 0000000..b0fb580
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayCloseV2.pbtxt
@@ -0,0 +1,17 @@
+op {
+  name: "TensorArrayCloseV2"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+}
+op {
+  name: "TensorArrayCloseV2"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  deprecation {
+    version: 26
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayCloseV3.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayCloseV3.pbtxt
new file mode 100644
index 0000000..c5d1c2b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayCloseV3.pbtxt
@@ -0,0 +1,8 @@
+op {
+  name: "TensorArrayCloseV3"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayConcat.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayConcat.pbtxt
new file mode 100644
index 0000000..e2c59ab
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayConcat.pbtxt
@@ -0,0 +1,36 @@
+op {
+  name: "TensorArrayConcat"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "value"
+    type_attr: "dtype"
+  }
+  output_arg {
+    name: "lengths"
+    type: DT_INT64
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "element_shape_except0"
+    type: "shape"
+    default_value {
+      shape {
+        unknown_rank: true
+      }
+    }
+  }
+  deprecation {
+    version: 16
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayConcatV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayConcatV2.pbtxt
new file mode 100644
index 0000000..72376bd
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayConcatV2.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "TensorArrayConcatV2"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "value"
+    type_attr: "dtype"
+  }
+  output_arg {
+    name: "lengths"
+    type: DT_INT64
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "element_shape_except0"
+    type: "shape"
+    default_value {
+      shape {
+        unknown_rank: true
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayConcatV3.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayConcatV3.pbtxt
new file mode 100644
index 0000000..91e575c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayConcatV3.pbtxt
@@ -0,0 +1,33 @@
+op {
+  name: "TensorArrayConcatV3"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "value"
+    type_attr: "dtype"
+  }
+  output_arg {
+    name: "lengths"
+    type: DT_INT64
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "element_shape_except0"
+    type: "shape"
+    default_value {
+      shape {
+        unknown_rank: true
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayGather.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayGather.pbtxt
new file mode 100644
index 0000000..a8ded38
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayGather.pbtxt
@@ -0,0 +1,36 @@
+op {
+  name: "TensorArrayGather"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "value"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "element_shape"
+    type: "shape"
+    default_value {
+      shape {
+        unknown_rank: true
+      }
+    }
+  }
+  deprecation {
+    version: 16
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayGatherV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayGatherV2.pbtxt
new file mode 100644
index 0000000..f729683
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayGatherV2.pbtxt
@@ -0,0 +1,67 @@
+op {
+  name: "TensorArrayGatherV2"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "value"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "element_shape"
+    type: "shape"
+    default_value {
+      shape {
+        unknown_rank: true
+      }
+    }
+  }
+}
+op {
+  name: "TensorArrayGatherV2"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "value"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "element_shape"
+    type: "shape"
+    default_value {
+      shape {
+        unknown_rank: true
+      }
+    }
+  }
+  deprecation {
+    version: 26
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayGatherV3.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayGatherV3.pbtxt
new file mode 100644
index 0000000..c87538a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayGatherV3.pbtxt
@@ -0,0 +1,33 @@
+op {
+  name: "TensorArrayGatherV3"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "value"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "element_shape"
+    type: "shape"
+    default_value {
+      shape {
+        unknown_rank: true
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayGrad.pbtxt
new file mode 100644
index 0000000..4221545
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayGrad.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "TensorArrayGrad"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "grad_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "source"
+    type: "string"
+  }
+  deprecation {
+    version: 16
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayGradV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayGradV2.pbtxt
new file mode 100644
index 0000000..d989c40
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayGradV2.pbtxt
@@ -0,0 +1,43 @@
+op {
+  name: "TensorArrayGradV2"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "grad_handle"
+    type: DT_STRING
+  }
+  attr {
+    name: "source"
+    type: "string"
+  }
+  is_stateful: true
+}
+op {
+  name: "TensorArrayGradV2"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "grad_handle"
+    type: DT_STRING
+  }
+  attr {
+    name: "source"
+    type: "string"
+  }
+  deprecation {
+    version: 26
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayGradV3.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayGradV3.pbtxt
new file mode 100644
index 0000000..53e2042
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayGradV3.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "TensorArrayGradV3"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "grad_handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "flow_out"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "source"
+    type: "string"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayGradWithShape.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayGradWithShape.pbtxt
new file mode 100644
index 0000000..1ce7390
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayGradWithShape.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "TensorArrayGradWithShape"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "shape_to_prepend"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "grad_handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "flow_out"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "source"
+    type: "string"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayPack.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayPack.pbtxt
new file mode 100644
index 0000000..f608e45
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayPack.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "TensorArrayPack"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "value"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "element_shape"
+    type: "shape"
+    default_value {
+      shape {
+        unknown_rank: true
+      }
+    }
+  }
+  deprecation {
+    version: 16
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayRead.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayRead.pbtxt
new file mode 100644
index 0000000..62660be
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayRead.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "TensorArrayRead"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "index"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "value"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  deprecation {
+    version: 16
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayReadV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayReadV2.pbtxt
new file mode 100644
index 0000000..cd0a2a3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayReadV2.pbtxt
@@ -0,0 +1,49 @@
+op {
+  name: "TensorArrayReadV2"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "index"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "value"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+}
+op {
+  name: "TensorArrayReadV2"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "index"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "value"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  deprecation {
+    version: 26
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayReadV3.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayReadV3.pbtxt
new file mode 100644
index 0000000..59e66fc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayReadV3.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "TensorArrayReadV3"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "index"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "value"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayScatter.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayScatter.pbtxt
new file mode 100644
index 0000000..b201716
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayScatter.pbtxt
@@ -0,0 +1,31 @@
+op {
+  name: "TensorArrayScatter"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "flow_out"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  deprecation {
+    version: 19
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayScatterV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayScatterV2.pbtxt
new file mode 100644
index 0000000..1eacf2d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayScatterV2.pbtxt
@@ -0,0 +1,57 @@
+op {
+  name: "TensorArrayScatterV2"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "flow_out"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
+op {
+  name: "TensorArrayScatterV2"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "flow_out"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  deprecation {
+    version: 26
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayScatterV3.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayScatterV3.pbtxt
new file mode 100644
index 0000000..5053ed6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayScatterV3.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "TensorArrayScatterV3"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "flow_out"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArraySize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArraySize.pbtxt
new file mode 100644
index 0000000..7f6ce95
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArraySize.pbtxt
@@ -0,0 +1,19 @@
+op {
+  name: "TensorArraySize"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  deprecation {
+    version: 16
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArraySizeV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArraySizeV2.pbtxt
new file mode 100644
index 0000000..8ee9eda
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArraySizeV2.pbtxt
@@ -0,0 +1,33 @@
+op {
+  name: "TensorArraySizeV2"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "size"
+    type: DT_INT32
+  }
+}
+op {
+  name: "TensorArraySizeV2"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  deprecation {
+    version: 26
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArraySizeV3.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArraySizeV3.pbtxt
new file mode 100644
index 0000000..8932b0d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArraySizeV3.pbtxt
@@ -0,0 +1,16 @@
+op {
+  name: "TensorArraySizeV3"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArraySplit.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArraySplit.pbtxt
new file mode 100644
index 0000000..06bf8bf
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArraySplit.pbtxt
@@ -0,0 +1,31 @@
+op {
+  name: "TensorArraySplit"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lengths"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "flow_out"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  deprecation {
+    version: 16
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArraySplitV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArraySplitV2.pbtxt
new file mode 100644
index 0000000..b45ea7a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArraySplitV2.pbtxt
@@ -0,0 +1,57 @@
+op {
+  name: "TensorArraySplitV2"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lengths"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "flow_out"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
+op {
+  name: "TensorArraySplitV2"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lengths"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "flow_out"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  deprecation {
+    version: 26
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArraySplitV3.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArraySplitV3.pbtxt
new file mode 100644
index 0000000..c072c0c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArraySplitV3.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "TensorArraySplitV3"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "lengths"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "flow_out"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayUnpack.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayUnpack.pbtxt
new file mode 100644
index 0000000..81e5abe
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayUnpack.pbtxt
@@ -0,0 +1,27 @@
+op {
+  name: "TensorArrayUnpack"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "flow_out"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  deprecation {
+    version: 20
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayV2.pbtxt
new file mode 100644
index 0000000..1293e19
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayV2.pbtxt
@@ -0,0 +1,95 @@
+op {
+  name: "TensorArrayV2"
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "element_shape"
+    type: "shape"
+    default_value {
+      shape {
+        unknown_rank: true
+      }
+    }
+  }
+  attr {
+    name: "dynamic_size"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "clear_after_read"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "tensor_array_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "TensorArrayV2"
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "element_shape"
+    type: "shape"
+    default_value {
+      shape {
+        unknown_rank: true
+      }
+    }
+  }
+  attr {
+    name: "dynamic_size"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "clear_after_read"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "tensor_array_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  deprecation {
+    version: 26
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayV3.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayV3.pbtxt
new file mode 100644
index 0000000..906e407
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayV3.pbtxt
@@ -0,0 +1,107 @@
+op {
+  name: "TensorArrayV3"
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "flow"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "element_shape"
+    type: "shape"
+    default_value {
+      shape {
+        unknown_rank: true
+      }
+    }
+  }
+  attr {
+    name: "dynamic_size"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "clear_after_read"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "tensor_array_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "TensorArrayV3"
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "flow"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "element_shape"
+    type: "shape"
+    default_value {
+      shape {
+        unknown_rank: true
+      }
+    }
+  }
+  attr {
+    name: "dynamic_size"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "clear_after_read"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "identical_element_shapes"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "tensor_array_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayWrite.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayWrite.pbtxt
new file mode 100644
index 0000000..8f1a94c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayWrite.pbtxt
@@ -0,0 +1,31 @@
+op {
+  name: "TensorArrayWrite"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  input_arg {
+    name: "index"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "flow_out"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  deprecation {
+    version: 16
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayWriteV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayWriteV2.pbtxt
new file mode 100644
index 0000000..fa0c1a6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayWriteV2.pbtxt
@@ -0,0 +1,57 @@
+op {
+  name: "TensorArrayWriteV2"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "index"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "flow_out"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
+op {
+  name: "TensorArrayWriteV2"
+  input_arg {
+    name: "handle"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "index"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "flow_out"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  deprecation {
+    version: 26
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorArrayWriteV3.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayWriteV3.pbtxt
new file mode 100644
index 0000000..45327d4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorArrayWriteV3.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "TensorArrayWriteV3"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "index"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "flow_in"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "flow_out"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorDataset.pbtxt
new file mode 100644
index 0000000..ecb4fb5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorDataset.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "TensorDataset"
+  input_arg {
+    name: "components"
+    type_list_attr: "Toutput_types"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "Toutput_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorForestCreateTreeVariable.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorForestCreateTreeVariable.pbtxt
new file mode 100644
index 0000000..e09d1be
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorForestCreateTreeVariable.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "TensorForestCreateTreeVariable"
+  input_arg {
+    name: "tree_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "tree_config"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorForestTreeDeserialize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorForestTreeDeserialize.pbtxt
new file mode 100644
index 0000000..932eda7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorForestTreeDeserialize.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "TensorForestTreeDeserialize"
+  input_arg {
+    name: "tree_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "tree_config"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorForestTreeIsInitializedOp.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorForestTreeIsInitializedOp.pbtxt
new file mode 100644
index 0000000..df8b190
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorForestTreeIsInitializedOp.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "TensorForestTreeIsInitializedOp"
+  input_arg {
+    name: "tree_handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "is_initialized"
+    type: DT_BOOL
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorForestTreePredict.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorForestTreePredict.pbtxt
new file mode 100644
index 0000000..8ee1a9b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorForestTreePredict.pbtxt
@@ -0,0 +1,20 @@
+op {
+  name: "TensorForestTreePredict"
+  input_arg {
+    name: "tree_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "dense_features"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "logits"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "logits_dimension"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorForestTreeResourceHandleOp.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorForestTreeResourceHandleOp.pbtxt
new file mode 100644
index 0000000..881aead
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorForestTreeResourceHandleOp.pbtxt
@@ -0,0 +1,22 @@
+op {
+  name: "TensorForestTreeResourceHandleOp"
+  output_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorForestTreeSerialize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorForestTreeSerialize.pbtxt
new file mode 100644
index 0000000..24350a7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorForestTreeSerialize.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "TensorForestTreeSerialize"
+  input_arg {
+    name: "tree_handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "tree_config"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorForestTreeSize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorForestTreeSize.pbtxt
new file mode 100644
index 0000000..4416110
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorForestTreeSize.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "TensorForestTreeSize"
+  input_arg {
+    name: "tree_handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "tree_size"
+    type: DT_INT32
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorListConcat.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorListConcat.pbtxt
new file mode 100644
index 0000000..010be2e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorListConcat.pbtxt
@@ -0,0 +1,47 @@
+op {
+  name: "TensorListConcat"
+  input_arg {
+    name: "input_handle"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "tensor"
+    type_attr: "element_dtype"
+  }
+  output_arg {
+    name: "lengths"
+    type: DT_INT64
+  }
+  attr {
+    name: "element_dtype"
+    type: "type"
+  }
+}
+op {
+  name: "TensorListConcat"
+  input_arg {
+    name: "input_handle"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "tensor"
+    type_attr: "element_dtype"
+  }
+  output_arg {
+    name: "lengths"
+    type: DT_INT64
+  }
+  attr {
+    name: "element_dtype"
+    type: "type"
+  }
+  attr {
+    name: "element_shape"
+    type: "shape"
+    default_value {
+      shape {
+        unknown_rank: true
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorListConcatLists.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorListConcatLists.pbtxt
new file mode 100644
index 0000000..faf228c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorListConcatLists.pbtxt
@@ -0,0 +1,19 @@
+op {
+  name: "TensorListConcatLists"
+  input_arg {
+    name: "input_a"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "input_b"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "output"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "element_dtype"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorListConcatV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorListConcatV2.pbtxt
new file mode 100644
index 0000000..0bb9546
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorListConcatV2.pbtxt
@@ -0,0 +1,37 @@
+op {
+  name: "TensorListConcatV2"
+  input_arg {
+    name: "input_handle"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "element_shape"
+    type_attr: "shape_type"
+  }
+  input_arg {
+    name: "leading_dims"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "tensor"
+    type_attr: "element_dtype"
+  }
+  output_arg {
+    name: "lengths"
+    type: DT_INT64
+  }
+  attr {
+    name: "element_dtype"
+    type: "type"
+  }
+  attr {
+    name: "shape_type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorListElementShape.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorListElementShape.pbtxt
new file mode 100644
index 0000000..26b982f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorListElementShape.pbtxt
@@ -0,0 +1,21 @@
+op {
+  name: "TensorListElementShape"
+  input_arg {
+    name: "input_handle"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "element_shape"
+    type_attr: "shape_type"
+  }
+  attr {
+    name: "shape_type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorListFromTensor.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorListFromTensor.pbtxt
new file mode 100644
index 0000000..6372e1e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorListFromTensor.pbtxt
@@ -0,0 +1,29 @@
+op {
+  name: "TensorListFromTensor"
+  input_arg {
+    name: "tensor"
+    type_attr: "element_dtype"
+  }
+  input_arg {
+    name: "element_shape"
+    type_attr: "shape_type"
+  }
+  output_arg {
+    name: "output_handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "element_dtype"
+    type: "type"
+  }
+  attr {
+    name: "shape_type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorListGather.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorListGather.pbtxt
new file mode 100644
index 0000000..43b4773
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorListGather.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "TensorListGather"
+  input_arg {
+    name: "input_handle"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "element_shape"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "values"
+    type_attr: "element_dtype"
+  }
+  attr {
+    name: "element_dtype"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorListGetItem.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorListGetItem.pbtxt
new file mode 100644
index 0000000..fa124bc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorListGetItem.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "TensorListGetItem"
+  input_arg {
+    name: "input_handle"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "index"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "element_shape"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "item"
+    type_attr: "element_dtype"
+  }
+  attr {
+    name: "element_dtype"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorListLength.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorListLength.pbtxt
new file mode 100644
index 0000000..b4ea660
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorListLength.pbtxt
@@ -0,0 +1,11 @@
+op {
+  name: "TensorListLength"
+  input_arg {
+    name: "input_handle"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "length"
+    type: DT_INT32
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorListPopBack.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorListPopBack.pbtxt
new file mode 100644
index 0000000..35aa68e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorListPopBack.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "TensorListPopBack"
+  input_arg {
+    name: "input_handle"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "element_shape"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output_handle"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "tensor"
+    type_attr: "element_dtype"
+  }
+  attr {
+    name: "element_dtype"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorListPushBack.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorListPushBack.pbtxt
new file mode 100644
index 0000000..f603c30
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorListPushBack.pbtxt
@@ -0,0 +1,19 @@
+op {
+  name: "TensorListPushBack"
+  input_arg {
+    name: "input_handle"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "tensor"
+    type_attr: "element_dtype"
+  }
+  output_arg {
+    name: "output_handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "element_dtype"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorListPushBackBatch.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorListPushBackBatch.pbtxt
new file mode 100644
index 0000000..186a36f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorListPushBackBatch.pbtxt
@@ -0,0 +1,19 @@
+op {
+  name: "TensorListPushBackBatch"
+  input_arg {
+    name: "input_handles"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "tensor"
+    type_attr: "element_dtype"
+  }
+  output_arg {
+    name: "output_handles"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "element_dtype"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorListReserve.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorListReserve.pbtxt
new file mode 100644
index 0000000..deacc5e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorListReserve.pbtxt
@@ -0,0 +1,29 @@
+op {
+  name: "TensorListReserve"
+  input_arg {
+    name: "element_shape"
+    type_attr: "shape_type"
+  }
+  input_arg {
+    name: "num_elements"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "element_dtype"
+    type: "type"
+  }
+  attr {
+    name: "shape_type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorListResize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorListResize.pbtxt
new file mode 100644
index 0000000..55f8030
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorListResize.pbtxt
@@ -0,0 +1,16 @@
+op {
+  name: "TensorListResize"
+  input_arg {
+    name: "input_handle"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output_handle"
+    type: DT_VARIANT
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorListScatter.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorListScatter.pbtxt
new file mode 100644
index 0000000..9e9bdcc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorListScatter.pbtxt
@@ -0,0 +1,33 @@
+op {
+  name: "TensorListScatter"
+  input_arg {
+    name: "tensor"
+    type_attr: "element_dtype"
+  }
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "element_shape"
+    type_attr: "shape_type"
+  }
+  output_arg {
+    name: "output_handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "element_dtype"
+    type: "type"
+  }
+  attr {
+    name: "shape_type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorListScatterIntoExistingList.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorListScatterIntoExistingList.pbtxt
new file mode 100644
index 0000000..a67c344
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorListScatterIntoExistingList.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "TensorListScatterIntoExistingList"
+  input_arg {
+    name: "input_handle"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "tensor"
+    type_attr: "element_dtype"
+  }
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output_handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "element_dtype"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorListScatterV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorListScatterV2.pbtxt
new file mode 100644
index 0000000..2391a21
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorListScatterV2.pbtxt
@@ -0,0 +1,37 @@
+op {
+  name: "TensorListScatterV2"
+  input_arg {
+    name: "tensor"
+    type_attr: "element_dtype"
+  }
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "element_shape"
+    type_attr: "shape_type"
+  }
+  input_arg {
+    name: "num_elements"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output_handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "element_dtype"
+    type: "type"
+  }
+  attr {
+    name: "shape_type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorListSetItem.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorListSetItem.pbtxt
new file mode 100644
index 0000000..0d1fb78
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorListSetItem.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "TensorListSetItem"
+  input_arg {
+    name: "input_handle"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "index"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "item"
+    type_attr: "element_dtype"
+  }
+  output_arg {
+    name: "output_handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "element_dtype"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorListSplit.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorListSplit.pbtxt
new file mode 100644
index 0000000..2acc29a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorListSplit.pbtxt
@@ -0,0 +1,33 @@
+op {
+  name: "TensorListSplit"
+  input_arg {
+    name: "tensor"
+    type_attr: "element_dtype"
+  }
+  input_arg {
+    name: "element_shape"
+    type_attr: "shape_type"
+  }
+  input_arg {
+    name: "lengths"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output_handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "element_dtype"
+    type: "type"
+  }
+  attr {
+    name: "shape_type"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorListStack.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorListStack.pbtxt
new file mode 100644
index 0000000..5a8e7bc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorListStack.pbtxt
@@ -0,0 +1,26 @@
+op {
+  name: "TensorListStack"
+  input_arg {
+    name: "input_handle"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "element_shape"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "tensor"
+    type_attr: "element_dtype"
+  }
+  attr {
+    name: "element_dtype"
+    type: "type"
+  }
+  attr {
+    name: "num_elements"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorScatterAdd.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorScatterAdd.pbtxt
new file mode 100644
index 0000000..5fb5b8c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorScatterAdd.pbtxt
@@ -0,0 +1,33 @@
+op {
+  name: "TensorScatterAdd"
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorScatterSub.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorScatterSub.pbtxt
new file mode 100644
index 0000000..8192052
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorScatterSub.pbtxt
@@ -0,0 +1,33 @@
+op {
+  name: "TensorScatterSub"
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorScatterUpdate.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorScatterUpdate.pbtxt
new file mode 100644
index 0000000..f6de5ea
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorScatterUpdate.pbtxt
@@ -0,0 +1,33 @@
+op {
+  name: "TensorScatterUpdate"
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "updates"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorSliceDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorSliceDataset.pbtxt
new file mode 100644
index 0000000..af024aa
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorSliceDataset.pbtxt
@@ -0,0 +1,24 @@
+op {
+  name: "TensorSliceDataset"
+  input_arg {
+    name: "components"
+    type_list_attr: "Toutput_types"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "Toutput_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorStridedSliceUpdate.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorStridedSliceUpdate.pbtxt
new file mode 100644
index 0000000..3854eee
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorStridedSliceUpdate.pbtxt
@@ -0,0 +1,76 @@
+op {
+  name: "TensorStridedSliceUpdate"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "begin"
+    type_attr: "Index"
+  }
+  input_arg {
+    name: "end"
+    type_attr: "Index"
+  }
+  input_arg {
+    name: "strides"
+    type_attr: "Index"
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Index"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "begin_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "end_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "ellipsis_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "new_axis_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "shrink_axis_mask"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorSummary.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorSummary.pbtxt
new file mode 100644
index 0000000..bf4114a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorSummary.pbtxt
@@ -0,0 +1,37 @@
+op {
+  name: "TensorSummary"
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "summary"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "description"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "labels"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "display_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TensorSummaryV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TensorSummaryV2.pbtxt
new file mode 100644
index 0000000..39092b0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TensorSummaryV2.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "TensorSummaryV2"
+  input_arg {
+    name: "tag"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "serialized_summary_metadata"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "summary"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TextLineDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TextLineDataset.pbtxt
new file mode 100644
index 0000000..66cedaf
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TextLineDataset.pbtxt
@@ -0,0 +1,20 @@
+op {
+  name: "TextLineDataset"
+  input_arg {
+    name: "filenames"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "compression_type"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "buffer_size"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TextLineReader.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TextLineReader.pbtxt
new file mode 100644
index 0000000..baf1ef1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TextLineReader.pbtxt
@@ -0,0 +1,63 @@
+op {
+  name: "TextLineReader"
+  output_arg {
+    name: "reader_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "skip_header_lines"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "TextLineReader"
+  output_arg {
+    name: "reader_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "skip_header_lines"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  deprecation {
+    version: 26
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TextLineReaderV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TextLineReaderV2.pbtxt
new file mode 100644
index 0000000..c669951
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TextLineReaderV2.pbtxt
@@ -0,0 +1,29 @@
+op {
+  name: "TextLineReaderV2"
+  output_arg {
+    name: "reader_handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "skip_header_lines"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ThreadPoolDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ThreadPoolDataset.pbtxt
new file mode 100644
index 0000000..eac7485
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ThreadPoolDataset.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "ThreadPoolDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "thread_pool"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ThreadPoolHandle.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ThreadPoolHandle.pbtxt
new file mode 100644
index 0000000..e2518b1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ThreadPoolHandle.pbtxt
@@ -0,0 +1,37 @@
+op {
+  name: "ThreadPoolHandle"
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "num_threads"
+    type: "int"
+  }
+  attr {
+    name: "max_intra_op_parallelism"
+    type: "int"
+    default_value {
+      i: 1
+    }
+  }
+  attr {
+    name: "display_name"
+    type: "string"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ThreadUnsafeUnigramCandidateSampler.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ThreadUnsafeUnigramCandidateSampler.pbtxt
new file mode 100644
index 0000000..89106aa
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ThreadUnsafeUnigramCandidateSampler.pbtxt
@@ -0,0 +1,111 @@
+op {
+  name: "ThreadUnsafeUnigramCandidateSampler"
+  input_arg {
+    name: "true_classes"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sampled_candidates"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "true_expected_count"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "sampled_expected_count"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "num_true"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_sampled"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "unique"
+    type: "bool"
+  }
+  attr {
+    name: "range_max"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+}
+op {
+  name: "ThreadUnsafeUnigramCandidateSampler"
+  input_arg {
+    name: "true_classes"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sampled_candidates"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "true_expected_count"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "sampled_expected_count"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "num_true"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_sampled"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "unique"
+    type: "bool"
+  }
+  attr {
+    name: "range_max"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Tile.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Tile.pbtxt
new file mode 100644
index 0000000..67de1e5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Tile.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "Tile"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "multiples"
+    type_attr: "Tmultiples"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tmultiples"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TileGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TileGrad.pbtxt
new file mode 100644
index 0000000..f710e1c
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TileGrad.pbtxt
@@ -0,0 +1,22 @@
+op {
+  name: "TileGrad"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "multiples"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  deprecation {
+    version: 3
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Timestamp.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Timestamp.pbtxt
new file mode 100644
index 0000000..6e51504
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Timestamp.pbtxt
@@ -0,0 +1,8 @@
+op {
+  name: "Timestamp"
+  output_arg {
+    name: "ts"
+    type: DT_DOUBLE
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TopK.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TopK.pbtxt
new file mode 100644
index 0000000..71c98b7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TopK.pbtxt
@@ -0,0 +1,196 @@
+op {
+  name: "TopK"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  attr {
+    name: "k"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "sorted"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+  deprecation {
+    version: 7
+  }
+}
+op {
+  name: "TopK"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  attr {
+    name: "k"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "sorted"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  deprecation {
+    version: 7
+  }
+}
+op {
+  name: "TopK"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  attr {
+    name: "k"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "sorted"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  deprecation {
+    version: 7
+  }
+}
+op {
+  name: "TopK"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  attr {
+    name: "k"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "sorted"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  deprecation {
+    version: 7
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TopKV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TopKV2.pbtxt
new file mode 100644
index 0000000..5089d2b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TopKV2.pbtxt
@@ -0,0 +1,180 @@
+op {
+  name: "TopKV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "k"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  attr {
+    name: "sorted"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+}
+op {
+  name: "TopKV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "k"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  attr {
+    name: "sorted"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
+op {
+  name: "TopKV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "k"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  attr {
+    name: "sorted"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+}
+op {
+  name: "TopKV2"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "k"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  attr {
+    name: "sorted"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Transpose.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Transpose.pbtxt
new file mode 100644
index 0000000..fa4fb6d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Transpose.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "Transpose"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "perm"
+    type_attr: "Tperm"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tperm"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TridiagonalMatMul.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TridiagonalMatMul.pbtxt
new file mode 100644
index 0000000..117d68b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TridiagonalMatMul.pbtxt
@@ -0,0 +1,35 @@
+op {
+  name: "TridiagonalMatMul"
+  input_arg {
+    name: "superdiag"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "maindiag"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "subdiag"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rhs"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TridiagonalSolve.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TridiagonalSolve.pbtxt
new file mode 100644
index 0000000..ee2cf74
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TridiagonalSolve.pbtxt
@@ -0,0 +1,61 @@
+op {
+  name: "TridiagonalSolve"
+  input_arg {
+    name: "diagonals"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rhs"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "TridiagonalSolve"
+  input_arg {
+    name: "diagonals"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "rhs"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "partial_pivoting"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_DOUBLE
+        type: DT_FLOAT
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TruncateDiv.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TruncateDiv.pbtxt
new file mode 100644
index 0000000..82eda3e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TruncateDiv.pbtxt
@@ -0,0 +1,104 @@
+op {
+  name: "TruncateDiv"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "TruncateDiv"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
+op {
+  name: "TruncateDiv"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_UINT8
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TruncateMod.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TruncateMod.pbtxt
new file mode 100644
index 0000000..70ce81b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TruncateMod.pbtxt
@@ -0,0 +1,84 @@
+op {
+  name: "TruncateMod"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "TruncateMod"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
+op {
+  name: "TruncateMod"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TruncatedNormal.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TruncatedNormal.pbtxt
new file mode 100644
index 0000000..018d657
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TruncatedNormal.pbtxt
@@ -0,0 +1,95 @@
+op {
+  name: "TruncatedNormal"
+  input_arg {
+    name: "shape"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "TruncatedNormal"
+  input_arg {
+    name: "shape"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/TryRpc.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TryRpc.pbtxt
new file mode 100644
index 0000000..e585195
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/TryRpc.pbtxt
@@ -0,0 +1,49 @@
+op {
+  name: "TryRpc"
+  input_arg {
+    name: "address"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "method"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "request"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "response"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "status_code"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "status_message"
+    type: DT_STRING
+  }
+  attr {
+    name: "protocol"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "fail_fast"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "timeout_in_ms"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Unbatch.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Unbatch.pbtxt
new file mode 100644
index 0000000..3934b18
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Unbatch.pbtxt
@@ -0,0 +1,41 @@
+op {
+  name: "Unbatch"
+  input_arg {
+    name: "batched_tensor"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "batch_index"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "id"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "unbatched_tensor"
+    type_attr: "T"
+  }
+  attr {
+    name: "timeout_micros"
+    type: "int"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/UnbatchDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/UnbatchDataset.pbtxt
new file mode 100644
index 0000000..cd61d31
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/UnbatchDataset.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "UnbatchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/UnbatchGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/UnbatchGrad.pbtxt
new file mode 100644
index 0000000..97240f0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/UnbatchGrad.pbtxt
@@ -0,0 +1,41 @@
+op {
+  name: "UnbatchGrad"
+  input_arg {
+    name: "original_input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "batch_index"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "id"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "batched_grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/UnicodeDecode.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/UnicodeDecode.pbtxt
new file mode 100644
index 0000000..fa036b3
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/UnicodeDecode.pbtxt
@@ -0,0 +1,107 @@
+op {
+  name: "UnicodeDecode"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "row_splits"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "char_values"
+    type: DT_INT32
+  }
+  attr {
+    name: "input_encoding"
+    type: "string"
+  }
+  attr {
+    name: "errors"
+    type: "string"
+    default_value {
+      s: "replace"
+    }
+    allowed_values {
+      list {
+        s: "strict"
+        s: "replace"
+        s: "ignore"
+      }
+    }
+  }
+  attr {
+    name: "replacement_char"
+    type: "int"
+    default_value {
+      i: 65533
+    }
+  }
+  attr {
+    name: "replace_control_characters"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "UnicodeDecode"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "row_splits"
+    type_attr: "Tsplits"
+  }
+  output_arg {
+    name: "char_values"
+    type: DT_INT32
+  }
+  attr {
+    name: "input_encoding"
+    type: "string"
+  }
+  attr {
+    name: "errors"
+    type: "string"
+    default_value {
+      s: "replace"
+    }
+    allowed_values {
+      list {
+        s: "strict"
+        s: "replace"
+        s: "ignore"
+      }
+    }
+  }
+  attr {
+    name: "replacement_char"
+    type: "int"
+    default_value {
+      i: 65533
+    }
+  }
+  attr {
+    name: "replace_control_characters"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "Tsplits"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/UnicodeDecodeWithOffsets.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/UnicodeDecodeWithOffsets.pbtxt
new file mode 100644
index 0000000..29d2747
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/UnicodeDecodeWithOffsets.pbtxt
@@ -0,0 +1,115 @@
+op {
+  name: "UnicodeDecodeWithOffsets"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "row_splits"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "char_values"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "char_to_byte_starts"
+    type: DT_INT64
+  }
+  attr {
+    name: "input_encoding"
+    type: "string"
+  }
+  attr {
+    name: "errors"
+    type: "string"
+    default_value {
+      s: "replace"
+    }
+    allowed_values {
+      list {
+        s: "strict"
+        s: "replace"
+        s: "ignore"
+      }
+    }
+  }
+  attr {
+    name: "replacement_char"
+    type: "int"
+    default_value {
+      i: 65533
+    }
+  }
+  attr {
+    name: "replace_control_characters"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
+  name: "UnicodeDecodeWithOffsets"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "row_splits"
+    type_attr: "Tsplits"
+  }
+  output_arg {
+    name: "char_values"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "char_to_byte_starts"
+    type: DT_INT64
+  }
+  attr {
+    name: "input_encoding"
+    type: "string"
+  }
+  attr {
+    name: "errors"
+    type: "string"
+    default_value {
+      s: "replace"
+    }
+    allowed_values {
+      list {
+        s: "strict"
+        s: "replace"
+        s: "ignore"
+      }
+    }
+  }
+  attr {
+    name: "replacement_char"
+    type: "int"
+    default_value {
+      i: 65533
+    }
+  }
+  attr {
+    name: "replace_control_characters"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "Tsplits"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/UnicodeEncode.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/UnicodeEncode.pbtxt
new file mode 100644
index 0000000..31a7a5b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/UnicodeEncode.pbtxt
@@ -0,0 +1,107 @@
+op {
+  name: "UnicodeEncode"
+  input_arg {
+    name: "input_values"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "input_splits"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output"
+    type: DT_STRING
+  }
+  attr {
+    name: "errors"
+    type: "string"
+    default_value {
+      s: "replace"
+    }
+    allowed_values {
+      list {
+        s: "ignore"
+        s: "replace"
+        s: "strict"
+      }
+    }
+  }
+  attr {
+    name: "output_encoding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "UTF-8"
+        s: "UTF-16-BE"
+        s: "UTF-32-BE"
+      }
+    }
+  }
+  attr {
+    name: "replacement_char"
+    type: "int"
+    default_value {
+      i: 65533
+    }
+  }
+}
+op {
+  name: "UnicodeEncode"
+  input_arg {
+    name: "input_values"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "input_splits"
+    type_attr: "Tsplits"
+  }
+  output_arg {
+    name: "output"
+    type: DT_STRING
+  }
+  attr {
+    name: "errors"
+    type: "string"
+    default_value {
+      s: "replace"
+    }
+    allowed_values {
+      list {
+        s: "ignore"
+        s: "replace"
+        s: "strict"
+      }
+    }
+  }
+  attr {
+    name: "output_encoding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "UTF-8"
+        s: "UTF-16-BE"
+        s: "UTF-32-BE"
+      }
+    }
+  }
+  attr {
+    name: "replacement_char"
+    type: "int"
+    default_value {
+      i: 65533
+    }
+  }
+  attr {
+    name: "Tsplits"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/UnicodeScript.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/UnicodeScript.pbtxt
new file mode 100644
index 0000000..60877b5
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/UnicodeScript.pbtxt
@@ -0,0 +1,11 @@
+op {
+  name: "UnicodeScript"
+  input_arg {
+    name: "input"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type: DT_INT32
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/UnicodeTranscode.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/UnicodeTranscode.pbtxt
new file mode 100644
index 0000000..5cab737
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/UnicodeTranscode.pbtxt
@@ -0,0 +1,54 @@
+op {
+  name: "UnicodeTranscode"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type: DT_STRING
+  }
+  attr {
+    name: "input_encoding"
+    type: "string"
+  }
+  attr {
+    name: "output_encoding"
+    type: "string"
+    allowed_values {
+      list {
+        s: "UTF-8"
+        s: "UTF-16-BE"
+        s: "UTF-32-BE"
+      }
+    }
+  }
+  attr {
+    name: "errors"
+    type: "string"
+    default_value {
+      s: "replace"
+    }
+    allowed_values {
+      list {
+        s: "strict"
+        s: "replace"
+        s: "ignore"
+      }
+    }
+  }
+  attr {
+    name: "replacement_char"
+    type: "int"
+    default_value {
+      i: 65533
+    }
+  }
+  attr {
+    name: "replace_control_characters"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/UniformCandidateSampler.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/UniformCandidateSampler.pbtxt
new file mode 100644
index 0000000..bea963f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/UniformCandidateSampler.pbtxt
@@ -0,0 +1,111 @@
+op {
+  name: "UniformCandidateSampler"
+  input_arg {
+    name: "true_classes"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sampled_candidates"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "true_expected_count"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "sampled_expected_count"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "num_true"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_sampled"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "unique"
+    type: "bool"
+  }
+  attr {
+    name: "range_max"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+}
+op {
+  name: "UniformCandidateSampler"
+  input_arg {
+    name: "true_classes"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "sampled_candidates"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "true_expected_count"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "sampled_expected_count"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "num_true"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_sampled"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "unique"
+    type: "bool"
+  }
+  attr {
+    name: "range_max"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "seed"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "seed2"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Unique.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Unique.pbtxt
new file mode 100644
index 0000000..be389ba
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Unique.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "Unique"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "idx"
+    type_attr: "out_idx"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "out_idx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/UniqueDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/UniqueDataset.pbtxt
new file mode 100644
index 0000000..ef44284
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/UniqueDataset.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "UniqueDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/UniqueV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/UniqueV2.pbtxt
new file mode 100644
index 0000000..83113e1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/UniqueV2.pbtxt
@@ -0,0 +1,85 @@
+op {
+  name: "UniqueV2"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "axis"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "idx"
+    type_attr: "out_idx"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "out_idx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "UniqueV2"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Taxis"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "idx"
+    type_attr: "out_idx"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Taxis"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "out_idx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/UniqueWithCounts.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/UniqueWithCounts.pbtxt
new file mode 100644
index 0000000..c386059
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/UniqueWithCounts.pbtxt
@@ -0,0 +1,36 @@
+op {
+  name: "UniqueWithCounts"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "idx"
+    type_attr: "out_idx"
+  }
+  output_arg {
+    name: "count"
+    type_attr: "out_idx"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "out_idx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/UniqueWithCountsV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/UniqueWithCountsV2.pbtxt
new file mode 100644
index 0000000..85a12b7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/UniqueWithCountsV2.pbtxt
@@ -0,0 +1,53 @@
+op {
+  name: "UniqueWithCountsV2"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "axis"
+    type_attr: "Taxis"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "idx"
+    type_attr: "out_idx"
+  }
+  output_arg {
+    name: "count"
+    type_attr: "out_idx"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Taxis"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "out_idx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Unpack.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Unpack.pbtxt
new file mode 100644
index 0000000..cc5fd91
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Unpack.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "Unpack"
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+    number_attr: "num"
+  }
+  attr {
+    name: "num"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "axis"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/UnravelIndex.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/UnravelIndex.pbtxt
new file mode 100644
index 0000000..df2c2bc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/UnravelIndex.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "UnravelIndex"
+  input_arg {
+    name: "indices"
+    type_attr: "Tidx"
+  }
+  input_arg {
+    name: "dims"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "Tidx"
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/UnsortedSegmentJoin.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/UnsortedSegmentJoin.pbtxt
new file mode 100644
index 0000000..dcbb91b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/UnsortedSegmentJoin.pbtxt
@@ -0,0 +1,49 @@
+op {
+  name: "UnsortedSegmentJoin"
+  input_arg {
+    name: "inputs"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "num_segments"
+    type_attr: "Tnumsegments"
+  }
+  output_arg {
+    name: "output"
+    type: DT_STRING
+  }
+  attr {
+    name: "separator"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tnumsegments"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/UnsortedSegmentMax.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/UnsortedSegmentMax.pbtxt
new file mode 100644
index 0000000..ee8578f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/UnsortedSegmentMax.pbtxt
@@ -0,0 +1,218 @@
+op {
+  name: "UnsortedSegmentMax"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "num_segments"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "UnsortedSegmentMax"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "num_segments"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "UnsortedSegmentMax"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "num_segments"
+    type_attr: "Tnumsegments"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_INT64
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tnumsegments"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "UnsortedSegmentMax"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "num_segments"
+    type_attr: "Tnumsegments"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tnumsegments"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/UnsortedSegmentMin.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/UnsortedSegmentMin.pbtxt
new file mode 100644
index 0000000..6a8e5ba
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/UnsortedSegmentMin.pbtxt
@@ -0,0 +1,62 @@
+op {
+  name: "UnsortedSegmentMin"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "num_segments"
+    type_attr: "Tnumsegments"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tnumsegments"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/UnsortedSegmentProd.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/UnsortedSegmentProd.pbtxt
new file mode 100644
index 0000000..e255518
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/UnsortedSegmentProd.pbtxt
@@ -0,0 +1,129 @@
+op {
+  name: "UnsortedSegmentProd"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "num_segments"
+    type_attr: "Tnumsegments"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tnumsegments"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "UnsortedSegmentProd"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "num_segments"
+    type_attr: "Tnumsegments"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tnumsegments"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/UnsortedSegmentSum.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/UnsortedSegmentSum.pbtxt
new file mode 100644
index 0000000..1925328
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/UnsortedSegmentSum.pbtxt
@@ -0,0 +1,238 @@
+op {
+  name: "UnsortedSegmentSum"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "num_segments"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "UnsortedSegmentSum"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "num_segments"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "UnsortedSegmentSum"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "num_segments"
+    type_attr: "Tnumsegments"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tnumsegments"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
+  name: "UnsortedSegmentSum"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "num_segments"
+    type_attr: "Tnumsegments"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "Tnumsegments"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Unstage.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Unstage.pbtxt
new file mode 100644
index 0000000..4bcfd02
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Unstage.pbtxt
@@ -0,0 +1,72 @@
+op {
+  name: "Unstage"
+  output_arg {
+    name: "values"
+    type_list_attr: "dtypes"
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "Unstage"
+  output_arg {
+    name: "values"
+    type_list_attr: "dtypes"
+  }
+  attr {
+    name: "capacity"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "memory_limit"
+    type: "int"
+    default_value {
+      i: 0
+    }
+    has_minimum: true
+  }
+  attr {
+    name: "dtypes"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/UnwrapDatasetVariant.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/UnwrapDatasetVariant.pbtxt
new file mode 100644
index 0000000..10e23a9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/UnwrapDatasetVariant.pbtxt
@@ -0,0 +1,11 @@
+op {
+  name: "UnwrapDatasetVariant"
+  input_arg {
+    name: "input_handle"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "output_handle"
+    type: DT_VARIANT
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/UpperBound.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/UpperBound.pbtxt
new file mode 100644
index 0000000..d1b3fa0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/UpperBound.pbtxt
@@ -0,0 +1,32 @@
+op {
+  name: "UpperBound"
+  input_arg {
+    name: "sorted_inputs"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/VarHandleOp.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/VarHandleOp.pbtxt
new file mode 100644
index 0000000..b5722b9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/VarHandleOp.pbtxt
@@ -0,0 +1,30 @@
+op {
+  name: "VarHandleOp"
+  output_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/VarIsInitializedOp.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/VarIsInitializedOp.pbtxt
new file mode 100644
index 0000000..3953601
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/VarIsInitializedOp.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "VarIsInitializedOp"
+  input_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "is_initialized"
+    type: DT_BOOL
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Variable.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Variable.pbtxt
new file mode 100644
index 0000000..943c24d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Variable.pbtxt
@@ -0,0 +1,31 @@
+op {
+  name: "Variable"
+  output_arg {
+    name: "ref"
+    type_attr: "dtype"
+    is_ref: true
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/VariableShape.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/VariableShape.pbtxt
new file mode 100644
index 0000000..570b4f2
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/VariableShape.pbtxt
@@ -0,0 +1,25 @@
+op {
+  name: "VariableShape"
+  input_arg {
+    name: "input"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/VariableV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/VariableV2.pbtxt
new file mode 100644
index 0000000..c27112f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/VariableV2.pbtxt
@@ -0,0 +1,31 @@
+op {
+  name: "VariableV2"
+  output_arg {
+    name: "ref"
+    type_attr: "dtype"
+    is_ref: true
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Where.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Where.pbtxt
new file mode 100644
index 0000000..c85edfd
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Where.pbtxt
@@ -0,0 +1,130 @@
+op {
+  name: "Where"
+  input_arg {
+    name: "input"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "index"
+    type: DT_INT64
+  }
+}
+op {
+  name: "Where"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "index"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_BOOL
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BOOL
+      }
+    }
+  }
+}
+op {
+  name: "Where"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "index"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_BOOL
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BFLOAT16
+        type: DT_BOOL
+      }
+    }
+  }
+}
+op {
+  name: "Where"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "index"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_BOOL
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_BOOL
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/While.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/While.pbtxt
new file mode 100644
index 0000000..807461b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/While.pbtxt
@@ -0,0 +1,98 @@
+op {
+  name: "While"
+  input_arg {
+    name: "input"
+    type_list_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "cond"
+    type: "func"
+  }
+  attr {
+    name: "body"
+    type: "func"
+  }
+  is_stateful: true
+}
+op {
+  name: "While"
+  input_arg {
+    name: "input"
+    type_list_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "cond"
+    type: "func"
+  }
+  attr {
+    name: "body"
+    type: "func"
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    default_value {
+      list {
+      }
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "While"
+  input_arg {
+    name: "input"
+    type_list_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_list_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "cond"
+    type: "func"
+  }
+  attr {
+    name: "body"
+    type: "func"
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "parallel_iterations"
+    type: "int"
+    default_value {
+      i: 10
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/WholeFileReader.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/WholeFileReader.pbtxt
new file mode 100644
index 0000000..729d765
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/WholeFileReader.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "WholeFileReader"
+  output_arg {
+    name: "reader_handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/WholeFileReaderV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/WholeFileReaderV2.pbtxt
new file mode 100644
index 0000000..2430494
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/WholeFileReaderV2.pbtxt
@@ -0,0 +1,22 @@
+op {
+  name: "WholeFileReaderV2"
+  output_arg {
+    name: "reader_handle"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/WindowDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/WindowDataset.pbtxt
new file mode 100644
index 0000000..9b68c41
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/WindowDataset.pbtxt
@@ -0,0 +1,39 @@
+op {
+  name: "WindowDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "shift"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "stride"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "drop_remainder"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/WorkerHeartbeat.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/WorkerHeartbeat.pbtxt
new file mode 100644
index 0000000..ae5c7b8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/WorkerHeartbeat.pbtxt
@@ -0,0 +1,12 @@
+op {
+  name: "WorkerHeartbeat"
+  input_arg {
+    name: "request"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "response"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/WrapDatasetVariant.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/WrapDatasetVariant.pbtxt
new file mode 100644
index 0000000..0b1e436
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/WrapDatasetVariant.pbtxt
@@ -0,0 +1,11 @@
+op {
+  name: "WrapDatasetVariant"
+  input_arg {
+    name: "input_handle"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "output_handle"
+    type: DT_VARIANT
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/WriteAudioSummary.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/WriteAudioSummary.pbtxt
new file mode 100644
index 0000000..8cc81eb
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/WriteAudioSummary.pbtxt
@@ -0,0 +1,33 @@
+op {
+  name: "WriteAudioSummary"
+  input_arg {
+    name: "writer"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "step"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "tag"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "tensor"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "sample_rate"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "max_outputs"
+    type: "int"
+    default_value {
+      i: 3
+    }
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/WriteFile.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/WriteFile.pbtxt
new file mode 100644
index 0000000..6a15b39
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/WriteFile.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "WriteFile"
+  input_arg {
+    name: "filename"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "contents"
+    type: DT_STRING
+  }
+}
+op {
+  name: "WriteFile"
+  input_arg {
+    name: "filename"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "contents"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/WriteGraphSummary.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/WriteGraphSummary.pbtxt
new file mode 100644
index 0000000..2957e22
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/WriteGraphSummary.pbtxt
@@ -0,0 +1,16 @@
+op {
+  name: "WriteGraphSummary"
+  input_arg {
+    name: "writer"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "step"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "tensor"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/WriteHistogramSummary.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/WriteHistogramSummary.pbtxt
new file mode 100644
index 0000000..544dab6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/WriteHistogramSummary.pbtxt
@@ -0,0 +1,43 @@
+op {
+  name: "WriteHistogramSummary"
+  input_arg {
+    name: "writer"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "step"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "tag"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/WriteImageSummary.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/WriteImageSummary.pbtxt
new file mode 100644
index 0000000..d4248e6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/WriteImageSummary.pbtxt
@@ -0,0 +1,47 @@
+op {
+  name: "WriteImageSummary"
+  input_arg {
+    name: "writer"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "step"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "tag"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "bad_color"
+    type: DT_UINT8
+  }
+  attr {
+    name: "max_images"
+    type: "int"
+    default_value {
+      i: 3
+    }
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_FLOAT
+        type: DT_HALF
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/WriteRawProtoSummary.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/WriteRawProtoSummary.pbtxt
new file mode 100644
index 0000000..82ac51a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/WriteRawProtoSummary.pbtxt
@@ -0,0 +1,16 @@
+op {
+  name: "WriteRawProtoSummary"
+  input_arg {
+    name: "writer"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "step"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "tensor"
+    type: DT_STRING
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/WriteScalarSummary.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/WriteScalarSummary.pbtxt
new file mode 100644
index 0000000..0f359a8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/WriteScalarSummary.pbtxt
@@ -0,0 +1,40 @@
+op {
+  name: "WriteScalarSummary"
+  input_arg {
+    name: "writer"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "step"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "tag"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "value"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/WriteSummary.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/WriteSummary.pbtxt
new file mode 100644
index 0000000..a641ece
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/WriteSummary.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "WriteSummary"
+  input_arg {
+    name: "writer"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "step"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "tag"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "summary_metadata"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Xdivy.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Xdivy.pbtxt
new file mode 100644
index 0000000..472d872
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Xdivy.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "Xdivy"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Xlogy.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Xlogy.pbtxt
new file mode 100644
index 0000000..cf727d2
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Xlogy.pbtxt
@@ -0,0 +1,28 @@
+op {
+  name: "Xlogy"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ZerosLike.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ZerosLike.pbtxt
new file mode 100644
index 0000000..5bb8d0a
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ZerosLike.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "ZerosLike"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Zeta.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Zeta.pbtxt
new file mode 100644
index 0000000..c391bd1
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/Zeta.pbtxt
@@ -0,0 +1,25 @@
+op {
+  name: "Zeta"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "q"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v1/ZipDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/ZipDataset.pbtxt
new file mode 100644
index 0000000..16e7ed8
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/ZipDataset.pbtxt
@@ -0,0 +1,61 @@
+op {
+  name: "ZipDataset"
+  input_arg {
+    name: "input_datasets"
+    type: DT_VARIANT
+    number_attr: "N"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
+op {
+  name: "ZipDataset"
+  input_arg {
+    name: "input_datasets"
+    type: DT_VARIANT
+    number_attr: "N"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+}
diff --git a/tensorflow/core/ops/compat/update_ops_main.cc b/tensorflow/core/ops/compat/update_ops_main.cc
index 79c830a..3618617 100644
--- a/tensorflow/core/ops/compat/update_ops_main.cc
+++ b/tensorflow/core/ops/compat/update_ops_main.cc
@@ -16,7 +16,9 @@
 #include <stdio.h>
 
 #include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/ops/compat/op_compatibility_lib.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/init_main.h"
@@ -41,19 +43,26 @@
   // Make sure the current version of ops are compatible with the
   // historical versions, and generate a new history adding all
   // changed ops.
-  OpList out_op_history;
+  OpCompatibilityLib::OpHistory out_op_history;
   int changed_ops = 0;
   int added_ops = 0;
   TF_QCHECK_OK(compatibility.ValidateCompatible(env, &changed_ops, &added_ops,
                                                 &out_op_history));
   printf("%d changed ops\n%d added ops\n", changed_ops, added_ops);
 
-  if (changed_ops + added_ops > 0) {
+  const string& history_dir = compatibility.op_history_directory();
+  Status status = env->CreateDir(history_dir);
+  if (!errors::IsAlreadyExists(status)) {
+    TF_QCHECK_OK(status);
+  }
+  if (changed_ops + added_ops > 0 || !errors::IsAlreadyExists(status)) {
     // Write out new op history.
-    const string& history_file = compatibility.op_history_file();
-    printf("Writing updated op history to %s...\n", history_file.c_str());
-    TF_QCHECK_OK(
-        WriteStringToFile(env, history_file, out_op_history.DebugString()));
+    printf("Writing updated op history to %s/...\n", history_dir.c_str());
+    for (const auto& op_file : out_op_history) {
+      TF_QCHECK_OK(WriteStringToFile(env,
+                                     io::JoinPath(history_dir, op_file.first),
+                                     op_file.second.DebugString()));
+    }
   }
 }
 
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index b4cd4f6..0b689b6 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -168,6 +168,7 @@
     .Attr("output_types: list(type) >= 1")
     .Attr("output_shapes: list(shape) >= 1")
     .Attr("slack_period: int = 0")
+    .Attr("legacy_autotune: bool = true")
     .SetShapeFn([](shape_inference::InferenceContext* c) {
       shape_inference::ShapeHandle unused;
       // buffer_size should be a scalar.
@@ -354,6 +355,22 @@
       return shape_inference::ScalarShape(c);
     });
 
+REGISTER_OP("AnonymousRandomSeedGenerator")
+    .Input("seed: int64")
+    .Input("seed2: int64")
+    .Output("handle: resource")
+    .Output("deleter: variant")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      c->set_output(0, c->Scalar());
+      c->set_output(1, c->Scalar());
+      return Status::OK();
+    });
+
+REGISTER_OP("DeleteRandomSeedGenerator")
+    .Input("handle: resource")
+    .Input("deleter: variant")
+    .SetShapeFn(shape_inference::NoOutputs);
+
 REGISTER_OP("ShuffleDataset")
     .Input("input_dataset: variant")
     .Input("buffer_size: int64")
@@ -372,6 +389,21 @@
       return shape_inference::ScalarShape(c);
     });
 
+REGISTER_OP("ShuffleDatasetV2")
+    .Input("input_dataset: variant")
+    .Input("buffer_size: int64")
+    .Input("seed_generator: resource")
+    .Output("handle: variant")
+    .Attr("output_types: list(type) >= 1")
+    .Attr("output_shapes: list(shape) >= 1")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      shape_inference::ShapeHandle unused;
+      // buffer_size, seed, and seed2 should be scalars.
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+      return shape_inference::ScalarShape(c);
+    });
+
 REGISTER_OP("ShuffleAndRepeatDataset")
     .Input("input_dataset: variant")
     .Input("buffer_size: int64")
@@ -391,6 +423,20 @@
       return shape_inference::ScalarShape(c);
     });
 
+REGISTER_OP("AnonymousMemoryCache")
+    .Output("handle: resource")
+    .Output("deleter: variant")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      c->set_output(0, c->Scalar());
+      c->set_output(1, c->Scalar());
+      return Status::OK();
+    });
+
+REGISTER_OP("DeleteMemoryCache")
+    .Input("handle: resource")
+    .Input("deleter: variant")
+    .SetShapeFn(shape_inference::NoOutputs);
+
 REGISTER_OP("CacheDataset")
     .Input("input_dataset: variant")
     .Input("filename: string")
@@ -404,6 +450,22 @@
       return shape_inference::ScalarShape(c);
     });
 
+REGISTER_OP("CacheDatasetV2")
+    .Input("input_dataset: variant")
+    .Input("filename: string")
+    .Input("cache: resource")
+    .Output("handle: variant")
+    .Attr("output_types: list(type) >= 1")
+    .Attr("output_shapes: list(shape) >= 1")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      shape_inference::ShapeHandle unused;
+      // filename should be a scalar.
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+      // cache should be a scalar.
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+      return shape_inference::ScalarShape(c);
+    });
+
 REGISTER_OP("TextLineDataset")
     .Input("filenames: string")
     .Input("compression_type: string")
diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc
index 5504f5e5..68823c8 100644
--- a/tensorflow/core/ops/experimental_dataset_ops.cc
+++ b/tensorflow/core/ops/experimental_dataset_ops.cc
@@ -658,7 +658,7 @@
 
 REGISTER_OP("ExperimentalRebatchDataset")
     .Input("input_dataset: variant")
-    .Input("num_workers: int64")
+    .Input("num_replicas: int64")
     .Output("handle: variant")
     .Attr("output_types: list(type) >= 1")
     .Attr("output_shapes: list(shape) >= 1")
@@ -667,7 +667,7 @@
 
 REGISTER_OP("RebatchDataset")
     .Input("input_dataset: variant")
-    .Input("num_workers: int64")
+    .Input("num_replicas: int64")
     .Output("handle: variant")
     .Attr("output_types: list(type) >= 1")
     .Attr("output_shapes: list(shape) >= 1")
diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc
index 8f1ac77..f5f7244 100644
--- a/tensorflow/core/ops/functional_ops.cc
+++ b/tensorflow/core/ops/functional_ops.cc
@@ -195,6 +195,31 @@
       by T.
 )doc");
 
+Status WhileShapeInferenceFn(shape_inference::InferenceContext* c) {
+  std::vector<PartialTensorShape> output_shapes;
+  TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+  // If `output_shapes` attr is set use that as the shapes of the outputs
+  // else use the input shapes.
+  if (!output_shapes.empty()) {
+    if (output_shapes.size() != c->num_outputs()) {
+      return errors::InvalidArgument(
+          "`output_shapes` must be the same length as num outputs (",
+          output_shapes.size(), " vs. ", c->num_outputs());
+    }
+    for (size_t i = 0; i < output_shapes.size(); ++i) {
+      shape_inference::ShapeHandle output_shape_handle;
+      TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+          output_shapes[i], &output_shape_handle));
+      c->set_output(static_cast<int>(i), output_shape_handle);
+    }
+  } else {
+    for (int i = 0; i < c->num_outputs(); ++i) {
+      c->set_output(i, c->input(i));
+    }
+  }
+  return Status::OK();
+}
+
 REGISTER_OP("While")
     .Input("input: T")
     .Output("output: T")
@@ -204,30 +229,7 @@
     .Attr("output_shapes: list(shape) = []")
     .Attr("parallel_iterations: int = 10")
     .SetIsStateful()
-    .SetShapeFn([](shape_inference::InferenceContext* c) {
-      std::vector<PartialTensorShape> output_shapes;
-      TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
-      // If `output_shapes` attr is set use that as the shapes of the outputs
-      // else use the input shapes.
-      if (!output_shapes.empty()) {
-        if (output_shapes.size() != c->num_outputs()) {
-          return errors::InvalidArgument(
-              "`output_shapes` must be the same length as num outputs (",
-              output_shapes.size(), " vs. ", c->num_outputs());
-        }
-        for (size_t i = 0; i < output_shapes.size(); ++i) {
-          shape_inference::ShapeHandle output_shape_handle;
-          TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
-              output_shapes[i], &output_shape_handle));
-          c->set_output(static_cast<int>(i), output_shape_handle);
-        }
-      } else {
-        for (int i = 0; i < c->num_outputs(); ++i) {
-          c->set_output(i, c->input(i));
-        }
-      }
-      return Status::OK();
-    });
+    .SetShapeFn(WhileShapeInferenceFn);
 
 REGISTER_OP("StatelessWhile")
     .Input("input: T")
@@ -235,12 +237,9 @@
     .Attr("T: list(type) >= 0")
     .Attr("cond: func")
     .Attr("body: func")
-    .SetShapeFn([](shape_inference::InferenceContext* c) {
-      for (int i = 0; i < c->num_outputs(); ++i) {
-        c->set_output(i, c->input(i));
-      }
-      return Status::OK();
-    });
+    .Attr("output_shapes: list(shape) = []")
+    .Attr("parallel_iterations: int = 10")
+    .SetShapeFn(WhileShapeInferenceFn);
 
 REGISTER_OP("For")
     .Input("start: int32")
diff --git a/tensorflow/core/ops/io_ops.cc b/tensorflow/core/ops/io_ops.cc
index 1f2edee..e2078e0 100644
--- a/tensorflow/core/ops/io_ops.cc
+++ b/tensorflow/core/ops/io_ops.cc
@@ -101,7 +101,7 @@
       const Tensor* shape_and_slices_tensor = c->input_tensor(2);
       if (shape_and_slices_tensor) {
         const auto& shape_and_slices_flat =
-            shape_and_slices_tensor->flat<string>();
+            shape_and_slices_tensor->flat<tstring>();
         if (shape_and_slices_flat.size() != c->num_outputs()) {
           return errors::InvalidArgument(
               "The number of shape_and_slice doesn't match tensor outputs.");
@@ -222,7 +222,7 @@
       const Tensor* shape_and_slices_tensor = c->input_tensor(2);
       if (shape_and_slices_tensor) {
         const auto& shape_and_slice =
-            shape_and_slices_tensor->flat<string>()(0);
+            shape_and_slices_tensor->flat<tstring>()(0);
         if (shape_and_slice.empty()) {
           c->set_output(0, c->UnknownShape());
         } else {
diff --git a/tensorflow/core/ops/list_ops.cc b/tensorflow/core/ops/list_ops.cc
index 7a0ccb1..c5d21ef 100644
--- a/tensorflow/core/ops/list_ops.cc
+++ b/tensorflow/core/ops/list_ops.cc
@@ -590,18 +590,17 @@
 
       auto* handle_data_a = c->input_handle_shapes_and_types(0);
       auto* handle_data_b = c->input_handle_shapes_and_types(1);
-      if ((handle_data_a == nullptr || handle_data_a->empty()) &&
-          (handle_data_b == nullptr || handle_data_b->empty())) {
+      bool handle_data_a_nonempty = handle_data_a && !handle_data_a->empty();
+      bool handle_data_b_nonempty = handle_data_b && !handle_data_b->empty();
+      if (!(handle_data_a_nonempty || handle_data_b_nonempty)) {
         c->set_output_handle_shapes_and_types(
             0, {{c->UnknownShape(), element_dtype}});
         return Status::OK();
       }
       shape_inference::ShapeAndType list_shape_type_a =
-          (handle_data_a && !handle_data_a->empty()) ? handle_data_a->at(0)
-                                                     : handle_data_b->at(0);
+          handle_data_a_nonempty ? handle_data_a->at(0) : handle_data_b->at(0);
       const shape_inference::ShapeAndType& list_shape_type_b =
-          (handle_data_b && !handle_data_b->empty()) ? handle_data_b->at(0)
-                                                     : handle_data_a->at(0);
+          handle_data_b_nonempty ? handle_data_b->at(0) : handle_data_a->at(0);
       if (list_shape_type_a.dtype != element_dtype) {
         return errors::InvalidArgument("input_a.type != element_dtype: ",
                                        DataTypeString(list_shape_type_a.dtype),
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index a55dde6..4d10054 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -1652,6 +1652,25 @@
 expected to invoke these operators.
 )doc");
 
+REGISTER_OP("_MklEagerConv2D")
+    .Input("input: T")
+    .Input("filter: T")
+    .Output("output: T")
+    .Attr("T: {bfloat16, float}")
+    .Attr("strides: list(int)")
+    .Attr("use_cudnn_on_gpu: bool = true")
+    .Attr(GetPaddingAttrStringWithExplicit())
+    .Attr(GetExplicitPaddingsAttrString())
+    .Attr(GetConvnetDataFormatAttrString())
+    .Attr("dilations: list(int) = [1, 1, 1, 1]")
+    .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding)
+    .Doc(R"doc(
+    MKL version of Conv2D operator for Eager mode. Uses MKL DNN APIs to perform 2D convolution.
+
+    NOTE Do not invoke this operator directly in Python. Eager Op rewrite is
+    expected to invoke these operators.
+    )doc");
+
 REGISTER_OP("__MklDummyConv2DWithBias")
     .Input("input: T")
     .Input("filter: T")
@@ -1782,6 +1801,33 @@
 expected to invoke these operators.
 )doc");
 
+REGISTER_OP("_MklEagerConv2DBackpropFilter")
+    .Input("input: T")
+    .Input("filter_sizes: int32")
+    .Input("out_backprop: T")
+    .Output("output: T")
+    .Attr("T: {bfloat16, float}")
+    .Attr("strides: list(int)")
+    .Attr("use_cudnn_on_gpu: bool = true")
+    .Attr(GetPaddingAttrStringWithExplicit())
+    .Attr(GetExplicitPaddingsAttrString())
+    .Attr(GetConvnetDataFormatAttrString())
+    .Attr("dilations: list(int) = [1, 1, 1, 1]")
+    .SetShapeFn([](InferenceContext* c) {
+      ShapeHandle s;
+      TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
+      TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
+      c->set_output(0, s);
+      return Status::OK();
+    })
+    .Doc(R"doc(
+MKL version of Conv2DBackpropFilter for Eager mode. Uses MKL DNN APIs
+to compute the gradients of convolution with respect to the filter.
+
+NOTE Do not invoke this operator directly in Python. Eager Op rewrite pass is
+expected to invoke these operators.
+)doc");
+
 REGISTER_OP("__MklDummyConv2DBackpropFilterWithBias")
     .Input("input: T")
     .Input("filter_sizes: int32")
@@ -1915,6 +1961,33 @@
 expected to invoke these operators.
 )doc");
 
+REGISTER_OP("_MklEagerConv2DBackpropInput")
+    .Input("input_sizes: int32")
+    .Input("filter: T")
+    .Input("out_backprop: T")
+    .Output("output: T")
+    .Attr("T: {bfloat16, float}")
+    .Attr("strides: list(int)")
+    .Attr("use_cudnn_on_gpu: bool = true")
+    .Attr(GetPaddingAttrStringWithExplicit())
+    .Attr(GetExplicitPaddingsAttrString())
+    .Attr(GetConvnetDataFormatAttrString())
+    .Attr("dilations: list(int) = [1, 1, 1, 1]")
+    .SetShapeFn([](InferenceContext* c) {
+      ShapeHandle s;
+      TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+      TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
+      c->set_output(0, s);
+      return Status::OK();
+    })
+    .Doc(R"doc(
+MKL version of Convolution2D backward input for Eager mode. Uses MKL DNN APIs
+to compute the gradients of convolution with respect to the input.
+
+NOTE Do not invoke this operator directly in Python. Eager op rewrite is
+expected to invoke these operators.
+)doc");
+
 REGISTER_OP("_MklConv3D")
     .Input("input: T")
     .Input("filter: T")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index b119eee..45f5604 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -786,6 +786,18 @@
   is_stateful: true
 }
 op {
+  name: "AnonymousMemoryCache"
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "deleter"
+    type: DT_VARIANT
+  }
+  is_stateful: true
+}
+op {
   name: "AnonymousMultiDeviceIterator"
   output_arg {
     name: "handle"
@@ -816,6 +828,26 @@
   is_stateful: true
 }
 op {
+  name: "AnonymousRandomSeedGenerator"
+  input_arg {
+    name: "seed"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "seed2"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "deleter"
+    type: DT_VARIANT
+  }
+  is_stateful: true
+}
+op {
   name: "Any"
   input_arg {
     name: "input"
@@ -1141,6 +1173,75 @@
   }
 }
 op {
+  name: "ApplyAdagradV2"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "update_slots"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
+op {
   name: "ApplyAdam"
   input_arg {
     name: "var"
@@ -4874,6 +4975,7 @@
     allowed_values {
       list {
         s: "inequality"
+        s: "equality"
       }
     }
   }
@@ -5928,6 +6030,38 @@
   }
 }
 op {
+  name: "CacheDatasetV2"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "filename"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "cache"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
+op {
   name: "Case"
   input_arg {
     name: "branch_index"
@@ -6042,6 +6176,7 @@
     name: "message"
     type: "string"
   }
+  is_stateful: true
 }
 op {
   name: "Cholesky"
@@ -6884,6 +7019,14 @@
   is_stateful: true
 }
 op {
+  name: "ConfigureTPUEmbedding"
+  attr {
+    name: "config"
+    type: "string"
+  }
+  is_stateful: true
+}
+op {
   name: "Conj"
   input_arg {
     name: "input"
@@ -10385,6 +10528,18 @@
   is_stateful: true
 }
 op {
+  name: "DeleteMemoryCache"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "deleter"
+    type: DT_VARIANT
+  }
+  is_stateful: true
+}
+op {
   name: "DeleteMultiDeviceIterator"
   input_arg {
     name: "multi_device_iterator"
@@ -10407,6 +10562,18 @@
   is_stateful: true
 }
 op {
+  name: "DeleteRandomSeedGenerator"
+  input_arg {
+    name: "handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "deleter"
+    type: DT_VARIANT
+  }
+  is_stateful: true
+}
+op {
   name: "DeleteSessionTensor"
   input_arg {
     name: "handle"
@@ -13144,7 +13311,7 @@
     type: DT_VARIANT
   }
   input_arg {
-    name: "num_workers"
+    name: "num_replicas"
     type: DT_INT64
   }
   output_arg {
@@ -25402,6 +25569,13 @@
       i: 0
     }
   }
+  attr {
+    name: "legacy_autotune"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
 }
 op {
   name: "Prelinearize"
@@ -30070,6 +30244,64 @@
   }
 }
 op {
+  name: "RaggedTensorToTensor"
+  input_arg {
+    name: "shape"
+    type_attr: "Tshape"
+  }
+  input_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "default_value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "row_partition_tensors"
+    type_attr: "Tindex"
+    number_attr: "num_row_partition_tensors"
+  }
+  output_arg {
+    name: "result"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "Tindex"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT64
+        type: DT_INT32
+      }
+    }
+  }
+  attr {
+    name: "Tshape"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT64
+        type: DT_INT32
+      }
+    }
+  }
+  attr {
+    name: "num_row_partition_tensors"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "row_partition_types"
+    type: "list(string)"
+  }
+}
+op {
   name: "RaggedTensorToVariant"
   input_arg {
     name: "rt_nested_splits"
@@ -31090,7 +31322,7 @@
     type: DT_VARIANT
   }
   input_arg {
-    name: "num_workers"
+    name: "num_replicas"
     type: DT_INT64
   }
   output_arg {
@@ -32667,6 +32899,69 @@
   is_stateful: true
 }
 op {
+  name: "ResourceApplyAdagradV2"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "update_slots"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  is_stateful: true
+}
+op {
   name: "ResourceApplyAdam"
   input_arg {
     name: "var"
@@ -34372,6 +34667,83 @@
   is_stateful: true
 }
 op {
+  name: "ResourceSparseApplyAdagradV2"
+  input_arg {
+    name: "var"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "accum"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "update_slots"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  is_stateful: true
+}
+op {
   name: "ResourceSparseApplyCenteredRMSProp"
   input_arg {
     name: "var"
@@ -38437,6 +38809,38 @@
   }
 }
 op {
+  name: "ShuffleDatasetV2"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "buffer_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "seed_generator"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
+op {
   name: "ShutdownDistributedTPU"
   is_stateful: true
 }
@@ -39663,6 +40067,89 @@
   }
 }
 op {
+  name: "SparseApplyAdagradV2"
+  input_arg {
+    name: "var"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "accum"
+    type_attr: "T"
+    is_ref: true
+  }
+  input_arg {
+    name: "lr"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "epsilon"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "grad"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "indices"
+    type_attr: "Tindices"
+  }
+  output_arg {
+    name: "out"
+    type_attr: "T"
+    is_ref: true
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "use_locking"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "update_slots"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
+op {
   name: "SparseApplyCenteredRMSProp"
   input_arg {
     name: "var"
@@ -43238,6 +43725,21 @@
     name: "body"
     type: "func"
   }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "parallel_iterations"
+    type: "int"
+    default_value {
+      i: 10
+    }
+  }
 }
 op {
   name: "StaticRegexFullMatch"
@@ -43694,6 +44196,63 @@
   }
 }
 op {
+  name: "StringNGrams"
+  input_arg {
+    name: "data"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "data_splits"
+    type_attr: "Tsplits"
+  }
+  output_arg {
+    name: "ngrams"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "ngrams_splits"
+    type_attr: "Tsplits"
+  }
+  attr {
+    name: "separator"
+    type: "string"
+  }
+  attr {
+    name: "ngram_widths"
+    type: "list(int)"
+    has_minimum: true
+  }
+  attr {
+    name: "left_pad"
+    type: "string"
+  }
+  attr {
+    name: "right_pad"
+    type: "string"
+  }
+  attr {
+    name: "pad_width"
+    type: "int"
+  }
+  attr {
+    name: "preserve_short_sequences"
+    type: "bool"
+  }
+  attr {
+    name: "Tsplits"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
   name: "StringSplit"
   input_arg {
     name: "input"
diff --git a/tensorflow/core/ops/ragged_conversion_ops.cc b/tensorflow/core/ops/ragged_conversion_ops.cc
index 5794b89..78fa5db 100644
--- a/tensorflow/core/ops/ragged_conversion_ops.cc
+++ b/tensorflow/core/ops/ragged_conversion_ops.cc
@@ -15,16 +15,84 @@
 #include "tensorflow/core/framework/common_shape_fns.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/ops/ragged_to_dense_util.h"
 
 namespace tensorflow {
 
+using errors::InvalidArgument;
 using shape_inference::DimensionHandle;
 using shape_inference::InferenceContext;
 using shape_inference::ShapeHandle;
 
+namespace {
+tensorflow::Status ValidateRowPartitionTypesAndShapes(
+    const std::vector<RowPartitionType>& row_partition_types,
+    InferenceContext* c) {
+  // Note: the allowed types may be extended in the future.
+  for (RowPartitionType row_partition_type : row_partition_types) {
+    switch (row_partition_type) {
+      case RowPartitionType::FIRST_DIM_SIZE:
+      case RowPartitionType::VALUE_ROWIDS:
+      case RowPartitionType::ROW_SPLITS:
+        break;
+      default:
+        return InvalidArgument("Unsupported partition type: ",
+                               RowPartitionTypeToString(row_partition_type));
+    }
+  }
+
+  if (row_partition_types.empty()) {
+    return InvalidArgument("Partition info types should not be empty");
+  }
+  for (int i = 1; i < row_partition_types.size(); ++i) {
+    if (row_partition_types[i] == RowPartitionType::FIRST_DIM_SIZE) {
+      return InvalidArgument("FIRST_DIM_SIZE must be first");
+    }
+  }
+  if (row_partition_types[0] == RowPartitionType::FIRST_DIM_SIZE &&
+      (row_partition_types.size() < 2 ||
+       row_partition_types[1] != RowPartitionType::VALUE_ROWIDS)) {
+    return InvalidArgument("FIRST_DIM_SIZE must be followed by VALUE_ROWIDS");
+  }
+  if (row_partition_types[0] == RowPartitionType::VALUE_ROWIDS) {
+    return InvalidArgument("VALUE_ROWIDS cannot be first");
+  }
+
+  int num_row_partition_tensors;
+  TF_RETURN_IF_ERROR(
+      c->GetAttr("num_row_partition_tensors", &num_row_partition_tensors));
+  if (num_row_partition_tensors != row_partition_types.size()) {
+    return InvalidArgument(
+        "Number of row partition tensors (", num_row_partition_tensors,
+        ") does not equal the number of row partition types(",
+        row_partition_types.size(), ").");
+  }
+
+  for (int i = 0; i < num_row_partition_tensors; ++i) {
+    TensorShapeProto partition_shape;
+    c->ShapeHandleToProto(c->input(3 + i), &partition_shape);
+    if (partition_shape.unknown_rank()) {
+      continue;
+    }
+    if (row_partition_types[i] == RowPartitionType::FIRST_DIM_SIZE) {
+      if (partition_shape.dim_size() != 0) {
+        return InvalidArgument("FIRST_DIM_SIZE must be a scalar.");
+      }
+    } else {
+      if (partition_shape.dim_size() != 1) {
+        return InvalidArgument("Row partition must be a vector.");
+      }
+    }
+  }
+  return tensorflow::Status::OK();
+}
+
+}  // namespace
+
 Status RaggedTensorToSparseShapeFn(InferenceContext* c);
 Status RaggedTensorToVariantShapeFn(InferenceContext* c);
 Status RaggedTensorFromVariantShapeFn(InferenceContext* c);
+tensorflow::Status RaggedTensorToTensorShapeFn(InferenceContext* c);
 
 //==============================================================================
 // Registered Ops
@@ -61,6 +129,19 @@
     .Attr("Tsplits: {int32, int64}")
     .SetShapeFn(RaggedTensorFromVariantShapeFn);
 
+REGISTER_OP("RaggedTensorToTensor")
+    .Attr("T: type")
+    .Attr("Tindex: {int64, int32}")
+    .Attr("Tshape: {int64, int32}")
+    .Attr("num_row_partition_tensors: int")
+    .Attr("row_partition_types: list(string)")
+    .Input("shape: Tshape")
+    .Input("values: T")
+    .Input("default_value: T")
+    .Input("row_partition_tensors: num_row_partition_tensors * Tindex")
+    .Output("result: T")
+    .SetShapeFn(RaggedTensorToTensorShapeFn);
+
 //==============================================================================
 // Shape Functions
 //==============================================================================
@@ -136,4 +217,47 @@
   return Status::OK();
 }
 
+tensorflow::Status RaggedTensorToTensorShapeFn(InferenceContext* c) {
+  TensorShapeProto shape;
+  {
+    ShapeHandle shape_handle;
+    TF_RETURN_IF_ERROR(
+        c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(0, &shape_handle));
+    c->ShapeHandleToProto(shape_handle, &shape);
+  }
+
+  std::vector<RowPartitionType> row_partition_types;
+  TF_RETURN_IF_ERROR(GetRowPartitionTypes(c, &row_partition_types));
+  int ragged_rank = GetRaggedRank(row_partition_types);
+  TF_RETURN_IF_ERROR(
+      ValidateRowPartitionTypesAndShapes(row_partition_types, c));
+
+  TensorShapeProto value_shape;
+  c->ShapeHandleToProto(c->input(1), &value_shape);
+
+  TensorShapeProto default_value_shape;
+  c->ShapeHandleToProto(c->input(2), &default_value_shape);
+
+  TF_RETURN_IF_ERROR(
+      ValidateDefaultValueShape(default_value_shape, value_shape));
+
+  // TODO(martinz): Theoretically, we could check the first dimension of
+  // value_shape against the first dimension of the last row_partition_tensor
+  // assuming it is a VALUE_ROWIDS type.
+  // TODO(martinz): Although we normally don't know the first dimension of the
+  // output, we could infer it from the first dimension of the first
+  // row_partition_tensor if it is ROW_SPLITS type.
+  // TODO(martinz): If the shape is provided, but the value_shape has missing
+  // dimensions, we can check the default_value_shape against the shape.
+  TensorShapeProto output_shape;
+  TF_RETURN_IF_ERROR(CombineRaggedTensorToTensorShapes(
+      ragged_rank, shape, value_shape, &output_shape));
+
+  ShapeHandle output_shape_handle;
+  TF_RETURN_IF_ERROR(
+      c->MakeShapeFromShapeProto(output_shape, &output_shape_handle));
+  c->set_output(0, output_shape_handle);
+  return Status::OK();
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/ops/ragged_to_dense_util.cc b/tensorflow/core/ops/ragged_to_dense_util.cc
new file mode 100644
index 0000000..246f724
--- /dev/null
+++ b/tensorflow/core/ops/ragged_to_dense_util.cc
@@ -0,0 +1,162 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/ops/ragged_to_dense_util.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+
+namespace tensorflow {
+
+using errors::InvalidArgument;
+
+string RowPartitionTypeToString(RowPartitionType row_partition_type) {
+  switch (row_partition_type) {
+    case RowPartitionType::FIRST_DIM_SIZE:
+      return "FIRST_DIM_SIZE";
+    case RowPartitionType::VALUE_ROWIDS:
+      return "VALUE_ROWIDS";
+    case RowPartitionType::ROW_LENGTHS:
+      return "ROW_LENGTHS";
+    case RowPartitionType::ROW_SPLITS:
+      return "ROW_SPLITS";
+    case RowPartitionType::ROW_LIMITS:
+      return "ROW_LIMITS";
+    case RowPartitionType::ROW_STARTS:
+      return "ROW_STARTS";
+    default:
+      return "UNKNOWN ROW PARTITION TYPE";
+  }
+}
+tensorflow::Status GetRowPartitionTypesHelper(
+    const std::vector<string>& row_partition_type_strings,
+    std::vector<RowPartitionType>* row_partition_types) {
+  static const auto kStringToType =
+      new std::unordered_map<string, RowPartitionType>(
+          {{"FIRST_DIM_SIZE", RowPartitionType::FIRST_DIM_SIZE},
+           {"VALUE_ROWIDS", RowPartitionType::VALUE_ROWIDS},
+           {"ROW_LENGTHS", RowPartitionType::ROW_LENGTHS},
+           {"ROW_SPLITS", RowPartitionType::ROW_SPLITS},
+           {"ROW_LIMITS", RowPartitionType::ROW_LIMITS},
+           {"ROW_STARTS", RowPartitionType::ROW_STARTS}});
+
+  for (const string& type_str : row_partition_type_strings) {
+    const auto iter = kStringToType->find(type_str);
+    if (iter == kStringToType->end()) {
+      return InvalidArgument("Unknown string for partition info type: ",
+                             type_str);
+    }
+    row_partition_types->push_back(iter->second);
+  }
+  return tensorflow::Status::OK();
+}
+
+tensorflow::Status CombineRaggedTensorToTensorShapes(
+    int ragged_rank, const TensorShapeProto& shape,
+    const TensorShapeProto& value_shape, TensorShapeProto* output_shape) {
+  // Test for consistency of value_shape and shape specified.
+  // If shape is unspecified and value_shape is specified, then copy
+  // over the size from the value_shape dimension.
+
+  if (value_shape.unknown_rank() && shape.unknown_rank()) {
+    output_shape->Clear();
+    output_shape->set_unknown_rank(true);
+    return tensorflow::Status::OK();
+  }
+
+  if (shape.unknown_rank()) {
+    // Here, value_shape must be of known size.
+    while (output_shape->dim_size() < ragged_rank + value_shape.dim_size()) {
+      output_shape->add_dim()->set_size(-1);
+    }
+  } else {
+    *output_shape = shape;
+  }
+  if (value_shape.unknown_rank()) {
+    return tensorflow::Status::OK();
+  }
+  // At this point, value_shape and output_shape have known ranks.
+  if (ragged_rank + value_shape.dim_size() != output_shape->dim_size()) {
+    return InvalidArgument("Value shape (", value_shape.DebugString(),
+                           "), ragged_rank(", ragged_rank, ") and shape(",
+                           shape.DebugString(),
+                           ") do not have a consistent number of dimensions");
+  }
+
+  for (int i = 1; i < value_shape.dim_size(); ++i) {
+    const TensorShapeProto::Dim& value_dim = value_shape.dim(i);
+    TensorShapeProto::Dim* output_shape_dim = output_shape->mutable_dim(
+        output_shape->dim_size() - value_shape.dim_size() + i);
+
+    if (value_dim.size() >= 0) {
+      if (output_shape_dim->size() >= 0) {
+        if (output_shape_dim->size() != value_dim.size()) {
+          return InvalidArgument("Value and shape dimension are inconsistent.");
+        }
+      } else {
+        output_shape_dim->set_size(value_dim.size());
+      }
+    }
+  }
+  return tensorflow::Status::OK();
+}
+
+int GetRaggedRank(const std::vector<RowPartitionType>& row_partition_types) {
+  if (row_partition_types.empty()) {
+    return 0;
+  }
+  if (row_partition_types[0] == RowPartitionType::FIRST_DIM_SIZE) {
+    return row_partition_types.size() - 1;
+  }
+  return row_partition_types.size();
+}
+
+tensorflow::Status ValidateDefaultValueShape(
+    const TensorShapeProto& default_value_shape,
+    const TensorShapeProto& value_shape) {
+  if (default_value_shape.unknown_rank() || value_shape.unknown_rank()) {
+    return tensorflow::Status::OK();
+  }
+
+  if (default_value_shape.dim_size() > value_shape.dim_size()) {
+    // TODO(martinz): This constraint is unnecessary. The
+    // default value could have as many dimensions as shape. If there is a
+    // discrepancy, it will be picked up when we broadcast the default value.
+    // For now, I'll relax the constraint only slightly.
+    return InvalidArgument(
+        "default_value_shape must have no more dimensions than the value. "
+        "default_value_shape: ",
+        default_value_shape.DebugString(),
+        " default_value_shape.dim_size(): ", default_value_shape.dim_size(),
+        " value_shape: ", value_shape.DebugString(),
+        " value_shape.dim_size(): ", value_shape.dim_size());
+  }
+  for (int i = 0;
+       i < std::min(default_value_shape.dim_size(), value_shape.dim_size() - 1);
+       ++i) {
+    if (default_value_shape.dim(i).size() >= 0 &&
+        value_shape.dim(i + 1).size() >= 0 &&
+        default_value_shape.dim(i).size() != 1 &&
+        default_value_shape.dim(i).size() != value_shape.dim(i + 1).size()) {
+      return InvalidArgument(
+          "default_value_shape and value_shape do not match on dimension ", i);
+    }
+  }
+  return tensorflow::Status::OK();
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/ops/ragged_to_dense_util.h b/tensorflow/core/ops/ragged_to_dense_util.h
new file mode 100644
index 0000000..d29d6a5
--- /dev/null
+++ b/tensorflow/core/ops/ragged_to_dense_util.h
@@ -0,0 +1,63 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_OPS_RAGGED_TO_DENSE_UTIL_H_
+#define TENSORFLOW_CORE_OPS_RAGGED_TO_DENSE_UTIL_H_
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+
+namespace tensorflow {
+enum class RowPartitionType {
+  FIRST_DIM_SIZE,
+  VALUE_ROWIDS,
+  ROW_LENGTHS,
+  ROW_SPLITS,
+  ROW_LIMITS,
+  ROW_STARTS
+};
+
+string RowPartitionTypeToString(RowPartitionType row_partition_type);
+
+Status GetRowPartitionTypesHelper(
+    const std::vector<string>& row_partition_type_strings,
+    std::vector<RowPartitionType>* row_partition_types);
+
+// ContextType must be InferenceContext or OpKernelConstruction.
+template <typename ContextType>
+Status GetRowPartitionTypes(
+    ContextType* context, std::vector<RowPartitionType>* row_partition_types) {
+  std::vector<string> row_partition_type_strings;
+  TF_RETURN_IF_ERROR(
+      context->GetAttr("row_partition_types", &row_partition_type_strings));
+  return GetRowPartitionTypesHelper(row_partition_type_strings,
+                                    row_partition_types);
+}
+
+Status CombineRaggedTensorToTensorShapes(int ragged_rank,
+                                         const TensorShapeProto& shape,
+                                         const TensorShapeProto& value_shape,
+                                         TensorShapeProto* output_shape);
+
+int GetRaggedRank(const std::vector<RowPartitionType>& row_partition_types);
+
+Status ValidateDefaultValueShape(const TensorShapeProto& default_value_shape,
+                                 const TensorShapeProto& value_shape);
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_OPS_RAGGED_TO_DENSE_UTIL_H_
diff --git a/tensorflow/core/ops/ragged_to_dense_util_test.cc b/tensorflow/core/ops/ragged_to_dense_util_test.cc
new file mode 100644
index 0000000..d3d9e68
--- /dev/null
+++ b/tensorflow/core/ops/ragged_to_dense_util_test.cc
@@ -0,0 +1,214 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/ops/ragged_to_dense_util.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+TEST(CombineRaggedTensorToTensorShapes, UnknownShapeUnknownValue) {
+  TensorShapeProto shape_proto;
+  shape_proto.set_unknown_rank(true);
+  TensorShapeProto value_shape_proto;
+  value_shape_proto.set_unknown_rank(true);
+  int ragged_rank = 1;
+
+  TensorShapeProto actual_output_shape_proto;
+  TF_ASSERT_OK(CombineRaggedTensorToTensorShapes(
+      ragged_rank, shape_proto, value_shape_proto, &actual_output_shape_proto));
+
+  EXPECT_EQ(true, actual_output_shape_proto.unknown_rank());
+}
+
+TEST(CombineRaggedTensorToTensorShapes, UnknownShape) {
+  TensorShapeProto shape_proto;
+  shape_proto.set_unknown_rank(true);
+  TensorShapeProto value_shape_proto;
+  value_shape_proto.add_dim()->set_size(6);
+  int ragged_rank = 1;
+
+  TensorShapeProto actual_output_shape_proto;
+  TF_ASSERT_OK(CombineRaggedTensorToTensorShapes(
+      ragged_rank, shape_proto, value_shape_proto, &actual_output_shape_proto));
+
+  ASSERT_EQ(actual_output_shape_proto.dim_size(), 2);
+  EXPECT_EQ(actual_output_shape_proto.dim(0).size(), -1);
+  EXPECT_EQ(actual_output_shape_proto.dim(1).size(), -1);
+}
+
+TEST(CombineRaggedTensorToTensorShapes, UnknownShapeDenseValue) {
+  TensorShapeProto shape_proto;
+  shape_proto.set_unknown_rank(true);
+  TensorShapeProto value_shape_proto;
+  value_shape_proto.add_dim()->set_size(6);
+  value_shape_proto.add_dim()->set_size(3);
+  int ragged_rank = 1;
+
+  TensorShapeProto actual_output_shape_proto;
+  TF_ASSERT_OK(CombineRaggedTensorToTensorShapes(
+      ragged_rank, shape_proto, value_shape_proto, &actual_output_shape_proto));
+
+  ASSERT_EQ(actual_output_shape_proto.dim_size(), 3);
+  EXPECT_EQ(actual_output_shape_proto.dim(0).size(), -1);
+  EXPECT_EQ(actual_output_shape_proto.dim(1).size(), -1);
+  EXPECT_EQ(actual_output_shape_proto.dim(2).size(), 3);
+}
+
+TEST(GetRowPartitionTypesHelper, BasicTest) {
+  const std::vector<string> row_partition_type_strings = {
+      "FIRST_DIM_SIZE", "VALUE_ROWIDS", "ROW_SPLITS"};
+  std::vector<RowPartitionType> row_partition_types;
+  TF_ASSERT_OK(GetRowPartitionTypesHelper(row_partition_type_strings,
+                                          &row_partition_types));
+  EXPECT_THAT(row_partition_types,
+              ::testing::ElementsAre(RowPartitionType::FIRST_DIM_SIZE,
+                                     RowPartitionType::VALUE_ROWIDS,
+                                     RowPartitionType::ROW_SPLITS));
+}
+
+TEST(RowPartitionTypeToString, BasicTest) {
+  EXPECT_EQ("FIRST_DIM_SIZE",
+            RowPartitionTypeToString(RowPartitionType::FIRST_DIM_SIZE));
+  EXPECT_EQ("VALUE_ROWIDS",
+            RowPartitionTypeToString(RowPartitionType::VALUE_ROWIDS));
+  EXPECT_EQ("ROW_SPLITS",
+            RowPartitionTypeToString(RowPartitionType::ROW_SPLITS));
+}
+
+TEST(ValidateDefaultValueShape, UnknownDefaultValueShape) {
+  TensorShapeProto default_value_shape_proto;
+  default_value_shape_proto.set_unknown_rank(true);
+  TensorShapeProto value_shape_proto;
+  value_shape_proto.add_dim()->set_size(6);
+  TF_EXPECT_OK(
+      ValidateDefaultValueShape(default_value_shape_proto, value_shape_proto));
+}
+
+TEST(ValidateDefaultValueShape, UnknownValueShape) {
+  TensorShapeProto default_value_shape_proto;
+  default_value_shape_proto.add_dim()->set_size(5);
+  TensorShapeProto value_shape_proto;
+  value_shape_proto.set_unknown_rank(true);
+  TF_EXPECT_OK(
+      ValidateDefaultValueShape(default_value_shape_proto, value_shape_proto));
+}
+
+TEST(ValidateDefaultValueShape, ScalarShape) {
+  TensorShapeProto default_value_shape_proto;
+  TensorShapeProto value_shape_proto;
+  value_shape_proto.add_dim()->set_size(5);
+  TF_EXPECT_OK(
+      ValidateDefaultValueShape(default_value_shape_proto, value_shape_proto));
+}
+
+TEST(ValidateDefaultValueShape, TensorShapeEqual) {
+  TensorShapeProto default_value_shape_proto;
+  default_value_shape_proto.add_dim()->set_size(2);
+  default_value_shape_proto.add_dim()->set_size(3);
+  TensorShapeProto value_shape_proto;
+  value_shape_proto.add_dim()->set_size(5);
+  value_shape_proto.add_dim()->set_size(2);
+  value_shape_proto.add_dim()->set_size(3);
+  TF_EXPECT_OK(
+      ValidateDefaultValueShape(default_value_shape_proto, value_shape_proto));
+}
+
+TEST(ValidateDefaultValueShape, TensorDimensionUnknown) {
+  TensorShapeProto default_value_shape_proto;
+  default_value_shape_proto.add_dim()->set_size(-1);
+  default_value_shape_proto.add_dim()->set_size(3);
+  TensorShapeProto value_shape_proto;
+  value_shape_proto.add_dim()->set_size(5);
+  value_shape_proto.add_dim()->set_size(2);
+  value_shape_proto.add_dim()->set_size(3);
+  TF_EXPECT_OK(
+      ValidateDefaultValueShape(default_value_shape_proto, value_shape_proto));
+}
+
+TEST(ValidateDefaultValueShape, TensorDimensionUnknownForValue) {
+  TensorShapeProto default_value_shape_proto;
+  default_value_shape_proto.add_dim()->set_size(2);
+  default_value_shape_proto.add_dim()->set_size(3);
+  TensorShapeProto value_shape_proto;
+  value_shape_proto.add_dim()->set_size(5);
+  value_shape_proto.add_dim()->set_size(-1);
+  value_shape_proto.add_dim()->set_size(3);
+  TF_EXPECT_OK(
+      ValidateDefaultValueShape(default_value_shape_proto, value_shape_proto));
+}
+
+TEST(ValidateDefaultValueShape, TensorDimensionFewDims) {
+  TensorShapeProto default_value_shape_proto;
+  default_value_shape_proto.add_dim()->set_size(3);
+  TensorShapeProto value_shape_proto;
+  value_shape_proto.add_dim()->set_size(5);
+  value_shape_proto.add_dim()->set_size(-1);
+  value_shape_proto.add_dim()->set_size(3);
+  TF_EXPECT_OK(
+      ValidateDefaultValueShape(default_value_shape_proto, value_shape_proto));
+}
+
+TEST(ValidateDefaultValueShape, WrongNumberOfDimensions) {
+  // I have modified this test to make the default value shape have more
+  // dimensions, instead of the same number.
+  TensorShapeProto default_value_shape_proto;
+  default_value_shape_proto.add_dim()->set_size(-1);
+  default_value_shape_proto.add_dim()->set_size(-1);
+  default_value_shape_proto.add_dim()->set_size(-1);
+  TensorShapeProto value_shape_proto;
+  value_shape_proto.add_dim()->set_size(-1);
+  value_shape_proto.add_dim()->set_size(-1);
+  EXPECT_FALSE(
+      ValidateDefaultValueShape(default_value_shape_proto, value_shape_proto)
+          .ok());
+}
+
+TEST(ValidateDefaultValueShape, WrongDimensionSize) {
+  TensorShapeProto default_value_shape_proto;
+  default_value_shape_proto.add_dim()->set_size(3);
+  default_value_shape_proto.add_dim()->set_size(-1);
+  TensorShapeProto value_shape_proto;
+  value_shape_proto.add_dim()->set_size(5);
+  value_shape_proto.add_dim()->set_size(6);
+  value_shape_proto.add_dim()->set_size(-1);
+  EXPECT_FALSE(
+      ValidateDefaultValueShape(default_value_shape_proto, value_shape_proto)
+          .ok());
+}
+
+// This is the case where broadcast could work, but we throw an error.
+TEST(ValidateDefaultValueShape, WrongDimensionSizeBut1) {
+  TensorShapeProto default_value_shape_proto;
+  default_value_shape_proto.add_dim()->set_size(3);
+  default_value_shape_proto.add_dim()->set_size(1);
+  TensorShapeProto value_shape_proto;
+  value_shape_proto.add_dim()->set_size(5);
+  value_shape_proto.add_dim()->set_size(3);
+  value_shape_proto.add_dim()->set_size(7);
+  TF_EXPECT_OK(
+      ValidateDefaultValueShape(default_value_shape_proto, value_shape_proto));
+}
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index 2e07db3..4d9ad0a 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -365,4 +365,26 @@
       return Status::OK();
     });
 
+REGISTER_OP("StringNGrams")
+    .Attr("separator: string")
+    .Attr("ngram_widths: list(int) >= 0")
+    .Attr("left_pad: string")
+    .Attr("right_pad: string")
+    .Attr("pad_width: int")
+    .Attr("preserve_short_sequences: bool")
+    .Attr("Tsplits: {int32, int64} = DT_INT64")
+    .Input("data: string")
+    .Input("data_splits: Tsplits")
+    .Output("ngrams: string")
+    .Output("ngrams_splits: Tsplits")
+    .SetShapeFn([](InferenceContext* c) {
+      c->set_output(0, c->UnknownShapeOfRank(1));
+      ShapeHandle data = c->input(0);
+      TF_RETURN_IF_ERROR(c->WithRank(data, 1, &data));
+      ShapeHandle data_splits = c->input(1);
+      TF_RETURN_IF_ERROR(c->WithRank(data_splits, 1, &data_splits));
+      c->set_output(1, data_splits);
+      return Status::OK();
+    });
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/ops/tpu_configuration_ops.cc b/tensorflow/core/ops/tpu_configuration_ops.cc
index febb250..94a6b32 100644
--- a/tensorflow/core/ops/tpu_configuration_ops.cc
+++ b/tensorflow/core/ops/tpu_configuration_ops.cc
@@ -199,4 +199,9 @@
     .SetIsStateful()
     .SetShapeFn(shape_inference::UnknownShape);
 
+REGISTER_OP("ConfigureTPUEmbedding")
+    .Attr("config: string")
+    .SetIsStateful()
+    .SetShapeFn(shape_inference::UnknownShape);
+
 }  // end namespace tensorflow
diff --git a/tensorflow/core/ops/tpu_cross_replica_ops.cc b/tensorflow/core/ops/tpu_cross_replica_ops.cc
index c26b49e..adce0b5 100644
--- a/tensorflow/core/ops/tpu_cross_replica_ops.cc
+++ b/tensorflow/core/ops/tpu_cross_replica_ops.cc
@@ -40,6 +40,9 @@
       }
       int concat_dimension;
       int split_dimension;
+      int split_count;
+
+      TF_RETURN_IF_ERROR(c->GetAttr("split_count", &split_count));
 
       TF_RETURN_IF_ERROR(c->GetAttr("concat_dimension", &concat_dimension));
 
@@ -58,14 +61,13 @@
       dims.resize(rank);
 
       for (int32 i = 0; i < rank; ++i) {
-        int64 in_idx = i;
+        dims[i] = c->Dim(input, i);
         if (i == concat_dimension) {
-          in_idx = split_dimension;
-        } else if (i == split_dimension) {
-          in_idx = concat_dimension;
+          dims[i] = c->MakeDim(c->Value(dims[i]) * split_count);
         }
-
-        dims[i] = c->Dim(input, in_idx);
+        if (i == split_dimension) {
+          dims[i] = c->MakeDim(c->Value(dims[i]) / split_count);
+        }
       }
 
       c->set_output(0, c->MakeShape(dims));
diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc
index 995ed42..08794a9 100644
--- a/tensorflow/core/ops/training_ops.cc
+++ b/tensorflow/core/ops/training_ops.cc
@@ -245,6 +245,20 @@
   return Status::OK();
 }
 
+static Status ApplyAdagradV2ShapeFn(InferenceContext* c, bool sparse) {
+  ShapeHandle unused;
+  ShapeHandle s = ShapeOrHandleShape(c, 0);                       // var
+  TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s));  // accum
+  TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));       // lr
+  TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));       // epsilon
+  TF_RETURN_IF_ERROR(
+      HandleGradAndIndicesInputs(c, sparse, 4 /* grad_idx */, &s));
+  if (c->num_outputs() > 0) {
+    c->set_output(0, s);
+  }
+  return Status::OK();
+}
+
 REGISTER_OP("ApplyAdagrad")
     .Input("var: Ref(T)")
     .Input("accum: Ref(T)")
@@ -270,6 +284,33 @@
       return ApplyAdagradShapeFn(c, false /* sparse */);
     });
 
+REGISTER_OP("ApplyAdagradV2")
+    .Input("var: Ref(T)")
+    .Input("accum: Ref(T)")
+    .Input("lr: T")
+    .Input("epsilon: T")
+    .Input("grad: T")
+    .Output("out: Ref(T)")
+    .Attr("T: numbertype")
+    .Attr("use_locking: bool = false")
+    .Attr("update_slots: bool = true")
+    .SetShapeFn([](InferenceContext* c) {
+      return ApplyAdagradV2ShapeFn(c, false /* sparse */);
+    });
+
+REGISTER_OP("ResourceApplyAdagradV2")
+    .Input("var: resource")
+    .Input("accum: resource")
+    .Input("lr: T")
+    .Input("epsilon: T")
+    .Input("grad: T")
+    .Attr("T: numbertype")
+    .Attr("use_locking: bool = false")
+    .Attr("update_slots: bool = true")
+    .SetShapeFn([](InferenceContext* c) {
+      return ApplyAdagradV2ShapeFn(c, false /* sparse */);
+    });
+
 static Status ApplyProximalAdagradShapeFn(InferenceContext* c, bool sparse) {
   ShapeHandle unused;
   ShapeHandle s = ShapeOrHandleShape(c, 0);                       // var
@@ -341,6 +382,37 @@
       return ApplyAdagradShapeFn(c, true /* sparse */);
     });
 
+REGISTER_OP("SparseApplyAdagradV2")
+    .Input("var: Ref(T)")
+    .Input("accum: Ref(T)")
+    .Input("lr: T")
+    .Input("epsilon: T")
+    .Input("grad: T")
+    .Input("indices: Tindices")
+    .Output("out: Ref(T)")
+    .Attr("T: numbertype")
+    .Attr("Tindices: {int32, int64}")
+    .Attr("use_locking: bool = false")
+    .Attr("update_slots: bool = true")
+    .SetShapeFn([](InferenceContext* c) {
+      return ApplyAdagradV2ShapeFn(c, true /* sparse */);
+    });
+
+REGISTER_OP("ResourceSparseApplyAdagradV2")
+    .Input("var: resource")
+    .Input("accum: resource")
+    .Input("lr: T")
+    .Input("epsilon: T")
+    .Input("grad: T")
+    .Input("indices: Tindices")
+    .Attr("T: numbertype")
+    .Attr("Tindices: {int32, int64}")
+    .Attr("use_locking: bool = false")
+    .Attr("update_slots: bool = true")
+    .SetShapeFn([](InferenceContext* c) {
+      return ApplyAdagradV2ShapeFn(c, true /* sparse */);
+    });
+
 static Status ApplyAdagradDAShapeFn(InferenceContext* c, bool sparse) {
   ShapeHandle unused;
   ShapeHandle s = ShapeOrHandleShape(c, 0);  // var
diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD
new file mode 100644
index 0000000..c50fc30
--- /dev/null
+++ b/tensorflow/core/platform/BUILD
@@ -0,0 +1,455 @@
+# Description:
+#   TensorFlow Base libraries.
+#   This package contains the following libraries:
+#     - Platform dependent libraries that require different implementations
+#       across different OSs or environments.
+#     - STL replacement libraries rest of TensorFlow should depend on.
+#
+#   The libraries in this package are not allowed to have ANY dependencies
+#   to any TensorFlow code outside this package.
+
+load(
+    "//tensorflow/core/platform:default/build_config.bzl",
+    "tf_additional_device_tracer_srcs",
+    "tf_additional_lib_hdrs",
+    "tf_additional_lib_srcs",
+    "tf_additional_libdevice_srcs",
+    "tf_additional_minimal_lib_srcs",
+    "tf_additional_monitoring_srcs",
+    "tf_additional_proto_hdrs",
+    "tf_additional_rocdl_deps",
+    "tf_additional_rocdl_srcs",
+    "tf_additional_test_srcs",
+    "tf_env_time_srcs",
+    "tf_logging_absl_deps",
+    "tf_platform_hdrs",
+    "tf_platform_srcs",
+)
+load(
+    "//tensorflow:tensorflow.bzl",
+    "tf_copts",
+)
+
+package(
+    default_visibility = [
+        "//tensorflow:__subpackages__",
+    ],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+cc_library(
+    name = "abi",
+    srcs = ["abi.cc"],
+    hdrs = ["abi.h"],
+    deps = [":types"],
+)
+
+cc_library(
+    name = "annotation",
+    hdrs = ["annotation.h"],
+    deps = [
+        ":macros",
+        "@com_google_absl//absl/strings",
+    ],
+)
+
+cc_library(
+    name = "byte_order",
+    hdrs = ["byte_order.h"],
+)
+
+cc_library(
+    name = "cpu_feature_guard",
+    srcs = ["cpu_feature_guard.cc"],
+    hdrs = ["cpu_feature_guard.h"],
+    deps = [
+        ":byte_order",
+        ":cpu_info",
+        ":logging",
+    ],
+)
+
+cc_library(
+    name = "cpu_info",
+    srcs = ["cpu_info.cc"] + tf_platform_srcs([
+        "cpu_info.h",
+    ]),
+    hdrs = ["cpu_info.h"],
+    copts = tf_copts(),
+    deps = [
+        ":byte_order",
+        ":logging",
+        ":platform",
+        ":types",
+    ],
+)
+
+cc_library(
+    name = "denormal",
+    srcs = ["denormal.cc"],
+    hdrs = ["denormal.h"],
+    deps = [
+        ":byte_order",
+        ":cpu_info",
+        ":logging",
+        ":macros",
+        ":platform",
+    ],
+)
+
+cc_library(
+    name = "env_time",
+    srcs = ["env_time.cc"] + tf_env_time_srcs(),
+    hdrs = ["env_time.h"],
+    deps = [
+        ":types",
+    ],
+)
+
+cc_library(
+    name = "file_statistics",
+    hdrs = ["file_statistics.h"],
+    deps = [":types"],
+)
+
+cc_library(
+    name = "host_info",
+    hdrs = ["host_info.h"],
+    deps = [":types"],
+)
+
+cc_library(
+    name = "logging",
+    srcs = tf_platform_hdrs(["logging.h"]) + tf_platform_srcs(["logging.cc"]),
+    hdrs = ["logging.h"],
+    deps = [
+        ":env_time",
+        ":macros",
+        ":platform",
+        ":types",
+        "//tensorflow/core/platform/default/build_config:base",
+    ] + tf_logging_absl_deps(),
+)
+
+cc_library(
+    name = "macros",
+    hdrs = ["macros.h"],
+)
+
+cc_library(
+    name = "rocm_rocdl_path",
+    srcs = ["rocm_rocdl_path.cc"] + tf_additional_rocdl_srcs(),
+    hdrs = ["rocm_rocdl_path.h"],
+    deps = [
+        ":types",
+        "//tensorflow/core:lib",
+    ] + tf_additional_rocdl_deps(),
+)
+
+cc_library(
+    name = "platform",
+    hdrs = ["platform.h"],
+)
+
+cc_library(
+    name = "prefetch",
+    hdrs = ["prefetch.h"],
+    deps = [":platform"],
+)
+
+cc_library(
+    name = "stacktrace",
+    srcs = glob(["*/stacktrace.h"]),
+    hdrs = ["stacktrace.h"],
+    deps = [
+        ":abi",
+        ":platform",
+        "//tensorflow/core/platform/default/build_config:stacktrace",
+    ],
+)
+
+cc_library(
+    name = "setround",
+    srcs = ["setround.cc"],
+    hdrs = ["setround.h"],
+    deps = [
+        ":logging",
+        ":macros",
+    ],
+)
+
+cc_library(
+    name = "stringpiece",
+    hdrs = ["stringpiece.h"],
+    deps = [
+        "@com_google_absl//absl/strings",
+    ],
+)
+
+cc_library(
+    name = "stringprintf",
+    srcs = ["stringprintf.cc"],
+    hdrs = ["stringprintf.h"],
+    deps = [
+        ":macros",
+        ":types",
+    ],
+)
+
+cc_library(
+    name = "tstring",
+    hdrs = ["tstring.h"],
+)
+
+cc_library(
+    name = "types",
+    srcs = tf_platform_hdrs(["integral_types.h"]),
+    hdrs = ["types.h"],
+    deps = [
+        ":platform",
+        ":tstring",
+        "//tensorflow/core/platform/default/build_config:base",
+    ],
+)
+
+# --------------------------------------------------------------------------
+#     Below libraries are here only to make sure the legacy build rules
+#     in tensorflow/core/BUILD are working!
+#
+#     DO NOT add any new dependencies on these rules!
+#
+# --------------------------------------------------------------------------
+
+filegroup(
+    name = "legacy_platform_lib_hdrs",
+    srcs = tf_additional_lib_hdrs(),
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+filegroup(
+    name = "legacy_platform_lib_srcs",
+    srcs = tf_additional_lib_srcs(
+        exclude = [
+            "*test*",
+            "**/*test*",
+            "**/cuda.h",
+            "**/cuda_libdevice_path.cc",
+            "**/rocm.h",
+            "**/monitoring.cc",
+            "**/stream_executor.h",
+            "**/env_time.cc",
+            "**/device_tracer.cc",
+            "**/logger.cc",
+            "**/logging.cc",
+            "**/human_readable_json.cc",
+            "**/rocm.h",
+            "**/rocm_rocdl_path.cc",
+            "abi.cc",
+            "cpu_info.cc",
+            "platform_strings.cc",
+            "protobuf.cc",
+            "stringprintf.cc",
+        ],
+    ),
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+filegroup(
+    name = "legacy_proto_hdrs",
+    srcs = tf_additional_proto_hdrs(),
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+filegroup(
+    name = "legacy_srcs_no_runtime",
+    srcs = glob(
+        [
+            "**/*.h",
+            "**/*.cc",
+        ],
+        exclude = [
+            "*test.*",
+            "*testutil*",
+            "*testlib*",
+            "*main.cc",
+            "**/*test.*",
+            "**/*testutil*",
+            "**/*testlib*",
+            "**/*main.cc",
+            "**/cuda_libdevice_path.*",
+            "**/logger.cc",
+            # Exclude env_time and logging to avoid collisions with
+            # :platform_base, a common dependency for downstream targets.
+            "**/env_time.cc",
+            "**/logging.cc",
+            "**/rocm_rocdl_path.*",
+            "default/test_benchmark.*",
+            "cuda.h",
+            "rocm.h",
+            "google/**/*",
+            "hadoop/**/*",
+            "gif.h",
+            "jpeg.h",
+            "png.h",
+            "stream_executor.*",
+            "windows/**/*",
+        ],
+    ),
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+filegroup(
+    name = "legacy_lib_internal_headers",
+    srcs = glob(
+        [
+            "*.h",
+            "profile_utils/**/*.h",
+        ],
+        exclude = [
+            "gif.h",
+            "jpeg.h",
+            "png.h",
+            "stringprintf.h",
+            "**/cuda.h",
+            "**/rocm.h",
+            "**/stream_executor.h",
+        ],
+    ),
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+filegroup(
+    name = "legacy_lib_internal_srcs",
+    srcs = glob(
+        [
+            "*.cc",
+            "profile_utils/**/*.cc",
+        ],
+        exclude = [
+            "*test*",
+            "**/*test*",
+            "**/env_time.cc",
+            "**/monitoring.cc",
+            "**/cuda_libdevice_path.cc",
+            "**/device_tracer.cc",
+            "**/logger.cc",
+            "**/logging.cc",
+            "**/human_readable_json.cc",
+            "**/rocm_rocdl_path.cc",
+            "abi.cc",
+            "cpu_info.cc",
+            "platform_strings.cc",
+            "protobuf.cc",
+            "stringprintf.cc",
+        ],
+    ),
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+filegroup(
+    name = "legacy_test_srcs",
+    srcs = tf_additional_test_srcs(),
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+filegroup(
+    name = "legacy_device_tracer_srcs",
+    srcs = tf_additional_device_tracer_srcs(),
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+filegroup(
+    name = "legacy_minimal_lib_srcs",
+    srcs = tf_additional_minimal_lib_srcs(),
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+filegroup(
+    name = "legacy_libdevice_srcs",
+    srcs = tf_additional_libdevice_srcs(),
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+filegroup(
+    name = "legacy_monitoring_srcs",
+    srcs = tf_additional_monitoring_srcs(),
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+filegroup(
+    name = "legacy_platform_port_srcs",
+    srcs = tf_platform_hdrs([
+        "cpu_info.h",
+        "dynamic_annotations.h",
+        "thread_annotations.h",
+        "mutex.h",
+    ]) + tf_platform_srcs([
+        "port.cc",
+    ]),
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+filegroup(
+    name = "legacy_platform_env_srcs",
+    srcs = tf_platform_srcs([
+        "env.cc",
+        "load_library.cc",
+    ]) + tf_platform_hdrs([
+        "wide_char.h",
+    ]),
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+filegroup(
+    name = "legacy_file_system_hdrs",
+    srcs = tf_platform_hdrs([
+        "windows_file_system.h",
+    ]),
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+filegroup(
+    name = "legacy_platform_other_srcs",
+    srcs = tf_platform_srcs([
+        "subprocess.cc",
+        "net.cc",
+        "tracing.cc",
+    ]) + tf_platform_hdrs([
+        "tracing.h",
+        "error.h",
+        "context.h",
+        "fingerprint.h",
+        "notification.h",
+        "strong_hash.h",
+        "subprocess.h",
+        "tracing_impl.h",
+    ]),
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+filegroup(
+    name = "legacy_human_readable_json_src",
+    srcs = tf_platform_srcs(["human_readable_json.cc"]),
+    visibility = ["//tensorflow/core:__pkg__"],
+)
+
+# TODO(gunan): Remove the following once references in core/BUILD is removed.
+exports_files(
+    glob(
+        [
+            "*",
+            "**",
+        ],
+        exclude = [
+            "abi.h",
+            "byte_order.h",
+            "cpu_info.cc",
+            "cpu_info.h",
+            "logging.h",
+            "macros.h",
+            "platform.h",
+            "types.h",
+            "stacktrace.h",
+        ],
+    ),
+)
diff --git a/tensorflow/core/platform/cpu_info.h b/tensorflow/core/platform/cpu_info.h
index b2d0f21..60574bf 100644
--- a/tensorflow/core/platform/cpu_info.h
+++ b/tensorflow/core/platform/cpu_info.h
@@ -23,7 +23,8 @@
 #include "tensorflow/core/platform/byte_order.h"
 
 #if defined(_MSC_VER)
-#include "tensorflow/core/platform/windows/cpu_info.h"
+// included so __cpuidex function is available for GETCPUID on Windows
+#include <intrin.h>
 #endif
 
 namespace tensorflow {
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index b1248af..417f37f 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -4,7 +4,7 @@
 load("//tensorflow:tensorflow.bzl", "if_not_mobile")
 load("//tensorflow:tensorflow.bzl", "if_windows")
 load("//tensorflow:tensorflow.bzl", "if_not_windows")
-load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
+load("//tensorflow/core/platform:default/build_config_root.bzl", "if_static")
 load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
 load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
 load(
@@ -471,12 +471,12 @@
 # must be compiled in the 'default' platform, this is a list of all headers
 # mentioned in the platform/* files.
 def tf_platform_hdrs(files):
-    return native.glob(["platform/*/" + f for f in files])
+    return native.glob(["*/" + f for f in files])
 
 def tf_platform_srcs(files):
-    base_set = ["platform/default/" + f for f in files]
-    windows_set = base_set + ["platform/windows/" + f for f in files]
-    posix_set = base_set + ["platform/posix/" + f for f in files]
+    base_set = ["default/" + f for f in files]
+    windows_set = base_set + ["windows/" + f for f in files]
+    posix_set = base_set + ["posix/" + f for f in files]
 
     # Handle cases where we must also bring the posix file in. Usually, the list
     # of files to build on windows builds is just all the stuff in the
@@ -485,7 +485,7 @@
     # file instead of making a copy in 'windows'.
     for f in files:
         if f == "error.cc":
-            windows_set.append("platform/posix/" + f)
+            windows_set.append("posix/" + f)
 
     return select({
         "//tensorflow:windows": native.glob(windows_set),
@@ -494,29 +494,29 @@
 
 def tf_additional_lib_hdrs(exclude = []):
     windows_hdrs = native.glob([
-        "platform/default/*.h",
-        "platform/windows/*.h",
-        "platform/posix/error.h",
+        "default/*.h",
+        "windows/*.h",
+        "posix/error.h",
     ], exclude = exclude)
     return select({
         "//tensorflow:windows": windows_hdrs,
         "//conditions:default": native.glob([
-            "platform/default/*.h",
-            "platform/posix/*.h",
+            "default/*.h",
+            "posix/*.h",
         ], exclude = exclude),
     })
 
 def tf_additional_lib_srcs(exclude = []):
     windows_srcs = native.glob([
-        "platform/default/*.cc",
-        "platform/windows/*.cc",
-        "platform/posix/error.cc",
+        "default/*.cc",
+        "windows/*.cc",
+        "posix/error.cc",
     ], exclude = exclude)
     return select({
         "//tensorflow:windows": windows_srcs,
         "//conditions:default": native.glob([
-            "platform/default/*.cc",
-            "platform/posix/*.cc",
+            "default/*.cc",
+            "posix/*.cc",
         ], exclude = exclude),
     })
 
@@ -525,29 +525,24 @@
 
 def tf_additional_monitoring_srcs():
     return [
-        "platform/default/monitoring.cc",
+        "default/monitoring.cc",
     ]
 
 def tf_additional_minimal_lib_srcs():
     return [
-        "platform/default/integral_types.h",
-        "platform/default/mutex.h",
-        "platform/default/mutex_data.h",
+        "default/integral_types.h",
+        "default/mutex.h",
+        "default/mutex_data.h",
     ]
 
 def tf_additional_proto_hdrs():
     return [
-        "platform/default/integral_types.h",
-        "platform/default/logging.h",
+        "default/integral_types.h",
+        "default/logging.h",
     ] + if_windows([
-        "platform/windows/integral_types.h",
+        "windows/integral_types.h",
     ])
 
-def tf_additional_proto_srcs():
-    return [
-        "platform/protobuf.cc",
-    ]
-
 def tf_additional_human_readable_json_deps():
     return []
 
@@ -596,7 +591,7 @@
     ]
 
 def tf_additional_device_tracer_srcs():
-    return ["platform/default/device_tracer.cc"]
+    return ["default/device_tracer.cc"]
 
 def tf_additional_device_tracer_cuda_deps():
     return []
@@ -627,20 +622,26 @@
     return ["@local_config_cuda//cuda:cuda_headers"]
 
 def tf_additional_libdevice_srcs():
-    return ["platform/default/cuda_libdevice_path.cc"]
+    return ["default/cuda_libdevice_path.cc"]
+
+def tf_additional_rocdl_deps():
+    return ["@local_config_rocm//rocm:rocm_headers"]
+
+def tf_additional_rocdl_srcs():
+    return ["default/rocm_rocdl_path.cc"]
 
 def tf_additional_test_deps():
     return []
 
 def tf_additional_test_srcs():
     return [
-        "platform/default/test_benchmark.cc",
+        "default/test_benchmark.cc",
     ] + select({
         "//tensorflow:windows": [
-            "platform/windows/test.cc",
+            "windows/test.cc",
         ],
         "//conditions:default": [
-            "platform/posix/test.cc",
+            "posix/test.cc",
         ],
     })
 
@@ -816,3 +817,22 @@
             "-DTENSORFLOW_USE_NUMA",
         ],
     })
+
+def tf_additional_rpc_deps():
+    return []
+
+def tf_logging_absl_deps():
+    return [
+        "@com_google_absl//absl/base",
+        "@com_google_absl//absl/strings",
+    ]
+
+def tf_env_time_srcs():
+    return select({
+        "//tensorflow:windows": [
+            "windows/env_time.cc",
+        ],
+        "//conditions:default": [
+            "posix/env_time.cc",
+        ],
+    })
diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD
index d917d44..4f96be2 100644
--- a/tensorflow/core/platform/default/build_config/BUILD
+++ b/tensorflow/core/platform/default/build_config/BUILD
@@ -12,7 +12,7 @@
 load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
 load("//tensorflow:tensorflow.bzl", "tf_copts")
 load("//tensorflow:tensorflow.bzl", "tf_cuda_library")
-load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
+load("//tensorflow/core/platform:default/build_config_root.bzl", "if_static")
 load("@local_config_sycl//sycl:platform.bzl", "sycl_library_path")
 load("@local_config_sycl//sycl:build_defs.bzl", "if_ccpp")
 
@@ -65,13 +65,13 @@
     name = "stream_executor_cuda",
     deps = [
         ":stream_executor_no_cuda",
-    ] + if_static(
-        [
+    ] + select({
+        "//tensorflow:oss": ["//tensorflow/stream_executor/cuda:cudart_stub"],
+        "//conditions:default": [
             "//tensorflow/stream_executor/cuda:all_runtime",
             ":cuda",
         ],
-        ["//tensorflow/stream_executor/cuda:cudart_stub"],
-    ) + select({
+    }) + select({
         "@local_config_cuda//cuda:darwin": ["IOKit"],
         "//conditions:default": [],
     }),
diff --git a/tensorflow/core/platform/default/device_tracer.cc b/tensorflow/core/platform/default/device_tracer.cc
index 27565a7..541f8e4 100644
--- a/tensorflow/core/platform/default/device_tracer.cc
+++ b/tensorflow/core/platform/default/device_tracer.cc
@@ -653,8 +653,7 @@
 }  // namespace
 
 // Not in anonymous namespace for testing purposes.
-std::unique_ptr<profiler::ProfilerInterface> CreateDeviceTracer(
-    const ProfilerContext*) {
+std::unique_ptr<profiler::ProfilerInterface> CreateDeviceTracer() {
   auto status = cuInit(0);
   if (status != CUDA_SUCCESS) {
     LogIfError(ToStatus(status));
diff --git a/tensorflow/core/platform/default/rocm_rocdl_path.cc b/tensorflow/core/platform/default/rocm_rocdl_path.cc
new file mode 100644
index 0000000..1419604
--- /dev/null
+++ b/tensorflow/core/platform/default/rocm_rocdl_path.cc
@@ -0,0 +1,36 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/rocm_rocdl_path.h"
+
+#include <stdlib.h>
+
+#if !defined(PLATFORM_GOOGLE) && TENSORFLOW_USE_ROCM
+#include "rocm/rocm_config.h"
+#endif
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+string RocmRoot() {
+#if TENSORFLOW_USE_ROCM
+  VLOG(3) << "ROCM root = " << TF_ROCM_TOOLKIT_PATH;
+  return TF_ROCM_TOOLKIT_PATH;
+#else
+  return "";
+#endif
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/platform/default/unbounded_work_queue.cc b/tensorflow/core/platform/default/unbounded_work_queue.cc
new file mode 100644
index 0000000..3cc66b6
--- /dev/null
+++ b/tensorflow/core/platform/default/unbounded_work_queue.cc
@@ -0,0 +1,96 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/unbounded_work_queue.h"
+
+#include "absl/memory/memory.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+UnboundedWorkQueue::UnboundedWorkQueue(Env* env, const string& thread_name)
+    : env_(env), thread_name_(thread_name) {}
+
+UnboundedWorkQueue::~UnboundedWorkQueue() {
+  {
+    mutex_lock l(work_queue_mu_);
+    // Wake up all `PooledThreadFunc` threads and cause them to terminate before
+    // joining them when `threads_` is cleared.
+    cancelled_ = true;
+    work_queue_cv_.notify_all();
+    if (!work_queue_.empty()) {
+      LOG(ERROR) << "UnboundedWorkQueue named \"" << thread_name_ << "\" was "
+                 << "deleted with pending work in its queue. This may indicate "
+                 << "a potential use-after-free bug.";
+    }
+  }
+
+  {
+    mutex_lock l(thread_pool_mu_);
+    // Clear the list of pooled threads, which will eventually terminate due to
+    // the previous notification.
+    //
+    // NOTE: It is safe to do this while holding `thread_pool_mu_`, because
+    // no subsequent calls to `this->Schedule()` should be issued after the
+    // destructor starts.
+    thread_pool_.clear();
+  }
+}
+
+void UnboundedWorkQueue::Schedule(WorkFunction fn) {
+  // Enqueue a work item for the new thread's function, and wake up a
+  // cached thread to process it.
+  mutex_lock l(work_queue_mu_);
+  work_queue_.push_back(std::move(fn));
+  work_queue_cv_.notify_one();
+  // NOTE: The queue may be non-empty, so we must account for queued work when
+  // considering how many threads are free.
+  if (work_queue_.size() > num_idle_threads_) {
+    // Spawn a new physical thread to process the given function.
+    // NOTE: `PooledThreadFunc` will eventually increment `num_idle_threads_`
+    // at the beginning of its work loop.
+    Thread* new_thread =
+        env_->StartThread({}, thread_name_, [this]() { PooledThreadFunc(); });
+
+    mutex_lock l(thread_pool_mu_);
+    thread_pool_.emplace_back(new_thread);
+  }
+}
+
+void UnboundedWorkQueue::PooledThreadFunc() {
+  while (true) {
+    WorkFunction fn;
+    {
+      mutex_lock l(work_queue_mu_);
+      ++num_idle_threads_;
+      while (!cancelled_ && work_queue_.empty()) {
+        // Wait for a new work function to be submitted, or the cache to be
+        // destroyed.
+        work_queue_cv_.wait(l);
+      }
+      if (cancelled_) {
+        return;
+      }
+      fn = std::move(work_queue_.front());
+      work_queue_.pop_front();
+      --num_idle_threads_;
+    }
+
+    fn();
+  }
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/platform/default/unbounded_work_queue.h b/tensorflow/core/platform/default/unbounded_work_queue.h
new file mode 100644
index 0000000..cba8362
--- /dev/null
+++ b/tensorflow/core/platform/default/unbounded_work_queue.h
@@ -0,0 +1,65 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_UNBOUNDED_WORK_QUEUE_H_
+#define TENSORFLOW_CORE_PLATFORM_DEFAULT_UNBOUNDED_WORK_QUEUE_H_
+
+#include <deque>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+// An `UnboundedWorkQueue` provides a mechanism for temporally multiplexing a
+// potentially large number of "logical" threads onto a smaller number of
+// "physical" threads. The multiplexing is achieved by maintaining an internal
+// pool of long-running "physical" threads that are used to execute the
+// "logical" threads.  Like a regular thread, a "logical" thread may block on
+// other threads, and the size of the pool will increase to ensure that progress
+// is made. This mechanism is recommended in situations where short-lived
+// threads are created repeatedly, to avoid the overhead and memory
+// fragmentation that can result from excessive thread creation.
+class UnboundedWorkQueue {
+ public:
+  UnboundedWorkQueue(Env* env, const string& thread_name);
+  ~UnboundedWorkQueue();
+
+  using WorkFunction = std::function<void()>;
+
+  // Schedule `fn` on a thread.  `fn` may perform blocking work, so if all the
+  // existing threads are blocked or busy, this may spawn a new thread which
+  // will be added to the thread pool managed by this work queue.
+  void Schedule(WorkFunction fn);
+
+ private:
+  void PooledThreadFunc();
+
+  Env* const env_;  // Not owned.
+  const string thread_name_;
+  mutex work_queue_mu_;
+  condition_variable work_queue_cv_ GUARDED_BY(work_queue_mu_);
+  size_t num_idle_threads_ GUARDED_BY(work_queue_mu_) = 0;
+  bool cancelled_ GUARDED_BY(work_queue_mu_) = false;
+  std::deque<WorkFunction> work_queue_ GUARDED_BY(work_queue_mu_);
+  mutex thread_pool_mu_;
+  std::vector<std::unique_ptr<Thread>> thread_pool_ GUARDED_BY(thread_pool_mu_);
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_PLATFORM_DEFAULT_UNBOUNDED_WORK_QUEUE_H_
diff --git a/tensorflow/core/platform/device_tracer_test.cc b/tensorflow/core/platform/device_tracer_test.cc
index d90e126..e43711f 100644
--- a/tensorflow/core/platform/device_tracer_test.cc
+++ b/tensorflow/core/platform/device_tracer_test.cc
@@ -39,15 +39,12 @@
 #include "tensorflow/core/util/device_name_utils.h"
 
 namespace tensorflow {
-struct ProfilerContext;
 
 #if GOOGLE_CUDA
-std::unique_ptr<profiler::ProfilerInterface> CreateDeviceTracer(
-    const ProfilerContext*);
+std::unique_ptr<profiler::ProfilerInterface> CreateDeviceTracer();
 #else
 // We don't have device tracer for non-cuda case.
-std::unique_ptr<profiler::ProfilerInterface> CreateDeviceTracer(
-    const ProfilerContext*) {
+std::unique_ptr<profiler::ProfilerInterface> CreateDeviceTracer() {
   return nullptr;
 }
 #endif
@@ -111,21 +108,21 @@
 };
 
 TEST_F(DeviceTracerTest, StartStop) {
-  auto tracer = CreateDeviceTracer(nullptr);
+  auto tracer = CreateDeviceTracer();
   if (!tracer) return;
   TF_EXPECT_OK(tracer->Start());
   TF_EXPECT_OK(tracer->Stop());
 }
 
 TEST_F(DeviceTracerTest, StopBeforeStart) {
-  auto tracer = CreateDeviceTracer(nullptr);
+  auto tracer = CreateDeviceTracer();
   if (!tracer) return;
   TF_EXPECT_OK(tracer->Stop());
   TF_EXPECT_OK(tracer->Stop());
 }
 
 TEST_F(DeviceTracerTest, CollectBeforeStart) {
-  auto tracer = CreateDeviceTracer(nullptr);
+  auto tracer = CreateDeviceTracer();
   if (!tracer) return;
   RunMetadata run_metadata;
   TF_EXPECT_OK(tracer->CollectData(&run_metadata));
@@ -133,7 +130,7 @@
 }
 
 TEST_F(DeviceTracerTest, CollectBeforeStop) {
-  auto tracer = CreateDeviceTracer(nullptr);
+  auto tracer = CreateDeviceTracer();
   if (!tracer) return;
   TF_EXPECT_OK(tracer->Start());
   RunMetadata run_metadata;
@@ -143,8 +140,8 @@
 }
 
 TEST_F(DeviceTracerTest, StartTwoTracers) {
-  auto tracer1 = CreateDeviceTracer(nullptr);
-  auto tracer2 = CreateDeviceTracer(nullptr);
+  auto tracer1 = CreateDeviceTracer();
+  auto tracer2 = CreateDeviceTracer();
   if (!tracer1 || !tracer2) return;
 
   TF_EXPECT_OK(tracer1->Start());
@@ -157,7 +154,7 @@
 
 TEST_F(DeviceTracerTest, RunWithTracer) {
   // On non-GPU platforms, we may not support DeviceTracer.
-  auto tracer = CreateDeviceTracer(nullptr);
+  auto tracer = CreateDeviceTracer();
   if (!tracer) return;
 
   Initialize({3, 2, -1, 0});
@@ -184,7 +181,7 @@
 }
 
 TEST_F(DeviceTracerTest, TraceToStepStatsCollector) {
-  auto tracer = CreateDeviceTracer(nullptr);
+  auto tracer = CreateDeviceTracer();
   if (!tracer) return;
 
   Initialize({3, 2, -1, 0});
diff --git a/tensorflow/core/platform/env.h b/tensorflow/core/platform/env.h
index fe781a5..f7a91c7 100644
--- a/tensorflow/core/platform/env.h
+++ b/tensorflow/core/platform/env.h
@@ -32,6 +32,11 @@
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/types.h"
 
+// Delete the definition of CopyFile as the linker gets confused.
+#ifdef PLATFORM_WINDOWS
+#undef CopyFile
+#endif
+
 namespace tensorflow {
 
 class Thread;
diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h
index 8ab43c4..21d9f3f 100644
--- a/tensorflow/core/platform/file_system.h
+++ b/tensorflow/core/platform/file_system.h
@@ -32,6 +32,7 @@
 
 #ifdef PLATFORM_WINDOWS
 #undef DeleteFile
+#undef CopyFile
 #endif
 
 namespace tensorflow {
diff --git a/tensorflow/core/platform/grpc_services.h b/tensorflow/core/platform/grpc_services.h
index cd91819..13b84ca 100644
--- a/tensorflow/core/platform/grpc_services.h
+++ b/tensorflow/core/platform/grpc_services.h
@@ -12,11 +12,11 @@
 #ifndef TENSORFLOW_CORE_PLATFORM_GRPC_SERVICES_H_
 #define TENSORFLOW_CORE_PLATFORM_GRPC_SERVICES_H_
 
+#include "tensorflow/core/platform/platform.h"
 #include "tensorflow/core/profiler/profiler_analysis.grpc.pb.h"
 #include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
 
 #if !defined(PLATFORM_GOOGLE)
-
 namespace tensorflow {
 namespace grpc {
 
diff --git a/tensorflow/core/platform/platform_strings.cc b/tensorflow/core/platform/platform_strings.cc
index c185263..489a211 100644
--- a/tensorflow/core/platform/platform_strings.cc
+++ b/tensorflow/core/platform/platform_strings.cc
@@ -15,14 +15,12 @@
 
 #include "tensorflow/core/platform/platform_strings.h"
 
+#include <cerrno>
 #include <cstdio>
 #include <cstring>
-
 #include <string>
 #include <vector>
 
-#include "tensorflow/core/lib/core/status.h"
-
 namespace tensorflow {
 
 int GetPlatformStrings(const std::string& path,
diff --git a/tensorflow/core/platform/posix/port.cc b/tensorflow/core/platform/posix/port.cc
index a3699de..47f4aba 100644
--- a/tensorflow/core/platform/posix/port.cc
+++ b/tensorflow/core/platform/posix/port.cc
@@ -303,10 +303,6 @@
 
 std::size_t MallocExtension_GetAllocatedSize(const void* p) { return 0; }
 
-void AdjustFilenameForLogging(string* filename) {
-  // Nothing to do
-}
-
 bool Snappy_Compress(const char* input, size_t length, string* output) {
 #ifdef TF_USE_SNAPPY
   output->resize(snappy::MaxCompressedLength(length));
diff --git a/tensorflow/core/platform/rocm_rocdl_path.cc b/tensorflow/core/platform/rocm_rocdl_path.cc
new file mode 100644
index 0000000..bf5b2bf
--- /dev/null
+++ b/tensorflow/core/platform/rocm_rocdl_path.cc
@@ -0,0 +1,26 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/rocm_rocdl_path.h"
+
+#include "tensorflow/core/lib/io/path.h"
+
+namespace tensorflow {
+
+string RocdlRoot() {
+  return tensorflow::io::JoinPath(tensorflow::RocmRoot(), "hcc/lib");
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/platform/rocm_rocdl_path.h b/tensorflow/core/platform/rocm_rocdl_path.h
new file mode 100644
index 0000000..e83ef5b
--- /dev/null
+++ b/tensorflow/core/platform/rocm_rocdl_path.h
@@ -0,0 +1,32 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_ROCM_ROCDL_PATH_H_
+#define TENSORFLOW_CORE_PLATFORM_ROCM_ROCDL_PATH_H_
+
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+// Returns the root directory of the ROCM SDK, which contains sub-folders such
+// as bin, lib, and rocdl.
+string RocmRoot();
+
+// Returns the directory that contains ROCm-Device-Libs files in the ROCm SDK.
+string RocdlRoot();
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_PLATFORM_ROCM_ROCDL_PATH_H_
diff --git a/tensorflow/core/platform/rocm_rocdl_path_test.cc b/tensorflow/core/platform/rocm_rocdl_path_test.cc
new file mode 100644
index 0000000..4a4d9b8
--- /dev/null
+++ b/tensorflow/core/platform/rocm_rocdl_path_test.cc
@@ -0,0 +1,35 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/rocm_rocdl_path.h"
+
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+#if TENSORFLOW_USE_ROCM
+TEST(RocmRocdlPathTest, ROCDLPath) {
+  VLOG(2) << "ROCm-Deivce-Libs root = " << RocdlRoot();
+  std::vector<string> rocdl_files;
+  TF_EXPECT_OK(Env::Default()->GetMatchingPaths(
+      io::JoinPath(RocdlRoot(), "*.amdgcn.bc"), &rocdl_files));
+  EXPECT_LT(0, rocdl_files.size());
+}
+#endif
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/platform/stringpiece.h b/tensorflow/core/platform/stringpiece.h
new file mode 100644
index 0000000..4ca42b4
--- /dev/null
+++ b/tensorflow/core/platform/stringpiece.h
@@ -0,0 +1,37 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// StringPiece is a simple structure containing a pointer into some external
+// storage and a size.  The user of a StringPiece must ensure that the slice
+// is not used after the corresponding external storage has been
+// deallocated.
+//
+// Multiple threads can invoke const methods on a StringPiece without
+// external synchronization, but if any of the threads may call a
+// non-const method, all threads accessing the same StringPiece must use
+// external synchronization.
+
+#ifndef TENSORFLOW_CORE_PLATFORM_STRINGPIECE_H_
+#define TENSORFLOW_CORE_PLATFORM_STRINGPIECE_H_
+
+#include "absl/strings/string_view.h"
+
+namespace tensorflow {
+
+using StringPiece = absl::string_view;
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_PLATFORM_STRINGPIECE_H_
diff --git a/tensorflow/core/platform/stringpiece_test.cc b/tensorflow/core/platform/stringpiece_test.cc
new file mode 100644
index 0000000..4643c0c
--- /dev/null
+++ b/tensorflow/core/platform/stringpiece_test.cc
@@ -0,0 +1,64 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/stringpiece.h"
+
+#include <unordered_map>
+
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+TEST(StringPiece, Ctor) {
+  {
+    // const char* without size.
+    const char* hello = "hello";
+    StringPiece s20(hello);
+    EXPECT_TRUE(s20.data() == hello);
+    EXPECT_EQ(5, s20.size());
+
+    // const char* with size.
+    StringPiece s21(hello, 4);
+    EXPECT_TRUE(s21.data() == hello);
+    EXPECT_EQ(4, s21.size());
+
+    // Not recommended, but valid C++
+    StringPiece s22(hello, 6);
+    EXPECT_TRUE(s22.data() == hello);
+    EXPECT_EQ(6, s22.size());
+  }
+
+  {
+    string hola = "hola";
+    StringPiece s30(hola);
+    EXPECT_TRUE(s30.data() == hola.data());
+    EXPECT_EQ(4, s30.size());
+
+    // std::string with embedded '\0'.
+    hola.push_back('\0');
+    hola.append("h2");
+    hola.push_back('\0');
+    StringPiece s31(hola);
+    EXPECT_TRUE(s31.data() == hola.data());
+    EXPECT_EQ(8, s31.size());
+  }
+}
+
+TEST(StringPiece, ConversionToString) {
+  EXPECT_EQ("", string(StringPiece("")));
+  EXPECT_EQ("foo", string(StringPiece("foo")));
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/platform/stringprintf.cc b/tensorflow/core/platform/stringprintf.cc
new file mode 100644
index 0000000..89d99c8
--- /dev/null
+++ b/tensorflow/core/platform/stringprintf.cc
@@ -0,0 +1,93 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/stringprintf.h"
+
+#include <errno.h>
+#include <stdarg.h>  // For va_list and related operations
+#include <stdio.h>   // MSVC requires this for _vsnprintf
+
+namespace tensorflow {
+namespace strings {
+
+void Appendv(string* dst, const char* format, va_list ap) {
+  // First try with a small fixed size buffer
+  static const int kSpaceLength = 1024;
+  char space[kSpaceLength];
+
+  // It's possible for methods that use a va_list to invalidate
+  // the data in it upon use.  The fix is to make a copy
+  // of the structure before using it and use that copy instead.
+  va_list backup_ap;
+  va_copy(backup_ap, ap);
+  int result = vsnprintf(space, kSpaceLength, format, backup_ap);
+  va_end(backup_ap);
+
+  if (result < kSpaceLength) {
+    if (result >= 0) {
+      // Normal case -- everything fit.
+      dst->append(space, result);
+      return;
+    }
+
+#ifdef _MSC_VER
+      // Error or MSVC running out of space.  MSVC 8.0 and higher
+      // can be asked about space needed with the special idiom below:
+      va_copy(backup_ap, ap);
+      result = vsnprintf(nullptr, 0, format, backup_ap);
+      va_end(backup_ap);
+#endif
+
+    if (result < 0) {
+      // Just an error.
+      return;
+    }
+  }
+
+  // Increase the buffer size to the size requested by vsnprintf,
+  // plus one for the closing \0.
+  int length = result + 1;
+  char* buf = new char[length];
+
+  // Restore the va_list before we use it again
+  va_copy(backup_ap, ap);
+  result = vsnprintf(buf, length, format, backup_ap);
+  va_end(backup_ap);
+
+  if (result >= 0 && result < length) {
+    // It fit
+    dst->append(buf, result);
+  }
+  delete[] buf;
+}
+
+string Printf(const char* format, ...) {
+  va_list ap;
+  va_start(ap, format);
+  string result;
+  Appendv(&result, format, ap);
+  va_end(ap);
+  return result;
+}
+
+void Appendf(string* dst, const char* format, ...) {
+  va_list ap;
+  va_start(ap, format);
+  Appendv(dst, format, ap);
+  va_end(ap);
+}
+
+}  // namespace strings
+}  // namespace tensorflow
diff --git a/tensorflow/core/platform/stringprintf.h b/tensorflow/core/platform/stringprintf.h
new file mode 100644
index 0000000..802b568
--- /dev/null
+++ b/tensorflow/core/platform/stringprintf.h
@@ -0,0 +1,52 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Printf variants that place their output in a C++ string.
+//
+// Usage:
+//      string result = strings::Printf("%d %s\n", 10, "hello");
+//      strings::Appendf(&result, "%d %s\n", 20, "there");
+
+#ifndef TENSORFLOW_CORE_PLATFORM_STRINGPRINTF_H_
+#define TENSORFLOW_CORE_PLATFORM_STRINGPRINTF_H_
+
+#include <stdarg.h>
+
+#include <string>
+
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace strings {
+
+// Return a C++ string
+extern string Printf(const char* format, ...)
+    // Tell the compiler to do printf format string checking.
+    TF_PRINTF_ATTRIBUTE(1, 2);
+
+// Append result to a supplied string
+extern void Appendf(string* dst, const char* format, ...)
+    // Tell the compiler to do printf format string checking.
+    TF_PRINTF_ATTRIBUTE(2, 3);
+
+// Lower-level routine that takes a va_list and appends to a specified
+// string.  All other routines are just convenience wrappers around it.
+extern void Appendv(string* dst, const char* format, va_list ap);
+
+}  // namespace strings
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_PLATFORM_STRINGPRINTF_H_
diff --git a/tensorflow/core/platform/stringprintf_test.cc b/tensorflow/core/platform/stringprintf_test.cc
new file mode 100644
index 0000000..d24523b
--- /dev/null
+++ b/tensorflow/core/platform/stringprintf_test.cc
@@ -0,0 +1,128 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/stringprintf.h"
+
+#include <string>
+
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace strings {
+namespace {
+
+TEST(PrintfTest, Empty) {
+  EXPECT_EQ("", Printf("%s", string().c_str()));
+  EXPECT_EQ("", Printf("%s", ""));
+}
+
+TEST(PrintfTest, Misc) {
+// MSVC does not support $ format specifier.
+#if !defined(_MSC_VER)
+  EXPECT_EQ("123hello w", Printf("%3$d%2$s %1$c", 'w', "hello", 123));
+#endif  // !_MSC_VER
+}
+
+TEST(AppendfTest, Empty) {
+  string value("Hello");
+  const char* empty = "";
+  Appendf(&value, "%s", empty);
+  EXPECT_EQ("Hello", value);
+}
+
+TEST(AppendfTest, EmptyString) {
+  string value("Hello");
+  Appendf(&value, "%s", "");
+  EXPECT_EQ("Hello", value);
+}
+
+TEST(AppendfTest, String) {
+  string value("Hello");
+  Appendf(&value, " %s", "World");
+  EXPECT_EQ("Hello World", value);
+}
+
+TEST(AppendfTest, Int) {
+  string value("Hello");
+  Appendf(&value, " %d", 123);
+  EXPECT_EQ("Hello 123", value);
+}
+
+TEST(PrintfTest, Multibyte) {
+  // If we are in multibyte mode and feed invalid multibyte sequence,
+  // Printf should return an empty string instead of running
+  // out of memory while trying to determine destination buffer size.
+  // see b/4194543.
+
+  char* old_locale = setlocale(LC_CTYPE, nullptr);
+  // Push locale with multibyte mode
+  setlocale(LC_CTYPE, "en_US.utf8");
+
+  const char kInvalidCodePoint[] = "\375\067s";
+  string value = Printf("%.*s", 3, kInvalidCodePoint);
+
+  // In some versions of glibc (e.g. eglibc-2.11.1, aka GRTEv2), snprintf
+  // returns error given an invalid codepoint. Other versions
+  // (e.g. eglibc-2.15, aka pre-GRTEv3) emit the codepoint verbatim.
+  // We test that the output is one of the above.
+  EXPECT_TRUE(value.empty() || value == kInvalidCodePoint);
+
+  // Repeat with longer string, to make sure that the dynamically
+  // allocated path in StringAppendV is handled correctly.
+  int n = 2048;
+  char* buf = new char[n + 1];
+  memset(buf, ' ', n - 3);
+  memcpy(buf + n - 3, kInvalidCodePoint, 4);
+  value = Printf("%.*s", n, buf);
+  // See GRTEv2 vs. GRTEv3 comment above.
+  EXPECT_TRUE(value.empty() || value == buf);
+  delete[] buf;
+
+  setlocale(LC_CTYPE, old_locale);
+}
+
+TEST(PrintfTest, NoMultibyte) {
+  // No multibyte handling, but the string contains funny chars.
+  char* old_locale = setlocale(LC_CTYPE, nullptr);
+  setlocale(LC_CTYPE, "POSIX");
+  string value = Printf("%.*s", 3, "\375\067s");
+  setlocale(LC_CTYPE, old_locale);
+  EXPECT_EQ("\375\067s", value);
+}
+
+TEST(PrintfTest, DontOverwriteErrno) {
+  // Check that errno isn't overwritten unless we're printing
+  // something significantly larger than what people are normally
+  // printing in their badly written PLOG() statements.
+  errno = ECHILD;
+  string value = Printf("Hello, %s!", "World");
+  EXPECT_EQ(ECHILD, errno);
+}
+
+TEST(PrintfTest, LargeBuf) {
+  // Check that the large buffer is handled correctly.
+  int n = 2048;
+  char* buf = new char[n + 1];
+  memset(buf, ' ', n);
+  buf[n] = 0;
+  string value = Printf("%s", buf);
+  EXPECT_EQ(buf, value);
+  delete[] buf;
+}
+
+}  // namespace
+
+}  // namespace strings
+}  // namespace tensorflow
diff --git a/tensorflow/core/platform/tstring.h b/tensorflow/core/platform/tstring.h
new file mode 100644
index 0000000..64a7a2d
--- /dev/null
+++ b/tensorflow/core/platform/tstring.h
@@ -0,0 +1,187 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_TSTRING_H_
+#define TENSORFLOW_CORE_PLATFORM_TSTRING_H_
+
+#include <string>
+
+// TODO(b/138799229): Used to toggle until global presubmits pass.
+// #define USE_TSTRING
+
+#ifdef USE_TSTRING
+
+// The inclusion of absl/strings/string_view.h in tstring.h would preclude the
+// use of tstring in tflite.  Given that, in order to mitigate the forced
+// inclusion of absl/strings/string_view.h while providing convenience methods
+// for implicit conversion, we replace explicit uses of absl::string_view with a
+// forward declaration and associated templates.
+namespace absl {
+class string_view;
+}
+
+namespace tensorflow {
+
+// tensorflow::tstring is the scalar type for DT_STRING tensors.
+//
+// TODO(b/138799229): In order to ease migration from tensorflow::string to
+// tensorflow::tstring, we define a simplified tstring class which wraps
+// std::string.  The API defined below is the expected subset of methods for
+// tstring.
+//
+// The underlying implementation of tstring will be replaced with the one
+// defined in [1] once the migration in tensorflow/ is complete.
+//
+// [1] https://github.com/tensorflow/community/pull/91
+class tstring {
+  std::string str_;
+
+ public:
+  tstring() = default;
+
+  tstring(const tstring&) = default;
+
+  tstring(const std::string& str) : str_(str) {}
+
+  tstring(const char* str, size_t len) : str_(str, len) {}
+
+  tstring(const char* str) : str_(str) {}
+
+  template <typename T, typename = std::enable_if_t<
+                            std::is_same<T, absl::string_view>::value, T>>
+  explicit tstring(const T& str) : str_(str.data(), str.size()) {}
+
+  tstring(tstring&&) noexcept = default;
+
+  ~tstring() = default;
+
+  tstring& operator=(const tstring& str) = default;
+
+  tstring& operator=(const std::string& str) {
+    str_ = str;
+
+    return *this;
+  }
+
+  template <typename T, typename = std::enable_if_t<
+                            std::is_same<T, absl::string_view>::value, T>>
+  tstring& operator=(const T& str) {
+    str_.assign(str.data(), str.size());
+
+    return *this;
+  }
+
+  tstring& operator=(const char* str) {
+    str_ = str;
+
+    return *this;
+  }
+
+  tstring& operator=(tstring&&) noexcept = default;
+
+  bool operator<(const tstring& o) const { return str_ < o.str_; }
+
+  bool operator>(const tstring& o) const { return str_ > o.str_; }
+
+  bool operator==(const char* o) const { return str_ == o; }
+
+  bool operator==(const tstring& o) const { return str_ == o.str_; }
+
+  bool operator!=(const char* o) const { return str_ != o; }
+
+  bool operator!=(const tstring& o) const { return str_ != o.str_; }
+
+  operator std::string() const { return str_; }
+
+  template <typename T, typename = std::enable_if_t<
+                            std::is_same<T, absl::string_view>::value, T>>
+  operator T() const {
+    return T(str_.data(), str_.size());
+  }
+
+  bool empty() const { return str_.empty(); }
+
+  size_t length() const { return str_.length(); }
+
+  size_t size() const { return str_.size(); }
+
+  const char* c_str() const { return str_.c_str(); }
+
+  const char* data() const { return str_.data(); }
+
+  const char& operator[](size_t i) const { return str_[i]; }
+
+  char* data() { return str_.data(); }
+
+  char& operator[](size_t i) { return str_[i]; }
+
+  void resize(size_t new_size) { str_.resize(new_size); }
+
+  tstring& assign(const char* str, size_t len) {
+    str_.assign(str, len);
+
+    return *this;
+  }
+
+  tstring& assign(const char* str) {
+    str_.assign(str);
+
+    return *this;
+  }
+
+  friend const tstring operator+(const tstring& a, const tstring& b);
+  friend bool operator==(const char* a, const tstring& b);
+  friend bool operator==(const std::string& a, const tstring& b);
+  friend std::ostream& operator<<(std::ostream& o, const tstring& str);
+  friend std::hash<tstring>;
+};
+
+inline bool operator==(const char* a, const tstring& b) { return a == b.str_; }
+
+inline bool operator==(const std::string& a, const tstring& b) {
+  return a == b.str_;
+}
+
+inline const tstring operator+(const tstring& a, const tstring& b) {
+  return tstring(a.str_ + b.str_);
+}
+
+inline std::ostream& operator<<(std::ostream& o, const tstring& str) {
+  return o << str.str_;
+}
+
+}  // namespace tensorflow
+
+namespace std {
+template <>
+struct hash<tensorflow::tstring> {
+  size_t operator()(const tensorflow::tstring& o) const {
+    std::hash<std::string> fn;
+    return fn(o.str_);
+  }
+};
+}  // namespace std
+
+#else  // USE_TSTRING
+
+namespace tensorflow {
+
+typedef std::string tstring;
+
+}  // namespace tensorflow
+
+#endif  // USE_TSTRING
+
+#endif  // TENSORFLOW_CORE_PLATFORM_TSTRING_H_
diff --git a/tensorflow/core/platform/types.h b/tensorflow/core/platform/types.h
index b82d9cc..ef6a8f9 100644
--- a/tensorflow/core/platform/types.h
+++ b/tensorflow/core/platform/types.h
@@ -17,7 +17,9 @@
 #define TENSORFLOW_CORE_PLATFORM_TYPES_H_
 
 #include <string>
+
 #include "tensorflow/core/platform/platform.h"
+#include "tensorflow/core/platform/tstring.h"
 
 // Include appropriate platform-dependent implementations
 #if defined(PLATFORM_GOOGLE) || defined(GOOGLE_INTEGRAL_TYPES)
diff --git a/tensorflow/core/platform/unbounded_work_queue.h b/tensorflow/core/platform/unbounded_work_queue.h
new file mode 100644
index 0000000..242980d
--- /dev/null
+++ b/tensorflow/core/platform/unbounded_work_queue.h
@@ -0,0 +1,33 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_UNBOUNDED_WORK_QUEUE_H_
+#define TENSORFLOW_CORE_PLATFORM_UNBOUNDED_WORK_QUEUE_H_
+
+#include "tensorflow/core/platform/platform.h"
+
+// An `UnboundedWorkQueue` feeds potentially-blocking work into a thread-pool
+// whose size automatically increases with demand.
+
+#if defined(PLATFORM_GOOGLE)
+#include "tensorflow/core/platform/google/unbounded_work_queue.h"
+#elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID) || \
+    defined(PLATFORM_GOOGLE_ANDROID) || defined(PLATFORM_WINDOWS)
+#include "tensorflow/core/platform/default/unbounded_work_queue.h"
+#else
+#error Define the appropriate PLATFORM_<foo> macro for this platform
+#endif
+
+#endif  // TENSORFLOW_CORE_PLATFORM_UNBOUNDED_WORK_QUEUE_H_
diff --git a/tensorflow/core/platform/unbounded_work_queue_test.cc b/tensorflow/core/platform/unbounded_work_queue_test.cc
new file mode 100644
index 0000000..03d91cd
--- /dev/null
+++ b/tensorflow/core/platform/unbounded_work_queue_test.cc
@@ -0,0 +1,104 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/unbounded_work_queue.h"
+
+#include "absl/memory/memory.h"
+#include "tensorflow/core/lib/core/blocking_counter.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+class UnboundedWorkQueueTest : public ::testing::Test {
+ protected:
+  UnboundedWorkQueueTest()
+      : work_queue_(
+            absl::make_unique<UnboundedWorkQueue>(Env::Default(), "test")) {}
+  ~UnboundedWorkQueueTest() override = default;
+
+  void RunMultipleCopiesOfClosure(const int num_closures,
+                                  std::function<void()> fn) {
+    for (int i = 0; i < num_closures; ++i) {
+      work_queue_->Schedule([this, fn]() {
+        fn();
+        mutex_lock l(mu_);
+        ++closure_count_;
+        cond_var_.notify_all();
+      });
+    }
+  }
+
+  void BlockUntilClosuresDone(const int num_closures) {
+    mutex_lock l(mu_);
+    while (closure_count_ < num_closures) {
+      cond_var_.wait(l);
+    }
+  }
+
+  void ResetQueue() { work_queue_.reset(); }
+
+  int NumClosuresExecuted() {
+    mutex_lock l(mu_);
+    return closure_count_;
+  }
+
+ private:
+  mutex mu_;
+  int closure_count_ GUARDED_BY(mu_) = 0;
+  condition_variable cond_var_;
+  std::unique_ptr<UnboundedWorkQueue> work_queue_;
+};
+
+TEST_F(UnboundedWorkQueueTest, SingleClosure) {
+  constexpr int num_closures = 1;
+  RunMultipleCopiesOfClosure(num_closures, []() {});
+  BlockUntilClosuresDone(num_closures);
+}
+
+TEST_F(UnboundedWorkQueueTest, MultipleClosures) {
+  constexpr int num_closures = 10;
+  RunMultipleCopiesOfClosure(num_closures, []() {});
+  BlockUntilClosuresDone(num_closures);
+}
+
+TEST_F(UnboundedWorkQueueTest, MultipleClosuresSleepingRandomly) {
+  constexpr int num_closures = 1000;
+  RunMultipleCopiesOfClosure(num_closures, []() {
+    Env::Default()->SleepForMicroseconds(random::New64() % 10);
+  });
+  BlockUntilClosuresDone(num_closures);
+}
+
+TEST_F(UnboundedWorkQueueTest, NestedClosures) {
+  constexpr int num_closures = 10;
+  // Run `num_closures` closures, each of which runs `num_closures` closures.
+  RunMultipleCopiesOfClosure(num_closures, [this]() {
+    RunMultipleCopiesOfClosure(num_closures, []() {});
+  });
+  BlockUntilClosuresDone(num_closures * num_closures + num_closures);
+}
+
+TEST_F(UnboundedWorkQueueTest, RacyDestructor) {
+  constexpr int num_closures = 100;
+  // Run `num_closures` closures, then delete `work_queue_`.
+  RunMultipleCopiesOfClosure(num_closures, []() {});
+  ResetQueue();
+  EXPECT_LE(NumClosuresExecuted(), num_closures);
+}
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/platform/windows/cpu_info.h b/tensorflow/core/platform/windows/cpu_info.h
deleted file mode 100644
index 8b42cbe..0000000
--- a/tensorflow/core/platform/windows/cpu_info.h
+++ /dev/null
@@ -1,22 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_PLATFORM_WINDOWS_CPU_INFO_H_
-#define TENSORFLOW_CORE_PLATFORM_WINDOWS_CPU_INFO_H_
-
-// included so __cpuidex function is available for GETCPUID on Windows
-#include <intrin.h>
-
-#endif  // TENSORFLOW_CORE_PLATFORM_WINDOWS_CPU_INFO_H_
diff --git a/tensorflow/core/platform/windows/port.cc b/tensorflow/core/platform/windows/port.cc
index 2aa84d6..2303b58 100644
--- a/tensorflow/core/platform/windows/port.cc
+++ b/tensorflow/core/platform/windows/port.cc
@@ -128,10 +128,6 @@
 
 std::size_t MallocExtension_GetAllocatedSize(const void* p) { return 0; }
 
-void AdjustFilenameForLogging(string* filename) {
-  // Nothing to do
-}
-
 bool Snappy_Compress(const char* input, size_t length, string* output) {
 #ifdef TF_USE_SNAPPY
   output->resize(snappy::MaxCompressedLength(length));
diff --git a/tensorflow/core/profiler/BUILD b/tensorflow/core/profiler/BUILD
index 470472e..73b938e 100644
--- a/tensorflow/core/profiler/BUILD
+++ b/tensorflow/core/profiler/BUILD
@@ -1,5 +1,5 @@
 load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
-load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library", "tf_additional_all_protos")
+load("//tensorflow/core/platform:default/build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
 # Placeholder for Google-internal load statements.
 
 package(
diff --git a/tensorflow/core/profiler/internal/BUILD b/tensorflow/core/profiler/internal/BUILD
index 71a3542..7439a7f 100644
--- a/tensorflow/core/profiler/internal/BUILD
+++ b/tensorflow/core/profiler/internal/BUILD
@@ -370,7 +370,8 @@
     srcs = ["traceme_recorder.cc"],
     hdrs = ["traceme_recorder.h"],
     visibility = [
-        "//perftools/accelerators/xprof/xprofilez:__pkg__",  # alias xprof::TraceMeRecorder
+        "//perftools/accelerators/xprof/xprofilez/cpu:__pkg__",  # host_tracer
+        "//perftools/accelerators/xprof/xprofilez/integration_tests:__pkg__",  # traceme_test
         "//tensorflow/core:__pkg__",  # executor.cc
         "//tensorflow/core/profiler/internal/cpu:__pkg__",  # host_tracer
         "//tensorflow/core/profiler/lib:__pkg__",  # traceme
@@ -378,7 +379,6 @@
     deps = [
         "//tensorflow/core:lib",
         "@com_google_absl//absl/base:core_headers",
-        "@com_google_absl//absl/container:flat_hash_map",
     ],
 )
 
@@ -437,3 +437,13 @@
         "@com_google_absl//absl/strings",
     ],
 )
+
+tf_cuda_library(
+    name = "python_traceme",
+    hdrs = ["python_traceme.h"],
+    visibility = ["//tensorflow/python:__pkg__"],
+    deps = [
+        "//tensorflow/core/profiler/lib:traceme",
+        "@com_google_absl//absl/types:optional",
+    ],
+)
diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer.cc b/tensorflow/core/profiler/internal/cpu/host_tracer.cc
index 6fddd58..4c45e24 100644
--- a/tensorflow/core/profiler/internal/cpu/host_tracer.cc
+++ b/tensorflow/core/profiler/internal/cpu/host_tracer.cc
@@ -141,7 +141,7 @@
 }  // namespace
 
 // Not in anonymous namespace for testing purposes.
-std::unique_ptr<ProfilerInterface> CreateHostTracer(const ProfilerContext*) {
+std::unique_ptr<ProfilerInterface> CreateHostTracer() {
   int host_trace_level = 2;
   return absl::make_unique<HostTracer>(host_trace_level);
 }
diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc b/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc
index 8b0e027..e047d9d 100644
--- a/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc
+++ b/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc
@@ -28,7 +28,8 @@
 namespace tensorflow {
 namespace profiler {
 namespace cpu {
-std::unique_ptr<ProfilerInterface> CreateHostTracer(const ProfilerContext*);
+
+std::unique_ptr<ProfilerInterface> CreateHostTracer();
 
 namespace {
 
@@ -80,7 +81,7 @@
 TEST(HostTracerTest, CollectsTraceMeEvents) {
   uint32 thread_id = Env::Default()->GetCurrentThreadId();
 
-  auto tracer = CreateHostTracer(nullptr);
+  auto tracer = CreateHostTracer();
 
   TF_ASSERT_OK(tracer->Start());
   { TraceMe traceme("hello"); }
diff --git a/tensorflow/core/profiler/internal/gpu/BUILD b/tensorflow/core/profiler/internal/gpu/BUILD
index 4622be2..dce7d39 100644
--- a/tensorflow/core/profiler/internal/gpu/BUILD
+++ b/tensorflow/core/profiler/internal/gpu/BUILD
@@ -7,3 +7,34 @@
     name = "device_tracer",
     actual = "//tensorflow/core:device_tracer",
 )
+
+load(
+    "//tensorflow:tensorflow.bzl",
+    "tf_copts",
+    "tf_cuda_library",
+    "if_cuda_is_configured_compat",
+)
+
+tf_cuda_library(
+    name = "cupti_interface",
+    hdrs = if_cuda_is_configured_compat(["cupti_interface.h"]),
+    copts = tf_copts(),
+    visibility = ["//visibility:public"],
+    deps = [
+        "//tensorflow/core:platform_base",
+        "//tensorflow/stream_executor/cuda:cupti_stub",
+        "@com_google_absl//absl/base:core_headers",
+    ],
+)
+
+tf_cuda_library(
+    name = "cupti_wrapper",
+    srcs = if_cuda_is_configured_compat(["cupti_wrapper.cc"]),
+    hdrs = if_cuda_is_configured_compat(["cupti_wrapper.h"]),
+    copts = tf_copts(),
+    visibility = ["//visibility:public"],
+    deps = [
+        ":cupti_interface",
+        "//tensorflow/stream_executor/cuda:cupti_stub",
+    ],
+)
diff --git a/tensorflow/core/profiler/internal/gpu/cupti_interface.h b/tensorflow/core/profiler/internal/gpu/cupti_interface.h
new file mode 100644
index 0000000..11baac4
--- /dev/null
+++ b/tensorflow/core/profiler/internal/gpu/cupti_interface.h
@@ -0,0 +1,196 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_CUPTI_INTERFACE_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_CUPTI_INTERFACE_H_
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h"
+#include "third_party/gpus/cuda/include/cuda.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace profiler {
+
+// Provides a wrapper interface to every single CUPTI API function. This class
+// is needed to create an easy mock object for CUPTI API calls. All member
+// functions are defined in the following order: activity related APIs, callback
+// related APIs, Event APIs, and metric APIs. Within each category, we follow
+// the order in the original CUPTI documentation.
+class CuptiInterface {
+ public:
+  CuptiInterface() {}
+
+  virtual ~CuptiInterface() {}
+
+  // CUPTI activity API
+  virtual CUptiResult ActivityDisable(CUpti_ActivityKind kind) = 0;
+
+  virtual CUptiResult ActivityEnable(CUpti_ActivityKind kind) = 0;
+
+  virtual CUptiResult ActivityFlushAll(uint32_t flag) = 0;
+
+  virtual CUptiResult ActivityGetNextRecord(uint8_t* buffer,
+                                            size_t valid_buffer_size_bytes,
+                                            CUpti_Activity** record) = 0;
+
+  virtual CUptiResult ActivityGetNumDroppedRecords(CUcontext context,
+                                                   uint32_t stream_id,
+                                                   size_t* dropped) = 0;
+
+  virtual CUptiResult ActivityConfigureUnifiedMemoryCounter(
+      CUpti_ActivityUnifiedMemoryCounterConfig* config, uint32_t count) = 0;
+
+  virtual CUptiResult ActivityRegisterCallbacks(
+      CUpti_BuffersCallbackRequestFunc func_buffer_requested,
+      CUpti_BuffersCallbackCompleteFunc func_buffer_completed) = 0;
+
+  virtual CUptiResult GetDeviceId(CUcontext context, uint32* deviceId) = 0;
+
+  virtual CUptiResult GetTimestamp(uint64_t* timestamp) = 0;
+
+  virtual CUptiResult Finalize() = 0;
+
+  // CUPTI callback API
+  virtual CUptiResult EnableCallback(uint32_t enable,
+                                     CUpti_SubscriberHandle subscriber,
+                                     CUpti_CallbackDomain domain,
+                                     CUpti_CallbackId cbid) = 0;
+
+  virtual CUptiResult EnableDomain(uint32_t enable,
+                                   CUpti_SubscriberHandle subscriber,
+                                   CUpti_CallbackDomain domain) = 0;
+
+  virtual CUptiResult Subscribe(CUpti_SubscriberHandle* subscriber,
+                                CUpti_CallbackFunc callback,
+                                void* userdata) = 0;
+
+  virtual CUptiResult Unsubscribe(CUpti_SubscriberHandle subscriber) = 0;
+
+  // CUPTI event API
+  virtual CUptiResult DeviceEnumEventDomains(
+      CUdevice device, size_t* array_size_bytes,
+      CUpti_EventDomainID* domain_array) = 0;
+
+  virtual CUptiResult DeviceGetEventDomainAttribute(
+      CUdevice device, CUpti_EventDomainID event_domain,
+      CUpti_EventDomainAttribute attrib, size_t* value_size, void* value) = 0;
+
+  virtual CUptiResult DisableKernelReplayMode(CUcontext context) = 0;
+
+  virtual CUptiResult EnableKernelReplayMode(CUcontext context) = 0;
+
+  virtual CUptiResult DeviceGetNumEventDomains(CUdevice device,
+                                               uint32_t* num_domains) = 0;
+
+  virtual CUptiResult EventDomainEnumEvents(CUpti_EventDomainID event_domain,
+                                            size_t* array_size_bytes,
+                                            CUpti_EventID* event_array) = 0;
+
+  virtual CUptiResult EventDomainGetNumEvents(CUpti_EventDomainID event_domain,
+                                              uint32_t* num_events) = 0;
+
+  virtual CUptiResult EventGetAttribute(CUpti_EventID event,
+                                        CUpti_EventAttribute attrib,
+                                        size_t* value_size, void* value) = 0;
+
+  virtual CUptiResult EventGetIdFromName(CUdevice device,
+                                         const char* event_name,
+                                         CUpti_EventID* event) = 0;
+
+  virtual CUptiResult EventGroupDisable(CUpti_EventGroup event_group) = 0;
+
+  virtual CUptiResult EventGroupEnable(CUpti_EventGroup event_group) = 0;
+
+  virtual CUptiResult EventGroupGetAttribute(CUpti_EventGroup event_group,
+                                             CUpti_EventGroupAttribute attrib,
+                                             size_t* value_size,
+                                             void* value) = 0;
+
+  virtual CUptiResult EventGroupReadEvent(CUpti_EventGroup event_group,
+                                          CUpti_ReadEventFlags flags,
+                                          CUpti_EventID event,
+                                          size_t* event_value_buffer_size_bytes,
+                                          uint64_t* eventValueBuffer) = 0;
+
+  virtual CUptiResult EventGroupSetAttribute(CUpti_EventGroup event_group,
+                                             CUpti_EventGroupAttribute attrib,
+                                             size_t value_size,
+                                             void* value) = 0;
+
+  virtual CUptiResult EventGroupSetsCreate(
+      CUcontext context, size_t event_id_array_size_bytes,
+      CUpti_EventID* event_id_array,
+      CUpti_EventGroupSets** event_group_passes) = 0;
+
+  virtual CUptiResult EventGroupSetsDestroy(
+      CUpti_EventGroupSets* event_group_sets) = 0;
+
+  // CUPTI metric API
+  virtual CUptiResult DeviceEnumMetrics(CUdevice device, size_t* arraySizeBytes,
+                                        CUpti_MetricID* metricArray) = 0;
+
+  virtual CUptiResult DeviceGetNumMetrics(CUdevice device,
+                                          uint32_t* num_metrics) = 0;
+
+  virtual CUptiResult MetricGetIdFromName(CUdevice device,
+                                          const char* metric_name,
+                                          CUpti_MetricID* metric) = 0;
+
+  virtual CUptiResult MetricGetNumEvents(CUpti_MetricID metric,
+                                         uint32_t* num_events) = 0;
+
+  virtual CUptiResult MetricEnumEvents(CUpti_MetricID metric,
+                                       size_t* event_id_array_size_bytes,
+                                       CUpti_EventID* event_id_array) = 0;
+
+  virtual CUptiResult MetricGetAttribute(CUpti_MetricID metric,
+                                         CUpti_MetricAttribute attrib,
+                                         size_t* value_size, void* value) = 0;
+
+  virtual CUptiResult MetricGetValue(CUdevice device, CUpti_MetricID metric,
+                                     size_t event_id_array_size_bytes,
+                                     CUpti_EventID* event_id_array,
+                                     size_t event_value_array_size_bytes,
+                                     uint64_t* event_value_array,
+                                     uint64_t time_duration,
+                                     CUpti_MetricValue* metric_value) = 0;
+
+  virtual CUptiResult GetResultString(CUptiResult result, const char** str) = 0;
+
+  // Interface maintenance functions. Not directly related to CUPTI, but
+  // required for implementing an error resilient layer over CUPTI API.
+
+  // Performance any clean up work that is required each time profile session
+  // is done. Therefore this can be called multiple times during process life
+  // time.
+  virtual void CleanUp() = 0;
+
+  // Whether CUPTI API is currently disabled due to unrecoverable errors.
+  // All subsequent calls will fail immediately without forwarding calls to
+  // CUPTI library.
+  virtual bool Disabled() const = 0;
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(CuptiInterface);
+};
+
+}  // namespace profiler
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_CUPTI_INTERFACE_H_
diff --git a/tensorflow/core/profiler/internal/gpu/cupti_wrapper.cc b/tensorflow/core/profiler/internal/gpu/cupti_wrapper.cc
new file mode 100644
index 0000000..ef2aff3
--- /dev/null
+++ b/tensorflow/core/profiler/internal/gpu/cupti_wrapper.cc
@@ -0,0 +1,237 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/profiler/internal/gpu/cupti_wrapper.h"
+
+#include <type_traits>
+
+namespace tensorflow {
+namespace profiler {
+
+CUptiResult CuptiWrapper::ActivityDisable(CUpti_ActivityKind kind) {
+  return cuptiActivityDisable(kind);
+}
+
+CUptiResult CuptiWrapper::ActivityEnable(CUpti_ActivityKind kind) {
+  return cuptiActivityEnable(kind);
+}
+
+CUptiResult CuptiWrapper::ActivityFlushAll(uint32_t flag) {
+  return cuptiActivityFlushAll(flag);
+}
+
+CUptiResult CuptiWrapper::ActivityGetNextRecord(uint8_t* buffer,
+                                                size_t valid_buffer_size_bytes,
+                                                CUpti_Activity** record) {
+  return cuptiActivityGetNextRecord(buffer, valid_buffer_size_bytes, record);
+}
+
+CUptiResult CuptiWrapper::ActivityGetNumDroppedRecords(CUcontext context,
+                                                       uint32_t stream_id,
+                                                       size_t* dropped) {
+  return cuptiActivityGetNumDroppedRecords(context, stream_id, dropped);
+}
+
+CUptiResult CuptiWrapper::ActivityConfigureUnifiedMemoryCounter(
+    CUpti_ActivityUnifiedMemoryCounterConfig* config, uint32_t count) {
+  return cuptiActivityConfigureUnifiedMemoryCounter(config, count);
+}
+
+CUptiResult CuptiWrapper::ActivityRegisterCallbacks(
+    CUpti_BuffersCallbackRequestFunc func_buffer_requested,
+    CUpti_BuffersCallbackCompleteFunc func_buffer_completed) {
+  return cuptiActivityRegisterCallbacks(func_buffer_requested,
+                                        func_buffer_completed);
+}
+
+CUptiResult CuptiWrapper::GetDeviceId(CUcontext context, uint32* deviceId) {
+  return cuptiGetDeviceId(context, deviceId);
+}
+
+CUptiResult CuptiWrapper::GetTimestamp(uint64_t* timestamp) {
+  return cuptiGetTimestamp(timestamp);
+}
+
+CUptiResult CuptiWrapper::Finalize() { return cuptiFinalize(); }
+
+CUptiResult CuptiWrapper::EnableCallback(uint32_t enable,
+                                         CUpti_SubscriberHandle subscriber,
+                                         CUpti_CallbackDomain domain,
+                                         CUpti_CallbackId cbid) {
+  return cuptiEnableCallback(enable, subscriber, domain, cbid);
+}
+
+CUptiResult CuptiWrapper::EnableDomain(uint32_t enable,
+                                       CUpti_SubscriberHandle subscriber,
+                                       CUpti_CallbackDomain domain) {
+  return cuptiEnableDomain(enable, subscriber, domain);
+}
+
+CUptiResult CuptiWrapper::Subscribe(CUpti_SubscriberHandle* subscriber,
+                                    CUpti_CallbackFunc callback,
+                                    void* userdata) {
+  return cuptiSubscribe(subscriber, callback, userdata);
+}
+
+CUptiResult CuptiWrapper::Unsubscribe(CUpti_SubscriberHandle subscriber) {
+  return cuptiUnsubscribe(subscriber);
+}
+
+CUptiResult CuptiWrapper::DeviceEnumEventDomains(
+    CUdevice device, size_t* array_size_bytes,
+    CUpti_EventDomainID* domain_array) {
+  return cuptiDeviceEnumEventDomains(device, array_size_bytes, domain_array);
+}
+
+CUptiResult CuptiWrapper::DeviceGetEventDomainAttribute(
+    CUdevice device, CUpti_EventDomainID event_domain,
+    CUpti_EventDomainAttribute attrib, size_t* value_size, void* value) {
+  return cuptiDeviceGetEventDomainAttribute(device, event_domain, attrib,
+                                            value_size, value);
+}
+
+CUptiResult CuptiWrapper::DisableKernelReplayMode(CUcontext context) {
+  return cuptiDisableKernelReplayMode(context);
+}
+
+CUptiResult CuptiWrapper::EnableKernelReplayMode(CUcontext context) {
+  return cuptiEnableKernelReplayMode(context);
+}
+
+CUptiResult CuptiWrapper::DeviceGetNumEventDomains(CUdevice device,
+                                                   uint32_t* num_domains) {
+  return cuptiDeviceGetNumEventDomains(device, num_domains);
+}
+
+CUptiResult CuptiWrapper::EventDomainEnumEvents(
+    CUpti_EventDomainID event_domain, size_t* array_size_bytes,
+    CUpti_EventID* event_array) {
+  return cuptiEventDomainEnumEvents(event_domain, array_size_bytes,
+                                    event_array);
+}
+
+CUptiResult CuptiWrapper::EventDomainGetNumEvents(
+    CUpti_EventDomainID event_domain, uint32_t* num_events) {
+  return cuptiEventDomainGetNumEvents(event_domain, num_events);
+}
+
+CUptiResult CuptiWrapper::EventGetAttribute(CUpti_EventID event,
+                                            CUpti_EventAttribute attrib,
+                                            size_t* value_size, void* value) {
+  return cuptiEventGetAttribute(event, attrib, value_size, value);
+}
+
+CUptiResult CuptiWrapper::EventGetIdFromName(CUdevice device,
+                                             const char* event_name,
+                                             CUpti_EventID* event) {
+  return cuptiEventGetIdFromName(device, event_name, event);
+}
+
+CUptiResult CuptiWrapper::EventGroupDisable(CUpti_EventGroup event_group) {
+  return cuptiEventGroupDisable(event_group);
+}
+
+CUptiResult CuptiWrapper::EventGroupEnable(CUpti_EventGroup event_group) {
+  return cuptiEventGroupEnable(event_group);
+}
+
+CUptiResult CuptiWrapper::EventGroupGetAttribute(
+    CUpti_EventGroup event_group, CUpti_EventGroupAttribute attrib,
+    size_t* value_size, void* value) {
+  return cuptiEventGroupGetAttribute(event_group, attrib, value_size, value);
+}
+
+CUptiResult CuptiWrapper::EventGroupReadEvent(
+    CUpti_EventGroup event_group, CUpti_ReadEventFlags flags,
+    CUpti_EventID event, size_t* event_value_buffer_size_bytes,
+    uint64_t* event_value_buffer) {
+  return cuptiEventGroupReadEvent(event_group, flags, event,
+                                  event_value_buffer_size_bytes,
+                                  event_value_buffer);
+}
+
+CUptiResult CuptiWrapper::EventGroupSetAttribute(
+    CUpti_EventGroup event_group, CUpti_EventGroupAttribute attrib,
+    size_t value_size, void* value) {
+  return cuptiEventGroupSetAttribute(event_group, attrib, value_size, value);
+}
+
+CUptiResult CuptiWrapper::EventGroupSetsCreate(
+    CUcontext context, size_t event_id_array_size_bytes,
+    CUpti_EventID* event_id_array, CUpti_EventGroupSets** event_group_passes) {
+  return cuptiEventGroupSetsCreate(context, event_id_array_size_bytes,
+                                   event_id_array, event_group_passes);
+}
+
+CUptiResult CuptiWrapper::EventGroupSetsDestroy(
+    CUpti_EventGroupSets* event_group_sets) {
+  return cuptiEventGroupSetsDestroy(event_group_sets);
+}
+
+// CUPTI metric API
+CUptiResult CuptiWrapper::DeviceEnumMetrics(CUdevice device,
+                                            size_t* arraySizeBytes,
+                                            CUpti_MetricID* metricArray) {
+  return cuptiDeviceEnumMetrics(device, arraySizeBytes, metricArray);
+}
+
+CUptiResult CuptiWrapper::DeviceGetNumMetrics(CUdevice device,
+                                              uint32_t* num_metrics) {
+  return cuptiDeviceGetNumMetrics(device, num_metrics);
+}
+
+CUptiResult CuptiWrapper::MetricGetIdFromName(CUdevice device,
+                                              const char* metric_name,
+                                              CUpti_MetricID* metric) {
+  return cuptiMetricGetIdFromName(device, metric_name, metric);
+}
+
+CUptiResult CuptiWrapper::MetricGetNumEvents(CUpti_MetricID metric,
+                                             uint32_t* num_events) {
+  return cuptiMetricGetNumEvents(metric, num_events);
+}
+
+CUptiResult CuptiWrapper::MetricEnumEvents(CUpti_MetricID metric,
+                                           size_t* event_id_array_size_bytes,
+                                           CUpti_EventID* event_id_array) {
+  return cuptiMetricEnumEvents(metric, event_id_array_size_bytes,
+                               event_id_array);
+}
+
+CUptiResult CuptiWrapper::MetricGetAttribute(CUpti_MetricID metric,
+                                             CUpti_MetricAttribute attrib,
+                                             size_t* value_size, void* value) {
+  return cuptiMetricGetAttribute(metric, attrib, value_size, value);
+}
+
+CUptiResult CuptiWrapper::MetricGetValue(CUdevice device, CUpti_MetricID metric,
+                                         size_t event_id_array_size_bytes,
+                                         CUpti_EventID* event_id_array,
+                                         size_t event_value_array_size_bytes,
+                                         uint64_t* event_value_array,
+                                         uint64_t time_duration,
+                                         CUpti_MetricValue* metric_value) {
+  return cuptiMetricGetValue(device, metric, event_id_array_size_bytes,
+                             event_id_array, event_value_array_size_bytes,
+                             event_value_array, time_duration, metric_value);
+}
+
+CUptiResult CuptiWrapper::GetResultString(CUptiResult result,
+                                          const char** str) {
+  return cuptiGetResultString(result, str);
+}
+
+}  // namespace profiler
+}  // namespace tensorflow
diff --git a/tensorflow/core/profiler/internal/gpu/cupti_wrapper.h b/tensorflow/core/profiler/internal/gpu/cupti_wrapper.h
new file mode 100644
index 0000000..e7a586d
--- /dev/null
+++ b/tensorflow/core/profiler/internal/gpu/cupti_wrapper.h
@@ -0,0 +1,179 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_CUPTI_WRAPPER_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_CUPTI_WRAPPER_H_
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h"
+#include "third_party/gpus/cuda/include/cuda.h"
+#include "tensorflow/core/profiler/internal/gpu/cupti_interface.h"
+
+namespace tensorflow {
+namespace profiler {
+
+class CuptiWrapper : public tensorflow::profiler::CuptiInterface {
+ public:
+  CuptiWrapper() {}
+
+  ~CuptiWrapper() override {}
+
+  // CUPTI activity API
+  CUptiResult ActivityDisable(CUpti_ActivityKind kind) override;
+
+  CUptiResult ActivityEnable(CUpti_ActivityKind kind) override;
+
+  CUptiResult ActivityFlushAll(uint32_t flag) override;
+
+  CUptiResult ActivityGetNextRecord(uint8_t* buffer,
+                                    size_t valid_buffer_size_bytes,
+                                    CUpti_Activity** record) override;
+
+  CUptiResult ActivityGetNumDroppedRecords(CUcontext context,
+                                           uint32_t stream_id,
+                                           size_t* dropped) override;
+
+  CUptiResult ActivityConfigureUnifiedMemoryCounter(
+      CUpti_ActivityUnifiedMemoryCounterConfig* config,
+      uint32_t count) override;
+
+  CUptiResult ActivityRegisterCallbacks(
+      CUpti_BuffersCallbackRequestFunc func_buffer_requested,
+      CUpti_BuffersCallbackCompleteFunc func_buffer_completed) override;
+
+  CUptiResult GetDeviceId(CUcontext context, uint32* deviceId) override;
+
+  CUptiResult GetTimestamp(uint64_t* timestamp) override;
+
+  // cuptiFinalize is only defined in CUDA8 and above.
+  // To enable it in CUDA8, the environment variable CUPTI_ENABLE_FINALIZE must
+  // be set to 1.
+  CUptiResult Finalize() override;
+
+  // CUPTI callback API
+  CUptiResult EnableCallback(uint32_t enable, CUpti_SubscriberHandle subscriber,
+                             CUpti_CallbackDomain domain,
+                             CUpti_CallbackId cbid) override;
+
+  CUptiResult EnableDomain(uint32_t enable, CUpti_SubscriberHandle subscriber,
+                           CUpti_CallbackDomain domain) override;
+
+  CUptiResult Subscribe(CUpti_SubscriberHandle* subscriber,
+                        CUpti_CallbackFunc callback, void* userdata) override;
+
+  CUptiResult Unsubscribe(CUpti_SubscriberHandle subscriber) override;
+
+  // CUPTI event API
+  CUptiResult DeviceEnumEventDomains(
+      CUdevice device, size_t* array_size_bytes,
+      CUpti_EventDomainID* domain_array) override;
+
+  CUptiResult DeviceGetEventDomainAttribute(CUdevice device,
+                                            CUpti_EventDomainID event_domain,
+                                            CUpti_EventDomainAttribute attrib,
+                                            size_t* value_size,
+                                            void* value) override;
+
+  CUptiResult DisableKernelReplayMode(CUcontext context) override;
+
+  CUptiResult EnableKernelReplayMode(CUcontext context) override;
+
+  CUptiResult DeviceGetNumEventDomains(CUdevice device,
+                                       uint32_t* num_domains) override;
+
+  CUptiResult EventDomainEnumEvents(CUpti_EventDomainID event_domain,
+                                    size_t* array_size_bytes,
+                                    CUpti_EventID* event_array) override;
+
+  CUptiResult EventDomainGetNumEvents(CUpti_EventDomainID event_domain,
+                                      uint32_t* num_events) override;
+
+  CUptiResult EventGetAttribute(CUpti_EventID event,
+                                CUpti_EventAttribute attrib, size_t* value_size,
+                                void* value) override;
+
+  CUptiResult EventGetIdFromName(CUdevice device, const char* event_name,
+                                 CUpti_EventID* event) override;
+
+  CUptiResult EventGroupDisable(CUpti_EventGroup event_group) override;
+
+  CUptiResult EventGroupEnable(CUpti_EventGroup event_group) override;
+
+  CUptiResult EventGroupGetAttribute(CUpti_EventGroup event_group,
+                                     CUpti_EventGroupAttribute attrib,
+                                     size_t* value_size, void* value) override;
+
+  CUptiResult EventGroupReadEvent(CUpti_EventGroup event_group,
+                                  CUpti_ReadEventFlags flags,
+                                  CUpti_EventID event,
+                                  size_t* event_value_buffer_size_bytes,
+                                  uint64_t* event_value_buffer) override;
+
+  CUptiResult EventGroupSetAttribute(CUpti_EventGroup event_group,
+                                     CUpti_EventGroupAttribute attrib,
+                                     size_t value_size, void* value) override;
+
+  CUptiResult EventGroupSetsCreate(
+      CUcontext context, size_t event_id_array_size_bytes,
+      CUpti_EventID* event_id_array,
+      CUpti_EventGroupSets** event_group_passes) override;
+
+  CUptiResult EventGroupSetsDestroy(
+      CUpti_EventGroupSets* event_group_sets) override;
+
+  // CUPTI metric API
+  CUptiResult DeviceEnumMetrics(CUdevice device, size_t* arraySizeBytes,
+                                CUpti_MetricID* metricArray) override;
+
+  CUptiResult DeviceGetNumMetrics(CUdevice device,
+                                  uint32_t* num_metrics) override;
+
+  CUptiResult MetricGetIdFromName(CUdevice device, const char* metric_name,
+                                  CUpti_MetricID* metric) override;
+
+  CUptiResult MetricGetNumEvents(CUpti_MetricID metric,
+                                 uint32_t* num_events) override;
+
+  CUptiResult MetricEnumEvents(CUpti_MetricID metric,
+                               size_t* event_id_array_size_bytes,
+                               CUpti_EventID* event_id_array) override;
+
+  CUptiResult MetricGetAttribute(CUpti_MetricID metric,
+                                 CUpti_MetricAttribute attrib,
+                                 size_t* value_size, void* value) override;
+
+  CUptiResult MetricGetValue(CUdevice device, CUpti_MetricID metric,
+                             size_t event_id_array_size_bytes,
+                             CUpti_EventID* event_id_array,
+                             size_t event_value_array_size_bytes,
+                             uint64_t* event_value_array,
+                             uint64_t time_duration,
+                             CUpti_MetricValue* metric_value) override;
+
+  CUptiResult GetResultString(CUptiResult result, const char** str) override;
+
+  void CleanUp() override {}
+  bool Disabled() const override { return false; }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(CuptiWrapper);
+};
+
+}  // namespace profiler
+}  // namespace tensorflow
+
+#endif  // PERFTOOLS_ACCELERATORS_XPROF_XPROFILEZ_NVIDIA_GPU_CUPTI_WRAPPER_H_
diff --git a/tensorflow/core/profiler/internal/profiler_interface.cc b/tensorflow/core/profiler/internal/profiler_interface.cc
index 2f48102..f71e538 100644
--- a/tensorflow/core/profiler/internal/profiler_interface.cc
+++ b/tensorflow/core/profiler/internal/profiler_interface.cc
@@ -34,11 +34,10 @@
 }
 
 void CreateProfilers(
-    const ProfilerContext* context,
     std::vector<std::unique_ptr<profiler::ProfilerInterface>>* result) {
   absl::MutexLock lock(GetMutex());
   for (auto factory : *GetFactories()) {
-    if (auto profiler = factory(context)) {
+    if (auto profiler = factory()) {
       result->push_back(std::move(profiler));
     }
   }
diff --git a/tensorflow/core/profiler/internal/profiler_interface.h b/tensorflow/core/profiler/internal/profiler_interface.h
index 4754f4f..09dbe51 100644
--- a/tensorflow/core/profiler/internal/profiler_interface.h
+++ b/tensorflow/core/profiler/internal/profiler_interface.h
@@ -15,15 +15,13 @@
 #ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_PROFILER_INTERFACE_H_
 #define TENSORFLOW_CORE_PROFILER_INTERNAL_PROFILER_INTERFACE_H_
 
+#include <memory>
+#include <vector>
+
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/protobuf/config.pb.h"
 
 namespace tensorflow {
-class EagerContext;
-struct ProfilerContext {
-  EagerContext* eager_context = nullptr;
-};
-
 namespace profiler {
 
 // Interface for tensorflow profiler plugins.
@@ -50,13 +48,11 @@
 
 }  // namespace profiler
 
-using ProfilerFactory =
-    std::unique_ptr<profiler::ProfilerInterface> (*)(const ProfilerContext*);
+using ProfilerFactory = std::unique_ptr<profiler::ProfilerInterface> (*)();
 
 void RegisterProfilerFactory(ProfilerFactory factory);
 
 void CreateProfilers(
-    const ProfilerContext* context,
     std::vector<std::unique_ptr<profiler::ProfilerInterface>>* result);
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/profiler/internal/python_traceme.h b/tensorflow/core/profiler/internal/python_traceme.h
new file mode 100644
index 0000000..ceb1154
--- /dev/null
+++ b/tensorflow/core/profiler/internal/python_traceme.h
@@ -0,0 +1,44 @@
+/* Copyright 2019 The TensorFlow Authors All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_PYTHON_TRACEME_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_PYTHON_TRACEME_H_
+
+#include <string>
+#include <utility>
+
+#include "absl/types/optional.h"
+#include "tensorflow/core/profiler/lib/traceme.h"
+
+namespace tensorflow {
+namespace profiler {
+
+// DO NOT USE THIS CLASS DIRECTLY IN C++ CODE.
+// This class is only used to implement TraceMe as a python context manager.
+class PythonTraceMe {
+ public:
+  explicit PythonTraceMe(const std::string& name) : activity_name_(name) {}
+  void Enter() { current_.emplace(std::move(activity_name_)); }
+  void Exit() { current_.reset(); }
+
+ private:
+  std::string activity_name_;
+  absl::optional<TraceMe> current_;
+};
+
+}  // namespace profiler
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_PROFILER_INTERNAL_PYTHON_TRACEME_H_
diff --git a/tensorflow/core/profiler/internal/tfprof_code.cc b/tensorflow/core/profiler/internal/tfprof_code.cc
index 88bcf9e..b39d9a9 100644
--- a/tensorflow/core/profiler/internal/tfprof_code.cc
+++ b/tensorflow/core/profiler/internal/tfprof_code.cc
@@ -295,13 +295,23 @@
     io::ZlibOutputBuffer* zlib_output_buffer = new io::ZlibOutputBuffer(
         file.get(), buf_size, buf_size, io::ZlibCompressionOptions::GZIP());
     s = zlib_output_buffer->Init();
-    if (!s.ok()) return s;
+    if (!s.ok()) {
+      delete zlib_output_buffer;
+      return s;
+    }
     s = zlib_output_buffer->Append(profile_pb.SerializeAsString());
-    if (!s.ok()) return s;
+    if (!s.ok()) {
+      delete zlib_output_buffer;
+      return s;
+    }
     s = zlib_output_buffer->Close();
-    if (!s.ok()) return s;
+    if (!s.ok()) {
+      delete zlib_output_buffer;
+      return s;
+    }
     fprintf(stdout, "\nRun pprof -png --nodecount=100 --sample_index=1 <%s>\n",
             filename.c_str());
+    delete zlib_output_buffer;
     return s;
   }
 
diff --git a/tensorflow/core/profiler/internal/tfprof_stats.cc b/tensorflow/core/profiler/internal/tfprof_stats.cc
index b586708..237814f 100644
--- a/tensorflow/core/profiler/internal/tfprof_stats.cc
+++ b/tensorflow/core/profiler/internal/tfprof_stats.cc
@@ -198,8 +198,9 @@
       continue;
     }
     node_added = true;
+    size_t num_nodes = nodes_map_.size();
     nodes_map_[node.name()] = std::unique_ptr<TFGraphNode>(
-        new TFGraphNode(&node, nodes_map_.size(), &nodes_map_));
+        new TFGraphNode(&node, num_nodes, &nodes_map_));
     node_defs[node.name()] = &node;
   }
   for (auto it = node_defs.begin(); it != node_defs.end(); it++) {
@@ -292,8 +293,9 @@
       if (node == nodes_map_.end()) {
         NodeDef def;
         if (CreateRunMetadataNode(name, &def)) {
+          size_t num_nodes = nodes_map_.size();
           nodes_map_[name] = std::unique_ptr<TFGraphNode>(
-              new TFGraphNode(&def, nodes_map_.size(), &nodes_map_));
+              new TFGraphNode(&def, num_nodes, &nodes_map_));
           nodes_map_.at(name)->AddStepStat(step, dev_stat.device(), node_stat);
         }
       } else {
diff --git a/tensorflow/core/profiler/internal/traceme_recorder.h b/tensorflow/core/profiler/internal/traceme_recorder.h
index 3740297..921df85 100644
--- a/tensorflow/core/profiler/internal/traceme_recorder.h
+++ b/tensorflow/core/profiler/internal/traceme_recorder.h
@@ -18,10 +18,10 @@
 #include <atomic>
 #include <cstddef>
 #include <string>
+#include <unordered_map>
 #include <vector>
 
 #include "absl/base/optimization.h"
-#include "absl/container/flat_hash_map.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/types.h"
 
@@ -109,7 +109,7 @@
   mutex mutex_;
   // Map of the static container instances (thread_local storage) for each
   // thread. While active, a ThreadLocalRecorder stores trace events.
-  absl::flat_hash_map<int32, ThreadLocalRecorder*> threads_ GUARDED_BY(mutex_);
+  std::unordered_map<int32, ThreadLocalRecorder*> threads_ GUARDED_BY(mutex_);
   // Events from threads that died during recording.
   TraceMeRecorder::Events orphaned_events_ GUARDED_BY(mutex_);
 };
diff --git a/tensorflow/core/profiler/lib/BUILD b/tensorflow/core/profiler/lib/BUILD
index d4dd151..81d3307 100644
--- a/tensorflow/core/profiler/lib/BUILD
+++ b/tensorflow/core/profiler/lib/BUILD
@@ -3,7 +3,7 @@
     "tf_cuda_library",
 )
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_additional_profiler_lib_deps",
 )
 
diff --git a/tensorflow/core/profiler/lib/profiler_session.cc b/tensorflow/core/profiler/lib/profiler_session.cc
index 390ab14..fb84d5b 100644
--- a/tensorflow/core/profiler/lib/profiler_session.cc
+++ b/tensorflow/core/profiler/lib/profiler_session.cc
@@ -130,9 +130,8 @@
 }
 }  // namespace
 
-/*static*/ std::unique_ptr<ProfilerSession> ProfilerSession::Create(
-    ProfilerContext* const context) {
-  return absl::WrapUnique(new ProfilerSession(context));
+/*static*/ std::unique_ptr<ProfilerSession> ProfilerSession::Create() {
+  return absl::WrapUnique(new ProfilerSession());
 }
 
 Status ProfilerSession::Status() {
@@ -173,7 +172,7 @@
   return Status::OK();
 }
 
-ProfilerSession::ProfilerSession(ProfilerContext* const context)
+ProfilerSession::ProfilerSession()
     : active_(!session_active.exchange(true)),
       start_time_micros_(Env::Default()->NowNanos() / EnvTime::kMicrosToNanos) {
   if (!active_) {
@@ -184,7 +183,7 @@
 
   LOG(INFO) << "Profiler session started.";
 
-  CreateProfilers(context, &profilers_);
+  CreateProfilers(&profilers_);
   status_ = Status::OK();
 
   for (auto& profiler : profilers_) {
diff --git a/tensorflow/core/profiler/lib/profiler_session.h b/tensorflow/core/profiler/lib/profiler_session.h
index b1a1233..b5a96c5 100644
--- a/tensorflow/core/profiler/lib/profiler_session.h
+++ b/tensorflow/core/profiler/lib/profiler_session.h
@@ -32,8 +32,7 @@
 class ProfilerSession {
  public:
   // Creates and ProfilerSession and starts profiling.
-  static std::unique_ptr<ProfilerSession> Create(
-      ProfilerContext* const context);
+  static std::unique_ptr<ProfilerSession> Create();
 
   // Deletes an exsiting Profiler and enables starting a new one.
   ~ProfilerSession();
@@ -45,9 +44,9 @@
 
  private:
   // Constructs an instance of the class and starts profiling
-  explicit ProfilerSession(ProfilerContext* const context);
+  ProfilerSession();
 
-  // Profiler is neither copyable or movable.
+  // ProfilerSession is neither copyable or movable.
   ProfilerSession(const ProfilerSession&) = delete;
   ProfilerSession& operator=(const ProfilerSession&) = delete;
 
diff --git a/tensorflow/core/profiler/lib/traceme.cc b/tensorflow/core/profiler/lib/traceme.cc
index 90272b8..7d02cfa 100644
--- a/tensorflow/core/profiler/lib/traceme.cc
+++ b/tensorflow/core/profiler/lib/traceme.cc
@@ -32,14 +32,14 @@
     absl::string_view activity_name) {
   uint64 activity_id = NewActivityId();
   TraceMeRecorder::Record({activity_id, string(activity_name),
-                           /*start_time=*/Env::Default()->NowNanos(),
+                           /*start_time=*/EnvTime::Default()->NowNanos(),
                            /*end_time=*/0});
   return activity_id;
 }
 
 /* static */ void TraceMe::ActivityEndImpl(uint64 activity_id) {
   TraceMeRecorder::Record({activity_id, /*name=*/"", /*start_time=*/0,
-                           /*end_time=*/Env::Default()->NowNanos()});
+                           /*end_time=*/EnvTime::Default()->NowNanos()});
 }
 
 }  // namespace profiler
diff --git a/tensorflow/core/profiler/lib/traceme.h b/tensorflow/core/profiler/lib/traceme.h
index 5a5ba52..b8e4acf 100644
--- a/tensorflow/core/profiler/lib/traceme.h
+++ b/tensorflow/core/profiler/lib/traceme.h
@@ -18,7 +18,7 @@
 #include <string>
 
 #include "absl/strings/string_view.h"
-#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/env_time.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/internal/traceme_recorder.h"
@@ -81,7 +81,7 @@
     DCHECK_GE(level, 1);
     if (TraceMeRecorder::Active(level)) {
       new (&no_init_.name) string(activity_name);
-      start_time_ = Env::Default()->NowNanos();
+      start_time_ = EnvTime::Default()->NowNanos();
     } else {
       start_time_ = kUntracedActivity;
     }
@@ -96,7 +96,7 @@
     DCHECK_GE(level, 1);
     if (TraceMeRecorder::Active(level)) {
       new (&no_init_.name) string(std::move(activity_name));
-      start_time_ = Env::Default()->NowNanos();
+      start_time_ = EnvTime::Default()->NowNanos();
     } else {
       start_time_ = kUntracedActivity;
     }
@@ -126,7 +126,7 @@
     DCHECK_GE(level, 1);
     if (TraceMeRecorder::Active(level)) {
       new (&no_init_.name) string(name_generator());
-      start_time_ = Env::Default()->NowNanos();
+      start_time_ = EnvTime::Default()->NowNanos();
     } else {
       start_time_ = kUntracedActivity;
     }
@@ -147,7 +147,7 @@
     if (start_time_ != kUntracedActivity) {
       if (TraceMeRecorder::Active()) {
         TraceMeRecorder::Record({kCompleteActivity, std::move(no_init_.name),
-                                 start_time_, Env::Default()->NowNanos()});
+                                 start_time_, EnvTime::Default()->NowNanos()});
       }
       no_init_.name.~string();
       start_time_ = kUntracedActivity;
diff --git a/tensorflow/core/profiler/rpc/BUILD b/tensorflow/core/profiler/rpc/BUILD
index 7d9e826..c343742 100644
--- a/tensorflow/core/profiler/rpc/BUILD
+++ b/tensorflow/core/profiler/rpc/BUILD
@@ -14,7 +14,7 @@
         "//tensorflow:grpc++",
         "//tensorflow/core:framework",
         "//tensorflow/core:grpc_services",
-        "//tensorflow/core/common_runtime/eager:context",
+        "//tensorflow/core:lib",
         "//tensorflow/core/profiler:protos_all_cc",
         "//tensorflow/core/profiler/lib:profiler_lib",
         "//tensorflow/core/profiler/lib:profiler_session",
@@ -32,7 +32,7 @@
         "//tensorflow:grpc++",
         "//tensorflow/core:framework",
         "//tensorflow/core:grpc_services",
-        "//tensorflow/core/common_runtime/eager:context",
+        "//tensorflow/core:lib",
         "//tensorflow/core/profiler:protos_all_cc",
         "//tensorflow/core/profiler/lib:profiler_lib",
         "//tensorflow/core/profiler/lib:profiler_session",
diff --git a/tensorflow/core/profiler/rpc/profiler_server.cc b/tensorflow/core/profiler/rpc/profiler_server.cc
index 257e4e0..38fe9c1 100644
--- a/tensorflow/core/profiler/rpc/profiler_server.cc
+++ b/tensorflow/core/profiler/rpc/profiler_server.cc
@@ -14,27 +14,26 @@
 ==============================================================================*/
 
 #include "tensorflow/core/profiler/rpc/profiler_server.h"
+
 #include <memory>
 #include <utility>
+
 #include "grpcpp/grpcpp.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/grpc_services.h"
+#include "tensorflow/core/profiler/lib/profiler_session.h"
 #include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
 #include "tensorflow/core/util/ptr_util.h"
 
 namespace tensorflow {
 
-std::unique_ptr<Thread> StartProfilerServer(
-    ProfilerContext* const profiler_context, int32 port) {
-  Env* env = profiler_context->eager_context != nullptr
-                 ? profiler_context->eager_context->TFEnv()
-                 : Env::Default();
-  // Starting the server in the child thread may be delay and user may already
-  // delete the profiler context at that point. So we need to make a copy.
-  ProfilerContext ctx = *profiler_context;
-  return WrapUnique(env->StartThread({}, "profiler server", [ctx, port]() {
+std::unique_ptr<Thread> StartProfilerServer(int32 port) {
+  Env* env = Env::Default();
+  return WrapUnique(env->StartThread({}, "profiler server", [port]() {
     string server_address = strings::StrCat("0.0.0.0:", port);
     std::unique_ptr<grpc::ProfilerService::Service> service =
-        CreateProfilerService(ctx);
+        CreateProfilerService();
     ::grpc::ServerBuilder builder;
     builder.AddListeningPort(server_address,
                              ::grpc::InsecureServerCredentials());
diff --git a/tensorflow/core/profiler/rpc/profiler_server.h b/tensorflow/core/profiler/rpc/profiler_server.h
index 21898d4..fd51612 100644
--- a/tensorflow/core/profiler/rpc/profiler_server.h
+++ b/tensorflow/core/profiler/rpc/profiler_server.h
@@ -15,11 +15,16 @@
 #ifndef TENSORFLOW_CORE_PROFILER_RPC_PROFILER_SERVER_H_
 #define TENSORFLOW_CORE_PROFILER_RPC_PROFILER_SERVER_H_
 
-#include "tensorflow/core/profiler/lib/profiler_session.h"
+#include <memory>
+
+#include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
+
 class Thread;
-std::unique_ptr<Thread> StartProfilerServer(
-    ProfilerContext* const profiler_context, int32 port);
+
+std::unique_ptr<Thread> StartProfilerServer(int32 port);
+
 }  // namespace tensorflow
+
 #endif  // TENSORFLOW_CORE_PROFILER_RPC_PROFILER_SERVER_H_
diff --git a/tensorflow/core/profiler/rpc/profiler_service_impl.cc b/tensorflow/core/profiler/rpc/profiler_service_impl.cc
index f25ee66..3b80519 100644
--- a/tensorflow/core/profiler/rpc/profiler_service_impl.cc
+++ b/tensorflow/core/profiler/rpc/profiler_service_impl.cc
@@ -14,8 +14,9 @@
 ==============================================================================*/
 
 #include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
+
 #include "grpcpp/support/status.h"
-#include "tensorflow/core/common_runtime/eager/context.h"
+#include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/grpc_services.h"
 #include "tensorflow/core/profiler/lib/profiler_session.h"
 #include "tensorflow/core/util/ptr_util.h"
@@ -25,10 +26,6 @@
 
 class ProfilerServiceImpl : public grpc::ProfilerService::Service {
  public:
-  explicit ProfilerServiceImpl(const ProfilerContext& profiler_context)
-      : profiler_context_(profiler_context) {}
-  ~ProfilerServiceImpl() override {}
-
   ::grpc::Status Monitor(::grpc::ServerContext* ctx, const MonitorRequest* req,
                          MonitorResponse* response) override {
     return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "unimplemented.");
@@ -37,16 +34,13 @@
   ::grpc::Status Profile(::grpc::ServerContext* ctx, const ProfileRequest* req,
                          ProfileResponse* response) override {
     LOG(INFO) << "Received a profile request.";
-    std::unique_ptr<ProfilerSession> profiler =
-        ProfilerSession::Create(&profiler_context_);
+    std::unique_ptr<ProfilerSession> profiler = ProfilerSession::Create();
     if (!profiler->Status().ok()) {
       return ::grpc::Status(::grpc::StatusCode::INTERNAL,
                             profiler->Status().error_message());
     }
 
-    Env* env = profiler_context_.eager_context != nullptr
-                   ? profiler_context_.eager_context->TFEnv()
-                   : Env::Default();
+    Env* env = Env::Default();
     for (size_t i = 0; i < req->duration_ms(); ++i) {
       env->SleepForMicroseconds(1000);
       if (ctx->IsCancelled()) {
@@ -61,15 +55,11 @@
 
     return ::grpc::Status::OK;
   }
-
- private:
-  ProfilerContext profiler_context_;
 };
 }  // namespace
 
-std::unique_ptr<grpc::ProfilerService::Service> CreateProfilerService(
-    const ProfilerContext& profiler_context) {
-  return MakeUnique<ProfilerServiceImpl>(profiler_context);
+std::unique_ptr<grpc::ProfilerService::Service> CreateProfilerService() {
+  return MakeUnique<ProfilerServiceImpl>();
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/profiler/rpc/profiler_service_impl.h b/tensorflow/core/profiler/rpc/profiler_service_impl.h
index 64ae01d..c003040 100644
--- a/tensorflow/core/profiler/rpc/profiler_service_impl.h
+++ b/tensorflow/core/profiler/rpc/profiler_service_impl.h
@@ -18,14 +18,13 @@
 #include "grpcpp/grpcpp.h"
 #include "grpcpp/server_context.h"
 #include "grpcpp/support/status.h"
-#include "tensorflow/core/common_runtime/eager/context.h"
 #include "tensorflow/core/platform/grpc_services.h"
 #include "tensorflow/core/profiler/lib/profiler_session.h"
 
 namespace tensorflow {
 
-std::unique_ptr<grpc::ProfilerService::Service> CreateProfilerService(
-    const ProfilerContext& profiler_context);
+std::unique_ptr<grpc::ProfilerService::Service> CreateProfilerService();
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_PROFILER_RPC_PROFILER_SERVICE_IMPL_H_
diff --git a/tensorflow/core/protobuf/autotuning.proto b/tensorflow/core/protobuf/autotuning.proto
index 86cbc4a..f43dbbe 100644
--- a/tensorflow/core/protobuf/autotuning.proto
+++ b/tensorflow/core/protobuf/autotuning.proto
@@ -80,5 +80,7 @@
   // stream_executor::DeviceDescription::pci_bus_id.
   string device_pci_bus_id = 5;
 
-  // Next ID: 6
+  string blas_version = 6;
+
+  // Next ID: 7
 }
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index e0283e0..4a5bdb2 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -335,6 +335,9 @@
   // while with it we'll be able to complete long steps (like complex
   // initializations) in the face of some network errors during RecvTensor.
   bool cache_rpc_response = 4;
+
+  // Disables TCP connection sharing when opening a new RPC channel.
+  bool disable_session_connection_sharing = 5;
 }
 
 // Metadata about the session.
diff --git a/tensorflow/core/protobuf/debug.proto b/tensorflow/core/protobuf/debug.proto
index 8ca76c4..3cfab17 100644
--- a/tensorflow/core/protobuf/debug.proto
+++ b/tensorflow/core/protobuf/debug.proto
@@ -10,13 +10,15 @@
 // Option for watching a node in TensorFlow Debugger (tfdbg).
 message DebugTensorWatch {
   // Name of the node to watch.
+  // Use "*" for wildcard. But note: currently, regex is not supported in
+  // general.
   string node_name = 1;
 
   // Output slot to watch.
-  // The semantics of output_slot == -1 is that the node is only watched for
-  // completion, but not for any output tensors. See NodeCompletionCallback
-  // in debug_gateway.h.
-  // TODO(cais): Implement this semantics.
+  // The semantics of output_slot == -1 is that all outputs of the node
+  // will be watched (i.e., a wildcard).
+  // Other negative values of output_slot are invalid and will lead to
+  // errors currently.
   int32 output_slot = 2;
 
   // Name(s) of the debugging op(s).
diff --git a/tensorflow/core/protobuf/eager_service.proto b/tensorflow/core/protobuf/eager_service.proto
index 0560394..a8993d6 100644
--- a/tensorflow/core/protobuf/eager_service.proto
+++ b/tensorflow/core/protobuf/eager_service.proto
@@ -7,6 +7,7 @@
 import "tensorflow/core/framework/function.proto";
 import "tensorflow/core/framework/tensor.proto";
 import "tensorflow/core/framework/tensor_shape.proto";
+import "tensorflow/core/framework/types.proto";
 import "tensorflow/core/framework/versions.proto";
 import "tensorflow/core/protobuf/tensorflow_server.proto";
 
@@ -15,6 +16,14 @@
   int64 op_id = 1;
   // The index into the outputs of the operation that produced this tensor.
   int32 output_num = 2;
+  // Device of the operation that produced this tensor. Cannot be empty.
+  // For multi-device functions, it's the default device passed to placer.
+  string device = 3;
+  // Device where the tensor is located. Can be empty if the operation producing
+  // this tensor is a multi-device function.
+  string op_device = 4;
+  // Tensor type.
+  DataType dtype = 5;
 }
 
 // A proto representation of an eager operation.
diff --git a/tensorflow/core/protobuf/meta_graph.proto b/tensorflow/core/protobuf/meta_graph.proto
index fa0192c..1eb2023 100644
--- a/tensorflow/core/protobuf/meta_graph.proto
+++ b/tensorflow/core/protobuf/meta_graph.proto
@@ -14,6 +14,7 @@
 import "tensorflow/core/framework/types.proto";
 import "tensorflow/core/protobuf/saved_object_graph.proto";
 import "tensorflow/core/protobuf/saver.proto";
+import "tensorflow/core/protobuf/struct.proto";
 
 // NOTE: This protocol buffer is evolving, and will go through revisions in the
 // coming months.
@@ -225,6 +226,15 @@
     string dense_shape_tensor_name = 3;
   }
 
+  // Generic encoding for composite tensors.
+  message CompositeTensor {
+    // The serialized TypeSpec for the composite tensor.
+    TypeSpecProto type_spec = 1;
+
+    // A TensorInfo for each flattened component tensor.
+    repeated TensorInfo components = 2;
+  }
+
   oneof encoding {
     // For dense `Tensor`s, the name of the tensor in the graph.
     string name = 1;
@@ -233,6 +243,8 @@
     // uses only the COO encoding.  This is supported and documented in the
     // SparseTensor Python class.
     CooSparse coo_sparse = 4;
+    // Generic encoding for CompositeTensors.
+    CompositeTensor composite_tensor = 5;
   }
   DataType dtype = 2;
   // The static shape should be recorded here, to the extent that it can
diff --git a/tensorflow/core/protobuf/struct.proto b/tensorflow/core/protobuf/struct.proto
index 48a97c9..ecf4877 100644
--- a/tensorflow/core/protobuf/struct.proto
+++ b/tensorflow/core/protobuf/struct.proto
@@ -125,4 +125,10 @@
 
   // The value returned by TypeSpec._serialize().
   StructuredValue type_state = 2;
+
+  // This is currently redundant with the type_spec_class enum, and is only
+  // used for error reporting.  In particular, if you use an older binary to
+  // load a newer model, and the model uses a TypeSpecClass that the older
+  // binary doesn't support, then this lets us display a useful error message.
+  string type_spec_class_name = 3;
 }
diff --git a/tensorflow/core/protobuf/tpu/BUILD b/tensorflow/core/protobuf/tpu/BUILD
index 33db44a..98aa1b8 100644
--- a/tensorflow/core/protobuf/tpu/BUILD
+++ b/tensorflow/core/protobuf/tpu/BUILD
@@ -1,5 +1,5 @@
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_additional_all_protos",
     "tf_proto_library",
     "tf_proto_library_py",
diff --git a/tensorflow/core/protobuf/tpu/optimization_parameters.proto b/tensorflow/core/protobuf/tpu/optimization_parameters.proto
index 7190001..f52f7bf 100644
--- a/tensorflow/core/protobuf/tpu/optimization_parameters.proto
+++ b/tensorflow/core/protobuf/tpu/optimization_parameters.proto
@@ -166,7 +166,7 @@
   float initial_benefit = 15;
 }
 
-// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
+// https://www.tensorflow.org/api_docs/python/tf/train/AdadeltaOptimizer
 // https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L68
 message AdadeltaParameters {
   float rho = 1;
@@ -175,7 +175,7 @@
   float initial_update = 4;
 }
 
-// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
+// https://www.tensorflow.org/api_docs/python/tf/train/ProximalAdagradOptimizer
 // https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L164
 message ProximalAdagradParameters {
   float l1 = 1;
@@ -183,6 +183,45 @@
   float initial_accumulator = 3;
 }
 
+// The online Yogi optimizer does not implement hyper-parameter update; use the
+// dynamic learning rate feature instead, setting the learning rate to:
+// user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
+// Here, t is the current timestep.
+//
+// https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization.pdf
+// plus some extensions based on FTRL.
+//
+// Note that the code by default implements the lazy version of online Yogi.
+message OnlineYogiParameters {
+  // The L1 regularization parameter (used analogously to the one in FTRL).
+  float l1 = 1;
+
+  // The L2 regularization parameter (used analogously to the one in FTRL).
+  float l2 = 2;
+
+  // \beta_2 from Algorithm 2 in the paper.
+  float beta2 = 3;
+
+  // Initial value of V variable in paper.
+  float initial_v = 4;
+
+  // Initial value of linear variable in FTRL.
+  float initial_linear = 5;
+
+  // x -> copysign(1, x) (i.e., return 1 for an input of +0 rather than 0).
+  message SignActivation {}
+
+  // x -> tanh(x * 10)
+  message TanhActivation {}
+
+  // Activation to use to replace sign function in v_t update in Algorithm 2 of
+  // paper.
+  oneof activation {
+    SignActivation sign = 6;
+    TanhActivation tanh = 7;
+  }
+}
+
 // Status of using gradient accumulation (doing two passes over the input
 // gradients: one to accumulate them into a temporary array and another to apply
 // them using the actual optimization algorithm). The extra message is to wrap
@@ -253,6 +292,7 @@
     MdlAdagradLightParameters mdl_adagrad_light = 11;
     AdadeltaParameters adadelta = 12;
     ProximalAdagradParameters proximal_adagrad = 14;
+    OnlineYogiParameters online_yogi = 20;
   }
 
   reserved 15;  // Old use_gradient_accumulation.
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index 304eef4..183a247 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -108,7 +108,7 @@
 
 #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
-#define TF_GRAPH_DEF_VERSION 106  // Updated: 2019/7/24
+#define TF_GRAPH_DEF_VERSION 122  // Updated: 2019/8/9
 
 // Checkpoint compatibility versions (the versions field in SavedSliceMeta).
 //
diff --git a/tensorflow/core/summary/summary_db_writer.cc b/tensorflow/core/summary/summary_db_writer.cc
index b203d43..1a9bd33 100644
--- a/tensorflow/core/summary/summary_db_writer.cc
+++ b/tensorflow/core/summary/summary_db_writer.cc
@@ -676,7 +676,7 @@
                const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) {
     if (t.dtype() == DT_STRING) {
       if (t.dims() == 0) {
-        return Update(db, step, computed_time, t, t.scalar<string>()(), rowid);
+        return Update(db, step, computed_time, t, t.scalar<tstring>()(), rowid);
       } else {
         SqliteTransaction txn(*db);
         TF_RETURN_IF_ERROR(
@@ -735,7 +735,7 @@
     )sql";
     SqliteStatement inserter;
     TF_RETURN_IF_ERROR(db->Prepare(inserter_sql, &inserter));
-    auto flat = t.flat<string>();
+    auto flat = t.flat<tstring>();
     for (int64 i = 0; i < flat.size(); ++i) {
       inserter.BindInt(1, tensor_rowid);
       inserter.BindInt(2, i);
@@ -751,7 +751,7 @@
     unflushed_bytes_ = 0;
     if (t.dtype() == DT_STRING) {
       if (t.dims() == 0) {
-        TF_RETURN_IF_ERROR(ReserveData(db, &txn, t.scalar<string>()().size()));
+        TF_RETURN_IF_ERROR(ReserveData(db, &txn, t.scalar<tstring>()().size()));
       } else {
         TF_RETURN_IF_ERROR(ReserveTensors(db, &txn, kReserveMinBytes));
       }
@@ -1106,9 +1106,9 @@
     // See tensorboard/plugins/image/summary.py and data_compat.py
     Tensor t{DT_STRING, {3}};
     auto img = s->mutable_image();
-    t.flat<string>()(0) = strings::StrCat(img->width());
-    t.flat<string>()(1) = strings::StrCat(img->height());
-    t.flat<string>()(2) = std::move(*img->mutable_encoded_image_string());
+    t.flat<tstring>()(0) = strings::StrCat(img->width());
+    t.flat<tstring>()(1) = strings::StrCat(img->height());
+    t.flat<tstring>()(2) = std::move(*img->mutable_encoded_image_string());
     int64 tag_id;
     PatchPluginName(s->mutable_metadata(), kImagePluginName);
     TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
@@ -1120,8 +1120,8 @@
     // See tensorboard/plugins/audio/summary.py and data_compat.py
     Tensor t{DT_STRING, {1, 2}};
     auto wav = s->mutable_audio();
-    t.flat<string>()(0) = std::move(*wav->mutable_encoded_audio_string());
-    t.flat<string>()(1) = "";
+    t.flat<tstring>()(0) = std::move(*wav->mutable_encoded_audio_string());
+    t.flat<tstring>()(1) = "";
     int64 tag_id;
     PatchPluginName(s->mutable_metadata(), kAudioPluginName);
     TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
diff --git a/tensorflow/core/summary/summary_file_writer_test.cc b/tensorflow/core/summary/summary_file_writer_test.cc
index 41060d7..932ae80 100644
--- a/tensorflow/core/summary/summary_file_writer_test.cc
+++ b/tensorflow/core/summary/summary_file_writer_test.cc
@@ -109,7 +109,7 @@
       "string_tensor_test",
       [](SummaryWriterInterface* writer) {
         Tensor hello(DT_STRING, TensorShape({}));
-        hello.scalar<string>()() = "hello";
+        hello.scalar<tstring>()() = "hello";
         TF_RETURN_IF_ERROR(writer->WriteTensor(
             2, hello, "name", SummaryMetadata().SerializeAsString()));
         TF_RETURN_IF_ERROR(writer->Flush());
diff --git a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc
index 915348b..fa49c42 100644
--- a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc
+++ b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc
@@ -47,6 +47,8 @@
       return "Adadelta";
     case OptimizationAlgorithm::kProximalAdagrad:
       return "ProximalAdagrad";
+    case OptimizationAlgorithm::kOnlineYogi:
+      return "OnlineYogi";
     case OptimizationAlgorithm::PARAMETERS_NOT_SET:
       return "*** Not set ***";
   }
@@ -77,6 +79,8 @@
       return "Adadelta";
     case OptimizationAlgorithm::kProximalAdagrad:
       return "proximal Adagrad";
+    case OptimizationAlgorithm::kOnlineYogi:
+      return "online Yogi";
     case OptimizationAlgorithm::PARAMETERS_NOT_SET:
       return "unknown (not specified)";
   }
@@ -121,6 +125,9 @@
     case OptimizationAlgorithm::kProximalAdagrad:
       *count = 1;
       return Status::OK();
+    case OptimizationAlgorithm::kOnlineYogi:
+      *count = 2;
+      return Status::OK();
     case OptimizationAlgorithm::PARAMETERS_NOT_SET:
       return errors::InvalidArgument("No optimization algorithm specified");
   }
@@ -242,6 +249,13 @@
           MakeStandardStateVariableSpecification("accumulators", 0.1));
       break;
     }
+    case OptimizationAlgorithm::kOnlineYogi: {
+      state_variables->push_back(
+          MakeStandardStateVariableSpecification("vs", 0.0));
+      state_variables->push_back(
+          MakeStandardStateVariableSpecification("linears", 0.0));
+      break;
+    }
     case OptimizationAlgorithm::PARAMETERS_NOT_SET: {
       return errors::InvalidArgument("No optimization algorithm specified");
     }
@@ -277,6 +291,7 @@
       OptimizationAlgorithm::kMdlAdagradLight,
       OptimizationAlgorithm::kAdadelta,
       OptimizationAlgorithm::kProximalAdagrad,
+      OptimizationAlgorithm::kOnlineYogi,
   };
 }
 
@@ -508,7 +523,8 @@
       *internal = false;
       return Status::OK();
     }
-    case OptimizationAlgorithm::kBoundedAdagrad: {
+    case OptimizationAlgorithm::kBoundedAdagrad:
+    case OptimizationAlgorithm::kOnlineYogi: {
       *internal = true;
       return Status::OK();
     }
diff --git a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h
index 320863d..bdd3c15 100644
--- a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h
+++ b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h
@@ -101,7 +101,7 @@
     OpRegistrationData *op_reg_data);
 
 // Returns whether an optimization algorithm is only supported internally.
-// Returns an error if the algorithm is not recongized at all.
+// Returns an error if the algorithm is not recognized at all.
 Status IsOptimizationAlgorithmInternal(OptimizationAlgorithm alg,
                                        bool *internal);
 
diff --git a/tensorflow/core/util/batch_util.cc b/tensorflow/core/util/batch_util.cc
index e1c32cd..3d704c4 100644
--- a/tensorflow/core/util/batch_util.cc
+++ b/tensorflow/core/util/batch_util.cc
@@ -107,8 +107,8 @@
 template <>
 void HandleSliceToElement<string>(Tensor* parent, Tensor* element, int64 index,
                                   bool can_move) {
-  auto parent_as_matrix = parent->flat_outer_dims<string>();
-  auto element_flat = element->flat<string>();
+  auto parent_as_matrix = parent->flat_outer_dims<tstring>();
+  auto element_flat = element->flat<tstring>();
   if (can_move) {
     for (int64 i = 0; i < element->NumElements(); ++i) {
       element_flat(i) = std::move(parent_as_matrix(index, i));
diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc
index 2dc5c83..4e49d4e 100644
--- a/tensorflow/core/util/example_proto_fast_parsing.cc
+++ b/tensorflow/core/util/example_proto_fast_parsing.cc
@@ -852,8 +852,8 @@
         break;
       }
       case DT_STRING: {
-        std::copy_n(in.flat<string>().data(), num_elements,
-                    out.flat<string>().data() + offset);
+        std::copy_n(in.flat<tstring>().data(), num_elements,
+                    out.flat<tstring>().data() + offset);
         break;
       }
       default:
@@ -1194,7 +1194,7 @@
         }
         case DT_STRING: {
           std::move(buffer.bytes_list.begin(), buffer.bytes_list.end(),
-                    values->flat<string>().data() + offset);
+                    values->flat<tstring>().data() + offset);
           break;
         }
         default:
@@ -1273,8 +1273,8 @@
   return Status::OK();
 }
 
-Status FastParseSingleExample(const Config& config, const string& serialized,
-                              Result* result) {
+Status FastParseSingleExample(const Config& config,
+                              absl::string_view serialized, Result* result) {
   DCHECK(result != nullptr);
   // Check config so we can safely CHECK(false) in switches on config.*.dtype
   for (auto& c : config.sparse) {
@@ -1578,7 +1578,7 @@
         case DT_STRING: {
           *out = Tensor(out_dtype, out_shape);
           CopyOrMoveBlock(bytes_list.begin(), bytes_list.end(),
-                          out->flat<string>().data());
+                          out->flat<tstring>().data());
           break;
         }
         default:
@@ -2079,7 +2079,7 @@
     int64* out_int64 = nullptr;
     switch (dtype) {
       case DT_STRING:
-        out_bytes = context_result->dense_values[t].flat<string>().data();
+        out_bytes = context_result->dense_values[t].flat<tstring>().data();
         break;
       case DT_FLOAT:
         out_float = context_result->dense_values[t].flat<float>().data();
@@ -2113,7 +2113,7 @@
         size_t num = 0;
         switch (dtype) {
           case DT_STRING:
-            in_bytes = c.default_value.flat<string>().data();
+            in_bytes = c.default_value.flat<tstring>().data();
             num = c.default_value.NumElements();
             for (int p = 0; p < num; p++) {
               *out_bytes++ = *in_bytes++;
@@ -2190,7 +2190,7 @@
     int64* out_int64 = nullptr;
     switch (dtype) {
       case DT_STRING:
-        out_bytes = context_result->sparse_values[t].flat<string>().data();
+        out_bytes = context_result->sparse_values[t].flat<tstring>().data();
         break;
       case DT_FLOAT:
         out_float = context_result->sparse_values[t].flat<float>().data();
@@ -2281,7 +2281,7 @@
     int64* out_int64 = nullptr;
     switch (dtype) {
       case DT_STRING:
-        out_bytes = feature_list_result->dense_values[t].flat<string>().data();
+        out_bytes = feature_list_result->dense_values[t].flat<tstring>().data();
         break;
       case DT_FLOAT:
         out_float = feature_list_result->dense_values[t].flat<float>().data();
@@ -2392,7 +2392,8 @@
     int64* out_int64 = nullptr;
     switch (dtype) {
       case DT_STRING:
-        out_bytes = feature_list_result->sparse_values[t].flat<string>().data();
+        out_bytes =
+            feature_list_result->sparse_values[t].flat<tstring>().data();
         break;
       case DT_FLOAT:
         out_float = feature_list_result->sparse_values[t].flat<float>().data();
diff --git a/tensorflow/core/util/example_proto_fast_parsing.h b/tensorflow/core/util/example_proto_fast_parsing.h
index 055d9c2..c2734fa 100644
--- a/tensorflow/core/util/example_proto_fast_parsing.h
+++ b/tensorflow/core/util/example_proto_fast_parsing.h
@@ -107,7 +107,7 @@
 typedef FastParseExampleConfig FastParseSingleExampleConfig;
 
 Status FastParseSingleExample(const FastParseSingleExampleConfig& config,
-                              const string& serialized, Result* result);
+                              absl::string_view serialized, Result* result);
 
 // Parses a batch of serialized SequenceExample protos and converts them into
 // result according to given config.
diff --git a/tensorflow/core/util/example_proto_helper_test.cc b/tensorflow/core/util/example_proto_helper_test.cc
index 1bf430b..141c240 100644
--- a/tensorflow/core/util/example_proto_helper_test.cc
+++ b/tensorflow/core/util/example_proto_helper_test.cc
@@ -57,7 +57,7 @@
     string_dense_config.dtype = DT_STRING;
     string_dense_config.shape = TensorShape({1});
     string_dense_config.default_value = Tensor(DT_STRING, TensorShape({1}));
-    string_dense_config.default_value.scalar<string>()() = "default";
+    string_dense_config.default_value.scalar<tstring>()() = "default";
     dense_vec_.push_back(string_dense_config);
 
     // Setup sparse feature configuration.
@@ -115,7 +115,7 @@
 
   const std::vector<Tensor>& string_tensor_vec = output_sparse_values_tmp[2];
   EXPECT_EQ(1, string_tensor_vec.size());
-  EXPECT_EQ("forty-two", string_tensor_vec[0].vec<string>()(0));
+  EXPECT_EQ("forty-two", string_tensor_vec[0].vec<tstring>()(0));
 }
 
 TEST_F(SingleExampleProtoToTensorsTest, SparseOnlyEmpty) {
@@ -143,7 +143,7 @@
 
   const std::vector<Tensor>& string_tensor_vec = output_sparse_values_tmp[2];
   EXPECT_EQ(1, string_tensor_vec.size());
-  EXPECT_EQ(0, string_tensor_vec[0].vec<string>().size());
+  EXPECT_EQ(0, string_tensor_vec[0].vec<tstring>().size());
 }
 
 TEST_F(SingleExampleProtoToTensorsTest, DenseOnlyTrivial) {
@@ -182,8 +182,8 @@
   EXPECT_EQ(1, float_dense_output.matrix<float>().size());
   EXPECT_NEAR(4.2, float_dense_output.matrix<float>()(0, 0), 0.001);
 
-  EXPECT_EQ(1, str_dense_output.matrix<string>().size());
-  EXPECT_EQ("forty-two", str_dense_output.matrix<string>()(0, 0));
+  EXPECT_EQ(1, str_dense_output.matrix<tstring>().size());
+  EXPECT_EQ("forty-two", str_dense_output.matrix<tstring>()(0, 0));
 }
 
 TEST_F(SingleExampleProtoToTensorsTest, DenseOnlyDefaults) {
@@ -211,8 +211,8 @@
   EXPECT_EQ(1, float_dense_output.matrix<float>().size());
   EXPECT_NEAR(0.0, float_dense_output.matrix<float>()(0, 0), 0.001);
 
-  EXPECT_EQ(1, str_dense_output.matrix<string>().size());
-  EXPECT_EQ("default", str_dense_output.matrix<string>()(0, 0));
+  EXPECT_EQ(1, str_dense_output.matrix<tstring>().size());
+  EXPECT_EQ("default", str_dense_output.matrix<tstring>()(0, 0));
 }
 
 }  // namespace
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 9cd69e3..ff218f2 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -204,6 +204,7 @@
 
 TensorFormat MklDnn3DDataFormatToTFDataFormat(MKL_TENSOR_FORMAT format);
 TensorFormat MklDnnDataFormatToTFDataFormat(MKL_TENSOR_FORMAT format);
+
 memory::dims CalculateTFStrides(const memory::dims& dims_tf_order);
 memory::desc CreateBlockedMemDescHelper(const memory::dims& dim,
                                         const memory::dims& strides,
@@ -696,15 +697,24 @@
 }
 
 // Get the MKL shape from the second string tensor
+inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape,
+                        bool eager_mode) {
+  if (!eager_mode) {
+    mklshape->DeSerializeMklDnnShape(
+        ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
+            .flat<uint8>()
+            .data(),
+        ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
+                .flat<uint8>()
+                .size() *
+            sizeof(uint8));
+  } else {
+    mklshape->SetMklTensor(false);
+  }
+}
+
 inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape) {
-  mklshape->DeSerializeMklDnnShape(
-      ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
-          .flat<uint8>()
-          .data(),
-      ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
-              .flat<uint8>()
-              .size() *
-          sizeof(uint8));
+  GetMklShape(ctext, n, mklshape, false);
 }
 
 // Gets the actual input
@@ -733,14 +743,15 @@
 /// Get shape of input tensor pointed by 'input_idx' in TensorShape format.
 /// If the input tensor is in MKL layout, then obtains TensorShape from
 /// MklShape.
-inline TensorShape GetTfShape(OpKernelContext* context, size_t input_idx) {
+inline TensorShape GetTfShape(OpKernelContext* context, size_t input_idx,
+                              bool eager_mode = false) {
   // Sanity check.
   CHECK_NOTNULL(context);
   CHECK_LT(input_idx, context->num_inputs());
 
   MklDnnShape input_mkl_shape;
-  GetMklShape(context, input_idx, &input_mkl_shape);
-  if (input_mkl_shape.IsMklTensor()) {
+  GetMklShape(context, input_idx, &input_mkl_shape, eager_mode);
+  if (input_mkl_shape.IsMklTensor() && !eager_mode) {
     return input_mkl_shape.GetTfShape();
   } else {
     const Tensor& t = MklGetInput(context, input_idx);
@@ -768,19 +779,22 @@
 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
                                       Tensor** output,
                                       const TensorShape& tf_shape,
-                                      const MklDnnShape& mkl_shape) {
-  Tensor* second_tensor = nullptr;
-  TensorShape second_shape;
-  second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
+                                      const MklDnnShape& mkl_shape,
+                                      bool eager_mode = false) {
   OP_REQUIRES_OK(
       ctext, ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()),
                                     tf_shape, output));
-  OP_REQUIRES_OK(ctext, ctext->allocate_output(
-                            GetTensorMetaDataIndex(n, ctext->num_outputs()),
-                            second_shape, &second_tensor));
-  mkl_shape.SerializeMklDnnShape(
-      second_tensor->flat<uint8>().data(),
-      second_tensor->flat<uint8>().size() * sizeof(uint8));
+  if (!eager_mode) {
+    Tensor* second_tensor = nullptr;
+    TensorShape second_shape;
+    second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
+    OP_REQUIRES_OK(ctext, ctext->allocate_output(
+                              GetTensorMetaDataIndex(n, ctext->num_outputs()),
+                              second_shape, &second_tensor));
+    mkl_shape.SerializeMklDnnShape(
+        second_tensor->flat<uint8>().data(),
+        second_tensor->flat<uint8>().size() * sizeof(uint8));
+  }
 }
 
 // Allocates a temp tensor and returns the data buffer for temporary storage.
@@ -1194,6 +1208,27 @@
   return memory::desc(md);
 }
 
+inline void CreateAndExecuteReorder(const reorder::primitive_desc& reorder_desc,
+                                    const memory& src_mem,
+                                    const memory& dst_mem,
+                                    const engine& engine) {
+  std::vector<primitive> net;
+#ifdef ENABLE_MKLDNN_V1
+  net.push_back(mkldnn::reorder(reorder_desc));
+  std::vector<MemoryArgsMap> net_args;
+  net_args.push_back({{MKLDNN_ARG_FROM, src_mem}, {MKLDNN_ARG_TO, dst_mem}});
+  DCHECK_EQ(net.size(), net_args.size());
+  stream cpu_stream(engine);
+  for (size_t i = 0; i < net.size(); ++i) {
+    net.at(i).execute(cpu_stream, net_args.at(i));
+  }
+  cpu_stream.wait();
+#else
+  net.push_back(mkldnn::reorder(reorder_desc, src_mem, dst_mem));
+  stream(stream::kind::eager).submit(net).wait();
+#endif  // ENABLE_MKLDNN_V1
+}
+
 template <typename T>
 inline primitive FindOrCreateReorder(const memory* from, const memory* to);
 
diff --git a/tensorflow/core/util/reporter.cc b/tensorflow/core/util/reporter.cc
index 0268709..eb69e29 100644
--- a/tensorflow/core/util/reporter.cc
+++ b/tensorflow/core/util/reporter.cc
@@ -21,25 +21,57 @@
 
 namespace tensorflow {
 
-TestReporter::TestReporter(const string& fname, const string& test_name)
+TestReportFile::TestReportFile(const string& fname, const string& test_name)
     : closed_(true), fname_(fname), test_name_(test_name) {}
 
-Status TestReporter::Close() {
+Status TestReportFile::Append(const string& content) {
   if (closed_) return Status::OK();
+  return log_file_->Append(content);
+}
+
+Status TestReportFile::Close() {
+  if (closed_) return Status::OK();
+  closed_ = true;
+  return log_file_->Close();
+}
+
+Status TestReportFile::Initialize() {
+  if (fname_.empty()) {
+    return Status::OK();
+  }
+  string mangled_fname = strings::StrCat(
+      fname_, absl::StrJoin(str_util::Split(test_name_, '/'), "__"));
+  Env* env = Env::Default();
+  if (env->FileExists(mangled_fname).ok()) {
+    return errors::InvalidArgument(
+        "Cannot create TestReportFile, file exists: ", mangled_fname);
+  }
+  TF_RETURN_IF_ERROR(env->NewWritableFile(mangled_fname, &log_file_));
+  TF_RETURN_IF_ERROR(log_file_->Flush());
+
+  closed_ = false;
+  return Status::OK();
+}
+
+TestReporter::TestReporter(const string& fname, const string& test_name)
+    : report_file_(fname, test_name) {
+  benchmark_entry_.set_name(test_name);
+}
+
+Status TestReporter::Close() {
+  if (report_file_.IsClosed()) return Status::OK();
 
   BenchmarkEntries entries;
   *entries.add_entry() = benchmark_entry_;
-  TF_RETURN_IF_ERROR(log_file_->Append(entries.SerializeAsString()));
-
+  TF_RETURN_IF_ERROR(report_file_.Append(entries.SerializeAsString()));
   benchmark_entry_.Clear();
-  closed_ = true;
 
-  return log_file_->Close();
+  return report_file_.Close();
 }
 
 Status TestReporter::Benchmark(int64 iters, double cpu_time, double wall_time,
                                double throughput) {
-  if (closed_) return Status::OK();
+  if (report_file_.IsClosed()) return Status::OK();
   benchmark_entry_.set_iters(iters);
   benchmark_entry_.set_cpu_time(cpu_time / iters);
   benchmark_entry_.set_wall_time(wall_time / iters);
@@ -48,34 +80,17 @@
 }
 
 Status TestReporter::SetProperty(const string& name, const string& value) {
-  if (closed_) return Status::OK();
+  if (report_file_.IsClosed()) return Status::OK();
   (*benchmark_entry_.mutable_extras())[name].set_string_value(value);
   return Status::OK();
 }
 
 Status TestReporter::SetProperty(const string& name, double value) {
-  if (closed_) return Status::OK();
+  if (report_file_.IsClosed()) return Status::OK();
   (*benchmark_entry_.mutable_extras())[name].set_double_value(value);
   return Status::OK();
 }
 
-Status TestReporter::Initialize() {
-  if (fname_.empty()) {
-    return Status::OK();
-  }
-  string mangled_fname = strings::StrCat(
-      fname_, absl::StrJoin(str_util::Split(test_name_, '/'), "__"));
-  Env* env = Env::Default();
-  if (env->FileExists(mangled_fname).ok()) {
-    return errors::InvalidArgument("Cannot create TestReporter, file exists: ",
-                                   mangled_fname);
-  }
-  TF_RETURN_IF_ERROR(env->NewWritableFile(mangled_fname, &log_file_));
-  TF_RETURN_IF_ERROR(log_file_->Flush());
-
-  benchmark_entry_.set_name(test_name_);
-  closed_ = false;
-  return Status::OK();
-}
+Status TestReporter::Initialize() { return report_file_.Initialize(); }
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/util/reporter.h b/tensorflow/core/util/reporter.h
index e551e2e..51d7502 100644
--- a/tensorflow/core/util/reporter.h
+++ b/tensorflow/core/util/reporter.h
@@ -29,6 +29,34 @@
 
 namespace tensorflow {
 
+// The TestReportFile provides a file abstraction for TF tests to use.
+class TestReportFile {
+ public:
+  // Create a TestReportFile with the test name 'test_name'.
+  TestReportFile(const string& fname, const string& test_name);
+
+  // Initialize the TestReportFile.  If the reporting env flag is set,
+  // try to create the reporting file.  Fails if the file already exists.
+  Status Initialize();
+
+  // Append the report file w/ 'content'.
+  Status Append(const string& content);
+
+  // Close the report file.
+  Status Close();
+
+  bool IsClosed() const { return closed_; }
+
+  ~TestReportFile() { Close().IgnoreError(); }  // Autoclose in destructor.
+
+ private:
+  bool closed_;
+  string fname_;
+  string test_name_;
+  std::unique_ptr<WritableFile> log_file_;
+  TF_DISALLOW_COPY_AND_ASSIGN(TestReportFile);
+};
+
 // The TestReporter writes test / benchmark output to binary Protobuf files when
 // the environment variable "TEST_REPORT_FILE_PREFIX" is defined.
 //
@@ -91,10 +119,7 @@
     const char* fname_ptr = getenv(kTestReporterEnv);
     return (fname_ptr != nullptr) ? fname_ptr : "";
   }
-  bool closed_;
-  string fname_;
-  string test_name_;
-  std::unique_ptr<WritableFile> log_file_;
+  TestReportFile report_file_;
   BenchmarkEntry benchmark_entry_;
   TF_DISALLOW_COPY_AND_ASSIGN(TestReporter);
 };
diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h
index 4e53c59..d33bd03 100644
--- a/tensorflow/core/util/sparse/sparse_tensor.h
+++ b/tensorflow/core/util/sparse/sparse_tensor.h
@@ -312,7 +312,11 @@
                                        str_util::Join(shape_, ","), "]");
       }
       if (!increasing) {
-        return errors::InvalidArgument(index, " is out of order");
+        return errors::InvalidArgument(
+            index,
+            " is out of order. Many sparse ops require sorted indices.\n"
+            "    Use `tf.sparse.reorder` to create a correctly ordered copy."
+            "\n\n");
       }
       if (!different) {
         return errors::InvalidArgument(index, " is repeated");
diff --git a/tensorflow/core/util/sparse/sparse_tensor_test.cc b/tensorflow/core/util/sparse/sparse_tensor_test.cc
index 5ab0a3d..24d0a2b 100644
--- a/tensorflow/core/util/sparse/sparse_tensor_test.cc
+++ b/tensorflow/core/util/sparse/sparse_tensor_test.cc
@@ -181,7 +181,7 @@
   Tensor vals(DT_STRING, TensorShape({N}));
 
   auto ix_t = ix.matrix<int64>();
-  auto vals_t = vals.vec<string>();
+  auto vals_t = vals.vec<tstring>();
   vals_t = vals_c;
   ix_t = ix_c;
 
@@ -191,8 +191,12 @@
   TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
   Status st_indices_valid = st.IndicesValid();
   EXPECT_FALSE(st_indices_valid.ok());
-  EXPECT_EQ("indices[2] = [2,0,0] is out of order",
-            st_indices_valid.error_message());
+  EXPECT_EQ(
+      "indices[2] = [2,0,0] is out of order. "
+      "Many sparse ops require sorted indices.\n"
+      "    Use `tf.sparse.reorder` to create a correctly ordered copy."
+      "\n\n",
+      st_indices_valid.error_message());
 
   // Regardless of how order is updated; so long as there are no
   // duplicates, the resulting indices are valid.
@@ -362,7 +366,7 @@
   Tensor vals(DT_STRING, TensorShape({N}));
 
   auto ix_t = GetSimpleIndexTensor(N, NDIM);
-  auto vals_t = vals.vec<string>();
+  auto vals_t = vals.vec<tstring>();
 
   ix.matrix<int64>() = ix_t;
 
@@ -402,7 +406,7 @@
   Tensor vals(DT_STRING, TensorShape({N}));
 
   auto ix_t = GetSimpleIndexTensor(N, NDIM);
-  auto vals_t = vals.vec<string>();
+  auto vals_t = vals.vec<tstring>();
 
   ix.matrix<int64>() = ix_t;
 
@@ -540,7 +544,7 @@
   auto ix_c = GetSimpleIndexTensor(N, NDIM);
 
   auto ix_t = ix.matrix<int64>();
-  auto vals_t = vals.vec<string>();
+  auto vals_t = vals.vec<tstring>();
 
   ix_t = ix_c;
 
@@ -561,7 +565,7 @@
   TF_EXPECT_OK(concatted.IndicesValid());
 
   auto conc_ix_t = concatted.indices().matrix<int64>();
-  auto conc_vals_t = concatted.values().vec<string>();
+  auto conc_vals_t = concatted.values().vec<tstring>();
 
   for (int n = 0; n < 4; ++n) {
     for (int i = 0; i < N; ++i) {
@@ -750,7 +754,7 @@
   TensorShape shape;
   std::vector<int64> order;
   auto ix_t = ix.matrix<int64>();
-  auto vals_t = vals.vec<string>();
+  auto vals_t = vals.vec<tstring>();
   for (int i = 0; i < N32; ++i) {
     int len = rnd.Rand32() % 1000;
     vals_t(i).resize(len);
diff --git a/tensorflow/core/util/stat_summarizer.cc b/tensorflow/core/util/stat_summarizer.cc
index 2117042..99f4c08 100644
--- a/tensorflow/core/util/stat_summarizer.cc
+++ b/tensorflow/core/util/stat_summarizer.cc
@@ -146,6 +146,15 @@
           ds.device().find("/stream:all") == std::string::npos) {
         continue;
       }
+      // NOTE(fishx): We will record ops execution time twice: one as CPU
+      // activity with device name "/host:CPU" and the other as TF runtime
+      // activity with device name started with "/job:*". It is safe to ignore
+      // CPU activties here.
+      // TODO(b/138729463): Read ops execution time from CPU activities instead
+      // of runtime acitivities.
+      if (ds.device().find("/host:CPU") != std::string::npos) {
+        continue;
+      }
 
       std::string name = ns.node_name();
       std::string op_type = "<>";
diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h
index 82af5c5..aea7021 100644
--- a/tensorflow/core/util/tensor_format.h
+++ b/tensorflow/core/util/tensor_format.h
@@ -20,6 +20,7 @@
 #include <vector>
 
 #include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
 #include "tensorflow/core/lib/gtl/inlined_vector.h"
 #include "tensorflow/core/platform/types.h"
 
@@ -123,6 +124,9 @@
       // Note: the VECT_W is not counted as an independent spatial dim here,
       // since it just a component of the width dimension.
       return num_dims - 3;  // Exclude N,C,VectDim.
+    default:
+      LOG(FATAL) << "Unknown format " << format;
+      return -1;  // Avoid compiler warning about missing return value
   }
 }
 
@@ -147,6 +151,9 @@
     case FORMAT_NCHW_VECT_C:
     case FORMAT_NHWC_VECT_W:
       return num_spatial_dims + 3;  // Include N,C,VectDim.
+    default:
+      LOG(FATAL) << "Unknown format " << format;
+      return -1;  // Avoid compiler warning about missing return value
   }
 }
 
@@ -441,7 +448,9 @@
                                           filter_tensor_format) == 3)
                   ? GetFilterDimIndex<3>(filter_tensor_format, dimension)
                   : GetFilterDimIndex<2>(filter_tensor_format, dimension);
-  CHECK(index >= 0 && index < dimension_attribute.size())
+  using size_type = typename gtl::ArraySlice<T>::size_type;
+  CHECK(index >= 0 &&
+        static_cast<size_type>(index) < dimension_attribute.size())
       << "Invalid index from the dimension: " << index << ", "
       << filter_tensor_format << ", " << dimension;
   return dimension_attribute[index];
diff --git a/tensorflow/core/util/tensor_slice_set.cc b/tensorflow/core/util/tensor_slice_set.cc
index 7c1d325..a2b8ca7 100644
--- a/tensorflow/core/util/tensor_slice_set.cc
+++ b/tensorflow/core/util/tensor_slice_set.cc
@@ -30,8 +30,7 @@
 
 TensorSliceSet::~TensorSliceSet() {}
 
-Status TensorSliceSet::Register(const TensorSlice& slice, const string& tag,
-                                const float* data) {
+Status TensorSliceSet::Register(const TensorSlice& slice, const string& tag) {
   TensorShape result_shape;
   TF_RETURN_IF_ERROR(slice.SliceTensorShape(shape_, &result_shape));
   string str = slice.DebugString();
@@ -53,69 +52,11 @@
     slices_hull_.UpdateToCover(slice);
   }
 
-  TensorSliceSet::SliceInfo info = {slice, tag, data,
-                                    result_shape.num_elements()};
+  TensorSliceSet::SliceInfo info = {slice, tag, result_shape.num_elements()};
   slices_.insert(std::make_pair(str, info));
   return Status::OK();
 }
 
-// TODO(yangke): merge Query() with QueryMeta()
-bool TensorSliceSet::Query(const TensorSlice& slice, float* data) const {
-  Status s;
-  string str = slice.DebugString();
-  // First we check if there is an exactly match (this is the dominant case).
-  const TensorSliceSet::SliceInfo* info = gtl::FindOrNull(slices_, str);
-  if (info) {
-    if (data) {
-      std::copy_n(info->data, info->num_floats, data);
-    }
-    return true;
-  } else {
-    // We didn't find any exact match but there is still a possibility that
-    // multiple existing slices can be patched together to output the slice.
-    // We figure this out by computing the intersection of each of the existing
-    // slices with the query slice, and check if the union of all these
-    // intersections cover the entire slice. We rely on the fact that the
-    // existing slices don't have any intersection among themselves.
-    TensorShape target_shape;
-    Status s;
-    s = slice.SliceTensorShape(shape_, &target_shape);
-    if (!s.ok()) {
-      LOG(WARNING) << s;
-      return false;
-    }
-    int64 total_size = target_shape.num_elements();
-
-    int64 overlap_size = 0;
-    TensorSlice intersection;
-    TensorShape inter_shape;
-    for (const auto& x : slices_) {
-      if (slice.Intersect(x.second.slice, &intersection)) {
-        s = intersection.SliceTensorShape(shape_, &inter_shape);
-        if (!s.ok()) {
-          LOG(WARNING) << s;
-          return false;
-        }
-        overlap_size += inter_shape.num_elements();
-      }
-    }
-    if (total_size == overlap_size) {
-      // We have it!
-      // Now we need to copy the data to "data"
-      if (data) {
-        for (const auto& x : slices_) {
-          CopyDataFromTensorSliceToTensorSlice(shape_, x.second.slice, slice,
-                                               x.second.data, data);
-        }
-      }
-      return true;
-    } else {
-      // We don't have all the data for the asked tensor slice
-      return false;
-    }
-  }
-}
-
 bool TensorSliceSet::QueryMeta(
     const TensorSlice& slice,
     std::vector<std::pair<TensorSlice, string>>* results) const {
@@ -194,7 +135,7 @@
     }
   }
   // Register the tensor slices without the actual data.
-  return tss->Register(slice, tag, nullptr);
+  return tss->Register(slice, tag);
 }
 
 }  // namespace checkpoint
diff --git a/tensorflow/core/util/tensor_slice_set.h b/tensorflow/core/util/tensor_slice_set.h
index 22baed0..7ab3586 100644
--- a/tensorflow/core/util/tensor_slice_set.h
+++ b/tensorflow/core/util/tensor_slice_set.h
@@ -16,11 +16,8 @@
 // A class to manage slices of a tensor. You can "register" set of slices for a
 // tensor and then "query" if we have data for a given slice.
 
-// TODO(yangke): consider moving it to a more private place so that we don't
-// need to expose the API.
-
-#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_
-#define TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_
+#ifndef TENSORFLOW_CORE_UTIL_TENSOR_SLICE_SET_H_
+#define TENSORFLOW_CORE_UTIL_TENSOR_SLICE_SET_H_
 
 #include <string>  // for string
 #include <unordered_map>
@@ -49,18 +46,7 @@
   // associated with the slice (in one application it denotes the name of the
   // file that contains the slice); the "data" points to the data of the tensor
   // slice (it can be a nullptr).
-  // We don't take the ownership of "data" and the caller needs to make sure
-  // the data is always available during the life time of the tensor slice set
-  // if it is not nullptr.
-  Status Register(const TensorSlice& slice, const string& tag,
-                  const float* data);
-
-  // Query about a new slice: checks if we have data for "slice" and if we have
-  // the data and "data" is not nullptr, fill "data" with the slice data. The
-  // caller needs to make sure "data" point to a large enough buffer.
-  // TODO(yangke): avoid unnecessary copying by using a core::RefCounted
-  // pointer.
-  bool Query(const TensorSlice& slice, float* data) const;
+  Status Register(const TensorSlice& slice, const string& tag);
 
   // Alternative way of querying about a new slice: instead of copying the
   // data, it returns a list of meta data about the stored slices that will
@@ -72,7 +58,6 @@
   struct SliceInfo {
     TensorSlice slice;
     const string tag;
-    const float* data;
     int64 num_floats;
   };
 
@@ -105,4 +90,4 @@
 
 }  // namespace tensorflow
 
-#endif  // TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_
+#endif  // TENSORFLOW_CORE_UTIL_TENSOR_SLICE_SET_H_
diff --git a/tensorflow/core/util/tensor_slice_set_test.cc b/tensorflow/core/util/tensor_slice_set_test.cc
index 8e12f7c..919629e 100644
--- a/tensorflow/core/util/tensor_slice_set_test.cc
+++ b/tensorflow/core/util/tensor_slice_set_test.cc
@@ -36,107 +36,6 @@
 //
 // We assume this is a row-major matrix.
 //
-// We store the tensor in a couple of slices and verify that we can recover all
-// of them.
-TEST(TensorSliceSetTest, QueryTwoD) {
-  TensorShape shape({4, 5});
-
-  TensorSliceSet tss(shape, DT_FLOAT);
-  // We store a few slices.
-
-  // Slice #1 is the top two rows:
-  //   0   1   2   3   4
-  //   5   6   7   8   9
-  //   .   .   .   .   .
-  //   .   .   .   .   .
-  const float src_1[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
-  TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-");
-  TF_CHECK_OK(tss.Register(slice_1, "", src_1));
-
-  // Slice #2 is the bottom left corner
-  //   .   .   .   .   .
-  //   .   .   .   .   .
-  //  10  11  12   .   .
-  //  15  16  17   .   .
-  const float src_2[] = {10, 11, 12, 15, 16, 17};
-  TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3");
-  TF_CHECK_OK(tss.Register(slice_2, "", src_2));
-
-  // Slice #3 is the bottom right corner
-  //   .   .   .   .   .
-  //   .   .   .   .   .
-  //   .   .   .   .   .
-  //   .   .   .  18  19
-  const float src_3[] = {18, 19};
-  TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2");
-  TF_CHECK_OK(tss.Register(slice_3, "", src_3));
-
-  // Notice that we leave a hole in the tensor
-  //   .   .   .   .   .
-  //   .   .   .   .   .
-  //   .   .   . (13) (14)
-  //   .   .   .   .   .
-
-  // Now we query some of the slices
-
-  // Slice #1 is an exact match
-  //   0   1   2   3   4
-  //   5   6   7   8   9
-  //   .   .   .   .   .
-  //   .   .   .   .   .
-  {
-    TensorSlice s = TensorSlice::ParseOrDie("0,2:-");
-    float expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
-    float results[10];
-    EXPECT_TRUE(tss.Query(s, results));
-    for (int i = 0; i < 10; ++i) {
-      EXPECT_EQ(expected[i], results[i]);
-    }
-  }
-
-  // Slice #2 is a subset match
-  //   .   .   .   .   .
-  //   5   6   7   8   9
-  //   .   .   .   .   .
-  //   .   .   .   .   .
-  {
-    TensorSlice s = TensorSlice::ParseOrDie("1,1:-");
-    float expected[] = {5, 6, 7, 8, 9};
-    float results[5];
-    EXPECT_TRUE(tss.Query(s, results));
-    for (int i = 0; i < 5; ++i) {
-      EXPECT_EQ(expected[i], results[i]);
-    }
-  }
-
-  // Slice #3 is a more complicated match: it needs the combination of a couple
-  // of slices
-  //   .   .   .   .   .
-  //   5   6   7   .   .
-  //  10  11  12   .   .
-  //   .   .   .   .   .
-  {
-    TensorSlice s = TensorSlice::ParseOrDie("1,2:0,3");
-    float expected[] = {5, 6, 7, 10, 11, 12};
-    float results[6];
-    EXPECT_TRUE(tss.Query(s, results));
-    for (int i = 0; i < 6; ++i) {
-      EXPECT_EQ(expected[i], results[i]);
-    }
-  }
-
-  // Slice #4 includes the hole and so there is no match
-  //   .   .   .   .   .
-  //   .   .   7   8   9
-  //   .   .  12  13  14
-  //   .   .   .   .   .
-  {
-    TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3");
-    float results[6];
-    EXPECT_FALSE(tss.Query(s, results));
-  }
-}
-
 // Testing the meta version of the tensor slice set.
 TEST(TensorSliceSetTest, QueryMetaTwoD) {
   TensorShape shape({4, 5});
@@ -150,7 +49,7 @@
   //   .   .   .   .   .
   //   .   .   .   .   .
   TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-");
-  TF_CHECK_OK(tss.Register(slice_1, "slice_1", nullptr));
+  TF_CHECK_OK(tss.Register(slice_1, "slice_1"));
 
   // Slice #2 is the bottom left corner
   //   .   .   .   .   .
@@ -158,7 +57,7 @@
   //  10  11  12   .   .
   //  15  16  17   .   .
   TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3");
-  TF_CHECK_OK(tss.Register(slice_2, "slice_2", nullptr));
+  TF_CHECK_OK(tss.Register(slice_2, "slice_2"));
 
   // Slice #3 is the bottom right corner
   //   .   .   .   .   .
@@ -166,7 +65,7 @@
   //   .   .   .   .   .
   //   .   .   .  18  19
   TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2");
-  TF_CHECK_OK(tss.Register(slice_3, "slice_3", nullptr));
+  TF_CHECK_OK(tss.Register(slice_3, "slice_3"));
 
   // Notice that we leave a hole in the tensor
   //   .   .   .   .   .
@@ -250,7 +149,7 @@
   TensorSliceSet slice_set(shape, DT_INT32);
   for (int i = 0; i < parts; ++i) {
     TensorSlice part({{i, 1}, {0, -1}});
-    TF_CHECK_OK(slice_set.Register(part, part.DebugString(), nullptr));
+    TF_CHECK_OK(slice_set.Register(part, part.DebugString()));
   }
 }
 
diff --git a/tensorflow/examples/adding_an_op/BUILD b/tensorflow/examples/adding_an_op/BUILD
index 9feb6eb..47dc19c 100644
--- a/tensorflow/examples/adding_an_op/BUILD
+++ b/tensorflow/examples/adding_an_op/BUILD
@@ -2,7 +2,7 @@
 # Code examples referenced by adding_an_op
 
 load(
-    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "//tensorflow/core/platform:default/build_config_root.bzl",
     "tf_cuda_tests_tags",
     "tf_exec_compatible_with",
 )
diff --git a/tensorflow/examples/android/README.md b/tensorflow/examples/android/README.md
index 4e4e168..bb646d2 100644
--- a/tensorflow/examples/android/README.md
+++ b/tensorflow/examples/android/README.md
@@ -45,7 +45,7 @@
 
 ## Prebuilt Components:
 
-The fastest path to trying the demo is to download the [prebuilt demo APK](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk).
+The fastest path to trying the demo is to download the [prebuilt demo APK](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk).
 
 Also available are precompiled native libraries, and a jcenter package that you
 may simply drop into your own applications. See
@@ -109,7 +109,9 @@
 
 NOTE: Bazel does not currently support building for Android on Windows. Full
 support for gradle/cmake builds is coming soon, but in the meantime we suggest
-that Windows users download the [prebuilt demo APK](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk) instead.
+that Windows users download the
+[prebuilt demo APK](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk)
+instead.
 
 ##### Install Bazel and Android Prerequisites
 
diff --git a/tensorflow/examples/saved_model/integration_tests/deploy_mnist_cnn.py b/tensorflow/examples/saved_model/integration_tests/deploy_mnist_cnn.py
index 64e0c6a..49fd033 100644
--- a/tensorflow/examples/saved_model/integration_tests/deploy_mnist_cnn.py
+++ b/tensorflow/examples/saved_model/integration_tests/deploy_mnist_cnn.py
@@ -79,5 +79,13 @@
     np.testing.assert_allclose(y_lite, y_tf, rtol=0, atol=1e-5,
                                err_msg='Mismatch at test example %d' % i)
 
+  # Test that it loads correctly with v1 load APIs as well.
+  with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as session:
+    tf.compat.v1.saved_model.load(
+        session,
+        [tf.compat.v1.saved_model.SERVING],
+        FLAGS.saved_model_dir)
+
+
 if __name__ == '__main__':
   app.run(main)
diff --git a/tensorflow/examples/saved_model/integration_tests/export_mnist_cnn.py b/tensorflow/examples/saved_model/integration_tests/export_mnist_cnn.py
index 6b94fda..f61631a 100644
--- a/tensorflow/examples/saved_model/integration_tests/export_mnist_cnn.py
+++ b/tensorflow/examples/saved_model/integration_tests/export_mnist_cnn.py
@@ -41,6 +41,11 @@
     'epochs', 10,
     'Number of epochs to train.')
 flags.DEFINE_bool(
+    'use_keras_save_api', False,
+    'Uses tf.keras.models.save_model() on the feature extractor '
+    'instead of tf.saved_model.save() on a manually wrapped version. '
+    'With this, the exported model as no hparams.')
+flags.DEFINE_bool(
     'fast_test_mode', False,
     'Shortcut training for running in unit tests.')
 flags.DEFINE_bool(
@@ -180,11 +185,19 @@
   # Save the feature extractor to a framework-agnostic SavedModel for reuse.
   # Note that the feature_extractor object has not been compiled or fitted,
   # so it does not contain an optimizer and related state.
-  exportable = wrap_keras_model_for_export(feature_extractor,
-                                           (None,) + mnist_util.INPUT_SHAPE,
-                                           set_feature_extractor_hparams,
-                                           default_hparams)
-  tf.saved_model.save(exportable, FLAGS.export_dir)
+  if FLAGS.use_keras_save_api:
+    # Use Keras' built-in way of creating reusable SavedModels.
+    # This has no support for adjustable hparams at this time (July 2019).
+    # (We could also call tf.saved_model.save(feature_extractor, ...),
+    # point is we're passing a Keras model, not a plain Checkpoint.)
+    tf.keras.models.save_model(feature_extractor, FLAGS.export_dir)
+  else:
+    # Assemble a reusable SavedModel manually, with adjustable hparams.
+    exportable = wrap_keras_model_for_export(feature_extractor,
+                                             (None,) + mnist_util.INPUT_SHAPE,
+                                             set_feature_extractor_hparams,
+                                             default_hparams)
+    tf.saved_model.save(exportable, FLAGS.export_dir)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/examples/saved_model/integration_tests/saved_model_test.py b/tensorflow/examples/saved_model/integration_tests/saved_model_test.py
index b516b8e..232a5b5 100644
--- a/tensorflow/examples/saved_model/integration_tests/saved_model_test.py
+++ b/tensorflow/examples/saved_model/integration_tests/saved_model_test.py
@@ -74,16 +74,19 @@
       combinations=(
           combinations.combine(
               # Test all combinations with tf.saved_model.save().
-              use_keras_save_api=False,
+              # Test all combinations using tf.keras.models.save_model()
+              # for both the reusable and the final full model.
+              use_keras_save_api=True,
               named_strategy=list(ds_utils.named_strategies.values()),
               retrain_flag_value=["true", "false"],
               regularization_loss_multiplier=[None, 2],  # Test for b/134528831.
           ) + combinations.combine(
-              # Test few critcial combinations with tf.keras.models.save_model()
-              # which is merely a thin wrapper (as of June 2019).
-              use_keras_save_api=True,
+              # Test few critcial combinations with raw tf.saved_model.save(),
+              # including export of a reusable SavedModel that gets assembled
+              # manually, including support for adjustable hparams.
+              use_keras_save_api=False,
               named_strategy=None,
-              retrain_flag_value="true",
+              retrain_flag_value=["true", "false"],
               regularization_loss_multiplier=[None, 2],  # Test for b/134528831.
           )),
       test_combinations=[combinations.NamedGPUCombination()])
@@ -97,24 +100,19 @@
     fast_test_mode = True
     temp_dir = self.get_temp_dir()
     feature_extrator_dir = os.path.join(temp_dir, "mnist_feature_extractor")
-
-    # TODO(b/135043074): remove this if-else.
-    if named_strategy is None:
-      full_model_dir = os.path.join(temp_dir, "full_model")
-    else:
-      full_model_dir = None
+    full_model_dir = os.path.join(temp_dir, "full_model")
 
     self.assertCommandSucceeded(
         "export_mnist_cnn",
         fast_test_mode=fast_test_mode,
-        export_dir=feature_extrator_dir)
+        export_dir=feature_extrator_dir,
+        use_keras_save_api=use_keras_save_api)
 
     use_kwargs = dict(fast_test_mode=fast_test_mode,
                       input_saved_model_dir=feature_extrator_dir,
                       retrain=retrain_flag_value,
+                      output_saved_model_dir=full_model_dir,
                       use_keras_save_api=use_keras_save_api)
-    if full_model_dir is not None:
-      use_kwargs["output_saved_model_dir"] = full_model_dir
     if named_strategy:
       use_kwargs["strategy"] = str(named_strategy)
     if regularization_loss_multiplier is not None:
@@ -122,11 +120,10 @@
           "regularization_loss_multiplier"] = regularization_loss_multiplier
     self.assertCommandSucceeded("use_mnist_cnn", **use_kwargs)
 
-    if full_model_dir is not None:
-      self.assertCommandSucceeded(
-          "deploy_mnist_cnn",
-          fast_test_mode=fast_test_mode,
-          saved_model_dir=full_model_dir)
+    self.assertCommandSucceeded(
+        "deploy_mnist_cnn",
+        fast_test_mode=fast_test_mode,
+        saved_model_dir=full_model_dir)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/examples/saved_model/integration_tests/use_mnist_cnn.py b/tensorflow/examples/saved_model/integration_tests/use_mnist_cnn.py
index 24d1be4..ae45a02 100644
--- a/tensorflow/examples/saved_model/integration_tests/use_mnist_cnn.py
+++ b/tensorflow/examples/saved_model/integration_tests/use_mnist_cnn.py
@@ -47,7 +47,8 @@
     'If set, the imported SavedModel is trained further.')
 flags.DEFINE_float(
     'dropout_rate', None,
-    'If set, dropout rate passed to the SavedModel.')
+    'If set, dropout rate passed to the SavedModel. '
+    'Requires a SavedModel with support for adjustable hyperparameters.')
 flags.DEFINE_float(
     'regularization_loss_multiplier', None,
     'If set, multiplier for the regularization losses in the SavedModel.')
diff --git a/tensorflow/examples/speech_commands/README.md b/tensorflow/examples/speech_commands/README.md
index 63be04e..8290781 100644
--- a/tensorflow/examples/speech_commands/README.md
+++ b/tensorflow/examples/speech_commands/README.md
@@ -1,4 +1,4 @@
 # Speech Commands Example
 
 This is a basic speech recognition example. For more information, see the
-tutorial at https://www.tensorflow.org/versions/master/tutorials/audio_recognition.
+tutorial at https://www.tensorflow.org/tutorials/sequences/audio_recognition.
diff --git a/tensorflow/examples/speech_commands/freeze.py b/tensorflow/examples/speech_commands/freeze.py
index c61e564..57981ac 100644
--- a/tensorflow/examples/speech_commands/freeze.py
+++ b/tensorflow/examples/speech_commands/freeze.py
@@ -90,7 +90,8 @@
       window_stride_ms, feature_bin_count, preprocess)
   runtime_settings = {'clip_stride_ms': clip_stride_ms}
 
-  wav_data_placeholder = tf.placeholder(tf.string, [], name='wav_data')
+  wav_data_placeholder = tf.compat.v1.placeholder(tf.string, [],
+                                                  name='wav_data')
   decoded_sample_data = contrib_audio.decode_wav(
       wav_data_placeholder,
       desired_channels=1,
@@ -104,7 +105,7 @@
 
   if preprocess == 'average':
     fingerprint_input = tf.nn.pool(
-        tf.expand_dims(spectrogram, -1),
+        input=tf.expand_dims(spectrogram, -1),
         window_shape=[1, model_settings['average_window_width']],
         strides=[1, model_settings['average_window_width']],
         pooling_type='AVG',
@@ -155,7 +156,7 @@
 def main(_):
 
   # Create the model and load its weights.
-  sess = tf.InteractiveSession()
+  sess = tf.compat.v1.InteractiveSession()
   create_inference_graph(
       FLAGS.wanted_words, FLAGS.sample_rate, FLAGS.clip_duration_ms,
       FLAGS.clip_stride_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms,
@@ -167,12 +168,12 @@
   # Turn all the variables into inline constants inside the graph and save it.
   frozen_graph_def = graph_util.convert_variables_to_constants(
       sess, sess.graph_def, ['labels_softmax'])
-  tf.train.write_graph(
+  tf.io.write_graph(
       frozen_graph_def,
       os.path.dirname(FLAGS.output_file),
       os.path.basename(FLAGS.output_file),
       as_text=False)
-  tf.logging.info('Saved frozen graph to %s', FLAGS.output_file)
+  tf.compat.v1.logging.info('Saved frozen graph to %s', FLAGS.output_file)
 
 
 if __name__ == '__main__':
@@ -236,4 +237,4 @@
       default='mfcc',
       help='Spectrogram processing mode. Can be "mfcc" or "average"')
   FLAGS, unparsed = parser.parse_known_args()
-  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
+  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/examples/speech_commands/generate_streaming_test_wav.py b/tensorflow/examples/speech_commands/generate_streaming_test_wav.py
index 9858906..d3df7f4 100644
--- a/tensorflow/examples/speech_commands/generate_streaming_test_wav.py
+++ b/tensorflow/examples/speech_commands/generate_streaming_test_wav.py
@@ -174,7 +174,7 @@
       '--data_url',
       type=str,
       # pylint: disable=line-too-long
-      default='http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz',
+      default='https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz',
       # pylint: enable=line-too-long
       help='Location of speech training data')
   parser.add_argument(
diff --git a/tensorflow/examples/speech_commands/input_data.py b/tensorflow/examples/speech_commands/input_data.py
index 6c2ce3f..2bb48e0 100644
--- a/tensorflow/examples/speech_commands/input_data.py
+++ b/tensorflow/examples/speech_commands/input_data.py
@@ -123,7 +123,7 @@
     Numpy array holding the sample data as floats between -1.0 and 1.0.
   """
   with tf.compat.v1.Session(graph=tf.Graph()) as sess:
-    wav_filename_placeholder = tf.placeholder(tf.string, [])
+    wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, [])
     wav_loader = io_ops.read_file(wav_filename_placeholder)
     wav_decoder = contrib_audio.decode_wav(wav_loader, desired_channels=1)
     return sess.run(
@@ -140,9 +140,9 @@
     sample_rate: Samples per second to encode in the file.
   """
   with tf.compat.v1.Session(graph=tf.Graph()) as sess:
-    wav_filename_placeholder = tf.placeholder(tf.string, [])
-    sample_rate_placeholder = tf.placeholder(tf.int32, [])
-    wav_data_placeholder = tf.placeholder(tf.float32, [None, 1])
+    wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, [])
+    sample_rate_placeholder = tf.compat.v1.placeholder(tf.int32, [])
+    wav_data_placeholder = tf.compat.v1.placeholder(tf.float32, [None, 1])
     wav_encoder = contrib_audio.encode_wav(wav_data_placeholder,
                                            sample_rate_placeholder)
     wav_saver = io_ops.write_file(wav_filename_placeholder, wav_encoder)
@@ -230,15 +230,16 @@
       try:
         filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress)
       except:
-        tf.logging.error('Failed to download URL: %s to folder: %s', data_url,
-                         filepath)
-        tf.logging.error('Please make sure you have enough free space and'
-                         ' an internet connection')
+        tf.compat.v1.logging.error(
+            'Failed to download URL: %s to folder: %s', data_url, filepath)
+        tf.compat.v1.logging.error(
+            'Please make sure you have enough free space and'
+            ' an internet connection')
         raise
       print()
       statinfo = os.stat(filepath)
-      tf.logging.info('Successfully downloaded %s (%d bytes)', filename,
-                      statinfo.st_size)
+      tf.compat.v1.logging.info('Successfully downloaded %s (%d bytes)',
+                                filename, statinfo.st_size)
     tarfile.open(filepath, 'r:gz').extractall(dest_directory)
 
   def prepare_data_index(self, silence_percentage, unknown_percentage,
@@ -350,7 +351,7 @@
     if not os.path.exists(background_dir):
       return self.background_data
     with tf.compat.v1.Session(graph=tf.Graph()) as sess:
-      wav_filename_placeholder = tf.placeholder(tf.string, [])
+      wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, [])
       wav_loader = io_ops.read_file(wav_filename_placeholder)
       wav_decoder = contrib_audio.decode_wav(wav_loader, desired_channels=1)
       search_path = os.path.join(self.data_dir, BACKGROUND_NOISE_DIR_NAME,
@@ -389,34 +390,34 @@
       ValueError: If the preprocessing mode isn't recognized.
       Exception: If the preprocessor wasn't compiled in.
     """
-    with tf.get_default_graph().name_scope('data'):
+    with tf.compat.v1.get_default_graph().name_scope('data'):
       desired_samples = model_settings['desired_samples']
-      self.wav_filename_placeholder_ = tf.placeholder(
+      self.wav_filename_placeholder_ = tf.compat.v1.placeholder(
           tf.string, [], name='wav_filename')
       wav_loader = io_ops.read_file(self.wav_filename_placeholder_)
       wav_decoder = contrib_audio.decode_wav(
           wav_loader, desired_channels=1, desired_samples=desired_samples)
       # Allow the audio sample's volume to be adjusted.
-      self.foreground_volume_placeholder_ = tf.placeholder(
+      self.foreground_volume_placeholder_ = tf.compat.v1.placeholder(
           tf.float32, [], name='foreground_volume')
       scaled_foreground = tf.multiply(wav_decoder.audio,
                                       self.foreground_volume_placeholder_)
       # Shift the sample's start position, and pad any gaps with zeros.
-      self.time_shift_padding_placeholder_ = tf.placeholder(
+      self.time_shift_padding_placeholder_ = tf.compat.v1.placeholder(
           tf.int32, [2, 2], name='time_shift_padding')
-      self.time_shift_offset_placeholder_ = tf.placeholder(
+      self.time_shift_offset_placeholder_ = tf.compat.v1.placeholder(
           tf.int32, [2], name='time_shift_offset')
       padded_foreground = tf.pad(
-          scaled_foreground,
-          self.time_shift_padding_placeholder_,
+          tensor=scaled_foreground,
+          paddings=self.time_shift_padding_placeholder_,
           mode='CONSTANT')
       sliced_foreground = tf.slice(padded_foreground,
                                    self.time_shift_offset_placeholder_,
                                    [desired_samples, -1])
       # Mix in background noise.
-      self.background_data_placeholder_ = tf.placeholder(
+      self.background_data_placeholder_ = tf.compat.v1.placeholder(
           tf.float32, [desired_samples, 1], name='background_data')
-      self.background_volume_placeholder_ = tf.placeholder(
+      self.background_volume_placeholder_ = tf.compat.v1.placeholder(
           tf.float32, [], name='background_volume')
       background_mul = tf.multiply(self.background_data_placeholder_,
                                    self.background_volume_placeholder_)
@@ -428,7 +429,7 @@
           window_size=model_settings['window_size_samples'],
           stride=model_settings['window_stride_samples'],
           magnitude_squared=True)
-      tf.summary.image(
+      tf.compat.v1.summary.image(
           'spectrogram', tf.expand_dims(spectrogram, -1), max_outputs=1)
       # The number of buckets in each FFT row in the spectrogram will depend on
       # how many input samples there are in each window. This can be quite
@@ -440,18 +441,20 @@
       # algorithm to shrink the representation.
       if model_settings['preprocess'] == 'average':
         self.output_ = tf.nn.pool(
-            tf.expand_dims(spectrogram, -1),
+            input=tf.expand_dims(spectrogram, -1),
             window_shape=[1, model_settings['average_window_width']],
             strides=[1, model_settings['average_window_width']],
             pooling_type='AVG',
             padding='SAME')
-        tf.summary.image('shrunk_spectrogram', self.output_, max_outputs=1)
+        tf.compat.v1.summary.image('shrunk_spectrogram',
+                                   self.output_,
+                                   max_outputs=1)
       elif model_settings['preprocess'] == 'mfcc':
         self.output_ = contrib_audio.mfcc(
             spectrogram,
             wav_decoder.sample_rate,
             dct_coefficient_count=model_settings['fingerprint_width'])
-        tf.summary.image(
+        tf.compat.v1.summary.image(
             'mfcc', tf.expand_dims(self.output_, -1), max_outputs=1)
       elif model_settings['preprocess'] == 'micro':
         if not frontend_op:
@@ -474,7 +477,7 @@
             out_scale=1,
             out_type=tf.float32)
         self.output_ = tf.multiply(micro_frontend, (10.0 / 256.0))
-        tf.summary.image(
+        tf.compat.v1.summary.image(
             'micro',
             tf.expand_dims(tf.expand_dims(self.output_, -1), 0),
             max_outputs=1)
@@ -485,10 +488,10 @@
 
       # Merge all the summaries and write them out to /tmp/retrain_logs (by
       # default)
-      self.merged_summaries_ = tf.summary.merge_all(scope='data')
+      self.merged_summaries_ = tf.compat.v1.summary.merge_all(scope='data')
       if summaries_dir:
-        self.summary_writer_ = tf.summary.FileWriter(summaries_dir + '/data',
-                                                     tf.get_default_graph())
+        self.summary_writer_ = tf.compat.v1.summary.FileWriter(
+            summaries_dir + '/data', tf.compat.v1.get_default_graph())
 
   def set_size(self, mode):
     """Calculates the number of samples in the dataset partition.
@@ -655,11 +658,11 @@
     data = np.zeros((sample_count, desired_samples))
     labels = []
     with tf.compat.v1.Session(graph=tf.Graph()) as sess:
-      wav_filename_placeholder = tf.placeholder(tf.string, [])
+      wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, [])
       wav_loader = io_ops.read_file(wav_filename_placeholder)
       wav_decoder = contrib_audio.decode_wav(
           wav_loader, desired_channels=1, desired_samples=desired_samples)
-      foreground_volume_placeholder = tf.placeholder(tf.float32, [])
+      foreground_volume_placeholder = tf.compat.v1.placeholder(tf.float32, [])
       scaled_foreground = tf.multiply(wav_decoder.audio,
                                       foreground_volume_placeholder)
       for i in range(sample_count):
diff --git a/tensorflow/examples/speech_commands/label_wav.py b/tensorflow/examples/speech_commands/label_wav.py
index 5af1669..eef0fc2 100644
--- a/tensorflow/examples/speech_commands/label_wav.py
+++ b/tensorflow/examples/speech_commands/label_wav.py
@@ -45,15 +45,15 @@
 
 def load_graph(filename):
   """Unpersists graph from file as default graph."""
-  with tf.gfile.GFile(filename, 'rb') as f:
-    graph_def = tf.GraphDef()
+  with tf.io.gfile.GFile(filename, 'rb') as f:
+    graph_def = tf.compat.v1.GraphDef()
     graph_def.ParseFromString(f.read())
     tf.import_graph_def(graph_def, name='')
 
 
 def load_labels(filename):
   """Read in labels, one label per line."""
-  return [line.rstrip() for line in tf.gfile.GFile(filename)]
+  return [line.rstrip() for line in tf.io.gfile.GFile(filename)]
 
 
 def run_graph(wav_data, labels, input_layer_name, output_layer_name,
@@ -79,14 +79,14 @@
 
 def label_wav(wav, labels, graph, input_name, output_name, how_many_labels):
   """Loads the model and labels, and runs the inference to print predictions."""
-  if not wav or not tf.gfile.Exists(wav):
-    tf.logging.fatal('Audio file does not exist %s', wav)
+  if not wav or not tf.io.gfile.exists(wav):
+    tf.compat.v1.logging.fatal('Audio file does not exist %s', wav)
 
-  if not labels or not tf.gfile.Exists(labels):
-    tf.logging.fatal('Labels file does not exist %s', labels)
+  if not labels or not tf.io.gfile.exists(labels):
+    tf.compat.v1.logging.fatal('Labels file does not exist %s', labels)
 
-  if not graph or not tf.gfile.Exists(graph):
-    tf.logging.fatal('Graph file does not exist %s', graph)
+  if not graph or not tf.io.gfile.exists(graph):
+    tf.compat.v1.logging.fatal('Graph file does not exist %s', graph)
 
   labels_list = load_labels(labels)
 
@@ -130,4 +130,4 @@
       help='Number of results to show.')
 
   FLAGS, unparsed = parser.parse_known_args()
-  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
+  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/examples/speech_commands/label_wav_test.py b/tensorflow/examples/speech_commands/label_wav_test.py
index 3c833d6..8dbbb71 100644
--- a/tensorflow/examples/speech_commands/label_wav_test.py
+++ b/tensorflow/examples/speech_commands/label_wav_test.py
@@ -49,7 +49,7 @@
     output_name = "test_output"
     graph_filename = os.path.join(tmp_dir, "test_graph.pb")
     with tf.compat.v1.Session() as sess:
-      tf.placeholder(tf.string, name=input_name)
+      tf.compat.v1.placeholder(tf.string, name=input_name)
       tf.zeros([1, 3], name=output_name)
       with open(graph_filename, "wb") as f:
         f.write(sess.graph.as_graph_def().SerializeToString())
diff --git a/tensorflow/examples/speech_commands/models.py b/tensorflow/examples/speech_commands/models.py
index 23bc55d..4fc144b 100644
--- a/tensorflow/examples/speech_commands/models.py
+++ b/tensorflow/examples/speech_commands/models.py
@@ -157,7 +157,7 @@
     sess: TensorFlow session.
     start_checkpoint: Path to saved checkpoint on disk.
   """
-  saver = tf.train.Saver(tf.global_variables())
+  saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
   saver.restore(sess, start_checkpoint)
 
 
@@ -187,15 +187,16 @@
     placeholder.
   """
   if is_training:
-    dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
+    dropout_prob = tf.compat.v1.placeholder(tf.float32, name='dropout_prob')
   fingerprint_size = model_settings['fingerprint_size']
   label_count = model_settings['label_count']
-  weights = tf.get_variable(
+  weights = tf.compat.v1.get_variable(
       name='weights',
-      initializer=tf.truncated_normal_initializer(stddev=0.001),
+      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.001),
       shape=[fingerprint_size, label_count])
-  bias = tf.get_variable(
-      name='bias', initializer=tf.zeros_initializer, shape=[label_count])
+  bias = tf.compat.v1.get_variable(name='bias',
+                                   initializer=tf.compat.v1.zeros_initializer,
+                                   shape=[label_count])
   logits = tf.matmul(fingerprint_input, weights) + bias
   if is_training:
     return logits, dropout_prob
@@ -252,7 +253,7 @@
     placeholder.
   """
   if is_training:
-    dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
+    dropout_prob = tf.compat.v1.placeholder(tf.float32, name='dropout_prob')
   input_frequency_size = model_settings['fingerprint_width']
   input_time_size = model_settings['spectrogram_length']
   fingerprint_4d = tf.reshape(fingerprint_input,
@@ -260,41 +261,48 @@
   first_filter_width = 8
   first_filter_height = 20
   first_filter_count = 64
-  first_weights = tf.get_variable(
+  first_weights = tf.compat.v1.get_variable(
       name='first_weights',
-      initializer=tf.truncated_normal_initializer(stddev=0.01),
+      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
       shape=[first_filter_height, first_filter_width, 1, first_filter_count])
-  first_bias = tf.get_variable(
+  first_bias = tf.compat.v1.get_variable(
       name='first_bias',
-      initializer=tf.zeros_initializer,
+      initializer=tf.compat.v1.zeros_initializer,
       shape=[first_filter_count])
-  first_conv = tf.nn.conv2d(fingerprint_4d, first_weights, [1, 1, 1, 1],
-                            'SAME') + first_bias
+  first_conv = tf.nn.conv2d(input=fingerprint_4d,
+                            filters=first_weights,
+                            strides=[1, 1, 1, 1],
+                            padding='SAME') + first_bias
   first_relu = tf.nn.relu(first_conv)
   if is_training:
-    first_dropout = tf.nn.dropout(first_relu, dropout_prob)
+    first_dropout = tf.compat.v1.nn.dropout(first_relu, dropout_prob)
   else:
     first_dropout = first_relu
-  max_pool = tf.nn.max_pool(first_dropout, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME')
+  max_pool = tf.nn.max_pool2d(input=first_dropout,
+                              ksize=[1, 2, 2, 1],
+                              strides=[1, 2, 2, 1],
+                              padding='SAME')
   second_filter_width = 4
   second_filter_height = 10
   second_filter_count = 64
-  second_weights = tf.get_variable(
+  second_weights = tf.compat.v1.get_variable(
       name='second_weights',
-      initializer=tf.truncated_normal_initializer(stddev=0.01),
+      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
       shape=[
           second_filter_height, second_filter_width, first_filter_count,
           second_filter_count
       ])
-  second_bias = tf.get_variable(
+  second_bias = tf.compat.v1.get_variable(
       name='second_bias',
-      initializer=tf.zeros_initializer,
+      initializer=tf.compat.v1.zeros_initializer,
       shape=[second_filter_count])
-  second_conv = tf.nn.conv2d(max_pool, second_weights, [1, 1, 1, 1],
-                             'SAME') + second_bias
+  second_conv = tf.nn.conv2d(input=max_pool,
+                             filters=second_weights,
+                             strides=[1, 1, 1, 1],
+                             padding='SAME') + second_bias
   second_relu = tf.nn.relu(second_conv)
   if is_training:
-    second_dropout = tf.nn.dropout(second_relu, dropout_prob)
+    second_dropout = tf.compat.v1.nn.dropout(second_relu, dropout_prob)
   else:
     second_dropout = second_relu
   second_conv_shape = second_dropout.get_shape()
@@ -306,13 +314,13 @@
   flattened_second_conv = tf.reshape(second_dropout,
                                      [-1, second_conv_element_count])
   label_count = model_settings['label_count']
-  final_fc_weights = tf.get_variable(
+  final_fc_weights = tf.compat.v1.get_variable(
       name='final_fc_weights',
-      initializer=tf.truncated_normal_initializer(stddev=0.01),
+      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
       shape=[second_conv_element_count, label_count])
-  final_fc_bias = tf.get_variable(
+  final_fc_bias = tf.compat.v1.get_variable(
       name='final_fc_bias',
-      initializer=tf.zeros_initializer,
+      initializer=tf.compat.v1.zeros_initializer,
       shape=[label_count])
   final_fc = tf.matmul(flattened_second_conv, final_fc_weights) + final_fc_bias
   if is_training:
@@ -368,7 +376,7 @@
     placeholder.
   """
   if is_training:
-    dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
+    dropout_prob = tf.compat.v1.placeholder(tf.float32, name='dropout_prob')
   input_frequency_size = model_settings['fingerprint_width']
   input_time_size = model_settings['spectrogram_length']
   fingerprint_4d = tf.reshape(fingerprint_input,
@@ -378,20 +386,21 @@
   first_filter_count = 186
   first_filter_stride_x = 1
   first_filter_stride_y = 1
-  first_weights = tf.get_variable(
+  first_weights = tf.compat.v1.get_variable(
       name='first_weights',
-      initializer=tf.truncated_normal_initializer(stddev=0.01),
+      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
       shape=[first_filter_height, first_filter_width, 1, first_filter_count])
-  first_bias = tf.get_variable(
+  first_bias = tf.compat.v1.get_variable(
       name='first_bias',
-      initializer=tf.zeros_initializer,
+      initializer=tf.compat.v1.zeros_initializer,
       shape=[first_filter_count])
-  first_conv = tf.nn.conv2d(fingerprint_4d, first_weights, [
-      1, first_filter_stride_y, first_filter_stride_x, 1
-  ], 'VALID') + first_bias
+  first_conv = tf.nn.conv2d(
+      input=fingerprint_4d, filters=first_weights,
+      strides=[1, first_filter_stride_y, first_filter_stride_x, 1],
+      padding='VALID') + first_bias
   first_relu = tf.nn.relu(first_conv)
   if is_training:
-    first_dropout = tf.nn.dropout(first_relu, dropout_prob)
+    first_dropout = tf.compat.v1.nn.dropout(first_relu, dropout_prob)
   else:
     first_dropout = first_relu
   first_conv_output_width = math.floor(
@@ -405,41 +414,41 @@
   flattened_first_conv = tf.reshape(first_dropout,
                                     [-1, first_conv_element_count])
   first_fc_output_channels = 128
-  first_fc_weights = tf.get_variable(
+  first_fc_weights = tf.compat.v1.get_variable(
       name='first_fc_weights',
-      initializer=tf.truncated_normal_initializer(stddev=0.01),
+      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
       shape=[first_conv_element_count, first_fc_output_channels])
-  first_fc_bias = tf.get_variable(
+  first_fc_bias = tf.compat.v1.get_variable(
       name='first_fc_bias',
-      initializer=tf.zeros_initializer,
+      initializer=tf.compat.v1.zeros_initializer,
       shape=[first_fc_output_channels])
   first_fc = tf.matmul(flattened_first_conv, first_fc_weights) + first_fc_bias
   if is_training:
-    second_fc_input = tf.nn.dropout(first_fc, dropout_prob)
+    second_fc_input = tf.compat.v1.nn.dropout(first_fc, dropout_prob)
   else:
     second_fc_input = first_fc
   second_fc_output_channels = 128
-  second_fc_weights = tf.get_variable(
+  second_fc_weights = tf.compat.v1.get_variable(
       name='second_fc_weights',
-      initializer=tf.truncated_normal_initializer(stddev=0.01),
+      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
       shape=[first_fc_output_channels, second_fc_output_channels])
-  second_fc_bias = tf.get_variable(
+  second_fc_bias = tf.compat.v1.get_variable(
       name='second_fc_bias',
-      initializer=tf.zeros_initializer,
+      initializer=tf.compat.v1.zeros_initializer,
       shape=[second_fc_output_channels])
   second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias
   if is_training:
-    final_fc_input = tf.nn.dropout(second_fc, dropout_prob)
+    final_fc_input = tf.compat.v1.nn.dropout(second_fc, dropout_prob)
   else:
     final_fc_input = second_fc
   label_count = model_settings['label_count']
-  final_fc_weights = tf.get_variable(
+  final_fc_weights = tf.compat.v1.get_variable(
       name='final_fc_weights',
-      initializer=tf.truncated_normal_initializer(stddev=0.01),
+      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
       shape=[second_fc_output_channels, label_count])
-  final_fc_bias = tf.get_variable(
+  final_fc_bias = tf.compat.v1.get_variable(
       name='final_fc_bias',
-      initializer=tf.zeros_initializer,
+      initializer=tf.compat.v1.zeros_initializer,
       shape=[label_count])
   final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias
   if is_training:
@@ -504,7 +513,7 @@
       ValueError: If the inputs tensor is incorrectly shaped.
   """
   if is_training:
-    dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
+    dropout_prob = tf.compat.v1.placeholder(tf.float32, name='dropout_prob')
 
   input_frequency_size = model_settings['fingerprint_width']
   input_time_size = model_settings['spectrogram_length']
@@ -528,12 +537,12 @@
   num_filters = rank * num_units
   # Create the runtime memory: [num_filters, batch, input_time_size]
   batch = 1
-  memory = tf.get_variable(
-      initializer=tf.zeros_initializer,
+  memory = tf.compat.v1.get_variable(
+      initializer=tf.compat.v1.zeros_initializer,
       shape=[num_filters, batch, input_time_size],
       trainable=False,
       name='runtime-memory')
-  first_time_flag = tf.get_variable(
+  first_time_flag = tf.compat.v1.get_variable(
       name="first_time_flag",
       dtype=tf.int32,
       initializer=1)
@@ -547,9 +556,9 @@
     window_stride_ms = int(model_settings['window_stride_samples'] * 1000 /
                            model_settings['sample_rate'])
     num_new_frames = tf.cond(
-        tf.equal(first_time_flag, 1),
-        lambda: input_time_size,
-        lambda: int(runtime_settings['clip_stride_ms'] / window_stride_ms))
+        pred=tf.equal(first_time_flag, 1),
+        true_fn=lambda: input_time_size,
+        false_fn=lambda: int(runtime_settings['clip_stride_ms'] / window_stride_ms))  # pylint:disable=line-too-long
   first_time_flag = 0
   new_fingerprint_input = fingerprint_input[
       :, -num_new_frames*input_frequency_size:]
@@ -557,20 +566,22 @@
   new_fingerprint_input = tf.expand_dims(new_fingerprint_input, 2)
 
   # Create the frequency filters.
-  weights_frequency = tf.get_variable(
+  weights_frequency = tf.compat.v1.get_variable(
       name='weights_frequency',
-      initializer=tf.truncated_normal_initializer(stddev=0.01),
+      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
       shape=[input_frequency_size, num_filters])
   # Expand to add input channels dimensions.
   # weights_frequency: [input_frequency_size, 1, num_filters]
   weights_frequency = tf.expand_dims(weights_frequency, 1)
   # Convolve the 1D feature filters sliding over the time dimension.
   # activations_time: [batch, num_new_frames, num_filters]
-  activations_time = tf.nn.conv1d(
-      new_fingerprint_input, weights_frequency, input_frequency_size, 'VALID')
+  activations_time = tf.nn.conv1d(input=new_fingerprint_input,
+                                  filters=weights_frequency,
+                                  stride=input_frequency_size,
+                                  padding='VALID')
   # Rearrange such that we can perform the batched matmul.
   # activations_time: [num_filters, batch, num_new_frames]
-  activations_time = tf.transpose(activations_time, perm=[2, 0, 1])
+  activations_time = tf.transpose(a=activations_time, perm=[2, 0, 1])
 
   # Runtime memory optimization.
   if not is_training:
@@ -578,13 +589,13 @@
     # then add those corresponding to the new frames.
     new_memory = memory[:, :, num_new_frames:]
     new_memory = tf.concat([new_memory, activations_time], 2)
-    tf.assign(memory, new_memory)
+    tf.compat.v1.assign(memory, new_memory)
     activations_time = new_memory
 
   # Create the time filters.
-  weights_time = tf.get_variable(
+  weights_time = tf.compat.v1.get_variable(
       name='weights_time',
-      initializer=tf.truncated_normal_initializer(stddev=0.01),
+      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
       shape=[num_filters, input_time_size])
   # Apply the time filter on the outputs of the feature filters.
   # weights_time: [num_filters, input_time_size, 1]
@@ -597,59 +608,60 @@
   # [num_filters, batch, 1] => [num_units, rank, batch]
   outputs = tf.reshape(outputs, [num_units, rank, -1])
   # Sum the rank outputs per unit => [num_units, batch].
-  units_output = tf.reduce_sum(outputs, axis=1)
+  units_output = tf.reduce_sum(input_tensor=outputs, axis=1)
   # Transpose to shape [batch, num_units]
-  units_output = tf.transpose(units_output)
+  units_output = tf.transpose(a=units_output)
 
   # Appy bias.
-  bias = tf.get_variable(
-      name='bias', initializer=tf.zeros_initializer, shape=[num_units])
+  bias = tf.compat.v1.get_variable(name='bias',
+                                   initializer=tf.compat.v1.zeros_initializer,
+                                   shape=[num_units])
   first_bias = tf.nn.bias_add(units_output, bias)
 
   # Relu.
   first_relu = tf.nn.relu(first_bias)
 
   if is_training:
-    first_dropout = tf.nn.dropout(first_relu, dropout_prob)
+    first_dropout = tf.compat.v1.nn.dropout(first_relu, dropout_prob)
   else:
     first_dropout = first_relu
 
   first_fc_output_channels = 256
-  first_fc_weights = tf.get_variable(
+  first_fc_weights = tf.compat.v1.get_variable(
       name='first_fc_weights',
-      initializer=tf.truncated_normal_initializer(stddev=0.01),
+      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
       shape=[num_units, first_fc_output_channels])
-  first_fc_bias = tf.get_variable(
+  first_fc_bias = tf.compat.v1.get_variable(
       name='first_fc_bias',
-      initializer=tf.zeros_initializer,
+      initializer=tf.compat.v1.zeros_initializer,
       shape=[first_fc_output_channels])
   first_fc = tf.matmul(first_dropout, first_fc_weights) + first_fc_bias
   if is_training:
-    second_fc_input = tf.nn.dropout(first_fc, dropout_prob)
+    second_fc_input = tf.compat.v1.nn.dropout(first_fc, dropout_prob)
   else:
     second_fc_input = first_fc
   second_fc_output_channels = 256
-  second_fc_weights = tf.get_variable(
+  second_fc_weights = tf.compat.v1.get_variable(
       name='second_fc_weights',
-      initializer=tf.truncated_normal_initializer(stddev=0.01),
+      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
       shape=[first_fc_output_channels, second_fc_output_channels])
-  second_fc_bias = tf.get_variable(
+  second_fc_bias = tf.compat.v1.get_variable(
       name='second_fc_bias',
-      initializer=tf.zeros_initializer,
+      initializer=tf.compat.v1.zeros_initializer,
       shape=[second_fc_output_channels])
   second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias
   if is_training:
-    final_fc_input = tf.nn.dropout(second_fc, dropout_prob)
+    final_fc_input = tf.compat.v1.nn.dropout(second_fc, dropout_prob)
   else:
     final_fc_input = second_fc
   label_count = model_settings['label_count']
-  final_fc_weights = tf.get_variable(
+  final_fc_weights = tf.compat.v1.get_variable(
       name='final_fc_weights',
-      initializer=tf.truncated_normal_initializer(stddev=0.01),
+      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
       shape=[second_fc_output_channels, label_count])
-  final_fc_bias = tf.get_variable(
+  final_fc_bias = tf.compat.v1.get_variable(
       name='final_fc_bias',
-      initializer=tf.zeros_initializer,
+      initializer=tf.compat.v1.zeros_initializer,
       shape=[label_count])
   final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias
   if is_training:
@@ -698,7 +710,7 @@
     placeholder.
   """
   if is_training:
-    dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
+    dropout_prob = tf.compat.v1.placeholder(tf.float32, name='dropout_prob')
   input_frequency_size = model_settings['fingerprint_width']
   input_time_size = model_settings['spectrogram_length']
   fingerprint_4d = tf.reshape(fingerprint_input,
@@ -706,22 +718,23 @@
   first_filter_width = 8
   first_filter_height = 10
   first_filter_count = 8
-  first_weights = tf.get_variable(
+  first_weights = tf.compat.v1.get_variable(
       name='first_weights',
-      initializer=tf.truncated_normal_initializer(stddev=0.01),
+      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
       shape=[first_filter_height, first_filter_width, 1, first_filter_count])
-  first_bias = tf.get_variable(
+  first_bias = tf.compat.v1.get_variable(
       name='first_bias',
-      initializer=tf.zeros_initializer,
+      initializer=tf.compat.v1.zeros_initializer,
       shape=[first_filter_count])
   first_conv_stride_x = 2
   first_conv_stride_y = 2
-  first_conv = tf.nn.conv2d(fingerprint_4d, first_weights,
-                            [1, first_conv_stride_y, first_conv_stride_x, 1],
-                            'SAME') + first_bias
+  first_conv = tf.nn.conv2d(
+      input=fingerprint_4d, filters=first_weights,
+      strides=[1, first_conv_stride_y, first_conv_stride_x, 1],
+      padding='SAME') + first_bias
   first_relu = tf.nn.relu(first_conv)
   if is_training:
-    first_dropout = tf.nn.dropout(first_relu, dropout_prob)
+    first_dropout = tf.compat.v1.nn.dropout(first_relu, dropout_prob)
   else:
     first_dropout = first_relu
   first_dropout_shape = first_dropout.get_shape()
@@ -733,13 +746,13 @@
   flattened_first_dropout = tf.reshape(first_dropout,
                                        [-1, first_dropout_element_count])
   label_count = model_settings['label_count']
-  final_fc_weights = tf.get_variable(
+  final_fc_weights = tf.compat.v1.get_variable(
       name='final_fc_weights',
-      initializer=tf.truncated_normal_initializer(stddev=0.01),
+      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
       shape=[first_dropout_element_count, label_count])
-  final_fc_bias = tf.get_variable(
+  final_fc_bias = tf.compat.v1.get_variable(
       name='final_fc_bias',
-      initializer=tf.zeros_initializer,
+      initializer=tf.compat.v1.zeros_initializer,
       shape=[label_count])
   final_fc = (
       tf.matmul(flattened_first_dropout, final_fc_weights) + final_fc_bias)
@@ -802,7 +815,7 @@
     placeholder.
   """
   if is_training:
-    dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
+    dropout_prob = tf.compat.v1.placeholder(tf.float32, name='dropout_prob')
   input_frequency_size = model_settings['fingerprint_width']
   input_time_size = model_settings['spectrogram_length']
   fingerprint_4d = tf.reshape(fingerprint_input,
@@ -811,47 +824,49 @@
   first_filter_width = 8
   first_filter_height = 10
   first_filter_count = 8
-  first_weights = tf.get_variable(
+  first_weights = tf.compat.v1.get_variable(
       name='first_weights',
-      initializer=tf.truncated_normal_initializer(stddev=0.01),
+      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
       shape=[first_filter_height, first_filter_width, 1, first_filter_count])
-  first_bias = tf.get_variable(
+  first_bias = tf.compat.v1.get_variable(
       name='first_bias',
-      initializer=tf.zeros_initializer,
+      initializer=tf.compat.v1.zeros_initializer,
       shape=[first_filter_count])
   first_conv_stride_x = 2
   first_conv_stride_y = 2
-  first_conv = tf.nn.conv2d(fingerprint_4d, first_weights,
-                            [1, first_conv_stride_y, first_conv_stride_x, 1],
-                            'SAME') + first_bias
+  first_conv = tf.nn.conv2d(
+      input=fingerprint_4d, filters=first_weights,
+      strides=[1, first_conv_stride_y, first_conv_stride_x, 1],
+      padding='SAME') + first_bias
   first_relu = tf.nn.relu(first_conv)
   if is_training:
-    first_dropout = tf.nn.dropout(first_relu, dropout_prob)
+    first_dropout = tf.compat.v1.nn.dropout(first_relu, dropout_prob)
   else:
     first_dropout = first_relu
 
   second_filter_width = 8
   second_filter_height = 10
   second_filter_count = 8
-  second_weights = tf.get_variable(
+  second_weights = tf.compat.v1.get_variable(
       name='second_weights',
-      initializer=tf.truncated_normal_initializer(stddev=0.01),
+      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
       shape=[
           second_filter_height, second_filter_width, first_filter_count,
           second_filter_count
       ])
-  second_bias = tf.get_variable(
+  second_bias = tf.compat.v1.get_variable(
       name='second_bias',
-      initializer=tf.zeros_initializer,
+      initializer=tf.compat.v1.zeros_initializer,
       shape=[second_filter_count])
   second_conv_stride_x = 8
   second_conv_stride_y = 8
-  second_conv = tf.nn.conv2d(first_dropout, second_weights,
-                             [1, second_conv_stride_y, second_conv_stride_x, 1],
-                             'SAME') + second_bias
+  second_conv = tf.nn.conv2d(
+      input=first_dropout, filters=second_weights,
+      strides=[1, second_conv_stride_y, second_conv_stride_x, 1],
+      padding='SAME') + second_bias
   second_relu = tf.nn.relu(second_conv)
   if is_training:
-    second_dropout = tf.nn.dropout(second_relu, dropout_prob)
+    second_dropout = tf.compat.v1.nn.dropout(second_relu, dropout_prob)
   else:
     second_dropout = second_relu
 
@@ -864,13 +879,13 @@
   flattened_second_dropout = tf.reshape(second_dropout,
                                         [-1, second_dropout_element_count])
   label_count = model_settings['label_count']
-  final_fc_weights = tf.get_variable(
+  final_fc_weights = tf.compat.v1.get_variable(
       name='final_fc_weights',
-      initializer=tf.truncated_normal_initializer(stddev=0.01),
+      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
       shape=[second_dropout_element_count, label_count])
-  final_fc_bias = tf.get_variable(
+  final_fc_bias = tf.compat.v1.get_variable(
       name='final_fc_bias',
-      initializer=tf.zeros_initializer,
+      initializer=tf.compat.v1.zeros_initializer,
       shape=[label_count])
   final_fc = (
       tf.matmul(flattened_second_dropout, final_fc_weights) + final_fc_bias)
diff --git a/tensorflow/examples/speech_commands/train.py b/tensorflow/examples/speech_commands/train.py
index 43a399b..446e351 100644
--- a/tensorflow/examples/speech_commands/train.py
+++ b/tensorflow/examples/speech_commands/train.py
@@ -86,11 +86,11 @@
 
 
 def main(_):
-  # We want to see all the logging messages for this tutorial.
-  tf.logging.set_verbosity(tf.logging.INFO)
+  # Set the verbosity based on flags (default is INFO, so we see all messages)
+  tf.compat.v1.logging.set_verbosity(FLAGS.verbosity)
 
   # Start a new TensorFlow session.
-  sess = tf.InteractiveSession()
+  sess = tf.compat.v1.InteractiveSession()
 
   # Begin by making sure we have the training data we need. If you already have
   # training data of your own, use `--data_url= ` on the command line to avoid
@@ -122,12 +122,12 @@
         'lists, but are %d and %d long instead' % (len(training_steps_list),
                                                    len(learning_rates_list)))
 
-  input_placeholder = tf.placeholder(
+  input_placeholder = tf.compat.v1.placeholder(
       tf.float32, [None, fingerprint_size], name='fingerprint_input')
   if FLAGS.quantize:
     fingerprint_min, fingerprint_max = input_data.get_features_range(
         model_settings)
-    fingerprint_input = tf.fake_quant_with_min_max_args(
+    fingerprint_input = tf.quantization.fake_quant_with_min_max_args(
         input_placeholder, fingerprint_min, fingerprint_max)
   else:
     fingerprint_input = input_placeholder
@@ -139,48 +139,52 @@
       is_training=True)
 
   # Define loss and optimizer
-  ground_truth_input = tf.placeholder(
+  ground_truth_input = tf.compat.v1.placeholder(
       tf.int64, [None], name='groundtruth_input')
 
   # Optionally we can add runtime checks to spot when NaNs or other symptoms of
   # numerical errors start occurring during training.
   control_dependencies = []
   if FLAGS.check_nans:
-    checks = tf.add_check_numerics_ops()
+    checks = tf.compat.v1.add_check_numerics_ops()
     control_dependencies = [checks]
 
   # Create the back propagation and training evaluation machinery in the graph.
-  with tf.name_scope('cross_entropy'):
-    cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy(
+  with tf.compat.v1.name_scope('cross_entropy'):
+    cross_entropy_mean = tf.compat.v1.losses.sparse_softmax_cross_entropy(
         labels=ground_truth_input, logits=logits)
   if FLAGS.quantize:
     tf.contrib.quantize.create_training_graph(quant_delay=0)
-  with tf.name_scope('train'), tf.control_dependencies(control_dependencies):
-    learning_rate_input = tf.placeholder(
+  with tf.compat.v1.name_scope('train'), tf.control_dependencies(
+      control_dependencies):
+    learning_rate_input = tf.compat.v1.placeholder(
         tf.float32, [], name='learning_rate_input')
-    train_step = tf.train.GradientDescentOptimizer(
+    train_step = tf.compat.v1.train.GradientDescentOptimizer(
         learning_rate_input).minimize(cross_entropy_mean)
-  predicted_indices = tf.argmax(logits, 1)
+  predicted_indices = tf.argmax(input=logits, axis=1)
   correct_prediction = tf.equal(predicted_indices, ground_truth_input)
-  confusion_matrix = tf.confusion_matrix(
-      ground_truth_input, predicted_indices, num_classes=label_count)
-  evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
-  with tf.get_default_graph().name_scope('eval'):
-    tf.summary.scalar('cross_entropy', cross_entropy_mean)
-    tf.summary.scalar('accuracy', evaluation_step)
+  confusion_matrix = tf.math.confusion_matrix(labels=ground_truth_input,
+                                              predictions=predicted_indices,
+                                              num_classes=label_count)
+  evaluation_step = tf.reduce_mean(input_tensor=tf.cast(correct_prediction,
+                                                        tf.float32))
+  with tf.compat.v1.get_default_graph().name_scope('eval'):
+    tf.compat.v1.summary.scalar('cross_entropy', cross_entropy_mean)
+    tf.compat.v1.summary.scalar('accuracy', evaluation_step)
 
-  global_step = tf.train.get_or_create_global_step()
-  increment_global_step = tf.assign(global_step, global_step + 1)
+  global_step = tf.compat.v1.train.get_or_create_global_step()
+  increment_global_step = tf.compat.v1.assign(global_step, global_step + 1)
 
-  saver = tf.train.Saver(tf.global_variables())
+  saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
 
   # Merge all the summaries and write them out to /tmp/retrain_logs (by default)
-  merged_summaries = tf.summary.merge_all(scope='eval')
-  train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
-                                       sess.graph)
-  validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/validation')
+  merged_summaries = tf.compat.v1.summary.merge_all(scope='eval')
+  train_writer = tf.compat.v1.summary.FileWriter(FLAGS.summaries_dir + '/train',
+                                                 sess.graph)
+  validation_writer = tf.compat.v1.summary.FileWriter(
+      FLAGS.summaries_dir + '/validation')
 
-  tf.global_variables_initializer().run()
+  tf.compat.v1.global_variables_initializer().run()
 
   start_step = 1
 
@@ -188,11 +192,11 @@
     models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
     start_step = global_step.eval(session=sess)
 
-  tf.logging.info('Training from step: %d ', start_step)
+  tf.compat.v1.logging.info('Training from step: %d ', start_step)
 
   # Save graph.pbtxt.
-  tf.train.write_graph(sess.graph_def, FLAGS.train_dir,
-                       FLAGS.model_architecture + '.pbtxt')
+  tf.io.write_graph(sess.graph_def, FLAGS.train_dir,
+                    FLAGS.model_architecture + '.pbtxt')
 
   # Save list of words.
   with gfile.GFile(
@@ -230,9 +234,10 @@
             dropout_prob: 0.5
         })
     train_writer.add_summary(train_summary, training_step)
-    tf.logging.info('Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
-                    (training_step, learning_rate_value, train_accuracy * 100,
-                     cross_entropy_value))
+    tf.compat.v1.logging.info(
+        'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
+        (training_step, learning_rate_value, train_accuracy * 100,
+         cross_entropy_value))
     is_last_step = (training_step == training_steps_max)
     if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step:
       set_size = audio_processor.set_size('validation')
@@ -258,20 +263,21 @@
           total_conf_matrix = conf_matrix
         else:
           total_conf_matrix += conf_matrix
-      tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
-      tf.logging.info('Step %d: Validation accuracy = %.1f%% (N=%d)' %
-                      (training_step, total_accuracy * 100, set_size))
+      tf.compat.v1.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
+      tf.compat.v1.logging.info('Step %d: Validation accuracy = %.1f%% (N=%d)' %
+                                (training_step, total_accuracy * 100, set_size))
 
     # Save the model checkpoint periodically.
     if (training_step % FLAGS.save_step_interval == 0 or
         training_step == training_steps_max):
       checkpoint_path = os.path.join(FLAGS.train_dir,
                                      FLAGS.model_architecture + '.ckpt')
-      tf.logging.info('Saving to "%s-%d"', checkpoint_path, training_step)
+      tf.compat.v1.logging.info('Saving to "%s-%d"', checkpoint_path,
+                                training_step)
       saver.save(sess, checkpoint_path, global_step=training_step)
 
   set_size = audio_processor.set_size('testing')
-  tf.logging.info('set_size=%d', set_size)
+  tf.compat.v1.logging.info('set_size=%d', set_size)
   total_accuracy = 0
   total_conf_matrix = None
   for i in xrange(0, set_size, FLAGS.batch_size):
@@ -290,9 +296,9 @@
       total_conf_matrix = conf_matrix
     else:
       total_conf_matrix += conf_matrix
-  tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
-  tf.logging.info('Final test accuracy = %.1f%% (N=%d)' % (total_accuracy * 100,
-                                                           set_size))
+  tf.compat.v1.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
+  tf.compat.v1.logging.info('Final test accuracy = %.1f%% (N=%d)' %
+                            (total_accuracy * 100, set_size))
 
 
 if __name__ == '__main__':
@@ -301,7 +307,7 @@
       '--data_url',
       type=str,
       # pylint: disable=line-too-long
-      default='http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz',
+      default='https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz',
       # pylint: enable=line-too-long
       help='Location of speech training data archive on the web.')
   parser.add_argument(
@@ -448,5 +454,33 @@
       default='mfcc',
       help='Spectrogram processing mode. Can be "mfcc", "average", or "micro"')
 
+  # Function used to parse --verbosity argument
+  def verbosity_arg(value):
+    """Parses verbosity argument.
+
+    Args:
+      value: A member of tf.logging.
+    Raises:
+      ArgumentTypeError: Not an expected value.
+    """
+    value = value.upper()
+    if value == 'INFO':
+      return tf.compat.v1.logging.INFO
+    elif value == 'DEBUG':
+      return tf.compat.v1.logging.DEBUG
+    elif value == 'ERROR':
+      return tf.compat.v1.logging.ERROR
+    elif value == 'FATAL':
+      return tf.compat.v1.logging.FATAL
+    elif value == 'WARN':
+      return tf.compat.v1.logging.WARN
+    else:
+      raise argparse.ArgumentTypeError('Not an expected value')
+  parser.add_argument(
+      '--verbosity',
+      type=verbosity_arg,
+      default=tf.compat.v1.logging.INFO,
+      help='Log verbosity. Can be "INFO", "DEBUG", "ERROR", "FATAL", or "WARN"')
+
   FLAGS, unparsed = parser.parse_known_args()
-  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
+  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/examples/speech_commands/train_test.py b/tensorflow/examples/speech_commands/train_test.py
index db19576..f17e2ba 100644
--- a/tensorflow/examples/speech_commands/train_test.py
+++ b/tensorflow/examples/speech_commands/train_test.py
@@ -100,6 +100,7 @@
         'background_frequency': 0.8,
         'eval_step_interval': 1,
         'save_step_interval': 1,
+        'verbosity': tf.compat.v1.logging.INFO
     }
     return DictStruct(**flags)
 
diff --git a/tensorflow/examples/speech_commands/wav_to_features.py b/tensorflow/examples/speech_commands/wav_to_features.py
index d7f2446..be3d045 100644
--- a/tensorflow/examples/speech_commands/wav_to_features.py
+++ b/tensorflow/examples/speech_commands/wav_to_features.py
@@ -62,7 +62,7 @@
   """
 
   # Start a new TensorFlow session.
-  sess = tf.InteractiveSession()
+  sess = tf.compat.v1.InteractiveSession()
 
   model_settings = models.prepare_model_settings(
       0, sample_rate, clip_duration_ms, window_size_ms, window_stride_ms,
@@ -124,12 +124,12 @@
 
 def main(_):
   # We want to see all the logging messages.
-  tf.logging.set_verbosity(tf.logging.INFO)
+  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
   wav_to_features(FLAGS.sample_rate, FLAGS.clip_duration_ms,
                   FLAGS.window_size_ms, FLAGS.window_stride_ms,
                   FLAGS.feature_bin_count, FLAGS.quantize, FLAGS.preprocess,
                   FLAGS.input_wav, FLAGS.output_c_file)
-  tf.logging.info('Wrote to "%s"' % (FLAGS.output_c_file))
+  tf.compat.v1.logging.info('Wrote to "%s"' % (FLAGS.output_c_file))
 
 
 if __name__ == '__main__':
@@ -182,4 +182,4 @@
       help='Where to save the generated C source file containing the features')
 
   FLAGS, unparsed = parser.parse_known_args()
-  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
+  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/examples/tutorials/deepdream/README.md b/tensorflow/examples/tutorials/deepdream/README.md
index 403e4b3..e16b366 100644
--- a/tensorflow/examples/tutorials/deepdream/README.md
+++ b/tensorflow/examples/tutorials/deepdream/README.md
@@ -5,11 +5,18 @@
 This directory contains Jupyter notebook that demonstrates a number of Convolutional Neural Network
 image generation techniques implemented with TensorFlow:
 
-- visualizing individual feature channels and their combinations to explore the space of patterns learned by the neural network (see [GoogLeNet](http://storage.googleapis.com/deepdream/visualz/tensorflow_inception/index.html) and [VGG16](http://storage.googleapis.com/deepdream/visualz/vgg16/index.html) galleries)
-- embedding TensorBoard graph visualizations into Jupyter notebooks
-- producing high-resolution images with tiled computation ([example](http://storage.googleapis.com/deepdream/pilatus_flowers.jpg))
-- using Laplacian Pyramid Gradient Normalization to produce smooth and colorful visuals at low cost
-- generating DeepDream-like images with TensorFlow
+-   visualizing individual feature channels and their combinations to explore
+    the space of patterns learned by the neural network (see
+    [GoogLeNet](https://storage.googleapis.com/deepdream/visualz/tensorflow_inception/index.html)
+    and
+    [VGG16](https://storage.googleapis.com/deepdream/visualz/vgg16/index.html)
+    galleries)
+-   embedding TensorBoard graph visualizations into Jupyter notebooks
+-   producing high-resolution images with tiled computation
+    ([example](https://storage.googleapis.com/deepdream/pilatus_flowers.jpg))
+-   using Laplacian Pyramid Gradient Normalization to produce smooth and
+    colorful visuals at low cost
+-   generating DeepDream-like images with TensorFlow
 
 You can view "deepdream.ipynb" directly on GitHub. Note that GitHub Jupyter notebook preview removes
 embedded graph visualizations. You can still see them online
diff --git a/tensorflow/examples/tutorials/deepdream/deepdream.ipynb b/tensorflow/examples/tutorials/deepdream/deepdream.ipynb
index 15112aa..448f3f6 100644
--- a/tensorflow/examples/tutorials/deepdream/deepdream.ipynb
+++ b/tensorflow/examples/tutorials/deepdream/deepdream.ipynb
@@ -40,14 +40,14 @@
       "source": [
         "This notebook demonstrates a number of Convolutional Neural Network image generation techniques implemented with TensorFlow for fun and science:\n",
         "\n",
-        "- visualize individual feature channels and their combinations to explore the space of patterns learned by the neural network (see [GoogLeNet](http://storage.googleapis.com/deepdream/visualz/tensorflow_inception/index.html) and [VGG16](http://storage.googleapis.com/deepdream/visualz/vgg16/index.html) galleries)\n",
+        "- visualize individual feature channels and their combinations to explore the space of patterns learned by the neural network (see [GoogLeNet](https://storage.googleapis.com/deepdream/visualz/tensorflow_inception/index.html) and [VGG16](https://storage.googleapis.com/deepdream/visualz/vgg16/index.html) galleries)\n",
         "- embed TensorBoard graph visualizations into Jupyter notebooks\n",
-        "- produce high-resolution images with tiled computation ([example](http://storage.googleapis.com/deepdream/pilatus_flowers.jpg))\n",
+        "- produce high-resolution images with tiled computation ([example](https://storage.googleapis.com/deepdream/pilatus_flowers.jpg))\n",
         "- use Laplacian Pyramid Gradient Normalization to produce smooth and colorful visuals at low cost\n",
         "- generate DeepDream-like images with TensorFlow (DogSlugs included)\n",
         "\n",
         "\n",
-        "The network under examination is the [GoogLeNet architecture](http://arxiv.org/abs/1409.4842), trained to classify images into one of 1000 categories of the [ImageNet](http://image-net.org/) dataset. It consists of a set of layers that apply a sequence of transformations to the input image. The parameters of these transformations were determined during the training process by a variant of gradient descent algorithm. The internal image representations may seem obscure, but it is possible to visualize and interpret them. In this notebook we are going to present a few tricks that allow to make these visualizations both efficient to generate and even beautiful. Impatient readers can start with exploring the full galleries of images generated by the method described here for [GoogLeNet](http://storage.googleapis.com/deepdream/visualz/tensorflow_inception/index.html) and [VGG16](http://storage.googleapis.com/deepdream/visualz/vgg16/index.html) architectures."
+        "The network under examination is the [GoogLeNet architecture](http://arxiv.org/abs/1409.4842), trained to classify images into one of 1000 categories of the [ImageNet](http://image-net.org/) dataset. It consists of a set of layers that apply a sequence of transformations to the input image. The parameters of these transformations were determined during the training process by a variant of gradient descent algorithm. The internal image representations may seem obscure, but it is possible to visualize and interpret them. In this notebook we are going to present a few tricks that allow to make these visualizations both efficient to generate and even beautiful. Impatient readers can start with exploring the full galleries of images generated by the method described here for [GoogLeNet](https://storage.googleapis.com/deepdream/visualz/tensorflow_inception/index.html) and [VGG16](https://storage.googleapis.com/deepdream/visualz/vgg16/index.html) architectures."
       ]
     },
     {
@@ -1117,7 +1117,7 @@
         "id": "mYsY6_Ngpfwl"
       },
       "source": [
-        "Don't hesitate to use higher resolution inputs (also increase the number of octaves)! Here is an [example](http://storage.googleapis.com/deepdream/pilatus_flowers.jpg) of running the flower dream over the bigger image."
+        "Don't hesitate to use higher resolution inputs (also increase the number of octaves)! Here is an [example](https://storage.googleapis.com/deepdream/pilatus_flowers.jpg) of running the flower dream over the bigger image."
       ]
     },
     {
diff --git a/tensorflow/examples/tutorials/mnist/BUILD b/tensorflow/examples/tutorials/mnist/BUILD
index 5d93620..264d084 100644
--- a/tensorflow/examples/tutorials/mnist/BUILD
+++ b/tensorflow/examples/tutorials/mnist/BUILD
@@ -1,7 +1,7 @@
 # Description:
 # Example TensorFlow models for MNIST used in tutorials
 
-load("//tensorflow:tensorflow.bzl", "py_test")
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
 
 package(
     licenses = ["notice"],  # Apache 2.0
@@ -94,43 +94,42 @@
     ],
 )
 
-py_test(
+tf_py_test(
     name = "fully_connected_feed_test",
-    size = "medium",
     srcs = [
         "fully_connected_feed.py",
     ],
+    additional_deps = [
+        ":input_data",
+        ":mnist",
+        "//tensorflow:tensorflow_py",
+    ],
     args = [
         "--fake_data",
         "--max_steps=10",
     ],
     main = "fully_connected_feed.py",
-    python_version = "PY2",
-    srcs_version = "PY2AND3",
-    deps = [
-        ":input_data",
-        ":mnist",
-        "//tensorflow:tensorflow_py",
-    ],
+    tags = ["no_pip"],
 )
 
-py_test(
+tf_py_test(
     name = "mnist_with_summaries_test",
     size = "small",
     srcs = [
         "mnist_with_summaries.py",
     ],
+    additional_deps = [
+        ":input_data",
+        "//tensorflow:tensorflow_py",
+    ],
     args = [
         "--fake_data",
         "--max_steps=10",
         "--learning_rate=0.00",
     ],
     main = "mnist_with_summaries.py",
-    python_version = "PY2",
-    srcs_version = "PY2AND3",
-    tags = ["notsan"],  # http://b/29184009
-    deps = [
-        ":input_data",
-        "//tensorflow:tensorflow_py",
+    tags = [
+        "no_pip",
+        "notsan",  # http://b/29184009
     ],
 )
diff --git a/tensorflow/examples/tutorials/mnist/fully_connected_feed.py b/tensorflow/examples/tutorials/mnist/fully_connected_feed.py
index e61cbab..8eb5710 100644
--- a/tensorflow/examples/tutorials/mnist/fully_connected_feed.py
+++ b/tensorflow/examples/tutorials/mnist/fully_connected_feed.py
@@ -50,9 +50,9 @@
   # Note that the shapes of the placeholders match the shapes of the full
   # image and label tensors, except the first dimension is now batch_size
   # rather than the full size of the train or test data sets.
-  images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,
-                                                         mnist.IMAGE_PIXELS))
-  labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
+  images_placeholder = tf.compat.v1.placeholder(
+      tf.float32, shape=(batch_size, mnist.IMAGE_PIXELS))
+  labels_placeholder = tf.compat.v1.placeholder(tf.int32, shape=(batch_size))
   return images_placeholder, labels_placeholder
 
 
@@ -140,19 +140,19 @@
     eval_correct = mnist.evaluation(logits, labels_placeholder)
 
     # Build the summary Tensor based on the TF collection of Summaries.
-    summary = tf.summary.merge_all()
+    summary = tf.compat.v1.summary.merge_all()
 
     # Add the variable initializer Op.
-    init = tf.global_variables_initializer()
+    init = tf.compat.v1.global_variables_initializer()
 
     # Create a saver for writing training checkpoints.
-    saver = tf.train.Saver()
+    saver = tf.compat.v1.train.Saver()
 
     # Create a session for running Ops on the Graph.
     sess = tf.compat.v1.Session()
 
     # Instantiate a SummaryWriter to output summaries and the Graph.
-    summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)
+    summary_writer = tf.compat.v1.summary.FileWriter(FLAGS.log_dir, sess.graph)
 
     # And then after everything is built:
 
@@ -216,9 +216,9 @@
 
 
 def main(_):
-  if tf.gfile.Exists(FLAGS.log_dir):
-    tf.gfile.DeleteRecursively(FLAGS.log_dir)
-  tf.gfile.MakeDirs(FLAGS.log_dir)
+  if tf.io.gfile.exists(FLAGS.log_dir):
+    tf.io.gfile.rmtree(FLAGS.log_dir)
+  tf.io.gfile.makedirs(FLAGS.log_dir)
   run_training()
 
 
@@ -276,4 +276,4 @@
   )
 
   FLAGS, unparsed = parser.parse_known_args()
-  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
+  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/examples/tutorials/mnist/mnist.py b/tensorflow/examples/tutorials/mnist/mnist.py
index 7cedd0e..0141d4b 100644
--- a/tensorflow/examples/tutorials/mnist/mnist.py
+++ b/tensorflow/examples/tutorials/mnist/mnist.py
@@ -54,29 +54,29 @@
     softmax_linear: Output tensor with the computed logits.
   """
   # Hidden 1
-  with tf.name_scope('hidden1'):
+  with tf.compat.v1.name_scope('hidden1'):
     weights = tf.Variable(
-        tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
-                            stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
-        name='weights')
+        tf.random.truncated_normal(
+            [IMAGE_PIXELS, hidden1_units],
+            stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))), name='weights')
     biases = tf.Variable(tf.zeros([hidden1_units]),
                          name='biases')
     hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
   # Hidden 2
-  with tf.name_scope('hidden2'):
+  with tf.compat.v1.name_scope('hidden2'):
     weights = tf.Variable(
-        tf.truncated_normal([hidden1_units, hidden2_units],
-                            stddev=1.0 / math.sqrt(float(hidden1_units))),
-        name='weights')
+        tf.random.truncated_normal(
+            [hidden1_units, hidden2_units],
+            stddev=1.0 / math.sqrt(float(hidden1_units))), name='weights')
     biases = tf.Variable(tf.zeros([hidden2_units]),
                          name='biases')
     hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
   # Linear
-  with tf.name_scope('softmax_linear'):
+  with tf.compat.v1.name_scope('softmax_linear'):
     weights = tf.Variable(
-        tf.truncated_normal([hidden2_units, NUM_CLASSES],
-                            stddev=1.0 / math.sqrt(float(hidden2_units))),
-        name='weights')
+        tf.random.truncated_normal(
+            [hidden2_units, NUM_CLASSES],
+            stddev=1.0 / math.sqrt(float(hidden2_units))), name='weights')
     biases = tf.Variable(tf.zeros([NUM_CLASSES]),
                          name='biases')
     logits = tf.matmul(hidden2, weights) + biases
@@ -93,8 +93,9 @@
   Returns:
     loss: Loss tensor of type float.
   """
-  labels = tf.to_int64(labels)
-  return tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
+  labels = tf.cast(labels, dtype=tf.int64)
+  return tf.compat.v1.losses.sparse_softmax_cross_entropy(
+      labels=labels, logits=logits)
 
 
 def training(loss, learning_rate):
@@ -115,9 +116,9 @@
     train_op: The Op for training.
   """
   # Add a scalar summary for the snapshot loss.
-  tf.summary.scalar('loss', loss)
+  tf.compat.v1.summary.scalar('loss', loss)
   # Create the gradient descent optimizer with the given learning rate.
-  optimizer = tf.train.GradientDescentOptimizer(learning_rate)
+  optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
   # Create a variable to track the global step.
   global_step = tf.Variable(0, name='global_step', trainable=False)
   # Use the optimizer to apply the gradients that minimize the loss
@@ -142,6 +143,6 @@
   # It returns a bool tensor with shape [batch_size] that is true for
   # the examples where the label is in the top k (here k=1)
   # of all logits for that example.
-  correct = tf.nn.in_top_k(logits, labels, 1)
+  correct = tf.nn.in_top_k(predictions=logits, targets=labels, k=1)
   # Return the number of true entries.
-  return tf.reduce_sum(tf.cast(correct, tf.int32))
+  return tf.reduce_sum(input_tensor=tf.cast(correct, tf.int32))
diff --git a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
index efe35ca..04315ad 100644
--- a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
+++ b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
@@ -40,22 +40,22 @@
   mnist = input_data.read_data_sets(FLAGS.data_dir,
                                     fake_data=FLAGS.fake_data)
 
-  sess = tf.InteractiveSession()
+  sess = tf.compat.v1.InteractiveSession()
   # Create a multilayer model.
 
   # Input placeholders
-  with tf.name_scope('input'):
-    x = tf.placeholder(tf.float32, [None, 784], name='x-input')
-    y_ = tf.placeholder(tf.int64, [None], name='y-input')
+  with tf.compat.v1.name_scope('input'):
+    x = tf.compat.v1.placeholder(tf.float32, [None, 784], name='x-input')
+    y_ = tf.compat.v1.placeholder(tf.int64, [None], name='y-input')
 
-  with tf.name_scope('input_reshape'):
+  with tf.compat.v1.name_scope('input_reshape'):
     image_shaped_input = tf.reshape(x, [-1, 28, 28, 1])
-    tf.summary.image('input', image_shaped_input, 10)
+    tf.compat.v1.summary.image('input', image_shaped_input, 10)
 
   # We can't initialize these variables to 0 - the network will get stuck.
   def weight_variable(shape):
     """Create a weight variable with appropriate initialization."""
-    initial = tf.truncated_normal(shape, stddev=0.1)
+    initial = tf.random.truncated_normal(shape, stddev=0.1)
     return tf.Variable(initial)
 
   def bias_variable(shape):
@@ -65,15 +65,15 @@
 
   def variable_summaries(var):
     """Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
-    with tf.name_scope('summaries'):
-      mean = tf.reduce_mean(var)
-      tf.summary.scalar('mean', mean)
-      with tf.name_scope('stddev'):
-        stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
-      tf.summary.scalar('stddev', stddev)
-      tf.summary.scalar('max', tf.reduce_max(var))
-      tf.summary.scalar('min', tf.reduce_min(var))
-      tf.summary.histogram('histogram', var)
+    with tf.compat.v1.name_scope('summaries'):
+      mean = tf.reduce_mean(input_tensor=var)
+      tf.compat.v1.summary.scalar('mean', mean)
+      with tf.compat.v1.name_scope('stddev'):
+        stddev = tf.sqrt(tf.reduce_mean(input_tensor=tf.square(var - mean)))
+      tf.compat.v1.summary.scalar('stddev', stddev)
+      tf.compat.v1.summary.scalar('max', tf.reduce_max(input_tensor=var))
+      tf.compat.v1.summary.scalar('min', tf.reduce_min(input_tensor=var))
+      tf.compat.v1.summary.histogram('histogram', var)
 
   def nn_layer(input_tensor, input_dim, output_dim, layer_name, act=tf.nn.relu):
     """Reusable code for making a simple neural net layer.
@@ -83,32 +83,32 @@
     and adds a number of summary ops.
     """
     # Adding a name scope ensures logical grouping of the layers in the graph.
-    with tf.name_scope(layer_name):
+    with tf.compat.v1.name_scope(layer_name):
       # This Variable will hold the state of the weights for the layer
-      with tf.name_scope('weights'):
+      with tf.compat.v1.name_scope('weights'):
         weights = weight_variable([input_dim, output_dim])
         variable_summaries(weights)
-      with tf.name_scope('biases'):
+      with tf.compat.v1.name_scope('biases'):
         biases = bias_variable([output_dim])
         variable_summaries(biases)
-      with tf.name_scope('Wx_plus_b'):
+      with tf.compat.v1.name_scope('Wx_plus_b'):
         preactivate = tf.matmul(input_tensor, weights) + biases
-        tf.summary.histogram('pre_activations', preactivate)
+        tf.compat.v1.summary.histogram('pre_activations', preactivate)
       activations = act(preactivate, name='activation')
-      tf.summary.histogram('activations', activations)
+      tf.compat.v1.summary.histogram('activations', activations)
       return activations
 
   hidden1 = nn_layer(x, 784, 500, 'layer1')
 
-  with tf.name_scope('dropout'):
-    keep_prob = tf.placeholder(tf.float32)
-    tf.summary.scalar('dropout_keep_probability', keep_prob)
+  with tf.compat.v1.name_scope('dropout'):
+    keep_prob = tf.compat.v1.placeholder(tf.float32)
+    tf.compat.v1.summary.scalar('dropout_keep_probability', keep_prob)
     dropped = tf.nn.dropout(hidden1, rate=(1 - keep_prob))
 
   # Do not apply softmax activation yet, see below.
   y = nn_layer(dropped, 500, 10, 'layer2', act=tf.identity)
 
-  with tf.name_scope('cross_entropy'):
+  with tf.compat.v1.name_scope('cross_entropy'):
     # The raw formulation of cross-entropy,
     #
     # tf.reduce_mean(-tf.reduce_sum(y_ * tf.math.log(tf.softmax(y)),
@@ -119,28 +119,30 @@
     # So here we use tf.compat.v1.losses.sparse_softmax_cross_entropy on the
     # raw logit outputs of the nn_layer above, and then average across
     # the batch.
-    with tf.name_scope('total'):
-      cross_entropy = tf.losses.sparse_softmax_cross_entropy(
+    with tf.compat.v1.name_scope('total'):
+      cross_entropy = tf.compat.v1.losses.sparse_softmax_cross_entropy(
           labels=y_, logits=y)
-  tf.summary.scalar('cross_entropy', cross_entropy)
+  tf.compat.v1.summary.scalar('cross_entropy', cross_entropy)
 
-  with tf.name_scope('train'):
-    train_step = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(
+  with tf.compat.v1.name_scope('train'):
+    train_step = tf.compat.v1.train.AdamOptimizer(FLAGS.learning_rate).minimize(
         cross_entropy)
 
-  with tf.name_scope('accuracy'):
-    with tf.name_scope('correct_prediction'):
-      correct_prediction = tf.equal(tf.argmax(y, 1), y_)
-    with tf.name_scope('accuracy'):
-      accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
-  tf.summary.scalar('accuracy', accuracy)
+  with tf.compat.v1.name_scope('accuracy'):
+    with tf.compat.v1.name_scope('correct_prediction'):
+      correct_prediction = tf.equal(tf.argmax(input=y, axis=1), y_)
+    with tf.compat.v1.name_scope('accuracy'):
+      accuracy = tf.reduce_mean(input_tensor=tf.cast(correct_prediction,
+                                                     tf.float32))
+  tf.compat.v1.summary.scalar('accuracy', accuracy)
 
   # Merge all the summaries and write them out to
   # /tmp/tensorflow/mnist/logs/mnist_with_summaries (by default)
-  merged = tf.summary.merge_all()
-  train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
-  test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test')
-  tf.global_variables_initializer().run()
+  merged = tf.compat.v1.summary.merge_all()
+  train_writer = tf.compat.v1.summary.FileWriter(FLAGS.log_dir + '/train',
+                                                 sess.graph)
+  test_writer = tf.compat.v1.summary.FileWriter(FLAGS.log_dir + '/test')
+  tf.compat.v1.global_variables_initializer().run()
 
   # Train the model, and also write summaries.
   # Every 10th step, measure test-set accuracy, and write test summaries
@@ -163,8 +165,9 @@
       print('Accuracy at step %s: %s' % (i, acc))
     else:  # Record train set summaries, and train
       if i % 100 == 99:  # Record execution stats
-        run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
-        run_metadata = tf.RunMetadata()
+        run_options = tf.compat.v1.RunOptions(
+            trace_level=tf.compat.v1.RunOptions.FULL_TRACE)
+        run_metadata = tf.compat.v1.RunMetadata()
         summary, _ = sess.run([merged, train_step],
                               feed_dict=feed_dict(True),
                               options=run_options,
@@ -180,9 +183,9 @@
 
 
 def main(_):
-  if tf.gfile.Exists(FLAGS.log_dir):
-    tf.gfile.DeleteRecursively(FLAGS.log_dir)
-  tf.gfile.MakeDirs(FLAGS.log_dir)
+  if tf.io.gfile.exists(FLAGS.log_dir):
+    tf.io.gfile.rmtree(FLAGS.log_dir)
+  tf.io.gfile.makedirs(FLAGS.log_dir)
   with tf.Graph().as_default():
     train()
 
@@ -211,4 +214,4 @@
                            'tensorflow/mnist/logs/mnist_with_summaries'),
       help='Summaries log directory')
   FLAGS, unparsed = parser.parse_known_args()
-  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
+  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml
index f1dffb1..5a96757 100644
--- a/tensorflow/java/maven/pom.xml
+++ b/tensorflow/java/maven/pom.xml
@@ -90,6 +90,17 @@
           </execution>
         </executions>
       </plugin>
+      <plugin>
+        <groupId>org.apache.maven.plugins</groupId>
+        <artifactId>maven-jar-plugin</artifactId>
+        <configuration>
+          <archive>
+            <manifest>
+              <addDefaultImplementationEntries>true</addDefaultImplementationEntries>
+            </manifest>
+          </archive>
+        </configuration>
+      </plugin>
     </plugins>
   </build>
 
diff --git a/tensorflow/java/maven/run_inside_container.sh b/tensorflow/java/maven/run_inside_container.sh
index 27ae193..3899ebb 100644
--- a/tensorflow/java/maven/run_inside_container.sh
+++ b/tensorflow/java/maven/run_inside_container.sh
@@ -85,6 +85,21 @@
   curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-darwin-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C darwin-x86_64
   curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-windows-x86_64-${TF_VERSION}.zip" -o /tmp/windows.zip
 
+  # Get rid of symlinks, those are not supported by jar. As of tensorflow 1.14,
+  # libtensorflow_jni.so expects to find
+  # libtensorflow_framework.so.<majorVersion>.
+  MAJOR_VERSION="${TF_VERSION/\.*/}"
+
+  FRAMEWORK_SO="$(readlink -f linux-x86_64/libtensorflow_framework.so)"
+  rm linux-x86_64/libtensorflow_framework.so
+  rm "linux-x86_64/libtensorflow_framework.so.${MAJOR_VERSION}"
+  mv "${FRAMEWORK_SO}" "linux-x86_64/libtensorflow_framework.so.${MAJOR_VERSION}"
+
+  FRAMEWORK_DYLIB="$(readlink -f darwin-x86_64/libtensorflow_framework.dylib)"
+  rm darwin-x86_64/libtensorflow_framework.dylib
+  rm "darwin-x86_64/libtensorflow_framework.${MAJOR_VERSION}.dylib"
+  mv "${FRAMEWORK_DYLIB}" "darwin-x86_64/libtensorflow_framework.${MAJOR_VERSION}.dylib"
+
   unzip /tmp/windows.zip -d windows-x86_64
   rm -f /tmp/windows.zip
   # Updated timestamps seem to be required to get Maven to pick up the file.
@@ -105,6 +120,11 @@
   curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-gpu-linux-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C linux-x86_64
   curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-gpu-windows-x86_64-${TF_VERSION}.zip" -o /tmp/windows.zip
 
+  FRAMEWORK_SO="$(readlink -f linux-x86_64/libtensorflow_framework.so)"
+  rm linux-x86_64/libtensorflow_framework.so
+  rm "linux-x86_64/libtensorflow_framework.so.${MAJOR_VERSION}"
+  mv "${FRAMEWORK_SO}" "linux-x86_64/libtensorflow_framework.so.${MAJOR_VERSION}"
+
   unzip /tmp/windows.zip -d windows-x86_64
   rm -f /tmp/windows.zip
 
diff --git a/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java b/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java
index 972e9cc..cbb878e 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java
@@ -141,7 +141,7 @@
      * <p>{@link DevicePlacementPolicy#SILENT} is used by default.
      *
      * @param value policy to apply
-     * @see {@link DevicePlacementPolicy}
+     * @see DevicePlacementPolicy
      */
     public Options devicePlacementPolicy(DevicePlacementPolicy value) {
       devicePlacementPolicy = value;
@@ -154,7 +154,7 @@
      * <p>{@link ResourceCleanupStrategy#IN_BACKGROUND} is used by default.
      *
      * @param value strategy to use
-     * @see {@link ResourceCleanupStrategy}
+     * @see ResourceCleanupStrategy
      */
     public Options resourceCleanupStrategy(ResourceCleanupStrategy value) {
       resourceCleanupStrategy = value;
@@ -169,8 +169,8 @@
      * not be supported on public endpoints in the future.
      *
      * @param value a serialized config proto
-     * @see
-     *     https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/config.proto
+     * @see <a
+     *     href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/config.proto"/>
      */
     public Options config(byte[] value) {
       config = value;
@@ -231,7 +231,7 @@
    * @param options options to use to build default session
    * @return default eager session
    * @throws IllegalStateException if the default session is already initialized
-   * @see {@link #getDefault()}
+   * @see #getDefault()
    */
   public static EagerSession initDefault(Options options) {
     synchronized (EagerSession.class) {
@@ -262,12 +262,12 @@
    * Ops tf = Ops.create();
    *
    * // Starting to build eager operations using default session, by calling
-   * // EagerSession.getDefault() explictly
+   * // EagerSession.getDefault() explicitly
    * Ops tf = Ops.create(EagerSession.getDefault());
    * }</pre>
    *
    * @return default eager session
-   * @see {@link #initDefault(Options)}
+   * @see #initDefault
    */
   public static EagerSession getDefault() {
     if (defaultSession == null) {
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
index a0e14f1..3a175b1 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
@@ -226,8 +226,8 @@
    * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
    * i.e., {@code dy/dx_1, dy/dx_2...}
    * <p>
-   * This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is
-   * a single output, {@code dx} is null and {@code prefix} is null.
+   * This is a simplified version of {@link #addGradients(String, Output[], Output[], Output[])
+   * where {@code y} is a single output, {@code dx} is null and {@code prefix} is null.
    *
    * @param y output of the function to derive
    * @param x inputs of the function for which partial derivatives are computed
diff --git a/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java b/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java
index cf773e1..2ab0e47 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java
@@ -65,7 +65,7 @@
         NativeLibrary.class.getClassLoader().getResourceAsStream(jniResourceName);
     // Extract the JNI's dependency
     final String frameworkLibName =
-        maybeAdjustForMacOS(System.mapLibraryName("tensorflow_framework"));
+        getVersionedLibraryName(System.mapLibraryName("tensorflow_framework"));
     final String frameworkResourceName = makeResourceName(frameworkLibName);
     log("frameworkResourceName: " + frameworkResourceName);
     final InputStream frameworkResource =
@@ -126,22 +126,66 @@
     }
   }
 
-  private static String maybeAdjustForMacOS(String libFilename) {
-    if (!System.getProperty("os.name").contains("OS X")) {
+  private static boolean resourceExists(String baseName) {
+    return NativeLibrary.class.getClassLoader().getResource(makeResourceName(baseName)) != null;
+  }
+
+  private static String getVersionedLibraryName(String libFilename) {
+    // If the resource exists as an unversioned file, return that.
+    if (resourceExists(libFilename)) {
       return libFilename;
     }
-    // This is macOS, and the TensorFlow release process might have setup dependencies on
-    // libtensorflow_framework.so instead of libtensorflow_framework.dylib. Adjust for that.
-    final ClassLoader cl = NativeLibrary.class.getClassLoader();
-    if (cl.getResource(makeResourceName(libFilename)) != null) {
-      return libFilename;
+
+    final String versionName = getMajorVersionNumber();
+
+    // If we're on darwin, the versioned libraries look like blah.1.dylib.
+    final String darwinSuffix = ".dylib";
+    if (libFilename.endsWith(darwinSuffix)) {
+      final String prefix = libFilename.substring(0, libFilename.length() - darwinSuffix.length());
+      if (versionName != null) {
+        final String darwinVersionedLibrary = prefix + "." + versionName + darwinSuffix;
+        if (resourceExists(darwinVersionedLibrary)) {
+          return darwinVersionedLibrary;
+        }
+      } else {
+        // If we're here, we're on darwin, but we couldn't figure out the major version number. We
+        // already tried the library name without any changes, but let's do one final try for the
+        // library with a .so suffix.
+        final String darwinSoName = prefix + ".so";
+        if (resourceExists(darwinSoName)) {
+          return darwinSoName;
+        }
+      }
+    } else if (libFilename.endsWith(".so")) {
+      // Libraries ending in ".so" are versioned like "libfoo.so.1", so try that.
+      final String versionedSoName = libFilename + "." + versionName;
+      if (versionName != null && resourceExists(versionedSoName)) {
+        return versionedSoName;
+      }
     }
-    // liftensorflow_framework.dylib not found, try libtensorflow_framework.so
-    final String suffix = ".dylib";
-    if (!libFilename.endsWith(suffix)) {
-      return libFilename;
+
+    // Otherwise, we've got no idea.
+    return libFilename;
+  }
+
+  /**
+   * Returns the major version number of this TensorFlow Java API, or {@code null} if it cannot be
+   * determined.
+   */
+  private static String getMajorVersionNumber() {
+    String version = NativeLibrary.class.getPackage().getImplementationVersion();
+    // expecting a string like 1.14.0, we want to get the first '1'.
+    int dotIndex;
+    if (version == null || (dotIndex = version.indexOf('.')) == -1) {
+      return null;
     }
-    return libFilename.substring(0, libFilename.length() - suffix.length()) + ".so";
+    String majorVersion = version.substring(0, dotIndex);
+    try {
+      Integer.parseInt(majorVersion);
+      return majorVersion;
+    } catch (NumberFormatException unused) {
+      return null;
+    }
   }
 
   private static String extractResource(
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Session.java b/tensorflow/java/src/main/java/org/tensorflow/Session.java
index b5e0f7a..bdcb4fd 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Session.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Session.java
@@ -22,7 +22,7 @@
  * Driver for {@link Graph} execution.
  *
  * <p>A {@code Session} instance encapsulates the environment in which {@link Operation}s in a
- * {@link Graph} are executed to compute {@link Tensor}s. For example:
+ * {@link Graph} are executed to compute {@link Tensor Tensors}. For example:
  *
  * <pre>{@code
  * // Let's say graph is an instance of the Graph class
@@ -109,12 +109,13 @@
   }
 
   /**
-   * Run {@link Operation}s and evaluate {@link Tensor}s.
+   * Run {@link Operation}s and evaluate {@link Tensor Tensors}.
    *
    * <p>A Runner runs the necessary graph fragments to execute every {@link Operation} required to
-   * evaluate the {@link Tensor}s to fetch. The {@link #feed(String,int,Tensor)} call allows callers
-   * to override the value of {@link Tensor}s in the graph by substituting the provided {@link
-   * Tensor}s for the outputs of the operations provided to {@link #feed(String,int,Tensor)}.
+   * evaluate the {@link Tensor Tensors} to fetch. The {@link #feed(String,int,Tensor)} call allows
+   * callers to override the value of {@link Tensor Tensors} in the graph by substituting the
+   * provided {@link Tensor Tensors} for the outputs of the operations provided to {@link
+   * #feed(String,int,Tensor)}.
    */
   public final class Runner {
     /**
@@ -201,7 +202,8 @@
     }
 
     /**
-     * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor}s.
+     * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor
+     * Tensors}.
      */
     public Runner addTarget(String operation) {
       GraphOperation op = operationByName(operation);
@@ -212,9 +214,10 @@
     }
 
     /**
-     * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor}s.
+     * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor
+     * Tensors}.
      *
-     * @throws execption if the operation is not a {@link GraphOperation}
+     * @throws IllegalArgumentException if the operation is not a {@link GraphOperation}
      */
     public Runner addTarget(Operation operation) {
       if (!(operation instanceof GraphOperation)) {
@@ -226,9 +229,10 @@
       targets.add((GraphOperation) operation);
       return this;
     }
-    
+
     /**
-     * Make {@link #run()} execute {@code operand}, but not return any evaluated {@link Tensor}s.
+     * Make {@link #run} execute {@code operand}, but not return any evaluated {@link Tensor
+     * Tensors}.
      */
     public Runner addTarget(Operand<?> operand) {
       return addTarget(operand.asOutput().op());
@@ -256,8 +260,8 @@
     /**
      * Execute the graph fragments necessary to compute all requested fetches.
      *
-     * <p><b>WARNING:</b> The caller assumes ownership of all returned {@link Tensor}s, i.e., the
-     * caller must call {@link Tensor#close()} on all elements of the returned list to free up
+     * <p><b>WARNING:</b> The caller assumes ownership of all returned {@link Tensor Tensors}, i.e.,
+     * the caller must call {@link Tensor#close} on all elements of the returned list to free up
      * resources.
      *
      * <p>TODO(ashankar): Reconsider the return type here. Two things in particular: (a) Make it
@@ -458,7 +462,7 @@
    * @param inputOpIndices (see inputTensorHandles)
    * @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values
    *     that are being "fed" (do not need to be computed) during graph execution.
-   *     inputTensorHandles[i] (which correponds to a Tensor.nativeHandle) is considered to be the
+   *     inputTensorHandles[i] (which corresponds to a Tensor.nativeHandle) is considered to be the
    *     inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus, it is required that
    *     inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length.
    * @param outputOpHandles (see outputOpIndices)
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
index ebc5b01..8472509 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
@@ -247,23 +247,8 @@
     return ret;
   }
 
-  /**
-   * Creates a Tensor of any type with data from the given buffer.
-   *
-   * <p>Creates a Tensor with the provided shape of any type where the tensor's data has been
-   * encoded into {@code data} as per the specification of the TensorFlow <a
-   * href="https://www.tensorflow.org/code/tensorflow/c/c_api.h">C
-   * API</a>.
-   *
-   * @param <T> The tensor element type
-   * @param type the tensor element type, specified as a DataType. This must agree with T.
-   * @param shape the tensor shape.
-   * @param data a buffer containing the tensor data.
-   * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the
-   *     buffer
-   */
   private static Tensor<?> create(DataType dtype, long[] shape, ByteBuffer data) {
-    int nremaining = 0;
+    int nremaining;
     if (dtype != DataType.STRING) {
       int elemBytes = elemByteSize(dtype);
       if (data.remaining() % elemBytes != 0) {
@@ -633,7 +618,7 @@
    *
    * <p>This helper class wraps the tensor native handle and support both situations; If an eager
    * reference to the tensor exists, it will take care of releasing the tensor at the end of its
-   * life. If the tensor is being explicetly closed before this happens, it will take cake of
+   * life. If the tensor is being explicitly closed before this happens, it will take cake of
    * clearing its association with any eager session before cleaning up the resources.
    */
   private static class NativeReference {
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/package-info.java b/tensorflow/java/src/main/java/org/tensorflow/types/package-info.java
index 4042fb1..a3d6edd 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/types/package-info.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/package-info.java
@@ -24,6 +24,6 @@
  *
  * <p>TensorFlow element types are also separately represented by the {@link
  * org.tensorflow.DataType} enum, with one enum value per element type. The enum representation is
- * not usually needed, but can be obtained using {@link org.tensorflow.DataType.fromClass}.
+ * not usually needed, but can be obtained using {@link org.tensorflow.DataType#fromClass}.
  */
 package org.tensorflow.types;
diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD
index e97de3d..e353edd 100644
--- a/tensorflow/lite/BUILD
+++ b/tensorflow/lite/BUILD
@@ -224,6 +224,7 @@
         "//tensorflow/lite/delegates/nnapi:nnapi_delegate",
         "//tensorflow/lite/nnapi:nnapi_implementation",
         "//tensorflow/lite/schema:schema_fbs",
+        "//tensorflow/lite/experimental/resource_variable:resource_variable",
     ] + select({
         ":with_select_tf_ops": [
             "//tensorflow/lite/delegates/flex:delegate",
@@ -327,6 +328,7 @@
         "testdata/2_subgraphs.bin",
         "testdata/empty_model.bin",
         "testdata/multi_add_flex.bin",
+        "testdata/test_min_runtime.bin",
         "testdata/test_model.bin",
         "testdata/test_model_broken.bin",
     ],
diff --git a/tensorflow/lite/arena_planner.cc b/tensorflow/lite/arena_planner.cc
index e695c43..3258f61 100644
--- a/tensorflow/lite/arena_planner.cc
+++ b/tensorflow/lite/arena_planner.cc
@@ -153,7 +153,7 @@
     }
   }
   // Go through the graph in execution order.
-  for (int i = 0; i < graph_info_->num_nodes(); ++i) {
+  for (size_t i = 0; i < graph_info_->num_nodes(); ++i) {
     const TfLiteNode& node = graph_info_->node(i);
 
     // First queue output tensors for allocation.
@@ -193,7 +193,7 @@
   TF_LITE_ENSURE_STATUS(CalculateAllocations(first_node, last_node));
   TF_LITE_ENSURE_STATUS(Commit());
 
-  for (int i = 0; i < graph_info_->num_tensors(); ++i) {
+  for (int i = 0; i < static_cast<int>(graph_info_->num_tensors()); ++i) {
     // TODO(ahentz): we could do this only for the tensors that were modified
     // in CalculateAllocations(), instead of redoing it for tensors that
     // already had proper pointers. However we must be very careful, because
@@ -237,9 +237,14 @@
     }
   }
 
-  // Don't forget to deallocate temporaries of last node.
-  TF_LITE_ENSURE_STATUS(
-      CalculateDeallocationOfInternalTensors(active_node - 1));
+  // For the case if the graph is empty the node index can be negative since we
+  // substract from the active node, so the node_index can be zero for those
+  // cases
+  if (active_node > 0) {
+    // Don't forget to deallocate temporaries of last node.
+    TF_LITE_ENSURE_STATUS(
+        CalculateDeallocationOfInternalTensors(active_node - 1));
+  }
 
   return kTfLiteOk;
 }
@@ -284,8 +289,8 @@
 
 TfLiteStatus ArenaPlanner::CalculateAllocationOfInternalTensors(
     int node_index) {
-  if (node_index < graph_info_->num_nodes()) {
-    const TfLiteNode& node = graph_info_->node(node_index);
+  if (node_index < static_cast<int>(graph_info_->num_nodes())) {
+    const TfLiteNode& node = graph_info_->node(static_cast<size_t>(node_index));
     TfLiteIntArray* node_temporaries = node.temporaries;
     for (int i = 0; i < node_temporaries->size; ++i) {
       int tensor_index = node_temporaries->data[i];
@@ -297,8 +302,8 @@
 
 TfLiteStatus ArenaPlanner::CalculateDeallocationOfInternalTensors(
     int node_index) {
-  if (node_index < graph_info_->num_nodes()) {
-    const TfLiteNode& node = graph_info_->node(node_index);
+  if (node_index < static_cast<int>(graph_info_->num_nodes())) {
+    const TfLiteNode& node = graph_info_->node(static_cast<size_t>(node_index));
     TfLiteIntArray* node_temporaries = node.temporaries;
     for (int i = 0; i < node_temporaries->size; ++i) {
       int tensor_index = node_temporaries->data[i];
diff --git a/tensorflow/lite/arena_planner_test.cc b/tensorflow/lite/arena_planner_test.cc
index 3b6c9d5..0e80d42 100644
--- a/tensorflow/lite/arena_planner_test.cc
+++ b/tensorflow/lite/arena_planner_test.cc
@@ -211,6 +211,18 @@
   Execute(0, 10);
 }
 
+TEST_F(ArenaPlannerTest, DeallocationOfInputTensor) {
+  // This is a negative TC, which will try to make sure that no allocation for
+  // input tensors is done, when making call with negative node_index, since
+  // previous check was doing comparison of node_index which was int and
+  // unsigned int, implicit conversion was passing this case, as the negative
+  // number was converted to unsigned it making it invalid.The new check
+  // takes care of this problem and removes the warning as well.
+  TestGraph graph({-1}, {}, {1});
+  SetGraph(&graph);
+  Execute(0, 10);
+}
+
 TEST_F(ArenaPlannerTest, GraphWithNoOps) {
   TestGraph graph({0, 10}, {}, {5, 11});
   SetGraph(&graph);
diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl
index cb98f69..cc7df7f 100644
--- a/tensorflow/lite/build_def.bzl
+++ b/tensorflow/lite/build_def.bzl
@@ -51,7 +51,6 @@
     return select({
         "//tensorflow:android": [
             "-Wl,--no-export-dynamic",  # Only inc syms referenced by dynamic obj.
-            "-Wl,--exclude-libs,ALL",  # Exclude syms in all libs from auto export.
             "-Wl,--gc-sections",  # Eliminate unused code and data.
             "-Wl,--as-needed",  # Don't link unused libs.
         ],
@@ -110,6 +109,7 @@
         linkstatic = 1,
         testonly = 0,
         deps = [],
+        tags = [],
         srcs = []):
     """Builds a jni binary for TFLite."""
     linkopts = linkopts + select({
@@ -130,6 +130,7 @@
         linkstatic = linkstatic,
         deps = deps + [linkscript, exported_symbols],
         srcs = srcs,
+        tags = tags,
         linkopts = linkopts,
         testonly = testonly,
     )
diff --git a/tensorflow/lite/c/c_api_internal.h b/tensorflow/lite/c/c_api_internal.h
index e1c54cb..c31d3e5 100644
--- a/tensorflow/lite/c/c_api_internal.h
+++ b/tensorflow/lite/c/c_api_internal.h
@@ -51,7 +51,11 @@
   kTfLiteMaxExternalContexts = 4
 } TfLiteExternalContextType;
 
+// Forward declare so dependent structs and methods can reference these types
+// prior to the struct definitions.
 struct TfLiteContext;
+struct TfLiteDelegate;
+struct TfLiteRegistration;
 
 // An external context is a collection of information unrelated to the TF Lite
 // framework, but useful to a subset of the ops. TF Lite knows very little
@@ -63,10 +67,6 @@
   TfLiteStatus (*Refresh)(struct TfLiteContext* context);
 } TfLiteExternalContext;
 
-// Forward declare so GetNode can use this is in Context.
-typedef struct _TfLiteRegistration TfLiteRegistration;
-typedef struct _TfLiteDelegate TfLiteDelegate;
-
 #define kOptionalTensor (-1)
 
 // Fixed size list of integers. Used for dimensions and inputs/outputs tensor
@@ -330,7 +330,7 @@
 
   // The delegate which knows how to handle `buffer_handle`.
   // WARNING: This is an experimental interface that is subject to change.
-  TfLiteDelegate* delegate;
+  struct TfLiteDelegate* delegate;
 
   // An integer buffer handle that can be handled by `delegate`.
   // The value is valid only when delegate is not null.
@@ -405,7 +405,7 @@
   // The pointer to the delegate. This is non-null only when the node is
   // created by calling `interpreter.ModifyGraphWithDelegate`.
   // WARNING: This is an experimental interface that is subject to change.
-  TfLiteDelegate* delegate;
+  struct TfLiteDelegate* delegate;
 } TfLiteNode;
 
 typedef struct TfLiteContext {
@@ -451,15 +451,15 @@
 
   // Get a Tensor node by node_index.
   // WARNING: This is an experimental interface that is subject to change.
-  TfLiteStatus (*GetNodeAndRegistration)(struct TfLiteContext*, int node_index,
-                                         TfLiteNode** node,
-                                         TfLiteRegistration** registration);
+  TfLiteStatus (*GetNodeAndRegistration)(
+      struct TfLiteContext*, int node_index, TfLiteNode** node,
+      struct TfLiteRegistration** registration);
 
   // Replace ops with one or more stub delegate operations. This function
   // does not take ownership of `nodes_to_replace`.
   TfLiteStatus (*ReplaceNodeSubsetsWithDelegateKernels)(
-      struct TfLiteContext*, TfLiteRegistration registration,
-      const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate);
+      struct TfLiteContext*, struct TfLiteRegistration registration,
+      const TfLiteIntArray* nodes_to_replace, struct TfLiteDelegate* delegate);
 
   // Number of threads that are recommended to subsystems like gemmlowp and
   // eigen.
@@ -484,7 +484,7 @@
   void* profiler;
 } TfLiteContext;
 
-typedef struct _TfLiteRegistration {
+typedef struct TfLiteRegistration {
   // Initializes the op from serialized data.
   // If a built-in op:
   //   `buffer` is the op's params data (TfLiteLSTMParams*).
@@ -560,7 +560,7 @@
 } TfLiteDelegateFlags;
 
 // WARNING: This is an experimental interface that is subject to change.
-typedef struct _TfLiteDelegate {
+typedef struct TfLiteDelegate {
   // Data that delegate needs to identify itself. This data is owned by the
   // delegate. The delegate is owned in the user code, so the delegate is
   // responsible for doing this when it is destroyed.
@@ -571,20 +571,21 @@
   // will look at the nodes and call ReplaceNodeSubsetsWithDelegateKernels()
   // to ask the TensorFlow lite runtime to create macro-nodes to represent
   // delegated subgraphs of the original graph.
-  TfLiteStatus (*Prepare)(TfLiteContext* context, TfLiteDelegate* delegate);
+  TfLiteStatus (*Prepare)(TfLiteContext* context,
+                          struct TfLiteDelegate* delegate);
 
   // Copy the data from delegate buffer handle into raw memory of the given
   // 'tensor'. This cannot be null. The delegate is allowed to allocate the raw
   // bytes as long as it follows the rules for kTfLiteDynamic tensors.
   TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context,
-                                       TfLiteDelegate* delegate,
+                                       struct TfLiteDelegate* delegate,
                                        TfLiteBufferHandle buffer_handle,
                                        TfLiteTensor* tensor);
 
   // Copy the data from raw memory of the given 'tensor' to delegate buffer
   // handle. This can be null if the delegate doesn't use its own buffer.
   TfLiteStatus (*CopyToBufferHandle)(TfLiteContext* context,
-                                     TfLiteDelegate* delegate,
+                                     struct TfLiteDelegate* delegate,
                                      TfLiteBufferHandle buffer_handle,
                                      TfLiteTensor* tensor);
 
@@ -592,7 +593,8 @@
   // this doesn't release the underlying resource (e.g. textures). The
   // resources are either owned by application layer or the delegate.
   // This can be null if the delegate doesn't use its own buffer.
-  void (*FreeBufferHandle)(TfLiteContext* context, TfLiteDelegate* delegate,
+  void (*FreeBufferHandle)(TfLiteContext* context,
+                           struct TfLiteDelegate* delegate,
                            TfLiteBufferHandle* handle);
 
   // Bitmask flags. See the comments in `TfLiteDelegateFlags`.
diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc
index 53a4e8f..d1121e9 100644
--- a/tensorflow/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc
@@ -435,6 +435,9 @@
                                    lstm_params->kernel_type());
             return kTfLiteError;
         }
+      } else {
+        error_reporter->Report("No valid LSTM builtin options exist");
+        return kTfLiteError;
       }
       *builtin_data = reinterpret_cast<void*>(params.release());
       break;
diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc
index acbd41d..42fa0c3 100644
--- a/tensorflow/lite/core/subgraph.cc
+++ b/tensorflow/lite/core/subgraph.cc
@@ -156,12 +156,14 @@
 
 Subgraph::Subgraph(ErrorReporter* error_reporter,
                    TfLiteExternalContext** external_contexts,
-                   std::vector<std::unique_ptr<Subgraph>>* subgraphs)
+                   std::vector<std::unique_ptr<Subgraph>>* subgraphs,
+                   ResourceVariableMap* resource_variables)
     : external_contexts_(external_contexts),
       error_reporter_(error_reporter),
       next_execution_plan_index_to_prepare_(0),
       next_execution_plan_index_to_plan_allocation_(0),
-      subgraphs_(subgraphs) {
+      subgraphs_(subgraphs),
+      resource_variables_(resource_variables) {
   context_.impl_ = static_cast<void*>(this);
   context_.ResizeTensor = ResizeTensor;
   context_.ReportError = ReportErrorC;
@@ -1056,6 +1058,9 @@
 }
 
 TfLiteStatus Subgraph::UndoAllDelegates() {
+  // Return early if there is nothing to reset to.
+  if (pre_delegation_execution_plan_.empty()) return kTfLiteOk;
+
   // First free all delegate nodes.
   for (int execution_plan_index = 0;
        execution_plan_index < execution_plan_.size(); ++execution_plan_index) {
@@ -1069,6 +1074,7 @@
 
   // Reset execution plan.
   execution_plan_ = pre_delegation_execution_plan_;
+  pre_delegation_execution_plan_.clear();
 
   // Delegate nodes are appended to nodes_and_registration_. Therefore,
   // cleanup nodes_and_registration_ to only contain nodes from
@@ -1147,24 +1153,41 @@
   // Setup additional context interface.
   SwitchToDelegateContext();
 
+  auto reset_delegation_if_not_ok = [this](TfLiteStatus status) {
+    if (status != kTfLiteOk) {
+      // This will undo all delegate nodes currently in the graph.
+      TF_LITE_ENSURE_STATUS(this->UndoAllDelegates());
+      // This will call AllocateTensors, thus-reapplying any (successfully
+      // applied) previous delegates.
+      TF_LITE_ENSURE_STATUS(this->EnsureMemoryAllocations());
+      ReportError(
+          "Restored previous execution plan after delegate application "
+          "failure.");
+      return kTfLiteError;
+    }
+    return kTfLiteOk;
+  };
+
   TfLiteStatus status = delegate->Prepare(&context_, delegate);
 
   // Remove additional context info.
   SwitchToKernelContext();
 
-  TF_LITE_ENSURE_OK(&context_, status);
+  TF_LITE_ENSURE_STATUS(reset_delegation_if_not_ok(status));
 
   if (!(delegate->flags & kTfLiteDelegateFlagsAllowDynamicTensors)) {
     // Reset the state to force tensor/op reallocation.
     state_ = kStateUninvokable;
-    TF_LITE_ENSURE_OK(&context_, EnsureMemoryAllocations());
+    TF_LITE_ENSURE_STATUS(
+        reset_delegation_if_not_ok(EnsureMemoryAllocations()));
     // After using a delegate which doesn't support dynamic tensors, make the
     // entire graph immutable.
     state_ = kStateInvokableAndImmutable;
   } else if (was_invokable_before_delegate) {
     // If the graph was invokable prior to delegate application, flush
     // allocation now to leave it in a consistent state.
-    TF_LITE_ENSURE_OK(&context_, EnsureMemoryAllocations());
+    TF_LITE_ENSURE_STATUS(
+        reset_delegation_if_not_ok(EnsureMemoryAllocations()));
   }
   delegates_applied_.push_back(delegate);
 
diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h
index 0a6bb63..b9736d8 100644
--- a/tensorflow/lite/core/subgraph.h
+++ b/tensorflow/lite/core/subgraph.h
@@ -16,12 +16,14 @@
 #define TENSORFLOW_LITE_CORE_SUBGRAPH_H_
 
 #include <cstdlib>
+#include <map>
 #include <vector>
 
 #include "tensorflow/lite/allocation.h"
 #include "tensorflow/lite/c/c_api_internal.h"
 #include "tensorflow/lite/core/api/profiler.h"
 #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
+#include "tensorflow/lite/experimental/resource_variable/resource_variable.h"
 #include "tensorflow/lite/memory_planner.h"
 #include "tensorflow/lite/util.h"
 
@@ -36,7 +38,8 @@
 
   Subgraph(ErrorReporter* error_reporter,
            TfLiteExternalContext** external_contexts,
-           std::vector<std::unique_ptr<Subgraph>>* subgraphs);
+           std::vector<std::unique_ptr<Subgraph>>* subgraphs,
+           ResourceVariableMap* resource_variables);
 
   Subgraph(const Subgraph&) = delete;
 
@@ -160,6 +163,10 @@
   // Read only access to list of variable tensors.
   const std::vector<int>& variables() const { return variables_; }
 
+  // WARNING: Experimental interface, subject to change.
+  // TODO(ycling): Move this function to an external context interface.
+  ResourceVariableMap& resource_variables() { return *resource_variables_; }
+
   size_t tensors_size() const { return tensors_.size(); }
 
   // Return the number of ops in the model.
@@ -581,6 +588,10 @@
   // Reference to data used by the cancellation function in
   // `check_cancelled_func_`.
   void* cancellation_data_ = nullptr;
+
+  // A map of resource variables. Owned by interpreter and shared by multiple
+  // subgraphs.
+  ResourceVariableMap* resource_variables_ = nullptr;
 };
 
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD
index 7eacef0..431fcab 100644
--- a/tensorflow/lite/delegates/gpu/BUILD
+++ b/tensorflow/lite/delegates/gpu/BUILD
@@ -52,6 +52,7 @@
         "//tensorflow/lite/delegates/gpu/gl:command_queue",
         "//tensorflow/lite/delegates/gpu/gl:compiler",
         "//tensorflow/lite/delegates/gpu/gl:egl_environment",
+        "//tensorflow/lite/delegates/gpu/gl:request_gpu_info",
         "//tensorflow/lite/delegates/gpu/gl:gl_call",
         "//tensorflow/lite/delegates/gpu/gl/converters:bhwc_to_phwc4",
         "//tensorflow/lite/delegates/gpu/gl/converters:phwc4_to_bhwc",
@@ -73,7 +74,6 @@
     name = "metal_delegate",
     srcs = ["metal_delegate.mm"],
     hdrs = ["metal_delegate.h"],
-    copts = ["-std=c++11"],
     sdk_frameworks = ["Metal"],
     deps = [
         "//tensorflow/lite:kernel_api",
@@ -96,6 +96,16 @@
     ],
 )
 
+objc_library(
+    name = "metal_delegate_internal",
+    hdrs = ["metal_delegate_internal.h"],
+    copts = ["-std=c++11"],
+    sdk_frameworks = ["Metal"],
+    deps = [
+        "//tensorflow/lite/delegates/gpu:metal_delegate",
+    ],
+)
+
 # build -c opt --config android_arm64 --copt -Os --copt -DTFLITE_GPU_BINARY_RELEASE --copt -fvisibility=hidden --linkopt -s --strip always :libtensorflowlite_gpu_gl.so
 cc_binary(
     name = "libtensorflowlite_gpu_gl.so",
diff --git a/tensorflow/lite/delegates/gpu/common/BUILD b/tensorflow/lite/delegates/gpu/common/BUILD
index fe5f5ed..9cb80e8 100644
--- a/tensorflow/lite/delegates/gpu/common/BUILD
+++ b/tensorflow/lite/delegates/gpu/common/BUILD
@@ -25,6 +25,15 @@
 )
 
 cc_library(
+    name = "gpu_info",
+    srcs = ["gpu_info.cc"],
+    hdrs = ["gpu_info.h"],
+    deps = [
+        "@com_google_absl//absl/strings",
+    ],
+)
+
+cc_library(
     name = "data_type",
     srcs = ["data_type.cc"],
     hdrs = ["data_type.h"],
@@ -77,6 +86,7 @@
         ":tensor",
         "//tensorflow/lite:context",
         "//tensorflow/lite:kernel_api",
+        "//tensorflow/lite:util",
         "//tensorflow/lite/c:c_api_internal",
         "//tensorflow/lite/kernels:kernel_util",
         "//tensorflow/lite/schema:schema_fbs",
diff --git a/tensorflow/lite/delegates/gpu/common/convert.cc b/tensorflow/lite/delegates/gpu/common/convert.cc
index 53db297..81d09b2 100644
--- a/tensorflow/lite/delegates/gpu/common/convert.cc
+++ b/tensorflow/lite/delegates/gpu/common/convert.cc
@@ -29,21 +29,9 @@
 constexpr int kPhwo4i4ChannelsInPlane = 4;
 constexpr int kPiohw4ChannelsInPlane = 4;
 
-}  // namespace
-
-uint32_t GetElementsSizeForPHWO4I4(const OHWI& shape) {
-  return AlignByN(shape.i, kPhwo4i4ChannelsInPlane) *
-         AlignByN(shape.o, kPhwo4i4ChannelsInPlane) * shape.h * shape.w;
-}
-
-uint32_t GetElementsSizeForPHWO4I4(const IHWO& shape) {
-  return AlignByN(shape.i, kPhwo4i4ChannelsInPlane) *
-         AlignByN(shape.o, kPhwo4i4ChannelsInPlane) * shape.h * shape.w;
-}
-
 // Layout is Po,H,W,OI4x4.
 Status ConvertToPHWO4I4(absl::Span<const float> in, const OHWI& shape,
-                        absl::Span<float> out) {
+                        absl::Span<float> out, bool reverse_space) {
   if (in.size() != shape.DimensionsProduct()) {
     return InvalidArgumentError(absl::StrCat(
         "ConvertToPHWO4I4: Input data size does not match expected size: ",
@@ -70,7 +58,9 @@
                 // tensor is in OHWI
                 int tensor_o = p * kPhwo4i4ChannelsInPlane + co;
                 int tensor_i = c * kPhwo4i4ChannelsInPlane + ci;
-                value = in[shape.LinearIndex({tensor_o, h, w, tensor_i})];
+                const int in_h = reverse_space ? shape.h - 1 - h : h;
+                const int in_w = reverse_space ? shape.w - 1 - w : w;
+                value = in[shape.LinearIndex({tensor_o, in_h, in_w, tensor_i})];
               }
               (*output++) = value;
             }
@@ -82,11 +72,34 @@
   return OkStatus();
 }
 
+}  // namespace
+
+uint32_t GetElementsSizeForPHWO4I4(const OHWI& shape) {
+  return AlignByN(shape.i, kPhwo4i4ChannelsInPlane) *
+         AlignByN(shape.o, kPhwo4i4ChannelsInPlane) * shape.h * shape.w;
+}
+
+uint32_t GetElementsSizeForPHWO4I4(const IHWO& shape) {
+  return AlignByN(shape.i, kPhwo4i4ChannelsInPlane) *
+         AlignByN(shape.o, kPhwo4i4ChannelsInPlane) * shape.h * shape.w;
+}
+
 std::vector<float> ConvertToPHWO4I4(
     const Tensor<OHWI, DataType::FLOAT32>& tensor) {
   std::vector<float> transposed(GetElementsSizeForPHWO4I4(tensor.shape));
   ConvertToPHWO4I4(tensor.data, tensor.shape,
-                   absl::MakeSpan(transposed.data(), transposed.size()))
+                   absl::MakeSpan(transposed.data(), transposed.size()),
+                   /*reverse_space=*/false)
+      .IgnoreError();
+  return transposed;
+}
+
+std::vector<float> ConvertToPHWO4I4Transposed(
+    const Tensor<OHWI, DataType::FLOAT32>& tensor) {
+  std::vector<float> transposed(GetElementsSizeForPHWO4I4(tensor.shape));
+  ConvertToPHWO4I4(tensor.data, tensor.shape,
+                   absl::MakeSpan(transposed.data(), transposed.size()),
+                   /*reverse_space=*/true)
       .IgnoreError();
   return transposed;
 }
diff --git a/tensorflow/lite/delegates/gpu/common/convert.h b/tensorflow/lite/delegates/gpu/common/convert.h
index fdf9e02..30a0a5f 100644
--- a/tensorflow/lite/delegates/gpu/common/convert.h
+++ b/tensorflow/lite/delegates/gpu/common/convert.h
@@ -63,14 +63,14 @@
 // @return number of elements when shape is converted into PHWO4I4.
 uint32_t GetElementsSizeForPHWO4I4(const OHWI& shape);
 
-// Layout is Po,H,W,OI4x4.
-Status ConvertToPHWO4I4(absl::Span<const float> in, const OHWI& shape,
-                        absl::Span<float> out);
-
 // Convenience wrapper around a method above.
 std::vector<float> ConvertToPHWO4I4(
     const Tensor<OHWI, DataType::FLOAT32>& tensor);
 
+// Convenience wrapper around a method above, for Transposed Convolution.
+std::vector<float> ConvertToPHWO4I4Transposed(
+    const Tensor<OHWI, DataType::FLOAT32>& tensor);
+
 // @return (x,y,z) size for PHWO4I4 to access elements where each element
 // consists of 4 values.
 uint3 Get3DSizeForPHWO4I4(const OHWI& shape);
diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.cc b/tensorflow/lite/delegates/gpu/common/gpu_info.cc
new file mode 100644
index 0000000..14fb48a
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/common/gpu_info.cc
@@ -0,0 +1,103 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
+
+#include <algorithm>
+#include <cctype>
+#include <string>
+
+#include "absl/strings/ascii.h"
+
+namespace tflite {
+namespace gpu {
+namespace {
+
+GpuType GetGpuType(const std::string& renderer) {
+  if (renderer.find("mali") != renderer.npos) {
+    return GpuType::MALI;
+  }
+  if (renderer.find("adreno") != renderer.npos) {
+    return GpuType::ADRENO;
+  }
+  if (renderer.find("powervr") != renderer.npos) {
+    return GpuType::POWERVR;
+  }
+  if (renderer.find("intel") != renderer.npos) {
+    return GpuType::INTEL;
+  }
+  if (renderer.find("nvidia") != renderer.npos) {
+    return GpuType::NVIDIA;
+  }
+  return GpuType::UNKNOWN;
+}
+
+GpuModel GetGpuModel(const std::string& renderer) {
+  auto found_model = [&](std::string model) -> bool {
+    return renderer.find(model) != renderer.npos;
+  };
+  // Adreno 6xx series
+  if (found_model("640")) return GpuModel::ADRENO640;
+  if (found_model("630")) return GpuModel::ADRENO630;
+  if (found_model("616")) return GpuModel::ADRENO616;
+  if (found_model("615")) return GpuModel::ADRENO615;
+  if (found_model("612")) return GpuModel::ADRENO612;
+  if (found_model("605")) return GpuModel::ADRENO605;
+  // Adreno 5xx series
+  if (found_model("540")) return GpuModel::ADRENO540;
+  if (found_model("530")) return GpuModel::ADRENO530;
+  if (found_model("512")) return GpuModel::ADRENO512;
+  if (found_model("510")) return GpuModel::ADRENO510;
+  if (found_model("509")) return GpuModel::ADRENO509;
+  if (found_model("508")) return GpuModel::ADRENO508;
+  if (found_model("506")) return GpuModel::ADRENO506;
+  if (found_model("505")) return GpuModel::ADRENO505;
+  if (found_model("504")) return GpuModel::ADRENO504;
+  // Adreno 4xx series
+  if (found_model("430")) return GpuModel::ADRENO430;
+  if (found_model("420")) return GpuModel::ADRENO420;
+  if (found_model("418")) return GpuModel::ADRENO418;
+  if (found_model("405")) return GpuModel::ADRENO405;
+  // Adreno 3xx series
+  if (found_model("330")) return GpuModel::ADRENO330;
+  if (found_model("320")) return GpuModel::ADRENO320;
+  if (found_model("308")) return GpuModel::ADRENO308;
+  if (found_model("306")) return GpuModel::ADRENO306;
+  if (found_model("305")) return GpuModel::ADRENO305;
+  if (found_model("304")) return GpuModel::ADRENO304;
+  // Adreno 2xx series
+  if (found_model("225")) return GpuModel::ADRENO225;
+  if (found_model("220")) return GpuModel::ADRENO220;
+  if (found_model("205")) return GpuModel::ADRENO205;
+  if (found_model("203")) return GpuModel::ADRENO203;
+  if (found_model("200")) return GpuModel::ADRENO200;
+  // Adreno 1xx series
+  if (found_model("130")) return GpuModel::ADRENO130;
+  return GpuModel::UNKNOWN;
+}
+
+}  // namespace
+
+void GetGpuModelAndType(const std::string& renderer, GpuModel* gpu_model,
+                        GpuType* gpu_type) {
+  std::string lowered = renderer;
+  absl::AsciiStrToLower(&lowered);
+  *gpu_type = GetGpuType(lowered);
+  *gpu_model =
+      *gpu_type == GpuType::ADRENO ? GetGpuModel(lowered) : GpuModel::UNKNOWN;
+}
+
+}  // namespace gpu
+}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.h b/tensorflow/lite/delegates/gpu/common/gpu_info.h
new file mode 100644
index 0000000..44d10b3
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/common/gpu_info.h
@@ -0,0 +1,92 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_GPU_INFO_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_GPU_INFO_H_
+
+#include <string>
+#include <vector>
+
+namespace tflite {
+namespace gpu {
+
+enum class GpuType { UNKNOWN, MALI, ADRENO, POWERVR, INTEL, NVIDIA };
+enum class GpuModel {
+  UNKNOWN,
+  // Adreno 6xx series
+  ADRENO640,
+  ADRENO630,
+  ADRENO616,
+  ADRENO615,
+  ADRENO612,
+  ADRENO605,
+  // Adreno 5xx series
+  ADRENO540,
+  ADRENO530,
+  ADRENO512,
+  ADRENO510,
+  ADRENO509,
+  ADRENO508,
+  ADRENO506,
+  ADRENO505,
+  ADRENO504,
+  // Adreno 4xx series
+  ADRENO430,
+  ADRENO420,
+  ADRENO418,
+  ADRENO405,
+  // Adreno 3xx series
+  ADRENO330,
+  ADRENO320,
+  ADRENO308,
+  ADRENO306,
+  ADRENO305,
+  ADRENO304,
+  // Adreno 2xx series
+  ADRENO225,
+  ADRENO220,
+  ADRENO205,
+  ADRENO203,
+  ADRENO200,
+  // Adreno 1xx series
+  ADRENO130,
+};
+
+struct GpuInfo {
+  GpuType type = GpuType::UNKNOWN;
+  std::string renderer_name;
+  std::string vendor_name;
+  std::string version;
+  GpuModel gpu_model;
+  int major_version = -1;
+  int minor_version = -1;
+  std::vector<std::string> extensions;
+  int max_ssbo_bindings = 0;
+  int max_image_bindings = 0;
+  std::vector<int> max_work_group_size;
+  int max_work_group_invocations;
+  int max_texture_size = 0;
+  int max_image_units = 0;
+  int max_array_texture_layers = 0;
+};
+
+// Analyzes `renderer` and returns matching `GpuType` and `GpuModel`.
+void GetGpuModelAndType(const std::string& renderer, GpuModel* gpu_model,
+                        GpuType* gpu_type);
+
+}  // namespace gpu
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_GPU_INFO_H_
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management.cc b/tensorflow/lite/delegates/gpu/common/memory_management.cc
index a5d5fc9..8bfdd83 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management.cc
+++ b/tensorflow/lite/delegates/gpu/common/memory_management.cc
@@ -60,6 +60,11 @@
   size_t object_id;
 };
 
+bool CompareBySize(const TensorUsageWithIndex<size_t>& first,
+                   const TensorUsageWithIndex<size_t>& second) {
+  return first.usage_record->tensor_size > second.usage_record->tensor_size;
+}
+
 // Implements memory management with a naive algorithm.
 //
 // The problem of memory management is NP-complete. This implements a
@@ -206,6 +211,88 @@
   return OkStatus();
 }
 
+// Assigns given tensors to offsets, using the following greedy algorithm:
+// - We have tensor usage records of all intermideate tensors as an input. Each
+// record consists of tensor size, first and last tasks, that use it. Let's call
+// [first_task..last_task] a tensor usage interval;
+// - Iterate through tensor usage records in non-increasing order of
+// corresponding tensor sizes;
+// - For each of these records consider already assigned tensors, which usage
+// intervals intersect with usage interval of current tensor, and find the
+// smallest gap in memory between them such, that current tensor fits into that
+// gap;
+// - If such a gap has been found, current tensor should be allocated into this
+// gap. Otherwise we can allocate it after the rightmost tensor, which usage
+// interval intersects with usage inteval of current tensor. So we assign
+// corresponding offset to current tensor and the tensor becomes assigned.
+Status GreedyBySizeAssignment(
+    const std::vector<TensorUsageRecord<size_t>>& usage_records,
+    OffsetsAssignment* assignment) {
+  const size_t num_tensors = usage_records.size();
+  assignment->offsets.resize(num_tensors);
+  assignment->total_size = 0;
+
+  // Ordered records are to be sorted by size of corrseponding tensor.
+  std::vector<TensorUsageWithIndex<size_t>> ordered_records;
+  for (size_t i = 0; i < num_tensors; ++i) {
+    ordered_records.emplace_back(&usage_records[i], i);
+  }
+  std::sort(ordered_records.begin(), ordered_records.end(), CompareBySize);
+
+  // Vector of ids of already allocated tensors, ordered by offset.
+  std::vector<size_t> ordered_allocs;
+
+  for (const auto& rec_with_idx : ordered_records) {
+    const TensorUsageRecord<size_t>* rec = rec_with_idx.usage_record;
+    size_t best_diff = kNotAssigned;
+    size_t best_offset = kNotAssigned;
+    size_t prev_offset = 0;
+    for (const auto& allocated_id : ordered_allocs) {
+      if (usage_records[allocated_id].last_task < rec->first_task ||
+          usage_records[allocated_id].first_task > rec->last_task) {
+        // Tensor allocated_id has usage interval, that doesn't intersect with
+        // current tensor's usage interval, so we skip it.
+        continue;
+      }
+      size_t cur_offset = assignment->offsets[allocated_id];
+      if (cur_offset >= prev_offset) {
+        size_t diff = cur_offset - prev_offset;
+        // Check, if current_tensor fits into the gap, located directly to the
+        // left of tensor allocated_id offset, and that this gap is the smallest
+        // of previously considered suitable gaps.
+        if (diff >= rec->tensor_size && diff < best_diff) {
+          best_diff = diff;
+          best_offset = prev_offset;
+        }
+      }
+      prev_offset = std::max(
+          prev_offset, cur_offset + usage_records[allocated_id].tensor_size);
+    }
+    if (assignment->total_size < prev_offset) {
+      return InternalError("Total size is wrong.");
+    }
+
+    // If no suitable gap found, we should allocate current tensor after the
+    // rightmost tensor, which usage interval intersects with the current one.
+    if (best_offset == kNotAssigned) {
+      best_offset = prev_offset;
+    }
+
+    // Assign best_offset to the current tensor and find the correct place to
+    // insert information about it into ordered_allocs to save the order.
+    auto it = ordered_allocs.begin();
+    while (it != ordered_allocs.end() &&
+           assignment->offsets[*it] <= best_offset) {
+      ++it;
+    }
+    ordered_allocs.insert(it, rec_with_idx.idx);
+    assignment->offsets[rec_with_idx.idx] = best_offset;
+    assignment->total_size =
+        std::max(assignment->total_size, best_offset + rec->tensor_size);
+  }
+  return OkStatus();
+}
+
 // This class build flow graph and solves Minimum-cost flow problem in it.
 class MinCostFlowSolver {
  public:
@@ -401,6 +488,28 @@
 
 }  // namespace
 
+bool CompareBySize(const TensorUsageWithIndex<size_t>& first,
+                   const TensorUsageWithIndex<size_t>& second) {
+  return first.usage_record->tensor_size > second.usage_record->tensor_size;
+}
+
+OffsetsAssignment ObjectsToOffsets(
+    const ObjectsAssignment<size_t>& obj_assignment) {
+  size_t num_tensors = obj_assignment.object_ids.size();
+  size_t num_objects = obj_assignment.object_sizes.size();
+  OffsetsAssignment result = {/*offsets=*/std::vector<size_t>(num_tensors),
+                              /*total_size=*/0};
+  std::vector<size_t> ids_to_offset(num_objects);
+  for (size_t i = 0; i < num_objects; ++i) {
+    ids_to_offset[i] = result.total_size;
+    result.total_size += obj_assignment.object_sizes[i];
+  }
+  for (size_t i = 0; i < num_tensors; ++i) {
+    result.offsets[i] = ids_to_offset[obj_assignment.object_ids[i]];
+  }
+  return result;
+}
+
 Status AssignObjectsToTensors(
     const std::vector<TensorUsageRecord<size_t>>& usage_records,
     const MemoryStrategy& strategy, ObjectsAssignment<size_t>* assignment) {
@@ -413,6 +522,9 @@
       return GreedyAssignment(usage_records, assignment);
     case MemoryStrategy::MINCOSTFLOW:
       return MinCostFlowAssignment(usage_records, assignment);
+    default:
+      return InternalError(
+          "MemoryStrategy is not supported with current tensor size type.");
   }
   return OkStatus();
 }
@@ -432,5 +544,18 @@
   return OkStatus();
 }
 
+Status AssignOffsetsToTensors(
+    const std::vector<TensorUsageRecord<size_t>>& usage_records,
+    const MemoryStrategy& strategy, OffsetsAssignment* assignment) {
+  if (strategy == MemoryStrategy::GREEDY_BY_SIZE) {
+    return GreedyBySizeAssignment(usage_records, assignment);
+  }
+  ObjectsAssignment<size_t> objects_assignment;
+  RETURN_IF_ERROR(
+      AssignObjectsToTensors(usage_records, strategy, &objects_assignment));
+  *assignment = ObjectsToOffsets(objects_assignment);
+  return OkStatus();
+}
+
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management.h b/tensorflow/lite/delegates/gpu/common/memory_management.h
index d3fec0a..ca74d2c 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management.h
+++ b/tensorflow/lite/delegates/gpu/common/memory_management.h
@@ -48,6 +48,19 @@
   }
 };
 
+template <typename TensorSizeT>
+struct TensorUsageWithIndex {
+  const TensorUsageRecord<TensorSizeT>* usage_record;
+  size_t idx;
+
+  TensorUsageWithIndex(const TensorUsageRecord<TensorSizeT>* usage_record,
+                       size_t idx)
+      : usage_record(usage_record), idx(idx) {}
+};
+
+bool CompareBySize(const TensorUsageWithIndex<size_t>& first,
+                   const TensorUsageWithIndex<size_t>& second);
+
 // Information about assignment of tensors to shared objects
 template <typename TensorSizeT>
 struct ObjectsAssignment {
@@ -57,6 +70,18 @@
   std::vector<TensorSizeT> object_sizes;
 };
 
+// Information about assignment of tensors to offsets for the case, when all of
+// them are going to be allocated in one continuous memory block.
+struct OffsetsAssignment {
+  std::vector<size_t> offsets;
+  size_t total_size;
+};
+
+// Converts given assignment of tensors to shared objects to the assignment of
+// the same tensors to offsets in continuous memory block.
+OffsetsAssignment ObjectsToOffsets(
+    const ObjectsAssignment<size_t>& obj_assignment);
+
 enum class MemoryStrategy {
   // Naive strategy is to allocate each object separately.
   // Can be useful for debugging to see all intermediate outputs.
@@ -66,10 +91,17 @@
   // tensors with the same size, but non-intersecting usage intervals.
   EQUALITY,
 
-  // Greedy strategy uses greedy algorithm to reuse memory from tensors, that
+  // Greedy strategy uses greedy algorithm, iterating through all the tensors in
+  // order of their first_task, to reuse memory from tensors, that
   // won't be used anymore, for new ones.
   GREEDY,
 
+  // Greedy by size strategy uses greedy algorithm, iterating through all the
+  // tensors in
+  // non-increasing of their size, to reuse memory from tensors, that
+  // won't be used anymore, for new ones.
+  GREEDY_BY_SIZE,
+
   // Mincostflow strategy consists of building auxiliary flow graph and solving
   // the minimum-cost flow problem in it. In the end edges with zero residual
   // capacity determine assignment of shared objects to tensors.
@@ -90,6 +122,12 @@
     const std::vector<TensorUsageRecord<BHWC>>& usage_records,
     const MemoryStrategy& strategy, ObjectsAssignment<BHWC>* assignment);
 
+// Calculates the assignement of tensors to offsets, considering those tensors
+// are going to be allocated in one continuous memory block.
+Status AssignOffsetsToTensors(
+    const std::vector<TensorUsageRecord<size_t>>& usage_records,
+    const MemoryStrategy& strategy, OffsetsAssignment* assignment);
+
 }  // namespace gpu
 }  // namespace tflite
 
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management_test.cc b/tensorflow/lite/delegates/gpu/common/memory_management_test.cc
index 34cc684..df745d4 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management_test.cc
+++ b/tensorflow/lite/delegates/gpu/common/memory_management_test.cc
@@ -24,6 +24,35 @@
 
 using ::testing::ElementsAre;
 
+TEST(Model, EmptyAssignment) {
+  ObjectsAssignment<size_t> objects_assignment;
+  OffsetsAssignment result = ObjectsToOffsets(objects_assignment);
+  EXPECT_TRUE(result.offsets.empty());
+  EXPECT_EQ(result.total_size, 0);
+}
+
+TEST(Model, OneObjectAssignment) {
+  ObjectsAssignment<size_t> objects_assignment;
+  objects_assignment.object_sizes = {16};
+  objects_assignment.object_ids = {0};
+  OffsetsAssignment result = ObjectsToOffsets(objects_assignment);
+  EXPECT_EQ(result.total_size, 16);
+  EXPECT_THAT(result.offsets, ElementsAre(0));
+
+  objects_assignment.object_ids = {0, 0, 0};
+  result = ObjectsToOffsets(objects_assignment);
+  EXPECT_EQ(result.total_size, 16);
+  EXPECT_THAT(result.offsets, ElementsAre(0, 0, 0));
+}
+
+TEST(Model, ManyObjectsAssignment) {
+  ObjectsAssignment<size_t> objects_assignment;
+  objects_assignment.object_sizes = {16, 8, 32, 32, 4, 16};
+  objects_assignment.object_ids = {2, 0, 2, 1, 3, 3, 1, 5};
+  OffsetsAssignment result = ObjectsToOffsets(objects_assignment);
+  EXPECT_THAT(result.offsets, ElementsAre(24, 0, 24, 16, 56, 56, 16, 92));
+}
+
 TEST(Model, EmptyRecords) {
   ObjectsAssignment<size_t> assignment;
   ASSERT_TRUE(
@@ -46,11 +75,19 @@
           .ok());
   EXPECT_TRUE(assignment.object_ids.empty());
   EXPECT_TRUE(assignment.object_sizes.empty());
+
+  OffsetsAssignment offsets_assignment;
+  ASSERT_TRUE(AssignOffsetsToTensors({}, MemoryStrategy::GREEDY_BY_SIZE,
+                                     &offsets_assignment)
+                  .ok());
+  EXPECT_TRUE(offsets_assignment.offsets.empty());
+  EXPECT_EQ(offsets_assignment.total_size, 0);
 }
 
 TEST(Model, OneRecord) {
   std::vector<TensorUsageRecord<size_t>> usage_records{
       {/*size=*/16, /*first=*/0, /*last=*/1}};
+
   ObjectsAssignment<size_t> assignment;
   ASSERT_TRUE(
       AssignObjectsToTensors(usage_records, MemoryStrategy::NAIVE, &assignment)
@@ -75,6 +112,14 @@
                   .ok());
   EXPECT_THAT(assignment.object_ids, ElementsAre(0));
   EXPECT_THAT(assignment.object_sizes, ElementsAre(16));
+
+  OffsetsAssignment offsets_assignment;
+  ASSERT_TRUE(AssignOffsetsToTensors(usage_records,
+                                     MemoryStrategy::GREEDY_BY_SIZE,
+                                     &offsets_assignment)
+                  .ok());
+  EXPECT_THAT(offsets_assignment.offsets, ElementsAre(0));
+  EXPECT_EQ(offsets_assignment.total_size, 16);
 }
 
 TEST(Model, ChainRecords) {
@@ -85,6 +130,7 @@
       {/*size=*/32, /*first=*/3, /*last=*/4},
       {/*size=*/8, /*first=*/4, /*last=*/5},
   };
+
   ObjectsAssignment<size_t> assignment;
   ASSERT_TRUE(
       AssignObjectsToTensors(usage_records, MemoryStrategy::NAIVE, &assignment)
@@ -109,6 +155,14 @@
                   .ok());
   EXPECT_THAT(assignment.object_ids, ElementsAre(0, 1, 0, 1, 0));
   EXPECT_THAT(assignment.object_sizes, ElementsAre(64, 32));
+
+  OffsetsAssignment offsets_assignment;
+  ASSERT_TRUE(AssignOffsetsToTensors(usage_records,
+                                     MemoryStrategy::GREEDY_BY_SIZE,
+                                     &offsets_assignment)
+                  .ok());
+  EXPECT_THAT(offsets_assignment.offsets, ElementsAre(0, 64, 0, 64, 0));
+  EXPECT_EQ(offsets_assignment.total_size, 96);
 }
 
 TEST(Model, ComplexRecords) {
@@ -122,6 +176,7 @@
       {/*size=*/8, /*first=*/6, /*last=*/8},
       {/*size=*/8, /*first=*/7, /*last=*/8},
       {/*size=*/16, /*first=*/8, /*last=*/9}};
+
   ObjectsAssignment<size_t> assignment;
   ASSERT_TRUE(
       AssignObjectsToTensors(usage_records, MemoryStrategy::NAIVE, &assignment)
@@ -147,6 +202,15 @@
                   .ok());
   EXPECT_THAT(assignment.object_ids, ElementsAre(0, 1, 2, 0, 3, 1, 3, 2, 0));
   EXPECT_THAT(assignment.object_sizes, ElementsAre(32, 64, 8, 8));
+
+  OffsetsAssignment offsets_assignment;
+  ASSERT_TRUE(AssignOffsetsToTensors(usage_records,
+                                     MemoryStrategy::GREEDY_BY_SIZE,
+                                     &offsets_assignment)
+                  .ok());
+  EXPECT_THAT(offsets_assignment.offsets,
+              ElementsAre(0, 32, 80, 64, 88, 0, 64, 72, 0));
+  EXPECT_EQ(offsets_assignment.total_size, 96);
 }
 
 TEST(Model, BHWCRecords) {
diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc
index 159eec5..5e8d597 100644
--- a/tensorflow/lite/delegates/gpu/common/model_builder.cc
+++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc
@@ -33,6 +33,7 @@
 #include "absl/strings/string_view.h"
 #include "tensorflow/lite/builtin_op_data.h"
 #include "tensorflow/lite/builtin_ops.h"
+#include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/c/c_api_internal.h"
 #include "tensorflow/lite/context.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
@@ -43,14 +44,12 @@
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/schema/schema_generated.h"
+#include "tensorflow/lite/util.h"
 
 namespace tflite {
 namespace gpu {
 namespace {
 
-using ::absl::make_unique;
-using ::absl::StrCat;
-
 // Creates a node that consumes output from the given node. Because output need
 // to stay the same, newly created node will inherit the output from the given
 // node, which will in turn get newly created copy of output. This is necessary
@@ -77,8 +76,8 @@
 Status CreateVectorCopyData(const TfLiteTensor& tensor, T* tensor_data) {
   if (tensor.bytes % sizeof(T) != 0) {
     return InvalidArgumentError(
-        StrCat("Input data size ", tensor.bytes,
-               " is not aligned to expected type: ", sizeof(T)));
+        absl::StrCat("Input data size ", tensor.bytes,
+                     " is not aligned to expected type: ", sizeof(T)));
   }
   std::memcpy(tensor_data, tensor.data.uint8, tensor.bytes);
   return OkStatus();
@@ -171,7 +170,7 @@
 Status SetAllDimensions<OHWI>(const TfLiteIntArray* dimensions, OHWI* shape) {
   if (dimensions->size != 4) {
     return InvalidArgumentError(
-        StrCat("Dimensions are not OHWI: ", dimensions->size));
+        absl::StrCat("Dimensions are not OHWI: ", dimensions->size));
   }
   shape->o = dimensions->data[0];
   shape->h = dimensions->data[1];
@@ -184,7 +183,7 @@
 Status SetAllDimensions<IHWO>(const TfLiteIntArray* dimensions, IHWO* shape) {
   if (dimensions->size != 4) {
     return InvalidArgumentError(
-        StrCat("Dimensions are not IHWO: ", dimensions->size));
+        absl::StrCat("Dimensions are not IHWO: ", dimensions->size));
   }
   shape->i = dimensions->data[0];
   shape->h = dimensions->data[1];
@@ -265,7 +264,8 @@
 
   Status ReadValue(uint32_t idx, Value<TensorRef<BHWC>>** value) const {
     if (idx >= tflite_node_->inputs->size) {
-      return OutOfRangeError(StrCat("ReadValue: input tensor index: ", idx));
+      return OutOfRangeError(
+          absl::StrCat("ReadValue: input tensor index: ", idx));
     }
     return ReadValueByTensorIdx(tflite_node_->inputs->data[idx], value);
   }
@@ -276,11 +276,11 @@
 
   Status GetTensorDims(uint32_t idx, TfLiteIntArray* dimensions) const {
     if (idx >= tflite_node_->inputs->size) {
-      return OutOfRangeError(StrCat("Input tensor index: ", idx));
+      return OutOfRangeError(absl::StrCat("Input tensor index: ", idx));
     }
     const int tensor_idx = tflite_node_->inputs->data[idx];
     if (tensor_idx < 0 || tensor_idx > context_->tensors_size) {
-      return OutOfRangeError(StrCat("Tensor index: ", tensor_idx));
+      return OutOfRangeError(absl::StrCat("Tensor index: ", tensor_idx));
     }
     const TfLiteTensor& tflite_tensor = context_->tensors[tensor_idx];
     *dimensions = *tflite_tensor.dims;
@@ -303,9 +303,9 @@
 
   Status AddOutput(const Node* node, int id) {
     if (tflite_node_->outputs->size <= id) {
-      return InvalidArgumentError(
-          StrCat("Data id ", id, " must be less than tflite node outputs size ",
-                 tflite_node_->outputs->size));
+      return InvalidArgumentError(absl::StrCat(
+          "Data id ", id, " must be less than tflite node outputs size ",
+          tflite_node_->outputs->size));
     }
     int output_tensor_idx = tflite_node_->outputs->data[id];
     Value<TensorRef<BHWC>>* value;
@@ -331,13 +331,13 @@
                               Value<TensorRef<BHWC>>** value) const {
     if (tensor_idx >= tensor_to_value_->size()) {
       return OutOfRangeError(
-          StrCat("ReadValue: input tensor index: ", tensor_idx));
+          absl::StrCat("ReadValue: input tensor index: ", tensor_idx));
     }
     if ((*tensor_to_value_)[tensor_idx] == nullptr) {
       const TfLiteTensor& tflite_tensor = context_->tensors[tensor_idx];
       if (tflite::IsConstantTensor(&tflite_tensor)) {
-        return NotFoundError(
-            StrCat("ReadValue: value is a constant tensor: ", tensor_idx));
+        return NotFoundError(absl::StrCat(
+            "ReadValue: value is a constant tensor: ", tensor_idx));
       }
       Value<TensorRef<BHWC>>* value = graph_->NewValue();
       RETURN_IF_ERROR(
@@ -463,7 +463,7 @@
       break;
     default:
       return NotFoundError(
-          StrCat("Unsupported fused activation: ", fused_activation));
+          absl::StrCat("Unsupported fused activation: ", fused_activation));
   }
   return OkStatus();
 }
@@ -708,7 +708,6 @@
       }
     }
     node->operation.attributes = std::move(attr);
-
     const auto* tf_options =
         reinterpret_cast<const TfLiteAddParams*>(tflite_node->builtin_data);
     if (!tf_options) {
@@ -1047,8 +1046,6 @@
                      const TfLiteNode* tflite_node,
                      const TfLiteRegistration* registration) final {
     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
-    TfLiteSubParams* tf_options;
-    RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
     if (IsOneArgumentOperation()) {
       RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/1,
                                          /*outputs=*/1));
@@ -1058,7 +1055,9 @@
     } else {
       return InvalidArgumentError("Op can only handle 1 or 2 operand(s).");
     }
-    return IsActivationSupported(tf_options->activation);
+    TfLiteFusedActivation activation;
+    RETURN_IF_ERROR(GetActivation(tflite_node, &activation));
+    return IsActivationSupported(activation);
   }
 
   Status Parse(const TfLiteNode* tflite_node,
@@ -1111,6 +1110,27 @@
   }
 
  private:
+  Status GetActivation(const TfLiteNode* tflite_node,
+                       TfLiteFusedActivation* activation) const {
+    if (operation_type_ == OperationType::DIV) {
+      TfLiteDivParams* tf_options;
+      RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
+      *activation = tf_options ? tf_options->activation : kTfLiteActNone;
+      return OkStatus();
+    }
+    if (operation_type_ == OperationType::SUB) {
+      TfLiteSubParams* tf_options;
+      RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
+      *activation = tf_options ? tf_options->activation : kTfLiteActNone;
+      return OkStatus();
+    }
+
+    // Return kTfLiteActNone as other ops either do not have TfLiteXxxParams or
+    // TfLiteXxxParams.activation.
+    *activation = kTfLiteActNone;
+    return OkStatus();
+  }
+
   bool IsOneArgumentOperation() const {
     switch (operation_type_) {
       case OperationType::ABS:
@@ -1715,7 +1735,7 @@
                const TfLiteRegistration* registration, GraphFloat32* graph,
                ObjectReader* reader) final {
     Node* node = graph->NewNode();
-    node->operation.type = ToString(OperationType::SOFT_MAX);
+    node->operation.type = ToString(OperationType::SOFTMAX);
     RETURN_IF_ERROR(reader->AddInput(node, 0));
     RETURN_IF_ERROR(reader->AddOutputs(node));
 
@@ -1731,8 +1751,7 @@
       // auto mul_node = reader->NewPassthroughNode(node);
       // mul_node->operation.type = ToString(OperationType::MUL);
     }
-    // TODO(impjdi): Rename to SoftmaxAttributes.
-    SoftMaxAttributes attr;
+    SoftmaxAttributes attr;
     attr.axis = Axis::CHANNELS;  // always by channels
     node->operation.attributes = attr;
     return OkStatus();
@@ -2096,86 +2115,89 @@
   const absl::string_view custom_name = registration->custom_name;
   switch (builtin_code) {
     case kTfLiteBuiltinAbs:
-      return make_unique<ElementwiseOperationParser>(OperationType::ABS);
+      return absl::make_unique<ElementwiseOperationParser>(OperationType::ABS);
     case kTfLiteBuiltinAdd:
-      return make_unique<AddOperationParser>();
+      return absl::make_unique<AddOperationParser>();
     case kTfLiteBuiltinAveragePool2d:
-      return make_unique<Pooling2DOperationParser>(PoolingType::AVERAGE);
+      return absl::make_unique<Pooling2DOperationParser>(PoolingType::AVERAGE);
     case kTfLiteBuiltinConcatenation:
-      return make_unique<ConcatenationOperationParser>();
+      return absl::make_unique<ConcatenationOperationParser>();
     case kTfLiteBuiltinConv2d:
-      return make_unique<Conv2DOperationParser>();
+      return absl::make_unique<Conv2DOperationParser>();
     case kTfLiteBuiltinCos:
-      return make_unique<ElementwiseOperationParser>(OperationType::COS);
+      return absl::make_unique<ElementwiseOperationParser>(OperationType::COS);
     case kTfLiteBuiltinDepthwiseConv2d:
-      return make_unique<DepthwiseConvolutionOperationParser>();
+      return absl::make_unique<DepthwiseConvolutionOperationParser>();
     case kTfLiteBuiltinDiv:
-      return make_unique<ElementwiseOperationParser>(OperationType::DIV);
+      return absl::make_unique<ElementwiseOperationParser>(OperationType::DIV);
     case kTfLiteBuiltinFullyConnected:
-      return make_unique<FullyConnectedOperationParser>();
+      return absl::make_unique<FullyConnectedOperationParser>();
     case kTfLiteBuiltinHardSwish:
-      return make_unique<HardSwishOperationParser>();
+      return absl::make_unique<HardSwishOperationParser>();
     case kTfLiteBuiltinLogistic:
-      return make_unique<ElementwiseOperationParser>(OperationType::SIGMOID);
+      return absl::make_unique<ElementwiseOperationParser>(
+          OperationType::SIGMOID);
     case kTfLiteBuiltinLog:
-      return make_unique<ElementwiseOperationParser>(OperationType::LOG);
+      return absl::make_unique<ElementwiseOperationParser>(OperationType::LOG);
     case kTfLiteBuiltinLstm:
-      return make_unique<LSTMOperationParser>();
+      return absl::make_unique<LSTMOperationParser>();
     case kTfLiteBuiltinMaxPool2d:
-      return make_unique<Pooling2DOperationParser>(PoolingType::MAX);
+      return absl::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
     case kTfLiteBuiltinMul:
-      return make_unique<MulOperationParser>();
+      return absl::make_unique<MulOperationParser>();
     case kTfLiteBuiltinPad:
-      return make_unique<PadOperationParser>();
+      return absl::make_unique<PadOperationParser>();
     case kTfLiteBuiltinPow:
-      return make_unique<ElementwiseOperationParser>(OperationType::POW);
+      return absl::make_unique<ElementwiseOperationParser>(OperationType::POW);
     case kTfLiteBuiltinRelu:
-      return make_unique<ReLUOperationParser>(0);
+      return absl::make_unique<ReLUOperationParser>(0);
     case kTfLiteBuiltinRelu6:
-      return make_unique<ReLUOperationParser>(6);
+      return absl::make_unique<ReLUOperationParser>(6);
     case kTfLiteBuiltinLeakyRelu:
-      return make_unique<ReLUOperationParser>(0);
+      return absl::make_unique<ReLUOperationParser>(0);
     case kTfLiteBuiltinPrelu:
-      return make_unique<PReLUOperationParser>();
+      return absl::make_unique<PReLUOperationParser>();
     case kTfLiteBuiltinReshape:
-      return make_unique<ReshapeOperationParser>();
+      return absl::make_unique<ReshapeOperationParser>();
     case kTfLiteBuiltinResizeBilinear:
-      return make_unique<ResizeBilinearOperationParser>();
+      return absl::make_unique<ResizeBilinearOperationParser>();
     case kTfLiteBuiltinRsqrt:
-      return make_unique<ElementwiseOperationParser>(OperationType::RSQRT);
+      return absl::make_unique<ElementwiseOperationParser>(
+          OperationType::RSQRT);
     case kTfLiteBuiltinSin:
-      return make_unique<ElementwiseOperationParser>(OperationType::SIN);
+      return absl::make_unique<ElementwiseOperationParser>(OperationType::SIN);
     case kTfLiteBuiltinSoftmax:
-      return make_unique<SoftmaxOperationParser>();
+      return absl::make_unique<SoftmaxOperationParser>();
     case kTfLiteBuiltinStridedSlice:
-      return make_unique<StridedSliceOperationParser>();
+      return absl::make_unique<StridedSliceOperationParser>();
     case kTfLiteBuiltinSqrt:
-      return make_unique<ElementwiseOperationParser>(OperationType::SQRT);
+      return absl::make_unique<ElementwiseOperationParser>(OperationType::SQRT);
     case kTfLiteBuiltinSquare:
-      return make_unique<ElementwiseOperationParser>(OperationType::SQUARE);
+      return absl::make_unique<ElementwiseOperationParser>(
+          OperationType::SQUARE);
     case kTfLiteBuiltinSquaredDifference:
-      return make_unique<ElementwiseOperationParser>(
+      return absl::make_unique<ElementwiseOperationParser>(
           OperationType::SQUARED_DIFF);
     case kTfLiteBuiltinSub:
-      return make_unique<ElementwiseOperationParser>(OperationType::SUB);
+      return absl::make_unique<ElementwiseOperationParser>(OperationType::SUB);
     case kTfLiteBuiltinTanh:
-      return make_unique<ElementwiseOperationParser>(OperationType::TANH);
+      return absl::make_unique<ElementwiseOperationParser>(OperationType::TANH);
     case kTfLiteBuiltinTransposeConv:
-      return make_unique<TransposeConvOperationParser>();
+      return absl::make_unique<TransposeConvOperationParser>();
 
     case kTfLiteBuiltinCustom:
       if (custom_name == "Convolution2DTransposeBias") {
-        return make_unique<Convolution2DTransposeBiasParser>();
+        return absl::make_unique<Convolution2DTransposeBiasParser>();
       }
       if (custom_name == "MaxPoolingWithArgmax2D") {
-        return make_unique<Pooling2DOperationParser>(PoolingType::MAX);
+        return absl::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
       }
       if (custom_name == "MaxUnpooling2D") {
-        return make_unique<Unpooling2DOperationParser>();
+        return absl::make_unique<Unpooling2DOperationParser>();
       }
       break;
   }
-  return make_unique<UnsupportedOperationParser>();
+  return absl::make_unique<UnsupportedOperationParser>();
 }
 
 }  // namespace
@@ -2220,12 +2242,114 @@
                               TfLiteRegistration** registration) {
   if (context->GetNodeAndRegistration(context, node_id, tflite_node,
                                       registration) != kTfLiteOk) {
-    return InvalidArgumentError(
-        StrCat("Couldn't get node and registration info for op: ", node_id));
+    return InvalidArgumentError(absl::StrCat(
+        "Couldn't get node and registration info for op: ", node_id));
   }
   return OkStatus();
 }
 
+TfLiteIntArray* GetOpsToReplaceFromGraphWithDequantize(TfLiteContext* context) {
+  TfLiteIntArray* execution_plan = nullptr;
+  if (context->GetExecutionPlan(context, &execution_plan) != kTfLiteOk) {
+    context->ReportError(context, "Unable to get graph execution plan.");
+    return nullptr;
+  }
+  std::set<std::string> errors;
+  std::unordered_map<int, int> dequant_nodes;
+  std::vector<int> ops_to_replace;
+  std::vector<int> dequant_nodes_to_save;
+
+  // Map the output tensor of a Dequantize nodes to its input tensor.
+  std::unordered_map<int, int> node_map;
+  for (int i = 0; i < execution_plan->size; ++i) {
+    bool replace_node = false;
+    // Keep track of any inputs from a Dequantize node.
+    std::vector<int> inputs_from_dequant;
+    std::vector<int> orig_inputs;
+
+    const int node_id = execution_plan->data[i];
+    TfLiteNode* node = nullptr;
+    TfLiteRegistration* registration = nullptr;
+    auto status =
+        GetNodeAndRegistration(context, node_id, &node, &registration);
+    if (!status.ok()) {
+      context->ReportError(context, status.error_message().c_str());
+      return nullptr;
+    }
+    if (registration->builtin_code == kTfLiteBuiltinDequantize &&
+        context->tensors[node->inputs->data[0]].type ==
+            TfLiteType::kTfLiteFloat16) {
+      // Record the output->input mapping for the op.
+      node_map[node->outputs->data[0]] = node->inputs->data[0];
+      // For now, add the node to the list of ops to replace.
+      ops_to_replace.push_back(node_id);
+      // Record the dequant node id, indexed by output id.
+      dequant_nodes[node->outputs->data[0]] = node_id;
+      continue;
+    }
+    TfLiteIntArray* inputs = node->inputs;
+    // Fix the node's inputs (i.e. prune out the preceding dequantize node)
+    // in order to test if it is supported on the GPU.
+    for (int j = 0; j < inputs->size; ++j) {
+      orig_inputs.push_back(inputs->data[j]);
+      if (node_map.find(inputs->data[j]) != node_map.end()) {
+        inputs_from_dequant.push_back(dequant_nodes[inputs->data[j]]);
+        // Remap inputs of this node to the inputs of the preceding dequant.
+        inputs->data[j] = node_map[inputs->data[j]];
+      }
+    }
+    status = IsSupported(context, node, registration);
+    if (status.ok() &&
+        // TODO(eignasheva): resolve sub operation support for metal delegate
+        // registration->builtin_code != kTfLiteBuiltinSub &&
+        IsAllFloatTensors(context, node->inputs) &&
+        IsAllFloatTensors(context, node->outputs)) {
+      if (errors.empty()) {
+        replace_node = true;
+        ops_to_replace.push_back(i);
+      }
+    } else {
+      // Unable to replace this node. Restore the inputs to the original
+      // if they were modified.
+      if (!inputs_from_dequant.empty()) {
+        TfLiteIntArray* inputs = node->inputs;
+        for (int j = 0; j < inputs->size; ++j) {
+          inputs->data[j] = orig_inputs[j];
+        }
+      }
+      errors.insert(GetOpNameByRegistration(registration) + ": " +
+                    status.error_message());
+    }
+    // if any input is the output of a dequantize node AND we failed to
+    // replace this op, mark the corresponding dequantize node as a node to
+    // save.
+    if (!replace_node && !inputs_from_dequant.empty()) {
+      dequant_nodes_to_save.insert(dequant_nodes_to_save.end(),
+                                   inputs_from_dequant.begin(),
+                                   inputs_from_dequant.end());
+    }
+  }
+  if (!errors.empty()) {
+    std::string unsupported = absl::StrJoin(errors, "\n");
+    std::string error_message =
+        "Next operations are not supported by GPU delegate:\n" + unsupported +
+        "\nFirst " + std::to_string(ops_to_replace.size()) +
+        " operations will run on the GPU, and the remaining " +
+        std::to_string(execution_plan->size - ops_to_replace.size()) +
+        " on the CPU.";
+    context->ReportError(context, error_message.c_str());
+  }
+  // Pop all dequantize nodes that must be preserved.
+  for (int i = 0; i < dequant_nodes_to_save.size(); ++i) {
+    auto it = std::find(ops_to_replace.begin(), ops_to_replace.end(),
+                        dequant_nodes_to_save[i]);
+    if (it != ops_to_replace.end()) {
+      ops_to_replace.erase(it);
+    }
+  }
+  return ConvertVectorToTfLiteIntArray(ops_to_replace);
+}
+
 // TODO(impjdi): Check number of input/output tensors and their dimensions.
 // TODO(impjdi): Check ops' parameters.
 TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) {
@@ -2234,27 +2358,38 @@
     context->ReportError(context, "Unable to get graph execution plan.");
     return nullptr;
   }
-  TfLiteIntArray* subgraph = TfLiteIntArrayCreate(execution_plan->size);
-  subgraph->size = 0;
-  std::set<std::string> errors;
 
-  // Map the output tensor of a Dequantize nodes to its input tensor.
-  std::unordered_map<int, int> node_map;
+  // Dispatch to another function if graph has Dequantize nodes.
   for (int i = 0; i < execution_plan->size; ++i) {
+    const int node_id = execution_plan->data[i];
     TfLiteNode* node = nullptr;
     TfLiteRegistration* registration = nullptr;
-    auto status = GetNodeAndRegistration(context, i, &node, &registration);
+    auto status =
+        GetNodeAndRegistration(context, node_id, &node, &registration);
     if (!status.ok()) {
       context->ReportError(context, status.error_message().c_str());
-      TfLiteIntArrayFree(subgraph);
       return nullptr;
     }
     if (registration->builtin_code == kTfLiteBuiltinDequantize &&
         context->tensors[node->inputs->data[0]].type ==
             TfLiteType::kTfLiteFloat16) {
-      // Record the output->input mapping for the op.
-      node_map[node->outputs->data[0]] = node->inputs->data[0];
-      continue;
+      return GetOpsToReplaceFromGraphWithDequantize(context);
+    }
+  }
+
+  // No Dequantize nodes. Iterate through graph and find ops to replace.
+  TfLiteIntArray* subgraph = TfLiteIntArrayCreate(execution_plan->size);
+  subgraph->size = 0;
+  std::set<std::string> errors;
+  for (int i = 0; i < execution_plan->size; ++i) {
+    const int node_id = execution_plan->data[i];
+    TfLiteNode* node;
+    TfLiteRegistration* registration;
+    auto status =
+        GetNodeAndRegistration(context, node_id, &node, &registration);
+    if (!status.ok()) {
+      context->ReportError(context, status.error_message().c_str());
+      return nullptr;
     }
     status = IsSupported(context, node, registration);
     if (status.ok() &&
@@ -2262,18 +2397,10 @@
         // registration->builtin_code != kTfLiteBuiltinSub &&
         IsAllFloatTensors(context, node->inputs) &&
         IsAllFloatTensors(context, node->outputs)) {
-      // Fix the node's inputs (i.e. prune out the preceding dequantize node)
-      // if the op is supported.
-      TfLiteIntArray* inputs = node->inputs;
-      for (int j = 0; j < inputs->size; ++j) {
-        if (node_map.find(inputs->data[j]) != node_map.end()) {
-          inputs->data[j] = node_map[inputs->data[j]];
-        }
-      }
-      if (errors.empty()) subgraph->data[subgraph->size++] = i;
+      if (errors.empty()) subgraph->data[subgraph->size++] = node_id;
     } else {
-      errors.insert(GetOpNameByRegistration(registration) + ": " +
-                    status.error_message());
+      errors.insert(absl::StrCat(GetOpNameByRegistration(registration), ": ",
+                                 status.error_message()));
     }
   }
   if (!errors.empty()) {
@@ -2292,32 +2419,42 @@
                   const TfLiteDelegateParams* delegate_params,
                   GraphFloat32* graph) {
   std::vector<std::unique_ptr<TFLiteOperationParser>> operations;
+  std::vector<int> tflite_nodes;
   for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) {
     TfLiteNode* tflite_node = nullptr;
     TfLiteRegistration* registration = nullptr;
     RETURN_IF_ERROR(GetNodeAndRegistration(
         context, delegate_params->nodes_to_replace->data[i], &tflite_node,
         &registration));
+    if (registration->builtin_code == kTfLiteBuiltinDequantize) {
+      // Ignore Dequantize nodes.
+      continue;
+    }
     auto op_parser = NewOperationParser(registration);
     if (!op_parser) {
       return UnimplementedError(
-          StrCat("Operation ", registration->builtin_code, "(",
-                 registration->custom_name,
-                 ") is not supported by TFLite GPU Delegate."));
+          absl::StrCat("Operation ", registration->builtin_code, "(",
+                       registration->custom_name,
+                       ") is not supported by TFLite GPU Delegate."));
     }
     operations.push_back(std::move(op_parser));
+    tflite_nodes.push_back(i);
   }
   std::vector<Value<TensorRef<BHWC>>*> tensor_to_value(context->tensors_size,
                                                        nullptr);
-  for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) {
-    TfLiteNode* tflite_node = nullptr;
-    TfLiteRegistration* registration = nullptr;
+  for (int i = 0; i < operations.size(); ++i) {
+    TfLiteNode* tflite_node;
+    TfLiteRegistration* registration;
     RETURN_IF_ERROR(GetNodeAndRegistration(
-        context, delegate_params->nodes_to_replace->data[i], &tflite_node,
-        &registration));
+        context, delegate_params->nodes_to_replace->data[tflite_nodes[i]],
+        &tflite_node, &registration));
     ObjectReader reader(graph, context, tflite_node, &tensor_to_value);
-    RETURN_IF_ERROR(
-        operations[i]->Parse(tflite_node, registration, graph, &reader));
+    const auto status =
+        operations[i]->Parse(tflite_node, registration, graph, &reader);
+    if (!status.ok()) {
+      return InternalError(absl::StrCat(GetOpNameByRegistration(registration),
+                                        ": ", status.error_message()));
+    }
   }
   return OkStatus();
 }
diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_test.cc b/tensorflow/lite/delegates/gpu/common/model_builder_test.cc
index 31c7c57..f737612 100644
--- a/tensorflow/lite/delegates/gpu/common/model_builder_test.cc
+++ b/tensorflow/lite/delegates/gpu/common/model_builder_test.cc
@@ -212,7 +212,8 @@
   //   t0 (FP16) -> DequantNode -> t1 (FP32) -> Add -> t4
   //   t2 (FP16) -> DequantNode -> t3 (FP32) --/
   //
-  // After pruning, the graph has one node:
+  // OpsToReplace should choose all three nodes for replacement, and
+  // the graph on the GPU will look like this (no Dequants):
   //
   //   t0 (FP16) --> Add -> t4
   //   t2 (FP16) --/
@@ -237,11 +238,11 @@
 
   TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
 
-  // Just one node left.
-  EXPECT_EQ(ops_to_replace->size, 1);
+  // Replace all nodes.
+  EXPECT_EQ(ops_to_replace->size, 3);
   TfLiteNode* node = nullptr;
   TfLiteRegistration* registration = nullptr;
-  context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node,
+  context->GetNodeAndRegistration(context, ops_to_replace->data[2], &node,
                                   &registration);
   EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
             TfLiteType::kTfLiteFloat16);
@@ -416,6 +417,174 @@
   TfLiteIntArrayFree(ops_to_replace);
 }
 
+class InterpreterMultiNode {
+ public:
+  InterpreterMultiNode() {
+    void* builtin_data = malloc(sizeof(int));
+    EXPECT_EQ(interpreter_.AddTensors(8), kTfLiteOk);
+    EXPECT_EQ(interpreter_.SetInputs({0, 1, 2}), kTfLiteOk);
+    EXPECT_EQ(interpreter_.SetOutputs({6, 7}), kTfLiteOk);
+
+    // Add 3 Dequantize Nodes with float16 input.
+    for (int i = 0; i < 3; ++i) {
+      const TfLiteRegistration reg_dequant = {/*init=*/nullptr,
+                                              /*free=*/nullptr,
+                                              /*prepare=*/nullptr,
+                                              /*invoke=*/nullptr,
+                                              /*profiling_string=*/nullptr,
+                                              kTfLiteBuiltinDequantize};
+      EXPECT_EQ(interpreter_.AddNodeWithParameters(
+                    /*inputs=*/{i}, /*outputs=*/{i + 3}, /*init_data=*/nullptr,
+                    /*init_data_size=*/0, /*builtin_data=*/nullptr,
+                    /*registration=*/&reg_dequant),
+                kTfLiteOk);
+    }
+
+    // Add the ADD op node that GPU delegate supports.
+    const TfLiteRegistration reg_add0 = {
+        [](TfLiteContext* context, const char* buffer, size_t length) {
+          return reinterpret_cast<void*>(new int(1));
+        },
+        [](TfLiteContext* context, void* buffer) {
+          delete reinterpret_cast<int*>(buffer);
+        },
+        nullptr,
+        nullptr,
+        nullptr,
+        kTfLiteBuiltinAdd};
+
+    EXPECT_EQ(interpreter_.AddNodeWithParameters(
+                  /*inputs=*/{4, 5}, /*outputs=*/{7}, /*init_data=*/nullptr,
+                  /*init_data_size=*/0,
+                  /*builtin_data=*/builtin_data,
+                  /*registration=*/&reg_add0),
+              kTfLiteOk);
+
+    // Add the GreaterThan op node that GPU delegate doesn't support.
+    const TfLiteRegistration reg_greater = {
+        [](TfLiteContext* context, const char* buffer, size_t length) {
+          return reinterpret_cast<void*>(new int(1));
+        },
+        [](TfLiteContext* context, void* buffer) {
+          delete reinterpret_cast<int*>(buffer);
+        },
+        nullptr,
+        nullptr,
+        nullptr,
+        kTfLiteBuiltinGreater};
+
+    EXPECT_EQ(interpreter_.AddNodeWithParameters(
+                  /*inputs=*/{3, 4}, /*outputs=*/{6}, /*init_data=*/nullptr,
+                  /*init_data_size=*/0,
+                  /*builtin_data=*/builtin_data,
+                  /*registration=*/&reg_greater),
+              kTfLiteOk);
+
+    const std::vector<int> dims = {1};
+    TfLiteQuantization quantization;
+    quantization.type = kTfLiteNoQuantization;
+    EXPECT_EQ(
+        interpreter_.SetTensorParametersReadWrite(
+            0, TfLiteType::kTfLiteFloat16, "t0", dims, quantization, false),
+        kTfLiteOk);
+    EXPECT_EQ(
+        interpreter_.SetTensorParametersReadWrite(
+            1, TfLiteType::kTfLiteFloat16, "t1", dims, quantization, false),
+        kTfLiteOk);
+    EXPECT_EQ(
+        interpreter_.SetTensorParametersReadWrite(
+            2, TfLiteType::kTfLiteFloat16, "t2", dims, quantization, false),
+        kTfLiteOk);
+    EXPECT_EQ(
+        interpreter_.SetTensorParametersReadWrite(
+            3, TfLiteType::kTfLiteFloat32, "t3", dims, quantization, false),
+        kTfLiteOk);
+    EXPECT_EQ(
+        interpreter_.SetTensorParametersReadWrite(
+            4, TfLiteType::kTfLiteFloat32, "t4", dims, quantization, false),
+        kTfLiteOk);
+    EXPECT_EQ(
+        interpreter_.SetTensorParametersReadWrite(
+            5, TfLiteType::kTfLiteFloat32, "t5", dims, quantization, false),
+        kTfLiteOk);
+    EXPECT_EQ(
+        interpreter_.SetTensorParametersReadWrite(
+            6, TfLiteType::kTfLiteFloat32, "t5", dims, quantization, false),
+        kTfLiteOk);
+    EXPECT_EQ(
+        interpreter_.SetTensorParametersReadWrite(
+            7, TfLiteType::kTfLiteFloat32, "t5", dims, quantization, false),
+        kTfLiteOk);
+    exec_plan_ = TfLiteIntArrayCreate(5);
+    exec_plan_->data[0] = 0;
+    exec_plan_->data[1] = 1;
+    exec_plan_->data[2] = 2;
+    exec_plan_->data[3] = 3;
+    exec_plan_->data[4] = 4;
+  }
+
+  ~InterpreterMultiNode() { TfLiteIntArrayFree(exec_plan_); }
+
+  Subgraph* GetSubgraph() { return interpreter_.subgraph(0); }
+  TfLiteIntArray* exec_plan() const { return exec_plan_; }
+
+ private:
+  Interpreter interpreter_;
+  TfLiteIntArray* exec_plan_;
+};
+
+InterpreterMultiNode* interpreter_mn = new InterpreterMultiNode();
+
+TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequants) {
+  // A graph with three Dequant nodes feeding two ops, 'Add' and 'Greater'.
+  // 'Add' can be replaced by the GPU delegate, but 'Greater' can not.
+  //   t0 (FP16) --> Dequant --> t3 (FP32) --> Greater -> t6
+  //   t1 (FP16) --> Dequant --> t4 (FP32) --/
+  //                                       --\
+  //   t3 (FP16) --> Dequant --> t5 (FP32) --> Add -> t7
+  //
+  //  OpsToReplace should replace the 'Add' op and the Dequant outputing
+  //  t5, but leave the other Dequant nodes because 'Greater' must run
+  //  on the CPU.
+  TfLiteContext* context = interpreter_mn->GetSubgraph()->context();
+
+  // These functions are meant to be called inside delegates. Swap out
+  // for similar functions to permit direct calling of GetOpsToReplace.
+  context->GetExecutionPlan = [](struct TfLiteContext* context,
+                                 TfLiteIntArray** execution_plan) {
+    *execution_plan = interpreter_mn->exec_plan();
+    return kTfLiteOk;
+  };
+  context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
+                                       TfLiteNode** node,
+                                       TfLiteRegistration** registration) {
+    auto& node_and_reg =
+        interpreter_mn->GetSubgraph()->nodes_and_registration()[node_index];
+    *node = &node_and_reg.first;
+    *registration = &node_and_reg.second;
+    return kTfLiteOk;
+  };
+
+  TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
+
+  EXPECT_EQ(ops_to_replace->size, 2);
+  // Op at index 2 is the Dequant op (t3 -> t5).
+  EXPECT_EQ(ops_to_replace->data[0], 2);
+  // Op at index 3 is the Add op.
+  EXPECT_EQ(ops_to_replace->data[1], 3);
+
+  TfLiteNode* node = nullptr;
+  TfLiteRegistration* registration = nullptr;
+  // Verify that Add op has fp16 inputs.
+  context->GetNodeAndRegistration(context, ops_to_replace->data[1], &node,
+                                  &registration);
+  EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
+            TfLiteType::kTfLiteFloat16);
+  EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
+            TfLiteType::kTfLiteFloat16);
+  TfLiteIntArrayFree(ops_to_replace);
+}
+
 }  // namespace
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/operations.cc b/tensorflow/lite/delegates/gpu/common/operations.cc
index eb1f018..8ce1202 100644
--- a/tensorflow/lite/delegates/gpu/common/operations.cc
+++ b/tensorflow/lite/delegates/gpu/common/operations.cc
@@ -106,8 +106,8 @@
       return "sin";
     case OperationType::SLICE:
       return "slice";
-    case OperationType::SOFT_MAX:
-      return "soft_max";
+    case OperationType::SOFTMAX:
+      return "softmax";
     case OperationType::SPACE_TO_BATCH:
       return "space_to_batch";
     case OperationType::SQRT:
@@ -158,7 +158,7 @@
           {"sigmoid", OperationType::SIGMOID},
           {"sin", OperationType::SIN},
           {"slice", OperationType::SLICE},
-          {"soft_max", OperationType::SOFT_MAX},
+          {"softmax", OperationType::SOFTMAX},
           {"sqrt", OperationType::SQRT},
           {"square", OperationType::SQUARE},
           {"subtract", OperationType::SUB},
diff --git a/tensorflow/lite/delegates/gpu/common/operations.h b/tensorflow/lite/delegates/gpu/common/operations.h
index 5e564f6..89f4610 100644
--- a/tensorflow/lite/delegates/gpu/common/operations.h
+++ b/tensorflow/lite/delegates/gpu/common/operations.h
@@ -63,7 +63,7 @@
   SIGMOID,
   SIN,
   SLICE,
-  SOFT_MAX,
+  SOFTMAX,
   SPACE_TO_BATCH,
   SQRT,
   SQUARE,
@@ -239,7 +239,7 @@
       alpha;
 };
 
-struct SoftMaxAttributes {
+struct SoftmaxAttributes {
   Axis axis = Axis::UNKNOWN;
 };
 
diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc
index cf7bbc1..586c7a3 100644
--- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc
+++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc
@@ -184,7 +184,7 @@
       const float add_value = add ? add->data[s] : *add_scalar;
       for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
         for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
-          const int index = attr->weights.shape.LinearIndex({d, k_y, k_x, s});
+          const int index = attr->weights.shape.LinearIndex({{d, k_y, k_x, s}});
           attr->bias.data[d] += attr->weights.data[index] * add_value;
         }
       }
@@ -206,7 +206,7 @@
       const int d = s * attr->weights.shape.o + g;
       for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
         for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
-          const int index = attr->weights.shape.LinearIndex({g, k_y, k_x, s});
+          const int index = attr->weights.shape.LinearIndex({{g, k_y, k_x, s}});
           attr->bias.data[d] += attr->weights.data[index] * add_value;
         }
       }
@@ -225,7 +225,7 @@
   for (int d = 0; d < attr->weights.shape.o; ++d) {
     for (int s = 0; s < attr->weights.shape.i; ++s) {
       const float add_value = add ? add->data[s] : *add_scalar;
-      const int index = attr->weights.shape.LinearIndex({d, 0, 0, s});
+      const int index = attr->weights.shape.LinearIndex({{d, 0, 0, s}});
       attr->bias.data[d] += attr->weights.data[index] * add_value;
     }
   }
diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc
index 3090c3f..fc351db 100644
--- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc
+++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc
@@ -164,7 +164,7 @@
     for (int s = 0; s < attr->weights.shape.i; ++s) {
       for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
         for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
-          const int index = attr->weights.shape.LinearIndex({d, k_y, k_x, s});
+          const int index = attr->weights.shape.LinearIndex({{d, k_y, k_x, s}});
           attr->weights.data[index] *= multiplier;
         }
       }
@@ -186,7 +186,7 @@
       const float multiplier = mul ? mul->data[d] : *mul_scalar;
       for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
         for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
-          const int index = attr->weights.shape.LinearIndex({g, k_y, k_x, s});
+          const int index = attr->weights.shape.LinearIndex({{g, k_y, k_x, s}});
           attr->weights.data[index] *= multiplier;
         }
       }
@@ -207,7 +207,7 @@
     for (int s = 0; s < attr->weights.shape.i; ++s) {
       for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
         for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
-          const int index = attr->weights.shape.LinearIndex({d, k_y, k_x, s});
+          const int index = attr->weights.shape.LinearIndex({{d, k_y, k_x, s}});
           attr->weights.data[index] *= multiplier;
         }
       }
@@ -225,7 +225,7 @@
   for (int d = 0; d < attr->weights.shape.o; ++d) {
     const float multiplier = mul ? mul->data[d] : *mul_scalar;
     for (int s = 0; s < attr->weights.shape.i; ++s) {
-      const int index = attr->weights.shape.LinearIndex({d, 0, 0, s});
+      const int index = attr->weights.shape.LinearIndex({{d, 0, 0, s}});
       attr->weights.data[index] *= multiplier;
     }
     if (!attr->bias.data.empty()) {
@@ -243,7 +243,7 @@
     for (int d = 0; d < attr->weights.shape.o; ++d) {
       for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
         for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
-          const int index = attr->weights.shape.LinearIndex({d, k_y, k_x, s});
+          const int index = attr->weights.shape.LinearIndex({{d, k_y, k_x, s}});
           attr->weights.data[index] *= multiplier;
         }
       }
@@ -261,7 +261,7 @@
     for (int g = 0; g < attr->weights.shape.o; ++g) {
       for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
         for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
-          const int index = attr->weights.shape.LinearIndex({g, k_y, k_x, s});
+          const int index = attr->weights.shape.LinearIndex({{g, k_y, k_x, s}});
           attr->weights.data[index] *= multiplier;
         }
       }
@@ -279,7 +279,7 @@
     for (int d = 0; d < attr->weights.shape.o; ++d) {
       for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
         for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
-          const int index = attr->weights.shape.LinearIndex({d, k_y, k_x, s});
+          const int index = attr->weights.shape.LinearIndex({{d, k_y, k_x, s}});
           attr->weights.data[index] *= multiplier;
         }
       }
@@ -294,7 +294,7 @@
   for (int s = 0; s < attr->weights.shape.i; ++s) {
     const float multiplier = mul ? mul->data[s] : *mul_scalar;
     for (int d = 0; d < attr->weights.shape.o; ++d) {
-      const int index = attr->weights.shape.LinearIndex({d, 0, 0, s});
+      const int index = attr->weights.shape.LinearIndex({{d, 0, 0, s}});
       attr->weights.data[index] *= multiplier;
     }
   }
diff --git a/tensorflow/lite/delegates/gpu/gl/BUILD b/tensorflow/lite/delegates/gpu/gl/BUILD
index b3385ea..7983833 100644
--- a/tensorflow/lite/delegates/gpu/gl/BUILD
+++ b/tensorflow/lite/delegates/gpu/gl/BUILD
@@ -16,7 +16,7 @@
         ":compiler",
         ":compiler_options",
         ":gl_call",
-        ":gpu_info",
+        ":request_gpu_info",
         ":node_shader",
         ":object",
         ":object_manager",
@@ -48,8 +48,8 @@
         ":gl_call",
         ":gl_program",
         ":gl_sync",
-        ":gpu_info",
         ":portable",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:types",
         "@com_google_absl//absl/memory",
@@ -80,9 +80,9 @@
     deps = [
         ":compiler_options",
         ":float16_conversions",
-        ":gpu_info",
         ":node_shader",
         "//tensorflow/lite/delegates/gpu/common:data_type",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:model_transformer",
         "//tensorflow/lite/delegates/gpu/common:operations",
@@ -103,7 +103,6 @@
     name = "compiler_options",
     hdrs = ["compiler_options.h"],
     deps = [
-        ":gpu_info",
         ":object",
     ],
 )
@@ -128,8 +127,8 @@
         ":egl_context",
         ":egl_surface",
         ":gl_call",
-        ":gpu_info",
         ":portable",
+        ":request_gpu_info",
         "//tensorflow/lite/delegates/gpu/common:status",
         "@com_google_absl//absl/memory",
     ],
@@ -272,18 +271,6 @@
     ],
 )
 
-cc_library(
-    name = "gpu_info",
-    srcs = ["gpu_info.cc"],
-    hdrs = ["gpu_info.h"],
-    deps = [
-        ":gl_errors",
-        ":portable",
-        "//tensorflow/lite/delegates/gpu/common:status",
-        "@com_google_absl//absl/strings",
-    ],
-)
-
 flatbuffer_cc_library(
     name = "metadata_cc_fbs",
     srcs = ["metadata.fbs"],
@@ -298,9 +285,9 @@
     hdrs = ["node_shader.h"],
     deps = [
         ":compiler_options",
-        ":gpu_info",
         ":object",
         ":variable",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:types",
@@ -345,6 +332,19 @@
 )
 
 cc_library(
+    name = "request_gpu_info",
+    srcs = ["request_gpu_info.cc"],
+    hdrs = ["request_gpu_info.h"],
+    deps = [
+        ":gl_errors",
+        ":portable",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
+        "//tensorflow/lite/delegates/gpu/common:status",
+        "@com_google_absl//absl/strings",
+    ],
+)
+
+cc_library(
     name = "runtime",
     srcs = ["runtime.cc"],
     hdrs = ["runtime.h"],
@@ -356,7 +356,6 @@
         ":gl_program",
         ":gl_shader",
         ":gl_texture",
-        ":gpu_info",
         ":object",
         ":object_manager",
         ":portable",
@@ -364,6 +363,7 @@
         ":stats",
         ":variable",
         "//tensorflow/lite/delegates/gpu/common:data_type",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:types",
         "//tensorflow/lite/delegates/gpu/gl/runtime:shared_buffer",
diff --git a/tensorflow/lite/delegates/gpu/gl/api.cc b/tensorflow/lite/delegates/gpu/gl/api.cc
index 2767bc3..fc9fcae 100644
--- a/tensorflow/lite/delegates/gpu/gl/api.cc
+++ b/tensorflow/lite/delegates/gpu/gl/api.cc
@@ -31,9 +31,9 @@
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/gl/compiler.h"
 #include "tensorflow/lite/delegates/gpu/gl/gl_call.h"
-#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/gl/object.h"
 #include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h"
+#include "tensorflow/lite/delegates/gpu/gl/request_gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/gl/runtime.h"
 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
 
diff --git a/tensorflow/lite/delegates/gpu/gl/command_queue.cc b/tensorflow/lite/delegates/gpu/gl/command_queue.cc
index 8e0e085..62f40bf 100644
--- a/tensorflow/lite/delegates/gpu/gl/command_queue.cc
+++ b/tensorflow/lite/delegates/gpu/gl/command_queue.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/lite/delegates/gpu/gl/command_queue.h"
 
 #include "absl/memory/memory.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 #include "tensorflow/lite/delegates/gpu/gl/gl_call.h"
diff --git a/tensorflow/lite/delegates/gpu/gl/command_queue.h b/tensorflow/lite/delegates/gpu/gl/command_queue.h
index bf313b4..a4c2100 100644
--- a/tensorflow/lite/delegates/gpu/gl/command_queue.h
+++ b/tensorflow/lite/delegates/gpu/gl/command_queue.h
@@ -18,10 +18,10 @@
 
 #include <memory>
 
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 #include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
-#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/gl/compiler.cc b/tensorflow/lite/delegates/gpu/gl/compiler.cc
index 12ee49d..cef8139 100644
--- a/tensorflow/lite/delegates/gpu/gl/compiler.cc
+++ b/tensorflow/lite/delegates/gpu/gl/compiler.cc
@@ -24,6 +24,7 @@
 #include "absl/memory/memory.h"
 #include "absl/types/any.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
diff --git a/tensorflow/lite/delegates/gpu/gl/compiler.h b/tensorflow/lite/delegates/gpu/gl/compiler.h
index 3b69211..e8b4348 100644
--- a/tensorflow/lite/delegates/gpu/gl/compiler.h
+++ b/tensorflow/lite/delegates/gpu/gl/compiler.h
@@ -20,11 +20,11 @@
 #include <memory>
 #include <unordered_set>
 
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/gl/compiler/shader_code.h"
 #include "tensorflow/lite/delegates/gpu/gl/compiler_options.h"
-#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/gl/node_shader.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/BUILD b/tensorflow/lite/delegates/gpu/gl/compiler/BUILD
index 6ff3457..5a2ba10 100644
--- a/tensorflow/lite/delegates/gpu/gl/compiler/BUILD
+++ b/tensorflow/lite/delegates/gpu/gl/compiler/BUILD
@@ -29,40 +29,12 @@
 )
 
 cc_library(
-    name = "parameter_accessor",
-    srcs = ["parameter_accessor.cc"],
-    hdrs = ["parameter_accessor.h"],
-    deps = [
-        ":preprocessor",
-        "//tensorflow/lite/delegates/gpu/common:types",
-        "//tensorflow/lite/delegates/gpu/gl:variable",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/strings:str_format",
-        "@com_google_absl//absl/types:variant",
-    ],
-)
-
-cc_test(
-    name = "parameter_accessor_test",
-    srcs = ["parameter_accessor_test.cc"],
-    tags = [
-        "local",
-        "tflite_not_portable_ios",
-    ],
-    deps = [
-        ":parameter_accessor",
-        "//tensorflow/lite/delegates/gpu/common:types",
-        "@com_google_googletest//:gtest_main",
-    ],
-)
-
-cc_library(
     name = "object_accessor",
     srcs = ["object_accessor.cc"],
     hdrs = ["object_accessor.h"],
     deps = [
-        ":parameter_accessor",
         ":preprocessor",
+        ":variable_accessor",
         "//tensorflow/lite/delegates/gpu/common:data_type",
         "//tensorflow/lite/delegates/gpu/common:types",
         "//tensorflow/lite/delegates/gpu/gl:object",
@@ -80,7 +52,7 @@
     ],
     deps = [
         ":object_accessor",
-        ":parameter_accessor",
+        ":variable_accessor",
         "//tensorflow/lite/delegates/gpu/common:types",
         "//tensorflow/lite/delegates/gpu/gl:variable",
         "@com_google_absl//absl/types:variant",
@@ -106,13 +78,13 @@
     deps = [
         ":compiled_node",
         ":object_accessor",
-        ":parameter_accessor",
         ":preprocessor",
         ":shader_code",
+        ":variable_accessor",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/gl:compiler_options",
-        "//tensorflow/lite/delegates/gpu/gl:gpu_info",
         "//tensorflow/lite/delegates/gpu/gl:object",
         "//tensorflow/lite/delegates/gpu/gl:variable",
         "@com_google_absl//absl/strings",
@@ -172,8 +144,8 @@
     hdrs = ["rename.h"],
     deps = [
         ":object_accessor",
-        ":parameter_accessor",
         ":preprocessor",
+        ":variable_accessor",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/gl:node_shader",
         "//tensorflow/lite/delegates/gpu/gl:object",
@@ -198,4 +170,32 @@
     ],
 )
 
+cc_library(
+    name = "variable_accessor",
+    srcs = ["variable_accessor.cc"],
+    hdrs = ["variable_accessor.h"],
+    deps = [
+        ":preprocessor",
+        "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/gl:variable",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:str_format",
+        "@com_google_absl//absl/types:variant",
+    ],
+)
+
+cc_test(
+    name = "variable_accessor_test",
+    srcs = ["variable_accessor_test.cc"],
+    tags = [
+        "local",
+        "tflite_not_portable_ios",
+    ],
+    deps = [
+        ":variable_accessor",
+        "//tensorflow/lite/delegates/gpu/common:types",
+        "@com_google_googletest//:gtest_main",
+    ],
+)
+
 tflite_portable_test_suite()
diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc
index fff15b4..e874a24 100644
--- a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc
+++ b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc
@@ -441,21 +441,21 @@
   void operator()(uint32_t) const {}
 
   void operator()(const uint2& size) const {
-    parameters->AddParameter(
+    variable_accessor->AddUniformParameter(
         {absl::StrCat(object_name, "_w"), static_cast<int32_t>(size.x)});
   }
 
   // p1 and p2 are padding. For some reason buffer does not map correctly
   // without it.
   void operator()(const uint3& size) const {
-    parameters->AddParameter(
+    variable_accessor->AddUniformParameter(
         {absl::StrCat(object_name, "_w"), static_cast<int32_t>(size.x)});
-    parameters->AddParameter(
+    variable_accessor->AddUniformParameter(
         {absl::StrCat(object_name, "_h"), static_cast<int32_t>(size.y)});
   }
 
   absl::string_view object_name;
-  ParameterAccessor* parameters;
+  VariableAccessor* variable_accessor;
 };
 
 // Adds necessary parameters to parameter accessor that represent object size
@@ -464,7 +464,7 @@
 //  - 2D : 'int object_name_w'
 //  - 3D : 'int object_name_w' + 'int object_name_h'
 void AddSizeParameters(absl::string_view object_name, const Object& object,
-                       ParameterAccessor* parameters) {
+                       VariableAccessor* parameters) {
   absl::visit(SizeParametersAdder{object_name, parameters}, object.size);
 }
 
@@ -533,7 +533,7 @@
   auto status = GenerateReadAccessor(it->second, element, sampler_textures_,
                                      output, &requires_sizes);
   if (requires_sizes) {
-    AddSizeParameters(it->first, it->second, parameter_accessor_);
+    AddSizeParameters(it->first, it->second, variable_accessor_);
   }
   return status;
 }
@@ -555,7 +555,7 @@
   auto status = GenerateWriteAccessor(it->second, element, value, output,
                                       &requires_sizes);
   if (requires_sizes) {
-    AddSizeParameters(it->first, it->second, parameter_accessor_);
+    AddSizeParameters(it->first, it->second, variable_accessor_);
   }
   return status;
 }
diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h
index e5bf628..78e7a2f 100644
--- a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h
+++ b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h
@@ -20,8 +20,8 @@
 #include <unordered_map>
 #include <vector>
 
-#include "tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.h"
 #include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h"
+#include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h"
 #include "tensorflow/lite/delegates/gpu/gl/object.h"
 
 namespace tflite {
@@ -54,15 +54,15 @@
 //
 class ObjectAccessor : public InlineRewrite {
  public:
-  ObjectAccessor(bool is_mali, ParameterAccessor* parameter_accessor)
-      : ObjectAccessor(is_mali, /*sampler_textures=*/false,
-                       parameter_accessor) {}
+  ObjectAccessor(bool is_mali, VariableAccessor* variable_accessor)
+      : ObjectAccessor(is_mali, /*sampler_textures=*/false, variable_accessor) {
+  }
 
   ObjectAccessor(bool is_mali, bool sampler_textures,
-                 ParameterAccessor* parameter_accessor)
+                 VariableAccessor* variable_accessor)
       : is_mali_(is_mali),
         sampler_textures_(sampler_textures),
-        parameter_accessor_(parameter_accessor) {}
+        variable_accessor_(variable_accessor) {}
 
   RewriteStatus Rewrite(absl::string_view input, std::string* output) final;
 
@@ -89,7 +89,7 @@
 
   const bool is_mali_;
   const bool sampler_textures_;
-  ParameterAccessor* parameter_accessor_;
+  VariableAccessor* variable_accessor_;
 };
 
 // Implementation details below.
diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor_test.cc b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor_test.cc
index 0b04210..c344d8f 100644
--- a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor_test.cc
+++ b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor_test.cc
@@ -22,7 +22,7 @@
 #include <gtest/gtest.h>
 #include "absl/types/variant.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
-#include "tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.h"
+#include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h"
 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
 
 namespace tflite {
@@ -46,101 +46,101 @@
 namespace {
 
 TEST(Preprocessor, CornerCases) {
-  ParameterAccessor parameters(false);
-  ObjectAccessor accessor(false, &parameters);
+  VariableAccessor variable_accessor(/*inline_values=*/false);
+  ObjectAccessor accessor(false, &variable_accessor);
   std::string result;
   ASSERT_EQ(accessor.Rewrite("", &result), RewriteStatus::NOT_RECOGNIZED);
   ASSERT_EQ(accessor.Rewrite("=", &result), RewriteStatus::NOT_RECOGNIZED);
 }
 
 TEST(Preprocessor, ReadFromBuffer) {
-  ParameterAccessor parameters(false);
-  ObjectAccessor accessor(false, &parameters);
+  VariableAccessor variable_accessor(/*inline_values=*/false);
+  ObjectAccessor accessor(false, &variable_accessor);
   ASSERT_TRUE(
       accessor.AddObject("obj", MakeReadonlyBuffer(std::vector<float>{1.0})));
   std::string result;
   EXPECT_EQ(accessor.Rewrite("obj[i]", &result), RewriteStatus::SUCCESS);
-  EXPECT_TRUE(parameters.GetUniformParameters().empty());
+  EXPECT_TRUE(variable_accessor.GetUniformParameters().empty());
   ASSERT_EQ(result, "obj.data[i]");
 }
 
 TEST(Preprocessor, ReadFromBufferLinear) {
-  ParameterAccessor parameters(false);
-  ObjectAccessor accessor(false, &parameters);
+  VariableAccessor variable_accessor(/*inline_values=*/false);
+  ObjectAccessor accessor(false, &variable_accessor);
   ASSERT_TRUE(accessor.AddObject(
       "obj", MakeReadonlyBuffer(uint3(1, 2, 3), std::vector<float>{1.0})));
   std::string result;
   EXPECT_EQ(accessor.Rewrite("obj[i]", &result), RewriteStatus::SUCCESS);
-  EXPECT_TRUE(parameters.GetUniformParameters().empty());
+  EXPECT_TRUE(variable_accessor.GetUniformParameters().empty());
   ASSERT_EQ(result, "obj.data[i]");
 }
 
 TEST(Preprocessor, ReadFromBufferByIndex) {
-  ParameterAccessor parameters(false);
-  ObjectAccessor accessor(false, &parameters);
+  VariableAccessor variable_accessor(/*inline_values=*/false);
+  ObjectAccessor accessor(false, &variable_accessor);
   ASSERT_TRUE(accessor.AddObject(
       "obj", MakeReadonlyBuffer(uint3(1, 2, 3), std::vector<float>{1.0})));
   std::string result;
   EXPECT_EQ(accessor.Rewrite("obj[x,y + 5,z]", &result),
             RewriteStatus::SUCCESS);
-  EXPECT_THAT(parameters.GetUniformParameters(),
+  EXPECT_THAT(variable_accessor.GetUniformParameters(),
               testing::UnorderedElementsAre(Variable{"obj_w", 1},
                                             Variable{"obj_h", 2}));
   ASSERT_EQ(result, "obj.data[x + $obj_w$ * (y + 5 + $obj_h$ * (z))]");
 }
 
 TEST(Preprocessor, ReadFromTexture) {
-  ParameterAccessor parameters(false);
-  ObjectAccessor accessor(false, &parameters);
+  VariableAccessor variable_accessor(/*inline_values=*/false);
+  ObjectAccessor accessor(false, &variable_accessor);
   ASSERT_TRUE(accessor.AddObject(
       "obj", MakeReadonlyTexture(uint3(1, 2, 3), {1.0, 2.0, 3.0, 4.0})));
   std::string result;
   EXPECT_EQ(accessor.Rewrite("obj[i,j,k]", &result), RewriteStatus::SUCCESS);
   // textures don't need extra variables to be stored for indexed access
-  EXPECT_TRUE(parameters.GetUniformParameters().empty());
+  EXPECT_TRUE(variable_accessor.GetUniformParameters().empty());
   ASSERT_EQ(result, "imageLoad(obj, ivec3(i, j, k))");
 }
 
 TEST(Preprocessor, ReadFromTexture1D) {
-  ParameterAccessor parameters(false);
-  ObjectAccessor accessor(false, &parameters);
+  VariableAccessor variable_accessor(/*inline_values=*/false);
+  ObjectAccessor accessor(false, &variable_accessor);
   ASSERT_TRUE(
       accessor.AddObject("obj", MakeReadonlyTexture({1.0, 2.0, 3.0, 4.0})));
   std::string result;
   EXPECT_EQ(accessor.Rewrite("obj[i]", &result), RewriteStatus::SUCCESS);
-  EXPECT_TRUE(parameters.GetUniformParameters().empty());
+  EXPECT_TRUE(variable_accessor.GetUniformParameters().empty());
   ASSERT_EQ(result, "imageLoad(obj, ivec2(i, 0))");
 }
 
 TEST(Preprocessor, WriteToBuffer) {
-  ParameterAccessor parameters(false);
-  ObjectAccessor accessor(false, &parameters);
+  VariableAccessor variable_accessor(/*inline_values=*/false);
+  ObjectAccessor accessor(false, &variable_accessor);
   ASSERT_TRUE(
       accessor.AddObject("obj", MakeReadonlyBuffer(std::vector<float>{1.0})));
   std::string result;
   EXPECT_EQ(accessor.Rewrite(" obj[i]  =value", &result),
             RewriteStatus::SUCCESS);
-  EXPECT_TRUE(parameters.GetUniformParameters().empty());
+  EXPECT_TRUE(variable_accessor.GetUniformParameters().empty());
   ASSERT_EQ(result, "obj.data[i] = value");
 }
 
 TEST(Preprocessor, WriteToBufferByIndex) {
-  ParameterAccessor parameters(false);
-  ObjectAccessor accessor(false, &parameters);
+  VariableAccessor variable_accessor(/*inline_values=*/false);
+  ObjectAccessor accessor(false, &variable_accessor);
   ASSERT_TRUE(accessor.AddObject(
       "obj", MakeReadonlyBuffer(uint3(1, 2, 3), {1.0, 2.0, 3.0, 4.0})));
   std::string result;
   EXPECT_EQ(accessor.Rewrite(" obj[i,j,k]  =value", &result),
             RewriteStatus::SUCCESS);
-  EXPECT_THAT(parameters.GetUniformParameters(),
+  EXPECT_THAT(variable_accessor.GetUniformParameters(),
               testing::UnorderedElementsAre(Variable{"obj_w", 1},
                                             Variable{"obj_h", 2}));
   ASSERT_EQ(result, "obj.data[i + $obj_w$ * (j + $obj_h$ * (k))] = value");
 }
 
 TEST(Preprocessor, WriteToTexture) {
-  ParameterAccessor parameters(false);
-  ObjectAccessor accessor(false, &parameters);
+  VariableAccessor variable_accessor(/*inline_values=*/false);
+  ObjectAccessor accessor(false, &variable_accessor);
   ASSERT_TRUE(accessor.AddObject(
       "obj", MakeReadonlyTexture(uint3(1, 1, 1), {1.0, 2.0, 3.0, 4.0})));
   std::string result;
@@ -150,20 +150,20 @@
 }
 
 TEST(Preprocessor, WriteToTexture1D) {
-  ParameterAccessor parameters(false);
-  ObjectAccessor accessor(false, &parameters);
+  VariableAccessor variable_accessor(/*inline_values=*/false);
+  ObjectAccessor accessor(false, &variable_accessor);
   ASSERT_TRUE(
       accessor.AddObject("obj", MakeReadonlyTexture({1.0, 2.0, 3.0, 4.0})));
   std::string result;
   EXPECT_EQ(accessor.Rewrite("obj[i]= value ", &result),
             RewriteStatus::SUCCESS);
-  EXPECT_TRUE(parameters.GetUniformParameters().empty());
+  EXPECT_TRUE(variable_accessor.GetUniformParameters().empty());
   ASSERT_EQ(result, "imageStore(obj, ivec2(i, 0), value)");
 }
 
 TEST(Preprocessor, FailedWriteToBuffer) {
-  ParameterAccessor parameters(false);
-  ObjectAccessor accessor(false, &parameters);
+  VariableAccessor variable_accessor(/*inline_values=*/false);
+  ObjectAccessor accessor(false, &variable_accessor);
   ASSERT_TRUE(
       accessor.AddObject("obj", MakeReadonlyBuffer(std::vector<float>{1.0})));
   std::string result;
@@ -173,8 +173,8 @@
 }
 
 TEST(Preprocessor, FailedWriteToTexture) {
-  ParameterAccessor parameters(false);
-  ObjectAccessor accessor(false, &parameters);
+  VariableAccessor variable_accessor(/*inline_values=*/false);
+  ObjectAccessor accessor(false, &variable_accessor);
   ASSERT_TRUE(accessor.AddObject(
       "obj", MakeReadonlyTexture(uint3(1, 1, 1), {1.0, 2.0, 3.0, 4.0})));
   std::string result;
@@ -183,8 +183,8 @@
 }
 
 TEST(Preprocessor, DeclareTexture) {
-  ParameterAccessor parameters(false);
-  ObjectAccessor accessor(false, &parameters);
+  VariableAccessor variable_accessor(/*inline_values=*/false);
+  ObjectAccessor accessor(false, &variable_accessor);
   ASSERT_TRUE(accessor.AddObject(
       "obj", MakeReadonlyTexture(uint3(1, 1, 1), {1.0, 2.0, 3.0, 4.0})));
   ASSERT_EQ(accessor.GetObjectDeclarations(),
@@ -193,8 +193,8 @@
 }
 
 TEST(Preprocessor, DeclareBuffer) {
-  ParameterAccessor parameters(false);
-  ObjectAccessor accessor(true, &parameters);
+  VariableAccessor variable_accessor(/*inline_values=*/false);
+  ObjectAccessor accessor(true, &variable_accessor);
   ASSERT_TRUE(
       accessor.AddObject("obj", MakeReadonlyBuffer(std::vector<float>{1.0})));
   ASSERT_EQ(accessor.GetObjectDeclarations(),
diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.cc b/tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.cc
deleted file mode 100644
index 55d7152..0000000
--- a/tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.cc
+++ /dev/null
@@ -1,368 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.h"
-
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/str_join.h"
-#include "absl/types/variant.h"
-#include "tensorflow/lite/delegates/gpu/common/types.h"
-
-namespace tflite {
-namespace gpu {
-namespace gl {
-namespace parameter_accessor_internal {
-
-// Parse the following regex manually
-// name(\[index\])?(\.field)?
-ParameterReference Parse(absl::string_view input) {
-  ParameterReference ref;
-  auto start_index = input.find('[');
-  if (start_index != std::string::npos) {
-    auto end_index = input.rfind(']');
-    if (end_index == std::string::npos) {
-      return ref;
-    }
-    ref.index = input.substr(start_index + 1, end_index - start_index - 1);
-    ref.name = input.substr(0, start_index);
-    ref.field = input.substr(end_index + 1);
-  } else {
-    auto dot = input.find('.');
-    if (dot != std::string::npos) {
-      ref.name = input.substr(0, dot);
-      ref.field = input.substr(dot);
-    } else {
-      ref.name = input;
-    }
-  }
-  return ref;
-}
-
-}  // namespace parameter_accessor_internal
-
-namespace {
-
-struct UniformTypeGetter {
-  std::string operator()(int) const { return "int"; }
-  std::string operator()(const int2&) const { return "ivec2"; }
-  std::string operator()(const std::vector<int2>&) const { return "ivec2"; }
-  std::string operator()(const int4&) const { return "ivec4"; }
-  std::string operator()(unsigned int) const { return "uint"; }
-  std::string operator()(const uint4&) const { return "uvec4"; }
-  std::string operator()(float) const { return "float"; }
-  std::string operator()(const float2&) const { return "vec2"; }
-  std::string operator()(const float4&) const { return "vec4"; }
-};
-
-// Returns GLSL uniform type of the given parameter.
-std::string GetUniformType(const Variable::ValueType& value) {
-  return absl::visit(UniformTypeGetter(), value);
-}
-
-template <typename T>
-void FormatValue(std::string* result, T t) {
-  absl::StrAppend(result, t);
-}
-
-template <>
-void FormatValue(std::string* result, float t) {
-  absl::StrAppend(result, absl::StrFormat("%.9ff", t));
-}
-
-// Unfortunately absl::StrJoin with custom formatter requires formatter to use
-// string, not std::string. Therefore, due to this compatibility issue data
-// needs to be converted to string representation first and then joined.
-template <typename T, int N>
-std::vector<std::string> ToString(const std::array<T, N>& data) {
-  std::vector<std::string> result(N);
-  for (int i = 0; i < N; ++i) {
-    FormatValue(&result[i], data[i]);
-  }
-  return result;
-}
-
-struct ConstGenerator {
-  template <typename T>
-  void operator()(T t) const {
-    FormatValue(result, t);
-  }
-
-  template <typename T>
-  void operator()(const Vec2<T>& v) const {
-    absl::StrAppend(result, UniformTypeGetter()(v), "(",
-                    absl::StrJoin(ToString<T, 2>(v.data_), ","), ")");
-  }
-
-  template <typename T>
-  void operator()(const Vec3<T>& v) const {
-    absl::StrAppend(result, UniformTypeGetter()(v), "(",
-                    absl::StrJoin(ToString<T, 3>(v.data_), ","), ")");
-  }
-
-  template <typename T>
-  void operator()(const Vec4<T>& v) const {
-    absl::StrAppend(result, UniformTypeGetter()(v), "(",
-                    absl::StrJoin(ToString<T, 4>(v.data_), ","), ")");
-  }
-
-  template <typename T>
-  void operator()(const std::vector<T>& v) const {
-    std::string type = UniformTypeGetter()(v);
-    absl::StrAppend(result, type, "[", v.size(), "](");
-    bool first = true;
-    for (const auto& i : v) {
-      if (first) {
-        first = false;
-      } else {
-        absl::StrAppend(result, ",");
-      }
-      (*this)(i);
-    }
-    absl::StrAppend(result, ")");
-  }
-
-  std::string* result;
-};
-
-// Appends string representation of a parameter value.
-void GetValue(const Variable::ValueType& value, std::string* result) {
-  absl::visit(ConstGenerator{result}, value);
-}
-
-struct UniformDeclarationGenerator {
-  template <typename T>
-  void operator()(const T&) const {
-    absl::StrAppend(result, "uniform ", GetUniformType(param.value), " ",
-                    param.name, ";\n");
-  }
-
-  template <typename T>
-  void operator()(const std::vector<T>& v) const {
-    absl::StrAppend(result, "uniform ", GetUniformType(param.value), " ",
-                    param.name, "[", v.size(), "];\n");
-  }
-
-  const Variable& param;
-  std::string* result;
-};
-
-void GenerateUniformDeclaration(const Variable& parameter,
-                                std::string* result) {
-  absl::visit(UniformDeclarationGenerator{parameter, result}, parameter.value);
-}
-
-struct VariableLengthGetter {
-  template <typename T>
-  bool operator()(const T&) const {
-    return false;
-  }
-  template <typename T>
-  bool operator()(const std::vector<T>&) const {
-    return true;
-  }
-};
-
-// Returns true if value is a vector
-bool IsVariableLength(const Variable::ValueType& value) {
-  return absl::visit(VariableLengthGetter(), value);
-}
-
-enum Field : uint8_t { UNKNOWN = 4, X = 0, Y = 1, Z = 2, W = 3 };
-
-Field ToField(absl::string_view field_name) {
-  if (field_name.size() == 2 && field_name[0] == '.') {
-    switch (field_name[1]) {
-      case 'x':
-        return Field::X;
-      case 'y':
-        return Field::Y;
-      case 'z':
-        return Field::Z;
-      case 'w':
-        return Field::W;
-    }
-  }
-  return Field::UNKNOWN;
-}
-
-struct FieldAccessor {
-  template <typename T>
-  void operator()(const T&) const {}
-
-  template <typename T>
-  void operator()(const Vec2<T>& v) const {
-    FormatValue(result, v[field]);
-  }
-
-  template <typename T>
-  void operator()(const Vec3<T>& v) const {
-    FormatValue(result, v[field]);
-  }
-
-  template <typename T>
-  void operator()(const Vec4<T>& v) const {
-    FormatValue(result, v[field]);
-  }
-
-  Field field;
-  std::string* result;
-};
-
-// Appends formatted value of the given field.
-void GetValue(const Variable::ValueType& value, Field field,
-              std::string* result) {
-  absl::visit(FieldAccessor{field, result}, value);
-}
-
-struct FieldChecker {
-  // For trivial as well as variable-length types indexed access is not allowed.
-  template <typename T>
-  bool operator()(const T&) const {
-    return false;
-  }
-
-  template <typename T>
-  bool operator()(const Vec2<T>& v) const {
-    return field < v.size();
-  }
-
-  template <typename T>
-  bool operator()(const Vec3<T>& v) const {
-    return field < v.size();
-  }
-
-  template <typename T>
-  bool operator()(const Vec4<T>& v) const {
-    return field < v.size();
-  }
-
-  template <typename T>
-  bool operator()(const std::vector<T>&) const {
-    // technically accessing [0] element of an empty vector is UB, but we need
-    // only type information for this check. Therefore, construct default T and
-    // use it instead.
-    T t;
-    return (*this)(t);
-  }
-
-  Field field;
-};
-
-// Returns true if field has field access and field is not out of bounds.
-bool HasField(const Variable::ValueType& value, Field field) {
-  return absl::visit(FieldChecker{field}, value);
-}
-
-void AssembleAccessor(absl::string_view name, absl::string_view index,
-                      absl::string_view field, std::string* result) {
-  if (index.empty()) {
-    absl::StrAppend(result, name, field);
-  } else {
-    absl::StrAppend(result, name, "[", index, "]", field);
-  }
-}
-
-}  // namespace
-
-RewriteStatus ParameterAccessor::Rewrite(absl::string_view input,
-                                         std::string* output) {
-  auto ref = parameter_accessor_internal::Parse(input);
-  if (ref.name.empty()) {
-    absl::StrAppend(output, "INVALID_SYNTAX");
-    return RewriteStatus::ERROR;
-  }
-
-  auto it = name_to_param_.find(std::string(ref.name.data(), ref.name.size()));
-  if (it == name_to_param_.end()) {
-    // Uniform with this name is not registered.
-    return RewriteStatus::NOT_RECOGNIZED;
-  }
-  const auto& value = it->second.value;
-
-  if (!ref.index.empty() && !IsVariableLength(value)) {
-    // Trying to access parameter by index, but it is not variable-length.
-    absl::StrAppend(output, "INVALID_ACCESS_BY_INDEX");
-    return RewriteStatus::ERROR;
-  }
-
-  Field f = ToField(ref.field);
-  if (!ref.field.empty() && !HasField(value, f)) {
-    // Trying to access a parameter by field, but it does not have it.
-    absl::StrAppend(output, "INVALID_ACCESS_BY_FIELD");
-    return RewriteStatus::ERROR;
-  }
-
-  // Error checks are complete now.
-
-  // All variable-length parameters are encoded as-is without inlining.
-  if (!inline_values_ || IsVariableLength(value)) {
-    AssembleAccessor(it->second.name, ref.index, ref.field, output);
-  } else {
-    // Parameter + field is replaced with field value.
-    if (f != Field::UNKNOWN) {
-      GetValue(value, f, output);
-    } else {
-      // Parameter is accessed directly.
-      GetValue(value, output);
-    }
-  }
-  return RewriteStatus::SUCCESS;
-}
-
-bool ParameterAccessor::AddParameter(Variable param) {
-  std::string name = param.name;
-  return name_to_param_.insert({name, std::move(param)}).second;
-}
-
-std::string ParameterAccessor::GetConstDeclarations() const {
-  // Variable length parameters are declared as const and accessed via variable
-  // with index.
-  std::string declarations;
-  for (auto& param : name_to_param_) {
-    const auto& value = param.second.value;
-    if (IsVariableLength(value)) {
-      absl::StrAppend(&declarations, "const ", GetUniformType(value), " ",
-                      param.second.name, "[] = ");
-      GetValue(value, &declarations);
-      absl::StrAppend(&declarations, ";\n");
-    }
-  }
-  return declarations;
-}
-
-std::string ParameterAccessor::GetUniformDeclarations() const {
-  std::string declarations;
-  if (!inline_values_) {
-    for (auto& param : name_to_param_) {
-      GenerateUniformDeclaration(param.second, &declarations);
-    }
-  }
-  return declarations;
-}
-
-std::vector<Variable> ParameterAccessor::GetUniformParameters() const {
-  std::vector<Variable> params;
-  if (!inline_values_) {
-    for (auto& param : name_to_param_) {
-      params.push_back(param.second);
-    }
-  }
-  return params;
-}
-
-}  // namespace gl
-}  // namespace gpu
-}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.h b/tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.h
deleted file mode 100644
index 3dacc34..0000000
--- a/tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.h
+++ /dev/null
@@ -1,92 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_PARAMETER_ACCESSOR_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_PARAMETER_ACCESSOR_H_
-
-#include <string>
-#include <unordered_map>
-#include <vector>
-
-#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h"
-#include "tensorflow/lite/delegates/gpu/gl/variable.h"
-
-namespace tflite {
-namespace gpu {
-namespace gl {
-
-// This rewrite handles access to parameters. It may rewrite a parameter with
-// actual values if inline_values is set to true.
-//
-// The following syntax is supported to access parameters:
-//  - simple parameter: name
-//  - parameter with field: name.(x|y|z|w)
-//  - parameter with index: name[i]
-//  - parameter with index and field: name[i].(x|y|z|w)
-//
-// If 'inline_values' is set to true, non variable-length parameters will be
-// inlined. For example, 'base.x' will be replaced with value of 'x' field from
-// 'base'. Variable-length are declared as const and accessed via index.
-// These declarations are returned by GetConstDeclarations.
-//
-// If 'inline_values' is set to false, all parameters will be declared as
-// uniforms. Uniform declarations are returned by GetUniformDeclarations.
-class ParameterAccessor : public InlineRewrite {
- public:
-  explicit ParameterAccessor(bool inline_values)
-      : inline_values_(inline_values) {}
-
-  RewriteStatus Rewrite(absl::string_view input, std::string* output) final;
-
-  // Return true if parameter was successfully added.
-  bool AddParameter(Variable param);
-
-  // Returns const parameters that need to be inlined in the a shader's code.
-  std::string GetConstDeclarations() const;
-
-  // Returns uniforms declarations that need to be inlined in a shader's code.
-  std::string GetUniformDeclarations() const;
-
-  // Returns a collection of uniform parameters.
-  std::vector<Variable> GetUniformParameters() const;
-
- private:
-  const bool inline_values_;
-  // Unique parameter index used for obfuscation.
-  uint32_t unique_param_index_ = 0;
-
-  std::unordered_map<std::string, Variable> name_to_param_;
-};
-
-// Implementation details below.
-
-namespace parameter_accessor_internal {
-
-struct ParameterReference {
-  absl::string_view name;
-  absl::string_view index;
-  absl::string_view field;
-};
-
-// Parse the following regex manually
-// name(\[index\])?(\.field)?
-ParameterReference Parse(absl::string_view input);
-
-}  // namespace parameter_accessor_internal
-}  // namespace gl
-}  // namespace gpu
-}  // namespace tflite
-
-#endif  // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_PARAMETER_ACCESSOR_H_
diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor_test.cc b/tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor_test.cc
deleted file mode 100644
index d8c634e..0000000
--- a/tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor_test.cc
+++ /dev/null
@@ -1,97 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.h"
-
-#include <string>
-#include <vector>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "tensorflow/lite/delegates/gpu/common/types.h"
-
-namespace tflite {
-namespace gpu {
-namespace gl {
-namespace {
-
-TEST(Preprocessor, CornerCases) {
-  ParameterAccessor accessor(true);
-  std::string result;
-  ASSERT_EQ(accessor.Rewrite("unknown", &result),
-            RewriteStatus::NOT_RECOGNIZED);
-}
-
-TEST(Preprocessor, Value) {
-  ParameterAccessor accessor(true);
-  ASSERT_TRUE(accessor.AddParameter({"var", int32_t(1)}));
-  std::string result;
-  EXPECT_EQ(accessor.Rewrite("var", &result), RewriteStatus::SUCCESS);
-  ASSERT_EQ(result, "1");
-}
-
-TEST(Preprocessor, ValueVec) {
-  ParameterAccessor accessor(true);
-  ASSERT_TRUE(accessor.AddParameter({"var", int2(1, 2)}));
-  std::string result;
-  EXPECT_EQ(accessor.Rewrite("var", &result), RewriteStatus::SUCCESS);
-  ASSERT_EQ(result, "ivec2(1,2)");
-}
-
-TEST(Preprocessor, Field) {
-  ParameterAccessor accessor(true);
-  ASSERT_TRUE(accessor.AddParameter({"var", float2(1.0, 2.1234567)}));
-  std::string result;
-  EXPECT_EQ(accessor.Rewrite("var.y", &result), RewriteStatus::SUCCESS);
-  ASSERT_EQ(result, "2.123456717f");
-}
-
-TEST(Preprocessor, FieldFail) {
-  ParameterAccessor accessor(true);
-  ASSERT_TRUE(accessor.AddParameter({"var", 1.0f}));
-  ASSERT_TRUE(accessor.AddParameter({"vec", float2(1.0, 1.0)}));
-  std::string result;
-  EXPECT_EQ(accessor.Rewrite("var.y", &result), RewriteStatus::ERROR);
-  ASSERT_EQ(result, "INVALID_ACCESS_BY_FIELD");
-
-  result.clear();
-  EXPECT_EQ(accessor.Rewrite("vec.z", &result), RewriteStatus::ERROR);
-  ASSERT_EQ(result, "INVALID_ACCESS_BY_FIELD");
-}
-
-TEST(Preprocessor, Variable) {
-  ParameterAccessor accessor(true);
-  std::vector<int2> v;
-  v.push_back(int2(1, 2));
-  ASSERT_TRUE(accessor.AddParameter({"var", v}));
-  std::string result;
-  EXPECT_EQ(accessor.Rewrite("var[i].y", &result), RewriteStatus::SUCCESS);
-  ASSERT_EQ(result, "var[i].y");
-  ASSERT_EQ(accessor.GetConstDeclarations(),
-            "const ivec2 var[] = ivec2[1](ivec2(1,2));\n");
-}
-
-TEST(Preprocessor, InlineVariableFail) {
-  ParameterAccessor accessor(true);
-  ASSERT_TRUE(accessor.AddParameter({"var", 1}));
-  std::string result;
-  EXPECT_EQ(accessor.Rewrite("var[i]", &result), RewriteStatus::ERROR);
-  ASSERT_EQ(result, "INVALID_ACCESS_BY_INDEX");
-}
-
-}  // namespace
-}  // namespace gl
-}  // namespace gpu
-}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/rename.cc b/tensorflow/lite/delegates/gpu/gl/compiler/rename.cc
index e8d1d78..674002b 100644
--- a/tensorflow/lite/delegates/gpu/gl/compiler/rename.cc
+++ b/tensorflow/lite/delegates/gpu/gl/compiler/rename.cc
@@ -25,8 +25,8 @@
 #include "absl/strings/str_split.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h"
-#include "tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.h"
 #include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h"
+#include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h"
 #include "tensorflow/lite/delegates/gpu/gl/object.h"
 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
 
@@ -35,24 +35,24 @@
 namespace gl {
 namespace {
 
-// Rewrites names of all parameters according to returned values from the
+// Rewrites names of all variables according to returned values from the
 // given NameFunctor.
-class ParameterRewriter : public InlineRewrite {
+class VariableRewriter : public InlineRewrite {
  public:
-  ParameterRewriter(const std::string& inline_delimiter,
-                    const NameFunctor& name_func)
+  VariableRewriter(const std::string& inline_delimiter,
+                   const NameFunctor& name_func)
       : inline_delimiter_(inline_delimiter), name_func_(name_func) {}
 
   RewriteStatus Rewrite(absl::string_view input, std::string* output) final {
-    auto ref = parameter_accessor_internal::Parse(input);
+    auto ref = variable_accessor_internal::Parse(input);
     if (ref.name.empty()) {
       absl::StrAppend(output, "INVALID_SYNTAX");
       return RewriteStatus::ERROR;
     }
 
     auto it =
-        name_to_param_.find(std::string(ref.name.data(), ref.name.size()));
-    if (it == name_to_param_.end()) {
+        name_to_variable_.find(std::string(ref.name.data(), ref.name.size()));
+    if (it == name_to_variable_.end()) {
       return RewriteStatus::NOT_RECOGNIZED;
     }
 
@@ -65,28 +65,28 @@
     return RewriteStatus::SUCCESS;
   }
 
-  // Return true if parameter was successfully added.
-  bool AddParameter(Variable param) {
-    std::string old_name = param.name;
-    param.name = name_func_(old_name);
-    return name_to_param_.insert({old_name, std::move(param)}).second;
+  // Return true if variable was successfully added.
+  bool AddVariable(Variable&& variable) {
+    std::string old_name = variable.name;
+    variable.name = name_func_(old_name);
+    return name_to_variable_.insert({old_name, std::move(variable)}).second;
   }
 
   // Returns a collection of uniform parameters with updated names.
   std::vector<Variable> GetUniformParameters() const {
-    std::vector<Variable> params;
-    params.reserve(name_to_param_.size());
-    for (auto& param : name_to_param_) {
-      params.push_back(param.second);
+    std::vector<Variable> variables;
+    variables.reserve(name_to_variable_.size());
+    for (const auto& variable : name_to_variable_) {
+      variables.push_back(variable.second);
     }
-    return params;
+    return variables;
   }
 
  private:
   const std::string inline_delimiter_;
   const NameFunctor name_func_;
 
-  std::unordered_map<std::string, Variable> name_to_param_;
+  std::unordered_map<std::string, Variable> name_to_variable_;
 };
 
 // Rewrites names of all objects according to returned values from the
@@ -122,7 +122,7 @@
   std::vector<std::pair<std::string, Object>> GetObjects() const {
     std::vector<std::pair<std::string, Object>> objects;
     objects.reserve(name_to_object_.size());
-    for (auto& o : name_to_object_) {
+    for (const auto& o : name_to_object_) {
       objects.push_back(o.second);
     }
     return objects;
@@ -175,11 +175,11 @@
 }  // namespace
 
 Status Rename(const NameFunctor& name_func, GeneratedCode* code) {
-  ParameterRewriter param_rewriter("$", name_func);
+  VariableRewriter variable_rewriter("$", name_func);
   ObjectRewriter object_rewriter("$", name_func);
-  for (auto&& param : code->parameters) {
-    if (!param_rewriter.AddParameter(std::move(param))) {
-      return InternalError("Parameter name already exists");
+  for (auto&& uniform_parameter : code->parameters) {
+    if (!variable_rewriter.AddVariable(std::move(uniform_parameter))) {
+      return InternalError("Variable name already exists");
     }
   }
   for (auto&& object : code->objects) {
@@ -187,13 +187,13 @@
       return InternalError("Object name already exists");
     }
   }
-  TextPreprocessor preprocessor('$', /* keep_unknown_rewrites = */ true);
-  preprocessor.AddRewrite(&param_rewriter);
+  TextPreprocessor preprocessor('$', /*keep_unknown_rewrites=*/true);
+  preprocessor.AddRewrite(&variable_rewriter);
   preprocessor.AddRewrite(&object_rewriter);
   std::string source_code;
   RETURN_IF_ERROR(preprocessor.Rewrite(code->source_code, &source_code));
   code->source_code = source_code;
-  code->parameters = param_rewriter.GetUniformParameters();
+  code->parameters = variable_rewriter.GetUniformParameters();
   code->objects = object_rewriter.GetObjects();
   return OkStatus();
 }
diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.cc b/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.cc
index 30da547..4b61948 100644
--- a/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.cc
+++ b/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.cc
@@ -18,8 +18,10 @@
 #include <algorithm>
 
 #include "absl/strings/str_cat.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h"
+#include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h"
 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
 
 namespace tflite {
@@ -32,32 +34,41 @@
 
 Status ShaderCodegen::Build(CompiledNodeAttributes attr,
                             ShaderCode* shader_code) const {
-  ParameterAccessor parameters(options_.inline_parameters);
-  ObjectAccessor objects(gpu_type_ == GpuType::MALI, options_.sampler_textures,
-                         &parameters);
+  VariableAccessor variable_accessor(options_.inline_parameters);
+  ObjectAccessor object_accessor(gpu_type_ == GpuType::MALI,
+                                 options_.sampler_textures, &variable_accessor);
 
-  auto add_object = [&](const std::string& name, Object&& object) {
-    if (!objects.AddObject(name, std::forward<Object>(object))) {
-      return InternalError("There is an object with the same name");
+  const auto add_object = [&](const std::string& name, Object&& object) {
+    if (!object_accessor.AddObject(name, std::forward<Object>(object))) {
+      return AlreadyExistsError(absl::StrCat("Object \"", name, "\""));
     }
     return OkStatus();
   };
 
-  auto add_parameter = [&](Variable&& param) {
-    if (!parameters.AddParameter(std::forward<Variable>(param))) {
-      return InternalError("There is a parameter with the same name");
+  const auto add_uniform_parameter = [&](Variable&& variable) {
+    const std::string name = variable.name;
+    if (!variable_accessor.AddUniformParameter(std::move(variable))) {
+      return AlreadyExistsError(
+          absl::StrCat("Uniform parameter \"", name, "\""));
     }
     return OkStatus();
   };
 
-  for (auto&& param : attr.code.parameters) {
-    RETURN_IF_ERROR(add_parameter(std::move(param)));
-  }
-
   for (auto&& object : attr.code.objects) {
     RETURN_IF_ERROR(add_object(object.first, std::move(object.second)));
   }
 
+  for (auto&& variable : attr.code.shared_variables) {
+    const std::string name = variable.name;
+    if (!variable_accessor.AddSharedVariable(std::move(variable))) {
+      return AlreadyExistsError(absl::StrCat("Shared variable \"", name, "\""));
+    }
+  }
+
+  for (auto&& variable : attr.code.parameters) {
+    RETURN_IF_ERROR(add_uniform_parameter(std::move(variable)));
+  }
+
   int index = 0;
   for (auto&& input : attr.inputs) {
     RETURN_IF_ERROR(
@@ -71,14 +82,14 @@
 
   // TODO(akulik): workload params need to go away and be replaced with
   // output_data_0_w
-  RETURN_IF_ERROR(add_parameter(
+  RETURN_IF_ERROR(add_uniform_parameter(
       {"workload_x", static_cast<int32_t>(attr.code.workload.x)}));
-  RETURN_IF_ERROR(add_parameter(
+  RETURN_IF_ERROR(add_uniform_parameter(
       {"workload_y", static_cast<int32_t>(attr.code.workload.y)}));
-  RETURN_IF_ERROR(add_parameter(
+  RETURN_IF_ERROR(add_uniform_parameter(
       {"workload_z", static_cast<int32_t>(attr.code.workload.z)}));
 
-  std::string source_code = R"(
+  std::string main_source_code = R"(
   ivec3 gid = ivec3(gl_GlobalInvocationID.xyz);
   if (gid.x >= $workload_x$ || gid.y >= $workload_y$ || gid.z >= $workload_z$) {
     return;
@@ -88,60 +99,68 @@
   switch (attr.code.input) {
     case IOStructure::ONLY_DEFINITIONS:
       for (int i = 0; i < attr.inputs.size(); ++i) {
-        absl::StrAppend(&source_code, "  highp vec4 value_", i,
+        absl::StrAppend(&main_source_code, "  highp vec4 value_", i,
                         " = vec4(0);\n");
       }
       break;
     case IOStructure::AUTO: {
       for (int i = 0; i < attr.inputs.size(); ++i) {
-        absl::StrAppend(&source_code, "  highp vec4 value_", i,
+        absl::StrAppend(&main_source_code, "  highp vec4 value_", i,
                         " = $input_data_", i, "[gid.x, gid.y, gid.z]$;\n");
       }
       break;
     }
   }
 
-  source_code.append(attr.code.source_code);
+  main_source_code.append(attr.code.source_code);
 
   if (attr.code.output == IOStructure::AUTO) {
     for (int i = 0; i < attr.outputs.size(); ++i) {
-      absl::StrAppend(&source_code, "  $output_data_", i,
+      absl::StrAppend(&main_source_code, "  $output_data_", i,
                       "[gid.x, gid.y, gid.z] = value_", i, "$;\n");
     }
   }
 
   // At this point main function is already generated. Now we need to process
-  // object and parameter accessors.
+  // object and variable accessors.
 
   // process objects first. Object accessor may introduce new uniform
   // parameters that need to be rewritten in the subsequent pass.
   {
     TextPreprocessor preprocessor('$', /*keep_unknown_rewrites=*/true);
-    preprocessor.AddRewrite(&objects);
-    RETURN_IF_ERROR(preprocessor.Rewrite(source_code, &source_code));
+    preprocessor.AddRewrite(&object_accessor);
+    RETURN_IF_ERROR(preprocessor.Rewrite(main_source_code, &main_source_code));
   }
 
   {
     TextPreprocessor preprocessor('$', /*keep_unknown_rewrites=*/false);
-    preprocessor.AddRewrite(&parameters);
-    RETURN_IF_ERROR(preprocessor.Rewrite(source_code, &source_code));
+    preprocessor.AddRewrite(&variable_accessor);
+    RETURN_IF_ERROR(preprocessor.Rewrite(main_source_code, &main_source_code));
   }
 
   if (options_.inline_parameters) {
-    source_code = absl::StrCat(parameters.GetConstDeclarations(), source_code);
+    main_source_code = absl::StrCat(variable_accessor.GetConstDeclarations(),
+                                    main_source_code);
   }
 
-  std::string declarations = absl::StrCat(
-      objects.GetFunctionsDeclarations(), "\n", objects.GetObjectDeclarations(),
-      "\n", parameters.GetUniformDeclarations());
-  *shader_code = ShaderCode(
-      parameters.GetUniformParameters(), objects.GetObjects(),
-      attr.code.workload, attr.code.workgroup,
-      absl::StrCat("layout(std430) buffer;\nprecision ",
-                   (options_.allow_precision_loss ? "mediump" : "highp"),
-                   " float;\n", declarations, "\nvoid main() {\n", source_code,
-                   "\n}"),
-      attr.node_indices);
+  // partial_source_code is only missing the following which is added later:
+  // #version 310 es
+  // layout(local_size_x = ..., local_size_y = ..., local_size_z = ...) in;
+  const char* precision = options_.allow_precision_loss ? "mediump" : "highp";
+  const std::string partial_source_code = absl::StrCat(
+      "layout(std430) buffer;\n",                                 //
+      "precision ", precision, " float;\n",                       //
+      object_accessor.GetFunctionsDeclarations(), "\n",           //
+      object_accessor.GetObjectDeclarations(), "\n",              //
+      variable_accessor.GetSharedVariableDeclarations(), "\n",    //
+      variable_accessor.GetUniformParameterDeclarations(), "\n",  //
+      "void main() {\n",                                          //
+      main_source_code,                                           //
+      "}");
+  *shader_code =
+      ShaderCode(variable_accessor.GetUniformParameters(),
+                 object_accessor.GetObjects(), attr.code.workload,
+                 attr.code.workgroup, partial_source_code, attr.node_indices);
   return OkStatus();
 }
 
diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.h b/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.h
index 06e4cf8..c4f09a3 100644
--- a/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.h
+++ b/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.h
@@ -19,14 +19,13 @@
 #include <string>
 #include <vector>
 
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h"
 #include "tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h"
-#include "tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.h"
 #include "tensorflow/lite/delegates/gpu/gl/compiler/shader_code.h"
 #include "tensorflow/lite/delegates/gpu/gl/compiler_options.h"
-#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/gl/object.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.cc b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.cc
new file mode 100644
index 0000000..b0e22b6
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.cc
@@ -0,0 +1,423 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h"
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "absl/types/variant.h"
+#include "tensorflow/lite/delegates/gpu/common/types.h"
+
+namespace tflite {
+namespace gpu {
+namespace gl {
+namespace variable_accessor_internal {
+
+// Parse the following regex manually
+// name(\[index\])?(\.field)?
+VariableReference Parse(absl::string_view input) {
+  VariableReference ref;
+  auto start_index = input.find('[');
+  if (start_index != std::string::npos) {
+    auto end_index = input.rfind(']');
+    if (end_index == std::string::npos) {
+      return ref;
+    }
+    ref.index = input.substr(start_index + 1, end_index - start_index - 1);
+    ref.name = input.substr(0, start_index);
+    ref.field = input.substr(end_index + 1);
+  } else {
+    auto dot = input.find('.');
+    if (dot != std::string::npos) {
+      ref.name = input.substr(0, dot);
+      ref.field = input.substr(dot);
+    } else {
+      ref.name = input;
+    }
+  }
+  return ref;
+}
+
+}  // namespace variable_accessor_internal
+
+namespace {
+
+struct VariableTypeGetter {
+  std::string operator()(int) const { return "int"; }
+  std::string operator()(const int2&) const { return "ivec2"; }
+  std::string operator()(const std::vector<int2>&) const { return "ivec2"; }
+  std::string operator()(const int4&) const { return "ivec4"; }
+  std::string operator()(unsigned int) const { return "uint"; }
+  std::string operator()(const uint4&) const { return "uvec4"; }
+  std::string operator()(float) const { return "float"; }
+  std::string operator()(const float2&) const { return "vec2"; }
+  std::string operator()(const float4&) const { return "vec4"; }
+  std::string operator()(const std::vector<float4>&) const { return "vec4"; }
+};
+
+// Returns GLSL uniform type of the given variable.
+std::string GetVariableType(const Variable::ValueType& value) {
+  return absl::visit(VariableTypeGetter(), value);
+}
+
+template <typename T>
+void FormatValue(std::string* result, T t) {
+  absl::StrAppend(result, t);
+}
+
+template <>
+void FormatValue(std::string* result, float t) {
+  absl::StrAppend(result, absl::StrFormat("%.9ff", t));
+}
+
+// Unfortunately absl::StrJoin with custom formatter requires formatter to use
+// string, not std::string. Therefore, due to this compatibility issue data
+// needs to be converted to string representation first and then joined.
+template <typename T, int N>
+std::vector<std::string> ToString(const std::array<T, N>& data) {
+  std::vector<std::string> result(N);
+  for (int i = 0; i < N; ++i) {
+    FormatValue(&result[i], data[i]);
+  }
+  return result;
+}
+
+struct ConstGenerator {
+  template <typename T>
+  void operator()(T t) const {
+    FormatValue(result, t);
+  }
+
+  template <typename T>
+  void operator()(const Vec2<T>& v) const {
+    absl::StrAppend(result, VariableTypeGetter()(v), "(",
+                    absl::StrJoin(ToString<T, 2>(v.data_), ","), ")");
+  }
+
+  template <typename T>
+  void operator()(const Vec3<T>& v) const {
+    absl::StrAppend(result, VariableTypeGetter()(v), "(",
+                    absl::StrJoin(ToString<T, 3>(v.data_), ","), ")");
+  }
+
+  template <typename T>
+  void operator()(const Vec4<T>& v) const {
+    absl::StrAppend(result, VariableTypeGetter()(v), "(",
+                    absl::StrJoin(ToString<T, 4>(v.data_), ","), ")");
+  }
+
+  template <typename T>
+  void operator()(const std::vector<T>& v) const {
+    std::string type = VariableTypeGetter()(v);
+    absl::StrAppend(result, type, "[", v.size(), "](");
+    bool first = true;
+    for (const auto& i : v) {
+      if (first) {
+        first = false;
+      } else {
+        absl::StrAppend(result, ",");
+      }
+      (*this)(i);
+    }
+    absl::StrAppend(result, ")");
+  }
+
+  std::string* result;
+};
+
+// Appends string representation of a variable value.
+void GetValue(const Variable::ValueType& value, std::string* result) {
+  absl::visit(ConstGenerator{result}, value);
+}
+
+struct SharedVariableDeclarationGenerator {
+  template <typename T>
+  void operator()(const T&) const {
+    absl::StrAppend(result, "shared ", GetVariableType(variable.value), " ",
+                    variable.name, ";\n");
+  }
+
+  template <typename T>
+  void operator()(const std::vector<T>& v) const {
+    absl::StrAppend(result, "shared ", GetVariableType(variable.value), " ",
+                    variable.name, "[", v.size(), "];\n");
+  }
+
+  const Variable& variable;
+  std::string* result;
+};
+
+void GenerateSharedVariableDeclaration(const Variable& variable,
+                                       std::string* result) {
+  absl::visit(SharedVariableDeclarationGenerator{variable, result},
+              variable.value);
+}
+
+struct UniformParameterDeclarationGenerator {
+  template <typename T>
+  void operator()(const T&) const {
+    absl::StrAppend(result, "uniform ", GetVariableType(variable.value), " ",
+                    variable.name, ";\n");
+  }
+
+  template <typename T>
+  void operator()(const std::vector<T>& v) const {
+    absl::StrAppend(result, "uniform ", GetVariableType(variable.value), " ",
+                    variable.name, "[", v.size(), "];\n");
+  }
+
+  const Variable& variable;
+  std::string* result;
+};
+
+void GenerateUniformParameterDeclaration(const Variable& variable,
+                                         std::string* result) {
+  absl::visit(UniformParameterDeclarationGenerator{variable, result},
+              variable.value);
+}
+
+struct VariableLengthGetter {
+  template <typename T>
+  bool operator()(const T&) const {
+    return false;
+  }
+  template <typename T>
+  bool operator()(const std::vector<T>&) const {
+    return true;
+  }
+};
+
+// Returns true if value is a vector
+bool IsVariableLength(const Variable::ValueType& value) {
+  return absl::visit(VariableLengthGetter(), value);
+}
+
+enum Field : uint8_t { UNKNOWN = 4, X = 0, Y = 1, Z = 2, W = 3 };
+
+Field ToField(absl::string_view field_name) {
+  if (field_name.size() == 2 && field_name[0] == '.') {
+    switch (field_name[1]) {
+      case 'x':
+        return Field::X;
+      case 'y':
+        return Field::Y;
+      case 'z':
+        return Field::Z;
+      case 'w':
+        return Field::W;
+    }
+  }
+  return Field::UNKNOWN;
+}
+
+struct FieldAccessor {
+  template <typename T>
+  void operator()(const T&) const {}
+
+  template <typename T>
+  void operator()(const Vec2<T>& v) const {
+    FormatValue(result, v[field]);
+  }
+
+  template <typename T>
+  void operator()(const Vec3<T>& v) const {
+    FormatValue(result, v[field]);
+  }
+
+  template <typename T>
+  void operator()(const Vec4<T>& v) const {
+    FormatValue(result, v[field]);
+  }
+
+  Field field;
+  std::string* result;
+};
+
+// Appends formatted value of the given field.
+void GetValue(const Variable::ValueType& value, Field field,
+              std::string* result) {
+  absl::visit(FieldAccessor{field, result}, value);
+}
+
+struct FieldChecker {
+  // For trivial as well as variable-length types indexed access is not allowed.
+  template <typename T>
+  bool operator()(const T&) const {
+    return false;
+  }
+
+  template <typename T>
+  bool operator()(const Vec2<T>& v) const {
+    return field < v.size();
+  }
+
+  template <typename T>
+  bool operator()(const Vec3<T>& v) const {
+    return field < v.size();
+  }
+
+  template <typename T>
+  bool operator()(const Vec4<T>& v) const {
+    return field < v.size();
+  }
+
+  template <typename T>
+  bool operator()(const std::vector<T>&) const {
+    // technically accessing [0] element of an empty vector is UB, but we need
+    // only type information for this check. Therefore, construct default T and
+    // use it instead.
+    T t;
+    return (*this)(t);
+  }
+
+  Field field;
+};
+
+// Returns true if field has field access and field is not out of bounds.
+bool HasField(const Variable::ValueType& value, Field field) {
+  return absl::visit(FieldChecker{field}, value);
+}
+
+void AssembleAccessor(absl::string_view name, absl::string_view index,
+                      absl::string_view field, std::string* result) {
+  if (index.empty()) {
+    absl::StrAppend(result, name, field);
+  } else {
+    absl::StrAppend(result, name, "[", index, "]", field);
+  }
+}
+
+}  // namespace
+
+RewriteStatus VariableAccessor::Rewrite(absl::string_view input,
+                                        std::string* output) {
+  auto ref = variable_accessor_internal::Parse(input);
+  if (ref.name.empty()) {
+    absl::StrAppend(output, "INVALID_SYNTAX");
+    return RewriteStatus::ERROR;
+  }
+
+  auto it =
+      name_to_variable_.find(std::string(ref.name.data(), ref.name.size()));
+  if (it == name_to_variable_.end()) {
+    // Uniform with this name is not registered.
+    return RewriteStatus::NOT_RECOGNIZED;
+  }
+  const auto& value = it->second.value;
+
+  if (!ref.index.empty() && !IsVariableLength(value)) {
+    // Trying to access variable by index, but it is not variable-length.
+    absl::StrAppend(output, "INVALID_ACCESS_BY_INDEX");
+    return RewriteStatus::ERROR;
+  }
+
+  Field f = ToField(ref.field);
+  if (!ref.field.empty() && !HasField(value, f)) {
+    // Trying to access a variable by field, but it does not have it.
+    absl::StrAppend(output, "INVALID_ACCESS_BY_FIELD");
+    return RewriteStatus::ERROR;
+  }
+
+  // Error checks are complete now.
+
+  // All variable-length variables are encoded as-is without inlining.
+  if (!inline_values_ || IsVariableLength(value)) {
+    AssembleAccessor(it->second.name, ref.index, ref.field, output);
+  } else {
+    // Parameter + field is replaced with field value.
+    if (f != Field::UNKNOWN) {
+      GetValue(value, f, output);
+    } else {
+      // Parameter is accessed directly.
+      GetValue(value, output);
+    }
+  }
+  return RewriteStatus::SUCCESS;
+}
+
+bool VariableAccessor::AddSharedVariable(Variable&& variable) {
+  const std::string name = variable.name;
+  if (!name_to_variable_.insert({name, std::move(variable)}).second) {
+    return false;
+  }
+  shared_variables_.insert(name);
+  return true;
+}
+
+bool VariableAccessor::AddUniformParameter(Variable&& variable) {
+  const std::string name = variable.name;
+  if (!name_to_variable_.insert({name, std::move(variable)}).second) {
+    return false;
+  }
+  uniform_parameters_.insert(name);
+  return true;
+}
+
+std::string VariableAccessor::GetConstDeclarations() const {
+  // Variable length variables are declared as const and accessed via variable
+  // with index.
+  std::string declarations;
+  for (const auto& variable : name_to_variable_) {
+    // Skip shared variables.
+    if (shared_variables_.find(variable.second.name) !=
+        shared_variables_.end()) {
+      continue;
+    }
+    const auto& value = variable.second.value;
+    if (IsVariableLength(value)) {
+      absl::StrAppend(&declarations, "const ", GetVariableType(value), " ",
+                      variable.second.name, "[] = ");
+      GetValue(value, &declarations);
+      absl::StrAppend(&declarations, ";\n");
+    }
+  }
+  return declarations;
+}
+
+std::string VariableAccessor::GetSharedVariableDeclarations() const {
+  std::string declarations;
+  for (const auto& name : shared_variables_) {
+    const auto& variable = name_to_variable_.at(name);
+    GenerateSharedVariableDeclaration(variable, &declarations);
+  }
+  return declarations;
+}
+
+std::string VariableAccessor::GetUniformParameterDeclarations() const {
+  std::string declarations;
+  if (!inline_values_) {
+    for (const auto& name : uniform_parameters_) {
+      const auto& variable = name_to_variable_.at(name);
+      GenerateUniformParameterDeclaration(variable, &declarations);
+    }
+  }
+  return declarations;
+}
+
+std::vector<Variable> VariableAccessor::GetUniformParameters() const {
+  std::vector<Variable> variables;
+  if (!inline_values_) {
+    variables.reserve(name_to_variable_.size());
+    for (const auto& variable : name_to_variable_) {
+      variables.push_back(variable.second);
+    }
+  }
+  return variables;
+}
+
+}  // namespace gl
+}  // namespace gpu
+}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h
new file mode 100644
index 0000000..d6a1063
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h
@@ -0,0 +1,98 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_VARIABLE_ACCESSOR_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_VARIABLE_ACCESSOR_H_
+
+#include <string>
+#include <unordered_map>
+#include <set>
+#include <vector>
+
+#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h"
+#include "tensorflow/lite/delegates/gpu/gl/variable.h"
+
+namespace tflite {
+namespace gpu {
+namespace gl {
+
+// This rewrite handles access to variables. It may rewrite a variable with
+// actual values if 'inline_values' is set to true.
+//
+// The following syntax is supported to access variables:
+//  - simple variable: name
+//  - variable with field: name.(x|y|z|w)
+//  - variable with index: name[i]
+//  - variable with index and field: name[i].(x|y|z|w)
+//
+// If 'inline_values' is set to true, non-variable-length variables will be
+// inlined. For example, 'base.x' will be replaced with value of 'x' field from
+// 'base'. Variable-length variables are declared as const and accessed via
+// index. These declarations are returned by GetConstDeclarations.
+//
+// If 'inline_values' is set to false, all variables will be declared as
+// uniforms. Uniform declarations are returned by GetUniformDeclarations.
+class VariableAccessor : public InlineRewrite {
+ public:
+  explicit VariableAccessor(bool inline_values)
+      : inline_values_(inline_values) {}
+
+  RewriteStatus Rewrite(absl::string_view input, std::string* output) final;
+
+  // Returns true if variable was successfully added.
+  bool AddSharedVariable(Variable&& variable);
+
+  // Returns true if variable was successfully added.
+  bool AddUniformParameter(Variable&& variable);
+
+  // Returns const variables that need to be inlined in the a shader's code.
+  std::string GetConstDeclarations() const;
+
+  // Returns shared varaible declarations that need to be inlined.
+  std::string GetSharedVariableDeclarations() const;
+
+  // Returns uniform parameter declarations that need to be inlined.
+  std::string GetUniformParameterDeclarations() const;
+
+  // Returns a collection of uniform parameters.
+  std::vector<Variable> GetUniformParameters() const;
+
+ private:
+  const bool inline_values_;
+  std::unordered_map<std::string, Variable> name_to_variable_;
+  std::set<std::string> shared_variables_;
+  std::set<std::string> uniform_parameters_;
+};
+
+// Implementation details below.
+
+namespace variable_accessor_internal {
+
+struct VariableReference {
+  absl::string_view name;
+  absl::string_view index;
+  absl::string_view field;
+};
+
+// Parse the following regex manually
+// name(\[index\])?(\.field)?
+VariableReference Parse(absl::string_view input);
+
+}  // namespace variable_accessor_internal
+}  // namespace gl
+}  // namespace gpu
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_VARIABLE_ACCESSOR_H_
diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor_test.cc b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor_test.cc
new file mode 100644
index 0000000..0e8be2a
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor_test.cc
@@ -0,0 +1,100 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h"
+
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/lite/delegates/gpu/common/types.h"
+
+namespace tflite {
+namespace gpu {
+namespace gl {
+namespace {
+
+TEST(PreprocessorTest, CornerCases) {
+  VariableAccessor variable_accessor(/*inline_values=*/true);
+  std::string result;
+  EXPECT_EQ(variable_accessor.Rewrite("unknown", &result),
+            RewriteStatus::NOT_RECOGNIZED);
+}
+
+TEST(PreprocessorTest, Value) {
+  VariableAccessor variable_accessor(/*inline_values=*/true);
+  ASSERT_TRUE(variable_accessor.AddUniformParameter({"var", int32_t(1)}));
+  std::string result;
+  ASSERT_EQ(variable_accessor.Rewrite("var", &result), RewriteStatus::SUCCESS);
+  EXPECT_EQ(result, "1");
+}
+
+TEST(PreprocessorTest, ValueVec) {
+  VariableAccessor variable_accessor(/*inline_values=*/true);
+  ASSERT_TRUE(variable_accessor.AddUniformParameter({"var", int2(1, 2)}));
+  std::string result;
+  ASSERT_EQ(variable_accessor.Rewrite("var", &result), RewriteStatus::SUCCESS);
+  EXPECT_EQ(result, "ivec2(1,2)");
+}
+
+TEST(PreprocessorTest, Field) {
+  VariableAccessor variable_accessor(/*inline_values=*/true);
+  ASSERT_TRUE(
+      variable_accessor.AddUniformParameter({"var", float2(1.0, 2.1234567)}));
+  std::string result;
+  ASSERT_EQ(variable_accessor.Rewrite("var.y", &result),
+            RewriteStatus::SUCCESS);
+  EXPECT_EQ(result, "2.123456717f");
+}
+
+TEST(PreprocessorTest, FieldFail) {
+  VariableAccessor variable_accessor(/*inline_values=*/true);
+  ASSERT_TRUE(variable_accessor.AddUniformParameter({"var", 1.0f}));
+  ASSERT_TRUE(variable_accessor.AddUniformParameter({"vec", float2(1.0, 1.0)}));
+  std::string result;
+  ASSERT_EQ(variable_accessor.Rewrite("var.y", &result), RewriteStatus::ERROR);
+  EXPECT_EQ(result, "INVALID_ACCESS_BY_FIELD");
+
+  result.clear();
+  ASSERT_EQ(variable_accessor.Rewrite("vec.z", &result), RewriteStatus::ERROR);
+  EXPECT_EQ(result, "INVALID_ACCESS_BY_FIELD");
+}
+
+TEST(PreprocessorTest, Variable) {
+  VariableAccessor variable_accessor(/*inline_values=*/true);
+  std::vector<int2> v;
+  v.push_back(int2(1, 2));
+  ASSERT_TRUE(variable_accessor.AddUniformParameter({"var", v}));
+  std::string result;
+  ASSERT_EQ(variable_accessor.Rewrite("var[i].y", &result),
+            RewriteStatus::SUCCESS);
+  ASSERT_EQ(result, "var[i].y");
+  EXPECT_EQ(variable_accessor.GetConstDeclarations(),
+            "const ivec2 var[] = ivec2[1](ivec2(1,2));\n");
+}
+
+TEST(PreprocessorTest, InlineVariableFail) {
+  VariableAccessor variable_accessor(/*inline_values=*/true);
+  ASSERT_TRUE(variable_accessor.AddUniformParameter({"var", 1}));
+  std::string result;
+  ASSERT_EQ(variable_accessor.Rewrite("var[i]", &result), RewriteStatus::ERROR);
+  EXPECT_EQ(result, "INVALID_ACCESS_BY_INDEX");
+}
+
+}  // namespace
+}  // namespace gl
+}  // namespace gpu
+}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/gl/compiler_options.h b/tensorflow/lite/delegates/gpu/gl/compiler_options.h
index a4545f5..6dbe7cb 100644
--- a/tensorflow/lite/delegates/gpu/gl/compiler_options.h
+++ b/tensorflow/lite/delegates/gpu/gl/compiler_options.h
@@ -16,7 +16,6 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_OPTIONS_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_OPTIONS_H_
 
-#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/gl/object.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/gl/egl_environment.cc b/tensorflow/lite/delegates/gpu/gl/egl_environment.cc
index 7179696..baf6002 100644
--- a/tensorflow/lite/delegates/gpu/gl/egl_environment.cc
+++ b/tensorflow/lite/delegates/gpu/gl/egl_environment.cc
@@ -18,6 +18,7 @@
 #include "absl/memory/memory.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/gl/gl_call.h"
+#include "tensorflow/lite/delegates/gpu/gl/request_gpu_info.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/gl/egl_environment.h b/tensorflow/lite/delegates/gpu/gl/egl_environment.h
index e23cc9c..fa7ca04 100644
--- a/tensorflow/lite/delegates/gpu/gl/egl_environment.h
+++ b/tensorflow/lite/delegates/gpu/gl/egl_environment.h
@@ -21,9 +21,9 @@
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/gl/egl_context.h"
 #include "tensorflow/lite/delegates/gpu/gl/egl_surface.h"
-#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/gl/portable_egl.h"
 #include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h"
+#include "tensorflow/lite/delegates/gpu/gl/request_gpu_info.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/gl/gl_program.cc b/tensorflow/lite/delegates/gpu/gl/gl_program.cc
index 8e63128..def8235 100644
--- a/tensorflow/lite/delegates/gpu/gl/gl_program.cc
+++ b/tensorflow/lite/delegates/gpu/gl/gl_program.cc
@@ -57,14 +57,17 @@
     return TFLITE_GPU_CALL_GL(glProgramUniform1i, program_id, uniform_id,
                               value);
   }
+
   Status operator()(const int2& value) {
     return TFLITE_GPU_CALL_GL(glProgramUniform2i, program_id, uniform_id,
                               value.x, value.y);
   }
+
   Status operator()(const int4& value) {
     return TFLITE_GPU_CALL_GL(glProgramUniform4i, program_id, uniform_id,
                               value.x, value.y, value.z, value.w);
   }
+
   Status operator()(const std::vector<int2>& value) {
     std::vector<GLint> ints(value.size() * 2, 0);
     for (int i = 0; i < value.size(); ++i) {
@@ -74,27 +77,44 @@
     return TFLITE_GPU_CALL_GL(glProgramUniform2iv, program_id, uniform_id,
                               ints.size(), ints.data());
   }
+
   Status operator()(unsigned int value) {
     return TFLITE_GPU_CALL_GL(glProgramUniform1ui, program_id, uniform_id,
                               value);
   }
+
   Status operator()(const uint4& value) {
     return TFLITE_GPU_CALL_GL(glProgramUniform4ui, program_id, uniform_id,
                               value.x, value.y, value.z, value.w);
   }
+
   Status operator()(float value) {
     return TFLITE_GPU_CALL_GL(glProgramUniform1f, program_id, uniform_id,
                               value);
   }
+
   Status operator()(const float2& value) {
     return TFLITE_GPU_CALL_GL(glProgramUniform2f, program_id, uniform_id,
                               value.x, value.y);
   }
+
   Status operator()(const float4& value) {
     return TFLITE_GPU_CALL_GL(glProgramUniform4f, program_id, uniform_id,
                               value.x, value.y, value.z, value.w);
   }
 
+  Status operator()(const std::vector<float4>& value) {
+    std::vector<GLfloat> floats(value.size() * 4, 0);
+    for (int i = 0; i < value.size(); ++i) {
+      floats[i * 4] = value[i].x;
+      floats[i * 4 + 1] = value[i].y;
+      floats[i * 4 + 2] = value[i].z;
+      floats[i * 4 + 3] = value[i].w;
+    }
+    return TFLITE_GPU_CALL_GL(glProgramUniform4fv, program_id, uniform_id,
+                              floats.size(), floats.data());
+  }
+
   const GLuint program_id;
   const GLint uniform_id;
 };
diff --git a/tensorflow/lite/delegates/gpu/gl/gpu_info.cc b/tensorflow/lite/delegates/gpu/gl/gpu_info.cc
deleted file mode 100644
index d40910c..0000000
--- a/tensorflow/lite/delegates/gpu/gl/gpu_info.cc
+++ /dev/null
@@ -1,155 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h"
-
-#include <algorithm>
-#include <cctype>
-#include <string>
-
-#include "absl/strings/ascii.h"
-#include "tensorflow/lite/delegates/gpu/gl/gl_errors.h"
-#include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h"
-
-namespace tflite {
-namespace gpu {
-namespace gl {
-namespace {
-
-GpuType GetGpuType(const std::string& renderer) {
-  if (renderer.find("mali") != renderer.npos) {
-    return GpuType::MALI;
-  }
-  if (renderer.find("adreno") != renderer.npos) {
-    return GpuType::ADRENO;
-  }
-  if (renderer.find("powervr") != renderer.npos) {
-    return GpuType::POWERVR;
-  }
-  if (renderer.find("intel") != renderer.npos) {
-    return GpuType::INTEL;
-  }
-  if (renderer.find("nvidia") != renderer.npos) {
-    return GpuType::NVIDIA;
-  }
-  return GpuType::UNKNOWN;
-}
-
-GpuModel GetGpuModel(const std::string& renderer) {
-  auto found_model = [&](std::string model) -> bool {
-    return renderer.find(model) != renderer.npos;
-  };
-  // Adreno 6xx series
-  if (found_model("640")) return GpuModel::ADRENO640;
-  if (found_model("630")) return GpuModel::ADRENO630;
-  if (found_model("616")) return GpuModel::ADRENO616;
-  if (found_model("615")) return GpuModel::ADRENO615;
-  if (found_model("612")) return GpuModel::ADRENO612;
-  if (found_model("605")) return GpuModel::ADRENO605;
-  // Adreno 5xx series
-  if (found_model("540")) return GpuModel::ADRENO540;
-  if (found_model("530")) return GpuModel::ADRENO530;
-  if (found_model("512")) return GpuModel::ADRENO512;
-  if (found_model("510")) return GpuModel::ADRENO510;
-  if (found_model("509")) return GpuModel::ADRENO509;
-  if (found_model("508")) return GpuModel::ADRENO508;
-  if (found_model("506")) return GpuModel::ADRENO506;
-  if (found_model("505")) return GpuModel::ADRENO505;
-  if (found_model("504")) return GpuModel::ADRENO504;
-  // Adreno 4xx series
-  if (found_model("430")) return GpuModel::ADRENO430;
-  if (found_model("420")) return GpuModel::ADRENO420;
-  if (found_model("418")) return GpuModel::ADRENO418;
-  if (found_model("405")) return GpuModel::ADRENO405;
-  // Adreno 3xx series
-  if (found_model("330")) return GpuModel::ADRENO330;
-  if (found_model("320")) return GpuModel::ADRENO320;
-  if (found_model("308")) return GpuModel::ADRENO308;
-  if (found_model("306")) return GpuModel::ADRENO306;
-  if (found_model("305")) return GpuModel::ADRENO305;
-  if (found_model("304")) return GpuModel::ADRENO304;
-  // Adreno 2xx series
-  if (found_model("225")) return GpuModel::ADRENO225;
-  if (found_model("220")) return GpuModel::ADRENO220;
-  if (found_model("205")) return GpuModel::ADRENO205;
-  if (found_model("203")) return GpuModel::ADRENO203;
-  if (found_model("200")) return GpuModel::ADRENO200;
-  // Adreno 1xx series
-  if (found_model("130")) return GpuModel::ADRENO130;
-  return GpuModel::UNKNOWN;
-}
-
-}  // namespace
-
-void GetGpuModelAndType(const std::string& renderer, GpuModel* gpu_model,
-                        GpuType* gpu_type) {
-  std::string lowered = renderer;
-  absl::AsciiStrToLower(&lowered);
-  *gpu_type = GetGpuType(lowered);
-  *gpu_model =
-      *gpu_type == GpuType::ADRENO ? GetGpuModel(lowered) : GpuModel::UNKNOWN;
-}
-
-Status RequestGpuInfo(GpuInfo* gpu_info) {
-  GpuInfo info;
-
-  const GLubyte* renderer_name = glGetString(GL_RENDERER);
-  if (renderer_name) {
-    info.renderer_name = reinterpret_cast<const char*>(renderer_name);
-    GetGpuModelAndType(info.renderer_name, &info.gpu_model, &info.type);
-  }
-
-  const GLubyte* vendor_name = glGetString(GL_VENDOR);
-  if (vendor_name) {
-    info.vendor_name = reinterpret_cast<const char*>(vendor_name);
-  }
-
-  const GLubyte* version_name = glGetString(GL_VERSION);
-  if (version_name) {
-    info.version = reinterpret_cast<const char*>(version_name);
-  }
-
-  glGetIntegerv(GL_MAJOR_VERSION, &info.major_version);
-  glGetIntegerv(GL_MINOR_VERSION, &info.minor_version);
-
-  GLint extensions_count;
-  glGetIntegerv(GL_NUM_EXTENSIONS, &extensions_count);
-  info.extensions.resize(extensions_count);
-  for (int i = 0; i < extensions_count; ++i) {
-    info.extensions[i] = std::string(
-        reinterpret_cast<const char*>(glGetStringi(GL_EXTENSIONS, i)));
-  }
-  glGetIntegerv(GL_MAX_COMPUTE_SHADER_STORAGE_BLOCKS, &info.max_ssbo_bindings);
-  glGetIntegerv(GL_MAX_COMPUTE_IMAGE_UNIFORMS, &info.max_image_bindings);
-  info.max_work_group_size.resize(3);
-  glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 0,
-                  &info.max_work_group_size[0]);
-  glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1,
-                  &info.max_work_group_size[1]);
-  glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 2,
-                  &info.max_work_group_size[2]);
-  glGetIntegerv(GL_MAX_COMPUTE_WORK_GROUP_INVOCATIONS,
-                &info.max_work_group_invocations);
-  glGetIntegerv(GL_MAX_TEXTURE_SIZE, &info.max_texture_size);
-  glGetIntegerv(GL_MAX_IMAGE_UNITS, &info.max_image_units);
-  glGetIntegerv(GL_MAX_ARRAY_TEXTURE_LAYERS, &info.max_array_texture_layers);
-  RETURN_IF_ERROR(GetOpenGlErrors());
-  *gpu_info = info;
-  return OkStatus();
-}
-
-}  // namespace gl
-}  // namespace gpu
-}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/gl/gpu_info.h b/tensorflow/lite/delegates/gpu/gl/gpu_info.h
deleted file mode 100644
index ba7e0a5..0000000
--- a/tensorflow/lite/delegates/gpu/gl/gpu_info.h
+++ /dev/null
@@ -1,100 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_GPU_INFO_H_
-#define TENSORFLOW_LITE_DELEGATES_GPU_GL_GPU_INFO_H_
-
-#include <string>
-#include <vector>
-
-#include "tensorflow/lite/delegates/gpu/common/status.h"
-
-namespace tflite {
-namespace gpu {
-namespace gl {
-
-enum class GpuType { UNKNOWN, MALI, ADRENO, POWERVR, INTEL, NVIDIA };
-enum class GpuModel {
-  UNKNOWN,
-  // Adreno 6xx series
-  ADRENO640,
-  ADRENO630,
-  ADRENO616,
-  ADRENO615,
-  ADRENO612,
-  ADRENO605,
-  // Adreno 5xx series
-  ADRENO540,
-  ADRENO530,
-  ADRENO512,
-  ADRENO510,
-  ADRENO509,
-  ADRENO508,
-  ADRENO506,
-  ADRENO505,
-  ADRENO504,
-  // Adreno 4xx series
-  ADRENO430,
-  ADRENO420,
-  ADRENO418,
-  ADRENO405,
-  // Adreno 3xx series
-  ADRENO330,
-  ADRENO320,
-  ADRENO308,
-  ADRENO306,
-  ADRENO305,
-  ADRENO304,
-  // Adreno 2xx series
-  ADRENO225,
-  ADRENO220,
-  ADRENO205,
-  ADRENO203,
-  ADRENO200,
-  // Adreno 1xx series
-  ADRENO130,
-};
-
-struct GpuInfo {
-  GpuType type = GpuType::UNKNOWN;
-  std::string renderer_name;
-  std::string vendor_name;
-  std::string version;
-  GpuModel gpu_model;
-  int major_version = -1;
-  int minor_version = -1;
-  std::vector<std::string> extensions;
-  int max_ssbo_bindings = 0;
-  int max_image_bindings = 0;
-  std::vector<int> max_work_group_size;
-  int max_work_group_invocations;
-  int max_texture_size = 0;
-  int max_image_units = 0;
-  int max_array_texture_layers = 0;
-};
-
-// Analyzes `renderer` and returns matching `GpuType` and `GpuModel`.
-void GetGpuModelAndType(const std::string& renderer, GpuModel* gpu_model,
-                        GpuType* gpu_type);
-
-// This method performs multiple GL calls, therefore, egl context needs to be
-// created upfront.
-Status RequestGpuInfo(GpuInfo* gpu_info);
-
-}  // namespace gl
-}  // namespace gpu
-}  // namespace tflite
-
-#endif  // TENSORFLOW_LITE_DELEGATES_GPU_GL_GPU_INFO_H_
diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/BUILD b/tensorflow/lite/delegates/gpu/gl/kernels/BUILD
index 50d204c..63b0683 100644
--- a/tensorflow/lite/delegates/gpu/gl/kernels/BUILD
+++ b/tensorflow/lite/delegates/gpu/gl/kernels/BUILD
@@ -114,6 +114,13 @@
 )
 
 cc_library(
+    name = "custom_registry",
+    srcs = ["custom_registry.cc"],
+    hdrs = ["custom_registry.h"],
+    deps = ["//tensorflow/lite/delegates/gpu/gl:node_shader"],
+)
+
+cc_library(
     name = "depthwise_conv",
     srcs = ["depthwise_conv.cc"],
     hdrs = ["depthwise_conv.h"],
@@ -575,9 +582,9 @@
         "//tensorflow/lite/delegates/gpu/gl:compiler_options",
         "//tensorflow/lite/delegates/gpu/gl:egl_environment",
         "//tensorflow/lite/delegates/gpu/gl:gl_buffer",
-        "//tensorflow/lite/delegates/gpu/gl:gpu_info",
         "//tensorflow/lite/delegates/gpu/gl:node_shader",
         "//tensorflow/lite/delegates/gpu/gl:object_manager",
+        "//tensorflow/lite/delegates/gpu/gl:request_gpu_info",
         "//tensorflow/lite/delegates/gpu/gl:runtime_options",
         "//tensorflow/lite/delegates/gpu/gl/workgroups:default_calculator",
         "@com_google_googletest//:gtest",
@@ -690,6 +697,7 @@
                "//tensorflow/lite/delegates/gpu:tflite_gpu_binary_release": [],
                "//conditions:default": NON_TFLITE_GPU_BINARY_RELEASE_OPERATORS,
            }) + [
+        ":custom_registry",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/gl:node_shader",
diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/custom_registry.cc b/tensorflow/lite/delegates/gpu/gl/kernels/custom_registry.cc
new file mode 100644
index 0000000..f5c5429
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/gl/kernels/custom_registry.cc
@@ -0,0 +1,33 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/delegates/gpu/gl/kernels/custom_registry.h"
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+namespace tflite {
+namespace gpu {
+namespace gl {
+
+void RegisterCustomOps(
+    std::unordered_map<std::string, std::vector<std::unique_ptr<NodeShader>>>*
+        shaders) {}
+
+}  // namespace gl
+}  // namespace gpu
+}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/custom_registry.h b/tensorflow/lite/delegates/gpu/gl/kernels/custom_registry.h
new file mode 100644
index 0000000..9a979a9
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/gl/kernels/custom_registry.h
@@ -0,0 +1,39 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_CUSTOM_REGISTRY_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_CUSTOM_REGISTRY_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/lite/delegates/gpu/gl/node_shader.h"
+
+namespace tflite {
+namespace gpu {
+namespace gl {
+
+// Registers custom operations.
+void RegisterCustomOps(
+    std::unordered_map<std::string, std::vector<std::unique_ptr<NodeShader>>>*
+        shaders_);
+
+}  // namespace gl
+}  // namespace gpu
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_CUSTOM_REGISTRY_H_
diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc b/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc
index 7c93ebd..3744a77 100644
--- a/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc
+++ b/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc
@@ -30,6 +30,7 @@
 #include "tensorflow/lite/delegates/gpu/gl/kernels/add.h"
 #include "tensorflow/lite/delegates/gpu/gl/kernels/concat.h"
 #include "tensorflow/lite/delegates/gpu/gl/kernels/conv.h"
+#include "tensorflow/lite/delegates/gpu/gl/kernels/custom_registry.h"
 #include "tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.h"
 #include "tensorflow/lite/delegates/gpu/gl/kernels/elementwise.h"
 #include "tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.h"
@@ -86,7 +87,7 @@
     insert_op(Type::RELU, NewReLUNodeShader);
     insert_op(Type::RESHAPE, NewReshapeNodeShader);
     insert_op(Type::SLICE, NewSliceNodeShader);
-    insert_op(Type::SOFT_MAX, NewSoftMaxNodeShader);
+    insert_op(Type::SOFTMAX, NewSoftmaxNodeShader);
     insert_op(Type::UPSAMPLE_2D, NewUpsamplingNodeShader);
 
     insert_elementwise_op(Type::ABS);
@@ -106,6 +107,7 @@
 
 #ifndef TFLITE_GPU_BINARY_RELEASE
     insert_op(Type::MAX_UNPOOLING_2D, NewMaxUnpoolingNodeShader);
+    RegisterCustomOps(&shaders_);
 #endif  // TFLITE_GPU_BINARY_RELEASE
   }
 
diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc b/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc
index 9067ec9..6d95590 100644
--- a/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc
+++ b/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc
@@ -15,10 +15,9 @@
 
 #include "tensorflow/lite/delegates/gpu/gl/kernels/softmax.h"
 
-#include <algorithm>
-#include <cstdint>
-#include <cstring>
+#include <memory>
 #include <string>
+#include <utility>
 #include <vector>
 
 #include "absl/memory/memory.h"
@@ -33,33 +32,117 @@
 namespace gl {
 namespace {
 
-class SoftMax : public NodeShader {
+float4 GetMask(int num_channels) {
+  float4 mask(0.0f);
+  const int remainder = num_channels % 4 == 0 ? 4 : num_channels % 4;
+  for (int i = 0; i < remainder; ++i) mask[i] = 1.0f;
+  return mask;
+}
+
+class Softmax : public NodeShader {
  public:
   Status GenerateCode(const GenerationContext& ctx,
                       GeneratedCode* generated_code) const final {
-    auto input = ctx.graph->FindInputs(ctx.node->id)[0];
-    auto output = ctx.graph->FindOutputs(ctx.node->id)[0];
-    auto attr =
-        absl::any_cast<SoftMaxAttributes>(ctx.node->operation.attributes);
+    const auto* input = ctx.graph->FindInputs(ctx.node->id)[0];
+    const auto* output = ctx.graph->FindOutputs(ctx.node->id)[0];
+    const auto& attr = absl::any_cast<const SoftmaxAttributes&>(
+        ctx.node->operation.attributes);
     if (input->tensor.shape != output->tensor.shape) {
-      return InvalidArgumentError("Input and output shape does not match");
+      return InvalidArgumentError("Input and output shapes do not match.");
     }
     if (attr.axis != Axis::CHANNELS) {
       return UnimplementedError("Softmax is only supported for channels axis.");
     }
+    return input->tensor.shape.h == 1 && input->tensor.shape.w == 1
+               ? GenerateCodeFor1x1(ctx, generated_code)
+               : GenerateCodeGeneral(ctx, generated_code);
+  }
 
-    float4 mask(0.0f);
-    const int channels = output->tensor.shape.c;
-    const int reminder = (channels % 4 == 0) ? 4 : channels % 4;
-    for (int i = 0; i < reminder; ++i) {
-      mask[i] = 1.0f;
+ private:
+  Status GenerateCodeFor1x1(const GenerationContext& ctx,
+                            GeneratedCode* generated_code) const {
+    const auto* output = ctx.graph->FindOutputs(ctx.node->id)[0];
+    const int depth = IntegralDivideRoundUp(output->tensor.shape.c, 4);
+    std::vector<Variable> shared_variables = {
+        {"partial_sum", std::vector<float4>(8)},
+    };
+    std::vector<Variable> uniform_parameters = {
+        {"depth", depth},
+        {"depth_div_32", IntegralDivideRoundUp(depth, 32)},
+        {"mask", GetMask(output->tensor.shape.c)},
+    };
+    std::string source_code = R"(
+  highp float sum = 0.0f;
+  int offset = 0;
+  int s = 0;
+  int tid = int(gl_LocalInvocationID.x);
+  do {
+    int z = offset + tid;
+    if (z < $depth$) {
+      vec4 mask_temp = z == $depth$ - 1 ? $mask$ : vec4(1.0f);
+      vec4 src = $input_data_0[0, 0, z]$;
+      sum += dot(mask_temp, exp(src));
+      offset += 32;
     }
+    s++;
+  } while (s < $depth_div_32$);
+
+  partial_sum[tid / 4][tid % 4] = sum;
+
+  memoryBarrierShared();
+  barrier();
+
+  if (tid == 0) {
+    sum = dot(vec4(1.0f), partial_sum[0]);
+    sum += dot(vec4(1.0f), partial_sum[1]);
+    sum += dot(vec4(1.0f), partial_sum[2]);
+    sum += dot(vec4(1.0f), partial_sum[3]);
+    sum += dot(vec4(1.0f), partial_sum[4]);
+    sum += dot(vec4(1.0f), partial_sum[5]);
+    sum += dot(vec4(1.0f), partial_sum[6]);
+    sum += dot(vec4(1.0f), partial_sum[7]);
+    partial_sum[0][0] = 1.0 / sum;
+  }
+
+  memoryBarrierShared();
+  barrier();
+
+  sum = partial_sum[0][0];
+
+  offset = 0;
+  s = 0;
+  do {
+    int z = offset + tid;
+    if (z < $depth$) {
+      vec4 temp = exp($input_data_0[0, 0, z]$) * sum;
+      $output_data_0[0, 0, z]$ = temp;
+      offset += 32;
+    }
+    s++;
+  } while (s < $depth_div_32$);
+)";
+    *generated_code = {
+        /*parameters=*/std::move(uniform_parameters),
+        /*objects=*/{},
+        /*shared_variables=*/std::move(shared_variables),
+        /*workload=*/uint3(32, 1, 1),
+        /*workgroup=*/uint3(32, 1, 1),
+        /*source_code=*/std::move(source_code),
+        /*input=*/IOStructure::ONLY_DEFINITIONS,
+        /*output=*/IOStructure::ONLY_DEFINITIONS,
+    };
+    return OkStatus();
+  }
+
+  Status GenerateCodeGeneral(const GenerationContext& ctx,
+                             GeneratedCode* generated_code) const {
+    const auto* output = ctx.graph->FindOutputs(ctx.node->id)[0];
     std::vector<Variable> parameters = {
         {"src_depth", IntegralDivideRoundUp(output->tensor.shape.c, 4)},
-        {"mask", mask},
+        {"mask", GetMask(output->tensor.shape.c)},
     };
 
-    std::string source = R"(
+    std::string source_code = R"(
   highp float sum = 0.0;
   for (int d = 0; d < $src_depth$ - 1; ++d) {
     sum += dot(vec4(1.0), exp($input_data_0[gid.x, gid.y, d]$));
@@ -79,7 +162,7 @@
         /*shared_variables=*/{},
         /*workload=*/uint3(output->tensor.shape.w, output->tensor.shape.h, 1),
         /*workgroup=*/uint3(),
-        /*source_code=*/std::move(source),
+        /*source_code=*/std::move(source_code),
         /*input=*/IOStructure::ONLY_DEFINITIONS,
         /*output=*/IOStructure::ONLY_DEFINITIONS,
     };
@@ -89,8 +172,8 @@
 
 }  // namespace
 
-std::unique_ptr<NodeShader> NewSoftMaxNodeShader() {
-  return absl::make_unique<SoftMax>();
+std::unique_ptr<NodeShader> NewSoftmaxNodeShader() {
+  return absl::make_unique<Softmax>();
 }
 
 }  // namespace gl
diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/softmax.h b/tensorflow/lite/delegates/gpu/gl/kernels/softmax.h
index 2eaf91b..2b6c786 100644
--- a/tensorflow/lite/delegates/gpu/gl/kernels/softmax.h
+++ b/tensorflow/lite/delegates/gpu/gl/kernels/softmax.h
@@ -25,7 +25,7 @@
 namespace gpu {
 namespace gl {
 
-std::unique_ptr<NodeShader> NewSoftMaxNodeShader();
+std::unique_ptr<NodeShader> NewSoftmaxNodeShader();
 
 }  // namespace gl
 }  // namespace gpu
diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/softmax_test.cc b/tensorflow/lite/delegates/gpu/gl/kernels/softmax_test.cc
index 1c82a80..1707e1e 100644
--- a/tensorflow/lite/delegates/gpu/gl/kernels/softmax_test.cc
+++ b/tensorflow/lite/delegates/gpu/gl/kernels/softmax_test.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/lite/delegates/gpu/gl/kernels/softmax.h"
 
+#include <cmath>
 #include <vector>
 
 #include <gmock/gmock.h>
@@ -31,7 +32,7 @@
 namespace gl {
 namespace {
 
-TEST(SoftmaxTest, WorksForChannelsAxis) {
+TEST(SoftmaxTest, Softmax) {
   TensorRef<BHWC> input;
   input.type = DataType::FLOAT32;
   input.ref = 0;
@@ -42,14 +43,15 @@
   output.ref = 1;
   output.shape = BHWC(1, 2, 2, 1);
 
-  SoftMaxAttributes attr;
+  SoftmaxAttributes attr;
   attr.axis = Axis::CHANNELS;
 
-  SingleOpModel model({ToString(OperationType::SOFT_MAX), attr}, {input},
+  SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input},
                       {output});
-  ASSERT_TRUE(model.PopulateTensor(0, {0.1, 0.2, 0.1, 0.2}));
-  ASSERT_OK(model.Invoke(*NewSoftMaxNodeShader()));
-  EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {1, 1, 1, 1}));
+  ASSERT_TRUE(model.PopulateTensor(0, {0.1f, 0.2f, 0.3f, 0.4f}));
+  ASSERT_OK(model.Invoke(*NewSoftmaxNodeShader()));
+  EXPECT_THAT(model.GetOutput(0),
+              Pointwise(FloatNear(1e-6f), {1.0f, 1.0f, 1.0f, 1.0f}));
 }
 
 TEST(SoftmaxTest, DoesNotWorkForHeightAxis) {
@@ -63,15 +65,13 @@
   output.ref = 1;
   output.shape = BHWC(1, 2, 2, 1);
 
-  SoftMaxAttributes attr;
+  SoftmaxAttributes attr;
   attr.axis = Axis::HEIGHT;
 
-  SingleOpModel model({ToString(OperationType::SOFT_MAX), attr}, {input},
+  SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input},
                       {output});
-  ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
-  ASSERT_THAT(
-      model.Invoke(*NewSoftMaxNodeShader()).message(),
-      testing::HasSubstr("Softmax is only supported for channels axis."));
+  ASSERT_TRUE(model.PopulateTensor(0, {0.1f, 0.2f, 0.3f, 0.4f}));
+  EXPECT_FALSE(model.Invoke(*NewSoftmaxNodeShader()).ok());
 }
 
 TEST(SoftmaxTest, DoesNotWorkForWidthAxis) {
@@ -85,15 +85,40 @@
   output.ref = 1;
   output.shape = BHWC(1, 2, 2, 1);
 
-  SoftMaxAttributes attr;
+  SoftmaxAttributes attr;
   attr.axis = Axis::WIDTH;
 
-  SingleOpModel model({ToString(OperationType::SOFT_MAX), attr}, {input},
+  SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input},
                       {output});
-  ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
-  ASSERT_THAT(
-      model.Invoke(*NewSoftMaxNodeShader()).message(),
-      testing::HasSubstr("Softmax is only supported for channels axis."));
+  ASSERT_TRUE(model.PopulateTensor(0, {0.1f, 0.2f, 0.3f, 0.4f}));
+  EXPECT_FALSE(model.Invoke(*NewSoftmaxNodeShader()).ok());
+}
+
+TEST(SoftmaxTest, Softmax1x1) {
+  TensorRef<BHWC> input;
+  input.type = DataType::FLOAT32;
+  input.ref = 0;
+  input.shape = BHWC(1, 1, 1, 4);
+
+  TensorRef<BHWC> output;
+  output.type = DataType::FLOAT32;
+  output.ref = 1;
+  output.shape = BHWC(1, 1, 1, 4);
+
+  SoftmaxAttributes attr;
+  attr.axis = Axis::CHANNELS;
+
+  const float sum =
+      std::exp(0.1f) + std::exp(0.2f) + std::exp(0.3f) + std::exp(0.4f);
+
+  SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input},
+                      {output});
+  ASSERT_TRUE(model.PopulateTensor(0, {0.1f, 0.2f, 0.3f, 0.4f}));
+  ASSERT_OK(model.Invoke(*NewSoftmaxNodeShader()));
+  EXPECT_THAT(model.GetOutput(0),
+              Pointwise(FloatNear(1e-6f),
+                        {std::exp(0.1f) / sum, std::exp(0.2f) / sum,
+                         std::exp(0.3f) / sum, std::exp(0.4f) / sum}));
 }
 
 }  // namespace
diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/test_util.cc b/tensorflow/lite/delegates/gpu/gl/kernels/test_util.cc
index e55eaf4..de6e324 100644
--- a/tensorflow/lite/delegates/gpu/gl/kernels/test_util.cc
+++ b/tensorflow/lite/delegates/gpu/gl/kernels/test_util.cc
@@ -28,8 +28,8 @@
 #include "tensorflow/lite/delegates/gpu/gl/api.h"
 #include "tensorflow/lite/delegates/gpu/gl/egl_environment.h"
 #include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
-#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/gl/object_manager.h"
+#include "tensorflow/lite/delegates/gpu/gl/request_gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.cc b/tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.cc
index b9ecd09..e84f3ef 100644
--- a/tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.cc
+++ b/tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.cc
@@ -41,8 +41,6 @@
     auto attr = absl::any_cast<const ConvolutionTransposedAttributes&>(
         ctx.node->operation.attributes);
     auto weights = attr.weights.shape;
-    const int32_t inner_size_w = (weights.w - 1) / attr.stride.w + 1;
-    const int32_t inner_size_h = (weights.h - 1) / attr.stride.h + 1;
 
     std::vector<Variable> parameters = {
         {"input_data_0_h", input->tensor.shape.h},
@@ -50,33 +48,25 @@
         {"src_depth", IntegralDivideRoundUp(weights.i, 4)},
         {"kernel_size", int2(weights.w, weights.h)},
         {"stride", int2(attr.stride.w, attr.stride.h)},
-        {"padding", int2(attr.padding.prepended.w, attr.padding.prepended.h)},
-        {"inner_size", int2(inner_size_w, inner_size_h)},
+        {"padding", int2(weights.w - 1 - attr.padding.prepended.w,
+                         weights.h - 1 - attr.padding.prepended.h)},
     };
 
     std::vector<std::pair<std::string, Object>> objects = {
-        {"weights", MakeReadonlyObject(Get3DSizeForPHWO4I4(attr.weights.shape),
-                                       ConvertToPHWO4I4(attr.weights))}};
+        {"weights",
+         MakeReadonlyObject(Get3DSizeForPHWO4I4(attr.weights.shape),
+                            ConvertToPHWO4I4Transposed(attr.weights))}};
 
     std::string source = R"(
-    ivec2 kernel_offset = $kernel_size$ - ivec2(1,1);
-    ivec2 offset = gid.xy + $padding$ - kernel_offset;
-    offset %= $stride$;
-    offset += $stride$;
-    offset %= $stride$;
-    ivec2 f_offset;
-    f_offset.x = offset.x == 0 ? 0 : ($stride.x$ - offset.x);
-    f_offset.y = offset.y == 0 ? 0 : ($stride.y$ - offset.y);
-    for (int ky = 0; ky < $inner_size.y$; ++ky) {
-      for (int kx = 0; kx < $inner_size.x$; ++kx) {
-        ivec2 index = ivec2(kx, ky) * $stride$ + f_offset;
-        bool inside_kernel = index.x < $kernel_size.x$ && index.y < $kernel_size.y$;
-        ivec2 coord = (gid.xy + index + $padding$ - kernel_offset) / $stride$;
-        bool outside = coord.x < 0 || coord.y < 0 ||
-                       coord.x >= $input_data_0_w$ || coord.y >= $input_data_0_h$;
-        if (inside_kernel && !outside) {
-          index = kernel_offset - index;
-          int i = index.y * $kernel_size.x$ + index.x;
+    #define IN_BOUNDS(p, p0, p1) (all(greaterThanEqual(p, p0)) && all(lessThan(p, p1)))
+
+    ivec2 p0 = ($padding$ + $stride$ - gid.xy % $stride$) % $stride$;
+    for (int y = p0.y; y < $kernel_size.y$; y += $stride.y$) {
+      for (int x = p0.x; x < $kernel_size.x$; x += $stride.x$) {
+        int i = y * $kernel_size.x$ + x;
+        ivec2 idx = gid.xy + ivec2(x, y) - $padding$;
+        if (IN_BOUNDS(idx, ivec2(0), ivec2($input_data_0_w$, $input_data_0_h$) * $stride$)) {
+          ivec2 coord = idx / $stride$;
           for (int l = 0; l < $src_depth$; ++l) {
             vec4 src_color = $input_data_0[coord.x, coord.y, l]$;
             value_0.x += dot(src_color, $weights[l * 4 + 0, i, gid.z]$);
diff --git a/tensorflow/lite/delegates/gpu/gl/node_shader.h b/tensorflow/lite/delegates/gpu/gl/node_shader.h
index 0225a7c..3836465 100644
--- a/tensorflow/lite/delegates/gpu/gl/node_shader.h
+++ b/tensorflow/lite/delegates/gpu/gl/node_shader.h
@@ -21,11 +21,11 @@
 #include <string>
 #include <vector>
 
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 #include "tensorflow/lite/delegates/gpu/gl/compiler_options.h"
-#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/gl/object.h"
 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
 
diff --git a/tensorflow/lite/delegates/gpu/gl/request_gpu_info.cc b/tensorflow/lite/delegates/gpu/gl/request_gpu_info.cc
new file mode 100644
index 0000000..7134fc0
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/gl/request_gpu_info.cc
@@ -0,0 +1,81 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/delegates/gpu/gl/request_gpu_info.h"
+
+#include <algorithm>
+#include <cctype>
+#include <string>
+
+#include "absl/strings/ascii.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
+#include "tensorflow/lite/delegates/gpu/gl/gl_errors.h"
+#include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h"
+
+namespace tflite {
+namespace gpu {
+namespace gl {
+
+Status RequestGpuInfo(GpuInfo* gpu_info) {
+  GpuInfo info;
+
+  const GLubyte* renderer_name = glGetString(GL_RENDERER);
+  if (renderer_name) {
+    info.renderer_name = reinterpret_cast<const char*>(renderer_name);
+    GetGpuModelAndType(info.renderer_name, &info.gpu_model, &info.type);
+  }
+
+  const GLubyte* vendor_name = glGetString(GL_VENDOR);
+  if (vendor_name) {
+    info.vendor_name = reinterpret_cast<const char*>(vendor_name);
+  }
+
+  const GLubyte* version_name = glGetString(GL_VERSION);
+  if (version_name) {
+    info.version = reinterpret_cast<const char*>(version_name);
+  }
+
+  glGetIntegerv(GL_MAJOR_VERSION, &info.major_version);
+  glGetIntegerv(GL_MINOR_VERSION, &info.minor_version);
+
+  GLint extensions_count;
+  glGetIntegerv(GL_NUM_EXTENSIONS, &extensions_count);
+  info.extensions.resize(extensions_count);
+  for (int i = 0; i < extensions_count; ++i) {
+    info.extensions[i] = std::string(
+        reinterpret_cast<const char*>(glGetStringi(GL_EXTENSIONS, i)));
+  }
+  glGetIntegerv(GL_MAX_COMPUTE_SHADER_STORAGE_BLOCKS, &info.max_ssbo_bindings);
+  glGetIntegerv(GL_MAX_COMPUTE_IMAGE_UNIFORMS, &info.max_image_bindings);
+  info.max_work_group_size.resize(3);
+  glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 0,
+                  &info.max_work_group_size[0]);
+  glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1,
+                  &info.max_work_group_size[1]);
+  glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 2,
+                  &info.max_work_group_size[2]);
+  glGetIntegerv(GL_MAX_COMPUTE_WORK_GROUP_INVOCATIONS,
+                &info.max_work_group_invocations);
+  glGetIntegerv(GL_MAX_TEXTURE_SIZE, &info.max_texture_size);
+  glGetIntegerv(GL_MAX_IMAGE_UNITS, &info.max_image_units);
+  glGetIntegerv(GL_MAX_ARRAY_TEXTURE_LAYERS, &info.max_array_texture_layers);
+  RETURN_IF_ERROR(GetOpenGlErrors());
+  *gpu_info = info;
+  return OkStatus();
+}
+
+}  // namespace gl
+}  // namespace gpu
+}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/gl/request_gpu_info.h b/tensorflow/lite/delegates/gpu/gl/request_gpu_info.h
new file mode 100644
index 0000000..4eba7a5
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/gl/request_gpu_info.h
@@ -0,0 +1,37 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_REQUEST_GPU_INFO_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_GL_REQUEST_GPU_INFO_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
+#include "tensorflow/lite/delegates/gpu/common/status.h"
+
+namespace tflite {
+namespace gpu {
+namespace gl {
+
+// This method performs multiple GL calls, therefore, egl context needs to be
+// created upfront.
+Status RequestGpuInfo(GpuInfo* gpu_info);
+
+}  // namespace gl
+}  // namespace gpu
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_GL_REQUEST_GPU_INFO_H_
diff --git a/tensorflow/lite/delegates/gpu/gl/runtime.cc b/tensorflow/lite/delegates/gpu/gl/runtime.cc
index 7249ac4..37bf66e 100644
--- a/tensorflow/lite/delegates/gpu/gl/runtime.cc
+++ b/tensorflow/lite/delegates/gpu/gl/runtime.cc
@@ -22,6 +22,7 @@
 
 #include "absl/strings/str_cat.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 #include "tensorflow/lite/delegates/gpu/gl/gl_call.h"
diff --git a/tensorflow/lite/delegates/gpu/gl/runtime.h b/tensorflow/lite/delegates/gpu/gl/runtime.h
index 23fff93..46e0732 100644
--- a/tensorflow/lite/delegates/gpu/gl/runtime.h
+++ b/tensorflow/lite/delegates/gpu/gl/runtime.h
@@ -18,13 +18,13 @@
 
 #include <vector>
 
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 #include "tensorflow/lite/delegates/gpu/gl/command_queue.h"
 #include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
 #include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
 #include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
-#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/gl/object.h"
 #include "tensorflow/lite/delegates/gpu/gl/object_manager.h"
 #include "tensorflow/lite/delegates/gpu/gl/runtime/shared_buffer.h"
diff --git a/tensorflow/lite/delegates/gpu/gl/serialization.cc b/tensorflow/lite/delegates/gpu/gl/serialization.cc
index 200ca1f..17db339 100644
--- a/tensorflow/lite/delegates/gpu/gl/serialization.cc
+++ b/tensorflow/lite/delegates/gpu/gl/serialization.cc
@@ -37,12 +37,14 @@
     data.add_data(offset);
     return data.Finish().Union();
   }
+
   Offset<void> operator()(const int2& value) {
     auto offset = builder->CreateVector(std::vector<int32_t>{value.x, value.y});
     data::DataInt32Builder data(*builder);
     data.add_data(offset);
     return data.Finish().Union();
   }
+
   Offset<void> operator()(const int4& value) {
     auto offset = builder->CreateVector(
         std::vector<int32_t>{value.x, value.y, value.z, value.w});
@@ -50,6 +52,7 @@
     data.add_data(offset);
     return data.Finish().Union();
   }
+
   Offset<void> operator()(const std::vector<int2>& value) {
     std::vector<int32_t> d(value.size() * 2);
     for (size_t i = 0; i < value.size(); ++i) {
@@ -61,12 +64,14 @@
     data.add_data(offset);
     return data.Finish().Union();
   }
+
   Offset<void> operator()(uint32_t value) {
     auto offset = builder->CreateVector(std::vector<uint32_t>{value});
     data::DataUint32Builder data(*builder);
     data.add_data(offset);
     return data.Finish().Union();
   }
+
   Offset<void> operator()(const uint4& value) {
     auto offset = builder->CreateVector(
         std::vector<uint32_t>{value.x, value.y, value.z, value.w});
@@ -74,18 +79,21 @@
     data.add_data(offset);
     return data.Finish().Union();
   }
+
   Offset<void> operator()(float value) {
     auto offset = builder->CreateVector(std::vector<float>{value});
     data::DataFloatBuilder data(*builder);
     data.add_data(offset);
     return data.Finish().Union();
   }
+
   Offset<void> operator()(const float2& value) {
     auto offset = builder->CreateVector(std::vector<float>{value.x, value.y});
     data::DataFloatBuilder data(*builder);
     data.add_data(offset);
     return data.Finish().Union();
   }
+
   Offset<void> operator()(const float4& value) {
     auto offset = builder->CreateVector(
         std::vector<float>{value.x, value.y, value.z, value.w});
@@ -94,6 +102,20 @@
     return data.Finish().Union();
   }
 
+  Offset<void> operator()(const std::vector<float4>& value) {
+    std::vector<float> d(value.size() * 4);
+    for (size_t i = 0; i < value.size(); ++i) {
+      d[i * 4] = value[i].x;
+      d[i * 4 + 1] = value[i].y;
+      d[i * 4 + 2] = value[i].z;
+      d[i * 4 + 3] = value[i].w;
+    }
+    auto offset = builder->CreateVector(d);
+    data::DataFloatBuilder data(*builder);
+    data.add_data(offset);
+    return data.Finish().Union();
+  }
+
   ::flatbuffers::FlatBufferBuilder* builder;
 };
 
@@ -101,60 +123,84 @@
   data::DataVariant operator()(int32_t) const {
     return data::DataVariant::DataInt32;
   }
+
   data::DataVariant operator()(const int2&) const {
     return data::DataVariant::DataInt32;
   }
+
   data::DataVariant operator()(const int4&) const {
     return data::DataVariant::DataInt32;
   }
+
   data::DataVariant operator()(const std::vector<int2>&) const {
     return data::DataVariant::DataInt32;
   }
+
   data::DataVariant operator()(uint32_t) const {
     return data::DataVariant::DataUint32;
   }
+
   data::DataVariant operator()(const uint4&) const {
     return data::DataVariant::DataUint32;
   }
+
   data::DataVariant operator()(float) const {
     return data::DataVariant::DataFloat;
   }
+
   data::DataVariant operator()(const float2&) const {
     return data::DataVariant::DataFloat;
   }
+
   data::DataVariant operator()(const float4&) const {
     return data::DataVariant::DataFloat;
   }
+
+  data::DataVariant operator()(const std::vector<float4>&) const {
+    return data::DataVariant::DataFloat;
+  }
 };
 
 struct ParameterTypeGetter {
   data::ParameterType operator()(int32_t) const {
     return data::ParameterType::INT32;
   }
+
   data::ParameterType operator()(const int2&) const {
     return data::ParameterType::INT32;
   }
+
   data::ParameterType operator()(const int4&) const {
     return data::ParameterType::INT32;
   }
+
   data::ParameterType operator()(const std::vector<int2>&) const {
     return data::ParameterType::INT32_2;
   }
+
   data::ParameterType operator()(uint32_t) const {
     return data::ParameterType::UINT32;
   }
+
   data::ParameterType operator()(const uint4&) const {
     return data::ParameterType::UINT32;
   }
+
   data::ParameterType operator()(float) const {
     return data::ParameterType::FLOAT32;
   }
+
   data::ParameterType operator()(const float2&) const {
     return data::ParameterType::FLOAT32;
   }
+
   data::ParameterType operator()(const float4&) const {
     return data::ParameterType::FLOAT32;
   }
+
+  data::ParameterType operator()(const std::vector<float4>&) const {
+    return data::ParameterType::FLOAT32;
+  }
 };
 
 data::DataType ToFB(DataType type) {
diff --git a/tensorflow/lite/delegates/gpu/gl/serialization_test.cc b/tensorflow/lite/delegates/gpu/gl/serialization_test.cc
index 38db441..27a3583a 100644
--- a/tensorflow/lite/delegates/gpu/gl/serialization_test.cc
+++ b/tensorflow/lite/delegates/gpu/gl/serialization_test.cc
@@ -70,14 +70,17 @@
   bool operator()(int32_t value) const {
     return value == absl::get<int32_t>(a.value);
   }
+
   bool operator()(const int2& value) const {
     auto v = absl::get<int2>(a.value);
     return value.x == v.x && value.y == v.y;
   }
+
   bool operator()(const int4& value) const {
     auto v = absl::get<int4>(a.value);
     return value.x == v.x && value.y == v.y && value.z == v.z && value.w == v.w;
   }
+
   bool operator()(const std::vector<int2>& value) const {
     auto v = absl::get<std::vector<int2>>(a.value);
     if (v.size() != value.size()) {
@@ -90,24 +93,43 @@
     }
     return true;
   }
+
   bool operator()(uint32_t value) const {
     return value == absl::get<uint32_t>(a.value);
   }
+
   bool operator()(const uint4& value) const {
     auto v = absl::get<uint4>(a.value);
     return value.x == v.x && value.y == v.y && value.z == v.z && value.w == v.w;
   }
+
   bool operator()(float value) const {
     return value == absl::get<float>(a.value);
   }
+
   bool operator()(float2 value) const {
     auto v = absl::get<float2>(a.value);
     return value.x == v.x && value.y == v.y;
   }
+
   bool operator()(const float4& value) const {
     auto v = absl::get<float4>(a.value);
     return value.x == v.x && value.y == v.y && value.z == v.z && value.w == v.w;
   }
+
+  bool operator()(const std::vector<float4>& value) const {
+    auto v = absl::get<std::vector<float4>>(a.value);
+    if (v.size() != value.size()) {
+      return false;
+    }
+    for (int i = 0; i < v.size(); ++i) {
+      if (v[i].x != value[i].x || v[i].y != value[i].y) {
+        return false;
+      }
+    }
+    return true;
+  }
+
   Variable a;
 };
 
diff --git a/tensorflow/lite/delegates/gpu/gl/variable.h b/tensorflow/lite/delegates/gpu/gl/variable.h
index f2f3979..1c5bb26 100644
--- a/tensorflow/lite/delegates/gpu/gl/variable.h
+++ b/tensorflow/lite/delegates/gpu/gl/variable.h
@@ -28,8 +28,9 @@
 namespace gl {
 
 struct Variable {
-  using ValueType = absl::variant<int32_t, int2, int4, uint32_t, uint4, float,
-                                  float2, float4, std::vector<int2>>;
+  using ValueType =
+      absl::variant<int32_t, int2, int4, uint32_t, uint4, float, float2, float4,
+                    std::vector<int2>, std::vector<float4>>;
 
   std::string name;
   ValueType value;
diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/BUILD b/tensorflow/lite/delegates/gpu/gl/workgroups/BUILD
index 28a172b..52fdb74 100644
--- a/tensorflow/lite/delegates/gpu/gl/workgroups/BUILD
+++ b/tensorflow/lite/delegates/gpu/gl/workgroups/BUILD
@@ -8,8 +8,8 @@
     srcs = ["calculator.cc"],
     hdrs = ["calculator.h"],
     deps = [
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:types",
-        "//tensorflow/lite/delegates/gpu/gl:gpu_info",
         "//tensorflow/lite/delegates/gpu/gl/compiler:shader_code",
     ],
 )
@@ -20,8 +20,8 @@
     hdrs = ["default_calculator.h"],
     deps = [
         ":calculator",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:types",
-        "//tensorflow/lite/delegates/gpu/gl:gpu_info",
     ],
 )
 
@@ -35,7 +35,7 @@
             ":default_calculator",
             "//tensorflow/lite/delegates/gpu/gl:common_cc_fbs",
             "//tensorflow/lite/delegates/gpu/gl:workgroups_cc_fbs",
-            "//tensorflow/lite/delegates/gpu/gl:gpu_info",
+            "//tensorflow/lite/delegates/gpu/common:gpu_info",
             "//tensorflow/lite/delegates/gpu/gl:metadata_cc_fbs",
             ":calculator",
             "@com_google_absl//absl/memory",
@@ -52,7 +52,7 @@
     deps = [
         ":calculator",
         ":default_calculator",
-        "//tensorflow/lite/delegates/gpu/gl:gpu_info",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
     ] + select({
         "//tensorflow/lite/delegates/gpu:tflite_gpu_binary_release": [],
         "//conditions:default": [
@@ -67,9 +67,9 @@
     hdrs = ["ideal_workgroup_picker.h"],
     deps = [
         ":calculator",
+        "//tensorflow/lite/delegates/gpu/common:gpu_info",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:types",
-        "//tensorflow/lite/delegates/gpu/gl:gpu_info",
     ],
 )
diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.cc b/tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.cc
index f0a1c4f..528d75d 100644
--- a/tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.cc
+++ b/tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.h"
 
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h"
 #include "tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.h"
 
diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.h b/tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.h
index 56d192d..e277e45 100644
--- a/tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.h
+++ b/tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.h
@@ -16,7 +16,7 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_BEST_EFFORT_CALCULATOR_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_BEST_EFFORT_CALCULATOR_H_
 
-#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.cc b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.cc
index 82ddf00..e21538b 100644
--- a/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.cc
+++ b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.cc
@@ -15,9 +15,9 @@
 
 #include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h"
 
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 #include "tensorflow/lite/delegates/gpu/gl/compiler/shader_code.h"
-#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h
index c59a943..1322474 100644
--- a/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h
+++ b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h
@@ -18,9 +18,9 @@
 
 #include <memory>
 
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 #include "tensorflow/lite/delegates/gpu/gl/compiler/shader_code.h"
-#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.cc b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.cc
index 673eedc3..b258f2c 100644
--- a/tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.cc
+++ b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.cc
@@ -20,15 +20,15 @@
 #include <memory>
 #include <unordered_map>
 
+#include "absl/memory/memory.h"
+#include "flatbuffers/flatbuffers.h"  // TF:flatbuffers
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
+#include "tensorflow/lite/delegates/gpu/common/types.h"
 #include "tensorflow/lite/delegates/gpu/gl/metadata_generated.h"
 #include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h"
 #include "tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.h"
 #include "tensorflow/lite/delegates/gpu/gl/workgroups_generated.h"
 
-#include "absl/memory/memory.h"
-#include "flatbuffers/flatbuffers.h"  // TF:flatbuffers
-#include "tensorflow/lite/delegates/gpu/common/types.h"
-
 #endif  // TFLITE_GPU_BINARY_RELEASE
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.h b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.h
index cca859f..4c034b1 100644
--- a/tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.h
+++ b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.h
@@ -16,7 +16,7 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_CALCULATOR_FROM_METADATA_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_CALCULATOR_FROM_METADATA_H_
 
-#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.cc b/tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.cc
index ebfba14..7b6358e 100644
--- a/tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.cc
+++ b/tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.h"
 
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 #include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h"
 
diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.h b/tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.h
index c8840ab..6053c9e 100644
--- a/tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.h
+++ b/tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.h
@@ -16,7 +16,7 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_DEFAULT_CALCULATOR_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_DEFAULT_CALCULATOR_H_
 
-#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h"
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/ideal_workgroup_picker.cc b/tensorflow/lite/delegates/gpu/gl/workgroups/ideal_workgroup_picker.cc
index 07dffa3..65636fe 100644
--- a/tensorflow/lite/delegates/gpu/gl/workgroups/ideal_workgroup_picker.cc
+++ b/tensorflow/lite/delegates/gpu/gl/workgroups/ideal_workgroup_picker.cc
@@ -18,10 +18,10 @@
 #include <map>
 #include <vector>
 
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
-#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/ideal_workgroup_picker.h b/tensorflow/lite/delegates/gpu/gl/workgroups/ideal_workgroup_picker.h
index 34461bd..34f628c 100644
--- a/tensorflow/lite/delegates/gpu/gl/workgroups/ideal_workgroup_picker.h
+++ b/tensorflow/lite/delegates/gpu/gl/workgroups/ideal_workgroup_picker.h
@@ -16,10 +16,10 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_IDEAL_WORKGROUP_PICKER_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_IDEAL_WORKGROUP_PICKER_H_
 
+#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
-#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/gl_delegate.cc b/tensorflow/lite/delegates/gpu/gl_delegate.cc
index f624fb9..2576ed4 100644
--- a/tensorflow/lite/delegates/gpu/gl_delegate.cc
+++ b/tensorflow/lite/delegates/gpu/gl_delegate.cc
@@ -43,6 +43,7 @@
 #include "tensorflow/lite/delegates/gpu/gl/egl_environment.h"
 #include "tensorflow/lite/delegates/gpu/gl/gl_call.h"
 #include "tensorflow/lite/delegates/gpu/gl/kernels/registry.h"
+#include "tensorflow/lite/delegates/gpu/gl/request_gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.h"
 #include "tensorflow/lite/minimal_logging.h"
 
diff --git a/tensorflow/lite/delegates/gpu/metal/BUILD b/tensorflow/lite/delegates/gpu/metal/BUILD
index 5c16a5f..c6dc95d 100644
--- a/tensorflow/lite/delegates/gpu/metal/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/BUILD
@@ -22,6 +22,7 @@
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/metal/kernels",
+        "//tensorflow/lite/delegates/gpu/metal/kernels:custom_registry",
     ],
 )
 
diff --git a/tensorflow/lite/delegates/gpu/metal/api.cc b/tensorflow/lite/delegates/gpu/metal/api.cc
index ae0b8c4..7395f7b 100644
--- a/tensorflow/lite/delegates/gpu/metal/api.cc
+++ b/tensorflow/lite/delegates/gpu/metal/api.cc
@@ -17,6 +17,7 @@
 
 #include <vector>
 
+#include "absl/strings/substitute.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
@@ -27,6 +28,7 @@
 #include "tensorflow/lite/delegates/gpu/metal/kernels/add.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/concat.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/conv.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h"
@@ -115,12 +117,148 @@
   }
 }
 
+Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
+                          const std::vector<ValueId>& inputs,
+                          const std::vector<ValueId>& outputs,
+                          const RuntimeOptions& options,
+                          std::vector<ComputeTaskDescriptorPtr>* tasks) {
+  int node_id = static_cast<int>(node->id);
+  auto op_type = OperationTypeFromString(node->operation.type);
+  switch (op_type) {
+    case OperationType::ADD:
+      *tasks = Add(node_id, inputs, outputs[0],
+                   absl::any_cast<AddAttributes>(node->operation.attributes),
+                   options);
+      break;
+    case OperationType::CONCAT: {
+      std::vector<BHWC> input_shapes;
+      for (auto& input : graph.FindInputs(node->id)) {
+        input_shapes.push_back(input->tensor.shape);
+      }
+      *tasks =
+          Concat(node_id, inputs, outputs[0],
+                 absl::any_cast<ConcatAttributes>(node->operation.attributes),
+                 input_shapes);
+      break;
+    }
+    case OperationType::CONVOLUTION_2D:
+      *tasks = SelectConvolution(
+          graph, node_id, inputs[0], outputs[0],
+          absl::any_cast<Convolution2DAttributes>(node->operation.attributes),
+          options);
+      break;
+    case OperationType::CONVOLUTION_TRANSPOSED:
+      *tasks =
+          ConvolutionTransposed(node_id, inputs[0], outputs[0],
+                                absl::any_cast<ConvolutionTransposedAttributes>(
+                                    node->operation.attributes),
+                                options);
+      break;
+    case OperationType::DEPTHWISE_CONVOLUTION:
+      *tasks =
+          SelectDepthWiseConv(node_id, inputs[0], outputs[0],
+                              absl::any_cast<DepthwiseConvolution2DAttributes>(
+                                  node->operation.attributes),
+                              options);
+      break;
+    case OperationType::FULLY_CONNECTED:
+      *tasks = FullyConnected(
+          node_id, inputs[0], outputs[0],
+          absl::any_cast<FullyConnectedAttributes>(node->operation.attributes),
+          options);
+      break;
+    case OperationType::HARD_SWISH:
+      *tasks = HardSwish(node_id, inputs[0], outputs[0], options);
+      break;
+    case OperationType::MAX_UNPOOLING_2D:
+      *tasks = MaxUnpooling(
+          node_id, inputs[0], inputs[1], outputs[0],
+          absl::any_cast<MaxUnpooling2DAttributes>(node->operation.attributes));
+      break;
+    case OperationType::MULTIPLY_SCALAR:
+      *tasks = Multiply(
+          node_id, inputs[0], outputs[0],
+          absl::any_cast<MultiplyScalarAttributes>(node->operation.attributes),
+          options);
+      break;
+    case OperationType::PAD:
+      *tasks =
+          Padding(node_id, inputs[0], outputs[0],
+                  absl::any_cast<PadAttributes>(node->operation.attributes));
+      break;
+    case OperationType::POOLING_2D:
+      *tasks = Pooling(
+          node_id, inputs[0], outputs,
+          absl::any_cast<Pooling2DAttributes>(node->operation.attributes));
+      break;
+    case OperationType::PRELU:
+      *tasks = PReLU(
+          node_id, inputs[0], outputs[0],
+          absl::any_cast<PReLUAttributes>(node->operation.attributes), options);
+      break;
+    case OperationType::RELU:
+      *tasks = ReLU(node_id, inputs[0], outputs[0],
+                    absl::any_cast<ReLUAttributes>(node->operation.attributes));
+      break;
+    case OperationType::RESHAPE:
+      *tasks = SelectReshape(
+          graph, node_id, inputs[0], outputs[0],
+          absl::any_cast<ReshapeAttributes>(node->operation.attributes));
+      break;
+    case OperationType::SLICE:
+      *tasks =
+          Slice(node_id, inputs[0], outputs[0],
+                absl::any_cast<SliceAttributes>(node->operation.attributes));
+      break;
+    case OperationType::SOFTMAX: {
+      auto attr = absl::any_cast<SoftmaxAttributes>(node->operation.attributes);
+      if (attr.axis != Axis::CHANNELS) {
+        return UnimplementedError("Softmax supports only CHANNELS dimension");
+      }
+      *tasks = SelectSoftmax(graph, node_id, inputs[0], outputs[0]);
+      break;
+    }
+    case OperationType::UPSAMPLE_2D:
+      *tasks = Upsample(
+          node_id, inputs[0], outputs[0],
+          absl::any_cast<Upsample2DAttributes>(node->operation.attributes));
+      break;
+    case OperationType::ABS:
+    case OperationType::COS:
+    case OperationType::LOG:
+    case OperationType::RSQRT:
+    case OperationType::SIGMOID:
+    case OperationType::SIN:
+    case OperationType::SQRT:
+    case OperationType::SQUARE:
+    case OperationType::TANH:
+      *tasks = ElementwiseWithOneInput(node_id, inputs[0], outputs[0], op_type);
+      break;
+    case OperationType::SUB:
+    case OperationType::DIV:
+    case OperationType::POW:
+    case OperationType::SQUARED_DIFF:
+      *tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type);
+      break;
+    case OperationType::APPLY_MASK:
+    case OperationType::BATCH_NORMALIZATION:
+    case OperationType::BATCH_TO_SPACE:
+    case OperationType::CONST:
+    case OperationType::LSTM:
+    case OperationType::MUL:
+    case OperationType::RESIZE:
+    case OperationType::SPACE_TO_BATCH:
+    case OperationType::UNKNOWN:
+      return UnimplementedError("Unsupported op: " + node->operation.type);
+  }
+  return OkStatus();
+}
+
 }  // namespace
 
 Status Compile(const GraphFloat32& graph, const RuntimeOptions& options,
                CompiledModel* compiled_model) {
   for (const auto& node : graph.nodes()) {
-    int node_id = static_cast<int>(node->id);
     std::vector<ValueId> inputs;
     for (auto& input : graph.FindInputs(node->id)) {
       inputs.push_back(static_cast<ValueId>(input->id));
@@ -129,142 +267,19 @@
     for (auto& output : graph.FindOutputs(node->id)) {
       outputs.push_back(static_cast<ValueId>(output->id));
     }
-
     std::vector<ComputeTaskDescriptorPtr> tasks;
-    auto op_type = OperationTypeFromString(node->operation.type);
-    switch (op_type) {
-      case OperationType::ADD:
-        tasks = Add(node_id, inputs, outputs[0],
-                    absl::any_cast<AddAttributes>(node->operation.attributes),
-                    options);
-        break;
-      case OperationType::CONCAT: {
-        std::vector<BHWC> input_shapes;
-        for (auto& input : graph.FindInputs(node->id)) {
-          input_shapes.push_back(input->tensor.shape);
-        }
-        tasks =
-            Concat(node_id, inputs, outputs[0],
-                   absl::any_cast<ConcatAttributes>(node->operation.attributes),
-                   input_shapes);
-        break;
+    auto custom_status =
+        RegisterCustomOps(graph, node, inputs, outputs, options, &tasks);
+    if (!custom_status.ok()) {
+      auto primary_status =
+          RegisterPrimaryOps(graph, node, inputs, outputs, options, &tasks);
+      if (!primary_status.ok()) {
+        return UnimplementedError(
+            absl::Substitute("Unsupported op type: $0; custom registry error: "
+                             "$1; primary registry error: $2;",
+                             node->operation.type, custom_status.message(),
+                             primary_status.message()));
       }
-      case OperationType::CONVOLUTION_2D:
-        tasks = SelectConvolution(
-            graph, node_id, inputs[0], outputs[0],
-            absl::any_cast<Convolution2DAttributes>(node->operation.attributes),
-            options);
-        break;
-      case OperationType::CONVOLUTION_TRANSPOSED:
-        tasks = ConvolutionTransposed(
-            node_id, inputs[0], outputs[0],
-            absl::any_cast<ConvolutionTransposedAttributes>(
-                node->operation.attributes),
-            options);
-        break;
-      case OperationType::DEPTHWISE_CONVOLUTION:
-        tasks = SelectDepthWiseConv(
-            node_id, inputs[0], outputs[0],
-            absl::any_cast<DepthwiseConvolution2DAttributes>(
-                node->operation.attributes),
-            options);
-        break;
-      case OperationType::FULLY_CONNECTED:
-        tasks = FullyConnected(node_id, inputs[0], outputs[0],
-                               absl::any_cast<FullyConnectedAttributes>(
-                                   node->operation.attributes),
-                               options);
-        break;
-      case OperationType::HARD_SWISH:
-        tasks = HardSwish(node_id, inputs[0], outputs[0], options);
-        break;
-      case OperationType::MAX_UNPOOLING_2D:
-        tasks = MaxUnpooling(node_id, inputs[0], inputs[1], outputs[0],
-                             absl::any_cast<MaxUnpooling2DAttributes>(
-                                 node->operation.attributes));
-        break;
-      case OperationType::MULTIPLY_SCALAR:
-        tasks = Multiply(node_id, inputs[0], outputs[0],
-                         absl::any_cast<MultiplyScalarAttributes>(
-                             node->operation.attributes),
-                         options);
-        break;
-      case OperationType::PAD:
-        tasks =
-            Padding(node_id, inputs[0], outputs[0],
-                    absl::any_cast<PadAttributes>(node->operation.attributes));
-        break;
-      case OperationType::POOLING_2D:
-        tasks = Pooling(
-            node_id, inputs[0], outputs,
-            absl::any_cast<Pooling2DAttributes>(node->operation.attributes));
-        break;
-      case OperationType::PRELU:
-        tasks =
-            PReLU(node_id, inputs[0], outputs[0],
-                  absl::any_cast<PReLUAttributes>(node->operation.attributes),
-                  options);
-        break;
-      case OperationType::RELU:
-        tasks =
-            ReLU(node_id, inputs[0], outputs[0],
-                 absl::any_cast<ReLUAttributes>(node->operation.attributes));
-        break;
-      case OperationType::RESHAPE:
-        tasks = SelectReshape(
-            graph, node_id, inputs[0], outputs[0],
-            absl::any_cast<ReshapeAttributes>(node->operation.attributes));
-        break;
-      case OperationType::SLICE:
-        tasks =
-            Slice(node_id, inputs[0], outputs[0],
-                  absl::any_cast<SliceAttributes>(node->operation.attributes));
-        break;
-      case OperationType::SOFT_MAX: {
-        auto attr =
-            absl::any_cast<SoftMaxAttributes>(node->operation.attributes);
-        if (attr.axis != Axis::CHANNELS) {
-          return UnimplementedError("Softmax supports only CHANNELS dimension");
-        }
-        tasks = SelectSoftmax(graph, node_id, inputs[0], outputs[0]);
-        break;
-      }
-      case OperationType::UPSAMPLE_2D:
-        tasks = Upsample(
-            node_id, inputs[0], outputs[0],
-            absl::any_cast<Upsample2DAttributes>(node->operation.attributes));
-        break;
-
-      case OperationType::ABS:
-      case OperationType::COS:
-      case OperationType::LOG:
-      case OperationType::RSQRT:
-      case OperationType::SIGMOID:
-      case OperationType::SIN:
-      case OperationType::SQRT:
-      case OperationType::SQUARE:
-      case OperationType::TANH:
-        tasks =
-            ElementwiseWithOneInput(node_id, inputs[0], outputs[0], op_type);
-        break;
-
-      case OperationType::SUB:
-      case OperationType::DIV:
-      case OperationType::POW:
-      case OperationType::SQUARED_DIFF:
-        tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type);
-        break;
-
-      case OperationType::APPLY_MASK:
-      case OperationType::BATCH_NORMALIZATION:
-      case OperationType::BATCH_TO_SPACE:
-      case OperationType::CONST:
-      case OperationType::LSTM:
-      case OperationType::MUL:
-      case OperationType::RESIZE:
-      case OperationType::SPACE_TO_BATCH:
-      case OperationType::UNKNOWN:
-        return UnimplementedError("Unsupported op: " + node->operation.type);
     }
     compiled_model->insert(compiled_model->end(), tasks.begin(), tasks.end());
   }
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
index 3a33b73..17e59e7 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
@@ -142,6 +142,19 @@
 )
 
 cc_library(
+    name = "custom_registry",
+    srcs = ["custom_registry.cc"],
+    hdrs = ["custom_registry.h"],
+    deps = [
+        "//tensorflow/lite/delegates/gpu/common:model",
+        "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common:types",
+        "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
+        "//tensorflow/lite/delegates/gpu/metal:runtime_options",
+    ],
+)
+
+cc_library(
     name = "depthwise_conv",
     srcs = ["depthwise_conv.cc"],
     hdrs = ["depthwise_conv.h"],
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.cc b/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.cc
new file mode 100644
index 0000000..228583c
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.cc
@@ -0,0 +1,39 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.h"
+
+#include <vector>
+
+#include "tensorflow/lite/delegates/gpu/common/model.h"
+#include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
+#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
+
+namespace tflite {
+namespace gpu {
+namespace metal {
+
+Status RegisterCustomOps(const GraphFloat32& graph, const Node* node,
+                         const std::vector<ValueId>& inputs,
+                         const std::vector<ValueId>& outputs,
+                         const RuntimeOptions& options,
+                         std::vector<ComputeTaskDescriptorPtr>* tasks) {
+  return UnimplementedError("Unsupported op: " + node->operation.type);
+}
+
+}  // namespace metal
+}  // namespace gpu
+}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.h b/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.h
new file mode 100644
index 0000000..bef2ba2
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.h
@@ -0,0 +1,41 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_CUSTOM_REGISTRY_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_CUSTOM_REGISTRY_H_
+
+#include <vector>
+
+#include "tensorflow/lite/delegates/gpu/common/model.h"
+#include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
+#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
+
+namespace tflite {
+namespace gpu {
+namespace metal {
+
+// Registers custom operations.
+Status RegisterCustomOps(const GraphFloat32& graph, const Node* node,
+                         const std::vector<ValueId>& inputs,
+                         const std::vector<ValueId>& outputs,
+                         const RuntimeOptions& options,
+                         std::vector<ComputeTaskDescriptorPtr>* tasks);
+
+}  // namespace metal
+}  // namespace gpu
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_CUSTOM_REGISTRY_H_
diff --git a/tensorflow/lite/delegates/gpu/metal_delegate.h b/tensorflow/lite/delegates/gpu/metal_delegate.h
index d38e73a..a47be53 100644
--- a/tensorflow/lite/delegates/gpu/metal_delegate.h
+++ b/tensorflow/lite/delegates/gpu/metal_delegate.h
@@ -18,8 +18,6 @@
 
 #import <Metal/Metal.h>
 
-#include <functional>
-
 #include "tensorflow/lite/c/c_api_internal.h"
 
 // Creates a new delegate instance that need to be destroyed with
@@ -48,25 +46,18 @@
 // When `options` is set to `nullptr`, the following default values are used:
 // .precision_loss_allowed = false,
 // .wait_type = kPassive,
-TfLiteDelegate* NewGpuDelegate(const GpuDelegateOptions* options);
+TfLiteDelegate* TFLGpuDelegateCreate(const GpuDelegateOptions* options);
 
-// Destroys a delegate created with `NewGpuDelegate` call.
-void DeleteGpuDelegate(TfLiteDelegate* delegate);
+// Destroys a delegate created with `TFLGpuDelegateCreate` call.
+void TFLGpuDelegateDelete(TfLiteDelegate* delegate);
 
 // Binds Metal buffer to an input or an output tensor in the initialized
 // delegate.  Bound buffer should have sufficient storage to accommodate all
 // elements of a tensor.  Returns non-zero on success, or zero otherwise.
 //
 // *** Must be called *before* `Interpreter::ModifyGraphWithDelegate`. ***
-bool BindMetalBufferToTensor(TfLiteDelegate* delegate, int tensor_index,
-                             id<MTLBuffer> metal_buffer);
-
-// Binds user-defined MTLComputeCommandEncoder. The delegate puts all GPU tasks
-// into this encoder instead of the internal encoder.
-// The callback is a user-defined function to take control over encoder and
-// command buffer. Can be nullptr.
-bool TFLSetCommandEncoder(
-    TfLiteDelegate* delegate, id<MTLComputeCommandEncoder> encoder,
-    std::function<id<MTLComputeCommandEncoder>(bool is_last)> control_encoder);
+bool TFLGpuDelegateBindMetalBufferToTensor(TfLiteDelegate* delegate,
+                                           int tensor_index,
+                                           id<MTLBuffer> metal_buffer);
 
 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_METAL_DELEGATE_H_
diff --git a/tensorflow/lite/delegates/gpu/metal_delegate.mm b/tensorflow/lite/delegates/gpu/metal_delegate.mm
index 36d60e9..a24aa30 100644
--- a/tensorflow/lite/delegates/gpu/metal_delegate.mm
+++ b/tensorflow/lite/delegates/gpu/metal_delegate.mm
@@ -107,10 +107,12 @@
       total_alarms_ = 1;
       NSString* error;
       id<MTLComputePipelineState> program;
+      // TODO(impjdi): Properly handle returned status.
       CreateComputeProgram(device_,
                            @"kernel void ComputeFunction(device int* output_buffer [[buffer(0)]]) "
                            @"{ output_buffer[0] = 0; }",
-                           @"ComputeFunction", nullptr, &program);
+                           @"ComputeFunction", nullptr, &program)
+          .IgnoreError();
       stub_program_ = program;
       stub_buffer_ = [device_ newBufferWithLength:sizeof(int) * 4
                                           options:MTLResourceHazardTrackingModeUntracked];
@@ -185,7 +187,9 @@
         )";
       NSString* error;
       id<MTLComputePipelineState> signal_program;
-      CreateComputeProgram(metal_device_, code, @"ComputeFunction", nullptr, &signal_program);
+      // TODO(impjdi): Properly handle returned status.
+      CreateComputeProgram(metal_device_, code, @"ComputeFunction", nullptr, &signal_program)
+          .IgnoreError();
       signal_program_ = signal_program;
       signal_buffer_ = [metal_device_ newBufferWithLength:sizeof(int) * 4
                                                   options:MTLResourceStorageModeShared |
@@ -613,22 +617,25 @@
 }  // namespace gpu
 }  // namespace tflite
 
-TfLiteDelegate* NewGpuDelegate(const GpuDelegateOptions* options) {
+TfLiteDelegate* TFLGpuDelegateCreate(const GpuDelegateOptions* options) {
   TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO, "Created TensorFlow Lite delegate for Metal.");
   auto* metal_delegate = new ::tflite::gpu::metal::Delegate(options);
   return metal_delegate ? metal_delegate->tflite_delegate() : nullptr;
 }
 
-void DeleteGpuDelegate(TfLiteDelegate* delegate) {
+void TFLGpuDelegateDelete(TfLiteDelegate* delegate) {
   delete ::tflite::gpu::metal::GetMetalDelegate(delegate);
 }
 
-bool BindMetalBufferToTensor(TfLiteDelegate* delegate, int tensor_index, id<MTLBuffer> buffer) {
+bool TFLGpuDelegateBindMetalBufferToTensor(TfLiteDelegate* delegate, int tensor_index,
+                                           id<MTLBuffer> buffer) {
   auto* metal_delegate = ::tflite::gpu::metal::GetMetalDelegate(delegate);
   return metal_delegate && metal_delegate->BindBufferToTensor(buffer, tensor_index).ok();
 }
 
-bool TFLSetCommandEncoder(
+// Note: This function is not exposed in `metal_delegate.h`, but it's exposed in
+// `metal_delegate_internal.h`.
+bool TFLGpuDelegateSetCommandEncoder(
     TfLiteDelegate* delegate, id<MTLComputeCommandEncoder> encoder,
     std::function<id<MTLComputeCommandEncoder>(bool is_last)> control_encoder) {
   auto* metal_delegate = ::tflite::gpu::metal::GetMetalDelegate(delegate);
diff --git a/tensorflow/lite/delegates/gpu/metal_delegate_internal.h b/tensorflow/lite/delegates/gpu/metal_delegate_internal.h
new file mode 100644
index 0000000..bc8ecdc
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/metal_delegate_internal.h
@@ -0,0 +1,33 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_DELEGATE_INTERNAL_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_DELEGATE_INTERNAL_H_
+
+#import <Metal/Metal.h>
+
+#include <functional>
+
+#include "tensorflow/lite/c/c_api_internal.h"
+
+// Binds user-defined MTLComputeCommandEncoder. The delegate puts all GPU tasks
+// into this encoder instead of the internal encoder.
+// The callback is a user-defined function to take control over encoder and
+// command buffer. Can be nullptr.
+bool TFLGpuDelegateSetCommandEncoder(
+    TfLiteDelegate* delegate, id<MTLComputeCommandEncoder> encoder,
+    std::function<id<MTLComputeCommandEncoder>(bool is_last)> control_encoder);
+
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_METAL_DELEGATE_INTERNAL_H_
diff --git a/tensorflow/lite/delegates/nnapi/BUILD b/tensorflow/lite/delegates/nnapi/BUILD
index 7cd5d14..954a943 100644
--- a/tensorflow/lite/delegates/nnapi/BUILD
+++ b/tensorflow/lite/delegates/nnapi/BUILD
@@ -18,9 +18,14 @@
         ],
         "//conditions:default": [
             "nnapi_delegate.cc",
+            "quant_lstm_sup.h",
+            "quant_lstm_sup.cc",
         ],
     }),
-    hdrs = ["nnapi_delegate.h"],
+    hdrs = [
+        "nnapi_delegate.h",
+        "nnapi_delegate_kernel.h",
+    ],
     deps = [
         "//tensorflow/lite:allocation",
         "//tensorflow/lite:kernel_api",
@@ -51,4 +56,22 @@
     ],
 )
 
+cc_test(
+    name = "quant_lstm_sup_test",
+    size = "small",
+    srcs = [
+        "quant_lstm_sup.cc",
+        "quant_lstm_sup.h",
+        "quant_lstm_sup_test.cc",
+    ],
+    deps = [
+        ":nnapi_delegate",
+        "//tensorflow/lite:framework",
+        "//tensorflow/lite/c:c_api_internal",
+        "//tensorflow/lite/kernels:kernel_util",
+        "//tensorflow/lite/testing:util",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
 tflite_portable_test_suite()
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
index 4b4737c..10b743a 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
@@ -14,27 +14,23 @@
 ==============================================================================*/
 #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
 
+#include <algorithm>
 #include <cstdarg>
+#include <cstddef>
 #include <cstdint>
 #include <cstring>
 #include <functional>
+#include <initializer_list>
 #include <iostream>
 #include <map>
 #include <memory>
 #include <string>
+#include <tuple>
 #include <vector>
 
-#include "tensorflow/lite/allocation.h"
-#include "tensorflow/lite/builtin_op_data.h"
-#include "tensorflow/lite/builtin_ops.h"
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/c/c_api_internal.h"
-#include "tensorflow/lite/context_util.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-#include "tensorflow/lite/minimal_logging.h"
-#include "tensorflow/lite/nnapi/nnapi_implementation.h"
-#include "tensorflow/lite/util.h"
-
+// This section needs to be before the import of nnapi_delegate_kernel
+// because the code changes according to  the definition of
+// TFLITE_NNAPI_ALLOW_MMAP_SHARING
 #ifdef __ANDROID__
 #include <sys/system_properties.h>
 #endif
@@ -44,6 +40,19 @@
 #include <unistd.h>
 #endif
 
+#include "tensorflow/lite/allocation.h"
+#include "tensorflow/lite/builtin_op_data.h"
+#include "tensorflow/lite/builtin_ops.h"
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/context_util.h"
+#include "tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h"
+#include "tensorflow/lite/delegates/nnapi/quant_lstm_sup.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/minimal_logging.h"
+#include "tensorflow/lite/nnapi/nnapi_implementation.h"
+#include "tensorflow/lite/util.h"
+
 namespace tflite {
 namespace {
 
@@ -59,8 +68,6 @@
     }                                                                         \
   } while (0)
 
-namespace {
-
 bool IsFloat(TfLiteType type) {
   switch (type) {
     case kTfLiteFloat32:
@@ -120,6 +127,12 @@
   return IsFloatOrUInt8(input_type);
 }
 
+bool IsFloatOrQuant8Operator(const TfLiteContext* context,
+                             const TfLiteNode* node) {
+  const auto input_type = context->tensors[node->inputs->data[0]].type;
+  return IsFloat(input_type) || IsQuantized(input_type);
+}
+
 // Check if the operation requires explict conversion from int8 to uint8 values.
 bool NeedInt8Conversion(const TfLiteContext* context, int builtin_code,
                         const TfLiteNode* node) {
@@ -140,9 +153,50 @@
       }
       return false;
     }
+    case kTfLiteBuiltinSelect: {
+      const auto value_type = context->tensors[node->inputs->data[1]].type;
+      return value_type == kTfLiteInt8;
+    }
+    case kTfLiteBuiltinAdd:
+    case kTfLiteBuiltinArgMax:
+    case kTfLiteBuiltinArgMin:
+    case kTfLiteBuiltinAveragePool2d:
+    case kTfLiteBuiltinBatchToSpaceNd:
+    case kTfLiteBuiltinConcatenation:
+    case kTfLiteBuiltinEqual:
+    case kTfLiteBuiltinExpandDims:
+    case kTfLiteBuiltinGreater:
+    case kTfLiteBuiltinGreaterEqual:
     case kTfLiteBuiltinL2Normalization:
+    case kTfLiteBuiltinLess:
+    case kTfLiteBuiltinLessEqual:
+    case kTfLiteBuiltinLogistic:
+    case kTfLiteBuiltinMaximum:
+    case kTfLiteBuiltinMaxPool2d:
+    case kTfLiteBuiltinMean:
+    case kTfLiteBuiltinMinimum:
+    case kTfLiteBuiltinMul:
+    case kTfLiteBuiltinNotEqual:
+    case kTfLiteBuiltinPad:
+    case kTfLiteBuiltinPadv2:
+    case kTfLiteBuiltinReduceMax:
+    case kTfLiteBuiltinReduceMin:
+    case kTfLiteBuiltinRelu:
+    case kTfLiteBuiltinReluN1To1:
+    case kTfLiteBuiltinRelu6:
+    case kTfLiteBuiltinResizeBilinear:
+    case kTfLiteBuiltinResizeNearestNeighbor:
+    case kTfLiteBuiltinReshape:
+    case kTfLiteBuiltinSlice:
+    case kTfLiteBuiltinSoftmax:
+    case kTfLiteBuiltinSpaceToBatchNd:
+    case kTfLiteBuiltinSpaceToDepth:
+    case kTfLiteBuiltinStridedSlice:
     case kTfLiteBuiltinSub:
-    case kTfLiteBuiltinTanh: {
+    case kTfLiteBuiltinTanh:
+    case kTfLiteBuiltinTile:
+    case kTfLiteBuiltinTopkV2:
+    case kTfLiteBuiltinTranspose: {
       return input_type == kTfLiteInt8;
     }
     default:
@@ -150,6 +204,22 @@
   }
 }
 
+constexpr int kLstmFullKernelInputSize = 24;
+// The 20 input version is deprecated and kept only to
+// support old model. The latest version of the LSTM Full Kernel
+// is the one with 24 inputs
+constexpr int kLstmFullKernelNoOptionalParamsInputSize = 20;
+constexpr int kLstmBasicKernelInputSize = 5;
+
+inline bool isLstmBasicKernel(const TfLiteNode* node) {
+  return node->inputs->size == kLstmBasicKernelInputSize;
+}
+
+inline bool isLstmFullKernel(const TfLiteNode* node) {
+  return node->inputs->size == kLstmFullKernelInputSize ||
+         node->inputs->size == kLstmFullKernelNoOptionalParamsInputSize;
+}
+
 bool IsHybridOperator(const TfLiteContext* context, int builtin_code,
                       const TfLiteNode* node) {
   switch (builtin_code) {
@@ -161,7 +231,15 @@
       const TfLiteType filter_type = context->tensors[filter_id].type;
       return IsFloat(input_type) && IsQuantized(filter_type);
     }
-    case kTfLiteBuiltinLstm:
+    case kTfLiteBuiltinLstm: {
+      const int input_id = node->inputs->data[0];
+      // Input #1 is optional so use #2 to determine if hybrid.
+      const int weights_id = node->inputs->data[2];
+      const TfLiteType input_type = context->tensors[input_id].type;
+      const TfLiteType weights_type = context->tensors[weights_id].type;
+      return isLstmFullKernel(node) && IsFloat(input_type) &&
+             IsQuantized(weights_type);
+    }
     case kTfLiteBuiltinUnidirectionalSequenceLstm: {
       const int input_id = node->inputs->data[0];
       // Input #1 is optional so use #2 to determine if hybrid.
@@ -207,9 +285,6 @@
   return input_scale * filter_scale < output_scale;
 }
 
-constexpr int32_t kMinSdkVersionForNNAPI = 27;
-constexpr int32_t kMinSdkVersionForNNAPI11 = 28;
-constexpr int32_t kMinSdkVersionForNNAPI12 = 29;
 constexpr size_t kDefaultByteAlignmentForNNAPI = 16;
 
 static size_t getNumPaddingBytes(size_t byte_size) {
@@ -221,16 +296,33 @@
   return num_padding_bytes;
 }
 
+std::string SimpleJoin(const std::vector<const char*>& elements,
+                       const char* separator) {
+  // Note that we avoid use of sstream to avoid binary size bloat.
+  std::string joined_elements;
+  for (auto it = elements.begin(); it != elements.end(); ++it) {
+    if (separator && it != elements.begin()) {
+      joined_elements += separator;
+    }
+    if (*it) {
+      joined_elements += *it;
+    }
+  }
+  return joined_elements;
+}
+
 // Return NNAPI device handle with the provided null-terminated device name. If
 // no matching device could be found, nullptr will be returned.
-ANeuralNetworksDevice* GetDeviceHandle(const char* device_name_ptr) {
+ANeuralNetworksDevice* GetDeviceHandle(TfLiteContext* context,
+                                       const char* device_name_ptr) {
   if (!device_name_ptr) return nullptr;
   ANeuralNetworksDevice* device_handle = nullptr;
   std::string device_name(device_name_ptr);
-  uint32_t numDevices = 0;
-  NnApiImplementation()->ANeuralNetworks_getDeviceCount(&numDevices);
+  uint32_t num_devices = 0;
+  NnApiImplementation()->ANeuralNetworks_getDeviceCount(&num_devices);
 
-  for (uint32_t i = 0; i < numDevices; i++) {
+  std::vector<const char*> device_names;
+  for (uint32_t i = 0; i < num_devices; i++) {
     ANeuralNetworksDevice* device = nullptr;
     const char* buffer = nullptr;
     NnApiImplementation()->ANeuralNetworks_getDevice(i, &device);
@@ -239,6 +331,14 @@
       device_handle = device;
       break;
     }
+    device_names.push_back(buffer);
+  }
+  if (!device_handle) {
+    context->ReportError(context,
+                         "Could not find the specified NNAPI accelerator: %s. "
+                         "Must be one of: {%s}.",
+                         device_name_ptr,
+                         SimpleJoin(device_names, ",").c_str());
   }
   return device_handle;
 }
@@ -270,18 +370,8 @@
 
 }  // namespace
 
-// RAII NN API Model Destructor for use with std::unique_ptr
-struct NNFreeModel {
-  void operator()(ANeuralNetworksModel* model) {
-    NnApiImplementation()->ANeuralNetworksModel_free(model);
-  }
-};
-// RAII NN API Compilation Destructor for use with std::unique_ptr
-struct NNFreeCompilation {
-  void operator()(ANeuralNetworksCompilation* model) {
-    NnApiImplementation()->ANeuralNetworksCompilation_free(model);
-  }
-};
+namespace delegate {
+namespace nnapi {
 
 // RAII NN API Execution Destructor for use with std::unique_ptr
 struct NNFreeExecution {
@@ -290,110 +380,6 @@
   }
 };
 
-// Manage NNAPI shared memory handle
-class NNMemory {
- public:
-#ifdef TFLITE_NNAPI_ALLOW_MMAP_SHARING
-  NNMemory(const NnApi* nnapi, const char* name, size_t size) {
-    if (name && size > 0) {
-      nnapi_ = nnapi;
-      byte_size_ = size;
-      fd_ = nnapi_->ASharedMemory_create(name, size);
-      data_ptr_ = reinterpret_cast<uint8_t*>(
-          mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd_, 0));
-      nnapi_->ANeuralNetworksMemory_createFromFd(size, PROT_READ | PROT_WRITE,
-                                                 fd_, 0, &nn_memory_handle_);
-    }
-  }
-#else
-  NNMemory(const NnApi* /*nnapi*/, const char* /*name*/, size_t /*size*/) {}
-#endif
-
-  ~NNMemory() {
-#ifdef TFLITE_NNAPI_ALLOW_MMAP_SHARING
-    if (data_ptr_) {
-      munmap(data_ptr_, byte_size_);
-    }
-    if (nn_memory_handle_) {
-      nnapi_->ANeuralNetworksMemory_free(nn_memory_handle_);
-    }
-    if (fd_ > 0) close(fd_);
-#endif
-  }
-
-  ANeuralNetworksMemory* get_handle() { return nn_memory_handle_; }
-  uint8_t* get_data_ptr() { return data_ptr_; }
-
- private:
-#ifdef TFLITE_NNAPI_ALLOW_MMAP_SHARING
-  const NnApi* nnapi_;
-  int fd_ = 0;
-  size_t byte_size_ = 0;
-#endif
-  uint8_t* data_ptr_ = nullptr;
-  ANeuralNetworksMemory* nn_memory_handle_ = nullptr;
-};  // namespace
-
-// Track tensor indices to NN API tensor indices mapping.
-class OperandMapping {
- public:
-  // Given a TFLite index return the ANN index. If it doesn't exist
-  // return -1.
-  int lite_index_to_ann(int index) const {
-    if (index < lite_tensor_to_ann_tensor_.size())
-      return lite_tensor_to_ann_tensor_[index];
-    else
-      return -1;
-  }
-
-  // NN API uses non tensor operands instead of structs. This creates one
-  // and returns the index. It uses a std::vector and resizes it as needed
-  // keeping -1 to unmapped values. Intermediate tensors likely will not
-  // be mapped.
-  int add_new_non_tensor_operand() { return next_ann_tensor_index_++; }
-
-  // Add a new mapping from `tflite_index` and return the NN API tensor index.
-  int add_new_ann_tensor_index(int tflite_index) {
-    if (tflite_index >= lite_tensor_to_ann_tensor_.size()) {
-      lite_tensor_to_ann_tensor_.resize(tflite_index + 1, -1);
-    }
-    int new_tensor_index = next_ann_tensor_index_++;
-    lite_tensor_to_ann_tensor_[tflite_index] = new_tensor_index;
-    return new_tensor_index;
-  }
-
-  // Given a TFLite index returns a TFLite type to which a tensor must be
-  // converted during copying the data to the memory allocated for NN API.
-  // kTfLiteNoType means no conversion is needed.
-  TfLiteType lite_index_to_ann_type_conversion(int index) const {
-    if (index >= 0 && index < index_to_type_conversion_.size())
-      return index_to_type_conversion_[index];
-    else
-      return kTfLiteNoType;
-  }
-
-  // Add a new mapping from TFLite index to a type conversion.
-  void add_type_conversion(int tflite_index, TfLiteType tflite_type) {
-    if (tflite_index >= index_to_type_conversion_.size()) {
-      index_to_type_conversion_.resize(tflite_index + 1, kTfLiteNoType);
-    }
-    index_to_type_conversion_[tflite_index] = tflite_type;
-  }
-
- private:
-  // Next index of ann tensor
-  int next_ann_tensor_index_ = 0;
-
-  // Mapping from lite index. Use a std::vector for speed and code size
-  // rather than a map.
-  std::vector<int> lite_tensor_to_ann_tensor_;
-  // Mapping from lite index to a type which tensor must be converted to during
-  // the copying of the data to the memory allocated for NN API. kTfLiteNoType
-  // means no conversion is needed. Use an std::vector for speed and code size
-  // rather than a map.
-  std::vector<TfLiteType> index_to_type_conversion_;
-};
-
 class DequantizeMapping {
  public:
   int DequantizedAnnIndex(int ann_index, TfLiteType type) const {
@@ -450,7 +436,14 @@
   TfLiteStatus AddVectorInt32Operand(const int32_t* values,
                                      uint32_t num_values) {
     return AddVectorOperand<int32_t>(values, num_values,
-                                     ANEURALNETWORKS_TENSOR_INT32);
+                                     ANEURALNETWORKS_TENSOR_INT32,
+                                     /*scale=*/0.f, /*zero_point=*/0);
+  }
+
+  TfLiteStatus AddVectorInt32Operand(const int32_t* values, uint32_t num_values,
+                                     float scale, int32_t zero_point) {
+    return AddVectorOperand<int32_t>(
+        values, num_values, ANEURALNETWORKS_TENSOR_INT32, scale, zero_point);
   }
 
   TfLiteStatus AddVectorFloat32Operand(const float* values,
@@ -577,6 +570,70 @@
     return kTfLiteOk;
   }
 
+  template <typename T>
+  TfLiteStatus AddNewInputConstantTensor(
+      int32_t nn_type, TfLiteType type, const TfLiteIntArray* dims,
+      const std::vector<T>& tensor_value,
+      const TfLiteQuantizationParams& quant_params, int* tensor_index) {
+    TF_LITE_ENSURE_OK(context_,
+                      context_->AddTensors(context_, 1, tensor_index));
+
+    TfLiteTensor* new_tensor = &context_->tensors[*tensor_index];
+    new_tensor->type = type;
+    new_tensor->allocation_type = kTfLiteDynamic;
+    new_tensor->params = quant_params;
+
+    // Not removing the new tensor in case of resizing errors since it will
+    // be cleared by the context
+    TF_LITE_ENSURE_OK(
+        context_,
+        context_->ResizeTensor(
+            context_, new_tensor,
+            // Resize Tensor takes ownership of the dims array passed as param
+            TfLiteIntArrayCopy(dims)));
+
+    memcpy(new_tensor->data.raw,
+           reinterpret_cast<const char*>(tensor_value.data()),
+           tensor_value.size() * sizeof(T));
+
+    const uint32_t tensor_rank = static_cast<uint32_t>(dims->size);
+    const uint32_t* tensor_dims = reinterpret_cast<const uint32_t*>(dims->data);
+    ANeuralNetworksOperandType operand_type{nn_type, tensor_rank, tensor_dims,
+                                            quant_params.scale,
+                                            quant_params.zero_point};
+
+    const int ann_tensor_index =
+        operand_mapping_->add_delegate_generated_input_ann_tensors_operand();
+
+    RETURN_TFLITE_ERROR_IF_NN_ERROR(
+        context_,
+        nnapi_->ANeuralNetworksModel_addOperand(nn_model_, &operand_type));
+
+    augmented_inputs_.push_back(ann_tensor_index);
+
+    RETURN_TFLITE_ERROR_IF_NN_ERROR(
+        context_, nnapi_->ANeuralNetworksModel_setOperandValue(
+                      nn_model_, ann_tensor_index, new_tensor->data.raw,
+                      new_tensor->bytes));
+
+    return kTfLiteOk;
+  }
+
+  template <typename T>
+  TfLiteStatus AddNewInputConstantTensor(
+      int32_t nn_type, TfLiteType type, std::initializer_list<int> dims,
+      const std::vector<T>& tensor_value,
+      const TfLiteQuantizationParams& quant_params, int* tensor_index) {
+    TfLiteIntArray* dim_array = TfLiteIntArrayCreate(dims.size());
+    dim_array->size = dims.size();
+    std::copy(dims.begin(), dims.end(), dim_array->data);
+
+    const auto result = AddNewInputConstantTensor(
+        nn_type, type, dim_array, tensor_value, quant_params, tensor_index);
+    TfLiteIntArrayFree(dim_array);
+    return result;
+  }
+
  private:
   // Returns a TF Lite type which has the same memory representation as a
   // provided NN API type.
@@ -614,9 +671,13 @@
 
   template <typename T>
   TfLiteStatus AddVectorOperand(const T* values, uint32_t num_values,
-                                int32_t nn_type) {
-    ANeuralNetworksOperandType operand_type{
-        .type = nn_type, .dimensionCount = 1, .dimensions = &num_values};
+                                int32_t nn_type, float scale,
+                                int32_t zero_point) {
+    ANeuralNetworksOperandType operand_type{.type = nn_type,
+                                            .dimensionCount = 1,
+                                            .dimensions = &num_values,
+                                            .scale = scale,
+                                            .zeroPoint = zero_point};
 
     RETURN_TFLITE_ERROR_IF_NN_ERROR(
         context_,
@@ -630,6 +691,13 @@
     return kTfLiteOk;
   }
 
+  template <typename T>
+  TfLiteStatus AddVectorOperand(const T* values, uint32_t num_values,
+                                int32_t nn_type) {
+    return AddVectorOperand(values, num_values, nn_type, /*scale=*/0.f,
+                            /*zero_point=*/0);
+  }
+
   TfLiteStatus AddFloat32OutputTensor(uint32_t dimension_count,
                                       const uint32_t* dimension_data,
                                       int* ann_index_out) {
@@ -712,6 +780,11 @@
       case kTfLiteBool:
         nn_type = ANEURALNETWORKS_TENSOR_BOOL8;
         break;
+      case kTfLiteInt16:
+        nn_type = ANEURALNETWORKS_TENSOR_QUANT16_SYMM;
+        scale = tensor->params.scale;
+        zeroPoint = tensor->params.zero_point;
+        break;
       default:
         context_->ReportError(
             context_, "Failed to add NN API tensor: type %s is not supported.",
@@ -829,14 +902,6 @@
   std::vector<uint32_t> augmented_outputs_;
 };
 
-struct NNAPIOpMappingArgs {
-  TfLiteContext* context;
-  NNAPIOpBuilder* builder;
-  TfLiteNode* node;
-  std::vector<int>* model_state_outputs;
-  std::vector<int>* model_state_tfl_inputs;
-};
-
 // Mapping function simply returning the operation type without adding any
 // additional parameter.
 template <ANeuralNetworksOperationType OperationType>
@@ -845,193 +910,175 @@
   return OperationType;
 }
 
-// The kernel that represents the node sub set of TF Lite being run on NN API.
-class NNAPIDelegateKernel {
- public:
-  NNAPIDelegateKernel() { nnapi_ = NnApiImplementation(); }
-  ~NNAPIDelegateKernel() {
-    for (auto content : allocation_memory_mapping_) {
-      nnapi_->ANeuralNetworksMemory_free(content.second);
-    }
-  }
-
-  typedef ANeuralNetworksOperationType (*MappingFn)(
-      const NNAPIOpMappingArgs& mapping_args);
-
-  // Return a function that knows how to translate a node into its operands
-  // when called. You can use this function to see if a node is supported
-  // (i.e. if the returned MappingFn is null, then the node is not supported).
-  static MappingFn Map(const TfLiteContext* context, int builtin_code,
-                       int version, int android_sdk_version,
-                       const TfLiteNode* node) {
-    switch (builtin_code) {
-      case kTfLiteBuiltinAdd:
-        if (version == 1) {
-          if (!IsFloatOrUint8Operator(context, node)) {
-            return nullptr;
-          }
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            auto builtin = reinterpret_cast<TfLiteAddParams*>(
-                mapping_args.node->builtin_data);
-            mapping_args.builder->AddScalarInt32Operand(builtin->activation);
-            return ANEURALNETWORKS_ADD;
-          };
+// Return a function that knows how to translate a node into its operands
+// when called. You can use this function to see if a node is supported
+// (i.e. if the returned MappingFn is null, then the node is not supported).
+NNAPIDelegateKernel::MappingFn NNAPIDelegateKernel::Map(
+    const TfLiteContext* context, int builtin_code, int version,
+    int android_sdk_version, const TfLiteNode* node,
+    bool is_accelerator_specified) {
+  switch (builtin_code) {
+    case kTfLiteBuiltinAdd:
+      if (version <= 2) {
+        if (!IsFloatOrQuant8Operator(context, node)) {
+          return nullptr;
         }
-        break;
-      case kTfLiteBuiltinArgMax:
-      case kTfLiteBuiltinArgMin:
-        if (version == 1) {
-          // Those operators were introduced in NNAPI 1.2.
-          if (android_sdk_version < kMinSdkVersionForNNAPI12) {
-            return nullptr;
-          }
-          // Only certain input types are supported.
-          auto input_type = context->tensors[node->inputs->data[0]].type;
-          if (input_type != kTfLiteFloat16 && input_type != kTfLiteFloat32 &&
-              input_type != kTfLiteInt32 && input_type != kTfLiteUInt8) {
-            return nullptr;
-          }
-          // NNAPI only supports axis as int32. If the axis type is int64 and
-          // constant we can convert it to int32 if the value isn't too large.
-          const auto& axis_tensor = context->tensors[node->inputs->data[1]];
-          if (axis_tensor.type == kTfLiteInt64) {
-            if (axis_tensor.allocation_type != kTfLiteMmapRo ||
-                *axis_tensor.data.i64 > std::numeric_limits<int32_t>::max() ||
-                *axis_tensor.data.i64 < std::numeric_limits<int32_t>::min()) {
-              return nullptr;
-            }
-          } else if (axis_tensor.type != kTfLiteInt32) {
-            return nullptr;
-          }
-          if (builtin_code == kTfLiteBuiltinArgMax) {
-            // NNAPI only supports int32 output.
-            auto builtin =
-                reinterpret_cast<TfLiteArgMaxParams*>(node->builtin_data);
-            if (builtin->output_type != kTfLiteInt32) {
-              return nullptr;
-            }
-            return BasicMappingFn<ANEURALNETWORKS_ARGMAX>;
-          } else {
-            // NNAPI only supports int32 output.
-            auto builtin =
-                reinterpret_cast<TfLiteArgMinParams*>(node->builtin_data);
-            if (builtin->output_type != kTfLiteInt32) {
-              return nullptr;
-            }
-            return BasicMappingFn<ANEURALNETWORKS_ARGMIN>;
-          }
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          auto builtin = reinterpret_cast<TfLiteAddParams*>(
+              mapping_args.node->builtin_data);
+          mapping_args.builder->AddScalarInt32Operand(builtin->activation);
+          return ANEURALNETWORKS_ADD;
+        };
+      }
+      break;
+    case kTfLiteBuiltinArgMax:
+    case kTfLiteBuiltinArgMin:
+      if (version <= 2) {
+        // Those operators were introduced in NNAPI 1.2.
+        if (android_sdk_version < kMinSdkVersionForNNAPI12) {
+          return nullptr;
         }
-        break;
-      case kTfLiteBuiltinMul:
-        if (version == 1) {
-          if (!IsFloatOrUint8Operator(context, node)) {
-            return nullptr;
-          }
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            auto builtin = reinterpret_cast<TfLiteMulParams*>(
-                mapping_args.node->builtin_data);
-            mapping_args.builder->AddScalarInt32Operand(builtin->activation);
-            return ANEURALNETWORKS_MUL;
-          };
+        // Only certain input types are supported.
+        auto input_type = context->tensors[node->inputs->data[0]].type;
+        if (input_type != kTfLiteFloat16 && input_type != kTfLiteFloat32 &&
+            input_type != kTfLiteInt32 && input_type != kTfLiteUInt8 &&
+            input_type != kTfLiteInt8) {
+          return nullptr;
         }
-        break;
-      case kTfLiteBuiltinAveragePool2d:
-        if (version == 1) {
-          if (!IsFloatOrUint8Operator(context, node)) {
+        // NNAPI only supports axis as int32. If the axis type is int64 and
+        // constant we can convert it to int32 if the value isn't too large.
+        const auto& axis_tensor = context->tensors[node->inputs->data[1]];
+        if (axis_tensor.type == kTfLiteInt64) {
+          if (axis_tensor.allocation_type != kTfLiteMmapRo ||
+              *axis_tensor.data.i64 > std::numeric_limits<int32_t>::max() ||
+              *axis_tensor.data.i64 < std::numeric_limits<int32_t>::min()) {
             return nullptr;
           }
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            mapping_args.builder->AddPoolingParams(
-                mapping_args.node->builtin_data);
-            return ANEURALNETWORKS_AVERAGE_POOL_2D;
-          };
+        } else if (axis_tensor.type != kTfLiteInt32) {
+          return nullptr;
         }
-        break;
-      case kTfLiteBuiltinMaxPool2d:
-        if (version == 1) {
-          if (!IsFloatOrUint8Operator(context, node)) {
-            return nullptr;
-          }
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            mapping_args.builder->AddPoolingParams(
-                mapping_args.node->builtin_data);
-            return ANEURALNETWORKS_MAX_POOL_2D;
-          };
-        }
-        break;
-      case kTfLiteBuiltinL2Pool2d:
-        if (version == 1) {
-          if (!IsFloatOperator(context, node)) {
-            return nullptr;
-          }
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            mapping_args.builder->AddPoolingParams(
-                mapping_args.node->builtin_data);
-            return ANEURALNETWORKS_L2_POOL_2D;
-          };
-        }
-        break;
-      case kTfLiteBuiltinConv2d:
-        if (version <= 2) {
-          if ((android_sdk_version < kMinSdkVersionForNNAPI12) &&
-              (IsHybridOperator(context, builtin_code, node) ||
-               !IsFloatOrUint8Operator(context, node))) {
-            // Hybrid operators not supported before NNAPI 1.2.
-            return nullptr;
-          }
-          if (android_sdk_version < kMinSdkVersionForNNAPI12) {
-            // Per-channel quantized convolution not supported before NNAPI 1.2.
-            const auto& filter_tensor = context->tensors[node->inputs->data[1]];
-            if (filter_tensor.quantization.type == kTfLiteAffineQuantization) {
-              TfLiteAffineQuantization* quantization_params =
-                  static_cast<TfLiteAffineQuantization*>(
-                      filter_tensor.quantization.params);
-              if (quantization_params->scale->size > 1) {
-                return nullptr;
-              }
-            }
-          }
-          const auto input_type = context->tensors[node->inputs->data[0]].type;
-          if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
-              input_type == kTfLiteUInt8 &&
-              !IsRestrictedScalesCompliant(context, node)) {
-            return nullptr;
-          }
+        if (builtin_code == kTfLiteBuiltinArgMax) {
+          // NNAPI only supports int32 output.
           auto builtin =
-              reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
-          if (node->inputs->size != 3) {
-            // TODO(b/132950584): Add support for Conv2D with omitted bias
+              reinterpret_cast<TfLiteArgMaxParams*>(node->builtin_data);
+          if (builtin->output_type != kTfLiteInt32) {
             return nullptr;
           }
-          // NNAPI supports dilated Conv2D since NNAPI 1.2.
-          if (builtin->dilation_width_factor != 1 ||
-              builtin->dilation_height_factor != 1) {
-            if (android_sdk_version < kMinSdkVersionForNNAPI12) {
+          return BasicMappingFn<ANEURALNETWORKS_ARGMAX>;
+        } else {
+          // NNAPI only supports int32 output.
+          auto builtin =
+              reinterpret_cast<TfLiteArgMinParams*>(node->builtin_data);
+          if (builtin->output_type != kTfLiteInt32) {
+            return nullptr;
+          }
+          return BasicMappingFn<ANEURALNETWORKS_ARGMIN>;
+        }
+      }
+      break;
+    case kTfLiteBuiltinMul:
+      if (version <= 2) {
+        if (!IsFloatOrQuant8Operator(context, node)) {
+          return nullptr;
+        }
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          auto builtin = reinterpret_cast<TfLiteMulParams*>(
+              mapping_args.node->builtin_data);
+          mapping_args.builder->AddScalarInt32Operand(builtin->activation);
+          return ANEURALNETWORKS_MUL;
+        };
+      }
+      break;
+    case kTfLiteBuiltinAveragePool2d:
+      if (version <= 2) {
+        if (!IsFloatOrQuant8Operator(context, node)) {
+          return nullptr;
+        }
+        auto builtin = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
+        // TODO(b/138756912): Large filter window would overflow on the
+        // reference CPU path.
+        if (!is_accelerator_specified &&
+            (builtin->filter_width * builtin->filter_height > 256)) {
+          return nullptr;
+        }
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          mapping_args.builder->AddPoolingParams(
+              mapping_args.node->builtin_data);
+          return ANEURALNETWORKS_AVERAGE_POOL_2D;
+        };
+      }
+      break;
+    case kTfLiteBuiltinMaxPool2d:
+      if (version <= 2) {
+        if (!IsFloatOrQuant8Operator(context, node)) {
+          return nullptr;
+        }
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          mapping_args.builder->AddPoolingParams(
+              mapping_args.node->builtin_data);
+          return ANEURALNETWORKS_MAX_POOL_2D;
+        };
+      }
+      break;
+    case kTfLiteBuiltinL2Pool2d:
+      if (version == 1) {
+        if (!IsFloatOperator(context, node)) {
+          return nullptr;
+        }
+        auto builtin = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
+        // Pre-Q devices may not support fused activation for l2_pool.
+        if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
+            builtin->activation != kTfLiteActNone) {
+          return nullptr;
+        }
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          mapping_args.builder->AddPoolingParams(
+              mapping_args.node->builtin_data);
+          return ANEURALNETWORKS_L2_POOL_2D;
+        };
+      }
+      break;
+    case kTfLiteBuiltinConv2d:
+      if (version <= 3) {
+        if ((android_sdk_version < kMinSdkVersionForNNAPI12) &&
+            (IsHybridOperator(context, builtin_code, node) ||
+             !IsFloatOrUint8Operator(context, node))) {
+          // Hybrid operators not supported before NNAPI 1.2.
+          return nullptr;
+        }
+        if (android_sdk_version < kMinSdkVersionForNNAPI12) {
+          // Per-channel quantized convolution not supported before NNAPI 1.2.
+          const auto& filter_tensor = context->tensors[node->inputs->data[1]];
+          if (filter_tensor.quantization.type == kTfLiteAffineQuantization) {
+            TfLiteAffineQuantization* quantization_params =
+                static_cast<TfLiteAffineQuantization*>(
+                    filter_tensor.quantization.params);
+            if (quantization_params->scale->size > 1) {
               return nullptr;
             }
-            return [](const NNAPIOpMappingArgs& mapping_args)
-                       -> ANeuralNetworksOperationType {
-              auto builtin = reinterpret_cast<TfLiteConvParams*>(
-                  mapping_args.node->builtin_data);
-              mapping_args.builder->AddScalarInt32Operand(builtin->padding);
-              mapping_args.builder->AddScalarInt32Operand(
-                  builtin->stride_width);
-              mapping_args.builder->AddScalarInt32Operand(
-                  builtin->stride_height);
-              mapping_args.builder->AddScalarInt32Operand(builtin->activation);
-              mapping_args.builder->AddScalarBoolOperand(
-                  false);  // Use NHWC format
-              mapping_args.builder->AddScalarInt32Operand(
-                  builtin->dilation_width_factor);
-              mapping_args.builder->AddScalarInt32Operand(
-                  builtin->dilation_height_factor);
-              return ANEURALNETWORKS_CONV_2D;
-            };
+          }
+        }
+        const auto input_type = context->tensors[node->inputs->data[0]].type;
+        if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
+            input_type == kTfLiteUInt8 &&
+            !IsRestrictedScalesCompliant(context, node)) {
+          return nullptr;
+        }
+        auto builtin = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
+        if (node->inputs->size != 3) {
+          // TODO(b/132950584): Add support for Conv2D with omitted bias
+          return nullptr;
+        }
+        // NNAPI supports dilated Conv2D since NNAPI 1.2.
+        if (builtin->dilation_width_factor != 1 ||
+            builtin->dilation_height_factor != 1) {
+          if (android_sdk_version < kMinSdkVersionForNNAPI12) {
+            return nullptr;
           }
           return [](const NNAPIOpMappingArgs& mapping_args)
                      -> ANeuralNetworksOperationType {
@@ -1041,1627 +1088,2138 @@
             mapping_args.builder->AddScalarInt32Operand(builtin->stride_width);
             mapping_args.builder->AddScalarInt32Operand(builtin->stride_height);
             mapping_args.builder->AddScalarInt32Operand(builtin->activation);
+            mapping_args.builder->AddScalarBoolOperand(
+                false);  // Use NHWC format
+            mapping_args.builder->AddScalarInt32Operand(
+                builtin->dilation_width_factor);
+            mapping_args.builder->AddScalarInt32Operand(
+                builtin->dilation_height_factor);
             return ANEURALNETWORKS_CONV_2D;
           };
         }
-        break;
-      case kTfLiteBuiltinDepthwiseConv2d:
-        if (version == 1) {
-          if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
-              !IsFloatOrUint8Operator(context, node)) {
-            return nullptr;
-          }
-          const auto input_type = context->tensors[node->inputs->data[0]].type;
-          if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
-              input_type == kTfLiteUInt8 &&
-              !IsRestrictedScalesCompliant(context, node)) {
-            return nullptr;
-          }
-          auto builtin =
-              reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data);
-          if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
-              (builtin->dilation_width_factor != 1 ||
-               builtin->dilation_height_factor != 1)) {
-            return nullptr;
-          }
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            auto builtin = reinterpret_cast<TfLiteDepthwiseConvParams*>(
-                mapping_args.node->builtin_data);
-            mapping_args.builder->AddScalarInt32Operand(builtin->padding);
-            mapping_args.builder->AddScalarInt32Operand(builtin->stride_width);
-            mapping_args.builder->AddScalarInt32Operand(builtin->stride_height);
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          auto builtin = reinterpret_cast<TfLiteConvParams*>(
+              mapping_args.node->builtin_data);
+          mapping_args.builder->AddScalarInt32Operand(builtin->padding);
+          mapping_args.builder->AddScalarInt32Operand(builtin->stride_width);
+          mapping_args.builder->AddScalarInt32Operand(builtin->stride_height);
+          mapping_args.builder->AddScalarInt32Operand(builtin->activation);
+          return ANEURALNETWORKS_CONV_2D;
+        };
+      }
+      break;
+    case kTfLiteBuiltinDepthwiseConv2d:
+      if (version <= 3) {
+        if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
+            !IsFloatOrUint8Operator(context, node)) {
+          return nullptr;
+        }
+        const auto input_type = context->tensors[node->inputs->data[0]].type;
+        if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
+            input_type == kTfLiteUInt8 &&
+            !IsRestrictedScalesCompliant(context, node)) {
+          return nullptr;
+        }
+        auto builtin =
+            reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data);
+        if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
+            (builtin->dilation_width_factor != 1 ||
+             builtin->dilation_height_factor != 1)) {
+          return nullptr;
+        }
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          auto builtin = reinterpret_cast<TfLiteDepthwiseConvParams*>(
+              mapping_args.node->builtin_data);
+          mapping_args.builder->AddScalarInt32Operand(builtin->padding);
+          mapping_args.builder->AddScalarInt32Operand(builtin->stride_width);
+          mapping_args.builder->AddScalarInt32Operand(builtin->stride_height);
+          mapping_args.builder->AddScalarInt32Operand(
+              builtin->depth_multiplier);
+          mapping_args.builder->AddScalarInt32Operand(builtin->activation);
+          if (builtin->dilation_width_factor != 1 ||
+              builtin->dilation_height_factor != 1) {
+            mapping_args.builder->AddScalarBoolOperand(
+                false);  // Use NHWC format
             mapping_args.builder->AddScalarInt32Operand(
-                builtin->depth_multiplier);
-            mapping_args.builder->AddScalarInt32Operand(builtin->activation);
-            if (builtin->dilation_width_factor != 1 ||
-                builtin->dilation_height_factor != 1) {
-              mapping_args.builder->AddScalarBoolOperand(
-                  false);  // Use NHWC format
-              mapping_args.builder->AddScalarInt32Operand(
-                  builtin->dilation_width_factor);
-              mapping_args.builder->AddScalarInt32Operand(
-                  builtin->dilation_height_factor);
-            }
-            return ANEURALNETWORKS_DEPTHWISE_CONV_2D;
-          };
+                builtin->dilation_width_factor);
+            mapping_args.builder->AddScalarInt32Operand(
+                builtin->dilation_height_factor);
+          }
+          return ANEURALNETWORKS_DEPTHWISE_CONV_2D;
+        };
+      }
+      break;
+    case kTfLiteBuiltinFullyConnected:
+      if (version <= 4) {
+        if (node->inputs->size != 3 ||
+            node->inputs->data[2] == kOptionalTensor) {
+          // TODO(b/132950584): Add support for FullyConnected with no bias.
+          return nullptr;
         }
-        break;
-      case kTfLiteBuiltinFullyConnected:
-        if (version == 1) {
-          if (node->inputs->size != 3 ||
-              node->inputs->data[2] == kOptionalTensor) {
-            // TODO(b/132950584): Add support for FullyConnected with no bias.
-            return nullptr;
-          }
-          const auto output_type =
-              context->tensors[node->outputs->data[0]].type;
-          if (output_type == kTfLiteInt16) {
-            return nullptr;
-          }
-          if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
-              (IsHybridOperator(context, builtin_code, node) ||
-               !IsFloatOrUint8Operator(context, node))) {
-            // Hybrid operators not supported before NNAPI 1.2.
-            return nullptr;
-          }
-          const auto input_type = context->tensors[node->inputs->data[0]].type;
-          if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
-              input_type == kTfLiteUInt8 &&
-              !IsRestrictedScalesCompliant(context, node)) {
-            return nullptr;
-          }
+        const auto output_type = context->tensors[node->outputs->data[0]].type;
+        if (output_type == kTfLiteInt16) {
+          return nullptr;
+        }
+        if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
+            (IsHybridOperator(context, builtin_code, node) ||
+             !IsFloatOrUint8Operator(context, node))) {
+          // Hybrid operators not supported before NNAPI 1.2.
+          return nullptr;
+        }
+        const auto input_type = context->tensors[node->inputs->data[0]].type;
+        if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
+            input_type == kTfLiteUInt8 &&
+            !IsRestrictedScalesCompliant(context, node)) {
+          return nullptr;
+        }
+        auto builtin =
+            reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
+        if (builtin->keep_num_dims) {
+          return nullptr;
+        }
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          auto builtin = reinterpret_cast<TfLiteFullyConnectedParams*>(
+              mapping_args.node->builtin_data);
+          mapping_args.builder->AddScalarInt32Operand(builtin->activation);
+          return ANEURALNETWORKS_FULLY_CONNECTED;
+        };
+      }
+      break;
+    case kTfLiteBuiltinSoftmax:
+      if (version <= 2) {
+        const auto& input = context->tensors[node->outputs->data[0]];
+        if (!IsFloatOrQuant8Operator(context, node)) {
+          return nullptr;
+        }
+        const int input_rank = input.dims->size;
+        if (input_rank > 4) return nullptr;
+        // Before API level 29 only 2D and 4D input tensors were supported.
+        if (android_sdk_version < kMinSdkVersionForNNAPI12) {
+          if (input_rank != 2 && input_rank != 4) return nullptr;
+        }
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          auto builtin = reinterpret_cast<TfLiteSoftmaxParams*>(
+              mapping_args.node->builtin_data);
+          mapping_args.builder->AddScalarFloat32Operand(builtin->beta);
+          // Optional scalar specifying the dimension the activation would be
+          // performed on is not added. Default to -1.
+          return ANEURALNETWORKS_SOFTMAX;
+        };
+      }
+      break;
+    case kTfLiteBuiltinReshape:
+      if (version == 1) {
+        if (!IsFloatOrQuant8Operator(context, node)) {
+          return nullptr;
+        }
+        // The shape input tensor must be constant.
+        if ((node->inputs->size < 2) ||
+            (context->tensors[node->inputs->data[1]].allocation_type !=
+             kTfLiteMmapRo)) {
+          return nullptr;
+        }
+        return BasicMappingFn<ANEURALNETWORKS_RESHAPE>;
+      }
+      break;
+    case kTfLiteBuiltinResizeBilinear:
+      if (version <= 2) {
+        const auto& input = context->tensors[node->inputs->data[0]];
+        const auto output_dims = context->tensors[node->outputs->data[0]].dims;
+        if (input.dims->size != 4) return nullptr;
+        if (!IsFloatOrQuant8Operator(context, node)) {
+          return nullptr;
+        }
+        // The size input tensor must be constant.
+        if ((node->inputs->size < 2) ||
+            (context->tensors[node->inputs->data[1]].allocation_type !=
+             kTfLiteMmapRo)) {
+          return nullptr;
+        }
+        if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
+            output_dims->data[1] != output_dims->data[2]) {
+          // Require width == height due to driver differences in NNAPI < 1.2
+          return nullptr;
+        }
+        auto builtin =
+            reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
+        if (builtin->align_corners) {
+          // NNAPI does not support align_corners == true.
+          return nullptr;
+        }
+        if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
+            input.type != kTfLiteFloat32) {
+          // NNAPI 1.0 & 1.1 only supports float input.
+          return nullptr;
+        }
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          const int output_id = mapping_args.node->outputs->data[0];
+          auto& output = mapping_args.context->tensors[output_id];
+          const int output_height = output.dims->data[1];
+          const int output_width = output.dims->data[2];
+          mapping_args.builder->AddScalarInt32Operand(output_width);
+          mapping_args.builder->AddScalarInt32Operand(output_height);
+          return ANEURALNETWORKS_RESIZE_BILINEAR;
+        };
+      }
+      break;
+    case kTfLiteBuiltinResizeNearestNeighbor: {
+      if (version > 2 || android_sdk_version < kMinSdkVersionForNNAPI12) {
+        return nullptr;
+      }
+      if (!IsFloatOrQuant8Operator(context, node)) {
+        return nullptr;
+      }
+      auto builtin = reinterpret_cast<TfLiteResizeNearestNeighborParams*>(
+          node->builtin_data);
+      if (builtin->align_corners) {
+        // NNAPI does not support align_corners == true.
+        return nullptr;
+      }
+      return [](const NNAPIOpMappingArgs& mapping_args)
+                 -> ANeuralNetworksOperationType {
+        const TfLiteTensor& new_shape =
+            mapping_args.context->tensors[mapping_args.node->inputs->data[1]];
+        // NNAPI uses scalar inputs for height and width.
+        mapping_args.builder->AddScalarInt32Operand(new_shape.data.i32[1]);
+        mapping_args.builder->AddScalarInt32Operand(new_shape.data.i32[0]);
+        mapping_args.builder->AddScalarBoolOperand(false);  // Use NHWC format
+
+        return ANEURALNETWORKS_RESIZE_NEAREST_NEIGHBOR;
+      };
+    } break;
+    case kTfLiteBuiltinSqueeze:
+      if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI11) {
+        auto builtin =
+            reinterpret_cast<TfLiteSqueezeParams*>(node->builtin_data);
+        if (android_sdk_version == kMinSdkVersionForNNAPI11 &&
+            builtin->num_squeeze_dims == 0) {
+          // NNAPI 1.1 does not support null squeeze_dims properly.
+          return nullptr;
+        }
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          auto builtin = reinterpret_cast<TfLiteSqueezeParams*>(
+              mapping_args.node->builtin_data);
+          // Note that we add the squeeze dimensions even if the dimensions
+          // were unspecified (empty), as NNAPI requires the operand.
+          mapping_args.builder->AddVectorInt32Operand(
+              builtin->num_squeeze_dims ? builtin->squeeze_dims : nullptr,
+              static_cast<uint32_t>(builtin->num_squeeze_dims));
+          return ANEURALNETWORKS_SQUEEZE;
+        };
+      }
+      break;
+    case kTfLiteBuiltinUnidirectionalSequenceLstm:
+      if (version <= 2 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
+        if (IsHybridOperator(context, builtin_code, node)) {
+          // Hybrid version of this op is not supported by NN API.
+          return nullptr;
+        }
+        if (node->inputs->size != 20 && node->inputs->size != 24) {
+          return nullptr;
+        }
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
           auto builtin =
-              reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
-          if (builtin->keep_num_dims) {
-            return nullptr;
-          }
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            auto builtin = reinterpret_cast<TfLiteFullyConnectedParams*>(
-                mapping_args.node->builtin_data);
-            mapping_args.builder->AddScalarInt32Operand(builtin->activation);
-            return ANEURALNETWORKS_FULLY_CONNECTED;
-          };
-        }
-        break;
-      case kTfLiteBuiltinSoftmax:
-        if (version == 1) {
-          const auto& input = context->tensors[node->outputs->data[0]];
-          if (input.type != kTfLiteFloat32 && input.type != kTfLiteUInt8) {
-            return nullptr;
-          }
-          const int input_rank = input.dims->size;
-          if (input_rank > 4) return nullptr;
-          // Before API level 29 only 2D and 4D input tensors were supported.
-          if (android_sdk_version < kMinSdkVersionForNNAPI12) {
-            if (input_rank != 2 && input_rank != 4) return nullptr;
-          }
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            auto builtin = reinterpret_cast<TfLiteSoftmaxParams*>(
-                mapping_args.node->builtin_data);
-            mapping_args.builder->AddScalarFloat32Operand(builtin->beta);
-            // Optional scalar specifying the dimension the activation would be
-            // performed on is not added. Default to -1.
-            return ANEURALNETWORKS_SOFTMAX;
-          };
-        }
-        break;
-      case kTfLiteBuiltinReshape:
-        if (version == 1) {
-          if (!IsFloatOrUint8Operator(context, node)) {
-            return nullptr;
-          }
-          // The shape input tensor must be constant.
-          if ((node->inputs->size < 2) ||
-              (context->tensors[node->inputs->data[1]].allocation_type !=
-               kTfLiteMmapRo)) {
-            return nullptr;
-          }
-          return BasicMappingFn<ANEURALNETWORKS_RESHAPE>;
-        }
-        break;
-      case kTfLiteBuiltinResizeBilinear:
-        if (version == 1) {
-          const auto& input = context->tensors[node->inputs->data[0]];
-          const auto output_dims =
-              context->tensors[node->outputs->data[0]].dims;
-          if (input.dims->size != 4) return nullptr;
-          if (!IsFloatOrUint8Operator(context, node)) {
-            return nullptr;
-          }
-          // The size input tensor must be constant.
-          if ((node->inputs->size < 2) ||
-              (context->tensors[node->inputs->data[1]].allocation_type !=
-               kTfLiteMmapRo)) {
-            return nullptr;
-          }
-          if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
-              output_dims->data[1] != output_dims->data[2]) {
-            // Require width == height due to driver differences in NNAPI < 1.2
-            return nullptr;
-          }
-          auto builtin =
-              reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
-          if (builtin->align_corners) {
-            // NNAPI does not support align_corners == true.
-            return nullptr;
-          }
-          if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
-              input.type != kTfLiteFloat32) {
-            // NNAPI 1.0 & 1.1 only supports float input.
-            return nullptr;
-          }
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            const int output_id = mapping_args.node->outputs->data[0];
-            auto& output = mapping_args.context->tensors[output_id];
-            const int output_height = output.dims->data[1];
-            const int output_width = output.dims->data[2];
-            mapping_args.builder->AddScalarInt32Operand(output_width);
-            mapping_args.builder->AddScalarInt32Operand(output_height);
-            return ANEURALNETWORKS_RESIZE_BILINEAR;
-          };
-        }
-        break;
-      case kTfLiteBuiltinSqueeze:
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI11) {
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            auto builtin = reinterpret_cast<TfLiteSqueezeParams*>(
-                mapping_args.node->builtin_data);
-            // Note that we add the squeeze dimensions even if the dimensions
-            // were unspecified (empty), as NNAPI requires the operand.
-            mapping_args.builder->AddVectorInt32Operand(
-                builtin->num_squeeze_dims ? builtin->squeeze_dims : nullptr,
-                static_cast<uint32_t>(builtin->num_squeeze_dims));
-            return ANEURALNETWORKS_SQUEEZE;
-          };
-        }
-        break;
-      case kTfLiteBuiltinUnidirectionalSequenceLstm:
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
-          if (IsHybridOperator(context, builtin_code, node)) {
-            // Hybrid version of this op is not supported by NN API.
-            return nullptr;
-          }
-          if (node->inputs->size != 20 && node->inputs->size != 24) {
-            return nullptr;
-          }
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            auto builtin =
-                reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
-                    mapping_args.node->builtin_data);
-            mapping_args.builder->AddScalarInt32Operand(builtin->activation);
-            mapping_args.builder->AddScalarFloat32Operand(builtin->cell_clip);
-            mapping_args.builder->AddScalarFloat32Operand(builtin->proj_clip);
-            mapping_args.builder->AddScalarBoolOperand(builtin->time_major);
-            const bool hybrid_op = IsHybridOperator(
-                mapping_args.context, kTfLiteBuiltinUnidirectionalSequenceLstm,
-                mapping_args.node);
-            if (mapping_args.node->inputs->size == 24) {
-              // Add layer normalization tensors if they are provided.
-              for (int i = 20; i < 24; ++i) {
-                const int input_index = mapping_args.node->inputs->data[i];
-                if (input_index != kOptionalTensor) {
-                  mapping_args.builder->AddTensorInput(input_index, hybrid_op);
-                } else {
-                  mapping_args.builder->AddVectorFloat32Operand(nullptr, 0);
-                }
-              }
-            } else {
-              for (int i = 0; i < 4; ++i) {
+              reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
+                  mapping_args.node->builtin_data);
+          mapping_args.builder->AddScalarInt32Operand(builtin->activation);
+          mapping_args.builder->AddScalarFloat32Operand(builtin->cell_clip);
+          mapping_args.builder->AddScalarFloat32Operand(builtin->proj_clip);
+          mapping_args.builder->AddScalarBoolOperand(builtin->time_major);
+          const bool hybrid_op = IsHybridOperator(
+              mapping_args.context, kTfLiteBuiltinUnidirectionalSequenceLstm,
+              mapping_args.node);
+          if (mapping_args.node->inputs->size == 24) {
+            // Add layer normalization tensors if they are provided.
+            for (int i = 20; i < 24; ++i) {
+              const int input_index = mapping_args.node->inputs->data[i];
+              if (input_index != kOptionalTensor) {
+                mapping_args.builder->AddTensorInput(input_index, hybrid_op);
+              } else {
                 mapping_args.builder->AddVectorFloat32Operand(nullptr, 0);
               }
             }
-
-            return ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_LSTM;
-          };
-        }
-        break;
-      case kTfLiteBuiltinL2Normalization: {
-        if (version == 1) {
-          const auto& input = context->tensors[node->inputs->data[0]];
-          if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
-              (!IsFloatOperator(context, node) || input.dims->size != 4)) {
-            return nullptr;
-          }
-          auto builtin =
-              reinterpret_cast<TfLiteL2NormParams*>(node->builtin_data);
-          if (builtin->activation == kTfLiteActNone) {
-            return BasicMappingFn<ANEURALNETWORKS_L2_NORMALIZATION>;
-          }
-        }
-        break;
-      }
-      case kTfLiteBuiltinLocalResponseNormalization:
-        if (version == 1) {
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            auto builtin = reinterpret_cast<TfLiteLocalResponseNormParams*>(
-                mapping_args.node->builtin_data);
-            mapping_args.builder->AddScalarInt32Operand(builtin->radius);
-            mapping_args.builder->AddScalarFloat32Operand(builtin->bias);
-            mapping_args.builder->AddScalarFloat32Operand(builtin->alpha);
-            mapping_args.builder->AddScalarFloat32Operand(builtin->beta);
-            return ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION;
-          };
-        }
-        break;
-      case kTfLiteBuiltinLshProjection:
-        if (version == 1) {
-          // NNAPI does not support sparse projection correctly (b/111751836).
-          if (reinterpret_cast<TfLiteLSHProjectionParams*>(node->builtin_data)
-                  ->type == kTfLiteLshProjectionSparse) {
-            return nullptr;
-          }
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            auto builtin = reinterpret_cast<TfLiteLSHProjectionParams*>(
-                mapping_args.node->builtin_data);
-            mapping_args.builder->AddScalarInt32Operand(builtin->type);
-            return ANEURALNETWORKS_LSH_PROJECTION;
-          };
-        }
-        break;
-      case kTfLiteBuiltinConcatenation:
-        if (version == 1 &&
-            reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data)
-                    ->activation == kTfLiteActNone) {
-          if (context->tensors[node->inputs->data[0]].type == kTfLiteUInt8 &&
-              android_sdk_version < kMinSdkVersionForNNAPI12) {
-            // NNAPI 1.0-1 only supported concatenating quantized tensor of the
-            // same scale and offset.
-            auto first_param = context->tensors[node->inputs->data[0]].params;
-            for (int i = 1; i < node->inputs->size; i++) {
-              auto curr_param = context->tensors[node->inputs->data[i]].params;
-              if (curr_param.scale != first_param.scale ||
-                  curr_param.zero_point != first_param.zero_point) {
-                return nullptr;
-              }
-            }
-          }
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            auto builtin = reinterpret_cast<TfLiteConcatenationParams*>(
-                mapping_args.node->builtin_data);
-            mapping_args.builder->AddScalarInt32Operand(builtin->axis);
-            return ANEURALNETWORKS_CONCATENATION;
-          };
-        }
-        break;
-      case kTfLiteBuiltinDequantize:
-        if (version == 1 || version == 2) {
-          const auto& input = context->tensors[node->inputs->data[0]];
-          if (input.type == kTfLiteFloat16) {
-            return nullptr;
-          }
-          const auto zero_point = input.params.zero_point;
-          // NN API supports int8 type since version 1.2 but only for symmetric
-          // quantization.
-          if (input.type == kTfLiteInt8 &&
-              (zero_point != 0 ||
-               android_sdk_version < kMinSdkVersionForNNAPI12)) {
-            return nullptr;
-          }
-          return BasicMappingFn<ANEURALNETWORKS_DEQUANTIZE>;
-        }
-        break;
-      case kTfLiteBuiltinFloor:
-        if (version == 1) {
-          return BasicMappingFn<ANEURALNETWORKS_FLOOR>;
-        }
-        break;
-      case kTfLiteBuiltinRelu:
-        if (version == 1) {
-          if (!IsFloatOrUint8Operator(context, node)) {
-            return nullptr;
-          }
-          return BasicMappingFn<ANEURALNETWORKS_RELU>;
-        }
-        break;
-      case kTfLiteBuiltinReluN1To1:
-        if (version == 1) {
-          if (!IsFloatOrUint8Operator(context, node)) {
-            return nullptr;
-          }
-          return BasicMappingFn<ANEURALNETWORKS_RELU1>;
-        }
-        break;
-      case kTfLiteBuiltinRelu6:
-        if (version == 1) {
-          if (!IsFloatOrUint8Operator(context, node)) {
-            return nullptr;
-          }
-          return BasicMappingFn<ANEURALNETWORKS_RELU6>;
-        }
-        break;
-      case kTfLiteBuiltinLogistic:
-        if (version == 1) {
-          if (!IsFloatOrUint8Operator(context, node)) {
-            return nullptr;
-          }
-          return BasicMappingFn<ANEURALNETWORKS_LOGISTIC>;
-        }
-        break;
-      case kTfLiteBuiltinTanh:
-        // TODO(miaowang): add additional checks for the parameters.
-        if (version == 1) {
-          const TfLiteType input_type =
-              context->tensors[node->inputs->data[0]].type;
-          if (IsFloat(input_type) ||
-              (IsQuantized(input_type) &&
-               android_sdk_version >= kMinSdkVersionForNNAPI12)) {
-            // NNAPI only support float tanh.
-            return BasicMappingFn<ANEURALNETWORKS_TANH>;
-          }
-        }
-        break;
-      case kTfLiteBuiltinSub:
-        if (version == 1) {
-          const TfLiteType input_type =
-              context->tensors[node->inputs->data[0]].type;
-          if ((android_sdk_version >= kMinSdkVersionForNNAPI11 &&
-               IsFloat(input_type)) ||
-              (android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-               IsQuantized(input_type))) {
-            // NNAPI only support float sub.
-            return [](const NNAPIOpMappingArgs& mapping_args)
-                       -> ANeuralNetworksOperationType {
-              auto builtin = reinterpret_cast<TfLiteSubParams*>(
-                  mapping_args.node->builtin_data);
-              mapping_args.builder->AddScalarInt32Operand(builtin->activation);
-              return ANEURALNETWORKS_SUB;
-            };
-          }
-        }
-        break;
-      case kTfLiteBuiltinDiv:
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI11 &&
-            context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) {
-          // NNAPI only support float div.
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            auto builtin = reinterpret_cast<TfLiteDivParams*>(
-                mapping_args.node->builtin_data);
-            mapping_args.builder->AddScalarInt32Operand(builtin->activation);
-            return ANEURALNETWORKS_DIV;
-          };
-        }
-        break;
-      case kTfLiteBuiltinPad:
-      case kTfLiteBuiltinPadv2: {
-        const TfLiteType input_type =
-            context->tensors[node->inputs->data[0]].type;
-        if (version == 1 &&
-            (input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8)) {
-          const TfLiteIntArrayView input_shape(
-              context->tensors[node->inputs->data[0]].dims);
-          if (HasZeroes(input_shape)) {
-            // NN API pad ops do not support input tensors with no elements
-            return nullptr;
-          }
-          if (node->inputs->size == 2 &&
-              android_sdk_version >= kMinSdkVersionForNNAPI11 &&
-              (context->tensors[node->inputs->data[0]].type == kTfLiteFloat32 ||
-               android_sdk_version >= kMinSdkVersionForNNAPI12)) {
-            // NNAPI does not support specifying the padding value.
-            // Before 1.2, NNAPI pads physical zero for quantized tensors, so
-            // only delegate float pad to NNAPI. NNAPI 1.2 onwards pads with
-            // zero-point, so delegate quantized pad as well.
-            return BasicMappingFn<ANEURALNETWORKS_PAD>;
-          } else if (node->inputs->size == 3 &&
-                     android_sdk_version >= kMinSdkVersionForNNAPI12) {
-            const int constant_value_id = node->inputs->data[2];
-            if (constant_value_id == kOptionalTensor) {
-              return BasicMappingFn<ANEURALNETWORKS_PAD>;
-            }
-            return BasicMappingFn<ANEURALNETWORKS_PAD_V2>;
-          }
-        }
-      } break;
-      case kTfLiteBuiltinUnidirectionalSequenceRnn:
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
-          if (IsHybridOperator(context, builtin_code, node)) {
-            // Hybrid version of this op is not supported by NN API.
-            return nullptr;
-          }
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            auto builtin = reinterpret_cast<TfLiteSequenceRNNParams*>(
-                mapping_args.node->builtin_data);
-            mapping_args.builder->AddScalarInt32Operand(builtin->activation);
-            mapping_args.builder->AddScalarInt32Operand(builtin->time_major);
-            return ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_RNN;
-          };
-        }
-        break;
-      case kTfLiteBuiltinSpaceToBatchNd:
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI11) {
-          return BasicMappingFn<ANEURALNETWORKS_SPACE_TO_BATCH_ND>;
-        }
-        break;
-      case kTfLiteBuiltinStridedSlice:
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI11) {
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            auto builtin = reinterpret_cast<TfLiteStridedSliceParams*>(
-                mapping_args.node->builtin_data);
-            mapping_args.builder->AddScalarInt32Operand(builtin->begin_mask);
-            mapping_args.builder->AddScalarInt32Operand(builtin->end_mask);
-            mapping_args.builder->AddScalarInt32Operand(
-                builtin->shrink_axis_mask);
-            return ANEURALNETWORKS_STRIDED_SLICE;
-          };
-        }
-        break;
-      case kTfLiteBuiltinTranspose:
-        // Note that the permutation input tensor value dictates the output
-        // dimensions.
-        // TODO(b/110888333): Support dynamically-sized tensors in delegates.
-        if ((version == 1) &&
-            (android_sdk_version >= kMinSdkVersionForNNAPI11) &&
-            (node->inputs->size > 1) &&
-            (context->tensors[node->inputs->data[1]].allocation_type ==
-             kTfLiteMmapRo)) {
-          return BasicMappingFn<ANEURALNETWORKS_TRANSPOSE>;
-        }
-        break;
-      case kTfLiteBuiltinAbs:
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
-          return BasicMappingFn<ANEURALNETWORKS_ABS>;
-        }
-        break;
-      case kTfLiteBuiltinExp:
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
-          return BasicMappingFn<ANEURALNETWORKS_EXP>;
-        }
-        break;
-      case kTfLiteBuiltinLog:
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
-          return BasicMappingFn<ANEURALNETWORKS_LOG>;
-        }
-        break;
-      case kTfLiteBuiltinRsqrt:
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
-          return BasicMappingFn<ANEURALNETWORKS_RSQRT>;
-        }
-        break;
-      case kTfLiteBuiltinPow:
-        // NN API only supports float inputs to this op.
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-            context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) {
-          return BasicMappingFn<ANEURALNETWORKS_POW>;
-        }
-        break;
-      case kTfLiteBuiltinSlice: {
-        const auto input_type = context->tensors[node->inputs->data[0]].type;
-        const auto begin_type = context->tensors[node->inputs->data[1]].type;
-        const auto size_type = context->tensors[node->inputs->data[2]].type;
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-            (input_type == kTfLiteFloat32 || input_type == kTfLiteInt32 ||
-             input_type == kTfLiteUInt8) &&
-            begin_type == kTfLiteInt32 && size_type == kTfLiteInt32) {
-          return BasicMappingFn<ANEURALNETWORKS_SLICE>;
-        }
-      } break;
-      case kTfLiteBuiltinSin:
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
-          return BasicMappingFn<ANEURALNETWORKS_SIN>;
-        }
-        break;
-      case kTfLiteBuiltinSqrt:
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
-          return BasicMappingFn<ANEURALNETWORKS_SQRT>;
-        }
-        break;
-      case kTfLiteBuiltinRnn:
-        // NNAPI only support float32 weights.
-        if (version == 1 && node->inputs->size == 5 &&
-            context->tensors[node->inputs->data[/*kWeightsTensor*/ 1]].type ==
-                kTfLiteFloat32) {
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            // NNAPI need both state_in and state_out.
-            int ann_index;
-            mapping_args.builder->AddStateFloat32Tensor(
-                mapping_args.node->inputs->data[/*kHiddenStateTensor*/ 4],
-                &ann_index);
-            mapping_args.model_state_outputs->push_back(ann_index);
-            mapping_args.model_state_tfl_inputs->push_back(
-                mapping_args.node->inputs->data[/*kHiddenStateTensor*/ 4]);
-            auto builtin = reinterpret_cast<TfLiteRNNParams*>(
-                mapping_args.node->builtin_data);
-            mapping_args.builder->AddScalarInt32Operand(builtin->activation);
-            return ANEURALNETWORKS_RNN;
-          };
-        }
-        break;
-      case kTfLiteBuiltinSpaceToDepth: {
-        const TfLiteType input_type =
-            context->tensors[node->inputs->data[0]].type;
-        if (version == 1 &&
-            (input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8)) {
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            auto builtin = reinterpret_cast<TfLiteSpaceToDepthParams*>(
-                mapping_args.node->builtin_data);
-            mapping_args.builder->AddScalarInt32Operand(builtin->block_size);
-            return ANEURALNETWORKS_SPACE_TO_DEPTH;
-          };
-        }
-      } break;
-      case kTfLiteBuiltinSvdf:
-        // NNAPI only support float32 weights.
-        // Only delegate to NNAPI 1.1, as SVDF does not support rank > 1 on 1.0.
-        if (version == 1 && node->inputs->size == 5 &&
-            android_sdk_version >= kMinSdkVersionForNNAPI11 &&
-            context->tensors[node->inputs->data[/*kWeightsFeatureTensor*/ 1]]
-                    .type == kTfLiteFloat32) {
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            // NNAPI need both state_in and state_out.
-            int ann_index;
-            mapping_args.builder->AddStateFloat32Tensor(
-                mapping_args.node->inputs
-                    ->data[/*kInputActivationStateTensor*/ 4],
-                &ann_index);
-            mapping_args.model_state_outputs->push_back(ann_index);
-            mapping_args.model_state_tfl_inputs->push_back(
-                mapping_args.node->inputs
-                    ->data[/*kInputActivationStateTensor*/ 4]);
-
-            auto builtin = reinterpret_cast<TfLiteSVDFParams*>(
-                mapping_args.node->builtin_data);
-            mapping_args.builder->AddScalarInt32Operand(builtin->rank);
-            mapping_args.builder->AddScalarInt32Operand(builtin->activation);
-            return ANEURALNETWORKS_SVDF;
-          };
-        }
-        break;
-      case kTfLiteBuiltinLstm:
-        // TODO(miaowang): add loggings to indicate why the op is rejected.
-        if (version == 1) {
-          if (android_sdk_version < kMinSdkVersionForNNAPI11) {
-            // Only delegate to NNAPI 1.1+, as 1.0 has a bug for optional
-            // tensors which would affect LSTM.
-            return nullptr;
-          }
-          if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
-              IsHybridOperator(context, builtin_code, node)) {
-            // Hybrid operators not supported before NNAPI 1.2.
-            return nullptr;
-          }
-          // TODO(levp): name the constants for number of inputs in LSTM kernel.
-          if (node->inputs->size != 20 && node->inputs->size != 24) {
-            return nullptr;
-          }
-          if (node->inputs->size == 24 &&
-              android_sdk_version < kMinSdkVersionForNNAPI12) {
-            // LSTM with layer norm introduced in API level 29
-            return nullptr;
-          }
-          const TfLiteType weight_type =
-              context
-                  ->tensors[node->inputs
-                                ->data[/*kInputToOutputWeightsTensor*/ 4]]
-                  .type;
-          if (weight_type != kTfLiteFloat32 && weight_type != kTfLiteUInt8) {
-            return nullptr;
-          }
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            auto builtin = reinterpret_cast<TfLiteLSTMParams*>(
-                mapping_args.node->builtin_data);
-            mapping_args.builder->AddScalarInt32Operand(builtin->activation);
-            mapping_args.builder->AddScalarFloat32Operand(builtin->cell_clip);
-            mapping_args.builder->AddScalarFloat32Operand(builtin->proj_clip);
-
-            // Current NNAPI implementation requires the scratch_buffer as
-            // output.
-            mapping_args.builder->AddAdditionalFloat32OutputTensor(2);
-
-            // NNAPI need both state_in and state_out for cell_state and
-            // output_state.
-            int ann_index;
-            mapping_args.builder->AddStateFloat32Tensor(
-                mapping_args.node->inputs
-                    ->data[/*kInputActivationStateTensor*/ 18],
-                &ann_index);
-            mapping_args.model_state_outputs->push_back(ann_index);
-            mapping_args.model_state_tfl_inputs->push_back(
-                mapping_args.node->inputs
-                    ->data[/*kInputActivationStateTensor*/ 18]);
-            mapping_args.builder->AddStateFloat32Tensor(
-                mapping_args.node->inputs->data[/*kInputCellStateTensor*/ 19],
-                &ann_index);
-            mapping_args.model_state_outputs->push_back(ann_index);
-            mapping_args.model_state_tfl_inputs->push_back(
-                mapping_args.node->inputs->data[/*kInputCellStateTensor*/ 19]);
-
-            const bool hybrid_op = IsHybridOperator(
-                mapping_args.context, kTfLiteBuiltinLstm, mapping_args.node);
-
-            if (mapping_args.node->inputs->size == 24) {
-              for (int i = 20; i < 24; ++i) {
-                const auto input_index = mapping_args.node->inputs->data[i];
-                if (input_index != kOptionalTensor) {
-                  mapping_args.builder->AddTensorInput(input_index, hybrid_op);
-                } else {
-                  mapping_args.builder->AddVectorFloat32Operand(nullptr, 0);
-                }
-              }
-            }
-
-            return ANEURALNETWORKS_LSTM;
-          };
-        }
-        break;
-      case kTfLiteBuiltinMean:
-        // NNAPI does not support generating a scalar as output for MEAN.
-        if (version == 1 &&
-            ((android_sdk_version >= kMinSdkVersionForNNAPI11 &&
-              context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) ||
-             (android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-              context->tensors[node->inputs->data[0]].type == kTfLiteUInt8)) &&
-            context->tensors[node->outputs->data[0]].dims->size > 0) {
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            auto builtin = reinterpret_cast<TfLiteReducerParams*>(
-                mapping_args.node->builtin_data);
-            int32_t keep_dims = 0;
-            if (builtin->keep_dims) keep_dims = 1;
-            mapping_args.builder->AddScalarInt32Operand(keep_dims);
-            return ANEURALNETWORKS_MEAN;
-          };
-        }
-        break;
-      case kTfLiteBuiltinEmbeddingLookup:
-        // NNAPI only support float32 values.
-        if (version == 1 &&
-            context->tensors[node->inputs->data[1]].type == kTfLiteFloat32) {
-          return BasicMappingFn<ANEURALNETWORKS_EMBEDDING_LOOKUP>;
-        }
-        break;
-      case kTfLiteBuiltinHashtableLookup:
-        // NNAPI only support float32 output.
-        if (version == 1 &&
-            context->tensors[node->outputs->data[0]].type == kTfLiteFloat32) {
-          return BasicMappingFn<ANEURALNETWORKS_HASHTABLE_LOOKUP>;
-        }
-        break;
-      case kTfLiteBuiltinMaximum: {
-        const auto input_type = context->tensors[node->inputs->data[0]].type;
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-            (input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
-             input_type == kTfLiteInt32)) {
-          return BasicMappingFn<ANEURALNETWORKS_MAXIMUM>;
-        }
-      } break;
-      case kTfLiteBuiltinMinimum: {
-        const auto input_type = context->tensors[node->inputs->data[0]].type;
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-            (input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
-             input_type == kTfLiteInt32)) {
-          return BasicMappingFn<ANEURALNETWORKS_MINIMUM>;
-        }
-      } break;
-      case kTfLiteBuiltinCast: {
-        const TfLiteType input_type =
-            context->tensors[node->inputs->data[0]].type;
-        const TfLiteType output_type =
-            context->tensors[node->outputs->data[0]].type;
-        auto is_supported_tensor_type = [](const TfLiteType& type) {
-          return (type == kTfLiteFloat32 || type == kTfLiteInt32 ||
-                  type == kTfLiteUInt8);
-        };
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-            is_supported_tensor_type(input_type) &&
-            is_supported_tensor_type(output_type)) {
-          return BasicMappingFn<ANEURALNETWORKS_CAST>;
-        }
-      } break;
-      case kTfLiteBuiltinPrelu:
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
-          if (!IsFloatOrUint8Operator(context, node)) {
-            return nullptr;
-          }
-          return BasicMappingFn<ANEURALNETWORKS_PRELU>;
-        }
-        break;
-      case kTfLiteBuiltinTile: {
-        // NN API doesn't support int64 and boolean inputs to this op
-        const auto input_type = context->tensors[node->inputs->data[0]].type;
-        const auto multipliers_type =
-            context->tensors[node->inputs->data[1]].type;
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-            (input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
-             input_type == kTfLiteInt32) &&
-            (multipliers_type == kTfLiteInt32)) {
-          return BasicMappingFn<ANEURALNETWORKS_TILE>;
-        }
-      } break;
-      case kTfLiteBuiltinLogicalOr: {
-        const auto input_type = context->tensors[node->inputs->data[0]].type;
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-            input_type == kTfLiteBool) {
-          return BasicMappingFn<ANEURALNETWORKS_LOGICAL_OR>;
-        }
-      } break;
-      case kTfLiteBuiltinLogicalAnd: {
-        const auto input_type = context->tensors[node->inputs->data[0]].type;
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-            input_type == kTfLiteBool) {
-          return BasicMappingFn<ANEURALNETWORKS_LOGICAL_AND>;
-        }
-      } break;
-      case kTfLiteBuiltinLogicalNot: {
-        const auto input_type = context->tensors[node->inputs->data[0]].type;
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-            input_type == kTfLiteBool) {
-          return BasicMappingFn<ANEURALNETWORKS_LOGICAL_NOT>;
-        }
-      } break;
-      case kTfLiteBuiltinLess: {
-        const auto input_type = context->tensors[node->inputs->data[0]].type;
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-            (input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
-             input_type == kTfLiteBool || input_type == kTfLiteInt32)) {
-          return BasicMappingFn<ANEURALNETWORKS_LESS>;
-        }
-      } break;
-      case kTfLiteBuiltinLessEqual: {
-        const auto input_type = context->tensors[node->inputs->data[0]].type;
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-            (input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
-             input_type == kTfLiteBool || input_type == kTfLiteInt32)) {
-          return BasicMappingFn<ANEURALNETWORKS_LESS_EQUAL>;
-        }
-      } break;
-      case kTfLiteBuiltinGreater: {
-        const auto input_type = context->tensors[node->inputs->data[0]].type;
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-            (input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
-             input_type == kTfLiteBool || input_type == kTfLiteInt32)) {
-          return BasicMappingFn<ANEURALNETWORKS_GREATER>;
-        }
-      } break;
-      case kTfLiteBuiltinGreaterEqual: {
-        const auto input_type = context->tensors[node->inputs->data[0]].type;
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-            (input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
-             input_type == kTfLiteBool || input_type == kTfLiteInt32)) {
-          return BasicMappingFn<ANEURALNETWORKS_GREATER_EQUAL>;
-        }
-      } break;
-      case kTfLiteBuiltinEqual: {
-        const auto input_type = context->tensors[node->inputs->data[0]].type;
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-            (input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
-             input_type == kTfLiteBool || input_type == kTfLiteInt32)) {
-          return BasicMappingFn<ANEURALNETWORKS_EQUAL>;
-        }
-      } break;
-      case kTfLiteBuiltinNotEqual: {
-        const auto input_type = context->tensors[node->inputs->data[0]].type;
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-            (input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
-             input_type == kTfLiteBool || input_type == kTfLiteInt32)) {
-          return BasicMappingFn<ANEURALNETWORKS_NOT_EQUAL>;
-        }
-      } break;
-      case kTfLiteBuiltinNeg: {
-        const auto input_type = context->tensors[node->inputs->data[0]].type;
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-            (input_type == kTfLiteFloat32 || input_type == kTfLiteInt32)) {
-          return BasicMappingFn<ANEURALNETWORKS_NEG>;
-        }
-      } break;
-      case kTfLiteBuiltinTopkV2: {
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
-          const auto& input = context->tensors[node->outputs->data[0]];
-          const auto& k_param = context->tensors[node->outputs->data[1]];
-          if ((input.type == kTfLiteFloat32 || input.type == kTfLiteInt32 ||
-               input.type == kTfLiteInt8) &&
-              (k_param.type == kTfLiteInt32 &&
-               k_param.allocation_type == kTfLiteMmapRo)) {
-            return [](const NNAPIOpMappingArgs& mapping_args)
-                       -> ANeuralNetworksOperationType {
-              const TfLiteTensor& k_param =
-                  mapping_args.context
-                      ->tensors[mapping_args.node->inputs->data[1]];
-              mapping_args.builder->AddScalarInt32Operand(*k_param.data.i32);
-              return ANEURALNETWORKS_TOPK_V2;
-            };
           } else {
-            return nullptr;
-          }
-        }
-      } break;
-      case kTfLiteBuiltinSelect: {
-        const auto value_type = context->tensors[node->inputs->data[1]].type;
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-            (value_type == kTfLiteFloat32 || value_type == kTfLiteUInt8 ||
-             value_type == kTfLiteInt32)) {
-          TfLiteIntArray* condition_shape =
-              context->tensors[node->inputs->data[0]].dims;
-          TfLiteIntArray* input_shape =
-              context->tensors[node->inputs->data[1]].dims;
-          // The Android Q-variant of select does not support broadcasting.
-          if (!TfLiteIntArrayEqual(condition_shape, input_shape)) {
-            return nullptr;
-          }
-          return BasicMappingFn<ANEURALNETWORKS_SELECT>;
-        }
-      } break;
-      case kTfLiteBuiltinGather: {
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
-          const auto& input = context->tensors[node->inputs->data[0]];
-          const auto& positions = context->tensors[node->inputs->data[1]];
-
-          auto is_supported_input_type = [](const TfLiteTensor& t) {
-            return (t.type == kTfLiteFloat32 || t.type == kTfLiteFloat16 ||
-                    t.type == kTfLiteInt32 || t.type == kTfLiteUInt8);
-          };
-
-          if (!is_supported_input_type(input) ||
-              !is_supported_input_type(positions)) {
-            return nullptr;
-          }
-
-          // 0-dimension args are not supported by NNAPI.
-          if (positions.dims->size == 0) {
-            return nullptr;
-          }
-
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            auto builtin = reinterpret_cast<TfLiteGatherParams*>(
-                mapping_args.node->builtin_data);
-            mapping_args.builder->AddTensorInput(
-                mapping_args.node->inputs->data[0],
-                /* hybrid_op */ false,
-                /* scalar_as_tensor */ false);
-
-            mapping_args.builder->AddScalarInt32Operand(builtin->axis);
-
-            mapping_args.builder->AddTensorInput(
-                mapping_args.node->inputs->data[1],
-                /* hybrid_op */ false,
-                /* scalar_as_tensor */ false);
-
-            return ANEURALNETWORKS_GATHER;
-          };
-        }
-      } break;
-      case kTfLiteBuiltinBidirectionalSequenceLstm:
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
-          if (IsHybridOperator(context, builtin_code, node)) {
-            // Hybrid version of this op is not supported by NN API.
-            return nullptr;
-          }
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            auto builtin =
-                reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
-                    mapping_args.node->builtin_data);
-            mapping_args.builder->AddScalarInt32Operand(builtin->activation);
-            mapping_args.builder->AddScalarFloat32Operand(builtin->cell_clip);
-            mapping_args.builder->AddScalarFloat32Operand(builtin->proj_clip);
-            mapping_args.builder->AddScalarBoolOperand(builtin->merge_outputs);
-            mapping_args.builder->AddScalarBoolOperand(builtin->time_major);
-            // TF Lite doesn't support layer normalization in bidirectional
-            // sequence LSTM, so we insert optional tensors for NNAPI
-            for (int i = 0; i < 8; ++i) {
+            for (int i = 0; i < 4; ++i) {
               mapping_args.builder->AddVectorFloat32Operand(nullptr, 0);
             }
-            return ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_LSTM;
-          };
+          }
+
+          return ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_LSTM;
+        };
+      }
+      break;
+    case kTfLiteBuiltinL2Normalization: {
+      if (version <= 2) {
+        const auto& input = context->tensors[node->inputs->data[0]];
+        if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
+            (!IsFloatOperator(context, node) || input.dims->size != 4)) {
+          return nullptr;
         }
-        break;
-      case kTfLiteBuiltinExpandDims: {
-        const auto input_type = context->tensors[node->inputs->data[0]].type;
-        const auto axis = context->tensors[node->inputs->data[1]];
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-            (input_type == kTfLiteFloat16 || input_type == kTfLiteFloat32 ||
-             input_type == kTfLiteInt32 || input_type == kTfLiteUInt8) &&
-            // TFLite supports axis also as int64 but NNAPI only int32
-            (axis.type == kTfLiteInt32 &&
-             axis.allocation_type == kTfLiteMmapRo)) {
+        auto builtin =
+            reinterpret_cast<TfLiteL2NormParams*>(node->builtin_data);
+        if (builtin->activation == kTfLiteActNone) {
+          return BasicMappingFn<ANEURALNETWORKS_L2_NORMALIZATION>;
+        }
+      }
+      break;
+    }
+    case kTfLiteBuiltinLocalResponseNormalization:
+      if (version == 1) {
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          auto builtin = reinterpret_cast<TfLiteLocalResponseNormParams*>(
+              mapping_args.node->builtin_data);
+          mapping_args.builder->AddScalarInt32Operand(builtin->radius);
+          mapping_args.builder->AddScalarFloat32Operand(builtin->bias);
+          mapping_args.builder->AddScalarFloat32Operand(builtin->alpha);
+          mapping_args.builder->AddScalarFloat32Operand(builtin->beta);
+          return ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION;
+        };
+      }
+      break;
+    case kTfLiteBuiltinLshProjection:
+      if (version == 1) {
+        if (reinterpret_cast<TfLiteLSHProjectionParams*>(node->builtin_data)
+                ->type == kTfLiteLshProjectionSparse) {
+          // NNAPI does not support sparse projection correctly pre-Q
+          // (b/111751836).
+          if (android_sdk_version < kMinSdkVersionForNNAPI12) {
+            return nullptr;
+          }
+          // NNAPI does not support weights for sparse projects.
+          if (node->inputs->size != 2) {
+            return nullptr;
+          }
+        }
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          auto builtin = reinterpret_cast<TfLiteLSHProjectionParams*>(
+              mapping_args.node->builtin_data);
+          int type = builtin->type;
+          // In Android Q+, NNAPI uses 3 to denote
+          // kTfLiteLshProjectionSparse.
+          const int kNNAPILshProjectionSparse = 3;
+          if (builtin->type == kTfLiteLshProjectionSparse) {
+            type = kNNAPILshProjectionSparse;
+            // Add NNAPI null weight operand.
+            mapping_args.builder->AddVectorFloat32Operand(nullptr, 0);
+          }
+          mapping_args.builder->AddScalarInt32Operand(type);
+          return ANEURALNETWORKS_LSH_PROJECTION;
+        };
+      }
+      break;
+    case kTfLiteBuiltinConcatenation:
+      if (version <= 2 &&
+          reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data)
+                  ->activation == kTfLiteActNone &&
+          context->tensors[node->inputs->data[0]].dims->size <= 4) {
+        if (context->tensors[node->inputs->data[0]].type == kTfLiteUInt8 &&
+            android_sdk_version < kMinSdkVersionForNNAPI12) {
+          // NNAPI 1.0-1 only supported concatenating quantized tensor of
+          // the same scale and offset.
+          auto first_param = context->tensors[node->inputs->data[0]].params;
+          for (int i = 1; i < node->inputs->size; i++) {
+            auto curr_param = context->tensors[node->inputs->data[i]].params;
+            if (curr_param.scale != first_param.scale ||
+                curr_param.zero_point != first_param.zero_point) {
+              return nullptr;
+            }
+          }
+        }
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          auto builtin = reinterpret_cast<TfLiteConcatenationParams*>(
+              mapping_args.node->builtin_data);
+          int axis = builtin->axis < 0
+                         ? mapping_args.context
+                                   ->tensors[mapping_args.node->inputs->data[0]]
+                                   .dims->size +
+                               builtin->axis
+                         : builtin->axis;
+          mapping_args.builder->AddScalarInt32Operand(axis);
+          return ANEURALNETWORKS_CONCATENATION;
+        };
+      }
+      break;
+    case kTfLiteBuiltinDequantize:
+      if (version == 1 || version == 2) {
+        const auto& input = context->tensors[node->inputs->data[0]];
+        if (input.type == kTfLiteFloat16) {
+          return nullptr;
+        }
+        const auto zero_point = input.params.zero_point;
+        // NN API supports int8 type since version 1.2 but only for
+        // symmetric quantization.
+        if (input.type == kTfLiteInt8 &&
+            (zero_point != 0 ||
+             android_sdk_version < kMinSdkVersionForNNAPI12)) {
+          return nullptr;
+        }
+        return BasicMappingFn<ANEURALNETWORKS_DEQUANTIZE>;
+      }
+      break;
+    case kTfLiteBuiltinFloor:
+      if (version == 1) {
+        return BasicMappingFn<ANEURALNETWORKS_FLOOR>;
+      }
+      break;
+    case kTfLiteBuiltinRelu:
+      if (version == 1) {
+        if (!IsFloatOrQuant8Operator(context, node)) {
+          return nullptr;
+        }
+        return BasicMappingFn<ANEURALNETWORKS_RELU>;
+      }
+      break;
+    case kTfLiteBuiltinReluN1To1:
+      if (version == 1) {
+        if (!IsFloatOrQuant8Operator(context, node)) {
+          return nullptr;
+        }
+        return BasicMappingFn<ANEURALNETWORKS_RELU1>;
+      }
+      break;
+    case kTfLiteBuiltinRelu6:
+      if (version == 1) {
+        if (!IsFloatOrQuant8Operator(context, node)) {
+          return nullptr;
+        }
+        return BasicMappingFn<ANEURALNETWORKS_RELU6>;
+      }
+      break;
+    case kTfLiteBuiltinLogistic:
+      if (version <= 2) {
+        if (!IsFloatOrQuant8Operator(context, node)) {
+          return nullptr;
+        }
+        return BasicMappingFn<ANEURALNETWORKS_LOGISTIC>;
+      }
+      break;
+    case kTfLiteBuiltinTanh:
+      if (version <= 2) {
+        const TfLiteType input_type =
+            context->tensors[node->inputs->data[0]].type;
+        if (IsFloat(input_type) ||
+            (IsQuantized(input_type) &&
+             android_sdk_version >= kMinSdkVersionForNNAPI12)) {
+          // NNAPI only support float tanh.
+          return BasicMappingFn<ANEURALNETWORKS_TANH>;
+        }
+      }
+      break;
+    case kTfLiteBuiltinSub:
+      if (version <= 2) {
+        const TfLiteType input_type =
+            context->tensors[node->inputs->data[0]].type;
+        if ((android_sdk_version >= kMinSdkVersionForNNAPI11 &&
+             IsFloat(input_type)) ||
+            (android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+             IsQuantized(input_type))) {
+          // NNAPI only support float sub.
           return [](const NNAPIOpMappingArgs& mapping_args)
                      -> ANeuralNetworksOperationType {
-            const TfLiteTensor& axis_param =
+            auto builtin = reinterpret_cast<TfLiteSubParams*>(
+                mapping_args.node->builtin_data);
+            mapping_args.builder->AddScalarInt32Operand(builtin->activation);
+            return ANEURALNETWORKS_SUB;
+          };
+        }
+      }
+      break;
+    case kTfLiteBuiltinDiv:
+      if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI11 &&
+          context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) {
+        // NNAPI only support float div.
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          auto builtin = reinterpret_cast<TfLiteDivParams*>(
+              mapping_args.node->builtin_data);
+          mapping_args.builder->AddScalarInt32Operand(builtin->activation);
+          return ANEURALNETWORKS_DIV;
+        };
+      }
+      break;
+    case kTfLiteBuiltinPad:
+    case kTfLiteBuiltinPadv2: {
+      if (version <= 2 && IsFloatOrQuant8Operator(context, node)) {
+        const TfLiteIntArrayView input_shape(
+            context->tensors[node->inputs->data[0]].dims);
+        if (HasZeroes(input_shape)) {
+          // NN API pad ops do not support input tensors with no elements
+          return nullptr;
+        }
+        if (node->inputs->size == 2 &&
+            android_sdk_version >= kMinSdkVersionForNNAPI11 &&
+            (context->tensors[node->inputs->data[0]].type == kTfLiteFloat32 ||
+             android_sdk_version >= kMinSdkVersionForNNAPI12)) {
+          // NNAPI does not support specifying the padding value.
+          // Before 1.2, NNAPI pads physical zero for quantized tensors, so
+          // only delegate float pad to NNAPI. NNAPI 1.2 onwards pads with
+          // zero-point, so delegate quantized pad as well.
+          return BasicMappingFn<ANEURALNETWORKS_PAD>;
+        } else if (node->inputs->size == 3 &&
+                   android_sdk_version >= kMinSdkVersionForNNAPI12) {
+          const int constant_value_id = node->inputs->data[2];
+          if (constant_value_id == kOptionalTensor) {
+            return BasicMappingFn<ANEURALNETWORKS_PAD>;
+          }
+          return BasicMappingFn<ANEURALNETWORKS_PAD_V2>;
+        }
+      }
+    } break;
+    case kTfLiteBuiltinUnidirectionalSequenceRnn:
+      if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
+        if (IsHybridOperator(context, builtin_code, node)) {
+          // Hybrid version of this op is not supported by NN API.
+          return nullptr;
+        }
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          auto builtin = reinterpret_cast<TfLiteSequenceRNNParams*>(
+              mapping_args.node->builtin_data);
+          mapping_args.builder->AddScalarInt32Operand(builtin->activation);
+          mapping_args.builder->AddScalarInt32Operand(builtin->time_major);
+          return ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_RNN;
+        };
+      }
+      break;
+    case kTfLiteBuiltinSpaceToBatchNd:
+      if (version <= 2 && android_sdk_version >= kMinSdkVersionForNNAPI11) {
+        return BasicMappingFn<ANEURALNETWORKS_SPACE_TO_BATCH_ND>;
+      }
+      break;
+    case kTfLiteBuiltinBatchToSpaceNd:
+      if (version <= 2 && android_sdk_version >= kMinSdkVersionForNNAPI11) {
+        auto crops = context->tensors[node->inputs->data[2]];
+        auto crops_data = crops.data.i32;
+        // Check if all crops are 0.
+        if (!crops_data || crops.bytes != 16 || crops_data[0] != 0 ||
+            crops_data[1] != 0 || crops_data[2] != 0 || crops_data[3] != 0) {
+          return nullptr;
+        }
+        return BasicMappingFn<ANEURALNETWORKS_BATCH_TO_SPACE_ND>;
+      }
+      break;
+    case kTfLiteBuiltinStridedSlice:
+      if (version <= 2 && android_sdk_version >= kMinSdkVersionForNNAPI11) {
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          auto builtin = reinterpret_cast<TfLiteStridedSliceParams*>(
+              mapping_args.node->builtin_data);
+          mapping_args.builder->AddScalarInt32Operand(builtin->begin_mask);
+          mapping_args.builder->AddScalarInt32Operand(builtin->end_mask);
+          mapping_args.builder->AddScalarInt32Operand(
+              builtin->shrink_axis_mask);
+          return ANEURALNETWORKS_STRIDED_SLICE;
+        };
+      }
+      break;
+    case kTfLiteBuiltinTranspose:
+      // Note that the permutation input tensor value dictates the output
+      // dimensions.
+      // TODO(b/110888333): Support dynamically-sized tensors in delegates.
+      if ((version <= 2) && (android_sdk_version >= kMinSdkVersionForNNAPI11) &&
+          (node->inputs->size > 1) &&
+          (context->tensors[node->inputs->data[1]].allocation_type ==
+           kTfLiteMmapRo)) {
+        return BasicMappingFn<ANEURALNETWORKS_TRANSPOSE>;
+      }
+      break;
+    case kTfLiteBuiltinAbs:
+      // NN API only supports float inputs to this op.
+      if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          IsFloat(context->tensors[node->inputs->data[0]].type)) {
+        return BasicMappingFn<ANEURALNETWORKS_ABS>;
+      }
+      break;
+    case kTfLiteBuiltinExp:
+      // NN API only supports float inputs to this op.
+      if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          IsFloat(context->tensors[node->inputs->data[0]].type)) {
+        return BasicMappingFn<ANEURALNETWORKS_EXP>;
+      }
+      break;
+    case kTfLiteBuiltinLog:
+      // NN API only supports float inputs to this op.
+      if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          IsFloat(context->tensors[node->inputs->data[0]].type)) {
+        return BasicMappingFn<ANEURALNETWORKS_LOG>;
+      }
+      break;
+    case kTfLiteBuiltinRsqrt:
+      // NN API only supports float inputs to this op.
+      if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          IsFloatOperator(context, node)) {
+        return BasicMappingFn<ANEURALNETWORKS_RSQRT>;
+      }
+      break;
+    case kTfLiteBuiltinPow:
+      // NN API only supports float inputs to this op.
+      if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          IsFloat(context->tensors[node->inputs->data[0]].type)) {
+        return BasicMappingFn<ANEURALNETWORKS_POW>;
+      }
+      break;
+    case kTfLiteBuiltinSlice: {
+      const auto input_type = context->tensors[node->inputs->data[0]].type;
+      const auto begin_type = context->tensors[node->inputs->data[1]].type;
+      const auto size_type = context->tensors[node->inputs->data[2]].type;
+      if (version <= 2 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          (input_type == kTfLiteFloat32 || input_type == kTfLiteInt32 ||
+           input_type == kTfLiteUInt8 || input_type == kTfLiteInt8) &&
+          begin_type == kTfLiteInt32 && size_type == kTfLiteInt32) {
+        return BasicMappingFn<ANEURALNETWORKS_SLICE>;
+      }
+    } break;
+    case kTfLiteBuiltinSin:
+      if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          IsFloat(context->tensors[node->inputs->data[0]].type)) {
+        return BasicMappingFn<ANEURALNETWORKS_SIN>;
+      }
+      break;
+    case kTfLiteBuiltinTransposeConv:
+      if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          const bool hybrid_op =
+              IsHybridOperator(mapping_args.context,
+                               kTfLiteBuiltinTransposeConv, mapping_args.node);
+          mapping_args.builder->AddTensorInput(/*kDataInputTensor*/ 2,
+                                               hybrid_op);
+          mapping_args.builder->AddTensorInput(/*kWeightsTensor*/ 1, hybrid_op);
+
+          // NNAPI requires a bias tensor, so we allocate a new tensor to fill
+          // it with zeroes. It is deleted with other tensors in the context
+          // during subgraph destructor call.
+          int bias_index = -1;
+          mapping_args.context->AddTensors(mapping_args.context, 1,
+                                           &bias_index);
+          TfLiteTensor* bias_tensor =
+              &mapping_args.context->tensors[bias_index];
+          const auto input_type =
+              mapping_args.context
+                  ->tensors[mapping_args.node->inputs
+                                ->data[/*kDataInputTensor*/ 2]]
+                  .type;
+          if (input_type == kTfLiteFloat32) {
+            bias_tensor->type = kTfLiteFloat32;
+          } else {
+            bias_tensor->type = kTfLiteInt32;
+          }
+
+          // Create an array with a required bias shape and resize the bias
+          // tensor.
+          TfLiteIntArray* bias_shape = TfLiteIntArrayCreate(1);
+          const TfLiteTensor& output_shape =
+              mapping_args.context->tensors
+                  [mapping_args.node->inputs->data[/*kOutputShapeTensor*/ 0]];
+          const int output_depth = output_shape.data.i32[3];
+          bias_shape->data[0] = output_depth;
+          bias_tensor->allocation_type = kTfLiteDynamic;
+          mapping_args.context->ResizeTensor(mapping_args.context, bias_tensor,
+                                             bias_shape);
+
+          // Set tensor's values to zeroes and add it using AddVector*, so
+          // that the values are copied to NNAPI. We don't use the AddTensor
+          // function because it doesn't copy values and the tensor we just
+          // created is not in the node->inputs.
+          if (input_type == kTfLiteFloat32) {
+            memset(bias_tensor->data.f, 0, output_depth * sizeof(float));
+            mapping_args.builder->AddVectorFloat32Operand(bias_tensor->data.f,
+                                                          output_depth);
+          } else {
+            memset(bias_tensor->data.i32, 0, output_depth * sizeof(int));
+            const TfLiteTensor& input_tensor =
+                mapping_args.context->tensors
+                    [mapping_args.node->inputs->data[/*kDataInputTensor*/ 2]];
+            const TfLiteTensor& filter_tensor =
+                mapping_args.context->tensors[mapping_args.node->inputs
+                                                  ->data[/*kWeightsTensor*/ 1]];
+            // NNAPI requires bias scale to be a product of an input scale and
+            // a filter scale.
+            bias_tensor->params.scale =
+                input_tensor.params.scale * filter_tensor.params.scale;
+            mapping_args.builder->AddVectorInt32Operand(
+                bias_tensor->data.i32, output_depth,
+                input_tensor.params.scale * filter_tensor.params.scale,
+                /*zero_point=*/0);
+          }
+
+          mapping_args.builder->AddTensorInput(/*kOutputShapeTensor*/ 0,
+                                               hybrid_op);
+
+          auto builtin = reinterpret_cast<TfLiteTransposeConvParams*>(
+              mapping_args.node->builtin_data);
+          mapping_args.builder->AddScalarInt32Operand(builtin->padding);
+          mapping_args.builder->AddScalarInt32Operand(builtin->stride_width);
+          mapping_args.builder->AddScalarInt32Operand(builtin->stride_height);
+          mapping_args.builder->AddScalarInt32Operand(
+              /*ANEURALNETWORKS_FUSED_NONE*/ 0);
+          // Use NHWC layout for input and output
+          mapping_args.builder->AddScalarBoolOperand(false);
+          return ANEURALNETWORKS_TRANSPOSE_CONV;
+        };
+      }
+      break;
+    case kTfLiteBuiltinSqrt:
+      if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          IsFloat(context->tensors[node->inputs->data[0]].type)) {
+        return BasicMappingFn<ANEURALNETWORKS_SQRT>;
+      }
+      break;
+    case kTfLiteBuiltinRnn:
+      // NNAPI only support float32 weights.
+      if (version == 1 && node->inputs->size == 5 &&
+          context->tensors[node->inputs->data[/*kWeightsTensor*/ 1]].type ==
+              kTfLiteFloat32) {
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          // NNAPI need both state_in and state_out.
+          int ann_index;
+          mapping_args.builder->AddStateFloat32Tensor(
+              mapping_args.node->inputs->data[/*kHiddenStateTensor*/ 4],
+              &ann_index);
+          mapping_args.model_state_outputs->push_back(ann_index);
+          mapping_args.model_state_tfl_inputs->push_back(
+              mapping_args.node->inputs->data[/*kHiddenStateTensor*/ 4]);
+          auto builtin = reinterpret_cast<TfLiteRNNParams*>(
+              mapping_args.node->builtin_data);
+          mapping_args.builder->AddScalarInt32Operand(builtin->activation);
+          return ANEURALNETWORKS_RNN;
+        };
+      }
+      break;
+    case kTfLiteBuiltinSpaceToDepth: {
+      const TfLiteType input_type =
+          context->tensors[node->inputs->data[0]].type;
+      if (version <= 2 &&
+          (input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
+           input_type == kTfLiteInt8)) {
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          auto builtin = reinterpret_cast<TfLiteSpaceToDepthParams*>(
+              mapping_args.node->builtin_data);
+          mapping_args.builder->AddScalarInt32Operand(builtin->block_size);
+          return ANEURALNETWORKS_SPACE_TO_DEPTH;
+        };
+      }
+    } break;
+    case kTfLiteBuiltinSvdf:
+      // NNAPI only support float32 weights.
+      // Only delegate to NNAPI 1.1, as SVDF does not support rank > 1
+      // on 1.0.
+      if (version == 1 && node->inputs->size == 5 &&
+          android_sdk_version >= kMinSdkVersionForNNAPI11 &&
+          context->tensors[node->inputs->data[/*kWeightsFeatureTensor*/ 1]]
+                  .type == kTfLiteFloat32) {
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          // NNAPI need both state_in and state_out.
+          int ann_index;
+          mapping_args.builder->AddStateFloat32Tensor(
+              mapping_args.node->inputs
+                  ->data[/*kInputActivationStateTensor*/ 4],
+              &ann_index);
+          mapping_args.model_state_outputs->push_back(ann_index);
+          mapping_args.model_state_tfl_inputs->push_back(
+              mapping_args.node->inputs
+                  ->data[/*kInputActivationStateTensor*/ 4]);
+
+          auto builtin = reinterpret_cast<TfLiteSVDFParams*>(
+              mapping_args.node->builtin_data);
+          mapping_args.builder->AddScalarInt32Operand(builtin->rank);
+          mapping_args.builder->AddScalarInt32Operand(builtin->activation);
+          return ANEURALNETWORKS_SVDF;
+        };
+      }
+      break;
+    case kTfLiteBuiltinLstm:
+      // TODO(miaowang): add loggings to indicate why the op is rejected.
+      if (version <= 3) {
+        if (android_sdk_version < kMinSdkVersionForNNAPI11) {
+          // Only delegate to NNAPI 1.1+, as 1.0 has a bug for optional
+          // tensors which would affect LSTM.
+          return nullptr;
+        }
+        if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
+            IsHybridOperator(context, builtin_code, node)) {
+          // Hybrid operators not supported before NNAPI 1.2.
+          return nullptr;
+        }
+
+        const auto weight_input_index =
+            isLstmBasicKernel(node) ? 2 /*  basic::kInputWeights */
+                                    : 4 /* full::kInputToOutputWeightsTensor */;
+
+        const TfLiteType weight_type =
+            context->tensors[node->inputs->data[weight_input_index]].type;
+
+        if (isLstmBasicKernel(node)) {
+          if (weight_type != kTfLiteUInt8) {
+            return nullptr;
+          }
+          const auto input_quantization_params =
+              context->tensors[node->inputs->data[0]].params;
+          if (input_quantization_params.scale != 1. / 128. ||
+              input_quantization_params.zero_point != 128) {
+            return nullptr;
+          }
+
+          const auto output_quantization_params =
+              context->tensors[node->outputs->data[0]].params;
+          if (output_quantization_params.scale != 1. / 128. ||
+              output_quantization_params.zero_point != 128) {
+            return nullptr;
+          }
+
+          const auto cell_state_quantization_params =
+              context->tensors[node->outputs->data[1]].params;
+          if (cell_state_quantization_params.scale != 16. / 32768. ||
+              cell_state_quantization_params.zero_point != 0) {
+            return nullptr;
+          }
+
+          auto is_const_tensor = [&node, &context](int tensor_idx) {
+            return context->tensors[node->inputs->data[tensor_idx]]
+                       .allocation_type == kTfLiteMmapRo;
+          };
+
+          if (!is_const_tensor(2 /* kInputWeights */)) {
+            return nullptr;
+          }
+
+          if (!is_const_tensor(3 /* kInputBiases */)) {
+            return nullptr;
+          }
+
+          return [](const NNAPIOpMappingArgs& mapping_args)
+                     -> ANeuralNetworksOperationType {
+            const auto output_dims =
+                mapping_args.context
+                    ->tensors[mapping_args.node->outputs->data[1]]
+                    .dims;
+
+            // Inputs kInputData
+            mapping_args.builder->AddTensorInput(
+                mapping_args.node->inputs->data[0 /* kInputData */],
+                /* hybrid_op */ false,
+                /* scalar_as_tensor */ false);
+
+            // The 8 weights tensors are set decomposing the
+            // kInputWeights param
+            const auto weight_tensor =
+                mapping_args.context->tensors
+                    [mapping_args.node->inputs->data[2 /* kInputWeights */]];
+
+            std::vector<uint8_t> recurrent_to_input;
+            std::vector<uint8_t> input_to_input;
+            std::vector<uint8_t> recurrent_to_cell;
+            std::vector<uint8_t> input_to_cell;
+            std::vector<uint8_t> recurrent_to_forget;
+            std::vector<uint8_t> input_to_forget;
+            std::vector<uint8_t> recurrent_to_output;
+            std::vector<uint8_t> input_to_output;
+            tflite::delegate::nnapi::DecomposeQuantLstmWeightsTensor(
+                weight_tensor.data.uint8, weight_tensor.dims,
+                &recurrent_to_input, &input_to_input, &recurrent_to_cell,
+                &input_to_cell, &recurrent_to_forget, &input_to_forget,
+                &recurrent_to_output, &input_to_output);
+
+            TfLiteIntArray* recurrent_weight_dims = TfLiteIntArrayCreate(2);
+            TfLiteIntArray* input_weight_dims = TfLiteIntArrayCreate(2);
+            tflite::delegate::nnapi::SetWeightSubmatrixDims(
+                weight_tensor.dims, recurrent_weight_dims, input_weight_dims);
+
+            int new_tensor_index = -1;
+
+            mapping_args.builder->AddNewInputConstantTensor<uint8_t>(
+                ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, kTfLiteUInt8,
+                input_weight_dims, input_to_input, weight_tensor.params,
+                &new_tensor_index);
+
+            mapping_args.builder->AddNewInputConstantTensor<uint8_t>(
+                ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, kTfLiteUInt8,
+                input_weight_dims, input_to_forget, weight_tensor.params,
+                &new_tensor_index);
+
+            mapping_args.builder->AddNewInputConstantTensor<uint8_t>(
+                ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, kTfLiteUInt8,
+                input_weight_dims, input_to_cell, weight_tensor.params,
+                &new_tensor_index);
+
+            mapping_args.builder->AddNewInputConstantTensor<uint8_t>(
+                ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, kTfLiteUInt8,
+                input_weight_dims, input_to_output, weight_tensor.params,
+                &new_tensor_index);
+
+            mapping_args.builder->AddNewInputConstantTensor<uint8_t>(
+                ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, kTfLiteUInt8,
+                recurrent_weight_dims, recurrent_to_input, weight_tensor.params,
+                &new_tensor_index);
+
+            mapping_args.builder->AddNewInputConstantTensor<uint8_t>(
+                ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, kTfLiteUInt8,
+                recurrent_weight_dims, recurrent_to_forget,
+                weight_tensor.params, &new_tensor_index);
+
+            mapping_args.builder->AddNewInputConstantTensor<uint8_t>(
+                ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, kTfLiteUInt8,
+                recurrent_weight_dims, recurrent_to_cell, weight_tensor.params,
+                &new_tensor_index);
+
+            mapping_args.builder->AddNewInputConstantTensor<uint8_t>(
+                ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, kTfLiteUInt8,
+                recurrent_weight_dims, recurrent_to_output,
+                weight_tensor.params, &new_tensor_index);
+
+            TfLiteIntArrayFree(input_weight_dims);
+            TfLiteIntArrayFree(recurrent_weight_dims);
+
+            // Biases have to be split in four
+            const auto bias_size = output_dims->data[1];
+            const TfLiteTensor& biases_tensor =
+                mapping_args.context->tensors[mapping_args.node->inputs
+                                                  ->data[3 /* kInputBiases */]];
+
+            std::vector<int32_t> input_bias;
+            std::vector<int32_t> cell_bias;
+            std::vector<int32_t> forget_bias;
+            std::vector<int32_t> output_bias;
+            delegate::nnapi::DecomposeBiasTensor(
+                biases_tensor.data.i32, bias_size, &input_bias, &cell_bias,
+                &forget_bias, &output_bias);
+
+            int input_bias_tensor = -1;
+            mapping_args.builder->AddNewInputConstantTensor<int32_t>(
+                ANEURALNETWORKS_TENSOR_INT32, kTfLiteInt32, {bias_size},
+                input_bias, biases_tensor.params, &input_bias_tensor);
+            int forget_bias_tensor = -1;
+            mapping_args.builder->AddNewInputConstantTensor(
+                ANEURALNETWORKS_TENSOR_INT32, kTfLiteInt32, {bias_size},
+                forget_bias, biases_tensor.params, &forget_bias_tensor);
+            int cell_gate_bias_tensor = -1;
+            mapping_args.builder->AddNewInputConstantTensor(
+                ANEURALNETWORKS_TENSOR_INT32, kTfLiteInt32, {bias_size},
+                cell_bias, biases_tensor.params, &cell_gate_bias_tensor);
+            int output_gate_bias_tensor = -1;
+            mapping_args.builder->AddNewInputConstantTensor(
+                ANEURALNETWORKS_TENSOR_INT32, kTfLiteInt32, {bias_size},
+                output_bias, biases_tensor.params, &output_gate_bias_tensor);
+
+            mapping_args.builder->AddTensorInput(
+                mapping_args.node->inputs->data[4 /* kInputPrevState */],
+                /* hybrid_op */ false,
+                /* scalar_as_tensor */ false);
+
+            // kInputPrevActivation
+            mapping_args.builder->AddTensorInput(
+                mapping_args.node->inputs->data[1 /* kInputPrevActivation */],
+                /* hybrid_op */ false,
+                /* scalar_as_tensor */ false);
+
+            // Configuring the copy from the activation, state outputs
+            // to their associated inputs
+            mapping_args.feedback_loops->push_back(std::make_tuple(
+                0 /*kOutputActivation*/, 1 /*kInputPrevActivation*/));
+
+            mapping_args.feedback_loops->push_back(
+                std::make_tuple(1 /*kOutputState*/, 4 /*kInputPrevState*/));
+
+            // OUTPUTS
+            // Setting only the first two since the remaining ones are
+            // ignored by NNAPI
+            mapping_args.builder->AddTensorOutput(
+                mapping_args.node->outputs->data[1 /* kOutputState */], 0);
+
+            mapping_args.builder->AddTensorOutput(
+                mapping_args.node->outputs
+                    ->data[0 /* kOutputkOutputActivationState */],
+                0);
+
+            return ANEURALNETWORKS_QUANTIZED_16BIT_LSTM;
+          };
+        }
+        if (node->inputs->size == 24 &&
+            android_sdk_version < kMinSdkVersionForNNAPI12) {
+          // LSTM with layer norm introduced in API level 29
+          return nullptr;
+        }
+        if (weight_type != kTfLiteFloat32 && weight_type != kTfLiteUInt8) {
+          return nullptr;
+        }
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          auto builtin = reinterpret_cast<TfLiteLSTMParams*>(
+              mapping_args.node->builtin_data);
+          mapping_args.builder->AddScalarInt32Operand(builtin->activation);
+          mapping_args.builder->AddScalarFloat32Operand(builtin->cell_clip);
+          mapping_args.builder->AddScalarFloat32Operand(builtin->proj_clip);
+
+          // Current NNAPI implementation requires the scratch_buffer as
+          // output.
+          mapping_args.builder->AddAdditionalFloat32OutputTensor(2);
+
+          // NNAPI need both state_in and state_out for cell_state and
+          // output_state.
+          int ann_index;
+          mapping_args.builder->AddStateFloat32Tensor(
+              mapping_args.node->inputs
+                  ->data[/*kInputActivationStateTensor*/ 18],
+              &ann_index);
+          mapping_args.model_state_outputs->push_back(ann_index);
+          mapping_args.model_state_tfl_inputs->push_back(
+              mapping_args.node->inputs
+                  ->data[/*kInputActivationStateTensor*/ 18]);
+          mapping_args.builder->AddStateFloat32Tensor(
+              mapping_args.node->inputs->data[/*kInputCellStateTensor*/ 19],
+              &ann_index);
+          mapping_args.model_state_outputs->push_back(ann_index);
+          mapping_args.model_state_tfl_inputs->push_back(
+              mapping_args.node->inputs->data[/*kInputCellStateTensor*/ 19]);
+
+          const bool hybrid_op = IsHybridOperator(
+              mapping_args.context, kTfLiteBuiltinLstm, mapping_args.node);
+
+          if (mapping_args.node->inputs->size == 24) {
+            for (int i = 20; i < 24; ++i) {
+              const auto input_index = mapping_args.node->inputs->data[i];
+              if (input_index != kOptionalTensor) {
+                mapping_args.builder->AddTensorInput(input_index, hybrid_op);
+              } else {
+                mapping_args.builder->AddVectorFloat32Operand(nullptr, 0);
+              }
+            }
+          }
+
+          return ANEURALNETWORKS_LSTM;
+        };
+      }
+      break;
+    case kTfLiteBuiltinMean:
+      // NNAPI does not support generating a scalar as output for MEAN.
+      if (version <= 2 &&
+          ((android_sdk_version >= kMinSdkVersionForNNAPI11 &&
+            context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) ||
+           (android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+            IsQuantized(context->tensors[node->inputs->data[0]].type))) &&
+          context->tensors[node->outputs->data[0]].dims->size > 0) {
+        auto input_param = context->tensors[node->inputs->data[0]].params;
+        auto output_param = context->tensors[node->outputs->data[0]].params;
+        // NNAPI requires that the input and output have the same
+        // quantization parameters.
+        if (input_param.scale != output_param.scale ||
+            input_param.zero_point != output_param.zero_point) {
+          return nullptr;
+        }
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          auto builtin = reinterpret_cast<TfLiteReducerParams*>(
+              mapping_args.node->builtin_data);
+          int32_t keep_dims = 0;
+          if (builtin->keep_dims) keep_dims = 1;
+          mapping_args.builder->AddScalarInt32Operand(keep_dims);
+          return ANEURALNETWORKS_MEAN;
+        };
+      }
+      break;
+    case kTfLiteBuiltinEmbeddingLookup:
+      // NNAPI only support float32 values.
+      if (version == 1 &&
+          context->tensors[node->inputs->data[1]].type == kTfLiteFloat32) {
+        return BasicMappingFn<ANEURALNETWORKS_EMBEDDING_LOOKUP>;
+      }
+      break;
+    case kTfLiteBuiltinHashtableLookup:
+      // NNAPI only support float32 output.
+      if (version == 1 &&
+          context->tensors[node->outputs->data[0]].type == kTfLiteFloat32) {
+        return BasicMappingFn<ANEURALNETWORKS_HASHTABLE_LOOKUP>;
+      }
+      break;
+    case kTfLiteBuiltinMaximum: {
+      const auto input_type = context->tensors[node->inputs->data[0]].type;
+      if (version <= 2 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          (input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
+           input_type == kTfLiteInt8 || input_type == kTfLiteInt32)) {
+        return BasicMappingFn<ANEURALNETWORKS_MAXIMUM>;
+      }
+    } break;
+    case kTfLiteBuiltinMinimum: {
+      const auto input_type = context->tensors[node->inputs->data[0]].type;
+      if (version <= 2 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          (input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
+           input_type == kTfLiteInt8 || input_type == kTfLiteInt32)) {
+        return BasicMappingFn<ANEURALNETWORKS_MINIMUM>;
+      }
+    } break;
+    case kTfLiteBuiltinCast: {
+      const TfLiteType input_type =
+          context->tensors[node->inputs->data[0]].type;
+      const TfLiteType output_type =
+          context->tensors[node->outputs->data[0]].type;
+      auto is_supported_tensor_type = [](const TfLiteType& type) {
+        return (type == kTfLiteFloat32 || type == kTfLiteInt32 ||
+                type == kTfLiteUInt8);
+      };
+      if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          is_supported_tensor_type(input_type) &&
+          is_supported_tensor_type(output_type)) {
+        return BasicMappingFn<ANEURALNETWORKS_CAST>;
+      }
+    } break;
+    case kTfLiteBuiltinPrelu:
+      if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
+        if (!IsFloatOrUint8Operator(context, node)) {
+          return nullptr;
+        }
+        return BasicMappingFn<ANEURALNETWORKS_PRELU>;
+      }
+      break;
+    case kTfLiteBuiltinTile: {
+      // NN API doesn't support int64 and boolean inputs to this op
+      const auto input_type = context->tensors[node->inputs->data[0]].type;
+      const auto multipliers_type =
+          context->tensors[node->inputs->data[1]].type;
+      if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          (input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
+           input_type == kTfLiteInt8 || input_type == kTfLiteInt32) &&
+          (multipliers_type == kTfLiteInt32)) {
+        return BasicMappingFn<ANEURALNETWORKS_TILE>;
+      }
+    } break;
+    case kTfLiteBuiltinLogicalOr: {
+      const auto input_type = context->tensors[node->inputs->data[0]].type;
+      if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          input_type == kTfLiteBool) {
+        return BasicMappingFn<ANEURALNETWORKS_LOGICAL_OR>;
+      }
+    } break;
+    case kTfLiteBuiltinLogicalAnd: {
+      const auto input_type = context->tensors[node->inputs->data[0]].type;
+      if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          input_type == kTfLiteBool) {
+        return BasicMappingFn<ANEURALNETWORKS_LOGICAL_AND>;
+      }
+    } break;
+    case kTfLiteBuiltinLogicalNot: {
+      const auto input_type = context->tensors[node->inputs->data[0]].type;
+      if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          input_type == kTfLiteBool) {
+        return BasicMappingFn<ANEURALNETWORKS_LOGICAL_NOT>;
+      }
+    } break;
+    case kTfLiteBuiltinLess: {
+      const auto input_type = context->tensors[node->inputs->data[0]].type;
+      if (version <= 2 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          (input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
+           input_type == kTfLiteInt8 || input_type == kTfLiteBool ||
+           input_type == kTfLiteInt32)) {
+        return BasicMappingFn<ANEURALNETWORKS_LESS>;
+      }
+    } break;
+    case kTfLiteBuiltinLessEqual: {
+      const auto input_type = context->tensors[node->inputs->data[0]].type;
+      if (version <= 2 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          (input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
+           input_type == kTfLiteInt8 || input_type == kTfLiteBool ||
+           input_type == kTfLiteInt32)) {
+        return BasicMappingFn<ANEURALNETWORKS_LESS_EQUAL>;
+      }
+    } break;
+    case kTfLiteBuiltinGreater: {
+      const auto input_type = context->tensors[node->inputs->data[0]].type;
+      if (version <= 2 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          (input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
+           input_type == kTfLiteInt8 || input_type == kTfLiteBool ||
+           input_type == kTfLiteInt32)) {
+        return BasicMappingFn<ANEURALNETWORKS_GREATER>;
+      }
+    } break;
+    case kTfLiteBuiltinGreaterEqual: {
+      const auto input_type = context->tensors[node->inputs->data[0]].type;
+      if (version <= 2 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          (input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
+           input_type == kTfLiteInt8 || input_type == kTfLiteBool ||
+           input_type == kTfLiteInt32)) {
+        return BasicMappingFn<ANEURALNETWORKS_GREATER_EQUAL>;
+      }
+    } break;
+    case kTfLiteBuiltinEqual: {
+      const auto input_type = context->tensors[node->inputs->data[0]].type;
+      if (version <= 2 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          (input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
+           input_type == kTfLiteInt8 || input_type == kTfLiteBool ||
+           input_type == kTfLiteInt32)) {
+        return BasicMappingFn<ANEURALNETWORKS_EQUAL>;
+      }
+    } break;
+    case kTfLiteBuiltinNotEqual: {
+      const auto input_type = context->tensors[node->inputs->data[0]].type;
+      if (version <= 2 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          (input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
+           input_type == kTfLiteInt8 || input_type == kTfLiteBool ||
+           input_type == kTfLiteInt32)) {
+        return BasicMappingFn<ANEURALNETWORKS_NOT_EQUAL>;
+      }
+    } break;
+    case kTfLiteBuiltinNeg: {
+      const auto input_type = context->tensors[node->inputs->data[0]].type;
+      if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          (input_type == kTfLiteFloat32 || input_type == kTfLiteInt32)) {
+        return BasicMappingFn<ANEURALNETWORKS_NEG>;
+      }
+    } break;
+    case kTfLiteBuiltinTopkV2: {
+      if (version <= 2 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
+        const auto& input = context->tensors[node->outputs->data[0]];
+        const auto& k_param = context->tensors[node->outputs->data[1]];
+        if ((input.type == kTfLiteFloat32 || input.type == kTfLiteInt32 ||
+             input.type == kTfLiteUInt8 || input.type == kTfLiteInt8) &&
+            (k_param.type == kTfLiteInt32 &&
+             k_param.allocation_type == kTfLiteMmapRo)) {
+          return [](const NNAPIOpMappingArgs& mapping_args)
+                     -> ANeuralNetworksOperationType {
+            const TfLiteTensor& k_param =
                 mapping_args.context
                     ->tensors[mapping_args.node->inputs->data[1]];
-            mapping_args.builder->AddScalarInt32Operand(*axis_param.data.i32);
-            return ANEURALNETWORKS_EXPAND_DIMS;
+            mapping_args.builder->AddScalarInt32Operand(*k_param.data.i32);
+            return ANEURALNETWORKS_TOPK_V2;
           };
+        } else {
+          return nullptr;
         }
-      } break;
-      case kTfLiteBuiltinSplit: {
-        // Tensor indices: split_dim: 0, value: 1
-        const TfLiteTensor& axis = context->tensors[node->inputs->data[0]];
-        const TfLiteTensor& input = context->tensors[node->inputs->data[1]];
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-            (input.type == kTfLiteFloat32 || input.type == kTfLiteUInt8 ||
-             input.type == kTfLiteInt32) &&
-            (axis.type == kTfLiteInt32 &&
-             axis.allocation_type == kTfLiteMmapRo)) {
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            const TfLiteTensor& axis =
-                mapping_args.context
-                    ->tensors[mapping_args.node->inputs->data[0]];
-            auto builtin = reinterpret_cast<TfLiteSplitParams*>(
-                mapping_args.node->builtin_data);
-            mapping_args.builder->AddScalarInt32Operand(*axis.data.i32);
-            mapping_args.builder->AddScalarInt32Operand(builtin->num_splits);
-            return ANEURALNETWORKS_SPLIT;
-          };
+      }
+    } break;
+    case kTfLiteBuiltinSelect: {
+      const auto value_type = context->tensors[node->inputs->data[1]].type;
+      if (version <= 2 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          (value_type == kTfLiteFloat32 || value_type == kTfLiteUInt8 ||
+           value_type == kTfLiteInt8 || value_type == kTfLiteInt32)) {
+        TfLiteIntArray* condition_shape =
+            context->tensors[node->inputs->data[0]].dims;
+        TfLiteIntArray* input_shape =
+            context->tensors[node->inputs->data[1]].dims;
+        // The Android Q-variant of select does not support broadcasting.
+        if (!TfLiteIntArrayEqual(condition_shape, input_shape)) {
+          return nullptr;
         }
-      } break;
-      case kTfLiteBuiltinLogSoftmax: {
-        const auto input_type = context->tensors[node->inputs->data[0]].type;
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-            input_type == kTfLiteFloat32) {
-          return [](const NNAPIOpMappingArgs& mapping_args)
-                     -> ANeuralNetworksOperationType {
-            // Scaling and axis are hardcoded to respectively 1 and -1
-            // in TFLite.
-            mapping_args.builder->AddScalarFloat32Operand(1);
-            mapping_args.builder->AddScalarInt32Operand(-1);
-            return ANEURALNETWORKS_LOG_SOFTMAX;
-          };
+        return BasicMappingFn<ANEURALNETWORKS_SELECT>;
+      }
+    } break;
+    case kTfLiteBuiltinGather: {
+      if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
+        const auto& input = context->tensors[node->inputs->data[0]];
+        const auto& positions = context->tensors[node->inputs->data[1]];
+
+        auto is_supported_input_type = [](const TfLiteTensor& t) {
+          return (t.type == kTfLiteFloat32 || t.type == kTfLiteFloat16 ||
+                  t.type == kTfLiteInt32 || t.type == kTfLiteUInt8);
+        };
+
+        if (!is_supported_input_type(input) ||
+            !is_supported_input_type(positions)) {
+          return nullptr;
         }
-      } break;
-      case kTfLiteBuiltinQuantize: {
-        const auto value_type = context->tensors[node->inputs->data[0]].type;
-        const auto output_type = context->tensors[node->outputs->data[0]].type;
-        const auto quantization_params =
-            context->tensors[node->outputs->data[0]].params;
-        if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-            value_type == kTfLiteFloat32 && output_type == kTfLiteUInt8 &&
-            quantization_params.scale > 0.f) {
-          return BasicMappingFn<ANEURALNETWORKS_QUANTIZE>;
+
+        // 0-dimension args are not supported by NNAPI.
+        if (positions.dims->size == 0) {
+          return nullptr;
         }
-      } break;
-      default:
-        // All other operators are not mapped.
+
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          auto builtin = reinterpret_cast<TfLiteGatherParams*>(
+              mapping_args.node->builtin_data);
+          mapping_args.builder->AddTensorInput(
+              mapping_args.node->inputs->data[0],
+              /* hybrid_op */ false,
+              /* scalar_as_tensor */ false);
+
+          mapping_args.builder->AddScalarInt32Operand(builtin->axis);
+
+          mapping_args.builder->AddTensorInput(
+              mapping_args.node->inputs->data[1],
+              /* hybrid_op */ false,
+              /* scalar_as_tensor */ false);
+
+          return ANEURALNETWORKS_GATHER;
+        };
+      }
+    } break;
+    case kTfLiteBuiltinBidirectionalSequenceLstm:
+      if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
+        if (IsHybridOperator(context, builtin_code, node)) {
+          // Hybrid version of this op is not supported by NN API.
+          return nullptr;
+        }
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          auto builtin =
+              reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
+                  mapping_args.node->builtin_data);
+          mapping_args.builder->AddScalarInt32Operand(builtin->activation);
+          mapping_args.builder->AddScalarFloat32Operand(builtin->cell_clip);
+          mapping_args.builder->AddScalarFloat32Operand(builtin->proj_clip);
+          mapping_args.builder->AddScalarBoolOperand(builtin->merge_outputs);
+          mapping_args.builder->AddScalarBoolOperand(builtin->time_major);
+          // TF Lite doesn't support layer normalization in bidirectional
+          // sequence LSTM, so we insert optional tensors for NNAPI
+          for (int i = 0; i < 8; ++i) {
+            mapping_args.builder->AddVectorFloat32Operand(nullptr, 0);
+          }
+          return ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_LSTM;
+        };
+      }
+      break;
+    case kTfLiteBuiltinExpandDims: {
+      const auto input_type = context->tensors[node->inputs->data[0]].type;
+      const auto axis = context->tensors[node->inputs->data[1]];
+      if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          (input_type == kTfLiteFloat16 || input_type == kTfLiteFloat32 ||
+           input_type == kTfLiteInt32 || input_type == kTfLiteUInt8 ||
+           input_type == kTfLiteInt8) &&
+          // TFLite supports axis also as int64 but NNAPI only int32
+          (axis.type == kTfLiteInt32 &&
+           axis.allocation_type == kTfLiteMmapRo)) {
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          const TfLiteTensor& axis_param =
+              mapping_args.context->tensors[mapping_args.node->inputs->data[1]];
+          mapping_args.builder->AddScalarInt32Operand(*axis_param.data.i32);
+          return ANEURALNETWORKS_EXPAND_DIMS;
+        };
+      }
+    } break;
+    case kTfLiteBuiltinSplit: {
+      // Tensor indices: split_dim: 0, value: 1
+      const TfLiteTensor& axis = context->tensors[node->inputs->data[0]];
+      const TfLiteTensor& input = context->tensors[node->inputs->data[1]];
+      if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          (input.type == kTfLiteFloat32 || input.type == kTfLiteUInt8 ||
+           input.type == kTfLiteInt32) &&
+          (axis.type == kTfLiteInt32 &&
+           axis.allocation_type == kTfLiteMmapRo)) {
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          const TfLiteTensor& axis =
+              mapping_args.context->tensors[mapping_args.node->inputs->data[0]];
+          auto builtin = reinterpret_cast<TfLiteSplitParams*>(
+              mapping_args.node->builtin_data);
+          mapping_args.builder->AddScalarInt32Operand(*axis.data.i32);
+          mapping_args.builder->AddScalarInt32Operand(builtin->num_splits);
+          return ANEURALNETWORKS_SPLIT;
+        };
+      }
+    } break;
+    case kTfLiteBuiltinLogSoftmax: {
+      const auto input_type = context->tensors[node->inputs->data[0]].type;
+      if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          input_type == kTfLiteFloat32) {
+        return [](const NNAPIOpMappingArgs& mapping_args)
+                   -> ANeuralNetworksOperationType {
+          // Scaling and axis are hardcoded to respectively 1 and -1
+          // in TFLite.
+          mapping_args.builder->AddScalarFloat32Operand(1);
+          mapping_args.builder->AddScalarInt32Operand(-1);
+          return ANEURALNETWORKS_LOG_SOFTMAX;
+        };
+      }
+    } break;
+    case kTfLiteBuiltinQuantize: {
+      const auto value_type = context->tensors[node->inputs->data[0]].type;
+      const auto output_type = context->tensors[node->outputs->data[0]].type;
+      const auto quantization_params =
+          context->tensors[node->outputs->data[0]].params;
+      if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+          value_type == kTfLiteFloat32 && output_type == kTfLiteUInt8 &&
+          quantization_params.scale > 0.f) {
+        return BasicMappingFn<ANEURALNETWORKS_QUANTIZE>;
+      }
+    } break;
+    case kTfLiteBuiltinReduceAny: {
+      if (version != 1 || android_sdk_version < kMinSdkVersionForNNAPI12) {
         return nullptr;
-    }
-    return nullptr;
+      }
+      // NNAPI does not support generating a scalar as output for REDUCE_ANY.
+      if (context->tensors[node->outputs->data[0]].dims->size == 0) {
+        return nullptr;
+      }
+      return [](const NNAPIOpMappingArgs& mapping_args)
+                 -> ANeuralNetworksOperationType {
+        auto builtin = reinterpret_cast<TfLiteReducerParams*>(
+            mapping_args.node->builtin_data);
+        mapping_args.builder->AddScalarBoolOperand(builtin->keep_dims);
+        return ANEURALNETWORKS_REDUCE_ANY;
+      };
+    } break;
+    case kTfLiteBuiltinReduceMin: {
+      if (version > 2 || android_sdk_version < kMinSdkVersionForNNAPI12) {
+        return nullptr;
+      }
+      // NNAPI does not support generating a scalar as output for REDUCE_MIN.
+      if (context->tensors[node->outputs->data[0]].dims->size == 0) {
+        return nullptr;
+      }
+      return [](const NNAPIOpMappingArgs& mapping_args)
+                 -> ANeuralNetworksOperationType {
+        auto builtin = reinterpret_cast<TfLiteReducerParams*>(
+            mapping_args.node->builtin_data);
+        mapping_args.builder->AddScalarBoolOperand(builtin->keep_dims);
+        return ANEURALNETWORKS_REDUCE_MIN;
+      };
+    } break;
+    case kTfLiteBuiltinReduceMax: {
+      if (version > 2 || android_sdk_version < kMinSdkVersionForNNAPI12) {
+        return nullptr;
+      }
+      // NNAPI does not support generating a scalar as output for REDUCE_MAX.
+      if (context->tensors[node->outputs->data[0]].dims->size == 0) {
+        return nullptr;
+      }
+      return [](const NNAPIOpMappingArgs& mapping_args)
+                 -> ANeuralNetworksOperationType {
+        auto builtin = reinterpret_cast<TfLiteReducerParams*>(
+            mapping_args.node->builtin_data);
+        mapping_args.builder->AddScalarBoolOperand(builtin->keep_dims);
+        return ANEURALNETWORKS_REDUCE_MAX;
+      };
+    } break;
+    case kTfLiteBuiltinReduceProd: {
+      if (version != 1 || android_sdk_version < kMinSdkVersionForNNAPI12) {
+        return nullptr;
+      }
+      // NNAPI only supports floating point REDUCE_PROD.
+      const auto input_type = context->tensors[node->inputs->data[0]].type;
+      if (input_type != kTfLiteFloat32) {
+        return nullptr;
+      }
+      // NNAPI does not support generating a scalar as output for REDUCE_PROD.
+      if (context->tensors[node->outputs->data[0]].dims->size == 0) {
+        return nullptr;
+      }
+      return [](const NNAPIOpMappingArgs& mapping_args)
+                 -> ANeuralNetworksOperationType {
+        auto builtin = reinterpret_cast<TfLiteReducerParams*>(
+            mapping_args.node->builtin_data);
+        mapping_args.builder->AddScalarBoolOperand(builtin->keep_dims);
+        return ANEURALNETWORKS_REDUCE_PROD;
+      };
+    } break;
+    case kTfLiteBuiltinSum: {
+      if (version != 1 || android_sdk_version < kMinSdkVersionForNNAPI12) {
+        return nullptr;
+      }
+      // NNAPI only supports floating point REDUCE_SUM.
+      const auto input_type = context->tensors[node->inputs->data[0]].type;
+      if (input_type != kTfLiteFloat32) {
+        return nullptr;
+      }
+      // NNAPI does not support generating a scalar as output for REDUCE_SUM.
+      if (context->tensors[node->outputs->data[0]].dims->size == 0) {
+        return nullptr;
+      }
+      return [](const NNAPIOpMappingArgs& mapping_args)
+                 -> ANeuralNetworksOperationType {
+        auto builtin = reinterpret_cast<TfLiteReducerParams*>(
+            mapping_args.node->builtin_data);
+        mapping_args.builder->AddScalarBoolOperand(builtin->keep_dims);
+        return ANEURALNETWORKS_REDUCE_SUM;
+      };
+    } break;
+    default:
+      // All other operators are not mapped.
+      return nullptr;
+  }
+  return nullptr;
+}
+
+// Initialize the kernel (a NN model).
+TfLiteStatus NNAPIDelegateKernel::Init(TfLiteContext* context,
+                                       const TfLiteDelegateParams* params) {
+  for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) {
+    nodes_.push_back(node_index);
   }
 
-  // Initialize the kernel (a NN model).
-  TfLiteStatus Init(TfLiteContext* context,
-                    const TfLiteDelegateParams* params) {
-    for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) {
-      nodes_.push_back(node_index);
+  const auto delegate_options =
+      StatefulNnApiDelegate::GetOptions(params->delegate);
+  const char* device_name_ptr = delegate_options.accelerator_name;
+  // user specified an acclelerator to use.
+  if (nnapi_->android_sdk_version >= kMinSdkVersionForNNAPI12 &&
+      device_name_ptr != nullptr) {
+    nnapi_device_ = GetDeviceHandle(context, device_name_ptr);
+    if (nnapi_device_ == nullptr) {
+      return kTfLiteError;
     }
+  }
 
-    const auto delegate_options =
-        StatefulNnApiDelegate::GetOptions(params->delegate);
-    const char* device_name_ptr = delegate_options.accelerator_name;
-    // user specified an acclelerator to use.
-    if (nnapi_->android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-        device_name_ptr != nullptr) {
-      nnapi_device_ = GetDeviceHandle(device_name_ptr);
-      if (nnapi_device_ == nullptr) {
-        context->ReportError(context,
-                             "Could not find the specified accelerator: %s.",
-                             device_name_ptr);
-        return kTfLiteError;
-      }
-    }
+  // Mark the handle backed tensors.
+  tensor_memory_map_ =
+      &StatefulNnApiDelegate::GetTensorMemoryMap(params->delegate);
 
-    // Mark the handle backed tensors.
-    tensor_memory_map_ =
-        &StatefulNnApiDelegate::GetTensorMemoryMap(params->delegate);
+  if (!nn_model_) {
+    ANeuralNetworksModel* model = nullptr;
+    RETURN_TFLITE_ERROR_IF_NN_ERROR(
+        context, nnapi_->ANeuralNetworksModel_create(&model));
+    nn_model_.reset(model);
 
-    if (!nn_model_) {
-      ANeuralNetworksModel* model = nullptr;
+    TF_LITE_ENSURE_STATUS(
+        BuildGraph(context, params->input_tensors, params->output_tensors));
+  }
+
+  if (!nn_compilation_) {
+    ANeuralNetworksCompilation* compilation = nullptr;
+    if (nnapi_device_ != nullptr) {
+      // Compile for the selected accelerator.
       RETURN_TFLITE_ERROR_IF_NN_ERROR(
-          context, nnapi_->ANeuralNetworksModel_create(&model));
-      nn_model_.reset(model);
-
-      TF_LITE_ENSURE_STATUS(
-          BuildGraph(context, params->input_tensors, params->output_tensors));
+          context, nnapi_->ANeuralNetworksCompilation_createForDevices(
+                       nn_model_.get(), &nnapi_device_, 1, &compilation));
+    } else {
+      RETURN_TFLITE_ERROR_IF_NN_ERROR(
+          context, nnapi_->ANeuralNetworksCompilation_create(nn_model_.get(),
+                                                             &compilation));
     }
 
-    if (!nn_compilation_) {
-      ANeuralNetworksCompilation* compilation = nullptr;
-      if (nnapi_device_ != nullptr) {
-        // Compile for the selected accelerator.
-        RETURN_TFLITE_ERROR_IF_NN_ERROR(
-            context, nnapi_->ANeuralNetworksCompilation_createForDevices(
-                         nn_model_.get(), &nnapi_device_, 1, &compilation));
-      } else {
-        RETURN_TFLITE_ERROR_IF_NN_ERROR(
-            context, nnapi_->ANeuralNetworksCompilation_create(nn_model_.get(),
-                                                               &compilation));
-      }
-
-      auto preference = delegate_options.execution_preference;
-      if (preference !=
-          StatefulNnApiDelegate::Options::ExecutionPreference::kUndefined) {
-        const int preference_result =
-            nnapi_->ANeuralNetworksCompilation_setPreference(compilation,
-                                                             preference);
-        if (preference_result != ANEURALNETWORKS_NO_ERROR) {
-          nnapi_->ANeuralNetworksCompilation_free(compilation);
-          compilation = nullptr;
-        }
-        RETURN_TFLITE_ERROR_IF_NN_ERROR(context, preference_result);
-      }
-
-      const char* cache_dir = delegate_options.cache_dir;
-      const char* model_token = delegate_options.model_token;
-      if (nnapi_->android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-          cache_dir && model_token) {
-        // Compilation caching could be enabled, try construct the uint8 token.
-        // TODO(133342794): use a generic token generator class.
-        uint64_t token_parts[4];
-        // bits from model_token.
-        token_parts[0] = std::hash<std::string>{}(model_token);
-        // bits from params->nodes_to_replace.
-        token_parts[1] = GetHash(params->nodes_to_replace);
-        // bits from params->input_tensors.
-        token_parts[2] = GetHash(params->input_tensors);
-        // bits from params->output_tensors.
-        token_parts[3] = GetHash(params->output_tensors);
-        // NNAPI requires the token to be 256bit long.
-        std::vector<uint8_t> nnapi_cache_token(32, 0);
-        // Copy the token bits.
-        uint8_t* p = reinterpret_cast<uint8_t*>(token_parts);
-        for (int i = 0; i < 4 * sizeof(uint64_t); i++) {
-          nnapi_cache_token[i] = p[i];
-        }
-        const int set_caching_result =
-            nnapi_->ANeuralNetworksCompilation_setCaching(
-                compilation, cache_dir, nnapi_cache_token.data());
-        if (set_caching_result != ANEURALNETWORKS_NO_ERROR) {
-          nnapi_->ANeuralNetworksCompilation_free(compilation);
-          compilation = nullptr;
-        }
-        RETURN_TFLITE_ERROR_IF_NN_ERROR(context, set_caching_result);
-      }
-      const int finish_result =
-          nnapi_->ANeuralNetworksCompilation_finish(compilation);
-      if (finish_result != ANEURALNETWORKS_NO_ERROR) {
+    auto preference = delegate_options.execution_preference;
+    if (preference !=
+        StatefulNnApiDelegate::Options::ExecutionPreference::kUndefined) {
+      const int preference_result =
+          nnapi_->ANeuralNetworksCompilation_setPreference(compilation,
+                                                           preference);
+      if (preference_result != ANEURALNETWORKS_NO_ERROR) {
         nnapi_->ANeuralNetworksCompilation_free(compilation);
         compilation = nullptr;
       }
-      RETURN_TFLITE_ERROR_IF_NN_ERROR(context, finish_result);
-      nn_compilation_.reset(compilation);
+      RETURN_TFLITE_ERROR_IF_NN_ERROR(context, preference_result);
     }
-    return kTfLiteOk;
+
+    const char* cache_dir = delegate_options.cache_dir;
+    const char* model_token = delegate_options.model_token;
+    if (nnapi_->android_sdk_version >= kMinSdkVersionForNNAPI12 && cache_dir &&
+        model_token) {
+      // Compilation caching could be enabled, try construct the uint8
+      // token.
+      // TODO(133342794): use a generic token generator class.
+      uint64_t token_parts[4];
+      // bits from model_token.
+      token_parts[0] = std::hash<std::string>{}(model_token);
+      // bits from params->nodes_to_replace.
+      token_parts[1] = GetHash(params->nodes_to_replace);
+      // bits from params->input_tensors.
+      token_parts[2] = GetHash(params->input_tensors);
+      // bits from params->output_tensors.
+      token_parts[3] = GetHash(params->output_tensors);
+      // NNAPI requires the token to be 256bit long.
+      std::vector<uint8_t> nnapi_cache_token(32, 0);
+      // Copy the token bits.
+      uint8_t* p = reinterpret_cast<uint8_t*>(token_parts);
+      for (int i = 0; i < 4 * sizeof(uint64_t); i++) {
+        nnapi_cache_token[i] = p[i];
+      }
+      const int set_caching_result =
+          nnapi_->ANeuralNetworksCompilation_setCaching(
+              compilation, cache_dir, nnapi_cache_token.data());
+      if (set_caching_result != ANEURALNETWORKS_NO_ERROR) {
+        nnapi_->ANeuralNetworksCompilation_free(compilation);
+        compilation = nullptr;
+      }
+      RETURN_TFLITE_ERROR_IF_NN_ERROR(context, set_caching_result);
+    }
+    const int finish_result =
+        nnapi_->ANeuralNetworksCompilation_finish(compilation);
+    if (finish_result != ANEURALNETWORKS_NO_ERROR) {
+      nnapi_->ANeuralNetworksCompilation_free(compilation);
+      compilation = nullptr;
+    }
+    RETURN_TFLITE_ERROR_IF_NN_ERROR(context, finish_result);
+    nn_compilation_.reset(compilation);
   }
+  return kTfLiteOk;
+}
 
-  TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) {
-    ANeuralNetworksExecution* execution = nullptr;
-    RETURN_TFLITE_ERROR_IF_NN_ERROR(
-        context, nnapi_->ANeuralNetworksExecution_create(nn_compilation_.get(),
-                                                         &execution));
-    std::unique_ptr<ANeuralNetworksExecution, NNFreeExecution>
-        execution_unique_ptr(execution);
+TfLiteStatus NNAPIDelegateKernel::Prepare(TfLiteContext* context,
+                                          TfLiteNode* node) {
+  if (!nn_compilation_) {
+    // Compilation failed earlier, return error.
+    return kTfLiteError;
+  }
+  return kTfLiteOk;
+}
 
-    // Set the input tensor buffers. Note: we access tflite tensors using
-    // absolute indices but NN api indices inputs by relative indices.
-    int relative_input_index = 0;
+TfLiteStatus NNAPIDelegateKernel::Invoke(TfLiteContext* context,
+                                         TfLiteNode* node) {
+  ANeuralNetworksExecution* execution = nullptr;
+  RETURN_TFLITE_ERROR_IF_NN_ERROR(
+      context, nnapi_->ANeuralNetworksExecution_create(nn_compilation_.get(),
+                                                       &execution));
+  std::unique_ptr<ANeuralNetworksExecution, NNFreeExecution>
+      execution_unique_ptr(execution);
 
-    size_t input_offset = 0;
-    for (auto absolute_input_index : TfLiteIntArrayView(node->inputs)) {
-      if (absolute_input_index == kOptionalTensor) {
-        continue;
-      }
-      TfLiteTensor* tensor = &context->tensors[absolute_input_index];
-      // TODO(miaowang): make sure the delegation works with dequantized weights
-      // as intermediate tensors.
-      if (tensor->allocation_type != kTfLiteMmapRo) {
-        if (tensor->buffer_handle != kTfLiteNullBufferHandle &&
-            tensor->buffer_handle < tensor_memory_map_->size()) {
-          RETURN_TFLITE_ERROR_IF_NN_ERROR(
-              context, nnapi_->ANeuralNetworksExecution_setInputFromMemory(
-                           execution, relative_input_index, nullptr,
-                           tensor_memory_map_->at(tensor->buffer_handle).memory,
-                           0, tensor->bytes));
-          relative_input_index++;
-          continue;
-        }
-        TfLiteType ann_type_equivalent =
-            operand_mapping_.lite_index_to_ann_type_conversion(
-                absolute_input_index);
-        int tensor_size = 0;
-        if (ann_type_equivalent != kTfLiteNoType) {
-          if (tensor->type == kTfLiteUInt8 &&
-              ann_type_equivalent == kTfLiteInt32) {
-            for (int i = 0; i < NumElements(tensor); ++i) {
-              reinterpret_cast<int32_t*>(nn_input_memory_->get_data_ptr() +
-                                         input_offset)[i] =
-                  static_cast<const int32_t>(tensor->data.raw_const[i]);
-            }
-          } else if (tensor->type == kTfLiteInt8 &&
-                     ann_type_equivalent == kTfLiteUInt8) {
-            // Explicitly convert int8 values to uint8 values.
-            uint8_t* input_ptr = reinterpret_cast<uint8_t*>(
-                nn_input_memory_->get_data_ptr() + input_offset);
-            for (int i = 0; i < NumElements(tensor); ++i) {
-              input_ptr[i] = static_cast<const uint8_t>(
-                  static_cast<int32_t>(tensor->data.int8[i]) + 128);
-            }
-          } else {
-            context->ReportError(
-                context,
-                "NN API Delegate: unsupported tensor types conversion: "
-                "from type code %d to type code %d.\n",
-                tensor->type, ann_type_equivalent);
-            return kTfLiteError;
-          }
-          size_t type_size;
-          TF_LITE_ENSURE_OK(
-              context, GetSizeOfType(context, ann_type_equivalent, &type_size));
-          tensor_size = NumElements(tensor) * type_size;
-          RETURN_TFLITE_ERROR_IF_NN_ERROR(
-              context,
-              nnapi_->ANeuralNetworksExecution_setInputFromMemory(
-                  execution, relative_input_index, nullptr,
-                  nn_input_memory_->get_handle(), input_offset, tensor_size));
-        } else {
-          // copy data to pre-allocated shared memory.
-          memcpy(nn_input_memory_->get_data_ptr() + input_offset,
-                 tensor->data.raw, tensor->bytes);
-          RETURN_TFLITE_ERROR_IF_NN_ERROR(
-              context,
-              nnapi_->ANeuralNetworksExecution_setInputFromMemory(
-                  execution, relative_input_index, nullptr,
-                  nn_input_memory_->get_handle(), input_offset, tensor->bytes));
-          tensor_size = tensor->bytes;
-        }
-        input_offset += tensor_size;
-        input_offset += getNumPaddingBytes(tensor_size);
-        relative_input_index++;
-      }
+  // Set the input tensor buffers. Note: we access tflite tensors using
+  // absolute indices but NN api indices inputs by relative indices.
+  int relative_input_index = 0;
+
+  size_t input_offset = 0;
+  for (auto absolute_input_index : TfLiteIntArrayView(node->inputs)) {
+    if (absolute_input_index == kOptionalTensor) {
+      continue;
     }
-
-    // Set the output tensor buffers.
-    int relative_output_index = 0;
-    size_t output_offset = 0;
-    for (auto output_index : TfLiteIntArrayView(node->outputs)) {
-      TfLiteTensor* tensor = &context->tensors[output_index];
+    TfLiteTensor* tensor = &context->tensors[absolute_input_index];
+    if (tensor->allocation_type != kTfLiteMmapRo) {
       if (tensor->buffer_handle != kTfLiteNullBufferHandle &&
           tensor->buffer_handle < tensor_memory_map_->size()) {
         RETURN_TFLITE_ERROR_IF_NN_ERROR(
-            context, nnapi_->ANeuralNetworksExecution_setOutputFromMemory(
-                         execution, relative_output_index, nullptr,
+            context, nnapi_->ANeuralNetworksExecution_setInputFromMemory(
+                         execution, relative_input_index, nullptr,
                          tensor_memory_map_->at(tensor->buffer_handle).memory,
                          0, tensor->bytes));
-
-      } else {
-        RETURN_TFLITE_ERROR_IF_NN_ERROR(
-            context,
-            nnapi_->ANeuralNetworksExecution_setOutputFromMemory(
-                execution, relative_output_index, nullptr,
-                nn_output_memory_->get_handle(), output_offset, tensor->bytes));
-        output_offset += tensor->bytes;
-        output_offset += getNumPaddingBytes(tensor->bytes);
-      }
-      relative_output_index++;
-    }
-
-    // The state_out of previous invocation need to be mapped to state_in of
-    // current invocation.
-    for (size_t i = 0; i < model_state_tfl_inputs_.size(); i++) {
-      int state_tensor_idx = model_state_tfl_inputs_[i];
-      TfLiteTensor* tensor = &context->tensors[state_tensor_idx];
-      // Here we are using a deep copy for state_in tensors so that we are not
-      // reading and writing into the same buffer during a invocation.
-      // TODO(110369471): using double shared buffer to minimize the copies.
-      RETURN_TFLITE_ERROR_IF_NN_ERROR(
-          context, nnapi_->ANeuralNetworksExecution_setOutput(
-                       execution, relative_output_index, nullptr,
-                       tensor->data.raw, tensor->bytes));
-      relative_output_index++;
-    }
-    // Invoke ANN in blocking fashion.
-    if (nnapi_->android_sdk_version < kMinSdkVersionForNNAPI12) {
-      ANeuralNetworksEvent* event = nullptr;
-      RETURN_TFLITE_ERROR_IF_NN_ERROR(
-          context,
-          nnapi_->ANeuralNetworksExecution_startCompute(execution, &event));
-      const int wait_result = nnapi_->ANeuralNetworksEvent_wait(event);
-      nnapi_->ANeuralNetworksEvent_free(event);
-      RETURN_TFLITE_ERROR_IF_NN_ERROR(context, wait_result);
-    } else {
-      // Use synchronous execution for NNAPI 1.2+.
-      RETURN_TFLITE_ERROR_IF_NN_ERROR(
-          context, nnapi_->ANeuralNetworksExecution_compute(execution));
-    }
-
-    // copy results from shared memory to the destination.
-    output_offset = 0;
-    for (auto output_index : TfLiteIntArrayView(node->outputs)) {
-      TfLiteTensor* tensor = &context->tensors[output_index];
-      if (tensor->buffer_handle != kTfLiteNullBufferHandle) {
+        relative_input_index++;
         continue;
       }
       TfLiteType ann_type_equivalent =
-          operand_mapping_.lite_index_to_ann_type_conversion(output_index);
-      if (tensor->type == kTfLiteInt8 && ann_type_equivalent == kTfLiteUInt8) {
-        // Explicitly convert uint8 values to int8 values.
-        uint8_t* output_ptr = reinterpret_cast<uint8_t*>(
-            nn_output_memory_->get_data_ptr() + output_offset);
-        for (int i = 0; i < NumElements(tensor); ++i) {
-          output_ptr[i] =
-              static_cast<uint8_t>(static_cast<int32_t>(output_ptr[i]) - 128);
+          operand_mapping_.lite_index_to_ann_type_conversion(
+              absolute_input_index);
+      int tensor_size = 0;
+      if (ann_type_equivalent != kTfLiteNoType) {
+        const auto num_elements = NumElements(tensor);
+        uint8_t* input_ptr = nn_input_memory_->get_data_ptr() + input_offset;
+        if (tensor->type == kTfLiteUInt8 &&
+            ann_type_equivalent == kTfLiteInt32) {
+          for (int i = 0; i < num_elements; ++i) {
+            reinterpret_cast<int32_t*>(input_ptr)[i] =
+                static_cast<const int32_t>(tensor->data.raw_const[i]);
+          }
+        } else if (tensor->type == kTfLiteInt8 &&
+                   ann_type_equivalent == kTfLiteUInt8) {
+          // Explicitly convert int8 values to uint8 values.
+          for (int i = 0; i < num_elements; ++i) {
+            input_ptr[i] = static_cast<const uint8_t>(
+                static_cast<int32_t>(tensor->data.int8[i]) + 128);
+          }
+        } else if (tensor->type == kTfLiteInt8 &&
+                   ann_type_equivalent == kTfLiteInt32) {
+          for (int i = 0; i < num_elements; ++i) {
+            reinterpret_cast<int32_t*>(input_ptr)[i] =
+                static_cast<const int32_t>(tensor->data.raw_const[i]) + 128;
+          }
+        } else {
+          context->ReportError(
+              context,
+              "NN API Delegate: unsupported tensor types conversion: "
+              "from type code %d to type code %d.\n",
+              tensor->type, ann_type_equivalent);
+          return kTfLiteError;
         }
+        size_t type_size;
+        TF_LITE_ENSURE_OK(
+            context, GetSizeOfType(context, ann_type_equivalent, &type_size));
+        tensor_size = NumElements(tensor) * type_size;
+        RETURN_TFLITE_ERROR_IF_NN_ERROR(
+            context,
+            nnapi_->ANeuralNetworksExecution_setInputFromMemory(
+                execution, relative_input_index, nullptr,
+                nn_input_memory_->get_handle(), input_offset, tensor_size));
+      } else {
+        // copy data to pre-allocated shared memory.
+        memcpy(nn_input_memory_->get_data_ptr() + input_offset,
+               tensor->data.raw, tensor->bytes);
+        RETURN_TFLITE_ERROR_IF_NN_ERROR(
+            context,
+            nnapi_->ANeuralNetworksExecution_setInputFromMemory(
+                execution, relative_input_index, nullptr,
+                nn_input_memory_->get_handle(), input_offset, tensor->bytes));
+        tensor_size = tensor->bytes;
       }
-      memcpy(tensor->data.raw,
-             nn_output_memory_->get_data_ptr() + output_offset, tensor->bytes);
+      input_offset += tensor_size;
+      input_offset += getNumPaddingBytes(tensor_size);
+      relative_input_index++;
+    }
+  }
+
+  // Set the output tensor buffers.
+  int relative_output_index = 0;
+  size_t output_offset = 0;
+  for (auto output_index : TfLiteIntArrayView(node->outputs)) {
+    // If the NNAPI implementation doesn't have some of the outputs
+    // they are left unmapped and we should not try to read their value here
+    if (operand_mapping_.lite_index_to_ann(output_index) == -1) {
+      continue;
+    }
+    TfLiteTensor* tensor = &context->tensors[output_index];
+    if (tensor->buffer_handle != kTfLiteNullBufferHandle &&
+        tensor->buffer_handle < tensor_memory_map_->size()) {
+      RETURN_TFLITE_ERROR_IF_NN_ERROR(
+          context, nnapi_->ANeuralNetworksExecution_setOutputFromMemory(
+                       execution, relative_output_index, nullptr,
+                       tensor_memory_map_->at(tensor->buffer_handle).memory, 0,
+                       tensor->bytes));
+
+    } else {
+      RETURN_TFLITE_ERROR_IF_NN_ERROR(
+          context,
+          nnapi_->ANeuralNetworksExecution_setOutputFromMemory(
+              execution, relative_output_index, nullptr,
+              nn_output_memory_->get_handle(), output_offset, tensor->bytes));
       output_offset += tensor->bytes;
       output_offset += getNumPaddingBytes(tensor->bytes);
     }
-
-    return kTfLiteOk;
+    relative_output_index++;
   }
 
- private:
-  // Access to NNApi.
-  const NnApi* nnapi_;
-  // ANN device handle.
-  ANeuralNetworksDevice* nnapi_device_ = nullptr;
-  // ANN API state.
-  std::unique_ptr<ANeuralNetworksModel, NNFreeModel> nn_model_;
-  std::unique_ptr<ANeuralNetworksCompilation, NNFreeCompilation>
-      nn_compilation_;
-  // Node indices that this delegate is responsible for. Indices here
-  // indexes into the nodes array in the TfLiteContext.
-  std::vector<int> nodes_;
-  // Track indices we use
-  OperandMapping operand_mapping_;
-  std::map<const MMAPAllocation*, ANeuralNetworksMemory*>
-      allocation_memory_mapping_;
-  // Track memory map
-  const std::vector<StatefulNnApiDelegate::MemoryRegistration>*
-      tensor_memory_map_;
-  std::vector<int> model_state_outputs_;
-  std::vector<int> model_state_tfl_inputs_;
-
-  std::unique_ptr<NNMemory> nn_input_memory_;
-  std::unique_ptr<NNMemory> nn_output_memory_;
-
-  void AddDequantizeOperatorsWhereNeeded(const TfLiteContext* context,
-                                         int builtin_code,
-                                         const TfLiteNode* node,
-                                         NNAPIOpBuilder* builder) {
-    // Depending on the operator and the input data format, Dequantize
-    // operators may need to be added. For example when the input is
-    // floating-point but weights are quantized then the weights will first be
-    // dequantized to the same format as the input before being passed to the
-    // operator.
-
-    // The tensor determining whether the inputs should be floating-point.
-    int input_tensor_index = -1;
-    std::vector<int> inputs_to_potentially_dequantize;
-
-    switch (builtin_code) {
-      case kTfLiteBuiltinConv2d:
-      case kTfLiteBuiltinFullyConnected: {
-        input_tensor_index = 0;
-        // Weights and bias are inputs #1 and #2 respectively and may require
-        // dequantization.
-        inputs_to_potentially_dequantize = {1, 2};
-        break;
-      }
-      case kTfLiteBuiltinLstm: {
-        input_tensor_index = 0;
-        inputs_to_potentially_dequantize = {1,  2,  3,  4,  5,  6,  7,
-                                            8,  9,  10, 11, 12, 13, 14,
-                                            15, 16, 17, 20, 21, 22, 23};
-        break;
-      }
-      default:
-        return;
-    }
-
-    int tensor_id = node->inputs->data[input_tensor_index];
-    if (tensor_id < 0) return;
-
-    // Nothing to do if the input is not floating-point.
-    if (!IsFloat(context->tensors[tensor_id].type)) return;
-
-    for (int i : inputs_to_potentially_dequantize) {
-      if (i < 0 || i >= node->inputs->size) continue;  // Ignore invalid index.
-      tensor_id = node->inputs->data[i];
-      if (tensor_id < 0) continue;  // Ignore optional input.
-
-      const TfLiteType type = context->tensors[tensor_id].type;
-      // Nothing to do for this tensor if it's not quantized.
-      if (!IsQuantized(type)) continue;
-
-      // Insert Dequantize operator if it hasn't been done already and change
-      // the node's input accordingly.
-      builder->AddDequantize(i, node->inputs->data[i], type);
-    }
+  // The state_out of previous invocation need to be mapped to state_in of
+  // current invocation.
+  for (size_t i = 0; i < model_state_tfl_inputs_.size(); i++) {
+    int state_tensor_idx = model_state_tfl_inputs_[i];
+    TfLiteTensor* tensor = &context->tensors[state_tensor_idx];
+    // Here we are using a deep copy for state_in tensors so that we are not
+    // reading and writing into the same buffer during a invocation.
+    // TODO(110369471): using double shared buffer to minimize the copies.
+    RETURN_TFLITE_ERROR_IF_NN_ERROR(
+        context, nnapi_->ANeuralNetworksExecution_setOutput(
+                     execution, relative_output_index, nullptr,
+                     tensor->data.raw, tensor->bytes));
+    relative_output_index++;
+  }
+  // Invoke ANN in blocking fashion.
+  if (nnapi_->android_sdk_version < kMinSdkVersionForNNAPI12) {
+    ANeuralNetworksEvent* event = nullptr;
+    RETURN_TFLITE_ERROR_IF_NN_ERROR(
+        context,
+        nnapi_->ANeuralNetworksExecution_startCompute(execution, &event));
+    const int wait_result = nnapi_->ANeuralNetworksEvent_wait(event);
+    nnapi_->ANeuralNetworksEvent_free(event);
+    RETURN_TFLITE_ERROR_IF_NN_ERROR(context, wait_result);
+  } else {
+    // Use synchronous execution for NNAPI 1.2+.
+    RETURN_TFLITE_ERROR_IF_NN_ERROR(
+        context, nnapi_->ANeuralNetworksExecution_compute(execution));
   }
 
-  TfLiteStatus AddOpsAndTensors(TfLiteContext* context) {
-    DequantizeMapping dequantize_mapping;
-    // The operand builder allows creating a single op. It is created outside
-    // the for loop to avoid reallocating the vectors.
-    NNAPIOpBuilder builder(nnapi_, context, &operand_mapping_,
-                           &dequantize_mapping, &allocation_memory_mapping_,
-                           nn_model_.get());
-    // Add Tensors.
-    for (auto node_index : nodes_) {
-      // Obtain the op and registration.
-      TfLiteNode* node;
-      TfLiteRegistration* reg;
-      TF_LITE_ENSURE_STATUS(
-          context->GetNodeAndRegistration(context, node_index, &node, &reg));
-
-      const bool hybrid_op = IsHybridOperator(context, reg->builtin_code, node);
-      const bool scalar_as_tensor = IsScalarInputSupported(reg->builtin_code);
-      const bool need_int8_conversion =
-          NeedInt8Conversion(context, reg->builtin_code, node);
-      int input_tensor_flags = 0;
-      if (scalar_as_tensor) {
-        input_tensor_flags |= NN_TENSOR_FLAG_SCALAR_AS_TENSOR;
+  // copy results from shared memory to the destination.
+  output_offset = 0;
+  for (auto output_index : TfLiteIntArrayView(node->outputs)) {
+    TfLiteTensor* tensor = &context->tensors[output_index];
+    if (tensor->buffer_handle != kTfLiteNullBufferHandle) {
+      continue;
+    }
+    TfLiteType ann_type_equivalent =
+        operand_mapping_.lite_index_to_ann_type_conversion(output_index);
+    if (tensor->type == kTfLiteInt8 && ann_type_equivalent == kTfLiteUInt8) {
+      // Explicitly convert uint8 values to int8 values.
+      uint8_t* output_ptr = reinterpret_cast<uint8_t*>(
+          nn_output_memory_->get_data_ptr() + output_offset);
+      for (int i = 0; i < NumElements(tensor); ++i) {
+        output_ptr[i] =
+            static_cast<uint8_t>(static_cast<int32_t>(output_ptr[i]) - 128);
       }
+    }
+    memcpy(tensor->data.raw, nn_output_memory_->get_data_ptr() + output_offset,
+           tensor->bytes);
+    output_offset += tensor->bytes;
+    output_offset += getNumPaddingBytes(tensor->bytes);
+  }
 
-      // Map inputs to NN API tensor indices.
-      for (int input_pos = 0; input_pos < node->inputs->size; ++input_pos) {
-        const auto input_index = node->inputs->data[input_pos];
-        if (need_int8_conversion &&
-            (input_pos == 0 ||
-             reg->builtin_code == kTfLiteBuiltinFullyConnected ||
-             reg->builtin_code == kTfLiteBuiltinSub)) {
-          // Only selected inputs require int8 conversion.
-          TF_LITE_ENSURE_STATUS(builder.AddTensorInput(
-              input_index, hybrid_op,
-              input_tensor_flags | NN_TENSOR_FLAG_INT8_CONVERSION));
-          continue;
-        }
-        if (reg->builtin_code == kTfLiteBuiltinLstm && input_pos >= 20) {
+  // copy output of all output tensors in feedback_loops_ into the
+  // associated input
+  for (auto feedback_loop : feedback_loops_) {
+    int output_tensor_idx;
+    int input_tensor_idx;
+    std::tie(output_tensor_idx, input_tensor_idx) = feedback_loop;
+    TfLiteTensor* src =
+        &context->tensors[node->outputs->data[output_tensor_idx]];
+    TfLiteTensor* dest =
+        &context->tensors[node->inputs->data[input_tensor_idx]];
+
+    memcpy(dest->data.raw, src->data.raw, src->bytes);
+  }
+
+  return kTfLiteOk;
+}
+
+void NNAPIDelegateKernel::AddDequantizeOperatorsWhereNeeded(
+    const TfLiteContext* context, int builtin_code, const TfLiteNode* node,
+    NNAPIOpBuilder* builder) {
+  // Depending on the operator and the input data format, Dequantize
+  // operators may need to be added. For example when the input is
+  // floating-point but weights are quantized then the weights will first be
+  // dequantized to the same format as the input before being passed to the
+  // operator.
+
+  // The tensor determining whether the inputs should be floating-point.
+  int input_tensor_index = -1;
+  std::vector<int> inputs_to_potentially_dequantize;
+
+  switch (builtin_code) {
+    case kTfLiteBuiltinConv2d:
+    case kTfLiteBuiltinFullyConnected: {
+      input_tensor_index = 0;
+      // Weights and bias are inputs #1 and #2 respectively and may require
+      // dequantization.
+      inputs_to_potentially_dequantize = {1, 2};
+      break;
+    }
+    case kTfLiteBuiltinLstm: {
+      input_tensor_index = 0;
+      inputs_to_potentially_dequantize = {1,  2,  3,  4,  5,  6,  7,
+                                          8,  9,  10, 11, 12, 13, 14,
+                                          15, 16, 17, 20, 21, 22, 23};
+      break;
+    }
+    default:
+      return;
+  }
+
+  int tensor_id = node->inputs->data[input_tensor_index];
+  if (tensor_id < 0) return;
+
+  // Nothing to do if the input is not floating-point.
+  if (!IsFloat(context->tensors[tensor_id].type)) return;
+
+  for (int i : inputs_to_potentially_dequantize) {
+    if (i < 0 || i >= node->inputs->size) continue;  // Ignore invalid index.
+    tensor_id = node->inputs->data[i];
+    if (tensor_id < 0) continue;  // Ignore optional input.
+
+    const TfLiteType type = context->tensors[tensor_id].type;
+    // Nothing to do for this tensor if it's not quantized.
+    if (!IsQuantized(type)) continue;
+
+    // Insert Dequantize operator if it hasn't been done already and change
+    // the node's input accordingly.
+    builder->AddDequantize(i, node->inputs->data[i], type);
+  }
+}
+
+TfLiteStatus NNAPIDelegateKernel::AddOpsAndTensors(TfLiteContext* context) {
+  DequantizeMapping dequantize_mapping;
+  // The operand builder allows creating a single op. It is created outside
+  // the for loop to avoid reallocating the vectors.
+  NNAPIOpBuilder builder(nnapi_, context, &operand_mapping_,
+                         &dequantize_mapping, &allocation_memory_mapping_,
+                         nn_model_.get());
+  // Add Tensors.
+  for (auto node_index : nodes_) {
+    // Obtain the op and registration.
+    TfLiteNode* node;
+    TfLiteRegistration* reg;
+    TF_LITE_ENSURE_STATUS(
+        context->GetNodeAndRegistration(context, node_index, &node, &reg));
+
+    const bool hybrid_op = IsHybridOperator(context, reg->builtin_code, node);
+    const bool scalar_as_tensor = IsScalarInputSupported(reg->builtin_code);
+    const bool need_int8_conversion =
+        NeedInt8Conversion(context, reg->builtin_code, node);
+    int input_tensor_flags = 0;
+    if (scalar_as_tensor) {
+      input_tensor_flags |= NN_TENSOR_FLAG_SCALAR_AS_TENSOR;
+    }
+
+    // Map inputs to NN API tensor indices.
+    for (int input_pos = 0; input_pos < node->inputs->size; ++input_pos) {
+      const auto input_index = node->inputs->data[input_pos];
+      if (need_int8_conversion &&
+          (input_pos == 0 ||
+           reg->builtin_code == kTfLiteBuiltinFullyConnected ||
+           reg->builtin_code == kTfLiteBuiltinAdd ||
+           reg->builtin_code == kTfLiteBuiltinMul ||
+           reg->builtin_code == kTfLiteBuiltinSub ||
+           reg->builtin_code == kTfLiteBuiltinConcatenation ||
+           reg->builtin_code == kTfLiteBuiltinMaximum ||
+           reg->builtin_code == kTfLiteBuiltinMinimum ||
+           reg->builtin_code == kTfLiteBuiltinLess ||
+           reg->builtin_code == kTfLiteBuiltinLessEqual ||
+           reg->builtin_code == kTfLiteBuiltinGreater ||
+           reg->builtin_code == kTfLiteBuiltinGreaterEqual ||
+           reg->builtin_code == kTfLiteBuiltinEqual ||
+           reg->builtin_code == kTfLiteBuiltinNotEqual ||
+           reg->builtin_code == kTfLiteBuiltinSelect)) {
+        // Only selected inputs require int8 conversion.
+        TF_LITE_ENSURE_STATUS(builder.AddTensorInput(
+            input_index, hybrid_op,
+            input_tensor_flags | NN_TENSOR_FLAG_INT8_CONVERSION));
+        continue;
+      }
+      if (reg->builtin_code == kTfLiteBuiltinLstm && isLstmFullKernel(node) &&
+          input_pos >= 20) {
+        // Skip layer normalization weights. They are added in the Map
+        // function (after all the other inputs added there) since layer
+        // normalization weights are the last four inputs of the LSTM op in
+        // NNAPI.
+        continue;
+      }
+      if (reg->builtin_code == kTfLiteBuiltinLstm && isLstmBasicKernel(node)) {
+        // Configuring all inputs in the Map function
+        continue;
+      }
+      if (reg->builtin_code == kTfLiteBuiltinUnidirectionalSequenceLstm) {
+        if (input_pos >= 20) {
           // Skip layer normalization weights. They are added in the Map
           // function (after all the other inputs added there) since layer
-          // normalization weights are the last four inputs of the LSTM op in
-          // NNAPI.
+          // normalization weights are the last four inputs of the
+          // unidirectional sequence LSTM op in NNAPI.
           continue;
         }
-        if (reg->builtin_code == kTfLiteBuiltinUnidirectionalSequenceLstm) {
-          if (input_pos >= 20) {
-            // Skip layer normalization weights. They are added in the Map
-            // function (after all the other inputs added there) since layer
-            // normalization weights are the last four inputs of the
-            // unidirectional sequence LSTM op in NNAPI.
-            continue;
-          }
-          if (input_index == kOptionalTensor) {
-            TF_LITE_ENSURE_STATUS(builder.AddVectorFloat32Operand(nullptr, 0));
-            continue;
-          }
-        }
-
-        if ((reg->builtin_code == kTfLiteBuiltinSplit) &&
-            (input_index == node->inputs->data[0])) {
-          // Skip the axis input tensor; it will be added as a scalar operand
-          // by the Map() mapping.
+        if (input_index == kOptionalTensor) {
+          TF_LITE_ENSURE_STATUS(builder.AddVectorFloat32Operand(nullptr, 0));
           continue;
         }
+      }
+      if ((reg->builtin_code == kTfLiteBuiltinSplit) &&
+          (input_index == node->inputs->data[0])) {
+        // Skip the axis input tensor; it will be added as a scalar operand
+        // by the Map() mapping.
+        continue;
+      }
+      if (reg->builtin_code == kTfLiteBuiltinTransposeConv) {
+        // Everything is added during Map since input tensors
+        // have different order.
+        continue;
+      }
 
-        // Pad and Padv2 have an optional parameter for a pad value which has
-        // to be converted to a scalar type in NN API.
-        if ((reg->builtin_code == kTfLiteBuiltinPadv2 ||
-             reg->builtin_code == kTfLiteBuiltinPad) &&
-            node->inputs->size == 3 && input_pos == 2) {
-          const int constant_value_id = node->inputs->data[2];
-          if (constant_value_id == kOptionalTensor) {
-            continue;
-          }
-          const TfLiteTensor constant_value =
-              context->tensors[constant_value_id];
+      // Pad and Padv2 have an optional parameter for a pad value which has
+      // to be converted to a scalar type in NN API.
+      if ((reg->builtin_code == kTfLiteBuiltinPadv2 ||
+           reg->builtin_code == kTfLiteBuiltinPad) &&
+          node->inputs->size == 3 && input_pos == 2) {
+        const int constant_value_id = node->inputs->data[2];
+        if (constant_value_id == kOptionalTensor) {
+          continue;
+        }
+        const TfLiteTensor constant_value = context->tensors[constant_value_id];
 
-          switch (constant_value.type) {
-            case kTfLiteFloat32:
-              if (constant_value.allocation_type == kTfLiteMmapRo) {
-                builder.AddScalarFloat32Operand(*constant_value.data.f);
+        switch (constant_value.type) {
+          case kTfLiteFloat32:
+            if (constant_value.allocation_type == kTfLiteMmapRo) {
+              builder.AddScalarFloat32Operand(*constant_value.data.f);
+            } else {
+              builder.AddSingleValueTensorAsScalarOperand(
+                  constant_value_id, ANEURALNETWORKS_FLOAT32);
+            }
+            break;
+          case kTfLiteUInt8:
+            if (constant_value.allocation_type == kTfLiteMmapRo) {
+              builder.AddScalarInt32Operand(
+                  static_cast<int32_t>(*constant_value.data.uint8));
+            } else {
+              builder.AddSingleValueTensorAsScalarOperand(
+                  constant_value_id, ANEURALNETWORKS_INT32);
+            }
+            break;
+          case kTfLiteInt8:
+            if (constant_value.allocation_type == kTfLiteMmapRo) {
+              builder.AddScalarInt32Operand(
+                  static_cast<int32_t>(*constant_value.data.int8) + 128);
+            } else {
+              builder.AddSingleValueTensorAsScalarOperand(
+                  constant_value_id, ANEURALNETWORKS_INT32);
+            }
+            break;
+          default:
+            context->ReportError(context,
+                                 "Unsupported type of pad value for pad_v2\n");
+            return kTfLiteError;
+        }
+        continue;
+      }
+
+      if (input_index == kOptionalTensor &&
+          (reg->builtin_code == kTfLiteBuiltinLstm ||
+           reg->builtin_code == kTfLiteBuiltinSvdf ||
+           reg->builtin_code == kTfLiteBuiltinBidirectionalSequenceLstm)) {
+        // properly handle the optional tensor for LSTM and SVDF.
+        // currently only support float32.
+        TF_LITE_ENSURE_STATUS(builder.AddVectorFloat32Operand(nullptr, 0));
+      } else if (reg->builtin_code == kTfLiteBuiltinResizeBilinear ||
+                 reg->builtin_code == kTfLiteBuiltinResizeNearestNeighbor) {
+        if (input_pos == 0) {
+          // Only the first input tensor is added. The second one,
+          // specifying the output height and width, is not added and
+          // instead the height and width will be added individually as
+          // scalars by the mapping function returned by Map().
+          TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index, hybrid_op));
+        }
+      } else if (reg->builtin_code == kTfLiteBuiltinTopkV2 && input_pos > 0) {
+        // The K parameter tensor is not handled here but by the functor
+        // returned by Map, the input tensor is instead added in
+        // the else clause below
+        continue;
+      } else if (reg->builtin_code == kTfLiteBuiltinGather) {
+        // Everything is added during Map since input tensors
+        // have different order.
+        continue;
+      } else if (reg->builtin_code == kTfLiteBuiltinExpandDims &&
+                 input_pos == 1) {
+        // The axis param is added during Map
+        continue;
+      } else if (reg->builtin_code == kTfLiteBuiltinBatchToSpaceNd &&
+                 input_pos == 2) {
+        // NNAPI does not support crops.
+        // The Map fucntion will check if all crops are zero.
+        continue;
+      } else if (reg->builtin_code == kTfLiteBuiltinArgMin ||
+                 reg->builtin_code == kTfLiteBuiltinArgMax) {
+        // The first input tensor is added as is. The second one, specifying
+        // the axis, needs to be converted to a scalar since TFLite uses a
+        // tensor but NNAPI uses a scalar as the axis.
+        if (input_pos == 0) {
+          TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index, hybrid_op));
+        } else {
+          const int axis_id = node->inputs->data[1];
+          const TfLiteTensor& axis_tensor = context->tensors[axis_id];
+          switch (axis_tensor.type) {
+            case kTfLiteInt32:
+              if (axis_tensor.allocation_type == kTfLiteMmapRo) {
+                TF_LITE_ENSURE_STATUS(builder.AddScalarInt32Operand(
+                    static_cast<int32_t>(*axis_tensor.data.i32)));
               } else {
-                builder.AddSingleValueTensorAsScalarOperand(
-                    constant_value_id, ANEURALNETWORKS_FLOAT32);
+                TF_LITE_ENSURE_STATUS(
+                    builder.AddSingleValueTensorAsScalarOperand(
+                        axis_id, ANEURALNETWORKS_INT32));
               }
               break;
-            case kTfLiteUInt8:
-              if (constant_value.allocation_type == kTfLiteMmapRo) {
-                builder.AddScalarInt32Operand(
-                    static_cast<int32_t>(*constant_value.data.uint8));
-              } else {
-                builder.AddSingleValueTensorAsScalarOperand(
-                    constant_value_id, ANEURALNETWORKS_INT32);
-              }
+            case kTfLiteInt64:
+              // Map() function already makes sure int64 input is constant.
+              TF_LITE_ENSURE_STATUS(builder.AddScalarInt32Operand(
+                  static_cast<int32_t>(*axis_tensor.data.i64)));
               break;
             default:
-              context->ReportError(
-                  context, "Unsupported type of pad value for pad_v2\n");
               return kTfLiteError;
           }
-          continue;
         }
-
-        if (input_index == kOptionalTensor &&
-            (reg->builtin_code == kTfLiteBuiltinLstm ||
-             reg->builtin_code == kTfLiteBuiltinSvdf ||
-             reg->builtin_code == kTfLiteBuiltinBidirectionalSequenceLstm)) {
-          // properly handle the optional tensor for LSTM and SVDF.
-          // currently only support float32.
-          // TODO(miaowang): make sure this is also able to handle quantized
-          // tensor when supported by NNAPI.
-          TF_LITE_ENSURE_STATUS(builder.AddVectorFloat32Operand(nullptr, 0));
-        } else if (reg->builtin_code == kTfLiteBuiltinResizeBilinear) {
-          if (input_pos == 0) {
-            // Only the first input tensor is added. The second one,
-            // specifying the output height and width, is not added and
-            // instead the height and width will be added individually as
-            // scalars by the mapping function returned by Map().
-            TF_LITE_ENSURE_STATUS(
-                builder.AddTensorInput(input_index, hybrid_op));
-          }
-        } else if (reg->builtin_code == kTfLiteBuiltinTopkV2 && input_pos > 0) {
-          // The K parameter tensor is not handled here but by the functor
-          // returned by Map, the input tensor is instead added in
-          // the else clause below
-          continue;
-        } else if (reg->builtin_code == kTfLiteBuiltinGather) {
-          // Everything is added during Map since input tensors
-          // have different order.
-          continue;
-        } else if (reg->builtin_code == kTfLiteBuiltinExpandDims &&
-                   input_pos == 1) {
-          // The axis param is added during Map
-          continue;
-        } else if (reg->builtin_code == kTfLiteBuiltinArgMin ||
-                   reg->builtin_code == kTfLiteBuiltinArgMax) {
-          // The first input tensor is added as is. The second one, specifying
-          // the axis, needs to be converted to a scalar since TFLite uses a
-          // tensor but NNAPI uses a scalar as the axis.
-          if (input_pos == 0) {
-            TF_LITE_ENSURE_STATUS(
-                builder.AddTensorInput(input_index, hybrid_op));
-          } else {
-            const int axis_id = node->inputs->data[1];
-            const TfLiteTensor& axis_tensor = context->tensors[axis_id];
-            switch (axis_tensor.type) {
-              case kTfLiteInt32:
-                if (axis_tensor.allocation_type == kTfLiteMmapRo) {
-                  TF_LITE_ENSURE_STATUS(builder.AddScalarInt32Operand(
-                      static_cast<int32_t>(*axis_tensor.data.i32)));
-                } else {
-                  TF_LITE_ENSURE_STATUS(
-                      builder.AddSingleValueTensorAsScalarOperand(
-                          axis_id, ANEURALNETWORKS_INT32));
-                }
-                break;
-              case kTfLiteInt64:
-                // Map() function already makes sure int64 input is constant.
-                TF_LITE_ENSURE_STATUS(builder.AddScalarInt32Operand(
-                    static_cast<int32_t>(*axis_tensor.data.i64)));
-                break;
-              default:
-                return kTfLiteError;
-            }
-          }
-        } else {
-          TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index, hybrid_op,
-                                                       input_tensor_flags));
-        }
-      }
-      // Get op type and operands
-      int nn_op_type = Map(
-          context, reg->builtin_code, reg->version, nnapi_->android_sdk_version,
-          node)({context, &builder, node, &model_state_outputs_,
-                 &model_state_tfl_inputs_});
-      // Map outputs to NN API tensor indices.
-      int output_tensor_flags = 0;
-      if (need_int8_conversion) {
-        output_tensor_flags |= NN_TENSOR_FLAG_INT8_CONVERSION;
-      }
-      for (auto output_index : TfLiteIntArrayView(node->outputs)) {
+      } else {
         TF_LITE_ENSURE_STATUS(
-            builder.AddTensorOutput(output_index, output_tensor_flags));
+            builder.AddTensorInput(input_index, hybrid_op, input_tensor_flags));
+      }
+    }
+    // Get op type and operands
+    int nn_op_type = Map(context, reg->builtin_code, reg->version,
+                         nnapi_->android_sdk_version, node,
+                         /*is_accelerator_specified=*/nnapi_device_ != nullptr)(
+        {context, &builder, node, &model_state_outputs_,
+         &model_state_tfl_inputs_, &feedback_loops_});
+    // Map outputs to NN API tensor indices.
+    int output_tensor_flags = 0;
+    if (need_int8_conversion) {
+      output_tensor_flags |= NN_TENSOR_FLAG_INT8_CONVERSION;
+    }
+    for (int output_pos = 0; output_pos < node->outputs->size; ++output_pos) {
+      const auto output_index = node->outputs->data[output_pos];
+
+      // Outputs for  basic LSTM cell are set in the Map function since
+      if (reg->builtin_code == kTfLiteBuiltinLstm && isLstmBasicKernel(node)) {
+        continue;
       }
 
-      // Dequantize operators may have to be added in case inputs are to be
-      // floating-point.
-      AddDequantizeOperatorsWhereNeeded(context, reg->builtin_code, node,
-                                        &builder);
-
-      builder.FinalizeAddOperation(nn_op_type);
+      TF_LITE_ENSURE_STATUS(
+          builder.AddTensorOutput(output_index, output_tensor_flags));
     }
-    return kTfLiteOk;
+
+    // Dequantize operators may have to be added in case inputs are to be
+    // floating-point.
+    AddDequantizeOperatorsWhereNeeded(context, reg->builtin_code, node,
+                                      &builder);
+
+    builder.FinalizeAddOperation(nn_op_type);
   }
+  return kTfLiteOk;
+}
 
-  TfLiteStatus BuildGraph(TfLiteContext* context,
-                          const TfLiteIntArray* input_tensors,
-                          const TfLiteIntArray* output_tensors) {
-    // Build the ops and tensors.
-    TF_LITE_ENSURE_STATUS(AddOpsAndTensors(context));
-    // Map input and output tensor indices to ANN
-    std::vector<uint32_t> inputs;
-    inputs.reserve(input_tensors->size);
-    std::vector<uint32_t> outputs;
-    outputs.reserve(output_tensors->size);
+TfLiteStatus NNAPIDelegateKernel::BuildGraph(
+    TfLiteContext* context, const TfLiteIntArray* input_tensors,
+    const TfLiteIntArray* output_tensors) {
+  // Build the ops and tensors.
+  TF_LITE_ENSURE_STATUS(AddOpsAndTensors(context));
+  // Map input and output tensor indices to ANN
+  std::vector<uint32_t> inputs;
+  inputs.reserve(input_tensors->size);
+  std::vector<uint32_t> outputs;
+  outputs.reserve(output_tensors->size);
 
-    size_t total_input_byte_size = 0;
-    // Make the TensorFlow Lite inputs and outputs to ann_indices.
-    for (int i : TfLiteIntArrayView(input_tensors)) {
-      // Constant tensors are not NNAPI inputs.
-      if (i != kOptionalTensor &&
-          context->tensors[i].allocation_type != kTfLiteMmapRo) {
-        inputs.push_back(operand_mapping_.lite_index_to_ann(i));
-        if (context->tensors[i].buffer_handle != kTfLiteNullBufferHandle) {
-          continue;
-        }
-        const TfLiteType nn_type_conversion =
-            operand_mapping_.lite_index_to_ann_type_conversion(i);
-        int tensor_size = 0;
-        if (nn_type_conversion == kTfLiteNoType) {
-          tensor_size = context->tensors[i].bytes;
-        } else {
-          size_t type_size;
-          TF_LITE_ENSURE_OK(
-              context, GetSizeOfType(context, nn_type_conversion, &type_size));
-          tensor_size = NumElements(&context->tensors[i]) * type_size;
-        }
-        total_input_byte_size += tensor_size;
-        total_input_byte_size += getNumPaddingBytes(tensor_size);
-      }
-    }
-
-    size_t total_output_byte_size = 0;
-    for (int i : TfLiteIntArrayView(output_tensors)) {
-      outputs.push_back(operand_mapping_.lite_index_to_ann(i));
+  size_t total_input_byte_size = 0;
+  // Make the TensorFlow Lite inputs and outputs to ann_indices.
+  for (int i : TfLiteIntArrayView(input_tensors)) {
+    // Constant tensors are not NNAPI inputs.
+    if (i != kOptionalTensor &&
+        context->tensors[i].allocation_type != kTfLiteMmapRo &&
+        // The delegate might not have mapped this input (this can
+        // happen if one tensor is split in several ones)
+        operand_mapping_.lite_index_to_ann(i) != -1) {
+      inputs.push_back(operand_mapping_.lite_index_to_ann(i));
       if (context->tensors[i].buffer_handle != kTfLiteNullBufferHandle) {
         continue;
       }
-      total_output_byte_size += context->tensors[i].bytes;
-      total_output_byte_size += getNumPaddingBytes(context->tensors[i].bytes);
+      const TfLiteType nn_type_conversion =
+          operand_mapping_.lite_index_to_ann_type_conversion(i);
+      int tensor_size = 0;
+      if (nn_type_conversion == kTfLiteNoType) {
+        tensor_size = context->tensors[i].bytes;
+      } else {
+        size_t type_size;
+        TF_LITE_ENSURE_OK(
+            context, GetSizeOfType(context, nn_type_conversion, &type_size));
+        tensor_size = NumElements(&context->tensors[i]) * type_size;
+      }
+      total_input_byte_size += tensor_size;
+      total_input_byte_size += getNumPaddingBytes(tensor_size);
     }
-
-    // Add state output tensors as model outputs.
-    for (int i : model_state_outputs_) {
-      outputs.push_back(i);
-    }
-
-    // Tell ANN to declare inputs/outputs
-    RETURN_TFLITE_ERROR_IF_NN_ERROR(
-        context, nnapi_->ANeuralNetworksModel_identifyInputsAndOutputs(
-                     nn_model_.get(), inputs.size(), inputs.data(),
-                     outputs.size(), outputs.data()));
-
-    // Set relaxed computation mode for fp32 if possible.
-    if (nnapi_->android_sdk_version >= kMinSdkVersionForNNAPI11) {
-      RETURN_TFLITE_ERROR_IF_NN_ERROR(
-          context,
-          nnapi_->ANeuralNetworksModel_relaxComputationFloat32toFloat16(
-              nn_model_.get(), context->allow_fp32_relax_to_fp16));
-    }
-
-    // Finalize the model
-    RETURN_TFLITE_ERROR_IF_NN_ERROR(
-        context, nnapi_->ANeuralNetworksModel_finish(nn_model_.get()));
-
-    // Create shared memory pool for inputs and outputs.
-    nn_input_memory_.reset(
-        new NNMemory(nnapi_, "input_pool", total_input_byte_size));
-    nn_output_memory_.reset(
-        new NNMemory(nnapi_, "output_pool", total_output_byte_size));
-
-    return kTfLiteOk;
   }
-};
 
-}  // namespace
+  size_t total_output_byte_size = 0;
+  for (int i : TfLiteIntArrayView(output_tensors)) {
+    const int output_tensor_ann_index = operand_mapping_.lite_index_to_ann(i);
+    // Unmapped outputs are not added
+    if (output_tensor_ann_index != -1) {
+      outputs.push_back(output_tensor_ann_index);
+    }
+    if (context->tensors[i].buffer_handle != kTfLiteNullBufferHandle) {
+      continue;
+    }
+    total_output_byte_size += context->tensors[i].bytes;
+    total_output_byte_size += getNumPaddingBytes(context->tensors[i].bytes);
+  }
+
+  // Add state output tensors as model outputs.
+  for (int i : model_state_outputs_) {
+    outputs.push_back(i);
+  }
+
+  // Tell ANN to declare inputs/outputs
+  RETURN_TFLITE_ERROR_IF_NN_ERROR(
+      context, nnapi_->ANeuralNetworksModel_identifyInputsAndOutputs(
+                   nn_model_.get(), inputs.size(), inputs.data(),
+                   outputs.size(), outputs.data()));
+
+  // Set relaxed computation mode for fp32 if possible.
+  if (nnapi_->android_sdk_version >= kMinSdkVersionForNNAPI11) {
+    RETURN_TFLITE_ERROR_IF_NN_ERROR(
+        context, nnapi_->ANeuralNetworksModel_relaxComputationFloat32toFloat16(
+                     nn_model_.get(), context->allow_fp32_relax_to_fp16));
+  }
+
+  // Finalize the model
+  RETURN_TFLITE_ERROR_IF_NN_ERROR(
+      context, nnapi_->ANeuralNetworksModel_finish(nn_model_.get()));
+
+  // Create shared memory pool for inputs and outputs.
+  nn_input_memory_.reset(
+      new NNMemory(nnapi_, "input_pool", total_input_byte_size));
+  nn_output_memory_.reset(
+      new NNMemory(nnapi_, "output_pool", total_output_byte_size));
+
+  return kTfLiteOk;
+}
+
+}  // namespace nnapi
+}  // namespace delegate
+
+using ::tflite::delegate::nnapi::NNAPIDelegateKernel;
 
 StatefulNnApiDelegate::StatefulNnApiDelegate(Options options)
     : TfLiteDelegate(TfLiteDelegateCreate()),
@@ -2761,6 +3319,9 @@
   }
 }
 
+using ::tflite::delegate::nnapi::kMinSdkVersionForNNAPI;
+using ::tflite::delegate::nnapi::kMinSdkVersionForNNAPI12;
+
 TfLiteStatus StatefulNnApiDelegate::DoPrepare(TfLiteContext* context,
                                               TfLiteDelegate* delegate) {
   // Do not check nodes_ if NN API is unavailable.
@@ -2769,18 +3330,20 @@
       !nnapi->nnapi_exists) {
     return kTfLiteOk;
   }
+  bool is_accelerator_specified = false;
   // For NNAPI 1.2+, check if there is any accelerator available.
   // If not, don't delegate to NNAPI's CPU reference implementation.
   if (nnapi->android_sdk_version >= kMinSdkVersionForNNAPI12) {
     // Check if user specified an acclelerator to use.
     const char* device_name_ptr = GetOptions(delegate).accelerator_name;
     if (device_name_ptr) {
-      if (!GetDeviceHandle(device_name_ptr)) {
+      if (!GetDeviceHandle(context, device_name_ptr)) {
         // If the selected accelerator cannot be found, NNAPI will not be used.
-        context->ReportError(context,
-                             "Could not find the specified accelerator: %s.",
-                             device_name_ptr);
         return kTfLiteOk;
+      } else {
+        // also check if the selected device is not CPU reference impl.
+        const string kNnapiReferenceImplName = "nnapi-reference";
+        is_accelerator_specified = kNnapiReferenceImplName != device_name_ptr;
       }
     } else {
       // If no accelerator is specified, only use NNAPI if an accelerator is
@@ -2805,7 +3368,6 @@
 
   int android_sdk_version = NnApiImplementation()->android_sdk_version;
   // Check for every node if it is supported
-  // TODO(b/80625235): Fix this to do more careful checking of versioning.
   for (int node_index : TfLiteIntArrayView(plan)) {
     TfLiteNode* node;
     TfLiteRegistration* registration;
@@ -2813,7 +3375,7 @@
         context, node_index, &node, &registration));
     if (NNAPIDelegateKernel::Map(context, registration->builtin_code,
                                  registration->version, android_sdk_version,
-                                 node)) {
+                                 node, is_accelerator_specified)) {
       supported_nodes.push_back(node_index);
     }
   }
@@ -2842,9 +3404,9 @@
       },
 
       .prepare = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
-        // Since the underlying resize happened ahead of delegation
-        // worked. This does nothing.
-        return kTfLiteOk;
+        NNAPIDelegateKernel* state =
+            reinterpret_cast<NNAPIDelegateKernel*>(node->user_data);
+        return state->Prepare(context, node);
       },
 
       .invoke = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h b/tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h
new file mode 100644
index 0000000..3a65c3d
--- /dev/null
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h
@@ -0,0 +1,243 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_KERNEL_H_
+#define TENSORFLOW_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_KERNEL_H_
+
+#include <map>
+#include <memory>
+
+#include "tensorflow/lite/allocation.h"
+#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
+#include "tensorflow/lite/nnapi/nnapi_implementation.h"
+
+namespace tflite {
+namespace delegate {
+namespace nnapi {
+
+constexpr int32_t kMinSdkVersionForNNAPI = 27;
+constexpr int32_t kMinSdkVersionForNNAPI11 = 28;
+constexpr int32_t kMinSdkVersionForNNAPI12 = 29;
+
+// Track tensor indices to NN API tensor indices mapping.
+class OperandMapping {
+ public:
+  // Given a TFLite index return the ANN index. If it doesn't exist
+  // return -1.
+  int lite_index_to_ann(int index) const {
+    if (index >= 0 && index < lite_tensor_to_ann_tensor_.size())
+      return lite_tensor_to_ann_tensor_[index];
+    else
+      return -1;
+  }
+
+  // NN API uses non tensor operands instead of structs. This creates one
+  // and returns the index. It uses a std::vector and resizes it as needed
+  // keeping -1 to unmapped values. Intermediate tensors likely will not
+  // be mapped.
+  int add_new_non_tensor_operand() { return next_ann_tensor_index_++; }
+
+  // This call is necessary for input operands generated by the delegate
+  // to map constant inputs not present in TFLite but required by NNAPI,
+  // for example when splitting one input in several ones.
+  int add_delegate_generated_input_ann_tensors_operand() {
+    return next_ann_tensor_index_++;
+  }
+
+  // Add a new mapping from `tflite_index` and return the NN API tensor index.
+  int add_new_ann_tensor_index(int tflite_index) {
+    if (tflite_index >= lite_tensor_to_ann_tensor_.size()) {
+      lite_tensor_to_ann_tensor_.resize(tflite_index + 1, -1);
+    }
+    const int new_tensor_index = next_ann_tensor_index_++;
+    lite_tensor_to_ann_tensor_[tflite_index] = new_tensor_index;
+    return new_tensor_index;
+  }
+
+  // Given a TFLite index returns a TFLite type to which a tensor must be
+  // converted during copying the data to the memory allocated for NN API.
+  // kTfLiteNoType means no conversion is needed.
+  TfLiteType lite_index_to_ann_type_conversion(int index) const {
+    if (index >= 0 && index < index_to_type_conversion_.size())
+      return index_to_type_conversion_[index];
+    else
+      return kTfLiteNoType;
+  }
+
+  // Add a new mapping from TFLite index to a type conversion.
+  void add_type_conversion(int tflite_index, TfLiteType tflite_type) {
+    if (tflite_index >= index_to_type_conversion_.size()) {
+      index_to_type_conversion_.resize(tflite_index + 1, kTfLiteNoType);
+    }
+    index_to_type_conversion_[tflite_index] = tflite_type;
+  }
+
+ private:
+  // Next index of ann tensor
+  int next_ann_tensor_index_ = 0;
+
+  // Mapping from lite index. Use a std::vector for speed and code size
+  // rather than a map.
+  std::vector<int> lite_tensor_to_ann_tensor_;
+  // Mapping from lite index to a type which tensor must be converted to during
+  // the copying of the data to the memory allocated for NN API. kTfLiteNoType
+  // means no conversion is needed. Use an std::vector for speed and code size
+  // rather than a map.
+  std::vector<TfLiteType> index_to_type_conversion_;
+};
+
+class NNAPIOpBuilder;
+
+// The kernel that represents the node sub set of TF Lite being run on NN API.
+struct NNAPIOpMappingArgs {
+  TfLiteContext* context;
+  NNAPIOpBuilder* builder;
+  TfLiteNode* node;
+  std::vector<int>* model_state_outputs;
+  std::vector<int>* model_state_tfl_inputs;
+  std::vector<std::tuple<int, int>>* feedback_loops;
+};
+
+// RAII NN API Model Destructor for use with std::unique_ptr
+struct NNFreeModel {
+  void operator()(ANeuralNetworksModel* model) {
+    NnApiImplementation()->ANeuralNetworksModel_free(model);
+  }
+};
+// RAII NN API Compilation Destructor for use with std::unique_ptr
+struct NNFreeCompilation {
+  void operator()(ANeuralNetworksCompilation* model) {
+    NnApiImplementation()->ANeuralNetworksCompilation_free(model);
+  }
+};
+
+// Manage NNAPI shared memory handle
+class NNMemory {
+ public:
+#ifdef TFLITE_NNAPI_ALLOW_MMAP_SHARING
+  NNMemory(const NnApi* nnapi, const char* name, size_t size) {
+    if (name && size > 0) {
+      nnapi_ = nnapi;
+      byte_size_ = size;
+      fd_ = nnapi_->ASharedMemory_create(name, size);
+      data_ptr_ = reinterpret_cast<uint8_t*>(
+          mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd_, 0));
+      nnapi_->ANeuralNetworksMemory_createFromFd(size, PROT_READ | PROT_WRITE,
+                                                 fd_, 0, &nn_memory_handle_);
+    }
+  }
+#else
+  NNMemory(const NnApi* /*nnapi*/, const char* /*name*/, size_t /*size*/) {}
+#endif
+
+  ~NNMemory() {
+#ifdef TFLITE_NNAPI_ALLOW_MMAP_SHARING
+    if (data_ptr_) {
+      munmap(data_ptr_, byte_size_);
+    }
+    if (nn_memory_handle_) {
+      nnapi_->ANeuralNetworksMemory_free(nn_memory_handle_);
+    }
+    if (fd_ > 0) close(fd_);
+#endif
+  }
+
+  ANeuralNetworksMemory* get_handle() { return nn_memory_handle_; }
+  uint8_t* get_data_ptr() { return data_ptr_; }
+
+ private:
+#ifdef TFLITE_NNAPI_ALLOW_MMAP_SHARING
+  const NnApi* nnapi_;
+  int fd_ = 0;
+  size_t byte_size_ = 0;
+#endif
+  uint8_t* data_ptr_ = nullptr;
+  ANeuralNetworksMemory* nn_memory_handle_ = nullptr;
+};
+
+// The kernel that represents the node sub set of TF Lite being run on NN API.
+class NNAPIDelegateKernel {
+ public:
+  NNAPIDelegateKernel() { nnapi_ = NnApiImplementation(); }
+  ~NNAPIDelegateKernel() {
+    for (auto content : allocation_memory_mapping_) {
+      nnapi_->ANeuralNetworksMemory_free(content.second);
+    }
+  }
+
+  typedef ANeuralNetworksOperationType (*MappingFn)(
+      const NNAPIOpMappingArgs& mapping_args);
+
+  // Return a function that knows how to translate a node into its operands
+  // when called. You can use this function to see if a node is supported
+  // (i.e. if the returned MappingFn is null, then the node is not supported).
+  static MappingFn Map(const TfLiteContext* context, int builtin_code,
+                       int version, int android_sdk_version,
+                       const TfLiteNode* node, bool is_accelerator_specified);
+
+  // Initialize the kernel (a NN model).
+  TfLiteStatus Init(TfLiteContext* context, const TfLiteDelegateParams* params);
+
+  TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node);
+
+  TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node);
+
+ private:
+  // Access to NNApi.
+  const NnApi* nnapi_;
+  // ANN device handle.
+  ANeuralNetworksDevice* nnapi_device_ = nullptr;
+  // ANN API state.
+  std::unique_ptr<ANeuralNetworksModel, NNFreeModel> nn_model_;
+  std::unique_ptr<ANeuralNetworksCompilation, NNFreeCompilation>
+      nn_compilation_;
+  // Node indices that this delegate is responsible for. Indices here
+  // indexes into the nodes array in the TfLiteContext.
+  std::vector<int> nodes_;
+  // Track indices we use
+  OperandMapping operand_mapping_;
+  std::map<const MMAPAllocation*, ANeuralNetworksMemory*>
+      allocation_memory_mapping_;
+  // Track memory map
+  const std::vector<StatefulNnApiDelegate::MemoryRegistration>*
+      tensor_memory_map_;
+  std::vector<int> model_state_outputs_;
+  std::vector<int> model_state_tfl_inputs_;
+  // This is the equivalent of the pair model_state_outputs_,
+  // model_state_tfl_inputs_ for all tensors where we have to keep the output
+  // data available for TFLite model users
+  std::vector<std::tuple<int, int>> feedback_loops_;
+
+  std::unique_ptr<NNMemory> nn_input_memory_;
+  std::unique_ptr<NNMemory> nn_output_memory_;
+
+  void AddDequantizeOperatorsWhereNeeded(const TfLiteContext* context,
+                                         int builtin_code,
+                                         const TfLiteNode* node,
+                                         NNAPIOpBuilder* builder);
+
+  TfLiteStatus AddOpsAndTensors(TfLiteContext* context);
+
+  TfLiteStatus BuildGraph(TfLiteContext* context,
+                          const TfLiteIntArray* input_tensors,
+                          const TfLiteIntArray* output_tensors);
+};
+
+}  // namespace nnapi
+}  // namespace delegate
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_KERNEL_H_
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc
index dbbe212..b1b1dcd 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc
@@ -259,6 +259,32 @@
   EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3}));
 }
 
+// Sanity check for the state-ful NNAPI delegate with invalid accelerator_name
+// specified.
+TEST(NNAPIDelegate, StatefulDelegateWithInvalidAcceleratorName) {
+  if (!NnApiImplementation()->ANeuralNetworksDevice_getName) {
+    GTEST_SKIP();
+  }
+  testing::internal::CaptureStderr();
+  StatefulNnApiDelegate::Options options;
+  options.execution_preference =
+      StatefulNnApiDelegate::Options::ExecutionPreference::kLowPower;
+  options.accelerator_name = "foo";
+
+  FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}},
+                    {TensorType_FLOAT32, {1, 2, 2, 1}},
+                    {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
+  EXPECT_THAT(testing::internal::GetCapturedStderr(),
+              testing::HasSubstr(
+                  "Could not find the specified NNAPI accelerator: foo"));
+
+  // Execution should fall back to the default CPU path.
+  m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
+  m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3}));
+}
+
 // Sanity check for the state-ful NNAPI delegate with compilation caching
 // enabled.
 TEST(NNAPIDelegate, StatefulDelegateWithCompilationCaching) {
diff --git a/tensorflow/lite/delegates/nnapi/quant_lstm_sup.cc b/tensorflow/lite/delegates/nnapi/quant_lstm_sup.cc
new file mode 100644
index 0000000..bcf2ff6
--- /dev/null
+++ b/tensorflow/lite/delegates/nnapi/quant_lstm_sup.cc
@@ -0,0 +1,152 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/delegates/nnapi/quant_lstm_sup.h"
+
+#include <algorithm>
+
+#include "tensorflow/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace delegate {
+namespace nnapi {
+
+// The function extracts a submatrix of the weights at a given row
+// and column offsets from  a 2D matrix
+void ExtractQuantLstmWeightsSubmatrix(const TfLiteIntArray* submatrix_dims,
+                                      const int32_t offset_row,
+                                      const int32_t offset_column,
+                                      const TfLiteIntArray* weight_dims,
+                                      const uint8_t* weights,
+                                      std::vector<uint8_t>* submatrix) {
+  auto const& submatrix_rows = submatrix_dims->data[0];
+  auto const& submatrix_cols = submatrix_dims->data[1];
+  auto const& weight_cols = weight_dims->data[1];
+
+  submatrix->resize(NumElements(submatrix_dims));
+
+  for (uint32_t i = 0; i < submatrix_rows * submatrix_cols; ++i) {
+    const uint32_t row = i / submatrix_cols;
+    const uint32_t column = i % submatrix_cols;
+    (*submatrix)[i] =
+        weights[(row + offset_row) * weight_cols + column + offset_column];
+  }
+}
+
+inline int OutputDepth(const TfLiteIntArray* weight_dims) {
+  return weight_dims->data[0] / 4;
+}
+
+inline int InputDepth(const TfLiteIntArray* weight_dims) {
+  return weight_dims->data[1] - OutputDepth(weight_dims);
+}
+
+void SetWeightSubmatrixDims(const TfLiteIntArray* weight_dims,
+                            TfLiteIntArray* recurrent_submatrix_dims,
+                            TfLiteIntArray* input_submatrix_dims) {
+  const auto input_depth = InputDepth(weight_dims);
+  const auto output_depth = OutputDepth(weight_dims);
+
+  recurrent_submatrix_dims->data[0] = output_depth;
+  recurrent_submatrix_dims->data[1] = output_depth;
+
+  input_submatrix_dims->data[0] = output_depth;
+  input_submatrix_dims->data[1] = input_depth;
+}
+
+// Doing exactly the opposite work of QuantizedLSTMCell::concatenateWeights
+// in NNAPI, decomposing the concat_weights tensor data into its 8 components
+// according to the following diagram
+//
+// +-----------------------------------+
+// | recurrentToInput  | inputToInput  |
+// |-------------------+---------------|
+// | recurrentToCell   | inputToCell   |
+// |-------------------+---------------|
+// | recurrentToForget | inputToForget |
+// |-------------------+---------------|
+// | recurrentToOutput | inputToOutput |
+// +-----------------------------------+
+void DecomposeQuantLstmWeightsTensor(const uint8_t* concat_weights,
+                                     const TfLiteIntArray* weight_dims,
+                                     std::vector<uint8_t>* recurrent_to_input,
+                                     std::vector<uint8_t>* input_to_input,
+                                     std::vector<uint8_t>* recurrent_to_cell,
+                                     std::vector<uint8_t>* input_to_cell,
+                                     std::vector<uint8_t>* recurrent_to_forget,
+                                     std::vector<uint8_t>* input_to_forget,
+                                     std::vector<uint8_t>* recurrent_to_output,
+                                     std::vector<uint8_t>* input_to_output) {
+  const auto output_depth = OutputDepth(weight_dims);
+
+  TfLiteIntArray* recurrent_submatrix_dims = TfLiteIntArrayCreate(2);
+  TfLiteIntArray* input_submatrix_dims = TfLiteIntArrayCreate(2);
+  SetWeightSubmatrixDims(weight_dims, recurrent_submatrix_dims,
+                         input_submatrix_dims);
+
+  ExtractQuantLstmWeightsSubmatrix(recurrent_submatrix_dims, 0 * output_depth,
+                                   0, weight_dims, concat_weights,
+                                   recurrent_to_input);
+  ExtractQuantLstmWeightsSubmatrix(input_submatrix_dims, 0 * output_depth,
+                                   output_depth, weight_dims, concat_weights,
+                                   input_to_input);
+
+  ExtractQuantLstmWeightsSubmatrix(recurrent_submatrix_dims, 1 * output_depth,
+                                   0, weight_dims, concat_weights,
+                                   recurrent_to_cell);
+  ExtractQuantLstmWeightsSubmatrix(input_submatrix_dims, 1 * output_depth,
+                                   output_depth, weight_dims, concat_weights,
+                                   input_to_cell);
+
+  ExtractQuantLstmWeightsSubmatrix(recurrent_submatrix_dims, 2 * output_depth,
+                                   0, weight_dims, concat_weights,
+                                   recurrent_to_forget);
+  ExtractQuantLstmWeightsSubmatrix(input_submatrix_dims, 2 * output_depth,
+                                   output_depth, weight_dims, concat_weights,
+                                   input_to_forget);
+
+  ExtractQuantLstmWeightsSubmatrix(recurrent_submatrix_dims, 3 * output_depth,
+                                   0, weight_dims, concat_weights,
+                                   recurrent_to_output);
+  ExtractQuantLstmWeightsSubmatrix(input_submatrix_dims, 3 * output_depth,
+                                   output_depth, weight_dims, concat_weights,
+                                   input_to_output);
+
+  TfLiteIntArrayFree(recurrent_submatrix_dims);
+  TfLiteIntArrayFree(input_submatrix_dims);
+}
+
+void DecomposeBiasTensor(const int32_t* biases, int bias_size,
+                         std::vector<int32_t>* input_bias,
+                         std::vector<int32_t>* cell_bias,
+                         std::vector<int32_t>* forget_bias,
+                         std::vector<int32_t>* output_bias) {
+  input_bias->resize(bias_size);
+  std::copy(biases, biases + bias_size, input_bias->begin());
+
+  cell_bias->resize(bias_size);
+  std::copy(biases + bias_size, biases + 2 * bias_size, cell_bias->begin());
+
+  forget_bias->resize(bias_size);
+  std::copy(biases + 2 * bias_size, biases + 3 * bias_size,
+            forget_bias->begin());
+
+  output_bias->resize(bias_size);
+  std::copy(biases + 3 * bias_size, biases + 4 * bias_size,
+            output_bias->begin());
+}
+
+}  // namespace nnapi
+}  // namespace delegate
+}  // namespace tflite
diff --git a/tensorflow/lite/delegates/nnapi/quant_lstm_sup.h b/tensorflow/lite/delegates/nnapi/quant_lstm_sup.h
new file mode 100644
index 0000000..1385b92
--- /dev/null
+++ b/tensorflow/lite/delegates/nnapi/quant_lstm_sup.h
@@ -0,0 +1,58 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_DELEGATES_NNAPI_QUANT_LSTM_SUP_H_
+#define TENSORFLOW_LITE_DELEGATES_NNAPI_QUANT_LSTM_SUP_H_
+
+#include <vector>
+
+#include "tensorflow/lite/c/c_api_internal.h"
+
+namespace tflite {
+namespace delegate {
+namespace nnapi {
+
+void ExtractQuantLstmWeightsSubmatrix(const TfLiteIntArray* submatrix_dims,
+                                      const int32_t offset_row,
+                                      const int32_t offset_column,
+                                      const TfLiteIntArray* weight_dims,
+                                      const uint8_t* weights,
+                                      std::vector<uint8_t>* submatrix);
+
+void DecomposeQuantLstmWeightsTensor(const uint8_t* concat_weights,
+                                     const TfLiteIntArray* weight_dims,
+                                     std::vector<uint8_t>* recurrent_to_input,
+                                     std::vector<uint8_t>* input_to_input,
+                                     std::vector<uint8_t>* recurrent_to_cell,
+                                     std::vector<uint8_t>* input_to_cell,
+                                     std::vector<uint8_t>* recurrent_to_forget,
+                                     std::vector<uint8_t>* input_to_forget,
+                                     std::vector<uint8_t>* recurrent_to_output,
+                                     std::vector<uint8_t>* input_to_output);
+
+void SetWeightSubmatrixDims(const TfLiteIntArray* weight_dims,
+                            TfLiteIntArray* recurrent_submatrix_dims,
+                            TfLiteIntArray* input_submatrix_dims);
+
+void DecomposeBiasTensor(const int32_t* biases, int bias_size,
+                         std::vector<int32_t>* input_bias,
+                         std::vector<int32_t>* cell_bias,
+                         std::vector<int32_t>* forget_bias,
+                         std::vector<int32_t>* output_bias);
+
+}  // namespace nnapi
+}  // namespace delegate
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_DELEGATES_NNAPI_QUANT_LSTM_SUP_H_
diff --git a/tensorflow/lite/delegates/nnapi/quant_lstm_sup_test.cc b/tensorflow/lite/delegates/nnapi/quant_lstm_sup_test.cc
new file mode 100644
index 0000000..2bbf52c
--- /dev/null
+++ b/tensorflow/lite/delegates/nnapi/quant_lstm_sup_test.cc
@@ -0,0 +1,344 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/delegates/nnapi/quant_lstm_sup.h"
+
+#include <cstdint>
+#include <initializer_list>
+#include <memory>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/testing/util.h"
+
+namespace {
+
+using ::testing::ElementsAreArray;
+using ::testing::Test;
+
+class DimsAllocatingTest : public Test {
+ protected:
+  DimsAllocatingTest() : allocated_dims_() {}
+
+  ~DimsAllocatingTest() override {
+    for (TfLiteIntArray* dim : allocated_dims_) {
+      TfLiteIntArrayFree(dim);
+    }
+  }
+
+  TfLiteIntArray* CreateDimArray(int size,
+                                 std::initializer_list<int> dimensions) {
+    TfLiteIntArray* dims = TfLiteIntArrayCreate(size);
+    allocated_dims_.push_back(dims);
+
+    int i = 0;
+    for (const int dimension : dimensions) {
+      dims->data[i++] = dimension;
+    }
+
+    return dims;
+  }
+
+ private:
+  std::vector<TfLiteIntArray*> allocated_dims_;
+};
+
+using tflite::delegate::nnapi::ExtractQuantLstmWeightsSubmatrix;
+
+class ExtractQuantLstmWeightsSubmatrixTest : public DimsAllocatingTest {};
+
+TEST_F(ExtractQuantLstmWeightsSubmatrixTest, TopLeftSubmatrixIsExtracted) {
+  std::vector<uint8_t> weights = {1,   2,   3,   4,   5,    //
+                                  11,  12,  13,  14,  15,   //
+                                  101, 102, 103, 104, 105,  //
+                                  111, 112, 113, 114, 115,  //
+                                  201, 202, 203, 204, 205,  //
+                                  211, 212, 213, 214, 215,  //
+                                  221, 222, 223, 224, 225,  //
+                                  231, 232, 233, 234, 235};
+  const TfLiteIntArray* weight_dims = CreateDimArray(2, {8, 5});
+
+  std::vector<uint8_t> submatrix;
+  const TfLiteIntArray* submatrix_dims = CreateDimArray(2, {2, 3});
+
+  ExtractQuantLstmWeightsSubmatrix(submatrix_dims, 0 /* offset_row */,
+                                   0 /* offset_column */, weight_dims,
+                                   weights.data(), &submatrix);
+
+  EXPECT_THAT(submatrix, ElementsAreArray({1, 2, 3, 11, 12, 13}));
+}
+
+TEST_F(ExtractQuantLstmWeightsSubmatrixTest, TopRightSubmatrixIsExtracted) {
+  std::vector<uint8_t> weights = {1,   2,   3,   4,   5,    //
+                                  11,  12,  13,  14,  15,   //
+                                  101, 102, 103, 104, 105,  //
+                                  111, 112, 113, 114, 115,  //
+                                  201, 202, 203, 204, 205,  //
+                                  211, 212, 213, 214, 215,  //
+                                  221, 222, 223, 224, 225,  //
+                                  231, 232, 233, 234, 235};
+  const TfLiteIntArray* weight_dims = CreateDimArray(2, {8, 5});
+
+  std::vector<uint8_t> submatrix;
+  const TfLiteIntArray* submatrix_dims = CreateDimArray(2, {2, 2});
+
+  ExtractQuantLstmWeightsSubmatrix(submatrix_dims, 0 /* offset_row */,
+                                   3 /* offset_column */, weight_dims,
+                                   weights.data(), &submatrix);
+
+  EXPECT_THAT(submatrix, ElementsAreArray({4, 5, 14, 15}));
+}
+
+TEST_F(ExtractQuantLstmWeightsSubmatrixTest, RightCentralSubmatrixIsExtracted) {
+  std::vector<uint8_t> weights = {1,   2,   3,   4,   5,    //
+                                  11,  12,  13,  14,  15,   //
+                                  101, 102, 103, 104, 105,  //
+                                  111, 112, 113, 114, 115,  //
+                                  201, 202, 203, 204, 205,  //
+                                  211, 212, 213, 214, 215,  //
+                                  221, 222, 223, 224, 225,  //
+                                  231, 232, 233, 234, 235};
+  const TfLiteIntArray* weight_dims = CreateDimArray(2, {8, 5});
+
+  std::vector<uint8_t> submatrix;
+  const TfLiteIntArray* submatrix_dims = CreateDimArray(2, {2, 2});
+
+  ExtractQuantLstmWeightsSubmatrix(
+      submatrix_dims, 1 * submatrix_dims->data[0] /* offset_row */,
+      3 /* offset_column */, weight_dims, weights.data(), &submatrix);
+
+  EXPECT_THAT(submatrix, ElementsAreArray({104, 105, 114, 115}));
+}
+
+using tflite::delegate::nnapi::DecomposeQuantLstmWeightsTensor;
+
+class QuantLstmWeightDecompTest : public DimsAllocatingTest {
+ protected:
+  QuantLstmWeightDecompTest()
+      : weights_({1,   2,   3,   4,   5,    //
+                  11,  12,  13,  14,  15,   //
+                  101, 102, 103, 104, 105,  //
+                  111, 112, 113, 114, 115,  //
+                  201, 202, 203, 204, 205,  //
+                  211, 212, 213, 214, 215,  //
+                  221, 222, 223, 224, 225,  //
+                  231, 232, 233, 234, 235}),
+        // Creating the arrays empty, the size is set by the decomposition
+        // function
+        recurrent_to_input_(),
+        input_to_input_(),
+        recurrent_to_cell_(),
+        input_to_cell_(),
+        recurrent_to_forget_(),
+        input_to_forget_(),
+        recurrent_to_output_(),
+        input_to_output_() {
+    weight_dims_ = CreateDimArray(2, {8, 5});
+  }
+
+  const std::vector<uint8_t> weights_;
+  const TfLiteIntArray* weight_dims_;
+  std::vector<uint8_t> recurrent_to_input_;
+  std::vector<uint8_t> input_to_input_;
+  std::vector<uint8_t> recurrent_to_cell_;
+  std::vector<uint8_t> input_to_cell_;
+  std::vector<uint8_t> recurrent_to_forget_;
+  std::vector<uint8_t> input_to_forget_;
+  std::vector<uint8_t> recurrent_to_output_;
+  std::vector<uint8_t> input_to_output_;
+};
+
+TEST_F(QuantLstmWeightDecompTest, ExtractRecurrentToInput) {
+  DecomposeQuantLstmWeightsTensor(
+      weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
+      &recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
+      &input_to_forget_, &recurrent_to_output_, &input_to_output_);
+
+  EXPECT_THAT(recurrent_to_input_, ElementsAreArray({1, 2,  //
+                                                     11, 12}));
+}
+
+TEST_F(QuantLstmWeightDecompTest, ExtractInputToInput) {
+  DecomposeQuantLstmWeightsTensor(
+      weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
+      &recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
+      &input_to_forget_, &recurrent_to_output_, &input_to_output_);
+
+  EXPECT_THAT(input_to_input_, ElementsAreArray({3, 4, 5,  //
+                                                 13, 14, 15}));
+}
+
+TEST_F(QuantLstmWeightDecompTest, ExtractRecurrentToCell) {
+  DecomposeQuantLstmWeightsTensor(
+      weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
+      &recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
+      &input_to_forget_, &recurrent_to_output_, &input_to_output_);
+
+  EXPECT_THAT(recurrent_to_cell_, ElementsAreArray({101, 102,  //
+                                                    111, 112}));
+}
+
+TEST_F(QuantLstmWeightDecompTest, ExtractInputToCell) {
+  DecomposeQuantLstmWeightsTensor(
+      weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
+      &recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
+      &input_to_forget_, &recurrent_to_output_, &input_to_output_);
+
+  EXPECT_THAT(input_to_cell_, ElementsAreArray({103, 104, 105,  //
+                                                113, 114, 115}));
+}
+
+TEST_F(QuantLstmWeightDecompTest, ExtractRecurrentToForget) {
+  DecomposeQuantLstmWeightsTensor(
+      weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
+      &recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
+      &input_to_forget_, &recurrent_to_output_, &input_to_output_);
+
+  EXPECT_THAT(recurrent_to_forget_, ElementsAreArray({201, 202,  //
+                                                      211, 212}));
+}
+
+TEST_F(QuantLstmWeightDecompTest, ExtractInputToForget) {
+  DecomposeQuantLstmWeightsTensor(
+      weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
+      &recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
+      &input_to_forget_, &recurrent_to_output_, &input_to_output_);
+
+  EXPECT_THAT(input_to_forget_, ElementsAreArray({203, 204, 205,  //
+                                                  213, 214, 215}));
+}
+
+TEST_F(QuantLstmWeightDecompTest, ExtractRecurrentToOutput) {
+  DecomposeQuantLstmWeightsTensor(
+      weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
+      &recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
+      &input_to_forget_, &recurrent_to_output_, &input_to_output_);
+
+  EXPECT_THAT(recurrent_to_output_, ElementsAreArray({221, 222,  //
+                                                      231, 232}));
+}
+
+TEST_F(QuantLstmWeightDecompTest, ExtractInputToOutput) {
+  DecomposeQuantLstmWeightsTensor(
+      weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
+      &recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
+      &input_to_forget_, &recurrent_to_output_, &input_to_output_);
+
+  EXPECT_THAT(input_to_output_, ElementsAreArray({223, 224, 225,  //
+                                                  233, 234, 235}));
+}
+
+using tflite::delegate::nnapi::DecomposeBiasTensor;
+
+TEST(DecomposeBiasTensor, ExtractInputBias) {
+  // clang-format off
+  std::vector<int32_t> biases
+      // inputGateBias
+      {-7876, 13488, -726, 32839,
+      // cellGateBias
+      39481, 48624, 48976, -21419,
+      // forgetGateBias
+      9206, -46884, -11693, -38724,
+      // outputGateBias
+      -58999, -17050, -41852, -40538};
+  // clang-format on
+
+  std::vector<int32_t> input_bias;
+  std::vector<int32_t> cell_bias;
+  std::vector<int32_t> forget_bias;
+  std::vector<int32_t> output_bias;
+  DecomposeBiasTensor(biases.data(), 4, &input_bias, &cell_bias, &forget_bias,
+                      &output_bias);
+
+  EXPECT_THAT(input_bias, ElementsAreArray({-7876, 13488, -726, 32839}));
+}
+
+TEST(DecomposeBiasTensor, ExtractCellBias) {
+  // clang-format off
+  std::vector<int32_t> biases
+      // inputGateBias
+      {-7876, 13488, -726, 32839,
+      // cellGateBias
+      39481, 48624, 48976, -21419,
+      // forgetGateBias
+      9206, -46884, -11693, -38724,
+      // outputGateBias
+      -58999, -17050, -41852, -40538};
+  // clang-format on
+
+  std::vector<int32_t> input_bias;
+  std::vector<int32_t> cell_bias;
+  std::vector<int32_t> forget_bias;
+  std::vector<int32_t> output_bias;
+  DecomposeBiasTensor(biases.data(), 4, &input_bias, &cell_bias, &forget_bias,
+                      &output_bias);
+
+  EXPECT_THAT(cell_bias, ElementsAreArray({39481, 48624, 48976, -21419}));
+}
+
+TEST(DecomposeBiasTensor, ExtractForgetBias) {
+  // clang-format off
+  std::vector<int32_t> biases
+      // inputGateBias
+      {-7876, 13488, -726, 32839,
+      // cellGateBias
+      39481, 48624, 48976, -21419,
+      // forgetGateBias
+      9206, -46884, -11693, -38724,
+      // outputGateBias
+      -58999, -17050, -41852, -40538};
+  // clang-format on
+
+  std::vector<int32_t> input_bias;
+  std::vector<int32_t> cell_bias;
+  std::vector<int32_t> forget_bias;
+  std::vector<int32_t> output_bias;
+  DecomposeBiasTensor(biases.data(), 4, &input_bias, &cell_bias, &forget_bias,
+                      &output_bias);
+
+  EXPECT_THAT(forget_bias, ElementsAreArray({9206, -46884, -11693, -38724}));
+}
+
+TEST(DecomposeBiasTensor, ExtractOutputBias) {
+  // clang-format off
+  std::vector<int32_t> biases
+      // inputGateBias
+      {-7876, 13488, -726, 32839,
+      // cellGateBias
+      39481, 48624, 48976, -21419,
+      // forgetGateBias
+      9206, -46884, -11693, -38724,
+      // outputGateBias
+      -58999, -17050, -41852, -40538};
+  // clang-format on
+
+  std::vector<int32_t> input_bias;
+  std::vector<int32_t> cell_bias;
+  std::vector<int32_t> forget_bias;
+  std::vector<int32_t> output_bias;
+  DecomposeBiasTensor(biases.data(), 4, &input_bias, &cell_bias, &forget_bias,
+                      &output_bias);
+
+  EXPECT_THAT(output_bias, ElementsAreArray({-58999, -17050, -41852, -40538}));
+}
+
+}  // namespace
+
+int main(int argc, char** argv) {
+  ::tflite::LogToStderr();
+  ::testing::InitGoogleTest(&argc, argv);
+  return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/lite/examples/ios/camera/CameraExampleViewController.mm b/tensorflow/lite/examples/ios/camera/CameraExampleViewController.mm
index c6a38e7..1b9792c 100644
--- a/tensorflow/lite/examples/ios/camera/CameraExampleViewController.mm
+++ b/tensorflow/lite/examples/ios/camera/CameraExampleViewController.mm
@@ -386,7 +386,7 @@
 - (void)dealloc {
 #if TFLITE_USE_GPU_DELEGATE
   if (delegate) {
-    DeleteGpuDelegate(delegate);
+    TFLGpuDelegateDelete(delegate);
   }
 #endif
   [self teardownAVCapture];
@@ -418,7 +418,7 @@
   GpuDelegateOptions options;
   options.allow_precision_loss = true;
   options.wait_type = GpuDelegateOptions::WaitType::kActive;
-  delegate = NewGpuDelegate(&options);
+  delegate = TFLGpuDelegateCreate(&options);
   interpreter->ModifyGraphWithDelegate(delegate);
 #endif
 
diff --git a/tensorflow/lite/examples/ios/download_models.sh b/tensorflow/lite/examples/ios/download_models.sh
index a450aba..68a9c96 100755
--- a/tensorflow/lite/examples/ios/download_models.sh
+++ b/tensorflow/lite/examples/ios/download_models.sh
@@ -17,8 +17,8 @@
 set -ex
 
 SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
-FLOAT_MODEL_URL="http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz"
-QUANTIZED_MODEL_URL="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz"
+FLOAT_MODEL_URL="https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz"
+QUANTIZED_MODEL_URL="https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz"
 DOWNLOADS_DIR=$(mktemp -d)
 
 cd "$SCRIPT_DIR"
diff --git a/tensorflow/lite/examples/python/README.md b/tensorflow/lite/examples/python/README.md
new file mode 100644
index 0000000..ddfedb2
--- /dev/null
+++ b/tensorflow/lite/examples/python/README.md
@@ -0,0 +1,47 @@
+# TensorFlow Lite Python image classification demo
+
+This `label_image.py` script shows how you can load a pre-trained and converted
+TensorFlow Lite model and use it to recognize objects in images. The Python
+script accepts arguments specifying the model to use, the corresponding labels
+file, and the image to process.
+
+Before you begin,
+make sure you [have TensorFlow installed](https://www.tensorflow.org/install).
+
+
+## Download sample model and image
+
+You can use any compatible model, but the following MobileNet v1 model offers
+a good demonstration of a model trained to recognize 1,000 different objects.
+
+```
+# Get photo
+curl https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/lite/examples/label_image/testdata/grace_hopper.bmp > /tmp/grace_hopper.bmp
+# Get model
+curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz | tar xzv -C /tmp
+# Get labels
+curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz  | tar xzv -C /tmp  mobilenet_v1_1.0_224/labels.txt
+
+mv /tmp/mobilenet_v1_1.0_224/labels.txt /tmp/
+```
+
+## Run the sample
+
+Note: Instead use `python` if you're using Python 2.x.
+
+```
+python3 label_image.py \
+  --model_file /tmp/mobilenet_v1_1.0_224.tflite \
+  --label_file /tmp/labels.txt \
+  --image /tmp/grace_hopper.bmp
+```
+
+You should see results like this:
+
+```
+0.728693: military uniform
+0.116163: Windsor tie
+0.035517: bow tie
+0.014874: mortarboard
+0.011758: bolo tie
+```
diff --git a/tensorflow/lite/examples/python/label_image.md b/tensorflow/lite/examples/python/label_image.md
deleted file mode 100644
index b4ec42f..0000000
--- a/tensorflow/lite/examples/python/label_image.md
+++ /dev/null
@@ -1,50 +0,0 @@
-
-With model, input image (grace_hopper.bmp), and labels file (labels.txt)
-in /tmp.
-
-The example input image and labels file are from TensorFlow repo and
-MobileNet V1 model files.
-
-```
-curl https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/lite/examples/label_image/testdata/grace_hopper.bmp > /tmp/grace_hopper.bmp
-
-curl  https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz  | tar xzv -C /tmp  mobilenet_v1_1.0_224/labels.txt
-mv /tmp/mobilenet_v1_1.0_224/labels.txt /tmp/
-
-```
-
-Run
-
-```
-curl http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224_quant.tgz | tar xzv -C /tmp
-bazel run --config opt //tensorflow/lite/examples/python:label_image
-```
-
-We can get results like
-
-```
-0.470588: military uniform
-0.337255: Windsor tie
-0.047059: bow tie
-0.031373: mortarboard
-0.019608: suit
-```
-
-Run
-
-```
-curl http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz | tar xzv -C /tmp
-bazel run --config opt //tensorflow/lite/examples/python:label_image \
--- --model_file /tmp/mobilenet_v1_1.0_224.tflite
-```
-
-We can get results like
-```
-0.728693: military uniform
-0.116163: Windsor tie
-0.035517: bow tie
-0.014874: mortarboard
-0.011758: bolo tie
-```
-
-Check [models](../../g3doc/models.md) for models hosted by Google.
diff --git a/tensorflow/lite/examples/python/label_image.py b/tensorflow/lite/examples/python/label_image.py
index 0bc15d3..e9eaa98 100644
--- a/tensorflow/lite/examples/python/label_image.py
+++ b/tensorflow/lite/examples/python/label_image.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""label_image for tflite"""
+"""label_image for tflite."""
 
 from __future__ import absolute_import
 from __future__ import division
@@ -23,46 +23,49 @@
 
 from PIL import Image
 
-from tensorflow.lite.python import interpreter as interpreter_wrapper
+from tensorflow.lite.python.interpreter import Interpreter
+
 
 def load_labels(filename):
-  my_labels = []
-  input_file = open(filename, 'r')
-  for l in input_file:
-    my_labels.append(l.strip())
-  return my_labels
+  with open(filename, 'r') as f:
+    return [line.strip() for line in f.readlines()]
 
-if __name__ == "__main__":
-  floating_model = False
 
+if __name__ == '__main__':
   parser = argparse.ArgumentParser()
-  parser.add_argument("-i", "--image", default="/tmp/grace_hopper.bmp", \
-    help="image to be classified")
-  parser.add_argument("-m", "--model_file", \
-    default="/tmp/mobilenet_v1_1.0_224_quant.tflite", \
-    help=".tflite model to be executed")
-  parser.add_argument("-l", "--label_file", default="/tmp/labels.txt", \
-    help="name of file containing labels")
-  parser.add_argument("--input_mean", default=127.5, help="input_mean")
-  parser.add_argument("--input_std", default=127.5, \
-    help="input standard deviation")
+  parser.add_argument(
+      '-i',
+      '--image',
+      default='/tmp/grace_hopper.bmp',
+      help='image to be classified')
+  parser.add_argument(
+      '-m',
+      '--model_file',
+      default='/tmp/mobilenet_v1_1.0_224_quant.tflite',
+      help='.tflite model to be executed')
+  parser.add_argument(
+      '-l',
+      '--label_file',
+      default='/tmp/labels.txt',
+      help='name of file containing labels')
+  parser.add_argument('--input_mean', default=127.5, help='input_mean')
+  parser.add_argument(
+      '--input_std', default=127.5, help='input standard deviation')
   args = parser.parse_args()
 
-  interpreter = interpreter_wrapper.Interpreter(model_path=args.model_file)
+  interpreter = Interpreter(model_path=args.model_file)
   interpreter.allocate_tensors()
 
   input_details = interpreter.get_input_details()
   output_details = interpreter.get_output_details()
 
   # check the type of the input tensor
-  if input_details[0]['dtype'] == np.float32:
-    floating_model = True
+  floating_model = input_details[0]['dtype'] == np.float32
 
   # NxHxWxC, H:1, W:2
   height = input_details[0]['shape'][1]
   width = input_details[0]['shape'][2]
-  img = Image.open(args.image)
-  img = img.resize((width, height))
+  img = Image.open(args.image).resize((width, height))
 
   # add N dim
   input_data = np.expand_dims(img, axis=0)
@@ -81,6 +84,6 @@
   labels = load_labels(args.label_file)
   for i in top_k:
     if floating_model:
-      print('{0:08.6f}'.format(float(results[i]))+":", labels[i])
+      print('{:08.6f}: {}'.format(float(results[i]), labels[i]))
     else:
-      print('{0:08.6f}'.format(float(results[i]/255.0))+":", labels[i])
+      print('{:08.6f}: {}'.format(float(results[i] / 255.0), labels[i]))
diff --git a/tensorflow/lite/experimental/c/c_api.cc b/tensorflow/lite/experimental/c/c_api.cc
index 67f826c..7118e36 100644
--- a/tensorflow/lite/experimental/c/c_api.cc
+++ b/tensorflow/lite/experimental/c/c_api.cc
@@ -123,6 +123,12 @@
         TFL_InterpreterOptions::kDefaultNumThreads) {
       interpreter->SetNumThreads(optional_options->num_threads);
     }
+
+    for (auto* delegate : optional_options->delegates) {
+      if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) {
+        return nullptr;
+      }
+    }
   }
 
   return new TFL_Interpreter{model->impl, std::move(optional_error_reporter),
diff --git a/tensorflow/lite/experimental/c/c_api_experimental.cc b/tensorflow/lite/experimental/c/c_api_experimental.cc
index a246ed9..0fc4169 100644
--- a/tensorflow/lite/experimental/c/c_api_experimental.cc
+++ b/tensorflow/lite/experimental/c/c_api_experimental.cc
@@ -41,6 +41,11 @@
   options->op_resolver.AddCustom(name, registration, min_version, max_version);
 }
 
+void TFL_InterpreterOptionsAddDelegate(TFL_InterpreterOptions* options,
+                                       TFL_Delegate* delegate) {
+  options->delegates.push_back(delegate);
+}
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif  // __cplusplus
diff --git a/tensorflow/lite/experimental/c/c_api_experimental.h b/tensorflow/lite/experimental/c/c_api_experimental.h
index 0f082c0..2df7487 100644
--- a/tensorflow/lite/experimental/c/c_api_experimental.h
+++ b/tensorflow/lite/experimental/c/c_api_experimental.h
@@ -23,6 +23,7 @@
 #endif  // __cplusplus
 
 typedef TfLiteBuiltinOperator TFL_BuiltinOperator;
+typedef TfLiteDelegate TFL_Delegate;
 
 // Resets all variable tensors to zero.
 TFL_CAPI_EXPORT extern TFL_Status TFL_InterpreterResetVariableTensors(
@@ -42,12 +43,22 @@
 //
 // NOTE: The interpreter will make a copy of `registration` internally, so the
 // caller should ensure that its contents (function pointers, etc...) remain
-// valid for the duration of the interpreter's lifetime. A common practice is
-// making the provided TFL_Registration instance static.
+// valid for the duration of any created interpreter's lifetime. A common
+// practice is making the provided TFL_Registration instance static.
 TFL_CAPI_EXPORT void TFL_InterpreterOptionsAddCustomOp(
     TFL_InterpreterOptions* options, const char* name,
     const TFL_Registration* registration, int min_version, int max_version);
 
+// Adds a delegate to be applied during `TFL_Interpreter` creation.
+//
+// If delegate application fails, interpreter creation will also fail with an
+// associated error logged.
+//
+// NOTE: The caller retains ownership of the delegate and should ensure that it
+// remains valid for the duration of any created interpreter's lifetime.
+TFL_CAPI_EXPORT extern void TFL_InterpreterOptionsAddDelegate(
+    TFL_InterpreterOptions* options, TFL_Delegate* delegate);
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif  // __cplusplus
diff --git a/tensorflow/lite/experimental/c/c_api_experimental_test.cc b/tensorflow/lite/experimental/c/c_api_experimental_test.cc
index e79c720..fc01ac4 100644
--- a/tensorflow/lite/experimental/c/c_api_experimental_test.cc
+++ b/tensorflow/lite/experimental/c/c_api_experimental_test.cc
@@ -32,7 +32,7 @@
   return &registration;
 }
 
-TEST(CApiExperimentalSimple, Smoke) {
+TEST(CApiExperimentalTest, Smoke) {
   TFL_Model* model = TFL_NewModelFromFile(
       "tensorflow/lite/testdata/add.bin");
   ASSERT_NE(model, nullptr);
@@ -52,6 +52,52 @@
   TFL_DeleteModel(model);
 }
 
+TEST(CApiExperimentalTest, Delegate) {
+  TFL_Model* model =
+      TFL_NewModelFromFile("tensorflow/lite/testdata/add.bin");
+
+  // Create and install a delegate instance.
+  bool delegate_prepared = false;
+  TfLiteDelegate delegate = TfLiteDelegateCreate();
+  delegate.data_ = &delegate_prepared;
+  delegate.Prepare = [](TfLiteContext* context, TfLiteDelegate* delegate) {
+    *static_cast<bool*>(delegate->data_) = true;
+    return kTfLiteOk;
+  };
+  TFL_InterpreterOptions* options = TFL_NewInterpreterOptions();
+  TFL_InterpreterOptionsAddDelegate(options, &delegate);
+  TFL_Interpreter* interpreter = TFL_NewInterpreter(model, options);
+
+  // The delegate should have been applied.
+  EXPECT_TRUE(delegate_prepared);
+
+  // Subsequent exectuion should behave properly (the delegate is a no-op).
+  TFL_DeleteInterpreterOptions(options);
+  TFL_DeleteModel(model);
+  EXPECT_EQ(TFL_InterpreterInvoke(interpreter), kTfLiteOk);
+  TFL_DeleteInterpreter(interpreter);
+}
+
+TEST(CApiExperimentalTest, DelegateFails) {
+  TFL_Model* model =
+      TFL_NewModelFromFile("tensorflow/lite/testdata/add.bin");
+
+  // Create and install a delegate instance.
+  TfLiteDelegate delegate = TfLiteDelegateCreate();
+  delegate.Prepare = [](TfLiteContext* context, TfLiteDelegate* delegate) {
+    return kTfLiteError;
+  };
+  TFL_InterpreterOptions* options = TFL_NewInterpreterOptions();
+  TFL_InterpreterOptionsAddDelegate(options, &delegate);
+  TFL_Interpreter* interpreter = TFL_NewInterpreter(model, options);
+
+  // Interpreter creation should fail as delegate preparation failed.
+  EXPECT_EQ(nullptr, interpreter);
+
+  TFL_DeleteInterpreterOptions(options);
+  TFL_DeleteModel(model);
+}
+
 }  // namespace
 
 int main(int argc, char** argv) {
diff --git a/tensorflow/lite/experimental/c/c_api_internal.h b/tensorflow/lite/experimental/c/c_api_internal.h
index 8a2987c..b058ec5 100644
--- a/tensorflow/lite/experimental/c/c_api_internal.h
+++ b/tensorflow/lite/experimental/c/c_api_internal.h
@@ -43,6 +43,8 @@
   void (*error_reporter)(void* user_data, const char* format,
                          va_list args) = nullptr;
   void* error_reporter_user_data = nullptr;
+
+  std::vector<TfLiteDelegate*> delegates;
 };
 
 struct TFL_Interpreter {
diff --git a/tensorflow/lite/experimental/c/c_api_types.h b/tensorflow/lite/experimental/c/c_api_types.h
index e1c54cb..c31d3e5 100644
--- a/tensorflow/lite/experimental/c/c_api_types.h
+++ b/tensorflow/lite/experimental/c/c_api_types.h
@@ -51,7 +51,11 @@
   kTfLiteMaxExternalContexts = 4
 } TfLiteExternalContextType;
 
+// Forward declare so dependent structs and methods can reference these types
+// prior to the struct definitions.
 struct TfLiteContext;
+struct TfLiteDelegate;
+struct TfLiteRegistration;
 
 // An external context is a collection of information unrelated to the TF Lite
 // framework, but useful to a subset of the ops. TF Lite knows very little
@@ -63,10 +67,6 @@
   TfLiteStatus (*Refresh)(struct TfLiteContext* context);
 } TfLiteExternalContext;
 
-// Forward declare so GetNode can use this is in Context.
-typedef struct _TfLiteRegistration TfLiteRegistration;
-typedef struct _TfLiteDelegate TfLiteDelegate;
-
 #define kOptionalTensor (-1)
 
 // Fixed size list of integers. Used for dimensions and inputs/outputs tensor
@@ -330,7 +330,7 @@
 
   // The delegate which knows how to handle `buffer_handle`.
   // WARNING: This is an experimental interface that is subject to change.
-  TfLiteDelegate* delegate;
+  struct TfLiteDelegate* delegate;
 
   // An integer buffer handle that can be handled by `delegate`.
   // The value is valid only when delegate is not null.
@@ -405,7 +405,7 @@
   // The pointer to the delegate. This is non-null only when the node is
   // created by calling `interpreter.ModifyGraphWithDelegate`.
   // WARNING: This is an experimental interface that is subject to change.
-  TfLiteDelegate* delegate;
+  struct TfLiteDelegate* delegate;
 } TfLiteNode;
 
 typedef struct TfLiteContext {
@@ -451,15 +451,15 @@
 
   // Get a Tensor node by node_index.
   // WARNING: This is an experimental interface that is subject to change.
-  TfLiteStatus (*GetNodeAndRegistration)(struct TfLiteContext*, int node_index,
-                                         TfLiteNode** node,
-                                         TfLiteRegistration** registration);
+  TfLiteStatus (*GetNodeAndRegistration)(
+      struct TfLiteContext*, int node_index, TfLiteNode** node,
+      struct TfLiteRegistration** registration);
 
   // Replace ops with one or more stub delegate operations. This function
   // does not take ownership of `nodes_to_replace`.
   TfLiteStatus (*ReplaceNodeSubsetsWithDelegateKernels)(
-      struct TfLiteContext*, TfLiteRegistration registration,
-      const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate);
+      struct TfLiteContext*, struct TfLiteRegistration registration,
+      const TfLiteIntArray* nodes_to_replace, struct TfLiteDelegate* delegate);
 
   // Number of threads that are recommended to subsystems like gemmlowp and
   // eigen.
@@ -484,7 +484,7 @@
   void* profiler;
 } TfLiteContext;
 
-typedef struct _TfLiteRegistration {
+typedef struct TfLiteRegistration {
   // Initializes the op from serialized data.
   // If a built-in op:
   //   `buffer` is the op's params data (TfLiteLSTMParams*).
@@ -560,7 +560,7 @@
 } TfLiteDelegateFlags;
 
 // WARNING: This is an experimental interface that is subject to change.
-typedef struct _TfLiteDelegate {
+typedef struct TfLiteDelegate {
   // Data that delegate needs to identify itself. This data is owned by the
   // delegate. The delegate is owned in the user code, so the delegate is
   // responsible for doing this when it is destroyed.
@@ -571,20 +571,21 @@
   // will look at the nodes and call ReplaceNodeSubsetsWithDelegateKernels()
   // to ask the TensorFlow lite runtime to create macro-nodes to represent
   // delegated subgraphs of the original graph.
-  TfLiteStatus (*Prepare)(TfLiteContext* context, TfLiteDelegate* delegate);
+  TfLiteStatus (*Prepare)(TfLiteContext* context,
+                          struct TfLiteDelegate* delegate);
 
   // Copy the data from delegate buffer handle into raw memory of the given
   // 'tensor'. This cannot be null. The delegate is allowed to allocate the raw
   // bytes as long as it follows the rules for kTfLiteDynamic tensors.
   TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context,
-                                       TfLiteDelegate* delegate,
+                                       struct TfLiteDelegate* delegate,
                                        TfLiteBufferHandle buffer_handle,
                                        TfLiteTensor* tensor);
 
   // Copy the data from raw memory of the given 'tensor' to delegate buffer
   // handle. This can be null if the delegate doesn't use its own buffer.
   TfLiteStatus (*CopyToBufferHandle)(TfLiteContext* context,
-                                     TfLiteDelegate* delegate,
+                                     struct TfLiteDelegate* delegate,
                                      TfLiteBufferHandle buffer_handle,
                                      TfLiteTensor* tensor);
 
@@ -592,7 +593,8 @@
   // this doesn't release the underlying resource (e.g. textures). The
   // resources are either owned by application layer or the delegate.
   // This can be null if the delegate doesn't use its own buffer.
-  void (*FreeBufferHandle)(TfLiteContext* context, TfLiteDelegate* delegate,
+  void (*FreeBufferHandle)(TfLiteContext* context,
+                           struct TfLiteDelegate* delegate,
                            TfLiteBufferHandle* handle);
 
   // Bitmask flags. See the comments in `TfLiteDelegateFlags`.
diff --git a/tensorflow/lite/experimental/ios/BUILD.apple b/tensorflow/lite/experimental/ios/BUILD.apple
index 2d78b21..24d975c 100644
--- a/tensorflow/lite/experimental/ios/BUILD.apple
+++ b/tensorflow/lite/experimental/ios/BUILD.apple
@@ -1,19 +1,13 @@
 # TensorFlow Lite for iOS
 
-load("//tensorflow/lite/experimental/ios:ios.bzl", "TFL_IOS_BUILD_VERSION", "TFL_MINIMUM_OS_VERSION")
+load("//tensorflow/lite/experimental/ios:ios.bzl", "TFL_MINIMUM_OS_VERSION")
 load("@build_bazel_rules_apple//apple:ios.bzl", "ios_static_framework")
-load("@build_bazel_rules_apple//apple:versioning.bzl", "apple_bundle_version")
 
 package(
     default_visibility = ["//visibility:private"],
     licenses = ["notice"],  # Apache 2.0
 )
 
-apple_bundle_version(
-    name = "TensorFlowLiteC_version",
-    build_version = TFL_IOS_BUILD_VERSION,
-)
-
 ios_static_framework(
     name = "TensorFlowLiteC_framework",
     hdrs = [
@@ -22,6 +16,5 @@
     ],
     bundle_name = "TensorFlowLiteC",
     minimum_os_version = TFL_MINIMUM_OS_VERSION,
-    version = ":TensorFlowLiteC_version",
     deps = ["//tensorflow/lite/experimental/c:c_api"],
 )
diff --git a/tensorflow/lite/experimental/ios/TensorFlowLiteC.podspec b/tensorflow/lite/experimental/ios/TensorFlowLiteC.podspec
index cb16346..5efd12c 100644
--- a/tensorflow/lite/experimental/ios/TensorFlowLiteC.podspec
+++ b/tensorflow/lite/experimental/ios/TensorFlowLiteC.podspec
@@ -1,10 +1,10 @@
 Pod::Spec.new do |s|
   s.name             = 'TensorFlowLiteC'
-  s.version          = '0.2.0'
+  s.version          = '1.14.0'
   s.authors          = 'Google Inc.'
   s.license          = { :type => 'Apache' }
   s.homepage         = 'https://github.com/tensorflow/tensorflow'
-  s.source           = { :http => "https://dl.google.com/dl/cpdc/9d0ec5e53f4ff34a/TensorFlowLiteC-#{s.version}.tar.gz" }
+  s.source           = { :http => "https://dl.google.com/dl/cpdc/0e27bc28472e2519/TensorFlowLiteC-#{s.version}.tar.gz" }
   s.summary          = 'TensorFlow Lite'
   s.description      = <<-DESC
 
diff --git a/tensorflow/lite/experimental/ios/ios.bzl b/tensorflow/lite/experimental/ios/ios.bzl
index 1698134..976c6b0 100644
--- a/tensorflow/lite/experimental/ios/ios.bzl
+++ b/tensorflow/lite/experimental/ios/ios.bzl
@@ -1,8 +1,5 @@
 """TensorFlow Lite Build Configurations for iOS"""
 
-# Current version of the TensorFlow Lite iOS libraries.
-TFL_IOS_BUILD_VERSION = "0.2.0"
-
 TFL_MINIMUM_OS_VERSION = "9.0"
 
 # Default tags for filtering iOS targets. Targets are restricted to Apple platforms.
diff --git a/tensorflow/lite/experimental/kernels/fp16/BUILD b/tensorflow/lite/experimental/kernels/fp16/BUILD
deleted file mode 100644
index 14f9ff4..0000000
--- a/tensorflow/lite/experimental/kernels/fp16/BUILD
+++ /dev/null
@@ -1,17 +0,0 @@
-# Experimental FP16-on-CPU implementation of a few select layers.
-
-package(
-    licenses = ["notice"],  # Apache 2.0
-)
-
-cc_library(
-    name = "common",
-    hdrs = [
-        "common.h",
-    ],
-    deps = [
-        "//tensorflow/lite:framework",
-        "//tensorflow/lite/c:c_api_internal",
-        "//tensorflow/lite/kernels/internal:tensor",
-    ],
-)
diff --git a/tensorflow/lite/experimental/kernels/fp16/common.h b/tensorflow/lite/experimental/kernels/fp16/common.h
deleted file mode 100644
index 8b82f14..0000000
--- a/tensorflow/lite/experimental/kernels/fp16/common.h
+++ /dev/null
@@ -1,75 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_FP16_COMMON_H_
-#define TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_FP16_COMMON_H_
-
-// Experimental half precision floating point type compatible with IEEE 754-2008
-// binary16 format.
-
-#include "tensorflow/lite/c/c_api_internal.h"
-#include "tensorflow/lite/interpreter.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-
-#if __GNUC__ && ((__clang__ && (__aarch64__ || __arm__)) || \
-                 (!__cplusplus && __ARM_FP16_FORMAT_IEEE))
-#define TFL_HAS_IEEE_FP16 1
-#endif
-#if __GNUC__ && \
-    (__clang__ || __ARM_FP16_FORMAT_IEEE || __ARM_FP16_FORMAT_ALTERNATIVE)
-#define TFL_HAS_ARM_FP16 1
-#endif
-
-namespace tflite {
-
-#if TFL_HAS_IEEE_FP16
-typedef _Float16 tfl_float16_t;
-#elif TFL_HAS_ARM_FP16
-typedef __fp16 tfl_float16_t;
-#else
-// TODO(b/138252484): implement tfl_float16_t using third_party/FP16
-#error "This header requires FP16 support."
-#endif
-
-// Check tfl_float16_t is 'compatible' with the placeholder type.
-static_assert(sizeof(tfl_float16_t) == sizeof(TfLiteFloat16),
-              "Size of real and placeholder FP16 types don't match.");
-static_assert(alignof(tfl_float16_t) == alignof(TfLiteFloat16),
-              "Alignment of real and placeholder FP16 types don't match.");
-
-// Specialization of typeToTfLiteType with tfl_float16_t.
-// Template is declared in interpreter.h
-template <>
-constexpr TfLiteType typeToTfLiteType<tfl_float16_t>() {
-  return kTfLiteFloat16;
-}
-
-// Specialization of GetTensorData with tfl_float16_t.
-// Template is declared in kernels/internal/tensor_ctypes.h
-template <>
-inline tfl_float16_t* GetTensorData(TfLiteTensor* tensor) {
-  return tensor != nullptr ? reinterpret_cast<tfl_float16_t*>(tensor->data.f16)
-                           : nullptr;
-}
-
-template <>
-inline const tfl_float16_t* GetTensorData(const TfLiteTensor* tensor) {
-  return tensor != nullptr
-             ? reinterpret_cast<const tfl_float16_t*>(tensor->data.f16)
-             : nullptr;
-}
-
-}  // namespace tflite
-
-#endif  // TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_FP16_COMMON_H_
diff --git a/tensorflow/lite/experimental/micro/README.md b/tensorflow/lite/experimental/micro/README.md
index b70aeb6..bf6e7dc 100644
--- a/tensorflow/lite/experimental/micro/README.md
+++ b/tensorflow/lite/experimental/micro/README.md
@@ -341,7 +341,7 @@
     to down load the Tensorflow source code and the support libraries \(but do
     not run the make command shown there.\)
 2.  Download the Eta Compute SDK, version 0.0.17. Contact info@etacompute.com
-3.  You will need the the Arm compiler arm-none-eabi-gcc, version 7.3.1
+3.  You will need the Arm compiler arm-none-eabi-gcc, version 7.3.1
     20180622, release ARM/embedded-7-branch revision 261907, 7-2018-q2-update.
     This compiler is downloaded through make.
 4.  Edit the file
diff --git a/tensorflow/lite/experimental/micro/arduino/debug_log.cc b/tensorflow/lite/experimental/micro/arduino/debug_log.cc
index 4d18f6f9..3cdd006 100644
--- a/tensorflow/lite/experimental/micro/arduino/debug_log.cc
+++ b/tensorflow/lite/experimental/micro/arduino/debug_log.cc
@@ -34,5 +34,5 @@
     DEBUG_SERIAL_OBJECT.begin(9600);
     is_initialized = true;
   }
-  DEBUG_SERIAL_OBJECT.println(s);
+  DEBUG_SERIAL_OBJECT.print(s);
 }
diff --git a/tensorflow/lite/experimental/micro/examples/hello_world/README.md b/tensorflow/lite/experimental/micro/examples/hello_world/README.md
index e0b593f..89804d4 100644
--- a/tensorflow/lite/experimental/micro/examples/hello_world/README.md
+++ b/tensorflow/lite/experimental/micro/examples/hello_world/README.md
@@ -82,11 +82,12 @@
 ### Obtain and import the library
 
 To use this sample application with Arduino, we've created an Arduino library
-that includes it as an example that you can open in the Arduino IDE.
+that includes it as an example that you can open in the Arduino Desktop IDE.
 
 Download the current nightly build of the library: [hello_world.zip](https://storage.googleapis.com/tensorflow-nightly/github/tensorflow/tensorflow/lite/experimental/micro/tools/make/gen/arduino_x86_64/prj/hello_world/hello_world.zip)
 
-Next, import this zip file into the Arduino IDE by going to `Sketch -> Include Library -> Add .ZIP Library...`.
+Next, import this zip file into the Arduino Desktop IDE by going to `Sketch ->
+Include Library -> Add .ZIP Library...`.
 
 #### Building the library
 
@@ -104,7 +105,8 @@
 tensorflow/lite/experimental/micro/tools/make/gen/arduino_x86_64/prj/hello_world/hello_world.zip
 ```
 
-You can then import this zip file into the Arduino IDE by going to `Sketch -> Include Library -> Add .ZIP Library...`.
+You can then import this zip file into the Arduino Desktop IDE by going to
+`Sketch -> Include Library -> Add .ZIP Library...`.
 
 ### Load and run the example
 
@@ -112,11 +114,11 @@
 example near the bottom of the list named `TensorFlowLite:hello_world`. Select
 it and click `hello_world` to load the example.
 
-Use the Arduino IDE to build and upload the example. Once it is running, you
-should see the built-in LED on your device flashing.
+Use the Arduino Desktop IDE to build and upload the example. Once it is running,
+you should see the built-in LED on your device flashing.
 
-The Arduino IDE includes a plotter that we can use to display the sine wave
-graphically. To view it, go to `Tools -> Serial Plotter`. You will see one
+The Arduino Desktop IDE includes a plotter that we can use to display the sine
+wave graphically. To view it, go to `Tools -> Serial Plotter`. You will see one
 datapoint being logged for each inference cycle, expressed as a number between 0
 and 255.
 
diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/README.md b/tensorflow/lite/experimental/micro/examples/micro_speech/README.md
index b9c9957..1d41572 100644
--- a/tensorflow/lite/experimental/micro/examples/micro_speech/README.md
+++ b/tensorflow/lite/experimental/micro/examples/micro_speech/README.md
@@ -1,81 +1,508 @@
-# Micro Speech Example
+# Micro Speech example
 
-This examples shows how you can use TensorFlow Lite to run a 20 kilobyte neural
+This example shows how you can use TensorFlow Lite to run a 20 kilobyte neural
 network model to recognize keywords in speech. It's designed to run on systems
-with very small amounts of memory such as microcontrollers and DSPs. The code
-itself also has a small footprint (for example around 22 kilobytes on a Cortex
+with very small amounts of memory such as microcontrollers and DSPs.
+
+The example application listens to its surroundings with a microphone and
+indicates when it has detected a word by lighting an LED or displaying data on a
+screen, depending on the capabilities of the device.
+
+The code has a small footprint (for example around 22 kilobytes on a Cortex
 M3) and only uses about 10 kilobytes of RAM for working memory, so it's able to
 run on systems like an STM32F103 with only 20 kilobytes of total SRAM and 64
 kilobytes of Flash.
 
-## Table of Contents
+## Table of contents
 
--   [Getting Started](#getting-started)
--   [Getting Started on a Microcontroller](#getting-started-on-a-microcontroller)
--   [Calculating the Input to the Neural Network](#calculating-the-input-to-the-neural-network)
--   [Creating Your Own Model](#creating-your-own-model)
+-   [Getting started](#getting-started)
+-   [Run on macOS](#run-on-macos)
+-   [Deploy to Arduino](#deploy-to-arduino)
+-   [Deploy to SparkFun Edge](#deploy-to-sparkfun-edge)
+-   [Deploy to STM32F746](#deploy-to-STM32F746)
+-   [Deploy to NXP FRDM K66F](#deploy-to-nxp-frdm-k66f)
+-   [Calculating the input to the neural network](#calculating-the-input-to-the-neural-network)
+-   [Train your own model](#train-your-own-model)
 
-## Getting Started
 
-To compile and test this example on a desktop Linux or MacOS machine, download
+## Getting started
+
+This code has been tested on the following devices:
+
+* [SparkFun Edge](https://sparkfun.com/products/15170)
+* [Arduino Nano 33 BLE Sense](https://store.arduino.cc/usa/nano-33-ble-sense-with-headers)
+* [ST Microelectronics STM32F746G Discovery kit](https://os.mbed.com/platforms/ST-Discovery-F746NG/)
+* [NXP FRDM K66F](https://www.nxp.com/design/development-boards/freedom-development-boards/mcu-boards/freedom-development-platform-for-kinetis-k66-k65-and-k26-mcus:FRDM-K66F)
+
+This readme contains instructions for building the code on Linux and macOS, and
+deploying the code to the above microcontroller platforms and macOS.
+
+### Build the tests
+
+To compile and test this example on a desktop Linux or macOS machine, download
 [the TensorFlow source code](https://github.com/tensorflow/tensorflow), `cd`
 into the source directory from a terminal, and then run the following command:
 
 ```
-make -f tensorflow/lite/experimental/micro/tools/make/Makefile
+make -f tensorflow/lite/experimental/micro/tools/make/Makefile test_micro_speech_test
 ```
 
 This will take a few minutes, and downloads frameworks the code uses like
 [CMSIS](https://developer.arm.com/embedded/cmsis) and
 [flatbuffers](https://google.github.io/flatbuffers/). Once that process has
-finished, run:
+finished, you should see a series of files get compiled, followed by some
+logging output from a test, which should conclude with `~~~ALL TESTS PASSED~~~`.
 
-```
-make -f tensorflow/lite/experimental/micro/tools/make/Makefile test_micro_speech
-```
+If you see this, it means that a small program has been built and run that loads
+the trained TensorFlow model, runs some example inputs through it, and got the
+expected outputs.
 
-You should see a series of files get compiled, followed by some logging output
-from a test, which should conclude with `~~~ALL TESTS PASSED~~~`. If you see
-this, it means that a small program has been built and run that loads a trained
-TensorFlow model, runs some example inputs through it, and got the expected
-outputs. This particular test runs spectrograms generated from recordings of
-people saying "Yes" and "No", and checks that the network correctly identifies
-them.
-
-To understand how TensorFlow Lite does this, you can look at the `TestInvoke()`
-function in
-[micro_speech_test.cc](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc).
-It's a fairly small amount of code, creating an interpreter, getting a handle to
-a model that's been compiled into the program, and then invoking the interpreter
+To understand how TensorFlow Lite does this, you can look at the source in
+[hello_world_test.cc](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc).
+It's a fairly small amount of code that creates an interpreter, gets a handle to
+a model that's been compiled into the program, and then invokes the interpreter
 with the model and sample inputs.
 
-## Getting Started on a Microcontroller
+### Run on macOS
 
-Once you have downloaded the dependencies and got the x86/Linux build working,
-you can try building a version for the STM32F103 'bluepill' device. The
-following command will build the test and then run it on an emulator, assuming
-you have Docker installed:
+The example contains an audio provider compatible with macOS. If you have access
+to a Mac, you can run the example on your development machine.
 
-*On Mac OS you need to have ARM compiler installed, one way of doing so is with
-brew: brew install caskroom/cask/gcc-arm-embedded*
+First, use the following command to build it:
 
 ```
-make -f tensorflow/lite/experimental/micro/tools/make/Makefile TARGET=bluepill test_micro_speech
+make -f tensorflow/lite/experimental/micro/tools/make/Makefile micro_speech
 ```
 
-If you have a real device
-[(see here for how to set one up)](https://github.com/google/stm32_bare_lib/tree/master/README.md)
-you can then convert the ELF file into a a `.bin` format executable to load onto
-it by running:
+Once the build completes, you can run the example with the following command:
 
 ```
-arm-none-eabi-objcopy \
-tensorflow/lite/experimental/micro/tools/make/gen/bluepill_cortex-m3/bin/micro_speech_test \
-tensorflow/lite/experimental/micro/tools/make/gen/bluepill_cortex-m3/bin/micro_speech_test.bin \
---output binary
+tensorflow/lite/experimental/micro/tools/make/gen/osx_x86_64/bin/micro_speech
 ```
 
-## Calculating the Input to the Neural Network
+You might see a pop-up asking for microphone access. If so, grant it, and the
+program will start.
+
+Try saying "yes" and "no". You should see output that looks like the following:
+
+```
+Heard yes (201) @4056ms
+Heard no (205) @6448ms
+Heard unknown (201) @13696ms
+Heard yes (205) @15000ms
+Heard yes (205) @16856ms
+Heard unknown (204) @18704ms
+Heard no (206) @21000ms
+```
+
+The number after each detected word is its score. By default, the recognize
+commands component only considers matches as valid if their score is over 200,
+so all of the scores you see will be at least 200.
+
+The number after the score is the number of milliseconds since the program was
+started.
+
+If you don't see any output, make sure your Mac's internal microphone is
+selected in the Mac's *Sound* menu, and that its input volume is turned up high
+enough.
+
+## Deploy to Arduino
+
+The following instructions will help you build and deploy this sample
+to [Arduino](https://www.arduino.cc/) devices.
+
+The sample has been tested with the following devices:
+
+- [Arduino Nano 33 BLE Sense](https://store.arduino.cc/usa/nano-33-ble-sense-with-headers)
+
+The Arduino Nano 33 BLE Sense is currently the only Arduino with a built-in
+microphone. If you're using a different Arduino board and attaching your own
+microphone, you'll need to implement your own +audio_provider.cc+. It also has a
+built-in LED, which is used to indicate that a word has been recognized.
+
+### Obtain and import the library
+
+To use this sample application with Arduino, we've created an Arduino library
+that includes it as an example that you can open in the Arduino IDE.
+
+Download the current nightly build of the library: [micro_speech.zip](https://storage.googleapis.com/tensorflow-nightly/github/tensorflow/tensorflow/lite/experimental/micro/tools/make/gen/arduino_x86_64/prj/micro_speech/micro_speech.zip)
+
+Next, import this zip file into the Arduino IDE by going to
+`Sketch -> Include Library -> Add .ZIP Library...`.
+
+#### Build the library
+
+If you need to build the library from source (for example, if you're making
+modifications to the code), run this command to generate a zip file containing
+the required source files:
+
+```
+make -f tensorflow/lite/experimental/micro/tools/make/Makefile TARGET=arduino TAGS="portable_optimized" generate_micro_speech_arduino_library_zip
+```
+
+A zip file will be created at the following location:
+
+```
+tensorflow/lite/experimental/micro/tools/make/gen/arduino_x86_64/prj/micro_speech/micro_speech.zip
+```
+
+You can then import this zip file into the Arduino IDE by going to
+`Sketch -> Include Library -> Add .ZIP Library...`.
+
+### Load and run the example
+
+Once the library has been added, go to `File -> Examples`. You should see an
+example near the bottom of the list named `TensorFlowLite:micro_speech`. Select
+it and click `micro_speech` to load the example.
+
+Use the Arduino IDE to build and upload the example. Once it is running, you
+should see the built-in LED on your device flashing. Saying the word "yes" will
+cause the LED to remain on for 3 seconds. The current model has fairly low
+accuracy, so you may have to repeat "yes" a few times.
+
+The program also outputs inference results to the serial port, which appear as
+follows:
+
+```
+Heard yes (201) @4056ms
+Heard no (205) @6448ms
+Heard unknown (201) @13696ms
+Heard yes (205) @15000ms
+```
+
+The number after each detected word is its score. By default, the program only
+considers matches as valid if their score is over 200, so all of the scores you
+see will be at least 200.
+
+When the program is run, it waits 5 seconds for a USB-serial connection to be
+available. If there is no connection available, it will not output data. To see
+the serial output in the Arduino desktop IDE, do the following:
+
+1. Open the Arduino IDE
+1. Connect the Arduino board to your computer via USB
+1. Press the reset button on the Arduino board
+1. Within 5 seconds, go to `Tools -> Serial Monitor` in the Arduino IDE. You may
+   have to try several times, since the board will take a moment to connect.
+
+If you don't see any output, repeat the process again.
+
+## Deploy to SparkFun Edge
+
+The following instructions will help you build and deploy this sample on the
+[SparkFun Edge development board](https://sparkfun.com/products/15170).
+
+The program will toggle the blue LED on and off with each inference. It will
+switch on the yellow LED when a "yes" is heard, the red LED when a "no" is
+heard, and the green LED when an unknown command is heard.
+
+The [AI on a microcontroller with TensorFlow Lite and SparkFun Edge](https://codelabs.developers.google.com/codelabs/sparkfun-tensorflow)
+walks through the deployment process in detail. The steps are also
+summarized below.
+
+### Compile the binary
+
+The following command will download the required dependencies and then compile a
+binary for the SparkFun Edge:
+
+```
+make -f tensorflow/lite/experimental/micro/tools/make/Makefile TARGET=sparkfun_edge micro_speech_bin
+```
+
+The binary will be created in the following location:
+
+```
+tensorflow/lite/experimental/micro/tools/make/gen/sparkfun_edge_cortex-m4/bin/micro_speech.bin
+```
+
+### Sign the binary
+
+The binary must be signed with cryptographic keys to be deployed to the device.
+We'll now run some commands that will sign our binary so it can be flashed to
+the SparkFun Edge. The scripts we are using come from the Ambiq SDK, which is
+downloaded when the `Makefile` is run.
+
+Enter the following command to set up some dummy cryptographic keys we can use
+for development:
+
+```
+cp tensorflow/lite/experimental/micro/tools/make/downloads/AmbiqSuite-Rel2.0.0/tools/apollo3_scripts/keys_info0.py \
+tensorflow/lite/experimental/micro/tools/make/downloads/AmbiqSuite-Rel2.0.0/tools/apollo3_scripts/keys_info.py
+```
+
+Next, run the following command to create a signed binary:
+
+```
+python3 tensorflow/lite/experimental/micro/tools/make/downloads/AmbiqSuite-Rel2.0.0/tools/apollo3_scripts/create_cust_image_blob.py \
+--bin tensorflow/lite/experimental/micro/tools/make/gen/sparkfun_edge_cortex-m4/bin/micro_speech.bin \
+--load-address 0xC000 \
+--magic-num 0xCB \
+-o main_nonsecure_ota \
+--version 0x0
+```
+
+This will create the file `main_nonsecure_ota.bin`. We'll now run another
+command to create a final version of the file that can be used to flash our
+device with the bootloader script we will use in the next step:
+
+```
+python3 tensorflow/lite/experimental/micro/tools/make/downloads/AmbiqSuite-Rel2.0.0/tools/apollo3_scripts/create_cust_wireupdate_blob.py \
+--load-address 0x20000 \
+--bin main_nonsecure_ota.bin \
+-i 6 \
+-o main_nonsecure_wire \
+--options 0x1
+```
+
+You should now have a file called `main_nonsecure_wire.bin` in the directory
+where you ran the commands. This is the file we'll be flashing to the device.
+
+### Flash the binary
+
+Next, attach the board to your computer via a USB-to-serial adapter.
+
+**Note:** If you're using the [SparkFun Serial Basic Breakout](https://www.sparkfun.com/products/15096),
+you should [install the latest drivers](https://learn.sparkfun.com/tutorials/sparkfun-serial-basic-ch340c-hookup-guide#drivers-if-you-need-them)
+before you continue.
+
+Once connected, assign the USB device name to an environment variable:
+
+```
+export DEVICENAME=put your device name here
+```
+
+Set another variable with the baud rate:
+
+```
+export BAUD_RATE=921600
+```
+
+Now, hold the button marked `14` on the device. While still holding the button,
+hit the button marked `RST`. Continue holding the button marked `14` while
+running the following command:
+
+```
+python3 tensorflow/lite/experimental/micro/tools/make/downloads/AmbiqSuite-Rel2.0.0/tools/apollo3_scripts/uart_wired_update.py \
+-b ${BAUD_RATE} ${DEVICENAME} \
+-r 1 \
+-f main_nonsecure_wire.bin \
+-i 6
+```
+
+You should see a long stream of output as the binary is flashed to the device.
+Once you see the following lines, flashing is complete:
+
+```
+Sending Reset Command.
+Done.
+```
+
+If you don't see these lines, flashing may have failed. Try running through the
+steps in [Flash the binary](#flash-the-binary) again (you can skip over setting
+the environment variables). If you continue to run into problems, follow the
+[AI on a microcontroller with TensorFlow Lite and SparkFun Edge](https://codelabs.developers.google.com/codelabs/sparkfun-tensorflow)
+codelab, which includes more comprehensive instructions for the flashing
+process.
+
+The binary should now be deployed to the device. Hit the button marked `RST` to
+reboot the board.
+
+You should see the device's blue LED flashing. The yellow LED should light when
+a "yes" is heard, the red LED when a "no" is heard, and the green LED when an
+unknown command is heard. The current model has fairly low accuracy, so you may
+have to repeat "yes" a few times.
+
+Debug information is logged by the board while the program is running. To view
+it, establish a serial connection to the board using a baud rate of `115200`.
+On OSX and Linux, the following command should work:
+
+```
+screen ${DEVICENAME} 115200
+```
+
+You will see a line output for every word that is detected:
+
+```
+Heard yes (201) @4056ms
+Heard no (205) @6448ms
+Heard unknown (201) @13696ms
+Heard yes (205) @15000ms
+```
+
+The number after each detected word is its score. By default, the program only
+considers matches as valid if their score is over 200, so all of the scores you
+see will be at least 200.
+
+To stop viewing the debug output with `screen`, hit `Ctrl+A`, immediately
+followed by the `K` key, then hit the `Y` key.
+
+## Deploy to STM32F746
+
+The following instructions will help you build and deploy the sample to the
+[STM32F7 discovery kit](https://os.mbed.com/platforms/ST-Discovery-F746NG/)
+using [ARM Mbed](https://github.com/ARMmbed/mbed-cli).
+
+Before we begin, you'll need the following:
+
+- STM32F7 discovery kit board
+- Mini-USB cable
+- ARM Mbed CLI ([installation instructions](https://os.mbed.com/docs/mbed-os/v5.12/tools/installation-and-setup.html))
+- Python 2.7 and pip
+
+Since Mbed requires a special folder structure for projects, we'll first run a
+command to generate a subfolder containing the required source files in this
+structure:
+
+```
+make -f tensorflow/lite/experimental/micro/tools/make/Makefile TARGET=mbed TAGS="CMSIS disco_f746ng" generate_micro_speech_mbed_project
+```
+
+This will result in the creation of a new folder:
+
+```
+tensorflow/lite/experimental/micro/tools/make/gen/mbed_cortex-m4/prj/hello_world/mbed
+```
+
+This folder contains all of the example's dependencies structured in the correct
+way for Mbed to be able to build it.
+
+Change into the directory and run the following commands, making sure you are
+using Python 2.7.15.
+
+First, tell Mbed that the current directory is the root of an Mbed project:
+
+```
+mbed config root .
+```
+
+Next, tell Mbed to download the dependencies and prepare to build:
+
+```
+mbed deploy
+```
+
+By default, Mbed will build the project using C++98. However, TensorFlow Lite
+requires C++11. Run the following Python snippet to modify the Mbed
+configuration files so that it uses C++11:
+
+```
+python -c 'import fileinput, glob;
+for filename in glob.glob("mbed-os/tools/profiles/*.json"):
+  for line in fileinput.input(filename, inplace=True):
+    print line.replace("\"-std=gnu++98\"","\"-std=c++11\", \"-fpermissive\"")'
+
+```
+
+Finally, run the following command to compile:
+
+```
+mbed compile -m DISCO_F746NG -t GCC_ARM
+```
+
+This should result in a binary at the following path:
+
+```
+./BUILD/DISCO_F746NG/GCC_ARM/mbed.bin
+```
+
+To deploy, plug in your STM board and copy the file to it. On macOS, you can do
+this with the following command:
+
+```
+cp ./BUILD/DISCO_F746NG/GCC_ARM/mbed.bin /Volumes/DIS_F746NG/
+```
+
+Copying the file will initiate the flashing process.
+
+The inference results are logged by the board while the program is running.
+To view it, establish a serial connection to the board
+using a baud rate of `9600`. On OSX and Linux, the following command should
+work, replacing `/dev/tty.devicename` with the name of your device as it appears
+in `/dev`:
+
+```
+screen /dev/tty.devicename 9600
+```
+
+You will see a line output for every word that is detected:
+
+```
+Heard yes (201) @4056ms
+Heard no (205) @6448ms
+Heard unknown (201) @13696ms
+Heard yes (205) @15000ms
+```
+
+The number after each detected word is its score. By default, the program only
+considers matches as valid if their score is over 200, so all of the scores you
+see will be at least 200.
+
+To stop viewing the debug output with `screen`, hit `Ctrl+A`, immediately
+followed by the `K` key, then hit the `Y` key.
+
+## Deploy to NXP FRDM K66F
+
+The following instructions will help you build and deploy the sample to the
+[NXP FRDM K66F](https://www.nxp.com/design/development-boards/freedom-development-boards/mcu-boards/freedom-development-platform-for-kinetis-k66-k65-and-k26-mcus:FRDM-K66F)
+using [ARM Mbed](https://github.com/ARMmbed/mbed-cli).
+
+1.  Download [the TensorFlow source code](https://github.com/tensorflow/tensorflow).
+2.  Follow instructions from [mbed website](https://os.mbed.com/docs/mbed-os/v5.13/tools/installation-and-setup.html) to setup and install mbed CLI.
+3.  Compile TensorFlow with the following command to generate mbed project:
+
+    ```
+    make -f tensorflow/lite/experimental/micro/tools/make/Makefile TARGET=mbed TAGS="nxp_k66f" generate_micro_speech_mbed_project
+    ```
+4.  Go to the location of the generated project. The generated project is usally
+    in `tensorflow/lite/experimental/micro/tools/make/gen/mbed_cortex-m4/prj/micro_speech/mbed`
+5.  Create a mbed project using the generated files: `mbed new .`
+6.  Change the project setting to use C++ 11 rather than C++ 14 using:
+
+    ```
+    python -c 'import fileinput, glob;
+    for filename in glob.glob("mbed-os/tools/profiles/*.json"):
+      for line in fileinput.input(filename, inplace=True):
+        print line.replace("\"-std=gnu++14\"","\"-std=c++11\", \"-fpermissive\"")'
+    ```
+7.  To compile project, use the following command:
+
+    ```
+    mbed compile --target K66F --toolchain GCC_ARM --profile release
+    ```
+8.  For some mbed compliers, you may get compile error in mbed_rtc_time.cpp.
+    Go to `mbed-os/platform/mbed_rtc_time.h` and comment line 32 and line 37:
+
+    ```
+    //#if !defined(__GNUC__) || defined(__CC_ARM) || defined(__clang__)
+    struct timeval {
+    time_t tv_sec;
+    int32_t tv_usec;
+    };
+    //#endif
+    ```
+9.  Look at helpful resources from NXP website such as [NXP FRDM-K66F User guide](https://www.nxp.com/docs/en/user-guide/FRDMK66FUG.pdf) and [NXP FRDM-K66F Getting Started](https://www.nxp.com/document/guide/get-started-with-the-frdm-k66f:NGS-FRDM-K66F)
+    to understand information about the board.
+10. Connect USB cable to micro USB port. When ethernet port is face towards you,
+    The micro USB port is left of the ethernet port.
+11.  To compile and flash in a single step, add the `--flash` option:
+
+    ```
+    mbed compile --target K66F --toolchain GCC_ARM --profile release --flash
+    ```
+12. Disconnect USB cable from the device to power down the device and connect
+    back the power cable to start running the model.
+13. Connect to serial port with baud rate of 9600 and correct serial device
+    to view the output from the MCU. In linux, you can run the following screen
+    command if the serial device is `/dev/ttyACM0`:
+
+    ```
+    sudo screen /dev/ttyACM0 9600
+    ```
+14. Saying "Yes" will print "Yes" and "No" will print "No" on the serial port.
+15. A loopback path from microphone to headset jack is enabled. Headset jack is
+    in black color. If there is no output on the serial port, you can connect
+    headphone to headphone port to check if audio loopback path is working.
+
+## Calculating the input to the neural network
 
 The TensorFlow Lite model doesn't take in raw audio sample data. Instead it
 works with spectrograms, which are two dimensional arrays that are made up of
@@ -88,54 +515,91 @@
 The recipe for creating the spectrogram data is that each frequency slice is
 created by running an FFT across a 30ms section of the audio sample data. The
 input samples are treated as being between -1 and +1 as real values (encoded as
--32,768 and 32,767 in 16-bit signed integer samples). This results in an FFT
-with 256 entries. Every sequence of six entries is averaged together, giving a
-total of 43 frequency buckets in the final slice. The results are stored as
-unsigned eight-bit values, where 0 represents a real number of zero, and 255
-represents 127.5 as a real number. Each adjacent frequency entry is stored in
-ascending memory order (frequency bucket 0 at data[0], bucket 1 at data [1],
-etc). The window for the frequency analysis is then moved forward by 20ms, and
-the process repeated, storing the results in the next memory row (for example
-bucket 0 in this moved window would be in data[43 + 0], etc). This process
-happens 49 times in total, producing a single channel image that is 43 pixels
-wide, and 49 rows high. Here's an illustration of the process:
+-32,768 and 32,767 in 16-bit signed integer samples).
+
+This results in an FFT with 256 entries. Every sequence of six entries is
+averaged together, giving a total of 43 frequency buckets in the final slice.
+The results are stored as unsigned eight-bit values, where 0 represents a real
+number of zero, and 255 represents 127.5 as a real number.
+
+Each adjacent frequency entry is stored in ascending memory order (frequency
+bucket 0 at data[0], bucket 1 at data [1], etc). The window for the frequency
+analysis is then moved forward by 20ms, and the process repeated, storing the
+results in the next memory row (for example bucket 0 in this moved window would
+be in data[43 + 0], etc). This process happens 49 times in total, producing a
+single channel image that is 43 pixels wide, and 49 rows high.
+
+Here's an illustration of the process:
 
 ![spectrogram diagram](https://storage.googleapis.com/download.tensorflow.org/example_images/spectrogram_diagram.png)
 
-The test data files have been generated by running the following commands:
+The test data files have been generated by running the following commands. See
+the training instructions below to learn how to set up the environment to run
+them.
 
 ```
-bazel run tensorflow/examples/speech_commands:wav_to_features -- \
---input_wav=${HOME}/speech_commands_test_set_v0.02/yes/f2e59fea_nohash_1.wav \
---output_c_file=yes_features_data.cc \
+python tensorflow/tensorflow/examples/speech_commands/wav_to_features.py \
+--input_wav=/content/speech_dataset/yes/f2e59fea_nohash_1.wav \
+--output_c_file=/content/yes_features_data.cc \
 --window_stride=20 --preprocess=average --quantize=1
 
-bazel run tensorflow/examples/speech_commands:wav_to_features -- \
---input_wav=${HOME}/speech_commands_test_set_v0.02/no/f9643d42_nohash_4.wav \
---output_c_file=no_features_data.cc \
+python tensorflow/tensorflow/examples/speech_commands/wav_to_features.py \
+--input_wav=/content/speech_dataset/no/f9643d42_nohash_4.wav \
+--output_c_file=/content/no_features_data.cc \
 --window_stride=20 --preprocess=average --quantize=1
 ```
 
-## Creating Your Own Model
+## Train your own model
 
 The neural network model used in this example was built using the
 [TensorFlow speech commands tutorial](https://www.tensorflow.org/tutorials/sequences/audio_recognition).
+You can retrain it to recognize any combination of words from this list:
 
-If you would like to create your own, you can start by training a model with the
-following commands. Note that this will begin a full build of TensorFlow from
-source; it is not currently possible to use the TensorFlow pip package. Due to
-the complexity of setting up a build environment, it's easiest to run these
-commands in a
+```
+yes
+no
+up
+down
+left
+right
+on
+off
+stop
+go
+```
+
+### Use Google Colaboratory
+
+The easiest way to train your own speech model is by running [`train_speech_model.ipynb`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/micro/examples/micro_speech/train_speech_model.ipynb)
+in Google Colaboratory. This avoids the need to install dependencies, and allows
+the use of GPUs for training. Total training time will be 1.5-2hrs.
+
+We strongly recommend trying this approach first.
+
+### Use your local machine
+
+You can use the following commands to train the model on your own machine.
+
+It may be easiest to run these commands in a
 [TensorFlow Docker container](https://www.tensorflow.org/install/docker). A full
 build may take a couple of hours.
 
+You must currently use the TensorFlow Nightly `pip` package. This version is
+confirmed to work:
+
+```
+tf-nightly-gpu==1.15.0.dev20190729
+```
+
 To begin training, run the following:
 
 ```
-bazel run -c opt --copt=-mavx2 --copt=-mfma \
-tensorflow/examples/speech_commands:train -- \
+python tensorflow/tensorflow/examples/speech_commands/train.py \
 --model_architecture=tiny_conv --window_stride=20 --preprocess=micro \
---wanted_words="yes,no" --silence_percentage=25 --unknown_percentage=25 --quantize=1
+--wanted_words="yes,no" --silence_percentage=25 --unknown_percentage=25 \
+--quantize=1 --verbosity=WARN --how_many_training_steps="15000,3000" \
+--learning_rate="0.001,0.0001" --summaries_dir=/tmp/retrain_logs \
+--data_dir=/tmp/speech_dataset --train_dir=/tmp/speech_commands_train
 ```
 
 If you see a compiling error on older machines, try leaving out the `--copt`
@@ -144,7 +608,7 @@
 has completed, the next step is to freeze the variables:
 
 ```
-bazel run tensorflow/examples/speech_commands:freeze -- \
+python tensorflow/tensorflow/examples/speech_commands/freeze.py \
 --model_architecture=tiny_conv --window_stride=20 --preprocess=micro \
 --wanted_words="yes,no" --quantize=1 --output_file=/tmp/tiny_conv.pb \
 --start_checkpoint=/tmp/speech_commands_train/tiny_conv.ckpt-18000
@@ -153,10 +617,10 @@
 The next step is to create a TensorFlow Lite file from the frozen graph:
 
 ```
-bazel run tensorflow/lite/toco:toco -- \
---input_file=/tmp/tiny_conv.pb --output_file=/tmp/tiny_conv.tflite \
---input_shapes=1,49,40,1 --input_arrays=Reshape_1 --output_arrays='labels_softmax' \
---inference_type=QUANTIZED_UINT8 --mean_values=0 --std_values=9.8077
+toco \
+--graph_def_file=/content/tiny_conv.pb --output_file=/tmp/tiny_conv.tflite \
+--input_shapes=1,1960 --input_arrays=Reshape_1 --output_arrays='labels_softmax' \
+--inference_type=QUANTIZED_UINT8 --mean_values=0 --std_dev_values=9.8077
 ```
 
 Finally, convert the file into a C source file that can be compiled into an
@@ -166,45 +630,7 @@
 xxd -i /tmp/tiny_conv.tflite > /tmp/tiny_conv_micro_features_model_data.cc
 ```
 
-Next, we need to update `tiny_conv_micro_features_model_data.cc` so that it is
-compatible with the `micro_features` sample code.
-
-First, open the file. The top two lines should look approximately as follows
-(the exact hex values may be different):
-
-```cpp
-unsigned char _tmp_tiny_conv_tflite[] = {
-  0x18, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c, 0x33, 0x00, 0x00, 0x0e, 0x00,
-```
-
-You need to add the include from the following snippet, and tweak the variable
-declaration. Don’t change the hex values, though:
-
-```cpp
-#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/tiny_conv_micro_features_model_data.h"
-
-const unsigned char g_tiny_conv_micro_features_model_data[] = {
-  0x18, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c, 0x33, 0x00, 0x00, 0x0e, 0x00,
-```
-
-Next, go to the very bottom of the file and find the variable named
-`_tmp_tiny_conv_tflite_len`.
-
-```cpp
-unsigned int _tmp_tiny_conv_tflite_len = 19800;
-```
-
-Change the declaration as follows, but do not change the number assigned to it,
-even if your number is different from the one in this guide.
-
-```cpp
-const int g_tiny_conv_micro_features_model_data_len = 19800;
-```
-
-Finally, save the file, then copy the `tiny_conv_micro_features_model_data.cc`
-file into the `micro_features/` subdirectory of your `tf_microspeech/` project.
-
-### Creating Your Own Model With Google Cloud
+### Use Google Cloud
 
 If want to train your model in Google Cloud you can do so by using
 pre-configured Deep Learning images.
@@ -231,28 +657,8 @@
 gcloud compute ssh "jupyter@${INSTANCE_NAME}"
 ```
 
-now install Bazel:
-
-```
-wget https://github.com/bazelbuild/bazel/releases/download/0.15.0/bazel-0.15.0-installer-linux-x86_64.sh
-sudo bash ./bazel-0.15.0-installer-linux-x86_64.sh
-source /usr/local/lib/bazel/bin/bazel-complete.bash
-sudo ln /usr/local/bin/bazel /usr/bin/bazel
-```
-
-and finally run the build:
-
-```
-# TensorFlow already pre-baked on the image
-cd src/tensorflow
-bazel run -c opt --copt=-mavx2 --copt=-mfma \
-tensorflow/examples/speech_commands:train -- \
---model_architecture=tiny_conv --window_stride=20 --preprocess=average \
---wanted_words="yes,no" --silence_percentage=25 --unknown_percentage=25 --quantize=1
-```
-
-After build is over follow the rest of the instructions from this tutorial. And
-finally do not forget to remove the instance when training is done:
+Finally, follow the instructions in the previous section to train the model. Do
+not forget to remove the instance when training is done:
 
 ```
 gcloud compute instances delete "${INSTANCE_NAME}" --zone="${ZONE}"
diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/arduino/audio_provider.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/arduino/audio_provider.cc
new file mode 100644
index 0000000..e8c27c8
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/examples/micro_speech/arduino/audio_provider.cc
@@ -0,0 +1,118 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h"
+
+#include "PDM.h"
+#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h"
+
+namespace {
+bool g_is_audio_initialized = false;
+// An internal buffer able to fit 16x our sample size
+constexpr int kAudioCaptureBufferSize = DEFAULT_PDM_BUFFER_SIZE * 16;
+int16_t g_audio_capture_buffer[kAudioCaptureBufferSize];
+// A buffer that holds our output
+int16_t g_audio_output_buffer[kMaxAudioSampleSize];
+// Mark as volatile so we can check in a while loop to see if
+// any samples have arrived yet.
+volatile int32_t g_latest_audio_timestamp = 0;
+}  // namespace
+
+void CaptureSamples() {
+  // This is how many bytes of new data we have each time this is called
+  const int number_of_samples = DEFAULT_PDM_BUFFER_SIZE;
+  // Calculate what timestamp the last audio sample represents
+  const int32_t time_in_ms =
+      g_latest_audio_timestamp +
+      (number_of_samples / (kAudioSampleFrequency / 1000));
+  // Determine the index, in the history of all samples, of the last sample
+  const int32_t start_sample_offset =
+      g_latest_audio_timestamp * (kAudioSampleFrequency / 1000);
+  // Determine the index of this sample in our ring buffer
+  const int capture_index = start_sample_offset % kAudioCaptureBufferSize;
+  // Read the data to the correct place in our buffer
+  PDM.read(g_audio_capture_buffer + capture_index, DEFAULT_PDM_BUFFER_SIZE);
+  // This is how we let the outside world know that new audio data has arrived.
+  g_latest_audio_timestamp = time_in_ms;
+}
+
+TfLiteStatus InitAudioRecording(tflite::ErrorReporter* error_reporter) {
+  // Hook up the callback that will be called with each sample
+  PDM.onReceive(CaptureSamples);
+  // Start listening for audio: MONO @ 16KHz with gain at 20
+  PDM.begin(1, kAudioSampleFrequency);
+  PDM.setGain(20);
+  // Block until we have our first audio sample
+  while (!g_latest_audio_timestamp) {
+  }
+
+  return kTfLiteOk;
+}
+
+TfLiteStatus GetAudioSamples(tflite::ErrorReporter* error_reporter,
+                             int start_ms, int duration_ms,
+                             int* audio_samples_size, int16_t** audio_samples) {
+  // Set everything up to start receiving audio
+  if (!g_is_audio_initialized) {
+    TfLiteStatus init_status = InitAudioRecording(error_reporter);
+    if (init_status != kTfLiteOk) {
+      return init_status;
+    }
+    g_is_audio_initialized = true;
+  }
+  // This next part should only be called when the main thread notices that the
+  // latest audio sample data timestamp has changed, so that there's new data
+  // in the capture ring buffer. The ring buffer will eventually wrap around and
+  // overwrite the data, but the assumption is that the main thread is checking
+  // often enough and the buffer is large enough that this call will be made
+  // before that happens.
+
+  // Determine the index, in the history of all samples, of the first
+  // sample we want
+  const int start_offset = start_ms * (kAudioSampleFrequency / 1000);
+  // Determine how many samples we want in total
+  const int duration_sample_count =
+      duration_ms * (kAudioSampleFrequency / 1000);
+  for (int i = 0; i < duration_sample_count; ++i) {
+    // For each sample, transform its index in the history of all samples into
+    // its index in g_audio_capture_buffer
+    const int capture_index = (start_offset + i) % kAudioCaptureBufferSize;
+    // Write the sample to the output buffer
+    g_audio_output_buffer[i] = g_audio_capture_buffer[capture_index];
+  }
+
+  // Set pointers to provide access to the audio
+  *audio_samples_size = kMaxAudioSampleSize;
+  *audio_samples = g_audio_output_buffer;
+
+  return kTfLiteOk;
+}
+
+int32_t LatestAudioTimestamp() { return g_latest_audio_timestamp; }
diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/arduino/command_responder.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/arduino/command_responder.cc
new file mode 100644
index 0000000..c98b8fb
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/examples/micro_speech/arduino/command_responder.cc
@@ -0,0 +1,61 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/experimental/micro/examples/micro_speech/command_responder.h"
+
+#include "Arduino.h"
+
+// Toggles the LED every inference, and keeps it on for ~2 seconds if a "yes"
+// was heard
+void RespondToCommand(tflite::ErrorReporter* error_reporter,
+                      int32_t current_time, const char* found_command,
+                      uint8_t score, bool is_new_command) {
+  static bool is_initialized = false;
+  if (!is_initialized) {
+    pinMode(LED_BUILTIN, OUTPUT);
+    is_initialized = true;
+  }
+  static int32_t last_yes_time = 0;
+  static int count = 0;
+
+  if (is_new_command) {
+    error_reporter->Report("Heard %s (%d) @%dms", found_command, score,
+                           current_time);
+    // If we heard a "yes", switch on an LED and store the time.
+    if (found_command[0] == 'y') {
+      last_yes_time = current_time;
+      digitalWrite(LED_BUILTIN, HIGH);
+    }
+  }
+
+  // If last_yes_time is non-zero but was >3 seconds ago, zero it
+  // and switch off the LED.
+  if (last_yes_time != 0) {
+    if (last_yes_time < (current_time - 3000)) {
+      last_yes_time = 0;
+      digitalWrite(LED_BUILTIN, LOW);
+    }
+    // If it is non-zero but <3 seconds ago, do nothing.
+    return;
+  }
+
+  // Otherwise, toggle the LED every time an inference is performed.
+  ++count;
+  if (count & 1) {
+    digitalWrite(LED_BUILTIN, HIGH);
+  } else {
+    digitalWrite(LED_BUILTIN, LOW);
+  }
+}
diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc
index b5dfa3d..ebb0207 100644
--- a/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc
+++ b/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc
@@ -95,8 +95,10 @@
       const int32_t slice_start_ms = (new_step * kFeatureSliceStrideMs);
       int16_t* audio_samples = nullptr;
       int audio_samples_size = 0;
-      GetAudioSamples(error_reporter, slice_start_ms, kFeatureSliceDurationMs,
-                      &audio_samples_size, &audio_samples);
+      // TODO(petewarden): Fix bug that leads to non-zero slice_start_ms
+      GetAudioSamples(error_reporter, (slice_start_ms > 0 ? slice_start_ms : 0),
+                      kFeatureSliceDurationMs, &audio_samples_size,
+                      &audio_samples);
       if (audio_samples_size < kMaxAudioSampleSize) {
         error_reporter->Report("Audio data size %d too small, want %d",
                                audio_samples_size, kMaxAudioSampleSize);
diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/nxp_k66f/audio_provider.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/nxp_k66f/audio_provider.cc
new file mode 100644
index 0000000..55267e5
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/examples/micro_speech/nxp_k66f/audio_provider.cc
@@ -0,0 +1,380 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// TensorFlow Headers
+#include "tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h"
+
+#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h"
+
+// mbed and NXP FRDM-K66F Headers
+#include "fsl_clock_config.h"  // NOLINT
+#include "fsl_common.h"        // NOLINT
+#include "fsl_dmamux.h"        // NOLINT
+#include "fsl_edma.h"          // NOLINT
+#include "fsl_gpio.h"          // NOLINT
+#include "fsl_i2c.h"           // NOLINT
+#include "fsl_lmem_cache.h"    // NOLINT
+#include "fsl_port.h"          // NOLINT
+#include "fsl_sai.h"           // NOLINT
+#include "fsl_sai_edma.h"      // NOLINT
+#include "mbed.h"              // NOLINT
+
+// Compiler pragma for alignment of data to make efficient use of DMA
+#if (defined(__ICCARM__))
+#if ((!(defined(FSL_FEATURE_HAS_NO_NONCACHEABLE_SECTION) && \
+        FSL_FEATURE_HAS_NO_NONCACHEABLE_SECTION)) &&        \
+     defined(FSL_FEATURE_L1ICACHE_LINESIZE_BYTE))
+#define AT_NONCACHEABLE_SECTION_ALIGN(var, alignbytes) \
+  SDK_PRAGMA(data_alignment = alignbytes) var @"NonCacheable"
+#else
+#define AT_NONCACHEABLE_SECTION_ALIGN(var, alignbytes) \
+  SDK_PRAGMA(data_alignment = alignbytes) var
+#endif
+#elif (defined(__CC_ARM) || defined(__ARMCC_VERSION))
+#if ((!(defined(FSL_FEATURE_HAS_NO_NONCACHEABLE_SECTION) && \
+        FSL_FEATURE_HAS_NO_NONCACHEABLE_SECTION)) &&        \
+     defined(FSL_FEATURE_L1ICACHE_LINESIZE_BYTE))
+#define AT_NONCACHEABLE_SECTION_ALIGN(var, alignbytes) \
+  __attribute__((section("NonCacheable"), zero_init))  \
+      __attribute__((aligned(alignbytes))) var
+#else
+#define AT_NONCACHEABLE_SECTION_ALIGN(var, alignbytes) \
+  __attribute__((aligned(alignbytes))) var
+#endif
+#elif (defined(__GNUC__))
+#if ((!(defined(FSL_FEATURE_HAS_NO_NONCACHEABLE_SECTION) && \
+        FSL_FEATURE_HAS_NO_NONCACHEABLE_SECTION)) &&        \
+     defined(FSL_FEATURE_L1ICACHE_LINESIZE_BYTE))
+#define AT_NONCACHEABLE_SECTION_ALIGN(var, alignbytes)          \
+  __attribute__((section("NonCacheable,\"aw\",%nobits @"))) var \
+      __attribute__((aligned(alignbytes)))
+#else
+#define AT_NONCACHEABLE_SECTION_ALIGN(var, alignbytes) \
+  var __attribute__((aligned(alignbytes)))
+#endif
+#else
+#error Toolchain not supported.
+#define AT_NONCACHEABLE_SECTION_ALIGN(var, alignbytes) var
+#endif
+
+namespace {
+
+// Buffer configuration for receiving audio data
+constexpr int kNoOfSamples = 512;
+constexpr int kBufferSize = kNoOfSamples * 2;
+constexpr int kNoOfBuffers = 4;
+constexpr int kOverSampleRate = 384;
+
+// Buffer management
+AT_NONCACHEABLE_SECTION_ALIGN(
+    static int16_t g_rx_buffer[kNoOfBuffers * kNoOfSamples], 4);
+sai_edma_handle_t g_tx_sai_handle;
+sai_edma_handle_t g_rx_sai_handle;
+static volatile uint32_t g_tx_index = 0;
+static volatile uint32_t g_rx_index = 0;
+edma_handle_t g_tx_dma_handle = {0};
+edma_handle_t g_rx_dma_handle = {0};
+sai_transfer_t g_sai_transfer;
+
+bool g_is_audio_initialized = false;
+constexpr int kAudioCaptureBufferSize = kAudioSampleFrequency * 0.5;
+int16_t g_audio_capture_buffer[kAudioCaptureBufferSize];
+int16_t g_audio_output_buffer[kMaxAudioSampleSize];
+int32_t g_latest_audio_timestamp = 0;
+
+// DA7212 configuration
+constexpr int da7212ConfigurationSize = 48;
+constexpr int da7212I2cAddress = 0x1A;
+volatile uint8_t g_da7212_register_config[da7212ConfigurationSize][2] = {
+    {0x21, 0x10},  // Set DIG_ROUTING_DAI to ADC right and ADC left
+    {0x22, 0x05},  // Set Sampling rate to 16 KHz
+    {0x23, 0x08},  // Enable master bias
+    {0x24, 0x00},  // Clear PLL Fractional division top
+    {0x25, 0x00},  // Clear PLL Fractional division bottom
+    {0x26, 0x20},  // Set PLL Integer division to 32
+    {0x27, 0x80},  // Set PLL input range to 2-10 MHz,system clock is PLL output
+    {0x28, 0x01},  // 64  BCLK per WCLK and S
+    {0x29, 0xC0},  // I2S 16-bit per channel, output is driven, DAI enable
+    {0x2A, 0x32},  // One stream for left and another for right
+    {0x45, 0x67},  // Set DAC Gain to 6 dB
+    {0x46, 0x67},  // Set DAC Gain to 6 dB
+    {0x47, 0xF1},  // Enable charge pump
+    {0x4B, 0x08},  // DAC_L selected
+    {0x4C, 0x08},  // DAC_R selected
+    {0x69, 0xA0},  // Enable DAC_L
+    {0x6A, 0xA0},  // Enable DAC_R
+    {0x6B, 0xB8},  // Enable HP_L
+    {0x6C, 0xB8},  // Enable HP_R
+    {0x6E, 0x98},  // Enable MIXOUT_L
+    {0x6F, 0x98},  // Enable MIXOUT_R
+    {0x95, 0x32}, {0xE0, 0x00}, {0x32, 0x80},  // Enable MIC
+    {0x33, 0x80},                              // Enable MIC
+    {0x34, 0x03},                              // Add MXIN Gain
+    {0x35, 0x03},                              // Add MXIN Gain
+    {0x36, 0x78},                              // Add ADC Gain
+    {0x37, 0x78},                              // Add ADC Gain
+    {0x60, 0xB0}, {0x61, 0xB0}, {0x65, 0x88}, {0x66, 0x88}, {0x67, 0xA0},
+    {0x68, 0xA0}, {0x62, 0xA9}, {0x50, 0xFE}, {0x51, 0xF7}, {0x93, 0x07},
+    {0x3A, 0x04}, {0x64, 0x84}, {0x39, 0x01}, {0x63, 0x80}, {0x38, 0x88},
+    {0x24, 0x00}, {0x25, 0x00}, {0x26, 0x20}, {0x20, 0x80}};
+
+// Save audio samples into intermediate buffer
+void CaptureSamples(const int16_t *sample_data) {
+  const int sample_size = kNoOfSamples;
+  const int32_t time_in_ms =
+      g_latest_audio_timestamp + (sample_size / (kAudioSampleFrequency / 1000));
+
+  const int32_t start_sample_offset =
+      g_latest_audio_timestamp * (kAudioSampleFrequency / 1000);
+  for (int i = 0; i < sample_size; ++i) {
+    const int capture_index =
+        (start_sample_offset + i) % kAudioCaptureBufferSize;
+    g_audio_capture_buffer[capture_index] = sample_data[i];
+  }
+  // This is how we let the outside world know that new audio data has arrived.
+  g_latest_audio_timestamp = time_in_ms;
+}
+
+// Callback function for SAI RX EDMA transfer complete
+static void SaiRxCallback(I2S_Type *base, sai_edma_handle_t *handle,
+                          status_t status, void *userData) {
+  if (kStatus_SAI_RxError == status) {
+    // Handle the error
+  } else {
+    // Save audio data into intermediate buffer
+    CaptureSamples(
+        reinterpret_cast<int16_t *>(g_rx_buffer + g_tx_index * kNoOfSamples));
+
+    // Submit received audio buffer to SAI TX for audio loopback debug
+    g_sai_transfer.data = (uint8_t *)(g_rx_buffer + g_tx_index * kNoOfSamples);
+    g_sai_transfer.dataSize = kBufferSize;
+    if (kStatus_Success ==
+        SAI_TransferSendEDMA(I2S0, &g_tx_sai_handle, &g_sai_transfer)) {
+      g_tx_index++;
+    }
+    if (g_tx_index == kNoOfBuffers) {
+      g_tx_index = 0U;
+    }
+
+    // Submit buffer to SAI RX to receive audio data
+    g_sai_transfer.data = (uint8_t *)(g_rx_buffer + g_rx_index * kNoOfSamples);
+    g_sai_transfer.dataSize = kBufferSize;
+    if (kStatus_Success ==
+        SAI_TransferReceiveEDMA(I2S0, &g_rx_sai_handle, &g_sai_transfer)) {
+      g_rx_index++;
+    }
+    if (g_rx_index == kNoOfBuffers) {
+      g_rx_index = 0U;
+    }
+  }
+}
+
+// Callback function for TX Buffer transfer
+static void SaiTxCallback(I2S_Type *base, sai_edma_handle_t *handle,
+                          status_t status, void *userData) {
+  if (kStatus_SAI_TxError == status) {
+    // Handle the error
+  }
+  // Do nothing
+}
+
+// Initialize MCU pins
+void McuInitializePins(void) {
+  // Port B Clock Gate Control: Clock enabled
+  CLOCK_EnableClock(kCLOCK_PortB);
+  // Port C Clock Gate Control: Clock enabled
+  CLOCK_EnableClock(kCLOCK_PortC);
+  // Port E Clock Gate Control: Clock enabled
+  CLOCK_EnableClock(kCLOCK_PortE);
+
+  // PORTB16 (pin E10) is configured as UART0_RX
+  PORT_SetPinMux(PORTB, 16U, kPORT_MuxAlt3);
+  // PORTB17 (pin E9) is configured as UART0_TX
+  PORT_SetPinMux(PORTB, 17U, kPORT_MuxAlt3);
+  // PORTC1 (pin B11) is configured as I2S0_TXD0
+  PORT_SetPinMux(PORTC, 1U, kPORT_MuxAlt6);
+
+  // PORTC10 (pin C7) is configured as I2C1_SCL
+  const port_pin_config_t portc10_pinC7_config = {
+      kPORT_PullUp,          kPORT_FastSlewRate,     kPORT_PassiveFilterDisable,
+      kPORT_OpenDrainEnable, kPORT_LowDriveStrength, kPORT_MuxAlt2,
+      kPORT_UnlockRegister};
+  PORT_SetPinConfig(PORTC, 10U, &portc10_pinC7_config);
+
+  // PORTC11 (pin B7) is configured as I2C1_SDA
+  const port_pin_config_t portc11_pinB7_config = {
+      kPORT_PullUp,          kPORT_FastSlewRate,     kPORT_PassiveFilterDisable,
+      kPORT_OpenDrainEnable, kPORT_LowDriveStrength, kPORT_MuxAlt2,
+      kPORT_UnlockRegister};
+  PORT_SetPinConfig(PORTC, 11U, &portc11_pinB7_config);
+
+  // PORTC6 (pin C8) is configured as I2S0_MCLK
+  PORT_SetPinMux(PORTC, 6U, kPORT_MuxAlt6);
+  // PORTE11 (pin G4) is configured as I2S0_TX_FS
+  PORT_SetPinMux(PORTE, 11U, kPORT_MuxAlt4);
+  // PORTE12 (pin G3) is configured as I2S0_TX_BCLK
+  PORT_SetPinMux(PORTE, 12U, kPORT_MuxAlt4);
+  SIM->SOPT5 =
+      ((SIM->SOPT5 & (~(SIM_SOPT5_UART0TXSRC_MASK))) | SIM_SOPT5_UART0TXSRC(0));
+  // PORTE7 (pin F4) is configured as I2S0_RXD0
+  PORT_SetPinMux(PORTE, 7U, kPORT_MuxAlt4);
+  SIM->SOPT5 =
+      ((SIM->SOPT5 & (~(SIM_SOPT5_UART0TXSRC_MASK))) | SIM_SOPT5_UART0TXSRC(0));
+}
+
+// Write DA7212 registers using I2C
+status_t Da7212WriteRegister(uint8_t register_address, uint8_t register_data) {
+  uint8_t data[1];
+  data[0] = (uint8_t)register_data;
+  i2c_master_transfer_t i2c_data;
+  i2c_data.slaveAddress = da7212I2cAddress;
+  i2c_data.direction = kI2C_Write;
+  i2c_data.subaddress = register_address;
+  i2c_data.subaddressSize = 1;
+  i2c_data.data = (uint8_t * volatile) data;
+  i2c_data.dataSize = 1;
+  i2c_data.flags = kI2C_TransferDefaultFlag;
+  return I2C_MasterTransferBlocking(I2C1, &i2c_data);
+}
+
+// Initialize DA7212
+void Da7212Initialize(void) {
+  for (uint32_t i = 0; i < da7212ConfigurationSize; i++) {
+    Da7212WriteRegister(g_da7212_register_config[i][0],
+                        g_da7212_register_config[i][1]);
+  }
+}
+
+// Initalization for receiving audio data
+TfLiteStatus InitAudioRecording(tflite::ErrorReporter *error_reporter) {
+  edma_config_t dma_config = {0};
+  sai_config_t sai_config;
+  sai_transfer_format_t sai_format;
+  volatile uint32_t delay_cycle = 500000;
+  i2c_master_config_t i2c_config = {0};
+
+  // Initialize FRDM-K66F pins
+  McuInitializePins();
+
+  // Set Clock to 180 MHz
+  // BOARD_BootClockRUN();
+  BOARD_BootClockHSRUN();
+
+  // Enable Code Caching to improve performance
+  LMEM_EnableCodeCache(LMEM, true);
+
+  // Initialize I2C
+  I2C_MasterGetDefaultConfig(&i2c_config);
+  I2C_MasterInit(I2C1, &i2c_config, CLOCK_GetFreq(kCLOCK_BusClk));
+
+  // Initialize SAI
+  memset(&sai_format, 0U, sizeof(sai_transfer_format_t));
+  SAI_TxGetDefaultConfig(&sai_config);
+  SAI_TxInit(I2S0, &sai_config);
+  SAI_RxGetDefaultConfig(&sai_config);
+  SAI_RxInit(I2S0, &sai_config);
+  sai_format.bitWidth = kSAI_WordWidth16bits;
+  sai_format.channel = 0U;
+  sai_format.sampleRate_Hz = kSAI_SampleRate16KHz;
+  sai_format.masterClockHz = kOverSampleRate * sai_format.sampleRate_Hz;
+  sai_format.protocol = sai_config.protocol;
+  sai_format.stereo = kSAI_MonoRight;
+  sai_format.watermark = FSL_FEATURE_SAI_FIFO_COUNT / 2U;
+
+  // Initialize DA7212
+  Da7212Initialize();
+
+  // Initialize SAI EDMA
+  EDMA_GetDefaultConfig(&dma_config);
+  EDMA_Init(DMA0, &dma_config);
+  EDMA_CreateHandle(&g_tx_dma_handle, DMA0, 0);
+  EDMA_CreateHandle(&g_rx_dma_handle, DMA0, 1);
+
+  // Initialize DMA MUX
+  DMAMUX_Init(DMAMUX);
+  DMAMUX_SetSource(DMAMUX, 0, (uint8_t)kDmaRequestMux0I2S0Tx);
+  DMAMUX_EnableChannel(DMAMUX, 0);
+  DMAMUX_SetSource(DMAMUX, 1, (uint8_t)kDmaRequestMux0I2S0Rx);
+  DMAMUX_EnableChannel(DMAMUX, 1);
+
+  // Wait few cycles for DA7212
+  while (delay_cycle) {
+    __ASM("nop");
+    delay_cycle--;
+  }
+
+  // Setup SAI EDMA Callbacks
+  SAI_TransferTxCreateHandleEDMA(I2S0, &g_tx_sai_handle, SaiTxCallback, NULL,
+                                 &g_tx_dma_handle);
+  SAI_TransferRxCreateHandleEDMA(I2S0, &g_rx_sai_handle, SaiRxCallback, NULL,
+                                 &g_rx_dma_handle);
+  SAI_TransferTxSetFormatEDMA(I2S0, &g_tx_sai_handle, &sai_format,
+                              CLOCK_GetFreq(kCLOCK_CoreSysClk),
+                              sai_format.masterClockHz);
+  SAI_TransferRxSetFormatEDMA(I2S0, &g_rx_sai_handle, &sai_format,
+                              CLOCK_GetFreq(kCLOCK_CoreSysClk),
+                              sai_format.masterClockHz);
+
+  // Submit buffers to SAI RX to start receiving audio
+  g_sai_transfer.data = (uint8_t *)(g_rx_buffer + g_rx_index * kNoOfSamples);
+  g_sai_transfer.dataSize = kBufferSize;
+  if (kStatus_Success ==
+      SAI_TransferReceiveEDMA(I2S0, &g_rx_sai_handle, &g_sai_transfer)) {
+    g_rx_index++;
+  }
+  if (g_rx_index == kNoOfBuffers) {
+    g_rx_index = 0U;
+  }
+  g_sai_transfer.data = (uint8_t *)(g_rx_buffer + g_rx_index * kNoOfSamples);
+  g_sai_transfer.dataSize = kBufferSize;
+  if (kStatus_Success ==
+      SAI_TransferReceiveEDMA(I2S0, &g_rx_sai_handle, &g_sai_transfer)) {
+    g_rx_index++;
+  }
+  if (g_rx_index == kNoOfBuffers) {
+    g_rx_index = 0U;
+  }
+  return kTfLiteOk;
+}
+
+}  // namespace
+
+// Main entry point for getting audio data.
+TfLiteStatus GetAudioSamples(tflite::ErrorReporter *error_reporter,
+                             int start_ms, int duration_ms,
+                             int *audio_samples_size, int16_t **audio_samples) {
+  if (!g_is_audio_initialized) {
+    TfLiteStatus init_status = InitAudioRecording(error_reporter);
+    if (init_status != kTfLiteOk) {
+      return init_status;
+    }
+    g_is_audio_initialized = true;
+  }
+  // This should only be called when the main thread notices that the latest
+  // audio sample data timestamp has changed, so that there's new data in the
+  // capture ring buffer. The ring buffer will eventually wrap around and
+  // overwrite the data, but the assumption is that the main thread is checking
+  // often enough and the buffer is large enough that this call will be made
+  // before that happens.
+  const int start_offset = start_ms * (kAudioSampleFrequency / 1000);
+  const int duration_sample_count =
+      duration_ms * (kAudioSampleFrequency / 1000);
+  for (int i = 0; i < duration_sample_count; ++i) {
+    const int capture_index = (start_offset + i) % kAudioCaptureBufferSize;
+    g_audio_output_buffer[i] = g_audio_capture_buffer[capture_index];
+  }
+  *audio_samples_size = kMaxAudioSampleSize;
+  *audio_samples = g_audio_output_buffer;
+  return kTfLiteOk;
+}
+
+int32_t LatestAudioTimestamp() { return g_latest_audio_timestamp; }
diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands_test.cc
index 6582c94..875ffac 100644
--- a/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands_test.cc
+++ b/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands_test.cc
@@ -78,8 +78,10 @@
 
   RecognizeCommands recognize_commands(error_reporter);
 
+  std::initializer_list<uint8_t> result_data = {255, 0, 0, 0};
+  auto result_dims = {2, 1, 4};
   TfLiteTensor results = tflite::testing::CreateQuantizedTensor(
-      {255, 0, 0, 0}, tflite::testing::IntArrayFromInitializer({2, 1, 4}),
+      result_data, tflite::testing::IntArrayFromInitializer(result_dims),
       "input_tensor", 0.0f, 128.0f);
 
   const char* found_command;
@@ -96,8 +98,10 @@
 
   RecognizeCommands recognize_commands(error_reporter, 1000, 51);
 
+  std::initializer_list<uint8_t> yes_data = {0, 0, 255, 0};
+  auto yes_dims = {2, 1, 4};
   TfLiteTensor yes_results = tflite::testing::CreateQuantizedTensor(
-      {0, 0, 255, 0}, tflite::testing::IntArrayFromInitializer({2, 1, 4}),
+      yes_data, tflite::testing::IntArrayFromInitializer(yes_dims),
       "input_tensor", 0.0f, 128.0f);
 
   bool has_found_new_command = false;
@@ -122,8 +126,10 @@
     TF_LITE_MICRO_EXPECT_EQ(0, tflite::testing::TestStrcmp("yes", new_command));
   }
 
+  std::initializer_list<uint8_t> no_data = {0, 0, 0, 255};
+  auto no_dims = {2, 1, 4};
   TfLiteTensor no_results = tflite::testing::CreateQuantizedTensor(
-      {0, 0, 0, 255}, tflite::testing::IntArrayFromInitializer({2, 1, 4}),
+      no_data, tflite::testing::IntArrayFromInitializer(no_dims),
       "input_tensor", 0.0f, 128.0f);
   has_found_new_command = false;
   new_command = "";
@@ -155,8 +161,10 @@
 
   RecognizeCommands recognize_commands(error_reporter, 1000, 51);
 
+  std::initializer_list<uint8_t> bad_data = {0, 0, 255};
+  auto bad_dims = {2, 1, 3};
   TfLiteTensor bad_results = tflite::testing::CreateQuantizedTensor(
-      {0, 0, 255}, tflite::testing::IntArrayFromInitializer({2, 1, 3}),
+      bad_data, tflite::testing::IntArrayFromInitializer(bad_dims),
       "input_tensor", 0.0f, 128.0f);
 
   const char* found_command;
@@ -173,8 +181,10 @@
 
   RecognizeCommands recognize_commands(error_reporter, 1000, 51);
 
+  std::initializer_list<uint8_t> result_data = {0, 0, 255, 0};
+  auto result_dims = {2, 1, 4};
   TfLiteTensor results = tflite::testing::CreateQuantizedTensor(
-      {0, 0, 255, 0}, tflite::testing::IntArrayFromInitializer({2, 1, 4}),
+      result_data, tflite::testing::IntArrayFromInitializer(result_dims),
       "input_tensor", 0.0f, 128.0f);
 
   const char* found_command;
@@ -194,8 +204,10 @@
 
   RecognizeCommands recognize_commands(error_reporter, 1000, 51);
 
+  std::initializer_list<uint8_t> result_data = {0, 0, 255, 0};
+  auto result_dims = {2, 1, 4};
   TfLiteTensor results = tflite::testing::CreateQuantizedTensor(
-      {0, 0, 255, 0}, tflite::testing::IntArrayFromInitializer({2, 1, 4}),
+      result_data, tflite::testing::IntArrayFromInitializer(result_dims),
       "input_tensor", 0.0f, 128.0f);
 
   const char* found_command;
diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/train_speech_model.ipynb b/tensorflow/lite/experimental/micro/examples/micro_speech/train_speech_model.ipynb
new file mode 100644
index 0000000..3832e3e
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/examples/micro_speech/train_speech_model.ipynb
@@ -0,0 +1,327 @@
+{
+  "nbformat": 4,
+  "nbformat_minor": 0,
+  "metadata": {
+    "colab": {
+      "name": "Train simple audio recognition model",
+      "version": "0.3.2",
+      "provenance": [],
+      "collapsed_sections": []
+    },
+    "kernelspec": {
+      "name": "python3",
+      "display_name": "Python 3"
+    },
+    "accelerator": "GPU"
+  },
+  "cells": [
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "pO4-CY_TCZZS",
+        "colab_type": "text"
+      },
+      "source": [
+        "# Train a Simple Audio Recognition model for microcontroller use"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "BaFfr7DHRmGF",
+        "colab_type": "text"
+      },
+      "source": [
+        "This notebook demonstrates how to train a 20kb [Simple Audio Recognition](https://www.tensorflow.org/tutorials/sequences/audio_recognition) model for [TensorFlow Lite for Microcontrollers](https://tensorflow.org/lite/microcontrollers/overview). It will produce the same model used in the [micro_speech](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/experimental/micro/examples/micro_speech) example application.\n",
+        "\n",
+        "The model is designed to be used with [Google Colaboratory](https://colab.research.google.com).\n",
+        "\n",
+        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/micro/examples/micro_speech/train_speech_model.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
+        "  </td>\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/micro/examples/micro_speech/train_speech_model.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
+        "  </td>\n",
+        "</table>\n"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "XaVtYN4nlCft",
+        "colab_type": "text"
+      },
+      "source": [
+        "The notebook runs Python scripts to train and freeze the model, and uses the TensorFlow Lite converter to convert it for use with TensorFlow Lite for Microcontrollers.\n",
+        "\n",
+        "**Training is much faster using GPU acceleration.** Before you proceed, ensure you are using a GPU runtime by going to **Runtime -> Change runtime type** and selecting **GPU**. Training 18,000 iterations will take 1.5-2 hours on a GPU runtime.\n",
+        "\n",
+        "## Configure training\n",
+        "\n",
+        "The following `os.environ` lines can be customized to set the words that will be trained for, and the steps and learning rate of the training. The default values will result in the same model that is used in the micro_speech example. Run the cell to set the configuration:"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "ludfxbNIaegy",
+        "colab_type": "code",
+        "colab": {}
+      },
+      "source": [
+        "import os\n",
+        "\n",
+        "# A comma-delimited list of the words you want to train for.\n",
+        "# The options are: yes,no,up,down,left,right,on,off,stop,go\n",
+        "# All other words will be used to train an \"unknown\" category.\n",
+        "os.environ[\"WANTED_WORDS\"] = \"yes,no\"\n",
+        "\n",
+        "# The number of steps and learning rates can be specified as comma-separated\n",
+        "# lists to define the rate at each stage. For example,\n",
+        "# TRAINING_STEPS=15000,3000 and LEARNING_RATE=0.001,0.0001\n",
+        "# will run 18,000 training loops in total, with a rate of 0.001 for the first\n",
+        "# 15,000, and 0.0001 for the final 3,000.\n",
+        "os.environ[\"TRAINING_STEPS\"]=\"15000,3000\"\n",
+        "os.environ[\"LEARNING_RATE\"]=\"0.001,0.0001\"\n",
+        "\n",
+        "# Calculate the total number of steps, which is used to identify the checkpoint\n",
+        "# file name.\n",
+        "total_steps = sum(map(lambda string: int(string),\n",
+        "                  os.environ[\"TRAINING_STEPS\"].split(\",\")))\n",
+        "os.environ[\"TOTAL_STEPS\"] = str(total_steps)\n",
+        "\n",
+        "# Print the configuration to confirm it\n",
+        "!echo \"Training these words: ${WANTED_WORDS}\"\n",
+        "!echo \"Training steps in each stage: ${TRAINING_STEPS}\"\n",
+        "!echo \"Learning rate in each stage: ${LEARNING_RATE}\"\n",
+        "!echo \"Total number of training steps: ${TOTAL_STEPS}\"\n"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "gCgeOpvY9pAi",
+        "colab_type": "text"
+      },
+      "source": [
+        "## Install dependencies\n",
+        "\n",
+        "Next, we'll install a GPU build of TensorFlow, so we can use GPU acceleration for training."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "Nd1iM1o2ymvA",
+        "colab_type": "code",
+        "colab": {}
+      },
+      "source": [
+        "# Replace Colab's default TensorFlow install with a more recent\n",
+        "# build that contains the operations that are needed for training\n",
+        "!pip uninstall -y tensorflow tensorflow_estimator\n",
+        "!pip install -q tf-estimator-nightly==1.14.0.dev2019072901 tf-nightly-gpu==1.15.0.dev20190729"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "T9Ty5mR58E4i",
+        "colab_type": "text"
+      },
+      "source": [
+        "We'll also clone the TensorFlow repository, which contains the scripts that train and freeze the model."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "APGx0fEh7hFF",
+        "colab_type": "code",
+        "colab": {}
+      },
+      "source": [
+        "# Clone the repository from GitHub\n",
+        "!git clone -q https://github.com/tensorflow/tensorflow\n",
+        "# Check out a commit that has been tested to work\n",
+        "# with the build of TensorFlow we're using\n",
+        "!git -c advice.detachedHead=false -C tensorflow checkout 17ce384df70"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "aV_0qkYh98LD",
+        "colab_type": "text"
+      },
+      "source": [
+        "## Load TensorBoard\n",
+        "\n",
+        "Now, set up TensorBoard so that we can graph our accuracy and loss as training proceeds."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "yZArmzT85SLq",
+        "colab_type": "code",
+        "colab": {}
+      },
+      "source": [
+        "# Delete any old logs from previous runs\n",
+        "!rm -rf /content/retrain_logs\n",
+        "# Load TensorBoard\n",
+        "%load_ext tensorboard\n",
+        "%tensorboard --logdir /content/retrain_logs"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "x1J96Ron-O4R",
+        "colab_type": "text"
+      },
+      "source": [
+        "## Begin training\n",
+        "\n",
+        "Next, run the following script to begin training. The script will first download the training data:"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "VJsEZx6lynbY",
+        "colab_type": "code",
+        "colab": {}
+      },
+      "source": [
+        "!python tensorflow/tensorflow/examples/speech_commands/train.py \\\n",
+        "--model_architecture=tiny_conv --window_stride=20 --preprocess=micro \\\n",
+        "--wanted_words=${WANTED_WORDS} --silence_percentage=25 --unknown_percentage=25 \\\n",
+        "--quantize=1 --verbosity=WARN --how_many_training_steps=${TRAINING_STEPS} \\\n",
+        "--learning_rate=${LEARNING_RATE} --summaries_dir=/content/retrain_logs \\\n",
+        "--data_dir=/content/speech_dataset --train_dir=/content/speech_commands_train \\\n"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "XQUJLrdS-ftl",
+        "colab_type": "text"
+      },
+      "source": [
+        "## Freeze the graph\n",
+        "\n",
+        "Once training is complete, run the following cell to freeze the graph."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "xyc3_eLh9sAg",
+        "colab_type": "code",
+        "colab": {}
+      },
+      "source": [
+        "!python tensorflow/tensorflow/examples/speech_commands/freeze.py \\\n",
+        "--model_architecture=tiny_conv --window_stride=20 --preprocess=micro \\\n",
+        "--wanted_words=${WANTED_WORDS} --quantize=1 --output_file=/content/tiny_conv.pb \\\n",
+        "--start_checkpoint=/content/speech_commands_train/tiny_conv.ckpt-${TOTAL_STEPS}"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "_DBGDxVI-nKG",
+        "colab_type": "text"
+      },
+      "source": [
+        "## Convert the model\n",
+        "\n",
+        "Run this cell to use the TensorFlow Lite converter to convert the frozen graph into the TensorFlow Lite format, fully quantized for use with embedded devices."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "lBj_AyCh1cC0",
+        "colab_type": "code",
+        "colab": {}
+      },
+      "source": [
+        "!toco \\\n",
+        "--graph_def_file=/content/tiny_conv.pb --output_file=/content/tiny_conv.tflite \\\n",
+        "--input_shapes=1,1960 --input_arrays=Reshape_1 --output_arrays='labels_softmax' \\\n",
+        "--inference_type=QUANTIZED_UINT8 --mean_values=0 --std_dev_values=9.8077"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "dt6Zqbxu-wIi",
+        "colab_type": "text"
+      },
+      "source": [
+        "The following cell will print the model size, which will be under 20 kilobytes."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "XohZOTjR8ZyE",
+        "colab_type": "code",
+        "colab": {}
+      },
+      "source": [
+        "import os\n",
+        "model_size = os.path.getsize(\"/content/tiny_conv.tflite\")\n",
+        "print(\"Model is %d bytes\" % model_size)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "2pQnN0i_-0L2",
+        "colab_type": "text"
+      },
+      "source": [
+        "Finally, we use xxd to transform the model into a source file that can be included in a C++ project and loaded by TensorFlow Lite for Microcontrollers."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "eoYyh0VU8pca",
+        "colab_type": "code",
+        "colab": {}
+      },
+      "source": [
+        "# Install xxd if it is not available\n",
+        "!apt-get -qq install xxd\n",
+        "# Save the file as a C source file\n",
+        "!xxd -i /content/tiny_conv.tflite > /content/tiny_conv.cc\n",
+        "# Print the source file\n",
+        "!cat /content/tiny_conv.cc"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    }
+  ]
+}
\ No newline at end of file
diff --git a/tensorflow/lite/experimental/micro/examples/micro_vision/Makefile.inc b/tensorflow/lite/experimental/micro/examples/micro_vision/Makefile.inc
index af79262..b2aae87 100644
--- a/tensorflow/lite/experimental/micro/examples/micro_vision/Makefile.inc
+++ b/tensorflow/lite/experimental/micro/examples/micro_vision/Makefile.inc
@@ -1,8 +1,8 @@
-$(eval $(call add_third_party_download,$(PERSON_MODEL_URL),$(PERSON_MODEL_MD5),person_model,))
+$(eval $(call add_third_party_download,$(PERSON_MODEL_URL),$(PERSON_MODEL_MD5),person_model_grayscale,))
 
 MICRO_VISION_MODEL_SRCS := \
 tensorflow/lite/experimental/micro/examples/micro_vision/model_settings.cc \
-$(MAKEFILE_DIR)/downloads/person_model/person_detect_model_data.cc
+$(MAKEFILE_DIR)/downloads/person_model_grayscale/person_detect_model_data.cc
 
 MICRO_VISION_MODEL_HDRS := \
 tensorflow/lite/experimental/micro/examples/micro_vision/model_settings.h \
@@ -10,8 +10,8 @@
 
 MICRO_VISION_TEST_SRCS := \
 tensorflow/lite/experimental/micro/examples/micro_vision/micro_vision_test.cc \
-$(MAKEFILE_DIR)/downloads/person_model/no_person_image_data.cc \
-$(MAKEFILE_DIR)/downloads/person_model/person_image_data.cc \
+$(MAKEFILE_DIR)/downloads/person_model_grayscale/no_person_image_data.cc \
+$(MAKEFILE_DIR)/downloads/person_model_grayscale/person_image_data.cc \
 $(MICRO_VISION_MODEL_SRCS)
 
 MICRO_VISION_TEST_HDRS := \
diff --git a/tensorflow/lite/experimental/micro/examples/micro_vision/micro_vision_test.cc b/tensorflow/lite/experimental/micro/examples/micro_vision/micro_vision_test.cc
index a0874de..c573576 100644
--- a/tensorflow/lite/experimental/micro/examples/micro_vision/micro_vision_test.cc
+++ b/tensorflow/lite/experimental/micro/examples/micro_vision/micro_vision_test.cc
@@ -69,10 +69,9 @@
   TF_LITE_MICRO_EXPECT_NE(nullptr, input);
   TF_LITE_MICRO_EXPECT_EQ(4, input->dims->size);
   TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
-  TF_LITE_MICRO_EXPECT_EQ(96, input->dims->data[1]);
-  TF_LITE_MICRO_EXPECT_EQ(96, input->dims->data[2]);
-  // TODO(rocky): This will be a single channel for monochrome inputs
-  TF_LITE_MICRO_EXPECT_EQ(3, input->dims->data[3]);
+  TF_LITE_MICRO_EXPECT_EQ(kNumRows, input->dims->data[1]);
+  TF_LITE_MICRO_EXPECT_EQ(kNumCols, input->dims->data[2]);
+  TF_LITE_MICRO_EXPECT_EQ(kNumChannels, input->dims->data[3]);
   TF_LITE_MICRO_EXPECT_EQ(kTfLiteUInt8, input->type);
 
   // Copy an image with a person into the memory area used for the input.
@@ -95,7 +94,7 @@
   TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]);
   TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[1]);
   TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[2]);
-  TF_LITE_MICRO_EXPECT_EQ(3, output->dims->data[3]);
+  TF_LITE_MICRO_EXPECT_EQ(kCategoryCount, output->dims->data[3]);
   TF_LITE_MICRO_EXPECT_EQ(kTfLiteUInt8, output->type);
 
   // Make sure that the expected "Person" score is higher than the other class.
diff --git a/tensorflow/lite/experimental/micro/examples/micro_vision/model_settings.h b/tensorflow/lite/experimental/micro/examples/micro_vision/model_settings.h
index c50688a..e3cec7a 100644
--- a/tensorflow/lite/experimental/micro/examples/micro_vision/model_settings.h
+++ b/tensorflow/lite/experimental/micro/examples/micro_vision/model_settings.h
@@ -23,7 +23,7 @@
 // if you change your model you'll need to update these constants.
 constexpr int kNumCols = 96;
 constexpr int kNumRows = 96;
-constexpr int kNumChannels = 3;
+constexpr int kNumChannels = 1;
 
 constexpr int kMaxImageSize = kNumCols * kNumRows * kNumChannels;
 
diff --git a/tensorflow/lite/experimental/micro/kernels/BUILD b/tensorflow/lite/experimental/micro/kernels/BUILD
index 5121bc3..6dfb694 100644
--- a/tensorflow/lite/experimental/micro/kernels/BUILD
+++ b/tensorflow/lite/experimental/micro/kernels/BUILD
@@ -14,12 +14,17 @@
 cc_library(
     name = "micro_ops",
     srcs = [
+        "arg_min_max.cc",
         "conv.cc",
         "depthwise_conv.cc",
         "elementwise.cc",
+        "floor.cc",
         "fully_connected.cc",
+        "logical.cc",
+        "maximum_minimum.cc",
         "pooling.cc",
         "prelu.cc",
+        "reshape.cc",
         "softmax.cc",
     ],
     hdrs = [
@@ -27,6 +32,7 @@
     copts = tflite_copts(),
     deps = [
         "//tensorflow/lite/c:c_api_internal",
+        "//tensorflow/lite/experimental/micro/kernels:micro_utils",
         "//tensorflow/lite/kernels:kernel_util",
         "//tensorflow/lite/kernels:op_macros",
         "//tensorflow/lite/kernels:padding",
@@ -55,12 +61,17 @@
 cc_library(
     name = "portable_optimized_micro_ops",
     srcs = [
+        "arg_min_max.cc",
         "conv.cc",
         "elementwise.cc",
+        "floor.cc",
         "fully_connected.cc",
+        "logical.cc",
+        "maximum_minimum.cc",
         "pooling.cc",
         "portable_optimized/depthwise_conv.cc",
         "prelu.cc",
+        "reshape.cc",
         "softmax.cc",
     ],
     hdrs = [
@@ -68,6 +79,7 @@
     copts = tflite_copts(),
     deps = [
         "//tensorflow/lite/c:c_api_internal",
+        "//tensorflow/lite/experimental/micro/kernels:micro_utils",
         "//tensorflow/lite/kernels:kernel_util",
         "//tensorflow/lite/kernels:op_macros",
         "//tensorflow/lite/kernels:padding",
@@ -194,3 +206,74 @@
         "//tensorflow/lite/experimental/micro/testing:micro_test",
     ],
 )
+
+tflite_micro_cc_test(
+    name = "floor_test",
+    srcs = [
+        "floor_test.cc",
+    ],
+    deps = [
+        ":all_ops_resolver",
+        "//tensorflow/lite/c:c_api_internal",
+        "//tensorflow/lite/experimental/micro:micro_framework",
+        "//tensorflow/lite/experimental/micro/testing:micro_test",
+    ],
+)
+
+tflite_micro_cc_test(
+    name = "logical_test",
+    srcs = [
+        "logical_test.cc",
+    ],
+    deps = [
+        ":all_ops_resolver",
+        "//tensorflow/lite/c:c_api_internal",
+        "//tensorflow/lite/experimental/micro:micro_framework",
+        "//tensorflow/lite/experimental/micro/testing:micro_test",
+    ],
+)
+
+tflite_micro_cc_test(
+    name = "maximum_minimum_test",
+    srcs = [
+        "maximum_minimum_test.cc",
+    ],
+    deps = [
+        ":all_ops_resolver",
+        "//tensorflow/lite/c:c_api_internal",
+        "//tensorflow/lite/experimental/micro:micro_framework",
+        "//tensorflow/lite/experimental/micro/testing:micro_test",
+    ],
+)
+
+tflite_micro_cc_test(
+    name = "arg_min_max_test",
+    srcs = [
+        "arg_min_max_test.cc",
+    ],
+    deps = [
+        ":all_ops_resolver",
+        "//tensorflow/lite/c:c_api_internal",
+        "//tensorflow/lite/experimental/micro:micro_framework",
+        "//tensorflow/lite/experimental/micro/kernels:micro_utils",
+        "//tensorflow/lite/experimental/micro/testing:micro_test",
+    ],
+)
+
+cc_library(
+    name = "micro_utils",
+    hdrs = ["micro_utils.h"],
+)
+
+tflite_micro_cc_test(
+    name = "reshape_test",
+    srcs = [
+        "reshape_test.cc",
+    ],
+    deps = [
+        ":all_ops_resolver",
+        "//tensorflow/lite/c:c_api_internal",
+        "//tensorflow/lite/experimental/micro:micro_framework",
+        "//tensorflow/lite/experimental/micro/testing:micro_test",
+    ],
+)
diff --git a/tensorflow/lite/experimental/micro/kernels/all_ops_resolver.cc b/tensorflow/lite/experimental/micro/kernels/all_ops_resolver.cc
index c54cdf7..1439d9d 100644
--- a/tensorflow/lite/experimental/micro/kernels/all_ops_resolver.cc
+++ b/tensorflow/lite/experimental/micro/kernels/all_ops_resolver.cc
@@ -23,7 +23,22 @@
 TfLiteRegistration* Register_AVERAGE_POOL_2D();
 TfLiteRegistration* Register_MAX_POOL_2D();
 TfLiteRegistration* Register_ABS();
+TfLiteRegistration* Register_SIN();
+TfLiteRegistration* Register_COS();
+TfLiteRegistration* Register_LOG();
+TfLiteRegistration* Register_SQRT();
+TfLiteRegistration* Register_RSQRT();
+TfLiteRegistration* Register_SQUARE();
 TfLiteRegistration* Register_PRELU();
+TfLiteRegistration* Register_FLOOR();
+TfLiteRegistration* Register_MAXIMUM();
+TfLiteRegistration* Register_MINIMUM();
+TfLiteRegistration* Register_ARG_MAX();
+TfLiteRegistration* Register_ARG_MIN();
+TfLiteRegistration* Register_LOGICAL_OR();
+TfLiteRegistration* Register_LOGICAL_AND();
+TfLiteRegistration* Register_LOGICAL_NOT();
+TfLiteRegistration* Register_RESHAPE();
 
 AllOpsResolver::AllOpsResolver() {
   AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D());
@@ -35,7 +50,22 @@
   AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D());
   AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, Register_AVERAGE_POOL_2D());
   AddBuiltin(BuiltinOperator_ABS, Register_ABS());
+  AddBuiltin(BuiltinOperator_SIN, Register_SIN());
+  AddBuiltin(BuiltinOperator_COS, Register_COS());
+  AddBuiltin(BuiltinOperator_LOG, Register_LOG());
+  AddBuiltin(BuiltinOperator_SQRT, Register_SQRT());
+  AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT());
+  AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE());
   AddBuiltin(BuiltinOperator_PRELU, Register_PRELU());
+  AddBuiltin(BuiltinOperator_FLOOR, Register_FLOOR());
+  AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM());
+  AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM());
+  AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX());
+  AddBuiltin(BuiltinOperator_ARG_MIN, Register_ARG_MIN());
+  AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR());
+  AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND());
+  AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
+  AddBuiltin(BuiltinOperator_RESHAPE, Register_RESHAPE());
 }
 
 }  // namespace micro
diff --git a/tensorflow/lite/experimental/micro/kernels/arg_min_max.cc b/tensorflow/lite/experimental/micro/kernels/arg_min_max.cc
new file mode 100644
index 0000000..8b54096
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/kernels/arg_min_max.cc
@@ -0,0 +1,120 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/kernels/internal/reference/arg_min_max.h"
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/experimental/micro/kernels/micro_utils.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace micro {
+namespace arg_min_max {
+
+constexpr int kInputTensor = 0;
+constexpr int kAxis = 1;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+  return kTfLiteOk;
+}
+
+template <typename T1, typename T2, typename T3>
+inline void ArgMinMaxHelper(const RuntimeShape& input1_shape,
+                            const T1* input1_data, const T3* input2_data,
+                            const RuntimeShape& output_shape, T2* output_data,
+                            bool is_arg_max) {
+  if (is_arg_max) {
+    reference_ops::ArgMinMax(input1_shape, input1_data, input2_data,
+                             output_shape, output_data, micro::Greater());
+  } else {
+    reference_ops::ArgMinMax(input1_shape, input1_data, input2_data,
+                             output_shape, output_data, micro::Less());
+  }
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
+  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+  const TfLiteTensor* axis = GetInput(context, node, kAxis);
+  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+#define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type)            \
+  ArgMinMaxHelper(GetTensorShape(input), GetTensorData<data_type>(input), \
+                  GetTensorData<axis_type>(axis), GetTensorShape(output), \
+                  GetTensorData<output_type>(output), is_arg_max)
+  if (axis->type == kTfLiteInt32) {
+    if (output->type == kTfLiteInt32) {
+      switch (input->type) {
+        case kTfLiteFloat32:
+          TF_LITE_ARG_MIN_MAX(float, int32_t, int32_t);
+          break;
+        case kTfLiteUInt8:
+          TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int32_t);
+          break;
+        case kTfLiteInt8:
+          TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int32_t);
+          break;
+        default:
+          context->ReportError(context,
+                               "Only float32, uint8 and int8 are "
+                               "supported currently, got %s.",
+                               TfLiteTypeGetName(input->type));
+          return kTfLiteError;
+      }
+    } else {
+      context->ReportError(context,
+                           "Only int32 are supported currently, got %s.",
+                           TfLiteTypeGetName(output->type));
+      return kTfLiteError;
+    }
+  } else {
+    context->ReportError(context, "Only int32 are supported currently, got %s.",
+                         TfLiteTypeGetName(axis->type));
+    return kTfLiteError;
+  }
+
+#undef TF_LITE_ARG_MIN_MAX
+
+  return kTfLiteOk;
+}
+
+TfLiteStatus ArgMinEval(TfLiteContext* context, TfLiteNode* node) {
+  return Eval(context, node, false);
+}
+
+TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) {
+  return Eval(context, node, true);
+}
+
+}  // namespace arg_min_max
+
+TfLiteRegistration* Register_ARG_MAX() {
+  static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare,
+                                 arg_min_max::ArgMaxEval};
+  return &r;
+}
+
+TfLiteRegistration* Register_ARG_MIN() {
+  static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare,
+                                 arg_min_max::ArgMinEval};
+  return &r;
+}
+
+}  // namespace micro
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/lite/experimental/micro/kernels/arg_min_max_test.cc b/tensorflow/lite/experimental/micro/kernels/arg_min_max_test.cc
new file mode 100644
index 0000000..0c987e4
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/kernels/arg_min_max_test.cc
@@ -0,0 +1,397 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
+#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h"
+#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
+#include "tensorflow/lite/experimental/micro/testing/test_utils.h"
+
+namespace tflite {
+namespace testing {
+namespace {
+
+// If expected output is empty, the test is expected to fail.
+void TestArgMinMax(TfLiteTensor* input_tensor, TfLiteTensor* axis_tensor,
+                   TfLiteTensor* output_tensor,
+                   std::initializer_list<int> expected_output_data,
+                   bool using_min = false) {
+  const int output_dims_count = ElementCount(*output_tensor->dims);
+  constexpr int inputs_size = 2;
+  constexpr int outputs_size = 1;
+  constexpr int tensors_size = inputs_size + outputs_size;
+  TfLiteTensor tensors[tensors_size] = {
+      *input_tensor,
+      *axis_tensor,
+      *output_tensor,
+  };
+  TfLiteContext context;
+  PopulateContext(tensors, tensors_size, &context);
+  ::tflite::ops::micro::AllOpsResolver resolver;
+  const TfLiteRegistration* registration;
+  if (using_min) {
+    registration = resolver.FindOp(tflite::BuiltinOperator_ARG_MIN, 1);
+  } else {
+    registration = resolver.FindOp(tflite::BuiltinOperator_ARG_MAX, 1);
+  }
+  TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
+
+  size_t init_data_size = 0;
+  void* user_data = nullptr;
+  if (registration->init) {
+    user_data = registration->init(&context, nullptr, init_data_size);
+  }
+  int inputs_array_data[] = {2, 0, 1};
+  TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
+  int outputs_array_data[] = {1, 2};
+  TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
+  TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
+  TfLiteNode node;
+  node.inputs = inputs_array;
+  node.outputs = outputs_array;
+  node.temporaries = temporaries_array;
+  node.user_data = user_data;
+  node.builtin_data = nullptr;
+  node.custom_initial_data = nullptr;
+  node.custom_initial_data_size = 0;
+  node.delegate = nullptr;
+  if (registration->prepare) {
+    TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
+  }
+  TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
+  if (!expected_output_data.size()) {
+    TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
+                            registration->invoke(&context, &node));
+    return;
+  }
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
+  if (registration->free) {
+    registration->free(&context, user_data);
+  }
+  for (int i = 0; i < output_dims_count; ++i) {
+    TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i],
+                              output_tensor->data.i32[i], 1e-5f);
+  }
+}
+}  // namespace
+}  // namespace testing
+}  // namespace tflite
+
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(GetMaxArgFloat) {
+  int32_t output_data[1];
+  TfLiteIntArray* input_dims =
+      tflite::testing::IntArrayFromInitializer({4, 1, 1, 1, 4});
+  auto input_tensor = tflite::testing::CreateFloatTensor(
+      {0.1, 0.9, 0.7, 0.3}, input_dims, "input_tensor");
+  auto axis_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
+      {3}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
+      "axis_tensor");
+  auto output_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
+      output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
+      "output_tensor");
+  tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
+                                 {1});
+}
+
+TF_LITE_MICRO_TEST(GetMaxArgUInt8) {
+  using tflite::testing::F2Q;
+  int32_t output_data[1];
+  float input_min = 0;
+  float input_max = 15.9375;
+  TfLiteIntArray* input_dims =
+      tflite::testing::IntArrayFromInitializer({4, 1, 1, 1, 4});
+  auto input_data = {
+      F2Q(1., input_min, input_max), F2Q(9., input_min, input_max),
+      F2Q(7., input_min, input_max), F2Q(3., input_min, input_max)};
+  auto input_tensor = tflite::testing::CreateQuantizedTensor(
+      input_data, input_dims, "input_tensor", input_min, input_max);
+  auto axis_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
+      {3}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
+      "axis_tensor");
+  auto output_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
+      output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
+      "output_tensor");
+  tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
+                                 {1});
+}
+
+TF_LITE_MICRO_TEST(GetMaxArgInt8) {
+  int32_t output_data[1];
+  TfLiteIntArray* input_dims =
+      tflite::testing::IntArrayFromInitializer({4, 1, 1, 1, 4});
+  std::initializer_list<int8_t> input_data = {1, 9, 7, 3};
+  auto input_tensor = tflite::testing::CreateTensor<int8_t, kTfLiteInt8>(
+      input_data, input_dims, "input_tensor");
+  auto axis_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
+      {3}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
+      "axis_tensor");
+  auto output_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
+      output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
+      "output_tensor");
+  tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
+                                 {1});
+}
+
+TF_LITE_MICRO_TEST(GetMaxArgInt32) {
+  using tflite::testing::F2Q32;
+  int32_t output_data[1];
+  float input_min = 0;
+  float input_max = 31.9375;
+  TfLiteIntArray* input_dims =
+      tflite::testing::IntArrayFromInitializer({4, 1, 1, 1, 4});
+  auto input_data = {
+      F2Q32(1, input_min, input_max), F2Q32(9, input_min, input_max),
+      F2Q32(7, input_min, input_max), F2Q32(3, input_min, input_max)};
+  auto input_tensor = tflite::testing::CreateQuantized32Tensor(
+      input_data, input_dims, "input_tensor", input_min, input_max);
+  auto axis_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
+      {3}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
+      "axis_tensor");
+  auto output_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
+      output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
+      "output_tensor");
+  tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
+                                 {});  // Expects {1} if supported.
+}
+
+TF_LITE_MICRO_TEST(GetMaxArgMulDimensions) {
+  using tflite::testing::F2Q;
+  int32_t output_data[2];
+  float input_min = 0;
+  float input_max = 15.9375;
+  TfLiteIntArray* input_dims =
+      tflite::testing::IntArrayFromInitializer({4, 1, 1, 2, 4});
+  auto input_data = {
+      F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
+      F2Q(7, input_min, input_max), F2Q(8, input_min, input_max),
+      F2Q(1, input_min, input_max), F2Q(9, input_min, input_max),
+      F2Q(7, input_min, input_max), F2Q(3, input_min, input_max)};
+  auto input_tensor = tflite::testing::CreateQuantizedTensor(
+      input_data, input_dims, "input_tensor", input_min, input_max);
+  auto axis_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
+      {3}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
+      "axis_tensor");
+  auto output_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
+      output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 2}),
+      "output_tensor");
+  tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
+                                 {3, 1});
+}
+
+TF_LITE_MICRO_TEST(GetMaxArgNegativeAxis) {
+  using tflite::testing::F2Q;
+  int32_t output_data[4];
+  float input_min = 0;
+  float input_max = 15.9375;
+  TfLiteIntArray* input_dims =
+      tflite::testing::IntArrayFromInitializer({4, 1, 1, 2, 4});
+  auto input_data = {
+      F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
+      F2Q(7, input_min, input_max), F2Q(8, input_min, input_max),
+      F2Q(1, input_min, input_max), F2Q(9, input_min, input_max),
+      F2Q(7, input_min, input_max), F2Q(3, input_min, input_max)};
+  auto input_tensor = tflite::testing::CreateQuantizedTensor(
+      input_data, input_dims, "input_tensor", input_min, input_max);
+  auto axis_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
+      {-2}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
+      "axis_tensor");
+  auto output_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
+      output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 4}),
+      "output_tensor");
+  tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
+                                 {0, 1, 0, 0});
+}
+
+TF_LITE_MICRO_TEST(GetMaxArgOutput64) {
+  using tflite::testing::F2Q;
+  int64_t output_data[2];
+  float input_min = 0;
+  float input_max = 15.9375;
+  TfLiteIntArray* input_dims =
+      tflite::testing::IntArrayFromInitializer({4, 1, 1, 2, 4});
+  auto input_data = {
+      F2Q(10, input_min, input_max), F2Q(2, input_min, input_max),
+      F2Q(7, input_min, input_max),  F2Q(8, input_min, input_max),
+      F2Q(1, input_min, input_max),  F2Q(9, input_min, input_max),
+      F2Q(7, input_min, input_max),  F2Q(3, input_min, input_max)};
+  auto input_tensor = tflite::testing::CreateQuantizedTensor(
+      input_data, input_dims, "input_tensor", input_min, input_max);
+  auto axis_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
+      {3}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
+      "axis_tensor");
+  auto output_tensor = tflite::testing::CreateTensor<int64_t, kTfLiteInt64>(
+      output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 2}),
+      "output_tensor");
+  tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
+                                 {});  // Expects {0, 1} if supported.
+}
+
+TF_LITE_MICRO_TEST(GetMaxArgAxis64) {
+  using tflite::testing::F2Q;
+  int32_t output_data[2];
+  float input_min = 0;
+  float input_max = 15.9375;
+  TfLiteIntArray* input_dims =
+      tflite::testing::IntArrayFromInitializer({4, 1, 1, 2, 4});
+  auto input_data = {
+      F2Q(10, input_min, input_max), F2Q(2, input_min, input_max),
+      F2Q(7, input_min, input_max),  F2Q(8, input_min, input_max),
+      F2Q(1, input_min, input_max),  F2Q(9, input_min, input_max),
+      F2Q(7, input_min, input_max),  F2Q(3, input_min, input_max)};
+  auto input_tensor = tflite::testing::CreateQuantizedTensor(
+      input_data, input_dims, "input_tensor", input_min, input_max);
+  auto axis_tensor = tflite::testing::CreateTensor<int64_t, kTfLiteInt64>(
+      {3}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
+      "axis_tensor");
+  auto output_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
+      output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 2}),
+      "output_tensor");
+  tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
+                                 {});  // Expects {0, 1} if supported.
+}
+
+TF_LITE_MICRO_TEST(GetMinArgFloat) {
+  int32_t output_data[1];
+  TfLiteIntArray* input_dims =
+      tflite::testing::IntArrayFromInitializer({4, 1, 1, 1, 4});
+  auto input_tensor = tflite::testing::CreateFloatTensor(
+      {0.1, 0.9, 0.7, 0.3}, input_dims, "input_tensor");
+  auto axis_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
+      {3}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
+      "axis_tensor");
+  auto output_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
+      output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
+      "output_tensor");
+  tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
+                                 {0}, true);
+}
+
+TF_LITE_MICRO_TEST(GetMinArgUInt8) {
+  using tflite::testing::F2Q;
+  float input_min = 0;
+  float input_max = 15.9375;
+  int32_t output_data[1];
+  TfLiteIntArray* input_dims =
+      tflite::testing::IntArrayFromInitializer({4, 1, 1, 1, 4});
+  // Getting weird error when defining input_data directly in
+  // CreateQuantizedTensor. So I have to define it ahead.
+  auto input_data = {
+      F2Q(1.0, input_min, input_max), F2Q(9.0, input_min, input_max),
+      F2Q(7.0, input_min, input_max), F2Q(3.0, input_min, input_max)};
+  auto input_tensor = tflite::testing::CreateQuantizedTensor(
+      input_data, input_dims, "input_tensor", input_min, input_max);
+  auto axis_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
+      {3}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
+      "axis_tensor");
+  auto output_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
+      output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
+      "output_tensor");
+  tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
+                                 {0}, true);
+}
+
+TF_LITE_MICRO_TEST(GetMinArgInt8) {
+  int32_t output_data[1];
+  TfLiteIntArray* input_dims =
+      tflite::testing::IntArrayFromInitializer({4, 1, 1, 1, 4});
+  std::initializer_list<int8_t> input_data = {1, 9, 7, 3};
+  auto input_tensor = tflite::testing::CreateTensor<int8_t, kTfLiteInt8>(
+      input_data, input_dims, "input_tensor");
+  auto axis_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
+      {3}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
+      "axis_tensor");
+  auto output_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
+      output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
+      "output_tensor");
+  tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
+                                 {0}, true);
+}
+
+TF_LITE_MICRO_TEST(GetMinArgMulDimensions) {
+  using tflite::testing::F2Q;
+  float input_min = 0;
+  float input_max = 15.9375;
+  int32_t output_data[2];
+  TfLiteIntArray* input_dims =
+      tflite::testing::IntArrayFromInitializer({4, 1, 1, 2, 4});
+  auto input_data = {
+      F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
+      F2Q(7, input_min, input_max), F2Q(8, input_min, input_max),
+      F2Q(1, input_min, input_max), F2Q(9, input_min, input_max),
+      F2Q(7, input_min, input_max), F2Q(3, input_min, input_max)};
+  auto input_tensor = tflite::testing::CreateQuantizedTensor(
+      input_data, input_dims, "input_tensor", input_min, input_max);
+  auto axis_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
+      {3}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
+      "axis_tensor");
+  auto output_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
+      output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 2}),
+      "output_tensor");
+  tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
+                                 {0, 0}, true);
+}
+
+TF_LITE_MICRO_TEST(GetMinArgOutput64) {
+  using tflite::testing::F2Q;
+  float input_min = 0;
+  float input_max = 15.9375;
+  int64_t output_data[2];
+  TfLiteIntArray* input_dims =
+      tflite::testing::IntArrayFromInitializer({4, 1, 1, 2, 4});
+  auto input_data = {
+      F2Q(10, input_min, input_max), F2Q(2, input_min, input_max),
+      F2Q(7, input_min, input_max),  F2Q(8, input_min, input_max),
+      F2Q(1, input_min, input_max),  F2Q(9, input_min, input_max),
+      F2Q(7, input_min, input_max),  F2Q(3, input_min, input_max)};
+  auto input_tensor = tflite::testing::CreateQuantizedTensor(
+      input_data, input_dims, "input_tensor", input_min, input_max);
+  auto axis_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
+      {3}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
+      "axis_tensor");
+  auto output_tensor = tflite::testing::CreateTensor<int64_t, kTfLiteInt64>(
+      output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 2}),
+      "output_tensor");
+  tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
+                                 {}, true);  // Expects {1, 0} if supported.
+}
+
+TF_LITE_MICRO_TEST(GetMinArgAxis64) {
+  using tflite::testing::F2Q;
+  float input_min = 0;
+  float input_max = 15.9375;
+  int32_t output_data[2];
+  TfLiteIntArray* input_dims =
+      tflite::testing::IntArrayFromInitializer({4, 1, 1, 2, 4});
+  auto input_data = {
+      F2Q(10, input_min, input_max), F2Q(2, input_min, input_max),
+      F2Q(7, input_min, input_max),  F2Q(8, input_min, input_max),
+      F2Q(1, input_min, input_max),  F2Q(9, input_min, input_max),
+      F2Q(7, input_min, input_max),  F2Q(3, input_min, input_max)};
+  auto input_tensor = tflite::testing::CreateQuantizedTensor(
+      input_data, input_dims, "input_tensor", input_min, input_max);
+  auto axis_tensor = tflite::testing::CreateTensor<int64_t, kTfLiteInt64>(
+      {3}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
+      "axis_tensor");
+  auto output_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
+      output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 2}),
+      "output_tensor");
+  tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
+                                 {}, true);  // Expects {1, 0} if supported
+}
+
+TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/experimental/micro/kernels/elementwise.cc b/tensorflow/lite/experimental/micro/kernels/elementwise.cc
index e6ea15e..eb90302 100644
--- a/tensorflow/lite/experimental/micro/kernels/elementwise.cc
+++ b/tensorflow/lite/experimental/micro/kernels/elementwise.cc
@@ -29,6 +29,10 @@
   return type == kTfLiteFloat32;
 }
 
+bool IsLogicalSupportedType(const TfLiteType type) {
+  return type == kTfLiteBool;
+}
+
 typedef bool (*IsSupportedType)(TfLiteType);
 template <IsSupportedType>
 TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
@@ -65,10 +69,43 @@
   return EvalImpl<float>(context, node, float_func, kTfLiteFloat32);
 }
 
+inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
+                                bool bool_func(bool)) {
+  return EvalImpl<bool>(context, node, bool_func, kTfLiteBool);
+}
+
 TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
   return EvalNumeric(context, node, std::abs);
 }
 
+TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
+  return EvalNumeric(context, node, std::sin);
+}
+
+TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) {
+  return EvalNumeric(context, node, std::cos);
+}
+
+TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
+  return EvalNumeric(context, node, std::log);
+}
+
+TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
+  return EvalNumeric(context, node, std::sqrt);
+}
+
+TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
+  return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
+}
+
+TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
+  return EvalNumeric(context, node, [](float f) { return f * f; });
+}
+
+TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
+  return EvalLogical(context, node, [](bool v) { return !v; });
+}
+
 }  // namespace
 }  // namespace elementwise
 
@@ -79,6 +116,63 @@
       elementwise::AbsEval};
   return &r;
 }
+
+TfLiteRegistration* Register_SIN() {
+  static TfLiteRegistration r = {
+      /* init */ nullptr, /* free */ nullptr,
+      elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+      elementwise::SinEval};
+  return &r;
+}
+
+TfLiteRegistration* Register_COS() {
+  static TfLiteRegistration r = {
+      /* init */ nullptr, /* free */ nullptr,
+      elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+      elementwise::CosEval};
+  return &r;
+}
+
+TfLiteRegistration* Register_LOG() {
+  static TfLiteRegistration r = {
+      /* init */ nullptr, /* free */ nullptr,
+      elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+      elementwise::LogEval};
+  return &r;
+}
+
+TfLiteRegistration* Register_SQRT() {
+  static TfLiteRegistration r = {
+      /* init */ nullptr, /* free */ nullptr,
+      elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+      elementwise::SqrtEval};
+  return &r;
+}
+
+TfLiteRegistration* Register_RSQRT() {
+  static TfLiteRegistration r = {
+      /* init */ nullptr, /* free */ nullptr,
+      elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+      elementwise::RsqrtEval};
+  return &r;
+}
+
+TfLiteRegistration* Register_SQUARE() {
+  static TfLiteRegistration r = {
+      /* init */ nullptr, /* free */ nullptr,
+      elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+      elementwise::SquareEval};
+  return &r;
+}
+
+TfLiteRegistration* Register_LOGICAL_NOT() {
+  static TfLiteRegistration r = {
+      /*init=*/nullptr, /*free=*/nullptr,
+      elementwise::GenericPrepare<elementwise::IsLogicalSupportedType>,
+      elementwise::LogicalNotEval};
+  return &r;
+}
+
 }  // namespace micro
 }  // namespace ops
 }  // namespace tflite
diff --git a/tensorflow/lite/experimental/micro/kernels/elementwise_test.cc b/tensorflow/lite/experimental/micro/kernels/elementwise_test.cc
index 1ba98af..c369260 100644
--- a/tensorflow/lite/experimental/micro/kernels/elementwise_test.cc
+++ b/tensorflow/lite/experimental/micro/kernels/elementwise_test.cc
@@ -22,7 +22,8 @@
 namespace tflite {
 namespace testing {
 
-void TestElementwiseFloat(std::initializer_list<int> input_dims_data,
+void TestElementwiseFloat(tflite::BuiltinOperator op,
+                          std::initializer_list<int> input_dims_data,
                           std::initializer_list<float> input_data,
                           std::initializer_list<int> output_dims_data,
                           std::initializer_list<float> expected_output_data,
@@ -47,7 +48,73 @@
   PopulateContext(tensors, tensors_size, &context);
   tflite::ops::micro::AllOpsResolver resolver;
   const TfLiteRegistration* registration =
-      resolver.FindOp(tflite::BuiltinOperator_ABS, /* version= */ 1);
+      resolver.FindOp(op, /* version= */ 1);
+  TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
+
+  void* user_data = nullptr;
+  if (registration->init) {
+    user_data = registration->init(&context, nullptr, 0);
+  }
+  auto inputs_array_data = {1, 0};
+  TfLiteIntArray* inputs_array = IntArrayFromInitializer(inputs_array_data);
+  auto outputs_array_data = {1, 1};
+  TfLiteIntArray* outputs_array = IntArrayFromInitializer(outputs_array_data);
+  auto temporaries_array_data = {0};
+  TfLiteIntArray* temporaries_array =
+      IntArrayFromInitializer(temporaries_array_data);
+
+  TfLiteNode node;
+  node.inputs = inputs_array;
+  node.outputs = outputs_array;
+  node.temporaries = temporaries_array;
+  node.user_data = user_data;
+  node.builtin_data = nullptr;
+  node.custom_initial_data = nullptr;
+  node.custom_initial_data_size = 0;
+  node.delegate = nullptr;
+
+  if (registration->prepare) {
+    TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
+  }
+  TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
+
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
+  if (registration->free) {
+    registration->free(&context, user_data);
+  }
+  for (int i = 0; i < output_dims_count; ++i) {
+    TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i],
+                              1e-5f);
+  }
+}
+
+void TestElementwiseBool(tflite::BuiltinOperator op,
+                         std::initializer_list<int> input_dims_data,
+                         std::initializer_list<bool> input_data,
+                         std::initializer_list<int> output_dims_data,
+                         std::initializer_list<bool> expected_output_data,
+                         bool* output_data) {
+  TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
+  TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
+  const int output_dims_count = ElementCount(*output_dims);
+
+  constexpr int input_size = 1;
+  constexpr int output_size = 1;
+  constexpr int tensors_size = input_size + output_size;
+  TfLiteTensor tensors[tensors_size] = {
+      CreateBoolTensor(input_data, input_dims, "input_tensor"),
+      CreateBoolTensor(output_data, output_dims, "output_tensor")};
+
+  // Place false in the uninitialized output buffer.
+  for (int i = 0; i < output_dims_count; ++i) {
+    output_data[i] = false;
+  }
+
+  TfLiteContext context;
+  PopulateContext(tensors, tensors_size, &context);
+  tflite::ops::micro::AllOpsResolver resolver;
+  const TfLiteRegistration* registration =
+      resolver.FindOp(op, /* version= */ 1);
   TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
 
   void* user_data = nullptr;
@@ -78,8 +145,7 @@
     registration->free(&context, user_data);
   }
   for (int i = 0; i < output_dims_count; ++i) {
-    TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i],
-                              1e-5f);
+    TF_LITE_MICRO_EXPECT_EQ(expected_output_data.begin()[i], output_data[i]);
   }
 }
 
@@ -92,14 +158,95 @@
   constexpr int output_dims_count = 4;
   float output_data[output_dims_count];
   tflite::testing::TestElementwiseFloat(
-      {2, 2, 2},  // Input shape
-      {
-          0.01, -0.01, 10, -10,  // Input values
-      },
-      {2, 2, 2},  // Output shape
-      {
-          0.01, 0.01, 10, 10,  // Output values
-      },
+      tflite::BuiltinOperator_ABS,  // ABS operator
+      {2, 2, 2},                    // Input shape
+      {0.01, -0.01, 10, -10},       // Input values
+      {2, 2, 2},                    // Output shape
+      {0.01, 0.01, 10, 10},         // Output values
+      output_data);
+}
+
+TF_LITE_MICRO_TEST(Sin) {
+  constexpr int output_dims_count = 4;
+  float output_data[output_dims_count];
+  tflite::testing::TestElementwiseFloat(
+      tflite::BuiltinOperator_SIN,    // SIN operator
+      {2, 2, 2},                      // Input shape
+      {0, 3.1415926, -3.1415926, 1},  // Input values
+      {2, 2, 2},                      // Output shape
+      {0, 0, 0, 0.84147},             // Output values
+      output_data);
+}
+
+TF_LITE_MICRO_TEST(Cos) {
+  constexpr int output_dims_count = 4;
+  float output_data[output_dims_count];
+  tflite::testing::TestElementwiseFloat(
+      tflite::BuiltinOperator_COS,    // COS operator
+      {2, 2, 2},                      // Input shape
+      {0, 3.1415926, -3.1415926, 1},  // Input values
+      {2, 2, 2},                      // Output shape
+      {1, -1, -1, 0.54030},           // Output values
+      output_data);
+}
+
+TF_LITE_MICRO_TEST(Log) {
+  constexpr int output_dims_count = 4;
+  float output_data[output_dims_count];
+  tflite::testing::TestElementwiseFloat(
+      tflite::BuiltinOperator_LOG,    // LOG operator
+      {2, 2, 2},                      // Input shape
+      {1, 2.7182818, 0.5, 2},         // Input values
+      {2, 2, 2},                      // Output shape
+      {0, 1, -0.6931472, 0.6931472},  // Output values
+      output_data);
+}
+
+TF_LITE_MICRO_TEST(Sqrt) {
+  constexpr int output_dims_count = 4;
+  float output_data[output_dims_count];
+  tflite::testing::TestElementwiseFloat(
+      tflite::BuiltinOperator_SQRT,  // SQRT operator
+      {2, 2, 2},                     // Input shape
+      {0, 1, 2, 4},                  // Input values
+      {2, 2, 2},                     // Output shape
+      {0, 1, 1.41421, 2},            // Output values
+      output_data);
+}
+
+TF_LITE_MICRO_TEST(Rsqrt) {
+  constexpr int output_dims_count = 4;
+  float output_data[output_dims_count];
+  tflite::testing::TestElementwiseFloat(
+      tflite::BuiltinOperator_RSQRT,  // RSQRT operator
+      {2, 2, 2},                      // Input shape
+      {1, 2, 4, 9},                   // Input values
+      {2, 2, 2},                      // Output shape
+      {1, 0.7071, 0.5, 0.33333},      // Output values
+      output_data);
+}
+
+TF_LITE_MICRO_TEST(Square) {
+  constexpr int output_dims_count = 4;
+  float output_data[output_dims_count];
+  tflite::testing::TestElementwiseFloat(
+      tflite::BuiltinOperator_SQUARE,  // SQARE operator
+      {2, 2, 2},                       // Input shape
+      {1, 2, 0.5, -3.0},               // Input values
+      {2, 2, 2},                       // Output shape
+      {1, 4.0, 0.25, 9.0},             // Output values
+      output_data);
+}
+
+TF_LITE_MICRO_TEST(LogicalNot) {
+  constexpr int output_dims_count = 4;
+  bool output_data[output_dims_count];
+  tflite::testing::TestElementwiseBool(
+      tflite::BuiltinOperator_LOGICAL_NOT,  // Logical NOT operator
+      {2, 2, 2},                            // Input shape
+      {true, false, false, true},           // Input values
+      {2, 2, 2},                            // Output shape
+      {false, true, true, false},           // Output values
       output_data);
 }
 
diff --git a/tensorflow/lite/experimental/micro/kernels/floor.cc b/tensorflow/lite/experimental/micro/kernels/floor.cc
new file mode 100644
index 0000000..7b55cff
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/kernels/floor.cc
@@ -0,0 +1,48 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/kernels/internal/reference/floor.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace micro {
+namespace floor {
+
+constexpr int kInputTensor = 0;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+  TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+  reference_ops::Floor(GetTensorShape(input), GetTensorData<float>(input),
+                       GetTensorShape(output), GetTensorData<float>(output));
+  return kTfLiteOk;
+}
+}  // namespace floor
+
+TfLiteRegistration* Register_FLOOR() {
+  static TfLiteRegistration r = {/*init=*/nullptr,
+                                 /*free=*/nullptr, /*prepare=*/nullptr,
+                                 floor::Eval};
+  return &r;
+}
+
+}  // namespace micro
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/lite/experimental/micro/kernels/floor_test.cc b/tensorflow/lite/experimental/micro/kernels/floor_test.cc
new file mode 100644
index 0000000..7b65a40
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/kernels/floor_test.cc
@@ -0,0 +1,99 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
+#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h"
+#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
+#include "tensorflow/lite/experimental/micro/testing/test_utils.h"
+
+namespace tflite {
+namespace testing {
+namespace {
+
+void TestFloor(std::initializer_list<int> input_dims_data,
+               std::initializer_list<float> input_data,
+               std::initializer_list<float> expected_output_data,
+               std::initializer_list<int> output_dims_data,
+               float* output_data) {
+  TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
+  TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
+  const int output_dims_count = ElementCount(*output_dims);
+  constexpr int inputs_size = 1;
+  constexpr int outputs_size = 1;
+  constexpr int tensors_size = inputs_size + outputs_size;
+  TfLiteTensor tensors[tensors_size] = {
+      CreateFloatTensor(input_data, input_dims, "input_tensor"),
+      CreateFloatTensor(output_data, output_dims, "output_tensor"),
+  };
+  TfLiteContext context;
+  PopulateContext(tensors, tensors_size, &context);
+  ::tflite::ops::micro::AllOpsResolver resolver;
+  const TfLiteRegistration* registration =
+      resolver.FindOp(tflite::BuiltinOperator_FLOOR, 1);
+  TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
+
+  int inputs_array_data[] = {1, 0};
+  TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
+  int outputs_array_data[] = {1, 1};
+  TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
+  TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
+  TfLiteNode node;
+  node.inputs = inputs_array;
+  node.outputs = outputs_array;
+  node.temporaries = temporaries_array;
+  node.user_data = nullptr;
+  node.builtin_data = nullptr;
+  node.custom_initial_data = nullptr;
+  node.custom_initial_data_size = 0;
+  node.delegate = nullptr;
+  TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
+  for (int i = 0; i < output_dims_count; ++i) {
+    TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i],
+                              1e-5f);
+  }
+}
+
+}  // namespace
+}  // namespace testing
+}  // namespace tflite
+
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(FloorOpSingleDimFloat32) {
+  float output_data[2];
+  tflite::testing::TestFloor(/*input_dims_data=*/{1, 2},
+                             /*input_data=*/{8.5f, 0.0f},
+                             /*expected_output_data=*/{8, 0},
+                             /*output_dims_data*/ {1, 2},
+                             /*output_data=*/output_data);
+}
+
+TF_LITE_MICRO_TEST(FloorOpMultiDimFloat32) {
+  float output_data[10];
+  tflite::testing::TestFloor(
+      /*input_dims_data=*/{4, 2, 1, 1, 5},
+      /*input_data=*/
+      {0.0001f, 8.0001f, 0.9999f, 9.9999f, 0.5f, -0.0001f, -8.0001f, -0.9999f,
+       -9.9999f, -0.5f},
+      /*expected_output_data=*/
+      {0.0f, 8.0f, 0.0f, 9.0f, 0.0f, -1.0f, -9.0f, -1.0f, -10.0f, -1.0f},
+      /*output_dims_data=*/{4, 2, 1, 1, 5},
+      /*output_data=*/output_data);
+}
+
+TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/experimental/micro/kernels/logical.cc b/tensorflow/lite/experimental/micro/kernels/logical.cc
new file mode 100644
index 0000000..8c2aa34
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/kernels/logical.cc
@@ -0,0 +1,87 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/kernels/internal/reference/binary_function.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace micro {
+namespace logical {
+namespace {
+
+// Input/output tensor index.
+constexpr int kInputTensor1 = 0;
+constexpr int kInputTensor2 = 1;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node,
+                         bool (*func)(bool, bool)) {
+  const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+  const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+  if (HaveSameShapes(input1, input2)) {
+    reference_ops::BinaryFunction<bool, bool, bool>(
+        GetTensorShape(input1), GetTensorData<bool>(input1),
+        GetTensorShape(input2), GetTensorData<bool>(input2),
+        GetTensorShape(output), GetTensorData<bool>(output), func);
+  } else {
+    reference_ops::BroadcastBinaryFunction4DSlow<bool, bool, bool>(
+        GetTensorShape(input1), GetTensorData<bool>(input1),
+        GetTensorShape(input2), GetTensorData<bool>(input2),
+        GetTensorShape(output), GetTensorData<bool>(output), func);
+  }
+
+  return kTfLiteOk;
+}
+
+bool LogicalOr(bool x, bool y) { return x || y; }
+
+TfLiteStatus LogicalOrEval(TfLiteContext* context, TfLiteNode* node) {
+  return LogicalImpl(context, node, LogicalOr);
+}
+
+bool LogicalAnd(bool x, bool y) { return x && y; }
+
+TfLiteStatus LogicalAndEval(TfLiteContext* context, TfLiteNode* node) {
+  return LogicalImpl(context, node, LogicalAnd);
+}
+
+}  // namespace
+}  // namespace logical
+
+TfLiteRegistration* Register_LOGICAL_OR() {
+  // Init, Free, Prepare, Eval are satisfying the Interface required by
+  // TfLiteRegistration.
+  static TfLiteRegistration r = {/* init */ nullptr, /* free */ nullptr,
+                                 /* prepare */ nullptr, logical::LogicalOrEval};
+  return &r;
+}
+
+TfLiteRegistration* Register_LOGICAL_AND() {
+  // Init, Free, Prepare, Eval are satisfying the Interface required by
+  // TfLiteRegistration.
+  static TfLiteRegistration r = {/* init */ nullptr, /* free */ nullptr,
+                                 /* prepare */ nullptr,
+                                 logical::LogicalAndEval};
+  return &r;
+}
+
+}  // namespace micro
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/lite/experimental/micro/kernels/logical_test.cc b/tensorflow/lite/experimental/micro/kernels/logical_test.cc
new file mode 100644
index 0000000..55dfaca
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/kernels/logical_test.cc
@@ -0,0 +1,128 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
+#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h"
+#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
+#include "tensorflow/lite/experimental/micro/testing/test_utils.h"
+
+namespace tflite {
+namespace testing {
+namespace {
+
+void TestLogicalOp(tflite::BuiltinOperator op,
+                   std::initializer_list<int> input1_dims_data,
+                   std::initializer_list<bool> input1_data,
+                   std::initializer_list<int> input2_dims_data,
+                   std::initializer_list<bool> input2_data,
+                   std::initializer_list<int> output_dims_data,
+                   std::initializer_list<bool> expected_output_data,
+                   bool* output_data) {
+  TfLiteIntArray* input1_dims = IntArrayFromInitializer(input1_dims_data);
+  TfLiteIntArray* input2_dims = IntArrayFromInitializer(input2_dims_data);
+  TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
+  const int output_dims_count = ElementCount(*output_dims);
+
+  constexpr int inputs_size = 2;
+  constexpr int outputs_size = 1;
+  constexpr int tensors_size = inputs_size + outputs_size;
+  TfLiteTensor tensors[tensors_size] = {
+      CreateBoolTensor(input1_data, input1_dims, "input1_tensor"),
+      CreateBoolTensor(input2_data, input2_dims, "input2_tensor"),
+      CreateBoolTensor(output_data, output_dims, "output_tensor"),
+  };
+
+  TfLiteContext context;
+  PopulateContext(tensors, tensors_size, &context);
+
+  ::tflite::ops::micro::AllOpsResolver resolver;
+  const TfLiteRegistration* registration = resolver.FindOp(op, 1);
+  TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
+
+  TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
+  TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2});
+  TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
+
+  TfLiteNode node;
+  node.inputs = inputs_array;
+  node.outputs = outputs_array;
+  node.temporaries = temporaries_array;
+  node.user_data = nullptr;
+  node.builtin_data = nullptr;
+  node.custom_initial_data = nullptr;
+  node.custom_initial_data_size = 0;
+  node.delegate = nullptr;
+
+  if (registration->prepare) {
+    TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
+  }
+
+  TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
+
+  TF_LITE_MICRO_EXPECT_EQ(output_dims_count, 4);
+  for (int i = 0; i < output_dims_count; ++i) {
+    TF_LITE_MICRO_EXPECT_EQ(expected_output_data.begin()[i], output_data[i]);
+  }
+}
+
+}  // namespace
+}  // namespace testing
+}  // namespace tflite
+
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(LogicalOr) {
+  bool output_data[4];
+  tflite::testing::TestLogicalOp(
+      tflite::BuiltinOperator_LOGICAL_OR,           // operator
+      {4, 1, 1, 1, 4}, {true, false, false, true},  // input1
+      {4, 1, 1, 1, 4}, {true, false, true, false},  // input2
+      {4, 1, 1, 1, 4}, {true, false, true, true},   // expected output
+      output_data);
+}
+
+TF_LITE_MICRO_TEST(BroadcastLogicalOr) {
+  bool output_data[4];
+  tflite::testing::TestLogicalOp(
+      tflite::BuiltinOperator_LOGICAL_OR,           // operator
+      {4, 1, 1, 1, 4}, {true, false, false, true},  // input1
+      {4, 1, 1, 1, 1}, {false},                     // input2
+      {4, 1, 1, 1, 4}, {true, false, false, true},  // expected output
+      output_data);
+}
+
+TF_LITE_MICRO_TEST(LogicalAnd) {
+  bool output_data[4];
+  tflite::testing::TestLogicalOp(
+      tflite::BuiltinOperator_LOGICAL_AND,           // operator
+      {4, 1, 1, 1, 4}, {true, false, false, true},   // input1
+      {4, 1, 1, 1, 4}, {true, false, true, false},   // input2
+      {4, 1, 1, 1, 4}, {true, false, false, false},  // expected output
+      output_data);
+}
+
+TF_LITE_MICRO_TEST(BroadcastLogicalAnd) {
+  bool output_data[4];
+  tflite::testing::TestLogicalOp(
+      tflite::BuiltinOperator_LOGICAL_AND,          // operator
+      {4, 1, 1, 1, 4}, {true, false, false, true},  // input1
+      {4, 1, 1, 1, 1}, {true},                      // input2
+      {4, 1, 1, 1, 4}, {true, false, false, true},  // expected output
+      output_data);
+}
+
+TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/experimental/micro/kernels/maximum_minimum.cc b/tensorflow/lite/experimental/micro/kernels/maximum_minimum.cc
new file mode 100644
index 0000000..bbbfb03
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/kernels/maximum_minimum.cc
@@ -0,0 +1,141 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/kernels/internal/reference/maximum_minimum.h"
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace micro {
+namespace maximum_minimum {
+namespace {
+
+// This file has a reference implementation of TFMaximum/TFMinimum.
+enum KernelType {
+  kReference,
+};
+
+constexpr int kInputTensor1 = 0;
+constexpr int kInputTensor2 = 1;
+constexpr int kOutputTensor = 0;
+
+struct OpContext {
+  OpContext(TfLiteContext* context, TfLiteNode* node) {
+    input1 = GetInput(context, node, kInputTensor1);
+    input2 = GetInput(context, node, kInputTensor2);
+    output = GetOutput(context, node, kOutputTensor);
+  }
+  const TfLiteTensor* input1;
+  const TfLiteTensor* input2;
+  TfLiteTensor* output;
+};
+
+struct MaximumOp {
+  template <typename data_type>
+  static data_type op(data_type el1, data_type el2) {
+    return el1 > el2 ? el1 : el2;
+  }
+};
+
+struct MinimumOp {
+  template <typename data_type>
+  static data_type op(data_type el1, data_type el2) {
+    return el1 < el2 ? el1 : el2;
+  }
+};
+
+}  // namespace
+
+template <typename data_type, typename op_type>
+void TFLiteOperation(TfLiteContext* context, TfLiteNode* node,
+                     const OpContext& op_context) {
+  reference_ops::MaximumMinimumBroadcast4DSlow(
+      GetTensorShape(op_context.input1),
+      GetTensorData<data_type>(op_context.input1),
+      GetTensorShape(op_context.input2),
+      GetTensorData<data_type>(op_context.input2),
+      GetTensorShape(op_context.output),
+      GetTensorData<data_type>(op_context.output),
+      op_type::template op<data_type>);
+}
+
+template <KernelType kernel_type, typename OpType>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+  OpContext op_context(context, node);
+
+  if (kernel_type == kReference) {
+    switch (op_context.output->type) {
+      case kTfLiteFloat32:
+        TFLiteOperation<float, OpType>(context, node, op_context);
+        break;
+      case kTfLiteUInt8:
+        TFLiteOperation<uint8_t, OpType>(context, node, op_context);
+        break;
+      case kTfLiteInt8:
+        TFLiteOperation<int8_t, OpType>(context, node, op_context);
+        break;
+      case kTfLiteInt32:
+        TFLiteOperation<int32_t, OpType>(context, node, op_context);
+        break;
+      case kTfLiteInt64:
+        TFLiteOperation<int64_t, OpType>(context, node, op_context);
+        break;
+      default:
+        context->ReportError(context,
+                             "Type %d is not supported by Maximum/Minimum.",
+                             op_context.output->type);
+        return kTfLiteError;
+    }
+  } else {
+    context->ReportError(context,
+                         "Kernel type not supported by Maximum/Minimum.",
+                         op_context.output->type);
+    return kTfLiteError;
+  }
+  return kTfLiteOk;
+}
+
+}  // namespace maximum_minimum
+
+TfLiteRegistration* Register_MAXIMUM() {
+  static TfLiteRegistration r = {
+      /* init */ nullptr,
+      /* free */ nullptr,
+      /* prepare */ nullptr,
+      maximum_minimum::Eval<maximum_minimum::kReference,
+                            maximum_minimum::MaximumOp>};
+  return &r;
+}
+
+TfLiteRegistration* Register_MINIMUM() {
+  static TfLiteRegistration r = {
+      /* init */ nullptr,
+      /* free */ nullptr,
+      /* prepare */ nullptr,
+      maximum_minimum::Eval<maximum_minimum::kReference,
+                            maximum_minimum::MinimumOp>};
+  return &r;
+}
+
+}  // namespace micro
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/lite/experimental/micro/kernels/maximum_minimum_test.cc b/tensorflow/lite/experimental/micro/kernels/maximum_minimum_test.cc
new file mode 100644
index 0000000..b944b4b
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/kernels/maximum_minimum_test.cc
@@ -0,0 +1,314 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
+#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h"
+#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
+#include "tensorflow/lite/experimental/micro/testing/test_utils.h"
+
+namespace tflite {
+namespace testing {
+namespace {
+
+void TestMaxMinFloat(tflite::BuiltinOperator op,
+                     std::initializer_list<int> input1_dims_data,
+                     std::initializer_list<float> input1_data,
+                     std::initializer_list<int> input2_dims_data,
+                     std::initializer_list<float> input2_data,
+                     std::initializer_list<float> expected_output_data,
+                     std::initializer_list<int> output_dims_data,
+                     float* output_data) {
+  TfLiteIntArray* input1_dims = IntArrayFromInitializer(input1_dims_data);
+  TfLiteIntArray* input2_dims = IntArrayFromInitializer(input2_dims_data);
+  TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
+  const int output_dims_count = ElementCount(*output_dims);
+
+  constexpr int inputs_size = 2;
+  constexpr int outputs_size = 1;
+  constexpr int tensors_size = inputs_size + outputs_size;
+  TfLiteTensor tensors[tensors_size] = {
+      CreateFloatTensor(input1_data, input1_dims, "input1_tensor"),
+      CreateFloatTensor(input2_data, input2_dims, "input2_tensor"),
+      CreateFloatTensor(output_data, output_dims, "output_tensor"),
+  };
+
+  TfLiteContext context;
+  PopulateContext(tensors, tensors_size, &context);
+
+  ::tflite::ops::micro::AllOpsResolver resolver;
+  const TfLiteRegistration* registration = resolver.FindOp(op, 1);
+  TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
+
+  TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
+  TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2});
+  TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
+
+  TfLiteNode node;
+  node.inputs = inputs_array;
+  node.outputs = outputs_array;
+  node.temporaries = temporaries_array;
+  node.user_data = nullptr;
+  node.builtin_data = nullptr;
+  node.custom_initial_data = nullptr;
+  node.custom_initial_data_size = 0;
+  node.delegate = nullptr;
+
+  if (registration->prepare) {
+    TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
+  }
+
+  TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
+
+  for (int i = 0; i < output_dims_count; ++i) {
+    TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i],
+                              1e-5);
+  }
+}
+
+void TestMaxMinQuantized(
+    tflite::BuiltinOperator op, std::initializer_list<int> input1_dims_data,
+    std::initializer_list<uint8_t> input1_data, float input1_min,
+    float input1_max, std::initializer_list<int> input2_dims_data,
+    std::initializer_list<uint8_t> input2_data, float input2_min,
+    float input2_max, std::initializer_list<uint8_t> expected_output_data,
+    float output_min, float output_max,
+    std::initializer_list<int> output_dims_data, uint8_t* output_data) {
+  TfLiteIntArray* input1_dims = IntArrayFromInitializer(input1_dims_data);
+  TfLiteIntArray* input2_dims = IntArrayFromInitializer(input2_dims_data);
+  TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
+  const int output_dims_count = ElementCount(*output_dims);
+
+  constexpr int inputs_size = 2;
+  constexpr int outputs_size = 1;
+  constexpr int tensors_size = inputs_size + outputs_size;
+  TfLiteTensor tensors[tensors_size] = {
+      CreateQuantizedTensor(input1_data, input1_dims, "input1_tensor",
+                            input1_min, input1_max),
+      CreateQuantizedTensor(input2_data, input2_dims, "input2_tensor",
+                            input2_min, input2_max),
+      CreateQuantizedTensor(output_data, output_dims, "output_tensor",
+                            output_min, output_max),
+  };
+
+  TfLiteContext context;
+  PopulateContext(tensors, tensors_size, &context);
+
+  ::tflite::ops::micro::AllOpsResolver resolver;
+  const TfLiteRegistration* registration = resolver.FindOp(op, 1);
+  TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
+
+  TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
+  TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2});
+  TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
+
+  TfLiteNode node;
+  node.inputs = inputs_array;
+  node.outputs = outputs_array;
+  node.temporaries = temporaries_array;
+  node.user_data = nullptr;
+  node.builtin_data = nullptr;
+  node.custom_initial_data = nullptr;
+  node.custom_initial_data_size = 0;
+  node.delegate = nullptr;
+
+  if (registration->prepare) {
+    TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
+  }
+
+  TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
+
+  for (int i = 0; i < output_dims_count; ++i) {
+    TF_LITE_MICRO_EXPECT_EQ(expected_output_data.begin()[i], output_data[i]);
+  }
+}
+
+void TestMaxMinQuantizedInt32(
+    tflite::BuiltinOperator op, std::initializer_list<int> input1_dims_data,
+    std::initializer_list<int32_t> input1_data, float input1_min,
+    float input1_max, std::initializer_list<int> input2_dims_data,
+    std::initializer_list<int32_t> input2_data, float input2_min,
+    float input2_max, std::initializer_list<int32_t> expected_output_data,
+    float output_min, float output_max,
+    std::initializer_list<int> output_dims_data, int32_t* output_data) {
+  TfLiteIntArray* input1_dims = IntArrayFromInitializer(input1_dims_data);
+  TfLiteIntArray* input2_dims = IntArrayFromInitializer(input2_dims_data);
+  TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
+  const int output_dims_count = ElementCount(*output_dims);
+
+  constexpr int inputs_size = 2;
+  constexpr int outputs_size = 1;
+  constexpr int tensors_size = inputs_size + outputs_size;
+  TfLiteTensor tensors[tensors_size] = {
+      CreateQuantized32Tensor(input1_data, input1_dims, "input1_tensor",
+                              input1_min, input1_max),
+      CreateQuantized32Tensor(input2_data, input2_dims, "input2_tensor",
+                              input2_min, input2_max),
+      CreateQuantized32Tensor(output_data, output_dims, "output_tensor",
+                              output_min, output_max),
+  };
+
+  TfLiteContext context;
+  PopulateContext(tensors, tensors_size, &context);
+
+  ::tflite::ops::micro::AllOpsResolver resolver;
+  const TfLiteRegistration* registration = resolver.FindOp(op, 1);
+  TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
+
+  TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
+  TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2});
+  TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
+
+  TfLiteNode node;
+  node.inputs = inputs_array;
+  node.outputs = outputs_array;
+  node.temporaries = temporaries_array;
+  node.user_data = nullptr;
+  node.builtin_data = nullptr;
+  node.custom_initial_data = nullptr;
+  node.custom_initial_data_size = 0;
+  node.delegate = nullptr;
+
+  if (registration->prepare) {
+    TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
+  }
+
+  TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
+
+  for (int i = 0; i < output_dims_count; ++i) {
+    TF_LITE_MICRO_EXPECT_EQ(expected_output_data.begin()[i], output_data[i]);
+  }
+}
+
+}  // namespace
+}  // namespace testing
+}  // namespace tflite
+
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(FloatTest) {
+  std::initializer_list<float> data1 = {1.0, 0.0, -1.0, 11.0, -2.0, -1.44};
+  std::initializer_list<float> data2 = {-1.0, 0.0, 1.0, 12.0, -3.0, -1.43};
+  float output_data[6];
+
+  tflite::testing::TestMaxMinFloat(
+      tflite::BuiltinOperator_MAXIMUM, {3, 3, 1, 2},
+      data1,                               // input1 shape and data
+      {3, 3, 1, 2}, data2,                 // input2 shape and data
+      {1.0, 0.0, 1.0, 12.0, -2.0, -1.43},  // expected output
+      {3, 3, 1, 2}, output_data);          // output shape and data buffer
+
+  tflite::testing::TestMaxMinFloat(
+      tflite::BuiltinOperator_MINIMUM, {3, 3, 1, 2},
+      data1,                                 // input1 shape and data
+      {3, 3, 1, 2}, data2,                   // input2 shape and data
+      {-1.0, 0.0, -1.0, 11.0, -3.0, -1.44},  // expected output
+      {3, 3, 1, 2}, output_data);            // output shape and data buffer
+}
+
+TF_LITE_MICRO_TEST(Uint8Test) {
+  std::initializer_list<uint8_t> data1 = {1, 0, 2, 11, 2, 23};
+  std::initializer_list<uint8_t> data2 = {0, 0, 1, 12, 255, 1};
+  const float input1_min = -63.5;
+  const float input1_max = 64;
+  const float input2_min = -63.5;
+  const float input2_max = 64;
+  const float output_min = -63.5;
+  const float output_max = 64;
+
+  uint8_t output_data[6];
+
+  tflite::testing::TestMaxMinQuantized(
+      tflite::BuiltinOperator_MAXIMUM,
+      // input1 shape, data and bounds
+      {3, 3, 1, 2}, data1, input1_min, input1_max,
+      // input2 shape, data and bounds
+      {3, 3, 1, 2}, data2, input2_min, input2_max,
+      // expected output
+      {1, 0, 2, 12, 255, 23},
+      // output bounds, shape and data buffer
+      output_min, output_max, {3, 3, 1, 2}, output_data);
+
+  tflite::testing::TestMaxMinQuantized(
+      tflite::BuiltinOperator_MINIMUM,
+      // input1 shape, data and bounds
+      {3, 3, 1, 2}, data1, input1_min, input1_max,
+      // input2 shape, data and bounds
+      {3, 3, 1, 2}, data2, input2_min, input2_max,
+      // expected output
+      {0, 0, 1, 11, 2, 1},
+      // output bounds, shape and data buffer
+      output_min, output_max, {3, 3, 1, 2}, output_data);
+}
+
+TF_LITE_MICRO_TEST(FloatWithBroadcastTest) {
+  std::initializer_list<float> data1 = {1.0, 0.0, -1.0, -2.0, -1.44, 11.0};
+  std::initializer_list<float> data2 = {0.5, 2.0};
+  float output_data[6];
+
+  tflite::testing::TestMaxMinFloat(
+      tflite::BuiltinOperator_MAXIMUM, {3, 3, 1, 2},
+      data1,                            // input1 shape and data
+      {1, 2}, data2,                    // input2 shape and data
+      {1.0, 2.0, 0.5, 2.0, 0.5, 11.0},  // expected output
+      {3, 3, 1, 2}, output_data);       // output shape and data buffer
+
+  tflite::testing::TestMaxMinFloat(
+      tflite::BuiltinOperator_MINIMUM, {3, 3, 1, 2},
+      data1,                               // input1 shape and data
+      {1, 2}, data2,                       // input2 shape and data
+      {0.5, 0.0, -1.0, -2.0, -1.44, 2.0},  // expected output
+      {3, 3, 1, 2}, output_data);          // output shape and data buffer
+}
+
+TF_LITE_MICRO_TEST(Int32WithBroadcastTest) {
+  const float input1_min = -63.5;
+  const float input1_max = 64;
+  const float input2_min = -63.5;
+  const float input2_max = 64;
+  const float output_min = -63.5;
+  const float output_max = 64;
+  std::initializer_list<int32_t> data1 = {1, 0, -1, -2, 3, 11};
+  std::initializer_list<int32_t> data2 = {2};
+  int32_t output_data[6];
+
+  tflite::testing::TestMaxMinQuantizedInt32(
+      tflite::BuiltinOperator_MAXIMUM,
+      // input1 shape, data and bounds
+      {3, 3, 1, 2}, data1, input1_min, input1_max,
+      // input2 shape, data and bounds
+      {1, 1}, data2, input2_min, input2_max,
+      // expected output
+      {2, 2, 2, 2, 3, 11},
+      // output bounds, shape and data buffer
+      output_min, output_max, {3, 3, 1, 2}, output_data);
+
+  tflite::testing::TestMaxMinQuantizedInt32(
+      tflite::BuiltinOperator_MINIMUM,
+      // input1 shape, data and bounds
+      {3, 3, 1, 2}, data1, input1_min, input1_max,
+      // input2 shape, data and bounds
+      {1, 1}, data2, input2_min, input2_max,
+      // expected output
+      {1, 0, -1, -2, 2, 2},
+      // output bounds, shape and data buffer
+      output_min, output_max, {3, 3, 1, 2}, output_data);
+}
+
+TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/experimental/micro/kernels/micro_utils.h b/tensorflow/lite/experimental/micro/kernels/micro_utils.h
new file mode 100644
index 0000000..dcb691f
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/kernels/micro_utils.h
@@ -0,0 +1,37 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_KERNELS_MICRO_UTILS_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_KERNELS_MICRO_UTILS_H_
+namespace tflite {
+namespace ops {
+namespace micro {
+
+// Same as gtl::Greater but defined here to reduce dependencies and
+// binary size for micro environment.
+struct Greater {
+  template <typename T>
+  bool operator()(const T& x, const T& y) const {
+    return x > y;
+  }
+};
+
+struct Less {
+  template <typename T>
+  bool operator()(const T& x, const T& y) const {
+    return x < y;
+  }
+};
+
+}  // namespace micro
+}  // namespace ops
+}  // namespace tflite
+#endif  // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_KERNELS_MICRO_UTILS_H_
diff --git a/tensorflow/lite/experimental/micro/kernels/prelu_test.cc b/tensorflow/lite/experimental/micro/kernels/prelu_test.cc
index 583b43b..6bc96ab 100644
--- a/tensorflow/lite/experimental/micro/kernels/prelu_test.cc
+++ b/tensorflow/lite/experimental/micro/kernels/prelu_test.cc
@@ -155,14 +155,14 @@
 TF_LITE_MICRO_TEST(FloatPreluActivationsOpTest) {
   const int output_dims_count = 12;
   float output_data[output_dims_count];
-  tflite::testing::TestPreluFloat({1, 2, 2, 3},  // input shape
+  tflite::testing::TestPreluFloat({4, 1, 2, 2, 3},  // input shape
                                   {
                                       0.0f, 0.0f, 0.0f,     // Row 1, Column 1
                                       1.0f, 1.0f, 1.0f,     // Row 1, Column 2
                                       -1.0f, -1.0f, -1.0f,  // Row 2, Column 1
                                       -2.0f, -2.0f, -2.0f,  // Row 1, Column 2
                                   },
-                                  {1, 1, 3},           // alpha shape
+                                  {3, 1, 1, 3},        // alpha shape
                                   {0.0f, 1.0f, 2.0f},  // alpha values
                                   {
                                       0.0f, 0.0f, 0.0f,    // Row 1, Column 1
@@ -170,7 +170,7 @@
                                       0.0f, -1.0f, -2.0f,  // Row 2, Column 1
                                       0.0f, -2.0f, -4.0f,  // Row 1, Column 2
                                   },
-                                  {1, 2, 2, 3},  // output shape
+                                  {4, 1, 2, 2, 3},  // output shape
                                   output_data);
 }
 
@@ -183,13 +183,13 @@
   const int output_dims_count = 12;
   uint8_t output_data[output_dims_count];
   tflite::testing::TestPreluQuantized(
-      {1, 2, 2, 3},  // input shape
+      {4, 1, 2, 2, 3},  // input shape
       {F2Q(0.0f, kMin, kMax), F2Q(0.0f, kMin, kMax), F2Q(0.0f, kMin, kMax),
        F2Q(0.5f, kMin, kMax), F2Q(0.5f, kMin, kMax), F2Q(0.5f, kMin, kMax),
        F2Q(-1.0f, kMin, kMax), F2Q(-1.0f, kMin, kMax), F2Q(-1.0f, kMin, kMax),
        F2Q(-0.25f, kMin, kMax), F2Q(-0.25f, kMin, kMax),
        F2Q(-0.25f, kMin, kMax)},
-      kMin, kMax, {1, 1, 3},  // alpha shape
+      kMin, kMax, {3, 1, 1, 3},  // alpha shape
       {F2Q(0.0f, kMin, kMax), F2Q(0.5f, kMin, kMax), F2Q(-0.5f, kMin, kMax)},
       kMin, kMax,
       {F2Q(0.0f, kMin, kMax), F2Q(0.0f, kMin, kMax), F2Q(0.0f, kMin, kMax),
@@ -197,7 +197,7 @@
        F2Q(0.0f, kMin, kMax), F2Q(-0.5f, kMin, kMax), F2Q(0.5f, kMin, kMax),
        F2Q(0.0f, kMin, kMax), F2Q(-0.125f, kMin, kMax),
        F2Q(0.125f, kMin, kMax)},
-      {1, 2, 2, 3},  // output shape
+      {4, 1, 2, 2, 3},  // output shape
       kMin, kMax, output_data);
 }
 
diff --git a/tensorflow/lite/experimental/micro/kernels/reshape.cc b/tensorflow/lite/experimental/micro/kernels/reshape.cc
new file mode 100644
index 0000000..338fc52
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/kernels/reshape.cc
@@ -0,0 +1,99 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace micro {
+namespace reshape {
+
+constexpr int kInputTensor = 0;
+constexpr int kShapeTensor = 1;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus ReshapeOutput(TfLiteContext* context, TfLiteNode* node) {
+  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+  // Tensorflow's Reshape allows one of the shape components to have the
+  // special -1 value, meaning it will be calculated automatically based on the
+  // input. Here we calculate what that dimension should be so that the number
+  // of output elements in the same as the number of input elements.
+  int num_input_elements = NumElements(input);
+  TfLiteIntArray* output_shape = output->dims;
+
+  if (NumInputs(node) == 1 &&  // Legacy scalar supported with params.
+      output_shape->size == 1 && output_shape->data[0] == 0) {
+    // Legacy tflite models use a shape parameter of [0] to indicate scalars,
+    // so adjust accordingly. TODO(b/111614235): Allow zero-sized buffers during
+    // toco conversion.
+    output_shape->size = 0;
+  }
+
+  int num_output_elements = 1;
+  int stretch_dim = -1;
+  for (int i = 0; i < output_shape->size; ++i) {
+    int value = output_shape->data[i];
+    if (value == -1) {
+      TF_LITE_ENSURE_EQ(context, stretch_dim, -1);
+      stretch_dim = i;
+    } else {
+      num_output_elements *= value;
+    }
+  }
+  if (stretch_dim != -1) {
+    output_shape->data[stretch_dim] = num_input_elements / num_output_elements;
+    num_output_elements *= output_shape->data[stretch_dim];
+  }
+
+  TF_LITE_ENSURE_EQ(context, input->type, output->type);
+  TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements);
+  return kTfLiteOk;
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+  TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2);
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+  return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+  if (ReshapeOutput(context, node) != kTfLiteOk) {
+    return kTfLiteError;
+  }
+
+  for (int i = 0; i < input->bytes; ++i) {
+    output->data.raw[i] = input->data.raw[i];
+  }
+  return kTfLiteOk;
+}
+
+}  // namespace reshape
+
+TfLiteRegistration* Register_RESHAPE() {
+  static TfLiteRegistration r = {nullptr, nullptr, reshape::Prepare,
+                                 reshape::Eval};
+  return &r;
+}
+
+}  // namespace micro
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/lite/experimental/micro/kernels/reshape_test.cc b/tensorflow/lite/experimental/micro/kernels/reshape_test.cc
new file mode 100644
index 0000000..9a62943
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/kernels/reshape_test.cc
@@ -0,0 +1,356 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
+#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h"
+#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
+#include "tensorflow/lite/experimental/micro/testing/test_utils.h"
+
+namespace tflite {
+namespace testing {
+namespace {
+
+inline TfLiteTensor CreateInt32Tensor(std::initializer_list<int32_t> data,
+                                      TfLiteIntArray* dims, const char* name) {
+  TfLiteTensor result;
+  result.type = kTfLiteInt32;
+  result.data.i32 = const_cast<int32_t*>(data.begin());
+  result.dims = dims;
+  result.params = {};
+  result.allocation_type = kTfLiteMemNone;
+  result.bytes = ElementCount(*dims) * sizeof(int32_t);
+  result.allocation = nullptr;
+  result.name = name;
+  result.is_variable = true;
+  return result;
+}
+
+inline TfLiteTensor CreateInt32ConstTensor(std::initializer_list<int32_t> data,
+                                           TfLiteIntArray* dims,
+                                           const char* name) {
+  auto result = CreateInt32Tensor(data, dims, name);
+  result.is_variable = false;
+  return result;
+}
+
+TfLiteReshapeParams create_params(int* shape_data) {
+  TfLiteReshapeParams op_params = {};
+  op_params.num_dimensions = shape_data[0];
+  for (int i = 0; i < shape_data[0]; ++i)
+    op_params.shape[i] = shape_data[i + 1];
+  return op_params;
+}
+
+// If expected output is empty, the test is expected to fail.
+void TestReshapeImpl(TfLiteTensor* input_tensor, TfLiteTensor* shape_tensor,
+                     TfLiteTensor* output_tensor, int expected_output_size,
+                     const float* expected_output,
+                     std::initializer_list<int> expected_dims) {
+  TfLiteContext context;
+  TfLiteTensor tensors[3];
+  if (shape_tensor == nullptr) {
+    constexpr int inputs_size = 1;
+    constexpr int outputs_size = 1;
+    constexpr int tensors_size = inputs_size + outputs_size;
+    tensors[0] = *input_tensor;
+    tensors[1] = *output_tensor,
+    PopulateContext(tensors, tensors_size, &context);
+  } else {
+    constexpr int inputs_size = 2;
+    constexpr int outputs_size = 1;
+    constexpr int tensors_size = inputs_size + outputs_size;
+    tensors[0] = *input_tensor;
+    tensors[1] = *shape_tensor;
+    tensors[2] = *output_tensor;
+    PopulateContext(tensors, tensors_size, &context);
+  }
+
+  ::tflite::ops::micro::AllOpsResolver resolver;
+  const TfLiteRegistration* registration =
+      resolver.FindOp(tflite::BuiltinOperator_RESHAPE, 1);
+  TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
+  TfLiteReshapeParams builtin_data =
+      create_params(reinterpret_cast<int*>(output_tensor->dims));
+  const char* init_data = reinterpret_cast<const char*>(&builtin_data);
+  size_t init_data_size = 0;
+  void* user_data = nullptr;
+  if (registration->init) {
+    user_data = registration->init(&context, init_data, init_data_size);
+  }
+  TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
+  TfLiteNode node;
+  if (shape_tensor == nullptr) {
+    node.inputs = IntArrayFromInitializer({1, 0});
+    node.outputs = IntArrayFromInitializer({1, 1});
+  } else {
+    node.inputs = IntArrayFromInitializer({2, 0, 1});
+    node.outputs = IntArrayFromInitializer({1, 2});
+  }
+  node.temporaries = temporaries_array;
+  node.user_data = user_data;
+  node.builtin_data = reinterpret_cast<void*>(&builtin_data);
+  node.custom_initial_data = nullptr;
+  node.custom_initial_data_size = 0;
+  node.delegate = nullptr;
+  if (registration->prepare) {
+    TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
+  }
+  TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
+  if (expected_output_size == 0) {
+    TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
+                            registration->invoke(&context, &node));
+    return;
+  }
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
+  if (registration->free) {
+    registration->free(&context, user_data);
+  }
+  const int output_dims_count = ElementCount(*output_tensor->dims);
+  switch (output_tensor->type) {
+    case kTfLiteFloat32:
+      for (int i = 0; i < expected_output_size; ++i) {
+        TF_LITE_MICRO_EXPECT_NEAR(expected_output[i], output_tensor->data.f[i],
+                                  1e-5f);
+      }
+      break;
+    case kTfLiteUInt8:
+      for (int i = 0; i < expected_output_size; ++i) {
+        TF_LITE_MICRO_EXPECT_NEAR(expected_output[i],
+                                  output_tensor->data.uint8[i], 1e-5f);
+      }
+      break;
+    case kTfLiteInt8:
+      for (int i = 0; i < expected_output_size; ++i) {
+        TF_LITE_MICRO_EXPECT_NEAR(expected_output[i],
+                                  output_tensor->data.int8[i], 1e-5f);
+      }
+      break;
+    default:
+      break;
+  }
+  TF_LITE_MICRO_EXPECT_EQ(expected_dims.size(), output_tensor->dims->size);
+  for (int i = 0; i < expected_dims.size(); ++i) {
+    TF_LITE_MICRO_EXPECT_NEAR(expected_dims.begin()[i],
+                              output_tensor->dims->data[i], 1e-5f);
+  }
+}
+
+void TestReshapeTyped(TfLiteTensor* input_tensor,
+                      std::initializer_list<int> shape_dims_data,
+                      std::initializer_list<int32_t> shape_data,
+                      int* output_dims_data, TfLiteTensor* output_tensor,
+                      int expected_output_size, const float* expected_output,
+                      std::initializer_list<int> expected_dims) {
+  TestReshapeImpl(input_tensor, nullptr, output_tensor, expected_output_size,
+                  expected_output, expected_dims);
+  TfLiteIntArray* shape_dims = IntArrayFromInitializer(shape_dims_data);
+  auto shape_tensor = CreateInt32Tensor(shape_data, shape_dims, "shape_tensor");
+  TestReshapeImpl(input_tensor, &shape_tensor, output_tensor,
+                  expected_output_size, expected_output, expected_dims);
+  auto shape_const_tensor =
+      CreateInt32ConstTensor(shape_data, shape_dims, "shape_tensor");
+  TestReshapeImpl(input_tensor, &shape_const_tensor, output_tensor,
+                  expected_output_size, expected_output, expected_dims);
+}
+
+void TestReshape(std::initializer_list<int> input_dims_data,
+                 std::initializer_list<float> input_data,
+                 std::initializer_list<int> shape_dims_data,
+                 std::initializer_list<int32_t> shape_data,
+                 int* output_dims_data, float* output_data,
+                 std::initializer_list<float> expected_output,
+                 std::initializer_list<int> expected_dims) {
+  TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
+  TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
+  int expected_output_size = expected_output.size();
+  // Testing float input.
+  auto input_tensor = CreateFloatTensor(input_data, input_dims, "input_tensor");
+  auto output_tensor =
+      CreateFloatTensor(output_data, output_dims, "input_tensor");
+  TestReshapeTyped(&input_tensor, shape_dims_data, shape_data, output_dims_data,
+                   &output_tensor, expected_output_size,
+                   expected_output.begin(), expected_dims);
+  // Testing uint8 input.
+  float expected_uint8[16], expected_int8[16];
+  uint8_t input_uint8[16], output_uint8[16];
+  int8_t input_int8[16], output_int8[16];
+  float input_min = 0;
+  float input_max = 15.9375;
+  for (int i = 0; i < input_data.size(); ++i) {
+    input_uint8[i] = F2Q(input_data.begin()[i], input_min, input_max);
+  }
+  for (int i = 0; i < expected_output.size(); ++i) {
+    expected_uint8[i] = F2Q(expected_output.begin()[i], input_min, input_max);
+  }
+  input_tensor = CreateQuantizedTensor(input_uint8, input_dims, "input_tensor",
+                                       input_min, input_max);
+  output_tensor = CreateQuantizedTensor(output_uint8, output_dims,
+                                        "input_tensor", input_min, input_max);
+  TestReshapeTyped(&input_tensor, shape_dims_data, shape_data, output_dims_data,
+                   &output_tensor, expected_output_size, expected_uint8,
+                   expected_dims);
+}
+}  // namespace
+}  // namespace testing
+}  // namespace tflite
+
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(MismatchedDimensions) {
+  float output_data[8];
+  int output_dims[3] = {2, 2, 1};
+  tflite::testing::TestReshape({4, 1, 2, 4, 1},  // input_dims
+                               {3},              // input_data
+                               {1, 2},           // shape_dims
+                               {2, 1},           // shape_data
+                               output_dims,      // output_dims
+                               output_data, {},  // expected_output
+                               {}                // expected_dims
+  );
+}
+
+TF_LITE_MICRO_TEST(TooManyDimensions) {
+  float output_data[2];
+  int output_dims[10] = {9, 1, 1, 1, 1, 1, 1, 1, 1, 2};
+  tflite::testing::TestReshape({9, 1, 1, 2, 1, 1, 1, 1, 1, 1},  // input_dims
+                               {3, 2},                          // input_data
+                               {1, 9},                          // shape_dims
+                               {1, 1, 1, 1, 1, 1, 1, 1, 2},     // shape_data
+                               output_dims,                     // output_dims
+                               output_data, {3, 2},         // expected_output
+                               {1, 1, 1, 1, 1, 1, 1, 1, 2}  // expected_dims
+  );
+}
+
+// Number of dimensions > 8 is accepted in micro since it does not use
+// TfLiteReshapeParams.
+TF_LITE_MICRO_TEST(TooManySpecialDimensions) {
+  float output_data[8];
+  int output_dims[5] = {4, -1, -1, 2, 4};
+  tflite::testing::TestReshape({4, 1, 2, 4, 1},  // input_dims
+                               {3},              // input_data
+                               {1, 4},           // shape_dims
+                               {-1, -1, 2, 4},   // shape_data
+                               output_dims,      // output_dims
+                               output_data, {},  // expected_output
+                               {}                // expected_dims
+  );
+}
+
+// Create the model with a 2x2 shape. Processing still works because the new
+// shape ends up being hardcoded as a flat vector.
+TF_LITE_MICRO_TEST(InvalidShape) {
+  using tflite::testing::CreateFloatTensor;
+  using tflite::testing::IntArrayFromInitializer;
+  using tflite::testing::IntArrayFromInts;
+  TfLiteIntArray* input_dims = IntArrayFromInitializer({3, 1, 2, 2});
+  auto input_data = {3.0f};
+  auto input_tensor = CreateFloatTensor(input_data, input_dims, "input_tensor");
+  float output_data[4];
+  int output_dims_data[6] = {2, 2, 1, 2, 2, 1};
+  TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
+  auto output_tensor =
+      CreateFloatTensor(output_data, output_dims, "input_tensor");
+  tflite::testing::TestReshapeImpl(&input_tensor,          // input_tensor
+                                   nullptr,                // shape_tensor
+                                   &output_tensor, 0, {},  // expected_output
+                                   {}                      // expected_dims
+  );
+}
+
+TF_LITE_MICRO_TEST(RegularShapes) {
+  float output_data[8];
+  int output_dims[4] = {3, 2, 2, 2};
+  tflite::testing::TestReshape({4, 1, 2, 4, 1},           // input_dims
+                               {1, 2, 3, 4, 5, 6, 7, 8},  // input_data
+                               {1, 3},                    // shape_dims
+                               {2, 2, 2},                 // shape_data
+                               output_dims,               // output_dims
+                               output_data,
+                               {1, 2, 3, 4, 5, 6, 7, 8},  // expected_output
+                               {2, 2, 2}                  // expected_dims
+  );
+}
+
+TF_LITE_MICRO_TEST(WithStretchDimension) {
+  float output_data[8];
+  int output_dims[4] = {3, 2, 1, -1};
+  tflite::testing::TestReshape({4, 1, 2, 4, 1},           // input_dims
+                               {1, 2, 3, 4, 5, 6, 7, 8},  // input_data
+                               {1, 3},                    // shape_dims
+                               {2, 1, -1},                // shape_data
+                               output_dims,               // output_dims
+                               output_data,
+                               {1, 2, 3, 4, 5, 6, 7, 8},  // expected_output
+                               {2, 1, 4}                  // expected_dims
+  );
+}
+
+// Shape is specified as '[]', which is the modern way to represent scalar
+// input and output.
+TF_LITE_MICRO_TEST(ScalarOutput) {
+  float output_data[1];
+  int output_dims[1] = {0};
+  tflite::testing::TestReshape({1, 1},            // input_dims
+                               {3},               // input_data
+                               {0},               // shape_dims
+                               {},                // shape_data
+                               output_dims,       // output_dims
+                               output_data, {3},  // expected_output
+                               {}                 // expected_dims
+  );
+}
+
+// Some old models specify '[0]' as the new shape, indicating that both input
+// and output are scalars.
+TF_LITE_MICRO_TEST(LegacyScalarOutput) {
+  using tflite::testing::CreateFloatTensor;
+  using tflite::testing::IntArrayFromInitializer;
+  using tflite::testing::IntArrayFromInts;
+  TfLiteIntArray* input_dims = IntArrayFromInitializer({1, 1});
+  auto input_data = {3.0f};
+  auto input_tensor = CreateFloatTensor(input_data, input_dims, "input_tensor");
+  float output_data[1];
+  int output_dims_data[2] = {1, 0};
+  TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
+  auto output_tensor =
+      CreateFloatTensor(output_data, output_dims, "input_tensor");
+  TfLiteIntArray* shape_dims = tflite::testing::IntArrayFromInitializer({1, 0});
+  auto shape_tensor =
+      tflite::testing::CreateInt32Tensor({0}, shape_dims, "shape_tensor");
+  tflite::testing::TestReshapeImpl(&input_tensor,          // input_tensor
+                                   &shape_tensor,          // shape_tensor
+                                   &output_tensor, 0, {},  // expected_output
+                                   {}                      // expected_dims
+  );
+  auto shape_const_tensor =
+      tflite::testing::CreateInt32ConstTensor({0}, shape_dims, "shape_tensor");
+  tflite::testing::TestReshapeImpl(&input_tensor,          // input_tensor
+                                   &shape_const_tensor,    // shape_tensor
+                                   &output_tensor, 0, {},  // expected_output
+                                   {}                      // expected_dims
+  );
+  float expected_ouput[1] = {3};
+  tflite::testing::TestReshapeImpl(&input_tensor,  // input_tensor
+                                   nullptr,        // shape_tensor
+                                   &output_tensor, 1,
+                                   expected_ouput,  // expected_output
+                                   {}               // expected_dims
+  );
+}
+
+TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/experimental/micro/kernels/softmax.cc b/tensorflow/lite/experimental/micro/kernels/softmax.cc
index 6d2d8b4..ff4ee43 100644
--- a/tensorflow/lite/experimental/micro/kernels/softmax.cc
+++ b/tensorflow/lite/experimental/micro/kernels/softmax.cc
@@ -42,7 +42,7 @@
                                     OpData* data) {
   if (input->type == kTfLiteUInt8) {
     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
-    TF_LITE_ENSURE(context, output->params.scale == 1. / 256);
+    TF_LITE_ENSURE(context, output->params.scale == 1.f / 256);
 
     static const int kScaledDiffIntegerBits = 5;
 
diff --git a/tensorflow/lite/experimental/micro/micro_interpreter.cc b/tensorflow/lite/experimental/micro/micro_interpreter.cc
index 5997196..d614aba 100644
--- a/tensorflow/lite/experimental/micro/micro_interpreter.cc
+++ b/tensorflow/lite/experimental/micro/micro_interpreter.cc
@@ -136,8 +136,10 @@
 
     if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) {
       error_reporter_->Report(
-          "Found builtin operator %s with custom options.\n",
+          "Unsupported behavior: found builtin operator %s with custom "
+          "options.\n",
           EnumNameBuiltinOperator(op_type));
+      return kTfLiteError;
     }
     StackDataAllocator stack_data_allocator;
     const char* custom_data = nullptr;
diff --git a/tensorflow/lite/experimental/micro/testing/test_utils.h b/tensorflow/lite/experimental/micro/testing/test_utils.h
index 5130901..aba15d3 100644
--- a/tensorflow/lite/experimental/micro/testing/test_utils.h
+++ b/tensorflow/lite/experimental/micro/testing/test_utils.h
@@ -73,7 +73,7 @@
 }
 
 // Converts a float value into a signed thirty-two-bit quantized value.
-inline uint8_t F2Q32(const float value, const float min, const float max) {
+inline int32_t F2Q32(const float value, const float min, const float max) {
   return static_cast<int32_t>((value - ZeroPointFromMinMax<int32_t>(min, max)) /
                               ScaleFromMinMax<int32_t>(min, max));
 }
@@ -123,6 +123,25 @@
   return CreateFloatTensor(data.begin(), dims, name);
 }
 
+inline TfLiteTensor CreateBoolTensor(const bool* data, TfLiteIntArray* dims,
+                                     const char* name) {
+  TfLiteTensor result;
+  result.type = kTfLiteBool;
+  result.data.b = const_cast<bool*>(data);
+  result.dims = dims;
+  result.params = {};
+  result.allocation_type = kTfLiteMemNone;
+  result.bytes = ElementCount(*dims) * sizeof(bool);
+  result.allocation = nullptr;
+  result.name = name;
+  return result;
+}
+
+inline TfLiteTensor CreateBoolTensor(std::initializer_list<bool> data,
+                                     TfLiteIntArray* dims, const char* name) {
+  return CreateBoolTensor(data.begin(), dims, name);
+}
+
 inline TfLiteTensor CreateQuantizedTensor(const uint8_t* data,
                                           TfLiteIntArray* dims,
                                           const char* name, float min,
@@ -171,6 +190,29 @@
   return CreateQuantized32Tensor(data.begin(), dims, name, min, max);
 }
 
+template <typename input_type = int32_t,
+          TfLiteType tensor_input_type = kTfLiteInt32>
+inline TfLiteTensor CreateTensor(const input_type* data, TfLiteIntArray* dims,
+                                 const char* name) {
+  TfLiteTensor result;
+  result.type = tensor_input_type;
+  result.data.raw = reinterpret_cast<char*>(const_cast<input_type*>(data));
+  result.dims = dims;
+  result.allocation_type = kTfLiteMemNone;
+  result.bytes = ElementCount(*dims) * sizeof(input_type);
+  result.allocation = nullptr;
+  result.name = name;
+  result.is_variable = true;
+  return result;
+}
+
+template <typename input_type = int32_t,
+          TfLiteType tensor_input_type = kTfLiteInt32>
+inline TfLiteTensor CreateTensor(std::initializer_list<input_type> data,
+                                 TfLiteIntArray* dims, const char* name) {
+  return CreateTensor<input_type, tensor_input_type>(data.begin(), dims, name);
+}
+
 // Do a simple string comparison for testing purposes, without requiring the
 // standard C library.
 inline int TestStrcmp(const char* a, const char* b) {
diff --git a/tensorflow/lite/experimental/micro/tools/ci_build/install_arduino_cli.sh b/tensorflow/lite/experimental/micro/tools/ci_build/install_arduino_cli.sh
index 3c7ebe3..f55c354 100755
--- a/tensorflow/lite/experimental/micro/tools/ci_build/install_arduino_cli.sh
+++ b/tensorflow/lite/experimental/micro/tools/ci_build/install_arduino_cli.sh
@@ -21,9 +21,8 @@
 cd /tmp
 
 rm -rf arduino-cli*
-curl -L -O "https://downloads.arduino.cc/arduino-cli/arduino-cli-latest-linux64.tar.bz2"
-tar xjf arduino-cli-latest-linux64.tar.bz2
-mv arduino-cli-*linux64 arduino-cli
+curl -L -O "https://github.com/arduino/arduino-cli/releases/download/0.4.0/arduino-cli_0.4.0_Linux_64bit.tar.gz"
+tar xzf arduino-cli_0.4.0_Linux_64bit.tar.gz
 
 /tmp/arduino-cli core update-index
-/tmp/arduino-cli core install arduino:sam
+/tmp/arduino-cli core install arduino:mbed
diff --git a/tensorflow/lite/experimental/micro/tools/ci_build/test_arduino_library.sh b/tensorflow/lite/experimental/micro/tools/ci_build/test_arduino_library.sh
index bb4a33f..c068176 100755
--- a/tensorflow/lite/experimental/micro/tools/ci_build/test_arduino_library.sh
+++ b/tensorflow/lite/experimental/micro/tools/ci_build/test_arduino_library.sh
@@ -23,15 +23,22 @@
 ARDUINO_HOME_DIR=${HOME}/Arduino
 ARDUINO_LIBRARIES_DIR=${ARDUINO_HOME_DIR}/libraries
 ARDUINO_CLI_TOOL=/tmp/arduino-cli
+# Necessary due to bug in arduino-cli that allows it to build files in pwd
+TEMP_BUILD_DIR=/tmp/tflite-arduino-build
 
 LIBRARY_ZIP=${1}
 
 rm -rf ${ARDUINO_LIBRARIES_DIR}
+rm -rf ${TEMP_BUILD_DIR}
 
 mkdir -p ${ARDUINO_HOME_DIR}/libraries
+mkdir -p ${TEMP_BUILD_DIR}
 
 unzip -q ${LIBRARY_ZIP} -d ${ARDUINO_LIBRARIES_DIR}
 
+# Change into this dir before running the tests
+cd ${TEMP_BUILD_DIR}
+
 for f in ${ARDUINO_LIBRARIES_DIR}/*/examples/*/*.ino; do
-  ${ARDUINO_CLI_TOOL} compile --fqbn arduino:sam:arduino_due_x $f
+  ${ARDUINO_CLI_TOOL} compile --fqbn arduino:mbed:nano33ble $f
 done
diff --git a/tensorflow/lite/experimental/micro/tools/make/Makefile b/tensorflow/lite/experimental/micro/tools/make/Makefile
index f382892..5bca0f5 100644
--- a/tensorflow/lite/experimental/micro/tools/make/Makefile
+++ b/tensorflow/lite/experimental/micro/tools/make/Makefile
@@ -107,13 +107,17 @@
 tensorflow/lite/kernels/internal/common.h \
 tensorflow/lite/kernels/internal/compatibility.h \
 tensorflow/lite/kernels/internal/optimized/neon_check.h \
+tensorflow/lite/kernels/internal/reference/binary_function.h \
 tensorflow/lite/kernels/internal/reference/conv.h \
 tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h \
 tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h \
+tensorflow/lite/kernels/internal/reference/floor.h \
 tensorflow/lite/kernels/internal/reference/fully_connected.h \
 tensorflow/lite/kernels/internal/reference/pooling.h \
 tensorflow/lite/kernels/internal/reference/prelu.h \
+tensorflow/lite/kernels/internal/reference/maximum_minimum.h \
 tensorflow/lite/kernels/internal/reference/softmax.h \
+tensorflow/lite/kernels/internal/reference/arg_min_max.h \
 tensorflow/lite/kernels/internal/round.h \
 tensorflow/lite/kernels/internal/tensor_ctypes.h \
 tensorflow/lite/kernels/internal/types.h \
diff --git a/tensorflow/lite/experimental/micro/tools/make/helper_functions.inc b/tensorflow/lite/experimental/micro/tools/make/helper_functions.inc
index b991bc6..87a6b0b 100644
--- a/tensorflow/lite/experimental/micro/tools/make/helper_functions.inc
+++ b/tensorflow/lite/experimental/micro/tools/make/helper_functions.inc
@@ -137,7 +137,7 @@
 	@mkdir -p $$(dir $$@)
 	@python tensorflow/lite/experimental/micro/tools/make/transform_arduino_source.py \
         --third_party_headers="$(4)" < $$< | \
-        sed -E 's/<string.h>/<string.h>\n#include <stdint.h>/g' > $$@
+        sed -E 's@#include <string.h>@//#include <string.h> /* Patched by helper_functions.inc for Arduino compatibility */@g' > $$@
 
 $(PRJDIR)$(2)/arduino/%: tensorflow/lite/experimental/micro/tools/make/templates/%
 	@mkdir -p $$(dir $$@)
diff --git a/tensorflow/lite/experimental/micro/tools/make/templates/arduino_example.ino b/tensorflow/lite/experimental/micro/tools/make/templates/arduino_example.ino
index 02ebe5f..ac8813f 100644
--- a/tensorflow/lite/experimental/micro/tools/make/templates/arduino_example.ino
+++ b/tensorflow/lite/experimental/micro/tools/make/templates/arduino_example.ino
@@ -18,11 +18,25 @@
 // Include an empty header so that Arduino knows to build the TF Lite library.
 #include <TensorFlowLite.h>
 
+// TensorFlow Lite defines its own main function
 extern int tflite_micro_main(int argc, char* argv[]);
 
+// So the example works with or without a serial connection,
+// wait to see one for 5 seconds before giving up.
+void waitForSerial() {
+  int start = millis();
+  while(!Serial) {
+    int diff = millis() - start;
+    if (diff > 5000) break;
+  }
+}
+
+// Runs once when the program starts
 void setup() {
+  waitForSerial();
   tflite_micro_main(0, NULL);
 }
 
+// Leave the loop unused
 void loop() {
 }
\ No newline at end of file
diff --git a/tensorflow/lite/experimental/micro/tools/make/third_party_downloads.inc b/tensorflow/lite/experimental/micro/tools/make/third_party_downloads.inc
index 42ecf3f..fb8bbf6 100644
--- a/tensorflow/lite/experimental/micro/tools/make/third_party_downloads.inc
+++ b/tensorflow/lite/experimental/micro/tools/make/third_party_downloads.inc
@@ -46,5 +46,5 @@
 KISSFFT_URL="https://github.com/mborgerding/kissfft/archive/v130.zip"
 KISSFFT_MD5="438ba1fef5783cc5f5f201395cc477ca"
 
-PERSON_MODEL_URL := "https://storage.googleapis.com/download.tensorflow.org/data/tf_lite_micro_person_data.tgz"
-PERSON_MODEL_MD5 := "dc0ffad71adb651fb7b2d472b6c901ef"
+PERSON_MODEL_URL := "https://storage.googleapis.com/download.tensorflow.org/data/tf_lite_micro_person_data_grayscale.zip"
+PERSON_MODEL_MD5 := "cd1059dd1c94afadd59608202732ad63"
diff --git a/tensorflow/lite/experimental/microfrontend/python/kernel_tests/audio_microfrontend_op_test.py b/tensorflow/lite/experimental/microfrontend/python/kernel_tests/audio_microfrontend_op_test.py
index 3ce8617..913c330 100644
--- a/tensorflow/lite/experimental/microfrontend/python/kernel_tests/audio_microfrontend_op_test.py
+++ b/tensorflow/lite/experimental/microfrontend/python/kernel_tests/audio_microfrontend_op_test.py
@@ -21,7 +21,7 @@
 import tensorflow as tf
 
 from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op as frontend_op
-from tensorflow.python.framework import test_util
+from tensorflow.python.framework import ops
 
 SAMPLE_RATE = 1000
 WINDOW_SIZE = 25
@@ -34,7 +34,10 @@
 
 class AudioFeatureGenerationTest(tf.test.TestCase):
 
-  @test_util.run_v1_only("b/120545219")
+  def setUp(self):
+    super(AudioFeatureGenerationTest, self).setUp()
+    ops.disable_eager_execution()
+
   def testSimple(self):
     with self.test_session():
       audio = tf.constant(
@@ -53,7 +56,6 @@
       self.assertAllEqual(filterbanks.eval(),
                           [[479, 425], [436, 378], [410, 350], [391, 325]])
 
-  @test_util.run_v1_only("b/120545219")
   def testSimpleFloatScaled(self):
     with self.test_session():
       audio = tf.constant(
@@ -75,7 +77,6 @@
                           [[7.484375, 6.640625], [6.8125, 5.90625],
                            [6.40625, 5.46875], [6.109375, 5.078125]])
 
-  @test_util.run_v1_only("b/120545219")
   def testStacking(self):
     with self.test_session():
       audio = tf.constant(
@@ -118,7 +119,6 @@
           [[479, 425, 479, 425, 436, 378], [479, 425, 436, 378, 410, 350],
            [436, 378, 410, 350, 391, 325], [410, 350, 391, 325, 391, 325]])
 
-  @test_util.run_v1_only("b/120545219")
   def testStackingDropFrame(self):
     with self.test_session():
       audio = tf.constant(
diff --git a/tensorflow/lite/experimental/objc/TensorFlowLiteObjC.podspec b/tensorflow/lite/experimental/objc/TensorFlowLiteObjC.podspec
index d89e024..41af895 100644
--- a/tensorflow/lite/experimental/objc/TensorFlowLiteObjC.podspec
+++ b/tensorflow/lite/experimental/objc/TensorFlowLiteObjC.podspec
@@ -1,10 +1,10 @@
 Pod::Spec.new do |s|
   s.name             = 'TensorFlowLiteObjC'
-  s.version          = '0.2.0'
+  s.version          = '1.14.0'
   s.authors          = 'Google Inc.'
   s.license          = { :type => 'Apache' }
   s.homepage         = 'https://github.com/tensorflow/tensorflow'
-  s.source           = { :git => 'https://github.com/tensorflow/tensorflow.git', :commit => '37c101d' }
+  s.source           = { :git => 'https://github.com/tensorflow/tensorflow.git', :tag => "v#{s.version}" }
   s.summary          = 'TensorFlow Lite for Objective-C'
   s.description      = <<-DESC
 
diff --git a/tensorflow/lite/experimental/resource_variable/BUILD b/tensorflow/lite/experimental/resource_variable/BUILD
new file mode 100644
index 0000000..af2ed19
--- /dev/null
+++ b/tensorflow/lite/experimental/resource_variable/BUILD
@@ -0,0 +1,17 @@
+package(
+    default_visibility = ["//visibility:public"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+cc_library(
+    name = "resource_variable",
+    srcs = [
+        "resource_variable.cc",
+    ],
+    hdrs = [
+        "resource_variable.h",
+    ],
+    deps = [
+        "//tensorflow/lite/c:c_api_internal",
+    ],
+)
diff --git a/tensorflow/lite/experimental/resource_variable/resource_variable.cc b/tensorflow/lite/experimental/resource_variable/resource_variable.cc
new file mode 100644
index 0000000..502ca27
--- /dev/null
+++ b/tensorflow/lite/experimental/resource_variable/resource_variable.cc
@@ -0,0 +1,78 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/experimental/resource_variable/resource_variable.h"
+
+#include <cstdlib>
+#include <cstring>
+#include <map>
+
+namespace tflite {
+
+ResourceVariable::ResourceVariable() {
+  memset(&tensor_, 0, sizeof(TfLiteTensor));
+}
+
+ResourceVariable::ResourceVariable(ResourceVariable&& other) {
+  tensor_ = other.tensor_;
+  is_initialized_ = other.is_initialized_;
+
+  memset(&other.tensor_, 0, sizeof(TfLiteTensor));
+  other.is_initialized_ = false;
+}
+
+ResourceVariable::~ResourceVariable() {
+  if (is_initialized_) {
+    free(tensor_.data.raw);
+    if (tensor_.dims) {
+      TfLiteIntArrayFree(tensor_.dims);
+    }
+  }
+}
+
+TfLiteStatus ResourceVariable::AssignFrom(const TfLiteTensor* tensor) {
+  // Save the old allocated resources and attributes that we might use.
+  char* old_raw = tensor_.data.raw;
+  size_t old_bytes = tensor_.bytes;
+  TfLiteIntArray* old_dims = tensor_.dims;
+
+  // Copy primitive parameters.
+  memset(&tensor_, 0, sizeof(tensor_));
+  tensor_.allocation_type = kTfLiteDynamic;
+  tensor_.type = tensor->type;
+  tensor_.params = tensor->params;
+  tensor_.quantization = tensor->quantization;
+
+  // Copy old shape if possible otherwise create a new one.
+  if (TfLiteIntArrayEqual(old_dims, tensor->dims)) {
+    tensor_.dims = old_dims;
+  } else {
+    TfLiteIntArrayFree(old_dims);
+    tensor_.dims = TfLiteIntArrayCopy(tensor->dims);
+  }
+
+  // Reuse the same buffer if possible otherwise allocate a new one.
+  tensor_.data.raw = old_raw;
+  if (old_bytes != tensor->bytes) {
+    TfLiteTensorRealloc(tensor->bytes, &tensor_);
+  }
+
+  memcpy(tensor_.data.raw, tensor->data.raw, tensor_.bytes);
+  is_initialized_ = true;
+
+  return kTfLiteOk;
+}
+
+}  // namespace tflite
diff --git a/tensorflow/lite/experimental/resource_variable/resource_variable.h b/tensorflow/lite/experimental/resource_variable/resource_variable.h
new file mode 100644
index 0000000..6a93848
--- /dev/null
+++ b/tensorflow/lite/experimental/resource_variable/resource_variable.h
@@ -0,0 +1,62 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_VARIABLE_RESOURCE_VARIABLE_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_VARIABLE_RESOURCE_VARIABLE_H_
+
+#include <unordered_map>
+
+#include "tensorflow/lite/c/c_api_internal.h"
+
+namespace tflite {
+
+/// WARNING: Experimental interface, subject to change.
+// A resource variable class. It's similar to TensorFlow Resource
+// Variable, but it's identified with int32 ID in TFLite (instead of
+// using Resource handle like TensorFlow).
+//
+// TODO(b/137042749): TFLite converter cannot convert variables yet.
+// Variable functionalities are only tested with unit tests now.
+class ResourceVariable {
+ public:
+  ResourceVariable();
+  ResourceVariable(ResourceVariable&& other);
+
+  ResourceVariable(const ResourceVariable&) = delete;
+  ResourceVariable& operator=(const ResourceVariable&) = delete;
+
+  ~ResourceVariable();
+
+  // Assigns data from a tensor. Copies its type, shape and data over.
+  TfLiteStatus AssignFrom(const TfLiteTensor* tensor);
+
+  // Get the data tensor stored in the resource variable.
+  // Returns `nullptr` if the variable is never initialized by calling
+  // `AssignFrom`.
+  TfLiteTensor* GetTensor() { return is_initialized_ ? &tensor_ : nullptr; }
+
+ private:
+  // The tensor (and its buffer stored in `tensor_.data` is fully owned by
+  // the `ResourceVariable` object.
+  TfLiteTensor tensor_;
+  // True if `AssignFrom` function is every called.
+  // False if and only if `tensor_` is filled with zeros.
+  bool is_initialized_ = false;
+};
+
+using ResourceVariableMap = std::unordered_map<int, ResourceVariable>;
+
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_VARIABLE_RESOURCE_VARIABLE_H_
diff --git a/tensorflow/lite/experimental/ruy/BUILD b/tensorflow/lite/experimental/ruy/BUILD
index 6c75783..75fdc84 100644
--- a/tensorflow/lite/experimental/ruy/BUILD
+++ b/tensorflow/lite/experimental/ruy/BUILD
@@ -7,6 +7,25 @@
 load(":ruy_test.bzl", "ruy_benchmark", "ruy_benchmark_opt_sets", "ruy_test")
 load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
 
+# 1. Enable -mfpu=neon unconditionally on ARM32. If it turns out that we need to support
+#    ARM32 without NEON then we'll implement runtime detection and dispatch at that point.
+# 2. Explicitly pass -O3 on mobile configs where just "-c opt" means "optimize for code size".
+#    We would want to only do that when compilation_mode is "opt", but limitations of
+#    the "select" keyword (no nested selects, no AND boolean) seem to make that difficult
+#    at the moment. For debugging purposes, this can be overridded on the command line, e.g.
+#      bazel build -c dbg --copt=-O0 ...
+RUY_COPTS = select({
+    "//tensorflow:android_arm64": [
+        "-O3",
+    ],
+    "//tensorflow:android_arm": [
+        "-O3",
+        "-mfpu=neon",
+    ],
+    "//conditions:default": [
+    ],
+})
+
 package(
     default_visibility = ["//visibility:private"],
     licenses = ["notice"],  # Apache 2.0
@@ -15,28 +34,33 @@
 cc_library(
     name = "platform",
     hdrs = ["platform.h"],
+    copts = RUY_COPTS,
 )
 
 cc_library(
     name = "check_macros",
     hdrs = ["check_macros.h"],
+    copts = RUY_COPTS,
     deps = ["//tensorflow/lite/kernels/internal:compatibility"],
 )
 
 cc_library(
     name = "opt_set",
     hdrs = ["opt_set.h"],
+    copts = RUY_COPTS,
 )
 
 cc_library(
     name = "time",
     hdrs = ["time.h"],
+    copts = RUY_COPTS,
 )
 
 cc_library(
     name = "wait",
     srcs = ["wait.cc"],
     hdrs = ["wait.h"],
+    copts = RUY_COPTS,
     deps = [":time"],
 )
 
@@ -52,6 +76,7 @@
 cc_library(
     name = "size_util",
     hdrs = ["size_util.h"],
+    copts = RUY_COPTS,
     deps = [":check_macros"],
 )
 
@@ -63,6 +88,7 @@
     hdrs = [
         "tune.h",
     ],
+    copts = RUY_COPTS,
     deps = [
         ":opt_set",
         ":platform",
@@ -95,6 +121,7 @@
     hdrs = [
         "allocator.h",
     ],
+    copts = RUY_COPTS,
     deps = [
         ":check_macros",
         ":size_util",
@@ -111,6 +138,13 @@
 )
 
 cc_library(
+    name = "side_pair",
+    hdrs = ["side_pair.h"],
+    copts = RUY_COPTS,
+    deps = [":check_macros"],
+)
+
+cc_library(
     name = "block_map",
     srcs = [
         "block_map.cc",
@@ -118,9 +152,11 @@
     hdrs = [
         "block_map.h",
     ],
+    copts = RUY_COPTS,
     deps = [
         ":check_macros",
         ":opt_set",
+        ":side_pair",
         ":size_util",
         "@gemmlowp//:profiler",
     ],
@@ -134,6 +170,7 @@
     hdrs = [
         "blocking_counter.h",
     ],
+    copts = RUY_COPTS,
     deps = [
         ":check_macros",
         ":wait",
@@ -148,6 +185,7 @@
     hdrs = [
         "thread_pool.h",
     ],
+    copts = RUY_COPTS,
     visibility = ruy_visibility(),
     deps = [
         ":blocking_counter",
@@ -164,12 +202,14 @@
     hdrs = [
         "detect_dotprod.h",
     ],
+    copts = RUY_COPTS,
     visibility = ruy_visibility(),
 )
 
 cc_library(
     name = "path",
     hdrs = ["path.h"],
+    copts = RUY_COPTS,
     visibility = ruy_visibility(),
     deps = [
         ":platform",
@@ -185,10 +225,11 @@
     hdrs = [
         "trace.h",
     ],
+    copts = RUY_COPTS,
     deps = [
         ":block_map",
         ":check_macros",
-        ":common",
+        ":side_pair",
         ":time",
     ],
 )
@@ -201,6 +242,7 @@
     hdrs = [
         "context.h",
     ],
+    copts = RUY_COPTS,
     visibility = ruy_visibility(),
     deps = [
         ":allocator",
@@ -216,6 +258,7 @@
 cc_library(
     name = "matrix",
     hdrs = ["matrix.h"],
+    copts = RUY_COPTS,
     visibility = ruy_visibility(),
     deps = [":check_macros"],
 )
@@ -223,6 +266,7 @@
 cc_library(
     name = "spec",
     hdrs = ["spec.h"],
+    copts = RUY_COPTS,
     visibility = ruy_visibility(),
     deps = [":matrix"],
 )
@@ -230,6 +274,7 @@
 cc_library(
     name = "internal_matrix",
     hdrs = ["internal_matrix.h"],
+    copts = RUY_COPTS,
     deps = [
         ":check_macros",
         ":common",
@@ -243,6 +288,7 @@
     hdrs = [
         "common.h",
     ],
+    copts = RUY_COPTS,
     deps = [
         ":check_macros",
         ":matrix",
@@ -257,16 +303,24 @@
     srcs = [
         "kernel_arm32.cc",
         "kernel_arm64.cc",
+        "kernel_avx512.cc",
     ],
     hdrs = [
         "kernel.h",
+        "kernel_arm.h",
+        "kernel_common.h",
+        "kernel_x86.h",
     ],
+    copts = RUY_COPTS,
     deps = [
+        ":check_macros",
         ":common",
         ":internal_matrix",
+        ":matrix",
         ":opt_set",
         ":path",
         ":platform",
+        ":side_pair",
         ":size_util",
         ":spec",
         ":tune",
@@ -278,38 +332,60 @@
 cc_library(
     name = "pack",
     srcs = [
-        "pack.cc",
+        "pack_arm.cc",
+        "pack_avx512.cc",
     ],
     hdrs = [
         "pack.h",
+        "pack_arm.h",
+        "pack_common.h",
+        "pack_x86.h",
     ],
+    copts = RUY_COPTS,
     deps = [
+        ":check_macros",
         ":common",
         ":internal_matrix",
+        ":matrix",
         ":opt_set",
         ":path",
         ":platform",
-        ":spec",
         ":tune",
         "@gemmlowp//:profiler",
     ],
 )
 
 cc_library(
+    name = "trmul_params",
+    hdrs = ["trmul_params.h"],
+    copts = RUY_COPTS,
+    deps = [
+        ":internal_matrix",
+        ":side_pair",
+        ":tune",
+    ],
+)
+
+cc_library(
     name = "trmul",
     srcs = ["trmul.cc"],
     hdrs = ["trmul.h"],
+    copts = RUY_COPTS,
     deps = [
         ":allocator",
         ":block_map",
+        ":check_macros",
         ":common",
         ":context",
         ":internal_matrix",
-        ":kernel",
+        ":matrix",
         ":opt_set",
-        ":pack",
+        ":side_pair",
+        ":size_util",
+        ":spec",
         ":thread_pool",
         ":trace",
+        ":trmul_params",
         ":tune",
         "@gemmlowp//:profiler",
     ],
@@ -326,16 +402,23 @@
         "ruy.h",
         "ruy_advanced.h",
     ],
+    copts = RUY_COPTS,
     visibility = ruy_visibility(),
     deps = [
         ":check_macros",
         ":common",
         ":context",
+        ":internal_matrix",
+        ":kernel",
         ":matrix",
+        ":opt_set",
+        ":pack",
         ":path",
+        ":side_pair",
         ":size_util",
         ":spec",
         ":trmul",
+        ":trmul_params",
         ":tune",
         "@gemmlowp//:profiler",
     ],
@@ -346,7 +429,12 @@
     name = "example",
     srcs = ["example.cc"],
     deps = [
+        ":context",
+        ":internal_matrix",
+        ":matrix",
+        ":path",
         ":ruy",
+        ":spec",
     ],
 )
 
@@ -355,7 +443,12 @@
     name = "example_advanced",
     srcs = ["example_advanced.cc"],
     deps = [
+        ":context",
+        ":internal_matrix",
+        ":matrix",
+        ":path",
         ":ruy",
+        ":spec",
     ],
 )
 
@@ -365,6 +458,7 @@
     testonly = True,
     srcs = ["pmu.cc"],
     hdrs = ["pmu.h"],
+    copts = RUY_COPTS,
     deps = [":check_macros"],
 )
 
@@ -373,6 +467,7 @@
     name = "test_lib",
     testonly = True,
     hdrs = ["test.h"],
+    copts = RUY_COPTS,
     # need defines, not copts, because it's controlling a header, test.h
     defines = ruy_test_ext_defines(),
     linkopts = select({
@@ -380,8 +475,10 @@
         "//conditions:default": ["-lm"],
     }),
     deps = [
+        ":matrix",
         ":pmu",
         ":ruy",
+        ":spec",
         ":time",
         "@com_google_googletest//:gtest",
         ":platform",
@@ -391,6 +488,7 @@
 ruy_benchmark(
     name = "benchmark",
     srcs = ["benchmark.cc"],
+    copts = RUY_COPTS,
     lhs_rhs_accum_dst = [
         ("f32", "f32", "f32", "f32"),
         ("u8", "u8", "i32", "u8"),
@@ -404,6 +502,7 @@
 ruy_test(
     name = "test_fast",
     srcs = ["test_fast.cc"],
+    copts = RUY_COPTS,
     lhs_rhs_accum_dst = [
         ("f32", "f32", "f32", "f32"),
         ("f64", "f32", "f64", "f32"),
@@ -419,6 +518,7 @@
 ruy_test(
     name = "test_slow",
     srcs = ["test_slow.cc"],
+    copts = RUY_COPTS,
     lhs_rhs_accum_dst = [
         ("f32", "f32", "f32", "f32"),
         ("u8", "u8", "i32", "u8"),
@@ -432,6 +532,7 @@
 ruy_test(
     name = "test_special_specs",
     srcs = ["test_special_specs.cc"],
+    copts = RUY_COPTS,
     lhs_rhs_accum_dst = [
         ("f32", "f32", "f32", "f32"),
         ("u8", "u8", "i32", "u8"),
@@ -442,6 +543,7 @@
 ruy_benchmark_opt_sets(
     name = "benchmark_opt_set",
     srcs = ["benchmark.cc"],
+    copts = RUY_COPTS,
     lhs_rhs_accum_dst = [
         ("f32", "f32", "f32", "f32"),
         ("u8", "u8", "i32", "u8"),
diff --git a/tensorflow/lite/experimental/ruy/allocator.h b/tensorflow/lite/experimental/ruy/allocator.h
index ef1db4d..aabaf61 100644
--- a/tensorflow/lite/experimental/ruy/allocator.h
+++ b/tensorflow/lite/experimental/ruy/allocator.h
@@ -16,6 +16,7 @@
 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_ALLOCATOR_H_
 #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_ALLOCATOR_H_
 
+#include <cstddef>
 #include <cstdint>
 #include <memory>
 #include <vector>
diff --git a/tensorflow/lite/experimental/ruy/allocator_test.cc b/tensorflow/lite/experimental/ruy/allocator_test.cc
index 7006b0d..4bc9956 100644
--- a/tensorflow/lite/experimental/ruy/allocator_test.cc
+++ b/tensorflow/lite/experimental/ruy/allocator_test.cc
@@ -15,8 +15,6 @@
 
 #include "tensorflow/lite/experimental/ruy/allocator.h"
 
-#include <cstdlib>
-
 #include <gtest/gtest.h>
 
 namespace ruy {
diff --git a/tensorflow/lite/experimental/ruy/benchmark.cc b/tensorflow/lite/experimental/ruy/benchmark.cc
index 7d05579..b1db2c0 100644
--- a/tensorflow/lite/experimental/ruy/benchmark.cc
+++ b/tensorflow/lite/experimental/ruy/benchmark.cc
@@ -13,6 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
+#include <cstdio>
 #include <cstdlib>
 #include <string>
 
diff --git a/tensorflow/lite/experimental/ruy/block_map.cc b/tensorflow/lite/experimental/ruy/block_map.cc
index 7405580..bb74c12 100644
--- a/tensorflow/lite/experimental/ruy/block_map.cc
+++ b/tensorflow/lite/experimental/ruy/block_map.cc
@@ -15,6 +15,9 @@
 
 #include "tensorflow/lite/experimental/ruy/block_map.h"
 
+#include <algorithm>
+#include <cstdint>
+
 #include "profiling/instrumentation.h"
 #include "tensorflow/lite/experimental/ruy/check_macros.h"
 #include "tensorflow/lite/experimental/ruy/opt_set.h"
@@ -22,46 +25,52 @@
 
 namespace ruy {
 
-void GetBlockByIndex(const BlockMap& block_map, std::uint32_t index,
-                     std::uint16_t* block_r, std::uint16_t* block_c) {
+void GetBlockByIndex(const BlockMap& block_map, int index,
+                     SidePair<int>* block) {
   gemmlowp::ScopedProfilingLabel label("GetBlockByIndex");
-  std::uint16_t rectr =
-      index & ((1 << block_map.rows_rectangularness_log2) - 1);
-  std::uint16_t rectc =
-      index & ((1 << block_map.cols_rectangularness_log2) - 1);
+  const std::uint32_t index_u32 = index;
+  const std::uint32_t rectr =
+      index_u32 & ((1u << block_map.rectangularness_log2[Side::kLhs]) - 1);
+  const std::uint32_t rectc =
+      index_u32 & ((1u << block_map.rectangularness_log2[Side::kRhs]) - 1);
 
-  std::uint16_t n1 = index >> (block_map.rows_rectangularness_log2 +
-                               block_map.cols_rectangularness_log2);
-  RUY_DCHECK_EQ(index, (n1 << (block_map.rows_rectangularness_log2 +
-                               block_map.cols_rectangularness_log2)) +
-                           rectr + rectc);
+  const std::uint32_t n1 =
+      index_u32 >> (block_map.rectangularness_log2[Side::kLhs] +
+                    block_map.rectangularness_log2[Side::kRhs]);
+  RUY_DCHECK_EQ(index_u32,
+                (n1 << (block_map.rectangularness_log2[Side::kLhs] +
+                        block_map.rectangularness_log2[Side::kRhs])) +
+                    rectr + rectc);
 
-  std::uint16_t br, bc;
+  std::uint32_t br, bc;
   if (block_map.traversal_order == BlockMapTraversalOrder::kLinear) {
-    br = n1 & ((1 << block_map.num_blocks_base_log2) - 1);
+    br = n1 & ((1u << block_map.num_blocks_base_log2) - 1);
     bc = n1 >> block_map.num_blocks_base_log2;
   } else {
     // Decode fractal z-order
-    std::uint16_t n2 =
-        (n1 & 0x9999) | ((n1 & 0x4444) >> 1) | ((n1 & 0x2222) << 1);
-    std::uint16_t n4 =
-        (n2 & 0xc3c3) | ((n2 & 0x3030) >> 2) | ((n2 & 0x0c0c) << 2);
-    std::uint16_t n8 =
-        (n4 & 0xf00f) | ((n4 & 0x0f00) >> 4) | ((n4 & 0x00f0) << 4);
-    br = n8 & 0xff;
-    bc = n8 >> 8;
+    const std::uint32_t n2 = (n1 & 0x99999999u) | ((n1 & 0x44444444u) >> 1) |
+                             ((n1 & 0x22222222u) << 1);
+    const std::uint32_t n4 = (n2 & 0xc3c3c3c3u) | ((n2 & 0x30303030u) >> 2) |
+                             ((n2 & 0x0c0c0c0cu) << 2);
+    const std::uint32_t n8 = (n4 & 0xf00ff00fu) | ((n4 & 0x0f000f00u) >> 4) |
+                             ((n4 & 0x00f000f0u) << 4);
+    const std::uint32_t n16 = (n8 & 0xff0000ffu) | ((n8 & 0x00ff0000u) >> 8) |
+                              ((n8 & 0x0000ff00u) << 8);
+
+    br = n16 & 0xffff;
+    bc = n16 >> 16;
     if (block_map.traversal_order == BlockMapTraversalOrder::kFractalU) {
       // Change fractal z-order to u-order
       br ^= bc;
     }
   }
 
-  br = (br << block_map.rows_rectangularness_log2) + rectr;
-  bc = (bc << block_map.cols_rectangularness_log2) + rectc;
+  br = (br << block_map.rectangularness_log2[Side::kLhs]) + rectr;
+  bc = (bc << block_map.rectangularness_log2[Side::kRhs]) + rectc;
 
   // Store
-  *block_r = br;
-  *block_c = bc;
+  (*block)[Side::kLhs] = br;
+  (*block)[Side::kRhs] = bc;
 }
 
 namespace {
@@ -85,6 +94,8 @@
   gemmlowp::ScopedProfilingLabel label("MakeBlockMap");
   RUY_DCHECK_GE(rows, kernel_rows);
   RUY_DCHECK_GE(cols, kernel_cols);
+  RUY_DCHECK_EQ(rows % kernel_rows, 0);
+  RUY_DCHECK_EQ(cols % kernel_cols, 0);
 
   block_map->traversal_order = BlockMapTraversalOrder::kLinear;
   if (RUY_OPT_ENABLED(RUY_OPT_FRACTAL) &&
@@ -161,86 +172,62 @@
                         ceil_log2(std::max(lhs_scalar_size, rhs_scalar_size)));
   l1_size_log2 = std::max(l1_size_log2, kernel_width_log2);
   l1_size_log2 = std::min(l1_size_log2, size_floor_log2);
-  l1_size_log2 = std::max(l1_size_log2, size_floor_log2 - 8);
 
   int num_blocks_base_log2 = size_floor_log2 - l1_size_log2;
   RUY_DCHECK_GE(num_blocks_base_log2, 0);
-  RUY_DCHECK_LE(num_blocks_base_log2, 8);
-  if (num_blocks_base_log2 == 0) {
-    if ((rows % kernel_rows) || (cols % kernel_cols)) {
-      num_blocks_base_log2 = 1;
-    }
-  }
-  RUY_DCHECK_LE(num_blocks_base_log2 + rows_rectangularness_log2, 16);
-  RUY_DCHECK_LE(num_blocks_base_log2 + cols_rectangularness_log2, 16);
-
-  int rows_rounded_up = round_up_pot(rows, kernel_rows);
-  int cols_rounded_up = round_up_pot(cols, kernel_cols);
 
   const int num_blocks_of_rows_log2 =
       num_blocks_base_log2 + rows_rectangularness_log2;
   const int num_blocks_of_cols_log2 =
       num_blocks_base_log2 + cols_rectangularness_log2;
 
-  std::uint16_t smallr =
-      round_down_pot(rows_rounded_up >> num_blocks_of_rows_log2, kernel_rows);
-  std::uint16_t smallc =
-      round_down_pot(cols_rounded_up >> num_blocks_of_cols_log2, kernel_cols);
-  std::uint16_t missr =
-      round_up_pot(rows_rounded_up - (smallr << num_blocks_of_rows_log2),
-                   kernel_rows) /
-      kernel_rows;
-  std::uint16_t missc =
-      round_up_pot(cols_rounded_up - (smallc << num_blocks_of_cols_log2),
-                   kernel_cols) /
-      kernel_cols;
+  const int smallr =
+      round_down_pot(rows >> num_blocks_of_rows_log2, kernel_rows);
+  const int smallc =
+      round_down_pot(cols >> num_blocks_of_cols_log2, kernel_cols);
+  const int missr =
+      round_up_pot(rows - (smallr << num_blocks_of_rows_log2), kernel_rows) >>
+      floor_log2(kernel_rows);
+  const int missc =
+      round_up_pot(cols - (smallc << num_blocks_of_cols_log2), kernel_cols) >>
+      floor_log2(kernel_cols);
 
-  block_map->rows = rows;
-  block_map->cols = cols;
-  block_map->kernel_rows = kernel_rows;
-  block_map->kernel_cols = kernel_cols;
+  block_map->dims[Side::kLhs] = rows;
+  block_map->dims[Side::kRhs] = cols;
+  block_map->kernel_dims[Side::kLhs] = kernel_rows;
+  block_map->kernel_dims[Side::kRhs] = kernel_cols;
   block_map->num_blocks_base_log2 = num_blocks_base_log2;
-  block_map->rows_rectangularness_log2 = rows_rectangularness_log2;
-  block_map->cols_rectangularness_log2 = cols_rectangularness_log2;
-  block_map->smallr = smallr;
-  block_map->smallc = smallc;
-  block_map->missr = missr;
-  block_map->missc = missc;
+  block_map->rectangularness_log2[Side::kLhs] = rows_rectangularness_log2;
+  block_map->rectangularness_log2[Side::kRhs] = cols_rectangularness_log2;
+  block_map->small_block_dims[Side::kLhs] = smallr;
+  block_map->small_block_dims[Side::kRhs] = smallc;
+  block_map->large_blocks[Side::kLhs] = missr;
+  block_map->large_blocks[Side::kRhs] = missc;
 }
 
-void GetBlockMatrixCoords(const BlockMap& block_map, std::uint16_t block_r,
-                          std::uint16_t block_c, int* start_r, int* start_c,
-                          int* end_r, int* end_c) {
+void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block,
+                          int* start, int* end) {
   gemmlowp::ScopedProfilingLabel label("GetBlockMatrixCoords");
-  int sr = block_r * block_map.smallr +
-           std::min(block_r, block_map.missr) * block_map.kernel_rows;
-  int er = sr + block_map.smallr +
-           (block_r < block_map.missr) * block_map.kernel_rows;
-  int sc = block_c * block_map.smallc +
-           std::min(block_c, block_map.missc) * block_map.kernel_cols;
-  int ec = sc + block_map.smallc +
-           (block_c < block_map.missc) * block_map.kernel_cols;
-  sc = round_down_pot(sc, block_map.kernel_cols);
-  ec = round_down_pot(ec, block_map.kernel_cols);
-  sr = round_down_pot(sr, block_map.kernel_rows);
-  er = round_down_pot(er, block_map.kernel_rows);
+  *start = block * block_map.small_block_dims[side] +
+           std::min(block, block_map.large_blocks[side]) *
+               block_map.kernel_dims[side];
+  *end =
+      *start + block_map.small_block_dims[side] +
+      (block < block_map.large_blocks[side] ? block_map.kernel_dims[side] : 0);
 
-  ec = std::min(ec, block_map.cols);
-  er = std::min(er, block_map.rows);
-  sc = std::max(0, ec - round_up_pot(ec - sc, block_map.kernel_cols));
-  sr = std::max(0, er - round_up_pot(er - sr, block_map.kernel_rows));
+  RUY_DCHECK_EQ(0, *start % block_map.kernel_dims[side]);
+  RUY_DCHECK_EQ(0, *end % block_map.kernel_dims[side]);
+  RUY_DCHECK_LE(*end, block_map.dims[side]);
+  RUY_DCHECK_LT(*start, *end);
+  RUY_DCHECK_GE(*start, 0);
+}
 
-  *start_c = sc;
-  *end_c = ec;
-  *start_r = sr;
-  *end_r = er;
-
-  RUY_DCHECK_LE(ec, block_map.cols);
-  RUY_DCHECK_LE(er, block_map.rows);
-  RUY_DCHECK_LT(sc, ec);
-  RUY_DCHECK_LT(sr, er);
-  RUY_DCHECK_GE(sc, 0);
-  RUY_DCHECK_GE(sr, 0);
+void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair<int>& block,
+                          SidePair<int>* start, SidePair<int>* end) {
+  for (Side side : {Side::kLhs, Side::kRhs}) {
+    GetBlockMatrixCoords(side, block_map, block[side], &(*start)[side],
+                         &(*end)[side]);
+  }
 }
 
 }  // namespace ruy
diff --git a/tensorflow/lite/experimental/ruy/block_map.h b/tensorflow/lite/experimental/ruy/block_map.h
index b0567ea..b51a1f5 100644
--- a/tensorflow/lite/experimental/ruy/block_map.h
+++ b/tensorflow/lite/experimental/ruy/block_map.h
@@ -16,7 +16,7 @@
 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_BLOCK_MAP_H_
 #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_BLOCK_MAP_H_
 
-#include <cstdint>
+#include "tensorflow/lite/experimental/ruy/side_pair.h"
 
 namespace ruy {
 
@@ -82,30 +82,22 @@
   // The order in which to traverse the matrix of which this BlockMap represents
   // a tiling (hereafter "the matrix").
   BlockMapTraversalOrder traversal_order;
-  // The number of rows in the matrix.
-  int rows;
-  // The number of columns in the matrix.
-  int cols;
+  // The dimensions of the block_map, that is, of the destination
+  // matrix rounded up to next multiples of kernel_dims.
+  SidePair<int> dims;
   // Log2 of the minimum number of subdivisions of the grid along either axis.
   int num_blocks_base_log2;
-  // Log2 of the additional subdivision of the rows axis.
-  int rows_rectangularness_log2;
-  // Log2 of the additional subdivision of the columns axis.
-  int cols_rectangularness_log2;
-  // Requested alignment of the subdivions grid along the rows axis.
-  int kernel_rows;
-  // Requested alignment of the subdivions grid along the columns axis.
-  int kernel_cols;
-  // Internal helper. Minimum number of rows in each block.
-  std::uint16_t smallr;
-  // Internal helper. Minimum number of columns in each block.
-  std::uint16_t smallc;
-  // Internal helper. Number of rows that would be missed at the end if
-  // all blocks had exactly `smallr` rows.
-  std::uint16_t missr;
-  // Internal helper. Number of columns that would be missed at the end if
-  // all blocks had exactly `smallc` columns.
-  std::uint16_t missc;
+  // Log2 of the additional subdivision of the rows/columns axis.
+  SidePair<int> rectangularness_log2;
+  // Requested alignment of the subdivisions of the grid along the rows/columns
+  // axis.
+  SidePair<int> kernel_dims;
+  // Internal helper. Minimum number of rows/columns in each block.
+  SidePair<int> small_block_dims;
+  // Internal helper. Number of blocks along each dimension that need to have
+  // their size in that dimension be given by (small_block_dims + kernel_dims)
+  // instead of just small_block_dims.
+  SidePair<int> large_blocks;
 };
 
 // Create a BlockMap suitable for tiling the destination matrix in a
@@ -114,28 +106,28 @@
                   int kernel_cols, int lhs_scalar_size, int rhs_scalar_size,
                   int cache_friendly_traversal_threshold, BlockMap* block_map);
 
-// Maps an integer index to a (block_r, block_c) block position in the grid.
-void GetBlockByIndex(const BlockMap& block_map, std::uint32_t index,
-                     std::uint16_t* block_r, std::uint16_t* block_c);
+// Maps an integer index to a block position in the grid.
+void GetBlockByIndex(const BlockMap& block_map, int index,
+                     SidePair<int>* block);
 
-// Given a (block_r, block_c) block position in the grid, returns its actual
+// Given a block position in the grid, returns its actual
+// position in the matrix that the BlockMap refers to in the dimension
+// referred to by `side`: along rows if side==kLhs, along columns if
+// side==kRhs.
+void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block,
+                          int* start, int* end);
+
+// Given a block position in the grid, returns its actual
 // position in the matrix that the BlockMap refers to in terms of
-// actual row/column indices: starting at row start_r and column start_c,
-// ending at row (end_r - 1) and column (end_c - 1).
-void GetBlockMatrixCoords(const BlockMap& block_map, std::uint16_t block_r,
-                          std::uint16_t block_c, int* start_r, int* start_c,
-                          int* end_r, int* end_c);
+// actual row/column indices.
+void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair<int>& block,
+                          SidePair<int>* start, SidePair<int>* end);
 
-// Returns the number of grid subdivisions along the rows dimension.
-inline std::uint16_t NumBlocksOfRows(const BlockMap& block_map) {
+// Returns the number of grid subdivisions along the rows dimension (if
+// side == kLhs) or columns dimension (if side == kRhs).
+inline int NumBlocksPerSide(Side side, const BlockMap& block_map) {
   return 1 << (block_map.num_blocks_base_log2 +
-               block_map.rows_rectangularness_log2);
-}
-
-// Returns the number of grid subdivisions along the columns dimension.
-inline std::uint16_t NumBlocksOfCols(const BlockMap& block_map) {
-  return 1 << (block_map.num_blocks_base_log2 +
-               block_map.cols_rectangularness_log2);
+               block_map.rectangularness_log2[side]);
 }
 
 // Returns the overall number of blocks in
@@ -145,10 +137,10 @@
 // Note that it is always true that
 //   NumBlocks == NumBlocksOfRows * NumBlocksOfCols
 // because either rows_rectangularness_log2 or cols_rectangularness_log2 is 0.
-inline std::uint32_t NumBlocks(const BlockMap& block_map) {
+inline int NumBlocks(const BlockMap& block_map) {
   return 1 << (2 * block_map.num_blocks_base_log2 +
-               block_map.rows_rectangularness_log2 +
-               block_map.cols_rectangularness_log2);
+               block_map.rectangularness_log2[Side::kLhs] +
+               block_map.rectangularness_log2[Side::kRhs]);
 }
 
 }  // namespace ruy
diff --git a/tensorflow/lite/experimental/ruy/blocking_counter.cc b/tensorflow/lite/experimental/ruy/blocking_counter.cc
index ac8a328..97b096d 100644
--- a/tensorflow/lite/experimental/ruy/blocking_counter.cc
+++ b/tensorflow/lite/experimental/ruy/blocking_counter.cc
@@ -15,9 +15,6 @@
 
 #include "tensorflow/lite/experimental/ruy/blocking_counter.h"
 
-#include <condition_variable>  // NOLINT(build/c++11)
-#include <mutex>               // NOLINT(build/c++11)
-
 #include "tensorflow/lite/experimental/ruy/check_macros.h"
 #include "tensorflow/lite/experimental/ruy/wait.h"
 
diff --git a/tensorflow/lite/experimental/ruy/common.h b/tensorflow/lite/experimental/ruy/common.h
index 9a59681..66bb4c5 100644
--- a/tensorflow/lite/experimental/ruy/common.h
+++ b/tensorflow/lite/experimental/ruy/common.h
@@ -18,7 +18,6 @@
 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_COMMON_H_
 #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_COMMON_H_
 
-#include <atomic>
 #include <limits>
 #include <type_traits>
 
@@ -28,10 +27,6 @@
 #include "tensorflow/lite/experimental/ruy/path.h"
 #include "tensorflow/lite/experimental/ruy/platform.h"
 
-#if (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32))
-#include <arm_neon.h>
-#endif
-
 #if RUY_OPT_ENABLED(RUY_OPT_PREFETCH)
 #define RUY_PREFETCH(X) X
 #else
@@ -56,20 +51,6 @@
   return const_cast<void*>(static_cast<const void*>(p));
 }
 
-// We need this where we have multiple threads potentially writing concurrently
-// to the same memory location. That is currently the case for Pack (see
-// the comment in TrMulTask where Pack is called) and in tracing.
-//
-// This is a strict-aliasing violation. For nicer things, see C++20 atomic_ref
-// and the defunct N4013. (Thanks to hboehm@).
-template <typename T>
-void relaxed_atomic_store(T* ptr, T value) {
-  static_assert(sizeof(std::atomic<T>) == sizeof(T), "");
-  std::atomic<T>* atomic = reinterpret_cast<std::atomic<T>*>(ptr);
-  RUY_DCHECK(atomic->is_lock_free());
-  atomic->store(value, std::memory_order_relaxed);
-}
-
 template <typename Scalar>
 Scalar SymmetricZeroPoint() {
   if (std::is_floating_point<Scalar>::value) {
diff --git a/tensorflow/lite/experimental/ruy/context.h b/tensorflow/lite/experimental/ruy/context.h
index 194e0af..3ca6a63 100644
--- a/tensorflow/lite/experimental/ruy/context.h
+++ b/tensorflow/lite/experimental/ruy/context.h
@@ -16,6 +16,7 @@
 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_CONTEXT_H_
 #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_CONTEXT_H_
 
+#include <cstddef>
 #include <memory>
 #include <vector>
 
diff --git a/tensorflow/lite/experimental/ruy/detect_dotprod.cc b/tensorflow/lite/experimental/ruy/detect_dotprod.cc
index 5aa1e30..35c812e 100644
--- a/tensorflow/lite/experimental/ruy/detect_dotprod.cc
+++ b/tensorflow/lite/experimental/ruy/detect_dotprod.cc
@@ -78,12 +78,12 @@
 
 #include <setjmp.h>
 #include <signal.h>
-#include <stdio.h>
-#include <stdlib.h>
-#include <string.h>
 #include <unistd.h>
 
-#include <mutex>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <mutex>  // NOLINT(build/c++11)
 
 // Intentionally keep checking for __linux__ here in case we want to
 // extend RUY_IMPLEMENT_DETECT_DOTPROD outside of linux in the future.
@@ -113,7 +113,7 @@
 sigjmp_buf& global_sigjmp_buf_just_before_trying_snippet() {
   static sigjmp_buf g;
   return g;
-};
+}
 
 // SIGILL signal handler. Long-jumps to just before
 // we ran the snippet that we know is the only thing that could have generated
@@ -173,7 +173,7 @@
       : "x0", "v0", "v1");
   // Expecting 100 (input accumulator value) + 100 * 100 + ... (repeat 4 times)
   return result == 40100;
-};
+}
 
 bool DetectDotprodBySigIllMethod() {
   return try_asm_snippet(dotprod_asm_snippet);
diff --git a/tensorflow/lite/experimental/ruy/dispatch.h b/tensorflow/lite/experimental/ruy/dispatch.h
index 9044be7..de74ef7 100644
--- a/tensorflow/lite/experimental/ruy/dispatch.h
+++ b/tensorflow/lite/experimental/ruy/dispatch.h
@@ -33,14 +33,28 @@
 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_DISPATCH_H_
 #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_DISPATCH_H_
 
-#include <limits>
+#include <algorithm>
+#include <cstdint>
+#include <limits>  // IWYU pragma: keep
+#include <type_traits>
 
 #include "profiling/instrumentation.h"
+#include "tensorflow/lite/experimental/ruy/check_macros.h"
 #include "tensorflow/lite/experimental/ruy/common.h"
 #include "tensorflow/lite/experimental/ruy/context.h"
+#include "tensorflow/lite/experimental/ruy/internal_matrix.h"
+#include "tensorflow/lite/experimental/ruy/kernel.h"
+#include "tensorflow/lite/experimental/ruy/kernel_common.h"
 #include "tensorflow/lite/experimental/ruy/matrix.h"
+#include "tensorflow/lite/experimental/ruy/opt_set.h"
+#include "tensorflow/lite/experimental/ruy/pack.h"
+#include "tensorflow/lite/experimental/ruy/pack_common.h"
+#include "tensorflow/lite/experimental/ruy/path.h"
+#include "tensorflow/lite/experimental/ruy/side_pair.h"
+#include "tensorflow/lite/experimental/ruy/size_util.h"
 #include "tensorflow/lite/experimental/ruy/spec.h"
 #include "tensorflow/lite/experimental/ruy/trmul.h"
+#include "tensorflow/lite/experimental/ruy/trmul_params.h"
 
 namespace ruy {
 
@@ -108,10 +122,10 @@
   RUY_DCHECK(spec.multiplier_exponent_perchannel == nullptr);
 }
 
-inline bool IsColMajorTrMul(const DMatrix& lhs, const DMatrix& rhs,
-                            const DMatrix& dst) {
-  return IsColMajor(lhs.layout) && IsColMajor(rhs.layout) &&
-         IsColMajor(dst.layout);
+inline bool IsColMajorTrMul(const TrMulParams& params) {
+  return IsColMajor(params.src[Side::kLhs].layout) &&
+         IsColMajor(params.src[Side::kRhs].layout) &&
+         IsColMajor(params.dst.layout);
 }
 
 inline void CreatePackedLayout(const Layout& src, const Type& scalar,
@@ -131,8 +145,8 @@
 }
 
 template <typename Scalar, typename PackedScalar>
-void CreatePackedMatrix(const DMatrix& src, const KernelLayout& kernel_layout,
-                        PMatrix* packed) {
+void CreatePackedMatrix(Side side, const KernelLayout& kernel_layout,
+                        TrMulParams* params) {
   // Ruy always uses 32-bit signed accumulators for quantized
   // matrix multiplication, so we would like to always use std::int32_t
   // unconditionally for SumsType.
@@ -142,6 +156,8 @@
       typename std::conditional<std::is_floating_point<Scalar>::value, Scalar,
                                 std::int32_t>::type;
 
+  const DMatrix& src = params->src[side];
+  PMatrix* packed = &params->packed[side];
   packed->data_type = Type::Create<PackedScalar>();
   packed->sums_type = Type::Create<SumsType>();
   CreatePackedLayout(src.layout, packed->data_type, kernel_layout,
@@ -160,7 +176,7 @@
   if (ThePath != Path::kStandardCpp) {
     // The optimized code paths currently only handle the case of all matrices
     // being column major.
-    if (!IsColMajorTrMul(params->lhs, params->rhs, params->dst)) {
+    if (!IsColMajorTrMul(*params)) {
       fallback_to_standard_cpp = true;
     }
   }
@@ -179,13 +195,12 @@
   using RhsKernelLayout = typename Kernel::RhsLayout;
 
   CreatePackedMatrix<LhsScalar, PackedLhsScalar>(
-      params->lhs, ToKernelLayout<LhsKernelLayout>(), &params->packed_lhs);
+      Side::kLhs, ToKernelLayout<LhsKernelLayout>(), params);
   CreatePackedMatrix<RhsScalar, PackedRhsScalar>(
-      params->rhs, ToKernelLayout<RhsKernelLayout>(), &params->packed_rhs);
-
-  params->lhs_run_pack =
+      Side::kRhs, ToKernelLayout<RhsKernelLayout>(), params);
+  params->run_pack[Side::kLhs] =
       &RunPack<ThePath, LhsKernelLayout, LhsScalar, PackedLhsScalar>;
-  params->rhs_run_pack =
+  params->run_pack[Side::kRhs] =
       &RunPack<ThePath, RhsKernelLayout, RhsScalar, PackedRhsScalar>;
   params->run_kernel =
       &RunKernel<ThePath, PackedLhsScalar, PackedRhsScalar, DstScalar, Spec>;
@@ -304,8 +319,8 @@
                        Context* context, Matrix<DstScalar>* dst, Path the_path,
                        TrMulParams* params) {
   // Fill in the fields we already know.
-  params->lhs = ToDMatrix(lhs);
-  params->rhs = ToDMatrix(rhs);
+  params->src[Side::kLhs] = ToDMatrix(lhs);
+  params->src[Side::kRhs] = ToDMatrix(rhs);
   params->dst = ToDMatrix(*dst);
   params->spec = ToVoidPtr(&spec);
 
diff --git a/tensorflow/lite/experimental/ruy/example.cc b/tensorflow/lite/experimental/ruy/example.cc
index 31da97d..c1a3d27 100644
--- a/tensorflow/lite/experimental/ruy/example.cc
+++ b/tensorflow/lite/experimental/ruy/example.cc
@@ -13,6 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
+#include <cstdint>
 #include <iostream>
 
 #include "tensorflow/lite/experimental/ruy/ruy.h"
diff --git a/tensorflow/lite/experimental/ruy/example_advanced.cc b/tensorflow/lite/experimental/ruy/example_advanced.cc
index 802c85c..f4415e1 100644
--- a/tensorflow/lite/experimental/ruy/example_advanced.cc
+++ b/tensorflow/lite/experimental/ruy/example_advanced.cc
@@ -13,7 +13,10 @@
 limitations under the License.
 ==============================================================================*/
 
+#include <cstddef>
 #include <iostream>
+#include <memory>
+#include <vector>
 
 #include "tensorflow/lite/experimental/ruy/ruy_advanced.h"
 
diff --git a/tensorflow/lite/experimental/ruy/internal_matrix.h b/tensorflow/lite/experimental/ruy/internal_matrix.h
index f44ce44..34826f1 100644
--- a/tensorflow/lite/experimental/ruy/internal_matrix.h
+++ b/tensorflow/lite/experimental/ruy/internal_matrix.h
@@ -90,9 +90,12 @@
 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_INTERNAL_MATRIX_H_
 #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_INTERNAL_MATRIX_H_
 
+#include <cstddef>
+#include <cstdint>
 #include <type_traits>
 #include <utility>
 
+#include "tensorflow/lite/experimental/ruy/check_macros.h"
 #include "tensorflow/lite/experimental/ruy/common.h"
 #include "tensorflow/lite/experimental/ruy/matrix.h"
 #include "tensorflow/lite/experimental/ruy/size_util.h"
diff --git a/tensorflow/lite/experimental/ruy/kernel.h b/tensorflow/lite/experimental/ruy/kernel.h
index 0c7a2e3..a096a10a 100644
--- a/tensorflow/lite/experimental/ruy/kernel.h
+++ b/tensorflow/lite/experimental/ruy/kernel.h
@@ -16,552 +16,16 @@
 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_H_
 #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_H_
 
-#include <cstddef>
-#include <cstdint>
-
-#include "fixedpoint/fixedpoint.h"
-#include "profiling/instrumentation.h"
-#include "tensorflow/lite/experimental/ruy/common.h"
-#include "tensorflow/lite/experimental/ruy/internal_matrix.h"
-#include "tensorflow/lite/experimental/ruy/opt_set.h"
-#include "tensorflow/lite/experimental/ruy/path.h"
 #include "tensorflow/lite/experimental/ruy/platform.h"
-#include "tensorflow/lite/experimental/ruy/size_util.h"
-#include "tensorflow/lite/experimental/ruy/spec.h"
-#include "tensorflow/lite/experimental/ruy/tune.h"
 
-namespace ruy {
-
-template <Path ThePath, typename LhsScalar, typename RhsScalar,
-          typename DstScalar, typename Spec>
-struct Kernel {};
-
-template <Path ThePath, typename LhsScalar, typename RhsScalar,
-          typename DstScalar, typename Spec>
-void RunKernelTyped(Tuning tuning, const PackedMatrix<LhsScalar>& lhs,
-                    const PackedMatrix<RhsScalar>& rhs, const Spec& spec,
-                    int start_row, int start_col, int end_row, int end_col,
-                    Matrix<DstScalar>* dst) {
-  using Kernel = Kernel<ThePath, LhsScalar, RhsScalar, DstScalar, Spec>;
-  Kernel kernel(tuning);
-  using LhsLayout = typename Kernel::LhsLayout;
-  using RhsLayout = typename Kernel::RhsLayout;
-  // end_row and end_col may be larger than dst dimensions.
-  // that is because kernels write directly to the destination matrix, whose
-  // dimensions may not be a multiple of the kernel dimensions, and we try to
-  // keep this annoyance localized as an implementation detail in kernels,
-  // by allowing to pass rounded-up values down as far as possible.
-  // These assertions encode the contract.
-  RUY_DCHECK_LE(0, start_row);
-  RUY_DCHECK_LE(start_row, end_row);
-  RUY_DCHECK_LT(end_row, dst->layout.rows + LhsLayout::kCols);
-  RUY_DCHECK_EQ((end_row - start_row) % LhsLayout::kCols, 0);
-  RUY_DCHECK_LE(0, start_col);
-  RUY_DCHECK_LE(start_col, end_col);
-  RUY_DCHECK_LT(end_col, dst->layout.cols + RhsLayout::kCols);
-  RUY_DCHECK_EQ((end_col - start_col) % RhsLayout::kCols, 0);
-#if RUY_OPT_ENABLED(RUY_OPT_FAT_KERNEL)
-  kernel.Run(lhs, rhs, spec, start_row, start_col, end_row, end_col, dst);
+// IWYU pragma: begin_exports
+#if RUY_PLATFORM(NEON)
+#include "tensorflow/lite/experimental/ruy/kernel_arm.h"
+#elif RUY_PLATFORM(AVX512)
+#include "tensorflow/lite/experimental/ruy/kernel_x86.h"
 #else
-  for (int col = start_col; col < end_col; col += RhsLayout::kCols) {
-    int block_end_col = std::min(col + RhsLayout::kCols, end_col);
-    for (int row = start_row; row < end_row; row += LhsLayout::kCols) {
-      int block_end_row = std::min(row + LhsLayout::kCols, end_row);
-      kernel.Run(lhs, rhs, spec, row, col, block_end_row, block_end_col, dst);
-    }
-  }
+#include "tensorflow/lite/experimental/ruy/kernel_common.h"
 #endif
-}
-
-// Main entry point for kernels.
-template <Path ThePath, typename LhsScalar, typename RhsScalar,
-          typename DstScalar, typename Spec>
-void RunKernel(Tuning tuning, const PMatrix& lhs, const PMatrix& rhs,
-               void* spec, int start_row, int start_col, int end_row,
-               int end_col, DMatrix* dst) {
-  Matrix<DstScalar> mdst = ToMatrix<DstScalar>(*dst);
-  RunKernelTyped<ThePath, LhsScalar, RhsScalar, DstScalar, Spec>(
-      tuning, ToPackedMatrix<LhsScalar>(lhs), ToPackedMatrix<RhsScalar>(rhs),
-      *static_cast<const Spec*>(spec), start_row, start_col, end_row, end_col,
-      &mdst);
-}
-
-// The signature of RunKernel is the same, regardless of template parameters.
-using RunKernelFn =
-    decltype(RunKernel<Path::kStandardCpp, std::int8_t, std::int8_t,
-                       std::int8_t, BasicSpec<std::int32_t, std::int8_t>>);
-
-// Copied from TF Lite code.
-inline std::int32_t MultiplyByQuantizedMultiplier(
-    std::int32_t x, std::int32_t quantized_multiplier, int shift) {
-  using gemmlowp::RoundingDivideByPOT;
-  using gemmlowp::SaturatingRoundingDoublingHighMul;
-  int left_shift = shift > 0 ? shift : 0;
-  int right_shift = shift > 0 ? 0 : -shift;
-  return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
-                                 x * (1 << left_shift), quantized_multiplier),
-                             right_shift);
-}
-
-// Helper to apply a fixed-point multiplier.  Only 'applicable' if AccumScalar
-// is int32 (i.e. in all cases except floating-point) and if the destination is
-// not int32 (i.e. unless the user wants to get raw accumulators).
-template <typename Spec,
-          bool IsApplicable =
-              std::is_same<typename Spec::AccumScalar, std::int32_t>::value &&
-              !std::is_same<typename Spec::DstScalar, std::int32_t>::value>
-struct ApplyMultiplierImpl {};
-
-// Specialization in non-applicable case: do nothing, just check that values
-// are default.
-template <typename Spec>
-struct ApplyMultiplierImpl<Spec, false> {
-  using AccumScalar = typename Spec::AccumScalar;
-  using DstScalar = typename Spec::DstScalar;
-  static void Run(const Spec& spec, int row, AccumScalar* accum) {
-    RUY_DCHECK_EQ(spec.multiplier_fixedpoint, 0);
-    RUY_DCHECK_EQ(spec.multiplier_exponent, 0);
-  }
-};
-
-template <typename Spec>
-struct ApplyMultiplierImpl<Spec, true> {
-  using AccumScalar = typename Spec::AccumScalar;
-  using DstScalar = typename Spec::DstScalar;
-  static void Run(const Spec& spec, int row, AccumScalar* accum) {
-    AccumScalar m = spec.multiplier_fixedpoint_perchannel
-                        ? spec.multiplier_fixedpoint_perchannel[row]
-                        : spec.multiplier_fixedpoint;
-    int e = spec.multiplier_exponent_perchannel
-                ? spec.multiplier_exponent_perchannel[row]
-                : spec.multiplier_exponent;
-    *accum = MultiplyByQuantizedMultiplier(*accum, m, e);
-  }
-};
-
-template <typename Spec>
-void ApplyMultiplier(const Spec& spec, int row,
-                     typename Spec::AccumScalar* accum) {
-  ApplyMultiplierImpl<Spec>::Run(spec, row, accum);
-}
-
-template <typename LhsScalar, typename RhsScalar, typename DstScalar,
-          typename Spec>
-struct Kernel<Path::kStandardCpp, LhsScalar, RhsScalar, DstScalar, Spec> {
-  using AccumScalar = typename Spec::AccumScalar;
-  using LhsLayout = typename Spec::StandardCppKernelLhsLayout;
-  using RhsLayout = typename Spec::StandardCppKernelRhsLayout;
-  explicit Kernel(Tuning) {}
-  void Run(const PackedMatrix<LhsScalar>& lhs,
-           const PackedMatrix<RhsScalar>& rhs, const Spec& spec, int start_row,
-           int start_col, int end_row, int end_col,
-           Matrix<DstScalar>* dst) const {
-    // See the comment in RunKernelTyped. end_row may be larger than
-    // dst->layout.rows. It's the responsibility of the kernel to avoid
-    // overrunning dst boundaries, which we do here by computing
-    // clamped_end_row.
-    int clamped_end_row = std::min(end_row, dst->layout.rows);
-    int clamped_end_col = std::min(end_col, dst->layout.cols);
-    RUY_DCHECK_LE(0, start_row);
-    RUY_DCHECK_LE(start_row, clamped_end_row);
-    RUY_DCHECK_LE(clamped_end_row, dst->layout.rows);
-    RUY_DCHECK_LE(clamped_end_row, end_row);
-    RUY_DCHECK_LE(end_row - clamped_end_row, LhsLayout::kCols);
-    RUY_DCHECK_LE(0, start_col);
-    RUY_DCHECK_LE(start_col, clamped_end_col);
-    RUY_DCHECK_LE(clamped_end_col, dst->layout.cols);
-    RUY_DCHECK_LE(clamped_end_col, end_col);
-    RUY_DCHECK_LE(end_col - clamped_end_col, RhsLayout::kCols);
-    gemmlowp::ScopedProfilingLabel label("Kernel (Standard Cpp)");
-    const int depth = lhs.layout.rows;
-    for (int i = start_row; i < clamped_end_row; i++) {
-      for (int j = start_col; j < clamped_end_col; j++) {
-        using AccumScalar = typename Spec::AccumScalar;
-        AccumScalar accum = 0;
-        for (int k = 0; k < depth; k++) {
-          AccumScalar lhs_val = Element(lhs, k, i);
-          AccumScalar rhs_val = Element(rhs, k, j);
-          accum += lhs_val * rhs_val;
-        }
-        if (spec.bias) {
-          accum += spec.bias[i];
-        }
-        if (lhs.zero_point) {
-          accum -= lhs.zero_point * rhs.sums[j];
-        }
-        if (rhs.zero_point) {
-          accum -= rhs.zero_point * lhs.sums[i];
-        }
-        if (lhs.zero_point && rhs.zero_point) {
-          accum += lhs.zero_point * rhs.zero_point * depth;
-        }
-        ApplyMultiplier(spec, i, &accum);
-        accum += dst->zero_point;
-        accum = std::min<AccumScalar>(accum, spec.clamp_max);
-        accum = std::max<AccumScalar>(accum, spec.clamp_min);
-        *ElementPtr(dst, i, j) = static_cast<DstScalar>(accum);
-      }
-    }
-  }
-};
-
-#define RUY_INHERIT_KERNEL(PARENT, CHILD)                                  \
-  template <typename LhsScalar, typename RhsScalar, typename DstScalar,    \
-            typename Spec>                                                 \
-  struct Kernel<CHILD, LhsScalar, RhsScalar, DstScalar, Spec>              \
-      : Kernel<PARENT, LhsScalar, RhsScalar, DstScalar, Spec> {            \
-    explicit Kernel(Tuning tuning)                                         \
-        : Kernel<PARENT, LhsScalar, RhsScalar, DstScalar, Spec>(tuning) {} \
-  };
-
-RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kNeon)
-RUY_INHERIT_KERNEL(Path::kNeon, Path::kNeonDotprod)
-
-// KernelParams are shared across 32-bit and 64-bit NEON code.
-#if (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && \
-    (RUY_OPT_ENABLED(RUY_OPT_ASM))
-
-#define RUY_ASM_FLAG_HAS_BIAS 0x1
-#define RUY_ASM_FLAG_HAS_LHS_SUMS 0x2
-#define RUY_ASM_FLAG_HAS_RHS_SUMS 0x4
-#define RUY_ASM_FLAG_HAS_PERCHANNEL 0x8
-#define RUY_ASM_FLAG_NEEDS_LEFT_SHIFT 0x10
-
-#define RUY_ASM_TYPE_ID_UINT8 1
-#define RUY_ASM_TYPE_ID_INT8 2
-#define RUY_ASM_TYPE_ID_INT16 3
-#define RUY_ASM_TYPE_ID_INT32 4
-
-template <typename DstScalar>
-struct DstTypeId {};
-
-template <>
-struct DstTypeId<std::uint8_t> {
-  static constexpr int kValue = RUY_ASM_TYPE_ID_UINT8;
-};
-
-template <>
-struct DstTypeId<std::int8_t> {
-  static constexpr int kValue = RUY_ASM_TYPE_ID_INT8;
-};
-
-template <>
-struct DstTypeId<std::int16_t> {
-  static constexpr int kValue = RUY_ASM_TYPE_ID_INT16;
-};
-
-template <>
-struct DstTypeId<std::int32_t> {
-  static constexpr int kValue = RUY_ASM_TYPE_ID_INT32;
-};
-
-template <int LhsCols, int RhsCols>
-struct KernelParams8bit {
-  static constexpr int kMaxDstTypeSize = 4;
-
-  const std::int32_t* bias;
-  const std::int32_t* lhs_sums;
-  const std::int32_t* rhs_sums;
-  const std::int8_t* lhs_base_ptr;
-  const std::int32_t* multiplier_fixedpoint;
-  const std::int32_t* multiplier_exponent;
-  const std::int8_t* rhs_base_ptr;
-  void* dst_base_ptr;
-  std::int32_t lhs_zero_point;
-  std::int32_t rhs_zero_point;
-  std::int32_t dst_zero_point;
-  std::int32_t prod_zp_depth;
-  std::int32_t start_row;
-  std::int32_t start_col;
-  std::int32_t last_row;
-  std::int32_t last_col;
-  std::int32_t dst_rows;
-  std::int32_t dst_cols;
-  std::int32_t lhs_stride;
-  std::int32_t rhs_stride;
-  std::int32_t dst_stride;
-  std::int32_t depth;
-  std::int32_t clamp_min;
-  std::int32_t clamp_max;
-  std::uint8_t flags;
-  std::uint8_t dst_type_id;
-  const std::int32_t zero_data[LhsCols] = {0};
-  std::uint8_t dst_tmp_buf[LhsCols * RhsCols * kMaxDstTypeSize];
-  std::int32_t multiplier_fixedpoint_buf[LhsCols];
-  std::int32_t multiplier_exponent_buf[LhsCols];
-};
-
-template <typename DstScalar, int LhsCols, int RhsCols>
-void MakeKernelParams8bit(const PackedMatrix<std::int8_t>& lhs,
-                          const PackedMatrix<std::int8_t>& rhs,
-                          const BasicSpec<std::int32_t, DstScalar>& spec,
-                          int start_row, int start_col, int end_row,
-                          int end_col, Matrix<DstScalar>* dst,
-                          KernelParams8bit<LhsCols, RhsCols>* params) {
-  using Params = KernelParams8bit<LhsCols, RhsCols>;
-
-  static_assert(sizeof(DstScalar) <= Params::kMaxDstTypeSize, "");
-
-  const int depth = lhs.layout.rows;
-  RUY_DCHECK_EQ(start_row % LhsCols, 0);
-  RUY_DCHECK_EQ(start_col % RhsCols, 0);
-  RUY_DCHECK_EQ(end_row % LhsCols, 0);
-  RUY_DCHECK_EQ(end_col % RhsCols, 0);
-
-  params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride;
-  params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride;
-  params->flags = 0;
-  params->bias = params->zero_data;
-  if (spec.bias) {
-    params->bias = spec.bias;
-    params->flags |= RUY_ASM_FLAG_HAS_BIAS;
-  }
-  if (lhs.sums) {
-    params->lhs_sums = lhs.sums;
-    params->flags |= RUY_ASM_FLAG_HAS_LHS_SUMS;
-  }
-  if (rhs.sums) {
-    params->rhs_sums = rhs.sums;
-    params->flags |= RUY_ASM_FLAG_HAS_RHS_SUMS;
-  }
-  params->start_row = start_row;
-  params->start_col = start_col;
-  params->last_row = end_row - LhsCols;
-  params->last_col = end_col - RhsCols;
-  params->lhs_stride = lhs.layout.stride;
-  params->rhs_stride = rhs.layout.stride;
-  params->dst_stride = sizeof(DstScalar) * dst->layout.stride;
-  params->lhs_zero_point = lhs.zero_point;
-  params->rhs_zero_point = rhs.zero_point;
-  params->dst_zero_point = dst->zero_point;
-  params->depth = depth;
-  params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth;
-  if (spec.multiplier_fixedpoint_perchannel) {
-    params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT;
-    params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL;
-    params->multiplier_fixedpoint = spec.multiplier_fixedpoint_perchannel;
-    params->multiplier_exponent = spec.multiplier_exponent_perchannel;
-  } else {
-    if (spec.multiplier_exponent > 0) {
-      params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT;
-    }
-    params->multiplier_fixedpoint = params->multiplier_fixedpoint_buf;
-    params->multiplier_exponent = params->multiplier_exponent_buf;
-    for (int i = 0; i < LhsCols; i++) {
-      params->multiplier_fixedpoint_buf[i] = spec.multiplier_fixedpoint;
-      params->multiplier_exponent_buf[i] = spec.multiplier_exponent;
-    }
-  }
-  params->clamp_min = spec.clamp_min;
-  params->clamp_max = spec.clamp_max;
-  params->dst_rows = dst->layout.rows;
-  params->dst_cols = dst->layout.cols;
-
-  RUY_DCHECK_LT(params->last_row, params->dst_rows);
-  RUY_DCHECK_LT(params->last_col, params->dst_cols);
-
-  params->dst_type_id = DstTypeId<DstScalar>::kValue;
-  params->dst_base_ptr =
-      dst->data.get() + start_col * dst->layout.stride + start_row;
-}
-
-void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 4>& params);
-void Kernel8bitNeonInOrder(const KernelParams8bit<4, 4>& params);
-void Kernel8bitNeonDotprodOutOfOrder(const KernelParams8bit<8, 8>& params);
-void Kernel8bitNeonDotprodInOrder(const KernelParams8bit<8, 8>& params);
-
-#if RUY_PLATFORM(NEON_64)
-template <typename DstScalar>
-struct Kernel<Path::kNeon, std::int8_t, std::int8_t, DstScalar,
-              BasicSpec<std::int32_t, DstScalar>> {
-  using LhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>;
-  using RhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>;
-  Tuning tuning = Tuning::kAuto;
-  explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
-  void Run(const PackedMatrix<std::int8_t>& lhs,
-           const PackedMatrix<std::int8_t>& rhs,
-           const BasicSpec<std::int32_t, DstScalar>& spec, int start_row,
-           int start_col, int end_row, int end_col,
-           Matrix<DstScalar>* dst) const {
-    KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
-    MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col,
-                         dst, &params);
-    if (__builtin_expect(tuning == Tuning::kInOrder, true)) {
-      Kernel8bitNeonInOrder(params);
-    } else {
-      Kernel8bitNeonOutOfOrder(params);
-    }
-  }
-};
-
-template <typename DstScalar>
-struct Kernel<Path::kNeonDotprod, std::int8_t, std::int8_t, DstScalar,
-              BasicSpec<std::int32_t, DstScalar>> {
-  Tuning tuning = Tuning::kAuto;
-  using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
-  using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
-  explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
-  void Run(const PackedMatrix<std::int8_t>& lhs,
-           const PackedMatrix<std::int8_t>& rhs,
-           const BasicSpec<std::int32_t, DstScalar>& spec, int start_row,
-           int start_col, int end_row, int end_col,
-           Matrix<DstScalar>* dst) const {
-    KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
-    MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col,
-                         dst, &params);
-    if (__builtin_expect(tuning == Tuning::kInOrder, true)) {
-      Kernel8bitNeonDotprodInOrder(params);
-    } else {
-      Kernel8bitNeonDotprodOutOfOrder(params);
-    }
-  }
-};
-#endif
-
-template <int LhsCols, int RhsCols>
-struct KernelParamsFloat {
-  const float* lhs_base_ptr;
-  const float* rhs_base_ptr;
-  float* dst_base_ptr;
-  const float* bias;
-  std::int32_t start_row;
-  std::int32_t start_col;
-  std::int32_t last_row;
-  std::int32_t last_col;
-  std::int32_t dst_rows;
-  std::int32_t dst_cols;
-  std::int32_t lhs_stride;
-  std::int32_t rhs_stride;
-  std::int32_t dst_stride;
-  std::int32_t depth;
-  float clamp_min;
-  float clamp_max;
-  std::uint8_t flags;
-  const float zero_data[LhsCols] = {0};
-  float dst_tmp_buf[LhsCols * RhsCols];
-};
-
-template <int LhsCols, int RhsCols>
-inline void MakeKernelParamsFloat(const PackedMatrix<float>& lhs,
-                                  const PackedMatrix<float>& rhs,
-                                  const BasicSpec<float, float>& spec,
-                                  int start_row, int start_col, int end_row,
-                                  int end_col, Matrix<float>* dst,
-                                  KernelParamsFloat<LhsCols, RhsCols>* params) {
-  using Params = KernelParamsFloat<LhsCols, RhsCols>;
-
-  const int depth = lhs.layout.rows;
-  RUY_DCHECK_EQ(start_row % LhsCols, 0);
-  RUY_DCHECK_EQ(start_col % RhsCols, 0);
-  RUY_DCHECK_EQ(end_row % LhsCols, 0);
-  RUY_DCHECK_EQ(end_col % RhsCols, 0);
-
-  params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride;
-  params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride;
-  params->dst_base_ptr =
-      dst->data.get() + start_col * dst->layout.stride + start_row;
-
-  std::uint8_t flags = 0;
-  params->bias = params->zero_data;
-  if (spec.bias) {
-    params->bias = spec.bias;
-    flags |= RUY_ASM_FLAG_HAS_BIAS;
-  }
-  params->flags = flags;
-  params->start_row = start_row;
-  params->start_col = start_col;
-  params->last_row = end_row - LhsCols;
-  params->last_col = end_col - RhsCols;
-  params->lhs_stride = sizeof(float) * lhs.layout.stride;
-  params->rhs_stride = sizeof(float) * rhs.layout.stride;
-  params->dst_stride = sizeof(float) * dst->layout.stride;
-  params->depth = depth;
-  params->clamp_min = spec.clamp_min;
-  params->clamp_max = spec.clamp_max;
-  params->dst_rows = dst->layout.rows;
-  params->dst_cols = dst->layout.cols;
-
-  RUY_DCHECK_LT(params->last_row, params->dst_rows);
-  RUY_DCHECK_LT(params->last_col, params->dst_cols);
-}
-
-void KernelFloatNeonOutOfOrder(const KernelParamsFloat<8, 8>& params);
-void KernelFloatNeonInOrder(const KernelParamsFloat<8, 8>& params);
-void KernelFloat32NeonOutOfOrder(const KernelParamsFloat<8, 4>& params);
-void KernelFloatNeonDotprodInOrder(const KernelParamsFloat<8, 8>& params);
-
-#if RUY_PLATFORM(NEON_64)
-// A Float kernel for ARM64 Neon.
-template <>
-struct Kernel<Path::kNeon, float, float, float, BasicSpec<float, float>> {
-  Tuning tuning = Tuning::kAuto;
-  using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
-  using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
-  explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
-  void Run(const PackedMatrix<float>& lhs, const PackedMatrix<float>& rhs,
-           const BasicSpec<float, float>& spec, int start_row, int start_col,
-           int end_row, int end_col, Matrix<float>* dst) const {
-    KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
-    MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row,
-                          end_col, dst, &params);
-    if (__builtin_expect(tuning == Tuning::kInOrder, true)) {
-      KernelFloatNeonInOrder(params);
-    } else {
-      KernelFloatNeonOutOfOrder(params);
-    }
-  }
-};
-#endif
-
-#if RUY_PLATFORM(NEON_32)
-// A Float kernel for ARM32 Neon.
-template <>
-struct Kernel<Path::kNeon, float, float, float, BasicSpec<float, float>> {
-  Tuning tuning = Tuning::kAuto;
-  using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
-  using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 4>;
-  explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
-  void Run(const PackedMatrix<float>& lhs, const PackedMatrix<float>& rhs,
-           const BasicSpec<float, float>& spec, int start_row, int start_col,
-           int end_row, int end_col, Matrix<float>* dst) const {
-    KernelParamsFloat<8, 4> params;
-
-    MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row,
-                          end_col, dst, &params);
-
-    KernelFloat32NeonOutOfOrder(params);
-  }
-};
-#endif
-
-// While the dotprod NEON extension does not concern floating-point arithmetic,
-// its presence allows us to distinguish, in the in-order tuning case, between
-// A53 and A55r1. TODO: should this be folded into tuning?
-template <>
-struct Kernel<Path::kNeonDotprod, float, float, float,
-              BasicSpec<float, float>> {
-  Tuning tuning = Tuning::kAuto;
-  using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
-  using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
-  using Base =
-      Kernel<Path::kNeon, float, float, float, BasicSpec<float, float>>;
-  explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
-  void Run(const PackedMatrix<float>& lhs, const PackedMatrix<float>& rhs,
-           const BasicSpec<float, float>& spec, int start_row, int start_col,
-           int end_row, int end_col, Matrix<float>* dst) const {
-    KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
-    MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row,
-                          end_col, dst, &params);
-    if (__builtin_expect(tuning == Tuning::kInOrder, true)) {
-      KernelFloatNeonDotprodInOrder(params);
-    } else {
-      KernelFloatNeonOutOfOrder(params);
-    }
-  }
-};
-
-#endif  // (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) &&
-        // (RUY_OPT_ENABLED(RUY_OPT_ASM)
-}  // namespace ruy
+// IWYU pragma: end_exports
 
 #endif  // TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_H_
diff --git a/tensorflow/lite/experimental/ruy/kernel_arm.h b/tensorflow/lite/experimental/ruy/kernel_arm.h
new file mode 100644
index 0000000..6f49dc4
--- /dev/null
+++ b/tensorflow/lite/experimental/ruy/kernel_arm.h
@@ -0,0 +1,172 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_ARM_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_ARM_H_
+
+#include <cstddef>
+#include <cstdint>
+
+#include "fixedpoint/fixedpoint.h"
+#include "profiling/instrumentation.h"
+#include "tensorflow/lite/experimental/ruy/common.h"
+#include "tensorflow/lite/experimental/ruy/internal_matrix.h"
+#include "tensorflow/lite/experimental/ruy/kernel_common.h"
+#include "tensorflow/lite/experimental/ruy/matrix.h"
+#include "tensorflow/lite/experimental/ruy/opt_set.h"
+#include "tensorflow/lite/experimental/ruy/path.h"
+#include "tensorflow/lite/experimental/ruy/platform.h"
+#include "tensorflow/lite/experimental/ruy/side_pair.h"
+#include "tensorflow/lite/experimental/ruy/size_util.h"
+#include "tensorflow/lite/experimental/ruy/spec.h"
+#include "tensorflow/lite/experimental/ruy/tune.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM(NEON) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 4>& params);
+void Kernel8bitNeonInOrder(const KernelParams8bit<4, 4>& params);
+void Kernel8bitNeonDotprodOutOfOrder(const KernelParams8bit<8, 8>& params);
+void Kernel8bitNeonDotprodInOrder(const KernelParams8bit<8, 8>& params);
+
+#if RUY_PLATFORM(NEON_64)
+template <typename DstScalar>
+struct Kernel<Path::kNeon, std::int8_t, std::int8_t, DstScalar,
+              BasicSpec<std::int32_t, DstScalar>> {
+  using LhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>;
+  using RhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>;
+  Tuning tuning = Tuning::kAuto;
+  explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+  void Run(const PackedMatrix<std::int8_t>& lhs,
+           const PackedMatrix<std::int8_t>& rhs,
+           const BasicSpec<std::int32_t, DstScalar>& spec, int start_row,
+           int start_col, int end_row, int end_col,
+           Matrix<DstScalar>* dst) const {
+    KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
+    MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col,
+                         dst, &params);
+    if (__builtin_expect(tuning == Tuning::kInOrder, true)) {
+      Kernel8bitNeonInOrder(params);
+    } else {
+      Kernel8bitNeonOutOfOrder(params);
+    }
+  }
+};
+
+template <typename DstScalar>
+struct Kernel<Path::kNeonDotprod, std::int8_t, std::int8_t, DstScalar,
+              BasicSpec<std::int32_t, DstScalar>> {
+  Tuning tuning = Tuning::kAuto;
+  using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
+  using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
+  explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+  void Run(const PackedMatrix<std::int8_t>& lhs,
+           const PackedMatrix<std::int8_t>& rhs,
+           const BasicSpec<std::int32_t, DstScalar>& spec, int start_row,
+           int start_col, int end_row, int end_col,
+           Matrix<DstScalar>* dst) const {
+    KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
+    MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col,
+                         dst, &params);
+    if (__builtin_expect(tuning == Tuning::kInOrder, true)) {
+      Kernel8bitNeonDotprodInOrder(params);
+    } else {
+      Kernel8bitNeonDotprodOutOfOrder(params);
+    }
+  }
+};
+#endif
+
+void KernelFloatNeonOutOfOrder(const KernelParamsFloat<8, 8>& params);
+void KernelFloatNeonInOrder(const KernelParamsFloat<8, 8>& params);
+void KernelFloat32NeonOutOfOrder(const KernelParamsFloat<8, 4>& params);
+void KernelFloatNeonDotprodInOrder(const KernelParamsFloat<8, 8>& params);
+
+#if RUY_PLATFORM(NEON_64)
+// A Float kernel for ARM64 Neon.
+template <>
+struct Kernel<Path::kNeon, float, float, float, BasicSpec<float, float>> {
+  Tuning tuning = Tuning::kAuto;
+  using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+  using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+  explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+  void Run(const PackedMatrix<float>& lhs, const PackedMatrix<float>& rhs,
+           const BasicSpec<float, float>& spec, int start_row, int start_col,
+           int end_row, int end_col, Matrix<float>* dst) const {
+    KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
+    MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row,
+                          end_col, dst, &params);
+    if (__builtin_expect(tuning == Tuning::kInOrder, true)) {
+      KernelFloatNeonInOrder(params);
+    } else {
+      KernelFloatNeonOutOfOrder(params);
+    }
+  }
+};
+#endif
+
+#if RUY_PLATFORM(NEON_32)
+// A Float kernel for ARM32 Neon.
+template <>
+struct Kernel<Path::kNeon, float, float, float, BasicSpec<float, float>> {
+  Tuning tuning = Tuning::kAuto;
+  using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+  using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 4>;
+  explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+  void Run(const PackedMatrix<float>& lhs, const PackedMatrix<float>& rhs,
+           const BasicSpec<float, float>& spec, int start_row, int start_col,
+           int end_row, int end_col, Matrix<float>* dst) const {
+    KernelParamsFloat<8, 4> params;
+
+    MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row,
+                          end_col, dst, &params);
+
+    KernelFloat32NeonOutOfOrder(params);
+  }
+};
+#endif
+
+// While the dotprod NEON extension does not concern floating-point arithmetic,
+// its presence allows us to distinguish, in the in-order tuning case, between
+// A53 and A55r1. TODO: should this be folded into tuning?
+template <>
+struct Kernel<Path::kNeonDotprod, float, float, float,
+              BasicSpec<float, float>> {
+  Tuning tuning = Tuning::kAuto;
+  using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+  using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+  using Base =
+      Kernel<Path::kNeon, float, float, float, BasicSpec<float, float>>;
+  explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+  void Run(const PackedMatrix<float>& lhs, const PackedMatrix<float>& rhs,
+           const BasicSpec<float, float>& spec, int start_row, int start_col,
+           int end_row, int end_col, Matrix<float>* dst) const {
+    KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
+    MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row,
+                          end_col, dst, &params);
+    if (__builtin_expect(tuning == Tuning::kInOrder, true)) {
+      KernelFloatNeonDotprodInOrder(params);
+    } else {
+      KernelFloatNeonOutOfOrder(params);
+    }
+  }
+};
+
+#endif  // RUY_PLATFORM(NEON) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+}  // namespace ruy
+
+#endif  // TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_ARM_H_
diff --git a/tensorflow/lite/experimental/ruy/kernel_arm32.cc b/tensorflow/lite/experimental/ruy/kernel_arm32.cc
index 8607f25..c002ba4 100644
--- a/tensorflow/lite/experimental/ruy/kernel_arm32.cc
+++ b/tensorflow/lite/experimental/ruy/kernel_arm32.cc
@@ -15,6 +15,7 @@
 
 #include "profiling/instrumentation.h"
 #include "tensorflow/lite/experimental/ruy/kernel.h"
+#include "tensorflow/lite/experimental/ruy/opt_set.h"
 #include "tensorflow/lite/experimental/ruy/platform.h"
 
 namespace ruy {
@@ -130,15 +131,12 @@
         // clang-format off
 
         // Load the first 32 bytes of LHS and RHS data.
-        // Load q0
-        "vld1.32 {d0}, [%[lhs_ptr]]!\n"
-        "vld1.32 {d1}, [%[lhs_ptr]]!\n"
-        // Load q1
-        "vld1.32 {d2}, [%[lhs_ptr]]!\n"
-        "vld1.32 {d3}, [%[lhs_ptr]]!\n"
+        // Load q0, q1
+        "vld1.32 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n"
+        "pld [%[lhs_ptr]]\n"
         // Load q2
-        "vld1.32 {d4}, [%[rhs_ptr]]!\n"
-        "vld1.32 {d5}, [%[rhs_ptr]]!\n"
+        "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n"
+        "pld [%[rhs_ptr]]\n"
 
         "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
 
@@ -189,17 +187,16 @@
         "vmla.f32 q5, q0, d4[1]\n"
         "vmla.f32 q7, q0, d5[0]\n"
         "vmla.f32 q9, q0, d5[1]\n"
-        "vld1.32 {d0}, [%[lhs_ptr]]!\n" // Reload LHS 1 into r0
-        "vld1.32 {d1}, [%[lhs_ptr]]!\n" // Reload LHS 1 into r0
+        "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n" // Reload LHS
 
         "vmla.f32 q4, q1, d4[0]\n"
         "vmla.f32 q6, q1, d4[1]\n"
         "vmla.f32 q8, q1, d5[0]\n"
         "vmla.f32 q10, q1, d5[1]\n"
-        "vld1.32 {d2}, [%[lhs_ptr]]!\n" // Reload LHS 2 into r1
-        "vld1.32 {d3}, [%[lhs_ptr]]!\n" // Reload LHS 2 into r1
-        "vld1.32 {d4}, [%[rhs_ptr]]!\n" // Reload RHS into r2
-        "vld1.32 {d5}, [%[rhs_ptr]]!\n" // Reload RHS into r2
+        "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
+        "pld [%[lhs_ptr]]\n"
+        "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" // Reload RHS
+        "pld [%[rhs_ptr]]\n"
 
         "add r1, r1, #1\n"
         "cmp r1, r2\n"
@@ -291,25 +288,18 @@
         "movne r1, r5\n"
 
         // Load 8 bias values.
-        "vld1.32 {d24}, [r1]!\n"
-        "vld1.32 {d25}, [r1]!\n"
-        "vld1.32 {d26}, [r1]!\n"
-        "vld1.32 {d27}, [r1]\n"
+        "vld1.32 {d24, d25, d26, d27}, [r1]\n"
 
         // Now that we know what LHS and RHS data the next iteration of the
         // main loop will need to load, we start loading the first 32 bytes of
         // each of LHS and RHS, into q0 -- q2, as we don't need q0 -- q2 anymore
         // in the rest of the work on the current block.
-        // Load q0
-        "vld1.32 {d0}, [%[lhs_ptr]]!\n"
-        "vld1.32 {d1}, [%[lhs_ptr]]!\n"
-        // Load q1
-        "vld1.32 {d2}, [%[lhs_ptr]]!\n"
-        "vld1.32 {d3}, [%[lhs_ptr]]!\n"
+        // Load q0, q1
+        "vld1.32 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n"
+        "pld [%[lhs_ptr]]\n"
         // Load q2
-        "vld1.32 {d4}, [%[rhs_ptr]]!\n"
-        "vld1.32 {d5}, [%[rhs_ptr]]!\n"
-
+        "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n"
+        "pld [%[rhs_ptr]]\n"
 
         // Perform the bias-addition (per the above, we have just folded into
         // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
@@ -391,40 +381,20 @@
         "31:\n"
 
         // Write our float values to the destination described by
-        // (r3 address, r4 stride).
-        // q3 = d6, d7
-        "vstr d6, [r3, #0]\n"
-        "vstr d7, [r3, #8]\n"
-        // q4 = d8, d9
-        "vstr d8, [r3, #16]\n"
-        "vstr d9, [r3, #24]\n"
+        // (r3 address, r4 stride)
+        "vst1.32 {d6, d7, d8, d9}, [r3]\n"
         "add r3, r3, r4\n"
         RUY_MAKE_ZERO(q3)
         RUY_MAKE_ZERO(q4)
-        // q5 = d10, d11
-        "vstr d10, [r3, #0]\n"
-        "vstr d11, [r3, #8]\n"
-        // q6 = d12, d13
-        "vstr d12, [r3, #16]\n"
-        "vstr d13, [r3, #24]\n"
+        "vst1.32 {d10, d11, d12, d13}, [r3]\n"
         "add r3, r3, r4\n"
         RUY_MAKE_ZERO(q5)
         RUY_MAKE_ZERO(q6)
-        // q7 = d14, d15
-        "vstr d14, [r3, #0]\n"
-        "vstr d15, [r3, #8]\n"
-        // q8 = d16, d17
-        "vstr d16, [r3, #16]\n"
-        "vstr d17, [r3, #24]\n"
+        "vst1.32 {d14, d15, d16, d17}, [r3]\n"
         "add r3, r3, r4\n"
         RUY_MAKE_ZERO(q7)
         RUY_MAKE_ZERO(q8)
-        // q9 = d18, d19
-        "vstr d18, [r3, #0]\n"
-        "vstr d19, [r3, #8]\n"
-        // q10 = d20, d21
-        "vstr d20, [r3, #16]\n"
-        "vstr d21, [r3, #24]\n"
+        "vst1.32 {d18, d19, d20, d21}, [r3]\n"
         "add r3, r3, r4\n"
         RUY_MAKE_ZERO(q9)
         RUY_MAKE_ZERO(q10)
@@ -518,10 +488,12 @@
         // clang-format on
         : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr)
         : [ params ] "r"(&params), [dst_tmp_buf] "r"(params.dst_tmp_buf)
+        // Clobber list must specify q registers (and not their constituent
+        // d registers). There is a (currently unexplained) slowdown if
+        // d registers are listed in the clobbers list.
         : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc",
-          "memory", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8",
-          "d9", "d10", "d12", "d13", "d14", "d15", "d16", "d17", "d18","d19",
-          "d20", "d21", "d22", "d23", "d24", "d25", "d26");
+          "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
+          "q9", "q10", "q12", "q13");
 }
 
 #undef RUY_OFFSET_BIAS
diff --git a/tensorflow/lite/experimental/ruy/kernel_arm64.cc b/tensorflow/lite/experimental/ruy/kernel_arm64.cc
index d2bcc10..6fa71bd 100644
--- a/tensorflow/lite/experimental/ruy/kernel_arm64.cc
+++ b/tensorflow/lite/experimental/ruy/kernel_arm64.cc
@@ -13,8 +13,12 @@
 limitations under the License.
 ==============================================================================*/
 
+#include <cstdint>
+
 #include "profiling/instrumentation.h"
+#include "tensorflow/lite/experimental/ruy/common.h"
 #include "tensorflow/lite/experimental/ruy/kernel.h"
+#include "tensorflow/lite/experimental/ruy/opt_set.h"
 #include "tensorflow/lite/experimental/ruy/platform.h"
 
 namespace ruy {
diff --git a/tensorflow/lite/experimental/ruy/kernel_avx512.cc b/tensorflow/lite/experimental/ruy/kernel_avx512.cc
new file mode 100644
index 0000000..03443e8
--- /dev/null
+++ b/tensorflow/lite/experimental/ruy/kernel_avx512.cc
@@ -0,0 +1,813 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <cstdint>
+
+#include "profiling/instrumentation.h"
+#include "tensorflow/lite/experimental/ruy/check_macros.h"
+#include "tensorflow/lite/experimental/ruy/kernel.h"
+#include "tensorflow/lite/experimental/ruy/opt_set.h"
+#include "tensorflow/lite/experimental/ruy/platform.h"
+
+#if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+#include <immintrin.h>  // IWYU pragma: keep
+#endif
+
+namespace ruy {
+
+#if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+inline std::int32_t mm512_get1_epi32(const __m512i v, int i) {
+  __m256i a =
+      i < 8 ? _mm512_extracti32x8_epi32(v, 0) : _mm512_extracti32x8_epi32(v, 1);
+  switch (i & ~8) {
+    case 0:
+      return _mm256_extract_epi32(a, 0);
+    case 1:
+      return _mm256_extract_epi32(a, 1);
+    case 2:
+      return _mm256_extract_epi32(a, 2);
+    case 3:
+      return _mm256_extract_epi32(a, 3);
+    case 4:
+      return _mm256_extract_epi32(a, 4);
+    case 5:
+      return _mm256_extract_epi32(a, 5);
+    case 6:
+      return _mm256_extract_epi32(a, 6);
+    case 7:
+      return _mm256_extract_epi32(a, 7);
+    default:
+      RUY_DCHECK(i < 16);
+      return 0;
+  }
+}
+
+inline __m512i mm512_set1_epi32(__m512i* v, int i, std::int32_t x) {
+  return *v = _mm512_mask_set1_epi32(*v, 1 << i, x);
+}
+
+void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
+  gemmlowp::ScopedProfilingLabel label("Kernel kAvx512");
+
+  std::int32_t dst_stride;
+  if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) ||
+      (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) {
+    dst_stride = params.dst_stride;
+  } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
+    dst_stride = params.dst_stride / sizeof(std::int16_t);
+  } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
+    dst_stride = params.dst_stride / sizeof(std::int32_t);
+  } else {
+    RUY_DCHECK(false);
+  }
+
+  int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0;
+
+  const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+  void* dst_col_ptr = params.dst_base_ptr;
+  const std::int32_t* bias_col_ptr = params.bias;
+  if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
+    bias_col_ptr += params.start_row;
+  }
+
+  for (int col = params.start_col; col <= params.last_col; col += 16) {
+    const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+    void* dst_ptr = dst_col_ptr;
+    const std::int32_t* bias_ptr = bias_col_ptr;
+
+    for (int row = params.start_row; row <= params.last_row; row += 16) {
+      const int residual_rows = std::min(params.dst_rows - row, 16);
+      const int residual_cols = std::min(params.dst_cols - col, 16);
+
+      __m512i accum_data_v[16];
+      __m512i accum_data_v_low[16];
+      __m512i accum_data_v_high[16];
+
+      // Initialize with bias.
+      const __mmask16 row_mask =
+          (static_cast<std::uint32_t>(1) << residual_rows) - 1;
+      const __m512i initial_accum_data =
+          _mm512_maskz_loadu_epi32(row_mask, bias_ptr);
+      __m512i initial_accum_data_low = initial_accum_data;
+      __m512i initial_accum_data_high = _mm512_setzero_epi32();
+      bias_ptr += bias_ptr_block_increment;
+
+      for (int j = 0; j < 16; ++j) {
+        accum_data_v_low[j] = initial_accum_data_low;
+        accum_data_v_high[j] = initial_accum_data_high;
+      }
+
+      //
+
+      const std::int8_t* lhs_ptr = lhs_col_ptr;
+      const std::int8_t* rhs_ptr = rhs_col_ptr;
+      for (int d = 0; d < params.depth; d += 4) {
+        const __m512i lhs_data = _mm512_loadu_epi8(lhs_ptr);
+        __m512i rhs_data = _mm512_loadu_epi8(rhs_ptr);
+
+        // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit.
+        __m512i lhs_16_bit_low =
+            _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data));
+        // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit.
+        __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16(
+            _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16)));
+
+        for (int j = 0; j < 16; ++j) {
+          // Mask that drops the 0th element.
+          static constexpr std::uint16_t shift_mask = 0xfffe;
+          const __m256i dup_rhs_element_low =
+              _mm256_broadcastw_epi16(_mm512_castsi512_si128(rhs_data));
+          // Shift rhs_data, moving next element into 0 position.
+          const __m256i dup_rhs_element_high = _mm256_set1_epi16(
+              _mm_extract_epi16(_mm512_castsi512_si128(rhs_data), 1));
+          // Shift rhs_data, moving next element into 0 position.
+          rhs_data = _mm512_maskz_compress_epi32(shift_mask, rhs_data);
+
+          __m512i rhs_16_bit_dup_low =
+              _mm512_cvtepi8_epi16(dup_rhs_element_low);
+          __m512i rhs_16_bit_dup_high =
+              _mm512_cvtepi8_epi16(dup_rhs_element_high);
+
+          accum_data_v_low[j] = _mm512_add_epi32(
+              accum_data_v_low[j],
+              _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+          accum_data_v_high[j] = _mm512_add_epi32(
+              accum_data_v_high[j],
+              _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+        }
+
+        lhs_ptr += 16 * 4;
+        rhs_ptr += 16 * 4;
+      }
+      for (int j = 0; j < 16; ++j) {
+        accum_data_v[j] =
+            _mm512_add_epi32(accum_data_v_low[j], accum_data_v_high[j]);
+      }
+
+      // Move most of this up to bias, or even outside row loop.
+
+      const std::int32_t lhs_zero_point = params.lhs_zero_point;
+      const std::int32_t rhs_zero_point = params.rhs_zero_point;
+      const std::int32_t prod_zp_depth = params.prod_zp_depth;
+      if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
+        const __m512i lhs_sums_offset =
+            _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point),
+                               _mm512_loadu_epi32(&params.lhs_sums[row]));
+        for (int j = 0; j < 16; ++j) {
+          accum_data_v[j] = _mm512_sub_epi32(accum_data_v[j], lhs_sums_offset);
+        }
+      }
+      if (((params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point) ||
+          prod_zp_depth) {
+        __m512i non_lhs_sums_offset =
+            _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point),
+                               _mm512_loadu_epi32(&params.rhs_sums[col]));
+        non_lhs_sums_offset = _mm512_sub_epi32(
+            non_lhs_sums_offset, _mm512_set1_epi32(prod_zp_depth));
+
+        for (int j = 0; j < 16; ++j) {
+          accum_data_v[j] = _mm512_sub_epi32(
+              accum_data_v[j],
+              _mm512_set1_epi32(mm512_get1_epi32(non_lhs_sums_offset, j)));
+        }
+      }
+
+      //
+
+      if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
+        __m512i m_vector;
+        __m512i e_vector;
+        // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
+        if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) {
+          m_vector = _mm512_maskz_loadu_epi32(
+              row_mask, &params.multiplier_fixedpoint[row]);
+          e_vector = _mm512_maskz_loadu_epi32(row_mask,
+                                              &params.multiplier_exponent[row]);
+        } else {
+          // These arrays have size LhsCols, and are pre-filled.
+          m_vector =
+              _mm512_maskz_loadu_epi32(row_mask, params.multiplier_fixedpoint);
+          e_vector =
+              _mm512_maskz_loadu_epi32(row_mask, params.multiplier_exponent);
+        }
+
+        const __m512i m_64bit_low =
+            _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0));
+        const __m512i m_64bit_high =
+            _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1));
+
+        const __m512i zero_vector = _mm512_setzero_epi32();
+        const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector);
+        const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector);
+        const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector);
+        const __m512i final_right_shift =
+            _mm512_add_epi32(right_shift, _mm512_set1_epi32(31));
+        const __m512i final_right_shift_low = _mm512_cvtepi32_epi64(
+            _mm512_extracti32x8_epi32(final_right_shift, 0));
+        const __m512i final_right_shift_high = _mm512_cvtepi32_epi64(
+            _mm512_extracti32x8_epi32(final_right_shift, 1));
+
+        const __m512i offset_vector =
+            _mm512_slli_epi64(_mm512_set1_epi64(1), 30);
+        // Really these should be shifted by neg_e_vector, but tests pass when
+        // using right_shift.
+        const __m512i offset_vector_low = _mm512_sllv_epi64(
+            offset_vector,
+            _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0)));
+        const __m512i offset_vector_high = _mm512_sllv_epi64(
+            offset_vector,
+            _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1)));
+
+        for (int j = 0; j < 16; ++j) {
+          accum_data_v[j] = _mm512_sllv_epi32(accum_data_v[j], left_shift);
+          // Apply the fixed-point part of the multiplier.
+          __m512i scaled_v_low =
+              _mm512_mul_epi32(_mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(
+                                   accum_data_v[j], 0)),
+                               m_64bit_low);
+          __m512i scaled_v_high =
+              _mm512_mul_epi32(_mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(
+                                   accum_data_v[j], 1)),
+                               m_64bit_high);
+
+          scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
+          scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
+
+          scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
+          scaled_v_high =
+              _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+
+          accum_data_v[j] =
+              _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+          accum_data_v[j] = _mm512_inserti32x8(
+              accum_data_v[j], _mm512_cvtepi64_epi32(scaled_v_high), 1);
+
+#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
+          RUY_DCHECK(false);
+#endif
+        }
+
+        if (params.dst_zero_point) {
+          __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point);
+          for (int j = 0; j < 16; ++j) {
+            accum_data_v[j] = _mm512_add_epi32(accum_data_v[j], dst_zero_point);
+          }
+        }
+        __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max);
+        __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min);
+        for (int j = 0; j < 16; ++j) {
+          accum_data_v[j] = _mm512_min_epi32(accum_data_v[j], clamp_max_v);
+          accum_data_v[j] = _mm512_max_epi32(accum_data_v[j], clamp_min_v);
+        }
+      }
+      const bool store_full_block =
+          (residual_rows == 16) && (residual_cols == 16);
+
+      if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
+        std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
+        const int block_col_offset = dst_stride;
+        if (store_full_block) {
+          for (int j = 0; j < 16; ++j) {
+            _mm_storeu_epi8(tmp_ptr, _mm512_cvtepi32_epi8(accum_data_v[j]));
+            tmp_ptr += block_col_offset;
+          }
+        } else {
+          for (int j = 0; j < residual_cols; ++j) {
+            _mm_mask_storeu_epi8(tmp_ptr, row_mask,
+                                 _mm512_cvtepi32_epi8(accum_data_v[j]));
+            tmp_ptr += block_col_offset;
+          }
+        }
+        dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + 16);
+      } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
+        std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
+        const int block_col_offset = dst_stride;
+        if (store_full_block) {
+          for (int j = 0; j < 16; ++j) {
+            _mm_storeu_epi8(tmp_ptr, _mm512_cvtepi32_epi8(accum_data_v[j]));
+            tmp_ptr += block_col_offset;
+          }
+        } else {
+          for (int j = 0; j < residual_cols; ++j) {
+            _mm_mask_storeu_epi8(tmp_ptr, row_mask,
+                                 _mm512_cvtepi32_epi8(accum_data_v[j]));
+            tmp_ptr += block_col_offset;
+          }
+        }
+        dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + 16);
+      } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
+        std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
+        const int block_col_offset = dst_stride;
+        if (store_full_block) {
+          for (int j = 0; j < 16; ++j) {
+            _mm256_storeu_epi16(tmp_ptr,
+                                _mm512_cvtepi32_epi16(accum_data_v[j]));
+            tmp_ptr += block_col_offset;
+          }
+        } else {
+          for (int j = 0; j < residual_cols; ++j) {
+            _mm256_mask_storeu_epi16(tmp_ptr, row_mask,
+                                     _mm512_cvtepi32_epi16(accum_data_v[j]));
+            tmp_ptr += block_col_offset;
+          }
+        }
+        dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + 16);
+      } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
+        if (store_full_block) {
+          std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
+          const int block_col_offset = dst_stride;
+          for (int j = 0; j < 16; ++j) {
+            _mm512_storeu_epi32(tmp_ptr, accum_data_v[j]);
+            tmp_ptr += block_col_offset;
+          }
+        } else {
+          std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
+          for (int j = 0; j < residual_cols; ++j) {
+            _mm512_mask_storeu_epi32(dst_block_ptr, row_mask, accum_data_v[j]);
+            dst_block_ptr += dst_stride;
+          }
+        }
+        dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + 16);
+      } else {
+        RUY_DCHECK(false);
+      }
+
+      lhs_col_ptr += 16 * params.lhs_stride;
+    }  // End row-block loop.
+
+    dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
+                                     16 * params.dst_stride);
+    rhs_col_ptr += 16 * params.rhs_stride;
+  }  // End col-block loop.
+}
+
+void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) {
+  gemmlowp::ScopedProfilingLabel label("Kernel kAvx512");
+  RUY_DCHECK_EQ(16, 16);
+
+  // As parameters are defined, we need to scale by sizeof(float).
+  const std::int64_t lhs_stride = params.lhs_stride >> 2;
+  const std::int64_t dst_stride = params.dst_stride >> 2;
+  const std::int64_t rhs_stride = params.rhs_stride >> 2;
+
+  int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
+  const int end_row = std::min(params.dst_rows, params.last_row + 16);
+  const int end_col = std::min(params.dst_cols, params.last_col + 16);
+
+  const float* adj_rhs_col_ptr =
+      params.rhs_base_ptr - params.start_col * rhs_stride;
+  float* adj_dst_col_ptr =
+      params.dst_base_ptr - params.start_col * dst_stride - params.start_row;
+  const float* adj_lhs_col_ptr =
+      params.lhs_base_ptr - params.start_row * lhs_stride;
+  const float* bias_col_ptr = params.bias;
+
+  const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max);
+  const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min);
+
+  int col = params.start_col;
+  for (; col <= end_col - 16; col += 16) {
+    const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
+    float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
+
+    int row = params.start_row;
+    for (; row <= end_row - 16; row += 16) {
+      const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+      float* dst_ptr = dst_col_ptr + row;
+      const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
+
+      // Initialize with bias.
+      const __m512 initial_accum_data = _mm512_loadu_ps(bias_ptr);
+
+      // Process block in two halves, split by columns.
+      {
+        constexpr int mmm = 0;
+
+        __m512 accum_data_v0 = initial_accum_data;
+        __m512 accum_data_v1 = initial_accum_data;
+        __m512 accum_data_v2 = initial_accum_data;
+        __m512 accum_data_v3 = initial_accum_data;
+        __m512 accum_data_v4 = initial_accum_data;
+        __m512 accum_data_v5 = initial_accum_data;
+        __m512 accum_data_v6 = initial_accum_data;
+        __m512 accum_data_v7 = initial_accum_data;
+
+        const float* lhs_ptr = lhs_col_ptr;
+        const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
+        for (int d = 0; d < (params.depth - 1); ++d) {
+          const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+          const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
+          lhs_ptr += 16;
+          rhs_ptr += 16;
+
+          {
+            const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
+            accum_data_v0 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
+            const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
+            accum_data_v1 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
+            const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
+            accum_data_v2 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
+            const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
+            accum_data_v3 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
+            const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
+            accum_data_v4 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
+            const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
+            accum_data_v5 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
+            const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
+            accum_data_v6 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
+            const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
+            accum_data_v7 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
+          }
+        }
+        {
+          const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+          const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
+          {
+            const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
+            accum_data_v0 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
+            const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
+            accum_data_v1 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
+            const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
+            accum_data_v2 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
+            const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
+            accum_data_v3 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
+            const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
+            accum_data_v4 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
+            const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
+            accum_data_v5 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
+            const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
+            accum_data_v6 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
+            const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
+            accum_data_v7 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
+          }
+          {
+            float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
+            accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
+            accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
+            _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0);
+            accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
+            accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
+            _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1);
+            accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
+            accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
+            _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2);
+            accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
+            accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
+            _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3);
+            accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
+            accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
+            _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4);
+            accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
+            accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
+            _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5);
+            accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
+            accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
+            _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6);
+            accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
+            accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
+            _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7);
+          }
+        }
+      }  // Inner half-block loop, unrolled, first iteration.
+      {
+        constexpr int mmm = 1;
+
+        __m512 accum_data_v0 = initial_accum_data;
+        __m512 accum_data_v1 = initial_accum_data;
+        __m512 accum_data_v2 = initial_accum_data;
+        __m512 accum_data_v3 = initial_accum_data;
+        __m512 accum_data_v4 = initial_accum_data;
+        __m512 accum_data_v5 = initial_accum_data;
+        __m512 accum_data_v6 = initial_accum_data;
+        __m512 accum_data_v7 = initial_accum_data;
+
+        const float* lhs_ptr = lhs_col_ptr;
+        const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
+        for (int d = 0; d < (params.depth - 1); ++d) {
+          const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+          const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
+          lhs_ptr += 16;
+          rhs_ptr += 16;
+          {
+            const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
+            accum_data_v0 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
+            const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
+            accum_data_v1 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
+            const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
+            accum_data_v2 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
+            const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
+            accum_data_v3 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
+            const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
+            accum_data_v4 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
+            const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
+            accum_data_v5 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
+            const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
+            accum_data_v6 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
+            const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
+            accum_data_v7 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
+          }
+        }
+        {
+          const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+          const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
+          {
+            const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
+            accum_data_v0 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
+            const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
+            accum_data_v1 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
+            const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
+            accum_data_v2 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
+            const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
+            accum_data_v3 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
+            const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
+            accum_data_v4 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
+            const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
+            accum_data_v5 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
+            const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
+            accum_data_v6 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
+            const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
+            accum_data_v7 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
+          }
+          {
+            float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
+            accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
+            accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
+            _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0);
+            accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
+            accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
+            _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1);
+            accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
+            accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
+            _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2);
+            accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
+            accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
+            _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3);
+            accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
+            accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
+            _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4);
+            accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
+            accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
+            _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5);
+            accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
+            accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
+            _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6);
+            accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
+            accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
+            _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7);
+          }
+        }
+      }  // Inner half-block loop, unrolled, second iteration.
+    }    // End row-block loop.
+
+    // The unrolling within this conditional may be somewhat pointless. It
+    // depends on the kinds of models.
+    if (row < end_row) {
+      const int residual_rows = end_row - row;
+
+      const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+      float* dst_ptr = dst_col_ptr + row;
+      const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
+
+      // Initialize with bias.
+      const __mmask16 row_mask =
+          (static_cast<std::uint32_t>(1) << residual_rows) - 1;
+      const __m512 initial_accum_data =
+          _mm512_maskz_loadu_ps(row_mask, bias_ptr);
+
+      // Process block in two halves, split by columns.
+      for (int mmm = 0; mmm < 2; ++mmm) {
+        __m512 accum_data_v0 = initial_accum_data;
+        __m512 accum_data_v1 = initial_accum_data;
+        __m512 accum_data_v2 = initial_accum_data;
+        __m512 accum_data_v3 = initial_accum_data;
+        __m512 accum_data_v4 = initial_accum_data;
+        __m512 accum_data_v5 = initial_accum_data;
+        __m512 accum_data_v6 = initial_accum_data;
+        __m512 accum_data_v7 = initial_accum_data;
+
+        const float* lhs_ptr = lhs_col_ptr;
+        const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
+        for (int d = 0; d < (params.depth - 1); ++d) {
+          const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+          const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
+          lhs_ptr += 16;
+          rhs_ptr += 16;
+          {
+            const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
+            accum_data_v0 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
+            const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
+            accum_data_v1 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
+            const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
+            accum_data_v2 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
+            const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
+            accum_data_v3 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
+            const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
+            accum_data_v4 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
+            const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
+            accum_data_v5 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
+            const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
+            accum_data_v6 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
+            const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
+            accum_data_v7 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
+          }
+        }
+        {
+          const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+          const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
+          {
+            const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
+            accum_data_v0 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
+            const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
+            accum_data_v1 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
+            const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
+            accum_data_v2 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
+            const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
+            accum_data_v3 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
+            const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
+            accum_data_v4 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
+            const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
+            accum_data_v5 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
+            const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
+            accum_data_v6 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
+            const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
+            accum_data_v7 =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
+          }
+          {
+            float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
+            accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
+            accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
+            _mm512_mask_storeu_ps(block_ptr + 0 * dst_stride, row_mask,
+                                  accum_data_v0);
+            accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
+            accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
+            _mm512_mask_storeu_ps(block_ptr + 1 * dst_stride, row_mask,
+                                  accum_data_v1);
+            accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
+            accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
+            _mm512_mask_storeu_ps(block_ptr + 2 * dst_stride, row_mask,
+                                  accum_data_v2);
+            accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
+            accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
+            _mm512_mask_storeu_ps(block_ptr + 3 * dst_stride, row_mask,
+                                  accum_data_v3);
+            accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
+            accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
+            _mm512_mask_storeu_ps(block_ptr + 4 * dst_stride, row_mask,
+                                  accum_data_v4);
+            accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
+            accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
+            _mm512_mask_storeu_ps(block_ptr + 5 * dst_stride, row_mask,
+                                  accum_data_v5);
+            accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
+            accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
+            _mm512_mask_storeu_ps(block_ptr + 6 * dst_stride, row_mask,
+                                  accum_data_v6);
+            accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
+            accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
+            _mm512_mask_storeu_ps(block_ptr + 7 * dst_stride, row_mask,
+                                  accum_data_v7);
+          }
+        }
+      }  // Inner half-block loop.
+    }    // Residual rows, main col-block loop.
+  }      // End col-block loop.
+
+  if (col < end_col) {
+    RUY_DCHECK_GE(end_col - col, 0);
+    RUY_DCHECK_LT(end_col - col, 16);
+
+    __m512 accum_data_v[8];
+
+    const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
+    float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
+
+    for (int row = params.start_row; row < end_row; row += 16) {
+      const int residual_rows = std::min(end_row - row, 16);
+
+      const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+      float* dst_ptr = dst_col_ptr + row;
+      const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
+
+      // Initialize with bias.
+      const __mmask16 row_mask =
+          (static_cast<std::uint32_t>(1) << residual_rows) - 1;
+      const __m512 initial_accum_data =
+          _mm512_maskz_loadu_ps(row_mask, bias_ptr);
+
+      // Process block in two halves, split by columns.
+      for (int mmm = 0; mmm < 2; ++mmm) {
+        for (int j = 0; j < 8; ++j) {
+          accum_data_v[j] = initial_accum_data;
+        }
+
+        const float* lhs_ptr = lhs_col_ptr;
+        const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
+        for (int d = 0; d < params.depth; ++d) {
+          const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+          const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
+
+          for (int j = 0; j < 8; ++j) {
+            const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data[j]);
+            accum_data_v[j] =
+                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]);
+          }
+          lhs_ptr += 16;
+          rhs_ptr += 16;
+        }
+
+        const int residual_cols = std::min(end_col - col - 8 * mmm, 8);
+
+        if (residual_rows == 16) {
+          if (residual_cols == 8) {
+            for (int j = 0; j < 8; ++j) {
+              float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
+              accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
+              accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
+              _mm512_storeu_ps(block_ptr, accum_data_v[j]);
+            }
+          } else {
+            for (int j = 0; j < residual_cols; ++j) {
+              float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
+              accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
+              accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
+              _mm512_storeu_ps(block_ptr, accum_data_v[j]);
+            }
+          }
+        } else {
+          for (int j = 0; j < residual_cols; ++j) {
+            float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
+            accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
+            accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
+            _mm512_mask_storeu_ps(block_ptr, row_mask, accum_data_v[j]);
+          }
+        }
+      }  // Inner half-block loop.
+    }    // End row-block loop.
+  }      // Residual cols.
+}
+
+#endif  //  RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+}  // namespace ruy
diff --git a/tensorflow/lite/experimental/ruy/kernel_common.h b/tensorflow/lite/experimental/ruy/kernel_common.h
new file mode 100644
index 0000000..31d93b2
--- /dev/null
+++ b/tensorflow/lite/experimental/ruy/kernel_common.h
@@ -0,0 +1,447 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_COMMON_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_COMMON_H_
+
+#include <algorithm>
+#include <cstdint>
+#include <type_traits>
+
+#include "fixedpoint/fixedpoint.h"
+#include "profiling/instrumentation.h"
+#include "tensorflow/lite/experimental/ruy/check_macros.h"
+#include "tensorflow/lite/experimental/ruy/common.h"
+#include "tensorflow/lite/experimental/ruy/internal_matrix.h"
+#include "tensorflow/lite/experimental/ruy/matrix.h"
+#include "tensorflow/lite/experimental/ruy/opt_set.h"
+#include "tensorflow/lite/experimental/ruy/path.h"
+#include "tensorflow/lite/experimental/ruy/platform.h"
+#include "tensorflow/lite/experimental/ruy/side_pair.h"
+#include "tensorflow/lite/experimental/ruy/size_util.h"
+#include "tensorflow/lite/experimental/ruy/spec.h"
+#include "tensorflow/lite/experimental/ruy/tune.h"
+
+namespace ruy {
+
+template <Path ThePath, typename LhsScalar, typename RhsScalar,
+          typename DstScalar, typename Spec>
+struct Kernel {};
+
+template <Path ThePath, typename LhsScalar, typename RhsScalar,
+          typename DstScalar, typename Spec>
+void RunKernelTyped(Tuning tuning, const PackedMatrix<LhsScalar>& lhs,
+                    const PackedMatrix<RhsScalar>& rhs, const Spec& spec,
+                    int start_row, int start_col, int end_row, int end_col,
+                    Matrix<DstScalar>* dst) {
+  using Kernel = Kernel<ThePath, LhsScalar, RhsScalar, DstScalar, Spec>;
+  Kernel kernel(tuning);
+  using LhsLayout = typename Kernel::LhsLayout;
+  using RhsLayout = typename Kernel::RhsLayout;
+  // end_row and end_col may be larger than dst dimensions.
+  // that is because kernels write directly to the destination matrix, whose
+  // dimensions may not be a multiple of the kernel dimensions, and we try to
+  // keep this annoyance localized as an implementation detail in kernels,
+  // by allowing to pass rounded-up values down as far as possible.
+  // These assertions encode the contract.
+  RUY_DCHECK_LE(0, start_row);
+  RUY_DCHECK_LE(start_row, end_row);
+  RUY_DCHECK_LT(end_row, dst->layout.rows + LhsLayout::kCols);
+  RUY_DCHECK_EQ((end_row - start_row) % LhsLayout::kCols, 0);
+  RUY_DCHECK_LE(0, start_col);
+  RUY_DCHECK_LE(start_col, end_col);
+  RUY_DCHECK_LT(end_col, dst->layout.cols + RhsLayout::kCols);
+  RUY_DCHECK_EQ((end_col - start_col) % RhsLayout::kCols, 0);
+#if RUY_OPT_ENABLED(RUY_OPT_FAT_KERNEL)
+  kernel.Run(lhs, rhs, spec, start_row, start_col, end_row, end_col, dst);
+#else
+  for (int col = start_col; col < end_col; col += RhsLayout::kCols) {
+    int block_end_col = std::min(col + RhsLayout::kCols, end_col);
+    for (int row = start_row; row < end_row; row += LhsLayout::kCols) {
+      int block_end_row = std::min(row + LhsLayout::kCols, end_row);
+      kernel.Run(lhs, rhs, spec, row, col, block_end_row, block_end_col, dst);
+    }
+  }
+#endif
+}
+
+// Main entry point for kernels.
+template <Path ThePath, typename LhsScalar, typename RhsScalar,
+          typename DstScalar, typename Spec>
+void RunKernel(Tuning tuning, const SidePair<PMatrix>& src, void* spec,
+               const SidePair<int>& start, const SidePair<int>& end,
+               DMatrix* dst) {
+  Matrix<DstScalar> mdst = ToMatrix<DstScalar>(*dst);
+  RunKernelTyped<ThePath, LhsScalar, RhsScalar, DstScalar, Spec>(
+      tuning, ToPackedMatrix<LhsScalar>(src[Side::kLhs]),
+      ToPackedMatrix<RhsScalar>(src[Side::kRhs]),
+      *static_cast<const Spec*>(spec), start[Side::kLhs], start[Side::kRhs],
+      end[Side::kLhs], end[Side::kRhs], &mdst);
+}
+
+// Copied from TF Lite code.
+inline std::int32_t MultiplyByQuantizedMultiplier(
+    std::int32_t x, std::int32_t quantized_multiplier, int shift) {
+  using gemmlowp::RoundingDivideByPOT;
+  using gemmlowp::SaturatingRoundingDoublingHighMul;
+  int left_shift = shift > 0 ? shift : 0;
+  int right_shift = shift > 0 ? 0 : -shift;
+  return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
+                                 x * (1 << left_shift), quantized_multiplier),
+                             right_shift);
+}
+
+// Helper to apply a fixed-point multiplier.  Only 'applicable' if AccumScalar
+// is int32 (i.e. in all cases except floating-point) and if the destination is
+// not int32 (i.e. unless the user wants to get raw accumulators).
+template <typename Spec,
+          bool IsApplicable =
+              std::is_same<typename Spec::AccumScalar, std::int32_t>::value &&
+              !std::is_same<typename Spec::DstScalar, std::int32_t>::value>
+struct ApplyMultiplierImpl {};
+
+// Specialization in non-applicable case: do nothing, just check that values
+// are default.
+template <typename Spec>
+struct ApplyMultiplierImpl<Spec, false> {
+  using AccumScalar = typename Spec::AccumScalar;
+  using DstScalar = typename Spec::DstScalar;
+  static void Run(const Spec& spec, int row, AccumScalar* accum) {
+    RUY_DCHECK_EQ(spec.multiplier_fixedpoint, 0);
+    RUY_DCHECK_EQ(spec.multiplier_exponent, 0);
+  }
+};
+
+template <typename Spec>
+struct ApplyMultiplierImpl<Spec, true> {
+  using AccumScalar = typename Spec::AccumScalar;
+  using DstScalar = typename Spec::DstScalar;
+  static void Run(const Spec& spec, int row, AccumScalar* accum) {
+    AccumScalar m = spec.multiplier_fixedpoint_perchannel
+                        ? spec.multiplier_fixedpoint_perchannel[row]
+                        : spec.multiplier_fixedpoint;
+    int e = spec.multiplier_exponent_perchannel
+                ? spec.multiplier_exponent_perchannel[row]
+                : spec.multiplier_exponent;
+    *accum = MultiplyByQuantizedMultiplier(*accum, m, e);
+  }
+};
+
+template <typename Spec>
+void ApplyMultiplier(const Spec& spec, int row,
+                     typename Spec::AccumScalar* accum) {
+  ApplyMultiplierImpl<Spec>::Run(spec, row, accum);
+}
+
+template <typename LhsScalar, typename RhsScalar, typename DstScalar,
+          typename Spec>
+struct Kernel<Path::kStandardCpp, LhsScalar, RhsScalar, DstScalar, Spec> {
+  using AccumScalar = typename Spec::AccumScalar;
+  using LhsLayout = typename Spec::StandardCppKernelLhsLayout;
+  using RhsLayout = typename Spec::StandardCppKernelRhsLayout;
+  explicit Kernel(Tuning) {}
+  void Run(const PackedMatrix<LhsScalar>& lhs,
+           const PackedMatrix<RhsScalar>& rhs, const Spec& spec, int start_row,
+           int start_col, int end_row, int end_col,
+           Matrix<DstScalar>* dst) const {
+    // See the comment in RunKernelTyped. end_row may be larger than
+    // dst->layout.rows. It's the responsibility of the kernel to avoid
+    // overrunning dst boundaries, which we do here by computing
+    // clamped_end_row.
+    int clamped_end_row = std::min(end_row, dst->layout.rows);
+    int clamped_end_col = std::min(end_col, dst->layout.cols);
+    RUY_DCHECK_LE(0, start_row);
+    RUY_DCHECK_LE(start_row, clamped_end_row);
+    RUY_DCHECK_LE(clamped_end_row, dst->layout.rows);
+    RUY_DCHECK_LE(clamped_end_row, end_row);
+    RUY_DCHECK_LE(end_row - clamped_end_row, LhsLayout::kCols);
+    RUY_DCHECK_LE(0, start_col);
+    RUY_DCHECK_LE(start_col, clamped_end_col);
+    RUY_DCHECK_LE(clamped_end_col, dst->layout.cols);
+    RUY_DCHECK_LE(clamped_end_col, end_col);
+    RUY_DCHECK_LE(end_col - clamped_end_col, RhsLayout::kCols);
+    gemmlowp::ScopedProfilingLabel label("Kernel (Standard Cpp)");
+    const int depth = lhs.layout.rows;
+    for (int i = start_row; i < clamped_end_row; i++) {
+      for (int j = start_col; j < clamped_end_col; j++) {
+        using AccumScalar = typename Spec::AccumScalar;
+        AccumScalar accum = 0;
+        for (int k = 0; k < depth; k++) {
+          AccumScalar lhs_val = Element(lhs, k, i);
+          AccumScalar rhs_val = Element(rhs, k, j);
+          accum += lhs_val * rhs_val;
+        }
+        if (spec.bias) {
+          accum += spec.bias[i];
+        }
+        if (lhs.zero_point) {
+          accum -= lhs.zero_point * rhs.sums[j];
+        }
+        if (rhs.zero_point) {
+          accum -= rhs.zero_point * lhs.sums[i];
+        }
+        if (lhs.zero_point && rhs.zero_point) {
+          accum += lhs.zero_point * rhs.zero_point * depth;
+        }
+        ApplyMultiplier(spec, i, &accum);
+        accum += dst->zero_point;
+        accum = std::min<AccumScalar>(accum, spec.clamp_max);
+        accum = std::max<AccumScalar>(accum, spec.clamp_min);
+        *ElementPtr(dst, i, j) = static_cast<DstScalar>(accum);
+      }
+    }
+  }
+};
+
+#define RUY_INHERIT_KERNEL(PARENT, CHILD)                                  \
+  template <typename LhsScalar, typename RhsScalar, typename DstScalar,    \
+            typename Spec>                                                 \
+  struct Kernel<CHILD, LhsScalar, RhsScalar, DstScalar, Spec>              \
+      : Kernel<PARENT, LhsScalar, RhsScalar, DstScalar, Spec> {            \
+    explicit Kernel(Tuning tuning)                                         \
+        : Kernel<PARENT, LhsScalar, RhsScalar, DstScalar, Spec>(tuning) {} \
+  };
+
+#if RUY_PLATFORM(NEON)
+RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kNeon)
+RUY_INHERIT_KERNEL(Path::kNeon, Path::kNeonDotprod)
+#elif RUY_PLATFORM(AVX512)
+RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kAvx512)
+#endif
+
+// KernelParams are shared across 32-bit and 64-bit NEON code, and x86 AVX-512
+// code.
+#if (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32) || \
+     RUY_PLATFORM(AVX512)) &&                          \
+    RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+#define RUY_ASM_FLAG_HAS_BIAS 0x1
+#define RUY_ASM_FLAG_HAS_LHS_SUMS 0x2
+#define RUY_ASM_FLAG_HAS_RHS_SUMS 0x4
+#define RUY_ASM_FLAG_HAS_PERCHANNEL 0x8
+#define RUY_ASM_FLAG_NEEDS_LEFT_SHIFT 0x10
+
+#define RUY_ASM_TYPE_ID_UINT8 1
+#define RUY_ASM_TYPE_ID_INT8 2
+#define RUY_ASM_TYPE_ID_INT16 3
+#define RUY_ASM_TYPE_ID_INT32 4
+
+template <typename DstScalar>
+struct DstTypeId {};
+
+template <>
+struct DstTypeId<std::uint8_t> {
+  static constexpr int kValue = RUY_ASM_TYPE_ID_UINT8;
+};
+
+template <>
+struct DstTypeId<std::int8_t> {
+  static constexpr int kValue = RUY_ASM_TYPE_ID_INT8;
+};
+
+template <>
+struct DstTypeId<std::int16_t> {
+  static constexpr int kValue = RUY_ASM_TYPE_ID_INT16;
+};
+
+template <>
+struct DstTypeId<std::int32_t> {
+  static constexpr int kValue = RUY_ASM_TYPE_ID_INT32;
+};
+
+template <int LhsCols, int RhsCols>
+struct KernelParams8bit {
+  static constexpr int kMaxDstTypeSize = 4;
+
+  const std::int32_t* bias;
+  const std::int32_t* lhs_sums;
+  const std::int32_t* rhs_sums;
+  const std::int8_t* lhs_base_ptr;
+  const std::int32_t* multiplier_fixedpoint;
+  const std::int32_t* multiplier_exponent;
+  const std::int8_t* rhs_base_ptr;
+  void* dst_base_ptr;
+  std::int32_t lhs_zero_point;
+  std::int32_t rhs_zero_point;
+  std::int32_t dst_zero_point;
+  std::int32_t prod_zp_depth;
+  std::int32_t start_row;
+  std::int32_t start_col;
+  std::int32_t last_row;
+  std::int32_t last_col;
+  std::int32_t dst_rows;
+  std::int32_t dst_cols;
+  std::int32_t lhs_stride;
+  std::int32_t rhs_stride;
+  std::int32_t dst_stride;
+  std::int32_t depth;
+  std::int32_t clamp_min;
+  std::int32_t clamp_max;
+  std::uint8_t flags;
+  std::uint8_t dst_type_id;
+  const std::int32_t zero_data[LhsCols] = {0};
+  std::uint8_t dst_tmp_buf[LhsCols * RhsCols * kMaxDstTypeSize];
+  std::int32_t multiplier_fixedpoint_buf[LhsCols];
+  std::int32_t multiplier_exponent_buf[LhsCols];
+};
+
+template <typename DstScalar, int LhsCols, int RhsCols>
+void MakeKernelParams8bit(const PackedMatrix<std::int8_t>& lhs,
+                          const PackedMatrix<std::int8_t>& rhs,
+                          const BasicSpec<std::int32_t, DstScalar>& spec,
+                          int start_row, int start_col, int end_row,
+                          int end_col, Matrix<DstScalar>* dst,
+                          KernelParams8bit<LhsCols, RhsCols>* params) {
+  using Params = KernelParams8bit<LhsCols, RhsCols>;
+
+  static_assert(sizeof(DstScalar) <= Params::kMaxDstTypeSize, "");
+
+  const int depth = lhs.layout.rows;
+  RUY_DCHECK_EQ(start_row % LhsCols, 0);
+  RUY_DCHECK_EQ(start_col % RhsCols, 0);
+  RUY_DCHECK_EQ(end_row % LhsCols, 0);
+  RUY_DCHECK_EQ(end_col % RhsCols, 0);
+
+  params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride;
+  params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride;
+  params->flags = 0;
+  params->bias = params->zero_data;
+  if (spec.bias) {
+    params->bias = spec.bias;
+    params->flags |= RUY_ASM_FLAG_HAS_BIAS;
+  }
+  if (lhs.sums) {
+    params->lhs_sums = lhs.sums;
+    params->flags |= RUY_ASM_FLAG_HAS_LHS_SUMS;
+  }
+  if (rhs.sums) {
+    params->rhs_sums = rhs.sums;
+    params->flags |= RUY_ASM_FLAG_HAS_RHS_SUMS;
+  }
+  params->start_row = start_row;
+  params->start_col = start_col;
+  params->last_row = end_row - LhsCols;
+  params->last_col = end_col - RhsCols;
+  params->lhs_stride = lhs.layout.stride;
+  params->rhs_stride = rhs.layout.stride;
+  params->dst_stride = sizeof(DstScalar) * dst->layout.stride;
+  params->lhs_zero_point = lhs.zero_point;
+  params->rhs_zero_point = rhs.zero_point;
+  params->dst_zero_point = dst->zero_point;
+  params->depth = depth;
+  params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth;
+  if (spec.multiplier_fixedpoint_perchannel) {
+    params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT;
+    params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL;
+    params->multiplier_fixedpoint = spec.multiplier_fixedpoint_perchannel;
+    params->multiplier_exponent = spec.multiplier_exponent_perchannel;
+  } else {
+    if (spec.multiplier_exponent > 0) {
+      params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT;
+    }
+    params->multiplier_fixedpoint = params->multiplier_fixedpoint_buf;
+    params->multiplier_exponent = params->multiplier_exponent_buf;
+    for (int i = 0; i < LhsCols; i++) {
+      params->multiplier_fixedpoint_buf[i] = spec.multiplier_fixedpoint;
+      params->multiplier_exponent_buf[i] = spec.multiplier_exponent;
+    }
+  }
+  params->clamp_min = spec.clamp_min;
+  params->clamp_max = spec.clamp_max;
+  params->dst_rows = dst->layout.rows;
+  params->dst_cols = dst->layout.cols;
+
+  RUY_DCHECK_LT(params->last_row, params->dst_rows);
+  RUY_DCHECK_LT(params->last_col, params->dst_cols);
+
+  params->dst_type_id = DstTypeId<DstScalar>::kValue;
+  params->dst_base_ptr =
+      dst->data.get() + start_col * dst->layout.stride + start_row;
+}
+
+template <int LhsCols, int RhsCols>
+struct KernelParamsFloat {
+  const float* lhs_base_ptr;
+  const float* rhs_base_ptr;
+  float* dst_base_ptr;
+  const float* bias;
+  std::int32_t start_row;
+  std::int32_t start_col;
+  std::int32_t last_row;
+  std::int32_t last_col;
+  std::int32_t dst_rows;
+  std::int32_t dst_cols;
+  std::int32_t lhs_stride;
+  std::int32_t rhs_stride;
+  std::int32_t dst_stride;
+  std::int32_t depth;
+  float clamp_min;
+  float clamp_max;
+  std::uint8_t flags;
+  const float zero_data[LhsCols] = {0};
+  float dst_tmp_buf[LhsCols * RhsCols];
+};
+
+template <int LhsCols, int RhsCols>
+inline void MakeKernelParamsFloat(const PackedMatrix<float>& lhs,
+                                  const PackedMatrix<float>& rhs,
+                                  const BasicSpec<float, float>& spec,
+                                  int start_row, int start_col, int end_row,
+                                  int end_col, Matrix<float>* dst,
+                                  KernelParamsFloat<LhsCols, RhsCols>* params) {
+  using Params = KernelParamsFloat<LhsCols, RhsCols>;
+
+  const int depth = lhs.layout.rows;
+  RUY_DCHECK_EQ(start_row % LhsCols, 0);
+  RUY_DCHECK_EQ(start_col % RhsCols, 0);
+  RUY_DCHECK_EQ(end_row % LhsCols, 0);
+  RUY_DCHECK_EQ(end_col % RhsCols, 0);
+
+  params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride;
+  params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride;
+  params->dst_base_ptr =
+      dst->data.get() + start_col * dst->layout.stride + start_row;
+
+  std::uint8_t flags = 0;
+  params->bias = params->zero_data;
+  if (spec.bias) {
+    params->bias = spec.bias;
+    flags |= RUY_ASM_FLAG_HAS_BIAS;
+  }
+  params->flags = flags;
+  params->start_row = start_row;
+  params->start_col = start_col;
+  params->last_row = end_row - LhsCols;
+  params->last_col = end_col - RhsCols;
+  params->lhs_stride = sizeof(float) * lhs.layout.stride;
+  params->rhs_stride = sizeof(float) * rhs.layout.stride;
+  params->dst_stride = sizeof(float) * dst->layout.stride;
+  params->depth = depth;
+  params->clamp_min = spec.clamp_min;
+  params->clamp_max = spec.clamp_max;
+  params->dst_rows = dst->layout.rows;
+  params->dst_cols = dst->layout.cols;
+
+  RUY_DCHECK_LT(params->last_row, params->dst_rows);
+  RUY_DCHECK_LT(params->last_col, params->dst_cols);
+}
+
+#endif  // (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32) ||
+        //  RUY_PLATFORM(AVX512)) &&
+        // RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+}  // namespace ruy
+
+#endif  // TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_COMMON_H_
diff --git a/tensorflow/lite/experimental/ruy/kernel_x86.h b/tensorflow/lite/experimental/ruy/kernel_x86.h
new file mode 100644
index 0000000..58f416f
--- /dev/null
+++ b/tensorflow/lite/experimental/ruy/kernel_x86.h
@@ -0,0 +1,76 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_X86_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_X86_H_
+
+#include <cstdint>
+
+#include "tensorflow/lite/experimental/ruy/common.h"
+#include "tensorflow/lite/experimental/ruy/internal_matrix.h"
+#include "tensorflow/lite/experimental/ruy/kernel_common.h"
+#include "tensorflow/lite/experimental/ruy/matrix.h"
+#include "tensorflow/lite/experimental/ruy/opt_set.h"
+#include "tensorflow/lite/experimental/ruy/path.h"
+#include "tensorflow/lite/experimental/ruy/platform.h"
+#include "tensorflow/lite/experimental/ruy/spec.h"
+#include "tensorflow/lite/experimental/ruy/tune.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params);
+
+template <typename DstScalar>
+struct Kernel<Path::kAvx512, std::int8_t, std::int8_t, DstScalar,
+              BasicSpec<std::int32_t, DstScalar>> {
+  Tuning tuning = Tuning::kAuto;
+  using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
+  using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
+  explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+  void Run(const PackedMatrix<std::int8_t>& lhs,
+           const PackedMatrix<std::int8_t>& rhs,
+           const BasicSpec<std::int32_t, DstScalar>& spec, int start_row,
+           int start_col, int end_row, int end_col,
+           Matrix<DstScalar>* dst) const {
+    KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
+    MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col,
+                         dst, &params);
+    Kernel8bitAvx512(params);
+  }
+};
+
+void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params);
+
+template <>
+struct Kernel<Path::kAvx512, float, float, float, BasicSpec<float, float>> {
+  Tuning tuning = Tuning::kAuto;
+  using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
+  using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
+  explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+  void Run(const PackedMatrix<float>& lhs, const PackedMatrix<float>& rhs,
+           const BasicSpec<float, float>& spec, int start_row, int start_col,
+           int end_row, int end_col, Matrix<float>* dst) const {
+    KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
+    MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row,
+                          end_col, dst, &params);
+    KernelFloatAvx512(params);
+  }
+};
+#endif  // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+}  // namespace ruy
+
+#endif  // TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_X86_H_
diff --git a/tensorflow/lite/experimental/ruy/matrix.h b/tensorflow/lite/experimental/ruy/matrix.h
index 3f26f09..b059628 100644
--- a/tensorflow/lite/experimental/ruy/matrix.h
+++ b/tensorflow/lite/experimental/ruy/matrix.h
@@ -17,7 +17,7 @@
 #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_MATRIX_H_
 
 #include <cstddef>
-#include <cstdint>
+#include <cstdint>  // IWYU pragma: keep
 #include <type_traits>
 
 #include "tensorflow/lite/experimental/ruy/check_macros.h"
diff --git a/tensorflow/lite/experimental/ruy/opt_set.h b/tensorflow/lite/experimental/ruy/opt_set.h
index 122cb75..525ba22 100644
--- a/tensorflow/lite/experimental/ruy/opt_set.h
+++ b/tensorflow/lite/experimental/ruy/opt_set.h
@@ -23,7 +23,7 @@
 // Each bit in RUY_OPT_SET controls a particular optimization done in Ruy.
 #if !defined(RUY_OPT_SET)
 // Default to all optimizations.
-#define RUY_OPT_SET 0x3ff
+#define RUY_OPT_SET 0x7ff
 #endif
 
 #define RUY_OPT_INTRINSICS 0x1
@@ -36,6 +36,7 @@
 #define RUY_OPT_AVOID_ALIASING 0x80
 #define RUY_OPT_MAX_STREAMING 0x100
 #define RUY_OPT_PREFETCH 0x200
+#define RUY_OPT_PACK_AHEAD 0x400
 
 #define RUY_OPT_ENABLED(ruy_opt) ((RUY_OPT_SET & ruy_opt) != 0)
 
diff --git a/tensorflow/lite/experimental/ruy/pack.cc b/tensorflow/lite/experimental/ruy/pack.cc
deleted file mode 100644
index 2ac955c..0000000
--- a/tensorflow/lite/experimental/ruy/pack.cc
+++ /dev/null
@@ -1,1533 +0,0 @@
-/* Copyright 2019 Google LLC. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/experimental/ruy/pack.h"
-
-#include "tensorflow/lite/experimental/ruy/platform.h"
-
-namespace ruy {
-
-#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
-
-void Pack8bitNeonOutOfOrder(const void* src_ptr0, const void* src_ptr1,
-                            const void* src_ptr2, const void* src_ptr3,
-                            int src_inc0, int src_inc1, int src_inc2,
-                            int src_inc3, int src_rows, int src_zero_point,
-                            std::int8_t* packed_ptr, int start_col, int end_col,
-                            std::int32_t* sums_ptr, int input_xor) {
-  gemmlowp::ScopedProfilingLabel label(
-      "Pack (kNeon, optimized for out-of-order cores)");
-  asm volatile(
-      // clang-format off
-          "dup v26.16b, %w[input_xor]\n"
-          "mov w1, #0\n"
-          "dup v28.4s, wzr\n"
-          "dup v29.4s, wzr\n"
-          "dup v30.4s, wzr\n"
-          "dup v31.4s, wzr\n"
-
-          "and w2, %w[rows], #-16\n"
-          "cmp w1, w2\n"
-          "beq 3f\n"
-
-          "add w1, w1, #16\n"
-          "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
-          "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
-          "cmp w1, w2\n"
-          "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
-          "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
-          "beq 2f\n"
-
-          "1:\n"
-
-          "add w1, w1, #16\n"
-          "eor v4.16b, v0.16b, v26.16b\n"
-          "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
-          "eor v5.16b, v1.16b, v26.16b\n"
-          "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
-          "eor v6.16b, v2.16b, v26.16b\n"
-          "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
-          "eor v7.16b, v3.16b, v26.16b\n"
-          "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
-
-          "saddlp v16.8h, v4.16b\n"
-          "str q4, [%[packed_ptr], #0]\n"
-          "saddlp v17.8h, v5.16b\n"
-          "str q5, [%[packed_ptr], #16]\n"
-          "saddlp v18.8h, v6.16b\n"
-          "str q6, [%[packed_ptr], #32]\n"
-          "saddlp v19.8h, v7.16b\n"
-          "str q7, [%[packed_ptr], #48]\n"
-          "sadalp v28.4s, v16.8h\n"
-          "cmp w1, w2\n"
-          "sadalp v29.4s, v17.8h\n"
-          "add %[packed_ptr], %[packed_ptr], #64\n"
-          "sadalp v30.4s, v18.8h\n"
-          "sadalp v31.4s, v19.8h\n"
-
-          "bne 1b\n"
-
-          "2:\n"
-
-          "eor v4.16b, v0.16b, v26.16b\n"
-          "eor v5.16b, v1.16b, v26.16b\n"
-          "eor v6.16b, v2.16b, v26.16b\n"
-          "eor v7.16b, v3.16b, v26.16b\n"
-
-          "saddlp v16.8h, v4.16b\n"
-          "str q4, [%[packed_ptr], #0]\n"
-          "saddlp v17.8h, v5.16b\n"
-          "str q5, [%[packed_ptr], #16]\n"
-          "saddlp v18.8h, v6.16b\n"
-          "str q6, [%[packed_ptr], #32]\n"
-          "saddlp v19.8h, v7.16b\n"
-          "str q7, [%[packed_ptr], #48]\n"
-          "sadalp v28.4s, v16.8h\n"
-          "sadalp v29.4s, v17.8h\n"
-          "sadalp v30.4s, v18.8h\n"
-          "sadalp v31.4s, v19.8h\n"
-
-          "add %[packed_ptr], %[packed_ptr], #64\n"
-
-          "3:\n"
-
-          "ands w2, %w[rows], #15\n"
-          "beq 4f\n"
-          "dup v0.16b, %w[src_zero_point]\n"
-          "dup v1.16b, %w[src_zero_point]\n"
-          "dup v2.16b, %w[src_zero_point]\n"
-          "dup v3.16b, %w[src_zero_point]\n"
-#define RUY_LOAD_ONE_ROW(R)                   \
-  "cmp w2, #" #R "\n"                         \
-  "beq 5f\n"                                  \
-  "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \
-  "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \
-  "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \
-  "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n"
-
-          RUY_LOAD_ONE_ROW(0)
-          RUY_LOAD_ONE_ROW(1)
-          RUY_LOAD_ONE_ROW(2)
-          RUY_LOAD_ONE_ROW(3)
-          RUY_LOAD_ONE_ROW(4)
-          RUY_LOAD_ONE_ROW(5)
-          RUY_LOAD_ONE_ROW(6)
-          RUY_LOAD_ONE_ROW(7)
-          RUY_LOAD_ONE_ROW(8)
-          RUY_LOAD_ONE_ROW(9)
-          RUY_LOAD_ONE_ROW(10)
-          RUY_LOAD_ONE_ROW(11)
-          RUY_LOAD_ONE_ROW(12)
-          RUY_LOAD_ONE_ROW(13)
-          RUY_LOAD_ONE_ROW(14)
-          RUY_LOAD_ONE_ROW(15)
-#undef RUY_LOAD_ONE_ROW
-          "5:\n"
-
-          "eor v4.16b, v0.16b, v26.16b\n"
-          "eor v5.16b, v1.16b, v26.16b\n"
-          "eor v6.16b, v2.16b, v26.16b\n"
-          "eor v7.16b, v3.16b, v26.16b\n"
-
-          "saddlp v16.8h, v4.16b\n"
-          "saddlp v17.8h, v5.16b\n"
-          "saddlp v18.8h, v6.16b\n"
-          "saddlp v19.8h, v7.16b\n"
-          "sadalp v28.4s, v16.8h\n"
-          "sadalp v29.4s, v17.8h\n"
-          "sadalp v30.4s, v18.8h\n"
-          "sadalp v31.4s, v19.8h\n"
-
-          "str q4, [%[packed_ptr], #0]\n"
-          "str q5, [%[packed_ptr], #16]\n"
-          "str q6, [%[packed_ptr], #32]\n"
-          "str q7, [%[packed_ptr], #48]\n"
-          "add %[packed_ptr], %[packed_ptr], #64\n"
-
-          "4:\n"
-
-          "addp v28.4s, v28.4s, v29.4s\n"
-          "addp v30.4s, v30.4s, v31.4s\n"
-          "addp v28.4s, v28.4s, v30.4s\n"
-
-          "cmp %[sums_ptr], #0\n"
-          "beq 6f\n"
-          "st1 {v28.4s}, [%[sums_ptr]], #16\n"
-          "6:\n"
-      // clang-format on
-
-      : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
-        [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3),
-        [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr)
-      : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)),
-        [ src_inc1 ] "r"(static_cast<std::int64_t>(src_inc1)),
-        [ src_inc2 ] "r"(static_cast<std::int64_t>(src_inc2)),
-        [ src_inc3 ] "r"(static_cast<std::int64_t>(src_inc3)),
-        [ rows ] "r"(src_rows), [ src_zero_point ] "r"(src_zero_point),
-        [ input_xor ] "r"(input_xor)
-      : "cc", "memory", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
-        "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
-        "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
-        "v27", "v28", "v29", "v30", "v31");
-}
-
-void Pack8bitNeonInOrder(const void* src_ptr0, const void* src_ptr1,
-                         const void* src_ptr2, const void* src_ptr3,
-                         int src_inc0, int src_inc1, int src_inc2, int src_inc3,
-                         int src_rows, int src_zero_point,
-                         std::int8_t* packed_ptr, int start_col, int end_col,
-                         std::int32_t* sums_ptr, int input_xor) {
-  gemmlowp::ScopedProfilingLabel label(
-      "Pack (kNeon, optimized for in-order cores)");
-  asm volatile(
-          // clang-format off
-          "dup v26.16b, %w[input_xor]\n"
-          "mov w1, #0\n"
-          "dup v28.4s, wzr\n"
-          "dup v29.4s, wzr\n"
-          "dup v30.4s, wzr\n"
-          "dup v31.4s, wzr\n"
-
-          "and w2, %w[rows], #-16\n"
-          "cmp w1, w2\n"
-          "beq 3f\n"
-          "ldr x10, [%[src_ptr0], #8]\n"
-          "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n"
-          "ldr x11, [%[src_ptr1], #8]\n"
-          "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n"
-          "ldr x12, [%[src_ptr2], #8]\n"
-          "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n"
-          "ldr x13, [%[src_ptr3], #8]\n"
-          "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n"
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr0], #64]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr1], #64]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr2], #64]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr3], #64]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr0], #128]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr1], #128]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr2], #128]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr3], #128]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr0], #192]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr1], #192]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr2], #192]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr3], #192]\n")
-          "add w1, w1, #16\n"
-          "cmp w1, w2\n"
-
-          "beq 2f\n"
-
-          "1:\n"
-          "add w1, w1, #16\n"
-          "ins v0.d[1], x10\n"
-          "ldr x10, [%[src_ptr0], #8]\n"
-          "ins v1.d[1], x11\n"
-          "ldr x11, [%[src_ptr1], #8]\n"
-          "ins v2.d[1], x12\n"
-          "ldr x12, [%[src_ptr2], #8]\n"
-          "ins v3.d[1], x13\n"
-          "ldr x13, [%[src_ptr3], #8]\n"
-          "eor v4.16b, v0.16b, v26.16b\n"
-          "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n"
-          "eor v5.16b, v1.16b, v26.16b\n"
-          "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n"
-          "eor v6.16b, v2.16b, v26.16b\n"
-          "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n"
-          "eor v7.16b, v3.16b, v26.16b\n"
-          "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n"
-          "saddlp v16.8h, v4.16b\n"
-          "str q4, [%[packed_ptr], #0]\n"
-          "saddlp v17.8h, v5.16b\n"
-          "str q5, [%[packed_ptr], #16]\n"
-          "saddlp v18.8h, v6.16b\n"
-          "str q6, [%[packed_ptr], #32]\n"
-          "saddlp v19.8h, v7.16b\n"
-          "str q7, [%[packed_ptr], #48]\n"
-          "sadalp v28.4s, v16.8h\n"
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr0], #240]\n")
-          "cmp w1, w2\n"
-          "sadalp v29.4s, v17.8h\n"
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr1], #240]\n")
-          "add %[packed_ptr], %[packed_ptr], #64\n"
-          "sadalp v30.4s, v18.8h\n"
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr2], #240]\n")
-          "sadalp v31.4s, v19.8h\n"
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr3], #240]\n")
-
-          "bne 1b\n"
-
-          "2:\n"
-          "ins v0.d[1], x10\n"
-          "ins v1.d[1], x11\n"
-          "ins v2.d[1], x12\n"
-          "ins v3.d[1], x13\n"
-          "eor v4.16b, v0.16b, v26.16b\n"
-          "eor v5.16b, v1.16b, v26.16b\n"
-          "eor v6.16b, v2.16b, v26.16b\n"
-          "eor v7.16b, v3.16b, v26.16b\n"
-
-          "saddlp v16.8h, v4.16b\n"
-          "str q4, [%[packed_ptr], #0]\n"
-          "saddlp v17.8h, v5.16b\n"
-          "str q5, [%[packed_ptr], #16]\n"
-          "saddlp v18.8h, v6.16b\n"
-          "str q6, [%[packed_ptr], #32]\n"
-          "saddlp v19.8h, v7.16b\n"
-          "str q7, [%[packed_ptr], #48]\n"
-          "sadalp v28.4s, v16.8h\n"
-          "sadalp v29.4s, v17.8h\n"
-          "sadalp v30.4s, v18.8h\n"
-          "sadalp v31.4s, v19.8h\n"
-
-          "add %[packed_ptr], %[packed_ptr], #64\n"
-
-          "3:\n"
-
-          "ands w2, %w[rows], #15\n"
-          "beq 4f\n"
-          "dup v0.16b, %w[src_zero_point]\n"
-          "dup v1.16b, %w[src_zero_point]\n"
-          "dup v2.16b, %w[src_zero_point]\n"
-          "dup v3.16b, %w[src_zero_point]\n"
-#define RUY_LOAD_ONE_ROW(R)                   \
-  "cmp w2, #" #R "\n"                         \
-  "beq 5f\n"                                  \
-  "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \
-  "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \
-  "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \
-  "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n"
-
-          RUY_LOAD_ONE_ROW(0)
-          RUY_LOAD_ONE_ROW(1)
-          RUY_LOAD_ONE_ROW(2)
-          RUY_LOAD_ONE_ROW(3)
-          RUY_LOAD_ONE_ROW(4)
-          RUY_LOAD_ONE_ROW(5)
-          RUY_LOAD_ONE_ROW(6)
-          RUY_LOAD_ONE_ROW(7)
-          RUY_LOAD_ONE_ROW(8)
-          RUY_LOAD_ONE_ROW(9)
-          RUY_LOAD_ONE_ROW(10)
-          RUY_LOAD_ONE_ROW(11)
-          RUY_LOAD_ONE_ROW(12)
-          RUY_LOAD_ONE_ROW(13)
-          RUY_LOAD_ONE_ROW(14)
-          RUY_LOAD_ONE_ROW(15)
-#undef RUY_LOAD_ONE_ROW
-          "5:\n"
-
-          "eor v4.16b, v0.16b, v26.16b\n"
-          "eor v5.16b, v1.16b, v26.16b\n"
-          "eor v6.16b, v2.16b, v26.16b\n"
-          "eor v7.16b, v3.16b, v26.16b\n"
-
-          "saddlp v16.8h, v4.16b\n"
-          "saddlp v17.8h, v5.16b\n"
-          "saddlp v18.8h, v6.16b\n"
-          "saddlp v19.8h, v7.16b\n"
-          "sadalp v28.4s, v16.8h\n"
-          "sadalp v29.4s, v17.8h\n"
-          "sadalp v30.4s, v18.8h\n"
-          "sadalp v31.4s, v19.8h\n"
-
-          "str q4, [%[packed_ptr], #0]\n"
-          "str q5, [%[packed_ptr], #16]\n"
-          "str q6, [%[packed_ptr], #32]\n"
-          "str q7, [%[packed_ptr], #48]\n"
-          "add %[packed_ptr], %[packed_ptr], #64\n"
-
-          "4:\n"
-
-          "addp v28.4s, v28.4s, v29.4s\n"
-          "addp v30.4s, v30.4s, v31.4s\n"
-          "addp v28.4s, v28.4s, v30.4s\n"
-
-          "cmp %[sums_ptr], #0\n"
-          "beq 6f\n"
-          "st1 {v28.4s}, [%[sums_ptr]], #16\n"
-          "6:\n"
-          // clang-format on
-
-          : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
-            [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3),
-            [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr)
-          : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)), [ src_inc1 ] "r"(static_cast<std::int64_t>(src_inc1)),
-            [ src_inc2 ] "r"(static_cast<std::int64_t>(src_inc2)), [ src_inc3 ] "r"(static_cast<std::int64_t>(src_inc3)),
-            [ rows ] "r"(src_rows),
-            [ src_zero_point ] "r"(src_zero_point),
-            [input_xor] "r"(input_xor)
-          : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5",
-            "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
-            "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24",
-            "v25", "v26", "v27", "v28", "v29", "v30", "v31");
-}
-
-void Pack8bitNeonDotprodInOrder(const void* src_ptr0, const void* src_ptr1,
-                                const void* src_ptr2, const void* src_ptr3,
-                                int src_inc0, int src_inc1, int src_inc2,
-                                int src_inc3, int src_rows, int src_zero_point,
-                                std::int8_t* packed_ptr, int start_col,
-                                int end_col, std::int32_t* sums_ptr,
-                                int input_xor) {
-  gemmlowp::ScopedProfilingLabel label(
-      "Pack (kNeonDotprod, optimized for in-order cores)");
-  asm volatile(
-          // clang-format off
-          "dup v26.16b, %w[input_xor]\n"
-          "mov w1, #1\n"
-          "dup v27.16b, w1\n"
-          "mov w1, #0\n"
-          "dup v28.4s, wzr\n"
-          "dup v29.4s, wzr\n"
-          "dup v30.4s, wzr\n"
-          "dup v31.4s, wzr\n"
-
-          "and w2, %w[rows], #-16\n"
-          "cmp w1, w2\n"
-          "beq 3f\n"
-          "ldr x10, [%[src_ptr0], #8]\n"
-          "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n"
-          "ldr x11, [%[src_ptr1], #8]\n"
-          "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n"
-          "ldr x12, [%[src_ptr2], #8]\n"
-          "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n"
-          "ldr x13, [%[src_ptr3], #8]\n"
-          "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n"
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr0], #64]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr1], #64]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr2], #64]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr3], #64]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr0], #128]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr1], #128]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr2], #128]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr3], #128]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr0], #192]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr1], #192]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr2], #192]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr3], #192]\n")
-          "add w1, w1, #16\n"
-          "cmp w1, w2\n"
-
-          "beq 2f\n"
-
-          "1:\n"
-          "add w1, w1, #16\n"
-          "ins v0.d[1], x10\n"
-          "ldr x10, [%[src_ptr0], #8]\n"
-          "ins v1.d[1], x11\n"
-          "ldr x11, [%[src_ptr1], #8]\n"
-          "ins v2.d[1], x12\n"
-          "ldr x12, [%[src_ptr2], #8]\n"
-          "ins v3.d[1], x13\n"
-          "ldr x13, [%[src_ptr3], #8]\n"
-
-          "eor v4.16b, v0.16b, v26.16b\n"
-          "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n"
-          "eor v5.16b, v1.16b, v26.16b\n"
-          "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n"
-          "eor v6.16b, v2.16b, v26.16b\n"
-          "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n"
-          "eor v7.16b, v3.16b, v26.16b\n"
-          "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n"
-
-          "trn1 v16.4s, v4.4s, v5.4s\n"
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr0], #240]\n")
-          "trn2 v17.4s, v4.4s, v5.4s\n"
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr1], #240]\n")
-          "trn1 v18.4s, v6.4s, v7.4s\n"
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr2], #240]\n")
-          "trn2 v19.4s, v6.4s, v7.4s\n"
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr3], #240]\n")
-
-          "trn1 v20.2d, v16.2d, v18.2d\n"
-          "trn2 v22.2d, v16.2d, v18.2d\n"
-          "trn1 v21.2d, v17.2d, v19.2d\n"
-          "trn2 v23.2d, v17.2d, v19.2d\n"
-          "cmp w1, w2\n"
-
-          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
-          "str q20, [%[packed_ptr], #0]\n"
-          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
-          "str q21, [%[packed_ptr], #32]\n"
-          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
-          "str q22, [%[packed_ptr], #64]\n"
-          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
-          "str q23, [%[packed_ptr], #96]\n"
-
-          "add %[packed_ptr], %[packed_ptr], #128\n"
-
-          "bne 1b\n"
-
-          "2:\n"
-          "ins v0.d[1], x10\n"
-          "ins v1.d[1], x11\n"
-          "ins v2.d[1], x12\n"
-          "ins v3.d[1], x13\n"
-          "eor v0.16b, v0.16b, v26.16b\n"
-          "eor v1.16b, v1.16b, v26.16b\n"
-          "eor v2.16b, v2.16b, v26.16b\n"
-          "eor v3.16b, v3.16b, v26.16b\n"
-
-          "trn1 v16.4s, v0.4s, v1.4s\n"
-          "trn2 v17.4s, v0.4s, v1.4s\n"
-          "trn1 v18.4s, v2.4s, v3.4s\n"
-          "trn2 v19.4s, v2.4s, v3.4s\n"
-
-          "trn1 v20.2d, v16.2d, v18.2d\n"
-          "trn2 v22.2d, v16.2d, v18.2d\n"
-          "trn1 v21.2d, v17.2d, v19.2d\n"
-          "trn2 v23.2d, v17.2d, v19.2d\n"
-
-          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
-          "str q20, [%[packed_ptr], #0]\n"
-          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
-          "str q21, [%[packed_ptr], #32]\n"
-          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
-          "str q22, [%[packed_ptr], #64]\n"
-          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
-          "str q23, [%[packed_ptr], #96]\n"
-          "add %[packed_ptr], %[packed_ptr], #128\n"
-
-          "3:\n"
-
-          "ands w2, %w[rows], #15\n"
-          "beq 4f\n"
-          "dup v0.16b, %w[src_zero_point]\n"
-          "dup v1.16b, %w[src_zero_point]\n"
-          "dup v2.16b, %w[src_zero_point]\n"
-          "dup v3.16b, %w[src_zero_point]\n"
-#define RUY_LOAD_ONE_ROW(R)                   \
-  "cmp w2, #" #R "\n"                         \
-  "beq 5f\n"                                  \
-  "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \
-  "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \
-  "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \
-  "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n"
-
-          RUY_LOAD_ONE_ROW(0)
-          RUY_LOAD_ONE_ROW(1)
-          RUY_LOAD_ONE_ROW(2)
-          RUY_LOAD_ONE_ROW(3)
-          RUY_LOAD_ONE_ROW(4)
-          RUY_LOAD_ONE_ROW(5)
-          RUY_LOAD_ONE_ROW(6)
-          RUY_LOAD_ONE_ROW(7)
-          RUY_LOAD_ONE_ROW(8)
-          RUY_LOAD_ONE_ROW(9)
-          RUY_LOAD_ONE_ROW(10)
-          RUY_LOAD_ONE_ROW(11)
-          RUY_LOAD_ONE_ROW(12)
-          RUY_LOAD_ONE_ROW(13)
-          RUY_LOAD_ONE_ROW(14)
-          RUY_LOAD_ONE_ROW(15)
-#undef RUY_LOAD_ONE_ROW
-          "5:\n"
-
-          "eor v0.16b, v0.16b, v26.16b\n"
-          "eor v1.16b, v1.16b, v26.16b\n"
-          "eor v2.16b, v2.16b, v26.16b\n"
-          "eor v3.16b, v3.16b, v26.16b\n"
-
-          "trn1 v16.4s, v0.4s, v1.4s\n"
-          "trn2 v17.4s, v0.4s, v1.4s\n"
-          "trn1 v18.4s, v2.4s, v3.4s\n"
-          "trn2 v19.4s, v2.4s, v3.4s\n"
-
-          "trn1 v20.2d, v16.2d, v18.2d\n"
-          "trn2 v22.2d, v16.2d, v18.2d\n"
-          "trn1 v21.2d, v17.2d, v19.2d\n"
-          "trn2 v23.2d, v17.2d, v19.2d\n"
-
-          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
-          "str q20, [%[packed_ptr], #0]\n"
-          "cmp w2, #4\n"
-          "ble 4f\n"
-          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
-          "str q21, [%[packed_ptr], #32]\n"
-          "cmp w2, #8\n"
-          "ble 4f\n"
-          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
-          "str q22, [%[packed_ptr], #64]\n"
-          "cmp w2, #12\n"
-          "ble 4f\n"
-          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
-          "str q23, [%[packed_ptr], #96]\n"
-          "add %[packed_ptr], %[packed_ptr], #128\n"
-
-          "4:\n"
-
-          "add v28.4s, v28.4s, v29.4s\n"
-          "add v30.4s, v30.4s, v31.4s\n"
-          "add v28.4s, v28.4s, v30.4s\n"
-
-          "cmp %[sums_ptr], #0\n"
-          "beq 6f\n"
-          "st1 {v28.4s}, [%[sums_ptr]], #16\n"
-          "6:\n"
-          // clang-format on
-
-          : [ src_ptr0 ] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), [src_ptr2] "+r"(src_ptr2),
-            [src_ptr3] "+r"(src_ptr3), [packed_ptr] "+r"(packed_ptr), [sums_ptr] "+r"(sums_ptr)
-          : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)), [ src_inc1 ] "r"(static_cast<std::int64_t>(src_inc1)),
-            [ src_inc2 ] "r"(static_cast<std::int64_t>(src_inc2)), [ src_inc3 ] "r"(static_cast<std::int64_t>(src_inc3)),
-                [rows] "r"(src_rows),
-            [src_zero_point] "r"(static_cast<int>(src_zero_point)),
-            [input_xor] "r"(input_xor)
-          : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
-            "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
-            "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
-}
-
-void Pack8bitNeonDotprodOutOfOrder(const void* src_ptr0, const void* src_ptr1,
-                                   const void* src_ptr2, const void* src_ptr3,
-                                   int src_inc0, int src_inc1, int src_inc2,
-                                   int src_inc3, int src_rows,
-                                   int src_zero_point, std::int8_t* packed_ptr,
-                                   int start_col, int end_col,
-                                   std::int32_t* sums_ptr, int input_xor) {
-  gemmlowp::ScopedProfilingLabel label(
-      "Pack (kNeonDotprod, optimized for out-of-order cores)");
-  asm volatile(
-      // clang-format off
-          "dup v26.16b, %w[input_xor]\n"
-          "mov w1, #1\n"
-          "dup v27.16b, w1\n"
-          "mov w1, #0\n"
-          "dup v28.4s, wzr\n"
-          "dup v29.4s, wzr\n"
-          "dup v30.4s, wzr\n"
-          "dup v31.4s, wzr\n"
-
-#if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING)
-          "and w2, %w[rows], #-64\n"
-          "cmp w1, w2\n"
-          "beq 9f\n"
-
-          "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
-          "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
-          "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
-          "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
-          "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n"
-          "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n"
-          "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n"
-          "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n"
-          "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n"
-          "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n"
-          "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n"
-          "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n"
-          "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n"
-          "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n"
-          "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n"
-          "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n"
-          "add w1, w1, #64\n"
-          "cmp w1, w2\n"
-          "beq 8f\n"
-
-          "7:\n"
-          "eor v0.16b, v0.16b, v26.16b\n"
-          "eor v1.16b, v1.16b, v26.16b\n"
-          "eor v2.16b, v2.16b, v26.16b\n"
-          "eor v3.16b, v3.16b, v26.16b\n"
-
-          "trn1 v16.4s, v0.4s, v1.4s\n"
-          "trn2 v17.4s, v0.4s, v1.4s\n"
-          "trn1 v18.4s, v2.4s, v3.4s\n"
-          "trn2 v19.4s, v2.4s, v3.4s\n"
-
-          "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
-          "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
-          "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
-          "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
-          "add w1, w1, #16\n"
-
-          "trn1 v20.2d, v16.2d, v18.2d\n"
-          "trn2 v22.2d, v16.2d, v18.2d\n"
-          "trn1 v21.2d, v17.2d, v19.2d\n"
-          "trn2 v23.2d, v17.2d, v19.2d\n"
-
-          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
-          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
-          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
-          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
-
-          "str q20, [%[packed_ptr], #0]\n"
-          "str q21, [%[packed_ptr], #32]\n"
-          "str q22, [%[packed_ptr], #64]\n"
-          "str q23, [%[packed_ptr], #96]\n"
-          "add %[packed_ptr], %[packed_ptr], #128\n"
-
-          "eor v4.16b, v4.16b, v26.16b\n"
-          "eor v5.16b, v5.16b, v26.16b\n"
-          "eor v6.16b, v6.16b, v26.16b\n"
-          "eor v7.16b, v7.16b, v26.16b\n"
-
-          "trn1 v16.4s, v4.4s, v5.4s\n"
-          "trn2 v17.4s, v4.4s, v5.4s\n"
-          "trn1 v18.4s, v6.4s, v7.4s\n"
-          "trn2 v19.4s, v6.4s, v7.4s\n"
-
-          "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n"
-          "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n"
-          "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n"
-          "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n"
-          "add w1, w1, #16\n"
-
-          "trn1 v20.2d, v16.2d, v18.2d\n"
-          "trn2 v22.2d, v16.2d, v18.2d\n"
-          "trn1 v21.2d, v17.2d, v19.2d\n"
-          "trn2 v23.2d, v17.2d, v19.2d\n"
-
-          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
-          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
-          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
-          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
-
-          "str q20, [%[packed_ptr], #0]\n"
-          "str q21, [%[packed_ptr], #32]\n"
-          "str q22, [%[packed_ptr], #64]\n"
-          "str q23, [%[packed_ptr], #96]\n"
-          "add %[packed_ptr], %[packed_ptr], #128\n"
-
-          "eor v8.16b, v8.16b, v26.16b\n"
-          "eor v9.16b, v9.16b, v26.16b\n"
-          "eor v10.16b, v10.16b, v26.16b\n"
-          "eor v11.16b, v11.16b, v26.16b\n"
-
-          "trn1 v16.4s, v8.4s, v9.4s\n"
-          "trn2 v17.4s, v8.4s, v9.4s\n"
-          "trn1 v18.4s, v10.4s, v11.4s\n"
-          "trn2 v19.4s, v10.4s, v11.4s\n"
-
-          "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n"
-          "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n"
-          "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n"
-          "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n"
-          "add w1, w1, #16\n"
-
-          "trn1 v20.2d, v16.2d, v18.2d\n"
-          "trn2 v22.2d, v16.2d, v18.2d\n"
-          "trn1 v21.2d, v17.2d, v19.2d\n"
-          "trn2 v23.2d, v17.2d, v19.2d\n"
-
-          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
-          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
-          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
-          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
-
-          "str q20, [%[packed_ptr], #0]\n"
-          "str q21, [%[packed_ptr], #32]\n"
-          "str q22, [%[packed_ptr], #64]\n"
-          "str q23, [%[packed_ptr], #96]\n"
-          "add %[packed_ptr], %[packed_ptr], #128\n"
-
-          "eor v12.16b, v12.16b, v26.16b\n"
-          "eor v13.16b, v13.16b, v26.16b\n"
-          "eor v14.16b, v14.16b, v26.16b\n"
-          "eor v15.16b, v15.16b, v26.16b\n"
-
-          "trn1 v16.4s, v12.4s, v13.4s\n"
-          "trn2 v17.4s, v12.4s, v13.4s\n"
-          "trn1 v18.4s, v14.4s, v15.4s\n"
-          "trn2 v19.4s, v14.4s, v15.4s\n"
-
-          "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n"
-          "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n"
-          "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n"
-          "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n"
-          "add w1, w1, #16\n"
-
-          "trn1 v20.2d, v16.2d, v18.2d\n"
-          "trn2 v22.2d, v16.2d, v18.2d\n"
-          "trn1 v21.2d, v17.2d, v19.2d\n"
-          "trn2 v23.2d, v17.2d, v19.2d\n"
-
-          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
-          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
-          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
-          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
-
-          "str q20, [%[packed_ptr], #0]\n"
-          "str q21, [%[packed_ptr], #32]\n"
-          "str q22, [%[packed_ptr], #64]\n"
-          "str q23, [%[packed_ptr], #96]\n"
-          "add %[packed_ptr], %[packed_ptr], #128\n"
-
-          "cmp w1, w2\n"
-          "bne 7b\n"
-
-          "8:\n"
-
-          "eor v0.16b, v0.16b, v26.16b\n"
-          "eor v1.16b, v1.16b, v26.16b\n"
-          "eor v2.16b, v2.16b, v26.16b\n"
-          "eor v3.16b, v3.16b, v26.16b\n"
-
-          "trn1 v16.4s, v0.4s, v1.4s\n"
-          "trn2 v17.4s, v0.4s, v1.4s\n"
-          "trn1 v18.4s, v2.4s, v3.4s\n"
-          "trn2 v19.4s, v2.4s, v3.4s\n"
-
-          "trn1 v20.2d, v16.2d, v18.2d\n"
-          "trn2 v22.2d, v16.2d, v18.2d\n"
-          "trn1 v21.2d, v17.2d, v19.2d\n"
-          "trn2 v23.2d, v17.2d, v19.2d\n"
-
-          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
-          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
-          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
-          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
-
-          "str q20, [%[packed_ptr], #0]\n"
-          "str q21, [%[packed_ptr], #32]\n"
-          "str q22, [%[packed_ptr], #64]\n"
-          "str q23, [%[packed_ptr], #96]\n"
-          "add %[packed_ptr], %[packed_ptr], #128\n"
-
-          "eor v4.16b, v4.16b, v26.16b\n"
-          "eor v5.16b, v5.16b, v26.16b\n"
-          "eor v6.16b, v6.16b, v26.16b\n"
-          "eor v7.16b, v7.16b, v26.16b\n"
-
-          "trn1 v16.4s, v4.4s, v5.4s\n"
-          "trn2 v17.4s, v4.4s, v5.4s\n"
-          "trn1 v18.4s, v6.4s, v7.4s\n"
-          "trn2 v19.4s, v6.4s, v7.4s\n"
-
-          "trn1 v20.2d, v16.2d, v18.2d\n"
-          "trn2 v22.2d, v16.2d, v18.2d\n"
-          "trn1 v21.2d, v17.2d, v19.2d\n"
-          "trn2 v23.2d, v17.2d, v19.2d\n"
-
-          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
-          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
-          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
-          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
-
-          "str q20, [%[packed_ptr], #0]\n"
-          "str q21, [%[packed_ptr], #32]\n"
-          "str q22, [%[packed_ptr], #64]\n"
-          "str q23, [%[packed_ptr], #96]\n"
-          "add %[packed_ptr], %[packed_ptr], #128\n"
-
-          "eor v8.16b, v8.16b, v26.16b\n"
-          "eor v9.16b, v9.16b, v26.16b\n"
-          "eor v10.16b, v10.16b, v26.16b\n"
-          "eor v11.16b, v11.16b, v26.16b\n"
-
-          "trn1 v16.4s, v8.4s, v9.4s\n"
-          "trn2 v17.4s, v8.4s, v9.4s\n"
-          "trn1 v18.4s, v10.4s, v11.4s\n"
-          "trn2 v19.4s, v10.4s, v11.4s\n"
-
-          "trn1 v20.2d, v16.2d, v18.2d\n"
-          "trn2 v22.2d, v16.2d, v18.2d\n"
-          "trn1 v21.2d, v17.2d, v19.2d\n"
-          "trn2 v23.2d, v17.2d, v19.2d\n"
-
-          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
-          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
-          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
-          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
-
-          "str q20, [%[packed_ptr], #0]\n"
-          "str q21, [%[packed_ptr], #32]\n"
-          "str q22, [%[packed_ptr], #64]\n"
-          "str q23, [%[packed_ptr], #96]\n"
-          "add %[packed_ptr], %[packed_ptr], #128\n"
-
-          "eor v12.16b, v12.16b, v26.16b\n"
-          "eor v13.16b, v13.16b, v26.16b\n"
-          "eor v14.16b, v14.16b, v26.16b\n"
-          "eor v15.16b, v15.16b, v26.16b\n"
-
-          "trn1 v16.4s, v12.4s, v13.4s\n"
-          "trn2 v17.4s, v12.4s, v13.4s\n"
-          "trn1 v18.4s, v14.4s, v15.4s\n"
-          "trn2 v19.4s, v14.4s, v15.4s\n"
-
-          "trn1 v20.2d, v16.2d, v18.2d\n"
-          "trn2 v22.2d, v16.2d, v18.2d\n"
-          "trn1 v21.2d, v17.2d, v19.2d\n"
-          "trn2 v23.2d, v17.2d, v19.2d\n"
-
-          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
-          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
-          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
-          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
-
-          "str q20, [%[packed_ptr], #0]\n"
-          "str q21, [%[packed_ptr], #32]\n"
-          "str q22, [%[packed_ptr], #64]\n"
-          "str q23, [%[packed_ptr], #96]\n"
-          "add %[packed_ptr], %[packed_ptr], #128\n"
-
-          "9:\n"
-#endif  // #if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING)
-          "and w2, %w[rows], #-16\n"
-          "cmp w1, w2\n"
-          "beq 3f\n"
-
-          "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
-          "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
-          "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
-          "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
-          "add w1, w1, #16\n"
-          "cmp w1, w2\n"
-          "beq 2f\n"
-
-          "1:\n"
-
-          "eor v0.16b, v0.16b, v26.16b\n"
-          "eor v1.16b, v1.16b, v26.16b\n"
-          "eor v2.16b, v2.16b, v26.16b\n"
-          "eor v3.16b, v3.16b, v26.16b\n"
-
-          "trn1 v16.4s, v0.4s, v1.4s\n"
-          "trn2 v17.4s, v0.4s, v1.4s\n"
-          "trn1 v18.4s, v2.4s, v3.4s\n"
-          "trn2 v19.4s, v2.4s, v3.4s\n"
-
-          "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
-          "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
-          "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
-          "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
-          "add w1, w1, #16\n"
-
-          "trn1 v20.2d, v16.2d, v18.2d\n"
-          "trn2 v22.2d, v16.2d, v18.2d\n"
-          "trn1 v21.2d, v17.2d, v19.2d\n"
-          "trn2 v23.2d, v17.2d, v19.2d\n"
-
-          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
-          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
-          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
-          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
-
-          "str q20, [%[packed_ptr], #0]\n"
-          "str q21, [%[packed_ptr], #32]\n"
-          "str q22, [%[packed_ptr], #64]\n"
-          "str q23, [%[packed_ptr], #96]\n"
-          "add %[packed_ptr], %[packed_ptr], #128\n"
-
-          "cmp w1, w2\n"
-          "bne 1b\n"
-
-          "2:\n"
-
-          "eor v0.16b, v0.16b, v26.16b\n"
-          "eor v1.16b, v1.16b, v26.16b\n"
-          "eor v2.16b, v2.16b, v26.16b\n"
-          "eor v3.16b, v3.16b, v26.16b\n"
-
-          "trn1 v16.4s, v0.4s, v1.4s\n"
-          "trn2 v17.4s, v0.4s, v1.4s\n"
-          "trn1 v18.4s, v2.4s, v3.4s\n"
-          "trn2 v19.4s, v2.4s, v3.4s\n"
-
-          "trn1 v20.2d, v16.2d, v18.2d\n"
-          "trn2 v22.2d, v16.2d, v18.2d\n"
-          "trn1 v21.2d, v17.2d, v19.2d\n"
-          "trn2 v23.2d, v17.2d, v19.2d\n"
-
-          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
-          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
-          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
-          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
-
-          "str q20, [%[packed_ptr], #0]\n"
-          "str q21, [%[packed_ptr], #32]\n"
-          "str q22, [%[packed_ptr], #64]\n"
-          "str q23, [%[packed_ptr], #96]\n"
-          "add %[packed_ptr], %[packed_ptr], #128\n"
-
-          "3:\n"
-
-          "ands w2, %w[rows], #15\n"
-          "beq 4f\n"
-          "dup v0.16b, %w[src_zero_point]\n"
-          "dup v1.16b, %w[src_zero_point]\n"
-          "dup v2.16b, %w[src_zero_point]\n"
-          "dup v3.16b, %w[src_zero_point]\n"
-#define RUY_LOAD_ONE_ROW(R)                   \
-  "cmp w2, #" #R "\n"                         \
-  "beq 5f\n"                                  \
-  "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \
-  "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \
-  "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \
-  "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n"
-
-          RUY_LOAD_ONE_ROW(0)
-          RUY_LOAD_ONE_ROW(1)
-          RUY_LOAD_ONE_ROW(2)
-          RUY_LOAD_ONE_ROW(3)
-          RUY_LOAD_ONE_ROW(4)
-          RUY_LOAD_ONE_ROW(5)
-          RUY_LOAD_ONE_ROW(6)
-          RUY_LOAD_ONE_ROW(7)
-          RUY_LOAD_ONE_ROW(8)
-          RUY_LOAD_ONE_ROW(9)
-          RUY_LOAD_ONE_ROW(10)
-          RUY_LOAD_ONE_ROW(11)
-          RUY_LOAD_ONE_ROW(12)
-          RUY_LOAD_ONE_ROW(13)
-          RUY_LOAD_ONE_ROW(14)
-          RUY_LOAD_ONE_ROW(15)
-#undef RUY_LOAD_ONE_ROW
-          "5:\n"
-
-          "eor v0.16b, v0.16b, v26.16b\n"
-          "eor v1.16b, v1.16b, v26.16b\n"
-          "eor v2.16b, v2.16b, v26.16b\n"
-          "eor v3.16b, v3.16b, v26.16b\n"
-
-          "trn1 v16.4s, v0.4s, v1.4s\n"
-          "trn2 v17.4s, v0.4s, v1.4s\n"
-          "trn1 v18.4s, v2.4s, v3.4s\n"
-          "trn2 v19.4s, v2.4s, v3.4s\n"
-
-          "trn1 v20.2d, v16.2d, v18.2d\n"
-          "trn2 v22.2d, v16.2d, v18.2d\n"
-          "trn1 v21.2d, v17.2d, v19.2d\n"
-          "trn2 v23.2d, v17.2d, v19.2d\n"
-
-          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
-          "str q20, [%[packed_ptr], #0]\n"
-          "cmp w2, #4\n"
-          "ble 4f\n"
-          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
-          "str q21, [%[packed_ptr], #32]\n"
-          "cmp w2, #8\n"
-          "ble 4f\n"
-          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
-          "str q22, [%[packed_ptr], #64]\n"
-          "cmp w2, #12\n"
-          "ble 4f\n"
-          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
-          "str q23, [%[packed_ptr], #96]\n"
-          "add %[packed_ptr], %[packed_ptr], #128\n"
-
-          "4:\n"
-
-          "add v28.4s, v28.4s, v29.4s\n"
-          "add v30.4s, v30.4s, v31.4s\n"
-          "add v28.4s, v28.4s, v30.4s\n"
-
-          "cmp %[sums_ptr], #0\n"
-          "beq 6f\n"
-          "st1 {v28.4s}, [%[sums_ptr]], #16\n"
-          "6:\n"
-      // clang-format on
-
-      : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
-        [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3),
-        [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr)
-      : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)),
-        [ src_inc1 ] "r"(static_cast<std::int64_t>(src_inc1)),
-        [ src_inc2 ] "r"(static_cast<std::int64_t>(src_inc2)),
-        [ src_inc3 ] "r"(static_cast<std::int64_t>(src_inc3)),
-        [ rows ] "r"(src_rows),
-        [ src_zero_point ] "r"(static_cast<int>(src_zero_point)),
-        [ input_xor ] "r"(input_xor)
-      : "cc", "memory", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
-        "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
-        "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
-        "v27", "v28", "v29", "v30", "v31");
-}
-
-#endif  // RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
-
-#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
-void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1,
-                             const float* src_ptr2, const float* src_ptr3,
-                             int src_inc0, int src_inc1, int src_inc2,
-                             int src_inc3, int src_rows, int src_zero_point,
-                             float* packed_ptr, int start_col, int end_col) {
-  gemmlowp::ScopedProfilingLabel label(
-      "Pack (kNeon, optimized for out-of-order cores)");
-  asm volatile(
-      // clang-format off
-          "mov w1, #0\n"
-
-          "and w2, %w[rows], #-4\n"
-          "cmp w1, w2\n"
-          "beq 3f\n"
-          "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n"
-          "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n"
-          "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n"
-          "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n"
-          "add w1, w1, #4\n"
-          "cmp w1, w2\n"
-
-          "beq 2f\n"
-
-          "1:\n"
-          "add w1, w1, #4\n"
-
-          "trn1 v16.4s, v0.4s, v1.4s\n"
-          "trn2 v17.4s, v0.4s, v1.4s\n"
-          "trn1 v18.4s, v2.4s, v3.4s\n"
-          "trn2 v19.4s, v2.4s, v3.4s\n"
-
-          "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n"
-          "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n"
-          "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n"
-          "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n"
-
-          "trn1 v20.2d, v16.2d, v18.2d\n"
-          "trn2 v22.2d, v16.2d, v18.2d\n"
-          "trn1 v21.2d, v17.2d, v19.2d\n"
-          "trn2 v23.2d, v17.2d, v19.2d\n"
-          "cmp w1, w2\n"
-
-          "str q20, [%[packed_ptr], #0]\n"
-          "str q21, [%[packed_ptr], #32]\n"
-          "str q22, [%[packed_ptr], #64]\n"
-          "str q23, [%[packed_ptr], #96]\n"
-
-          "add %[packed_ptr], %[packed_ptr], #128\n"
-
-          "bne 1b\n"
-
-          "2:\n"
-
-          "trn1 v16.4s, v0.4s, v1.4s\n"
-          "trn2 v17.4s, v0.4s, v1.4s\n"
-          "trn1 v18.4s, v2.4s, v3.4s\n"
-          "trn2 v19.4s, v2.4s, v3.4s\n"
-
-          "trn1 v20.2d, v16.2d, v18.2d\n"
-          "trn2 v22.2d, v16.2d, v18.2d\n"
-          "trn1 v21.2d, v17.2d, v19.2d\n"
-          "trn2 v23.2d, v17.2d, v19.2d\n"
-
-          "str q20, [%[packed_ptr], #0]\n"
-          "str q21, [%[packed_ptr], #32]\n"
-          "str q22, [%[packed_ptr], #64]\n"
-          "str q23, [%[packed_ptr], #96]\n"
-          "add %[packed_ptr], %[packed_ptr], #128\n"
-
-          "3:\n"
-
-          "ands w2, %w[rows], #3\n"
-          "beq 4f\n"
-          "dup v0.16b, wzr\n"
-          "dup v1.16b, wzr\n"
-          "dup v2.16b, wzr\n"
-          "dup v3.16b, wzr\n"
-#define RUY_LOAD_ONE_ROW(R)                   \
-  "cmp w2, #" #R "\n"                         \
-  "beq 5f\n"                                  \
-  "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \
-  "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \
-  "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \
-  "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n"
-
-          RUY_LOAD_ONE_ROW(0)
-          RUY_LOAD_ONE_ROW(1)
-          RUY_LOAD_ONE_ROW(2)
-          RUY_LOAD_ONE_ROW(3)
-#undef RUY_LOAD_ONE_ROW
-          "5:\n"
-
-          "trn1 v16.4s, v0.4s, v1.4s\n"
-          "trn2 v17.4s, v0.4s, v1.4s\n"
-          "trn1 v18.4s, v2.4s, v3.4s\n"
-          "trn2 v19.4s, v2.4s, v3.4s\n"
-
-          "trn1 v20.2d, v16.2d, v18.2d\n"
-          "trn2 v22.2d, v16.2d, v18.2d\n"
-          "trn1 v21.2d, v17.2d, v19.2d\n"
-          "trn2 v23.2d, v17.2d, v19.2d\n"
-
-          "mov x1, #32\n"
-
-#define RUY_STORE_ONE_ROW(ROW, REGISTER)                  \
-          "cmp w2, #" #ROW "\n"                           \
-          "beq 4f\n"                                      \
-          "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n"
-
-          RUY_STORE_ONE_ROW(0, v20)
-          RUY_STORE_ONE_ROW(1, v21)
-          RUY_STORE_ONE_ROW(2, v22)
-          RUY_STORE_ONE_ROW(3, v23)
-
-#undef RUY_STORE_ONE_ROW
-
-          "4:\n"
-
-      // clang-format on
-
-      : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
-        [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3),
-        [ packed_ptr ] "+r"(packed_ptr)
-      : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)),
-        [ src_inc1 ] "r"(static_cast<std::int64_t>(src_inc1)),
-        [ src_inc2 ] "r"(static_cast<std::int64_t>(src_inc2)),
-        [ src_inc3 ] "r"(static_cast<std::int64_t>(src_inc3)),
-        [ rows ] "r"(src_rows)
-      : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1",
-        "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
-        "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
-        "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
-}
-#endif
-
-#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM)
-void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1,
-                             const float* src_ptr2, const float* src_ptr3,
-                             int src_inc, int src_rows, int src_zero_point,
-                             float* packed_ptr, int start_col, int end_col,
-                             int output_stride) {
-  gemmlowp::ScopedProfilingLabel label(
-      "Pack (kNeon, optimized for out-of-order cores)");
-  asm volatile(
-      // clang-format off
-          "mov r1, #0\n"
-          "and r2, %[rows], #-4\n"
-          "cmp r1, r2\n"
-          "beq 3f\n"
-#define RUY_LOAD_FOUR_BY_FOUR()               \
-  /* Load q0 */                               \
-  "vldr d0, [%[src_ptr0], #0]\n"              \
-  "vldr d1, [%[src_ptr0], #8]\n"              \
-  /* if src_inc0 != 0, add 16 to src_ptr0 */  \
-  "and r3, %[src_inc], #1\n"                  \
-  "add %[src_ptr0], %[src_ptr0], r3, lsl #4\n"\
-  /* Load q1 */                               \
-  "vldr d2, [%[src_ptr1], #0]\n"              \
-  "vldr d3, [%[src_ptr1], #8]\n"              \
-  /* if src_inc1 != 0, add 16 to src_ptr0 */  \
-  "and r3, %[src_inc], #2\n"                  \
-  "add %[src_ptr1], %[src_ptr1], r3, lsl #3\n"\
-  /* Load q2 */                               \
-  "vldr d4, [%[src_ptr2], #0]\n"              \
-  "vldr d5, [%[src_ptr2], #8]\n"              \
-  /* if src_inc2 != 0, add 16 to src_ptr0 */  \
-  "and r3, %[src_inc], #4\n"                  \
-  "add %[src_ptr2], %[src_ptr2], r3, lsl #2\n"\
-  /* Load q3 */                               \
-  "vldr d6, [%[src_ptr3], #0]\n"              \
-  "vldr d7, [%[src_ptr3], #8]\n"              \
-  /* if src_inc3 != 0, add 16 to src_ptr0 */  \
-  "and r3, %[src_inc], #8\n"                  \
-  "add %[src_ptr3], %[src_ptr3], r3, lsl #1\n"\
-
-          RUY_LOAD_FOUR_BY_FOUR()
-          "add r1, r1, #4\n"
-          "cmp r1, r2\n"
-
-          "beq 2f\n"
-
-          "1:\n"
-          "add r1, r1, #4\n"
-
-          // Transpose 4x4 matrix.
-          "vzip.32 q0, q1\n"
-          "vzip.32 q2, q3\n"
-
-          "vtrn.32 q0, q2\n"
-          "vtrn.32 q1, q3\n"
-
-          "vzip.32 q0, q2\n"
-          "vzip.32 q1, q3\n"
-
-          "vmov q8, q0\n"
-          "vmov q9, q1\n"
-          "vmov q10, q2\n"
-          "vmov q11, q3\n"
-
-          RUY_LOAD_FOUR_BY_FOUR()
-#undef RUY_LOAD_FOUR_BY_FOUR
-
-#define RUY_STORE_FOUR_BY_FOUR()                  \
-  /* Store q8, q10, q9, q11 */                    \
-  /* q8 = d16, d17 */                             \
-  "vstr d16, [%[packed_ptr], #0]\n"               \
-  "vstr d17, [%[packed_ptr], #8]\n"               \
-  /* q10 = d20, d21 */                            \
-  "add %[packed_ptr], %[packed_ptr], %[stride]\n" \
-  "vstr d20, [%[packed_ptr], #0]\n"               \
-  "vstr d21, [%[packed_ptr], #8]\n"               \
-  /* q9 = d18, d19 */                             \
-  "add %[packed_ptr], %[packed_ptr], %[stride]\n" \
-  "vstr d18, [%[packed_ptr], #0]\n"               \
-  "vstr d19, [%[packed_ptr], #8]\n"               \
-  /* q11 = d22, d23 */                            \
-  "add %[packed_ptr], %[packed_ptr], %[stride]\n" \
-  "vstr d22, [%[packed_ptr], #0]\n"               \
-  "vstr d23, [%[packed_ptr], #8]\n"               \
-  "add %[packed_ptr], %[packed_ptr], %[stride]\n" \
-
-          RUY_STORE_FOUR_BY_FOUR()
-          "cmp r1, r2\n"
-
-          "bne 1b\n"
-
-          "2:\n"
-
-          // Transpose 4x4 matrix.
-          "vzip.32 q0, q1\n"
-          "vzip.32 q2, q3\n"
-
-          "vtrn.32 q0, q2\n"
-          "vtrn.32 q1, q3\n"
-
-          "vzip.32 q0, q2\n"
-          "vzip.32 q1, q3\n"
-
-          "vmov q8, q0\n"
-          "vmov q9, q1\n"
-          "vmov q10, q2\n"
-          "vmov q11, q3\n"
-
-          RUY_STORE_FOUR_BY_FOUR()
-#undef RUY_STORE_FOUR_BY_FOUR
-          "3:\n"
-
-          "ands r2, %[rows], #3\n"
-          "beq 4f\n"
-          "mov r0, 0\n"
-          // Zero out q0 - q3
-          "vdup.32 q0, r0\n"
-          "vdup.32 q1, r0\n"
-          "vdup.32 q2, r0\n"
-          "vdup.32 q3, r0\n"
-#define RUY_LOAD_ONE_ROW_FIRST_HALF(R, I)    \
-  "cmp r2, #" #R "\n"                        \
-  "beq 5f\n"                                 \
-  "vld1.32 { d0[" #I "] }, [%[src_ptr0]]!\n" \
-  "vld1.32 { d2[" #I "] }, [%[src_ptr1]]!\n" \
-  "vld1.32 { d4[" #I "] }, [%[src_ptr2]]!\n" \
-  "vld1.32 { d6[" #I "] }, [%[src_ptr3]]!\n"
-
-#define RUY_LOAD_ONE_ROW_SECOND_HALF(R, I)      \
-  "vld1.32 { d1[" #I "] }, [%[src_ptr0]]!\n" \
-  "vld1.32 { d3[" #I "] }, [%[src_ptr1]]!\n" \
-  "vld1.32 { d5[" #I "] }, [%[src_ptr2]]!\n" \
-  "vld1.32 { d7[" #I "] }, [%[src_ptr3]]!\n"
-
-          RUY_LOAD_ONE_ROW_FIRST_HALF(0, 0)
-          RUY_LOAD_ONE_ROW_FIRST_HALF(1, 1)
-          RUY_LOAD_ONE_ROW_SECOND_HALF(2, 0)
-          RUY_LOAD_ONE_ROW_SECOND_HALF(3, 1)
-#undef RUY_LOAD_ONE_ROW_SECOND_HALF
-#undef RUY_LOAD_ONE_ROW_FIRST_HALF
-          "5:\n"
-
-          // Transpose 4x4 matrix.
-          "vzip.32 q0, q1\n"
-          "vzip.32 q2, q3\n"
-
-          "vtrn.32 q0, q2\n"
-          "vtrn.32 q1, q3\n"
-
-          "vzip.32 q0, q2\n"
-          "vzip.32 q1, q3\n"
-
-          "vmov q8, q0\n"
-          "vmov q9, q1\n"
-          "vmov q10, q2\n"
-          "vmov q11, q3\n"
-
-          "mov r1, #32\n"
-
-#define RUY_STORE_ONE_ROW(ROW, REGISTER1, REGISTER2)      \
-          "cmp r2, #" #ROW "\n"                           \
-          "beq 4f\n"                                      \
-          "vstr " #REGISTER1 ", [%[packed_ptr]]\n"    \
-          "vstr " #REGISTER2 ", [%[packed_ptr], #8]\n"    \
-          "add %[packed_ptr], %[packed_ptr], %[stride]\n"
-
-          // Store q8
-          RUY_STORE_ONE_ROW(0, d16, d17)
-          // Store q10
-          RUY_STORE_ONE_ROW(1, d20, d21)
-          // Store q9
-          RUY_STORE_ONE_ROW(2, d18, d19)
-          // Store q11
-          RUY_STORE_ONE_ROW(3, d22, d23)
-
-#undef RUY_STORE_ONE_ROW
-
-          "4:\n"
-
-      // clang-format on
-      : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
-        [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3),
-        [ packed_ptr ] "+r"(packed_ptr)
-      : [ src_inc ] "r"(static_cast<std::int64_t>(src_inc)),
-        [ rows ] "r"(src_rows), [ stride ] "r"(output_stride)
-      : "cc", "memory", "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3",
-        "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", "d12", "d13",
-        "d14", "d15", "d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23");
-}
-
-#endif  // (RUY_PLATFORM(NEON_32)
-
-#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
-void PackFloatNeonInOrder(const float* src_ptr0, const float* src_ptr1,
-                          const float* src_ptr2, const float* src_ptr3,
-                          int src_inc0, int src_inc1, int src_inc2,
-                          int src_inc3, int src_rows, int src_zero_point,
-                          float* packed_ptr, int start_col, int end_col) {
-  gemmlowp::ScopedProfilingLabel label(
-      "Pack (kNeon, optimized for in-order cores)");
-
-  asm volatile(
-          // clang-format off
-          "mov w1, #0\n"
-
-          "and w2, %w[rows], #-4\n"
-          "cmp w1, w2\n"
-          "beq 3f\n"
-          "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n"
-          "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n"
-          "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n"
-          "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n"
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr0], #64]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr1], #64]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr2], #64]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr3], #64]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr0], #128]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr1], #128]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr2], #128]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr3], #128]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr0], #192]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr1], #192]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr2], #192]\n")
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr3], #192]\n")
-          "add w1, w1, #4\n"
-          "cmp w1, w2\n"
-
-          "beq 2f\n"
-
-          "1:\n"
-          "add w1, w1, #4\n"
-
-          "ldr x10, [%[src_ptr0], #8]\n"
-          "trn1 v16.4s, v0.4s, v1.4s\n"
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr0], #240]\n")
-          "ldr x11, [%[src_ptr1], #8]\n"
-          "trn2 v17.4s, v0.4s, v1.4s\n"
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr1], #240]\n")
-          "ldr x12, [%[src_ptr2], #8]\n"
-          "trn1 v18.4s, v2.4s, v3.4s\n"
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr2], #240]\n")
-          "ldr x13, [%[src_ptr3], #8]\n"
-          "trn2 v19.4s, v2.4s, v3.4s\n"
-          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr3], #240]\n")
-
-          "ld1 {v0.2s}, [%[src_ptr0]], %[src_inc0]\n"
-          "trn1 v20.2d, v16.2d, v18.2d\n"
-          "ld1 {v1.2s}, [%[src_ptr1]], %[src_inc1]\n"
-          "trn2 v22.2d, v16.2d, v18.2d\n"
-          "ld1 {v2.2s}, [%[src_ptr2]], %[src_inc2]\n"
-          "trn1 v21.2d, v17.2d, v19.2d\n"
-          "ld1 {v3.2s}, [%[src_ptr3]], %[src_inc3]\n"
-          "trn2 v23.2d, v17.2d, v19.2d\n"
-          "cmp w1, w2\n"
-
-          "ins v0.d[1], x10\n"
-          "str q20, [%[packed_ptr], #0]\n"
-          "ins v1.d[1], x11\n"
-          "str q21, [%[packed_ptr], #32]\n"
-          "ins v2.d[1], x12\n"
-          "str q22, [%[packed_ptr], #64]\n"
-          "ins v3.d[1], x13\n"
-          "str q23, [%[packed_ptr], #96]\n"
-
-          "add %[packed_ptr], %[packed_ptr], #128\n"
-
-          "bne 1b\n"
-
-          "2:\n"
-
-          "trn1 v16.4s, v0.4s, v1.4s\n"
-          "trn2 v17.4s, v0.4s, v1.4s\n"
-          "trn1 v18.4s, v2.4s, v3.4s\n"
-          "trn2 v19.4s, v2.4s, v3.4s\n"
-
-          "trn1 v20.2d, v16.2d, v18.2d\n"
-          "trn2 v22.2d, v16.2d, v18.2d\n"
-          "trn1 v21.2d, v17.2d, v19.2d\n"
-          "trn2 v23.2d, v17.2d, v19.2d\n"
-
-          "str q20, [%[packed_ptr], #0]\n"
-          "str q21, [%[packed_ptr], #32]\n"
-          "str q22, [%[packed_ptr], #64]\n"
-          "str q23, [%[packed_ptr], #96]\n"
-          "add %[packed_ptr], %[packed_ptr], #128\n"
-
-          "3:\n"
-
-          "ands w2, %w[rows], #3\n"
-          "beq 4f\n"
-          "dup v0.16b, wzr\n"
-          "dup v1.16b, wzr\n"
-          "dup v2.16b, wzr\n"
-          "dup v3.16b, wzr\n"
-#define RUY_LOAD_ONE_ROW(R)                   \
-  "cmp w2, #" #R "\n"                         \
-  "beq 5f\n"                                  \
-  "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \
-  "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \
-  "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \
-  "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n"
-
-          RUY_LOAD_ONE_ROW(0)
-          RUY_LOAD_ONE_ROW(1)
-          RUY_LOAD_ONE_ROW(2)
-          RUY_LOAD_ONE_ROW(3)
-#undef RUY_LOAD_ONE_ROW
-          "5:\n"
-
-          "trn1 v16.4s, v0.4s, v1.4s\n"
-          "trn2 v17.4s, v0.4s, v1.4s\n"
-          "trn1 v18.4s, v2.4s, v3.4s\n"
-          "trn2 v19.4s, v2.4s, v3.4s\n"
-
-          "trn1 v20.2d, v16.2d, v18.2d\n"
-          "trn2 v22.2d, v16.2d, v18.2d\n"
-          "trn1 v21.2d, v17.2d, v19.2d\n"
-          "trn2 v23.2d, v17.2d, v19.2d\n"
-
-          "mov x1, #32\n"
-
-#define RUY_STORE_ONE_ROW(ROW, REGISTER)                  \
-          "cmp w2, #" #ROW "\n"                           \
-          "beq 4f\n"                                      \
-          "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n"
-
-          RUY_STORE_ONE_ROW(0, v20)
-          RUY_STORE_ONE_ROW(1, v21)
-          RUY_STORE_ONE_ROW(2, v22)
-          RUY_STORE_ONE_ROW(3, v23)
-
-#undef RUY_STORE_ONE_ROW
-
-          "4:\n"
-
-          // clang-format on
-
-          : [ src_ptr0 ] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), [src_ptr2] "+r"(src_ptr2),
-            [src_ptr3] "+r"(src_ptr3), [packed_ptr] "+r"(packed_ptr)
-          : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)), [src_inc1] "r"(static_cast<std::int64_t>(src_inc1)), [src_inc2] "r"(static_cast<std::int64_t>(src_inc2)),
-            [src_inc3] "r"(static_cast<std::int64_t>(src_inc3)), [rows] "r"(src_rows)
-          : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
-            "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
-            "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
-}
-#endif  // RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
-
-}  // namespace ruy
diff --git a/tensorflow/lite/experimental/ruy/pack.h b/tensorflow/lite/experimental/ruy/pack.h
index 8a4034c..f1dc1b6 100644
--- a/tensorflow/lite/experimental/ruy/pack.h
+++ b/tensorflow/lite/experimental/ruy/pack.h
@@ -83,420 +83,16 @@
 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_H_
 #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_H_
 
-#include <cstdint>
-
-#include "profiling/instrumentation.h"
-#include "tensorflow/lite/experimental/ruy/common.h"
-#include "tensorflow/lite/experimental/ruy/internal_matrix.h"
-#include "tensorflow/lite/experimental/ruy/opt_set.h"
 #include "tensorflow/lite/experimental/ruy/platform.h"
-#include "tensorflow/lite/experimental/ruy/tune.h"
 
-namespace ruy {
-
-template <Path ThePath, typename Scalar>
-struct PackedTypeImpl {
-  using Type = Scalar;
-};
-
-template <>
-struct PackedTypeImpl<Path::kNeon, std::uint8_t> {
-  using Type = std::int8_t;
-};
-template <>
-struct PackedTypeImpl<Path::kNeonDotprod, std::uint8_t> {
-  using Type = std::int8_t;
-};
-
-template <Path ThePath, typename Scalar>
-using PackedType = typename PackedTypeImpl<ThePath, Scalar>::Type;
-
-template <typename PackedScalar, typename Scalar>
-PackedScalar Pack(Scalar x) {
-  return x - SymmetricZeroPoint<Scalar>() + SymmetricZeroPoint<PackedScalar>();
-}
-
-template <Path ThePath, typename FixedKernelLayout, typename Scalar,
-          typename PackedScalar, typename SumsType>
-struct PackImpl {};
-
-#define RUY_INHERIT_PACK(PARENT, CHILD)                                       \
-  template <typename FixedKernelLayout, typename Scalar,                      \
-            typename PackedScalar, typename SumsType>                         \
-  struct PackImpl<CHILD, FixedKernelLayout, Scalar, PackedScalar, SumsType>   \
-      : PackImpl<PARENT, FixedKernelLayout, Scalar, PackedScalar, SumsType> { \
-  };
-
-template <typename FixedKernelLayout, typename Scalar, typename PackedScalar,
-          typename SumsType>
-struct PackImpl<Path::kStandardCpp, FixedKernelLayout, Scalar, PackedScalar,
-                SumsType> {
-  static void Run(Tuning, const Matrix<Scalar>& src_matrix,
-                  PackedMatrix<PackedScalar>* packed_matrix, int start_col,
-                  int end_col) {
-    gemmlowp::ScopedProfilingLabel label("Pack (generic)");
-    RUY_DCHECK_EQ((end_col - start_col) % FixedKernelLayout::kCols, 0);
-    SumsType* sums = packed_matrix->sums;
-    for (int col = start_col; col < end_col; col++) {
-      SumsType accum = 0;
-      for (int row = 0; row < packed_matrix->layout.rows; row++) {
-        PackedScalar packed_val;
-        if (col < src_matrix.layout.cols && row < src_matrix.layout.rows) {
-          packed_val = Pack<PackedScalar>(Element(src_matrix, row, col));
-        } else {
-          packed_val = packed_matrix->zero_point;
-        }
-        accum += packed_val;
-        relaxed_atomic_store(ElementPtr(packed_matrix, row, col), packed_val);
-      }
-      if (sums) {
-        relaxed_atomic_store(sums + col, accum);
-      }
-    }
-  }
-};
-
-RUY_INHERIT_PACK(Path::kStandardCpp, Path::kNeon)
-#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
-RUY_INHERIT_PACK(Path::kNeon, Path::kNeonDotprod)
-#endif
-
-#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
-void Pack8bitNeonOutOfOrder(const void* src_ptr0, const void* src_ptr1,
-                            const void* src_ptr2, const void* src_ptr3,
-                            int src_inc0, int src_inc1, int src_inc2,
-                            int src_inc3, int src_rows, int src_zero_point,
-                            std::int8_t* packed_ptr, int start_col, int end_col,
-                            std::int32_t* sums_ptr, int input_xor);
-void Pack8bitNeonInOrder(const void* src_ptr0, const void* src_ptr1,
-                         const void* src_ptr2, const void* src_ptr3,
-                         int src_inc0, int src_inc1, int src_inc2, int src_inc3,
-                         int src_rows, int src_zero_point,
-                         std::int8_t* packed_ptr, int start_col, int end_col,
-                         std::int32_t* sums_ptr, int input_xor);
-void Pack8bitNeonDotprodOutOfOrder(const void* src_ptr0, const void* src_ptr1,
-                                   const void* src_ptr2, const void* src_ptr3,
-                                   int src_inc0, int src_inc1, int src_inc2,
-                                   int src_inc3, int src_rows,
-                                   int src_zero_point, std::int8_t* packed_ptr,
-                                   int start_col, int end_col,
-                                   std::int32_t* sums_ptr, int input_xor);
-void Pack8bitNeonDotprodInOrder(const void* src_ptr0, const void* src_ptr1,
-                                const void* src_ptr2, const void* src_ptr3,
-                                int src_inc0, int src_inc1, int src_inc2,
-                                int src_inc3, int src_rows, int src_zero_point,
-                                std::int8_t* packed_ptr, int start_col,
-                                int end_col, std::int32_t* sums_ptr,
-                                int input_xor);
-
-template <typename Scalar>
-struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kColMajor, 16, 4>, Scalar,
-                std::int8_t, std::int32_t> {
-  static_assert(std::is_same<Scalar, std::int8_t>::value ||
-                    std::is_same<Scalar, std::uint8_t>::value,
-                "");
-  static constexpr int kInputXor =
-      std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
-
-  static void Run(Tuning tuning, const Matrix<Scalar>& src_matrix,
-                  PackedMatrix<std::int8_t>* packed_matrix, int start_col,
-                  int end_col) {
-    RUY_DCHECK(IsColMajor(src_matrix.layout));
-    RUY_DCHECK(IsColMajor(packed_matrix->layout));
-    RUY_DCHECK_EQ(start_col % 4, 0);
-    std::int32_t* sums = packed_matrix->sums;
-    Scalar zerobuf[16];
-    memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf));
-    for (int block_col = start_col; block_col < end_col; block_col += 4) {
-      int src_stride = src_matrix.layout.stride;
-      const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
-      const Scalar* src_ptr1 = src_ptr0 + src_stride;
-      const Scalar* src_ptr2 = src_ptr1 + src_stride;
-      const Scalar* src_ptr3 = src_ptr2 + src_stride;
-      int src_inc0 = 16;
-      int src_inc1 = 16;
-      int src_inc2 = 16;
-      int src_inc3 = 16;
-      if (block_col >= src_matrix.layout.cols - 3) {
-        if (block_col >= src_matrix.layout.cols - 0) {
-          src_ptr0 = zerobuf;
-          src_inc0 = 0;
-        }
-        if (block_col >= src_matrix.layout.cols - 1) {
-          src_ptr1 = zerobuf;
-          src_inc1 = 0;
-        }
-        if (block_col >= src_matrix.layout.cols - 2) {
-          src_ptr2 = zerobuf;
-          src_inc2 = 0;
-        }
-        if (block_col >= src_matrix.layout.cols - 3) {
-          src_ptr3 = zerobuf;
-          src_inc3 = 0;
-        }
-      }
-      std::int8_t* packed_ptr =
-          packed_matrix->data + packed_matrix->layout.stride * block_col;
-      std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
-      if (__builtin_expect(tuning == Tuning::kInOrder, true)) {
-        Pack8bitNeonInOrder(
-            src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1,
-            src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point,
-            packed_ptr, start_col, end_col, sums_ptr, kInputXor);
-      } else {
-        Pack8bitNeonOutOfOrder(
-            src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1,
-            src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point,
-            packed_ptr, start_col, end_col, sums_ptr, kInputXor);
-      }
-    }
-  }
-};
-
-template <typename Scalar>
-struct PackImpl<Path::kNeonDotprod, FixedKernelLayout<Order::kColMajor, 4, 8>,
-                Scalar, std::int8_t, std::int32_t> {
-  static_assert(std::is_same<Scalar, std::int8_t>::value ||
-                    std::is_same<Scalar, std::uint8_t>::value,
-                "");
-  static constexpr int kInputXor =
-      std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
-
-  static void Run(Tuning tuning, const Matrix<Scalar>& src_matrix,
-                  PackedMatrix<std::int8_t>* packed_matrix, int start_col,
-                  int end_col) {
-    RUY_DCHECK(IsColMajor(src_matrix.layout));
-    RUY_DCHECK(IsColMajor(packed_matrix->layout));
-    RUY_DCHECK_EQ(start_col % 8, 0);
-    std::int32_t* sums = packed_matrix->sums;
-    Scalar zerobuf[16];
-    memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf));
-    for (int block_col = start_col; block_col < end_col; block_col += 4) {
-      int src_stride = src_matrix.layout.stride;
-      const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
-      const Scalar* src_ptr1 = src_ptr0 + src_stride;
-      const Scalar* src_ptr2 = src_ptr1 + src_stride;
-      const Scalar* src_ptr3 = src_ptr2 + src_stride;
-      std::int64_t src_inc0 = 16;
-      std::int64_t src_inc1 = 16;
-      std::int64_t src_inc2 = 16;
-      std::int64_t src_inc3 = 16;
-      if (block_col >= src_matrix.layout.cols - 3) {
-        if (block_col >= src_matrix.layout.cols - 0) {
-          src_ptr0 = zerobuf;
-          src_inc0 = 0;
-        }
-        if (block_col >= src_matrix.layout.cols - 1) {
-          src_ptr1 = zerobuf;
-          src_inc1 = 0;
-        }
-        if (block_col >= src_matrix.layout.cols - 2) {
-          src_ptr2 = zerobuf;
-          src_inc2 = 0;
-        }
-        if (block_col >= src_matrix.layout.cols - 3) {
-          src_ptr3 = zerobuf;
-          src_inc3 = 0;
-        }
-      }
-      std::int8_t* packed_ptr =
-          packed_matrix->data +
-          packed_matrix->layout.stride * (block_col & ~7) +
-          ((block_col & 4) * 4);
-      std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
-      if (__builtin_expect(tuning == Tuning::kInOrder, true)) {
-        Pack8bitNeonDotprodInOrder(
-            src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1,
-            src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point,
-            packed_ptr, start_col, end_col, sums_ptr, kInputXor);
-      } else {
-        Pack8bitNeonDotprodOutOfOrder(
-            src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1,
-            src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point,
-            packed_ptr, start_col, end_col, sums_ptr, kInputXor);
-      }
-    }
-  }
-};
-#endif  // (RUY_PLATFORM(NEON_64)&& RUY_OPT_ENABLED(RUY_OPT_ASM)
-
-#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
-void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1,
-                             const float* src_ptr2, const float* src_ptr3,
-                             int src_inc0, int src_inc1, int src_inc2,
-                             int src_inc3, int src_rows, int src_zero_point,
-                             float* packed_ptr, int start_col, int end_col);
-void PackFloatNeonInOrder(const float* src_ptr0, const float* src_ptr1,
-                          const float* src_ptr2, const float* src_ptr3,
-                          int src_inc0, int src_inc1, int src_inc2,
-                          int src_inc3, int src_rows, int src_zero_point,
-                          float* packed_ptr, int start_col, int end_col);
-
-#elif RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM)
-void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1,
-                             const float* src_ptr2, const float* src_ptr3,
-                             int src_inc, int src_rows, int src_zero_point,
-                             float* packed_ptr, int start_col, int end_col,
-                             int stride);
-#endif  // (RUY_PLATFORM(NEON_64)&& RUY_OPT_ENABLED(RUY_OPT_ASM)
-
-#if (RUY_PLATFORM(NEON_32) || RUY_PLATFORM(NEON_64)) && \
-    RUY_OPT_ENABLED(RUY_OPT_ASM)
-
-template <>
-struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
-                float, float> {
-  static void Run(Tuning tuning, const Matrix<float>& src_matrix,
-                  PackedMatrix<float>* packed_matrix, int start_col,
-                  int end_col) {
-    RUY_DCHECK(IsColMajor(src_matrix.layout));
-    RUY_DCHECK(IsColMajor(packed_matrix->layout));
-    RUY_DCHECK_EQ(start_col % 8, 0);
-    const float zerobuf[4] = {0};
-    for (int block_col = start_col; block_col < end_col; block_col += 4) {
-      int src_stride = src_matrix.layout.stride;
-      const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
-      const float* src_ptr1 = src_ptr0 + src_stride;
-      const float* src_ptr2 = src_ptr1 + src_stride;
-      const float* src_ptr3 = src_ptr2 + src_stride;
-      std::int64_t src_inc0 = 16;
-      std::int64_t src_inc1 = 16;
-      std::int64_t src_inc2 = 16;
-      std::int64_t src_inc3 = 16;
-      if (block_col >= src_matrix.layout.cols - 3) {
-        if (block_col >= src_matrix.layout.cols - 0) {
-          src_ptr0 = zerobuf;
-          src_inc0 = 0;
-        }
-        if (block_col >= src_matrix.layout.cols - 1) {
-          src_ptr1 = zerobuf;
-          src_inc1 = 0;
-        }
-        if (block_col >= src_matrix.layout.cols - 2) {
-          src_ptr2 = zerobuf;
-          src_inc2 = 0;
-        }
-        if (block_col >= src_matrix.layout.cols - 3) {
-          src_ptr3 = zerobuf;
-          src_inc3 = 0;
-        }
-      }
-      float* packed_ptr = packed_matrix->data +
-                          packed_matrix->layout.stride * (block_col & ~7) +
-                          ((block_col & 4));
-#if RUY_PLATFORM(NEON_64)
-      if (__builtin_expect(tuning == Tuning::kInOrder, true)) {
-        PackFloatNeonInOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0,
-                             src_inc1, src_inc2, src_inc3,
-                             src_matrix.layout.rows, src_matrix.zero_point,
-                             packed_ptr, start_col, end_col);
-      } else {
-        PackFloatNeonOutOfOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3,
-                                src_inc0, src_inc1, src_inc2, src_inc3,
-                                src_matrix.layout.rows, src_matrix.zero_point,
-                                packed_ptr, start_col, end_col);
-      }
+// IWYU pragma: begin_exports
+#if RUY_PLATFORM(NEON)
+#include "tensorflow/lite/experimental/ruy/pack_arm.h"
+#elif RUY_PLATFORM(AVX512)
+#include "tensorflow/lite/experimental/ruy/pack_x86.h"
 #else
-      // Encode each of src_inc0, ..., src_inc3 in lowest 4 bits of src_inc
-      // to save on registers (we have fewer general purpose registers in
-      // 32-bit ARM than in 64-bit ARM). For the 64-bit case, we pass four
-      // values that are each either 16 or 0 and use them directly. For the
-      // 32-bit case, bits 0, 1, 2, and 3 are used to determine if we should
-      // use the value 16 (bit is set) or 0 (bit is not set) for the
-      // respective increment value.
-      std::int64_t src_inc = 0;
-      src_inc += src_inc0 == 16 ? 1 : 0;
-      src_inc += src_inc1 == 16 ? 2 : 0;
-      src_inc += src_inc2 == 16 ? 4 : 0;
-      src_inc += src_inc3 == 16 ? 8 : 0;
-      const int kOutputStride = 32;
-      PackFloatNeonOutOfOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc,
-                              src_matrix.layout.rows, src_matrix.zero_point,
-                              packed_ptr, start_col, end_col, kOutputStride);
-#endif  // RUY_PLATFORM(NEON_64)
-    }
-  }
-};
-
-#if RUY_PLATFORM(NEON_32)
-// The 32-bit float kernel is 8 rows X 4 columns, so we need an additional
-// specialization for a FixedKernelLayout with 4 columns.
-template <>
-struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kRowMajor, 1, 4>, float,
-                float, float> {
-  static void Run(Tuning tuning, const Matrix<float>& src_matrix,
-                  PackedMatrix<float>* packed_matrix, int start_col,
-                  int end_col) {
-    RUY_DCHECK(IsColMajor(src_matrix.layout));
-    RUY_DCHECK(IsColMajor(packed_matrix->layout));
-    RUY_DCHECK_EQ(start_col % 4, 0);
-    const float zerobuf[4] = {0};
-    for (int block_col = start_col; block_col < end_col; block_col += 4) {
-      int src_stride = src_matrix.layout.stride;
-      const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
-      const float* src_ptr1 = src_ptr0 + src_stride;
-      const float* src_ptr2 = src_ptr1 + src_stride;
-      const float* src_ptr3 = src_ptr2 + src_stride;
-      std::int64_t src_inc0 = 16;
-      std::int64_t src_inc1 = 16;
-      std::int64_t src_inc2 = 16;
-      std::int64_t src_inc3 = 16;
-      if (block_col >= src_matrix.layout.cols - 3) {
-        if (block_col >= src_matrix.layout.cols - 0) {
-          src_ptr0 = zerobuf;
-          src_inc0 = 0;
-        }
-        if (block_col >= src_matrix.layout.cols - 1) {
-          src_ptr1 = zerobuf;
-          src_inc1 = 0;
-        }
-        if (block_col >= src_matrix.layout.cols - 2) {
-          src_ptr2 = zerobuf;
-          src_inc2 = 0;
-        }
-        if (block_col >= src_matrix.layout.cols - 3) {
-          src_ptr3 = zerobuf;
-          src_inc3 = 0;
-        }
-      }
-      float* packed_ptr =
-          packed_matrix->data + packed_matrix->layout.stride * (block_col);
-      // Encode each of src_inc0, ..., src_inc1 in lowest 4 bits of scrc_inc
-      // to save registers.
-      std::int64_t src_inc = 0;
-      src_inc += src_inc0 == 16 ? 1 : 0;
-      src_inc += src_inc1 == 16 ? 2 : 0;
-      src_inc += src_inc2 == 16 ? 4 : 0;
-      src_inc += src_inc3 == 16 ? 8 : 0;
-      const int kOutputStride = 16;
-      PackFloatNeonOutOfOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc,
-                              src_matrix.layout.rows, src_matrix.zero_point,
-                              packed_ptr, start_col, end_col, kOutputStride);
-    }
-  }
-};
-#endif  // (RUY_PLATFORM(NEON_32))
-#endif  // (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && \
-        // RUY_OPT_ENABLED(RUY_OPT_ASM)
-
-// Main entry point for packing.
-template <Path ThePath, typename FixedKernelLayout, typename Scalar,
-          typename PackedScalar>
-void RunPack(Tuning tuning, const DMatrix& src_matrix, PMatrix* packed_matrix,
-             int start_col, int end_col) {
-  using SumsType = typename PackedMatrix<PackedScalar>::SumsType;
-  Matrix<Scalar> src = ToMatrix<Scalar>(src_matrix);
-  PackedMatrix<PackedScalar> packed =
-      ToPackedMatrix<PackedScalar>(*packed_matrix);
-  PackImpl<ThePath, FixedKernelLayout, Scalar, PackedScalar, SumsType>::Run(
-      tuning, src, &packed, start_col, end_col);
-}
-
-// The signature of RunPack is the same, regardless of its template parameters.
-using RunPackFn = decltype(
-    RunPack<Path::kStandardCpp, FixedKernelLayout<Order::kColMajor, 1, 1>,
-            std::int8_t, std::int8_t>);
-
-}  // namespace ruy
+#include "tensorflow/lite/experimental/ruy/pack_common.h"
+#endif
+// IWYU pragma: end_exports
 
 #endif  // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_H_
diff --git a/tensorflow/lite/experimental/ruy/pack_arm.cc b/tensorflow/lite/experimental/ruy/pack_arm.cc
new file mode 100644
index 0000000..84db027
--- /dev/null
+++ b/tensorflow/lite/experimental/ruy/pack_arm.cc
@@ -0,0 +1,1527 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdint>
+
+#include "profiling/instrumentation.h"
+#include "tensorflow/lite/experimental/ruy/common.h"
+#include "tensorflow/lite/experimental/ruy/opt_set.h"
+#include "tensorflow/lite/experimental/ruy/pack.h"
+#include "tensorflow/lite/experimental/ruy/platform.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+void Pack8bitNeonOutOfOrder(const void* src_ptr0, const void* src_ptr1,
+                            const void* src_ptr2, const void* src_ptr3,
+                            int src_inc0, int src_inc1, int src_inc2,
+                            int src_inc3, int src_rows, int src_zero_point,
+                            std::int8_t* packed_ptr, int start_col, int end_col,
+                            std::int32_t* sums_ptr, int input_xor) {
+  gemmlowp::ScopedProfilingLabel label(
+      "Pack (kNeon, optimized for out-of-order cores)");
+  asm volatile(
+      // clang-format off
+          "dup v26.16b, %w[input_xor]\n"
+          "mov w1, #0\n"
+          "dup v28.4s, wzr\n"
+          "dup v29.4s, wzr\n"
+          "dup v30.4s, wzr\n"
+          "dup v31.4s, wzr\n"
+
+          "and w2, %w[rows], #-16\n"
+          "cmp w1, w2\n"
+          "beq 3f\n"
+
+          "add w1, w1, #16\n"
+          "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
+          "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
+          "cmp w1, w2\n"
+          "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
+          "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
+          "beq 2f\n"
+
+          "1:\n"
+
+          "add w1, w1, #16\n"
+          "eor v4.16b, v0.16b, v26.16b\n"
+          "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
+          "eor v5.16b, v1.16b, v26.16b\n"
+          "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
+          "eor v6.16b, v2.16b, v26.16b\n"
+          "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
+          "eor v7.16b, v3.16b, v26.16b\n"
+          "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
+
+          "saddlp v16.8h, v4.16b\n"
+          "str q4, [%[packed_ptr], #0]\n"
+          "saddlp v17.8h, v5.16b\n"
+          "str q5, [%[packed_ptr], #16]\n"
+          "saddlp v18.8h, v6.16b\n"
+          "str q6, [%[packed_ptr], #32]\n"
+          "saddlp v19.8h, v7.16b\n"
+          "str q7, [%[packed_ptr], #48]\n"
+          "sadalp v28.4s, v16.8h\n"
+          "cmp w1, w2\n"
+          "sadalp v29.4s, v17.8h\n"
+          "add %[packed_ptr], %[packed_ptr], #64\n"
+          "sadalp v30.4s, v18.8h\n"
+          "sadalp v31.4s, v19.8h\n"
+
+          "bne 1b\n"
+
+          "2:\n"
+
+          "eor v4.16b, v0.16b, v26.16b\n"
+          "eor v5.16b, v1.16b, v26.16b\n"
+          "eor v6.16b, v2.16b, v26.16b\n"
+          "eor v7.16b, v3.16b, v26.16b\n"
+
+          "saddlp v16.8h, v4.16b\n"
+          "str q4, [%[packed_ptr], #0]\n"
+          "saddlp v17.8h, v5.16b\n"
+          "str q5, [%[packed_ptr], #16]\n"
+          "saddlp v18.8h, v6.16b\n"
+          "str q6, [%[packed_ptr], #32]\n"
+          "saddlp v19.8h, v7.16b\n"
+          "str q7, [%[packed_ptr], #48]\n"
+          "sadalp v28.4s, v16.8h\n"
+          "sadalp v29.4s, v17.8h\n"
+          "sadalp v30.4s, v18.8h\n"
+          "sadalp v31.4s, v19.8h\n"
+
+          "add %[packed_ptr], %[packed_ptr], #64\n"
+
+          "3:\n"
+
+          "ands w2, %w[rows], #15\n"
+          "beq 4f\n"
+          "dup v0.16b, %w[src_zero_point]\n"
+          "dup v1.16b, %w[src_zero_point]\n"
+          "dup v2.16b, %w[src_zero_point]\n"
+          "dup v3.16b, %w[src_zero_point]\n"
+#define RUY_LOAD_ONE_ROW(R)                   \
+  "cmp w2, #" #R "\n"                         \
+  "beq 5f\n"                                  \
+  "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \
+  "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \
+  "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \
+  "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n"
+
+          RUY_LOAD_ONE_ROW(0)
+          RUY_LOAD_ONE_ROW(1)
+          RUY_LOAD_ONE_ROW(2)
+          RUY_LOAD_ONE_ROW(3)
+          RUY_LOAD_ONE_ROW(4)
+          RUY_LOAD_ONE_ROW(5)
+          RUY_LOAD_ONE_ROW(6)
+          RUY_LOAD_ONE_ROW(7)
+          RUY_LOAD_ONE_ROW(8)
+          RUY_LOAD_ONE_ROW(9)
+          RUY_LOAD_ONE_ROW(10)
+          RUY_LOAD_ONE_ROW(11)
+          RUY_LOAD_ONE_ROW(12)
+          RUY_LOAD_ONE_ROW(13)
+          RUY_LOAD_ONE_ROW(14)
+          RUY_LOAD_ONE_ROW(15)
+#undef RUY_LOAD_ONE_ROW
+          "5:\n"
+
+          "eor v4.16b, v0.16b, v26.16b\n"
+          "eor v5.16b, v1.16b, v26.16b\n"
+          "eor v6.16b, v2.16b, v26.16b\n"
+          "eor v7.16b, v3.16b, v26.16b\n"
+
+          "saddlp v16.8h, v4.16b\n"
+          "saddlp v17.8h, v5.16b\n"
+          "saddlp v18.8h, v6.16b\n"
+          "saddlp v19.8h, v7.16b\n"
+          "sadalp v28.4s, v16.8h\n"
+          "sadalp v29.4s, v17.8h\n"
+          "sadalp v30.4s, v18.8h\n"
+          "sadalp v31.4s, v19.8h\n"
+
+          "str q4, [%[packed_ptr], #0]\n"
+          "str q5, [%[packed_ptr], #16]\n"
+          "str q6, [%[packed_ptr], #32]\n"
+          "str q7, [%[packed_ptr], #48]\n"
+          "add %[packed_ptr], %[packed_ptr], #64\n"
+
+          "4:\n"
+
+          "addp v28.4s, v28.4s, v29.4s\n"
+          "addp v30.4s, v30.4s, v31.4s\n"
+          "addp v28.4s, v28.4s, v30.4s\n"
+
+          "cmp %[sums_ptr], #0\n"
+          "beq 6f\n"
+          "st1 {v28.4s}, [%[sums_ptr]], #16\n"
+          "6:\n"
+      // clang-format on
+
+      : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
+        [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3),
+        [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr)
+      : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)),
+        [ src_inc1 ] "r"(static_cast<std::int64_t>(src_inc1)),
+        [ src_inc2 ] "r"(static_cast<std::int64_t>(src_inc2)),
+        [ src_inc3 ] "r"(static_cast<std::int64_t>(src_inc3)),
+        [ rows ] "r"(src_rows), [ src_zero_point ] "r"(src_zero_point),
+        [ input_xor ] "r"(input_xor)
+      : "cc", "memory", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
+        "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
+        "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
+        "v27", "v28", "v29", "v30", "v31");
+}
+
+void Pack8bitNeonInOrder(const void* src_ptr0, const void* src_ptr1,
+                         const void* src_ptr2, const void* src_ptr3,
+                         int src_inc0, int src_inc1, int src_inc2, int src_inc3,
+                         int src_rows, int src_zero_point,
+                         std::int8_t* packed_ptr, int start_col, int end_col,
+                         std::int32_t* sums_ptr, int input_xor) {
+  gemmlowp::ScopedProfilingLabel label(
+      "Pack (kNeon, optimized for in-order cores)");
+  asm volatile(
+          // clang-format off
+          "dup v26.16b, %w[input_xor]\n"
+          "mov w1, #0\n"
+          "dup v28.4s, wzr\n"
+          "dup v29.4s, wzr\n"
+          "dup v30.4s, wzr\n"
+          "dup v31.4s, wzr\n"
+
+          "and w2, %w[rows], #-16\n"
+          "cmp w1, w2\n"
+          "beq 3f\n"
+          "ldr x10, [%[src_ptr0], #8]\n"
+          "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n"
+          "ldr x11, [%[src_ptr1], #8]\n"
+          "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n"
+          "ldr x12, [%[src_ptr2], #8]\n"
+          "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n"
+          "ldr x13, [%[src_ptr3], #8]\n"
+          "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n"
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr0], #64]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr1], #64]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr2], #64]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr3], #64]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr0], #128]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr1], #128]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr2], #128]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr3], #128]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr0], #192]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr1], #192]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr2], #192]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr3], #192]\n")
+          "add w1, w1, #16\n"
+          "cmp w1, w2\n"
+
+          "beq 2f\n"
+
+          "1:\n"
+          "add w1, w1, #16\n"
+          "ins v0.d[1], x10\n"
+          "ldr x10, [%[src_ptr0], #8]\n"
+          "ins v1.d[1], x11\n"
+          "ldr x11, [%[src_ptr1], #8]\n"
+          "ins v2.d[1], x12\n"
+          "ldr x12, [%[src_ptr2], #8]\n"
+          "ins v3.d[1], x13\n"
+          "ldr x13, [%[src_ptr3], #8]\n"
+          "eor v4.16b, v0.16b, v26.16b\n"
+          "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n"
+          "eor v5.16b, v1.16b, v26.16b\n"
+          "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n"
+          "eor v6.16b, v2.16b, v26.16b\n"
+          "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n"
+          "eor v7.16b, v3.16b, v26.16b\n"
+          "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n"
+          "saddlp v16.8h, v4.16b\n"
+          "str q4, [%[packed_ptr], #0]\n"
+          "saddlp v17.8h, v5.16b\n"
+          "str q5, [%[packed_ptr], #16]\n"
+          "saddlp v18.8h, v6.16b\n"
+          "str q6, [%[packed_ptr], #32]\n"
+          "saddlp v19.8h, v7.16b\n"
+          "str q7, [%[packed_ptr], #48]\n"
+          "sadalp v28.4s, v16.8h\n"
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr0], #240]\n")
+          "cmp w1, w2\n"
+          "sadalp v29.4s, v17.8h\n"
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr1], #240]\n")
+          "add %[packed_ptr], %[packed_ptr], #64\n"
+          "sadalp v30.4s, v18.8h\n"
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr2], #240]\n")
+          "sadalp v31.4s, v19.8h\n"
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr3], #240]\n")
+
+          "bne 1b\n"
+
+          "2:\n"
+          "ins v0.d[1], x10\n"
+          "ins v1.d[1], x11\n"
+          "ins v2.d[1], x12\n"
+          "ins v3.d[1], x13\n"
+          "eor v4.16b, v0.16b, v26.16b\n"
+          "eor v5.16b, v1.16b, v26.16b\n"
+          "eor v6.16b, v2.16b, v26.16b\n"
+          "eor v7.16b, v3.16b, v26.16b\n"
+
+          "saddlp v16.8h, v4.16b\n"
+          "str q4, [%[packed_ptr], #0]\n"
+          "saddlp v17.8h, v5.16b\n"
+          "str q5, [%[packed_ptr], #16]\n"
+          "saddlp v18.8h, v6.16b\n"
+          "str q6, [%[packed_ptr], #32]\n"
+          "saddlp v19.8h, v7.16b\n"
+          "str q7, [%[packed_ptr], #48]\n"
+          "sadalp v28.4s, v16.8h\n"
+          "sadalp v29.4s, v17.8h\n"
+          "sadalp v30.4s, v18.8h\n"
+          "sadalp v31.4s, v19.8h\n"
+
+          "add %[packed_ptr], %[packed_ptr], #64\n"
+
+          "3:\n"
+
+          "ands w2, %w[rows], #15\n"
+          "beq 4f\n"
+          "dup v0.16b, %w[src_zero_point]\n"
+          "dup v1.16b, %w[src_zero_point]\n"
+          "dup v2.16b, %w[src_zero_point]\n"
+          "dup v3.16b, %w[src_zero_point]\n"
+#define RUY_LOAD_ONE_ROW(R)                   \
+  "cmp w2, #" #R "\n"                         \
+  "beq 5f\n"                                  \
+  "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \
+  "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \
+  "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \
+  "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n"
+
+          RUY_LOAD_ONE_ROW(0)
+          RUY_LOAD_ONE_ROW(1)
+          RUY_LOAD_ONE_ROW(2)
+          RUY_LOAD_ONE_ROW(3)
+          RUY_LOAD_ONE_ROW(4)
+          RUY_LOAD_ONE_ROW(5)
+          RUY_LOAD_ONE_ROW(6)
+          RUY_LOAD_ONE_ROW(7)
+          RUY_LOAD_ONE_ROW(8)
+          RUY_LOAD_ONE_ROW(9)
+          RUY_LOAD_ONE_ROW(10)
+          RUY_LOAD_ONE_ROW(11)
+          RUY_LOAD_ONE_ROW(12)
+          RUY_LOAD_ONE_ROW(13)
+          RUY_LOAD_ONE_ROW(14)
+          RUY_LOAD_ONE_ROW(15)
+#undef RUY_LOAD_ONE_ROW
+          "5:\n"
+
+          "eor v4.16b, v0.16b, v26.16b\n"
+          "eor v5.16b, v1.16b, v26.16b\n"
+          "eor v6.16b, v2.16b, v26.16b\n"
+          "eor v7.16b, v3.16b, v26.16b\n"
+
+          "saddlp v16.8h, v4.16b\n"
+          "saddlp v17.8h, v5.16b\n"
+          "saddlp v18.8h, v6.16b\n"
+          "saddlp v19.8h, v7.16b\n"
+          "sadalp v28.4s, v16.8h\n"
+          "sadalp v29.4s, v17.8h\n"
+          "sadalp v30.4s, v18.8h\n"
+          "sadalp v31.4s, v19.8h\n"
+
+          "str q4, [%[packed_ptr], #0]\n"
+          "str q5, [%[packed_ptr], #16]\n"
+          "str q6, [%[packed_ptr], #32]\n"
+          "str q7, [%[packed_ptr], #48]\n"
+          "add %[packed_ptr], %[packed_ptr], #64\n"
+
+          "4:\n"
+
+          "addp v28.4s, v28.4s, v29.4s\n"
+          "addp v30.4s, v30.4s, v31.4s\n"
+          "addp v28.4s, v28.4s, v30.4s\n"
+
+          "cmp %[sums_ptr], #0\n"
+          "beq 6f\n"
+          "st1 {v28.4s}, [%[sums_ptr]], #16\n"
+          "6:\n"
+          // clang-format on
+
+          : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
+            [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3),
+            [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr)
+          : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)), [ src_inc1 ] "r"(static_cast<std::int64_t>(src_inc1)),
+            [ src_inc2 ] "r"(static_cast<std::int64_t>(src_inc2)), [ src_inc3 ] "r"(static_cast<std::int64_t>(src_inc3)),
+            [ rows ] "r"(src_rows),
+            [ src_zero_point ] "r"(src_zero_point),
+            [input_xor] "r"(input_xor)
+          : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5",
+            "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
+            "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24",
+            "v25", "v26", "v27", "v28", "v29", "v30", "v31");
+}
+
+void Pack8bitNeonDotprodInOrder(const void* src_ptr0, const void* src_ptr1,
+                                const void* src_ptr2, const void* src_ptr3,
+                                int src_inc0, int src_inc1, int src_inc2,
+                                int src_inc3, int src_rows, int src_zero_point,
+                                std::int8_t* packed_ptr, int start_col,
+                                int end_col, std::int32_t* sums_ptr,
+                                int input_xor) {
+  gemmlowp::ScopedProfilingLabel label(
+      "Pack (kNeonDotprod, optimized for in-order cores)");
+  asm volatile(
+          // clang-format off
+          "dup v26.16b, %w[input_xor]\n"
+          "mov w1, #1\n"
+          "dup v27.16b, w1\n"
+          "mov w1, #0\n"
+          "dup v28.4s, wzr\n"
+          "dup v29.4s, wzr\n"
+          "dup v30.4s, wzr\n"
+          "dup v31.4s, wzr\n"
+
+          "and w2, %w[rows], #-16\n"
+          "cmp w1, w2\n"
+          "beq 3f\n"
+          "ldr x10, [%[src_ptr0], #8]\n"
+          "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n"
+          "ldr x11, [%[src_ptr1], #8]\n"
+          "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n"
+          "ldr x12, [%[src_ptr2], #8]\n"
+          "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n"
+          "ldr x13, [%[src_ptr3], #8]\n"
+          "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n"
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr0], #64]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr1], #64]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr2], #64]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr3], #64]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr0], #128]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr1], #128]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr2], #128]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr3], #128]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr0], #192]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr1], #192]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr2], #192]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr3], #192]\n")
+          "add w1, w1, #16\n"
+          "cmp w1, w2\n"
+
+          "beq 2f\n"
+
+          "1:\n"
+          "add w1, w1, #16\n"
+          "ins v0.d[1], x10\n"
+          "ldr x10, [%[src_ptr0], #8]\n"
+          "ins v1.d[1], x11\n"
+          "ldr x11, [%[src_ptr1], #8]\n"
+          "ins v2.d[1], x12\n"
+          "ldr x12, [%[src_ptr2], #8]\n"
+          "ins v3.d[1], x13\n"
+          "ldr x13, [%[src_ptr3], #8]\n"
+
+          "eor v4.16b, v0.16b, v26.16b\n"
+          "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n"
+          "eor v5.16b, v1.16b, v26.16b\n"
+          "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n"
+          "eor v6.16b, v2.16b, v26.16b\n"
+          "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n"
+          "eor v7.16b, v3.16b, v26.16b\n"
+          "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n"
+
+          "trn1 v16.4s, v4.4s, v5.4s\n"
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr0], #240]\n")
+          "trn2 v17.4s, v4.4s, v5.4s\n"
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr1], #240]\n")
+          "trn1 v18.4s, v6.4s, v7.4s\n"
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr2], #240]\n")
+          "trn2 v19.4s, v6.4s, v7.4s\n"
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr3], #240]\n")
+
+          "trn1 v20.2d, v16.2d, v18.2d\n"
+          "trn2 v22.2d, v16.2d, v18.2d\n"
+          "trn1 v21.2d, v17.2d, v19.2d\n"
+          "trn2 v23.2d, v17.2d, v19.2d\n"
+          "cmp w1, w2\n"
+
+          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
+          "str q20, [%[packed_ptr], #0]\n"
+          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
+          "str q21, [%[packed_ptr], #32]\n"
+          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
+          "str q22, [%[packed_ptr], #64]\n"
+          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
+          "str q23, [%[packed_ptr], #96]\n"
+
+          "add %[packed_ptr], %[packed_ptr], #128\n"
+
+          "bne 1b\n"
+
+          "2:\n"
+          "ins v0.d[1], x10\n"
+          "ins v1.d[1], x11\n"
+          "ins v2.d[1], x12\n"
+          "ins v3.d[1], x13\n"
+          "eor v0.16b, v0.16b, v26.16b\n"
+          "eor v1.16b, v1.16b, v26.16b\n"
+          "eor v2.16b, v2.16b, v26.16b\n"
+          "eor v3.16b, v3.16b, v26.16b\n"
+
+          "trn1 v16.4s, v0.4s, v1.4s\n"
+          "trn2 v17.4s, v0.4s, v1.4s\n"
+          "trn1 v18.4s, v2.4s, v3.4s\n"
+          "trn2 v19.4s, v2.4s, v3.4s\n"
+
+          "trn1 v20.2d, v16.2d, v18.2d\n"
+          "trn2 v22.2d, v16.2d, v18.2d\n"
+          "trn1 v21.2d, v17.2d, v19.2d\n"
+          "trn2 v23.2d, v17.2d, v19.2d\n"
+
+          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
+          "str q20, [%[packed_ptr], #0]\n"
+          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
+          "str q21, [%[packed_ptr], #32]\n"
+          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
+          "str q22, [%[packed_ptr], #64]\n"
+          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
+          "str q23, [%[packed_ptr], #96]\n"
+          "add %[packed_ptr], %[packed_ptr], #128\n"
+
+          "3:\n"
+
+          "ands w2, %w[rows], #15\n"
+          "beq 4f\n"
+          "dup v0.16b, %w[src_zero_point]\n"
+          "dup v1.16b, %w[src_zero_point]\n"
+          "dup v2.16b, %w[src_zero_point]\n"
+          "dup v3.16b, %w[src_zero_point]\n"
+#define RUY_LOAD_ONE_ROW(R)                   \
+  "cmp w2, #" #R "\n"                         \
+  "beq 5f\n"                                  \
+  "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \
+  "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \
+  "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \
+  "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n"
+
+          RUY_LOAD_ONE_ROW(0)
+          RUY_LOAD_ONE_ROW(1)
+          RUY_LOAD_ONE_ROW(2)
+          RUY_LOAD_ONE_ROW(3)
+          RUY_LOAD_ONE_ROW(4)
+          RUY_LOAD_ONE_ROW(5)
+          RUY_LOAD_ONE_ROW(6)
+          RUY_LOAD_ONE_ROW(7)
+          RUY_LOAD_ONE_ROW(8)
+          RUY_LOAD_ONE_ROW(9)
+          RUY_LOAD_ONE_ROW(10)
+          RUY_LOAD_ONE_ROW(11)
+          RUY_LOAD_ONE_ROW(12)
+          RUY_LOAD_ONE_ROW(13)
+          RUY_LOAD_ONE_ROW(14)
+          RUY_LOAD_ONE_ROW(15)
+#undef RUY_LOAD_ONE_ROW
+          "5:\n"
+
+          "eor v0.16b, v0.16b, v26.16b\n"
+          "eor v1.16b, v1.16b, v26.16b\n"
+          "eor v2.16b, v2.16b, v26.16b\n"
+          "eor v3.16b, v3.16b, v26.16b\n"
+
+          "trn1 v16.4s, v0.4s, v1.4s\n"
+          "trn2 v17.4s, v0.4s, v1.4s\n"
+          "trn1 v18.4s, v2.4s, v3.4s\n"
+          "trn2 v19.4s, v2.4s, v3.4s\n"
+
+          "trn1 v20.2d, v16.2d, v18.2d\n"
+          "trn2 v22.2d, v16.2d, v18.2d\n"
+          "trn1 v21.2d, v17.2d, v19.2d\n"
+          "trn2 v23.2d, v17.2d, v19.2d\n"
+
+          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
+          "str q20, [%[packed_ptr], #0]\n"
+          "cmp w2, #4\n"
+          "ble 4f\n"
+          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
+          "str q21, [%[packed_ptr], #32]\n"
+          "cmp w2, #8\n"
+          "ble 4f\n"
+          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
+          "str q22, [%[packed_ptr], #64]\n"
+          "cmp w2, #12\n"
+          "ble 4f\n"
+          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
+          "str q23, [%[packed_ptr], #96]\n"
+          "add %[packed_ptr], %[packed_ptr], #128\n"
+
+          "4:\n"
+
+          "add v28.4s, v28.4s, v29.4s\n"
+          "add v30.4s, v30.4s, v31.4s\n"
+          "add v28.4s, v28.4s, v30.4s\n"
+
+          "cmp %[sums_ptr], #0\n"
+          "beq 6f\n"
+          "st1 {v28.4s}, [%[sums_ptr]], #16\n"
+          "6:\n"
+          // clang-format on
+
+          : [ src_ptr0 ] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), [src_ptr2] "+r"(src_ptr2),
+            [src_ptr3] "+r"(src_ptr3), [packed_ptr] "+r"(packed_ptr), [sums_ptr] "+r"(sums_ptr)
+          : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)), [ src_inc1 ] "r"(static_cast<std::int64_t>(src_inc1)),
+            [ src_inc2 ] "r"(static_cast<std::int64_t>(src_inc2)), [ src_inc3 ] "r"(static_cast<std::int64_t>(src_inc3)),
+                [rows] "r"(src_rows),
+            [src_zero_point] "r"(static_cast<int>(src_zero_point)),
+            [input_xor] "r"(input_xor)
+          : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
+            "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
+            "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
+}
+
+void Pack8bitNeonDotprodOutOfOrder(const void* src_ptr0, const void* src_ptr1,
+                                   const void* src_ptr2, const void* src_ptr3,
+                                   int src_inc0, int src_inc1, int src_inc2,
+                                   int src_inc3, int src_rows,
+                                   int src_zero_point, std::int8_t* packed_ptr,
+                                   int start_col, int end_col,
+                                   std::int32_t* sums_ptr, int input_xor) {
+  gemmlowp::ScopedProfilingLabel label(
+      "Pack (kNeonDotprod, optimized for out-of-order cores)");
+  asm volatile(
+      // clang-format off
+          "dup v26.16b, %w[input_xor]\n"
+          "mov w1, #1\n"
+          "dup v27.16b, w1\n"
+          "mov w1, #0\n"
+          "dup v28.4s, wzr\n"
+          "dup v29.4s, wzr\n"
+          "dup v30.4s, wzr\n"
+          "dup v31.4s, wzr\n"
+
+#if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING)
+          "and w2, %w[rows], #-64\n"
+          "cmp w1, w2\n"
+          "beq 9f\n"
+
+          "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
+          "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
+          "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
+          "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
+          "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n"
+          "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n"
+          "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n"
+          "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n"
+          "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n"
+          "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n"
+          "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n"
+          "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n"
+          "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n"
+          "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n"
+          "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n"
+          "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n"
+          "add w1, w1, #64\n"
+          "cmp w1, w2\n"
+          "beq 8f\n"
+
+          "7:\n"
+          "eor v0.16b, v0.16b, v26.16b\n"
+          "eor v1.16b, v1.16b, v26.16b\n"
+          "eor v2.16b, v2.16b, v26.16b\n"
+          "eor v3.16b, v3.16b, v26.16b\n"
+
+          "trn1 v16.4s, v0.4s, v1.4s\n"
+          "trn2 v17.4s, v0.4s, v1.4s\n"
+          "trn1 v18.4s, v2.4s, v3.4s\n"
+          "trn2 v19.4s, v2.4s, v3.4s\n"
+
+          "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
+          "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
+          "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
+          "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
+          "add w1, w1, #16\n"
+
+          "trn1 v20.2d, v16.2d, v18.2d\n"
+          "trn2 v22.2d, v16.2d, v18.2d\n"
+          "trn1 v21.2d, v17.2d, v19.2d\n"
+          "trn2 v23.2d, v17.2d, v19.2d\n"
+
+          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
+          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
+          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
+          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
+
+          "str q20, [%[packed_ptr], #0]\n"
+          "str q21, [%[packed_ptr], #32]\n"
+          "str q22, [%[packed_ptr], #64]\n"
+          "str q23, [%[packed_ptr], #96]\n"
+          "add %[packed_ptr], %[packed_ptr], #128\n"
+
+          "eor v4.16b, v4.16b, v26.16b\n"
+          "eor v5.16b, v5.16b, v26.16b\n"
+          "eor v6.16b, v6.16b, v26.16b\n"
+          "eor v7.16b, v7.16b, v26.16b\n"
+
+          "trn1 v16.4s, v4.4s, v5.4s\n"
+          "trn2 v17.4s, v4.4s, v5.4s\n"
+          "trn1 v18.4s, v6.4s, v7.4s\n"
+          "trn2 v19.4s, v6.4s, v7.4s\n"
+
+          "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n"
+          "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n"
+          "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n"
+          "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n"
+          "add w1, w1, #16\n"
+
+          "trn1 v20.2d, v16.2d, v18.2d\n"
+          "trn2 v22.2d, v16.2d, v18.2d\n"
+          "trn1 v21.2d, v17.2d, v19.2d\n"
+          "trn2 v23.2d, v17.2d, v19.2d\n"
+
+          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
+          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
+          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
+          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
+
+          "str q20, [%[packed_ptr], #0]\n"
+          "str q21, [%[packed_ptr], #32]\n"
+          "str q22, [%[packed_ptr], #64]\n"
+          "str q23, [%[packed_ptr], #96]\n"
+          "add %[packed_ptr], %[packed_ptr], #128\n"
+
+          "eor v8.16b, v8.16b, v26.16b\n"
+          "eor v9.16b, v9.16b, v26.16b\n"
+          "eor v10.16b, v10.16b, v26.16b\n"
+          "eor v11.16b, v11.16b, v26.16b\n"
+
+          "trn1 v16.4s, v8.4s, v9.4s\n"
+          "trn2 v17.4s, v8.4s, v9.4s\n"
+          "trn1 v18.4s, v10.4s, v11.4s\n"
+          "trn2 v19.4s, v10.4s, v11.4s\n"
+
+          "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n"
+          "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n"
+          "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n"
+          "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n"
+          "add w1, w1, #16\n"
+
+          "trn1 v20.2d, v16.2d, v18.2d\n"
+          "trn2 v22.2d, v16.2d, v18.2d\n"
+          "trn1 v21.2d, v17.2d, v19.2d\n"
+          "trn2 v23.2d, v17.2d, v19.2d\n"
+
+          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
+          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
+          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
+          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
+
+          "str q20, [%[packed_ptr], #0]\n"
+          "str q21, [%[packed_ptr], #32]\n"
+          "str q22, [%[packed_ptr], #64]\n"
+          "str q23, [%[packed_ptr], #96]\n"
+          "add %[packed_ptr], %[packed_ptr], #128\n"
+
+          "eor v12.16b, v12.16b, v26.16b\n"
+          "eor v13.16b, v13.16b, v26.16b\n"
+          "eor v14.16b, v14.16b, v26.16b\n"
+          "eor v15.16b, v15.16b, v26.16b\n"
+
+          "trn1 v16.4s, v12.4s, v13.4s\n"
+          "trn2 v17.4s, v12.4s, v13.4s\n"
+          "trn1 v18.4s, v14.4s, v15.4s\n"
+          "trn2 v19.4s, v14.4s, v15.4s\n"
+
+          "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n"
+          "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n"
+          "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n"
+          "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n"
+          "add w1, w1, #16\n"
+
+          "trn1 v20.2d, v16.2d, v18.2d\n"
+          "trn2 v22.2d, v16.2d, v18.2d\n"
+          "trn1 v21.2d, v17.2d, v19.2d\n"
+          "trn2 v23.2d, v17.2d, v19.2d\n"
+
+          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
+          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
+          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
+          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
+
+          "str q20, [%[packed_ptr], #0]\n"
+          "str q21, [%[packed_ptr], #32]\n"
+          "str q22, [%[packed_ptr], #64]\n"
+          "str q23, [%[packed_ptr], #96]\n"
+          "add %[packed_ptr], %[packed_ptr], #128\n"
+
+          "cmp w1, w2\n"
+          "bne 7b\n"
+
+          "8:\n"
+
+          "eor v0.16b, v0.16b, v26.16b\n"
+          "eor v1.16b, v1.16b, v26.16b\n"
+          "eor v2.16b, v2.16b, v26.16b\n"
+          "eor v3.16b, v3.16b, v26.16b\n"
+
+          "trn1 v16.4s, v0.4s, v1.4s\n"
+          "trn2 v17.4s, v0.4s, v1.4s\n"
+          "trn1 v18.4s, v2.4s, v3.4s\n"
+          "trn2 v19.4s, v2.4s, v3.4s\n"
+
+          "trn1 v20.2d, v16.2d, v18.2d\n"
+          "trn2 v22.2d, v16.2d, v18.2d\n"
+          "trn1 v21.2d, v17.2d, v19.2d\n"
+          "trn2 v23.2d, v17.2d, v19.2d\n"
+
+          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
+          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
+          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
+          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
+
+          "str q20, [%[packed_ptr], #0]\n"
+          "str q21, [%[packed_ptr], #32]\n"
+          "str q22, [%[packed_ptr], #64]\n"
+          "str q23, [%[packed_ptr], #96]\n"
+          "add %[packed_ptr], %[packed_ptr], #128\n"
+
+          "eor v4.16b, v4.16b, v26.16b\n"
+          "eor v5.16b, v5.16b, v26.16b\n"
+          "eor v6.16b, v6.16b, v26.16b\n"
+          "eor v7.16b, v7.16b, v26.16b\n"
+
+          "trn1 v16.4s, v4.4s, v5.4s\n"
+          "trn2 v17.4s, v4.4s, v5.4s\n"
+          "trn1 v18.4s, v6.4s, v7.4s\n"
+          "trn2 v19.4s, v6.4s, v7.4s\n"
+
+          "trn1 v20.2d, v16.2d, v18.2d\n"
+          "trn2 v22.2d, v16.2d, v18.2d\n"
+          "trn1 v21.2d, v17.2d, v19.2d\n"
+          "trn2 v23.2d, v17.2d, v19.2d\n"
+
+          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
+          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
+          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
+          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
+
+          "str q20, [%[packed_ptr], #0]\n"
+          "str q21, [%[packed_ptr], #32]\n"
+          "str q22, [%[packed_ptr], #64]\n"
+          "str q23, [%[packed_ptr], #96]\n"
+          "add %[packed_ptr], %[packed_ptr], #128\n"
+
+          "eor v8.16b, v8.16b, v26.16b\n"
+          "eor v9.16b, v9.16b, v26.16b\n"
+          "eor v10.16b, v10.16b, v26.16b\n"
+          "eor v11.16b, v11.16b, v26.16b\n"
+
+          "trn1 v16.4s, v8.4s, v9.4s\n"
+          "trn2 v17.4s, v8.4s, v9.4s\n"
+          "trn1 v18.4s, v10.4s, v11.4s\n"
+          "trn2 v19.4s, v10.4s, v11.4s\n"
+
+          "trn1 v20.2d, v16.2d, v18.2d\n"
+          "trn2 v22.2d, v16.2d, v18.2d\n"
+          "trn1 v21.2d, v17.2d, v19.2d\n"
+          "trn2 v23.2d, v17.2d, v19.2d\n"
+
+          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
+          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
+          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
+          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
+
+          "str q20, [%[packed_ptr], #0]\n"
+          "str q21, [%[packed_ptr], #32]\n"
+          "str q22, [%[packed_ptr], #64]\n"
+          "str q23, [%[packed_ptr], #96]\n"
+          "add %[packed_ptr], %[packed_ptr], #128\n"
+
+          "eor v12.16b, v12.16b, v26.16b\n"
+          "eor v13.16b, v13.16b, v26.16b\n"
+          "eor v14.16b, v14.16b, v26.16b\n"
+          "eor v15.16b, v15.16b, v26.16b\n"
+
+          "trn1 v16.4s, v12.4s, v13.4s\n"
+          "trn2 v17.4s, v12.4s, v13.4s\n"
+          "trn1 v18.4s, v14.4s, v15.4s\n"
+          "trn2 v19.4s, v14.4s, v15.4s\n"
+
+          "trn1 v20.2d, v16.2d, v18.2d\n"
+          "trn2 v22.2d, v16.2d, v18.2d\n"
+          "trn1 v21.2d, v17.2d, v19.2d\n"
+          "trn2 v23.2d, v17.2d, v19.2d\n"
+
+          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
+          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
+          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
+          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
+
+          "str q20, [%[packed_ptr], #0]\n"
+          "str q21, [%[packed_ptr], #32]\n"
+          "str q22, [%[packed_ptr], #64]\n"
+          "str q23, [%[packed_ptr], #96]\n"
+          "add %[packed_ptr], %[packed_ptr], #128\n"
+
+          "9:\n"
+#endif  // #if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING)
+          "and w2, %w[rows], #-16\n"
+          "cmp w1, w2\n"
+          "beq 3f\n"
+
+          "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
+          "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
+          "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
+          "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
+          "add w1, w1, #16\n"
+          "cmp w1, w2\n"
+          "beq 2f\n"
+
+          "1:\n"
+
+          "eor v0.16b, v0.16b, v26.16b\n"
+          "eor v1.16b, v1.16b, v26.16b\n"
+          "eor v2.16b, v2.16b, v26.16b\n"
+          "eor v3.16b, v3.16b, v26.16b\n"
+
+          "trn1 v16.4s, v0.4s, v1.4s\n"
+          "trn2 v17.4s, v0.4s, v1.4s\n"
+          "trn1 v18.4s, v2.4s, v3.4s\n"
+          "trn2 v19.4s, v2.4s, v3.4s\n"
+
+          "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
+          "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
+          "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
+          "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
+          "add w1, w1, #16\n"
+
+          "trn1 v20.2d, v16.2d, v18.2d\n"
+          "trn2 v22.2d, v16.2d, v18.2d\n"
+          "trn1 v21.2d, v17.2d, v19.2d\n"
+          "trn2 v23.2d, v17.2d, v19.2d\n"
+
+          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
+          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
+          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
+          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
+
+          "str q20, [%[packed_ptr], #0]\n"
+          "str q21, [%[packed_ptr], #32]\n"
+          "str q22, [%[packed_ptr], #64]\n"
+          "str q23, [%[packed_ptr], #96]\n"
+          "add %[packed_ptr], %[packed_ptr], #128\n"
+
+          "cmp w1, w2\n"
+          "bne 1b\n"
+
+          "2:\n"
+
+          "eor v0.16b, v0.16b, v26.16b\n"
+          "eor v1.16b, v1.16b, v26.16b\n"
+          "eor v2.16b, v2.16b, v26.16b\n"
+          "eor v3.16b, v3.16b, v26.16b\n"
+
+          "trn1 v16.4s, v0.4s, v1.4s\n"
+          "trn2 v17.4s, v0.4s, v1.4s\n"
+          "trn1 v18.4s, v2.4s, v3.4s\n"
+          "trn2 v19.4s, v2.4s, v3.4s\n"
+
+          "trn1 v20.2d, v16.2d, v18.2d\n"
+          "trn2 v22.2d, v16.2d, v18.2d\n"
+          "trn1 v21.2d, v17.2d, v19.2d\n"
+          "trn2 v23.2d, v17.2d, v19.2d\n"
+
+          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
+          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
+          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
+          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
+
+          "str q20, [%[packed_ptr], #0]\n"
+          "str q21, [%[packed_ptr], #32]\n"
+          "str q22, [%[packed_ptr], #64]\n"
+          "str q23, [%[packed_ptr], #96]\n"
+          "add %[packed_ptr], %[packed_ptr], #128\n"
+
+          "3:\n"
+
+          "ands w2, %w[rows], #15\n"
+          "beq 4f\n"
+          "dup v0.16b, %w[src_zero_point]\n"
+          "dup v1.16b, %w[src_zero_point]\n"
+          "dup v2.16b, %w[src_zero_point]\n"
+          "dup v3.16b, %w[src_zero_point]\n"
+#define RUY_LOAD_ONE_ROW(R)                   \
+  "cmp w2, #" #R "\n"                         \
+  "beq 5f\n"                                  \
+  "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \
+  "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \
+  "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \
+  "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n"
+
+          RUY_LOAD_ONE_ROW(0)
+          RUY_LOAD_ONE_ROW(1)
+          RUY_LOAD_ONE_ROW(2)
+          RUY_LOAD_ONE_ROW(3)
+          RUY_LOAD_ONE_ROW(4)
+          RUY_LOAD_ONE_ROW(5)
+          RUY_LOAD_ONE_ROW(6)
+          RUY_LOAD_ONE_ROW(7)
+          RUY_LOAD_ONE_ROW(8)
+          RUY_LOAD_ONE_ROW(9)
+          RUY_LOAD_ONE_ROW(10)
+          RUY_LOAD_ONE_ROW(11)
+          RUY_LOAD_ONE_ROW(12)
+          RUY_LOAD_ONE_ROW(13)
+          RUY_LOAD_ONE_ROW(14)
+          RUY_LOAD_ONE_ROW(15)
+#undef RUY_LOAD_ONE_ROW
+          "5:\n"
+
+          "eor v0.16b, v0.16b, v26.16b\n"
+          "eor v1.16b, v1.16b, v26.16b\n"
+          "eor v2.16b, v2.16b, v26.16b\n"
+          "eor v3.16b, v3.16b, v26.16b\n"
+
+          "trn1 v16.4s, v0.4s, v1.4s\n"
+          "trn2 v17.4s, v0.4s, v1.4s\n"
+          "trn1 v18.4s, v2.4s, v3.4s\n"
+          "trn2 v19.4s, v2.4s, v3.4s\n"
+
+          "trn1 v20.2d, v16.2d, v18.2d\n"
+          "trn2 v22.2d, v16.2d, v18.2d\n"
+          "trn1 v21.2d, v17.2d, v19.2d\n"
+          "trn2 v23.2d, v17.2d, v19.2d\n"
+
+          ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
+          "str q20, [%[packed_ptr], #0]\n"
+          "cmp w2, #4\n"
+          "ble 4f\n"
+          ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
+          "str q21, [%[packed_ptr], #32]\n"
+          "cmp w2, #8\n"
+          "ble 4f\n"
+          ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
+          "str q22, [%[packed_ptr], #64]\n"
+          "cmp w2, #12\n"
+          "ble 4f\n"
+          ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
+          "str q23, [%[packed_ptr], #96]\n"
+          "add %[packed_ptr], %[packed_ptr], #128\n"
+
+          "4:\n"
+
+          "add v28.4s, v28.4s, v29.4s\n"
+          "add v30.4s, v30.4s, v31.4s\n"
+          "add v28.4s, v28.4s, v30.4s\n"
+
+          "cmp %[sums_ptr], #0\n"
+          "beq 6f\n"
+          "st1 {v28.4s}, [%[sums_ptr]], #16\n"
+          "6:\n"
+      // clang-format on
+
+      : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
+        [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3),
+        [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr)
+      : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)),
+        [ src_inc1 ] "r"(static_cast<std::int64_t>(src_inc1)),
+        [ src_inc2 ] "r"(static_cast<std::int64_t>(src_inc2)),
+        [ src_inc3 ] "r"(static_cast<std::int64_t>(src_inc3)),
+        [ rows ] "r"(src_rows),
+        [ src_zero_point ] "r"(static_cast<int>(src_zero_point)),
+        [ input_xor ] "r"(input_xor)
+      : "cc", "memory", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
+        "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
+        "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
+        "v27", "v28", "v29", "v30", "v31");
+}
+
+#endif  // RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1,
+                             const float* src_ptr2, const float* src_ptr3,
+                             int src_inc0, int src_inc1, int src_inc2,
+                             int src_inc3, int src_rows, int src_zero_point,
+                             float* packed_ptr, int start_col, int end_col) {
+  gemmlowp::ScopedProfilingLabel label(
+      "Pack (kNeon, optimized for out-of-order cores)");
+  asm volatile(
+      // clang-format off
+          "mov w1, #0\n"
+
+          "and w2, %w[rows], #-4\n"
+          "cmp w1, w2\n"
+          "beq 3f\n"
+          "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n"
+          "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n"
+          "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n"
+          "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n"
+          "add w1, w1, #4\n"
+          "cmp w1, w2\n"
+
+          "beq 2f\n"
+
+          "1:\n"
+          "add w1, w1, #4\n"
+
+          "trn1 v16.4s, v0.4s, v1.4s\n"
+          "trn2 v17.4s, v0.4s, v1.4s\n"
+          "trn1 v18.4s, v2.4s, v3.4s\n"
+          "trn2 v19.4s, v2.4s, v3.4s\n"
+
+          "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n"
+          "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n"
+          "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n"
+          "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n"
+
+          "trn1 v20.2d, v16.2d, v18.2d\n"
+          "trn2 v22.2d, v16.2d, v18.2d\n"
+          "trn1 v21.2d, v17.2d, v19.2d\n"
+          "trn2 v23.2d, v17.2d, v19.2d\n"
+          "cmp w1, w2\n"
+
+          "str q20, [%[packed_ptr], #0]\n"
+          "str q21, [%[packed_ptr], #32]\n"
+          "str q22, [%[packed_ptr], #64]\n"
+          "str q23, [%[packed_ptr], #96]\n"
+
+          "add %[packed_ptr], %[packed_ptr], #128\n"
+
+          "bne 1b\n"
+
+          "2:\n"
+
+          "trn1 v16.4s, v0.4s, v1.4s\n"
+          "trn2 v17.4s, v0.4s, v1.4s\n"
+          "trn1 v18.4s, v2.4s, v3.4s\n"
+          "trn2 v19.4s, v2.4s, v3.4s\n"
+
+          "trn1 v20.2d, v16.2d, v18.2d\n"
+          "trn2 v22.2d, v16.2d, v18.2d\n"
+          "trn1 v21.2d, v17.2d, v19.2d\n"
+          "trn2 v23.2d, v17.2d, v19.2d\n"
+
+          "str q20, [%[packed_ptr], #0]\n"
+          "str q21, [%[packed_ptr], #32]\n"
+          "str q22, [%[packed_ptr], #64]\n"
+          "str q23, [%[packed_ptr], #96]\n"
+          "add %[packed_ptr], %[packed_ptr], #128\n"
+
+          "3:\n"
+
+          "ands w2, %w[rows], #3\n"
+          "beq 4f\n"
+          "dup v0.16b, wzr\n"
+          "dup v1.16b, wzr\n"
+          "dup v2.16b, wzr\n"
+          "dup v3.16b, wzr\n"
+#define RUY_LOAD_ONE_ROW(R)                   \
+  "cmp w2, #" #R "\n"                         \
+  "beq 5f\n"                                  \
+  "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \
+  "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \
+  "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \
+  "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n"
+
+          RUY_LOAD_ONE_ROW(0)
+          RUY_LOAD_ONE_ROW(1)
+          RUY_LOAD_ONE_ROW(2)
+          RUY_LOAD_ONE_ROW(3)
+#undef RUY_LOAD_ONE_ROW
+          "5:\n"
+
+          "trn1 v16.4s, v0.4s, v1.4s\n"
+          "trn2 v17.4s, v0.4s, v1.4s\n"
+          "trn1 v18.4s, v2.4s, v3.4s\n"
+          "trn2 v19.4s, v2.4s, v3.4s\n"
+
+          "trn1 v20.2d, v16.2d, v18.2d\n"
+          "trn2 v22.2d, v16.2d, v18.2d\n"
+          "trn1 v21.2d, v17.2d, v19.2d\n"
+          "trn2 v23.2d, v17.2d, v19.2d\n"
+
+          "mov x1, #32\n"
+
+#define RUY_STORE_ONE_ROW(ROW, REGISTER)                  \
+          "cmp w2, #" #ROW "\n"                           \
+          "beq 4f\n"                                      \
+          "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n"
+
+          RUY_STORE_ONE_ROW(0, v20)
+          RUY_STORE_ONE_ROW(1, v21)
+          RUY_STORE_ONE_ROW(2, v22)
+          RUY_STORE_ONE_ROW(3, v23)
+
+#undef RUY_STORE_ONE_ROW
+
+          "4:\n"
+
+      // clang-format on
+
+      : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
+        [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3),
+        [ packed_ptr ] "+r"(packed_ptr)
+      : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)),
+        [ src_inc1 ] "r"(static_cast<std::int64_t>(src_inc1)),
+        [ src_inc2 ] "r"(static_cast<std::int64_t>(src_inc2)),
+        [ src_inc3 ] "r"(static_cast<std::int64_t>(src_inc3)),
+        [ rows ] "r"(src_rows)
+      : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1",
+        "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
+        "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
+        "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
+}
+#endif
+
+#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1,
+                             const float* src_ptr2, const float* src_ptr3,
+                             int src_inc, int src_rows, int src_zero_point,
+                             float* packed_ptr, int start_col, int end_col,
+                             int output_stride) {
+  gemmlowp::ScopedProfilingLabel label(
+      "Pack (kNeon, optimized for out-of-order cores)");
+  asm volatile(
+      // clang-format off
+          "mov r1, #0\n"
+          "and r2, %[rows], #-4\n"
+          "cmp r1, r2\n"
+          "beq 3f\n"
+#define RUY_LOAD_FOUR_BY_FOUR()               \
+  /* Load q0 */                               \
+  "vld1.32 {d0, d1}, [%[src_ptr0]]\n"         \
+  /* if src_inc0 != 0, add 16 to src_ptr0 */  \
+  "and r3, %[src_inc], #1\n"                  \
+  "add %[src_ptr0], %[src_ptr0], r3, lsl #4\n"\
+  /* Load q1 */                               \
+  "vld1.32 {d2, d3}, [%[src_ptr1]]\n"         \
+  /* if src_inc1 != 0, add 16 to src_ptr0 */  \
+  "and r3, %[src_inc], #2\n"                  \
+  "add %[src_ptr1], %[src_ptr1], r3, lsl #3\n"\
+  /* Load q2 */                               \
+  "vld1.32 {d4, d5}, [%[src_ptr2]]\n"         \
+  /* if src_inc2 != 0, add 16 to src_ptr0 */  \
+  "and r3, %[src_inc], #4\n"                  \
+  "add %[src_ptr2], %[src_ptr2], r3, lsl #2\n"\
+  /* Load q3 */                               \
+  "vld1.32 {d6, d7}, [%[src_ptr3]]\n"         \
+  /* if src_inc3 != 0, add 16 to src_ptr0 */  \
+  "and r3, %[src_inc], #8\n"                  \
+  "add %[src_ptr3], %[src_ptr3], r3, lsl #1\n"\
+
+          RUY_LOAD_FOUR_BY_FOUR()
+          "add r1, r1, #4\n"
+          "cmp r1, r2\n"
+
+          "beq 2f\n"
+
+          "1:\n"
+          "add r1, r1, #4\n"
+
+          // Transpose 4x4 matrix.
+          "vzip.32 q0, q1\n"
+          "vzip.32 q2, q3\n"
+
+          "vtrn.32 q0, q2\n"
+          "vtrn.32 q1, q3\n"
+
+          "vzip.32 q0, q2\n"
+          "vzip.32 q1, q3\n"
+
+          "vmov q8, q0\n"
+          "vmov q9, q1\n"
+          "vmov q10, q2\n"
+          "vmov q11, q3\n"
+
+          RUY_LOAD_FOUR_BY_FOUR()
+#undef RUY_LOAD_FOUR_BY_FOUR
+
+#define RUY_STORE_FOUR_BY_FOUR()                  \
+  /* Store q8, q10, q9, q11 */                    \
+  /* q8 = d16, d17 */                             \
+  "vst1.32 {d16, d17}, [%[packed_ptr]]\n"         \
+  /* q10 = d20, d21 */                            \
+  "add %[packed_ptr], %[packed_ptr], %[stride]\n" \
+  "vst1.32 {d20, d21}, [%[packed_ptr]]\n"         \
+  /* q9 = d18, d19 */                             \
+  "add %[packed_ptr], %[packed_ptr], %[stride]\n" \
+  "vst1.32 {d18, d19}, [%[packed_ptr]]\n"         \
+  /* q11 = d22, d23 */                            \
+  "add %[packed_ptr], %[packed_ptr], %[stride]\n" \
+  "vst1.32 {d22, d23}, [%[packed_ptr]]\n"         \
+  "add %[packed_ptr], %[packed_ptr], %[stride]\n" \
+
+          RUY_STORE_FOUR_BY_FOUR()
+          "cmp r1, r2\n"
+
+          "bne 1b\n"
+
+          "2:\n"
+
+          // Transpose 4x4 matrix.
+          "vzip.32 q0, q1\n"
+          "vzip.32 q2, q3\n"
+
+          "vtrn.32 q0, q2\n"
+          "vtrn.32 q1, q3\n"
+
+          "vzip.32 q0, q2\n"
+          "vzip.32 q1, q3\n"
+
+          "vmov q8, q0\n"
+          "vmov q9, q1\n"
+          "vmov q10, q2\n"
+          "vmov q11, q3\n"
+
+          RUY_STORE_FOUR_BY_FOUR()
+#undef RUY_STORE_FOUR_BY_FOUR
+          "3:\n"
+
+          "ands r2, %[rows], #3\n"
+          "beq 4f\n"
+          "mov r0, #0\n"
+          // Zero out q0 - q3
+          "vdup.32 q0, r0\n"
+          "vdup.32 q1, r0\n"
+          "vdup.32 q2, r0\n"
+          "vdup.32 q3, r0\n"
+#define RUY_LOAD_ONE_ROW_FIRST_HALF(R, I)    \
+  "cmp r2, #" #R "\n"                        \
+  "beq 5f\n"                                 \
+  "vld1.32 { d0[" #I "] }, [%[src_ptr0]]!\n" \
+  "vld1.32 { d2[" #I "] }, [%[src_ptr1]]!\n" \
+  "vld1.32 { d4[" #I "] }, [%[src_ptr2]]!\n" \
+  "vld1.32 { d6[" #I "] }, [%[src_ptr3]]!\n"
+
+#define RUY_LOAD_ONE_ROW_SECOND_HALF(R, I)      \
+  "vld1.32 { d1[" #I "] }, [%[src_ptr0]]!\n" \
+  "vld1.32 { d3[" #I "] }, [%[src_ptr1]]!\n" \
+  "vld1.32 { d5[" #I "] }, [%[src_ptr2]]!\n" \
+  "vld1.32 { d7[" #I "] }, [%[src_ptr3]]!\n"
+
+          RUY_LOAD_ONE_ROW_FIRST_HALF(0, 0)
+          RUY_LOAD_ONE_ROW_FIRST_HALF(1, 1)
+          RUY_LOAD_ONE_ROW_SECOND_HALF(2, 0)
+          RUY_LOAD_ONE_ROW_SECOND_HALF(3, 1)
+#undef RUY_LOAD_ONE_ROW_SECOND_HALF
+#undef RUY_LOAD_ONE_ROW_FIRST_HALF
+          "5:\n"
+
+          // Transpose 4x4 matrix.
+          "vzip.32 q0, q1\n"
+          "vzip.32 q2, q3\n"
+
+          "vtrn.32 q0, q2\n"
+          "vtrn.32 q1, q3\n"
+
+          "vzip.32 q0, q2\n"
+          "vzip.32 q1, q3\n"
+
+          "vmov q8, q0\n"
+          "vmov q9, q1\n"
+          "vmov q10, q2\n"
+          "vmov q11, q3\n"
+
+          "mov r1, #32\n"
+
+#define RUY_STORE_ONE_ROW(ROW, REGISTER)      \
+          "cmp r2, #" #ROW "\n"                           \
+          "beq 4f\n"                                      \
+          "vst1.32 {" #REGISTER "}, [%[packed_ptr]]\n"    \
+          "add %[packed_ptr], %[packed_ptr], %[stride]\n"
+
+          // Store q8
+          RUY_STORE_ONE_ROW(0, q8)
+          // Store q10
+          RUY_STORE_ONE_ROW(1, q10)
+          // Store q9
+          RUY_STORE_ONE_ROW(2, q9)
+          // Store q11
+          RUY_STORE_ONE_ROW(3, q11)
+
+#undef RUY_STORE_ONE_ROW
+
+          "4:\n"
+
+      // clang-format on
+      : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
+        [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3),
+        [ packed_ptr ] "+r"(packed_ptr)
+      : [ src_inc ] "r"(static_cast<std::int64_t>(src_inc)),
+        [ rows ] "r"(src_rows), [ stride ] "r"(output_stride)
+      : "cc", "memory", "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3",
+        "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", "d12", "d13",
+        "d14", "d15", "d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23");
+}
+
+#endif  // (RUY_PLATFORM(NEON_32)
+
+#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+void PackFloatNeonInOrder(const float* src_ptr0, const float* src_ptr1,
+                          const float* src_ptr2, const float* src_ptr3,
+                          int src_inc0, int src_inc1, int src_inc2,
+                          int src_inc3, int src_rows, int src_zero_point,
+                          float* packed_ptr, int start_col, int end_col) {
+  gemmlowp::ScopedProfilingLabel label(
+      "Pack (kNeon, optimized for in-order cores)");
+
+  asm volatile(
+          // clang-format off
+          "mov w1, #0\n"
+
+          "and w2, %w[rows], #-4\n"
+          "cmp w1, w2\n"
+          "beq 3f\n"
+          "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n"
+          "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n"
+          "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n"
+          "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n"
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr0], #64]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr1], #64]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr2], #64]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr3], #64]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr0], #128]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr1], #128]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr2], #128]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr3], #128]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr0], #192]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr1], #192]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr2], #192]\n")
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr3], #192]\n")
+          "add w1, w1, #4\n"
+          "cmp w1, w2\n"
+
+          "beq 2f\n"
+
+          "1:\n"
+          "add w1, w1, #4\n"
+
+          "ldr x10, [%[src_ptr0], #8]\n"
+          "trn1 v16.4s, v0.4s, v1.4s\n"
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr0], #240]\n")
+          "ldr x11, [%[src_ptr1], #8]\n"
+          "trn2 v17.4s, v0.4s, v1.4s\n"
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr1], #240]\n")
+          "ldr x12, [%[src_ptr2], #8]\n"
+          "trn1 v18.4s, v2.4s, v3.4s\n"
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr2], #240]\n")
+          "ldr x13, [%[src_ptr3], #8]\n"
+          "trn2 v19.4s, v2.4s, v3.4s\n"
+          RUY_PREFETCH("prfm pldl1strm, [%[src_ptr3], #240]\n")
+
+          "ld1 {v0.2s}, [%[src_ptr0]], %[src_inc0]\n"
+          "trn1 v20.2d, v16.2d, v18.2d\n"
+          "ld1 {v1.2s}, [%[src_ptr1]], %[src_inc1]\n"
+          "trn2 v22.2d, v16.2d, v18.2d\n"
+          "ld1 {v2.2s}, [%[src_ptr2]], %[src_inc2]\n"
+          "trn1 v21.2d, v17.2d, v19.2d\n"
+          "ld1 {v3.2s}, [%[src_ptr3]], %[src_inc3]\n"
+          "trn2 v23.2d, v17.2d, v19.2d\n"
+          "cmp w1, w2\n"
+
+          "ins v0.d[1], x10\n"
+          "str q20, [%[packed_ptr], #0]\n"
+          "ins v1.d[1], x11\n"
+          "str q21, [%[packed_ptr], #32]\n"
+          "ins v2.d[1], x12\n"
+          "str q22, [%[packed_ptr], #64]\n"
+          "ins v3.d[1], x13\n"
+          "str q23, [%[packed_ptr], #96]\n"
+
+          "add %[packed_ptr], %[packed_ptr], #128\n"
+
+          "bne 1b\n"
+
+          "2:\n"
+
+          "trn1 v16.4s, v0.4s, v1.4s\n"
+          "trn2 v17.4s, v0.4s, v1.4s\n"
+          "trn1 v18.4s, v2.4s, v3.4s\n"
+          "trn2 v19.4s, v2.4s, v3.4s\n"
+
+          "trn1 v20.2d, v16.2d, v18.2d\n"
+          "trn2 v22.2d, v16.2d, v18.2d\n"
+          "trn1 v21.2d, v17.2d, v19.2d\n"
+          "trn2 v23.2d, v17.2d, v19.2d\n"
+
+          "str q20, [%[packed_ptr], #0]\n"
+          "str q21, [%[packed_ptr], #32]\n"
+          "str q22, [%[packed_ptr], #64]\n"
+          "str q23, [%[packed_ptr], #96]\n"
+          "add %[packed_ptr], %[packed_ptr], #128\n"
+
+          "3:\n"
+
+          "ands w2, %w[rows], #3\n"
+          "beq 4f\n"
+          "dup v0.16b, wzr\n"
+          "dup v1.16b, wzr\n"
+          "dup v2.16b, wzr\n"
+          "dup v3.16b, wzr\n"
+#define RUY_LOAD_ONE_ROW(R)                   \
+  "cmp w2, #" #R "\n"                         \
+  "beq 5f\n"                                  \
+  "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \
+  "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \
+  "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \
+  "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n"
+
+          RUY_LOAD_ONE_ROW(0)
+          RUY_LOAD_ONE_ROW(1)
+          RUY_LOAD_ONE_ROW(2)
+          RUY_LOAD_ONE_ROW(3)
+#undef RUY_LOAD_ONE_ROW
+          "5:\n"
+
+          "trn1 v16.4s, v0.4s, v1.4s\n"
+          "trn2 v17.4s, v0.4s, v1.4s\n"
+          "trn1 v18.4s, v2.4s, v3.4s\n"
+          "trn2 v19.4s, v2.4s, v3.4s\n"
+
+          "trn1 v20.2d, v16.2d, v18.2d\n"
+          "trn2 v22.2d, v16.2d, v18.2d\n"
+          "trn1 v21.2d, v17.2d, v19.2d\n"
+          "trn2 v23.2d, v17.2d, v19.2d\n"
+
+          "mov x1, #32\n"
+
+#define RUY_STORE_ONE_ROW(ROW, REGISTER)                  \
+          "cmp w2, #" #ROW "\n"                           \
+          "beq 4f\n"                                      \
+          "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n"
+
+          RUY_STORE_ONE_ROW(0, v20)
+          RUY_STORE_ONE_ROW(1, v21)
+          RUY_STORE_ONE_ROW(2, v22)
+          RUY_STORE_ONE_ROW(3, v23)
+
+#undef RUY_STORE_ONE_ROW
+
+          "4:\n"
+
+          // clang-format on
+
+          : [ src_ptr0 ] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), [src_ptr2] "+r"(src_ptr2),
+            [src_ptr3] "+r"(src_ptr3), [packed_ptr] "+r"(packed_ptr)
+          : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)), [src_inc1] "r"(static_cast<std::int64_t>(src_inc1)), [src_inc2] "r"(static_cast<std::int64_t>(src_inc2)),
+            [src_inc3] "r"(static_cast<std::int64_t>(src_inc3)), [rows] "r"(src_rows)
+          : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
+            "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
+            "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
+}
+#endif  // RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+}  // namespace ruy
diff --git a/tensorflow/lite/experimental/ruy/pack_arm.h b/tensorflow/lite/experimental/ruy/pack_arm.h
new file mode 100644
index 0000000..c3696e0
--- /dev/null
+++ b/tensorflow/lite/experimental/ruy/pack_arm.h
@@ -0,0 +1,422 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// # What is "packing"?
+//
+// Before feeding data to the gemm kernels (the parts of Ruy that do lots
+// of multiply-add operations), Ruy first performs a data transformation (which
+// we call "packing") on the input matrices. This transformation has two main
+// goals:
+// - rearrange data into blocks that are a convenient size/layout for the gemm
+// kernels to consume. This helps make the memory access pattern of the gemm
+// kernel simpler and more contiguous, and puts the data in a layout most
+// convenient for specific arithmetic instructions in the gemm kernel.
+// - compute row/column sums needed for handling quantization with non-symmetric
+// zero points.
+//
+// # Simplified algorithmic analysis of packing
+//
+// Packing is a relatively simple transformation which does a small constant
+// amount of work on each element of an input matrix, and hence for an NxM
+// matrix performs O(N*M) work. If N and M are of the same order, then this is
+// O(N^2) work.
+//
+// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations.
+// Note that if N, K, and M are all the same order, then the number of
+// multiply-accumulate operations is O(N^3).
+//
+// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the
+// case of all dimensions being roughly the same order.
+//
+// # Packing cost can be significant
+//
+// When matrix * matrix multiplications begin to look more like matrix * vector
+// multiplications, packing cost can become significant. We sometimes call these
+// cases "gemv-like".
+//
+// Continuing the algorithmic analysis above, if we consider a case where an
+// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the
+// situation is different. In this case, the multiply-accumulate work is only
+// quadratic, so the quadratic cost of packing can be come significant.
+//
+// Another way to say this is that the cost of packing an input matrix (either
+// the LHS or RHS) is amortized across the non-depth dimension of the opposite
+// input matrix. Thus, when the LHS has very few rows or the RHS has very few
+// columns, the cost of packing the opposite input matrix can become
+// significant.
+//
+// As a rough rule of thumb, the cost of packing starts to become significant
+// when either N or M is below 32 (and other dimensions are hundreds), with very
+// significant packing costs at 8 or below. This varies by data type, Path, and
+// tuning, so these numbers are only rough guides.
+//
+// One practical use case that is affected by this is inference of
+// fully connected neural network layers with a low batch size. The weight
+// matrix (which is a constant for inference) is the one affected by significant
+// packing cost.
+//
+// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack
+// input matrices that are affected by significant packing costs.
+//
+// # Implementation notes
+//
+// Ruy's packing routines always operate on a range of columns and can be
+// applied to either the LHS or RHS. This is possible because Ruy internally
+// implements a TrMul, so the accumulation along depth is done along columns of
+// both the LHS and RHS (whereas for a normal Mul the accumulation along depth
+// for the LHS is along rows). As another example, we are always computing
+// column sums for quantization (and never row sums, since the LHS is
+// transposed).
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_ARM_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_ARM_H_
+
+#include <cstdint>
+#include <type_traits>
+
+#include "profiling/instrumentation.h"
+#include "tensorflow/lite/experimental/ruy/check_macros.h"
+#include "tensorflow/lite/experimental/ruy/common.h"
+#include "tensorflow/lite/experimental/ruy/internal_matrix.h"
+#include "tensorflow/lite/experimental/ruy/matrix.h"
+#include "tensorflow/lite/experimental/ruy/opt_set.h"
+#include "tensorflow/lite/experimental/ruy/pack_common.h"
+#include "tensorflow/lite/experimental/ruy/path.h"
+#include "tensorflow/lite/experimental/ruy/platform.h"
+#include "tensorflow/lite/experimental/ruy/tune.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+void Pack8bitNeonOutOfOrder(const void* src_ptr0, const void* src_ptr1,
+                            const void* src_ptr2, const void* src_ptr3,
+                            int src_inc0, int src_inc1, int src_inc2,
+                            int src_inc3, int src_rows, int src_zero_point,
+                            std::int8_t* packed_ptr, int start_col, int end_col,
+                            std::int32_t* sums_ptr, int input_xor);
+void Pack8bitNeonInOrder(const void* src_ptr0, const void* src_ptr1,
+                         const void* src_ptr2, const void* src_ptr3,
+                         int src_inc0, int src_inc1, int src_inc2, int src_inc3,
+                         int src_rows, int src_zero_point,
+                         std::int8_t* packed_ptr, int start_col, int end_col,
+                         std::int32_t* sums_ptr, int input_xor);
+void Pack8bitNeonDotprodOutOfOrder(const void* src_ptr0, const void* src_ptr1,
+                                   const void* src_ptr2, const void* src_ptr3,
+                                   int src_inc0, int src_inc1, int src_inc2,
+                                   int src_inc3, int src_rows,
+                                   int src_zero_point, std::int8_t* packed_ptr,
+                                   int start_col, int end_col,
+                                   std::int32_t* sums_ptr, int input_xor);
+void Pack8bitNeonDotprodInOrder(const void* src_ptr0, const void* src_ptr1,
+                                const void* src_ptr2, const void* src_ptr3,
+                                int src_inc0, int src_inc1, int src_inc2,
+                                int src_inc3, int src_rows, int src_zero_point,
+                                std::int8_t* packed_ptr, int start_col,
+                                int end_col, std::int32_t* sums_ptr,
+                                int input_xor);
+
+template <typename Scalar>
+struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kColMajor, 16, 4>, Scalar,
+                std::int8_t, std::int32_t> {
+  static_assert(std::is_same<Scalar, std::int8_t>::value ||
+                    std::is_same<Scalar, std::uint8_t>::value,
+                "");
+  static constexpr int kInputXor =
+      std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+
+  static void Run(Tuning tuning, const Matrix<Scalar>& src_matrix,
+                  PackedMatrix<std::int8_t>* packed_matrix, int start_col,
+                  int end_col) {
+    RUY_DCHECK(IsColMajor(src_matrix.layout));
+    RUY_DCHECK(IsColMajor(packed_matrix->layout));
+    RUY_DCHECK_EQ(start_col % 4, 0);
+    std::int32_t* sums = packed_matrix->sums;
+    Scalar zerobuf[16];
+    memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf));
+    for (int block_col = start_col; block_col < end_col; block_col += 4) {
+      int src_stride = src_matrix.layout.stride;
+      const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
+      const Scalar* src_ptr1 = src_ptr0 + src_stride;
+      const Scalar* src_ptr2 = src_ptr1 + src_stride;
+      const Scalar* src_ptr3 = src_ptr2 + src_stride;
+      int src_inc0 = 16;
+      int src_inc1 = 16;
+      int src_inc2 = 16;
+      int src_inc3 = 16;
+      if (block_col >= src_matrix.layout.cols - 3) {
+        if (block_col >= src_matrix.layout.cols - 0) {
+          src_ptr0 = zerobuf;
+          src_inc0 = 0;
+        }
+        if (block_col >= src_matrix.layout.cols - 1) {
+          src_ptr1 = zerobuf;
+          src_inc1 = 0;
+        }
+        if (block_col >= src_matrix.layout.cols - 2) {
+          src_ptr2 = zerobuf;
+          src_inc2 = 0;
+        }
+        if (block_col >= src_matrix.layout.cols - 3) {
+          src_ptr3 = zerobuf;
+          src_inc3 = 0;
+        }
+      }
+      std::int8_t* packed_ptr =
+          packed_matrix->data + packed_matrix->layout.stride * block_col;
+      std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
+      if (__builtin_expect(tuning == Tuning::kInOrder, true)) {
+        Pack8bitNeonInOrder(
+            src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1,
+            src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point,
+            packed_ptr, start_col, end_col, sums_ptr, kInputXor);
+      } else {
+        Pack8bitNeonOutOfOrder(
+            src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1,
+            src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point,
+            packed_ptr, start_col, end_col, sums_ptr, kInputXor);
+      }
+    }
+  }
+};
+
+template <typename Scalar>
+struct PackImpl<Path::kNeonDotprod, FixedKernelLayout<Order::kColMajor, 4, 8>,
+                Scalar, std::int8_t, std::int32_t> {
+  static_assert(std::is_same<Scalar, std::int8_t>::value ||
+                    std::is_same<Scalar, std::uint8_t>::value,
+                "");
+  static constexpr int kInputXor =
+      std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+
+  static void Run(Tuning tuning, const Matrix<Scalar>& src_matrix,
+                  PackedMatrix<std::int8_t>* packed_matrix, int start_col,
+                  int end_col) {
+    RUY_DCHECK(IsColMajor(src_matrix.layout));
+    RUY_DCHECK(IsColMajor(packed_matrix->layout));
+    RUY_DCHECK_EQ(start_col % 8, 0);
+    std::int32_t* sums = packed_matrix->sums;
+    Scalar zerobuf[16];
+    memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf));
+    for (int block_col = start_col; block_col < end_col; block_col += 4) {
+      int src_stride = src_matrix.layout.stride;
+      const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
+      const Scalar* src_ptr1 = src_ptr0 + src_stride;
+      const Scalar* src_ptr2 = src_ptr1 + src_stride;
+      const Scalar* src_ptr3 = src_ptr2 + src_stride;
+      std::int64_t src_inc0 = 16;
+      std::int64_t src_inc1 = 16;
+      std::int64_t src_inc2 = 16;
+      std::int64_t src_inc3 = 16;
+      if (block_col >= src_matrix.layout.cols - 3) {
+        if (block_col >= src_matrix.layout.cols - 0) {
+          src_ptr0 = zerobuf;
+          src_inc0 = 0;
+        }
+        if (block_col >= src_matrix.layout.cols - 1) {
+          src_ptr1 = zerobuf;
+          src_inc1 = 0;
+        }
+        if (block_col >= src_matrix.layout.cols - 2) {
+          src_ptr2 = zerobuf;
+          src_inc2 = 0;
+        }
+        if (block_col >= src_matrix.layout.cols - 3) {
+          src_ptr3 = zerobuf;
+          src_inc3 = 0;
+        }
+      }
+      std::int8_t* packed_ptr =
+          packed_matrix->data +
+          packed_matrix->layout.stride * (block_col & ~7) +
+          ((block_col & 4) * 4);
+      std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
+      if (__builtin_expect(tuning == Tuning::kInOrder, true)) {
+        Pack8bitNeonDotprodInOrder(
+            src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1,
+            src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point,
+            packed_ptr, start_col, end_col, sums_ptr, kInputXor);
+      } else {
+        Pack8bitNeonDotprodOutOfOrder(
+            src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1,
+            src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point,
+            packed_ptr, start_col, end_col, sums_ptr, kInputXor);
+      }
+    }
+  }
+};
+#endif  // (RUY_PLATFORM(NEON_64)&& RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1,
+                             const float* src_ptr2, const float* src_ptr3,
+                             int src_inc0, int src_inc1, int src_inc2,
+                             int src_inc3, int src_rows, int src_zero_point,
+                             float* packed_ptr, int start_col, int end_col);
+void PackFloatNeonInOrder(const float* src_ptr0, const float* src_ptr1,
+                          const float* src_ptr2, const float* src_ptr3,
+                          int src_inc0, int src_inc1, int src_inc2,
+                          int src_inc3, int src_rows, int src_zero_point,
+                          float* packed_ptr, int start_col, int end_col);
+
+#elif RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1,
+                             const float* src_ptr2, const float* src_ptr3,
+                             int src_inc, int src_rows, int src_zero_point,
+                             float* packed_ptr, int start_col, int end_col,
+                             int stride);
+#endif  // (RUY_PLATFORM(NEON_64)&& RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+#if (RUY_PLATFORM(NEON_32) || RUY_PLATFORM(NEON_64)) && \
+    RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+template <>
+struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
+                float, float> {
+  static void Run(Tuning tuning, const Matrix<float>& src_matrix,
+                  PackedMatrix<float>* packed_matrix, int start_col,
+                  int end_col) {
+    RUY_DCHECK(IsColMajor(src_matrix.layout));
+    RUY_DCHECK(IsColMajor(packed_matrix->layout));
+    RUY_DCHECK_EQ(start_col % 8, 0);
+    const float zerobuf[4] = {0};
+    for (int block_col = start_col; block_col < end_col; block_col += 4) {
+      int src_stride = src_matrix.layout.stride;
+      const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
+      const float* src_ptr1 = src_ptr0 + src_stride;
+      const float* src_ptr2 = src_ptr1 + src_stride;
+      const float* src_ptr3 = src_ptr2 + src_stride;
+      std::int64_t src_inc0 = 16;
+      std::int64_t src_inc1 = 16;
+      std::int64_t src_inc2 = 16;
+      std::int64_t src_inc3 = 16;
+      if (block_col >= src_matrix.layout.cols - 3) {
+        if (block_col >= src_matrix.layout.cols - 0) {
+          src_ptr0 = zerobuf;
+          src_inc0 = 0;
+        }
+        if (block_col >= src_matrix.layout.cols - 1) {
+          src_ptr1 = zerobuf;
+          src_inc1 = 0;
+        }
+        if (block_col >= src_matrix.layout.cols - 2) {
+          src_ptr2 = zerobuf;
+          src_inc2 = 0;
+        }
+        if (block_col >= src_matrix.layout.cols - 3) {
+          src_ptr3 = zerobuf;
+          src_inc3 = 0;
+        }
+      }
+      float* packed_ptr = packed_matrix->data +
+                          packed_matrix->layout.stride * (block_col & ~7) +
+                          ((block_col & 4));
+#if RUY_PLATFORM(NEON_64)
+      if (__builtin_expect(tuning == Tuning::kInOrder, true)) {
+        PackFloatNeonInOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0,
+                             src_inc1, src_inc2, src_inc3,
+                             src_matrix.layout.rows, src_matrix.zero_point,
+                             packed_ptr, start_col, end_col);
+      } else {
+        PackFloatNeonOutOfOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3,
+                                src_inc0, src_inc1, src_inc2, src_inc3,
+                                src_matrix.layout.rows, src_matrix.zero_point,
+                                packed_ptr, start_col, end_col);
+      }
+#else
+      // Encode each of src_inc0, ..., src_inc3 in lowest 4 bits of src_inc
+      // to save on registers (we have fewer general purpose registers in
+      // 32-bit ARM than in 64-bit ARM). For the 64-bit case, we pass four
+      // values that are each either 16 or 0 and use them directly. For the
+      // 32-bit case, bits 0, 1, 2, and 3 are used to determine if we should
+      // use the value 16 (bit is set) or 0 (bit is not set) for the
+      // respective increment value.
+      std::int64_t src_inc = 0;
+      src_inc += src_inc0 == 16 ? 1 : 0;
+      src_inc += src_inc1 == 16 ? 2 : 0;
+      src_inc += src_inc2 == 16 ? 4 : 0;
+      src_inc += src_inc3 == 16 ? 8 : 0;
+      const int kOutputStride = 32;
+      PackFloatNeonOutOfOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc,
+                              src_matrix.layout.rows, src_matrix.zero_point,
+                              packed_ptr, start_col, end_col, kOutputStride);
+#endif  // RUY_PLATFORM(NEON_64)
+    }
+  }
+};
+
+#if RUY_PLATFORM(NEON_32)
+// The 32-bit float kernel is 8 rows X 4 columns, so we need an additional
+// specialization for a FixedKernelLayout with 4 columns.
+template <>
+struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kRowMajor, 1, 4>, float,
+                float, float> {
+  static void Run(Tuning tuning, const Matrix<float>& src_matrix,
+                  PackedMatrix<float>* packed_matrix, int start_col,
+                  int end_col) {
+    RUY_DCHECK(IsColMajor(src_matrix.layout));
+    RUY_DCHECK(IsColMajor(packed_matrix->layout));
+    RUY_DCHECK_EQ(start_col % 4, 0);
+    const float zerobuf[4] = {0};
+    for (int block_col = start_col; block_col < end_col; block_col += 4) {
+      int src_stride = src_matrix.layout.stride;
+      const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
+      const float* src_ptr1 = src_ptr0 + src_stride;
+      const float* src_ptr2 = src_ptr1 + src_stride;
+      const float* src_ptr3 = src_ptr2 + src_stride;
+      std::int64_t src_inc0 = 16;
+      std::int64_t src_inc1 = 16;
+      std::int64_t src_inc2 = 16;
+      std::int64_t src_inc3 = 16;
+      if (block_col >= src_matrix.layout.cols - 3) {
+        if (block_col >= src_matrix.layout.cols - 0) {
+          src_ptr0 = zerobuf;
+          src_inc0 = 0;
+        }
+        if (block_col >= src_matrix.layout.cols - 1) {
+          src_ptr1 = zerobuf;
+          src_inc1 = 0;
+        }
+        if (block_col >= src_matrix.layout.cols - 2) {
+          src_ptr2 = zerobuf;
+          src_inc2 = 0;
+        }
+        if (block_col >= src_matrix.layout.cols - 3) {
+          src_ptr3 = zerobuf;
+          src_inc3 = 0;
+        }
+      }
+      float* packed_ptr =
+          packed_matrix->data + packed_matrix->layout.stride * (block_col);
+      // Encode each of src_inc0, ..., src_inc1 in lowest 4 bits of scrc_inc
+      // to save registers.
+      std::int64_t src_inc = 0;
+      src_inc += src_inc0 == 16 ? 1 : 0;
+      src_inc += src_inc1 == 16 ? 2 : 0;
+      src_inc += src_inc2 == 16 ? 4 : 0;
+      src_inc += src_inc3 == 16 ? 8 : 0;
+      const int kOutputStride = 16;
+      PackFloatNeonOutOfOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc,
+                              src_matrix.layout.rows, src_matrix.zero_point,
+                              packed_ptr, start_col, end_col, kOutputStride);
+    }
+  }
+};
+#endif  // (RUY_PLATFORM(NEON_32))
+#endif  // (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && \
+        // RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+}  // namespace ruy
+
+#endif  // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_ARM_H_
diff --git a/tensorflow/lite/experimental/ruy/pack_avx512.cc b/tensorflow/lite/experimental/ruy/pack_avx512.cc
new file mode 100644
index 0000000..1f2d29e
--- /dev/null
+++ b/tensorflow/lite/experimental/ruy/pack_avx512.cc
@@ -0,0 +1,546 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+#include <cstring>
+
+#include "profiling/instrumentation.h"
+#include "tensorflow/lite/experimental/ruy/check_macros.h"
+#include "tensorflow/lite/experimental/ruy/matrix.h"
+#include "tensorflow/lite/experimental/ruy/opt_set.h"
+#include "tensorflow/lite/experimental/ruy/pack.h"
+#include "tensorflow/lite/experimental/ruy/path.h"
+#include "tensorflow/lite/experimental/ruy/platform.h"
+
+#if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS)
+#include <immintrin.h>  // IWYU pragma: keep
+#endif
+
+namespace ruy {
+
+#if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS)
+
+// The first int8_t template parameter is arbitrary: this routine is common to
+// all 8-bit source matrix types.
+using PackImpl8bitAvx512 =
+    PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
+             std::int8_t, std::int8_t, std::int32_t>;
+
+namespace {
+
+inline void ZeroHalf8bitAvx512(int src_rows, std::int8_t packed_zero_point,
+                               std::int8_t* packed_ptr) {
+  using Layout = PackImpl8bitAvx512::Layout;
+  static constexpr int kHalfLayoutCols =
+      PackImpl8bitAvx512::kHalfLayoutCols;  // Half the number of cols in a
+                                            // block.
+  RUY_DCHECK_EQ(kHalfLayoutCols, 8);
+  RUY_DCHECK_EQ(Layout::kCols, 16);
+  RUY_DCHECK_EQ(Layout::kRows, 4);
+
+  const int non_trailing_blocks = (src_rows & ~31) >> 2;
+  // This routine fills half blocks, and typically fills the second halves.
+  // Thus packed_ptr is already offset by 8 * 4.
+  for (int k = 0; k < non_trailing_blocks; ++k) {
+    for (int j = 0; j < (kHalfLayoutCols * Layout::kRows); ++j) {
+      packed_ptr[Layout::kCols * Layout::kRows * k + j] = packed_zero_point;
+    }
+  }
+}
+
+inline void HalfPack8bitAvx512(const std::int8_t* src_ptr,
+                               std::int8_t input_xor,
+                               const std::int8_t* zerobuf, int src_stride,
+                               int remaining_src_cols, int src_rows,
+                               std::int8_t* packed_ptr, std::int32_t* sums_ptr,
+                               std::int8_t* trailing_buf) {
+  using Layout = PackImpl8bitAvx512::Layout;
+  static constexpr int kHalfLayoutCols =
+      PackImpl8bitAvx512::kHalfLayoutCols;  // Half the number of cols in a
+                                            // block.
+  RUY_DCHECK_EQ(Layout::kCols, 16);
+  RUY_DCHECK_EQ(Layout::kRows, 4);
+  RUY_DCHECK_EQ(kHalfLayoutCols, 8);
+
+  std::int8_t in_data[kHalfLayoutCols][kHalfLayoutCols][Layout::kCols];
+
+  const std::int8_t* src_ptr0 = src_ptr;
+  const std::int8_t* src_ptr1 = src_ptr0 + src_stride;
+  const std::int8_t* src_ptr2 = src_ptr1 + src_stride;
+  const std::int8_t* src_ptr3 = src_ptr2 + src_stride;
+  const std::int8_t* src_ptr4 = src_ptr3 + src_stride;
+  const std::int8_t* src_ptr5 = src_ptr4 + src_stride;
+  const std::int8_t* src_ptr6 = src_ptr5 + src_stride;
+  const std::int8_t* src_ptr7 = src_ptr6 + src_stride;
+  // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
+  // We process 8 of these chunks at a time, padding short input chunks.
+  constexpr int kNumRowChunks = 8;
+  constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
+  std::int64_t src_inc0 = kNumChunkedSrcRows;
+  std::int64_t src_inc1 = kNumChunkedSrcRows;
+  std::int64_t src_inc2 = kNumChunkedSrcRows;
+  std::int64_t src_inc3 = kNumChunkedSrcRows;
+  std::int64_t src_inc4 = kNumChunkedSrcRows;
+  std::int64_t src_inc5 = kNumChunkedSrcRows;
+  std::int64_t src_inc6 = kNumChunkedSrcRows;
+  std::int64_t src_inc7 = kNumChunkedSrcRows;
+  // Handle cases where source does not have kHalfLayoutCols (8) columns.
+  if (remaining_src_cols < 8) {
+    if (remaining_src_cols <= 0) {
+      src_ptr0 = zerobuf;
+      src_inc0 = 0;
+    }
+    if (remaining_src_cols <= 1) {
+      src_ptr1 = zerobuf;
+      src_inc1 = 0;
+    }
+    if (remaining_src_cols <= 2) {
+      src_ptr2 = zerobuf;
+      src_inc2 = 0;
+    }
+    if (remaining_src_cols <= 3) {
+      src_ptr3 = zerobuf;
+      src_inc3 = 0;
+    }
+    if (remaining_src_cols <= 4) {
+      src_ptr4 = zerobuf;
+      src_inc4 = 0;
+    }
+    if (remaining_src_cols <= 5) {
+      src_ptr5 = zerobuf;
+      src_inc5 = 0;
+    }
+    if (remaining_src_cols <= 6) {
+      src_ptr6 = zerobuf;
+      src_inc6 = 0;
+    }
+    src_ptr7 = zerobuf;
+    src_inc7 = 0;
+  }
+
+  const std::int8_t zero_point = zerobuf[0];
+
+  if (sums_ptr) {
+    // i: kHalfLayoutCols.
+    for (int i = 0; i < 8; ++i) {
+      sums_ptr[i] = 0;
+    }
+  }
+
+  // The overall packing effectively pads the source rows to
+  // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we
+  // only pack for (src_rows + 31) & ~31. When there is an incomplete
+  // destination block, this is stored into trailing_buf instead of packed_ptr.
+  for (int k = 0; k < src_rows; k += 2 * kNumChunkedSrcRows) {
+    // m: {0, 1} for 2 chunks of rows.
+    for (int m = 0; m < 2; ++m) {
+      // Available source rows.
+      // If this is less than 0 (for m=1), we skip, having filled trailing
+      // buffer for m=0. Also, if source rows is zero on m=1, then we filled
+      // exactly to the end of the column in the packed buffer.
+      const int available_src_rows = src_rows - k - m * kNumChunkedSrcRows;
+      // Effectively,
+      // available rows = std::max(0, std::min(8, src_rows - k - 8 * 4 * m));
+      // treat each case separately.
+      if (available_src_rows >= kNumChunkedSrcRows) {
+        // i: chunks, s: Layout::Rows.
+        for (int i = 0; i < 8; ++i) {
+          for (int s = 0; s < 4; ++s) {
+            in_data[0][i][s] = src_ptr0[i * 4 + s];
+            in_data[1][i][s] = src_ptr1[i * 4 + s];
+            in_data[2][i][s] = src_ptr2[i * 4 + s];
+            in_data[3][i][s] = src_ptr3[i * 4 + s];
+            in_data[4][i][s] = src_ptr4[i * 4 + s];
+            in_data[5][i][s] = src_ptr5[i * 4 + s];
+            in_data[6][i][s] = src_ptr6[i * 4 + s];
+            in_data[7][i][s] = src_ptr7[i * 4 + s];
+          }
+        }
+        // i: chunks, j: kHalfLayoutCols, s: Layout::Rows.
+        for (int i = 0; i < 8; ++i) {
+          for (int j = 0; j < 8; ++j) {
+            for (int s = 0; s < 4; ++s) {
+              // 16 * 4 * i is offset for each block, that is
+              // (Layout::kCols * Layout::kRows * i)
+              packed_ptr[(16 * i + j) * 4 + s] = in_data[j][i][s] ^ input_xor;
+            }
+            if (sums_ptr) {
+              for (int s = 0; s < 4; ++s) {
+                sums_ptr[j] += in_data[j][i][s] ^ input_xor;
+              }
+            }
+          }
+        }
+      } else if (available_src_rows > 0) {
+        RUY_DCHECK_LT(available_src_rows >> 2, kNumChunkedSrcRows);
+        int i = 0;
+        // Consume chunks of 4 rows that are complete.
+        for (; i < (available_src_rows >> 2); ++i) {
+          for (int s = 0; s < 4; ++s) {
+            in_data[0][i][s] = src_ptr0[i * 4 + s];
+            in_data[1][i][s] = src_ptr1[i * 4 + s];
+            in_data[2][i][s] = src_ptr2[i * 4 + s];
+            in_data[3][i][s] = src_ptr3[i * 4 + s];
+            in_data[4][i][s] = src_ptr4[i * 4 + s];
+            in_data[5][i][s] = src_ptr5[i * 4 + s];
+            in_data[6][i][s] = src_ptr6[i * 4 + s];
+            in_data[7][i][s] = src_ptr7[i * 4 + s];
+          }
+        }
+        // Consume any incomplete chunk.
+        if (i < ((available_src_rows + 3) >> 2)) {
+          int s = 0;
+          for (; s < (available_src_rows & 3); ++s) {
+            in_data[0][i][s] = src_ptr0[i * 4 + s];
+            in_data[1][i][s] = src_ptr1[i * 4 + s];
+            in_data[2][i][s] = src_ptr2[i * 4 + s];
+            in_data[3][i][s] = src_ptr3[i * 4 + s];
+            in_data[4][i][s] = src_ptr4[i * 4 + s];
+            in_data[5][i][s] = src_ptr5[i * 4 + s];
+            in_data[6][i][s] = src_ptr6[i * 4 + s];
+            in_data[7][i][s] = src_ptr7[i * 4 + s];
+          }
+          RUY_DCHECK_LE(s, 4);
+          for (; s < 4; ++s) {
+            // j: kHalfLayoutCols.
+            for (int j = 0; j < 8; ++j) {
+              in_data[j][i][s] = zero_point;
+            }
+          }
+          ++i;
+        }
+        // We do not care what goes into the trailing buffer, but we want
+        // in_data[...] ^ input_xor == 0 for irrelevant values in the summation.
+        //
+        // It might prove better in optimized code to pad uniformly with
+        // zero_point, and compensate by initializing the summations with the
+        // compensating offset, effectively
+        // ((input_xor - zero_point) ^ input_xor) *
+        //                         4 * (8 - ((available_src_rows + 3) >> 2)).
+        for (; i < 8; ++i) {
+          for (int s = 0; s < 4; ++s) {
+            for (int j = 0; j < 8; ++j) {
+              in_data[j][i][s] = input_xor;
+            }
+          }
+        }
+        // We loop through [0, 8) rather than
+        // [0, (available_src_rows + 3) >> 2), since that emulates what we might
+        // do in fully-optimized code.
+        //
+        // i: chunks, j: kHalfLayoutCols, s: Layout::Rows.
+        if (sums_ptr) {
+          for (int i = 0; i < 8; ++i) {
+            for (int j = 0; j < 8; ++j) {
+              for (int s = 0; s < 4; ++s) {
+                trailing_buf[(16 * i + j) * 4 + s] =
+                    in_data[j][i][s] ^ input_xor;
+                sums_ptr[j] = sums_ptr[j] + (in_data[j][i][s] ^ input_xor);
+              }
+            }
+          }
+        } else {
+          for (int i = 0; i < 8; ++i) {
+            for (int j = 0; j < 8; ++j) {
+              for (int s = 0; s < 4; ++s) {
+                trailing_buf[(16 * i + j) * 4 + s] =
+                    in_data[j][i][s] ^ input_xor;
+              }
+            }
+          }
+        }
+      }
+
+      packed_ptr += 16 * kNumChunkedSrcRows;
+      src_ptr0 += src_inc0;
+      src_ptr1 += src_inc1;
+      src_ptr2 += src_inc2;
+      src_ptr3 += src_inc3;
+      src_ptr4 += src_inc4;
+      src_ptr5 += src_inc5;
+      src_ptr6 += src_inc6;
+      src_ptr7 += src_inc7;
+    }
+  }
+}
+
+inline __m512 LoaduTwo(const float* addr_lo, const float* addr_hi) {
+  __m512 lower_filled = _mm512_castps256_ps512(_mm256_loadu_ps(addr_lo));
+  return _mm512_insertf32x8(lower_filled, _mm256_loadu_ps(addr_hi), 1);
+}
+
+inline __m512 MaskLoaduTwo(__mmask8 row_mask, const float* addr_lo,
+                           const float* addr_hi) {
+  __m512 lower_filled =
+      _mm512_castps256_ps512(_mm256_maskz_loadu_ps(row_mask, addr_lo));
+  return _mm512_insertf32x8(lower_filled,
+                            _mm256_maskz_loadu_ps(row_mask, addr_hi), 1);
+}
+
+inline __m512 Mm512UnpackloPsx2(const __m512 a, const __m512 b) {
+  return _mm512_castpd_ps(
+      _mm512_unpacklo_pd(_mm512_castps_pd(a), _mm512_castps_pd(b)));
+}
+
+inline __m512 Mm512UnpackhiPsx2(const __m512 a, const __m512 b) {
+  return _mm512_castpd_ps(
+      _mm512_unpackhi_pd(_mm512_castps_pd(a), _mm512_castps_pd(b)));
+}
+
+inline void HalfPackFloatAvx512(const float* src_ptr, const float* zerobuf,
+                                int src_stride, int remaining_src_cols,
+                                int src_rows, float* packed_ptr,
+                                float* trailing_buf) {
+  const float* src_ptr0 = src_ptr;
+  const float* src_ptr1 = src_ptr0 + src_stride;
+  const float* src_ptr2 = src_ptr1 + src_stride;
+  const float* src_ptr3 = src_ptr2 + src_stride;
+  const float* src_ptr4 = src_ptr3 + src_stride;
+  const float* src_ptr5 = src_ptr4 + src_stride;
+  const float* src_ptr6 = src_ptr5 + src_stride;
+  const float* src_ptr7 = src_ptr6 + src_stride;
+  std::int64_t src_inc0 = 8;
+  std::int64_t src_inc1 = 8;
+  std::int64_t src_inc2 = 8;
+  std::int64_t src_inc3 = 8;
+  std::int64_t src_inc4 = 8;
+  std::int64_t src_inc5 = 8;
+  std::int64_t src_inc6 = 8;
+  std::int64_t src_inc7 = 8;
+  if (remaining_src_cols < 8) {
+    if (remaining_src_cols <= 0) {
+      src_ptr0 = zerobuf;
+      src_inc0 = 0;
+    }
+    if (remaining_src_cols <= 1) {
+      src_ptr1 = zerobuf;
+      src_inc1 = 0;
+    }
+    if (remaining_src_cols <= 2) {
+      src_ptr2 = zerobuf;
+      src_inc2 = 0;
+    }
+    if (remaining_src_cols <= 3) {
+      src_ptr3 = zerobuf;
+      src_inc3 = 0;
+    }
+    if (remaining_src_cols <= 4) {
+      src_ptr4 = zerobuf;
+      src_inc4 = 0;
+    }
+    if (remaining_src_cols <= 5) {
+      src_ptr5 = zerobuf;
+      src_inc5 = 0;
+    }
+    if (remaining_src_cols <= 6) {
+      src_ptr6 = zerobuf;
+      src_inc6 = 0;
+    }
+    src_ptr7 = zerobuf;
+    src_inc7 = 0;
+  }
+
+  for (int k = 0; k < src_rows; k += 16) {
+    for (int m = 0; m < 2; ++m) {
+      const int available_src_rows = src_rows - k - 8 * m;
+      // Effectively,
+      // available_src_rows = std::max(0, std::min(8, src_rows - k - 8 * m));
+      // but treat each case separately.
+      if (available_src_rows > 7) {
+        __m512 t0, t1, t2, t3;
+        __m512 r0, r1, r2, r3;
+
+        t0 = LoaduTwo(src_ptr0, src_ptr4);
+        t1 = LoaduTwo(src_ptr1, src_ptr5);
+        t2 = LoaduTwo(src_ptr2, src_ptr6);
+        t3 = LoaduTwo(src_ptr3, src_ptr7);
+
+        r0 = _mm512_unpacklo_ps(t0, t1);
+        r2 = _mm512_unpackhi_ps(t0, t1);
+        r1 = _mm512_unpacklo_ps(t2, t3);
+        r3 = _mm512_unpackhi_ps(t2, t3);
+
+        t0 = Mm512UnpackloPsx2(r0, r1);
+        t2 = Mm512UnpackhiPsx2(r0, r1);
+        t1 = Mm512UnpackloPsx2(r2, r3);
+        t3 = Mm512UnpackhiPsx2(r2, r3);
+
+        r0 = _mm512_shuffle_f32x4(t0, t1, 0x88);
+        r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd);
+        r2 = _mm512_shuffle_f32x4(t2, t3, 0x88);
+        r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd);
+
+        _mm256_storeu_ps(packed_ptr + 0 * 16, _mm512_castps512_ps256(r0));
+        _mm256_storeu_ps(packed_ptr + 2 * 16, _mm512_extractf32x8_ps(r0, 1));
+        _mm256_storeu_ps(packed_ptr + 4 * 16, _mm512_castps512_ps256(r1));
+        _mm256_storeu_ps(packed_ptr + 6 * 16, _mm512_extractf32x8_ps(r1, 1));
+        _mm256_storeu_ps(packed_ptr + 1 * 16, _mm512_castps512_ps256(r2));
+        _mm256_storeu_ps(packed_ptr + 3 * 16, _mm512_extractf32x8_ps(r2, 1));
+        _mm256_storeu_ps(packed_ptr + 5 * 16, _mm512_castps512_ps256(r3));
+        _mm256_storeu_ps(packed_ptr + 7 * 16, _mm512_extractf32x8_ps(r3, 1));
+      } else if (available_src_rows > 0) {
+        const __mmask8 row_mask =
+            (static_cast<std::uint32_t>(1) << available_src_rows) - 1;
+
+        __m512 t0, t1, t2, t3;
+        __m512 r0, r1, r2, r3;
+
+        t0 = MaskLoaduTwo(row_mask, src_ptr0, src_ptr4);
+        t1 = MaskLoaduTwo(row_mask, src_ptr1, src_ptr5);
+        t2 = MaskLoaduTwo(row_mask, src_ptr2, src_ptr6);
+        t3 = MaskLoaduTwo(row_mask, src_ptr3, src_ptr7);
+
+        r0 = _mm512_unpacklo_ps(t0, t1);
+        r2 = _mm512_unpackhi_ps(t0, t1);
+        r1 = _mm512_unpacklo_ps(t2, t3);
+        r3 = _mm512_unpackhi_ps(t2, t3);
+
+        t0 = Mm512UnpackloPsx2(r0, r1);
+        t2 = Mm512UnpackhiPsx2(r0, r1);
+        t1 = Mm512UnpackloPsx2(r2, r3);
+        t3 = Mm512UnpackhiPsx2(r2, r3);
+
+        r0 = _mm512_shuffle_f32x4(t0, t1, 0x88);
+        r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd);
+        r2 = _mm512_shuffle_f32x4(t2, t3, 0x88);
+        r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd);
+
+        _mm256_storeu_ps(trailing_buf + 0 * 16, _mm512_castps512_ps256(r0));
+        _mm256_storeu_ps(trailing_buf + 2 * 16, _mm512_extractf32x8_ps(r0, 1));
+        _mm256_storeu_ps(trailing_buf + 4 * 16, _mm512_castps512_ps256(r1));
+        _mm256_storeu_ps(trailing_buf + 6 * 16, _mm512_extractf32x8_ps(r1, 1));
+        _mm256_storeu_ps(trailing_buf + 1 * 16, _mm512_castps512_ps256(r2));
+        _mm256_storeu_ps(trailing_buf + 3 * 16, _mm512_extractf32x8_ps(r2, 1));
+        _mm256_storeu_ps(trailing_buf + 5 * 16, _mm512_castps512_ps256(r3));
+        // Do not store _mm512_extractf32x8_ps(r3, 1).
+      }
+
+      packed_ptr += 16 * 8;
+      src_ptr0 += src_inc0;
+      src_ptr1 += src_inc1;
+      src_ptr2 += src_inc2;
+      src_ptr3 += src_inc3;
+      src_ptr4 += src_inc4;
+      src_ptr5 += src_inc5;
+      src_ptr6 += src_inc6;
+      src_ptr7 += src_inc7;
+    }
+  }
+}
+
+inline void ZeroHalfFloatAvx512(int src_rows, float* packed_ptr) {
+  const int non_trailing_rows = src_rows & ~7;
+  for (int k = 0; k < non_trailing_rows; ++k) {
+    for (int j = 0; j < 8; ++j) {
+      packed_ptr[j] = 0.0f;
+    }
+    packed_ptr += 16;
+  }
+}
+
+}  // namespace.
+
+void Pack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor,
+                    const std::int8_t* zerobuf, int src_stride,
+                    int remaining_src_cols, int src_rows,
+                    std::int8_t* packed_ptr, std::int32_t* sums_ptr) {
+  gemmlowp::ScopedProfilingLabel label("Pack kAvx512 8bit");
+
+  using Layout = PackImpl8bitAvx512::Layout;
+  constexpr int kHalfBlockOffset = 32;
+  RUY_DCHECK_EQ(kHalfBlockOffset * 2, Layout::kRows * Layout::kRows);
+  static constexpr int kHalfLayoutCols =
+      PackImpl8bitAvx512::kHalfLayoutCols;  // Half the number of cols in a
+                                            // block.
+  RUY_DCHECK_EQ(kHalfLayoutCols, 8);
+  RUY_DCHECK_EQ(Layout::kCols, 16);
+  RUY_DCHECK_EQ(Layout::kRows, 4);
+
+  // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
+  // We process 8 of these chunks at a time, padding short input chunks.
+  constexpr int kNumRowChunks = 8;
+
+  // Each packed block is 4*16, and there are normally 8. The trailing block is
+  // only slightly shorter.
+  constexpr int kTrailingBufSize =
+      kNumRowChunks * Layout::kCols * Layout::kRows;
+  std::int8_t trailing_buf[kTrailingBufSize];
+  memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t));
+
+  std::int32_t* second_sums_ptr =
+      sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr;
+  if (remaining_src_cols > kHalfLayoutCols) {
+    HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride,
+                       remaining_src_cols, src_rows, packed_ptr, sums_ptr,
+                       trailing_buf);
+    HalfPack8bitAvx512(src_ptr + src_stride * kHalfLayoutCols, input_xor,
+                       zerobuf, src_stride,
+                       remaining_src_cols - kHalfLayoutCols, src_rows,
+                       packed_ptr + kHalfBlockOffset, second_sums_ptr,
+                       trailing_buf + kHalfBlockOffset);
+  } else {
+    HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride,
+                       remaining_src_cols, src_rows, packed_ptr, sums_ptr,
+                       trailing_buf);
+    ZeroHalf8bitAvx512(src_rows, zerobuf[0] ^ input_xor,
+                       packed_ptr + kHalfBlockOffset);
+    // The kernel may not need the second half-blocks sums to be set.
+    if (second_sums_ptr) {
+      for (int i = 0; i < kHalfLayoutCols; ++i) {
+        second_sums_ptr[i] = (zerobuf[0] ^ input_xor) * ((src_rows + 3) & ~3);
+      }
+    }
+  }
+  constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
+  const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
+  // If the number of source rows is not a multiple of kChunkedRowMask, there
+  // will be data in the trailing buffer,
+  if (trailing_data > 0) {
+    const int non_trailing_rows = src_rows & ~kChunkedRowMask;
+    // Destination "rows" are padded to next highest multiple of Layout::kRows.
+    const int dst_rows = (src_rows + 3) & ~3;
+    const int trailing_rows = dst_rows - non_trailing_rows;
+    memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf,
+           Layout::kCols * trailing_rows * sizeof(std::int8_t));
+  }
+}
+
+void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride,
+                     int remaining_src_cols, int src_rows, float* packed_ptr) {
+  gemmlowp::ScopedProfilingLabel label("Pack kAvx512 float");
+  float trailing_buf[7 * 16];
+  if (remaining_src_cols > 8) {
+    HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
+                        src_rows, packed_ptr, trailing_buf);
+    HalfPackFloatAvx512(src_ptr + src_stride * 8, zerobuf, src_stride,
+                        remaining_src_cols - 8, src_rows, packed_ptr + 8,
+                        trailing_buf + 8);
+  } else {
+    memset(trailing_buf, 0, sizeof(trailing_buf));
+    HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
+                        src_rows, packed_ptr, trailing_buf);
+    ZeroHalfFloatAvx512(src_rows, packed_ptr + 8);
+  }
+  const int trailing_rows = src_rows & 7;
+  if (trailing_rows > 0) {
+    const int non_trailing_rows = src_rows & ~7;
+    memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf,
+           16 * trailing_rows * sizeof(float));
+  }
+}
+
+#endif  // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS)
+
+}  // namespace ruy
diff --git a/tensorflow/lite/experimental/ruy/pack_common.h b/tensorflow/lite/experimental/ruy/pack_common.h
new file mode 100644
index 0000000..ecad726
--- /dev/null
+++ b/tensorflow/lite/experimental/ruy/pack_common.h
@@ -0,0 +1,193 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// # What is "packing"?
+//
+// Before feeding data to the gemm kernels (the parts of Ruy that do lots
+// of multiply-add operations), Ruy first performs a data transformation (which
+// we call "packing") on the input matrices. This transformation has two main
+// goals:
+// - rearrange data into blocks that are a convenient size/layout for the gemm
+// kernels to consume. This helps make the memory access pattern of the gemm
+// kernel simpler and more contiguous, and puts the data in a layout most
+// convenient for specific arithmetic instructions in the gemm kernel.
+// - compute row/column sums needed for handling quantization with non-symmetric
+// zero points.
+//
+// # Simplified algorithmic analysis of packing
+//
+// Packing is a relatively simple transformation which does a small constant
+// amount of work on each element of an input matrix, and hence for an NxM
+// matrix performs O(N*M) work. If N and M are of the same order, then this is
+// O(N^2) work.
+//
+// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations.
+// Note that if N, K, and M are all the same order, then the number of
+// multiply-accumulate operations is O(N^3).
+//
+// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the
+// case of all dimensions being roughly the same order.
+//
+// # Packing cost can be significant
+//
+// When matrix * matrix multiplications begin to look more like matrix * vector
+// multiplications, packing cost can become significant. We sometimes call these
+// cases "gemv-like".
+//
+// Continuing the algorithmic analysis above, if we consider a case where an
+// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the
+// situation is different. In this case, the multiply-accumulate work is only
+// quadratic, so the quadratic cost of packing can be come significant.
+//
+// Another way to say this is that the cost of packing an input matrix (either
+// the LHS or RHS) is amortized across the non-depth dimension of the opposite
+// input matrix. Thus, when the LHS has very few rows or the RHS has very few
+// columns, the cost of packing the opposite input matrix can become
+// significant.
+//
+// As a rough rule of thumb, the cost of packing starts to become significant
+// when either N or M is below 32 (and other dimensions are hundreds), with very
+// significant packing costs at 8 or below. This varies by data type, Path, and
+// tuning, so these numbers are only rough guides.
+//
+// One practical use case that is affected by this is inference of
+// fully connected neural network layers with a low batch size. The weight
+// matrix (which is a constant for inference) is the one affected by significant
+// packing cost.
+//
+// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack
+// input matrices that are affected by significant packing costs.
+//
+// # Implementation notes
+//
+// Ruy's packing routines always operate on a range of columns and can be
+// applied to either the LHS or RHS. This is possible because Ruy internally
+// implements a TrMul, so the accumulation along depth is done along columns of
+// both the LHS and RHS (whereas for a normal Mul the accumulation along depth
+// for the LHS is along rows). As another example, we are always computing
+// column sums for quantization (and never row sums, since the LHS is
+// transposed).
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_COMMON_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_COMMON_H_
+
+#include <cstdint>
+
+#include "profiling/instrumentation.h"
+#include "tensorflow/lite/experimental/ruy/check_macros.h"
+#include "tensorflow/lite/experimental/ruy/common.h"
+#include "tensorflow/lite/experimental/ruy/internal_matrix.h"
+#include "tensorflow/lite/experimental/ruy/matrix.h"
+#include "tensorflow/lite/experimental/ruy/opt_set.h"
+#include "tensorflow/lite/experimental/ruy/path.h"
+#include "tensorflow/lite/experimental/ruy/platform.h"
+#include "tensorflow/lite/experimental/ruy/tune.h"
+
+namespace ruy {
+
+template <Path ThePath, typename Scalar>
+struct PackedTypeImpl {
+  using Type = Scalar;
+};
+
+#if RUY_PLATFORM(NEON)
+template <>
+struct PackedTypeImpl<Path::kNeon, std::uint8_t> {
+  using Type = std::int8_t;
+};
+template <>
+struct PackedTypeImpl<Path::kNeonDotprod, std::uint8_t> {
+  using Type = std::int8_t;
+};
+#elif RUY_PLATFORM(AVX512)
+template <>
+struct PackedTypeImpl<Path::kAvx512, std::uint8_t> {
+  using Type = std::int8_t;
+};
+#endif
+
+template <Path ThePath, typename Scalar>
+using PackedType = typename PackedTypeImpl<ThePath, Scalar>::Type;
+
+template <typename PackedScalar, typename Scalar>
+PackedScalar Pack(Scalar x) {
+  return x - SymmetricZeroPoint<Scalar>() + SymmetricZeroPoint<PackedScalar>();
+}
+
+template <Path ThePath, typename FixedKernelLayout, typename Scalar,
+          typename PackedScalar, typename SumsType>
+struct PackImpl {};
+
+#define RUY_INHERIT_PACK(PARENT, CHILD)                                       \
+  template <typename FixedKernelLayout, typename Scalar,                      \
+            typename PackedScalar, typename SumsType>                         \
+  struct PackImpl<CHILD, FixedKernelLayout, Scalar, PackedScalar, SumsType>   \
+      : PackImpl<PARENT, FixedKernelLayout, Scalar, PackedScalar, SumsType> { \
+  };
+
+template <typename FixedKernelLayout, typename Scalar, typename PackedScalar,
+          typename SumsType>
+struct PackImpl<Path::kStandardCpp, FixedKernelLayout, Scalar, PackedScalar,
+                SumsType> {
+  static void Run(Tuning, const Matrix<Scalar>& src_matrix,
+                  PackedMatrix<PackedScalar>* packed_matrix, int start_col,
+                  int end_col) {
+    gemmlowp::ScopedProfilingLabel label("Pack (generic)");
+    RUY_DCHECK_EQ((end_col - start_col) % FixedKernelLayout::kCols, 0);
+    SumsType* sums = packed_matrix->sums;
+    for (int col = start_col; col < end_col; col++) {
+      SumsType accum = 0;
+      for (int row = 0; row < packed_matrix->layout.rows; row++) {
+        PackedScalar packed_val;
+        if (col < src_matrix.layout.cols && row < src_matrix.layout.rows) {
+          packed_val = Pack<PackedScalar>(Element(src_matrix, row, col));
+        } else {
+          packed_val = packed_matrix->zero_point;
+        }
+        accum += packed_val;
+        *ElementPtr(packed_matrix, row, col) = packed_val;
+      }
+      if (sums) {
+        sums[col] = accum;
+      }
+    }
+  }
+};
+
+#if RUY_PLATFORM(NEON)
+RUY_INHERIT_PACK(Path::kStandardCpp, Path::kNeon)
+#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+RUY_INHERIT_PACK(Path::kNeon, Path::kNeonDotprod)
+#endif
+#elif RUY_PLATFORM(AVX512)
+RUY_INHERIT_PACK(Path::kStandardCpp, Path::kAvx512)
+#endif
+
+// Main entry point for packing.
+template <Path ThePath, typename FixedKernelLayout, typename Scalar,
+          typename PackedScalar>
+void RunPack(Tuning tuning, const DMatrix& src_matrix, PMatrix* packed_matrix,
+             int start_col, int end_col) {
+  using SumsType = typename PackedMatrix<PackedScalar>::SumsType;
+  Matrix<Scalar> src = ToMatrix<Scalar>(src_matrix);
+  PackedMatrix<PackedScalar> packed =
+      ToPackedMatrix<PackedScalar>(*packed_matrix);
+  PackImpl<ThePath, FixedKernelLayout, Scalar, PackedScalar, SumsType>::Run(
+      tuning, src, &packed, start_col, end_col);
+}
+
+}  // namespace ruy
+
+#endif  // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_COMMON_H_
diff --git a/tensorflow/lite/experimental/ruy/pack_x86.h b/tensorflow/lite/experimental/ruy/pack_x86.h
new file mode 100644
index 0000000..a4d12bb
--- /dev/null
+++ b/tensorflow/lite/experimental/ruy/pack_x86.h
@@ -0,0 +1,190 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// # What is "packing"?
+//
+// Before feeding data to the gemm kernels (the parts of Ruy that do lots
+// of multiply-add operations), Ruy first performs a data transformation (which
+// we call "packing") on the input matrices. This transformation has two main
+// goals:
+// - rearrange data into blocks that are a convenient size/layout for the gemm
+// kernels to consume. This helps make the memory access pattern of the gemm
+// kernel simpler and more contiguous, and puts the data in a layout most
+// convenient for specific arithmetic instructions in the gemm kernel.
+// - compute row/column sums needed for handling quantization with non-symmetric
+// zero points.
+//
+// # Simplified algorithmic analysis of packing
+//
+// Packing is a relatively simple transformation which does a small constant
+// amount of work on each element of an input matrix, and hence for an NxM
+// matrix performs O(N*M) work. If N and M are of the same order, then this is
+// O(N^2) work.
+//
+// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations.
+// Note that if N, K, and M are all the same order, then the number of
+// multiply-accumulate operations is O(N^3).
+//
+// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the
+// case of all dimensions being roughly the same order.
+//
+// # Packing cost can be significant
+//
+// When matrix * matrix multiplications begin to look more like matrix * vector
+// multiplications, packing cost can become significant. We sometimes call these
+// cases "gemv-like".
+//
+// Continuing the algorithmic analysis above, if we consider a case where an
+// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the
+// situation is different. In this case, the multiply-accumulate work is only
+// quadratic, so the quadratic cost of packing can be come significant.
+//
+// Another way to say this is that the cost of packing an input matrix (either
+// the LHS or RHS) is amortized across the non-depth dimension of the opposite
+// input matrix. Thus, when the LHS has very few rows or the RHS has very few
+// columns, the cost of packing the opposite input matrix can become
+// significant.
+//
+// As a rough rule of thumb, the cost of packing starts to become significant
+// when either N or M is below 32 (and other dimensions are hundreds), with very
+// significant packing costs at 8 or below. This varies by data type, Path, and
+// tuning, so these numbers are only rough guides.
+//
+// One practical use case that is affected by this is inference of
+// fully connected neural network layers with a low batch size. The weight
+// matrix (which is a constant for inference) is the one affected by significant
+// packing cost.
+//
+// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack
+// input matrices that are affected by significant packing costs.
+//
+// # Implementation notes
+//
+// Ruy's packing routines always operate on a range of columns and can be
+// applied to either the LHS or RHS. This is possible because Ruy internally
+// implements a TrMul, so the accumulation along depth is done along columns of
+// both the LHS and RHS (whereas for a normal Mul the accumulation along depth
+// for the LHS is along rows). As another example, we are always computing
+// column sums for quantization (and never row sums, since the LHS is
+// transposed).
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_X86_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_X86_H_
+
+#include <cstdint>
+#include <type_traits>
+
+#include "profiling/instrumentation.h"
+#include "tensorflow/lite/experimental/ruy/check_macros.h"
+#include "tensorflow/lite/experimental/ruy/common.h"
+#include "tensorflow/lite/experimental/ruy/internal_matrix.h"
+#include "tensorflow/lite/experimental/ruy/matrix.h"
+#include "tensorflow/lite/experimental/ruy/opt_set.h"
+#include "tensorflow/lite/experimental/ruy/pack_common.h"
+#include "tensorflow/lite/experimental/ruy/path.h"
+#include "tensorflow/lite/experimental/ruy/platform.h"
+#include "tensorflow/lite/experimental/ruy/tune.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+// Note that source and zero buffers can be uint8 type, but in the packing
+// function are reinterpreted as int8, and are XOR-ed with input_xor.
+void Pack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor,
+                    const std::int8_t* zerobuf, int src_stride,
+                    int remaining_src_cols, int src_rows,
+                    std::int8_t* packed_ptr, std::int32_t* sums_ptr);
+
+template <typename Scalar>
+struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
+                Scalar, std::int8_t, std::int32_t> {
+  static_assert(std::is_same<Scalar, std::int8_t>::value ||
+                    std::is_same<Scalar, std::uint8_t>::value,
+                "");
+  using Layout = FixedKernelLayout<Order::kColMajor, 4, 16>;
+  static constexpr int kHalfLayoutCols =
+      8;  // Half the number of cols in a block.
+  static constexpr std::int8_t kInputXor =
+      std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+
+  static void Run(Tuning tuning, const Matrix<Scalar>& src_matrix,
+                  PackedMatrix<std::int8_t>* packed_matrix, int start_col,
+                  int end_col) {
+    gemmlowp::ScopedProfilingLabel label("Pack (AVX-512)");
+
+    RUY_DCHECK(IsColMajor(src_matrix.layout));
+    RUY_DCHECK(IsColMajor(packed_matrix->layout));
+    RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
+    RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
+    RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols);
+    std::int32_t* sums = packed_matrix->sums;
+    Scalar zerobuf[kHalfLayoutCols * Layout::kRows];
+    memset(zerobuf, packed_matrix->zero_point ^ kInputXor,
+           kHalfLayoutCols * Layout::kRows * sizeof(Scalar));
+    for (int block_col = start_col; block_col < end_col;
+         block_col += Layout::kCols) {
+      std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
+      int src_stride = src_matrix.layout.stride;
+      const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col;
+      int remaining_src_cols = src_matrix.layout.cols - block_col;
+
+      static constexpr int block_col_mask = ~(Layout::kCols - 1);  // High bits.
+      std::int8_t* packed_ptr =
+          packed_matrix->data +
+          packed_matrix->layout.stride * (block_col & block_col_mask);
+      Pack8bitAvx512(reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor,
+                     reinterpret_cast<const std::int8_t*>(zerobuf), src_stride,
+                     remaining_src_cols, src_matrix.layout.rows, packed_ptr,
+                     sums_ptr);
+    }
+  }
+};
+
+void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride,
+                     int remaining_src_cols, int src_rows, float* packed_ptr);
+
+template <>
+struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kRowMajor, 1, 16>,
+                float, float, float> {
+  static void Run(Tuning, const Matrix<float>& src_matrix,
+                  PackedMatrix<float>* packed_matrix, int start_col,
+                  int end_col) {
+    using Layout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
+    RUY_DCHECK(IsColMajor(src_matrix.layout));
+    RUY_DCHECK(IsColMajor(packed_matrix->layout));
+    RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
+    RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
+    const float zerobuf[Layout::kCols] = {
+        0.0f};  // Remainder default inits to 0.0f.
+    for (int block_col = start_col; block_col < end_col;
+         block_col += Layout::kCols) {
+      int src_stride = src_matrix.layout.stride;
+      const float* src_ptr = src_matrix.data.get() + src_stride * block_col;
+      int remaining_src_cols = src_matrix.layout.cols - block_col;
+
+      static constexpr int block_col_mask = ~(Layout::kCols - 1);  // High bits.
+      float* packed_ptr =
+          packed_matrix->data +
+          packed_matrix->layout.stride * (block_col & block_col_mask);
+      PackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
+                      src_matrix.layout.rows, packed_ptr);
+    }
+  }
+};
+#endif  // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+}  // namespace ruy
+
+#endif  // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_X86_H_
diff --git a/tensorflow/lite/experimental/ruy/path.h b/tensorflow/lite/experimental/ruy/path.h
index b82e302..142abc7 100644
--- a/tensorflow/lite/experimental/ruy/path.h
+++ b/tensorflow/lite/experimental/ruy/path.h
@@ -51,6 +51,11 @@
 // given base architecture (such as ARM). Higher values of this enum correspond
 // to "better" code paths within a given base architecture for which Ruy has
 // optimized code paths.
+//
+// Values are reused across architectures.
+// Rationale: Scale better to N architectures, it is good to have small values
+// both for the compile-time logic to select paths, and when manually spelling
+// out Path values, such as when invoking a test or benchmark.
 enum class Path : std::uint8_t {
   // This is a special null value, representing the absence of any path.
   kNone = 0,
@@ -66,11 +71,19 @@
   //
   // This is intended for testing/development.
   kStandardCpp = 0x2,
+  //
+  // ARM architectures.
+  //
   // Optimized path using a widely available subset of ARM NEON instructions.
   kNeon = 0x4,
   // Optimized path making use of ARM NEON dot product instructions that are
   // available on newer ARM cores.
   kNeonDotprod = 0x8,
+  //
+  // x86 architectures.
+  //
+  // Optimized for AVX-512.
+  kAvx512 = 0x4,
 };
 
 inline constexpr Path operator|(Path p, Path q) {
@@ -104,6 +117,11 @@
     Path::kReference | Path::kStandardCpp | Path::kNeon | Path::kNeonDotprod;
 #elif RUY_PLATFORM(NEON_32)
 constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | Path::kNeon;
+#elif RUY_PLATFORM(AVX512)
+// TODO(b/138433137): kAllPaths should always contain kAvx512 regardless of
+// whether AVX-512 is enabled in the translation unit #including this header.
+constexpr Path kAllPaths =
+    Path::kReference | Path::kStandardCpp | Path::kAvx512;
 #else
 constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp;
 #endif
@@ -111,6 +129,9 @@
 // We don't know how to do runtime dotprod detection outside of linux for now.
 #if RUY_PLATFORM(NEON)
 constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | Path::kNeon;
+#elif RUY_PLATFORM(AVX512)
+constexpr Path kAllPaths =
+    Path::kReference | Path::kStandardCpp | Path::kAvx512;
 #else
 constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp;
 #endif
diff --git a/tensorflow/lite/experimental/ruy/platform.h b/tensorflow/lite/experimental/ruy/platform.h
index 29c0fc2..c7ef11a 100644
--- a/tensorflow/lite/experimental/ruy/platform.h
+++ b/tensorflow/lite/experimental/ruy/platform.h
@@ -18,38 +18,74 @@
 
 #define RUY_PLATFORM(X) ((RUY_DONOTUSEDIRECTLY_##X) != 0)
 
-// Detect ARM 32-bit
+// Architecture-level platform detection.
+//
+// Ruy requires these to be mutually exclusive.
+
+// Detect x86.
+#if defined(__x86_64__) || defined(__i386__) || defined(__i386) || \
+    defined(__x86__) || defined(__X86__) || defined(_X86_) ||      \
+    defined(_M_IX86) || defined(_M_X64)
+#define RUY_DONOTUSEDIRECTLY_X86 1
+#else
+#define RUY_DONOTUSEDIRECTLY_X86 0
+#endif
+
+// Detect ARM 32-bit.
 #ifdef __arm__
 #define RUY_DONOTUSEDIRECTLY_ARM_32 1
 #else
 #define RUY_DONOTUSEDIRECTLY_ARM_32 0
 #endif
 
-// Detect ARM 64-bit
+// Detect ARM 64-bit.
 #ifdef __aarch64__
 #define RUY_DONOTUSEDIRECTLY_ARM_64 1
 #else
 #define RUY_DONOTUSEDIRECTLY_ARM_64 0
 #endif
 
-// Detect NEON
-#if (defined __ARM_NEON) || (defined __ARM_NEON__)
+// Combined ARM.
+#define RUY_DONOTUSEDIRECTLY_ARM \
+  (RUY_DONOTUSEDIRECTLY_ARM_64 || RUY_DONOTUSEDIRECTLY_ARM_32)
+
+// Feature and capability platform detection.
+//
+// These are mostly sub-selections of architectures.
+
+// Detect NEON. Explictly avoid emulation, or anything like it, on x86.
+#if (defined(__ARM_NEON) || defined(__ARM_NEON__)) && !RUY_PLATFORM(X86)
 #define RUY_DONOTUSEDIRECTLY_NEON 1
 #else
 #define RUY_DONOTUSEDIRECTLY_NEON 0
 #endif
 
-// Define ARM 32-bit NEON
+// Define ARM 32-bit NEON.
 #define RUY_DONOTUSEDIRECTLY_NEON_32 \
   (RUY_DONOTUSEDIRECTLY_NEON && RUY_DONOTUSEDIRECTLY_ARM_32)
 
-// Define ARM 64-bit NEON
+// Define ARM 64-bit NEON.
 // Note: NEON is implied by ARM64, so this define is redundant.
 // It still allows some conveyance of intent.
 #define RUY_DONOTUSEDIRECTLY_NEON_64 \
   (RUY_DONOTUSEDIRECTLY_NEON && RUY_DONOTUSEDIRECTLY_ARM_64)
 
-// Detect APPLE
+// These CPU capabilities will all be true when Skylake is enabled during
+// compilation.
+//
+// TODO(b/138433137) Select AVX-512 at runtime rather than via compile options.
+//
+// Disabled on __APPLE__ because b/138922878, see comment #8, we may only need
+// to disable this on XCode <= 10.2.
+#if RUY_PLATFORM(X86) && defined(__AVX512F__) && defined(__AVX512DQ__) &&      \
+    defined(__AVX512CD__) && defined(__AVX512BW__) && defined(__AVX512VL__) && \
+    !defined(__APPLE__)
+#define RUY_DONOTUSEDIRECTLY_AVX512 1
+#else
+#define RUY_DONOTUSEDIRECTLY_AVX512 0
+#endif
+
+// Detect APPLE.
 #ifdef __APPLE__
 #define RUY_DONOTUSEDIRECTLY_APPLE 1
 #else
diff --git a/tensorflow/lite/experimental/ruy/pmu.cc b/tensorflow/lite/experimental/ruy/pmu.cc
index 40f5f50..3ec62bb 100644
--- a/tensorflow/lite/experimental/ruy/pmu.cc
+++ b/tensorflow/lite/experimental/ruy/pmu.cc
@@ -21,7 +21,9 @@
 #include <asm/unistd.h>
 #include <linux/perf_event.h>
 #include <sys/ioctl.h>
+#include <syscall.h>
 #include <unistd.h>
+
 #include <cstdio>
 #endif
 
@@ -47,7 +49,8 @@
     pe.exclude_hv = 1;
     fd_ = syscall(__NR_perf_event_open, &pe, 0, -1, -1, 0);
     if (fd_ == -1) {
-      fprintf(stderr, "perf_event_open failed for config 0x%lx\n", config);
+      fprintf(stderr, "perf_event_open failed for config 0x%lx\n",
+              static_cast<unsigned long>(config));
       // abort();
     }
     ioctl(fd_, PERF_EVENT_IOC_RESET, 0);
diff --git a/tensorflow/lite/experimental/ruy/pmu.h b/tensorflow/lite/experimental/ruy/pmu.h
index b77882c..03f0cb7 100644
--- a/tensorflow/lite/experimental/ruy/pmu.h
+++ b/tensorflow/lite/experimental/ruy/pmu.h
@@ -16,8 +16,6 @@
 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PMU_H_
 #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PMU_H_
 
-#include <cstdint>
-
 namespace ruy {
 
 class PmuEventsPrivate;
diff --git a/tensorflow/lite/experimental/ruy/prepack.h b/tensorflow/lite/experimental/ruy/prepack.h
index 9019efa..5966a5e 100644
--- a/tensorflow/lite/experimental/ruy/prepack.h
+++ b/tensorflow/lite/experimental/ruy/prepack.h
@@ -18,13 +18,20 @@
 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACK_H_
 #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACK_H_
 
+#include <cstddef>
 #include <functional>
 
+#include "profiling/instrumentation.h"
+#include "tensorflow/lite/experimental/ruy/check_macros.h"
 #include "tensorflow/lite/experimental/ruy/context.h"
 #include "tensorflow/lite/experimental/ruy/dispatch.h"
+#include "tensorflow/lite/experimental/ruy/internal_matrix.h"
 #include "tensorflow/lite/experimental/ruy/matrix.h"
 #include "tensorflow/lite/experimental/ruy/path.h"
+#include "tensorflow/lite/experimental/ruy/side_pair.h"
 #include "tensorflow/lite/experimental/ruy/spec.h"
+#include "tensorflow/lite/experimental/ruy/trmul.h"
+#include "tensorflow/lite/experimental/ruy/trmul_params.h"
 #include "tensorflow/lite/experimental/ruy/tune.h"
 
 namespace ruy {
@@ -34,8 +41,7 @@
 void PrePackForMulInternal(const Matrix<LhsScalar>& lhs,
                            const Matrix<RhsScalar>& rhs, const Spec& spec,
                            Context* context, Matrix<DstScalar>* dst,
-                           PrepackedMatrix* prepacked_lhs,
-                           PrepackedMatrix* prepacked_rhs,
+                           SidePair<PrepackedMatrix*> prepacked,
                            std::function<void*(std::size_t)> alloc_fn) {
   gemmlowp::ScopedProfilingLabel label("PrePackForMul");
   Path the_path = context->GetPathToTake<CompiledPaths>();
@@ -47,24 +53,21 @@
   CreateTrMulParams<TrMulCompiledPaths>(transposed_lhs, rhs, spec, context, dst,
                                         the_path, &params);
 
+  const SidePair<int> origin{0, 0};
+  const SidePair<int> rounded_dims{params.packed[Side::kLhs].layout.cols,
+                                   params.packed[Side::kRhs].layout.cols};
+
   Tuning tuning = context->GetMainThreadTuning();
-  if (prepacked_lhs) {
-    prepacked_lhs->data_size = DataSize(params.packed_lhs);
-    prepacked_lhs->sums_size = SumsSize(params.packed_lhs);
-    prepacked_lhs->data = alloc_fn(prepacked_lhs->data_size);
-    prepacked_lhs->sums = alloc_fn(prepacked_lhs->sums_size);
-    params.packed_lhs.data = prepacked_lhs->data;
-    params.packed_lhs.sums = prepacked_lhs->sums;
-    params.LhsRunPack(tuning, 0, params.packed_lhs.layout.cols);
-  }
-  if (prepacked_rhs) {
-    prepacked_rhs->data_size = DataSize(params.packed_rhs);
-    prepacked_rhs->sums_size = SumsSize(params.packed_rhs);
-    prepacked_rhs->data = alloc_fn(prepacked_rhs->data_size);
-    prepacked_rhs->sums = alloc_fn(prepacked_rhs->sums_size);
-    params.packed_rhs.data = prepacked_rhs->data;
-    params.packed_rhs.sums = prepacked_rhs->sums;
-    params.RhsRunPack(tuning, 0, params.packed_rhs.layout.cols);
+  for (Side side : {Side::kLhs, Side::kRhs}) {
+    if (prepacked[side]) {
+      prepacked[side]->data_size = DataSize(params.packed[side]);
+      prepacked[side]->sums_size = SumsSize(params.packed[side]);
+      prepacked[side]->data = alloc_fn(prepacked[side]->data_size);
+      prepacked[side]->sums = alloc_fn(prepacked[side]->sums_size);
+      params.packed[side].data = prepacked[side]->data;
+      params.packed[side].sums = prepacked[side]->sums;
+      params.RunPack(side, tuning, origin[side], rounded_dims[side]);
+    }
   }
 }
 
@@ -73,8 +76,7 @@
 void MulWithPrepackedInternal(const Matrix<LhsScalar>& lhs,
                               const Matrix<RhsScalar>& rhs, const Spec& spec,
                               Context* context, Matrix<DstScalar>* dst,
-                              PrepackedMatrix* prepacked_lhs,
-                              PrepackedMatrix* prepacked_rhs) {
+                              SidePair<PrepackedMatrix*> prepacked) {
   gemmlowp::ScopedProfilingLabel label("MulWithPrepacked");
 
   EnforceLayoutSupport<Spec>(lhs.layout, rhs.layout, dst->layout);
@@ -90,16 +92,14 @@
   CreateTrMulParams<TrMulCompiledPaths>(transposed_lhs, rhs, spec, context, dst,
                                         the_path, &params);
 
-  if (prepacked_lhs) {
-    params.packed_lhs.data = prepacked_lhs->data;
-    params.packed_lhs.sums = prepacked_lhs->sums;
-    params.lhs_is_prepacked = true;
+  for (Side side : {Side::kLhs, Side::kRhs}) {
+    if (prepacked[side]) {
+      params.packed[side].data = prepacked[side]->data;
+      params.packed[side].sums = prepacked[side]->sums;
+      params.is_prepacked[side] = true;
+    }
   }
-  if (prepacked_rhs) {
-    params.packed_rhs.data = prepacked_rhs->data;
-    params.packed_rhs.sums = prepacked_rhs->sums;
-    params.rhs_is_prepacked = true;
-  }
+
   TrMul(&params, context);
 }
 
diff --git a/tensorflow/lite/experimental/ruy/ruy.h b/tensorflow/lite/experimental/ruy/ruy.h
index e28e397..436b1af 100644
--- a/tensorflow/lite/experimental/ruy/ruy.h
+++ b/tensorflow/lite/experimental/ruy/ruy.h
@@ -21,6 +21,7 @@
 #include "tensorflow/lite/experimental/ruy/context.h"
 #include "tensorflow/lite/experimental/ruy/dispatch.h"
 #include "tensorflow/lite/experimental/ruy/matrix.h"
+#include "tensorflow/lite/experimental/ruy/path.h"
 #include "tensorflow/lite/experimental/ruy/spec.h"
 
 namespace ruy {
diff --git a/tensorflow/lite/experimental/ruy/ruy_advanced.h b/tensorflow/lite/experimental/ruy/ruy_advanced.h
index 36382e7..6874819 100644
--- a/tensorflow/lite/experimental/ruy/ruy_advanced.h
+++ b/tensorflow/lite/experimental/ruy/ruy_advanced.h
@@ -16,7 +16,14 @@
 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ADVANCED_H_
 #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ADVANCED_H_
 
+#include <cstddef>
+#include <functional>
+
+#include "tensorflow/lite/experimental/ruy/context.h"
+#include "tensorflow/lite/experimental/ruy/matrix.h"
+#include "tensorflow/lite/experimental/ruy/path.h"
 #include "tensorflow/lite/experimental/ruy/prepack.h"
+#include "tensorflow/lite/experimental/ruy/side_pair.h"
 
 namespace ruy {
 
@@ -40,8 +47,9 @@
                    PrepackedMatrix* prepacked_lhs,
                    PrepackedMatrix* prepacked_rhs,
                    std::function<void*(std::size_t)> alloc_fn) {
-  PrePackForMulInternal<CompiledPaths>(lhs, rhs, spec, context, dst,
-                                       prepacked_lhs, prepacked_rhs, alloc_fn);
+  SidePair<PrepackedMatrix*> prepacked(prepacked_lhs, prepacked_rhs);
+  PrePackForMulInternal<CompiledPaths>(lhs, rhs, spec, context, dst, prepacked,
+                                       alloc_fn);
 }
 
 template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
@@ -51,8 +59,9 @@
                       Context* context, Matrix<DstScalar>* dst,
                       PrepackedMatrix* prepacked_lhs,
                       PrepackedMatrix* prepacked_rhs) {
+  SidePair<PrepackedMatrix*> prepacked(prepacked_lhs, prepacked_rhs);
   MulWithPrepackedInternal<CompiledPaths>(lhs, rhs, spec, context, dst,
-                                          prepacked_lhs, prepacked_rhs);
+                                          prepacked);
 }
 
 }  // namespace ruy
diff --git a/tensorflow/lite/experimental/ruy/ruy_test.bzl b/tensorflow/lite/experimental/ruy/ruy_test.bzl
index df9f58c..986bddf 100644
--- a/tensorflow/lite/experimental/ruy/ruy_test.bzl
+++ b/tensorflow/lite/experimental/ruy/ruy_test.bzl
@@ -6,12 +6,12 @@
 and destination.
 """
 
-def ruy_test(name, srcs, lhs_rhs_accum_dst, tags = []):
+def ruy_test(name, srcs, lhs_rhs_accum_dst, copts, tags = []):
     for (lhs, rhs, accum, dst) in lhs_rhs_accum_dst:
         native.cc_test(
             name = "%s_%s_%s_%s_%s" % (name, lhs, rhs, accum, dst),
             srcs = srcs,
-            copts = [
+            copts = copts + [
                 "-DRUY_TEST_LHSSCALAR=%s" % lhs,
                 "-DRUY_TEST_RHSSCALAR=%s" % rhs,
                 "-DRUY_TEST_ACCUMSCALAR=%s" % accum,
@@ -24,13 +24,14 @@
             tags = tags,
         )
 
-def ruy_benchmark(name, srcs, lhs_rhs_accum_dst):
+def ruy_benchmark(name, srcs, lhs_rhs_accum_dst, copts):
+    tags = ["req_dep=@gemmlowp//:profiler"]
     for (lhs, rhs, accum, dst) in lhs_rhs_accum_dst:
         native.cc_binary(
             name = "%s_%s_%s_%s_%s" % (name, lhs, rhs, accum, dst),
             testonly = True,
             srcs = srcs,
-            copts = [
+            copts = copts + [
                 "-DRUY_TEST_LHSSCALAR=%s" % lhs,
                 "-DRUY_TEST_RHSSCALAR=%s" % rhs,
                 "-DRUY_TEST_ACCUMSCALAR=%s" % accum,
@@ -38,18 +39,20 @@
             ],
             deps = [
                 "//tensorflow/lite/experimental/ruy:test_lib",
-                "@gemmlowp//:profiler",
+                "@gemmlowp//:profiler",  # Note also tagged as req_dep.
             ],
+            tags = tags,
         )
 
-def ruy_benchmark_opt_sets(name, opt_sets, srcs, lhs_rhs_accum_dst):
+def ruy_benchmark_opt_sets(name, opt_sets, srcs, lhs_rhs_accum_dst, copts):
+    tags = ["req_dep=@gemmlowp//:profiler"]
     for opt_set in opt_sets:
         for (lhs, rhs, accum, dst) in lhs_rhs_accum_dst:
             native.cc_binary(
                 name = "%s_%s_%s_%s_%s_%s" % (name, opt_set, lhs, rhs, accum, dst),
                 testonly = True,
                 srcs = srcs,
-                copts = [
+                copts = copts + [
                     "-DRUY_TEST_LHSSCALAR=%s" % lhs,
                     "-DRUY_TEST_RHSSCALAR=%s" % rhs,
                     "-DRUY_TEST_ACCUMSCALAR=%s" % accum,
@@ -58,6 +61,7 @@
                 ],
                 deps = [
                     "//tensorflow/lite/experimental/ruy:test_lib",
-                    "@gemmlowp//:profiler",
+                    "@gemmlowp//:profiler",  # Note also tagged as req_dep.
                 ],
+                tags = tags,
             )
diff --git a/tensorflow/lite/experimental/ruy/side_pair.h b/tensorflow/lite/experimental/ruy/side_pair.h
new file mode 100644
index 0000000..b20a2d1
--- /dev/null
+++ b/tensorflow/lite/experimental/ruy/side_pair.h
@@ -0,0 +1,54 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_SIDE_PAIR_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_SIDE_PAIR_H_
+
+#include "tensorflow/lite/experimental/ruy/check_macros.h"
+
+namespace ruy {
+
+enum class Side { kLhs = 0, kRhs = 1 };
+
+template <typename T>
+class SidePair final {
+ public:
+  SidePair() {}
+  SidePair(const T& a, const T& b) : elem_{a, b} {}
+  const T& operator[](Side side) const {
+    const int index = static_cast<int>(side);
+    // Technically this check is vacuous, since other values would be
+    // out-of-range for enum Side.
+    RUY_DCHECK(index == 0 || index == 1);
+    return elem_[index];
+  }
+
+  T& operator[](Side side) {
+    const int index = static_cast<int>(side);
+    // Technically this check is vacuous, since other values would be
+    // out-of-range for enum Side.
+    RUY_DCHECK(index == 0 || index == 1);
+    return elem_[index];
+  }
+
+ private:
+  static_assert(static_cast<int>(Side::kLhs) == 0, "");
+  static_assert(static_cast<int>(Side::kRhs) == 1, "");
+  T elem_[2];
+};
+
+}  // namespace ruy
+
+#endif  // TENSORFLOW_LITE_EXPERIMENTAL_RUY_SIDE_PAIR_H_
diff --git a/tensorflow/lite/experimental/ruy/size_util.h b/tensorflow/lite/experimental/ruy/size_util.h
index 78ff90f..1e2fd20 100644
--- a/tensorflow/lite/experimental/ruy/size_util.h
+++ b/tensorflow/lite/experimental/ruy/size_util.h
@@ -27,7 +27,7 @@
 inline int floor_log2(int n) {
   RUY_DCHECK_GE(n, 1);
 #ifdef _WIN32
-  unsigned long result;
+  unsigned long result;  // NOLINT[runtime/int]
   _BitScanReverse(&result, n);
   return result;
 #else
diff --git a/tensorflow/lite/experimental/ruy/spec.h b/tensorflow/lite/experimental/ruy/spec.h
index 0913445..1d8c339 100644
--- a/tensorflow/lite/experimental/ruy/spec.h
+++ b/tensorflow/lite/experimental/ruy/spec.h
@@ -16,7 +16,6 @@
 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_SPEC_H_
 #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_SPEC_H_
 
-#include <cstdint>
 #include <limits>
 #include <type_traits>
 
diff --git a/tensorflow/lite/experimental/ruy/test.h b/tensorflow/lite/experimental/ruy/test.h
index d604bc7..8b2f0e1 100644
--- a/tensorflow/lite/experimental/ruy/test.h
+++ b/tensorflow/lite/experimental/ruy/test.h
@@ -16,23 +16,33 @@
 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_TEST_H_
 #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_TEST_H_
 
+#include <math.h>
+
 #include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <cstdio>
+#include <cstdlib>
 #include <ctime>
-#include <initializer_list>
 #include <iostream>
+#include <iterator>
 #include <limits>
+#include <memory>
 #include <random>
 #include <set>
 #include <sstream>
 #include <string>
+#include <tuple>
 #include <type_traits>
 #include <vector>
 
-#include <gtest/gtest.h>
+#include <gtest/gtest.h>  // IWYU pragma: export
+#include "tensorflow/lite/experimental/ruy/matrix.h"  // IWYU pragma: export
 #include "tensorflow/lite/experimental/ruy/platform.h"
 #include "tensorflow/lite/experimental/ruy/pmu.h"
 #include "tensorflow/lite/experimental/ruy/ruy.h"
 #include "tensorflow/lite/experimental/ruy/ruy_advanced.h"
+#include "tensorflow/lite/experimental/ruy/spec.h"  // IWYU pragma: export
 #include "tensorflow/lite/experimental/ruy/time.h"
 
 #ifdef RUY_TEST_EXTERNAL_PATHS
@@ -66,8 +76,12 @@
   switch (path) {
     RUY_PATHNAME_CASE(kReference)
     RUY_PATHNAME_CASE(kStandardCpp)
+#if RUY_PLATFORM(NEON)
     RUY_PATHNAME_CASE(kNeon)
     RUY_PATHNAME_CASE(kNeonDotprod)
+#elif RUY_PLATFORM(AVX512)
+    RUY_PATHNAME_CASE(kAvx512)
+#endif
     default:
       RUY_CHECK(false);
       return nullptr;
@@ -245,7 +259,7 @@
 inline std::default_random_engine& global_random_engine() {
   static std::default_random_engine engine;
   return engine;
-};
+}
 
 template <typename Scalar>
 struct UniformRandomDistribution {
@@ -660,7 +674,7 @@
           LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
           &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst,
           -lhs.zero_point, -rhs.zero_point, output_pipeline);
-    } else
+    } else  // NOLINT[readability/braces]
 #endif
     {
       const auto& output_pipeline =
@@ -680,7 +694,7 @@
           LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
           &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst,
           -lhs.zero_point, -rhs.zero_point, output_pipeline);
-    } else
+    } else  // NOLINT[readability/braces]
 #endif
     {
       const auto& output_pipeline = std::make_tuple(
@@ -1866,11 +1880,11 @@
     if (record_pmu) {
       pmu_events.StartRecording();
     }
-    TimePoint time_start = Clock::now();
+    TimePoint time_start = Now();
     TimePoint t = time_start;
     int iters = 0;
     int iters_at_a_time = 1;
-    while (ToSeconds(t - time_start) < benchmark_min_secs) {
+    while (ToFloatSeconds(t - time_start) < benchmark_min_secs) {
       for (int i = 0; i < iters_at_a_time; i++) {
         if (cold) {
           lhs.matrix.data = cold_lhs.Next();
@@ -1887,10 +1901,10 @@
         iters++;
       }
       iters_at_a_time *= 2;
-      t = Clock::now();
+      t = Now();
     }
-    latency = std::min(latency,
-                       static_cast<float>(ToSeconds(t - time_start) / iters));
+    latency = std::min(
+        latency, static_cast<float>(ToFloatSeconds(t - time_start) / iters));
     if (record_pmu) {
       pmu_events.StopRecording();
       const float normalization_factor =
diff --git a/tensorflow/lite/experimental/ruy/test_fast.cc b/tensorflow/lite/experimental/ruy/test_fast.cc
index 8e23b57..8e93d89 100644
--- a/tensorflow/lite/experimental/ruy/test_fast.cc
+++ b/tensorflow/lite/experimental/ruy/test_fast.cc
@@ -15,6 +15,8 @@
 
 // This test contains cheap test cases, completes in a few seconds.
 
+#include <vector>
+
 #include "tensorflow/lite/experimental/ruy/test.h"
 
 namespace ruy {
diff --git a/tensorflow/lite/experimental/ruy/thread_pool.cc b/tensorflow/lite/experimental/ruy/thread_pool.cc
index db69dc8..83ae085 100644
--- a/tensorflow/lite/experimental/ruy/thread_pool.cc
+++ b/tensorflow/lite/experimental/ruy/thread_pool.cc
@@ -18,10 +18,12 @@
 #include <atomic>
 #include <chrono>              // NOLINT(build/c++11)
 #include <condition_variable>  // NOLINT(build/c++11)
+#include <cstdint>
+#include <cstdlib>
+#include <memory>
 #include <mutex>               // NOLINT(build/c++11)
 #include <thread>              // NOLINT(build/c++11)
 
-#include "tensorflow/lite/experimental/ruy/blocking_counter.h"
 #include "tensorflow/lite/experimental/ruy/check_macros.h"
 #include "tensorflow/lite/experimental/ruy/wait.h"
 
@@ -153,6 +155,13 @@
 
 void ThreadPool::ExecuteImpl(int task_count, int stride, Task* tasks) {
   RUY_DCHECK_GE(task_count, 1);
+
+  // Case of 1 thread: just run the single task on the current thread.
+  if (task_count == 1) {
+    (tasks + 0)->Run();
+    return;
+  }
+
   // Task #0 will be run on the current thread.
   CreateThreads(task_count - 1);
   counter_to_decrement_when_ready_.Reset(task_count - 1);
@@ -160,8 +169,10 @@
     auto task_address = reinterpret_cast<std::uintptr_t>(tasks) + i * stride;
     threads_[i - 1]->StartWork(reinterpret_cast<Task*>(task_address));
   }
-  // Execute task #0 workload immediately on the current thread.
+
+  // Execute task #0 immediately on the current thread.
   (tasks + 0)->Run();
+
   // Wait for the threads submitted above to finish.
   counter_to_decrement_when_ready_.Wait();
 }
diff --git a/tensorflow/lite/experimental/ruy/time.h b/tensorflow/lite/experimental/ruy/time.h
index 0c656ec..07d6caa 100644
--- a/tensorflow/lite/experimental/ruy/time.h
+++ b/tensorflow/lite/experimental/ruy/time.h
@@ -17,20 +17,62 @@
 #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_TIME_H_
 
 #include <chrono>  // NOLINT(build/c++11)
+#include <cstdint>  // IWYU pragma: keep
+#include <ratio>    // NOLINT(build/c++11)
+
+#ifdef __linux__
+#include <sys/time.h>  // for CLOCK_MONOTONIC_COARSE
+
+#include <ctime>
+#endif
 
 namespace ruy {
 
-using Clock = std::chrono::steady_clock;
+using InternalDefaultClock = std::chrono::steady_clock;
 
-using TimePoint = Clock::time_point;
-using Duration = Clock::duration;
+using TimePoint = InternalDefaultClock::time_point;
+using Duration = InternalDefaultClock::duration;
 
-inline double ToSeconds(Duration d) {
-  return std::chrono::duration_cast<std::chrono::duration<double>>(d).count();
+template <typename RepresentationType>
+Duration DurationFromSeconds(RepresentationType representation) {
+  return std::chrono::duration_cast<Duration>(
+      std::chrono::duration<RepresentationType>(representation));
 }
 
-inline Duration DurationFromSeconds(double s) {
-  return std::chrono::duration_cast<Duration>(std::chrono::duration<double>(s));
+template <typename RepresentationType>
+Duration DurationFromMilliseconds(RepresentationType representation) {
+  return std::chrono::duration_cast<Duration>(
+      std::chrono::duration<RepresentationType, std::milli>(representation));
+}
+
+template <typename RepresentationType>
+Duration DurationFromNanoseconds(RepresentationType representation) {
+  return std::chrono::duration_cast<Duration>(
+      std::chrono::duration<RepresentationType, std::nano>(representation));
+}
+
+inline float ToFloatSeconds(const Duration& duration) {
+  return std::chrono::duration_cast<std::chrono::duration<float>>(duration)
+      .count();
+}
+
+inline std::int64_t ToInt64Nanoseconds(const Duration& duration) {
+  return std::chrono::duration_cast<
+             std::chrono::duration<std::int64_t, std::nano>>(duration)
+      .count();
+}
+
+inline TimePoint Now() { return InternalDefaultClock::now(); }
+
+inline TimePoint CoarseNow() {
+#ifdef __linux__
+  timespec t;
+  clock_gettime(CLOCK_MONOTONIC_COARSE, &t);
+  return TimePoint(
+      DurationFromNanoseconds(1000000000LL * t.tv_sec + t.tv_nsec));
+#else
+  return Now();
+#endif
 }
 
 }  // namespace ruy
diff --git a/tensorflow/lite/experimental/ruy/trace.cc b/tensorflow/lite/experimental/ruy/trace.cc
index c84a59e..55b2fed 100644
--- a/tensorflow/lite/experimental/ruy/trace.cc
+++ b/tensorflow/lite/experimental/ruy/trace.cc
@@ -16,207 +16,166 @@
 #include "tensorflow/lite/experimental/ruy/trace.h"
 
 #include <algorithm>
-#include <cerrno>
-#include <cstdint>
+#include <cerrno>  // IWYU pragma: keep
 #include <cstdio>
+#include <cstdlib>
 #include <string>
 #include <vector>
 
-#include "tensorflow/lite/experimental/ruy/block_map.h"
 #include "tensorflow/lite/experimental/ruy/check_macros.h"
-#include "tensorflow/lite/experimental/ruy/common.h"
+#include "tensorflow/lite/experimental/ruy/side_pair.h"
 #include "tensorflow/lite/experimental/ruy/time.h"
 
 namespace ruy {
 
 #ifdef RUY_TRACE
 
-struct BlockTraceEntry {
-  std::uint32_t thread_id = 0;
-  TimePoint time_reserved;
-  TimePoint time_computed_coords;
-  TimePoint time_packed_lhs;
-  TimePoint time_packed_rhs;
-  TimePoint time_finished;
+enum class TraceEvent : std::uint8_t {
+  kNone,
+  kThreadStart,
+  kThreadLoopStart,
+  kThreadEnd,
+  kBlockReserved,
+  kBlockPackedLhs,
+  kBlockPackedRhs,
+  kBlockFinished
 };
 
-struct ThreadTraceEntry {
-  TimePoint time_start;
-  TimePoint time_loop_start;
-  TimePoint time_end;
+struct TraceEntry {
+  TimePoint time_point;
+  TraceEvent event;
+  // ruy-internal thread id i.e. contiguous index into array of threads,
+  // with 0 designating the main thread.
+  std::uint16_t thread_id = 0;
+  // Additional parameters whose meaning depends on the 'event' type.
+  std::uint32_t params[1];
 };
 
 struct Trace {
-  enum class LifeStage {
-    kInitial,
-    kRecordingRootFields,
-    kRecordingBlockAndThreadFields,
-    kComplete
-  };
-  void StartRecordingBlockAndThreadFields(const BlockMap& block_map_,
-                                          int thread_count_) {
-    RUY_DCHECK(life_stage == LifeStage::kRecordingRootFields);
-    block_map = block_map_;
-    thread_count = thread_count_;
-    int num_blocks = NumBlocks(block_map);
-    if (num_blocks > block_entries.size()) {
-      block_entries.resize(NumBlocks(block_map));
-    }
-    if (thread_count > thread_entries.size()) {
-      thread_entries.resize(thread_count);
-    }
-    life_stage = LifeStage::kRecordingBlockAndThreadFields;
-  }
   BlockMap block_map;
   int thread_count = 0;
-  std::vector<BlockTraceEntry> block_entries;
-  std::vector<ThreadTraceEntry> thread_entries;
+  // During recording, to avoid having to use locks or atomics, we let
+  // each thread append to its own specific vector.
+  std::vector<std::vector<TraceEntry>> thread_specific_entries;
+  // Global vector of entries into which we coalesce thread_specific_entries
+  // after recording is finished, when dumping a trace. See
+  // AggregateThreadSpecificEntries.
+  std::vector<TraceEntry> entries;
   TimePoint time_start;
   TimePoint time_execute;
   TimePoint time_end;
-  LifeStage life_stage = LifeStage::kInitial;
 };
 
-struct ProcessedTrace {
-  enum class Event : std::uint8_t {
-    kNone,
-    kThreadStart,
-    kThreadLoopStart,
-    kThreadEnd,
-    kBlockReserved,
-    kBlockComputedCoords,
-    kBlockPackedLhs,
-    kBlockPackedRhs,
-    kBlockFinished
-  };
-  struct Entry {
-    Event event = Event::kNone;
-    std::uint32_t thread_id = 0;
-    std::uint32_t block_id = 0;
-    TimePoint time;
-  };
+namespace {
 
-  BlockMap block_map;
-  int thread_count = 0;
-  TimePoint time_start;
-  TimePoint time_execute;
-  TimePoint time_end;
-  std::vector<Entry> entries;
-  void Add(Event event, std::uint32_t thread_id, std::uint32_t block_id,
-           TimePoint time) {
-    // If the time point is still in its default-constructed state,
-    // that means we didn't record it.
-    if (!time.time_since_epoch().count()) {
-      return;
+// Coalesce Trace::thread_specific_entries into Trace::entries.
+void AggregateThreadSpecificEntries(Trace* trace) {
+  RUY_CHECK(trace->entries.empty());
+  for (auto& thread_specific_entries_vector : trace->thread_specific_entries) {
+    for (const TraceEntry& entry : thread_specific_entries_vector) {
+      trace->entries.push_back(entry);
     }
-    Entry entry;
-    entry.event = event;
-    entry.thread_id = thread_id;
-    entry.block_id = block_id;
-    entry.time = time;
-    entries.push_back(entry);
+    thread_specific_entries_vector.clear();
   }
-  void Process(const Trace& trace) {
-    thread_count = trace.thread_count;
-    block_map = trace.block_map;
-    time_start = trace.time_start;
-    time_execute = trace.time_execute;
-    time_end = trace.time_end;
-    entries.clear();
-    for (int i = 0; i < trace.thread_count; i++) {
-      const auto& entry = trace.thread_entries[i];
-      Add(Event::kThreadStart, i, 0, entry.time_start);
-      Add(Event::kThreadLoopStart, i, 0, entry.time_loop_start);
-      Add(Event::kThreadEnd, i, 0, entry.time_end);
-    }
-    std::uint32_t num_blocks = NumBlocks(block_map);
-    for (int i = 0; i < num_blocks; i++) {
-      const auto& entry = trace.block_entries[i];
-      Add(Event::kBlockReserved, entry.thread_id, i, entry.time_reserved);
-      Add(Event::kBlockComputedCoords, entry.thread_id, i,
-          entry.time_computed_coords);
-      Add(Event::kBlockPackedLhs, entry.thread_id, i, entry.time_packed_lhs);
-      Add(Event::kBlockPackedRhs, entry.thread_id, i, entry.time_packed_rhs);
-      Add(Event::kBlockFinished, entry.thread_id, i, entry.time_finished);
-    }
-    std::sort(entries.begin(), entries.end(),
-              [](const Entry& a, const Entry& b) -> bool {
-                return a.time < b.time ||
-                       (a.time == b.time &&
-                        static_cast<int>(a.event) < static_cast<int>(b.event));
-              });
-  }
-  void Dump() {
-    const char* trace_filename = getenv("RUY_TRACE_FILE");
-    FILE* trace_file = trace_filename ? fopen(trace_filename, "w") : stderr;
-    if (!trace_file) {
-      fprintf(stderr, "Failed to open %s for write, errno=%d\n", trace_filename,
-              errno);
-      RUY_CHECK(false);
-    }
-    fprintf(trace_file, "thread_count:%d\n", thread_count);
-    fprintf(trace_file, "num_blocks:%d\n", NumBlocks(block_map));
-    fprintf(trace_file, "rows:%d\n", block_map.rows);
-    fprintf(trace_file, "cols:%d\n", block_map.cols);
-    fprintf(trace_file, "Execute: %.9f\n",
-            ToSeconds(time_execute - time_start));
-    for (const Entry& entry : entries) {
-      double time = ToSeconds(entry.time - time_start);
-      switch (entry.event) {
-        case Event::kThreadStart:
-          fprintf(trace_file, "ThreadStart: %.9f, %d\n", time, entry.thread_id);
-          break;
-        case Event::kThreadLoopStart:
-          fprintf(trace_file, "ThreadLoopStart: %.9f, %d\n", time,
-                  entry.thread_id);
-          break;
-        case Event::kThreadEnd:
-          fprintf(trace_file, "ThreadEnd: %.9f, %d\n", time, entry.thread_id);
-          break;
-        case Event::kBlockReserved: {
-          std::uint16_t block_r, block_c;
-          int start_r, start_c, end_r, end_c;
-          GetBlockByIndex(block_map, entry.block_id, &block_r, &block_c);
-          GetBlockMatrixCoords(block_map, block_r, block_c, &start_r, &start_c,
-                               &end_r, &end_c);
-          fprintf(trace_file, "BlockReserved: %.9f, %d, %d, %d, %d, %d, %d\n",
-                  time, entry.thread_id, entry.block_id, start_r, start_c,
-                  end_r, end_c);
-          break;
-        }
-        case Event::kBlockComputedCoords:
-          fprintf(trace_file, "BlockComputedCoords: %.9f, %d, %d\n", time,
-                  entry.thread_id, entry.block_id);
-          break;
-        case Event::kBlockPackedLhs:
-          fprintf(trace_file, "BlockPackedLhs: %.9f, %d, %d\n", time,
-                  entry.thread_id, entry.block_id);
-          break;
-        case Event::kBlockPackedRhs:
-          fprintf(trace_file, "BlockPackedRhs: %.9f, %d, %d\n", time,
-                  entry.thread_id, entry.block_id);
-          break;
-        case Event::kBlockFinished:
-          fprintf(trace_file, "BlockFinished: %.9f, %d, %d\n", time,
-                  entry.thread_id, entry.block_id);
-          break;
-        default:
-          RUY_CHECK(false);
-      }
-    }
-    fprintf(trace_file, "End: %.9f\n", ToSeconds(time_end - time_start));
-    if (trace_filename) {
-      fclose(trace_file);
-    }
-  }
-};
-
-void DumpTrace(const Trace& trace) {
-  ProcessedTrace processed_trace;
-  processed_trace.Process(trace);
-  processed_trace.Dump();
 }
 
+// Sort Trace::entries by ascending time. In case of equal timepoints,
+// sort by some semi-arbitrary ordering of event types.
+void Sort(Trace* trace) {
+  std::sort(std::begin(trace->entries), std::end(trace->entries),
+            [](const TraceEntry& a, const TraceEntry& b) -> bool {
+              return a.time_point < b.time_point ||
+                     (a.time_point == b.time_point &&
+                      static_cast<int>(a.event) < static_cast<int>(b.event));
+            });
+}
+
+// Dump a trace. Assumes that AggregateThreadSpecificEntries and Sort have
+// already been called on it.
+//
+// On some architectures long long ints are not same as std::int64_t, and
+// time is printed as %lld, so static_casts are necessary.
+void Dump(const Trace& trace) {
+  const char* trace_filename = getenv("RUY_TRACE_FILE");
+  FILE* trace_file = trace_filename ? fopen(trace_filename, "w") : stderr;
+  if (!trace_file) {
+    fprintf(stderr, "Failed to open %s for write, errno=%d\n", trace_filename,
+            errno);
+    RUY_CHECK(false);
+  }
+  fprintf(trace_file, "thread_count:%d\n", trace.thread_count);
+  fprintf(trace_file, "rows:%d\n", trace.block_map.dims[Side::kLhs]);
+  fprintf(trace_file, "cols:%d\n", trace.block_map.dims[Side::kRhs]);
+  fprintf(trace_file, "Execute: %lld\n",
+          static_cast<long long int>(
+              ToInt64Nanoseconds(trace.time_execute - trace.time_start)));
+  for (const TraceEntry& entry : trace.entries) {
+    long long int time = static_cast<long long int>(
+        ToInt64Nanoseconds(entry.time_point - trace.time_start));
+    switch (entry.event) {
+      case TraceEvent::kThreadStart:
+        fprintf(trace_file, "ThreadStart: %lld, %d\n", time, entry.thread_id);
+        break;
+      case TraceEvent::kThreadLoopStart:
+        fprintf(trace_file, "ThreadLoopStart: %lld, %d\n", time,
+                entry.thread_id);
+        break;
+      case TraceEvent::kThreadEnd:
+        fprintf(trace_file, "ThreadEnd: %lld, %d\n", time, entry.thread_id);
+        break;
+      case TraceEvent::kBlockReserved: {
+        std::uint32_t block_id = entry.params[0];
+        SidePair<int> block;
+        GetBlockByIndex(trace.block_map, block_id, &block);
+        SidePair<int> start, end;
+        GetBlockMatrixCoords(trace.block_map, block, &start, &end);
+        fprintf(trace_file,
+                "BlockReserved: %lld, %d, %d, %d, %d, %d, %d, %d, %d\n", time,
+                entry.thread_id, block_id, block[Side::kLhs], block[Side::kRhs],
+                start[Side::kLhs], start[Side::kRhs], end[Side::kLhs],
+                end[Side::kRhs]);
+        break;
+      }
+      case TraceEvent::kBlockPackedLhs: {
+        std::uint32_t block = entry.params[0];
+        int start, end;
+        GetBlockMatrixCoords(Side::kLhs, trace.block_map, block, &start, &end);
+        fprintf(trace_file, "BlockPackedLhs: %lld, %d, %d, %d, %d\n", time,
+                entry.thread_id, block, start, end);
+        break;
+      }
+      case TraceEvent::kBlockPackedRhs: {
+        std::uint32_t block = entry.params[0];
+        int start, end;
+        GetBlockMatrixCoords(Side::kRhs, trace.block_map, block, &start, &end);
+        fprintf(trace_file, "BlockPackedRhs: %lld, %d, %d, %d, %d\n", time,
+                entry.thread_id, block, start, end);
+        break;
+      }
+      case TraceEvent::kBlockFinished: {
+        std::uint32_t block_id = entry.params[0];
+        SidePair<int> block;
+        GetBlockByIndex(trace.block_map, block_id, &block);
+        fprintf(trace_file, "BlockFinished: %lld, %d, %d, %d, %d\n", time,
+                entry.thread_id, block_id, block[Side::kLhs],
+                block[Side::kRhs]);
+        break;
+      }
+      default:
+        RUY_CHECK(false);
+    }
+  }
+  fprintf(trace_file, "End: %lld\n",
+          static_cast<long long int>(
+              ToInt64Nanoseconds(trace.time_end - trace.time_start)));
+  if (trace_filename) {
+    fclose(trace_file);
+  }
+}
+
+}  // anonymous namespace
+
+// Get a Trace object to record to, or null of tracing is not enabled.
 Trace* NewTraceOrNull(TracingContext* tracing, int rows, int depth, int cols) {
   if (!tracing->initialized) {
     tracing->initialized = true;
@@ -253,130 +212,114 @@
   return tracing->trace;
 }
 
+// The trace recorded on a context is finalized and dumped by
+// this TracingContext destructor.
+//
+// The idea of dumping on context destructor is that typically one wants to
+// run many matrix multiplications, e.g. to hit a steady state in terms of
+// performance characteristics, but only trace the last repetition of the
+// workload, when that steady state was attained.
 TracingContext::~TracingContext() {
   if (trace) {
-    DumpTrace(*trace);
+    AggregateThreadSpecificEntries(trace);
+    Sort(trace);
+    Dump(*trace);
   }
   delete trace;
 }
 
+void TraceRecordStart(Trace* trace) {
+  if (trace) {
+    trace->time_start = Now();
+  }
+}
+
+void TraceRecordExecute(const BlockMap& block_map, int thread_count,
+                        Trace* trace) {
+  if (trace) {
+    trace->time_execute = Now();
+    trace->block_map = block_map;
+    trace->thread_count = thread_count;
+    trace->thread_specific_entries.resize(thread_count);
+    for (int thread = 0; thread < thread_count; thread++) {
+      trace->thread_specific_entries[thread].clear();
+      // Reserve some large size to avoid frequent heap allocations
+      // affecting the recorded timings.
+      trace->thread_specific_entries[thread].reserve(16384);
+    }
+  }
+}
+
+void TraceRecordEnd(Trace* trace) {
+  if (trace) {
+    trace->time_end = Now();
+  }
+}
+
 void TraceRecordThreadStart(std::uint32_t thread_id, Trace* trace) {
   if (trace) {
-    RUY_DCHECK(trace->life_stage ==
-               Trace::LifeStage::kRecordingBlockAndThreadFields);
-    relaxed_atomic_store(&trace->block_entries[thread_id].thread_id, thread_id);
-    TimePoint now = Clock::now();
-    relaxed_atomic_store(&trace->block_entries[thread_id].time_reserved, now);
-    relaxed_atomic_store(&trace->thread_entries[thread_id].time_start, now);
+    TraceEntry entry;
+    entry.event = TraceEvent::kThreadStart;
+    entry.time_point = Now();
+    entry.thread_id = thread_id;
+    trace->thread_specific_entries[thread_id].push_back(entry);
   }
 }
 
 void TraceRecordThreadLoopStart(std::uint32_t thread_id, Trace* trace) {
   if (trace) {
-    RUY_DCHECK(trace->life_stage ==
-               Trace::LifeStage::kRecordingBlockAndThreadFields);
-    TimePoint now = Clock::now();
-    relaxed_atomic_store(&trace->thread_entries[thread_id].time_loop_start,
-                         now);
+    TraceEntry entry;
+    entry.event = TraceEvent::kThreadLoopStart;
+    entry.time_point = Now();
+    entry.thread_id = thread_id;
+    trace->thread_specific_entries[thread_id].push_back(entry);
   }
 }
 
 void TraceRecordBlockReserved(std::uint32_t thread_id, std::uint32_t block_id,
                               Trace* trace) {
   if (trace) {
-    RUY_DCHECK(trace->life_stage ==
-               Trace::LifeStage::kRecordingBlockAndThreadFields);
-    // This is typically called on the next block id just obtained by atomic
-    // increment; this may be out of range.
-    if (block_id < trace->block_entries.size()) {
-      relaxed_atomic_store(&trace->block_entries[block_id].thread_id,
-                           thread_id);
-      TimePoint now = Clock::now();
-      relaxed_atomic_store(&trace->block_entries[block_id].time_reserved, now);
-    }
+    TraceEntry entry;
+    entry.event = TraceEvent::kBlockReserved;
+    entry.time_point = Now();
+    entry.thread_id = thread_id;
+    entry.params[0] = block_id;
+    trace->thread_specific_entries[thread_id].push_back(entry);
   }
 }
 
-void TraceRecordBlockCoordsComputed(std::uint32_t block_id, Trace* trace) {
+void TraceRecordBlockPacked(std::uint32_t thread_id, Side side, int block,
+                            Trace* trace) {
   if (trace) {
-    RUY_DCHECK(trace->life_stage ==
-               Trace::LifeStage::kRecordingBlockAndThreadFields);
-    TimePoint now = Clock::now();
-    relaxed_atomic_store(&trace->block_entries[block_id].time_computed_coords,
-                         now);
+    TraceEntry entry;
+    entry.event = side == Side::kLhs ? TraceEvent::kBlockPackedLhs
+                                     : TraceEvent::kBlockPackedRhs;
+    entry.time_point = Now();
+    entry.thread_id = thread_id;
+    entry.params[0] = block;
+    trace->thread_specific_entries[thread_id].push_back(entry);
   }
 }
 
-void TraceRecordBlockPackedLhs(std::uint32_t block_id, Trace* trace) {
+void TraceRecordBlockFinished(std::uint32_t thread_id, std::uint32_t block_id,
+                              Trace* trace) {
   if (trace) {
-    RUY_DCHECK(trace->life_stage ==
-               Trace::LifeStage::kRecordingBlockAndThreadFields);
-    TimePoint now = Clock::now();
-    relaxed_atomic_store(&trace->block_entries[block_id].time_packed_lhs, now);
-  }
-}
-
-void TraceRecordBlockPackedRhs(std::uint32_t block_id, Trace* trace) {
-  if (trace) {
-    RUY_DCHECK(trace->life_stage ==
-               Trace::LifeStage::kRecordingBlockAndThreadFields);
-    TimePoint now = Clock::now();
-    relaxed_atomic_store(&trace->block_entries[block_id].time_packed_rhs, now);
-  }
-}
-
-void TraceRecordBlockFinished(std::uint32_t block_id, Trace* trace) {
-  if (trace) {
-    RUY_DCHECK(trace->life_stage ==
-               Trace::LifeStage::kRecordingBlockAndThreadFields);
-    TimePoint now = Clock::now();
-    relaxed_atomic_store(&trace->block_entries[block_id].time_finished, now);
+    TraceEntry entry;
+    entry.event = TraceEvent::kBlockFinished;
+    entry.time_point = Now();
+    entry.thread_id = thread_id;
+    entry.params[0] = block_id;
+    trace->thread_specific_entries[thread_id].push_back(entry);
   }
 }
 
 void TraceRecordThreadEnd(std::uint32_t thread_id, Trace* trace) {
   if (trace) {
-    RUY_DCHECK(trace->life_stage ==
-               Trace::LifeStage::kRecordingBlockAndThreadFields);
-    TimePoint now = Clock::now();
-    relaxed_atomic_store(&trace->thread_entries[thread_id].time_end, now);
-  }
-}
-
-void TraceRecordStart(Trace* trace) {
-  if (trace) {
-    RUY_DCHECK(trace->life_stage == Trace::LifeStage::kInitial ||
-               trace->life_stage == Trace::LifeStage::kComplete);
-    TimePoint now = Clock::now();
-    relaxed_atomic_store(&trace->time_start, now);
-    trace->life_stage = Trace::LifeStage::kRecordingRootFields;
-  }
-}
-
-void TraceRecordExecute(Trace* trace) {
-  if (trace) {
-    RUY_DCHECK(trace->life_stage == Trace::LifeStage::kRecordingRootFields);
-    TimePoint now = Clock::now();
-    relaxed_atomic_store(&trace->time_execute, now);
-  }
-}
-
-void TraceRecordEnd(Trace* trace) {
-  if (trace) {
-    RUY_DCHECK(trace->life_stage ==
-               Trace::LifeStage::kRecordingBlockAndThreadFields);
-    TimePoint now = Clock::now();
-    relaxed_atomic_store(&trace->time_end, now);
-    trace->life_stage = Trace::LifeStage::kComplete;
-  }
-}
-
-void TraceStartRecordingBlockAndThreadFields(const BlockMap& block_map,
-                                             int thread_count, Trace* trace) {
-  if (trace) {
-    RUY_DCHECK(trace->life_stage == Trace::LifeStage::kRecordingRootFields);
-    trace->StartRecordingBlockAndThreadFields(block_map, thread_count);
-    trace->life_stage = Trace::LifeStage::kRecordingBlockAndThreadFields;
+    TraceEntry entry;
+    entry.event = TraceEvent::kThreadEnd;
+    entry.time_point = Now();
+    entry.thread_id = thread_id;
+    trace->thread_specific_entries[thread_id].push_back(entry);
   }
 }
 
diff --git a/tensorflow/lite/experimental/ruy/trace.h b/tensorflow/lite/experimental/ruy/trace.h
index ecd793d..db02821 100644
--- a/tensorflow/lite/experimental/ruy/trace.h
+++ b/tensorflow/lite/experimental/ruy/trace.h
@@ -16,12 +16,10 @@
 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_TRACE_H_
 #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_TRACE_H_
 
-#include <algorithm>
 #include <cstdint>
-#include <cstdio>
-#include <vector>
 
 #include "tensorflow/lite/experimental/ruy/block_map.h"
+#include "tensorflow/lite/experimental/ruy/side_pair.h"
 
 namespace ruy {
 
@@ -39,23 +37,20 @@
   ~TracingContext();
 };
 
-void DumpTrace(const Trace& trace);
-
 Trace* NewTraceOrNull(TracingContext* context, int rows, int depth, int cols);
 void TraceRecordThreadStart(std::uint32_t thread_id, Trace* trace);
 void TraceRecordThreadLoopStart(std::uint32_t thread_id, Trace* trace);
 void TraceRecordBlockReserved(std::uint32_t thread_id, std::uint32_t block_id,
                               Trace* trace);
-void TraceRecordBlockCoordsComputed(std::uint32_t block_id, Trace* trace);
-void TraceRecordBlockPackedLhs(std::uint32_t block_id, Trace* trace);
-void TraceRecordBlockPackedRhs(std::uint32_t block_id, Trace* trace);
-void TraceRecordBlockFinished(std::uint32_t block_id, Trace* trace);
+void TraceRecordBlockPacked(std::uint32_t thread_id, Side side, int block,
+                            Trace* trace);
+void TraceRecordBlockFinished(std::uint32_t thread_id, std::uint32_t block_id,
+                              Trace* trace);
 void TraceRecordThreadEnd(std::uint32_t thread_id, Trace* trace);
 void TraceRecordStart(Trace* trace);
-void TraceRecordExecute(Trace* trace);
+void TraceRecordExecute(const BlockMap& block_map, int thread_count,
+                        Trace* trace);
 void TraceRecordEnd(Trace* trace);
-void TraceStartRecordingBlockAndThreadFields(const BlockMap& block_map,
-                                             int thread_count, Trace* trace);
 
 #else
 
@@ -65,16 +60,12 @@
 inline void TraceRecordThreadStart(std::uint32_t, Trace*) {}
 inline void TraceRecordThreadLoopStart(std::uint32_t, Trace*) {}
 inline void TraceRecordBlockReserved(std::uint32_t, std::uint32_t, Trace*) {}
-inline void TraceRecordBlockCoordsComputed(std::uint32_t, Trace*) {}
-inline void TraceRecordBlockPackedLhs(std::uint32_t, Trace*) {}
-inline void TraceRecordBlockPackedRhs(std::uint32_t, Trace*) {}
-inline void TraceRecordBlockFinished(std::uint32_t, Trace*) {}
+inline void TraceRecordBlockPacked(std::uint32_t, Side, int, Trace*) {}
+inline void TraceRecordBlockFinished(std::uint32_t, std::uint32_t, Trace*) {}
 inline void TraceRecordThreadEnd(std::uint32_t, Trace*) {}
 inline void TraceRecordStart(Trace*) {}
-inline void TraceRecordExecute(Trace*) {}
+inline void TraceRecordExecute(const BlockMap&, int, Trace*) {}
 inline void TraceRecordEnd(Trace*) {}
-inline void TraceStartRecordingBlockAndThreadFields(const BlockMap&, int,
-                                                    Trace*) {}
 
 #endif
 
diff --git a/tensorflow/lite/experimental/ruy/trmul.cc b/tensorflow/lite/experimental/ruy/trmul.cc
index 39f0171..5776a89 100644
--- a/tensorflow/lite/experimental/ruy/trmul.cc
+++ b/tensorflow/lite/experimental/ruy/trmul.cc
@@ -15,114 +15,98 @@
 
 #include "tensorflow/lite/experimental/ruy/trmul.h"
 
+#include <atomic>
+#include <cstdint>
 #include <cstring>
+#include <memory>
+#include <vector>
 
 #include "profiling/instrumentation.h"
 #include "tensorflow/lite/experimental/ruy/allocator.h"
 #include "tensorflow/lite/experimental/ruy/block_map.h"
+#include "tensorflow/lite/experimental/ruy/check_macros.h"
 #include "tensorflow/lite/experimental/ruy/common.h"
+#include "tensorflow/lite/experimental/ruy/internal_matrix.h"
+#include "tensorflow/lite/experimental/ruy/matrix.h"
 #include "tensorflow/lite/experimental/ruy/opt_set.h"
+#include "tensorflow/lite/experimental/ruy/side_pair.h"
+#include "tensorflow/lite/experimental/ruy/size_util.h"
+#include "tensorflow/lite/experimental/ruy/spec.h"
 #include "tensorflow/lite/experimental/ruy/thread_pool.h"
 #include "tensorflow/lite/experimental/ruy/trace.h"
+#include "tensorflow/lite/experimental/ruy/tune.h"
 
 namespace ruy {
 
 namespace {
 
+enum class PackingStatus : std::uint8_t { kNotStarted, kInProgress, kFinished };
+
 struct TrMulTask final : Task {
   TrMulTask(TrMulParams* params_, const BlockMap& block_map_,
-            std::atomic<std::uint32_t>* atomic_n_, std::uint32_t thread_id_,
-            std::atomic<bool>* lhs_packed_, std::atomic<bool>* rhs_packed_,
+            std::atomic<int>* atomic_block_id_, int thread_id_,
+            bool need_atomics_,
+            SidePair<std::atomic<PackingStatus>*> packing_status_,
             TuningResolver* tuning_resolver_, Allocator* local_allocator_,
             Trace* trace_)
       : params(params_),
         block_map(block_map_),
-        atomic_n(atomic_n_),
+        atomic_block_id(atomic_block_id_),
         thread_id(thread_id_),
-        lhs_packed(lhs_packed_),
-        rhs_packed(rhs_packed_),
+        need_atomics(need_atomics_),
+        packing_status(packing_status_),
         tuning_resolver(tuning_resolver_),
         local_allocator(local_allocator_),
-        trace(trace_) {}
+        trace(trace_),
+        local_packed{nullptr, nullptr} {}
 
   void Run() override {
     TraceRecordThreadStart(thread_id, trace);
 
-    std::uint16_t num_blocks_of_rows = NumBlocksOfRows(block_map);
-    std::uint16_t num_blocks_of_cols = NumBlocksOfCols(block_map);
-    std::uint32_t num_blocks = NumBlocks(block_map);
-
-    bool* local_lhs_packed = nullptr;
-    bool* local_rhs_packed = nullptr;
-
-    if (lhs_packed) {
-      local_allocator->Allocate(num_blocks_of_rows, &local_lhs_packed);
-      memset(local_lhs_packed, 0, num_blocks_of_rows * sizeof(bool));
+    for (Side side : {Side::kLhs, Side::kRhs}) {
+      if (!params->is_prepacked[side]) {
+        const int size = NumBlocksPerSide(side, block_map);
+        local_allocator->Allocate(size, &local_packed[side]);
+        memset(local_packed[side], 0, size * sizeof(bool));
+      }
     }
-    if (rhs_packed) {
-      local_allocator->Allocate(num_blocks_of_cols, &local_rhs_packed);
-      memset(local_rhs_packed, 0, num_blocks_of_cols * sizeof(bool));
-    }
+
+    const int num_blocks = NumBlocks(block_map);
 
     const Tuning tuning = tuning_resolver->Resolve();
 
     TraceRecordThreadLoopStart(thread_id, trace);
 
-    std::uint16_t block_r, block_c;
-    int start_r, start_c, end_r, end_c;
+    SidePair<int> block;
+    SidePair<int> start;
+    SidePair<int> end;
 
     // Each thread starts by initially reserving the block whose id
     // is the thread id.
-    std::uint32_t n = thread_id;
-    TraceRecordBlockReserved(thread_id, n, trace);
+    int block_id = thread_id;
+    TraceRecordBlockReserved(thread_id, block_id, trace);
 
-    while (n < num_blocks) {
+    while (block_id < num_blocks) {
       // Reserve the next block to handle. In order to hide the latency
       // (typically comparable to an access to the level of data cache that
       // is shared among CPU cores, e.g. 60 cycles on an ARM CPU as of 2019)
       // of this atomic operation, we structure this code so as to avoid
       // immediately depending on the `next_n` result.
-      const std::uint32_t next_n =
-          atomic_n->fetch_add(1, std::memory_order_relaxed);
-      TraceRecordBlockReserved(thread_id, next_n, trace);
+      const int next_block_id =
+          atomic_block_id->fetch_add(1, std::memory_order_relaxed);
+      TraceRecordBlockReserved(thread_id, next_block_id, trace);
       // Get coordinates of the current block to handle, in "block space".
-      GetBlockByIndex(block_map, n, &block_r, &block_c);
+      GetBlockByIndex(block_map, block_id, &block);
       // Get coordinates of the current block to handle, in matrix space.
-      GetBlockMatrixCoords(block_map, block_r, block_c, &start_r, &start_c,
-                           &end_r, &end_c);
-      TraceRecordBlockCoordsComputed(n, trace);
-      // Maybe pack the current LHS block, if not already packed.
-      // Note that if two threads concurrently hit the same LHS block to pack,
-      // we allow them to concurrently pack it, writing the same packed matrix
-      // data to the same location. That is considered worth it to avoid
-      // having one thread blocked on another one. Avoiding that is considered
-      // important especially on mobile, where there can be large speed
-      // discrepancy between threads, e.g. if different threads are scheduled
-      // on CPU cores of different types (big/little), different clock speed,
-      // different contention with other processes.
-      if (local_lhs_packed && !local_lhs_packed[block_r]) {
-        if (!lhs_packed[block_r].load(std::memory_order_acquire)) {
-          params->LhsRunPack(tuning, start_r, end_r);
-          TraceRecordBlockPackedLhs(n, trace);
-          local_lhs_packed[block_r] = true;
-          lhs_packed[block_r].store(true, std::memory_order_release);
-        }
-      }
-      // Maybe pack the current RHS block. Same comments as above for LHS.
-      if (local_rhs_packed && !local_rhs_packed[block_c]) {
-        if (!rhs_packed[block_c].load(std::memory_order_acquire)) {
-          params->RhsRunPack(tuning, start_c, end_c);
-          TraceRecordBlockPackedRhs(n, trace);
-          local_rhs_packed[block_c] = true;
-          rhs_packed[block_c].store(true, std::memory_order_release);
-        }
-      }
+      GetBlockMatrixCoords(block_map, block, &start, &end);
+      // Maybe pack the current LHS/RHS block, if not already packed.
+      EnsurePacked(block, start, end, tuning);
       // Actually do matrix multiplication work
-      params->RunKernel(tuning, start_r, start_c, end_r, end_c);
-      TraceRecordBlockFinished(n, trace);
+      params->RunKernel(tuning, start, end);
+      TraceRecordBlockFinished(thread_id, block_id, trace);
       // Move on to the next block as obtained by the atomic increment
       // at the start of this while loop iteration.
-      n = next_n;
+      block_id = next_block_id;
     }
 
     local_allocator->FreeAll();
@@ -131,15 +115,128 @@
   }
 
  private:
+  // Tries to pack a block, without blocking.
+  // If the block was already packed, returns true.
+  // If the block was not started packing, packs it and returns true.
+  // If the block was being packed by another thread, returns false.
+  bool TryPack(Side side, int block, int start, int end, Tuning tuning) {
+    if (params->is_prepacked[side]) {
+      return true;
+    }
+    if (!local_packed[side][block]) {
+      if (need_atomics) {
+        // Explanation of this compare_exchange_strong operation:
+        // This atomically performs all of the following:
+        // 1. Read `status` with "acquire" memory order.
+        //    * That this read uses "acquire" is because both memory orders
+        //      specified have "acquire" as their read-component.
+        // 2. Compare (bitwise) with `exchanged_status`.
+        // 3. If equal, stores the value kInProgress to `status` with "release"
+        //    memory order, and returns true, so we take this 'if' branch.
+        //    * That this store uses "release" is because of the _rel part in
+        //      memory_order_acq_rel passed as the first memory order argument.
+        // 4. If not equal, stores the loaded value of `status` to
+        //    `exchanged_status` with "relaxed" semantics, and returns false,
+        //    so we take the 'else' branch.
+        //    * That this store uses "relaxed" is because the second memory
+        //      order argument, memory_order_acquire, implies no particular
+        //      store semantics. "relaxed" is acceptable here because this
+        //      stores to a local stack variable.
+        //
+        // Rationale for compare_exchange_strong as opposed to
+        // compare_exchange_weak:
+        // The spurious-failure case with compare_exchange_weak will actually
+        // happen a lot here, because the atomic 'status' bytes are stored
+        // contiguously in arrays and neighboring values will be accessed
+        // by multiple threads concurrently. On a typical ARM CPU, an exclusives
+        // reservation granule is 64 bytes, so a lot of false-sharing may
+        // happen. Using compare_exchange_weak would thus result in often having
+        // TryPack return 'false' when it could instead have done the packing
+        // work and returned 'true'. Heuristically, that is not a good thing.
+        // Moreover, this changes the TryPack contract, loosening it and making
+        // it harder for the caller to reason about. Finally, the overhead of
+        // atomic operations is mitigated by the enclosing check on
+        // local_packed, so maybe the overhead of compare_exchange_strong isn't
+        // such a problem. But we don't really know for sure, that would be
+        // interesting to experiment more with.
+        PackingStatus exchanged_status = PackingStatus::kNotStarted;
+        std::atomic<PackingStatus>& status = packing_status[side][block];
+        if (status.compare_exchange_strong(
+                exchanged_status, PackingStatus::kInProgress,
+                std::memory_order_acq_rel, std::memory_order_acquire)) {
+          // In this branch, the status was kNotStarted and we just atomically
+          // changed it to kInProgress as we are about to handle the packing
+          // ourselves.
+          params->RunPack(side, tuning, start, end);
+          TraceRecordBlockPacked(thread_id, side, block, trace);
+          status.store(PackingStatus::kFinished, std::memory_order_release);
+        } else if (exchanged_status == PackingStatus::kInProgress) {
+          // Another thread is currently packing this block.
+          return false;
+        }
+        RUY_DCHECK(status.load(std::memory_order_acquire) ==
+                   PackingStatus::kFinished);
+      } else {
+        // Single-threaded case: no need for expensive atomics, local_packed
+        // is the truth already.
+        params->RunPack(side, tuning, start, end);
+        TraceRecordBlockPacked(thread_id, side, block, trace);
+      }
+      local_packed[side][block] = true;
+    }
+    return true;
+  }
+
+  // Ensures that both the LHS and RHS blocks required by the specified block
+  // are packed. In the event that they are already being packed on another
+  // threads, this function may perform the packing of some other block while
+  // waiting for that other thread to finish packing the requested block.
+  void EnsurePacked(const SidePair<int>& block, const SidePair<int>& start,
+                    const SidePair<int>& end, Tuning tuning) {
+#if RUY_OPT_ENABLED(RUY_OPT_PACK_AHEAD)
+    SidePair<int> next_runahead_block{block[Side::kLhs] + 1,
+                                      block[Side::kRhs] + 1};
+    Side next_runahead_side = Side::kLhs;
+#endif
+    while (true) {
+      bool both_sides_packed = true;
+      for (Side side : {Side::kLhs, Side::kRhs}) {
+        both_sides_packed &=
+            TryPack(side, block[side], start[side], end[side], tuning);
+      }
+      if (both_sides_packed) {
+        break;
+      }
+#if RUY_OPT_ENABLED(RUY_OPT_PACK_AHEAD)
+      const Side runahead_side = next_runahead_side;
+      const int runahead_block = next_runahead_block[runahead_side];
+      next_runahead_side =
+          next_runahead_side == Side::kLhs ? Side::kRhs : Side::kLhs;
+      if (runahead_block >= NumBlocksPerSide(runahead_side, block_map)) {
+        continue;
+      }
+      int runahead_block_start, runahead_block_end;
+      GetBlockMatrixCoords(runahead_side, block_map, runahead_block,
+                           &runahead_block_start, &runahead_block_end);
+      TryPack(runahead_side, runahead_block, runahead_block_start,
+              runahead_block_end, tuning);
+      next_runahead_block[runahead_side] = runahead_block + 1;
+#endif
+    }
+  }
+
   TrMulParams* params;
   const BlockMap& block_map;
-  std::atomic<std::uint32_t>* atomic_n;
-  std::uint32_t thread_id;
-  std::atomic<bool>* lhs_packed;
-  std::atomic<bool>* rhs_packed;
+  std::atomic<int>* atomic_block_id;
+  int thread_id;
+  bool need_atomics;
+  SidePair<std::atomic<PackingStatus>*> packing_status;
   TuningResolver* tuning_resolver;
   Allocator* local_allocator;
   Trace* trace;
+
+  // Local indicators of packedness to avoid the overhead of atomic ops.
+  SidePair<bool*> local_packed;
 };
 
 void AllocatePMatrix(Allocator* allocator, PMatrix* packed) {
@@ -169,16 +266,14 @@
 void TrMul(TrMulParams* params, Context* context) {
   gemmlowp::ScopedProfilingLabel label("TrMul");
 
-  PMatrix& packed_lhs = params->packed_lhs;
-  PMatrix& packed_rhs = params->packed_rhs;
-  DMatrix& lhs = params->lhs;
-  DMatrix& rhs = params->rhs;
+  PMatrix& packed_lhs = params->packed[Side::kLhs];
+  PMatrix& packed_rhs = params->packed[Side::kRhs];
+  DMatrix& lhs = params->src[Side::kLhs];
+  DMatrix& rhs = params->src[Side::kRhs];
 
   const int rows = lhs.layout.cols;
   const int cols = rhs.layout.cols;
   const int depth = lhs.layout.rows;
-  const int rows_rounded_up = packed_lhs.layout.cols;
-  const int cols_rounded_up = packed_rhs.layout.cols;
 
   int thread_count = GetThreadCount(context, rows, cols, depth);
   const auto loop_structure =
@@ -186,24 +281,30 @@
                        params->cache_friendly_traversal_threshold);
   Allocator* allocator = context->GetMainAllocator();
 
-  if (!params->lhs_is_prepacked) {
-    AllocatePMatrix(allocator, &packed_lhs);
-  }
-  if (!params->rhs_is_prepacked) {
-    AllocatePMatrix(allocator, &packed_rhs);
+  // Allocate packed matrices
+  for (Side side : {Side::kLhs, Side::kRhs}) {
+    if (!params->is_prepacked[side]) {
+      AllocatePMatrix(allocator, &params->packed[side]);
+    }
   }
 
+  // Case of running this TrMul as a simple loop.
+  // This is a good place to start reading this function: all the rest
+  // of this function is just an optimized, but functionally equivalent,
+  // version of that.
   if (loop_structure == LoopStructure::kSimple) {
     gemmlowp::ScopedProfilingLabel label_simple("TrMulImpl, simple loop");
     Tuning tuning = context->GetMainThreadTuning();
 
-    if (!params->lhs_is_prepacked) {
-      params->LhsRunPack(tuning, 0, rows_rounded_up);
+    const SidePair<int> origin{0, 0};
+    const SidePair<int> rounded_dims{packed_lhs.layout.cols,
+                                     packed_rhs.layout.cols};
+    for (Side side : {Side::kLhs, Side::kRhs}) {
+      if (!params->is_prepacked[side]) {
+        params->RunPack(side, tuning, origin[side], rounded_dims[side]);
+      }
     }
-    if (!params->rhs_is_prepacked) {
-      params->RhsRunPack(tuning, 0, cols_rounded_up);
-    }
-    params->RunKernel(tuning, 0, 0, rows_rounded_up, cols_rounded_up);
+    params->RunKernel(tuning, origin, rounded_dims);
 
     allocator->FreeAll();
     return;
@@ -216,60 +317,56 @@
 
   // Initialize block map.
   BlockMap block_map;
-  MakeBlockMap(rows_rounded_up, cols_rounded_up, depth,
+  MakeBlockMap(packed_lhs.layout.cols, packed_rhs.layout.cols, depth,
                packed_lhs.layout.kernel.cols, packed_rhs.layout.kernel.cols,
                packed_lhs.data_type.size, packed_rhs.data_type.size,
                params->cache_friendly_traversal_threshold, &block_map);
-  std::uint16_t num_blocks_of_rows = NumBlocksOfRows(block_map);
-  std::uint16_t num_blocks_of_cols = NumBlocksOfCols(block_map);
-  std::uint32_t num_blocks = NumBlocks(block_map);
-  RUY_DCHECK_EQ(num_blocks, num_blocks_of_rows * num_blocks_of_cols);
 
   // Initialize per-thread state.
-  thread_count = clamp(thread_count, 1, num_blocks);
+  thread_count = clamp(thread_count, 1, NumBlocks(block_map));
+  const bool need_atomics = thread_count > 1;
   context->EnsureNPerThreadStates(thread_count);
   for (auto& per_thread_state : context->per_thread_states) {
     per_thread_state->tuning_resolver.SetTuning(context->explicit_tuning);
   }
 
-  // Allocate memory.
-  std::atomic<bool>* lhs_packed = nullptr;
-  if (!params->lhs_is_prepacked) {
-    allocator->Allocate(num_blocks_of_rows, &lhs_packed);
+  // In the need_atomics case, allocate and initialize atomic values tracking
+  // the packing status of blocks.
+  SidePair<std::atomic<PackingStatus>*> packing_status{nullptr, nullptr};
+  if (need_atomics) {
+    for (Side side : {Side::kLhs, Side::kRhs}) {
+      if (!params->is_prepacked[side]) {
+        const int size = NumBlocksPerSide(side, block_map);
+        allocator->Allocate(size, &packing_status[side]);
+        for (int i = 0; i < size; i++) {
+          packing_status[side][i].store(PackingStatus::kNotStarted,
+                                        std::memory_order_relaxed);
+        }
+      }
+    }
   }
-  std::atomic<bool>* rhs_packed = nullptr;
-  if (!params->rhs_is_prepacked) {
-    allocator->Allocate(num_blocks_of_cols, &rhs_packed);
-  }
-  std::atomic<std::uint32_t>* atomic_n;
-  allocator->Allocate(1, &atomic_n);
+
+  // Create the atomic block id, allocate it using Allocator so that
+  // we get the alignment ensuring that it sits alone in its exclusives
+  // reservation granule.
+  std::atomic<int>* atomic_block_id;
+  allocator->Allocate(1, &atomic_block_id);
+
+  // Create task objects.
   TrMulTask* tasks;
   allocator->Allocate(thread_count, &tasks);
 
-  // Initialize allocated data.
-  if (lhs_packed != nullptr) {
-    for (int i = 0; i < num_blocks_of_rows; i++) {
-      lhs_packed[i].store(false, std::memory_order_release);
-    }
-  }
-  if (rhs_packed != nullptr) {
-    for (int i = 0; i < num_blocks_of_cols; i++) {
-      rhs_packed[i].store(false, std::memory_order_release);
-    }
-  }
-  atomic_n->store(thread_count);
+  atomic_block_id->store(thread_count);
 
   for (int i = 0; i < thread_count; i++) {
-    new (tasks + i)
-        TrMulTask(params, block_map, atomic_n, i, lhs_packed, rhs_packed,
-                  &context->per_thread_states[i]->tuning_resolver,
-                  &context->per_thread_states[i]->allocator, trace);
+    new (tasks + i) TrMulTask(params, block_map, atomic_block_id, i,
+                              need_atomics, packing_status,
+                              &context->per_thread_states[i]->tuning_resolver,
+                              &context->per_thread_states[i]->allocator, trace);
   }
 
   // Do the computation.
-  TraceRecordExecute(trace);
-  TraceStartRecordingBlockAndThreadFields(block_map, thread_count, trace);
-
+  TraceRecordExecute(block_map, thread_count, trace);
   context->workers_pool.Execute(thread_count, tasks);
 
   // Finish up.
@@ -277,9 +374,8 @@
     tasks[i].~TrMulTask();
   }
 
-  TraceRecordEnd(trace);
-
   allocator->FreeAll();
+  TraceRecordEnd(trace);
 }
 
 }  // namespace ruy
diff --git a/tensorflow/lite/experimental/ruy/trmul.h b/tensorflow/lite/experimental/ruy/trmul.h
index 1a3872b..6f7d7ba 100644
--- a/tensorflow/lite/experimental/ruy/trmul.h
+++ b/tensorflow/lite/experimental/ruy/trmul.h
@@ -27,47 +27,10 @@
 #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_TRMUL_H_
 
 #include "tensorflow/lite/experimental/ruy/context.h"
-#include "tensorflow/lite/experimental/ruy/internal_matrix.h"
-#include "tensorflow/lite/experimental/ruy/kernel.h"
-#include "tensorflow/lite/experimental/ruy/pack.h"
-#include "tensorflow/lite/experimental/ruy/tune.h"
+#include "tensorflow/lite/experimental/ruy/trmul_params.h"
 
 namespace ruy {
 
-// Type-erased data needed for implementing TrMul.
-struct TrMulParams {
-  // Helper functions for invoking the function pointers.
-  void LhsRunPack(Tuning tuning, int start_c, int end_c) {
-    lhs_run_pack(tuning, lhs, &packed_lhs, start_c, end_c);
-  }
-  void RhsRunPack(Tuning tuning, int start_c, int end_c) {
-    rhs_run_pack(tuning, rhs, &packed_rhs, start_c, end_c);
-  }
-  void RunKernel(Tuning tuning, int start_r, int start_c, int end_r,
-                 int end_c) {
-    run_kernel(tuning, packed_lhs, packed_rhs, spec, start_r, start_c, end_r,
-               end_c, &dst);
-  }
-
-  // Function pointers to type-erased entry points for kernels and packers.
-  RunPackFn* lhs_run_pack = nullptr;
-  RunPackFn* rhs_run_pack = nullptr;
-  RunKernelFn* run_kernel = nullptr;
-
-  // Matrices and packed matrices.
-  DMatrix lhs;
-  DMatrix rhs;
-  DMatrix dst;
-  PMatrix packed_lhs;
-  PMatrix packed_rhs;
-  bool lhs_is_prepacked = false;
-  bool rhs_is_prepacked = false;
-  int cache_friendly_traversal_threshold = 0;
-
-  // Type-erased Spec.
-  void* spec = nullptr;
-};
-
 void TrMul(TrMulParams* params, Context* context);
 
 }  // namespace ruy
diff --git a/tensorflow/lite/experimental/ruy/trmul_params.h b/tensorflow/lite/experimental/ruy/trmul_params.h
new file mode 100644
index 0000000..2d06604
--- /dev/null
+++ b/tensorflow/lite/experimental/ruy/trmul_params.h
@@ -0,0 +1,59 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_TRMUL_PARAMS_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_TRMUL_PARAMS_H_
+
+#include "tensorflow/lite/experimental/ruy/internal_matrix.h"
+#include "tensorflow/lite/experimental/ruy/side_pair.h"
+#include "tensorflow/lite/experimental/ruy/tune.h"
+
+namespace ruy {
+
+using RunKernelFn = void(Tuning, const SidePair<PMatrix>&, void*,
+                         const SidePair<int>&, const SidePair<int>&, DMatrix*);
+
+using RunPackFn = void(Tuning, const DMatrix&, PMatrix*, int, int);
+
+// Type-erased data needed for implementing TrMul.
+struct TrMulParams {
+  TrMulParams() : run_pack{nullptr, nullptr}, is_prepacked{false, false} {}
+  // Helper functions for invoking the function pointers.
+  void RunPack(Side side, Tuning tuning, int start, int end) {
+    run_pack[side](tuning, src[side], &packed[side], start, end);
+  }
+  void RunKernel(Tuning tuning, const SidePair<int>& start,
+                 const SidePair<int>& end) {
+    run_kernel(tuning, packed, spec, start, end, &dst);
+  }
+
+  // Function pointers to type-erased entry points for kernels and packers.
+  SidePair<RunPackFn*> run_pack;
+  RunKernelFn* run_kernel = nullptr;
+
+  // Matrices and packed matrices.
+  SidePair<DMatrix> src;
+  DMatrix dst;
+  SidePair<PMatrix> packed;
+  SidePair<bool> is_prepacked;
+  int cache_friendly_traversal_threshold = 0;
+
+  // Type-erased Spec.
+  void* spec = nullptr;
+};
+
+}  // namespace ruy
+
+#endif  // TENSORFLOW_LITE_EXPERIMENTAL_RUY_TRMUL_PARAMS_H_
diff --git a/tensorflow/lite/experimental/ruy/tune.cc b/tensorflow/lite/experimental/ruy/tune.cc
index 58a956e..3249b5b 100644
--- a/tensorflow/lite/experimental/ruy/tune.cc
+++ b/tensorflow/lite/experimental/ruy/tune.cc
@@ -18,8 +18,6 @@
 #include <algorithm>
 #include <cstdint>
 
-#include "tensorflow/lite/experimental/ruy/time.h"
-
 namespace ruy {
 
 #ifdef RUY_IMPLEMENT_TUNING
@@ -88,16 +86,17 @@
   Duration timing_nicely_ordered = Duration::max();
 
   for (int r = 0; r < kRepeats; r++) {
-    TimePoint t0 = Clock::now();
+    TimePoint t0 = Now();
     PoorlyOrderedKernel(kLoopIters);
-    TimePoint t1 = Clock::now();
+    TimePoint t1 = Now();
     NicelyOrderedKernel(kLoopIters);
-    TimePoint t2 = Clock::now();
+    TimePoint t2 = Now();
     timing_poorly_ordered = std::min(timing_poorly_ordered, t1 - t0);
     timing_nicely_ordered = std::min(timing_nicely_ordered, t2 - t1);
   }
 
-  return ToSeconds(timing_nicely_ordered) / ToSeconds(timing_poorly_ordered);
+  return ToFloatSeconds(timing_nicely_ordered) /
+         ToFloatSeconds(timing_poorly_ordered);
 }
 
 float TuningResolver::ThresholdRatio() {
@@ -138,17 +137,15 @@
 
 #endif
 
-static constexpr double kExpirySecs = 0.25;
-
 TuningResolver::TuningResolver()
-    : expiry_duration_(DurationFromSeconds(kExpirySecs)) {}
+    : expiry_duration_(DurationFromMilliseconds(250)) {}
 
 Tuning TuningResolver::Resolve() {
 #ifdef RUY_IMPLEMENT_TUNING
   if (unresolved_tuning_ != Tuning::kAuto) {
     return unresolved_tuning_;
   }
-  TimePoint new_timepoint = Clock::now();
+  TimePoint new_timepoint = CoarseNow();
   if (last_resolved_tuning_ != Tuning::kAuto &&
       (new_timepoint - last_resolved_timepoint_) < expiry_duration_) {
     return last_resolved_tuning_;
diff --git a/tensorflow/lite/experimental/ruy/tune.h b/tensorflow/lite/experimental/ruy/tune.h
index a1d0eb9..c625778 100644
--- a/tensorflow/lite/experimental/ruy/tune.h
+++ b/tensorflow/lite/experimental/ruy/tune.h
@@ -72,8 +72,6 @@
 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_TUNE_H_
 #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_TUNE_H_
 
-#include <cstdint>
-
 #include "tensorflow/lite/experimental/ruy/opt_set.h"
 #include "tensorflow/lite/experimental/ruy/platform.h"
 #include "tensorflow/lite/experimental/ruy/time.h"
diff --git a/tensorflow/lite/experimental/ruy/wait.cc b/tensorflow/lite/experimental/ruy/wait.cc
index 56000f3..310f53d 100644
--- a/tensorflow/lite/experimental/ruy/wait.cc
+++ b/tensorflow/lite/experimental/ruy/wait.cc
@@ -15,11 +15,7 @@
 
 #include "tensorflow/lite/experimental/ruy/wait.h"
 
-#include <condition_variable>  // NOLINT(build/c++11)
-#include <functional>
-#include <mutex>  // NOLINT(build/c++11)
-
-#include "tensorflow/lite/experimental/ruy/time.h"
+#include <chrono>  // NOLINT(build/c++11)
 
 namespace ruy {
 
@@ -32,8 +28,8 @@
   }
 
   // Then try busy-waiting.
-  const TimePoint wait_start = Clock::now();
-  while (Clock::now() - wait_start < spin_duration) {
+  const TimePoint wait_start = Now();
+  while (Now() - wait_start < spin_duration) {
     if (condition()) {
       return;
     }
@@ -67,8 +63,7 @@
   // a little while, then start on a new GEMM. In that case the wait interval
   // may be a little longer. There may also not be another GEMM for a long time,
   // in which case we'll end up passively waiting below.
-  const double kMaxBusyWaitSeconds = 2e-3;
-  const Duration spin_duration = DurationFromSeconds(kMaxBusyWaitSeconds);
+  const Duration spin_duration = DurationFromMilliseconds(2);
   WaitUntil(condition, spin_duration, condvar, mutex);
 }
 
diff --git a/tensorflow/lite/experimental/ruy/wait.h b/tensorflow/lite/experimental/ruy/wait.h
index df4f3e3..ae38836 100644
--- a/tensorflow/lite/experimental/ruy/wait.h
+++ b/tensorflow/lite/experimental/ruy/wait.h
@@ -18,7 +18,7 @@
 
 #include <condition_variable>  // NOLINT(build/c++11)
 #include <functional>
-#include <mutex>  // NOLINT(build/c++11)
+#include <mutex>  //  NOLINT(build/c++11)
 
 #include "tensorflow/lite/experimental/ruy/time.h"
 
diff --git a/tensorflow/lite/experimental/ruy/wait_test.cc b/tensorflow/lite/experimental/ruy/wait_test.cc
index a19d8c8..7c99f10 100644
--- a/tensorflow/lite/experimental/ruy/wait_test.cc
+++ b/tensorflow/lite/experimental/ruy/wait_test.cc
@@ -39,9 +39,15 @@
         condvar_(condvar),
         mutex_(mutex) {}
   void operator()() {
+    // end_value_==-1 is how the master thread will tell us it's OK to terminate
     while (end_value_.load() != -1) {
+      // wait until end_value is set to a higher value
+      while (value_->load() == end_value_.load()) {
+      }
+      // increment value as long as it's lower than end_value
       while (value_->fetch_add(1) < end_value_.load() - 1) {
       }
+      // when value has reached end_value, notify the master thread.
       while (value_->load() == end_value_.load()) {
         std::lock_guard<std::mutex> lock(*mutex_);
         condvar_->notify_all();
@@ -56,13 +62,14 @@
   std::mutex* mutex_;
 };
 
-void WaitTest(const Duration& spin_duration) {
+void WaitTest(const Duration& spin_duration, const Duration& delay) {
   std::condition_variable condvar;
   std::mutex mutex;
   std::atomic<int> value(0);
   std::atomic<int> end_value(0);
   ThreadCountingUpToValue thread_callable(end_value, &value, &condvar, &mutex);
   std::thread thread(thread_callable);
+  std::this_thread::sleep_for(delay);
   for (int i = 1; i < 10; i++) {
     end_value.store(1000 * i);
     const auto& condition = [&value, &end_value]() {
@@ -75,17 +82,26 @@
   thread.join();
 }
 
-TEST(WaitTest, WaitTestNoSpin) { WaitTest(DurationFromSeconds(0)); }
+TEST(WaitTest, WaitTestNoSpin) {
+  WaitTest(DurationFromSeconds(0), DurationFromSeconds(0));
+}
 
 TEST(WaitTest, WaitTestSpinOneMicrosecond) {
-  WaitTest(DurationFromSeconds(1e-6));
+  WaitTest(DurationFromSeconds(1e-6), DurationFromSeconds(0));
 }
 
 TEST(WaitTest, WaitTestSpinOneMillisecond) {
-  WaitTest(DurationFromSeconds(1e-3));
+  WaitTest(DurationFromSeconds(1e-3), DurationFromSeconds(0));
 }
 
-TEST(WaitTest, WaitTestSpinOneSecond) { WaitTest(DurationFromSeconds(1)); }
+TEST(WaitTest, WaitTestSpinOneSecond) {
+  WaitTest(DurationFromSeconds(1), DurationFromSeconds(0));
+}
+
+// Testcase to consistently reproduce the hang in b/139062384.
+TEST(WaitTest, WaitTestNoSpinWithDelayBug139062384) {
+  WaitTest(DurationFromSeconds(0), DurationFromSeconds(1));
+}
 
 }  // namespace
 }  // namespace ruy
diff --git a/tensorflow/lite/experimental/swift/BUILD.apple b/tensorflow/lite/experimental/swift/BUILD.apple
index 0a2126b..7a78c98 100644
--- a/tensorflow/lite/experimental/swift/BUILD.apple
+++ b/tensorflow/lite/experimental/swift/BUILD.apple
@@ -25,7 +25,9 @@
     name = "Tests",
     size = "small",
     minimum_os_version = TFL_MINIMUM_OS_VERSION,
-    tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
+    tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS + [
+        "nozapfhahn",  # Fails during coverage build, see b/139134323.
+    ],
     deps = [
         ":TestsLibrary",
     ],
@@ -35,7 +37,9 @@
     name = "TestsLibrary",
     testonly = 1,
     srcs = glob(["Tests/*.swift"]),
-    tags = TFL_DEFAULT_TAGS,
+    tags = TFL_DEFAULT_TAGS + [
+        "nozapfhahn",  # Fails during coverage build, see b/139134323.
+    ],
     deps = [
         ":Resources",
         ":TensorFlowLite",
diff --git a/tensorflow/lite/experimental/swift/Sources/Delegate.swift b/tensorflow/lite/experimental/swift/Sources/Delegate.swift
new file mode 100644
index 0000000..11a609f
--- /dev/null
+++ b/tensorflow/lite/experimental/swift/Sources/Delegate.swift
@@ -0,0 +1,24 @@
+// Copyright 2019 Google Inc. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at:
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+import TensorFlowLiteC
+
+/// A delegate that the `Interpreter` uses to perform TensorFlow Lite model computations.
+public protocol Delegate: class {
+  /// `TFL_Delegate` C pointer type.
+  typealias CDelegate = OpaquePointer
+
+  /// Delegate that performs model computations.
+  var cDelegate: CDelegate? { get }
+}
diff --git a/tensorflow/lite/experimental/swift/Sources/Interpreter.swift b/tensorflow/lite/experimental/swift/Sources/Interpreter.swift
index 457ca41..a2b0556 100644
--- a/tensorflow/lite/experimental/swift/Sources/Interpreter.swift
+++ b/tensorflow/lite/experimental/swift/Sources/Interpreter.swift
@@ -17,8 +17,7 @@
 
 /// A TensorFlow Lite interpreter that performs inference from a given model.
 public final class Interpreter {
-
-  /// The `TFL_Interpreter` C pointer type represented as an `UnsafePointer<TFL_Interpreter>`.
+  /// `TFL_Interpreter` C pointer type represented as an `UnsafePointer<TFL_Interpreter>`.
   private typealias CInterpreter = OpaquePointer
 
   /// Total number of input tensors associated with the model.
@@ -31,15 +30,15 @@
     return Int(TFL_InterpreterGetOutputTensorCount(cInterpreter))
   }
 
-  /// The underlying `TFL_Interpreter` C pointer.
+  /// Underlying `TFL_Interpreter` C pointer.
   private var cInterpreter: CInterpreter?
 
   /// Creates a new model interpreter instance.
   ///
   /// - Parameters:
   ///   - modelPath: Local file path to a TensorFlow Lite model.
-  ///   - options: Custom configurations for the interpreter. The default is `nil` indicating that
-  ///       the interpreter will determine the configuration options.
+  ///   - options: Custom configurations for the interpreter. Default is `nil` indicating that the
+  ///       interpreter will determine the configuration options.
   /// - Throws: An error if the model could not be loaded or the interpreter could not be created.
   public init(modelPath: String, options: InterpreterOptions? = nil) throws {
     guard let model = Model(filePath: modelPath) else { throw InterpreterError.failedToLoadModel }
diff --git a/tensorflow/lite/experimental/swift/Sources/InterpreterError.swift b/tensorflow/lite/experimental/swift/Sources/InterpreterError.swift
index a07f857..3a8e5bc 100644
--- a/tensorflow/lite/experimental/swift/Sources/InterpreterError.swift
+++ b/tensorflow/lite/experimental/swift/Sources/InterpreterError.swift
@@ -15,7 +15,7 @@
 import Foundation
 
 /// TensorFlow Lite interpreter errors.
-public enum InterpreterError: Error {
+public enum InterpreterError: Error, Equatable, Hashable {
   case invalidTensorIndex(index: Int, maxIndex: Int)
   case invalidTensorDataCount(provided: Int, required: Int)
   case invalidTensorDataType
@@ -37,8 +37,8 @@
     switch self {
     case .invalidTensorIndex(let index, let maxIndex):
       return "Invalid tensor index \(index), max index is \(maxIndex)."
-    case .invalidTensorDataCount(let providedCount, let requiredCount):
-      return "Provided data count \(providedCount) must match the required count \(requiredCount)."
+    case .invalidTensorDataCount(let provided, let required):
+      return "Provided data count \(provided) must match the required count \(required)."
     case .invalidTensorDataType:
       return "Tensor data type is unsupported or could not be determined due to a model error."
     case .failedToLoadModel:
@@ -63,9 +63,5 @@
 
 extension InterpreterError: CustomStringConvertible {
   /// Textual representation of the TensorFlow Lite interpreter error.
-  public var description: String {
-    return errorDescription ?? "Unknown error."
-  }
+  public var description: String { return errorDescription ?? "Unknown error." }
 }
-
-extension InterpreterError: Equatable {}
diff --git a/tensorflow/lite/experimental/swift/Sources/InterpreterOptions.swift b/tensorflow/lite/experimental/swift/Sources/InterpreterOptions.swift
index ae2bbc4..255bce2 100644
--- a/tensorflow/lite/experimental/swift/Sources/InterpreterOptions.swift
+++ b/tensorflow/lite/experimental/swift/Sources/InterpreterOptions.swift
@@ -14,9 +14,8 @@
 
 /// Custom configuration options for a TensorFlow Lite `Interpreter`.
 public struct InterpreterOptions: Equatable {
-
-  /// Maximum number of CPU threads that the interpreter should run on. Default is `nil` which
-  /// indicates that the `Interpreter` will decide the number of threads to use.
+  /// Maximum number of CPU threads that the interpreter should run on. Default is `nil` indicating
+  /// that the `Interpreter` will decide the number of threads to use.
   public var threadCount: Int? = nil
 
   /// Creates a new instance of interpreter options.
diff --git a/tensorflow/lite/experimental/swift/Sources/Model.swift b/tensorflow/lite/experimental/swift/Sources/Model.swift
index 6d52dcc..0635e8c 100644
--- a/tensorflow/lite/experimental/swift/Sources/Model.swift
+++ b/tensorflow/lite/experimental/swift/Sources/Model.swift
@@ -16,11 +16,10 @@
 
 /// A TensorFlow Lite model used by the 'Interpreter` to perform inference.
 final class Model {
-
-  /// The `TFL_Model` C pointer type represented as an `UnsafePointer<TFL_Model>`.
+  /// `TFL_Model` C pointer type represented as an `UnsafePointer<TFL_Model>`.
   typealias CModel = OpaquePointer
 
-  /// The underlying `TFL_Model` C pointer.
+  /// Underlying `TFL_Model` C pointer.
   let cModel: CModel?
 
   /// Creates a new model instance.
diff --git a/tensorflow/lite/experimental/swift/Sources/QuantizationParameters.swift b/tensorflow/lite/experimental/swift/Sources/QuantizationParameters.swift
index 254ab3f..e3f4a52 100644
--- a/tensorflow/lite/experimental/swift/Sources/QuantizationParameters.swift
+++ b/tensorflow/lite/experimental/swift/Sources/QuantizationParameters.swift
@@ -15,8 +15,7 @@
 /// Parameters that determine the mapping of quantized values to real values. Quantized values can
 /// be mapped to float values using the following conversion:
 /// `realValue = scale * (quantizedValue - zeroPoint)`.
-public struct QuantizationParameters {
-
+public struct QuantizationParameters: Equatable, Hashable {
   /// Difference between real values corresponding to consecutive quantized values differing by 1.
   /// For example, the range of quantized values for `UInt8` data type is [0, 255].
   public let scale: Float
diff --git a/tensorflow/lite/experimental/swift/Sources/Tensor.swift b/tensorflow/lite/experimental/swift/Sources/Tensor.swift
index 317914f..4684bc2 100644
--- a/tensorflow/lite/experimental/swift/Sources/Tensor.swift
+++ b/tensorflow/lite/experimental/swift/Sources/Tensor.swift
@@ -16,8 +16,7 @@
 import TensorFlowLiteC
 
 /// An input or output tensor in a TensorFlow Lite graph.
-public struct Tensor {
-
+public struct Tensor: Equatable, Hashable {
   /// Name of the tensor.
   public let name: String
 
@@ -38,9 +37,10 @@
   /// - Parameters:
   ///   - name: Name of the tensor.
   ///   - dataType: Data type of the tensor.
+  ///   - shape: Shape of the tensor.
   ///   - data: Data in the input tensor.
   ///   - quantizationParameters Quantization parameters for the tensor if using a quantized model.
-  ///       The default is `nil`.
+  ///       Default is `nil`.
   init(
     name: String,
     dataType: TensorDataType,
@@ -57,7 +57,7 @@
 }
 
 /// Supported TensorFlow Lite tensor data types.
-public enum TensorDataType: Equatable {
+public enum TensorDataType: Equatable, Hashable {
   /// Boolean.
   case bool
   /// 8-bit unsigned integer.
@@ -102,7 +102,7 @@
 }
 
 /// The shape of a TensorFlow Lite tensor.
-public struct TensorShape {
+public struct TensorShape: Equatable, Hashable {
 
   /// The number of dimensions of the tensor.
   public let rank: Int
diff --git a/tensorflow/lite/experimental/swift/TensorFlowLiteSwift.podspec b/tensorflow/lite/experimental/swift/TensorFlowLiteSwift.podspec
index 3210ccc..f50e99b 100644
--- a/tensorflow/lite/experimental/swift/TensorFlowLiteSwift.podspec
+++ b/tensorflow/lite/experimental/swift/TensorFlowLiteSwift.podspec
@@ -1,10 +1,10 @@
 Pod::Spec.new do |s|
   s.name             = 'TensorFlowLiteSwift'
-  s.version          = '0.2.0'
+  s.version          = '1.14.0'
   s.authors          = 'Google Inc.'
   s.license          = { :type => 'Apache' }
   s.homepage         = 'https://github.com/tensorflow/tensorflow'
-  s.source           = { :git => 'https://github.com/tensorflow/tensorflow.git', :commit => '37c101d' }
+  s.source           = { :git => 'https://github.com/tensorflow/tensorflow.git', :tag => "v#{s.version}" }
   s.summary          = 'TensorFlow Lite for Swift'
   s.description      = <<-DESC
 
diff --git a/tensorflow/lite/experimental/swift/Tests/QuantizationParametersTests.swift b/tensorflow/lite/experimental/swift/Tests/QuantizationParametersTests.swift
index 65648c2..f836889 100644
--- a/tensorflow/lite/experimental/swift/Tests/QuantizationParametersTests.swift
+++ b/tensorflow/lite/experimental/swift/Tests/QuantizationParametersTests.swift
@@ -33,11 +33,3 @@
     XCTAssertNotEqual(parameters2, parameters3)
   }
 }
-
-// MARK: - Extensions
-
-extension QuantizationParameters: Equatable {
-  public static func == (lhs: QuantizationParameters, rhs: QuantizationParameters) -> Bool {
-    return lhs.scale == rhs.scale && lhs.zeroPoint == rhs.zeroPoint
-  }
-}
diff --git a/tensorflow/lite/experimental/swift/Tests/TensorTests.swift b/tensorflow/lite/experimental/swift/Tests/TensorTests.swift
index 4540043..1fad0e7 100644
--- a/tensorflow/lite/experimental/swift/Tests/TensorTests.swift
+++ b/tensorflow/lite/experimental/swift/Tests/TensorTests.swift
@@ -39,6 +39,38 @@
     XCTAssertEqual(inputTensor.quantizationParameters, quantizationParameters)
   }
 
+  func testTensor_Equatable() {
+    let name = "Tensor"
+    let dataType: TensorDataType = .uInt8
+    let shape = TensorShape(Constant.dimensions)
+    guard let data = name.data(using: .utf8) else { XCTFail("Data should not be nil."); return }
+    let quantizationParameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
+    let tensor1 = Tensor(
+      name: name,
+      dataType: dataType,
+      shape: shape,
+      data: data,
+      quantizationParameters: quantizationParameters
+    )
+    var tensor2 = Tensor(
+      name: name,
+      dataType: dataType,
+      shape: shape,
+      data: data,
+      quantizationParameters: quantizationParameters
+    )
+    XCTAssertEqual(tensor1, tensor2)
+
+    tensor2 = Tensor(
+      name: "Tensor2",
+      dataType: dataType,
+      shape: shape,
+      data: data,
+      quantizationParameters: quantizationParameters
+    )
+    XCTAssertNotEqual(tensor1, tensor2)
+  }
+
   // MARK: - TensorShape
 
   func testTensorShape_InitWithArray() {
@@ -58,6 +90,15 @@
     XCTAssertEqual(shape.rank, Constant.dimensions.count)
     XCTAssertEqual(shape.dimensions, Constant.dimensions)
   }
+
+  func testTensorShape_Equatable() {
+    let shape1 = TensorShape(2, 2, 3)
+    var shape2: TensorShape = [2, 2, 3]
+    XCTAssertEqual(shape1, shape2)
+
+    shape2 = [2, 2, 4]
+    XCTAssertNotEqual(shape1, shape2)
+  }
 }
 
 // MARK: - Constants
@@ -66,18 +107,3 @@
   /// Array of 2 arrays of 2 arrays of 3 numbers: [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]].
   static let dimensions = [2, 2, 3]
 }
-
-// MARK: - Extensions
-
-extension TensorShape: Equatable {
-  public static func == (lhs: TensorShape, rhs: TensorShape) -> Bool {
-    return lhs.rank == rhs.rank && lhs.dimensions == rhs.dimensions
-  }
-}
-
-extension Tensor: Equatable {
-  public static func == (lhs: Tensor, rhs: Tensor) -> Bool {
-    return lhs.name == rhs.name && lhs.dataType == rhs.dataType && lhs.shape == rhs.shape &&
-           lhs.data == rhs.data && lhs.quantizationParameters == rhs.quantizationParameters
-  }
-}
diff --git a/tensorflow/lite/external_cpu_backend_context.h b/tensorflow/lite/external_cpu_backend_context.h
index 8d5125d..8809863 100644
--- a/tensorflow/lite/external_cpu_backend_context.h
+++ b/tensorflow/lite/external_cpu_backend_context.h
@@ -49,7 +49,7 @@
 // serialized way. Here's an example to illustrate the context sharing among 2
 // TF Lite interpreters:
 //
-//  TfLiteInternalBackendContext* global_ctxt = new ExternalCpuBackendContext();
+//  TfLiteExternalContext* global_ctxt = new ExternalCpuBackendContext();
 //  interpreter1 = /*...*/;
 //  interpreter1->SetExternalContext(kTfLiteCpuBackendContext, global_ctxt);
 //  interpreter2 = /*...*/;
diff --git a/tensorflow/lite/g3doc/_book.yaml b/tensorflow/lite/g3doc/_book.yaml
index b7b954e..b9c4fea 100644
--- a/tensorflow/lite/g3doc/_book.yaml
+++ b/tensorflow/lite/g3doc/_book.yaml
@@ -24,6 +24,8 @@
         path: /lite/guide/android
       - title: "iOS quickstart"
         path: /lite/guide/ios
+      - title: "Python quickstart"
+        path: /lite/guide/python
       - title: "FAQ"
         path: /lite/guide/faq
       - title: "Roadmap"
@@ -74,11 +76,12 @@
         path: /lite/performance/model_optimization
       - title: "Post-training quantization"
         path: /lite/performance/post_training_quantization
-      - title: "Post-training quantization example"
-        path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tutorials/post_training_quant.ipynb
-      - title: "Post-training integer quantization example"
-        path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tutorials/post_training_integer_quant.ipynb
-        status: external
+      - title: "Post-training weight quantization"
+        path: /lite/performance/post_training_quant
+      - title: "Post-training integer quantization"
+        path: /lite/performance/post_training_integer_quant
+      - title: "Post-training float16 quantization"
+        path: /lite/performance/post_training_float16_quant
       - title: "Delegates"
         path: /lite/performance/delegates
       - title: "GPU delegate"
diff --git a/tensorflow/lite/g3doc/convert/quantization.md b/tensorflow/lite/g3doc/convert/quantization.md
index 895f3e6..9dfc7a2 100644
--- a/tensorflow/lite/g3doc/convert/quantization.md
+++ b/tensorflow/lite/g3doc/convert/quantization.md
@@ -14,7 +14,29 @@
 
 ```
 converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
-converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
+converter.optimizations = [tf.lite.Optimize.DEFAULT]
+tflite_quant_model = converter.convert()
+```
+
+# Full integer quantization of weights and activations
+
+We can get further latency improvements, reductions in peak memory usage, and
+access to integer only hardware accelerators by making sure all model math is
+quantized. To do this, we need to measure the dynamic range of activations and
+inputs with a representative data set. You can simply create an input data
+generator and provide it to our converter.
+
+```
+import tensorflow as tf
+
+def representative_dataset_gen():
+  for _ in range(num_calibration_steps):
+    # Get sample input data as a numpy array in a method of your choosing.
+    yield [input]
+
+converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
+converter.optimizations = [tf.lite.Optimize.DEFAULT]
+converter.representative_dataset = representative_dataset_gen
 tflite_quant_model = converter.convert()
 ```
 
@@ -25,10 +47,13 @@
 Currently, this requires training a model with
 ["fake-quantization" nodes](https://github.com/tensorflow/tensorflow/tree/r1.13/tensorflow/contrib/quantize).
 
+This is only available in the v1 converter. A longer term solution that's
+compatible with 2.0 semantics is in progress.
+
 Convert the graph:
 
 ```
-converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
+converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(saved_model_dir)
 converter.inference_type = tf.lite.constants.QUANTIZED_UINT8
 input_arrays = converter.get_input_arrays()
 converter.quantized_input_stats = {input_arrays[0] : (0., 1.)}  # mean, std_dev
diff --git a/tensorflow/lite/g3doc/convert/rnn.md b/tensorflow/lite/g3doc/convert/rnn.md
index 7beaf32..52bc287 100644
--- a/tensorflow/lite/g3doc/convert/rnn.md
+++ b/tensorflow/lite/g3doc/convert/rnn.md
@@ -11,26 +11,27 @@
 ## Currently supported
 
 Currently, RNN models using
-[`tf.nn.static_rnn`](https://www.tensorflow.org/api_docs/python/tf/nn/static_rnn)
+[`tf.compat.v1.nn.static_rnn`](https://www.tensorflow.org/api_docs/python/tf/nn/static_rnn)
 can be converted successfully as long as no `sequence_length` is specified.
 
-The following `tf.nn.rnn_cell` operations work with `tf.nn.static_rnn`:
+The following `tf.compat.v1.nn.rnn_cell` operations work with
+`tf.compat.v1.nn.static_rnn`:
 
-*   [tf.nn.rnn_cell.LSTMCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/LSTMCell)
-*   [tf.nn.rnn_cell.RNNCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/RNNCell)
-*   [tf.nn.rnn_cell.GRUCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/GRUCell)
-*   [tf.nn.rnn_cell.BasicLSTMCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/BasicLSTMCell)
-*   [tf.nn.rnn_cell.BasicRNNCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/BasicRNNCell)
+*   [tf.compat.v1.nn.rnn_cell.LSTMCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/LSTMCell)
+*   [tf.compat.v1.nn.rnn_cell.RNNCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/RNNCell)
+*   [tf.compat.v1.nn.rnn_cell.GRUCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/GRUCell)
+*   [tf.compat.v1.nn.rnn_cell.BasicLSTMCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/BasicLSTMCell)
+*   [tf.compat.v1.nn.rnn_cell.BasicRNNCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/BasicRNNCell)
 
 In addition, TensorFlow Lite provides some experimental drop-in replacements for
 RNN operations that enable dynamic RNN architectures with TensorFlow Lite.
 
 Drop-in replacements are available for the following:
 
-*   [tf.nn.dynamic_rnn](https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn)
-*   [tf.nn.bidirectional_dynamic_rnn](https://www.tensorflow.org/api_docs/python/tf/nn/bidirectional_dynamic_rnn)
-*   [tf.nn.rnn_cell.RNNCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/RNNCell)
-*   [tf.nn.rnn_cell.LSTMCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/LSTMCell)
+*   [tf.compat.v1.nn.dynamic_rnn](https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn)
+*   [tf.compat.v1.nn.bidirectional_dynamic_rnn](https://www.tensorflow.org/api_docs/python/tf/nn/bidirectional_dynamic_rnn)
+*   [tf.compat.v1.nn.rnn_cell.RNNCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/RNNCell)
+*   [tf.compat.v1.nn.rnn_cell.LSTMCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/LSTMCell)
 
 ## Not currently supported
 
@@ -40,10 +41,10 @@
 in the next section are employed, models built with the following TensorFlow
 functions will not convert successfully:
 
-*   [tf.nn.static_rnn](https://www.tensorflow.org/api_docs/python/tf/nn/static_rnn)
+*   [tf.compat.v1.nn.static_rnn](https://www.tensorflow.org/api_docs/python/tf/nn/static_rnn)
     where a `sequence_length` is specified
-*   [tf.nn.dynamic_rnn](https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn)
-*   [tf.nn.bidirectional_dynamic_rnn](https://www.tensorflow.org/api_docs/python/tf/nn/bidirectional_dynamic_rnn)
+*   [tf.compat.v1.nn.dynamic_rnn](https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn)
+*   [tf.compat.v1.nn.bidirectional_dynamic_rnn](https://www.tensorflow.org/api_docs/python/tf/nn/bidirectional_dynamic_rnn)
 
 Note: TensorFlow Lite plans to implement all required Control Flow operations by
 the end of 2019. At this point, all RNN architectures will convert successfully.
@@ -56,7 +57,7 @@
 ### 1. Refactoring
 
 The simplest approach, if possible, is to refactor the model architecture to use
-[tf.nn.static_rnn](https://www.tensorflow.org/api_docs/python/tf/nn/static_rnn)
+[tf.compat.v1.nn.static_rnn](https://www.tensorflow.org/api_docs/python/tf/nn/static_rnn)
 without `sequence_length`.
 
 ### 2. Drop-in replacements that use op hints and fused ops
@@ -69,24 +70,24 @@
 
 The following drop-in replacements are available:
 
-*   [tf.lite.experimental.nn.dynamic_rnn](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/examples/lstm/rnn.py#L41)
+*   [tf.compat.v1.lite.experimental.nn.dynamic_rnn](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/examples/lstm/rnn.py#L41)
     *   replacement for tf.nn.dynamic_rnn
-*   [tf.lite.experimental.nn.bidirectional_dynamic_rnn](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/examples/lstm/rnn.py#L279)
+*   [tf.compat.v1.lite.experimental.nn.bidirectional_dynamic_rnn](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/examples/lstm/rnn.py#L279)
     *   replacement for tf.nn.bidirectional_dynamic_rnn
-*   [tf.lite.experimental.nn.TfLiteRNNCell](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/examples/lstm/rnn_cell.py#L39)
+*   [tf.compat.v1.lite.experimental.nn.TfLiteRNNCell](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/examples/lstm/rnn_cell.py#L39)
     *   replacement for tf.nn.rnn_cell.RNNCell
-*   [tf.lite.experimental.nn.TfLiteLSTMCell](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/examples/lstm/rnn_cell.py#L159)
+*   [tf.compat.v1.lite.experimental.nn.TfLiteLSTMCell](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/examples/lstm/rnn_cell.py#L159)
     *   replacement for tf.nn.rnn_cell.LSTMCell
 
 Note: These replacements must be used together. For example, if you are using
-`tf.lite.experimental.nn.dynamic_rnn`, you must combine it with
-`tf.lite.experimental.nn.TfLiteRNNCell` instead of using
-`tf.nn.rnn_cell.RNNCell`.
+`tf.compat.v1.lite.experimental.nn.dynamic_rnn`, you must combine it with
+`tf.compat.v1.lite.experimental.nn.TfLiteRNNCell` instead of using
+`tf.compat.v1.nn.rnn_cell.RNNCell`.
 
 Instead of
-[tf.nn.rnn_cell.MultiRNNCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/MultiRNNCell),
+[tf.compat.v1.nn.rnn_cell.MultiRNNCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/MultiRNNCell),
 you should use
-[tf.keras.layers.StackedRNNCells](https://www.tensorflow.org/api_docs/python/tf/keras/layers/StackedRNNCells).
+[tf.compat.v1.keras.layers.StackedRNNCells](https://www.tensorflow.org/api_docs/python/tf/keras/layers/StackedRNNCells).
 
 For a tutorial on using these replacements, see
 [TensorFlow Lite LSTM ops API](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/examples/lstm/g3doc/README.md).
@@ -95,4 +96,4 @@
 [TensorFlowLite_LSTM_Keras_Tutorial](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/examples/lstm/TensorFlowLite_LSTM_Keras_Tutorial.ipynb).
 
 Note: There is no replacement available for
-[tf.nn.rnn_cell.GRUCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/GRUCell).
+[tf.compat.v1.nn.rnn_cell.GRUCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/GRUCell).
diff --git a/tensorflow/lite/g3doc/guide/build_arm64.md b/tensorflow/lite/g3doc/guide/build_arm64.md
index 304b721..825e235 100644
--- a/tensorflow/lite/g3doc/guide/build_arm64.md
+++ b/tensorflow/lite/g3doc/guide/build_arm64.md
@@ -1,23 +1,37 @@
 # Build TensorFlow Lite for ARM64 boards
 
-## Cross compiling
+This page describes how to build the TensorFlow Lite static library for
+ARM64-based computers. If you just want to start using TensorFlow Lite to
+execute your models, the fastest option is to install the TensorFlow Lite
+runtime package as shown in the [Python quickstart](python.md).
 
-### Installing the toolchain
+Note: This page shows how to compile only the C++ static library for
+TensorFlow Lite. Alternative install options include: [install just the Python
+interpreter API](python.md) (for inferencing only); [install the full
+TensorFlow package from pip](https://www.tensorflow.org/install/pip);
+or [build the full TensorFlow package](
+https://www.tensorflow.org/install/source).
+
+## Cross-compile for ARM64
+
+To ensure the proper build environment, we recommend using one of our TensorFlow
+Docker images such as [tensorflow/tensorflow:nightly-devel](
+https://hub.docker.com/r/tensorflow/tensorflow/tags/).
+
+To get started, install the toolchain and libs:
 
 ```bash
 sudo apt-get update
 sudo apt-get install crossbuild-essential-arm64
 ```
 
-> If you are using Docker, you may not use `sudo`.
+If you are using Docker, you may not use `sudo`.
 
-### Building
-
-Clone this Tensorflow repository. Run this script at the root of the repository
-to download all the dependencies:
-
-> The Tensorflow repository is in `/tensorflow` if you are using
-> `tensorflow/tensorflow:nightly-devel` docker image, just try it.
+Now git-clone the TensorFlow repository
+(`https://github.com/tensorflow/tensorflow`)—if you're using the TensorFlow
+Docker image, the repo is already provided in `/tensorflow_src/`—and then run
+this script at the root of the TensorFlow repository to download all the
+build dependencies:
 
 ```bash
 ./tensorflow/lite/tools/make/download_dependencies.sh
@@ -25,7 +39,7 @@
 
 Note that you only need to do this once.
 
-Compile:
+Then compile:
 
 ```bash
 ./tensorflow/lite/tools/make/build_aarch64_lib.sh
@@ -34,17 +48,19 @@
 This should compile a static library in:
 `tensorflow/lite/tools/make/gen/aarch64_armv8-a/lib/libtensorflow-lite.a`.
 
-## Native compiling
+## Compile natively on ARM64
 
 These steps were tested on HardKernel Odroid C2, gcc version 5.4.0.
 
-Log in to your board, install the toolchain.
+Log in to your board and install the toolchain:
 
 ```bash
 sudo apt-get install build-essential
 ```
 
-First, clone the TensorFlow repository. Run this at the root of the repository:
+Now git-clone the TensorFlow repository
+(`https://github.com/tensorflow/tensorflow`) and run this at the root of
+the repository:
 
 ```bash
 ./tensorflow/lite/tools/make/download_dependencies.sh
@@ -52,7 +68,7 @@
 
 Note that you only need to do this once.
 
-Compile:
+Then compile:
 
 ```bash
 ./tensorflow/lite/tools/make/build_aarch64_lib.sh
diff --git a/tensorflow/lite/g3doc/guide/build_rpi.md b/tensorflow/lite/g3doc/guide/build_rpi.md
index 1a438ab..7ab4b43 100644
--- a/tensorflow/lite/g3doc/guide/build_rpi.md
+++ b/tensorflow/lite/g3doc/guide/build_rpi.md
@@ -1,30 +1,42 @@
 # Build TensorFlow Lite for Raspberry Pi
 
-## Cross compiling
+This page describes how to build the TensorFlow Lite static library for
+Raspberry Pi. If you just want to start using TensorFlow Lite to execute your
+models, the fastest option is to install the TensorFlow Lite runtime package as
+shown in the [Python quickstart](python.md).
 
-### Installing the toolchain
+Note: This page shows how to compile only the C++ static library for
+TensorFlow Lite. Alternative install options include: [install just the Python
+interpreter API](python.md) (for inferencing only); [install the full
+TensorFlow package from pip](https://www.tensorflow.org/install/pip);
+or [build the full TensorFlow package](
+https://www.tensorflow.org/install/source_rpi).
 
-This has been tested on Ubuntu 16.04.3 64bit and Tensorflow devel docker image
+
+## Cross-compile for Raspberry Pi
+
+This has been tested on Ubuntu 16.04.3 64bit and TensorFlow devel docker image
 [tensorflow/tensorflow:nightly-devel](https://hub.docker.com/r/tensorflow/tensorflow/tags/).
 
-To cross compile TensorFlow Lite, first install the toolchain and libs.
+To cross compile TensorFlow Lite, first install the toolchain and libs:
 
 ```bash
 sudo apt-get update
 sudo apt-get install crossbuild-essential-armhf
 ```
 
-> If you are using Docker, you may not use `sudo`.
+If you are using Docker, you may not use `sudo`.
 
-### Building
-
-Clone this Tensorflow repository, Run this script at the root of the repository to download all the dependencies:
-
-> The Tensorflow repository is in `/tensorflow` if you are using `tensorflow/tensorflow:nightly-devel` docker image, just try it.
+Now git-clone the TensorFlow repository
+(`https://github.com/tensorflow/tensorflow`)—if you're using the TensorFlow
+Docker image, the repo is already provided in `/tensorflow_src/`—and then run
+this script at the root of the TensorFlow repository to download all the
+build dependencies:
 
 ```bash
 ./tensorflow/lite/tools/make/download_dependencies.sh
 ```
+
 Note that you only need to do this once.
 
 You should then be able to compile:
@@ -36,23 +48,29 @@
 This should compile a static library in:
 `tensorflow/lite/tools/make/gen/rpi_armv7l/lib/libtensorflow-lite.a`.
 
-## Native compiling
+
+## Compile natively on Raspberry Pi
+
 This has been tested on Raspberry Pi 3b, Raspbian GNU/Linux 9.1 (stretch), gcc version 6.3.0 20170516 (Raspbian 6.3.0-18+rpi1).
 
-Log in to you Raspberry Pi, install the toolchain.
+Log in to your Raspberry Pi and install the toolchain:
 
 ```bash
 sudo apt-get install build-essential
 ```
 
-First, clone the TensorFlow repository. Run this at the root of the repository:
+Now git-clone the TensorFlow repository
+(`https://github.com/tensorflow/tensorflow`) and run this at the root of
+the repository:
 
 ```bash
 ./tensorflow/lite/tools/make/download_dependencies.sh
 ```
+
 Note that you only need to do this once.
 
 You should then be able to compile:
+
 ```bash
 ./tensorflow/lite/tools/make/build_rpi_lib.sh
 ```
diff --git a/tensorflow/lite/g3doc/guide/get_started.md b/tensorflow/lite/g3doc/guide/get_started.md
index 72ddff4..ce16b79 100644
--- a/tensorflow/lite/g3doc/guide/get_started.md
+++ b/tensorflow/lite/g3doc/guide/get_started.md
@@ -211,9 +211,15 @@
 
 ### Linux
 
-Embedded Linux is an important platform for deploying machine learning. We
-provide build instructions for both [Raspberry Pi](build_rpi.md) and
-[Arm64-based boards](build_arm64.md) such as Odroid C2, Pine64, and NanoPi.
+Embedded Linux is an important platform for deploying machine learning. To get
+started using Python to perform inference with your TensorFlow Lite models,
+follow the [Python quickstart](python.md).
+
+To instead install the C++ library, see the
+build instructions for [Raspberry Pi](build_rpi.md) or
+[Arm64-based boards](build_arm64.md) (for boards such as Odroid C2, Pine64, and
+NanoPi).
+
 
 ### Microcontrollers
 
@@ -266,11 +272,16 @@
 import tensorflow as tf
 
 converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
-converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
+converter.optimizations = [tf.lite.Optimize.DEFAULT]
 tflite_quant_model = converter.convert()
 open("converted_model.tflite", "wb").write(tflite_quantized_model)
 ```
 
+TensorFlow Lite supports reducing precision of values from full floating point
+to half-precision floats (float16) or 8-bit integers. There are trade-offs in
+model size and accuracy for each choice, and some operations have optimized
+implementations for these reduced precision types.
+
 To learn more about quantization, see
 [Post-training quantization](../performance/post_training_quantization.md).
 
@@ -289,5 +300,8 @@
 
 *   If you're a mobile developer, visit [Android quickstart](android.md) or
     [iOS quickstart](ios.md).
+*   If you're building Linux embedded devices, see the [Python quickstart](
+    python.md) or C++ build instructions for [Raspberry Pi](build_rpi.md) and
+    [Arm64-based boards](build_arm64.md).
 *   Explore our [pre-trained models](../models).
 *   Try our [example apps](https://www.tensorflow.org/lite/examples).
diff --git a/tensorflow/lite/g3doc/guide/hosted_models.md b/tensorflow/lite/g3doc/guide/hosted_models.md
index 323d31b..560a026 100644
--- a/tensorflow/lite/g3doc/guide/hosted_models.md
+++ b/tensorflow/lite/g3doc/guide/hosted_models.md
@@ -21,29 +21,29 @@
 classification models offer the smallest model size and fastest performance, at
 the expense of accuracy.
 
-Model name                  | Paper and model                                                                                                                                           | Model size | Top-1 accuracy | Top-5 accuracy | TF Lite performance
---------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------:
-Mobilenet_V1_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb     | 39.5%          | 64.4%          | 3.7 ms
-Mobilenet_V1_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb     | 42.8%          | 68.1%          | 5.5 ms
-Mobilenet_V1_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb     | 45.7%          | 70.8%          | 7.9 ms
-Mobilenet_V1_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb     | 48.2%          | 72.8%          | 10.4 ms
-Mobilenet_V1_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_128_quant.tgz)  | 1.4 Mb     | 54.9%          | 78.1%          | 8.8 ms
-Mobilenet_V1_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_160_quant.tgz)  | 1.4 Mb     | 57.2%          | 80.5%          | 13.0 ms
-Mobilenet_V1_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_192_quant.tgz)  | 1.4 Mb     | 59.9%          | 82.1%          | 18.3 ms
-Mobilenet_V1_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_224_quant.tgz)  | 1.4 Mb     | 61.2%          | 83.2%          | 24.7 ms
-Mobilenet_V1_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb     | 55.9%          | 79.1%          | 16.2 ms
-Mobilenet_V1_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb     | 62.4%          | 83.7%          | 24.3 ms
-Mobilenet_V1_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb     | 66.1%          | 86.2%          | 33.8 ms
-Mobilenet_V1_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb     | 66.9%          | 86.9%          | 45.4 ms
-Mobilenet_V1_1.0_128_quant  | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_128_quant.tgz)  | 4.3 Mb     | 63.3%          | 84.1%          | 24.9 ms
-Mobilenet_V1_1.0_160_quant  | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_160_quant.tgz)  | 4.3 Mb     | 66.9%          | 86.7%          | 37.4 ms
-Mobilenet_V1_1.0_192_quant  | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_192_quant.tgz)  | 4.3 Mb     | 69.1%          | 88.1%          | 51.9 ms
-Mobilenet_V1_1.0_224_quant  | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz)  | 4.3 Mb     | 70.0%          | 89.0%          | 70.2 ms
-Mobilenet_V2_1.0_224_quant  | [paper](https://arxiv.org/abs/1806.08342), [tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224_quant.tgz)              | 3.4 Mb     | 70.8%          | 89.9%          | 53.4 ms
-Inception_V1_quant          | [paper](https://arxiv.org/abs/1409.4842), [tflite&pb](http://download.tensorflow.org/models/inception_v1_224_quant_20181026.tgz)                          | 6.4 Mb     | 70.1%          | 89.8%          | 154.5 ms
-Inception_V2_quant          | [paper](https://arxiv.org/abs/1512.00567), [tflite&pb](http://download.tensorflow.org/models/inception_v2_224_quant_20181026.tgz)                         | 11 Mb      | 73.5%          | 91.4%          | 235.0 ms
-Inception_V3_quant          | [paper](https://arxiv.org/abs/1806.08342),[tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/inception_v3_quant.tgz)                       | 23 Mb      | 77.5%          | 93.7%          | 637 ms
-Inception_V4_quant          | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](http://download.tensorflow.org/models/inception_v4_299_quant_20181026.tgz)                         | 41 Mb      | 79.5%          | 93.9%          | 1250.8 ms
+Model name                  | Paper and model                                                                                                                                                                   | Model size | Top-1 accuracy | Top-5 accuracy | TF Lite performance
+--------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------:
+Mobilenet_V1_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb     | 39.5%          | 64.4%          | 3.7 ms
+Mobilenet_V1_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb     | 42.8%          | 68.1%          | 5.5 ms
+Mobilenet_V1_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb     | 45.7%          | 70.8%          | 7.9 ms
+Mobilenet_V1_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb     | 48.2%          | 72.8%          | 10.4 ms
+Mobilenet_V1_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_128_quant.tgz)  | 1.4 Mb     | 54.9%          | 78.1%          | 8.8 ms
+Mobilenet_V1_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_160_quant.tgz)  | 1.4 Mb     | 57.2%          | 80.5%          | 13.0 ms
+Mobilenet_V1_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_192_quant.tgz)  | 1.4 Mb     | 59.9%          | 82.1%          | 18.3 ms
+Mobilenet_V1_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_224_quant.tgz)  | 1.4 Mb     | 61.2%          | 83.2%          | 24.7 ms
+Mobilenet_V1_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb     | 55.9%          | 79.1%          | 16.2 ms
+Mobilenet_V1_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb     | 62.4%          | 83.7%          | 24.3 ms
+Mobilenet_V1_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb     | 66.1%          | 86.2%          | 33.8 ms
+Mobilenet_V1_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb     | 66.9%          | 86.9%          | 45.4 ms
+Mobilenet_V1_1.0_128_quant  | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_128_quant.tgz)  | 4.3 Mb     | 63.3%          | 84.1%          | 24.9 ms
+Mobilenet_V1_1.0_160_quant  | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_160_quant.tgz)  | 4.3 Mb     | 66.9%          | 86.7%          | 37.4 ms
+Mobilenet_V1_1.0_192_quant  | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_192_quant.tgz)  | 4.3 Mb     | 69.1%          | 88.1%          | 51.9 ms
+Mobilenet_V1_1.0_224_quant  | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz)  | 4.3 Mb     | 70.0%          | 89.0%          | 70.2 ms
+Mobilenet_V2_1.0_224_quant  | [paper](https://arxiv.org/abs/1806.08342), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224_quant.tgz)              | 3.4 Mb     | 70.8%          | 89.9%          | 53.4 ms
+Inception_V1_quant          | [paper](https://arxiv.org/abs/1409.4842), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_224_quant_20181026.tgz)                          | 6.4 Mb     | 70.1%          | 89.8%          | 154.5 ms
+Inception_V2_quant          | [paper](https://arxiv.org/abs/1512.00567), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/inception_v2_224_quant_20181026.tgz)                         | 11 Mb      | 73.5%          | 91.4%          | 235.0 ms
+Inception_V3_quant          | [paper](https://arxiv.org/abs/1806.08342),[tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/inception_v3_quant.tgz)                       | 23 Mb      | 77.5%          | 93.7%          | 637 ms
+Inception_V4_quant          | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/inception_v4_299_quant_20181026.tgz)                         | 41 Mb      | 79.5%          | 93.9%          | 1250.8 ms
 
 Note: The model files include both TF Lite FlatBuffer and Tensorflow frozen
 Graph.
@@ -68,23 +68,23 @@
 Inception_V3          | [paper](http://arxiv.org/abs/1512.00567), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz)         | 95.3 Mb    | 77.9%          | 93.8%          | 1433 ms             | 1522 ms
 Inception_V4          | [paper](http://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz)         | 170.7 Mb   | 80.1%          | 95.1%          | 2986 ms             | 3139 ms
 Inception_ResNet_V2   | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz) | 121.0 Mb   | 77.5%          | 94.0%          | 2731 ms             | 2926 ms
-Mobilenet_V1_0.25_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz)                                       | 1.9 Mb     | 41.4%          | 66.2%          | 6.2 ms              | 13.0 ms
-Mobilenet_V1_0.25_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160.tgz)                                       | 1.9 Mb     | 45.4%          | 70.2%          | 8.6 ms              | 19.5 ms
-Mobilenet_V1_0.25_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192.tgz)                                       | 1.9 Mb     | 47.1%          | 72.0%          | 12.1 ms             | 27.8 ms
-Mobilenet_V1_0.25_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224.tgz)                                       | 1.9 Mb     | 49.7%          | 74.1%          | 16.2 ms             | 37.3 ms
-Mobilenet_V1_0.50_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128.tgz)                                        | 5.3 Mb     | 56.2%          | 79.3%          | 18.1 ms             | 29.9 ms
-Mobilenet_V1_0.50_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz)                                        | 5.3 Mb     | 59.0%          | 81.8%          | 26.8 ms             | 45.9 ms
-Mobilenet_V1_0.50_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192.tgz)                                        | 5.3 Mb     | 61.7%          | 83.5%          | 35.6 ms             | 65.3 ms
-Mobilenet_V1_0.50_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224.tgz)                                        | 5.3 Mb     | 63.2%          | 84.9%          | 47.6 ms             | 164.2 ms
-Mobilenet_V1_0.75_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128.tgz)                                       | 10.3 Mb    | 62.0%          | 83.8%          | 34.6 ms             | 48.7 ms
-Mobilenet_V1_0.75_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160.tgz)                                       | 10.3 Mb    | 65.2%          | 85.9%          | 51.3 ms             | 75.2 ms
-Mobilenet_V1_0.75_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192.tgz)                                       | 10.3 Mb    | 67.1%          | 87.2%          | 71.7 ms             | 107.0 ms
-Mobilenet_V1_0.75_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224.tgz)                                       | 10.3 Mb    | 68.3%          | 88.1%          | 95.7 ms             | 143.4 ms
-Mobilenet_V1_1.0_128  | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128.tgz)                                        | 16.9 Mb    | 65.2%          | 85.7%          | 57.4 ms             | 76.8 ms
-Mobilenet_V1_1.0_160  | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160.tgz)                                        | 16.9 Mb    | 68.0%          | 87.7%          | 86.0 ms             | 117.7 ms
-Mobilenet_V1_1.0_192  | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192.tgz)                                        | 16.9 Mb    | 69.9%          | 89.1%          | 118.6 ms            | 167.3 ms
-Mobilenet_V1_1.0_224  | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz)                                        | 16.9 Mb    | 71.0%          | 89.9%          | 160.1 ms            | 224.3 ms
-Mobilenet_V2_1.0_224  | [paper](https://arxiv.org/pdf/1801.04381.pdf), [tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz)                                                | 14.0 Mb    | 71.8%          | 90.6%          | 117 ms              |
+Mobilenet_V1_0.25_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz)               | 1.9 Mb     | 41.4%          | 66.2%          | 6.2 ms              | 13.0 ms
+Mobilenet_V1_0.25_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160.tgz)               | 1.9 Mb     | 45.4%          | 70.2%          | 8.6 ms              | 19.5 ms
+Mobilenet_V1_0.25_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192.tgz)               | 1.9 Mb     | 47.1%          | 72.0%          | 12.1 ms             | 27.8 ms
+Mobilenet_V1_0.25_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224.tgz)               | 1.9 Mb     | 49.7%          | 74.1%          | 16.2 ms             | 37.3 ms
+Mobilenet_V1_0.50_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128.tgz)                | 5.3 Mb     | 56.2%          | 79.3%          | 18.1 ms             | 29.9 ms
+Mobilenet_V1_0.50_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz)                | 5.3 Mb     | 59.0%          | 81.8%          | 26.8 ms             | 45.9 ms
+Mobilenet_V1_0.50_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192.tgz)                | 5.3 Mb     | 61.7%          | 83.5%          | 35.6 ms             | 65.3 ms
+Mobilenet_V1_0.50_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224.tgz)                | 5.3 Mb     | 63.2%          | 84.9%          | 47.6 ms             | 164.2 ms
+Mobilenet_V1_0.75_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128.tgz)               | 10.3 Mb    | 62.0%          | 83.8%          | 34.6 ms             | 48.7 ms
+Mobilenet_V1_0.75_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160.tgz)               | 10.3 Mb    | 65.2%          | 85.9%          | 51.3 ms             | 75.2 ms
+Mobilenet_V1_0.75_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192.tgz)               | 10.3 Mb    | 67.1%          | 87.2%          | 71.7 ms             | 107.0 ms
+Mobilenet_V1_0.75_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224.tgz)               | 10.3 Mb    | 68.3%          | 88.1%          | 95.7 ms             | 143.4 ms
+Mobilenet_V1_1.0_128  | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128.tgz)                | 16.9 Mb    | 65.2%          | 85.7%          | 57.4 ms             | 76.8 ms
+Mobilenet_V1_1.0_160  | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160.tgz)                | 16.9 Mb    | 68.0%          | 87.7%          | 86.0 ms             | 117.7 ms
+Mobilenet_V1_1.0_192  | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192.tgz)                | 16.9 Mb    | 69.9%          | 89.1%          | 118.6 ms            | 167.3 ms
+Mobilenet_V1_1.0_224  | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz)                | 16.9 Mb    | 71.0%          | 89.9%          | 160.1 ms            | 224.3 ms
+Mobilenet_V2_1.0_224  | [paper](https://arxiv.org/pdf/1801.04381.pdf), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz)                        | 14.0 Mb    | 71.8%          | 90.6%          | 117 ms              |
 
 ### AutoML mobile models
 
@@ -113,7 +113,7 @@
 The object detection model we currently host is
 **coco_ssd_mobilenet_v1_1.0_quant_2018_06_29**.
 
-<a class="button button-primary" href="http://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip">Download
+<a class="button button-primary" href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip">Download
 model and labels</a>
 
 ## Pose estimation
diff --git a/tensorflow/lite/g3doc/guide/ops_select.md b/tensorflow/lite/g3doc/guide/ops_select.md
index 7f2d5e3..7990a7a 100644
--- a/tensorflow/lite/g3doc/guide/ops_select.md
+++ b/tensorflow/lite/g3doc/guide/ops_select.md
@@ -40,7 +40,7 @@
 *   `TFLITE_BUILTINS` - Converts models using TensorFlow Lite builtin ops.
 *   `SELECT_TF_OPS` - Converts models using TensorFlow ops. The exact subset of
     supported ops can be found in the whitelist at
-    `lite/toco/tflite/whitelisted_flex_ops.cc`.
+    `lite/delegates/flex/whitelisted_flex_ops.cc`.
 
 Note: `target_spec.supported_ops` was previously `target_ops` in the Python API.
 
diff --git a/tensorflow/lite/g3doc/guide/ops_version.md b/tensorflow/lite/g3doc/guide/ops_version.md
index 9418ce4..c83ea56 100644
--- a/tensorflow/lite/g3doc/guide/ops_version.md
+++ b/tensorflow/lite/g3doc/guide/ops_version.md
@@ -155,7 +155,7 @@
 
 ### Change TOCO TFLite exporter
 
-The last step is to make TOCO populate the minimum version that's required to
+The next step is to make TOCO populate the minimum version that's required to
 execute the op. In this example, it means:
 
 *   Populate version=1 when dilation factors are all 1.
@@ -184,6 +184,21 @@
 }
 ```
 
+### Update the operator version map
+
+The last step is to add the new version info into the operator version map. This
+step is required because we need generate the model's minimum required runtime
+version based on this version map.
+
+To do this, you need to add a new map entry in `lite/toco/tflite/op_version.cc`.
+
+In this example, it means you need to add the following into `op_version_map`:
+```
+{{OperatorType::kConv, 3}, "kPendingReleaseOpVersion"}
+```
+(`kPendingReleaseOpVersion` will be replaced with the appropriate release
+version in the next stable release.)
+
 ### Delegation Implementation
 
 TensorFlow Lite provides a delegation API which enables delegating ops to
diff --git a/tensorflow/lite/g3doc/guide/python.md b/tensorflow/lite/g3doc/guide/python.md
new file mode 100644
index 0000000..fbedd08
--- /dev/null
+++ b/tensorflow/lite/g3doc/guide/python.md
@@ -0,0 +1,100 @@
+# Python quickstart
+
+Using TensorFlow Lite with Python is great for embedded devices based on Linux,
+such as [Raspberry Pi](https://www.raspberrypi.org/){:.external} and
+[Coral devices with Edge TPU](https://coral.withgoogle.com/){:.external},
+among many others.
+
+This page shows how you can start running TensorFlow Lite models with Python in
+just a few minutes. All you need is a TensorFlow model [converted to TensorFlow
+Lite](../convert/). (If you don't have a model converted yet, you can experiment
+using the model provided with the example linked below.)
+
+## Install just the TensorFlow Lite interpreter
+
+To quickly start executing TensorFlow Lite models with Python, you can install
+just the TensorFlow Lite interpreter, instead of all TensorFlow packages.
+
+This interpreter-only package is a fraction the size of the full TensorFlow
+package and includes the bare minimum code required to run inferences with
+TensorFlow Lite—it includes only the [`tf.lite.Interpreter`](
+https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter) Python class.
+This small package is ideal when all you want to do is execute `.tflite` models
+and avoid wasting disk space with the large TensorFlow library.
+
+Note: If you need access to other Python APIs, such as the [TensorFlow Lite
+Converter](../convert/python_api.md), you must install the [full TensorFlow
+package](https://www.tensorflow.org/install/).
+
+To install just the interpreter, download the appropriate Python wheel for your
+system from the following table, and then install it with the `pip install`
+command.
+
+For example, if you're setting up a Raspberry Pi (using Raspbian Buster, which
+has Python 3.7), install the Python wheel as follows (after you click to
+download the `.whl` file below):
+
+<pre class="devsite-terminal devsite-click-to-copy">
+pip3 install tflite_runtime-1.14.0-cp37-cp37m-linux_armv7l.whl
+</pre>
+
+<table>
+<tr><th></th><th>ARM 32</th><th>ARM 64</th><th>x86-64</th></tr>
+<tr><th style="white-space:nowrap">Python 3.5</th>
+  <td><a href="https://dl.google.com/coral/python/tflite_runtime-1.14.0-cp35-cp35m-linux_armv7l.whl"
+    >tflite_runtime-1.14.0-cp35-cp35m-linux_armv7l.whl</a></td>
+  <td><a href="https://dl.google.com/coral/python/tflite_runtime-1.14.0-cp35-cp35m-linux_aarch64.whl"
+    >tflite_runtime-1.14.0-cp35-cp35m-linux_aarch64.whl</a></td>
+  <td><a href="https://dl.google.com/coral/python/tflite_runtime-1.14.0-cp35-cp35m-linux_x86_64.whl"
+    >tflite_runtime-1.14.0-cp35-cp35m-linux_x86_64.whl</a></td>
+</tr>
+<tr><th>Python 3.6</th>
+  <td>N/A</td>
+  <td>N/A</td>
+  <td><a href="https://dl.google.com/coral/python/tflite_runtime-1.14.0-cp36-cp36m-linux_x86_64.whl"
+    >tflite_runtime-1.14.0-cp36-cp36m-linux_x86_64.whl</a></td>
+</tr>
+<tr><th>Python 3.7</th>
+  <td><a href="https://dl.google.com/coral/python/tflite_runtime-1.14.0-cp37-cp37m-linux_armv7l.whl"
+    >tflite_runtime-1.14.0-cp37-cp37m-linux_armv7l.whl</a></td>
+  <td><a href="https://dl.google.com/coral/python/tflite_runtime-1.14.0-cp37-cp37m-linux_aarch64.whl"
+    >tflite_runtime-1.14.0-cp37-cp37m-linux_aarch64.whl</a></td>
+  <td><a href="https://dl.google.com/coral/python/tflite_runtime-1.14.0-cp37-cp37m-linux_x86_64.whl"
+    >tflite_runtime-1.14.0-cp37-cp37m-linux_x86_64.whl</a></td>
+</tr>
+</table>
+
+
+## Run an inference using tflite_runtime
+
+To distinguish this interpreter-only package from the full TensorFlow package
+(allowing both to be installed, if you choose), the Python module provided in
+the above wheel is named `tflite_runtime`.
+
+So instead of importing `Interpreter` from the `tensorflow` module, you need to
+import it from `tflite_runtime`.
+
+For example, after you install the package above, copy and run the
+[`label_image.py`](
+https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/examples/python/)
+file. It will (probably) fail because you don't have the `tensorflow` library
+installed. To fix it, simply edit this line of the file:
+
+```python
+from tensorflow.lite.python.interpreter import Interpreter
+```
+
+So it instead reads:
+
+```python
+from tflite_runtime import Interpreter
+```
+
+Now run `label_image.py` again. That's it! You're now executing TensorFlow Lite
+models.
+
+For more details about the `Interpreter` API, read [Load and run a model
+in Python](inference.md#load-and-run-a-model-in-python).
+
+To convert other TensorFlow models to TensorFlow Lite, read about the
+the [TensorFlow Lite Converter](../convert/).
diff --git a/tensorflow/lite/g3doc/models/object_detection/overview.md b/tensorflow/lite/g3doc/models/object_detection/overview.md
index 94df4aa..f9da639 100644
--- a/tensorflow/lite/g3doc/models/object_detection/overview.md
+++ b/tensorflow/lite/g3doc/models/object_detection/overview.md
@@ -20,7 +20,7 @@
 familiar with the <a href="https://www.tensorflow.org/api_docs/python/tf/lite">TensorFlow Lite APIs</a>, you can
 download our starter object detection model and the accompanying labels.
 
-<a class="button button-primary" href="http://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip">Download
+<a class="button button-primary" href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip">Download
 starter model and labels</a>
 
 For more information about the starter model, see
@@ -185,7 +185,7 @@
 We recommend starting with this pre-trained quantized COCO SSD MobileNet v1
 model.
 
-<a class="button button-primary" href="http://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip">Download
+<a class="button button-primary" href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip">Download
 starter model and labels</a>
 
 ### Uses and limitations
@@ -193,7 +193,7 @@
 The object detection model we provide can identify and locate up to 10 objects
 in an image. It is trained to recognize 80 classes of object. For a full list of
 classes, see the labels file in the
-<a href="http://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip">model
+<a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip">model
 zip</a>.
 
 If you want to train a model to recognize new classes, see
@@ -256,7 +256,7 @@
 
 The pre-trained models we provide are trained to detect 80 classes of object.
 For a full list of classes, see the labels file in the
-<a href="http://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip">model
+<a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip">model
 zip</a>.
 
 You can use a technique known as transfer learning to re-train a model to
diff --git a/tensorflow/lite/g3doc/models/pose_estimation/overview.md b/tensorflow/lite/g3doc/models/pose_estimation/overview.md
index b64d8b1..3ff915c 100644
--- a/tensorflow/lite/g3doc/models/pose_estimation/overview.md
+++ b/tensorflow/lite/g3doc/models/pose_estimation/overview.md
@@ -7,14 +7,22 @@
 _PoseNet_ is a vision model that can be used to estimate the pose of a person in
 an image or video by estimating where key body joints are.
 
-<a class="button button-primary" href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/gpu/multi_person_mobilenet_v1_075_float.tflite">Download
-starter model</a>
+<a class="button button-primary" href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/posenet_mobilenet_v1_100_257x257_multi_kpt_stripped.tflite">
+Download starter model</a>
 
 Android and iOS end-to-end tutorials are coming soon. In the meantime, if you
 want to experiment this on a web browser, check out the
 <a href="https://github.com/tensorflow/tfjs-models/tree/master/posenet">TensorFlow.js
 GitHub repository</a>.
 
+### Example applications and guides
+
+There is a TensorFlow Lite sample application that demonstrates the PoseNet
+model on Android.
+
+<a class="button button-primary" href="https://github.com/tensorflow/examples/tree/master/lite/examples/posenet/android">
+Android example</a>.
+
 ## How it works
 
 Pose estimation refers to computer vision techniques that detect human figures
@@ -138,6 +146,7 @@
 <ul>
   <li><a href="https://medium.com/tensorflow/real-time-human-pose-estimation-in-the-browser-with-tensorflow-js-7dd0bc881cd5">Blog post: Real-time Human Pose Estimation in the Browser with TensorFlow.js</a></li>
   <li><a href="https://github.com/tensorflow/tfjs-models/tree/master/posenet">TF.js GitHub: Pose Detection in the Browser: PoseNet Model</a></li>
+   <li><a href="https://medium.com/tensorflow/track-human-poses-in-real-time-on-android-with-tensorflow-lite-e66d0f3e6f9e">Blog post: Track human poses in real-time on Android with TensorFlow Lite</a></li>
 </ul>
 
 ### Use cases
diff --git a/tensorflow/lite/g3doc/models/smart_reply/overview.md b/tensorflow/lite/g3doc/models/smart_reply/overview.md
index b2363ad..abfcc8c 100644
--- a/tensorflow/lite/g3doc/models/smart_reply/overview.md
+++ b/tensorflow/lite/g3doc/models/smart_reply/overview.md
@@ -8,7 +8,7 @@
 suggestions are intended to be contextually relevant, one-touch responses that
 help the user to easily reply to an incoming message.
 
-<a class="button button-primary" href="http://download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip">Download
+<a class="button button-primary" href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip">Download
 starter model and labels</a>
 
 ### Sample application
diff --git a/tensorflow/lite/g3doc/performance/benchmarks.md b/tensorflow/lite/g3doc/performance/benchmarks.md
index a51fdb4..c730520 100644
--- a/tensorflow/lite/g3doc/performance/benchmarks.md
+++ b/tensorflow/lite/g3doc/performance/benchmarks.md
@@ -46,7 +46,7 @@
   </thead>
   <tr>
     <td rowspan = 2>
-      <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz">Mobilenet_1.0_224(float)</a>
+      <a href="https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz">Mobilenet_1.0_224(float)</a>
     </td>
     <td>Pixel 2 </td>
     <td>123.3 ms</td>
@@ -57,7 +57,7 @@
   </tr>
   <tr>
     <td rowspan = 2>
-      <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz">Mobilenet_1.0_224 (quant)</a>
+      <a href="https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz">Mobilenet_1.0_224 (quant)</a>
     </td>
     <td>Pixel 2 </td>
     <td>65.4 ms</td>
@@ -130,14 +130,14 @@
   </thead>
   <tr>
     <td>
-      <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz">Mobilenet_1.0_224(float)</a>
+      <a href="https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz">Mobilenet_1.0_224(float)</a>
     </td>
     <td>iPhone 8 </td>
     <td>32.2 ms</td>
   </tr>
   <tr>
     <td>
-      <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz)">Mobilenet_1.0_224 (quant)</a>
+      <a href="https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz)">Mobilenet_1.0_224 (quant)</a>
     </td>
     <td>iPhone 8 </td>
     <td>24.4 ms</td>
diff --git a/tensorflow/lite/g3doc/performance/images/optimization.jpg b/tensorflow/lite/g3doc/performance/images/optimization.jpg
index 1a419f6..f866768 100644
--- a/tensorflow/lite/g3doc/performance/images/optimization.jpg
+++ b/tensorflow/lite/g3doc/performance/images/optimization.jpg
Binary files differ
diff --git a/tensorflow/lite/g3doc/performance/post_training_float16_quant.ipynb b/tensorflow/lite/g3doc/performance/post_training_float16_quant.ipynb
new file mode 100644
index 0000000..87f5081
--- /dev/null
+++ b/tensorflow/lite/g3doc/performance/post_training_float16_quant.ipynb
@@ -0,0 +1,683 @@
+{
+  "nbformat": 4,
+  "nbformat_minor": 0,
+  "metadata": {
+    "colab": {
+      "name": "post_training-float16-quant.ipynb",
+      "version": "0.3.2",
+      "provenance": [],
+      "private_outputs": true,
+      "collapsed_sections": [],
+      "toc_visible": true
+    },
+    "kernelspec": {
+      "name": "python3",
+      "display_name": "Python 3"
+    }
+  },
+  "cells": [
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "c8Cx-rUMVX25",
+        "colab_type": "text"
+      },
+      "source": [
+        "##### Copyright 2019 The TensorFlow Authors."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "I9sUhVL_VZNO",
+        "colab_type": "code",
+        "colab": {},
+        "cellView": "form"
+      },
+      "source": [
+        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+        "# you may not use this file except in compliance with the License.\n",
+        "# You may obtain a copy of the License at\n",
+        "#\n",
+        "# https://www.apache.org/licenses/LICENSE-2.0\n",
+        "#\n",
+        "# Unless required by applicable law or agreed to in writing, software\n",
+        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+        "# See the License for the specific language governing permissions and\n",
+        "# limitations under the License."
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "6Y8E0lw5eYWm"
+      },
+      "source": [
+        "# Post-training float16 quantization"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "CGuqeuPSVNo-",
+        "colab_type": "text"
+      },
+      "source": [
+        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/lite/performance/post_training_float16_quant\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
+        "  </td>\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/post_training_float16_quant.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
+        "  </td>\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/post_training_float16_quant.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
+        "  </td>\n",
+        "</table>"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "BTC1rDAuei_1"
+      },
+      "source": [
+        "## Overview\n",
+        "\n",
+        "[TensorFlow Lite](https://www.tensorflow.org/lite/) now supports\n",
+        "converting weights to 16-bit floating point values during model conversion from TensorFlow to TensorFlow Lite's flat buffer format. This results in a 2x reduction in model size. Some harware, like GPUs, can compute natively in this reduced precision arithmetic, realizing a speedup over traditional floating point execution. The Tensorflow Lite GPU delegate can be configured to run in this way. However, a model converted to float16 weights can still run on the CPU without additional modification: the float16 weights are  upsampled to float32 prior to the first inference. This permits a significant reduction in model size in exchange for a minimal impacts to latency and accuracy.\n",
+        "\n",
+        "In this tutorial, you train an MNIST model from scratch, check its accuracy in TensorFlow, and then convert the saved model into a Tensorflow Lite flatbuffer\n",
+        "with float16 quantization. Finally, check the\n",
+        "accuracy of the converted model and compare it to the original saved model. The training script, `mnist.py`, is available from the\n",
+        "[TensorFlow official MNIST tutorial](https://github.com/tensorflow/models/tree/master/official/mnist).\n"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "2XsEP17Zelz9"
+      },
+      "source": [
+        "## Build an MNIST model"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "dDqqUIZjZjac"
+      },
+      "source": [
+        "### Setup"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "gyqAw1M9lyab",
+        "colab": {}
+      },
+      "source": [
+        "! pip uninstall -y tensorflow\n",
+        "! pip install -U tf-nightly"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "WsN6s5L1ieNl",
+        "colab": {}
+      },
+      "source": [
+        "import tensorflow as tf\n",
+        "tf.enable_eager_execution()\n",
+        "\n",
+        "import numpy as np\n",
+        "\n",
+        "tf.logging.set_verbosity(tf.logging.DEBUG)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "00U0taBoe-w7",
+        "colab": {}
+      },
+      "source": [
+        "! git clone --depth 1 https://github.com/tensorflow/models"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "c6nb7OPlXs_3",
+        "colab_type": "code",
+        "colab": {}
+      },
+      "source": [
+        "tf.lite.constants.FLOAT16"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "4XZPtSh-fUOc",
+        "colab": {}
+      },
+      "source": [
+        "import sys\n",
+        "import os\n",
+        "\n",
+        "if sys.version_info.major >= 3:\n",
+        "    import pathlib\n",
+        "else:\n",
+        "    import pathlib2 as pathlib\n",
+        "\n",
+        "# Add `models` to the python path.\n",
+        "models_path = os.path.join(os.getcwd(), \"models\")\n",
+        "sys.path.append(models_path)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "eQ6Q0qqKZogR"
+      },
+      "source": [
+        "### Train and export the model"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "eMsw_6HujaqM",
+        "colab": {}
+      },
+      "source": [
+        "saved_models_root = \"/tmp/mnist_saved_model\""
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "hWSAjQWagIHl",
+        "colab": {}
+      },
+      "source": [
+        "# The above path addition is not visible to subprocesses, add the path for the subprocess as well.\n",
+        "!PYTHONPATH={models_path} python models/official/mnist/mnist.py --train_epochs=1 --export_dir {saved_models_root} --data_format=channels_last"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "5NMaNZQCkW9X"
+      },
+      "source": [
+        "For the example, you trained the model for just a single epoch, so it only trains to ~96% accuracy."
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "xl8_fzVAZwOh"
+      },
+      "source": [
+        "### Convert to a TensorFlow Lite model\n",
+        "\n",
+        "The `savedmodel` directory is named with a timestamp. Select the most recent one: "
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "Xp5oClaZkbtn",
+        "colab": {}
+      },
+      "source": [
+        "saved_model_dir = str(sorted(pathlib.Path(saved_models_root).glob(\"*\"))[-1])\n",
+        "saved_model_dir"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "AT8BgkKmljOy"
+      },
+      "source": [
+        "Using the [Python `TFLiteConverter`](https://www.tensorflow.org/lite/convert/python_api), the saved model can be converted into a TensorFlow Lite model.\n",
+        "\n",
+        "First load the model using the `TFLiteConverter`:"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "_i8B2nDZmAgQ",
+        "colab": {}
+      },
+      "source": [
+        "converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)\n",
+        "tflite_model = converter.convert()"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "F2o2ZfF0aiCx"
+      },
+      "source": [
+        "Write it out to a `.tflite` file:"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "vptWZq2xnclo",
+        "colab": {}
+      },
+      "source": [
+        "tflite_models_dir = pathlib.Path(\"/tmp/mnist_tflite_models/\")\n",
+        "tflite_models_dir.mkdir(exist_ok=True, parents=True)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "Ie9pQaQrn5ue",
+        "colab": {}
+      },
+      "source": [
+        "tflite_model_file = tflite_models_dir/\"mnist_model.tflite\"\n",
+        "tflite_model_file.write_bytes(tflite_model)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "7BONhYtYocQY"
+      },
+      "source": [
+        "To instead quantize the model to float16 on export, first set the `optimizations` flag to use default optimizations. Then specify that float16 is the supported type on the target platform:"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "HEZ6ET1AHAS3",
+        "colab": {}
+      },
+      "source": [
+        "tf.logging.set_verbosity(tf.logging.INFO)\n",
+        "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
+        "converter.target_spec.supported_types = [tf.lite.constants.FLOAT16]"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "xW84iMYjHd9t",
+        "colab_type": "text"
+      },
+      "source": [
+        "Finally, convert the model like usual. Note, by default the converted model will still use float input and outputs for invocation convenience."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "yuNfl3CoHNK3",
+        "colab_type": "code",
+        "colab": {}
+      },
+      "source": [
+        "tflite_fp16_model = converter.convert()\n",
+        "tflite_model_fp16_file = tflite_models_dir/\"mnist_model_quant_f16.tflite\"\n",
+        "tflite_model_fp16_file.write_bytes(tflite_fp16_model)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "PhMmUTl4sbkz"
+      },
+      "source": [
+        "Note how the resulting file is approximately `1/2` the size."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "JExfcfLDscu4",
+        "colab": {}
+      },
+      "source": [
+        "!ls -lh {tflite_models_dir}"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "L8lQHMp_asCq"
+      },
+      "source": [
+        "## Run the TensorFlow Lite models"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "-5l6-ciItvX6"
+      },
+      "source": [
+        "Run the TensorFlow Lite model using the Python TensorFlow Lite Interpreter. \n",
+        "\n",
+        "### Load the test data\n",
+        "\n",
+        "First, let's load the MNIST test data to feed to the model:"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "eTIuU07NuKFL",
+        "colab": {}
+      },
+      "source": [
+        "_, mnist_test = tf.keras.datasets.mnist.load_data()\n",
+        "images, labels = tf.cast(mnist_test[0], tf.float32)/255.0, mnist_test[1]\n",
+        "\n",
+        "mnist_ds = tf.data.Dataset.from_tensor_slices((images, labels)).batch(1)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "Ap_jE7QRvhPf"
+      },
+      "source": [
+        "### Load the model into the interpreters"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "Jn16Rc23zTss",
+        "colab": {}
+      },
+      "source": [
+        "interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))\n",
+        "interpreter.allocate_tensors()"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "J8Pztk1mvNVL",
+        "colab": {}
+      },
+      "source": [
+        "interpreter_fp16 = tf.lite.Interpreter(model_path=str(tflite_model_fp16_file))\n",
+        "interpreter_fp16.allocate_tensors()"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "2opUt_JTdyEu"
+      },
+      "source": [
+        "### Test the models on one image"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "AKslvo2kwWac",
+        "colab": {}
+      },
+      "source": [
+        "for img, label in mnist_ds:\n",
+        "  break\n",
+        "\n",
+        "interpreter.set_tensor(interpreter.get_input_details()[0][\"index\"], img)\n",
+        "interpreter.invoke()\n",
+        "predictions = interpreter.get_tensor(\n",
+        "    interpreter.get_output_details()[0][\"index\"])"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "XZClM2vo3_bm",
+        "colab": {}
+      },
+      "source": [
+        "import matplotlib.pylab as plt\n",
+        "\n",
+        "plt.imshow(img[0])\n",
+        "template = \"True:{true}, predicted:{predict}\"\n",
+        "_ = plt.title(template.format(true= str(label[0].numpy()),\n",
+        "                              predict=str(predictions[0])))\n",
+        "plt.grid(False)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "3gwhv4lKbYZ4",
+        "colab": {}
+      },
+      "source": [
+        "interpreter_fp16.set_tensor(\n",
+        "    interpreter_fp16.get_input_details()[0][\"index\"], img)\n",
+        "interpreter_fp16.invoke()\n",
+        "predictions = interpreter_fp16.get_tensor(\n",
+        "    interpreter_fp16.get_output_details()[0][\"index\"])"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "CIH7G_MwbY2x",
+        "colab": {}
+      },
+      "source": [
+        "plt.imshow(img[0])\n",
+        "template = \"True:{true}, predicted:{predict}\"\n",
+        "_ = plt.title(template.format(true= str(label[0].numpy()),\n",
+        "                              predict=str(predictions[0])))\n",
+        "plt.grid(False)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "LwN7uIdCd8Gw"
+      },
+      "source": [
+        "### Evaluate the models"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "05aeAuWjvjPx",
+        "colab": {}
+      },
+      "source": [
+        "def eval_model(interpreter, mnist_ds):\n",
+        "  total_seen = 0\n",
+        "  num_correct = 0\n",
+        "\n",
+        "  input_index = interpreter.get_input_details()[0][\"index\"]\n",
+        "  output_index = interpreter.get_output_details()[0][\"index\"]\n",
+        "  for img, label in mnist_ds:\n",
+        "    total_seen += 1\n",
+        "    interpreter.set_tensor(input_index, img)\n",
+        "    interpreter.invoke()\n",
+        "    predictions = interpreter.get_tensor(output_index)\n",
+        "    if predictions == label.numpy():\n",
+        "      num_correct += 1\n",
+        "\n",
+        "    if total_seen % 500 == 0:\n",
+        "      print(\"Accuracy after %i images: %f\" %\n",
+        "            (total_seen, float(num_correct) / float(total_seen)))\n",
+        "\n",
+        "  return float(num_correct) / float(total_seen)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "T5mWkSbMcU5z",
+        "colab": {}
+      },
+      "source": [
+        "# Create smaller dataset for demonstration purposes\n",
+        "mnist_ds_demo = mnist_ds.take(2000)\n",
+        "\n",
+        "print(eval_model(interpreter, mnist_ds_demo))"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "Km3cY9ry8ZlG"
+      },
+      "source": [
+        "Repeat the evaluation on the float16 quantized model to obtain:"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "-9cnwiPp6EGm",
+        "colab": {}
+      },
+      "source": [
+        "# NOTE: Colab runs on server CPUs. At the time of writing this, TensorFlow Lite\n",
+        "# doesn't have super optimized server CPU kernels. For this reason this may be\n",
+        "# slower than the above float interpreter. But for mobile CPUs, considerable\n",
+        "# speedup can be observed.\n",
+        "print(eval_model(interpreter_fp16, mnist_ds_demo))"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "L7lfxkor8pgv"
+      },
+      "source": [
+        "In this example, you have quantized a model to float16 with no difference in the accuracy.\n",
+        "\n",
+        "It's also possible to evaluate the fp16 quantized model on the GPU. To perform all arithmetic with the reduced precision values, be sure to create the `TfLiteGPUDelegateOptions` struct in your app and set `precision_loss_allowed` to `1`, like this:\n",
+        "\n",
+        "```\n",
+        "//Prepare GPU delegate.\n",
+        "const TfLiteGpuDelegateOptions options = {\n",
+        "  .metadata = NULL,\n",
+        "  .compile_options = {\n",
+        "    .precision_loss_allowed = 1,  // FP16\n",
+        "    .preferred_gl_object_type = TFLITE_GL_OBJECT_TYPE_FASTEST,\n",
+        "    .dynamic_batch_enabled = 0,   // Not fully functional yet\n",
+        "  },\n",
+        "};\n",
+        "```\n",
+        "\n",
+        "Detailed documentation on the TFLite GPU delegate and how to use it in your application can be found [here](https://www.tensorflow.org/lite/performance/gpu_advanced?source=post_page---------------------------)"
+      ]
+    }
+  ]
+}
\ No newline at end of file
diff --git a/tensorflow/lite/g3doc/performance/post_training_integer_quant.ipynb b/tensorflow/lite/g3doc/performance/post_training_integer_quant.ipynb
new file mode 100644
index 0000000..c715011
--- /dev/null
+++ b/tensorflow/lite/g3doc/performance/post_training_integer_quant.ipynb
@@ -0,0 +1,691 @@
+{
+  "nbformat": 4,
+  "nbformat_minor": 0,
+  "metadata": {
+    "colab": {
+      "name": "post_training_integer_quant.ipynb",
+      "version": "0.3.2",
+      "provenance": [],
+      "private_outputs": true,
+      "collapsed_sections": [],
+      "toc_visible": true
+    },
+    "kernelspec": {
+      "name": "python3",
+      "display_name": "Python 3"
+    }
+  },
+  "cells": [
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "_DDaAex5Q7u-",
+        "colab_type": "text"
+      },
+      "source": [
+        "##### Copyright 2019 The TensorFlow Authors."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "W1dWWdNHQ9L0",
+        "colab_type": "code",
+        "colab": {},
+        "cellView": "form"
+      },
+      "source": [
+        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+        "# you may not use this file except in compliance with the License.\n",
+        "# You may obtain a copy of the License at\n",
+        "#\n",
+        "# https://www.apache.org/licenses/LICENSE-2.0\n",
+        "#\n",
+        "# Unless required by applicable law or agreed to in writing, software\n",
+        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+        "# See the License for the specific language governing permissions and\n",
+        "# limitations under the License."
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "6Y8E0lw5eYWm"
+      },
+      "source": [
+        "# Post-training integer quantization"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "CIGrZZPTZVeO"
+      },
+      "source": [
+        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/lite/performance/post_training_integer_quant\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
+        "  </td>\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/post_training_integer_quant.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
+        "  </td>\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/post_training_integer_quant.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
+        "  </td>\n",
+        "</table>"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "BTC1rDAuei_1"
+      },
+      "source": [
+        "## Overview\n",
+        "\n",
+        "[TensorFlow Lite](https://www.tensorflow.org/lite/) now supports\n",
+        "converting an entire model (weights and activations) to 8-bit during model conversion from TensorFlow to TensorFlow Lite's flat buffer format. This results in a 4x reduction in model size and a 3 to 4x performance improvement on CPU performance. In addition, this fully quantized model can be consumed by integer-only hardware accelerators.\n",
+        "\n",
+        "In contrast to [post-training \"on-the-fly\" quantization](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/tutorials/post_training_quant.ipynb)\n",
+        ", which only stores weights as 8-bit ints, in this technique all weights *and* activations are quantized statically during model conversion.\n",
+        "\n",
+        "In this tutorial, you train an MNIST model from scratch, check its accuracy in TensorFlow, and then convert the saved model into a Tensorflow Lite flatbuffer\n",
+        "with full quantization. Finally, check the\n",
+        "accuracy of the converted model and compare it to the original saved model. The training script, `mnist.py`, is available from the\n",
+        "[TensorFlow official MNIST tutorial](https://github.com/tensorflow/models/tree/master/official/mnist).\n"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "2XsEP17Zelz9"
+      },
+      "source": [
+        "## Build an MNIST model"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "dDqqUIZjZjac"
+      },
+      "source": [
+        "### Setup"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "gyqAw1M9lyab",
+        "colab": {}
+      },
+      "source": [
+        "! pip uninstall -y tensorflow\n",
+        "! pip install -U tf-nightly"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "WsN6s5L1ieNl",
+        "colab": {}
+      },
+      "source": [
+        "import tensorflow as tf\n",
+        "tf.enable_eager_execution()"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "00U0taBoe-w7",
+        "colab": {}
+      },
+      "source": [
+        "! git clone --depth 1 https://github.com/tensorflow/models"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "4XZPtSh-fUOc",
+        "colab": {}
+      },
+      "source": [
+        "import sys\n",
+        "import os\n",
+        "\n",
+        "if sys.version_info.major >= 3:\n",
+        "    import pathlib\n",
+        "else:\n",
+        "    import pathlib2 as pathlib\n",
+        "\n",
+        "# Add `models` to the python path.\n",
+        "models_path = os.path.join(os.getcwd(), \"models\")\n",
+        "sys.path.append(models_path)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "eQ6Q0qqKZogR"
+      },
+      "source": [
+        "### Train and export the model"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "eMsw_6HujaqM",
+        "colab": {}
+      },
+      "source": [
+        "saved_models_root = \"/tmp/mnist_saved_model\""
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "hWSAjQWagIHl",
+        "colab": {}
+      },
+      "source": [
+        "# The above path addition is not visible to subprocesses, add the path for the subprocess as well.\n",
+        "# Note: channels_last is required here or the conversion may fail. \n",
+        "!PYTHONPATH={models_path} python models/official/mnist/mnist.py --train_epochs=1 --export_dir {saved_models_root} --data_format=channels_last"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "5NMaNZQCkW9X"
+      },
+      "source": [
+        "For the example, you train the model for just a single epoch, so it only trains to ~96% accuracy."
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "xl8_fzVAZwOh"
+      },
+      "source": [
+        "### Convert to a TensorFlow Lite model\n",
+        "\n",
+        "The `savedmodel` directory is named with a timestamp. Select the most recent one: "
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "Xp5oClaZkbtn",
+        "colab": {}
+      },
+      "source": [
+        "saved_model_dir = str(sorted(pathlib.Path(saved_models_root).glob(\"*\"))[-1])\n",
+        "saved_model_dir"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "AT8BgkKmljOy"
+      },
+      "source": [
+        "Using the [Python `TFLiteConverter`](https://www.tensorflow.org/lite/convert/python_api), the saved model can be converted into a TensorFlow Lite model.\n",
+        "\n",
+        "First load the model using the `TFLiteConverter`:"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "_i8B2nDZmAgQ",
+        "colab": {}
+      },
+      "source": [
+        "import tensorflow as tf\n",
+        "tf.enable_eager_execution()\n",
+        "tf.logging.set_verbosity(tf.logging.DEBUG)\n",
+        "\n",
+        "converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)\n",
+        "tflite_model = converter.convert()"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "F2o2ZfF0aiCx"
+      },
+      "source": [
+        "Write it out to a `.tflite` file:"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "vptWZq2xnclo",
+        "colab": {}
+      },
+      "source": [
+        "tflite_models_dir = pathlib.Path(\"/tmp/mnist_tflite_models/\")\n",
+        "tflite_models_dir.mkdir(exist_ok=True, parents=True)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "Ie9pQaQrn5ue",
+        "colab": {}
+      },
+      "source": [
+        "tflite_model_file = tflite_models_dir/\"mnist_model.tflite\"\n",
+        "tflite_model_file.write_bytes(tflite_model)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "7BONhYtYocQY"
+      },
+      "source": [
+        "To instead quantize the model on export, first set the `optimizations` flag to optimize for size:"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "HEZ6ET1AHAS3",
+        "colab": {}
+      },
+      "source": [
+        "tf.logging.set_verbosity(tf.logging.INFO)\n",
+        "converter.optimizations = [tf.lite.Optimize.DEFAULT]"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "rTe8avZJHMDO",
+        "colab_type": "text"
+      },
+      "source": [
+        "Now, construct and provide a representative dataset, this is used to get the dynamic range of activations."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "FiwiWU3gHdkW",
+        "colab_type": "code",
+        "colab": {}
+      },
+      "source": [
+        "mnist_train, _ = tf.keras.datasets.mnist.load_data()\n",
+        "images = tf.cast(mnist_train[0], tf.float32)/255.0\n",
+        "mnist_ds = tf.data.Dataset.from_tensor_slices((images)).batch(1)\n",
+        "def representative_data_gen():\n",
+        "  for input_value in mnist_ds.take(100):\n",
+        "    yield [input_value]\n",
+        "\n",
+        "converter.representative_dataset = representative_data_gen"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "xW84iMYjHd9t",
+        "colab_type": "text"
+      },
+      "source": [
+        "Finally, convert the model like usual. By default, the converted model will still use float input and outputs for invocation convenience."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "yuNfl3CoHNK3",
+        "colab_type": "code",
+        "colab": {}
+      },
+      "source": [
+        "tflite_quant_model = converter.convert()\n",
+        "tflite_model_quant_file = tflite_models_dir/\"mnist_model_quant.tflite\"\n",
+        "tflite_model_quant_file.write_bytes(tflite_quant_model)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "PhMmUTl4sbkz"
+      },
+      "source": [
+        "Note how the resulting file is approximately `1/4` the size."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "JExfcfLDscu4",
+        "colab": {}
+      },
+      "source": [
+        "!ls -lh {tflite_models_dir}"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "L8lQHMp_asCq"
+      },
+      "source": [
+        "## Run the TensorFlow Lite models"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "-5l6-ciItvX6"
+      },
+      "source": [
+        "Run the TensorFlow Lite model using the Python TensorFlow Lite\n",
+        "Interpreter. \n",
+        "\n",
+        "### Load the test data\n",
+        "\n",
+        "First, let's load the MNIST test data to feed to the model:"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "eTIuU07NuKFL",
+        "colab": {}
+      },
+      "source": [
+        "import numpy as np\n",
+        "_, mnist_test = tf.keras.datasets.mnist.load_data()\n",
+        "images, labels = tf.cast(mnist_test[0], tf.float32)/255.0, mnist_test[1]\n",
+        "\n",
+        "mnist_ds = tf.data.Dataset.from_tensor_slices((images, labels)).batch(1)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "Ap_jE7QRvhPf"
+      },
+      "source": [
+        "### Load the model into the interpreters"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "Jn16Rc23zTss",
+        "colab": {}
+      },
+      "source": [
+        "interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))\n",
+        "interpreter.allocate_tensors()"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "J8Pztk1mvNVL",
+        "colab": {}
+      },
+      "source": [
+        "interpreter_quant = tf.lite.Interpreter(model_path=str(tflite_model_quant_file))\n",
+        "interpreter_quant.allocate_tensors()"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "2opUt_JTdyEu"
+      },
+      "source": [
+        "### Test the models on one image"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "AKslvo2kwWac",
+        "colab": {}
+      },
+      "source": [
+        "for img, label in mnist_ds:\n",
+        "  break\n",
+        "\n",
+        "interpreter.set_tensor(interpreter.get_input_details()[0][\"index\"], img)\n",
+        "interpreter.invoke()\n",
+        "predictions = interpreter.get_tensor(\n",
+        "    interpreter.get_output_details()[0][\"index\"])"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "XZClM2vo3_bm",
+        "colab": {}
+      },
+      "source": [
+        "import matplotlib.pylab as plt\n",
+        "\n",
+        "plt.imshow(img[0])\n",
+        "template = \"True:{true}, predicted:{predict}\"\n",
+        "_ = plt.title(template.format(true= str(label[0].numpy()),\n",
+        "                              predict=str(predictions[0])))\n",
+        "plt.grid(False)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "3gwhv4lKbYZ4",
+        "colab": {}
+      },
+      "source": [
+        "interpreter_quant.set_tensor(\n",
+        "    interpreter_quant.get_input_details()[0][\"index\"], img)\n",
+        "interpreter_quant.invoke()\n",
+        "predictions = interpreter_quant.get_tensor(\n",
+        "    interpreter_quant.get_output_details()[0][\"index\"])"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "CIH7G_MwbY2x",
+        "colab": {}
+      },
+      "source": [
+        "plt.imshow(img[0])\n",
+        "template = \"True:{true}, predicted:{predict}\"\n",
+        "_ = plt.title(template.format(true= str(label[0].numpy()),\n",
+        "                              predict=str(predictions[0])))\n",
+        "plt.grid(False)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "LwN7uIdCd8Gw"
+      },
+      "source": [
+        "### Evaluate the models"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "05aeAuWjvjPx",
+        "colab": {}
+      },
+      "source": [
+        "def eval_model(interpreter, mnist_ds):\n",
+        "  total_seen = 0\n",
+        "  num_correct = 0\n",
+        "\n",
+        "  input_index = interpreter.get_input_details()[0][\"index\"]\n",
+        "  output_index = interpreter.get_output_details()[0][\"index\"]\n",
+        "\n",
+        "  for img, label in mnist_ds:\n",
+        "    total_seen += 1\n",
+        "    interpreter.set_tensor(input_index, img)\n",
+        "    interpreter.invoke()\n",
+        "    predictions = interpreter.get_tensor(output_index)\n",
+        "    if predictions == label.numpy():\n",
+        "      num_correct += 1\n",
+        "\n",
+        "    if total_seen % 500 == 0:\n",
+        "      print(\"Accuracy after %i images: %f\" %\n",
+        "            (total_seen, float(num_correct) / float(total_seen)))\n",
+        "\n",
+        "  return float(num_correct) / float(total_seen)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "T5mWkSbMcU5z",
+        "colab": {}
+      },
+      "source": [
+        "# Create smaller dataset for demonstration purposes\n",
+        "mnist_ds_demo = mnist_ds.take(2000)\n",
+        "\n",
+        "print(eval_model(interpreter, mnist_ds_demo))"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "Km3cY9ry8ZlG"
+      },
+      "source": [
+        "Repeat the evaluation on the fully quantized model to obtain:"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "-9cnwiPp6EGm",
+        "colab": {}
+      },
+      "source": [
+        "# NOTE: Colab runs on server CPUs. At the time of writing this, TensorFlow Lite\n",
+        "# doesn't have super optimized server CPU kernels. For this reason this may be\n",
+        "# slower than the above float interpreter. But for mobile CPUs, considerable\n",
+        "# speedup can be observed.\n",
+        "# Only use 2000 for demonstration purposes\n",
+        "print(eval_model(interpreter_quant, mnist_ds_demo))"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "L7lfxkor8pgv"
+      },
+      "source": [
+        "In this example, you have fully quantized a model with no difference in the accuracy."
+      ]
+    }
+  ]
+}
\ No newline at end of file
diff --git a/tensorflow/lite/g3doc/performance/post_training_quant.ipynb b/tensorflow/lite/g3doc/performance/post_training_quant.ipynb
new file mode 100644
index 0000000..89b2c2b
--- /dev/null
+++ b/tensorflow/lite/g3doc/performance/post_training_quant.ipynb
@@ -0,0 +1,735 @@
+{
+  "nbformat": 4,
+  "nbformat_minor": 0,
+  "metadata": {
+    "colab": {
+      "name": "post_training_quant.ipynb",
+      "version": "0.3.2",
+      "provenance": [],
+      "private_outputs": true,
+      "collapsed_sections": [],
+      "toc_visible": true
+    },
+    "kernelspec": {
+      "name": "python3",
+      "display_name": "Python 3"
+    }
+  },
+  "cells": [
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "_-GR0EDHM1SO",
+        "colab_type": "text"
+      },
+      "source": [
+        "##### Copyright 2019 The TensorFlow Authors."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "R3yYtBPkM2qZ",
+        "colab_type": "code",
+        "colab": {},
+        "cellView": "form"
+      },
+      "source": [
+        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+        "# you may not use this file except in compliance with the License.\n",
+        "# You may obtain a copy of the License at\n",
+        "#\n",
+        "# https://www.apache.org/licenses/LICENSE-2.0\n",
+        "#\n",
+        "# Unless required by applicable law or agreed to in writing, software\n",
+        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+        "# See the License for the specific language governing permissions and\n",
+        "# limitations under the License."
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "6Y8E0lw5eYWm"
+      },
+      "source": [
+        "# Post-training weight quantization"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "CIGrZZPTZVeO"
+      },
+      "source": [
+        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/lite/performance/post_training_quant\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
+        "  </td>\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/post_training_quant.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
+        "  </td>\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/post_training_quant.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
+        "  </td>\n",
+        "</table>"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "BTC1rDAuei_1"
+      },
+      "source": [
+        "## Overview\n",
+        "\n",
+        "[TensorFlow Lite](https://www.tensorflow.org/lite/) now supports\n",
+        "converting weights to 8 bit precision as part of model conversion from\n",
+        "tensorflow graphdefs to TensorFlow Lite's flat buffer format. Weight quantization\n",
+        "achieves a 4x reduction in the model size. In addition, TFLite supports on the\n",
+        "fly quantization and dequantization of activations to allow for:\n",
+        "\n",
+        "1.  Using quantized kernels for faster implementation when available.\n",
+        "2.  Mixing of floating-point kernels with quantized kernels for different parts\n",
+        "    of the graph.\n",
+        "\n",
+        "The activations are always stored in floating point. For ops that\n",
+        "support quantized kernels, the activations are quantized to 8 bits of precision\n",
+        "dynamically prior to processing and are de-quantized to float precision after\n",
+        "processing. Depending on the model being converted, this can give a speedup over\n",
+        "pure floating point computation.\n",
+        "\n",
+        "In contrast to\n",
+        "[quantization aware training](https://github.com/tensorflow/tensorflow/tree/r1.14/tensorflow/contrib/quantize)\n",
+        ", the weights are quantized post training and the activations are quantized dynamically \n",
+        "at inference in this method.\n",
+        "Therefore, the model weights are not retrained to compensate for quantization\n",
+        "induced errors. It is important to check the accuracy of the quantized model to\n",
+        "ensure that the degradation is acceptable.\n",
+        "\n",
+        "This tutorial trains an MNIST model from scratch, checks its accuracy in\n",
+        "TensorFlow, and then converts the saved model into a Tensorflow Lite flatbuffer\n",
+        "with weight quantization. Finally, it checks the\n",
+        "accuracy of the converted model and compare it to the original saved model. The training script, `mnist.py`, is from\n",
+        "[Tensorflow official mnist tutorial](https://github.com/tensorflow/models/tree/master/official/mnist).\n"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "2XsEP17Zelz9"
+      },
+      "source": [
+        "## Build an MNIST model"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "dDqqUIZjZjac"
+      },
+      "source": [
+        "### Setup"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "gyqAw1M9lyab",
+        "colab": {}
+      },
+      "source": [
+        "! pip uninstall -y tensorflow\n",
+        "! pip install -U tf-nightly"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "WsN6s5L1ieNl",
+        "colab": {}
+      },
+      "source": [
+        "import tensorflow as tf\n",
+        "tf.enable_eager_execution()"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "00U0taBoe-w7",
+        "colab": {}
+      },
+      "source": [
+        "! git clone --depth 1 https://github.com/tensorflow/models"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "4XZPtSh-fUOc",
+        "colab": {}
+      },
+      "source": [
+        "import sys\n",
+        "import os\n",
+        "\n",
+        "if sys.version_info.major >= 3:\n",
+        "    import pathlib\n",
+        "else:\n",
+        "    import pathlib2 as pathlib\n",
+        "\n",
+        "# Add `models` to the python path.\n",
+        "models_path = os.path.join(os.getcwd(), \"models\")\n",
+        "sys.path.append(models_path)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "eQ6Q0qqKZogR"
+      },
+      "source": [
+        "### Train and export the model"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "eMsw_6HujaqM",
+        "colab": {}
+      },
+      "source": [
+        "saved_models_root = \"/tmp/mnist_saved_model\""
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "hWSAjQWagIHl",
+        "colab": {}
+      },
+      "source": [
+        "# The above path addition is not visible to subprocesses, add the path for the subprocess as well.\n",
+        "# Note: channels_last is required here or the conversion may fail. \n",
+        "!PYTHONPATH={models_path} python models/official/mnist/mnist.py --train_epochs=1 --export_dir {saved_models_root} --data_format=channels_last"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "5NMaNZQCkW9X"
+      },
+      "source": [
+        "For the example, since you trained the model for just a single epoch, so it only trains to ~96% accuracy.\n",
+        "\n"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "xl8_fzVAZwOh"
+      },
+      "source": [
+        "### Convert to a TFLite model\n",
+        "\n",
+        "The `savedmodel` directory is named with a timestamp. Select the most recent one: "
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "Xp5oClaZkbtn",
+        "colab": {}
+      },
+      "source": [
+        "saved_model_dir = str(sorted(pathlib.Path(saved_models_root).glob(\"*\"))[-1])\n",
+        "saved_model_dir"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "AT8BgkKmljOy"
+      },
+      "source": [
+        "Using the python `TFLiteConverter`, the saved model can be converted into a TFLite model.\n",
+        "\n",
+        "First load the model using the `TFLiteConverter`:"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "_i8B2nDZmAgQ",
+        "colab": {}
+      },
+      "source": [
+        "import tensorflow as tf\n",
+        "tf.enable_eager_execution()\n",
+        "converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)\n",
+        "tflite_model = converter.convert()"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "F2o2ZfF0aiCx"
+      },
+      "source": [
+        "Write it out to a tflite file:"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "vptWZq2xnclo",
+        "colab": {}
+      },
+      "source": [
+        "tflite_models_dir = pathlib.Path(\"/tmp/mnist_tflite_models/\")\n",
+        "tflite_models_dir.mkdir(exist_ok=True, parents=True)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "Ie9pQaQrn5ue",
+        "colab": {}
+      },
+      "source": [
+        "tflite_model_file = tflite_models_dir/\"mnist_model.tflite\"\n",
+        "tflite_model_file.write_bytes(tflite_model)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "7BONhYtYocQY"
+      },
+      "source": [
+        "To quantize the model on export, set the `optimizations` flag to optimize for size:"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "g8PUvLWDlmmz",
+        "colab": {}
+      },
+      "source": [
+        "# Note: If you don't have a recent tf-nightly installed, the\n",
+        "# \"optimizations\" line will have no effect.\n",
+        "tf.logging.set_verbosity(tf.logging.INFO)\n",
+        "converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]\n",
+        "tflite_quant_model = converter.convert()\n",
+        "tflite_model_quant_file = tflite_models_dir/\"mnist_model_quant.tflite\"\n",
+        "tflite_model_quant_file.write_bytes(tflite_quant_model)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "PhMmUTl4sbkz"
+      },
+      "source": [
+        "Note how the resulting file, is approximately `1/4` the size."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "JExfcfLDscu4",
+        "colab": {}
+      },
+      "source": [
+        "!ls -lh {tflite_models_dir}"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "L8lQHMp_asCq"
+      },
+      "source": [
+        "## Run the TFLite models"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "-5l6-ciItvX6"
+      },
+      "source": [
+        "Run the TensorFlow Lite model using the Python TensorFlow Lite\n",
+        "Interpreter. \n",
+        "\n",
+        "### load the test data\n",
+        "\n",
+        "First let's load the mnist test data to feed to it:"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "eTIuU07NuKFL",
+        "colab": {}
+      },
+      "source": [
+        "import numpy as np\n",
+        "mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()\n",
+        "images, labels = tf.cast(mnist_test[0], tf.float32)/255.0, mnist_test[1]\n",
+        "\n",
+        "# Note: If you change the batch size, then use \n",
+        "# `tf.lite.Interpreter.resize_tensor_input` to also change it for\n",
+        "# the interpreter.\n",
+        "mnist_ds = tf.data.Dataset.from_tensor_slices((images, labels)).batch(1)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "Ap_jE7QRvhPf"
+      },
+      "source": [
+        "### Load the model into an interpreter"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "Jn16Rc23zTss",
+        "colab": {}
+      },
+      "source": [
+        "interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))\n",
+        "interpreter.allocate_tensors()\n",
+        "input_index = interpreter.get_input_details()[0][\"index\"]\n",
+        "output_index = interpreter.get_output_details()[0][\"index\"]"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "J8Pztk1mvNVL",
+        "colab": {}
+      },
+      "source": [
+        "tf.logging.set_verbosity(tf.logging.DEBUG)\n",
+        "interpreter_quant = tf.lite.Interpreter(model_path=str(tflite_model_quant_file))"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "Afl6yGvWyqAr",
+        "colab": {}
+      },
+      "source": [
+        "interpreter_quant.allocate_tensors()\n",
+        "input_index = interpreter_quant.get_input_details()[0][\"index\"]\n",
+        "output_index = interpreter_quant.get_output_details()[0][\"index\"]"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "2opUt_JTdyEu"
+      },
+      "source": [
+        "### Test the model on one image"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "AKslvo2kwWac",
+        "colab": {}
+      },
+      "source": [
+        "for img, label in mnist_ds.take(1):\n",
+        "  break\n",
+        "\n",
+        "interpreter.set_tensor(input_index, img)\n",
+        "interpreter.invoke()\n",
+        "predictions = interpreter.get_tensor(output_index)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "XZClM2vo3_bm",
+        "colab": {}
+      },
+      "source": [
+        "import matplotlib.pylab as plt\n",
+        "\n",
+        "plt.imshow(img[0])\n",
+        "template = \"True:{true}, predicted:{predict}\"\n",
+        "_ = plt.title(template.format(true= str(label[0].numpy()),\n",
+        "                              predict=str(predictions[0])))\n",
+        "plt.grid(False)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "LwN7uIdCd8Gw"
+      },
+      "source": [
+        "### Evaluate the models"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "05aeAuWjvjPx",
+        "colab": {}
+      },
+      "source": [
+        "def eval_model(interpreter, mnist_ds):\n",
+        "  total_seen = 0\n",
+        "  num_correct = 0\n",
+        "\n",
+        "  for img, label in mnist_ds:\n",
+        "    total_seen += 1\n",
+        "    interpreter.set_tensor(input_index, img)\n",
+        "    interpreter.invoke()\n",
+        "    predictions = interpreter.get_tensor(output_index)\n",
+        "    if predictions == label.numpy():\n",
+        "      num_correct += 1\n",
+        "\n",
+        "    if total_seen % 500 == 0:\n",
+        "        print(\"Accuracy after %i images: %f\" %\n",
+        "              (total_seen, float(num_correct) / float(total_seen)))\n",
+        "\n",
+        "  return float(num_correct) / float(total_seen)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "DqXBnDfJ7qxL",
+        "colab": {}
+      },
+      "source": [
+        "print(eval_model(interpreter, mnist_ds))"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "Km3cY9ry8ZlG"
+      },
+      "source": [
+        "Repeat the evaluation on the weight quantized model to obtain:\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "-9cnwiPp6EGm",
+        "colab": {}
+      },
+      "source": [
+        "print(eval_model(interpreter_quant, mnist_ds))"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "L7lfxkor8pgv"
+      },
+      "source": [
+        "\n",
+        "In this example, the compressed model has no difference in the accuracy."
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "M0o1FtmWeKZm"
+      },
+      "source": [
+        "## Optimizing an existing model\n",
+        "\n",
+        "Resnets with pre-activation layers (Resnet-v2) are widely used for vision applications.\n",
+        "  Pre-trained frozen graph for resnet-v2-101 is available at the\n",
+        "  [Tensorflow Lite model repository](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/models.md).\n",
+        "\n",
+        "You can convert the frozen graph to a TensorFLow Lite flatbuffer with quantization by:\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "v5p5VcNPjILQ",
+        "colab": {}
+      },
+      "source": [
+        "archive_path = tf.keras.utils.get_file(\"resnet_v2_101.tgz\", \"https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/resnet_v2_101.tgz\", extract=True)\n",
+        "archive_path = pathlib.Path(archive_path)\n",
+        "archive_dir = str(archive_path.parent)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "-sxnXQuC4ThD"
+      },
+      "source": [
+        "The `info.txt` file lists the input and output names. You can also find them using TensorBoard to visually inspect the graph."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "g_Q_OMEJ4LIc",
+        "colab": {}
+      },
+      "source": [
+        "! cat {archive_dir}/resnet_v2_101_299_info.txt"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "ujCAFhqm-C6H",
+        "colab": {}
+      },
+      "source": [
+        "graph_def_file = pathlib.Path(archive_path).parent/\"resnet_v2_101_299_frozen.pb\"\n",
+        "input_arrays = [\"input\"] \n",
+        "output_arrays = [\"output\"]\n",
+        "converter = tf.lite.TFLiteConverter.from_frozen_graph(\n",
+        "  str(graph_def_file), input_arrays, output_arrays, input_shapes={\"input\":[1,299,299,3]})\n",
+        "converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]\n",
+        "resnet_tflite_file = graph_def_file.parent/\"resnet_v2_101_quantized.tflite\"\n",
+        "resnet_tflite_file.write_bytes(converter.convert())"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab_type": "code",
+        "id": "vhOjeg1x9Knp",
+        "colab": {}
+      },
+      "source": [
+        "!ls -lh {archive_dir}/*.tflite"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "qqHLaqFMCjRZ"
+      },
+      "source": [
+        "\n",
+        "The model size reduces from 171 MB to 43 MB.\n",
+        "The accuracy of this model on imagenet can be evaluated using the scripts provided for [TFLite accuracy measurement](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/accuracy/ilsvrc).\n",
+        "\n",
+        "The optimized model top-1 accuracy is 76.8, the same as the floating point model."
+      ]
+    }
+  ]
+}
\ No newline at end of file
diff --git a/tensorflow/lite/g3doc/performance/post_training_quantization.md b/tensorflow/lite/g3doc/performance/post_training_quantization.md
index 69ebf7e..30f8c09 100644
--- a/tensorflow/lite/g3doc/performance/post_training_quantization.md
+++ b/tensorflow/lite/g3doc/performance/post_training_quantization.md
@@ -8,6 +8,20 @@
 
 ### Optimization options
 
+There are several post training quantization options to choose from. Here is a
+summary table of the choices and the benefits they provide:
+
+| Technique              | Benefits                  | Hardware            |
+| ---------------------- | ------------------------- | ------------------- |
+| Post training "hybrid" | 4x smaller, 2-3x speedup, | CPU                 |
+:                        : accuracy                  :                     :
+| Post training integer  | 4x smaller, More speedup  | CPU, Edge TPU, etc. |
+| Post training fp16     | 2x smaller, Potential GPU | CPU/GPU             |
+:                        : acceleration              :                     :
+
+This decision tree can help determine which post-training quantization method is
+best for your use case:
+
 ![post-training optimization options](images/optimization.jpg)
 
 ### Quantizing weights
@@ -78,6 +92,35 @@
 This makes the converter throw an error if it encounters an operation it cannot
 currently quantize.
 
+### Float16 quantization of weights
+
+We can reduce the size of a floating point model by quantizing the weights to
+float16, the IEEE standard for 16 bit floating point numbers. The advantages of
+this quantization are as follows:
+
+-   reduce model size by up to half (since all weights are now half the original
+    size)
+-   minimal loss in accuracy
+-   some delegates (e.g. the GPU delegate) can operate directly on float16 data,
+    which results in faster execution than float32 computations.
+
+This quantization may not be a good choice if you need maximum performance (a
+quantization to fixed point math would be better in that case). To enable
+float16 quantization of weights, specify "DEFAULT" optimization as above and
+then specify that float16 is in supported types for the target_spec:
+
+```
+import tensorflow as tf
+converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
+converter.optimizations = [tf.lite.Optimize.DEFAULT]
+converter.target_spec.supported_types = [tf.lite.constants.FLOAT16]
+tflite_quant_model = converter.convert()
+```
+
+By default, a float16 quantized model will "dequantize" the weights values to
+float32 when run on the CPU. The GPU delegate will not perform this
+dequantization, since it can operate on float16 data.
+
 ### Model accuracy
 
 Since weights are quantized post training, there could be an accuracy loss,
diff --git a/tensorflow/lite/interpreter.cc b/tensorflow/lite/interpreter.cc
index bf72f78..6ef6c2c 100644
--- a/tensorflow/lite/interpreter.cc
+++ b/tensorflow/lite/interpreter.cc
@@ -134,8 +134,8 @@
 
   subgraphs_.reserve(base_index + subgraphs_to_add);
   for (int i = 0; i < subgraphs_to_add; ++i) {
-    Subgraph* subgraph =
-        new Subgraph(error_reporter_, external_contexts_, &subgraphs_);
+    Subgraph* subgraph = new Subgraph(error_reporter_, external_contexts_,
+                                      &subgraphs_, &resource_variables_);
     subgraphs_.emplace_back(subgraph);
   }
 }
diff --git a/tensorflow/lite/interpreter.h b/tensorflow/lite/interpreter.h
index 8eef585..397d47a 100644
--- a/tensorflow/lite/interpreter.h
+++ b/tensorflow/lite/interpreter.h
@@ -28,6 +28,7 @@
 #include "tensorflow/lite/core/api/error_reporter.h"
 #include "tensorflow/lite/core/api/profiler.h"
 #include "tensorflow/lite/core/subgraph.h"
+#include "tensorflow/lite/experimental/resource_variable/resource_variable.h"
 #include "tensorflow/lite/external_cpu_backend_context.h"
 #include "tensorflow/lite/memory_planner.h"
 #include "tensorflow/lite/stderr_reporter.h"
@@ -539,6 +540,10 @@
 
   // Subgraphs
   std::vector<std::unique_ptr<Subgraph>> subgraphs_;
+
+  // A map of resource variables. Owned by interpreter and shared by multiple
+  // subgraphs.
+  ResourceVariableMap resource_variables_;
 };
 
 }  // namespace tflite
diff --git a/tensorflow/lite/interpreter_test.cc b/tensorflow/lite/interpreter_test.cc
index 71dc1ef..f8ab53f 100644
--- a/tensorflow/lite/interpreter_test.cc
+++ b/tensorflow/lite/interpreter_test.cc
@@ -1155,8 +1155,9 @@
     // value-copyable and compatible with TfLite.
     explicit SimpleDelegate(
         const std::vector<int>& nodes,
-        TfLiteDelegateFlags delegate_flags = kTfLiteDelegateFlagsNone)
-        : nodes_(nodes) {
+        TfLiteDelegateFlags delegate_flags = kTfLiteDelegateFlagsNone,
+        bool fail_node_prepare = false)
+        : nodes_(nodes), fail_delegate_node_prepare_(fail_node_prepare) {
       delegate_.Prepare = [](TfLiteContext* context,
                              TfLiteDelegate* delegate) -> TfLiteStatus {
         auto* simple = reinterpret_cast<SimpleDelegate*>(delegate->data_);
@@ -1191,7 +1192,8 @@
         }
 
         context->ReplaceNodeSubsetsWithDelegateKernels(
-            context, FakeFusedRegistration(), nodes_to_separate, delegate);
+            context, simple->FakeFusedRegistration(), nodes_to_separate,
+            delegate);
         TfLiteIntArrayFree(nodes_to_separate);
         return kTfLiteOk;
       };
@@ -1224,7 +1226,7 @@
       delegate_.flags = delegate_flags;
     }
 
-    static TfLiteRegistration FakeFusedRegistration() {
+    TfLiteRegistration FakeFusedRegistration() {
       TfLiteRegistration reg = {nullptr};
       reg.custom_name = "fake_fused_op";
 
@@ -1270,6 +1272,12 @@
             context, output, TfLiteIntArrayCopy(input1->dims)));
         return kTfLiteOk;
       };
+      if (fail_delegate_node_prepare_) {
+        reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
+          return kTfLiteError;
+        };
+      }
+
       return reg;
     }
 
@@ -1278,7 +1286,9 @@
    private:
     std::vector<int> nodes_;
     TfLiteDelegate delegate_;
+    bool fail_delegate_node_prepare_ = false;
   };
+
   std::unique_ptr<Interpreter> interpreter_;
   std::unique_ptr<SimpleDelegate> delegate_, delegate2_;
 };
@@ -1291,7 +1301,7 @@
   int node = interpreter_->execution_plan()[0];
   const auto* node_and_reg = interpreter_->node_and_registration(node);
   EXPECT_EQ(node_and_reg->second.custom_name,
-            SimpleDelegate::FakeFusedRegistration().custom_name);
+            delegate_->FakeFusedRegistration().custom_name);
 
   const TfLiteDelegateParams* params =
       reinterpret_cast<const TfLiteDelegateParams*>(
@@ -1310,6 +1320,73 @@
   EXPECT_EQ(params->output_tensors->data[1], 4);
 }
 
+TEST_F(TestDelegate, DelegateNodePrepareFailure) {
+  delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate(
+      {0, 1, 2}, kTfLiteDelegateFlagsNone, true /**fail_node_prepare**/));
+  // ModifyGraphWithDelegate fails, since the Prepare() method in the node's
+  // TfLiteRegistration returns an error status.
+  ASSERT_EQ(
+      interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()),
+      kTfLiteError);
+  // Execution plan should remain unchanged.
+  ASSERT_EQ(interpreter_->execution_plan().size(), 3);
+
+  std::vector<float> input = {1.0f, 2.0f, 3.0f};
+  std::vector<float> expected_output = {2.0f, 4.0f, 6.0f};
+  constexpr int kOutputTensorIndex = 3;
+  TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex);
+
+  // Verify Invoke() behavior.
+  memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float));
+  memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float));
+  interpreter_->Invoke();
+  for (int i = 0; i < 3; ++i) {
+    EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i;
+  }
+}
+
+TEST_F(TestDelegate, SecondDelegationPrepareFailure) {
+  // First delegate only supports nodes 1, 2. Gets applied successfully.
+  // This delegate should support dynamic tensors, otherwise the second won't be
+  // applied.
+  delegate_ = std::unique_ptr<SimpleDelegate>(
+      new SimpleDelegate({1, 2}, kTfLiteDelegateFlagsAllowDynamicTensors));
+  // Second delegate supports node 0, but fails during the delegate-node's
+  // Prepare.
+  delegate2_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate(
+      {0}, kTfLiteDelegateFlagsNone, true /**fail_node_prepare**/));
+
+  // Initially, execution plan has 3 nodes.
+  ASSERT_EQ(interpreter_->execution_plan().size(), 3);
+  // First delegate should be applied successfully, yielding a plan with 2
+  // nodes.
+  ASSERT_EQ(
+      interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()),
+      kTfLiteOk);
+  ASSERT_EQ(interpreter_->execution_plan().size(), 2);
+  // Second delegate won't get applied. However, we should be back to the
+  // previous 2-node plan.
+  ASSERT_EQ(
+      interpreter_->ModifyGraphWithDelegate(delegate2_->get_tf_lite_delegate()),
+      kTfLiteError);
+  ASSERT_EQ(interpreter_->execution_plan().size(), 2);
+
+  std::vector<float> input = {1.0f, 2.0f, 3.0f};
+  // Node 0: tensor_2 = tensor0 + tensor0
+  // Delegated node: tensor_2 + tensor_1
+  std::vector<float> expected_output = {3.0f, 6.0f, 9.0f};
+  constexpr int kOutputTensorIndex = 3;
+  TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex);
+
+  // Verify Invoke() behavior to ensure Interpreter isn't broken.
+  memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float));
+  memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float));
+  interpreter_->Invoke();
+  for (int i = 0; i < 3; ++i) {
+    EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i;
+  }
+}
+
 TEST_F(TestDelegate, StaticDelegateMakesGraphImmutable) {
   delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
   ASSERT_EQ(
@@ -1343,7 +1420,7 @@
   ASSERT_EQ(interpreter_->execution_plan()[1], 3);
   const auto* node_and_reg = interpreter_->node_and_registration(3);
   ASSERT_EQ(node_and_reg->second.custom_name,
-            SimpleDelegate::FakeFusedRegistration().custom_name);
+            delegate_->FakeFusedRegistration().custom_name);
 }
 
 TEST_F(TestDelegate, SetBufferHandleToInput) {
diff --git a/tensorflow/lite/java/demo/app/build.gradle b/tensorflow/lite/java/demo/app/build.gradle
index c353b2c..fca1843 100644
--- a/tensorflow/lite/java/demo/app/build.gradle
+++ b/tensorflow/lite/java/demo/app/build.gradle
@@ -60,8 +60,8 @@
 }
 
 def targetFolder = "src/main/assets"
-def modelFloatDownloadUrl = "http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz"
-def modelQuantDownloadUrl = "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz"
+def modelFloatDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz"
+def modelQuantDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz"
 def localCacheFloat = "build/intermediates/mobilenet_v1_1.0_224.tgz"
 def localCacheQuant = "build/intermediates/mmobilenet_v1_1.0_224_quant.tgz"
 
diff --git a/tensorflow/lite/java/jni/BUILD b/tensorflow/lite/java/jni/BUILD
index 3121cda..137ca32 100644
--- a/tensorflow/lite/java/jni/BUILD
+++ b/tensorflow/lite/java/jni/BUILD
@@ -18,6 +18,7 @@
         "//tensorflow:android": [],
         "//conditions:default": ["."],
     }),
+    visibility = ["//visibility:public"],
 )
 
 # Silly rules to make
diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
index 5aef4fb..37f8b38 100644
--- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
+++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
@@ -388,6 +388,18 @@
     wrapper.modifyGraphWithDelegate(delegate);
   }
 
+  /**
+   * Advanced: Resets all variable tensors to the default value.
+   *
+   * <p>If a variable tensor doesn't have an associated buffer, it will be reset to zero.
+   *
+   * <p>WARNING: This is an experimental API and subject to change.
+   */
+  public void resetVariableTensors() {
+    checkNotClosed();
+    wrapper.resetVariableTensors();
+  }
+
   /** Release resources associated with the {@code Interpreter}. */
   @Override
   public void close() {
diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
index 160d4df..abe0ec7 100644
--- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
+++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -193,6 +193,10 @@
     delegates.add(delegate);
   }
 
+  void resetVariableTensors() {
+    resetVariableTensors(interpreterHandle, errorHandle);
+  }
+
   /** Gets index of an input given its name. */
   int getInputIndex(String name) {
     if (inputsIndexes == null) {
@@ -374,6 +378,8 @@
   private static native void applyDelegate(
       long interpreterHandle, long errorHandle, long delegateHandle);
 
+  private static native void resetVariableTensors(long interpreterHandle, long errorHandle);
+
   private static native void delete(long errorHandle, long modelHandle, long interpreterHandle);
 
   static {
diff --git a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
index c2abbab..b865097 100644
--- a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
+++ b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
@@ -508,6 +508,25 @@
   }
 }
 
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_resetVariableTensors(
+    JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) {
+  tflite::Interpreter* interpreter =
+      convertLongToInterpreter(env, interpreter_handle);
+  if (interpreter == nullptr) return;
+
+  BufferErrorReporter* error_reporter =
+      convertLongToErrorReporter(env, error_handle);
+  if (error_reporter == nullptr) return;
+
+  TfLiteStatus status = interpreter->ResetVariableTensors();
+  if (status != kTfLiteOk) {
+    ThrowException(env, kIllegalArgumentException,
+                   "Internal error: Failed to reset variable tensors: %s",
+                   error_reporter->CachedErrorMessage());
+  }
+}
+
 JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_delete(
     JNIEnv* env, jclass clazz, jlong error_handle, jlong model_handle,
     jlong interpreter_handle) {
diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index d62b1e1..6f22764 100644
--- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -479,6 +479,23 @@
     }
   }
 
+  @Test
+  public void testResetVariableTensors() throws Exception {
+    float[][][][] inputs = new float[2][8][8][3];
+    float[][][][] parsedOutputs = new float[2][8][8][3];
+
+    // Smoke test to ensure resetting variables at various times in a simple graph doesn't fail.
+    // TODO(b/138197256): Test with model that has variables.
+    try (Interpreter interpreter = new Interpreter(MODEL_BUFFER)) {
+      interpreter.resetVariableTensors();
+      interpreter.run(inputs, parsedOutputs);
+
+      interpreter.resetVariableTensors();
+      interpreter.resetVariableTensors();
+      interpreter.run(inputs, parsedOutputs);
+    }
+  }
+
   private static native long getNativeHandleForDelegate();
 
   private static native long getNativeHandleForInvalidDelegate();
diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD
index 2b550c9..c5012b7 100644
--- a/tensorflow/lite/kernels/BUILD
+++ b/tensorflow/lite/kernels/BUILD
@@ -494,6 +494,36 @@
 )
 
 cc_library(
+    name = "variable_op_kernels",
+    srcs = [
+        "assign_variable.cc",
+        "read_variable.cc",
+    ],
+    deps = [
+        ":kernel_util",
+        ":op_macros",
+        "//tensorflow/lite:framework",
+        "//tensorflow/lite/c:c_api_internal",
+        "//tensorflow/lite/kernels/internal:tensor",
+    ],
+)
+
+cc_test(
+    name = "variable_ops_test",
+    size = "small",
+    srcs = [
+        "variable_ops_test.cc",
+    ],
+    deps = [
+        ":test_main",
+        ":test_util",
+        ":variable_op_kernels",
+        "//tensorflow/lite:framework",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
+cc_library(
     name = "custom_ops",
     srcs = ["rfft2d.cc"],
     hdrs = ["custom_ops_register.h"],
@@ -519,6 +549,7 @@
         ":op_macros",
         "//tensorflow/lite/c:c_api_internal",
         "//tensorflow/lite/kernels/internal:kernel_utils",
+        "//tensorflow/lite/kernels/internal:tensor",
         "//tensorflow/lite/kernels/internal:tensor_utils",
         "//third_party/eigen3",
         "@gemmlowp",
@@ -708,6 +739,7 @@
     name = "batch_to_space_nd_test",
     size = "small",
     srcs = ["batch_to_space_nd_test.cc"],
+    tags = ["tflite_nnapi"],
     deps = [
         ":builtin_ops",
         ":test_main",
@@ -735,6 +767,7 @@
     name = "concatenation_test",
     size = "small",
     srcs = ["concatenation_test.cc"],
+    tags = ["tflite_nnapi"],
     deps = [
         ":builtin_ops",
         ":test_main",
@@ -871,6 +904,7 @@
     name = "elementwise_test",
     size = "small",
     srcs = ["elementwise_test.cc"],
+    tags = ["tflite_nnapi"],
     deps = [
         ":builtin_ops",
         ":test_main",
@@ -939,6 +973,7 @@
     name = "exp_test",
     size = "small",
     srcs = ["exp_test.cc"],
+    tags = ["tflite_nnapi"],
     deps = [
         ":builtin_ops",
         ":test_main",
@@ -979,6 +1014,7 @@
     name = "reduce_test",
     size = "small",
     srcs = ["reduce_test.cc"],
+    tags = ["tflite_nnapi"],
     deps = [
         ":builtin_ops",
         ":test_main",
@@ -1006,6 +1042,7 @@
     name = "pad_test",
     size = "small",
     srcs = ["pad_test.cc"],
+    tags = ["tflite_nnapi"],
     deps = [
         ":builtin_ops",
         ":test_main",
@@ -1091,6 +1128,7 @@
     name = "resize_nearest_neighbor_test",
     size = "small",
     srcs = ["resize_nearest_neighbor_test.cc"],
+    tags = ["tflite_nnapi"],
     deps = [
         ":builtin_ops",
         ":test_main",
@@ -1104,6 +1142,7 @@
     name = "svdf_test",
     size = "small",
     srcs = ["svdf_test.cc"],
+    tags = ["tflite_nnapi"],
     deps = [
         ":builtin_ops",
         ":test_main",
@@ -1117,6 +1156,7 @@
     name = "embedding_lookup_test",
     size = "small",
     srcs = ["embedding_lookup_test.cc"],
+    tags = ["tflite_nnapi"],
     deps = [
         ":builtin_ops",
         ":test_main",
@@ -1131,6 +1171,7 @@
     name = "embedding_lookup_sparse_test",
     size = "small",
     srcs = ["embedding_lookup_sparse_test.cc"],
+    tags = ["tflite_nnapi"],
     deps = [
         ":builtin_ops",
         ":test_main",
@@ -1174,6 +1215,7 @@
     name = "pooling_test",
     size = "small",
     srcs = ["pooling_test.cc"],
+    tags = ["tflite_nnapi"],
     deps = [
         ":builtin_ops",
         ":test_main",
@@ -1231,6 +1273,7 @@
     name = "hashtable_lookup_test",
     size = "small",
     srcs = ["hashtable_lookup_test.cc"],
+    tags = ["tflite_nnapi"],
     deps = [
         ":builtin_ops",
         ":test_main",
@@ -1245,6 +1288,7 @@
     name = "lstm_test",
     size = "small",
     srcs = ["lstm_test.cc"],
+    tags = ["tflite_nnapi"],
     deps = [
         ":builtin_ops",
         ":test_main",
@@ -1313,6 +1357,7 @@
     name = "squeeze_test",
     size = "small",
     srcs = ["squeeze_test.cc"],
+    tags = ["tflite_nnapi"],
     deps = [
         ":builtin_ops",
         ":test_main",
@@ -1326,6 +1371,7 @@
     name = "strided_slice_test",
     size = "small",
     srcs = ["strided_slice_test.cc"],
+    tags = ["tflite_nnapi"],
     deps = [
         ":builtin_ops",
         ":test_main",
@@ -1416,6 +1462,7 @@
     name = "transpose_conv_test",
     size = "small",
     srcs = ["transpose_conv_test.cc"],
+    tags = ["tflite_nnapi"],
     deps = [
         ":builtin_ops",
         ":test_main",
@@ -1834,3 +1881,18 @@
         "@com_google_googletest//:gtest",
     ],
 )
+
+cc_test(
+    name = "quant_basic_lstm_test",
+    size = "small",
+    srcs = ["quant_basic_lstm_test.cc"],
+    tags = ["tflite_nnapi"],
+    deps = [
+        ":builtin_ops",
+        ":kernel_util",
+        ":test_main",
+        ":test_util",
+        "//tensorflow/lite:framework",
+        "@com_google_googletest//:gtest",
+    ],
+)
diff --git a/tensorflow/lite/kernels/activations.cc b/tensorflow/lite/kernels/activations.cc
index 7efd5df..793f90b 100644
--- a/tensorflow/lite/kernels/activations.cc
+++ b/tensorflow/lite/kernels/activations.cc
@@ -55,6 +55,11 @@
   uint8_t* table_zero = nullptr;
 };
 
+struct SoftmaxOpData {
+  struct SoftmaxParams params = {};
+  float table[256];
+};
+
 struct LogSoftmaxOpData : public OpData {
   int32_t reverse_scaling_divisor = 0;
   int32_t reverse_scaling_right_shift = 0;
@@ -131,6 +136,14 @@
   return new OpData;
 }
 
+void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) {
+  return new SoftmaxOpData;
+}
+
+void SoftmaxFree(TfLiteContext* context, void* buffer) {
+  delete reinterpret_cast<SoftmaxOpData*>(buffer);
+}
+
 void* LogSoftmaxInit(TfLiteContext* context, const char* buffer,
                      size_t length) {
   return new LogSoftmaxOpData;
@@ -363,7 +376,7 @@
 
 TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
   auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
-  OpData* data = reinterpret_cast<OpData*>(node->user_data);
+  SoftmaxOpData* data = reinterpret_cast<SoftmaxOpData*>(node->user_data);
 
   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -375,16 +388,11 @@
   TF_LITE_ENSURE(context, num_dims >= 1 && num_dims <= 4);
 
   if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
-    if (CheckOutputQuantParams(context, input, output) == kTfLiteError) {
-      return kTfLiteError;
-    }
-
-    static const int kScaledDiffIntegerBits = 5;
-    tflite::PreprocessSoftmaxScaling(
-        params->beta, input->params.scale, kScaledDiffIntegerBits,
-        &data->input_multiplier, &data->input_left_shift);
-    data->diff_min = -1.0 * tflite::CalculateInputRadius(
-                                kScaledDiffIntegerBits, data->input_left_shift);
+    data->params.table = data->table;
+    optimized_ops::PopulateSoftmaxLookupTable(
+        &data->params, input->params.scale, params->beta);
+    data->params.zero_point = output->params.zero_point;
+    data->params.scale = output->params.scale;
   }
 
   return context->ResizeTensor(context, output,
@@ -749,61 +757,25 @@
   }
 }
 
-TfLiteStatus SoftmaxQuantizedUint8(TfLiteContext* context,
-                                   const TfLiteTensor* input,
-                                   TfLiteTensor* output,
-                                   TfLiteSoftmaxParams* params, OpData* data) {
-  switch (NumDimensions(input)) {
-    case 1:
-    case 2:
-    case 3:
-    case 4:
-      SoftmaxParams op_params;
-      op_params.input_multiplier = data->input_multiplier;
-      op_params.input_left_shift = data->input_left_shift;
-      op_params.diff_min = data->diff_min;
-      optimized_ops::Softmax(
-          op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
-          GetTensorShape(output), GetTensorData<uint8_t>(output));
-      return kTfLiteOk;
-    default:
-      context->ReportError(
-          context,
-          "Only 1D, 2D, 3D and 4D tensors supported currently, got %dD.",
-          NumDimensions(input));
-      return kTfLiteError;
-  }
-}
-
-TfLiteStatus SoftmaxQuantizedInt8(TfLiteContext* context,
-                                  const TfLiteTensor* input,
-                                  TfLiteTensor* output,
-                                  TfLiteSoftmaxParams* params, OpData* data) {
-  switch (NumDimensions(input)) {
-    case 1:
-    case 2:
-    case 3:
-    case 4:
-      SoftmaxParams op_params;
-      op_params.input_multiplier = data->input_multiplier;
-      op_params.input_left_shift = data->input_left_shift;
-      op_params.diff_min = data->diff_min;
-      optimized_integer_ops::Softmax(
-          op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
-          GetTensorShape(output), GetTensorData<int8_t>(output));
-      return kTfLiteOk;
-    default:
-      context->ReportError(
-          context,
-          "Only 1D, 2D, 3D and 4D tensors supported currently, got %dD.",
-          NumDimensions(input));
-      return kTfLiteError;
+template <typename T>
+TfLiteStatus SoftmaxQuantized(TfLiteContext* context, const TfLiteTensor* input,
+                              TfLiteTensor* output, SoftmaxOpData* data) {
+  if (NumDimensions(input) >= 1 && NumDimensions(input) <= 4) {
+    optimized_ops::Softmax(data->params, GetTensorShape(input),
+                           GetTensorData<T>(input), GetTensorShape(output),
+                           GetTensorData<T>(output));
+    return kTfLiteOk;
+  } else {
+    context->ReportError(
+        context, "Only 1D, 2D, 3D and 4D tensors supported currently, got %dD.",
+        NumDimensions(input));
+    return kTfLiteError;
   }
 }
 
 TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
   auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
-  OpData* data = reinterpret_cast<OpData*>(node->user_data);
+  SoftmaxOpData* data = reinterpret_cast<SoftmaxOpData*>(node->user_data);
 
   const TfLiteTensor* input = GetInput(context, node, 0);
   TfLiteTensor* output = GetOutput(context, node, 0);
@@ -815,10 +787,10 @@
       return SoftmaxFloat(context, input, output, params);
     }
     case kTfLiteUInt8: {
-      return SoftmaxQuantizedUint8(context, input, output, params, data);
+      return SoftmaxQuantized<uint8_t>(context, input, output, data);
     }
     case kTfLiteInt8: {
-      return SoftmaxQuantizedInt8(context, input, output, params, data);
+      return SoftmaxQuantized<int8_t>(context, input, output, data);
     }
 
     default:
@@ -1055,9 +1027,9 @@
 }
 
 TfLiteRegistration* Register_SOFTMAX() {
-  static TfLiteRegistration r = {activations::Init, activations::Free,
-                                 activations::SoftmaxPrepare,
-                                 activations::SoftmaxEval};
+  static TfLiteRegistration r = {
+      activations::SoftmaxInit, activations::SoftmaxFree,
+      activations::SoftmaxPrepare, activations::SoftmaxEval};
   return &r;
 }
 
diff --git a/tensorflow/lite/kernels/add_test.cc b/tensorflow/lite/kernels/add_test.cc
index 9449981..9dd7df1 100644
--- a/tensorflow/lite/kernels/add_test.cc
+++ b/tensorflow/lite/kernels/add_test.cc
@@ -109,7 +109,7 @@
 TEST(FloatAddOpModel, VariousInputShapes) {
   std::vector<std::vector<int>> test_shapes = {
       {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
-  for (int i = 0; i < test_shapes.size(); ++i) {
+  for (size_t i = 0; i < test_shapes.size(); ++i) {
     FloatAddOpModel m({TensorType_FLOAT32, test_shapes[i]},
                       {TensorType_FLOAT32, test_shapes[i]},
                       {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
@@ -125,7 +125,7 @@
 TEST(FloatAddOpModel, WithBroadcast) {
   std::vector<std::vector<int>> test_shapes = {
       {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
-  for (int i = 0; i < test_shapes.size(); ++i) {
+  for (size_t i = 0; i < test_shapes.size(); ++i) {
     FloatAddOpModel m({TensorType_FLOAT32, test_shapes[i]},
                       {TensorType_FLOAT32, {}},  // always a scalar
                       {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
@@ -162,7 +162,7 @@
 TEST(IntegerAddOpModel, VariousInputShapes) {
   std::vector<std::vector<int>> test_shapes = {
       {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
-  for (int i = 0; i < test_shapes.size(); ++i) {
+  for (size_t i = 0; i < test_shapes.size(); ++i) {
     IntegerAddOpModel m({TensorType_INT32, test_shapes[i]},
                         {TensorType_INT32, test_shapes[i]},
                         {TensorType_INT32, {}}, ActivationFunctionType_NONE);
@@ -177,7 +177,7 @@
 TEST(IntegerAddOpModel, WithBroadcast) {
   std::vector<std::vector<int>> test_shapes = {
       {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
-  for (int i = 0; i < test_shapes.size(); ++i) {
+  for (size_t i = 0; i < test_shapes.size(); ++i) {
     IntegerAddOpModel m({TensorType_INT32, test_shapes[i]},
                         {TensorType_INT32, {}},  // always a scalar
                         {TensorType_INT32, {}}, ActivationFunctionType_NONE);
@@ -199,7 +199,7 @@
       {0.6, 0.4, 0.3, 0.1}, {0.6, 0.4, 0.5, -0.8}, {0.6, 0.4, -0.8, 0.5}};
   std::vector<std::vector<float>> results = {
       {0.7, 0.6, 0.6, 0.5}, {-0.2, 0.6, 0.9, -0.1}, {-0.2, 0.6, -0.1, 0.8}};
-  for (int i = 0; i < inputs1.size(); ++i) {
+  for (size_t i = 0; i < inputs1.size(); ++i) {
     QuantizedAddOpModel m({tensor_type, {1, 2, 2, 1}, -1.0, 1.0},
                           {tensor_type, {1, 2, 2, 1}, -1.0, 1.0},
                           {tensor_type, {}, -1.0, 1.0},
@@ -232,7 +232,7 @@
       {0.6, 0.4, 0.3, 0.1}, {0.6, 0.4, 0.5, -0.8}, {0.6, 0.4, -0.8, 0.5}};
   std::vector<std::vector<float>> results = {
       {0.7, 0.6, 0.6, 0.5}, {-0.2, 0.6, 0.9, -0.1}, {-0.2, 0.6, -0.1, 0.8}};
-  for (int i = 0; i < inputs1.size(); ++i) {
+  for (size_t i = 0; i < inputs1.size(); ++i) {
     QuantizedAddOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMin, kMax},
                           {TensorType_INT16, {1, 2, 2, 1}, kMin, kMax},
                           {TensorType_INT16, {}, kMin, kMax},
@@ -256,7 +256,7 @@
                                              {0.6, 0.4, -0.8, 0.5}};
   std::vector<std::vector<float>> results = {{-0.2, 0.6, 1.0, -0.1},
                                              {-0.2, 0.6, -0.1, 0.8}};
-  for (int i = 0; i < inputs1.size(); ++i) {
+  for (size_t i = 0; i < inputs1.size(); ++i) {
     QuantizedAddOpModel m({tensor_type, {1, 2, 2, 1}, -1.0, 1.0},
                           {tensor_type, {1, 2, 2, 1}, -1.0, 1.0},
                           {tensor_type, {}, -1.0, 1.0},
@@ -284,7 +284,7 @@
   float kQuantizedTolerance = GetTolerance(-3.0, 3.0);
   std::vector<std::vector<int>> test_shapes = {
       {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
-  for (int i = 0; i < test_shapes.size(); ++i) {
+  for (size_t i = 0; i < test_shapes.size(); ++i) {
     QuantizedAddOpModel m({tensor_type, test_shapes[i], -3.0, 3.0},
                           {tensor_type, test_shapes[i], -3.0, 3.0},
                           {tensor_type, {}, -3.0, 3.0},
@@ -314,7 +314,7 @@
   float kQuantizedTolerance = GetTolerance(-3.f, 3.f);
   std::vector<std::vector<int>> test_shapes = {
       {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
-  for (int i = 0; i < test_shapes.size(); ++i) {
+  for (size_t i = 0; i < test_shapes.size(); ++i) {
     QuantizedAddOpModel model_fixture(
         {tensor_type, test_shapes[i], -3.f, 3.f}, {tensor_type, {}, -3.f, 3.f},
         {tensor_type, {}, -3.f, 3.f}, ActivationFunctionType_NONE);
@@ -330,7 +330,7 @@
         << "With shape number " << i;
   }
   // Re-run with exchanged inputs.
-  for (int i = 0; i < test_shapes.size(); ++i) {
+  for (size_t i = 0; i < test_shapes.size(); ++i) {
     QuantizedAddOpModel model_fixture(
         {tensor_type, {}, -3.f, 3.f}, {tensor_type, test_shapes[i], -3.f, 3.f},
         {tensor_type, {}, -3.f, 3.f}, ActivationFunctionType_NONE);
@@ -374,7 +374,7 @@
        1.0f,  -0.7f, 0.9f, 1.2f, -1.7f, 1.7f, -1.2f, 1.6f, -1.3f},
       {-0.1f, 2.5f, 1.2f, 0.8f, 0.4f, -1.5f, 1.7f, 3.0f, -0.6f, 1.0f, 1.6f,
        -1.3f}};
-  for (int i = 0; i < test_shapes.size(); ++i) {
+  for (size_t i = 0; i < test_shapes.size(); ++i) {
     QuantizedAddOpModel model_fixture({tensor_type, base_shape, -3.f, 3.f},
                                       {tensor_type, test_shapes[i], -3.f, 3.f},
                                       {tensor_type, {}, -3.f, 3.f},
@@ -391,7 +391,7 @@
         << "With shape number " << i;
   }
   // Re-run with exchanged inputs.
-  for (int i = 0; i < test_shapes.size(); ++i) {
+  for (size_t i = 0; i < test_shapes.size(); ++i) {
     QuantizedAddOpModel model_fixture({tensor_type, test_shapes[i], -3.f, 3.f},
                                       {tensor_type, base_shape, -3.f, 3.f},
                                       {tensor_type, {}, -3.f, 3.f},
diff --git a/tensorflow/lite/kernels/assign_variable.cc b/tensorflow/lite/kernels/assign_variable.cc
new file mode 100644
index 0000000..099b8e1
--- /dev/null
+++ b/tensorflow/lite/kernels/assign_variable.cc
@@ -0,0 +1,86 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string.h>
+
+#include <memory>
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/core/subgraph.h"
+#include "tensorflow/lite/kernels/internal/tensor.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace assign_variable {
+
+constexpr int kInputVariableId = 0;
+constexpr int kInputValue = 1;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+  TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+  // TODO(b/137042749): TFLite infrastructure (converter, delegate) doesn't
+  // fully support 0-output ops yet. Currently it works if we manually crfat
+  // a TFLite graph that contains variable ops. Note:
+  // * The TFLite Converter need to be changed to be able to produce an op
+  //   with 0 output.
+  // * The delegation code need to be changed to handle 0 output ops. However
+  //   everything still works fine when variable ops aren't used.
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 0);
+
+  const TfLiteTensor* input_variable_id_tensor =
+      GetInput(context, node, kInputVariableId);
+  TF_LITE_ENSURE_EQ(context, input_variable_id_tensor->type, kTfLiteInt32);
+  TF_LITE_ENSURE_EQ(context, NumElements(input_variable_id_tensor), 1);
+
+  return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+  Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
+
+  const TfLiteTensor* input_variable_id_tensor =
+      GetInput(context, node, kInputVariableId);
+  const TfLiteTensor* input_value_tensor = GetInput(context, node, kInputValue);
+
+  int variable_id = input_variable_id_tensor->data.i32[0];
+  auto& resource_variables = subgraph->resource_variables();
+
+  auto variable_iterator = resource_variables.find(variable_id);
+  if (variable_iterator == resource_variables.end()) {
+    auto ret = resource_variables.emplace(variable_id, ResourceVariable());
+    variable_iterator = ret.first;
+  }
+
+  auto& variable = variable_iterator->second;
+  variable.AssignFrom(input_value_tensor);
+
+  return kTfLiteOk;
+}
+
+}  // namespace assign_variable
+
+TfLiteRegistration* Register_ASSIGN_VARIABLE() {
+  static TfLiteRegistration r = {nullptr, nullptr, assign_variable::Prepare,
+                                 assign_variable::Eval};
+  return &r;
+}
+
+}  // namespace custom
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/lite/kernels/basic_rnn.cc b/tensorflow/lite/kernels/basic_rnn.cc
index 630f1b3..8106b2e 100644
--- a/tensorflow/lite/kernels/basic_rnn.cc
+++ b/tensorflow/lite/kernels/basic_rnn.cc
@@ -19,6 +19,7 @@
 #include "tensorflow/lite/c/c_api_internal.h"
 #include "tensorflow/lite/kernels/activation_functor.h"
 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/kernels/op_macros.h"
 
@@ -167,7 +168,6 @@
                         TfLiteTensor* hidden_state_scratch,
                         TfLiteTensor* scaling_factors,
                         TfLiteTensor* hidden_state, TfLiteTensor* output) {
-  const bool is_uint8_hybrid = input_weights->type == kTfLiteUInt8;
   const int batch_size = input->dims->data[0];
   const int num_units = input_weights->dims->data[0];
   const int input_size = input->dims->data[1];
@@ -180,18 +180,17 @@
   const float* input_ptr_batch = input->data.f;
   float* output_ptr_batch = output->data.f;
   // Initialize input_weights, recurrent_weights and bias.
-  const int8_t* input_weights_ptr =
-      GetInt8DataPtr(input_weights, is_uint8_hybrid);
+  const int8_t* input_weights_ptr = GetTensorData<int8_t>(input_weights);
   const int8_t* recurrent_weights_ptr =
-      GetInt8DataPtr(recurrent_weights, is_uint8_hybrid);
+      GetTensorData<int8_t>(recurrent_weights);
   const float* bias_ptr = bias->data.f;
   // Get the scale of the quantized weights.
   float input_weights_scale = input_weights->params.scale;
   float recurrent_weights_scale = recurrent_weights->params.scale;
   // Initialize temporary storage for quantized values.
-  int8_t* quantized_input_ptr = GetInt8DataPtr(input_scratch, is_uint8_hybrid);
+  int8_t* quantized_input_ptr = GetTensorData<int8_t>(input_scratch);
   int8_t* quantized_hidden_state_ptr =
-      GetInt8DataPtr(hidden_state_scratch, is_uint8_hybrid);
+      GetTensorData<int8_t>(hidden_state_scratch);
   float* scaling_factors_ptr = scaling_factors->data.f;
 
   kernel_utils::RnnBatchStep(
diff --git a/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc
index acf6663..d3946aa 100644
--- a/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc
+++ b/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc
@@ -23,6 +23,7 @@
 #include "tensorflow/lite/c/c_api_internal.h"
 #include "tensorflow/lite/kernels/activation_functor.h"
 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/kernels/op_macros.h"
 
@@ -360,7 +361,8 @@
             input->data.f + b * input_size * max_time + s * input_size;
         const float* aux_input_ptr_batch =
             (aux_input != nullptr)
-                ? aux_input->data.f + b * input_size * max_time + s * input_size
+                ? aux_input->data.f + b * aux_input_size * max_time +
+                      s * aux_input_size
                 : nullptr;
         float* output_ptr_batch = fw_output_offset + s * fw_output_step;
 
@@ -383,7 +385,8 @@
             input->data.f + b * input_size * max_time + s * input_size;
         const float* aux_input_ptr_batch =
             (aux_input != nullptr)
-                ? aux_input->data.f + b * input_size * max_time + s * input_size
+                ? aux_input->data.f + b * aux_input_size * max_time +
+                      s * aux_input_size
                 : nullptr;
         float* output_ptr_batch = bw_output_offset + s * bw_output_step;
 
@@ -413,7 +416,6 @@
     TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
     TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state,
     TfLiteTensor* bw_output) {
-  const bool is_uint8_hybrid = fw_input_weights->type == kTfLiteUInt8;
   const bool time_major = params->time_major;
   const int batch_size =
       (time_major) ? input->dims->data[1] : input->dims->data[0];
@@ -424,46 +426,40 @@
 
   const int fw_num_units = fw_input_weights->dims->data[0];
   const float* fw_bias_ptr = fw_bias->data.f;
-  const int8_t* fw_input_weights_ptr =
-      GetInt8DataPtr(fw_input_weights, is_uint8_hybrid);
+  const int8_t* fw_input_weights_ptr = GetTensorData<int8_t>(fw_input_weights);
   float fw_input_weights_scale = fw_input_weights->params.scale;
   const int8_t* fw_recurrent_weights_ptr =
-      GetInt8DataPtr(fw_recurrent_weights, is_uint8_hybrid);
+      GetTensorData<int8_t>(fw_recurrent_weights);
   float fw_recurrent_weights_scale = fw_recurrent_weights->params.scale;
 
   const int bw_num_units = bw_input_weights->dims->data[0];
   const float* bw_bias_ptr = bw_bias->data.f;
-  const int8_t* bw_input_weights_ptr =
-      GetInt8DataPtr(bw_input_weights, is_uint8_hybrid);
+  const int8_t* bw_input_weights_ptr = GetTensorData<int8_t>(bw_input_weights);
   float bw_input_weights_scale = bw_input_weights->params.scale;
   const int8_t* bw_recurrent_weights_ptr =
-      GetInt8DataPtr(bw_recurrent_weights, is_uint8_hybrid);
+      GetTensorData<int8_t>(bw_recurrent_weights);
   float bw_recurrent_weights_scale = bw_recurrent_weights->params.scale;
 
   // Set the auxiliary pointers and scales if needed.
-  int8_t* aux_fw_input_weights_ptr = nullptr;
+  const int8_t* aux_fw_input_weights_ptr = nullptr;
   float aux_fw_input_weights_scale = 0.0f;
-  int8_t* aux_bw_input_weights_ptr = nullptr;
+  const int8_t* aux_bw_input_weights_ptr = nullptr;
   float aux_bw_input_weights_scale = 0.0f;
   int8_t* aux_quantized_input_ptr = nullptr;
   if (aux_input_size > 0) {
-    aux_fw_input_weights_ptr =
-        GetInt8DataPtr(aux_fw_input_weights, is_uint8_hybrid);
+    aux_fw_input_weights_ptr = GetTensorData<int8_t>(aux_fw_input_weights);
     aux_fw_input_weights_scale = aux_fw_input_weights->params.scale;
-    aux_bw_input_weights_ptr =
-        GetInt8DataPtr(aux_bw_input_weights, is_uint8_hybrid);
+    aux_bw_input_weights_ptr = GetTensorData<int8_t>(aux_bw_input_weights);
     aux_bw_input_weights_scale = aux_bw_input_weights->params.scale;
-    aux_quantized_input_ptr =
-        GetInt8DataPtr(aux_input_quantized, is_uint8_hybrid);
+    aux_quantized_input_ptr = GetTensorData<int8_t>(aux_input_quantized);
   }
 
   // Initialize temporary storage for quantized values.
-  int8_t* quantized_input_ptr =
-      GetInt8DataPtr(input_quantized, is_uint8_hybrid);
+  int8_t* quantized_input_ptr = GetTensorData<int8_t>(input_quantized);
   int8_t* fw_quantized_hidden_state_ptr =
-      GetInt8DataPtr(fw_hidden_state_quantized, is_uint8_hybrid);
+      GetTensorData<int8_t>(fw_hidden_state_quantized);
   int8_t* bw_quantized_hidden_state_ptr =
-      GetInt8DataPtr(bw_hidden_state_quantized, is_uint8_hybrid);
+      GetTensorData<int8_t>(bw_hidden_state_quantized);
   float* scaling_factors_ptr = scaling_factors->data.f;
 
   const int fw_output_step =
diff --git a/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc
index eb7cb0b..a5210da 100644
--- a/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc
+++ b/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc
@@ -27,6 +27,12 @@
 namespace tflite {
 namespace {
 
+enum class AuxInputMode {
+  kNoAuxInput,
+  kCrossLinking,
+  kNoCrossLinking,
+};
+
 using ::testing::ElementsAreArray;
 
 static float rnn_input[] = {
@@ -654,13 +660,15 @@
 class BidirectionalRNNOpModel : public SingleOpModel {
  public:
   BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units,
-                          int bw_units, int input_size, bool use_aux_input,
-                          bool time_major, bool merge_outputs)
+                          int bw_units, int input_size, int aux_input_size,
+                          AuxInputMode aux_input_mode, bool time_major,
+                          bool merge_outputs)
       : batches_(batches),
         sequence_len_(sequence_len),
         fw_units_(fw_units),
         bw_units_(bw_units),
-        input_size_(input_size) {
+        input_size_(input_size),
+        aux_input_size_(aux_input_size) {
     input_ = AddInput(TensorType_FLOAT32);
     fw_weights_ = AddInput(TensorType_FLOAT32);
     fw_recurrent_weights_ = AddInput(TensorType_FLOAT32);
@@ -671,15 +679,33 @@
     bw_bias_ = AddInput(TensorType_FLOAT32);
     bw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
 
-    int aux_input_size = 0;
-    if (use_aux_input) {
+    const auto input_shape =
+        (time_major) ? std::vector<int>({sequence_len_, batches_, input_size_})
+                     : std::vector<int>({batches_, sequence_len_, input_size_});
+
+    std::vector<int> aux_input_shape = {0};
+    std::vector<int> aux_fw_weights_shape = {0};
+    std::vector<int> aux_bw_weights_shape = {0};
+    if (aux_input_mode != AuxInputMode::kNoAuxInput) {
       aux_input_ = AddInput(TensorType_FLOAT32);
-      aux_input_size = input_size_;
+      aux_input_shape =
+          (time_major)
+              ? std::vector<int>({sequence_len_, batches_, aux_input_size_})
+              : std::vector<int>({batches_, sequence_len_, aux_input_size_});
     } else {
       aux_input_ = AddNullInput();
     }
-    aux_fw_weights_ = AddNullInput();
-    aux_bw_weights_ = AddNullInput();
+
+    if (aux_input_mode == AuxInputMode::kCrossLinking) {
+      aux_fw_weights_ = AddInput(TensorType_FLOAT32);
+      aux_bw_weights_ = AddInput(TensorType_FLOAT32);
+
+      aux_fw_weights_shape = {fw_units, aux_input_size_};
+      aux_bw_weights_shape = {bw_units, aux_input_size_};
+    } else {
+      aux_fw_weights_ = AddNullInput();
+      aux_bw_weights_ = AddNullInput();
+    }
 
     fw_output_ = AddOutput(TensorType_FLOAT32);
     if (!merge_outputs) {
@@ -692,23 +718,20 @@
         CreateBidirectionalSequenceRNNOptions(
             builder_, time_major, ActivationFunctionType_RELU, merge_outputs)
             .Union());
-    const auto input_shape =
-        (time_major) ? std::vector<int>({sequence_len_, batches_, input_size_})
-                     : std::vector<int>({batches_, sequence_len_, input_size_});
 
     BuildInterpreter({
-        input_shape,                                // input
-        {fw_units_, input_size_},                   // fw_weights
-        {fw_units_, fw_units_},                     // fw_recurrent_weights
-        {fw_units_},                                // fw_bias
-        {batches_, fw_units_},                      // fw_hidden_state
-        {bw_units_, input_size_},                   // bw_weights
-        {bw_units_, bw_units_},                     // bw_recurrent_weights
-        {bw_units_},                                // bw_bias
-        {batches_, bw_units_},                      // bw_hidden_state
-        {batches_, sequence_len_, aux_input_size},  // aux_input
-        {fw_units_, 0},                             // aux_fw_weights
-        {bw_units_, 0},                             // aux_bw_weights
+        input_shape,               // input
+        {fw_units_, input_size_},  // fw_weights
+        {fw_units_, fw_units_},    // fw_recurrent_weights
+        {fw_units_},               // fw_bias
+        {batches_, fw_units_},     // fw_hidden_state
+        {bw_units_, input_size_},  // bw_weights
+        {bw_units_, bw_units_},    // bw_recurrent_weights
+        {bw_units_},               // bw_bias
+        {batches_, bw_units_},     // bw_hidden_state
+        aux_input_shape,           // aux_input
+        aux_fw_weights_shape,      // aux_fw_weights
+        aux_bw_weights_shape,      // aux_bw_weights
     });
   }
 
@@ -720,19 +743,19 @@
     PopulateTensor(bw_bias_, f);
   }
 
-  void SetFwWeights(std::initializer_list<float> f) {
+  void SetFwWeights(const std::vector<float>& f) {
     PopulateTensor(fw_weights_, f);
   }
 
-  void SetBwWeights(std::initializer_list<float> f) {
+  void SetBwWeights(const std::vector<float>& f) {
     PopulateTensor(bw_weights_, f);
   }
 
-  void SetFwRecurrentWeights(std::initializer_list<float> f) {
+  void SetFwRecurrentWeights(const std::vector<float>& f) {
     PopulateTensor(fw_recurrent_weights_, f);
   }
 
-  void SetBwRecurrentWeights(std::initializer_list<float> f) {
+  void SetBwRecurrentWeights(const std::vector<float>& f) {
     PopulateTensor(bw_recurrent_weights_, f);
   }
 
@@ -748,10 +771,19 @@
     PopulateTensor(aux_input_, offset, begin, end);
   }
 
+  void SetAuxFwWeights(const std::vector<float>& f) {
+    PopulateTensor(aux_fw_weights_, f);
+  }
+
+  void SetAuxBwWeights(const std::vector<float>& f) {
+    PopulateTensor(aux_bw_weights_, f);
+  }
+
   std::vector<float> GetFwOutput() { return ExtractVector<float>(fw_output_); }
   std::vector<float> GetBwOutput() { return ExtractVector<float>(bw_output_); }
 
   int input_size() { return input_size_; }
+  int aux_input_size() { return aux_input_size_; }
   int num_fw_units() { return fw_units_; }
   int num_bw_units() { return bw_units_; }
   int num_batches() { return batches_; }
@@ -778,6 +810,7 @@
   int fw_units_;
   int bw_units_;
   int input_size_;
+  int aux_input_size_;
 };
 
 // TODO(mirkov): add another test which directly compares to TF once TOCO
@@ -785,7 +818,8 @@
 TEST(BidirectionalRNNOpTest, BlackBoxTest) {
   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
                               /*fw_units=*/16, /*bw_units=*/16,
-                              /*input_size=*/8, /*use_aux_input=*/false,
+                              /*input_size=*/8, /*aux_input_size=*/0,
+                              /*aux_input_mode=*/AuxInputMode::kNoAuxInput,
                               /*time_major=*/false,
                               /*merge_outputs=*/false);
   rnn.SetFwWeights(weights);
@@ -824,7 +858,8 @@
 TEST(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) {
   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
                               /*fw_units=*/16, /*bw_units=*/16,
-                              /*input_size=*/8, /*use_aux_input=*/false,
+                              /*input_size=*/8, /*aux_input_size=*/0,
+                              /*aux_input_mode=*/AuxInputMode::kNoAuxInput,
                               /*time_major=*/true,
                               /*merge_outputs=*/false);
   rnn.SetFwWeights(weights);
@@ -861,7 +896,8 @@
 TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) {
   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
                               /*fw_units=*/16, /*bw_units=*/16,
-                              /*input_size=*/8, /*use_aux_input=*/false,
+                              /*input_size=*/8, /*aux_input_size=*/0,
+                              /*aux_input_mode=*/AuxInputMode::kNoAuxInput,
                               /*time_major=*/false,
                               /*merge_outputs=*/true);
   rnn.SetFwWeights(weights);
@@ -900,7 +936,8 @@
 TEST(BidirectionalRNNOpTest, BlackBoxTestTimeMajorMergeOutputs) {
   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
                               /*fw_units=*/16, /*bw_units=*/16,
-                              /*input_size=*/8, /*use_aux_input=*/false,
+                              /*input_size=*/8, /*aux_input_size=*/0,
+                              /*aux_input_mode=*/AuxInputMode::kNoAuxInput,
                               /*time_major=*/true,
                               /*merge_outputs=*/true);
   rnn.SetFwWeights(weights);
@@ -945,7 +982,8 @@
 TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) {
   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
                               /*fw_units=*/16, /*bw_units=*/16,
-                              /*input_size=*/8, /*use_aux_input=*/false,
+                              /*input_size=*/8, /*aux_input_size=*/0,
+                              /*aux_input_mode=*/AuxInputMode::kNoAuxInput,
                               /*time_major=*/false,
                               /*merge_outputs=*/false);
   rnn.SetFwWeights(weights);
@@ -993,7 +1031,8 @@
 TEST(BidirectionalRNNOpTest, EndToEndTest) {
   BidirectionalRNNOpModel rnn(/*batches=*/1, /*sequence_len=*/4,
                               /*fw_units=*/16, /*bw_units=*/16,
-                              /*input_size=*/8, /*use_aux_input=*/false,
+                              /*input_size=*/8, /*aux_input_size=*/0,
+                              /*aux_input_mode=*/AuxInputMode::kNoAuxInput,
                               /*time_major=*/false,
                               /*merge_outputs=*/false);
   const int output_size = 4;
@@ -1061,11 +1100,15 @@
   }
 }
 
-// Same as BlackBox test, but has aux input.
-TEST(BidirectionalRNNOpTest, BlackBoxTestAuxInput) {
+// Same as BlackBox test, but has an auxiliary input. The layer has no
+// cross-linking, i.e. the regular input is passed as an input to the forward
+// network only and the auxiliary input is passed as an input to the backward
+// network only.
+TEST(BidirectionalRNNOpTest, BlackBoxTestNoCrossLinkingRegularAndAuxInput) {
   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
                               /*fw_units=*/16, /*bw_units=*/16,
-                              /*input_size=*/8, /*use_aux_input=*/true,
+                              /*input_size=*/8, /*aux_input_size=*/8,
+                              /*aux_input_mode=*/AuxInputMode::kNoCrossLinking,
                               /*time_major=*/true,
                               /*merge_outputs=*/false);
   rnn.SetFwWeights(weights);
@@ -1092,20 +1135,29 @@
   rnn.Invoke();
 
   std::vector<float> fw_expected;
+  std::vector<float> bw_expected;
   for (int i = 0; i < rnn.sequence_len(); i++) {
     float* golden_fw_start = rnn_golden_fw_output + i * rnn.num_fw_units();
     float* golden_fw_end = golden_fw_start + rnn.num_fw_units();
     fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
     fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
+
+    float* golden_bw_start = rnn_golden_bw_output + i * rnn.num_fw_units();
+    float* golden_bw_end = golden_bw_start + rnn.num_fw_units();
+    bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
+    bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
   }
   EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
+  EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
 }
 
-// Same as previous test, but has aux input is all zeros.
-TEST(BidirectionalRNNOpTest, BlackBoxTestAuxInputZeros) {
+// Same as above but the auxiliary input is set to zeroes. This test makes sure
+// that the forward network works as expected in a no-cross-linking mode.
+TEST(BidirectionalRNNOpTest, BlackBoxTestNoCrossLinkingRegularInputOnly) {
   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
                               /*fw_units=*/16, /*bw_units=*/16,
-                              /*input_size=*/8, /*use_aux_input=*/true,
+                              /*input_size=*/8, /*aux_input_size=*/8,
+                              /*aux_input_mode=*/AuxInputMode::kNoCrossLinking,
                               /*time_major=*/true,
                               /*merge_outputs=*/false);
   rnn.SetFwWeights(weights);
@@ -1146,12 +1198,14 @@
   EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
 }
 
-// Same as previous test, but has input is all zeros, and aux input is the real
-// input. This is testing the bw path is functional.
-TEST(BidirectionalRNNOpTest, BlackBoxTestAuxInputInputZeros) {
+// Same as above but the regular (i.e. not auxiliary) input is set to zeroes.
+// This test makes sure that the backward network works as expected in a
+// no-cross-linking mode.
+TEST(BidirectionalRNNOpTest, BlackBoxTestNoCrossLinkingAuxInputOnly) {
   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
                               /*fw_units=*/16, /*bw_units=*/16,
-                              /*input_size=*/8, /*use_aux_input=*/true,
+                              /*input_size=*/8, /*aux_input_size=*/8,
+                              /*aux_input_mode=*/AuxInputMode::kNoCrossLinking,
                               /*time_major=*/true,
                               /*merge_outputs=*/false);
   rnn.SetFwWeights(weights);
@@ -1192,5 +1246,204 @@
   EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
 }
 
+// Same as BlackBox test, but an input is passed to auxiliary input instead of
+// the regular one. Regular input and weights are set to zero.
+TEST(BidirectionalRNNOpTest, BlackBoxTestCrossLinkingAuxInputOnly) {
+  BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
+                              /*fw_units=*/16, /*bw_units=*/16,
+                              /*input_size=*/8, /*aux_input_size=*/8,
+                              /*aux_input_mode=*/AuxInputMode::kCrossLinking,
+                              /*time_major=*/false,
+                              /*merge_outputs=*/false);
+  rnn.SetFwWeights(std::vector<float>(weights.size(), 0.0));
+  rnn.SetBwWeights(std::vector<float>(weights.size(), 0.0));
+  rnn.SetFwBias(biases);
+  rnn.SetBwBias(biases);
+  rnn.SetFwRecurrentWeights(recurrent_weights);
+  rnn.SetBwRecurrentWeights(recurrent_weights);
+  rnn.SetAuxFwWeights(weights);
+  rnn.SetAuxBwWeights(weights);
+
+  const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
+  std::vector<float> zero_input(input_sequence_size, 0.f);
+  float* batch_start = rnn_input;
+  float* batch_end = batch_start + input_sequence_size;
+  // Set batch 0 inputs
+  rnn.SetInput(0, zero_input.data(), zero_input.data() + zero_input.size());
+  rnn.SetAuxInput(0, batch_start, batch_end);
+  // Set batch 1 inputs
+  rnn.SetInput(input_sequence_size, zero_input.data(),
+               zero_input.data() + zero_input.size());
+  rnn.SetAuxInput(input_sequence_size, batch_start, batch_end);
+
+  rnn.Invoke();
+
+  float* golden_fw_start = rnn_golden_fw_output;
+  float* golden_fw_end =
+      golden_fw_start + rnn.num_fw_units() * rnn.sequence_len();
+  std::vector<float> fw_expected;
+  fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
+  fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
+  EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
+
+  float* golden_bw_start = rnn_golden_bw_output;
+  float* golden_bw_end =
+      golden_bw_start + rnn.num_bw_units() * rnn.sequence_len();
+  std::vector<float> bw_expected;
+  bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
+  bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
+  EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
+}
+
+// Same as BlackBox test, but an input is passed to auxiliary input instead of
+// the regular one. Regular input and weights are set to zero. Time major inputs
+// and outputs.
+TEST(BidirectionalRNNOpTest, BlackBoxTestCrossLinkingAuxInputOnlyTimeMajor) {
+  BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
+                              /*fw_units=*/16, /*bw_units=*/16,
+                              /*input_size=*/8, /*aux_input_size=*/8,
+                              /*aux_input_mode=*/AuxInputMode::kCrossLinking,
+                              /*time_major=*/true,
+                              /*merge_outputs=*/false);
+  rnn.SetFwWeights(std::vector<float>(weights.size(), 0.0));
+  rnn.SetBwWeights(std::vector<float>(weights.size(), 0.0));
+  rnn.SetFwBias(biases);
+  rnn.SetBwBias(biases);
+  rnn.SetFwRecurrentWeights(recurrent_weights);
+  rnn.SetBwRecurrentWeights(recurrent_weights);
+  rnn.SetAuxFwWeights(weights);
+  rnn.SetAuxBwWeights(weights);
+
+  std::vector<float> zero_input(rnn.sequence_len(), 0.f);
+
+  // Insert the inputs in time_major format. The batch_major format is:
+  // [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as:
+  // [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15].
+  for (int i = 0; i < rnn.sequence_len(); i++) {
+    float* batch_start = rnn_input + i * rnn.input_size();
+    float* batch_end = batch_start + rnn.input_size();
+    // The two batches are identical.
+    // Set batch 0 inputs
+    rnn.SetInput(2 * i * rnn.input_size(), &zero_input.front(),
+                 &zero_input.back() + 1);
+    rnn.SetAuxInput(2 * i * rnn.input_size(), batch_start, batch_end);
+    // Set batch 1 inputs
+    rnn.SetInput((2 * i + 1) * rnn.input_size(), &zero_input.front(),
+                 &zero_input.back() + 1);
+    rnn.SetAuxInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
+  }
+
+  rnn.Invoke();
+
+  std::vector<float> fw_expected;
+  for (int i = 0; i < rnn.sequence_len(); i++) {
+    float* golden_fw_start = rnn_golden_fw_output + i * rnn.num_fw_units();
+    float* golden_fw_end = golden_fw_start + rnn.num_fw_units();
+    fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
+    fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
+  }
+  EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
+}
+
+// Same as BlackBox test, but the input tensor and weights tensor are split
+// along the last dimension and passed to both regular and auxiliry inputs and
+// weights. The output in this case is the same. To understand this, let's
+// define W and V as regular input weights matrix and auxiliary input weights
+// matrix correspondingly. It's easy to see that this is equivalent to a regular
+// RNN with weights U = (W|V) and z^T = x^T | y^T, where .|. denotes
+// concatenation along horizontal axis:
+//   f(z) = Uz + b
+// is equivalent to:
+//   f((x^T|y^T)^T) = (Wx + Vy) + b.
+void run_blackbox_test_with_input_split(int input_size, int aux_input_size) {
+  const int num_units = 16;
+  BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
+                              /*fw_units=*/num_units, /*bw_units=*/num_units,
+                              input_size, aux_input_size,
+                              /*aux_input_mode=*/AuxInputMode::kCrossLinking,
+                              /*time_major=*/false,
+                              /*merge_outputs=*/false);
+  std::vector<float> reg_weights(num_units * rnn.input_size());
+  std::vector<float> aux_weights(num_units * rnn.aux_input_size());
+  int full_weights_size = weights.size();
+  int reg_weights_offset = 0;
+  int aux_weights_offset = 0;
+  int weights_offset = 0;
+  // Alternating copying to regular input weights and auxiliary input weights to
+  // split the original weight matrix in half along the last axis.
+  while (weights_offset < full_weights_size) {
+    std::copy(weights.begin() + weights_offset,
+              weights.begin() + weights_offset + rnn.input_size(),
+              reg_weights.begin() + reg_weights_offset);
+    weights_offset += rnn.input_size();
+    reg_weights_offset += rnn.input_size();
+
+    std::copy(weights.begin() + weights_offset,
+              weights.begin() + weights_offset + rnn.aux_input_size(),
+              aux_weights.begin() + aux_weights_offset);
+    weights_offset += rnn.aux_input_size();
+    aux_weights_offset += rnn.aux_input_size();
+  }
+
+  rnn.SetFwWeights(reg_weights);
+  rnn.SetBwWeights(reg_weights);
+  rnn.SetFwBias(biases);
+  rnn.SetBwBias(biases);
+  rnn.SetFwRecurrentWeights(recurrent_weights);
+  rnn.SetBwRecurrentWeights(recurrent_weights);
+  rnn.SetAuxFwWeights(aux_weights);
+  rnn.SetAuxBwWeights(aux_weights);
+
+  int full_input_size =
+      (rnn.input_size() + rnn.aux_input_size()) * rnn.sequence_len();
+  int reg_input_offset = 0;
+  int aux_input_offset = 0;
+  // Alternating copying to regular input tensor and auxiliary input tensor to
+  // split the original input matrix in half along the last axis.
+  for (int batch = 0; batch < 2; ++batch) {
+    int input_offset = 0;
+    while (input_offset < full_input_size) {
+      rnn.SetInput(reg_input_offset, rnn_input + input_offset,
+                   rnn_input + input_offset + rnn.input_size());
+      input_offset += rnn.input_size();
+      reg_input_offset += rnn.input_size();
+
+      rnn.SetAuxInput(aux_input_offset, rnn_input + input_offset,
+                      rnn_input + input_offset + rnn.aux_input_size());
+      input_offset += rnn.aux_input_size();
+      aux_input_offset += rnn.aux_input_size();
+    }
+  }
+
+  rnn.Invoke();
+
+  float* golden_fw_start = rnn_golden_fw_output;
+  float* golden_fw_end =
+      golden_fw_start + rnn.num_fw_units() * rnn.sequence_len();
+  std::vector<float> fw_expected;
+  fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
+  fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
+  EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
+
+  float* golden_bw_start = rnn_golden_bw_output;
+  float* golden_bw_end =
+      golden_bw_start + rnn.num_bw_units() * rnn.sequence_len();
+  std::vector<float> bw_expected;
+  bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
+  bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
+  EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
+}
+
+TEST(BidirectionalRNNOpTest,
+     BlackBoxTestCrossLinkingRegularAndAuxInputEvenSplit) {
+  run_blackbox_test_with_input_split(/*input_size=*/4, /*aux_input_size=*/4);
+}
+
+// Same as above but the input tensor and the weights tensor are split unevenly.
+TEST(BidirectionalRNNOpTest,
+     BlackBoxTestCrossLinkingRegularAndAuxInputUnevenSplit) {
+  run_blackbox_test_with_input_split(/*input_size=*/2, /*aux_input_size=*/6);
+}
+
 }  // namespace
 }  // namespace tflite
diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h b/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h
index 017f166..aa41f03 100644
--- a/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h
+++ b/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h
@@ -541,8 +541,10 @@
       // being processed.
 
       // Add bias values.
-      int32x4_t bias_vec = vld1q_s32(params.bias + row);
-      reduced = vaddq_s32(reduced, bias_vec);
+      if (params.bias) {
+        int32x4_t bias_vec = vld1q_s32(params.bias + row);
+        reduced = vaddq_s32(reduced, bias_vec);
+      }
 
       // Get multiplier parameters.
       int32x4_t multiplier_fixedpoint;
diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h b/tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h
index 3c63443..a73149c 100644
--- a/tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h
+++ b/tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h
@@ -92,9 +92,6 @@
 
     using ColVectorMap =
         gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>;
-    ColVectorMap bias_vector(params.bias, lhs_params.rows);
-    gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
-    bias_addition_stage.bias_vector = bias_vector;
     gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage;
     scale_stage.result_offset_after_shift = dst_params.zero_point;
     scale_stage.result_fixedpoint_multiplier = params.multiplier_fixedpoint;
@@ -105,12 +102,25 @@
     clamp_stage.min = params.clamp_min;
     clamp_stage.max = params.clamp_max;
     SaturatingCastStageType saturating_cast_stage;
-    auto output_pipeline = std::make_tuple(bias_addition_stage, scale_stage,
-                                           clamp_stage, saturating_cast_stage);
     using BitDepthParams = typename GemmlowpBitDepthParams<SrcScalar>::Type;
-    gemmlowp::GemmWithOutputPipeline<SrcScalar, DstScalar, BitDepthParams>(
-        context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst,
-        -lhs_params.zero_point, -rhs_params.zero_point, output_pipeline);
+    if (params.bias) {
+      ColVectorMap bias_vector(params.bias, lhs_params.rows);
+      gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
+      bias_addition_stage.bias_vector = bias_vector;
+      auto output_pipeline = std::make_tuple(
+          bias_addition_stage, scale_stage, clamp_stage, saturating_cast_stage);
+      gemmlowp::GemmWithOutputPipeline<SrcScalar, DstScalar, BitDepthParams>(
+          context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs,
+          &gemmlowp_dst, -lhs_params.zero_point, -rhs_params.zero_point,
+          output_pipeline);
+    } else {
+      auto output_pipeline =
+          std::make_tuple(scale_stage, clamp_stage, saturating_cast_stage);
+      gemmlowp::GemmWithOutputPipeline<SrcScalar, DstScalar, BitDepthParams>(
+          context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs,
+          &gemmlowp_dst, -lhs_params.zero_point, -rhs_params.zero_point,
+          output_pipeline);
+    }
   }
 };
 
diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_params.h b/tensorflow/lite/kernels/cpu_backend_gemm_params.h
index 40e81dc..27c2738 100644
--- a/tensorflow/lite/kernels/cpu_backend_gemm_params.h
+++ b/tensorflow/lite/kernels/cpu_backend_gemm_params.h
@@ -158,20 +158,12 @@
     TFLITE_DCHECK(!params.multiplier_exponent_perchannel);
   } else if (quantization_flavor ==
              QuantizationFlavor::kIntegerWithUniformMultiplier) {
-    // For now require a bias vector. Ruy does not care, but for gemmlowp
-    // it's a separate instantiation of the whole GEMM, so we save a lot of
-    // binary size by requiring a bias vector, and that's what we've been
-    // doing all along in our usage of gemmlowp, so somehow that must
-    // be OK with all existing users.
-    TFLITE_DCHECK(params.bias);
     TFLITE_DCHECK(params.multiplier_fixedpoint);
     // Nothing to check about multiplier_exponent
     TFLITE_DCHECK(!params.multiplier_fixedpoint_perchannel);
     TFLITE_DCHECK(!params.multiplier_exponent_perchannel);
   } else if (quantization_flavor ==
              QuantizationFlavor::kIntegerWithPerRowMultiplier) {
-    // See above comment about requiring bias.
-    TFLITE_DCHECK(params.bias);
     TFLITE_DCHECK(!params.multiplier_fixedpoint);
     TFLITE_DCHECK(!params.multiplier_exponent);
     TFLITE_DCHECK(params.multiplier_fixedpoint_perchannel);
diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_test.cc b/tensorflow/lite/kernels/cpu_backend_gemm_test.cc
index fe2792b..427c6ab 100644
--- a/tensorflow/lite/kernels/cpu_backend_gemm_test.cc
+++ b/tensorflow/lite/kernels/cpu_backend_gemm_test.cc
@@ -416,8 +416,7 @@
   }
 
   GemmParams<AccumScalar, DstScalar> params;
-  if (use_golden || !std::is_floating_point<AccumScalar>::value ||
-      (random_engine() % 2)) {
+  if (use_golden || (random_engine() % 2)) {
     // cpu_backend_gemm supports bias=null only in the float path. Test that
     // in 50% of float testcases.
     params.bias = bias_data.data();
diff --git a/tensorflow/lite/kernels/depthwise_conv_test.cc b/tensorflow/lite/kernels/depthwise_conv_test.cc
index 3a8a62c..75b4d5e 100644
--- a/tensorflow/lite/kernels/depthwise_conv_test.cc
+++ b/tensorflow/lite/kernels/depthwise_conv_test.cc
@@ -616,6 +616,8 @@
   }
 };
 
+// Only enable this test for neon.
+#ifdef USE_NEON
 TEST_F(QuantizedDepthwiseConvolutionOpTest, LargeOutputChannelTest) {
   const TensorData input({TensorType_UINT8, {1, 4, 4, 2400}, -63.5, 64});
   const TensorData filter({TensorType_UINT8, {1, 3, 3, 2400}, -63.5, 64});
@@ -646,6 +648,7 @@
   reference_impl.SetInput(input_data);
   reference_impl.SetFilter(filter_data);
   reference_impl.SetBias(bias_data);
+  reference_impl.Invoke();
 
   QuantizedDepthwiseConvolutionOpModel optimized_impl(
       ops::builtin::Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT(), input, filter,
@@ -653,9 +656,11 @@
   optimized_impl.SetInput(input_data);
   optimized_impl.SetFilter(filter_data);
   optimized_impl.SetBias(bias_data);
+  optimized_impl.Invoke();
 
-  // EXPECT_THAT(reference_impl.GetOutput(), optimized_impl.GetOutput());
+  EXPECT_THAT(reference_impl.GetOutput(), optimized_impl.GetOutput());
 }
+#endif
 
 // In this test we set the input and output scales so that the results match
 // exactly the 'non-quantized' version.
diff --git a/tensorflow/lite/kernels/dequantize.cc b/tensorflow/lite/kernels/dequantize.cc
index db7e23e..5ba94a6 100644
--- a/tensorflow/lite/kernels/dequantize.cc
+++ b/tensorflow/lite/kernels/dequantize.cc
@@ -33,6 +33,12 @@
 namespace builtin {
 namespace dequantize {
 
+// This file has two implementation of Dequantize.
+enum KernelType {
+  kReference,
+  kGenericOptimized,
+};
+
 struct OpContext {
   OpContext(TfLiteContext* context, TfLiteNode* node) {
     input = GetInput(context, node, 0);
@@ -78,6 +84,7 @@
                                TfLiteIntArrayCopy(op_context.input->dims));
 }
 
+template <KernelType kernel_type>
 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
   OpContext op_context(context, node);
@@ -91,24 +98,45 @@
   op_params.scale = op_context.input->params.scale;
   switch (op_context.input->type) {
     case kTfLiteUInt8:
-      optimized_ops::Dequantize(op_params, GetTensorShape(op_context.input),
-                                GetTensorData<uint8_t>(op_context.input),
-                                GetTensorShape(op_context.output),
-                                GetTensorData<float>(op_context.output));
+      if (kernel_type == kReference) {
+        reference_ops::Dequantize(op_params, GetTensorShape(op_context.input),
+                                  GetTensorData<uint8_t>(op_context.input),
+                                  GetTensorShape(op_context.output),
+                                  GetTensorData<float>(op_context.output));
+      } else {
+        optimized_ops::Dequantize(op_params, GetTensorShape(op_context.input),
+                                  GetTensorData<uint8_t>(op_context.input),
+                                  GetTensorShape(op_context.output),
+                                  GetTensorData<float>(op_context.output));
+      }
       break;
     case kTfLiteInt8:
-      reference_integer_ops::Dequantize<int8_t>(
-          op_params, GetTensorShape(op_context.input),
-          GetTensorData<int8_t>(op_context.input),
-          GetTensorShape(op_context.output),
-          GetTensorData<float>(op_context.output));
+      if (kernel_type == kReference) {
+        reference_integer_ops::Dequantize<int8_t>(
+            op_params, GetTensorShape(op_context.input),
+            GetTensorData<int8_t>(op_context.input),
+            GetTensorShape(op_context.output),
+            GetTensorData<float>(op_context.output));
+      } else {
+        optimized_ops::Dequantize(op_params, GetTensorShape(op_context.input),
+                                  GetTensorData<int8_t>(op_context.input),
+                                  GetTensorShape(op_context.output),
+                                  GetTensorData<float>(op_context.output));
+      }
       break;
     case kTfLiteInt16:
-      reference_integer_ops::Dequantize<int16_t>(
-          op_params, GetTensorShape(op_context.input),
-          GetTensorData<int16_t>(op_context.input),
-          GetTensorShape(op_context.output),
-          GetTensorData<float>(op_context.output));
+      if (kernel_type == kReference) {
+        reference_integer_ops::Dequantize<int16_t>(
+            op_params, GetTensorShape(op_context.input),
+            GetTensorData<int16_t>(op_context.input),
+            GetTensorShape(op_context.output),
+            GetTensorData<float>(op_context.output));
+      } else {
+        optimized_ops::Dequantize(op_params, GetTensorShape(op_context.input),
+                                  GetTensorData<int16_t>(op_context.input),
+                                  GetTensorShape(op_context.output),
+                                  GetTensorData<float>(op_context.output));
+      }
       break;
     case kTfLiteFloat16: {
       const Eigen::half* half_data = reinterpret_cast<const Eigen::half*>(
@@ -134,12 +162,26 @@
 }  // namespace dequantize
 
 TfLiteRegistration* Register_DEQUANTIZE_OPT() {
-  static TfLiteRegistration r = {dequantize::Init, dequantize::Free,
-                                 dequantize::Prepare, dequantize::Eval};
+  static TfLiteRegistration r = {
+      dequantize::Init, dequantize::Free, dequantize::Prepare,
+      dequantize::Eval<dequantize::kGenericOptimized>};
   return &r;
 }
 
-TfLiteRegistration* Register_DEQUANTIZE() { return Register_DEQUANTIZE_OPT(); }
+TfLiteRegistration* Register_DEQUANTIZE_REF() {
+  static TfLiteRegistration r = {dequantize::Init, dequantize::Free,
+                                 dequantize::Prepare,
+                                 dequantize::Eval<dequantize::kReference>};
+  return &r;
+}
+
+TfLiteRegistration* Register_DEQUANTIZE() {
+#ifdef USE_NEON
+  return Register_DEQUANTIZE_OPT();
+#else
+  return Register_DEQUANTIZE_REF();
+#endif
+}
 
 }  // namespace builtin
 }  // namespace ops
diff --git a/tensorflow/lite/kernels/floor.cc b/tensorflow/lite/kernels/floor.cc
index b6ccce3..7607419 100644
--- a/tensorflow/lite/kernels/floor.cc
+++ b/tensorflow/lite/kernels/floor.cc
@@ -13,9 +13,10 @@
 limitations under the License.
 ==============================================================================*/
 
+#include "tensorflow/lite/kernels/internal/reference/floor.h"
+
 #include "tensorflow/lite/c/c_api_internal.h"
 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
-#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/lite/kernels/internal/tensor.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
 
diff --git a/tensorflow/lite/kernels/fully_connected.cc b/tensorflow/lite/kernels/fully_connected.cc
index 64da153..f02c405 100644
--- a/tensorflow/lite/kernels/fully_connected.cc
+++ b/tensorflow/lite/kernels/fully_connected.cc
@@ -17,6 +17,7 @@
 
 #include <cassert>
 #include <cmath>
+#include <cstdint>
 #include <cstdio>
 #include <cstdlib>
 #include <iostream>
@@ -131,7 +132,7 @@
   OpData* data = reinterpret_cast<OpData*>(node->user_data);
 
   // Check we have all the inputs and outputs we need.
-  TF_LITE_ENSURE_EQ(context, node->inputs->size, 3);
+  TF_LITE_ENSURE(context, node->inputs->size == 2 || node->inputs->size == 3);
   // Shuffled formats need a workspace to store the shuffled input activations.
   const int expected_outputs_count =
       params->weights_format == kTfLiteFullyConnectedWeightsFormatDefault ? 1
@@ -140,7 +141,10 @@
 
   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
   const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
-  const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
+  const TfLiteTensor* bias =
+      (node->inputs->size == 3)
+          ? GetOptionalInputTensor(context, node, kBiasTensor)
+          : nullptr;
   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
 
   // Check proper datatype match among all Input Tensors
@@ -524,7 +528,10 @@
 
   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
   const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
-  const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
+  const TfLiteTensor* bias =
+      (node->inputs->size == 3)
+          ? GetOptionalInputTensor(context, node, kBiasTensor)
+          : nullptr;
   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
 
   switch (filter->type) {
diff --git a/tensorflow/lite/kernels/fully_connected_test.cc b/tensorflow/lite/kernels/fully_connected_test.cc
index 637ee6b..c564a52 100644
--- a/tensorflow/lite/kernels/fully_connected_test.cc
+++ b/tensorflow/lite/kernels/fully_connected_test.cc
@@ -139,7 +139,8 @@
       bool keep_num_dims = false, bool bias_tensor_optional = false,
       ActivationFunctionType activation_func = ActivationFunctionType_RELU,
       FullyConnectedOptionsWeightsFormat weights_format =
-          FullyConnectedOptionsWeightsFormat_DEFAULT)
+          FullyConnectedOptionsWeightsFormat_DEFAULT,
+      bool add_bias_for_quantized = true)
       : batches_(batches), units_(units) {
     int total_input_size = 1;
     for (size_t i = 0; i < input.shape.size(); ++i) {
@@ -155,7 +156,7 @@
       bias_ = AddNullInput();
     } else if (input.type == TensorType_FLOAT32) {
       bias_ = AddInput({TensorType_FLOAT32, {units_}});
-    } else {
+    } else if (add_bias_for_quantized) {
       // This is a quantized version. The scale of 'bias' depends on the scales
       // of input and filter. Supposedly this is correctly set during quantized
       // training.
@@ -176,9 +177,13 @@
                      .Union());
     resolver_ = absl::make_unique<SingleOpResolver>(
         BuiltinOperator_FULLY_CONNECTED, registration);
-    BuildInterpreter(
-        {GetShape(input_), GetShape(weights_),
-         (bias_ == kOptionalTensor) ? std::vector<int>() : GetShape(bias_)});
+    std::vector<std::vector<int>> inputs = {GetShape(input_),
+                                            GetShape(weights_)};
+    if (add_bias_for_quantized) {
+      inputs.push_back((bias_ == kOptionalTensor) ? std::vector<int>()
+                                                  : GetShape(bias_));
+    }
+    BuildInterpreter(inputs);
   }
 
   int input_size() { return input_size_; }
@@ -465,6 +470,40 @@
               ElementsAre(151, 152, 153, 185, 186, 187));
 }
 
+TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedUint8NoBias) {
+  QuantizedFullyConnectedOpModel m(
+      GetRegistration(), /*units=*/3, /*batches*/ 2,
+      /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64},
+      /*output=*/{TensorType_UINT8, {}, -127, 128},
+      /*keep_num_dims =*/false, /*bool bias_tensor_optional =*/false,
+      /*ActivationFunctionType activation_func =*/ActivationFunctionType_RELU,
+      /*FullyConnectedOptionsWeightsFormat weights_format =*/
+      FullyConnectedOptionsWeightsFormat_DEFAULT,
+      /*add_bias_for_quantized =*/false);
+
+  // input_product_scale < output_scale was not true.
+  m.SetWeights<uint8_t>({
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
+  });
+
+  m.SetInput<uint8_t>({
+      1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
+      1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
+  });
+
+  m.Invoke();
+
+  EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
+              ElementsAreArray(ArrayFloatNear({
+                  23, 23, 23,  //
+                  57, 57, 57,  //
+              })));
+  EXPECT_THAT(m.GetOutput<uint8_t>(),
+              ElementsAre(150, 150, 150, 184, 184, 184));
+}
+
 TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt8) {
   QuantizedFullyConnectedOpModel m(
       GetRegistration(), /*units=*/3, /*batches*/ 2,
@@ -491,6 +530,36 @@
   EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(23, 24, 25, 57, 58, 59));
 }
 
+TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt8NoBias) {
+  QuantizedFullyConnectedOpModel m(
+      GetRegistration(), /*units=*/3, /*batches*/ 2,
+      /*input=*/{TensorType_INT8, {2, 10}, -63.5, 64},
+      /*output=*/{TensorType_INT8, {}, -127, 128},
+      /*keep_num_dims =*/false, /*bool bias_tensor_optional =*/false,
+      /*ActivationFunctionType activation_func =*/ActivationFunctionType_RELU,
+      /*FullyConnectedOptionsWeightsFormat weights_format =*/
+      FullyConnectedOptionsWeightsFormat_DEFAULT,
+      /*add_bias_for_quantized =*/false);
+
+  // input_product_scale < output_scale was not true.
+  m.SetWeights<int8_t>({
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
+  });
+
+  m.SetInput<int8_t>({
+      1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
+      1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
+  });
+
+  m.Invoke();
+
+  EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
+              ElementsAreArray(ArrayFloatNear({23, 23, 23, 57, 57, 57})));
+  EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(22, 22, 22, 56, 56, 56));
+}
+
 // Test the GEMV path.
 TEST_P(QuantizedFullyConnectedOpTest, SimpleTestSingleBatchQuantizedInt8) {
   QuantizedFullyConnectedOpModel m(
diff --git a/tensorflow/lite/kernels/gather.cc b/tensorflow/lite/kernels/gather.cc
index 54d05ad..85eb423 100644
--- a/tensorflow/lite/kernels/gather.cc
+++ b/tensorflow/lite/kernels/gather.cc
@@ -60,6 +60,7 @@
     case kTfLiteInt8:
     case kTfLiteInt64:
     case kTfLiteInt32:
+    case kTfLiteBool:
       break;
     case kTfLiteString: {
       // Only 1D input is supported.
@@ -142,6 +143,8 @@
         return Gather<int32_t, int32_t>(*params, input, positions, output);
       case kTfLiteInt64:
         return Gather<int64_t, int32_t>(*params, input, positions, output);
+      case kTfLiteBool:
+        return Gather<bool, int32_t>(*params, input, positions, output);
       case kTfLiteString:
         return GatherStrings<int32_t>(context, input, positions, output);
       default:
@@ -162,6 +165,8 @@
         return Gather<int32_t, int64_t>(*params, input, positions, output);
       case kTfLiteInt64:
         return Gather<int64_t, int64_t>(*params, input, positions, output);
+      case kTfLiteBool:
+        return Gather<bool, int64_t>(*params, input, positions, output);
       case kTfLiteString:
         return GatherStrings<int64_t>(context, input, positions, output);
       default:
diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD
index 199909c..08262f3 100644
--- a/tensorflow/lite/kernels/internal/BUILD
+++ b/tensorflow/lite/kernels/internal/BUILD
@@ -167,6 +167,16 @@
     },
 )
 
+config_setting(
+    name = "raspberry_pi_with_neon",
+    define_values = {
+        "raspberry_pi_with_neon": "true",
+    },
+    values = {
+        "cpu": "armeabi",
+    },
+)
+
 cc_library(
     name = "common",
     srcs = [],
@@ -194,6 +204,7 @@
         "optimized/integer_ops/depthwise_conv.h",
         "optimized/integer_ops/depthwise_conv_3x3_filter.h",
         "optimized/integer_ops/fully_connected.h",
+        "optimized/integer_ops/mean.h",
         "optimized/integer_ops/mul.h",
         "optimized/integer_ops/pooling.h",
         "optimized/integer_ops/softmax.h",
@@ -347,9 +358,14 @@
     name = "reference_base",
     srcs = [],
     hdrs = [
+        "reference/add.h",
+        "reference/arg_min_max.h",
+        "reference/binary_function.h",
+        "reference/comparisons.h",
         "reference/conv.h",
         "reference/depthwiseconv_float.h",
         "reference/depthwiseconv_uint8.h",
+        "reference/floor.h",
         "reference/fully_connected.h",
         "reference/integer_ops/add.h",
         "reference/integer_ops/conv.h",
@@ -364,11 +380,14 @@
         "reference/integer_ops/pooling.h",
         "reference/integer_ops/softmax.h",
         "reference/integer_ops/tanh.h",
+        "reference/maximum_minimum.h",
         "reference/pooling.h",
         "reference/prelu.h",
+        "reference/process_broadcast_shapes.h",
         "reference/reference_ops.h",
         "reference/softmax.h",
         "reference/strided_slice.h",
+        "reference/svdf.h",
     ],
     deps = [
         ":common",
@@ -377,6 +396,7 @@
         ":round",
         ":strided_slice_logic",
         ":tensor",
+        ":tensor_utils",
         ":types",
         "@gemmlowp//:fixedpoint",
         "@gemmlowp//:profiler",
@@ -400,13 +420,20 @@
     name = "legacy_reference_base",
     srcs = [],
     hdrs = [
+        "reference/add.h",
+        "reference/arg_min_max.h",
+        "reference/binary_function.h",
+        "reference/comparisons.h",
         "reference/conv.h",
         "reference/depthwiseconv_float.h",
         "reference/depthwiseconv_uint8.h",
+        "reference/floor.h",
         "reference/fully_connected.h",
         "reference/legacy_reference_ops.h",
+        "reference/maximum_minimum.h",
         "reference/pooling.h",
         "reference/prelu.h",
+        "reference/process_broadcast_shapes.h",
         "reference/reference_ops.h",
         "reference/softmax.h",
         "reference/strided_slice.h",
@@ -575,7 +602,6 @@
         ":cpu_check",
         ":types",
         "//tensorflow/lite/c:c_api_internal",
-        "@arm_neon_2_x86_sse",
         "//tensorflow/lite/kernels:cpu_backend_context",
         "//tensorflow/lite/kernels:op_macros",
         "@gemmlowp//:fixedpoint",
@@ -604,6 +630,9 @@
         ":ios_arm64": [
             ":neon_tensor_utils",
         ],
+        ":raspberry_pi_with_neon": [
+            ":neon_tensor_utils",
+        ],
         ":ios_x86_64": [
             ":sse_tensor_utils",
         ],
diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h
index 7fb2d88..f6127c5 100644
--- a/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h
+++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h
@@ -55,7 +55,9 @@
   TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
   const int output_rows = output_shape.Dims(output_dim_count - 1);
   TFLITE_DCHECK_EQ(output_rows, filter_rows);
-  TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
+  if (bias_data) {
+    TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
+  }
 
   cpu_backend_gemm::MatrixParams<int8> lhs_params;
   lhs_params.rows = filter_rows;
diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/mean.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/mean.h
new file mode 100644
index 0000000..4afec53
--- /dev/null
+++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/mean.h
@@ -0,0 +1,236 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_MEAN_H_
+#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_MEAN_H_
+
+#include "tensorflow/lite/kernels/cpu_backend_context.h"
+#include "tensorflow/lite/kernels/cpu_backend_threadpool.h"
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
+
+namespace tflite {
+namespace optimized_integer_ops {
+
+#ifdef USE_NEON
+
+using optimized_ops::DivideSumForMeanImpl;
+using optimized_ops::RoundToNearest;
+
+#endif  // USE_NEON
+
+inline void MeanImpl(const tflite::MeanParams& op_params,
+                     const RuntimeShape& input_shape, const int8_t* input_data,
+                     int32 input_zero_point, float input_scale,
+                     const RuntimeShape& output_shape, int8_t* output_data,
+                     int32 output_zero_point, float output_scale,
+                     int start_depth, int end_depth) {
+  gemmlowp::ScopedProfilingLabel label("Mean4D/Int8/MeanImpl");
+
+  // Current implementation only supports dimension equals 4 and simultaneous
+  // reduction over width and height.
+  const int output_batch = output_shape.Dims(0);
+  const int output_height = output_shape.Dims(2);
+  const int output_width = output_shape.Dims(2);
+  const int input_height = input_shape.Dims(1);
+  const int input_width = input_shape.Dims(2);
+  const float num_elements_in_axis = input_width * input_height;
+
+  TFLITE_DCHECK_EQ(op_params.axis_count, 2);
+  TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
+                (op_params.axis[0] == 2 && op_params.axis[1] == 1));
+  TFLITE_DCHECK_EQ(output_height, 1);
+  TFLITE_DCHECK_EQ(output_width, 1);
+
+  const bool ordinary_mean =
+      (input_zero_point == output_zero_point && input_scale == output_scale);
+  float scale = 0.0f, bias = 0.0f;
+  if (!ordinary_mean) {
+    scale = input_scale / output_scale;
+    bias = -input_zero_point * scale + 0.5;
+  }
+
+#ifdef USE_NEON
+  const float32x4_t num_elements_dup = vdupq_n_f32(num_elements_in_axis);
+  // This is only an approximation as NEON does not offer division instruction.
+  const float32x4_t scale_dup = vdupq_n_f32(scale);
+  const float32x4_t num_elements_reverse = vrecpeq_f32(num_elements_dup);
+  float32x4_t zero_point_with_bias_dup = vdupq_n_f32(output_zero_point + bias);
+#endif  // USE_NEON
+
+  for (int out_b = 0; out_b < output_batch; ++out_b) {
+    int out_d = start_depth;
+#ifdef USE_NEON
+
+    for (; out_d < end_depth - 8; out_d += 8) {
+      float32x4_t temp_sum_1 = vdupq_n_f32(0);
+      float32x4_t temp_sum_2 = vdupq_n_f32(0);
+      for (int in_h = 0; in_h < input_height; ++in_h) {
+        for (int in_w = 0; in_w < input_width; ++in_w) {
+          const int8_t* input_data_ptr =
+              input_data + Offset(input_shape, out_b, in_h, in_w, out_d);
+          int8x8_t input_data_val = vld1_s8(input_data_ptr);
+          int16x8_t input_data_val_shift = vmovl_s8(input_data_val);
+          float32x4_t input_float_1 =
+              vcvtq_f32_s32(vmovl_s16(vget_high_s16(input_data_val_shift)));
+          float32x4_t input_float_2 =
+              vcvtq_f32_s32(vmovl_s16(vget_low_s16(input_data_val_shift)));
+          temp_sum_1 = vaddq_f32(temp_sum_1, input_float_1);
+          temp_sum_2 = vaddq_f32(temp_sum_2, input_float_2);
+        }
+      }
+
+      const float32x4_t mean_1 =
+          DivideSumForMeanImpl(temp_sum_1, num_elements_reverse, ordinary_mean,
+                               scale_dup, zero_point_with_bias_dup);
+      const float32x4_t mean_2 =
+          DivideSumForMeanImpl(temp_sum_2, num_elements_reverse, ordinary_mean,
+                               scale_dup, zero_point_with_bias_dup);
+
+      int32x4_t casted_mean_1 = RoundToNearest(mean_1);
+      int16x4_t narrow_range_mean_1 = vmovn_s32(casted_mean_1);
+      int32x4_t casted_mean_2 = RoundToNearest(mean_2);
+      int16x4_t narrow_range_mean_2 = vmovn_s32(casted_mean_2);
+      int16x8_t combined_mean =
+          vcombine_u16(narrow_range_mean_2, narrow_range_mean_1);
+      int8x8_t narrowed_combined_mean = vmovn_s16(combined_mean);
+      int8_t* output_data_ptr =
+          output_data + Offset(output_shape, out_b, 0, 0, out_d);
+      vst1_s8(output_data_ptr, narrowed_combined_mean);
+    }
+#endif  // USE_NEON
+
+    for (; out_d < end_depth; ++out_d) {
+      float temp_value = 0;
+      for (int in_h = 0; in_h < input_height; ++in_h) {
+        for (int in_w = 0; in_w < input_width; ++in_w) {
+          temp_value +=
+              input_data[Offset(input_shape, out_b, in_h, in_w, out_d)];
+        }
+      }
+
+      temp_value = temp_value / num_elements_in_axis;
+      if (ordinary_mean) {
+        output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
+            static_cast<int8_t>(round(temp_value));
+      } else {
+        output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
+            static_cast<int8_t>(round(temp_value * scale + bias)) +
+            output_zero_point;
+      }
+    }
+  }
+}
+
+struct MeanWorkerTask : cpu_backend_threadpool::Task {
+  MeanWorkerTask(const tflite::MeanParams& op_params,
+                 const RuntimeShape& input_shape, const int8_t* input_data,
+                 int32 input_zero_point, float input_scale,
+                 const RuntimeShape& output_shape, int8_t* output_data,
+                 int32 output_zero_point, float output_scale, int start_height,
+                 int end_height)
+      : op_params(op_params),
+        input_shape(input_shape),
+        input_data(input_data),
+        input_zero_point(input_zero_point),
+        input_scale(input_scale),
+        output_shape(output_shape),
+        output_data(output_data),
+        output_zero_point(output_zero_point),
+        output_scale(output_scale),
+        start_height(start_height),
+        end_height(end_height) {}
+
+  void Run() override {
+    MeanImpl(op_params, input_shape, input_data, input_zero_point, input_scale,
+             output_shape, output_data, output_zero_point, output_scale,
+             start_height, end_height);
+  }
+
+ private:
+  const tflite::MeanParams& op_params;
+  const RuntimeShape& input_shape;
+  const int8_t* input_data;
+  int32 input_zero_point;
+  float input_scale;
+  const RuntimeShape& output_shape;
+  int8_t* output_data;
+  int32 output_zero_point;
+  float output_scale;
+  int start_height;
+  int end_height;
+};
+
+inline void Mean(const tflite::MeanParams& op_params,
+                 const RuntimeShape& unextended_input_shape,
+                 const int8_t* input_data, int32 input_zero_point,
+                 float input_scale, const RuntimeShape& unextended_output_shape,
+                 int8_t* output_data, int32 output_zero_point,
+                 float output_scale, CpuBackendContext* cpu_backend_context) {
+  gemmlowp::ScopedProfilingLabel label("Mean4D/Int8");
+  // Current implementation only supports dimension equals 4 and simultaneous
+  // reduction over width and height.
+  TFLITE_CHECK_EQ(unextended_input_shape.DimensionsCount(), 4);
+  TFLITE_CHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+  const RuntimeShape input_shape =
+      RuntimeShape::ExtendedShape(4, unextended_input_shape);
+  const RuntimeShape output_shape =
+      RuntimeShape::ExtendedShape(4, unextended_output_shape);
+  const int output_height = output_shape.Dims(1);
+  const int output_width = output_shape.Dims(2);
+  const int output_depth = output_shape.Dims(3);
+
+  TFLITE_DCHECK_EQ(op_params.axis_count, 2);
+  TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
+                (op_params.axis[0] == 2 && op_params.axis[1] == 1));
+  TFLITE_DCHECK_EQ(output_height, 1);
+  TFLITE_DCHECK_EQ(output_width, 1);
+
+  constexpr int kMinDepthPerThread = 8;
+  int thread_count = output_depth / kMinDepthPerThread;
+  thread_count = thread_count > 0 ? thread_count : 1;
+  const int capped_thread_count =
+      std::min(thread_count, cpu_backend_context->max_num_threads());
+
+  if (capped_thread_count == 1) {
+    MeanImpl(op_params, input_shape, input_data, input_zero_point, input_scale,
+             output_shape, output_data, output_zero_point, output_scale, 0,
+             output_depth);
+  } else {
+    // Instead parrallel for batch, we loop for the output_depth since batch
+    // is typical 1.
+    std::vector<MeanWorkerTask> tasks;
+    // TODO(b/131746020) don't create new heap allocations every time.
+    // At least we make it a single heap allocation by using reserve().
+    tasks.reserve(capped_thread_count);
+    int depth_start = 0;
+    for (int i = 0; i < capped_thread_count; ++i) {
+      // Try to distribute the tasks as even as possible.
+      int depth_end = depth_start +
+                      (output_depth - depth_start) / (capped_thread_count - i);
+      tasks.emplace_back(op_params, input_shape, input_data, input_zero_point,
+                         input_scale, output_shape, output_data,
+                         output_zero_point, output_scale, depth_start,
+                         depth_end);
+      depth_start = depth_end;
+    }
+    cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
+                                    cpu_backend_context);
+  }
+}
+
+}  // namespace optimized_integer_ops
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_MEAN_H_
diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h
index fa95f09..08b8da0 100644
--- a/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h
+++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h
@@ -44,6 +44,9 @@
       vdup_n_s8(params.quantized_activation_min);
   const auto output_activation_max_vector =
       vdup_n_s8(params.quantized_activation_max);
+  const int left_shift = std::max(0, params.output_shift);
+  const int right_shift = std::max(0, -params.output_shift);
+  const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
   for (; i <= size - 8; i += 8) {
     // We load / store 8 at a time, multiplying as two sets of 4 int32s.
     const auto input1_val_original = vld1_s8(input1_data + i);
@@ -61,14 +64,16 @@
     auto p1 = vmull_s16(input2_val_low, input1_val_low);
     auto p2 = vmull_s16(input2_val_high, input1_val_high);
 
+    p1 = vshlq_s32(p1, left_shift_vec);
+    p2 = vshlq_s32(p2, left_shift_vec);
     p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
     p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
     using gemmlowp::RoundingDivideByPOT;
-    p1 = RoundingDivideByPOT(p1, -params.output_shift);
-    p2 = RoundingDivideByPOT(p2, -params.output_shift);
+    p1 = RoundingDivideByPOT(p1, right_shift);
+    p2 = RoundingDivideByPOT(p2, right_shift);
 
-    const auto p1_narrowed = vmovn_s32(p1);
-    const auto p2_narrowed = vmovn_s32(p2);
+    const auto p1_narrowed = vqmovn_s32(p1);
+    const auto p2_narrowed = vqmovn_s32(p2);
     const auto p =
         vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
     const auto clamped =
@@ -83,9 +88,9 @@
     const int32 input2_val = params.input2_offset + input2_data[i];
     const int32 unclamped_result =
         params.output_offset +
-        MultiplyByQuantizedMultiplierSmallerThanOneExp(input1_val * input2_val,
-                                                       params.output_multiplier,
-                                                       params.output_shift);
+        MultiplyByQuantizedMultiplier(input1_val * input2_val,
+                                      params.output_multiplier,
+                                      params.output_shift);
     const int32 clamped_output =
         std::min(params.quantized_activation_max,
                  std::max(params.quantized_activation_min, unclamped_result));
@@ -114,6 +119,9 @@
       vdup_n_s8(params.quantized_activation_min);
   const auto output_activation_max_vector =
       vdup_n_s8(params.quantized_activation_max);
+  const int left_shift = std::max(0, params.output_shift);
+  const int right_shift = std::max(0, -params.output_shift);
+  const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
   for (; i <= size - 8; i += 8) {
     // We load / store 8 at a time, multiplying as two sets of 4 int32s.
     const auto input2_val_original = vld1_s8(input2_data + i);
@@ -126,14 +134,16 @@
     auto p1 = vmull_n_s16(input2_val_low, input1_val);
     auto p2 = vmull_n_s16(input2_val_high, input1_val);
 
+    p1 = vshlq_s32(p1, left_shift_vec);
+    p2 = vshlq_s32(p2, left_shift_vec);
     p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
     p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
     using gemmlowp::RoundingDivideByPOT;
-    p1 = RoundingDivideByPOT(p1, -params.output_shift);
-    p2 = RoundingDivideByPOT(p2, -params.output_shift);
+    p1 = RoundingDivideByPOT(p1, right_shift);
+    p2 = RoundingDivideByPOT(p2, right_shift);
 
-    const auto p1_narrowed = vmovn_s32(p1);
-    const auto p2_narrowed = vmovn_s32(p2);
+    const auto p1_narrowed = vqmovn_s32(p1);
+    const auto p2_narrowed = vqmovn_s32(p2);
     const auto p =
         vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
     const auto clamped =
@@ -147,9 +157,9 @@
     const int32 input2_val = params.input2_offset + input2_data[i];
     const int32 unclamped_result =
         params.output_offset +
-        MultiplyByQuantizedMultiplierSmallerThanOneExp(input1_val * input2_val,
-                                                       params.output_multiplier,
-                                                       params.output_shift);
+        MultiplyByQuantizedMultiplier(input1_val * input2_val,
+                                      params.output_multiplier,
+                                      params.output_shift);
     const int32 clamped_output =
         std::min(params.quantized_activation_max,
                  std::max(params.quantized_activation_min, unclamped_result));
diff --git a/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h
index 9154645..b930516 100644
--- a/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h
+++ b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h
@@ -51,7 +51,6 @@
 using reference_ops::Concatenation;
 using reference_ops::ConcatenationWithScaling;
 using reference_ops::DepthConcatenation;
-using reference_ops::Dequantize;
 using reference_ops::Div;
 using reference_ops::FakeQuant;
 using reference_ops::Gather;
@@ -3973,6 +3972,208 @@
              filter_width, filter_height, output_data, output_dims);
 }
 
+inline void Softmax(const SoftmaxParams& params,
+                    const RuntimeShape& input_shape, const uint8* input_data,
+                    const RuntimeShape& output_shape, uint8* output_data) {
+  const int32 input_beta_multiplier = params.input_multiplier;
+  const int32 input_beta_left_shift = params.input_left_shift;
+  const int diff_min = params.diff_min;
+  // The representation chosen for the input to the exp() function is Q5.26.
+  // We need to leave extra space since values that we skip might be as large as
+  // -32 before multiplying by input_beta_multiplier, and therefore as large as
+  // -16 afterwards.  Note that exp(-8) is definitely not insignificant to
+  // accumulation, but exp(-16) definitely is.
+  static const int kScaledDiffIntegerBits = 5;
+  static const int kAccumulationIntegerBits = 12;
+  using FixedPointScaledDiff =
+      gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
+  using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
+  using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
+
+  gemmlowp::ScopedProfilingLabel label("Softmax/8bit");
+  const int trailing_dim = input_shape.DimensionsCount() - 1;
+  const int outer_size =
+      MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+  const int depth =
+      MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
+
+  for (int b = 0; b < outer_size; ++b) {
+    const uint8* input_data_ptr = input_data + b * depth;
+    uint8* output_data_ptr = output_data + b * depth;
+
+    // Determine the largest entry in the current row
+    uint8 max_in_row = 0;
+    {
+      int c = 0;
+#ifdef USE_NEON
+      uint8x16_t max16_0 = vdupq_n_u8(0);
+      uint8x16_t max16_1 = vdupq_n_u8(0);
+      for (; c <= depth - 32; c += 32) {
+        max16_0 = vmaxq_u8(max16_0, vld1q_u8(input_data_ptr + c + 0));
+        max16_1 = vmaxq_u8(max16_1, vld1q_u8(input_data_ptr + c + 16));
+      }
+      uint8x16_t max16 = vmaxq_u8(max16_0, max16_1);
+      if (c <= depth - 16) {
+        max16 = vmaxq_u8(max16, vld1q_u8(input_data_ptr + c));
+        c += 16;
+      }
+      uint8x8_t max8 = vmax_u8(vget_low_u8(max16), vget_high_u8(max16));
+      if (c <= depth - 8) {
+        max8 = vmax_u8(max8, vld1_u8(input_data_ptr + c));
+        c += 8;
+      }
+      uint8x8_t max4 = vmax_u8(max8, vext_u8(max8, max8, 4));
+      uint8x8_t max2 = vmax_u8(max4, vext_u8(max4, max4, 2));
+      uint8x8_t max1 = vpmax_u8(max2, max2);
+      max_in_row = vget_lane_u8(max1, 0);
+#endif
+      for (; c < depth; ++c) {
+        max_in_row = std::max(max_in_row, input_data_ptr[c]);
+      }
+    }
+
+#ifdef USE_NEON
+    using FixedPointAccumInt32x4 =
+        gemmlowp::FixedPoint<int32x4_t, kAccumulationIntegerBits>;
+    using FixedPointScaledDiffInt32x4 =
+        gemmlowp::FixedPoint<int32x4_t, kScaledDiffIntegerBits>;
+    using FixedPoint0Int32x4 = gemmlowp::FixedPoint<int32x4_t, 0>;
+    FixedPoint0Int32x4 input_beta_multiplier_f0 =
+        FixedPoint0Int32x4::FromScalarRaw(input_beta_multiplier);
+    int16x8_t max_in_row_s16 = vdupq_n_s16(max_in_row);
+#endif
+
+    // Compute the sum of exponentials of the differences of entries in the
+    // current row from the largest entry in the current row.
+    FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
+    {
+      int c = 0;
+#ifdef USE_NEON
+      int32x4_t diff_min_s32 = vdupq_n_s32(diff_min);
+      FixedPointAccumInt32x4 sum_of_exps_0 = FixedPointAccumInt32x4::Zero();
+      FixedPointAccumInt32x4 sum_of_exps_1 = FixedPointAccumInt32x4::Zero();
+      FixedPointAccumInt32x4 zeros = FixedPointAccumInt32x4::Zero();
+      for (; c <= depth - 8; c += 8) {
+        uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
+        int16x8_t input_diff_s16 =
+            vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
+        int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
+        int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
+        int32x4_t mask_0 =
+            gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_0, diff_min_s32);
+        int32x4_t mask_1 =
+            gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_1, diff_min_s32);
+        FixedPointScaledDiffInt32x4 scaled_diff_0 =
+            input_beta_multiplier_f0 *
+            FixedPointScaledDiffInt32x4::FromRaw(
+                gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
+        FixedPointScaledDiffInt32x4 scaled_diff_1 =
+            input_beta_multiplier_f0 *
+            FixedPointScaledDiffInt32x4::FromRaw(
+                gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
+        FixedPointAccumInt32x4 exps_0 =
+            gemmlowp::Rescale<kAccumulationIntegerBits>(
+                exp_on_negative_values(scaled_diff_0));
+        FixedPointAccumInt32x4 exps_1 =
+            gemmlowp::Rescale<kAccumulationIntegerBits>(
+                exp_on_negative_values(scaled_diff_1));
+        FixedPointAccumInt32x4 masked_exps_0 =
+            SelectUsingMask(mask_0, exps_0, zeros);
+        FixedPointAccumInt32x4 masked_exps_1 =
+            SelectUsingMask(mask_1, exps_1, zeros);
+        sum_of_exps_0 = sum_of_exps_0 + masked_exps_0;
+        sum_of_exps_1 = sum_of_exps_1 + masked_exps_1;
+      }
+      int32x4_t sum_of_exps_reduced_4 = (sum_of_exps_0 + sum_of_exps_1).raw();
+      int32x2_t sum_of_exps_reduced_2 =
+          vadd_s32(vget_low_s32(sum_of_exps_reduced_4),
+                   vget_high_s32(sum_of_exps_reduced_4));
+      int32x2_t sum_of_exps_reduced_1 =
+          vpadd_s32(sum_of_exps_reduced_2, sum_of_exps_reduced_2);
+      sum_of_exps =
+          FixedPointAccum::FromRaw(vget_lane_s32(sum_of_exps_reduced_1, 0));
+#endif
+      for (; c < depth; ++c) {
+        int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
+        if (input_diff >= diff_min) {
+          const int32 input_diff_rescaled =
+              MultiplyByQuantizedMultiplierGreaterThanOne(
+                  input_diff, input_beta_multiplier, input_beta_left_shift);
+          const FixedPointScaledDiff scaled_diff_f8 =
+              FixedPointScaledDiff::FromRaw(input_diff_rescaled);
+          sum_of_exps =
+              sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
+                                exp_on_negative_values(scaled_diff_f8));
+        }
+      }
+    }
+
+    // Compute the fixed-point multiplier and shift that we need to apply to
+    // perform a division by the above-computed sum-of-exponentials.
+    int num_bits_over_unit = 0;
+    FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal(
+        sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit));
+
+    // Compute the quotients of exponentials of differences of entries in the
+    // current row from the largest entry, over the previously-computed sum of
+    // exponentials.
+    {
+      int c = 0;
+#ifdef USE_NEON
+      int16x8_t diff_min_s16 = vdupq_n_s16(diff_min);
+      for (; c <= depth - 8; c += 8) {
+        uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
+        int16x8_t input_diff_s16 =
+            vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
+        int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
+        int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
+        uint8x8_t mask = vmovn_u16(vcgeq_s16(input_diff_s16, diff_min_s16));
+        FixedPointScaledDiffInt32x4 scaled_diff_0 =
+            input_beta_multiplier_f0 *
+            FixedPointScaledDiffInt32x4::FromRaw(
+                gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
+        FixedPointScaledDiffInt32x4 scaled_diff_1 =
+            input_beta_multiplier_f0 *
+            FixedPointScaledDiffInt32x4::FromRaw(
+                gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
+        FixedPoint0Int32x4 exp_0 = exp_on_negative_values(scaled_diff_0);
+        FixedPoint0Int32x4 exp_1 = exp_on_negative_values(scaled_diff_1);
+        int32x4_t output_s32_0 = gemmlowp::RoundingDivideByPOT(
+            vqrdmulhq_n_s32(exp_0.raw(), shifted_scale.raw()),
+            num_bits_over_unit + 31 - 8);
+        int32x4_t output_s32_1 = gemmlowp::RoundingDivideByPOT(
+            vqrdmulhq_n_s32(exp_1.raw(), shifted_scale.raw()),
+            num_bits_over_unit + 31 - 8);
+        int16x8_t output_s16 =
+            vcombine_s16(vqmovn_s32(output_s32_0), vqmovn_s32(output_s32_1));
+        uint8x8_t output_u8 = vqmovun_s16(output_s16);
+        uint8x8_t masked_output = vbsl_u8(mask, output_u8, vdup_n_u8(0));
+        vst1_u8(output_data_ptr + c, masked_output);
+      }
+#endif
+      for (; c < depth; ++c) {
+        int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
+        if (input_diff >= diff_min) {
+          const int32 input_diff_rescaled =
+              MultiplyByQuantizedMultiplierGreaterThanOne(
+                  input_diff, input_beta_multiplier, input_beta_left_shift);
+          const FixedPointScaledDiff scaled_diff_f8 =
+              FixedPointScaledDiff::FromRaw(input_diff_rescaled);
+
+          FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
+          int32 unsat_output = gemmlowp::RoundingDivideByPOT(
+              (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
+
+          output_data_ptr[c] = std::max(std::min(unsat_output, 255), 0);
+
+        } else {
+          output_data_ptr[c] = 0;
+        }
+      }
+    }
+  }
+}
+
 inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
                     float beta, float* output_data,
                     const RuntimeShape& output_shape) {
@@ -4706,6 +4907,17 @@
           DimsToShape(output_dims), output_data);
 }
 
+inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
+                       int32 zero_point, double scale, float* output_data,
+                       const Dims<4>& output_dims) {
+  tflite::DequantizationParams op_params;
+  op_params.zero_point = zero_point;
+  op_params.scale = scale;
+
+  Dequantize(op_params, DimsToShape(input_dims), input_data,
+             DimsToShape(output_dims), output_data);
+}
+
 }  // namespace optimized_ops
 }  // namespace tflite
 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc
index 11caae6..4eaa9a9 100644
--- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -33,14 +33,13 @@
 
 #define kFloatWeightsPerNeonLane 4
 
+// aligned_alloc is available (via cstdlib/stdlib.h) with C++17/C11.
 #if __cplusplus >= 201703L || __STDC_VERSION__ >= 201112L
 #if !defined(__ANDROID__) || __ANDROID_API__ >= 28
-#define TFLITE_USE_STD_ALIGN
+#if !defined(__APPLE__)  // Apple does not provide aligned_alloc.
+#define TFLITE_USE_STD_ALIGNED_ALLOC
 #endif
 #endif
-
-#ifdef TFLITE_USE_STD_ALIGN
-#include <stdalign.h>
 #endif
 
 namespace tflite {
@@ -54,7 +53,7 @@
 // the passed freeing_buffer pointer.
 inline void* aligned_alloc(size_t alignment, size_t size,
                            void** freeing_buffer) {
-#ifdef TFLITE_USE_STD_ALIGN
+#ifdef TFLITE_USE_STD_ALIGNED_ALLOC
   *freeing_buffer = ::aligned_alloc(
       alignment, (size + alignment - 1) / alignment * alignment);
   return *freeing_buffer;
@@ -870,6 +869,23 @@
   }
 }
 
+// TODO(renjieliu): Avoid duplicating the logic.
+// Also consider changing the rounding stragey from "ties to away" to
+// "ties to even" since vcvtnq_s32_f32 is generally more available.
+inline int32x4_t RoundToNearest(const float32x4_t input) {
+#if defined(_ACAT_ARM64)
+  return vcvtaq_s32_f32(input);
+#else
+  static const float32x4_t zero_val_dup = vdupq_n_f32(0.0f);
+  static const float32x4_t point5_val_dup = vdupq_n_f32(0.5f);
+
+  const int32x4_t mask = vreinterpretq_s32_u32(vcltq_f32(input, zero_val_dup));
+  const float32x4_t casted_mask = vcvtq_f32_s32(mask);
+  const float32x4_t round = vaddq_f32(casted_mask, point5_val_dup);
+  return vcvtq_s32_f32(vaddq_f32(input, round));
+#endif
+}
+
 void NeonSymmetricQuantizeFloats(const float* values, const int size,
                                  int8_t* quantized_values, float* min,
                                  float* max, float* scaling_factor) {
@@ -892,8 +908,6 @@
 
   // Vectorized constants.
   const float32x4_t q_factor_f32x4 = vmovq_n_f32(scaling_factor_inv);
-  const float32x4_t point5_f32x4 = vmovq_n_f32(0.5);
-  const float32x4_t zero_f32x4 = vmovq_n_f32(0.0);
   const int32x4_t scale_i32x4 = vmovq_n_s32(kScale);
   const int32x4_t neg_scale_i32x4 = vmovq_n_s32(-kScale);
 
@@ -901,29 +915,13 @@
     // Implements the vectorized version of the following:
     // const int32 quantized_value = static_cast<int32>(
     //    std::round(*scaling_factor * values[i]));
-    // Since the vectorized round intrinsics (vrndqa_f32) is not supported
-    // on all Neon flavors, we use the following method for rounding: if (x
-    // < 0) (int)(x - 0.5) if (x >= 0) (int)(x + 0.5)
     float32x4_t value0_f32x4 = vld1q_f32(&values[i]);
     float32x4_t value1_f32x4 = vld1q_f32(&values[i + kFloatWeightsPerNeonLane]);
     float32x4_t mul0_f32x4 = vmulq_f32(value0_f32x4, q_factor_f32x4);
     float32x4_t mul1_f32x4 = vmulq_f32(value1_f32x4, q_factor_f32x4);
 
-    int32x4_t cmp_with_zero0_ui32x4 =
-        (int32x4_t)vcltq_f32(mul0_f32x4, zero_f32x4);  // NOLINT
-    int32x4_t cmp_with_zero1_ui32x4 =
-        (int32x4_t)vcltq_f32(mul1_f32x4, zero_f32x4);  // NOLINT
-
-    float32x4_t cmp_with_zero0_f32x4 = vcvtq_f32_s32(cmp_with_zero0_ui32x4);
-    float32x4_t cmp_with_zero1_f32x4 = vcvtq_f32_s32(cmp_with_zero1_ui32x4);
-    cmp_with_zero0_f32x4 = vaddq_f32(cmp_with_zero0_f32x4, point5_f32x4);
-    cmp_with_zero1_f32x4 = vaddq_f32(cmp_with_zero1_f32x4, point5_f32x4);
-
-    mul0_f32x4 = vaddq_f32(mul0_f32x4, cmp_with_zero0_f32x4);
-    mul1_f32x4 = vaddq_f32(mul1_f32x4, cmp_with_zero1_f32x4);
-
-    int32x4_t f2i0_i32x4 = vcvtq_s32_f32(mul0_f32x4);
-    int32x4_t f2i1_i32x4 = vcvtq_s32_f32(mul1_f32x4);
+    const int32x4_t f2i0_i32x4 = RoundToNearest(mul0_f32x4);
+    const int32x4_t f2i1_i32x4 = RoundToNearest(mul1_f32x4);
 
     // Implements the vectorized version of the folowing block:
     //  quantized_values[i] = std::min(kScale, std::max(-kScale,
@@ -1021,24 +1019,6 @@
   }
 }
 
-void NeonVectorShiftLeft(float* vector, int v_size, float shift_value) {
-  // This variable keeps track of the next to the last index which is being
-  // copied to make sure we are not out of the vector boundary.
-  int last_index_copy = kFloatWeightsPerNeonLane;
-  int current_index_copy = 0;
-  while (last_index_copy < v_size) {
-    float32x4_t v_f32x4 = vld1q_f32(vector + current_index_copy + 1);
-    vst1q_f32(vector + current_index_copy, v_f32x4);
-    current_index_copy += kFloatWeightsPerNeonLane;
-    last_index_copy += kFloatWeightsPerNeonLane;
-  }
-  // Postamble loop.
-  for (int i = current_index_copy; i < v_size - 1; i++) {
-    vector[i] = vector[i + 1];
-  }
-  vector[v_size - 1] = shift_value;
-}
-
 }  // namespace tensor_utils
 }  // namespace tflite
 
diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h
index c4f13a1..af1bb7f 100644
--- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h
+++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h
@@ -152,10 +152,6 @@
                    min_value, max_value, scaling_factor);
 }
 
-void VectorShiftLeft(float* vector, int v_size, float shift_value) {
-  NEON_OR_PORTABLE(VectorShiftLeft, vector, v_size, shift_value);
-}
-
 void ReductionSumVector(const float* input_vector, float* output_vector,
                         int output_size, int reduction_size) {
   NEON_OR_PORTABLE(ReductionSumVector, input_vector, output_vector, output_size,
diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
index 5798ea3..2388d78 100644
--- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
@@ -56,7 +56,6 @@
 namespace optimized_ops {
 
 // Unoptimized reference ops:
-using reference_ops::AffineQuantize;
 using reference_ops::ArgMax;
 using reference_ops::ArgMinMax;
 using reference_ops::Broadcast4DSlowGreater;
@@ -73,7 +72,6 @@
 using reference_ops::Concatenation;
 using reference_ops::ConcatenationWithScaling;
 using reference_ops::DepthConcatenation;
-using reference_ops::Dequantize;
 using reference_ops::Div;
 using reference_ops::Elu;
 using reference_ops::FakeQuant;
@@ -268,7 +266,9 @@
   TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
   const int output_rows = output_shape.Dims(output_dim_count - 1);
   TFLITE_DCHECK_EQ(output_rows, filter_rows);
-  TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
+  if (bias_data) {
+    TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
+  }
 
   cpu_backend_gemm::MatrixParams<uint8> lhs_params;
   lhs_params.rows = filter_rows;
@@ -801,6 +801,51 @@
                                   cpu_backend_context);
 }
 
+#ifdef USE_NEON
+
+inline float32x4_t DivideSumForMeanImpl(
+    const float32x4_t sum, const float32x4_t num_elements_reverse,
+    const bool ordinary_mean, const float32x4_t scale_dup,
+    const float32x4_t zero_point_with_bias_dup) {
+  const float32x4_t val = vmulq_f32(sum, num_elements_reverse);
+  if (!ordinary_mean) {
+#ifdef ARM_FEATURE_FMA
+    return vfmaq_f32(zero_point_with_bias_dup, scale_dup, val);
+#else
+    return vmlaq_f32(zero_point_with_bias_dup, scale_dup, val);
+#endif  // ARM_FEATURE_FMA
+  }
+  return val;
+}
+
+inline int32x4_t RoundToNearest(const float32x4_t input) {
+#if !defined(__aarch64__) && !defined(__SSE4_1__)
+  static const float32x4_t zero_val_dup = vdupq_n_f32(0.0f);
+  static const float32x4_t point5_val_dup = vdupq_n_f32(0.5f);
+  static const float32x4_t minus_point5_val_dup = vdupq_n_f32(-0.5f);
+
+  const uint32x4_t mask = vcltq_f32(input, zero_val_dup);
+  const float32x4_t round =
+      vbslq_f32(mask, minus_point5_val_dup, point5_val_dup);
+  return vcvtq_s32_f32(vaddq_f32(input, round));
+#else
+  return vcvtnq_s32_f32(input);
+#endif  // !defined(__aarch64__)
+}
+
+inline uint32x4_t RoundToNearestUnsigned(const float32x4_t input) {
+#if defined(__aarch64__) && !defined(__SSE4_1__)
+  // Note that vcvtnq_u32_f32 is not available on the arm_neon_sse.h.
+  return vcvtnq_u32_f32(input);
+#else
+  static const float32x4_t point5_val_dup = vdupq_n_f32(0.5f);
+
+  return vcvtq_u32_f32(vaddq_f32(input, point5_val_dup));
+#endif  // defined(__aarch64__) && !defined(__SSE4_1__)
+}
+
+#endif  // USE_NEON
+
 inline void MeanImpl(const tflite::MeanParams& op_params,
                      const RuntimeShape& input_shape, const uint8_t* input_data,
                      int32 input_zero_point, float input_scale,
@@ -826,7 +871,7 @@
 
   const bool ordinary_mean =
       (input_zero_point == output_zero_point && input_scale == output_scale);
-  float scale, bias;
+  float scale = 0.0f, bias = 0.0f;
   if (!ordinary_mean) {
     scale = input_scale / output_scale;
     bias = -input_zero_point * scale + 0.5;
@@ -835,15 +880,10 @@
 #ifdef USE_NEON
   const float32x4_t num_elements_dup = vdupq_n_f32(num_elements_in_axis);
   // This is only an approximation as NEON does not offer division instruction.
+  const float32x4_t scale_dup = vdupq_n_f32(scale);
   const float32x4_t num_elements_reverse = vrecpeq_f32(num_elements_dup);
-  const float32x4_t kRounding = vdupq_n_f32(0.5);
-  float32x4_t bias_dup;
-  float32x4_t output_zero_point_dup;
-  if (!ordinary_mean) {
-    bias_dup = vdupq_n_f32(bias);
-    output_zero_point_dup = vdupq_n_f32(output_zero_point);
-  }
-#endif
+  float32x4_t zero_point_with_bias_dup = vdupq_n_f32(output_zero_point + bias);
+#endif  // USE_NEON
 
   for (int out_b = 0; out_b < output_batch; ++out_b) {
     int out_d = start_depth;
@@ -868,28 +908,16 @@
         }
       }
 
-      float32x4_t mean_1 = vmulq_f32(temp_sum_1, num_elements_reverse);
-      float32x4_t mean_2 = vmulq_f32(temp_sum_2, num_elements_reverse);
+      const float32x4_t mean_1 =
+          DivideSumForMeanImpl(temp_sum_1, num_elements_reverse, ordinary_mean,
+                               scale_dup, zero_point_with_bias_dup);
+      const float32x4_t mean_2 =
+          DivideSumForMeanImpl(temp_sum_2, num_elements_reverse, ordinary_mean,
+                               scale_dup, zero_point_with_bias_dup);
 
-      if (!ordinary_mean) {
-        // maq is not supported, break down into two ops.
-        mean_1 = vmulq_n_f32(mean_1, scale);
-        mean_1 = vaddq_f32(mean_1, bias_dup);
-        mean_2 = vmulq_n_f32(mean_2, scale);
-        mean_2 = vaddq_f32(mean_2, bias_dup);
-      }
-
-      if (!ordinary_mean) {
-        mean_1 = vaddq_f32(mean_1, output_zero_point_dup);
-        mean_2 = vaddq_f32(mean_2, output_zero_point_dup);
-      }
-
-      // Rounding.
-      mean_1 = vaddq_f32(mean_1, kRounding);
-      mean_2 = vaddq_f32(mean_2, kRounding);
-      uint32x4_t casted_mean_1 = vcvtq_u32_f32(mean_1);
+      uint32x4_t casted_mean_1 = RoundToNearestUnsigned(mean_1);
       uint16x4_t narrow_range_mean_1 = vmovn_u32(casted_mean_1);
-      uint32x4_t casted_mean_2 = vcvtq_u32_f32(mean_2);
+      uint32x4_t casted_mean_2 = RoundToNearestUnsigned(mean_2);
       uint16x4_t narrow_range_mean_2 = vmovn_u32(casted_mean_2);
       uint16x8_t combined_mean =
           vcombine_u16(narrow_range_mean_2, narrow_range_mean_1);
@@ -898,7 +926,7 @@
           output_data + Offset(output_shape, out_b, 0, 0, out_d);
       vst1_u8(output_data_ptr, narrowed_combined_mean);
     }
-#endif
+#endif  // USE_NEON
 
     for (; out_d < end_depth; ++out_d) {
       float temp_value = 0;
@@ -2055,6 +2083,9 @@
       vdup_n_u8(params.quantized_activation_min);
   const auto output_activation_max_vector =
       vdup_n_u8(params.quantized_activation_max);
+  const int left_shift = std::max(0, params.output_shift);
+  const int right_shift = std::max(0, -params.output_shift);
+  const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
   for (; i <= size - 8; i += 8) {
     // We load / store 8 at a time, multiplying as two sets of 4 int32s.
     const auto input1_val_original = vld1_u8(input1_data + i);
@@ -2074,14 +2105,16 @@
     auto p1 = vmull_s16(input2_val_low, input1_val_low);
     auto p2 = vmull_s16(input2_val_high, input1_val_high);
 
+    p1 = vshlq_s32(p1, left_shift_vec);
+    p2 = vshlq_s32(p2, left_shift_vec);
     p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
     p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
     using gemmlowp::RoundingDivideByPOT;
-    p1 = RoundingDivideByPOT(p1, -params.output_shift);
-    p2 = RoundingDivideByPOT(p2, -params.output_shift);
+    p1 = RoundingDivideByPOT(p1, right_shift);
+    p2 = RoundingDivideByPOT(p2, right_shift);
 
-    const auto p1_narrowed = vmovn_s32(p1);
-    const auto p2_narrowed = vmovn_s32(p2);
+    const auto p1_narrowed = vqmovn_s32(p1);
+    const auto p2_narrowed = vqmovn_s32(p2);
     const auto p =
         vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
     const auto clamped =
@@ -2096,9 +2129,9 @@
     const int32 input2_val = params.input2_offset + input2_data[i];
     const int32 unclamped_result =
         params.output_offset +
-        MultiplyByQuantizedMultiplierSmallerThanOneExp(input1_val * input2_val,
-                                                       params.output_multiplier,
-                                                       params.output_shift);
+        MultiplyByQuantizedMultiplier(input1_val * input2_val,
+                                      params.output_multiplier,
+                                      params.output_shift);
     const int32 clamped_output =
         std::min(params.quantized_activation_max,
                  std::max(params.quantized_activation_min, unclamped_result));
@@ -2126,6 +2159,9 @@
       vdup_n_u8(params.quantized_activation_min);
   const auto output_activation_max_vector =
       vdup_n_u8(params.quantized_activation_max);
+  const int left_shift = std::max(0, params.output_shift);
+  const int right_shift = std::max(0, -params.output_shift);
+  const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
   for (; i <= size - 8; i += 8) {
     // We load / store 8 at a time, multiplying as two sets of 4 int32s.
     const auto input2_val_original = vld1_u8(input2_data + i);
@@ -2139,11 +2175,13 @@
     auto p1 = vmull_n_s16(input2_val_low, input1_val);
     auto p2 = vmull_n_s16(input2_val_high, input1_val);
 
+    p1 = vshlq_s32(p1, left_shift_vec);
+    p2 = vshlq_s32(p2, left_shift_vec);
     p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
     p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
     using gemmlowp::RoundingDivideByPOT;
-    p1 = RoundingDivideByPOT(p1, -params.output_shift);
-    p2 = RoundingDivideByPOT(p2, -params.output_shift);
+    p1 = RoundingDivideByPOT(p1, right_shift);
+    p2 = RoundingDivideByPOT(p2, right_shift);
 
     const auto p1_narrowed = vmovn_s32(p1);
     const auto p2_narrowed = vmovn_s32(p2);
@@ -3485,205 +3523,64 @@
   out_mat.array().rowwise() *= scale;
 }
 
-inline void Softmax(const SoftmaxParams& params,
-                    const RuntimeShape& input_shape, const uint8* input_data,
-                    const RuntimeShape& output_shape, uint8* output_data) {
-  const int32 input_beta_multiplier = params.input_multiplier;
-  const int32 input_beta_left_shift = params.input_left_shift;
-  const int diff_min = params.diff_min;
-  // The representation chosen for the input to the exp() function is Q5.26.
-  // We need to leave extra space since values that we skip might be as large as
-  // -32 before multiplying by input_beta_multiplier, and therefore as large as
-  // -16 afterwards.  Note that exp(-8) is definitely not insignificant to
-  // accumulation, but exp(-16) definitely is.
-  static const int kScaledDiffIntegerBits = 5;
-  static const int kAccumulationIntegerBits = 12;
-  using FixedPointScaledDiff =
-      gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
-  using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
-  using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
+inline int32_t QuantizeSoftmaxOutput(int8_t* output_data, float prob_rescaled,
+                                     int32_t zero_point) {
+  const int32_t prob_rnd = static_cast<int32_t>(std::round(prob_rescaled));
+  return prob_rnd + zero_point;
+}
 
-  gemmlowp::ScopedProfilingLabel label("Softmax/8bit");
+inline int32_t QuantizeSoftmaxOutput(uint8_t* output_data, float prob_rescaled,
+                                     int32_t zero_point) {
+  return static_cast<int32_t>(prob_rescaled + 0.5);
+}
+
+inline void PopulateSoftmaxLookupTable(SoftmaxParams* data, float input_scale,
+                                       float beta) {
+  const float scale = -input_scale * beta;
+  const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
+  for (int32_t val = 0; val <= max_uint8; ++val) {
+    data->table[max_uint8 - val] = expf(scale * val);
+  }
+}
+
+template <typename T>
+inline void Softmax(const SoftmaxParams& params,
+                    const RuntimeShape& input_shape, const T* input_data,
+                    const RuntimeShape& output_shape, T* output_data) {
   const int trailing_dim = input_shape.DimensionsCount() - 1;
-  const int outer_size =
+  const int excluding_last_dim =
       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
-  const int depth =
+  const int last_dim =
       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
 
-  for (int b = 0; b < outer_size; ++b) {
-    const uint8* input_data_ptr = input_data + b * depth;
-    uint8* output_data_ptr = output_data + b * depth;
-
-    // Determine the largest entry in the current row
-    uint8 max_in_row = 0;
-    {
-      int c = 0;
-#ifdef USE_NEON
-      uint8x16_t max16_0 = vdupq_n_u8(0);
-      uint8x16_t max16_1 = vdupq_n_u8(0);
-      for (; c <= depth - 32; c += 32) {
-        max16_0 = vmaxq_u8(max16_0, vld1q_u8(input_data_ptr + c + 0));
-        max16_1 = vmaxq_u8(max16_1, vld1q_u8(input_data_ptr + c + 16));
-      }
-      uint8x16_t max16 = vmaxq_u8(max16_0, max16_1);
-      if (c <= depth - 16) {
-        max16 = vmaxq_u8(max16, vld1q_u8(input_data_ptr + c));
-        c += 16;
-      }
-      uint8x8_t max8 = vmax_u8(vget_low_u8(max16), vget_high_u8(max16));
-      if (c <= depth - 8) {
-        max8 = vmax_u8(max8, vld1_u8(input_data_ptr + c));
-        c += 8;
-      }
-      uint8x8_t max4 = vmax_u8(max8, vext_u8(max8, max8, 4));
-      uint8x8_t max2 = vmax_u8(max4, vext_u8(max4, max4, 2));
-      uint8x8_t max1 = vpmax_u8(max2, max2);
-      max_in_row = vget_lane_u8(max1, 0);
-#endif
-      for (; c < depth; ++c) {
-        max_in_row = std::max(max_in_row, input_data_ptr[c]);
-      }
+  const int32_t clamp_max = std::numeric_limits<T>::max();
+  const int32_t clamp_min = std::numeric_limits<T>::min();
+  for (int i = 0; i < excluding_last_dim; ++i) {
+    int32_t max_val = std::numeric_limits<T>::min();
+    // Find max quantized value.
+    for (int j = 0; j < last_dim; ++j) {
+      max_val = std::max(max_val, static_cast<int32_t>(input_data[j]));
     }
 
-#ifdef USE_NEON
-    using FixedPointAccumInt32x4 =
-        gemmlowp::FixedPoint<int32x4_t, kAccumulationIntegerBits>;
-    using FixedPointScaledDiffInt32x4 =
-        gemmlowp::FixedPoint<int32x4_t, kScaledDiffIntegerBits>;
-    using FixedPoint0Int32x4 = gemmlowp::FixedPoint<int32x4_t, 0>;
-    FixedPoint0Int32x4 input_beta_multiplier_f0 =
-        FixedPoint0Int32x4::FromScalarRaw(input_beta_multiplier);
-    int16x8_t max_in_row_s16 = vdupq_n_s16(max_in_row);
-#endif
-
-    // Compute the sum of exponentials of the differences of entries in the
-    // current row from the largest entry in the current row.
-    FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
-    {
-      int c = 0;
-#ifdef USE_NEON
-      int32x4_t diff_min_s32 = vdupq_n_s32(diff_min);
-      FixedPointAccumInt32x4 sum_of_exps_0 = FixedPointAccumInt32x4::Zero();
-      FixedPointAccumInt32x4 sum_of_exps_1 = FixedPointAccumInt32x4::Zero();
-      FixedPointAccumInt32x4 zeros = FixedPointAccumInt32x4::Zero();
-      for (; c <= depth - 8; c += 8) {
-        uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
-        int16x8_t input_diff_s16 =
-            vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
-        int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
-        int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
-        int32x4_t mask_0 =
-            gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_0, diff_min_s32);
-        int32x4_t mask_1 =
-            gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_1, diff_min_s32);
-        FixedPointScaledDiffInt32x4 scaled_diff_0 =
-            input_beta_multiplier_f0 *
-            FixedPointScaledDiffInt32x4::FromRaw(
-                gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
-        FixedPointScaledDiffInt32x4 scaled_diff_1 =
-            input_beta_multiplier_f0 *
-            FixedPointScaledDiffInt32x4::FromRaw(
-                gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
-        FixedPointAccumInt32x4 exps_0 =
-            gemmlowp::Rescale<kAccumulationIntegerBits>(
-                exp_on_negative_values(scaled_diff_0));
-        FixedPointAccumInt32x4 exps_1 =
-            gemmlowp::Rescale<kAccumulationIntegerBits>(
-                exp_on_negative_values(scaled_diff_1));
-        FixedPointAccumInt32x4 masked_exps_0 =
-            SelectUsingMask(mask_0, exps_0, zeros);
-        FixedPointAccumInt32x4 masked_exps_1 =
-            SelectUsingMask(mask_1, exps_1, zeros);
-        sum_of_exps_0 = sum_of_exps_0 + masked_exps_0;
-        sum_of_exps_1 = sum_of_exps_1 + masked_exps_1;
-      }
-      int32x4_t sum_of_exps_reduced_4 = (sum_of_exps_0 + sum_of_exps_1).raw();
-      int32x2_t sum_of_exps_reduced_2 =
-          vadd_s32(vget_low_s32(sum_of_exps_reduced_4),
-                   vget_high_s32(sum_of_exps_reduced_4));
-      int32x2_t sum_of_exps_reduced_1 =
-          vpadd_s32(sum_of_exps_reduced_2, sum_of_exps_reduced_2);
-      sum_of_exps =
-          FixedPointAccum::FromRaw(vget_lane_s32(sum_of_exps_reduced_1, 0));
-#endif
-      for (; c < depth; ++c) {
-        int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
-        if (input_diff >= diff_min) {
-          const int32 input_diff_rescaled =
-              MultiplyByQuantizedMultiplierGreaterThanOne(
-                  input_diff, input_beta_multiplier, input_beta_left_shift);
-          const FixedPointScaledDiff scaled_diff_f8 =
-              FixedPointScaledDiff::FromRaw(input_diff_rescaled);
-          sum_of_exps =
-              sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
-                                exp_on_negative_values(scaled_diff_f8));
-        }
-      }
+    float sum_exp = 0.0f;
+    const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
+    const float* table_offset = &params.table[max_uint8 - max_val];
+    // Calculate normalizer sum(exp(x)).
+    for (int j = 0; j < last_dim; ++j) {
+      sum_exp += table_offset[input_data[j]];
     }
 
-    // Compute the fixed-point multiplier and shift that we need to apply to
-    // perform a division by the above-computed sum-of-exponentials.
-    int num_bits_over_unit = 0;
-    FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal(
-        sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit));
-
-    // Compute the quotients of exponentials of differences of entries in the
-    // current row from the largest entry, over the previously-computed sum of
-    // exponentials.
-    {
-      int c = 0;
-#ifdef USE_NEON
-      int16x8_t diff_min_s16 = vdupq_n_s16(diff_min);
-      for (; c <= depth - 8; c += 8) {
-        uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
-        int16x8_t input_diff_s16 =
-            vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
-        int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
-        int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
-        uint8x8_t mask = vmovn_u16(vcgeq_s16(input_diff_s16, diff_min_s16));
-        FixedPointScaledDiffInt32x4 scaled_diff_0 =
-            input_beta_multiplier_f0 *
-            FixedPointScaledDiffInt32x4::FromRaw(
-                gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
-        FixedPointScaledDiffInt32x4 scaled_diff_1 =
-            input_beta_multiplier_f0 *
-            FixedPointScaledDiffInt32x4::FromRaw(
-                gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
-        FixedPoint0Int32x4 exp_0 = exp_on_negative_values(scaled_diff_0);
-        FixedPoint0Int32x4 exp_1 = exp_on_negative_values(scaled_diff_1);
-        int32x4_t output_s32_0 = gemmlowp::RoundingDivideByPOT(
-            vqrdmulhq_n_s32(exp_0.raw(), shifted_scale.raw()),
-            num_bits_over_unit + 31 - 8);
-        int32x4_t output_s32_1 = gemmlowp::RoundingDivideByPOT(
-            vqrdmulhq_n_s32(exp_1.raw(), shifted_scale.raw()),
-            num_bits_over_unit + 31 - 8);
-        int16x8_t output_s16 =
-            vcombine_s16(vqmovn_s32(output_s32_0), vqmovn_s32(output_s32_1));
-        uint8x8_t output_u8 = vqmovun_s16(output_s16);
-        uint8x8_t masked_output = vbsl_u8(mask, output_u8, vdup_n_u8(0));
-        vst1_u8(output_data_ptr + c, masked_output);
-      }
-#endif
-      for (; c < depth; ++c) {
-        int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
-        if (input_diff >= diff_min) {
-          const int32 input_diff_rescaled =
-              MultiplyByQuantizedMultiplierGreaterThanOne(
-                  input_diff, input_beta_multiplier, input_beta_left_shift);
-          const FixedPointScaledDiff scaled_diff_f8 =
-              FixedPointScaledDiff::FromRaw(input_diff_rescaled);
-
-          FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
-          int32 unsat_output = gemmlowp::RoundingDivideByPOT(
-              (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
-
-          output_data_ptr[c] = std::max(std::min(unsat_output, 255), 0);
-
-        } else {
-          output_data_ptr[c] = 0;
-        }
-      }
+    const float inv_sum_exp = 1.0f / (sum_exp * params.scale);
+    // Normalize and quantize probabilities.
+    for (int j = 0; j < last_dim; ++j) {
+      const float prob_rescaled = table_offset[input_data[j]] * inv_sum_exp;
+      const int32_t prob_quantized =
+          QuantizeSoftmaxOutput(output_data, prob_rescaled, params.zero_point);
+      output_data[j] = static_cast<T>(
+          std::max(std::min(clamp_max, prob_quantized), clamp_min));
     }
+    input_data += last_dim;
+    output_data += last_dim;
   }
 }
 
@@ -5258,9 +5155,10 @@
 #ifdef USE_NEON
 
 inline void MultiplyByQuantizedMultiplier4Rows(
-    int32x4_t input_val_1, int32x4_t input_val_2, int32x4_t input_val_3,
-    int32x4_t input_val_4, int32_t multiplier, int32_t left_shifted_one,
-    int32_t right_shift, int32x4_t* result_val_1, int32x4_t* result_val_2,
+    const int32x4_t input_val_1, const int32x4_t input_val_2,
+    const int32x4_t input_val_3, const int32x4_t input_val_4,
+    const int32_t multiplier, const int32_t left_shifted_one,
+    const int32_t right_shift, int32x4_t* result_val_1, int32x4_t* result_val_2,
     int32x4_t* result_val_3, int32x4_t* result_val_4) {
   using gemmlowp::RoundingDivideByPOT;
   using gemmlowp::SaturatingRoundingDoublingHighMul;
@@ -5300,20 +5198,21 @@
   int i = 0;
 #ifdef USE_NEON
   // Constants.
-  int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
-  int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
-  int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
-  int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
+  const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
+  const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
+  const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
+  const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
 
   // Left shift & right shift unconditionally.
-  int32_t left_shifted_one =
+  const int32_t left_shifted_one =
       effective_scale_shift > 0 ? 1 << effective_scale_shift : 1;
-  int32_t right_shift = effective_scale_shift > 0 ? 0 : -effective_scale_shift;
+  const int32_t right_shift =
+      effective_scale_shift > 0 ? 0 : -effective_scale_shift;
 
   for (; i <= size - 16; i += 16) {
-    int8x16_t input_vec = vld1q_s8(input_data + i);
-    int16x8_t first_half = vmovl_s8(vget_low_s8(input_vec));
-    int16x8_t second_half = vmovl_s8(vget_high_s8(input_vec));
+    const int8x16_t input_vec = vld1q_s8(input_data + i);
+    const int16x8_t first_half = vmovl_s8(vget_low_s8(input_vec));
+    const int16x8_t second_half = vmovl_s8(vget_high_s8(input_vec));
     int32x4_t input_val_1 = vmovl_s16(vget_low_s16(first_half));
     int32x4_t input_val_2 = vmovl_s16(vget_high_s16(first_half));
     int32x4_t input_val_3 = vmovl_s16(vget_low_s16(second_half));
@@ -5338,21 +5237,27 @@
     result_val_3 = vmaxq_s32(vminq_s32(result_val_3, max_val_dup), min_val_dup);
     result_val_4 = vmaxq_s32(vminq_s32(result_val_4, max_val_dup), min_val_dup);
 
-    uint32x4_t result_val_1_unsigned = vreinterpretq_u32_s32(result_val_1);
-    uint32x4_t result_val_2_unsigned = vreinterpretq_u32_s32(result_val_2);
-    uint32x4_t result_val_3_unsigned = vreinterpretq_u32_s32(result_val_3);
-    uint32x4_t result_val_4_unsigned = vreinterpretq_u32_s32(result_val_4);
+    const uint32x4_t result_val_1_unsigned =
+        vreinterpretq_u32_s32(result_val_1);
+    const uint32x4_t result_val_2_unsigned =
+        vreinterpretq_u32_s32(result_val_2);
+    const uint32x4_t result_val_3_unsigned =
+        vreinterpretq_u32_s32(result_val_3);
+    const uint32x4_t result_val_4_unsigned =
+        vreinterpretq_u32_s32(result_val_4);
 
-    uint16x4_t narrowed_val_1 = vqmovn_u32(result_val_1_unsigned);
-    uint16x4_t narrowed_val_2 = vqmovn_u32(result_val_2_unsigned);
-    uint16x4_t narrowed_val_3 = vqmovn_u32(result_val_3_unsigned);
-    uint16x4_t narrowed_val_4 = vqmovn_u32(result_val_4_unsigned);
-    uint16x8_t output_first_half = vcombine_u16(narrowed_val_1, narrowed_val_2);
-    uint16x8_t output_second_half =
+    const uint16x4_t narrowed_val_1 = vqmovn_u32(result_val_1_unsigned);
+    const uint16x4_t narrowed_val_2 = vqmovn_u32(result_val_2_unsigned);
+    const uint16x4_t narrowed_val_3 = vqmovn_u32(result_val_3_unsigned);
+    const uint16x4_t narrowed_val_4 = vqmovn_u32(result_val_4_unsigned);
+    const uint16x8_t output_first_half =
+        vcombine_u16(narrowed_val_1, narrowed_val_2);
+    const uint16x8_t output_second_half =
         vcombine_u16(narrowed_val_3, narrowed_val_4);
-    uint8x8_t narrowed_first_half = vqmovn_u16(output_first_half);
-    uint8x8_t narrowed_second_half = vqmovn_u16(output_second_half);
-    uint8x16_t result = vcombine_u8(narrowed_first_half, narrowed_second_half);
+    const uint8x8_t narrowed_first_half = vqmovn_u16(output_first_half);
+    const uint8x8_t narrowed_second_half = vqmovn_u16(output_second_half);
+    const uint8x16_t result =
+        vcombine_u8(narrowed_first_half, narrowed_second_half);
     vst1q_u8(output_data + i, result);
   }
 
@@ -5376,7 +5281,7 @@
                                         int32_t input_zeropoint,
                                         int32_t output_zeropoint,
                                         int8_t* output_data) {
-  gemmlowp::ScopedProfilingLabel label("Requantize/UInt8ToInt8");
+  gemmlowp::ScopedProfilingLabel label("Requantize/Uint8ToInt8");
 
   static constexpr int32_t kMinOutput = std::numeric_limits<int8_t>::min();
   static constexpr int32_t kMaxOutput = std::numeric_limits<int8_t>::max();
@@ -5384,20 +5289,21 @@
   int i = 0;
 #ifdef USE_NEON
   // Constants.
-  int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
-  int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
-  int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
-  int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
+  const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
+  const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
+  const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
+  const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
 
   // Left shift & right shift unconditionally.
-  int32_t left_shifted_one =
+  const int32_t left_shifted_one =
       effective_scale_shift > 0 ? 1 << effective_scale_shift : 1;
-  int32_t right_shift = effective_scale_shift > 0 ? 0 : -effective_scale_shift;
+  const int32_t right_shift =
+      effective_scale_shift > 0 ? 0 : -effective_scale_shift;
 
   for (; i <= size - 16; i += 16) {
-    uint8x16_t input_vec = vld1q_u8(input_data + i);
-    uint16x8_t first_half = vmovl_u8(vget_low_u8(input_vec));
-    uint16x8_t second_half = vmovl_u8(vget_high_u8(input_vec));
+    const uint8x16_t input_vec = vld1q_u8(input_data + i);
+    const uint16x8_t first_half = vmovl_u8(vget_low_u8(input_vec));
+    const uint16x8_t second_half = vmovl_u8(vget_high_u8(input_vec));
     int32x4_t input_val_1 =
         vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(first_half)));
     int32x4_t input_val_2 =
@@ -5426,15 +5332,18 @@
     result_val_3 = vmaxq_s32(vminq_s32(result_val_3, max_val_dup), min_val_dup);
     result_val_4 = vmaxq_s32(vminq_s32(result_val_4, max_val_dup), min_val_dup);
 
-    int16x4_t narrowed_val_1 = vqmovn_s32(result_val_1);
-    int16x4_t narrowed_val_2 = vqmovn_s32(result_val_2);
-    int16x4_t narrowed_val_3 = vqmovn_s32(result_val_3);
-    int16x4_t narrowed_val_4 = vqmovn_s32(result_val_4);
-    int16x8_t output_first_half = vcombine_s16(narrowed_val_1, narrowed_val_2);
-    int16x8_t output_second_half = vcombine_s16(narrowed_val_3, narrowed_val_4);
-    int8x8_t narrowed_first_half = vqmovn_s16(output_first_half);
-    int8x8_t narrowed_second_half = vqmovn_s16(output_second_half);
-    int8x16_t result = vcombine_s8(narrowed_first_half, narrowed_second_half);
+    const int16x4_t narrowed_val_1 = vqmovn_s32(result_val_1);
+    const int16x4_t narrowed_val_2 = vqmovn_s32(result_val_2);
+    const int16x4_t narrowed_val_3 = vqmovn_s32(result_val_3);
+    const int16x4_t narrowed_val_4 = vqmovn_s32(result_val_4);
+    const int16x8_t output_first_half =
+        vcombine_s16(narrowed_val_1, narrowed_val_2);
+    const int16x8_t output_second_half =
+        vcombine_s16(narrowed_val_3, narrowed_val_4);
+    const int8x8_t narrowed_first_half = vqmovn_s16(output_first_half);
+    const int8x8_t narrowed_second_half = vqmovn_s16(output_second_half);
+    const int8x16_t result =
+        vcombine_s8(narrowed_first_half, narrowed_second_half);
     vst1q_s8(output_data + i, result);
   }
 
@@ -5451,6 +5360,180 @@
   }
 }
 
+template <>
+inline void Requantize<int8_t, int8_t>(const int8_t* input_data, int32_t size,
+                                       int32_t effective_scale_multiplier,
+                                       int32_t effective_scale_shift,
+                                       int32_t input_zeropoint,
+                                       int32_t output_zeropoint,
+                                       int8_t* output_data) {
+  gemmlowp::ScopedProfilingLabel label("Requantize/Int8ToInt8");
+
+  static constexpr int32_t kMinOutput = std::numeric_limits<int8_t>::min();
+  static constexpr int32_t kMaxOutput = std::numeric_limits<int8_t>::max();
+
+  int i = 0;
+#ifdef USE_NEON
+  // Constants.
+  const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
+  const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
+  const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
+  const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
+
+  // Left shift & right shift unconditionally.
+  int32_t left_shifted_one =
+      effective_scale_shift > 0 ? 1 << effective_scale_shift : 1;
+  int32_t right_shift = effective_scale_shift > 0 ? 0 : -effective_scale_shift;
+
+  for (; i <= size - 16; i += 16) {
+    const int8x16_t input_vec = vld1q_s8(input_data + i);
+    const int16x8_t first_half = vmovl_s8(vget_low_s8(input_vec));
+    const int16x8_t second_half = vmovl_s8(vget_high_s8(input_vec));
+    int32x4_t input_val_1 = vmovl_s16(vget_low_s16(first_half));
+    int32x4_t input_val_2 = vmovl_s16(vget_high_s16(first_half));
+    int32x4_t input_val_3 = vmovl_s16(vget_low_s16(second_half));
+    int32x4_t input_val_4 = vmovl_s16(vget_high_s16(second_half));
+
+    input_val_1 = vaddq_s32(input_val_1, input_zero_point_dup);
+    input_val_2 = vaddq_s32(input_val_2, input_zero_point_dup);
+    input_val_3 = vaddq_s32(input_val_3, input_zero_point_dup);
+    input_val_4 = vaddq_s32(input_val_4, input_zero_point_dup);
+
+    int32x4_t result_val_1, result_val_2, result_val_3, result_val_4;
+    MultiplyByQuantizedMultiplier4Rows(
+        input_val_1, input_val_2, input_val_3, input_val_4,
+        effective_scale_multiplier, left_shifted_one, right_shift,
+        &result_val_1, &result_val_2, &result_val_3, &result_val_4);
+
+    result_val_1 = vaddq_s32(result_val_1, output_zero_point_dup);
+    result_val_2 = vaddq_s32(result_val_2, output_zero_point_dup);
+    result_val_3 = vaddq_s32(result_val_3, output_zero_point_dup);
+    result_val_4 = vaddq_s32(result_val_4, output_zero_point_dup);
+    result_val_1 = vmaxq_s32(vminq_s32(result_val_1, max_val_dup), min_val_dup);
+    result_val_2 = vmaxq_s32(vminq_s32(result_val_2, max_val_dup), min_val_dup);
+    result_val_3 = vmaxq_s32(vminq_s32(result_val_3, max_val_dup), min_val_dup);
+    result_val_4 = vmaxq_s32(vminq_s32(result_val_4, max_val_dup), min_val_dup);
+
+    const int16x4_t narrowed_val_1 = vqmovn_s32(result_val_1);
+    const int16x4_t narrowed_val_2 = vqmovn_s32(result_val_2);
+    const int16x4_t narrowed_val_3 = vqmovn_s32(result_val_3);
+    const int16x4_t narrowed_val_4 = vqmovn_s32(result_val_4);
+    const int16x8_t output_first_half =
+        vcombine_s16(narrowed_val_1, narrowed_val_2);
+    const int16x8_t output_second_half =
+        vcombine_s16(narrowed_val_3, narrowed_val_4);
+    const int8x8_t narrowed_first_half = vqmovn_s16(output_first_half);
+    const int8x8_t narrowed_second_half = vqmovn_s16(output_second_half);
+    const int8x16_t result =
+        vcombine_s8(narrowed_first_half, narrowed_second_half);
+    vst1q_s8(output_data + i, result);
+  }
+
+#endif
+  for (; i < size; ++i) {
+    const int32_t input = input_data[i] - input_zeropoint;
+    const int32_t output =
+        MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
+                                      effective_scale_shift) +
+        output_zeropoint;
+    const int32_t clamped_output =
+        std::max(std::min(output, kMaxOutput), kMinOutput);
+    output_data[i] = static_cast<int8_t>(clamped_output);
+  }
+}
+
+template <>
+inline void Requantize<uint8_t, uint8_t>(
+    const uint8_t* input_data, int32_t size, int32_t effective_scale_multiplier,
+    int32_t effective_scale_shift, int32_t input_zeropoint,
+    int32_t output_zeropoint, uint8_t* output_data) {
+  gemmlowp::ScopedProfilingLabel label("Requantize/Uint8ToUint8");
+
+  static constexpr int32_t kMinOutput = std::numeric_limits<uint8_t>::min();
+  static constexpr int32_t kMaxOutput = std::numeric_limits<uint8_t>::max();
+
+  int i = 0;
+#ifdef USE_NEON
+  // Constants.
+  const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
+  const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
+  const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
+  const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
+
+  // Left shift & right shift unconditionally.
+  int32_t left_shifted_one =
+      effective_scale_shift > 0 ? 1 << effective_scale_shift : 1;
+  int32_t right_shift = effective_scale_shift > 0 ? 0 : -effective_scale_shift;
+
+  for (; i <= size - 16; i += 16) {
+    const uint8x16_t input_vec = vld1q_u8(input_data + i);
+    const uint16x8_t first_half = vmovl_u8(vget_low_u8(input_vec));
+    const uint16x8_t second_half = vmovl_u8(vget_high_u8(input_vec));
+    int32x4_t input_val_1 =
+        vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(first_half)));
+    int32x4_t input_val_2 =
+        vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(first_half)));
+    int32x4_t input_val_3 =
+        vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(second_half)));
+    int32x4_t input_val_4 =
+        vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(second_half)));
+    input_val_1 = vaddq_s32(input_val_1, input_zero_point_dup);
+    input_val_2 = vaddq_s32(input_val_2, input_zero_point_dup);
+    input_val_3 = vaddq_s32(input_val_3, input_zero_point_dup);
+    input_val_4 = vaddq_s32(input_val_4, input_zero_point_dup);
+
+    int32x4_t result_val_1, result_val_2, result_val_3, result_val_4;
+    MultiplyByQuantizedMultiplier4Rows(
+        input_val_1, input_val_2, input_val_3, input_val_4,
+        effective_scale_multiplier, left_shifted_one, right_shift,
+        &result_val_1, &result_val_2, &result_val_3, &result_val_4);
+
+    result_val_1 = vaddq_s32(result_val_1, output_zero_point_dup);
+    result_val_2 = vaddq_s32(result_val_2, output_zero_point_dup);
+    result_val_3 = vaddq_s32(result_val_3, output_zero_point_dup);
+    result_val_4 = vaddq_s32(result_val_4, output_zero_point_dup);
+    result_val_1 = vmaxq_s32(vminq_s32(result_val_1, max_val_dup), min_val_dup);
+    result_val_2 = vmaxq_s32(vminq_s32(result_val_2, max_val_dup), min_val_dup);
+    result_val_3 = vmaxq_s32(vminq_s32(result_val_3, max_val_dup), min_val_dup);
+    result_val_4 = vmaxq_s32(vminq_s32(result_val_4, max_val_dup), min_val_dup);
+
+    const uint32x4_t result_val_1_unsigned =
+        vreinterpretq_u32_s32(result_val_1);
+    const uint32x4_t result_val_2_unsigned =
+        vreinterpretq_u32_s32(result_val_2);
+    const uint32x4_t result_val_3_unsigned =
+        vreinterpretq_u32_s32(result_val_3);
+    const uint32x4_t result_val_4_unsigned =
+        vreinterpretq_u32_s32(result_val_4);
+
+    const uint16x4_t narrowed_val_1 = vqmovn_u32(result_val_1_unsigned);
+    const uint16x4_t narrowed_val_2 = vqmovn_u32(result_val_2_unsigned);
+    const uint16x4_t narrowed_val_3 = vqmovn_u32(result_val_3_unsigned);
+    const uint16x4_t narrowed_val_4 = vqmovn_u32(result_val_4_unsigned);
+    const uint16x8_t output_first_half =
+        vcombine_u16(narrowed_val_1, narrowed_val_2);
+    const uint16x8_t output_second_half =
+        vcombine_u16(narrowed_val_3, narrowed_val_4);
+    const uint8x8_t narrowed_first_half = vqmovn_u16(output_first_half);
+    const uint8x8_t narrowed_second_half = vqmovn_u16(output_second_half);
+    const uint8x16_t result =
+        vcombine_u8(narrowed_first_half, narrowed_second_half);
+    vst1q_u8(output_data + i, result);
+  }
+
+#endif
+  for (; i < size; ++i) {
+    const int32_t input = input_data[i] - input_zeropoint;
+    const int32_t output =
+        MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
+                                      effective_scale_shift) +
+        output_zeropoint;
+    const int32_t clamped_output =
+        std::max(std::min(output, kMaxOutput), kMinOutput);
+    output_data[i] = static_cast<uint8_t>(clamped_output);
+  }
+}
+
 inline void HardSwish(const RuntimeShape& input_shape, const float* input_data,
                       const RuntimeShape& output_shape, float* output_data) {
   gemmlowp::ScopedProfilingLabel label("HardSwish/Float");
@@ -5748,6 +5831,324 @@
                                     unextended_output_shape, output_data);
 }
 
+#ifdef USE_NEON
+
+inline void ScaleWithNewZeroPoint(const int32x4_t input,
+                                  const float32x4_t scale_dup,
+                                  const float32x4_t zero_times_scale_dup,
+                                  float32x4_t* output) {
+#ifdef __ARM_FEATURE_FMA
+  *output = vfmaq_f32(zero_times_scale_dup, vcvtq_f32_s32(input), scale_dup);
+#else
+  *output = vaddq_f32(vmulq_f32(vcvtq_f32_s32(input), scale_dup),
+                      zero_times_scale_dup);
+#endif
+}
+
+#endif  // USE_NEON
+
+inline void Dequantize(const tflite::DequantizationParams& op_params,
+                       const RuntimeShape& input_shape,
+                       const uint8_t* input_data,
+                       const RuntimeShape& output_shape, float* output_data) {
+  gemmlowp::ScopedProfilingLabel label("Dequantize/Uint8");
+  const int32 zero_point = op_params.zero_point;
+  const double scale = op_params.scale;
+  const int flat_size = MatchingFlatSize(input_shape, output_shape);
+
+  int i = 0;
+#ifdef USE_NEON
+  const float32x4_t scale_dup = vdupq_n_f32(static_cast<float>(scale));
+  const float32x4_t zero_times_scale_dup =
+      vdupq_n_f32(static_cast<float>(-zero_point * scale));
+  for (; i <= flat_size - 8; i += 8) {
+    const uint8x8_t input_u8 = vld1_u8(input_data + i);
+    const uint16x8_t input_u16 = vmovl_u8(input_u8);
+    const int16x8_t input_s16 = vreinterpretq_s16_u16(input_u16);
+    const int16x4_t input_s16_low = vget_low_s16(input_s16);
+    const int16x4_t input_s16_high = vget_high_s16(input_s16);
+    const int32x4_t val_low = vmovl_s16(input_s16_low);
+    const int32x4_t val_high = vmovl_s16(input_s16_high);
+
+    float32x4_t result_low, result_high;
+    ScaleWithNewZeroPoint(val_low, scale_dup, zero_times_scale_dup,
+                          &result_low);
+    ScaleWithNewZeroPoint(val_high, scale_dup, zero_times_scale_dup,
+                          &result_high);
+
+    vst1q_f32(output_data + i, result_low);
+    vst1q_f32(output_data + i + 4, result_high);
+  }
+#endif  // NEON
+  for (; i < flat_size; ++i) {
+    const int32 val = input_data[i];
+    const float result = static_cast<float>(scale * (val - zero_point));
+    output_data[i] = result;
+  }
+}
+
+inline void Dequantize(const tflite::DequantizationParams& op_params,
+                       const RuntimeShape& input_shape,
+                       const int8_t* input_data,
+                       const RuntimeShape& output_shape, float* output_data) {
+  gemmlowp::ScopedProfilingLabel label("Dequantize/Int8");
+  const int32 zero_point = op_params.zero_point;
+  const double scale = op_params.scale;
+  const int flat_size = MatchingFlatSize(input_shape, output_shape);
+
+  int i = 0;
+#ifdef USE_NEON
+  const float32x4_t scale_dup = vdupq_n_f32(static_cast<float>(scale));
+  const float32x4_t zero_times_scale_dup =
+      vdupq_n_f32(static_cast<float>(-zero_point * scale));
+  for (; i <= flat_size - 8; i += 8) {
+    const int8x8_t input_s8 = vld1_s8(input_data + i);
+    const int16x8_t input_s16 = vmovl_s8(input_s8);
+    const int16x4_t input_s16_low = vget_low_s16(input_s16);
+    const int16x4_t input_s16_high = vget_high_s16(input_s16);
+    const int32x4_t val_low = vmovl_s16(input_s16_low);
+    const int32x4_t val_high = vmovl_s16(input_s16_high);
+
+    float32x4_t result_low, result_high;
+    ScaleWithNewZeroPoint(val_low, scale_dup, zero_times_scale_dup,
+                          &result_low);
+    ScaleWithNewZeroPoint(val_high, scale_dup, zero_times_scale_dup,
+                          &result_high);
+
+    vst1q_f32(output_data + i, result_low);
+    vst1q_f32(output_data + i + 4, result_high);
+  }
+#endif  // NEON
+  for (; i < flat_size; ++i) {
+    const int32 val = input_data[i];
+    const float result = static_cast<float>(scale * (val - zero_point));
+    output_data[i] = result;
+  }
+}
+
+inline void Dequantize(const tflite::DequantizationParams& op_params,
+                       const RuntimeShape& input_shape,
+                       const int16_t* input_data,
+                       const RuntimeShape& output_shape, float* output_data) {
+  gemmlowp::ScopedProfilingLabel label("Dequantize/Int16");
+  const int32 zero_point = op_params.zero_point;
+  const double scale = op_params.scale;
+  const int flat_size = MatchingFlatSize(input_shape, output_shape);
+
+  int i = 0;
+#ifdef USE_NEON
+  const float32x4_t scale_dup = vdupq_n_f32(static_cast<float>(scale));
+  const float32x4_t zero_times_scale_dup =
+      vdupq_n_f32(static_cast<float>(-zero_point * scale));
+  for (; i <= flat_size - 8; i += 8) {
+    const int16x4_t input_s16_low = vld1_s16(input_data + i);
+    const int16x4_t input_s16_high = vld1_s16(input_data + i + 4);
+    const int32x4_t val_low = vmovl_s16(input_s16_low);
+    const int32x4_t val_high = vmovl_s16(input_s16_high);
+
+    float32x4_t result_low, result_high;
+    ScaleWithNewZeroPoint(val_low, scale_dup, zero_times_scale_dup,
+                          &result_low);
+    ScaleWithNewZeroPoint(val_high, scale_dup, zero_times_scale_dup,
+                          &result_high);
+
+    vst1q_f32(output_data + i, result_low);
+    vst1q_f32(output_data + i + 4, result_high);
+  }
+#endif  // NEON
+  for (; i < flat_size; ++i) {
+    const int32 val = input_data[i];
+    const float result = static_cast<float>(scale * (val - zero_point));
+    output_data[i] = result;
+  }
+}
+
+inline void Dequantize(const RuntimeShape& input_shape,
+                       const Eigen::half* input_data,
+                       const RuntimeShape& output_shape, float* output_data) {
+  reference_ops::Dequantize(input_shape, input_data, output_shape, output_data);
+}
+
+template <typename T>
+inline void AffineQuantize(const tflite::QuantizationParams& op_params,
+                           const RuntimeShape& input_shape,
+                           const float* input_data,
+                           const RuntimeShape& output_shape, T* output_data) {
+  reference_ops::AffineQuantize(op_params, input_shape, input_data,
+                                output_shape, output_data);
+}
+
+template <>
+inline void AffineQuantize(const tflite::QuantizationParams& op_params,
+                           const RuntimeShape& input_shape,
+                           const float* input_data,
+                           const RuntimeShape& output_shape,
+                           int8_t* output_data) {
+  gemmlowp::ScopedProfilingLabel label("Quantize/Int8");
+  const int32 zero_point = op_params.zero_point;
+  const double scale = static_cast<double>(op_params.scale);
+  const int flat_size = MatchingFlatSize(input_shape, output_shape);
+  static constexpr int32 min_val = std::numeric_limits<int8_t>::min();
+  static constexpr int32 max_val = std::numeric_limits<int8_t>::max();
+
+  int i = 0;
+#ifdef USE_NEON
+  const float32x4_t reverse_scale_dup = vdupq_n_f32(1.0f / scale);
+  const int32x4_t zero_point_dup = vdupq_n_s32(zero_point);
+  const int32x4_t min_val_dup = vdupq_n_s32(min_val);
+  const int32x4_t max_val_dup = vdupq_n_s32(max_val);
+
+  for (; i <= flat_size - 8; i += 8) {
+    const float* src_data_ptr = input_data + i;
+    float32x4_t input_val_0 = vld1q_f32(src_data_ptr);
+    float32x4_t input_val_1 = vld1q_f32(src_data_ptr + 4);
+
+    input_val_0 = vmulq_f32(input_val_0, reverse_scale_dup);
+    input_val_1 = vmulq_f32(input_val_1, reverse_scale_dup);
+
+    int32x4_t casted_val_0 = RoundToNearest(input_val_0);
+    int32x4_t casted_val_1 = RoundToNearest(input_val_1);
+
+    casted_val_0 = vaddq_s32(casted_val_0, zero_point_dup);
+    casted_val_1 = vaddq_s32(casted_val_1, zero_point_dup);
+
+    // Clamp the values to fit the target type's range.
+    casted_val_0 = vmaxq_s32(casted_val_0, min_val_dup);
+    casted_val_1 = vmaxq_s32(casted_val_1, min_val_dup);
+    casted_val_0 = vminq_s32(casted_val_0, max_val_dup);
+    casted_val_1 = vminq_s32(casted_val_1, max_val_dup);
+
+    const int16x4_t narrowed_val_0 = vmovn_s32(casted_val_0);
+    const int16x4_t narrowed_val_1 = vmovn_s32(casted_val_1);
+    const int16x8_t combined_val = vcombine_s16(narrowed_val_0, narrowed_val_1);
+    const int8x8_t combined_val_narrowed = vmovn_s16(combined_val);
+    vst1_s8(output_data + i, combined_val_narrowed);
+  }
+#endif  // NEON
+
+  for (; i < flat_size; ++i) {
+    const float val = input_data[i];
+    const int32 unclamped =
+        static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
+    const int32 clamped = std::min(std::max(unclamped, min_val), max_val);
+    output_data[i] = clamped;
+  }
+}
+
+template <>
+inline void AffineQuantize(const tflite::QuantizationParams& op_params,
+                           const RuntimeShape& input_shape,
+                           const float* input_data,
+                           const RuntimeShape& output_shape,
+                           uint8_t* output_data) {
+  gemmlowp::ScopedProfilingLabel label("Quantize/Uint8");
+  const int32 zero_point = op_params.zero_point;
+  const double scale = static_cast<double>(op_params.scale);
+  const int flat_size = MatchingFlatSize(input_shape, output_shape);
+  static constexpr int32 min_val = std::numeric_limits<uint8_t>::min();
+  static constexpr int32 max_val = std::numeric_limits<uint8_t>::max();
+
+  int i = 0;
+#ifdef USE_NEON
+  const float32x4_t reverse_scale_dup = vdupq_n_f32(1.0f / scale);
+  const int32x4_t zero_point_dup = vdupq_n_s32(zero_point);
+  const int32x4_t min_val_dup = vdupq_n_s32(min_val);
+  const int32x4_t max_val_dup = vdupq_n_s32(max_val);
+
+  for (; i <= flat_size - 8; i += 8) {
+    const float* src_data_ptr = input_data + i;
+    float32x4_t input_val_0 = vld1q_f32(src_data_ptr);
+    float32x4_t input_val_1 = vld1q_f32(src_data_ptr + 4);
+
+    input_val_0 = vmulq_f32(input_val_0, reverse_scale_dup);
+    input_val_1 = vmulq_f32(input_val_1, reverse_scale_dup);
+
+    int32x4_t casted_val_0 = RoundToNearest(input_val_0);
+    int32x4_t casted_val_1 = RoundToNearest(input_val_1);
+
+    casted_val_0 = vaddq_s32(casted_val_0, zero_point_dup);
+    casted_val_1 = vaddq_s32(casted_val_1, zero_point_dup);
+
+    // Clamp the values to fit the target type's range.
+    casted_val_0 = vmaxq_s32(casted_val_0, min_val_dup);
+    casted_val_1 = vmaxq_s32(casted_val_1, min_val_dup);
+    casted_val_0 = vminq_s32(casted_val_0, max_val_dup);
+    casted_val_1 = vminq_s32(casted_val_1, max_val_dup);
+
+    const uint16x4_t narrowed_val_0 = vqmovun_s32(casted_val_0);
+    const uint16x4_t narrowed_val_1 = vqmovun_s32(casted_val_1);
+    const uint16x8_t combined_val =
+        vcombine_u16(narrowed_val_0, narrowed_val_1);
+    const uint8x8_t combined_val_narrowed = vmovn_u16(combined_val);
+    vst1_u8(output_data + i, combined_val_narrowed);
+  }
+#endif  // NEON
+
+  for (; i < flat_size; ++i) {
+    const float val = input_data[i];
+    const int32 unclamped =
+        static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
+    const int32 clamped = std::min(std::max(unclamped, min_val), max_val);
+    output_data[i] = clamped;
+  }
+}
+
+template <>
+inline void AffineQuantize(const tflite::QuantizationParams& op_params,
+                           const RuntimeShape& input_shape,
+                           const float* input_data,
+                           const RuntimeShape& output_shape,
+                           int16_t* output_data) {
+  gemmlowp::ScopedProfilingLabel label("Quantize/Int16");
+  const int32 zero_point = op_params.zero_point;
+  const double scale = static_cast<double>(op_params.scale);
+  const int flat_size = MatchingFlatSize(input_shape, output_shape);
+  static constexpr int32 min_val = std::numeric_limits<int16_t>::min();
+  static constexpr int32 max_val = std::numeric_limits<int16_t>::max();
+
+  int i = 0;
+#ifdef USE_NEON
+  const float32x4_t reverse_scale_dup = vdupq_n_f32(1.0f / scale);
+  const int32x4_t zero_point_dup = vdupq_n_s32(zero_point);
+  const int32x4_t min_val_dup = vdupq_n_s32(min_val);
+  const int32x4_t max_val_dup = vdupq_n_s32(max_val);
+
+  for (; i <= flat_size - 8; i += 8) {
+    const float* src_data_ptr = input_data + i;
+    float32x4_t input_val_0 = vld1q_f32(src_data_ptr);
+    float32x4_t input_val_1 = vld1q_f32(src_data_ptr + 4);
+
+    input_val_0 = vmulq_f32(input_val_0, reverse_scale_dup);
+    input_val_1 = vmulq_f32(input_val_1, reverse_scale_dup);
+
+    int32x4_t casted_val_0 = RoundToNearest(input_val_0);
+    int32x4_t casted_val_1 = RoundToNearest(input_val_1);
+
+    casted_val_0 = vaddq_s32(casted_val_0, zero_point_dup);
+    casted_val_1 = vaddq_s32(casted_val_1, zero_point_dup);
+
+    // Clamp the values to fit the target type's range.
+    casted_val_0 = vmaxq_s32(casted_val_0, min_val_dup);
+    casted_val_1 = vmaxq_s32(casted_val_1, min_val_dup);
+    casted_val_0 = vminq_s32(casted_val_0, max_val_dup);
+    casted_val_1 = vminq_s32(casted_val_1, max_val_dup);
+
+    const int16x4_t narrowed_val_0 = vmovn_s32(casted_val_0);
+    const int16x4_t narrowed_val_1 = vmovn_s32(casted_val_1);
+    vst1_s16(output_data + i, narrowed_val_0);
+    vst1_s16(output_data + i + 4, narrowed_val_1);
+  }
+#endif  // NEON
+
+  for (; i < flat_size; ++i) {
+    const float val = input_data[i];
+    const int32 unclamped =
+        static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
+    const int32 clamped = std::min(std::max(unclamped, min_val), max_val);
+    output_data[i] = clamped;
+  }
+}
+
 }  // namespace optimized_ops
 }  // namespace tflite
 
diff --git a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h
index 41f7194..373f75f 100644
--- a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h
+++ b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h
@@ -163,10 +163,6 @@
                    min_value, max_value, scaling_factor);
 }
 
-void VectorShiftLeft(float* vector, int v_size, float shift_value) {
-  NEON_OR_PORTABLE(VectorShiftLeft, vector, v_size, shift_value);
-}
-
 void ReductionSumVector(const float* input_vector, float* output_vector,
                         int output_size, int reduction_size) {
   NEON_OR_PORTABLE(ReductionSumVector, input_vector, output_vector, output_size,
diff --git a/tensorflow/lite/kernels/internal/quantization_util.cc b/tensorflow/lite/kernels/internal/quantization_util.cc
index 71eef71..af07c5a 100644
--- a/tensorflow/lite/kernels/internal/quantization_util.cc
+++ b/tensorflow/lite/kernels/internal/quantization_util.cc
@@ -72,6 +72,20 @@
     ++*shift;
   }
   TFLITE_CHECK_LE(q_fixed, std::numeric_limits<int32_t>::max());
+  // A shift amount smaller than -31 would cause all bits to be shifted out
+  // and thus all results would be zero. We implement that instead with
+  // q_fixed==0, so as to avoid hitting issues with right-shift
+  // operations with shift amounts greater than 31. Note that this happens
+  // roughly when abs(double_multiplier) < 2^-31 and the present handling means
+  // that we're effectively flushing tiny double_multiplier's to zero.
+  // We could conceivably handle values in the range (roughly) [32, 63]
+  // as 'denormals' i.e. (shift==0, q_fixed < 2^30). In that point of view
+  // the present handling is just doing 'flush denormals to zero'. We could
+  // reconsider and actually generate nonzero denormals if a need arises.
+  if (*shift < -31) {
+    *shift = 0;
+    q_fixed = 0;
+  }
   *quantized_multiplier = static_cast<int32_t>(q_fixed);
 }
 
diff --git a/tensorflow/lite/kernels/internal/quantization_util_test.cc b/tensorflow/lite/kernels/internal/quantization_util_test.cc
index ca4ff37..56c7720 100644
--- a/tensorflow/lite/kernels/internal/quantization_util_test.cc
+++ b/tensorflow/lite/kernels/internal/quantization_util_test.cc
@@ -381,6 +381,20 @@
   EXPECT_THAT(quantize(2), Pair(1073741824, 2));
 }
 
+TEST(QuantizationUtilTest, QuantizeMultiplierUnderflow) {
+  auto quantize = [](double d) {
+    int32_t q;
+    int s;
+    QuantizeMultiplier(d, &q, &s);
+    return std::pair<int32_t, int>{q, s};
+  };
+
+  EXPECT_THAT(quantize(std::ldexp(1.0f, -31)), Pair(1073741824, -30));
+  EXPECT_THAT(quantize(std::ldexp(1.0f, -32)), Pair(1073741824, -31));
+  EXPECT_THAT(quantize(std::ldexp(0.99f, -32)), Pair(0, 0));
+  EXPECT_THAT(quantize(std::ldexp(1.0f, -33)), Pair(0, 0));
+}
+
 TEST(QuantizationUtilTest, PreprocessSoftmaxScaling) {
   auto quantize = [](double beta, double scale, int integer_bits) {
     int32_t q;
diff --git a/tensorflow/lite/kernels/internal/reference/add.h b/tensorflow/lite/kernels/internal/reference/add.h
new file mode 100644
index 0000000..5193a58
--- /dev/null
+++ b/tensorflow/lite/kernels/internal/reference/add.h
@@ -0,0 +1,418 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ADD_H_
+#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ADD_H_
+
+#include "fixedpoint/fixedpoint.h"
+#include "tensorflow/lite/kernels/internal/common.h"
+
+namespace tflite {
+
+namespace reference_ops {
+
+template <typename T>
+inline void Add(const ArithmeticParams& params,
+                const RuntimeShape& input1_shape, const T* input1_data,
+                const RuntimeShape& input2_shape, const T* input2_data,
+                const RuntimeShape& output_shape, T* output_data) {
+  const int flat_size =
+      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+  for (int i = 0; i < flat_size; ++i) {
+    output_data[i] = ActivationFunctionWithMinMax(
+        input1_data[i] + input2_data[i], params.quantized_activation_min,
+        params.quantized_activation_max);
+  }
+}
+
+inline void Add(const ArithmeticParams& params,
+                const RuntimeShape& input1_shape, const float* input1_data,
+                const RuntimeShape& input2_shape, const float* input2_data,
+                const RuntimeShape& output_shape, float* output_data) {
+  const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
+  for (int i = 0; i < size; i++) {
+    auto x = input1_data[i] + input2_data[i];
+    output_data[i] = ActivationFunctionWithMinMax(
+        x, params.float_activation_min, params.float_activation_max);
+  }
+}
+
+// Element-wise add that can often be used for inner loop of broadcast add as
+// well as the non-broadcast add.
+inline void AddElementwise(int size, const ArithmeticParams& params,
+                           const uint8* input1_data, const uint8* input2_data,
+                           uint8* output_data) {
+  TFLITE_DCHECK_GT(params.input1_offset, -256);
+  TFLITE_DCHECK_GT(params.input2_offset, -256);
+  TFLITE_DCHECK_LT(params.input1_offset, 256);
+  TFLITE_DCHECK_LT(params.input2_offset, 256);
+
+  for (int i = 0; i < size; ++i) {
+    const int32 input1_val = params.input1_offset + input1_data[i];
+    const int32 input2_val = params.input2_offset + input2_data[i];
+    const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
+    const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
+    const int32 scaled_input1_val =
+        MultiplyByQuantizedMultiplierSmallerThanOneExp(
+            shifted_input1_val, params.input1_multiplier, params.input1_shift);
+    const int32 scaled_input2_val =
+        MultiplyByQuantizedMultiplierSmallerThanOneExp(
+            shifted_input2_val, params.input2_multiplier, params.input2_shift);
+    const int32 raw_sum = scaled_input1_val + scaled_input2_val;
+    const int32 raw_output =
+        MultiplyByQuantizedMultiplierSmallerThanOneExp(
+            raw_sum, params.output_multiplier, params.output_shift) +
+        params.output_offset;
+    const int32 clamped_output =
+        std::min(params.quantized_activation_max,
+                 std::max(params.quantized_activation_min, raw_output));
+    output_data[i] = static_cast<uint8>(clamped_output);
+  }
+}
+
+// Scalar-broadcast add that can be used for inner loop of more general
+// broadcast add, so that, for example, scalar-broadcast with batch will still
+// be fast.
+inline void AddScalarBroadcast(int size, const ArithmeticParams& params,
+                               uint8 input1_data, const uint8* input2_data,
+                               uint8* output_data) {
+  TFLITE_DCHECK_GT(params.input1_offset, -256);
+  TFLITE_DCHECK_GT(params.input2_offset, -256);
+  TFLITE_DCHECK_LT(params.input1_offset, 256);
+  TFLITE_DCHECK_LT(params.input2_offset, 256);
+
+  const int32 input1_val = params.input1_offset + input1_data;
+  const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
+  const int32 scaled_input1_val =
+      MultiplyByQuantizedMultiplierSmallerThanOneExp(
+          shifted_input1_val, params.input1_multiplier, params.input1_shift);
+  for (int i = 0; i < size; ++i) {
+    const int32 input2_val = params.input2_offset + input2_data[i];
+    const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
+    const int32 scaled_input2_val =
+        MultiplyByQuantizedMultiplierSmallerThanOneExp(
+            shifted_input2_val, params.input2_multiplier, params.input2_shift);
+    const int32 raw_sum = scaled_input1_val + scaled_input2_val;
+    const int32 raw_output =
+        MultiplyByQuantizedMultiplierSmallerThanOneExp(
+            raw_sum, params.output_multiplier, params.output_shift) +
+        params.output_offset;
+    const int32 clamped_output =
+        std::min(params.quantized_activation_max,
+                 std::max(params.quantized_activation_min, raw_output));
+    output_data[i] = static_cast<uint8>(clamped_output);
+  }
+}
+
+inline void Add(const ArithmeticParams& params,
+                const RuntimeShape& input1_shape, const uint8* input1_data,
+                const RuntimeShape& input2_shape, const uint8* input2_data,
+                const RuntimeShape& output_shape, uint8* output_data) {
+  TFLITE_DCHECK_LE(params.quantized_activation_min,
+                   params.quantized_activation_max);
+  const int flat_size =
+      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+
+  TFLITE_DCHECK_GT(params.input1_offset, -256);
+  TFLITE_DCHECK_GT(params.input2_offset, -256);
+  TFLITE_DCHECK_LT(params.input1_offset, 256);
+  TFLITE_DCHECK_LT(params.input2_offset, 256);
+  AddElementwise(flat_size, params, input1_data, input2_data, output_data);
+}
+
+inline void Add(const ArithmeticParams& params,
+                const RuntimeShape& input1_shape, const int16* input1_data,
+                const RuntimeShape& input2_shape, const int16* input2_data,
+                const RuntimeShape& output_shape, int16* output_data) {
+  TFLITE_DCHECK_LE(params.quantized_activation_min,
+                   params.quantized_activation_max);
+
+  const int input1_shift = params.input1_shift;
+  const int flat_size =
+      MatchingFlatSize(output_shape, input1_shape, input2_shape);
+  const int16 output_activation_min = params.quantized_activation_min;
+  const int16 output_activation_max = params.quantized_activation_max;
+
+  TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0);
+  TFLITE_DCHECK_LE(input1_shift, 0);
+  TFLITE_DCHECK_LE(params.input2_shift, 0);
+  const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data;
+  const int16* shift_input = input1_shift == 0 ? input2_data : input1_data;
+  const int input_right_shift =
+      input1_shift == 0 ? -params.input2_shift : -input1_shift;
+
+  for (int i = 0; i < flat_size; i++) {
+    // F0 uses 0 integer bits, range [-1, 1].
+    using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
+
+    F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
+    F0 scaled_input = F0::FromRaw(
+        gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
+    F0 result = gemmlowp::SaturatingAdd(scaled_input, input_ready_scaled);
+    const int16 raw_output = result.raw();
+    const int16 clamped_output = std::min(
+        output_activation_max, std::max(output_activation_min, raw_output));
+    output_data[i] = clamped_output;
+  }
+}
+
+// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary
+// dimensionality if the runtime code does a single loop over one dimension
+// that handles broadcasting as the base case. The code generator would then
+// generate max(D1, D2) nested for loops.
+// TODO(benoitjacob): BroadcastAdd is intentionally duplicated from
+// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
+// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
+// reference_ops.h.
+inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
+                               const RuntimeShape& input1_shape,
+                               const float* input1_data,
+                               const RuntimeShape& input2_shape,
+                               const float* input2_data,
+                               const RuntimeShape& output_shape,
+                               float* output_data) {
+  NdArrayDesc<4> desc1;
+  NdArrayDesc<4> desc2;
+  NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+                                      &desc2);
+  const RuntimeShape extended_output_shape =
+      RuntimeShape::ExtendedShape(4, output_shape);
+
+  // In Tensorflow, the dimensions are canonically named (batch_number, row,
+  // col, channel), with extents (batches, height, width, depth), with the
+  // trailing dimension changing most rapidly (channels has the smallest stride,
+  // typically 1 element).
+  //
+  // In generated C code, we store arrays with the dimensions reversed. The
+  // first dimension has smallest stride.
+  //
+  // We name our variables by their Tensorflow convention, but generate C code
+  // nesting loops such that the innermost loop has the smallest stride for the
+  // best cache behavior.
+  for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+    for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+      for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+        for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+          output_data[Offset(extended_output_shape, b, y, x, c)] =
+              ActivationFunctionWithMinMax(
+                  input1_data[SubscriptToIndex(desc1, b, y, x, c)] +
+                      input2_data[SubscriptToIndex(desc2, b, y, x, c)],
+                  params.float_activation_min, params.float_activation_max);
+        }
+      }
+    }
+  }
+}
+
+inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
+                               const RuntimeShape& input1_shape,
+                               const int32* input1_data,
+                               const RuntimeShape& input2_shape,
+                               const int32* input2_data,
+                               const RuntimeShape& output_shape,
+                               int32* output_data) {
+  NdArrayDesc<4> desc1;
+  NdArrayDesc<4> desc2;
+  NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+                                      &desc2);
+  const RuntimeShape extended_output_shape =
+      RuntimeShape::ExtendedShape(4, output_shape);
+
+  // In Tensorflow, the dimensions are canonically named (batch_number, row,
+  // col, channel), with extents (batches, height, width, depth), with the
+  // trailing dimension changing most rapidly (channels has the smallest stride,
+  // typically 1 element).
+  //
+  // In generated C code, we store arrays with the dimensions reversed. The
+  // first dimension has smallest stride.
+  //
+  // We name our variables by their Tensorflow convention, but generate C code
+  // nesting loops such that the innermost loop has the smallest stride for the
+  // best cache behavior.
+  for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+    for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+      for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+        for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+          output_data[Offset(extended_output_shape, b, y, x, c)] =
+              ActivationFunctionWithMinMax(
+                  input1_data[SubscriptToIndex(desc1, b, y, x, c)] +
+                      input2_data[SubscriptToIndex(desc2, b, y, x, c)],
+                  params.quantized_activation_min,
+                  params.quantized_activation_max);
+        }
+      }
+    }
+  }
+}
+
+inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
+                               const RuntimeShape& input1_shape,
+                               const uint8* input1_data,
+                               const RuntimeShape& input2_shape,
+                               const uint8* input2_data,
+                               const RuntimeShape& output_shape,
+                               uint8* output_data) {
+  NdArrayDesc<4> desc1;
+  NdArrayDesc<4> desc2;
+  NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+                                      &desc2);
+  const RuntimeShape extended_output_shape =
+      RuntimeShape::ExtendedShape(4, output_shape);
+
+  // In Tensorflow, the dimensions are canonically named (batch_number, row,
+  // col, channel), with extents (batches, height, width, depth), with the
+  // trailing dimension changing most rapidly (channels has the smallest stride,
+  // typically 1 element).
+  //
+  // In generated C code, we store arrays with the dimensions reversed. The
+  // first dimension has smallest stride.
+  //
+  // We name our variables by their Tensorflow convention, but generate C code
+  // nesting loops such that the innermost loop has the smallest stride for the
+  // best cache behavior.
+  for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+    for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+      for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+        for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+          const int32 input1_val =
+              params.input1_offset +
+              input1_data[SubscriptToIndex(desc1, b, y, x, c)];
+          const int32 input2_val =
+              params.input2_offset +
+              input2_data[SubscriptToIndex(desc2, b, y, x, c)];
+          const int32 shifted_input1_val =
+              input1_val * (1 << params.left_shift);
+          const int32 shifted_input2_val =
+              input2_val * (1 << params.left_shift);
+          const int32 scaled_input1_val =
+              MultiplyByQuantizedMultiplierSmallerThanOneExp(
+                  shifted_input1_val, params.input1_multiplier,
+                  params.input1_shift);
+          const int32 scaled_input2_val =
+              MultiplyByQuantizedMultiplierSmallerThanOneExp(
+                  shifted_input2_val, params.input2_multiplier,
+                  params.input2_shift);
+          const int32 raw_sum = scaled_input1_val + scaled_input2_val;
+          const int32 raw_output =
+              MultiplyByQuantizedMultiplierSmallerThanOneExp(
+                  raw_sum, params.output_multiplier, params.output_shift) +
+              params.output_offset;
+          const int32 clamped_output =
+              std::min(params.quantized_activation_max,
+                       std::max(params.quantized_activation_min, raw_output));
+          output_data[Offset(extended_output_shape, b, y, x, c)] =
+              static_cast<uint8>(clamped_output);
+        }
+      }
+    }
+  }
+}
+
+inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params,
+                                 const RuntimeShape& unswitched_input1_shape,
+                                 const uint8* unswitched_input1_data,
+                                 const RuntimeShape& unswitched_input2_shape,
+                                 const uint8* unswitched_input2_data,
+                                 const RuntimeShape& output_shape,
+                                 uint8* output_data) {
+  ArithmeticParams switched_params = unswitched_params;
+  switched_params.input1_offset = unswitched_params.input2_offset;
+  switched_params.input1_multiplier = unswitched_params.input2_multiplier;
+  switched_params.input1_shift = unswitched_params.input2_shift;
+  switched_params.input2_offset = unswitched_params.input1_offset;
+  switched_params.input2_multiplier = unswitched_params.input1_multiplier;
+  switched_params.input2_shift = unswitched_params.input1_shift;
+
+  const bool use_unswitched =
+      unswitched_params.broadcast_category ==
+      tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
+
+  const ArithmeticParams& params =
+      use_unswitched ? unswitched_params : switched_params;
+  const uint8* input1_data =
+      use_unswitched ? unswitched_input1_data : unswitched_input2_data;
+  const uint8* input2_data =
+      use_unswitched ? unswitched_input2_data : unswitched_input1_data;
+
+  // Fivefold nested loops. The second input resets its position for each
+  // iteration of the second loop. The first input resets its position at the
+  // beginning of the fourth loop. The innermost loop is an elementwise add of
+  // sections of the arrays.
+  uint8* output_data_ptr = output_data;
+  const uint8* input1_data_ptr = input1_data;
+  const uint8* input2_data_reset = input2_data;
+  // In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared
+  // between input shapes. y3 for input 1 is always broadcast, and so the
+  // dimension there is 1, whereas optionally y1 might be broadcast for input 2.
+  // Put another way,
+  // input1.shape.FlatSize = y0 * y1 * y2 * y4,
+  // input2.shape.FlatSize = y0 * y2 * y3 * y4.
+  int y0 = params.broadcast_shape[0];
+  int y1 = params.broadcast_shape[1];
+  int y2 = params.broadcast_shape[2];
+  int y3 = params.broadcast_shape[3];
+  int y4 = params.broadcast_shape[4];
+  if (y4 > 1) {
+    // General fivefold pattern, with y4 > 1 so there is a non-broadcast inner
+    // dimension.
+    for (int i0 = 0; i0 < y0; ++i0) {
+      const uint8* input2_data_ptr;
+      for (int i1 = 0; i1 < y1; ++i1) {
+        input2_data_ptr = input2_data_reset;
+        for (int i2 = 0; i2 < y2; ++i2) {
+          for (int i3 = 0; i3 < y3; ++i3) {
+            AddElementwise(y4, params, input1_data_ptr, input2_data_ptr,
+                           output_data_ptr);
+            input2_data_ptr += y4;
+            output_data_ptr += y4;
+          }
+          // We have broadcast y4 of input1 data y3 times, and now move on.
+          input1_data_ptr += y4;
+        }
+      }
+      // We have broadcast y2*y3*y4 of input2 data y1 times, and now move on.
+      input2_data_reset = input2_data_ptr;
+    }
+  } else {
+    // Special case of y4 == 1, in which the innermost loop is a single element
+    // and can be combined with the next (y3) as an inner broadcast.
+    //
+    // Note that this handles the case of pure scalar broadcast when
+    // y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar
+    // broadcast with batch (as y2 > 1).
+    //
+    // NOTE The process is the same as the above general case except simplified
+    // for y4 == 1 and the loop over y3 is contained within the
+    // AddScalarBroadcast function.
+    for (int i0 = 0; i0 < y0; ++i0) {
+      const uint8* input2_data_ptr;
+      for (int i1 = 0; i1 < y1; ++i1) {
+        input2_data_ptr = input2_data_reset;
+        for (int i2 = 0; i2 < y2; ++i2) {
+          AddScalarBroadcast(y3, params, *input1_data_ptr, input2_data_ptr,
+                             output_data_ptr);
+          input2_data_ptr += y3;
+          output_data_ptr += y3;
+          input1_data_ptr += 1;
+        }
+      }
+      input2_data_reset = input2_data_ptr;
+    }
+  }
+}
+
+}  // namespace reference_ops
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ADD_H_
diff --git a/tensorflow/lite/kernels/internal/reference/arg_min_max.h b/tensorflow/lite/kernels/internal/reference/arg_min_max.h
new file mode 100644
index 0000000..e6f34fd
--- /dev/null
+++ b/tensorflow/lite/kernels/internal/reference/arg_min_max.h
@@ -0,0 +1,68 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ARG_MIN_MAX_H_
+#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ARG_MIN_MAX_H_
+
+#include "tensorflow/lite/kernels/internal/types.h"
+
+namespace tflite {
+
+namespace reference_ops {
+
+template <typename T1, typename T2, typename T3, typename Cmp>
+void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
+               const T3* input2_data, const RuntimeShape& output_shape,
+               T2* output_data, const Cmp& cmp) {
+  TFLITE_DCHECK_GT(input1_shape.DimensionsCount(), 0);
+  TFLITE_DCHECK_EQ(input1_shape.DimensionsCount() - 1,
+                   output_shape.DimensionsCount());
+  int axis = input2_data[0];
+  if (axis < 0) {
+    axis += input1_shape.DimensionsCount();
+  }
+  const int axis_size = input1_shape.Dims(axis);
+
+  int outer_size = 1;
+  for (int i = 0; i < axis; ++i) {
+    TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i));
+    outer_size *= input1_shape.Dims(i);
+  }
+
+  int inner_size = 1;
+  const int dims_count = input1_shape.DimensionsCount();
+  for (int i = axis + 1; i < dims_count; ++i) {
+    TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i - 1));
+    inner_size *= input1_shape.Dims(i);
+  }
+  for (int outer = 0; outer < outer_size; ++outer) {
+    for (int inner = 0; inner < inner_size; ++inner) {
+      auto min_max_value = input1_data[outer * axis_size * inner_size + inner];
+      T2 min_max_index = 0;
+      for (int i = 1; i < axis_size; ++i) {
+        const auto& curr_value =
+            input1_data[(outer * axis_size + i) * inner_size + inner];
+        if (cmp(curr_value, min_max_value)) {
+          min_max_value = curr_value;
+          min_max_index = static_cast<T2>(i);
+        }
+      }
+      output_data[outer * inner_size + inner] = min_max_index;
+    }
+  }
+}
+}  // namespace reference_ops
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ARG_MIN_MAX_H_
diff --git a/tensorflow/lite/kernels/internal/reference/binary_function.h b/tensorflow/lite/kernels/internal/reference/binary_function.h
new file mode 100644
index 0000000..82095af
--- /dev/null
+++ b/tensorflow/lite/kernels/internal/reference/binary_function.h
@@ -0,0 +1,84 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BINARY_FUNCTION_H_
+#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BINARY_FUNCTION_H_
+
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/compatibility.h"
+#include "tensorflow/lite/kernels/internal/types.h"
+
+namespace tflite {
+
+namespace reference_ops {
+
+// TODO(ycling): Refactoring. Remove BroadcastLogical and use the more
+// generalized and efficient BroadcastBinaryFunction.
+//
+// Also appears to duplicte MinimumMaximum.
+//
+// R: Result type. T1: Input 1 type. T2: Input 2 type.
+template <typename R, typename T1, typename T2>
+inline void BroadcastBinaryFunction4DSlow(
+    const RuntimeShape& unextended_input1_shape, const T1* input1_data,
+    const RuntimeShape& unextended_input2_shape, const T2* input2_data,
+    const RuntimeShape& unextended_output_shape, R* output_data,
+    R (*func)(T1, T2)) {
+  TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+  const RuntimeShape output_shape =
+      RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+  NdArrayDesc<4> desc1;
+  NdArrayDesc<4> desc2;
+  NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+                                      unextended_input2_shape, &desc1, &desc2);
+
+  for (int b = 0; b < output_shape.Dims(0); ++b) {
+    for (int y = 0; y < output_shape.Dims(1); ++y) {
+      for (int x = 0; x < output_shape.Dims(2); ++x) {
+        for (int c = 0; c < output_shape.Dims(3); ++c) {
+          auto out_idx = Offset(output_shape, b, y, x, c);
+          auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
+          auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
+          auto in1_val = input1_data[in1_idx];
+          auto in2_val = input2_data[in2_idx];
+          output_data[out_idx] = func(in1_val, in2_val);
+        }
+      }
+    }
+  }
+}
+
+// R: Result type. T1: Input 1 type. T2: Input 2 type.
+// TODO(renjieliu): Refactor other binary functions to use this one.
+template <typename R, typename T1, typename T2>
+inline void BinaryFunction(const RuntimeShape& input1_shape,
+                           const T1* input1_data,
+                           const RuntimeShape& input2_shape,
+                           const T2* input2_data,
+                           const RuntimeShape& output_shape, R* output_data,
+                           R (*func)(T1, T2)) {
+  const int flat_size =
+      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+  for (int i = 0; i < flat_size; ++i) {
+    output_data[i] = func(input1_data[i], input2_data[i]);
+  }
+}
+
+}  // namespace reference_ops
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BINARY_FUNCTION_H_
diff --git a/tensorflow/lite/kernels/internal/reference/comparisons.h b/tensorflow/lite/kernels/internal/reference/comparisons.h
new file mode 100644
index 0000000..7f8072f
--- /dev/null
+++ b/tensorflow/lite/kernels/internal/reference/comparisons.h
@@ -0,0 +1,276 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
+#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
+
+#include "profiling/instrumentation.h"
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/types.h"
+
+namespace tflite {
+
+namespace reference_ops {
+
+template <typename T>
+inline bool EqualFn(T lhs, T rhs) {
+  return lhs == rhs;
+}
+
+template <typename T>
+inline bool NotEqualFn(T lhs, T rhs) {
+  return lhs != rhs;
+}
+
+template <typename T>
+inline bool GreaterFn(T lhs, T rhs) {
+  return lhs > rhs;
+}
+template <typename T>
+inline bool GreaterEqualFn(T lhs, T rhs) {
+  return lhs >= rhs;
+}
+template <typename T>
+inline bool LessFn(T lhs, T rhs) {
+  return lhs < rhs;
+}
+template <typename T>
+inline bool LessEqualFn(T lhs, T rhs) {
+  return lhs <= rhs;
+}
+
+template <typename T>
+using ComparisonFn = bool (*)(T, T);
+
+template <typename T, ComparisonFn<T> F>
+inline void ComparisonImpl(
+    const ComparisonParams& op_params, const RuntimeShape& input1_shape,
+    const T* input1_data, const RuntimeShape& input2_shape,
+    const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
+  const int64_t flatsize =
+      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+  for (int64_t i = 0; i < flatsize; ++i) {
+    output_data[i] = F(input1_data[i], input2_data[i]);
+  }
+}
+
+template <ComparisonFn<float> F>
+inline void Comparison(const ComparisonParams& op_params,
+                       const RuntimeShape& input1_shape,
+                       const float* input1_data,
+                       const RuntimeShape& input2_shape,
+                       const float* input2_data,
+                       const RuntimeShape& output_shape, bool* output_data) {
+  ComparisonImpl<float, F>(op_params, input1_shape, input1_data, input2_shape,
+                           input2_data, output_shape, output_data);
+}
+
+template <typename T, ComparisonFn<int32> F>
+inline void ComparisonWithScaling(
+    const ComparisonParams& op_params, const RuntimeShape& input1_shape,
+    const T* input1_data, const RuntimeShape& input2_shape,
+    const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
+  int left_shift = op_params.left_shift;
+  int32 input1_offset = op_params.input1_offset;
+  int32 input1_multiplier = op_params.input1_multiplier;
+  int input1_shift = op_params.input1_shift;
+  int32 input2_offset = op_params.input2_offset;
+  int32 input2_multiplier = op_params.input2_multiplier;
+  int input2_shift = op_params.input2_shift;
+
+  const int64_t flatsize =
+      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+  for (int64_t i = 0; i < flatsize; ++i) {
+    const int32 input1_val = input1_offset + input1_data[i];
+    const int32 input2_val = input2_offset + input2_data[i];
+    const int32 shifted_input1_val = input1_val * (1 << left_shift);
+    const int32 shifted_input2_val = input2_val * (1 << left_shift);
+    const int32 scaled_input1_val =
+        MultiplyByQuantizedMultiplierSmallerThanOneExp(
+            shifted_input1_val, input1_multiplier, input1_shift);
+    const int32 scaled_input2_val =
+        MultiplyByQuantizedMultiplierSmallerThanOneExp(
+            shifted_input2_val, input2_multiplier, input2_shift);
+    output_data[i] = F(scaled_input1_val, scaled_input2_val);
+  }
+}
+
+template <typename T, ComparisonFn<T> F>
+inline void BroadcastComparison4DSlowImpl(
+    const ComparisonParams& op_params,
+    const RuntimeShape& unextended_input1_shape, const T* input1_data,
+    const RuntimeShape& unextended_input2_shape, const T* input2_data,
+    const RuntimeShape& unextended_output_shape, bool* output_data) {
+  gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlow");
+  TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+  const RuntimeShape output_shape =
+      RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+  NdArrayDesc<4> desc1;
+  NdArrayDesc<4> desc2;
+  NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+                                      unextended_input2_shape, &desc1, &desc2);
+
+  for (int b = 0; b < output_shape.Dims(0); ++b) {
+    for (int y = 0; y < output_shape.Dims(1); ++y) {
+      for (int x = 0; x < output_shape.Dims(2); ++x) {
+        for (int c = 0; c < output_shape.Dims(3); ++c) {
+          output_data[Offset(output_shape, b, y, x, c)] =
+              F(input1_data[SubscriptToIndex(desc1, b, y, x, c)],
+                input2_data[SubscriptToIndex(desc2, b, y, x, c)]);
+        }
+      }
+    }
+  }
+}
+template <ComparisonFn<float> F>
+inline void BroadcastComparison4DSlow(const ComparisonParams& op_params,
+                                      const RuntimeShape& input1_shape,
+                                      const float* input1_data,
+                                      const RuntimeShape& input2_shape,
+                                      const float* input2_data,
+                                      const RuntimeShape& output_shape,
+                                      bool* output_data) {
+  BroadcastComparison4DSlowImpl<float, F>(op_params, input1_shape, input1_data,
+                                          input2_shape, input2_data,
+                                          output_shape, output_data);
+}
+
+template <typename T, ComparisonFn<int32> F>
+inline void BroadcastComparison4DSlowWithScaling(
+    const ComparisonParams& op_params,
+    const RuntimeShape& unextended_input1_shape, const T* input1_data,
+    const RuntimeShape& unextended_input2_shape, const T* input2_data,
+    const RuntimeShape& unextended_output_shape, bool* output_data) {
+  gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlowWithScaling");
+  TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+  const RuntimeShape output_shape =
+      RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+  NdArrayDesc<4> desc1;
+  NdArrayDesc<4> desc2;
+  NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+                                      unextended_input2_shape, &desc1, &desc2);
+
+  int left_shift = op_params.left_shift;
+  int32 input1_offset = op_params.input1_offset;
+  int32 input1_multiplier = op_params.input1_multiplier;
+  int input1_shift = op_params.input1_shift;
+  int32 input2_offset = op_params.input2_offset;
+  int32 input2_multiplier = op_params.input2_multiplier;
+  int input2_shift = op_params.input2_shift;
+
+  for (int b = 0; b < output_shape.Dims(0); ++b) {
+    for (int y = 0; y < output_shape.Dims(1); ++y) {
+      for (int x = 0; x < output_shape.Dims(2); ++x) {
+        for (int c = 0; c < output_shape.Dims(3); ++c) {
+          const int32 input1_val =
+              input1_offset + input1_data[SubscriptToIndex(desc1, b, y, x, c)];
+          const int32 input2_val =
+              input2_offset + input2_data[SubscriptToIndex(desc2, b, y, x, c)];
+          const int32 shifted_input1_val = input1_val * (1 << left_shift);
+          const int32 shifted_input2_val = input2_val * (1 << left_shift);
+          const int32 scaled_input1_val =
+              MultiplyByQuantizedMultiplierSmallerThanOneExp(
+                  shifted_input1_val, input1_multiplier, input1_shift);
+          const int32 scaled_input2_val =
+              MultiplyByQuantizedMultiplierSmallerThanOneExp(
+                  shifted_input2_val, input2_multiplier, input2_shift);
+          output_data[Offset(output_shape, b, y, x, c)] =
+              F(scaled_input1_val, scaled_input2_val);
+        }
+      }
+    }
+  }
+}
+
+#define TFLITE_COMPARISON_OP(name)                                             \
+  inline void name(const ComparisonParams& op_params,                          \
+                   const RuntimeShape& input1_shape, const float* input1_data, \
+                   const RuntimeShape& input2_shape, const float* input2_data, \
+                   const RuntimeShape& output_shape, bool* output_data) {      \
+    gemmlowp::ScopedProfilingLabel label(#name);                               \
+    Comparison<name##Fn>(op_params, input1_shape, input1_data, input2_shape,   \
+                         input2_data, output_shape, output_data);              \
+  }                                                                            \
+  template <typename T>                                                        \
+  inline void name##NoScaling(                                                 \
+      const ComparisonParams& op_params, const RuntimeShape& input1_shape,     \
+      const T* input1_data, const RuntimeShape& input2_shape,                  \
+      const T* input2_data, const RuntimeShape& output_shape,                  \
+      bool* output_data) {                                                     \
+    gemmlowp::ScopedProfilingLabel label(#name "NoScaling");                   \
+    ComparisonImpl<T, name##Fn>(op_params, input1_shape, input1_data,          \
+                                input2_shape, input2_data, output_shape,       \
+                                output_data);                                  \
+  }                                                                            \
+  template <typename T>                                                        \
+  inline void name##WithScaling(                                               \
+      const ComparisonParams& op_params, const RuntimeShape& input1_shape,     \
+      const T* input1_data, const RuntimeShape& input2_shape,                  \
+      const T* input2_data, const RuntimeShape& output_shape,                  \
+      bool* output_data) {                                                     \
+    gemmlowp::ScopedProfilingLabel label(#name "WithScaling/8bit");            \
+    ComparisonWithScaling<T, name##Fn>(op_params, input1_shape, input1_data,   \
+                                       input2_shape, input2_data,              \
+                                       output_shape, output_data);             \
+  }                                                                            \
+  template <typename T>                                                        \
+  inline void Broadcast4DSlow##name##NoScaling(                                \
+      const ComparisonParams& op_params, const RuntimeShape& input1_shape,     \
+      const T* input1_data, const RuntimeShape& input2_shape,                  \
+      const T* input2_data, const RuntimeShape& output_shape,                  \
+      bool* output_data) {                                                     \
+    gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name "NoScaling"); \
+    BroadcastComparison4DSlowImpl<T, name##Fn>(                                \
+        op_params, input1_shape, input1_data, input2_shape, input2_data,       \
+        output_shape, output_data);                                            \
+  }                                                                            \
+  inline void Broadcast4DSlow##name(                                           \
+      const ComparisonParams& op_params, const RuntimeShape& input1_shape,     \
+      const float* input1_data, const RuntimeShape& input2_shape,              \
+      const float* input2_data, const RuntimeShape& output_shape,              \
+      bool* output_data) {                                                     \
+    gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name);             \
+    BroadcastComparison4DSlow<name##Fn>(op_params, input1_shape, input1_data,  \
+                                        input2_shape, input2_data,             \
+                                        output_shape, output_data);            \
+  }                                                                            \
+  template <typename T>                                                        \
+  inline void Broadcast4DSlow##name##WithScaling(                              \
+      const ComparisonParams& op_params, const RuntimeShape& input1_shape,     \
+      const T* input1_data, const RuntimeShape& input2_shape,                  \
+      const T* input2_data, const RuntimeShape& output_shape,                  \
+      bool* output_data) {                                                     \
+    gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name "/8bit");     \
+    BroadcastComparison4DSlowWithScaling<T, name##Fn>(                         \
+        op_params, input1_shape, input1_data, input2_shape, input2_data,       \
+        output_shape, output_data);                                            \
+  }
+TFLITE_COMPARISON_OP(Equal);
+TFLITE_COMPARISON_OP(NotEqual);
+TFLITE_COMPARISON_OP(Greater);
+TFLITE_COMPARISON_OP(GreaterEqual);
+TFLITE_COMPARISON_OP(Less);
+TFLITE_COMPARISON_OP(LessEqual);
+#undef TFLITE_COMPARISON_OP
+
+}  // namespace reference_ops
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
diff --git a/tensorflow/lite/kernels/internal/reference/floor.h b/tensorflow/lite/kernels/internal/reference/floor.h
new file mode 100644
index 0000000..0693fd4
--- /dev/null
+++ b/tensorflow/lite/kernels/internal/reference/floor.h
@@ -0,0 +1,39 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FLOOR_H_
+#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FLOOR_H_
+
+#include <cmath>
+
+#include "tensorflow/lite/kernels/internal/types.h"
+
+namespace tflite {
+
+namespace reference_ops {
+
+inline void Floor(const RuntimeShape& input_shape, const float* input_data,
+                  const RuntimeShape& output_shape, float* output_data) {
+  const int flat_size = MatchingFlatSize(input_shape, output_shape);
+
+  for (int i = 0; i < flat_size; i++) {
+    int offset = i;
+    output_data[offset] = std::floor(input_data[offset]);
+  }
+}
+
+}  // namespace reference_ops
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FLOOR_H_
diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h b/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h
index dad17fb..9c629ff 100644
--- a/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h
+++ b/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h
@@ -30,9 +30,9 @@
     const int32 input2_val = params.input2_offset + input2_data[i];
     const int32 unclamped_result =
         params.output_offset +
-        MultiplyByQuantizedMultiplierSmallerThanOneExp(input1_val * input2_val,
-                                                       params.output_multiplier,
-                                                       params.output_shift);
+        MultiplyByQuantizedMultiplier(input1_val * input2_val,
+                                      params.output_multiplier,
+                                      params.output_shift);
     const int32 clamped_output =
         std::min(params.quantized_activation_max,
                  std::max(params.quantized_activation_min, unclamped_result));
@@ -112,9 +112,9 @@
               input2_data[SubscriptToIndex(desc2, b, y, x, c)];
           const int32 unclamped_result =
               params.output_offset +
-              MultiplyByQuantizedMultiplierSmallerThanOneExp(
-                  input1_val * input2_val, params.output_multiplier,
-                  params.output_shift);
+              MultiplyByQuantizedMultiplier(input1_val * input2_val,
+                                            params.output_multiplier,
+                                            params.output_shift);
           const int32 clamped_output = std::min(
               params.quantized_activation_max,
               std::max(params.quantized_activation_min, unclamped_result));
diff --git a/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h
index 082f86e..615abdf 100644
--- a/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h
+++ b/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h
@@ -2192,25 +2192,6 @@
                      DimsToShape(output_dims), output_data);
 }
 
-inline void Logical(const bool* input1_data, const Dims<4>& input1_dims,
-                    const bool* input2_data, const Dims<4>& input2_dims,
-                    bool* output_data, const Dims<4>& output_dims,
-                    const std::function<bool(bool, bool)>& func) {
-  Logical(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
-          input2_data, DimsToShape(output_dims), output_data, func);
-}
-
-inline void BroadcastLogical(const bool* input1_data,
-                             const Dims<4>& input1_dims,
-                             const bool* input2_data,
-                             const Dims<4>& input2_dims, bool* output_data,
-                             const Dims<4>& output_dims,
-                             const std::function<bool(bool, bool)>& func) {
-  BroadcastLogical4DSlow(DimsToShape(input1_dims), input1_data,
-                         DimsToShape(input2_dims), input2_data,
-                         DimsToShape(output_dims), output_data, func);
-}
-
 // R: Result type. T1: Input 1 type. T2: Input 2 type.
 template <typename R, typename T1, typename T2>
 inline void BroadcastBinaryFunction(const T1* input1_data,
diff --git a/tensorflow/lite/kernels/internal/reference/maximum_minimum.h b/tensorflow/lite/kernels/internal/reference/maximum_minimum.h
new file mode 100644
index 0000000..480069a
--- /dev/null
+++ b/tensorflow/lite/kernels/internal/reference/maximum_minimum.h
@@ -0,0 +1,61 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_MAXIMUM_MINIMUM_H_
+#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_MAXIMUM_MINIMUM_H_
+
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace reference_ops {
+
+template <typename T, typename Op>
+void MaximumMinimumBroadcast4DSlow(const RuntimeShape& unextended_input1_shape,
+                                   const T* input1_data,
+                                   const RuntimeShape& unextended_input2_shape,
+                                   const T* input2_data,
+                                   const RuntimeShape& unextended_output_shape,
+                                   T* output_data, Op op) {
+  TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+  const RuntimeShape output_shape =
+      RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+  NdArrayDesc<4> desc1;
+  NdArrayDesc<4> desc2;
+  NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+                                      unextended_input2_shape, &desc1, &desc2);
+
+  for (int b = 0; b < output_shape.Dims(0); ++b) {
+    for (int y = 0; y < output_shape.Dims(1); ++y) {
+      for (int x = 0; x < output_shape.Dims(2); ++x) {
+        for (int c = 0; c < output_shape.Dims(3); ++c) {
+          auto out_idx = Offset(output_shape, b, y, x, c);
+          auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
+          auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
+          auto in1_val = input1_data[in1_idx];
+          auto in2_val = input2_data[in2_idx];
+          output_data[out_idx] = op(in1_val, in2_val);
+        }
+      }
+    }
+  }
+}
+
+}  // namespace reference_ops
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_MAXIMUM_MINIMUM_H_
diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc
index 472425e..0b91677 100644
--- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc
+++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -301,14 +301,6 @@
   }
 }
 
-void PortableVectorShiftLeft(float* vector, int v_size, float shift_value) {
-  TF_LITE_ASSERT(v_size > 0);
-  for (int i = 0; i < v_size - 1; i++) {
-    vector[i] = vector[i + 1];
-  }
-  vector[v_size - 1] = shift_value;
-}
-
 void PortableReductionSumVector(const float* input_vector, float* output_vector,
                                 int output_size, int reduction_size) {
   const float* input_vector_ptr = input_vector;
diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h
index 28ca981..fb24c9b 100644
--- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h
+++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h
@@ -155,10 +155,6 @@
   PortableClipVector(vector, v_size, abs_limit, result);
 }
 
-void VectorShiftLeft(float* vector, int v_size, float shift_value) {
-  PortableVectorShiftLeft(vector, v_size, shift_value);
-}
-
 void ReductionSumVector(const float* input_vector, float* output_vector,
                         int output_size, int reduction_size) {
   PortableReductionSumVector(input_vector, output_vector, output_size,
diff --git a/tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h b/tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h
new file mode 100644
index 0000000..d903022
--- /dev/null
+++ b/tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h
@@ -0,0 +1,119 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_
+#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_
+
+#include "tensorflow/lite/kernels/internal/types.h"
+
+namespace tflite {
+
+namespace reference_ops {
+
+// Return true for broadcast case, false otherwise.
+inline bool ProcessBroadcastShapes(const RuntimeShape& shape0,
+                                   const RuntimeShape& shape1,
+                                   tflite::ArithmeticParams* params) {
+  const int dims_count =
+      std::max(shape0.DimensionsCount(), shape1.DimensionsCount());
+
+  params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
+  RuntimeShape scalar_shape(dims_count, 1);
+
+  auto extended_shape0 = RuntimeShape::ExtendedShape(dims_count, shape0);
+  auto extended_shape1 = RuntimeShape::ExtendedShape(dims_count, shape1);
+
+  // Check for "exact" match, implicitly accepting any scalar shapes.
+  if (extended_shape0 == extended_shape1) {
+    params->broadcast_category = BroadcastableOpCategory::kNonBroadcast;
+    return false;
+  }
+
+  for (int i = dims_count - 1; i >= 0; --i) {
+    if (extended_shape0.Dims(i) == extended_shape1.Dims(i)) {
+      continue;
+    } else if (extended_shape0.Dims(i) == 1) {
+      params->broadcast_category =
+          BroadcastableOpCategory::kFirstInputBroadcastsFast;
+      break;
+    } else if (extended_shape1.Dims(i) == 1) {
+      params->broadcast_category =
+          BroadcastableOpCategory::kSecondInputBroadcastsFast;
+      break;
+    } else {
+      params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
+      break;
+    }
+  }
+
+  if (params->broadcast_category !=
+          BroadcastableOpCategory::kFirstInputBroadcastsFast &&
+      params->broadcast_category !=
+          BroadcastableOpCategory::kSecondInputBroadcastsFast) {
+    return false;
+  }
+
+  // From this point it is assumed contractually that corresponding dimensions
+  // in shape0 and shape1 are either (a) equal or (b) one or other equals 1.
+  const bool swap_inputs = params->broadcast_category ==
+                           BroadcastableOpCategory::kSecondInputBroadcastsFast;
+  const RuntimeShape* shape_a =
+      swap_inputs ? &extended_shape1 : &extended_shape0;
+  const RuntimeShape* shape_b =
+      swap_inputs ? &extended_shape0 : &extended_shape1;
+
+  int i = dims_count - 1;
+  params->broadcast_shape[0] = 1;
+  params->broadcast_shape[1] = 1;
+  params->broadcast_shape[2] = 1;
+  params->broadcast_shape[3] = 1;
+  params->broadcast_shape[4] = 1;
+  // y_0 is greedy: include dims if both or neither equal 1: in other words,
+  // test for equality rather than (shape_a->Dims(i) != 1).
+  while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
+    params->broadcast_shape[4] *= shape_b->Dims(i);
+    --i;
+  }
+  // Here either input_a or input_b has dim of 1 (if i >= 0).  If it is input_b
+  // that has the unit dimension, the next two loops are not entered.
+  while (i >= 0 && shape_a->Dims(i) == 1) {
+    params->broadcast_shape[3] *= shape_b->Dims(i);
+    --i;
+  }
+  while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
+    params->broadcast_shape[2] *= shape_a->Dims(i);
+    --i;
+  }
+  // Here either input_a or input_b has dim of 1 (if i >= 0).
+  while (i >= 0 && shape_b->Dims(i) == 1) {
+    params->broadcast_shape[1] *= shape_a->Dims(i);
+    --i;
+  }
+  while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
+    params->broadcast_shape[0] *= shape_b->Dims(i);
+    --i;
+  }
+
+  // Rarer case is when the broadcast dimensions cannot be handled by a fivefold
+  // loop.
+  if (i >= 0) {
+    params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
+  }
+  return true;
+}
+
+}  // namespace reference_ops
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_
diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h
index 92b3b47..91fd8d9 100644
--- a/tensorflow/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h
@@ -32,10 +32,17 @@
 #include "tensorflow/lite/c/c_api_internal.h"
 #include "tensorflow/lite/kernels/internal/common.h"
 #include "tensorflow/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/lite/kernels/internal/reference/add.h"
+#include "tensorflow/lite/kernels/internal/reference/arg_min_max.h"
+#include "tensorflow/lite/kernels/internal/reference/binary_function.h"
+#include "tensorflow/lite/kernels/internal/reference/comparisons.h"
 #include "tensorflow/lite/kernels/internal/reference/conv.h"
+#include "tensorflow/lite/kernels/internal/reference/floor.h"
 #include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
+#include "tensorflow/lite/kernels/internal/reference/maximum_minimum.h"
 #include "tensorflow/lite/kernels/internal/reference/pooling.h"
 #include "tensorflow/lite/kernels/internal/reference/prelu.h"
+#include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
 #include "tensorflow/lite/kernels/internal/reference/softmax.h"
 #include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
 #include "tensorflow/lite/kernels/internal/round.h"
@@ -47,98 +54,6 @@
 
 namespace reference_ops {
 
-// Return true for broadcast case, false otherwise.
-inline bool ProcessBroadcastShapes(const RuntimeShape& shape0,
-                                   const RuntimeShape& shape1,
-                                   tflite::ArithmeticParams* params) {
-  const int dims_count =
-      std::max(shape0.DimensionsCount(), shape1.DimensionsCount());
-
-  params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
-  RuntimeShape scalar_shape(dims_count, 1);
-
-  auto extended_shape0 = RuntimeShape::ExtendedShape(dims_count, shape0);
-  auto extended_shape1 = RuntimeShape::ExtendedShape(dims_count, shape1);
-
-  // Check for "exact" match, implicitly accepting any scalar shapes.
-  if (extended_shape0 == extended_shape1) {
-    params->broadcast_category = BroadcastableOpCategory::kNonBroadcast;
-    return false;
-  }
-
-  for (int i = dims_count - 1; i >= 0; --i) {
-    if (extended_shape0.Dims(i) == extended_shape1.Dims(i)) {
-      continue;
-    } else if (extended_shape0.Dims(i) == 1) {
-      params->broadcast_category =
-          BroadcastableOpCategory::kFirstInputBroadcastsFast;
-      break;
-    } else if (extended_shape1.Dims(i) == 1) {
-      params->broadcast_category =
-          BroadcastableOpCategory::kSecondInputBroadcastsFast;
-      break;
-    } else {
-      params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
-      break;
-    }
-  }
-
-  if (params->broadcast_category !=
-          BroadcastableOpCategory::kFirstInputBroadcastsFast &&
-      params->broadcast_category !=
-          BroadcastableOpCategory::kSecondInputBroadcastsFast) {
-    return false;
-  }
-
-  // From this point it is assumed contractually that corresponding dimensions
-  // in shape0 and shape1 are either (a) equal or (b) one or other equals 1.
-  const bool swap_inputs = params->broadcast_category ==
-                           BroadcastableOpCategory::kSecondInputBroadcastsFast;
-  const RuntimeShape* shape_a =
-      swap_inputs ? &extended_shape1 : &extended_shape0;
-  const RuntimeShape* shape_b =
-      swap_inputs ? &extended_shape0 : &extended_shape1;
-
-  int i = dims_count - 1;
-  params->broadcast_shape[0] = 1;
-  params->broadcast_shape[1] = 1;
-  params->broadcast_shape[2] = 1;
-  params->broadcast_shape[3] = 1;
-  params->broadcast_shape[4] = 1;
-  // y_0 is greedy: include dims if both or neither equal 1: in other words,
-  // test for equality rather than (shape_a->Dims(i) != 1).
-  while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
-    params->broadcast_shape[4] *= shape_b->Dims(i);
-    --i;
-  }
-  // Here either input_a or input_b has dim of 1 (if i >= 0).  If it is input_b
-  // that has the unit dimension, the next two loops are not entered.
-  while (i >= 0 && shape_a->Dims(i) == 1) {
-    params->broadcast_shape[3] *= shape_b->Dims(i);
-    --i;
-  }
-  while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
-    params->broadcast_shape[2] *= shape_a->Dims(i);
-    --i;
-  }
-  // Here either input_a or input_b has dim of 1 (if i >= 0).
-  while (i >= 0 && shape_b->Dims(i) == 1) {
-    params->broadcast_shape[1] *= shape_a->Dims(i);
-    --i;
-  }
-  while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
-    params->broadcast_shape[0] *= shape_b->Dims(i);
-    --i;
-  }
-
-  // Rarer case is when the broadcast dimensions cannot be handled by a fivefold
-  // loop.
-  if (i >= 0) {
-    params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
-  }
-  return true;
-}
-
 template <typename T>
 inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
                          const RuntimeShape& unextended_input_shape,
@@ -403,32 +318,6 @@
   }
 }
 
-template <typename T>
-inline void Add(const ArithmeticParams& params,
-                const RuntimeShape& input1_shape, const T* input1_data,
-                const RuntimeShape& input2_shape, const T* input2_data,
-                const RuntimeShape& output_shape, T* output_data) {
-  const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
-  for (int i = 0; i < flat_size; ++i) {
-    output_data[i] = ActivationFunctionWithMinMax(
-        input1_data[i] + input2_data[i], params.quantized_activation_min,
-        params.quantized_activation_max);
-  }
-}
-
-inline void Add(const ArithmeticParams& params,
-                const RuntimeShape& input1_shape, const float* input1_data,
-                const RuntimeShape& input2_shape, const float* input2_data,
-                const RuntimeShape& output_shape, float* output_data) {
-  const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
-  for (int i = 0; i < size; i++) {
-    auto x = input1_data[i] + input2_data[i];
-    output_data[i] = ActivationFunctionWithMinMax(
-        x, params.float_activation_min, params.float_activation_max);
-  }
-}
-
 // T is expected to be either float or int.
 template <typename T>
 inline void AddN(const RuntimeShape& input_shape, const size_t num_inputs,
@@ -445,373 +334,6 @@
   }
 }
 
-// Element-wise add that can often be used for inner loop of broadcast add as
-// well as the non-broadcast add.
-inline void AddElementwise(int size, const ArithmeticParams& params,
-                           const uint8* input1_data, const uint8* input2_data,
-                           uint8* output_data) {
-  TFLITE_DCHECK_GT(params.input1_offset, -256);
-  TFLITE_DCHECK_GT(params.input2_offset, -256);
-  TFLITE_DCHECK_LT(params.input1_offset, 256);
-  TFLITE_DCHECK_LT(params.input2_offset, 256);
-
-  for (int i = 0; i < size; ++i) {
-    const int32 input1_val = params.input1_offset + input1_data[i];
-    const int32 input2_val = params.input2_offset + input2_data[i];
-    const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
-    const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
-    const int32 scaled_input1_val =
-        MultiplyByQuantizedMultiplierSmallerThanOneExp(
-            shifted_input1_val, params.input1_multiplier, params.input1_shift);
-    const int32 scaled_input2_val =
-        MultiplyByQuantizedMultiplierSmallerThanOneExp(
-            shifted_input2_val, params.input2_multiplier, params.input2_shift);
-    const int32 raw_sum = scaled_input1_val + scaled_input2_val;
-    const int32 raw_output =
-        MultiplyByQuantizedMultiplierSmallerThanOneExp(
-            raw_sum, params.output_multiplier, params.output_shift) +
-        params.output_offset;
-    const int32 clamped_output =
-        std::min(params.quantized_activation_max,
-                 std::max(params.quantized_activation_min, raw_output));
-    output_data[i] = static_cast<uint8>(clamped_output);
-  }
-}
-
-// Scalar-broadcast add that can be used for inner loop of more general
-// broadcast add, so that, for example, scalar-broadcast with batch will still
-// be fast.
-inline void AddScalarBroadcast(int size, const ArithmeticParams& params,
-                               uint8 input1_data, const uint8* input2_data,
-                               uint8* output_data) {
-  TFLITE_DCHECK_GT(params.input1_offset, -256);
-  TFLITE_DCHECK_GT(params.input2_offset, -256);
-  TFLITE_DCHECK_LT(params.input1_offset, 256);
-  TFLITE_DCHECK_LT(params.input2_offset, 256);
-
-  const int32 input1_val = params.input1_offset + input1_data;
-  const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
-  const int32 scaled_input1_val =
-      MultiplyByQuantizedMultiplierSmallerThanOneExp(
-          shifted_input1_val, params.input1_multiplier, params.input1_shift);
-  for (int i = 0; i < size; ++i) {
-    const int32 input2_val = params.input2_offset + input2_data[i];
-    const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
-    const int32 scaled_input2_val =
-        MultiplyByQuantizedMultiplierSmallerThanOneExp(
-            shifted_input2_val, params.input2_multiplier, params.input2_shift);
-    const int32 raw_sum = scaled_input1_val + scaled_input2_val;
-    const int32 raw_output =
-        MultiplyByQuantizedMultiplierSmallerThanOneExp(
-            raw_sum, params.output_multiplier, params.output_shift) +
-        params.output_offset;
-    const int32 clamped_output =
-        std::min(params.quantized_activation_max,
-                 std::max(params.quantized_activation_min, raw_output));
-    output_data[i] = static_cast<uint8>(clamped_output);
-  }
-}
-
-inline void Add(const ArithmeticParams& params,
-                const RuntimeShape& input1_shape, const uint8* input1_data,
-                const RuntimeShape& input2_shape, const uint8* input2_data,
-                const RuntimeShape& output_shape, uint8* output_data) {
-  TFLITE_DCHECK_LE(params.quantized_activation_min,
-                   params.quantized_activation_max);
-  const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
-
-  TFLITE_DCHECK_GT(params.input1_offset, -256);
-  TFLITE_DCHECK_GT(params.input2_offset, -256);
-  TFLITE_DCHECK_LT(params.input1_offset, 256);
-  TFLITE_DCHECK_LT(params.input2_offset, 256);
-  AddElementwise(flat_size, params, input1_data, input2_data, output_data);
-}
-
-inline void Add(const ArithmeticParams& params,
-                const RuntimeShape& input1_shape, const int16* input1_data,
-                const RuntimeShape& input2_shape, const int16* input2_data,
-                const RuntimeShape& output_shape, int16* output_data) {
-  TFLITE_DCHECK_LE(params.quantized_activation_min,
-                   params.quantized_activation_max);
-
-  const int input1_shift = params.input1_shift;
-  const int flat_size =
-      MatchingFlatSize(output_shape, input1_shape, input2_shape);
-  const int16 output_activation_min = params.quantized_activation_min;
-  const int16 output_activation_max = params.quantized_activation_max;
-
-  TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0);
-  TFLITE_DCHECK_LE(input1_shift, 0);
-  TFLITE_DCHECK_LE(params.input2_shift, 0);
-  const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data;
-  const int16* shift_input = input1_shift == 0 ? input2_data : input1_data;
-  const int input_right_shift =
-      input1_shift == 0 ? -params.input2_shift : -input1_shift;
-
-  for (int i = 0; i < flat_size; i++) {
-    // F0 uses 0 integer bits, range [-1, 1].
-    using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
-
-    F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
-    F0 scaled_input = F0::FromRaw(
-        gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
-    F0 result = gemmlowp::SaturatingAdd(scaled_input, input_ready_scaled);
-    const int16 raw_output = result.raw();
-    const int16 clamped_output = std::min(
-        output_activation_max, std::max(output_activation_min, raw_output));
-    output_data[i] = clamped_output;
-  }
-}
-
-// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary
-// dimensionality if the runtime code does a single loop over one dimension
-// that handles broadcasting as the base case. The code generator would then
-// generate max(D1, D2) nested for loops.
-// TODO(benoitjacob): BroadcastAdd is intentionally duplicated from
-// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
-// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
-// reference_ops.h.
-inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
-                               const RuntimeShape& input1_shape,
-                               const float* input1_data,
-                               const RuntimeShape& input2_shape,
-                               const float* input2_data,
-                               const RuntimeShape& output_shape,
-                               float* output_data) {
-  gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/float");
-  NdArrayDesc<4> desc1;
-  NdArrayDesc<4> desc2;
-  NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
-                                      &desc2);
-  const RuntimeShape extended_output_shape =
-      RuntimeShape::ExtendedShape(4, output_shape);
-
-  // In Tensorflow, the dimensions are canonically named (batch_number, row,
-  // col, channel), with extents (batches, height, width, depth), with the
-  // trailing dimension changing most rapidly (channels has the smallest stride,
-  // typically 1 element).
-  //
-  // In generated C code, we store arrays with the dimensions reversed. The
-  // first dimension has smallest stride.
-  //
-  // We name our variables by their Tensorflow convention, but generate C code
-  // nesting loops such that the innermost loop has the smallest stride for the
-  // best cache behavior.
-  for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
-    for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
-      for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
-        for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
-          output_data[Offset(extended_output_shape, b, y, x, c)] =
-              ActivationFunctionWithMinMax(
-                  input1_data[SubscriptToIndex(desc1, b, y, x, c)] +
-                      input2_data[SubscriptToIndex(desc2, b, y, x, c)],
-                  params.float_activation_min, params.float_activation_max);
-        }
-      }
-    }
-  }
-}
-
-inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
-                               const RuntimeShape& input1_shape,
-                               const int32* input1_data,
-                               const RuntimeShape& input2_shape,
-                               const int32* input2_data,
-                               const RuntimeShape& output_shape,
-                               int32* output_data) {
-  gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/int32");
-  NdArrayDesc<4> desc1;
-  NdArrayDesc<4> desc2;
-  NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
-                                      &desc2);
-  const RuntimeShape extended_output_shape =
-      RuntimeShape::ExtendedShape(4, output_shape);
-
-  // In Tensorflow, the dimensions are canonically named (batch_number, row,
-  // col, channel), with extents (batches, height, width, depth), with the
-  // trailing dimension changing most rapidly (channels has the smallest stride,
-  // typically 1 element).
-  //
-  // In generated C code, we store arrays with the dimensions reversed. The
-  // first dimension has smallest stride.
-  //
-  // We name our variables by their Tensorflow convention, but generate C code
-  // nesting loops such that the innermost loop has the smallest stride for the
-  // best cache behavior.
-  for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
-    for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
-      for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
-        for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
-          output_data[Offset(extended_output_shape, b, y, x, c)] =
-              ActivationFunctionWithMinMax(
-                  input1_data[SubscriptToIndex(desc1, b, y, x, c)] +
-                      input2_data[SubscriptToIndex(desc2, b, y, x, c)],
-                  params.quantized_activation_min,
-                  params.quantized_activation_max);
-        }
-      }
-    }
-  }
-}
-
-inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
-                               const RuntimeShape& input1_shape,
-                               const uint8* input1_data,
-                               const RuntimeShape& input2_shape,
-                               const uint8* input2_data,
-                               const RuntimeShape& output_shape,
-                               uint8* output_data) {
-  gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/uint8");
-  NdArrayDesc<4> desc1;
-  NdArrayDesc<4> desc2;
-  NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
-                                      &desc2);
-  const RuntimeShape extended_output_shape =
-      RuntimeShape::ExtendedShape(4, output_shape);
-
-  // In Tensorflow, the dimensions are canonically named (batch_number, row,
-  // col, channel), with extents (batches, height, width, depth), with the
-  // trailing dimension changing most rapidly (channels has the smallest stride,
-  // typically 1 element).
-  //
-  // In generated C code, we store arrays with the dimensions reversed. The
-  // first dimension has smallest stride.
-  //
-  // We name our variables by their Tensorflow convention, but generate C code
-  // nesting loops such that the innermost loop has the smallest stride for the
-  // best cache behavior.
-  for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
-    for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
-      for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
-        for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
-          const int32 input1_val =
-              params.input1_offset +
-              input1_data[SubscriptToIndex(desc1, b, y, x, c)];
-          const int32 input2_val =
-              params.input2_offset +
-              input2_data[SubscriptToIndex(desc2, b, y, x, c)];
-          const int32 shifted_input1_val =
-              input1_val * (1 << params.left_shift);
-          const int32 shifted_input2_val =
-              input2_val * (1 << params.left_shift);
-          const int32 scaled_input1_val =
-              MultiplyByQuantizedMultiplierSmallerThanOneExp(
-                  shifted_input1_val, params.input1_multiplier,
-                  params.input1_shift);
-          const int32 scaled_input2_val =
-              MultiplyByQuantizedMultiplierSmallerThanOneExp(
-                  shifted_input2_val, params.input2_multiplier,
-                  params.input2_shift);
-          const int32 raw_sum = scaled_input1_val + scaled_input2_val;
-          const int32 raw_output =
-              MultiplyByQuantizedMultiplierSmallerThanOneExp(
-                  raw_sum, params.output_multiplier, params.output_shift) +
-              params.output_offset;
-          const int32 clamped_output =
-              std::min(params.quantized_activation_max,
-                       std::max(params.quantized_activation_min, raw_output));
-          output_data[Offset(extended_output_shape, b, y, x, c)] =
-              static_cast<uint8>(clamped_output);
-        }
-      }
-    }
-  }
-}
-
-inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params,
-                                 const RuntimeShape& unswitched_input1_shape,
-                                 const uint8* unswitched_input1_data,
-                                 const RuntimeShape& unswitched_input2_shape,
-                                 const uint8* unswitched_input2_data,
-                                 const RuntimeShape& output_shape,
-                                 uint8* output_data) {
-  ArithmeticParams switched_params = unswitched_params;
-  switched_params.input1_offset = unswitched_params.input2_offset;
-  switched_params.input1_multiplier = unswitched_params.input2_multiplier;
-  switched_params.input1_shift = unswitched_params.input2_shift;
-  switched_params.input2_offset = unswitched_params.input1_offset;
-  switched_params.input2_multiplier = unswitched_params.input1_multiplier;
-  switched_params.input2_shift = unswitched_params.input1_shift;
-
-  const bool use_unswitched =
-      unswitched_params.broadcast_category ==
-      tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
-
-  const ArithmeticParams& params =
-      use_unswitched ? unswitched_params : switched_params;
-  const uint8* input1_data =
-      use_unswitched ? unswitched_input1_data : unswitched_input2_data;
-  const uint8* input2_data =
-      use_unswitched ? unswitched_input2_data : unswitched_input1_data;
-
-  // Fivefold nested loops. The second input resets its position for each
-  // iteration of the second loop. The first input resets its position at the
-  // beginning of the fourth loop. The innermost loop is an elementwise add of
-  // sections of the arrays.
-  uint8* output_data_ptr = output_data;
-  const uint8* input1_data_ptr = input1_data;
-  const uint8* input2_data_reset = input2_data;
-  // In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared
-  // between input shapes. y3 for input 1 is always broadcast, and so the
-  // dimension there is 1, whereas optionally y1 might be broadcast for input 2.
-  // Put another way,
-  // input1.shape.FlatSize = y0 * y1 * y2 * y4,
-  // input2.shape.FlatSize = y0 * y2 * y3 * y4.
-  int y0 = params.broadcast_shape[0];
-  int y1 = params.broadcast_shape[1];
-  int y2 = params.broadcast_shape[2];
-  int y3 = params.broadcast_shape[3];
-  int y4 = params.broadcast_shape[4];
-  if (y4 > 1) {
-    // General fivefold pattern, with y4 > 1 so there is a non-broadcast inner
-    // dimension.
-    for (int i0 = 0; i0 < y0; ++i0) {
-      const uint8* input2_data_ptr;
-      for (int i1 = 0; i1 < y1; ++i1) {
-        input2_data_ptr = input2_data_reset;
-        for (int i2 = 0; i2 < y2; ++i2) {
-          for (int i3 = 0; i3 < y3; ++i3) {
-            AddElementwise(y4, params, input1_data_ptr, input2_data_ptr,
-                           output_data_ptr);
-            input2_data_ptr += y4;
-            output_data_ptr += y4;
-          }
-          // We have broadcast y4 of input1 data y3 times, and now move on.
-          input1_data_ptr += y4;
-        }
-      }
-      // We have broadcast y2*y3*y4 of input2 data y1 times, and now move on.
-      input2_data_reset = input2_data_ptr;
-    }
-  } else {
-    // Special case of y4 == 1, in which the innermost loop is a single element
-    // and can be combined with the next (y3) as an inner broadcast.
-    //
-    // Note that this handles the case of pure scalar broadcast when
-    // y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar
-    // broadcast with batch (as y2 > 1).
-    //
-    // NOTE The process is the same as the above general case except simplified
-    // for y4 == 1 and the loop over y3 is contained within the
-    // AddScalarBroadcast function.
-    for (int i0 = 0; i0 < y0; ++i0) {
-      const uint8* input2_data_ptr;
-      for (int i1 = 0; i1 < y1; ++i1) {
-        input2_data_ptr = input2_data_reset;
-        for (int i2 = 0; i2 < y2; ++i2) {
-          AddScalarBroadcast(y3, params, *input1_data_ptr, input2_data_ptr,
-                             output_data_ptr);
-          input2_data_ptr += y3;
-          output_data_ptr += y3;
-          input1_data_ptr += 1;
-        }
-      }
-      input2_data_reset = input2_data_ptr;
-    }
-  }
-}
-
 template <typename T>
 inline void Mul(const ArithmeticParams& params,
                 const RuntimeShape& input1_shape, const T* input1_data,
@@ -898,9 +420,9 @@
     const int32 input2_val = params.input2_offset + input2_data[i];
     const int32 unclamped_result =
         params.output_offset +
-        MultiplyByQuantizedMultiplierSmallerThanOneExp(input1_val * input2_val,
-                                                       params.output_multiplier,
-                                                       params.output_shift);
+        MultiplyByQuantizedMultiplier(input1_val * input2_val,
+                                      params.output_multiplier,
+                                      params.output_shift);
     const int32 clamped_output =
         std::min(params.quantized_activation_max,
                  std::max(params.quantized_activation_min, unclamped_result));
@@ -1002,9 +524,9 @@
               input2_data[SubscriptToIndex(desc2, b, y, x, c)];
           const int32 unclamped_result =
               params.output_offset +
-              MultiplyByQuantizedMultiplierSmallerThanOneExp(
-                  input1_val * input2_val, params.output_multiplier,
-                  params.output_shift);
+              MultiplyByQuantizedMultiplier(input1_val * input2_val,
+                                            params.output_multiplier,
+                                            params.output_shift);
           const int32 clamped_output = std::min(
               params.quantized_activation_max,
               std::max(params.quantized_activation_min, unclamped_result));
@@ -2468,7 +1990,6 @@
   Tanh(input_shape, input_data, output_shape, output_data);
 }
 
-
 inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
                  const int16* input_data, const RuntimeShape& output_shape,
                  int16* output_data) {
@@ -2632,21 +2153,11 @@
                                             std::modulus<T>, FloatMod>::type;
   ModFunc mod_func;
   T trunc_mod = mod_func(input1, input2);
-  return trunc_mod != 0 && ((input2 < 0) != (trunc_mod < 0))
-             ? trunc_mod + input2
+  return (trunc_mod != 0) && ((input2 < 0) != (trunc_mod < 0))
+             ? (trunc_mod + input2)
              : trunc_mod;
 }
 
-inline void Floor(const RuntimeShape& input_shape, const float* input_data,
-                  const RuntimeShape& output_shape, float* output_data) {
-  const int flat_size = MatchingFlatSize(input_shape, output_shape);
-
-  for (int i = 0; i < flat_size; i++) {
-    int offset = i;
-    output_data[offset] = std::floor(input_data[offset]);
-  }
-}
-
 inline void Ceil(const RuntimeShape& input_shape, const float* input_data,
                  const RuntimeShape& output_shape, float* output_data) {
   const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -3546,87 +3057,6 @@
   Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
 }
 
-template <typename T, typename Op>
-void MaximumMinimumBroadcast4DSlow(const RuntimeShape& unextended_input1_shape,
-                                   const T* input1_data,
-                                   const RuntimeShape& unextended_input2_shape,
-                                   const T* input2_data,
-                                   const RuntimeShape& unextended_output_shape,
-                                   T* output_data, Op op) {
-  gemmlowp::ScopedProfilingLabel label("MaximumMinimumBroadcast4DSlow");
-  TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
-  TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
-  TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
-  const RuntimeShape output_shape =
-      RuntimeShape::ExtendedShape(4, unextended_output_shape);
-
-  NdArrayDesc<4> desc1;
-  NdArrayDesc<4> desc2;
-  NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
-                                      unextended_input2_shape, &desc1, &desc2);
-
-  for (int b = 0; b < output_shape.Dims(0); ++b) {
-    for (int y = 0; y < output_shape.Dims(1); ++y) {
-      for (int x = 0; x < output_shape.Dims(2); ++x) {
-        for (int c = 0; c < output_shape.Dims(3); ++c) {
-          auto out_idx = Offset(output_shape, b, y, x, c);
-          auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
-          auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
-          auto in1_val = input1_data[in1_idx];
-          auto in2_val = input2_data[in2_idx];
-          output_data[out_idx] = op(in1_val, in2_val);
-        }
-      }
-    }
-  }
-}
-
-template <typename T1, typename T2, typename T3, typename Cmp>
-void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
-               const T3* input2_data, const RuntimeShape& output_shape,
-               T2* output_data, const Cmp& cmp) {
-  gemmlowp::ScopedProfilingLabel label("ArgMinMax");
-  TFLITE_DCHECK_GT(input1_shape.DimensionsCount(), 0);
-  TFLITE_DCHECK_EQ(input1_shape.DimensionsCount() - 1,
-                   output_shape.DimensionsCount());
-
-  int axis = input2_data[0];
-  if (axis < 0) {
-    axis += input1_shape.DimensionsCount();
-  }
-
-  const int axis_size = input1_shape.Dims(axis);
-
-  int outer_size = 1;
-  for (int i = 0; i < axis; ++i) {
-    TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i));
-    outer_size *= input1_shape.Dims(i);
-  }
-
-  int inner_size = 1;
-  const int dims_count = input1_shape.DimensionsCount();
-  for (int i = axis + 1; i < dims_count; ++i) {
-    TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i - 1));
-    inner_size *= input1_shape.Dims(i);
-  }
-
-  for (int outer = 0; outer < outer_size; ++outer) {
-    for (int inner = 0; inner < inner_size; ++inner) {
-      auto min_max_value = input1_data[outer * axis_size * inner_size + inner];
-      int min_max_index = 0;
-      for (int i = 1; i < axis_size; ++i) {
-        const auto& curr_value =
-            input1_data[(outer * axis_size + i) * inner_size + inner];
-        if (cmp(curr_value, min_max_value)) {
-          min_max_value = curr_value;
-          min_max_index = i;
-        }
-      }
-      output_data[outer * inner_size + inner] = min_max_index;
-    }
-  }
-}
-
 template <typename T1, typename T2, typename T3>
 void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
             const T3* input2_data, const RuntimeShape& output_shape,
@@ -3856,253 +3286,6 @@
   }
 }
 
-template <typename T>
-inline bool EqualFn(T lhs, T rhs) {
-  return lhs == rhs;
-}
-
-template <typename T>
-inline bool NotEqualFn(T lhs, T rhs) {
-  return lhs != rhs;
-}
-
-template <typename T>
-inline bool GreaterFn(T lhs, T rhs) {
-  return lhs > rhs;
-}
-template <typename T>
-inline bool GreaterEqualFn(T lhs, T rhs) {
-  return lhs >= rhs;
-}
-template <typename T>
-inline bool LessFn(T lhs, T rhs) {
-  return lhs < rhs;
-}
-template <typename T>
-inline bool LessEqualFn(T lhs, T rhs) {
-  return lhs <= rhs;
-}
-
-template <typename T>
-using ComparisonFn = bool (*)(T, T);
-
-template <typename T, ComparisonFn<T> F>
-inline void ComparisonImpl(
-    const ComparisonParams& op_params, const RuntimeShape& input1_shape,
-    const T* input1_data, const RuntimeShape& input2_shape,
-    const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
-  const int64_t flatsize =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
-  for (int64_t i = 0; i < flatsize; ++i) {
-    output_data[i] = F(input1_data[i], input2_data[i]);
-  }
-}
-
-template <ComparisonFn<float> F>
-inline void Comparison(const ComparisonParams& op_params,
-                       const RuntimeShape& input1_shape,
-                       const float* input1_data,
-                       const RuntimeShape& input2_shape,
-                       const float* input2_data,
-                       const RuntimeShape& output_shape, bool* output_data) {
-  ComparisonImpl<float, F>(op_params, input1_shape, input1_data, input2_shape,
-                           input2_data, output_shape, output_data);
-}
-
-template <typename T, ComparisonFn<int32> F>
-inline void ComparisonWithScaling(
-    const ComparisonParams& op_params, const RuntimeShape& input1_shape,
-    const T* input1_data, const RuntimeShape& input2_shape,
-    const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
-  int left_shift = op_params.left_shift;
-  int32 input1_offset = op_params.input1_offset;
-  int32 input1_multiplier = op_params.input1_multiplier;
-  int input1_shift = op_params.input1_shift;
-  int32 input2_offset = op_params.input2_offset;
-  int32 input2_multiplier = op_params.input2_multiplier;
-  int input2_shift = op_params.input2_shift;
-
-  const int64_t flatsize =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
-  for (int64_t i = 0; i < flatsize; ++i) {
-    const int32 input1_val = input1_offset + input1_data[i];
-    const int32 input2_val = input2_offset + input2_data[i];
-    const int32 shifted_input1_val = input1_val * (1 << left_shift);
-    const int32 shifted_input2_val = input2_val * (1 << left_shift);
-    const int32 scaled_input1_val =
-        MultiplyByQuantizedMultiplierSmallerThanOneExp(
-            shifted_input1_val, input1_multiplier, input1_shift);
-    const int32 scaled_input2_val =
-        MultiplyByQuantizedMultiplierSmallerThanOneExp(
-            shifted_input2_val, input2_multiplier, input2_shift);
-    output_data[i] = F(scaled_input1_val, scaled_input2_val);
-  }
-}
-
-template <typename T, ComparisonFn<T> F>
-inline void BroadcastComparison4DSlowImpl(
-    const ComparisonParams& op_params,
-    const RuntimeShape& unextended_input1_shape, const T* input1_data,
-    const RuntimeShape& unextended_input2_shape, const T* input2_data,
-    const RuntimeShape& unextended_output_shape, bool* output_data) {
-  gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlow");
-  TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
-  TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
-  TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
-  const RuntimeShape output_shape =
-      RuntimeShape::ExtendedShape(4, unextended_output_shape);
-
-  NdArrayDesc<4> desc1;
-  NdArrayDesc<4> desc2;
-  NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
-                                      unextended_input2_shape, &desc1, &desc2);
-
-  for (int b = 0; b < output_shape.Dims(0); ++b) {
-    for (int y = 0; y < output_shape.Dims(1); ++y) {
-      for (int x = 0; x < output_shape.Dims(2); ++x) {
-        for (int c = 0; c < output_shape.Dims(3); ++c) {
-          output_data[Offset(output_shape, b, y, x, c)] =
-              F(input1_data[SubscriptToIndex(desc1, b, y, x, c)],
-                input2_data[SubscriptToIndex(desc2, b, y, x, c)]);
-        }
-      }
-    }
-  }
-}
-template <ComparisonFn<float> F>
-inline void BroadcastComparison4DSlow(const ComparisonParams& op_params,
-                                      const RuntimeShape& input1_shape,
-                                      const float* input1_data,
-                                      const RuntimeShape& input2_shape,
-                                      const float* input2_data,
-                                      const RuntimeShape& output_shape,
-                                      bool* output_data) {
-  BroadcastComparison4DSlowImpl<float, F>(op_params, input1_shape, input1_data,
-                                          input2_shape, input2_data,
-                                          output_shape, output_data);
-}
-
-template <typename T, ComparisonFn<int32> F>
-inline void BroadcastComparison4DSlowWithScaling(
-    const ComparisonParams& op_params,
-    const RuntimeShape& unextended_input1_shape, const T* input1_data,
-    const RuntimeShape& unextended_input2_shape, const T* input2_data,
-    const RuntimeShape& unextended_output_shape, bool* output_data) {
-  gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlowWithScaling");
-  TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
-  TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
-  TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
-  const RuntimeShape output_shape =
-      RuntimeShape::ExtendedShape(4, unextended_output_shape);
-
-  NdArrayDesc<4> desc1;
-  NdArrayDesc<4> desc2;
-  NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
-                                      unextended_input2_shape, &desc1, &desc2);
-
-  int left_shift = op_params.left_shift;
-  int32 input1_offset = op_params.input1_offset;
-  int32 input1_multiplier = op_params.input1_multiplier;
-  int input1_shift = op_params.input1_shift;
-  int32 input2_offset = op_params.input2_offset;
-  int32 input2_multiplier = op_params.input2_multiplier;
-  int input2_shift = op_params.input2_shift;
-
-  for (int b = 0; b < output_shape.Dims(0); ++b) {
-    for (int y = 0; y < output_shape.Dims(1); ++y) {
-      for (int x = 0; x < output_shape.Dims(2); ++x) {
-        for (int c = 0; c < output_shape.Dims(3); ++c) {
-          const int32 input1_val =
-              input1_offset + input1_data[SubscriptToIndex(desc1, b, y, x, c)];
-          const int32 input2_val =
-              input2_offset + input2_data[SubscriptToIndex(desc2, b, y, x, c)];
-          const int32 shifted_input1_val = input1_val * (1 << left_shift);
-          const int32 shifted_input2_val = input2_val * (1 << left_shift);
-          const int32 scaled_input1_val =
-              MultiplyByQuantizedMultiplierSmallerThanOneExp(
-                  shifted_input1_val, input1_multiplier, input1_shift);
-          const int32 scaled_input2_val =
-              MultiplyByQuantizedMultiplierSmallerThanOneExp(
-                  shifted_input2_val, input2_multiplier, input2_shift);
-          output_data[Offset(output_shape, b, y, x, c)] =
-              F(scaled_input1_val, scaled_input2_val);
-        }
-      }
-    }
-  }
-}
-
-#define TFLITE_COMPARISON_OP(name)                                             \
-  inline void name(const ComparisonParams& op_params,                          \
-                   const RuntimeShape& input1_shape, const float* input1_data, \
-                   const RuntimeShape& input2_shape, const float* input2_data, \
-                   const RuntimeShape& output_shape, bool* output_data) {      \
-    gemmlowp::ScopedProfilingLabel label(#name);                               \
-    Comparison<name##Fn>(op_params, input1_shape, input1_data, input2_shape,   \
-                         input2_data, output_shape, output_data);              \
-  }                                                                            \
-  template <typename T>                                                        \
-  inline void name##NoScaling(                                                 \
-      const ComparisonParams& op_params, const RuntimeShape& input1_shape,     \
-      const T* input1_data, const RuntimeShape& input2_shape,                  \
-      const T* input2_data, const RuntimeShape& output_shape,                  \
-      bool* output_data) {                                                     \
-    gemmlowp::ScopedProfilingLabel label(#name "NoScaling");                   \
-    ComparisonImpl<T, name##Fn>(op_params, input1_shape, input1_data,          \
-                                input2_shape, input2_data, output_shape,       \
-                                output_data);                                  \
-  }                                                                            \
-  template <typename T>                                                        \
-  inline void name##WithScaling(                                               \
-      const ComparisonParams& op_params, const RuntimeShape& input1_shape,     \
-      const T* input1_data, const RuntimeShape& input2_shape,                  \
-      const T* input2_data, const RuntimeShape& output_shape,                  \
-      bool* output_data) {                                                     \
-    gemmlowp::ScopedProfilingLabel label(#name "WithScaling/8bit");            \
-    ComparisonWithScaling<T, name##Fn>(op_params, input1_shape, input1_data,   \
-                                       input2_shape, input2_data,              \
-                                       output_shape, output_data);             \
-  }                                                                            \
-  template <typename T>                                                        \
-  inline void Broadcast4DSlow##name##NoScaling(                                \
-      const ComparisonParams& op_params, const RuntimeShape& input1_shape,     \
-      const T* input1_data, const RuntimeShape& input2_shape,                  \
-      const T* input2_data, const RuntimeShape& output_shape,                  \
-      bool* output_data) {                                                     \
-    gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name "NoScaling"); \
-    BroadcastComparison4DSlowImpl<T, name##Fn>(                                \
-        op_params, input1_shape, input1_data, input2_shape, input2_data,       \
-        output_shape, output_data);                                            \
-  }                                                                            \
-  inline void Broadcast4DSlow##name(                                           \
-      const ComparisonParams& op_params, const RuntimeShape& input1_shape,     \
-      const float* input1_data, const RuntimeShape& input2_shape,              \
-      const float* input2_data, const RuntimeShape& output_shape,              \
-      bool* output_data) {                                                     \
-    gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name);             \
-    BroadcastComparison4DSlow<name##Fn>(op_params, input1_shape, input1_data,  \
-                                        input2_shape, input2_data,             \
-                                        output_shape, output_data);            \
-  }                                                                            \
-  template <typename T>                                                        \
-  inline void Broadcast4DSlow##name##WithScaling(                              \
-      const ComparisonParams& op_params, const RuntimeShape& input1_shape,     \
-      const T* input1_data, const RuntimeShape& input2_shape,                  \
-      const T* input2_data, const RuntimeShape& output_shape,                  \
-      bool* output_data) {                                                     \
-    gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name "/8bit");     \
-    BroadcastComparison4DSlowWithScaling<T, name##Fn>(                         \
-        op_params, input1_shape, input1_data, input2_shape, input2_data,       \
-        output_shape, output_data);                                            \
-  }
-TFLITE_COMPARISON_OP(Equal);
-TFLITE_COMPARISON_OP(NotEqual);
-TFLITE_COMPARISON_OP(Greater);
-TFLITE_COMPARISON_OP(GreaterEqual);
-TFLITE_COMPARISON_OP(Less);
-TFLITE_COMPARISON_OP(LessEqual);
-#undef TFLITE_COMPARISON_OP
-
 template <typename D, typename T>
 void Select(const RuntimeShape& input_condition_shape,
             const D* input_condition_data, const RuntimeShape& input_x_shape,
@@ -4252,104 +3435,6 @@
   }
 }
 
-inline void Logical(const RuntimeShape& input1_shape, const bool* input1_data,
-                    const RuntimeShape& input2_shape, const bool* input2_data,
-                    const RuntimeShape& output_shape, bool* output_data,
-                    const std::function<bool(bool, bool)>& func) {
-  const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
-  for (int i = 0; i < flat_size; ++i) {
-    output_data[i] = func(input1_data[i], input2_data[i]);
-  }
-}
-
-inline void BroadcastLogical4DSlow(
-    const RuntimeShape& unextended_input1_shape, const bool* input1_data,
-    const RuntimeShape& unextended_input2_shape, const bool* input2_data,
-    const RuntimeShape& unextended_output_shape, bool* output_data,
-    const std::function<bool(bool, bool)>& func) {
-  TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
-  TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
-  TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
-  const RuntimeShape output_shape =
-      RuntimeShape::ExtendedShape(4, unextended_output_shape);
-
-  NdArrayDesc<4> desc1;
-  NdArrayDesc<4> desc2;
-  NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
-                                      unextended_input2_shape, &desc1, &desc2);
-
-  for (int b = 0; b < output_shape.Dims(0); ++b) {
-    for (int y = 0; y < output_shape.Dims(1); ++y) {
-      for (int x = 0; x < output_shape.Dims(2); ++x) {
-        for (int c = 0; c < output_shape.Dims(3); ++c) {
-          auto out_idx = Offset(output_shape, b, y, x, c);
-          auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
-          auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
-          auto in1_val = input1_data[in1_idx];
-          auto in2_val = input2_data[in2_idx];
-          output_data[out_idx] = func(in1_val, in2_val);
-        }
-      }
-    }
-  }
-}
-
-// TODO(ycling): Refactoring. Remove BroadcastLogical and use the more
-// generalized and efficient BroadcastBinaryFunction.
-//
-// Also appears to duplicte MinimumMaximum.
-//
-// R: Result type. T1: Input 1 type. T2: Input 2 type.
-template <typename R, typename T1, typename T2>
-inline void BroadcastBinaryFunction4DSlow(
-    const RuntimeShape& unextended_input1_shape, const T1* input1_data,
-    const RuntimeShape& unextended_input2_shape, const T2* input2_data,
-    const RuntimeShape& unextended_output_shape, R* output_data,
-    R (*func)(T1, T2)) {
-  TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
-  TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
-  TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
-  const RuntimeShape output_shape =
-      RuntimeShape::ExtendedShape(4, unextended_output_shape);
-
-  NdArrayDesc<4> desc1;
-  NdArrayDesc<4> desc2;
-  NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
-                                      unextended_input2_shape, &desc1, &desc2);
-
-  for (int b = 0; b < output_shape.Dims(0); ++b) {
-    for (int y = 0; y < output_shape.Dims(1); ++y) {
-      for (int x = 0; x < output_shape.Dims(2); ++x) {
-        for (int c = 0; c < output_shape.Dims(3); ++c) {
-          auto out_idx = Offset(output_shape, b, y, x, c);
-          auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
-          auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
-          auto in1_val = input1_data[in1_idx];
-          auto in2_val = input2_data[in2_idx];
-          output_data[out_idx] = func(in1_val, in2_val);
-        }
-      }
-    }
-  }
-}
-
-// R: Result type. T1: Input 1 type. T2: Input 2 type.
-// TODO(renjieliu): Refactor other binary functions to use this one.
-template <typename R, typename T1, typename T2>
-inline void BinaryFunction(const RuntimeShape& input1_shape,
-                           const T1* input1_data,
-                           const RuntimeShape& input2_shape,
-                           const T2* input2_data,
-                           const RuntimeShape& output_shape, R* output_data,
-                           R (*func)(T1, T2)) {
-  const int flat_size =
-      MatchingFlatSize(input1_shape, input2_shape, output_shape);
-  for (int i = 0; i < flat_size; ++i) {
-    output_data[i] = func(input1_data[i], input2_data[i]);
-  }
-}
-
 template <typename T>
 inline void ResizeNearestNeighbor(
     const tflite::ResizeNearestNeighborParams& op_params,
diff --git a/tensorflow/lite/kernels/internal/reference/svdf.h b/tensorflow/lite/kernels/internal/reference/svdf.h
new file mode 100644
index 0000000..fe2ea9f
--- /dev/null
+++ b/tensorflow/lite/kernels/internal/reference/svdf.h
@@ -0,0 +1,210 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SVDF_H_
+#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SVDF_H_
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/lite/kernels/internal/types.h"
+
+// SVDF op that compresses a fully connected op via low-rank matrix
+// factorization. See https://research.google.com/pubs/archive/43813.pdf for
+// details.
+
+namespace tflite {
+namespace reference_ops {
+
+static inline void ApplyTimeWeightsBiasAndActivation(
+    int batch_size, int memory_size, int num_filters, int num_units, int rank,
+    const TfLiteTensor* weights_time, const TfLiteTensor* bias,
+    TfLiteFusedActivation activation, TfLiteTensor* activation_state,
+    TfLiteTensor* scratch, TfLiteTensor* output) {
+  // Compute matmul(state, weights_time).
+  // The rightmost column is used to save temporary output (with the size of
+  // num_filters). This is achieved by starting at
+  // GetTensorData<float>(activation_state), and having the stride equal to
+  // memory_size.
+  for (int b = 0; b < batch_size; ++b) {
+    float* state_ptr_batch =
+        GetTensorData<float>(activation_state) + b * memory_size * num_filters;
+    float* scratch_ptr_batch = GetTensorData<float>(scratch) + b * num_filters;
+    tensor_utils::BatchVectorBatchVectorDotProduct(
+        GetTensorData<float>(weights_time), state_ptr_batch, memory_size,
+        num_filters, scratch_ptr_batch, /*result_stride=*/1);
+  }
+
+  // Initialize output with bias if provided.
+  if (bias) {
+    tensor_utils::VectorBatchVectorAssign(GetTensorData<float>(bias), num_units,
+                                          batch_size,
+                                          GetTensorData<float>(output));
+  } else {
+    tensor_utils::ZeroVector(GetTensorData<float>(output),
+                             batch_size * num_units);
+  }
+
+  // Reduction sum.
+  for (int b = 0; b < batch_size; ++b) {
+    float* output_ptr_batch = GetTensorData<float>(output) + b * num_units;
+    float* scratch_ptr_batch = GetTensorData<float>(scratch) + b * num_filters;
+    tensor_utils::ReductionSumVector(scratch_ptr_batch, output_ptr_batch,
+                                     num_units, rank);
+  }
+
+  // Apply activation.
+  for (int b = 0; b < batch_size; ++b) {
+    float* output_ptr_batch = GetTensorData<float>(output) + b * num_units;
+    tensor_utils::ApplyActivationToVector(output_ptr_batch, num_units,
+                                          activation, output_ptr_batch);
+  }
+
+  // Left shift the activation_state to make room for next cycle's activation.
+  // TODO(alanchiao): explore collapsing this into a single loop.
+  for (int b = 0; b < batch_size; ++b) {
+    float* state_ptr_batch =
+        GetTensorData<float>(activation_state) + b * memory_size * num_filters;
+    for (int f = 0; f < num_filters; ++f) {
+      tensor_utils::VectorShiftLeft(state_ptr_batch, memory_size,
+                                    /*shift_value=*/0.0f);
+      state_ptr_batch += memory_size;
+    }
+  }
+}
+
+inline void EvalFloatSVDF(TfLiteContext* context, TfLiteNode* node,
+                          const TfLiteTensor* input,
+                          const TfLiteTensor* weights_feature,
+                          const TfLiteTensor* weights_time,
+                          const TfLiteTensor* bias,
+                          const TfLiteSVDFParams* params, TfLiteTensor* scratch,
+                          TfLiteTensor* state, TfLiteTensor* output) {
+  const int rank = params->rank;
+  const int batch_size = input->dims->data[0];
+  const int input_size = input->dims->data[1];
+  const int num_filters = weights_feature->dims->data[0];
+  const int num_units = num_filters / rank;
+  const int memory_size = weights_time->dims->data[1];
+
+  // Clear the activation (state's leftmost column).
+  // TODO(ghodrat): Add a test which initialize activation_state with invalid
+  // values in leftmost column and make sure it passes.
+  for (int b = 0; b < batch_size; ++b) {
+    float* state_ptr_batch =
+        GetTensorData<float>(state) + b * memory_size * num_filters;
+    for (int c = 0; c < num_filters; ++c) {
+      float* state_ptr = state_ptr_batch + c * memory_size;
+      state_ptr[memory_size - 1] = 0.0f;
+    }
+  }
+
+  // Compute conv1d(inputs, weights_feature).
+  // The state's rightmost column is used to save current cycle activation. This
+  // is achieved by starting at GetTensorData<float>(state)[memory_size - 1] and
+  // having the stride equal to memory_size.
+  tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+      GetTensorData<float>(weights_feature), num_filters, input_size,
+      GetTensorData<float>(input), batch_size,
+      &GetTensorData<float>(state)[memory_size - 1], memory_size);
+
+  ApplyTimeWeightsBiasAndActivation(batch_size, memory_size, num_filters,
+                                    num_units, rank, weights_time, bias,
+                                    params->activation, state, scratch, output);
+}
+
+inline void EvalHybridSVDF(
+    TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input,
+    const TfLiteTensor* weights_feature, const TfLiteTensor* weights_time,
+    const TfLiteTensor* bias, const TfLiteSVDFParams* params,
+    TfLiteTensor* scratch, TfLiteTensor* scaling_factors,
+    TfLiteTensor* input_quantized, TfLiteTensor* state, TfLiteTensor* output) {
+  const int rank = params->rank;
+  const int batch_size = input->dims->data[0];
+  const int input_size = input->dims->data[1];
+  const int num_filters = weights_feature->dims->data[0];
+  const int num_units = num_filters / rank;
+  const int memory_size = weights_time->dims->data[1];
+
+  // Initialize the pointer to input.
+  const float* input_ptr_batch = GetTensorData<float>(input);
+
+  // Initialize the pointer to storage for quantized values and the weights
+  // feature.
+  int8_t* quantized_input_ptr_batch;
+  const int8_t* weights_feature_ptr;
+  if (weights_feature->type == kTfLiteUInt8) {
+    quantized_input_ptr_batch =
+        reinterpret_cast<int8_t*>(GetTensorData<uint8_t>(input_quantized));
+    weights_feature_ptr = reinterpret_cast<const int8_t*>(
+        GetTensorData<uint8_t>(weights_feature));
+  } else {
+    quantized_input_ptr_batch = GetTensorData<int8_t>(input_quantized);
+    weights_feature_ptr = GetTensorData<int8_t>(weights_feature);
+  }
+
+  // Initialize the pointer to storage for scaling factors.
+  float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
+
+  // Initialize the weights scale.
+  const float weights_feature_scale = weights_feature->params.scale;
+
+  // Clear the activation (state's leftmost column).
+  // TODO(ghodrat): Add a test which initialize state with invalid values in
+  // the leftmost column and make sure it passes.
+  for (int b = 0; b < batch_size; ++b) {
+    float* state_ptr_batch =
+        GetTensorData<float>(state) + b * memory_size * num_filters;
+    for (int c = 0; c < num_filters; ++c) {
+      float* state_ptr = state_ptr_batch + c * memory_size;
+      state_ptr[memory_size - 1] = 0.0;
+    }
+  }
+
+  if (!tensor_utils::IsZeroVector(input_ptr_batch, batch_size * input_size)) {
+    // Quantize input from float to int8.
+    float unused_min, unused_max;
+    for (int b = 0; b < batch_size; ++b) {
+      const int offset = b * input_size;
+      tensor_utils::SymmetricQuantizeFloats(
+          input_ptr_batch + offset, input_size,
+          quantized_input_ptr_batch + offset, &unused_min, &unused_max,
+          &scaling_factors_ptr[b]);
+      scaling_factors_ptr[b] *= weights_feature_scale;
+    }
+
+    // Compute conv1d(inputs, weights_feature).
+    // The rightmost column of state is used to save the current cycle
+    // activation.
+    // This is achieved by starting at GetTensorData<float>(state)[memory_size -
+    // 1] and having the stride equal to memory_size.
+    tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+        weights_feature_ptr, num_filters, input_size, quantized_input_ptr_batch,
+        scaling_factors_ptr, batch_size,
+        &GetTensorData<float>(state)[memory_size - 1], memory_size);
+  }
+
+  // TODO(alanchiao): can optimize hybrid case ~5% by unrolling loop in applying
+  // time weights so that the inner loop multiplies eight elements at a time.
+  ApplyTimeWeightsBiasAndActivation(batch_size, memory_size, num_filters,
+                                    num_units, rank, weights_time, bias,
+                                    params->activation, state, scratch, output);
+}
+
+}  // namespace reference_ops
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SVDF_H_
diff --git a/tensorflow/lite/kernels/internal/softmax_quantized_test.cc b/tensorflow/lite/kernels/internal/softmax_quantized_test.cc
index ea69f49..269dc98 100644
--- a/tensorflow/lite/kernels/internal/softmax_quantized_test.cc
+++ b/tensorflow/lite/kernels/internal/softmax_quantized_test.cc
@@ -124,9 +124,14 @@
                                                      input_beta_left_shift);
 
   SoftmaxParams params;
+  float table[256];
   params.input_multiplier = input_beta_multiplier;
   params.input_left_shift = input_beta_left_shift;
   params.diff_min = diff_min;
+  params.scale = 1.0f / 256;
+  params.zero_point = 0;
+  params.table = table;
+  optimized_ops::PopulateSoftmaxLookupTable(&params, input_scale, beta);
   optimized_ops::Softmax(params, shape_common, input_data, shape_common,
                          optimized_softmax_output.data());
   reference_ops::Softmax(params, shape_common, input_data, shape_common,
@@ -137,7 +142,7 @@
                            "Optimized vs float reference", false);
   CheckOutputData<uint8_t>(optimized_softmax_output.data(),
                            reference_quant_softmax_output.data(), shape_common,
-                           "Optimized vs quant reference", true);
+                           "Optimized vs quant reference", false);
   CheckOutputData<uint8_t>(reference_quant_softmax_output.data(),
                            reference_float_softmax_output.data(), shape_common,
                            "Quant reference vs float reference", false);
diff --git a/tensorflow/lite/kernels/internal/tensor_utils.h b/tensorflow/lite/kernels/internal/tensor_utils.h
index 8eba2f1..c2bd92c 100644
--- a/tensorflow/lite/kernels/internal/tensor_utils.h
+++ b/tensorflow/lite/kernels/internal/tensor_utils.h
@@ -15,6 +15,8 @@
 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
 #define TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
 
+#include <algorithm>
+
 #include "tensorflow/lite/c/builtin_op_data.h"
 
 #if defined(_MSC_VER)
@@ -179,7 +181,13 @@
                 float* result);
 
 // Shift left a vector in place with v_size size.
-void VectorShiftLeft(float* vector, int v_size, float shift_value);
+template <typename T>
+void VectorShiftLeft(T* vector, int v_size, const T& shift_value) {
+  // When copying overlapping ranges, std::copy is appropriate when beginning of
+  // the destination range is outside the source range.
+  std::copy(vector + 1, vector + v_size, vector);
+  vector[v_size - 1] = shift_value;
+}
 
 // Reduce-sum on a float input vector:
 // input_vector: float pointer to input vector.
diff --git a/tensorflow/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/lite/kernels/internal/tensor_utils_test.cc
index 0918c8d..5b07cf1 100644
--- a/tensorflow/lite/kernels/internal/tensor_utils_test.cc
+++ b/tensorflow/lite/kernels/internal/tensor_utils_test.cc
@@ -874,7 +874,7 @@
   constexpr int kVectorSize = 5;
   static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0};
   std::vector<float> result(kVectorSize);
-  VectorShiftLeft(input, kVectorSize, 3.0);
+  VectorShiftLeft(input, kVectorSize, 3.0f);
   result.assign(input, input + kVectorSize);
   EXPECT_THAT(result,
               ElementsAreArray(ArrayFloatNear({-0.5, 1.0, -1.5, 2.0, 3.0})));
diff --git a/tensorflow/lite/kernels/internal/types.h b/tensorflow/lite/kernels/internal/types.h
index b786bde..eb7b630 100644
--- a/tensorflow/lite/kernels/internal/types.h
+++ b/tensorflow/lite/kernels/internal/types.h
@@ -16,6 +16,7 @@
 #define TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
 
 #include <algorithm>
+#include <cstdint>
 #include <cstring>
 #include <initializer_list>
 
@@ -985,6 +986,9 @@
   int32 reverse_scaling_divisor;
   int32 reverse_scaling_right_shift;
   int diff_min;
+  int32_t zero_point;
+  float scale;
+  float* table;
 };
 
 struct SpaceToBatchParams {
diff --git a/tensorflow/lite/kernels/kernel_util.cc b/tensorflow/lite/kernels/kernel_util.cc
index 7f5ab19..5fdb301 100644
--- a/tensorflow/lite/kernels/kernel_util.cc
+++ b/tensorflow/lite/kernels/kernel_util.cc
@@ -23,19 +23,6 @@
 
 namespace tflite {
 
-void GuardedQuantizeMultiplier(double effective_output_scale,
-                               int32_t* significand, int* shift) {
-  QuantizeMultiplier(effective_output_scale, significand, shift);
-  // Additional guard to make sure RoundingDivideByPOT does not fail.
-  if (*shift < -31) {
-    // If shift is less than -31, RoundingDivideByPOT fails. This happens when
-    // min and max are close and small. For this particular case, both
-    // significand and shift are set to zero.
-    *significand = 0;
-    *shift = 0;
-  }
-}
-
 TfLiteStatus PopulateConvolutionQuantizationParams(
     TfLiteContext* context, const TfLiteTensor* input,
     const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output,
@@ -79,7 +66,7 @@
                                           static_cast<double>(output_scale);
     int32_t significand;
     int shift;
-    GuardedQuantizeMultiplier(effective_output_scale, &significand, &shift);
+    QuantizeMultiplier(effective_output_scale, &significand, &shift);
     per_channel_multiplier[i] = significand;
     per_channel_shift[i] = shift;
   }
diff --git a/tensorflow/lite/kernels/kernel_util.h b/tensorflow/lite/kernels/kernel_util.h
index a76d925..7eb2997 100644
--- a/tensorflow/lite/kernels/kernel_util.h
+++ b/tensorflow/lite/kernels/kernel_util.h
@@ -54,14 +54,18 @@
   return node->intermediates->size;
 }
 
-inline int64_t NumElements(const TfLiteTensor* t) {
+inline int64_t NumElements(const TfLiteIntArray* dims) {
   int64_t count = 1;
-  for (int i = 0; i < NumDimensions(t); ++i) {
-    count *= SizeOfDimension(t, i);
+  for (int i = 0; i < dims->size; ++i) {
+    count *= dims->data[i];
   }
   return count;
 }
 
+inline int64_t NumElements(const TfLiteTensor* t) {
+  return NumElements(t->dims);
+}
+
 inline const TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context,
                                                   const TfLiteNode* node,
                                                   int index) {
@@ -72,14 +76,6 @@
   return nullptr;
 }
 
-inline int8_t* GetInt8DataPtr(const TfLiteTensor* tensor, const bool is_uint8) {
-  if (is_uint8) {
-    return reinterpret_cast<int8_t*>(tensor->data.uint8);
-  } else {
-    return tensor->data.int8;
-  }
-}
-
 // Determines whether tensor is constant.
 inline bool IsConstantTensor(const TfLiteTensor* tensor) {
   return tensor->allocation_type == kTfLiteMmapRo;
@@ -114,10 +110,6 @@
     int32_t* output_activation_min, int32_t* output_activation_max,
     int32_t* per_channel_multiplier, int* per_channel_shift);
 
-// QuantizedMultiplier with the guard that shift will not be smaller than -31.
-void GuardedQuantizeMultiplier(double effective_output_scale,
-                               int32_t* significand, int* shift);
-
 // Calculates the multiplication factor for a quantized convolution (or
 // quantized depthwise convolution) involving the given tensors. Returns an
 // error if the scales of the tensors are not compatible.
diff --git a/tensorflow/lite/kernels/kernel_util_test.cc b/tensorflow/lite/kernels/kernel_util_test.cc
index 04d559a..95759db 100644
--- a/tensorflow/lite/kernels/kernel_util_test.cc
+++ b/tensorflow/lite/kernels/kernel_util_test.cc
@@ -389,15 +389,9 @@
   auto* filter_params = reinterpret_cast<TfLiteAffineQuantization*>(
       malloc(sizeof(TfLiteAffineQuantization)));
   filter_params->scale = TfLiteFloatArrayCreate(3);
-  int32_t two_pow_neg_31 = 0x30000000;  // 2^-31 so shift = -30.
-  int32_t two_pow_neg_32 = 0x2F800000;  // 2^-32 so shift = -31.
-  int32_t two_pow_neg_33 = 0x2F000000;  // 2^-33 so shift = -32.
-  float* scale_date = reinterpret_cast<float*>(&two_pow_neg_31);
-  filter_params->scale->data[0] = *scale_date;
-  scale_date = reinterpret_cast<float*>(&two_pow_neg_32);
-  filter_params->scale->data[1] = *scale_date;
-  scale_date = reinterpret_cast<float*>(&two_pow_neg_33);
-  filter_params->scale->data[2] = *scale_date;
+  filter_params->scale->data[0] = std::ldexp(1.0f, -31);
+  filter_params->scale->data[1] = std::ldexp(1.0f, -32);
+  filter_params->scale->data[2] = std::ldexp(1.0f, -33);
   filter_params->zero_point = TfLiteIntArrayCreate(3);
   filter_params->zero_point->data[0] = 0;
   filter_params->zero_point->data[1] = 0;
@@ -416,9 +410,9 @@
   auto* bias_params = reinterpret_cast<TfLiteAffineQuantization*>(
       malloc(sizeof(TfLiteAffineQuantization)));
   bias_params->scale = TfLiteFloatArrayCreate(3);
-  bias_params->scale->data[0] = 4.6566129e-10;  // 2^-31
-  bias_params->scale->data[1] = 2.3283064e-10;  // 2^-32
-  bias_params->scale->data[2] = 1.1641532e-10;  // 2^-33
+  bias_params->scale->data[0] = std::ldexp(1.0f, -31);
+  bias_params->scale->data[1] = std::ldexp(1.0f, -32);
+  bias_params->scale->data[2] = std::ldexp(1.0f, -33);
   bias_params->zero_point = TfLiteIntArrayCreate(3);
   bias_params->zero_point->data[0] = 11;
   bias_params->zero_point->data[1] = 12;
diff --git a/tensorflow/lite/kernels/logical.cc b/tensorflow/lite/kernels/logical.cc
index 582bcff..7a2805d 100644
--- a/tensorflow/lite/kernels/logical.cc
+++ b/tensorflow/lite/kernels/logical.cc
@@ -78,7 +78,7 @@
 }
 
 TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node,
-                         const std::function<bool(bool, bool)>& func) {
+                         bool (*func)(bool, bool)) {
   OpData* data = reinterpret_cast<OpData*>(node->user_data);
 
   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
@@ -86,28 +86,30 @@
   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
 
   if (data->requires_broadcast) {
-    reference_ops::BroadcastLogical4DSlow(
+    reference_ops::BroadcastBinaryFunction4DSlow<bool, bool, bool>(
         GetTensorShape(input1), GetTensorData<bool>(input1),
         GetTensorShape(input2), GetTensorData<bool>(input2),
         GetTensorShape(output), GetTensorData<bool>(output), func);
   } else {
-    reference_ops::Logical(GetTensorShape(input1), GetTensorData<bool>(input1),
-                           GetTensorShape(input2), GetTensorData<bool>(input2),
-                           GetTensorShape(output), GetTensorData<bool>(output),
-                           func);
+    reference_ops::BinaryFunction<bool, bool, bool>(
+        GetTensorShape(input1), GetTensorData<bool>(input1),
+        GetTensorShape(input2), GetTensorData<bool>(input2),
+        GetTensorShape(output), GetTensorData<bool>(output), func);
   }
 
   return kTfLiteOk;
 }
 
+bool LogicalOr(bool x, bool y) { return x || y; }
+
 TfLiteStatus LogicalOrEval(TfLiteContext* context, TfLiteNode* node) {
-  const auto logical_or_func = std::logical_or<bool>();
-  return LogicalImpl(context, node, logical_or_func);
+  return LogicalImpl(context, node, LogicalOr);
 }
 
+bool LogicalAnd(bool x, bool y) { return x && y; }
+
 TfLiteStatus LogicalAndEval(TfLiteContext* context, TfLiteNode* node) {
-  const auto logical_and_func = std::logical_and<bool>();
-  return LogicalImpl(context, node, logical_and_func);
+  return LogicalImpl(context, node, LogicalAnd);
 }
 
 }  // namespace
diff --git a/tensorflow/lite/kernels/lstm_eval.cc b/tensorflow/lite/kernels/lstm_eval.cc
index a518daf..960c770 100644
--- a/tensorflow/lite/kernels/lstm_eval.cc
+++ b/tensorflow/lite/kernels/lstm_eval.cc
@@ -24,8 +24,8 @@
 #include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/c/c_api_internal.h"
 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/kernels/op_macros.h"
 
 namespace tflite {
@@ -1123,9 +1123,6 @@
     TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
     TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
     TfLiteTensor* cell_state, TfLiteTensor* output) {
-  // For operations that use int8 instead of uint8 we need to fetch raw data
-  // from the tensor different. We use this bool for that condition.
-  const bool is_uint8_hybrid = input_to_output_weights->type == kTfLiteUInt8;
   TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
   const int n_input = input->dims->data[input->dims->size - 1];
   int max_time, n_batch;
@@ -1164,37 +1161,33 @@
   }
 
   // Check optional tensors, the respective pointers can be null.
-  int8_t* input_to_input_weights_ptr = nullptr;
+  const int8_t* input_to_input_weights_ptr = nullptr;
   float input_to_input_weights_scale = 1.0f;
-  int8_t* recurrent_to_input_weights_ptr = nullptr;
+  const int8_t* recurrent_to_input_weights_ptr = nullptr;
   float recurrent_to_input_weights_scale = 1.0f;
   float* input_gate_bias_ptr = nullptr;
   if (!use_cifg) {
-    input_to_input_weights_ptr =
-        GetInt8DataPtr(input_to_input_weights, is_uint8_hybrid);
+    input_to_input_weights_ptr = GetTensorData<int8_t>(input_to_input_weights);
     recurrent_to_input_weights_ptr =
-        GetInt8DataPtr(recurrent_to_input_weights, is_uint8_hybrid);
+        GetTensorData<int8_t>(recurrent_to_input_weights);
     input_gate_bias_ptr = input_gate_bias->data.f;
     input_to_input_weights_scale = input_to_input_weights->params.scale;
     recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
   }
 
-  int8_t* cell_to_input_weights_ptr = nullptr;
-  int8_t* cell_to_forget_weights_ptr = nullptr;
-  int8_t* cell_to_output_weights_ptr = nullptr;
+  const int8_t* cell_to_input_weights_ptr = nullptr;
+  const int8_t* cell_to_forget_weights_ptr = nullptr;
+  const int8_t* cell_to_output_weights_ptr = nullptr;
   float cell_to_input_weights_scale = 1.0f;
   float cell_to_forget_weights_scale = 1.0f;
   float cell_to_output_weights_scale = 1.0f;
   if (use_peephole) {
     if (!use_cifg) {
-      cell_to_input_weights_ptr =
-          GetInt8DataPtr(cell_to_input_weights, is_uint8_hybrid);
+      cell_to_input_weights_ptr = GetTensorData<int8_t>(cell_to_input_weights);
       cell_to_input_weights_scale = cell_to_input_weights->params.scale;
     }
-    cell_to_forget_weights_ptr =
-        GetInt8DataPtr(cell_to_forget_weights, is_uint8_hybrid);
-    cell_to_output_weights_ptr =
-        GetInt8DataPtr(cell_to_output_weights, is_uint8_hybrid);
+    cell_to_forget_weights_ptr = GetTensorData<int8_t>(cell_to_forget_weights);
+    cell_to_output_weights_ptr = GetTensorData<int8_t>(cell_to_output_weights);
     cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
     cell_to_output_weights_scale = cell_to_output_weights->params.scale;
   }
@@ -1212,7 +1205,7 @@
   const int8_t* projection_weights_ptr =
       (projection_weights == nullptr)
           ? nullptr
-          : GetInt8DataPtr(projection_weights, is_uint8_hybrid);
+          : GetTensorData<int8_t>(projection_weights);
   const float projection_weights_scale =
       (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
   const float* projection_bias_ptr =
@@ -1220,26 +1213,26 @@
 
   // Required tensors, pointers are non-null.
   const int8_t* input_to_forget_weights_ptr =
-      GetInt8DataPtr(input_to_forget_weights, is_uint8_hybrid);
+      GetTensorData<int8_t>(input_to_forget_weights);
   const float input_to_forget_weights_scale =
       input_to_forget_weights->params.scale;
   const int8_t* input_to_cell_weights_ptr =
-      GetInt8DataPtr(input_to_cell_weights, is_uint8_hybrid);
+      GetTensorData<int8_t>(input_to_cell_weights);
   const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
   const int8_t* input_to_output_weights_ptr =
-      GetInt8DataPtr(input_to_output_weights, is_uint8_hybrid);
+      GetTensorData<int8_t>(input_to_output_weights);
   const float input_to_output_weights_scale =
       input_to_output_weights->params.scale;
   const int8_t* recurrent_to_forget_weights_ptr =
-      GetInt8DataPtr(recurrent_to_forget_weights, is_uint8_hybrid);
+      GetTensorData<int8_t>(recurrent_to_forget_weights);
   const float recurrent_to_forget_weights_scale =
       recurrent_to_forget_weights->params.scale;
   const int8_t* recurrent_to_cell_weights_ptr =
-      GetInt8DataPtr(recurrent_to_cell_weights, is_uint8_hybrid);
+      GetTensorData<int8_t>(recurrent_to_cell_weights);
   const float recurrent_to_cell_weights_scale =
       recurrent_to_cell_weights->params.scale;
   const int8_t* recurrent_to_output_weights_ptr =
-      GetInt8DataPtr(recurrent_to_output_weights, is_uint8_hybrid);
+      GetTensorData<int8_t>(recurrent_to_output_weights);
   const float recurrent_to_output_weights_scale =
       recurrent_to_output_weights->params.scale;
   const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
@@ -1247,26 +1240,25 @@
   const float* output_gate_bias_ptr = output_gate_bias->data.f;
 
   // Temporary storage for quantized values and scaling factors.
-  int8_t* quantized_input_ptr =
-      GetInt8DataPtr(input_quantized, is_uint8_hybrid);
+  int8_t* quantized_input_ptr = GetTensorData<int8_t>(input_quantized);
   int8_t* quantized_aux_input_ptr =
       (aux_input_quantized == nullptr)
           ? nullptr
-          : GetInt8DataPtr(aux_input_quantized, is_uint8_hybrid);
+          : GetTensorData<int8_t>(aux_input_quantized);
   int8_t* quantized_output_state_ptr =
-      GetInt8DataPtr(output_state_quantized, is_uint8_hybrid);
+      GetTensorData<int8_t>(output_state_quantized);
   int8_t* quantized_cell_state_ptr =
-      GetInt8DataPtr(cell_state_quantized, is_uint8_hybrid);
+      GetTensorData<int8_t>(cell_state_quantized);
   float* scaling_factors_ptr = scaling_factors->data.f;
   float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
   float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
 
   // Auxiliary input and weights.
   float* aux_input_ptr = nullptr;
-  int8_t* aux_input_to_input_weights_ptr = nullptr;
-  int8_t* aux_input_to_forget_weights_ptr = nullptr;
-  int8_t* aux_input_to_cell_weights_ptr = nullptr;
-  int8_t* aux_input_to_output_weights_ptr = nullptr;
+  const int8_t* aux_input_to_input_weights_ptr = nullptr;
+  const int8_t* aux_input_to_forget_weights_ptr = nullptr;
+  const int8_t* aux_input_to_cell_weights_ptr = nullptr;
+  const int8_t* aux_input_to_output_weights_ptr = nullptr;
   float aux_input_to_input_weights_scale = 0.0f;
   float aux_input_to_forget_weights_scale = 0.0f;
   float aux_input_to_cell_weights_scale = 0.0f;
@@ -1274,14 +1266,14 @@
   if (aux_input_size > 0) {
     if (!use_cifg) {
       aux_input_to_input_weights_ptr =
-          GetInt8DataPtr(aux_input_to_input_weights, is_uint8_hybrid);
+          GetTensorData<int8_t>(aux_input_to_input_weights);
     }
     aux_input_to_forget_weights_ptr =
-        GetInt8DataPtr(aux_input_to_forget_weights, is_uint8_hybrid);
+        GetTensorData<int8_t>(aux_input_to_forget_weights);
     aux_input_to_cell_weights_ptr =
-        GetInt8DataPtr(aux_input_to_cell_weights, is_uint8_hybrid);
+        GetTensorData<int8_t>(aux_input_to_cell_weights);
     aux_input_to_output_weights_ptr =
-        GetInt8DataPtr(aux_input_to_output_weights, is_uint8_hybrid);
+        GetTensorData<int8_t>(aux_input_to_output_weights);
     if (!use_cifg) {
       aux_input_to_input_weights_scale =
           aux_input_to_input_weights->params.scale;
diff --git a/tensorflow/lite/kernels/lstm_test.cc b/tensorflow/lite/kernels/lstm_test.cc
index 84ddc25..d02b11e 100644
--- a/tensorflow/lite/kernels/lstm_test.cc
+++ b/tensorflow/lite/kernels/lstm_test.cc
@@ -38,12 +38,13 @@
               bool use_peephole, bool use_projection_weights,
               bool use_projection_bias, float cell_clip, float proj_clip,
               const std::vector<std::vector<int>>& input_shapes,
-              const TensorType& weight_type = TensorType_FLOAT32,
-              bool is_layer_norm = false)
+              const TensorType weight_type, bool is_layer_norm)
       : n_batch_(n_batch),
         n_input_(n_input),
         n_cell_(n_cell),
-        n_output_(n_output) {
+        n_output_(n_output),
+        weight_type_(weight_type),
+        is_layer_norm_(is_layer_norm) {
     input_ = AddInput(TensorType_FLOAT32);
 
     if (use_cifg) {
@@ -103,9 +104,9 @@
 
     // Adding the 2 input state tensors.
     input_activation_state_ =
-        AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true);
+        AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_output_}}, true);
     input_cell_state_ =
-        AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
+        AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_cell_}}, true);
 
     // Layer norm weights.
     if (is_layer_norm) {
@@ -134,178 +135,10 @@
                                    cell_clip, proj_clip)
                      .Union());
 
-    BuildInterpreter(input_shapes);
-  }
-
-  void SetInputToInputWeights(const std::vector<float>& f) {
-    PopulateTensor(input_to_input_weights_, f);
-  }
-
-  void SetInputToForgetWeights(const std::vector<float>& f) {
-    PopulateTensor(input_to_forget_weights_, f);
-  }
-
-  void SetInputToCellWeights(const std::vector<float>& f) {
-    PopulateTensor(input_to_cell_weights_, f);
-  }
-
-  void SetInputToOutputWeights(const std::vector<float>& f) {
-    PopulateTensor(input_to_output_weights_, f);
-  }
-
-  void SetRecurrentToInputWeights(const std::vector<float>& f) {
-    PopulateTensor(recurrent_to_input_weights_, f);
-  }
-
-  void SetRecurrentToForgetWeights(const std::vector<float>& f) {
-    PopulateTensor(recurrent_to_forget_weights_, f);
-  }
-
-  void SetRecurrentToCellWeights(const std::vector<float>& f) {
-    PopulateTensor(recurrent_to_cell_weights_, f);
-  }
-
-  void SetRecurrentToOutputWeights(const std::vector<float>& f) {
-    PopulateTensor(recurrent_to_output_weights_, f);
-  }
-
-  void SetCellToInputWeights(const std::vector<float>& f) {
-    PopulateTensor(cell_to_input_weights_, f);
-  }
-
-  void SetCellToForgetWeights(const std::vector<float>& f) {
-    PopulateTensor(cell_to_forget_weights_, f);
-  }
-
-  void SetCellToOutputWeights(const std::vector<float>& f) {
-    PopulateTensor(cell_to_output_weights_, f);
-  }
-
-  void SetInputLayerNormCoefficients(const std::vector<float>& f) {
-    PopulateTensor(input_layer_norm_coefficients_, f);
-  }
-
-  void SetForgetLayerNormCoefficients(const std::vector<float>& f) {
-    PopulateTensor(forget_layer_norm_coefficients_, f);
-  }
-
-  void SetCellLayerNormCoefficients(const std::vector<float>& f) {
-    PopulateTensor(cell_layer_norm_coefficients_, f);
-  }
-
-  void SetOutputLayerNormCoefficients(const std::vector<float>& f) {
-    PopulateTensor(output_layer_norm_coefficients_, f);
-  }
-
-  void SetInputGateBias(const std::vector<float>& f) {
-    PopulateTensor(input_gate_bias_, f);
-  }
-
-  void SetForgetGateBias(const std::vector<float>& f) {
-    PopulateTensor(forget_gate_bias_, f);
-  }
-
-  void SetCellBias(const std::vector<float>& f) {
-    PopulateTensor(cell_bias_, f);
-  }
-
-  void SetOutputGateBias(const std::vector<float>& f) {
-    PopulateTensor(output_gate_bias_, f);
-  }
-
-  void SetProjectionWeights(const std::vector<float>& f) {
-    PopulateTensor(projection_weights_, f);
-  }
-
-  void SetProjectionBias(const std::vector<float>& f) {
-    PopulateTensor(projection_bias_, f);
-  }
-
-  void SetInput(int offset, const float* begin, const float* end) {
-    PopulateTensor(input_, offset, const_cast<float*>(begin),
-                   const_cast<float*>(end));
-  }
-
-  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
-
-  int num_inputs() { return n_input_; }
-  int num_outputs() { return n_output_; }
-  int num_cells() { return n_cell_; }
-  int num_batches() { return n_batch_; }
-
- protected:
-  int input_;
-  int input_to_input_weights_;
-  int input_to_forget_weights_;
-  int input_to_cell_weights_;
-  int input_to_output_weights_;
-
-  int recurrent_to_input_weights_;
-  int recurrent_to_forget_weights_;
-  int recurrent_to_cell_weights_;
-  int recurrent_to_output_weights_;
-
-  int cell_to_input_weights_;
-  int cell_to_forget_weights_;
-  int cell_to_output_weights_;
-
-  int input_layer_norm_coefficients_;
-  int forget_layer_norm_coefficients_;
-  int cell_layer_norm_coefficients_;
-  int output_layer_norm_coefficients_;
-
-  int input_gate_bias_;
-  int forget_gate_bias_;
-  int cell_bias_;
-  int output_gate_bias_;
-
-  int projection_weights_;
-  int projection_bias_;
-  int input_activation_state_;
-  int input_cell_state_;
-
-  int output_;
-  int output_state_;
-  int cell_state_;
-
-  int n_batch_;
-  int n_input_;
-  int n_cell_;
-  int n_output_;
-
- private:
-  int AddLayerNormCoeffsTensor(
-      int tensor_index, const std::vector<std::vector<int>>& input_shapes) {
-    if (input_shapes[tensor_index][0] != 0) {
-      return AddInput(TensorType_FLOAT32);
-    } else {
-      return AddNullInput();
-    }
-  }
-};
-
-class HybridLSTMOpModel : public LSTMOpModel {
- public:
-  HybridLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
-                    bool use_cifg, bool use_peephole,
-                    bool use_projection_weights, bool use_projection_bias,
-                    float cell_clip, float proj_clip,
-                    const std::vector<std::vector<int>>& input_shapes,
-                    TensorType tensor_type)
-      : LSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg, use_peephole,
-                    use_projection_weights, use_projection_bias, cell_clip,
-                    proj_clip, input_shapes, tensor_type) {
-    tensor_type_ = tensor_type;
-  }
-
-  TensorType tensor_type_;
-
-  void SetWeights(int weights_idx, const std::vector<float>& f) {
-    if (tensor_type_ == TensorType_UINT8) {
-      SymmetricQuantizeAndPopulate(weights_idx, f);
-    } else {
-      SignedSymmetricQuantizeAndPopulate(weights_idx, f);
-    }
+    // Do not apply delegate yet since tensor values are not known (and more
+    // specifically scales in quantized tensors are not known).
+    BuildInterpreter(input_shapes, /*allow_fp32_relax_to_fp16=*/false,
+                     /*apply_delegate=*/false);
   }
 
   void SetInputToInputWeights(const std::vector<float>& f) {
@@ -352,9 +185,136 @@
     SetWeights(cell_to_output_weights_, f);
   }
 
+  void SetInputLayerNormCoefficients(const std::vector<float>& f) {
+    PopulateTensor(input_layer_norm_coefficients_, f);
+  }
+
+  void SetForgetLayerNormCoefficients(const std::vector<float>& f) {
+    PopulateTensor(forget_layer_norm_coefficients_, f);
+  }
+
+  void SetCellLayerNormCoefficients(const std::vector<float>& f) {
+    PopulateTensor(cell_layer_norm_coefficients_, f);
+  }
+
+  void SetOutputLayerNormCoefficients(const std::vector<float>& f) {
+    PopulateTensor(output_layer_norm_coefficients_, f);
+  }
+
+  void SetInputGateBias(const std::vector<float>& f) {
+    PopulateTensor(input_gate_bias_, f);
+  }
+
+  void SetForgetGateBias(const std::vector<float>& f) {
+    PopulateTensor(forget_gate_bias_, f);
+  }
+
+  void SetCellBias(const std::vector<float>& f) {
+    PopulateTensor(cell_bias_, f);
+  }
+
+  void SetOutputGateBias(const std::vector<float>& f) {
+    PopulateTensor(output_gate_bias_, f);
+  }
+
   void SetProjectionWeights(const std::vector<float>& f) {
     SetWeights(projection_weights_, f);
   }
+
+  void SetProjectionBias(const std::vector<float>& f) {
+    PopulateTensor(projection_bias_, f);
+  }
+
+  void SetInput(int offset, const float* begin, const float* end) {
+    SingleOpModel::PopulateTensor(input_, offset, const_cast<float*>(begin),
+                                  const_cast<float*>(end));
+  }
+
+  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+  int num_inputs() { return n_input_; }
+  int num_outputs() { return n_output_; }
+  int num_cells() { return n_cell_; }
+  int num_batches() { return n_batch_; }
+
+ protected:
+  int input_;
+  int input_to_input_weights_;
+  int input_to_forget_weights_;
+  int input_to_cell_weights_;
+  int input_to_output_weights_;
+
+  int recurrent_to_input_weights_;
+  int recurrent_to_forget_weights_;
+  int recurrent_to_cell_weights_;
+  int recurrent_to_output_weights_;
+
+  int cell_to_input_weights_;
+  int cell_to_forget_weights_;
+  int cell_to_output_weights_;
+
+  int input_layer_norm_coefficients_ = kOptionalTensor;
+  int forget_layer_norm_coefficients_ = kOptionalTensor;
+  int cell_layer_norm_coefficients_ = kOptionalTensor;
+  int output_layer_norm_coefficients_ = kOptionalTensor;
+
+  int input_gate_bias_;
+  int forget_gate_bias_;
+  int cell_bias_;
+  int output_gate_bias_;
+
+  int projection_weights_;
+  int projection_bias_;
+  int input_activation_state_;
+  int input_cell_state_;
+
+  int output_;
+  int output_state_;
+  int cell_state_;
+
+  int n_batch_;
+  int n_input_;
+  int n_cell_;
+  int n_output_;
+
+ private:
+  int AddLayerNormCoeffsTensor(
+      int tensor_index, const std::vector<std::vector<int>>& input_shapes) {
+    if (input_shapes[tensor_index][0] != 0) {
+      return AddInput(TensorType_FLOAT32);
+    } else {
+      return AddNullInput();
+    }
+  }
+
+  template <typename T>
+  void PopulateTensor(int index, const std::vector<T>& data) {
+    // Nothing to do if tensor is an optional input or if data vector is empty.
+    if ((index == kOptionalTensor) || data.empty()) return;
+    SingleOpModel::PopulateTensor(index, data);
+  }
+
+  void SetWeights(int index, const std::vector<float>& data) {
+    if (data.empty()) return;
+    if (index == kOptionalTensor) return;
+    switch (weight_type_) {
+      case TensorType_FLOAT32:
+        PopulateTensor(index, data);
+        break;
+      case TensorType_UINT8:
+        SymmetricQuantizeAndPopulate(index, data);
+        break;
+      case TensorType_INT8:
+        SignedSymmetricQuantizeAndPopulate(index, data);
+        break;
+      default:
+        GTEST_FAIL() << "Type not supported: " << weight_type_;
+        break;
+    }
+  }
+
+  const TensorType weight_type_;
+  const bool is_layer_norm_;
 };
 
 class BaseLstmTest : public ::testing::Test {
@@ -376,6 +336,10 @@
   std::vector<float> cell_to_forget_weights_;
   std::vector<float> cell_to_output_weights_;
   std::vector<float> projection_weights_;
+  std::vector<float> input_layer_norm_coefficients_;
+  std::vector<float> forget_layer_norm_coefficients_;
+  std::vector<float> cell_layer_norm_coefficients_;
+  std::vector<float> output_layer_norm_coefficients_;
 
   // LSTM input is stored as num_batch x num_inputs vector.
   std::vector<std::vector<float>> lstm_input_;
@@ -386,6 +350,16 @@
   void VerifyGoldens(const std::vector<std::vector<float>>& input,
                      const std::vector<std::vector<float>>& output,
                      LSTMOpModel* lstm, float tolerance = 1e-5) {
+    // Weights are set twice:
+    // - The delegate, if used, needs to know the scales and zero-points of
+    //   quantized tensors, which are computed dynamically when weights are set,
+    //   so weights have to be set before applying the delegate.
+    // - Applying a delegate will invalidate the tensor data so weights have to
+    //   be set a second time.
+    SetAllWeightsAndBiases(lstm);
+    lstm->ApplyDelegate();
+    SetAllWeightsAndBiases(lstm);
+
     const int num_batches = input.size();
     EXPECT_GT(num_batches, 0);
     const int num_inputs = lstm->num_inputs();
@@ -413,6 +387,37 @@
                   ElementsAreArray(ArrayFloatNear(expected, tolerance)));
     }
   }
+
+  // Sets all weights and biases that have been defined by test. The test can
+  // define only a subset of all those vectors, and only the ones that have been
+  // defined will be set.
+  void SetAllWeightsAndBiases(LSTMOpModel* lstm) {
+    lstm->SetInputToInputWeights(input_to_input_weights_);
+    lstm->SetInputToCellWeights(input_to_cell_weights_);
+    lstm->SetInputToForgetWeights(input_to_forget_weights_);
+    lstm->SetInputToOutputWeights(input_to_output_weights_);
+
+    lstm->SetInputGateBias(input_gate_bias_);
+    lstm->SetCellBias(cell_gate_bias_);
+    lstm->SetForgetGateBias(forget_gate_bias_);
+    lstm->SetOutputGateBias(output_gate_bias_);
+
+    lstm->SetRecurrentToInputWeights(recurrent_to_input_weights_);
+    lstm->SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+    lstm->SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+    lstm->SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+    lstm->SetCellToInputWeights(cell_to_input_weights_);
+    lstm->SetCellToForgetWeights(cell_to_forget_weights_);
+    lstm->SetCellToOutputWeights(cell_to_output_weights_);
+
+    lstm->SetProjectionWeights(projection_weights_);
+
+    lstm->SetInputLayerNormCoefficients(input_layer_norm_coefficients_);
+    lstm->SetForgetLayerNormCoefficients(forget_layer_norm_coefficients_);
+    lstm->SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
+    lstm->SetOutputLayerNormCoefficients(output_layer_norm_coefficients_);
+  }
 };
 
 class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
@@ -500,22 +505,9 @@
 
                        {0, 0},  // projection_weight tensor
                        {0},     // projection_bias tensor
-                   });
-
-  lstm.SetInputToInputWeights(input_to_input_weights_);
-  lstm.SetInputToCellWeights(input_to_cell_weights_);
-  lstm.SetInputToForgetWeights(input_to_forget_weights_);
-  lstm.SetInputToOutputWeights(input_to_output_weights_);
-
-  lstm.SetInputGateBias(input_gate_bias_);
-  lstm.SetCellBias(cell_gate_bias_);
-  lstm.SetForgetGateBias(forget_gate_bias_);
-  lstm.SetOutputGateBias(output_gate_bias_);
-
-  lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
-  lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
-  lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
-  lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+                   },
+                   /*weight_type=*/TensorType_FLOAT32,
+                   /*is_layer_norm=*/false);
 
   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
 }
@@ -572,21 +564,6 @@
                    /*weight_type=*/TensorType_FLOAT32,
                    /*is_layer_norm=*/true);
 
-  lstm.SetInputToInputWeights(input_to_input_weights_);
-  lstm.SetInputToCellWeights(input_to_cell_weights_);
-  lstm.SetInputToForgetWeights(input_to_forget_weights_);
-  lstm.SetInputToOutputWeights(input_to_output_weights_);
-
-  lstm.SetInputGateBias(input_gate_bias_);
-  lstm.SetCellBias(cell_gate_bias_);
-  lstm.SetForgetGateBias(forget_gate_bias_);
-  lstm.SetOutputGateBias(output_gate_bias_);
-
-  lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
-  lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
-  lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
-  lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
-
   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
 }
 
@@ -598,52 +575,38 @@
   const int n_cell = 4;
   const int n_output = 4;
 
-  HybridLSTMOpModel lstm(
-      n_batch, n_input, n_cell, n_output,
-      /*use_cifg=*/false, /*use_peephole=*/false,
-      /*use_projection_weights=*/false,
-      /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
-      {
-          {n_batch, n_input},  // input tensor
+  LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
+                   /*use_cifg=*/false, /*use_peephole=*/false,
+                   /*use_projection_weights=*/false,
+                   /*use_projection_bias=*/false, /*cell_clip=*/0.0,
+                   /*proj_clip=*/0.0,
+                   {
+                       {n_batch, n_input},  // input tensor
 
-          {n_cell, n_input},  // input_to_input_weight tensor
-          {n_cell, n_input},  // input_to_forget_weight tensor
-          {n_cell, n_input},  // input_to_cell_weight tensor
-          {n_cell, n_input},  // input_to_output_weight tensor
+                       {n_cell, n_input},  // input_to_input_weight tensor
+                       {n_cell, n_input},  // input_to_forget_weight tensor
+                       {n_cell, n_input},  // input_to_cell_weight tensor
+                       {n_cell, n_input},  // input_to_output_weight tensor
 
-          {n_cell, n_output},  // recurrent_to_input_weight tensor
-          {n_cell, n_output},  // recurrent_to_forget_weight tensor
-          {n_cell, n_output},  // recurrent_to_cell_weight tensor
-          {n_cell, n_output},  // recurrent_to_output_weight tensor
+                       {n_cell, n_output},  // recurrent_to_input_weight tensor
+                       {n_cell, n_output},  // recurrent_to_forget_weight tensor
+                       {n_cell, n_output},  // recurrent_to_cell_weight tensor
+                       {n_cell, n_output},  // recurrent_to_output_weight tensor
 
-          {0},  // cell_to_input_weight tensor
-          {0},  // cell_to_forget_weight tensor
-          {0},  // cell_to_output_weight tensor
+                       {0},  // cell_to_input_weight tensor
+                       {0},  // cell_to_forget_weight tensor
+                       {0},  // cell_to_output_weight tensor
 
-          {n_cell},  // input_gate_bias tensor
-          {n_cell},  // forget_gate_bias tensor
-          {n_cell},  // cell_bias tensor
-          {n_cell},  // output_gate_bias tensor
+                       {n_cell},  // input_gate_bias tensor
+                       {n_cell},  // forget_gate_bias tensor
+                       {n_cell},  // cell_bias tensor
+                       {n_cell},  // output_gate_bias tensor
 
-          {0, 0},  // projection_weight tensor
-          {0},     // projection_bias tensor
-      },
-      TensorType_UINT8);
-
-  lstm.SetInputToInputWeights(input_to_input_weights_);
-  lstm.SetInputToCellWeights(input_to_cell_weights_);
-  lstm.SetInputToForgetWeights(input_to_forget_weights_);
-  lstm.SetInputToOutputWeights(input_to_output_weights_);
-
-  lstm.SetInputGateBias(input_gate_bias_);
-  lstm.SetCellBias(cell_gate_bias_);
-  lstm.SetForgetGateBias(forget_gate_bias_);
-  lstm.SetOutputGateBias(output_gate_bias_);
-
-  lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
-  lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
-  lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
-  lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+                       {0, 0},  // projection_weight tensor
+                       {0},     // projection_bias tensor
+                   },
+                   /*weight_type=*/TensorType_UINT8,
+                   /*is_layer_norm=*/false);
 
   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
                 /*tolerance=*/0.0157651);
@@ -657,52 +620,38 @@
   const int n_cell = 4;
   const int n_output = 4;
 
-  HybridLSTMOpModel lstm(
-      n_batch, n_input, n_cell, n_output,
-      /*use_cifg=*/false, /*use_peephole=*/false,
-      /*use_projection_weights=*/false,
-      /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
-      {
-          {n_batch, n_input},  // input tensor
+  LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
+                   /*use_cifg=*/false, /*use_peephole=*/false,
+                   /*use_projection_weights=*/false,
+                   /*use_projection_bias=*/false, /*cell_clip=*/0.0,
+                   /*proj_clip=*/0.0,
+                   {
+                       {n_batch, n_input},  // input tensor
 
-          {n_cell, n_input},  // input_to_input_weight tensor
-          {n_cell, n_input},  // input_to_forget_weight tensor
-          {n_cell, n_input},  // input_to_cell_weight tensor
-          {n_cell, n_input},  // input_to_output_weight tensor
+                       {n_cell, n_input},  // input_to_input_weight tensor
+                       {n_cell, n_input},  // input_to_forget_weight tensor
+                       {n_cell, n_input},  // input_to_cell_weight tensor
+                       {n_cell, n_input},  // input_to_output_weight tensor
 
-          {n_cell, n_output},  // recurrent_to_input_weight tensor
-          {n_cell, n_output},  // recurrent_to_forget_weight tensor
-          {n_cell, n_output},  // recurrent_to_cell_weight tensor
-          {n_cell, n_output},  // recurrent_to_output_weight tensor
+                       {n_cell, n_output},  // recurrent_to_input_weight tensor
+                       {n_cell, n_output},  // recurrent_to_forget_weight tensor
+                       {n_cell, n_output},  // recurrent_to_cell_weight tensor
+                       {n_cell, n_output},  // recurrent_to_output_weight tensor
 
-          {0},  // cell_to_input_weight tensor
-          {0},  // cell_to_forget_weight tensor
-          {0},  // cell_to_output_weight tensor
+                       {0},  // cell_to_input_weight tensor
+                       {0},  // cell_to_forget_weight tensor
+                       {0},  // cell_to_output_weight tensor
 
-          {n_cell},  // input_gate_bias tensor
-          {n_cell},  // forget_gate_bias tensor
-          {n_cell},  // cell_bias tensor
-          {n_cell},  // output_gate_bias tensor
+                       {n_cell},  // input_gate_bias tensor
+                       {n_cell},  // forget_gate_bias tensor
+                       {n_cell},  // cell_bias tensor
+                       {n_cell},  // output_gate_bias tensor
 
-          {0, 0},  // projection_weight tensor
-          {0},     // projection_bias tensor
-      },
-      TensorType_INT8);
-
-  lstm.SetInputToInputWeights(input_to_input_weights_);
-  lstm.SetInputToCellWeights(input_to_cell_weights_);
-  lstm.SetInputToForgetWeights(input_to_forget_weights_);
-  lstm.SetInputToOutputWeights(input_to_output_weights_);
-
-  lstm.SetInputGateBias(input_gate_bias_);
-  lstm.SetCellBias(cell_gate_bias_);
-  lstm.SetForgetGateBias(forget_gate_bias_);
-  lstm.SetOutputGateBias(output_gate_bias_);
-
-  lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
-  lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
-  lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
-  lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+                       {0, 0},  // projection_weight tensor
+                       {0},     // projection_bias tensor
+                   },
+                   /*weight_type=*/TensorType_INT8,
+                   /*is_layer_norm=*/false);
 
   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
                 /*tolerance=*/0.0157651);
@@ -791,22 +740,9 @@
 
                        {0, 0},  // projection_weight tensor
                        {0},     // projection_bias tensor
-                   });
-
-  lstm.SetInputToCellWeights(input_to_cell_weights_);
-  lstm.SetInputToForgetWeights(input_to_forget_weights_);
-  lstm.SetInputToOutputWeights(input_to_output_weights_);
-
-  lstm.SetCellBias(cell_gate_bias_);
-  lstm.SetForgetGateBias(forget_gate_bias_);
-  lstm.SetOutputGateBias(output_gate_bias_);
-
-  lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
-  lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
-  lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
-
-  lstm.SetCellToForgetWeights(cell_to_forget_weights_);
-  lstm.SetCellToOutputWeights(cell_to_output_weights_);
+                   },
+                   /*weight_type=*/TensorType_FLOAT32,
+                   /*is_layer_norm=*/false);
 
   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
 }
@@ -819,53 +755,38 @@
   const int n_cell = 4;
   const int n_output = 4;
 
-  HybridLSTMOpModel lstm(
-      n_batch, n_input, n_cell, n_output,
-      /*use_cifg=*/true, /*use_peephole=*/true,
-      /*use_projection_weights=*/false,
-      /*use_projection_bias=*/false,
-      /*cell_clip=*/0.0, /*proj_clip=*/0.0,
-      {
-          {n_batch, n_input},  // input tensor
+  LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
+                   /*use_cifg=*/true, /*use_peephole=*/true,
+                   /*use_projection_weights=*/false,
+                   /*use_projection_bias=*/false,
+                   /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+                   {
+                       {n_batch, n_input},  // input tensor
 
-          {0, 0},             // input_to_input_weight tensor
-          {n_cell, n_input},  // input_to_forget_weight tensor
-          {n_cell, n_input},  // input_to_cell_weight tensor
-          {n_cell, n_input},  // input_to_output_weight tensor
+                       {0, 0},             // input_to_input_weight tensor
+                       {n_cell, n_input},  // input_to_forget_weight tensor
+                       {n_cell, n_input},  // input_to_cell_weight tensor
+                       {n_cell, n_input},  // input_to_output_weight tensor
 
-          {0, 0},              // recurrent_to_input_weight tensor
-          {n_cell, n_output},  // recurrent_to_forget_weight tensor
-          {n_cell, n_output},  // recurrent_to_cell_weight tensor
-          {n_cell, n_output},  // recurrent_to_output_weight tensor
+                       {0, 0},              // recurrent_to_input_weight tensor
+                       {n_cell, n_output},  // recurrent_to_forget_weight tensor
+                       {n_cell, n_output},  // recurrent_to_cell_weight tensor
+                       {n_cell, n_output},  // recurrent_to_output_weight tensor
 
-          {0},       // cell_to_input_weight tensor
-          {n_cell},  // cell_to_forget_weight tensor
-          {n_cell},  // cell_to_output_weight tensor
+                       {0},       // cell_to_input_weight tensor
+                       {n_cell},  // cell_to_forget_weight tensor
+                       {n_cell},  // cell_to_output_weight tensor
 
-          {0},       // input_gate_bias tensor
-          {n_cell},  // forget_gate_bias tensor
-          {n_cell},  // cell_bias tensor
-          {n_cell},  // output_gate_bias tensor
+                       {0},       // input_gate_bias tensor
+                       {n_cell},  // forget_gate_bias tensor
+                       {n_cell},  // cell_bias tensor
+                       {n_cell},  // output_gate_bias tensor
 
-          {0, 0},  // projection_weight tensor
-          {0},     // projection_bias tensor
-      },
-      TensorType_UINT8);
-
-  lstm.SetInputToCellWeights(input_to_cell_weights_);
-  lstm.SetInputToForgetWeights(input_to_forget_weights_);
-  lstm.SetInputToOutputWeights(input_to_output_weights_);
-
-  lstm.SetCellBias(cell_gate_bias_);
-  lstm.SetForgetGateBias(forget_gate_bias_);
-  lstm.SetOutputGateBias(output_gate_bias_);
-
-  lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
-  lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
-  lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
-
-  lstm.SetCellToForgetWeights(cell_to_forget_weights_);
-  lstm.SetCellToOutputWeights(cell_to_output_weights_);
+                       {0, 0},  // projection_weight tensor
+                       {0},     // projection_bias tensor
+                   },
+                   /*weight_type=*/TensorType_UINT8,
+                   /*is_layer_norm=*/false);
 
   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
 }
@@ -878,53 +799,38 @@
   const int n_cell = 4;
   const int n_output = 4;
 
-  HybridLSTMOpModel lstm(
-      n_batch, n_input, n_cell, n_output,
-      /*use_cifg=*/true, /*use_peephole=*/true,
-      /*use_projection_weights=*/false,
-      /*use_projection_bias=*/false,
-      /*cell_clip=*/0.0, /*proj_clip=*/0.0,
-      {
-          {n_batch, n_input},  // input tensor
+  LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
+                   /*use_cifg=*/true, /*use_peephole=*/true,
+                   /*use_projection_weights=*/false,
+                   /*use_projection_bias=*/false,
+                   /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+                   {
+                       {n_batch, n_input},  // input tensor
 
-          {0, 0},             // input_to_input_weight tensor
-          {n_cell, n_input},  // input_to_forget_weight tensor
-          {n_cell, n_input},  // input_to_cell_weight tensor
-          {n_cell, n_input},  // input_to_output_weight tensor
+                       {0, 0},             // input_to_input_weight tensor
+                       {n_cell, n_input},  // input_to_forget_weight tensor
+                       {n_cell, n_input},  // input_to_cell_weight tensor
+                       {n_cell, n_input},  // input_to_output_weight tensor
 
-          {0, 0},              // recurrent_to_input_weight tensor
-          {n_cell, n_output},  // recurrent_to_forget_weight tensor
-          {n_cell, n_output},  // recurrent_to_cell_weight tensor
-          {n_cell, n_output},  // recurrent_to_output_weight tensor
+                       {0, 0},              // recurrent_to_input_weight tensor
+                       {n_cell, n_output},  // recurrent_to_forget_weight tensor
+                       {n_cell, n_output},  // recurrent_to_cell_weight tensor
+                       {n_cell, n_output},  // recurrent_to_output_weight tensor
 
-          {0},       // cell_to_input_weight tensor
-          {n_cell},  // cell_to_forget_weight tensor
-          {n_cell},  // cell_to_output_weight tensor
+                       {0},       // cell_to_input_weight tensor
+                       {n_cell},  // cell_to_forget_weight tensor
+                       {n_cell},  // cell_to_output_weight tensor
 
-          {0},       // input_gate_bias tensor
-          {n_cell},  // forget_gate_bias tensor
-          {n_cell},  // cell_bias tensor
-          {n_cell},  // output_gate_bias tensor
+                       {0},       // input_gate_bias tensor
+                       {n_cell},  // forget_gate_bias tensor
+                       {n_cell},  // cell_bias tensor
+                       {n_cell},  // output_gate_bias tensor
 
-          {0, 0},  // projection_weight tensor
-          {0},     // projection_bias tensor
-      },
-      TensorType_INT8);
-
-  lstm.SetInputToCellWeights(input_to_cell_weights_);
-  lstm.SetInputToForgetWeights(input_to_forget_weights_);
-  lstm.SetInputToOutputWeights(input_to_output_weights_);
-
-  lstm.SetCellBias(cell_gate_bias_);
-  lstm.SetForgetGateBias(forget_gate_bias_);
-  lstm.SetOutputGateBias(output_gate_bias_);
-
-  lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
-  lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
-  lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
-
-  lstm.SetCellToForgetWeights(cell_to_forget_weights_);
-  lstm.SetCellToOutputWeights(cell_to_output_weights_);
+                       {0, 0},  // projection_weight tensor
+                       {0},     // projection_bias tensor
+                   },
+                   /*weight_type=*/TensorType_INT8,
+                   /*is_layer_norm=*/false);
 
   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
 }
@@ -1563,91 +1469,51 @@
 
                        {n_output, n_cell},  // projection_weight tensor
                        {0},                 // projection_bias tensor
-                   });
-
-  lstm.SetInputToInputWeights(input_to_input_weights_);
-  lstm.SetInputToCellWeights(input_to_cell_weights_);
-  lstm.SetInputToForgetWeights(input_to_forget_weights_);
-  lstm.SetInputToOutputWeights(input_to_output_weights_);
-
-  lstm.SetInputGateBias(input_gate_bias_);
-  lstm.SetCellBias(cell_gate_bias_);
-  lstm.SetForgetGateBias(forget_gate_bias_);
-  lstm.SetOutputGateBias(output_gate_bias_);
-
-  lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
-  lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
-  lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
-  lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
-
-  lstm.SetCellToInputWeights(cell_to_input_weights_);
-  lstm.SetCellToForgetWeights(cell_to_forget_weights_);
-  lstm.SetCellToOutputWeights(cell_to_output_weights_);
-
-  lstm.SetProjectionWeights(projection_weights_);
+                   },
+                   /*weight_type=*/TensorType_FLOAT32,
+                   /*is_layer_norm=*/false);
 
   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
 }
 
-TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTesInt8) {
+TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTestInt8) {
   const int n_batch = 2;
   const int n_input = 5;
   const int n_cell = 20;
   const int n_output = 16;
 
-  HybridLSTMOpModel lstm(
-      n_batch, n_input, n_cell, n_output,
-      /*use_cifg=*/false, /*use_peephole=*/true,
-      /*use_projection_weights=*/true,
-      /*use_projection_bias=*/false,
-      /*cell_clip=*/0.0, /*proj_clip=*/0.0,
-      {
-          {n_batch, n_input},  // input tensor
+  LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
+                   /*use_cifg=*/false, /*use_peephole=*/true,
+                   /*use_projection_weights=*/true,
+                   /*use_projection_bias=*/false,
+                   /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+                   {
+                       {n_batch, n_input},  // input tensor
 
-          {n_cell, n_input},  // input_to_input_weight tensor
-          {n_cell, n_input},  // input_to_forget_weight tensor
-          {n_cell, n_input},  // input_to_cell_weight tensor
-          {n_cell, n_input},  // input_to_output_weight tensor
+                       {n_cell, n_input},  // input_to_input_weight tensor
+                       {n_cell, n_input},  // input_to_forget_weight tensor
+                       {n_cell, n_input},  // input_to_cell_weight tensor
+                       {n_cell, n_input},  // input_to_output_weight tensor
 
-          {n_cell, n_output},  // recurrent_to_input_weight tensor
-          {n_cell, n_output},  // recurrent_to_forget_weight tensor
-          {n_cell, n_output},  // recurrent_to_cell_weight tensor
-          {n_cell, n_output},  // recurrent_to_output_weight tensor
+                       {n_cell, n_output},  // recurrent_to_input_weight tensor
+                       {n_cell, n_output},  // recurrent_to_forget_weight tensor
+                       {n_cell, n_output},  // recurrent_to_cell_weight tensor
+                       {n_cell, n_output},  // recurrent_to_output_weight tensor
 
-          {n_cell},  // cell_to_input_weight tensor
-          {n_cell},  // cell_to_forget_weight tensor
-          {n_cell},  // cell_to_output_weight tensor
+                       {n_cell},  // cell_to_input_weight tensor
+                       {n_cell},  // cell_to_forget_weight tensor
+                       {n_cell},  // cell_to_output_weight tensor
 
-          {n_cell},  // input_gate_bias tensor
-          {n_cell},  // forget_gate_bias tensor
-          {n_cell},  // cell_bias tensor
-          {n_cell},  // output_gate_bias tensor
+                       {n_cell},  // input_gate_bias tensor
+                       {n_cell},  // forget_gate_bias tensor
+                       {n_cell},  // cell_bias tensor
+                       {n_cell},  // output_gate_bias tensor
 
-          {n_output, n_cell},  // projection_weight tensor
-          {0},                 // projection_bias tensor
-      },
-      TensorType_UINT8);
-
-  lstm.SetInputToInputWeights(input_to_input_weights_);
-  lstm.SetInputToCellWeights(input_to_cell_weights_);
-  lstm.SetInputToForgetWeights(input_to_forget_weights_);
-  lstm.SetInputToOutputWeights(input_to_output_weights_);
-
-  lstm.SetInputGateBias(input_gate_bias_);
-  lstm.SetCellBias(cell_gate_bias_);
-  lstm.SetForgetGateBias(forget_gate_bias_);
-  lstm.SetOutputGateBias(output_gate_bias_);
-
-  lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
-  lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
-  lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
-  lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
-
-  lstm.SetCellToInputWeights(cell_to_input_weights_);
-  lstm.SetCellToForgetWeights(cell_to_forget_weights_);
-  lstm.SetCellToOutputWeights(cell_to_output_weights_);
-
-  lstm.SetProjectionWeights(projection_weights_);
+                       {n_output, n_cell},  // projection_weight tensor
+                       {0},                 // projection_bias tensor
+                   },
+                   /*weight_type=*/TensorType_INT8,
+                   /*is_layer_norm=*/false);
 
   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
 }
@@ -1659,233 +1525,44 @@
   const int n_cell = 20;
   const int n_output = 16;
 
-  HybridLSTMOpModel lstm(
-      n_batch, n_input, n_cell, n_output,
-      /*use_cifg=*/false, /*use_peephole=*/true,
-      /*use_projection_weights=*/true,
-      /*use_projection_bias=*/false,
-      /*cell_clip=*/0.0, /*proj_clip=*/0.0,
-      {
-          {n_batch, n_input},  // input tensor
+  LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
+                   /*use_cifg=*/false, /*use_peephole=*/true,
+                   /*use_projection_weights=*/true,
+                   /*use_projection_bias=*/false,
+                   /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+                   {
+                       {n_batch, n_input},  // input tensor
 
-          {n_cell, n_input},  // input_to_input_weight tensor
-          {n_cell, n_input},  // input_to_forget_weight tensor
-          {n_cell, n_input},  // input_to_cell_weight tensor
-          {n_cell, n_input},  // input_to_output_weight tensor
+                       {n_cell, n_input},  // input_to_input_weight tensor
+                       {n_cell, n_input},  // input_to_forget_weight tensor
+                       {n_cell, n_input},  // input_to_cell_weight tensor
+                       {n_cell, n_input},  // input_to_output_weight tensor
 
-          {n_cell, n_output},  // recurrent_to_input_weight tensor
-          {n_cell, n_output},  // recurrent_to_forget_weight tensor
-          {n_cell, n_output},  // recurrent_to_cell_weight tensor
-          {n_cell, n_output},  // recurrent_to_output_weight tensor
+                       {n_cell, n_output},  // recurrent_to_input_weight tensor
+                       {n_cell, n_output},  // recurrent_to_forget_weight tensor
+                       {n_cell, n_output},  // recurrent_to_cell_weight tensor
+                       {n_cell, n_output},  // recurrent_to_output_weight tensor
 
-          {n_cell},  // cell_to_input_weight tensor
-          {n_cell},  // cell_to_forget_weight tensor
-          {n_cell},  // cell_to_output_weight tensor
+                       {n_cell},  // cell_to_input_weight tensor
+                       {n_cell},  // cell_to_forget_weight tensor
+                       {n_cell},  // cell_to_output_weight tensor
 
-          {n_cell},  // input_gate_bias tensor
-          {n_cell},  // forget_gate_bias tensor
-          {n_cell},  // cell_bias tensor
-          {n_cell},  // output_gate_bias tensor
+                       {n_cell},  // input_gate_bias tensor
+                       {n_cell},  // forget_gate_bias tensor
+                       {n_cell},  // cell_bias tensor
+                       {n_cell},  // output_gate_bias tensor
 
-          {n_output, n_cell},  // projection_weight tensor
-          {0},                 // projection_bias tensor
-      },
-      TensorType_INT8);
-
-  lstm.SetInputToInputWeights(input_to_input_weights_);
-  lstm.SetInputToCellWeights(input_to_cell_weights_);
-  lstm.SetInputToForgetWeights(input_to_forget_weights_);
-  lstm.SetInputToOutputWeights(input_to_output_weights_);
-
-  lstm.SetInputGateBias(input_gate_bias_);
-  lstm.SetCellBias(cell_gate_bias_);
-  lstm.SetForgetGateBias(forget_gate_bias_);
-  lstm.SetOutputGateBias(output_gate_bias_);
-
-  lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
-  lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
-  lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
-  lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
-
-  lstm.SetCellToInputWeights(cell_to_input_weights_);
-  lstm.SetCellToForgetWeights(cell_to_forget_weights_);
-  lstm.SetCellToOutputWeights(cell_to_output_weights_);
-
-  lstm.SetProjectionWeights(projection_weights_);
+                       {n_output, n_cell},  // projection_weight tensor
+                       {0},                 // projection_bias tensor
+                   },
+                   /*weight_type=*/TensorType_UINT8,
+                   /*is_layer_norm=*/false);
 
   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
 }
 
-class LayerNormLSTMOpModel : public LSTMOpModel {
- public:
-  LayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
-                       bool use_cifg, bool use_peephole,
-                       bool use_projection_weights, bool use_projection_bias,
-                       float cell_clip, float proj_clip,
-                       const std::vector<std::vector<int>>& input_shapes,
-                       const TensorType& weight_type = TensorType_FLOAT32)
-      : LSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg, use_peephole,
-                    use_projection_weights, use_projection_bias, cell_clip,
-                    proj_clip, input_shapes, weight_type,
-                    /*is_layer_norm*/ true) {}
-};
-
-class HybridLayerNormLSTMOpModel : public LayerNormLSTMOpModel {
- public:
-  HybridLayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
-                             bool use_cifg, bool use_peephole,
-                             bool use_projection_weights,
-                             bool use_projection_bias, float cell_clip,
-                             float proj_clip,
-                             const std::vector<std::vector<int>>& input_shapes,
-                             TensorType tensor_type)
-      : LayerNormLSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg,
-                             use_peephole, use_projection_weights,
-                             use_projection_bias, cell_clip, proj_clip,
-                             input_shapes, tensor_type) {
-    tensor_type_ = tensor_type;
-  }
-
-  TensorType tensor_type_;
-
-  void SetWeights(int weights_idx, const std::vector<float>& f) {
-    if (tensor_type_ == TensorType_UINT8) {
-      SymmetricQuantizeAndPopulate(weights_idx, f);
-    } else {
-      SignedSymmetricQuantizeAndPopulate(weights_idx, f);
-    }
-  }
-
-  void SetInputToInputWeights(const std::vector<float>& f) {
-    SetWeights(input_to_input_weights_, f);
-  }
-
-  void SetInputToForgetWeights(const std::vector<float>& f) {
-    SetWeights(input_to_forget_weights_, f);
-  }
-
-  void SetInputToCellWeights(const std::vector<float>& f) {
-    SetWeights(input_to_cell_weights_, f);
-  }
-
-  void SetInputToOutputWeights(const std::vector<float>& f) {
-    SetWeights(input_to_output_weights_, f);
-  }
-
-  void SetRecurrentToInputWeights(const std::vector<float>& f) {
-    SetWeights(recurrent_to_input_weights_, f);
-  }
-
-  void SetRecurrentToForgetWeights(const std::vector<float>& f) {
-    SetWeights(recurrent_to_forget_weights_, f);
-  }
-
-  void SetRecurrentToCellWeights(const std::vector<float>& f) {
-    SetWeights(recurrent_to_cell_weights_, f);
-  }
-
-  void SetRecurrentToOutputWeights(const std::vector<float>& f) {
-    SetWeights(recurrent_to_output_weights_, f);
-  }
-
-  void SetCellToInputWeights(const std::vector<float>& f) {
-    SetWeights(cell_to_input_weights_, f);
-  }
-
-  void SetCellToForgetWeights(const std::vector<float>& f) {
-    SetWeights(cell_to_forget_weights_, f);
-  }
-
-  void SetCellToOutputWeights(const std::vector<float>& f) {
-    SetWeights(cell_to_output_weights_, f);
-  }
-
-  void SetInputLayerNormCoefficients(const std::vector<float>& f) {
-    PopulateTensor(input_layer_norm_coefficients_, f);
-  }
-
-  void SetForgetLayerNormCoefficients(const std::vector<float>& f) {
-    PopulateTensor(forget_layer_norm_coefficients_, f);
-  }
-
-  void SetCellLayerNormCoefficients(const std::vector<float>& f) {
-    PopulateTensor(cell_layer_norm_coefficients_, f);
-  }
-
-  void SetOutputLayerNormCoefficients(const std::vector<float>& f) {
-    PopulateTensor(output_layer_norm_coefficients_, f);
-  }
-
-  void SetProjectionWeights(const std::vector<float>& f) {
-    SetWeights(projection_weights_, f);
-  }
-};
-
-class BaseLayerNormLstmTest : public ::testing::Test {
- protected:
-  // Weights of the Layer Norm LSTM model. Some are optional.
-  std::vector<float> input_to_input_weights_;
-  std::vector<float> input_to_cell_weights_;
-  std::vector<float> input_to_forget_weights_;
-  std::vector<float> input_to_output_weights_;
-  std::vector<float> input_gate_bias_;
-  std::vector<float> cell_gate_bias_;
-  std::vector<float> forget_gate_bias_;
-  std::vector<float> output_gate_bias_;
-  std::vector<float> recurrent_to_input_weights_;
-  std::vector<float> recurrent_to_cell_weights_;
-  std::vector<float> recurrent_to_forget_weights_;
-  std::vector<float> recurrent_to_output_weights_;
-  std::vector<float> cell_to_input_weights_;
-  std::vector<float> cell_to_forget_weights_;
-  std::vector<float> cell_to_output_weights_;
-  std::vector<float> projection_weights_;
-  std::vector<float> input_layer_norm_coefficients_;
-  std::vector<float> forget_layer_norm_coefficients_;
-  std::vector<float> cell_layer_norm_coefficients_;
-  std::vector<float> output_layer_norm_coefficients_;
-
-  // Layer Norm LSTM input is stored as num_batch x num_inputs vector.
-  std::vector<std::vector<float>> layer_norm_lstm_input_;
-
-  // Compares output up to tolerance to the result of the layer_norm_lstm given
-  // the input.
-  void VerifyGoldens(const std::vector<std::vector<float>>& input,
-                     const std::vector<std::vector<float>>& output,
-                     LayerNormLSTMOpModel* layer_norm_lstm,
-                     float tolerance = 1e-5) {
-    const int num_batches = input.size();
-    EXPECT_GT(num_batches, 0);
-    const int num_inputs = layer_norm_lstm->num_inputs();
-    EXPECT_GT(num_inputs, 0);
-    const int input_sequence_size = input[0].size() / num_inputs;
-    EXPECT_GT(input_sequence_size, 0);
-    for (int i = 0; i < input_sequence_size; ++i) {
-      for (int b = 0; b < num_batches; ++b) {
-        const float* batch_start = input[b].data() + i * num_inputs;
-        const float* batch_end = batch_start + num_inputs;
-
-        layer_norm_lstm->SetInput(b * layer_norm_lstm->num_inputs(),
-                                  batch_start, batch_end);
-      }
-
-      layer_norm_lstm->Invoke();
-
-      const int num_outputs = layer_norm_lstm->num_outputs();
-      std::vector<float> expected;
-      for (int b = 0; b < num_batches; ++b) {
-        const float* golden_start_batch = output[b].data() + i * num_outputs;
-        const float* golden_end_batch = golden_start_batch + num_outputs;
-        expected.insert(expected.end(), golden_start_batch, golden_end_batch);
-      }
-      EXPECT_THAT(layer_norm_lstm->GetOutput(),
-                  ElementsAreArray(ArrayFloatNear(expected, tolerance)));
-    }
-  }
-};
-
 class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest
-    : public BaseLayerNormLstmTest {
+    : public BaseLstmTest {
   void SetUp() override {
     input_to_input_weights_ = {0.5,  0.6,  0.7,  -0.8, -0.9, 0.1,  0.2,
                                0.3,  -0.4, 0.5,  -0.8, 0.7,  -0.6, 0.5,
@@ -1937,7 +1614,7 @@
     projection_weights_ = {-0.1, 0.2,  0.01, -0.2, 0.1,  0.5,
                            0.3,  0.08, 0.07, 0.2,  -0.4, 0.2};
 
-    layer_norm_lstm_input_ = {
+    lstm_input_ = {
         {// Batch0: 3 (input_sequence_size) * 5 (n_input)
          0.7, 0.8, 0.1, 0.2, 0.3,   // seq 0
          0.8, 0.1, 0.2, 0.4, 0.5,   // seq 1
@@ -1960,7 +1637,7 @@
   const float cell_clip = 0.0;
   const float proj_clip = 0.0;
 
-  LayerNormLSTMOpModel layer_norm_lstm(
+  LSTMOpModel layer_norm_lstm(
       n_batch, n_input, n_cell, n_output,
       /*use_cifg=*/false, /*use_peephole=*/true,
       /*use_projection_weights=*/true,
@@ -1997,53 +1674,25 @@
           {n_cell},  // forget_layer_norm_coefficient tensor
           {n_cell},  // cell_layer_norm_coefficient tensor
           {n_cell},  // output_layer_norm_coefficient tensor
-      });
-
-  layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
-  layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
-  layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
-  layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
-
-  layer_norm_lstm.SetInputGateBias(input_gate_bias_);
-  layer_norm_lstm.SetCellBias(cell_gate_bias_);
-  layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
-  layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
-
-  layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
-  layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
-  layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
-  layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
-
-  layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
-  layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
-  layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
-
-  layer_norm_lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients_);
-  layer_norm_lstm.SetForgetLayerNormCoefficients(
-      forget_layer_norm_coefficients_);
-  layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
-  layer_norm_lstm.SetOutputLayerNormCoefficients(
-      output_layer_norm_coefficients_);
-
-  layer_norm_lstm.SetProjectionWeights(projection_weights_);
+      },
+      /*weight_type=*/TensorType_FLOAT32,
+      /*is_layer_norm=*/true);
 
   // Verify the final output.
-  const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
-      {
-          // Batch0: 3 (input_sequence_size) * 3 (n_output)
-          0.0244077, 0.128027, -0.00170918,  // seq 0
-          0.0137642, 0.140751, 0.0395835,    // seq 1
-          -0.00459231, 0.155278, 0.0837377,  // seq 2
-      },
-      {
-          // Batch1: 3 (input_sequence_size) * 3 (n_output)
-          -0.00692428, 0.0848741, 0.063445,  // seq 0
-          -0.00403912, 0.139963, 0.072681,   // seq 1
-          0.00752706, 0.161903, 0.0561371,   // seq 2
-      }};
+  lstm_golden_output_ = {{
+                             // Batch0: 3 (input_sequence_size) * 3 (n_output)
+                             0.0244077, 0.128027, -0.00170918,  // seq 0
+                             0.0137642, 0.140751, 0.0395835,    // seq 1
+                             -0.00459231, 0.155278, 0.0837377,  // seq 2
+                         },
+                         {
+                             // Batch1: 3 (input_sequence_size) * 3 (n_output)
+                             -0.00692428, 0.0848741, 0.063445,  // seq 0
+                             -0.00403912, 0.139963, 0.072681,   // seq 1
+                             0.00752706, 0.161903, 0.0561371,   // seq 2
+                         }};
 
-  VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
-                &layer_norm_lstm);
+  VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
 }
 
 TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
@@ -2055,7 +1704,7 @@
   const float cell_clip = 0.0;
   const float proj_clip = 0.0;
 
-  HybridLayerNormLSTMOpModel layer_norm_lstm(
+  LSTMOpModel layer_norm_lstm(
       n_batch, n_input, n_cell, n_output,
       /*use_cifg=*/false, /*use_peephole=*/true,
       /*use_projection_weights=*/true,
@@ -2093,52 +1742,24 @@
           {n_cell},  // cell_layer_norm_coefficient tensor
           {n_cell},  // output_layer_norm_coefficient tensor
       },
-      TensorType_UINT8);
+      /*weight_type=*/TensorType_UINT8,
+      /*is_layer_norm=*/true);
 
-  layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
-  layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
-  layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
-  layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
+  lstm_golden_output_ = {{
+                             // Batch0: 3 (input_sequence_size) * 3 (n_output)
+                             0.0244576, 0.127847, -0.00181765,  // seq 0
+                             0.0137518, 0.140892, 0.0402234,    // seq 1
+                             -0.0048839, 0.155096, 0.0840309,   // seq 2
+                         },
+                         {
+                             // Batch1: 3 (input_sequence_size) * 3 (n_output)
+                             -0.00728636, 0.0843957, 0.0634786,  // seq 0
+                             -0.00448382, 0.139278, 0.0737372,   // seq 1
+                             0.00734616, 0.161793, 0.0560238,    // seq 2
+                         }};
 
-  layer_norm_lstm.SetInputGateBias(input_gate_bias_);
-  layer_norm_lstm.SetCellBias(cell_gate_bias_);
-  layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
-  layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
-
-  layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
-  layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
-  layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
-  layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
-
-  layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
-  layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
-  layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
-
-  layer_norm_lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients_);
-  layer_norm_lstm.SetForgetLayerNormCoefficients(
-      forget_layer_norm_coefficients_);
-  layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
-  layer_norm_lstm.SetOutputLayerNormCoefficients(
-      output_layer_norm_coefficients_);
-
-  layer_norm_lstm.SetProjectionWeights(projection_weights_);
-
-  const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
-      {
-          // Batch0: 3 (input_sequence_size) * 3 (n_output)
-          0.0244576, 0.127847, -0.00181765,  // seq 0
-          0.0137518, 0.140892, 0.0402234,    // seq 1
-          -0.0048839, 0.155096, 0.0840309,   // seq 2
-      },
-      {
-          // Batch1: 3 (input_sequence_size) * 3 (n_output)
-          -0.00728636, 0.0843957, 0.0634786,  // seq 0
-          -0.00448382, 0.139278, 0.0737372,   // seq 1
-          0.00734616, 0.161793, 0.0560238,    // seq 2
-      }};
-
-  VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
-                &layer_norm_lstm);
+  VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm,
+                /*tolerance=*/0.0010907);
 }
 
 TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
@@ -2150,7 +1771,7 @@
   const float cell_clip = 0.0;
   const float proj_clip = 0.0;
 
-  HybridLayerNormLSTMOpModel layer_norm_lstm(
+  LSTMOpModel layer_norm_lstm(
       n_batch, n_input, n_cell, n_output,
       /*use_cifg=*/false, /*use_peephole=*/true,
       /*use_projection_weights=*/true,
@@ -2188,56 +1809,26 @@
           {n_cell},  // cell_layer_norm_coefficient tensor
           {n_cell},  // output_layer_norm_coefficient tensor
       },
-      TensorType_INT8);
+      /*weight_type=*/TensorType_INT8,
+      /*is_layer_norm=*/true);
 
-  layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
-  layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
-  layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
-  layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
+  lstm_golden_output_ = {{
+                             // Batch0: 3 (input_sequence_size) * 3 (n_output)
+                             0.0244576, 0.127847, -0.00181765,  // seq 0
+                             0.0137518, 0.140892, 0.0402234,    // seq 1
+                             -0.0048839, 0.155096, 0.0840309,   // seq 2
+                         },
+                         {
+                             // Batch1: 3 (input_sequence_size) * 3 (n_output)
+                             -0.00728636, 0.0843957, 0.0634786,  // seq 0
+                             -0.00448382, 0.139278, 0.0737372,   // seq 1
+                             0.00734616, 0.161793, 0.0560238,    // seq 2
+                         }};
 
-  layer_norm_lstm.SetInputGateBias(input_gate_bias_);
-  layer_norm_lstm.SetCellBias(cell_gate_bias_);
-  layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
-  layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
-
-  layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
-  layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
-  layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
-  layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
-
-  layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
-  layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
-  layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
-
-  layer_norm_lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients_);
-  layer_norm_lstm.SetForgetLayerNormCoefficients(
-      forget_layer_norm_coefficients_);
-  layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
-  layer_norm_lstm.SetOutputLayerNormCoefficients(
-      output_layer_norm_coefficients_);
-
-  layer_norm_lstm.SetProjectionWeights(projection_weights_);
-
-  const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
-      {
-          // Batch0: 3 (input_sequence_size) * 3 (n_output)
-          0.0244576, 0.127847, -0.00181765,  // seq 0
-          0.0137518, 0.140892, 0.0402234,    // seq 1
-          -0.0048839, 0.155096, 0.0840309,   // seq 2
-      },
-      {
-          // Batch1: 3 (input_sequence_size) * 3 (n_output)
-          -0.00728636, 0.0843957, 0.0634786,  // seq 0
-          -0.00448382, 0.139278, 0.0737372,   // seq 1
-          0.00734616, 0.161793, 0.0560238,    // seq 2
-      }};
-
-  VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
-                &layer_norm_lstm);
+  VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
 }
 
-class CifgPeepholeProjectionNoClippingLayerNormLstmTest
-    : public BaseLayerNormLstmTest {
+class CifgPeepholeProjectionNoClippingLayerNormLstmTest : public BaseLstmTest {
   void SetUp() override {
     input_to_forget_weights_ = {-0.6, -0.1, 0.3,  0.2,  0.9,  -0.5, -0.2,
                                 -0.4, 0.3,  -0.8, -0.4, 0.3,  -0.5, -0.4,
@@ -2269,7 +1860,7 @@
     projection_weights_ = {-0.1, 0.2,  0.01, -0.2, 0.1,  0.5,
                            0.3,  0.08, 0.07, 0.2,  -0.4, 0.2};
 
-    layer_norm_lstm_input_ = {
+    lstm_input_ = {
         {// Batch0: 3 (input_sequence_size) * 5 (n_input)
          0.7, 0.8, 0.1, 0.2, 0.3,   // seq 0
          0.8, 0.1, 0.2, 0.4, 0.5,   // seq 1
@@ -2292,7 +1883,7 @@
   const float cell_clip = 0.0;
   const float proj_clip = 0.0;
 
-  LayerNormLSTMOpModel layer_norm_lstm(
+  LSTMOpModel layer_norm_lstm(
       n_batch, n_input, n_cell, n_output,
       /*use_cifg=*/true, /*use_peephole=*/true,
       /*use_projection_weights=*/true,
@@ -2329,33 +1920,12 @@
           {n_cell},  // forget_layer_norm_coefficient tensor
           {n_cell},  // cell_layer_norm_coefficient tensor
           {n_cell},  // output_layer_norm_coefficient tensor
-      });
-
-  layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
-  layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
-  layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
-
-  layer_norm_lstm.SetCellBias(cell_gate_bias_);
-  layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
-  layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
-
-  layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
-  layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
-  layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
-
-  layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
-  layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
-
-  layer_norm_lstm.SetForgetLayerNormCoefficients(
-      forget_layer_norm_coefficients_);
-  layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
-  layer_norm_lstm.SetOutputLayerNormCoefficients(
-      output_layer_norm_coefficients_);
-
-  layer_norm_lstm.SetProjectionWeights(projection_weights_);
+      },
+      /*weight_type=*/TensorType_FLOAT32,
+      /*is_layer_norm=*/true);
 
   // Verify the final output.
-  const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
+  lstm_golden_output_ = {
       {
           // Batch0: 3 (input_sequence_size) * 3 (n_output)
           0.02129706, 0.140816242, 0.0112733059,     // seq 0
@@ -2369,8 +1939,7 @@
           -0.0103429332, 0.173016444, 0.0720508844,   // seq 2
       }};
 
-  VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
-                &layer_norm_lstm);
+  VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
 }
 
 TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
@@ -2382,7 +1951,7 @@
   const float cell_clip = 0.0;
   const float proj_clip = 0.0;
 
-  HybridLayerNormLSTMOpModel layer_norm_lstm(
+  LSTMOpModel layer_norm_lstm(
       n_batch, n_input, n_cell, n_output,
       /*use_cifg=*/true, /*use_peephole=*/true,
       /*use_projection_weights=*/true,
@@ -2420,33 +1989,11 @@
           {n_cell},  // cell_layer_norm_coefficient tensor
           {n_cell},  // output_layer_norm_coefficient tensor
       },
-      TensorType_UINT8);
-
-  layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
-  layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
-  layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
-
-  layer_norm_lstm.SetCellBias(cell_gate_bias_);
-  layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
-  layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
-
-  layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
-  layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
-  layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
-
-  layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
-  layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
-
-  layer_norm_lstm.SetForgetLayerNormCoefficients(
-      forget_layer_norm_coefficients_);
-  layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
-  layer_norm_lstm.SetOutputLayerNormCoefficients(
-      output_layer_norm_coefficients_);
-
-  layer_norm_lstm.SetProjectionWeights(projection_weights_);
+      /*weight_type=*/TensorType_UINT8,
+      /*is_layer_norm=*/true);
 
   // Verify the final output.
-  const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
+  lstm_golden_output_ = {
       {
           // Batch0: 3 (input_sequence_size) * 3 (n_output)
           0.0212250091, 0.140474007, 0.0115012666,   // seq 0
@@ -2460,8 +2007,8 @@
           -0.0103605557, 0.172605693, 0.0728750974,   // seq 2
       }};
 
-  VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
-                &layer_norm_lstm);
+  VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm,
+                /*tolerance=*/0.000902065);
 }
 
 TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
@@ -2473,7 +2020,7 @@
   const float cell_clip = 0.0;
   const float proj_clip = 0.0;
 
-  HybridLayerNormLSTMOpModel layer_norm_lstm(
+  LSTMOpModel layer_norm_lstm(
       n_batch, n_input, n_cell, n_output,
       /*use_cifg=*/true, /*use_peephole=*/true,
       /*use_projection_weights=*/true,
@@ -2511,33 +2058,11 @@
           {n_cell},  // cell_layer_norm_coefficient tensor
           {n_cell},  // output_layer_norm_coefficient tensor
       },
-      TensorType_INT8);
-
-  layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
-  layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
-  layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
-
-  layer_norm_lstm.SetCellBias(cell_gate_bias_);
-  layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
-  layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
-
-  layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
-  layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
-  layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
-
-  layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
-  layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
-
-  layer_norm_lstm.SetForgetLayerNormCoefficients(
-      forget_layer_norm_coefficients_);
-  layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
-  layer_norm_lstm.SetOutputLayerNormCoefficients(
-      output_layer_norm_coefficients_);
-
-  layer_norm_lstm.SetProjectionWeights(projection_weights_);
+      /*weight_type=*/TensorType_INT8,
+      /*is_layer_norm=*/true);
 
   // Verify the final output.
-  const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
+  lstm_golden_output_ = {
       {
           // Batch0: 3 (input_sequence_size) * 3 (n_output)
           0.0212250091, 0.140474007, 0.0115012666,   // seq 0
@@ -2551,8 +2076,7 @@
           -0.0103605557, 0.172605693, 0.0728750974,   // seq 2
       }};
 
-  VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
-                &layer_norm_lstm);
+  VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
 }
 
 #ifdef GTEST_HAS_DEATH_TEST
@@ -2593,7 +2117,8 @@
                        {0, 0},  // projection_weight tensor
                        {0},     // projection_bias tensor
                    },
-                   /*weight_type=*/TensorType_INT32),
+                   /*weight_type=*/TensorType_INT32,
+                   /*is_layer_norm=*/false),
                "");
 
   EXPECT_DEATH(LSTMOpModel lstm(
@@ -2627,7 +2152,8 @@
                        {0, 0},  // projection_weight tensor
                        {0},     // projection_bias tensor
                    },
-                   /*weight_type=*/TensorType_COMPLEX64),
+                   /*weight_type=*/TensorType_COMPLEX64,
+                   /*is_layer_norm=*/false),
                "");
 }
 #endif
diff --git a/tensorflow/lite/kernels/mul.cc b/tensorflow/lite/kernels/mul.cc
index f11a1f3..8c7eaf8 100644
--- a/tensorflow/lite/kernels/mul.cc
+++ b/tensorflow/lite/kernels/mul.cc
@@ -101,8 +101,8 @@
       output->type == kTfLiteInt16) {
     double real_multiplier =
         input1->params.scale * input2->params.scale / output->params.scale;
-    QuantizeMultiplierSmallerThanOneExp(
-        real_multiplier, &data->output_multiplier, &data->output_shift);
+    QuantizeMultiplier(real_multiplier, &data->output_multiplier,
+                       &data->output_shift);
   }
 
   return context->ResizeTensor(context, output, output_size);
diff --git a/tensorflow/lite/kernels/mul_test.cc b/tensorflow/lite/kernels/mul_test.cc
index b6a7700..30f9c52 100644
--- a/tensorflow/lite/kernels/mul_test.cc
+++ b/tensorflow/lite/kernels/mul_test.cc
@@ -206,12 +206,37 @@
                                               kQuantizedTolerance)));
 }
 
+template <TensorType tensor_type, typename integer_dtype>
+void NoActivationLargeMultiplier() {
+  // TODO(b/138722124): Remove this after setting the appropriate op version (3)
+  // for dependent tests.
+  if (SingleOpModel::GetForceUseNnapi()) {
+    // NNAPI doesn't currently support Mul with multiplier>1.
+    return;
+  }
+  // Intentionally pathological output range much narrower than needed
+  // to represent input values to exercise the multiplier>1 case.
+  QuantizedMulOpModel m({tensor_type, {1, 2, 2, 1}, -100, 100},
+                        {tensor_type, {1, 2, 2, 1}, -100, 100},
+                        {tensor_type, {}, -10, 10},
+                        ActivationFunctionType_NONE);
+  m.QuantizeAndPopulate<integer_dtype>(m.input1(), {-4, 2, 3, 1});
+  m.QuantizeAndPopulate<integer_dtype>(m.input2(), {-1, -3, 4, 2});
+  m.Invoke();
+  // Note the large tolerance. This computation is inherently inaccurate.
+  const float kTolerance = 1.4f;
+  EXPECT_THAT(m.GetDequantizedOutput<integer_dtype>(),
+              ElementsAreArray(ArrayFloatNear({4, -6, 10, 2}, kTolerance)));
+}
+
 TEST(QuantizedMulOpTest, NoActivationUInt8) {
   NoActivation<TensorType_UINT8, uint8_t>();
+  NoActivationLargeMultiplier<TensorType_UINT8, uint8_t>();
 }
 
 TEST(QuantizedMulOpTest, NoActivationInt8) {
   NoActivation<TensorType_INT8, int8_t>();
+  NoActivationLargeMultiplier<TensorType_INT8, int8_t>();
 }
 
 TEST(QuantizedMulOpTest, NoActivationInt16) {
diff --git a/tensorflow/lite/kernels/quant_basic_lstm_test.cc b/tensorflow/lite/kernels/quant_basic_lstm_test.cc
new file mode 100644
index 0000000..e8f7ad3
--- /dev/null
+++ b/tensorflow/lite/kernels/quant_basic_lstm_test.cc
@@ -0,0 +1,230 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <initializer_list>
+#include <memory>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/kernels/register.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class QuantizedLSTMOpModel : public SingleOpModel {
+ public:
+  QuantizedLSTMOpModel(int numBatches, int inputSize, float weightsScale,
+                       int32_t weightsZeroPoint, int outputSize,
+                       std::initializer_list<uint8_t> weights,
+                       std::initializer_list<int32_t> biases) {
+    std::vector<uint32_t> inputs;
+
+    input_size_ = inputSize;
+    output_size_ = outputSize;
+
+    std::vector<int> input_shape{numBatches, inputSize};
+    std::vector<int> output_shape{numBatches, outputSize};
+    std::vector<int> weight_shape{4 * outputSize, outputSize + inputSize};
+    std::vector<int> state_shape{numBatches, outputSize};
+    std::vector<int> bias_shape{4 * outputSize};
+
+    input_ =
+        AddInput({TensorType_UINT8, input_shape, 0.0f, 0.0f, 1. / 128., 128});
+    prev_output_ =
+        AddInput({TensorType_UINT8, output_shape, 0.0f, 0.0f, 1. / 128., 128});
+    // Biases and Weights have to be constant in order to allow NNAPI
+    // delegation
+    weights_ = AddConstInput<uint8_t>({TensorType_UINT8, weight_shape, 0.0f,
+                                       0.0f, weightsScale, weightsZeroPoint},
+                                      weights);
+    biases_ = AddConstInput<int32_t>(
+        {TensorType_INT32, bias_shape, 0.0f, 0.0f, weightsScale / 128, 0},
+        biases);
+    prev_cell_state_ =
+        AddInput({TensorType_INT16, state_shape, 0.0f, 0.0f, 1. / 2048., 0});
+
+    output_ =
+        AddOutput({TensorType_UINT8, output_shape, 0.0f, 0.0f, 1. / 128., 128});
+    cell_state_out_ =
+        AddOutput({TensorType_INT16, state_shape, 0.0f, 0.0f, 1. / 2048., 0});
+    output_concat_temp_ =
+        AddOutput({TensorType_UINT8, output_shape, 0.0f, 0.0f, 1. / 128., 128});
+    output_activation_temp_ =
+        AddOutput({TensorType_INT16, output_shape, 0.0f, 0.0f, 1. / 128., 128});
+
+    SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
+                 CreateLSTMOptions(builder_, ActivationFunctionType_TANH, 0.0,
+                                   0.0, LSTMKernelType_BASIC)
+                     .Union());
+
+    BuildInterpreter({GetShape(input_), GetShape(prev_output_),
+                      GetShape(weights_), GetShape(biases_),
+                      GetShape(prev_cell_state_)});
+
+    // init feedback inputs to zero
+    std::vector<int16_t> initial_state(GetTensorSize(cell_state_out_), 0);
+    PopulateTensor(prev_cell_state_, initial_state);
+    std::vector<uint8_t> initial_prev_output(GetTensorSize(output_), 0);
+    PopulateTensor(prev_output_, initial_prev_output);
+  }
+
+  int inputSize() { return input_size_; }
+
+  int outputSize() { return output_size_; }
+
+  void setInput(const std::vector<uint8_t>& input) {
+    PopulateTensor(input_, input);
+  }
+
+  std::vector<uint8_t> getOutput() { return ExtractVector<uint8_t>(output_); }
+
+ private:
+  // Inputs
+  int input_;
+  int weights_;
+  int biases_;
+  int prev_cell_state_;
+  int prev_output_;
+  // Outputs
+  int cell_state_out_;
+  int output_;
+  int output_concat_temp_;
+  int output_activation_temp_;
+
+  int input_size_;
+  int output_size_;
+};
+
+class QuantizedLstmTest : public ::testing::Test {
+ protected:
+  void VerifyGoldens(const std::vector<std::vector<uint8_t>>& input,
+                     const std::vector<std::vector<uint8_t>>& output,
+                     QuantizedLSTMOpModel* lstm) {
+    const int numBatches = input.size();
+    ASSERT_GT(numBatches, 0);
+    const int inputSize = lstm->inputSize();
+    ASSERT_GT(inputSize, 0);
+    const int inputSequenceSize = input[0].size() / inputSize;
+    ASSERT_GT(inputSequenceSize, 0);
+    for (int i = 0; i < inputSequenceSize; ++i) {
+      std::vector<uint8_t> inputStep;
+      for (int b = 0; b < numBatches; ++b) {
+        const uint8_t* batchStart = input[b].data() + i * inputSize;
+        const uint8_t* batchEnd = batchStart + inputSize;
+        inputStep.insert(inputStep.end(), batchStart, batchEnd);
+      }
+      lstm->setInput(inputStep);
+      lstm->Invoke();
+
+      const int outputSize = lstm->outputSize();
+      std::vector<float> expected;
+      for (int b = 0; b < numBatches; ++b) {
+        const uint8_t* goldenBatchStart = output[b].data() + i * outputSize;
+        const uint8_t* goldenBatchEnd = goldenBatchStart + outputSize;
+        expected.insert(expected.end(), goldenBatchStart, goldenBatchEnd);
+      }
+      EXPECT_THAT(lstm->getOutput(), ElementsAreArray(expected));
+    }
+  }
+};
+
+// Inputs and weights in this test are random and the test only checks that the
+// outputs are equal to outputs obtained from running TF Lite version of
+// quantized LSTM on the same inputs.
+TEST_F(QuantizedLstmTest, BasicQuantizedLstmTest) {
+  const int numBatches = 2;
+  const int inputSize = 2;
+  const int outputSize = 4;
+
+  float weightsScale = 0.00408021;
+  int weightsZeroPoint = 100;
+
+  QuantizedLSTMOpModel lstm(
+      numBatches, inputSize, weightsScale, weightsZeroPoint, outputSize,
+
+      // This data are copied from QuantizedLSTMTest.cpp in NNAPI source code
+      // I have to recompose the weight matrix before passing it to the model
+
+      // recurrentToInputWeights   inputToInputWeights
+      {254, 206, 77, 168, 146, 250, 71, 20, 215, 6, 235, 171, 223, 7, 118, 225,
+       10, 218, 59, 130, 174, 26, 171, 108,
+
+       // recurrentToCellWeights     inputToCellWeights
+       172, 60, 205, 65, 133, 34, 14, 0, 140, 168, 29, 49, 240, 223, 133, 56,
+       206, 109, 142, 64, 246, 216, 54, 183,
+
+       // recurrentToForgetWeights   inputToForgetWeights
+       137, 240, 103, 52, 24, 50, 68, 51, 237, 112, 132, 179, 0, 220, 89, 23,
+       158, 110, 69, 4, 207, 253, 3, 169,
+
+       // recurrentToOutputWeights  inputToOutputWeights
+       106, 214, 67, 23, 195, 187, 59, 158, 45, 3, 11, 99, 119, 132, 49, 205,
+       109, 10, 129, 218, 11, 98, 218, 48},
+
+      // inputGateBias
+      {-7876, 13488, -726, 32839,
+       // cellGateBias
+       39481, 48624, 48976, -21419,
+       // forgetGateBias
+       9206, -46884, -11693, -38724,
+       // outputGateBias
+       -58999, -17050, -41852, -40538});
+  // clang-format on
+
+  // LSTM input is stored as numBatches x (sequenceLength x inputSize) vector.
+  std::vector<std::vector<uint8_t>> lstmInput;
+  // clang-format off
+    lstmInput = {{154, 166,
+                  166, 179,
+                  141, 141},
+                 {100, 200,
+                  50,  150,
+                  111, 222}};
+  // clang-format on
+
+  // LSTM output is stored as numBatches x (sequenceLength x outputSize) vector.
+  std::vector<std::vector<uint8_t>> lstmGoldenOutput;
+  /*
+    This is the output used in NNAPI's QuantizedLSTMTest.cpp
+    I get slightly different values that are consistent running with or
+    without acceleration
+
+    lstmGoldenOutput = {{136, 150, 140, 115,
+                         140, 151, 146, 112,
+                         139, 153, 146, 114},
+                        {135, 152, 138, 112,
+                         136, 156, 142, 112,
+                         141, 154, 146, 108}};
+   */
+
+  // clang-format off
+    lstmGoldenOutput = {{131, 152, 136, 109,
+                         138, 150, 145, 111,
+                         139, 152, 146, 113},
+                        {131, 153, 135, 107,
+                         134, 154, 140, 111,
+                         140, 154, 145, 108}};
+  // clang-format on
+  VerifyGoldens(lstmInput, lstmGoldenOutput, &lstm);
+}
+
+}  // namespace
+}  // namespace tflite
diff --git a/tensorflow/lite/kernels/quantize.cc b/tensorflow/lite/kernels/quantize.cc
index 35a0bb5..234d6d3 100644
--- a/tensorflow/lite/kernels/quantize.cc
+++ b/tensorflow/lite/kernels/quantize.cc
@@ -24,6 +24,12 @@
 namespace builtin {
 namespace quantize {
 
+// This file has two implementation of Quantize.
+enum KernelType {
+  kReference,
+  kGenericOptimized,
+};
+
 struct OpData {
   int32_t output_multiplier;
   int output_shift;
@@ -87,6 +93,7 @@
                                TfLiteIntArrayCopy(op_context.input->dims));
 }
 
+template <KernelType kernel_type>
 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
   OpData* data = reinterpret_cast<OpData*>(node->user_data);
 
@@ -100,17 +107,35 @@
       op_params.zero_point = output->params.zero_point;
       op_params.scale = output->params.scale;
       if (output->type == kTfLiteInt8) {
-        optimized_ops::AffineQuantize(
-            op_params, GetTensorShape(input), GetTensorData<float>(input),
-            GetTensorShape(output), GetTensorData<int8_t>(output));
+        if (kernel_type == kReference) {
+          reference_ops::AffineQuantize(
+              op_params, GetTensorShape(input), GetTensorData<float>(input),
+              GetTensorShape(output), GetTensorData<int8_t>(output));
+        } else {
+          optimized_ops::AffineQuantize(
+              op_params, GetTensorShape(input), GetTensorData<float>(input),
+              GetTensorShape(output), GetTensorData<int8_t>(output));
+        }
       } else if (output->type == kTfLiteUInt8) {
-        optimized_ops::AffineQuantize(
-            op_params, GetTensorShape(input), GetTensorData<float>(input),
-            GetTensorShape(output), GetTensorData<uint8_t>(output));
+        if (kernel_type == kReference) {
+          reference_ops::AffineQuantize(
+              op_params, GetTensorShape(input), GetTensorData<float>(input),
+              GetTensorShape(output), GetTensorData<uint8_t>(output));
+        } else {
+          optimized_ops::AffineQuantize(
+              op_params, GetTensorShape(input), GetTensorData<float>(input),
+              GetTensorShape(output), GetTensorData<uint8_t>(output));
+        }
       } else if (output->type == kTfLiteInt16) {
-        optimized_ops::AffineQuantize(
-            op_params, GetTensorShape(input), GetTensorData<float>(input),
-            GetTensorShape(output), GetTensorData<int16_t>(output));
+        if (kernel_type == kReference) {
+          reference_ops::AffineQuantize(
+              op_params, GetTensorShape(input), GetTensorData<float>(input),
+              GetTensorShape(output), GetTensorData<int16_t>(output));
+        } else {
+          optimized_ops::AffineQuantize(
+              op_params, GetTensorShape(input), GetTensorData<float>(input),
+              GetTensorShape(output), GetTensorData<int16_t>(output));
+        }
       } else {
         context->ReportError(
             context,
@@ -124,15 +149,29 @@
       const int32_t size =
           MatchingFlatSize(GetTensorShape(input), GetTensorShape(output));
       if (output->type == kTfLiteInt8) {
-        optimized_ops::Requantize<int8_t, int8_t>(
-            GetTensorData<int8_t>(input), size, data->output_multiplier,
-            data->output_shift, input->params.zero_point,
-            output->params.zero_point, GetTensorData<int8_t>(output));
+        if (kernel_type == kReference) {
+          reference_ops::Requantize<int8_t, int8_t>(
+              GetTensorData<int8_t>(input), size, data->output_multiplier,
+              data->output_shift, input->params.zero_point,
+              output->params.zero_point, GetTensorData<int8_t>(output));
+        } else {
+          optimized_ops::Requantize<int8_t, int8_t>(
+              GetTensorData<int8_t>(input), size, data->output_multiplier,
+              data->output_shift, input->params.zero_point,
+              output->params.zero_point, GetTensorData<int8_t>(output));
+        }
       } else if (output->type == kTfLiteUInt8) {
-        optimized_ops::Requantize<int8_t, uint8_t>(
-            GetTensorData<int8_t>(input), size, data->output_multiplier,
-            data->output_shift, input->params.zero_point,
-            output->params.zero_point, GetTensorData<uint8_t>(output));
+        if (kernel_type == kReference) {
+          reference_ops::Requantize<int8_t, uint8_t>(
+              GetTensorData<int8_t>(input), size, data->output_multiplier,
+              data->output_shift, input->params.zero_point,
+              output->params.zero_point, GetTensorData<uint8_t>(output));
+        } else {
+          optimized_ops::Requantize<int8_t, uint8_t>(
+              GetTensorData<int8_t>(input), size, data->output_multiplier,
+              data->output_shift, input->params.zero_point,
+              output->params.zero_point, GetTensorData<uint8_t>(output));
+        }
       } else {
         context->ReportError(
             context,
@@ -185,11 +224,25 @@
 // scale and zero point.
 TfLiteRegistration* Register_QUANTIZE_OPT() {
   static TfLiteRegistration r = {quantize::Init, quantize::Free,
-                                 quantize::Prepare, quantize::Eval};
+                                 quantize::Prepare,
+                                 quantize::Eval<quantize::kGenericOptimized>};
   return &r;
 }
 
-TfLiteRegistration* Register_QUANTIZE() { return Register_QUANTIZE_OPT(); }
+TfLiteRegistration* Register_QUANTIZE_REF() {
+  static TfLiteRegistration r = {quantize::Init, quantize::Free,
+                                 quantize::Prepare,
+                                 quantize::Eval<quantize::kReference>};
+  return &r;
+}
+
+TfLiteRegistration* Register_QUANTIZE() {
+#ifdef USE_NEON
+  return Register_QUANTIZE_OPT();
+#else
+  return Register_QUANTIZE_REF();
+#endif
+}
 
 }  // namespace builtin
 }  // namespace ops
diff --git a/tensorflow/lite/kernels/quantize_test.cc b/tensorflow/lite/kernels/quantize_test.cc
index e720f74..69b6f7d 100644
--- a/tensorflow/lite/kernels/quantize_test.cc
+++ b/tensorflow/lite/kernels/quantize_test.cc
@@ -129,6 +129,20 @@
               ElementsAreArray({1, 3, 5, 7, 9, 11, 13, 15, 17, 19}));
 }
 
+// Same as previous test, except more data to hit the neon path.
+TEST(QuantizeOpTest, Int8Int8SmallerScaleNeonPath) {
+  QuantizeOpModel m({TensorType_INT8, {1, 1, 4, 5}, -127, 128},
+                    {TensorType_INT8, {1, 1, 4, 5}, -63.5, 64});
+
+  // Input will quantized to {0,1,2,3,4,5,6,7,8,9,9,8,7,6,5,4,3,2,1,0}.
+  m.SetInputAndQuantize<int8_t>(
+      {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutput<int8_t>(),
+              ElementsAreArray({1,  3,  5,  7,  9,  11, 13, 15, 17, 19,
+                                19, 17, 15, 13, 11, 9,  7,  5,  3,  1}));
+}
+
 // Input scale 0.500000, output scale 0.500000, input zeropoint 127, output
 // zeropoint 127
 TEST(QuantizeOpTest, UInt8UInt8SameScale) {
@@ -171,6 +185,22 @@
       ElementsAreArray({129, 131, 133, 135, 137, 139, 141, 143, 145, 147}));
 }
 
+// Same as previous test, except more data to hit the neon path.
+TEST(QuantizeOpTest, Uint8Uint8SmallerScaleNeonPath) {
+  QuantizeOpModel m({TensorType_UINT8, {1, 1, 4, 5}, -127, 128},
+                    {TensorType_UINT8, {1, 1, 4, 5}, -63.5, 64});
+
+  // Input will quantized to {128, 129, 130, 131, 132, 133, 134, 135, 136, 137,
+  // 137, 136, 135, 134, 133, 132, 131, 130, 129, 128}.
+  m.SetInputAndQuantize<uint8_t>(
+      {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1});
+  m.Invoke();
+  EXPECT_THAT(
+      m.GetOutput<uint8_t>(),
+      ElementsAreArray({129, 131, 133, 135, 137, 139, 141, 143, 145, 147,
+                        147, 145, 143, 141, 139, 137, 135, 133, 131, 129}));
+}
+
 // Input scale 1.000000, output scale 1.000000, input zeropoint -1, output
 // zeropoint 127
 TEST(QuantizeOpTest, Int8Uint8SameScale) {
diff --git a/tensorflow/lite/kernels/read_variable.cc b/tensorflow/lite/kernels/read_variable.cc
new file mode 100644
index 0000000..4996bcc
--- /dev/null
+++ b/tensorflow/lite/kernels/read_variable.cc
@@ -0,0 +1,88 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string.h>
+
+#include <memory>
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/core/subgraph.h"
+#include "tensorflow/lite/kernels/internal/tensor.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace read_variable {
+
+constexpr int kInputVariableId = 0;
+constexpr int kOutputValue = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+  TF_LITE_ENSURE_EQ(context, node->inputs->size, 1);
+  TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+
+  const TfLiteTensor* input_variable_id_tensor =
+      GetInput(context, node, kInputVariableId);
+  TF_LITE_ENSURE_EQ(context, input_variable_id_tensor->type, kTfLiteInt32);
+  TF_LITE_ENSURE_EQ(context, NumElements(input_variable_id_tensor), 1);
+
+  TfLiteTensor* output = GetOutput(context, node, kOutputValue);
+  SetTensorToDynamic(output);
+
+  return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+  Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
+
+  const TfLiteTensor* input_variable_id_tensor =
+      GetInput(context, node, kInputVariableId);
+  int variable_id = input_variable_id_tensor->data.i32[0];
+  auto& resource_variables = subgraph->resource_variables();
+
+  const auto& variable_iterator = resource_variables.find(variable_id);
+  if (variable_iterator == resource_variables.end()) {
+    context->ReportError(context, "Variable ID %d is read before initialized.",
+                         variable_id);
+    return kTfLiteError;
+  }
+  auto& variable = variable_iterator->second;
+
+  TfLiteTensor* variable_tensor = variable.GetTensor();
+  TfLiteTensor* output = GetOutput(context, node, kOutputValue);
+
+  TF_LITE_ENSURE_EQ(context, variable_tensor->type, output->type);
+  TF_LITE_ENSURE_OK(
+      context, context->ResizeTensor(
+                   context, output, TfLiteIntArrayCopy(variable_tensor->dims)));
+  memcpy(output->data.raw, variable_tensor->data.raw, output->bytes);
+
+  return kTfLiteOk;
+}
+
+}  // namespace read_variable
+
+TfLiteRegistration* Register_READ_VARIABLE() {
+  static TfLiteRegistration r = {nullptr, nullptr, read_variable::Prepare,
+                                 read_variable::Eval};
+  return &r;
+}
+
+}  // namespace custom
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/lite/kernels/reduce.cc b/tensorflow/lite/kernels/reduce.cc
index 3474a40..1ac90c6 100644
--- a/tensorflow/lite/kernels/reduce.cc
+++ b/tensorflow/lite/kernels/reduce.cc
@@ -21,6 +21,7 @@
 #include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/c/c_api_internal.h"
 #include "tensorflow/lite/kernels/cpu_backend_context.h"
+#include "tensorflow/lite/kernels/internal/optimized/integer_ops/mean.h"
 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/lite/kernels/internal/quantization_util.h"
 #include "tensorflow/lite/kernels/internal/reference/integer_ops/mean.h"
@@ -39,6 +40,7 @@
 // This file has reference implementation of reduce_* operators.
 enum KernelType {
   kReference,
+  kGenericOptimized,
 };
 
 struct OpData {
@@ -285,115 +287,160 @@
     TF_LITE_ENSURE_OK(context, ResizeTempSum(context, &op_context, temp_sum));
   }
 
-  // Defer to specialized implementation for 4D Mean across axes 1 & 2.
-  if (op_context.input->type == kTfLiteFloat32 ||
-      op_context.input->type == kTfLiteUInt8) {
-    tflite::MeanParams op_params;
-    op_params.axis_count = num_axis;
-    ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
-    const TfLiteTensor* input = op_context.input;
-    if (op_context.params->keep_dims && NumDimensions(input) == 4 &&
-        op_params.axis_count == 2 &&
-        ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
-         (op_params.axis[0] == 2 && op_params.axis[1] == 1))) {
-      if (op_context.input->type == kTfLiteUInt8) {
-        optimized_ops::Mean(
-            op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
-            op_context.input->params.zero_point, op_context.input->params.scale,
-            GetTensorShape(op_context.output),
-            GetTensorData<uint8_t>(op_context.output),
-            op_context.output->params.zero_point,
-            op_context.output->params.scale,
-            CpuBackendContext::GetFromContext(context));
-      } else {
-        reference_ops::Mean(op_params, GetTensorShape(input),
-                            GetTensorData<float>(input),
-                            GetTensorShape(op_context.output),
-                            GetTensorData<float>(op_context.output));
-      }
-      return kTfLiteOk;
-    }
-  }
-
-  if (op_context.input->type == kTfLiteInt8) {
-    tflite::MeanParams op_params;
-    op_params.axis_count = num_axis;
-    ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
-    const TfLiteTensor* input = op_context.input;
-    reference_integer_ops::Mean(
-        op_params, data->multiplier, data->shift, GetTensorShape(input),
-        GetTensorData<int8_t>(input), op_context.input->params.zero_point,
-        GetTensorShape(op_context.output),
-        GetTensorData<int8_t>(op_context.output),
-        op_context.output->params.zero_point);
-    return kTfLiteOk;
-  }
-
-#define TF_LITE_MEAN(kernel_type, data_type, temp_data_type)        \
-  kernel_type::Mean<>(                                              \
-      GetTensorData<data_type>(op_context.input),                   \
-      op_context.input->dims->data, op_context.input->dims->size,   \
-      GetTensorData<data_type>(op_context.output),                  \
-      op_context.output->dims->data, op_context.output->dims->size, \
-      GetTensorData<int>(op_context.axis), num_axis,                \
-      op_context.params->keep_dims, GetTensorData<int>(temp_index), \
-      GetTensorData<int>(resolved_axis),                            \
-      GetTensorData<temp_data_type>(temp_sum))
-
-  if (kernel_type == kReference) {
+  if (kernel_type == kGenericOptimized) {
+    // Use optimized ops if available.
     switch (op_context.input->type) {
-      case kTfLiteFloat32: {
+      case kTfLiteInt8: {
         tflite::MeanParams op_params;
         op_params.axis_count = num_axis;
         ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
         const TfLiteTensor* input = op_context.input;
+        optimized_integer_ops::Mean(
+            op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
+            op_context.input->params.zero_point, op_context.input->params.scale,
+            GetTensorShape(op_context.output),
+            GetTensorData<int8_t>(op_context.output),
+            op_context.output->params.zero_point,
+            op_context.output->params.scale,
+            CpuBackendContext::GetFromContext(context));
+        return kTfLiteOk;
+      } break;
+      case kTfLiteUInt8: {
+        tflite::MeanParams op_params;
+        op_params.axis_count = num_axis;
+        ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
+        const TfLiteTensor* input = op_context.input;
+        // TODO(b/13910232): Handle the below special case in the optimized
+        // method.
         if (op_context.params->keep_dims && NumDimensions(input) == 4 &&
             op_params.axis_count == 2 &&
             ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
              (op_params.axis[0] == 2 && op_params.axis[1] == 1))) {
-          reference_ops::Mean(op_params, GetTensorShape(input),
-                              GetTensorData<float>(input),
-                              GetTensorShape(op_context.output),
-                              GetTensorData<float>(op_context.output));
-        } else {
-          TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, float, float));
+          optimized_ops::Mean(
+              op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
+              op_context.input->params.zero_point,
+              op_context.input->params.scale, GetTensorShape(op_context.output),
+              GetTensorData<uint8_t>(op_context.output),
+              op_context.output->params.zero_point,
+              op_context.output->params.scale,
+              CpuBackendContext::GetFromContext(context));
+          return kTfLiteOk;
         }
       } break;
-      case kTfLiteInt32:
-        TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, int, int64_t));
-        break;
-      case kTfLiteInt64:
-        TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, int64_t, int64_t));
-        break;
-      case kTfLiteUInt8:
-        if (op_context.input->params.zero_point ==
-                op_context.output->params.zero_point &&
-            op_context.input->params.scale == op_context.output->params.scale) {
-          TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, uint8_t, int));
-        } else {
-          TF_LITE_ENSURE(
-              context,
-              reference_ops::QuantizedMeanOrSum<>(
-                  GetTensorData<uint8_t>(op_context.input),
-                  op_context.input->params.zero_point,
-                  op_context.input->params.scale, op_context.input->dims->data,
-                  op_context.input->dims->size,
-                  GetTensorData<uint8_t>(op_context.output),
-                  op_context.output->params.zero_point,
-                  op_context.output->params.scale,
-                  op_context.output->dims->data, op_context.output->dims->size,
-                  GetTensorData<int>(op_context.axis), num_axis,
-                  op_context.params->keep_dims, GetTensorData<int>(temp_index),
-                  GetTensorData<int>(resolved_axis),
-                  GetTensorData<int>(temp_sum),
-                  /*compute_sum=*/false));
-        }
-        break;
       default:
-        return kTfLiteError;
+        break;
     }
   }
-#undef TF_LITE_MEAN
+
+  // From here, it uses the reference implementations.
+  // TODO(b/139102329): Clean up the function signatures to merge the variations
+  // and handle the specialized cases in the combined reference implementations
+  // per each op.
+  switch (op_context.input->type) {
+    case kTfLiteFloat32: {
+      tflite::MeanParams op_params;
+      op_params.axis_count = num_axis;
+      ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
+      const TfLiteTensor* input = op_context.input;
+      // TODO(b/13910232): Handle the below special case in the combined
+      // reference method.
+      // Defer to specialized implementation for 4D Mean across axes 1 & 2.
+      if (op_context.params->keep_dims && NumDimensions(input) == 4 &&
+          op_params.axis_count == 2 &&
+          ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
+           (op_params.axis[0] == 2 && op_params.axis[1] == 1))) {
+        reference_ops::Mean(op_params, GetTensorShape(input),
+                            GetTensorData<float>(input),
+                            GetTensorShape(op_context.output),
+                            GetTensorData<float>(op_context.output));
+      } else {
+        TF_LITE_ENSURE(
+            context,
+            reference_ops::Mean(
+                GetTensorData<float>(op_context.input),
+                op_context.input->dims->data, op_context.input->dims->size,
+                GetTensorData<float>(op_context.output),
+                op_context.output->dims->data, op_context.output->dims->size,
+                GetTensorData<int>(op_context.axis), num_axis,
+                op_context.params->keep_dims, GetTensorData<int>(temp_index),
+                GetTensorData<int>(resolved_axis),
+                GetTensorData<float>(temp_sum)));
+      }
+    } break;
+    case kTfLiteInt32:
+      TF_LITE_ENSURE(
+          context,
+          reference_ops::Mean(
+              GetTensorData<int>(op_context.input),
+              op_context.input->dims->data, op_context.input->dims->size,
+              GetTensorData<int>(op_context.output),
+              op_context.output->dims->data, op_context.output->dims->size,
+              GetTensorData<int>(op_context.axis), num_axis,
+              op_context.params->keep_dims, GetTensorData<int>(temp_index),
+              GetTensorData<int>(resolved_axis),
+              GetTensorData<int64_t>(temp_sum)));
+      break;
+    case kTfLiteInt64:
+      TF_LITE_ENSURE(
+          context,
+          reference_ops::Mean(
+              GetTensorData<int64_t>(op_context.input),
+              op_context.input->dims->data, op_context.input->dims->size,
+              GetTensorData<int64_t>(op_context.output),
+              op_context.output->dims->data, op_context.output->dims->size,
+              GetTensorData<int>(op_context.axis), num_axis,
+              op_context.params->keep_dims, GetTensorData<int>(temp_index),
+              GetTensorData<int>(resolved_axis),
+              GetTensorData<int64_t>(temp_sum)));
+      break;
+    case kTfLiteInt8: {
+      tflite::MeanParams op_params;
+      op_params.axis_count = num_axis;
+      ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
+      const TfLiteTensor* input = op_context.input;
+      reference_integer_ops::Mean(
+          op_params, data->multiplier, data->shift, GetTensorShape(input),
+          GetTensorData<int8_t>(input), op_context.input->params.zero_point,
+          GetTensorShape(op_context.output),
+          GetTensorData<int8_t>(op_context.output),
+          op_context.output->params.zero_point);
+    } break;
+    case kTfLiteUInt8: {
+      if (op_context.input->params.zero_point ==
+              op_context.output->params.zero_point &&
+          op_context.input->params.scale == op_context.output->params.scale) {
+        TF_LITE_ENSURE(
+            context,
+            reference_ops::Mean(
+                GetTensorData<uint8_t>(op_context.input),
+                op_context.input->dims->data, op_context.input->dims->size,
+                GetTensorData<uint8_t>(op_context.output),
+                op_context.output->dims->data, op_context.output->dims->size,
+                GetTensorData<int>(op_context.axis), num_axis,
+                op_context.params->keep_dims, GetTensorData<int>(temp_index),
+                GetTensorData<int>(resolved_axis),
+                GetTensorData<int>(temp_sum)));
+      } else {
+        TF_LITE_ENSURE(
+            context,
+            reference_ops::QuantizedMeanOrSum<>(
+                GetTensorData<uint8_t>(op_context.input),
+                op_context.input->params.zero_point,
+                op_context.input->params.scale, op_context.input->dims->data,
+                op_context.input->dims->size,
+                GetTensorData<uint8_t>(op_context.output),
+                op_context.output->params.zero_point,
+                op_context.output->params.scale, op_context.output->dims->data,
+                op_context.output->dims->size,
+                GetTensorData<int>(op_context.axis), num_axis,
+                op_context.params->keep_dims, GetTensorData<int>(temp_index),
+                GetTensorData<int>(resolved_axis), GetTensorData<int>(temp_sum),
+                /*compute_sum=*/false));
+      }
+    } break;
+    default:
+      return kTfLiteError;
+  }
   return kTfLiteOk;
 }
 
@@ -585,6 +632,13 @@
 }
 }  // namespace reduce
 
+TfLiteRegistration* Register_MEAN_OPT() {
+  static TfLiteRegistration r = {reduce::Init, reduce::Free,
+                                 reduce::PrepareMeanOrSum,
+                                 reduce::EvalMean<reduce::kGenericOptimized>};
+  return &r;
+}
+
 TfLiteRegistration* Register_MEAN_REF() {
   static TfLiteRegistration r = {reduce::Init, reduce::Free,
                                  reduce::PrepareMeanOrSum,
@@ -626,8 +680,13 @@
   return &r;
 }
 
-// TODO(kanlig): add optimized implementation of Mean.
-TfLiteRegistration* Register_MEAN() { return Register_MEAN_REF(); }
+TfLiteRegistration* Register_MEAN() {
+#ifdef USE_NEON
+  return Register_MEAN_OPT();
+#else
+  return Register_MEAN_REF();
+#endif
+}
 TfLiteRegistration* Register_SUM() { return Register_SUM_REF(); }
 TfLiteRegistration* Register_REDUCE_PROD() {
   return Register_REDUCE_PROD_REF();
diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc
index 6832ac7..9a88231 100644
--- a/tensorflow/lite/kernels/register.cc
+++ b/tensorflow/lite/kernels/register.cc
@@ -200,7 +200,7 @@
              Register_EMBEDDING_LOOKUP_SPARSE());
   AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED(),
              /* min_version */ 1,
-             /* max_version */ 5);
+             /* max_version */ 6);
   AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION());
   AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP());
   AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX(),
@@ -219,7 +219,7 @@
              /* min_version */ 1,
              /* max_version */ 2);
   AddBuiltin(BuiltinOperator_MUL, Register_MUL(), /* min_version */ 1,
-             /* max_version */ 2);
+             /* max_version */ 3);
   AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION(),
              /* min_version */ 1,
              /* max_version */ 2);
@@ -251,7 +251,7 @@
              /* max_version */ 2);
   AddBuiltin(BuiltinOperator_GATHER, Register_GATHER(),
              /* min_version */ 1,
-             /* max_version */ 2);
+             /* max_version */ 3);
   AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE(),
              /* min_version */ 1,
              /* max_version */ 3);
diff --git a/tensorflow/lite/kernels/reshape.cc b/tensorflow/lite/kernels/reshape.cc
index 7da36a2..3cb0742 100644
--- a/tensorflow/lite/kernels/reshape.cc
+++ b/tensorflow/lite/kernels/reshape.cc
@@ -31,8 +31,10 @@
 constexpr int kShapeTensor = 1;
 constexpr int kOutputTensor = 0;
 
-TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node,
-                          TfLiteIntArray* output_shape) {
+TfLiteIntArray* GetOutputShape(TfLiteContext*, TfLiteNode*);
+
+TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
+  TfLiteIntArray* output_shape = GetOutputShape(context, node);
   std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)>
       scoped_output_shape(output_shape, TfLiteIntArrayFree);
 
@@ -65,8 +67,8 @@
   return context->ResizeTensor(context, output, scoped_output_shape.release());
 }
 
-TfLiteIntArray* GetOutputShapeFromTensor(TfLiteContext* context,
-                                         TfLiteNode* node) {
+inline TfLiteIntArray* GetOutputShapeFromTensor(TfLiteContext* context,
+                                                TfLiteNode* node) {
   const TfLiteTensor* shape = GetInput(context, node, kShapeTensor);
 
   TfLiteIntArray* output_shape = TfLiteIntArrayCreate(shape->dims->data[0]);
@@ -77,8 +79,8 @@
   return output_shape;
 }
 
-TfLiteIntArray* GetOutputShapeFromParam(TfLiteContext* context,
-                                        TfLiteNode* node) {
+inline TfLiteIntArray* GetOutputShapeFromParam(TfLiteContext* context,
+                                               TfLiteNode* node) {
   auto* params = reinterpret_cast<TfLiteReshapeParams*>(node->builtin_data);
 
   // The function is returned above this line if the shape tensor is usable.
@@ -99,7 +101,7 @@
 }
 
 // Check if the shape tensor is valid. Shapes should be int32 vectors.
-bool ShapeIsVector(TfLiteContext* context, TfLiteNode* node) {
+inline bool ShapeIsVector(TfLiteContext* context, TfLiteNode* node) {
   const TfLiteTensor* shape = GetInput(context, node, kShapeTensor);
   return (shape->dims->size == 1 && shape->type == kTfLiteInt32);
 }
@@ -124,8 +126,7 @@
   if (output->type != kTfLiteString) {
     if (NumInputs(node) == 1 ||
         IsConstantTensor(GetInput(context, node, kShapeTensor))) {
-      TF_LITE_ENSURE_OK(
-          context, ResizeOutput(context, node, GetOutputShape(context, node)));
+      TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
     } else {
       SetTensorToDynamic(output);
     }
@@ -141,8 +142,7 @@
   // a string tensor, or its shape cannot be calculated during Prepare(). In
   // either case, we now have all the information to calculate its shape.
   if (IsDynamicTensor(output)) {
-    TF_LITE_ENSURE_OK(
-        context, ResizeOutput(context, node, GetOutputShape(context, node)));
+    TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
   }
 
   // Note that string tensors are always "dynamic" in the sense that their size
diff --git a/tensorflow/lite/kernels/resize_bilinear_test.cc b/tensorflow/lite/kernels/resize_bilinear_test.cc
index 1f1e05d..194ba51 100644
--- a/tensorflow/lite/kernels/resize_bilinear_test.cc
+++ b/tensorflow/lite/kernels/resize_bilinear_test.cc
@@ -349,14 +349,16 @@
   });
   m.SetSize({3, 3});
   m.Invoke();
-  EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear({
-                                         3, 5, 6,     //
-                                         7, 9, 10,    //
-                                         9, 11, 12,   //
-                                         4, 8, 10,    //
-                                         9, 12, 13,   //
-                                         12, 14, 16,  //
-                                     })));
+  EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear(
+                                         {
+                                             3, 5, 6,     //
+                                             7, 9, 10,    //
+                                             9, 11, 12,   //
+                                             4, 8, 10,    //
+                                             9, 12, 13,   //
+                                             12, 14, 16,  //
+                                         },
+                                         /*max_abs_error=*/1)));
 
   ResizeBilinearOpModel const_m({TensorType_INT8, {2, 2, 2, 1}}, {3, 3});
   const_m.SetInput<int8_t>({
@@ -366,14 +368,16 @@
       12, 16  //
   });
   const_m.Invoke();
-  EXPECT_THAT(const_m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear({
-                                               3, 5, 6,     //
-                                               7, 9, 10,    //
-                                               9, 11, 12,   //
-                                               4, 8, 10,    //
-                                               9, 12, 13,   //
-                                               12, 14, 16,  //
-                                           })));
+  EXPECT_THAT(const_m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear(
+                                               {
+                                                   3, 5, 6,     //
+                                                   7, 9, 10,    //
+                                                   9, 11, 12,   //
+                                                   4, 8, 10,    //
+                                                   9, 12, 13,   //
+                                                   12, 14, 16,  //
+                                               },
+                                               /*max_abs_error=*/1)));
 }
 
 TEST(ResizeBilinearOpTest, ThreeDimensionalResizeUInt8) {
@@ -415,11 +419,13 @@
   });
   m.SetSize({3, 3});
   m.Invoke();
-  EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear({
-                                         3, 4, 5, 8, 6, 10,       //
-                                         7, 9, 10, 12, 11, 13,    //
-                                         10, 12, 12, 14, 14, 16,  //
-                                     })));
+  EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear(
+                                         {
+                                             3, 4, 5, 8, 6, 10,       //
+                                             7, 9, 10, 12, 11, 13,    //
+                                             10, 12, 12, 14, 14, 16,  //
+                                         },
+                                         /*max_abs_error=*/1)));
 
   ResizeBilinearOpModel const_m({TensorType_INT8, {1, 2, 2, 2}}, {3, 3});
   const_m.SetInput<int8_t>({
@@ -427,11 +433,13 @@
       10, 12, 14, 16,  //
   });
   const_m.Invoke();
-  EXPECT_THAT(const_m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear({
-                                               3, 4, 5, 8, 6, 10,       //
-                                               7, 9, 10, 12, 11, 13,    //
-                                               10, 12, 12, 14, 14, 16,  //
-                                           })));
+  EXPECT_THAT(const_m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear(
+                                               {
+                                                   3, 4, 5, 8, 6, 10,       //
+                                                   7, 9, 10, 12, 11, 13,    //
+                                                   10, 12, 12, 14, 14, 16,  //
+                                               },
+                                               /*max_abs_error=*/1)));
 }
 }  // namespace
 }  // namespace tflite
diff --git a/tensorflow/lite/kernels/slice.cc b/tensorflow/lite/kernels/slice.cc
index 3b4ee40..af30cad 100644
--- a/tensorflow/lite/kernels/slice.cc
+++ b/tensorflow/lite/kernels/slice.cc
@@ -43,10 +43,11 @@
 const int kMaxDim = 4;
 
 template <typename T>
-TfLiteStatus CalculateOutputShapeVector(
-    TfLiteContext* context, const TfLiteTensor* input,
-    const TfLiteTensor* begin, const TfLiteTensor* size,
-    std::vector<int64_t>* output_shape_vector) {
+TfLiteStatus CalculateOutputShapeVector(TfLiteContext* context,
+                                        const TfLiteTensor* input,
+                                        const TfLiteTensor* begin,
+                                        const TfLiteTensor* size,
+                                        std::vector<int>* output_shape_vector) {
   for (int idx = 0; idx < NumDimensions(input); ++idx) {
     T size_value = GetTensorData<T>(size)[idx];
     if (size_value < 0) {
@@ -62,7 +63,7 @@
         return kTfLiteError;
       }
     }
-    output_shape_vector->push_back(size_value);
+    output_shape_vector->push_back(static_cast<int>(size_value));
   }
   return kTfLiteOk;
 }
@@ -81,7 +82,7 @@
                                const TfLiteTensor* input,
                                const TfLiteTensor* begin,
                                const TfLiteTensor* size, TfLiteTensor* output) {
-  std::vector<int64_t> output_shape_vector;
+  std::vector<int> output_shape_vector;
 
   if (begin->type == kTfLiteInt32) {
     TF_LITE_ENSURE_STATUS(CalculateOutputShapeVector<int32_t>(
diff --git a/tensorflow/lite/kernels/svdf.cc b/tensorflow/lite/kernels/svdf.cc
index ae04c96..3be938f 100644
--- a/tensorflow/lite/kernels/svdf.cc
+++ b/tensorflow/lite/kernels/svdf.cc
@@ -16,6 +16,9 @@
 // SVDF op that compresses a fully connected op via low-rank matrix
 // factorization. See https://research.google.com/pubs/archive/43813.pdf for
 // details.
+
+#include "tensorflow/lite/kernels/internal/reference/svdf.h"
+
 #include <cassert>
 #include <cmath>
 #include <cstdio>
@@ -26,6 +29,7 @@
 #include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/c/c_api_internal.h"
 #include "tensorflow/lite/kernels/activation_functor.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/kernels/op_macros.h"
@@ -43,60 +47,6 @@
   int activation_state_tensor_index;
 };
 
-static inline void ApplyTimeWeightsBiasAndActivation(
-    int batch_size, int memory_size, int num_filters, int num_units, int rank,
-    const TfLiteTensor* weights_time, const TfLiteTensor* bias,
-    TfLiteFusedActivation activation, TfLiteTensor* activation_state,
-    TfLiteTensor* scratch, TfLiteTensor* output) {
-  // Compute matmul(state, weights_time).
-  // The right most column is used to save temporary output (with the size of
-  // num_filters). This is achieved by starting at activation_state->data.f,
-  // and having the stride equal to memory_size.
-  for (int b = 0; b < batch_size; ++b) {
-    float* state_ptr_batch =
-        activation_state->data.f + b * memory_size * num_filters;
-    float* scratch_ptr_batch = scratch->data.f + b * num_filters;
-    tensor_utils::BatchVectorBatchVectorDotProduct(
-        weights_time->data.f, state_ptr_batch, memory_size, num_filters,
-        scratch_ptr_batch, /*result_stride=*/1);
-  }
-
-  // Initialize output with bias if provided.
-  if (bias) {
-    tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size,
-                                          output->data.f);
-  } else {
-    tensor_utils::ZeroVector(output->data.f, batch_size * num_units);
-  }
-
-  // Reduction sum.
-  for (int b = 0; b < batch_size; ++b) {
-    float* output_ptr_batch = output->data.f + b * num_units;
-    float* scratch_ptr_batch = scratch->data.f + b * num_filters;
-    tensor_utils::ReductionSumVector(scratch_ptr_batch, output_ptr_batch,
-                                     num_units, rank);
-  }
-
-  // Apply activation.
-  for (int b = 0; b < batch_size; ++b) {
-    float* output_ptr_batch = output->data.f + b * num_units;
-    tensor_utils::ApplyActivationToVector(output_ptr_batch, num_units,
-                                          activation, output_ptr_batch);
-  }
-
-  // Left shift the activation_state to make room for next cycle's activation.
-  // TODO(alanchiao): explore collapsing this into a single loop.
-  for (int b = 0; b < batch_size; ++b) {
-    float* state_ptr_batch =
-        activation_state->data.f + b * memory_size * num_filters;
-    for (int f = 0; f < num_filters; ++f) {
-      tensor_utils::VectorShiftLeft(state_ptr_batch, memory_size,
-                                    /*shift_value=*/0.0f);
-      state_ptr_batch += memory_size;
-    }
-  }
-}
-
 }  // namespace
 
 // Input tensors.
@@ -113,6 +63,7 @@
 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
   auto* op_data = new OpData();
   op_data->float_weights_time_initialized = false;
+  // Note: only needs 4 scratch tensors when is_hybrid_op, only 1 otherwise.
   context->AddTensors(context, /*tensors_to_add=*/4,
                       &op_data->scratch_tensor_index);
   return op_data;
@@ -241,123 +192,6 @@
   return kTfLiteOk;
 }
 
-TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
-                       const TfLiteTensor* input,
-                       const TfLiteTensor* weights_feature,
-                       const TfLiteTensor* weights_time,
-                       const TfLiteTensor* bias, const TfLiteSVDFParams* params,
-                       TfLiteTensor* scratch, TfLiteTensor* state,
-                       TfLiteTensor* output) {
-  const int rank = params->rank;
-  const int batch_size = input->dims->data[0];
-  const int input_size = input->dims->data[1];
-  const int num_filters = weights_feature->dims->data[0];
-  const int num_units = num_filters / rank;
-  const int memory_size = weights_time->dims->data[1];
-
-  // Clear the activation (state left most column).
-  // TODO(ghodrat): Add a test which initialize activation_state with invalid
-  // values in left most column and make sure it passes.
-  for (int b = 0; b < batch_size; ++b) {
-    float* state_ptr_batch = state->data.f + b * memory_size * num_filters;
-    for (int c = 0; c < num_filters; ++c) {
-      float* state_ptr = state_ptr_batch + c * memory_size;
-      state_ptr[memory_size - 1] = 0.0f;
-    }
-  }
-
-  // Compute conv1d(inputs, weights_feature).
-  // The state right most column is used to save current cycle activation. This
-  // is achieved by starting at state->data.f[memory_size - 1] and having the
-  // stride equal to memory_size.
-  tensor_utils::MatrixBatchVectorMultiplyAccumulate(
-      weights_feature->data.f, num_filters, input_size, input->data.f,
-      batch_size, &state->data.f[memory_size - 1], memory_size);
-
-  ApplyTimeWeightsBiasAndActivation(batch_size, memory_size, num_filters,
-                                    num_units, rank, weights_time, bias,
-                                    params->activation, state, scratch, output);
-  return kTfLiteOk;
-}
-
-TfLiteStatus EvalHybrid(
-    TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input,
-    const TfLiteTensor* weights_feature, const TfLiteTensor* weights_time,
-    const TfLiteTensor* bias, const TfLiteSVDFParams* params,
-    TfLiteTensor* scratch, TfLiteTensor* scaling_factors,
-    TfLiteTensor* input_quantized, TfLiteTensor* state, TfLiteTensor* output) {
-  const int rank = params->rank;
-  const int batch_size = input->dims->data[0];
-  const int input_size = input->dims->data[1];
-  const int num_filters = weights_feature->dims->data[0];
-  const int num_units = num_filters / rank;
-  const int memory_size = weights_time->dims->data[1];
-
-  // Initialize the pointer to input.
-  const float* input_ptr_batch = input->data.f;
-
-  // Initialize the pointer to storage for quantized values and the weights
-  // feature.
-  int8_t* quantized_input_ptr_batch;
-  const int8_t* weights_feature_ptr;
-  if (weights_feature->type == kTfLiteUInt8) {
-    quantized_input_ptr_batch =
-        reinterpret_cast<int8_t*>(input_quantized->data.uint8);
-    weights_feature_ptr =
-        reinterpret_cast<int8_t*>(weights_feature->data.uint8);
-  } else {
-    quantized_input_ptr_batch = input_quantized->data.int8;
-    weights_feature_ptr = weights_feature->data.int8;
-  }
-
-  // Initialize the pointer to storage for scaling factors.
-  float* scaling_factors_ptr = scaling_factors->data.f;
-
-  // Initialize the weights scale.
-  const float weights_feature_scale = weights_feature->params.scale;
-
-  // Clear the activation (state left most column).
-  // TODO(ghodrat): Add a test which initialize state with invalid values in
-  // the left most column and make sure it passes.
-  for (int b = 0; b < batch_size; ++b) {
-    float* state_ptr_batch = state->data.f + b * memory_size * num_filters;
-    for (int c = 0; c < num_filters; ++c) {
-      float* state_ptr = state_ptr_batch + c * memory_size;
-      state_ptr[memory_size - 1] = 0.0;
-    }
-  }
-
-  if (!tensor_utils::IsZeroVector(input_ptr_batch, batch_size * input_size)) {
-    // Quantize input from float to int8.
-    float unused_min, unused_max;
-    for (int b = 0; b < batch_size; ++b) {
-      const int offset = b * input_size;
-      tensor_utils::SymmetricQuantizeFloats(
-          input_ptr_batch + offset, input_size,
-          quantized_input_ptr_batch + offset, &unused_min, &unused_max,
-          &scaling_factors_ptr[b]);
-      scaling_factors_ptr[b] *= weights_feature_scale;
-    }
-
-    // Compute conv1d(inputs, weights_feature).
-    // The rightmost column of state is used to save the current cycle
-    // activation.
-    // This is achieved by starting at state->data.f[memory_size - 1]
-    // and having the stride equal to memory_size.
-    tensor_utils::MatrixBatchVectorMultiplyAccumulate(
-        weights_feature_ptr, num_filters, input_size, quantized_input_ptr_batch,
-        scaling_factors_ptr, batch_size, &state->data.f[memory_size - 1],
-        memory_size);
-  }
-
-  // TODO(alanchiao): can optimize hybrid case ~5% by unrolling loop in applying
-  // time weights so that the inner loop multiplies eight elements at a time.
-  ApplyTimeWeightsBiasAndActivation(batch_size, memory_size, num_filters,
-                                    num_units, rank, weights_time, bias,
-                                    params->activation, state, scratch, output);
-  return kTfLiteOk;
-}
-
 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
   auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
@@ -377,8 +211,10 @@
 
   switch (weights_feature->type) {
     case kTfLiteFloat32: {
-      return EvalFloat(context, node, input, weights_feature, weights_time,
-                       bias, params, scratch, activation_state, output);
+      reference_ops::EvalFloatSVDF(context, node, input, weights_feature,
+                                   weights_time, bias, params, scratch,
+                                   activation_state, output);
+      return kTfLiteOk;
       break;
     }
     case kTfLiteUInt8:
@@ -398,26 +234,29 @@
         const float dequantization_scale = weights_time->params.scale;
         const int8_t* weights_time_ptr;
         if (weights_feature->type == kTfLiteUInt8) {
-          weights_time_ptr =
-              reinterpret_cast<int8_t*>(weights_time->data.uint8);
+          weights_time_ptr = reinterpret_cast<const int8_t*>(
+              GetTensorData<uint8_t>(weights_time));
         } else {
-          weights_time_ptr = weights_time->data.int8;
+          weights_time_ptr = GetTensorData<int8_t>(weights_time);
         }
+        float* float_weights_time_ptr =
+            GetTensorData<float>(float_weights_time);
         for (int i = 0; i < NumElements(float_weights_time); ++i) {
-          float_weights_time->data.f[i] =
+          float_weights_time_ptr[i] =
               weights_time_ptr[i] * dequantization_scale;
         }
         op_data->float_weights_time_initialized = true;
       }
-      return EvalHybrid(context, node, input, weights_feature,
-                        float_weights_time, bias, params, scratch,
-                        scaling_factors, input_quantized, activation_state,
-                        output);
+      reference_ops::EvalHybridSVDF(context, node, input, weights_feature,
+                                    float_weights_time, bias, params, scratch,
+                                    scaling_factors, input_quantized,
+                                    activation_state, output);
+      return kTfLiteOk;
       break;
     }
     default:
-      context->ReportError(context, "Type %d not currently supported.",
-                           weights_feature->type);
+      context->ReportError(context, "Type %s not currently supported.",
+                           TfLiteTypeGetName(weights_feature->type));
       return kTfLiteError;
   }
 }
diff --git a/tensorflow/lite/kernels/test_util.cc b/tensorflow/lite/kernels/test_util.cc
index 9c4dead..743d668 100644
--- a/tensorflow/lite/kernels/test_util.cc
+++ b/tensorflow/lite/kernels/test_util.cc
@@ -116,7 +116,8 @@
 
 void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
                                      int num_threads,
-                                     bool allow_fp32_relax_to_fp16) {
+                                     bool allow_fp32_relax_to_fp16,
+                                     bool apply_delegate) {
   auto opcodes = builder_.CreateVector(opcodes_);
   auto operators = builder_.CreateVector(operators_);
   auto tensors = builder_.CreateVector(tensors_);
@@ -161,6 +162,13 @@
       << "Cannot allocate tensors";
   interpreter_->ResetVariableTensors();
 
+  // In some rare cases a test may need to postpone modifying the graph with
+  // a delegate, e.g. if tensors are not fully specified. In such cases the
+  // test has to explicitly call ApplyDelegate() when necessary.
+  if (apply_delegate) ApplyDelegate();
+}
+
+void SingleOpModel::ApplyDelegate() {
   if (force_use_nnapi) {
     // TODO(b/124505407): Check the result and fail accordingly.
     interpreter_->ModifyGraphWithDelegate(TestNnApiDelegate());
@@ -179,18 +187,22 @@
 void SingleOpModel::BuildInterpreter(
     std::vector<std::vector<int>> input_shapes) {
   BuildInterpreter(input_shapes, /*num_threads=*/-1,
-                   /*allow_fp32_relax_to_fp16=*/false);
+                   /*allow_fp32_relax_to_fp16=*/false,
+                   /*apply_delegate=*/true);
+}
+
+void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
+                                     bool allow_fp32_relax_to_fp16,
+                                     bool apply_delegate) {
+  BuildInterpreter(input_shapes, /*num_threads=*/-1, allow_fp32_relax_to_fp16,
+                   apply_delegate);
 }
 
 void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
                                      int num_threads) {
   BuildInterpreter(input_shapes, num_threads,
-                   /*allow_fp32_relax_to_fp16=*/false);
-}
-
-void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
-                                     bool allow_fp32_relax_to_fp16) {
-  BuildInterpreter(input_shapes, /*num_threads=*/-1, allow_fp32_relax_to_fp16);
+                   /*allow_fp32_relax_to_fp16=*/false,
+                   /*apply_delegate=*/true);
 }
 
 // static
@@ -198,6 +210,9 @@
   force_use_nnapi = use_nnapi;
 }
 
+// static
+bool SingleOpModel::GetForceUseNnapi() { return force_use_nnapi; }
+
 int32_t SingleOpModel::GetTensorSize(int index) const {
   TfLiteTensor* t = interpreter_->tensor(index);
   CHECK(t);
diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h
index 1faae70..31fde68 100644
--- a/tensorflow/lite/kernels/test_util.h
+++ b/tensorflow/lite/kernels/test_util.h
@@ -151,6 +151,8 @@
     apply_delegate_fn_ = apply_delegate_fn;
   }
 
+  void ApplyDelegate();
+
   // Copying or assignment is disallowed to simplify ownership semantics.
   SingleOpModel(const SingleOpModel&) = delete;
   SingleOpModel& operator=(const SingleOpModel&) = delete;
@@ -255,13 +257,14 @@
   // Build the interpreter for this model. Also, resize and allocate all
   // tensors given the shapes of the inputs.
   void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
-                        int num_threads, bool allow_fp32_relax_to_fp16);
+                        int num_threads, bool allow_fp32_relax_to_fp16,
+                        bool apply_delegate = true);
 
   void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
                         int num_threads);
 
   void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
-                        bool allow_fp32_relax_to_fp16);
+                        bool allow_fp32_relax_to_fp16, bool apply_delegate);
 
   void BuildInterpreter(std::vector<std::vector<int>> input_shapes);
 
@@ -358,6 +361,7 @@
 
   // Enables NNAPI delegate application during interpreter creation.
   static void SetForceUseNnapi(bool use_nnapi);
+  static bool GetForceUseNnapi();
 
  protected:
   int32_t GetTensorSize(int index) const;
diff --git a/tensorflow/lite/kernels/tile.cc b/tensorflow/lite/kernels/tile.cc
index 1b74797..dc049ca 100644
--- a/tensorflow/lite/kernels/tile.cc
+++ b/tensorflow/lite/kernels/tile.cc
@@ -70,10 +70,10 @@
   }
 }
 
-template <typename T>
-void CopyMultipleTimes(const T* in_data, int32_t in_size, int32_t multiplier,
+template <typename T, typename M>
+void CopyMultipleTimes(const T* in_data, int32_t in_size, M multiplier,
                        T* out_data) {
-  for (int i = 0; i < multiplier; ++i) {
+  for (M i = 0; i < multiplier; ++i) {
     const T* in_end = in_data + in_size;
     T* new_out_data = std::copy(in_data, in_end, out_data);
     in_data = out_data;
@@ -109,8 +109,9 @@
   CopyMultipleTimes(out_data, total_tiled_stride_size,
                     multipliers[dimension] - 1,
                     out_data + total_tiled_stride_size);
-  return std::make_pair(total_stride_size,
-                        total_tiled_stride_size * multipliers[dimension]);
+  return std::make_pair(
+      total_stride_size,
+      static_cast<int>(total_tiled_stride_size * multipliers[dimension]));
 }
 
 template <typename T>
diff --git a/tensorflow/lite/kernels/unpack.cc b/tensorflow/lite/kernels/unpack.cc
index 3af2e96..511ea85 100644
--- a/tensorflow/lite/kernels/unpack.cc
+++ b/tensorflow/lite/kernels/unpack.cc
@@ -36,7 +36,7 @@
 
   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
   TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
-  TF_LITE_ENSURE(context, NumDimensions(input) > 1);
+  TF_LITE_ENSURE(context, NumElements(input) > 0);
   int axis = data->axis;
   if (axis < 0) {
     axis += NumDimensions(input);
diff --git a/tensorflow/lite/kernels/unpack_test.cc b/tensorflow/lite/kernels/unpack_test.cc
index fb38b50..28d21cc 100644
--- a/tensorflow/lite/kernels/unpack_test.cc
+++ b/tensorflow/lite/kernels/unpack_test.cc
@@ -126,6 +126,13 @@
                /*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}});
 }
 
+TEST(UnpackOpTest, FloatVectorToScalar) {
+  Check<float>(/*axis=*/0, /*input_shape=*/{5},
+               /*input_data=*/{1, 2, 3, 4, 5},
+               /*exp_output_shape=*/{{}, {}, {}, {}, {}},
+               /*exp_output_data=*/{{1}, {2}, {3}, {4}, {5}});
+}
+
 // int32 tests.
 TEST(UnpackOpTest, IntThreeOutputs) {
   Check<int32_t>(/*axis=*/0, /*input_shape=*/{3, 2},
@@ -159,6 +166,14 @@
                  /*type=*/TensorType_INT32);
 }
 
+TEST(UnpackOpTest, IntVectorToScalar) {
+  Check<int32_t>(/*axis=*/0, /*input_shape=*/{5},
+                 /*input_data=*/{1, 2, 3, 4, 5},
+                 /*exp_output_shape=*/{{}, {}, {}, {}, {}},
+                 /*exp_output_data=*/{{1}, {2}, {3}, {4}, {5}},
+                 /*type=*/TensorType_INT32);
+}
+
 // uint8 tests.
 TEST(UnpackOpTest, Uint8ThreeOutputs) {
   Check<uint8_t>(/*axis=*/0, /*input_shape=*/{3, 2},
@@ -208,6 +223,14 @@
                  /*type=*/TensorType_UINT8);
 }
 
+TEST(UnpackOpTest, Uint8VectorToScalar) {
+  Check<uint8_t>(/*axis=*/0, /*input_shape=*/{5},
+                 /*input_data=*/{1, 2, 3, 4, 5},
+                 /*exp_output_shape=*/{{}, {}, {}, {}, {}},
+                 /*exp_output_data=*/{{1}, {2}, {3}, {4}, {5}},
+                 /*type=*/TensorType_UINT8);
+}
+
 // int8 tests.
 TEST(UnpackOpTest, Int8ThreeOutputs) {
   Check<int8_t>(/*axis=*/0, /*input_shape=*/{3, 2},
@@ -257,5 +280,13 @@
                 /*type=*/TensorType_INT8);
 }
 
+TEST(UnpackOpTest, Int8VectorToScalar) {
+  Check<int8_t>(/*axis=*/0, /*input_shape=*/{5},
+                /*input_data=*/{1, 2, 3, 4, 5},
+                /*exp_output_shape=*/{{}, {}, {}, {}, {}},
+                /*exp_output_data=*/{{1}, {2}, {3}, {4}, {5}},
+                /*type=*/TensorType_INT8);
+}
+
 }  // namespace
 }  // namespace tflite
diff --git a/tensorflow/lite/kernels/variable_ops_test.cc b/tensorflow/lite/kernels/variable_ops_test.cc
new file mode 100644
index 0000000..e6e1a40
--- /dev/null
+++ b/tensorflow/lite/kernels/variable_ops_test.cc
@@ -0,0 +1,149 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/model.h"
+
+namespace tflite {
+
+// Forward declaraction for op kernels.
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_ASSIGN_VARIABLE();
+TfLiteRegistration* Register_READ_VARIABLE();
+
+}  // namespace custom
+}  // namespace ops
+
+namespace {
+
+class VariableOpsTest : public ::testing::Test {
+ protected:
+  void SetUp() override {
+    assign_registration_ = ::tflite::ops::custom::Register_ASSIGN_VARIABLE();
+    ASSERT_NE(assign_registration_, nullptr);
+    read_registration_ = ::tflite::ops::custom::Register_READ_VARIABLE();
+    ASSERT_NE(read_registration_, nullptr);
+
+    ConstructGraph();
+  }
+
+  void ConstructGraph() {
+    // Construct a graph like ths:
+    //   Input: %0, %1, %2
+    //   Output: %3
+    //   variable_assign(%0, %2)
+    //   %3 = read(%1)
+
+    int first_new_tensor_index;
+    ASSERT_EQ(interpreter_.AddTensors(4, &first_new_tensor_index), kTfLiteOk);
+    ASSERT_EQ(interpreter_.SetInputs({0, 1, 2}), kTfLiteOk);
+    ASSERT_EQ(interpreter_.SetOutputs({3}), kTfLiteOk);
+    interpreter_.SetTensorParametersReadWrite(0, kTfLiteInt32, "", 0, nullptr,
+                                              {}, false);
+    interpreter_.SetTensorParametersReadWrite(1, kTfLiteInt32, "", 0, nullptr,
+                                              {}, false);
+    interpreter_.SetTensorParametersReadWrite(2, kTfLiteFloat32, "", 0, nullptr,
+                                              {}, false);
+    interpreter_.SetTensorParametersReadWrite(3, kTfLiteFloat32, "", 0, nullptr,
+                                              {}, false);
+    int node_index;
+    interpreter_.AddNodeWithParameters({0, 2}, {}, nullptr, 0, nullptr,
+                                       assign_registration_, &node_index);
+    interpreter_.AddNodeWithParameters({1}, {3}, nullptr, 0, nullptr,
+                                       read_registration_, &node_index);
+  }
+  TfLiteRegistration* assign_registration_;
+  TfLiteRegistration* read_registration_;
+  Interpreter interpreter_;
+};
+
+TEST_F(VariableOpsTest, TestAssignThenReadVariable) {
+  ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk);
+  TfLiteTensor* input_assign_index = interpreter_.tensor(0);
+  input_assign_index->data.i32[0] = 1;
+  TfLiteTensor* input_read_index = interpreter_.tensor(1);
+  input_read_index->data.i32[0] = 1;
+  TfLiteTensor* input_data_index = interpreter_.tensor(2);
+  input_data_index->data.f[0] = 1717;
+  ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk);
+
+  // Verify output.
+  TfLiteTensor* output = interpreter_.tensor(3);
+  ASSERT_EQ(output->dims->size, 0);
+  EXPECT_EQ(output->data.f[0], 1717);
+}
+
+TEST_F(VariableOpsTest, TestReadVariableBeforeAssign) {
+  ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk);
+  TfLiteTensor* input_assign_index = interpreter_.tensor(0);
+  input_assign_index->data.i32[0] = 1;
+  TfLiteTensor* input_read_index = interpreter_.tensor(1);
+  input_read_index->data.i32[0] = 2;
+  TfLiteTensor* input_data_index = interpreter_.tensor(2);
+  input_data_index->data.f[0] = 1717;
+
+  // Error because variable 2 is never initialized.
+  ASSERT_EQ(interpreter_.Invoke(), kTfLiteError);
+}
+
+TEST_F(VariableOpsTest, TestReeasignToDifferentSize) {
+  // 1st invocation. The variable is assigned as a scalar.
+  {
+    ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk);
+
+    TfLiteTensor* input_assign_index = interpreter_.tensor(0);
+    input_assign_index->data.i32[0] = 1;
+    TfLiteTensor* input_read_index = interpreter_.tensor(1);
+    input_read_index->data.i32[0] = 1;
+    TfLiteTensor* input_data_index = interpreter_.tensor(2);
+    input_data_index->data.f[0] = 1717;
+    ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk);
+
+    // Verify output.
+    TfLiteTensor* output = interpreter_.tensor(3);
+    ASSERT_EQ(output->dims->size, 0);
+    EXPECT_EQ(output->data.f[0], 1717);
+  }
+
+  // 2nd invocation. The variable is assigned as a 1D vector with 2 elements.
+  {
+    interpreter_.ResizeInputTensor(2, {2});
+    ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk);
+
+    TfLiteTensor* input_assign_index = interpreter_.tensor(0);
+    input_assign_index->data.i32[0] = 1;
+    TfLiteTensor* input_read_index = interpreter_.tensor(1);
+    input_read_index->data.i32[0] = 1;
+    TfLiteTensor* input_data_index = interpreter_.tensor(2);
+    input_data_index->data.f[0] = 1717;
+    input_data_index->data.f[1] = 2121;
+    ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk);
+
+    // Verify output.
+    TfLiteTensor* output = interpreter_.tensor(3);
+    ASSERT_EQ(output->dims->size, 1);
+    ASSERT_EQ(output->dims->data[0], 2);
+    EXPECT_EQ(output->data.f[0], 1717);
+    EXPECT_EQ(output->data.f[1], 2121);
+  }
+}
+
+}  // namespace
+}  // namespace tflite
diff --git a/tensorflow/lite/model.cc b/tensorflow/lite/model.cc
index 5fd9e21..516ba69 100644
--- a/tensorflow/lite/model.cc
+++ b/tensorflow/lite/model.cc
@@ -159,6 +159,22 @@
   return model;
 }
 
+string FlatBufferModel::GetMinimumRuntime() const {
+  if (!model_ || !model_->metadata()) return "";
+
+  for (int i = 0; i < model_->metadata()->size(); ++i) {
+    auto metadata = model_->metadata()->Get(i);
+    if (metadata->name()->str() == "min_runtime_version") {
+      auto buf = metadata->buffer();
+      auto* buffer = (*model_->buffers())[buf];
+      auto* array = buffer->data();
+      return string(reinterpret_cast<const char*>(array->data()),
+                    array->size());
+    }
+  }
+  return "";
+}
+
 bool FlatBufferModel::CheckModelIdentifier() const {
   if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
     const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
@@ -311,13 +327,21 @@
           EnumNameBuiltinOperator(op_type));
     }
 
-    if (op->custom_options()) {
-      subgraph->AddNodeWithParameters(
-          FlatBufferIntArrayToVector(op->inputs()),
-          FlatBufferIntArrayToVector(op->outputs()),
-          FlatBufferIntArrayToVector(op->intermediates()),
-          reinterpret_cast<const char*>(op->custom_options()->data()),
-          op->custom_options()->size(), nullptr, registration);
+    if (op_type == BuiltinOperator_CUSTOM) {
+      if (op->custom_options()) {
+        subgraph->AddNodeWithParameters(
+            FlatBufferIntArrayToVector(op->inputs()),
+            FlatBufferIntArrayToVector(op->outputs()),
+            FlatBufferIntArrayToVector(op->intermediates()),
+            reinterpret_cast<const char*>(op->custom_options()->data()),
+            op->custom_options()->size(), nullptr, registration);
+      } else {
+        subgraph->AddNodeWithParameters(
+            FlatBufferIntArrayToVector(op->inputs()),
+            FlatBufferIntArrayToVector(op->outputs()),
+            FlatBufferIntArrayToVector(op->intermediates()), nullptr, 0,
+            nullptr, registration);
+      }
     } else {
       void* builtin_data = nullptr;
       MallocDataAllocator malloc_allocator;
diff --git a/tensorflow/lite/model.h b/tensorflow/lite/model.h
index 6c56947..06dd2e2 100644
--- a/tensorflow/lite/model.h
+++ b/tensorflow/lite/model.h
@@ -135,6 +135,12 @@
   ErrorReporter* error_reporter() const { return error_reporter_; }
   const Allocation* allocation() const { return allocation_.get(); }
 
+  // Returns the minimum runtime version from the flatbuffer. This runtime
+  // version encodes the minimum required interpreter version to run the
+  // flatbuffer model. If the minimum version can't be determined, an empty
+  // string will be returned.
+  string GetMinimumRuntime() const;
+
   /// Returns true if the model identifier is correct (otherwise false and
   /// reports an error).
   bool CheckModelIdentifier() const;
diff --git a/tensorflow/lite/model_test.cc b/tensorflow/lite/model_test.cc
index d58dbf4..7dc582b 100644
--- a/tensorflow/lite/model_test.cc
+++ b/tensorflow/lite/model_test.cc
@@ -315,6 +315,22 @@
   ASSERT_NE(interpreter, nullptr);
 }
 
+// Test reading the minimum runtime string from metadata in a Model flatbuffer.
+TEST(BasicFlatBufferModel, TestReadRuntimeVersionFromModel) {
+  // First read a model that doesn't have the runtime string.
+  auto model1 = FlatBufferModel::BuildFromFile(
+      "tensorflow/lite/testdata/test_model.bin");
+  ASSERT_TRUE(model1);
+  ASSERT_EQ(model1->GetMinimumRuntime(), "");
+
+  // Read a model that has minimum runtime string populated.
+  auto model2 = FlatBufferModel::BuildFromFile(
+      "tensorflow/lite/testdata/test_min_runtime.bin");
+  ASSERT_TRUE(model2);
+  // Check that we have read the runtime string correctly.
+  ASSERT_EQ(model2->GetMinimumRuntime(), "1.10.0");
+}
+
 // TODO(aselle): Add tests for serialization of builtin op data types.
 // These tests will occur with the evaluation tests of individual operators,
 // not here.
diff --git a/tensorflow/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java b/tensorflow/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java
index fbd7505..cbd155b 100644
--- a/tensorflow/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java
+++ b/tensorflow/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java
@@ -53,8 +53,13 @@
   @WorkerThread
   public synchronized void loadModel() {
     if (!isLibraryLoaded) {
-      System.loadLibrary(JNI_LIB);
-      isLibraryLoaded = true;
+      try {
+        System.loadLibrary(JNI_LIB);
+        isLibraryLoaded = true;
+      } catch (Exception e) {
+        Log.e(TAG, "Failed to load prebuilt smartreply_jni lib", e);
+        return;
+      }
     }
 
     try {
diff --git a/tensorflow/lite/models/smartreply/g3doc/README.md b/tensorflow/lite/models/smartreply/g3doc/README.md
index 1b8ff15..0443929 100644
--- a/tensorflow/lite/models/smartreply/g3doc/README.md
+++ b/tensorflow/lite/models/smartreply/g3doc/README.md
@@ -62,8 +62,8 @@
 ## How to use this Model?
 
 We have provided a pre-built demo APK that you can download, install and test on
-your phone ([demo APK
-here](http://download.tensorflow.org/deps/tflite/SmartReplyDemo.apk)).
+your phone
+([demo APK here](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/SmartReplyDemo.apk)).
 
 The On-Device Smart Reply demo App works in the following way:
 
diff --git a/tensorflow/lite/models/speech_test.cc b/tensorflow/lite/models/speech_test.cc
index 63436ef..4b40858 100644
--- a/tensorflow/lite/models/speech_test.cc
+++ b/tensorflow/lite/models/speech_test.cc
@@ -108,7 +108,7 @@
       "speech_hotword_model_out_rank1.csv", /*input_tensor=*/"0",
       /*output_tensor=*/"18", /*persistent_tensors=*/"4",
       /*sequence_size=*/40, &os));
-  testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
+  testing::TfLiteDriver test_driver;
   ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
       << test_driver.GetErrorMessage();
 }
@@ -120,7 +120,7 @@
       "speech_hotword_model_out_rank2.csv", /*input_tensor=*/"17",
       /*output_tensor=*/"18", /*persistent_tensors=*/"1",
       /*sequence_size=*/40, &os));
-  testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
+  testing::TfLiteDriver test_driver;
   ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
       << test_driver.GetErrorMessage();
 }
@@ -133,7 +133,7 @@
       /*output_tensor=*/"63",
       /*persistent_tensors=*/"18,19,38,39,58,59",
       /*sequence_size=*/80, &os));
-  testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
+  testing::TfLiteDriver test_driver;
   ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
       << test_driver.GetErrorMessage();
 }
@@ -146,7 +146,7 @@
                      /*output_tensor=*/"104",
                      /*persistent_tensors=*/"18,19,38,39,58,59,78,79,98,99",
                      /*sequence_size=*/320, &os));
-  testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
+  testing::TfLiteDriver test_driver;
   ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
       << test_driver.GetErrorMessage();
 }
@@ -159,7 +159,7 @@
       /*output_tensor=*/"104",
       /*persistent_tensors=*/"18,19,38,39,58,59,78,79,98,99",
       /*sequence_size=*/320, &os));
-  testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
+  testing::TfLiteDriver test_driver;
   ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
       << test_driver.GetErrorMessage();
 }
@@ -170,7 +170,7 @@
 // results.
 TEST_P(SpeechTest, DISABLED_AsrLmTest) {
   std::ifstream in_file;
-  testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
+  testing::TfLiteDriver test_driver;
   ASSERT_TRUE(Init("speech_asr_lm_model.test_spec", &test_driver, &in_file));
   ASSERT_TRUE(
       testing::ParseAndRunTests(&in_file, &test_driver, GetMaxInvocations()))
@@ -185,7 +185,7 @@
       /*output_tensor=*/"56",
       /*persistent_tensors=*/"27,28,47,48",
       /*sequence_size=*/320, &os));
-  testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
+  testing::TfLiteDriver test_driver;
   ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
       << test_driver.GetErrorMessage();
 }
@@ -198,7 +198,7 @@
                              /*output_tensor=*/"71",
                              /*persistent_tensors=*/"24,25,44,45,64,65,70",
                              /*sequence_size=*/334, &os));
-  testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
+  testing::TfLiteDriver test_driver;
   ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
       << test_driver.GetErrorMessage();
 }
diff --git a/tensorflow/lite/nnapi/BUILD b/tensorflow/lite/nnapi/BUILD
index 7ea58c5..228f7d4 100644
--- a/tensorflow/lite/nnapi/BUILD
+++ b/tensorflow/lite/nnapi/BUILD
@@ -12,6 +12,9 @@
         "NeuralNetworksTypes.h",
     ],
     linkopts = select({
+        "//tensorflow:emscripten": [],
+        "//tensorflow:ios": [],
+        "//tensorflow:macos": [],
         "//tensorflow:windows": [],
         "//conditions:default": ["-ldl"],
     }),
@@ -20,9 +23,15 @@
 cc_library(
     name = "nnapi_implementation",
     srcs = select({
+        "//tensorflow:emscripten": [
+            "nnapi_implementation_disabled.cc",
+        ],
         "//tensorflow:ios": [
             "nnapi_implementation_disabled.cc",
         ],
+        "//tensorflow:macos": [
+            "nnapi_implementation_disabled.cc",
+        ],
         "//tensorflow:windows": [
             "nnapi_implementation_disabled.cc",
         ],
@@ -34,12 +43,16 @@
         "nnapi_implementation.h",
     ],
     linkopts = select({
+        "//tensorflow:emscripten": [],
+        "//tensorflow:ios": [],
+        "//tensorflow:macos": [],
         "//tensorflow:windows": [],
         "//conditions:default": ["-ldl"],
     }) + select({
         "//tensorflow:android": [],
-        "//tensorflow:macos": [],
+        "//tensorflow:emscripten": [],
         "//tensorflow:ios": [],
+        "//tensorflow:macos": [],
         "//tensorflow:windows": [],
         "//conditions:default": ["-lrt"],
     }),
diff --git a/tensorflow/lite/nnapi/NeuralNetworksTypes.h b/tensorflow/lite/nnapi/NeuralNetworksTypes.h
index 6b5d8e2..b4ec12e 100644
--- a/tensorflow/lite/nnapi/NeuralNetworksTypes.h
+++ b/tensorflow/lite/nnapi/NeuralNetworksTypes.h
@@ -41,6 +41,7 @@
   ANEURALNETWORKS_TENSOR_QUANT8_ASYMM = 5,
   ANEURALNETWORKS_BOOL = 6,
   ANEURALNETWORKS_TENSOR_BOOL8 = 9,
+  ANEURALNETWORKS_TENSOR_QUANT16_SYMM = 7,
   ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL = 11,
   ANEURALNETWORKS_TENSOR_QUANT8_SYMM = 13,
 };
@@ -115,6 +116,12 @@
   ANEURALNETWORKS_POW = 70,
   ANEURALNETWORKS_PRELU = 71,
   ANEURALNETWORKS_QUANTIZE = 72,
+  ANEURALNETWORKS_QUANTIZED_16BIT_LSTM = 73,
+  ANEURALNETWORKS_REDUCE_ANY = 76,
+  ANEURALNETWORKS_REDUCE_MAX = 77,
+  ANEURALNETWORKS_REDUCE_MIN = 78,
+  ANEURALNETWORKS_REDUCE_PROD = 79,
+  ANEURALNETWORKS_REDUCE_SUM = 80,
   ANEURALNETWORKS_RSQRT = 83,
   ANEURALNETWORKS_SELECT = 84,
   ANEURALNETWORKS_SIN = 85,
@@ -123,8 +130,10 @@
   ANEURALNETWORKS_SQRT = 88,
   ANEURALNETWORKS_TILE = 89,
   ANEURALNETWORKS_TOPK_V2 = 90,
+  ANEURALNETWORKS_TRANSPOSE_CONV = 91,
   ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_LSTM = 92,
   ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_RNN = 93,
+  ANEURALNETWORKS_RESIZE_NEAREST_NEIGHBOR = 94,
 };
 
 /**
diff --git a/tensorflow/lite/nnapi/nnapi_implementation.cc b/tensorflow/lite/nnapi/nnapi_implementation.cc
index bc5159f..c30a24a 100644
--- a/tensorflow/lite/nnapi/nnapi_implementation.cc
+++ b/tensorflow/lite/nnapi/nnapi_implementation.cc
@@ -178,7 +178,13 @@
     }
   }
 #else
-  nnapi.ASharedMemory_create = ASharedMemory_create;
+  // Mock ASharedMemory_create only if libneuralnetworks.so was successfully
+  // loaded. This ensures identical behaviour on platforms which use this
+  // implementation, but don't have libneuralnetworks.so library, and
+  // platforms which use nnapi_implementation_disabled.cc stub.
+  if (libneuralnetworks != nullptr) {
+    nnapi.ASharedMemory_create = ASharedMemory_create;
+  }
 #endif  // __ANDROID__
 
   // API 28 (NN 1.1) methods.
diff --git a/tensorflow/lite/nnapi/nnapi_implementation_test.cc b/tensorflow/lite/nnapi/nnapi_implementation_test.cc
index 9f30b95..0d696af 100644
--- a/tensorflow/lite/nnapi/nnapi_implementation_test.cc
+++ b/tensorflow/lite/nnapi/nnapi_implementation_test.cc
@@ -116,7 +116,7 @@
   EXPECT_EQ(nnapi->ANeuralNetworksExecution_startCompute, nullptr);
   EXPECT_EQ(nnapi->ANeuralNetworksEvent_wait, nullptr);
   EXPECT_EQ(nnapi->ANeuralNetworksEvent_free, nullptr);
-  EXPECT_NE(nnapi->ASharedMemory_create, nullptr);
+  EXPECT_EQ(nnapi->ASharedMemory_create, nullptr);
   EXPECT_EQ(nnapi->ANeuralNetworks_getDeviceCount, nullptr);
   EXPECT_EQ(nnapi->ANeuralNetworks_getDevice, nullptr);
   EXPECT_EQ(nnapi->ANeuralNetworksDevice_getName, nullptr);
diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD
index df7c07f..ca00546 100644
--- a/tensorflow/lite/python/BUILD
+++ b/tensorflow/lite/python/BUILD
@@ -111,7 +111,6 @@
     srcs = ["lite_v2_test.py"],
     srcs_version = "PY2AND3",
     tags = [
-        "no_oss",
         "no_windows",
     ],
     deps = [
@@ -142,7 +141,6 @@
     srcs = ["lite_mlir_test.py"],
     srcs_version = "PY2AND3",
     tags = [
-        "no_oss",
         "no_windows",
     ],
     deps = [
diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py
index 328e44e..9fe8b25 100644
--- a/tensorflow/lite/python/convert.py
+++ b/tensorflow/lite/python/convert.py
@@ -161,7 +161,7 @@
       # Some of the subtests within the "convert_test" unit-test fail
       # with the error shown above. So watch out for that scenario and
       # convert debug_info_str to bytes where needed
-      if isinstance(debug_info_str, str):
+      if not isinstance(debug_info_str, bytes):
         fp_debug.write(debug_info_str.encode("utf-8"))
       else:
         fp_debug.write(debug_info_str)
diff --git a/tensorflow/lite/python/convert_test.py b/tensorflow/lite/python/convert_test.py
index 382c351..543ddda 100644
--- a/tensorflow/lite/python/convert_test.py
+++ b/tensorflow/lite/python/convert_test.py
@@ -25,6 +25,7 @@
 from tensorflow.lite.python.interpreter import Interpreter
 from tensorflow.python.client import session
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes
 from tensorflow.python.framework.graph_util_impl import _extract_graph_summary
@@ -34,32 +35,27 @@
 from tensorflow.python.platform import test
 
 
-@test_util.run_v1_only("Incompatible with 2.0.")
 class ConvertTest(test_util.TensorFlowTestCase):
 
   def testBasic(self):
-    in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3],
-                                      dtype=dtypes.float32)
-    out_tensor = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      out_tensor = in_tensor + in_tensor
+      sess = session.Session()
 
     # Try running on valid graph
     tflite_model = convert.toco_convert(sess.graph_def, [in_tensor],
                                         [out_tensor])
     self.assertTrue(tflite_model)
 
-    # TODO(aselle): remove tests that fail (we must get TOCO to not fatal
-    # all the time).
-    # Try running on identity graph (known fail)
-    # with self.assertRaisesRegexp(RuntimeError, "!model->operators.empty()"):
-    #   result = convert.toco_convert(sess.graph_def, [in_tensor], [in_tensor])
-
   def testQuantization(self):
-    in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3],
-                                      dtype=dtypes.float32)
-    out_tensor = array_ops.fake_quant_with_min_max_args(in_tensor + in_tensor,
-                                                        min=0., max=1.)
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      out_tensor = array_ops.fake_quant_with_min_max_args(
+          in_tensor + in_tensor, min=0., max=1.)
+      sess = session.Session()
 
     tflite_model = convert.toco_convert(
         sess.graph_def, [in_tensor], [out_tensor],
@@ -68,11 +64,12 @@
     self.assertTrue(tflite_model)
 
   def testQuantizationInvalid(self):
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    out_tensor = array_ops.fake_quant_with_min_max_args(
-        in_tensor + in_tensor, min=0., max=1.)
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      out_tensor = array_ops.fake_quant_with_min_max_args(
+          in_tensor + in_tensor, min=0., max=1.)
+      sess = session.Session()
 
     with self.assertRaises(ValueError) as error:
       convert.toco_convert(
@@ -83,10 +80,11 @@
         "QUANTIZED_UINT8.", str(error.exception))
 
   def testGraphDefBasic(self):
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32, name="input")
-    _ = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32, name="input")
+      _ = in_tensor + in_tensor
+      sess = session.Session()
 
     tflite_model = convert.toco_convert_graph_def(
         sess.graph_def, [("input", [1, 16, 16, 3])], ["add"],
@@ -113,13 +111,14 @@
     self.assertEqual((0., 0.), output_details[0]["quantization"])
 
   def testGraphDefQuantization(self):
-    in_tensor_1 = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputA")
-    in_tensor_2 = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputB")
-    _ = array_ops.fake_quant_with_min_max_args(
-        in_tensor_1 + in_tensor_2, min=0., max=1., name="output")
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor_1 = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputA")
+      in_tensor_2 = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputB")
+      _ = array_ops.fake_quant_with_min_max_args(
+          in_tensor_1 + in_tensor_2, min=0., max=1., name="output")
+      sess = session.Session()
 
     input_arrays_map = [("inputA", [1, 16, 16, 3]), ("inputB", [1, 16, 16, 3])]
     output_arrays = ["output"]
@@ -158,13 +157,14 @@
     self.assertTrue(output_details[0]["quantization"][0] > 0)  # scale
 
   def testGraphDefQuantizationInvalid(self):
-    in_tensor_1 = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputA")
-    in_tensor_2 = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputB")
-    _ = array_ops.fake_quant_with_min_max_args(
-        in_tensor_1 + in_tensor_2, min=0., max=1., name="output")
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor_1 = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputA")
+      in_tensor_2 = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputB")
+      _ = array_ops.fake_quant_with_min_max_args(
+          in_tensor_1 + in_tensor_2, min=0., max=1., name="output")
+      sess = session.Session()
 
     input_arrays_map = [("inputA", [1, 16, 16, 3]), ("inputB", [1, 16, 16, 3])]
     output_arrays = ["output"]
@@ -180,7 +180,6 @@
         "QUANTIZED_UINT8.", str(error.exception))
 
 
-@test_util.run_v1_only("Incompatible with 2.0.")
 class ConvertTestOpHint(test_util.TensorFlowTestCase):
   """Test the hint to stub functionality."""
 
@@ -219,82 +218,91 @@
 
   def testSwishLiteHint(self):
     """Makes a custom op swish and makes sure it gets converted as a unit."""
-    image = array_ops.constant([1., 2., 3., 4.])
-    swish_scale = array_ops.constant(1.0)
+    with ops.Graph().as_default():
+      image = array_ops.constant([1., 2., 3., 4.])
+      swish_scale = array_ops.constant(1.0)
 
-    def _swish(input_tensor, scale):
-      custom = op_hint.OpHint("cool_activation")
-      input_tensor, scale = custom.add_inputs(input_tensor, scale)
-      output = math_ops.sigmoid(input_tensor) * input_tensor * scale
-      output, = custom.add_outputs(output)
-      return output
-    output = array_ops.identity(_swish(image, swish_scale), name="ModelOutput")
+      def _swish(input_tensor, scale):
+        custom = op_hint.OpHint("cool_activation")
+        input_tensor, scale = custom.add_inputs(input_tensor, scale)
+        output = math_ops.sigmoid(input_tensor) * input_tensor * scale
+        output, = custom.add_outputs(output)
+        return output
 
-    with self.cached_session() as sess:
-      # check if identities have been put into the graph (2 input, 1 output,
-      # and 1 final output).
-      self.assertEqual(self._countIdentities(sess.graph_def.node), 4)
+      output = array_ops.identity(
+          _swish(image, swish_scale), name="ModelOutput")
 
-      stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
-          graph_def=sess.graph_def)
+      with self.cached_session() as sess:
+        # check if identities have been put into the graph (2 input, 1 output,
+        # and 1 final output).
+        self.assertEqual(self._countIdentities(sess.graph_def.node), 4)
 
-      self.assertEqual(
-          self._getGraphOpTypes(
-              stubbed_graphdef,
-              output_nodes=[op_hint._tensor_name_base(output.name)]),
-          set(["cool_activation", "Const", "Identity"]))
+        stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
+            graph_def=sess.graph_def)
+
+        self.assertEqual(
+            self._getGraphOpTypes(
+                stubbed_graphdef,
+                output_nodes=[op_hint._tensor_name_base(output.name)]),
+            set(["cool_activation", "Const", "Identity"]))
 
   def testScaleAndBiasAndIdentity(self):
     """This tests a scaled add which has 3 inputs and 2 outputs."""
-    a = array_ops.constant(1.)
-    x = array_ops.constant([2., 3.])
-    b = array_ops.constant([4., 5.])
+    with ops.Graph().as_default():
+      a = array_ops.constant(1.)
+      x = array_ops.constant([2., 3.])
+      b = array_ops.constant([4., 5.])
 
-    def _scaled_and_bias_and_identity(a, x, b):
-      custom = op_hint.OpHint("scale_and_bias_and_identity")
-      a, x, b = custom.add_inputs(a, x, b)
-      return custom.add_outputs(a * x + b, x)
-    output = array_ops.identity(_scaled_and_bias_and_identity(a, x, b),
-                                name="ModelOutput")
+      def _scaled_and_bias_and_identity(a, x, b):
+        custom = op_hint.OpHint("scale_and_bias_and_identity")
+        a, x, b = custom.add_inputs(a, x, b)
+        return custom.add_outputs(a * x + b, x)
 
-    with self.cached_session() as sess:
-      # make sure one identity for each input (3) and output (2) => 3 + 2 = 5
-      # +1 for the final output
-      self.assertEqual(self._countIdentities(sess.graph_def.node), 6)
+      output = array_ops.identity(
+          _scaled_and_bias_and_identity(a, x, b), name="ModelOutput")
 
-      stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
-          graph_def=sess.graph_def)
+      with self.cached_session() as sess:
+        # make sure one identity for each input (3) and output (2) => 3 + 2 = 5
+        # +1 for the final output
+        self.assertEqual(self._countIdentities(sess.graph_def.node), 6)
 
-      self.assertEqual(
-          self._getGraphOpTypes(
-              stubbed_graphdef,
-              output_nodes=[op_hint._tensor_name_base(output.name)]),
-          set(["scale_and_bias_and_identity", "Const", "Identity", "Pack"]))
+        stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
+            graph_def=sess.graph_def)
+
+        self.assertEqual(
+            self._getGraphOpTypes(
+                stubbed_graphdef,
+                output_nodes=[op_hint._tensor_name_base(output.name)]),
+            set(["scale_and_bias_and_identity", "Const", "Identity", "Pack"]))
 
   def testTwoFunctions(self):
     """Tests if two functions are converted correctly."""
-    a = array_ops.constant([1.])
-    b = array_ops.constant([1.])
-    def _double_values(x):
-      custom = op_hint.OpHint("add_test")
-      x, = custom.add_inputs(x)
-      output = math_ops.multiply(x, x)
-      output, = custom.add_outputs(output)
-      return output
-    output = array_ops.identity(
-        math_ops.add(_double_values(a), _double_values(b)), name="ModelOutput")
+    with ops.Graph().as_default():
+      a = array_ops.constant([1.])
+      b = array_ops.constant([1.])
 
-    with self.cached_session() as sess:
-      # make sure one identity for each input (2) and output (2) => 2 + 2
-      # +1 for the final output
-      self.assertEqual(self._countIdentities(sess.graph_def.node), 5)
-      stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
-          graph_def=sess.graph_def)
-      self.assertEqual(
-          self._getGraphOpTypes(
-              stubbed_graphdef,
-              output_nodes=[op_hint._tensor_name_base(output.name)]),
-          set(["add_test", "Const", "Identity", "Add"]))
+      def _double_values(x):
+        custom = op_hint.OpHint("add_test")
+        x, = custom.add_inputs(x)
+        output = math_ops.multiply(x, x)
+        output, = custom.add_outputs(output)
+        return output
+
+      output = array_ops.identity(
+          math_ops.add(_double_values(a), _double_values(b)),
+          name="ModelOutput")
+
+      with self.cached_session() as sess:
+        # make sure one identity for each input (2) and output (2) => 2 + 2
+        # +1 for the final output
+        self.assertEqual(self._countIdentities(sess.graph_def.node), 5)
+        stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
+            graph_def=sess.graph_def)
+        self.assertEqual(
+            self._getGraphOpTypes(
+                stubbed_graphdef,
+                output_nodes=[op_hint._tensor_name_base(output.name)]),
+            set(["add_test", "Const", "Identity", "Add"]))
 
   def _get_input_index(self, x):
     return x.op.node_def.attr[op_hint.OpHint.FUNCTION_INPUT_INDEX_ATTR].i
@@ -307,93 +315,97 @@
 
   def testTags(self):
     """Test if multiple args with the same tag are grouped."""
-    a = array_ops.constant([1.])
-    b = array_ops.constant([2.])
-    c = array_ops.constant([3.])
-    d = array_ops.constant([4.])
-    custom = op_hint.OpHint("test_tag")
-    a = custom.add_input(a, tag="mytag",
-                         aggregate=op_hint.OpHint.AGGREGATE_STACK)
-    b, = custom.add_inputs(b)
-    c = custom.add_input(c, tag="mytag",
-                         aggregate=op_hint.OpHint.AGGREGATE_STACK)
-    d = custom.add_input(d, tag="mytag2",
-                         aggregate=op_hint.OpHint.AGGREGATE_STACK)
-    res = math_ops.add(math_ops.mul(a, b), math_ops.mul(c, b))
-    custom.add_outputs([res])
-    with self.cached_session():
-      self.assertEqual(self._get_input_index(a), 0)
-      self.assertEqual(self._get_sort_index(a), 0)
-      self.assertEqual(self._get_input_index(b), 1)
-      self.assertEqual(self._get_sort_index(b), 0)
-      self.assertEqual(self._get_input_index(c), 0)
-      self.assertEqual(self._get_sort_index(c), 1)
+    with ops.Graph().as_default():
+      a = array_ops.constant([1.])
+      b = array_ops.constant([2.])
+      c = array_ops.constant([3.])
+      d = array_ops.constant([4.])
+      custom = op_hint.OpHint("test_tag")
+      a = custom.add_input(
+          a, tag="mytag", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+      b, = custom.add_inputs(b)
+      c = custom.add_input(
+          c, tag="mytag", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+      d = custom.add_input(
+          d, tag="mytag2", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+      res = math_ops.add(math_ops.mul(a, b), math_ops.mul(c, b))
+      custom.add_outputs([res])
+      with self.cached_session():
+        self.assertEqual(self._get_input_index(a), 0)
+        self.assertEqual(self._get_sort_index(a), 0)
+        self.assertEqual(self._get_input_index(b), 1)
+        self.assertEqual(self._get_sort_index(b), 0)
+        self.assertEqual(self._get_input_index(c), 0)
+        self.assertEqual(self._get_sort_index(c), 1)
 
   def testOverrideIndex(self):
-    a = array_ops.constant([1.])
-    b = array_ops.constant([2.])
-    c = array_ops.constant([3.])
-    custom = op_hint.OpHint("test_override")
-    b = custom.add_input(b)  # should auto assign 0
-    a = custom.add_input(a, index_override=1)
-    c = custom.add_input(c)  # should auto assign 2
-    with self.cached_session():
-      self.assertEqual(self._get_input_index(a), 1)
-      self.assertEqual(self._get_input_index(b), 0)
-      self.assertEqual(self._get_input_index(c), 2)
+    with ops.Graph().as_default():
+      a = array_ops.constant([1.])
+      b = array_ops.constant([2.])
+      c = array_ops.constant([3.])
+      custom = op_hint.OpHint("test_override")
+      b = custom.add_input(b)  # should auto assign 0
+      a = custom.add_input(a, index_override=1)
+      c = custom.add_input(c)  # should auto assign 2
+      with self.cached_session():
+        self.assertEqual(self._get_input_index(a), 1)
+        self.assertEqual(self._get_input_index(b), 0)
+        self.assertEqual(self._get_input_index(c), 2)
 
   def testAggregate(self):
-    a = array_ops.constant([3., 4.])
-    b = array_ops.constant([5., 6.])
-    hint = op_hint.OpHint("agg")
-    a0, a1 = array_ops.unstack(a)
-    b0, b1 = array_ops.unstack(b)
+    with ops.Graph().as_default():
+      a = array_ops.constant([3., 4.])
+      b = array_ops.constant([5., 6.])
+      hint = op_hint.OpHint("agg")
+      a0, a1 = array_ops.unstack(a)
+      b0, b1 = array_ops.unstack(b)
 
-    a0 = hint.add_input(a0, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK)
-    b0 = hint.add_input(b0, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK)
-    a1 = hint.add_input(a1, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK)
-    b1 = hint.add_input(b1, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+      a0 = hint.add_input(a0, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+      b0 = hint.add_input(b0, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+      a1 = hint.add_input(a1, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+      b1 = hint.add_input(b1, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK)
 
-    c0 = math_ops.add(a0, b0, name="addleft")
-    c1 = math_ops.add(a1, b1, name="addright")
-    c0 = hint.add_output(
-        c0, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK)
-    c1 = hint.add_output(
-        c1, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+      c0 = math_ops.add(a0, b0, name="addleft")
+      c1 = math_ops.add(a1, b1, name="addright")
+      c0 = hint.add_output(
+          c0, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+      c1 = hint.add_output(
+          c1, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK)
 
-    curr = array_ops.stack([c0, c1])
-    output = array_ops.identity(curr, name="FINAL_OUTPUT")
-    with self.cached_session() as sess:
-      stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
-          graph_def=sess.graph_def)
-      self.assertEqual(
-          self._getGraphOpTypes(
-              stubbed_graphdef,
-              output_nodes=[op_hint._tensor_name_base(output.name)]),
-          set(["agg", "Const", "Identity"]))
+      curr = array_ops.stack([c0, c1])
+      output = array_ops.identity(curr, name="FINAL_OUTPUT")
+      with self.cached_session() as sess:
+        stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
+            graph_def=sess.graph_def)
+        self.assertEqual(
+            self._getGraphOpTypes(
+                stubbed_graphdef,
+                output_nodes=[op_hint._tensor_name_base(output.name)]),
+            set(["agg", "Const", "Identity"]))
 
   def testFindHintedOutputNodes(self):
     """Test if all hinted output nodes are correctly found."""
+    with ops.Graph().as_default():
 
-    def _build_ophinted_op(name, input1, input2):
-      custom_op = op_hint.OpHint(name)
-      input1 = custom_op.add_input(input1)
-      input2 = custom_op.add_input(input2)
-      output = math_ops.mul(input1, input2)
-      return custom_op.add_output(output)
+      def _build_ophinted_op(name, input1, input2):
+        custom_op = op_hint.OpHint(name)
+        input1 = custom_op.add_input(input1)
+        input2 = custom_op.add_input(input2)
+        output = math_ops.mul(input1, input2)
+        return custom_op.add_output(output)
 
-    output_1 = _build_ophinted_op("custom_op_1", array_ops.constant([1.]),
-                                  array_ops.constant([2.]))
-    output_2 = _build_ophinted_op("custom_op_2", array_ops.constant([3.]),
-                                  array_ops.constant([4.]))
-    with self.cached_session() as sess:
-      hinted_outputs_nodes = op_hint.find_all_hinted_output_nodes(sess)
-      expected_hinted_output_nodes = [
-          _node_name(output_1.name),
-          _node_name(output_2.name)
-      ]
-      self.assertEqual(
-          len(hinted_outputs_nodes), len(expected_hinted_output_nodes))
+      output_1 = _build_ophinted_op("custom_op_1", array_ops.constant([1.]),
+                                    array_ops.constant([2.]))
+      output_2 = _build_ophinted_op("custom_op_2", array_ops.constant([3.]),
+                                    array_ops.constant([4.]))
+      with self.cached_session() as sess:
+        hinted_outputs_nodes = op_hint.find_all_hinted_output_nodes(sess)
+        expected_hinted_output_nodes = [
+            _node_name(output_1.name),
+            _node_name(output_2.name)
+        ]
+        self.assertEqual(
+            len(hinted_outputs_nodes), len(expected_hinted_output_nodes))
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h
index 56fe360..da3e551 100644
--- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h
+++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h
@@ -26,8 +26,7 @@
 // automatically move <Python.h> before <locale>.
 #include <Python.h>
 
-struct _TfLiteDelegate;
-typedef struct _TfLiteDelegate TfLiteDelegate;
+struct TfLiteDelegate;
 
 // We forward declare TFLite classes here to avoid exposing them to SWIG.
 namespace tflite {
diff --git a/tensorflow/lite/python/lite_flex_test.py b/tensorflow/lite/python/lite_flex_test.py
index c1fc54b..a3294d8 100644
--- a/tensorflow/lite/python/lite_flex_test.py
+++ b/tensorflow/lite/python/lite_flex_test.py
@@ -24,6 +24,7 @@
 from tensorflow.python.eager import def_function
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import variables
@@ -31,14 +32,14 @@
 from tensorflow.python.training.tracking import tracking
 
 
-@test_util.run_v1_only('Incompatible with 2.0.')
 class FromSessionTest(test_util.TensorFlowTestCase):
 
   def testFlexMode(self):
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    out_tensor = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      out_tensor = in_tensor + in_tensor
+      sess = session.Session()
 
     # Convert model and ensure model is not None.
     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
@@ -58,10 +59,11 @@
         str(error.exception))
 
   def testDeprecatedFlags(self):
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    out_tensor = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      out_tensor = in_tensor + in_tensor
+      sess = session.Session()
 
     # Convert model and ensure model is not None.
     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
diff --git a/tensorflow/lite/python/lite_mlir_test.py b/tensorflow/lite/python/lite_mlir_test.py
index f234eaf..8cdb100 100644
--- a/tensorflow/lite/python/lite_mlir_test.py
+++ b/tensorflow/lite/python/lite_mlir_test.py
@@ -23,6 +23,7 @@
 from tensorflow.lite.python import lite
 from tensorflow.lite.python import lite_constants
 from tensorflow.lite.python.interpreter import Interpreter
+from tensorflow.python import keras
 from tensorflow.python.client import session
 from tensorflow.python.eager import def_function
 from tensorflow.python.framework import constant_op
@@ -40,14 +41,14 @@
 from tensorflow.python.training.tracking import tracking
 
 
-@test_util.run_v1_only('Incompatible with 2.0.')
 class FromSessionTest(test_util.TensorFlowTestCase):
 
   def testFloat(self):
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    out_tensor = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      out_tensor = in_tensor + in_tensor
+      sess = session.Session()
 
     # Convert model and ensure model is not None.
     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
@@ -74,13 +75,15 @@
     self.assertEqual((0., 0.), output_details[0]['quantization'])
 
   def testString(self):
-    in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.string)
-    out_tensor = array_ops.reshape(in_tensor, shape=[2, 2])
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.string)
+      out_tensor = array_ops.reshape(in_tensor, shape=[2, 2])
+      sess = session.Session()
 
     # Convert model and ensure model is not None.
     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                   [out_tensor])
+    converter.experimental_enable_mlir_converter = True
     tflite_model = converter.convert()
 
     # Check values from converted model.
@@ -100,13 +103,14 @@
     self.assertTrue(([2, 2] == output_details[0]['shape']).all())
 
   def testQuantization(self):
-    in_tensor_1 = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
-    in_tensor_2 = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
-    out_tensor = array_ops.fake_quant_with_min_max_args(
-        in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor_1 = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
+      in_tensor_2 = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
+      out_tensor = array_ops.fake_quant_with_min_max_args(
+          in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
+      sess = session.Session()
 
     # Convert model and ensure model is not None.
     converter = lite.TFLiteConverter.from_session(sess,
@@ -147,13 +151,15 @@
 
   def testScalarValid(self):
     # Construct a graph using a scalar (empty shape) input.
-    in_tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[])
-    out_tensor = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[])
+      out_tensor = in_tensor + in_tensor
+      sess = session.Session()
 
     # Test conversion with the scalar input shape.
     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                   [out_tensor])
+    converter.experimental_enable_mlir_converter = True
     tflite_model = converter.convert()
 
     # Check values from converted model.
@@ -182,27 +188,31 @@
     self.assertTrue((expected_output == output_data).all())
 
   def testPostTrainingQuantize(self):
+    self.skipTest('b/124315492')
     np.random.seed(0)
-    # We need the tensor to have more than 1024 elements for quantize_weights
-    # to kick in. Thus, the [33, 33] shape.
-    in_tensor_1 = array_ops.placeholder(
-        shape=[33, 33], dtype=dtypes.float32, name='inputA')
-    in_tensor_2 = constant_op.constant(
-        np.random.uniform(low=-10., high=10., size=(33, 33)),
-        shape=[33, 33],
-        dtype=dtypes.float32,
-        name='inputB')
-    out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
-    sess = session.Session()
+    with ops.Graph().as_default():
+      # We need the tensor to have more than 1024 elements for quantize_weights
+      # to kick in. Thus, the [33, 33] shape.
+      in_tensor_1 = array_ops.placeholder(
+          shape=[33, 33], dtype=dtypes.float32, name='inputA')
+      in_tensor_2 = constant_op.constant(
+          np.random.uniform(low=-10., high=10., size=(33, 33)),
+          shape=[33, 33],
+          dtype=dtypes.float32,
+          name='inputB')
+      out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
+      sess = session.Session()
 
     # Convert float model.
     float_converter = lite.TFLiteConverter.from_session(sess, [in_tensor_1],
                                                         [out_tensor])
+    float_converter.experimental_enable_mlir_converter = True
     float_tflite = float_converter.convert()
 
     # Convert quantized weights model.
     quantized_converter = lite.TFLiteConverter.from_session(
         sess, [in_tensor_1], [out_tensor])
+    quantized_converter.experimental_enable_mlir_converter = True
     quantized_converter.optimizations = [lite.Optimize.DEFAULT]
     quantized_tflite = quantized_converter.convert()
 
@@ -231,6 +241,7 @@
     # Convert model and ensure model is not None.
     converter = lite.TFLiteConverter.from_session(sess, [placeholder],
                                                   [output_node])
+    converter.experimental_enable_mlir_converter = True
     tflite_model = converter.convert()
 
     # Check values from converted model.
@@ -416,15 +427,39 @@
         expected = expected.c.numpy()
       np.testing.assert_almost_equal(expected, actual)
 
+  @test_util.run_v2_only
+  def testKerasLSTM(self):
+    self.skipTest('b/138657502')
+    input_data = constant_op.constant(
+        np.array(np.random.random_sample((10, 10, 10)), dtype=np.float32))
+
+    model = keras.models.Sequential(
+        [keras.layers.LSTM(units=10, input_shape=(10, 10))])
+
+    run_model = def_function.function(model.__call__)
+    concrete_func = run_model.get_concrete_function(
+        tensor_spec.TensorSpec((10, 10, 10), dtype=dtypes.float32))
+
+    # Convert model.
+    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
+    converter.experimental_enable_mlir_converter = True
+    tflite_model = converter.convert()
+
+    # Check values from converted model.
+    expected_value = concrete_func(input_data)
+    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
+    for expected, actual in zip(expected_value, actual_value):
+      np.testing.assert_almost_equal(expected, actual)
+
 
 class TestFlexMode(test_util.TensorFlowTestCase):
 
-  @test_util.run_v1_only('Incompatible with 2.0.')
   def testSession(self):
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    out_tensor = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      out_tensor = in_tensor + in_tensor
+      sess = session.Session()
 
     # Convert model and ensure model is not None.
     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py
index ae68022..d206083 100644
--- a/tensorflow/lite/python/lite_test.py
+++ b/tensorflow/lite/python/lite_test.py
@@ -101,14 +101,14 @@
     self.assertTrue(converter._has_valid_tensors())
 
 
-@test_util.run_v1_only('Incompatible with 2.0.')
 class FromSessionTest(TestModels, parameterized.TestCase):
 
   def testFloat(self):
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    out_tensor = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      out_tensor = in_tensor + in_tensor
+      sess = session.Session()
 
     # Convert model and ensure model is not None.
     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
@@ -135,9 +135,10 @@
     self.assertEqual((0., 0.), output_details[0]['quantization'])
 
   def testString(self):
-    in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.string)
-    out_tensor = array_ops.reshape(in_tensor, shape=[2, 2])
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.string)
+      out_tensor = array_ops.reshape(in_tensor, shape=[2, 2])
+      sess = session.Session()
 
     # Convert model and ensure model is not None.
     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
@@ -164,13 +165,14 @@
     # interpreter API after support has been added.
 
   def testQuantization(self):
-    in_tensor_1 = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
-    in_tensor_2 = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
-    out_tensor = array_ops.fake_quant_with_min_max_args(
-        in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor_1 = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
+      in_tensor_2 = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
+      out_tensor = array_ops.fake_quant_with_min_max_args(
+          in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
+      sess = session.Session()
 
     # Convert model and ensure model is not None.
     converter = lite.TFLiteConverter.from_session(sess,
@@ -210,13 +212,14 @@
     self.assertTrue(output_details[0]['quantization'][0] > 0)  # scale
 
   def testQuantizationInvalid(self):
-    in_tensor_1 = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
-    in_tensor_2 = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
-    out_tensor = array_ops.fake_quant_with_min_max_args(
-        in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor_1 = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
+      in_tensor_2 = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
+      out_tensor = array_ops.fake_quant_with_min_max_args(
+          in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
+      sess = session.Session()
 
     # Convert model and ensure model is not None.
     converter = lite.TFLiteConverter.from_session(sess,
@@ -232,11 +235,12 @@
 
   def testIntermediateInputArray(self):
     """Convert a model from an intermediate input array."""
-    in_tensor_init = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    in_tensor_final = in_tensor_init + in_tensor_init
-    out_tensor = in_tensor_final + in_tensor_final
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor_init = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      in_tensor_final = in_tensor_init + in_tensor_init
+      out_tensor = in_tensor_final + in_tensor_final
+      sess = session.Session()
 
     # Convert model and ensure model is not None.
     converter = lite.TFLiteConverter.from_session(sess, [in_tensor_final],
@@ -263,9 +267,10 @@
     self.assertEqual((0., 0.), output_details[0]['quantization'])
 
   def testSizeNoneInvalid(self):
-    in_tensor = array_ops.placeholder(dtype=dtypes.float32)
-    out_tensor = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(dtype=dtypes.float32)
+      out_tensor = in_tensor + in_tensor
+      sess = session.Session()
 
     # Test None as shape.
     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
@@ -277,9 +282,10 @@
 
   def testScalarValid(self):
     # Construct a graph using a scalar (empty shape) input.
-    in_tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[])
-    out_tensor = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[])
+      out_tensor = in_tensor + in_tensor
+      sess = session.Session()
 
     # Test conversion with the scalar input shape.
     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
@@ -313,10 +319,11 @@
     self.assertTrue((expected_output == output_data).all())
 
   def testSizeInvalid(self):
-    in_tensor = array_ops.placeholder(
-        shape=[1, None, 16, 3], dtype=dtypes.float32)
-    out_tensor = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, None, 16, 3], dtype=dtypes.float32)
+      out_tensor = in_tensor + in_tensor
+      sess = session.Session()
 
     # Test invalid shape. None after 1st dimension.
     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
@@ -329,10 +336,11 @@
         str(error.exception))
 
   def testBatchSizeValid(self):
-    in_tensor = array_ops.placeholder(
-        shape=[None, 16, 16, 3], dtype=dtypes.float32)
-    out_tensor = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[None, 16, 16, 3], dtype=dtypes.float32)
+      out_tensor = in_tensor + in_tensor
+      sess = session.Session()
 
     # Convert model and ensure model is not None.
     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
@@ -359,13 +367,14 @@
     self.assertEqual((0., 0.), output_details[0]['quantization'])
 
   def testFreezeGraph(self):
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    var = variable_scope.get_variable(
-        'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    out_tensor = in_tensor + var
-    sess = session.Session()
-    sess.run(_global_variables_initializer())
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      var = variable_scope.get_variable(
+          'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      out_tensor = in_tensor + var
+      sess = session.Session()
+      sess.run(_global_variables_initializer())
 
     # Convert model and ensure model is not None.
     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
@@ -391,12 +400,12 @@
     self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
     self.assertEqual((0., 0.), output_details[0]['quantization'])
 
-  # TODO(nupurgarg): Verify value of contents in GraphViz.
   def testGraphviz(self):
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    out_tensor = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      out_tensor = in_tensor + in_tensor
+      sess = session.Session()
 
     # Convert model and ensure model is not None.
     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
@@ -405,12 +414,12 @@
     graphviz_output = converter.convert()
     self.assertTrue(graphviz_output)
 
-  # TODO(nupurgarg): Verify value of contents in GraphViz.
   def testDumpGraphviz(self):
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    out_tensor = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      out_tensor = in_tensor + in_tensor
+      sess = session.Session()
 
     # Convert model and ensure model is not None.
     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
@@ -441,10 +450,11 @@
     self.assertTrue(num_items_graphviz_video > num_items_graphviz)
 
   def testInferenceInputType(self):
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    out_tensor = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      out_tensor = in_tensor + in_tensor
+      sess = session.Session()
 
     # Convert model and ensure model is not None.
     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
@@ -472,10 +482,11 @@
     self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
 
   def testDefaultRangesStats(self):
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    out_tensor = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      out_tensor = in_tensor + in_tensor
+      sess = session.Session()
 
     # Convert model and ensure model is not None.
     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
@@ -505,15 +516,16 @@
     self.assertTrue(output_details[0]['quantization'][0] > 0)  # scale
 
   def testPostTrainingQuantizeDeprecatedAttribute(self):
-    in_tensor_1 = array_ops.placeholder(
-        shape=[33, 33], dtype=dtypes.float32, name='inputA')
-    in_tensor_2 = constant_op.constant(
-        np.random.uniform(low=-10., high=10., size=(33, 33)),
-        shape=[33, 33],
-        dtype=dtypes.float32,
-        name='inputB')
-    out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor_1 = array_ops.placeholder(
+          shape=[33, 33], dtype=dtypes.float32, name='inputA')
+      in_tensor_2 = constant_op.constant(
+          np.random.uniform(low=-10., high=10., size=(33, 33)),
+          shape=[33, 33],
+          dtype=dtypes.float32,
+          name='inputB')
+      out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
+      sess = session.Session()
 
     quantized_converter = lite.TFLiteConverter.from_session(
         sess, [in_tensor_1], [out_tensor])
@@ -528,17 +540,18 @@
 
   def testPostTrainingQuantize(self):
     np.random.seed(0)
-    # We need the tensor to have more than 1024 elements for quantize_weights
-    # to kick in. Thus, the [33, 33] shape.
-    in_tensor_1 = array_ops.placeholder(
-        shape=[33, 33], dtype=dtypes.float32, name='inputA')
-    in_tensor_2 = constant_op.constant(
-        np.random.uniform(low=-10., high=10., size=(33, 33)),
-        shape=[33, 33],
-        dtype=dtypes.float32,
-        name='inputB')
-    out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
-    sess = session.Session()
+    with ops.Graph().as_default():
+      # We need the tensor to have more than 1024 elements for quantize_weights
+      # to kick in. Thus, the [33, 33] shape.
+      in_tensor_1 = array_ops.placeholder(
+          shape=[33, 33], dtype=dtypes.float32, name='inputA')
+      in_tensor_2 = constant_op.constant(
+          np.random.uniform(low=-10., high=10., size=(33, 33)),
+          shape=[33, 33],
+          dtype=dtypes.float32,
+          name='inputB')
+      out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
+      sess = session.Session()
 
     # Convert float model.
     float_converter = lite.TFLiteConverter.from_session(sess, [in_tensor_1],
@@ -574,8 +587,9 @@
     return (inp, output, calibration_gen)
 
   def testPostTrainingCalibrateAndQuantize(self):
-    inp, output, calibration_gen = self._getCalibrationQuantizeModel()
-    sess = session.Session()
+    with ops.Graph().as_default():
+      inp, output, calibration_gen = self._getCalibrationQuantizeModel()
+      sess = session.Session()
 
     # Convert float model.
     float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
@@ -604,8 +618,9 @@
     self.assertLess(len(quantized_tflite), len(float_tflite))
 
   def testCalibrateAndQuantizeBuiltinInt8(self):
-    inp, output, calibration_gen = self._getCalibrationQuantizeModel()
-    sess = session.Session()
+    with ops.Graph().as_default():
+      inp, output, calibration_gen = self._getCalibrationQuantizeModel()
+      sess = session.Session()
 
     # Convert float model.
     float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
@@ -648,8 +663,9 @@
   def testQuantizeFloat16(self, use_rep_data, include_int8,
                           is_float16_quantized, is_error,
                           is_post_training_quantized):
-    inp, output, calibration_gen = self._getCalibrationQuantizeModel()
-    sess = session.Session()
+    with ops.Graph().as_default():
+      inp, output, calibration_gen = self._getCalibrationQuantizeModel()
+      sess = session.Session()
 
     # Convert float model.
     float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
@@ -698,8 +714,9 @@
         raise ValueError('Invalid test options.')
 
   def testInvalidQuantizeFloat16(self):
-    inp, output, _ = self._getCalibrationQuantizeModel()
-    sess = session.Session()
+    with ops.Graph().as_default():
+      inp, output, _ = self._getCalibrationQuantizeModel()
+      sess = session.Session()
 
     # Specify float16 quantization
     quantized_converter = lite.TFLiteConverter.from_session(
@@ -718,17 +735,18 @@
 
   def testInvalidPostTrainingQuantize(self):
     np.random.seed(0)
-    # We need the tensor to have more than 1024 elements for quantize_weights
-    # to kick in. Thus, the [33, 33] shape.
-    in_tensor_1 = array_ops.placeholder(
-        shape=[33, 33], dtype=dtypes.float32, name='inputA')
-    in_tensor_2 = constant_op.constant(
-        np.random.uniform(low=-10., high=10., size=(33, 33)),
-        shape=[33, 33],
-        dtype=dtypes.float32,
-        name='inputB')
-    out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
-    sess = session.Session()
+    with ops.Graph().as_default():
+      # We need the tensor to have more than 1024 elements for quantize_weights
+      # to kick in. Thus, the [33, 33] shape.
+      in_tensor_1 = array_ops.placeholder(
+          shape=[33, 33], dtype=dtypes.float32, name='inputA')
+      in_tensor_2 = constant_op.constant(
+          np.random.uniform(low=-10., high=10., size=(33, 33)),
+          shape=[33, 33],
+          dtype=dtypes.float32,
+          name='inputB')
+      out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
+      sess = session.Session()
 
     # Attempt to convert to quantized weights model.
     quantized_converter = lite.TFLiteConverter.from_session(
@@ -744,8 +762,9 @@
         'TFLITE_BUILTINS_INT8 or INT8 supported types.', str(error.exception))
 
   def testPostTrainingCalibrateAndQuantizeFloatNotAllowed(self):
-    inp, output, calibration_gen = self._getCalibrationQuantizeModel()
-    sess = session.Session()
+    with ops.Graph().as_default():
+      inp, output, calibration_gen = self._getCalibrationQuantizeModel()
+      sess = session.Session()
 
     # Convert float model.
     float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
@@ -768,8 +787,9 @@
     self.assertLess(len(quantized_tflite), len(float_tflite))
 
   def testPostTrainingCalibrateAndQuantizeInt8Inputs(self):
-    inp, output, calibration_gen = self._getCalibrationQuantizeModel()
-    sess = session.Session()
+    with ops.Graph().as_default():
+      inp, output, calibration_gen = self._getCalibrationQuantizeModel()
+      sess = session.Session()
 
     # Convert float model.
     float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
@@ -801,10 +821,11 @@
 
   def testFloatTocoConverter(self):
     """Tests deprecated test TocoConverter."""
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    out_tensor = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      out_tensor = in_tensor + in_tensor
+      sess = session.Session()
 
     # Convert model and ensure model is not None.
     converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
@@ -817,9 +838,11 @@
 
   def testMultipleOutputNodeNames(self):
     """Tests converting a graph with an op that have multiple outputs."""
-    input_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32)
-    out0, out1, out2, out3 = array_ops.split(input_tensor, [1, 1, 1, 1], axis=0)
-    sess = session.Session()
+    with ops.Graph().as_default():
+      input_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32)
+      out0, out1, out2, out3 = array_ops.split(
+          input_tensor, [1, 1, 1, 1], axis=0)
+      sess = session.Session()
 
     # Convert model and ensure model is not None.
     converter = lite.TFLiteConverter.from_session(sess, [input_tensor],
@@ -888,10 +911,11 @@
     self.assertEqual((0., 0.), output_details[0]['quantization'])
 
   def testInferenceInputOutputTypeFloatDefault(self):
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    out_tensor = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      out_tensor = in_tensor + in_tensor
+      sess = session.Session()
 
     # Convert model and ensure model is not None.
     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
@@ -916,11 +940,12 @@
     self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
 
   def testInferenceInputOutputTypeQuantizedUint8Default(self):
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    out_tensor = array_ops.fake_quant_with_min_max_args(
-        in_tensor + in_tensor, min=0., max=1., name='output')
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      out_tensor = array_ops.fake_quant_with_min_max_args(
+          in_tensor + in_tensor, min=0., max=1., name='output')
+      sess = session.Session()
 
     # Convert model and ensure model is not None.
     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
@@ -947,11 +972,12 @@
     self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
 
   def testReusingConverterWithDifferentPostTrainingQuantization(self):
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    out_tensor = array_ops.fake_quant_with_min_max_args(
-        in_tensor + in_tensor, min=0., max=1., name='output')
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      out_tensor = array_ops.fake_quant_with_min_max_args(
+          in_tensor + in_tensor, min=0., max=1., name='output')
+      sess = session.Session()
 
     # Convert model and ensure model is not None.
     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
@@ -969,16 +995,18 @@
     # This is a regression test for the case where shape of dynamic output
     # tensors changes between invocations.
     # See also https://github.com/tensorflow/tensorflow/issues/26549
-    input_tensor = array_ops.placeholder(shape=[1, 1], dtype=dtypes.float32)
-    input2_tensor = array_ops.placeholder(shape=[1], dtype=dtypes.float32)
+    with ops.Graph().as_default():
+      input_tensor = array_ops.placeholder(shape=[1, 1], dtype=dtypes.float32)
+      input2_tensor = array_ops.placeholder(shape=[1], dtype=dtypes.float32)
 
-    # The bug is triggered only when dynamic tensor is intermediate. Putting
-    # some other ops around it.
-    neg = math_ops.negative(input2_tensor)
-    padding = array_ops.placeholder(shape=[2, 2], dtype=dtypes.int32)
-    output_tensor = array_ops.pad(input_tensor, padding) + neg
+      # The bug is triggered only when dynamic tensor is intermediate. Putting
+      # some other ops around it.
+      neg = math_ops.negative(input2_tensor)
+      padding = array_ops.placeholder(shape=[2, 2], dtype=dtypes.int32)
+      output_tensor = array_ops.pad(input_tensor, padding) + neg
 
-    sess = session.Session()
+      sess = session.Session()
+
     converter = lite.TFLiteConverter.from_session(
         sess, [input_tensor, padding, input2_tensor], [output_tensor])
     tflite_model = converter.convert()
@@ -1025,14 +1053,14 @@
     self.assertIn((func + 'add'), converter._debug_info.traces)
 
 
-@test_util.run_v1_only('Incompatible with 2.0.')
 class FromFrozenGraphFile(test_util.TensorFlowTestCase):
 
   def testFloat(self):
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    _ = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      _ = in_tensor + in_tensor
+      sess = session.Session()
 
     # Write graph to file.
     graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
@@ -1064,10 +1092,11 @@
     self.assertEqual((0., 0.), output_details[0]['quantization'])
 
   def testFloatWithShapesArray(self):
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    _ = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      _ = in_tensor + in_tensor
+      sess = session.Session()
 
     # Write graph to file.
     graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
@@ -1090,12 +1119,13 @@
     self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
 
   def testFreezeGraph(self):
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    var = variable_scope.get_variable(
-        'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    _ = in_tensor + var
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      var = variable_scope.get_variable(
+          'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      _ = in_tensor + var
+      sess = session.Session()
 
     # Write graph to file.
     graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
@@ -1110,10 +1140,11 @@
                      str(error.exception))
 
   def testPbtxt(self):
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    _ = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      _ = in_tensor + in_tensor
+      sess = session.Session()
 
     # Write graph to file.
     graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt')
@@ -1166,10 +1197,11 @@
         str(error.exception))
 
   def testFloatTocoConverter(self):
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    _ = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      _ = in_tensor + in_tensor
+      sess = session.Session()
 
     # Write graph to file.
     graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
@@ -1188,10 +1220,11 @@
 
   def testGraphDebugInfo(self):
     """Test a frozen graph doesn't have debug info captured."""
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    _ = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      _ = in_tensor + in_tensor
+      sess = session.Session()
 
     # Write graph to file.
     graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
@@ -1296,21 +1329,21 @@
         str(error.exception))
 
 
-@test_util.run_v1_only('Incompatible with 2.0.')
 class FromSavedModelTest(TestModels):
 
   def _createSavedModel(self, shape):
     """Create a simple SavedModel."""
     saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel')
-    with session.Session() as sess:
-      in_tensor_1 = array_ops.placeholder(
-          shape=shape, dtype=dtypes.float32, name='inputB')
-      in_tensor_2 = array_ops.placeholder(
-          shape=shape, dtype=dtypes.float32, name='inputA')
-      out_tensor = in_tensor_1 + in_tensor_2
-      inputs = {'x': in_tensor_1, 'y': in_tensor_2}
-      outputs = {'z': out_tensor}
-      saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
+    with ops.Graph().as_default():
+      with session.Session() as sess:
+        in_tensor_1 = array_ops.placeholder(
+            shape=shape, dtype=dtypes.float32, name='inputB')
+        in_tensor_2 = array_ops.placeholder(
+            shape=shape, dtype=dtypes.float32, name='inputA')
+        out_tensor = in_tensor_1 + in_tensor_2
+        inputs = {'x': in_tensor_1, 'y': in_tensor_2}
+        outputs = {'z': out_tensor}
+        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
     return saved_model_dir
 
   def testSimpleModel(self):
@@ -1465,7 +1498,6 @@
     return config
 
 
-@test_util.run_v1_only('Incompatible with 2.0.')
 class FromKerasFile(TestModels, parameterized.TestCase):
 
   def setUp(self):
@@ -1578,6 +1610,7 @@
 
   def testSequentialModelInputArray(self):
     """Test a Sequential tf.keras model testing input arrays argument."""
+    ops.disable_eager_execution()
     self._getSequentialModel()
 
     # Invalid input array raises error.
@@ -1622,6 +1655,7 @@
 
   def testSequentialModelOutputArray(self):
     """Test a Sequential tf.keras model testing output arrays argument."""
+    ops.disable_eager_execution()
     self._getSequentialModel()
 
     # Invalid output array raises error.
@@ -1747,12 +1781,10 @@
 
     output_details = interpreter.get_output_details()
     self.assertLen(output_details, 2)
-    self.assertEqual('dense_1/BiasAdd', output_details[0]['name'])
     self.assertEqual(np.float32, output_details[0]['dtype'])
     self.assertTrue(([1, 4] == output_details[0]['shape']).all())
     self.assertEqual((0., 0.), output_details[0]['quantization'])
 
-    self.assertEqual('dropout/Identity', output_details[1]['name'])
     self.assertEqual(np.float32, output_details[1]['dtype'])
     self.assertTrue(([1, 4] == output_details[1]['shape']).all())
     self.assertEqual((0., 0.), output_details[1]['quantization'])
@@ -1800,7 +1832,6 @@
 
     output_details = interpreter.get_output_details()
     self.assertLen(output_details, 1)
-    self.assertEqual('time_distributed/Reshape_1', output_details[0]['name'])
     self.assertEqual(np.float32, output_details[0]['dtype'])
     self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all())
     self.assertEqual((0., 0.), output_details[0]['quantization'])
@@ -1839,17 +1870,18 @@
       self.assertValidDebugInfo(converter._debug_info)
 
 
-@test_util.run_v1_only('Incompatible with 2.0.')
 class GrapplerTest(TestModels):
 
   def testConstantFolding(self):
+    ops.disable_eager_execution()
     # Constant folding handles the tf.broadcast_to operation which was not
     # supported by the TFLite at the time this test was added.
-    in_tensor = array_ops.placeholder(shape=[3, 3], dtype=dtypes.float32)
-    y_const = constant_op.constant([1., 2., 3.])
-    y_broadcast = gen_array_ops.broadcast_to(y_const, [3, 3])
-    out_tensor = math_ops.matmul(in_tensor, y_broadcast, name='output')
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(shape=[3, 3], dtype=dtypes.float32)
+      y_const = constant_op.constant([1., 2., 3.])
+      y_broadcast = gen_array_ops.broadcast_to(y_const, [3, 3])
+      out_tensor = math_ops.matmul(in_tensor, y_broadcast, name='output')
+      sess = session.Session()
 
     # Convert model.
     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
diff --git a/tensorflow/lite/python/op_hint.py b/tensorflow/lite/python/op_hint.py
index 390f4a0..5aa212a 100644
--- a/tensorflow/lite/python/op_hint.py
+++ b/tensorflow/lite/python/op_hint.py
@@ -854,7 +854,7 @@
     if n in reachable_by_output:
       if n not in reachable_by_input and n not in output_nodes_set:
         # special handle for while loop function def.
-        if node.op == "While":
+        if node.op == "While" or node.op == "StatelessWhile":
           body_name = node.attr["body"].func.name
           inputs_outside_loop = node.input
           for function_def in graph_def.library.function:
diff --git a/tensorflow/lite/python/util_test.py b/tensorflow/lite/python/util_test.py
index f13fad5..0c76db2 100644
--- a/tensorflow/lite/python/util_test.py
+++ b/tensorflow/lite/python/util_test.py
@@ -24,6 +24,7 @@
 from tensorflow.python.client import session
 from tensorflow.python.framework import convert_to_constants
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
@@ -32,7 +33,6 @@
 
 
 # TODO(nupurgarg): Add test for Grappler and frozen graph related functions.
-@test_util.run_v1_only("Incompatible with 2.0.")
 class UtilTest(test_util.TensorFlowTestCase):
 
   def testConvertDtype(self):
@@ -59,50 +59,53 @@
         util.convert_dtype_to_tflite_type(dtypes.bool), _types_pb2.BOOL)
 
   def testTensorName(self):
-    in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32)
-    # out_tensors should have names: "split:0", "split:1", "split:2", "split:3".
-    out_tensors = array_ops.split(
-        value=in_tensor, num_or_size_splits=[1, 1, 1, 1], axis=0)
-    expect_names = ["split", "split:1", "split:2", "split:3"]
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32)
+      out_tensors = array_ops.split(
+          value=in_tensor, num_or_size_splits=[1, 1, 1, 1], axis=0)
 
+    expect_names = ["split", "split:1", "split:2", "split:3"]
     for i in range(len(expect_names)):
       got_name = util.get_tensor_name(out_tensors[i])
       self.assertEqual(got_name, expect_names[i])
 
   @test_util.enable_control_flow_v2
   def testRemoveLowerUsingSwitchMerge(self):
-    i = array_ops.placeholder(shape=(), dtype=dtypes.int32)
-    c = lambda i: math_ops.less(i, 10)
-    b = lambda i: math_ops.add(i, 1)
-    control_flow_ops.while_loop(c, b, [i])
-    sess = session.Session()
+    with ops.Graph().as_default():
+      i = array_ops.placeholder(shape=(), dtype=dtypes.int32)
+      c = lambda i: math_ops.less(i, 10)
+      b = lambda i: math_ops.add(i, 1)
+      control_flow_ops.while_loop(c, b, [i])
+      sess = session.Session()
+
     new_graph_def = convert_to_constants.disable_lower_using_switch_merge(
         sess.graph_def)
     lower_using_switch_merge_is_removed = False
     for node in new_graph_def.node:
-      if node.op == "While":
+      if node.op == "While" or node.op == "StatelessWhile":
         if not node.attr["_lower_using_switch_merge"].b:
           lower_using_switch_merge_is_removed = True
     self.assertEqual(lower_using_switch_merge_is_removed, True)
 
 
-@test_util.run_v1_only("Incompatible with 2.0.")
 class TensorFunctionsTest(test_util.TensorFlowTestCase):
 
   def testGetTensorsValid(self):
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    _ = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      _ = in_tensor + in_tensor
+      sess = session.Session()
 
     tensors = util.get_tensors_from_tensor_names(sess.graph, ["Placeholder"])
     self.assertEqual("Placeholder:0", tensors[0].name)
 
   def testGetTensorsInvalid(self):
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32)
-    _ = in_tensor + in_tensor
-    sess = session.Session()
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, 16, 16, 3], dtype=dtypes.float32)
+      _ = in_tensor + in_tensor
+      sess = session.Session()
 
     with self.assertRaises(ValueError) as error:
       util.get_tensors_from_tensor_names(sess.graph, ["invalid-input"])
@@ -110,14 +113,16 @@
                      str(error.exception))
 
   def testSetTensorShapeValid(self):
-    tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
+    with ops.Graph().as_default():
+      tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
     self.assertEqual([None, 3, 5], tensor.shape.as_list())
 
     util.set_tensor_shapes([tensor], {"Placeholder": [5, 3, 5]})
     self.assertEqual([5, 3, 5], tensor.shape.as_list())
 
   def testSetTensorShapeNoneValid(self):
-    tensor = array_ops.placeholder(dtype=dtypes.float32)
+    with ops.Graph().as_default():
+      tensor = array_ops.placeholder(dtype=dtypes.float32)
     self.assertEqual(None, tensor.shape)
 
     util.set_tensor_shapes([tensor], {"Placeholder": [1, 3, 5]})
@@ -125,7 +130,8 @@
 
   def testSetTensorShapeArrayInvalid(self):
     # Tests set_tensor_shape where the tensor name passed in doesn't exist.
-    tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
+    with ops.Graph().as_default():
+      tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
     self.assertEqual([None, 3, 5], tensor.shape.as_list())
 
     with self.assertRaises(ValueError) as error:
@@ -138,7 +144,8 @@
   @test_util.run_deprecated_v1
   def testSetTensorShapeDimensionInvalid(self):
     # Tests set_tensor_shape where the shape passed in is incompatiable.
-    tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
+    with ops.Graph().as_default():
+      tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
     self.assertEqual([None, 3, 5], tensor.shape.as_list())
 
     with self.assertRaises(ValueError) as error:
@@ -148,7 +155,8 @@
     self.assertEqual([None, 3, 5], tensor.shape.as_list())
 
   def testSetTensorShapeEmpty(self):
-    tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
+    with ops.Graph().as_default():
+      tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
     self.assertEqual([None, 3, 5], tensor.shape.as_list())
 
     util.set_tensor_shapes([tensor], {})
diff --git a/tensorflow/lite/schema/BUILD b/tensorflow/lite/schema/BUILD
index 814fa62..338369138 100644
--- a/tensorflow/lite/schema/BUILD
+++ b/tensorflow/lite/schema/BUILD
@@ -95,7 +95,7 @@
         "tflite_not_portable_ios",
     ],
     deps = [
-        "//tensorflow/core:lib_platform",
+        "//tensorflow/core/platform",
         "@com_google_googletest//:gtest",
         "@flatbuffers//:flatc_library",
     ],
diff --git a/tensorflow/lite/testdata/test_min_runtime.bin b/tensorflow/lite/testdata/test_min_runtime.bin
new file mode 100644
index 0000000..c681743
--- /dev/null
+++ b/tensorflow/lite/testdata/test_min_runtime.bin
Binary files differ
diff --git a/tensorflow/lite/testing/BUILD b/tensorflow/lite/testing/BUILD
index e2eb79d..4f89fda 100644
--- a/tensorflow/lite/testing/BUILD
+++ b/tensorflow/lite/testing/BUILD
@@ -74,11 +74,21 @@
 )
 
 py_library(
-    name = "generate_examples_lib",
-    srcs = ["generate_examples_lib.py"],
+    name = "toco_convert",
+    srcs = ["toco_convert.py"],
     data = [
         "//tensorflow/lite/toco",
     ],
+    deps = [
+        ":generate_examples_lib",
+        "//tensorflow:tensorflow_py",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_library(
+    name = "generate_examples_lib",
+    srcs = ["generate_examples_lib.py"],
     srcs_version = "PY2AND3",
     deps = [
         ":generate_examples_report",
@@ -93,13 +103,11 @@
 py_binary(
     name = "generate_examples",
     srcs = ["generate_examples.py"],
-    data = [
-        "//tensorflow/lite/toco",
-    ],
     python_version = "PY2",
     srcs_version = "PY2AND3",
     deps = [
         ":generate_examples_lib",
+        ":toco_convert",
         "//tensorflow:tensorflow_py",
         "//third_party/py/numpy",
         "@six_archive//:six",
@@ -190,6 +198,7 @@
         "//tensorflow/lite/kernels:builtin_ops",
         "//tensorflow/lite/kernels:custom_ops",
         "//tensorflow/lite/kernels:reference_ops",
+        "//tensorflow/lite/tools/evaluation:utils",
         "@com_google_absl//absl/strings",
     ],
 )
@@ -368,6 +377,7 @@
     deps = [
         ":split",
         ":tflite_diff_util",
+        ":tflite_driver",
     ] + select({
         "//conditions:default": [
             "//tensorflow/core:framework_internal",
diff --git a/tensorflow/lite/testing/generate_examples.py b/tensorflow/lite/testing/generate_examples.py
index 5d8662d..1678e68 100644
--- a/tensorflow/lite/testing/generate_examples.py
+++ b/tensorflow/lite/testing/generate_examples.py
@@ -35,6 +35,7 @@
 import os
 import sys
 from tensorflow.lite.testing import generate_examples_lib
+from tensorflow.lite.testing import toco_convert
 
 # TODO(aselle): Disable GPU for now
 os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
@@ -95,6 +96,7 @@
   options.run_with_flex = FLAGS.run_with_flex
   options.make_edgetpu_tests = FLAGS.make_edgetpu_tests
   options.make_forward_compat_test = FLAGS.make_forward_compat_test
+  options.tflite_convert_function = toco_convert.toco_convert
 
   generate_examples_lib.generate_examples(options)
 
diff --git a/tensorflow/lite/testing/generate_examples_lib.py b/tensorflow/lite/testing/generate_examples_lib.py
index 472caae..53e286f 100644
--- a/tensorflow/lite/testing/generate_examples_lib.py
+++ b/tensorflow/lite/testing/generate_examples_lib.py
@@ -37,7 +37,6 @@
 import random
 import re
 import string
-import tempfile
 import traceback
 import zipfile
 import numpy as np
@@ -70,8 +69,6 @@
     # TOCO doesn't support scalars as input.
     # Concat doesn't work with a single input tensor
     r"concat.*num_tensors=1": "67378344",
-    # Transposition in MatMul is not fully supported.
-    "fully_connected.*transpose_a=True": "67586970",
     # Softmax graphs are too complex.
     r"softmax.*dim=0": "67749831",
     # BatchToSpaceND only supports 4D tensors.
@@ -79,7 +76,7 @@
     # Div will use floordiv.
     r"div.*int32": "72051395",
     # Strided slice cannot handle new_axis_mask.
-    r"strided_slice.*new_axis_num=1|2": "137470173",
+    r"strided_slice.*spec=\[None": "137470173",
 }
 
 
@@ -106,9 +103,7 @@
     self.make_edgetpu_tests = False
     # The function to convert a TensorFLow model to TFLite model.
     # See the document for `toco_convert` function for its required signature.
-    # TODO(ycling): Decouple `toco_convert` function from this module, and
-    # remove the `toco` attribute in this class.
-    self.tflite_convert_function = toco_convert
+    self.tflite_convert_function = None
     # A map from regular expression to bug number. Any test failure with label
     # matching the expression will be considered due to the corresponding bug.
     self.known_bugs = KNOWN_BUGS
@@ -158,47 +153,6 @@
     self.inference_output_type = None
 
 
-def toco_options(data_types,
-                 input_arrays,
-                 output_arrays,
-                 shapes,
-                 extra_toco_options=ExtraTocoOptions()):
-  """Create TOCO options to process a model.
-
-  Args:
-    data_types: input and inference types used by TOCO.
-    input_arrays: names of the input tensors
-    output_arrays: name of the output tensors
-    shapes: shapes of the input tensors
-    extra_toco_options: additional toco options
-  Returns:
-    the options in a string.
-  """
-  shape_str = ":".join([",".join(str(y) for y in x) for x in shapes if x])
-  inference_type = "FLOAT"
-  # TODO(ahentz): if we get multi-input quantization to work we need this
-  # to change
-  if data_types[0] == "QUANTIZED_UINT8":
-    inference_type = "QUANTIZED_UINT8"
-  s = (" --input_data_types=%s" % ",".join(data_types) +
-       " --inference_type=%s" % inference_type +
-       " --input_format=TENSORFLOW_GRAPHDEF" + " --output_format=TFLITE" +
-       " --input_arrays=%s" % ",".join(input_arrays) +
-       " --output_arrays=%s" % ",".join(output_arrays))
-  if shape_str:
-    s += (" --input_shapes=%s" % shape_str)
-  if extra_toco_options.drop_control_dependency:
-    s += " --drop_control_dependency"
-  if extra_toco_options.allow_custom_ops:
-    s += " --allow_custom_ops"
-  if extra_toco_options.rnn_states:
-    s += (" --rnn_states='" + extra_toco_options.rnn_states + "'")
-  if extra_toco_options.split_tflite_lstm_inputs is not None:
-    if extra_toco_options.split_tflite_lstm_inputs:
-      s += " --split_tflite_lstm_inputs=true"
-    else:
-      s += " --split_tflite_lstm_inputs=false"
-  return s
 
 
 def format_result(t):
@@ -268,7 +222,7 @@
     fp.write("}\n")
 
 
-_TF_TYPE_INFO = {
+TF_TYPE_INFO = {
     tf.float32: (np.float32, "FLOAT"),
     tf.float16: (np.float16, "FLOAT"),
     tf.int32: (np.int32, "INT32"),
@@ -283,8 +237,8 @@
 def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
   """Build tensor data spreading the range [min_value, max_value)."""
 
-  if dtype in _TF_TYPE_INFO:
-    dtype = _TF_TYPE_INFO[dtype][0]
+  if dtype in TF_TYPE_INFO:
+    dtype = TF_TYPE_INFO[dtype][0]
 
   if dtype in (tf.float32, tf.float16):
     value = (max_value-min_value)*np.random.random_sample(shape)+min_value
@@ -303,8 +257,8 @@
 def create_scalar_data(dtype, min_value=-100, max_value=100):
   """Build scalar tensor data range from min_value to max_value exclusively."""
 
-  if dtype in _TF_TYPE_INFO:
-    dtype = _TF_TYPE_INFO[dtype][0]
+  if dtype in TF_TYPE_INFO:
+    dtype = TF_TYPE_INFO[dtype][0]
 
   if dtype in (tf.float32, tf.float16):
     value = (max_value - min_value) * np.random.random() + min_value
@@ -361,100 +315,6 @@
       expected_tf_failures=3)
 
 
-def toco_convert(options, graph_def, input_tensors, output_tensors, **kwargs):
-  """Convert a model's graph def into a tflite model.
-
-  NOTE: this currently shells out to the toco binary, but we would like
-  convert to Python API tooling in the future.
-
-  Args:
-    options: An Options instance.
-    graph_def: A GraphDef object.
-    input_tensors: List of input tensor tuples `(name, shape, type)`.
-    output_tensors: List of output tensors (names).
-    **kwargs: Extra options to be passed.
-
-  Returns:
-    output tflite model, log_txt from conversion
-    or None, log_txt if it did not convert properly.
-  """
-  # Convert ophint ops if presented.
-  graph_def = tf.lite.experimental.convert_op_hints_to_stubs(
-      graph_def=graph_def)
-  graph_def_str = graph_def.SerializeToString()
-
-  extra_toco_options = kwargs.get("extra_toco_options", ExtraTocoOptions())
-  test_params = kwargs.get("test_params", {})
-  input_arrays = [x[0] for x in input_tensors]
-  data_types = [_TF_TYPE_INFO[x[2]][1] for x in input_tensors]
-
-  if test_params.get("fully_quantize", False):
-    with tempfile.NamedTemporaryFile() as graphdef_file:
-      graphdef_file.write(graph_def_str)
-      graphdef_file.flush()
-
-      input_shapes = get_input_shapes_map(input_tensors)
-      converter = tf.lite.TocoConverter.from_frozen_graph(
-          graphdef_file.name, input_arrays, output_tensors, input_shapes)
-
-      def representative_dataset(input_tensors):
-        calibration_inputs = []
-        for _, shape, _ in input_tensors:
-          if shape:
-            dims = [dim.value for dim in shape.dims]
-            calibration_inputs.append(
-                np.random.uniform(-1, 1, tuple(dims)).astype(np.float32))
-        return calibration_inputs
-
-      def representative_dataset_gen():
-        for _ in range(100):
-          yield representative_dataset(input_tensors)
-
-      converter.target_spec.supported_ops = [
-          tf.lite.OpsSet.TFLITE_BUILTINS_INT8
-      ]
-      converter.representative_dataset = representative_dataset_gen
-      if extra_toco_options.inference_input_type:
-        converter.inference_input_type = (
-            extra_toco_options.inference_input_type)
-      if extra_toco_options.inference_output_type:
-        converter.inference_output_type = (
-            extra_toco_options.inference_output_type)
-
-      try:
-        tflite_model = converter.convert()
-        return tflite_model, ""
-      except Exception as e:
-        log = "{0}\n{1}".format(str(e), traceback.format_exc())
-        return None, log
-
-  else:
-    opts = toco_options(
-        data_types=data_types,
-        input_arrays=input_arrays,
-        shapes=[x[1] for x in input_tensors],
-        output_arrays=output_tensors,
-        extra_toco_options=extra_toco_options)
-
-    with tempfile.NamedTemporaryFile() as graphdef_file, \
-         tempfile.NamedTemporaryFile() as output_file, \
-         tempfile.NamedTemporaryFile("w+") as stdout_file:
-      graphdef_file.write(graph_def_str)
-      graphdef_file.flush()
-
-      # TODO(aselle): Switch this to subprocess at some point.
-      if options.run_with_flex:
-        opts += " --enable_select_tf_ops --force_select_tf_ops"
-      cmd = ("%s --input_file=%s --output_file=%s %s > %s 2>&1" %
-             (bin_path, graphdef_file.name, output_file.name, opts,
-              stdout_file.name))
-      exit_code = os.system(cmd)
-      log = (
-          cmd + "exited with code %d" % exit_code + "\n------------------\n" +
-          stdout_file.read())
-      return (None if exit_code != 0 else output_file.read()), log
-
-
 def get_input_shapes_map(input_tensors):
   """Gets a map of input names to shapes.
 
@@ -1103,7 +963,7 @@
 
   def build_inputs(parameters, sess, inputs, outputs):
     dummy_input = np.zeros(
-        parameters["input_shape"], dtype=_TF_TYPE_INFO[parameters["dtype"]][0])
+        parameters["input_shape"], dtype=TF_TYPE_INFO[parameters["dtype"]][0])
     return [dummy_input], sess.run(outputs, feed_dict={inputs[0]: dummy_input})
 
   make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
@@ -2316,6 +2176,12 @@
       "transpose_a": [False],
       "transpose_b": [True],
       "constant_filter": [True, False],
+  }, {
+      "shape1": [[5, 3]],
+      "shape2": [[5, 3]],
+      "transpose_a": [True],
+      "transpose_b": [False],
+      "constant_filter": [True, False],
   }]
 
   def build_graph(parameters):
@@ -3169,7 +3035,7 @@
     """Build inputs for stride_slice test."""
     input_values = create_tensor_data(parameters["dtype"],
                                       parameters["input_shape"])
-    index_type = _TF_TYPE_INFO[parameters["index_type"]][0]
+    index_type = TF_TYPE_INFO[parameters["index_type"]][0]
     values = [input_values]
     if not parameters["constant_indices"]:
       begin_values = np.array(parameters["begin"]).astype(index_type)
@@ -3297,12 +3163,37 @@
   test_parameters = [
       {
           "dtype": [tf.float32],
-          "new_axis_num": [0, 1, 2],
-          "shape": [[12, 7], [33]],
-          "stride": [1, 2, 3],
-          "use_begin_end_mask": [True, False],
-          # share between begin and end to avoid creating too many combinations.
-          "begin_end_offset": [0, 1, 3]
+          "shape": [[12, 7], [33, 1]],
+          "spec": [[slice(3, 7, 2), slice(None)],
+                   [tf.newaxis,
+                    slice(3, 7, 1), tf.newaxis,
+                    slice(None)], [slice(1, 5, 1), slice(None)]],
+      },
+      # 1-D case
+      {
+          "dtype": [tf.float32],
+          "shape": [[44]],
+          "spec": [[slice(3, 7, 2)], [tf.newaxis, slice(None)]],
+      },
+      # Shrink mask.
+      {
+          "dtype": [tf.float32],
+          "shape": [[21, 15, 7]],
+          "spec": [[slice(3, 7, 2), slice(None), 2]],
+      },
+      # Ellipsis.
+      {
+          "dtype": [tf.float32],
+          "shape": [[21, 15, 7]],
+          "spec": [[slice(3, 7, 2), Ellipsis]],
+      },
+      # All combinations.
+      {
+          "dtype": [tf.float32],
+          "shape": [[21, 15, 7]],
+          "spec": [[tf.newaxis,
+                    slice(3, 7, 2),
+                    slice(None), Ellipsis]],
       },
   ]
 
@@ -3315,38 +3206,12 @@
     Returns:
       strided_slice spec, e.g., [2:3, :] or [tf.newaxis, :, tf.newaxis].
     """
-    shape = parameters["shape"]
-    new_axis_num = parameters["new_axis_num"]
-    insert_new_axis_array = [False] * len(shape)
-    for _ in range(new_axis_num):
-      insert_loc = np.random.randint(0, len(insert_new_axis_array) + 1)
-      insert_new_axis_array.insert(insert_loc, True)
-    slice_spec = []
-    index = 0
-    for insert_new_axis in insert_new_axis_array:
-      if insert_new_axis:
-        slice_spec.append(tf.newaxis)
-      else:
-        # Random pop up begin/end/strides or just use ":"
-        if parameters["use_begin_end_mask"]:
-          # use slice(None), means use all values, equivalent of ":".
-          slice_spec.append(slice(None))
-        else:
-          # Begin.
-          begin = parameters["begin_end_offset"]
-          # End.
-          end = shape[index] - parameters["begin_end_offset"]
-          # Strides.
-          stride = parameters["stride"]
-          slice_spec.append(slice(begin, end, stride))
-        index += 1
-    return slice_spec
 
   def build_graph(parameters):
     """Build a simple graph with np style strided_slice."""
     input_value = tf.placeholder(
         dtype=parameters["dtype"], shape=parameters["shape"])
-    out = input_value.__getitem__(build_strided_slice_spec(parameters))
+    out = input_value.__getitem__(parameters["spec"])
     return [input_value], [out]
 
   def build_inputs(parameters, sess, inputs, outputs):
@@ -4123,7 +3988,7 @@
     """Build inputs for slice test."""
     input_values = create_tensor_data(parameters["dtype"],
                                       parameters["input_shape"])
-    index_type = _TF_TYPE_INFO[parameters["index_type"]][0]
+    index_type = TF_TYPE_INFO[parameters["index_type"]][0]
 
     begin_values = np.array(parameters["begin"]).astype(index_type)
     size_values = np.array(parameters["size"]).astype(index_type)
@@ -4800,7 +4665,7 @@
     return [input_tensor], [out]
 
   def build_inputs(parameters, sess, inputs, outputs):
-    numpy_type = _TF_TYPE_INFO[parameters["dtype"]][0]
+    numpy_type = TF_TYPE_INFO[parameters["dtype"]][0]
     input_value = np.array([[1, 0], [2, 1]], numpy_type)
     return [input_value], sess.run(
         outputs, feed_dict=dict(zip(inputs, [input_value])))
@@ -5262,13 +5127,8 @@
   make_zip_of_tests(options, test_parameters, build_graph, build_inputs,
                     extra_toco_options)
 
-# Toco binary path provided by the generate rule.
-bin_path = None
-
 
 def generate_examples(options):
-  global bin_path
-
   def mkdir_if_not_exist(x):
     if not os.path.isdir(x):
       os.mkdir(x)
@@ -5279,7 +5139,6 @@
   mkdir_if_not_exist(opstest_path)
 
   out = options.zip_to_output
-  bin_path = options.toco
   # Some zip filenames contain a postfix identifying the conversion mode. The
   # list of valid conversion modes is defined in
   # generated_test_conversion_modes() in build_def.bzl.
diff --git a/tensorflow/lite/testing/generated_examples_zip_test.cc b/tensorflow/lite/testing/generated_examples_zip_test.cc
index b293611..df77b94 100644
--- a/tensorflow/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/lite/testing/generated_examples_zip_test.cc
@@ -106,6 +106,9 @@
 
     // Select kernel doesn't support broadcasting yet.
     {R"(^\/where.*1,2,3,1)", "134692786"},
+
+    // Strided slice doesn't support ellipsis.
+    {R"(strided_slice.*Ellipsis)", "138098220"},
 };
 
 // Additional list of tests that are expected to fail when
@@ -262,7 +265,9 @@
 
   std::ifstream tflite_stream(tflite_test_case);
   ASSERT_TRUE(tflite_stream.is_open()) << tflite_test_case;
-  tflite::testing::TfLiteDriver test_driver(FLAGS_use_nnapi);
+  tflite::testing::TfLiteDriver test_driver(
+      FLAGS_use_nnapi ? TfLiteDriver::DelegateType::kNnapi
+                      : TfLiteDriver::DelegateType::kNone);
 
   if (test_path.find("fully_quantize=True") != std::string::npos) {
     // TODO(b/134594898): Tighten this constraint.
diff --git a/tensorflow/lite/testing/kernel_test/tflite_kernel_runner.cc b/tensorflow/lite/testing/kernel_test/tflite_kernel_runner.cc
index 34c1728..dbbaf16 100644
--- a/tensorflow/lite/testing/kernel_test/tflite_kernel_runner.cc
+++ b/tensorflow/lite/testing/kernel_test/tflite_kernel_runner.cc
@@ -19,10 +19,13 @@
   tflite::testing::kernel_test::TestOptions options =
       tflite::testing::kernel_test::ParseTfliteKernelTestFlags(&argc, argv);
   const bool run_reference_kernel = options.kernel_type == "REFERENCE";
-  const bool use_nnapi = options.kernel_type == "NNAPI";
+  const tflite::testing::TfLiteDriver::DelegateType delegate_type =
+      options.kernel_type == "NNAPI"
+          ? tflite::testing::TfLiteDriver::DelegateType::kNnapi
+          : tflite::testing::TfLiteDriver::DelegateType::kNone;
 
   auto runner = absl::make_unique<tflite::testing::TfLiteDriver>(
-      use_nnapi, "", run_reference_kernel);
+      delegate_type, run_reference_kernel);
   if (tflite::testing::kernel_test::RunKernelTest(options, runner.get()) ==
       kTfLiteOk) {
     return 0;
diff --git a/tensorflow/lite/testing/kernel_test/util_test.cc b/tensorflow/lite/testing/kernel_test/util_test.cc
index cbec660..0599ded 100644
--- a/tensorflow/lite/testing/kernel_test/util_test.cc
+++ b/tensorflow/lite/testing/kernel_test/util_test.cc
@@ -34,7 +34,8 @@
       "tensorflow/lite/testdata/test_input.csv";
   options.dump_output_to_file = FLAGS_test_tmpdir + "/test_out.csv";
   options.kernel_type = "REFERENCE";
-  std::unique_ptr<TestRunner> runner(new TfLiteDriver(false, "", true));
+  std::unique_ptr<TestRunner> runner(new TfLiteDriver(
+      TfLiteDriver::DelegateType::kNone, /*reference_kernel=*/true));
   RunKernelTest(options, runner.get());
   std::string expected = "3";
   for (int i = 0; i < 1 * 8 * 8 * 3 - 1; i++) {
diff --git a/tensorflow/lite/testing/model_coverage/BUILD b/tensorflow/lite/testing/model_coverage/BUILD
index 7e6a659..39ed70c 100644
--- a/tensorflow/lite/testing/model_coverage/BUILD
+++ b/tensorflow/lite/testing/model_coverage/BUILD
@@ -4,7 +4,7 @@
 
 licenses(["notice"])  # Apache 2.0
 
-py_binary(
+py_library(
     name = "model_coverage_lib",
     srcs = ["model_coverage_lib.py"],
     srcs_version = "PY2AND3",
diff --git a/tensorflow/lite/testing/model_coverage/model_coverage_lib_test.py b/tensorflow/lite/testing/model_coverage/model_coverage_lib_test.py
index d1309b7..328ac9e 100644
--- a/tensorflow/lite/testing/model_coverage/model_coverage_lib_test.py
+++ b/tensorflow/lite/testing/model_coverage/model_coverage_lib_test.py
@@ -38,7 +38,6 @@
 from tensorflow.python.training.training_util import write_graph
 
 
-@test_util.run_v1_only('Incompatible with 2.0.')
 class EvaluateFrozenGraph(test.TestCase):
 
   def _saveFrozenGraph(self, sess):
@@ -47,27 +46,29 @@
     return graph_def_file
 
   def testFloat(self):
-    with session.Session().as_default() as sess:
-      in_tensor = array_ops.placeholder(
-          shape=[1, 16, 16, 3], dtype=dtypes.float32)
-      _ = in_tensor + in_tensor
-    filename = self._saveFrozenGraph(sess)
+    with ops.Graph().as_default():
+      with session.Session().as_default() as sess:
+        in_tensor = array_ops.placeholder(
+            shape=[1, 16, 16, 3], dtype=dtypes.float32)
+        _ = in_tensor + in_tensor
 
+    filename = self._saveFrozenGraph(sess)
     model_coverage.test_frozen_graph(filename, ['Placeholder'], ['add'])
 
   def testMultipleOutputs(self):
-    with session.Session().as_default() as sess:
-      in_tensor_1 = array_ops.placeholder(
-          shape=[1, 16], dtype=dtypes.float32, name='inputA')
-      in_tensor_2 = array_ops.placeholder(
-          shape=[1, 16], dtype=dtypes.float32, name='inputB')
+    with ops.Graph().as_default():
+      with session.Session().as_default() as sess:
+        in_tensor_1 = array_ops.placeholder(
+            shape=[1, 16], dtype=dtypes.float32, name='inputA')
+        in_tensor_2 = array_ops.placeholder(
+            shape=[1, 16], dtype=dtypes.float32, name='inputB')
 
-      weight = constant_op.constant(-1.0, shape=[16, 16])
-      bias = constant_op.constant(-1.0, shape=[16])
-      layer = math_ops.matmul(in_tensor_1, weight) + bias
-      _ = math_ops.reduce_mean(math_ops.square(layer - in_tensor_2))
+        weight = constant_op.constant(-1.0, shape=[16, 16])
+        bias = constant_op.constant(-1.0, shape=[16])
+        layer = math_ops.matmul(in_tensor_1, weight) + bias
+        _ = math_ops.reduce_mean(math_ops.square(layer - in_tensor_2))
+
     filename = self._saveFrozenGraph(sess)
-
     model_coverage.test_frozen_graph(filename, ['inputA', 'inputB'],
                                      ['add', 'Mean'])
 
@@ -94,17 +95,18 @@
 
   def _getQuantizedModel(self):
     np.random.seed(0)
-    with session.Session().as_default() as sess:
-      # The tensor needs to have more than 1024 elements for quantize_weights to
-      # kick in. Thus, the [33, 33] shape.
-      in_tensor_1 = array_ops.placeholder(
-          shape=[33, 33], dtype=dtypes.float32, name='inputA')
-      in_tensor_2 = constant_op.constant(
-          np.random.uniform(low=-10., high=10., size=(33, 33)),
-          shape=[33, 33],
-          dtype=dtypes.float32,
-          name='inputB')
-      _ = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
+    with ops.Graph().as_default():
+      with session.Session().as_default() as sess:
+        # The tensor needs to have more than 1024 elements for quantize_weights
+        # to kick in. Thus, the [33, 33] shape.
+        in_tensor_1 = array_ops.placeholder(
+            shape=[33, 33], dtype=dtypes.float32, name='inputA')
+        in_tensor_2 = constant_op.constant(
+            np.random.uniform(low=-10., high=10., size=(33, 33)),
+            shape=[33, 33],
+            dtype=dtypes.float32,
+            name='inputB')
+        _ = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
 
     filename = self._saveFrozenGraph(sess)
     return filename
@@ -125,25 +127,24 @@
         target_ops=set([lite.OpsSet.SELECT_TF_OPS]))
 
 
-@test_util.run_v1_only('Incompatible with 2.0.')
 class EvaluateSavedModel(test.TestCase):
 
   def testFloat(self):
     saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel')
-    with session.Session().as_default() as sess:
-      in_tensor_1 = array_ops.placeholder(
-          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
-      in_tensor_2 = array_ops.placeholder(
-          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
-      out_tensor = in_tensor_1 + in_tensor_2
+    with ops.Graph().as_default():
+      with session.Session().as_default() as sess:
+        in_tensor_1 = array_ops.placeholder(
+            shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
+        in_tensor_2 = array_ops.placeholder(
+            shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
+        out_tensor = in_tensor_1 + in_tensor_2
 
-      inputs = {'x': in_tensor_1, 'y': in_tensor_2}
-      outputs = {'z': out_tensor}
-      saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
+        inputs = {'x': in_tensor_1, 'y': in_tensor_2}
+        outputs = {'z': out_tensor}
+        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
     model_coverage.test_saved_model(saved_model_dir)
 
 
-@test_util.run_v1_only('Incompatible with 2.0.')
 class EvaluateKerasModel(test.TestCase):
 
   def _getSingleInputKerasModel(self):
diff --git a/tensorflow/lite/testing/nnapi_example.cc b/tensorflow/lite/testing/nnapi_example.cc
index 309cb19..a847ffa 100644
--- a/tensorflow/lite/testing/nnapi_example.cc
+++ b/tensorflow/lite/testing/nnapi_example.cc
@@ -42,7 +42,9 @@
   }
 
   printf("Use nnapi is set to: %d\n", use_nnapi);
-  tflite::testing::TfLiteDriver test_driver(use_nnapi);
+  tflite::testing::TfLiteDriver test_driver(
+      use_nnapi ? tflite::testing::TfLiteDriver::DelegateType::kNnapi
+                : tflite::testing::TfLiteDriver::DelegateType::kNone);
 
   test_driver.SetModelBaseDir(dirname(examples_filename));
   if (!tflite::testing::ParseAndRunTests(&tflite_stream, &test_driver)) {
diff --git a/tensorflow/lite/testing/tflite_diff_flags.h b/tensorflow/lite/testing/tflite_diff_flags.h
index 2fe068e..8b1205e 100644
--- a/tensorflow/lite/testing/tflite_diff_flags.h
+++ b/tensorflow/lite/testing/tflite_diff_flags.h
@@ -17,9 +17,10 @@
 
 #include <cstring>
 
+#include "tensorflow/core/util/command_line_flags.h"
 #include "tensorflow/lite/testing/split.h"
 #include "tensorflow/lite/testing/tflite_diff_util.h"
-#include "tensorflow/core/util/command_line_flags.h"
+#include "tensorflow/lite/testing/tflite_driver.h"
 
 namespace tflite {
 namespace testing {
@@ -33,9 +34,10 @@
     string input_layer_shape;
     string output_layer;
     int32_t num_runs_per_pass = 100;
-    string delegate;
+    string delegate_name;
   } values;
 
+  std::string delegate_name;
   std::vector<tensorflow::Flag> flags = {
       tensorflow::Flag("tensorflow_model", &values.tensorflow_model,
                        "Path of tensorflow model."),
@@ -55,9 +57,9 @@
                        "output_1,output_2."),
       tensorflow::Flag("num_runs_per_pass", &values.num_runs_per_pass,
                        "[optional] Number of full runs in each pass."),
-      tensorflow::Flag("delegate", &values.delegate,
+      tensorflow::Flag("delegate", &values.delegate_name,
                        "[optional] Delegate to use for executing ops. Must be "
-                       "`{\"\", FLEX}`"),
+                       "`{\"\", NNAPI, GPU, FLEX}`"),
   };
 
   bool no_inputs = *argc == 1;
@@ -70,9 +72,20 @@
              values.input_layer_shape.empty() || values.output_layer.empty()) {
     fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
     return {};
-  } else if (!(values.delegate == "" || values.delegate == "FLEX")) {
-    fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
-    return {};
+  }
+
+  TfLiteDriver::DelegateType delegate = TfLiteDriver::DelegateType::kNone;
+  if (!values.delegate_name.empty()) {
+    if (delegate_name == "NNAPI") {
+      delegate = TfLiteDriver::DelegateType::kNnapi;
+    } else if (values.delegate_name == "GPU") {
+      delegate = TfLiteDriver::DelegateType::kGpu;
+    } else if (values.delegate_name == "FLEX") {
+      delegate = TfLiteDriver::DelegateType::kFlex;
+    } else {
+      fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
+      return {};
+    }
   }
 
   return {values.tensorflow_model,
@@ -82,7 +95,7 @@
           Split<string>(values.input_layer_shape, ":"),
           Split<string>(values.output_layer, ","),
           values.num_runs_per_pass,
-          values.delegate};
+          delegate};
 }
 
 }  // namespace testing
diff --git a/tensorflow/lite/testing/tflite_diff_util.cc b/tensorflow/lite/testing/tflite_diff_util.cc
index 0142ae4..721830a 100644
--- a/tensorflow/lite/testing/tflite_diff_util.cc
+++ b/tensorflow/lite/testing/tflite_diff_util.cc
@@ -33,7 +33,7 @@
           options.input_layer_shape, options.output_layer)) {
     return false;
   }
-  TfLiteDriver tflite_driver(/*use_nnapi=*/true, options.delegate);
+  TfLiteDriver tflite_driver(options.delegate);
   tflite_driver.LoadModel(options.tflite_model);
   return tflite::testing::ParseAndRunTests(&tflite_stream, &tflite_driver);
 }
diff --git a/tensorflow/lite/testing/tflite_diff_util.h b/tensorflow/lite/testing/tflite_diff_util.h
index 3f9f108..091134f 100644
--- a/tensorflow/lite/testing/tflite_diff_util.h
+++ b/tensorflow/lite/testing/tflite_diff_util.h
@@ -18,6 +18,7 @@
 #include <vector>
 
 #include "tensorflow/lite/string.h"
+#include "tensorflow/lite/testing/tflite_driver.h"
 
 namespace tflite {
 namespace testing {
@@ -44,9 +45,8 @@
   // each of the passes. The first pass has a single inference, while the
   // second pass does multiple inferences back to back.
   int num_runs_per_pass;
-  // Path to the delegate library to be loaded in order to execute ops. Must be
-  // `{"", FLEX}`.
-  string delegate;
+  // The type of delegate to apply during inference.
+  TfLiteDriver::DelegateType delegate;
 };
 
 // Run a single TensorFLow Lite diff test with a given options.
diff --git a/tensorflow/lite/testing/tflite_driver.cc b/tensorflow/lite/testing/tflite_driver.cc
index 50981c5..cbed30f 100644
--- a/tensorflow/lite/testing/tflite_driver.cc
+++ b/tensorflow/lite/testing/tflite_driver.cc
@@ -25,6 +25,7 @@
 #include "tensorflow/lite/string_util.h"
 #include "tensorflow/lite/testing/join.h"
 #include "tensorflow/lite/testing/split.h"
+#include "tensorflow/lite/tools/evaluation/utils.h"
 
 namespace tflite {
 namespace testing {
@@ -259,9 +260,8 @@
   }
 }
 
-TfLiteDriver::TfLiteDriver(bool use_nnapi, const string& delegate_name,
-                           bool reference_kernel)
-    : use_nnapi_(use_nnapi),
+TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
+    : delegate_(nullptr, nullptr),
       relative_threshold_(kRelativeThreshold),
       absolute_threshold_(kAbsoluteThreshold) {
   if (reference_kernel) {
@@ -274,8 +274,21 @@
                                    tflite::ops::custom::Register_RFFT2D());
   }
 
-  if (delegate_name == "FLEX") {
-    delegate_ = FlexDelegate::Create();
+  switch (delegate_type) {
+    case DelegateType::kNone:
+      break;
+    case DelegateType::kNnapi:
+      delegate_ = evaluation::CreateNNAPIDelegate();
+      break;
+    case DelegateType::kGpu:
+      delegate_ = evaluation::CreateGPUDelegate(/*model=*/nullptr);
+      break;
+    case DelegateType::kFlex:
+      delegate_ = Interpreter::TfLiteDelegatePtr(
+          FlexDelegate::Create().release(), [](TfLiteDelegate* delegate) {
+            delete static_cast<tflite::FlexDelegate*>(delegate);
+          });
+      break;
   }
 }
 
@@ -310,8 +323,6 @@
     Invalidate("Failed build interpreter");
     return;
   }
-  interpreter_->UseNNAPI(use_nnapi_);
-
   if (delegate_) {
     if (interpreter_->ModifyGraphWithDelegate(delegate_.get()) != kTfLiteOk) {
       Invalidate("Unable to the build graph using the delegate");
diff --git a/tensorflow/lite/testing/tflite_driver.h b/tensorflow/lite/testing/tflite_driver.h
index 8dd6459..a9bd92a 100644
--- a/tensorflow/lite/testing/tflite_driver.h
+++ b/tensorflow/lite/testing/tflite_driver.h
@@ -31,7 +31,19 @@
 // A test runner that feeds inputs into TF Lite and verifies its outputs.
 class TfLiteDriver : public TestRunner {
  public:
-  explicit TfLiteDriver(bool use_nnapi, const string& delegate = "",
+  enum class DelegateType {
+    kNone,
+    kNnapi,
+    kGpu,
+    kFlex,
+  };
+
+  /**
+   * Creates a new TfLiteDriver
+   * @param  delegate         The (optional) delegate to use.
+   * @param  reference_kernel Whether to use the builtin reference kernel ops.
+   */
+  explicit TfLiteDriver(DelegateType delegate_type = DelegateType::kNone,
                         bool reference_kernel = false);
   ~TfLiteDriver() override;
 
@@ -71,8 +83,7 @@
   class Expectation;
 
   std::unique_ptr<OpResolver> resolver_;
-  std::unique_ptr<FlexDelegate> delegate_;
-  bool use_nnapi_ = false;
+  Interpreter::TfLiteDelegatePtr delegate_;
   std::unique_ptr<FlatBufferModel> model_;
   std::unique_ptr<Interpreter> interpreter_;
   std::map<int, std::unique_ptr<Expectation>> expected_output_;
diff --git a/tensorflow/lite/testing/tflite_driver_test.cc b/tensorflow/lite/testing/tflite_driver_test.cc
index 93125c4..99efd2d 100644
--- a/tensorflow/lite/testing/tflite_driver_test.cc
+++ b/tensorflow/lite/testing/tflite_driver_test.cc
@@ -24,7 +24,7 @@
 using ::testing::ElementsAre;
 
 TEST(TfliteDriverTest, SimpleTest) {
-  std::unique_ptr<TestRunner> runner(new TfLiteDriver(/*use_nnapi=*/false));
+  std::unique_ptr<TestRunner> runner(new TfLiteDriver());
 
   runner->SetModelBaseDir("tensorflow/lite");
   runner->LoadModel("testdata/multi_add.bin");
@@ -60,7 +60,8 @@
 
 TEST(TfliteDriverTest, SingleAddOpTest) {
   std::unique_ptr<TestRunner> runner(new TfLiteDriver(
-      /*use_nnapi*/ false, /*delegate*/ "", /*reference_kernel*/ true));
+      /*delegate_type=*/TfLiteDriver::DelegateType::kNone,
+      /*reference_kernel=*/true));
 
   runner->SetModelBaseDir("tensorflow/lite");
   runner->LoadModel("testdata/multi_add.bin");
@@ -95,7 +96,7 @@
 }
 
 TEST(TfliteDriverTest, AddQuantizedInt8Test) {
-  std::unique_ptr<TestRunner> runner(new TfLiteDriver(/*use_nnapi=*/false));
+  std::unique_ptr<TestRunner> runner(new TfLiteDriver());
 
   runner->SetModelBaseDir("tensorflow/lite");
   runner->LoadModel("testdata/add_quantized_int8.bin");
diff --git a/tensorflow/lite/testing/toco_convert.py b/tensorflow/lite/testing/toco_convert.py
new file mode 100644
index 0000000..d14a989
--- /dev/null
+++ b/tensorflow/lite/testing/toco_convert.py
@@ -0,0 +1,170 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tempfile
+
+import numpy as np
+import tensorflow as tf
+import traceback
+
+from tensorflow.lite.testing import generate_examples_lib
+
+
+def toco_options(data_types,
+                 input_arrays,
+                 output_arrays,
+                 shapes,
+                 extra_toco_options=None):
+  """Create TOCO options to process a model.
+
+  Args:
+    data_types: input and inference types used by TOCO.
+    input_arrays: names of the input tensors
+    output_arrays: name of the output tensors
+    shapes: shapes of the input tensors
+    extra_toco_options: additional toco options
+
+  Returns:
+    the options in a string.
+  """
+  if extra_toco_options is None:
+    extra_toco_options = generate_examples_lib.ExtraTocoOptions()
+
+  shape_str = ":".join([",".join(str(y) for y in x) for x in shapes if x])
+  inference_type = "FLOAT"
+  # TODO(ahentz): if we get multi-input quantization to work we need this
+  # to change
+  if data_types[0] == "QUANTIZED_UINT8":
+    inference_type = "QUANTIZED_UINT8"
+  s = (" --input_data_types=%s" % ",".join(data_types) +
+       " --inference_type=%s" % inference_type +
+       " --input_format=TENSORFLOW_GRAPHDEF" + " --output_format=TFLITE" +
+       " --input_arrays=%s" % ",".join(input_arrays) +
+       " --output_arrays=%s" % ",".join(output_arrays))
+  if shape_str:
+    s += (" --input_shapes=%s" % shape_str)
+  if extra_toco_options.drop_control_dependency:
+    s += " --drop_control_dependency"
+  if extra_toco_options.allow_custom_ops:
+    s += " --allow_custom_ops"
+  if extra_toco_options.rnn_states:
+    s += (" --rnn_states='" + extra_toco_options.rnn_states + "'")
+  if extra_toco_options.split_tflite_lstm_inputs is not None:
+    if extra_toco_options.split_tflite_lstm_inputs:
+      s += " --split_tflite_lstm_inputs=true"
+    else:
+      s += " --split_tflite_lstm_inputs=false"
+  return s
+
+
+def toco_convert(options, graph_def, input_tensors, output_tensors, **kwargs):
+  """Convert a model's graph def into a tflite model.
+
+  NOTE: this currently shells out to the toco binary, but we would like
+  convert to Python API tooling in the future.
+
+  Args:
+    options: An Options instance.
+    graph_def: A GraphDef object.
+    input_tensors: List of input tensor tuples `(name, shape, type)`.
+    output_tensors: List of output tensors (names).
+    **kwargs: Extra options to be passed.
+
+  Returns:
+    output tflite model, log_txt from conversion
+    or None, log_txt if it did not convert properly.
+  """
+  # Convert ophint ops if presented.
+  graph_def = tf.lite.experimental.convert_op_hints_to_stubs(
+      graph_def=graph_def)
+  graph_def_str = graph_def.SerializeToString()
+
+  extra_toco_options = kwargs.get(
+      "extra_toco_options", generate_examples_lib.ExtraTocoOptions())
+  test_params = kwargs.get("test_params", {})
+  input_arrays = [x[0] for x in input_tensors]
+  data_types = [
+      generate_examples_lib.TF_TYPE_INFO[x[2]][1] for x in input_tensors]
+
+  if test_params.get("fully_quantize", False):
+    with tempfile.NamedTemporaryFile() as graphdef_file:
+      graphdef_file.write(graph_def_str)
+      graphdef_file.flush()
+
+      input_shapes = generate_examples_lib.get_input_shapes_map(input_tensors)
+      converter = tf.lite.TocoConverter.from_frozen_graph(
+          graphdef_file.name, input_arrays, output_tensors, input_shapes)
+
+      def representative_dataset(input_tensors):
+        calibration_inputs = []
+        for _, shape, _ in input_tensors:
+          if shape:
+            dims = [dim.value for dim in shape.dims]
+            calibration_inputs.append(
+                np.random.uniform(-1, 1, tuple(dims)).astype(np.float32))
+        return calibration_inputs
+
+      def representative_dataset_gen():
+        for _ in range(100):
+          yield representative_dataset(input_tensors)
+
+      converter.target_spec.supported_ops = [
+          tf.lite.OpsSet.TFLITE_BUILTINS_INT8
+      ]
+      converter.representative_dataset = representative_dataset_gen
+      if extra_toco_options.inference_input_type:
+        converter.inference_input_type = (
+            extra_toco_options.inference_input_type)
+      if extra_toco_options.inference_output_type:
+        converter.inference_output_type = (
+            extra_toco_options.inference_output_type)
+
+      try:
+        tflite_model = converter.convert()
+        return tflite_model, ""
+      except Exception as e:
+        log = "{0}\n{1}".format(str(e), traceback.format_exc())
+        return None, log
+
+  else:
+    opts = toco_options(
+        data_types=data_types,
+        input_arrays=input_arrays,
+        shapes=[x[1] for x in input_tensors],
+        output_arrays=output_tensors,
+        extra_toco_options=extra_toco_options)
+
+    with tempfile.NamedTemporaryFile() as graphdef_file, \
+         tempfile.NamedTemporaryFile() as output_file, \
+         tempfile.NamedTemporaryFile("w+") as stdout_file:
+      graphdef_file.write(graph_def_str)
+      graphdef_file.flush()
+
+      # TODO(aselle): Switch this to subprocess at some point.
+      if options.run_with_flex:
+        opts += " --enable_select_tf_ops --force_select_tf_ops"
+      cmd = ("%s --input_file=%s --output_file=%s %s > %s 2>&1" %
+             (options.toco, graphdef_file.name, output_file.name, opts,
+              stdout_file.name))
+      exit_code = os.system(cmd)
+      log = (
+          cmd + "exited with code %d" % exit_code + "\n------------------\n" +
+          stdout_file.read())
+      return (None if exit_code != 0 else output_file.read()), log
diff --git a/tensorflow/lite/toco/BUILD b/tensorflow/lite/toco/BUILD
index 43714fc..32a86b4 100644
--- a/tensorflow/lite/toco/BUILD
+++ b/tensorflow/lite/toco/BUILD
@@ -1,5 +1,5 @@
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_proto_library_cc",
     "tf_proto_library_py",
 )
@@ -222,6 +222,7 @@
         "graph_transformations/quantize.cc",
         "graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc",
         "graph_transformations/remove_final_dequantize_op.cc",
+        "graph_transformations/remove_successive_transpose.cc",
         "graph_transformations/remove_tensorflow_assert.cc",
         "graph_transformations/remove_tensorflow_identity.cc",
         "graph_transformations/remove_trivial_binary.cc",
diff --git a/tensorflow/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/lite/toco/graph_transformations/graph_transformations.h
index c53e070..8e05312 100644
--- a/tensorflow/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/lite/toco/graph_transformations/graph_transformations.h
@@ -159,6 +159,7 @@
 DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax)
 DECLARE_GRAPH_TRANSFORMATION(Quantize)
 DECLARE_GRAPH_TRANSFORMATION(RemoveFinalDequantizeOp)
+DECLARE_GRAPH_TRANSFORMATION(RemoveSuccesiveTranspose)
 DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowAssert)
 DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity)
 DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialBinaryOperator)
diff --git a/tensorflow/lite/toco/graph_transformations/quantize.cc b/tensorflow/lite/toco/graph_transformations/quantize.cc
index 8e951e0..399f953 100644
--- a/tensorflow/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/lite/toco/graph_transformations/quantize.cc
@@ -31,12 +31,17 @@
 
 namespace {
 
-bool SupportsQuantization(const Operator& op) {
+bool SupportsQuantization(Model* model, const Operator& op) {
   auto type = op.type;
   if (type == OperatorType::kUnsupported) {
     auto* unsupported = static_cast<const TensorFlowUnsupportedOperator*>(&op);
     return unsupported->quantized;
   }
+  if (op.type == OperatorType::kRange) {
+    const auto& array = model->GetArray(op.outputs[0]);
+    return (array.data_type != ArrayDataType::kFloat &&
+            array.data_type != ArrayDataType::kFloat16);
+  }
   return type == OperatorType::kConv || type == OperatorType::kDepthwiseConv ||
          type == OperatorType::kFullyConnected ||
          type == OperatorType::kConcatenation ||
@@ -494,7 +499,7 @@
           << "Input array " << input << " is missing quantization_params";
     }
   }
-  if (!SupportsQuantization(op)) {
+  if (!SupportsQuantization(model, op)) {
     return tensorflow::errors::InvalidArgument(
         "Unimplemented: this graph contains an operator of type ",
         HelpfulOperatorTypeName(op),
diff --git a/tensorflow/lite/toco/graph_transformations/remove_successive_transpose.cc b/tensorflow/lite/toco/graph_transformations/remove_successive_transpose.cc
new file mode 100644
index 0000000..1f0fdf8
--- /dev/null
+++ b/tensorflow/lite/toco/graph_transformations/remove_successive_transpose.cc
@@ -0,0 +1,95 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <vector>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/lite/toco/model.h"
+#include "tensorflow/lite/toco/tooling_util.h"
+
+namespace toco {
+
+namespace {
+
+bool TransformsToIdentity(std::vector<int> const& perm1,
+                          std::vector<int> const& perm2) {
+  if (perm2.size() != perm1.size() || perm1.empty()) {
+    return false;
+  }
+  // perm1 is the order of the indices after first transpose. When perm1 is
+  // reordered according to perm2, if the result is simple increasing sequence
+  // i.e., range(0, perm1.size()), then the two transposes cancel each other.
+  for (int i = 0; i < perm1.size(); ++i) {
+    if (perm1[i] < 0 || perm1[i] >= perm1.size() || perm2[i] < 0 ||
+        perm2[i] >= perm1.size()) {
+      return false;
+    }
+    if (perm1[perm2[i]] != i) {
+      return false;
+    }
+  }
+  return true;
+}
+
+void ReplaceOpInputsWith(Model* model, const string& lookfor,
+                         const string& replacewith) {
+  for (const auto& op : model->operators) {
+    for (int i = 0; i < op->inputs.size(); ++i) {
+      if (op->inputs[i] == lookfor) {
+        op->inputs[i] = replacewith;
+      }
+    }
+  }
+}
+
+}  // namespace
+
+::tensorflow::Status RemoveSuccesiveTranspose::Run(Model* model,
+                                                   std::size_t op_index,
+                                                   bool* modified) {
+  *modified = false;
+  auto op = model->operators.begin() + op_index;
+  if (op->get()->type != OperatorType::kTranspose) {
+    return ::tensorflow::Status::OK();
+  }
+
+  TransposeOperator* t_op = static_cast<TransposeOperator*>(op->get());
+  if (CountOpsWithInput(*model, t_op->outputs[0]) != 1) {
+    return ::tensorflow::Status::OK();
+  }
+  Operator* next = GetOpWithInput(*model, t_op->outputs[0]);
+  if (!next || next->type != OperatorType::kTranspose) {
+    return ::tensorflow::Status::OK();
+  }
+
+  TransposeOperator* t_next = static_cast<TransposeOperator*>(next);
+  if (!CountOpsWithInput(*model, t_next->outputs[0])) {
+    return ::tensorflow::Status::OK();
+  }
+
+  if (TransformsToIdentity(t_op->perm, t_next->perm)) {
+    // Find the input tensor that uses the results of transpose t_next, then
+    // make it point to the input of t_op, effectively isolating both the
+    // transposes from the graph.
+    ReplaceOpInputsWith(model, t_next->outputs[0], t_op->inputs[0]);
+    DeleteOpAndArrays(model, t_next);
+    DeleteOpAndArrays(model, t_op);
+    *modified = true;
+  }
+
+  return ::tensorflow::Status::OK();
+}
+
+}  // namespace toco
diff --git a/tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.cc
index 292c601..bd529bd 100644
--- a/tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.cc
+++ b/tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.cc
@@ -57,6 +57,12 @@
       to_array.final_data_type == ArrayDataType::kNone) {
     to_array.final_data_type = from_array.final_data_type;
   }
+  // The 'from' array may now be unused. We delete it here immediately
+  // so that this function doesn't violate graph invariants (no unused arrays)
+  // and as it's not trivial to get this right for the caller since
+  // DeleteOpAndArrays will no longer delete this array, since it's no longer
+  // referenced by this op.
+  DeleteArrayIfUnused(from, model);
 }
 
 }  // namespace
diff --git a/tensorflow/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_binary.cc
index 637579e..03d2f05 100644
--- a/tensorflow/lite/toco/graph_transformations/resolve_constant_binary.cc
+++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_binary.cc
@@ -165,7 +165,7 @@
   }
 }
 
-void EvaluateBinaryOperatorOnConstantInputs(Model* model,
+bool EvaluateBinaryOperatorOnConstantInputs(Model* model,
                                             const Operator* binary_op) {
   const auto inputs_data_type = model->GetArray(binary_op->inputs[0]).data_type;
   const auto output_data_type =
@@ -175,7 +175,7 @@
       output_data_type == OutputDataType) {                                 \
     EvaluateBinaryOperatorOnConstantInputs<InputsDataType, OutputDataType>( \
         model, binary_op);                                                  \
-    return;                                                                 \
+    return true;                                                            \
   }
   TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kFloat)
   TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kBool)
@@ -183,8 +183,7 @@
   TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kBool)
   TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kInt64)
   TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kBool)
-  LOG(FATAL) << "Unimplemented: don't know how to resolve a constant "
-             << "binary operator for these data types.";
+  return false;
 #undef TOCO_HANDLE_CASE
 }
 }  // namespace
@@ -245,7 +244,9 @@
       << static_cast<int>(input1_array.data_type) << ").";
 
   // Do the actual constants propagation
-  EvaluateBinaryOperatorOnConstantInputs(model, binary_op);
+  if (!EvaluateBinaryOperatorOnConstantInputs(model, binary_op)) {
+    return ::tensorflow::Status::OK();
+  }
 
   DeleteOpAndArrays(model, binary_op);
   *modified = true;
diff --git a/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
index 1aa30bcf1..ac95d60 100644
--- a/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
+++ b/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
@@ -66,25 +66,69 @@
   const auto* matmul_op =
       static_cast<const TensorFlowMatMulOperator*>(matmul_it->get());
 
-  // Handling transposition of the first input here isn't very simple because
-  // we need to know the actual shape in order to produce a proper
-  // TransposeOperator.  However, the second input is supposed to be 2D, so we
-  // can actually handle transposition of that matrix, which happens to be more
-  // common anyway.
+  auto refresh_matmul_iterator = [&model, &matmul_it, &matmul_op]() {
+    matmul_it = std::find_if(model->operators.begin(), model->operators.end(),
+                             [matmul_op](const std::unique_ptr<Operator>& op) {
+                               return op.get() == matmul_op;
+                             });
+    DCHECK_EQ(matmul_it->get(), matmul_op);
+  };
+
+  string input_lhs = matmul_op->inputs[0];
+  string input_rhs = matmul_op->inputs[1];
+
+  // Handle `transpose_a` with best effort: If the dimension of lhs is known,
+  // insert a `Transpose` op.
   if (matmul_op->transpose_a) {
-    AddMessageF(
-        "Not replacing %s by a FullyConnected operator, because it has "
-        "the transpose_a attribute",
-        LogName(*matmul_op));
-    return ::tensorflow::Status::OK();
+    Array& lhs_array = model->GetArray(input_lhs);
+    if (!lhs_array.has_shape()) {
+      AddMessageF(
+          "Not replacing %s by a FullyConnected operator, because it has "
+          "the transpose_a attribute and LHS has no shape",
+          LogName(*matmul_op));
+      return ::tensorflow::Status::OK();
+    }
+
+    int dimensions_count = lhs_array.shape().dimensions_count();
+    if (dimensions_count < 2) {
+      return ::tensorflow::errors::InvalidArgument(
+          "Inputs of MatMul should have dimension >= 2. Got %d dimensions",
+          dimensions_count);
+    }
+
+    // Create a permutation vector to exchange the last 2 dimensions.
+    // E.g. For 4D, create [0, 1, 3, 2].
+    std::vector<int> perm;
+    perm.reserve(dimensions_count);
+    for (int i = 0; i < dimensions_count; ++i) {
+      perm.push_back(i);
+    }
+    std::swap(perm[dimensions_count - 1], perm[dimensions_count - 2]);
+
+    auto* transpose_op = new TransposeOperator;
+    transpose_op->inputs = {
+        input_lhs,
+        CreateInt32Array(
+            model, AvailableArrayName(*model, input_lhs + "/transpose/perm"),
+            perm)};
+    transpose_op->outputs = {
+        AvailableArrayName(*model, input_lhs + "/transpose")};
+    model->GetOrCreateArray(transpose_op->outputs[0]);
+    model->operators.emplace(matmul_it, transpose_op);
+    // Sanity check
+    DCHECK_EQ(transpose_op, FindTransposeOpWithInput(*model, input_lhs));
+    input_lhs = transpose_op->outputs[0];
+
+    refresh_matmul_iterator();
   }
 
+  // TODO(b/138662017): The following code assumes that RHS is 2D. This isn't
+  // always true in TensorFlow.
+  //
   // Reorder the axes on the second input. TensorFlow uses row-major ordering
   // on both inputs, however this is inefficient for the FullyConnected
   // operator. We'll transpose the second input to be in column-major order now
   // and let constant propagation optimize things (if possible).
-  string input_lhs = matmul_op->inputs[0];
-  string input_rhs = matmul_op->inputs[1];
   if (!matmul_op->transpose_b) {
     // Need to transpose input_rhs, by inserting a TransposeOperator.
     // First, check if there already is a TransposeOperator transposing that
@@ -108,6 +152,7 @@
       model->operators.emplace(matmul_it, transpose_op);
       // Sanity check
       DCHECK_EQ(transpose_op, FindTransposeOpWithInput(*model, input_rhs));
+      refresh_matmul_iterator();
     } else {
       AddMessageF(
           "While replacing %s by a FullyConnected operator, reused existing "
@@ -118,15 +163,6 @@
     input_rhs = transpose_op->outputs[0];
   }
 
-  // Refresh iterator.
-  matmul_it = model->operators.begin();
-  for (; matmul_it != model->operators.end(); ++matmul_it) {
-    if (matmul_it->get() == matmul_op) {
-      break;
-    }
-  }
-  DCHECK_EQ(matmul_it->get(), matmul_op);
-
   // Construct the new FullyConnectedOperator.
   auto* fc_op = new FullyConnectedOperator;
   fc_op->inputs = {input_lhs, input_rhs};
@@ -181,14 +217,7 @@
     }
 
     // We may have just invalidated matmul_it, so let's refresh it now.
-    matmul_it = model->operators.begin();
-    for (; matmul_it != model->operators.end(); ++matmul_it) {
-      if (matmul_it->get() == matmul_op) {
-        break;
-      }
-    }
-    CHECK(matmul_it != model->operators.end());
-    CHECK(matmul_it->get() == matmul_op);
+    refresh_matmul_iterator();
   } else {
     AddMessageF("Replacing %s by a FullyConnected operator",
                 LogName(*matmul_op));
diff --git a/tensorflow/lite/toco/graph_transformations/tests/BUILD b/tensorflow/lite/toco/graph_transformations/tests/BUILD
index 4992429..099c083 100644
--- a/tensorflow/lite/toco/graph_transformations/tests/BUILD
+++ b/tensorflow/lite/toco/graph_transformations/tests/BUILD
@@ -32,6 +32,17 @@
 )
 
 tf_cc_test(
+    name = "remove_successive_transpose_test",
+    srcs = ["remove_successive_transpose_test.cc"],
+    deps = [
+        "//tensorflow/lite/toco:graph_transformations",
+        "//tensorflow/lite/toco:model",
+        "//tensorflow/lite/toco:tooling_util",
+        "@com_google_googletest//:gtest_main",
+    ],
+)
+
+tf_cc_test(
     name = "resolve_constant_concatenation_test",
     srcs = ["resolve_constant_concatenation_test.cc"],
     deps = [
diff --git a/tensorflow/lite/toco/graph_transformations/tests/remove_successive_transpose_test.cc b/tensorflow/lite/toco/graph_transformations/tests/remove_successive_transpose_test.cc
new file mode 100644
index 0000000..a5a0afb
--- /dev/null
+++ b/tensorflow/lite/toco/graph_transformations/tests/remove_successive_transpose_test.cc
@@ -0,0 +1,147 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/lite/toco/model.h"
+#include "tensorflow/lite/toco/tooling_util.h"
+
+namespace {
+
+using ::testing::Test;
+
+class RemoveSuccessiveTransposeTest : public Test {
+ protected:
+  RemoveSuccessiveTransposeTest() {}
+
+  void SetUp() override { model_.reset(new toco::Model); }
+
+  void CreateArray(const std::string& name, const std::vector<int>& shape) {
+    toco::Array& array = model_->GetOrCreateArray(name);
+    array.data_type = toco::ArrayDataType::kFloat;
+    toco::Shape* array_shape = array.mutable_shape();
+    *(array_shape->mutable_dims()) = shape;
+  }
+
+  void CreateConstantArray(const std::string& name,
+                           const std::vector<int>& shape,
+                           const std::vector<float>& data) {
+    CreateArray(name, shape);
+    toco::Array& array = model_->GetOrCreateArray(name);
+    auto& array_buffer = array.GetMutableBuffer<toco::ArrayDataType::kFloat>();
+    int bufsize = 1;
+    for (int dim : shape) {
+      bufsize *= dim;
+    }
+    array_buffer.data.resize(bufsize);
+    float* buf_ptr = array_buffer.data.data();
+    for (int i = 0; i < bufsize; ++i) {
+      buf_ptr[i] = data[i];
+    }
+  }
+
+  void CreateGraph(const std::vector<int>& perm1,
+                   const std::vector<int>& perm2) {
+    CreateArray("InputA", {2, 2});
+    CreateArray("InputB", {2, 2});
+    CreateArray("Input", {2, 2});
+    CreateArray("InputTranspose", {2, 2});
+    CreateArray("InputTransposeTranspose", {2, 2});
+    CreateArray("InputTransposeTransposePlusB", {2, 2});
+
+    auto* add_op = new toco::AddOperator;
+    add_op->inputs = {"InputA", "InputB"};
+    add_op->outputs = {"Input"};
+    model_->operators.push_back(std::unique_ptr<toco::Operator>(add_op));
+
+    auto* transpose_op = new toco::TransposeOperator;
+    transpose_op->inputs = {"Input"};
+    transpose_op->perm = perm1;
+    transpose_op->outputs = {"InputTranspose"};
+    model_->operators.push_back(std::unique_ptr<toco::Operator>(transpose_op));
+
+    auto* transpose2_op = new toco::TransposeOperator;
+    transpose2_op->inputs = {"InputTranspose"};
+    transpose2_op->perm = perm2;
+    transpose2_op->outputs = {"InputTransposeTranspose"};
+    model_->operators.push_back(std::unique_ptr<toco::Operator>(transpose2_op));
+
+    auto* add2_op = new toco::AddOperator;
+    add2_op->inputs = {"InputTransposeTranspose", "InputB"};
+    add2_op->outputs = {"InputTransposeTransposePlusB"};
+    model_->operators.push_back(std::unique_ptr<toco::Operator>(add2_op));
+  }
+
+  std::unique_ptr<toco::Model> model_;
+};
+
+TEST_F(RemoveSuccessiveTransposeTest, RemoveTranspose) {
+  // Creating a model.
+  CreateGraph({1, 0}, {1, 0});
+
+  toco::RemoveSuccesiveTranspose transformation;
+  bool modified;
+  ASSERT_TRUE(transformation.Run(model_.get(), /*op_index=*/1, &modified).ok());
+  EXPECT_TRUE(modified);
+
+  ASSERT_EQ(model_->operators.size(), 2);
+  ASSERT_EQ(model_->operators[0]->type, toco::OperatorType::kAdd);
+  ASSERT_EQ(model_->operators[1]->type, toco::OperatorType::kAdd);
+  ASSERT_EQ(model_->operators[1]->inputs[0], model_->operators[0]->outputs[0]);
+}
+
+TEST_F(RemoveSuccessiveTransposeTest, DontRemoveNotIdentityTranspose) {
+  // Creating a model.
+  CreateGraph({0, 2, 1}, {1, 0, 2});
+
+  toco::RemoveSuccesiveTranspose transformation;
+  bool modified;
+  ASSERT_TRUE(transformation.Run(model_.get(), /*op_index=*/1, &modified).ok());
+  EXPECT_FALSE(modified);
+}
+
+TEST_F(RemoveSuccessiveTransposeTest, DontRemoveTransposeOutputUnused) {
+  CreateArray("InputA", {2, 2});
+  CreateArray("InputB", {2, 2});
+  CreateArray("Input", {2, 2});
+  CreateArray("InputTranspose", {2, 2});
+  CreateArray("InputTransposeTranspose", {2, 2});
+
+  auto* add_op = new toco::AddOperator;
+  add_op->inputs = {"InputA", "InputB"};
+  add_op->outputs = {"Input"};
+  model_->operators.push_back(std::unique_ptr<toco::Operator>(add_op));
+
+  auto* transpose_op = new toco::TransposeOperator;
+  transpose_op->inputs = {"Input"};
+  transpose_op->perm = {0, 2, 1};
+  transpose_op->outputs = {"InputTranspose"};
+  model_->operators.push_back(std::unique_ptr<toco::Operator>(transpose_op));
+
+  auto* transpose2_op = new toco::TransposeOperator;
+  transpose2_op->inputs = {"InputTranspose"};
+  transpose2_op->perm = {0, 2, 1};
+  transpose2_op->outputs = {"InputTransposeTranspose"};
+  model_->operators.push_back(std::unique_ptr<toco::Operator>(transpose2_op));
+
+  toco::RemoveSuccesiveTranspose transformation;
+  bool modified;
+  ASSERT_TRUE(transformation.Run(model_.get(), /*op_index=*/1, &modified).ok());
+  EXPECT_FALSE(modified);
+}
+}  // namespace
diff --git a/tensorflow/lite/toco/import_tensorflow.cc b/tensorflow/lite/toco/import_tensorflow.cc
index 17c7d71..0921189 100644
--- a/tensorflow/lite/toco/import_tensorflow.cc
+++ b/tensorflow/lite/toco/import_tensorflow.cc
@@ -818,9 +818,6 @@
     reorder->output_axes_order = AxesOrder::kOHWI;
     model->operators.emplace_back(reorder);
   }
-  auto* conv = new ConvOperator;
-  conv->inputs = {input_name, reordered_weights_name};
-  conv->outputs = {node.name()};
   if (!HasAttr(node, "strides")) {
     return tensorflow::errors::InvalidArgument("Missing attribute 'strides'");
   }
@@ -828,8 +825,8 @@
   TF_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides"));
   TF_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)"));
   TF_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)"));
-  conv->stride_height = strides.i(1);
-  conv->stride_width = strides.i(2);
+  int dilation_height_factor;
+  int dilation_width_factor;
   if (HasAttr(node, "dilations")) {
     const auto& dilations = GetListAttr(node, "dilations");
     TF_RETURN_IF_ERROR(
@@ -841,21 +838,30 @@
           node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
           dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "]."));
     }
-    conv->dilation_height_factor = dilations.i(1);
-    conv->dilation_width_factor = dilations.i(2);
+    dilation_height_factor = dilations.i(1);
+    dilation_width_factor = dilations.i(2);
   } else {
-    conv->dilation_height_factor = 1;
-    conv->dilation_width_factor = 1;
+    dilation_height_factor = 1;
+    dilation_width_factor = 1;
   }
   const auto& padding = GetStringAttr(node, "padding");
+  PaddingType padding_type;
   if (padding == "SAME") {
-    conv->padding.type = PaddingType::kSame;
+    padding_type = PaddingType::kSame;
   } else if (padding == "VALID") {
-    conv->padding.type = PaddingType::kValid;
+    padding_type = PaddingType::kValid;
   } else {
     return tensorflow::errors::InvalidArgument(
         "Bad padding (only SAME and VALID are supported)");
   }
+  auto* conv = new ConvOperator;
+  conv->inputs = {input_name, reordered_weights_name};
+  conv->outputs = {node.name()};
+  conv->stride_height = strides.i(1);
+  conv->stride_width = strides.i(2);
+  conv->dilation_height_factor = dilation_height_factor;
+  conv->dilation_width_factor = dilation_width_factor;
+  conv->padding.type = padding_type;
   model->operators.emplace_back(conv);
 
   return tensorflow::Status::OK();
@@ -894,15 +900,12 @@
     reorder->output_axes_order = AxesOrder::k1HWO;
     model->operators.emplace_back(reorder);
   }
-  auto* conv = new DepthwiseConvOperator;
-  conv->inputs = {input_name, reordered_weights_name};
-  conv->outputs = {node.name()};
   const auto& strides = GetListAttr(node, "strides");
-  CHECK_EQ(strides.i_size(), 4);
-  CHECK_EQ(strides.i(0), 1);
-  CHECK_EQ(strides.i(3), 1);
-  conv->stride_height = strides.i(1);
-  conv->stride_width = strides.i(2);
+  TF_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides"));
+  TF_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)"));
+  TF_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)"));
+  int dilation_height_factor;
+  int dilation_width_factor;
   if (HasAttr(node, "dilations")) {
     const auto& dilations = GetListAttr(node, "dilations");
     TF_RETURN_IF_ERROR(
@@ -914,20 +917,30 @@
           node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
           dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "]."));
     }
-    conv->dilation_height_factor = dilations.i(1);
-    conv->dilation_width_factor = dilations.i(2);
+    dilation_height_factor = dilations.i(1);
+    dilation_width_factor = dilations.i(2);
   } else {
-    conv->dilation_height_factor = 1;
-    conv->dilation_width_factor = 1;
+    dilation_height_factor = 1;
+    dilation_width_factor = 1;
   }
   const auto& padding = GetStringAttr(node, "padding");
+  PaddingType padding_type;
   if (padding == "SAME") {
-    conv->padding.type = PaddingType::kSame;
+    padding_type = PaddingType::kSame;
   } else if (padding == "VALID") {
-    conv->padding.type = PaddingType::kValid;
+    padding_type = PaddingType::kValid;
   } else {
-    LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
+    return tensorflow::errors::InvalidArgument(
+        "Bad padding (only SAME and VALID are supported)");
   }
+  auto* conv = new DepthwiseConvOperator;
+  conv->inputs = {input_name, reordered_weights_name};
+  conv->outputs = {node.name()};
+  conv->stride_height = strides.i(1);
+  conv->stride_width = strides.i(2);
+  conv->dilation_height_factor = dilation_height_factor;
+  conv->dilation_width_factor = dilation_width_factor;
+  conv->padding.type = padding_type;
   model->operators.emplace_back(conv);
   return tensorflow::Status::OK();
 }
@@ -1914,7 +1927,7 @@
         << "Dilation unsupported in TransposeConv. TensorFlow op \""
         << node.name() << "\" had dilations";
     CHECK((dilations.i(0) == 1) && (dilations.i(1) == 1) &&
-          (dilations.i(1) == 1) && (dilations.i(3) == 1))
+          (dilations.i(2) == 1) && (dilations.i(3) == 1))
         << "Dilation unsupported in TransposeConv. TensorFlow op \""
         << node.name() << "\" had dilations:[ " << dilations.i(0) << ", "
         << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3)
@@ -2369,12 +2382,13 @@
     const ModelFlags& model_flags, Model* model) {
   DCHECK_EQ(node.op(), "UnidirectionalSequenceLstm");
 
-  auto* op = new UnidirectionalSequenceLstmOperator();
   const auto& indices = GetListAttr(node, "_tflite_input_indices");
   if (indices.i_size() != node.input().size()) {
     return tensorflow::errors::InvalidArgument("Input size does not match.");
   }
 
+  auto* op = new UnidirectionalSequenceLstmOperator();
+
   // The input size needs to be the same as the TfLite UniDirectionalSequence
   // Lstm implementation.
   const int kInputsSize = 20;
@@ -2424,12 +2438,12 @@
     const ModelFlags& model_flags, Model* model) {
   DCHECK_EQ(node.op(), "UnidirectionalSequenceRnn");
 
-  auto* op = new UnidirectionalSequenceRnnOperator();
   const auto& indices = GetListAttr(node, "_tflite_input_indices");
   if (indices.i_size() != node.input().size()) {
     return tensorflow::errors::InvalidArgument("Input size does not match.");
   }
 
+  auto* op = new UnidirectionalSequenceRnnOperator();
   for (const string& input : node.input()) {
     op->inputs.push_back(input);
   }
diff --git a/tensorflow/lite/toco/model.h b/tensorflow/lite/toco/model.h
index 7a95e5d..3d1e82a 100644
--- a/tensorflow/lite/toco/model.h
+++ b/tensorflow/lite/toco/model.h
@@ -21,6 +21,7 @@
 #include <memory>
 #include <string>
 #include <unordered_map>
+#include <unordered_set>
 #include <vector>
 
 #include "absl/types/optional.h"
@@ -553,6 +554,10 @@
   FullyConnectedOperator() : Operator(OperatorType::kFullyConnected) {}
   FullyConnectedWeightsFormat weights_format =
       FullyConnectedWeightsFormat::kDefault;
+
+  // `keep_num_dims` is supported in the FullyConnected kernel version 5, but
+  // it's never supported by Toco.
+  bool keep_num_dims = false;
 };
 
 // Dequantization operator, converting a quantized array of integers with
diff --git a/tensorflow/lite/toco/tflite/BUILD b/tensorflow/lite/toco/tflite/BUILD
index 01850bf..f27f0f9 100644
--- a/tensorflow/lite/toco/tflite/BUILD
+++ b/tensorflow/lite/toco/tflite/BUILD
@@ -89,6 +89,7 @@
     ],
     visibility = ["//visibility:public"],
     deps = [
+        ":op_version",
         ":operator",
         ":types",
         "//tensorflow/lite:schema_fbs_version",
@@ -108,9 +109,11 @@
     ],
     deps = [
         ":export",
+        ":operator",
         "//tensorflow/core:ops",
         "//tensorflow/lite/schema:schema_fbs",
         "@com_google_googletest//:gtest_main",
+        "@flatbuffers",
     ],
 )
 
diff --git a/tensorflow/lite/toco/tflite/export.cc b/tensorflow/lite/toco/tflite/export.cc
index c32466b..227c6aa 100644
--- a/tensorflow/lite/toco/tflite/export.cc
+++ b/tensorflow/lite/toco/tflite/export.cc
@@ -19,6 +19,7 @@
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/lite/context.h"
 #include "tensorflow/lite/schema/schema_generated.h"
+#include "tensorflow/lite/toco/tflite/op_version.h"
 #include "tensorflow/lite/toco/tflite/operator.h"
 #include "tensorflow/lite/toco/tflite/types.h"
 #include "tensorflow/lite/toco/tooling_util.h"
@@ -38,9 +39,11 @@
 using ::tflite::BuiltinOperator_MAX;
 using ::tflite::BuiltinOperator_MIN;
 using ::tflite::CreateBuffer;
+using ::tflite::CreateMetadata;
 using ::tflite::CreateModel;
 using ::tflite::CreateOperator;
 using ::tflite::CreateTensor;
+using ::tflite::Metadata;
 using ::tflite::Operator;
 using ::tflite::OperatorCode;
 using ::tflite::SubGraph;
@@ -456,6 +459,17 @@
   }
 }
 
+// Exports a string buffer that contains the model's minimum required runtime
+// version.
+void ExportModelVersionBuffer(
+    const Model& model, std::vector<Offset<Vector<uint8_t>>>* buffers_to_write,
+    FlatBufferBuilder* builder) {
+  const std::string min_runtime = GetMinimumRuntimeVersionForModel(model);
+  buffers_to_write->push_back(builder->CreateVector(
+      reinterpret_cast<const uint8_t*>(min_runtime.data()),
+      min_runtime.size()));
+}
+
 tensorflow::Status Export(
     const Model& model, string* output_file_contents,
     const ExportParams& params,
@@ -612,11 +626,20 @@
         "not implemented yet.");
   }
 
+  // Write the minimum required runtime version into metadata.
+  auto metadata =
+      CreateMetadata(builder, builder.CreateString("min_runtime_version"),
+                     buffers_to_write.size());
+  ExportModelVersionBuffer(model, &buffers_to_write, &builder);
+  std::vector<flatbuffers::Offset<Metadata>> metadatas = {metadata};
+
   auto buffers = ExportBuffers(model, buffers_to_write, &builder);
   auto description = builder.CreateString("TOCO Converted.");
+
   auto new_model_location =
       CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
-                  builder.CreateVector(subgraphs), description, buffers);
+                  builder.CreateVector(subgraphs), description, buffers,
+                  /* metadata_buffer */ 0, builder.CreateVector(metadatas));
   ::tflite::FinishModelBuffer(builder, new_model_location);
 
   if (params.quantize_weights == QuantizedBufferType::NONE) {
diff --git a/tensorflow/lite/toco/tflite/export_test.cc b/tensorflow/lite/toco/tflite/export_test.cc
index 0ae6104..bbb1c55 100644
--- a/tensorflow/lite/toco/tflite/export_test.cc
+++ b/tensorflow/lite/toco/tflite/export_test.cc
@@ -16,6 +16,7 @@
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
+#include "flatbuffers/flatbuffers.h"  // TF:flatbuffers
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/lite/schema/schema_generated.h"
@@ -245,6 +246,44 @@
   EXPECT_THAT(ExportAndGetOperatorIndices(params), ElementsAre(1, 0, 2, 3));
 }
 
+TEST_F(ExportTest, ExportMinRuntime) {
+  AddOperatorsByName({"Conv", "Add", "Sub"});
+
+  ExportParams params;
+  params.allow_custom_ops = true;
+  params.enable_select_tf_ops = false;
+  params.quantize_weights = QuantizedBufferType::NONE;
+
+  string output;
+  auto status = Export(input_model_, &output, params);
+  auto* model = ::tflite::GetModel(output.data());
+  EXPECT_EQ(model->metadata()->size(), 1);
+  EXPECT_EQ(model->metadata()->Get(0)->name()->str(), "min_runtime_version");
+  auto buf = model->metadata()->Get(0)->buffer();
+  auto* buffer = (*model->buffers())[buf];
+  auto* array = buffer->data();
+  string version(reinterpret_cast<const char*>(array->data()), array->size());
+  EXPECT_EQ(version, "1.6.0");
+}
+
+TEST_F(ExportTest, ExportEmptyMinRuntime) {
+  AddOperatorsByName({"Switch", "MyCustomOp", "Assert"});
+
+  ExportParams params;
+  params.allow_custom_ops = true;
+
+  string output;
+  auto status = Export(input_model_, &output, params);
+  auto* model = ::tflite::GetModel(output.data());
+  EXPECT_EQ(model->metadata()->size(), 1);
+  EXPECT_EQ(model->metadata()->Get(0)->name()->str(), "min_runtime_version");
+  auto buf = model->metadata()->Get(0)->buffer();
+  auto* buffer = (*model->buffers())[buf];
+  auto* array = buffer->data();
+  string version(reinterpret_cast<const char*>(array->data()), array->size());
+  EXPECT_EQ(version, "");
+}
+
 TEST_F(ExportTest, UnsupportedControlFlowErrors) {
   AddOperatorsByName({"Conv", "Add", "Switch", "Merge"});
 
@@ -532,7 +571,7 @@
       auto* op = new ConvOperator;
       op->inputs.push_back("input");
       op->inputs.push_back("filter");
-      op->inputs.push_back("output");
+      op->outputs.push_back("output");
 
       op->padding.type = PaddingType::kSame;
       op->stride_width = 1;
diff --git a/tensorflow/lite/toco/tflite/op_version.cc b/tensorflow/lite/toco/tflite/op_version.cc
index 1937f3e..50e0ea9 100644
--- a/tensorflow/lite/toco/tflite/op_version.cc
+++ b/tensorflow/lite/toco/tflite/op_version.cc
@@ -61,8 +61,11 @@
           {{OperatorType::kFullyConnected, 2}, "1.10.0"},
           {{OperatorType::kFullyConnected, 3}, "1.14.0"},
           {{OperatorType::kFullyConnected, 4}, "1.14.0"},
+          {{OperatorType::kFullyConnected, 5}, "1.14.0"},
+          {{OperatorType::kFullyConnected, 6}, kPendingReleaseOpVersion},
           {{OperatorType::kGather, 1}, "1.6.0"},
           {{OperatorType::kGather, 2}, "1.14.0"},
+          {{OperatorType::kGather, 3}, kPendingReleaseOpVersion},
           {{OperatorType::kGatherNd, 1}, "1.14.0"},
           {{OperatorType::kSvdf, 1}, "1.5.0"},
           {{OperatorType::kSvdf, 2}, "1.14.0"},
@@ -78,6 +81,7 @@
           {{OperatorType::kMinimum, 2}, "1.14.0"},
           {{OperatorType::kMul, 1}, "1.5.0"},
           {{OperatorType::kMul, 2}, "1.14.0"},
+          {{OperatorType::kMul, 3}, kPendingReleaseOpVersion},
           {{OperatorType::kPad, 1}, "1.5.0"},
           {{OperatorType::kPad, 2}, "1.14.0"},
           {{OperatorType::kTile, 1}, "1.10.1"},
@@ -181,6 +185,7 @@
   op_signature.model = &model;
   string model_min_version;
   for (const auto& op : model.operators) {
+    if (op_types_map.find(op->type) == op_types_map.end()) continue;
     op_signature.op = op.get();
     const int version = op_types_map.at(op->type)->GetVersion(op_signature);
     std::pair<OperatorType, int> version_key = {op->type, version};
diff --git a/tensorflow/lite/toco/tflite/op_version.h b/tensorflow/lite/toco/tflite/op_version.h
index 9c2b167..54a7750 100644
--- a/tensorflow/lite/toco/tflite/op_version.h
+++ b/tensorflow/lite/toco/tflite/op_version.h
@@ -20,10 +20,10 @@
 namespace toco {
 namespace tflite {
 
-// Get the minimum TF Lite runtime required to run a model. Each operator in
-// the model will have its own minimum requirement of a runtime, and the model's
-// minimum requirement of runtime is defined as the maximum of all the
-// operators' minimum runtime.
+// Get the minimum TF Lite runtime required to run a model. Each built-in
+// operator in the model will have its own minimum requirement of a runtime, and
+// the model's minimum requirement of runtime is defined as the maximum of all
+// the built-in operators' minimum runtime.
 std::string GetMinimumRuntimeVersionForModel(const Model& model);
 
 }  // namespace tflite
diff --git a/tensorflow/lite/toco/tflite/op_version_test.cc b/tensorflow/lite/toco/tflite/op_version_test.cc
index daacc71..4d567c3 100644
--- a/tensorflow/lite/toco/tflite/op_version_test.cc
+++ b/tensorflow/lite/toco/tflite/op_version_test.cc
@@ -131,7 +131,7 @@
   fc->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8;
   model.operators.push_back(std::move(fc));
 
-  EXPECT_EQ(GetMinimumRuntimeVersionForModel(model), "1.10.0");
+  EXPECT_EQ(GetMinimumRuntimeVersionForModel(model), "");
 }
 
 }  // namespace
diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc
index b064ea3..c9b8804 100644
--- a/tensorflow/lite/toco/tflite/operator.cc
+++ b/tensorflow/lite/toco/tflite/operator.cc
@@ -496,6 +496,14 @@
     const Array& input_array = op_signature.model->GetArray(input_name);
     const Array& weights_array = op_signature.model->GetArray(weights_name);
     const Array& output_array = op_signature.model->GetArray(output_name);
+    // 2 inputs (no bias) use case is supported starting from version 6.
+    if (op_signature.op->inputs.size() == 2) {
+      return 6;
+    }
+    // `keep_num_dims` is supported at verison 5.
+    if (fc_op.keep_num_dims) {
+      return 5;
+    }
     // Int8 fully fixed point kernel is at version 4.
     if (input_array.data_type == ArrayDataType::kInt8 &&
         weights_array.data_type == ArrayDataType::kInt8 &&
@@ -540,7 +548,11 @@
   int GetVersion(const OperatorSignature& op_signature) const override {
     const string& input_name = op_signature.op->inputs[0];
     const Array& input_array = op_signature.model->GetArray(input_name);
-    // If the op take int8 input, it is version 2.
+    // If the op takes bool input, it is version 3.
+    if (input_array.data_type == ArrayDataType::kBool) {
+      return 3;
+    }
+    // If the op takes int8 input, it is version 2.
     if (input_array.data_type == ArrayDataType::kInt8) {
       return 2;
     }
@@ -778,10 +790,23 @@
   }
 
   int GetVersion(const OperatorSignature& op_signature) const override {
-    const string& input_name = op_signature.op->inputs[0];
-    const Array& input_array = op_signature.model->GetArray(input_name);
+    const string& input1_name = op_signature.op->inputs[0];
+    const string& input2_name = op_signature.op->inputs[1];
+    const string& output_name = op_signature.op->outputs[0];
+    const Array& input1_array = op_signature.model->GetArray(input1_name);
+    const Array& input2_array = op_signature.model->GetArray(input2_name);
+    const Array& output_array = op_signature.model->GetArray(output_name);
+    const auto& input1_quant = input1_array.quantization_params;
+    const auto& input2_quant = input2_array.quantization_params;
+    const auto& output_quant = output_array.quantization_params;
+    // Version 3 supports have a rescale value greater than or equal to 1.
+    if (input1_quant && input2_quant && output_quant &&
+        (input1_quant->scale * input2_quant->scale / output_quant->scale) >=
+            1.0) {
+      return 3;
+    }
     // Version 2 supports signed int8 input types.
-    if (input_array.data_type == ArrayDataType::kInt8) {
+    if (input1_array.data_type == ArrayDataType::kInt8) {
       return 2;
     }
     return 1;
diff --git a/tensorflow/lite/toco/tflite/operator_test.cc b/tensorflow/lite/toco/tflite/operator_test.cc
index 3b007cb..40313f8 100644
--- a/tensorflow/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/lite/toco/tflite/operator_test.cc
@@ -917,7 +917,35 @@
 
 TEST_F(OperatorTest, VersioningSubTest) { SimpleVersioningTest<SubOperator>(); }
 
-TEST_F(OperatorTest, VersioningMulTest) { SimpleVersioningTest<MulOperator>(); }
+void SimpleMulVersioningTest(ArrayDataType data_type, float multiplier,
+                             int version) {
+  MulOperator op;
+  op.inputs = {"input1", "input2"};
+  op.outputs = {"output"};
+  auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
+  const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
+
+  Model model;
+  Array& input0 = model.GetOrCreateArray(op.inputs[0]);
+  Array& input1 = model.GetOrCreateArray(op.inputs[1]);
+  Array& output = model.GetOrCreateArray(op.outputs[0]);
+
+  input0.data_type = data_type;
+  input0.GetOrCreateQuantizationParams().scale = 1.0f;
+  input1.data_type = data_type;
+  input1.GetOrCreateQuantizationParams().scale = 1.0f;
+  output.data_type = data_type;
+  output.GetOrCreateQuantizationParams().scale = 1.0f / multiplier;
+
+  OperatorSignature signature = {.op = &op, .model = &model};
+  EXPECT_EQ(base_op->GetVersion(signature), version);
+}
+
+TEST_F(OperatorTest, VersioningMulTest) {
+  SimpleMulVersioningTest(ArrayDataType::kUint8, 0.5f, 1);
+  SimpleMulVersioningTest(ArrayDataType::kInt8, 0.5f, 2);
+  SimpleMulVersioningTest(ArrayDataType::kInt8, 2.0f, 3);
+}
 
 TEST_F(OperatorTest, VersioningPadTest) { SimpleVersioningTest<PadOperator>(); }
 
@@ -957,7 +985,7 @@
   output_uint8_array.data_type = ArrayDataType::kUint8;
   OperatorSignature uint8_signature = {.op = &fully_connected_op,
                                        .model = &uint8_model};
-  EXPECT_EQ(op->GetVersion(uint8_signature), 1);
+  EXPECT_EQ(op->GetVersion(uint8_signature), 6);
 
   Model int8_model;
   Array& input_int8_array =
@@ -971,7 +999,7 @@
   output_int8_array.data_type = ArrayDataType::kInt8;
   OperatorSignature int8_signature = {.op = &fully_connected_op,
                                       .model = &int8_model};
-  EXPECT_EQ(op->GetVersion(int8_signature), 4);
+  EXPECT_EQ(op->GetVersion(int8_signature), 6);
 }
 
 TEST_F(OperatorTest, VersioningDequantizeTest) {
diff --git a/tensorflow/lite/toco/toco_tooling.cc b/tensorflow/lite/toco/toco_tooling.cc
index 020d228..c9143bb 100644
--- a/tensorflow/lite/toco/toco_tooling.cc
+++ b/tensorflow/lite/toco/toco_tooling.cc
@@ -67,6 +67,7 @@
   transformations->Add(new PropagateActivationFunctionIntoConstants);
   transformations->Add(new PropagateArrayDataTypes);
   transformations->Add(new PropagateFixedSizes);
+  transformations->Add(new RemoveSuccesiveTranspose);
   transformations->Add(new RemoveTensorFlowAssert);
   transformations->Add(new RemoveTensorFlowIdentity);
   transformations->Add(new RemoveTrivialConcatenation);
diff --git a/tensorflow/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD
index 69e8fc6..6860a4a 100644
--- a/tensorflow/lite/tools/benchmark/BUILD
+++ b/tensorflow/lite/tools/benchmark/BUILD
@@ -36,6 +36,26 @@
     ],
 )
 
+cc_binary(
+    name = "benchmark_model_performance_options",
+    srcs = [
+        "benchmark_tflite_performance_options_main.cc",
+    ],
+    copts = common_copts,
+    linkopts = tflite_linkopts() + select({
+        "//tensorflow:android": [
+            "-pie",  # Android 5.0 and later supports only PIE
+            "-lm",  # some builtin ops, e.g., tanh, need -lm
+        ],
+        "//conditions:default": [],
+    }),
+    deps = [
+        ":benchmark_performance_options",
+        ":benchmark_tflite_model_lib",
+        ":logging",
+    ],
+)
+
 tf_cc_binary(
     name = "benchmark_model_plus_flex",
     srcs = [
@@ -87,6 +107,7 @@
     copts = common_copts,
     deps = [
         ":benchmark_model_lib",
+        ":benchmark_utils",
         ":logging",
         "//tensorflow/lite:framework",
         "//tensorflow/lite:string_util",
@@ -99,6 +120,24 @@
 )
 
 cc_library(
+    name = "benchmark_performance_options",
+    srcs = [
+        "benchmark_performance_options.cc",
+    ],
+    hdrs = ["benchmark_performance_options.h"],
+    copts = common_copts,
+    deps = [
+        ":benchmark_model_lib",
+        ":benchmark_params",
+        ":benchmark_utils",
+        ":logging",
+        "//tensorflow/core:stats_calculator_portable",
+        "//tensorflow/lite/profiling:time",
+        "//tensorflow/lite/tools:command_line_flags",
+    ],
+)
+
+cc_library(
     name = "benchmark_params",
     srcs = [
         "benchmark_params.cc",
@@ -117,6 +156,7 @@
     copts = common_copts,
     deps = [
         ":benchmark_params",
+        ":benchmark_utils",
         ":logging",
         "//tensorflow/core:stats_calculator_portable",
         "//tensorflow/lite:framework",
@@ -125,4 +165,27 @@
     ],
 )
 
+cc_library(
+    name = "benchmark_utils",
+    srcs = [
+        "benchmark_utils.cc",
+    ],
+    hdrs = ["benchmark_utils.h"],
+    copts = common_copts,
+    deps = ["//tensorflow/lite/profiling:time"],
+)
+
+cc_test(
+    name = "benchmark_utils_test",
+    srcs = [
+        "benchmark_utils_test.cc",
+    ],
+    copts = common_copts,
+    deps = [
+        ":benchmark_utils",
+        "//tensorflow/lite/profiling:time",
+        "@com_google_googletest//:gtest_main",
+    ],
+)
+
 tflite_portable_test_suite()
diff --git a/tensorflow/lite/tools/benchmark/README.md b/tensorflow/lite/tools/benchmark/README.md
index 8e77a22..4fb2827 100644
--- a/tensorflow/lite/tools/benchmark/README.md
+++ b/tensorflow/lite/tools/benchmark/README.md
@@ -213,3 +213,22 @@
 
 Average inference timings in us: Warmup: 83235, Init: 38467, no stats: 79760.9
 ```
+
+## Benchmark multiple performance options in a single run
+
+A convenient and simple C++ binary is also provided to benchmark multiple
+performance options in a single run. This binary is built based on the
+aforementioned benchmark tool that could only benchmark a single performance
+option at a time. They share the same build/install/run process, but the BUILD
+target name of this binary is `benchmark_model_performance_options` and it takes
+some additional parameters as detailed below.
+
+### Additional Parameters
+*   `perf_options_list`: `string` (default='all') \
+    A comma-separated list of TFLite performance options to benchmark.
+*   `option_benchmark_run_delay`: `float` (default=-1.0) \
+    The delay between two consecutive runs of benchmarking performance options
+    in seconds.
+*   `random_shuffle_benchmark_runs`: `bool` (default=true) \
+    Whether to perform all benchmark runs, each of which has different
+    performance options, in a random order.
diff --git a/tensorflow/lite/tools/benchmark/benchmark_model.cc b/tensorflow/lite/tools/benchmark/benchmark_model.cc
index 3ee5500..1aee4ca 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_model.cc
+++ b/tensorflow/lite/tools/benchmark/benchmark_model.cc
@@ -19,22 +19,9 @@
 #include <sstream>
 
 #include "tensorflow/lite/profiling/time.h"
+#include "tensorflow/lite/tools/benchmark/benchmark_utils.h"
 #include "tensorflow/lite/tools/benchmark/logging.h"
 
-namespace {
-void SleepForSeconds(double sleep_seconds) {
-  if (sleep_seconds <= 0.0) {
-    return;
-  }
-  // If requested, sleep between runs for an arbitrary amount of time.
-  // This can be helpful to determine the effect of mobile processor
-  // scaling and thermal throttling.
-  return tflite::profiling::time::SleepForMicros(
-      static_cast<uint64_t>(sleep_seconds * 1e6));
-}
-
-}  // namespace
-
 namespace tflite {
 namespace benchmark {
 using tensorflow::Stat;
@@ -55,7 +42,7 @@
 
 BenchmarkModel::BenchmarkModel() : params_(DefaultParams()) {}
 
-void BenchmarkLoggingListener::OnBenchmarkEnd(const BenchmarkResults &results) {
+void BenchmarkLoggingListener::OnBenchmarkEnd(const BenchmarkResults& results) {
   auto inference_us = results.inference_time_us();
   auto init_us = results.startup_latency_us();
   auto warmup_us = results.warmup_time_us();
@@ -143,7 +130,7 @@
     listeners_.OnSingleRunEnd();
 
     run_stats.UpdateStat(end_us - start_us);
-    SleepForSeconds(params_.Get<float>("run_delay"));
+    util::SleepForSeconds(params_.Get<float>("run_delay"));
     now_us = profiling::time::NowMicros();
   }
 
@@ -156,7 +143,7 @@
 
 bool BenchmarkModel::ValidateParams() { return true; }
 
-void BenchmarkModel::Run(int argc, char **argv) {
+void BenchmarkModel::Run(int argc, char** argv) {
   if (!ParseFlags(argc, argv)) {
     return;
   }
@@ -187,10 +174,10 @@
       {startup_latency_us, input_bytes, warmup_time_us, inference_time_us});
 }
 
-bool BenchmarkModel::ParseFlags(int argc, char **argv) {
+bool BenchmarkModel::ParseFlags(int* argc, char** argv) {
   auto flag_list = GetFlags();
   const bool parse_result =
-      Flags::Parse(&argc, const_cast<const char **>(argv), flag_list);
+      Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
   if (!parse_result) {
     std::string usage = Flags::Usage(argv[0], flag_list);
     TFLITE_LOG(ERROR) << usage;
diff --git a/tensorflow/lite/tools/benchmark/benchmark_model.h b/tensorflow/lite/tools/benchmark/benchmark_model.h
index 132ee84..0e78370 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_model.h
+++ b/tensorflow/lite/tools/benchmark/benchmark_model.h
@@ -129,8 +129,9 @@
 template <typename T>
 Flag CreateFlag(const char* name, BenchmarkParams* params,
                 const std::string& usage) {
-  return Flag(name, [params, name](const T& val) { params->Set<T>(name, val); },
-              params->Get<T>(name), usage);
+  return Flag(
+      name, [params, name](const T& val) { params->Set<T>(name, val); },
+      params->Get<T>(name), usage);
 }
 
 // Benchmarks a model.
@@ -150,11 +151,19 @@
     listeners_.AddListener(listener);
   }
 
+  BenchmarkParams* mutable_params() { return &params_; }
+
+  // Unparsable flags will remain in 'argv' in the original order and 'argc'
+  // will be updated accordingly.
+  bool ParseFlags(int* argc, char** argv);
+
  protected:
   virtual void LogParams();
   virtual bool ValidateParams();
-  bool ParseFlags(int argc, char** argv);
+
+  bool ParseFlags(int argc, char** argv) { return ParseFlags(&argc, argv); }
   virtual std::vector<Flag> GetFlags();
+
   virtual uint64_t ComputeInputBytes() = 0;
   virtual tensorflow::Stat<int64_t> Run(int min_num_times, float min_secs,
                                         float max_secs, RunType run_type);
diff --git a/tensorflow/lite/tools/benchmark/benchmark_params.cc b/tensorflow/lite/tools/benchmark/benchmark_params.cc
index 5ab3adf..caff971 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_params.cc
+++ b/tensorflow/lite/tools/benchmark/benchmark_params.cc
@@ -53,5 +53,13 @@
   TFLITE_BENCHMARK_CHECK(HasParam(name)) << name << " was not found.";
 }
 
+void BenchmarkParams::Set(const BenchmarkParams& other) {
+  for (const auto& param : params_) {
+    const BenchmarkParam* other_param = other.GetParam(param.first);
+    if (other_param == nullptr) continue;
+    param.second->Set(*other_param);
+  }
+}
+
 }  // namespace benchmark
 }  // namespace tflite
diff --git a/tensorflow/lite/tools/benchmark/benchmark_params.h b/tensorflow/lite/tools/benchmark/benchmark_params.h
index c591cc2..07db44d 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_params.h
+++ b/tensorflow/lite/tools/benchmark/benchmark_params.h
@@ -47,8 +47,17 @@
     AssertHasSameType(GetValueType<T>(), type_);
     return static_cast<TypedBenchmarkParam<T>*>(this);
   }
+
+  template <typename T>
+  const TypedBenchmarkParam<T>* AsConstTyped() const {
+    AssertHasSameType(GetValueType<T>(), type_);
+    return static_cast<const TypedBenchmarkParam<T>*>(this);
+  }
+
   virtual ~BenchmarkParam() {}
-  BenchmarkParam(ParamType type) : type_(type) {}
+  explicit BenchmarkParam(ParamType type) : type_(type) {}
+
+  virtual void Set(const BenchmarkParam&) {}
 
  private:
   static void AssertHasSameType(ParamType a, ParamType b);
@@ -59,11 +68,16 @@
 template <typename T>
 class TypedBenchmarkParam : public BenchmarkParam {
  public:
-  TypedBenchmarkParam(const T& value)
+  explicit TypedBenchmarkParam(const T& value)
       : BenchmarkParam(GetValueType<T>()), value_(value) {}
+
   void Set(const T& value) { value_ = value; }
 
-  T Get() { return value_; }
+  T Get() const { return value_; }
+
+  void Set(const BenchmarkParam& other) override {
+    Set(other.AsConstTyped<T>()->Get());
+  }
 
  private:
   T value_;
@@ -80,6 +94,12 @@
     return params_.find(name) != params_.end();
   }
 
+  const BenchmarkParam* GetParam(const std::string& name) const {
+    const auto& entry = params_.find(name);
+    if (entry == params_.end()) return nullptr;
+    return entry->second.get();
+  }
+
   template <typename T>
   void Set(const std::string& name, const T& value) {
     AssertParamExists(name);
@@ -92,6 +112,9 @@
     return params_.at(name)->AsTyped<T>()->Get();
   }
 
+  // Set the value of all same parameters from 'other'.
+  void Set(const BenchmarkParams& other);
+
  private:
   void AssertParamExists(const std::string& name) const;
   std::unordered_map<std::string, std::unique_ptr<BenchmarkParam>> params_;
diff --git a/tensorflow/lite/tools/benchmark/benchmark_performance_options.cc b/tensorflow/lite/tools/benchmark/benchmark_performance_options.cc
new file mode 100644
index 0000000..0afba77
--- /dev/null
+++ b/tensorflow/lite/tools/benchmark/benchmark_performance_options.cc
@@ -0,0 +1,247 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/tools/benchmark/benchmark_performance_options.h"
+
+#include <algorithm>
+#include <iomanip>
+#include <memory>
+#include <sstream>
+#include <utility>
+
+#include "tensorflow/core/util/stats_calculator.h"
+#include "tensorflow/lite/profiling/time.h"
+#include "tensorflow/lite/tools/benchmark/benchmark_params.h"
+#include "tensorflow/lite/tools/benchmark/benchmark_utils.h"
+#include "tensorflow/lite/tools/benchmark/logging.h"
+#include "tensorflow/lite/tools/command_line_flags.h"
+
+namespace tflite {
+namespace benchmark {
+
+void MultiRunStatsRecorder::OnBenchmarkStart(const BenchmarkParams& params) {
+  current_run_name_.clear();
+
+  if (params.Get<bool>("use_nnapi")) {
+    current_run_name_ = "nnapi";
+    return;
+  }
+
+  if (params.Get<bool>("use_gpu")) {
+    current_run_name_ = "gpu";
+    return;
+  }
+
+  // Handle cases run on CPU
+  // Note: could use std::to_string to convert an integer to string but it
+  // requires C++11.
+  std::stringstream sstm;
+  sstm << "cpu w/ " << params.Get<int32_t>("num_threads") << " threads";
+  current_run_name_ = sstm.str();
+}
+
+void MultiRunStatsRecorder::OnBenchmarkEnd(const BenchmarkResults& results) {
+  each_run_stats_.emplace_back(std::make_pair(current_run_name_, results));
+}
+
+void MultiRunStatsRecorder::OutputStats() {
+  // Make a 80-character-long header.
+  TFLITE_LOG(INFO) << "\n==============Summary of All Runs w/ Different "
+                      "Performance Options==============";
+  std::sort(each_run_stats_.begin(), each_run_stats_.end(),
+            EachRunStatsEntryComparator());
+
+  for (const auto& run_stats : each_run_stats_) {
+    std::stringstream stream;
+    // Output the name of this run first.
+    stream << std::setw(26) << run_stats.first << ": ";
+    run_stats.second.inference_time_us().OutputToStream(&stream);
+    TFLITE_LOG(INFO) << stream.str();
+  }
+}
+
+BenchmarkPerformanceOptions::BenchmarkPerformanceOptions(
+    BenchmarkModel* single_option_run)
+    : BenchmarkPerformanceOptions(DefaultParams(), single_option_run,
+                                  DefaultRunStatsRecorder()) {}
+
+BenchmarkPerformanceOptions::BenchmarkPerformanceOptions(
+    BenchmarkParams params, BenchmarkModel* single_option_run,
+    std::unique_ptr<MultiRunStatsRecorder> all_run_stats)
+    : params_(std::move(params)),
+      single_option_run_(single_option_run),
+      single_option_run_params_(single_option_run->mutable_params()),
+      all_run_stats_(std::move(all_run_stats)) {
+  single_option_run_->AddListener(all_run_stats_.get());
+}
+
+BenchmarkParams BenchmarkPerformanceOptions::DefaultParams() {
+  BenchmarkParams params;
+  params.AddParam("perf_options_list",
+                  BenchmarkParam::Create<std::string>("all"));
+  params.AddParam("option_benchmark_run_delay",
+                  BenchmarkParam::Create<float>(-1.0f));
+  params.AddParam("random_shuffle_benchmark_runs",
+                  BenchmarkParam::Create<bool>(true));
+  return params;
+}
+
+std::unique_ptr<MultiRunStatsRecorder>
+BenchmarkPerformanceOptions::DefaultRunStatsRecorder() {
+  return std::unique_ptr<MultiRunStatsRecorder>(new MultiRunStatsRecorder());
+}
+
+std::vector<Flag> BenchmarkPerformanceOptions::GetFlags() {
+  return {
+      CreateFlag<std::string>(
+          "perf_options_list", &params_,
+          "A comma-separated list of TFLite performance options to benchmark. "
+          "By default, all performance options are benchmarked."),
+      CreateFlag<float>("option_benchmark_run_delay", &params_,
+                        "The delay between two consecutive runs of "
+                        "benchmarking performance options in seconds."),
+      CreateFlag<bool>(
+          "random_shuffle_benchmark_runs", &params_,
+          "Whether to perform all benchmark runs, each of which has different "
+          "performance options, in a random order. It is enabled by default."),
+  };
+}
+
+bool BenchmarkPerformanceOptions::ParseFlags(int* argc, char** argv) {
+  auto flag_list = GetFlags();
+  const bool parse_result =
+      Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
+  if (!parse_result) {
+    std::string usage = Flags::Usage(argv[0], flag_list);
+    TFLITE_LOG(ERROR) << usage;
+    return false;
+  }
+
+  // Parse the value of --perf_options_list to find performance options to be
+  // benchmarked.
+  return ParsePerfOptions();
+}
+
+bool BenchmarkPerformanceOptions::ParsePerfOptions() {
+  const auto& perf_options_list = params_.Get<std::string>("perf_options_list");
+  if (!util::SplitAndParse(perf_options_list, ',', &perf_options_)) {
+    TFLITE_LOG(ERROR) << "Cannot parse --perf_options_list: '"
+                      << perf_options_list
+                      << "'. Please double-check its value.";
+    perf_options_.clear();
+    return false;
+  }
+
+  const auto valid_options = GetValidPerfOptions();
+  bool is_valid = true;
+  for (const auto& option : perf_options_) {
+    if (std::find(valid_options.begin(), valid_options.end(), option) ==
+        valid_options.end()) {
+      is_valid = false;
+      break;
+    }
+  }
+  if (!is_valid) {
+    std::string valid_options_str;
+    for (int i = 0; i < valid_options.size() - 1; ++i) {
+      valid_options_str += (valid_options[i] + ", ");
+    }
+    valid_options_str += valid_options.back();
+    TFLITE_LOG(ERROR)
+        << "There are invalid perf options in --perf_options_list: '"
+        << perf_options_list << "'. Valid perf options are: ["
+        << valid_options_str << "]";
+    perf_options_.clear();
+    return false;
+  }
+  return true;
+}
+
+std::vector<std::string> BenchmarkPerformanceOptions::GetValidPerfOptions()
+    const {
+  return {"all", "cpu", "gpu", "nnapi"};
+}
+
+bool BenchmarkPerformanceOptions::HasOption(const std::string& option) const {
+  return std::find(perf_options_.begin(), perf_options_.end(), option) !=
+         perf_options_.end();
+}
+
+void BenchmarkPerformanceOptions::ResetPerformanceOptions() {
+  single_option_run_params_->Set<int32_t>("num_threads", 1);
+  single_option_run_params_->Set<bool>("use_gpu", false);
+  single_option_run_params_->Set<bool>("use_nnapi", false);
+}
+
+void BenchmarkPerformanceOptions::CreatePerformanceOptions() {
+  TFLITE_LOG(INFO) << "The list of TFLite runtime options to be benchmarked: ["
+                   << params_.Get<std::string>("perf_options_list") << "]";
+
+  const bool benchmark_all = HasOption("all");
+
+  if (benchmark_all || HasOption("cpu")) {
+    const std::vector<int> num_threads = {1, 2, 4};
+    for (const int count : num_threads) {
+      BenchmarkParams params;
+      params.AddParam("num_threads", BenchmarkParam::Create<int32_t>(count));
+      all_run_params_.emplace_back(std::move(params));
+    }
+  }
+
+  if (benchmark_all || HasOption("gpu")) {
+    BenchmarkParams params;
+    params.AddParam("use_gpu", BenchmarkParam::Create<bool>(true));
+    all_run_params_.emplace_back(std::move(params));
+  }
+
+  if (benchmark_all || HasOption("nnapi")) {
+    BenchmarkParams params;
+    params.AddParam("use_nnapi", BenchmarkParam::Create<bool>(true));
+    all_run_params_.emplace_back(std::move(params));
+  }
+}
+
+void BenchmarkPerformanceOptions::Run(int argc, char** argv) {
+  // We first parse flags for single-option runs to get information like
+  // parameters of the input model etc.
+  if (!single_option_run_->ParseFlags(&argc, argv)) return;
+
+  // Now, we parse flags that are specified for this particular binary.
+  if (!ParseFlags(&argc, argv)) return;
+
+  // Now, the remaining are unrecognized flags and we simply print them out.
+  for (int i = 1; i < argc; ++i) {
+    TFLITE_LOG(WARN) << "WARNING: unrecognized commandline flag: " << argv[i];
+  }
+
+  CreatePerformanceOptions();
+
+  if (params_.Get<bool>("random_shuffle_benchmark_runs")) {
+    std::random_shuffle(all_run_params_.begin(), all_run_params_.end());
+  }
+
+  // Now perform all runs, each with different performance-affecting parameters.
+  for (const auto& run_params : all_run_params_) {
+    // Reset all performance-related options before any runs.
+    ResetPerformanceOptions();
+    single_option_run_params_->Set(run_params);
+    util::SleepForSeconds(params_.Get<float>("option_benchmark_run_delay"));
+    single_option_run_->Run();
+  }
+
+  all_run_stats_->OutputStats();
+}
+}  // namespace benchmark
+}  // namespace tflite
diff --git a/tensorflow/lite/tools/benchmark/benchmark_performance_options.h b/tensorflow/lite/tools/benchmark/benchmark_performance_options.h
new file mode 100644
index 0000000..df5aa81
--- /dev/null
+++ b/tensorflow/lite/tools/benchmark/benchmark_performance_options.h
@@ -0,0 +1,98 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_PERFORMANCE_OPTIONS_H_
+#define TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_PERFORMANCE_OPTIONS_H_
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/lite/tools/benchmark/benchmark_model.h"
+
+namespace tflite {
+namespace benchmark {
+
+class MultiRunStatsRecorder : public BenchmarkListener {
+ public:
+  void OnBenchmarkStart(const BenchmarkParams& params) override;
+  void OnBenchmarkEnd(const BenchmarkResults& results) override;
+
+  virtual void OutputStats();
+
+ protected:
+  using EachRunStatsEntry = std::pair<std::string, BenchmarkResults>;
+
+  // Use this to order the runs by the average inference time in increasing
+  // order (i.e. the fastest run ranks first.)
+  struct EachRunStatsEntryComparator {
+    bool operator()(const EachRunStatsEntry& i, const EachRunStatsEntry& j) {
+      return (i.second.inference_time_us().avg() <
+              j.second.inference_time_us().avg());
+    }
+  };
+
+  std::string current_run_name_;
+  std::vector<EachRunStatsEntry> each_run_stats_;
+};
+
+// Benchmarks all performance options on a model by repeatedly invoking the
+// single-performance-option run on a passed-in 'BenchmarkModel' object.
+class BenchmarkPerformanceOptions {
+ public:
+  // Doesn't own the memory of 'single_option_run'.
+  explicit BenchmarkPerformanceOptions(BenchmarkModel* single_option_run);
+
+  virtual ~BenchmarkPerformanceOptions() {}
+
+  void Run(int argc, char** argv);
+
+ protected:
+  static BenchmarkParams DefaultParams();
+  static std::unique_ptr<MultiRunStatsRecorder> DefaultRunStatsRecorder();
+
+  BenchmarkPerformanceOptions(
+      BenchmarkParams params, BenchmarkModel* single_option_run,
+      std::unique_ptr<MultiRunStatsRecorder> all_run_stats);
+
+  // Unparsable flags will remain in 'argv' in the original order and 'argc'
+  // will be updated accordingly.
+  bool ParseFlags(int* argc, char** argv);
+  virtual std::vector<Flag> GetFlags();
+
+  bool ParsePerfOptions();
+  virtual std::vector<std::string> GetValidPerfOptions() const;
+  bool HasOption(const std::string& option) const;
+
+  virtual void ResetPerformanceOptions();
+  virtual void CreatePerformanceOptions();
+
+  BenchmarkParams params_;
+  std::vector<std::string> perf_options_;
+
+  // The object that drives a single-performance-option run.
+  BenchmarkModel* const single_option_run_;          // Doesn't own the memory.
+  BenchmarkParams* const single_option_run_params_;  // Doesn't own the memory.
+
+  // Each element is a set of performance-affecting benchmark parameters to be
+  // all set for a particular benchmark run.
+  std::vector<BenchmarkParams> all_run_params_;
+
+  std::unique_ptr<MultiRunStatsRecorder> all_run_stats_;
+};
+
+}  // namespace benchmark
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_PERFORMANCE_OPTIONS_H_
diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
index 0035a0b..694cc060 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -29,6 +29,7 @@
 #include "tensorflow/lite/profiling/buffered_profiler.h"
 #include "tensorflow/lite/profiling/profile_summarizer.h"
 #include "tensorflow/lite/string_util.h"
+#include "tensorflow/lite/tools/benchmark/benchmark_utils.h"
 #include "tensorflow/lite/tools/benchmark/logging.h"
 #include "tensorflow/lite/tools/evaluation/utils.h"
 
@@ -119,40 +120,14 @@
 }
 
 std::vector<std::string> Split(const std::string& str, const char delim) {
-  std::istringstream input(str);
   std::vector<std::string> results;
-  std::string item;
-  while (std::getline(input, item, delim)) {
-    results.push_back(item);
+  if (!util::SplitAndParse(str, delim, &results)) {
+    results.clear();
   }
   return results;
 }
 
 template <typename T>
-bool SplitAndParse(const std::string& str, char delim, std::vector<T>* values) {
-  std::istringstream input(str);
-  bool first = true;
-  while (!input.eof()) {
-    if (!first) {
-      char c;
-      input >> c;
-      if (c != delim) {
-        return false;
-      }
-    } else {
-      first = false;
-    }
-    T val;
-    input >> val;
-    if (!input.eof() && !input.good()) {
-      return false;
-    }
-    values->push_back(val);
-  }
-  return true;
-}
-
-template <typename T>
 void FillRandomValue(T* ptr, int num_elements,
                      const std::function<T()>& random_func) {
   for (int i = 0; i < num_elements; ++i) {
@@ -197,7 +172,7 @@
 
     input.name = names[i];
 
-    TFLITE_BENCHMARK_CHECK(SplitAndParse(shapes[i], ',', &input.shape))
+    TFLITE_BENCHMARK_CHECK(util::SplitAndParse(shapes[i], ',', &input.shape))
         << "Incorrect size string specified: " << shapes[i];
     for (int dim : input.shape) {
       if (dim == -1) {
@@ -351,6 +326,12 @@
       FillRandomValue<float>(t_data.data.f, num_elements, []() {
         return static_cast<float>(rand()) / RAND_MAX - 0.5f;
       });
+    } else if (t->type == kTfLiteInt64) {
+      t_data.bytes = sizeof(int64_t) * num_elements;
+      t_data.data.raw = new char[t_data.bytes];
+      FillRandomValue<int64_t>(t_data.data.i64, num_elements, []() {
+        return static_cast<int64_t>(rand()) % 100;
+      });
     } else if (t->type == kTfLiteInt32) {
       // TODO(yunluli): This is currently only used for handling embedding input
       // for speech models. Generalize if necessary.
@@ -396,9 +377,15 @@
     if (t->type == kTfLiteFloat32) {
       std::memcpy(interpreter_->typed_tensor<float>(i), inputs_data_[j].data.f,
                   inputs_data_[j].bytes);
+    } else if (t->type == kTfLiteInt64) {
+      std::memcpy(interpreter_->typed_tensor<int64_t>(i),
+                  inputs_data_[j].data.i64, inputs_data_[j].bytes);
     } else if (t->type == kTfLiteInt32) {
       std::memcpy(interpreter_->typed_tensor<int32_t>(i),
                   inputs_data_[j].data.i32, inputs_data_[j].bytes);
+    } else if (t->type == kTfLiteInt64) {
+      std::memcpy(interpreter_->typed_tensor<int64_t>(i),
+                  inputs_data_[j].data.i64, inputs_data_[j].bytes);
     } else if (t->type == kTfLiteInt16) {
       std::memcpy(interpreter_->typed_tensor<int16_t>(i),
                   inputs_data_[j].data.i16, inputs_data_[j].bytes);
diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_performance_options_main.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_performance_options_main.cc
new file mode 100644
index 0000000..c70a719
--- /dev/null
+++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_performance_options_main.cc
@@ -0,0 +1,40 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/tools/benchmark/benchmark_performance_options.h"
+#include "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h"
+#include "tensorflow/lite/tools/benchmark/logging.h"
+
+namespace tflite {
+namespace benchmark {
+
+int Main(int argc, char** argv) {
+#ifdef TFLITE_CUSTOM_OPS_HEADER
+  TFLITE_LOG(INFO) << "STARTING with custom ops!";
+#else
+  TFLITE_LOG(INFO) << "STARTING!";
+#endif
+  BenchmarkTfLiteModel benchmark;
+  BenchmarkLoggingListener listener;
+  benchmark.AddListener(&listener);
+
+  BenchmarkPerformanceOptions all_options_benchmark(&benchmark);
+  all_options_benchmark.Run(argc, argv);
+  return EXIT_SUCCESS;
+}
+}  // namespace benchmark
+}  // namespace tflite
+
+int main(int argc, char** argv) { return tflite::benchmark::Main(argc, argv); }
diff --git a/tensorflow/lite/tools/benchmark/benchmark_utils.cc b/tensorflow/lite/tools/benchmark/benchmark_utils.cc
new file mode 100644
index 0000000..d8fe263
--- /dev/null
+++ b/tensorflow/lite/tools/benchmark/benchmark_utils.cc
@@ -0,0 +1,37 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/tools/benchmark/benchmark_utils.h"
+
+#include "tensorflow/lite/profiling/time.h"
+
+namespace tflite {
+namespace benchmark {
+namespace util {
+
+void SleepForSeconds(double sleep_seconds) {
+  if (sleep_seconds <= 0.0) {
+    return;
+  }
+  // If requested, sleep between runs for an arbitrary amount of time.
+  // This can be helpful to determine the effect of mobile processor
+  // scaling and thermal throttling.
+  tflite::profiling::time::SleepForMicros(
+      static_cast<uint64_t>(sleep_seconds * 1e6));
+}
+
+}  // namespace util
+}  // namespace benchmark
+}  // namespace tflite
diff --git a/tensorflow/lite/tools/benchmark/benchmark_utils.h b/tensorflow/lite/tools/benchmark/benchmark_utils.h
new file mode 100644
index 0000000..b690116
--- /dev/null
+++ b/tensorflow/lite/tools/benchmark/benchmark_utils.h
@@ -0,0 +1,52 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_UTILS_H_
+#define TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_UTILS_H_
+
+#include <sstream>
+#include <string>
+#include <vector>
+
+namespace tflite {
+namespace benchmark {
+namespace util {
+
+// A convenient function that wraps tflite::profiling::time::SleepForMicros and
+// simply return if 'sleep_seconds' is negative.
+void SleepForSeconds(double sleep_seconds);
+
+// Split the 'str' according to 'delim', and store each splitted element into
+// 'values'.
+template <typename T>
+bool SplitAndParse(const std::string& str, char delim, std::vector<T>* values) {
+  std::istringstream input(str);
+  for (std::string line; std::getline(input, line, delim);) {
+    std::istringstream to_parse(line);
+    T val;
+    to_parse >> val;
+    if (!to_parse.eof() && !to_parse.good()) {
+      return false;
+    }
+    values->emplace_back(val);
+  }
+  return true;
+}
+
+}  // namespace util
+}  // namespace benchmark
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_UTILS_H_
diff --git a/tensorflow/lite/tools/benchmark/benchmark_utils_test.cc b/tensorflow/lite/tools/benchmark/benchmark_utils_test.cc
new file mode 100644
index 0000000..cb15172
--- /dev/null
+++ b/tensorflow/lite/tools/benchmark/benchmark_utils_test.cc
@@ -0,0 +1,80 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/tools/benchmark/benchmark_utils.h"
+
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/lite/profiling/time.h"
+
+namespace tflite {
+namespace benchmark {
+namespace {
+
+TEST(BenchmarkHelpersTest, SleepForNegativeSeconds) {
+  const auto start_ts = tflite::profiling::time::NowMicros();
+  // The following should return immediately.
+  util::SleepForSeconds(-5.0);
+  const auto end_ts = tflite::profiling::time::NowMicros();
+
+  // As we don't have a mocked clock, we simply expect <1 sec has elapsed, which
+  // is admittedly not quite accurate.
+  EXPECT_LT(end_ts - start_ts, 1000000);
+}
+
+TEST(BenchmarkHelpersTest, SleepForSomeSeconds) {
+  const auto start_ts = tflite::profiling::time::NowMicros();
+  // The following should return after 2.0 secs
+  util::SleepForSeconds(2.0);
+  const auto end_ts = tflite::profiling::time::NowMicros();
+
+  // As we don't have a mocked clock, we simply expect >1.9 sec has elapsed.
+  EXPECT_GT(end_ts - start_ts, 1900000);
+}
+
+TEST(BenchmarkHelpersTest, SplitAndParseFailed) {
+  std::vector<int> results;
+  const bool splitted = util::SplitAndParse("hello;world", ';', &results);
+
+  EXPECT_FALSE(splitted);
+}
+
+TEST(BenchmarkHelpersTest, SplitAndParseString) {
+  std::vector<std::string> results;
+  const bool splitted = util::SplitAndParse("hello,world", ',', &results);
+
+  EXPECT_TRUE(splitted);
+  EXPECT_EQ(2, results.size());
+
+  EXPECT_EQ("hello", results[0]);
+  EXPECT_EQ("world", results[1]);
+}
+
+TEST(BenchmarkHelpersTest, SplitAndParseInts) {
+  std::vector<int> results;
+  const bool splitted = util::SplitAndParse("1,2", ',', &results);
+
+  EXPECT_TRUE(splitted);
+  EXPECT_EQ(2, results.size());
+
+  EXPECT_EQ(1, results[0]);
+  EXPECT_EQ(2, results[1]);
+}
+
+}  // namespace
+}  // namespace benchmark
+}  // namespace tflite
diff --git a/tensorflow/lite/tools/benchmark/ios/README.md b/tensorflow/lite/tools/benchmark/ios/README.md
index 3a9ae27..5c772ac 100644
--- a/tensorflow/lite/tools/benchmark/ios/README.md
+++ b/tensorflow/lite/tools/benchmark/ios/README.md
@@ -13,7 +13,7 @@
 number of threads. The default values in the JSON file are for the
 Mobilenet_1.0_224 model
 ([paper](https://arxiv.org/pdf/1704.04861.pdf),
-[tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz))
+[tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz))
 
 ## To build/install/run
 
diff --git a/tensorflow/lite/tools/benchmark/logging.h b/tensorflow/lite/tools/benchmark/logging.h
index 42ccb2c..808090b 100644
--- a/tensorflow/lite/tools/benchmark/logging.h
+++ b/tensorflow/lite/tools/benchmark/logging.h
@@ -46,10 +46,19 @@
   std::stringstream& Stream() { return stream_; }
   ~LoggingWrapper() {
     if (should_log_) {
-      std::cerr << stream_.str() << std::endl;
-      if (severity_ == LogSeverity::FATAL) {
-        std::flush(std::cerr);
-        std::abort();
+      switch (severity_) {
+        case LogSeverity::INFO:
+        case LogSeverity::WARN:
+          std::cout << stream_.str() << std::endl;
+          break;
+        case LogSeverity::ERROR:
+          std::cerr << stream_.str() << std::endl;
+          break;
+        case LogSeverity::FATAL:
+          std::cerr << stream_.str() << std::endl;
+          std::flush(std::cerr);
+          std::abort();
+          break;
       }
     }
   }
diff --git a/tensorflow/lite/tools/evaluation/proto/BUILD b/tensorflow/lite/tools/evaluation/proto/BUILD
index fe4e028..8c265ff 100644
--- a/tensorflow/lite/tools/evaluation/proto/BUILD
+++ b/tensorflow/lite/tools/evaluation/proto/BUILD
@@ -19,7 +19,7 @@
 )
 
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_proto_library_py",
 )
 
diff --git a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/README.md b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/README.md
index 07b9b18..382719f 100644
--- a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/README.md
+++ b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/README.md
@@ -191,7 +191,7 @@
 (8) Run the binary.
 
 ```
-adb shell /data/local/tmp/imagenet_accuracy_eval \
+adb shell /data/local/tmp/run_eval \
   --model_file=/data/local/tmp/mobilenet_quant_v1_224.tflite \
   --ground_truth_images_path=/data/local/tmp/ilsvrc_images \
   --ground_truth_labels=/data/local/tmp/ilsvrc_validation_labels.txt \
diff --git a/tensorflow/lite/tools/evaluation/utils.cc b/tensorflow/lite/tools/evaluation/utils.cc
index 162acba..d40afba 100644
--- a/tensorflow/lite/tools/evaluation/utils.cc
+++ b/tensorflow/lite/tools/evaluation/utils.cc
@@ -73,6 +73,7 @@
   return kTfLiteOk;
 }
 
+// TODO(b/138448769): Migrate delegate helper APIs to lite/testing.
 Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate() {
 #if defined(__ANDROID__)
   return Interpreter::TfLiteDelegatePtr(
@@ -108,7 +109,8 @@
     tflite::FlatBufferModel* model) {
 #if defined(__ANDROID__)
   TfLiteGpuDelegateOptions options;
-  options.metadata = TfLiteGpuDelegateGetModelMetadata(model->GetModel());
+  options.metadata =
+      model ? TfLiteGpuDelegateGetModelMetadata(model->GetModel()) : nullptr;
   options.compile_options.precision_loss_allowed = 1;
   options.compile_options.preferred_gl_object_type =
       TFLITE_GL_OBJECT_TYPE_FASTEST;
diff --git a/tensorflow/lite/tools/make/Makefile b/tensorflow/lite/tools/make/Makefile
index 89ef6e5..73c50d3 100644
--- a/tensorflow/lite/tools/make/Makefile
+++ b/tensorflow/lite/tools/make/Makefile
@@ -59,7 +59,6 @@
 CXXFLAGS += $(EXTRA_CXXFLAGS)
 CFLAGS := ${CXXFLAGS}
 CXXFLAGS += --std=c++11
-CFLAGS :=
 LDOPTS := -L/usr/local/lib
 ARFLAGS := -r
 TARGET_TOOLCHAIN_PREFIX :=
@@ -99,6 +98,7 @@
 $(wildcard tensorflow/lite/c/*.c) \
 $(wildcard tensorflow/lite/core/*.cc) \
 $(wildcard tensorflow/lite/core/api/*.cc) \
+tensorflow/lite/experimental/resource_variable/*.cc \
 tensorflow/lite/experimental/ruy/allocator.cc \
 tensorflow/lite/experimental/ruy/block_map.cc \
 tensorflow/lite/experimental/ruy/blocking_counter.cc \
@@ -106,7 +106,7 @@
 tensorflow/lite/experimental/ruy/detect_dotprod.cc \
 tensorflow/lite/experimental/ruy/kernel_arm32.cc \
 tensorflow/lite/experimental/ruy/kernel_arm64.cc \
-tensorflow/lite/experimental/ruy/pack.cc \
+tensorflow/lite/experimental/ruy/pack_arm.cc \
 tensorflow/lite/experimental/ruy/pmu.cc \
 tensorflow/lite/experimental/ruy/thread_pool.cc \
 tensorflow/lite/experimental/ruy/trace.cc \
@@ -161,9 +161,14 @@
 ifeq ($(TARGET),rpi)
 	BUILD_WITH_NNAPI=false
 endif
+ifeq ($(TARGET),generic-aarch64)
+	BUILD_WITH_NNAPI=false
+endif
 ifeq ($(BUILD_WITH_NNAPI),true)
 	CORE_CC_ALL_SRCS += tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
+  CORE_CC_ALL_SRCS += tensorflow/lite/delegates/nnapi/quant_lstm_sup.cc
 	CORE_CC_ALL_SRCS += tensorflow/lite/nnapi/nnapi_implementation.cc
+	LIBS += -lrt
 else
 	CORE_CC_ALL_SRCS += tensorflow/lite/delegates/nnapi/nnapi_delegate_disabled.cc
 	CORE_CC_ALL_SRCS += tensorflow/lite/nnapi/nnapi_implementation_disabled.cc
diff --git a/tensorflow/lite/tools/make/build_generic_aarch64_lib.sh b/tensorflow/lite/tools/make/build_generic_aarch64_lib.sh
new file mode 100755
index 0000000..d497b94
--- /dev/null
+++ b/tensorflow/lite/tools/make/build_generic_aarch64_lib.sh
@@ -0,0 +1,22 @@
+#!/bin/bash -x
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+set -e
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+cd "$SCRIPT_DIR/../../../.."
+
+CC_PREFIX=aarch64-linux-gnu- make -j 3 -f tensorflow/lite/tools/make/Makefile TARGET=generic-aarch64 TARGET_ARCH=armv8-a
diff --git a/tensorflow/lite/tools/make/targets/generic_aarch64_makefile.inc b/tensorflow/lite/tools/make/targets/generic_aarch64_makefile.inc
new file mode 100644
index 0000000..f4e4f1f
--- /dev/null
+++ b/tensorflow/lite/tools/make/targets/generic_aarch64_makefile.inc
@@ -0,0 +1,33 @@
+# Settings for generic aarch64 boards such as Odroid C2 or Pine64.
+ifeq ($(TARGET),generic-aarch64)
+  # The aarch64 architecture covers all 64-bit ARM chips. This arch mandates
+  # NEON, so FPU flags are not needed below.
+  TARGET_ARCH := armv8-a
+  TARGET_TOOLCHAIN_PREFIX := aarch64-linux-gnu-
+
+  CXXFLAGS += \
+    -march=armv8-a \
+    -funsafe-math-optimizations \
+    -ftree-vectorize \
+    -fPIC
+
+  CCFLAGS += \
+    -march=armv8-a \
+    -funsafe-math-optimizations \
+    -ftree-vectorize \
+    -fPIC
+
+  LDFLAGS := \
+    -Wl,--no-export-dynamic \
+    -Wl,--exclude-libs,ALL \
+    -Wl,--gc-sections \
+    -Wl,--as-needed
+
+       
+  LIBS := \
+    -lstdc++ \
+    -lpthread \
+    -lm \
+    -ldl
+
+endif
diff --git a/tensorflow/lite/tools/optimize/quantize_weights.cc b/tensorflow/lite/tools/optimize/quantize_weights.cc
index 89965e1..451faae 100644
--- a/tensorflow/lite/tools/optimize/quantize_weights.cc
+++ b/tensorflow/lite/tools/optimize/quantize_weights.cc
@@ -112,6 +112,16 @@
   return {};
 }
 
+// Checks that a specific input can be quantized.
+bool IsQuantizedInput(const OperatorCodeT* op_code,
+                      const CustomOpMap& custom_op_map, int op_input_idx) {
+  const auto quantized_input_indices =
+      GetWeightInputIndices(op_code, custom_op_map);
+  return std::find(std::begin(quantized_input_indices),
+                   std::end(quantized_input_indices),
+                   op_input_idx) != std::end(quantized_input_indices);
+}
+
 // Returns true if the operator supports hybrid evaluation.
 bool IsHybridEvaluationOp(const OperatorT* op, const OperatorCodeT* op_code,
                           const CustomOpMap& custom_op_map) {
@@ -390,7 +400,9 @@
           use_hybrid_evaluation &&
           IsHybridEvaluationOp(consumer_op, consumer_op_code, custom_op_map) &&
           CheckAllOpInputsQuantized(subgraph, consumer_op, consumer_op_code,
-                                    custom_op_map);
+                                    custom_op_map) &&
+          IsQuantizedInput(consumer_op_code, custom_op_map,
+                           consumer_op_info.op_input_idx);
       if (!eval_hybrid) {
         dequant_op_infos.push_back(consumer_op_info);
       }
diff --git a/tensorflow/lite/tools/visualize.py b/tensorflow/lite/tools/visualize.py
index ce604ee..2d83588 100644
--- a/tensorflow/lite/tools/visualize.py
+++ b/tensorflow/lite/tools/visualize.py
@@ -313,7 +313,7 @@
 
     nodes.append({
         "id": TensorName(tensor_index),
-        "name": "%r (%d)" % (tensor["shape"], tensor_index),
+        "name": "%r (%d)" % (getattr(tensor, "shape", []), tensor_index),
         "group": 1,
         "x": initial_y[1],
         "y": initial_y[0]
diff --git a/tensorflow/lite/tutorials/post_training_integer_quant.ipynb b/tensorflow/lite/tutorials/post_training_integer_quant.ipynb
deleted file mode 100644
index 1629b4c..0000000
--- a/tensorflow/lite/tutorials/post_training_integer_quant.ipynb
+++ /dev/null
@@ -1,651 +0,0 @@
-{
-  "nbformat": 4,
-  "nbformat_minor": 0,
-  "metadata": {
-    "colab": {
-      "name": "post-training--integer-quant.ipynb",
-      "version": "0.3.2",
-      "provenance": [],
-      "private_outputs": true,
-      "collapsed_sections": [],
-      "toc_visible": true
-    },
-    "kernelspec": {
-      "display_name": "Python 2",
-      "name": "python2"
-    }
-  },
-  "cells": [
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "6Y8E0lw5eYWm"
-      },
-      "source": [
-        "# Post Training Integer Quantization"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "CIGrZZPTZVeO"
-      },
-      "source": [
-        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
-        "  <td>\n",
-        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/tutorials/post_training_integer_quant.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
-        "  </td>\n",
-        "  <td>\n",
-        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tutorials/post_training_integer_quant.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
-        "  </td>\n",
-        "</table>"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "BTC1rDAuei_1"
-      },
-      "source": [
-        "## Overview\n",
-        "\n",
-        "[TensorFlow Lite](https://www.tensorflow.org/lite/) now supports\n",
-        "converting an entire model (weights and activations) to 8-bit during model conversion from TensorFlow to TensorFlow Lite's flat buffer format. This results in a 4x reduction in model size and a 3 to 4x performance improvement on CPU performance. In addition, this fully quantized model can be consumed by integer-only hardware accelerators.\n",
-        "\n",
-        "In contrast to [post-training \"on-the-fly\" quantization](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/tutorials/post_training_quant.ipynb)\n",
-        ", which only stores weights as 8-bit ints, in this technique all weights *and* activations are quantized statically during model conversion.\n",
-        "\n",
-        "In this tutorial, we train an MNIST model from scratch, check its accuracy in TensorFlow, and then convert the saved model into a Tensorflow Lite flatbuffer\n",
-        "with full quantization. We finally check the\n",
-        "accuracy of the converted model and compare it to the original saved model. We\n",
-        "run the training script [mnist.py](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py) from\n",
-        "[Tensorflow official MNIST tutorial](https://github.com/tensorflow/models/tree/master/official/mnist).\n"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "2XsEP17Zelz9"
-      },
-      "source": [
-        "## Building an MNIST model"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "dDqqUIZjZjac"
-      },
-      "source": [
-        "### Setup"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "metadata": {
-        "colab_type": "code",
-        "id": "gyqAw1M9lyab",
-        "colab": {}
-      },
-      "source": [
-        "! pip uninstall -y tensorflow\n",
-        "! pip install -U tf-nightly"
-      ],
-      "execution_count": 0,
-      "outputs": []
-    },
-    {
-      "cell_type": "code",
-      "metadata": {
-        "colab_type": "code",
-        "id": "WsN6s5L1ieNl",
-        "colab": {}
-      },
-      "source": [
-        "import tensorflow as tf\n",
-        "tf.enable_eager_execution()"
-      ],
-      "execution_count": 0,
-      "outputs": []
-    },
-    {
-      "cell_type": "code",
-      "metadata": {
-        "colab_type": "code",
-        "id": "00U0taBoe-w7",
-        "colab": {}
-      },
-      "source": [
-        "! git clone --depth 1 https://github.com/tensorflow/models"
-      ],
-      "execution_count": 0,
-      "outputs": []
-    },
-    {
-      "cell_type": "code",
-      "metadata": {
-        "colab_type": "code",
-        "id": "4XZPtSh-fUOc",
-        "colab": {}
-      },
-      "source": [
-        "import sys\n",
-        "import os\n",
-        "\n",
-        "if sys.version_info.major >= 3:\n",
-        "    import pathlib\n",
-        "else:\n",
-        "    import pathlib2 as pathlib\n",
-        "\n",
-        "# Add `models` to the python path.\n",
-        "models_path = os.path.join(os.getcwd(), \"models\")\n",
-        "sys.path.append(models_path)"
-      ],
-      "execution_count": 0,
-      "outputs": []
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "eQ6Q0qqKZogR"
-      },
-      "source": [
-        "### Train and export the model"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "metadata": {
-        "colab_type": "code",
-        "id": "eMsw_6HujaqM",
-        "colab": {}
-      },
-      "source": [
-        "saved_models_root = \"/tmp/mnist_saved_model\""
-      ],
-      "execution_count": 0,
-      "outputs": []
-    },
-    {
-      "cell_type": "code",
-      "metadata": {
-        "colab_type": "code",
-        "id": "hWSAjQWagIHl",
-        "colab": {}
-      },
-      "source": [
-        "# The above path addition is not visible to subprocesses, add the path for the subprocess as well.\n",
-        "# Note: channels_last is required here or the conversion may fail. \n",
-        "!PYTHONPATH={models_path} python models/official/mnist/mnist.py --train_epochs=1 --export_dir {saved_models_root} --data_format=channels_last"
-      ],
-      "execution_count": 0,
-      "outputs": []
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "5NMaNZQCkW9X"
-      },
-      "source": [
-        "For the example, we only trained the model for a single epoch, so it only trains to ~96% accuracy.\n",
-        "\n"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "xl8_fzVAZwOh"
-      },
-      "source": [
-        "### Convert to a TensorFlow Lite model\n",
-        "\n",
-        "The `savedmodel` directory is named with a timestamp. Select the most recent one: "
-      ]
-    },
-    {
-      "cell_type": "code",
-      "metadata": {
-        "colab_type": "code",
-        "id": "Xp5oClaZkbtn",
-        "colab": {}
-      },
-      "source": [
-        "saved_model_dir = str(sorted(pathlib.Path(saved_models_root).glob(\"*\"))[-1])\n",
-        "saved_model_dir"
-      ],
-      "execution_count": 0,
-      "outputs": []
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "AT8BgkKmljOy"
-      },
-      "source": [
-        "Using the [Python `TFLiteConverter`](https://www.tensorflow.org/lite/convert/python_api), the saved model can be converted into a TensorFlow Lite model.\n",
-        "\n",
-        "First load the model using the `TFLiteConverter`:"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "metadata": {
-        "colab_type": "code",
-        "id": "_i8B2nDZmAgQ",
-        "colab": {}
-      },
-      "source": [
-        "import tensorflow as tf\n",
-        "tf.enable_eager_execution()\n",
-        "tf.logging.set_verbosity(tf.logging.DEBUG)\n",
-        "\n",
-        "converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)\n",
-        "tflite_model = converter.convert()"
-      ],
-      "execution_count": 0,
-      "outputs": []
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "F2o2ZfF0aiCx"
-      },
-      "source": [
-        "Write it out to a `.tflite` file:"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "metadata": {
-        "colab_type": "code",
-        "id": "vptWZq2xnclo",
-        "colab": {}
-      },
-      "source": [
-        "tflite_models_dir = pathlib.Path(\"/tmp/mnist_tflite_models/\")\n",
-        "tflite_models_dir.mkdir(exist_ok=True, parents=True)"
-      ],
-      "execution_count": 0,
-      "outputs": []
-    },
-    {
-      "cell_type": "code",
-      "metadata": {
-        "colab_type": "code",
-        "id": "Ie9pQaQrn5ue",
-        "colab": {}
-      },
-      "source": [
-        "tflite_model_file = tflite_models_dir/\"mnist_model.tflite\"\n",
-        "tflite_model_file.write_bytes(tflite_model)"
-      ],
-      "execution_count": 0,
-      "outputs": []
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "7BONhYtYocQY"
-      },
-      "source": [
-        "To instead quantize the model on export, first set the `optimizations` flag to optimize for size:"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "metadata": {
-        "colab_type": "code",
-        "id": "HEZ6ET1AHAS3",
-        "colab": {}
-      },
-      "source": [
-        "tf.logging.set_verbosity(tf.logging.INFO)\n",
-        "converter.optimizations = [tf.lite.Optimize.DEFAULT]"
-      ],
-      "execution_count": 0,
-      "outputs": []
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "id": "rTe8avZJHMDO",
-        "colab_type": "text"
-      },
-      "source": [
-        "Now, construct and provide a representative dataset, this is used to get the dynamic range of activations."
-      ]
-    },
-    {
-      "cell_type": "code",
-      "metadata": {
-        "id": "FiwiWU3gHdkW",
-        "colab_type": "code",
-        "colab": {}
-      },
-      "source": [
-        "mnist_train, _ = tf.keras.datasets.mnist.load_data()\n",
-        "images = tf.cast(mnist_train[0], tf.float32)/255.0\n",
-        "mnist_ds = tf.data.Dataset.from_tensor_slices((images)).batch(1)\n",
-        "def representative_data_gen():\n",
-        "  for input_value in mnist_ds.take(100):\n",
-        "    yield [input_value]\n",
-        "\n",
-        "converter.representative_dataset = representative_data_gen"
-      ],
-      "execution_count": 0,
-      "outputs": []
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "id": "xW84iMYjHd9t",
-        "colab_type": "text"
-      },
-      "source": [
-        "Finally, convert the model like usual. Note, by default the converted model will still use float input and outputs for invocation convenience."
-      ]
-    },
-    {
-      "cell_type": "code",
-      "metadata": {
-        "id": "yuNfl3CoHNK3",
-        "colab_type": "code",
-        "colab": {}
-      },
-      "source": [
-        "tflite_quant_model = converter.convert()\n",
-        "tflite_model_quant_file = tflite_models_dir/\"mnist_model_quant.tflite\"\n",
-        "tflite_model_quant_file.write_bytes(tflite_quant_model)"
-      ],
-      "execution_count": 0,
-      "outputs": []
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "PhMmUTl4sbkz"
-      },
-      "source": [
-        "Note how the resulting file is approximately `1/4` the size."
-      ]
-    },
-    {
-      "cell_type": "code",
-      "metadata": {
-        "colab_type": "code",
-        "id": "JExfcfLDscu4",
-        "colab": {}
-      },
-      "source": [
-        "!ls -lh {tflite_models_dir}"
-      ],
-      "execution_count": 0,
-      "outputs": []
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "L8lQHMp_asCq"
-      },
-      "source": [
-        "## Run the TensorFlow Lite models"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "-5l6-ciItvX6"
-      },
-      "source": [
-        "We can run the TensorFlow Lite model using the Python TensorFlow Lite\n",
-        "Interpreter. \n",
-        "\n",
-        "### Load the test data\n",
-        "\n",
-        "First, let's load the MNIST test data to feed to the model:"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "metadata": {
-        "colab_type": "code",
-        "id": "eTIuU07NuKFL",
-        "colab": {}
-      },
-      "source": [
-        "import numpy as np\n",
-        "_, mnist_test = tf.keras.datasets.mnist.load_data()\n",
-        "images, labels = tf.cast(mnist_test[0], tf.float32)/255.0, mnist_test[1]\n",
-        "\n",
-        "mnist_ds = tf.data.Dataset.from_tensor_slices((images, labels)).batch(1)"
-      ],
-      "execution_count": 0,
-      "outputs": []
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "Ap_jE7QRvhPf"
-      },
-      "source": [
-        "### Load the model into the interpreters"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "metadata": {
-        "colab_type": "code",
-        "id": "Jn16Rc23zTss",
-        "colab": {}
-      },
-      "source": [
-        "interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))\n",
-        "interpreter.allocate_tensors()"
-      ],
-      "execution_count": 0,
-      "outputs": []
-    },
-    {
-      "cell_type": "code",
-      "metadata": {
-        "colab_type": "code",
-        "id": "J8Pztk1mvNVL",
-        "colab": {}
-      },
-      "source": [
-        "interpreter_quant = tf.lite.Interpreter(model_path=str(tflite_model_quant_file))\n",
-        "interpreter_quant.allocate_tensors()"
-      ],
-      "execution_count": 0,
-      "outputs": []
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "2opUt_JTdyEu"
-      },
-      "source": [
-        "### Test the models on one image"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "metadata": {
-        "colab_type": "code",
-        "id": "AKslvo2kwWac",
-        "colab": {}
-      },
-      "source": [
-        "for img, label in mnist_ds:\n",
-        "  break\n",
-        "\n",
-        "interpreter.set_tensor(interpreter.get_input_details()[0][\"index\"], img)\n",
-        "interpreter.invoke()\n",
-        "predictions = interpreter.get_tensor(\n",
-        "    interpreter.get_output_details()[0][\"index\"])"
-      ],
-      "execution_count": 0,
-      "outputs": []
-    },
-    {
-      "cell_type": "code",
-      "metadata": {
-        "colab_type": "code",
-        "id": "XZClM2vo3_bm",
-        "colab": {}
-      },
-      "source": [
-        "import matplotlib.pylab as plt\n",
-        "\n",
-        "plt.imshow(img[0])\n",
-        "template = \"True:{true}, predicted:{predict}\"\n",
-        "_ = plt.title(template.format(true= str(label[0].numpy()),\n",
-        "                              predict=str(predictions[0])))\n",
-        "plt.grid(False)"
-      ],
-      "execution_count": 0,
-      "outputs": []
-    },
-    {
-      "cell_type": "code",
-      "metadata": {
-        "colab_type": "code",
-        "id": "3gwhv4lKbYZ4",
-        "colab": {}
-      },
-      "source": [
-        "interpreter_quant.set_tensor(\n",
-        "    interpreter_quant.get_input_details()[0][\"index\"], img)\n",
-        "interpreter_quant.invoke()\n",
-        "predictions = interpreter_quant.get_tensor(\n",
-        "    interpreter_quant.get_output_details()[0][\"index\"])"
-      ],
-      "execution_count": 0,
-      "outputs": []
-    },
-    {
-      "cell_type": "code",
-      "metadata": {
-        "colab_type": "code",
-        "id": "CIH7G_MwbY2x",
-        "colab": {}
-      },
-      "source": [
-        "plt.imshow(img[0])\n",
-        "template = \"True:{true}, predicted:{predict}\"\n",
-        "_ = plt.title(template.format(true= str(label[0].numpy()),\n",
-        "                              predict=str(predictions[0])))\n",
-        "plt.grid(False)"
-      ],
-      "execution_count": 0,
-      "outputs": []
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "LwN7uIdCd8Gw"
-      },
-      "source": [
-        "### Evaluate the models"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "metadata": {
-        "colab_type": "code",
-        "id": "05aeAuWjvjPx",
-        "colab": {}
-      },
-      "source": [
-        "def eval_model(interpreter, mnist_ds):\n",
-        "  total_seen = 0\n",
-        "  num_correct = 0\n",
-        "\n",
-        "  input_index = interpreter.get_input_details()[0][\"index\"]\n",
-        "  output_index = interpreter.get_output_details()[0][\"index\"]\n",
-        "  for img, label in mnist_ds:\n",
-        "    total_seen += 1\n",
-        "    interpreter.set_tensor(input_index, img)\n",
-        "    interpreter.invoke()\n",
-        "    predictions = interpreter.get_tensor(output_index)\n",
-        "    if predictions == label.numpy():\n",
-        "      num_correct += 1\n",
-        "\n",
-        "    if total_seen % 500 == 0:\n",
-        "      print(\"Accuracy after %i images: %f\" %\n",
-        "            (total_seen, float(num_correct) / float(total_seen)))\n",
-        "\n",
-        "  return float(num_correct) / float(total_seen)"
-      ],
-      "execution_count": 0,
-      "outputs": []
-    },
-    {
-      "cell_type": "code",
-      "metadata": {
-        "colab_type": "code",
-        "id": "T5mWkSbMcU5z",
-        "colab": {}
-      },
-      "source": [
-        "print(eval_model(interpreter, mnist_ds))"
-      ],
-      "execution_count": 0,
-      "outputs": []
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "Km3cY9ry8ZlG"
-      },
-      "source": [
-        "We can repeat the evaluation on the fully quantized model to obtain:\n"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "metadata": {
-        "colab_type": "code",
-        "id": "-9cnwiPp6EGm",
-        "colab": {}
-      },
-      "source": [
-        "# NOTE: Colab runs on server CPUs. At the time of writing this, TensorFlow Lite\n",
-        "# doesn't have super optimized server CPU kernels. For this reason this may be\n",
-        "# slower than the above float interpreter. But for mobile CPUs, considerable\n",
-        "# speedup can be observed.\n",
-        "print(eval_model(interpreter_quant, mnist_ds))\n"
-      ],
-      "execution_count": 0,
-      "outputs": []
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "L7lfxkor8pgv"
-      },
-      "source": [
-        "In this example, we have fully quantized a model with no difference in the accuracy."
-      ]
-    }
-  ]
-}
diff --git a/tensorflow/lite/tutorials/post_training_quant.ipynb b/tensorflow/lite/tutorials/post_training_quant.ipynb
deleted file mode 100644
index 8bc02ee..0000000
--- a/tensorflow/lite/tutorials/post_training_quant.ipynb
+++ /dev/null
@@ -1,703 +0,0 @@
-{
-  "cells": [
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "6Y8E0lw5eYWm"
-      },
-      "source": [
-        "# Post Training Quantization"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "CIGrZZPTZVeO"
-      },
-      "source": [
-        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
-        "  \u003ctd\u003e\n",
-        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/tutorials/post_training_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
-        "  \u003c/td\u003e\n",
-        "  \u003ctd\u003e\n",
-        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tutorials/post_training_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
-        "  \u003c/td\u003e\n",
-        "\u003c/table\u003e"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "BTC1rDAuei_1"
-      },
-      "source": [
-        "## Overview\n",
-        "\n",
-        "[TensorFlow Lite](https://www.tensorflow.org/lite/) now supports\n",
-        "converting weights to 8 bit precision as part of model conversion from\n",
-        "tensorflow graphdefs to TFLite's flat buffer format. Weight quantization\n",
-        "achieves a 4x reduction in the model size. In addition, TFLite supports on the\n",
-        "fly quantization and dequantization of activations to allow for:\n",
-        "\n",
-        "1.  Using quantized kernels for faster implementation when available.\n",
-        "\n",
-        "2.  Mixing of floating-point kernels with quantized kernels for different parts\n",
-        "    of the graph.\n",
-        "\n",
-        "Note that the activations are always stored in floating point. For ops that\n",
-        "support quantized kernels, the activations are quantized to 8 bits of precision\n",
-        "dynamically prior to processing and are de-quantized to float precision after\n",
-        "processing. Depending on the model being converted, this can give a speedup over\n",
-        "pure floating point computation.\n",
-        "\n",
-        "In contrast to\n",
-        "[quantization aware training](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/quantize)\n",
-        ", the weights are quantized post training and the activations are quantized dynamically \n",
-        "at inference in this method.\n",
-        "Therefore, the model weights are not retrained to compensate for quantization\n",
-        "induced errors. It is important to check the accuracy of the quantized model to\n",
-        "ensure that the degradation is acceptable.\n",
-        "\n",
-        "In this tutorial, we train an MNIST model from scratch, check its accuracy in\n",
-        "tensorflow and then convert the saved model into a Tensorflow Lite flatbuffer\n",
-        "with weight quantization. We finally check the\n",
-        "accuracy of the converted model and compare it to the original saved model. We\n",
-        "run the training script mnist.py from\n",
-        "[Tensorflow official mnist tutorial](https://github.com/tensorflow/models/tree/master/official/mnist).\n"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "2XsEP17Zelz9"
-      },
-      "source": [
-        "## Building an MNIST model"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "dDqqUIZjZjac"
-      },
-      "source": [
-        "### Setup"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "gyqAw1M9lyab"
-      },
-      "outputs": [],
-      "source": [
-        "! pip uninstall -y tensorflow\n",
-        "! pip install -U tf-nightly"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "WsN6s5L1ieNl"
-      },
-      "outputs": [],
-      "source": [
-        "import tensorflow as tf\n",
-        "tf.enable_eager_execution()"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "00U0taBoe-w7"
-      },
-      "outputs": [],
-      "source": [
-        "! git clone --depth 1 https://github.com/tensorflow/models"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "4XZPtSh-fUOc"
-      },
-      "outputs": [],
-      "source": [
-        "import sys\n",
-        "import os\n",
-        "\n",
-        "if sys.version_info.major \u003e= 3:\n",
-        "    import pathlib\n",
-        "else:\n",
-        "    import pathlib2 as pathlib\n",
-        "\n",
-        "# Add `models` to the python path.\n",
-        "models_path = os.path.join(os.getcwd(), \"models\")\n",
-        "sys.path.append(models_path)"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "eQ6Q0qqKZogR"
-      },
-      "source": [
-        "### Train and export the model"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "eMsw_6HujaqM"
-      },
-      "outputs": [],
-      "source": [
-        "saved_models_root = \"/tmp/mnist_saved_model\""
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "hWSAjQWagIHl"
-      },
-      "outputs": [],
-      "source": [
-        "# The above path addition is not visible to subprocesses, add the path for the subprocess as well.\n",
-        "# Note: channels_last is required here or the conversion may fail. \n",
-        "!PYTHONPATH={models_path} python models/official/mnist/mnist.py --train_epochs=1 --export_dir {saved_models_root} --data_format=channels_last"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "5NMaNZQCkW9X"
-      },
-      "source": [
-        "For the example, we only trained the model for a single epoch, so it only trains to ~96% accuracy.\n",
-        "\n"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "xl8_fzVAZwOh"
-      },
-      "source": [
-        "### Convert to a TFLite model\n",
-        "\n",
-        "The `savedmodel` directory is named with a timestamp. Select the most recent one: "
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "Xp5oClaZkbtn"
-      },
-      "outputs": [],
-      "source": [
-        "saved_model_dir = str(sorted(pathlib.Path(saved_models_root).glob(\"*\"))[-1])\n",
-        "saved_model_dir"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "AT8BgkKmljOy"
-      },
-      "source": [
-        "Using the python `TFLiteConverter`, the saved model can be converted into a TFLite model.\n",
-        "\n",
-        "First load the model using the `TFLiteConverter`:"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "_i8B2nDZmAgQ"
-      },
-      "outputs": [],
-      "source": [
-        "import tensorflow as tf\n",
-        "tf.enable_eager_execution()\n",
-        "converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)\n",
-        "tflite_model = converter.convert()"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "F2o2ZfF0aiCx"
-      },
-      "source": [
-        "Write it out to a tflite file:"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "vptWZq2xnclo"
-      },
-      "outputs": [],
-      "source": [
-        "tflite_models_dir = pathlib.Path(\"/tmp/mnist_tflite_models/\")\n",
-        "tflite_models_dir.mkdir(exist_ok=True, parents=True)"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "Ie9pQaQrn5ue"
-      },
-      "outputs": [],
-      "source": [
-        "tflite_model_file = tflite_models_dir/\"mnist_model.tflite\"\n",
-        "tflite_model_file.write_bytes(tflite_model)"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "7BONhYtYocQY"
-      },
-      "source": [
-        "To quantize the model on export, set the `optimizations` flag to optimize for size:"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "g8PUvLWDlmmz"
-      },
-      "outputs": [],
-     "source": [
-        "# Note: If you don't have a recent tf-nightly installed, the\n",
-        "# \"optimizations\" line will have no effect.\n",
-        "tf.logging.set_verbosity(tf.logging.INFO)\n",
-        "converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]\n",
-        "tflite_quant_model = converter.convert()\n",
-        "tflite_model_quant_file = tflite_models_dir/\"mnist_model_quant.tflite\"\n",
-        "tflite_model_quant_file.write_bytes(tflite_quant_model)"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "PhMmUTl4sbkz"
-      },
-    "source": [
-        "Note how the resulting file, is approximately `1/4` the size."
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "JExfcfLDscu4"
-      },
-      "outputs": [],
-      "source": [
-        "!ls -lh {tflite_models_dir}"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "L8lQHMp_asCq"
-      },
-      "source": [
-        "## Run the TFLite models"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "-5l6-ciItvX6"
-      },
-      "source": [
-        "We can run the TensorFlow Lite model using the python TensorFlow Lite\n",
-        "Interpreter. \n",
-        "\n",
-        "### load the test data\n",
-        "\n",
-        "First let's load the mnist test data to feed to it:"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "eTIuU07NuKFL"
-      },
-      "outputs": [],
-      "source": [
-        "import numpy as np\n",
-        "mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()\n",
-        "images, labels = tf.cast(mnist_test[0], tf.float32)/255.0, mnist_test[1]\n",
-        "\n",
-        "# Note: If you change the batch size, then use \n",
-        "# `tf.lite.Interpreter.resize_tensor_input` to also change it for\n",
-        "# the interpreter.\n",
-        "mnist_ds = tf.data.Dataset.from_tensor_slices((images, labels)).batch(1)"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "Ap_jE7QRvhPf"
-      },
-      "source": [
-        "### Load the model into an interpreter"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "Jn16Rc23zTss"
-      },
-      "outputs": [],
-      "source": [
-        "interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))\n",
-        "interpreter.allocate_tensors()\n",
-        "input_index = interpreter.get_input_details()[0][\"index\"]\n",
-        "output_index = interpreter.get_output_details()[0][\"index\"]"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "J8Pztk1mvNVL"
-      },
-      "outputs": [],
-      "source": [
-        "tf.logging.set_verbosity(tf.logging.DEBUG)\n",
-        "interpreter_quant = tf.lite.Interpreter(model_path=str(tflite_model_quant_file))"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "Afl6yGvWyqAr"
-      },
-      "outputs": [],
-      "source": [
-        "interpreter_quant.allocate_tensors()\n",
-        "input_index = interpreter_quant.get_input_details()[0][\"index\"]\n",
-        "output_index = interpreter_quant.get_output_details()[0][\"index\"]\n"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "2opUt_JTdyEu"
-      },
-      "source": [
-        "### Test the model on one image"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "AKslvo2kwWac"
-      },
-      "outputs": [],
-      "source": [
-        "for img, label in mnist_ds.take(1):\n",
-        "  break\n",
-        "\n",
-        "interpreter.set_tensor(input_index, img)\n",
-        "interpreter.invoke()\n",
-        "predictions = interpreter.get_tensor(output_index)"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "XZClM2vo3_bm"
-      },
-      "outputs": [],
-      "source": [
-        "import matplotlib.pylab as plt\n",
-        "\n",
-        "plt.imshow(img[0])\n",
-        "template = \"True:{true}, predicted:{predict}\"\n",
-        "_ = plt.title(template.format(true= str(label[0].numpy()),\n",
-        "                              predict=str(predictions[0])))\n",
-        "plt.grid(False)"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "LwN7uIdCd8Gw"
-      },
-      "source": [
-        "### Evaluate the models"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "05aeAuWjvjPx"
-      },
-      "outputs": [],
-      "source": [
-        "def eval_model(interpreter, mnist_ds):\n",
-        "  total_seen = 0\n",
-        "  num_correct = 0\n",
-        "\n",
-        "  for img, label in mnist_ds:\n",
-        "    total_seen += 1\n",
-        "    interpreter.set_tensor(input_index, img)\n",
-        "    interpreter.invoke()\n",
-        "    predictions = interpreter.get_tensor(output_index)\n",
-        "    if predictions == label.numpy():\n",
-        "      num_correct += 1\n",
-        "\n",
-        "    if total_seen % 500 == 0:\n",
-        "        print(\"Accuracy after %i images: %f\" %\n",
-        "              (total_seen, float(num_correct) / float(total_seen)))\n",
-        "\n",
-        "  return float(num_correct) / float(total_seen)"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "DqXBnDfJ7qxL"
-      },
-      "outputs": [],
-      "source": [
-        "print(eval_model(interpreter, mnist_ds))"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "Km3cY9ry8ZlG"
-      },
-      "source": [
-        "We can repeat the evaluation on the weight quantized model to obtain:\n"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "-9cnwiPp6EGm"
-      },
-      "outputs": [],
-      "source": [
-        "print(eval_model(interpreter_quant, mnist_ds))\n"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "L7lfxkor8pgv"
-      },
-      "source": [
-        "\n",
-        "In this example, we have compressed model with no difference in the accuracy."
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "M0o1FtmWeKZm"
-      },
-      "source": [
-        "\n",
-        "\n",
-        "## Optimizing an existing model\n",
-        "\n",
-        "We now consider another example. Resnets with pre-activation layers (Resnet-v2) are widely used for vision applications.\n",
-        "  Pre-trained frozen graph for resnet-v2-101 is available at the\n",
-        "  [Tensorflow Lite model repository](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/models.md).\n",
-        "\n",
-        "We can convert the frozen graph to a TFLite flatbuffer with quantization by:\n"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "v5p5VcNPjILQ"
-      },
-      "outputs": [],
-      "source": [
-        "archive_path = tf.keras.utils.get_file(\"resnet_v2_101.tgz\", \"https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/resnet_v2_101.tgz\", extract=True)\n",
-        "archive_path = pathlib.Path(archive_path)\n",
-        "archive_dir = str(archive_path.parent)"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "-sxnXQuC4ThD"
-      },
-      "source": [
-        "The `info.txt` file lists the input and output names. You can also find them using TensorBoard to visually inspect the graph."
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "g_Q_OMEJ4LIc"
-      },
-      "outputs": [],
-      "source": [
-        "! cat {archive_dir}/resnet_v2_101_299_info.txt"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "ujCAFhqm-C6H"
-      },
-      "outputs": [],
-      "source": [
-        "graph_def_file = pathlib.Path(archive_path).parent/\"resnet_v2_101_299_frozen.pb\"\n",
-        "input_arrays = [\"input\"] \n",
-        "output_arrays = [\"output\"]\n",
-        "converter = tf.lite.TFLiteConverter.from_frozen_graph(\n",
-        "  str(graph_def_file), input_arrays, output_arrays, input_shapes={\"input\":[1,299,299,3]})\n",
-        "converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]\n",
-        "resnet_tflite_file = graph_def_file.parent/\"resnet_v2_101_quantized.tflite\"\n",
-        "resnet_tflite_file.write_bytes(converter.convert())\n"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 0,
-      "metadata": {
-        "colab": {},
-        "colab_type": "code",
-        "id": "vhOjeg1x9Knp"
-      },
-      "outputs": [],
-      "source": [
-        "\n",
-        "!ls -lh {archive_dir}/*.tflite"
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "qqHLaqFMCjRZ"
-      },
-      "source": [
-        "\n",
-        "The model size reduces from 171 MB to 43 MB.\n",
-        "The accuracy of this model on imagenet can be evaluated using the scripts provided for [TFLite accuracy measurement](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/accuracy/ilsvrc).\n",
-        "\n",
-        "The optimized model top-1 accuracy is 76.8, the same as the floating point model."
-      ]
-    }
-  ],
-  "metadata": {
-    "colab": {
-      "collapsed_sections": [],
-      "name": "post-training-quant.ipynb",
-      "private_outputs": true,
-      "provenance": [],
-      "toc_visible": true,
-      "version": "0.3.2"
-    },
-    "kernelspec": {
-      "display_name": "Python 2",
-      "name": "python2"
-    }
-  },
-  "nbformat": 4,
-  "nbformat_minor": 0
-}
diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files
index ccf39fe..9e0f62a 100644
--- a/tensorflow/opensource_only.files
+++ b/tensorflow/opensource_only.files
@@ -1,3 +1,4 @@
+llvm/llvm/projects/google_mlir/WORKSPACE
 tensorflow/contrib/mpi/BUILD
 tensorflow/stream_executor/build_defs.bzl
 tensorflow/python/autograph/core/config.py
@@ -71,9 +72,13 @@
 tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/BUILD
 tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/build_defs.bzl
 tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
+tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/BUILD
+tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/cc_toolchain_config.bzl
+tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/dummy_toolchain.bzl
 tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/BUILD
 tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc5-rocm/BUILD
 tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc5-rocm/cc_toolchain_config.bzl
+tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py/BUILD
 tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3_opt/BUILD
 tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
 tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
@@ -185,11 +190,6 @@
 tensorflow/third_party/clang_toolchain/download_clang.bzl
 tensorflow/third_party/clang_toolchain/BUILD
 tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
-tensorflow/third_party/mlir/BUILD
-tensorflow/third_party/mlir/mlir_configure.bzl
-tensorflow/third_party/mlir/bindings/python/BUILD
-tensorflow/third_party/mlir/test/BUILD
-tensorflow/third_party/mlir/tblgen.bzl
 tensorflow/third_party/gast.BUILD
 tensorflow/third_party/llvm/BUILD
 tensorflow/third_party/llvm/expand_cmake_vars.py
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index d5710ee..8946522 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -12,19 +12,19 @@
     "//tensorflow_models:__subpackages__",
     "//tensorflow_model_optimization:__subpackages__",
     "//third_party/py/cleverhans:__subpackages__",
-    "//third_party/py/neural_structured_learning/keras:__pkg__",
+    "//third_party/py/neural_structured_learning:__subpackages__",
     "//third_party/py/tensorflow_examples:__subpackages__",
     "//third_party/py/tf_slim:__subpackages__",
     # TODO(aselle): to pass open source test.
     "//bazel_pip/tensorflow/lite/toco/python:__pkg__",
 ]
 
-load("//tensorflow:tensorflow.bzl", "if_mlir", "if_not_v2", "if_not_windows", "tf_cuda_library", "tf_gen_op_wrapper_py", "py_test", "tf_py_test", "py_tests", "tf_py_build_info_genrule", "tf_cc_shared_object")
+load("//tensorflow:tensorflow.bzl", "if_mlir", "if_not_v2", "if_not_windows", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test")
 load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
 load("//tensorflow:tensorflow.bzl", "cuda_py_test")
 load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
-load("//tensorflow/core:platform/default/build_config.bzl", "pyx_library", "tf_proto_library", "tf_proto_library_py", "tf_additional_lib_deps", "tf_additional_all_protos", "tf_protos_grappler", "tf_additional_cupti_test_flags")  # @unused
-load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_plugin_deps", "tf_additional_verbs_deps", "tf_additional_mpi_deps", "tf_additional_gdr_deps", "if_static")
+load("//tensorflow/core/platform:default/build_config.bzl", "pyx_library", "tf_additional_all_protos", "tf_additional_cupti_test_flags", "tf_additional_lib_deps", "tf_proto_library", "tf_proto_library_py", "tf_protos_grappler")  # @unused
+load("//tensorflow/core/platform:default/build_config_root.bzl", "if_static", "tf_additional_gdr_deps", "tf_additional_mpi_deps", "tf_additional_plugin_deps", "tf_additional_verbs_deps")
 load("//tensorflow/python:build_defs.bzl", "tf_gen_op_wrapper_private_py")
 load(
     "//third_party/ngraph:build_defs.bzl",
@@ -2488,6 +2488,22 @@
     ],
 )
 
+cuda_py_test(
+    name = "collective_ops_gpu_test",
+    size = "small",
+    srcs = ["ops/collective_ops_gpu_test.py"],
+    additional_deps = [
+        ":client_testlib",
+        ":collective_ops",
+        ":framework_for_generated_wrappers",
+        "//third_party/py/numpy",
+    ],
+    tags = [
+        "no_cuda_on_cpu_tap",
+        "no_windows",
+    ],
+)
+
 py_library(
     name = "control_flow_grad",
     srcs =
@@ -2546,11 +2562,23 @@
     srcs_version = "PY2AND3",
     deps = [
         ":control_flow_util",
+        ":control_flow_v2_func_graphs",
         ":framework_ops",
         ":util",
         "//tensorflow/core:protos_all_py",
+        "//tensorflow/python/distribute:distribute_lib",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:function",
+        "//tensorflow/python/keras:base_layer_utils",
+    ],
+)
+
+py_library(
+    name = "control_flow_v2_func_graphs",
+    srcs = ["ops/control_flow_v2_func_graphs.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":func_graph",
     ],
 )
 
@@ -2560,11 +2588,56 @@
     srcs_version = "PY2AND3",
     deps = [
         ":control_flow_util",
+        ":control_flow_util_v2",
         ":framework_ops",
         ":util",
     ],
 )
 
+tf_py_test(
+    name = "control_flow_v2_toggles_test",
+    size = "small",
+    srcs = ["ops/control_flow_v2_toggles_test.py"],
+    additional_deps = [
+        ":control_flow_v2_toggles",
+        ":control_flow_util_v2",
+        ":client_testlib",
+        ":platform_test",
+    ],
+)
+
+tf_py_test(
+    name = "control_flow_v2_enable_test",
+    size = "small",
+    srcs = ["ops/control_flow_v2_enable_test.py"],
+    additional_deps = [
+        ":tf2",
+        ":control_flow_util",
+        ":client_testlib",
+        ":platform_test",
+    ],
+)
+
+tf_py_test(
+    name = "control_flow_v2_disable_test",
+    size = "small",
+    srcs = ["ops/control_flow_v2_disable_test.py"],
+    additional_deps = [
+        ":tf2",
+        ":control_flow_util",
+        ":client_testlib",
+        ":platform_test",
+    ],
+    # This tests that it is possible to disable cfv2 using env vars.
+    # This does not apply to TF 2.0 nightly builds which enable
+    # v2 behavior using `tf.compat.v1.enable_v2_behavior()` in which case
+    # `tf.compat.v1.disable_control_flow_v2()` needs to be used.
+    tags = [
+        "no_oss",
+        "no_pip",
+    ],
+)
+
 py_library(
     name = "cond_v2",
     srcs = [
@@ -3882,6 +3955,7 @@
         ":array_ops",
         ":cond_v2",
         ":control_flow_ops",
+        ":control_flow_v2_toggles",
         ":embedding_ops",
         ":framework_for_generated_wrappers",
         ":framework_test_lib",
@@ -4473,6 +4547,12 @@
     ] + if_mlir(["//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass"]),
 )
 
+tf_py_test(
+    name = "object_identity_test",
+    size = "small",
+    srcs = ["util/object_identity_test.py"],
+)
+
 # Placeholder for intenal nest_test comments.
 tf_py_test(
     name = "util_nest_test",
@@ -4803,6 +4883,7 @@
         "util/py_checkpoint_reader.i",
         "util/stat_summarizer.i",
         "util/tfprof.i",
+        "util/traceme.i",
         "util/transform_graph.i",
         "util/util.i",
         "//tensorflow/lite/toco/python:toco.i",
@@ -4851,6 +4932,7 @@
         "//tensorflow/core/debug",
         "//tensorflow/core/distributed_runtime:server_lib",
         "//tensorflow/core/profiler/internal:print_model_analysis",
+        "//tensorflow/core/profiler/internal:python_traceme",
         "//tensorflow/tools/graph_transforms:transform_graph_lib",
         "//tensorflow/lite/toco/python:toco_python_api",
         "//tensorflow/python/eager:pywrap_tfe_lib",
@@ -4893,6 +4975,22 @@
     output_group = "interface_library",
 )
 
+cc_import(
+    name = "_pywrap_tensorflow_internal_linux",
+    shared_library = "//tensorflow/python:lib_pywrap_tensorflow_internal.so",
+)
+
+cc_import(
+    name = "_pywrap_tensorflow_internal_macos",
+    shared_library = "//tensorflow/python:lib_pywrap_tensorflow_internal.dylib",
+)
+
+cc_import(
+    name = "_pywrap_tensorflow_internal_windows",
+    interface_library = "//tensorflow/python:pywrap_tensorflow_import_lib_file",
+    shared_library = "//tensorflow/python:_pywrap_tensorflow_internal.dll",
+)
+
 # Rename the import library for _pywrap_tensorflow_internal.pyd to _pywrap_tensorflow_internal.lib
 # (It was _pywrap_tensorflow_internal.so.if.lib).
 genrule(
@@ -5114,7 +5212,6 @@
     grpc_enabled = True,
     tags = [
         "no_oss",  # Test flaky due to port collisions.
-        "nofwdcompat",  # b/137641346
         "notsan",  # data race due to b/62910646
         "oss_serial",
     ],
@@ -5402,7 +5499,10 @@
         ":variable_scope",
         ":variables",
     ],
-    tags = ["notsan"],
+    tags = [
+        "no_windows",  # b/139083295: bfloat16 tests fail on Windows
+        "notsan",
+    ],
     xla_enable_strict_auto_jit = True,
 )
 
@@ -6133,6 +6233,22 @@
 )
 
 cuda_py_test(
+    name = "collective_ops_benchmark",
+    srcs = ["ops/collective_ops_benchmark.py"],
+    additional_deps = [
+        ":array_ops",
+        ":client",
+        ":client_testlib",
+        ":collective_ops",
+        ":framework_for_generated_wrappers",
+        ":platform",
+        ":variables",
+        "//tensorflow/core:protos_all_py",
+    ],
+    main = "ops/collective_ops_benchmark.py",
+)
+
+cuda_py_test(
     name = "concat_benchmark",
     srcs = ["ops/concat_benchmark.py"],
     additional_deps = [
@@ -6349,6 +6465,7 @@
     additional_deps = [
         ":array_ops",
         ":client_testlib",
+        ":framework_combinations",
         ":framework_for_generated_wrappers",
         ":tf_item",
         "//tensorflow/core:protos_all_py",
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index 4975568..4e5477d 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -142,6 +142,13 @@
 from tensorflow.python.eager.def_function import function
 from tensorflow.python.framework.ops import enable_eager_execution
 
+# Check whether TF2_BEHAVIOR is turned on.
+from tensorflow.python.eager import monitoring as _monitoring
+from tensorflow.python import tf2 as _tf2
+_tf2_gauge = _monitoring.BoolGauge('/tensorflow/api/tf2_enable',
+                                   'Environment variable TF2_BEHAVIOR is set".')
+_tf2_gauge.get_cell().set(_tf2.enabled())
+
 # Necessary for the symbols in this module to be taken into account by
 # the namespace management system (API decorators).
 from tensorflow.python.ops import rnn
diff --git a/tensorflow/python/autograph/LIMITATIONS.md b/tensorflow/python/autograph/LIMITATIONS.md
deleted file mode 100644
index b4e4ca6..0000000
--- a/tensorflow/python/autograph/LIMITATIONS.md
+++ /dev/null
@@ -1,50 +0,0 @@
-# Capabilities and Limitations
-
-TF AutoGraph converts Eager Python code into TensorFlow graph-mode code. For example, users write code with `if` and `while` and AutoGraph automatically converts it into the equivalent `tf.cond`, and `tf.while_loop`.
-
-Python is a large language, so hoping to convert arbitrary Python code directly to TF graphs is overly ambitious. However, the Python code written to metaprogram TF graphs is in practice a restricted subset. We aim to support as much of this subset as possible. The table below lays out what we currently handle, what we hope to support, and what we have no plans to support.
-
-# Python Language Support Status
-
-Note: as more complex features in TensorFlow are made more accessible using AutoGraph, we expect to come across use cases that haven't been tried before, some of which might reveal rare bugs. If we do find any such bugs, we may add additional restrictions for the affected configurations, until those bugs are resolved.
-
-Construct                   | Supported now? | Plan to support? | Notes
-:-------------------------- | :------------: | :--------------: | :----
-If statement                | Yes            |                  | Converts to `tf.cond`. If variables are created in one branch that don’t exist in another, which is inexpressible in TF, we throw a clear error.
-For statement               | Yes            |                  | We will specialize `for` loops with unknown and known lengths, as well as for loops over TF datasets. Converts to `tf.while_loop`, with an additional `maximum_iterations` hint, if that is known. Creating variables inside the loop that are used later outside the loop is not supported, as the loop may have no iterations.
-While statement             | Yes            |                  | Converts to `tf.while_loop`. Creating variables inside the loop is not supported, as the loop may have no iterations.
-Continue and break          | Yes            |                  | Converts to boolean flags and extra predicates in loop tests.
-Composition of control flow | Yes            |                  | Arbitrary composition of `if`, `while`, `for`, `break`, and `continue`, along with other supported language elements, is supported and tested.
-Iterators                   | Some           | Yes              | Not all iterators supported, but we plan to support everything that can be desugared, such as `enumerate` and `zip`.
-Multiple return values      | Yes            |                  | We desugar them into variables, boolean flags and conditionals so that the function has a single return value at the end, and provide a clear error if we are unable to do so.
-Print expression            | Yes            |                  | Wrapped in `PyFunc`, and given proper control dependencies. Optional support for using tf.Log when py_func is undesirable exists.
-Static function calls       | Yes            |                  | Non-recursive function calls
-Nested call trees           | Yes            |                  | For example, `f` calls `g` which calls `h`, all of which need conversion.
-Recursive function calls    | No             | Maybe            | Based on available support in TF. Currently `function.Defun` is the best candidate, but it is not reentrant.
-Python built-ins            | Some           | Yes              | `print`, `len`, `range`, `xrange`, `int`, `float` are supported, and we plan to support or clearly error on all [Python built-ins](https://docs.python.org/3/library/functions.html).
-List operations             | Yes            |                  | We convert list creation, append, pop and indexing to their TF TensorArray equivalents. However, we do need some extra type hints to fully convert correctly. We hope to remove this limitation.
-Function variables          | Yes            |                  | e.g. `f_new = f_orig; f_new()`
-Lambda functions            | No             | Yes              | Planned feature.
-Classes                     | Yes            |                  | Classes can be converted all at once, or method-by-method. Some limitations exist around static and class methods.
-Subclasses                  | Yes            |                  | Subclassing library objects like tf.keras.Model is also supported.
-Dynamic types               | Some           |                  | `o = C1() if foo else C2(); o.bar()`. Some scenarios where types are data-dependent may not be supported. We will raise a meaningful error in that case.
-Dynamic code / exec         | No             |                  |
-Reflection                  | No             |                  |
-Try / Except                | No             | No               | No current sane TF equivalent.
-Global variables            | Restricted     |                  | In general, we only support read-only access to arguments or variables defined outside the converted code. A few exceptions include TensorFlow library code.
-Functions with side effects | Some           |                  | Side effects are allowed, under certain circumstances.
-Collections                 | Some           | Yes              | We currently support lists. There are currently no TF equivalents of dictionaries or tuples.
-List Comprehensions         | Yes            |                  | We desugar `ListComp` into the appropriate combination of `For` and `If` statements. Other comprehensions are currently very low priority.
-Custom context managers     | No             | Yes              | Currently low priority. Left unconverted currently.
-Generators                  | No             | Maybe            | Could be achievable using queues; very low priority.
-Assertions                  | Yes            |                  | As `tf.Assert`
-Deletion                    | Yes            | Maybe            | Currently unconverted. If new semantics are required for `del`, we are able to add it in.
-Inline imports              | No             | Yes              | For example, `import numpy as np; np.eye(3)`. Currently low priority.
-Async                       | No             | No               |
-
-## Extra capabilities
-
- - We liberally add name scopes to generated functions
- - Operations get decent default names everywhere (planned)
- - Statements that have no output values are given correct control dependencies. For example, `for i in range(n): print(i)` will have control dependencies to ensure the `print` statements are executed serially.
-
diff --git a/tensorflow/python/autograph/README.md b/tensorflow/python/autograph/README.md
deleted file mode 100644
index bfe21b4..0000000
--- a/tensorflow/python/autograph/README.md
+++ /dev/null
@@ -1,143 +0,0 @@
-# AutoGraph
-
-IMPORTANT: AutoGraph is beta software, and under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! We'd also love contributions ([please see our contributing guidelines](CONTRIBUTING.md) and our [style guide](STYLE_GUIDE.md)).
-
-AutoGraph is a Python to TensorFlow compiler.
-
-With AutoGraph, you can write [Eager style](https://www.tensorflow.org/guide/eager) code in a concise manner, and run it as a TensorFlow graph. AutoGraph uses source code transformation and partial evaluation to generate Python code that builds an equivalent TensorFlow subgraph. The result is code that behaves like ops and can be freely combined with other TensorFlow ops.  [Please see this file for which parts of the Python language we currently support](LIMITATIONS.md).
-
-For example, this Python function:
-
-```
-def f(x):
-  if x < 0:
-    x = -x
-  return x
-```
-
-would be converted to this:
-
-```
-def graph_mode_f(x):
-  with tf.name_scope('f'):
-
-    def if_true():
-      with tf.name_scope('if_true'):
-        x_1, = x,
-        x_1 = tf.negative(x_1)
-        return x_1,
-
-    def if_false():
-      with tf.name_scope('if_false'):
-        x_1, = x,
-        return x_1,
-    x = ag__.utils.run_cond(tf.greater(x, 0), if_true, if_false)
-    return x
-```
-
-so you can use it like an op:
-
-```
-with tf.Graph().as_default():
-  x = tf.constant(-1.0)
-
-  converted_f = autograph.to_graph(f)
-  y = converted_f(x)
-
-  with tf.Session() as sess:
-    print(sess.run(y))
-    # Output: 1
-```
-
-# Getting started
-
-Use AutoGraph in one of the following ways, described below:
-
- 1. Annotations (simpler)
- 2. Functional API (more flexible)
-
-To get started, install the latest nightly TensorFlow build:
-
-```shell
-pip install -U tf-nightly
-```
-
-Then import the `autograph` module from `tf.contrib`:
-
-```
-from tensorflow.python import autograph as ag
-```
-
-### Related links
-
-Articles:
-
- * [TensorFlow blog post](https://medium.com/tensorflow/autograph-converts-python-into-tensorflow-graphs-b2a871f87ec7)
-
-Interactive notebooks:
-
- * [Quick guide](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/guide/autograph.ipynb)
- * [RNN trained using Keras and Estimators](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb)
- * [Demo from the TF Dev Summit 2018](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb)
- * [Basic control flow speed test](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_collatz_speed_test.ipynb)
- * [MNIST training speed test](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_mnist_speed_test.ipynb)
- * [Basic algorithm samples](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/algorithms.ipynb)
- * [Introductory workshop support notebook](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb)
-
-## Using with annotations
-
-Annotating a function or class with `@convert` converts it in place:
-
-```
-@ag.convert()
-def f(x):
-  if x < 0:
-    x = -x
-  return x
-```
-
-... so that it always outputs TensorFlow code:
-
-```
-with tf.Graph().as_default():
-  x = tf.constant(-1)
-
-  y = f(x)
-
-  with tf.Session() as sess:
-    print(sess.run(y))
-    # Output: 1
-```
-
-## Using the functional API
-
-The functional API allows you to convert an existing function, class or object after it was defined:
-
-```
-converted_f = ag.to_graph(f)
-
-print(converted_f(tf.constant(-1)))
-# Output: Tensor
-
-print(f(-1))
-# Output: 1
-```
-
-You can use the functional API to inspect the generated code as well:
-
-```
-print(ag.to_code(f))
-# Output: <Python and TensorFlow code>
-```
-
-## Filing bugs and feature requests
-
-### Reporting a bug
-
- - If AutoGraph-generated code is compiling and running, but producing an incorrect result, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message.
- - If AutoGraph-generated code is compiling, but not running, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message.
- - If AutoGraph-generated code is not compiling, send us two minimal pieces of code. First, the Eager code that you would like to write, and second, the Graph code that you would like AutoGraph to have generated for you.
-
-### Requesting a feature
-
-If you’d like AutoGraph to convert a feature of Python or TF that we currently don’t handle, please let us know by filing a bug. We’ll make it as easy as possible to interact with us through there.
diff --git a/tensorflow/python/autograph/__init__.py b/tensorflow/python/autograph/__init__.py
index 0d2d8a1..4132cd5 100644
--- a/tensorflow/python/autograph/__init__.py
+++ b/tensorflow/python/autograph/__init__.py
@@ -40,7 +40,6 @@
 from tensorflow.python.autograph.impl.api import convert
 from tensorflow.python.autograph.impl.api import converted_call
 from tensorflow.python.autograph.impl.api import do_not_convert
-from tensorflow.python.autograph.impl.api import RunMode
 from tensorflow.python.autograph.impl.api import StackTraceMapper
 from tensorflow.python.autograph.impl.api import to_code
 from tensorflow.python.autograph.impl.api import to_graph
@@ -56,7 +55,6 @@
     'AutoGraphError',
     'ConversionOptions',
     'Feature',
-    'RunMode',
     'StackTraceMapper',
     'convert',
     'converted_call',
diff --git a/tensorflow/python/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD
index 7a6a77e..0f6189c 100644
--- a/tensorflow/python/autograph/converters/BUILD
+++ b/tensorflow/python/autograph/converters/BUILD
@@ -32,7 +32,6 @@
         "lists.py",
         "logical_expressions.py",
         "return_statements.py",
-        "side_effect_guards.py",
         "slices.py",
     ],
     srcs_version = "PY2AND3",
@@ -183,18 +182,6 @@
 )
 
 py_test(
-    name = "side_effect_guards_test",
-    srcs = ["side_effect_guards_test.py"],
-    srcs_version = "PY2AND3",
-    tags = ["notsan"],
-    deps = [
-        ":converters",
-        "//tensorflow/python:client_testlib",
-        "//tensorflow/python/autograph/core:test_lib",
-    ],
-)
-
-py_test(
     name = "return_statements_test",
     srcs = ["return_statements_test.py"],
     srcs_version = "PY2AND3",
diff --git a/tensorflow/python/autograph/converters/asserts_test.py b/tensorflow/python/autograph/converters/asserts_test.py
index 061b63f..0302964 100644
--- a/tensorflow/python/autograph/converters/asserts_test.py
+++ b/tensorflow/python/autograph/converters/asserts_test.py
@@ -19,31 +19,28 @@
 from __future__ import print_function
 
 from tensorflow.python.autograph.converters import asserts
-from tensorflow.python.autograph.converters import side_effect_guards
+from tensorflow.python.autograph.converters import function_scopes
 from tensorflow.python.autograph.core import converter_testing
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import errors_impl
-from tensorflow.python.framework import test_util
-from tensorflow.python.ops import gen_control_flow_ops
+from tensorflow.python.framework import ops
 from tensorflow.python.platform import test
 
 
 class AssertsTest(converter_testing.TestCase):
 
-  @test_util.run_deprecated_v1
   def test_basic(self):
 
     def test_fn(a):
-      assert a, 'test message'
-      return tf.no_op()  # pylint:disable=undefined-variable
+      assert a, 'testmsg'
+      return a
 
-    with self.converted(test_fn, (asserts, side_effect_guards), {},
-                        (gen_control_flow_ops.no_op,)) as result:
-      with self.cached_session() as sess:
+    with ops.Graph().as_default():
+      with self.converted(test_fn, (function_scopes, asserts), {}) as result:
         op = result.test_fn(constant_op.constant(False))
-        with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
-                                     'test message'):
-          self.evaluate(op)
+
+      with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'testmsg'):
+        self.evaluate(op)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/autograph/converters/call_trees.py b/tensorflow/python/autograph/converters/call_trees.py
index 52e6af5..5a5a2c9 100644
--- a/tensorflow/python/autograph/converters/call_trees.py
+++ b/tensorflow/python/autograph/converters/call_trees.py
@@ -16,6 +16,8 @@
 
 Note: this transformer does not rename the top level object being converted;
 that is the caller's responsibility.
+
+Requires function_scopes.
 """
 
 from __future__ import absolute_import
@@ -29,6 +31,7 @@
 from tensorflow.python.autograph.pyct import ast_util
 from tensorflow.python.autograph.pyct import parser
 from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.utils import ag_logging
 
 
 # TODO(mdan): Rename to FunctionCallsTransformer.
@@ -38,12 +41,37 @@
 
   no_root = True
 
+  def __init__(self):
+    self.context_name = None
+
+
+set_trace_warned = False
+
 
 class CallTreeTransformer(converter.Base):
   """Transforms the call tree by renaming transformed symbols."""
 
+  def visit_Lambda(self, node):
+    if anno.hasanno(node, 'function_context_name'):
+      # Lambda functions created during the conversion process have no
+      # context manager.
+      self.state[_Function].enter()
+      self.state[_Function].context_name = anno.getanno(
+          node, 'function_context_name')
+      node = self.generic_visit(node)
+      self.state[_Function].exit()
+    else:
+      node = self.generic_visit(node)
+    return node
+
   def visit_FunctionDef(self, node):
     self.state[_Function].enter()
+    # Note: if the conversion process ever creates helper functions, this
+    # assumption will no longer hold.
+    assert anno.hasanno(node, 'function_context_name'), (
+        'The function_scopes converter always creates a scope for functions.')
+    self.state[_Function].context_name = anno.getanno(
+        node, 'function_context_name')
     node.args = self.visit(node.args)
     node.body = self.visit_block(node.body)
 
@@ -72,6 +100,7 @@
 
   def visit_Call(self, node):
     full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
+    function_context_name = self.state[_Function].context_name
     node = self.generic_visit(node)
 
     # TODO(mdan): Refactor converted_call as a 'Call' operator.
@@ -81,11 +110,26 @@
     if full_name.startswith('ag__.'):
       return node
 
+    # Calls to the function context manager (inserted by function_scopes) are
+    # also safe.
+    if full_name.startswith(function_context_name + '.'):
+      return node
+
     # Calls to pdb.set_trace or ipdb.set_trace are never converted. We don't use
     # the normal mechanisms to bypass these literals because they are sensitive
     # to the frame they are being called from.
     # TODO(mdan): Generalize this to a "static whitelist" config.
-    if full_name in ('pdb.set_trace', 'ipdb.set_trace'):
+    if full_name in ('pdb.set_trace', 'ipdb.set_trace', 'breakpoint'):
+      global set_trace_warned
+      if not set_trace_warned:
+        # TODO(mdan): Update and shorten once available on tensorflow.org.
+        ag_logging.warn(
+            'Detected `pdb.set_trace()` in converted code. The code'
+            ' generated by AutoGraph is not optimized for step-by-step'
+            ' debugging. See https://github.com/tensorflow/tensorflow/'
+            'blob/master/tensorflow/python/autograph/g3doc/reference/'
+            'debugging.md.')
+        set_trace_warned = True
       return node
 
     if (full_name == 'print' and
@@ -130,15 +174,15 @@
           keywords=ast_util.keywords_to_dict(normal_keywords))
 
     template = """
-      ag__.converted_call(func, options, args, kwargs)
+      ag__.converted_call(func, options, args, kwargs, function_ctx)
     """
     new_call = templates.replace_as_expression(
         template,
         func=func,
-        options=self.ctx.program.options.to_ast(
-            internal_convert_user_code=self.ctx.program.options.recursive),
+        options=parser.parse_expression(function_context_name + '.callopts'),
         args=args,
-        kwargs=kwargs)
+        kwargs=kwargs,
+        function_ctx=function_context_name)
 
     return new_call
 
diff --git a/tensorflow/python/autograph/converters/call_trees_test.py b/tensorflow/python/autograph/converters/call_trees_test.py
index b77248b..6336d38 100644
--- a/tensorflow/python/autograph/converters/call_trees_test.py
+++ b/tensorflow/python/autograph/converters/call_trees_test.py
@@ -21,6 +21,7 @@
 import imp
 
 from tensorflow.python.autograph.converters import call_trees
+from tensorflow.python.autograph.converters import function_scopes
 from tensorflow.python.autograph.core import converter_testing
 from tensorflow.python.platform import test
 
@@ -32,7 +33,7 @@
     def test_fn(f):
       return f() + 20
 
-    with self.converted(test_fn, call_trees, {}) as result:
+    with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
       self.assertEqual(result.test_fn(lambda: 1), 21)
       self.assertListEqual(self.dynamic_calls, [((), None)])
 
@@ -41,7 +42,7 @@
     def test_fn(f, g):
       return f(g() + 20) + 4000
 
-    with self.converted(test_fn, call_trees, {}) as result:
+    with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
       self.assertEqual(result.test_fn(lambda x: x + 300, lambda: 1), 4321)
       self.assertListEqual(self.dynamic_calls, [
           ((), None),
@@ -53,7 +54,7 @@
     def test_fn(f, g):
       return f(g()) + 300
 
-    with self.converted(test_fn, call_trees, {}) as result:
+    with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
       self.assertEqual(result.test_fn(lambda x: x + 20, lambda: 1), 321)
       self.assertListEqual(self.dynamic_calls, [
           ((), None),
@@ -68,8 +69,8 @@
     def test_fn():
       return get_one().__add__(20)
 
-    with self.converted(test_fn, call_trees, {'get_one': get_one},
-                        ()) as result:
+    with self.converted(test_fn, (function_scopes, call_trees),
+                        {'get_one': get_one}, ()) as result:
 
       self.assertEqual(result.test_fn(), 21)
 
@@ -83,7 +84,7 @@
     def test_fn(f, a, b):
       return f(a, c=b) + 300
 
-    with self.converted(test_fn, call_trees, {}) as result:
+    with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
       self.assertEqual(result.test_fn(lambda a, c: a + c, 1, 20), 321)
       self.assertListEqual(self.dynamic_calls, [((1,), {'c': 20})])
 
@@ -92,7 +93,7 @@
     def test_fn(f, a, *args, **kwargs):
       return f(a, *args, **kwargs) + 5
 
-    with self.converted(test_fn, call_trees, {}) as result:
+    with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
       self.assertEqual(
           result.test_fn(lambda *args, **kwargs: 7, 1, *[2, 3], **{
               'b': 4,
@@ -109,7 +110,8 @@
       args = [1, 20, 300]
       return f(*args) + 4000
 
-    with self.converted(test_fn, call_trees, {'f': f}) as result:
+    with self.converted(test_fn, (function_scopes, call_trees),
+                        {'f': f}) as result:
       self.assertEqual(result.test_fn(), 4321)
       self.assertListEqual(self.dynamic_calls, [((1, 20, 300), None)])
 
@@ -118,7 +120,7 @@
     def test_fn(f, a, b, **kwargs):
       return f(a, b=b, **kwargs) + 5
 
-    with self.converted(test_fn, call_trees, {}) as result:
+    with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
       self.assertEqual(
           result.test_fn(lambda *args, **kwargs: 7, 1, 2, **{'c': 3}), 12)
       self.assertListEqual(self.dynamic_calls, [((1,), {'b': 2, 'c': 3})])
@@ -133,7 +135,8 @@
     def test_fn():
       return pdb.set_trace()
 
-    with self.converted(test_fn, call_trees, {'pdb': pdb}) as result:
+    with self.converted(test_fn, (function_scopes, call_trees),
+                        {'pdb': pdb}) as result:
       result.test_fn()
       self.assertListEqual(tracking_list, [1])
 
@@ -148,7 +151,8 @@
         return self.other_method(a) + 300
 
     tc = TestClass()
-    with self.converted(TestClass.test_method, call_trees, {}) as result:
+    with self.converted(TestClass.test_method, (function_scopes, call_trees),
+                        {}) as result:
       self.assertEqual(321, result.test_method(tc, 1))
       self.assertListEqual(self.dynamic_calls, [((1,), None)])
 
@@ -163,7 +167,8 @@
         return self.other_method(a) + 300
 
     tc = TestClass()
-    with self.converted(tc.test_method, call_trees, {}) as result:
+    with self.converted(tc.test_method, (function_scopes, call_trees),
+                        {}) as result:
       self.assertEqual(321, result.test_method(tc, 1))
       self.assertListEqual(self.dynamic_calls, [((1,), None)])
 
diff --git a/tensorflow/python/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py
index 7f846ba..c4b0e14 100644
--- a/tensorflow/python/autograph/converters/control_flow.py
+++ b/tensorflow/python/autograph/converters/control_flow.py
@@ -77,12 +77,14 @@
           template, body_name=body_name, body=body, return_stmt=return_stmt)
 
   def _create_cond_expr(self, results, test, body_name, orelse_name,
-                        state_getter_name,
-                        state_setter_name):
+                        state_getter_name, state_setter_name,
+                        basic_symbol_names, composite_symbol_names):
     if results is not None:
       template = """
         results = ag__.if_stmt(test, body_name, orelse_name,
-                               state_getter_name, state_setter_name)
+                               state_getter_name, state_setter_name,
+                               (basic_symbol_names,),
+                               (composite_symbol_names,))
       """
       return templates.replace(
           template,
@@ -91,10 +93,13 @@
           body_name=body_name,
           orelse_name=orelse_name,
           state_getter_name=state_getter_name,
-          state_setter_name=state_setter_name)
+          state_setter_name=state_setter_name,
+          basic_symbol_names=basic_symbol_names,
+          composite_symbol_names=composite_symbol_names)
     else:
       template = """
-        ag__.if_stmt(test, body_name, orelse_name, getter_name, setter_name)
+        ag__.if_stmt(test, body_name, orelse_name, getter_name, setter_name,
+                     (basic_symbol_names,), (composite_symbol_names,))
       """
       return templates.replace(
           template,
@@ -102,7 +107,9 @@
           body_name=body_name,
           orelse_name=orelse_name,
           getter_name=state_getter_name,
-          setter_name=state_setter_name)
+          setter_name=state_setter_name,
+          basic_symbol_names=basic_symbol_names,
+          composite_symbol_names=composite_symbol_names)
 
   def _fmt_symbols(self, symbol_set):
     if not symbol_set:
@@ -119,10 +126,12 @@
     # Composite symbols are handled elsewhere see _create_state_functions
     return {s for s in modified_live if not s.is_composite()}
 
-  def _create_state_functions(self, composites,
-                              state_getter_name, state_setter_name):
+  def _create_state_functions(self, composites, state_getter_name,
+                              state_setter_name):
+
     if composites:
       composite_tuple = tuple(composites)
+
       template = """
         def state_getter_name():
           return composite_tuple,
@@ -231,6 +240,8 @@
     state_setter_name = self.ctx.namer.new_symbol('set_state', all_referenced)
 
     returned_from_cond = tuple(returned_from_cond)
+    composites = tuple(composites)
+
     if returned_from_cond:
       if len(returned_from_cond) == 1:
         cond_results = returned_from_cond[0]
@@ -275,9 +286,15 @@
     composite_defs = self._create_state_functions(
         composites, state_getter_name, state_setter_name)
 
+    basic_symbol_names = tuple(
+        gast.Str(str(symbol)) for symbol in returned_from_cond)
+    composite_symbol_names = tuple(
+        gast.Str(str(symbol)) for symbol in composites)
+
     cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name,
                                        orelse_name, state_getter_name,
-                                       state_setter_name)
+                                       state_setter_name, basic_symbol_names,
+                                       composite_symbol_names)
 
     if_ast = (
         undefined_assigns + composite_defs + body_def + orelse_def +
@@ -361,6 +378,11 @@
     state_functions = self._create_state_functions(
         composite_loop_vars, state_getter_name, state_setter_name)
 
+    basic_symbol_names = tuple(
+        gast.Str(str(symbol)) for symbol in basic_loop_vars)
+    composite_symbol_names = tuple(
+        gast.Str(str(symbol)) for symbol in composite_loop_vars)
+
     # TODO(mdan): Use a single template.
     # If the body and test functions took a single tuple for loop_vars, instead
     # of *loop_vars, then a single template could be used.
@@ -377,7 +399,9 @@
             body_name,
             state_getter_name,
             state_setter_name,
-            (loop_vars,))
+            (loop_vars,),
+            (basic_symbol_names,),
+            (composite_symbol_names,))
       """
       node = templates.replace(
           template,
@@ -389,7 +413,9 @@
           body=node.body,
           state_functions=state_functions,
           state_getter_name=state_getter_name,
-          state_setter_name=state_setter_name)
+          state_setter_name=state_setter_name,
+          basic_symbol_names=basic_symbol_names,
+          composite_symbol_names=composite_symbol_names)
     else:
       template = """
         state_functions
@@ -403,7 +429,9 @@
             body_name,
             state_getter_name,
             state_setter_name,
-            ())
+            (),
+            (),
+            (composite_symbol_names,))
       """
       node = templates.replace(
           template,
@@ -413,7 +441,8 @@
           body=node.body,
           state_functions=state_functions,
           state_getter_name=state_getter_name,
-          state_setter_name=state_setter_name)
+          state_setter_name=state_setter_name,
+          composite_symbol_names=composite_symbol_names)
 
     undefined_assigns = self._create_undefined_assigns(possibly_undefs)
     return undefined_assigns + node
@@ -466,6 +495,11 @@
 
     undefined_assigns = self._create_undefined_assigns(possibly_undefs)
 
+    basic_symbol_names = tuple(
+        gast.Str(str(symbol)) for symbol in basic_loop_vars)
+    composite_symbol_names = tuple(
+        gast.Str(str(symbol)) for symbol in composite_loop_vars)
+
     # TODO(mdan): Use a single template.
     # If the body and test functions took a single tuple for loop_vars, instead
     # of *loop_vars, then a single template could be used.
@@ -484,7 +518,9 @@
             body_name,
             state_getter_name,
             state_setter_name,
-            (loop_vars,))
+            (loop_vars,),
+            (basic_symbol_names,),
+            (composite_symbol_names,))
       """
       return templates.replace(
           template,
@@ -500,7 +536,9 @@
           body=node.body,
           state_functions=state_functions,
           state_getter_name=state_getter_name,
-          state_setter_name=state_setter_name)
+          state_setter_name=state_setter_name,
+          basic_symbol_names=basic_symbol_names,
+          composite_symbol_names=composite_symbol_names)
     else:
       template = """
         undefined_assigns
@@ -516,7 +554,9 @@
             body_name,
             state_getter_name,
             state_setter_name,
-            ())
+            (),
+            (),
+            (composite_symbol_names,))
       """
       return templates.replace(
           template,
@@ -530,7 +570,8 @@
           body=node.body,
           state_functions=state_functions,
           state_getter_name=state_getter_name,
-          state_setter_name=state_setter_name)
+          state_setter_name=state_setter_name,
+          composite_symbol_names=composite_symbol_names)
 
 
 def transform(node, ctx):
diff --git a/tensorflow/python/autograph/converters/function_scopes.py b/tensorflow/python/autograph/converters/function_scopes.py
index 284b5b3..52bd701 100644
--- a/tensorflow/python/autograph/converters/function_scopes.py
+++ b/tensorflow/python/autograph/converters/function_scopes.py
@@ -21,54 +21,98 @@
 import gast
 
 from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
 from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis import annos
+
+
+class _Function(object):
+
+  def __init__(self):
+    self.context_name = None
 
 
 class FunctionBodyTransformer(converter.Base):
   """Wraps function bodies around autograph-specific boilerplate."""
 
-  def _name_for_current_scope(self):
-    innermost = self.enclosing_entities[-1]
-    if len(self.enclosing_entities) > 1:
-      parent = self.enclosing_entities[-2]
-      if isinstance(parent, gast.ClassDef):
-        # Methods also take the name of their class.
-        name = '%s/%s' % (parent.name, innermost.name)
-      else:
-        name = innermost.name
-    else:
-      name = innermost.name
+  def visit_Return(self, node):
+    if node.value is None:
+      return node
+    return templates.replace(
+        'return function_context_name.mark_return_value(value)',
+        function_context_name=self.state[_Function].context_name,
+        value=node.value)
 
-    # Sanitize the name.
-    # See https://www.tensorflow.org/api_docs/python/tf/Graph#name_scope
-    # TensorFlow doesn't like leading underscores at the top level.
-    while name[0] == '_':
-      name = name[1:]
-    return name
-
-  def visit_FunctionDef(self, node):
+  def visit_Lambda(self, node):
+    self.state[_Function].enter()
     node = self.generic_visit(node)
 
-    final_body = []
-    indented_body = node.body
-    if node.body:
-      first_statement = node.body[0]
-      # Skip the docstring, if any.
-      if (isinstance(first_statement, gast.Expr) and
-          isinstance(first_statement.value, gast.Str)):
-        indented_body = indented_body[1:]
-        final_body.append(first_statement)
+    # Only wrap the top-level function. Theoretically, we can and should wrap
+    # everything, but that can lead to excessive boilerplate when lambdas are
+    # nested.
+    # TODO(mdan): Looks more closely for use cases that actually require this.
+    if self.state[_Function].level > 2:
+      self.state[_Function].exit()
+      return node
+
+    scope = anno.getanno(node, anno.Static.SCOPE)
+    function_context_name = self.ctx.namer.new_symbol('lambda_scope',
+                                                      scope.referenced)
+    self.state[_Function].context_name = function_context_name
+    anno.setanno(node, 'function_context_name', function_context_name)
 
     template = """
-      with ag__.function_scope(scope_name):
+      ag__.with_function_scope(
+          lambda function_context: body, function_context_name, options)
+    """
+    node.body = templates.replace_as_expression(
+        template,
+        options=self.ctx.program.options.to_ast(),
+        function_context=function_context_name,
+        function_context_name=gast.Str(function_context_name),
+        body=node.body)
+
+    self.state[_Function].exit()
+    return node
+
+  def visit_FunctionDef(self, node):
+    self.state[_Function].enter()
+    scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
+
+    function_context_name = self.ctx.namer.new_symbol(
+        '{}_scope'.format(node.name), scope.referenced)
+    self.state[_Function].context_name = function_context_name
+    anno.setanno(node, 'function_context_name', function_context_name)
+
+    node = self.generic_visit(node)
+
+    docstring_node = None
+    if node.body:
+      first_statement = node.body[0]
+      if (isinstance(first_statement, gast.Expr) and
+          isinstance(first_statement.value, gast.Str)):
+        docstring_node = first_statement
+        node.body = node.body[1:]
+
+    template = """
+      with ag__.FunctionScope(
+      function_name, context_name, options) as function_context:
         body
     """
-    scoped_body = templates.replace(
+    wrapped_body = templates.replace(
         template,
-        scope_name=gast.Str(self._name_for_current_scope()),
-        body=indented_body)
-    final_body.extend(scoped_body)
-    node.body = final_body
+        function_name=gast.Str(node.name),
+        context_name=gast.Str(function_context_name),
+        options=self.ctx.program.options.to_ast(),
+        function_context=function_context_name,
+        body=node.body)
+
+    if docstring_node is not None:
+      wrapped_body = [docstring_node] + wrapped_body
+
+    node.body = wrapped_body
+
+    self.state[_Function].exit()
     return node
 
 
diff --git a/tensorflow/python/autograph/converters/function_scopes_test.py b/tensorflow/python/autograph/converters/function_scopes_test.py
index f973687..a123105 100644
--- a/tensorflow/python/autograph/converters/function_scopes_test.py
+++ b/tensorflow/python/autograph/converters/function_scopes_test.py
@@ -77,7 +77,7 @@
       first, second = result.test_fn(constant_op.constant(1))
       self.assertIn('test_fn/', first.op.name)
       self.assertNotIn('inner_fn', first.op.name)
-      self.assertIn('test_fn/inner_fn/', second.op.name)
+      self.assertIn('test_fn/inner_fn/', second.op.inputs[0].name)
 
   @test_util.run_deprecated_v1
   def test_method(self):
@@ -98,9 +98,9 @@
 
     with self.compiled(node, {}, (ops.name_scope,)) as result:
       first, second = result.TestClass().test_fn(constant_op.constant(1))
-      self.assertIn('TestClass/test_fn/', first.op.name)
+      self.assertIn('test_fn/', first.op.name)
       self.assertNotIn('inner_fn', first.op.name)
-      self.assertIn('TestClass/test_fn/inner_fn/', second.op.name)
+      self.assertIn('test_fn/inner_fn/', second.op.inputs[0].name)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/autograph/converters/side_effect_guards.py b/tensorflow/python/autograph/converters/side_effect_guards.py
deleted file mode 100644
index 21de0e4..0000000
--- a/tensorflow/python/autograph/converters/side_effect_guards.py
+++ /dev/null
@@ -1,203 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Adds guards against function calls with side effects.
-
-Only standalone calls are guarded.
-
-WARNING: This mechanism is incomplete. Particularly, it only guards the
-arguments passed to functions, and does not account for indirectly modified
-state.
-
-Example:
-  y = tf.compat.v1.layers.dense(x)       # Creates TF variable 'foo'
-  loss = loss(y)
-  opt.minimize(loss)           # indirectly affects 'foo'
-  z = tf.compat.v1.get_variable('foo')   # Indirectly affects `loss` and 'foo'
-  # Here, `loss` can be guarded. But `z` cannot.
-
-# TODO(mdan): We should probably define a safe mode where we guard everything.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import gast
-
-from tensorflow.python.autograph.core import converter
-from tensorflow.python.autograph.pyct import anno
-from tensorflow.python.autograph.pyct import ast_util
-from tensorflow.python.autograph.pyct import qual_names
-from tensorflow.python.autograph.pyct import templates
-from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
-
-
-class SymbolNamer(object):
-  """Describes the interface for SideEffectGuardTransformer's namer."""
-
-  def new_symbol(self, name_root, reserved_locals):
-    """Generate a new unique function_name.
-
-    Args:
-      name_root: String, used as stem in the new name.
-      reserved_locals: Set(string), additional local symbols that are reserved.
-    Returns:
-      String.
-    """
-    raise NotImplementedError()
-
-
-class SideEffectGuardTransformer(converter.Base):
-  """Adds control dependencies to functions with side effects."""
-
-  def _visit_and_reindent(self, nodes):
-    new_nodes = []
-    current_dest = new_nodes
-    alias_map = {}
-    reindent_requested = False
-    for n in nodes:
-      n = self.visit(n)
-      # NOTE: the order in which these statements execute is important; in
-      # particular, watch out for ending up with cycles in the AST.
-      if alias_map:
-        n = ast_util.rename_symbols(n, alias_map)
-      if isinstance(n, (list, tuple)):
-        current_dest.extend(n)
-      else:
-        current_dest.append(n)
-      if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER):
-        reindent_requested = True
-        new_dest, new_alias_map = anno.getanno(
-            n, anno.Basic.INDENT_BLOCK_REMAINDER)
-        anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER)
-        new_alias_map.update(alias_map)
-        alias_map = new_alias_map
-        current_dest = new_dest
-
-    if reindent_requested:
-      no_controls_to_gate = False
-      if not current_dest:
-        no_controls_to_gate = True
-      if len(current_dest) == 1:
-        if ast_util.matches(current_dest[0], 'return'):
-          no_controls_to_gate = True
-        if ast_util.matches(current_dest[0], 'return ()'):
-          no_controls_to_gate = True
-        if ast_util.matches(current_dest[0], 'return []'):
-          no_controls_to_gate = True
-        if ast_util.matches(current_dest[0], 'return {}'):
-          no_controls_to_gate = True
-      if no_controls_to_gate:
-        # TODO(mdan): There may still be something that could be done.
-        raise ValueError(
-            'Unable to insert statement into the computation flow: it is not'
-            ' followed by any computation which the statement could gate.')
-
-    return new_nodes
-
-  def visit_FunctionDef(self, node):
-    node.body = self._visit_and_reindent(node.body)
-    return node
-
-  def visit_With(self, node):
-    node.body = self._visit_and_reindent(node.body)
-    return node
-
-  def visit_If(self, node):
-    node.body = self._visit_and_reindent(node.body)
-    node.orelse = self._visit_and_reindent(node.orelse)
-    return node
-
-  def visit_While(self, node):
-    node.body = self._visit_and_reindent(node.body)
-    node.orelse = self._visit_and_reindent(node.orelse)
-    return node
-
-  # TODO(b/123995141) Remove once ExceptionHandlers are in the CFG
-  def visit_ExceptHandler(self, node):
-    return node
-
-  def visit_Expr(self, node):
-    self.generic_visit(node)
-    if isinstance(node.value, gast.Call):
-      # Patterns of single function calls, like:
-      #   opt.minimize(loss)
-      # or:
-      #   tf.compat.v1.py_func(...)
-
-      # First, attempt to gate future evaluation of args. If that's not
-      # possible, gate all remaining statements (and that may fail too, see
-      # _visit_and_reindent.
-      args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE)
-      live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
-      # NOTE: We can't guard object attributes because they may not be writable.
-      # In addition, avoid renaming well-known names.
-      # TODO(mdan): Move these names into config.
-      unguarded_names = (qual_names.QN('self'), qual_names.QN('ag__'))
-      guarded_args = tuple(s for s in live_out
-                           if not s.is_composite() and s not in unguarded_names)
-
-      # TODO(mdan): Include all arguments which depended on guarded_args too.
-      # For example, the following will still cause a race:
-      #   tf.compat.v1.assign(a, a + 1)
-      #   b = a + 1
-      #   tf.compat.v1.assign(a, a + 1)  # Control deps here should include `b`
-      #   c = b + 1
-      # Or maybe we should just raise an "unsafe assign" error?
-
-      if guarded_args:
-        # The aliases may need new names to avoid incorrectly making them local.
-        # TODO(mdan): This is brutal. It will even rename modules - any fix?
-        need_alias = tuple(
-            s for s in guarded_args if s not in args_scope.parent.modified)
-        aliased_new_names = tuple(
-            qual_names.QN(
-                self.ctx.namer.new_symbol(
-                    s.ssf(), args_scope.parent.referenced)) for s in need_alias)
-        alias_map = dict(zip(need_alias, aliased_new_names))
-        if len(guarded_args) == 1:
-          s, = guarded_args
-          aliased_guarded_args = alias_map.get(s, s)
-        else:
-          aliased_guarded_args = gast.Tuple(
-              [alias_map.get(s, s).ast() for s in guarded_args], None)
-
-        template = """
-          with ag__.utils.control_dependency_on_returns(call):
-            aliased_guarded_args = ag__.utils.alias_tensors(guarded_args)
-        """
-        control_deps_guard = templates.replace(
-            template,
-            call=node.value,
-            aliased_guarded_args=aliased_guarded_args,
-            guarded_args=guarded_args)[-1]
-      else:
-        alias_map = {}
-
-        template = """
-          with ag__.utils.control_dependency_on_returns(call):
-            pass
-        """
-        control_deps_guard = templates.replace(template, call=node.value)[-1]
-        control_deps_guard.body = []
-
-      node = control_deps_guard
-      anno.setanno(node, anno.Basic.INDENT_BLOCK_REMAINDER,
-                   (node.body, alias_map))
-    return node
-
-
-def transform(node, ctx):
-  return SideEffectGuardTransformer(ctx).visit(node)
diff --git a/tensorflow/python/autograph/converters/side_effect_guards_test.py b/tensorflow/python/autograph/converters/side_effect_guards_test.py
deleted file mode 100644
index ead05d0..0000000
--- a/tensorflow/python/autograph/converters/side_effect_guards_test.py
+++ /dev/null
@@ -1,166 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for side_effect_guards module."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.autograph.converters import side_effect_guards
-from tensorflow.python.autograph.core import converter_testing
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import errors_impl
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import test_util
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.platform import test
-
-
-tf = None  # Will be replaced by a mock.
-
-
-class SideEffectGuardsTest(converter_testing.TestCase):
-
-  @test_util.run_deprecated_v1
-  def test_side_effect_on_return_only_variable(self):
-
-    def test_fn(a):
-      tf.assign(a, a + 1)
-      return a
-
-    node, ctx = self.prepare(test_fn, {})
-    node = side_effect_guards.transform(node, ctx)
-
-    self.assertEqual(len(node.body), 1)
-
-    with self.compiled(node, {}, (state_ops.assign,)) as result:
-      with self.cached_session() as sess:
-        v = variable_scope.get_variable('test', initializer=2)
-        self.evaluate(v.initializer)
-        self.evaluate(result.test_fn(v))
-        # TODO(mdan): Add support for this use case.
-        # Right now the variable `a` is not conditioned on the `assign` because
-        # there's no way to add control dependencies to a variable object.
-        self.assertEqual(2, self.evaluate(v))
-
-  def test_side_effect_on_used_variable(self):
-
-    def test_fn(a):
-      tf.assign(a, a + 1)
-      return a + 1
-
-    node, ctx = self.prepare(test_fn, {})
-    node = side_effect_guards.transform(node, ctx)
-
-    self.assertEqual(len(node.body), 1)
-
-    with self.compiled(node, {}, (state_ops.assign,)) as result:
-      with self.cached_session() as sess:
-        v = variable_scope.get_variable('test', initializer=2)
-        self.evaluate(v.initializer)
-        self.evaluate(result.test_fn(v))
-        # TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
-        # Right now it's 3 or 4 based on whether the read is synchronized.
-        self.assertEqual(3, self.evaluate(v))
-
-  @test_util.run_deprecated_v1
-  def test_side_effect_on_tensor(self):
-
-    def test_fn(a):
-      tf.Assert(a > 0, ['expected in throw'])
-      return a
-
-    node, ctx = self.prepare(test_fn, {})
-    node = side_effect_guards.transform(node, ctx)
-
-    self.assertEqual(len(node.body), 1)
-
-    with self.compiled(node, {}, (control_flow_ops.Assert,)) as result:
-      with self.cached_session() as sess:
-        with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
-                                     'expected in throw'):
-          sess.run(result.test_fn(constant_op.constant(-1)))
-
-  def test_multiline_block(self):
-
-    def test_fn(a):
-      tf.assign_add(a, 1)
-      b = a + 1
-      tf.assign_add(a, 1)
-      b += 1
-      return b
-
-    node, ctx = self.prepare(test_fn, {})
-    node = side_effect_guards.transform(node, ctx)
-
-    self.assertEqual(len(node.body), 1)
-
-    with self.compiled(node, {}, (state_ops.assign_add,)) as result:
-      with self.cached_session() as sess:
-        v = variable_scope.get_variable('test', initializer=2)
-        self.evaluate(v.initializer)
-        self.evaluate(result.test_fn(v))
-        # TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
-        self.assertEqual(4, self.evaluate(v))
-
-  def test_multiline_nested_block(self):
-
-    def test_fn(a):
-      with tf.name_scope('foo'):
-        tf.assign(a, a + 1)
-        b = a + 1
-      return b
-
-    node, ctx = self.prepare(test_fn, {})
-    node = side_effect_guards.transform(node, ctx)
-
-    self.assertEqual(len(node.body[0].body), 1)
-
-    with self.compiled(node, {}, (state_ops.assign, ops.name_scope)) as result:
-      with self.cached_session() as sess:
-        v = variable_scope.get_variable('test', initializer=2)
-        self.evaluate(v.initializer)
-        self.evaluate(result.test_fn(v))
-        # TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
-        self.assertEqual(3, self.evaluate(v))
-
-  def test_multiline_block_unsafe(self):
-
-    def test_fn(a):
-      tf.assign(a, a + 1)
-      b = a + 1
-      tf.assign_add(a, 1)
-      c = b + 1
-      return c
-
-    node, ctx = self.prepare(test_fn, {})
-    node = side_effect_guards.transform(node, ctx)
-
-    self.assertEqual(len(node.body), 1)
-
-    with self.compiled(node, {},
-                       (state_ops.assign, state_ops.assign_add)) as result:
-      with self.cached_session() as sess:
-        v = variable_scope.get_variable('test', initializer=2)
-        self.evaluate(v.initializer)
-        self.evaluate(result.test_fn(v))
-        # TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
-        self.assertEqual(4, self.evaluate(v))
-
-
-if __name__ == '__main__':
-  test.main()
diff --git a/tensorflow/python/autograph/core/BUILD b/tensorflow/python/autograph/core/BUILD
index a480f83..8d7fc1d 100644
--- a/tensorflow/python/autograph/core/BUILD
+++ b/tensorflow/python/autograph/core/BUILD
@@ -23,7 +23,7 @@
         "config.py",
         "config_lib.py",
         "converter.py",
-        "function_wrapping.py",
+        "function_wrappers.py",
         "naming.py",
         "unsupported_features_checker.py",
     ],
@@ -68,8 +68,8 @@
 )
 
 py_test(
-    name = "function_wrapping_test",
-    srcs = ["function_wrapping_test.py"],
+    name = "function_wrappers_test",
+    srcs = ["function_wrappers_test.py"],
     srcs_version = "PY2AND3",
     deps = [
         ":core",
diff --git a/tensorflow/python/autograph/core/converter.py b/tensorflow/python/autograph/core/converter.py
index 2ec12c6..e9bf009 100644
--- a/tensorflow/python/autograph/core/converter.py
+++ b/tensorflow/python/autograph/core/converter.py
@@ -148,9 +148,9 @@
   Attributes:
     recursive: bool, whether to recursively convert any user functions or
       classes that the converted function may use.
-    force_conversion: bool, whether to force convertinng the target entity. When
-      force_conversion is turned off, the converter may decide to return the
-      function as-is.
+    user_requested: bool, whether the conversion was explicitly requested by
+      the user, as opposed to being performed as a result of other logic. This
+      value always auto-resets resets to False in child conversions.
     optional_features: Union[Feature, Set[Feature]], controls the use of
       optional features in the conversion process. See Feature for available
       options.
@@ -158,11 +158,11 @@
 
   def __init__(self,
                recursive=False,
-               force_conversion=False,
+               user_requested=False,
                internal_convert_user_code=True,
                optional_features=Feature.ALL):
     self.recursive = recursive
-    self.force_conversion = force_conversion
+    self.user_requested = user_requested
     # TODO(mdan): Rename to conversion_recursion_depth?
     self.internal_convert_user_code = internal_convert_user_code
 
@@ -174,7 +174,7 @@
     self.optional_features = optional_features
 
   def as_tuple(self):
-    return (self.recursive, self.force_conversion,
+    return (self.recursive, self.user_requested,
             self.internal_convert_user_code, self.optional_features)
 
   def __hash__(self):
@@ -191,16 +191,20 @@
     return (Feature.ALL in self.optional_features or
             feature in self.optional_features)
 
-  def to_ast(self, internal_convert_user_code=None):
+  def call_options(self):
+    """Returns the corresponding options to be used for recursive conversion."""
+    return ConversionOptions(
+        recursive=self.recursive,
+        user_requested=False,
+        internal_convert_user_code=self.recursive,
+        optional_features=self.optional_features)
+
+  def to_ast(self):
     """Returns a representation of this object as an AST node.
 
     The AST node encodes a constructor that would create an object with the
     same contents.
 
-    Args:
-      internal_convert_user_code: Optional[bool], allows ovrriding the
-        corresponding value.
-
     Returns:
       ast.Node
     """
@@ -210,7 +214,7 @@
     template = """
       ag__.ConversionOptions(
           recursive=recursive_val,
-          force_conversion=force_conversion_val,
+          user_requested=user_requested_val,
           optional_features=optional_features_val,
           internal_convert_user_code=internal_convert_user_code_val)
     """
@@ -219,23 +223,19 @@
       return parser.parse_expression('({})'.format(', '.join(
           'ag__.{}'.format(str(v)) for v in values)))
 
-    if internal_convert_user_code is None:
-      internal_convert_user_code = self.internal_convert_user_code
-
     expr_ast = templates.replace(
         template,
         recursive_val=parser.parse_expression(str(self.recursive)),
-        force_conversion_val=parser.parse_expression(
-            str(self.force_conversion)),
+        user_requested_val=parser.parse_expression(str(self.user_requested)),
         internal_convert_user_code_val=parser.parse_expression(
-            str(internal_convert_user_code)),
+            str(self.internal_convert_user_code)),
         optional_features_val=list_of_features(self.optional_features))
     return expr_ast[0].value
 
 
 STANDARD_OPTIONS = ConversionOptions(
     recursive=True,
-    force_conversion=False,
+    user_requested=False,
     internal_convert_user_code=True,
     optional_features=None)
 
@@ -262,13 +262,15 @@
   Attributes:
     namer: Namer
     info: transformer.EntityInfo
-    program: ProgramContext
+    program: ProgramContext,
+    targe_name: Text
   """
 
-  def __init__(self, namer, entity_info, program_ctx):
+  def __init__(self, namer, entity_info, program_ctx, target_name=None):
     super(EntityContext, self).__init__(entity_info)
     self.namer = namer
     self.program = program_ctx
+    self.target_name = target_name
 
 
 class Base(transformer.Base):
diff --git a/tensorflow/python/autograph/core/converter_test.py b/tensorflow/python/autograph/core/converter_test.py
index 85b4d45..2d5b334 100644
--- a/tensorflow/python/autograph/core/converter_test.py
+++ b/tensorflow/python/autograph/core/converter_test.py
@@ -50,7 +50,7 @@
     reparsed_opts = reparsed.test_fn()
 
     self.assertEqual(opts.recursive, reparsed_opts.recursive)
-    self.assertEqual(opts.force_conversion, reparsed_opts.force_conversion)
+    self.assertEqual(opts.user_requested, False)
     self.assertEqual(
         opts.internal_convert_user_code,
         reparsed_opts.internal_convert_user_code)
diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py
index 507739f..7560b43 100644
--- a/tensorflow/python/autograph/core/converter_testing.py
+++ b/tensorflow/python/autograph/core/converter_testing.py
@@ -27,7 +27,7 @@
 from tensorflow.python.autograph import operators
 from tensorflow.python.autograph import utils
 from tensorflow.python.autograph.core import converter
-from tensorflow.python.autograph.core import function_wrapping
+from tensorflow.python.autograph.core import function_wrappers
 from tensorflow.python.autograph.core import naming
 from tensorflow.python.autograph.lang import special_functions
 from tensorflow.python.autograph.pyct import compiler
@@ -57,7 +57,7 @@
 
     self.dynamic_calls = []
     # See api.converted_call
-    def converted_call(f, unused_opts, args, kwargs):
+    def converted_call(f, unused_opts, args, kwargs, unused_function_ctx):
       """Mock version of api.converted_call."""
       self.dynamic_calls.append((args, kwargs))
       if kwargs is None:
@@ -78,7 +78,7 @@
       fake_ag.ConversionOptions = converter.ConversionOptions
       fake_ag.Feature = converter.Feature
       fake_ag.utils = utils
-      fake_ag.function_scope = function_wrapping.function_scope
+      fake_ag.FunctionScope = function_wrappers.FunctionScope
       result.ag__ = fake_ag
       result.ag_source_map__ = source_map
       for k, v in namespace.items():
@@ -135,7 +135,8 @@
         source_file='<fragment>',
         future_features=future_features,
         namespace=namespace)
-    ctx = converter.EntityContext(namer, entity_info, program_ctx)
+    ctx = converter.EntityContext(
+        namer, entity_info, program_ctx, 'test_fn')
     origin_info.resolve_entity(node, source, test_fn)
     node = converter.standard_analysis(node, ctx, is_initial=True)
     return node, ctx
diff --git a/tensorflow/python/autograph/core/function_wrappers.py b/tensorflow/python/autograph/core/function_wrappers.py
new file mode 100644
index 0000000..55b1071
--- /dev/null
+++ b/tensorflow/python/autograph/core/function_wrappers.py
@@ -0,0 +1,108 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Support for wrapping converted functions bodies with auxiliary logic."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.core import ag_ctx
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.framework import auto_control_deps
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.util import nest
+
+
+class FunctionScope(object):
+  """Context manager that wraps the body of a converted function.
+
+  This context manager handles various operations related to the scope of a
+  function:
+    * optional TF name scopes - these name scopes match the name of the
+        function, for easy visualization in tensorBoard;
+    * optional automatic control dependencies - this adds the same mechanism
+        for control dependenecies that is used by `@tf.function`; it can be
+        optionally enabled when using `tf.autograph.to_graph`;
+    * tracking of autograph conversion state (whether it's enabled by the user,
+        conversion options;
+  """
+
+  def __init__(self, function_name, scope_name, options):
+    self.name = scope_name
+    self.options = options
+
+    if options.user_requested:
+      self.autograph_ctx = ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED,
+                                                   options)
+    self.callopts = options.call_options()
+
+    use_name_scope = options.uses(converter.Feature.NAME_SCOPES)
+    self.use_name_scope = use_name_scope
+    if use_name_scope:
+      self.name_scope = ops.name_scope(self._sanitize(function_name))
+
+    use_auto_deps = self.options.uses(converter.Feature.AUTO_CONTROL_DEPS)
+    self.use_auto_deps = use_auto_deps
+    if use_auto_deps:
+      self.autodeps_scope = auto_control_deps.AutomaticControlDependencies()
+      self._return_value_marked = False
+
+  def _sanitize(self, name):
+    """See https://www.tensorflow.org/api_docs/python/tf/Graph#name_scope."""
+    # TensorFlow doesn't like leading underscores at the top level.
+    if name and name.startswith('_'):
+      name = 'fn' + name
+    return name
+
+  def __enter__(self):
+    if self.options.user_requested:
+      self.autograph_ctx.__enter__()
+    if self.use_name_scope:
+      self.name_scope.__enter__()
+    if self.use_auto_deps:
+      self.autodeps_scope.__enter__()
+    return self
+
+  def __exit__(self, exc_type, exc_val, exc_tb):
+    if self.options.user_requested:
+      self.autograph_ctx.__exit__(exc_type, exc_val, exc_tb)
+    if self.use_name_scope:
+      self.name_scope.__exit__(exc_type, exc_val, exc_tb)
+    if self.use_auto_deps:
+      self.autodeps_scope.__exit__(exc_type, exc_val, exc_tb)
+
+  def mark_return_value(self, value):
+    """Marks a value as returned from the function guarded by the scope."""
+    if self.use_auto_deps:
+      self._return_value_marked = True
+      if value is None:
+        # We don't create dummy returns, to preserve Python semantics. The user
+        # is responsible for adding a return value to the top-level function.
+        return None
+
+      def _mark_return_if_tensor(t):
+        if tensor_util.is_tensor(t):
+          return self.autodeps_scope.mark_as_return(t)
+        return t
+
+      value = nest.map_structure(_mark_return_if_tensor, value)
+    return value
+
+
+def with_function_scope(thunk, scope_name, options):
+  """Inline version of the FunctionScope context manager."""
+  with FunctionScope('lambda_', scope_name, options) as scope:
+    return thunk(scope)
diff --git a/tensorflow/python/autograph/core/function_wrappers_test.py b/tensorflow/python/autograph/core/function_wrappers_test.py
new file mode 100644
index 0000000..0191800
--- /dev/null
+++ b/tensorflow/python/autograph/core/function_wrappers_test.py
@@ -0,0 +1,61 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for function_wrappers module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.core import function_wrappers
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class FunctionWrappersTest(test.TestCase):
+
+  def test_name_scope(self):
+    if context.executing_eagerly():
+      self.skipTest('Tensor names are disabled in eager')
+
+    with function_wrappers.FunctionScope(
+        'test_name', None,
+        converter.ConversionOptions(
+            optional_features=converter.Feature.NAME_SCOPES)):
+      t = constant_op.constant(1)
+    self.assertIn('test_name', t.name)
+
+  def test_auto_cotrol_deps(self):
+    v = variables.Variable(1)
+    with function_wrappers.FunctionScope(
+        '_', None,
+        converter.ConversionOptions(
+            optional_features=converter.Feature.AUTO_CONTROL_DEPS)) as scope:
+      v.assign(2)
+      op = scope.mark_return_value(constant_op.constant(1))
+    self.evaluate(op)
+    self.assertEqual(self.evaluate(v.read_value()), 2)
+
+  def test_all_disabled(self):
+    with function_wrappers.FunctionScope(None, None,
+                                         converter.STANDARD_OPTIONS):
+      t = constant_op.constant(1)
+    self.assertEqual(self.evaluate(t), 1)
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/python/autograph/core/function_wrapping.py b/tensorflow/python/autograph/core/function_wrapping.py
deleted file mode 100644
index 21b66ef..0000000
--- a/tensorflow/python/autograph/core/function_wrapping.py
+++ /dev/null
@@ -1,30 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Support for wrapping converted functions bodies with auxiliary logic."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import contextlib
-
-from tensorflow.python.framework import ops
-
-
-@contextlib.contextmanager
-def function_scope(function_name):
-  """Returns a context manager for the converted body of a function."""
-  with ops.name_scope(function_name):
-    yield
diff --git a/tensorflow/python/autograph/core/function_wrapping_test.py b/tensorflow/python/autograph/core/function_wrapping_test.py
deleted file mode 100644
index 7e21b97..0000000
--- a/tensorflow/python/autograph/core/function_wrapping_test.py
+++ /dev/null
@@ -1,36 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for function_wrapping module."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.autograph.core import function_wrapping
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import test_util
-from tensorflow.python.platform import test
-
-
-class FunctionWrappingTest(test.TestCase):
-
-  @test_util.run_deprecated_v1
-  def test_function_scope_name(self):
-    with function_wrapping.function_scope('test_name'):
-      t = constant_op.constant(1)
-    self.assertIn('test_name', t.name)
-
-if __name__ == '__main__':
-  test.main()
diff --git a/tensorflow/python/autograph/docs/pyfunc_dtypes.md b/tensorflow/python/autograph/docs/pyfunc_dtypes.md
deleted file mode 100644
index c2427f5..0000000
--- a/tensorflow/python/autograph/docs/pyfunc_dtypes.md
+++ /dev/null
@@ -1,33 +0,0 @@
-# Specifying return data type for `py_func` calls
-
-The `py_func` op requires specifying a
-[data type](https://www.tensorflow.org/guide/tensors#data_types).
-
-When wrapping a function with `py_func`, for instance using
-`@autograph.do_not_convert(run_as=autograph.RunMode.PY_FUNC)`, you have two
-options to specify the returned data type:
-
- * explicitly, with a specified `tf.DType` value
- * by matching the data type of an input argument, which is then assumed to be
-     a `Tensor`
-
-Examples:
-
-Specify an explicit data type:
-
-```
-  def foo(a):
-    return a + 1
-
-  autograph.util.wrap_py_func(f, return_dtypes=[tf.float32])
-```
-
-Match the data type of the first argument:
-
-```
-  def foo(a):
-    return a + 1
-
-  autograph.util.wrap_py_func(
-      f, return_dtypes=[autograph.utils.py_func.MatchDType(0)])
-```
diff --git a/tensorflow/python/autograph/g3doc/reference/common_errors.md b/tensorflow/python/autograph/g3doc/reference/common_errors.md
new file mode 100644
index 0000000..79867e0
--- /dev/null
+++ b/tensorflow/python/autograph/g3doc/reference/common_errors.md
@@ -0,0 +1,87 @@
+# AutoGraph reference
+
+[Index](index.md)
+
+## Common AutoGraph errors
+
+### "WARNING: `<name>` could not be transformed"
+
+This warning is output when AutoGraph could not convert a function, for an
+unexpected reason. The error message contains the reason why the function could
+not be converted, as well as guidance on how to proceed next.
+
+Note: AutoGraph does not always output a warning. For example, constructors
+are silently called without conversion.
+
+When this warning is printed, the code returned by AutoGraph still executes, but
+the functions indicated in the warning will be executed as they are, without
+conversion. If the functions contain pure Python or graph code (for example,
+they have no Tensor-dependent control flow), then the code is likely to still
+run without error. However, if it contains any constructs that are only
+supported in AutoGraph, expect subsequent exceptions.
+
+Note: the warning is output to the [abseil](https://github.com/abseil/abseil-py)
+logger, with `WARNING` severity. To direct these warnings to `stdout`, use
+`tf.autograph.set_verbosity(0, True)`.
+
+### "OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool`"
+
+This exception is raised whenever a `tf.Tensor` is type-cast as a Python `bool`,
+in a context where eager execution is not active. The exception is only raised
+when graph execution is active, for example inside a `@tf.function` with
+AutoGraph turned off. It can be caused by using a `tf.Tensor` value as:
+
+  * the condition of an `if` or `while` statement: `if <tensor>:`
+  * the argument in a logical expression: `tensor and another_tensor`
+  * the argument to the `bool` built-in: `bool(tensor)`
+
+Note: These operations are allowed when executing eagerly.
+
+Within the context of AutoGraph, it usually indicates eager-style control
+flow that has not been converted by AutoGraph, for any reason.
+
+When encountering this error, make sure that the function is either decorated
+with `@tf.function`, or called from another function decorated in this way. Also
+look at the console and logging output for conversion warnings (see the section
+above).
+
+### "OperatorNotAllowedInGraphError: iterating over `tf.Tensor`"
+
+This exception is raised whenever you try to iterate over a `tf.Tensor`,
+in a context where eager execution is not active. The exception is only raised
+when graph execution is active, for example inside a `@tf.function` with
+AutoGraph turned off. It can be caused by using a `tf.Tensor` value as:
+
+  * the iterated of a `for` statement: `for i in tensor:`
+  * the argument to the `iter` built-in: `iter(tensor)`
+
+Note: These operations are allowed when executing eagerly.
+
+This exception is similar to the previous example, and has similar causes and
+remedies.
+
+### "InaccessibleTensorError: The tensor `<name>` is defined in another function or code block"
+
+This exception is common to code which attempts to obtain values calculated
+within a `tf.cond`, `tf.while_loop`, or another `@tf.function` without using
+functional style or through mutable collections. See
+[Limitations](limitations.md) for more details.
+
+### "StagingError: in converted code"
+
+This exception is used by AutoGraph to wrap exceptions with custom constructors
+that it cannot re-raise with the original type. See
+[Error handling](error_handling.md) for more details. If your code uses custom
+exceptions, expect them to be wrapped by this exception.
+
+### "Unable to identify source code of lambda function"
+
+This error usually appears in the context of a conversion warning. It indicates
+that a lambda function could not be parsed (see [Limitations](limitations.md)).
+
+This type of errors can usually be avoided by creating lambda functions in
+separate simple assignments, for example:
+
+```
+l = lambda <args>: <body>
+```
diff --git a/tensorflow/python/autograph/g3doc/reference/control_flow.md b/tensorflow/python/autograph/g3doc/reference/control_flow.md
new file mode 100644
index 0000000..494e556
--- /dev/null
+++ b/tensorflow/python/autograph/g3doc/reference/control_flow.md
@@ -0,0 +1,517 @@
+# AutoGraph reference
+
+[Index](index.md)
+
+## Control flow
+
+AutoGraph rewrites all control flow statements with specialized AutoGraph
+function calls. These function calls are capable of executing the corresponding
+control flow statement using Python semantics for effects outside the Python
+interpreter itself (see the [Introduction](intro.md)).
+
+### Dispatch rules
+
+Key Point: Only statements that are conditioned on, or iterate over, a
+TensorFlow object such as `tf.Tensor`, are converted into TensorFlow ops.
+
+As described in the [Introduction](intro.md), AutoGraph aims to preserve the
+semantics of valid Python code. If a control flow statement runs in graph
+execution without raising an error, then AutoGraph will also execute it as
+normal Python control flow. Statements which would normally raise an error, for
+example because a `tf.Tensor` cannot be used as a `bool` in an `if` statement,
+are converted to TensorFlow control flow ops.
+
+#### Analogy with compile-time constants and code optimization
+
+From the perspective of a TensorFlow graph, non-Tensor values, for example an
+integer or a NumPy array, are _constants_: they do not change value while the
+graph executes.
+
+For example, in the graph below, the condition is always `True` (it is
+invariant):
+
+```
+x = 1
+y = tf.cond(x > 0, lambda: 3 * x, lambda 5 * x)
+```
+
+That is equivalent to the code below:
+
+```
+x = 1
+y = 3 * x
+```
+
+In the example above, we've optimized away the conditional on a constant
+condition. The AutoGraph dispatch rules have the same effect: anything that is
+not a TensorFlow object is a compile-time constant for TensorFlow, and can be
+optimized away. For this reason, you can usually mix Python and TensorFlow
+computation and it will transparently have the expected result even
+when only some computations are executed in the graph.
+
+<!-- TODO(mdan): This is actually a limitation (a very subtle one) -->
+Caution: The assumption of invariant code made above is not true if the
+TensorFlow graph had callbacks into the Python code. If you modify data
+from within a `tf.py_function`, then the code outside a `tf.py_function`
+will have unpredictable behavior if it depends on the same data.
+
+For example, the `tf.cond` that runs as part of the `if` statement below will
+miss the update made by `f`:
+
+```
+n = [10]
+def f():
+  n[0] = 20
+  return 0
+tf.py_function(f, (), (tf.int32,))
+if tf.equal(n[0], 10):
+  tf.print('n is 10')
+```
+
+```
+n is 10
+```
+
+### Compound symbols
+
+AutoGraph usually handles basic symbols:
+
+```
+if a < 0:
+  a = -a
+```
+
+```
+a = tf.cond(a < 0, lambda: -a, lambda: a)
+```
+
+But it can also handle complex symbols in many cases. For example, if we treat
+`a.b` as a symbol in the code below, then we can use it as if it were a basic
+symbol name:
+
+```
+if a.b < 0
+  a.b = -a.b
+```
+
+```
+a.b = tf.cond(a.b < 0, lambda: -a.b, lambda: a.b)
+```
+
+This is useful in methods, which can operate on properties of `self`, as well as
+working directly on more complex object structures or collections.
+
+Caution: There are certain [limitations](limitations.md) around using Python
+collections and object mutation. When in doubt, place the values you work
+with into local variables and operate on those.
+
+### Effects of the tracing process
+
+#### All Python code paths are executed during tracing
+
+When constructing a graph, TensorFlow _traces_ the code. The tracing of control
+flow requires visiting _every possible code path_ (usually once).
+
+Note: In rare cases, the runtime may decide to trace some code paths several
+times. For example, the condition of a `while` statement may be executed twice,
+first with a temporary graph, to determine whether it evaluates to a
+`tf.Tensor`, then if it is a `tf.Tensor`, it's executed a second time in the
+proper graph.
+
+In other words, when tracing executes both branches of an if statement.
+Similarly, the body of loops is executed once (even if the loop would otherwise
+not iterate at all).
+
+This explains why inserting `print` statements in an `if` statement produces
+this output:
+
+```
+print('before if')
+if tf.constant(True):
+  print('true branch')
+else:
+  print('false branch')
+print('after if')
+```
+
+```
+before if
+true branch
+false branch
+after if
+```
+
+Note: Control flow that is not executed as a TensorFlow graph is not traced. Its
+body will execute as expected.
+
+Example of code that runs as regular Python code:
+
+```
+print('before if')
+if True:  # Condition not a Tensor, running normally
+  print('true branch')
+else:
+  print('false branch')
+print('after if')
+```
+
+```
+before if
+true branch
+after if
+```
+
+#### Python values modified in TensorFlow control flow become Tensors
+
+If a symbol is modified in a TensorFlow control flow statement, then it becomes
+a `tf.Tensor`, even if it started off as a Python promitive value.
+
+For example, the conditional below will run as a `tf.cond` (its condition is a
+`tf.Tensor`), which in turn will cause `i` to become a `tf.Tensor`.
+
+```
+i = 0
+if tf.greater(i, 0):
+  i = 1
+# i is not a Tensor
+```
+
+### `if` statements
+
+`if` statements whose condition is a `tf.Tensor` are executed as TensorFlow
+conditionals by converting them to `tf.cond`:
+
+```
+if tf.random.uniform(()) > 0.5:
+  x = 1
+else:
+  x = 2
+```
+
+`if` statements whose condition is not a `tf.Tensor` are executed as normal
+Python:
+
+```
+if np.random.uniform() > 0.5:
+  x = 1
+else:
+  x = 2
+```
+
+`if` statements executed as TensorFlow conditionals are subject to restrictions
+(see [limitations](limitations.md)). All symbols affected by the statement and
+used thereafter must be:
+
+ * of a data type understood by TensorFlow
+ * defined in both branches
+ * of consistent dtypes in both branches, for TensorFlow entities
+ * of consistent structure in both branches, for static collections (such as
+   lists or tuples)
+
+### `while` statements
+
+`while` statements whose condition is a `tf.Tensor` are executed as TensorFlow
+loops by converting them to `tf.while_loop`:
+
+```
+x = 0
+while tf.random.uniform(()) > 0.5:
+  x = x + 1
+```
+
+`while` statements whose condition is not a `tf.Tensor` are executed as normal
+Python:
+
+```
+x = 0
+while np.random.uniform() > 0.5:
+  x = x + 1
+```
+
+`while` statements executed as TensorFlow loops are subject to restrictions
+(see [limitations](limitations.md)). All symbols affected by the statement and
+used thereafter must be:
+
+ * of a data type understood by TensorFlow
+ * defined before the loop
+ * of consistent dtype at the beginning and the end of the loop,
+   for TensorFlow entities
+ * either of consistent shape at the beginning and the end of the loop,
+   for TensorFlow entities, or declared in `shape_invariants`
+ * of consistent structure  at the beginning and the end of the loop, for
+   static collections (such as lists or tuples)
+
+Caution: A `while` loop whose condition is a Python scalar will execute as
+normal Python. If you intended to run the loop as a TensorFlow loop, the loop
+will replicate its body in the graph (it is unrolled). To avoid that, make sure
+its condition is converted to a `tf.Tensor`, using for instance `tf.constant`.
+
+For example, the following loop is unrolled, even though the list contains
+`tf.Tensor` values, because the type of `l` is a Python `list`:
+
+```
+l = [tf.constant(1), tf.constant(2), tf.constant(3)]
+for i in l:
+  tf.print(i)  # This is unrolled - three `tf.print`s are built in the graph. 
+```
+
+If you wish for the loop to run as a TensorFlow loop, stack the loop:
+
+```
+l = [tf.constant(1), tf.constant(2), tf.constant(3)]
+for i in tf.stack(l):
+  tf.print(i)  # This runs as a TensorFlow loop.
+```
+
+<!-- TODO(mdan): List this under limitations -->
+Caution: A loop in which the type of the condition condition changes across
+iterations, in a way that would influence the way the loop is executed, is not
+allowed in AutoGraph.
+
+For example, the loop below will generate an error. After the first iteration,
+`i` becomes a tf.Tensor, because
+
+```
+i = 0
+while i < 10:  # `i < 10` is a Python bool - run as normal while loop
+  i = tf.constant(1)  # Error -- `i < 10` would now be a `tf.Tensor`
+```
+
+### `for` statements
+
+`for` statements that iterate over a `tf.Tensor` are executed as TensorFlow
+loops by converting them to a `tf.while_loop` which iterates over the first
+dimension (equivalent to NumPy):
+
+```
+for i in tf.constant(((1, 2), (3, 4))):
+  tf.print('iteration:', i)
+```
+
+```
+iteration: [1, 2]
+iteration: [3, 4]
+```
+
+Note: If possible, AutoGraph will also set the `maximum_iteration` parameter
+of the `tf.while_loop`.
+
+`for` statements that iterate over a the output of a `tf.range` are executed as
+TensorFlow loops by converting them to a `tf.while_loop` which uses the
+arguments passed to the `tf.range`:
+
+```
+for i in tf.range(3):
+  tf.print('iteration:', i)
+```
+
+`for` statements that iterate over a `tf.data.Dataset` and which do not contain
+`break` or `return` statements are executed as TensorFlow loops by converting
+them to `tf.data.Dataset.reduce` ops:
+
+```
+for i in tf.data.Dataset.range(3):
+  tf.print('iteration:', i)
+```
+
+`for` statements that iterate over a _distributed_ `tf.data.Dataset` and which
+do not contain `break` or `return` statements are executed as TensorFlow loops
+by converting them to the datasets' `reduce` ops:
+
+```
+for i in tf.distribute.OneDeviceStrategy('cpu').experimental_distribute_dataset(
+    tf.data.Dataset.range(3)):
+  tf.print('iteration:', i)
+```
+
+`for` statements that iterate over a `tf.data.Dataset` and which contain
+`break` or `return` statements are executed as TensorFlow loops by converting
+them to a combination of `tf.data.Dataset.scan`, `tf.data.Dataset.take_while`
+and `tf.data.Dataset.reduce` ops:
+
+```
+for i in tf.data.Dataset.range(3):
+  tf.print('iteration:', i)
+  break
+```
+
+```
+iteration: 1
+```
+
+`for` statements that iterate over a `tf.data.Dataset` _iterator_ are executed
+as TensorFlow loops by converting them to a combination of `tf.while_loop`,
+and `tf.cond` ops:
+
+```
+for i in iter(tf.data.Dataset.range(3)):
+  tf.print('iteration:', i)
+```
+
+`for` statements that iterate over a type different from any of the above are
+executed as normal Python:
+
+```
+for i in [1, 2, 3]:
+  print('iteration:', i)
+```
+
+Caution: A `for` loop over a `list` or `tuple` of `tf.Tensor` is considered to
+iterate over a Python `list` (or respectively `tuple`), therefore will be
+executed as normal Python. If you intended to run it as a TensorFlow loop,
+use `tf.stack` or `tf.concat`.
+
+Caution: A `for` loop over a Python `range` will be executed as normal Python.
+If you intended to run it as a TensorFlow loop, `tf.range`.
+
+Note: AutoGraph may output a warning when it believes that you are unrolling
+a loop inefficiently. However, the warning thresholds are very conservative.
+
+### `break` statements
+
+Code blocks in which `break` statements are used are rewritten with equivalent
+code that uses extra control booleans and conditionals. The control booleans are
+used directly in `while` loops. In the case of `for` loops, the AutoGraph
+corresponding operator accepts an `extra_test` argument which is similar to
+the conditional of a while loop, and which contains the control boolean.
+
+For example, the `while` loop below is rewritten as (showing the output of the
+`break` transformation only):
+
+```
+while i < 10:
+  if i > 3:
+    break
+  i += 1
+```
+
+```
+break_ = False
+while i < 10 and not break_:
+  if i > 3:
+    break_ = True
+    continue  # The continue statement is also rewritten in a subsequent pass
+  i += 1
+```
+
+Another example shows how the control boolean is used in the overload of a `for`
+loop (showing portions of the final output):
+
+```
+for i in range(10):
+  if i > 3:
+    break
+```
+
+```
+break_ = False
+...
+def extra_test(break_):
+  return ag__.not_(break_)
+# break_ becomes a loop variable.
+break_, = ag__.for_stmt(range(10), extra_test, ..., (break_,))
+```
+
+### `continue` statements
+
+Code blocks in which `continue` statements are used are rewritten with
+equivalent code that uses extra control booleans and conditionals, similar to
+how `break` is handled.
+
+For example, the `for` loop below is rewritten as (showing the output of the
+`continue` transformation only):
+
+```
+for i in range(10):
+  if i > 3:
+    continue
+```
+
+```
+for i in range(10):
+  continue_ = False
+  if i > 3:
+    continue_ = True
+  if not continue_:
+    i += 1
+```
+
+Notice that unlike `break`, `continue` statements are local to the loop and do
+not influence the number of iterations.
+
+### `return` statements
+
+`return` statements are also rewritten using control symbols, in a manner
+similar to how `break` is converted. In the case of `return` statements, an
+additional symbol keeps track of the return value.
+
+Depending on the structure of the code, the return value might be undefined
+in parts of the code (for example on code paths in which no return statement
+has executed). AutoGraph keeps track of this by using a special value.
+This special value is converted to `None` (the default return value) upon
+exiting the function.
+
+Caution: TensorFlow control flow doe not support undefined values, and an
+undefined return value is no exception. Therefore, AutoGraph will raise an
+error for TensorFlow control flow in which the return value is not known for
+all code paths.
+
+For example, the following code raises an error because the return value would
+be undefined when the random number would be less than 0.5:
+
+```
+if tf.random.uniform(()) > 0.5:
+  return 1
+```
+
+```
+ValueError: A value must also be returned from the else branch.
+```
+
+An example of rewriting a `while` (showing the output of the `return`
+transformation only):
+
+```
+def f():
+  while i < 10:
+    if i > 3:
+      return 1
+    i += 1
+```
+
+```
+def f():
+  do_return = False
+  retval_ = ag__.UndefinedReturnValue()
+  while i < 10 and not do_return:
+    if i > 3:
+      do_return = True
+      retval_ = 1
+    if not do_return:
+      i += 1
+  return ag__.retval(retval_)  # Transforms any UndefinedReturnValue to None
+```
+
+Note: AutoGraph performs an additional code normalization in which an `if`
+statement with no `else` branch contains a `return` statement it is rewritten as
+an `if-else` statement in which the code that follows the statement is moved
+under the `else` branch.
+
+Example (showing the normalization only):
+
+```
+def f():
+  if i > 3:
+    return 1
+  i += 1
+```
+
+```
+def f():
+  if i > 3:
+    return 1
+  else:
+   i += 1
+```
+
+
diff --git a/tensorflow/python/autograph/g3doc/reference/error_handling.md b/tensorflow/python/autograph/g3doc/reference/error_handling.md
new file mode 100644
index 0000000..ce3a64f
--- /dev/null
+++ b/tensorflow/python/autograph/g3doc/reference/error_handling.md
@@ -0,0 +1,213 @@
+# AutoGraph reference
+
+[Index](index.md)
+
+## Error handling
+
+When an exception occurs in code generated by AutoGraph, the error message
+is augmented with information about the location in the original code,
+before conversion.
+
+When an error occurs in a TensorFlow graph constructed using AutoGraph code,
+the stack trace which points to where the failing op was created is modified
+to point to the original code, before conversion.
+
+### Python execution errors
+
+Python execution (or tracing) exceptions that are raised in AutoGraph code are
+caught and re-raised with an extended error message that contains references
+to the original code.
+
+These functions are re-raised by `@tf.function`. If you use a `try/catch` the
+exception inside `tf.function`, you will obtain the original exception.
+
+The exception traceback still contains the entire call stack, including frames
+corresponding to generated code.
+
+AutoGraph tries to re-raise an exception of the same type as the original
+exception. This is usually possible for subclasses of
+`Exception` that do not define a custom `__init__`. For more complex
+exception types which define a custom constructor, AutoGraph raises a
+`StagingError` instead.
+
+Among the distinctive features of the re-raised exception:
+
+ * the exception traceback indicates the call stack of the exception, up to the
+   first @tf.function
+ * the error message includes references to the original code within
+   the `@tf.function`
+ * the references corresponding to converted code are marked with an
+   asterisk (`*`)
+
+For example, the code below triggers an exception in the Python runtime, at
+graph construction time:
+
+```
+@tf.function
+def f():
+  tf.constant(1) + tf.constant(1.0)
+f()
+```
+
+An excerpt of the exception that is raised is shown below:
+
+```
+Traceback (most recent call last):
+  File "<ipython-input-10-1938a51c970d>", line 11, in <module>
+    f()
+  File "tensorflow/python/eager/def_function.py", line 417, in __call__
+    self._initialize(args, kwds, add_initializers_to=initializer_map)
+  ... more TensorFlow internal frames ...
+TypeError: in converted code:
+
+    <ipython-input-9-002fa22f79df>:8 f  *
+        tf.constant(1) + tf.constant(1.0)
+    tensorflow/python/ops/math_ops.py:900 binary_op_wrapper
+        return func(x, y, name=name)
+    ... more TensorFlow internal frames ...
+
+    TypeError: Input 'y' of 'AddV2' Op has type float32 that does not match type int32 of argument 'x'.
+
+```
+
+Note: the exact appearance of the various parts in the error message may change
+in the future.
+
+Let's look at the individual components of this exception.
+
+The traceback of the exception indicates the location until the call to
+`@tf.function`, including any frames internal to TensorFlow:
+
+```
+Traceback (most recent call last):
+  File "<ipython-input-10-1938a51c970d>", line 11, in <module>
+    f()
+  File "tensorflow/python/eager/def_function.py", line 417, in __call__
+    self._initialize(args, kwds, add_initializers_to=initializer_map)
+  File "tensorflow/python/eager/def_function.py", line 360, in _initialize
+    *args, **kwds))
+  File "tensorflow/python/eager/function.py", line 1688, in _get_concrete_function_internal_garbage_collected
+    graph_function, _, _ = self._maybe_define_function(args, kwargs)
+  File "tensorflow/python/eager/function.py", line 1992, in _maybe_define_function
+    graph_function = self._create_graph_function(args, kwargs)
+  File "tensorflow/python/eager/function.py", line 1878, in _create_graph_function
+    capture_by_value=self._capture_by_value),
+  File "tensorflow/python/framework/func_graph.py", line 791, in func_graph_from_py_func
+    func_outputs = python_func(*func_args, **func_kwargs)
+  File "tensorflow/python/eager/def_function.py", line 310, in wrapped_fn
+    return weak_wrapped_fn().__wrapped__(*args, **kwds)
+  File "tensorflow/python/framework/func_graph.py", line 781, in wrapper
+    raise e.ag_error_metadata.to_exception(type(e))
+```
+
+The exception message includes the location inside the converted function `f`:
+
+```
+TypeError: in converted code:
+
+    <ipython-input-9-002fa22f79df>:8 f  *
+        tf.constant(1) + tf.constant(1.0)
+    tensorflow/python/ops/math_ops.py:900 binary_op_wrapper
+        return func(x, y, name=name)
+    tensorflow/python/ops/math_ops.py:1198 _add_dispatch
+        return gen_math_ops.add_v2(x, y, name=name)
+    tensorflow/python/ops/gen_math_ops.py:549 add_v2
+        "AddV2", x=x, y=y, name=name)
+    tensorflow/python/framework/op_def_library.py:564 _apply_op_helper
+        inferred_from[input_arg.type_attr]))
+```
+
+Notice the frame corresponding to the call of `f`. The function is converted,
+which is being indicated by the asterisk `*` character displayed next to
+`f`:
+
+```
+    <ipython-input-9-002fa22f79df>:8 f  *
+        tf.constant(1) + tf.constant(1.0)
+```
+
+Lastly, the lower part includes the message that the exception originally
+reported:
+
+```
+    TypeError: Input 'y' of 'AddV2' Op has type float32 that does not match type int32 of argument 'x'.
+```
+
+Note: Typically, error messages raised by code internal to TensorFlow refers
+to arguments of the internal API that failed. Error messages raised by code
+internal to AutoGraph (that is, 'tensorflow/python/autograph') usually
+refer to symbols used in your code.
+
+### TensorFlow execution errors
+
+TensorFlow execution errors are displayed normally, but the portions of the
+error message which correspond to user code contain references to the original
+code.
+
+For example, the code below triggers an error in the TensorFlow runtime, at
+graph execution time:
+
+```
+@tf.function
+def my_function():
+  tf.Assert(tf.random.uniform(()) > 1.0, ['example error'])
+my_function()
+```
+
+An excerpt of the exception that is subsequently raised is shown below:
+
+```
+Traceback (most recent call last):
+  File "<ipython-input-16-af656fb445f0>", line 11, in <module>
+    my_function()
+  File "tensorflow/python/eager/def_function.py", line 435, in __call__
+    return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds)
+  File "tensorflow/python/eager/function.py", line 636, in _filtered_call
+    self.captured_inputs)
+  File "tensorflow/python/eager/function.py", line 734, in _call_flat
+    outputs = self._inference_function.call(ctx, args)
+  File "tensorflow/python/eager/function.py", line 460, in call
+    ctx=ctx)
+  File "tensorflow/python/eager/execute.py", line 68, in quick_execute
+    six.raise_from(core._status_to_exception(e.code, message), None)
+  File "<string>", line 3, in raise_from
+InvalidArgumentError:  assertion failed: [example error]
+    [[node Assert/Assert (defined at <ipython-input-16-af656fb445f0>:8) ]] [Op:__inference_my_function_79]
+```
+
+Notice the error message containing references to the location where the failing
+op was defined in the code (`<ipython-input-16-af656fb445f0>:8`):
+
+```
+InvalidArgumentError:  assertion failed: [example error]
+    [[node Assert/Assert (defined at <ipython-input-16-af656fb445f0>:8) ]] [Op:__inference_my_function_79]
+```
+
+### AutoGraph conversion exceptions
+
+Within `@tf.function`, when AutoGraph fails to convert a function, it displays
+a warning message and attempts to run the function without conversion.
+
+For example, the code below make a call to a Python
+[generator](https://wiki.python.org/moin/Generators) function, which is not
+supported by AutoGraph:
+
+```
+def example_generator():
+  yield 1
+
+@tf.function
+def f():
+  for i in example_generator():
+    print(i)
+```
+
+Calling `f()` will still run the code. AutoGraph will convert the function `f`,
+but skips the function `example_generator`. In addition, AutoGraph prints a
+warning to the console indicating that the function is called without being
+converted.
+
+```
+WARNING: Entity <function example_generator at 0x7f951b67f158> appears to be
+a generator function. It will not be converted by AutoGraph.
+```
diff --git a/tensorflow/python/autograph/g3doc/reference/functions.md b/tensorflow/python/autograph/g3doc/reference/functions.md
new file mode 100644
index 0000000..f2768a0
--- /dev/null
+++ b/tensorflow/python/autograph/g3doc/reference/functions.md
@@ -0,0 +1,65 @@
+# AutoGraph reference
+
+[Index](index.md)
+
+## Functions and function calls
+
+Typically, AutoGraph converts one function at a time. If a function calls other
+functions, the called function will be converted recursively, as described
+below.
+
+### Function calls
+
+AutoGraph rewrites all function calls with a special wrapper that may convert
+the called function at runtime.
+
+For example, the function call below:
+
+```
+f(x, y, z=1)
+```
+
+Is converted to code that schematically looks like this:
+
+```
+ag__.converted_call(f, ..., (x, y), {'z': 1}, ...)
+```
+
+All calls are rewritten, including calls to other types of callables, builtin
+functions, etc.
+
+If the originally called function is not converted, AutoGraph simply
+forwards the call to it, so that the wrapper is functionally equivalent with
+the original function call.
+
+If the originally called function is converted, then the conversion is performed
+first and the converted function is called instead.
+
+Note: a caching mechanism prevents the same function from being converted
+multiple times. This mechanism ensures that functions calls made with different
+[global or free variables](https://docs.python.org/3/reference/executionmodel.html#binding-of-names)
+are handled correctly.
+
+#### Function conversion rules
+
+The following types of functions are not converted:
+
+  * functions already converted
+  * functions defined in in a whitelisted module (see autograph/core/config.py)
+  * non-Python functions (such as native bindings)
+  * `print`, `pdb.set_trace`, `ipdb.set_trace`
+  * most built-in functions (exceptions are listed in
+    autograph/operators/py_builtins.py)
+  * constructors
+  * functions without source code attached (prints a warning)(see
+    [limitations](limitations.md))
+  * generator functions (prints a warning)
+
+When AutoGraph encounters a function that it cannot convert outside of this
+list, it prints a warning.
+
+### Nested functions
+
+Functions nested inside a function converted by AutoGraph are converted
+at the same time as the function containing them. If the nested function is
+returned, a converted version of it is returned.
diff --git a/tensorflow/python/autograph/g3doc/reference/index.md b/tensorflow/python/autograph/g3doc/reference/index.md
index 1a12596..6fb7ab6 100644
--- a/tensorflow/python/autograph/g3doc/reference/index.md
+++ b/tensorflow/python/autograph/g3doc/reference/index.md
@@ -2,11 +2,21 @@
 
 This reference document describes the semantics of AutoGraph transformations.
 
+In `@tf.function`, AutoGraph allows running Eager-style code as a TensorFlow
+graph.
+
 *   [Introduction](intro.md)
 *   [Interacting with the generated code](generated_code.md)
 *   [Debugging AutoGraph code](debugging.md)
-*   Control Flow (coming soon)
-*   Collections (coming soon)
-*   Exceptions (coming soon)
-*   Builtin Functions (coming soon)
-*   Datasets (coming soon)
+*   [Control flow](control_flow.md)
+*   [Functions and function calls](functions.md)
+*   [Error handling](error_handling.md)
+*   [Limitations](limitations.md)
+*   [Common errors](common_errors.md)
+
+For more information on AutoGraph, see the following articles:
+
+*   [AutoGraph tutorial](https://www.tensorflow.org/alpha/beta/autograph)
+*   [Eager tutorial](https://www.tensorflow.org/alpha/guide/eager)
+*   [TensorFlow 2.0 Alpha](https://www.tensorflow.org/alpha)
+*   [AutoGraph blog post](https://medium.com/tensorflow/autograph-converts-python-into-tensorflow-graphs-b2a871f87ec7)
diff --git a/tensorflow/python/autograph/g3doc/reference/intro.md b/tensorflow/python/autograph/g3doc/reference/intro.md
index 1c720fd..1de0069 100644
--- a/tensorflow/python/autograph/g3doc/reference/intro.md
+++ b/tensorflow/python/autograph/g3doc/reference/intro.md
@@ -4,15 +4,6 @@
 
 ## Introduction
 
-This document describes the semantics of AutoGraph's code transformations.
-
-For more information on AutoGraph, see the following articles:
-
-*   [AutoGraph tutorial](https://www.tensorflow.org/alpha/guide/autograph)
-*   [Eager tutorial](https://www.tensorflow.org/alpha/guide/eager)
-*   [TensorFlow 2.0 Alpha](https://www.tensorflow.org/alpha)
-*   [AutoGraph blog post](https://medium.com/tensorflow/autograph-converts-python-into-tensorflow-graphs-b2a871f87ec7)
-
 ### Terminology
 
 Typically, AutoGraph operates by converting a function into a new function with
diff --git a/tensorflow/python/autograph/g3doc/reference/limitations.md b/tensorflow/python/autograph/g3doc/reference/limitations.md
new file mode 100644
index 0000000..ebfab4b
--- /dev/null
+++ b/tensorflow/python/autograph/g3doc/reference/limitations.md
@@ -0,0 +1,463 @@
+# AutoGraph reference
+
+[Index](index.md)
+
+## Limitations
+
+When AutoGraph is applied to normal Python code, you should expect no change
+in functionality.
+However, when applied to TensorFlow control flow (for example, an if statement
+with a `tf.Tensor` condition), there are certain limitations. This section
+describes these limitations and practices that will allow you to avoid them.
+
+Key Term: Python variables refer to Python symbols (or symbols for short) and
+should not be confused with TensorFlow variables.
+
+Key Term: A TensorFlow loop variable (or loop variable for short) refers to a
+value (typically a `tf.Tensor`) modified by a loop. See `tf.while_loop`.
+
+### Indirect modifications and hidden side effects in TensorFlow control flow
+
+<!-- TODO(mdan) Refine this paragraph well - it's important -->
+Key Point: We recommend using functional style and immutable Python collections.
+
+#### AutoGraph analyzes code to detect modifications
+
+One of the most important functions of AutoGraph is to rewrite Python control
+flow statements into equivalent TensorFlow ops. This process requires "wiring"
+variables in the Python code whose values are affected these statements control
+flow into the respective ops.
+
+The examples below use a `while` loop, but the same notions extend to all
+control flow: `if` and `for` statements.
+
+In the example below, `x` needs to become a loop variable of the
+corresponding `tf.while_loop':
+
+```
+while x > 0:
+  x = x - 1
+```
+```
+x = tf.while_loop(..., loop_vars=(x,)
+```
+
+TF control ops support only a limited set of types for loop variable. At the
+same time, the efficiency of TensorFlow graphs is influenced by the number of
+loop variables, so we don't want to create them unnecessarily. For this reason,
+AutoGraph only pulls symbols through loop variables if necessary.
+
+Note: If a symbol refers to a nested structure, such as a `dict` of `dict`s,
+then when that symbol is added to the loop variables the entire structure
+becomes part of the loop variables - TensorFlow automatically unpacks it.
+
+For example, the symbol 'y' below is not wired through the `tf.while_loop`'s
+`loop_vars` because it is not affected by the while loop:
+
+```
+y = 0
+while x > 0:
+  x = x - 1
+print(y)
+```
+```
+x = tf.while_loop(..., loop_vars=(x,)  # y does not need to be a loop variable
+```
+
+AutoGraph uses static analysis to determine which symbols are modified by the
+code, in order to transform them into control flow variables. Static analysis
+is generally performed on single functions - Python's dynamic nature limits its
+effectiveness across functions.
+
+#### Modifications are not detected across functions
+
+Because static analysis is limited to single functions, modifications that are
+performed in other functions are not visible to AutoGraph:
+
+```
+def change_y():
+  global y
+  y = y + 1
+
+while x > 0:
+  change_y()  # Problem -- change made to y is not visible here!
+```
+
+This can be easily remedied using functional style - writing functions that take
+their inputs as arguments, and return everything they calculate as return
+values:
+
+```
+def change(y):
+  y = y + 1
+  return y
+
+while x > 0:
+  y = change(y)  # Okay -- y can now be properly tracked!
+```
+
+#### Modifications are not detected in methods
+
+A special case of hidden side effects are methods, which are commonly used
+to change the value of objects:
+
+```
+def MyClass(object):
+  def change(self):
+    self.y += 1
+
+c = MyClass()
+while x > 0:
+  c.change()  # Problem -- modification to c.y is not visible here!
+```
+
+This can be addressed in a number of ways.
+
+One possibility is to operate directly on the object properties:
+
+```
+c = MyClass()
+while x > 0:
+  c.y += 1  # Okay -- c.y can now be properly tracked!
+```
+
+Another possibility is to rely on immutable objects. This may lead to many
+temporary objects when executing eagerly, but their number is greatly reduced
+in `@tf.function`:
+
+```
+def MyClass(object):
+  def change(self):
+    self.y += 1
+    return self
+
+c = MyClass()
+while x > 0:
+  c = c.change()  # Okay -- c is now a loop var.
+```
+
+Note: TensorFlow control flow does not currently support arbitrary Python
+objects, but it does support basic collection objects such as `list`, `dict`,
+`tuple`, `namedtuple` and their subclasses. Design your objects as subclasses
+of [namedtuple](https://docs.python.org/3/library/collections.html#collections.namedtuple).
+
+### Python collections in TensorFlow control flow
+
+Key Point: Use TensorFlow collection classes instead of Python collections.
+Python collections are okay to use when they represent a fixed structure (that
+is, `list`s don't change length, `dict`s don't add or remove keys).
+
+#### Modifying Python collections in TensorFlow control flow is not allowed
+
+One of the advantages of eager execution is that you may use the usual Python
+collections, like `list` or `dict` to hold `tf.Tensor` values. However, these
+are generally not compatible with TensorFlow control flow. Specialized
+collections like `tf.TensorArray` are required.
+
+Consider the following example:
+
+```
+def fn():
+  l = []
+
+  def loop_cond(i):
+    return i < 10
+
+  def loop_body(i):
+    i = i + 1
+    l.append(i)
+    return i,
+
+  tf.while_loop(
+      cond=loop_cond,
+      body=loop_body,
+      loop_vars=(0,))
+
+  return l
+```
+
+This code works in eager execution, which does not use the TensorFlow runtime
+for the `tf.while_loop`:
+
+```
+fn()
+```
+
+However, it does not work in graph execution, because TensorFlow uses special
+mechanisms to ensure the computations are correctly sequenced in the dataflow
+graph:
+
+```
+tf.function(fn)()  # Error -- illegal tensor capture!
+```
+
+The equivalent AutoGraph code raises the same error:
+
+```
+l = []
+for i in tf.range(10):
+  l.append(i)  # Error -- illegal tensor capture!
+```
+
+Instead, use the specialized `tf.TensorArray` class:
+
+```
+l = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
+for i in tf.range(10):
+  l = l.write(l.size(), i)  # Okay
+```
+
+#### Python collections of fixed structure are allowed TensorFlow control flow
+
+An exception from the previous rule is made by Python collections that are
+static, that is, they don't grow in size for the duration of the computation.
+
+Caution: Use functional style when manipulating static collections.
+
+Examples:
+
+```
+static_list = [tf.constant(3)]
+while d.prop > 0:
+  static_list[0] -= 1  # Okay -- static_list does not change structure
+```
+```
+static_object = MyClass()
+static_object.field = tf.constant(3)
+while static_object.field > 0:
+  static_object.field -= 1  # Okay -- static_object does not change structure
+```
+```
+static_dict = {'field': tf.constant(3)}
+while static_dict['field'] > 0:
+  static_dict['field'] -= 1  # Okay -- static_dict does not change structure
+```
+
+However, remember to use functional style when these collections are used
+inside control flow.
+
+#### Python collections of fixed structure with dynamic index
+
+A more subtle error occurs when the collection is static, but is accessed in a
+dynamic way, that is with a key that is not constant.
+
+For example:
+
+```
+d = {'a': tf.constant(3)}
+for i in tf.range(10):
+  for key in d:
+    d[key] += i  # Problem -- accessing `dict` using non-constant key
+```
+
+The code above will raises an "illegal capture" error. To remedy it, write it
+in functional style:
+
+```
+d = {'a': tf.constant(3)}
+for i in tf.range(10):
+  d = {key: value + i for key, value in d.items()}  # Okay
+```
+
+### Shape and dtype consistency in TensorFlow control flow
+
+Unlike Python, TensorFlow has limited support for dynamic typing. This means
+that tensors must maintain consistent shapes and dtypes across control flow
+paths.
+
+Note: In general, these restrictions do not apply in control flow in Eager
+execution, because Eager execution uses Python control flow, rather than
+TensorFlow control flow ops.
+
+#### Consistency of dtype
+
+The dtypes across all code paths must be consistent in conditionals and loops.
+
+For example, if a `tf.cond` (and correspondingly, an AutoGraph `if`) sets a
+tensor value conditionally, then that tensor must have the same shape and dtype
+in both branches of the conditional.
+
+Example of illegal dtype change in a conditional:
+
+```
+x = tf.cond(
+    tf.random.uniform(()) > 0.5,
+    lambda: tf.constant(1, dtype=tf.int32),
+    lambda: tf.constant(1, dtype=tf.float32))  # Error -- inconsistent dtypes: int32, float32
+```
+
+The same restriction in AutoGraph code:
+
+```
+if tf.random.uniform(()) > 0.5:
+  x = tf.constant(1, dtype=tf.int32)
+else:
+  x = tf.constant(1, dtype=tf.float32)  # Error -- inconsistent dtypes: int32, float32
+```
+
+Example of illegal dtype change in a loop:
+
+```
+# This won't work - "x" changes dtype inside the loop.
+x = tf.while_loop(
+    lambda _: tf.random.uniform(()) > 0.5,
+    lambda x: tf.constant(1, dtype=tf.float32),
+    loop_vars=(tf.constant(1, dtype=tf.int32),))  # Error -- inconsistent dtypes: int32, float32
+```
+
+The same restriction in AutoGraph code:
+
+```
+x = tf.constant(0, dtype=tf.int32)
+while tf.random.uniform(()) > 0.5:
+  x = tf.constant(0, dtype=tf.float32)   # Error -- inconsistent dtypes: int32, float32
+```
+
+#### Consistency of shape
+
+The shapes across all code paths must be consistent in loops only. When tensors
+do need to change shape across iterations, use `shape_invariants`.
+
+Note: Shapes are allowed to be inconsistent in conditionals. The result will be
+a partially dynamic shape.
+
+In a `tf.while_loop` (and correspondingly, an AutoGraph `while` or `for` loop)
+all loop variables must maintain consistent shape and dtype across iterations.
+That is, every loop variable must have the same shape at the end of the loop
+body as the shape that it had at the beginning of the loop body.
+
+Example of illegal shape change in a loop:
+
+```
+def loop_body(x):  # x.shape is ()
+  return tf.constant((1, 2, 3))  # Error -- inconsistent shapes: (), (3,)
+
+x = tf.while_loop(
+    lambda _: tf.random.uniform(()) > 0.5,
+    loop_body,
+    loop_vars=(tf.constant(1,))
+```
+
+The same restriction in AutoGraph code:
+
+```
+x = tf.constant(0, dtype=tf.int32)
+while tf.random.uniform(()) > 0.5:
+  x = tf.constant(0, dtype=tf.float32)  # Error -- inconsistent shapes: (), (3,)
+```
+
+### Undefined and None values in TensorFlow
+
+TensorFlow does not support undefined and `None` values. All tensors must have
+a value.
+
+Example:
+
+```
+x = tf.cond(
+    tf.random.uniform(()) > 0.5,
+    lambda: tf.constant(1),
+    lambda: None)  # Error -- a Tensor cannot be None
+```
+
+The same restriction carries over in AutoGraph, but only if the symbol is used
+after the conditional (otherwise AutoGraph avoids making it a return value
+of the `tf.cond`):
+
+```
+if tf.random.uniform(()) > 0.5:
+  x = tf.constant(1)
+else:
+  x = None
+tf.print(x)  # Error -- x may be None here
+```
+
+A related but less obvious restriction in AutoGraph forbids symbols to be
+defined in only one branch of TensorFlow control flow, if the symbol is
+used afterwards:
+
+```
+del x
+if tf.random.uniform(()) > 0.5:
+  x = tf.constant(1)
+else:
+  pass
+tf.print(x)  # Error -- x may be undefined here
+```
+
+Similarly, variables defined in a loop may not be used outside the loop, again
+if the symbol is used afterwards:
+
+```
+del x
+if tf.random.uniform(()) > 0.5:
+  x = tf.constant(1)
+tf.print(x)  # Error -- x may be undefined here
+```
+
+Avoid these limitations by defining a default value before the control flow
+statement:
+
+```
+x = tf.constant()
+if tf.random.uniform(()) > 0.5:
+  x = tf.constant(1)
+tf.print(x)  # Okay -- x is either 0 or 1
+```
+
+Note: `None` values and undefined symbols are allowed in Eager control flow,
+because Eager execution uses Python control flow, rather than TensorFlow
+control flow ops.
+
+### Access to source code
+
+Key point: AutoGraph can only handle functions whose source code can be
+accessed at runtime.
+
+Almost all Python functions allow access to their source code. However, a few
+exceptions exist:
+
+ * functions created in the Python interactive shell
+ * functions with native bindings (these do not have Python source code)
+ * functions created dynamically, using `exec` or `eval`
+
+Use
+[inspect.getsource](https://docs.python.org/3/library/inspect.html#inspect.getsource)
+to quickly diagnose whether the source code is available for a function.
+
+#### Source code of lambda functions
+
+Key Point: Declare lambda functions on separate lines to avoid failures to
+load their source code.
+
+The Python runtime exposes the source code of lambda functions, however it
+may include surrounding code. Typically, the code includes all the lines that
+contained the lambda function, including surrounding code. This may make it
+impossible to parse the exact source code of the lambda function.
+
+For example, consider the declaration of a lambda function below, which
+is otherwise valid Python code:
+
+```
+foo = (
+ 'bar',
+ lambda: x)
+```
+
+The Python runtime will report the following source code for `foo[0]`:
+
+```
+>>> inspect.getsource(foo[0])
+' lambda: x)\n'
+```
+
+The code is the entire line of code at which the lambda was declared. Because
+the line is part of a larger expression, the line itself is not syntactically
+correct and cannot be parsed.
+
+This shortcoming can be avoided by declaring the lambda function separately:
+
+```
+my_lambda = lambda: x
+foo = ('bar', my_lambda)
+```
diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py
index d850937..283e294 100644
--- a/tensorflow/python/autograph/impl/api.py
+++ b/tensorflow/python/autograph/impl/api.py
@@ -28,7 +28,6 @@
 import sys
 import textwrap
 import traceback
-from enum import Enum
 
 # pylint:disable=g-bad-import-order
 import six
@@ -42,7 +41,6 @@
 from tensorflow.python.autograph.pyct import inspect_utils
 from tensorflow.python.autograph.pyct import origin_info
 from tensorflow.python.autograph.utils import ag_logging as logging
-from tensorflow.python.autograph.utils import py_func
 from tensorflow.python.framework import errors_impl
 from tensorflow.python.util import tf_decorator
 from tensorflow.python.util import tf_inspect
@@ -73,37 +71,41 @@
 class _ErrorMetadata(errors.ErrorMetadataBase):
   """AutoGraph-specific error metadata. See base class."""
 
-  def create_exception(self, preferred_type):
-    if preferred_type == errors_impl.OpError:
+  def create_exception(self, source_error):
+    preferred_type = type(source_error)
+    if issubclass(preferred_type, errors_impl.OpError):
       # Best-effort unpacking of OpError exceptions.
       # TODO(mdan): Use a mechanism that is more future-proof.
-      t = type(self.cause)
-      init_argspec = tf_inspect.getfullargspec(t.__init__)
+      init_argspec = tf_inspect.getfullargspec(preferred_type.__init__)
       message = self.get_message()
-      init_args = tuple(init_argspec.argspec)
+      init_args = tuple(init_argspec.args)
       # At the time of this writing, TF errors either take 3 or 4 arguments,
       # with the fourth being error_code.
       if init_args == ('self', 'node_def', 'op', 'message', 'error_code'):
-        return t(
-            node_def=self.cause.node_def,
-            op=self.cause.op,
+        return preferred_type(
+            node_def=source_error.node_def,
+            op=source_error.op,
             message=message,
             error_code=self.error_code)
       elif init_args == ('self', 'node_def', 'op', 'message'):
         if 'error_code' in init_argspec.kwonlyargs:
-          return t(
-              node_def=self.cause.node_def,
-              op=self.cause.op,
+          return preferred_type(
+              node_def=source_error.node_def,
+              op=source_error.op,
               message=message,
               errro_code=self.error_code)
         else:
-          return t(
-              node_def=self.cause.node_def, op=self.cause.op, message=message)
+          return preferred_type(
+              node_def=source_error.node_def,
+              op=source_error.op,
+              message=message)
 
-    elif preferred_type in (AutoGraphError, ConversionError, StagingError):
+    elif preferred_type in (AutoGraphError, ConversionError, StagingError,
+                            errors_impl.InaccessibleTensorError,
+                            errors_impl.OperatorNotAllowedInGraphError):
       return preferred_type(self.get_message())
 
-    exc = super(_ErrorMetadata, self).create_exception(preferred_type)
+    exc = super(_ErrorMetadata, self).create_exception(source_error)
     if exc is not None:
       return exc
 
@@ -121,16 +123,33 @@
   def __init__(self, converted_fn):
     self._source_map = converted_fn.ag_source_map
 
-  def map(self, filename, lineno, name):
-    loc = origin_info.LineLocation(filename=filename, lineno=lineno)
-    if loc not in self._source_map:
-      return filename, lineno, name
+  def get_effective_source_map(self):
+    effective_source_map = self._effective_source_map
+    if effective_source_map is None:
+      if self.parent is not None:
+        parent_map = self.parent.get_effective_source_map()
+      else:
+        parent_map = {}
 
-    origin = self._source_map[loc]
-    return origin.loc.filename, origin.loc.lineno, origin.function_name
+      effective_source_map = {}
+      for loc, origin in self._source_map.items():
+        effective_source_map[(loc.filename, loc.lineno)] = (
+            origin.loc.filename, origin.loc.lineno, origin.function_name)
+
+      for key, value in parent_map.items():
+        filename, lineno, _ = value
+        value_loc = origin_info.LineLocation(filename=filename, lineno=lineno)
+        if value_loc in self._source_map:
+          origin = self._source_map[value_loc]
+          effective_source_map[key] = (
+              origin.loc.filename, origin.loc.lineno, origin.function_name)
+        else:
+          effective_source_map[key] = value
+      self._effective_source_map = effective_source_map
+    return effective_source_map
 
 
-def tf_convert(f, ctx, convert_by_default=True, force_conversion=False):
+def tf_convert(f, ctx, convert_by_default=True, user_requested=False):
   """Decorator that applies AutoGraph to a function.
 
   Use in internal APIs.
@@ -147,8 +166,8 @@
     ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used.
     convert_by_default: bool, whether to use AutoGraph when the context doesn't
       specify.
-    force_conversion: bool, whether to ignore the conversion whitelist. See
-      ConversionOptions.force_conversion.
+    user_requested: bool, whether to ignore the conversion whitelist. See
+      ConversionOptions.user_requested.
 
   Returns:
     Either `f or the converted version of `f`.
@@ -161,12 +180,12 @@
 
   # TODO(mdan): Grab features from context.
   if ctx.status == ag_ctx.Status.ENABLED:
-    wrapper = convert(recursive=True, force_conversion=force_conversion)(f)
+    wrapper = convert(recursive=True, user_requested=user_requested)(f)
   elif ctx.status == ag_ctx.Status.DISABLED:
     wrapper = do_not_convert(f)
   elif ctx.status == ag_ctx.Status.UNSPECIFIED:
     if convert_by_default:
-      wrapper = convert(recursive=True, force_conversion=force_conversion)(f)
+      wrapper = convert(recursive=True, user_requested=user_requested)(f)
     else:
       wrapper = call_with_unspecified_conversion_status(f)
   else:
@@ -180,7 +199,7 @@
 
 
 # TODO(mdan): Make private.
-def convert(recursive=False, optional_features=None, force_conversion=True):
+def convert(recursive=False, optional_features=None, user_requested=True):
   """Decorator that compiles a function to use TensorFlow ops.
 
   The decorator is dynamic - it recompiles the target whenever the decorated
@@ -194,8 +213,8 @@
     optional_features: converted.Feature, allows toggling optional or
       experimental features. When set to None, only the core features are
       enabled.
-    force_conversion: bool, whether to ignore the conversion whitelist. See
-      ConversionOptions.force_conversion.
+    user_requested: bool, whether to ignore the conversion whitelist. See
+      ConversionOptions.user_requested.
 
   Returns:
     Callable, a decorator that converts the given function into an equivalent
@@ -207,21 +226,17 @@
 
     def wrapper(*args, **kwargs):
       """Wrapper that calls the converted version of f."""
-      with ag_ctx.ControlStatusCtx(
-          status=ag_ctx.Status.ENABLED, options=optional_features):
-        try:
-          return converted_call(
-              f,
-              converter.ConversionOptions(
-                  recursive=recursive,
-                  force_conversion=force_conversion,
-                  optional_features=optional_features,
-              ), args, kwargs)
-        except Exception as e:  # pylint:disable=broad-except
-          if hasattr(e, 'ag_error_metadata'):
-            raise e.ag_error_metadata.to_exception(type(e))
-          else:
-            raise
+      options = converter.ConversionOptions(
+          recursive=recursive,
+          user_requested=user_requested,
+          optional_features=optional_features)
+      try:
+        return converted_call(f, options, args, kwargs)
+      except Exception as e:  # pylint:disable=broad-except
+        if hasattr(e, 'ag_error_metadata'):
+          raise e.ag_error_metadata.to_exception(e)
+        else:
+          raise
 
     if inspect.isfunction(f) or inspect.ismethod(f):
       wrapper = functools.update_wrapper(wrapper, f)
@@ -236,20 +251,6 @@
   return decorator
 
 
-class RunMode(Enum):
-  """Specifies the way a converted function or method should be executed in TF.
-
-  Attributes:
-   * GRAPH: Call this function directly, as-is. This is suitable for functions
-     that were already designed for TF graphs and contain ops.
-   * PY_FUNC: Wrap this function into a py_func op. This is suitable for code
-     that will only run correctly in Python, for example code that renders to
-     the display, reads keyboard input, etc.
-  """
-  GRAPH = 1
-  PY_FUNC = 2
-
-
 def call_with_unspecified_conversion_status(func):
   """Decorator that resets the conversion context to the unspecified status."""
   def wrapper(*args, **kwargs):
@@ -270,18 +271,11 @@
 
 
 @tf_export('autograph.experimental.do_not_convert')
-def do_not_convert(func=None, run_as=RunMode.GRAPH, return_dtypes=None):
+def do_not_convert(func=None):
   """Decorator that suppresses the conversion of a function.
 
-  See also: docs/pyfunc_dtypes.md
-
   Args:
     func: function to decorate.
-    run_as: RunMode, specifies how to use the function in TensorFlow.
-    return_dtypes: Optional[Iterable[ Union[tf.DType,
-      utils.py_func.MatchDType]]], the return data types of the converted
-      function, if run_as is RunMode.PY_FUNC. Ignored otherwise. May be set to
-      None if the function has no return values.
 
   Returns:
     If `func` is not None, returns a `Callable` which is equivalent to
@@ -291,29 +285,12 @@
     above case.
   """
   if func is None:
-    return functools.partial(
-        do_not_convert,
-        run_as=run_as,
-        return_dtypes=return_dtypes)
+    return do_not_convert
 
-  def graph_wrapper(*args, **kwargs):
+  def wrapper(*args, **kwargs):
     with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED):
       return func(*args, **kwargs)
 
-  def py_func_wrapper(*args, **kwargs):
-    if kwargs:
-      raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs')
-    # TODO(mdan): Add support for kwargs.
-    return py_func.wrap_py_func(
-        func, return_dtypes, args, kwargs, use_dummy_return=not return_dtypes)
-
-  if run_as == RunMode.GRAPH:
-    wrapper = graph_wrapper
-  elif run_as == RunMode.PY_FUNC:
-    wrapper = py_func_wrapper
-  else:
-    raise ValueError('unknown value for run_as: %s' % run_as)
-
   if inspect.isfunction(func) or inspect.ismethod(func):
     wrapper = functools.update_wrapper(wrapper, func)
 
@@ -385,8 +362,23 @@
   return False
 
 
-def converted_call(f, options, args, kwargs):
-  """Compiles a function call inline. For internal use only."""
+def converted_call(f, options, args, kwargs, caller_fn_scope=None):
+  """Compiles a function call inline.
+
+  For internal use only.
+
+  Args:
+    f: The function to convert.
+    options: converter.ConversionOptions
+    args: Tuple, the original positional arguments of f
+    kwargs: Dict, the original keyword arguments of f
+    caller_fn_scope: Optional[function_wrappers.FunctionScope], the function
+      scope of the converted function in which this call was originally made.
+
+  Returns:
+    Any, the result of executing a possibly-converted `f` with the given
+      arguments.
+  """
   logging.log(1, 'Converted call: %s\n    args: %s\n    kwargs: %s\n', f, args,
               kwargs)
 
@@ -395,7 +387,9 @@
 
   if inspect_utils.isbuiltin(f):
     if f is eval:
-      return py_builtins.eval_in_original_context(f, args, 1)
+      return py_builtins.eval_in_original_context(f, args, caller_fn_scope)
+    if f is super:
+      return py_builtins.super_in_original_context(f, args, caller_fn_scope)
     if kwargs:
       return py_builtins.overload_of(f)(*args, **kwargs)
     else:
@@ -441,7 +435,7 @@
     logging.log(2, 'Permanently whitelisted: %s: TensorFlow plugin', f)
     return _call_unconverted(f, args, kwargs, options)
 
-  if not options.force_conversion and conversion.is_whitelisted_for_graph(f):
+  if not options.user_requested and conversion.is_whitelisted_for_graph(f):
     return _call_unconverted(f, args, kwargs, options)
 
   # internal_convert_user_code is for example turned off when issuing a dynamic
@@ -507,10 +501,9 @@
                     target_entity)
         return _call_unconverted(f, args, kwargs, options)
 
-    converted_f = to_graph(
-        target_entity,
-        recursive=options.recursive,
-        experimental_optional_features=options.optional_features)
+    program_ctx = converter.ProgramContext(
+        options=options, autograph_module=tf_inspect.getmodule(converted_call))
+    converted_f = conversion.convert(target_entity, program_ctx)
 
     if logging.has_verbosity(2):
       logging.log(2, 'Defaults of %s : %s', converted_f,
@@ -615,6 +608,7 @@
     program_ctx = converter.ProgramContext(
         options=converter.ConversionOptions(
             recursive=recursive,
+            user_requested=True,
             optional_features=experimental_optional_features),
         autograph_module=tf_inspect.getmodule(to_graph))
     return conversion.convert(entity, program_ctx)
diff --git a/tensorflow/python/autograph/impl/api_py3_test.py b/tensorflow/python/autograph/impl/api_py3_test.py
index 951b313..d1ae215 100644
--- a/tensorflow/python/autograph/impl/api_py3_test.py
+++ b/tensorflow/python/autograph/impl/api_py3_test.py
@@ -38,6 +38,27 @@
                            (), {'a': constant_op.constant(-1)})
     self.assertEqual(-1, self.evaluate(x))
 
+  def test_super_with_no_arg(self):
+    test_case_self = self
+
+    class TestBase:
+
+      def plus_three(self, x):
+        return x + 3
+
+    class TestSubclass(TestBase):
+
+      def plus_three(self, x):
+        test_case_self.fail('This should never be called.')
+
+      def no_arg(self, x):
+        return super().plus_three(x)
+
+    tc = api.converted_call(TestSubclass,
+                            converter.ConversionOptions(recursive=True), (), {})
+
+    self.assertEqual(5, tc.no_arg(2))
+
 
 if __name__ == '__main__':
   os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'
diff --git a/tensorflow/python/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py
index 43330f7..5a969c8 100644
--- a/tensorflow/python/autograph/impl/api_test.py
+++ b/tensorflow/python/autograph/impl/api_test.py
@@ -35,7 +35,6 @@
 from tensorflow.python.autograph.impl import api
 from tensorflow.python.autograph.pyct import inspect_utils
 from tensorflow.python.autograph.pyct import parser
-from tensorflow.python.autograph.utils import py_func
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import function
@@ -108,11 +107,11 @@
       self.assertListEqual([0, 1], self.evaluate(x).tolist())
 
   @test_util.run_deprecated_v1
-  def test_convert_then_do_not_convert_graph(self):
+  def test_convert_then_do_not_convert(self):
 
     class TestClass(object):
 
-      @api.do_not_convert(run_as=api.RunMode.GRAPH)
+      @api.do_not_convert
       def called_member(self, a):
         return tf.negative(a)
 
@@ -129,32 +128,6 @@
     self.assertAllEqual((0, 1), self.evaluate(x))
 
   @test_util.run_deprecated_v1
-  def test_convert_then_do_not_convert_py_func(self):
-
-    class TestClass(object):
-
-      @api.do_not_convert(
-          run_as=api.RunMode.PY_FUNC, return_dtypes=py_func.MatchDType(1))
-      def called_member(self, a):
-        return np.negative(a)
-
-      @api.convert(recursive=True)
-      def test_method(self, x, s, a):
-        while tf.reduce_sum(x) > s:
-          y = self.called_member(a)
-          # set_shape works around while_loop's limitations.
-          # TODO(mdan): Allow specifying shapes (or ShapeLike) instead.
-          y.set_shape(a.shape)
-          x //= y
-        return x
-
-    tc = TestClass()
-    x = tc.test_method(
-        constant_op.constant((2, 4)), constant_op.constant(1),
-        constant_op.constant(-2))
-    self.assertAllEqual((0, 1), self.evaluate(x))
-
-  @test_util.run_deprecated_v1
   def test_decorator_calls_decorated(self):
 
     class TestClass(object):
@@ -456,8 +429,7 @@
     # tc is still a TestClass - constructors are whitelisted.
     # TODO(b/124016764): Support this use case.
     # The error below is specific to the `if` statement not being converted.
-    with self.assertRaisesRegex(TypeError,
-                                'Using a `tf.Tensor` as a Python `bool`'):
+    with self.assertRaises(TypeError):
       tc.test_method()
 
   def test_converted_call_mangled_properties(self):
@@ -518,7 +490,7 @@
       return x + 1
 
     x = api.converted_call(
-        f, converter.ConversionOptions(recursive=True, force_conversion=True),
+        f, converter.ConversionOptions(recursive=True, user_requested=True),
         (constant_op.constant(0),), {})
     self.assertTrue(self.evaluate(x))
 
@@ -538,8 +510,7 @@
     opts = converter.ConversionOptions(internal_convert_user_code=False)
 
     # f should not be converted, causing len to error out.
-    with self.assertRaisesRegexp(Exception,
-                                 'object of type \'Tensor\' has no len()'):
+    with self.assertRaisesRegexp(Exception, 'len is not well defined'):
       api.converted_call(f, opts, (constant_op.constant([0]),), {})
 
     # len on the other hand should work fine.
@@ -593,7 +564,7 @@
 
     # TODO(mdan): Add the missing level of support to LOGICAL_EXPRESSIONS.
     opts = converter.ConversionOptions(
-        force_conversion=True, optional_features=None)
+        user_requested=True, optional_features=None)
 
     x = api.converted_call(gen_math_ops.add, opts, (1, 1), {})
 
@@ -631,6 +602,53 @@
 
     self.assertTrue(inspect_utils.isnamedtuple(x))
 
+  def test_converted_call_namedtuple_subclass_bound_method(self):
+
+    class TestClass(collections.namedtuple('TestNamedtuple', ('a', 'b'))):
+
+      def test_method(self, x):
+        while tf.reduce_sum(x) > self.a:
+          x //= self.b
+        return x
+
+    opts = converter.ConversionOptions(recursive=True)
+
+    obj = TestClass(5, 2)
+    x = api.converted_call(obj.test_method, opts,
+                           (constant_op.constant([2, 4]),), {})
+
+    self.assertAllEqual(self.evaluate(x), [1, 2])
+
+  def test_converted_call_namedtuple_method(self):
+
+    class TestClass(collections.namedtuple('TestNamedtuple', ('a', 'b'))):
+      pass
+
+    opts = converter.ConversionOptions(recursive=True)
+
+    obj = TestClass(5, 2)
+    # _asdict is a documented method of namedtuple.
+    x = api.converted_call(obj._asdict, opts, (), {})
+
+    self.assertDictEqual(x, {'a': 5, 'b': 2})
+
+  def test_converted_call_namedtuple_subclass_unbound_method(self):
+
+    class TestClass(collections.namedtuple('TestNamedtuple', ('a', 'b'))):
+
+      def test_method(self, x):
+        while tf.reduce_sum(x) > self.a:
+          x //= self.b
+        return x
+
+    opts = converter.ConversionOptions(recursive=True)
+
+    obj = TestClass(5, 2)
+    x = api.converted_call(TestClass.test_method, opts,
+                           (obj, constant_op.constant([2, 4])), {})
+
+    self.assertAllEqual(self.evaluate(x), [1, 2])
+
   def test_converted_call_lambda(self):
 
     opts = converter.ConversionOptions(recursive=True)
@@ -674,7 +692,7 @@
     def f():
       return dataset_ops.Dataset.range(-3, 3).map(other_fn)
 
-    # Dataset iteration only works inside tf.function.
+    # Dataset iteration only works inside tf.
     @def_function.function
     def graph_fn():
       opts = converter.ConversionOptions(recursive=True)
@@ -851,13 +869,10 @@
 
     self.assertNotEqual(converted_recursive.ag_module,
                         converted_non_recursive.ag_module)
-    self.assertIn('ag__.STD', tf_inspect.getsource(converted_recursive))
-    self.assertNotIn('internal_convert_user_code=False',
-                     tf_inspect.getsource(converted_recursive))
-    self.assertIn('internal_convert_user_code=False',
-                  tf_inspect.getsource(converted_non_recursive))
-    self.assertNotIn('internal_convert_user_code=True',
-                     tf_inspect.getsource(converted_non_recursive))
+    self.assertRegex(tf_inspect.getsource(converted_recursive),
+                     'FunctionScope(.*recursive=True.*)')
+    self.assertRegex(tf_inspect.getsource(converted_non_recursive),
+                     'FunctionScope(.*recursive=False.*)')
 
   def test_to_graph_preserves_bindings(self):
     y = 3
@@ -880,6 +895,22 @@
 
     self.assertTrue(hasattr(api.to_graph(test_fn), 'ag_source_map'))
 
+  def test_to_graph_sets_conversion_context(self):
+
+    def g():
+      self.assertEqual(ag_ctx.control_status_ctx().status,
+                       ag_ctx.Status.ENABLED)
+      return 0
+
+    # Note: the autograph=False sets the contect to Status.DISABLED. The test
+    # verifies that to_graph overrides that.
+    @def_function.function(autograph=False)
+    def f():
+      converted_g = api.to_graph(g)
+      converted_g()
+
+    f()
+
   def test_to_code_basic(self):
 
     def test_fn(x, s):
@@ -948,7 +979,7 @@
 
     decorated_f = tf_decorator.make_decorator(f, wrapper)
 
-    # Note: the autograph setting of tf.function has nothing to do with the
+    # Note: the autograph setting of tf has nothing to do with the
     # test case. We just disable it to avoid confusion.
     @def_function.function(autograph=False)
     def test_fn(ctx):
@@ -965,6 +996,50 @@
       # The code in `f` is only valid with AutoGraph.
       test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED))
 
+  def test_super_with_one_arg(self):
+    test_case_self = self
+
+    class TestBase(object):
+
+      def plus_three(self, x):
+        return x + 3
+
+    class TestSubclass(TestBase):
+
+      def plus_three(self, x):
+        test_case_self.fail('This should never be called.')
+
+      def one_arg(self, x):
+        test_base_unbound = super(TestSubclass)
+        test_base = test_base_unbound.__get__(self, TestSubclass)
+        return test_base.plus_three(x)
+
+    tc = api.converted_call(TestSubclass,
+                            converter.ConversionOptions(recursive=True), (), {})
+
+    self.assertEqual(5, tc.one_arg(2))
+
+  def test_super_with_two_args(self):
+    test_case_self = self
+
+    class TestBase(object):
+
+      def plus_three(self, x):
+        return x + 3
+
+    class TestSubclass(TestBase):
+
+      def plus_three(self, x):
+        test_case_self.fail('This should never be called.')
+
+      def two_args(self, x):
+        return super(TestSubclass, self).plus_three(x)
+
+    tc = api.converted_call(TestSubclass,
+                            converter.ConversionOptions(recursive=True), (), {})
+
+    self.assertEqual(5, tc.two_args(2))
+
 
 if __name__ == '__main__':
   os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'
diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py
index b97c7e5..a027572 100644
--- a/tensorflow/python/autograph/impl/conversion.py
+++ b/tensorflow/python/autograph/impl/conversion.py
@@ -43,11 +43,10 @@
 from tensorflow.python.autograph.converters import lists
 from tensorflow.python.autograph.converters import logical_expressions
 from tensorflow.python.autograph.converters import return_statements
-from tensorflow.python.autograph.converters import side_effect_guards
 from tensorflow.python.autograph.converters import slices
 from tensorflow.python.autograph.core import config
 from tensorflow.python.autograph.core import converter
-from tensorflow.python.autograph.core import function_wrapping
+from tensorflow.python.autograph.core import function_wrappers
 from tensorflow.python.autograph.core import naming
 from tensorflow.python.autograph.core import unsupported_features_checker
 from tensorflow.python.autograph.lang import special_functions
@@ -325,7 +324,9 @@
   return _instantiate(entity, converted_entity_info, free_nonglobal_var_names)
 
 
-def is_whitelisted_for_graph(o, check_call_override=True):
+# TODO(mdan): allow_namedtuple_subclass should be hardcoded to True.
+def is_whitelisted_for_graph(
+    o, check_call_override=True, allow_namedtuple_subclass=False):
   """Checks whether an entity is whitelisted for use in graph mode.
 
   Examples of whitelisted entities include all members of the tensorflow
@@ -336,6 +337,8 @@
     check_call_override: Reserved for internal use. When set to `False`, it
       disables the rule according to which classes are whitelisted if their
       __call__ method is whitelisted.
+    allow_namedtuple_subclass: Reserved for internal use. When `True`,
+      namedtuple subclasses are not whitelisted.
 
   Returns:
     Boolean
@@ -399,7 +402,10 @@
         return True
 
       owner_class = inspect_utils.getdefiningclass(o, owner_class)
-      if is_whitelisted_for_graph(owner_class, check_call_override=False):
+      if is_whitelisted_for_graph(
+          owner_class,
+          check_call_override=False,
+          allow_namedtuple_subclass=True):
         logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o,
                     owner_class)
         return True
@@ -408,8 +414,13 @@
     # Due to the way they're constructed, namedtuple types cannot be converted
     # because they don't expose source code. But we assume they are safe for
     # graph mode since they are just containers.
-    logging.log(2, 'Whitelisted: %s: named tuple', o)
-    return True
+    if allow_namedtuple_subclass:
+      if not any(inspect_utils.isnamedtuple(base) for base in o.__bases__):
+        logging.log(2, 'Whitelisted: %s: named tuple', o)
+        return True
+    else:
+      logging.log(2, 'Whitelisted: %s: named tuple or subclass', o)
+      return True
 
   logging.log(2, 'Not whitelisted: %s: default rule', o)
   return False
@@ -601,7 +612,8 @@
     ag_internal.STD = converter.STANDARD_OPTIONS
     ag_internal.Feature = converter.Feature
     ag_internal.utils = utils
-    ag_internal.function_scope = function_wrapping.function_scope
+    ag_internal.FunctionScope = function_wrappers.FunctionScope
+    ag_internal.with_function_scope = function_wrappers.with_function_scope
     # TODO(mdan): Add safeguards against name clashes.
     # We don't want to create a submodule because we want the operators to be
     # accessible as ag__.<operator>
@@ -641,24 +653,27 @@
   _add_self_references(namespace, program_ctx.autograph_module)
   namer = naming.Namer(namespace)
 
+  if isinstance(node, gast.Lambda):
+    new_name = namer.new_symbol('tf__lambda', ())
+  elif do_rename:
+    new_name = namer.function_name(f.__name__)
+  else:
+    new_name = f.__name__
+
   entity_info = transformer.EntityInfo(
       source_code=source,
       source_file='<fragment>',
       future_features=future_features,
       namespace=namespace)
-  context = converter.EntityContext(namer, entity_info, program_ctx)
+  context = converter.EntityContext(namer, entity_info, program_ctx, new_name)
   node = node_to_graph(node, context)
 
   if isinstance(node, gast.Lambda):
-    new_name = namer.new_symbol('tf__lambda', ())
     node = gast.Assign(
         targets=[gast.Name(new_name, gast.Store(), None)], value=node)
-
   elif do_rename:
-    new_name = namer.function_name(f.__name__)
     node.name = new_name
   else:
-    new_name = f.__name__
     assert node.name == new_name
 
   return (node,), new_name, entity_info
@@ -681,6 +696,7 @@
   unsupported_features_checker.verify(node)
 
   node = converter.standard_analysis(node, context, is_initial=True)
+  node = converter.apply_(node, context, function_scopes)
   node = converter.apply_(node, context, arg_defaults)
   node = converter.apply_(node, context, directives)
   node = converter.apply_(node, context, break_statements)
@@ -698,9 +714,4 @@
   node = converter.apply_(node, context, control_flow)
   node = converter.apply_(node, context, conditional_expressions)
   node = converter.apply_(node, context, logical_expressions)
-  if context.program.options.uses(converter.Feature.AUTO_CONTROL_DEPS):
-    node = converter.apply_(node, context, side_effect_guards)
-  # TODO(mdan): If function scopes ever does more, the toggle will need moving.
-  if context.program.options.uses(converter.Feature.NAME_SCOPES):
-    node = converter.apply_(node, context, function_scopes)
   return node
diff --git a/tensorflow/python/autograph/operators/BUILD b/tensorflow/python/autograph/operators/BUILD
index 43f1253..25fefbd 100644
--- a/tensorflow/python/autograph/operators/BUILD
+++ b/tensorflow/python/autograph/operators/BUILD
@@ -37,6 +37,7 @@
         "//tensorflow/python:array_ops",
         "//tensorflow/python:constant_op",
         "//tensorflow/python:control_flow_ops",
+        "//tensorflow/python:control_flow_util",
         "//tensorflow/python:dtypes",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:list_ops",
@@ -46,6 +47,7 @@
         "//tensorflow/python:variables",
         "//tensorflow/python/autograph/utils",
         "//tensorflow/python/data/ops:dataset_ops",
+        "//third_party/py/numpy",
     ],
 )
 
@@ -100,6 +102,24 @@
     deps = [
         ":operators",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/core",
+    ],
+)
+
+py_test(
+    name = "py_builtins_py3_test",
+    srcs = ["py_builtins_py3_test.py"],
+    python_version = "PY3",
+    srcs_version = "PY3",
+    tags = [
+        "no_windows",
+        # TODO(kkimlabs): Temporay workaround since KokoroPresubmit was failing.
+        #                 cl/259400943 for more context.
+        "no_oss_py2",
+    ],
+    deps = [
+        ":operators",
+        "//tensorflow/python:client_testlib",
     ],
 )
 
diff --git a/tensorflow/python/autograph/operators/control_flow.py b/tensorflow/python/autograph/operators/control_flow.py
index 9e179f5..3f0f53f 100644
--- a/tensorflow/python/autograph/operators/control_flow.py
+++ b/tensorflow/python/autograph/operators/control_flow.py
@@ -59,6 +59,9 @@
 from __future__ import division
 from __future__ import print_function
 
+import functools
+import numpy as np
+
 from tensorflow.python.autograph.operators import py_builtins
 from tensorflow.python.autograph.operators import special_values
 from tensorflow.python.autograph.utils import ag_logging
@@ -74,8 +77,10 @@
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.util import nest
 
 LIMIT_PYTHON_ITERATIONS = True
 PYTHON_MAX_ITERATIONS = 100000000  # Fails in about one minute for empty loops.
@@ -83,7 +88,6 @@
 INEFFICIENT_UNROLL_MIN_ITERATIONS = 3000
 INEFFICIENT_UNROLL_MIN_OPS = 1
 
-
 def _disallow_undefs_into_loop(*values):
   """Ensures that all values in the state are defined when entering a loop."""
   undefined = tuple(filter(special_values.is_undefined, values))
@@ -98,10 +102,173 @@
       # return value if the loop contained a return statement.
       # TODO(mdan): This should be checked at the place where return occurs.
       raise ValueError(
-          'Return statements are not supported within a TensorFlow loop.')
+          'return statements are not supported within a TensorFlow loop.')
 
 
-def for_stmt(iter_, extra_test, body, get_state, set_state, init_vars):
+def _shape_greater_than_or_equal(shape1, shape2):
+  """Check whether the shape2 is equal or more specific than shape1."""
+
+  # The following logic was mirrored from control_flow_ops.py's
+  # _ShapeLessThanOrEqual function.
+  if shape1.dims is None:
+    return True
+  if shape1.ndims != shape2.ndims:
+    return False
+  for dim1, dim2 in zip(shape1.dims, shape2.dims):
+    if dim1.value is not None and dim1.value != dim2.value:
+      return False
+  return True
+
+
+def _verify_tf_loop_vars(init_loop_vars,
+                         first_iter_vars,
+                         basic_symbol_names,
+                         composite_symbol_names,
+                         include_shapes=True):
+  """Verifies loop variables for consistency."""
+
+  # The whole point of _verify_tf_loop_vars is to give more useful error message
+  # than tf-level exception by including variable names.  If it's not available,
+  # there is no point at performing this verification here.  As of 2019-07-31,
+  # operators:control_flow_test does not pass the names.
+  if basic_symbol_names is None:
+    return
+
+  output_symbol_names = basic_symbol_names + composite_symbol_names
+
+  assert len(init_loop_vars) == len(first_iter_vars) == len(output_symbol_names)
+
+  for init_loop_var, first_iter_var, name in zip(init_loop_vars,
+                                                 first_iter_vars,
+                                                 output_symbol_names):
+
+    try:
+      nest.assert_same_structure(
+          init_loop_var, first_iter_var, expand_composites=True)
+    except (ValueError, TypeError) as e:
+      raise TypeError('"{}" does not have the same nested structure after one'
+                      ' iteration.\n\n{}'.format(name, e))
+
+    def _check_same_type(name, init_loop_var, first_iter_var):
+      """Ensures init_loop_var and first_iter_var are consistent."""
+      if isinstance(init_loop_var, (bool, int, float, str)):
+        init_loop_var = ops.convert_to_tensor_v2(init_loop_var)
+
+      if isinstance(first_iter_var, (bool, int, float, str)):
+        first_iter_var = ops.convert_to_tensor_v2(first_iter_var)
+
+      if (not tensor_util.is_tensor(init_loop_var) or
+          not tensor_util.is_tensor(first_iter_var)):
+        return
+
+      # TODO(mdan): Properly account for CompositeTensors.
+      if (not hasattr(init_loop_var, 'dtype') or
+          not hasattr(first_iter_var, 'dtype')):
+        return
+      if (not hasattr(init_loop_var, 'shape') or
+          not hasattr(first_iter_var, 'shape')):
+        return
+
+      if init_loop_var.dtype != first_iter_var.dtype:
+        raise TypeError(
+            '"{}" has dtype {} before the loop, but dtype {} after one'
+            ' iteration. TensorFlow control flow requires it stays the'
+            ' same.'.format(
+                name,
+                init_loop_var.dtype.name,
+                first_iter_var.dtype.name,
+            ))
+
+      if include_shapes:
+        init_shape = init_loop_var.shape
+        first_iter_shape = first_iter_var.shape
+        # TODO(b/135183013): Update needed once we support shape_invariants.
+        if not _shape_greater_than_or_equal(init_shape, first_iter_shape):
+          raise ValueError(
+              '"{}" has shape {} before the loop, but shape {} after one'
+              ' iteration. TensorFlow control flow requires it stays the'
+              ' same or be more specific.'.format(name, init_shape,
+                                                  first_iter_shape))
+
+    nest.map_structure(
+        functools.partial(_check_same_type, name), init_loop_var,
+        first_iter_var)
+
+
+def _verify_tf_cond_vars(body_outputs, orelse_outputs, basic_symbol_names,
+                         composite_symbol_names):
+  """Verifies variables manipulated by a conditional for consistency."""
+
+  # The whole point of _verify_tf_cond_vars is to give more useful error message
+  # than tf-level exception by including variable names.  If it's not available,
+  # there is no point at performing this verification here.  As of 2019-07-31,
+  # conditional expression does not pass the names.
+  if basic_symbol_names is None:
+    return
+
+  output_symbol_names = basic_symbol_names + composite_symbol_names
+
+  basic_body_outputs, composite_body_outputs = body_outputs
+  basic_orelse_outputs, composite_orelse_outputs = orelse_outputs
+  assert isinstance(composite_body_outputs, tuple)
+  assert isinstance(composite_orelse_outputs, tuple)
+
+  # TODO(kkimlabs): Make this more consistent.
+  # The basic outputs should always be a tuple.
+  if not isinstance(basic_body_outputs, tuple):
+    basic_body_outputs = (basic_body_outputs,)
+  if not isinstance(basic_orelse_outputs, tuple):
+    basic_orelse_outputs = (basic_orelse_outputs,)
+
+  body_outputs = basic_body_outputs + composite_body_outputs
+  orelse_outputs = basic_orelse_outputs + composite_orelse_outputs
+
+  for body_output, orelse_output, name in zip(body_outputs, orelse_outputs,
+                                              output_symbol_names):
+    try:
+      nest.assert_same_structure(
+          body_output, orelse_output, expand_composites=True)
+    except (ValueError, TypeError) as e:
+      raise TypeError(
+          '"{}" does not have the same nested structure in the TRUE and FALSE'
+          ' branches.\n\n{}'.format(name, str(e)))
+
+    def _check_same_type(name, body_output_var, orelse_output_var):
+      """Verfies that body_output_var and orelse_output_var have same dtype."""
+      if isinstance(body_output_var, (bool, int, float, str)):
+        body_output_var = ops.convert_to_tensor_v2(body_output_var)
+
+      if isinstance(orelse_output_var, (bool, int, float, str)):
+        orelse_output_var = ops.convert_to_tensor_v2(orelse_output_var)
+
+      if (not tensor_util.is_tensor(body_output_var) or
+          not tensor_util.is_tensor(orelse_output_var)):
+        return
+
+      # TODO(mdan): Properly account for CompositeTensors.
+      if (not hasattr(body_output_var, 'dtype') or
+          not hasattr(orelse_output_var, 'dtype')):
+        return
+
+      if body_output_var.dtype != orelse_output_var.dtype:
+        raise TypeError(
+            '"{}" has dtype {} in the TRUE branch, but dtype={} in the FALSE'
+            ' branch. TensorFlow control flow requires that they are the'
+            ' same.'.format(name, body_output_var.dtype.name,
+                            orelse_output_var.dtype.name))
+
+    nest.map_structure(
+        functools.partial(_check_same_type, name), body_output, orelse_output)
+
+
+def for_stmt(iter_,
+             extra_test,
+             body,
+             get_state,
+             set_state,
+             init_vars,
+             basic_symbol_names=None,
+             composite_symbol_names=None):
   """Functional form of a for statement.
 
   The loop operates on a state, which includes all symbols that are
@@ -135,6 +302,8 @@
     set_state: Additional callable which save values captured by get_state back
       into the Python environment. This is only useful when staging the loop.
     init_vars: Tuple containing the initial state.
+    basic_symbol_names: Tuple containing basic loop var names.
+    composite_symbol_names: Tuple containing composite loop var names.
 
   Returns:
     Tuple containing the final state.
@@ -142,18 +311,22 @@
   if tensor_util.is_tensor(iter_):
     if tensors.is_range_tensor(iter_):
       return _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state,
-                                init_vars)
+                                init_vars, basic_symbol_names,
+                                composite_symbol_names)
     else:
       return _known_len_tf_for_stmt(iter_, extra_test, body, get_state,
-                                    set_state, init_vars)
+                                    set_state, init_vars, basic_symbol_names,
+                                    composite_symbol_names)
 
   if isinstance(iter_, dataset_ops.DatasetV2):
     return _tf_dataset_for_stmt(iter_, extra_test, body, get_state, set_state,
-                                init_vars)
+                                init_vars, basic_symbol_names,
+                                composite_symbol_names)
 
   if isinstance(iter_, iterator_ops.IteratorV2):
     return _tf_iterator_for_stmt(iter_, extra_test, body, get_state, set_state,
-                                 init_vars)
+                                 init_vars, basic_symbol_names,
+                                 composite_symbol_names)
 
   # Note: This experimental interface is subject to change.
   custom_handler = getattr(iter_, '_autograph_for_loop', None)
@@ -179,7 +352,8 @@
 
 
 def _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state,
-                           init_vars):
+                           init_vars, basic_symbol_names,
+                           composite_symbol_names):
   """Overload of for_stmt that iterates over TF entities that admit a length."""
   _disallow_undefs_into_loop(*init_vars)
 
@@ -191,8 +365,11 @@
   iter_ = ta.unstack(iter_)
 
   def while_body(iterate_index, *loop_vars):
+    """Main loop body."""
     iterate = iter_.read(iterate_index)
     new_vars = body(iterate, *loop_vars)
+    _verify_tf_loop_vars(loop_vars, new_vars, basic_symbol_names,
+                         composite_symbol_names)
 
     loop_vars = (iterate_index + 1,)
     if new_vars:
@@ -206,13 +383,22 @@
           iterate_index < n, lambda: extra_test(*loop_vars), lambda: False)
     return iterate_index < n
 
+  opts = {}
+  # TODO(b/134181679): We do not always set maximum_iterations since that
+  # is significantly slower on GPU.
+  if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
+    opts['maximum_iterations'] = n
+
   results = _tf_while_stmt(
       while_cond,
       while_body,
       get_state,
       set_state,
-      init_vars=(0,) + init_vars,
-      opts=dict(maximum_iterations=n))
+      (0,) + init_vars,
+      None,
+      None,
+      opts=opts,
+  )
 
   # Note: the iteration index is not returned by the while loop, however
   # if a symbol with the same name exists outside the loop, it will be captured
@@ -227,8 +413,8 @@
   return results
 
 
-def _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state,
-                       init_vars):
+def _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars,
+                       basic_symbol_names, composite_symbol_names):
   """Overload of for_stmt that iterates over a TF range (and elides it)."""
   _disallow_undefs_into_loop(*init_vars)
 
@@ -236,33 +422,61 @@
 
   def while_body(iterate, *loop_vars):
     new_vars = body(iterate, *loop_vars)
-
     loop_vars = (iterate + delta,)
+
     if new_vars:
       loop_vars += new_vars
 
     return loop_vars
 
   def while_cond(iterate, *loop_vars):
-    main_test = math_ops.logical_or(
-        math_ops.logical_and(delta >= 0, iterate < limit),
-        math_ops.logical_and(delta < 0, iterate > limit))
+    """Cond function for `tf.while_loop`."""
+
+    def build_main_test():
+      """Main iteration condition."""
+      # Note(b/138857806): LogicalAnd is slow on GPU so we avoid adding it if
+      # `delta` is a compile time constant.
+      delta_const = tensor_util.constant_value(delta)
+      if delta_const is not None:
+        # Support single element arrays.
+        delta_const = np.asscalar(delta_const)
+        if delta_const >= 0:
+          return iterate < limit
+        else:
+          return iterate > limit
+      else:
+        return math_ops.logical_or(
+            math_ops.logical_and(delta >= 0, iterate < limit),
+            math_ops.logical_and(delta < 0, iterate > limit))
+
+    main_test = build_main_test()
     if extra_test is not None:
       return control_flow_ops.cond(
           main_test, lambda: extra_test(*loop_vars), lambda: False)
     return main_test
 
-  # This specific dtype is required by while_loop.
-  maximum_iterations = math_ops.cast(
-      misc.get_range_len(start, limit, delta), dtypes.int32)
+  # The first loopvar corresponds to the iterate variable which is internal.
+  if isinstance(basic_symbol_names, tuple):
+    basic_symbol_names = (None,) + basic_symbol_names
+
+  opts = {}
+  # TODO(b/134181679): We do not always set maximum_iterations since that
+  # is significantly slower on GPU.
+  if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
+    # This specific dtype is required by while_loop.
+    opts['maximum_iterations'] = math_ops.cast(
+        misc.get_range_len(start, limit, delta), dtypes.int32)
 
   results = _tf_while_stmt(
       while_cond,
       while_body,
       get_state,
       set_state,
-      init_vars=(start,) + init_vars,
-      opts=dict(maximum_iterations=maximum_iterations))
+      (start,) + init_vars,
+      basic_symbol_names,
+      composite_symbol_names,
+      opts=opts,
+  )
 
   # Note: the iteration index is not returned by the while loop, however
   # if a symbol with the same name exists outside the loop, it will be captured
@@ -278,12 +492,16 @@
 
 
 def _tf_iterator_for_stmt(itr, extra_test, body, get_state, set_state,
-                          init_vars):
+                          init_vars, basic_symbol_names,
+                          composite_symbol_names):
   """Overload of for_stmt that iterates over TF Iterators. See for_loop."""
   _disallow_undefs_into_loop(*init_vars)
 
   def while_body_actual(opt_iterate, *loop_vars):
+    """Actual main loop body."""
     new_vars = body(opt_iterate.get_value(), *loop_vars)
+    _verify_tf_loop_vars(loop_vars, new_vars, basic_symbol_names,
+                         composite_symbol_names)
     # TODO(mdan): Fix this inconsistency in the converter.
     if new_vars is None:
       new_vars = ()
@@ -318,31 +536,40 @@
           has_next, lambda: extra_test(*loop_vars), lambda: False)
     return has_next
 
+  # The first loopvar corresponds to the iterate variable which is internal.
   _, final_vars = _tf_while_stmt(
       while_cond,
       while_body,
       get_state,
       set_state,
-      init_vars=(True, init_vars),
-      opts=None)
+      (True, init_vars),
+      None,
+      None,
+      opts=None,
+  )
   return final_vars
 
 
-def _tf_dataset_for_stmt(ds, extra_test, body, get_state, set_state, init_vars):
+def _tf_dataset_for_stmt(ds, extra_test, body, get_state, set_state, init_vars,
+                         basic_symbol_names, composite_symbol_names):
   """Overload of for_stmt that iterates over TF Datasets."""
   _disallow_undefs_into_loop(*init_vars)
 
   if extra_test is not None:
     assert init_vars, 'Lowering should always add state.'
     return _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state,
-                                             set_state, init_vars)
+                                             set_state, init_vars,
+                                             basic_symbol_names,
+                                             composite_symbol_names)
 
   return _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state,
-                                         init_vars)
+                                         init_vars, basic_symbol_names,
+                                         composite_symbol_names)
 
 
 def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state,
-                                      set_state, init_vars):
+                                      set_state, init_vars, basic_symbol_names,
+                                      composite_symbol_names):
   """Overload of _dataset_for_stmt with early stopping. See for_stmt."""
 
   # TODO(mdan): Simplify this - following it is extremely difficult.
@@ -354,6 +581,12 @@
     def true_fn():
       set_state(state)
       outputs = body(iterate, *loop_vars)
+      _verify_tf_loop_vars(
+          loop_vars + state,
+          outputs + state,
+          basic_symbol_names,
+          composite_symbol_names,
+          include_shapes=False)
       return outputs, get_state()
 
     extra_cond = extra_test(*loop_vars)
@@ -385,7 +618,8 @@
   return final_vars
 
 
-def _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state, init_vars):
+def _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state, init_vars,
+                                    basic_symbol_names, composite_symbol_names):
   """Overload of _dataset_for_stmt without early stopping. See for_stmt."""
   init_state = get_state()
   assert isinstance(init_vars, tuple)
@@ -399,6 +633,8 @@
 
   if no_vars:
     init_vars = (constant_op.constant(0),)
+    if isinstance(basic_symbol_names, tuple):
+      basic_symbol_names = (None,) + basic_symbol_names
   if no_state:
     init_state = (constant_op.constant(0),)
 
@@ -419,6 +655,12 @@
     else:
       new_state = get_state()
 
+    _verify_tf_loop_vars(
+        loop_vars + state,
+        new_vars + new_state,
+        basic_symbol_names,
+        composite_symbol_names,
+        include_shapes=False)
     return new_vars, new_state
 
   aug_vars = init_vars, get_state()
@@ -430,7 +672,16 @@
   return final_vars
 
 
-def while_stmt(test, body, get_state, set_state, init_vars, opts=None):
+def while_stmt(
+    test,
+    body,
+    get_state,
+    set_state,
+    init_vars,
+    basic_symbol_names=None,
+    composite_symbol_names=None,
+    opts=None,
+):
   """Functional form of a while statement.
 
   The loop operates on a so-called state, which includes all symbols that are
@@ -449,11 +700,14 @@
     set_state: Additional callable which save values captured by get_state back
       into the Python environment. This is only useful when staging the loop.
     init_vars: Tuple containing the initial state.
+    basic_symbol_names: Tuple containing basic loop var names.
+    composite_symbol_names: Tuple containing composite loop var names.
     opts: Optional dict of extra loop parameters.
 
   Returns:
     Tuple containing the final state.
   """
+
   # Evaluate the initial test once in order to do the dispatch. The evaluation
   # is isolated to minimize unwanted side effects.
   # TODO(mdan): Do a full iteration - some state types might lower to Tensor.
@@ -463,7 +717,8 @@
   # TensorFlow: Multiple evaluations are acceptable in this case, so we're fine
   # with the re-evaluation of `test` that `_tf_while_stmt` will make.
   if tensors.is_dense_tensor(init_test):
-    return _tf_while_stmt(test, body, get_state, set_state, init_vars, opts)
+    return _tf_while_stmt(test, body, get_state, set_state, init_vars,
+                          basic_symbol_names, composite_symbol_names, opts)
 
   # Normal Python: We already consumed one evaluation of `test`; consistently,
   # unroll one iteration before dispatching to a normal loop.
@@ -475,7 +730,11 @@
   return _py_while_stmt(test, body, get_state, set_state, init_vars, opts)
 
 
-def _tf_while_stmt(test, body, get_state, set_state, init_vars, opts):
+# TODO(kkimlabs): Some callers set basic_symbol_names=None and
+# composite_symbol_names=None and call _verify_tf_loop_vars(...) itself.  We can
+# remove these arguments once all callers do that.
+def _tf_while_stmt(test, body, get_state, set_state, init_vars,
+                   basic_symbol_names, composite_symbol_names, opts):
   """Overload of while_stmt that stages a TF while_stmt."""
   _disallow_undefs_into_loop(*init_vars)
 
@@ -495,7 +754,11 @@
     state = aug_loop_vars[state_slice]
     set_state(state)
     loop_vars = body(*aug_loop_vars[loop_vars_slice])
-    return loop_vars + get_state()
+    new_state = loop_vars + get_state()
+    _verify_tf_loop_vars(aug_loop_vars, new_state, basic_symbol_names,
+                         composite_symbol_names)
+
+    return new_state
 
   # Non-v2 while_loop unpacks the results when there is only one return value.
   # This enforces consistency across versions.
@@ -592,7 +855,13 @@
   return loop_vars
 
 
-def if_stmt(cond, body, orelse, get_state, set_state):
+def if_stmt(cond,
+            body,
+            orelse,
+            get_state,
+            set_state,
+            basic_symbol_names=None,
+            composite_symbol_names=None):
   """Functional form of an if statement.
 
   Args:
@@ -612,18 +881,22 @@
       restore checkpointed values. The single argument a tuple containing values
       for each composite symbol that may be modified in a branch of the
       conditional. The is usually the result of a call to get_state.
+    basic_symbol_names: Tuple containing basic loop var names.
+    composite_symbol_names: Tuple containing composite loop var names.
 
   Returns:
     Tuple containing the statement outputs.
   """
   # Note: tf.cond doesn't support SparseTensor.
   if tensors.is_dense_tensor(cond):
-    return tf_if_stmt(cond, body, orelse, get_state, set_state)
+    return tf_if_stmt(cond, body, orelse, get_state, set_state,
+                      basic_symbol_names, composite_symbol_names)
   else:
     return _py_if_stmt(cond, body, orelse)
 
 
-def tf_if_stmt(cond, body, orelse, get_state, set_state):
+def tf_if_stmt(cond, body, orelse, get_state, set_state, basic_symbol_names,
+               composite_symbol_names):
   """Overload of if_stmt that stages a TF cond."""
   body = _wrap_disallow_undefs_from_cond(body, branch_name='if')
   orelse = _wrap_disallow_undefs_from_cond(orelse, branch_name='else')
@@ -635,7 +908,28 @@
   # symbols (e.g. `a`) which cannot be passed by reference and must be returned.
   # See _isolate_state.
   # TODO(mdan): We should minimize calls to get/set_state.
-  final_vars, final_state = control_flow_ops.cond(cond, body, orelse)
+
+  body_branch = 0
+  orelse_branch = 1
+  result = [None, None]
+
+  def error_checking_body():
+    result[body_branch] = body()
+    if result[orelse_branch] is not None:
+      _verify_tf_cond_vars(result[body_branch], result[orelse_branch],
+                           basic_symbol_names, composite_symbol_names)
+    return result[body_branch]
+
+  def error_checking_orelse():
+    result[orelse_branch] = orelse()
+    if result[body_branch] is not None:
+      _verify_tf_cond_vars(result[body_branch], result[orelse_branch],
+                           basic_symbol_names, composite_symbol_names)
+    return result[orelse_branch]
+
+  final_vars, final_state = control_flow_ops.cond(cond, error_checking_body,
+                                                  error_checking_orelse)
+
   set_state(final_state)
 
   return final_vars
diff --git a/tensorflow/python/autograph/operators/control_flow_test.py b/tensorflow/python/autograph/operators/control_flow_test.py
index cf25075..7b6217c 100644
--- a/tensorflow/python/autograph/operators/control_flow_test.py
+++ b/tensorflow/python/autograph/operators/control_flow_test.py
@@ -34,6 +34,7 @@
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import gen_math_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
 
@@ -62,6 +63,19 @@
           init_vars=(0,))
       self.assertEqual(self.evaluate(s), (1234,))
 
+  def test_range_tensor_random_delta(self):
+
+    with ops.Graph().as_default():
+      random_one = random_ops.random_uniform((), 1, 2, dtype=dtypes.int32)
+      s = control_flow.for_stmt(
+          math_ops.range(0, 5, random_one),
+          extra_test=lambda s: True,
+          body=lambda i, s: (s * 10 + i,),
+          get_state=lambda: (),
+          set_state=lambda _: None,
+          init_vars=(0,))
+      self.assertEqual(self.evaluate(s), (1234,))
+
   def test_range_tensor_explicit_limit_delta(self):
     with ops.Graph().as_default():
       s = control_flow.for_stmt(
@@ -73,6 +87,21 @@
           init_vars=(0,))
       self.assertEqual(self.evaluate(s), (-171207,))
 
+  def test_range_tensor_random_negative_delta(self):
+    with ops.Graph().as_default():
+      random_neg_five = random_ops.random_uniform((),
+                                                  -5,
+                                                  -4,
+                                                  dtype=dtypes.int32)
+      s = control_flow.for_stmt(
+          math_ops.range(17, 3, random_neg_five),
+          extra_test=lambda s: True,
+          body=lambda i, s: (s * 100 + i,),
+          get_state=lambda: (),
+          set_state=lambda _: None,
+          init_vars=(0,))
+      self.assertEqual(self.evaluate(s), (171207,))
+
   def test_range_tensor_negative_delta(self):
     with ops.Graph().as_default():
       s = control_flow.for_stmt(
diff --git a/tensorflow/python/autograph/operators/py_builtins.py b/tensorflow/python/autograph/operators/py_builtins.py
index ab28228..435e103 100644
--- a/tensorflow/python/autograph/operators/py_builtins.py
+++ b/tensorflow/python/autograph/operators/py_builtins.py
@@ -27,6 +27,7 @@
 
 from tensorflow.python.autograph.utils import py_func
 from tensorflow.python.autograph.utils import tensors
+from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -48,11 +49,34 @@
   return f
 
 
-def eval_in_original_context(f, args, caller_level_delta):
-  """Executes the eval function with the user-specified globals/locals."""
+def _find_originating_frame(caller_fn_scope, innermost=True):
+  """Locates the frame in which `caller_fn_scope` was defined."""
   ctx_frame = inspect.currentframe()
-  for _ in range(caller_level_delta + 1):
+  result = None
+  while ctx_frame is not None:
+    # Note it should not be normally possible to get false positives this way
+    # because the function scope object is not accessible to user code (barring
+    # call stack introspection).
+    if ctx_frame.f_locals.get(caller_fn_scope.name, None) is caller_fn_scope:
+      result = ctx_frame
+      if innermost:
+        break
     ctx_frame = ctx_frame.f_back
+
+  assert result is not None, (
+      'the conversion process should ensure the caller_fn_scope is always'
+      ' found somewhere on the call stack')
+
+  return result
+
+
+def eval_in_original_context(f, args, caller_fn_scope):
+  """Executes the eval function in the context of a specified function."""
+  # When control flow is rewritten using functions, eval should use the
+  # variables found in the same block where it was called. That is equivalent
+  # to the innermost function call.
+  ctx_frame = _find_originating_frame(caller_fn_scope, innermost=True)
+
   args = (
       args[0],
       ctx_frame.f_globals if len(args) < 2 else args[1],
@@ -61,6 +85,68 @@
   return f(*args)
 
 
+def super_in_original_context(f, args, caller_fn_scope):
+  """Executes the super function in the context of a specified function.
+
+  See https://docs.python.org/3/library/functions.html#super for the exact
+  details
+
+  Args:
+    f: Callable, typically the super builtin
+    args: List[Any], the original call arguments
+    caller_fn_scope: Optional[function_wrappers.FunctionScope], the function
+      scope of the converted function in which this call was originally made
+
+  Returns:
+    The result of calling `f` as if it was called in the frame indicated by
+      `caller_fn_scope`.
+  """
+
+  # Python 2 doesn't support implicit argument super variants.
+  if six.PY2:
+    return f(*args)
+
+  # Only the no-arg call is desugared.
+  if args:
+    return f(*args)
+
+  # Inner functions seem to include their closure in f_locals, so we need
+  # to find the outermost frame.
+  ctx_frame = _find_originating_frame(caller_fn_scope, innermost=False)
+
+  # When super(..) is called without arguments, it looks for __class__ cell
+  # variable and the first argument passed in the enclosing function according
+  # to the spec https://www.python.org/dev/peps/pep-3135/ .
+  #
+  # We couldn't verify if `inspect.currentframe().f_code.co_varnames[0]` is
+  # guaranteed to be the first argument from an official doc or PEP, however,
+  # it's fairly stable and well established:
+  # - An unofficial community doc mentions it.
+  #   https://python-reference.readthedocs.io/en/latest/docs/code/varnames.html
+  # - CPython has tests checking that order, which was merged in 2008, and
+  #   unchanged since then.
+  #   https://github.com/python/cpython/blame/2f224a077a83ac9de8a12bb7dcc516642b8176d8/Lib/lib2to3/tests/data/py2_test_grammar.py#L157
+  #   https://github.com/python/cpython/blame/2f224a077a83ac9de8a12bb7dcc516642b8176d8/Lib/lib2to3/tests/data/py3_test_grammar.py#L192
+  #
+  # Note: the name can be more reliably obtained by inspecting the calling
+  # function's argspec.
+  #
+  # Even though methods can be declared using *args (def method(*args)),
+  # that pattern is disallowed by super() -- it raises super() no arguments.
+  # Method definitions using **kwargs are not allowed at all.
+  # In other words, we can always assume that self is on the first positional
+  # argument (for correct code).
+  #
+  # TODO(mdan): Consider additional checks in case the input code is incorrect.
+  # For example, the error might be cryptic compared to what super() regularly
+  # raises.
+
+  type_arg = ctx_frame.f_locals['__class__']
+  self_arg_name = ctx_frame.f_code.co_varnames[0]
+  self_arg = ctx_frame.f_locals[self_arg_name]
+  return f(type_arg, self_arg)
+
+
 def abs_(x):
   if tensor_util.is_tensor(x):
     return _tf_abs(x)
@@ -139,7 +225,7 @@
     return s.shape.dims[0].value
 
   # Static shape of unknown dimensions: use dynamic shape but statically
-  # chech that it's a scalar.
+  # check that it's a scalar.
   shape = array_ops.shape(s)
 
   assert shape.shape, 'shape tensor of zero size? {}'.format(shape)
@@ -242,7 +328,21 @@
   return range(start_or_stop)
 
 
-SUPPORTED_BUILTINS = (abs, float, int, len, print, range)
+def enumerate_(s, start=0):
+  if isinstance(s, dataset_ops.DatasetV2):
+    return _tf_dataset_enumerate(s, start)
+  return _py_enumerate(s, start)
+
+
+def _tf_dataset_enumerate(s, start=0):
+  return s.enumerate(start)
+
+
+def _py_enumerate(s, start=0):
+  return enumerate(s, start)
+
+
+SUPPORTED_BUILTINS = (abs, float, int, len, print, range, enumerate)
 
 if six.PY2:
   SUPPORTED_BUILTINS += (xrange,)
@@ -256,4 +356,5 @@
     'range': range_,
     # TODO(mdan): This might make more sense as tf.data.range.
     'xrange': range_,
+    'enumerate': enumerate_,
 }
diff --git a/tensorflow/python/autograph/operators/py_builtins_py3_test.py b/tensorflow/python/autograph/operators/py_builtins_py3_test.py
new file mode 100644
index 0000000..11a33b9
--- /dev/null
+++ b/tensorflow/python/autograph/operators/py_builtins_py3_test.py
@@ -0,0 +1,123 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for py_builtins_py3 module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.core import function_wrappers
+from tensorflow.python.autograph.operators import py_builtins
+from tensorflow.python.platform import test
+
+
+class TestBaseClass(object):
+
+  def overridden_method(self, x):
+    return x + 20
+
+
+class PyBuiltinsTest(test.TestCase):
+
+  def _basic_function_scope(self):
+    return function_wrappers.FunctionScope(
+        'test_function_name',
+        'test_scope',  # Note: this must match the name in the `with` statement.
+        converter.ConversionOptions())
+
+  def test_super_in_original_context_niladic_call(self):
+    test_case_self = self
+
+    class TestSubclass(TestBaseClass):
+
+      def overridden_method(self, x):
+        test_case_self.fail('This should never be called.')
+
+      def test_method(self):
+        with test_case_self._basic_function_scope() as test_scope:
+          b = py_builtins.super_in_original_context(super, (), test_scope)
+          return b.overridden_method(1)
+
+    tc = TestSubclass()
+    self.assertEqual(tc.test_method(), 21)
+
+  def test_super_in_original_context_caller_with_locals(self):
+    test_case_self = self
+
+    class TestSubclass(TestBaseClass):
+
+      def overridden_method(self, x):
+        test_case_self.fail('This should never be called.')
+
+      def test_method(self, x):
+        y = 7
+        with test_case_self._basic_function_scope() as test_scope:
+          z = 7
+          return py_builtins.super_in_original_context(
+              super, (), test_scope).overridden_method(x + y - z)
+
+    tc = TestSubclass()
+    self.assertEqual(tc.test_method(1), 21)
+
+  def test_super_in_original_context_inner_function(self):
+    test_case_self = self
+
+    class TestSubclass(TestBaseClass):
+
+      def overridden_method(self, x):
+        test_case_self.fail('This should never be called.')
+
+      def test_method(self, x):
+        with test_case_self._basic_function_scope() as test_scope:
+          # Oddly, it's sufficient to use `self` in an inner function
+          # to gain access to __class__ in this scope.
+          # TODO(mdan): Is this true across implementations?
+          # Note: normally, it's illegal to use super() in inner functions (it
+          # throws an error), but the generated code may create them.
+          def inner_fn():
+            return py_builtins.super_in_original_context(
+                super, (), test_scope).overridden_method(x)
+
+          return inner_fn()
+
+    tc = TestSubclass()
+    self.assertEqual(tc.test_method(1), 21)
+
+  def test_super_in_original_context_inner_lambda(self):
+    test_case_self = self
+
+    class TestSubclass(TestBaseClass):
+
+      def overridden_method(self, x):
+        test_case_self.fail('This should never be called.')
+
+      def test_method(self, x):
+        with test_case_self._basic_function_scope() as test_scope:
+          # Oddly, it's sufficient to use `self` in an inner function
+          # to gain access to __class__ in this scope.
+          # TODO(mdan): Is this true across implementations?
+          # Note: normally, it's illegal to use super() in inner functions (it
+          # throws an error), but the generated code may create them.
+          l = lambda: py_builtins.super_in_original_context(  # pylint:disable=g-long-lambda
+              super, (), test_scope).overridden_method(x)
+          return l()
+
+    tc = TestSubclass()
+    self.assertEqual(tc.test_method(1), 21)
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/python/autograph/operators/py_builtins_test.py b/tensorflow/python/autograph/operators/py_builtins_test.py
index 1be10bf..e706a28 100644
--- a/tensorflow/python/autograph/operators/py_builtins_test.py
+++ b/tensorflow/python/autograph/operators/py_builtins_test.py
@@ -22,8 +22,11 @@
 
 import six
 
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.core import function_wrappers
 from tensorflow.python.autograph.operators import data_structures
 from tensorflow.python.autograph.operators import py_builtins
+from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors_impl
@@ -33,6 +36,12 @@
 from tensorflow.python.platform import test
 
 
+class TestBase(object):
+
+  def plus_twenty(self, x):
+    return x + 20
+
+
 class PyBuiltinsTest(test.TestCase):
 
   def test_abs(self):
@@ -137,23 +146,88 @@
       r = py_builtins.range_(5, constant_op.constant(2))
       self.assertAllEqual(self.evaluate(r), [])
 
+  def test_enumerate(self):
+    self.assertListEqual(
+        list(py_builtins.enumerate_([3, 2, 1])), [(0, 3), (1, 2), (2, 1)])
+    self.assertListEqual(
+        list(py_builtins.enumerate_([3, 2, 1], 5)), [(5, 3), (6, 2), (7, 1)])
+    self.assertListEqual(list(py_builtins.enumerate_([-8], -3)), [(-3, -8)])
+
+  def test_enumerate_dataset(self):
+    dataset = dataset_ops.DatasetV2.from_tensor_slices(['a', 'c'])
+    start = constant_op.constant(20, dtype=dtypes.int64)
+    dataset = py_builtins.enumerate_(dataset, start)
+    iterator = dataset_ops.make_one_shot_iterator(dataset)
+
+    with self.cached_session() as sess:
+      self.assertAllEqual(self.evaluate(iterator.get_next()), (20, b'a'))
+      self.assertAllEqual(self.evaluate(iterator.get_next()), (21, b'c'))
+
+  def _basic_function_scope(self):
+    return function_wrappers.FunctionScope(
+        'test_function_name',
+        'test_scope',  # Note: this must match the name in the `with` statement.
+        converter.ConversionOptions())
+
   def test_eval_in_original_context(self):
 
-    def caller_1(lvl_delta):
+    def test_fn():
       l = 1  # pylint:disable=unused-variable
-      return py_builtins.eval_in_original_context(eval, ('l',), lvl_delta)
+      with self._basic_function_scope() as test_scope:
+        return py_builtins.eval_in_original_context(eval, ('l',), test_scope)
 
-    def caller_2(lvl_delta):
-      l = 2  # pylint:disable=unused-variable
-      return caller_1(lvl_delta)
+    self.assertEqual(test_fn(), 1)
 
-    def caller_3(lvl_delta):
-      l = 3  # pylint:disable=unused-variable
-      return caller_2(lvl_delta)
+  def test_eval_in_original_context_inner_function(self):
 
-    self.assertEqual(caller_3(0), 1)
-    self.assertEqual(caller_3(1), 2)
-    self.assertEqual(caller_3(2), 3)
+    def test_fn():
+      l = 1  # pylint:disable=unused-variable
+      with self._basic_function_scope() as test_scope:
+
+        def inner_fn():
+          # Note: a user function without a top-level function scope should
+          # never be found in user code; it's only possible in generated code.
+          l = 2  # pylint:disable=unused-variable
+          return py_builtins.eval_in_original_context(eval, ('l',), test_scope)
+
+        return inner_fn()
+
+    self.assertEqual(test_fn(), 2)
+
+  def test_super_in_original_context_unary_call(self):
+    test_case_self = self
+
+    class TestSubclass(TestBase):
+
+      def plus_twenty(self, x):
+        test_case_self.fail('This should never be called.')
+
+      def test_method(self):
+        with test_case_self._basic_function_scope() as test_scope:
+          test_base_unbound = py_builtins.super_in_original_context(
+              super, (TestSubclass,), test_scope)
+          test_base = test_base_unbound.__get__(self, TestSubclass)
+          return test_base.plus_twenty(1)
+
+    tc = TestSubclass()
+    self.assertEqual(tc.test_method(), 21)
+
+  def test_super_in_original_context_binary_call(self):
+    test_case_self = self
+
+    class TestSubclass(TestBase):
+
+      def plus_twenty(self, x):
+        test_case_self.fail('This should never be called.')
+
+      def test_method(self):
+        with test_case_self._basic_function_scope() as test_scope:
+          test_base = py_builtins.super_in_original_context(
+              super, (TestSubclass, self), test_scope)
+          return test_base.plus_twenty(1)
+
+    tc = TestSubclass()
+    self.assertEqual(tc.test_method(), 21)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/autograph/pyct/common_transformers/anf.py b/tensorflow/python/autograph/pyct/common_transformers/anf.py
index ed141ae..216c023 100644
--- a/tensorflow/python/autograph/pyct/common_transformers/anf.py
+++ b/tensorflow/python/autograph/pyct/common_transformers/anf.py
@@ -156,7 +156,8 @@
   # A-normal form.  Thus they are left in by default, but could be pulled out
   # if the configuration calls for it.
   _literal_nodes = (
-      gast.Num, gast.Str, gast.Bytes, gast.NameConstant
+      gast.Num, gast.Str, gast.Bytes, gast.NameConstant,
+      gast.Name  # Name is here to cover True, False, and None in Python 2
   )
 
   def _match(self, pattern, parent, field, child):
@@ -198,7 +199,8 @@
     """
     if node is None:
       return node
-    if isinstance(node, self._trivial_nodes):
+    if (isinstance(node, self._trivial_nodes) and
+        not _is_py2_name_constant(node)):
       return node
     if isinstance(node, list):
       # If something's field was actually a list, e.g., variadic arguments.
@@ -493,6 +495,10 @@
     return node
 
 
+def _is_py2_name_constant(node):
+  return isinstance(node, gast.Name) and node.id in ['True', 'False', 'None']
+
+
 def transform(node, ctx, config=None, gensym_source=None):
   """Converts the given node to A-normal form (ANF).
 
diff --git a/tensorflow/python/autograph/pyct/common_transformers/anf_test.py b/tensorflow/python/autograph/pyct/common_transformers/anf_test.py
index df5ea5d..fe2e9b2 100644
--- a/tensorflow/python/autograph/pyct/common_transformers/anf_test.py
+++ b/tensorflow/python/autograph/pyct/common_transformers/anf_test.py
@@ -479,7 +479,7 @@
   def test_constants_in_function_calls(self):
     # An example specific configuration that differs from the default: Moving
     # literals out of being directly passed to functions, but nothing else.
-    literals = (gast.Num, gast.Str, gast.Bytes, gast.NameConstant)
+    literals = (gast.Num, gast.Str, gast.Bytes, gast.NameConstant, gast.Name)
     config = [(anf.ASTEdgePattern(gast.Call, anf.ANY, literals), anf.REPLACE)]
 
     def test_function(x, frob):
@@ -514,6 +514,24 @@
 
     self.assert_body_anfs_as_expected(expected_result, test_function, config)
 
+  def test_touching_name_constant(self):
+    # Checking that the nodes for `True`, `False`, and `None` can be manipulated
+    # by a configuration.  This is non-trivial, because in Python 2 those are
+    # represented as `Name`, which is the same node type as variable references.
+    specials = (gast.Name, gast.NameConstant)
+    config = [(anf.ASTEdgePattern(gast.Call, anf.ANY, specials), anf.REPLACE)]
+
+    def test_function(f):
+      return f(True, False, None)
+
+    def expected_result(f):
+      tmp_1001 = True
+      tmp_1002 = False
+      tmp_1003 = None
+      return f(tmp_1001, tmp_1002, tmp_1003)
+
+    self.assert_body_anfs_as_expected(expected_result, test_function, config)
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/autograph/pyct/errors.py b/tensorflow/python/autograph/pyct/errors.py
index 4960883..345320d 100644
--- a/tensorflow/python/autograph/pyct/errors.py
+++ b/tensorflow/python/autograph/pyct/errors.py
@@ -197,7 +197,8 @@
 
     return '\n'.join(lines)
 
-  def create_exception(self, preferred_type):
+  def create_exception(self, source_error):
+    preferred_type = type(source_error)
     if preferred_type.__init__ is Exception.__init__:
       return preferred_type(self.get_message())
     if preferred_type in KNOWN_STRING_CONSTRUCTOR_ERRORS:
@@ -206,8 +207,8 @@
       return MultilineMessageKeyError(self.get_message(), self.cause_message)
     return None
 
-  def to_exception(self, preferred_type):
-    exc = self.create_exception(preferred_type)
+  def to_exception(self, source_error):
+    exc = self.create_exception(source_error)
     exc.__suppress_context__ = True
     exc.ag_error_metadata = self
     return exc
diff --git a/tensorflow/python/autograph/pyct/errors_test.py b/tensorflow/python/autograph/pyct/errors_test.py
index f6286a5..9640af1 100644
--- a/tensorflow/python/autograph/pyct/errors_test.py
+++ b/tensorflow/python/autograph/pyct/errors_test.py
@@ -36,7 +36,7 @@
         cause_metadata=None,
         cause_message='test message',
         source_map={})
-    exc = em.create_exception(CustomError)
+    exc = em.create_exception(CustomError())
     self.assertIsInstance(exc, CustomError)
     self.assertIn('test message', str(exc))
 
@@ -52,7 +52,7 @@
         cause_metadata=None,
         cause_message='test message',
         source_map={})
-    exc = em.create_exception(CustomError)
+    exc = em.create_exception(CustomError())
     self.assertIsNone(exc)
 
   def test_get_message_when_frame_info_code_is_none(self):
diff --git a/tensorflow/python/autograph/pyct/inspect_utils.py b/tensorflow/python/autograph/pyct/inspect_utils.py
index 6d4f252..47c52d2 100644
--- a/tensorflow/python/autograph/pyct/inspect_utils.py
+++ b/tensorflow/python/autograph/pyct/inspect_utils.py
@@ -81,7 +81,7 @@
 
 def isbuiltin(f):
   """Returns True if the argument is a built-in function."""
-  if f in six.moves.builtins.__dict__.values():
+  if any(f is builtin for builtin in six.moves.builtins.__dict__.values()):
     return True
   elif isinstance(f, types.BuiltinFunctionType):
     return True
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py
index 2821205..fa2c706 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/activity.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py
@@ -360,6 +360,7 @@
     assert not self._in_function_def_args
     self.state[_Lambda].enter()
     node = self.generic_visit(node)
+    anno.setanno(node, anno.Static.SCOPE, self.scope)
     self.state[_Lambda].exit()
     return node
 
diff --git a/tensorflow/python/autograph/pyct/templates.py b/tensorflow/python/autograph/pyct/templates.py
index 3e705dd..253e294 100644
--- a/tensorflow/python/autograph/pyct/templates.py
+++ b/tensorflow/python/autograph/pyct/templates.py
@@ -122,6 +122,7 @@
         anno.Basic.SKIP_PROCESSING,
         anno.Static.ORIG_DEFINITIONS,
         'extra_test',
+        'function_context_name',
     }
 
   def _prepare_replacement(self, replaced, key):
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 032781e..8d1be4b 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -18,7 +18,6 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
 import functools
 import re
 import threading
@@ -30,6 +29,7 @@
 from tensorflow.core.protobuf import rewriter_config_pb2
 from tensorflow.python import pywrap_tensorflow as tf_session
 from tensorflow.python.eager import context
+from tensorflow.python.eager import monitoring
 from tensorflow.python.framework import device
 from tensorflow.python.framework import error_interpolation
 from tensorflow.python.framework import errors
@@ -40,8 +40,13 @@
 from tensorflow.python.training.experimental import mixed_precision_global_state
 from tensorflow.python.util import compat
 from tensorflow.python.util import nest
+from tensorflow.python.util import object_identity
 from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.compat import collections_abc
 
+_python_session_create_counter = monitoring.Counter(
+    '/tensorflow/api/python/session_create_counter',
+    'Counter for number of sessions created in Python.')
 
 class SessionInterface(object):
   """Base class for implementations of TensorFlow client sessions."""
@@ -259,7 +264,7 @@
     elif isinstance(fetch, (list, tuple)):
       # NOTE(touts): This is also the code path for namedtuples.
       return _ListFetchMapper(fetch)
-    elif isinstance(fetch, collections.Mapping):
+    elif isinstance(fetch, collections_abc.Mapping):
       return _DictFetchMapper(fetch)
     elif _is_attrs_instance(fetch):
       return _AttrsFetchMapper(fetch)
@@ -470,9 +475,10 @@
     self._fetches = []
     self._targets = []
     self._feeds = feeds
-    self._feed_handles = feed_handles or {}
+    self._feed_handles = (
+        feed_handles or object_identity.ObjectIdentityDictionary())
     self._ops = []
-    self._fetch_handles = {}
+    self._fetch_handles = object_identity.ObjectIdentityDictionary()
     for fetch in self._fetch_mapper.unique_fetches():
       if isinstance(fetch, ops.Operation):
         self._assert_fetchable(graph, fetch)
@@ -491,8 +497,12 @@
 
   def _assert_fetchable(self, graph, op):
     if not graph.is_fetchable(op):
-      raise ValueError('Operation %r has been marked as not fetchable.' %
-                       op.name)
+      raise errors.InaccessibleTensorError(
+          'Operation %r has been marked as not fetchable. Typically this'
+          ' happens when it is defined in another function or code block.'
+          ' Use return values,explicit Python locals or TensorFlow collections'
+          ' to access it.'
+          % op.name)
 
   def fetches(self):
     """Return the unique names of tensors to fetch.
@@ -635,6 +645,7 @@
         creating the TensorFlow session.
       TypeError: If one of the arguments has the wrong type.
     """
+    _python_session_create_counter.get_cell().increase_by(1)
     if graph is None:
       self._graph = ops.get_default_graph()
     else:
@@ -1060,7 +1071,8 @@
 
     # Validate and process fetches.
     # TODO(touts): Support feeding and fetching the same tensor.
-    fetch_handler = _FetchHandler(self._graph, fetches, {})
+    fetch_handler = _FetchHandler(self._graph, fetches,
+                                  object_identity.ObjectIdentityDictionary())
 
     # Set up a graph with feeds and fetches for partial run.
     def _setup_fn(session, feed_list, fetch_list, target_list):
@@ -1094,7 +1106,7 @@
                          'graph before calling run().')
 
     # Create request.
-    feed_dict_tensor = {}
+    feed_dict_tensor = object_identity.ObjectIdentityDictionary()
     feed_map = {}
 
     # Validate and process feed_dict.
@@ -1228,7 +1240,8 @@
     self._extend_graph()
 
     # Create a fetch handler to take care of the structure of fetches.
-    fetch_handler = _FetchHandler(self._graph, fetches, {})
+    fetch_handler = _FetchHandler(self._graph, fetches,
+                                  object_identity.ObjectIdentityDictionary())
     # pylint: disable=protected-access
     fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()]
     target_list = [op._c_op for op in fetch_handler.targets()]
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 0c98002..9ec86c7 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -31,7 +31,7 @@
 # This value changes every day with an automatic CL. It can be modified in code
 # via `forward_compatibility_horizon()` or with the environment variable
 # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 7, 24)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 8, 9)
 
 _FORWARD_COMPATIBILITY_HORIZON_OVERRIDDEN = False
 _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
diff --git a/tensorflow/python/compat/v2_compat.py b/tensorflow/python/compat/v2_compat.py
index 8538108..0ae672d 100644
--- a/tensorflow/python/compat/v2_compat.py
+++ b/tensorflow/python/compat/v2_compat.py
@@ -39,13 +39,15 @@
   This function is called in the main TensorFlow `__init__.py` file, user should
   not need to call it, except during complex migrations.
   """
+  # TF2 behavior is enabled if either 1) enable_v2_behavior() is called or
+  # 2) the TF2_BEHAVIOR=1 environment variable is set.  In the latter case,
+  # the modules below independently check if tf2.enabled().
   tf2.enable()
   ops.enable_eager_execution()
   tensor_shape.enable_v2_tensorshape()  # Also switched by tf2
   variable_scope.enable_resource_variables()
   # Enables TensorArrayV2 and control flow V2.
-  # TODO(b/134181885): Re-enable this.
-  # control_flow_v2_toggles.enable_control_flow_v2()
+  control_flow_v2_toggles.enable_control_flow_v2()
 
 
 @tf_export(v1=["disable_v2_behavior"])
diff --git a/tensorflow/python/compiler/tensorrt/test/quantization_mnist_test.py b/tensorflow/python/compiler/tensorrt/test/quantization_mnist_test.py
index 5699461..d44a0ec 100644
--- a/tensorflow/python/compiler/tensorrt/test/quantization_mnist_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/quantization_mnist_test.py
@@ -153,8 +153,7 @@
           # runtime to allocate GPU memory.
           max_workspace_size_bytes=1 << 28,
           minimum_segment_size=2,
-          use_calibration=False,
-          use_function_backup=False)
+          use_calibration=False)
       graph_def = converter.convert()
       logging.info('Number of nodes after TF-TRT conversion: %d',
                    len(graph_def.node))
diff --git a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py
index 6b72cbe..30aae4d 100644
--- a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py
+++ b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py
@@ -23,6 +23,7 @@
 import gc
 import itertools
 import os
+import re
 import shutil
 import tempfile
 import warnings
@@ -234,10 +235,8 @@
         is_dynamic_op=run_params.dynamic_engine,
         maximum_cached_engines=1,
         use_calibration=run_params.use_calibration,
-        use_function_backup=False,
         max_batch_size=min(batch_list))
-    return conversion_params._replace(
-        use_function_backup=IsQuantizationWithCalibration(conversion_params))
+    return conversion_params
 
   def ShouldRunTest(self, run_params):
     """Whether to run the test."""
@@ -388,8 +387,7 @@
         minimum_segment_size=conversion_params.minimum_segment_size,
         is_dynamic_op=conversion_params.is_dynamic_op,
         maximum_cached_engines=conversion_params.maximum_cached_engines,
-        use_calibration=conversion_params.use_calibration,
-        use_function_backup=conversion_params.use_function_backup)
+        use_calibration=conversion_params.use_calibration)
 
   def _GetCalibratedInferGraph(self, run_params, saved_model_dir, inputs_data):
     """Return trt converted graphdef in INT8 mode."""
@@ -558,21 +556,18 @@
       if node.op == "TRTEngineOp":
         logging.info("Found TRTEngineOp: " + node.name)
         num_engines += 1
-        segment_funcdef_name = node.attr["segment_funcdef_name"].s
+        segment_funcdef_name = node.attr["segment_func"].func.name
         function_name = node.name + "_native_segment"
-        if IsQuantizationWithCalibration(run_params):
-          self.assertNotEmpty(segment_funcdef_name, node.name)
-          self.assertIn(function_name, functions)
-        else:
-          self.assertEmpty(segment_funcdef_name, node.name)
-          self.assertNotIn(function_name, functions)
+        is_dynamic_engine = not node.attr["static_engine"].b
+        self.assertNotEmpty(segment_funcdef_name, node.name)
+        self.assertIn(function_name, functions)
+        if not IsQuantizationWithCalibration and not is_dynamic_engine:
+          self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
         self.assertIn(node.name, expected_engines)
-        self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
         self.assertEqual(
             self._ToBytes(run_params.precision_mode),
             node.attr["precision_mode"].s, node.name)
 
-        is_dynamic_engine = not node.attr["static_engine"].b
         self.assertEqual(run_params.dynamic_engine, is_dynamic_engine,
                          node.name)
         self.assertEqual(node.attr["use_calibration"].b,
@@ -602,10 +597,11 @@
         node.name for node in gdef_to_verify.node if node.op == "TRTEngineOp"
     ]
     for func in gdef_to_verify.library.function:
-      for node in func.node_def:
-        all_op_names.append(node.name)
-        if node.op == "TRTEngineOp":
-          trt_op_names.append(node.name)
+      if not re.search(r"TRTEngineOp_\d+_native_segment", func.signature.name):
+        for node in func.node_def:
+          all_op_names.append(node.name)
+          if node.op == "TRTEngineOp":
+            trt_op_names.append(node.name)
     # Remove the function name prefix.
     def _Canonicalize(names):
       return set([self._ToString(name.split("/")[-1]) for name in names])
diff --git a/tensorflow/python/compiler/tensorrt/trt_convert.py b/tensorflow/python/compiler/tensorrt/trt_convert.py
index b11938a..210983f 100644
--- a/tensorflow/python/compiler/tensorrt/trt_convert.py
+++ b/tensorflow/python/compiler/tensorrt/trt_convert.py
@@ -51,6 +51,7 @@
 from tensorflow.python.saved_model import tag_constants
 from tensorflow.python.training import saver
 from tensorflow.python.training.tracking import tracking
+from tensorflow.python.util import nest
 from tensorflow.python.util.lazy_loader import LazyLoader
 
 # Lazily load the op, since it's not available in cpu-only builds. Importing
@@ -94,8 +95,10 @@
 
   @staticmethod
   def supported_precision_modes():
-    return [TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8]
-
+    precisions = [
+        TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8
+    ]
+    return precisions + [p.lower() for p in precisions]
 
 # Use a large enough number as the default max_workspace_size for TRT engines,
 # so it can produce reasonable performance results with the default.
@@ -144,11 +147,6 @@
         # trained with fake quantization.
         "use_calibration",
 
-        # If set to True, it will create a FunctionDef for each subgraph that is
-        # converted to TRT op, and if TRT ops fail to execute at runtime, it'll
-        # invoke that function as a fallback.
-        "use_function_backup",
-
         # Max size for the input batch.
         # This option is deprecated in TF 2.0.
         "max_batch_size",
@@ -162,10 +160,8 @@
     is_dynamic_op=False,
     maximum_cached_engines=1,
     use_calibration=True,
-    use_function_backup=True,
     max_batch_size=1)
 
-_TRT_ENGINE_CACHE_CONTAINER_NAME = "TF-TRT-Engine-Cache"
 _TRT_ENGINE_OP_NAME = "TRTEngineOp"
 
 
@@ -269,8 +265,6 @@
       "maximum_cached_engines"].i = conversion_params.maximum_cached_engines
   optimizer.parameter_map[
       "use_calibration"].b = conversion_params.use_calibration
-  optimizer.parameter_map[
-      "use_function_backup"].b = conversion_params.use_function_backup
 
   if is_v2:
     # Static mode (a.k.a pre-generating TRT engines and make them node
@@ -341,8 +335,7 @@
                minimum_segment_size=3,
                is_dynamic_op=False,
                maximum_cached_engines=1,
-               use_calibration=True,
-               use_function_backup=True):
+               use_calibration=True):
     """Initialize the converter.
 
     Args:
@@ -381,13 +374,14 @@
         will occur. Please note that accuracy may be negatively affected if
         there is a mismatch between which tensors TRT quantizes and which
         tensors were trained with fake quantization.
-      use_function_backup: if set to True, it will create a FunctionDef for each
-        subgraph that is converted to TRT op, and if TRT ops fail to execute at
-        runtime, it'll invoke that function as a fallback.
 
     Raises:
       ValueError: if the combination of the parameters is invalid.
+      RuntimeError: if this class is used in TF 2.0.
     """
+    if context.executing_eagerly():
+      raise RuntimeError("Please use TrtGraphConverterV2 in TF 2.0.")
+
     if input_graph_def and input_saved_model_dir:
       raise ValueError(
           "Can only specify one of input_graph_def and input_saved_model_dir")
@@ -421,12 +415,6 @@
           "dynamic TRT ops only. Disregarding is_dynamic_op parameter.")
       is_dynamic_op = True
 
-    # TODO(laigd): consider provide a mechanism to remove the fallback path
-    # after calibration is done.
-    if self._need_calibration and not use_function_backup:
-      raise ValueError(
-          "Calibration requires enabling fallback to TF function execution.")
-
     # TODO(laigd):
     # - Verify in int8 mode that maximum_cached_engines is set properly.
     # - If it fails to build the int8 engine it should return error.
@@ -443,7 +431,6 @@
         is_dynamic_op=is_dynamic_op,
         maximum_cached_engines=maximum_cached_engines,
         use_calibration=use_calibration,
-        use_function_backup=use_function_backup,
         max_batch_size=max_batch_size)
     _check_conversion_params(self._conversion_params)
 
@@ -590,14 +577,18 @@
     assert self._need_calibration
     assert not self._calibration_data_collected
 
-    if context.executing_eagerly():
-      raise RuntimeError("Calibration for TF 2.0 is not supported yet.")
-
     if (feed_dict_fn and input_map_fn) or (not feed_dict_fn and
                                            not input_map_fn):
       raise ValueError(
           "Should specify one and only one of feed_dict_fn and input_map_fn.")
 
+    if input_map_fn:
+      for k, v in input_map_fn().items():
+        if not isinstance(k, str):
+          raise ValueError("Keys of input_map_fn must be of type str")
+        if not isinstance(v, tf.Tensor):
+          raise ValueError("Values of input_map_fn must be of type tf.Tensor")
+
     self._calibration_graph = ops.Graph()
     with self._calibration_graph.as_default():
       fetches = importer.import_graph_def(
@@ -734,8 +725,7 @@
 
 def _get_resource_handle(name, device):
   with ops.device(device):
-    return gen_trt_ops.create_trt_engine_cache_handle(
-        container=_TRT_ENGINE_CACHE_CONTAINER_NAME, resource_name=name)
+    return gen_trt_ops.create_trt_resource_handle(resource_name=name)
 
 
 class TRTEngineResourceDeleter(tracking.CapturableResourceDeleter):
@@ -766,14 +756,14 @@
     self._resource_name = resource_name
     # Track the serialized engine file in the SavedModel.
     self._filename = self._track_trackable(
-        tracking.TrackableAsset(filename), "_serialized_trt_engine_filename")
+        tracking.TrackableAsset(filename), "_serialized_trt_resource_filename")
     self._maximum_cached_engines = maximum_cached_engines
 
   def _create_resource(self):
     return _get_resource_handle(self._resource_name, self._resource_device)
 
   def _initialize(self):
-    gen_trt_ops.populate_trt_engine_cache(
+    gen_trt_ops.initialize_trt_resource(
         self.resource_handle,
         self._filename,
         max_cached_engines_count=self._maximum_cached_engines)
@@ -911,6 +901,10 @@
         self._converted_graph_def,
         [tensor.name for tensor in frozen_func.inputs],
         [tensor.name for tensor in frozen_func.outputs])
+    # Reconstruct the output signatures using the ones from original model.
+    self._converted_func.graph.structured_outputs = nest.pack_sequence_as(
+        func.graph.structured_outputs,
+        self._converted_func.graph.structured_outputs)
 
     self._converted = True
 
@@ -944,11 +938,10 @@
       filename = os.path.join(engine_asset_dir,
                               "trt-serialized-engine." + canonical_engine_name)
       try:
-        gen_trt_ops.dump_trt_engine_cache(
-            container=_TRT_ENGINE_CACHE_CONTAINER_NAME,
+        gen_trt_ops.serialize_trt_resource(
             resource_name=canonical_engine_name,
             filename=filename,
-            delete_cache_after_dump=True)
+            delete_resource=True)
       except errors.NotFoundError:
         # If user haven't run the function to populate the engine, it's fine,
         # and we don't need to track any serialized TRT engines.
diff --git a/tensorflow/python/compiler/tensorrt/trt_convert_test.py b/tensorflow/python/compiler/tensorrt/trt_convert_test.py
index 61ecd79..f49376f 100644
--- a/tensorflow/python/compiler/tensorrt/trt_convert_test.py
+++ b/tensorflow/python/compiler/tensorrt/trt_convert_test.py
@@ -35,6 +35,7 @@
 from tensorflow.python.framework import graph_util
 from tensorflow.python.framework import importer
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
@@ -42,13 +43,14 @@
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
 from tensorflow.python.saved_model import builder
+from tensorflow.python.saved_model import load
 from tensorflow.python.saved_model import loader
+from tensorflow.python.saved_model import loader_impl
+from tensorflow.python.saved_model import save
 from tensorflow.python.saved_model import signature_constants
 from tensorflow.python.saved_model import signature_def_utils
 from tensorflow.python.saved_model import tag_constants
 from tensorflow.python.saved_model import utils
-from tensorflow.python.saved_model import load
-from tensorflow.python.saved_model import save
 from tensorflow.python.tools import saved_model_utils
 from tensorflow.python.training.tracking import tracking
 from tensorflow.python.util.lazy_loader import LazyLoader
@@ -67,6 +69,9 @@
   # memory.
   _TRT_MAX_WORKSPACE_SIZE_BYTES = 2 << 20
 
+  def mkdtemp(self):
+    return tempfile.mkdtemp(dir=self.get_temp_dir())
+
   def testGetTensorrtRewriterConfig(self):
     """Test case for TrtGraphConverter.get_tensorrt_rewriter_config()."""
     if not is_tensorrt_enabled():
@@ -200,8 +205,7 @@
                     max_batch_size=1,
                     minimum_segment_size=3,
                     is_dynamic_op=False,
-                    maximum_cached_engines=1,
-                    use_function_backup=False):
+                    maximum_cached_engines=1):
     """Helper method to convert a GraphDef or SavedModel using TF-TRT."""
     converter = trt_convert.TrtGraphConverter(
         input_saved_model_dir=input_saved_model_dir,
@@ -215,8 +219,7 @@
                         else trt_convert.TrtPrecisionMode.FP32),
         minimum_segment_size=minimum_segment_size,
         is_dynamic_op=is_dynamic_op,
-        maximum_cached_engines=maximum_cached_engines,
-        use_function_backup=use_function_backup)
+        maximum_cached_engines=maximum_cached_engines)
     output_graph_def = converter.convert()
 
     if need_calibration:
@@ -249,8 +252,7 @@
         input_saved_model_dir=input_saved_model_dir,
         output_saved_model_dir=output_saved_model_dir,
         need_calibration=need_calibration,
-        is_dynamic_op=is_dynamic_op,
-        use_function_backup=need_calibration)
+        is_dynamic_op=is_dynamic_op)
     graph_defs_to_verify = [output_graph_def]
 
     if output_saved_model_dir:
@@ -291,8 +293,7 @@
     if not is_tensorrt_enabled():
       return
 
-    tmp_dir = self.get_temp_dir()
-    input_saved_model_dir = os.path.join(tmp_dir, "in_dir1")
+    input_saved_model_dir = self.mkdtemp()
     self._WriteInputSavedModel(input_saved_model_dir)
 
     for need_calibration in [False, True]:
@@ -300,22 +301,21 @@
       self._TestTrtGraphConverter()
 
       # Use SavedModel as input.
-      output_saved_model_dir = os.path.join(
-          tmp_dir, "out_dir1%s" % ("_int8" if need_calibration else ""))
       self._TestTrtGraphConverter(
           input_saved_model_dir=input_saved_model_dir,
-          output_saved_model_dir=output_saved_model_dir,
+          output_saved_model_dir=self.mkdtemp(),
           need_calibration=need_calibration)
 
-  def _CreateConverterV2(self, input_saved_model_dir):
+  def _CreateConverterV2(self,
+                         input_saved_model_dir,
+                         precision_mode=trt_convert.TrtPrecisionMode.FP32):
     return trt_convert.TrtGraphConverterV2(
         input_saved_model_dir=input_saved_model_dir,
         input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
         conversion_params=trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
-            precision_mode=trt_convert.TrtPrecisionMode.FP32,
+            precision_mode=precision_mode,
             is_dynamic_op=True,
-            maximum_cached_engines=2,
-            use_function_backup=False))
+            maximum_cached_engines=2))
 
   @test_util.run_v2_only
   def testTrtGraphConverter_BasicConversion_v2(self):
@@ -326,7 +326,7 @@
     np_input = np.random.random_sample([4, 1, 1]).astype(np.float32)
 
     # Create a model and save it.
-    input_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
+    input_saved_model_dir = self.mkdtemp()
     root = self._GetModelForV2()
     expected_output = root.run(np_input)
     save.save(root, input_saved_model_dir,
@@ -354,7 +354,7 @@
     _check_trt_ops(converted_concrete_func.graph.as_graph_def())
 
     # Save the converted model without any TRT engine cache.
-    output_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
+    output_saved_model_dir = self.mkdtemp()
     converter.save(output_saved_model_dir)
     unexpected_asset_file = os.path.join(
         output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0")
@@ -364,10 +364,10 @@
     output_with_trt = converted_func(np_input)
     self.assertEqual(1, len(output_with_trt))
     self.assertAllClose(
-        expected_output, output_with_trt[0], atol=1e-6, rtol=1e-6)
+        expected_output, output_with_trt.values()[0], atol=1e-6, rtol=1e-6)
 
     # Save the converted model again with serialized engine cache.
-    output_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
+    output_saved_model_dir = self.mkdtemp()
     converter.save(output_saved_model_dir)
     expected_asset_file = os.path.join(
         output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0")
@@ -376,7 +376,7 @@
 
     # Load and verify the converted model.
     #
-    # TODO(laigd): the name of then new input_signature of the
+    # TODO(laigd): the name of the new input_signature of the
     # `root_with_trt.run` function is empty string (originaly was None),
     # investigate why.
     root_with_trt = load.load(output_saved_model_dir)
@@ -389,10 +389,54 @@
     output_with_trt = converted_signature(ops.convert_to_tensor(np_input))
     # The output of running the converted signature is a dict due to
     # compatibility reasons with V1 SavedModel signature mechanism.
-    output_with_trt = output_with_trt[output_with_trt.keys()[0]]
+    output_with_trt = output_with_trt.values()[0]
     self.assertAllClose(expected_output, output_with_trt, atol=1e-6, rtol=1e-6)
 
   @test_util.run_v2_only
+  def testTrtGraphConverter_Int8Conversion_v2(self):
+    if not is_tensorrt_enabled():
+      return
+
+    np_input = np.random.random_sample([4, 1, 1]).astype(np.float32)
+
+    # Create a model and save it.
+    input_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
+    root = self._GetModelForV2()
+    expected_output = root.run(np_input)
+    save.save(root, input_saved_model_dir,
+              {_SAVED_MODEL_SIGNATURE_KEY: root.run})
+
+    # Run TRT conversion.
+    converter = self._CreateConverterV2(
+        input_saved_model_dir, precision_mode=trt_convert.TrtPrecisionMode.INT8)
+    converted_func = converter.convert()
+
+    # Run the converted function for INT8 calibration.
+    calibration_output = converted_func(np_input)
+    self.assertEqual(1, len(calibration_output))
+    self.assertAllClose(
+        expected_output, calibration_output.values()[0], atol=1e-6, rtol=1e-6)
+
+    # Save the converted model again with serialized engine cache.
+    output_saved_model_dir = self.mkdtemp()
+    converter.save(output_saved_model_dir)
+    expected_asset_file = os.path.join(
+        output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0")
+    self.assertTrue(os.path.exists(expected_asset_file))
+    self.assertTrue(os.path.getsize(expected_asset_file))
+
+    # Load and verify the converted model.
+    root_with_trt = load.load(output_saved_model_dir)
+    converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
+    output_with_trt = converted_signature(ops.convert_to_tensor(np_input))
+    self.assertEqual(1, len(output_with_trt))
+
+    # The output of running the converted signature is a dict due to
+    # compatibility reasons with V1 SavedModel signature mechanism.
+    self.assertAllClose(
+        expected_output, output_with_trt.values()[0], atol=1e-6, rtol=1e-6)
+
+  @test_util.run_v2_only
   def testTrtGraphConverter_DestroyEngineCache(self):
     """Test case for trt_convert.TrtGraphConverter()."""
     if not is_tensorrt_enabled():
@@ -401,7 +445,7 @@
     np_input = np.random.random_sample([4, 1, 1]).astype(np.float32)
 
     # Create a model and save it.
-    input_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
+    input_saved_model_dir = self.mkdtemp()
     root = self._GetModelForV2()
     save.save(root, input_saved_model_dir,
               {_SAVED_MODEL_SIGNATURE_KEY: root.run})
@@ -410,13 +454,12 @@
     converter = self._CreateConverterV2(input_saved_model_dir)
     converted_func = converter.convert()
     converted_func(np_input)  # Populate the TRT engine cache.
-    output_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
+    output_saved_model_dir = self.mkdtemp()
     converter.save(output_saved_model_dir)
 
     def _destroy_cache():
       with ops.device("GPU:0"):
-        handle = gen_trt_ops.create_trt_engine_cache_handle(
-            container=trt_convert._TRT_ENGINE_CACHE_CONTAINER_NAME,
+        handle = gen_trt_ops.create_trt_resource_handle(
             resource_name="TRTEngineOp_0")
         gen_resource_variable_ops.destroy_resource_op(
             handle, ignore_lookup_error=False)
@@ -442,20 +485,126 @@
                                  r"Resource .* does not exist."):
       _destroy_cache()
 
-  def _TestRun(self,
-               sess,
-               batch_size,
-               use_function_backup=False,
-               expect_engine_is_run=True):
-    try:
-      result = sess.run(
-          "output:0", feed_dict={"input:0": [[[1.0]]] * batch_size})
-      self.assertAllEqual([[[4.0]]] * batch_size, result)
-    except errors.OpError as e:
-      # This should happen only when fallback path is disabled and TRT engine
-      # fails to run.
-      self.assertTrue(not use_function_backup and not expect_engine_is_run)
-      self.assertIn("Fallback path is disabled, for TRTEngineOp_0", str(e))
+  def _CompareSavedModel(self, model_class):
+    signature_key = "serving_default"
+
+    def _GetModelPaths(model_class):
+      input_saved_model_dir = self.mkdtemp()
+      root = model_class()
+      save.save(root, input_saved_model_dir)
+
+      converter = trt_convert.TrtGraphConverterV2(
+          input_saved_model_dir=input_saved_model_dir)
+      converter.convert()
+      output_saved_model_dir = self.mkdtemp()
+      converter.save(output_saved_model_dir)
+      return input_saved_model_dir, output_saved_model_dir
+
+    def _GetSignatureDef(export_dir):
+      saved_model_proto = loader_impl.parse_saved_model(export_dir)
+      self.assertEqual(1, len(saved_model_proto.meta_graphs))
+      meta_graph = saved_model_proto.meta_graphs[0]
+      self.assertIn(signature_key, meta_graph.signature_def)
+      return meta_graph.signature_def[signature_key]
+
+    def _CompareSignatureDef(original_def, converted_def, is_input):
+      endpoints = original_def.inputs if is_input else original_def.outputs
+      converted_endpoints = (
+          converted_def.inputs if is_input else converted_def.outputs)
+      self.assertEqual(set(endpoints.keys()), set(converted_endpoints.keys()))
+      for key in endpoints:
+        original_input = endpoints[key]
+        converted_input = converted_endpoints[key]
+        self.assertEqual(original_input.name, converted_input.name)
+        self.assertEqual(original_input.dtype, converted_input.dtype)
+        self.assertEqual(
+            tensor_shape.TensorShape(original_input.tensor_shape).as_list(),
+            tensor_shape.TensorShape(converted_input.tensor_shape).as_list())
+
+    def _GetStructuredOutputs(export_dir):
+      root = load.load(export_dir)
+      return root.signatures[signature_key].structured_outputs
+
+    saved_model_path, converted_saved_model_path = _GetModelPaths(model_class)
+    original_def = _GetSignatureDef(saved_model_path)
+    converted_def = _GetSignatureDef(converted_saved_model_path)
+    self.assertEqual(original_def.method_name, converted_def.method_name)
+    _CompareSignatureDef(original_def, converted_def, True)
+    _CompareSignatureDef(original_def, converted_def, False)
+
+    self.assertEqual(
+        _GetStructuredOutputs(saved_model_path),
+        _GetStructuredOutputs(converted_saved_model_path))
+
+  @test_util.run_v2_only
+  def testRetainSignatureInfo_NoInputs(self):
+
+    class _Model(tracking.AutoTrackable):
+
+      @def_function.function(input_signature=[])
+      def run(self):
+        return array_ops.constant(1.0)
+
+    self._CompareSavedModel(_Model)
+
+  @test_util.run_v2_only
+  def testRetainSignatureInfo_OneInput(self):
+
+    class _Model(tracking.AutoTrackable):
+
+      @def_function.function(input_signature=[
+          tensor_spec.TensorSpec(shape=[None, 1], dtype=dtypes.float32)
+      ])
+      def run(self, inp):
+        return inp + inp * inp
+
+    self._CompareSavedModel(_Model)
+
+  @test_util.run_v2_only
+  def testRetainSignatureInfo_TwoInputs(self):
+
+    class _Model(tracking.AutoTrackable):
+
+      @def_function.function(input_signature=[
+          tensor_spec.TensorSpec(shape=[None, 1], dtype=dtypes.float32),
+          tensor_spec.TensorSpec(shape=[None, 2], dtype=dtypes.float32)
+      ])
+      def run(self, inp1, inp2):
+        return inp1 + inp2 * inp2
+
+    self._CompareSavedModel(_Model)
+
+  @test_util.run_v2_only
+  def testRetainSignatureInfo_OneOutputSignatureKey(self):
+
+    class _Model(tracking.AutoTrackable):
+
+      @def_function.function(input_signature=[])
+      def run(self):
+        return {"my_output": array_ops.constant(1.0)}
+
+    self._CompareSavedModel(_Model)
+
+  @test_util.run_v2_only
+  def testRetainSignatureInfo_TwoOutputSignatureKeys(self):
+
+    class _Model(tracking.AutoTrackable):
+
+      @def_function.function(input_signature=[
+          tensor_spec.TensorSpec(shape=[None, 1], dtype=dtypes.float32)
+      ])
+      def run(self, inp):
+        # Here the keys are not ordered lexicographically on purpose.
+        return {
+            "output_b": array_ops.constant(1.0),
+            "output_a": inp + inp * inp
+        }
+
+    self._CompareSavedModel(_Model)
+
+  def _TestRun(self, sess, batch_size, expect_engine_is_run=True):
+    result = sess.run("output:0", feed_dict={"input:0": [[[1.0]]] * batch_size})
+    self.assertAllEqual([[[4.0]]] * batch_size, result)
 
   @test_util.deprecated_graph_mode_only
   def testTrtGraphConverter_MinimumSegmentSize(self):
@@ -478,16 +627,14 @@
     if not is_tensorrt_enabled():
       return
 
-    tmp_dir = self.get_temp_dir()
-    input_saved_model_dir = os.path.join(tmp_dir, "in_dir2")
-    output_saved_model_dir = os.path.join(tmp_dir, "out_dir2")
+    input_saved_model_dir = self.mkdtemp()
+    output_saved_model_dir = self.mkdtemp()
     self._WriteInputSavedModel(input_saved_model_dir)
     output_graph_def = self._ConvertGraph(
         input_saved_model_dir=input_saved_model_dir,
         output_saved_model_dir=output_saved_model_dir,
         is_dynamic_op=True,
-        maximum_cached_engines=2,
-        use_function_backup=False)  # Disallow fallback.
+        maximum_cached_engines=2)
 
     # Test the output GraphDef.
     with ops.Graph().as_default():
@@ -513,19 +660,17 @@
         # the max, it should evict an old engine and create a new one.
         self._TestRun(sess, 3)
 
-  def _TestStaticOp(self, use_function_backup):
+  def _TestStaticOp(self):
     if not is_tensorrt_enabled():
       return
 
-    tmp_dir = self.get_temp_dir()
-    input_saved_model_dir = os.path.join(tmp_dir, "in_dir3")
-    output_saved_model_dir = os.path.join(tmp_dir, "out_dir3")
+    input_saved_model_dir = self.mkdtemp()
+    output_saved_model_dir = self.mkdtemp()
     self._WriteInputSavedModel(input_saved_model_dir)
     output_graph_def = self._ConvertGraph(
         input_saved_model_dir=input_saved_model_dir,
         output_saved_model_dir=output_saved_model_dir,
-        maximum_cached_engines=2,  # This is noop, added just for testing.
-        use_function_backup=use_function_backup)
+        maximum_cached_engines=2)
 
     # Test the output GraphDef.
     with ops.Graph().as_default():
@@ -533,18 +678,10 @@
       with self.session(config=self._GetConfigProto()) as sess:
         # Run with batch size 1, the default engine embedded in the graphdef
         # will be used.
-        self._TestRun(
-            sess,
-            1,
-            use_function_backup=use_function_backup,
-            expect_engine_is_run=True)
+        self._TestRun(sess, 1, expect_engine_is_run=True)
         # Run with batch size 2, which exceed the max_batch_size, it should try
         # to fall back to TF function.
-        self._TestRun(
-            sess,
-            2,
-            use_function_backup=use_function_backup,
-            expect_engine_is_run=False)
+        self._TestRun(sess, 2, expect_engine_is_run=False)
 
     # Test the output SavedModel
     with ops.Graph().as_default():
@@ -552,26 +689,14 @@
         loader.load(sess, [tag_constants.SERVING], output_saved_model_dir)
         # Run with batch size 1, the default engine embedded in the graphdef
         # will be used.
-        self._TestRun(
-            sess,
-            1,
-            use_function_backup=use_function_backup,
-            expect_engine_is_run=True)
+        self._TestRun(sess, 1, expect_engine_is_run=True)
         # Run with batch size 2, which exceed the max_batch_size, it should try
         # to fall back to TF function.
-        self._TestRun(
-            sess,
-            2,
-            use_function_backup=use_function_backup,
-            expect_engine_is_run=False)
+        self._TestRun(sess, 2, expect_engine_is_run=False)
 
   @test_util.deprecated_graph_mode_only
-  def testTrtGraphConverter_StaticOp_NoFallback(self):
-    self._TestStaticOp(use_function_backup=False)
-
-  @test_util.deprecated_graph_mode_only
-  def testTrtGraphConverter_StaticOp_WithFallback(self):
-    self._TestStaticOp(use_function_backup=True)
+  def testTrtGraphConverter_StaticOp(self):
+    self._TestStaticOp()
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/compiler/xla/xla.py b/tensorflow/python/compiler/xla/xla.py
index 1fa462f..55bfaeb 100644
--- a/tensorflow/python/compiler/xla/xla.py
+++ b/tensorflow/python/compiler/xla/xla.py
@@ -223,7 +223,7 @@
       for index in xrange(len(op.inputs)):
         x = op.inputs[index]
         real_x = self.AddValue(x)
-        if real_x != x:
+        if real_x is not x:
           op._update_input(index, real_x)  # pylint: disable=protected-access
 
     if external_control_inputs:
diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD
index 6482458..d609100 100644
--- a/tensorflow/python/data/experimental/kernel_tests/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/BUILD
@@ -540,7 +540,7 @@
     size = "medium",
     srcs = ["rejection_resample_test.py"],
     python_version = "PY2",
-    shard_count = 2,
+    shard_count = 5,
     srcs_version = "PY2AND3",
     tags = [
         "noasan",
@@ -574,6 +574,7 @@
         "//tensorflow/python:array_ops",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:constant_op",
+        "//tensorflow/python:control_flow_v2_toggles",
         "//tensorflow/python:dtypes",
         "//tensorflow/python:errors",
         "//tensorflow/python:framework_test_lib",
@@ -660,6 +661,7 @@
     name = "snapshot_test",
     srcs = ["snapshot_test.py"],
     python_version = "PY2",
+    shard_count = 10,
     srcs_version = "PY2AND3",
     deps = [
         ":reader_dataset_ops_test_base",
diff --git a/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py
index 5d964c1..e9d5b43 100644
--- a/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py
@@ -24,10 +24,11 @@
 from tensorflow.python.data.experimental.ops import interleave_ops
 from tensorflow.python.data.experimental.ops import readers
 from tensorflow.python.data.experimental.ops import unique
+from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import errors
-from tensorflow.python.framework import test_util
 from tensorflow.python.ops import string_ops
 from tensorflow.python.platform import test
 
@@ -37,7 +38,6 @@
     yield l[i:i + n]
 
 
-@test_util.run_all_in_graph_and_eager_modes
 class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
                            parameterized.TestCase):
 
@@ -65,7 +65,10 @@
     else:
       self.assertDatasetProduces(dataset, list(chunk(expected, batch)))
 
-  @parameterized.parameters(True, False)
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(shuffle=[True, False])))
   def testFlatMapReaderPipeline(self, shuffle):
     dataset = dataset_ops.Dataset.list_files(
         self.test_filenames, shuffle=shuffle)
@@ -80,6 +83,7 @@
     ]
     self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testZipReaderPipeline(self):
     dataset1 = dataset_ops.Dataset.list_files(
         self.test_filenames, shuffle=False)
@@ -101,7 +105,10 @@
 
     self.assertDatasetProduces(dataset, expected)
 
-  @parameterized.parameters(True, False)
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(shuffle=[True, False])))
   def testConcatenateReaderPipeline(self, shuffle):
     dataset1 = dataset_ops.Dataset.list_files(
         self.test_filenames, shuffle=shuffle)
@@ -125,7 +132,10 @@
     expected += expected
     self.assertDatasetProducesWithShuffle(dataset, expected, 5, 8, shuffle)
 
-  @parameterized.parameters(True, False)
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(shuffle=[True, False])))
   def testPipelineWithMap(self, shuffle):
     dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
     dataset = dataset.apply(
@@ -141,6 +151,7 @@
     ]
     self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testDirectFilenameTFRecordReaderPipeline(self):
     dataset = core_readers.TFRecordDataset(self.test_filenames)
     dataset = distribute._AutoShardDataset(dataset, 5, 0)
@@ -152,7 +163,10 @@
     ]
     self.assertDatasetProduces(dataset, expected)
 
-  @parameterized.parameters(True, False)
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(shuffle=[True, False])))
   def testValidPipelineWithRangeDataset(self, shuffle):
     dataset = dataset_ops.Dataset.range(self._num_files)
     dataset = dataset.map(lambda n: string_ops.string_join(  # pylint:disable=g-long-lambda
@@ -171,9 +185,13 @@
     ]
     self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
 
-  @parameterized.parameters((1, 0, 10, 10), (2, 1, 20, 5), (10, 1, 1, 10))
-  def testStandardReaderPipeline(self, num_epochs, index, batch_size,
-                                 parallel_reads):
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(params=[(1, 0, 10, 10), (2, 1, 20, 5),
+                                       (10, 1, 1, 10)])))
+  def testStandardReaderPipeline(self, params):
+    num_epochs, index, batch_size, parallel_reads = params
     dataset = readers.make_tf_record_dataset(
         file_pattern=self.test_filenames,
         num_epochs=num_epochs,
@@ -195,7 +213,10 @@
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(outputs())
 
-  @parameterized.parameters(True, False)
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(shuffle=[True, False])))
   def testSampleResNetPipeline(self, shuffle):
     dataset = dataset_ops.Dataset.list_files(
         self.test_filenames, shuffle=shuffle)
@@ -211,6 +232,7 @@
     ]
     self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testWorkersGreaterThanNumFiles(self):
     dataset = dataset_ops.Dataset.list_files(self.test_filenames)
     dataset = dataset.apply(
@@ -219,6 +241,7 @@
     dataset = distribute._AutoShardDataset(dataset, 500, 499)
     self.assertDatasetProduces(dataset, [])
 
+  @combinations.generate(test_base.default_test_combinations())
   def testTFRecordReaderWithDirectFileNames(self):
     # Using `_TFRecordDataset` creates a raw op rather than wrapping it around
     # a flat_map automatically.
@@ -232,6 +255,7 @@
     ]
     self.assertDatasetProduces(dataset, expected)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testTFRecordReaderWithDirectFileNamesAndShapes(self):
     # Using `_TFRecordDataset` creates a raw op rather than wrapping it around
     # a flat_map automatically.
@@ -248,23 +272,27 @@
     ]
     self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
 
+  @combinations.generate(test_base.default_test_combinations())
   def testShardOutOfRange(self):
     dataset = dataset_ops.Dataset.range(5)
     with self.assertRaises(errors.InvalidArgumentError):
       dataset = distribute._AutoShardDataset(dataset, 10, 0)
       self.evaluate(self.getNext(dataset)())
 
+  @combinations.generate(test_base.default_test_combinations())
   def testShardOutOfRangeEmptyDataset(self):
     dataset = dataset_ops.Dataset.range(0)
     with self.assertRaises(errors.OutOfRangeError):
       dataset = distribute._AutoShardDataset(dataset, 10, 0)
       self.evaluate(self.getNext(dataset)())
 
+  @combinations.generate(test_base.default_test_combinations())
   def testNoReaderPipelines(self):
     dataset = dataset_ops.Dataset.range(1024)
     dataset = distribute._AutoShardDataset(dataset, 2, 0)
     self.assertDatasetProduces(dataset, [i for i in range(1024) if i % 2 == 0])
 
+  @combinations.generate(test_base.default_test_combinations())
   def testUnknownOpInPipelineStillShardsAtTheEnd(self):
     dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
     dataset = dataset.flat_map(core_readers.TFRecordDataset)
@@ -279,6 +307,7 @@
     ]
     self.assertDatasetProduces(dataset, expected)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testInvalidWorkerIndex(self):
     dataset = dataset_ops.Dataset.list_files(self.test_filenames)
     dataset = dataset.flat_map(core_readers.TFRecordDataset)
@@ -289,7 +318,6 @@
       self.evaluate(self.getNext(dataset)())
 
 
-@test_util.run_all_in_graph_and_eager_modes
 class AutoShardTextLineDatasetTest(
     reader_dataset_ops_test_base.TextLineDatasetTestBase,
     parameterized.TestCase):
@@ -300,6 +328,7 @@
     self._num_records = 10
     self.test_filenames = self._createFiles(self._num_files, self._num_records)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testDirectFilenameTextLineReaderPipeline(self):
     dataset = core_readers.TextLineDataset(self.test_filenames)
     dataset = distribute._AutoShardDataset(dataset, 5, 0)
diff --git a/tensorflow/python/data/experimental/kernel_tests/make_csv_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/make_csv_dataset_test.py
index 267e3e8..16c323b 100644
--- a/tensorflow/python/data/experimental/kernel_tests/make_csv_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/make_csv_dataset_test.py
@@ -221,6 +221,54 @@
           compression_type=compression_type,
       )
 
+  def testMakeCSVDataset_withCompressionTypeAndNoColumnNames(self):
+    """Tests `compression_type` argument."""
+    record_defaults = [
+        constant_op.constant([], dtypes.int32),
+        constant_op.constant([], dtypes.int64),
+        constant_op.constant([], dtypes.float32),
+        constant_op.constant([], dtypes.float64),
+        constant_op.constant([], dtypes.string)
+    ]
+
+    column_names = ["col%d" % i for i in range(5)]
+    inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"],
+              [
+                  ",".join(x for x in column_names), "10,11,12,13,14",
+                  "15,16,17,18,19"
+              ]]
+    expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
+                       [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
+    label = "col0"
+
+    self._test_dataset(
+        inputs,
+        expected_output=expected_output,
+        expected_keys=column_names,
+        label_name=label,
+        batch_size=1,
+        num_epochs=1,
+        shuffle=False,
+        header=True,
+        column_defaults=record_defaults,
+        compression_type="GZIP",
+    )
+
+    with self.assertRaisesRegexp(ValueError,
+                                 "compression_type .ZLIB. is not supported"):
+      self._test_dataset(
+          inputs,
+          expected_output=expected_output,
+          expected_keys=column_names,
+          label_name=label,
+          batch_size=1,
+          num_epochs=1,
+          shuffle=False,
+          header=True,
+          column_defaults=record_defaults,
+          compression_type="ZLIB",
+      )
+
   def testMakeCSVDataset_withBadInputs(self):
     """Tests that exception is raised when input is malformed.
     """
diff --git a/tensorflow/python/data/experimental/kernel_tests/make_tf_record_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/make_tf_record_dataset_test.py
index 31b9cd6..ec17603 100644
--- a/tensorflow/python/data/experimental/kernel_tests/make_tf_record_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/make_tf_record_dataset_test.py
@@ -102,15 +102,17 @@
 
   def _shuffle_test(self, batch_size, num_epochs, num_parallel_reads=1,
                     seed=None):
-    dataset = readers.make_tf_record_dataset(
-        file_pattern=self.test_filenames,
-        num_epochs=num_epochs,
-        batch_size=batch_size,
-        num_parallel_reads=num_parallel_reads,
-        shuffle=True,
-        shuffle_seed=seed)
 
-    next_element = self.getNext(dataset)
+    def dataset_fn():
+      return readers.make_tf_record_dataset(
+          file_pattern=self.test_filenames,
+          num_epochs=num_epochs,
+          batch_size=batch_size,
+          num_parallel_reads=num_parallel_reads,
+          shuffle=True,
+          shuffle_seed=seed)
+
+    next_element = self.getNext(dataset_fn())
     first_batches = []
     try:
       while True:
@@ -118,7 +120,7 @@
     except errors.OutOfRangeError:
       pass
 
-    next_element = self.getNext(dataset)
+    next_element = self.getNext(dataset_fn())
     second_batches = []
     try:
       while True:
diff --git a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
index a0253ad..61562f2 100644
--- a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
@@ -227,7 +227,7 @@
           array_ops.check_numerics(
               constant_op.constant(1.0) / constant_op.constant(0.0), "oops"))
       dataset = dataset.apply(batching.map_and_batch(lambda x: x, 14))
-      get_next = self.getNext(dataset)
+      get_next = self.getNext(dataset, requires_initialization=True)
       self.evaluate(get_next())
 
   def testMapAndBatchShapeMismatch(self):
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/shuffle_and_repeat_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/shuffle_and_repeat_fusion_test.py
index 824cc68..8d429b0 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/shuffle_and_repeat_fusion_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/shuffle_and_repeat_fusion_test.py
@@ -17,9 +17,11 @@
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python import tf2
 from tensorflow.python.data.experimental.ops import optimization
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import test_util
 from tensorflow.python.platform import test
@@ -29,8 +31,13 @@
 class ShuffleAndRepeatFusionTest(test_base.DatasetTestBase):
 
   def testShuffleAndRepeatFusion(self):
+    if tf2.enabled() and context.executing_eagerly():
+      expected = "Shuffle"
+    else:
+      expected = "ShuffleAndRepeat"
+
     dataset = dataset_ops.Dataset.range(10).apply(
-        optimization.assert_next(["ShuffleAndRepeat"])).shuffle(10).repeat(2)
+        optimization.assert_next([expected])).shuffle(10).repeat(2)
     options = dataset_ops.Options()
     options.experimental_optimization.apply_default_optimizations = False
     options.experimental_optimization.shuffle_and_repeat_fusion = True
diff --git a/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py
index c36ea68..02523c1 100644
--- a/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py
@@ -48,96 +48,98 @@
   return nest.flatten(dataset_ops.get_legacy_output_shapes(dataset))
 
 
-@parameterized.named_parameters(("WithDropRemainder", True),
-                                ("WithoutDropRemainder", False))
 @test_util.run_all_in_graph_and_eager_modes
-class RebatchDatasetTest(test_base.DatasetTestBase):
+class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
 
+  drop_remainder_cases = [("WithDropRemainder", True),
+                          ("WithoutDropRemainder", False)]
+
+  @parameterized.named_parameters(drop_remainder_cases)
   def testBasic(self, drop_remainder):
     dataset = dataset_ops.Dataset.range(1024).batch(
         32, drop_remainder=drop_remainder)
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
-    self.assertEqual(
-        [[32 if drop_remainder else None]],
-        [ts.as_list() for ts in _flat_shapes(dataset)])
-    self.assertEqual(
-        [[8 if drop_remainder else None]],
-        [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
+    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
+    self.assertEqual([[8] if drop_remainder else [None]],
+                     [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
 
     expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)]  # pylint: disable=g-complex-comprehension
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 
-  def testScalarInputError(self, _):
+  def testScalarInputError(self):
     dataset = dataset_ops.Dataset.range(1024)
+    distribute._RebatchDataset(dataset.batch(4), num_replicas=4)
     with self.assertRaisesRegexp(ValueError, "at least one dimension"):
-      distribute._RebatchDataset(dataset, num_workers=4)
+      distribute._RebatchDataset(dataset, num_replicas=4)
 
-  def testNotDivisible(self, drop_remainder):
+  @parameterized.named_parameters(drop_remainder_cases)
+  def testBatchNotDivisibleByNumReplicas(self, drop_remainder):
     dataset = dataset_ops.Dataset.range(1024).batch(
         32, drop_remainder=drop_remainder)
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=5)
-    expected_output = [[k for k in range(i, i + 7)] for i in range(0, 1022, 7)]  # pylint: disable=g-complex-comprehension
-    if not drop_remainder:
-      expected_output.append([1022, 1023])
+    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
+    self.assertEqual([[None]],
+                     [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
+    expected_output = []
+    i = 0
+    for _ in range(32):  # number of steps
+      # first four minibatches have seven elements
+      for _ in range(4):
+        expected_output.append([k for k in range(i, i + 7)])
+        i += 7
+      # last minibatch has four elements
+      expected_output.append([k for k in range(i, i + 4)])
+      i += 4
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 
-  def testTupleOutput(self, drop_remainder):
-    dataset = (
-        dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch(
-            32, drop_remainder=drop_remainder))
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
+  def testTupleOutput(self):
+    dataset = dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch(32)
+    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
     expected_output = [([k for k in range(i, i + 8)],  # pylint: disable=g-complex-comprehension
                         [k for k in range(i, i + 8)])
                        for i in range(0, 1024, 8)]
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 
-  def testNestedDictionaryOutput(self, drop_remainder):
+  def testNestedDictionaryOutput(self):
     dataset = dataset_ops.Dataset.range(1024).map(
-        lambda x: {"a": x, "b": {"c": x}}).batch(
-            32, drop_remainder=drop_remainder)
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
+        lambda x: {"a": x, "b": {"c": x}}).batch(32)
+    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
     expected_output = [{"a": [k for k in range(i, i + 8)],  # pylint: disable=g-complex-comprehension
                         "b": {"c": [k for k in range(i, i + 8)]}}
                        for i in range(0, 1024, 8)]
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 
-  def testFinalPartialBatchOriginal(self, drop_remainder):
+  @parameterized.named_parameters(drop_remainder_cases)
+  def testFinalPartialBatch(self, drop_remainder):
     dataset = dataset_ops.Dataset.range(1032).batch(
         32, drop_remainder=drop_remainder)
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
-    self.assertEqual(
-        [[32 if drop_remainder else None]],
-        [ts.as_list() for ts in _flat_shapes(dataset)])
-    self.assertEqual(
-        [[8 if drop_remainder else None]],
-        [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
+    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
+    self.assertEqual([[8] if drop_remainder else [None]],
+                     [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
 
-    expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1032, 8)]  # pylint: disable=g-complex-comprehension
+    # if drop_remainder, the final partial batch is dropped, even though it
+    # makes up a complete minibatch.
+    expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)]  # pylint: disable=g-complex-comprehension
+    if not drop_remainder:
+      expected_output.append([k for k in range(1024, 1032)])
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 
+  @parameterized.named_parameters(drop_remainder_cases)
   def testFinalPartialBatchAfterRebatch(self, drop_remainder):
     dataset = dataset_ops.Dataset.range(34).batch(
         32, drop_remainder=drop_remainder)
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
-    self.assertEqual(
-        [[32 if drop_remainder else None]],
-        [ts.as_list() for ts in _flat_shapes(dataset)])
-    self.assertEqual(
-        [[8 if drop_remainder else None]],
-        [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
+    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
+    self.assertEqual([[8] if drop_remainder else [None]],
+                     [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
 
     expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)]  # pylint: disable=g-complex-comprehension
     if not drop_remainder:
       expected_output += [[32, 33]]
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 
-  def testMultipleBatches(self, drop_remainder):
-    dataset = dataset_ops.Dataset.range(128).batch(
-        4, drop_remainder=drop_remainder)
-    dataset = dataset.batch(8, drop_remainder=drop_remainder)
-    self.assertEqual(
-        [[8, 4]] if drop_remainder else [[None, None]],
-        [ts.as_list() for ts in _flat_shapes(dataset)])
+  def testMultipleBatches(self):
+    dataset = dataset_ops.Dataset.range(128).batch(4).batch(8)
+    self.assertEqual([[None, None]],
+                     [ts.as_list() for ts in _flat_shapes(dataset)])
+
     # Each element is a list of 8 elements where each element is a list of 4.
     expected_output = [[[j, j + 1, j + 2, j + 3]  # pylint: disable=g-complex-comprehension
                         for j in range(i, i + 32, 4)]  # generates 8 elements
@@ -145,39 +147,30 @@
     self.assertDatasetProduces(dataset, expected_output)
 
     rebatched_dataset = distribute._RebatchDataset(dataset, 4)
-    self.assertEqual(
-        [[2, 4]] if drop_remainder else [[None, None]],
-        [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
+    self.assertEqual([[None, None]],
+                     [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
     # Each element is a list of 2 elements where each element is a list of 4.
     expected_output = [[[j, j + 1, j + 2, j + 3]  # pylint: disable=g-complex-comprehension
                         for j in range(i, i + 8, 4)]  # generates 2 elements
                        for i in range(0, 128, 8)]
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 
-  def testMapAndBatch(self, drop_remainder):
+  def testMapAndBatch(self):
     dataset = dataset_ops.Dataset.range(1024).apply(
-        batching.map_and_batch(
-            math_ops.square, 32, drop_remainder=drop_remainder))
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
-    self.assertEqual(
-        [[32 if drop_remainder else None]],
-        [ts.as_list() for ts in _flat_shapes(dataset)])
-    self.assertEqual(
-        [[8 if drop_remainder else None]],
-        [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
+        batching.map_and_batch(math_ops.square, 32))
+    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
+    self.assertEqual([[None]],
+                     [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
     expected_output = [[k**2 for k in range(i, i + 8)]  # pylint: disable=g-complex-comprehension
                        for i in range(0, 1024, 8)]
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 
-  def testMapAndBatchWithCapturedInput(self, drop_remainder):
+  def testMapAndBatchWithCapturedInput(self):
     captured_t = variables.Variable(42)
     dataset = dataset_ops.Dataset.range(1024).apply(
-        batching.map_and_batch(
-            lambda x: captured_t, 32, drop_remainder=drop_remainder))
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
-    self.assertEqual([[32 if drop_remainder else None]],
-                     [ts.as_list() for ts in _flat_shapes(dataset)])
-    self.assertEqual([[8 if drop_remainder else None]],
+        batching.map_and_batch(lambda x: captured_t, 32))
+    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
+    self.assertEqual([[None]],
                      [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
     expected_output = [[42 for _ in range(i, i + 8)]  # pylint: disable=g-complex-comprehension
                        for i in range(0, 1024, 8)]
@@ -185,22 +178,19 @@
     self.assertDatasetProduces(
         rebatched_dataset, expected_output, requires_initialization=True)
 
-  def testPaddedBatch(self, drop_remainder):
-    dataset = dataset_ops.Dataset.range(128).batch(4).padded_batch(
-        8, padded_shapes=[5], drop_remainder=drop_remainder)
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
-    self.assertEqual(
-        [[8, 5]] if drop_remainder else [[None, 5]],
-        [ts.as_list() for ts in _flat_shapes(dataset)])
+  def testPaddedBatch(self):
+    dataset = dataset_ops.Dataset.range(128).batch(
+        4, drop_remainder=True).padded_batch(
+            8, padded_shapes=[5])
+    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
     # Each element is a list of 8 elements in which each element is a list of 5
     # elements, first four are numbers and the last one is a padded zero.
     expected_output = [[[j, j + 1, j + 2, j + 3, 0]  # pylint: disable=g-complex-comprehension
                         for j in range(i, i + 32, 4)]  # generates 8 elements
                        for i in range(0, 128, 32)]
     self.assertDatasetProduces(dataset, expected_output)
-    self.assertEqual(
-        [[2, 5]] if drop_remainder else [[None, 5]],
-        [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
+    self.assertEqual([[None, 5]],
+                     [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
     # Each element is a list of 2 elements in which each element is a list of 5
     # elements, first four are numbers and the last one is a padded zero.
     expected_output = [[[j, j + 1, j + 2, j + 3, 0]  # pylint: disable=g-complex-comprehension
@@ -208,32 +198,22 @@
                        for i in range(0, 128, 8)]
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 
-  def testConcatenate(self, drop_remainder):
-    dataset1 = dataset_ops.Dataset.range(64).batch(
-        8, drop_remainder=drop_remainder)
-    dataset2 = dataset_ops.Dataset.range(32).batch(
-        8, drop_remainder=drop_remainder)
+  def testConcatenate(self):
+    dataset1 = dataset_ops.Dataset.range(64).batch(8)
+    dataset2 = dataset_ops.Dataset.range(32).batch(8)
     dataset = dataset1.concatenate(dataset2)
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
-    self.assertEqual(
-        [[8 if drop_remainder else None]],
-        [ts.as_list() for ts in _flat_shapes(dataset)])
-    self.assertEqual(
-        [[2 if drop_remainder else None]],
-        [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
+    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
+    self.assertEqual([[None]],
+                     [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
     expected_output = ([[i, i + 1] for i in range(0, 64, 2)] +
                        [[i, i + 1] for i in range(0, 32, 2)])
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 
-  def testConcatenateDifferentShapes(self, drop_remainder):
-    dataset1 = dataset_ops.Dataset.range(64).batch(
-        16, drop_remainder=drop_remainder)
-    dataset2 = dataset_ops.Dataset.range(32).batch(
-        8, drop_remainder=drop_remainder)
+  def testConcatenateDifferentShapes(self):
+    dataset1 = dataset_ops.Dataset.range(64).batch(16)
+    dataset2 = dataset_ops.Dataset.range(32).batch(8)
     dataset = dataset1.concatenate(dataset2)
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
-    self.assertEqual(
-        [[None]], [ts.as_list() for ts in _flat_shapes(dataset)])
+    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
     self.assertEqual(
         [[None]],
         [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
@@ -241,73 +221,56 @@
                        [[i, i + 1] for i in range(0, 32, 2)])
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 
-  def testZip(self, drop_remainder):
-    dataset1 = dataset_ops.Dataset.range(64).batch(
-        8, drop_remainder=drop_remainder)
-    dataset2 = dataset_ops.Dataset.range(32).batch(
-        8, drop_remainder=drop_remainder)
+  def testZip(self):
+    dataset1 = dataset_ops.Dataset.range(64).batch(8)
+    dataset2 = dataset_ops.Dataset.range(32).batch(8)
     dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
-    self.assertEqual(
-        [[8], [8]] if drop_remainder else [[None], [None]],
-        [ts.as_list() for ts in _flat_shapes(dataset)])
-    self.assertEqual(
-        [[2], [2]] if drop_remainder else [[None], [None]],
-        [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
+    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
+    self.assertEqual([[None], [None]],
+                     [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
     expected_output = [([i, i + 1], [i, i + 1]) for i in range(0, 32, 2)]
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 
-  def testZipDifferentShapes(self, drop_remainder):
-    dataset1 = dataset_ops.Dataset.range(64).batch(
-        16, drop_remainder=drop_remainder)
-    dataset2 = dataset_ops.Dataset.range(32).batch(
-        8, drop_remainder=drop_remainder)
+  def testZipDifferentShapes(self):
+    dataset1 = dataset_ops.Dataset.range(64).batch(16)
+    dataset2 = dataset_ops.Dataset.range(32).batch(8)
     dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
-    self.assertEqual(
-        [[16], [8]] if drop_remainder else [[None], [None]],
-        [ts.as_list() for ts in _flat_shapes(dataset)])
-    self.assertEqual(
-        [[4], [2]] if drop_remainder else [[None], [None]],
-        [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
+    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
+    self.assertEqual([[None], [None]],
+                     [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
     expected_output = [([2 * i, 2 * i + 1, 2 * i + 2, 2 * i + 3], [i, i + 1])
                        for i in range(0, 32, 2)]
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 
-  def testUnsupportedTransformError(self, drop_remainder):
-    dataset = dataset_ops.Dataset.range(1024).batch(
-        32, drop_remainder=drop_remainder).apply(sleep.sleep(10))
+  def testUnsupportedTransformError(self):
+    dataset = dataset_ops.Dataset.range(1024).batch(32).apply(sleep.sleep(10))
     with self.assertRaises(errors.InvalidArgumentError):
       rebatched_dataset = distribute._RebatchDataset(
-          dataset, num_workers=4, use_fallback=False)
+          dataset, num_replicas=4, use_fallback=False)
       next_element = self.getNext(rebatched_dataset)
       self.evaluate(next_element())
 
-  def testUnsupportedTransformInFlatMapError(self, drop_remainder):
+  def testUnsupportedTransformInFlatMapError(self):
     dataset = dataset_ops.Dataset.range(2).flat_map(
         lambda _: dataset_ops.Dataset.range(32).batch(  # pylint: disable=g-long-lambda
-            32, drop_remainder=drop_remainder).apply(sleep.sleep(10)))
+            32).apply(sleep.sleep(10)))
     with self.assertRaises(errors.InvalidArgumentError):
       rebatched_dataset = distribute._RebatchDataset(
-          dataset, num_workers=4, use_fallback=False)
+          dataset, num_replicas=4, use_fallback=False)
       next_element = self.getNext(rebatched_dataset)
       self.evaluate(next_element())
 
-  def testFlatMapBatching(self, drop_remainder):
-    dataset = dataset_ops.Dataset.range(
-        2).flat_map(lambda _: dataset_ops.Dataset.range(32).batch(  # pylint: disable=g-long-lambda
-            32, drop_remainder=drop_remainder))
-    self.assertEqual(
-        [[32 if drop_remainder else None]],
-        [ts.as_list() for ts in _flat_shapes(dataset)])
+  def testFlatMapBatching(self):
+    dataset = dataset_ops.Dataset.range(2).flat_map(
+        lambda _: dataset_ops.Dataset.range(32).batch(  # pylint: disable=g-long-lambda
+            32))
     # Two elements where each element is range(32)
     expected_output = [[k for k in range(32)] for _ in range(2)]  # pylint: disable=g-complex-comprehension
     self.assertDatasetProduces(dataset, expected_output)
 
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
-    self.assertEqual(
-        [[8 if drop_remainder else None]],
-        [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
+    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
+    self.assertEqual([[None]],
+                     [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
     # Two elements where each element is a list of 4 elements where each element
     # is a list of 8.
     expected_output = [[k for k in range(i, i + 8)]  # pylint: disable=g-complex-comprehension
@@ -315,21 +278,18 @@
                        for i in range(0, 32, 8)]  # generates 4 elements
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 
-  def testInterleaveBatching(self, drop_remainder):
-    dataset = dataset_ops.Dataset.range(
-        2).interleave(lambda _: dataset_ops.Dataset.range(32).batch(  # pylint: disable=g-long-lambda
-            32, drop_remainder=drop_remainder), cycle_length=2)
-    self.assertEqual(
-        [[32 if drop_remainder else None]],
-        [ts.as_list() for ts in _flat_shapes(dataset)])
+  def testInterleaveBatching(self):
+    dataset = dataset_ops.Dataset.range(2).interleave(
+        lambda _: dataset_ops.Dataset.range(32).batch(  # pylint: disable=g-long-lambda
+            32),
+        cycle_length=2)
     # Two elements where each element is range(32)
     expected_output = [[k for k in range(32)] for _ in range(2)]  # pylint: disable=g-complex-comprehension
     self.assertDatasetProduces(dataset, expected_output)
 
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
-    self.assertEqual(
-        [[8 if drop_remainder else None]],
-        [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
+    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
+    self.assertEqual([[None]],
+                     [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
     # List of 4 elements where each element is a list of 8 numbering from 0 to
     # 31 repeated twice.
     expected_output = [[k for k in range(i, i + 8)]  # pylint: disable=g-complex-comprehension
@@ -337,22 +297,19 @@
                        for _ in range(2)]
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 
-  def testParallelInterleaveBatching(self, drop_remainder):
-    dataset = dataset_ops.Dataset.range(
-        2).interleave(lambda _: dataset_ops.Dataset.range(32).batch(  # pylint: disable=g-long-lambda
-            32, drop_remainder=drop_remainder), cycle_length=2,
-                      num_parallel_calls=2)
-    self.assertEqual(
-        [[32 if drop_remainder else None]],
-        [ts.as_list() for ts in _flat_shapes(dataset)])
+  def testParallelInterleaveBatching(self):
+    dataset = dataset_ops.Dataset.range(2).interleave(
+        lambda _: dataset_ops.Dataset.range(32).batch(  # pylint: disable=g-long-lambda
+            32),
+        cycle_length=2,
+        num_parallel_calls=2)
     # Two elements where each element is range(32)
     expected_output = [[k for k in range(32)] for _ in range(2)]  # pylint: disable=g-complex-comprehension
     self.assertDatasetProduces(dataset, expected_output)
 
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
-    self.assertEqual(
-        [[8 if drop_remainder else None]],
-        [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
+    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
+    self.assertEqual([[None]],
+                     [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
     # List of 4 elements where each element is a list of 8 numbering from 0 to
     # 31 repeated twice in collated fashion i.e [0...8], [0...8] etc.
     expected_output = [[k for k in range(i, i + 8)]  # pylint: disable=g-complex-comprehension
@@ -360,17 +317,17 @@
                        for _ in range(2)]
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 
-  def testGroupByWindowStaticBatch(self, drop_remainder):
+  def testGroupByWindowStaticBatch(self):
     dataset = dataset_ops.Dataset.from_tensor_slices(
         [[array_ops.constant(i, dtype=dtypes.int64)] * 3 for i in range(40)])
     reduce_fn = lambda bucket_id, ds: ds.batch(  # pylint: disable=g-long-lambda
-        batch_size=10, drop_remainder=drop_remainder)
+        batch_size=10)
     dataset = dataset.apply(
         grouping.group_by_window(
             key_func=lambda x: x[0] % 4, reduce_func=reduce_fn, window_size=10))
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=2)
+    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=2)
 
-    self.assertEqual([[5, 3] if drop_remainder else [None, 3]],
+    self.assertEqual([[None, 3]],
                      [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
     # pylint: disable=g-complex-comprehension
     expected_output = [[[j + i * 4 + k * 20] * 3
@@ -379,36 +336,90 @@
                        for k in range(2)]
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 
-  def testGroupByWindowDynamicBatch(self, drop_remainder):
+  def testGroupByWindowDynamicBatch(self):
+    # {0, 1, 0, 1, ...}
     dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2)
-    reduce_fn = lambda bucket_id, ds: ds.batch(  # pylint: disable=g-long-lambda
-        batch_size=(bucket_id + 1) * 5, drop_remainder=drop_remainder)
+
+    def reduce_fn(key, ds):
+      # key == 0 -> .batch(5)
+      # key == 1 -> .batch(10)
+      return ds.batch(batch_size=(key + 1) * 5)
+
     dataset = dataset.apply(
         grouping.group_by_window(
             key_func=lambda x: x, reduce_func=reduce_fn, window_size=10))
-    dataset = distribute._RebatchDataset(dataset, num_workers=2)
+    dataset = distribute._RebatchDataset(dataset, num_replicas=2)
 
     self.assertEqual([[None]],
                      [ts.as_list() for ts in _flat_shapes(dataset)])
-    pairs = [(3, 0), (3, 0), (3, 0)]
-    if not drop_remainder:
-      pairs.extend([(1, 0)])
-    pairs.extend([(5, 1), (5, 1)])
+
+    # The batches of 5 (value == 0) will be split into minibatches of (3, 2) and
+    # the batches of 10 (value == 1) split into minibatches of (5, 5)
+    # [(batch_size, value), ...]
+    pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (5, 1), (5, 1)]
     pairs = pairs * 2
     expected_output = [[value] * batch_size for batch_size, value in pairs]
     self.assertDatasetProduces(dataset, expected_output)
 
-  def testScanAfterBatch(self, drop_remainder):
+  def testGroupByWindowDynamicBatchWithPartialBatch(self):
+    # {0, 1, 0, 1, ...}
+    dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2)
+
+    def reduce_fn(key, ds):
+      # key == 0 -> .batch(5)
+      # key == 1 -> .batch(10)
+      return ds.batch(batch_size=(key + 1) * 5)
+
+    dataset = dataset.apply(
+        grouping.group_by_window(
+            key_func=lambda x: x, reduce_func=reduce_fn, window_size=11))
+    dataset = distribute._RebatchDataset(dataset, num_replicas=2)
+
+    self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)])
+
+    # The batches of 5 (value == 0) will be split into minibatches of (3, 2) and
+    # the batches of 10 (value == 1) split into minibatches of (5, 5)
+    # [(batch_size, value), ...]
+    pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (1, 0), (5, 1), (5, 1), (1, 1),
+             (3, 0), (2, 0), (3, 0), (1, 0), (5, 1), (4, 1)]
+    expected_output = [[value] * batch_size for batch_size, value in pairs]
+    self.assertDatasetProduces(dataset, expected_output)
+
+  def testGroupByWindowDynamicBatchWithPartialBatchWithDropRemainder(self):
+    # This test exercises nested batch functionality, dynamic batch size
+    # and drop_remainder=True together.
+    dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2)
+
+    def reduce_fn(key, ds):
+      # key == 0 -> .batch(5)
+      # key == 1 -> .batch(10)
+      return ds.batch(batch_size=(key + 1) * 5, drop_remainder=True)
+
+    dataset = dataset.apply(
+        grouping.group_by_window(
+            key_func=lambda x: x, reduce_func=reduce_fn, window_size=11))
+    dataset = distribute._RebatchDataset(dataset, num_replicas=2)
+
+    self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)])
+
+    # The batches of 5 (value == 0) will be split into minibatches of (3, 2) and
+    # the batches of 10 (value == 1) split into minibatches of (5, 5)
+    # [(batch_size, value), ...]
+    pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (5, 1), (5, 1), (3, 0), (2, 0)]
+    expected_output = [[value] * batch_size for batch_size, value in pairs]
+    self.assertDatasetProduces(dataset, expected_output)
+
+  def testScanAfterBatch(self):
     dataset = dataset_ops.Dataset.range(40).batch(10).apply(
         scan_ops.scan(np.int64(2), lambda state, value: (state, value * state)))
-    dataset = distribute._RebatchDataset(dataset, num_workers=2)
+    dataset = distribute._RebatchDataset(dataset, num_replicas=2)
 
     self.assertEqual([[None]],
                      [ts.as_list() for ts in _flat_shapes(dataset)])
     expected_output = [[i * 2 for i in range(j*5, (j+1)*5)] for j in range(8)]  # pylint: disable=g-complex-comprehension
     self.assertDatasetProduces(dataset, expected_output)
 
-  def testMakeBatchedFeaturesDataset(self, drop_remainder):
+  def testMakeBatchedFeaturesDataset(self):
     # Set up
     fn = os.path.join(self.get_temp_dir(), "tf_record.txt")
     writer = python_io.TFRecordWriter(fn)
@@ -429,13 +440,11 @@
         features={"value": parsing_ops.FixedLenFeature([], dtypes.int64)},
         shuffle=False,
         num_epochs=1,
-        drop_final_batch=drop_remainder)
+        drop_final_batch=False)
 
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
+    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
 
-    self.assertEqual([[32 if drop_remainder else None]],
-                     [ts.as_list() for ts in _flat_shapes(dataset)])
-    self.assertEqual([[8 if drop_remainder else None]],
+    self.assertEqual([[None]],
                      [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
 
     expected_output = [{
@@ -450,7 +459,7 @@
   def testWithNoBatchDataset(self):
     dataset = dataset_ops.Dataset.from_tensor_slices(
         [[k for k in range(i, i + 32)] for i in range(0, 1024, 32)])  # pylint: disable=g-complex-comprehension
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
+    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
     self.assertEqual([[32]], [ts.as_list() for ts in _flat_shapes(dataset)])
     self.assertEqual([[8]],
                      [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
@@ -461,7 +470,7 @@
   def testWithUnhandledTransformation(self):
     dataset = dataset_ops.Dataset.range(1024).batch(
         32, drop_remainder=True).apply(sleep.sleep(10))
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
+    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
     self.assertEqual([[32]], [ts.as_list() for ts in _flat_shapes(dataset)])
     self.assertEqual([[8]],
                      [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
@@ -473,7 +482,7 @@
     dataset = dataset_ops.Dataset.range(2).flat_map(
         lambda _: dataset_ops.Dataset.range(32).batch(  # pylint: disable=g-long-lambda
             32, drop_remainder=True).apply(sleep.sleep(10)))
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
+    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
 
     self.assertEqual([[8]],
                      [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
@@ -491,7 +500,7 @@
 
     with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                  "Cannot use rebatching fallback"):
-      rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
+      rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
       next_element = self.getNext(rebatched_dataset)
       self.evaluate(next_element())
 
@@ -503,11 +512,11 @@
 
     with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                  "Cannot use rebatching fallback"):
-      rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
+      rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
       next_element = self.getNext(rebatched_dataset)
       self.evaluate(next_element())
 
-  def testBatchSizeIndivisibleByNumWorkers(self):
+  def testBatchSizeNotDivisibleByNumReplicas(self):
     # This doesn't work; reshape requires tensor shape to be exactly divisible
     # by the second dim.
     dataset = dataset_ops.Dataset.range(64).batch(
@@ -515,7 +524,7 @@
 
     with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                  "Cannot use rebatching fallback"):
-      rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=5)
+      rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
       next_element = self.getNext(rebatched_dataset)
       self.evaluate(next_element())
 
@@ -523,7 +532,7 @@
     dataset = dataset_ops.Dataset.from_tensors((np.arange(10), np.arange(5)))
     with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                  "Cannot use rebatching fallback"):
-      rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=5)
+      rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
       next_element = self.getNext(rebatched_dataset)
       self.evaluate(next_element())
 
diff --git a/tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py b/tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py
index 063e123..673e77f 100644
--- a/tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py
@@ -46,7 +46,8 @@
     initial_dist = [0.2] * 5 if initial_known else None
     classes = math_ops.cast(classes, dtypes.int64)  # needed for Windows build.
     dataset = dataset_ops.Dataset.from_tensor_slices(classes).shuffle(
-        200, seed=21).map(lambda c: (c, string_ops.as_string(c))).repeat()
+        200, seed=21, reshuffle_each_iteration=False).map(
+            lambda c: (c, string_ops.as_string(c))).repeat()
 
     get_next = self.getNext(
         dataset.apply(
diff --git a/tensorflow/python/data/experimental/kernel_tests/replicate_cluster_test.py b/tensorflow/python/data/experimental/kernel_tests/replicate_cluster_test.py
index 24913d4..5cf8419 100644
--- a/tensorflow/python/data/experimental/kernel_tests/replicate_cluster_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/replicate_cluster_test.py
@@ -33,6 +33,7 @@
 class ReplicateClusterTest(test_base.DatasetTestBase):
 
   def setUp(self):
+    super(ReplicateClusterTest, self).setUp()
     # Start the local server.
     worker_config = config_pb2.ConfigProto()
     worker_config.device_count["CPU"] = 2
@@ -99,7 +100,7 @@
       it1 = dataset_ops.make_initializable_iterator(dataset1)
     # We don't support stateful ops in functions as of now.
     with session.Session(self._target) as sess:
-      with self.assertRaises(errors.InvalidArgumentError):
+      with self.assertRaises(errors.FailedPreconditionError):
         sess.run(it1.initializer)
 
 
diff --git a/tensorflow/python/data/experimental/kernel_tests/replicate_test.py b/tensorflow/python/data/experimental/kernel_tests/replicate_test.py
index 55b8d25..120ad59 100644
--- a/tensorflow/python/data/experimental/kernel_tests/replicate_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/replicate_test.py
@@ -73,13 +73,10 @@
       dataset0 = dataset_ops.Dataset.range(100).map(
           lambda _: counter_var.assign_add(1))
     # We don't support stateful ops in functions as of now.
-    with self.assertRaises(errors.InvalidArgumentError):
+    with self.assertRaises(errors.FailedPreconditionError):
       replicated_ds = distribute.replicate(dataset0,
                                            [self._device1, self._device2])
-      dataset1 = replicated_ds[self._device1]
-      with ops.device(self._device1):
-        self.assertDatasetProduces(
-            dataset1, range(100), requires_initialization=True)
+      self.evaluate(replicated_ds[self._device1]._variant_tensor)
 
 
 JOB_NAME = "remote_device"
@@ -120,6 +117,7 @@
     self._device2 = "/job:%s/replica:0/task:2/device:CPU:0" % JOB_NAME
 
   def setUp(self):
+    super(RemoteReplicateTest, self).setUp()
     # Start the local server.
     local_port = pywrap_tensorflow.TF_PickUnusedPortOrDie()
     context.set_server_def(
@@ -169,13 +167,10 @@
       dataset0 = dataset_ops.Dataset.range(100).map(
           lambda _: counter_var.assign_add(1))
     # We don't support stateful ops in functions as of now.
-    with self.assertRaises(errors.InvalidArgumentError):
+    with self.assertRaises(errors.FailedPreconditionError):
       replicated_ds = distribute.replicate(dataset0,
                                            [self._device1, self._device2])
-      dataset1 = replicated_ds[self._device1]
-      with ops.device(self._device1):
-        self.assertDatasetProduces(
-            dataset1, range(100), requires_initialization=True)
+      self.evaluate(replicated_ds[self._device1]._variant_tensor)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/scan_test.py b/tensorflow/python/data/experimental/kernel_tests/scan_test.py
index 0932a25..8f059c4 100644
--- a/tensorflow/python/data/experimental/kernel_tests/scan_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/scan_test.py
@@ -31,6 +31,7 @@
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_v2_toggles
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import script_ops
 from tensorflow.python.ops import tensor_array_ops
@@ -156,6 +157,9 @@
 
   def testTensorArrayWithCondResetByExternalCaptureBreaks(self):
 
+    if control_flow_v2_toggles.control_flow_v2_enabled():
+      self.skipTest("v1 only test")
+
     empty_ta = tensor_array_ops.TensorArray(
         size=0, element_shape=[], dtype=dtypes.int64, dynamic_size=True)
 
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/auto_shard_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/auto_shard_dataset_serialization_test.py
index 5bf8365..ee1792f 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/auto_shard_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/auto_shard_dataset_serialization_test.py
@@ -58,7 +58,7 @@
       dataset = distribute._AutoShardDataset(dataset, 5, 3)
       return dataset
 
-    self.run_core_tests(build_dataset, None, 20)
+    self.run_core_tests(build_dataset, 20)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py
index d72a6df..8766a1c 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py
@@ -44,7 +44,6 @@
     num_outputs = tensor_slice_len // batch_size
     self.run_core_tests(
         lambda: self.build_dataset(15.0, tensor_slice_len, batch_size),
-        lambda: self.build_dataset(20.0, tensor_slice_len, batch_size),
         num_outputs)
 
   def _build_dataset_dense_to_sparse(self, components):
@@ -54,11 +53,9 @@
 
   def testDenseToSparseBatchDatasetCore(self):
     components = np.random.randint(5, size=(40,)).astype(np.int32)
-    diff_comp = np.random.randint(2, size=(100,)).astype(np.int32)
 
     num_outputs = len(components) // 4
     self.run_core_tests(lambda: self._build_dataset_dense_to_sparse(components),
-                        lambda: self._build_dataset_dense_to_sparse(diff_comp),
                         num_outputs)
 
   def _sparse(self, i):
@@ -69,14 +66,13 @@
     return dataset_ops.Dataset.range(10).map(self._sparse).batch(batch_size)
 
   def testSparseCore(self):
-    self.run_core_tests(self._build_dataset_sparse,
-                        lambda: self._build_dataset_sparse(2), 2)
+    self.run_core_tests(self._build_dataset_sparse, 2)
 
   def _build_dataset_nested_sparse(self):
     return dataset_ops.Dataset.range(10).map(self._sparse).batch(5).batch(2)
 
   def testNestedSparseCore(self):
-    self.run_core_tests(self._build_dataset_nested_sparse, None, 1)
+    self.run_core_tests(self._build_dataset_nested_sparse, 1)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py
index 2bcf77f..0f86e44 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py
@@ -85,24 +85,14 @@
         ds_fn, [5], 8, verify_exhausted=False, save_checkpoint_at_end=False)
     self.assertSequenceEqual(outputs, range(8))
 
-    if is_memory:
-      outputs = outputs[:5]
-      outputs.extend(
-          self.gen_outputs(
-              ds_fn, [],
-              self.num_outputs - 5,
-              ckpt_saved=True,
-              verify_exhausted=False))
-      self.assertSequenceEqual(outputs, self.expected_outputs())
-    else:
-      # Restoring from checkpoint and running GetNext should return
-      # `AlreadExistsError` now because the lockfile already exists.
-      with self.assertRaises(errors.AlreadyExistsError):
+    outputs = outputs[:5]
+    outputs.extend(
         self.gen_outputs(
             ds_fn, [],
             self.num_outputs - 5,
             ckpt_saved=True,
-            verify_exhausted=False)
+            verify_exhausted=False))
+    self.assertSequenceEqual(outputs, self.expected_outputs())
 
   @parameterized.named_parameters(
       ('Memory', True),
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_branch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_branch_dataset_serialization_test.py
index eaedcae..d73420c 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_branch_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_branch_dataset_serialization_test.py
@@ -46,7 +46,7 @@
           ratio_numerator=10)
 
     for size in [100, 1000]:
-      self.run_core_tests(lambda: build_ds(size), None, size // 10)  # pylint: disable=cell-var-from-loop
+      self.run_core_tests(lambda: build_ds(size), size // 10)  # pylint: disable=cell-var-from-loop
 
   def testWithCapture(self):
 
@@ -64,7 +64,7 @@
       return optimization._ChooseFastestBranchDataset(
           dataset, [branch_0, branch_1], num_elements_per_branch=3)
 
-    self.run_core_tests(build_ds, None, 10)
+    self.run_core_tests(build_ds, 10)
 
   def testWithPrefetch(self):
 
@@ -82,7 +82,7 @@
       return optimization._ChooseFastestBranchDataset(
           dataset, [branch_0, branch_1], num_elements_per_branch=3)
 
-    self.run_core_tests(build_ds, None, 10)
+    self.run_core_tests(build_ds, 10)
 
   def testWithMoreOutputThanInput(self):
 
@@ -97,7 +97,7 @@
           ratio_denominator=10,
           num_elements_per_branch=100)
 
-    self.run_core_tests(build_ds, None, 1000)
+    self.run_core_tests(build_ds, 1000)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_dataset_serialization_test.py
index 936dc22..73146a5 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_dataset_serialization_test.py
@@ -38,7 +38,7 @@
           dataset.batch(batch_size).map(map_fn)
       ])
 
-    self.run_core_tests(build_ds, None, num_outputs // 2)
+    self.run_core_tests(build_ds, num_outputs // 2)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py
index c075dff..968c858 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py
@@ -39,9 +39,7 @@
   def testConcatenateCore(self):
     num_outputs = 9
     array = np.tile(np.array([[16], [17], [18], [19], [20]]), 15)
-    diff_array = np.array([[1], [2], [3], [4], [5]])
     self.run_core_tests(lambda: self._build_concatenate_dataset(array),
-                        lambda: self._build_concatenate_dataset(diff_array),
                         num_outputs)
 
 
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py
index d498349..c1c91a6 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py
@@ -65,7 +65,6 @@
     defs = [[0]] * self._num_cols
     self.run_core_tests(
         lambda: self.ds_func(record_defaults=defs, buffer_size=2),
-        lambda: self.ds_func(record_defaults=defs, buffer_size=12),
         self._num_outputs)
 
 
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py
index 41a095f..2c31c23 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py
@@ -37,9 +37,7 @@
     # Equal length components
     arr = np.array(1)
     num_outputs = 1
-    diff_arr = np.array(2)
     self.run_core_tests(lambda: self._build_tensor_dataset(arr),
-                        lambda: self._build_tensor_dataset(diff_arr),
                         num_outputs)
 
 
@@ -55,16 +53,12 @@
                   np.tile(np.array([[12], [13], [14], [15]]), 22),
                   np.array([37.0, 38.0, 39.0, 40.0]))
 
-    diff_comp = (np.tile(np.array([[1], [2], [3], [4]]), 20),
-                 np.tile(np.array([[5], [6], [7], [8]]), 22),
-                 np.array([1.0, 2.0, 3.0, 4.0]))
-
     dict_components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]}
 
     self.run_core_tests(lambda: self._build_tensor_slices_dataset(components),
-                        lambda: self._build_tensor_slices_dataset(diff_comp), 4)
+                        4)
     self.run_core_tests(
-        lambda: self._build_tensor_slices_dataset(dict_components), None, 3)
+        lambda: self._build_tensor_slices_dataset(dict_components), 3)
 
 
 class FromSparseTensorSlicesSerializationTest(
@@ -82,11 +76,9 @@
 
   def testFromSparseTensorSlicesCore(self):
     slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []]
-    diff_slices = [[1., 2.], [2.], [2., 3., 4.], [], [], []]
 
     self.run_core_tests(
         lambda: self._build_sparse_tensor_slice_dataset(slices),
-        lambda: self._build_sparse_tensor_slice_dataset(diff_slices),
         9,
         sparse_tensors=True)
 
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py
index d4f377d..f6ab5a1 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py
@@ -24,7 +24,6 @@
 
 from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
 from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.ops import iterator_ops
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
@@ -62,13 +61,11 @@
   # TODO(b/72657739): Remove sparse_tensor argument, which is to test the
   # (deprecated) saveable `SparseTensorSliceDataset`, once the API
   # `from_sparse_tensor_slices()`and related tests are deleted.
-  def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False):
+  def run_core_tests(self, ds_fn, num_outputs, sparse_tensors=False):
     """Runs the core tests.
 
     Args:
-      ds_fn1: 0-argument function that returns a Dataset.
-      ds_fn2: 0-argument function that returns a Dataset different from
-        ds_fn1. If None, verify_restore_in_modified_graph test is not run.
+      ds_fn: 0-argument function that returns a Dataset.
       num_outputs: Total number of outputs expected from this Dataset.
       sparse_tensors: Whether dataset is built from SparseTensor(s).
 
@@ -80,33 +77,19 @@
     options = dataset_ops.Options()
     options.experimental_optimization.apply_default_optimizations = False
 
-    def ds_fn1_no_opt():
-      return ds_fn1().with_options(options)
+    def ds_fn_no_opt():
+      return ds_fn().with_options(options)
 
     self.verify_unused_iterator(
-        ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors)
+        ds_fn_no_opt, num_outputs, sparse_tensors=sparse_tensors)
     self.verify_fully_used_iterator(
-        ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors)
+        ds_fn_no_opt, num_outputs, sparse_tensors=sparse_tensors)
     self.verify_exhausted_iterator(
-        ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors)
-    self.verify_init_before_restore(
-        ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors)
+        ds_fn_no_opt, num_outputs, sparse_tensors=sparse_tensors)
     self.verify_multiple_breaks(
-        ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors)
+        ds_fn_no_opt, num_outputs, sparse_tensors=sparse_tensors)
     self.verify_reset_restored_iterator(
-        ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors)
-    self.verify_restore_in_empty_graph(
-        ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors)
-    if ds_fn2:
-
-      def ds_fn2_no_opt():
-        return ds_fn2().with_options(options)
-
-      self.verify_restore_in_modified_graph(
-          ds_fn1_no_opt,
-          ds_fn2_no_opt,
-          num_outputs,
-          sparse_tensors=sparse_tensors)
+        ds_fn_no_opt, num_outputs, sparse_tensors=sparse_tensors)
 
   def verify_unused_iterator(self,
                              ds_fn,
@@ -176,30 +159,6 @@
         sparse_tensors=sparse_tensors)
     self.assertEqual(len(actual), 0)
 
-  def verify_init_before_restore(self,
-                                 ds_fn,
-                                 num_outputs,
-                                 sparse_tensors=False,
-                                 verify_exhausted=True):
-    """Verifies that restoring into an already initialized iterator works.
-
-    Args:
-      ds_fn: See `run_core_tests`.
-      num_outputs: See `run_core_tests`.
-      sparse_tensors: See `run_core_tests`.
-      verify_exhausted: See `gen_outputs`.
-
-    Raises:
-      AssertionError if any test fails.
-    """
-    self.verify_run_with_breaks(
-        ds_fn,
-        self.gen_break_points(num_outputs),
-        num_outputs,
-        init_before_restore=True,
-        sparse_tensors=sparse_tensors,
-        verify_exhausted=verify_exhausted)
-
   def verify_multiple_breaks(self,
                              ds_fn,
                              num_outputs,
@@ -270,6 +229,7 @@
           ds_fn, sparse_tensors=sparse_tensors)
       get_next_op = remove_variants(get_next_op)
       with self.session(graph=g) as sess:
+        self._initialize(init_op, sess)
         self._restore(saver, sess)
         self._initialize(init_op, sess)
         for _ in range(num_outputs):
@@ -279,130 +239,6 @@
             sess.run(get_next_op)
     self.match(expected, actual)
 
-  def verify_restore_in_modified_graph(self,
-                                       ds_fn1,
-                                       ds_fn2,
-                                       num_outputs,
-                                       break_point=None,
-                                       sparse_tensors=False,
-                                       verify_exhausted=True):
-    """Attempts to restore an iterator in a modified graph.
-
-    Builds an input pipeline using ds_fn1, runs it for `break_point` steps
-    and saves a checkpoint. Then builds a new graph using ds_fn2, restores
-    the checkpoint from ds_fn1 and verifies that the restore is successful.
-
-    Args:
-      ds_fn1: See `run_core_tests`.
-      ds_fn2: See `run_core_tests`.
-      num_outputs: See `run_core_tests`.
-      break_point: Break point. Optional. Defaults to num_outputs/2.
-      sparse_tensors: See `run_core_tests`.
-      verify_exhausted: See `gen_outputs`.
-
-    Raises:
-      AssertionError if any test fails.
-    """
-    break_point = num_outputs // 2 if not break_point else break_point
-
-    # Skip `break_point` items and store the remaining produced from ds_fn1
-    # in `expected`.
-    self.gen_outputs(
-        ds_fn1, [],
-        break_point,
-        sparse_tensors=sparse_tensors,
-        verify_exhausted=False)
-    expected = self.gen_outputs(
-        ds_fn1, [],
-        num_outputs - break_point,
-        ckpt_saved=True,
-        sparse_tensors=sparse_tensors,
-        verify_exhausted=verify_exhausted)
-
-    # Generate `break_point` items from ds_fn1 and save checkpoint.
-    self.gen_outputs(
-        ds_fn1, [],
-        break_point,
-        sparse_tensors=sparse_tensors,
-        verify_exhausted=False)
-
-    actual = []
-    # Build graph for ds_fn2 but load checkpoint for ds_fn1.
-    with ops.Graph().as_default() as g:
-      _, get_next_op, saver = self._build_graph(
-          ds_fn2, sparse_tensors=sparse_tensors)
-      get_next_op = remove_variants(get_next_op)
-      with self.session(graph=g) as sess:
-        self._restore(saver, sess)
-        for _ in range(num_outputs - break_point):
-          actual.append(sess.run(get_next_op))
-        if verify_exhausted:
-          with self.assertRaises(errors.OutOfRangeError):
-            sess.run(get_next_op)
-
-    self.match(expected, actual)
-
-  def verify_restore_in_empty_graph(self,
-                                    ds_fn,
-                                    num_outputs,
-                                    break_point=None,
-                                    sparse_tensors=False,
-                                    verify_exhausted=True):
-    """Attempts to restore an iterator in an empty graph.
-
-    Builds an input pipeline using ds_fn, runs it for `break_point` steps
-    and saves a checkpoint. Then builds a new empty graph, restores
-    the checkpoint from ds_fn and verifies that the restore is successful.
-
-    Args:
-      ds_fn: See `run_core_tests`.
-      num_outputs: See `run_core_tests`.
-      break_point: Break point. Optional. Defaults to num_outputs/2.
-      sparse_tensors: See `run_core_tests`.
-      verify_exhausted: See `gen_outputs`.
-
-    Raises:
-      AssertionError if any test fails.
-    """
-    break_point = num_outputs // 2 if not break_point else break_point
-
-    # Skip `break_point` items and store the remaining produced from ds_fn
-    # in `expected`.
-    self.gen_outputs(
-        ds_fn, [],
-        break_point,
-        sparse_tensors=sparse_tensors,
-        verify_exhausted=False)
-    expected = self.gen_outputs(
-        ds_fn, [],
-        num_outputs - break_point,
-        ckpt_saved=True,
-        sparse_tensors=sparse_tensors,
-        verify_exhausted=verify_exhausted)
-
-    # Generate `break_point` items from ds_fn and save checkpoint.
-    self.gen_outputs(
-        ds_fn, [],
-        break_point,
-        sparse_tensors=sparse_tensors,
-        verify_exhausted=False)
-
-    actual = []
-    # Build an empty graph but load checkpoint for ds_fn.
-    with ops.Graph().as_default() as g:
-      get_next_op, saver = self._build_empty_graph(
-          ds_fn, sparse_tensors=sparse_tensors)
-      get_next_op = remove_variants(get_next_op)
-      with self.session(graph=g) as sess:
-        self._restore(saver, sess)
-        for _ in range(num_outputs - break_point):
-          actual.append(sess.run(get_next_op))
-        if verify_exhausted:
-          with self.assertRaises(errors.OutOfRangeError):
-            sess.run(get_next_op)
-
-    self.match(expected, actual)
-
   def verify_error_on_save(self,
                            ds_fn,
                            num_outputs,
@@ -438,7 +274,6 @@
                              ds_fn,
                              break_points,
                              num_outputs,
-                             init_before_restore=False,
                              sparse_tensors=False,
                              verify_exhausted=True):
     """Verifies that ds_fn() produces the same outputs with and without breaks.
@@ -454,7 +289,6 @@
       ds_fn: See `gen_outputs`.
       break_points: See `gen_outputs`.
       num_outputs: See `gen_outputs`.
-      init_before_restore: See `gen_outputs`.
       sparse_tensors: See `run_core_tests`.
       verify_exhausted: See `gen_outputs`.
 
@@ -464,7 +298,6 @@
     expected = self.gen_outputs(
         ds_fn, [],
         num_outputs,
-        init_before_restore=init_before_restore,
         sparse_tensors=sparse_tensors,
         verify_exhausted=verify_exhausted)
 
@@ -472,7 +305,6 @@
         ds_fn,
         break_points,
         num_outputs,
-        init_before_restore=init_before_restore,
         sparse_tensors=sparse_tensors,
         verify_exhausted=verify_exhausted)
 
@@ -483,7 +315,6 @@
                   break_points,
                   num_outputs,
                   ckpt_saved=False,
-                  init_before_restore=False,
                   sparse_tensors=False,
                   verify_exhausted=True,
                   save_checkpoint_at_end=True):
@@ -501,11 +332,7 @@
         produce outputs till next checkpoint or till `num_outputs` elements
         have been produced. `break_point` must be <= `num_outputs`.
       num_outputs: The total number of outputs to produce from the iterator.
-      ckpt_saved: Whether a checkpoint already exists. If False, we build the
-        graph from ds_fn.
-      init_before_restore: Whether init should be called before saver.restore.
-        This is just so that we can verify that restoring an already initialized
-        iterator works.
+      ckpt_saved: Whether a checkpoint already exists.
       sparse_tensors:  Whether dataset is built from SparseTensor(s).
       verify_exhausted: Whether to verify that the iterator has been exhausted
         after producing `num_outputs` elements.
@@ -535,8 +362,7 @@
         get_next_op = remove_variants(get_next_op)
         with self.session(graph=g) as sess:
           if ckpt_saved:
-            if init_before_restore:
-              self._initialize(init_op, sess)
+            self._initialize(init_op, sess)
             self._restore(saver, sess)
           else:
             self._initialize(init_op, sess)
@@ -584,13 +410,11 @@
         for item1, item2 in zip(expected, actual):
           self.match(item1, item2)
     elif isinstance(expected, sparse_tensor.SparseTensorValue):
-      return self.match(
-          (expected.indices, expected.values, expected.dense_shape),
-          (actual.indices, actual.values, actual.dense_shape))
+      self.match((expected.indices, expected.values, expected.dense_shape),
+                 (actual.indices, actual.values, actual.dense_shape))
     elif isinstance(expected, ragged_tensor_value.RaggedTensorValue):
-      return self.match(
-          (expected.values, expected.row_splits),
-          (actual.values, actual.row_splits))
+      self.match((expected.values, expected.row_splits),
+                 (actual.values, actual.row_splits))
     else:
       self.assertEqual(expected, actual)
 
@@ -617,20 +441,6 @@
     saver = saver_lib.Saver(allow_empty=True)
     return init_op, get_next, saver
 
-  def _build_empty_graph(self, ds_fn, sparse_tensors=False):
-    iterator = iterator_ops.Iterator.from_structure(
-        self._get_output_types(ds_fn),
-        output_shapes=self._get_output_shapes(ds_fn),
-        output_classes=self._get_output_classes(ds_fn))
-    saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
-    ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
-    if sparse_tensors:
-      get_next = sparse_tensor.SparseTensor(*iterator.get_next())
-    else:
-      get_next = iterator.get_next()
-    saver = saver_lib.Saver(allow_empty=True)
-    return get_next, saver
-
   def _add_iterator_ops_to_collection(self,
                                       init_op,
                                       get_next,
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py
index e3ba8ad..4aaf450 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py
@@ -35,7 +35,6 @@
     div = 3
     num_outputs = sum(x % 3 != 2 for x in range(100))
     self.run_core_tests(lambda: self._build_filter_range_graph(div),
-                        lambda: self._build_filter_range_graph(div * 2),
                         num_outputs)
 
   def _build_filter_dict_graph(self):
@@ -46,7 +45,7 @@
 
   def testFilterDictCore(self):
     num_outputs = sum((x**2) % 2 == 0 for x in range(10))
-    self.run_core_tests(self._build_filter_dict_graph, None, num_outputs)
+    self.run_core_tests(self._build_filter_dict_graph, num_outputs)
 
   def _build_sparse_filter(self):
 
@@ -62,7 +61,7 @@
 
   def testSparseCore(self):
     num_outputs = 5
-    self.run_core_tests(self._build_sparse_filter, None, num_outputs)
+    self.run_core_tests(self._build_sparse_filter, num_outputs)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py
index 70caf3e..4a9c6b1 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py
@@ -37,7 +37,6 @@
     num_epochs = 5
     num_outputs = num_epochs * self._num_files * self._num_records
     self.run_core_tests(lambda: self._build_iterator_graph(num_epochs),
-                        lambda: self._build_iterator_graph(num_epochs * 2),
                         num_outputs)
 
 
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py
index e18cfa5..b2da2c7 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py
@@ -43,7 +43,7 @@
 
       return dataset_ops.Dataset.range(start, start + 5 * 5, 5).flat_map(map_fn)
 
-    self.run_core_tests(lambda: build_ds(0), lambda: build_ds(10), 25)
+    self.run_core_tests(lambda: build_ds(0), 25)
 
   def testMapThenFlatMap(self):
 
@@ -58,7 +58,7 @@
 
       return dataset_ops.Dataset.range(5).flat_map(flat_map_fn)
 
-    self.run_core_tests(build_ds, None, 500)
+    self.run_core_tests(build_ds, 500)
 
   def testCaptureDefunInMapFn(self):
 
@@ -74,7 +74,7 @@
 
       return dataset_ops.Dataset.range(100).flat_map(map_fn)
 
-    self.run_core_tests(build_ds, None, 100)
+    self.run_core_tests(build_ds, 100)
 
   def testDisallowVariableCapture(self):
 
@@ -84,7 +84,7 @@
       return dataset_ops.Dataset.range(5).flat_map(
           lambda _: dataset_ops.Dataset.from_tensor_slices([test_var]))
 
-    self.verify_error_on_save(build_ds, 5, errors.InvalidArgumentError)
+    self.verify_error_on_save(build_ds, 5, errors.FailedPreconditionError)
 
   def testDisallowCapturingStatefulOps(self):
 
@@ -100,7 +100,7 @@
 
       return dataset_ops.Dataset.range(5).flat_map(flat_map_fn)
 
-    self.verify_error_on_save(build_ds, 500, errors.InvalidArgumentError)
+    self.verify_error_on_save(build_ds, 500, errors.FailedPreconditionError)
 
   def testSparseCore(self):
 
@@ -115,7 +115,7 @@
     def _build_ds():
       return dataset_ops.Dataset.range(10).map(_map_fn).flat_map(_flat_map_fn)
 
-    self.run_core_tests(_build_ds, None, 20)
+    self.run_core_tests(_build_ds, 20)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py
index 169c884..d2f1ffb 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py
@@ -41,20 +41,10 @@
     components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64)
     self.verify_unused_iterator(
         lambda: self._build_dataset(components), 5, verify_exhausted=True)
-    self.verify_init_before_restore(
-        lambda: self._build_dataset(components), 5, verify_exhausted=True)
     self.verify_multiple_breaks(
         lambda: self._build_dataset(components), 5, verify_exhausted=True)
     self.verify_reset_restored_iterator(
         lambda: self._build_dataset(components), 5, verify_exhausted=True)
-    self.verify_restore_in_empty_graph(
-        lambda: self._build_dataset(components), 5, verify_exhausted=True)
-    diff_components = np.array([5, 4, 3, 2, 1, 0], dtype=np.int64)
-    self.verify_restore_in_modified_graph(
-        lambda: self._build_dataset(components),
-        lambda: self._build_dataset(diff_components),
-        5,
-        verify_exhausted=True)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py
index e5bc762..69e28d4 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py
@@ -37,20 +37,10 @@
         [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
     self.verify_unused_iterator(
         lambda: self._build_dataset(components), 12, verify_exhausted=False)
-    self.verify_init_before_restore(
-        lambda: self._build_dataset(components), 12, verify_exhausted=False)
     self.verify_multiple_breaks(
         lambda: self._build_dataset(components), 12, verify_exhausted=False)
     self.verify_reset_restored_iterator(
         lambda: self._build_dataset(components), 12, verify_exhausted=False)
-    self.verify_restore_in_empty_graph(
-        lambda: self._build_dataset(components), 12, verify_exhausted=False)
-    diff_components = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64)
-    self.verify_restore_in_modified_graph(
-        lambda: self._build_dataset(components),
-        lambda: self._build_dataset(diff_components),
-        12,
-        verify_exhausted=False)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py
index df1f431..5858bd2 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py
@@ -17,8 +17,6 @@
 from __future__ import division
 from __future__ import print_function
 
-import numpy as np
-
 from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
 from tensorflow.python.data.experimental.ops import error_ops
 from tensorflow.python.data.ops import dataset_ops
@@ -29,17 +27,14 @@
 class IgnoreErrorsSerializationTest(
     dataset_serialization_test_base.DatasetSerializationTestBase):
 
-  def _build_ds(self, components):
-    return dataset_ops.Dataset.from_tensor_slices(components).map(
-        lambda x: array_ops.check_numerics(x, "message")).apply(
+  def _build_ds(self):
+    return dataset_ops.Dataset.range(5).map(
+        array_ops.ones).map(lambda x: array_ops.gather(x, [0])).apply(
             error_ops.ignore_errors())
 
   def testIgnoreErrorsCore(self):
-    components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
-    diff_components = np.array([1., 2., 3., np.nan]).astype(np.float32)
     num_outputs = 4
-    self.run_core_tests(lambda: self._build_ds(components),
-                        lambda: self._build_ds(diff_components), num_outputs)
+    self.run_core_tests(self._build_ds, num_outputs)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py
index 0c1d40c..f3daffb 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py
@@ -57,8 +57,6 @@
     self.run_core_tests(
         lambda: self._build_iterator_graph(
             input_values, cycle_length, block_length, num_parallel_calls),
-        lambda: self._build_iterator_graph(
-            input_values, cycle_length * 2, block_length, num_parallel_calls),
         num_outputs)
     # pylint: enable=g-long-lambda
 
@@ -76,7 +74,7 @@
       return dataset_ops.Dataset.range(10).map(_map_fn).interleave(
           _interleave_fn, cycle_length=1)
 
-    self.run_core_tests(_build_dataset, None, 20)
+    self.run_core_tests(_build_dataset, 20)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py
index 8bfe6ce..9cffd39 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py
@@ -52,10 +52,8 @@
                   num_parallel_batches=num_parallel_batches,
                   drop_remainder=drop_remainder))
 
-    self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15),
-                        num_outputs_keep_remainder)
-    self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True),
-                        num_outputs_drop_remainder)
+    self.run_core_tests(lambda: build_ds(10), num_outputs_keep_remainder)
+    self.run_core_tests(lambda: build_ds(10, True), num_outputs_drop_remainder)
 
   def testNumParallelCalls(self):
     range_size = 11
@@ -79,10 +77,8 @@
                   num_parallel_calls=num_parallel_calls,
                   drop_remainder=drop_remainder))
 
-    self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15),
-                        num_outputs_keep_remainder)
-    self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True),
-                        num_outputs_drop_remainder)
+    self.run_core_tests(lambda: build_ds(10), num_outputs_keep_remainder)
+    self.run_core_tests(lambda: build_ds(10, True), num_outputs_drop_remainder)
 
   def testSparse(self):
 
@@ -95,7 +91,7 @@
       return dataset_ops.Dataset.range(10).apply(
           batching.map_and_batch(map_fn, 5))
 
-    self.run_core_tests(build_dataset, None, 2)
+    self.run_core_tests(build_dataset, 2)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py
index a8667c2..7380172 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py
@@ -53,10 +53,7 @@
         .repeat(self._num_epochs))
 
   def testSaveRestoreCore(self):
-    self.run_core_tests(
-        self._build_ds,
-        lambda: self._build_ds(multiplier=15.0),
-        self._num_outputs)
+    self.run_core_tests(self._build_ds, self._num_outputs)
 
   def testSaveStatefulFunction(self):
 
@@ -68,7 +65,7 @@
 
       return dataset_ops.Dataset.range(100).map(_map_fn)
 
-    self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
+    self.verify_error_on_save(_build_ds, 15, errors.FailedPreconditionError)
 
   def testCaptureVariableInMapFn(self):
 
@@ -78,7 +75,7 @@
       return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
           lambda _: counter_var.assign_add(1)))
 
-    self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
+    self.verify_error_on_save(_build_ds, 15, errors.FailedPreconditionError)
 
   def testCaptureConstantInMapFn(self):
 
@@ -87,7 +84,7 @@
       return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
           lambda x: x + constant_var))
 
-    self.run_core_tests(_build_ds, None, 10)
+    self.run_core_tests(_build_ds, 10)
 
   def testCaptureDefunInMapFn(self):
     num_outputs = 100
@@ -100,7 +97,7 @@
 
       return dataset_ops.Dataset.range(num_outputs).map(defun_fn)
 
-    self.run_core_tests(_build_ds, None, num_outputs)
+    self.run_core_tests(_build_ds, num_outputs)
 
   def testBuildDefunInMapFn(self):
     num_outputs = 100
@@ -119,7 +116,7 @@
 
       return dataset_ops.Dataset.range(num_outputs).map(defun_fn)
 
-    self.run_core_tests(_build_ds, None, num_outputs)
+    self.run_core_tests(_build_ds, num_outputs)
 
   def testSparseCore(self):
 
@@ -133,8 +130,7 @@
       return dataset_ops.Dataset.range(num_outputs).map(_sparse)
 
     num_outputs = 10
-    self.run_core_tests(lambda: _build_ds(num_outputs),
-                        lambda: _build_ds(int(num_outputs / 2)), num_outputs)
+    self.run_core_tests(lambda: _build_ds(num_outputs), num_outputs)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/matching_files_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/matching_files_dataset_serialization_test.py
index c026e97..94b5e1b 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/matching_files_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/matching_files_dataset_serialization_test.py
@@ -55,7 +55,6 @@
 
     num_outputs = width * len(patterns)
     self.run_core_tests(lambda: self._build_iterator_graph(patterns),
-                        lambda: self._build_iterator_graph(patterns[0:1]),
                         num_outputs)
 
     shutil.rmtree(tmp_dir, ignore_errors=True)
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py
index aaa46ba..646f306 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py
@@ -32,7 +32,7 @@
       return dataset_ops.Dataset.range(num_elements).map(lambda x: x * x).batch(
           batch_size).apply(optimization.optimize(["map_and_batch_fusion"]))
 
-    self.run_core_tests(lambda: build_dataset(200, 10), None, 20)
+    self.run_core_tests(lambda: build_dataset(200, 10), 20)
 
   def testWithNewFunction(self):
     """Tests that optimized datasets with new functions work."""
@@ -46,7 +46,7 @@
       dataset = dataset.apply(optimization.optimize(["map_vectorization"]))
       return dataset
 
-    self.run_core_tests(build_dataset, None, 20)
+    self.run_core_tests(build_dataset, 20)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py
index 6f72b24..3988e64 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py
@@ -36,10 +36,8 @@
           lambda x: array_ops.fill([x], x)).padded_batch(
               4, padded_shapes=[-1])
 
-    seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32)
-    seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32)
-    self.run_core_tests(lambda: build_dataset(seq_lens1),
-                        lambda: build_dataset(seq_lens2), 8)
+    seq_lens = np.random.randint(1, 20, size=(32,)).astype(np.int32)
+    self.run_core_tests(lambda: build_dataset(seq_lens), 8)
 
   def testPaddedBatchNonDefaultPadding(self):
 
@@ -56,10 +54,8 @@
               padded_shapes=(padded_shape, padded_shape),
               padding_values=(-1, "<end>"))
 
-    seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32)
-    seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32)
-    self.run_core_tests(lambda: build_dataset(seq_lens1),
-                        lambda: build_dataset(seq_lens2), 8)
+    seq_lens = np.random.randint(1, 20, size=(32,)).astype(np.int32)
+    self.run_core_tests(lambda: build_dataset(seq_lens), 8)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py
index b8f38e8..c441ee7 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py
@@ -46,20 +46,18 @@
     # cycle_length > 1, block_length > 1
     cycle_length = 2
     block_length = 3
-    self.run_core_tests(
-        lambda: self._build_ds(cycle_length, block_length),
-        lambda: self._build_ds(cycle_length * 2, block_length * 1),
-        self.num_outputs)
+    self.run_core_tests(lambda: self._build_ds(cycle_length, block_length),
+                        self.num_outputs)
     # cycle_length = 1
     cycle_length = 1
     block_length = 3
     self.run_core_tests(lambda: self._build_ds(cycle_length, block_length),
-                        None, self.num_outputs)
+                        self.num_outputs)
     # block_length = 1
     cycle_length = 2
     block_length = 1
     self.run_core_tests(lambda: self._build_ds(cycle_length, block_length),
-                        None, self.num_outputs)
+                        self.num_outputs)
 
   def testSerializationWithSloppy(self):
     break_points = self.gen_break_points(self.num_outputs, 10)
@@ -94,7 +92,7 @@
       return dataset_ops.Dataset.range(10).map(_map_fn).apply(
           interleave_ops.parallel_interleave(_interleave_fn, 1))
 
-    self.run_core_tests(_build_dataset, None, 20)
+    self.run_core_tests(_build_dataset, 20)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py
index 4e4ed68..6ec012f 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py
@@ -63,10 +63,7 @@
 
   def testSaveRestoreCore(self):
     for ds_fn in [self._build_ds, self._build_ds_with_prefetch]:
-      self.run_core_tests(
-          ds_fn,
-          lambda: ds_fn(multiplier=15.0),  # pylint: disable=cell-var-from-loop
-          self._num_outputs)
+      self.run_core_tests(ds_fn, self._num_outputs)
 
   def testSaveStatefulFunction(self):
 
@@ -79,7 +76,7 @@
       return dataset_ops.Dataset.range(100).map(
           _map_fn, num_parallel_calls=2).prefetch(2)
 
-    self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
+    self.verify_error_on_save(_build_ds, 15, errors.FailedPreconditionError)
 
   def testCaptureVariableInMapFn(self):
 
@@ -90,7 +87,7 @@
           lambda _: counter_var.assign_add(1),
           num_parallel_calls=2).prefetch(2))
 
-    self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
+    self.verify_error_on_save(_build_ds, 15, errors.FailedPreconditionError)
 
   def testCaptureConstantInMapFn(self):
 
@@ -99,7 +96,7 @@
       return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
           lambda x: x + constant_var, num_parallel_calls=2).prefetch(2))
 
-    self.run_core_tests(_build_ds, None, 10)
+    self.run_core_tests(_build_ds, 10)
 
   def testCaptureDefunInMapFn(self):
     num_outputs = 100
@@ -113,7 +110,7 @@
       return dataset_ops.Dataset.range(num_outputs).map(
           defun_fn, num_parallel_calls=2).prefetch(2)
 
-    self.run_core_tests(_build_ds, None, num_outputs)
+    self.run_core_tests(_build_ds, num_outputs)
 
   def testBuildDefunInMapFn(self):
     num_outputs = 100
@@ -133,7 +130,7 @@
       return dataset_ops.Dataset.range(num_outputs).map(
           defun_fn, num_parallel_calls=2).prefetch(2)
 
-    self.run_core_tests(_build_ds, None, num_outputs)
+    self.run_core_tests(_build_ds, num_outputs)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py
index b3dfe21..6698fce 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py
@@ -41,9 +41,7 @@
     # pylint: disable=g-long-lambda
     self.run_core_tests(
         lambda: self.ParseExampleDataset(
-            num_repeat=num_repeat, batch_size=batch_size),
-        lambda: self.ParseExampleDataset(num_repeat=10, batch_size=4),
-        num_outputs)
+            num_repeat=num_repeat, batch_size=batch_size), num_outputs)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py
index 00d74c0..738d956 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py
@@ -31,8 +31,7 @@
 
   def testCore(self):
     num_outputs = 100
-    self.run_core_tests(lambda: self.build_dataset(10),
-                        lambda: self.build_dataset(20), num_outputs)
+    self.run_core_tests(lambda: self.build_dataset(10), num_outputs)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py
index 34419a3..c06cd39 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py
@@ -110,7 +110,6 @@
     stop = 10
     stop_1 = 8
     self.run_core_tests(lambda: self._build_range_dataset(start, stop),
-                        lambda: self._build_range_dataset(start, stop_1),
                         stop - start)
 
 
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/rebatch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/rebatch_dataset_serialization_test.py
index a053d08..0ae2692 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/rebatch_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/rebatch_dataset_serialization_test.py
@@ -32,9 +32,9 @@
       return distribute._RebatchDataset(
           dataset_ops.Dataset.range(num_elements).batch(
               4 * batch_size, drop_remainder=True),
-          num_workers=4)
+          num_replicas=4)
 
-    self.run_core_tests(lambda: build_dataset(200, 10), None, 20)
+    self.run_core_tests(lambda: build_dataset(200, 10), 20)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py
index c23c1ec..f12267d 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py
@@ -37,9 +37,7 @@
     return dataset.take(num_samples)
 
   def testSerializationCore(self):
-    self.run_core_tests(
-        lambda: self._build_dataset([0.5, 0.5], 100),
-        lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100)
+    self.run_core_tests(lambda: self._build_dataset([0.5, 0.5], 100), 100)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py
index 5f50160..33aa33c 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py
@@ -32,8 +32,7 @@
 
   def testScanCore(self):
     num_output = 5
-    self.run_core_tests(lambda: self._build_dataset(num_output),
-                        lambda: self._build_dataset(2), num_output)
+    self.run_core_tests(lambda: self._build_dataset(num_output), num_output)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py
index fe99a3d..09c09aa 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py
@@ -34,23 +34,21 @@
   def testSkipFewerThanInputs(self):
     count = 4
     num_outputs = 10 - count
-    self.run_core_tests(lambda: self._build_skip_dataset(count),
-                        lambda: self._build_skip_dataset(count + 2),
-                        num_outputs)
+    self.run_core_tests(lambda: self._build_skip_dataset(count), num_outputs)
 
   def testSkipVarious(self):
     # Skip more than inputs
-    self.run_core_tests(lambda: self._build_skip_dataset(20), None, 0)
+    self.run_core_tests(lambda: self._build_skip_dataset(20), 0)
     # Skip exactly the input size
-    self.run_core_tests(lambda: self._build_skip_dataset(10), None, 0)
-    self.run_core_tests(lambda: self._build_skip_dataset(-1), None, 0)
+    self.run_core_tests(lambda: self._build_skip_dataset(10), 0)
+    self.run_core_tests(lambda: self._build_skip_dataset(-1), 0)
     # Skip nothing
-    self.run_core_tests(lambda: self._build_skip_dataset(0), None, 10)
+    self.run_core_tests(lambda: self._build_skip_dataset(0), 10)
 
   def testInvalidSkip(self):
     with self.assertRaisesRegexp(ValueError,
                                  'Shape must be rank 0 but is rank 1'):
-      self.run_core_tests(lambda: self._build_skip_dataset([1, 2]), None, 0)
+      self.run_core_tests(lambda: self._build_skip_dataset([1, 2]), 0)
 
 
 class TakeDatasetSerializationTest(
@@ -62,26 +60,22 @@
 
   def testTakeFewerThanInputs(self):
     count = 4
-    self.run_core_tests(
-        lambda: self._build_take_dataset(count),
-        lambda: self._build_take_dataset(count + 2),
-        count,
-    )
+    self.run_core_tests(lambda: self._build_take_dataset(count), count)
 
   def testTakeVarious(self):
     # Take more than inputs
-    self.run_core_tests(lambda: self._build_take_dataset(20), None, 10)
+    self.run_core_tests(lambda: self._build_take_dataset(20), 10)
     # Take exactly the input size
-    self.run_core_tests(lambda: self._build_take_dataset(10), None, 10)
+    self.run_core_tests(lambda: self._build_take_dataset(10), 10)
     # Take all
-    self.run_core_tests(lambda: self._build_take_dataset(-1), None, 10)
+    self.run_core_tests(lambda: self._build_take_dataset(-1), 10)
     # Take nothing
-    self.run_core_tests(lambda: self._build_take_dataset(0), None, 0)
+    self.run_core_tests(lambda: self._build_take_dataset(0), 0)
 
   def testInvalidTake(self):
     with self.assertRaisesRegexp(ValueError,
                                  'Shape must be rank 0 but is rank 1'):
-      self.run_core_tests(lambda: self._build_take_dataset([1, 2]), None, 0)
+      self.run_core_tests(lambda: self._build_take_dataset([1, 2]), 0)
 
 
 class RepeatDatasetSerializationTest(
@@ -94,35 +88,26 @@
 
   def testFiniteRepeat(self):
     count = 10
-    self.run_core_tests(lambda: self._build_repeat_dataset(count),
-                        lambda: self._build_repeat_dataset(count + 2),
-                        3 * count)
+    self.run_core_tests(lambda: self._build_repeat_dataset(count), 3 * count)
 
   def testEmptyRepeat(self):
-    self.run_core_tests(lambda: self._build_repeat_dataset(0), None, 0)
+    self.run_core_tests(lambda: self._build_repeat_dataset(0), 0)
 
   def testInfiniteRepeat(self):
     self.verify_unused_iterator(
         lambda: self._build_repeat_dataset(-1), 10, verify_exhausted=False)
-    self.verify_init_before_restore(
-        lambda: self._build_repeat_dataset(-1), 10, verify_exhausted=False)
     self.verify_multiple_breaks(
         lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False)
     self.verify_reset_restored_iterator(
         lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False)
-    self.verify_restore_in_modified_graph(
-        lambda: self._build_repeat_dataset(-1),
-        lambda: self._build_repeat_dataset(2),
-        20,
-        verify_exhausted=False)
+
     # Test repeat empty dataset
-    self.run_core_tests(lambda: self._build_repeat_dataset(-1, 0), None, 0)
+    self.run_core_tests(lambda: self._build_repeat_dataset(-1, 0), 0)
 
   def testInvalidRepeat(self):
     with self.assertRaisesRegexp(
         ValueError, 'Shape must be rank 0 but is rank 1'):
-      self.run_core_tests(lambda: self._build_repeat_dataset([1, 2], 0),
-                          None, 0)
+      self.run_core_tests(lambda: self._build_repeat_dataset([1, 2], 0), 0)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py
index 2486db9..2cada3f 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py
@@ -72,6 +72,7 @@
       init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
                                                         num_outputs)
       with self.session(graph=g) as sess:
+        self.evaluate(init_ops)
         saver.restore(sess, self._ckpt_path())
         for _ in range(num_outputs - break_point):
           output = self.evaluate(get_next_ops)
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/shard_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/shard_dataset_serialization_test.py
index 99674b6..e180b10 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/shard_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/shard_dataset_serialization_test.py
@@ -31,10 +31,9 @@
   def _build_dataset(self, num_elements, num_shards, index):
     return dataset_ops.Dataset.range(num_elements).shard(num_shards, index)
 
-  @parameterized.parameters((10, 5, 2, 3), (10, 10, 0, 9), (100, 2, 0, 1))
-  def testCore(self, elems, num_shards, index1, index2):
-    self.run_core_tests(lambda: self._build_dataset(elems, num_shards, index1),
-                        lambda: self._build_dataset(elems, num_shards, index2),
+  @parameterized.parameters((10, 5, 2), (10, 10, 0), (100, 2, 0))
+  def testCore(self, elems, num_shards, index):
+    self.run_core_tests(lambda: self._build_dataset(elems, num_shards, index),
                         elems // num_shards)
 
 
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py
index f847ac1..42f01b7 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py
@@ -31,8 +31,7 @@
         shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed))
 
   def testCore(self):
-    self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20),
-                        100)
+    self.run_core_tests(lambda: self._build_ds(10), 100)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py
index e753a7a..8e05823 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py
@@ -17,6 +17,10 @@
 from __future__ import division
 from __future__ import print_function
 
+import itertools
+
+from absl.testing import parameterized
+
 from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
 from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
 from tensorflow.python.data.ops import dataset_ops
@@ -26,7 +30,8 @@
 
 
 class ShuffleDatasetSerializationTest(
-    dataset_serialization_test_base.DatasetSerializationTestBase):
+    dataset_serialization_test_base.DatasetSerializationTestBase,
+    parameterized.TestCase):
 
   def _build_shuffle_dataset(
       self,
@@ -36,113 +41,100 @@
       seed=None,
       reshuffle_each_iteration=None,
   ):
-    return dataset_ops.Dataset.range(range_limit).shuffle(
+    dataset = dataset_ops.Dataset.range(range_limit).shuffle(
         buffer_size,
         seed=seed,
         reshuffle_each_iteration=reshuffle_each_iteration).repeat(num_repeats)
+    # TODO(b/138399725): Re-enable default optimizations.
+    options = dataset_ops.Options()
+    options.experimental_optimization.apply_default_optimizations = False
+    return dataset.with_options(options)
 
-  def testShuffleCore(self):
-
+  @parameterized.parameters(itertools.product([True, False], [1, 3, 5, 8, 10]))
+  def testShuffleCore(self, reshuffle_each_iteration, buffer_size):
     seed = 55
     range_limit = 5
     num_repeats = 2
     num_outputs = range_limit * num_repeats
-    buffer_sizes = [1, 3, 5, 8, 10]
-    # pylint: disable=cell-var-from-loop
     # pylint: disable=g-long-lambda
-    for reshuffle_each_iteration in [True, False]:
-      for buffer_size in buffer_sizes:
-        self.run_core_tests(
-            lambda: self._build_shuffle_dataset(
-                range_limit=range_limit,
-                num_repeats=num_repeats,
-                buffer_size=buffer_size,
-                seed=seed,
-                reshuffle_each_iteration=reshuffle_each_iteration),
-            lambda: self._build_shuffle_dataset(
-                range_limit=range_limit,
-                num_repeats=num_repeats,
-                buffer_size=buffer_size,
-                seed=10,
-                reshuffle_each_iteration=reshuffle_each_iteration),
-            num_outputs)
-    # pylint: enable=cell-var-from-loop
-    # pylint: enable=g-long-lambda
+    self.run_core_tests(
+        lambda: self._build_shuffle_dataset(
+            range_limit=range_limit,
+            num_repeats=num_repeats,
+            buffer_size=buffer_size,
+            seed=seed,
+            reshuffle_each_iteration=reshuffle_each_iteration), num_outputs)
 
-  def testNonDeterministicSeeding(self):
-
+  # TODO(b/133780904): Re-enable this test once randomness state is hoisted out
+  # of the input pipeline.
+  @parameterized.parameters(itertools.product([True, False], [1, 3, 5, 8, 10]))
+  def _testNonDeterministicSeeding(self, reshuffle_each_iteration, buffer_size):
     range_limit = 5
     num_repeats = 2
     num_outputs = range_limit * num_repeats
-    buffer_sizes = [1, 3, 5, 8, 10]
-    for reshuffle_each_iteration in [True, False]:
-      for buffer_size in buffer_sizes:
 
-        def ds_fn():
-          # pylint: disable=cell-var-from-loop
-          return self._build_shuffle_dataset(
-              range_limit=range_limit,
-              num_repeats=num_repeats,
-              buffer_size=buffer_size,
-              seed=None,  # Iterator seeds are generated non-deterministically.
-              reshuffle_each_iteration=reshuffle_each_iteration)
-          # pylint: enable=cell-var-from-loop
+    def ds_fn():
+      # pylint: disable=cell-var-from-loop
+      return self._build_shuffle_dataset(
+          range_limit=range_limit,
+          num_repeats=num_repeats,
+          buffer_size=buffer_size,
+          seed=None,  # Iterator seeds are generated non-deterministically.
+          reshuffle_each_iteration=reshuffle_each_iteration)
+      # pylint: enable=cell-var-from-loop
 
-        # We checkpoint the initial state of the Dataset so that we can restore
-        # the seeds in the next run. Since the seeding is non-deterministic
-        # the dataset gets initialized with different seeds each time.
-        expected = self.gen_outputs(
-            ds_fn,
-            break_points=[0],
-            num_outputs=num_outputs,
-            ckpt_saved=False,
-            verify_exhausted=False,
-            save_checkpoint_at_end=False)
-        actual = self.gen_outputs(
-            ds_fn,
-            break_points=self.gen_break_points(num_outputs),
-            num_outputs=num_outputs,
-            ckpt_saved=True,
-            verify_exhausted=False)
+    # We checkpoint the initial state of the Dataset so that we can restore
+    # the seeds in the next run. Since the seeding is non-deterministic
+    # the dataset gets initialized with different seeds each time.
+    expected = self.gen_outputs(
+        ds_fn,
+        break_points=[0],
+        num_outputs=num_outputs,
+        ckpt_saved=False,
+        verify_exhausted=False,
+        save_checkpoint_at_end=False)
+    actual = self.gen_outputs(
+        ds_fn,
+        break_points=self.gen_break_points(num_outputs),
+        num_outputs=num_outputs,
+        ckpt_saved=True,
+        verify_exhausted=False)
+    self.match(expected, actual)
+
+  @parameterized.parameters(itertools.product([True, False], [1, 3, 5, 8, 10]))
+  def testMultipleIterators(self, reshuffle_each_iteration, buffer_size):
+    range_limit = 5
+    num_repeats = 2
+    num_outputs = range_limit * num_repeats
+
+    def ds_fn():
+      # pylint: disable=cell-var-from-loop
+      return self._build_shuffle_dataset(
+          range_limit=range_limit,
+          num_repeats=num_repeats,
+          buffer_size=buffer_size,
+          seed=None,  # Iterator seeds are generated non-deterministically.
+          reshuffle_each_iteration=reshuffle_each_iteration)
+      # pylint: enable=cell-var-from-loop
+
+    with ops.Graph().as_default() as g:
+      ds = ds_fn()
+      iterators = [ds.make_one_shot_iterator(), ds.make_one_shot_iterator()]
+      get_next_ops = [it.get_next() for it in iterators]
+      saveables = [
+          contrib_iterator_ops.make_saveable_from_iterator(it)
+          for it in iterators
+      ]
+      for saveable in saveables:
+        ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+      saver = saver_lib.Saver(allow_empty=True)
+      with self.session(graph=g) as sess:
+        self._save(sess, saver)
+        expected = [self.evaluate(get_next_ops) for _ in range(num_outputs)]
+        self._restore(saver, sess)
+        actual = [self.evaluate(get_next_ops) for _ in range(num_outputs)]
         self.match(expected, actual)
 
-  def testMultipleIterators(self):
-    range_limit = 5
-    num_repeats = 2
-    num_outputs = range_limit * num_repeats
-    buffer_sizes = [1, 3, 5, 8, 10]
-
-    for reshuffle_each_iteration in [True, False]:
-      for buffer_size in buffer_sizes:
-
-        def ds_fn():
-          # pylint: disable=cell-var-from-loop
-          return self._build_shuffle_dataset(
-              range_limit=range_limit,
-              num_repeats=num_repeats,
-              buffer_size=buffer_size,
-              seed=None,  # Iterator seeds are generated non-deterministically.
-              reshuffle_each_iteration=reshuffle_each_iteration)
-          # pylint: enable=cell-var-from-loop
-
-        with ops.Graph().as_default() as g:
-          ds = ds_fn()
-          iterators = [ds.make_one_shot_iterator(), ds.make_one_shot_iterator()]
-          get_next_ops = [it.get_next() for it in iterators]
-          saveables = [
-              contrib_iterator_ops.make_saveable_from_iterator(it)
-              for it in iterators
-          ]
-          for saveable in saveables:
-            ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
-          saver = saver_lib.Saver(allow_empty=True)
-          with self.session(graph=g) as sess:
-            self._save(sess, saver)
-            expected = [self.evaluate(get_next_ops) for _ in range(num_outputs)]
-            self._restore(saver, sess)
-            actual = [self.evaluate(get_next_ops) for _ in range(num_outputs)]
-            self.match(expected, actual)
-
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py
index 006279b..e3a44a4 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py
@@ -44,9 +44,7 @@
   def testSQLSaveable(self):
     num_repeats = 4
     num_outputs = num_repeats * 2
-    self.run_core_tests(lambda: self._build_dataset(num_repeats),
-                        lambda: self._build_dataset(num_repeats // 2),
-                        num_outputs)
+    self.run_core_tests(lambda: self._build_dataset(num_repeats), num_outputs)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py
index 9372eef..66d4236 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py
@@ -44,15 +44,13 @@
       # pylint: disable=g-long-lambda
       self.run_core_tests(
           lambda: dataset_ops.Dataset.range(100).apply(
-              stats_ops.bytes_produced_stats(["bytes_produced"])),
-          None, 100)
+              stats_ops.bytes_produced_stats(["bytes_produced"])), 100)
       # pylint: enable=g-long-lambda
 
   def testBytesStatsDatasetSaveableCore(self):
     num_outputs = 100
-    self.run_core_tests(
-        lambda: self._build_dataset_bytes_stats(num_outputs),
-        lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs)
+    self.run_core_tests(lambda: self._build_dataset_bytes_stats(num_outputs),
+                        num_outputs)
 
   def _build_dataset_latency_stats(self, num_elements, tag="record_latency"):
     return dataset_ops.Dataset.range(num_elements).apply(
@@ -72,25 +70,23 @@
       self.run_core_tests(
           lambda: dataset_ops.Dataset.range(100).apply(
               stats_ops.latency_stats(["record_latency", "record_latency_2"])),
-          None, 100)
+          100)
       # pylint: enable=g-long-lambda
 
   def testLatencyStatsDatasetSaveableCore(self):
     num_outputs = 100
 
-    self.run_core_tests(
-        lambda: self._build_dataset_latency_stats(num_outputs),
-        lambda: self._build_dataset_latency_stats(num_outputs // 10),
-        num_outputs)
+    self.run_core_tests(lambda: self._build_dataset_latency_stats(num_outputs),
+                        num_outputs)
 
     self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs),
-                        None, num_outputs)
+                        num_outputs)
 
     tag1 = "record_latency"
     tag2 = "record_latency"
     self.run_core_tests(
         lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2),
-        None, num_outputs)
+        num_outputs)
 
   def _build_dataset_stats_aggregator(self):
     aggregator = stats_aggregator.StatsAggregator()
@@ -100,7 +96,7 @@
   def test_set_stats_aggregator_not_support_checkpointing(self):
     with self.assertRaisesRegexp(errors.UnimplementedError,
                                  "does not support checkpointing"):
-      self.run_core_tests(self._build_dataset_stats_aggregator, None, 10)
+      self.run_core_tests(self._build_dataset_stats_aggregator, 10)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/take_while_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/take_while_dataset_serialization_test.py
index 47899ea..67a27ac 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/take_while_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/take_while_dataset_serialization_test.py
@@ -33,10 +33,9 @@
     return dataset_ops.Dataset.range(num_elements).apply(
         take_while_ops.take_while(lambda x: x < upper_bound))
 
-  @parameterized.parameters((23, 10, 7), (10, 50, 0), (25, 30, 25))
-  def testCore(self, num_elem1, num_elem2, upper_bound):
-    self.run_core_tests(lambda: self._build_dataset(num_elem1, upper_bound),
-                        lambda: self._build_dataset(num_elem2, upper_bound),
+  @parameterized.parameters((23, 7), (10, 0), (25, 25))
+  def testCore(self, num_elem, upper_bound):
+    self.run_core_tests(lambda: self._build_dataset(num_elem, upper_bound),
                         upper_bound)
 
 
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py
index c87a744..97827c8 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py
@@ -45,7 +45,7 @@
       # pylint: disable=cell-var-from-loop
       self.run_core_tests(
           lambda: self._build_iterator_graph(test_filenames, compression_type),
-          lambda: self._build_iterator_graph(test_filenames), num_outputs)
+          num_outputs)
       # pylint: enable=cell-var-from-loop
 
 
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py
index f0dcc13..92cd8e0 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py
@@ -70,10 +70,9 @@
     self.run_core_tests(
         lambda: self._build_iterator_graph(num_epochs, batch_size,
                                            buffer_size=0),
-        lambda: self._build_iterator_graph(num_epochs * 2, batch_size),
         num_outputs)
     self.run_core_tests(
-        lambda: self._build_iterator_graph(num_epochs, buffer_size=0), None,
+        lambda: self._build_iterator_graph(num_epochs, buffer_size=0),
         num_outputs * batch_size)
     # pylint: enable=g-long-lambda
 
@@ -81,7 +80,6 @@
     num_epochs = 5
     num_outputs = num_epochs * self._num_files * self._num_records
     self.run_core_tests(lambda: self._build_iterator_graph(num_epochs),
-                        lambda: self._build_iterator_graph(num_epochs * 2),
                         num_outputs)
 
   def testTFRecordWithCompressionCore(self):
@@ -89,10 +87,10 @@
     num_outputs = num_epochs * self._num_files * self._num_records
     self.run_core_tests(
         lambda: self._build_iterator_graph(num_epochs, compression_type="ZLIB"),
-        lambda: self._build_iterator_graph(num_epochs * 2), num_outputs)
+        num_outputs)
     self.run_core_tests(
         lambda: self._build_iterator_graph(num_epochs, compression_type="GZIP"),
-        lambda: self._build_iterator_graph(num_epochs * 2), num_outputs)
+        num_outputs)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py
index 528598d..e900c56 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py
@@ -43,7 +43,6 @@
     num_outputs = tensor_slice_len
     self.run_core_tests(
         lambda: self.build_dataset(15.0, tensor_slice_len, batch_size),
-        lambda: self.build_dataset(20.0, tensor_slice_len, batch_size),
         num_outputs)
 
 
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py
index e2862af..278fd85 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py
@@ -32,8 +32,7 @@
       return dataset_ops.Dataset.range(num_elements).map(
           lambda x: x % unique_elem_range).apply(unique.unique())
 
-    self.run_core_tests(lambda: build_dataset(200, 100),
-                        lambda: build_dataset(40, 100), 100)
+    self.run_core_tests(lambda: build_dataset(200, 100), 100)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py
index 4ea6131..b26691f 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py
@@ -43,11 +43,10 @@
     # Equal length components
     arr = [37.0, 38.0, 39.0, 40.0]
     num_outputs = len(arr)
-    self.run_core_tests(lambda: self._build_dataset(arr), None, num_outputs)
+    self.run_core_tests(lambda: self._build_dataset(arr), num_outputs)
     # Variable length components
     diff_size_arr = [1.0, 2.0]
-    self.run_core_tests(lambda: self._build_dataset(diff_size_arr),
-                        lambda: self._build_dataset(arr), 2)
+    self.run_core_tests(lambda: self._build_dataset(diff_size_arr), 2)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py
index 313dabf..0f43f4d 100644
--- a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py
@@ -21,18 +21,17 @@
 import time
 from absl.testing import parameterized
 
-from tensorflow.python.compat import compat
 from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
 from tensorflow.python.data.experimental.ops import snapshot
+from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.ops import readers as core_readers
-from tensorflow.python.framework import test_util
+from tensorflow.python.framework import combinations
 from tensorflow.python.ops import gen_array_ops
 from tensorflow.python.ops import string_ops
 from tensorflow.python.platform import test
 
 
-@test_util.run_all_in_graph_and_eager_modes
 class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
                           parameterized.TestCase):
 
@@ -78,6 +77,7 @@
           self.assertEqual(filename, "%08d.snapshot" % file_counter)
           file_counter += 1
 
+  @combinations.generate(test_base.default_test_combinations())
   def testWriteDifferentPipelinesInOneDirectory(self):
     tmpdir = self.makeSnapshotDirectory()
 
@@ -91,6 +91,7 @@
 
     self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testWriteSnapshotMultipleSimultaneous(self):
     tmpdir = self.makeSnapshotDirectory()
 
@@ -110,6 +111,7 @@
     # one that lost the race would be in passthrough mode.
     self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testGetNextCreatesDir(self):
     tmpdir = self.makeSnapshotDirectory()
 
@@ -128,8 +130,12 @@
     # We check that only one directory is created.
     self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
 
-  @parameterized.parameters(snapshot.COMPRESSION_NONE,
-                            snapshot.COMPRESSION_GZIP)
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(
+              compression=[snapshot.COMPRESSION_NONE,
+                           snapshot.COMPRESSION_GZIP])))
   def testWriteSnapshotSimpleSuccessful(self, compression):
     tmpdir = self.makeSnapshotDirectory()
 
@@ -139,6 +145,7 @@
 
     self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testWriteSnapshotRepeatAfterwards(self):
     tmpdir = self.makeSnapshotDirectory()
 
@@ -149,8 +156,12 @@
 
     self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
 
-  @parameterized.parameters(snapshot.COMPRESSION_NONE,
-                            snapshot.COMPRESSION_GZIP)
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(
+              compression=[snapshot.COMPRESSION_NONE,
+                           snapshot.COMPRESSION_GZIP])))
   def testReadSnapshotBackAfterWrite(self, compression):
     self.setUpTFRecord()
     filenames = self.test_filenames
@@ -174,123 +185,80 @@
         tmpdir, compression=compression))
     self.assertDatasetProduces(dataset2, expected)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testReadSnapshotParallelAfterWrite(self):
-    with compat.forward_compatibility_horizon(2019, 8, 16):
-      self.setUpTFRecord(10, 4000)
-      filenames = self.test_filenames
+    self.setUpTFRecord(10, 4000)
+    filenames = self.test_filenames
 
-      expected = [
-          b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
-          for f in range(0, 10)
-          for r in range(0, 4000)
-      ]
+    expected = [
+        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
+        for f in range(0, 10)
+        for r in range(0, 4000)
+    ]
 
-      tmpdir = self.makeSnapshotDirectory()
-      dataset = core_readers._TFRecordDataset(filenames)
-      dataset = dataset.apply(
-          snapshot.snapshot(
-              tmpdir,
-              shard_size_bytes=1024 * 1024,
-              num_reader_threads=2,
-              reader_buffer_size=10))
-      self.assertDatasetProduces(dataset, expected, assert_items_equal=True)
+    tmpdir = self.makeSnapshotDirectory()
+    dataset = core_readers._TFRecordDataset(filenames)
+    dataset = dataset.apply(
+        snapshot.snapshot(
+            tmpdir,
+            shard_size_bytes=1024 * 1024,
+            num_reader_threads=2,
+            reader_buffer_size=10))
+    self.assertDatasetProduces(dataset, expected, assert_items_equal=True)
 
-      # remove the original files and try to read the data back only from
-      # snapshot.
-      self.removeTFRecords()
+    # remove the original files and try to read the data back only from
+    # snapshot.
+    self.removeTFRecords()
 
-      dataset2 = core_readers._TFRecordDataset(filenames)
-      dataset2 = dataset2.apply(
-          snapshot.snapshot(
-              tmpdir,
-              shard_size_bytes=1024 * 1024,
-              num_reader_threads=2,
-              reader_buffer_size=10))
-      self.assertDatasetProduces(dataset2, expected, assert_items_equal=True)
+    dataset2 = core_readers._TFRecordDataset(filenames)
+    dataset2 = dataset2.apply(
+        snapshot.snapshot(
+            tmpdir,
+            shard_size_bytes=1024 * 1024,
+            num_reader_threads=2,
+            reader_buffer_size=10))
+    self.assertDatasetProduces(dataset2, expected, assert_items_equal=True)
 
-  @parameterized.parameters(
-      {
-          "compression": snapshot.COMPRESSION_NONE,
-          "threads": 2,
-          "size": 1
-      },
-      {
-          "compression": snapshot.COMPRESSION_GZIP,
-          "threads": 2,
-          "size": 1
-      },
-      {
-          "compression": snapshot.COMPRESSION_NONE,
-          "threads": 2,
-          "size": 2
-      },
-      {
-          "compression": snapshot.COMPRESSION_GZIP,
-          "threads": 2,
-          "size": 2
-      },
-      {
-          "compression": snapshot.COMPRESSION_NONE,
-          "threads": 8,
-          "size": 1
-      },
-      {
-          "compression": snapshot.COMPRESSION_GZIP,
-          "threads": 8,
-          "size": 1
-      },
-      {
-          "compression": snapshot.COMPRESSION_NONE,
-          "threads": 8,
-          "size": 4
-      },
-      {
-          "compression": snapshot.COMPRESSION_GZIP,
-          "threads": 8,
-          "size": 4
-      },
-      {
-          "compression": snapshot.COMPRESSION_NONE,
-          "threads": 8,
-          "size": 8
-      },
-      {
-          "compression": snapshot.COMPRESSION_GZIP,
-          "threads": 8,
-          "size": 8
-      },
-  )
-  def testReadSnapshotBackAfterMultiThreadedWrite(self, compression, threads,
-                                                  size):
-    with compat.forward_compatibility_horizon(2019, 8, 16):
-      self.setUpTFRecord()
-      filenames = self.test_filenames
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.times(
+              combinations.combine(
+                  compression=[snapshot.COMPRESSION_NONE,
+                               snapshot.COMPRESSION_GZIP]),
+              combinations.combine(threads=2, size=[1, 2]) +
+              combinations.combine(threads=8, size=[1, 4, 8]))))
+  def testReadSnapshotBackAfterMultiThreadedWrite(
+      self, compression, threads, size):
+    self.setUpTFRecord()
+    filenames = self.test_filenames
 
-      expected = [
-          b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
-          for f in range(0, 10)
-          for r in range(0, 10)
-      ]
+    expected = [
+        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
+        for f in range(0, 10)
+        for r in range(0, 10)
+    ]
 
-      tmpdir = self.makeSnapshotDirectory()
-      dataset = core_readers._TFRecordDataset(filenames)
-      dataset = dataset.apply(
-          snapshot.snapshot(
-              tmpdir,
-              compression=compression,
-              num_writer_threads=threads,
-              writer_buffer_size=size))
-      self.assertDatasetProduces(dataset, expected)
+    tmpdir = self.makeSnapshotDirectory()
+    dataset = core_readers._TFRecordDataset(filenames)
+    dataset = dataset.apply(
+        snapshot.snapshot(
+            tmpdir,
+            compression=compression,
+            num_writer_threads=threads,
+            writer_buffer_size=size))
+    self.assertDatasetProduces(dataset, expected)
 
-      # remove the original files and try to read the data back only from
-      # snapshot
-      self.removeTFRecords()
+    # remove the original files and try to read the data back only from
+    # snapshot
+    self.removeTFRecords()
 
-      dataset2 = core_readers._TFRecordDataset(filenames)
-      dataset2 = dataset2.apply(
-          snapshot.snapshot(tmpdir, compression=compression))
-      self.assertDatasetProduces(dataset2, expected, assert_items_equal=True)
+    dataset2 = core_readers._TFRecordDataset(filenames)
+    dataset2 = dataset2.apply(
+        snapshot.snapshot(tmpdir, compression=compression))
+    self.assertDatasetProduces(dataset2, expected, assert_items_equal=True)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testSameFingerprintWithDifferentInitializationOrder(self):
     tmpdir = self.makeSnapshotDirectory()
 
@@ -312,6 +280,7 @@
 
     self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testExpiredSnapshotRewrite(self):
     tmpdir = self.makeSnapshotDirectory()
 
@@ -340,6 +309,7 @@
       self.evaluate(next2())
     self.assertSnapshotDirectoryContains(tmpdir, 1, 2, 1)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testSpecifyShardSize(self):
     tmpdir = self.makeSnapshotDirectory()
 
@@ -355,6 +325,7 @@
 
     self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 4)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testAdditionalOperationsAfterReadBack(self):
     self.setUpTFRecord()
     filenames = self.test_filenames
diff --git a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
index ede4f8a..4f04a0a 100644
--- a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
@@ -362,7 +362,7 @@
 
     num_output = 100 // 16 + 1
     self.parallelCallsStats(
-        dataset_fn, {"ExperimentalMapAndBatchDataset"},
+        dataset_fn, {"MapAndBatchDataset"},
         num_output,
         check_elements=False,
         function_processing_time=True)
@@ -391,7 +391,7 @@
       num_output = total_records // batch_size + 1
 
     self.parallelCallsStats(
-        dataset_fn, {"ExperimentalParseExampleDataset"},
+        dataset_fn, {"ParseExampleDataset"},
         num_output,
         check_elements=False)
 
@@ -409,19 +409,19 @@
     handle = self.getHandle(aggregator)
     self.assertStatisticsHasCount(
         handle,
-        self.regexForNodeName("record_stats::ExperimentalParseExampleDataset",
+        self.regexForNodeName("record_stats::ParseExampleDataset",
                               "features_count"), total_records)
     self.assertStatisticsHasCount(
         handle,
-        self.regexForNodeName("record_stats::ExperimentalParseExampleDataset",
+        self.regexForNodeName("record_stats::ParseExampleDataset",
                               "feature_values_count"), total_records)
     self.assertStatisticsHasSum(
         handle,
-        self.regexForNodeName("record_stats::ExperimentalParseExampleDataset",
+        self.regexForNodeName("record_stats::ParseExampleDataset",
                               "features_count"), total_records * 4)
     self.assertStatisticsHasSum(
         handle,
-        self.regexForNodeName("record_stats::ExperimentalParseExampleDataset",
+        self.regexForNodeName("record_stats::ParseExampleDataset",
                               "feature_values_count"),
         self._sum_keywords(1) * num_epochs + 3 * total_records)
 
diff --git a/tensorflow/python/data/experimental/ops/batching.py b/tensorflow/python/data/experimental/ops/batching.py
index 5dc2c1c..d028d35 100644
--- a/tensorflow/python/data/experimental/ops/batching.py
+++ b/tensorflow/python/data/experimental/ops/batching.py
@@ -17,7 +17,6 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python.compat import compat
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.util import convert
 from tensorflow.python.data.util import nest
@@ -247,18 +246,11 @@
         tensor_shape.TensorShape([None]).concatenate(self._row_shape),
         dataset_ops.get_legacy_output_types(input_dataset))
 
-    if compat.forward_compatible(2019, 8, 3):
-      variant_tensor = ged_ops.dense_to_sparse_batch_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          self._batch_size,
-          row_shape=convert.partial_shape_to_tensor(self._row_shape),
-          **self._flat_structure)
-    else:
-      variant_tensor = ged_ops.experimental_dense_to_sparse_batch_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          self._batch_size,
-          row_shape=convert.partial_shape_to_tensor(self._row_shape),
-          **self._flat_structure)
+    variant_tensor = ged_ops.dense_to_sparse_batch_dataset(
+        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
+        self._batch_size,
+        row_shape=convert.partial_shape_to_tensor(self._row_shape),
+        **self._flat_structure)
     super(_DenseToSparseBatchDataset, self).__init__(input_dataset,
                                                      variant_tensor)
 
@@ -302,26 +294,15 @@
           lambda component_spec: component_spec._batch(None),
           self._map_func.output_structure)
     # pylint: enable=protected-access
-    if compat.forward_compatible(2019, 8, 3):
-      variant_tensor = ged_ops.map_and_batch_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          self._map_func.function.captured_inputs,
-          f=self._map_func.function,
-          batch_size=self._batch_size_t,
-          num_parallel_calls=self._num_parallel_calls_t,
-          drop_remainder=self._drop_remainder_t,
-          preserve_cardinality=True,
-          **self._flat_structure)
-    else:
-      variant_tensor = ged_ops.experimental_map_and_batch_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          self._map_func.function.captured_inputs,
-          f=self._map_func.function,
-          batch_size=self._batch_size_t,
-          num_parallel_calls=self._num_parallel_calls_t,
-          drop_remainder=self._drop_remainder_t,
-          preserve_cardinality=True,
-          **self._flat_structure)
+    variant_tensor = ged_ops.map_and_batch_dataset(
+        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
+        self._map_func.function.captured_inputs,
+        f=self._map_func.function,
+        batch_size=self._batch_size_t,
+        num_parallel_calls=self._num_parallel_calls_t,
+        drop_remainder=self._drop_remainder_t,
+        preserve_cardinality=True,
+        **self._flat_structure)
     super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor)
 
   def _functions(self):
diff --git a/tensorflow/python/data/experimental/ops/cardinality.py b/tensorflow/python/data/experimental/ops/cardinality.py
index d7f4764..db4bb8f 100644
--- a/tensorflow/python/data/experimental/ops/cardinality.py
+++ b/tensorflow/python/data/experimental/ops/cardinality.py
@@ -17,7 +17,6 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python.compat import compat
 from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
 from tensorflow.python.util.tf_export import tf_export
 
@@ -49,7 +48,4 @@
     constant `INFINITE_CARDINALITY` and `UNKNOWN_CARDINALITY` respectively.
   """
 
-  if compat.forward_compatible(2019, 8, 3):
-    return ged_ops.dataset_cardinality(dataset._variant_tensor)  # pylint: disable=protected-access
-  else:
-    return ged_ops.experimental_dataset_cardinality(dataset._variant_tensor)  # pylint: disable=protected-access
+  return ged_ops.dataset_cardinality(dataset._variant_tensor)  # pylint: disable=protected-access
diff --git a/tensorflow/python/data/experimental/ops/distribute.py b/tensorflow/python/data/experimental/ops/distribute.py
index b834fe8..9bbd3ef 100644
--- a/tensorflow/python/data/experimental/ops/distribute.py
+++ b/tensorflow/python/data/experimental/ops/distribute.py
@@ -49,18 +49,11 @@
     self._input_dataset = input_dataset
 
     self._element_spec = input_dataset.element_spec
-    if compat.forward_compatible(2019, 8, 3):
-      variant_tensor = ged_ops.auto_shard_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          num_workers=num_workers,
-          index=index,
-          **self._flat_structure)
-    else:
-      variant_tensor = ged_ops.experimental_auto_shard_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          num_workers=num_workers,
-          index=index,
-          **self._flat_structure)
+    variant_tensor = ged_ops.auto_shard_dataset(
+        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
+        num_workers=num_workers,
+        index=index,
+        **self._flat_structure)
     super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor)
 
   @property
@@ -74,19 +67,29 @@
 
 
 class _RebatchDataset(dataset_ops.UnaryDataset):
-  """A `Dataset` that divides the batch size by `num_workers`."""
+  """A `Dataset` that divides the batch size by `num_replicas`.
 
-  def __init__(self, input_dataset, num_workers, use_fallback=True):
+  For each batch in the input dataset, the resulting dataset will produce
+  `num_replicas` minibatches whose sizes add up to the original batch size.
+  """
+
+  def __init__(self, input_dataset, num_replicas, use_fallback=True):
     self._input_dataset = input_dataset
 
     def recalculate_output_shapes(output_shapes):
-      """Recalculates the output_shapes after dividing it by num_workers."""
+      """Recalculates the output_shapes after dividing it by num_replicas."""
       if len(output_shapes) < 1:
         raise ValueError(
             "Input shape should have at least one dimension. "
             "Perhaps your input dataset is not batched?")
-      output_dims = [d for d in output_shapes.dims]
-      output_dims[0] = (output_dims[0] + num_workers - 1) // num_workers
+      output_dims = [d.value for d in output_shapes.dims]
+
+      if output_dims[0] is not None and output_dims[0] % num_replicas == 0:
+        output_dims[0] = output_dims[0] // num_replicas
+      else:
+        # Set the batch dimension to unknown. If the global batch size does not
+        # divide num_replicas evenly, the minibatches may have different sizes.
+        output_dims[0] = None
       return tensor_shape.TensorShape(output_dims)
 
     input_types = dataset_ops.get_legacy_output_types(self._input_dataset)
@@ -99,18 +102,13 @@
     if compat.forward_compatible(2019, 8, 13) or not use_fallback:
       variant_tensor = ged_ops.rebatch_dataset(
           self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          num_workers=num_workers,
+          num_replicas=num_replicas,
           use_fallback=use_fallback,
           **self._flat_structure)
-    elif compat.forward_compatible(2019, 8, 3):
+    else:
       variant_tensor = ged_ops.rebatch_dataset(
           self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          num_workers=num_workers,
-          **self._flat_structure)
-    else:
-      variant_tensor = ged_ops.experimental_rebatch_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          num_workers=num_workers,
+          num_replicas=num_replicas,
           **self._flat_structure)
     super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
 
diff --git a/tensorflow/python/data/experimental/ops/distribute_options.py b/tensorflow/python/data/experimental/ops/distribute_options.py
index 3c5b4a6..33f9420 100644
--- a/tensorflow/python/data/experimental/ops/distribute_options.py
+++ b/tensorflow/python/data/experimental/ops/distribute_options.py
@@ -47,6 +47,14 @@
       "option does nothing. If None, defaults to True.",
       default_factory=lambda: True)
 
+  _make_stateless = options.create_option(
+      name="_make_stateless",
+      ty=bool,
+      docstring=
+      "Determines whether the input pipeline should be rewritten to not "
+      "contain stateful transformations (so that its graph can be moved "
+      "between devices).")
+
   num_devices = options.create_option(
       name="num_devices",
       ty=int,
diff --git a/tensorflow/python/data/experimental/ops/error_ops.py b/tensorflow/python/data/experimental/ops/error_ops.py
index 1aa2ad7..23937bb 100644
--- a/tensorflow/python/data/experimental/ops/error_ops.py
+++ b/tensorflow/python/data/experimental/ops/error_ops.py
@@ -17,7 +17,6 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python.compat import compat
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.ops import gen_experimental_dataset_ops
 from tensorflow.python.util.tf_export import tf_export
@@ -60,14 +59,8 @@
   def __init__(self, input_dataset):
     """See `Dataset.ignore_errors()` for details."""
     self._input_dataset = input_dataset
-    if compat.forward_compatible(2019, 8, 3):
-      variant_tensor = (
-          gen_experimental_dataset_ops.ignore_errors_dataset(
-              self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-              **self._flat_structure))
-    else:
-      variant_tensor = (
-          gen_experimental_dataset_ops.experimental_ignore_errors_dataset(
-              self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-              **self._flat_structure))
+    variant_tensor = (
+        gen_experimental_dataset_ops.ignore_errors_dataset(
+            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
+            **self._flat_structure))
     super(_IgnoreErrorsDataset, self).__init__(input_dataset, variant_tensor)
diff --git a/tensorflow/python/data/experimental/ops/grouping.py b/tensorflow/python/data/experimental/ops/grouping.py
index 3cbe784..e48ffbc 100644
--- a/tensorflow/python/data/experimental/ops/grouping.py
+++ b/tensorflow/python/data/experimental/ops/grouping.py
@@ -19,7 +19,6 @@
 
 import numpy as np
 
-from tensorflow.python.compat import compat
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.util import nest
 from tensorflow.python.data.util import structure
@@ -255,30 +254,17 @@
     self._make_init_func(reducer.init_func)
     self._make_reduce_func(reducer.reduce_func, input_dataset)
     self._make_finalize_func(reducer.finalize_func)
-    if compat.forward_compatible(2019, 8, 3):
-      variant_tensor = ged_ops.experimental_group_by_reducer_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          self._key_func.function.captured_inputs,
-          self._init_func.function.captured_inputs,
-          self._reduce_func.function.captured_inputs,
-          self._finalize_func.function.captured_inputs,
-          key_func=self._key_func.function,
-          init_func=self._init_func.function,
-          reduce_func=self._reduce_func.function,
-          finalize_func=self._finalize_func.function,
-          **self._flat_structure)
-    else:
-      variant_tensor = ged_ops.group_by_reducer_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          self._key_func.function.captured_inputs,
-          self._init_func.function.captured_inputs,
-          self._reduce_func.function.captured_inputs,
-          self._finalize_func.function.captured_inputs,
-          key_func=self._key_func.function,
-          init_func=self._init_func.function,
-          reduce_func=self._reduce_func.function,
-          finalize_func=self._finalize_func.function,
-          **self._flat_structure)
+    variant_tensor = ged_ops.experimental_group_by_reducer_dataset(
+        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
+        self._key_func.function.captured_inputs,
+        self._init_func.function.captured_inputs,
+        self._reduce_func.function.captured_inputs,
+        self._finalize_func.function.captured_inputs,
+        key_func=self._key_func.function,
+        init_func=self._init_func.function,
+        reduce_func=self._reduce_func.function,
+        finalize_func=self._finalize_func.function,
+        **self._flat_structure)
     super(_GroupByReducerDataset, self).__init__(input_dataset, variant_tensor)
 
   def _make_key_func(self, key_func, input_dataset):
@@ -390,26 +376,15 @@
     self._make_key_func(key_func, input_dataset)
     self._make_reduce_func(reduce_func, input_dataset)
     self._make_window_size_func(window_size_func)
-    if compat.forward_compatible(2019, 8, 3):
-      variant_tensor = ged_ops.group_by_window_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          self._key_func.function.captured_inputs,
-          self._reduce_func.function.captured_inputs,
-          self._window_size_func.function.captured_inputs,
-          key_func=self._key_func.function,
-          reduce_func=self._reduce_func.function,
-          window_size_func=self._window_size_func.function,
-          **self._flat_structure)
-    else:
-      variant_tensor = ged_ops.experimental_group_by_window_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          self._key_func.function.captured_inputs,
-          self._reduce_func.function.captured_inputs,
-          self._window_size_func.function.captured_inputs,
-          key_func=self._key_func.function,
-          reduce_func=self._reduce_func.function,
-          window_size_func=self._window_size_func.function,
-          **self._flat_structure)
+    variant_tensor = ged_ops.group_by_window_dataset(
+        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
+        self._key_func.function.captured_inputs,
+        self._reduce_func.function.captured_inputs,
+        self._window_size_func.function.captured_inputs,
+        key_func=self._key_func.function,
+        reduce_func=self._reduce_func.function,
+        window_size_func=self._window_size_func.function,
+        **self._flat_structure)
     super(_GroupByWindowDataset, self).__init__(input_dataset, variant_tensor)
 
   def _make_window_size_func(self, window_size_func):
diff --git a/tensorflow/python/data/experimental/ops/interleave_ops.py b/tensorflow/python/data/experimental/ops/interleave_ops.py
index 9abf8fb..07351b8 100644
--- a/tensorflow/python/data/experimental/ops/interleave_ops.py
+++ b/tensorflow/python/data/experimental/ops/interleave_ops.py
@@ -17,87 +17,22 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python.compat import compat
 from tensorflow.python.data.experimental.ops import random_ops
 from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import convert
+from tensorflow.python.data.ops import readers
 from tensorflow.python.data.util import nest
 from tensorflow.python.data.util import structure
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
 from tensorflow.python.ops import gen_stateless_random_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.util import deprecation
 from tensorflow.python.util.tf_export import tf_export
 
 
-class _ParallelInterleaveDataset(dataset_ops.UnaryDataset):
-  """A `Dataset` that maps a function over its input and flattens the result."""
-
-  def __init__(self, input_dataset, map_func, cycle_length, block_length,
-               sloppy, buffer_output_elements, prefetch_input_elements):
-    """See `tf.data.experimental.parallel_interleave()` for details."""
-    self._input_dataset = input_dataset
-    self._map_func = dataset_ops.StructuredFunctionWrapper(
-        map_func, self._transformation_name(), dataset=input_dataset)
-    if not isinstance(self._map_func.output_structure, dataset_ops.DatasetSpec):
-      raise TypeError("`map_func` must return a `Dataset` object.")
-    self._element_spec = self._map_func.output_structure._element_spec  # pylint: disable=protected-access
-    self._cycle_length = ops.convert_to_tensor(
-        cycle_length, dtype=dtypes.int64, name="cycle_length")
-    self._block_length = ops.convert_to_tensor(
-        block_length, dtype=dtypes.int64, name="block_length")
-    self._sloppy = ops.convert_to_tensor(
-        sloppy, dtype=dtypes.bool, name="sloppy")
-    self._buffer_output_elements = convert.optional_param_to_tensor(
-        "buffer_output_elements",
-        buffer_output_elements,
-        argument_default=2 * block_length)
-    self._prefetch_input_elements = convert.optional_param_to_tensor(
-        "prefetch_input_elements",
-        prefetch_input_elements,
-        argument_default=2 * cycle_length)
-    # pylint: disable=protected-access
-    if compat.forward_compatible(2019, 8, 3):
-      variant_tensor = ged_ops.parallel_interleave_dataset(
-          self._input_dataset._variant_tensor,
-          self._map_func.function.captured_inputs,
-          self._cycle_length,
-          self._block_length,
-          self._sloppy,
-          self._buffer_output_elements,
-          self._prefetch_input_elements,
-          f=self._map_func.function,
-          **self._flat_structure)
-    else:
-      variant_tensor = ged_ops.experimental_parallel_interleave_dataset(
-          self._input_dataset._variant_tensor,
-          self._map_func.function.captured_inputs,
-          self._cycle_length,
-          self._block_length,
-          self._sloppy,
-          self._buffer_output_elements,
-          self._prefetch_input_elements,
-          f=self._map_func.function,
-          **self._flat_structure)
-    # pylint: enable=protected-access
-    super(_ParallelInterleaveDataset, self).__init__(input_dataset,
-                                                     variant_tensor)
-
-  def _functions(self):
-    return [self._map_func]
-
-  @property
-  def element_spec(self):
-    return self._element_spec
-
-  def _transformation_name(self):
-    return "tf.data.experimental.parallel_interleave()"
-
-
 @deprecation.deprecated(
     None,
     "Use `tf.data.Dataset.interleave(map_func, cycle_length, block_length, "
@@ -154,7 +89,7 @@
     `tf.data.Dataset.apply`.
   """
   def _apply_fn(dataset):
-    return _ParallelInterleaveDataset(
+    return readers.ParallelInterleaveDataset(
         dataset, map_func, cycle_length, block_length, sloppy,
         buffer_output_elements, prefetch_input_elements)
 
@@ -191,19 +126,11 @@
 
   def _as_variant_tensor(self):
     # pylint: disable=protected-access
-    if compat.forward_compatible(2019, 8, 3):
-      return (
-          ged_ops.directed_interleave_dataset(
-              self._selector_input._variant_tensor,
-              [data_input._variant_tensor for data_input in self._data_inputs],
-              **self._flat_structure))
-    else:
-      return (
-          ged_ops.experimental_directed_interleave_dataset(
-              self._selector_input._variant_tensor,
-              [data_input._variant_tensor for data_input in self._data_inputs],
-              **self._flat_structure))
-    # pylint: enable=protected-access
+    return (
+        gen_experimental_dataset_ops.directed_interleave_dataset(
+            self._selector_input._variant_tensor,
+            [data_input._variant_tensor for data_input in self._data_inputs],
+            **self._flat_structure))
 
   def _inputs(self):
     return [self._selector_input] + self._data_inputs
@@ -358,4 +285,3 @@
 # these aliases in place.
 choose_from_datasets = choose_from_datasets_v1
 sample_from_datasets = sample_from_datasets_v1
-
diff --git a/tensorflow/python/data/experimental/ops/matching_files.py b/tensorflow/python/data/experimental/ops/matching_files.py
index 59b477b..5bb0142 100644
--- a/tensorflow/python/data/experimental/ops/matching_files.py
+++ b/tensorflow/python/data/experimental/ops/matching_files.py
@@ -18,7 +18,6 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python.compat import compat
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -32,11 +31,7 @@
   def __init__(self, patterns):
     self._patterns = ops.convert_to_tensor(
         patterns, dtype=dtypes.string, name="patterns")
-    if compat.forward_compatible(2019, 8, 3):
-      variant_tensor = ged_ops.matching_files_dataset(self._patterns)
-    else:
-      variant_tensor = ged_ops.experimental_matching_files_dataset(
-          self._patterns)
+    variant_tensor = ged_ops.matching_files_dataset(self._patterns)
     super(MatchingFilesDataset, self).__init__(variant_tensor)
 
   @property
diff --git a/tensorflow/python/data/experimental/ops/optimization.py b/tensorflow/python/data/experimental/ops/optimization.py
index 23c381e..a5f71d3 100644
--- a/tensorflow/python/data/experimental/ops/optimization.py
+++ b/tensorflow/python/data/experimental/ops/optimization.py
@@ -17,7 +17,6 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python.compat import compat
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -105,18 +104,11 @@
       raise ValueError("At least one transformation should be specified")
     self._transformations = ops.convert_to_tensor(
         transformations, dtype=dtypes.string, name="transformations")
-    if compat.forward_compatible(2019, 8, 3):
-      variant_tensor = (
-          gen_experimental_dataset_ops.assert_next_dataset(
-              self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-              self._transformations,
-              **self._flat_structure))
-    else:
-      variant_tensor = (
-          gen_experimental_dataset_ops.experimental_assert_next_dataset(
-              self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-              self._transformations,
-              **self._flat_structure))
+    variant_tensor = (
+        gen_experimental_dataset_ops.assert_next_dataset(
+            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
+            self._transformations,
+            **self._flat_structure))
     super(_AssertNextDataset, self).__init__(input_dataset, variant_tensor)
 
 
@@ -126,16 +118,10 @@
   def __init__(self, input_dataset):
     """See `non_serializable()` for details."""
     self._input_dataset = input_dataset
-    if compat.forward_compatible(2019, 8, 3):
-      variant_tensor = (
-          gen_experimental_dataset_ops.non_serializable_dataset(
-              self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-              **self._flat_structure))
-    else:
-      variant_tensor = (
-          gen_experimental_dataset_ops.experimental_non_serializable_dataset(
-              self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-              **self._flat_structure))
+    variant_tensor = (
+        gen_experimental_dataset_ops.non_serializable_dataset(
+            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
+            **self._flat_structure))
     super(_NonSerializableDataset, self).__init__(input_dataset, variant_tensor)
 
 
@@ -171,18 +157,11 @@
     """
     self._datasets = list(datasets)
     self._element_spec = self._datasets[0].element_spec
-    if compat.forward_compatible(2019, 8, 3):
-      variant_tensor = (
-          gen_experimental_dataset_ops.choose_fastest_dataset(
-              [dataset._variant_tensor for dataset in self._datasets],  # pylint: disable=protected-access
-              num_experiments=num_experiments,
-              **self._flat_structure))
-    else:
-      variant_tensor = (
-          gen_experimental_dataset_ops.experimental_choose_fastest_dataset(
-              [dataset._variant_tensor for dataset in self._datasets],  # pylint: disable=protected-access
-              num_experiments=num_experiments,
-              **self._flat_structure))
+    variant_tensor = (
+        gen_experimental_dataset_ops.choose_fastest_dataset(
+            [dataset._variant_tensor for dataset in self._datasets],  # pylint: disable=protected-access
+            num_experiments=num_experiments,
+            **self._flat_structure))
     super(_ChooseFastestDataset, self).__init__(variant_tensor)
 
   def _inputs(self):
diff --git a/tensorflow/python/data/experimental/ops/parsing_ops.py b/tensorflow/python/data/experimental/ops/parsing_ops.py
index 3dad40a..2f74eba 100644
--- a/tensorflow/python/data/experimental/ops/parsing_ops.py
+++ b/tensorflow/python/data/experimental/ops/parsing_ops.py
@@ -17,7 +17,6 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python.compat import compat
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.util import structure
 from tensorflow.python.framework import dtypes
@@ -81,28 +80,16 @@
     self._element_spec = structure.convert_legacy_structure(
         output_types, output_shapes, output_classes)
 
-    if compat.forward_compatible(2019, 8, 3):
-      variant_tensor = (
-          gen_experimental_dataset_ops.parse_example_dataset(
-              self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-              self._num_parallel_calls,
-              self._dense_defaults,
-              self._sparse_keys,
-              self._dense_keys,
-              self._sparse_types,
-              self._dense_shapes,
-              **self._flat_structure))
-    else:
-      variant_tensor = (
-          gen_experimental_dataset_ops.experimental_parse_example_dataset(
-              self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-              self._num_parallel_calls,
-              self._dense_defaults,
-              self._sparse_keys,
-              self._dense_keys,
-              self._sparse_types,
-              self._dense_shapes,
-              **self._flat_structure))
+    variant_tensor = (
+        gen_experimental_dataset_ops.parse_example_dataset(
+            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
+            self._num_parallel_calls,
+            self._dense_defaults,
+            self._sparse_keys,
+            self._dense_keys,
+            self._sparse_types,
+            self._dense_shapes,
+            **self._flat_structure))
     super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor)
 
   @property
diff --git a/tensorflow/python/data/experimental/ops/random_ops.py b/tensorflow/python/data/experimental/ops/random_ops.py
index 873fe23..e4bd782 100644
--- a/tensorflow/python/data/experimental/ops/random_ops.py
+++ b/tensorflow/python/data/experimental/ops/random_ops.py
@@ -19,7 +19,6 @@
 
 import functools
 
-from tensorflow.python.compat import compat
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.util import random_seed
 from tensorflow.python.framework import dtypes
@@ -35,12 +34,8 @@
   def __init__(self, seed=None):
     """A `Dataset` of pseudorandom values."""
     self._seed, self._seed2 = random_seed.get_seed(seed)
-    if compat.forward_compatible(2019, 8, 3):
-      variant_tensor = gen_experimental_dataset_ops.random_dataset(
-          seed=self._seed, seed2=self._seed2, **self._flat_structure)
-    else:
-      variant_tensor = gen_experimental_dataset_ops.experimental_random_dataset(
-          seed=self._seed, seed2=self._seed2, **self._flat_structure)
+    variant_tensor = gen_experimental_dataset_ops.random_dataset(
+        seed=self._seed, seed2=self._seed2, **self._flat_structure)
     super(RandomDatasetV2, self).__init__(variant_tensor)
 
   @property
diff --git a/tensorflow/python/data/experimental/ops/readers.py b/tensorflow/python/data/experimental/ops/readers.py
index 91ebb52..f634d06 100644
--- a/tensorflow/python/data/experimental/ops/readers.py
+++ b/tensorflow/python/data/experimental/ops/readers.py
@@ -20,12 +20,13 @@
 import collections
 import csv
 import functools
+import gzip
 
 import numpy as np
 
-from tensorflow.python.compat import compat
 from tensorflow.python.data.experimental.ops import batching
 from tensorflow.python.data.experimental.ops import error_ops
+from tensorflow.python.data.experimental.ops import interleave_ops
 from tensorflow.python.data.experimental.ops import parsing_ops
 from tensorflow.python.data.experimental.ops import shuffle_ops
 from tensorflow.python.data.ops import dataset_ops
@@ -36,6 +37,7 @@
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_spec
+from tensorflow.python.framework import tensor_util
 from tensorflow.python.lib.io import file_io
 from tensorflow.python.ops import gen_experimental_dataset_ops
 from tensorflow.python.ops import io_ops
@@ -107,10 +109,11 @@
       return type_list[i]
 
 
-def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header):
+def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header,
+                  file_io_fn):
   """Generator that yields rows of CSV file(s) in order."""
   for fn in filenames:
-    with file_io.FileIO(fn, "r") as f:
+    with file_io_fn(fn) as f:
       rdr = csv.reader(
           f,
           delimiter=field_delim,
@@ -128,14 +131,15 @@
 
 def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim,
                            na_value, header, num_rows_for_inference,
-                           select_columns):
+                           select_columns, file_io_fn):
   """Infers column types from the first N valid CSV records of files."""
   if select_columns is None:
     select_columns = range(num_cols)
   inferred_types = [None] * len(select_columns)
 
   for i, csv_row in enumerate(
-      _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header)):
+      _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header,
+                    file_io_fn)):
     if num_rows_for_inference is not None and i >= num_rows_for_inference:
       break
 
@@ -152,13 +156,13 @@
   ]
 
 
-def _infer_column_names(filenames, field_delim, use_quote_delim):
+def _infer_column_names(filenames, field_delim, use_quote_delim, file_io_fn):
   """Infers column names from first rows of files."""
   csv_kwargs = {
       "delimiter": field_delim,
       "quoting": csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE
   }
-  with file_io.FileIO(filenames[0], "r") as f:
+  with file_io_fn(filenames[0]) as f:
     try:
       column_names = next(csv.reader(f, **csv_kwargs))
     except StopIteration:
@@ -166,7 +170,7 @@
                         "of %s.  Empty file?") % filenames[0])
 
   for name in filenames[1:]:
-    with file_io.FileIO(name, "r") as f:
+    with file_io_fn(name) as f:
       try:
         if next(csv.reader(f, **csv_kwargs)) != column_names:
           raise ValueError(
@@ -425,12 +429,28 @@
     dataset = dataset.shuffle(len(filenames), shuffle_seed)
 
   # Clean arguments; figure out column names and defaults
-
+  if column_names is None or column_defaults is None:
+    # Find out which io function to open the file
+    file_io_fn = lambda filename: file_io.FileIO(filename, "r")
+    if compression_type is not None:
+      compression_type_value = tensor_util.constant_value(compression_type)
+      if compression_type_value is None:
+        raise ValueError("Received unkown compression_type")
+      if compression_type_value == "GZIP":
+        file_io_fn = lambda filename: gzip.open(filename, "rt")
+      elif compression_type_value == "ZLIB":
+        raise ValueError(
+            "compression_type (%s) is not supported for probing columns" %
+            compression_type)
+      elif compression_type_value != "":
+        raise ValueError("compression_type (%s) is not supported" %
+                         compression_type)
   if column_names is None:
     if not header:
       raise ValueError("Cannot infer column names without a header line.")
     # If column names are not provided, infer from the header lines
-    column_names = _infer_column_names(filenames, field_delim, use_quote_delim)
+    column_names = _infer_column_names(filenames, field_delim, use_quote_delim,
+                                       file_io_fn)
   if len(column_names) != len(set(column_names)):
     raise ValueError("Cannot have duplicate column names.")
 
@@ -439,15 +459,18 @@
 
   if column_defaults is not None:
     column_defaults = [
-        constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
+        constant_op.constant([], dtype=x)
+        if not tensor_util.is_tensor(x) and x in _ACCEPTABLE_CSV_TYPES else x
         for x in column_defaults
     ]
   else:
     # If column defaults are not provided, infer from records at graph
     # construction time
-    column_defaults = _infer_column_defaults(
-        filenames, len(column_names), field_delim, use_quote_delim, na_value,
-        header, num_rows_for_inference, select_columns)
+    column_defaults = _infer_column_defaults(filenames, len(column_names),
+                                             field_delim, use_quote_delim,
+                                             na_value, header,
+                                             num_rows_for_inference,
+                                             select_columns, file_io_fn)
 
   if select_columns is not None and len(column_defaults) != len(select_columns):
     raise ValueError(
@@ -493,18 +516,9 @@
     return features
 
   # Read files sequentially (if num_parallel_reads=1) or in parallel
-  cycle_length = num_parallel_reads
-  if num_parallel_reads == dataset_ops.AUTOTUNE:
-    cycle_length = core_readers.DEFAULT_CYCLE_LENGTH
-  dataset = dataset.interleave(
-      filename_to_dataset,
-      cycle_length,
-      num_parallel_calls=num_parallel_reads)
-
-  if sloppy:
-    options = dataset_ops.Options()
-    options.experimental_deterministic = False
-    dataset = dataset.with_options(options)
+  dataset = dataset.apply(
+      interleave_ops.parallel_interleave(
+          filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy))
 
   dataset = _maybe_shuffle_and_repeat(
       dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
@@ -650,7 +664,8 @@
         argument_default="",
         argument_dtype=dtypes.string)
     record_defaults = [
-        constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
+        constant_op.constant([], dtype=x)
+        if not tensor_util.is_tensor(x) and x in _ACCEPTABLE_CSV_TYPES else x
         for x in record_defaults
     ]
     self._record_defaults = ops.convert_n_to_tensor(
@@ -673,30 +688,17 @@
     )
     self._element_spec = tuple(
         tensor_spec.TensorSpec([], d.dtype) for d in self._record_defaults)
-    if compat.forward_compatible(2019, 8, 3):
-      variant_tensor = gen_experimental_dataset_ops.csv_dataset(
-          filenames=self._filenames,
-          record_defaults=self._record_defaults,
-          buffer_size=self._buffer_size,
-          header=self._header,
-          output_shapes=self._flat_shapes,
-          field_delim=self._field_delim,
-          use_quote_delim=self._use_quote_delim,
-          na_value=self._na_value,
-          select_cols=self._select_cols,
-          compression_type=self._compression_type)
-    else:
-      variant_tensor = gen_experimental_dataset_ops.experimental_csv_dataset(
-          filenames=self._filenames,
-          record_defaults=self._record_defaults,
-          buffer_size=self._buffer_size,
-          header=self._header,
-          output_shapes=self._flat_shapes,
-          field_delim=self._field_delim,
-          use_quote_delim=self._use_quote_delim,
-          na_value=self._na_value,
-          select_cols=self._select_cols,
-          compression_type=self._compression_type)
+    variant_tensor = gen_experimental_dataset_ops.csv_dataset(
+        filenames=self._filenames,
+        record_defaults=self._record_defaults,
+        buffer_size=self._buffer_size,
+        header=self._header,
+        output_shapes=self._flat_shapes,
+        field_delim=self._field_delim,
+        use_quote_delim=self._use_quote_delim,
+        na_value=self._na_value,
+        select_cols=self._select_cols,
+        compression_type=self._compression_type)
     super(CsvDatasetV2, self).__init__(variant_tensor)
 
   @property
@@ -846,18 +848,11 @@
     reader_args = []
 
   # Read files sequentially (if reader_num_threads=1) or in parallel
-  cycle_length = reader_num_threads
-  if reader_num_threads == dataset_ops.AUTOTUNE:
-    cycle_length = core_readers.DEFAULT_CYCLE_LENGTH
-  dataset = dataset.interleave(
-      lambda filename: reader(filename, *reader_args),
-      cycle_length,
-      num_parallel_calls=reader_num_threads)
-
-  if sloppy_ordering:
-    options = dataset_ops.Options()
-    options.experimental_deterministic = False
-    dataset = dataset.with_options(options)
+  dataset = dataset.apply(
+      interleave_ops.parallel_interleave(
+          lambda filename: reader(filename, *reader_args),
+          cycle_length=reader_num_threads,
+          sloppy=sloppy_ordering))
 
   # Extract values if the `Example` tensors are stored as key-value tuples.
   if dataset_ops.get_legacy_output_types(dataset) == (
@@ -986,14 +981,9 @@
         query, dtype=dtypes.string, name="query")
     self._element_spec = nest.map_structure(
         lambda dtype: tensor_spec.TensorSpec([], dtype), output_types)
-    if compat.forward_compatible(2019, 8, 3):
-      variant_tensor = gen_experimental_dataset_ops.sql_dataset(
-          self._driver_name, self._data_source_name, self._query,
-          **self._flat_structure)
-    else:
-      variant_tensor = gen_experimental_dataset_ops.experimental_sql_dataset(
-          self._driver_name, self._data_source_name, self._query,
-          **self._flat_structure)
+    variant_tensor = gen_experimental_dataset_ops.sql_dataset(
+        self._driver_name, self._data_source_name, self._query,
+        **self._flat_structure)
     super(SqlDatasetV2, self).__init__(variant_tensor)
 
   @property
diff --git a/tensorflow/python/data/experimental/ops/scan_ops.py b/tensorflow/python/data/experimental/ops/scan_ops.py
index a81f5e6..27662d72c 100644
--- a/tensorflow/python/data/experimental/ops/scan_ops.py
+++ b/tensorflow/python/data/experimental/ops/scan_ops.py
@@ -17,14 +17,12 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
-
-from tensorflow.python.compat import compat
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.util import nest
 from tensorflow.python.data.util import structure
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import gen_experimental_dataset_ops
+from tensorflow.python.util.compat import collections_abc
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -52,9 +50,8 @@
           input_structure=(self._state_structure,
                            input_dataset.element_spec),
           add_to_graph=False)
-      if not (
-          isinstance(wrapped_func.output_types, collections.Sequence) and
-          len(wrapped_func.output_types) == 2):
+      if not (isinstance(wrapped_func.output_types, collections_abc.Sequence)
+              and len(wrapped_func.output_types) == 2):
         raise TypeError("The scan function must return a pair comprising the "
                         "new state and the output value.")
 
@@ -123,22 +120,13 @@
     self._scan_func = wrapped_func
     self._scan_func.function.add_to_graph(ops.get_default_graph())
     # pylint: disable=protected-access
-    if compat.forward_compatible(2019, 8, 3):
-      variant_tensor = gen_experimental_dataset_ops.scan_dataset(
-          self._input_dataset._variant_tensor,
-          structure.to_tensor_list(self._state_structure, self._initial_state),
-          self._scan_func.function.captured_inputs,
-          f=self._scan_func.function,
-          preserve_cardinality=True,
-          **self._flat_structure)
-    else:
-      variant_tensor = gen_experimental_dataset_ops.experimental_scan_dataset(
-          self._input_dataset._variant_tensor,
-          structure.to_tensor_list(self._state_structure, self._initial_state),
-          self._scan_func.function.captured_inputs,
-          f=self._scan_func.function,
-          preserve_cardinality=True,
-          **self._flat_structure)
+    variant_tensor = gen_experimental_dataset_ops.scan_dataset(
+        self._input_dataset._variant_tensor,
+        structure.to_tensor_list(self._state_structure, self._initial_state),
+        self._scan_func.function.captured_inputs,
+        f=self._scan_func.function,
+        preserve_cardinality=True,
+        **self._flat_structure)
     super(_ScanDataset, self).__init__(input_dataset, variant_tensor)
 
   def _functions(self):
diff --git a/tensorflow/python/data/experimental/ops/sleep.py b/tensorflow/python/data/experimental/ops/sleep.py
index 837ec00..ff56436 100644
--- a/tensorflow/python/data/experimental/ops/sleep.py
+++ b/tensorflow/python/data/experimental/ops/sleep.py
@@ -17,7 +17,6 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python.compat import compat
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.ops import gen_experimental_dataset_ops
 
@@ -28,16 +27,10 @@
   def __init__(self, input_dataset, sleep_microseconds):
     self._input_dataset = input_dataset
     self._sleep_microseconds = sleep_microseconds
-    if compat.forward_compatible(2019, 8, 3):
-      variant_tensor = gen_experimental_dataset_ops.sleep_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          self._sleep_microseconds,
-          **self._flat_structure)
-    else:
-      variant_tensor = gen_experimental_dataset_ops.experimental_sleep_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          self._sleep_microseconds,
-          **self._flat_structure)
+    variant_tensor = gen_experimental_dataset_ops.sleep_dataset(
+        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
+        self._sleep_microseconds,
+        **self._flat_structure)
     super(_SleepDataset, self).__init__(input_dataset, variant_tensor)
 
 
diff --git a/tensorflow/python/data/experimental/ops/snapshot.py b/tensorflow/python/data/experimental/ops/snapshot.py
index 7e074cc..b0d66c2 100644
--- a/tensorflow/python/data/experimental/ops/snapshot.py
+++ b/tensorflow/python/data/experimental/ops/snapshot.py
@@ -17,7 +17,6 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python.compat import compat
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -66,31 +65,19 @@
     self._input_dataset = input_dataset
     self._path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path")
 
-    if compat.forward_compatible(2019, 8, 15):
-      variant_tensor = ged_ops.snapshot_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          path=self._path,
-          compression=self._compression,
-          reader_path_prefix=self._reader_path_prefix,
-          writer_path_prefix=self._writer_path_prefix,
-          shard_size_bytes=self._shard_size_bytes,
-          pending_snapshot_expiry_seconds=self._pending_snapshot_expiry_seconds,
-          num_reader_threads=self._num_reader_threads,
-          reader_buffer_size=self._reader_buffer_size,
-          num_writer_threads=self._num_writer_threads,
-          writer_buffer_size=self._writer_buffer_size,
-          **self._flat_structure)
-    else:
-      variant_tensor = ged_ops.snapshot_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          path=self._path,
-          compression=self._compression,
-          reader_path_prefix=self._reader_path_prefix,
-          writer_path_prefix=self._writer_path_prefix,
-          shard_size_bytes=self._shard_size_bytes,
-          pending_snapshot_expiry_seconds=self._pending_snapshot_expiry_seconds,
-          **self._flat_structure)
-
+    variant_tensor = ged_ops.snapshot_dataset(
+        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
+        path=self._path,
+        compression=self._compression,
+        reader_path_prefix=self._reader_path_prefix,
+        writer_path_prefix=self._writer_path_prefix,
+        shard_size_bytes=self._shard_size_bytes,
+        pending_snapshot_expiry_seconds=self._pending_snapshot_expiry_seconds,
+        num_reader_threads=self._num_reader_threads,
+        reader_buffer_size=self._reader_buffer_size,
+        num_writer_threads=self._num_writer_threads,
+        writer_buffer_size=self._writer_buffer_size,
+        **self._flat_structure)
     super(_SnapshotDataset, self).__init__(input_dataset, variant_tensor)
 
 
diff --git a/tensorflow/python/data/experimental/ops/stats_aggregator.py b/tensorflow/python/data/experimental/ops/stats_aggregator.py
index cb8239c..d8174ac 100644
--- a/tensorflow/python/data/experimental/ops/stats_aggregator.py
+++ b/tensorflow/python/data/experimental/ops/stats_aggregator.py
@@ -19,7 +19,6 @@
 
 import tempfile
 
-from tensorflow.python.compat import compat
 from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
 from tensorflow.python.ops import summary_ops_v2
 from tensorflow.python.util.tf_export import tf_export
@@ -126,10 +125,7 @@
 
   def __init__(self):
     """Creates a `StatsAggregator`."""
-    if compat.forward_compatible(2019, 8, 3):
-      self._resource = ged_ops.stats_aggregator_handle()
-    else:
-      self._resource = ged_ops.experimental_stats_aggregator_handle()
+    self._resource = ged_ops.stats_aggregator_handle()
 
   def get_summary(self):
     """Returns a string `tf.Tensor` that summarizes the aggregated statistics.
@@ -141,10 +137,7 @@
     Returns:
       A scalar string `tf.Tensor` that summarizes the aggregated statistics.
     """
-    if compat.forward_compatible(2019, 8, 3):
-      return ged_ops.stats_aggregator_summary(self._resource)
-    else:
-      return ged_ops.experimental_stats_aggregator_summary(self._resource)
+    return ged_ops.stats_aggregator_summary(self._resource)
 
 
 # TODO(b/116314787): Change this to StatsAggregatorV2 when we have stable
diff --git a/tensorflow/python/data/experimental/ops/stats_ops.py b/tensorflow/python/data/experimental/ops/stats_ops.py
index 02d3f1e..c132a22 100644
--- a/tensorflow/python/data/experimental/ops/stats_ops.py
+++ b/tensorflow/python/data/experimental/ops/stats_ops.py
@@ -17,7 +17,6 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python.compat import compat
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -66,14 +65,8 @@
   """
 
   def _apply_fn(dataset):
-    if compat.forward_compatible(2019, 8, 3):
-      return _StatsDataset(
-          dataset, gen_experimental_dataset_ops.bytes_produced_stats_dataset,
-          tag)
-    else:
-      return _StatsDataset(
-          dataset, gen_experimental_dataset_ops
-          .experimental_bytes_produced_stats_dataset, tag)
+    return _StatsDataset(
+        dataset, gen_experimental_dataset_ops.bytes_produced_stats_dataset, tag)
 
   return _apply_fn
 
@@ -95,14 +88,8 @@
   """
 
   def _apply_fn(dataset):
-    if compat.forward_compatible(2019, 8, 3):
-      return _StatsDataset(
-          dataset,
-          gen_experimental_dataset_ops.latency_stats_dataset, tag)
-    else:
-      return _StatsDataset(
-          dataset,
-          gen_experimental_dataset_ops.experimental_latency_stats_dataset, tag)
+    return _StatsDataset(
+        dataset, gen_experimental_dataset_ops.latency_stats_dataset, tag)
 
   return _apply_fn
 
diff --git a/tensorflow/python/data/experimental/ops/take_while_ops.py b/tensorflow/python/data/experimental/ops/take_while_ops.py
index 3b8cb2b..fbaf0c2 100644
--- a/tensorflow/python/data/experimental/ops/take_while_ops.py
+++ b/tensorflow/python/data/experimental/ops/take_while_ops.py
@@ -17,7 +17,6 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python.compat import compat
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import tensor_spec
@@ -42,18 +41,11 @@
       raise ValueError("`predicate` must return a scalar boolean tensor.")
 
     self._predicate = wrapped_func
-    if compat.forward_compatible(2019, 8, 3):
-      var_tensor = gen_experimental_dataset_ops.take_while_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          other_arguments=self._predicate.function.captured_inputs,
-          predicate=self._predicate.function,
-          **self._flat_structure)
-    else:
-      var_tensor = gen_experimental_dataset_ops.experimental_take_while_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          other_arguments=self._predicate.function.captured_inputs,
-          predicate=self._predicate.function,
-          **self._flat_structure)
+    var_tensor = gen_experimental_dataset_ops.take_while_dataset(
+        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
+        other_arguments=self._predicate.function.captured_inputs,
+        predicate=self._predicate.function,
+        **self._flat_structure)
     super(_TakeWhileDataset, self).__init__(input_dataset, var_tensor)
 
   def _functions(self):
diff --git a/tensorflow/python/data/experimental/ops/threadpool.py b/tensorflow/python/data/experimental/ops/threadpool.py
index 0997e46..c30b36c 100644
--- a/tensorflow/python/data/experimental/ops/threadpool.py
+++ b/tensorflow/python/data/experimental/ops/threadpool.py
@@ -19,7 +19,6 @@
 
 import threading
 
-from tensorflow.python.compat import compat
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.eager import context
 from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
@@ -47,31 +46,18 @@
     """Creates a `PrivateThreadPool` with the given number of threads."""
     if context.executing_eagerly():
       shared_name = _generate_shared_name("privatethreadpool")
-      if compat.forward_compatible(2019, 8, 3):
-        self._resource = ged_ops.thread_pool_handle(
-            num_threads=num_threads,
-            max_intra_op_parallelism=max_intra_op_parallelism,
-            display_name=display_name,
-            shared_name=shared_name)
-      else:
-        self._resource = ged_ops.experimental_thread_pool_handle(
-            num_threads=num_threads,
-            max_intra_op_parallelism=max_intra_op_parallelism,
-            display_name=display_name,
-            shared_name=shared_name)
+      self._resource = ged_ops.thread_pool_handle(
+          num_threads=num_threads,
+          max_intra_op_parallelism=max_intra_op_parallelism,
+          display_name=display_name,
+          shared_name=shared_name)
       self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
           handle=self._resource, handle_device=context.context().device_name)
     else:
-      if compat.forward_compatible(2019, 8, 3):
-        self._resource = ged_ops.thread_pool_handle(
-            num_threads=num_threads,
-            max_intra_op_parallelism=max_intra_op_parallelism,
-            display_name=display_name)
-      else:
-        self._resource = ged_ops.experimental_thread_pool_handle(
-            num_threads=num_threads,
-            max_intra_op_parallelism=max_intra_op_parallelism,
-            display_name=display_name)
+      self._resource = ged_ops.thread_pool_handle(
+          num_threads=num_threads,
+          max_intra_op_parallelism=max_intra_op_parallelism,
+          display_name=display_name)
 
 
 class _ThreadPoolDataset(dataset_ops.UnaryUnchangedStructureDataset):
@@ -80,16 +66,10 @@
   def __init__(self, input_dataset, thread_pool):
     self._input_dataset = input_dataset
     self._thread_pool = thread_pool
-    if compat.forward_compatible(2019, 8, 3):
-      variant_tensor = ged_ops.thread_pool_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          self._thread_pool._resource,  # pylint: disable=protected-access
-          **self._flat_structure)
-    else:
-      variant_tensor = ged_ops.experimental_thread_pool_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          self._thread_pool._resource,  # pylint: disable=protected-access
-          **self._flat_structure)
+    variant_tensor = ged_ops.thread_pool_dataset(
+        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
+        self._thread_pool._resource,  # pylint: disable=protected-access
+        **self._flat_structure)
     super(_ThreadPoolDataset, self).__init__(input_dataset, variant_tensor)
 
 
diff --git a/tensorflow/python/data/experimental/ops/unique.py b/tensorflow/python/data/experimental/ops/unique.py
index 396ec7a..057c9ca 100644
--- a/tensorflow/python/data/experimental/ops/unique.py
+++ b/tensorflow/python/data/experimental/ops/unique.py
@@ -17,7 +17,6 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python.compat import compat
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.framework import dtypes
 from tensorflow.python.ops import gen_experimental_dataset_ops
@@ -60,12 +59,7 @@
       raise TypeError(
           "`tf.data.experimental.unique()` only supports inputs with a single "
           "`tf.int32`, `tf.int64`, or `tf.string` component.")
-    if compat.forward_compatible(2019, 8, 3):
-      variant_tensor = gen_experimental_dataset_ops.unique_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          **self._flat_structure)
-    else:
-      variant_tensor = gen_experimental_dataset_ops.experimental_unique_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          **self._flat_structure)
+    variant_tensor = gen_experimental_dataset_ops.unique_dataset(
+        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
+        **self._flat_structure)
     super(_UniqueDataset, self).__init__(input_dataset, variant_tensor)
diff --git a/tensorflow/python/data/experimental/ops/writers.py b/tensorflow/python/data/experimental/ops/writers.py
index 21c0c73..0d1785c 100644
--- a/tensorflow/python/data/experimental/ops/writers.py
+++ b/tensorflow/python/data/experimental/ops/writers.py
@@ -17,7 +17,6 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python.compat import compat
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.util import convert
 from tensorflow.python.framework import dtypes
@@ -84,9 +83,5 @@
           "produces shape {0} and types {1}".format(
               dataset_ops.get_legacy_output_shapes(dataset),
               dataset_ops.get_legacy_output_types(dataset)))
-    if compat.forward_compatible(2019, 8, 3):
-      return gen_experimental_dataset_ops.dataset_to_tf_record(
-          dataset._variant_tensor, self._filename, self._compression_type)  # pylint: disable=protected-access
-    else:
-      return gen_experimental_dataset_ops.experimental_dataset_to_tf_record(
-          dataset._variant_tensor, self._filename, self._compression_type)  # pylint: disable=protected-access
+    return gen_experimental_dataset_ops.dataset_to_tf_record(
+        dataset._variant_tensor, self._filename, self._compression_type)  # pylint: disable=protected-access
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index f76277a..18eb215 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -47,6 +47,24 @@
 )
 
 tf_py_test(
+    name = "checkpoint_test",
+    size = "medium",
+    srcs = ["checkpoint_test.py"],
+    additional_deps = [
+        ":test_base",
+        "//tensorflow/python/data/ops:dataset_ops",
+        "//tensorflow/python/eager:context",
+        "//tensorflow/python/training/tracking:util",
+        "//tensorflow/python:checkpoint_management",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:errors",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:math_ops",
+    ],
+    grpc_enabled = True,
+)
+
+tf_py_test(
     name = "concatenate_test",
     size = "small",
     srcs = ["concatenate_test.py"],
@@ -62,28 +80,6 @@
 )
 
 tf_py_test(
-    name = "dataset_checkpoint_test",
-    size = "small",
-    srcs = ["dataset_checkpoint_test.py"],
-    additional_deps = [
-        ":test_base",
-        "//tensorflow/python/data/ops:dataset_ops",
-        "//tensorflow/python/data/ops:iterator_ops",
-        "//tensorflow/python:client_testlib",
-        "//tensorflow/python:dataset_ops_gen",
-        "//tensorflow/python:dtypes",
-        "//tensorflow/python:errors",
-        "//tensorflow/python:framework_ops",
-        "//tensorflow/python:framework_test_lib",
-        "//tensorflow/python:io_ops",
-        "//tensorflow/python:parsing_ops",
-        "//tensorflow/python:platform",
-        "//tensorflow/python:tensor_shape",
-        "//tensorflow/python:variables",
-    ],
-)
-
-tf_py_test(
     name = "dataset_test",
     size = "small",
     srcs = ["dataset_test.py"],
@@ -309,24 +305,6 @@
 )
 
 tf_py_test(
-    name = "iterator_checkpoint_test",
-    size = "medium",
-    srcs = ["iterator_checkpoint_test.py"],
-    additional_deps = [
-        ":test_base",
-        "//tensorflow/python/data/ops:dataset_ops",
-        "//tensorflow/python/eager:context",
-        "//tensorflow/python/training/tracking:util",
-        "//tensorflow/python:checkpoint_management",
-        "//tensorflow/python:client_testlib",
-        "//tensorflow/python:errors",
-        "//tensorflow/python:framework_test_lib",
-        "//tensorflow/python:math_ops",
-    ],
-    grpc_enabled = True,
-)
-
-tf_py_test(
     name = "iterator_cluster_test",
     size = "small",
     srcs = ["iterator_cluster_test.py"],
@@ -468,6 +446,7 @@
         "//tensorflow/python:framework_test_lib",
     ],
     tags = [
+        "no_oss",
         "no_windows_gpu",
     ],
     xla_enable_strict_auto_jit = True,
@@ -712,6 +691,7 @@
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:dtypes",
         "//tensorflow/python:errors",
+        "//tensorflow/python:framework_combinations",
         "//tensorflow/python:framework_test_lib",
         "//tensorflow/python:sparse_tensor",
         "//tensorflow/python/data/ops:dataset_ops",
diff --git a/tensorflow/python/data/kernel_tests/batch_test.py b/tensorflow/python/data/kernel_tests/batch_test.py
index 2f049e4..6a3a2cc 100644
--- a/tensorflow/python/data/kernel_tests/batch_test.py
+++ b/tensorflow/python/data/kernel_tests/batch_test.py
@@ -24,10 +24,10 @@
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.util import nest
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops.ragged import ragged_concat_ops
@@ -37,15 +37,14 @@
 from tensorflow.python.platform import test
 
 
-@test_util.run_all_in_graph_and_eager_modes
 class BatchTest(test_base.DatasetTestBase, parameterized.TestCase):
 
-  @parameterized.named_parameters(
-      ('even', 28, 14, False),
-      ('uneven_with_remainder', 28, 15, False),
-      ('uneven_without_remainder', 28, 15, True),
-      ('empty', 0, 14, False),
-  )
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(
+              count=[0, 28], batch_size=[14, 15], drop_remainder=[True,
+                                                                  False])))
   def testBasic(self, count, batch_size, drop_remainder):
     """Tests the batch dataset logic for various input configurations.
 
@@ -95,11 +94,13 @@
     with self.assertRaises(errors.OutOfRangeError):
       result = self.evaluate(get_next())
 
+  @combinations.generate(test_base.default_test_combinations())
   def testInvalidBatchSize(self):
     with self.assertRaises(errors.InvalidArgumentError):
       dataset = (dataset_ops.Dataset.range(10).batch(0))
       self.evaluate(dataset._variant_tensor)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testDataset(self):
 
     def map_fn(i):
@@ -125,6 +126,7 @@
     ]
     self.assertDatasetProduces(dataset, expected_output=expected_output)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testSparseWithDifferentDenseShapes(self):
 
     def _sparse(i):
@@ -150,6 +152,7 @@
               dense_shape=[5, (i + 1) * 5 - 1]))
     self.assertDatasetProduces(dataset, expected_output=expected_output)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testSparseNested(self):
 
     def _sparse(i):
@@ -166,6 +169,7 @@
     ]
     self.assertDatasetProduces(dataset, expected_output=expected_output)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testShapeError(self):
 
     def generator():
@@ -183,7 +187,7 @@
             r'Cannot batch tensors with different shapes in component 0. First '
             r'element had shape \[3\] and element 2 had shape \[4\].'))
 
-  # Ragged Tensors.
+  @combinations.generate(test_base.default_test_combinations())
   def testRagged(self):
 
     def _ragged(i):
@@ -196,6 +200,7 @@
     ]
     self.assertDatasetProduces(dataset, expected_output=expected_output)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testRaggedWithDifferentShapes(self):
     dataset = dataset_ops.Dataset.range(10).map(ragged_math_ops.range).batch(5)
     expected_output = [
@@ -205,6 +210,7 @@
     ]
     self.assertDatasetProduces(dataset, expected_output=expected_output)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testRaggedNested(self):
 
     def _ragged(i):
diff --git a/tensorflow/python/data/kernel_tests/cache_test.py b/tensorflow/python/data/kernel_tests/cache_test.py
index 305092c..a7df114 100644
--- a/tensorflow/python/data/kernel_tests/cache_test.py
+++ b/tensorflow/python/data/kernel_tests/cache_test.py
@@ -17,34 +17,39 @@
 from __future__ import division
 from __future__ import print_function
 
+import functools
 from os import path
 import shutil
 import tempfile
 
+from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
-from tensorflow.python.framework import test_util
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
 
 
-@test_util.run_all_in_graph_and_eager_modes
-class FileCacheTest(test_base.DatasetTestBase):
+class FileCacheTest(test_base.DatasetTestBase, parameterized.TestCase):
 
   def setUp(self):
+    super(FileCacheTest, self).setUp()
     self.tmp_dir = tempfile.mkdtemp()
     self.cache_prefix = path.join(self.tmp_dir, "cache")
 
   def tearDown(self):
+    super(FileCacheTest, self).tearDown()
     if self.tmp_dir:
       shutil.rmtree(self.tmp_dir, ignore_errors=True)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testCacheDatasetPassthrough(self):
     components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
                   np.array([9.0, 10.0, 11.0, 12.0]))
@@ -97,6 +102,7 @@
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
 
+  @combinations.generate(test_base.default_test_combinations())
   def testConcurrentWriters(self):
     components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
                   np.array([9.0, 10.0, 11.0, 12.0]))
@@ -118,6 +124,7 @@
 
     self.evaluate(get_next1())  # this should continue to succeed
 
+  @combinations.generate(test_base.default_test_combinations())
   def testConcurrentReaders(self):
     components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
                   np.array([9.0, 10.0, 11.0, 12.0]))
@@ -164,16 +171,36 @@
     self.assertAllEqual(elements, elements_itr1)
     self.assertAllEqual(elements, elements_itr2)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testReadingPastEndOfSequence(self):
     dataset = dataset_ops.Dataset.range(10).cache(self.cache_prefix)
     dataset = dataset.map(lambda a: a).batch(4).repeat(2)
     expected_output = [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]] * 2
     self.assertDatasetProduces(dataset, expected_output)
 
+  @combinations.generate(test_base.default_test_combinations())
+  def testCleaningUpCacheFiles(self):
 
-@test_util.run_all_in_graph_and_eager_modes
-class MemoryCacheTest(test_base.DatasetTestBase):
+    def do_test(i):
+      dataset = dataset_ops.Dataset.range(10).cache(self.cache_prefix)
+      get_next = self.getNext(dataset)
+      for _ in range(i):
+        try:
+          self.evaluate(get_next())
+        except errors.OutOfRangeError:
+          break
 
+    if not context.executing_eagerly():
+      self.skipTest(
+          "Test requires eager mode for iterators to be deconstructed")
+
+    for i in [0, 3, 10, 12, 15]:
+      do_test(i)
+
+
+class MemoryCacheTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+  @combinations.generate(test_base.default_test_combinations())
   def testCacheDatasetPassthrough(self):
     with ops.device("cpu:0"):
       repeat_count = variables.Variable(constant_op.constant(10, dtypes.int64))
@@ -208,6 +235,7 @@
       with self.assertRaises(errors.OutOfRangeError):
         self.evaluate(cached_next())
 
+  @combinations.generate(test_base.default_test_combinations())
   def testEmptyCacheReading(self):
     components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
                   np.array([9.0, 10.0, 11.0, 12.0]))
@@ -220,11 +248,12 @@
     # caching, respectively.
     self.assertDatasetProduces(cache_dataset, expected_output=[])
 
+  @combinations.generate(test_base.default_test_combinations())
   def testConcurrentReaders(self):
 
-    dataset = dataset_ops.Dataset.range(5).cache()
-    d1 = dataset.map(lambda x: x + 1)
-    d2 = dataset.map(lambda x: x + 6)
+    dataset_fn = lambda: dataset_ops.Dataset.range(5).cache()
+    d1 = dataset_fn().map(lambda x: x + 1)
+    d2 = dataset_fn().map(lambda x: x + 6)
 
     get_next1 = self.getNext(d1)
 
@@ -248,12 +277,81 @@
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next1())
 
+  @combinations.generate(test_base.default_test_combinations())
   def testCacheTakeRepeat(self):
     dataset = dataset_ops.Dataset.range(10).cache().take(5).repeat(2)
 
     expected_output = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
     self.assertDatasetProduces(dataset, expected_output=expected_output)
 
+  @combinations.generate(test_base.default_test_combinations())
+  def testCacheRepeatEpochs(self):
+    counter = variables.Variable(0)
+    self.evaluate(counter.initializer)
+
+    def increment_fn(x):
+      counter.assign_add(1)
+      return x
+
+    dataset = dataset_ops.Dataset.range(10).map(increment_fn).cache().repeat(2)
+    get_next = self.getNext(dataset, requires_initialization=True)
+
+    # first epoch
+    for i in range(10):
+      self.assertEqual(i, self.evaluate(counter))
+      self.assertEqual(i, self.evaluate(get_next()))
+    # second epoch
+    for i in range(10):
+      self.assertEqual(10, self.evaluate(counter))
+      self.assertEqual(i, self.evaluate(get_next()))
+    with self.assertRaises(errors.OutOfRangeError):
+      self.evaluate(get_next())
+
+  @combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
+  def testCacheIterationEpochs(self):
+    counter = variables.Variable(0)
+    self.evaluate(counter.initializer)
+
+    def increment_fn(x):
+      counter.assign_add(1)
+      return x
+
+    dataset = dataset_ops.Dataset.range(10).map(increment_fn).cache()
+
+    # first epoch
+    i = 0
+    for elem in dataset:
+      self.assertEqual(i, self.evaluate(elem))
+      i += 1
+      self.assertEqual(i, self.evaluate(counter))
+
+    # second epoch
+    i = 0
+    for elem in dataset:
+      self.assertEqual(10, self.evaluate(counter))
+      self.assertEqual(i, self.evaluate(elem))
+      i += 1
+
+  @combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
+  def testCacheV2ResourceCapture(self):
+
+    def make_dataset():
+      ids = dataset_ops.Dataset.range(10)
+      ids = ids.cache()
+
+      def interleave_fn(dataset, _):
+        return dataset
+
+      dataset = dataset_ops.Dataset.range(1)
+      dataset = dataset.interleave(functools.partial(interleave_fn, ids))
+      return dataset
+
+    results = []
+    for elem in make_dataset():
+      results.append(elem.numpy())
+
+    self.assertAllEqual(results, range(10))
+
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/data/kernel_tests/checkpoint_test.py b/tensorflow/python/data/kernel_tests/checkpoint_test.py
new file mode 100644
index 0000000..738d09b
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/checkpoint_test.py
@@ -0,0 +1,395 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for checkpointing tf.data iterators."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl.testing import parameterized
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import combinations
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training.tracking import util as trackable_utils
+
+
+# TODO(jsimsa): Add missing test combinations.
+class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+  def tearDown(self):
+    super(CheckpointTest, self).tearDown()
+    prefix = self._iterator_checkpoint_prefix()
+    pattern = prefix + "*"
+    files = gfile.Glob(pattern)
+    map(gfile.Remove, files)
+
+  def _iterator_checkpoint_prefix(self):
+    return os.path.join(self.get_temp_dir(), "iterator")
+
+  def _save_op(self, iterator_resource):
+    iterator_state_variant = gen_dataset_ops.serialize_iterator(
+        iterator_resource)
+    save_op = io_ops.write_file(
+        self._iterator_checkpoint_prefix(),
+        parsing_ops.serialize_tensor(iterator_state_variant))
+    return save_op
+
+  def _restore_op(self, iterator_resource):
+    iterator_state_variant = parsing_ops.parse_tensor(
+        io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant)
+    restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+                                                      iterator_state_variant)
+    return restore_op
+
+  @combinations.generate(
+      combinations.combine(tf_api_version=[1, 2], mode="graph"))
+  def testSaveRestore(self):
+
+    def _build_graph(start, stop):
+      iterator = dataset_ops.make_initializable_iterator(
+          dataset_ops.Dataset.range(start, stop))
+      init_op = iterator.initializer
+      get_next = iterator.get_next()
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
+      return init_op, get_next, save_op, restore_op
+
+    # Saving and restoring in different sessions.
+    start = 2
+    stop = 10
+    break_point = 5
+    with ops.Graph().as_default() as g:
+      init_op, get_next, save_op, _ = _build_graph(start, stop)
+      with self.session(graph=g) as sess:
+        sess.run(variables.global_variables_initializer())
+        sess.run(init_op)
+        for i in range(start, break_point):
+          self.assertEqual(i, sess.run(get_next))
+        sess.run(save_op)
+
+    with ops.Graph().as_default() as g:
+      init_op, get_next, _, restore_op = _build_graph(start, stop)
+      with self.session(graph=g) as sess:
+        sess.run(init_op)
+        sess.run(restore_op)
+        for i in range(break_point, stop):
+          self.assertEqual(i, sess.run(get_next))
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next)
+
+    # Saving and restoring in same session.
+    with ops.Graph().as_default() as g:
+      init_op, get_next, save_op, restore_op = _build_graph(start, stop)
+      with self.session(graph=g) as sess:
+        sess.run(variables.global_variables_initializer())
+        sess.run(init_op)
+        for i in range(start, break_point):
+          self.assertEqual(i, sess.run(get_next))
+        sess.run(save_op)
+        sess.run(init_op)
+        sess.run(restore_op)
+        for i in range(break_point, stop):
+          self.assertEqual(i, sess.run(get_next))
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next)
+
+  @combinations.generate(
+      combinations.combine(tf_api_version=[1, 2], mode="graph"))
+  def testInitThenRestore(self):
+    # Note: Calling init_op before restore_op is redundant. This test just makes
+    # sure we do not fail if restore is called on an already initialized
+    # iterator resource.
+
+    def _build_graph(start, stop):
+      dataset = dataset_ops.Dataset.range(start, stop)
+      iterator = dataset_ops.make_initializable_iterator(dataset)
+      init_op = iterator.initializer
+      get_next = iterator.get_next()
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
+      return init_op, get_next, save_op, restore_op
+
+    # Saving and restoring in different sessions.
+    start = 2
+    stop = 10
+    break_point = 5
+    with ops.Graph().as_default() as g:
+      init_op, get_next, save_op, _ = _build_graph(start, stop)
+      with self.session(graph=g) as sess:
+        sess.run(variables.global_variables_initializer())
+        sess.run(init_op)
+        for i in range(start, break_point):
+          self.assertEqual(i, sess.run(get_next))
+        sess.run(save_op)
+
+    with ops.Graph().as_default() as g:
+      init_op, get_next, _, restore_op = _build_graph(start, stop)
+      with self.session(graph=g) as sess:
+        sess.run(init_op)
+        sess.run(restore_op)
+        for i in range(break_point, stop):
+          self.assertEqual(i, sess.run(get_next))
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next)
+
+  @combinations.generate(
+      combinations.combine(tf_api_version=[1, 2], mode="graph"))
+  def testMultipleSaves(self):
+
+    def _build_graph(start, stop):
+      iterator = dataset_ops.make_initializable_iterator(
+          dataset_ops.Dataset.range(start, stop))
+      init_op = iterator.initializer
+      get_next = iterator.get_next()
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
+      return init_op, get_next, save_op, restore_op
+
+    start = 2
+    stop = 10
+    break_point1 = 5
+    break_point2 = 7
+
+    with ops.Graph().as_default() as g:
+      init_op, get_next, save_op, _ = _build_graph(start, stop)
+      with self.session(graph=g) as sess:
+        sess.run(variables.global_variables_initializer())
+        sess.run(init_op)
+        for i in range(start, break_point1):
+          self.assertEqual(i, sess.run(get_next))
+        sess.run(save_op)
+
+    with ops.Graph().as_default() as g:
+      init_op, get_next, save_op, restore_op = _build_graph(start, stop)
+      with self.session(graph=g) as sess:
+        sess.run(init_op)
+        sess.run(restore_op)
+        for i in range(break_point1, break_point2):
+          self.assertEqual(i, sess.run(get_next))
+        sess.run(save_op)
+
+    break_point2 = 7
+    with ops.Graph().as_default() as g:
+      init_op, get_next, save_op, restore_op = _build_graph(start, stop)
+      with self.session(graph=g) as sess:
+        sess.run(init_op)
+        sess.run(restore_op)
+        for i in range(break_point2, stop):
+          self.assertEqual(i, sess.run(get_next))
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next)
+
+  @combinations.generate(
+      combinations.combine(tf_api_version=[1, 2], mode="graph"))
+  def testSaveRestoreWithRepeat(self):
+
+    def _build_graph(start, stop, num_epochs):
+      iterator = dataset_ops.make_initializable_iterator(
+          dataset_ops.Dataset.range(start, stop).repeat(num_epochs))
+      init_op = iterator.initializer
+      get_next = iterator.get_next()
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
+      return init_op, get_next, save_op, restore_op
+
+    start = 2
+    stop = 10
+    num_epochs = 5
+    break_range = 5
+    break_epoch = 3
+    with ops.Graph().as_default() as g:
+      init_op, get_next, save_op, restore_op = _build_graph(
+          start, stop, num_epochs)
+      with self.session(graph=g) as sess:
+        sess.run(variables.global_variables_initializer())
+        sess.run(init_op)
+        # Note: There is no checkpoint saved currently so a NotFoundError is
+        # raised.
+        with self.assertRaises(errors.NotFoundError):
+          sess.run(init_op)
+          sess.run(restore_op)
+        for _ in range(break_epoch - 1):
+          for i in range(start, stop):
+            self.assertEqual(i, sess.run(get_next))
+        for i in range(start, break_range):
+          self.assertEqual(i, sess.run(get_next))
+        sess.run(save_op)
+
+    with ops.Graph().as_default() as g:
+      init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
+      with self.session(graph=g) as sess:
+        sess.run(init_op)
+        sess.run(restore_op)
+        for i in range(break_range, stop):
+          self.assertEqual(i, sess.run(get_next))
+        for _ in range(break_epoch, num_epochs):
+          for i in range(start, stop):
+            self.assertEqual(i, sess.run(get_next))
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next)
+
+  @combinations.generate(
+      combinations.combine(tf_api_version=[1, 2], mode="graph"))
+  def testSaveRestoreExhaustedIterator(self):
+
+    def _build_graph(start, stop, num_epochs):
+      iterator = dataset_ops.make_initializable_iterator(
+          dataset_ops.Dataset.range(start, stop).repeat(num_epochs))
+      init_op = iterator.initializer
+      get_next = iterator.get_next()
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
+      return init_op, get_next, save_op, restore_op
+
+    start = 2
+    stop = 10
+    num_epochs = 5
+    with ops.Graph().as_default() as g:
+      init_op, get_next, save_op, restore_op = _build_graph(
+          start, stop, num_epochs)
+      with self.session(graph=g) as sess:
+        sess.run(variables.global_variables_initializer())
+        sess.run(init_op)
+        # Note: There is no checkpoint saved currently so a NotFoundError is
+        # raised.
+        with self.assertRaises(errors.NotFoundError):
+          sess.run(init_op)
+          sess.run(restore_op)
+        for _ in range(num_epochs):
+          for i in range(start, stop):
+            self.assertEqual(i, sess.run(get_next))
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next)
+        sess.run(save_op)
+
+    with ops.Graph().as_default() as g:
+      init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
+      with self.session(graph=g) as sess:
+        sess.run(init_op)
+        sess.run(restore_op)
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next)
+
+  @combinations.generate(
+      combinations.combine(tf_api_version=[1, 2], mode="eager"))
+  def testSaveRestoreOneShotIterator(self):
+    checkpoint_directory = self.get_temp_dir()
+    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+    dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]).map(
+        math_ops.square).batch(2)
+    # TODO(b/138399725): Re-enable default optimizations.
+    options = dataset_ops.Options()
+    options.experimental_optimization.apply_default_optimizations = False
+    dataset = dataset.with_options(options)
+    iterator = iter(dataset)
+    get_next = iterator.get_next
+    checkpoint = trackable_utils.Checkpoint(iterator=iterator)
+    self.assertAllEqual([1, 4], get_next())
+    save_path = checkpoint.save(checkpoint_prefix)
+    self.assertAllEqual([9, 16], get_next())
+    self.assertAllEqual([25, 36], get_next())
+    checkpoint.restore(save_path).run_restore_ops()
+    self.assertAllEqual([9, 16], get_next())
+    self.assertAllEqual([25, 36], get_next())
+    with self.assertRaises(errors.OutOfRangeError):
+      get_next()
+
+  @combinations.generate(
+      combinations.combine(tf_api_version=[1, 2], mode="eager"))
+  def testSaveRestoreMultipleIterator(self):
+    checkpoint_directory = self.get_temp_dir()
+    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+    dataset = dataset_ops.Dataset.from_tensor_slices(
+        [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
+    dataset = dataset.map(math_ops.square).batch(2)
+    # TODO(b/138399725): Re-enable default optimizations.
+    options = dataset_ops.Options()
+    options.experimental_optimization.apply_default_optimizations = False
+    dataset = dataset.with_options(options)
+    iterator_1 = iter(dataset)
+    get_next_1 = iterator_1.get_next
+    iterator_2 = iter(dataset)
+    get_next_2 = iterator_2.get_next
+    dataset_2 = dataset_ops.Dataset.range(10)
+    iterator_3 = iter(dataset_2)
+    get_next_3 = iterator_3.get_next
+    checkpoint = trackable_utils.Checkpoint(
+        iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3)
+    self.assertAllEqual([1, 4], get_next_1())
+    self.assertAllEqual(0, get_next_3())
+    self.assertAllEqual(1, get_next_3())
+    self.assertAllEqual(2, get_next_3())
+    save_path = checkpoint.save(checkpoint_prefix)
+    self.assertAllEqual([1, 4], get_next_2())
+    self.assertAllEqual([9, 16], get_next_2())
+    self.assertAllEqual(3, get_next_3())
+    checkpoint.restore(save_path).run_restore_ops()
+    self.assertAllEqual([9, 16], get_next_1())
+    self.assertAllEqual([1, 4], get_next_2())
+    self.assertAllEqual(3, get_next_3())
+
+  @combinations.generate(
+      combinations.combine(tf_api_version=[1, 2], mode="eager"))
+  def testRestoreExhaustedIterator(self):
+    checkpoint_directory = self.get_temp_dir()
+    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+    dataset = dataset_ops.Dataset.range(3)
+    iterator = iter(dataset)
+    get_next = iterator.get_next
+    checkpoint = trackable_utils.Checkpoint(iterator=iterator)
+    self.assertAllEqual(0, get_next())
+    self.assertAllEqual(1, get_next())
+    save_path = checkpoint.save(checkpoint_prefix)
+    self.assertAllEqual(2, get_next())
+    checkpoint.restore(save_path).run_restore_ops()
+    self.assertAllEqual(2, get_next())
+    save_path = checkpoint.save(checkpoint_prefix)
+    checkpoint.restore(save_path).run_restore_ops()
+    with self.assertRaises(errors.OutOfRangeError):
+      get_next()
+
+  @combinations.generate(
+      combinations.combine(tf_api_version=[1, 2], mode="eager"))
+  def testRestoreInReconstructedIteratorInitializable(self):
+    checkpoint_directory = self.get_temp_dir()
+    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+    dataset = dataset_ops.Dataset.range(10)
+    iterator = iter(dataset)
+    get_next = iterator.get_next
+    checkpoint = trackable_utils.Checkpoint(iterator=iterator)
+    for i in range(5):
+      checkpoint.restore(
+          checkpoint_management.latest_checkpoint(
+              checkpoint_directory)).initialize_or_restore()
+      for j in range(2):
+        self.assertEqual(i * 2 + j, self.evaluate(get_next()))
+      checkpoint.save(file_prefix=checkpoint_prefix)
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/python/data/kernel_tests/concatenate_test.py b/tensorflow/python/data/kernel_tests/concatenate_test.py
index 384fd28..bf72660 100644
--- a/tensorflow/python/data/kernel_tests/concatenate_test.py
+++ b/tensorflow/python/data/kernel_tests/concatenate_test.py
@@ -17,20 +17,21 @@
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.util import nest
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import test_util
 from tensorflow.python.platform import test
 
 
-@test_util.run_all_in_graph_and_eager_modes
-class ConcatenateTest(test_base.DatasetTestBase):
+class ConcatenateTest(test_base.DatasetTestBase, parameterized.TestCase):
 
+  @combinations.generate(test_base.default_test_combinations())
   def testConcatenateDataset(self):
     input_components = (
         np.tile(np.array([[1], [2], [3], [4]]), 20),
@@ -64,6 +65,7 @@
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
 
+  @combinations.generate(test_base.default_test_combinations())
   def testConcatenateDatasetDifferentShape(self):
     input_components = (
         np.tile(np.array([[1], [2], [3], [4]]), 20),
@@ -94,6 +96,7 @@
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
 
+  @combinations.generate(test_base.default_test_combinations())
   def testConcatenateDatasetDifferentStructure(self):
     input_components = (
         np.tile(np.array([[1], [2], [3], [4]]), 5),
@@ -110,6 +113,7 @@
     with self.assertRaisesRegexp(TypeError, "have different types"):
       input_dataset.concatenate(dataset_to_concatenate)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testConcatenateDatasetDifferentKeys(self):
     input_components = {
         "foo": np.array([[1], [2], [3], [4]]),
@@ -127,6 +131,7 @@
     with self.assertRaisesRegexp(TypeError, "have different types"):
       input_dataset.concatenate(dataset_to_concatenate)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testConcatenateDatasetDifferentType(self):
     input_components = (
         np.tile(np.array([[1], [2], [3], [4]]), 5),
diff --git a/tensorflow/python/data/kernel_tests/dataset_checkpoint_test.py b/tensorflow/python/data/kernel_tests/dataset_checkpoint_test.py
deleted file mode 100644
index 82bdf20..0000000
--- a/tensorflow/python/data/kernel_tests/dataset_checkpoint_test.py
+++ /dev/null
@@ -1,361 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Checkpoint tests for `tf.data.Dataset`."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import io_ops
-from tensorflow.python.ops import parsing_ops
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import gfile
-from tensorflow.python.platform import test
-
-
-class DatasetCheckpointTest(test_base.DatasetTestBase):
-
-  def tearDown(self):
-    # Remove all checkpoint files.
-    prefix = self._iterator_checkpoint_prefix()
-    pattern = prefix + "*"
-    files = gfile.Glob(pattern)
-    map(gfile.Remove, files)
-
-  def _iterator_checkpoint_prefix(self):
-    return os.path.join(self.get_temp_dir(), "iterator")
-
-  def _save_op(self, iterator_resource):
-    iterator_state_variant = gen_dataset_ops.serialize_iterator(
-        iterator_resource)
-    save_op = io_ops.write_file(
-        self._iterator_checkpoint_prefix(),
-        parsing_ops.serialize_tensor(iterator_state_variant))
-    return save_op
-
-  def _restore_op(self, iterator_resource):
-    iterator_state_variant = parsing_ops.parse_tensor(
-        io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant)
-    restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
-                                                      iterator_state_variant)
-    return restore_op
-
-  def testSaveRestore(self):
-
-    def _build_graph(start, stop):
-      iterator = dataset_ops.make_initializable_iterator(
-          dataset_ops.Dataset.range(start, stop))
-      init_op = iterator.initializer
-      get_next = iterator.get_next()
-      save_op = self._save_op(iterator._iterator_resource)
-      restore_op = self._restore_op(iterator._iterator_resource)
-      return init_op, get_next, save_op, restore_op
-
-    # Saving and restoring in different sessions.
-    start = 2
-    stop = 10
-    break_point = 5
-    with ops.Graph().as_default() as g:
-      init_op, get_next, save_op, _ = _build_graph(start, stop)
-      with self.session(graph=g) as sess:
-        sess.run(variables.global_variables_initializer())
-        sess.run(init_op)
-        for i in range(start, break_point):
-          self.assertEqual(i, sess.run(get_next))
-        sess.run(save_op)
-
-    with ops.Graph().as_default() as g:
-      init_op, get_next, _, restore_op = _build_graph(start, stop)
-      with self.session(graph=g) as sess:
-        sess.run(init_op)
-        sess.run(restore_op)
-        for i in range(break_point, stop):
-          self.assertEqual(i, sess.run(get_next))
-        with self.assertRaises(errors.OutOfRangeError):
-          sess.run(get_next)
-
-    # Saving and restoring in same session.
-    with ops.Graph().as_default() as g:
-      init_op, get_next, save_op, restore_op = _build_graph(start, stop)
-      with self.session(graph=g) as sess:
-        sess.run(variables.global_variables_initializer())
-        sess.run(init_op)
-        for i in range(start, break_point):
-          self.assertEqual(i, sess.run(get_next))
-        sess.run(save_op)
-        sess.run(restore_op)
-        for i in range(break_point, stop):
-          self.assertEqual(i, sess.run(get_next))
-        with self.assertRaises(errors.OutOfRangeError):
-          sess.run(get_next)
-
-  def testRestoreWithoutBuildingDatasetGraph(self):
-
-    def _build_graph(start, stop, num_epochs):
-      dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs)
-      iterator = dataset_ops.make_initializable_iterator(dataset)
-      init_op = iterator.initializer
-      get_next = iterator.get_next()
-      save_op = self._save_op(iterator._iterator_resource)
-      restore_op = self._restore_op(iterator._iterator_resource)
-      return init_op, get_next, save_op, restore_op
-
-    # Saving and restoring in different sessions.
-    start = 2
-    stop = 10
-    num_epochs = 5
-    break_point = 5
-    break_epoch = 3
-    with ops.Graph().as_default() as g:
-      init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs)
-      with self.session(graph=g) as sess:
-        sess.run(variables.global_variables_initializer())
-        sess.run(init_op)
-        for _ in range(break_epoch):
-          for i in range(start, stop):
-            self.assertEqual(i, sess.run(get_next))
-        for i in range(start, break_point):
-          self.assertEqual(i, sess.run(get_next))
-        sess.run(save_op)
-
-    with ops.Graph().as_default() as g:
-      # Create an empty IteratorResource and restore the Iterator into it.
-      output_types = dtypes.int64
-      output_shapes = tensor_shape.TensorShape([])
-      iterator = iterator_ops.Iterator.from_structure(output_types,
-                                                      output_shapes)
-      restore_op = self._restore_op(iterator._iterator_resource)
-      get_next = iterator.get_next()
-      with self.session(graph=g) as sess:
-        sess.run(restore_op)
-        for i in range(break_point, stop):
-          self.assertEqual(i, sess.run(get_next))
-        for _ in range(break_epoch + 1, num_epochs):
-          for i in range(start, stop):
-            self.assertEqual(i, sess.run(get_next))
-        with self.assertRaises(errors.OutOfRangeError):
-          sess.run(get_next)
-
-  def testRestoreInModifiedGraph(self):
-
-    def _build_graph(start, stop):
-      dataset = dataset_ops.Dataset.range(start, stop)
-      iterator = dataset_ops.make_initializable_iterator(dataset)
-      init_op = iterator.initializer
-      get_next = iterator.get_next()
-      save_op = self._save_op(iterator._iterator_resource)
-      restore_op = self._restore_op(iterator._iterator_resource)
-      return init_op, get_next, save_op, restore_op
-
-    # Saving and restoring in different sessions.
-    start = 2
-    stop = 10
-    stop_1 = 8
-    break_point = 5
-    with ops.Graph().as_default() as g:
-      init_op, get_next, save_op, _ = _build_graph(start, stop)
-      with self.session(graph=g) as sess:
-        sess.run(variables.global_variables_initializer())
-        sess.run(init_op)
-        for i in range(start, break_point):
-          self.assertEqual(i, sess.run(get_next))
-        sess.run(save_op)
-
-    with ops.Graph().as_default() as g:
-      # Intentionally build a graph with a different value for stop to make sure
-      # the original dataset graph is actually getting loaded.
-      init_op, get_next, _, restore_op = _build_graph(start, stop_1)
-      with self.session(graph=g) as sess:
-        sess.run(restore_op)
-        for i in range(break_point, stop):
-          self.assertEqual(i, sess.run(get_next))
-        with self.assertRaises(errors.OutOfRangeError):
-          sess.run(get_next)
-
-  def testInitThenRestore(self):
-    # Note: Calling init_op before restore_op is redundant. This test just makes
-    # sure we do not fail if restore is called on an already initialized
-    # iterator resource.
-
-    def _build_graph(start, stop):
-      dataset = dataset_ops.Dataset.range(start, stop)
-      iterator = dataset_ops.make_initializable_iterator(dataset)
-      init_op = iterator.initializer
-      get_next = iterator.get_next()
-      save_op = self._save_op(iterator._iterator_resource)
-      restore_op = self._restore_op(iterator._iterator_resource)
-      return init_op, get_next, save_op, restore_op
-
-    # Saving and restoring in different sessions.
-    start = 2
-    stop = 10
-    break_point = 5
-    with ops.Graph().as_default() as g:
-      init_op, get_next, save_op, _ = _build_graph(start, stop)
-      with self.session(graph=g) as sess:
-        sess.run(variables.global_variables_initializer())
-        sess.run(init_op)
-        for i in range(start, break_point):
-          self.assertEqual(i, sess.run(get_next))
-        sess.run(save_op)
-
-    with ops.Graph().as_default() as g:
-      init_op, get_next, _, restore_op = _build_graph(start, stop)
-      with self.session(graph=g) as sess:
-        sess.run(init_op)
-        sess.run(restore_op)
-        for i in range(break_point, stop):
-          self.assertEqual(i, sess.run(get_next))
-        with self.assertRaises(errors.OutOfRangeError):
-          sess.run(get_next)
-
-  def testMultipleSaves(self):
-
-    def _build_graph(start, stop):
-      iterator = dataset_ops.make_initializable_iterator(
-          dataset_ops.Dataset.range(start, stop))
-      init_op = iterator.initializer
-      get_next = iterator.get_next()
-      save_op = self._save_op(iterator._iterator_resource)
-      restore_op = self._restore_op(iterator._iterator_resource)
-      return init_op, get_next, save_op, restore_op
-
-    start = 2
-    stop = 10
-    break_point1 = 5
-    break_point2 = 7
-
-    with ops.Graph().as_default() as g:
-      init_op, get_next, save_op, _ = _build_graph(start, stop)
-      with self.session(graph=g) as sess:
-        sess.run(variables.global_variables_initializer())
-        sess.run(init_op)
-        for i in range(start, break_point1):
-          self.assertEqual(i, sess.run(get_next))
-        sess.run(save_op)
-
-    with ops.Graph().as_default() as g:
-      init_op, get_next, save_op, restore_op = _build_graph(start, stop)
-      with self.session(graph=g) as sess:
-        sess.run(restore_op)
-        for i in range(break_point1, break_point2):
-          self.assertEqual(i, sess.run(get_next))
-        sess.run(save_op)
-
-    break_point2 = 7
-    with ops.Graph().as_default() as g:
-      init_op, get_next, save_op, restore_op = _build_graph(start, stop)
-      with self.session(graph=g) as sess:
-        sess.run(restore_op)
-        for i in range(break_point2, stop):
-          self.assertEqual(i, sess.run(get_next))
-        with self.assertRaises(errors.OutOfRangeError):
-          sess.run(get_next)
-
-  def testSaveRestoreWithRepeat(self):
-
-    def _build_graph(start, stop, num_epochs):
-      iterator = dataset_ops.make_initializable_iterator(
-          dataset_ops.Dataset.range(start, stop).repeat(num_epochs))
-      init_op = iterator.initializer
-      get_next = iterator.get_next()
-      save_op = self._save_op(iterator._iterator_resource)
-      restore_op = self._restore_op(iterator._iterator_resource)
-      return init_op, get_next, save_op, restore_op
-
-    start = 2
-    stop = 10
-    num_epochs = 5
-    break_range = 5
-    break_epoch = 3
-    with ops.Graph().as_default() as g:
-      init_op, get_next, save_op, restore_op = _build_graph(
-          start, stop, num_epochs)
-      with self.session(graph=g) as sess:
-        sess.run(variables.global_variables_initializer())
-        sess.run(init_op)
-        # Note: There is no checkpoint saved currently so a NotFoundError is
-        # raised.
-        with self.assertRaises(errors.NotFoundError):
-          sess.run(restore_op)
-        for _ in range(break_epoch - 1):
-          for i in range(start, stop):
-            self.assertEqual(i, sess.run(get_next))
-        for i in range(start, break_range):
-          self.assertEqual(i, sess.run(get_next))
-        sess.run(save_op)
-
-    with ops.Graph().as_default() as g:
-      init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
-      with self.session(graph=g) as sess:
-        sess.run(restore_op)
-        for i in range(break_range, stop):
-          self.assertEqual(i, sess.run(get_next))
-        for _ in range(break_epoch, num_epochs):
-          for i in range(start, stop):
-            self.assertEqual(i, sess.run(get_next))
-        with self.assertRaises(errors.OutOfRangeError):
-          sess.run(get_next)
-
-  def testSaveRestoreExhaustedIterator(self):
-
-    def _build_graph(start, stop, num_epochs):
-      iterator = dataset_ops.make_initializable_iterator(
-          dataset_ops.Dataset.range(start, stop).repeat(num_epochs))
-      init_op = iterator.initializer
-      get_next = iterator.get_next()
-      save_op = self._save_op(iterator._iterator_resource)
-      restore_op = self._restore_op(iterator._iterator_resource)
-      return init_op, get_next, save_op, restore_op
-
-    start = 2
-    stop = 10
-    num_epochs = 5
-    with ops.Graph().as_default() as g:
-      init_op, get_next, save_op, restore_op = _build_graph(
-          start, stop, num_epochs)
-      with self.session(graph=g) as sess:
-        sess.run(variables.global_variables_initializer())
-        sess.run(init_op)
-        # Note: There is no checkpoint saved currently so a NotFoundError is
-        # raised.
-        with self.assertRaises(errors.NotFoundError):
-          sess.run(restore_op)
-        for _ in range(num_epochs):
-          for i in range(start, stop):
-            self.assertEqual(i, sess.run(get_next))
-        with self.assertRaises(errors.OutOfRangeError):
-          sess.run(get_next)
-        sess.run(save_op)
-
-    with ops.Graph().as_default() as g:
-      init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
-      with self.session(graph=g) as sess:
-        sess.run(restore_op)
-        with self.assertRaises(errors.OutOfRangeError):
-          sess.run(get_next)
-
-
-if __name__ == "__main__":
-  test.main()
diff --git a/tensorflow/python/data/kernel_tests/dataset_test.py b/tensorflow/python/data/kernel_tests/dataset_test.py
index 348228b..b2d2fd2 100644
--- a/tensorflow/python/data/kernel_tests/dataset_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_test.py
@@ -34,14 +34,16 @@
 from tensorflow.python.eager import def_function
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import random_ops
 from tensorflow.python.platform import test
-from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.platform import tf_logging
 
 
 @test_util.run_all_in_graph_and_eager_modes
@@ -51,7 +53,13 @@
     dataset = dataset_ops.Dataset.range(10)
     graph = graph_pb2.GraphDef().FromString(
         self.evaluate(dataset._as_serialized_graph()))
-    self.assertTrue(any([node.op != "RangeDataset" for node in graph.node]))
+    self.assertTrue(any([node.op == "RangeDataset" for node in graph.node]))
+
+  def testAsSerializedGraphStateful(self):
+    dataset = dataset_ops.Dataset.range(10).map(
+        lambda _: random_ops.random_uniform(()))
+    with self.assertRaises(errors.FailedPreconditionError):
+      self.evaluate(dataset._as_serialized_graph())
 
   def testAsFunctionWithMap(self):
     if not context.executing_eagerly():
@@ -320,7 +328,7 @@
   def testSkipEagerSameGraphErrorOneShotSimple(self):
     dataset = dataset_ops.Dataset.range(10)
     with ops.Graph().as_default():
-      with test.mock.patch.object(logging, "warning") as mock_log:
+      with test.mock.patch.object(tf_logging, "warning") as mock_log:
         _ = dataset_ops.make_one_shot_iterator(dataset)
         self.assertRegexpMatches(
             str(mock_log.call_args), "Please ensure that all datasets in the "
diff --git a/tensorflow/python/data/kernel_tests/enumerate_test.py b/tensorflow/python/data/kernel_tests/enumerate_test.py
index 0666449..1ff9b9d 100644
--- a/tensorflow/python/data/kernel_tests/enumerate_test.py
+++ b/tensorflow/python/data/kernel_tests/enumerate_test.py
@@ -17,18 +17,20 @@
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
+
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import test_util
 from tensorflow.python.platform import test
 
 
-@test_util.run_all_in_graph_and_eager_modes
-class EnumerateTest(test_base.DatasetTestBase):
+class EnumerateTest(test_base.DatasetTestBase, parameterized.TestCase):
 
+  @combinations.generate(test_base.default_test_combinations())
   def testEnumerate(self):
     components = (["a", "b"], [1, 2], [37.0, 38])
     start = constant_op.constant(20, dtype=dtypes.int64)
diff --git a/tensorflow/python/data/kernel_tests/iterator_checkpoint_test.py b/tensorflow/python/data/kernel_tests/iterator_checkpoint_test.py
deleted file mode 100644
index dfb54b5..0000000
--- a/tensorflow/python/data/kernel_tests/iterator_checkpoint_test.py
+++ /dev/null
@@ -1,129 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Checkpoint tests for `tf.data.Iterator`."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import functools
-import os
-
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.eager import context
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import test_util
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-from tensorflow.python.training import checkpoint_management
-from tensorflow.python.training.tracking import util as trackable_utils
-
-
-@test_util.run_all_in_graph_and_eager_modes
-class IteratorCheckpointingTest(test_base.DatasetTestBase):
-
-  def testSaveRestoreOneShotIterator(self):
-    checkpoint_directory = self.get_temp_dir()
-    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
-    dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]).map(
-        math_ops.square).batch(2)
-    iterator = iter(dataset) if context.executing_eagerly(
-    ) else dataset_ops.make_one_shot_iterator(dataset)
-    get_next = iterator.get_next if context.executing_eagerly(
-    ) else functools.partial(self.evaluate, iterator.get_next())
-    checkpoint = trackable_utils.Checkpoint(iterator=iterator)
-    self.assertAllEqual([1, 4], get_next())
-    save_path = checkpoint.save(checkpoint_prefix)
-    self.assertAllEqual([9, 16], get_next())
-    self.assertAllEqual([25, 36], get_next())
-    checkpoint.restore(save_path).run_restore_ops()
-    self.assertAllEqual([9, 16], get_next())
-    self.assertAllEqual([25, 36], get_next())
-    with self.assertRaises(errors.OutOfRangeError):
-      get_next()
-
-  def testSaveRestoreMultipleIterator(self):
-    checkpoint_directory = self.get_temp_dir()
-    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
-    dataset = dataset_ops.Dataset.from_tensor_slices(
-        [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
-    dataset = dataset.map(math_ops.square).batch(2)
-    iterator_1 = iter(dataset) if context.executing_eagerly(
-    ) else dataset_ops.make_one_shot_iterator(dataset)
-    get_next_1 = iterator_1.get_next if context.executing_eagerly(
-    ) else functools.partial(self.evaluate, iterator_1.get_next())
-    iterator_2 = iter(dataset) if context.executing_eagerly(
-    ) else dataset_ops.make_one_shot_iterator(dataset)
-    get_next_2 = iterator_2.get_next if context.executing_eagerly(
-    ) else functools.partial(self.evaluate, iterator_2.get_next())
-    dataset_2 = dataset_ops.Dataset.range(10)
-    iterator_3 = iter(dataset_2) if context.executing_eagerly(
-    ) else dataset_ops.make_one_shot_iterator(dataset_2)
-    get_next_3 = iterator_3.get_next if context.executing_eagerly(
-    ) else functools.partial(self.evaluate, iterator_3.get_next())
-    checkpoint = trackable_utils.Checkpoint(
-        iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3)
-    self.assertAllEqual([1, 4], get_next_1())
-    self.assertAllEqual(0, get_next_3())
-    self.assertAllEqual(1, get_next_3())
-    self.assertAllEqual(2, get_next_3())
-    save_path = checkpoint.save(checkpoint_prefix)
-    self.assertAllEqual([1, 4], get_next_2())
-    self.assertAllEqual([9, 16], get_next_2())
-    self.assertAllEqual(3, get_next_3())
-    checkpoint.restore(save_path).run_restore_ops()
-    self.assertAllEqual([9, 16], get_next_1())
-    self.assertAllEqual([1, 4], get_next_2())
-    self.assertAllEqual(3, get_next_3())
-
-  def testRestoreExhaustedIterator(self):
-    checkpoint_directory = self.get_temp_dir()
-    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
-    dataset = dataset_ops.Dataset.range(3)
-    iterator = iter(dataset) if context.executing_eagerly(
-    ) else dataset_ops.make_one_shot_iterator(dataset)
-    get_next = iterator.get_next if context.executing_eagerly(
-    ) else functools.partial(self.evaluate, iterator.get_next())
-    checkpoint = trackable_utils.Checkpoint(iterator=iterator)
-    self.assertAllEqual(0, get_next())
-    self.assertAllEqual(1, get_next())
-    save_path = checkpoint.save(checkpoint_prefix)
-    self.assertAllEqual(2, get_next())
-    checkpoint.restore(save_path).run_restore_ops()
-    self.assertAllEqual(2, get_next())
-    save_path = checkpoint.save(checkpoint_prefix)
-    checkpoint.restore(save_path).run_restore_ops()
-    with self.assertRaises(errors.OutOfRangeError):
-      get_next()
-
-  def testRestoreInReconstructedIteratorInitializable(self):
-    checkpoint_directory = self.get_temp_dir()
-    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
-    dataset = dataset_ops.Dataset.range(10)
-    iterator = iter(dataset) if context.executing_eagerly(
-    ) else dataset_ops.make_initializable_iterator(dataset)
-    get_next = iterator.get_next
-    checkpoint = trackable_utils.Checkpoint(iterator=iterator)
-    for i in range(5):
-      checkpoint.restore(
-          checkpoint_management.latest_checkpoint(
-              checkpoint_directory)).initialize_or_restore()
-      for j in range(2):
-        self.assertEqual(i * 2 + j, self.evaluate(get_next()))
-      checkpoint.save(file_prefix=checkpoint_prefix)
-
-
-if __name__ == "__main__":
-  test.main()
diff --git a/tensorflow/python/data/kernel_tests/iterator_test.py b/tensorflow/python/data/kernel_tests/iterator_test.py
index caaf09c..4fc1566 100644
--- a/tensorflow/python/data/kernel_tests/iterator_test.py
+++ b/tensorflow/python/data/kernel_tests/iterator_test.py
@@ -17,7 +17,6 @@
 from __future__ import division
 from __future__ import print_function
 
-import os
 import warnings
 
 from absl.testing import parameterized
@@ -30,7 +29,6 @@
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.data.ops import readers
 from tensorflow.python.data.util import structure
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
@@ -45,9 +43,7 @@
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import data_flow_ops
 from tensorflow.python.ops import functional_ops
-from tensorflow.python.ops import gen_dataset_ops
 from tensorflow.python.ops import gradients_impl
-from tensorflow.python.ops import io_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import parsing_ops
 from tensorflow.python.ops import script_ops
@@ -212,24 +208,22 @@
   @test_util.deprecated_graph_mode_only
   def testOneShotIteratorInitializerFails(self):
     # Define a dataset whose initialization will always fail.
-    dataset = dataset_ops.Dataset.from_tensors(
-        array_ops.check_numerics(
-            constant_op.constant(1.0) / constant_op.constant(0.0), "oops"))
+    dataset = dataset_ops.Dataset.from_tensors(array_ops.gather([0], [4]))
     iterator = dataset_ops.make_one_shot_iterator(dataset)
     next_element = iterator.get_next()
 
     with self.cached_session() as sess:
-      with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
+      with self.assertRaisesRegexp(errors.InvalidArgumentError, ""):
         sess.run(next_element)
 
       # Test that subsequent attempts to use the iterator also fail.
-      with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
+      with self.assertRaisesRegexp(errors.InvalidArgumentError, ""):
         sess.run(next_element)
 
     with self.cached_session() as sess:
 
       def consumer_thread():
-        with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
+        with self.assertRaisesRegexp(errors.InvalidArgumentError, ""):
           sess.run(next_element)
 
       num_threads = 8
@@ -772,65 +766,6 @@
             })
 
   @test_util.deprecated_graph_mode_only
-  def testIncorrectIteratorRestore(self):
-
-    def _path():
-      return os.path.join(self.get_temp_dir(), "iterator")
-
-    def _save_op(iterator_resource):
-      iterator_state_variant = gen_dataset_ops.serialize_iterator(
-          iterator_resource)
-      save_op = io_ops.write_file(
-          _path(), parsing_ops.serialize_tensor(iterator_state_variant))
-      return save_op
-
-    def _restore_op(iterator_resource):
-      iterator_state_variant = parsing_ops.parse_tensor(
-          io_ops.read_file(_path()), dtypes.variant)
-      restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
-                                                        iterator_state_variant)
-      return restore_op
-
-    def _build_range_dataset_graph():
-      start = 1
-      stop = 10
-      iterator = dataset_ops.make_initializable_iterator(
-          dataset_ops.Dataset.range(start, stop))
-      init_op = iterator.initializer
-      get_next = iterator.get_next()
-      save_op = _save_op(iterator._iterator_resource)
-      restore_op = _restore_op(iterator._iterator_resource)
-      return init_op, get_next, save_op, restore_op
-
-    def _build_reader_dataset_graph():
-      filenames = ["test"]  # Does not exist but we don't care in this test.
-      iterator = dataset_ops.make_initializable_iterator(
-          readers.FixedLengthRecordDataset(filenames, 1, 0, 0))
-      init_op = iterator.initializer
-      get_next_op = iterator.get_next()
-      save_op = _save_op(iterator._iterator_resource)
-      restore_op = _restore_op(iterator._iterator_resource)
-      return init_op, get_next_op, save_op, restore_op
-
-    # Saving iterator for RangeDataset graph.
-    with ops.Graph().as_default() as g:
-      init_op, _, save_op, _ = _build_range_dataset_graph()
-      with self.session(graph=g) as sess:
-        sess.run(init_op)
-        sess.run(save_op)
-
-    # Attempt to restore the saved iterator into an IteratorResource of
-    # incompatible type. An iterator of RangeDataset has output type int64,
-    # while an iterator of FixedLengthRecordDataset has output type string.
-    # So an InvalidArgumentError should be raised by
-    # IteratorResource::set_iterator.
-    with ops.Graph().as_default() as g:
-      _, _, _, restore_op = _build_reader_dataset_graph()
-      with self.session(graph=g) as sess:
-        with self.assertRaises(errors.InvalidArgumentError):
-          sess.run(restore_op)
-
-  @test_util.deprecated_graph_mode_only
   def testRepeatedGetNextWarning(self):
     iterator = dataset_ops.make_one_shot_iterator(dataset_ops.Dataset.range(10))
     warnings.simplefilter("always")
diff --git a/tensorflow/python/data/kernel_tests/list_files_test.py b/tensorflow/python/data/kernel_tests/list_files_test.py
index 03cec7e..6168330 100644
--- a/tensorflow/python/data/kernel_tests/list_files_test.py
+++ b/tensorflow/python/data/kernel_tests/list_files_test.py
@@ -79,8 +79,9 @@
     filenames = ['a', 'b', 'c']
     self._touchTempFiles(filenames)
 
-    dataset = dataset_ops.Dataset.list_files(
-        path.join(self.tmp_dir, '*'), shuffle=True, seed=37)
+    def dataset_fn():
+      return dataset_ops.Dataset.list_files(
+          path.join(self.tmp_dir, '*'), shuffle=True, seed=37)
 
     expected_filenames = [
         compat.as_bytes(path.join(self.tmp_dir, filename))
@@ -90,7 +91,7 @@
     all_actual_filenames = []
     for _ in range(3):
       actual_filenames = []
-      next_element = self.getNext(dataset, requires_initialization=True)
+      next_element = self.getNext(dataset_fn(), requires_initialization=True)
       try:
         while True:
           actual_filenames.append(self.evaluate(next_element()))
diff --git a/tensorflow/python/data/kernel_tests/shuffle_test.py b/tensorflow/python/data/kernel_tests/shuffle_test.py
index a0f56a4..e8a475a 100644
--- a/tensorflow/python/data/kernel_tests/shuffle_test.py
+++ b/tensorflow/python/data/kernel_tests/shuffle_test.py
@@ -18,25 +18,25 @@
 from __future__ import print_function
 
 import collections
+import functools
 
 from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
-
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import random_seed
-from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.platform import test
 
 
-@test_util.run_all_in_graph_and_eager_modes
 class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
 
+  @combinations.generate(test_base.default_test_combinations())
   def testShuffleDataset(self):
     components = (
         np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
@@ -115,8 +115,8 @@
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
 
-  @test_util.run_deprecated_v1
-  def testSkipEagerSeedZero(self):
+  @combinations.generate(combinations.combine(tf_api_version=1, mode="graph"))
+  def testSeedZero(self):
     """Test for same behavior when the seed is a Python or Tensor zero."""
     iterator = dataset_ops.make_one_shot_iterator(
         dataset_ops.Dataset.range(10).shuffle(10, seed=0))
@@ -141,6 +141,7 @@
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
 
+  @combinations.generate(test_base.default_test_combinations())
   def testDefaultArguments(self):
     components = [0, 1, 2, 3, 4]
     dataset = dataset_ops.Dataset.from_tensor_slices(components).shuffle(
@@ -154,42 +155,20 @@
     for i in range(5):
       self.assertEqual(10, counts[i])
 
-  @parameterized.named_parameters(
-      ("Reshuffle", True),
-      ("NoReshuffle", False),
-  )
-  def testReshuffle(self, reshuffle):
-    dataset = dataset_ops.Dataset.range(10).shuffle(
-        10, reshuffle_each_iteration=reshuffle).repeat(2)
-    next_element = self.getNext(dataset)
-
-    first_epoch = []
-    for _ in range(10):
-      first_epoch.append(self.evaluate(next_element()))
-
-    second_epoch = []
-    for _ in range(10):
-      second_epoch.append(self.evaluate(next_element()))
-
-    self.assertEqual(first_epoch == second_epoch, not reshuffle)
-
-  @parameterized.named_parameters(
-      ("ReshuffleGraphLevelSeed", True, 38, None),
-      ("ReshuffleOpLevelSeed", True, None, 42),
-      ("ReshuffleGraphAndOpLevelSeed", True, 38, 42),
-      ("NoReshuffleGraphLevelSeed", False, 38, None),
-      ("NoReshuffleOpLevelSeed", False, None, 42),
-      ("NoReshuffleGraphAndOpLevelSeed", False, 38, 42),
-  )
-  def testSkipEagerShuffleSeed(self, reshuffle, graph_level_seed,
-                               op_level_seed):
+  @combinations.generate(
+      combinations.times(
+          combinations.combine(tf_api_version=[1, 2], mode="graph"),
+          combinations.combine(reshuffle=[True, False]),
+          combinations.combine(graph_seed=38, op_seed=None) +
+          combinations.combine(graph_seed=None, op_seed=42) +
+          combinations.combine(graph_seed=38, op_seed=42)))
+  def testShuffleSeed(self, reshuffle, graph_seed, op_seed):
     results = []
     for _ in range(2):
       with ops.Graph().as_default() as g:
-        random_seed.set_random_seed(graph_level_seed)
+        random_seed.set_random_seed(graph_seed)
         dataset = dataset_ops.Dataset.range(10).shuffle(
-            10, seed=op_level_seed, reshuffle_each_iteration=reshuffle).repeat(
-                3)
+            10, seed=op_seed, reshuffle_each_iteration=reshuffle).repeat(3)
         iterator = dataset_ops.make_one_shot_iterator(dataset)
         next_element = iterator.get_next()
 
@@ -203,15 +182,13 @@
 
     self.assertAllEqual(results[0], results[1])
 
-  # TODO(b/117581999): fails for eager mode with result[0] equal to result[1],
-  # debug.
-  @parameterized.named_parameters(
-      ("ReshuffleOneShot", True, False),
-      ("ReshuffleInitializable", True, True),
-      ("NoReshuffleOneShot", False, False),
-      ("NoReshuffleInitializable", False, True),
-  )
-  def testSkipEagerMultipleIterators(self, reshuffle, initializable):
+  # TODO(b/117581999): enable this test for eager-mode.
+  @combinations.generate(
+      combinations.times(
+          combinations.combine(tf_api_version=[1, 2], mode="graph"),
+          combinations.combine(
+              reshuffle=[True, False], initializable=[True, False])))
+  def testMultipleIterators(self, reshuffle, initializable):
     with ops.Graph().as_default() as g:
       dataset = dataset_ops.Dataset.range(100).shuffle(
           10, reshuffle_each_iteration=reshuffle).repeat(3)
@@ -239,6 +216,62 @@
 
         self.assertNotEqual(results[0], results[1])
 
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(reshuffle=[True, False], seed=[None, 42])))
+  def testReshuffleRepeatEpochs(self, reshuffle, seed):
+    dataset = dataset_ops.Dataset.range(10).shuffle(
+        10, seed=seed, reshuffle_each_iteration=reshuffle).repeat(2)
+    next_element = self.getNext(dataset)
+
+    first_epoch = []
+    for _ in range(10):
+      first_epoch.append(self.evaluate(next_element()))
+
+    second_epoch = []
+    for _ in range(10):
+      second_epoch.append(self.evaluate(next_element()))
+
+    self.assertEqual(first_epoch == second_epoch, not reshuffle)
+
+  @combinations.generate(
+      combinations.times(
+          combinations.combine(tf_api_version=2, mode="eager"),
+          combinations.combine(reshuffle=[True, False], seed=[None, 42])))
+  def testReshuffleIterationEpochs(self, reshuffle, seed):
+    dataset = dataset_ops.Dataset.range(10).shuffle(
+        10, seed=seed, reshuffle_each_iteration=reshuffle)
+
+    first_epoch = []
+    for elem in dataset:
+      first_epoch.append(elem.numpy())
+
+    second_epoch = []
+    for elem in dataset:
+      second_epoch.append(elem.numpy())
+
+    self.assertEqual(first_epoch == second_epoch, not reshuffle)
+
+  @combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
+  def testShuffleV2ResourceCapture(self):
+
+    def make_dataset():
+      ids = dataset_ops.Dataset.range(10)
+      ids = ids.shuffle(1)
+
+      def interleave_fn(dataset, _):
+        return dataset
+
+      dataset = dataset_ops.Dataset.range(1)
+      dataset = dataset.interleave(functools.partial(interleave_fn, ids))
+      return dataset
+
+    results = []
+    for elem in make_dataset():
+      results.append(elem.numpy())
+
+    self.assertAllEqual(results, range(10))
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py
index f17f018..c81ec21 100644
--- a/tensorflow/python/data/kernel_tests/test_base.py
+++ b/tensorflow/python/data/kernel_tests/test_base.py
@@ -24,8 +24,10 @@
 from tensorflow.python.data.util import nest
 from tensorflow.python.data.util import structure
 from tensorflow.python.eager import context
+from tensorflow.python.framework import combinations
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import tensor_array_ops
@@ -33,6 +35,11 @@
 from tensorflow.python.platform import test
 
 
+def default_test_combinations():
+  """Returns the default test combinations for tf.data tests."""
+  return combinations.combine(tf_api_version=[1, 2], mode=["eager", "graph"])
+
+
 class DatasetTestBase(test.TestCase):
   """Base class for dataset tests."""
 
@@ -87,7 +94,11 @@
         else:
           return r
       return _wrapper
-    if context.executing_eagerly():
+
+    # Create an anonymous iterator if we are in eager-mode or are graph inside
+    # of a tf.function.
+    building_function = ops.get_default_graph()._building_function  # pylint: disable=protected-access
+    if context.executing_eagerly() or building_function:
       iterator = iter(dataset)
       return ta_wrapper(iterator._next_internal)  # pylint: disable=protected-access
     else:
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index c60ebe9..db425fa 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -30,7 +30,7 @@
 
 
 from tensorflow.core.framework import graph_pb2
-from tensorflow.python.compat import compat
+from tensorflow.python import tf2
 from tensorflow.python.data.experimental.ops import distribute_options
 from tensorflow.python.data.experimental.ops import optimization_options
 from tensorflow.python.data.experimental.ops import stats_options
@@ -108,6 +108,25 @@
   A `Dataset` can be used to represent an input pipeline as a
   collection of elements and a "logical plan" of transformations that act on
   those elements.
+
+  A dataset contains elements that each have the same (nested) structure and the
+  individual components of the structure can be of any type representable by
+  `tf.TypeSpec`, including `tf.Tensor`, `tf.data.Dataset`, `tf.SparseTensor`,
+  `tf.RaggedTensor`, or `tf.TensorArray`.
+
+  Example elements:
+  ```python
+  # Integer element
+  a = 1
+  # Float element
+  b = 2.0
+  # Tuple element with 2 components
+  c = (1, 2)
+  # Dict element with 3 components
+  d = {"a": (2, 2), "b": 3}
+  # Element containing a dataset
+  e = tf.data.Dataset.from_element(10)
+  ```
   """
 
   def __init__(self, variant_tensor):
@@ -1165,18 +1184,17 @@
 
     2) Use `tf.py_function`, which allows you to write arbitrary Python code but
     will generally result in worse performance than 1). For example:
-    
+
     ```python
     d = tf.data.Dataset.from_tensor_slices(['hello', 'world'])
-    
+
     # transform a string tensor to upper case string using a Python function
     def upper_case_fn(t: tf.Tensor) -> str:
         return t.numpy().decode('utf-8').upper()
-    
+
     d.map(lambda x: tf.py_function(func=upper_case_fn,
           inp=[x], Tout=tf.string))  # ==> [ "HELLO", "WORLD" ]
     ```
-    
 
     Args:
       map_func: A function mapping a dataset element to another dataset element.
@@ -1598,13 +1616,13 @@
         raise AttributeError("Please use _variant_tensor instead of "
                              "_as_variant_tensor() to obtain the variant "
                              "associated with a dataset")
-      raise AttributeError("A likely cause of this error is that the super "
+      raise AttributeError("{}: A likely cause of this error is that the super "
                            "call for this dataset is not the last line of the "
                            "__init__ method. The base class causes the "
                            "_as_variant_tensor call in its constructor and "
                            "if that uses attributes defined in the __init__ "
                            "method, those attrs need to be defined before the "
-                           "super call.")
+                           "super call.".format(e))
     super(DatasetV1, self).__init__(variant_tensor)
 
   @abc.abstractmethod
@@ -1724,12 +1742,8 @@
     dataset = self._apply_options()
     if shared_name is None:
       shared_name = ""
-    if compat.forward_compatible(2018, 8, 3):
-      iterator_resource = gen_dataset_ops.iterator_v2(
-          container="", shared_name=shared_name, **self._flat_structure)
-    else:
-      iterator_resource = gen_dataset_ops.iterator(
-          container="", shared_name=shared_name, **self._flat_structure)
+    iterator_resource = gen_dataset_ops.iterator_v2(
+        container="", shared_name=shared_name, **self._flat_structure)
     with ops.colocate_with(iterator_resource):
       initializer = gen_dataset_ops.make_iterator(
           dataset._variant_tensor,  # pylint: disable=protected-access
@@ -2195,7 +2209,7 @@
       name="experimental_distribute",
       ty=distribute_options.DistributeOptions,
       docstring=
-      "The distribution options associated with the dataset. See "
+      "The distribution strategy options associated with the dataset. See "
       "`tf.data.experimental.DistributeOptions` for more details.",
       default_factory=distribute_options.DistributeOptions)
 
@@ -2240,11 +2254,13 @@
 
     if self.experimental_deterministic is False:
       result.append("make_sloppy")
-    exp_stats_options = self.experimental_stats
-    if exp_stats_options and exp_stats_options.latency_all_edges:
+    if self.experimental_stats and self.experimental_stats.latency_all_edges:
       result.append("latency_all_edges")
     if self.experimental_slack:
       result.append("slack")
+    if (self.experimental_distribute and
+        self.experimental_distribute._make_stateless):  # pylint: disable=protected-access
+      result.append("make_stateless")
     return result
 
   def _static_optimization_configs(self):
@@ -2578,6 +2594,8 @@
                          "must be specified.")
       self._input_structure = input_structure
 
+    self._func = func
+
     if defun_kwargs is None:
       defun_kwargs = {}
 
@@ -2909,6 +2927,47 @@
     return self._structure
 
 
+class _MemoryCacheDeleter(object):
+  """An object which cleans up an anonymous memory cache resource.
+
+  An alternative to defining a __del__ method on an object. Even if the parent
+  object is part of a reference cycle, the cycle will be collectable.
+  """
+
+  def __init__(self, handle, device, deleter):
+    self._deleter = deleter
+    self._handle = handle
+    self._device = device
+    self._eager_mode = context.executing_eagerly()
+
+  def __del__(self):
+    with ops.device(self._device):
+      # Make sure the resource is deleted in the same mode as it was created in.
+      if self._eager_mode:
+        with context.eager_mode():
+          gen_dataset_ops.delete_memory_cache(
+              handle=self._handle, deleter=self._deleter)
+      else:
+        with context.graph_mode():
+          gen_dataset_ops.delete_memory_cache(
+              handle=self._handle, deleter=self._deleter)
+
+
+class _MemoryCache(object):
+  """Represents a memory cache resource."""
+
+  def __init__(self):
+    super(_MemoryCache, self).__init__()
+    self._device = context.context().device_name
+    self._handle, self._deleter = (gen_dataset_ops.anonymous_memory_cache())
+    self._resource_deleter = _MemoryCacheDeleter(
+        handle=self._handle, device=self._device, deleter=self._deleter)
+
+  @property
+  def handle(self):
+    return self._handle
+
+
 class CacheDataset(UnaryUnchangedStructureDataset):
   """A `Dataset` that caches elements of its input."""
 
@@ -2917,13 +2976,64 @@
     self._input_dataset = input_dataset
     self._filename = ops.convert_to_tensor(
         filename, dtype=dtypes.string, name="filename")
-    variant_tensor = gen_dataset_ops.cache_dataset(
-        input_dataset._variant_tensor,  # pylint: disable=protected-access
-        filename=self._filename,
-        **self._flat_structure)
+    if tf2.enabled() and (context.executing_eagerly() or
+                          ops.get_default_graph()._building_function):  # pylint: disable=protected-access
+      self._cache = _MemoryCache()
+      variant_tensor = gen_dataset_ops.cache_dataset_v2(
+          input_dataset._variant_tensor,  # pylint: disable=protected-access
+          filename=self._filename,
+          cache=self._cache.handle,
+          **self._flat_structure)
+    else:
+      variant_tensor = gen_dataset_ops.cache_dataset(
+          input_dataset._variant_tensor,  # pylint: disable=protected-access
+          filename=self._filename,
+          **self._flat_structure)
     super(CacheDataset, self).__init__(input_dataset, variant_tensor)
 
 
+class _RandomSeedGeneratorDeleter(object):
+  """An object which cleans up an anonymous random seed generator resource.
+
+  An alternative to defining a __del__ method on an object. Even if the parent
+  object is part of a reference cycle, the cycle will be collectable.
+  """
+
+  def __init__(self, handle, device, deleter):
+    self._deleter = deleter
+    self._handle = handle
+    self._device = device
+    self._eager_mode = context.executing_eagerly()
+
+  def __del__(self):
+    with ops.device(self._device):
+      # Make sure the resource is deleted in the same mode as it was created in.
+      if self._eager_mode:
+        with context.eager_mode():
+          gen_dataset_ops.delete_random_seed_generator(
+              handle=self._handle, deleter=self._deleter)
+      else:
+        with context.graph_mode():
+          gen_dataset_ops.delete_random_seed_generator(
+              handle=self._handle, deleter=self._deleter)
+
+
+class _RandomSeedGenerator(object):
+  """Represents a random seed generator resource."""
+
+  def __init__(self, seed, seed2):
+    super(_RandomSeedGenerator, self).__init__()
+    self._device = context.context().device_name
+    self._handle, self._deleter = (
+        gen_dataset_ops.anonymous_random_seed_generator(seed=seed, seed2=seed2))
+    self._resource_deleter = _RandomSeedGeneratorDeleter(
+        handle=self._handle, device=self._device, deleter=self._deleter)
+
+  @property
+  def handle(self):
+    return self._handle
+
+
 class ShuffleDataset(UnaryUnchangedStructureDataset):
   """A `Dataset` that randomly shuffles the elements of its input."""
 
@@ -2960,13 +3070,24 @@
       self._reshuffle_each_iteration = True
     else:
       self._reshuffle_each_iteration = reshuffle_each_iteration
-    variant_tensor = gen_dataset_ops.shuffle_dataset(
-        input_dataset._variant_tensor,  # pylint: disable=protected-access
-        buffer_size=self._buffer_size,
-        seed=self._seed,
-        seed2=self._seed2,
-        reshuffle_each_iteration=self._reshuffle_each_iteration,
-        **self._flat_structure)
+
+    if tf2.enabled() and self._reshuffle_each_iteration and (
+        context.executing_eagerly() or
+        ops.get_default_graph()._building_function):  # pylint: disable=protected-access
+      self._seed_generator = _RandomSeedGenerator(self._seed, self._seed2)
+      variant_tensor = gen_dataset_ops.shuffle_dataset_v2(
+          input_dataset._variant_tensor,  # pylint: disable=protected-access
+          buffer_size=self._buffer_size,
+          seed_generator=self._seed_generator.handle,
+          **self._flat_structure)
+    else:
+      variant_tensor = gen_dataset_ops.shuffle_dataset(
+          input_dataset._variant_tensor,  # pylint: disable=protected-access
+          buffer_size=self._buffer_size,
+          seed=self._seed,
+          seed2=self._seed2,
+          reshuffle_each_iteration=self._reshuffle_each_iteration,
+          **self._flat_structure)
     super(ShuffleDataset, self).__init__(input_dataset, variant_tensor)
 
 
@@ -3629,20 +3750,12 @@
     self._stats_aggregator = aggregator
     self._prefix = prefix
     self._counter_prefix = counter_prefix
-    if compat.forward_compatible(2019, 8, 3):
-      variant_tensor = ged_ops.set_stats_aggregator_dataset(
-          input_dataset._variant_tensor,  # pylint: disable=protected-access
-          self._stats_aggregator._resource,  # pylint: disable=protected-access
-          self._prefix,
-          self._counter_prefix,
-          **self._flat_structure)
-    else:
-      variant_tensor = ged_ops.experimental_set_stats_aggregator_dataset(
-          input_dataset._variant_tensor,  # pylint: disable=protected-access
-          self._stats_aggregator._resource,  # pylint: disable=protected-access
-          self._prefix,
-          self._counter_prefix,
-          **self._flat_structure)
+    variant_tensor = ged_ops.set_stats_aggregator_dataset(
+        input_dataset._variant_tensor,  # pylint: disable=protected-access
+        self._stats_aggregator._resource,  # pylint: disable=protected-access
+        self._prefix,
+        self._counter_prefix,
+        **self._flat_structure)
     super(_SetStatsAggregatorDataset, self).__init__(input_dataset,
                                                      variant_tensor)
 
@@ -3656,16 +3769,10 @@
         max_intra_op_parallelism,
         dtype=dtypes.int64,
         name="max_intra_op_parallelism")
-    if compat.forward_compatible(2019, 8, 3):
-      variant_tensor = ged_ops.max_intra_op_parallelism_dataset(
-          input_dataset._variant_tensor,  # pylint: disable=protected-access
-          self._max_intra_op_parallelism,
-          **self._flat_structure)
-    else:
-      variant_tensor = ged_ops.experimental_max_intra_op_parallelism_dataset(
-          input_dataset._variant_tensor,  # pylint: disable=protected-access
-          self._max_intra_op_parallelism,
-          **self._flat_structure)
+    variant_tensor = ged_ops.max_intra_op_parallelism_dataset(
+        input_dataset._variant_tensor,  # pylint: disable=protected-access
+        self._max_intra_op_parallelism,
+        **self._flat_structure)
     super(_MaxIntraOpParallelismDataset, self).__init__(input_dataset,
                                                         variant_tensor)
 
@@ -3677,16 +3784,10 @@
     self._input_dataset = input_dataset
     self._num_threads = ops.convert_to_tensor(
         num_threads, dtype=dtypes.int64, name="num_threads")
-    if compat.forward_compatible(2019, 8, 3):
-      variant_tensor = ged_ops.private_thread_pool_dataset(
-          input_dataset._variant_tensor,  # pylint: disable=protected-access
-          self._num_threads,
-          **self._flat_structure)
-    else:
-      variant_tensor = ged_ops.experimental_private_thread_pool_dataset(
-          input_dataset._variant_tensor,  # pylint: disable=protected-access
-          self._num_threads,
-          **self._flat_structure)
+    variant_tensor = ged_ops.private_thread_pool_dataset(
+        input_dataset._variant_tensor,  # pylint: disable=protected-access
+        self._num_threads,
+        **self._flat_structure)
     super(_PrivateThreadPoolDataset, self).__init__(input_dataset,
                                                     variant_tensor)
 
@@ -3725,14 +3826,9 @@
     self._structure = nest.map_structure(
         lambda component_spec: component_spec._unbatch(),  # pylint: disable=protected-access
         get_structure(input_dataset))
-    if compat.forward_compatible(2019, 8, 3):
-      variant_tensor = ged_ops.unbatch_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          **self._flat_structure)
-    else:
-      variant_tensor = ged_ops.experimental_unbatch_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          **self._flat_structure)
+    variant_tensor = ged_ops.unbatch_dataset(
+        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
+        **self._flat_structure)
     super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
 
   @property
diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py
index 8d523d3..446cd09 100644
--- a/tensorflow/python/data/ops/iterator_ops.py
+++ b/tensorflow/python/data/ops/iterator_ops.py
@@ -20,7 +20,6 @@
 import threading
 import warnings
 
-from tensorflow.python.compat import compat
 from tensorflow.python.data.ops import optional_ops
 from tensorflow.python.data.util import nest
 from tensorflow.python.data.util import structure
@@ -201,29 +200,22 @@
         output_types, output_shapes, output_classes)
     if shared_name is None:
       shared_name = ""
-    if compat.forward_compatible(2018, 8, 3):
-      if _device_stack_is_empty():
-        with ops.device("/cpu:0"):
-          iterator_resource = gen_dataset_ops.iterator_v2(
-              container="",
-              shared_name=shared_name,
-              output_types=structure.get_flat_tensor_types(
-                  output_structure),
-              output_shapes=structure.get_flat_tensor_shapes(
-                  output_structure))
-      else:
+    if _device_stack_is_empty():
+      with ops.device("/cpu:0"):
         iterator_resource = gen_dataset_ops.iterator_v2(
             container="",
             shared_name=shared_name,
-            output_types=structure.get_flat_tensor_types(output_structure),
+            output_types=structure.get_flat_tensor_types(
+                output_structure),
             output_shapes=structure.get_flat_tensor_shapes(
                 output_structure))
     else:
-      iterator_resource = gen_dataset_ops.iterator(
+      iterator_resource = gen_dataset_ops.iterator_v2(
           container="",
           shared_name=shared_name,
           output_types=structure.get_flat_tensor_types(output_structure),
-          output_shapes=structure.get_flat_tensor_shapes(output_structure))
+          output_shapes=structure.get_flat_tensor_shapes(
+              output_structure))
     return Iterator(iterator_resource, None, output_types, output_shapes,
                     output_classes)
 
@@ -291,20 +283,14 @@
     output_structure = structure.convert_legacy_structure(
         output_types, output_shapes, output_classes)
     string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
-    if compat.forward_compatible(2018, 8, 3):
-      if _device_stack_is_empty():
-        with ops.device("/cpu:0"):
-          iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
-              string_handle,
-              output_types=structure.get_flat_tensor_types(output_structure),
-              output_shapes=structure.get_flat_tensor_shapes(output_structure))
-      else:
+    if _device_stack_is_empty():
+      with ops.device("/cpu:0"):
         iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
             string_handle,
             output_types=structure.get_flat_tensor_types(output_structure),
             output_shapes=structure.get_flat_tensor_shapes(output_structure))
     else:
-      iterator_resource = gen_dataset_ops.iterator_from_string_handle(
+      iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
           string_handle,
           output_types=structure.get_flat_tensor_types(output_structure),
           output_shapes=structure.get_flat_tensor_shapes(output_structure))
@@ -795,8 +781,7 @@
     return IteratorSpec(value.element_spec)  # pylint: disable=protected-access
 
 
-# TODO(b/71645805): Expose trackable stateful objects from dataset
-# attributes(potential).
+# TODO(b/71645805): Expose trackable stateful objects from dataset.
 class _IteratorSaveable(BaseSaverBuilder.SaveableObject):
   """SaveableObject for saving/restoring iterator state."""
 
diff --git a/tensorflow/python/data/ops/multi_device_iterator_ops.py b/tensorflow/python/data/ops/multi_device_iterator_ops.py
index 0a5fd45..acebe54 100644
--- a/tensorflow/python/data/ops/multi_device_iterator_ops.py
+++ b/tensorflow/python/data/ops/multi_device_iterator_ops.py
@@ -348,7 +348,7 @@
 
   def _eager_reset(self):
     """Resets the MultiDeviceIterator in eager mode."""
-    if not context.executing_eagerly():
+    if not ops.executing_eagerly_outside_functions():
       raise ValueError("Eager reset is only supported in eager mode.")
     # pylint: disable=protected-access
     self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
diff --git a/tensorflow/python/data/ops/optional_ops.py b/tensorflow/python/data/ops/optional_ops.py
index 07434f5..1e1402c 100644
--- a/tensorflow/python/data/ops/optional_ops.py
+++ b/tensorflow/python/data/ops/optional_ops.py
@@ -34,14 +34,16 @@
 @tf_export("data.experimental.Optional")
 @six.add_metaclass(abc.ABCMeta)
 class Optional(composite_tensor.CompositeTensor):
-  """Wraps a nested structure of tensors that may/may not be present at runtime.
+  """Wraps a value that may/may not be present at runtime.
 
   An `Optional` can represent the result of an operation that may fail as a
   value, rather than raising an exception and halting execution. For example,
   `tf.data.experimental.get_next_as_optional` returns an `Optional` that either
   contains the next value from a `tf.compat.v1.data.Iterator` if one exists, or
-  a "none"
-  value that indicates the end of the sequence has been reached.
+  a "none" value that indicates the end of the sequence has been reached.
+
+  `Optional` can only be used by values that are convertible to `Tensor` or
+  `CompositeTensor`.
   """
 
   @abc.abstractmethod
@@ -58,7 +60,7 @@
 
   @abc.abstractmethod
   def get_value(self, name=None):
-    """Returns a nested structure of values wrapped by this optional.
+    """Returns the value wrapped by this optional.
 
     If this optional does not have a value (i.e. `self.has_value()` evaluates
     to `False`), this operation will raise `tf.errors.InvalidArgumentError`
@@ -68,7 +70,7 @@
       name: (Optional.) A name for the created operation.
 
     Returns:
-      A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects.
+      The wrapped value.
     """
     raise NotImplementedError("Optional.get_value()")
 
@@ -87,7 +89,8 @@
     """Returns an `Optional` that wraps the given value.
 
     Args:
-      value: A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects.
+      value: A value to wrap. The value must be convertible to `Tensor` or
+        `CompositeTensor`.
 
     Returns:
       An `Optional` that wraps `value`.
diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py
index dab33fe..e6867b1 100644
--- a/tensorflow/python/data/ops/readers.py
+++ b/tensorflow/python/data/ops/readers.py
@@ -17,7 +17,6 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python.compat import compat
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.util import convert
 from tensorflow.python.framework import dtypes
@@ -26,17 +25,13 @@
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
 from tensorflow.python.util.tf_export import tf_export
 
 
 # TODO(b/64974358): Increase default buffer size to 256 MB.
 _DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024  # 256 KB
 
-# If the user requests the degree of interleave parallelism to be autotuned,
-# cycle length controls the maximum level of parallelism. We set it to a small
-# constant as a tradeoff between effective parallelism and memory and CPU usage.
-DEFAULT_CYCLE_LENGTH = 10
-
 
 def _create_or_validate_filenames_dataset(filenames):
   """Creates (or validates) a dataset of filenames.
@@ -84,13 +79,10 @@
   if num_parallel_reads is None:
     return filenames.flat_map(read_one_file)
   else:
-    cycle_length = num_parallel_reads
-    if num_parallel_reads == dataset_ops.AUTOTUNE:
-      cycle_length = DEFAULT_CYCLE_LENGTH
-    return filenames.interleave(
-        read_one_file,
-        cycle_length,
-        num_parallel_calls=num_parallel_reads)
+    return ParallelInterleaveDataset(
+        filenames, read_one_file, cycle_length=num_parallel_reads,
+        block_length=1, sloppy=False, buffer_output_elements=None,
+        prefetch_input_elements=None)
 
 
 class _TextLineDataset(dataset_ops.DatasetSource):
@@ -220,6 +212,56 @@
     return tensor_spec.TensorSpec([], dtypes.string)
 
 
+class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
+  """A `Dataset` that maps a function over its input and flattens the result."""
+
+  def __init__(self, input_dataset, map_func, cycle_length, block_length,
+               sloppy, buffer_output_elements, prefetch_input_elements):
+    """See `tf.data.experimental.parallel_interleave()` for details."""
+    self._input_dataset = input_dataset
+    self._map_func = dataset_ops.StructuredFunctionWrapper(
+        map_func, self._transformation_name(), dataset=input_dataset)
+    if not isinstance(self._map_func.output_structure, dataset_ops.DatasetSpec):
+      raise TypeError("`map_func` must return a `Dataset` object.")
+    self._element_spec = self._map_func.output_structure._element_spec  # pylint: disable=protected-access
+    self._cycle_length = ops.convert_to_tensor(
+        cycle_length, dtype=dtypes.int64, name="cycle_length")
+    self._block_length = ops.convert_to_tensor(
+        block_length, dtype=dtypes.int64, name="block_length")
+    self._sloppy = ops.convert_to_tensor(
+        sloppy, dtype=dtypes.bool, name="sloppy")
+    self._buffer_output_elements = convert.optional_param_to_tensor(
+        "buffer_output_elements",
+        buffer_output_elements,
+        argument_default=2 * block_length)
+    self._prefetch_input_elements = convert.optional_param_to_tensor(
+        "prefetch_input_elements",
+        prefetch_input_elements,
+        argument_default=2 * cycle_length)
+    variant_tensor = ged_ops.parallel_interleave_dataset(
+        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
+        self._map_func.function.captured_inputs,
+        self._cycle_length,
+        self._block_length,
+        self._sloppy,
+        self._buffer_output_elements,
+        self._prefetch_input_elements,
+        f=self._map_func.function,
+        **self._flat_structure)
+    super(ParallelInterleaveDataset, self).__init__(input_dataset,
+                                                    variant_tensor)
+
+  def _functions(self):
+    return [self._map_func]
+
+  @property
+  def element_spec(self):
+    return self._element_spec
+
+  def _transformation_name(self):
+    return "tf.data.experimental.parallel_interleave()"
+
+
 @tf_export("data.TFRecordDataset", v1=[])
 class TFRecordDatasetV2(dataset_ops.DatasetV2):
   """A `Dataset` comprising records from one or more TFRecord files."""
@@ -352,15 +394,9 @@
         compression_type,
         argument_default="",
         argument_dtype=dtypes.string)
-    if (self._compression_type is not None or
-        compat.forward_compatible(2018, 11, 30)):
-      variant_tensor = gen_dataset_ops.fixed_length_record_dataset_v2(
-          self._filenames, self._header_bytes, self._record_bytes,
-          self._footer_bytes, self._buffer_size, self._compression_type)
-    else:
-      variant_tensor = gen_dataset_ops.fixed_length_record_dataset(
-          self._filenames, self._header_bytes, self._record_bytes,
-          self._footer_bytes, self._buffer_size)
+    variant_tensor = gen_dataset_ops.fixed_length_record_dataset_v2(
+        self._filenames, self._header_bytes, self._record_bytes,
+        self._footer_bytes, self._buffer_size, self._compression_type)
     super(_FixedLengthRecordDataset, self).__init__(variant_tensor)
 
   @property
diff --git a/tensorflow/python/data/util/nest.py b/tensorflow/python/data/util/nest.py
index ebfd8af..24cdc97 100644
--- a/tensorflow/python/data/util/nest.py
+++ b/tensorflow/python/data/util/nest.py
@@ -35,12 +35,11 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections as _collections
-
 import six as _six
 
 from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
 from tensorflow.python.framework import sparse_tensor as _sparse_tensor
+from tensorflow.python.util.compat import collections_abc as _collections_abc
 
 
 def _sorted(dict_):
@@ -69,9 +68,8 @@
     # corresponding `OrderedDict` to pack it back).
     result = dict(zip(_sorted(instance), args))
     return type(instance)((key, result[key]) for key in instance)
-  elif (isinstance(instance, tuple) and
-        hasattr(instance, "_fields") and
-        isinstance(instance._fields, _collections.Sequence) and
+  elif (isinstance(instance, tuple) and hasattr(instance, "_fields") and
+        isinstance(instance._fields, _collections_abc.Sequence) and
         all(isinstance(f, _six.string_types) for f in instance._fields)):
     # This is a namedtuple
     return type(instance)(*args)
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 86b9478..6a087df 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -800,6 +800,22 @@
 )
 
 cuda_py_test(
+    name = "debug_grappler_test",
+    size = "small",
+    srcs = ["lib/debug_grappler_test.py"],
+    additional_deps = [
+        ":debug_data",
+        ":debug_utils",
+        "//tensorflow/python:client",
+        "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:platform_test",
+        "//tensorflow/python:variables",
+    ],
+    xla_enable_strict_auto_jit = False,  # Tests TF:Classic implementation
+)
+
+cuda_py_test(
     name = "session_debug_file_test",
     size = "small",
     srcs = ["lib/session_debug_file_test.py"],
diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py
index 586982d..7958498 100644
--- a/tensorflow/python/debug/cli/analyzer_cli_test.py
+++ b/tensorflow/python/debug/cli/analyzer_cli_test.py
@@ -47,6 +47,13 @@
 from tensorflow.python.util import tf_inspect
 
 
+# Helper function to accommodate MKL-enabled TensorFlow:
+# MatMul op is supported by MKL and its name is prefixed with "_Mkl" during the
+# MKL graph rewrite pass.
+def _matmul_op_name():
+  return "_MklMatMul" if test_util.IsMklEnabled() else "MatMul"
+
+
 def _cli_config_from_temp_file():
   return cli_config.CLIConfig(
       config_file_path=os.path.join(tempfile.mkdtemp(), ".tfdbg_config"))
@@ -135,14 +142,9 @@
   attr_segs = out.font_attr_segs
   line_counter = 0
 
-  num_tensors = len(expected_tensor_names)
-
-  if tensor_filter_name is None:
-    tst.assertEqual("%d dumped tensor(s):" % num_tensors, next(line_iter))
-  else:
-    tst.assertEqual("%d dumped tensor(s) passing filter \"%s\":" %
-                    (num_tensors, tensor_filter_name), next(line_iter))
+  num_dumped_tensors = int(next(line_iter).split(" ")[0])
   line_counter += 1
+  tst.assertGreaterEqual(num_dumped_tensors, len(expected_tensor_names))
 
   if op_type_regex is not None:
     tst.assertEqual("Op type regex filter: \"%s\"" % op_type_regex,
@@ -669,7 +671,10 @@
         "simple_mul_add/u:0", "simple_mul_add/v:0", "simple_mul_add/u/read:0",
         "simple_mul_add/v/read:0", "simple_mul_add/matmul:0",
         "simple_mul_add/add:0"
-    ], ["VariableV2", "VariableV2", "Identity", "Identity", "MatMul", "Add"])
+    ], [
+        "VariableV2", "VariableV2", "Identity", "Identity",
+        _matmul_op_name(), "Add"
+    ])
 
     # Check the main menu.
     check_main_menu(self, out, list_tensors_enabled=False)
@@ -683,8 +688,10 @@
             "simple_mul_add/u:0", "simple_mul_add/v:0",
             "simple_mul_add/u/read:0", "simple_mul_add/v/read:0",
             "simple_mul_add/matmul:0", "simple_mul_add/add:0"
+        ], [
+            "VariableV2", "VariableV2", "Identity", "Identity",
+            _matmul_op_name(), "Add"
         ],
-        ["VariableV2", "VariableV2", "Identity", "Identity", "MatMul", "Add"],
         sort_by="timestamp",
         reverse=True)
     check_main_menu(self, out, list_tensors_enabled=False)
@@ -697,8 +704,10 @@
             "simple_mul_add/u:0", "simple_mul_add/v:0",
             "simple_mul_add/u/read:0", "simple_mul_add/v/read:0",
             "simple_mul_add/matmul:0", "simple_mul_add/add:0"
+        ], [
+            "VariableV2", "VariableV2", "Identity", "Identity",
+            _matmul_op_name(), "Add"
         ],
-        ["VariableV2", "VariableV2", "Identity", "Identity", "MatMul", "Add"],
         sort_by="dump_size")
     check_main_menu(self, out, list_tensors_enabled=False)
 
@@ -710,8 +719,10 @@
             "simple_mul_add/u:0", "simple_mul_add/v:0",
             "simple_mul_add/u/read:0", "simple_mul_add/v/read:0",
             "simple_mul_add/matmul:0", "simple_mul_add/add:0"
+        ], [
+            "VariableV2", "VariableV2", "Identity", "Identity",
+            _matmul_op_name(), "Add"
         ],
-        ["VariableV2", "VariableV2", "Identity", "Identity", "MatMul", "Add"],
         sort_by="dump_size",
         reverse=True)
     check_main_menu(self, out, list_tensors_enabled=False)
@@ -730,8 +741,10 @@
             "simple_mul_add/u:0", "simple_mul_add/v:0",
             "simple_mul_add/u/read:0", "simple_mul_add/v/read:0",
             "simple_mul_add/matmul:0", "simple_mul_add/add:0"
+        ], [
+            "VariableV2", "VariableV2", "Identity", "Identity",
+            _matmul_op_name(), "Add"
         ],
-        ["VariableV2", "VariableV2", "Identity", "Identity", "MatMul", "Add"],
         sort_by="op_type",
         reverse=False)
     check_main_menu(self, out, list_tensors_enabled=False)
@@ -745,8 +758,10 @@
             "simple_mul_add/u:0", "simple_mul_add/v:0",
             "simple_mul_add/u/read:0", "simple_mul_add/v/read:0",
             "simple_mul_add/matmul:0", "simple_mul_add/add:0"
+        ], [
+            "VariableV2", "VariableV2", "Identity", "Identity",
+            _matmul_op_name(), "Add"
         ],
-        ["VariableV2", "VariableV2", "Identity", "Identity", "MatMul", "Add"],
         sort_by="op_type",
         reverse=True)
     check_main_menu(self, out, list_tensors_enabled=False)
@@ -760,8 +775,10 @@
             "simple_mul_add/u:0", "simple_mul_add/v:0",
             "simple_mul_add/u/read:0", "simple_mul_add/v/read:0",
             "simple_mul_add/matmul:0", "simple_mul_add/add:0"
+        ], [
+            "VariableV2", "VariableV2", "Identity", "Identity",
+            _matmul_op_name(), "Add"
         ],
-        ["VariableV2", "VariableV2", "Identity", "Identity", "MatMul", "Add"],
         sort_by="tensor_name",
         reverse=False)
     check_main_menu(self, out, list_tensors_enabled=False)
@@ -775,8 +792,10 @@
             "simple_mul_add/u:0", "simple_mul_add/v:0",
             "simple_mul_add/u/read:0", "simple_mul_add/v/read:0",
             "simple_mul_add/matmul:0", "simple_mul_add/add:0"
+        ], [
+            "VariableV2", "VariableV2", "Identity", "Identity",
+            _matmul_op_name(), "Add"
         ],
-        ["VariableV2", "VariableV2", "Identity", "Identity", "MatMul", "Add"],
         sort_by="tensor_name",
         reverse=True)
     check_main_menu(self, out, list_tensors_enabled=False)
@@ -803,13 +822,13 @@
         ["Identity", "Identity"],
         op_type_regex="Identity")
 
-    out = self._registry.dispatch_command("list_tensors",
-                                          ["-t", "(Add|MatMul)"])
+    out = self._registry.dispatch_command(
+        "list_tensors", ["-t", "(Add|" + _matmul_op_name() + ")"])
     assert_listed_tensors(
         self,
         out, ["simple_mul_add/add:0", "simple_mul_add/matmul:0"],
-        ["Add", "MatMul"],
-        op_type_regex="(Add|MatMul)")
+        ["Add", _matmul_op_name()],
+        op_type_regex=("(Add|" + _matmul_op_name() + ")"))
     check_main_menu(self, out, list_tensors_enabled=False)
 
   def testListTensorFilterByNodeNameRegexAndOpTypeRegex(self):
@@ -845,7 +864,9 @@
     assert_listed_tensors(
         self,
         out, ["simple_mul_add/matmul:0", "simple_mul_add/add:0"],
-        ["MatMul", "Add"], tensor_filter_name="is_2x1_vector")
+        [_matmul_op_name(), "Add"],
+        tensor_filter_name="is_2x1_vector")
+
     check_main_menu(self, out, list_tensors_enabled=False)
 
   def testListTensorsFilterNanOrInf(self):
@@ -884,7 +905,7 @@
 
     recipients = [("Add", "simple_mul_add/add"), ("Add", "simple_mul_add/add")]
 
-    assert_node_attribute_lines(self, out, node_name, "MatMul",
+    assert_node_attribute_lines(self, out, node_name, _matmul_op_name(),
                                 self._main_device,
                                 [("Identity", "simple_mul_add/u/read"),
                                  ("Identity", "simple_mul_add/v/read")], [],
@@ -906,17 +927,21 @@
     node_name = "simple_mul_add/matmul"
     out = self._registry.dispatch_command("node_info", ["-a", node_name])
 
+    test_attr_key_val_pairs = [("transpose_a", "b: false"),
+                               ("transpose_b", "b: false"),
+                               ("T", "type: DT_DOUBLE")]
+    if test_util.IsMklEnabled():
+      test_attr_key_val_pairs.append(("_kernel", 's: "MklNameChangeOp"'))
+
     assert_node_attribute_lines(
         self,
         out,
         node_name,
-        "MatMul",
+        _matmul_op_name(),
         self._main_device, [("Identity", "simple_mul_add/u/read"),
                             ("Identity", "simple_mul_add/v/read")], [],
         [("Add", "simple_mul_add/add"), ("Add", "simple_mul_add/add")], [],
-        attr_key_val_pairs=[("transpose_a", "b: false"),
-                            ("transpose_b", "b: false"),
-                            ("T", "type: DT_DOUBLE")])
+        attr_key_val_pairs=test_attr_key_val_pairs)
     check_main_menu(
         self,
         out,
@@ -933,7 +958,7 @@
         self,
         out,
         node_name,
-        "MatMul",
+        _matmul_op_name(),
         self._main_device, [("Identity", "simple_mul_add/u/read"),
                             ("Identity", "simple_mul_add/v/read")], [],
         [("Add", "simple_mul_add/add"), ("Add", "simple_mul_add/add")], [],
@@ -959,11 +984,12 @@
         self,
         out,
         node_name,
-        "MatMul",
+        _matmul_op_name(),
         self._main_device, [("Identity", "simple_mul_add/u/read"),
                             ("Identity", "simple_mul_add/v/read")], [],
         [("Add", "simple_mul_add/add"), ("Add", "simple_mul_add/add")], [],
-        show_stack_trace=True, stack_trace_available=False)
+        show_stack_trace=True,
+        stack_trace_available=False)
     check_main_menu(
         self,
         out,
@@ -982,11 +1008,12 @@
         self,
         out,
         node_name,
-        "MatMul",
+        _matmul_op_name(),
         self._main_device, [("Identity", "simple_mul_add/u/read"),
                             ("Identity", "simple_mul_add/v/read")], [],
         [("Add", "simple_mul_add/add"), ("Add", "simple_mul_add/add")], [],
-        show_stack_trace=True, stack_trace_available=True)
+        show_stack_trace=True,
+        stack_trace_available=True)
     check_main_menu(
         self,
         out,
@@ -1003,7 +1030,8 @@
     assert_node_attribute_lines(self, out, node_name, "Identity",
                                 self._main_device,
                                 [("VariableV2", "simple_mul_add/u")], [],
-                                [("MatMul", "simple_mul_add/matmul")], [])
+                                [(_matmul_op_name(), "simple_mul_add/matmul")],
+                                [])
     check_main_menu(
         self,
         out,
diff --git a/tensorflow/python/debug/lib/debug_grappler_test.py b/tensorflow/python/debug/lib/debug_grappler_test.py
new file mode 100644
index 0000000..7a3bf90
--- /dev/null
+++ b/tensorflow/python/debug/lib/debug_grappler_test.py
@@ -0,0 +1,121 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for debugger functionalities in tf.Session."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import shutil
+import tempfile
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.debug.lib import debug_data
+from tensorflow.python.debug.lib import debug_utils
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+
+
+def _grappler_enabled_session_config():
+  """Constructs a Session config proto that explicitly enables Grappler.
+
+  Returns:
+    A config proto that obtains extra safety for the unit tests in this
+    file by ensuring that the relevant Grappler rewrites are always enabled.
+  """
+  rewriter_config = rewriter_config_pb2.RewriterConfig(
+      disable_model_pruning=False,
+      arithmetic_optimization=rewriter_config_pb2.RewriterConfig.ON)
+  graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
+  return config_pb2.ConfigProto(graph_options=graph_options)
+
+
+class SessionDebugGrapplerInteractionTest(test_util.TensorFlowTestCase):
+
+  def setUp(self):
+    super(SessionDebugGrapplerInteractionTest, self).setUp()
+    self._dump_root = tempfile.mkdtemp()
+    self._debug_url = "file://%s" % self._dump_root
+
+  def tearDown(self):
+    ops.reset_default_graph()
+    if os.path.isdir(self._dump_root):
+      shutil.rmtree(self._dump_root)
+    super(SessionDebugGrapplerInteractionTest, self).tearDown()
+
+  def testArithmeticOptimizationActive(self):
+    """Tests that tfdbg can dump the tensor from nodes created by Grappler."""
+    with session.Session(config=_grappler_enabled_session_config()) as sess:
+      u = variables.VariableV1([[1, 2], [3, 4]], name="u", dtype=dtypes.float32)
+      # The next two ops should be optimized by Grappler into a single op:
+      # either an AddN op or a Mul op.
+      x = math_ops.add(u, u)
+      x = math_ops.add(x, u)
+      y = math_ops.multiply(x, u)
+
+      sess.run(variables.global_variables_initializer())
+
+      run_options = config_pb2.RunOptions(output_partition_graphs=True)
+      debug_utils.watch_graph(
+          run_options,
+          sess.graph,
+          debug_ops=["DebugIdentity"],
+          debug_urls=[self._debug_url])
+
+      run_metadata = config_pb2.RunMetadata()
+      run_result = sess.run(y, options=run_options, run_metadata=run_metadata)
+      self.assertAllClose(run_result, [[3, 12], [27, 48]])
+
+      dump_data = debug_data.DebugDumpDir(
+          self._dump_root, partition_graphs=run_metadata.partition_graphs,
+          validate=True)
+
+      original_node_names = set([op.name for op in sess.graph.get_operations()])
+      dumped_node_names = set(dump_data.nodes())
+      grappler_created_node_names = dumped_node_names - original_node_names
+      grappler_removed_node_names = original_node_names - dumped_node_names
+
+      # Assert that Grappler should have replaced some of the nodes from the
+      # original graph with new nodes.
+      self.assertTrue(grappler_created_node_names)
+      self.assertTrue(grappler_removed_node_names)
+
+      # Iterate through the nodes created by Grappler. One of them should be
+      # be the result of replacing the original add ops with an AddN op or a
+      # Mul op.
+      found_optimized_node = False
+      for grappler_node_name in grappler_created_node_names:
+        node_op_type = dump_data.node_op_type(grappler_node_name)
+        # Look for the node created by Grappler's arithmetic optimization.
+        if node_op_type in ("AddN", "Mul"):
+          datum = dump_data.get_tensors(grappler_node_name, 0, "DebugIdentity")
+          self.assertEqual(1, len(datum))
+          self.assertAllClose(datum[0], [[3, 6], [9, 12]])
+          found_optimized_node = True
+          break
+      self.assertTrue(
+          found_optimized_node,
+          "Failed to find optimized node created by Grappler's arithmetic "
+          "optimization.")
+
+
+if __name__ == "__main__":
+  googletest.main()
diff --git a/tensorflow/python/debug/lib/debug_utils.py b/tensorflow/python/debug/lib/debug_utils.py
index f2a43a6..eb21694 100644
--- a/tensorflow/python/debug/lib/debug_utils.py
+++ b/tensorflow/python/debug/lib/debug_utils.py
@@ -134,6 +134,10 @@
     reset_disk_byte_usage: (`bool`) whether to reset the tracked disk byte
       usage to zero (default: `False`).
   """
+  if not debug_ops:
+    raise ValueError("debug_ops must not be empty or None.")
+  if not debug_urls:
+    raise ValueError("debug_urls must not be empty or None.")
 
   if isinstance(debug_ops, str):
     debug_ops = [debug_ops]
@@ -173,6 +177,23 @@
           tolerate_debug_op_creation_failures=(
               tolerate_debug_op_creation_failures),
           global_step=global_step)
+
+  # If no filter for node or tensor is used, will add a wildcard node name, so
+  # that all nodes, including the ones created internally by TensorFlow itself
+  # (e.g., by Grappler), can be watched during debugging.
+  use_node_name_wildcard = (not node_name_pattern and
+                            not op_type_pattern and
+                            not tensor_dtype_pattern)
+  if use_node_name_wildcard:
+    add_debug_tensor_watch(
+        run_options,
+        "*",
+        output_slot=-1,
+        debug_ops=debug_ops,
+        debug_urls=debug_urls,
+        tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures,
+        global_step=global_step)
+
   run_options.debug_options.reset_disk_byte_usage = reset_disk_byte_usage
 
 
diff --git a/tensorflow/python/debug/lib/debug_utils_test.py b/tensorflow/python/debug/lib/debug_utils_test.py
index 9d59cfc..6e0b637 100644
--- a/tensorflow/python/debug/lib/debug_utils_test.py
+++ b/tensorflow/python/debug/lib/debug_utils_test.py
@@ -59,11 +59,13 @@
     cls._graph = cls._sess.graph
 
     # These are all the expected nodes in the graph:
-    #   Two variables (a, b), each with four nodes (Variable, init, Assign,
-    #       read).
-    #   One constant (c).
-    #   One add operation and one matmul operation.
-    cls._expected_num_nodes = 4 * 2 + 1 + 1 + 1
+    #   - Two variables (a, b), each with four nodes (Variable, init, Assign,
+    #     read).
+    #   - One constant (c).
+    #   - One add operation and one matmul operation.
+    #   - One wildcard node name ("*") that covers nodes created internally
+    #     by TensorFlow itself (e.g., Grappler).
+    cls._expected_num_nodes = 4 * 2 + 1 + 1 + 1 + 1
 
   def setUp(self):
     self._run_options = config_pb2.RunOptions()
@@ -88,9 +90,14 @@
     for watch in watch_opts:
       node_names.append(watch.node_name)
 
-      self.assertEqual(expected_output_slot, watch.output_slot)
-      self.assertEqual(expected_debug_ops, watch.debug_ops)
-      self.assertEqual(expected_debug_urls, watch.debug_urls)
+      if watch.node_name == "*":
+        self.assertEqual(-1, watch.output_slot)
+        self.assertEqual(expected_debug_ops, watch.debug_ops)
+        self.assertEqual(expected_debug_urls, watch.debug_urls)
+      else:
+        self.assertEqual(expected_output_slot, watch.output_slot)
+        self.assertEqual(expected_debug_ops, watch.debug_ops)
+        self.assertEqual(expected_debug_urls, watch.debug_urls)
 
     return node_names
 
@@ -203,19 +210,22 @@
                                       ["file:///tmp/tfdbg_1"])
 
     # Verify the node names.
-    self.assertTrue("a1_init" in node_names)
-    self.assertTrue("a1" in node_names)
-    self.assertTrue("a1/Assign" in node_names)
-    self.assertTrue("a1/read" in node_names)
+    self.assertIn("a1_init", node_names)
+    self.assertIn("a1", node_names)
+    self.assertIn("a1/Assign", node_names)
+    self.assertIn("a1/read", node_names)
 
-    self.assertTrue("b_init" in node_names)
-    self.assertTrue("b" in node_names)
-    self.assertTrue("b/Assign" in node_names)
-    self.assertTrue("b/read" in node_names)
+    self.assertIn("b_init", node_names)
+    self.assertIn("b", node_names)
+    self.assertIn("b/Assign", node_names)
+    self.assertIn("b/read", node_names)
 
-    self.assertTrue("c" in node_names)
-    self.assertTrue("p1" in node_names)
-    self.assertTrue("s" in node_names)
+    self.assertIn("c", node_names)
+    self.assertIn("p1", node_names)
+    self.assertIn("s", node_names)
+
+    # Assert that the wildcard node name has been created.
+    self.assertIn("*", node_names)
 
   @test_util.run_v1_only("b/120545219")
   def testWatchGraph_nodeNameWhitelist(self):
diff --git a/tensorflow/python/debug/lib/session_debug_grpc_test.py b/tensorflow/python/debug/lib/session_debug_grpc_test.py
index db93946..f3b187c 100644
--- a/tensorflow/python/debug/lib/session_debug_grpc_test.py
+++ b/tensorflow/python/debug/lib/session_debug_grpc_test.py
@@ -164,7 +164,7 @@
     self.assertAllClose(42.0, w_result)
 
     dump = debug_data.DebugDumpDir(self._dump_root)
-    self.assertEqual(5, dump.size)
+    self.assertLessEqual(5, dump.size)
     self.assertAllClose([2.1], dump.get_tensors("u", 0, "DebugIdentity"))
     self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity"))
     self.assertAllClose([20.0], dump.get_tensors("v", 0, "DebugIdentity"))
diff --git a/tensorflow/python/debug/lib/session_debug_testlib.py b/tensorflow/python/debug/lib/session_debug_testlib.py
index d14399b..e2740d8 100644
--- a/tensorflow/python/debug/lib/session_debug_testlib.py
+++ b/tensorflow/python/debug/lib/session_debug_testlib.py
@@ -659,16 +659,15 @@
 
       # Verify that the nodes with bad values are caught through running find
       # on the debug dump.
-      self.assertEqual(3, len(bad_data))
-      self.assertEqual(x_name, bad_data[0].node_name)
-      self.assertEqual(y_name, bad_data[1].node_name)
-      self.assertEqual(z_name, bad_data[2].node_name)
+      self.assertLessEqual(3, len(bad_data))
+      node_names = [datum.node_name for datum in bad_data]
+      self.assertIn(x_name, node_names)
+      self.assertIn(y_name, node_names)
+      self.assertIn(z_name, node_names)
 
       # Test first_n kwarg of find(): Find the first offending tensor.
       first_bad_datum = dump.find(has_bad_value, first_n=1)
-
       self.assertEqual(1, len(first_bad_datum))
-      self.assertEqual(x_name, first_bad_datum[0].node_name)
 
   def testFindInfOrNanWithOpNameExclusion(self):
     with session.Session() as sess:
@@ -708,16 +707,15 @@
 
       # Verify that the nodes with bad values are caught through running find
       # on the debug dump.
-      self.assertEqual(2, len(bad_data))
+      self.assertLessEqual(2, len(bad_data))
       # Assert that the node `x` should have been excluded.
-      self.assertEqual(y_name, bad_data[0].node_name)
-      self.assertEqual(z_name, bad_data[1].node_name)
+      node_names = [datum.node_name for datum in bad_data]
+      self.assertIn(y_name, node_names)
+      self.assertIn(z_name, node_names)
 
       first_bad_datum = dump.find(
           debug_data.has_inf_or_nan, first_n=1, exclude_node_names=".*/x$")
-
       self.assertEqual(1, len(first_bad_datum))
-      self.assertEqual(y_name, first_bad_datum[0].node_name)
 
   def _session_run_for_graph_structure_lookup(self):
     with session.Session(config=no_rewrite_session_config()) as sess:
@@ -1378,7 +1376,7 @@
           sess, y, debug_ops=["DebugNumericSummary(mute_if_healthy=true)"],
           validate=False)
 
-      self.assertEqual(2, dump.size)
+      self.assertLessEqual(2, dump.size)
       self.assertAllClose([[
           1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, np.inf, -np.inf, np.nan,
           np.nan, 1.0, 0.0
@@ -1393,7 +1391,7 @@
       shutil.rmtree(self._dump_root)
       _, dump = self._debug_run_and_get_dump(
           sess, y, debug_ops=["DebugNumericSummary()"])
-      self.assertEqual(8, dump.size)
+      self.assertLessEqual(8, dump.size)
 
   def testDebugNumericSummaryMuteOnHealthyAndCustomBoundsWork(self):
     with session.Session() as sess:
diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
index 3d90fa0..83222f2 100644
--- a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
+++ b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
@@ -459,7 +459,8 @@
     self.assertEqual(2, len(debug_dumps))
     for debug_dump in debug_dumps:
       node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data]
-      self.assertItemsEqual(["callable_a", "callable_b"], node_names)
+      self.assertIn("callable_a", node_names)
+      self.assertIn("callable_b", node_names)
 
   def testDebuggingMakeCallableFromOptionsWithTwoFeedsWorks(self):
     ph1 = array_ops.placeholder(dtypes.float32, name="callable_ph1")
@@ -486,7 +487,8 @@
     self.assertEqual(2, len(debug_dumps))
     for debug_dump in debug_dumps:
       node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data]
-      self.assertItemsEqual(["callable_a", "callable_b"], node_names)
+      self.assertIn("callable_a", node_names)
+      self.assertIn("callable_b", node_names)
 
   def testDebugMakeCallableFromOptionsWithCustomOptionsAndMetadataWorks(self):
     variable_1 = variables.VariableV1(
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 3eebc63..38d12f8 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -1,7 +1,7 @@
 load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test")
 load("//tensorflow:tensorflow.bzl", "cuda_py_test")
 load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test")
-load("//tensorflow/core:platform/default/distribute.bzl", "distribute_py_test")
+load("//tensorflow/core/platform:default/distribute.bzl", "distribute_py_test")
 
 package(
     default_visibility = ["//tensorflow:internal"],
@@ -75,6 +75,7 @@
         "//tensorflow/python:math_ops",
         "//tensorflow/python:platform",
         "//tensorflow/python:resource_variable_ops",
+        "//tensorflow/python:tensor_util",
         "//tensorflow/python/eager:context",
         "//tensorflow/tools/docs:doc_controls",
         "@six_archive//:six",
@@ -220,6 +221,7 @@
     srcs = ["distribute_coordinator_test.py"],
     python_version = "PY2",
     srcs_version = "PY2AND3",
+    tags = ["no_oss_py2"],  # b/138443278
     deps = [
         ":distribute_coordinator",
         "//tensorflow/core:protos_all_py",
@@ -882,6 +884,7 @@
         ":single_loss_example",
         "//tensorflow/contrib/tpu:tpu_lib",
         "//tensorflow/python:control_flow_ops",
+        "//tensorflow/python:control_flow_v2_toggles",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:variable_scope",
@@ -1108,7 +1111,9 @@
     name = "saved_model_save_load_test",
     size = "medium",
     srcs = ["saved_model_save_load_test.py"],
+    full_precision = True,
     main = "saved_model_save_load_test.py",
+    shard_count = 5,
     deps = [
         ":saved_model_test_base",
         "//tensorflow/python/saved_model",
@@ -1116,27 +1121,12 @@
 )
 
 distribute_py_test(
-    name = "keras_experimental_saved_model_test",
-    size = "medium",
-    srcs = ["keras_experimental_saved_model_test.py"],
-    main = "keras_experimental_saved_model_test.py",
-    shard_count = 5,
-    tags = [
-        "no_oss",  # TODO(b/135287893) reenable
-        "no_rocm",
-    ],
-    deps = [
-        ":saved_model_test_base",
-        "//tensorflow/python/keras:saving",
-    ],
-)
-
-distribute_py_test(
     name = "keras_save_load_test",
     size = "medium",
     srcs = ["keras_save_load_test.py"],
+    full_precision = True,
     main = "keras_save_load_test.py",
-    shard_count = 3,
+    shard_count = 5,
     deps = [
         ":saved_model_test_base",
         "//tensorflow/python/keras:saving",
@@ -1147,7 +1137,9 @@
     name = "saved_model_mixed_api_test",
     size = "medium",
     srcs = ["saved_model_mixed_api_test.py"],
+    full_precision = True,
     main = "saved_model_mixed_api_test.py",
+    shard_count = 5,
     deps = [
         ":saved_model_test_base",
         "//tensorflow/python/keras:saving",
diff --git a/tensorflow/python/distribute/cluster_resolver/BUILD b/tensorflow/python/distribute/cluster_resolver/BUILD
index c4341ca..4862333 100644
--- a/tensorflow/python/distribute/cluster_resolver/BUILD
+++ b/tensorflow/python/distribute/cluster_resolver/BUILD
@@ -1,6 +1,10 @@
 # Description: Operations defined for Cluster Resolvers
 
 load("//tensorflow:tensorflow.bzl", "tf_py_test")
+load(
+    "//tensorflow/core/platform:default/build_config.bzl",
+    "tf_additional_rpc_deps",
+)
 
 package(
     default_visibility = [
@@ -63,7 +67,7 @@
     deps = [
         ":base_cluster_resolver_py",
         "//tensorflow/python:training_server_lib",
-    ],
+    ] + tf_additional_rpc_deps(),
 )
 
 py_library(
diff --git a/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py b/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py
index 757d2a4..be7df0e 100644
--- a/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py
+++ b/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py
@@ -276,6 +276,19 @@
     if self._is_google_environment():
       self._environment = 'google'
       self.rpc_layer = None
+
+      # TODO(rsopher): remove this logic when possible
+      if self._tpu and self._tpu.startswith(compat.as_bytes('/bns')):
+        bns_and_port = self._tpu.rsplit(compat.as_bytes(':'), 1)
+        if len(bns_and_port) == 2:
+          try:
+            int(bns_and_port[1])
+          except ValueError:
+            # Leave named ports.
+            pass
+          else:
+            # Strip numerical ports.
+            self._tpu = bns_and_port[0]
     else:
       self._environment = ''
       self.rpc_layer = 'grpc'
diff --git a/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver_test.py b/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver_test.py
index cb4d785..37c8216 100644
--- a/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver_test.py
+++ b/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver_test.py
@@ -516,6 +516,20 @@
     cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/ab/cd/ef')
     self.assertEqual(cluster_resolver.environment, 'google')
     self.assertEqual(cluster_resolver.rpc_layer, None)
+    self.assertEqual(cluster_resolver._tpu, compat.as_bytes('/bns/ab/cd/ef'))
+
+  def testEnvironmentAndRpcDetectionForGoogleNumericalPort(self):
+    cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/ab/cd/ef:1234')
+    self.assertEqual(cluster_resolver.environment, 'google')
+    self.assertEqual(cluster_resolver.rpc_layer, None)
+    self.assertEqual(cluster_resolver._tpu, compat.as_bytes('/bns/ab/cd/ef'))
+
+  def testEnvironmentAndRpcDetectionForGoogleNamedPort(self):
+    cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/ab/cd/ef:port')
+    self.assertEqual(cluster_resolver.environment, 'google')
+    self.assertEqual(cluster_resolver.rpc_layer, None)
+    self.assertEqual(cluster_resolver._tpu,
+                     compat.as_bytes('/bns/ab/cd/ef:port'))
 
   def testEnvironmentAndRpcDetectionForGrpcString(self):
     cluster_resolver = resolver.TPUClusterResolver(
diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py
index c43d28b..e35f95a 100644
--- a/tensorflow/python/distribute/collective_all_reduce_strategy.py
+++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py
@@ -40,7 +40,6 @@
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import collective_ops
 from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.training.tracking import base as trackable
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -336,11 +335,6 @@
 
           if self._num_workers > 1:
             if self._is_chief:
-              # Unwrap `initial_value` if it is a `CheckpointInitialValue`.
-              # TODO(b/138130844): Revert the following check once
-              # `CheckpointInitialValue` class is removed.
-              if isinstance(initial_value, trackable.CheckpointInitialValue):
-                initial_value = initial_value.wrapped_value
               bcast_send = collective_ops.broadcast_send(
                   initial_value, initial_value.shape, initial_value.dtype,
                   group_size, group_key, collective_instance_key)
diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy_test.py b/tensorflow/python/distribute/collective_all_reduce_strategy_test.py
index f9e2a11..8c30366 100644
--- a/tensorflow/python/distribute/collective_all_reduce_strategy_test.py
+++ b/tensorflow/python/distribute/collective_all_reduce_strategy_test.py
@@ -580,12 +580,12 @@
       self._test_all_reduce_mean_gradient_tape(distribution)
 
   @combinations.generate(combinations.combine(mode=['graph']))
-  def testNumpyIterator(self):
+  def testNumpyDataset(self):
     num_gpus = 2
     if context.num_gpus() < num_gpus:
       self.skipTest('Not enough GPUs')
     strategy, _, _ = self._get_test_object(None, None, num_gpus=num_gpus)
-    self._test_numpy_iterator(strategy)
+    self._test_numpy_dataset(strategy)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py
index 1932a5a..56ec856 100644
--- a/tensorflow/python/distribute/cross_device_ops.py
+++ b/tensorflow/python/distribute/cross_device_ops.py
@@ -30,6 +30,7 @@
 from tensorflow.python.eager import context
 from tensorflow.python.framework import kernels
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import resource_variable_ops
@@ -84,7 +85,8 @@
   # If the same value is present on all replicas then the PerReplica value will
   # be a single value. We also handle the case when `value` is a single value
   # and equal to 0.
-  if value == 0:
+  # TODO:(b/138823479): handle the tensor value properly.
+  if not tensor_util.is_tensor(value) and value == 0:
     return 0
   # If there is only a single value and the reduce op is MEAN,
   # that value should be on all destinations.
@@ -262,7 +264,7 @@
       ValueError: if per_replica_value can't be converted to a PerReplica
         object.
     """
-    if not isinstance(per_replica_value, value_lib.PerReplica):
+    if not isinstance(per_replica_value, value_lib.DistributedValues):
       per_replica_value = _make_tensor_into_per_replica(per_replica_value)
 
     validate_destinations(destinations)
diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py
index ec85cd3..85a33e6 100644
--- a/tensorflow/python/distribute/distribute_lib.py
+++ b/tensorflow/python/distribute/distribute_lib.py
@@ -97,8 +97,11 @@
 
 import copy
 import enum  # pylint: disable=g-bad-import-order
+import json
+import os
 import threading
 import weakref
+
 import six
 
 from tensorflow.python.autograph.core import ag_ctx
@@ -123,6 +126,8 @@
 from tensorflow.python.ops.losses import loss_reduction
 from tensorflow.python.ops.losses import losses_impl
 from tensorflow.python.platform import tf_logging
+from tensorflow.python.training import server_lib
+from tensorflow.python.training.tracking import base as trackable
 from tensorflow.python.util import nest
 from tensorflow.python.util import tf_contextlib
 from tensorflow.python.util.tf_export import tf_export
@@ -410,13 +415,16 @@
 # pylint: disable=line-too-long
 @tf_export("distribute.Strategy", v1=[])
 class Strategy(object):
-  """A list of devices with a state & compute distribution policy.
+  """A state & compute distribution policy on a list of devices.
 
   See [the guide](https://www.tensorflow.org/alpha/guide/distribute_strategy)
   for overview and examples.
 
   In short:
 
+  * To use it with Keras `compile`/`fit`,
+    [please
+    read](https://www.tensorflow.org/alpha/guide/distribute_strategy#using_tfdistributestrategy_with_keras).
   * You may pass descendant of `tf.distribute.Strategy` to
     `tf.estimator.RunConfig` to specify how a `tf.estimator.Estimator`
     should distribute its computation. See
@@ -425,11 +433,10 @@
     strategy should be used when building an executing your model.
     (This puts you in the "cross-replica context" for this strategy, which
     means the strategy is put in control of things like variable placement.)
-  * If using Keras `compile`/`fit`,
-    [that is it](https://www.tensorflow.org/alpha/guide/distribute_strategy#using_tfdistributestrategy_with_keras).
   * If you are writing a custom training loop, you will need to call a few more
     methods,
-    [see the guide](https://www.tensorflow.org/alpha/guide/distribute_strategy#using_tfdistributestrategy_with_custom_training_loops):
+    [see the
+    guide](https://www.tensorflow.org/alpha/guide/distribute_strategy#using_tfdistributestrategy_with_custom_training_loops):
 
       * Start by either creating a `tf.data.Dataset` normally or using
         `tf.distribute.experimental_make_numpy_dataset` to make a dataset out of
@@ -484,7 +491,8 @@
   accumulate metrics across steps in a given epoch.
 
   See the
-  [custom training loop tutorial](https://www.tensorflow.org/alpha/tutorials/distribute/training_loops)
+  [custom training loop
+  tutorial](https://www.tensorflow.org/alpha/tutorials/distribute/training_loops)
   for a more detailed example.
 
   Note: `tf.distribute.Strategy` currently does not support TensorFlow's
@@ -725,14 +733,17 @@
     Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
     "per-replica" values, such as those produced by a "distributed `Dataset`",
     when `fn` is executed on a particular replica, it will be executed with the
-    component of those "per-replica" values that corresponds to that replica.
+    component of those "per-replica" values that correspond to that replica.
 
     `fn` may call `tf.distribute.get_replica_context()` to access members such
     as `all_reduce`.
 
-    IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being
-    used, and whether eager execution is enabled, `fn` may be called one or more
-    times (once for each replica).
+    All arguments in `args` or `kwargs` should either be nest of tensors or
+    per-replica objects containing tensors or composite tensors.
+
+    IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and
+    whether eager execution is enabled, `fn` may be called one or more times (
+    once for each replica).
 
     Args:
       fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
@@ -871,7 +882,7 @@
   def experimental_local_results(self, value):
     """Returns the list of all local per-replica values contained in `value`.
 
-    Note: This only returns values on the workers initiated by this client.
+    Note: This only returns values on the worker initiated by this client.
     When using a `tf.distribute.Strategy` like
     `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker
     will be its own client, and this function will only return values
@@ -937,6 +948,20 @@
   def __copy__(self):
     raise RuntimeError("Must only deepcopy DistributionStrategy.")
 
+  def _in_multi_worker_mode(self):
+    """Method to infer if this `Strategy` is working in multi-worker settings.
+
+    Experimental. Signature and implementation are subject to change.
+
+    Returns:
+      Whether this strategy indicates working in multi-worker settings.
+    """
+    # TODO(b/137857865): Check for whether it is multi_worker_mode should not
+    # rely on TF_CONFIG environment variable.
+    tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
+    cluster_spec = server_lib.ClusterSpec(tf_config.get("cluster", {}))
+    return tf_config and "master" not in cluster_spec.jobs
+
 
 # TF v1.x version has additional deprecated APIs
 @tf_export(v1=["distribute.Strategy"])
@@ -1300,9 +1325,18 @@
   def _scope(self, strategy):
     """Implementation of tf.distribute.Strategy.scope()."""
     def creator_with_resource_vars(*args, **kwargs):
+      """Variable creator to use in `_CurrentDistributionContext`."""
       _require_strategy_scope_extended(self)
       kwargs["use_resource"] = True
       kwargs["distribute_strategy"] = strategy
+
+      # Unwrap `initial_value` if it is a `CheckpointInitialValue` to avoid
+      # dereferencing a `Tensor` that is without a `name`.
+      # TODO(b/138130844): Revisit the following check once
+      # `CheckpointInitialValue` class is removed.
+      if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue):
+        kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
+
       return self._create_variable(*args, **kwargs)
 
     def distributed_getter(getter, *args, **kwargs):
@@ -1431,7 +1465,7 @@
         all-reduction, pass `value` to `destinations`.
 
     Returns:
-      A value mirrored to `destinations`.
+      A tensor or value mirrored to `destinations`.
     """
     # TODO(josh11b): More docstring
     _require_cross_replica_or_default_context_extended(self)
@@ -2290,9 +2324,12 @@
     if kwargs.get("trainable", True):
       collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
       l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
-      for v in value_list:
-        if v in l:
-          l.remove(v)
+      for value in value_list:
+        for i, trainable_variable in enumerate(l):
+          if value is trainable_variable:
+            del l[i]
+            break
+
     g.add_to_collections(collections, result)
   elif ops.GraphKeys.GLOBAL_STEP in collections:
     ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py
index 84b2351..d35bfa5 100644
--- a/tensorflow/python/distribute/input_lib.py
+++ b/tensorflow/python/distribute/input_lib.py
@@ -495,6 +495,11 @@
         else:
           raise
 
+    # TODO(b/138745411): Remove once stateful transformations are supported.
+    options = dataset_ops.Options()
+    options.experimental_distribute._make_stateless = True  # pylint: disable=protected-access
+    dataset = dataset.with_options(options)
+
     self._cloned_datasets = []
     if input_context:
       # Between-graph where we rely on the input_context for sharding
@@ -887,7 +892,7 @@
     Returns:
       A list of any initializer ops that should be run.
     """
-    if context.executing_eagerly():
+    if ops.executing_eagerly_outside_functions():
       self._iterator._eager_reset()  # pylint: disable=protected-access
       return []
     else:
@@ -963,6 +968,10 @@
     worker = input_workers.worker_devices[i]
     with ops.device(worker):
       dataset = dataset_fn(ctx)
+      # TODO(b/138745411): Remove once stateful transformations are supported.
+      options = dataset_ops.Options()
+      options.experimental_distribute._make_stateless = True  # pylint: disable=protected-access
+      dataset = dataset.with_options(options)
       devices = input_workers.compute_devices_for_worker(i)
       iterator = _SingleWorkerDatasetIterator(dataset, worker, devices)
       iterators.append(iterator)
diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py
index 46b4a6c..9fe6721 100644
--- a/tensorflow/python/distribute/input_lib_test.py
+++ b/tensorflow/python/distribute/input_lib_test.py
@@ -37,6 +37,7 @@
 from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.distribute import values
 from tensorflow.python.eager import context
+from tensorflow.python.eager import def_function
 from tensorflow.python.eager import test
 from tensorflow.python.framework import errors
 from tensorflow.python.ops import control_flow_ops
@@ -222,6 +223,31 @@
 
   @combinations.generate(
       combinations.combine(
+          mode=["eager"],
+          distribution=[
+              strategy_combinations.mirrored_strategy_with_gpu_and_cpu
+          ]))
+  def testMultiDeviceIterInitialize(self, distribution):
+    worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
+    dataset_fn = lambda _: dataset_ops.DatasetV1.range(10)
+
+    devices = nest.flatten([ds for _, ds in worker_device_pairs])
+    device_map = values.ReplicaDeviceMap(devices)
+    input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
+
+    dist_dataset = input_lib.get_distributed_dataset(
+        dataset_fn(distribute_lib.InputContext()), input_workers, distribution)
+
+    iterator = dataset_ops.make_one_shot_iterator(dist_dataset)
+
+    @def_function.function
+    def init_func_for_iter():
+      self.evaluate(iterator.initializer)
+
+    init_func_for_iter()
+
+  @combinations.generate(
+      combinations.combine(
           mode=["graph"],
           distribution=[
               strategy_combinations.one_device_strategy,
diff --git a/tensorflow/python/distribute/keras_experimental_saved_model_test.py b/tensorflow/python/distribute/keras_experimental_saved_model_test.py
deleted file mode 100644
index 0a0a57f..0000000
--- a/tensorflow/python/distribute/keras_experimental_saved_model_test.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for saving and loading using keras experimental APIs with DS."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.distribute import combinations
-from tensorflow.python.distribute import saved_model_test_base as test_base
-from tensorflow.python.eager import test
-from tensorflow.python.keras.saving import saved_model_experimental as saved_model
-
-
-class KerasExperimentalSaveLoadTest(test_base.TestSavedModelBase):
-
-  def setUp(self):
-    self._root_dir = 'keras_experimental_save_load'
-    super(KerasExperimentalSaveLoadTest, self).setUp()
-
-  def _save_model(self, model, saved_dir):
-    saved_model.export_saved_model(model, saved_dir)
-
-  def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
-                          output_name):
-    restored_keras_model = saved_model.load_from_saved_model(saved_dir)
-    return restored_keras_model.predict(
-        predict_dataset, steps=test_base.PREDICT_STEPS)
-
-  @combinations.generate(test_base.simple_models_with_strategies())
-  def test_save_no_strategy_restore_strategy(self, model_and_input,
-                                             distribution, run_distributed):
-    self.run_test_save_no_strategy_restore_strategy(model_and_input,
-                                                    distribution,
-                                                    run_distributed)
-
-  @combinations.generate(
-      combinations.times(test_base.simple_models_with_strategies(),
-                         combinations.combine(save_in_scope=[True, False])))
-  def test_save_strategy_restore_no_strategy(self, model_and_input,
-                                             distribution, save_in_scope,
-                                             run_distributed):
-    self.run_test_save_strategy_restore_no_strategy(model_and_input,
-                                                    distribution, save_in_scope,
-                                                    run_distributed)
-
-  @combinations.generate(
-      combinations.times(test_base.simple_models_with_strategy_pairs(),
-                         combinations.combine(save_in_scope=[True, False])))
-  def test_save_strategy_restore_strategy(self, model_and_input,
-                                          distribution_for_saving,
-                                          distribution_for_restoring,
-                                          save_in_scope, run_distributed):
-    self.run_test_save_strategy_restore_strategy(model_and_input,
-                                                 distribution_for_saving,
-                                                 distribution_for_restoring,
-                                                 save_in_scope, run_distributed)
-
-
-if __name__ == '__main__':
-  test.main()
diff --git a/tensorflow/python/distribute/keras_save_load_test.py b/tensorflow/python/distribute/keras_save_load_test.py
index fcb4941..45bf27a 100644
--- a/tensorflow/python/distribute/keras_save_load_test.py
+++ b/tensorflow/python/distribute/keras_save_load_test.py
@@ -34,30 +34,32 @@
     model.save(saved_dir, save_format='tf')
 
   def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
-                          output_name):
+                          output_name, experimental_run_tf_function):
     restored_keras_model = save.load_model(saved_dir)
+    restored_keras_model._experimental_run_tf_function = (
+        experimental_run_tf_function)
     return restored_keras_model.predict(
         predict_dataset, steps=test_base.PREDICT_STEPS)
 
   @combinations.generate(test_base.simple_models_with_strategies())
   def test_save_no_strategy_restore_strategy(self, model_and_input,
-                                             distribution, run_distributed):
-    self.run_test_save_no_strategy_restore_strategy(model_and_input,
-                                                    distribution,
-                                                    run_distributed)
+                                             distribution,
+                                             experimental_run_tf_function):
+    self.run_test_save_no_strategy_restore_strategy(
+        model_and_input, distribution, experimental_run_tf_function)
 
   @combinations.generate(
       combinations.times(test_base.simple_models_with_strategies(),
                          combinations.combine(save_in_scope=[True, False])))
   def test_save_strategy_restore_no_strategy(self, model_and_input,
                                              distribution, save_in_scope,
-                                             run_distributed):
+                                             experimental_run_tf_function):
     if save_in_scope:
       self.skipTest(('b/134703272 - Saving model in tf.distribute.Strategy ',
                      'scope is not supported.'))
-    self.run_test_save_strategy_restore_no_strategy(model_and_input,
-                                                    distribution, save_in_scope,
-                                                    run_distributed)
+    self.run_test_save_strategy_restore_no_strategy(
+        model_and_input, distribution, save_in_scope,
+        experimental_run_tf_function)
 
   @combinations.generate(
       combinations.times(test_base.simple_models_with_strategy_pairs(),
@@ -65,14 +67,16 @@
   def test_save_strategy_restore_strategy(self, model_and_input,
                                           distribution_for_saving,
                                           distribution_for_restoring,
-                                          save_in_scope, run_distributed):
+                                          save_in_scope,
+                                          experimental_run_tf_function):
     if save_in_scope:
       self.skipTest(('b/134703272 - Saving model in tf.distribute.Strategy ',
                      'scope is not supported.'))
     self.run_test_save_strategy_restore_strategy(model_and_input,
                                                  distribution_for_saving,
                                                  distribution_for_restoring,
-                                                 save_in_scope, run_distributed)
+                                                 save_in_scope,
+                                                 experimental_run_tf_function)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/distribute/minimize_loss_test.py b/tensorflow/python/distribute/minimize_loss_test.py
index c9422ab..789ee97 100644
--- a/tensorflow/python/distribute/minimize_loss_test.py
+++ b/tensorflow/python/distribute/minimize_loss_test.py
@@ -34,6 +34,7 @@
 from tensorflow.python.layers import core
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_v2_toggles
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables as variables_lib
@@ -161,6 +162,9 @@
               optimizer_fn=strategy_combinations.optimizers_v1_and_v2,
               mode=["graph"]))
   def testOptimizerInsideModelFn(self, distribution, optimizer_fn):
+    if (not context.executing_eagerly() and
+        control_flow_v2_toggles.control_flow_v2_enabled()):
+      self.skipTest("b/138751864")
     created_variables = []
     trainable_variables = []
 
diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py
index 0afbb83..a1d520d 100644
--- a/tensorflow/python/distribute/mirrored_strategy.py
+++ b/tensorflow/python/distribute/mirrored_strategy.py
@@ -342,7 +342,8 @@
   This strategy uses one replica per device and sync replication for its
   multi-GPU version.
 
-  The multi-worker version will be added in the future.
+  To use `MirroredStrategy` with multiple workers, please refer to
+  `tf.distribute.MultiWorkerMirroredStrategy`.
 
   Args:
     devices: a list of device strings.  If `None`, all available GPUs are used.
@@ -374,8 +375,22 @@
 
   def __init__(self, container_strategy, devices=None, cross_device_ops=None):
     super(MirroredExtended, self).__init__(container_strategy)
-    if devices is None:
-      devices = all_devices()
+    if context.executing_eagerly():
+      if devices and not _is_device_list_local(devices):
+        raise RuntimeError("In-graph multi-worker training with "
+                           "`MirroredStrategy` is not supported in eager mode.")
+      else:
+        if TFConfigClusterResolver().cluster_spec().as_dict():
+          # if you are executing in eager mode, only the single machine code
+          # path is supported.
+          logging.info("Initializing local devices since in-graph multi-worker "
+                       "training with `MirroredStrategy` is not supported in "
+                       "eager mode. TF_CONFIG will be ignored when "
+                       "when initializing `MirroredStrategy`.")
+        devices = devices or all_local_devices()
+    else:
+      devices = devices or all_devices()
+
     assert devices, ("Got an empty `devices` list and unable to recognize "
                      "any local devices.")
     self._cross_device_ops = cross_device_ops
diff --git a/tensorflow/python/distribute/mirrored_strategy_test.py b/tensorflow/python/distribute/mirrored_strategy_test.py
index 8f94f39..5c8e9a7 100644
--- a/tensorflow/python/distribute/mirrored_strategy_test.py
+++ b/tensorflow/python/distribute/mirrored_strategy_test.py
@@ -615,6 +615,7 @@
       self.assertIsInstance(mirrored_var, values.MirroredVariable)
       self.evaluate(variables.global_variables_initializer())
       self.assertEqual(1.0, self.evaluate(mirrored_var))
+      self.assertIsNotNone(ops.tensor_id(mirrored_var))
       mirrored_var_result = self.evaluate(mirrored_var.assign(6.0))
       self.assertEqual(6.0, mirrored_var_result)
 
@@ -1163,45 +1164,50 @@
         context.num_gpus())
 
   def testMinimizeLossGraph(self):
-    strategy = mirrored_strategy.MirroredStrategy(
-        cross_device_ops=self._make_cross_device_ops())
-    strategy.configure(cluster_spec=self._cluster_spec)
-    self._test_minimize_loss_graph(strategy, learning_rate=0.05)
-
-  def testMinimizeLossGraphMirroredStrategy(self):
-    strategy = mirrored_strategy.MirroredStrategy(
-        mirrored_strategy.all_local_devices(),
-        cross_device_ops=self._make_cross_device_ops())
-    strategy.configure(cluster_spec=self._cluster_spec)
-    self._test_minimize_loss_graph(strategy, learning_rate=0.05)
-
-  def testMinimizeLossGraphMirroredStrategyWithOneNode(self):
-    cluster_spec = {}
-    cluster_spec["chief"] = self._cluster_spec["chief"]
-    tf_config = {"cluster": cluster_spec}
-    with test.mock.patch.dict("os.environ",
-                              {"TF_CONFIG": json.dumps(tf_config)}):
-      strategy = mirrored_strategy.MirroredStrategy()
-      self.assertIsInstance(strategy.extended._inferred_cross_device_ops,
-                            cross_device_ops_lib.NcclAllReduce)
-    self.skipTest("b/130551176, run the following once fixed.")
-    self._test_minimize_loss_graph(strategy, learning_rate=0.05)
-
-  def testInitializeFromTFConfig(self):
-    tf_config = {"cluster": self._cluster_spec}
-    with test.mock.patch.dict("os.environ",
-                              {"TF_CONFIG": json.dumps(tf_config)}):
+    with context.graph_mode():
       strategy = mirrored_strategy.MirroredStrategy(
           cross_device_ops=self._make_cross_device_ops())
-      self.assertEqual(
-          max(context.num_gpus(), 1) * 3, strategy.num_replicas_in_sync)
+      strategy.configure(cluster_spec=self._cluster_spec)
+      self._test_minimize_loss_graph(strategy, learning_rate=0.05)
+
+  def testMinimizeLossGraphMirroredStrategy(self):
+    with context.graph_mode():
+      strategy = mirrored_strategy.MirroredStrategy(
+          mirrored_strategy.all_local_devices(),
+          cross_device_ops=self._make_cross_device_ops())
+      strategy.configure(cluster_spec=self._cluster_spec)
+      self._test_minimize_loss_graph(strategy, learning_rate=0.05)
+
+  def testMinimizeLossGraphMirroredStrategyWithOneNode(self):
+    with context.graph_mode():
+      cluster_spec = {}
+      cluster_spec["chief"] = self._cluster_spec["chief"]
+      tf_config = {"cluster": cluster_spec}
+      with test.mock.patch.dict("os.environ",
+                                {"TF_CONFIG": json.dumps(tf_config)}):
+        strategy = mirrored_strategy.MirroredStrategy()
+        self.assertIsInstance(strategy.extended._inferred_cross_device_ops,
+                              cross_device_ops_lib.NcclAllReduce)
+      self.skipTest("b/130551176, run the following once fixed.")
+      self._test_minimize_loss_graph(strategy, learning_rate=0.05)
+
+  def testInitializeFromTFConfig(self):
+    with context.graph_mode():
+      tf_config = {"cluster": self._cluster_spec}
+      with test.mock.patch.dict("os.environ",
+                                {"TF_CONFIG": json.dumps(tf_config)}):
+        strategy = mirrored_strategy.MirroredStrategy(
+            cross_device_ops=self._make_cross_device_ops())
+        self.assertEqual(
+            max(context.num_gpus(), 1) * 3, strategy.num_replicas_in_sync)
 
   def testSummaryForReplicaZeroOnly(self):
-    strategy = mirrored_strategy.MirroredStrategy(
-        mirrored_strategy.all_local_devices(),
-        cross_device_ops=self._make_cross_device_ops())
-    strategy.configure(cluster_spec=self._cluster_spec)
-    self._test_summary_for_replica_zero_only(strategy)
+    with context.graph_mode():
+      strategy = mirrored_strategy.MirroredStrategy(
+          mirrored_strategy.all_local_devices(),
+          cross_device_ops=self._make_cross_device_ops())
+      strategy.configure(cluster_spec=self._cluster_spec)
+      self._test_summary_for_replica_zero_only(strategy)
 
 
 class MirroredVariableStopGradientTest(test.TestCase, parameterized.TestCase):
diff --git a/tensorflow/python/distribute/model_collection/simple_models.py b/tensorflow/python/distribute/model_collection/simple_models.py
index 5dd5fc2..6a95f06 100644
--- a/tensorflow/python/distribute/model_collection/simple_models.py
+++ b/tensorflow/python/distribute/model_collection/simple_models.py
@@ -22,9 +22,12 @@
 
 from tensorflow.python import keras
 from tensorflow.python.distribute.model_collection import model_collection_base
+from tensorflow.python.eager import def_function
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.keras.optimizer_v2 import gradient_descent
+from tensorflow.python.module import module
+from tensorflow.python.ops import variables
 
 _BATCH_SIZE = 10
 
@@ -49,13 +52,14 @@
 
     model = keras.Model(inputs=x, outputs=y)
     optimizer = gradient_descent.SGD(learning_rate=0.001)
-    run_distributed = kwargs.pop('run_distributed', None)
-    assert run_distributed is not None
+    experimental_run_tf_function = kwargs.pop('experimental_run_tf_function',
+                                              None)
+    assert experimental_run_tf_function is not None
     model.compile(
         loss='mse',
         metrics=['mae'],
         optimizer=optimizer,
-        run_distributed=run_distributed)
+        experimental_run_tf_function=experimental_run_tf_function)
 
     return model, output_name
 
@@ -77,13 +81,14 @@
         5, dtype=dtypes.float32, name=output_name, input_dim=3)
     model.add(y)
     optimizer = gradient_descent.SGD(learning_rate=0.001)
-    run_distributed = kwargs.pop('run_distributed', None)
-    assert run_distributed is not None
+    experimental_run_tf_function = kwargs.pop('experimental_run_tf_function',
+                                              None)
+    assert experimental_run_tf_function is not None
     model.compile(
         loss='mse',
         metrics=['mae'],
         optimizer=optimizer,
-        run_distributed=run_distributed)
+        experimental_run_tf_function=experimental_run_tf_function)
 
     return model, output_name
 
@@ -112,14 +117,15 @@
   def get_model(self, **kwargs):
     model = _SimpleModel()
     optimizer = gradient_descent.SGD(learning_rate=0.001)
-    run_distributed = kwargs.pop('run_distributed', None)
-    assert run_distributed is not None
+    experimental_run_tf_function = kwargs.pop('experimental_run_tf_function',
+                                              None)
+    assert experimental_run_tf_function is not None
     model.compile(
         loss='mse',
         metrics=['mae'],
         cloning=False,
         optimizer=optimizer,
-        run_distributed=run_distributed)
+        experimental_run_tf_function=experimental_run_tf_function)
 
     return model, model.output_name
 
@@ -128,3 +134,27 @@
 
   def get_batch_size(self):
     return _BATCH_SIZE
+
+
+class _SimpleModule(module.Module):
+
+  def __init__(self):
+    self.v = variables.Variable(3.0)
+
+  @def_function.function
+  def __call__(self, x):
+    return self.v * x
+
+
+class SimpleTFModuleModel(model_collection_base.ModelAndInput):
+  """A simple model based on tf.Module and its data."""
+
+  def get_model(self, **kwargs):
+    model = _SimpleModule()
+    return model, 'foo'
+
+  def get_data(self):
+    return _get_data_for_simple_models()
+
+  def get_batch_size(self):
+    return _BATCH_SIZE
diff --git a/tensorflow/python/distribute/model_combinations.py b/tensorflow/python/distribute/model_combinations.py
index 798bf11..2d8ca79 100644
--- a/tensorflow/python/distribute/model_combinations.py
+++ b/tensorflow/python/distribute/model_combinations.py
@@ -29,3 +29,6 @@
 
 simple_subclass_model = combinations.NamedObject(
     "SimpleSubclassModel", simple_models.SimpleSubclassModel())
+
+simple_tfmodule_model = combinations.NamedObject(
+    "SimpleTFModuleModel", simple_models.SimpleTFModuleModel())
diff --git a/tensorflow/python/distribute/multi_worker_test_base.py b/tensorflow/python/distribute/multi_worker_test_base.py
index dc03de8..096723a 100644
--- a/tensorflow/python/distribute/multi_worker_test_base.py
+++ b/tensorflow/python/distribute/multi_worker_test_base.py
@@ -18,7 +18,6 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
 import contextlib
 import copy
 import json
@@ -49,6 +48,7 @@
 from tensorflow.python.training import coordinator
 from tensorflow.python.training import server_lib
 from tensorflow.python.util import nest
+from tensorflow.python.util.compat import collections_abc
 
 
 original_run_std_server = dc._run_std_server  # pylint: disable=protected-access
@@ -359,7 +359,7 @@
     self._coord.join(threads)
 
 
-class MockOsEnv(collections.Mapping):
+class MockOsEnv(collections_abc.Mapping):
   """A class that allows per-thread TF_CONFIG."""
 
   def __init__(self, *args):
diff --git a/tensorflow/python/distribute/multi_worker_util.py b/tensorflow/python/distribute/multi_worker_util.py
index 918e025..983cd5f 100644
--- a/tensorflow/python/distribute/multi_worker_util.py
+++ b/tensorflow/python/distribute/multi_worker_util.py
@@ -18,9 +18,6 @@
 from __future__ import division
 from __future__ import print_function
 
-import json
-import os
-
 from tensorflow.core.protobuf import cluster_pb2
 from tensorflow.python.distribute import distribute_coordinator_context as dc_context
 from tensorflow.python.training import server_lib
@@ -226,15 +223,6 @@
   raise ValueError("There is no id for task_type %r" % task_type)
 
 
-def in_multi_worker_mode():
-  """Whether the program is operating in Multi-Worker setting."""
-  # TODO(rchao): Consider a warning if user uses multiple `model` method
-  # calls in multi-worker setting.
-  tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
-  cluster_spec = server_lib.ClusterSpec(tf_config.get("cluster", {}))
-  return tf_config and "master" not in cluster_spec.jobs
-
-
 def should_save_checkpoint():
   """Returns whether the current worker should save checkpoints.
 
diff --git a/tensorflow/python/distribute/saved_model_mixed_api_test.py b/tensorflow/python/distribute/saved_model_mixed_api_test.py
index 834cfbb..74d208d 100644
--- a/tensorflow/python/distribute/saved_model_mixed_api_test.py
+++ b/tensorflow/python/distribute/saved_model_mixed_api_test.py
@@ -42,30 +42,30 @@
     keras_saved_model.export_saved_model(model, saved_dir, serving_only=True)
 
   def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
-                          output_name):
+                          output_name, experimental_run_tf_function):
     return test_base.load_and_run_with_saved_model_api(distribution, saved_dir,
                                                        predict_dataset,
                                                        output_name)
 
   @combinations.generate(test_base.simple_models_with_strategies())
   def test_save_no_strategy_restore_strategy(self, model_and_input,
-                                             distribution, run_distributed):
-    self.run_test_save_no_strategy_restore_strategy(model_and_input,
-                                                    distribution,
-                                                    run_distributed)
+                                             distribution,
+                                             experimental_run_tf_function):
+    self.run_test_save_no_strategy_restore_strategy(
+        model_and_input, distribution, experimental_run_tf_function)
 
   @combinations.generate(
       combinations.times(test_base.simple_models_with_strategies(),
                          combinations.combine(save_in_scope=[True, False])))
   def test_save_strategy_restore_no_strategy(self, model_and_input,
                                              distribution, save_in_scope,
-                                             run_distributed):
+                                             experimental_run_tf_function):
     if save_in_scope:
       self.skipTest(('Saving model within tf.distribute.Strategy scope is not ',
                      'supported.'))
-    self.run_test_save_strategy_restore_no_strategy(model_and_input,
-                                                    distribution, save_in_scope,
-                                                    run_distributed)
+    self.run_test_save_strategy_restore_no_strategy(
+        model_and_input, distribution, save_in_scope,
+        experimental_run_tf_function)
 
   @combinations.generate(
       combinations.times(test_base.simple_models_with_strategy_pairs(),
@@ -73,14 +73,16 @@
   def test_save_strategy_restore_strategy(self, model_and_input,
                                           distribution_for_saving,
                                           distribution_for_restoring,
-                                          save_in_scope, run_distributed):
+                                          save_in_scope,
+                                          experimental_run_tf_function):
     if save_in_scope:
       self.skipTest(('Saving model within tf.distribute.Strategy scope is not ',
                      'supported.'))
     self.run_test_save_strategy_restore_strategy(model_and_input,
                                                  distribution_for_saving,
                                                  distribution_for_restoring,
-                                                 save_in_scope, run_distributed)
+                                                 save_in_scope,
+                                                 experimental_run_tf_function)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/distribute/saved_model_save_load_test.py b/tensorflow/python/distribute/saved_model_save_load_test.py
index 6c0b246..04903f1 100644
--- a/tensorflow/python/distribute/saved_model_save_load_test.py
+++ b/tensorflow/python/distribute/saved_model_save_load_test.py
@@ -21,43 +21,46 @@
 from tensorflow.python.distribute import combinations
 from tensorflow.python.distribute import saved_model_test_base as test_base
 from tensorflow.python.eager import test
+from tensorflow.python.framework import tensor_spec
+from tensorflow.python.ops import array_ops
 from tensorflow.python.saved_model import saved_model
 
 
-class SavedModelSaveAndLoadTest(test_base.TestSavedModelBase):
+class SavedModelKerasModelTest(test_base.TestSavedModelBase):
 
   def setUp(self):
     self._root_dir = 'saved_model_save_load'
-    super(SavedModelSaveAndLoadTest, self).setUp()
+    super(SavedModelKerasModelTest, self).setUp()
 
   def _save_model(self, model, saved_dir):
     saved_model.save(model, saved_dir)
 
   def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
-                          output_name):
+                          output_name, experimental_run_tf_function):
     return test_base.load_and_run_with_saved_model_api(distribution, saved_dir,
                                                        predict_dataset,
                                                        output_name)
 
   @combinations.generate(test_base.simple_models_with_strategies())
   def test_save_no_strategy_restore_strategy(self, model_and_input,
-                                             distribution, run_distributed):
-    self.run_test_save_no_strategy_restore_strategy(model_and_input,
-                                                    distribution,
-                                                    run_distributed)
+                                             distribution,
+                                             experimental_run_tf_function):
+    self.run_test_save_no_strategy_restore_strategy(
+        model_and_input, distribution, experimental_run_tf_function)
 
   @combinations.generate(
       combinations.times(test_base.simple_models_with_strategies(),
                          combinations.combine(save_in_scope=[True, False])))
   def test_save_strategy_restore_no_strategy(self, model_and_input,
                                              distribution, save_in_scope,
-                                             run_distributed):
+                                             experimental_run_tf_function):
     if save_in_scope:
+      # TODO(b/134703272): Unskip this test when fixed.
       self.skipTest(('Saving model within tf.distribute.Strategy scope is not ',
                      'supported.'))
-    self.run_test_save_strategy_restore_no_strategy(model_and_input,
-                                                    distribution, save_in_scope,
-                                                    run_distributed)
+    self.run_test_save_strategy_restore_no_strategy(
+        model_and_input, distribution, save_in_scope,
+        experimental_run_tf_function)
 
   @combinations.generate(
       combinations.times(test_base.simple_models_with_strategy_pairs(),
@@ -65,14 +68,90 @@
   def test_save_strategy_restore_strategy(self, model_and_input,
                                           distribution_for_saving,
                                           distribution_for_restoring,
-                                          save_in_scope, run_distributed):
+                                          save_in_scope,
+                                          experimental_run_tf_function):
     if save_in_scope:
+      # TODO(b/134703272): Unskip this test when fixed.
       self.skipTest(('Saving model within tf.distribute.Strategy scope is not ',
                      'supported.'))
     self.run_test_save_strategy_restore_strategy(model_and_input,
                                                  distribution_for_saving,
                                                  distribution_for_restoring,
-                                                 save_in_scope, run_distributed)
+                                                 save_in_scope,
+                                                 experimental_run_tf_function)
+
+
+class SavedModelTFModuleTest(test_base.TestSavedModelBase):
+
+  def setUp(self):
+    self._root_dir = 'saved_model_save_load'
+    super(SavedModelTFModuleTest, self).setUp()
+
+  def _train_model(self, model, x_train, y_train, batch_size):
+    pass
+
+  def _predict_with_model(self, distribution, model, predict_dataset):
+    if distribution:
+      dist_predict_dataset = distribution.experimental_distribute_dataset(
+          predict_dataset)
+      per_replica_predict_data = next(iter(dist_predict_dataset))
+      result = distribution.experimental_run_v2(
+          model, args=(per_replica_predict_data,))
+      # Convert the per_replica value to a list, then concatenate them
+      reduced = distribution.experimental_local_results(result)
+      concat = array_ops.concat(reduced, 0)
+      return concat
+    else:
+      return model(next(iter(predict_dataset)))
+
+  def _save_model(self, model, saved_dir):
+    call = model.__call__.get_concrete_function(tensor_spec.TensorSpec(None))
+    saved_model.save(model, saved_dir, signatures=call)
+
+  def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
+                          output_name, experimental_run_tf_function):
+    del output_name, experimental_run_tf_function
+    model = saved_model.load(saved_dir)
+    return self._predict_with_model(distribution, model, predict_dataset)
+
+  @combinations.generate(test_base.tfmodule_models_with_strategies())
+  def test_save_no_strategy_restore_strategy(self, model_and_input,
+                                             distribution,
+                                             experimental_run_tf_function):
+    self.run_test_save_no_strategy_restore_strategy(
+        model_and_input, distribution, experimental_run_tf_function)
+
+  @combinations.generate(
+      combinations.times(test_base.tfmodule_models_with_strategies(),
+                         combinations.combine(save_in_scope=[True, False])))
+  def test_save_strategy_restore_no_strategy(
+      self, model_and_input, distribution, save_in_scope,
+      experimental_run_tf_function):
+    if save_in_scope:
+      # TODO(b/134703272): Unskip this test when fixed.
+      self.skipTest(('Saving model within tf.distribute.Strategy scope is not ',
+                     'supported.'))
+    self.run_test_save_strategy_restore_no_strategy(
+        model_and_input, distribution, save_in_scope,
+        experimental_run_tf_function)
+
+  @combinations.generate(
+      combinations.times(test_base.tfmodule_models_with_strategy_pairs(),
+                         combinations.combine(save_in_scope=[True, False])))
+  def test_save_strategy_restore_strategy(self, model_and_input,
+                                          distribution_for_saving,
+                                          distribution_for_restoring,
+                                          save_in_scope,
+                                          experimental_run_tf_function):
+    if save_in_scope:
+      # TODO(b/134703272): Unskip this test when fixed.
+      self.skipTest(('Saving model within tf.distribute.Strategy scope is not ',
+                     'supported.'))
+    self.run_test_save_strategy_restore_strategy(model_and_input,
+                                                 distribution_for_saving,
+                                                 distribution_for_restoring,
+                                                 save_in_scope,
+                                                 experimental_run_tf_function)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/distribute/saved_model_test_base.py b/tensorflow/python/distribute/saved_model_test_base.py
index c17c0e3..1001dd4 100644
--- a/tensorflow/python/distribute/saved_model_test_base.py
+++ b/tensorflow/python/distribute/saved_model_test_base.py
@@ -35,6 +35,9 @@
 _DEFAULT_FUNCTION_KEY = 'serving_default'
 
 _TOLERANCE = 1e-30
+# TPU uses bfloat16 for computation in hardware underlying, so it has less
+# precision than CPU/GPU.
+_TPU_TOLERANCE = 1e-7
 
 PREDICT_STEPS = 1
 
@@ -47,32 +50,62 @@
 ]
 
 
-strategies_minus_tpu = [
+strategies = [
     # TODO(b/132702156): include default strategy
     strategy_combinations.one_device_strategy,
     strategy_combinations.one_device_strategy_gpu,
     strategy_combinations.mirrored_strategy_with_one_cpu,
     strategy_combinations.mirrored_strategy_with_one_gpu,
     strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
-    strategy_combinations.mirrored_strategy_with_two_gpus
+    strategy_combinations.mirrored_strategy_with_two_gpus,
+    strategy_combinations.tpu_strategy
 ]
 
 
+def is_tpu_strategy(distribution):
+  return (distribution is not None and
+          distribution.__class__.__name__.startswith('TPUStrategy'))
+
+
+def get_tolerance(save_distribution, restore_distribution):
+  if is_tpu_strategy(save_distribution) or is_tpu_strategy(
+      restore_distribution):
+    return _TPU_TOLERANCE
+  return _TOLERANCE
+
+
 def simple_models_with_strategies():
   return combinations.combine(
       model_and_input=simple_models,
-      distribution=strategies_minus_tpu,
+      distribution=strategies,
       mode=['eager'],
-      run_distributed=[True, False])
+      experimental_run_tf_function=[True, False])
 
 
 def simple_models_with_strategy_pairs():
   return combinations.combine(
       model_and_input=simple_models,
-      distribution_for_saving=strategies_minus_tpu,
-      distribution_for_restoring=strategies_minus_tpu,
+      distribution_for_saving=strategies,
+      distribution_for_restoring=strategies,
       mode=['eager'],
-      run_distributed=[True, False])
+      experimental_run_tf_function=[True, False])
+
+
+def tfmodule_models_with_strategies():
+  return combinations.combine(
+      model_and_input=[model_combinations.simple_tfmodule_model],
+      distribution=strategies,
+      mode=['eager'],
+      experimental_run_tf_function=[True])
+
+
+def tfmodule_models_with_strategy_pairs():
+  return combinations.combine(
+      model_and_input=[model_combinations.simple_tfmodule_model],
+      distribution_for_saving=strategies,
+      distribution_for_restoring=strategies,
+      mode=['eager'],
+      experimental_run_tf_function=[True])
 
 
 def load_and_run_with_saved_model_api(distribution, saved_dir, predict_dataset,
@@ -118,7 +151,7 @@
     raise NotImplementedError('must be implemented in descendants')
 
   def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
-                          output_name):
+                          output_name, experimental_run_tf_function):
     """Load the model and run 1 step of predict with it.
 
     This method must be implemented by the subclasses.
@@ -131,6 +164,8 @@
         cross_replica context.
       output_name: the string representing the name of the output layer of the
         model.
+      experimental_run_tf_function: Whether to use the single execution path
+        for models.
     """
 
     raise NotImplementedError('must be implemented in descendants')
@@ -144,6 +179,9 @@
     # Train the model for 1 epoch
     model.fit(x=training_dataset, epochs=1, steps_per_epoch=100)
 
+  def _predict_with_model(self, distribution, model, predict_dataset):
+    return model.predict(predict_dataset, steps=PREDICT_STEPS)
+
   def _get_predict_dataset(self, x_predict, batch_size):
     predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict)
     predict_dataset = predict_dataset.repeat()
@@ -151,20 +189,20 @@
     return predict_dataset
 
   def run_test_save_no_strategy_restore_strategy(self, model_and_input,
-                                                 distribution, run_distributed):
+                                                 distribution,
+                                                 experimental_run_tf_function):
     """Save a model without DS, and restore it with DS."""
 
-    saved_dir = os.path.join(self.get_temp_dir(), self._root_dir,
-                             'test_save_no_dist_restore_dist')
+    saved_dir = os.path.join(self.get_temp_dir(), '0')
 
     model, output_name = model_and_input.get_model(
-        run_distributed=run_distributed)
+        experimental_run_tf_function=experimental_run_tf_function)
     x_train, y_train, x_predict = model_and_input.get_data()
     batch_size = model_and_input.get_batch_size()
+    predict_dataset = self._get_predict_dataset(x_predict, batch_size)
 
     self._train_model(model, x_train, y_train, batch_size)
-    predict_dataset = self._get_predict_dataset(x_predict, batch_size)
-    result_before_save = model.predict(predict_dataset, steps=PREDICT_STEPS)
+    result_before_save = self._predict_with_model(None, model, predict_dataset)
 
     self._save_model(model, saved_dir)
 
@@ -173,27 +211,29 @@
           distribution=distribution,
           saved_dir=saved_dir,
           predict_dataset=predict_dataset,
-          output_name=output_name)
+          output_name=output_name,
+          experimental_run_tf_function=experimental_run_tf_function)
 
-    self.assertAllClose(result_before_save, result_after_save, atol=_TOLERANCE)
+    tolerance = get_tolerance(None, distribution)
+    self.assertAllClose(result_before_save, result_after_save, atol=tolerance)
 
   def run_test_save_strategy_restore_no_strategy(self, model_and_input,
                                                  distribution, save_in_scope,
-                                                 run_distributed):
+                                                 experimental_run_tf_function):
     """Save a model with DS, and restore it without DS."""
 
-    saved_dir = os.path.join(self.get_temp_dir(), self._root_dir,
-                             'test_save_no_dist_restore_dist')
+    saved_dir = os.path.join(self.get_temp_dir(), '1')
 
     with distribution.scope():
       model, output_name = model_and_input.get_model(
-          run_distributed=run_distributed)
+          experimental_run_tf_function=experimental_run_tf_function)
       x_train, y_train, x_predict = model_and_input.get_data()
       batch_size = model_and_input.get_batch_size()
 
       self._train_model(model, x_train, y_train, batch_size)
       predict_dataset = self._get_predict_dataset(x_predict, batch_size)
-      result_before_save = model.predict(predict_dataset, steps=PREDICT_STEPS)
+      result_before_save = self._predict_with_model(
+          distribution, model, predict_dataset)
 
     if save_in_scope:
       with distribution.scope():
@@ -205,28 +245,30 @@
         distribution=None,
         saved_dir=saved_dir,
         predict_dataset=predict_dataset,
-        output_name=output_name)
+        output_name=output_name,
+        experimental_run_tf_function=experimental_run_tf_function)
 
-    self.assertAllClose(result_before_save, load_result, atol=_TOLERANCE)
+    tolerance = get_tolerance(distribution, None)
+    self.assertAllClose(result_before_save, load_result, atol=tolerance)
 
   def run_test_save_strategy_restore_strategy(self, model_and_input,
                                               distribution_for_saving,
                                               distribution_for_restoring,
-                                              save_in_scope, run_distributed):
+                                              save_in_scope,
+                                              experimental_run_tf_function):
     """Save a model with DS, and restore it with potentially different DS."""
-
-    saved_dir = os.path.join(self.get_temp_dir(), self._root_dir,
-                             'test_save_dist_restore_dist')
+    saved_dir = os.path.join(self.get_temp_dir(), '2')
 
     with distribution_for_saving.scope():
       model, output_name = model_and_input.get_model(
-          run_distributed=run_distributed)
+          experimental_run_tf_function=experimental_run_tf_function)
       x_train, y_train, x_predict = model_and_input.get_data()
       batch_size = model_and_input.get_batch_size()
 
       self._train_model(model, x_train, y_train, batch_size)
       predict_dataset = self._get_predict_dataset(x_predict, batch_size)
-      result_before_save = model.predict(predict_dataset, steps=PREDICT_STEPS)
+      result_before_save = self._predict_with_model(
+          distribution_for_saving, model, predict_dataset)
 
     if save_in_scope:
       with distribution_for_saving.scope():
@@ -240,6 +282,9 @@
           distribution=distribution_for_restoring,
           saved_dir=saved_dir,
           predict_dataset=predict_dataset,
-          output_name=output_name)
+          output_name=output_name,
+          experimental_run_tf_function=experimental_run_tf_function)
 
-    self.assertAllClose(result_before_save, load_result, atol=_TOLERANCE)
+    tolerance = get_tolerance(distribution_for_saving,
+                              distribution_for_restoring)
+    self.assertAllClose(result_before_save, load_result, atol=tolerance)
diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py
index 05efd2c..cbed12b 100644
--- a/tensorflow/python/distribute/values.py
+++ b/tensorflow/python/distribute/values.py
@@ -613,6 +613,7 @@
     # We need to make _keras_initialized a member of DistributedVariable because
     # without this it will use `__getattr__` which will delegate to a component
     # variable.
+    self._id = ops.uid()
     self._keras_initialized = False
     # Typically, a `DistributedVariable`'s initializer is composed of the
     # initializers of the components variables. However, in some cases, such as
@@ -776,6 +777,9 @@
     """Pass resource_variable_ops.is_resource_variable check."""
     pass
 
+  def _clone_with_new_values(self, new_values):
+    raise NotImplementedError("Must be implemented in descendents.")
+
 
 ops.register_dense_tensor_like_type(DistributedVariable)
 
@@ -1069,6 +1073,10 @@
     return ops.internal_convert_to_tensor(
         self.get(), dtype=dtype, name=name, as_ref=as_ref)
 
+  def _clone_with_new_values(self, new_values):
+    return type(self)(self._distribute_strategy, self._device_map, new_values,
+                      self._aggregation, logical_device=self._logical_device)
+
 
 # Register a conversion function which reads the value of the variable,
 # allowing instances of the class to be used as tensors.
@@ -1159,7 +1167,7 @@
         "Replica-local variables may only be assigned in a replica context.")
 
 
-class SyncOnReadVariable(DistributedVariable, PerReplica):
+class SyncOnReadVariable(DistributedVariable):
   """Holds a map from replica to variables whose values are reduced on save."""
 
   def __init__(
@@ -1245,6 +1253,10 @@
     return ops.internal_convert_to_tensor(
         self.get(), dtype=dtype, name=name, as_ref=as_ref)
 
+  def _clone_with_new_values(self, new_values):
+    return type(self)(self._distribute_strategy, self._device_map, new_values,
+                      self._aggregation, logical_device=self._logical_device)
+
 
 # Register a conversion function for SyncOnReadVariable which allows as_ref to
 # be true.
diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py
index 753f3f3..fe56dae 100644
--- a/tensorflow/python/distribute/values_test.py
+++ b/tensorflow/python/distribute/values_test.py
@@ -757,6 +757,19 @@
       self.assertIsInstance(converted, ops.Tensor)
       self.assertEqual(converted.dtype, replica_local.dtype)
 
+  @test_util.run_v2_only
+  def testCanPassToDefFun(self):
+    @def_function.function
+    def add1(x):
+      return x + 1
+
+    v = variable_scope.get_variable(
+        name="v", initializer=[1.], use_resource=True)
+    device_map = values.ReplicaDeviceMap(("/job:foo/device:CPU:0",))
+    replica_local = values.SyncOnReadVariable(
+        None, device_map, (v,), variable_scope.VariableAggregation.MEAN)
+    self.assertEqual(2., self.evaluate(add1(replica_local)))
+
 
 @combinations.generate(
     combinations.combine(
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index b58bf18..91615a9 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -14,10 +14,12 @@
     name = "pywrap_tfe_lib",
     srcs = [
         "pywrap_tensor.cc",
+        "pywrap_tensor_conversion.cc",
         "pywrap_tfe_src.cc",
     ],
     hdrs = [
         "pywrap_tensor.h",
+        "pywrap_tensor_conversion.h",
         "pywrap_tfe.h",
     ],
     visibility = [
@@ -34,6 +36,9 @@
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/common_runtime/eager:tensor_handle",
+        "//tensorflow/core/platform:logging",
+        "//tensorflow/core/platform:types",
         "//tensorflow/python:cpp_python_util",
         "//tensorflow/python:ndarray_tensor",
         "//tensorflow/python:ndarray_tensor_bridge",
@@ -42,6 +47,8 @@
         "//tensorflow/python:safe_ptr",
         "//third_party/py/numpy:headers",
         "//third_party/python_runtime:headers",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/hash",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:variant",
     ],
@@ -106,11 +113,22 @@
 )
 
 py_library(
+    name = "executor",
+    srcs = ["executor.py"],
+    srcs_version = "PY2AND3",
+    visibility = ["//tensorflow:internal"],
+    deps = [
+        "//tensorflow/python:pywrap_tensorflow",
+    ],
+)
+
+py_library(
     name = "context",
     srcs = ["context.py"],
     srcs_version = "PY2AND3",
     visibility = ["//tensorflow:internal"],
     deps = [
+        ":executor",
         ":monitoring",
         "//tensorflow/python:device",
         "//tensorflow/python:device_spec",
@@ -192,6 +210,7 @@
 py_test(
     name = "profiler_client_test",
     srcs = ["profiler_client_test.py"],
+    python_version = "PY2",
     srcs_version = "PY2AND3",
     tags = ["no_pip"],
     visibility = ["//tensorflow:internal"],
@@ -256,6 +275,7 @@
 
 cuda_py_test(
     name = "core_test",
+    size = "small",
     srcs = ["core_test.py"],
     additional_deps = [
         ":context",
@@ -438,10 +458,10 @@
     srcs_version = "PY2AND3",
     visibility = ["//visibility:public"],
     deps = [
+        ":execute",
         "//tensorflow/python:dtypes",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:tensor_shape",
-        "//tensorflow/python/eager:execute",
     ],
 )
 
@@ -451,7 +471,11 @@
     srcs_version = "PY2AND3",
     visibility = ["//tensorflow:internal"],
     deps = [
+        ":context",
+        ":core",
+        ":execute",
         ":graph_only_ops",
+        ":tape",
         "//tensorflow/python:dtypes",
         "//tensorflow/python:errors",
         "//tensorflow/python:framework_ops",
@@ -459,10 +483,6 @@
         "//tensorflow/python:gradients_impl",
         "//tensorflow/python:graph_to_function_def",
         "//tensorflow/python:util",
-        "//tensorflow/python/eager:context",
-        "//tensorflow/python/eager:core",
-        "//tensorflow/python/eager:execute",
-        "//tensorflow/python/eager:tape",
         "//third_party/py/numpy",
         "@six_archive//:six",
     ],
@@ -474,7 +494,10 @@
     srcs_version = "PY2AND3",
     visibility = ["//tensorflow:internal"],
     deps = [
+        ":context",
+        ":execute",
         ":imperative_grad",
+        ":tape",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:constant_op",
         "//tensorflow/python:dtypes",
@@ -485,9 +508,6 @@
         "//tensorflow/python:tensor_shape",
         "//tensorflow/python:unconnected_gradients",
         "//tensorflow/python:util",
-        "//tensorflow/python/eager:context",
-        "//tensorflow/python/eager:execute",
-        "//tensorflow/python/eager:tape",
         "//tensorflow/python/ops/parallel_for:control_flow_ops",
         "@six_archive//:six",
     ],
@@ -655,6 +675,22 @@
     ],
 )
 
+cuda_py_test(
+    name = "def_function_xla_jit_test",
+    srcs = ["def_function_xla_jit_test.py"],
+    additional_deps = [
+        ":def_function",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:constant_op",
+        "//tensorflow/python:framework_ops",
+    ],
+    tags = [
+        "no_mac",
+        "no_windows",
+    ],
+    xla_enabled = True,
+)
+
 tf_xla_py_test(
     name = "def_function_xla_test",
     srcs = ["def_function_xla_test.py"],
@@ -684,6 +720,7 @@
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:template",
         "//tensorflow/python:variable_scope",
+        "//tensorflow/python/saved_model:nested_structure_coder",
         "//tensorflow/python/training/tracking:base",
     ],
 )
@@ -704,9 +741,9 @@
     srcs_version = "PY2AND3",
     visibility = ["//tensorflow:internal"],
     deps = [
+        ":context",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python:platform",
-        "//tensorflow/python/eager:context",
     ],
 )
 
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index cebf0a8..e2ef2ba 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -1128,7 +1128,7 @@
     See [wikipedia article](http://en.wikipedia.org/wiki/jacobian_matrix_and_determinant) for the
     definition of a Jacobian. This function is essentially an efficient
     implementation of the following:
-    
+
     `tf.stack([self.jacobian(y[i], x[i]) for i in range(x.shape[0])])`.
 
     Note that compared to `GradientTape.jacobian` which computes gradient of
@@ -1146,7 +1146,7 @@
       x = tf.constant([[1., 2.], [3., 4.]], dtype=tf.float32)
       g.watch(x)
       y = x * x
-    batch_jacobian = g.batch_jacobian(y, x) 
+    batch_jacobian = g.batch_jacobian(y, x)
     # batch_jacobian is [[[2,  0], [0,  4]], [[6,  0], [0,  8]]]
     ```
 
@@ -1229,10 +1229,11 @@
             " with experimental_use_pfor set to False.")
       output = pfor_ops.for_loop(loop_fn, target.dtype, target_row_size,
                                  parallel_iterations=parallel_iterations)
-    if output is None:
-      return None
-    output = array_ops.reshape(output,
-                               [target_row_size, batch_size, -1])
-    output = array_ops.transpose(output, [1, 0, 2])
     new_shape = array_ops.concat([target_shape, source_shape[1:]], axis=0)
-    return array_ops.reshape(output, new_shape)
+    if output is None:
+      return array_ops.zeros(new_shape)
+    else:
+      output = array_ops.reshape(output,
+                                 [target_row_size, batch_size, -1])
+      output = array_ops.transpose(output, [1, 0, 2])
+      return array_ops.reshape(output, new_shape)
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 30d3a4f..248d161 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -18,6 +18,7 @@
 
 import functools
 
+from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.python import pywrap_tensorflow
@@ -128,7 +129,7 @@
         _ = v + 1.0  # This reads the variable inside the loop context
         with backprop.GradientTape() as t:
           result = v * 2
-        self.assertTrue(t.gradient(result, v) is not None)
+        self.assertIsNotNone(t.gradient(result, v))
         return 1.0
 
       control_flow_ops.while_loop(lambda i: False, body, [1.0])
@@ -268,8 +269,8 @@
 
     grads = backprop.implicit_grad(f)()
     ordered_variables = [x[1] for x in grads]
-    self.assertTrue(ordered_variables[0] is v0)
-    self.assertTrue(ordered_variables[1] is v1)
+    self.assertIs(ordered_variables[0], v0)
+    self.assertIs(ordered_variables[1], v1)
 
   def testTapeNoOpGradient(self):
     x = constant_op.constant(3.0)
@@ -1482,7 +1483,7 @@
 
 
 @test_util.run_all_in_graph_and_eager_modes
-class BatchJacobianTest(test.TestCase):
+class BatchJacobianTest(test.TestCase, parameterized.TestCase):
 
   def _batch_jacobian(self, experimental_use_pfor):
     persistent = context.executing_eagerly and not experimental_use_pfor
@@ -1583,6 +1584,23 @@
     self.assertAllClose(g.batch_jacobian(y, x, parallel_iterations=2),
                         g.batch_jacobian(y, x, parallel_iterations=3))
 
+  @parameterized.parameters(
+      (True, True),
+      (True, False),
+      (False, True),
+      (False, False))
+  def test_degenerate_shape(self, use_function, use_pfor):
+
+    def f(x):
+      with backprop.GradientTape(persistent=True) as tape:
+        tape.watch(x)
+        y = x**2
+      return tape.batch_jacobian(y, x, experimental_use_pfor=use_pfor)
+
+    if use_function:
+      f = def_function.function(f)
+    self.assertAllEqual([1, 0, 0], array_ops.shape(f(array_ops.zeros([1, 0]))))
+
 
 class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase):
 
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index 615e8a8..929432d 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -184,10 +184,10 @@
     ctx = context.context()
     if device == GPU:
       # Warmup the GPU
-      ops.EagerTensor(value, context=ctx, device=device)
+      ops.EagerTensor(value, device=device)
 
     def func():
-      ops.EagerTensor(value, context=ctx, device=device, dtype=dtype)
+      ops.EagerTensor(value, device=device, dtype=dtype)
 
     self._run(func, 30000)
 
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 245228d..c0d1881 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -30,6 +30,7 @@
 from tensorflow.core.protobuf import rewriter_config_pb2
 from tensorflow.python import pywrap_tensorflow
 from tensorflow.python import tf2
+from tensorflow.python.eager import executor
 from tensorflow.python.eager import monitoring
 from tensorflow.python.framework import c_api_util
 from tensorflow.python.framework import device as pydev
@@ -64,10 +65,9 @@
 MIRRORING_NONE = pywrap_tensorflow.TFE_MIRRORING_NONE
 MIRRORING_ALL = pywrap_tensorflow.TFE_MIRRORING_ALL
 
-_tf2_gauge = monitoring.BoolGauge("/tensorflow/api/tf2_enable",
-                                  "Whether tf2.enable() is called.")
-
-_tf2_gauge.get_cell().set(tf2.enabled())
+_python_eager_context_create_counter = monitoring.Counter(
+    "/tensorflow/api/python/eager_context_create_counter",
+    "Counter for number of eager contexts created in Python.")
 
 
 class _EagerTensorCache(object):
@@ -155,7 +155,6 @@
 
   def __init__(self):
     super(_TensorCaches, self).__init__()
-    self.scalar_cache = {}
     self._ones_rank_cache = None
     self._zeros_cache = None
 
@@ -186,8 +185,8 @@
     self.summary_recording = None
     self.summary_recording_distribution_strategy = True
     self.summary_step = None
-    self.execution_mode = SYNC
     self.function_call_options = None
+    self.executor = None
 
 
 ContextSwitch = collections.namedtuple(
@@ -383,6 +382,7 @@
     self._post_execution_callbacks = []
     self._seed = None
     self._initialize_lock = threading.Lock()
+    self._initialized = False
     if device_policy is None:
       device_policy = DEVICE_PLACEMENT_SILENT
     self._device_policy = device_policy
@@ -392,7 +392,7 @@
           "execution_mode should be None/SYNC/ASYNC. Got %s" % execution_mode)
     if execution_mode is None:
       execution_mode = SYNC
-    self._execution_mode = execution_mode
+    self._default_is_async = execution_mode == ASYNC
     self._server_def = server_def
     self._collective_ops_server_def = None
     self._collective_leader = None
@@ -414,6 +414,7 @@
     self._log_device_placement = None
     self._optimizer_experimental_options = {}
 
+    _python_eager_context_create_counter.get_cell().increase_by(1)
   # pylint: enable=redefined-outer-name
 
   def _set_global_seed(self, seed):
@@ -467,8 +468,10 @@
 
   def ensure_initialized(self):
     """Initialize handle and devices if not already done so."""
+    if self._initialized:
+      return
     with self._initialize_lock:
-      if self._context_handle is not None:
+      if self._initialized:
         return
       assert self._context_devices is None
       opts = pywrap_tensorflow.TFE_NewContextOptions()
@@ -481,9 +484,9 @@
         if self._mirroring_policy is not None:
           pywrap_tensorflow.TFE_ContextOptionsSetMirroringPolicy(
               opts, self._mirroring_policy)
-        if self._execution_mode == ASYNC:
+        if self._default_is_async == ASYNC:
           pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True)
-        self._context_handle = pywrap_tensorflow.TFE_NewContext(opts)
+        context_handle = pywrap_tensorflow.TFE_NewContext(opts)
       finally:
         pywrap_tensorflow.TFE_DeleteContextOptions(opts)
       assert not (self._server_def and self._collective_ops_server_def), (
@@ -491,19 +494,21 @@
           "moment. If this is important to you, please file an issue.")
       if self._server_def is not None:
         server_def_str = self._server_def.SerializeToString()
-        pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle, 600,
+        pywrap_tensorflow.TFE_ContextSetServerDef(context_handle, 600,
                                                   server_def_str)
       elif self._collective_ops_server_def is not None:
         server_def_str = self._collective_ops_server_def.SerializeToString()
-        pywrap_tensorflow.TFE_EnableCollectiveOps(self._context_handle,
+        pywrap_tensorflow.TFE_EnableCollectiveOps(context_handle,
                                                   server_def_str)
 
+      self._context_handle = context_handle
       self._initialize_logical_devices()
+      self._initialized = True
 
   def _clear_caches(self):
-    self.scalar_cache().clear()
     self.ones_rank_cache().flush()
     self.zeros_cache().flush()
+    pywrap_tensorflow.TFE_ClearScalarCache()
 
   def set_server_def(self, server_def, keep_alive_secs=600):
     """Allow setting a server_def on the context.
@@ -533,12 +538,11 @@
       server_def_str = server_def.SerializeToString()
       pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle,
                                                 keep_alive_secs, server_def_str)
-
-      # Clear all the caches in case there are remote tensors in them.
-      self._clear_caches()
-
       self._initialize_logical_devices()
 
+    # Clear all the caches in case there are remote tensors in them.
+    self._clear_caches()
+
   def enable_collective_ops(self, server_def):
     """Enable distributed collective ops with an appropriate server_def.
 
@@ -650,10 +654,6 @@
     """Returns True if current thread has eager executing enabled."""
     return self._thread_local_data.is_eager
 
-  def scalar_cache(self):
-    """Per-device cache for scalars."""
-    return _tensor_caches_map[self._id].scalar_cache
-
   def ones_rank_cache(self):
     """Per-device cache for scalars."""
     return _tensor_caches_map[self._id].ones_rank_cache
@@ -745,18 +745,11 @@
     """List of the names of devices available to execute operations."""
     return self._devices
 
+  # TODO(fishx): remove this property.
   @property
   def execution_mode(self):
     """Gets execution mode for current thread."""
-    # Only get the execution mode from the context if it has already been
-    # initialized
-    if self._context_handle is None:
-      return self._execution_mode
-
-    mode = self._thread_local_data.execution_mode
-    if mode is None:
-      mode = self._execution_mode
-    return mode
+    return ASYNC if self.is_async() else SYNC
 
   @execution_mode.setter
   def execution_mode(self, mode):
@@ -764,18 +757,39 @@
     if mode not in (None, SYNC, ASYNC):
       raise ValueError(
           "Execution mode should be None/SYNC/ASYNC. Got %s" % mode)
+
     if mode is None:
       mode = SYNC
 
-    if self._thread_local_data.execution_mode != mode:
-      self._thread_local_data.execution_mode = mode
-
+    enable_async = (mode == ASYNC)
+    if self.is_async() != enable_async:
       # Only set the execution mode if the context has already been initialized
       if self._context_handle is not None:
-        pywrap_tensorflow.TFE_ContextSetAsyncForThread(self._context_handle,
-                                                       mode == ASYNC)
+        self.executor.wait()
+        executor_new = executor.new_executor(enable_async)
+        self._thread_local_data.executor = executor_new
+        pywrap_tensorflow.TFE_ContextSetExecutorForThread(
+            self._context_handle, executor_new.handle())
       else:
-        self._execution_mode = mode
+        self._default_is_async = enable_async
+
+  def is_async(self):
+    if self._context_handle is not None:
+      return self.executor.is_async()
+    else:
+      return self._default_is_async
+
+  @property
+  def executor(self):
+    ensure_initialized()
+    return executor.Executor(
+        pywrap_tensorflow.TFE_ContextGetExecutorForThread(self._context_handle))
+
+  @executor.setter
+  def executor(self, e):
+    ensure_initialized()
+    pywrap_tensorflow.TFE_ContextSetExecutorForThread(self._context_handle,
+                                                      e.handle())
 
   @property
   def config(self):
@@ -943,14 +957,6 @@
     """Returns function call options for current thread."""
     self._thread_local_data.function_call_options = options
 
-  def async_wait(self):
-    """Waits for ops dispatched in ASYNC mode to finish."""
-    pywrap_tensorflow.TFE_ContextAsyncWait(self._handle)
-
-  def async_clear_error(self):
-    """Clears errors raised during ASYNC execution."""
-    pywrap_tensorflow.TFE_ContextAsyncClearError(self._handle)
-
   def num_gpus(self):
     """The number of GPUs available to execute operations."""
     self.ensure_initialized()
@@ -1730,16 +1736,40 @@
   context().execution_mode = mode
 
 
+# TODO(fishx): remove this method.
 @tf_contextlib.contextmanager
 def execution_mode(mode):
   """Context manager for setting execution mode for current thread."""
   ctx = context()
-  old_mode = ctx.execution_mode
+  executor_new = executor.new_executor(mode == ASYNC)
+  executor_old = ctx.executor
   try:
-    ctx.execution_mode = mode
+    executor_old.wait()
+    ctx.executor = executor_new
     yield
   finally:
-    ctx.execution_mode = old_mode
+    ctx.executor = executor_old
+    executor_new.wait()
+
+
+@tf_contextlib.contextmanager
+def executor_scope(e):
+  """Context manager for changing executor for current thread.
+
+  Args:
+    e: A Executor to execute eager ops under this scope. Setting it to None will
+      switch back to use the default executor for the context.
+
+  Yields:
+    Context manager for setting the executor for current thread.
+  """
+  ctx = context()
+  executor_old = ctx.executor
+  try:
+    ctx.executor = e
+    yield
+  finally:
+    ctx.executor = executor_old
 
 
 @tf_export("experimental.function_executor_type")
@@ -1765,14 +1795,19 @@
     context().function_call_options = old_options
 
 
+def is_async():
+  """Returns true if current thread is in async mode."""
+  return context().is_async()
+
+
 def async_wait():
   """Waits for ops dispatched in ASYNC mode to finish."""
-  return context().async_wait()
+  return context().executor.wait()
 
 
 def async_clear_error():
   """Clears errors raised during ASYNC execution mode."""
-  return context().async_clear_error()
+  return context().executor.clear_error()
 
 
 def num_gpus():
diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py
index f2e77fe..cfe2dd0 100644
--- a/tensorflow/python/eager/core_test.py
+++ b/tensorflow/python/eager/core_test.py
@@ -31,6 +31,7 @@
 from tensorflow.python.eager import core
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import execute as execute_lib
+from tensorflow.python.eager import executor
 from tensorflow.python.eager import test
 from tensorflow.python.framework import config
 from tensorflow.python.framework import constant_op
@@ -127,7 +128,7 @@
       self._test_hashable(variable_a, variable_b, True)
       ops.enable_tensor_equality()
       _v2_check(variable_a, variable_b)
-      self._test_hashable(variable_a, variable_b, True)
+      self._test_hashable(variable_a, variable_b, False)
 
       # We only test numpy behaviour in v2 mode since we'd like to match that.
       numpy_a = np.array(1.0)
@@ -178,7 +179,7 @@
       self._test_hashable(variable_a, variable_b, True)
       ops.enable_tensor_equality()
       _v2_check(variable_a, variable_b)
-      self._test_hashable(variable_a, variable_b, True)
+      self._test_hashable(variable_a, variable_b, False)
 
       numpy_a = np.array(float('nan'))
       numpy_b = np.array(float('nan'))
@@ -223,12 +224,14 @@
       with self.assertRaises(ValueError):
         bool(tf_a == tf_c)
       self.assertAllEqual(tf_a == tf_c, [True, False])
+      self.assertNotAllEqual(tf_a, tf_c)
       with self.assertRaises(ValueError):
         bool(np_a == np_b)
       self.assertAllEqual(np_a == np_b, [True, True])
       with self.assertRaises(ValueError):
         bool(np_a == np_c)
       self.assertAllEqual(np_a == np_c, [True, False])
+      self.assertNotAllEqual(np_a, np_c)
 
       # Warning even though we technically shouldn't be able to compare here,
       # since the id is the same both TF & numpy will handle lists with the same
@@ -488,13 +491,13 @@
       x = x.gpu()
       x = x.gpu()
       x = x.cpu()
-      context.async_wait()
+      context.context().executor.wait()
 
     # Invalid device
     with self.assertRaises(RuntimeError):
       x.gpu(context.context().num_gpus() + 1)
-      context.async_wait()
-    context.async_clear_error()
+      context.context().executor.wait()
+    context.context().executor.clear_error()
 
   @test_util.run_gpu_only
   def testCopyScope(self):
@@ -516,6 +519,22 @@
     test_var = variables.Variable([2., 3.])
     self.assertAllEqual(test_fn(test_var), 1.0)
 
+  def testPyFunctionAsync(self):
+
+    def simple_fn(v):
+      one = constant_op.constant(1.)
+      return v + one
+
+    @def_function.function
+    def test_fn(v):
+      return script_ops.eager_py_func(simple_fn, [v], dtypes.float32)
+
+    async_executor = executor.new_executor(enable_async=True)
+    with context.executor_scope(async_executor):
+      test_var = variables.Variable(2.)
+      self.assertAllEqual(test_fn(test_var), 3.0)
+    async_executor.wait()
+
   @test_util.run_gpu_only
   def testNumpyForceCPU(self):
     cpu = constant_op.constant([[1., 2.], [3., 4.]])
@@ -564,8 +583,8 @@
           inputs=[three, five],
           attrs=('transpose_a', False, 'transpose_b', False, 'T',
                  three.dtype.as_datatype_enum))
-      context.async_wait()
-    context.async_clear_error()
+      context.context().executor.wait()
+    context.context().executor.clear_error()
     context.context().execution_mode = context.SYNC
 
   def testExecuteTooManyNumOutputs(self):
diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py
index 66c7502..f910986 100644
--- a/tensorflow/python/eager/def_function.py
+++ b/tensorflow/python/eager/def_function.py
@@ -33,6 +33,7 @@
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training.tracking import base as trackable
 from tensorflow.python.util import nest
+from tensorflow.python.util import object_identity
 from tensorflow.python.util import tf_decorator
 from tensorflow.python.util.tf_export import tf_export
 
@@ -181,7 +182,16 @@
             with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
               self._initializer_op = resource_variable_ops.assign_variable_op(
                   self._handle, lifted_initializer, name=n)
+      elif context.executing_eagerly():
+        # In this case, both current scope and init scope are eager.
+        # Assign_variable_op will be executed immediately. So we don't need to
+        # add it to "add_initializers_to" to lift it out.
+        with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
+          resource_variable_ops.assign_variable_op(
+              self._handle, initial_value, name=n)
       else:
+        # Init scope is eager but current scope is graph. We will lift out this
+        # variable by addint it into "add_initializers_to".
         if add_initializers_to is not None:
           add_initializers_to[self] = initial_value
         def assign_fn():
@@ -195,7 +205,8 @@
         def not_assign_fn():
           return ops.convert_to_tensor(0)
         # Note: this cond is always guaranteed to run because we're inside a
-        # defun which will insert automatic control dependencies.
+        # defun which will insert automatic control dependencies. It will only
+        # execute assign_fn if lifting failed.
         control_flow_ops.cond(
             resource_variable_ops.var_is_initialized_op(self._handle),
             not_assign_fn, assign_fn)
@@ -252,7 +263,8 @@
                input_signature=None,
                autograph=True,
                experimental_autograph_options=None,
-               experimental_relax_shapes=False):
+               experimental_relax_shapes=False,
+               experimental_compile=None):
     """Initializes a `Function`.
 
     Args:
@@ -268,7 +280,19 @@
         conversion options when autograph is set to True.
       experimental_relax_shapes: When true, argument shapes may be relaxed to
         avoid unecessary retracing.
-
+      experimental_compile: If false, execute the function in a regular way. The
+        function is optimized by some graph rewrite passes (some ops might be
+        clustered into a single op) and interpreted by the standard TensorFlow
+        executor, which dispatches op kernels one by one as they become
+        executable. Set it to false when directly running a multi-device
+        function on TPUs (e.g. two TPU cores, one TPU core and its
+        host CPU). If True, the function is compiled directly by XLA. XLA would
+        fuse all the ops and emit more efficient code to run for some devices
+        (e.g. TPU, XLA_GPU) and some use cases (e.g. dense tensor computation).
+        It requires that the whole function is compilable by XLA. If None
+        (default), compile the function with XLA when running on TPU and go
+        through the regular function execution path when running on other
+        devices.
 
     Raises:
       ValueError: if `input_signature` is not None and the `python_function`'s
@@ -280,6 +304,7 @@
     self._autograph = autograph
     self._experimental_autograph_options = experimental_autograph_options
     self.experimental_relax_shapes = experimental_relax_shapes
+    self._experimental_compile = experimental_compile
     self._created_variables = None
     self._stateful_fn = None
     self._stateless_fn = None
@@ -316,9 +341,16 @@
 
   def _defun(self, fn):
     """Returns a defun generated from the input function."""
-    return function_lib.defun(
+    attributes = None
+    if self._experimental_compile is not None:
+      if self._experimental_compile:
+        attributes = {"_XlaCompile": True}
+      else:
+        attributes = {"_XlaCompile": False}
+    return function_lib.defun_with_attributes(
         fn,
         input_signature=self.input_signature,
+        attributes=attributes,
         autograph=self._autograph,
         experimental_autograph_options=self._experimental_autograph_options,
         experimental_relax_shapes=self.experimental_relax_shapes)
@@ -413,7 +445,7 @@
       return results
 
     # This is the first call of __call__, so we have to initialize.
-    initializer_map = {}
+    initializer_map = object_identity.ObjectIdentityDictionary()
     self._initialize(args, kwds, add_initializers_to=initializer_map)
     if self._created_variables:
       try:
@@ -511,7 +543,7 @@
     # Note: using defun here avoids an infinite recursion.
     @function_lib.defun
     def initialize_variables():
-      op_map = {}
+      op_map = object_identity.ObjectIdentityDictionary()
       for v, init in initializer_map.items():
         with ops.init_scope():
           if resource_variable_ops.var_is_initialized_op(v.handle):
@@ -552,7 +584,7 @@
           "has been used")
     # Here we trace the function, collect the initializers, and attempt to
     # extract them and run them eagerly. Fail only if we cannot do so.
-    initializer_map = {}
+    initializer_map = object_identity.ObjectIdentityDictionary()
     self._initialize(args, kwargs, add_initializers_to=initializer_map)
 
     # Note: using defun here avoids an infinite recursion.
@@ -581,13 +613,7 @@
       concrete_functions.extend(
           self._stateless_fn._function_cache.all_values())
     # pylint: enable=protected-access
-    deduplicated_concrete_functions = []
     seen_signatures = []
-    # We are using a list so that:
-    #  - the returned collection is deterministic, and
-    #  - we can use a custom equality operator (is_same_structure).
-    # This is run only at serialization time on likely very small inputs so we
-    # are not concerned about O(n^2) runtime.
     for concrete_function in concrete_functions:
       signature = concrete_function.structured_input_signature
       flattened = nest.flatten(signature)
@@ -599,9 +625,14 @@
       equal_to_signature = functools.partial(
           function_lib.is_same_structure, signature, check_values=True)
       if not any(equal_to_signature(s) for s in seen_signatures):
-        deduplicated_concrete_functions.append(concrete_function)
         seen_signatures.append(signature)
-    return deduplicated_concrete_functions
+
+    # Re-create concrete functions for these signatures. Re-creating ensures
+    # that if the cache key has changed, the function will be traced again.
+    concrete_functions = []
+    for args, kwargs in seen_signatures:
+      concrete_functions.append(self.get_concrete_function(*args, **kwargs))
+    return concrete_functions
 
   def get_concrete_function(self, *args, **kwargs):
     """Returns a `ConcreteFunction` specialized to inputs and execution context.
@@ -680,7 +711,7 @@
       ValueError: if this object has not yet been called on concrete values.
     """
     if self._stateful_fn is None:
-      initializer_map = {}
+      initializer_map = object_identity.ObjectIdentityDictionary()
       self._initialize(args, kwargs, add_initializers_to=initializer_map)
       self._initialize_uninitialized_variables(initializer_map)
 
@@ -730,7 +761,8 @@
              input_signature=None,
              autograph=True,
              experimental_autograph_options=None,
-             experimental_relax_shapes=False):
+             experimental_relax_shapes=False,
+             experimental_compile=None):
   """Creates a callable TensorFlow graph from a Python function.
 
   `function` constructs a callable that executes a TensorFlow graph
@@ -987,6 +1019,21 @@
       autograph=True.
     experimental_relax_shapes: When true, argument shapes may be relaxed to
       avoid unecessary retracing.
+    experimental_compile: If false, execute the function in a regular way. The
+      function is optimized by some graph rewrite passes (some ops might be
+      clustered into a single op) and interpreted by the standard TensorFlow
+      executor, which dispatches op kernels one by one as they become
+      executable. Set it to false when directly running a multi-device function
+      on TPUs (e.g. two TPU cores, one TPU core and its host CPU). If True, the
+      function is compiled directly by XLA (https://www.tensorflow.org/xla).
+      XLA would fuse all the ops and emit more efficient code to run for some
+      devices (e.g. TPU, XLA_GPU) and some use cases (e.g. dense tensor
+      computation). It requires that the whole function is compilable by XLA
+      (e.g. static tensor shape, a subset of operations, no string, compile-time
+      constant input, etc). If None (default), compile the function with XLA
+      when running on TPU and go through the regular function execution path
+      when running on other devices. Note: TensorArrays on TPU don't work with
+      standard TensorFlow executor.
 
   Returns:
      If `func` is not None, returns a callable that will execute the compiled
@@ -1014,7 +1061,8 @@
             input_signature=input_signature,
             autograph=autograph,
             experimental_autograph_options=experimental_autograph_options,
-            experimental_relax_shapes=experimental_relax_shapes))
+            experimental_relax_shapes=experimental_relax_shapes,
+            experimental_compile=experimental_compile))
 
   # This code path is for the `foo = tf.function(foo, ...)` use case
   if func is not None:
diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py
index 4a7d6fe..9ab42b6 100644
--- a/tensorflow/python/eager/def_function_test.py
+++ b/tensorflow/python/eager/def_function_test.py
@@ -391,7 +391,8 @@
         outputs.append(inputs[t])
       return outputs
 
-    with self.assertRaisesRegexp(ValueError, 'inner'):
+    with self.assertRaisesRegexp(errors.InaccessibleTensorError,
+                                 'defined in another function or code block'):
       f(array_ops.zeros(shape=(8, 42, 3)))
 
   def testRuntimeErrorNotSticky(self):
diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py
new file mode 100644
index 0000000..c3e90cd
--- /dev/null
+++ b/tensorflow/python/eager/def_function_xla_jit_test.py
@@ -0,0 +1,47 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.eager import def_function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class DefFunctionTest(test.TestCase):
+
+  def testCompileFunctionWithXLA(self):
+
+    def fn(x):
+      return array_ops.unique(x).y  # Unique is not supported by XLA
+
+    func = def_function.function(fn, experimental_compile=False)
+    xla_func = def_function.function(fn, experimental_compile=True)
+
+    inputs = constant_op.constant([1, 2, 2, 3, 3])
+    self.assertAllClose([1, 2, 3], func(inputs))
+    with self.assertRaisesRegexp(errors.InvalidArgumentError,
+                                 'node is not compilable'):
+      xla_func(inputs)
+
+
+if __name__ == '__main__':
+  ops.enable_eager_execution()
+  test.main()
diff --git a/tensorflow/python/eager/executor.py b/tensorflow/python/eager/executor.py
new file mode 100644
index 0000000..be84401
--- /dev/null
+++ b/tensorflow/python/eager/executor.py
@@ -0,0 +1,76 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Executor for eager execution."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python import pywrap_tensorflow
+
+
+class Executor(object):
+  """A class for handling eager execution.
+
+  The default behavior for asynchronous execution is to serialize all ops on
+  a single thread. Having different `Executor` objects in different threads
+  enables executing ops asynchronously in parallel:
+
+  ```python
+  def thread_function():
+    executor = executor.Executor(enable_async=True):
+    context.set_executor(executor)
+
+  a = threading.Thread(target=thread_function)
+  a.start()
+  b = threading.Thread(target=thread_function)
+  b.start()
+  ```
+  """
+
+  def __init__(self, handle):
+    self._handle = handle
+
+  def __del__(self):
+    try:
+      # pywrap_tensorflow.TFE_ExecutorWaitForAllPendingNodes(self._handle)
+      pywrap_tensorflow.TFE_DeleteExecutor(self._handle)
+    except TypeError:
+      # Suppress some exceptions, mainly for the case when we're running on
+      # module deletion. Things that can go wrong include the pywrap module
+      # already being unloaded, self._handle. no longer being
+      # valid, and so on. Printing warnings in these cases is silly
+      # (exceptions raised from __del__ are printed as warnings to stderr).
+      pass  # 'NoneType' object is not callable when the handle has been
+      # partially unloaded.
+
+  def is_async(self):
+    return pywrap_tensorflow.TFE_ExecutorIsAsync(self._handle)
+
+  def handle(self):
+    return self._handle
+
+  def wait(self):
+    """Waits for ops dispatched in this executor to finish."""
+    pywrap_tensorflow.TFE_ExecutorWaitForAllPendingNodes(self._handle)
+
+  def clear_error(self):
+    """Clears errors raised in this executor during execution."""
+    pywrap_tensorflow.TFE_ExecutorClearError(self._handle)
+
+
+def new_executor(enable_async):
+  handle = pywrap_tensorflow.TFE_NewExecutor(enable_async)
+  return Executor(handle)
diff --git a/tensorflow/python/eager/forwardprop_test.py b/tensorflow/python/eager/forwardprop_test.py
index ffc688a..bc46f08 100644
--- a/tensorflow/python/eager/forwardprop_test.py
+++ b/tensorflow/python/eager/forwardprop_test.py
@@ -282,9 +282,12 @@
       f = _forwardgrad(f)
     self.assertAllClose(expected, f(primal))
 
-  def testFunctionGradPureForward(self):
+  @parameterized.named_parameters(
+      [("Function", def_function.function),
+       ("NoFunction", lambda f: f)])
+  def testGradPureForward(self, decorator):
 
-    @def_function.function
+    @decorator
     def f(x):
       return x ** 3.5
 
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index d2bce69..94727c6 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -60,6 +60,7 @@
 from tensorflow.python.util import lazy_loader
 from tensorflow.python.util import memory
 from tensorflow.python.util import nest
+from tensorflow.python.util import object_identity
 from tensorflow.python.util import tf_decorator
 from tensorflow.python.util import tf_inspect
 
@@ -75,10 +76,38 @@
 BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"
 
 
-CacheKey = collections.namedtuple("CacheKey", [
-    "input_signature", "parent_graph", "device_functions", "colocation_stack",
-    "in_cross_replica_context"
-])
+class CacheKey(
+    collections.namedtuple("CacheKey", [
+        "input_signature", "parent_graph", "device_functions",
+        "colocation_stack", "in_cross_replica_context"
+    ])):
+  """Named tuple used to key the function cache."""
+
+  def __hash__(self):
+    """Provide a hash even if the input signature objects aren't hashable."""
+    return hash((self._hash_fix(self.input_signature), self.parent_graph,
+                 self.device_functions, self.colocation_stack,
+                 self.in_cross_replica_context))
+
+  def _hash_fix(self, elem):
+    """Ensure elem is hashable even if a Variable is nested in it."""
+    # Descend into tuples
+    if isinstance(elem, tuple):
+      return tuple(self._hash_fix(i) for i in elem)
+
+    if isinstance(elem, set):
+      return {self._hash_fix(i) for i in elem}
+
+    # If the element is not hashable, assume it is a weakref to a variable and
+    # return the dtype & shape. Else, simply return the element
+    try:
+      hash(elem)
+    except TypeError:
+      v = elem()
+      return (v.__class__, tensor_spec.TensorSpec(v.shape, v.dtype))
+
+    return elem
+
 
 CacheKey.replace = CacheKey._replace  # pylint: disable=protected-access
 
@@ -355,9 +384,11 @@
     operations = [op for op in graph.get_operations() if op not in input_ops]
 
     graph_output_names = graph._output_names  # pylint: disable=protected-access
-    if (graph_output_names is not None
-        and all(t in graph_output_names for t in outputs)):
-      output_names = [compat.as_bytes(graph_output_names[t]) for t in outputs]
+    if (graph_output_names is not None and
+        all(ops.tensor_id(t) in graph_output_names for t in outputs)):
+      output_names = [
+          compat.as_bytes(graph_output_names[ops.tensor_id(t)]) for t in outputs
+      ]
       if len(set(output_names)) != len(output_names):
         # There are duplicate names for some reason, probably an invalid
         # signature. Revert to auto-naming.
@@ -449,7 +480,8 @@
     """
     if len(args) != len(self.signature.input_arg):
       raise ValueError(
-          "Arguments and signature arguments do not match: %s %s " %
+          "Arguments and signature arguments do not match. "
+          "got: %s, expected: %s " %
           (len(args), len(list(self.signature.input_arg))))
 
     function_call_options = ctx.function_call_options
@@ -512,6 +544,449 @@
       return outputs
 
 
+class _DelayedRewriteGradientFunctions(object):
+  """Caches forward/backward functions with a delayed forward rewrite."""
+
+  def __init__(self, func_graph, attrs, func_graph_deleter):
+    """Construct an inference function and initialize caches."""
+    # A map from the number of forward function outputs with accepted gradients
+    # to forward and backward functions, used to cache non-tape backward
+    # function generation.
+    self._cached_function_pairs = {}
+    self._func_graph = func_graph
+    self._inference_function = _EagerDefinedFunction(
+        _inference_name(self._func_graph.name), self._func_graph,
+        self._func_graph.inputs, self._func_graph.outputs, attrs)
+    self._attrs = attrs
+    self._gradient_name = None
+    # Note that the FuncGraph is mutated later, so we need to inspect it now to
+    # figure out the user-specified outputs of the inference function.
+    self._num_inference_outputs = len(self._func_graph.outputs)
+    self._func_graph_deleter = func_graph_deleter
+
+  def forward_backward(self, num_doutputs=None):
+    """A possibly-cached pair of forward and backward functions."""
+    if num_doutputs is None:
+      num_doutputs = self._num_inference_outputs
+    forward_backward = self._cached_function_pairs.get(num_doutputs)
+    if forward_backward is not None:
+      return forward_backward
+    forward, backward = self._construct_forward_backward(num_doutputs)
+    self._cached_function_pairs[num_doutputs] = (forward, backward)
+    return forward, backward
+
+  def _construct_forward_backward(self, num_doutputs):
+    """Constructs a pair of forward and backward functions.
+
+    Args:
+      num_doutputs: The constructed backprop function will take output gradients
+        for the first `num_doutputs` outputs of the forward function. Defaults
+        to the number of outputs for the inference function, but when
+        higher-order gradients are computed this will increase to include side
+        outputs.
+
+    Returns:
+      A pair of (forward_function, backward_function):
+        forward_function: A re-generated inference function (an
+          _EagerDefinedFunction) to account for new side outputs, if any extra
+          were required when building the backward pass.
+        backward_function: A ConcreteFunction that Takes `num_doutputs`
+          arguments and returns gradients with respect to inputs of the forward
+          function.
+    """
+    trainable_outputs = [
+        output for output in self._func_graph.outputs[:num_doutputs]
+        if gradients_util.IsTrainable(output)]
+
+    signature = []
+    for t in trainable_outputs:
+      signature.append(
+          tensor_spec.TensorSpec(*default_gradient.shape_and_dtype(t)))
+
+    def _backprop_function(*grad_ys):
+      return gradients_util._GradientsHelper(  # pylint: disable=protected-access
+          trainable_outputs,
+          self._func_graph.inputs,
+          grad_ys=grad_ys,
+          src_graph=self._func_graph)
+
+    with self._func_graph.as_default():
+      backwards_graph = func_graph_module.FuncGraph(
+          _backward_name(self._func_graph.name))
+      func_graph_module.func_graph_from_py_func(
+          name=backwards_graph.name,
+          python_func=_backprop_function,
+          args=[], kwargs={},
+          signature=signature,
+          func_graph=backwards_graph)
+      backwards_graph_captures = backwards_graph.external_captures
+      captures_from_forward = [
+          c for c in backwards_graph_captures if
+          not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph]
+
+      forward_function_name = _forward_name(self._func_graph.name)
+
+      existing_outputs = object_identity.ObjectIdentitySet(
+          self._func_graph.outputs)
+      for capture in captures_from_forward:
+        if capture not in existing_outputs:
+          existing_outputs.add(capture)
+          self._func_graph.outputs.append(capture)
+      backward_function_attr = _parse_func_attrs(
+          {FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
+      backward_function_attr.update(self._attrs)
+
+      backward_function = ConcreteFunction(
+          backwards_graph, attrs=backward_function_attr)
+      forward_function_attr = _parse_func_attrs({
+          BACKWARD_FUNCTION_ATTRIBUTE_NAME:
+          backward_function.name})
+      forward_function_attr.update(self._attrs)
+
+      forward_function = _EagerDefinedFunction(
+          forward_function_name, self._func_graph, self._func_graph.inputs,
+          self._func_graph.outputs, forward_function_attr)
+      return forward_function, backward_function
+
+  def _rewrite_forward_and_call_backward(self, op, *doutputs):
+    """Add outputs to the forward call and feed them to the grad function."""
+    forward_function, backwards_function = self.forward_backward(len(doutputs))
+    if not backwards_function.outputs:
+      return []
+    forward_function.add_to_graph(op.graph)
+
+    # pylint: disable=protected-access
+    # Rewrite an inference call op to be a forward call op
+    op._set_func_attr("f", forward_function.name)
+    op._set_type_list_attr("Tout", forward_function._output_types)
+    op._add_outputs(
+        forward_function._output_types[len(op.outputs):],
+        forward_function._output_shapes[len(op.outputs):])
+    for i in range(len(op.outputs)):
+      func_graph_output = forward_function._func_graph_outputs[i]
+      custom_gradient.copy_handle_data(func_graph_output, op.outputs[i])
+    # pylint: enable=protected-access
+
+    capture_mapping = dict(
+        zip([ops.tensor_id(t) for t in self._func_graph.outputs], op.outputs))
+    remapped_captures = [
+        capture_mapping.get(ops.tensor_id(capture), capture)
+        for capture in backwards_function.captured_inputs
+    ]
+
+    # Replace Nones with zeros since we're calling a graph function which
+    # expects numeric inputs.
+    cleaned_doutputs = []
+    for doutput, placeholder in zip(doutputs, self._func_graph.outputs):
+      if gradients_util.IsTrainable(placeholder):
+        if doutput is not None:
+          cleaned_doutputs.append(doutput)
+        else:
+          cleaned_doutputs.append(default_gradient.zeros_like(placeholder))
+
+    # Compute the gradients using the side outputs
+    return backwards_function._call_flat(  # pylint: disable=protected-access
+        cleaned_doutputs, remapped_captures)
+
+  def register(self):
+    """Registers a delayed-rewrite gradient with a unique name (idempotent).
+
+    The gradient rewrites an inference call op to a forward call op, but does
+    not modify a pre-existing forward call op. It then computes the gradient
+    from the output's gradients and the side outputs of the forward op.
+
+    Returns:
+      The name under which gradient was registered.
+    """
+    if self._gradient_name:
+      return self._gradient_name
+    self._gradient_name = "PartitionedCall-%s" % ops.uid()
+
+    @ops.RegisterGradient(self._gradient_name)
+    def _registered_grad_fn(op, *doutputs):  # pylint: disable=unused-variable
+      return self._rewrite_forward_and_call_backward(op, *doutputs)
+    return self._gradient_name
+
+  @property
+  def forward(self):
+    """A forward function with only user-specified outputs.
+
+    The call operation for the returned inference function can be rewritten into
+    a forward function. This only happens if the backward function (from the
+    `backward` method) ends up being used to compute gradients.
+
+    This approach avoids constructing unnecessary graphs, but it only works if
+    we are calling this function when not executing eagerly.
+
+    Returns:
+      An _EagerDefinedFunction.
+    """
+    return self._inference_function
+
+  def backward(self, outputs):
+    """Fetch a backward function for `outputs` from the forward function."""
+    def _backward_function(*args):
+      call_op = outputs[0].op
+      return self._rewrite_forward_and_call_backward(call_op, *args)
+    return _backward_function, outputs
+
+
+class _TapeGradientFunctions(object):
+  """Caches forward and backward functions compatible with eager gradients.
+
+  In contrast to the delayed-rewrite approach in
+  `_DelayedRewriteGradientFunctions` which only works with delayed execution,
+  the forward function generated by this class has a fixed set of outputs which
+  may be preserved by a tape in order to compute gradients later.
+
+  This class is abstract; its child classes differ in how many side outputs of
+  the forward function their backward function accepts gradients for, which
+  determines whether higher-order tape gradients are possible.
+  """
+
+  def __init__(self, func_graph, attrs, func_graph_deleter):
+    self._func_graph = func_graph
+    self._attrs = attrs
+    self._forward = None
+    self._backward = None
+    self._num_outputs = len(func_graph.outputs)
+    self._func_graph_deleter = func_graph_deleter
+
+  def _build_functions_for_outputs(self, outputs):
+    """Forward+backward functions where the backward function sees `outputs`."""
+    # First figure out which of `outputs` are trainable. We'll accept gradients
+    # for each of these in the backward function.
+    handles_to_variables = self._func_graph.variable_captures
+    trainable_outputs = []
+    for output in outputs:
+      if gradients_util.IsTrainable(output):
+        # Swap in the Variable object for resource handles if we can so
+        # sparse gradients work.
+        output = handles_to_variables.get(ops.tensor_id(output), output)
+        trainable_outputs.append(output)
+
+    backwards_graph = func_graph_module.FuncGraph(
+        _backward_name(self._func_graph.name))
+    # Keep track of the forward graph so that if the backwards graph
+    # tries to capture tensors those will be correctly captured first in
+    # the forward graph. This is an edge case that can only happen with
+    # tf.custom_gradient.
+    backwards_graph._forward_func_graph = self._func_graph  # pylint: disable=protected-access
+    with backwards_graph.as_default():
+      gradients_wrt_outputs = []
+      for output in trainable_outputs:
+        gradient_shape, gradient_dtype = default_gradient.shape_and_dtype(
+            output)
+        gradients_wrt_outputs.append(
+            graph_placeholder(gradient_dtype, gradient_shape))
+      gradients_wrt_inputs = gradients_util._GradientsHelper(  # pylint: disable=protected-access
+          trainable_outputs,
+          self._func_graph.inputs,
+          grad_ys=gradients_wrt_outputs,
+          src_graph=self._func_graph)
+
+      captures_from_forward = [
+          c for c in backwards_graph.external_captures
+          if not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph
+      ]
+      existing_outputs = object_identity.ObjectIdentitySet(
+          self._func_graph.outputs)
+      for capture in captures_from_forward:
+        if capture not in existing_outputs:
+          existing_outputs.add(capture)
+          self._func_graph.outputs.append(capture)
+
+    forward_function_name = _forward_name(self._func_graph.name)
+    backward_function_attr = _parse_func_attrs(
+        {FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
+    backward_function_attr.update(self._attrs)
+
+    # The ordering of `backwards_graph.inputs` is important: inputs of
+    # `backward_function` correspond to outputs (including
+    # side outputs) of `self._tape_forward_function`.
+    backwards_graph.inputs = (
+        gradients_wrt_outputs + backwards_graph.internal_captures)
+    backwards_graph.outputs.extend(
+        grad
+        for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True)
+        if grad is not None)
+    backwards_graph.structured_outputs = gradients_wrt_inputs
+    backward_function = ConcreteFunction(
+        backwards_graph, attrs=backward_function_attr)
+
+    forward_function_attr = _parse_func_attrs({
+        BACKWARD_FUNCTION_ATTRIBUTE_NAME:
+            backward_function.name})
+    forward_function_attr.update(self._attrs)
+
+    forward_function = _EagerDefinedFunction(
+        forward_function_name, self._func_graph, self._func_graph.inputs,
+        self._func_graph.outputs,
+        forward_function_attr)
+    return forward_function, backward_function
+
+  @property
+  def forward(self):
+    """Construct or fetch a forward function with side-outputs.
+
+    When graph building without a tape active, symbolic gradients rely on
+    regenerating the backward function for higher-order gradients (to account
+    for new side outputs of the rewritten forward function call). Thus there is
+    no fixed backward function for this case. However, when a tape is active
+    (eager or graph building), we generate fixed backward and forward functions
+    at forward function call time.
+
+    This difference between the tape and non-tape cases is to avoid building
+    unneeded backward functions while graph building (where we may or may not
+    eventually need gradients).
+
+    Returns:
+      A forward _EagerDefinedFunction.
+    """
+    if self._forward is None:
+      self._forward, self._backward = (
+          self._forward_and_backward_functions())
+    return self._forward
+
+  def backward(self, outputs):
+    """Create a backward function given `outputs` from the forward function."""
+    capture_mapping = dict(
+        zip([ops.tensor_id(t) for t in self._func_graph.outputs], outputs))
+    remapped_captures = [
+        capture_mapping.get(ops.tensor_id(capture), capture)
+        for capture in self._backward.captured_inputs
+    ]
+    # We may need to use zeros_like to get a zero for variant Tensors with
+    # unconnected gradients. We do that in advance so we don't have to hold on
+    # to the outputs themselves, which may not be needed otherwise.
+    variant_zeros_like = {}
+    backward_function_inputs = (
+        len(self._backward.inputs) - len(self._backward.captured_inputs))
+    recorded_outputs = []
+    trainable_recorded_outputs = 0
+    skip_positions = []
+    for output_index, output in enumerate(outputs):
+      if trainable_recorded_outputs < backward_function_inputs:
+        recorded_outputs.append(output)
+      if gradients_util.IsTrainable(output):
+        trainable_recorded_outputs += 1
+      else:
+        skip_positions.append(output_index)
+      if output.dtype == dtypes.variant:
+        variant_zeros_like[output_index] = default_gradient.zeros_like(output)
+
+    def _backward_function_wrapper(*args):
+      """Process output gradients and call the backward function."""
+      if not self._backward.outputs:
+        return []
+      processed_args = []
+      input_index = 0
+      for output_index, arg in enumerate(args):
+        if output_index in skip_positions:
+          continue
+        if arg is None:
+          # We're calling a (non-polymorphic) ConcreteFunction, so we need to
+          # have a Tensor value for each Tensor we thought would be trainable
+          # based on its dtype, even if it ended up being unconnected.
+          input_placeholder = self._backward.inputs[
+              input_index]
+          if input_placeholder.dtype == dtypes.variant:
+            arg = variant_zeros_like[output_index]
+          else:
+            arg = array_ops.zeros(
+                *default_gradient.shape_and_dtype(input_placeholder))
+        processed_args.append(arg)
+        input_index += 1
+        if input_index >= backward_function_inputs:
+          break
+      return self._backward._call_flat(  # pylint: disable=protected-access
+          processed_args, remapped_captures)
+
+    return _backward_function_wrapper, recorded_outputs
+
+
+class _FirstOrderTapeGradientFunctions(_TapeGradientFunctions):
+  """Caches tape-friendly functions for first-order gradients."""
+
+  def __init__(self, func_graph, attrs, func_graph_deleter):
+    super(_FirstOrderTapeGradientFunctions, self).__init__(
+        func_graph, attrs, func_graph_deleter)
+    self._num_inference_outputs = len(func_graph.outputs)
+    self._func_graph_deleter = func_graph_deleter
+
+  def _forward_and_backward_functions(self):
+    """Shortcut for when only first-order gradients are required.
+
+    The returned backward function does not accept gradients with respect to
+    side output of forward_function. This is fine as long as the user can't
+    possibly request second order tape gradients, as when they've used a single
+    non-persistent GradientTape. Since we don't need the backward function to
+    take gradients with respect to side outputs, we can skip some potentially
+    slow graph building.
+
+    Returns:
+      A tuple of (forward_function, backward_function):
+        forward_function: Takes the same inputs as the inference function, but
+          returns side outputs used by backward_function in addition to the
+          inference function's outputs.
+        backward_function: Takes side outputs from forward_function and
+          gradients with respect to the "real" outputs of forward_function and
+          returns gradients with respect to the inputs.
+    """
+    outputs = self._func_graph.outputs[:self._num_inference_outputs]
+    return self._build_functions_for_outputs(outputs)
+
+
+class _HigherOrderTapeGradientFunctions(_TapeGradientFunctions):
+  """Caches tape-friendly functions for higher-order gradients."""
+
+  # TODO(b/136189779): Cond/while under a tape may need similar logic. Consider
+  # generalizing if so.
+  def _forward_and_backward_functions(self):
+    """Forward and backward functions suitable for higher-order gradients.
+
+    Unlike in `_FirstOrderTapeGradientFunctions`, the backward function built by
+    this method accepts gradients for all of the outputs of the returned forward
+    function, including side outputs.
+
+    Returns:
+      A tuple of (forward_function, backward_function):
+        forward_function: Takes the same inputs as the inference function, but
+          returns side outputs used by backward_function in addition to the
+          inference function's outputs.
+        backward_function: Takes side outputs from forward_function and
+          gradients with respect to all of its outputs, real and side. Returns
+          gradients with respect to the inputs.
+    """
+    outputs = []
+    # First we need to figure out how many side outputs from the forward pass
+    # will be required. We do this in a temporary graph to avoid actually
+    # running multiple copies of the backward pass (one per _GradientsHelper
+    # call).
+    #
+    # While computing gradients, the backward function captures Tensors from
+    # the forward function. We add these as side outputs of the original
+    # function. However, we then need to accept output gradients with respect
+    # to these side outputs for higher order gradients to work. Thus we loop
+    # until the number of outputs of the function stabilizes. Note that this
+    # is only required for tape gradients, where we need to declare in advance
+    # all of the forward op's outputs: symbolic gradients with tf.gradients
+    # instead rely on regenerating backward functions when higher-order
+    # gradients are requested.
+    while len(outputs) < len(self._func_graph.outputs):
+      new_outputs = self._func_graph.outputs[len(outputs):]
+      outputs = list(self._func_graph.outputs)
+      self._build_functions_for_outputs(new_outputs)
+    forward_function, backward_function = (
+        self._build_functions_for_outputs(outputs))
+    if len(self._func_graph.outputs) != len(outputs):
+      raise AssertionError(
+          ("Unexpectedly added new outputs to the forward function when "
+           "building the backward function: {}").format(
+               self._func_graph.outputs[len(outputs):]))
+    return forward_function, backward_function
+
+
 class _PossibleTapeGradientTypes(enum.Enum):
   """Represents the output of TFE_Py_TapeSetPossibleGradientTypes."""
   NONE = 0
@@ -526,7 +1001,8 @@
   is differentiable under `tf.GradientTape` objects.
   """
 
-  def __init__(self, func_graph, attrs=None, signature=None):
+  def __init__(self, func_graph, attrs=None, signature=None,
+               shared_func_graph=True):
     """Initialize a `ConcreteFunction`.
 
     Args:
@@ -536,6 +1012,9 @@
         definition.
      signature: a nested sequence of `TensorSpec` objects specifying the input
        signature of this function.
+     shared_func_graph: If False, the ConcreteFunction takes ownership of
+       `func_graph` and will break reference cycles when it is deleted. This
+       makes the FuncGraph inoperable.
 
     Raises:
       ValueError: If number of input_placeholders is not equal to the number
@@ -544,40 +1023,29 @@
     self._arg_keywords = None
     self._num_positional_args = None
     self._func_graph = func_graph
-    self._captured_inputs = list(self._func_graph.captures.keys())
-    self._captured_closures = [
-        x[0] for x in self._func_graph.deferred_captures.values()]
-    self._num_outputs = len(self._func_graph.outputs)
+    self._captured_inputs = self._func_graph.external_captures
+    self._captured_closures = self._func_graph.deferred_external_captures
     self._output_shapes = tuple(
         output.shape for output in self._func_graph.outputs)
-    self._attrs = _parse_func_attrs(attrs or {})
-
-    self._inference_function = _EagerDefinedFunction(
-        _inference_name(self._func_graph.name), self._func_graph,
-        self._func_graph.inputs, self._func_graph.outputs, self._attrs)
-
-    # When graph building without a tape active, symbolic gradients rely on
-    # regenerating the backward function for higher-order gradients (to account
-    # for new side outputs of the rewritten forward function call). Thus there
-    # is no fixed backward function for this case. However, when a tape is
-    # active (eager or graph building), we generate fixed backward and forward
-    # functions at forward function call time.
-    #
-    # This difference between the tape and non-tape cases is to avoid building
-    # unneeded backward functions while graph building (where we may or may not
-    # eventually need gradients).
-    self._tape_forward_function_first_order = None
-    self._tape_backward_function_first_order = None
-    self._tape_forward_function_higher_order = None
-    self._tape_backward_function_higher_order = None
-
-    # A map from the number of forward function outputs with accepted gradients
-    # to backward functions, used to cache non-tape backward function
-    # generation.
-    self._cached_graph_backprop_functions = {}
-
+    attrs = _parse_func_attrs(attrs or {})
     self._signature = signature
-    self._gradient_name = None
+
+    if shared_func_graph:
+      self._garbage_collector = None
+    else:
+      self._garbage_collector = ConcreteFunctionGarbageCollector(
+          func_graph)
+
+    # Pairs of forward and backward functions used for computing gradients.
+    #
+    # These each get a reference to the FuncGraph deleter since they use the
+    # FuncGraph directly.
+    self._delayed_rewrite_functions = _DelayedRewriteGradientFunctions(
+        func_graph, attrs, self._garbage_collector)
+    self._first_order_tape_functions = _FirstOrderTapeGradientFunctions(
+        func_graph, attrs, self._garbage_collector)
+    self._higher_order_tape_functions = _HigherOrderTapeGradientFunctions(
+        func_graph, attrs, self._garbage_collector)
 
   def __call__(self, *args, **kwargs):
     """Executes the wrapped function.
@@ -685,6 +1153,11 @@
     ctx = context.context()
     executing_eagerly = ctx.executing_eagerly()
 
+    # Copy saveable status of function's graph to current FuncGraph.
+    default_graph = ops.get_default_graph()
+    if default_graph.building_function and not self._func_graph.saveable:
+      default_graph.mark_as_unsaveable(self._func_graph.saving_errors)
+
     if any(isinstance(a, composite_tensor.CompositeTensor) for a in args):
       raise AssertionError("Expected all args to be Tensors or Variables; "
                            "but got CompositeTensor: %r" % args)
@@ -695,7 +1168,7 @@
         resource_variable_ops.variable_accessed(v)
 
     tensor_inputs = []
-    variables_used = set([])
+    variables_used = object_identity.ObjectIdentitySet([])
     for i, arg in enumerate(args):
       if isinstance(arg, resource_variable_ops.BaseResourceVariable):
         # We can pass a variable more than once, and in this case we need to
@@ -736,104 +1209,24 @@
                          "on invocation of %s, the %d-th input (%s) was not a "
                          "Tensor." % (self._func_graph.name, i, str(arg)))
     args = tensor_inputs + captured_inputs
-
-    possible_gradient_type = _PossibleTapeGradientTypes(
-        pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes(args))
-    if possible_gradient_type == _PossibleTapeGradientTypes.FIRST_ORDER:
-      if context.executing_eagerly():
-        # There is a single non-persistent tape active, so the user can only
-        # request first-order gradients from a tape. We can spend less time
-        # graph building since we know this.
-        #
-        # We may still end up computing higher-order gradients, but that'd be
-        # through `tf.gradients`, which can re-write the forward pass and so
-        # needs no preparation here.
-        forward_function, backward_function = (
-            self._tape_functions_for_first_order())
-        return self._tape_backprop_call(
-            args, forward_function, backward_function)
-      else:
-        # We can avoid computing second-order gradients in some cases by doing a
-        # delayed rewrite when graph building. Since we know we'll only compute
-        # first-order tape gradients, the delayed rewrite is safe: we won't need
-        # to tell the tape about side outputs.
-        #
-        # TODO(allenl): This case is really dirty. It would be better if we
-        # could temporarily pop all of the current tapes to avoid
-        # accidentally taking second-order gradients.
-        return self._backprop_call_with_delayed_rewrite(args)
-    elif possible_gradient_type == _PossibleTapeGradientTypes.HIGHER_ORDER:
-      # Either there's a persistent tape watching, or there are multiple nested
-      # tapes. Either way, the user may request higher-order gradients. We'll
-      # spend a bit more time and make sure higher-order gradients are correct.
-      forward_function, backward_function = (
-          self._tape_functions_for_higher_order())
-      return self._tape_backprop_call(args, forward_function, backward_function)
-    # else possible_gradient_type == _PossibleTapeGradientTypes.NONE, meaning no
-    # tape is recording.
-
-    # Only need to override the gradient in graph mode and when we have outputs.
-    if context.executing_eagerly() or not self.outputs:
-      outputs = self._inference_function.call(
+    forward_backward = self._select_forward_and_backward_functions(args)
+    forward_function = forward_backward.forward
+    if executing_eagerly:
+      flat_outputs = forward_function.call(
           ctx, args, cancellation_manager=cancellation_manager)
     else:
-      self._register_gradient()
+      gradient_name = self._delayed_rewrite_functions.register()
       with ops.get_default_graph().gradient_override_map(
-          {"PartitionedCall": self._gradient_name,
-           "StatefulPartitionedCall": self._gradient_name}):
-        outputs = self._inference_function.call(ctx, args)
-    return self._build_call_outputs(outputs)
-
-  def _register_gradient(self):
-    """Registers the gradient for this `ConcreteFunction`.
-
-    The gradient rewrites an inference call op to a forward call op, but does
-    not modify a pre-existing forward call op. It then computes the gradient
-    from the output's gradients and the side outputs of the forward op.
-    """
-    if self._gradient_name:
-      return
-    self._gradient_name = "PartitionedCall-%s" % ops.uid()
-
-    @ops.RegisterGradient(self._gradient_name)
-    def _registered_grad_fn(op, *doutputs):  # pylint: disable=unused-variable
-      return self._grad_fn(op, *doutputs)
-
-  def _grad_fn(self, op, *doutputs):
-    """Gradients of this function."""
-    backwards_function = self._graph_backprop_function(len(doutputs))
-    self._forward_function.add_to_graph(op.graph)
-
-    # pylint: disable=protected-access
-    # Rewrite an inference call op to be a forward call op
-    op._set_func_attr("f", self._forward_function.name)
-    op._set_type_list_attr("Tout", self._forward_function._output_types)
-    op._add_outputs(
-        self._forward_function._output_types[len(op.outputs):],
-        self._forward_function._output_shapes[len(op.outputs):])
-    for i in range(len(op.outputs)):
-      func_graph_output = self._forward_function._func_graph_outputs[i]
-      custom_gradient.copy_handle_data(func_graph_output, op.outputs[i])
-    # pylint: enable=protected-access
-
-    capture_mapping = dict(zip(self._func_graph.outputs, op.outputs))
-    remapped_captures = []
-    for capture in backwards_function.captured_inputs:
-      remapped_captures.append(capture_mapping.get(capture, capture))
-
-    # Replace Nones with zeros since we're calling a graph function which
-    # expects numeric inputs.
-    cleaned_doutputs = []
-    for doutput, placeholder in zip(doutputs, self._func_graph.outputs):
-      if gradients_util.IsTrainable(placeholder):
-        if doutput is not None:
-          cleaned_doutputs.append(doutput)
-        else:
-          cleaned_doutputs.append(default_gradient.zeros_like(placeholder))
-
-    # Compute the gradients using the side outputs
-    return backwards_function._call_flat(  # pylint: disable=protected-access
-        cleaned_doutputs, remapped_captures)
+          {"PartitionedCall": gradient_name,
+           "StatefulPartitionedCall": gradient_name}):
+        flat_outputs = forward_function.call(ctx, args)
+    if isinstance(flat_outputs, ops.Operation) or flat_outputs is None:
+      # We only record function calls which have outputs.
+      return self._build_call_outputs(flat_outputs)
+    backward_function, to_record = forward_backward.backward(flat_outputs)
+    tape.record_operation(forward_function.signature.name,
+                          to_record, args, backward_function)
+    return self._build_call_outputs(flat_outputs)
 
   def _experimental_with_cancellation_manager(self, cancellation_manager):
     """Returns a callable that invokes a cancelable version of this function.
@@ -855,7 +1248,7 @@
   @property
   def name(self):
     """`ConcreteFunction` name."""
-    return self._inference_function.name
+    return self._delayed_rewrite_functions.forward.name
 
   @property
   def graph(self):
@@ -895,7 +1288,7 @@
   @property
   def function_def(self):
     """Returns a `FunctionDef` object representing this function."""
-    return self._inference_function.definition
+    return self._delayed_rewrite_functions.forward.definition
 
   @property
   def output_shapes(self):
@@ -915,390 +1308,81 @@
             self._func_graph.structured_outputs),
         expand_composites=False)
 
-  def add_to_graph(self, g=None, register_gradient_functions=False):
-    """Registers the function, adds it to the graph g or default graph."""
+  def add_to_graph(self, g=None):
+    """Registers the function, adds it to the graph g or default graph.
+
+    Args:
+      g: If specified, registers the function with this graph. Defaults to the
+        current context (either the default graph or the eager context).
+    """
     # If we are not executing eagerly, adds the function to default graph if no
     # graph is specified.
     # In case of eager execution, function definition gets added to context
     # during construction itself.
 
-    # TODO(allenl/shivaniagrawal): rename this to register to reflect the
-    # method's functionality better. Remove register_gradient_functions argument
-    # and figure out if these needs to be registered.
-
     if not context.executing_eagerly() and not g:
       g = ops.get_default_graph()
-    self._inference_function.add_to_graph(g)  # pylint: disable=protected-access
+    self._delayed_rewrite_functions.forward.add_to_graph(g)
 
-    # pylint: disable=protected-access
-    if register_gradient_functions:
-      # There are two situations for the actual call of a defun:
-      # 1. If none of the input args are resource variables or watch by any
-      #   tape, and it will run the _inference_function of concrete_func for
-      #   forward pass, the gradient will be generated by standard mechanism.
-      # 2. Otherwise, defun will create two functions, one for forward pass,
-      #   and the backward pass will be created via tape.
-      #   When registering the function, we register both cases.
-      backward_function = self._graph_backprop_function()._inference_function
-      forward_function = self._forward_function
-      # pylint: enable=protected-access
-      forward_function.add_to_graph(g)
-      backward_function.add_to_graph(g)
+  def add_gradient_functions_to_graph(self, g=None):
+    """Add forward/backward functions to graph `g` or the current context."""
+    if not context.executing_eagerly() and not g:
+      g = ops.get_default_graph()
+    self._delayed_rewrite_functions.forward.add_to_graph(g)
+    forward_function, backward_function = (
+        self._delayed_rewrite_functions.forward_backward())
+    forward_function.add_to_graph(g)
+    backward_function.add_to_graph(g)
 
-  def _graph_backprop_function(self, num_doutputs=None):
-    """A possibly-cached backprop function."""
-    backward_function = self._cached_graph_backprop_functions.get(
-        num_doutputs, None)
-    if backward_function is not None:
-      return backward_function
-    backward_function = self._construct_graph_backprop_function(num_doutputs)
-    self._cached_graph_backprop_functions[num_doutputs] = backward_function
-    return backward_function
+  def _register_delayed_rewrite_gradient(self):
+    """Registers a delayed-rewrite gradient function and returns the name."""
+    return self._delayed_rewrite_functions.register()
 
-  def _construct_graph_backprop_function(self, num_doutputs=None):
-    """Constructs a backprop function object for this function.
+  def _select_forward_and_backward_functions(self, args):
+    """Selects forward and backward functions based on the calling context.
+
+    The forward function computes the "real" function outputs, `self._outputs`,
+    and any extra values needed by the corresponding backward function.
 
     Args:
-      num_doutputs: The constructed backprop function will take output gradients
-        for the first `num_doutputs` outputs of the forward function. Defaults
-        to the number of outputs for the inference function, but when
-        higher-order gradients are computed this will increase to include side
-        outputs.
+      args: A flat list of Tensors with all of the inputs to the forward
+        function (including user-specified and captured inputs).
 
     Returns:
-      A backward function taking `num_doutputs` arguments and returning
-      gradients with respect to inputs of the forward function.
-
-      self._forward_function is re-generated to account for new side outputs, if
-      any extra were required when building the backward pass.
+      An object with a `forward` property containing an _EagerDefinedFunction,
+      and a corresponding `backward` method which takes outputs from the forward
+      function and returns a backward function.
     """
-    if num_doutputs is None:
-      num_doutputs = len(self._inference_function.signature.output_arg)
-    trainable_outputs = [
-        output for output in self._func_graph.outputs[:num_doutputs]
-        if gradients_util.IsTrainable(output)]
-
-    signature = []
-    for t in trainable_outputs:
-      signature.append(
-          tensor_spec.TensorSpec(*default_gradient.shape_and_dtype(t)))
-
-    def _backprop_function(*grad_ys):
-      return gradients_util._GradientsHelper(  # pylint: disable=protected-access
-          trainable_outputs,
-          self._func_graph.inputs,
-          grad_ys=grad_ys,
-          src_graph=self._func_graph)
-
-    with self._func_graph.as_default():
-      backwards_graph = func_graph_module.FuncGraph(
-          _backward_name(self._func_graph.name))
-      func_graph_module.func_graph_from_py_func(
-          name=backwards_graph.name,
-          python_func=_backprop_function,
-          args=[], kwargs={},
-          signature=signature,
-          func_graph=backwards_graph)
-      backwards_graph_captures = list(backwards_graph.captures.keys())
-      captures_from_forward = [
-          c for c in backwards_graph_captures if
-          not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph]
-
-      forward_function_name = _forward_name(self._func_graph.name)
-
-      existing_outputs = set(self._func_graph.outputs)
-      for capture in captures_from_forward:
-        if capture not in existing_outputs:
-          existing_outputs.add(capture)
-          self._func_graph.outputs.append(capture)
-      backward_function_attr = _parse_func_attrs(
-          {FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
-      backward_function_attr.update(self._attrs)
-
-      backward_function = ConcreteFunction(
-          backwards_graph, attrs=backward_function_attr)
-      forward_function_attr = _parse_func_attrs({
-          BACKWARD_FUNCTION_ATTRIBUTE_NAME:
-          backward_function._inference_function.name})  # pylint: disable=protected-access
-      forward_function_attr.update(self._attrs)
-
-      self._forward_function = _EagerDefinedFunction(
-          forward_function_name, self._func_graph, self._func_graph.inputs,
-          self._func_graph.outputs, forward_function_attr)
-      return backward_function
-
-  def _tape_functions_for_first_order(self):
-    """Shortcut for when only first-order gradients are required.
-
-    The returned backward function does not accept gradients with respect to
-    side output of forward_function. This is fine as long as the user can't
-    possibly request second order tape gradients, as when they've used a single
-    non-persistent GradientTape. Since we don't need the backward function to
-    take gradients with respect to side outputs, we can skip some potentially
-    slow graph building.
-
-    Returns:
-      A tuple of (forward_function, backward_function):
-        forward_function: Takes the same inputs as the inference function, but
-          returns side outputs used by backward_function in addition to the
-          inference function's outputs.
-        backward_function: Takes side outputs from forward_function and
-          gradients with respect to the "real" outputs of forward_function and
-          returns gradients with respect to the inputs.
-    """
-    if self._tape_forward_function_first_order is not None:
-      return (self._tape_forward_function_first_order,
-              self._tape_backward_function_first_order)
-    outputs = self._func_graph.outputs[
-        :len(self._inference_function.signature.output_arg)]
-    forward_function, backward_function = (
-        self._tape_forward_and_backward_functions(outputs))
-    self._tape_forward_function_first_order = forward_function
-    self._tape_backward_function_first_order = backward_function
-    return forward_function, backward_function
-
-  # TODO(b/136189779): Cond/while under a tape may need similar logic. Consider
-  # generalizing if so.
-  def _tape_functions_for_higher_order(self):
-    """Forward and backward functions suitable for higher-order gradients.
-
-    Unlike `_tape_functions_for_first_order`, the backward function built by
-    this method accepts gradients for all of the outputs of the returned forward
-    function, including side outputs.
-
-    Returns:
-      A tuple of (forward_function, backward_function):
-        forward_function: Takes the same inputs as the inference function, but
-          returns side outputs used by backward_function in addition to the
-          inference function's outputs.
-        backward_function: Takes side outputs from forward_function and
-          gradients with respect to all of its outputs, real and side. Returns
-          gradients with respect to the inputs.
-    """
-    if self._tape_forward_function_higher_order is not None:
-      return (self._tape_forward_function_higher_order,
-              self._tape_backward_function_higher_order)
-    outputs = []
-    # First we need to figure out how many side outputs from the forward pass
-    # will be required. We do this in a temporary graph to avoid actually
-    # running multiple copies of the backward pass (one per _GradientsHelper
-    # call).
-    #
-    # While computing gradients, the backward function captures Tensors from
-    # the forward function. We add these as side outputs of the original
-    # function. However, we then need to accept output gradients with respect
-    # to these side outputs for higher order gradients to work. Thus we loop
-    # until the number of outputs of the function stabilizes. Note that this
-    # is only required for tape gradients, where we need to declare in advance
-    # all of the forward op's outputs: symbolic gradients with tf.gradients
-    # instead rely on regenerating backward functions when higher-order
-    # gradients are requested.
-    while len(outputs) < len(self._func_graph.outputs):
-      new_outputs = self._func_graph.outputs[len(outputs):]
-      outputs = list(self._func_graph.outputs)
-      self._tape_forward_and_backward_functions(new_outputs)
-    forward_function, backward_function = (
-        self._tape_forward_and_backward_functions(outputs))
-    if len(self._func_graph.outputs) != len(outputs):
-      raise AssertionError(
-          ("Unexpectedly added new outputs to the forward function when "
-           "building the backward function: {}").format(
-               self._func_graph.outputs[len(outputs):]))
-    self._tape_forward_function_higher_order = forward_function
-    self._tape_backward_function_higher_order = backward_function
-    return forward_function, backward_function
-
-  def _tape_forward_and_backward_functions(self, outputs):
-    """Constructs tape forward and back functions for `outputs`."""
-    # First figure out which of `outputs` are trainable. We'll accept gradients
-    # for each of these in the backward function.
-    handles_to_variables = {self._func_graph.captures[v.handle]: v
-                            for v in self._func_graph.variables
-                            if v.handle in self._func_graph.captures}
-    trainable_outputs = []
-    for output in outputs:
-      if gradients_util.IsTrainable(output):
-        # Swap in the Variable object for resource handles if we can so
-        # sparse gradients work.
-        output = handles_to_variables.get(output, output)
-        trainable_outputs.append(output)
-
-    backwards_graph = func_graph_module.FuncGraph(
-        _backward_name(self._func_graph.name))
-    # Keep track of the forward graph so that if the backwards graph
-    # tries to capture tensors those will be correctly captured first in
-    # the forward graph. This is an edge case that can only happen with
-    # tf.custom_gradient.
-    backwards_graph._forward_func_graph = self._func_graph  # pylint: disable=protected-access
-    with backwards_graph.as_default():
-      gradients_wrt_outputs = []
-      for output in trainable_outputs:
-        gradient_shape, gradient_dtype = default_gradient.shape_and_dtype(
-            output)
-        gradients_wrt_outputs.append(
-            graph_placeholder(gradient_dtype, gradient_shape))
-      gradients_wrt_inputs = gradients_util._GradientsHelper(  # pylint: disable=protected-access
-          trainable_outputs,
-          self._func_graph.inputs,
-          grad_ys=gradients_wrt_outputs,
-          src_graph=self._func_graph)
-
-      captures_from_forward = [
-          c for c in backwards_graph.captures.keys() if
-          not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph]
-      existing_outputs = set(self._func_graph.outputs)
-      for capture in captures_from_forward:
-        if capture not in existing_outputs:
-          existing_outputs.add(capture)
-          self._func_graph.outputs.append(capture)
-
-    forward_function_name = _forward_name(self._func_graph.name)
-    backward_function_attr = _parse_func_attrs(
-        {FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
-    backward_function_attr.update(self._attrs)
-
-    # The ordering of `backwards_graph.inputs` is important: inputs of
-    # `backward_function` correspond to outputs (including
-    # side outputs) of `self._tape_forward_function`.
-    backwards_graph.inputs = (
-        gradients_wrt_outputs + list(backwards_graph.captures.values()))
-    backwards_graph.outputs.extend(
-        grad
-        for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True)
-        if grad is not None)
-    backwards_graph.structured_outputs = gradients_wrt_inputs
-    backward_function = ConcreteFunction(
-        backwards_graph, attrs=backward_function_attr)
-
-    forward_function_attr = _parse_func_attrs({
-        BACKWARD_FUNCTION_ATTRIBUTE_NAME:
-            backward_function._inference_function.name})  # pylint: disable=protected-access
-    forward_function_attr.update(self._attrs)
-
-    forward_function = _EagerDefinedFunction(
-        forward_function_name, self._func_graph, self._func_graph.inputs,
-        self._func_graph.outputs,
-        forward_function_attr)
-    return forward_function, backward_function
-
-  def _tape_backprop_call(self, args, forward_function, backward_function):
-    """Calls the forward function and records the result on a tape.
-
-    Args:
-      args: All inputs to the function, including resolved captured inputs
-      forward_function: The forward pass, outputting both user-specified and
-        side outputs.
-      backward_function: Computes gradients for inputs of forward_function given
-        output gradients for the first `N` of forward_function's outputs, not
-        necessarily all of them. See `_tape_functions_for_first_order` and
-        `_tape_functions_for_higher_order`.
-
-    Returns:
-      The call output.
-    """
-    ctx = context.context()
-
-    self._register_gradient()
-    with ops.get_default_graph().gradient_override_map(
-        {"PartitionedCall": self._gradient_name,
-         "StatefulPartitionedCall": self._gradient_name}):
-      outputs = forward_function.call(ctx, args)
-
-    if isinstance(outputs, ops.Operation) or outputs is None:
-      return outputs
-
-    # `real_outputs` are the actual outputs of the inference graph function;
-    # `side_outputs` are the intermediate Tensors that were added as outputs to
-    # the forward graph function so that we can compute its gradient.
-    real_outputs = outputs[:self._num_outputs]
-
-    capture_mapping = dict(zip(self._func_graph.outputs, outputs))
-    remapped_captures = [
-        capture_mapping.get(capture, capture)
-        for capture in backward_function.captured_inputs]
-    # We may need to use zeros_like to get a zero for variant Tensors with
-    # unconnected gradients. We do that in advance so we don't have to hold on
-    # to the outputs themselves, which may not be needed otherwise.
-    variant_zeros_like = {}
-    backward_function_inputs = (
-        len(backward_function.inputs) - len(backward_function.captured_inputs))
-    recorded_outputs = []
-    trainable_recorded_outputs = 0
-    skip_positions = []
-    for output_index, output in enumerate(outputs):
-      if trainable_recorded_outputs < backward_function_inputs:
-        recorded_outputs.append(output)
-      if gradients_util.IsTrainable(output):
-        trainable_recorded_outputs += 1
+    possible_gradient_type = _PossibleTapeGradientTypes(
+        pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes(args))
+    if possible_gradient_type == _PossibleTapeGradientTypes.FIRST_ORDER:
+      if context.executing_eagerly():
+        # There is a single non-persistent tape active, so the user can only
+        # request first-order gradients from a tape. We can spend less time
+        # graph building since we know this.
+        #
+        # We may still end up computing higher-order gradients, but that'd be
+        # through `tf.gradients`, which can re-write the forward pass and so
+        # needs no preparation here.
+        return self._first_order_tape_functions
       else:
-        skip_positions.append(output_index)
-      if output.dtype == dtypes.variant:
-        variant_zeros_like[output_index] = default_gradient.zeros_like(output)
-
-    def _backward_function_wrapper(*args):
-      """Process output gradients and call the backward function."""
-      processed_args = []
-      input_index = 0
-      for output_index, arg in enumerate(args):
-        if output_index in skip_positions:
-          continue
-        if arg is None:
-          # We're calling a (non-polymorphic) ConcreteFunction, so we need to
-          # have a Tensor value for each Tensor we thought would be trainable
-          # based on its dtype, even if it ended up being unconnected.
-          input_placeholder = backward_function.inputs[
-              input_index]
-          if input_placeholder.dtype == dtypes.variant:
-            arg = variant_zeros_like[output_index]
-          else:
-            arg = array_ops.zeros(
-                *default_gradient.shape_and_dtype(input_placeholder))
-        processed_args.append(arg)
-        input_index += 1
-      return backward_function._call_flat(  # pylint: disable=protected-access
-          processed_args, remapped_captures)
-
-    tape.record_operation(forward_function.signature.name,
-                          recorded_outputs, args, _backward_function_wrapper)
-    return self._build_call_outputs(real_outputs)
-
-  def _backprop_call_with_delayed_rewrite(self, args):
-    """Calls the inference function and records the result on a tape.
-
-    The recorded backwards function will construct the backwards graph and
-    rewrite the inference function to the forward function. This only happens
-    if the recorded backwards function ends up being used to compute gradients.
-
-    This approach avoids constructing unnecessary graphs, but it only works if
-    we are calling this function when not executing eagerly.
-
-    (Only records results on a tape if the function has outputs)
-
-    Args:
-      args: All inputs to the function, including resolved captured inputs
-
-    Returns:
-      The call output.
-    """
-    ctx = context.context()
-
-    self._register_gradient()
-    with ops.get_default_graph().gradient_override_map(
-        {"PartitionedCall": self._gradient_name,
-         "StatefulPartitionedCall": self._gradient_name}):
-      outputs = self._inference_function.call(ctx, args)
-
-    if isinstance(outputs, ops.Operation) or outputs is None:
-      return outputs
-
-    call_op = outputs[0].op
-
-    def backward_function(*args):
-      return self._grad_fn(call_op, *args)
-
-    tape.record_operation(self._inference_function.signature.name, outputs,
-                          args, backward_function)
-    return self._build_call_outputs(outputs)
+        # We can avoid computing second-order gradients in some cases by doing a
+        # delayed rewrite when graph building. Since we know we'll only compute
+        # first-order tape gradients, the delayed rewrite is safe: we won't need
+        # to tell the tape about side outputs.
+        #
+        # TODO(allenl): This case is really dirty. It would be better if we
+        # could temporarily pop all of the current tapes to avoid
+        # accidentally taking second-order gradients.
+        return self._delayed_rewrite_functions
+    elif possible_gradient_type == _PossibleTapeGradientTypes.HIGHER_ORDER:
+      # Either there's a persistent tape watching, or there are multiple nested
+      # tapes. Either way, the user may request higher-order gradients. We'll
+      # spend a bit more time and make sure higher-order gradients are correct.
+      return self._higher_order_tape_functions
+    # else possible_gradient_type == _PossibleTapeGradientTypes.NONE, meaning no
+    # tape is recording.
+    return self._delayed_rewrite_functions
 
   def _build_call_outputs(self, result):
     """Maps the fdef output list to actual output structure.
@@ -1800,7 +1884,8 @@
       args = self.input_signature
       kwargs = {}
     seen_names = set()
-    captured = frozenset(graph_function.graph.internal_captures)
+    captured = object_identity.ObjectIdentitySet(
+        graph_function.graph.internal_captures)
     # pylint: disable=protected-access
     graph_function._arg_keywords = []
     prefix_counts = {}
@@ -1943,17 +2028,12 @@
             arg_names=arg_names,
             override_flat_arg_shapes=override_flat_arg_shapes,
             capture_by_value=self._capture_by_value),
-        self._function_attributes)
-
-    # pylint: disable=protected-access
-    # Tell the ConcreteFunction to clean up its graph once it goes out of
-    # scope. ConcreteFunction does not do this in its constructor since it
-    # gets used in some places (like Keras) where the FuncGraph lives
-    # longer than the ConcreteFunction.
-    graph_function._garbage_collector = ConcreteFunctionGarbageCollector(
-        graph_function.graph)
-    # pylint: enable=protected-access
-
+        self._function_attributes,
+        # Tell the ConcreteFunction to clean up its graph once it goes out of
+        # scope. This is not the default behavior since it gets used in some
+        # places (like Keras) where the FuncGraph lives longer than the
+        # ConcreteFunction.
+        shared_func_graph=False)
     return graph_function
 
   def _define_function_with_shape_relaxation(self, args, kwargs):
@@ -2083,7 +2163,8 @@
     raise ValueError("Only defun function is allowed to be registered. "
                      "Got type: %s" % type(func))
   concrete_func = func.get_concrete_function(*args, **kwargs)
-  concrete_func.add_to_graph(register_gradient_functions=True)
+  concrete_func.add_to_graph()
+  concrete_func.add_gradient_functions_to_graph()
   return concrete_func
 
 
diff --git a/tensorflow/python/eager/function_gradients_test.py b/tensorflow/python/eager/function_gradients_test.py
index f151bab..1b052ad 100644
--- a/tensorflow/python/eager/function_gradients_test.py
+++ b/tensorflow/python/eager/function_gradients_test.py
@@ -298,6 +298,24 @@
       y = f(x)
     self.assertAllEqual(self.evaluate(t.gradient(y, x)), 4.0)
 
+  def testGraphLoopGradientInsideSession(self):
+    with ops.Graph().as_default():
+      n = constant_op.constant(2.0)
+      x = array_ops.placeholder(dtypes.float32, shape=None)
+
+      @def_function.function
+      def f():
+        c = lambda n: n < 10
+        b = lambda n: n * x
+        return control_flow_ops.while_loop(c, b, [n],
+                                           [tensor_shape.unknown_shape()])
+
+      l = f()
+      dx = gradients_impl.gradients(l, [x])[0]
+
+      with self.cached_session():
+        self.assertEqual(dx.eval(feed_dict={x: 2.0}), 24.0)
+
   def testDefunDifferentiable(self):
     v = resource_variable_ops.ResourceVariable(1.0)
 
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index a922baa..0a6b349 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -42,6 +42,7 @@
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
+from tensorflow.python.framework import func_graph
 from tensorflow.python.framework import function as tf_function
 from tensorflow.python.framework import indexed_slices
 from tensorflow.python.framework import ops
@@ -149,6 +150,15 @@
     r = add(x, v2)
     self.assertEqual(3.0, self.evaluate(r))
 
+  def testVariableOnly(self):
+    v = variables.Variable(1.0)
+    add = def_function.function(lambda x: x.assign_add(1.0))
+    r1 = add(v)
+    self.assertEqual(2.0, self.evaluate(r1))
+    c = constant_op.constant(1.0)
+    with self.assertRaisesRegexp(AttributeError, 'no attribute'):
+      add(c)
+
   def testExternalControlDependency(self):
     with ops.Graph().as_default(), self.test_session():
       v = variables.Variable(1.0)
@@ -287,7 +297,7 @@
     def f(_):
       return 1.0
 
-    with self.assertRaisesRegexp(TypeError, 'set'):
+    with self.assertRaisesRegexp(AttributeError, 'set'):
       f(set([]))
 
   def testFuncName(self):
@@ -2092,7 +2102,8 @@
     with context.graph_mode(), self.cached_session():
       with ops.get_default_graph().as_default():
         t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
-        composite.add_to_graph(register_gradient_functions=True)
+        composite.add_to_graph()
+        composite.add_gradient_functions_to_graph()
 
         graph = ops.get_default_graph()
         # pylint: disable=protected-access
@@ -2124,6 +2135,32 @@
         # is added.
         self.assertLen(graph._functions, 6)
 
+  @parameterized.named_parameters(
+      dict(testcase_name='Defun',
+           function_decorator=function.defun),
+      dict(testcase_name='DefFunction',
+           function_decorator=def_function.function))
+  def testEagerCaptures(self, function_decorator):
+    with context.eager_mode():
+      large_tensor = array_ops.ones(shape=(256,))
+      self.assertGreater(256, func_graph._EAGER_CONST_THRESHOLD)
+
+      small_tensor = array_ops.ones(shape=(4,))
+      self.assertLessEqual(4, func_graph._EAGER_CONST_THRESHOLD)
+
+      v = resource_variable_ops.ResourceVariable(0.0)
+
+    for captured, op_type in [(large_tensor, 'Placeholder'),
+                              (small_tensor, 'Const'), (v, 'Placeholder')]:
+      @function_decorator
+      def test_fn():
+        return captured + 1  # pylint: disable=cell-var-from-loop
+
+      g = test_fn.get_concrete_function().graph
+      internal_captures = g.internal_captures
+      self.assertLen(internal_captures, 1)
+      self.assertEqual(internal_captures[0].op.type, op_type)
+
   def testRegisterFunctionWithInputSignature(self):
     def matmul(x, y):
       return math_ops.matmul(x, y)
diff --git a/tensorflow/python/eager/lift_to_graph.py b/tensorflow/python/eager/lift_to_graph.py
index a1c297e..6d71508 100644
--- a/tensorflow/python/eager/lift_to_graph.py
+++ b/tensorflow/python/eager/lift_to_graph.py
@@ -20,7 +20,6 @@
 from __future__ import print_function
 
 import collections
-import six
 
 from tensorflow.python.framework import func_graph
 from tensorflow.python.framework import ops
@@ -28,6 +27,7 @@
 from tensorflow.python.ops import op_selector
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.util import compat
+from tensorflow.python.util import object_identity
 
 
 UnliftableError = op_selector.UnliftableError
@@ -202,7 +202,7 @@
   op_map[s.op] = copied_placeholder.op
 
 
-def lift_to_graph(init_tensors,
+def lift_to_graph(tensors,
                   graph,
                   sources=None,
                   disallowed_placeholders=None,
@@ -213,7 +213,7 @@
   """Copies the tensor and all its inputs recursively to the outer graph.
 
   Args:
-    init_tensors: The Tensor to lift.
+    tensors: The Tensors to lift.
     graph: The graph to lift to.
     sources: Optional sequence of nodes to start from. If omitted the whole
       subgraph which feeds into `init_tensor` is lifted.
@@ -234,14 +234,18 @@
   Raises:
     UnliftableError: If a placeholder blocks lifting.
   """
-  variable_init_tensors = {i for i in init_tensors if isinstance(
-      i, resource_variable_ops.ResourceVariable)}
-  init_tensors = set(init_tensors).difference(variable_init_tensors)
-  base_graph = base_graph or list(init_tensors)[0].graph
-  op_map = op_map or {}
+  variable_init_tensors = []
+  init_tensors = []
+  for tensor in tensors:
+    if isinstance(tensor, resource_variable_ops.ResourceVariable):
+      variable_init_tensors.append(tensor)
+    else:
+      init_tensors.append(tensor)
+  base_graph = base_graph or init_tensors[0].graph
+  op_map = op_map or object_identity.ObjectIdentityDictionary()
 
   # Check that the initializer does not depend on any placeholders.
-  sources = set(sources or [])
+  sources = object_identity.ObjectIdentitySet(sources or [])
   visited_ops = set([x.op for x in sources])
   op_outputs = collections.defaultdict(set)
 
@@ -287,11 +291,15 @@
 
   # When lifting from one FuncGraph to another, we will need to capture the
   # relevant tensors as well.
-  captures = collections.OrderedDict()
+  captures = []
+  inverse_captures = object_identity.ObjectIdentityDictionary()
+  internal_captures = []
   if (isinstance(base_graph, func_graph.FuncGraph) and
       isinstance(graph, func_graph.FuncGraph)):
     captures = base_graph.captures
-  inverse_captures = {v: k for k, v in captures.items()}
+    for external_capture, internal_capture in captures:
+      inverse_captures[internal_capture] = external_capture
+    internal_captures = base_graph.internal_captures
 
   # ops_to_copy now holds a reverse topologically sorted list of ops which
   # ends in the initializer. We copy those to the outermost graph and
@@ -301,7 +309,7 @@
                   })  # Pass through variables.
     source_ops = set()
     # Add the sources in the same order as the original graph.
-    for s in six.itervalues(captures):
+    for s in internal_captures:
       if s in sources:
         sources.remove(s)
         source_ops.add(s.op)
diff --git a/tensorflow/python/eager/lift_to_graph_test.py b/tensorflow/python/eager/lift_to_graph_test.py
index 619b9dc..90db3eb 100644
--- a/tensorflow/python/eager/lift_to_graph_test.py
+++ b/tensorflow/python/eager/lift_to_graph_test.py
@@ -41,7 +41,7 @@
       return v1 + v2 + v3
 
     concrete_fn = fn.get_concrete_function()
-    original_captures = concrete_fn.graph.captures
+    original_captures = concrete_fn.graph.internal_captures
     outputs = concrete_fn.graph.outputs
 
     for _ in range(100):
@@ -49,11 +49,10 @@
 
       lift_to_graph.lift_to_graph(
           outputs, g, add_sources=True, handle_captures=True)
-      lifted_captures = g.captures
+      lifted_captures = g.internal_captures
       self.assertLen(lifted_captures, 3)
-      for original_capture, lifted_capture in zip(original_captures.values(),
-                                                  lifted_captures.values()):
-        self.assertEqual(original_capture.name, lifted_capture.name)
+      for original, lifted in zip(original_captures, lifted_captures):
+        self.assertEqual(original.name, lifted.name)
 
   def testClassAttrsRemoved(self):
     """Tests that _class attrs (from colocate_with()) are removed."""
diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py
index 0a3eb2f..39697ec 100644
--- a/tensorflow/python/eager/ops_test.py
+++ b/tensorflow/python/eager/ops_test.py
@@ -273,24 +273,23 @@
   def testSilentCopy(self):
     # Temporarily replace the context
     # pylint: disable=protected-access
-    del context._context
-    context._context = context.Context()
+    old_context = context.context()
+    context._set_context(context.Context())
     try:
       config.set_device_policy('silent')
       cpu_tensor = constant_op.constant(1.0)
       gpu_tensor = cpu_tensor.gpu()
       self.assertAllEqual(cpu_tensor + gpu_tensor, 2.0)
     finally:
-      del context._context
-      context._context = context.Context()
+      context._set_context(old_context)
     # pylint: enable=protected-access
 
   @test_util.run_gpu_only
   def testSoftPlacement(self):
     # Temporarily replace the context
     # pylint: disable=protected-access
-    del context._context
-    context._context = context.Context()
+    old_context = context.context()
+    context._set_context(context.Context())
     try:
       config.set_device_policy('silent')
       config.set_soft_device_placement(True)
@@ -299,8 +298,7 @@
       self.assertEqual(result.device,
                        '/job:localhost/replica:0/task:0/device:GPU:0')
     finally:
-      del context._context
-      context._context = context.Context()
+      context._set_context(old_context)
     # pylint: enable=protected-access
 
   def testRandomUniform(self):
diff --git a/tensorflow/python/eager/profiler.py b/tensorflow/python/eager/profiler.py
index b40cec9..d906fc9 100644
--- a/tensorflow/python/eager/profiler.py
+++ b/tensorflow/python/eager/profiler.py
@@ -71,14 +71,9 @@
   with _profiler_lock:
     if _profiler is not None:
       raise ProfilerAlreadyRunningError('Another profiler is running.')
-    profiler_context = pywrap_tensorflow.TFE_NewProfilerContext()
     if context.default_execution_mode == context.EAGER_MODE:
       context.ensure_initialized()
-      pywrap_tensorflow.TFE_ProfilerContextSetEagerContext(
-          profiler_context,
-          context.context()._handle)  # pylint: disable=protected-access
-    _profiler = pywrap_tensorflow.TFE_NewProfiler(profiler_context)
-    pywrap_tensorflow.TFE_DeleteProfilerContext(profiler_context)
+    _profiler = pywrap_tensorflow.TFE_NewProfiler()
     if not pywrap_tensorflow.TFE_ProfilerIsOk(_profiler):
       logging.warning('Another profiler session is running which is probably '
                       'created by profiler server. Please avoid using profiler '
@@ -102,7 +97,7 @@
       raise ProfilerNotRunningError(
           'Cannot stop profiling. No profiler is running.')
     if context.default_execution_mode == context.EAGER_MODE:
-      context.async_wait()
+      context.context().executor.wait()
     with c_api_util.tf_buffer() as buffer_:
       pywrap_tensorflow.TFE_ProfilerSerializeToString(
           _profiler,
@@ -161,14 +156,9 @@
   Args:
     port: port profiler server listens to.
   """
-  profiler_context = pywrap_tensorflow.TFE_NewProfilerContext()
   if context.default_execution_mode == context.EAGER_MODE:
     context.ensure_initialized()
-    pywrap_tensorflow.TFE_ProfilerContextSetEagerContext(
-        profiler_context,
-        context.context()._handle)  # pylint: disable=protected-access
-  pywrap_tensorflow.TFE_StartProfilerServer(profiler_context, port)
-  pywrap_tensorflow.TFE_DeleteProfilerContext(profiler_context)
+  pywrap_tensorflow.TFE_StartProfilerServer(port)
 
 
 class Profiler(object):
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index 40f7be5..b81edda 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/python/eager/pywrap_tensor.h"
 
 #include <stdlib.h>
+#include <string.h>
 
 #include "structmember.h"  // NOLINT // For PyMemberDef
 #include "tensorflow/c/c_api.h"
@@ -24,6 +25,7 @@
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/python/eager/pywrap_tensor_conversion.h"
 #include "tensorflow/python/eager/pywrap_tfe.h"
 #include "tensorflow/python/lib/core/ndarray_tensor.h"
 #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
@@ -40,18 +42,31 @@
 // events on eager tensors. This is set by TFE_Py_InitEagerTensor, if at all.
 PyObject* eager_tensor_profiler = nullptr;
 
-TFE_Context* GetContext(PyObject* ctx) {
-  TFE_Context* context =
-      reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(ctx, nullptr));
-  if (context == nullptr) {
+TFE_Context* GetContextHandle(PyObject* py_context) {
+  tensorflow::Safe_PyObjectPtr py_context_handle(
+      PyObject_GetAttrString(py_context, "_handle"));
+  if (py_context_handle == nullptr) {
+    // Current Python code makes sure this never happens. If it does, or
+    // becomes hard to maintain, we can call the ensure_initialized() method
+    // here.
+    PyErr_SetString(
+        PyExc_TypeError,
+        "Expected `context` argument in EagerTensor constructor to have a "
+        "`_handle` attribute but it did not. Was eager Context initialized?");
+    return nullptr;
+  }
+
+  auto* ctx = reinterpret_cast<TFE_Context*>(
+      PyCapsule_GetPointer(py_context_handle.get(), nullptr));
+  if (ctx == nullptr) {
     PyErr_SetString(PyExc_TypeError,
                     tensorflow::strings::StrCat(
                         "Expected context._handle to contain a PyCapsule "
                         "encoded pointer to TFE_Context. Got ",
-                        Py_TYPE(ctx)->tp_name)
+                        Py_TYPE(py_context_handle.get())->tp_name)
                         .c_str());
   }
-  return context;
+  return ctx;
 }
 
 // Convert a Python numpy.ndarray object to a TFE_TensorHandle.
@@ -105,41 +120,6 @@
   return ret;
 }
 
-TFE_TensorHandle* CopyToDevice(TFE_TensorHandle* handle, PyObject* ctx,
-                               PyObject* dev) {
-  const char* device = "";
-  if (dev != nullptr && dev != Py_None) {
-    device = PyBytes_AsString(dev);
-#if PY_MAJOR_VERSION >= 3
-    if (device == nullptr) {
-      PyErr_Clear();
-      device = PyUnicode_AsUTF8(dev);
-    }
-#endif
-    if (device == nullptr) {
-      PyErr_SetString(PyExc_TypeError,
-                      "Error parsing device argument to CopyToDevice");
-      return nullptr;
-    }
-  }
-  TFE_Context* context = GetContext(ctx);
-  if (context == nullptr) {  // PyErr already set by GetContext
-    return nullptr;
-  }
-  auto status = tensorflow::make_safe(TF_NewStatus());
-  TFE_TensorHandle* new_handle =
-      TFE_TensorHandleCopyToDevice(handle, context, device, status.get());
-  if (TF_GetCode(status.get()) != TF_OK) {
-    PyErr_SetString(
-        PyExc_RuntimeError,
-        tensorflow::strings::StrCat("Error copying tensor to device: ", device,
-                                    ". ", TF_Message(status.get()))
-            .c_str());
-    return nullptr;
-  }
-  return new_handle;
-}
-
 // Helper function to convert `v` to a tensorflow::DataType and store it in
 // `*out`. Returns true on success, false otherwise.
 // Note that we assume that v is a python int (not long) representing a
@@ -168,6 +148,41 @@
 #endif
 }
 
+// PyObject->tensorflow::DataType conversion function to be used with
+// PyArg_Parse* APIs.
+int ConvertDataType(PyObject* obj, tensorflow::DataType* dst) {
+  if (obj == Py_None) {
+    *dst = tensorflow::DataType::DT_INVALID;
+  } else if (!PyIntToDataType(obj, dst)) {
+    PyErr_SetString(
+        PyExc_TypeError,
+        tensorflow::strings::StrCat(
+            "Expecting a DataType value for dtype. Got ", Py_TYPE(obj)->tp_name)
+            .c_str());
+    return 0;
+  }
+
+  return 1;
+}
+
+// Conversion function extracting a const char** device name from a PyObject.
+// The function should be used with PyArg_Parse* APIs.
+int ConvertDeviceName(PyObject* obj, const char** dst) {
+  if (obj == Py_None) {
+    *dst = nullptr;
+  } else {
+    auto device_name = TFE_GetPythonString(obj);
+    if (device_name == nullptr) {
+      PyErr_Clear();
+      PyErr_SetString(PyExc_TypeError, "Error parsing device argument.");
+      return 0;
+    }
+    *dst = device_name;
+  }
+
+  return 1;
+}
+
 }  // namespace
 
 namespace tensorflow {
@@ -252,8 +267,10 @@
   return new TFE_TensorHandle(handle);
 }
 
-TFE_TensorHandle* ConvertToEagerTensor(TFE_Context* ctx, PyObject* value,
-                                       tensorflow::DataType dtype) {
+TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
+                                               PyObject* value,
+                                               tensorflow::DataType dtype,
+                                               const char* device_name) {
   tensorflow::Safe_PyObjectPtr value_decrefer;
   if (PyArray_IsScalar(value, Generic)) {
     // Convert numpy scalars to numpy arrays.
@@ -301,24 +318,22 @@
 
   if (handle == nullptr) return nullptr;
 
+  Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
   TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get());
   if (dtype != tensorflow::DT_INVALID &&
       dtype != static_cast<DataType>(handle_dtype)) {
     if (tensorflow::IsCompatible(dtype, static_cast<DataType>(handle_dtype))) {
-      Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
       handle = tensorflow::make_safe(
           tensorflow::EagerCast(ctx, handle.get(), handle_dtype,
                                 static_cast<TF_DataType>(dtype), status.get()));
       if (TF_GetCode(status.get()) != TF_OK) {
-        PyErr_SetString(
-            PyExc_TypeError,
-            absl::StrCat(
-                "Error while casting from dtype ",
-                tensorflow::DataTypeString(static_cast<DataType>(handle_dtype)),
-                " to ",
-                tensorflow::DataTypeString(static_cast<DataType>(dtype)), ". ",
-                TF_Message(status.get()))
-                .c_str());
+        PyErr_SetString(PyExc_TypeError,
+                        absl::StrCat("Error while casting from dtype ",
+                                     tensorflow::DataTypeString(
+                                         static_cast<DataType>(handle_dtype)),
+                                     " to ", tensorflow::DataTypeString(dtype),
+                                     ". ", TF_Message(status.get()))
+                            .c_str());
         return nullptr;
       }
     } else {
@@ -333,9 +348,69 @@
     }
   }
 
+  // Almost all TensorFlow kernels for GPU devices keep int32 tensors in host
+  // memory. We approximate the same behavior for eager execution - keeping
+  // int32 tensors in host memory.
+  //
+  // We do so to preclude the need for callers into such kernels from having to
+  // explicitly place the int32 tensors in host memory. For example, without
+  // this, one needed:
+  //
+  // with tf.device('/gpu:0'):
+  //   ...// code here
+  //   with tf.device('/cpu:0'):
+  //     shape = tf.constant(...)
+  //   y = tf.random_uniform(shape)
+  //
+  // Without the CPU device block, tfe.ops.random_uniform would fail since the
+  // kernel expects the shape in host memory.
+  //
+  // With this support, we simplify the code:
+  //
+  // with tf.device('/gpu:0'):
+  //   y = tf.random_uniform(...)
+  //
+  // The approximation is not exact there are GPU kernels which do not require
+  // host memory for int32 tensors. This will lead to a discrepancy between
+  // eager and graph execution.
+  //
+  // To support remote execution copy int32 tensors to another CPU device.
+  // TODO(ashankar): Fix this.
+  if (device_name != nullptr &&
+      (TFE_TensorHandleDataType(handle.get()) != TF_INT32 ||
+       strstr(device_name, "/device:CPU:0") != nullptr)) {
+    // Note that this is a shallow copy and will share the underlying buffer
+    // if copying to the same device.
+    handle = make_safe(TFE_TensorHandleCopyToDevice(handle.get(), ctx,
+                                                    device_name, status.get()));
+    if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_RuntimeError)) {
+      return nullptr;
+    }
+  }
+
   return handle.release();
 }
 
+TFE_TensorHandle* ConvertToEagerTensor(TFE_Context* ctx, PyObject* value,
+                                       DataType dtype,
+                                       const char* device_name) {
+  // Reduce the overhead of allocation/transfer-to-device for scalars by
+  // caching the corresponding handles. Note that currently only Python
+  // scalars are cached.
+  // TODO(slebedev): also cache singleton NumPy arrays and scalars?
+  if (PyArray_IsPythonNumber(value)) {
+    auto* cache = TFE_TensorHandleCache::Get();
+    TFE_TensorHandle* handle = cache->Lookup(value, dtype, device_name);
+    if (handle != nullptr) return handle;
+    handle = ConvertToEagerTensorUncached(ctx, value, dtype, device_name);
+    if (handle == nullptr) return nullptr;
+    cache->Insert(value, dtype, device_name, handle);
+    return handle;
+  } else {
+    return ConvertToEagerTensorUncached(ctx, value, dtype, device_name);
+  }
+}
+
 }  // namespace tensorflow
 
 extern "C" {
@@ -433,106 +508,23 @@
   self->weakreflist = nullptr;
   self->context = nullptr;
   PyObject* value;
-  PyObject* context = nullptr;
-  PyObject* device = nullptr;
-  PyObject* dtype = Py_None;
-  PyObject* other_value = nullptr;
-  const char* kwlist[] = {"value", "context",     "device",
-                          "dtype", "other_value", nullptr};
-  if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|OO",
-                                   const_cast<char**>(kwlist), &value, &context,
-                                   &device, &dtype, &other_value)) {
+  const char* device_name = nullptr;
+  tensorflow::DataType dtype = tensorflow::DataType::DT_INVALID;
+  const char* kwlist[] = {"value", "device", "dtype", nullptr};
+  if (!PyArg_ParseTupleAndKeywords(
+          args, kwds, "OO&|O&", const_cast<char**>(kwlist), &value,
+          ConvertDeviceName, &device_name, ConvertDataType, &dtype)) {
     return -1;
   }
 
-  tensorflow::Safe_PyObjectPtr context_handle(
-      PyObject_GetAttrString(context, "_handle"));
-  if (context_handle == nullptr) {
-    // Current Python code makes sure this never happens. If it does, or
-    // becomes hard to maintain, we can call the ensure_initialized() method
-    // here.
-    PyErr_SetString(
-        PyExc_TypeError,
-        "Expected `context` argument in EagerTensor constructor to have a "
-        "`_handle` field but it did not. Was eager Context initialized?");
-    return -1;
-  }
-  self->context = context;
-  Py_INCREF(self->context);
+  PyObject* py_context = GetPyEagerContext();
+  if (py_context == nullptr) return -1;
+  self->context = py_context;
 
-  if (other_value != nullptr) {
-    if (!EagerTensor_CheckExact(other_value)) {
-      PyErr_SetString(PyExc_TypeError,
-                      tensorflow::strings::StrCat(
-                          "Expecting an EagerTensor for other_value, got ",
-                          Py_TYPE(other_value)->tp_name)
-                          .c_str());
-
-      return -1;
-    }
-    EagerTensor* other = reinterpret_cast<EagerTensor*>(other_value);
-    self->handle =
-        TFE_TensorHandleCopySharingTensor(other->handle, self->status);
-
-    if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
-      return -1;
-    }
-
-    return 0;
-  }
-
-  // Extract dtype
-  tensorflow::DataType desired_dtype = tensorflow::DT_INVALID;
-  if (dtype != Py_None) {
-    if (!PyIntToDataType(dtype, &desired_dtype)) {
-      PyErr_SetString(PyExc_TypeError,
-                      tensorflow::strings::StrCat(
-                          "Expecting a DataType value for dtype. Got ",
-                          Py_TYPE(dtype)->tp_name)
-                          .c_str());
-      return -1;
-    }
-  }
-  PyErr_Clear();
-  tensorflow::Safe_TFE_TensorHandlePtr handle =
-      tensorflow::make_safe(tensorflow::ConvertToEagerTensor(
-          GetContext(context_handle.get()), value, desired_dtype));
+  auto* handle = tensorflow::ConvertToEagerTensor(GetContextHandle(py_context),
+                                                  value, dtype, device_name);
   if (handle == nullptr) return -1;
-
-  // Almost all TensorFlow kernels for GPU devices keep int32 tensors in host
-  // memory. We approximate the same behavior for eager execution - keeping
-  // int32 tensors in host memory.
-  //
-  // We do so to preclude the need for callers into such kernels from having to
-  // explicitly place the int32 tensors in host memory. For example, without
-  // this, one needed:
-  //
-  // with tf.device('/gpu:0'):
-  //   ...// code here
-  //   with tf.device('/cpu:0'):
-  //     shape = tf.constant(...)
-  //   y = tf.random_uniform(shape)
-  //
-  // Without the CPU device block, tfe.ops.random_uniform would fail since the
-  // kernel expects the shape in host memory.
-  //
-  // With this support, we simplify the code:
-  //
-  // with tf.device('/gpu:0'):
-  //   y = tf.random_uniform(...)
-  //
-  // The approximation is not exact there are GPU kernels which do not require
-  // host memory for int32 tensors. This will lead to a discrepancy between
-  // eager and graph execution.
-  // TODO(ashankar): Fix this.
-  if (TFE_TensorHandleDataType(handle.get()) != TF_INT32) {
-    // Note that this is a shallow copy and will share the underlying buffer
-    // if copying to the same device.
-    handle = tensorflow::make_safe(
-        CopyToDevice(handle.get(), context_handle.get(), device));
-    if (handle == nullptr) return -1;
-  }
-  self->handle = handle.release();
+  self->handle = handle;
 
   if (!MaybeInvokeCreatedOnEagerTensorProfiler(self)) {
     return -1;
@@ -667,15 +659,24 @@
 // Function `_copy_to_device`.
 static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
                                             PyObject* kwds) {
-  const char* kwlist[] = {"context", "device", nullptr};
-  PyObject* ctx = nullptr;
-  PyObject* dev = nullptr;
-  if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO", const_cast<char**>(kwlist),
-                                   &ctx, &dev) ||
-      !ctx || !dev) {
+  if (!_PyArg_NoKeywords("copy_to_device", kwds)) return nullptr;
+
+  const char* device_name = nullptr;
+  if (!PyArg_ParseTuple(args, "O&:copy_to_device", ConvertDeviceName,
+                        &device_name)) {
     return nullptr;
   }
-  auto handle = CopyToDevice(self->handle, ctx, dev);
+
+  // Note that this is a shallow copy and will share the underlying buffer
+  // if copying to the same device.
+  TFE_TensorHandle* handle = TFE_TensorHandleCopyToDevice(
+      self->handle, GetContextHandle(self->context), device_name, self->status);
+  if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_RuntimeError)) {
+    // Cleanup self->status before returning.
+    TF_SetStatus(self->status, TF_OK, "");
+    return nullptr;
+  }
+
   return EagerTensorFromHandle(handle);
 }
 
@@ -900,13 +901,13 @@
     t->handle = handle;
     t->status = TF_NewStatus();
     t->weakreflist = nullptr;
-    PyObject* context = GetPyEagerContext();
-    if (context == nullptr) {
+    PyObject* py_context = GetPyEagerContext();
+    if (py_context == nullptr) {
       LOG(ERROR) << "Cannot create an eager tensor before eager context has "
                     "been set or after it has been deleted";
       return nullptr;
     }
-    t->context = context;
+    t->context = py_context;
 
     if (!MaybeInvokeCreatedOnEagerTensorProfiler(t)) {
       return nullptr;
diff --git a/tensorflow/python/eager/pywrap_tensor.h b/tensorflow/python/eager/pywrap_tensor.h
index 53c0d77..0a46217 100644
--- a/tensorflow/python/eager/pywrap_tensor.h
+++ b/tensorflow/python/eager/pywrap_tensor.h
@@ -27,10 +27,15 @@
 
 namespace tensorflow {
 
-// Converts value to a TFE_TensorHandle of a given dtype. Note that the
-// resulting handle is always allocated on CPU.
+// Converts a value to a TFE_TensorHandle of a given dtype. The handle is
+// first allocated on CPU and then copied to a device identified by
+// device_name, unless it is nullptr.
+//
+// Note that an DT_INT32 handle is always kept on CPU regardless of the
+// device_name argument.
 TFE_TensorHandle* ConvertToEagerTensor(TFE_Context* ctx, PyObject* value,
-                                       DataType dtype);
+                                       DataType dtype,
+                                       const char* device_name = nullptr);
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/python/eager/pywrap_tensor_conversion.cc b/tensorflow/python/eager/pywrap_tensor_conversion.cc
new file mode 100644
index 0000000..90bd62a
--- /dev/null
+++ b/tensorflow/python/eager/pywrap_tensor_conversion.cc
@@ -0,0 +1,69 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/python/eager/pywrap_tensor_conversion.h"
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/hash/hash.h"
+#include "tensorflow/c/eager/c_api_internal.h"
+#include "tensorflow/core/lib/monitoring/counter.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+auto* scalar_cache_hits = tensorflow::monitoring::Counter<0>::New(
+    "/tensorflow/eager/python/scalar_cache_hits",
+    "Number of times a scalar TFE_TensorHandle was retrieved from cache");
+auto* scalar_cache_misses = tensorflow::monitoring::Counter<0>::New(
+    "/tensorflow/eager/python/scalar_cache_misses",
+    "Number of times a scalar TFE_TensorHandle was not available in cache");
+
+TFE_TensorHandleCache* TFE_TensorHandleCache::Get() {
+  // TODO(slebedev): link with Context (in context.py) instead of having
+  // a static global?
+  static auto* cache = new TFE_TensorHandleCache();
+  return cache;
+}
+
+TFE_TensorHandle* TFE_TensorHandleCache::Lookup(
+    PyObject* value, tensorflow::DataType dtype,
+    absl::string_view device_name) const {
+  CHECK_NOTNULL(value);
+  const auto& it = cache.find(Key{PyObjectPtr{value}, dtype, device_name});
+  if (it == cache.end()) {
+    scalar_cache_misses->GetCell()->IncrementBy(1);
+    return nullptr;
+  }
+
+  scalar_cache_hits->GetCell()->IncrementBy(1);
+  auto* handle = it->second;
+  handle->Ref();
+  return new TFE_TensorHandle(handle);
+}
+
+void TFE_TensorHandleCache::Insert(PyObject* value, tensorflow::DataType dtype,
+                                   absl::string_view device_name,
+                                   TFE_TensorHandle* handle) {
+  Py_INCREF(value);
+  handle->handle->Ref();
+  cache.emplace(Key{PyObjectPtr{value}, dtype, device_name}, handle->handle);
+}
+
+void TFE_TensorHandleCache::Clear() {
+  DecrefUnrefAll();
+  cache.clear();
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/python/eager/pywrap_tensor_conversion.h b/tensorflow/python/eager/pywrap_tensor_conversion.h
new file mode 100644
index 0000000..5caf68c
--- /dev/null
+++ b/tensorflow/python/eager/pywrap_tensor_conversion.h
@@ -0,0 +1,101 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_CONVERSION_H_
+#define TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_CONVERSION_H_
+
+// Place `<locale>` before <Python.h> to avoid build failure in macOS.
+#include <locale>
+
+// The empty line above is on purpose as otherwise clang-format will
+// automatically move <Python.h> before <locale>.
+#include <Python.h>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/hash/hash.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+#include "tensorflow/core/framework/types.pb.h"
+
+namespace tensorflow {
+
+// Wrapper-class allowing to use Python hashing/comparison functions
+// for PyObject*.
+//
+// Note that unlike Safe_PyObjectPtr this class does not steal a
+// reference to a Python object. The caller is responsible for doing
+// Py_INCREF/Py_DECREF.
+struct PyObjectPtr {
+  template <typename H>
+  friend H AbslHashValue(H h, const PyObjectPtr& obj) {
+    return H::combine(std::move(h), PyObject_Hash(obj.ptr));
+  }
+
+  explicit PyObjectPtr(PyObject* ptr) : ptr(ptr) {}
+
+  explicit inline operator PyObject*() const { return ptr; }
+
+  inline bool operator==(const PyObjectPtr& other) const {
+    // We require exact type equality to account for 0 == 0.0 == False.
+    if (Py_TYPE(ptr) != Py_TYPE(other.ptr)) {
+      return false;
+    }
+
+    bool result = PyObject_RichCompareBool(ptr, other.ptr, Py_EQ) > 0;
+    CHECK(!PyErr_Occurred());
+    return result;
+  }
+
+ private:
+  PyObject* ptr;
+};
+
+// Cache mapping PyObject* to the corresponding on-device TFE_TensorHandles.
+// Used to speed up ConvertToEagerTensor for scalars.
+// TODO(slebedev): move ConvertToEagerTensor here.
+struct TFE_TensorHandleCache {
+  static TFE_TensorHandleCache* Get();
+
+  TFE_TensorHandleCache() { cache.reserve(64); }
+  ~TFE_TensorHandleCache() { DecrefUnrefAll(); }
+
+  TFE_TensorHandle* Lookup(PyObject* value, tensorflow::DataType dtype,
+                           absl::string_view device_name) const;
+
+  void Insert(PyObject* value, tensorflow::DataType dtype,
+              absl::string_view device_name, TFE_TensorHandle* handle);
+
+  void Clear();
+
+ private:
+  // TODO(slebedev): should the key depend on TFE_Context?
+  using Key = std::tuple<PyObjectPtr, tensorflow::DataType, absl::string_view>;
+
+  void DecrefUnrefAll() {
+    for (const auto& p : cache) {
+      Py_DECREF(static_cast<PyObject*>(std::get<0>(p.first)));
+      p.second->Unref();
+    }
+  }
+
+  // Not guarded by a mutex because the code is only used while the
+  // GIL is held.
+  absl::flat_hash_map<Key, tensorflow::TensorHandle*> cache;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_CONVERSION_H_
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 9b6ac1a..b3a4bb2 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -1905,6 +1905,12 @@
       if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
         return;
       }
+      if (accumulator->accumulator->BusyAccumulating()) {
+        // Ensure inner accumulators don't see outer accumulators' jvps. This
+        // mostly happens on its own, with some potentially surprising
+        // exceptions, so the blanket policy is for consistency.
+        break;
+      }
     }
   }
 }
@@ -2365,7 +2371,6 @@
           {"Relu6", {true, {}}},
           {"Elu", {true, {}}},
           {"Selu", {true, {}}},
-          {"SparseSoftmaxCrossEntropyWithLogits", {true, {}}},
           {"Neg", {true, {}}},
           {"Inv", {true, {}}},
           {"Reciprocal", {true, {}}},
@@ -2383,6 +2388,7 @@
 
           // Ops that don't require a subset of inputs.
           {"FusedBatchNorm", {false, {2}}},
+          {"SparseSoftmaxCrossEntropyWithLogits", {false, {1}}},
       });
 
   auto it = m->find(op_name);
@@ -2698,25 +2704,15 @@
   // The hint comes from a supposedly similarly typed tensor.
   tensorflow::DataType dtype_hint = dtype_hint_getter();
 
-  tensorflow::Safe_TFE_TensorHandlePtr handle = tensorflow::make_safe(
-      tensorflow::ConvertToEagerTensor(op_exec_info.ctx, input, dtype_hint));
+  TFE_TensorHandle* handle = tensorflow::ConvertToEagerTensor(
+      op_exec_info.ctx, input, dtype_hint, op_exec_info.device_name);
   if (handle == nullptr) {
     return MaybeRaiseExceptionFromTFStatus(status, nullptr);
   }
 
-  auto output_dtype = TFE_TensorHandleDataType(handle.get());
-  if (output_dtype != TF_INT32) {
-    // Note that this is a shallow copy and will share the underlying buffer
-    // if copying to the same device.
-    handle = tensorflow::make_safe(TFE_TensorHandleCopyToDevice(
-        handle.get(), op_exec_info.ctx, op_exec_info.device_name, status));
-    if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
-      return false;
-    }
-  }
-
-  output_handle->reset(EagerTensorFromHandle(handle.release()));
-  dtype_setter(static_cast<tensorflow::DataType>(output_dtype));
+  output_handle->reset(EagerTensorFromHandle(handle));
+  dtype_setter(
+      static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(handle)));
 
   return true;
 }
@@ -3505,9 +3501,9 @@
 PyObject* weak_eager_context = nullptr;
 }  // namespace
 
-PyObject* TFE_Py_SetEagerContext(PyObject* python_context) {
+PyObject* TFE_Py_SetEagerContext(PyObject* py_context) {
   Py_XDECREF(weak_eager_context);
-  weak_eager_context = PyWeakref_NewRef(python_context, nullptr);
+  weak_eager_context = PyWeakref_NewRef(py_context, nullptr);
   if (weak_eager_context == nullptr) {
     return nullptr;
   }
@@ -3516,14 +3512,14 @@
 
 PyObject* GetPyEagerContext() {
   if (weak_eager_context == nullptr) {
-    PyErr_SetString(PyExc_ValueError, "Python eager context is not set");
+    PyErr_SetString(PyExc_RuntimeError, "Python eager context is not set");
     return nullptr;
   }
-  PyObject* context = PyWeakref_GET_OBJECT(weak_eager_context);
-  if (context == Py_None) {
-    LOG(ERROR) << "Eager context has been destroyed";
+  PyObject* py_context = PyWeakref_GET_OBJECT(weak_eager_context);
+  if (py_context == Py_None) {
+    PyErr_SetString(PyExc_RuntimeError, "Eager context has been destroyed");
     return nullptr;
   }
-  Py_INCREF(context);
-  return context;
+  Py_INCREF(py_context);
+  return py_context;
 }
diff --git a/tensorflow/python/eager/remote.py b/tensorflow/python/eager/remote.py
index cccec01..15dec68 100644
--- a/tensorflow/python/eager/remote.py
+++ b/tensorflow/python/eager/remote.py
@@ -24,6 +24,7 @@
 from tensorflow.python import pywrap_tensorflow
 from tensorflow.python.distribute.cluster_resolver import cluster_resolver
 from tensorflow.python.eager import context
+from tensorflow.python.platform import remote_utils
 from tensorflow.python.training import server_lib
 from tensorflow.python.util import nest
 from tensorflow.python.util.tf_export import tf_export
@@ -73,11 +74,10 @@
 
 
 @tf_export("config.experimental_connect_to_cluster")
-def connect_to_cluster(
-    cluster_spec_or_resolver,
-    job_name="localhost",
-    task_index=0,
-    protocol="grpc"):
+def connect_to_cluster(cluster_spec_or_resolver,
+                       job_name="localhost",
+                       task_index=0,
+                       protocol=None):
   """Connects to the given cluster.
 
   Will make devices on the cluster available to use. Note that calling this more
@@ -92,8 +92,10 @@
       the cluster.
     job_name: The name of the local job.
     task_index: The local task index.
-    protocol: The communication protocol.
+    protocol: The communication protocol, such as `"grpc"`. If unspecified, will
+      use the default from `python/platform/remote_utils.py`.
   """
+  protocol = protocol or remote_utils.get_default_communication_protocol()
   if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec):
     cluster_spec = cluster_spec_or_resolver
   elif isinstance(cluster_spec_or_resolver, cluster_resolver.ClusterResolver):
diff --git a/tensorflow/python/eager/remote_test.py b/tensorflow/python/eager/remote_test.py
index 1b13d0c..30746a8 100644
--- a/tensorflow/python/eager/remote_test.py
+++ b/tensorflow/python/eager/remote_test.py
@@ -18,6 +18,10 @@
 from __future__ import division
 from __future__ import print_function
 
+import random
+
+import numpy as np
+
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import remote
@@ -26,6 +30,7 @@
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.training import server_lib
@@ -79,20 +84,35 @@
         cm.exception.message)
 
   def testMultiDeviceFunctionAmbiguousDevice(self):
+    self.skipTest('b/139212497')
 
     @def_function.function
     def ambiguous_device(i):
       with ops.device('cpu:0'):
         return i + constant_op.constant([2])
 
-    with self.assertRaises(errors.InvalidArgumentError) as cm:
+    with self.assertRaises(ValueError) as cm:
       with ops.device('/job:worker/replica:0/task:0/cpu:0'):
-        self.assertAllEqual(
-            ambiguous_device(constant_op.constant([2])).numpy(), [3])
+        ambiguous_device(constant_op.constant([2])).numpy()
 
     self.assertIn('the output node must match exactly one device',
                   cm.exception.message)
 
+  def testStreaming(self):
+    """A mini stress test for streaming - issuing many RPCs back to back."""
+    with ops.device('job:worker/replica:0/task:0/device:CPU:0'):
+      x = array_ops.ones([2, 2])
+      y = array_ops.zeros([2, 2])
+      num_iters = 200
+      for _ in range(num_iters):
+        y = x + y
+        # Ask for y's shape after every 10 additions on average.
+        # This exercises waiting for remote shape logic in TensorHandle.
+        if random.randint(1, 10) == 1:
+          _ = y.shape
+    np.testing.assert_array_equal(
+        [[num_iters, num_iters], [num_iters, num_iters]], y.numpy())
+
 
 class MultiWorkersTest(test.TestCase):
 
diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py
index c433058..74b4b43 100644
--- a/tensorflow/python/eager/tensor_test.py
+++ b/tensorflow/python/eager/tensor_test.py
@@ -47,8 +47,7 @@
   if dtype is not None:
     dtype = dtype.as_datatype_enum
   try:
-    return ops.EagerTensor(
-        value, context=ctx, device=device, dtype=dtype)
+    return ops.EagerTensor(value, device=device, dtype=dtype)
   except core._NotOkStatusException as e:  # pylint: disable=protected-access
     raise core._status_to_exception(e.code, e.message)
 
@@ -68,35 +67,22 @@
     context.ensure_initialized()
     ctx = context.context()
     device = ctx.device_name
-    # Missing context.
-    with self.assertRaisesRegexp(
-        TypeError, r".*argument 'context' \(pos 2\).*"):
-      ops.EagerTensor(1, device=device)
     # Missing device.
-    with self.assertRaisesRegexp(
-        TypeError, r".*argument 'device' \(pos 3\).*"):
-      ops.EagerTensor(1, context=ctx)
+    with self.assertRaisesRegexp(TypeError, r".*argument 'device' \(pos 2\).*"):
+      ops.EagerTensor(1)
     # Bad dtype type.
     with self.assertRaisesRegexp(TypeError,
                                  "Expecting a DataType value for dtype. Got"):
-      ops.EagerTensor(1, context=ctx, device=device, dtype="1")
+      ops.EagerTensor(1, device=device, dtype="1")
 
     # Following errors happen when trying to copy to GPU.
     if not test_util.is_gpu_available():
       self.skipTest("No GPUs found")
 
     with ops.device("/device:GPU:0"):
-      device = ctx.device_name
-      # Bad context.
-      with self.assertRaisesRegexp(
-          TypeError,
-          "Expected `context` argument in EagerTensor constructor to have a "
-          "`_handle` field but it did not. Was eager Context initialized?"):
-        ops.EagerTensor(1.0, context=1, device=device)
       # Bad device.
-      with self.assertRaisesRegexp(
-          TypeError, "Error parsing device argument to CopyToDevice"):
-        ops.EagerTensor(1.0, context=ctx, device=1)
+      with self.assertRaisesRegexp(TypeError, "Error parsing device argument"):
+        ops.EagerTensor(1.0, device=1)
 
   def testNumpyValue(self):
     values = np.array([3.0])
@@ -122,8 +108,7 @@
     ctx = context.context()
     # Bad dtype value.
     with self.assertRaisesRegexp(TypeError, "Invalid dtype argument value"):
-      ops.EagerTensor(
-          values, context=ctx, device=ctx.device_name, dtype=12345)
+      ops.EagerTensor(values, device=ctx.device_name, dtype=12345)
 
   def testNumpyOrderHandling(self):
     n = np.array([[1, 2], [3, 4]], order="F")
diff --git a/tensorflow/python/eager/wrap_function.py b/tensorflow/python/eager/wrap_function.py
index ad2f24e..625a7d3 100644
--- a/tensorflow/python/eager/wrap_function.py
+++ b/tensorflow/python/eager/wrap_function.py
@@ -22,9 +22,11 @@
 import weakref
 
 from tensorflow.core.protobuf import meta_graph_pb2
+from tensorflow.core.protobuf import struct_pb2
 from tensorflow.python.eager import context
 from tensorflow.python.eager import function
 from tensorflow.python.eager import lift_to_graph
+from tensorflow.python.framework import composite_tensor
 from tensorflow.python.framework import func_graph
 from tensorflow.python.framework import importer
 from tensorflow.python.framework import ops
@@ -34,8 +36,10 @@
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.saved_model import nested_structure_coder
 from tensorflow.python.training.tracking import data_structures
 from tensorflow.python.util import nest
+from tensorflow.python.util import object_identity
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -103,6 +107,14 @@
         graph.get_tensor_by_name(tensor_info.coo_sparse.values_tensor_name),
         graph.get_tensor_by_name(
             tensor_info.coo_sparse.dense_shape_tensor_name))
+  elif encoding == "composite_tensor":
+    struct_coder = nested_structure_coder.StructureCoder()
+    spec_proto = struct_pb2.StructuredValue(
+        type_spec_value=tensor_info.composite_tensor.type_spec)
+    spec = struct_coder.decode_proto(spec_proto)
+    components = [graph.get_tensor_by_name(component.name) for component in
+                  tensor_info.composite_tensor.components]
+    return spec._from_components(components)  # pylint: disable=protected-access
   else:
     raise ValueError("Invalid TensorInfo.encoding: %s" % encoding)
 
@@ -116,8 +128,7 @@
       trainable=old_variable.trainable,
       extra_handle_data=old_variable.handle)
   new_variable._initializer_op = old_variable._initializer_op  # pylint: disable=protected-access
-  graph.inputs.append(old_variable.handle)
-  graph.captures[new_variable.handle] = old_variable.handle
+  graph.add_capture(new_variable.handle, old_variable.handle)
   # Now that we've added the new variable to graph.captures,
   # graph.capture will use that cached value and do some post-processing
   # on the capture like recording it on the tape.
@@ -151,8 +162,9 @@
         ops.GraphKeys.GLOBAL_VARIABLES)
     local_collection_variables = ops.get_collection(
         ops.GraphKeys.LOCAL_VARIABLES)
-    existing_captures = set(graph.internal_captures)
-    lifted_variables = {}
+    existing_captures = object_identity.ObjectIdentitySet(
+        graph.internal_captures)
+    lifted_variables = object_identity.ObjectIdentityDictionary()
 
     def _should_lift_variable(v):
       return ((v._in_graph_mode  # pylint: disable=protected-access
@@ -242,15 +254,16 @@
     """
     # TODO(b/129646028): Add support for CompositeTensors.
     name = name or "pruned"
-    feeds = nest.map_structure(self.graph.as_graph_element, feeds)
-    flat_feeds = nest.flatten(feeds)
+    flat_feeds = nest.flatten(feeds, expand_composites=True)
+    flat_feeds = [self.graph.as_graph_element(t) for t in flat_feeds]
     for f in flat_feeds:
       if not isinstance(f, ops.Tensor):
         raise ValueError("Feeds must be tensors.")
 
     # Ignoring all feeds that are captures allows prune to be called
     # using wrapped_func.inputs even when it uses variables
-    internal_captures = self.graph.internal_captures
+    internal_captures = object_identity.ObjectIdentitySet(
+        self.graph.internal_captures)
     flat_feeds = [f for f in flat_feeds if f not in internal_captures]
 
     operation_fetches = []
@@ -276,12 +289,13 @@
       elif isinstance(fetch, meta_graph_pb2.TensorInfo):
         tensor_infos.append(fetch)
         decoded = _get_element_from_tensor_info(fetch, self._func_graph)
-        if tensor_util.is_tensor(decoded):
+        if (tensor_util.is_tensor(decoded) or
+            isinstance(decoded, composite_tensor.CompositeTensor)):
           tensor_fetches.append(decoded)
         else:
           operation_fetches.append(decoded)
         return decoded
-      elif isinstance(fetch, ops.Tensor):
+      elif isinstance(fetch, (ops.Tensor, composite_tensor.CompositeTensor)):
         tensor_fetches.append(fetch)
         return fetch
       else:
@@ -303,7 +317,7 @@
     lift_map = lift_to_graph.lift_to_graph(
         operation_fetches + tensor_fetches,
         pruned_graph,
-        sources=flat_feeds + internal_captures)
+        sources=flat_feeds + self.graph.internal_captures)
 
     # Note that we add the component tensors of any composite tensors to the
     # returned function's outputs list; the list must contain these component
@@ -311,10 +325,9 @@
     pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches)
     pruned_graph.control_outputs.extend(
         [lift_map[operation] for operation in operation_fetches])
-    for external_capture, internal_capture in self.graph.captures.items():
-      pruned_graph.captures[external_capture] = lift_map[internal_capture]
     pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
-    pruned_graph.inputs.extend(pruned_graph.captures.values())
+    for external_capture, internal_capture in self.graph.captures:
+      pruned_graph.add_capture(external_capture, lift_map[internal_capture])
     for ti in tensor_infos:
       if ti.WhichOneof("encoding") == "name":  # Dense tensors only
         t = pruned_graph.as_graph_element(ti.name)
diff --git a/tensorflow/python/eager/wrap_function_test.py b/tensorflow/python/eager/wrap_function_test.py
index 1a135b3..4b592a5 100644
--- a/tensorflow/python/eager/wrap_function_test.py
+++ b/tensorflow/python/eager/wrap_function_test.py
@@ -36,6 +36,8 @@
 from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
+from tensorflow.python.ops.ragged import ragged_factory_ops
+from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.platform import test
 from tensorflow.python.training import saver as saver_lib
 
@@ -84,6 +86,31 @@
     f_pruned = f_wrapped.prune(x_in[0], [x_out[0]])
     self.assertAllEqual(f_pruned(ops.convert_to_tensor(2.0)), [4.0])
 
+  def testPruneRagged(self):
+
+    x_in = []
+    x_out = []
+
+    def f(x, y):
+      x_in.append(x)
+      xx = x * x
+      x_out.append(xx)
+      return xx, y * y
+
+    x_spec = ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32)
+    y_spec = tensor_spec.TensorSpec((), dtypes.float32)
+
+    f_wrapped = wrap_function.wrap_function(f, [x_spec, y_spec])
+
+    f_pruned = f_wrapped.prune(x_in[0], x_out[0])
+    rt = ragged_factory_ops.constant([[1.0, 2.0], [3.0]])
+    expected = ragged_factory_ops.constant_value([[1.0, 4.0], [9.0]])
+
+    # Note: when we call f_pruned, we must pass the RaggedTensor in using
+    # its components, since that's the current convention for how concrete
+    # functions handle structured inputs.
+    self.assertAllEqual(f_pruned(rt.values, rt.row_splits), expected)
+
   def _assert_single_captured_variable_argument(self, graph_def):
     # The single FunctionDef should have one argument, a captured variable
     function_def, = graph_def.library.function
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 0020eeb..1c6b837 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -6,7 +6,6 @@
 py_library(
     name = "estimator_py",
     srcs = [
-        "__init__.py",
         "estimator_lib.py",
     ],
     srcs_version = "PY2AND3",
@@ -383,7 +382,6 @@
 py_library(
     name = "inputs_queues",
     srcs = [
-        "inputs/queues/__init__.py",
         "inputs/queues/feeding_functions.py",
         "inputs/queues/feeding_queue_runner.py",
     ],
diff --git a/tensorflow/python/estimator/__init__.py b/tensorflow/python/estimator/__init__.py
deleted file mode 100644
index 1e32161..0000000
--- a/tensorflow/python/estimator/__init__.py
+++ /dev/null
@@ -1,32 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""estimator python module.
-
-Importing from tensorflow.python.estimator is unsupported
-and will soon break!
-"""
-# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow_estimator.python import estimator
-
-# Include attrs that start with single underscore.
-_HAS_DYNAMIC_ATTRIBUTES = True
-estimator.__all__ = [s for s in dir(estimator) if not s.startswith('__')]
-
-from tensorflow_estimator.python.estimator import *
diff --git a/tensorflow/python/estimator/canned/__init__.py b/tensorflow/python/estimator/canned/__init__.py
deleted file mode 100644
index d640c8c..0000000
--- a/tensorflow/python/estimator/canned/__init__.py
+++ /dev/null
@@ -1,32 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""canned python module.
-
-Importing from tensorflow.python.estimator is unsupported
-and will soon break!
-"""
-# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow_estimator.python.estimator import canned
-
-# Include attrs that start with single underscore.
-_HAS_DYNAMIC_ATTRIBUTES = True
-canned.__all__ = [s for s in dir(canned) if not s.startswith('__')]
-
-from tensorflow_estimator.python.estimator.canned import *
diff --git a/tensorflow/python/estimator/export/__init__.py b/tensorflow/python/estimator/export/__init__.py
deleted file mode 100644
index 898efd4..0000000
--- a/tensorflow/python/estimator/export/__init__.py
+++ /dev/null
@@ -1,32 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""export python module.
-
-Importing from tensorflow.python.estimator is unsupported
-and will soon break!
-"""
-# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow_estimator.python.estimator import export
-
-# Include attrs that start with single underscore.
-_HAS_DYNAMIC_ATTRIBUTES = True
-export.__all__ = [s for s in dir(export) if not s.startswith('__')]
-
-from tensorflow_estimator.python.estimator.export import *
diff --git a/tensorflow/python/estimator/inputs/__init__.py b/tensorflow/python/estimator/inputs/__init__.py
deleted file mode 100644
index 045ede2..0000000
--- a/tensorflow/python/estimator/inputs/__init__.py
+++ /dev/null
@@ -1,32 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""inputs python module.
-
-Importing from tensorflow.python.estimator is unsupported
-and will soon break!
-"""
-# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow_estimator.python.estimator import inputs
-
-# Include attrs that start with single underscore.
-_HAS_DYNAMIC_ATTRIBUTES = True
-inputs.__all__ = [s for s in dir(inputs) if not s.startswith('__')]
-
-from tensorflow_estimator.python.estimator.inputs import *
diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD
index 1e8828d..38c3657 100644
--- a/tensorflow/python/feature_column/BUILD
+++ b/tensorflow/python/feature_column/BUILD
@@ -53,6 +53,8 @@
 py_library(
     name = "feature_column_v2",
     srcs = [
+        "dense_features.py",
+        "dense_features_v2.py",
         "feature_column_v2.py",
         "sequence_feature_column.py",
         "serialization.py",
@@ -115,6 +117,15 @@
     ],
 )
 
+tf_py_test(
+    name = "dense_features_test",
+    srcs = ["dense_features_test.py"],
+    additional_deps = [
+        ":feature_column_test_main_lib",
+    ],
+    tags = ["no_pip"],
+)
+
 py_library(
     name = "feature_column_test_main_lib",
     srcs = ["feature_column_test.py"],
@@ -156,6 +167,15 @@
     ],
 )
 
+tf_py_test(
+    name = "dense_features_v2_test",
+    srcs = ["dense_features_v2_test.py"],
+    additional_deps = [
+        ":feature_column_v2_test_main_lib",
+    ],
+    tags = ["no_pip"],
+)
+
 py_library(
     name = "feature_column_v2_test_main_lib",
     srcs = ["feature_column_v2_test.py"],
@@ -181,7 +201,6 @@
         "//tensorflow/python:variables",
         "//tensorflow/python/eager:backprop",
         "//tensorflow/python/eager:context",
-        "//tensorflow/python/estimator:estimator_py",
         "//third_party/py/numpy",
     ],
 )
diff --git a/tensorflow/python/feature_column/dense_features.py b/tensorflow/python/feature_column/dense_features.py
new file mode 100644
index 0000000..d150fcc
--- /dev/null
+++ b/tensorflow/python/feature_column/dense_features.py
@@ -0,0 +1,124 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A layer that produces a dense `Tensor` based on given `feature_columns`."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.feature_column import feature_column_v2 as fc
+from tensorflow.python.framework import ops
+from tensorflow.python.util.tf_export import keras_export
+
+
+@keras_export(v1=['keras.layers.DenseFeatures'])
+class DenseFeatures(fc._BaseFeaturesLayer):  # pylint: disable=protected-access
+  """A layer that produces a dense `Tensor` based on given `feature_columns`.
+
+  Generally a single example in training data is described with FeatureColumns.
+  At the first layer of the model, this column oriented data should be converted
+  to a single `Tensor`.
+
+  This layer can be called multiple times with different features.
+
+  This is the V1 version of this layer that uses variable_scope's to create
+  variables which works well with PartitionedVariables. Variable scopes are
+  deprecated in V2, so the V2 version uses name_scopes instead. But currently
+  that lacks support for partitioned variables. Use this if you need
+  partitioned variables.
+
+  Example:
+
+  ```python
+  price = numeric_column('price')
+  keywords_embedded = embedding_column(
+      categorical_column_with_hash_bucket("keywords", 10K), dimensions=16)
+  columns = [price, keywords_embedded, ...]
+  feature_layer = DenseFeatures(columns)
+
+  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
+  dense_tensor = feature_layer(features)
+  for units in [128, 64, 32]:
+    dense_tensor = tf.compat.v1.keras.layers.Dense(
+                       units, activation='relu')(dense_tensor)
+  prediction = tf.compat.v1.keras.layers.Dense(1)(dense_tensor)
+  ```
+  """
+
+  def __init__(self, feature_columns, trainable=True, name=None, **kwargs):
+    """Constructs a DenseFeatures layer.
+
+    Args:
+      feature_columns: An iterable containing the FeatureColumns to use as
+        inputs to your model. All items should be instances of classes derived
+        from `DenseColumn` such as `numeric_column`, `embedding_column`,
+        `bucketized_column`, `indicator_column`. If you have categorical
+        features, you can wrap them with an `embedding_column` or
+        `indicator_column`.
+      trainable:  Boolean, whether the layer's variables will be updated via
+        gradient descent during training.
+      name: Name to give to the DenseFeatures.
+      **kwargs: Keyword arguments to construct a layer.
+
+    Raises:
+      ValueError: if an item in `feature_columns` is not a `DenseColumn`.
+    """
+    super(DenseFeatures, self).__init__(
+        feature_columns=feature_columns,
+        trainable=trainable,
+        name=name,
+        expected_column_type=fc.DenseColumn,
+        **kwargs)
+
+  @property
+  def _is_feature_layer(self):
+    return True
+
+  def _target_shape(self, input_shape, total_elements):
+    return (input_shape[0], total_elements)
+
+  def call(self, features, cols_to_output_tensors=None):
+    """Returns a dense tensor corresponding to the `feature_columns`.
+
+    Args:
+      features: A mapping from key to tensors. `FeatureColumn`s look up via
+        these keys. For example `numeric_column('price')` will look at 'price'
+        key in this dict. Values can be a `SparseTensor` or a `Tensor` depends
+        on corresponding `FeatureColumn`.
+      cols_to_output_tensors: If not `None`, this will be filled with a dict
+        mapping feature columns to output tensors created.
+
+    Returns:
+      A `Tensor` which represents input layer of a model. Its shape
+      is (batch_size, first_layer_dimension) and its dtype is `float32`.
+      first_layer_dimension is determined based on given `feature_columns`.
+
+    Raises:
+      ValueError: If features are not a dictionary.
+    """
+    if not isinstance(features, dict):
+      raise ValueError('We expected a dictionary here. Instead we got: ',
+                       features)
+    transformation_cache = fc.FeatureTransformationCache(features)
+    output_tensors = []
+    for column in self._feature_columns:
+      with ops.name_scope(column.name):
+        tensor = column.get_dense_tensor(transformation_cache,
+                                         self._state_manager)
+        processed_tensors = self._process_dense_tensor(column, tensor)
+        if cols_to_output_tensors is not None:
+          cols_to_output_tensors[column] = processed_tensors
+        output_tensors.append(processed_tensors)
+    return self._verify_and_concat_tensors(output_tensors)
diff --git a/tensorflow/python/feature_column/dense_features_test.py b/tensorflow/python/feature_column/dense_features_test.py
new file mode 100644
index 0000000..bc9bea2
--- /dev/null
+++ b/tensorflow/python/feature_column/dense_features_test.py
@@ -0,0 +1,627 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for dense_features."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.feature_column import dense_features as df
+from tensorflow.python.feature_column import feature_column_v2 as fc
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import variables as variables_lib
+from tensorflow.python.platform import test
+
+
+def _initialized_session(config=None):
+  sess = session.Session(config=config)
+  sess.run(variables_lib.global_variables_initializer())
+  sess.run(lookup_ops.tables_initializer())
+  return sess
+
+
+class DenseFeaturesTest(test.TestCase):
+
+  @test_util.run_in_graph_and_eager_modes()
+  def test_retrieving_input(self):
+    features = {'a': [0.]}
+    dense_features = df.DenseFeatures(fc.numeric_column('a'))
+    inputs = self.evaluate(dense_features(features))
+    self.assertAllClose([[0.]], inputs)
+
+  def test_reuses_variables(self):
+    with context.eager_mode():
+      sparse_input = sparse_tensor.SparseTensor(
+          indices=((0, 0), (1, 0), (2, 0)),
+          values=(0, 1, 2),
+          dense_shape=(3, 3))
+
+      # Create feature columns (categorical and embedding).
+      categorical_column = fc.categorical_column_with_identity(
+          key='a', num_buckets=3)
+      embedding_dimension = 2
+
+      def _embedding_column_initializer(shape, dtype, partition_info=None):
+        del shape  # unused
+        del dtype  # unused
+        del partition_info  # unused
+        embedding_values = (
+            (1, 0),  # id 0
+            (0, 1),  # id 1
+            (1, 1))  # id 2
+        return embedding_values
+
+      embedding_column = fc.embedding_column(
+          categorical_column,
+          dimension=embedding_dimension,
+          initializer=_embedding_column_initializer)
+
+      dense_features = df.DenseFeatures([embedding_column])
+      features = {'a': sparse_input}
+
+      inputs = dense_features(features)
+      variables = dense_features.variables
+
+      # Sanity check: test that the inputs are correct.
+      self.assertAllEqual([[1, 0], [0, 1], [1, 1]], inputs)
+
+      # Check that only one variable was created.
+      self.assertEqual(1, len(variables))
+
+      # Check that invoking dense_features on the same features does not create
+      # additional variables
+      _ = dense_features(features)
+      self.assertEqual(1, len(variables))
+      self.assertEqual(variables[0], dense_features.variables[0])
+
+  def test_feature_column_dense_features_gradient(self):
+    with context.eager_mode():
+      sparse_input = sparse_tensor.SparseTensor(
+          indices=((0, 0), (1, 0), (2, 0)),
+          values=(0, 1, 2),
+          dense_shape=(3, 3))
+
+      # Create feature columns (categorical and embedding).
+      categorical_column = fc.categorical_column_with_identity(
+          key='a', num_buckets=3)
+      embedding_dimension = 2
+
+      def _embedding_column_initializer(shape, dtype, partition_info=None):
+        del shape  # unused
+        del dtype  # unused
+        del partition_info  # unused
+        embedding_values = (
+            (1, 0),  # id 0
+            (0, 1),  # id 1
+            (1, 1))  # id 2
+        return embedding_values
+
+      embedding_column = fc.embedding_column(
+          categorical_column,
+          dimension=embedding_dimension,
+          initializer=_embedding_column_initializer)
+
+      dense_features = df.DenseFeatures([embedding_column])
+      features = {'a': sparse_input}
+
+      def scale_matrix():
+        matrix = dense_features(features)
+        return 2 * matrix
+
+      # Sanity check: Verify that scale_matrix returns the correct output.
+      self.assertAllEqual([[2, 0], [0, 2], [2, 2]], scale_matrix())
+
+      # Check that the returned gradient is correct.
+      grad_function = backprop.implicit_grad(scale_matrix)
+      grads_and_vars = grad_function()
+      indexed_slice = grads_and_vars[0][0]
+      gradient = grads_and_vars[0][0].values
+
+      self.assertAllEqual([0, 1, 2], indexed_slice.indices)
+      self.assertAllEqual([[2, 2], [2, 2], [2, 2]], gradient)
+
+  def test_raises_if_empty_feature_columns(self):
+    with self.assertRaisesRegexp(ValueError,
+                                 'feature_columns must not be empty'):
+      df.DenseFeatures(feature_columns=[])(features={})
+
+  def test_should_be_dense_column(self):
+    with self.assertRaisesRegexp(ValueError, 'must be a .*DenseColumn'):
+      df.DenseFeatures(feature_columns=[
+          fc.categorical_column_with_hash_bucket('wire_cast', 4)
+      ])(
+          features={
+              'a': [[0]]
+          })
+
+  def test_does_not_support_dict_columns(self):
+    with self.assertRaisesRegexp(
+        ValueError, 'Expected feature_columns to be iterable, found dict.'):
+      df.DenseFeatures(feature_columns={'a': fc.numeric_column('a')})(
+          features={
+              'a': [[0]]
+          })
+
+  def test_bare_column(self):
+    with ops.Graph().as_default():
+      features = features = {'a': [0.]}
+      net = df.DenseFeatures(fc.numeric_column('a'))(features)
+
+      self.evaluate(variables_lib.global_variables_initializer())
+      self.evaluate(lookup_ops.tables_initializer())
+
+      self.assertAllClose([[0.]], self.evaluate(net))
+
+  def test_column_generator(self):
+    with ops.Graph().as_default():
+      features = features = {'a': [0.], 'b': [1.]}
+      columns = (fc.numeric_column(key) for key in features)
+      net = df.DenseFeatures(columns)(features)
+
+      self.evaluate(variables_lib.global_variables_initializer())
+      self.evaluate(lookup_ops.tables_initializer())
+
+      self.assertAllClose([[0., 1.]], self.evaluate(net))
+
+  def test_raises_if_duplicate_name(self):
+    with self.assertRaisesRegexp(
+        ValueError, 'Duplicate feature column name found for columns'):
+      df.DenseFeatures(
+          feature_columns=[fc.numeric_column('a'),
+                           fc.numeric_column('a')])(
+                               features={
+                                   'a': [[0]]
+                               })
+
+  def test_one_column(self):
+    price = fc.numeric_column('price')
+    with ops.Graph().as_default():
+      features = {'price': [[1.], [5.]]}
+      net = df.DenseFeatures([price])(features)
+
+      self.evaluate(variables_lib.global_variables_initializer())
+      self.evaluate(lookup_ops.tables_initializer())
+
+      self.assertAllClose([[1.], [5.]], self.evaluate(net))
+
+  def test_multi_dimension(self):
+    price = fc.numeric_column('price', shape=2)
+    with ops.Graph().as_default():
+      features = {'price': [[1., 2.], [5., 6.]]}
+      net = df.DenseFeatures([price])(features)
+
+      self.evaluate(variables_lib.global_variables_initializer())
+      self.evaluate(lookup_ops.tables_initializer())
+
+      self.assertAllClose([[1., 2.], [5., 6.]], self.evaluate(net))
+
+  def test_compute_output_shape(self):
+    price1 = fc.numeric_column('price1', shape=2)
+    price2 = fc.numeric_column('price2', shape=4)
+    with ops.Graph().as_default():
+      features = {
+          'price1': [[1., 2.], [5., 6.]],
+          'price2': [[3., 4., 5., 6.], [7., 8., 9., 10.]]
+      }
+      dense_features = df.DenseFeatures([price1, price2])
+      self.assertEqual((None, 6), dense_features.compute_output_shape((None,)))
+      net = dense_features(features)
+
+      self.evaluate(variables_lib.global_variables_initializer())
+      self.evaluate(lookup_ops.tables_initializer())
+
+      self.assertAllClose([[1., 2., 3., 4., 5., 6.], [5., 6., 7., 8., 9., 10.]],
+                          self.evaluate(net))
+
+  def test_raises_if_shape_mismatch(self):
+    price = fc.numeric_column('price', shape=2)
+    with ops.Graph().as_default():
+      features = {'price': [[1.], [5.]]}
+      with self.assertRaisesRegexp(
+          Exception,
+          r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
+        df.DenseFeatures([price])(features)
+
+  def test_reshaping(self):
+    price = fc.numeric_column('price', shape=[1, 2])
+    with ops.Graph().as_default():
+      features = {'price': [[[1., 2.]], [[5., 6.]]]}
+      net = df.DenseFeatures([price])(features)
+
+      self.evaluate(variables_lib.global_variables_initializer())
+      self.evaluate(lookup_ops.tables_initializer())
+
+      self.assertAllClose([[1., 2.], [5., 6.]], self.evaluate(net))
+
+  def test_multi_column(self):
+    price1 = fc.numeric_column('price1', shape=2)
+    price2 = fc.numeric_column('price2')
+    with ops.Graph().as_default():
+      features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
+      net = df.DenseFeatures([price1, price2])(features)
+
+      self.evaluate(variables_lib.global_variables_initializer())
+      self.evaluate(lookup_ops.tables_initializer())
+
+      self.assertAllClose([[1., 2., 3.], [5., 6., 4.]], self.evaluate(net))
+
+  def test_cols_to_output_tensors(self):
+    price1 = fc.numeric_column('price1', shape=2)
+    price2 = fc.numeric_column('price2')
+    with ops.Graph().as_default():
+      cols_dict = {}
+      features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
+      dense_features = df.DenseFeatures([price1, price2])
+      net = dense_features(features, cols_dict)
+
+      self.evaluate(variables_lib.global_variables_initializer())
+      self.evaluate(lookup_ops.tables_initializer())
+
+      self.assertAllClose([[1., 2.], [5., 6.]],
+                          self.evaluate(cols_dict[price1]))
+      self.assertAllClose([[3.], [4.]], self.evaluate(cols_dict[price2]))
+      self.assertAllClose([[1., 2., 3.], [5., 6., 4.]], self.evaluate(net))
+
+  def test_column_order(self):
+    price_a = fc.numeric_column('price_a')
+    price_b = fc.numeric_column('price_b')
+    with ops.Graph().as_default():
+      features = {
+          'price_a': [[1.]],
+          'price_b': [[3.]],
+      }
+      net1 = df.DenseFeatures([price_a, price_b])(features)
+      net2 = df.DenseFeatures([price_b, price_a])(features)
+
+      self.evaluate(variables_lib.global_variables_initializer())
+      self.evaluate(lookup_ops.tables_initializer())
+
+      self.assertAllClose([[1., 3.]], self.evaluate(net1))
+      self.assertAllClose([[1., 3.]], self.evaluate(net2))
+
+  def test_fails_for_categorical_column(self):
+    animal = fc.categorical_column_with_identity('animal', num_buckets=4)
+    with ops.Graph().as_default():
+      features = {
+          'animal':
+              sparse_tensor.SparseTensor(
+                  indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
+      }
+      with self.assertRaisesRegexp(Exception, 'must be a .*DenseColumn'):
+        df.DenseFeatures([animal])(features)
+
+  def test_static_batch_size_mismatch(self):
+    price1 = fc.numeric_column('price1')
+    price2 = fc.numeric_column('price2')
+    with ops.Graph().as_default():
+      features = {
+          'price1': [[1.], [5.], [7.]],  # batchsize = 3
+          'price2': [[3.], [4.]]  # batchsize = 2
+      }
+      with self.assertRaisesRegexp(
+          ValueError,
+          r'Batch size \(first dimension\) of each feature must be same.'):  # pylint: disable=anomalous-backslash-in-string
+        df.DenseFeatures([price1, price2])(features)
+
+  def test_subset_of_static_batch_size_mismatch(self):
+    price1 = fc.numeric_column('price1')
+    price2 = fc.numeric_column('price2')
+    price3 = fc.numeric_column('price3')
+    with ops.Graph().as_default():
+      features = {
+          'price1': array_ops.placeholder(dtype=dtypes.int64),  # batchsize = 3
+          'price2': [[3.], [4.]],  # batchsize = 2
+          'price3': [[3.], [4.], [5.]]  # batchsize = 3
+      }
+      with self.assertRaisesRegexp(
+          ValueError,
+          r'Batch size \(first dimension\) of each feature must be same.'):  # pylint: disable=anomalous-backslash-in-string
+        df.DenseFeatures([price1, price2, price3])(features)
+
+  def test_runtime_batch_size_mismatch(self):
+    price1 = fc.numeric_column('price1')
+    price2 = fc.numeric_column('price2')
+    with ops.Graph().as_default():
+      features = {
+          'price1': array_ops.placeholder(dtype=dtypes.int64),  # batchsize = 3
+          'price2': [[3.], [4.]]  # batchsize = 2
+      }
+      net = df.DenseFeatures([price1, price2])(features)
+      with _initialized_session() as sess:
+        with self.assertRaisesRegexp(errors.OpError,
+                                     'Dimensions of inputs should match'):
+          sess.run(net, feed_dict={features['price1']: [[1.], [5.], [7.]]})
+
+  def test_runtime_batch_size_matches(self):
+    price1 = fc.numeric_column('price1')
+    price2 = fc.numeric_column('price2')
+    with ops.Graph().as_default():
+      features = {
+          'price1': array_ops.placeholder(dtype=dtypes.int64),  # batchsize = 2
+          'price2': array_ops.placeholder(dtype=dtypes.int64),  # batchsize = 2
+      }
+      net = df.DenseFeatures([price1, price2])(features)
+      with _initialized_session() as sess:
+        sess.run(
+            net,
+            feed_dict={
+                features['price1']: [[1.], [5.]],
+                features['price2']: [[1.], [5.]],
+            })
+
+  def test_multiple_layers_with_same_embedding_column(self):
+    some_sparse_column = fc.categorical_column_with_hash_bucket(
+        'sparse_feature', hash_bucket_size=5)
+    some_embedding_column = fc.embedding_column(
+        some_sparse_column, dimension=10)
+
+    with ops.Graph().as_default():
+      features = {
+          'sparse_feature': [['a'], ['x']],
+      }
+      all_cols = [some_embedding_column]
+      df.DenseFeatures(all_cols)(features)
+      df.DenseFeatures(all_cols)(features)
+      # Make sure that 2 variables get created in this case.
+      self.assertEqual(2,
+                       len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
+      expected_var_names = [
+          'dense_features/sparse_feature_embedding/embedding_weights:0',
+          'dense_features_1/sparse_feature_embedding/embedding_weights:0'
+      ]
+      self.assertItemsEqual(
+          expected_var_names,
+          [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+
+  @test_util.run_deprecated_v1
+  def test_multiple_layers_with_same_shared_embedding_column(self):
+    categorical_column_a = fc.categorical_column_with_identity(
+        key='aaa', num_buckets=3)
+    categorical_column_b = fc.categorical_column_with_identity(
+        key='bbb', num_buckets=3)
+    embedding_dimension = 2
+    embedding_column_b, embedding_column_a = fc.shared_embedding_columns_v2(
+        [categorical_column_b, categorical_column_a],
+        dimension=embedding_dimension)
+
+    with ops.Graph().as_default():
+      features = {
+          'aaa':
+              sparse_tensor.SparseTensor(
+                  indices=((0, 0), (1, 0), (1, 1)),
+                  values=(0, 1, 0),
+                  dense_shape=(2, 2)),
+          'bbb':
+              sparse_tensor.SparseTensor(
+                  indices=((0, 0), (1, 0), (1, 1)),
+                  values=(1, 2, 1),
+                  dense_shape=(2, 2)),
+      }
+      all_cols = [embedding_column_a, embedding_column_b]
+      df.DenseFeatures(all_cols)(features)
+      df.DenseFeatures(all_cols)(features)
+      # Make sure that only 1 variable gets created in this case.
+      self.assertEqual(1,
+                       len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
+      self.assertItemsEqual(
+          ['aaa_bbb_shared_embedding:0'],
+          [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+
+  @test_util.run_deprecated_v1
+  def test_multiple_layers_with_same_shared_embedding_column_diff_graphs(self):
+    categorical_column_a = fc.categorical_column_with_identity(
+        key='aaa', num_buckets=3)
+    categorical_column_b = fc.categorical_column_with_identity(
+        key='bbb', num_buckets=3)
+    embedding_dimension = 2
+    embedding_column_b, embedding_column_a = fc.shared_embedding_columns_v2(
+        [categorical_column_b, categorical_column_a],
+        dimension=embedding_dimension)
+    all_cols = [embedding_column_a, embedding_column_b]
+
+    with ops.Graph().as_default():
+      features = {
+          'aaa':
+              sparse_tensor.SparseTensor(
+                  indices=((0, 0), (1, 0), (1, 1)),
+                  values=(0, 1, 0),
+                  dense_shape=(2, 2)),
+          'bbb':
+              sparse_tensor.SparseTensor(
+                  indices=((0, 0), (1, 0), (1, 1)),
+                  values=(1, 2, 1),
+                  dense_shape=(2, 2)),
+      }
+      df.DenseFeatures(all_cols)(features)
+      # Make sure that only 1 variable gets created in this case.
+      self.assertEqual(1,
+                       len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
+
+    with ops.Graph().as_default():
+      features1 = {
+          'aaa':
+              sparse_tensor.SparseTensor(
+                  indices=((0, 0), (1, 0), (1, 1)),
+                  values=(0, 1, 0),
+                  dense_shape=(2, 2)),
+          'bbb':
+              sparse_tensor.SparseTensor(
+                  indices=((0, 0), (1, 0), (1, 1)),
+                  values=(1, 2, 1),
+                  dense_shape=(2, 2)),
+      }
+
+      df.DenseFeatures(all_cols)(features1)
+      # Make sure that only 1 variable gets created in this case.
+      self.assertEqual(1,
+                       len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
+      self.assertItemsEqual(
+          ['aaa_bbb_shared_embedding:0'],
+          [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+
+  @test_util.run_deprecated_v1
+  def test_with_1d_sparse_tensor(self):
+    embedding_values = (
+        (1., 2., 3., 4., 5.),  # id 0
+        (6., 7., 8., 9., 10.),  # id 1
+        (11., 12., 13., 14., 15.)  # id 2
+    )
+
+    def _initializer(shape, dtype, partition_info=None):
+      del shape, dtype, partition_info
+      return embedding_values
+
+    # price has 1 dimension in dense_features
+    price = fc.numeric_column('price')
+
+    # one_hot_body_style has 3 dims in dense_features.
+    body_style = fc.categorical_column_with_vocabulary_list(
+        'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+    one_hot_body_style = fc.indicator_column(body_style)
+
+    # embedded_body_style has 5 dims in dense_features.
+    country = fc.categorical_column_with_vocabulary_list(
+        'country', vocabulary_list=['US', 'JP', 'CA'])
+    embedded_country = fc.embedding_column(
+        country, dimension=5, initializer=_initializer)
+
+    # Provides 1-dim tensor and dense tensor.
+    features = {
+        'price':
+            constant_op.constant([
+                11.,
+                12.,
+            ]),
+        'body-style':
+            sparse_tensor.SparseTensor(
+                indices=((0,), (1,)),
+                values=('sedan', 'hardtop'),
+                dense_shape=(2,)),
+        # This is dense tensor for the categorical_column.
+        'country':
+            constant_op.constant(['CA', 'US']),
+    }
+    self.assertEqual(1, features['price'].shape.ndims)
+    self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
+    self.assertEqual(1, features['country'].shape.ndims)
+
+    net = df.DenseFeatures([price, one_hot_body_style, embedded_country])(
+        features)
+    self.assertEqual(1 + 3 + 5, net.shape[1])
+    with _initialized_session() as sess:
+
+      # Each row is formed by concatenating `embedded_body_style`,
+      # `one_hot_body_style`, and `price` in order.
+      self.assertAllEqual([[0., 0., 1., 11., 12., 13., 14., 15., 11.],
+                           [1., 0., 0., 1., 2., 3., 4., 5., 12.]],
+                          sess.run(net))
+
+  @test_util.run_deprecated_v1
+  def test_with_1d_unknown_shape_sparse_tensor(self):
+    embedding_values = (
+        (1., 2.),  # id 0
+        (6., 7.),  # id 1
+        (11., 12.)  # id 2
+    )
+
+    def _initializer(shape, dtype, partition_info=None):
+      del shape, dtype, partition_info
+      return embedding_values
+
+    # price has 1 dimension in dense_features
+    price = fc.numeric_column('price')
+
+    # one_hot_body_style has 3 dims in dense_features.
+    body_style = fc.categorical_column_with_vocabulary_list(
+        'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+    one_hot_body_style = fc.indicator_column(body_style)
+
+    # embedded_body_style has 5 dims in dense_features.
+    country = fc.categorical_column_with_vocabulary_list(
+        'country', vocabulary_list=['US', 'JP', 'CA'])
+    embedded_country = fc.embedding_column(
+        country, dimension=2, initializer=_initializer)
+
+    # Provides 1-dim tensor and dense tensor.
+    features = {
+        'price': array_ops.placeholder(dtypes.float32),
+        'body-style': array_ops.sparse_placeholder(dtypes.string),
+        # This is dense tensor for the categorical_column.
+        'country': array_ops.placeholder(dtypes.string),
+    }
+    self.assertIsNone(features['price'].shape.ndims)
+    self.assertIsNone(features['body-style'].get_shape().ndims)
+    self.assertIsNone(features['country'].shape.ndims)
+
+    price_data = np.array([11., 12.])
+    body_style_data = sparse_tensor.SparseTensorValue(
+        indices=((0,), (1,)), values=('sedan', 'hardtop'), dense_shape=(2,))
+    country_data = np.array([['US'], ['CA']])
+
+    net = df.DenseFeatures([price, one_hot_body_style, embedded_country])(
+        features)
+    self.assertEqual(1 + 3 + 2, net.shape[1])
+    with _initialized_session() as sess:
+
+      # Each row is formed by concatenating `embedded_body_style`,
+      # `one_hot_body_style`, and `price` in order.
+      self.assertAllEqual(
+          [[0., 0., 1., 1., 2., 11.], [1., 0., 0., 11., 12., 12.]],
+          sess.run(
+              net,
+              feed_dict={
+                  features['price']: price_data,
+                  features['body-style']: body_style_data,
+                  features['country']: country_data
+              }))
+
+  @test_util.run_deprecated_v1
+  def test_with_rank_0_feature(self):
+    # price has 1 dimension in dense_features
+    price = fc.numeric_column('price')
+    features = {
+        'price': constant_op.constant(0),
+    }
+    self.assertEqual(0, features['price'].shape.ndims)
+
+    # Static rank 0 should fail
+    with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
+      df.DenseFeatures([price])(features)
+
+    # Dynamic rank 0 should fail
+    features = {
+        'price': array_ops.placeholder(dtypes.float32),
+    }
+    net = df.DenseFeatures([price])(features)
+    self.assertEqual(1, net.shape[1])
+    with _initialized_session() as sess:
+      with self.assertRaisesOpError('Feature .* cannot have rank 0'):
+        sess.run(net, feed_dict={features['price']: np.array(1)})
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/python/feature_column/dense_features_v2.py b/tensorflow/python/feature_column/dense_features_v2.py
new file mode 100644
index 0000000..3d17b48
--- /dev/null
+++ b/tensorflow/python/feature_column/dense_features_v2.py
@@ -0,0 +1,93 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A layer that produces a dense `Tensor` based on given `feature_columns`."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.feature_column import dense_features
+from tensorflow.python.feature_column import feature_column_v2 as fc
+from tensorflow.python.framework import ops
+from tensorflow.python.util.tf_export import keras_export
+
+
+@keras_export('keras.layers.DenseFeatures', v1=[])
+class DenseFeatures(dense_features.DenseFeatures):
+  """A layer that produces a dense `Tensor` based on given `feature_columns`.
+
+  Generally a single example in training data is described with FeatureColumns.
+  At the first layer of the model, this column oriented data should be converted
+  to a single `Tensor`.
+
+  This layer can be called multiple times with different features.
+
+  This is the V2 version of this layer that uses name_scopes to create
+  variables instead of variable_scopes. But this approach currently lacks
+  support for partitioned variables. In that case, use the V1 version instead.
+
+  Example:
+
+  ```python
+  price = numeric_column('price')
+  keywords_embedded = embedding_column(
+      categorical_column_with_hash_bucket("keywords", 10K), dimensions=16)
+  columns = [price, keywords_embedded, ...]
+  feature_layer = DenseFeatures(columns)
+
+  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
+  dense_tensor = feature_layer(features)
+  for units in [128, 64, 32]:
+    dense_tensor = tf.keras.layers.Dense(units, activation='relu')(dense_tensor)
+  prediction = tf.keras.layers.Dense(1)(dense_tensor)
+  ```
+  """
+
+  def __init__(self,
+               feature_columns,
+               trainable=True,
+               name=None,
+               **kwargs):
+    """Creates a DenseFeatures object.
+
+    Args:
+      feature_columns: An iterable containing the FeatureColumns to use as
+        inputs to your model. All items should be instances of classes derived
+        from `DenseColumn` such as `numeric_column`, `embedding_column`,
+        `bucketized_column`, `indicator_column`. If you have categorical
+        features, you can wrap them with an `embedding_column` or
+        `indicator_column`.
+      trainable:  Boolean, whether the layer's variables will be updated via
+        gradient descent during training.
+      name: Name to give to the DenseFeatures.
+      **kwargs: Keyword arguments to construct a layer.
+
+    Raises:
+      ValueError: if an item in `feature_columns` is not a `DenseColumn`.
+    """
+    super(DenseFeatures, self).__init__(
+        feature_columns=feature_columns,
+        trainable=trainable,
+        name=name,
+        **kwargs)
+    self._state_manager = fc._StateManagerImplV2(self, self.trainable)  # pylint: disable=protected-access
+
+  def build(self, _):
+    for column in self._feature_columns:
+      with ops.name_scope(column.name):
+        column.create_state(self._state_manager)
+    # We would like to call Layer.build and not _DenseFeaturesHelper.build.
+    # pylint: disable=protected-access
+    super(fc._BaseFeaturesLayer, self).build(None)  # pylint: disable=bad-super-call
diff --git a/tensorflow/python/feature_column/dense_features_v2_test.py b/tensorflow/python/feature_column/dense_features_v2_test.py
new file mode 100644
index 0000000..a281d8c
--- /dev/null
+++ b/tensorflow/python/feature_column/dense_features_v2_test.py
@@ -0,0 +1,627 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for dense_features_v2."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.feature_column import dense_features_v2 as df
+from tensorflow.python.feature_column import feature_column_v2 as fc
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import variables as variables_lib
+from tensorflow.python.platform import test
+
+
+def _initialized_session(config=None):
+  sess = session.Session(config=config)
+  sess.run(variables_lib.global_variables_initializer())
+  sess.run(lookup_ops.tables_initializer())
+  return sess
+
+
+class DenseFeaturesTest(test.TestCase):
+
+  @test_util.run_in_graph_and_eager_modes()
+  def test_retrieving_input(self):
+    features = {'a': [0.]}
+    dense_features = df.DenseFeatures(fc.numeric_column('a'))
+    inputs = self.evaluate(dense_features(features))
+    self.assertAllClose([[0.]], inputs)
+
+  def test_reuses_variables(self):
+    with context.eager_mode():
+      sparse_input = sparse_tensor.SparseTensor(
+          indices=((0, 0), (1, 0), (2, 0)),
+          values=(0, 1, 2),
+          dense_shape=(3, 3))
+
+      # Create feature columns (categorical and embedding).
+      categorical_column = fc.categorical_column_with_identity(
+          key='a', num_buckets=3)
+      embedding_dimension = 2
+
+      def _embedding_column_initializer(shape, dtype, partition_info=None):
+        del shape  # unused
+        del dtype  # unused
+        del partition_info  # unused
+        embedding_values = (
+            (1, 0),  # id 0
+            (0, 1),  # id 1
+            (1, 1))  # id 2
+        return embedding_values
+
+      embedding_column = fc.embedding_column(
+          categorical_column,
+          dimension=embedding_dimension,
+          initializer=_embedding_column_initializer)
+
+      dense_features = df.DenseFeatures([embedding_column])
+      features = {'a': sparse_input}
+
+      inputs = dense_features(features)
+      variables = dense_features.variables
+
+      # Sanity check: test that the inputs are correct.
+      self.assertAllEqual([[1, 0], [0, 1], [1, 1]], inputs)
+
+      # Check that only one variable was created.
+      self.assertEqual(1, len(variables))
+
+      # Check that invoking dense_features on the same features does not create
+      # additional variables
+      _ = dense_features(features)
+      self.assertEqual(1, len(variables))
+      self.assertEqual(variables[0], dense_features.variables[0])
+
+  def test_feature_column_dense_features_gradient(self):
+    with context.eager_mode():
+      sparse_input = sparse_tensor.SparseTensor(
+          indices=((0, 0), (1, 0), (2, 0)),
+          values=(0, 1, 2),
+          dense_shape=(3, 3))
+
+      # Create feature columns (categorical and embedding).
+      categorical_column = fc.categorical_column_with_identity(
+          key='a', num_buckets=3)
+      embedding_dimension = 2
+
+      def _embedding_column_initializer(shape, dtype, partition_info=None):
+        del shape  # unused
+        del dtype  # unused
+        del partition_info  # unused
+        embedding_values = (
+            (1, 0),  # id 0
+            (0, 1),  # id 1
+            (1, 1))  # id 2
+        return embedding_values
+
+      embedding_column = fc.embedding_column(
+          categorical_column,
+          dimension=embedding_dimension,
+          initializer=_embedding_column_initializer)
+
+      dense_features = df.DenseFeatures([embedding_column])
+      features = {'a': sparse_input}
+
+      def scale_matrix():
+        matrix = dense_features(features)
+        return 2 * matrix
+
+      # Sanity check: Verify that scale_matrix returns the correct output.
+      self.assertAllEqual([[2, 0], [0, 2], [2, 2]], scale_matrix())
+
+      # Check that the returned gradient is correct.
+      grad_function = backprop.implicit_grad(scale_matrix)
+      grads_and_vars = grad_function()
+      indexed_slice = grads_and_vars[0][0]
+      gradient = grads_and_vars[0][0].values
+
+      self.assertAllEqual([0, 1, 2], indexed_slice.indices)
+      self.assertAllEqual([[2, 2], [2, 2], [2, 2]], gradient)
+
+  def test_raises_if_empty_feature_columns(self):
+    with self.assertRaisesRegexp(ValueError,
+                                 'feature_columns must not be empty'):
+      df.DenseFeatures(feature_columns=[])(features={})
+
+  def test_should_be_dense_column(self):
+    with self.assertRaisesRegexp(ValueError, 'must be a .*DenseColumn'):
+      df.DenseFeatures(feature_columns=[
+          fc.categorical_column_with_hash_bucket('wire_cast', 4)
+      ])(
+          features={
+              'a': [[0]]
+          })
+
+  def test_does_not_support_dict_columns(self):
+    with self.assertRaisesRegexp(
+        ValueError, 'Expected feature_columns to be iterable, found dict.'):
+      df.DenseFeatures(feature_columns={'a': fc.numeric_column('a')})(
+          features={
+              'a': [[0]]
+          })
+
+  def test_bare_column(self):
+    with ops.Graph().as_default():
+      features = features = {'a': [0.]}
+      net = df.DenseFeatures(fc.numeric_column('a'))(features)
+
+      self.evaluate(variables_lib.global_variables_initializer())
+      self.evaluate(lookup_ops.tables_initializer())
+
+      self.assertAllClose([[0.]], self.evaluate(net))
+
+  def test_column_generator(self):
+    with ops.Graph().as_default():
+      features = features = {'a': [0.], 'b': [1.]}
+      columns = (fc.numeric_column(key) for key in features)
+      net = df.DenseFeatures(columns)(features)
+
+      self.evaluate(variables_lib.global_variables_initializer())
+      self.evaluate(lookup_ops.tables_initializer())
+
+      self.assertAllClose([[0., 1.]], self.evaluate(net))
+
+  def test_raises_if_duplicate_name(self):
+    with self.assertRaisesRegexp(
+        ValueError, 'Duplicate feature column name found for columns'):
+      df.DenseFeatures(
+          feature_columns=[fc.numeric_column('a'),
+                           fc.numeric_column('a')])(
+                               features={
+                                   'a': [[0]]
+                               })
+
+  def test_one_column(self):
+    price = fc.numeric_column('price')
+    with ops.Graph().as_default():
+      features = {'price': [[1.], [5.]]}
+      net = df.DenseFeatures([price])(features)
+
+      self.evaluate(variables_lib.global_variables_initializer())
+      self.evaluate(lookup_ops.tables_initializer())
+
+      self.assertAllClose([[1.], [5.]], self.evaluate(net))
+
+  def test_multi_dimension(self):
+    price = fc.numeric_column('price', shape=2)
+    with ops.Graph().as_default():
+      features = {'price': [[1., 2.], [5., 6.]]}
+      net = df.DenseFeatures([price])(features)
+
+      self.evaluate(variables_lib.global_variables_initializer())
+      self.evaluate(lookup_ops.tables_initializer())
+
+      self.assertAllClose([[1., 2.], [5., 6.]], self.evaluate(net))
+
+  def test_compute_output_shape(self):
+    price1 = fc.numeric_column('price1', shape=2)
+    price2 = fc.numeric_column('price2', shape=4)
+    with ops.Graph().as_default():
+      features = {
+          'price1': [[1., 2.], [5., 6.]],
+          'price2': [[3., 4., 5., 6.], [7., 8., 9., 10.]]
+      }
+      dense_features = df.DenseFeatures([price1, price2])
+      self.assertEqual((None, 6), dense_features.compute_output_shape((None,)))
+      net = dense_features(features)
+
+      self.evaluate(variables_lib.global_variables_initializer())
+      self.evaluate(lookup_ops.tables_initializer())
+
+      self.assertAllClose([[1., 2., 3., 4., 5., 6.], [5., 6., 7., 8., 9., 10.]],
+                          self.evaluate(net))
+
+  def test_raises_if_shape_mismatch(self):
+    price = fc.numeric_column('price', shape=2)
+    with ops.Graph().as_default():
+      features = {'price': [[1.], [5.]]}
+      with self.assertRaisesRegexp(
+          Exception,
+          r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
+        df.DenseFeatures([price])(features)
+
+  def test_reshaping(self):
+    price = fc.numeric_column('price', shape=[1, 2])
+    with ops.Graph().as_default():
+      features = {'price': [[[1., 2.]], [[5., 6.]]]}
+      net = df.DenseFeatures([price])(features)
+
+      self.evaluate(variables_lib.global_variables_initializer())
+      self.evaluate(lookup_ops.tables_initializer())
+
+      self.assertAllClose([[1., 2.], [5., 6.]], self.evaluate(net))
+
+  def test_multi_column(self):
+    price1 = fc.numeric_column('price1', shape=2)
+    price2 = fc.numeric_column('price2')
+    with ops.Graph().as_default():
+      features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
+      net = df.DenseFeatures([price1, price2])(features)
+
+      self.evaluate(variables_lib.global_variables_initializer())
+      self.evaluate(lookup_ops.tables_initializer())
+
+      self.assertAllClose([[1., 2., 3.], [5., 6., 4.]], self.evaluate(net))
+
+  def test_cols_to_output_tensors(self):
+    price1 = fc.numeric_column('price1', shape=2)
+    price2 = fc.numeric_column('price2')
+    with ops.Graph().as_default():
+      cols_dict = {}
+      features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
+      dense_features = df.DenseFeatures([price1, price2])
+      net = dense_features(features, cols_dict)
+
+      self.evaluate(variables_lib.global_variables_initializer())
+      self.evaluate(lookup_ops.tables_initializer())
+
+      self.assertAllClose([[1., 2.], [5., 6.]],
+                          self.evaluate(cols_dict[price1]))
+      self.assertAllClose([[3.], [4.]], self.evaluate(cols_dict[price2]))
+      self.assertAllClose([[1., 2., 3.], [5., 6., 4.]], self.evaluate(net))
+
+  def test_column_order(self):
+    price_a = fc.numeric_column('price_a')
+    price_b = fc.numeric_column('price_b')
+    with ops.Graph().as_default():
+      features = {
+          'price_a': [[1.]],
+          'price_b': [[3.]],
+      }
+      net1 = df.DenseFeatures([price_a, price_b])(features)
+      net2 = df.DenseFeatures([price_b, price_a])(features)
+
+      self.evaluate(variables_lib.global_variables_initializer())
+      self.evaluate(lookup_ops.tables_initializer())
+
+      self.assertAllClose([[1., 3.]], self.evaluate(net1))
+      self.assertAllClose([[1., 3.]], self.evaluate(net2))
+
+  def test_fails_for_categorical_column(self):
+    animal = fc.categorical_column_with_identity('animal', num_buckets=4)
+    with ops.Graph().as_default():
+      features = {
+          'animal':
+              sparse_tensor.SparseTensor(
+                  indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
+      }
+      with self.assertRaisesRegexp(Exception, 'must be a .*DenseColumn'):
+        df.DenseFeatures([animal])(features)
+
+  def test_static_batch_size_mismatch(self):
+    price1 = fc.numeric_column('price1')
+    price2 = fc.numeric_column('price2')
+    with ops.Graph().as_default():
+      features = {
+          'price1': [[1.], [5.], [7.]],  # batchsize = 3
+          'price2': [[3.], [4.]]  # batchsize = 2
+      }
+      with self.assertRaisesRegexp(
+          ValueError,
+          r'Batch size \(first dimension\) of each feature must be same.'):  # pylint: disable=anomalous-backslash-in-string
+        df.DenseFeatures([price1, price2])(features)
+
+  def test_subset_of_static_batch_size_mismatch(self):
+    price1 = fc.numeric_column('price1')
+    price2 = fc.numeric_column('price2')
+    price3 = fc.numeric_column('price3')
+    with ops.Graph().as_default():
+      features = {
+          'price1': array_ops.placeholder(dtype=dtypes.int64),  # batchsize = 3
+          'price2': [[3.], [4.]],  # batchsize = 2
+          'price3': [[3.], [4.], [5.]]  # batchsize = 3
+      }
+      with self.assertRaisesRegexp(
+          ValueError,
+          r'Batch size \(first dimension\) of each feature must be same.'):  # pylint: disable=anomalous-backslash-in-string
+        df.DenseFeatures([price1, price2, price3])(features)
+
+  def test_runtime_batch_size_mismatch(self):
+    price1 = fc.numeric_column('price1')
+    price2 = fc.numeric_column('price2')
+    with ops.Graph().as_default():
+      features = {
+          'price1': array_ops.placeholder(dtype=dtypes.int64),  # batchsize = 3
+          'price2': [[3.], [4.]]  # batchsize = 2
+      }
+      net = df.DenseFeatures([price1, price2])(features)
+      with _initialized_session() as sess:
+        with self.assertRaisesRegexp(errors.OpError,
+                                     'Dimensions of inputs should match'):
+          sess.run(net, feed_dict={features['price1']: [[1.], [5.], [7.]]})
+
+  def test_runtime_batch_size_matches(self):
+    price1 = fc.numeric_column('price1')
+    price2 = fc.numeric_column('price2')
+    with ops.Graph().as_default():
+      features = {
+          'price1': array_ops.placeholder(dtype=dtypes.int64),  # batchsize = 2
+          'price2': array_ops.placeholder(dtype=dtypes.int64),  # batchsize = 2
+      }
+      net = df.DenseFeatures([price1, price2])(features)
+      with _initialized_session() as sess:
+        sess.run(
+            net,
+            feed_dict={
+                features['price1']: [[1.], [5.]],
+                features['price2']: [[1.], [5.]],
+            })
+
+  def test_multiple_layers_with_same_embedding_column(self):
+    some_sparse_column = fc.categorical_column_with_hash_bucket(
+        'sparse_feature', hash_bucket_size=5)
+    some_embedding_column = fc.embedding_column(
+        some_sparse_column, dimension=10)
+
+    with ops.Graph().as_default():
+      features = {
+          'sparse_feature': [['a'], ['x']],
+      }
+      all_cols = [some_embedding_column]
+      df.DenseFeatures(all_cols)(features)
+      df.DenseFeatures(all_cols)(features)
+      # Make sure that 2 variables get created in this case.
+      self.assertEqual(2,
+                       len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
+      expected_var_names = [
+          'dense_features/sparse_feature_embedding/embedding_weights:0',
+          'dense_features_1/sparse_feature_embedding/embedding_weights:0'
+      ]
+      self.assertItemsEqual(
+          expected_var_names,
+          [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+
+  @test_util.run_deprecated_v1
+  def test_multiple_layers_with_same_shared_embedding_column(self):
+    categorical_column_a = fc.categorical_column_with_identity(
+        key='aaa', num_buckets=3)
+    categorical_column_b = fc.categorical_column_with_identity(
+        key='bbb', num_buckets=3)
+    embedding_dimension = 2
+    embedding_column_b, embedding_column_a = fc.shared_embedding_columns_v2(
+        [categorical_column_b, categorical_column_a],
+        dimension=embedding_dimension)
+
+    with ops.Graph().as_default():
+      features = {
+          'aaa':
+              sparse_tensor.SparseTensor(
+                  indices=((0, 0), (1, 0), (1, 1)),
+                  values=(0, 1, 0),
+                  dense_shape=(2, 2)),
+          'bbb':
+              sparse_tensor.SparseTensor(
+                  indices=((0, 0), (1, 0), (1, 1)),
+                  values=(1, 2, 1),
+                  dense_shape=(2, 2)),
+      }
+      all_cols = [embedding_column_a, embedding_column_b]
+      df.DenseFeatures(all_cols)(features)
+      df.DenseFeatures(all_cols)(features)
+      # Make sure that only 1 variable gets created in this case.
+      self.assertEqual(1,
+                       len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
+      self.assertItemsEqual(
+          ['aaa_bbb_shared_embedding:0'],
+          [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+
+  @test_util.run_deprecated_v1
+  def test_multiple_layers_with_same_shared_embedding_column_diff_graphs(self):
+    categorical_column_a = fc.categorical_column_with_identity(
+        key='aaa', num_buckets=3)
+    categorical_column_b = fc.categorical_column_with_identity(
+        key='bbb', num_buckets=3)
+    embedding_dimension = 2
+    embedding_column_b, embedding_column_a = fc.shared_embedding_columns_v2(
+        [categorical_column_b, categorical_column_a],
+        dimension=embedding_dimension)
+    all_cols = [embedding_column_a, embedding_column_b]
+
+    with ops.Graph().as_default():
+      features = {
+          'aaa':
+              sparse_tensor.SparseTensor(
+                  indices=((0, 0), (1, 0), (1, 1)),
+                  values=(0, 1, 0),
+                  dense_shape=(2, 2)),
+          'bbb':
+              sparse_tensor.SparseTensor(
+                  indices=((0, 0), (1, 0), (1, 1)),
+                  values=(1, 2, 1),
+                  dense_shape=(2, 2)),
+      }
+      df.DenseFeatures(all_cols)(features)
+      # Make sure that only 1 variable gets created in this case.
+      self.assertEqual(1,
+                       len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
+
+    with ops.Graph().as_default():
+      features1 = {
+          'aaa':
+              sparse_tensor.SparseTensor(
+                  indices=((0, 0), (1, 0), (1, 1)),
+                  values=(0, 1, 0),
+                  dense_shape=(2, 2)),
+          'bbb':
+              sparse_tensor.SparseTensor(
+                  indices=((0, 0), (1, 0), (1, 1)),
+                  values=(1, 2, 1),
+                  dense_shape=(2, 2)),
+      }
+
+      df.DenseFeatures(all_cols)(features1)
+      # Make sure that only 1 variable gets created in this case.
+      self.assertEqual(1,
+                       len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
+      self.assertItemsEqual(
+          ['aaa_bbb_shared_embedding:0'],
+          [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+
+  @test_util.run_deprecated_v1
+  def test_with_1d_sparse_tensor(self):
+    embedding_values = (
+        (1., 2., 3., 4., 5.),  # id 0
+        (6., 7., 8., 9., 10.),  # id 1
+        (11., 12., 13., 14., 15.)  # id 2
+    )
+
+    def _initializer(shape, dtype, partition_info=None):
+      del shape, dtype, partition_info
+      return embedding_values
+
+    # price has 1 dimension in dense_features
+    price = fc.numeric_column('price')
+
+    # one_hot_body_style has 3 dims in dense_features.
+    body_style = fc.categorical_column_with_vocabulary_list(
+        'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+    one_hot_body_style = fc.indicator_column(body_style)
+
+    # embedded_body_style has 5 dims in dense_features.
+    country = fc.categorical_column_with_vocabulary_list(
+        'country', vocabulary_list=['US', 'JP', 'CA'])
+    embedded_country = fc.embedding_column(
+        country, dimension=5, initializer=_initializer)
+
+    # Provides 1-dim tensor and dense tensor.
+    features = {
+        'price':
+            constant_op.constant([
+                11.,
+                12.,
+            ]),
+        'body-style':
+            sparse_tensor.SparseTensor(
+                indices=((0,), (1,)),
+                values=('sedan', 'hardtop'),
+                dense_shape=(2,)),
+        # This is dense tensor for the categorical_column.
+        'country':
+            constant_op.constant(['CA', 'US']),
+    }
+    self.assertEqual(1, features['price'].shape.ndims)
+    self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
+    self.assertEqual(1, features['country'].shape.ndims)
+
+    net = df.DenseFeatures([price, one_hot_body_style, embedded_country])(
+        features)
+    self.assertEqual(1 + 3 + 5, net.shape[1])
+    with _initialized_session() as sess:
+
+      # Each row is formed by concatenating `embedded_body_style`,
+      # `one_hot_body_style`, and `price` in order.
+      self.assertAllEqual([[0., 0., 1., 11., 12., 13., 14., 15., 11.],
+                           [1., 0., 0., 1., 2., 3., 4., 5., 12.]],
+                          sess.run(net))
+
+  @test_util.run_deprecated_v1
+  def test_with_1d_unknown_shape_sparse_tensor(self):
+    embedding_values = (
+        (1., 2.),  # id 0
+        (6., 7.),  # id 1
+        (11., 12.)  # id 2
+    )
+
+    def _initializer(shape, dtype, partition_info=None):
+      del shape, dtype, partition_info
+      return embedding_values
+
+    # price has 1 dimension in dense_features
+    price = fc.numeric_column('price')
+
+    # one_hot_body_style has 3 dims in dense_features.
+    body_style = fc.categorical_column_with_vocabulary_list(
+        'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+    one_hot_body_style = fc.indicator_column(body_style)
+
+    # embedded_body_style has 5 dims in dense_features.
+    country = fc.categorical_column_with_vocabulary_list(
+        'country', vocabulary_list=['US', 'JP', 'CA'])
+    embedded_country = fc.embedding_column(
+        country, dimension=2, initializer=_initializer)
+
+    # Provides 1-dim tensor and dense tensor.
+    features = {
+        'price': array_ops.placeholder(dtypes.float32),
+        'body-style': array_ops.sparse_placeholder(dtypes.string),
+        # This is dense tensor for the categorical_column.
+        'country': array_ops.placeholder(dtypes.string),
+    }
+    self.assertIsNone(features['price'].shape.ndims)
+    self.assertIsNone(features['body-style'].get_shape().ndims)
+    self.assertIsNone(features['country'].shape.ndims)
+
+    price_data = np.array([11., 12.])
+    body_style_data = sparse_tensor.SparseTensorValue(
+        indices=((0,), (1,)), values=('sedan', 'hardtop'), dense_shape=(2,))
+    country_data = np.array([['US'], ['CA']])
+
+    net = df.DenseFeatures([price, one_hot_body_style, embedded_country])(
+        features)
+    self.assertEqual(1 + 3 + 2, net.shape[1])
+    with _initialized_session() as sess:
+
+      # Each row is formed by concatenating `embedded_body_style`,
+      # `one_hot_body_style`, and `price` in order.
+      self.assertAllEqual(
+          [[0., 0., 1., 1., 2., 11.], [1., 0., 0., 11., 12., 12.]],
+          sess.run(
+              net,
+              feed_dict={
+                  features['price']: price_data,
+                  features['body-style']: body_style_data,
+                  features['country']: country_data
+              }))
+
+  @test_util.run_deprecated_v1
+  def test_with_rank_0_feature(self):
+    # price has 1 dimension in dense_features
+    price = fc.numeric_column('price')
+    features = {
+        'price': constant_op.constant(0),
+    }
+    self.assertEqual(0, features['price'].shape.ndims)
+
+    # Static rank 0 should fail
+    with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
+      df.DenseFeatures([price])(features)
+
+    # Dynamic rank 0 should fail
+    features = {
+        'price': array_ops.placeholder(dtypes.float32),
+    }
+    net = df.DenseFeatures([price])(features)
+    self.assertEqual(1, net.shape[1])
+    with _initialized_session() as sess:
+      with self.assertRaisesOpError('Feature .* cannot have rank 0'):
+        sess.run(net, feed_dict={features['price']: np.array(1)})
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index f783f21..ff33612 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -166,6 +166,7 @@
 from tensorflow.python.training import checkpoint_utils
 from tensorflow.python.util import nest
 from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.compat import collections_abc
 
 
 def _internal_input_layer(features,
@@ -2287,7 +2288,7 @@
   if isinstance(feature_columns, _FeatureColumn):
     feature_columns = [feature_columns]
 
-  if isinstance(feature_columns, collections.Iterator):
+  if isinstance(feature_columns, collections_abc.Iterator):
     feature_columns = list(feature_columns)
 
   if isinstance(feature_columns, dict):
diff --git a/tensorflow/python/feature_column/feature_column_lib.py b/tensorflow/python/feature_column/feature_column_lib.py
index 6b4dfe6..6a99584 100644
--- a/tensorflow/python/feature_column/feature_column_lib.py
+++ b/tensorflow/python/feature_column/feature_column_lib.py
@@ -18,7 +18,11 @@
 from __future__ import division
 from __future__ import print_function
 
-# pylint: disable=unused-import,line-too-long,wildcard-import
+# pylint: disable=unused-import,line-too-long,wildcard-import,g-bad-import-order
+# We import dense_features_v2 first so that the V1 DenseFeatures is the default
+# if users directly import feature_column_lib.
+from tensorflow.python.feature_column.dense_features_v2 import *
+from tensorflow.python.feature_column.dense_features import *
 from tensorflow.python.feature_column.feature_column import *
 from tensorflow.python.feature_column.feature_column_v2 import *
 from tensorflow.python.feature_column.sequence_feature_column import *
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index 0ded2bf..e1bdef8 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -4394,8 +4394,7 @@
     id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
     self.assertIsNone(id_weight_pair.weight_tensor)
     with _initialized_session():
-      with self.assertRaisesRegexp(
-          errors.OpError, 'assert_greater_or_equal_0'):
+      with self.assertRaisesRegexp(errors.OpError, 'assert'):
         id_weight_pair.id_tensor.eval()
 
   @test_util.run_deprecated_v1
@@ -4408,8 +4407,7 @@
     id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
     self.assertIsNone(id_weight_pair.weight_tensor)
     with _initialized_session():
-      with self.assertRaisesRegexp(
-          errors.OpError, 'assert_less_than_num_buckets'):
+      with self.assertRaisesRegexp(errors.OpError, 'assert'):
         id_weight_pair.id_tensor.eval()
 
   @test_util.run_deprecated_v1
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index 260d0a2..8c6778d 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -167,8 +167,8 @@
 from tensorflow.python.training.tracking import tracking
 from tensorflow.python.util import deprecation
 from tensorflow.python.util import nest
-from tensorflow.python.util.tf_export import keras_export
 from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.compat import collections_abc
 
 
 _FEATURE_COLUMN_DEPRECATION_DATE = None
@@ -318,6 +318,31 @@
     raise ValueError('Resource does not exist.')
 
 
+class _StateManagerImplV2(_StateManagerImpl):
+  """Manages the state of DenseFeatures."""
+
+  def create_variable(self,
+                      feature_column,
+                      name,
+                      shape,
+                      dtype=None,
+                      trainable=True,
+                      use_resource=True,
+                      initializer=None):
+    if name in self._cols_to_vars_map[feature_column]:
+      raise ValueError('Variable already exists.')
+
+    var = self._layer.add_variable(
+        name=name,
+        shape=shape,
+        dtype=dtype,
+        initializer=initializer,
+        trainable=self._trainable and trainable,
+        use_resource=use_resource)
+    self._cols_to_vars_map[feature_column][name] = var
+    return var
+
+
 class _BaseFeaturesLayer(Layer):
   """Base class for DenseFeatures and SequenceFeatures.
 
@@ -415,104 +440,6 @@
     return cls(**config_cp)
 
 
-@keras_export('keras.layers.DenseFeatures')
-class DenseFeatures(_BaseFeaturesLayer):
-  """A layer that produces a dense `Tensor` based on given `feature_columns`.
-
-  Generally a single example in training data is described with FeatureColumns.
-  At the first layer of the model, this column oriented data should be converted
-  to a single `Tensor`.
-
-  This layer can be called multiple times with different features.
-
-  Example:
-
-  ```python
-  price = numeric_column('price')
-  keywords_embedded = embedding_column(
-      categorical_column_with_hash_bucket("keywords", 10K), dimensions=16)
-  columns = [price, keywords_embedded, ...]
-  feature_layer = DenseFeatures(columns)
-
-  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
-  dense_tensor = feature_layer(features)
-  for units in [128, 64, 32]:
-    dense_tensor = tf.keras.layers.Dense(units, activation='relu')(dense_tensor)
-  prediction = tf.keras.layers.Dense(1)(dense_tensor)
-  ```
-  """
-
-  def __init__(self,
-               feature_columns,
-               trainable=True,
-               name=None,
-               **kwargs):
-    """Constructs a DenseFeatures.
-
-    Args:
-      feature_columns: An iterable containing the FeatureColumns to use as
-        inputs to your model. All items should be instances of classes derived
-        from `DenseColumn` such as `numeric_column`, `embedding_column`,
-        `bucketized_column`, `indicator_column`. If you have categorical
-        features, you can wrap them with an `embedding_column` or
-        `indicator_column`.
-      trainable:  Boolean, whether the layer's variables will be updated via
-        gradient descent during training.
-      name: Name to give to the DenseFeatures.
-      **kwargs: Keyword arguments to construct a layer.
-
-    Raises:
-      ValueError: if an item in `feature_columns` is not a `DenseColumn`.
-    """
-    super(DenseFeatures, self).__init__(
-        feature_columns=feature_columns,
-        trainable=trainable,
-        name=name,
-        expected_column_type=DenseColumn,
-        **kwargs)
-
-  @property
-  def _is_feature_layer(self):
-    return True
-
-  def _target_shape(self, input_shape, total_elements):
-    return (input_shape[0], total_elements)
-
-  def call(self, features, cols_to_output_tensors=None):
-    """Returns a dense tensor corresponding to the `feature_columns`.
-
-    Args:
-      features: A mapping from key to tensors. `FeatureColumn`s look up via
-        these keys. For example `numeric_column('price')` will look at 'price'
-        key in this dict. Values can be a `SparseTensor` or a `Tensor` depends
-        on corresponding `FeatureColumn`.
-      cols_to_output_tensors: If not `None`, this will be filled with a dict
-        mapping feature columns to output tensors created.
-
-    Returns:
-      A `Tensor` which represents input layer of a model. Its shape
-      is (batch_size, first_layer_dimension) and its dtype is `float32`.
-      first_layer_dimension is determined based on given `feature_columns`.
-
-    Raises:
-      ValueError: If features are not a dictionary.
-    """
-    if not isinstance(features, dict):
-      raise ValueError('We expected a dictionary here. Instead we got: ',
-                       features)
-    transformation_cache = FeatureTransformationCache(features)
-    output_tensors = []
-    for column in self._feature_columns:
-      with ops.name_scope(column.name):
-        tensor = column.get_dense_tensor(transformation_cache,
-                                         self._state_manager)
-        processed_tensors = self._process_dense_tensor(column, tensor)
-        if cols_to_output_tensors is not None:
-          cols_to_output_tensors[column] = processed_tensors
-        output_tensors.append(processed_tensors)
-    return self._verify_and_concat_tensors(output_tensors)
-
-
 class _LinearModelLayer(Layer):
   """Layer that contains logic for `LinearModel`."""
 
@@ -2784,7 +2711,7 @@
   if isinstance(feature_columns, FeatureColumn):
     feature_columns = [feature_columns]
 
-  if isinstance(feature_columns, collections.Iterator):
+  if isinstance(feature_columns, collections_abc.Iterator):
     feature_columns = list(feature_columns)
 
   if isinstance(feature_columns, dict):
diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py
index 5b4c263..253ed9e 100644
--- a/tensorflow/python/feature_column/feature_column_v2_test.py
+++ b/tensorflow/python/feature_column/feature_column_v2_test.py
@@ -31,6 +31,7 @@
 from tensorflow.python.client import session
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
+from tensorflow.python.feature_column import dense_features as df
 from tensorflow.python.feature_column import feature_column as fc_old
 from tensorflow.python.feature_column import feature_column_v2 as fc
 from tensorflow.python.framework import constant_op
@@ -47,10 +48,7 @@
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables as variables_lib
 from tensorflow.python.platform import test
-from tensorflow.python.training import coordinator
-from tensorflow.python.training import queue_runner_impl
 from tensorflow.python.training import rmsprop
-from tensorflow_estimator.python.estimator.inputs import numpy_io
 
 
 def _initialized_session(config=None):
@@ -2153,45 +2151,6 @@
             })
 
   @test_util.run_deprecated_v1
-  def test_with_numpy_input_fn(self):
-    price = fc.numeric_column('price')
-    price_buckets = fc.bucketized_column(
-        price, boundaries=[
-            0.,
-            10.,
-            100.,
-        ])
-    body_style = fc.categorical_column_with_vocabulary_list(
-        'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
-
-    input_fn = numpy_io.numpy_input_fn(
-        x={
-            'price': np.array([-1., 2., 13., 104.]),
-            'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),
-        },
-        batch_size=2,
-        shuffle=False)
-    features = input_fn()
-    model = fc.LinearModel([price_buckets, body_style])
-    net = model(features)
-    # self.assertEqual(1 + 3 + 5, net.shape[1])
-    with _initialized_session() as sess:
-      coord = coordinator.Coordinator()
-      threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
-
-      body_style_var, price_buckets_var, bias = model.variables
-
-      sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
-      sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
-      sess.run(bias.assign([5.]))
-
-      self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]],
-                          self.evaluate(net))
-
-      coord.request_stop()
-      coord.join(threads)
-
-  @test_util.run_deprecated_v1
   def test_with_1d_sparse_tensor(self):
     price = fc.numeric_column('price')
     price_buckets = fc.bucketized_column(
@@ -3201,630 +3160,6 @@
         fc_old.linear_model(features, all_cols)
 
 
-class DenseFeaturesTest(test.TestCase):
-
-  @test_util.run_in_graph_and_eager_modes()
-  def test_retrieving_input(self):
-    features = {'a': [0.]}
-    dense_features = fc.DenseFeatures(fc.numeric_column('a'))
-    inputs = self.evaluate(dense_features(features))
-    self.assertAllClose([[0.]], inputs)
-
-  def test_reuses_variables(self):
-    with context.eager_mode():
-      sparse_input = sparse_tensor.SparseTensor(
-          indices=((0, 0), (1, 0), (2, 0)),
-          values=(0, 1, 2),
-          dense_shape=(3, 3))
-
-      # Create feature columns (categorical and embedding).
-      categorical_column = fc.categorical_column_with_identity(
-          key='a', num_buckets=3)
-      embedding_dimension = 2
-
-      def _embedding_column_initializer(shape, dtype, partition_info):
-        del shape  # unused
-        del dtype  # unused
-        del partition_info  # unused
-        embedding_values = (
-            (1, 0),  # id 0
-            (0, 1),  # id 1
-            (1, 1))  # id 2
-        return embedding_values
-
-      embedding_column = fc.embedding_column(
-          categorical_column,
-          dimension=embedding_dimension,
-          initializer=_embedding_column_initializer)
-
-      dense_features = fc.DenseFeatures([embedding_column])
-      features = {'a': sparse_input}
-
-      inputs = dense_features(features)
-      variables = dense_features.variables
-
-      # Sanity check: test that the inputs are correct.
-      self.assertAllEqual([[1, 0], [0, 1], [1, 1]], inputs)
-
-      # Check that only one variable was created.
-      self.assertEqual(1, len(variables))
-
-      # Check that invoking dense_features on the same features does not create
-      # additional variables
-      _ = dense_features(features)
-      self.assertEqual(1, len(variables))
-      self.assertEqual(variables[0], dense_features.variables[0])
-
-  def test_feature_column_dense_features_gradient(self):
-    with context.eager_mode():
-      sparse_input = sparse_tensor.SparseTensor(
-          indices=((0, 0), (1, 0), (2, 0)),
-          values=(0, 1, 2),
-          dense_shape=(3, 3))
-
-      # Create feature columns (categorical and embedding).
-      categorical_column = fc.categorical_column_with_identity(
-          key='a', num_buckets=3)
-      embedding_dimension = 2
-
-      def _embedding_column_initializer(shape, dtype, partition_info):
-        del shape  # unused
-        del dtype  # unused
-        del partition_info  # unused
-        embedding_values = (
-            (1, 0),  # id 0
-            (0, 1),  # id 1
-            (1, 1))  # id 2
-        return embedding_values
-
-      embedding_column = fc.embedding_column(
-          categorical_column,
-          dimension=embedding_dimension,
-          initializer=_embedding_column_initializer)
-
-      dense_features = fc.DenseFeatures([embedding_column])
-      features = {'a': sparse_input}
-
-      def scale_matrix():
-        matrix = dense_features(features)
-        return 2 * matrix
-
-      # Sanity check: Verify that scale_matrix returns the correct output.
-      self.assertAllEqual([[2, 0], [0, 2], [2, 2]], scale_matrix())
-
-      # Check that the returned gradient is correct.
-      grad_function = backprop.implicit_grad(scale_matrix)
-      grads_and_vars = grad_function()
-      indexed_slice = grads_and_vars[0][0]
-      gradient = grads_and_vars[0][0].values
-
-      self.assertAllEqual([0, 1, 2], indexed_slice.indices)
-      self.assertAllEqual([[2, 2], [2, 2], [2, 2]], gradient)
-
-  def test_raises_if_empty_feature_columns(self):
-    with self.assertRaisesRegexp(ValueError,
-                                 'feature_columns must not be empty'):
-      fc.DenseFeatures(feature_columns=[])(features={})
-
-  def test_should_be_dense_column(self):
-    with self.assertRaisesRegexp(ValueError, 'must be a .*DenseColumn'):
-      fc.DenseFeatures(feature_columns=[
-          fc.categorical_column_with_hash_bucket('wire_cast', 4)
-      ])(
-          features={
-              'a': [[0]]
-          })
-
-  def test_does_not_support_dict_columns(self):
-    with self.assertRaisesRegexp(
-        ValueError, 'Expected feature_columns to be iterable, found dict.'):
-      fc.DenseFeatures(feature_columns={'a': fc.numeric_column('a')})(
-          features={
-              'a': [[0]]
-          })
-
-  def test_bare_column(self):
-    with ops.Graph().as_default():
-      features = features = {'a': [0.]}
-      net = fc.DenseFeatures(fc.numeric_column('a'))(features)
-
-      self.evaluate(variables_lib.global_variables_initializer())
-      self.evaluate(lookup_ops.tables_initializer())
-
-      self.assertAllClose([[0.]], self.evaluate(net))
-
-  def test_column_generator(self):
-    with ops.Graph().as_default():
-      features = features = {'a': [0.], 'b': [1.]}
-      columns = (fc.numeric_column(key) for key in features)
-      net = fc.DenseFeatures(columns)(features)
-
-      self.evaluate(variables_lib.global_variables_initializer())
-      self.evaluate(lookup_ops.tables_initializer())
-
-      self.assertAllClose([[0., 1.]], self.evaluate(net))
-
-  def test_raises_if_duplicate_name(self):
-    with self.assertRaisesRegexp(
-        ValueError, 'Duplicate feature column name found for columns'):
-      fc.DenseFeatures(
-          feature_columns=[fc.numeric_column('a'),
-                           fc.numeric_column('a')])(
-                               features={
-                                   'a': [[0]]
-                               })
-
-  def test_one_column(self):
-    price = fc.numeric_column('price')
-    with ops.Graph().as_default():
-      features = {'price': [[1.], [5.]]}
-      net = fc.DenseFeatures([price])(features)
-
-      self.evaluate(variables_lib.global_variables_initializer())
-      self.evaluate(lookup_ops.tables_initializer())
-
-      self.assertAllClose([[1.], [5.]], self.evaluate(net))
-
-  def test_multi_dimension(self):
-    price = fc.numeric_column('price', shape=2)
-    with ops.Graph().as_default():
-      features = {'price': [[1., 2.], [5., 6.]]}
-      net = fc.DenseFeatures([price])(features)
-
-      self.evaluate(variables_lib.global_variables_initializer())
-      self.evaluate(lookup_ops.tables_initializer())
-
-      self.assertAllClose([[1., 2.], [5., 6.]], self.evaluate(net))
-
-  def test_compute_output_shape(self):
-    price1 = fc.numeric_column('price1', shape=2)
-    price2 = fc.numeric_column('price2', shape=4)
-    with ops.Graph().as_default():
-      features = {
-          'price1': [[1., 2.], [5., 6.]],
-          'price2': [[3., 4., 5., 6.], [7., 8., 9., 10.]]
-      }
-      dense_features = fc.DenseFeatures([price1, price2])
-      self.assertEqual((None, 6), dense_features.compute_output_shape((None,)))
-      net = dense_features(features)
-
-      self.evaluate(variables_lib.global_variables_initializer())
-      self.evaluate(lookup_ops.tables_initializer())
-
-      self.assertAllClose([[1., 2., 3., 4., 5., 6.], [5., 6., 7., 8., 9., 10.]],
-                          self.evaluate(net))
-
-  def test_raises_if_shape_mismatch(self):
-    price = fc.numeric_column('price', shape=2)
-    with ops.Graph().as_default():
-      features = {'price': [[1.], [5.]]}
-      with self.assertRaisesRegexp(
-          Exception,
-          r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
-        fc.DenseFeatures([price])(features)
-
-  def test_reshaping(self):
-    price = fc.numeric_column('price', shape=[1, 2])
-    with ops.Graph().as_default():
-      features = {'price': [[[1., 2.]], [[5., 6.]]]}
-      net = fc.DenseFeatures([price])(features)
-
-      self.evaluate(variables_lib.global_variables_initializer())
-      self.evaluate(lookup_ops.tables_initializer())
-
-      self.assertAllClose([[1., 2.], [5., 6.]], self.evaluate(net))
-
-  def test_multi_column(self):
-    price1 = fc.numeric_column('price1', shape=2)
-    price2 = fc.numeric_column('price2')
-    with ops.Graph().as_default():
-      features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
-      net = fc.DenseFeatures([price1, price2])(features)
-
-      self.evaluate(variables_lib.global_variables_initializer())
-      self.evaluate(lookup_ops.tables_initializer())
-
-      self.assertAllClose([[1., 2., 3.], [5., 6., 4.]], self.evaluate(net))
-
-  def test_cols_to_output_tensors(self):
-    price1 = fc.numeric_column('price1', shape=2)
-    price2 = fc.numeric_column('price2')
-    with ops.Graph().as_default():
-      cols_dict = {}
-      features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
-      dense_features = fc.DenseFeatures([price1, price2])
-      net = dense_features(features, cols_dict)
-
-      self.evaluate(variables_lib.global_variables_initializer())
-      self.evaluate(lookup_ops.tables_initializer())
-
-      self.assertAllClose([[1., 2.], [5., 6.]],
-                          self.evaluate(cols_dict[price1]))
-      self.assertAllClose([[3.], [4.]], self.evaluate(cols_dict[price2]))
-      self.assertAllClose([[1., 2., 3.], [5., 6., 4.]], self.evaluate(net))
-
-  def test_column_order(self):
-    price_a = fc.numeric_column('price_a')
-    price_b = fc.numeric_column('price_b')
-    with ops.Graph().as_default():
-      features = {
-          'price_a': [[1.]],
-          'price_b': [[3.]],
-      }
-      net1 = fc.DenseFeatures([price_a, price_b])(features)
-      net2 = fc.DenseFeatures([price_b, price_a])(features)
-
-      self.evaluate(variables_lib.global_variables_initializer())
-      self.evaluate(lookup_ops.tables_initializer())
-
-      self.assertAllClose([[1., 3.]], self.evaluate(net1))
-      self.assertAllClose([[1., 3.]], self.evaluate(net2))
-
-  def test_fails_for_categorical_column(self):
-    animal = fc.categorical_column_with_identity('animal', num_buckets=4)
-    with ops.Graph().as_default():
-      features = {
-          'animal':
-              sparse_tensor.SparseTensor(
-                  indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
-      }
-      with self.assertRaisesRegexp(Exception, 'must be a .*DenseColumn'):
-        fc.DenseFeatures([animal])(features)
-
-  def test_static_batch_size_mismatch(self):
-    price1 = fc.numeric_column('price1')
-    price2 = fc.numeric_column('price2')
-    with ops.Graph().as_default():
-      features = {
-          'price1': [[1.], [5.], [7.]],  # batchsize = 3
-          'price2': [[3.], [4.]]  # batchsize = 2
-      }
-      with self.assertRaisesRegexp(
-          ValueError,
-          r'Batch size \(first dimension\) of each feature must be same.'):  # pylint: disable=anomalous-backslash-in-string
-        fc.DenseFeatures([price1, price2])(features)
-
-  def test_subset_of_static_batch_size_mismatch(self):
-    price1 = fc.numeric_column('price1')
-    price2 = fc.numeric_column('price2')
-    price3 = fc.numeric_column('price3')
-    with ops.Graph().as_default():
-      features = {
-          'price1': array_ops.placeholder(dtype=dtypes.int64),  # batchsize = 3
-          'price2': [[3.], [4.]],  # batchsize = 2
-          'price3': [[3.], [4.], [5.]]  # batchsize = 3
-      }
-      with self.assertRaisesRegexp(
-          ValueError,
-          r'Batch size \(first dimension\) of each feature must be same.'):  # pylint: disable=anomalous-backslash-in-string
-        fc.DenseFeatures([price1, price2, price3])(features)
-
-  def test_runtime_batch_size_mismatch(self):
-    price1 = fc.numeric_column('price1')
-    price2 = fc.numeric_column('price2')
-    with ops.Graph().as_default():
-      features = {
-          'price1': array_ops.placeholder(dtype=dtypes.int64),  # batchsize = 3
-          'price2': [[3.], [4.]]  # batchsize = 2
-      }
-      net = fc.DenseFeatures([price1, price2])(features)
-      with _initialized_session() as sess:
-        with self.assertRaisesRegexp(errors.OpError,
-                                     'Dimensions of inputs should match'):
-          sess.run(net, feed_dict={features['price1']: [[1.], [5.], [7.]]})
-
-  def test_runtime_batch_size_matches(self):
-    price1 = fc.numeric_column('price1')
-    price2 = fc.numeric_column('price2')
-    with ops.Graph().as_default():
-      features = {
-          'price1': array_ops.placeholder(dtype=dtypes.int64),  # batchsize = 2
-          'price2': array_ops.placeholder(dtype=dtypes.int64),  # batchsize = 2
-      }
-      net = fc.DenseFeatures([price1, price2])(features)
-      with _initialized_session() as sess:
-        sess.run(
-            net,
-            feed_dict={
-                features['price1']: [[1.], [5.]],
-                features['price2']: [[1.], [5.]],
-            })
-
-  def test_multiple_layers_with_same_embedding_column(self):
-    some_sparse_column = fc.categorical_column_with_hash_bucket(
-        'sparse_feature', hash_bucket_size=5)
-    some_embedding_column = fc.embedding_column(
-        some_sparse_column, dimension=10)
-
-    with ops.Graph().as_default():
-      features = {
-          'sparse_feature': [['a'], ['x']],
-      }
-      all_cols = [some_embedding_column]
-      fc.DenseFeatures(all_cols)(features)
-      fc.DenseFeatures(all_cols)(features)
-      # Make sure that 2 variables get created in this case.
-      self.assertEqual(2, len(
-          ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
-      expected_var_names = [
-          'dense_features/sparse_feature_embedding/embedding_weights:0',
-          'dense_features_1/sparse_feature_embedding/embedding_weights:0'
-      ]
-      self.assertItemsEqual(
-          expected_var_names,
-          [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
-
-  @test_util.run_deprecated_v1
-  def test_multiple_layers_with_same_shared_embedding_column(self):
-    categorical_column_a = fc.categorical_column_with_identity(
-        key='aaa', num_buckets=3)
-    categorical_column_b = fc.categorical_column_with_identity(
-        key='bbb', num_buckets=3)
-    embedding_dimension = 2
-    embedding_column_b, embedding_column_a = fc.shared_embedding_columns_v2(
-        [categorical_column_b, categorical_column_a],
-        dimension=embedding_dimension)
-
-    with ops.Graph().as_default():
-      features = {
-          'aaa':
-              sparse_tensor.SparseTensor(
-                  indices=((0, 0), (1, 0), (1, 1)),
-                  values=(0, 1, 0),
-                  dense_shape=(2, 2)),
-          'bbb':
-              sparse_tensor.SparseTensor(
-                  indices=((0, 0), (1, 0), (1, 1)),
-                  values=(1, 2, 1),
-                  dense_shape=(2, 2)),
-      }
-      all_cols = [embedding_column_a, embedding_column_b]
-      fc.DenseFeatures(all_cols)(features)
-      fc.DenseFeatures(all_cols)(features)
-      # Make sure that only 1 variable gets created in this case.
-      self.assertEqual(1, len(
-          ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
-      self.assertItemsEqual(
-          ['aaa_bbb_shared_embedding:0'],
-          [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
-
-  @test_util.run_deprecated_v1
-  def test_multiple_layers_with_same_shared_embedding_column_diff_graphs(self):
-    categorical_column_a = fc.categorical_column_with_identity(
-        key='aaa', num_buckets=3)
-    categorical_column_b = fc.categorical_column_with_identity(
-        key='bbb', num_buckets=3)
-    embedding_dimension = 2
-    embedding_column_b, embedding_column_a = fc.shared_embedding_columns_v2(
-        [categorical_column_b, categorical_column_a],
-        dimension=embedding_dimension)
-    all_cols = [embedding_column_a, embedding_column_b]
-
-    with ops.Graph().as_default():
-      features = {
-          'aaa':
-              sparse_tensor.SparseTensor(
-                  indices=((0, 0), (1, 0), (1, 1)),
-                  values=(0, 1, 0),
-                  dense_shape=(2, 2)),
-          'bbb':
-              sparse_tensor.SparseTensor(
-                  indices=((0, 0), (1, 0), (1, 1)),
-                  values=(1, 2, 1),
-                  dense_shape=(2, 2)),
-      }
-      fc.DenseFeatures(all_cols)(features)
-      # Make sure that only 1 variable gets created in this case.
-      self.assertEqual(1, len(
-          ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
-
-    with ops.Graph().as_default():
-      features1 = {
-          'aaa':
-              sparse_tensor.SparseTensor(
-                  indices=((0, 0), (1, 0), (1, 1)),
-                  values=(0, 1, 0),
-                  dense_shape=(2, 2)),
-          'bbb':
-              sparse_tensor.SparseTensor(
-                  indices=((0, 0), (1, 0), (1, 1)),
-                  values=(1, 2, 1),
-                  dense_shape=(2, 2)),
-      }
-
-      fc.DenseFeatures(all_cols)(features1)
-      # Make sure that only 1 variable gets created in this case.
-      self.assertEqual(1, len(
-          ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
-      self.assertItemsEqual(
-          ['aaa_bbb_shared_embedding:0'],
-          [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
-
-  @test_util.run_deprecated_v1
-  def test_with_numpy_input_fn(self):
-    embedding_values = (
-        (1., 2., 3., 4., 5.),  # id 0
-        (6., 7., 8., 9., 10.),  # id 1
-        (11., 12., 13., 14., 15.)  # id 2
-    )
-
-    def _initializer(shape, dtype, partition_info):
-      del shape, dtype, partition_info
-      return embedding_values
-
-    # price has 1 dimension in dense_features
-    price = fc.numeric_column('price')
-    body_style = fc.categorical_column_with_vocabulary_list(
-        'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
-    # one_hot_body_style has 3 dims in dense_features.
-    one_hot_body_style = fc.indicator_column(body_style)
-    # embedded_body_style has 5 dims in dense_features.
-    embedded_body_style = fc.embedding_column(
-        body_style, dimension=5, initializer=_initializer)
-
-    input_fn = numpy_io.numpy_input_fn(
-        x={
-            'price': np.array([11., 12., 13., 14.]),
-            'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),
-        },
-        batch_size=2,
-        shuffle=False)
-    features = input_fn()
-    net = fc.DenseFeatures([price, one_hot_body_style, embedded_body_style])(
-        features)
-    self.assertEqual(1 + 3 + 5, net.shape[1])
-    with _initialized_session() as sess:
-      coord = coordinator.Coordinator()
-      threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
-
-      # Each row is formed by concatenating `embedded_body_style`,
-      # `one_hot_body_style`, and `price` in order.
-      self.assertAllEqual([[11., 12., 13., 14., 15., 0., 0., 1., 11.],
-                           [1., 2., 3., 4., 5., 1., 0., 0., 12]], sess.run(net))
-
-      coord.request_stop()
-      coord.join(threads)
-
-  @test_util.run_deprecated_v1
-  def test_with_1d_sparse_tensor(self):
-    embedding_values = (
-        (1., 2., 3., 4., 5.),  # id 0
-        (6., 7., 8., 9., 10.),  # id 1
-        (11., 12., 13., 14., 15.)  # id 2
-    )
-
-    def _initializer(shape, dtype, partition_info):
-      del shape, dtype, partition_info
-      return embedding_values
-
-    # price has 1 dimension in dense_features
-    price = fc.numeric_column('price')
-
-    # one_hot_body_style has 3 dims in dense_features.
-    body_style = fc.categorical_column_with_vocabulary_list(
-        'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
-    one_hot_body_style = fc.indicator_column(body_style)
-
-    # embedded_body_style has 5 dims in dense_features.
-    country = fc.categorical_column_with_vocabulary_list(
-        'country', vocabulary_list=['US', 'JP', 'CA'])
-    embedded_country = fc.embedding_column(
-        country, dimension=5, initializer=_initializer)
-
-    # Provides 1-dim tensor and dense tensor.
-    features = {
-        'price':
-            constant_op.constant([
-                11.,
-                12.,
-            ]),
-        'body-style':
-            sparse_tensor.SparseTensor(
-                indices=((0,), (1,)),
-                values=('sedan', 'hardtop'),
-                dense_shape=(2,)),
-        # This is dense tensor for the categorical_column.
-        'country':
-            constant_op.constant(['CA', 'US']),
-    }
-    self.assertEqual(1, features['price'].shape.ndims)
-    self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
-    self.assertEqual(1, features['country'].shape.ndims)
-
-    net = fc.DenseFeatures([price, one_hot_body_style, embedded_country])(
-        features)
-    self.assertEqual(1 + 3 + 5, net.shape[1])
-    with _initialized_session() as sess:
-
-      # Each row is formed by concatenating `embedded_body_style`,
-      # `one_hot_body_style`, and `price` in order.
-      self.assertAllEqual([[0., 0., 1., 11., 12., 13., 14., 15., 11.],
-                           [1., 0., 0., 1., 2., 3., 4., 5., 12.]],
-                          sess.run(net))
-
-  @test_util.run_deprecated_v1
-  def test_with_1d_unknown_shape_sparse_tensor(self):
-    embedding_values = (
-        (1., 2.),  # id 0
-        (6., 7.),  # id 1
-        (11., 12.)  # id 2
-    )
-
-    def _initializer(shape, dtype, partition_info):
-      del shape, dtype, partition_info
-      return embedding_values
-
-    # price has 1 dimension in dense_features
-    price = fc.numeric_column('price')
-
-    # one_hot_body_style has 3 dims in dense_features.
-    body_style = fc.categorical_column_with_vocabulary_list(
-        'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
-    one_hot_body_style = fc.indicator_column(body_style)
-
-    # embedded_body_style has 5 dims in dense_features.
-    country = fc.categorical_column_with_vocabulary_list(
-        'country', vocabulary_list=['US', 'JP', 'CA'])
-    embedded_country = fc.embedding_column(
-        country, dimension=2, initializer=_initializer)
-
-    # Provides 1-dim tensor and dense tensor.
-    features = {
-        'price': array_ops.placeholder(dtypes.float32),
-        'body-style': array_ops.sparse_placeholder(dtypes.string),
-        # This is dense tensor for the categorical_column.
-        'country': array_ops.placeholder(dtypes.string),
-    }
-    self.assertIsNone(features['price'].shape.ndims)
-    self.assertIsNone(features['body-style'].get_shape().ndims)
-    self.assertIsNone(features['country'].shape.ndims)
-
-    price_data = np.array([11., 12.])
-    body_style_data = sparse_tensor.SparseTensorValue(
-        indices=((0,), (1,)), values=('sedan', 'hardtop'), dense_shape=(2,))
-    country_data = np.array([['US'], ['CA']])
-
-    net = fc.DenseFeatures([price, one_hot_body_style, embedded_country])(
-        features)
-    self.assertEqual(1 + 3 + 2, net.shape[1])
-    with _initialized_session() as sess:
-
-      # Each row is formed by concatenating `embedded_body_style`,
-      # `one_hot_body_style`, and `price` in order.
-      self.assertAllEqual(
-          [[0., 0., 1., 1., 2., 11.], [1., 0., 0., 11., 12., 12.]],
-          sess.run(
-              net,
-              feed_dict={
-                  features['price']: price_data,
-                  features['body-style']: body_style_data,
-                  features['country']: country_data
-              }))
-
-  @test_util.run_deprecated_v1
-  def test_with_rank_0_feature(self):
-    # price has 1 dimension in dense_features
-    price = fc.numeric_column('price')
-    features = {
-        'price': constant_op.constant(0),
-    }
-    self.assertEqual(0, features['price'].shape.ndims)
-
-    # Static rank 0 should fail
-    with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
-      fc.DenseFeatures([price])(features)
-
-    # Dynamic rank 0 should fail
-    features = {
-        'price': array_ops.placeholder(dtypes.float32),
-    }
-    net = fc.DenseFeatures([price])(features)
-    self.assertEqual(1, net.shape[1])
-    with _initialized_session() as sess:
-      with self.assertRaisesOpError('Feature .* cannot have rank 0'):
-        sess.run(net, feed_dict={features['price']: np.array(1)})
-
-
 class InputLayerTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes
@@ -3846,7 +3181,7 @@
           key='a', num_buckets=3)
       embedding_dimension = 2
 
-      def _embedding_column_initializer(shape, dtype, partition_info):
+      def _embedding_column_initializer(shape, dtype, partition_info=None):
         del shape  # unused
         del dtype  # unused
         del partition_info  # unused
@@ -3891,7 +3226,7 @@
           key='a', num_buckets=3)
       embedding_dimension = 2
 
-      def _embedding_column_initializer(shape, dtype, partition_info):
+      def _embedding_column_initializer(shape, dtype, partition_info=None):
         del shape  # unused
         del dtype  # unused
         del partition_info  # unused
@@ -4268,7 +3603,7 @@
         (11., 12., 13., 14., 15.)  # id 2
     )
 
-    def _initializer(shape, dtype, partition_info):
+    def _initializer(shape, dtype, partition_info=None):
       del shape, dtype, partition_info
       return embedding_values
 
@@ -4325,7 +3660,7 @@
         (11., 12.)  # id 2
     )
 
-    def _initializer(shape, dtype, partition_info):
+    def _initializer(shape, dtype, partition_info=None):
       del shape, dtype, partition_info
       return embedding_values
 
@@ -5670,8 +5005,7 @@
             values=np.array((0, 1, 0), dtype=np.int64),
             dense_shape=(2, 2)), self.evaluate(id_weight_pair.id_tensor))
 
-  @test_util.run_deprecated_v1
-  def test_get_sparse_tensors_with_inputs_too_small(self):
+  def _test_get_sparse_tensors_with_inputs_too_small(self):
     column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
     inputs = sparse_tensor.SparseTensorValue(
         indices=((0, 0), (1, 0), (1, 1)), values=(1, -1, 0), dense_shape=(2, 2))
@@ -5684,11 +5018,19 @@
     self.evaluate(variables_lib.global_variables_initializer())
     self.evaluate(lookup_ops.tables_initializer())
 
-    with self.assertRaisesRegexp(errors.OpError, 'assert_greater_or_equal_0'):
+    with self.assertRaisesRegexp(errors.OpError, 'assert'):
       self.evaluate(id_weight_pair.id_tensor)
 
   @test_util.run_deprecated_v1
-  def test_get_sparse_tensors_with_inputs_too_big(self):
+  def test_get_sparse_tensors_with_inputs_too_small(self):
+    self._test_get_sparse_tensors_with_inputs_too_small()
+
+  @test_util.run_deprecated_v1
+  @test_util.enable_control_flow_v2
+  def test_get_sparse_tensors_with_inputs_too_small_v2(self):
+    self._test_get_sparse_tensors_with_inputs_too_small()
+
+  def _test_get_sparse_tensors_with_inputs_too_big(self):
     column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
     inputs = sparse_tensor.SparseTensorValue(
         indices=((0, 0), (1, 0), (1, 1)), values=(1, 99, 0), dense_shape=(2, 2))
@@ -5701,11 +5043,19 @@
     self.evaluate(variables_lib.global_variables_initializer())
     self.evaluate(lookup_ops.tables_initializer())
 
-    with self.assertRaisesRegexp(errors.OpError,
-                                 'assert_less_than_num_buckets'):
+    with self.assertRaisesRegexp(errors.OpError, 'assert'):
       self.evaluate(id_weight_pair.id_tensor)
 
   @test_util.run_deprecated_v1
+  def test_get_sparse_tensors_with_inputs_too_big(self):
+    self._test_get_sparse_tensors_with_inputs_too_big()
+
+  @test_util.run_deprecated_v1
+  @test_util.enable_control_flow_v2
+  def test_get_sparse_tensors_with_inputs_too_big_v2(self):
+    self._test_get_sparse_tensors_with_inputs_too_big()
+
+  @test_util.run_deprecated_v1
   def test_get_sparse_tensors_with_default_value(self):
     column = fc.categorical_column_with_identity(
         key='aaa', num_buckets=4, default_value=3)
@@ -6153,7 +5503,7 @@
               sparse_tensor.SparseTensor(
                   indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
       }
-      net = fc.DenseFeatures([animal])(features)
+      net = df.DenseFeatures([animal])(features)
 
       self.evaluate(variables_lib.global_variables_initializer())
       self.evaluate(lookup_ops.tables_initializer())
@@ -6426,7 +5776,7 @@
         (7., 11.)  # id 2
     )
 
-    def _initializer(shape, dtype, partition_info):
+    def _initializer(shape, dtype, partition_info=None):
       self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
       self.assertEqual(dtypes.float32, dtype)
       self.assertIsNone(partition_info)
@@ -6492,7 +5842,7 @@
         (7., 11.)  # id 2
     )
 
-    def _initializer(shape, dtype, partition_info):
+    def _initializer(shape, dtype, partition_info=None):
       self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
       self.assertEqual(dtypes.float32, dtype)
       self.assertIsNone(partition_info)
@@ -6557,7 +5907,7 @@
         (2., 7., 12.)  # id 3
     )
 
-    def _initializer(shape, dtype, partition_info):
+    def _initializer(shape, dtype, partition_info=None):
       self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
       self.assertEqual(dtypes.float32, dtype)
       self.assertIsNone(partition_info)
@@ -6624,7 +5974,7 @@
         (7., 11.)  # id 2
     )
 
-    def _initializer(shape, dtype, partition_info):
+    def _initializer(shape, dtype, partition_info=None):
       self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
       self.assertEqual(dtypes.float32, dtype)
       self.assertIsNone(partition_info)
@@ -6766,7 +6116,7 @@
     embedding_shape = (vocabulary_size, embedding_dimension)
     zeros_embedding_values = np.zeros(embedding_shape)
 
-    def _initializer(shape, dtype, partition_info):
+    def _initializer(shape, dtype, partition_info=None):
       self.assertAllEqual(embedding_shape, shape)
       self.assertEqual(dtypes.float32, dtype)
       self.assertIsNone(partition_info)
@@ -6850,7 +6200,7 @@
         (7., 11.)  # id 2
     )
 
-    def _initializer(shape, dtype, partition_info):
+    def _initializer(shape, dtype, partition_info=None):
       self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
       self.assertEqual(dtypes.float32, dtype)
       self.assertIsNone(partition_info)
@@ -6877,7 +6227,7 @@
         initializer=_initializer)
 
     # Provide sparse input and get dense result.
-    l = fc.DenseFeatures((embedding_column,))
+    l = df.DenseFeatures((embedding_column,))
     dense_features = l({'aaa': sparse_input})
 
     # Assert expected embedding variable and lookups.
@@ -6917,7 +6267,7 @@
         (7., 11.)  # id 2
     )
 
-    def _initializer(shape, dtype, partition_info):
+    def _initializer(shape, dtype, partition_info=None):
       self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
       self.assertEqual(dtypes.float32, dtype)
       self.assertIsNone(partition_info)
@@ -6945,7 +6295,7 @@
         trainable=False)
 
     # Provide sparse input and get dense result.
-    dense_features = fc.DenseFeatures((embedding_column,))({
+    dense_features = df.DenseFeatures((embedding_column,))({
         'aaa': sparse_input
     })
 
@@ -6983,7 +6333,7 @@
         (7., 11.)  # id 2
     )
 
-    def _initializer(shape, dtype, partition_info):
+    def _initializer(shape, dtype, partition_info=None):
       self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
       self.assertEqual(dtypes.float32, dtype)
       self.assertIsNone(partition_info)
@@ -7046,7 +6396,7 @@
     embedding_shape = (vocabulary_size, embedding_dimension)
     zeros_embedding_values = np.zeros(embedding_shape)
 
-    def _initializer(shape, dtype, partition_info):
+    def _initializer(shape, dtype, partition_info=None):
       self.assertAllEqual(embedding_shape, shape)
       self.assertEqual(dtypes.float32, dtype)
       self.assertIsNone(partition_info)
@@ -7128,7 +6478,7 @@
     embedding_shape = (vocabulary_size, embedding_dimension)
     zeros_embedding_values = np.zeros(embedding_shape)
 
-    def _initializer(shape, dtype, partition_info):
+    def _initializer(shape, dtype, partition_info=None):
       self.assertAllEqual(embedding_shape, shape)
       self.assertEqual(dtypes.float32, dtype)
       self.assertIsNone(partition_info)
@@ -7248,7 +6598,7 @@
   @test_util.run_deprecated_v1
   def test_serialization_with_custom_initializer(self):
 
-    def _initializer(shape, dtype, partition_info):
+    def _initializer(shape, dtype, partition_info=None):
       del shape, dtype, partition_info
       return ValueError('Not expected to be called')
 
@@ -7526,7 +6876,7 @@
         (7., 11.)  # id 2
     )
 
-    def _initializer(shape, dtype, partition_info):
+    def _initializer(shape, dtype, partition_info=None):
       self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
       self.assertEqual(dtypes.float32, dtype)
       self.assertIsNone(partition_info)
@@ -7610,7 +6960,7 @@
         (7., 11.)  # id 2
     )
 
-    def _initializer(shape, dtype, partition_info):
+    def _initializer(shape, dtype, partition_info=None):
       self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
       self.assertEqual(dtypes.float32, dtype)
       self.assertIsNone(partition_info)
@@ -7655,7 +7005,7 @@
     embedding_shape = (vocabulary_size, embedding_dimension)
     zeros_embedding_values = np.zeros(embedding_shape)
 
-    def _initializer(shape, dtype, partition_info):
+    def _initializer(shape, dtype, partition_info=None):
       self.assertAllEqual(embedding_shape, shape)
       self.assertEqual(dtypes.float32, dtype)
       self.assertIsNone(partition_info)
@@ -7769,7 +7119,7 @@
         (7., 11.)  # id 2
     )
 
-    def _initializer(shape, dtype, partition_info):
+    def _initializer(shape, dtype, partition_info=None):
       self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
       self.assertEqual(dtypes.float32, dtype)
       self.assertIsNone(partition_info)
@@ -7820,7 +7170,7 @@
     }
 
     # Provide sparse input and get dense result.
-    dense_features = fc.DenseFeatures(
+    dense_features = df.DenseFeatures(
         feature_columns=(embedding_column_b, embedding_column_a,
                          embedding_column_c, embedding_column_d))(
                              features)
@@ -7859,7 +7209,7 @@
   @test_util.run_deprecated_v1
   def test_serialization(self):
 
-    def _initializer(shape, dtype, partition_info):
+    def _initializer(shape, dtype, partition_info=None):
       del shape, dtype, partition_info
       return ValueError('Not expected to be called')
 
diff --git a/tensorflow/python/feature_column/sequence_feature_column_integration_test.py b/tensorflow/python/feature_column/sequence_feature_column_integration_test.py
index b7c67945..03fc1b6 100644
--- a/tensorflow/python/feature_column/sequence_feature_column_integration_test.py
+++ b/tensorflow/python/feature_column/sequence_feature_column_integration_test.py
@@ -26,6 +26,7 @@
 from tensorflow.core.example import example_pb2
 from tensorflow.core.example import feature_pb2
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.feature_column import dense_features
 from tensorflow.python.feature_column import feature_column_v2 as fc
 from tensorflow.python.feature_column import sequence_feature_column as sfc
 from tensorflow.python.keras.layers import recurrent
@@ -92,7 +93,7 @@
     # Tile the context features across the sequence features
     sequence_input_layer = sfc.SequenceFeatures(seq_cols)
     seq_layer, _ = sequence_input_layer(features)
-    input_layer = fc.DenseFeatures(ctx_cols)
+    input_layer = dense_features.DenseFeatures(ctx_cols)
     ctx_layer = input_layer(features)
     input_layer = sfc.concatenate_context_input(ctx_layer, seq_layer)
 
diff --git a/tensorflow/python/feature_column/sequence_feature_column_test.py b/tensorflow/python/feature_column/sequence_feature_column_test.py
index 53ccc32..d6da37f 100644
--- a/tensorflow/python/feature_column/sequence_feature_column_test.py
+++ b/tensorflow/python/feature_column/sequence_feature_column_test.py
@@ -23,6 +23,7 @@
 import numpy as np
 
 from tensorflow.python.client import session
+from tensorflow.python.feature_column import dense_features
 from tensorflow.python.feature_column import feature_column_v2 as fc
 from tensorflow.python.feature_column import sequence_feature_column as sfc
 from tensorflow.python.feature_column import serialization
@@ -112,7 +113,8 @@
         (17., 18., 19.)  # id 2
     )
     def _get_initializer(embedding_dimension, embedding_values):
-      def _initializer(shape, dtype, partition_info):
+
+      def _initializer(shape, dtype, partition_info=None):
         self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
         self.assertEqual(dtypes.float32, dtype)
         self.assertIsNone(partition_info)
@@ -199,7 +201,7 @@
 
       def _get_initializer(embedding_dimension, embedding_values):
 
-        def _initializer(shape, dtype, partition_info):
+        def _initializer(shape, dtype, partition_info=None):
           self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
           self.assertEqual(dtypes.float32, dtype)
           self.assertIsNone(partition_info)
@@ -663,7 +665,7 @@
         ValueError,
         r'In embedding_column: aaa_embedding\. categorical_column must not be '
         r'of type SequenceCategoricalColumn\.'):
-      input_layer = fc.DenseFeatures([embedding_column_a])
+      input_layer = dense_features.DenseFeatures([embedding_column_a])
       _ = input_layer({'aaa': sparse_input})
 
   def test_indicator_column(self):
@@ -684,7 +686,7 @@
         ValueError,
         r'In indicator_column: aaa_indicator\. categorical_column must not be '
         r'of type SequenceCategoricalColumn\.'):
-      input_layer = fc.DenseFeatures([indicator_column_a])
+      input_layer = dense_features.DenseFeatures([indicator_column_a])
       _ = input_layer({'aaa': sparse_input})
 
 
@@ -971,7 +973,8 @@
         (3., 5.),  # id 1
         (7., 11.)  # id 2
     )
-    def _initializer(shape, dtype, partition_info):
+
+    def _initializer(shape, dtype, partition_info=None):
       self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
       self.assertEqual(dtypes.float32, dtype)
       self.assertIsNone(partition_info)
@@ -1066,7 +1069,7 @@
         (7., 11.)  # id 2
     )
 
-    def _initializer(shape, dtype, partition_info):
+    def _initializer(shape, dtype, partition_info=None):
       self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
       self.assertEqual(dtypes.float32, dtype)
       self.assertIsNone(partition_info)
diff --git a/tensorflow/python/feature_column/serialization_test.py b/tensorflow/python/feature_column/serialization_test.py
index 9788349..8a9082d 100644
--- a/tensorflow/python/feature_column/serialization_test.py
+++ b/tensorflow/python/feature_column/serialization_test.py
@@ -20,6 +20,7 @@
 
 from absl.testing import parameterized
 
+from tensorflow.python.feature_column import dense_features
 from tensorflow.python.feature_column import feature_column_v2 as fc
 from tensorflow.python.feature_column import sequence_feature_column as sfc
 from tensorflow.python.feature_column import serialization
@@ -125,7 +126,8 @@
     cols = [fc.numeric_column('a'),
             fc.embedding_column(fc.categorical_column_with_identity(
                 key='b', num_buckets=3), dimension=2)]
-    orig_layer = fc.DenseFeatures(cols, trainable=trainable, name=name)
+    orig_layer = dense_features.DenseFeatures(
+        cols, trainable=trainable, name=name)
     config = orig_layer.get_config()
 
     self.assertEqual(config['name'], orig_layer.name)
@@ -147,10 +149,11 @@
                 'b', vocabulary_list=['1', '2', '3']), dimension=2),
             fc.indicator_column(fc.categorical_column_with_hash_bucket(
                 key='c', hash_bucket_size=3))]
-    orig_layer = fc.DenseFeatures(cols, trainable=trainable, name=name)
+    orig_layer = dense_features.DenseFeatures(
+        cols, trainable=trainable, name=name)
     config = orig_layer.get_config()
 
-    new_layer = fc.DenseFeatures.from_config(config)
+    new_layer = dense_features.DenseFeatures.from_config(config)
 
     self.assertEqual(new_layer.name, orig_layer.name)
     self.assertEqual(new_layer.trainable, trainable)
@@ -168,10 +171,10 @@
     ab = fc.crossed_column([a, b], hash_bucket_size=2)
     cols = [fc.indicator_column(ab)]
 
-    orig_layer = fc.DenseFeatures(cols)
+    orig_layer = dense_features.DenseFeatures(cols)
     config = orig_layer.get_config()
 
-    new_layer = fc.DenseFeatures.from_config(config)
+    new_layer = dense_features.DenseFeatures.from_config(config)
 
     self.assertLen(new_layer._feature_columns, 1)
     self.assertEqual(new_layer._feature_columns[0].name, 'a_X_b_indicator')
diff --git a/tensorflow/python/framework/config_test.py b/tensorflow/python/framework/config_test.py
index afdfa97..5d3b300 100644
--- a/tensorflow/python/framework/config_test.py
+++ b/tensorflow/python/framework/config_test.py
@@ -43,9 +43,10 @@
     try:
       return fn(*args, **kwargs)
     finally:
-      del context._context
-      context._context = context.Context()
-      ops.enable_eager_execution()
+      # Reset the context.
+      context._context = None
+      ops.enable_eager_execution_internal()
+      assert context._context is not None
 
   return wrapper
 
diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py
index a4b2769..1f5bbfb 100644
--- a/tensorflow/python/framework/constant_op.py
+++ b/tensorflow/python/framework/constant_op.py
@@ -23,8 +23,6 @@
 from __future__ import division
 from __future__ import print_function
 
-import six
-
 from tensorflow.core.framework import attr_value_pb2
 from tensorflow.core.framework import types_pb2
 from tensorflow.python.eager import context
@@ -95,22 +93,7 @@
     except AttributeError:
       dtype = dtypes.as_dtype(dtype).as_datatype_enum
   ctx.ensure_initialized()
-  device = ctx.device_name
-  if isinstance(value, (float,) + six.integer_types):
-    # Use a scalar cache. This will put each scalar of each type only once on
-    # each device. Scalars don't use much device memory but copying scalars can
-    # trigger memcpys which are slow.
-    cache_key = device, value, dtype, type(value)
-    scalar_cache = ctx.scalar_cache()
-    tensor = scalar_cache.get(cache_key, None)
-    if tensor is not None:
-      return ops.EagerTensor(
-          value, ctx, device, dtype, tensor)
-    t = ops.EagerTensor(value, ctx, device, dtype)
-    scalar_cache[cache_key] = t
-    return t
-  else:
-    return ops.EagerTensor(value, ctx, device, dtype)
+  return ops.EagerTensor(value, ctx.device_name, dtype)
 
 
 @tf_export(v1=["constant"])
diff --git a/tensorflow/python/framework/convert_to_constants.py b/tensorflow/python/framework/convert_to_constants.py
index 4e2e24c..c6efc85 100644
--- a/tensorflow/python/framework/convert_to_constants.py
+++ b/tensorflow/python/framework/convert_to_constants.py
@@ -29,11 +29,13 @@
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.grappler import tf_optimizer
 from tensorflow.python.ops import array_ops
+from tensorflow.python.util import object_identity
 from tensorflow.python.training.saver import export_meta_graph
 
 
 _CONDITIONAL_OPS = set(["If", "StatelessIf"])
-_CONTROL_FLOW_OPS = _CONDITIONAL_OPS.union(set(["While"]))
+_LOOP_OPS = set(["While", "StatelessWhile"])
+_CONTROL_FLOW_OPS = _CONDITIONAL_OPS.union(_LOOP_OPS)
 
 
 def disable_lower_using_switch_merge(graph_def):
@@ -176,13 +178,15 @@
     Dict
   """
   tensor_data = {}
-  map_index_to_variable = {
-      func.captured_inputs.index(var.handle): var
-      for var in func.graph.variables
-  }
+  map_index_to_variable = {}
+  for var in func.graph.variables:
+    for idx, captured_input in enumerate(func.captured_inputs):
+      if var.handle is captured_input:  # pylint: disable=protected-access
+        map_index_to_variable[idx] = var
+        break
 
   # Iterates through all captures which are represented as Placeholders.
-  for idx, (val_tensor, name_tensor) in enumerate(func.graph.captures.items()):
+  for idx, (val_tensor, name_tensor) in enumerate(func.graph.captures):
     tensor_name = _get_tensor_name(name_tensor.name)
     is_variable = idx in map_index_to_variable
     if is_variable:
@@ -202,10 +206,9 @@
 
   Creates a map from function name to a list of types and a list of shapes that
   correspond with the function arguments. The data is primarily determined from
-  the corresponding "If", "StatelessIf", or "While" op. If the argument is a
-  resource variable, then the type is determined from the type of the data
-  contained within the Tensor. The shape data is only determined in the case of
-  the "While" op.
+  the corresponding "If" or "While" op. If the argument is a resource variable,
+  then the type is determined from the type of the data contained within the
+  Tensor. The shape data is only determined in the case of the "While" op.
 
   `is_also_output_type` is used to identify the "While" bodies that require the
   output types to be updated at the same time the input types are updated.
@@ -249,7 +252,7 @@
 
       add_value(node.attr["then_branch"].func.name, arg_types, None, False)
       add_value(node.attr["else_branch"].func.name, arg_types, None, False)
-    elif node.op == "While":
+    elif node.op in _LOOP_OPS:
       arg_types = [dtype for dtype in node.attr["T"].list.type]
       output_shapes = [shape for shape in node.attr["output_shapes"].list.shape]
 
@@ -298,7 +301,7 @@
 
 
 def _populate_if_op(output_node, input_node, function_data):
-  """Updates the type attributes and the function names of If or StatelessIf.
+  """Updates the type attributes and function names of If or StatelessIf.
 
   Args:
     output_node: TensorFlow NodeDef.
@@ -317,7 +320,7 @@
 
 
 def _populate_while_op(output_node, input_node, function_data):
-  """Updates the type attributes and the function names of the While op.
+  """Updates the type attributes and function names of While or StatelessWhile.
 
   Args:
     output_node: TensorFlow NodeDef.
@@ -352,10 +355,11 @@
     ConcreteFunction.
   """
   # Create a ConcreteFunction from the new GraphDef.
-  input_tensors = list(func.graph.captures.values())
-  converted_inputs = set(
+  input_tensors = func.graph.internal_captures
+  converted_inputs = object_identity.ObjectIdentitySet(
       [input_tensors[index] for index in converted_input_indices])
-  not_converted_inputs = set(func.inputs).difference(converted_inputs)
+  not_converted_inputs = object_identity.ObjectIdentitySet(
+      func.inputs).difference(converted_inputs)
   not_converted_inputs_map = {
       tensor.name: tensor for tensor in not_converted_inputs
   }
@@ -432,7 +436,7 @@
         if input_name in tensor_data:
           dtype = attr_value_pb2.AttrValue(type=arg_types[idx])
           _save_placeholder(_get_tensor_name(input_tensor), dtype)
-    elif node.op == "While":
+    elif node.op in _LOOP_OPS:
       # Get dtype and data for resource Placeholders.
       cond_func = node.attr["cond"].func.name
       arg_types = function_data[cond_func]["types"]
@@ -442,7 +446,7 @@
           dtype = attr_value_pb2.AttrValue(type=arg_types[idx])
           _save_placeholder(_get_tensor_name(input_tensor), dtype)
     elif (node.op == "Identity" and node.attr["T"].type == dtypes.resource and
-          name_to_node[_get_tensor_name(node.input[0])].op == "While"):
+          name_to_node[_get_tensor_name(node.input[0])].op in _LOOP_OPS):
       # Store the dtype for Identity resource ops that are outputs of While ops.
       while_node = name_to_node[_get_tensor_name(node.input[0])]
       body_func = while_node.attr["body"].func.name
@@ -502,7 +506,7 @@
     # Update the function names and argument types for the conditional ops.
     elif input_node.op in _CONDITIONAL_OPS:
       _populate_if_op(output_node, input_node, function_data)
-    elif input_node.op == "While":
+    elif input_node.op in _LOOP_OPS:
       _populate_while_op(output_node, input_node, function_data)
     else:
       output_node.CopyFrom(input_node)
@@ -553,7 +557,7 @@
         # Update the function names and argument types for the conditional ops.
         elif input_node.op in _CONDITIONAL_OPS:
           _populate_if_op(output_node, input_node, function_data)
-        elif input_node.op == "While":
+        elif input_node.op in _LOOP_OPS:
           _populate_while_op(output_node, input_node, function_data)
         else:
           output_node.CopyFrom(input_node)
diff --git a/tensorflow/python/framework/convert_to_constants_test.py b/tensorflow/python/framework/convert_to_constants_test.py
index f962d5e..cbe8528 100644
--- a/tensorflow/python/framework/convert_to_constants_test.py
+++ b/tensorflow/python/framework/convert_to_constants_test.py
@@ -37,6 +37,7 @@
 from tensorflow.python.ops import rnn
 from tensorflow.python.ops import rnn_cell_impl
 from tensorflow.python.ops import variables
+from tensorflow.python.ops import while_v2
 from tensorflow.python.platform import test
 from tensorflow.python.saved_model import simple_save
 from tensorflow.python.saved_model.load import load
@@ -47,6 +48,24 @@
 
 class VariablesToConstantsTest(test.TestCase):
 
+  def _freezeModel(self, model):
+    """Freezes the model.
+
+    Args:
+      model: Function.
+
+    Returns:
+      root: AutoTrackable object with original ConcreteFunction.
+      output_func: frozen ConcreteFunction.
+    """
+    root = tracking.AutoTrackable()
+    root.f = model
+    input_func = root.f.get_concrete_function()
+
+    output_func = convert_to_constants.convert_variables_to_constants_v2(
+        input_func, lower_control_flow=False)
+    return root, output_func
+
   def _hasStatefulPartitionedCallOp(self, graph_def):
     """Determines if a StatefulPartitionedCall op exists in the graph."""
     for node in graph_def.node:
@@ -60,6 +79,11 @@
 
   def _testConvertedFunction(self, obj, func, converted_concrete_func,
                              input_data):
+    # Ensure the converted graph has no variables and no function calls.
+    constant_graph_def = converted_concrete_func.graph.as_graph_def()
+    self.assertEqual(0, self._getNumVariables(constant_graph_def))
+    self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))
+
     # Check that the converted ConcreteFunction produces the same result as the
     # original Function.
     expected_value = nest.flatten(func(**input_data))
@@ -104,10 +128,6 @@
 
     output_func = convert_to_constants.convert_variables_to_constants_v2(
         input_func)
-    constant_graph_def = output_func.graph.as_graph_def()
-    self.assertEqual(0, self._getNumVariables(constant_graph_def))
-    self.assertFalse(constant_graph_def.library.function)
-
     self._testConvertedFunction(root, root.f, output_func, input_data)
 
   @test_util.run_v2_only
@@ -125,10 +145,6 @@
 
     output_func = convert_to_constants.convert_variables_to_constants_v2(
         input_func)
-    constant_graph_def = output_func.graph.as_graph_def()
-    self.assertEqual(0, self._getNumVariables(constant_graph_def))
-    self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))
-
     self._testConvertedFunction(root, root.f, output_func, input_data)
 
   @test_util.run_v2_only
@@ -146,10 +162,6 @@
 
     output_func = convert_to_constants.convert_variables_to_constants_v2(
         input_func)
-    constant_graph_def = output_func.graph.as_graph_def()
-    self.assertEqual(0, self._getNumVariables(constant_graph_def))
-    self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))
-
     self._testConvertedFunction(root, root.f, output_func, input_data)
 
   @test_util.run_v2_only
@@ -172,10 +184,6 @@
 
     output_func = convert_to_constants.convert_variables_to_constants_v2(
         input_func)
-    constant_graph_def = output_func.graph.as_graph_def()
-    self.assertEqual(0, self._getNumVariables(constant_graph_def))
-    self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))
-
     self._testConvertedFunction(root, root.f, output_func, input_data)
 
   @test_util.run_v2_only
@@ -209,15 +217,12 @@
 
     output_func = convert_to_constants.convert_variables_to_constants_v2(
         input_func)
-    constant_graph_def = output_func.graph.as_graph_def()
-    self.assertEqual(0, self._getNumVariables(constant_graph_def))
-    self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))
-
     self._testConvertedFunction(root, root.add, output_func, input_data)
 
   @test_util.run_v2_only
   def testKerasModel(self):
-    input_data = constant_op.constant(1., shape=[1, 1])
+    """Test a basic Keras model with Variables."""
+    input_data = {"x": constant_op.constant(1., shape=[1, 1])}
 
     # Create a simple Keras model.
     x = [-1, 0, 1, 2, 3, 4]
@@ -228,26 +233,14 @@
     model.compile(optimizer="sgd", loss="mean_squared_error")
     model.fit(x, y, epochs=1)
 
-    # Get the concrete function from the Keras model.
-    @def_function.function
+    @def_function.function(input_signature=[
+        tensor_spec.TensorSpec(shape=[1, 1], dtype=dtypes.float32)
+    ])
     def to_save(x):
       return model(x)
 
-    input_func = to_save.get_concrete_function(input_data)
-
-    variable_graph_def = input_func.graph.as_graph_def()
-    self.assertEqual(2, self._getNumVariables(variable_graph_def))
-
-    output_func = convert_to_constants.convert_variables_to_constants_v2(
-        input_func)
-    constant_graph_def = output_func.graph.as_graph_def()
-    self.assertEqual(0, self._getNumVariables(constant_graph_def))
-    self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))
-
-    # Check value.
-    expected_value = to_save(input_data)
-    actual_value = nest.flatten(output_func(input_data))
-    self.assertEqual(expected_value.numpy(), actual_value)
+    root, output_func = self._freezeModel(to_save)
+    self._testConvertedFunction(root, root.f, output_func, input_data)
 
   def _singleMetaGraphSavedModel(self):
     export_graph = ops.Graph()
@@ -276,21 +269,20 @@
 
   @test_util.run_v2_only
   def testRefVariableImport(self):
+    """Test a model with 1.X ReferenceVariables."""
+    input_data = {"start": constant_op.constant(1., shape=[1, 1])}
+
     saved = self._singleMetaGraphSavedModel()
     imported = load(saved)
     fn = imported.signatures["serving_default"]
-    output_func = convert_to_constants.convert_variables_to_constants_v2(fn)
-    constant_graph_def = output_func.graph.as_graph_def()
-    self.assertEqual(0, self._getNumVariables(constant_graph_def))
-    self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))
 
-    input_data = {"start": constant_op.constant(1., shape=[1, 1])}
+    output_func = convert_to_constants.convert_variables_to_constants_v2(fn)
     root = tracking.AutoTrackable()
     self._testConvertedFunction(root, fn, output_func, input_data)
 
   @test_util.run_v2_only
   def testIf(self):
-    """Test whether If op freezes correctly."""
+    """Test a model with the If op."""
     input_data = {
         "x": constant_op.constant([1., 2.], shape=[1, 2]),
         "b": constant_op.constant(True)
@@ -312,22 +304,12 @@
       return control_flow_ops.cond(
           b, true_fn=lambda: true_fn(x), false_fn=lambda: false_fn(x))
 
-    root = tracking.AutoTrackable()
-    root.f = model
-    input_func = root.f.get_concrete_function()
-    input_func(**input_data)
-
-    output_func = convert_to_constants.convert_variables_to_constants_v2(
-        input_func, lower_control_flow=False)
-    constant_graph_def = output_func.graph.as_graph_def()
-    self.assertEqual(0, self._getNumVariables(constant_graph_def))
-    self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))
-
+    root, output_func = self._freezeModel(model)
     self._testConvertedFunction(root, root.f, output_func, input_data)
 
   @test_util.run_v2_only
   def testStatelessIf(self):
-    """Test whether StatelessIf op freezes correctly."""
+    """Test a model with the StatelessIf op."""
     input_data = {"b": constant_op.constant(True)}
 
     x = constant_op.constant([1., 2.], shape=[1, 2], name="x")
@@ -343,21 +325,12 @@
     def model(b):
       return cond_v2.cond_v2(b, true_fn, false_fn)
 
-    root = tracking.AutoTrackable()
-    root.f = model
-    input_func = root.f.get_concrete_function()
-    input_func(**input_data)
-
-    output_func = convert_to_constants.convert_variables_to_constants_v2(
-        input_func, lower_control_flow=False)
-    constant_graph_def = output_func.graph.as_graph_def()
-    self.assertEqual(0, self._getNumVariables(constant_graph_def))
-    self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))
-
+    root, output_func = self._freezeModel(model)
     self._testConvertedFunction(root, root.f, output_func, input_data)
 
   @test_util.run_v2_only
   def testStaticRnn(self):
+    """Test a StaticRnn containing If ops."""
     input_data = {
         "x":
             constant_op.constant(
@@ -374,20 +347,12 @@
       return rnn.static_rnn(
           cell, seq, dtype=dtypes.float32, sequence_length=[1])
 
-    root = tracking.AutoTrackable()
-    root.f = model
-    input_func = root.f.get_concrete_function()
-
-    output_func = convert_to_constants.convert_variables_to_constants_v2(
-        input_func, lower_control_flow=False)
-    constant_graph_def = output_func.graph.as_graph_def()
-    self.assertEqual(0, self._getNumVariables(constant_graph_def))
-    self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))
-
+    root, output_func = self._freezeModel(model)
     self._testConvertedFunction(root, root.f, output_func, input_data)
 
   @test_util.run_v2_only
-  def testLoop(self):
+  def testWhile(self):
+    """Test a While loop."""
     input_data = {"x": constant_op.constant([1., 2., 3., 4.], shape=[2, 2])}
 
     weights = variables.Variable([[0.1, 0.2], [0.3, 0.4]], dtype=dtypes.float32)
@@ -404,21 +369,30 @@
     def model(x):
       return control_flow_ops.while_loop(condition, body, [x])
 
-    root = tracking.AutoTrackable()
-    root.f = model
-    input_func = root.f.get_concrete_function()
-    input_func(**input_data)
+    root, output_func = self._freezeModel(model)
+    self._testConvertedFunction(root, root.f, output_func, input_data)
 
-    output_func = convert_to_constants.convert_variables_to_constants_v2(
-        input_func, lower_control_flow=False)
-    constant_graph_def = output_func.graph.as_graph_def()
-    self.assertEqual(0, self._getNumVariables(constant_graph_def))
-    self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))
+  @test_util.run_v2_only
+  def testStatelessWhile(self):
+    """Test a StatelessWhile loop."""
+    input_data = {"x": constant_op.constant(2.)}
 
+    @def_function.function(input_signature=[
+        tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32)
+    ])
+    def model(x):
+      return while_v2.while_loop(
+          lambda v: v < 4.,
+          lambda v: v * v, [x],
+          return_same_structure=False,
+          name="while_1")  # x**2
+
+    root, output_func = self._freezeModel(model)
     self._testConvertedFunction(root, root.f, output_func, input_data)
 
   @test_util.run_v2_only
   def testDynamicRnn(self):
+    """Test a DynamicRnn containing While loops."""
     input_data = {
         "x":
             constant_op.constant(
@@ -434,16 +408,29 @@
     def model(x):
       return rnn.dynamic_rnn(cell, x, dtype=dtypes.float32)
 
-    root = tracking.AutoTrackable()
-    root.f = model
-    input_func = root.f.get_concrete_function()
+    root, output_func = self._freezeModel(model)
+    self._testConvertedFunction(root, root.f, output_func, input_data)
 
-    output_func = convert_to_constants.convert_variables_to_constants_v2(
-        input_func, lower_control_flow=False)
-    constant_graph_def = output_func.graph.as_graph_def()
-    self.assertEqual(0, self._getNumVariables(constant_graph_def))
-    self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))
+  @test_util.run_v2_only
+  def testKerasLSTM(self):
+    """Test a Keras LSTM containing dynamic_rnn ops."""
+    input_data = {
+        "x":
+            constant_op.constant(
+                np.array(
+                    np.random.random_sample((10, 10, 10)), dtype=np.float32))
+    }
 
+    model = keras.models.Sequential(
+        [keras.layers.LSTM(units=10, input_shape=(10, 10))])
+
+    @def_function.function(input_signature=[
+        tensor_spec.TensorSpec(shape=[10, 10, 10], dtype=dtypes.float32)
+    ])
+    def to_save(x):
+      return model(x)
+
+    root, output_func = self._freezeModel(to_save)
     self._testConvertedFunction(root, root.f, output_func, input_data)
 
 
diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py
index 16403b2..e817c31 100644
--- a/tensorflow/python/framework/dtypes.py
+++ b/tensorflow/python/framework/dtypes.py
@@ -566,6 +566,10 @@
     _NP_TO_TF[pdt] = next(
         _NP_TO_TF[dt] for dt in _NP_TO_TF if dt == pdt().dtype)
 
+
+TF_VALUE_DTYPES = set(_NP_TO_TF.values())
+
+
 _TF_TO_NP = {
     types_pb2.DT_HALF:
         np.float16,
diff --git a/tensorflow/python/framework/errors_impl.py b/tensorflow/python/framework/errors_impl.py
index fbdc2aa..caaeab4 100644
--- a/tensorflow/python/framework/errors_impl.py
+++ b/tensorflow/python/framework/errors_impl.py
@@ -46,6 +46,14 @@
   return compact_traces
 
 
+class InaccessibleTensorError(ValueError):
+  pass
+
+
+class OperatorNotAllowedInGraphError(TypeError):
+  pass
+
+
 @tf_export("errors.OpError", v1=["errors.OpError", "OpError"])
 @deprecation.deprecated_endpoints("OpError")
 class OpError(Exception):
diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py
index f747110..30db860 100644
--- a/tensorflow/python/framework/func_graph.py
+++ b/tensorflow/python/framework/func_graph.py
@@ -22,15 +22,20 @@
 import itertools
 import weakref
 
+import numpy as np
+
 from tensorflow.core.framework import attr_value_pb2
 from tensorflow.python.eager import context
 from tensorflow.python.eager import execute
 from tensorflow.python.eager import tape
 from tensorflow.python.eager.graph_only_ops import graph_placeholder
 from tensorflow.python.framework import composite_tensor
+from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_spec
+from tensorflow.python.framework import type_spec
 from tensorflow.python.framework.auto_control_deps import AutomaticControlDependencies
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import custom_gradient
@@ -40,6 +45,7 @@
 from tensorflow.python.util import compat
 from tensorflow.python.util import memory
 from tensorflow.python.util import nest
+from tensorflow.python.util import object_identity
 from tensorflow.python.util import tf_contextlib
 from tensorflow.python.util import tf_decorator
 from tensorflow.python.util.lazy_loader import LazyLoader
@@ -61,6 +67,9 @@
 ]
 
 
+_EAGER_CONST_THRESHOLD = 128
+
+
 class UnknownArgument(object):
   """Signifies an argument which is not currently handled."""
   pass
@@ -106,6 +115,7 @@
         type(None),
         dtypes.DType,
         tensor_spec.TensorSpec,
+        type_spec.TypeSpec,
     )):
       return arg
     return UnknownArgument()
@@ -151,9 +161,6 @@
       or the global default Graph.
     captures: Maps external tensor -> internal tensor (i.e. input placeholder).
       The entries are in the order they were captured.
-    deferred_captures: Maps arbitrary key -> (closure, nest of placeholders),
-      where at function call time the value of closure() will be used to feed
-      the nest of placeholders.
     control_captures: Set of external ops on which this graph has a control
       dependency.
     seed: The graph-level random seed.
@@ -190,14 +197,17 @@
     self.structured_input_signature = None
     self.structured_outputs = None
     self._weak_variables = []
-    self._watched_variables = weakref.WeakSet()
+    self._watched_variables = object_identity.ObjectIdentityWeakSet()
     self.outer_graph = ops.get_default_graph()
-    self.captures = py_collections.OrderedDict()
+    self._captures = py_collections.OrderedDict()
     # If not None, records the names of output args of this function. Used to
     # preserve the output names in the signature of a serialized+deserialized
     # function. Private at the moment mostly because it's often out of date.
     self._output_names = None
-    self.deferred_captures = py_collections.OrderedDict()
+    # Maps arbitrary key -> (closure, nest of placeholders), where at function
+    # call time the value of closure() will be used to feed the nest of
+    # placeholders.
+    self._deferred_captures = py_collections.OrderedDict()
     # Inherit capture-by-value from outer graph.
     if capture_by_value is not None:
       self.capture_by_value = capture_by_value
@@ -239,6 +249,12 @@
     else:
       self._collections = collections
 
+    # Keep track of whether this FuncGraph is exportable to SavedModel. Use
+    # `graph.mark_as_unsaveable(reason)` to mark this FuncGraph and any
+    # dependent functions as unsaveable.
+    self._saveable = True
+    self._saving_errors = set()
+
   def __str__(self):
     return "FuncGraph(name=%s, id=%s)" % (self.name, id(self))
 
@@ -272,7 +288,7 @@
     """
     if key is None:
       key = object()
-    if key not in self.deferred_captures:
+    if key not in self._deferred_captures:
 
       def convert_to_placeholder(s):
         if not isinstance(s, tensor_spec.TensorSpec):
@@ -295,8 +311,8 @@
         # pylint: enable=protected-access
         return nest.flatten(y, expand_composites=True)
 
-      self.deferred_captures[key] = (wrapped_closure, placeholder)
-    return self.deferred_captures[key][1]
+      self._deferred_captures[key] = (wrapped_closure, placeholder)
+    return self._deferred_captures[key][1]
 
   def control_dependencies(self, control_inputs):
     """Handles control dependencies.
@@ -438,7 +454,7 @@
       op_def=None,
       compute_device=True):
     # When capturing by value, do the read outside
-    reverse_captures = dict((v, k) for k, v in self.captures.items())
+    reverse_captures = dict((v, k) for k, v in self.captures)
     uncaptured_inputs = [reverse_captures.get(t, t) for t in inputs]
     with ops.init_scope():
       if context.executing_eagerly():
@@ -546,6 +562,11 @@
 
     Returns:
       Tensor from this FuncGraph.
+
+    Raises:
+      InaccessibleTensorError: if any tensors are accessed in a manner that
+      bypasses the mechanisms required for the data dependencies to be correctly
+      wired.
     """
     # Note: _forward_func_graph is currently only set when building the gradient
     # graph graph of a defun call. If the backwards graph tries to capture
@@ -553,17 +574,6 @@
     # makes sure that any tensor needed by a custom_gradient is correctly
     # captured.
 
-    # TODO(b/134097853): figure out a better way to check distributed variables
-    if hasattr(tensor, "_distribute_strategy") and hasattr(tensor, "_values"):
-      # This checks if the 'tensor' is a DistributedVariable. When it is a
-      # DistributedVariable, we do not want to check its "graph" attr as the
-      # following if branch does, because "graph" is not an attr for the
-      # container DistributedVariable object, and the underlying components may
-      # not have been initialized yet.
-      # The reason we do not use isinstance() is due to cyclic dependency issue.
-      if name is None:
-        name = str("distributed_variable")
-      return self._capture_helper(tensor, name)
     if (getattr(tensor, "graph", None) is not self and
         hasattr(self, "_forward_func_graph") and
         isinstance(self._forward_func_graph, FuncGraph)):
@@ -571,6 +581,13 @@
     if isinstance(tensor, ops.EagerTensor):
       if name is None:
         name = str(ops.uid())
+
+      # Small EagerTensors are captured with Const ops
+      if (tensor.dtype in dtypes.TF_VALUE_DTYPES and
+          np.prod(tensor.shape) <= _EAGER_CONST_THRESHOLD):
+        return self.capture_eager_tensor(tensor, name)
+
+      # Large EagerTensors and resources are captured with Placeholder ops
       return self._capture_helper(tensor, name)
     if tensor.graph is not self:
       if name is None:
@@ -578,39 +595,142 @@
       inner_graph = tensor.graph
       while inner_graph is not None and isinstance(inner_graph, FuncGraph):
         if inner_graph is self:
-          raise ValueError(
-              "Trying to capture a tensor from an inner function. This can be "
-              "caused by accessing a tensor defined inside a loop or "
-              "conditional body, or a subfunction, from a calling function, "
-              "without going through the proper return value mechanism. "
-              "Consider using TensorFlow mechanisms such as TensorArrays "
-              "to return tensors from inner functions or loop / conditional "
-              "bodies. Tensor: %s; tensor graph: %s; this graph: %s"
+          raise errors.InaccessibleTensorError(
+              "The tensor '%s' cannot be accessed here: it is defined"
+              " in another function or code block. Use return values,"
+              " explicit Python locals or TensorFlow collections to access"
+              " it. Defined in: %s; accessed from: %s.\n"
               % (tensor, tensor.graph, self))
         inner_graph = inner_graph.outer_graph
       return self._capture_helper(tensor, name)
     return tensor
 
   def _capture_helper(self, tensor, name):
-    captured_tensor = self.captures.get(tensor, None)
-    if captured_tensor is None:
-      captured_tensor = _create_substitute_placeholder(tensor, name=name,
-                                                       dtype=tensor.dtype)
-      self.captures[tensor] = captured_tensor
-      self.inputs.append(captured_tensor)
-    tape.record_operation("captured_value", [captured_tensor], [tensor],
+    capture = self._captures.get(ops.tensor_id(tensor))
+    if capture is None:
+      placeholder = _create_substitute_placeholder(
+          tensor, name=name, dtype=tensor.dtype)
+      self.add_capture(tensor, placeholder)
+    else:
+      placeholder = capture[1]
+    tape.record_operation("captured_value", [placeholder], [tensor],
                           lambda x: [x])
-    return captured_tensor
+    return placeholder
+
+  @property
+  def captures(self):
+    """Order list of tuples containing external and internal captures."""
+    return self._captures.values()
+
+  def add_capture(self, tensor, placeholder):
+    """Capture a specific tensor and utilize the provided placeholder.
+
+    Args:
+      tensor: Tensor to captures.
+      placeholder: Provided placeholder for the tensor.
+    """
+    self._captures[ops.tensor_id(tensor)] = (tensor, placeholder)
+    self.inputs.append(placeholder)
+
+  def reset_captures(self, capture_list):
+    """Set the captures with the provided list of captures & placeholder."""
+    self._captures = py_collections.OrderedDict()
+    for tensor, placeholder in capture_list:
+      self._captures[ops.tensor_id(tensor)] = (tensor, placeholder)
+
+  def pop_capture(self, tensor):
+    """Remove the capture and return the generated placeholder."""
+    capture = self._captures.pop(ops.tensor_id(tensor), None)
+    if capture is None:
+      return None
+
+    return capture[1]
+
+  def clear_captures(self):
+    # TODO(b/115366440): Delete this method when a custom OrderedDict is added.
+    # Clearing captures using clear() leaves some cycles around.
+    while self._captures:
+      self._captures.popitem()
+    memory.dismantle_ordered_dict(self._captures)
+    while self._deferred_captures:
+      self._deferred_captures.popitem()
+    memory.dismantle_ordered_dict(self._deferred_captures)
+
+  def capture_distributed_variable(self, variable, placeholder):
+    """Add given distributed variable to captures with given placeholder."""
+    self._captures[ops.tensor_id(variable)] = (variable, placeholder)
+    tape.record_operation("captured_value", [placeholder], [variable],
+                          lambda x: [x])
+
+  def capture_eager_tensor(self, tensor, name):
+    capture = self._captures.get(ops.tensor_id(tensor))
+    if capture is None:
+      # We clear all control dependencies and place the Const op on the same
+      # device as the source tensor. The device placement may be relaxed at
+      # a later date.
+      with ops.control_dependencies(None), self.device(tensor.device):
+        graph_const = constant_op.constant(tensor.numpy(), dtype=tensor.dtype,
+                                           shape=tensor.shape, name=name)
+      self.add_capture(tensor, graph_const)
+    else:
+      graph_const = capture[1]
+    tape.record_operation("captured_value", [graph_const], [tensor],
+                          lambda x: [x])
+    return graph_const
 
   @property
   def external_captures(self):
     """External tensors captured by this function."""
-    return list(self.captures.keys())
+    return [c[0] for c in self._captures.values()]
 
   @property
   def internal_captures(self):
     """Placeholders in this function corresponding captured tensors."""
-    return list(self.captures.values())
+    return [c[1] for c in self._captures.values()]
+
+  @property
+  def deferred_external_captures(self):
+    """Ordered nest of tensors whose placeholders will be fed at call time."""
+    return [c[0] for c in self._deferred_captures.values()]
+
+  @property
+  def deferred_internal_captures(self):
+    """List of nest of placeholders which at call time will be fed."""
+    return [c[1] for c in self._deferred_captures.values()]
+
+  @property
+  def variable_captures(self):
+    """Map of tensor ids of variable handles to variables which are captured."""
+    return {
+        ops.tensor_id(self._captures[ops.tensor_id(v.handle)][1]): v
+        for v in self.variables
+        if ops.tensor_id(v.handle) in self._captures
+    }
+
+  def mark_as_unsaveable(self, error_message):
+    """Marks this FuncGraph as unsaveable.
+
+    Any attempts to export this FuncGraph will raise an error with the specified
+    message.
+
+    Args:
+      error_message: List or string containing the error message to be raised
+        when saving this FuncGraph to SavedModel.
+    """
+    self._saveable = False
+    if isinstance(error_message, str):
+      error_message = [error_message]
+    self._saving_errors.update(error_message)
+
+  @property
+  def saveable(self):
+    """Returns whether this FuncGraph is saveable."""
+    return self._saveable
+
+  @property
+  def saving_errors(self):
+    """Returns set of errors preventing this FuncGraph from being saved."""
+    return self._saving_errors
 
 
 def func_graph_from_py_func(name,
@@ -778,11 +898,11 @@
                 autograph.ConversionOptions(
                     recursive=True,
                     optional_features=autograph_options,
-                    force_conversion=True,
+                    user_requested=True,
                 ), args, kwargs)
           except Exception as e:  # pylint:disable=broad-except
             if hasattr(e, "ag_error_metadata"):
-              raise e.ag_error_metadata.to_exception(type(e))
+              raise e.ag_error_metadata.to_exception(e)
             else:
               raise
 
@@ -807,7 +927,7 @@
     # Variables in `func_args`, `func_kwargs` should be explicit inputs
     # to the function, not captured inputs.
     graph_variables = list(func_graph._watched_variables)  # pylint: disable=protected-access
-    arg_variables = set()
+    arg_variables = object_identity.ObjectIdentitySet()
     inputs = []
     for arg in (nest.flatten(func_args, expand_composites=True) +
                 nest.flatten(func_kwargs, expand_composites=True)):
@@ -815,7 +935,7 @@
         # Even if an argument variable was not used in the function, we've
         # already manually captured the resource Tensor when creating argument
         # placeholders.
-        resource_placeholder = func_graph.captures.pop(arg.handle, None)
+        resource_placeholder = func_graph.pop_capture(arg.handle)
         if resource_placeholder is None:
           continue
         arg_variables.add(arg)
@@ -824,12 +944,8 @@
         inputs.append(arg)
     variables = [v for v in graph_variables if v not in arg_variables]
     func_graph.inputs = (
-        inputs +
-        list(func_graph.captures.values()) +
-        nest.flatten(
-            [x[1] for x in func_graph.deferred_captures.values()],
-            expand_composites=True))
-
+        inputs + func_graph.internal_captures + nest.flatten(
+            func_graph.deferred_internal_captures, expand_composites=True))
     func_graph.structured_outputs = func_outputs
     # Returning a closed-over tensor does not trigger convert_to_tensor.
     func_graph.outputs.extend(
@@ -856,7 +972,7 @@
   """
   if (not isinstance(tensor, ops.EagerTensor) and
       tensor.op.graph.building_function and tensor.op.type == "Placeholder"):
-    for input_t, placeholder_t in tensor.op.graph.captures.items():
+    for input_t, placeholder_t in tensor.op.graph.captures:
       if tensor == placeholder_t:
         return maybe_captured(input_t)
   # pylint: enable=protected-access
@@ -1066,12 +1182,5 @@
     func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable
       after this function.
   """
-  # TODO(b/115366440): Delete this method when a custom OrderedDict is added.
-  # Clearing captures using clear() leaves some cycles around.
-  while func_graph.captures:
-    func_graph.captures.popitem()
-  memory.dismantle_ordered_dict(func_graph.captures)
-  while func_graph.deferred_captures:
-    func_graph.deferred_captures.popitem()
-  memory.dismantle_ordered_dict(func_graph.deferred_captures)
+  func_graph.clear_captures()
   ops.dismantle_graph(func_graph)
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index e607838..3404056 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -1186,6 +1186,15 @@
                      (attr_name, type(value)))
 
 
+def _get_kwarg_as_str_attr(attr_name, value):
+  """Creates an AttrValue for a python object."""
+  if isinstance(value, str):
+    return attr_value_pb2.AttrValue(s=compat.as_bytes(value))
+  else:
+    raise ValueError("Unsupported attribute type for %s with type %s" %
+                     (attr_name, type(value)))
+
+
 def _parse_kwargs_as_attrs(func_name, **kwargs):
   """Parses **kwargs into a node's attributes."""
   attrs = {}
@@ -1218,7 +1227,10 @@
     if key.startswith("experimental_"):
       attrs[key] = _get_experimental_kwarg_as_attr(key, kwargs[key])
       del kwargs[key]
-
+    # Support for https://github.com/tensorflow/community/pull/113/files.
+    elif key == "_implements" or key == "_reference":
+      attrs[key] = _get_kwarg_as_str_attr(key, kwargs[key])
+      del kwargs[key]
   if kwargs:
     raise ValueError("Unknown keyword arguments: %s" % kwargs.keys())
   return attrs
diff --git a/tensorflow/python/framework/function_def_to_graph.py b/tensorflow/python/framework/function_def_to_graph.py
index 9d14495..9e7e3cc 100644
--- a/tensorflow/python/framework/function_def_to_graph.py
+++ b/tensorflow/python/framework/function_def_to_graph.py
@@ -62,7 +62,7 @@
 
   with func_graph.as_default():
     # Add all function nodes to the graph.
-    importer.import_graph_def(graph_def, name="")
+    importer.import_graph_def_for_function(graph_def, name="")
 
     # Initialize fields specific to FuncGraph.
 
@@ -99,8 +99,9 @@
     output_names = {}
     for ret_arg_def, tensor_name in zip(
         fdef.signature.output_arg, output_tensor_names):
-      output_names[func_graph.get_tensor_by_name(tensor_name)] = (
-          ret_arg_def.name)
+      output_names[ops.tensor_id(
+          func_graph.get_tensor_by_name(tensor_name))] = (
+              ret_arg_def.name)
     func_graph._output_names = output_names  # pylint: disable=protected-access
   return func_graph
 
diff --git a/tensorflow/python/framework/function_def_to_graph_test.py b/tensorflow/python/framework/function_def_to_graph_test.py
index 3c58598..588ad6c 100644
--- a/tensorflow/python/framework/function_def_to_graph_test.py
+++ b/tensorflow/python/framework/function_def_to_graph_test.py
@@ -218,7 +218,7 @@
     # `function_def_to_graph` can find it.
     fn2_defun()
 
-    fdef = fn2_defun._inference_function.definition
+    fdef = fn2_defun.function_def
     func_graph = function_def_to_graph.function_def_to_graph(fdef)
     with func_graph.as_default():
       x_ph, y_ph = func_graph.inputs
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 7f679a4..58a1d37 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -1388,6 +1388,20 @@
     self.assertEqual(FunctionWithBoolAttr.definition.attr["experimental_tag"].b,
                      True)
 
+  def testImplementsReferenceAttrs(self):
+
+    @function.Defun(
+        dtypes.int32, _implements="org.google.lstm", _reference="arxiv.org")
+    def FunctionWithStrAttr(i):
+      return array_ops.identity(i)
+
+    self.assertIn("_implements", FunctionWithStrAttr.definition.attr)
+    self.assertEqual(FunctionWithStrAttr.definition.attr["_implements"].s,
+                     b"org.google.lstm")
+    self.assertIn("_reference", FunctionWithStrAttr.definition.attr)
+    self.assertEqual(FunctionWithStrAttr.definition.attr["_reference"].s,
+                     b"arxiv.org")
+
 
 class FunctionOverloadTest(test.TestCase):
 
diff --git a/tensorflow/python/framework/graph_util_impl.py b/tensorflow/python/framework/graph_util_impl.py
index 5c131ab..ae7d9eb 100644
--- a/tensorflow/python/framework/graph_util_impl.py
+++ b/tensorflow/python/framework/graph_util_impl.py
@@ -45,6 +45,13 @@
     "VariableV2",
 }
 
+_CONTROL_FLOW_OP_NAMES_OR_IDENTITY = [
+    "Switch",
+    "Enter",
+    "Exit",
+    "Identity",
+]
+
 
 def _is_variable_op(op):
   """Returns true if 'op' refers to a Variable node."""
@@ -290,11 +297,12 @@
       else:
         variable_names.append(variable_name + ":0")
     elif node.op in ["ReadVariableOp", "ResourceGather"]:
-      # There can be one or more Identity or Switch ops in between the
+      # There can be one or more Identity or control flow ops in between the
       # ReadVariableOp and VarHandleOp. Store the ops with the associated
       # dtypes.
       source_op_name = get_input_name(node)
-      while map_name_to_node[source_op_name].op in ["Identity", "Switch"]:
+      while (map_name_to_node[source_op_name].op in
+             _CONTROL_FLOW_OP_NAMES_OR_IDENTITY):
         resource_op_types[source_op_name] = node.attr["dtype"]
         source_op_name = get_input_name(map_name_to_node[source_op_name])
       if map_name_to_node[source_op_name].op != "VarHandleOp":
diff --git a/tensorflow/python/framework/graph_util_test.py b/tensorflow/python/framework/graph_util_test.py
index d7626e9..1cae964 100644
--- a/tensorflow/python/framework/graph_util_test.py
+++ b/tensorflow/python/framework/graph_util_test.py
@@ -327,6 +327,10 @@
         sess.graph.get_tensor_by_name(tensor.name) for tensor in tensor_list
     ]
 
+  def _get_tensor_names(self, tensors):
+    """Returns a list of string names for the tensors specified."""
+    return [tensor.name.split(":")[0] for tensor in tensors]
+
   def _evaluate_graph_def(self, graph_def, inputs, outputs, input_data):
     """Evaluates the GraphDef using Sessions."""
     with ops.Graph().as_default() as graph:
@@ -338,6 +342,19 @@
     return sess.run(
         output_tensors, feed_dict=dict(zip(input_tensors, input_data)))
 
+  def _ensure_no_variables_in_graph(self, graph_def):
+    """Ensures there are no variables in the graph."""
+    for node in graph_def.node:
+      self.assertNotIn(
+          node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
+
+  def _test_converted_keras_model(self, model, constant_graph_def, input_data):
+    """Compares the converted Keras model."""
+    expected_value = model.predict(input_data)
+    actual_value = self._evaluate_graph_def(constant_graph_def, model.inputs,
+                                            model.outputs, [input_data])
+    np.testing.assert_almost_equal(np.array([expected_value]), actual_value, 5)
+
   def _test_variable_to_const_conversion(self, use_resource):
     with ops.Graph().as_default():
       with variable_scope.variable_scope("", use_resource=use_resource):
@@ -395,10 +412,7 @@
     with ops.Graph().as_default():
       _ = importer.import_graph_def(constant_graph_def, name="")
       self.assertEqual(4, len(constant_graph_def.node))
-      for node in constant_graph_def.node:
-        self.assertNotIn(
-            node.op,
-            ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
+      self._ensure_no_variables_in_graph(constant_graph_def)
       with session.Session() as sess:
         output_node = sess.graph.get_tensor_by_name("output_node:0")
         output = self.evaluate(output_node)
@@ -440,10 +454,7 @@
         constant_graph_def = graph_util.convert_variables_to_constants(
             sess, variable_graph_def, ["output_node"])
 
-    # Ensure there are no variables after freezing.
-    for node in constant_graph_def.node:
-      self.assertNotIn(
-          node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
+    self._ensure_no_variables_in_graph(constant_graph_def)
 
   def testReferenceVariables(self):
     """Freezes a graph with reference variables."""
@@ -464,9 +475,6 @@
   @test_util.run_v1_only("Incompatible with TF 2.0")
   def testWithEmbeddings(self):
     """Freezes a graph with embeddings."""
-    input_data = np.array(np.random.random_sample([1, 1]), dtype=np.int32)
-
-    # Make model.
     state_input = keras.layers.Input(
         shape=(1,), name="state_input", dtype="int32")
     output = keras.layers.Embedding(
@@ -476,25 +484,19 @@
     model.compile(
         loss={"state": "sparse_categorical_crossentropy"}, optimizer="adam")
 
-    # Get associated session.
+    # Freeze the graph.
     sess = keras.backend.get_session()
     variable_graph_def = sess.graph_def
-    output_tensor = [tensor.name.split(":")[0] for tensor in model.outputs]
+    output_tensor = self._get_tensor_names(model.outputs)
     constant_graph_def = graph_util.convert_variables_to_constants(
         sess, variable_graph_def, output_tensor)
 
-    # Ensure graph has no variables.
-    for node in constant_graph_def.node:
-      self.assertNotIn(
-          node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
+    # Validate converted graph.
+    input_data = np.array(np.random.random_sample([1, 1]), dtype=np.int32)
+    self._ensure_no_variables_in_graph(constant_graph_def)
+    self._test_converted_keras_model(model, constant_graph_def, input_data)
 
-    # Compare the value of the graphs.
-    expected_value = model.predict(input_data)
-    actual_value = self._evaluate_graph_def(constant_graph_def, model.inputs,
-                                            model.outputs, [input_data])
-    np.testing.assert_almost_equal(np.array([expected_value]), actual_value, 5)
-
-  def testWithSwitch(self):
+  def testGraphWithSwitch(self):
     """Freezes a graph which contains a Switch with type RESOURCE_DT."""
     with ops.Graph().as_default():
       with variable_scope.variable_scope("", use_resource=True):
@@ -513,10 +515,25 @@
           constant_graph_def = graph_util.convert_variables_to_constants(
               sess, variable_graph_def, ["output_node"])
 
-    # Ensure there are no variables after freezing.
-    for node in constant_graph_def.node:
-      self.assertNotIn(
-          node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
+    self._ensure_no_variables_in_graph(constant_graph_def)
+
+  @test_util.run_v1_only("Incompatible with TF 2.0")
+  def testLSTM(self):
+    """Freezes a Keras LSTM."""
+    model = keras.models.Sequential(
+        [keras.layers.LSTM(units=10, input_shape=(10, 10))])
+
+    # Freeze the model.
+    sess = keras.backend.get_session()
+    variable_graph_def = sess.graph_def
+    output_tensor = self._get_tensor_names(model.outputs)
+    constant_graph_def = graph_util.convert_variables_to_constants(
+        sess, variable_graph_def, output_tensor)
+
+    # Validate converted graph.
+    input_data = np.array(np.random.random_sample([10, 10, 10]), dtype=np.int32)
+    self._ensure_no_variables_in_graph(constant_graph_def)
+    self._test_converted_keras_model(model, constant_graph_def, input_data)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py
index 3ba7176..95a7710 100644
--- a/tensorflow/python/framework/importer.py
+++ b/tensorflow/python/framework/importer.py
@@ -202,7 +202,8 @@
 
 
 def _PopulateTFImportGraphDefOptions(options, prefix, input_map,
-                                     return_elements):
+                                     return_elements,
+                                     validate_colocation_constraints):
   """Populates the TF_ImportGraphDefOptions `options`."""
   c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix)
   c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, True)
@@ -229,6 +230,9 @@
       c_api.TF_ImportGraphDefOptionsAddReturnOperation(options,
                                                        compat.as_str(name))
 
+  c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(
+      options, validate_colocation_constraints)
+
 
 def _ProcessNewOps(graph):
   """Processes the newly-added TF_Operations in `graph`."""
@@ -392,6 +396,73 @@
       do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
       it refers to an unknown tensor).
   """
+  return _import_graph_def_internal(
+      graph_def,
+      input_map=input_map,
+      return_elements=return_elements,
+      name=name,
+      op_dict=op_dict,
+      producer_op_list=producer_op_list)
+
+
+def import_graph_def_for_function(  # pylint: disable=invalid-name
+    graph_def, name=None):
+  """Like import_graph_def but does not validate colocation constraints."""
+  return _import_graph_def_internal(
+      graph_def, validate_colocation_constraints=False, name=name)
+
+
+def _import_graph_def_internal(  # pylint: disable=invalid-name
+    graph_def,
+    input_map=None,
+    return_elements=None,
+    validate_colocation_constraints=True,
+    name=None,
+    op_dict=None,
+    producer_op_list=None):
+  """Imports the graph from `graph_def` into the current default `Graph`.
+
+  This function provides a way to import a serialized TensorFlow
+  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
+  protocol buffer, and extract individual objects in the `GraphDef` as
+  `tf.Tensor` and `tf.Operation` objects. Once extracted,
+  these objects are placed into the current default `Graph`. See
+  `tf.Graph.as_graph_def` for a way to create a `GraphDef`
+  proto.
+
+  Args:
+    graph_def: A `GraphDef` proto containing operations to be imported into the
+      default graph.
+    input_map: A dictionary mapping input names (as strings) in `graph_def` to
+      `Tensor` objects. The values of the named input tensors in the imported
+      graph will be re-mapped to the respective `Tensor` values.
+    return_elements: A list of strings containing operation names in `graph_def`
+      that will be returned as `Operation` objects; and/or tensor names in
+      `graph_def` that will be returned as `Tensor` objects.
+    validate_colocation_constraints: Whether to validate colocation constraints.
+    name: (Optional.) A prefix that will be prepended to the names in
+      `graph_def`. Note that this does not apply to imported function names.
+      Defaults to `"import"`.
+    op_dict: (Optional.) Deprecated, do not use.
+    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
+      list of `OpDef`s used by the producer of the graph. If provided,
+      unrecognized attrs for ops in `graph_def` that have their default value
+      according to `producer_op_list` will be removed. This will allow some more
+      `GraphDef`s produced by later binaries to be accepted by earlier binaries.
+
+  Returns:
+    A list of `Operation` and/or `Tensor` objects from the imported graph,
+    corresponding to the names in `return_elements`,
+    and None if `returns_elements` is None.
+
+  Raises:
+    TypeError: If `graph_def` is not a `GraphDef` proto,
+      `input_map` is not a dictionary mapping strings to `Tensor` objects,
+      or `return_elements` is not a list of strings.
+    ValueError: If `input_map`, or `return_elements` contains names that
+      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
+      it refers to an unknown tensor).
+  """
   op_dict = op_def_registry.get_registered_ops()
 
   graph_def = _ProcessGraphDefParam(graph_def, op_dict)
@@ -416,8 +487,8 @@
 
   scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
   options = scoped_options.options
-  _PopulateTFImportGraphDefOptions(options, prefix, input_map,
-                                   return_elements)
+  _PopulateTFImportGraphDefOptions(options, prefix, input_map, return_elements,
+                                   validate_colocation_constraints)
 
   # _ProcessNewOps mutates the new operations. _mutation_lock ensures a
   # Session.run call cannot occur between creating the TF_Operations in the
diff --git a/tensorflow/python/framework/indexed_slices.py b/tensorflow/python/framework/indexed_slices.py
index 3026caa..8bc21ca 100644
--- a/tensorflow/python/framework/indexed_slices.py
+++ b/tensorflow/python/framework/indexed_slices.py
@@ -23,6 +23,7 @@
 import warnings
 import numpy as np
 
+from tensorflow.python import tf2
 from tensorflow.python.eager import context
 from tensorflow.python.framework import composite_tensor
 from tensorflow.python.framework import dtypes
@@ -211,7 +212,7 @@
       self._dense_shape_dtype = None
     else:
       self._dense_shape_dtype = dtypes.as_dtype(dense_shape_dtype)
-    self._indices_shape = tensor_shape.as_shape(indices_shape)
+    self._indices_shape = tensor_shape.as_shape(indices_shape).with_rank(1)
 
   def _serialize(self):
     return (self._shape, self._values_dtype, self._indices_dtype,
@@ -235,7 +236,14 @@
       return (value.values, value.indices, value.dense_shape)
 
   def _from_components(self, tensor_list):
-    return IndexedSlices(*tensor_list)
+    if (all(isinstance(t, np.ndarray) for t in tensor_list) and
+        not tf2.enabled()):
+      if len(tensor_list) == 2:
+        return IndexedSlicesValue(tensor_list[0], tensor_list[1], None)
+      else:
+        return IndexedSlicesValue(*tensor_list)
+    else:
+      return IndexedSlices(*tensor_list)
 
 
 @tf_export(v1=["convert_to_tensor_or_indexed_slices"])
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index a20cc83..eaa8c8d 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -63,11 +63,18 @@
 from tensorflow.python.util import function_utils
 from tensorflow.python.util import lock_util
 from tensorflow.python.util import memory
+from tensorflow.python.util import object_identity
 from tensorflow.python.util import tf_contextlib
 from tensorflow.python.util import tf_stack
+from tensorflow.python.util.compat import collections_abc
 from tensorflow.python.util.deprecation import deprecated_args
+from tensorflow.python.util.lazy_loader import LazyLoader
 from tensorflow.python.util.tf_export import tf_export
 
+ag_ctx = LazyLoader(
+    "ag_ctx", globals(),
+    "tensorflow.python.autograph.core.ag_ctx")
+
 
 # Temporary global switches determining if we should enable the work-in-progress
 # calls to the C API. These will be removed once all functionality is supported.
@@ -257,7 +264,7 @@
     text = "\n" + text
   return text
 
-
+@tf_export(v1=["enable_tensor_equality"])
 def enable_tensor_equality():
   """Compare Tensors with element-wise comparison and thus be unhashable.
 
@@ -268,7 +275,7 @@
   """
   Tensor._USE_EQUALITY = True  # pylint: disable=protected-access
 
-
+@tf_export(v1=["disable_tensor_equality"])
 def disable_tensor_equality():
   """Compare Tensors by their id and be hashable.
 
@@ -500,11 +507,45 @@
     raise ValueError(
         "Tensor._shape cannot be assigned, use Tensor.set_shape instead.")
 
+  def _disallow_when_autograph_disabled(self, task):
+    raise errors.OperatorNotAllowedInGraphError(
+        "{} is not allowed: AutoGraph is disabled in this function."
+        " Try decorating it directly with @tf.function.".format(task))
+
+  def _disallow_when_autograph_enabled(self, task):
+    raise errors.OperatorNotAllowedInGraphError(
+        "{} is not allowed: AutoGraph did not convert this function. Try"
+        " decorating it directly with @tf.function.".format(task))
+
+  def _disallow_in_graph_mode(self, task):
+    raise errors.OperatorNotAllowedInGraphError(
+        "{} is not allowed in Graph execution. Use Eager execution or decorate"
+        " this function with @tf.function.".format(task))
+
+  def _disallow_bool_casting(self):
+    if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
+      self._disallow_when_autograph_disabled(
+          "using a `tf.Tensor` as a Python `bool`")
+    elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED:
+      self._disallow_when_autograph_enabled(
+          "using a `tf.Tensor` as a Python `bool`")
+    else:
+      # Default: V1-style Graph execution.
+      self._disallow_in_graph_mode("using a `tf.Tensor` as a Python `bool`")
+
+  def _disallow_iteration(self):
+    if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
+      self._disallow_when_autograph_disabled("iterating over `tf.Tensor`")
+    elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED:
+      self._disallow_when_autograph_enabled("iterating over `tf.Tensor`")
+    else:
+      # Default: V1-style Graph execution.
+      self._disallow_in_graph_mode("iterating over `tf.Tensor`")
+
   def __iter__(self):
     if not context.executing_eagerly():
-      raise TypeError(
-          "Tensor objects are only iterable when eager execution is "
-          "enabled. To iterate over this tensor use tf.map_fn.")
+      self._disallow_iteration()
+
     shape = self._shape_tuple()
     if shape is None:
       raise TypeError("Cannot iterate over a tensor with unknown shape.")
@@ -666,8 +707,11 @@
                                                    self._dtype.name)
 
   def __hash__(self):
-    if Tensor._USE_EQUALITY and executing_eagerly_outside_functions():
-      raise TypeError("Tensor is unhashable if Tensor equality is enabled.")
+    g = getattr(self, "graph", None)
+    if (Tensor._USE_EQUALITY and executing_eagerly_outside_functions() and
+        (g is None or g._building_function)):  # pylint: disable=protected-access
+      raise TypeError("Tensor is unhashable if Tensor equality is enabled. "
+                      "Instead, use tensor.experimental_ref() as the key.")
     else:
       return id(self)
 
@@ -687,6 +731,15 @@
   # with ndarrays.
   __array_priority__ = 100
 
+  def __array__(self):
+    raise NotImplementedError("Cannot convert a symbolic Tensor ({}) to a numpy"
+                              " array.".format(self.name))
+
+  def __len__(self):
+    raise TypeError("len is not well defined for symbolic Tensors. ({}) "
+                    "Please call `x.shape` rather than `len(x)` for "
+                    "shape information.".format(self.name))
+
   @staticmethod
   def _override_operator(operator, func):
     _override_helper(Tensor, operator, func)
@@ -695,8 +748,8 @@
     """Dummy method to prevent a tensor from being used as a Python `bool`.
 
     This overload raises a `TypeError` when the user inadvertently
-    treats a `Tensor` as a boolean (e.g. in an `if` statement). For
-    example:
+    treats a `Tensor` as a boolean (most commonly in an `if` or `while`
+    statement), in code that was not converted by AutoGraph. For example:
 
     ```python
     if tf.constant(True):  # Will raise.
@@ -706,17 +759,10 @@
       # ...
     ```
 
-    This disallows ambiguities between testing the Python value vs testing the
-    dynamic condition of the `Tensor`.
-
     Raises:
       `TypeError`.
     """
-    raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "
-                    "Use `if t is not None:` instead of `if t:` to test if a "
-                    "tensor is defined, and use TensorFlow ops such as "
-                    "tf.cond to execute subgraphs conditioned on the value of "
-                    "a tensor.")
+    self._disallow_bool_casting()
 
   def __nonzero__(self):
     """Dummy method to prevent a tensor from being used as a Python `bool`.
@@ -726,11 +772,7 @@
     Raises:
       `TypeError`.
     """
-    raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "
-                    "Use `if t is not None:` instead of `if t:` to test if a "
-                    "tensor is defined, and use TensorFlow ops such as "
-                    "tf.cond to execute subgraphs conditioned on the value of "
-                    "a tensor.")
+    self._disallow_bool_casting()
 
   def eval(self, feed_dict=None, session=None):
     """Evaluates this tensor in a `Session`.
@@ -755,6 +797,59 @@
     """
     return _eval_using_default_session(self, feed_dict, self.graph, session)
 
+  def experimental_ref(self):
+    # tf.Variable also has the same experimental_ref() API.  If you update the
+    # documenation here, please update tf.Variable.experimental_ref() as well.
+    """Returns a hashable reference object to this Tensor.
+
+    Warning: Experimental API that could be changed or removed.
+
+    The primary usecase for this API is to put tensors in a set/dictionary.
+    We can't put tensors in a set/dictionary as `tensor.__hash__()` is no longer
+    available starting Tensorflow 2.0.
+
+    ```python
+    import tensorflow as tf
+
+    x = tf.constant(5)
+    y = tf.constant(10)
+    z = tf.constant(10)
+
+    # The followings will raise an exception starting 2.0
+    # TypeError: Tensor is unhashable if Tensor equality is enabled.
+    tensor_set = {x, y, z}
+    tensor_dict = {x: 'five', y: 'ten', z: 'ten'}
+    ```
+
+    Instead, we can use `tensor.experimental_ref()`.
+
+    ```python
+    tensor_set = {x.experimental_ref(),
+                  y.experimental_ref(),
+                  z.experimental_ref()}
+
+    print(x.experimental_ref() in tensor_set)
+    ==> True
+
+    tensor_dict = {x.experimental_ref(): 'five',
+                   y.experimental_ref(): 'ten',
+                   z.experimental_ref(): 'ten'}
+
+    print(tensor_dict[y.experimental_ref()])
+    ==> ten
+    ```
+
+    Also, the reference object provides `.deref()` function that returns the
+    original Tensor.
+
+    ```python
+    x = tf.constant(5)
+    print(x.experimental_ref().deref())
+    ==> tf.Tensor(5, shape=(), dtype=int32)
+    ```
+    """
+    return object_identity.Reference(self)
+
 
 # TODO(agarwal): consider getting rid of this.
 class _EagerTensorBase(Tensor):
@@ -893,7 +988,7 @@
     """
     raise NotImplementedError()
 
-  def _copy_to_device(self, context, device):  # pylint: disable=redefined-outer-name
+  def _copy_to_device(self, device_name):  # pylint: disable=redefined-outer-name
     raise NotImplementedError()
 
   @staticmethod
@@ -902,7 +997,6 @@
 
   def _copy_nograd(self, ctx=None, device_name=None):
     """Copies tensor to dest device, but doesn't record the operation."""
-    # pylint: disable=protected-access
     # Creates a new tensor on the dest device.
     if ctx is None:
       ctx = context.context()
@@ -911,7 +1005,7 @@
     # pylint: disable=protected-access
     try:
       ctx.ensure_initialized()
-      new_tensor = self._copy_to_device(context=ctx._handle, device=device_name)
+      new_tensor = self._copy_to_device(device_name)
     except core._NotOkStatusException as e:
       six.raise_from(core._status_to_exception(e.code, e.message), None)
     return new_tensor
@@ -1158,19 +1252,21 @@
                                as_ref=False,
                                preferred_dtype=None,
                                ctx=None,
-                               accept_composite_tensors=False):
+                               accepted_result_types=(Tensor,)):
   """Implementation of the public convert_to_tensor."""
+  if isinstance(value, EagerTensor):
+    if ctx is None:
+      ctx = context.context()
+    if not ctx.executing_eagerly():
+      graph = get_default_graph()
+      if not graph.building_function:
+        raise RuntimeError("Attempting to capture an EagerTensor without "
+                           "building a function.")
+      return graph.capture(value, name=name)
+
   if dtype is not None:
     dtype = dtypes.as_dtype(dtype)
-  if ctx is None:
-    ctx = context.context()
-  if isinstance(value, EagerTensor) and not ctx.executing_eagerly():
-    graph = get_default_graph()
-    if not graph.building_function:
-      raise RuntimeError("Attempting to capture an EagerTensor without "
-                         "building a function.")
-    return graph.capture(value, name=name)
-  elif isinstance(value, Tensor):
+  if isinstance(value, Tensor):
     if dtype is not None and not dtype.is_compatible_with(value.dtype):
       raise ValueError(
           "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
@@ -1179,7 +1275,6 @@
 
   if preferred_dtype is not None:
     preferred_dtype = dtypes.as_dtype(preferred_dtype)
-
   for base_type, conversion_func in tensor_conversion_registry.get(type(value)):
     # If dtype is None but preferred_dtype is not None, we try to
     # cast to preferred_dtype first.
@@ -1204,11 +1299,7 @@
     if ret is NotImplemented:
       continue
 
-    is_acceptable_type = (
-        isinstance(ret, Tensor) or
-        (accept_composite_tensors and
-         isinstance(ret, composite_tensor.CompositeTensor)))
-    if not is_acceptable_type:
+    if not isinstance(ret, accepted_result_types):
       raise RuntimeError(
           "%sConversion function %r for type %s returned non-Tensor: %r" %
           (_error_prefix(name), conversion_func, base_type, ret))
@@ -1254,7 +1345,7 @@
     RuntimeError: If a registered conversion function returns an invalid
       value.
   """
-  if not isinstance(values, collections.Sequence):
+  if not isinstance(values, collections_abc.Sequence):
     raise TypeError("values must be a sequence.")
   ret = []
   if ctx is None:
@@ -1362,7 +1453,7 @@
         dtype=dtype,
         name=name,
         as_ref=as_ref,
-        accept_composite_tensors=True)
+        accepted_result_types=(Tensor, composite_tensor.CompositeTensor))
 
 
 def internal_convert_n_to_tensor_or_composite(values,
@@ -1391,7 +1482,7 @@
     RuntimeError: If a registered conversion function returns an invalid
       value.
   """
-  if not isinstance(values, collections.Sequence):
+  if not isinstance(values, collections_abc.Sequence):
     raise TypeError("values must be a sequence.")
   ret = []
   for i, value in enumerate(values):
@@ -2041,8 +2132,6 @@
     """The list of `Tensor` objects representing the outputs of this op."""
     return self._outputs
 
-# pylint: disable=protected-access
-
   class _InputList(object):
     """Immutable input list wrapper."""
 
@@ -2064,9 +2153,6 @@
     def __getitem__(self, i):
       return self._inputs[i]
 
-
-# pylint: enable=protected-access
-
   @property
   def inputs(self):
     """The list of `Tensor` objects representing the data inputs of this op."""
@@ -2785,7 +2871,7 @@
     # self._thread_local._colocation_stack is used instead.
     self._graph_colocation_stack = traceable_stack.TraceableStack()
     # Set of tensors that are dangerous to feed!
-    self._unfeedable_tensors = set()
+    self._unfeedable_tensors = object_identity.ObjectIdentitySet()
     # Set of operations that are dangerous to fetch!
     self._unfetchable_ops = set()
     # A map of tensor handle placeholder to tensor dtype.
@@ -2817,6 +2903,14 @@
     self._add_control_dependencies = False
     # Cache for OpDef protobufs retrieved via the C API.
     self._op_def_cache = {}
+    # Cache for constant results of `broadcast_gradient_args()`. The keys are
+    # tuples of fully-defined shapes: (x_shape_tuple, y_shape_tuple), and the
+    # values are tuples of reduction indices: (rx, ry).
+    self._bcast_grad_args_cache = {}
+    # Cache for constant results of `reduced_shape()`. The keys are pairs of
+    # tuples: (input_shape_tuple, reduction_indices_tuple), and the values
+    # are pairs of tuples: (output_shape_kept_dims, tile_scaling).
+    self._reduced_shape_cache = {}
 
     # TODO(skyewm): fold as much of the above as possible into the C
     # implementation
@@ -4961,9 +5055,10 @@
     return self._thread_local._auto_cast_variable_read_dtype  # pylint: disable=protected-access
 
   @_auto_cast_variable_read_dtype.setter
-  def _auto_cast_variable_read_dtype(self, _auto_cast_variable_read_dtype):
-    self._thread_local._auto_cast_variable_read_dtype = (  # pylint: disable=protected-access
-        _auto_cast_variable_read_dtype)
+  def _auto_cast_variable_read_dtype(self, dtype):
+    if dtype:
+      dtype = dtypes.as_dtype(dtype)
+    self._thread_local._auto_cast_variable_read_dtype = dtype  # pylint: disable=protected-access
 
   @tf_contextlib.contextmanager
   def _enable_auto_casting_variables(self, dtype):
@@ -5847,8 +5942,9 @@
     The appropriate graph to use for the given inputs.
 
   """
-  if get_default_graph().building_function:
-    return get_default_graph()
+  current_default_graph = get_default_graph()
+  if current_default_graph.building_function:
+    return current_default_graph
 
   op_input_list = tuple(op_input_list)  # Handle generators correctly
   if graph and not isinstance(graph, Graph):
@@ -5881,7 +5977,7 @@
         raise ValueError("%s is not from the passed-in graph." % graph_element)
 
   # 2. If all else fails, we use the default graph, which is always there.
-  return graph or get_default_graph()
+  return graph or current_default_graph
 
 
 @tf_export(v1=["GraphKeys"])
@@ -6226,15 +6322,21 @@
         raise ValueError(
             "At least one of name (%s) and default_name (%s) must be provided."
             % (self._name, self._default_name))
-      if self._values is None:
-        self._values = []
-      if self._values:
-        g = _get_graph_from_inputs(self._values)
-        self._g_manager = g.as_default()
-        self._g_manager.__enter__()
+
+      g = get_default_graph()
+      if self._values and not g.building_function:
+        # Specialize based on the knowledge that `_get_graph_from_inputs()`
+        # ignores `inputs` when building a function.
+        g_from_inputs = _get_graph_from_inputs(self._values)
+        if g_from_inputs is not g:
+          g = g_from_inputs
+          self._g_manager = g.as_default()
+          self._g_manager.__enter__()
+        else:
+          self._g_manager = None
       else:
-        g = get_default_graph()
         self._g_manager = None
+
       try:
         self._name_scope = g.name_scope(self._name)
         return self._name_scope.__enter__()
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 1b272cf..6aae914 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -18,14 +18,18 @@
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
 import gc
+import numpy as np
 import os
 import threading
 import weakref
 
 from tensorflow.core.framework import attr_value_pb2
 from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.autograph.core import ag_ctx
 from tensorflow.python.client import session
+from tensorflow.python.compat import compat as forward_compat
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
@@ -38,9 +42,11 @@
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import function
+from tensorflow.python.framework import indexed_slices
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.framework import test_ops
 from tensorflow.python.framework import test_util
@@ -99,13 +105,47 @@
     self.assertEqual([1, 2, 3], t.get_shape())
 
   def testIterable(self):
+    if not context.executing_eagerly():
+      self.skipTest("Eager-mode test")
     op = ops.Operation(
         ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
     t = op.outputs[0]
-    self.assertTrue(isinstance(t, ops.Tensor))
-    with self.assertRaisesRegexp(TypeError, "iter"):
-      for _ in t:
-        pass
+    with self.assertRaisesRegexp(TypeError, "Cannot iterate"):
+      next(iter(t))
+
+  def testIterableGraph(self):
+    if context.executing_eagerly():
+      self.skipTest("Graph-mode test")
+
+    op = ops.Operation(
+        ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
+    t = op.outputs[0]
+    with self.assertRaisesRegexp(TypeError, "iterating.*not allowed in Graph"):
+      next(iter(t))
+    with self.assertRaisesRegexp(
+        TypeError, "iterating.*AutoGraph did not convert"):
+      with ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED):
+        next(iter(t))
+    with self.assertRaisesRegexp(
+        TypeError, "iterating.*AutoGraph is disabled"):
+      with ag_ctx.ControlStatusCtx(ag_ctx.Status.DISABLED):
+        next(iter(t))
+
+  def testImplicitBool(self):
+    op = ops.Operation(
+        ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.bool])
+    t = op.outputs[0]
+    with self.assertRaisesRegexp(
+        TypeError, "using.*as a.*bool.*not allowed in Graph"):
+      bool(t)
+    with self.assertRaisesRegexp(
+        TypeError, "using.*as a.*bool.*AutoGraph did not convert"):
+      with ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED):
+        bool(t)
+    with self.assertRaisesRegexp(
+        TypeError, "using.*as a.*bool.*AutoGraph is disabled"):
+      with ag_ctx.ControlStatusCtx(ag_ctx.Status.DISABLED):
+        bool(t)
 
   def testAddShape(self):
     with self.cached_session():
@@ -148,6 +188,130 @@
           r"\(op: 'Add(V2)?'\) with input shapes: \[1,2,3\], \[4,5,6\]."):
         _ = a + b
 
+  def testNumpyArray(self):
+    with ops.Graph().as_default():
+      x = array_ops.ones((3, 4), name="test_ones")
+
+    with self.assertRaisesRegexp(NotImplementedError,
+                                 r"Cannot convert a symbolic.+test_ones"):
+      np.array(x)
+
+    with self.assertRaisesRegexp(TypeError, "not well defined.+test_ones"):
+      len(x)
+
+    # EagerTensors should still behave as numpy arrays.
+    with context.eager_mode():
+      x = array_ops.ones((3, 4))
+
+    self.assertAllEqual(x, np.ones((3, 4)))
+    self.assertAllEqual(np.array(x), np.ones((3, 4)))
+    self.assertEqual(len(x), 3)
+
+  def testRef(self):
+    x1 = constant_op.constant(3)
+    x2 = x1
+    y = constant_op.constant(3)
+    z = constant_op.constant([6, 10])
+    w = variables.Variable(5)
+
+    self.assertEqual(x1.experimental_ref(), x1.experimental_ref())
+    self.assertEqual(x2.experimental_ref(), x2.experimental_ref())
+    self.assertEqual(x1.experimental_ref(), x2.experimental_ref())
+    self.assertEqual(y.experimental_ref(), y.experimental_ref())
+    self.assertEqual(z.experimental_ref(), z.experimental_ref())
+    self.assertEqual(w.experimental_ref(), w.experimental_ref())
+
+    self.assertNotEqual(x1.experimental_ref(), y.experimental_ref())
+    self.assertNotEqual(x1.experimental_ref(), z.experimental_ref())
+    self.assertNotEqual(x1.experimental_ref(), w.experimental_ref())
+    self.assertNotEqual(y.experimental_ref(), z.experimental_ref())
+    self.assertNotEqual(y.experimental_ref(), w.experimental_ref())
+    self.assertNotEqual(z.experimental_ref(), w.experimental_ref())
+
+  def testRefDeref(self):
+    x1 = constant_op.constant(3)
+    x2 = x1
+    y = constant_op.constant(3)
+    z = constant_op.constant([6, 10])
+    w = variables.Variable(5)
+
+    self.assertIs(x1, x1.experimental_ref().deref())
+    self.assertIs(x2, x2.experimental_ref().deref())
+    self.assertIs(x1, x2.experimental_ref().deref())
+    self.assertIs(x2, x1.experimental_ref().deref())
+    self.assertIs(y, y.experimental_ref().deref())
+    self.assertIs(z, z.experimental_ref().deref())
+
+    self.assertIsNot(x1, y.experimental_ref().deref())
+    self.assertIsNot(x1, z.experimental_ref().deref())
+    self.assertIsNot(x1, w.experimental_ref().deref())
+    self.assertIsNot(y, z.experimental_ref().deref())
+    self.assertIsNot(y, w.experimental_ref().deref())
+    self.assertIsNot(z, w.experimental_ref().deref())
+
+  def testRefInSet(self):
+    x1 = constant_op.constant(3)
+    x2 = x1
+    y = constant_op.constant(3)
+    z = constant_op.constant([6, 10])
+    w = variables.Variable(5)
+
+    self.assertEqual(x1.experimental_ref(), x2.experimental_ref())
+
+    tensor_set = {
+        x1.experimental_ref(),
+        x2.experimental_ref(),
+        y.experimental_ref(),
+        z.experimental_ref(),
+        w.experimental_ref(),
+    }
+
+    self.assertEqual(len(tensor_set), 4)
+    self.assertIn(x1.experimental_ref(), tensor_set)
+    self.assertIn(x2.experimental_ref(), tensor_set)
+    self.assertIn(y.experimental_ref(), tensor_set)
+    self.assertIn(z.experimental_ref(), tensor_set)
+    self.assertIn(w.experimental_ref(), tensor_set)
+
+  def testRefInDict(self):
+    x1 = constant_op.constant(3)
+    x2 = x1
+    y = constant_op.constant(3)
+    z = constant_op.constant([6, 10])
+    w = variables.Variable(5)
+
+    self.assertEqual(x1.experimental_ref(), x2.experimental_ref())
+
+    tensor_dict = {
+        x1.experimental_ref(): "x1",
+        y.experimental_ref(): "y",
+        z.experimental_ref(): "z",
+        w.experimental_ref(): "w",
+    }
+
+    self.assertEqual(len(tensor_dict), 4)
+
+    # Overwriting x1
+    tensor_dict[x2.experimental_ref()] = "x2"
+    self.assertEqual(len(tensor_dict), 4)
+
+    self.assertEqual(tensor_dict[x1.experimental_ref()], "x2")
+    self.assertEqual(tensor_dict[x2.experimental_ref()], "x2")
+    self.assertEqual(tensor_dict[y.experimental_ref()], "y")
+    self.assertEqual(tensor_dict[z.experimental_ref()], "z")
+    self.assertEqual(tensor_dict[w.experimental_ref()], "w")
+
+  def testTensorRefStrong(self):
+    x = constant_op.constant(1.)
+    x_ref = x.experimental_ref()
+    del x
+    self.assertIsNotNone(x_ref.deref())
+
+  def testVariableRefStrong(self):
+    x = variables.Variable(1.)
+    x_ref = x.experimental_ref()
+    del x
+    self.assertIsNotNone(x_ref.deref())
 
 class IndexedSlicesTest(test_util.TensorFlowTestCase):
 
@@ -191,6 +355,136 @@
       self.assertAllEqual(x.indices.eval(), [0, 2])
 
 
+@test_util.run_all_in_graph_and_eager_modes
+class IndexedSlicesSpecTest(test_util.TensorFlowTestCase,
+                            parameterized.TestCase):
+
+  def assertAllTensorsEqual(self, list1, list2):
+    self.assertLen(list1, len(list2))
+    for (t1, t2) in zip(list1, list2):
+      self.assertAllEqual(t1, t2)
+
+  def testConstruction(self):
+    spec1 = indexed_slices.IndexedSlicesSpec()
+    self.assertEqual(spec1._shape.rank, None)
+    self.assertEqual(spec1._values_dtype, dtypes.float32)
+    self.assertEqual(spec1._indices_dtype, dtypes.int64)
+    self.assertEqual(spec1._dense_shape_dtype, None)
+    self.assertEqual(spec1._indices_shape.as_list(), [None])
+
+    spec2 = indexed_slices.IndexedSlicesSpec([None, None], dtypes.string,
+                                             dtypes.int32, dtypes.int64, [10])
+    self.assertEqual(spec2._shape.as_list(), [None, None])
+    self.assertEqual(spec2._values_dtype, dtypes.string)
+    self.assertEqual(spec2._indices_dtype, dtypes.int32)
+    self.assertEqual(spec2._dense_shape_dtype, dtypes.int64)
+    self.assertEqual(spec2._indices_shape.as_list(), [10])
+
+  def testValueType(self):
+    spec1 = indexed_slices.IndexedSlicesSpec()
+    self.assertEqual(spec1.value_type, ops.IndexedSlices)
+
+  @parameterized.parameters([
+      (indexed_slices.IndexedSlicesSpec(),
+       (tensor_shape.TensorShape(None), dtypes.float32, dtypes.int64, None,
+        tensor_shape.TensorShape([None]))),
+      (indexed_slices.IndexedSlicesSpec(shape=[5, None, None]),
+       (tensor_shape.TensorShape([5, None, None]), dtypes.float32,
+        dtypes.int64, None, tensor_shape.TensorShape([None]))),
+      (indexed_slices.IndexedSlicesSpec(
+          dtype=dtypes.int32, dense_shape_dtype=dtypes.int64),
+       (tensor_shape.TensorShape(None), dtypes.int32, dtypes.int64,
+        dtypes.int64, tensor_shape.TensorShape([None]))),
+      (indexed_slices.IndexedSlicesSpec(indices_shape=[100]),
+       (tensor_shape.TensorShape(None), dtypes.float32, dtypes.int64, None,
+        tensor_shape.TensorShape([100]))),
+  ])  # pyformat: disable
+  def testSerialize(self, spec, expected):
+    serialization = spec._serialize()
+    # TensorShape has an unconventional definition of equality, so we can't use
+    # assertEqual directly here.  But repr() is deterministic and lossless for
+    # the expected values, so we can use that instead.
+    self.assertEqual(repr(serialization), repr(expected))
+
+  @parameterized.parameters([
+      (indexed_slices.IndexedSlicesSpec(dtype=dtypes.string), (
+          tensor_spec.TensorSpec(None, dtypes.string),
+          tensor_spec.TensorSpec([None], dtypes.int64),
+      )),
+      (indexed_slices.IndexedSlicesSpec(
+          dtype=dtypes.string, dense_shape_dtype=dtypes.int32), (
+              tensor_spec.TensorSpec(None, dtypes.string),
+              tensor_spec.TensorSpec([None], dtypes.int64),
+              tensor_spec.TensorSpec([None], dtypes.int32),
+          )),
+      (indexed_slices.IndexedSlicesSpec(
+          shape=[5, 10, 15], dense_shape_dtype=dtypes.int32), (
+              tensor_spec.TensorSpec([None, 10, 15], dtypes.float32),
+              tensor_spec.TensorSpec([None], dtypes.int64),
+              tensor_spec.TensorSpec([3], dtypes.int32),
+          )),
+      (indexed_slices.IndexedSlicesSpec(
+          shape=[5, 10, 15], dense_shape_dtype=dtypes.int32,
+          indices_shape=[20]), (
+              tensor_spec.TensorSpec([20, 10, 15], dtypes.float32),
+              tensor_spec.TensorSpec([20], dtypes.int64),
+              tensor_spec.TensorSpec([3], dtypes.int32),
+          )),
+  ])
+  def testComponentSpecs(self, spec, expected):
+    self.assertEqual(spec._component_specs, expected)
+
+  @parameterized.parameters([
+      {
+          "spec": indexed_slices.IndexedSlicesSpec(),
+          "values": [3.0, 5.0],
+          "indices": [5, 10]
+      },
+      {
+          "spec":
+              indexed_slices.IndexedSlicesSpec(dense_shape_dtype=dtypes.int32),
+          "values": [3.0, 5.0],
+          "indices": [5, 10],
+          "dense_shape": [100]
+      },
+  ])
+  def testToFromComponents(self, spec, indices, values, dense_shape=None):
+    x = ops.IndexedSlices(indices, values, dense_shape)
+    actual_components = spec._to_components(x)
+    if dense_shape is None:
+      self.assertAllTensorsEqual(actual_components, [indices, values])
+    else:
+      self.assertAllTensorsEqual(actual_components,
+                                 [indices, values, dense_shape])
+    st_reconstructed = spec._from_components(actual_components)
+    self.assertAllEqual(x.indices, st_reconstructed.indices)
+    self.assertAllEqual(x.values, st_reconstructed.values)
+    if dense_shape is None:
+      self.assertIs(st_reconstructed.dense_shape, None)
+    else:
+      self.assertAllEqual(x.dense_shape, st_reconstructed.dense_shape)
+
+  @test_util.run_v1_only("IndexedSlicesValue is deprecated in v2")
+  def testFromNumpyComponents(self):
+    indices = np.array([3, 8])
+    values = np.array([1.0, 9.0])
+    dense_shape = np.array([100])
+
+    spec1 = indexed_slices.IndexedSlicesSpec(dense_shape_dtype=dtypes.int32)
+    st1 = spec1._from_components((values, indices, dense_shape))
+    self.assertIsInstance(st1, indexed_slices.IndexedSlicesValue)
+    self.assertAllEqual(st1.indices, indices)
+    self.assertAllEqual(st1.values, values)
+    self.assertAllEqual(st1.dense_shape, dense_shape)
+
+    spec2 = indexed_slices.IndexedSlicesSpec()
+    st2 = spec2._from_components((values, indices))
+    self.assertIsInstance(st2, indexed_slices.IndexedSlicesValue)
+    self.assertAllEqual(st2.indices, indices)
+    self.assertAllEqual(st2.values, values)
+    self.assertIs(st2.dense_shape, None)
+
+
 class NodeDefConstructorTest(test_util.TensorFlowTestCase):
 
   def testNoArgs(self):
@@ -413,6 +707,13 @@
       ops.convert_to_tensor(values, dtype=dtypes.int64)
 
   @test_util.run_in_graph_and_eager_modes
+  def testConvertToLongLongTensorType(self):
+    tensor = ops.convert_to_tensor(
+        # Get a numpy array of dtype NPY_LONGLONG
+        np.prod(constant_op.constant([1])._shape_tuple()))
+    self.assertEqual(dtypes.int64, tensor.dtype)
+
+  @test_util.run_in_graph_and_eager_modes
   def testConvertToTensorFromInvalidTensor(self):
     tensor = constant_op.constant(42.0, dtype=dtypes.float32)
     with self.assertRaises(ValueError):
@@ -624,33 +925,34 @@
   @test_util.enable_control_flow_v2
   @test_util.run_v1_only("b/120545219")
   def testAddWhileInput(self):
-    @eager_function.defun
-    def test():
-      output = control_flow_ops.while_loop(lambda x: x < 3, lambda x: x + 1,
-                                           [1])
-      while_op = output.op.inputs[0].op
-      self.assertEqual(while_op.type, "While")
-      orig_num_inputs = len(while_op.inputs)
+    if forward_compat.forward_compatible(2019, 8, 23):
+      @eager_function.defun
+      def test():
+        output = control_flow_ops.while_loop(lambda x: x < 3, lambda x: x + 1,
+                                             [1])
+        while_op = output.op.inputs[0].op
+        self.assertEqual(while_op.type, "StatelessWhile")
+        orig_num_inputs = len(while_op.inputs)
 
-      # Make sure we can handle the while op having a control input.
-      while_op._add_control_input(constant_op.constant(0).op)
+        # Make sure we can handle the while op having a control input.
+        while_op._add_control_input(constant_op.constant(0).op)
 
-      new_input1 = constant_op.constant(1.0)
-      new_input2 = constant_op.constant(True)
+        new_input1 = constant_op.constant(1.0)
+        new_input2 = constant_op.constant(True)
 
-      # Clear output shapes to bypass shape checking.
-      while_op._set_shape_list_attr("output_shapes", [])
-      while_op._set_type_list_attr("T",
-                                   [t.dtype for t in while_op.inputs] +
-                                   [new_input1.dtype, new_input2.dtype])
+        # Clear output shapes to bypass shape checking.
+        while_op._set_shape_list_attr("output_shapes", [])
+        while_op._set_type_list_attr("T",
+                                     [t.dtype for t in while_op.inputs] +
+                                     [new_input1.dtype, new_input2.dtype])
 
-      while_op._add_while_inputs([new_input1, new_input2])
-      # Can't add an edge beyond what's specified by "T"
-      with self.assertRaises(errors.OutOfRangeError):
-        while_op._add_while_inputs([new_input2])
-      self.assertEqual(len(while_op.inputs), orig_num_inputs + 2)  # pylint: disable=g-deprecated-assert
+        while_op._add_while_inputs([new_input1, new_input2])
+        # Can't add an edge beyond what's specified by "T"
+        with self.assertRaises(errors.OutOfRangeError):
+          while_op._add_while_inputs([new_input2])
+        self.assertEqual(len(while_op.inputs), orig_num_inputs + 2)  # pylint: disable=g-deprecated-assert
 
-    test()
+      test()
 
   @test_util.run_deprecated_v1
   def testOpDef(self):
diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py
index 788d0e9..a598f43 100644
--- a/tensorflow/python/framework/sparse_tensor.py
+++ b/tensorflow/python/framework/sparse_tensor.py
@@ -22,6 +22,7 @@
 import numpy as np
 
 from tensorflow.python import pywrap_tensorflow
+from tensorflow.python import tf2
 from tensorflow.python.framework import composite_tensor
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -279,6 +280,16 @@
     return (self._shape, self._dtype)
 
   @property
+  def dtype(self):
+    """The `tf.dtypes.DType` specified by this type for the SparseTensor."""
+    return self._dtype
+
+  @property
+  def shape(self):
+    """The `tf.TensorShape` specified by this type for the SparseTensor."""
+    return self._shape
+
+  @property
   def _component_specs(self):
     rank = self._shape.ndims
     num_values = None
@@ -293,7 +304,11 @@
     return [value.indices, value.values, value.dense_shape]
 
   def _from_components(self, tensor_list):
-    return SparseTensor(*tensor_list)
+    if (all(isinstance(t, np.ndarray) for t in tensor_list) and
+        not tf2.enabled()):
+      return SparseTensorValue(*tensor_list)
+    else:
+      return SparseTensor(*tensor_list)
 
   # The SparseTensorSpec tensor_list encoding uses (de)serialize_sparse ops
   # to (un)box the component tensors in a way that allows for batching &
diff --git a/tensorflow/python/framework/sparse_tensor_test.py b/tensorflow/python/framework/sparse_tensor_test.py
index 03aa63b..0202a83 100644
--- a/tensorflow/python/framework/sparse_tensor_test.py
+++ b/tensorflow/python/framework/sparse_tensor_test.py
@@ -18,12 +18,15 @@
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.python.eager import context
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import sparse_ops
 from tensorflow.python.platform import googletest
@@ -108,5 +111,146 @@
             sparse_tensor_value.dense_shape, convertee.dense_shape)
 
 
+@test_util.run_all_in_graph_and_eager_modes
+class SparseTensorSpecTest(test_util.TensorFlowTestCase,
+                           parameterized.TestCase):
+
+  def assertAllTensorsEqual(self, list1, list2):
+    self.assertLen(list1, len(list2))
+    for (t1, t2) in zip(list1, list2):
+      self.assertAllEqual(t1, t2)
+
+  def testConstruction(self):
+    spec1 = sparse_tensor.SparseTensorSpec()
+    self.assertEqual(spec1.shape.rank, None)
+    self.assertEqual(spec1.dtype, dtypes.float32)
+
+    spec2 = sparse_tensor.SparseTensorSpec([None, None], dtypes.string)
+    self.assertEqual(spec2.shape.as_list(), [None, None])
+    self.assertEqual(spec2.dtype, dtypes.string)
+
+  def testValueType(self):
+    spec1 = sparse_tensor.SparseTensorSpec()
+    self.assertEqual(spec1.value_type, sparse_tensor.SparseTensor)
+
+  @parameterized.parameters([
+      (sparse_tensor.SparseTensorSpec(),
+       (tensor_shape.TensorShape(None), dtypes.float32)),
+      (sparse_tensor.SparseTensorSpec(shape=[5, None, None]),
+       (tensor_shape.TensorShape([5, None, None]), dtypes.float32)),
+      (sparse_tensor.SparseTensorSpec(dtype=dtypes.int32),
+       (tensor_shape.TensorShape(None), dtypes.int32)),
+  ])  # pyformat: disable
+  def testSerialize(self, st_spec, expected):
+    serialization = st_spec._serialize()
+    # TensorShape has an unconventional definition of equality, so we can't use
+    # assertEqual directly here.  But repr() is deterministic and lossless for
+    # the expected values, so we can use that instead.
+    self.assertEqual(repr(serialization), repr(expected))
+
+  @parameterized.parameters([
+      (sparse_tensor.SparseTensorSpec(dtype=dtypes.string), [
+          tensor_spec.TensorSpec([None, None], dtypes.int64),
+          tensor_spec.TensorSpec([None], dtypes.string),
+          tensor_spec.TensorSpec([None], dtypes.int64)
+      ]),
+      (sparse_tensor.SparseTensorSpec(shape=[5, None, None]), [
+          tensor_spec.TensorSpec([None, 3], dtypes.int64),
+          tensor_spec.TensorSpec([None], dtypes.float32),
+          tensor_spec.TensorSpec([3], dtypes.int64)
+      ]),
+  ])
+  def testComponentSpecs(self, st_spec, expected):
+    self.assertEqual(st_spec._component_specs, expected)
+
+  @parameterized.parameters([
+      {
+          "st_spec": sparse_tensor.SparseTensorSpec(),
+          "indices": [[0, 1], [10, 8]],
+          "values": [3.0, 5.0],
+          "dense_shape": [100, 100]
+      },
+      {
+          "st_spec": sparse_tensor.SparseTensorSpec([100, None, None]),
+          "indices": [[0, 1, 3], [10, 8, 2]],
+          "values": [3.0, 5.0],
+          "dense_shape": [100, 20, 20]
+      },
+  ])
+  def testToFromComponents(self, st_spec, indices, values, dense_shape):
+    st = sparse_tensor.SparseTensor(indices, values, dense_shape)
+    actual_components = st_spec._to_components(st)
+    self.assertAllTensorsEqual(actual_components,
+                               [indices, values, dense_shape])
+    st_reconstructed = st_spec._from_components(actual_components)
+    self.assertAllEqual(st.indices, st_reconstructed.indices)
+    self.assertAllEqual(st.values, st_reconstructed.values)
+    self.assertAllEqual(st.dense_shape, st_reconstructed.dense_shape)
+
+  @test_util.run_v1_only("SparseTensorValue is deprecated in v2")
+  def testFromNumpyComponents(self):
+    indices = np.array([[0], [8]])
+    values = np.array([1.0, 9.0])
+    dense_shape = np.array([100])
+    spec = sparse_tensor.SparseTensorSpec()
+    st = spec._from_components([indices, values, dense_shape])
+    self.assertIsInstance(st, sparse_tensor.SparseTensorValue)
+    self.assertAllEqual(st.indices, indices)
+    self.assertAllEqual(st.values, values)
+    self.assertAllEqual(st.dense_shape, dense_shape)
+
+  @parameterized.parameters([
+      sparse_tensor.SparseTensorSpec(dtype=dtypes.string),
+      sparse_tensor.SparseTensorSpec(shape=[5, None, None]),
+  ])
+  def testFlatTensorSpecs(self, st_spec):
+    self.assertEqual(st_spec._flat_tensor_specs,
+                     [tensor_spec.TensorSpec(None, dtypes.variant)])
+
+  @parameterized.parameters([
+      {
+          "st_spec": sparse_tensor.SparseTensorSpec(),
+          "indices": [[0, 1], [10, 8]],
+          "values": [3.0, 5.0],
+          "dense_shape": [100, 100]
+      },
+      {
+          "st_spec": sparse_tensor.SparseTensorSpec([100, None, None]),
+          "indices": [[0, 1, 3], [10, 8, 2]],
+          "values": [3.0, 5.0],
+          "dense_shape": [100, 20, 20]
+      },
+  ])
+  def testToFromTensorList(self, st_spec, indices, values, dense_shape):
+    st = sparse_tensor.SparseTensor(indices, values, dense_shape)
+    tensor_list = st_spec._to_tensor_list(st)
+    st_reconstructed = st_spec._from_tensor_list(tensor_list)
+    self.assertAllEqual(st.indices, st_reconstructed.indices)
+    self.assertAllEqual(st.values, st_reconstructed.values)
+    self.assertAllEqual(st.dense_shape, st_reconstructed.dense_shape)
+
+  @parameterized.parameters([
+      (sparse_tensor.SparseTensorSpec([2, None], dtypes.float32), 32,
+       sparse_tensor.SparseTensorSpec([32, 2, None], dtypes.float32)),
+      (sparse_tensor.SparseTensorSpec([4, None], dtypes.float32), None,
+       sparse_tensor.SparseTensorSpec([None, 4, None], dtypes.float32)),
+      (sparse_tensor.SparseTensorSpec([2], dtypes.float32), 32,
+       sparse_tensor.SparseTensorSpec([32, 2], dtypes.float32)),
+  ])
+  def testBatch(self, spec, batch_size, expected):
+    self.assertEqual(spec._batch(batch_size), expected)
+
+  @parameterized.parameters([
+      (sparse_tensor.SparseTensorSpec([32, None, None], dtypes.float32),
+       sparse_tensor.SparseTensorSpec([None, None], dtypes.float32)),
+      (sparse_tensor.SparseTensorSpec([None, None, None], dtypes.float32),
+       sparse_tensor.SparseTensorSpec([None, None], dtypes.float32)),
+      (sparse_tensor.SparseTensorSpec([32, 2], dtypes.float32),
+       sparse_tensor.SparseTensorSpec([2], dtypes.float32)),
+  ])
+  def testUnbatch(self, spec, expected):
+    self.assertEqual(spec._unbatch(), expected)
+
+
 if __name__ == "__main__":
   googletest.main()
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index daf4b09..caf97df 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -22,6 +22,7 @@
 
 from tensorflow.core.framework import tensor_pb2
 from tensorflow.core.framework import tensor_shape_pb2
+from tensorflow.python.eager import context
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_like
@@ -332,6 +333,11 @@
 
 def _is_array_like(obj):  # pylint: disable=invalid-name
   """Check if a given object is array-like."""
+  if isinstance(obj, ops.Tensor) and not isinstance(obj, ops._EagerTensorBase):  # pylint: disable=protected-access
+    # Tensor implements __array__ only so it can inform the user that it is not
+    # a valid array.
+    return False
+
   # TODO(slebedev): an object could also implement C-level array interface.
   if (callable(getattr(obj, "__array__", None)) or
       isinstance(getattr(obj, "__array_interface__", None), dict)):
@@ -904,6 +910,18 @@
       pass
     except TypeError:  # Could come from slicing prev.
       pass
+  elif (tensor.op.type == "Placeholder" and
+        tensor.op.graph.building_function and
+        hasattr(tensor.op.graph, "internal_captures")):
+    # If we are inside a FuncGraph try to lookup the constant value of the
+    # corresponding external capture. Note that we only look at captures and
+    # not the fed inputs because those can be fed different values in different
+    # instantiations of the function call or different iterations of a
+    # tf.while_loop.
+    for i, capture in enumerate(tensor.op.graph.internal_captures):
+      if capture is tensor:
+        external_capture = tensor.op.graph.external_captures[i]
+        return constant_value_as_shape(external_capture)
 
   ret = tensor_shape.unknown_shape(shape.dims[0].value)
   value = constant_value(tensor)
@@ -944,3 +962,12 @@
       # not convertible to Tensors becasue of mixed content.
       shape = tuple(map(tensor_shape.dimension_value, shape))
   return ops.convert_to_tensor(shape, dtype=dtype, name="shape")
+
+
+def maybe_set_static_shape(tensor, shape):  # pylint: disable=invalid-name
+  if (not context.executing_eagerly() and
+      ops.get_default_graph().building_function and
+      not tensor.shape.is_fully_defined()):
+    shape = shape_tensor(shape)
+    const_shape = constant_value_as_shape(shape)
+    tensor.set_shape(const_shape)
diff --git a/tensorflow/python/framework/test_ops.cc b/tensorflow/python/framework/test_ops.cc
index 5d1386c..550d5ba 100644
--- a/tensorflow/python/framework/test_ops.cc
+++ b/tensorflow/python/framework/test_ops.cc
@@ -96,13 +96,13 @@
                    ctx->allocate_output("result", TensorShape({}), &output));
     switch (KL) {
       case DEFAULT_LABEL:
-        output->scalar<string>()() = "My label is: default";
+        output->scalar<tstring>()() = "My label is: default";
         break;
       case OVERLOAD_1_LABEL:
-        output->scalar<string>()() = "My label is: overload_1";
+        output->scalar<tstring>()() = "My label is: overload_1";
         break;
       case OVERLOAD_2_LABEL:
-        output->scalar<string>()() = "My label is: overload_2";
+        output->scalar<tstring>()() = "My label is: overload_2";
         break;
     }
   }
@@ -676,7 +676,7 @@
     Tensor* output;
     OP_REQUIRES_OK(ctx,
                    ctx->allocate_output("device", TensorShape({}), &output));
-    output->scalar<string>()() = ctx->device()->name();
+    output->scalar<tstring>()() = ctx->device()->name();
   }
 };
 
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 4eaae12..8857e76 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -19,7 +19,6 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
 from collections import OrderedDict
 import contextlib
 import functools
@@ -69,6 +68,7 @@
 from tensorflow.python.framework import versions
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import control_flow_util_v2
 from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.ops.ragged import ragged_tensor_value
 from tensorflow.python.ops import script_ops
@@ -83,6 +83,7 @@
 from tensorflow.python.util import tf_inspect
 from tensorflow.python.util.protobuf import compare
 from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.compat import collections_abc
 
 
 # If the below import is made available through the BUILD rule, then this
@@ -536,6 +537,29 @@
   return wrapper
 
 
+def enable_output_all_intermediates(fn):
+  """Force-enable outputing all intermediates from functional control flow ops.
+
+  Args:
+    fn: the function to be wrapped
+
+  Returns:
+    The wrapped function
+  """
+
+  def wrapper(*args, **kwargs):
+    output_all_intermediates_old = \
+        control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
+    control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = True
+    try:
+      return fn(*args, **kwargs)
+    finally:
+      control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = \
+          output_all_intermediates_old
+
+  return wrapper
+
+
 def assert_no_new_pyobjects_executing_eagerly(f):
   """Decorator for asserting that no new Python objects persist after a test.
 
@@ -915,11 +939,14 @@
 def run_all_in_graph_and_eager_modes(cls):
   """Execute all test methods in the given class with and without eager."""
   base_decorator = run_in_graph_and_eager_modes
-  for name, value in cls.__dict__.copy().items():
-    if callable(value) and name.startswith(
-        unittest.TestLoader.testMethodPrefix) and not (
-            name.startswith("testSkipEager") or
-            name.startswith("test_skip_eager") or name == "test_session"):
+  for name in dir(cls):
+    if (not name.startswith(unittest.TestLoader.testMethodPrefix) or
+        name.startswith("testSkipEager") or
+        name.startswith("test_skip_eager") or
+        name == "test_session"):
+      continue
+    value = getattr(cls, name, None)
+    if callable(value):
       setattr(cls, name, base_decorator(value))
   return cls
 
@@ -2301,8 +2328,8 @@
       a = a._asdict()
     if hasattr(b, "_asdict"):
       b = b._asdict()
-    a_is_dict = isinstance(a, collections.Mapping)
-    if a_is_dict != isinstance(b, collections.Mapping):
+    a_is_dict = isinstance(a, collections_abc.Mapping)
+    if a_is_dict != isinstance(b, collections_abc.Mapping):
       raise ValueError("Can't compare dict to non-dict, a%s vs b%s. %s" %
                        (path_str, path_str, msg))
     if a_is_dict:
@@ -2496,6 +2523,21 @@
       np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs))
 
   @py_func_if_in_function
+  def assertNotAllEqual(self, a, b, msg=None):
+    """Asserts that two numpy arrays or Tensors do not have the same values.
+
+    Args:
+      a: the expected numpy ndarray or anything can be converted to one.
+      b: the actual numpy ndarray or anything can be converted to one.
+      msg: Optional message to report on failure.
+    """
+    try:
+      self.assertAllEqual(a, b, msg)
+    except AssertionError:
+      return
+    raise AssertionError("The two values are equal at all elements")
+
+  @py_func_if_in_function
   def assertAllGreater(self, a, comparison_target):
     """Assert element values are all greater than a target value.
 
diff --git a/tensorflow/python/framework/type_spec.py b/tensorflow/python/framework/type_spec.py
index 214dce5..ffc93b0 100644
--- a/tensorflow/python/framework/type_spec.py
+++ b/tensorflow/python/framework/type_spec.py
@@ -260,7 +260,8 @@
 
   def __eq__(self, other):
     # pylint: disable=protected-access
-    return self.__get_cmp_key() == other.__get_cmp_key()
+    return (type(other) is type(self) and
+            self.__get_cmp_key() == other.__get_cmp_key())
 
   def __ne__(self, other):
     return not self == other
diff --git a/tensorflow/python/framework/type_spec_test.py b/tensorflow/python/framework/type_spec_test.py
index dcc54b5..46e1ea3 100644
--- a/tensorflow/python/framework/type_spec_test.py
+++ b/tensorflow/python/framework/type_spec_test.py
@@ -143,6 +143,8 @@
                       tensor_spec.TensorSpec([4], name="a")),
        TwoTensorsSpec([5, 3], dtypes.int32, [3], dtypes.bool,
                       tensor_spec.TensorSpec([4], name="b"))),
+      ("Non-TypeSpec",
+       TwoTensorsSpec([5, 3], dtypes.int32, [8], dtypes.bool), 5),
       )
   def testInequality(self, v1, v2):
     # pylint: disable=g-generic-assert
diff --git a/tensorflow/python/grappler/auto_mixed_precision_test.py b/tensorflow/python/grappler/auto_mixed_precision_test.py
index ea16008..dc020d7 100644
--- a/tensorflow/python/grappler/auto_mixed_precision_test.py
+++ b/tensorflow/python/grappler/auto_mixed_precision_test.py
@@ -436,7 +436,7 @@
       self.assertEqual(num_to_fp32, 1)
       self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
 
-  @test_util.run_deprecated_v1
+  @test_util.run_v1_only('b/138749235')
   @test_util.disable_xla('This test does not pass with XLA')
   def test_simple_loop(self):
     """Test graph with while loop."""
@@ -455,7 +455,7 @@
       self._assert_output_fp16(node_map, 'while/Relu')
       self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
 
-  @test_util.run_deprecated_v1
+  @test_util.run_v1_only('b/138749235')
   @test_util.disable_xla('This test does not pass with XLA')
   def test_loop_with_vars_intertwined(self):
     """Test graph with intertwined while loops."""
@@ -528,7 +528,7 @@
       self._assert_output_fp16(node_map, 'Relu_1')
       self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
 
-  @test_util.run_deprecated_v1
+  @test_util.run_v1_only('b/138749235')
   @test_util.disable_xla('This test does not pass with XLA')
   def test_recurrent_lstm(self):
     """Test graph with recurrent lstm."""
@@ -554,42 +554,42 @@
       self._assert_output_fp16(node_map, 'while/Tanh_1')
       self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
 
-  @test_util.run_deprecated_v1
+  @test_util.run_v1_only('v1 loop test')
   @test_util.disable_xla('This test does not pass with XLA')
   def test_propagation_through_simple_loop_1(self):
     self._run_simple_loop_test('W', 'C', 'C')
 
-  @test_util.run_deprecated_v1
+  @test_util.run_v1_only('v1 loop test')
   @test_util.disable_xla('This test does not pass with XLA')
   def test_propagation_through_simple_loop_2(self):
     self._run_simple_loop_test('C', 'C', 'W')
 
-  @test_util.run_deprecated_v1
+  @test_util.run_v1_only('v1 loop test')
   @test_util.disable_xla('This test does not pass with XLA')
   def test_propagation_through_simple_loop_3(self):
     self._run_simple_loop_test('W', 'G', 'W')
 
-  @test_util.run_deprecated_v1
+  @test_util.run_v1_only('v1 loop test')
   @test_util.disable_xla('This test does not pass with XLA')
   def test_propagation_through_simple_loop_4(self):
     self._run_simple_loop_test('W', 'gbg', 'W')
 
-  @test_util.run_deprecated_v1
+  @test_util.run_v1_only('b/138749235')
   @test_util.disable_xla('This test does not pass with XLA')
   def test_propagation_through_simple_loop_5(self):
     self._run_simple_loop_test('b', 'gWC', 'c')
 
-  @test_util.run_deprecated_v1
+  @test_util.run_v1_only('b/138749235')
   @test_util.disable_xla('This test does not pass with XLA')
   def test_propagation_through_simple_loop_6(self):
     self._run_simple_loop_test('b', 'CWCG', 'C')
 
-  @test_util.run_deprecated_v1
+  @test_util.run_v1_only('b/138749235')
   @test_util.disable_xla('This test does not pass with XLA')
   def test_propagation_through_simple_loop_7(self):
     self._run_simple_loop_test('C', 'GWCG', 'C')
 
-  @test_util.run_deprecated_v1
+  @test_util.run_v1_only('b/138749235')
   @test_util.disable_xla('This test does not pass with XLA')
   def test_propagation_through_simple_loop_8(self):
     self._run_simple_loop_test('C', 'CgbgWC', 'g')
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index cca0963..5e480ab 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -3,7 +3,6 @@
 
 load("//tensorflow:tensorflow.bzl", "tf_py_test")
 load("//tensorflow:tensorflow.bzl", "cuda_py_test")
-load("//tensorflow/core:platform/default/distribute.bzl", "distribute_py_test")
 
 package(
     default_visibility = ["//visibility:public"],
@@ -145,11 +144,12 @@
         ":tf_utils",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:auto_control_deps",
-        "//tensorflow/python:control_flow_util",
+        "//tensorflow/python:control_flow_v2_func_graphs",
         "//tensorflow/python:dtypes",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:init_ops",
         "//tensorflow/python:init_ops_v2",
+        "//tensorflow/python:tf2",
         "//tensorflow/python:util",
         "//tensorflow/python:variables",
         "//tensorflow/python/distribute:distribute_lib",
@@ -492,6 +492,7 @@
         ":generic_utils",
         ":tf_utils",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python:control_flow_util",
         "//tensorflow/python:cudnn_rnn_ops_gen",
         "//tensorflow/python:dtypes",
         "//tensorflow/python:embedding_ops",
@@ -723,7 +724,7 @@
         "//third_party/py/numpy",
         "//tensorflow/python:client_testlib",
     ],
-    shard_count = 12,
+    shard_count = 20,
 )
 
 tf_py_test(
@@ -755,6 +756,7 @@
     srcs = ["layers/tensorflow_op_layer_test.py"],
     additional_deps = [
         ":keras",
+        ":saving",
         "@absl_py//absl/testing:parameterized",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python/eager:backprop",
@@ -913,7 +915,7 @@
         "//third_party/py/numpy",
         "//tensorflow/python:client_testlib",
     ],
-    shard_count = 2,
+    shard_count = 4,
     tags = ["no_windows"],
 )
 
@@ -1124,7 +1126,7 @@
 
 tf_py_test(
     name = "wrappers_test",
-    size = "medium",
+    size = "large",
     srcs = ["layers/wrappers_test.py"],
     additional_deps = [
         ":keras",
@@ -1390,7 +1392,10 @@
         "//tensorflow/python:client_testlib",
     ],
     shard_count = 4,
-    tags = ["notsan"],
+    tags = [
+        "no_oss",
+        "notsan",
+    ],
 )
 
 tf_py_test(
@@ -1509,6 +1514,23 @@
 )
 
 tf_py_test(
+    name = "training_integration_test",
+    size = "medium",
+    srcs = ["engine/training_integration_test.py"],
+    additional_deps = [
+        ":keras",
+        "@absl_py//absl/testing:parameterized",
+        "//third_party/py/numpy",
+        "//tensorflow/python:client_testlib",
+    ],
+    shard_count = 30,
+    tags = [
+        "no_rocm",
+        "nomac",  # TODO(mihaimaruseac): b/127695564
+    ],
+)
+
+tf_py_test(
     name = "feature_columns_integration_test",
     size = "medium",
     srcs = ["engine/feature_columns_integration_test.py"],
@@ -1558,11 +1580,39 @@
     ],
 )
 
+py_library(
+    name = "model_subclassing_test_util",
+    srcs = ["model_subclassing_test_util.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+    ],
+)
+
 tf_py_test(
     name = "model_subclassing_test",
     size = "medium",
     srcs = ["model_subclassing_test.py"],
     additional_deps = [
+        ":model_subclassing_test_util",
+        ":keras",
+        "@absl_py//absl/testing:parameterized",
+        "//third_party/py/numpy",
+        "//tensorflow/python:client_testlib",
+    ],
+    shard_count = 4,
+    tags = [
+        "no_windows",
+        "notsan",
+    ],
+)
+
+tf_py_test(
+    name = "model_subclassing_compiled_test",
+    size = "medium",
+    srcs = ["model_subclassing_compiled_test.py"],
+    additional_deps = [
+        ":model_subclassing_test_util",
         ":keras",
         "@absl_py//absl/testing:parameterized",
         "//third_party/py/numpy",
diff --git a/tensorflow/python/keras/__init__.py b/tensorflow/python/keras/__init__.py
index 64fa731..9655254 100644
--- a/tensorflow/python/keras/__init__.py
+++ b/tensorflow/python/keras/__init__.py
@@ -38,6 +38,7 @@
 from tensorflow.python.keras import models
 from tensorflow.python.keras import ops
 from tensorflow.python.keras import optimizers
+from tensorflow.python.keras import premade
 from tensorflow.python.keras import preprocessing
 from tensorflow.python.keras import regularizers
 from tensorflow.python.keras import utils
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 186b4f2..e6c06ee 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -37,7 +37,6 @@
 from tensorflow.python.distribute import distribute_coordinator as dc
 from tensorflow.python.distribute import distribute_coordinator_context as dc_context
 from tensorflow.python.distribute import distribution_strategy_context
-from tensorflow.python.distribute import multi_worker_util
 from tensorflow.python.eager import context
 from tensorflow.python.eager import function as eager_function
 from tensorflow.python.eager import lift_to_graph
@@ -72,6 +71,7 @@
 from tensorflow.python.ops.ragged import ragged_factory_ops
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import nest
+from tensorflow.python.util import object_identity
 from tensorflow.python.util import tf_contextlib
 from tensorflow.python.util import tf_inspect
 from tensorflow.python.util.tf_export import keras_export
@@ -271,10 +271,13 @@
   Returns:
       Learning phase (scalar integer tensor or Python integer).
   """
-  if ops.get_default_graph() is _GRAPH:
+  graph = ops.get_default_graph()
+  if graph is _GRAPH:
     # Don't enter an init_scope for the learning phase if eager execution
     # is enabled but we're inside the Keras workspace graph.
-    return symbolic_learning_phase()
+    learning_phase = symbolic_learning_phase()
+    _mark_func_graph_as_unsaveable(graph, learning_phase)
+    return learning_phase
   with ops.init_scope():
     # We always check & set the learning phase inside the init_scope,
     # otherwise the wrong default_graph will be used to look up the learning
@@ -288,13 +291,34 @@
         # Fallback to inference mode as default.
         return 0
       return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH]
-    return symbolic_learning_phase()
+    learning_phase = symbolic_learning_phase()
+    _mark_func_graph_as_unsaveable(graph, learning_phase)
+    return learning_phase
 
 
 def global_learning_phase_is_set():
   return _DUMMY_EAGER_GRAPH in _GRAPH_LEARNING_PHASES
 
 
+def _mark_func_graph_as_unsaveable(graph, learning_phase):
+  """Mark func graph as unsaveable due to use of symbolic keras learning phase.
+
+  Functions that capture the symbolic learning phase cannot be exported to
+  SavedModel. Mark the funcgraph as unsaveable, so that an error will be raised
+  if it is exported.
+
+  Args:
+    graph: Graph or FuncGraph object.
+    learning_phase: Learning phase placeholder or int defined in the graph.
+  """
+  if graph.building_function and is_placeholder(learning_phase):
+    graph.mark_as_unsaveable(
+        'The keras learning phase placeholder was used inside a function. '
+        'Exporting placeholders is not supported when saving out a SavedModel. '
+        'Please call `tf.keras.backend.set_learning_phase(0)` in the function '
+        'to set the learning phase to a constant value.')
+
+
 def symbolic_learning_phase():
   graph = get_graph()
   with graph.as_default():
@@ -802,7 +826,7 @@
     return
   graph = v.graph if hasattr(v, 'graph') else get_graph()
   if graph not in _GRAPH_VARIABLES:
-    _GRAPH_VARIABLES[graph] = weakref.WeakSet()
+    _GRAPH_VARIABLES[graph] = object_identity.ObjectIdentityWeakSet()
   _GRAPH_VARIABLES[graph].add(v)
 
 
@@ -1090,7 +1114,7 @@
 
     global _FREEZABLE_VARS
     if graph not in _FREEZABLE_VARS:
-      _FREEZABLE_VARS[graph] = weakref.WeakSet()
+      _FREEZABLE_VARS[graph] = object_identity.ObjectIdentityWeakSet()
     _FREEZABLE_VARS[graph].add(x)
   return x
 
@@ -2544,6 +2568,17 @@
 
   Returns:
       A tensor.
+
+  Example:
+      ```python
+      >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
+      >>> b = tf.constant([[10, 20, 30], [40, 50, 60], [70, 80, 90]])
+      >>> tf.keras.backend.concatenate((a, b), axis=-1)
+      <tf.Tensor: id=14, shape=(3, 6), dtype=int32, numpy=
+      array([[ 1,  2,  3, 10, 20, 30],
+             [ 4,  5,  6, 40, 50, 60],
+             [ 7,  8,  9, 70, 80, 90]], dtype=int32)>
+      ```
   """
   if axis < 0:
     rank = ndim(tensors[0])
@@ -2568,6 +2603,21 @@
 
   Returns:
       A tensor.
+
+  Example:
+    ```python
+      >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
+      >>> a
+      <tf.Tensor: id=32, shape=(4, 3), dtype=int32, numpy=
+      array([[ 1,  2,  3],
+             [ 4,  5,  6],
+             [ 7,  8,  9],
+             [10, 11, 12]], dtype=int32)>
+      >>> tf.keras.backend.reshape(a, shape=(2, 6))
+      <tf.Tensor: id=35, shape=(2, 6), dtype=int32, numpy=
+      array([[ 1,  2,  3,  4,  5,  6],
+             [ 7,  8,  9, 10, 11, 12]], dtype=int32)>
+    ```
   """
   return array_ops.reshape(x, shape)
 
@@ -2583,6 +2633,22 @@
 
   Returns:
       A tensor.
+
+  Example:
+    ```python
+      >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
+      >>> a
+      <tf.Tensor: id=49, shape=(4, 3), dtype=int32, numpy=
+      array([[ 1,  2,  3],
+             [ 4,  5,  6],
+             [ 7,  8,  9],
+             [10, 11, 12]], dtype=int32)>
+      >>> tf.keras.backend.permute_dimensions(a, pattern=(1, 0))
+      <tf.Tensor: id=52, shape=(3, 4), dtype=int32, numpy=
+      array([[ 1,  4,  7, 10],
+             [ 2,  5,  8, 11],
+             [ 3,  6,  9, 12]], dtype=int32)>
+    ```
   """
   return array_ops.transpose(x, perm=pattern)
 
@@ -2696,6 +2762,14 @@
 
   Returns:
       A tensor.
+
+  Example:
+      ```python
+        >>> b = tf.constant([1, 2, 3])
+        >>> tf.keras.backend.repeat_elements(b, rep=2, axis=0)
+        <tf.Tensor: id=70, shape=(6,), dtype=int32,
+            numpy=array([1, 1, 2, 2, 3, 3], dtype=int32)>
+      ```
   """
   x_shape = x.shape.as_list()
   # For static axis
@@ -2748,6 +2822,21 @@
 
   Returns:
       A tensor.
+
+  Example:
+      ```python
+        >>> b = tf.constant([[1, 2], [3, 4]])
+        >>> b
+        <tf.Tensor: id=78, shape=(2, 2), dtype=int32, numpy=
+        array([[1, 2],
+               [3, 4]], dtype=int32)>
+        >>> tf.keras.backend.repeat(b, n=2)
+        <tf.Tensor: id=82, shape=(2, 2, 2), dtype=int32, numpy=
+        array([[[1, 2],
+                [1, 2]],
+               [[3, 4],
+                [3, 4]]], dtype=int32)>
+      ```
   """
   assert ndim(x) == 2
   x = array_ops.expand_dims(x, 1)
@@ -2775,6 +2864,14 @@
   Returns:
       An integer tensor.
 
+  Example:
+      ```python
+        >>> tf.keras.backend.arange(start=0, stop=10, step=1.5)
+        <tf.Tensor: id=96, shape=(7,), dtype=float32,
+            numpy=array([0. , 1.5, 3. , 4.5, 6. , 7.5, 9. ], dtype=float32)>
+
+      ```
+
   """
   # Match the behavior of numpy and Theano by returning an empty sequence.
   if stop is None and start < 0:
@@ -2811,6 +2908,18 @@
 
   Returns:
       A tensor, reshaped into 1-D
+
+  Example:
+      ```python
+        >>> b = tf.constant([[1, 2], [3, 4]])
+        >>> b
+        <tf.Tensor: id=102, shape=(2, 2), dtype=int32, numpy=
+        array([[1, 2],
+               [3, 4]], dtype=int32)>
+        >>> tf.keras.backend.flatten(b)
+        <tf.Tensor: id=105, shape=(4,), dtype=int32,
+            numpy=array([1, 2, 3, 4], dtype=int32)>
+      ```
   """
   return array_ops.reshape(x, [-1])
 
@@ -2972,6 +3081,18 @@
 
   Returns:
       A tensor.
+
+  Example:
+      ```python
+        >>> a = tf.constant([[1, 2],[3, 4]])
+        >>> b = tf.constant([[10, 20],[30, 40]])
+        >>> tf.keras.backend.stack((a, b))
+        <tf.Tensor: id=146, shape=(2, 2, 2), dtype=int32, numpy=
+        array([[[ 1,  2],
+                [ 3,  4]],
+               [[10, 20],
+                [30, 40]]], dtype=int32)>
+      ```
   """
   return array_ops.stack(x, axis=axis)
 
@@ -3412,7 +3533,8 @@
 
     self._freezable_vars_to_feed = []
     self._freezable_vars_values = []
-    freezable_vars_from_keras_graph = _FREEZABLE_VARS.get(global_graph, {})
+    freezable_vars_from_keras_graph = object_identity.ObjectIdentitySet(
+        _FREEZABLE_VARS.get(global_graph, {}))
     with _scratch_graph() as exec_graph:
       global_graph = get_graph()
       if source_graph not in (exec_graph, global_graph):
@@ -3424,8 +3546,12 @@
             [p_new for [_, p_new] in legacy_update_ops
              if isinstance(p_new, ops.Tensor)])
         lifted_map = lift_to_graph.lift_to_graph(
-            init_tensors=init_tensors, graph=exec_graph, sources=inputs,
-            add_sources=True, handle_captures=True, base_graph=source_graph)
+            tensors=init_tensors,
+            graph=exec_graph,
+            sources=inputs,
+            add_sources=True,
+            handle_captures=True,
+            base_graph=source_graph)
 
         inputs = [lifted_map[i] for i in inputs]
         outputs = [lifted_map[i] for i in outputs]
@@ -3457,8 +3583,7 @@
       with ops.control_dependencies(updates_ops):
         self.outputs[0] = array_ops.identity(self.outputs[0])
 
-      exec_graph.inputs = self._input_references + list(
-          exec_graph.captures.values())
+      exec_graph.inputs = self._input_references + exec_graph.internal_captures
       exec_graph.outputs = self.outputs
       graph_fn = eager_function.ConcreteFunction(exec_graph)
 
@@ -3472,8 +3597,8 @@
     with exec_graph.as_default():
       for x in self.inputs:
         if x.op.type == 'PlaceholderWithDefault':
-          self._placeholder_default_values[x] = tensor_util.constant_value(
-              x.op.inputs[0])
+          self._placeholder_default_values[ops.tensor_id(
+              x)] = tensor_util.constant_value(x.op.inputs[0])
 
   def __call__(self, inputs):
     input_values = nest.flatten(inputs, expand_composites=True)
@@ -3484,7 +3609,8 @@
     for tensor, value in zip(self._input_references, input_values):
       if value is None:
         # Assume `value` is a placeholder with default
-        value = self._placeholder_default_values.get(tensor, None)
+        value = self._placeholder_default_values.get(
+            ops.tensor_id(tensor), None)
         if value is None:
           raise ValueError(
               'You must feed a value for placeholder %s' % (tensor,))
@@ -3795,6 +3921,7 @@
         tensor_array_ops.TensorArray(
             dtype=out.dtype,
             size=time_steps_t,
+            element_shape=out.shape,
             tensor_array_name='output_ta_%s' % i)
         for i, out in enumerate(nest.flatten(output_time_zero)))
 
@@ -4024,17 +4151,19 @@
   if training is None:
     training = learning_phase()
 
-  if training == 1 or training is True:
-    if callable(x):
-      return x()
-    else:
-      return x
+  # TODO(b/138862903): Handle the case when training is tensor.
+  if not tensor_util.is_tensor(training):
+    if training == 1 or training is True:
+      if callable(x):
+        return x()
+      else:
+        return x
 
-  elif training == 0 or training is False:
-    if callable(alt):
-      return alt()
-    else:
-      return alt
+    elif training == 0 or training is False:
+      if callable(alt):
+        return alt()
+      else:
+        return alt
 
   # else: assume learning phase is a placeholder tensor.
   x = switch(training, x, alt)
@@ -4197,7 +4326,7 @@
 
   Raises:
       ValueError: if `axis` is neither -1 nor one of the axes of `output`.
-      
+
   Example:
   ```python:
       import tensorflow as tf
@@ -5676,7 +5805,7 @@
 
     set_session(session)
 
-  if multi_worker_util.in_multi_worker_mode():
+  if distribution_strategy._in_multi_worker_mode():
     dc.run_distribute_coordinator(
         _create_session,
         distribution_strategy,
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index c735f71..492d583 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -43,10 +43,13 @@
 from tensorflow.python.keras.utils.mode_keys import ModeKeys
 from tensorflow.python.lib.io import file_io
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import summary_ops_v2
+from tensorflow.python.ops import variables
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training import checkpoint_management
 from tensorflow.python.util.tf_export import keras_export
+from tensorflow.python.util.compat import collections_abc
 
 try:
   import requests
@@ -898,8 +901,8 @@
       self.save_weights_only = True
 
   def on_train_begin(self, logs=None):
-    if multi_worker_util.in_multi_worker_mode():
-      # pylint: disable=protected-access
+    # pylint: disable=protected-access
+    if self.model._in_multi_worker_mode():
       # MultiWorkerTrainingState is used to manage the training state needed
       # for preemption-recovery of a worker in multi-worker training.
       self.model._training_state = (
@@ -914,8 +917,8 @@
     # If this is not multi worker training, restoring is not needed, or
     # restoring failed, check if it should load weights on restart.
     if self.load_weights_on_restart:
-      if (not multi_worker_util.in_multi_worker_mode()
-          or multi_worker_util.should_load_checkpoint()):
+      if (not self.model._in_multi_worker_mode() or
+          multi_worker_util.should_load_checkpoint()):
         filepath_to_load = (
             self._get_most_recently_modified_file_matching_pattern(
                 self.filepath))
@@ -931,7 +934,8 @@
                 filepath_to_load, e))
 
   def on_train_end(self, logs=None):
-    if multi_worker_util.in_multi_worker_mode():
+    # pylint: disable=protected-access
+    if self.model._in_multi_worker_mode():
       # In multi-worker training, on successful exit of training, delete the
       # training state backup file that was saved for the purpose of worker
       # recovery.
@@ -954,14 +958,15 @@
 
   def on_epoch_end(self, epoch, logs=None):
     self.epochs_since_last_save += 1
+    # pylint: disable=protected-access
     if self.save_freq == 'epoch':
-      if multi_worker_util.in_multi_worker_mode():
+      if self.model._in_multi_worker_mode():
         # Exclude training state variables in user-requested checkpoint file.
         with self._training_state.untrack_vars():
           self._save_model(epoch=epoch, logs=logs)
       else:
         self._save_model(epoch=epoch, logs=logs)
-    if multi_worker_util.in_multi_worker_mode():
+    if self.model._in_multi_worker_mode():
       # For multi-worker training, back up the weights and current training
       # state for possible future recovery.
       # TODO(rchao): Call `back_up` at finer period such as N steps.
@@ -1013,7 +1018,8 @@
 
   def _get_file_path(self, epoch, logs):
     """Returns the file path for checkpoint."""
-    if not multi_worker_util.in_multi_worker_mode(
+    # pylint: disable=protected-access
+    if not self.model._in_multi_worker_mode(
     ) or multi_worker_util.should_save_checkpoint():
       return self.filepath.format(epoch=epoch + 1, **logs)
     else:
@@ -1031,8 +1037,9 @@
     # Remove the checkpoint directory in multi-worker training where this worker
     # should not checkpoint. It is a dummy directory previously saved for sync
     # distributed training.
-    if multi_worker_util.in_multi_worker_mode(
-    ) and not multi_worker_util.should_save_checkpoint():
+
+    if (self.model._in_multi_worker_mode() and  # pylint: disable=protected-access
+        not multi_worker_util.should_save_checkpoint()):
       file_io.delete_recursively(self._temp_file_dir)
       del self._temp_file_dir
 
@@ -1394,7 +1401,7 @@
         writes the losses and metrics to TensorBoard after each batch. The same
         applies for `'epoch'`. If using an integer, let's say `1000`, the
         callback will write the metrics and losses to TensorBoard every 1000
-        samples. Note that writing too frequently to TensorBoard can slow down
+        batches. Note that writing too frequently to TensorBoard can slow down
         your training.
       profile_batch: Profile the batch to sample compute characteristics. By
         default, it will profile the second batch. Set profile_batch=0 to
@@ -1441,16 +1448,14 @@
     self._samples_seen = 0
     self._samples_seen_at_last_write = 0
     self._current_batch = 0
-    self._total_batches_seen = 0
-    self._total_val_batches_seen = 0
 
     # A collection of file writers currently in use, to be closed when
     # training ends for this callback. Writers are keyed by the
     # directory name under the root logdir: e.g., "train" or
     # "validation".
-    self._writers = {}
     self._train_run_name = 'train'
     self._validation_run_name = 'validation'
+    self._writers = {}
 
     self._profile_batch = profile_batch
     # True when a trace is running.
@@ -1507,6 +1512,10 @@
     if self.embeddings_freq:
       self._configure_embeddings()
 
+    self._prev_summary_writer = context.context().summary_writer
+    self._prev_summary_recording = context.context().summary_recording
+    self._prev_summary_step = context.context().summary_step
+
   def _configure_embeddings(self):
     """Configure the Projector for embeddings."""
     # TODO(omalleyt): Add integration tests.
@@ -1575,12 +1584,55 @@
       self._writers[writer_name] = writer
     return self._writers[writer_name]
 
+  def _set_default_writer(self, writer_name):
+    """Sets the default writer for custom batch-level summaries."""
+    if self.update_freq == 'epoch':
+      # Writer is only used for custom summaries, which are written
+      # batch-by-batch.
+      return
+    writer = self._get_writer(writer_name)
+    step = self._total_batches_seen[writer_name]
+    context.context().summary_writer = writer
+
+    def _should_record():
+      return math_ops.equal(step % self.update_freq, 0)
+
+    context.context().summary_recording = _should_record
+    summary_ops_v2.set_step(step)
+
+  def _init_batch_steps(self):
+    """Create the total batch counters."""
+    if ops.executing_eagerly_outside_functions():
+      # Variables are needed for the `step` value of custom tf.summaries
+      # to be updated inside a tf.function.
+      self._total_batches_seen = {
+          self._train_run_name: variables.Variable(0, dtype='int64'),
+          self._validation_run_name: variables.Variable(0, dtype='int64')
+      }
+    else:
+      # Custom tf.summaries are not supported in legacy graph mode.
+      self._total_batches_seen = {
+          self._train_run_name: 0,
+          self._validation_run_name: 0
+      }
+
+  def _increment_step(self, writer_name):
+    step = self._total_batches_seen[writer_name]
+    if isinstance(step, variables.Variable):
+      step.assign_add(1)
+    else:
+      self._total_batches_seen[writer_name] += 1
+
   def on_train_begin(self, logs=None):
+    self._init_batch_steps()
     if self._profile_batch == 1:
       summary_ops_v2.trace_on(graph=True, profiler=True)
       self._is_tracing = True
 
-  def on_batch_end(self, batch, logs=None):
+  def on_test_begin(self, logs=None):
+    self._set_default_writer(self._validation_run_name)
+
+  def on_train_batch_end(self, batch, logs=None):
     """Writes scalar summaries for metrics on every training batch.
 
     Performs profiling if current batch is in profiler_batches.
@@ -1589,24 +1641,35 @@
       batch: Integer, index of batch within the current epoch.
       logs: Dict. Metric results for this batch.
     """
+    if self.update_freq == 'epoch' and self._profile_batch is None:
+      return
+
     # Don't output batch_size and batch number as TensorBoard summaries
     logs = logs or {}
-    self._samples_seen += logs.get('size', 1)
-    samples_seen_since = self._samples_seen - self._samples_seen_at_last_write
-    if self.update_freq != 'epoch' and samples_seen_since >= self.update_freq:
-      self._log_metrics(logs, prefix='batch_', step=self._total_batches_seen)
-      self._samples_seen_at_last_write = self._samples_seen
-    self._total_batches_seen += 1
-    if self._is_tracing:
-      self._log_trace()
-    elif (not self._is_tracing and
-          self._total_batches_seen == self._profile_batch - 1):
-      self._enable_trace()
+    train_batches = self._total_batches_seen[self._train_run_name]
+    if self.update_freq != 'epoch' and batch % self.update_freq == 0:
+      self._log_metrics(logs, prefix='batch_', step=train_batches)
+
+    self._increment_step(self._train_run_name)
+
+    if context.executing_eagerly():
+      if self._is_tracing:
+        self._log_trace()
+      elif (not self._is_tracing and
+            math_ops.equal(train_batches, self._profile_batch - 1)):
+        self._enable_trace()
+
+  def on_test_batch_end(self, batch, logs=None):
+    if self.update_freq == 'epoch':
+      return
+    self._increment_step(self._validation_run_name)
+
+  def on_epoch_begin(self, epoch, logs=None):
+    self._set_default_writer(self._train_run_name)
 
   def on_epoch_end(self, epoch, logs=None):
     """Runs metrics and histogram summaries at epoch end."""
-    step = epoch if self.update_freq == 'epoch' else self._samples_seen
-    self._log_metrics(logs, prefix='epoch_', step=step)
+    self._log_metrics(logs, prefix='epoch_', step=epoch)
 
     if self.histogram_freq and epoch % self.histogram_freq == 0:
       self._log_weights(epoch)
@@ -1619,19 +1682,25 @@
       self._log_trace()
     self._close_writers()
 
+    context.context().summary_writer = self._prev_summary_writer
+    context.context().summary_recording = self._prev_summary_recording
+    context.context().summary_step = self._prev_summary_step
+
   def _enable_trace(self):
     if context.executing_eagerly():
       summary_ops_v2.trace_on(graph=True, profiler=True)
       self._is_tracing = True
 
   def _log_trace(self):
+    """Logs the trace graph to TensorBoard."""
     if context.executing_eagerly():
       with self._get_writer(self._train_run_name).as_default(), \
           summary_ops_v2.always_record_summaries():
         # TODO(b/126388999): Remove step info in the summary name.
+        step = K.get_value(self._total_batches_seen[self._train_run_name])
         summary_ops_v2.trace_export(
-            name='batch_%d' % self._total_batches_seen,
-            step=self._total_batches_seen,
+            name='batch_%d' % step,
+            step=step,
             profiler_outdir=os.path.join(self.log_dir, 'train'))
       self._is_tracing = False
 
@@ -1902,7 +1971,7 @@
       is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
       if isinstance(k, six.string_types):
         return k
-      elif isinstance(k, collections.Iterable) and not is_zero_dim_ndarray:
+      elif isinstance(k, collections_abc.Iterable) and not is_zero_dim_ndarray:
         return '"[%s]"' % (', '.join(map(str, k)))
       else:
         return k
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index f072384..46fdb9b 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -20,6 +20,7 @@
 
 import collections
 import csv
+import json
 import os
 import re
 import shutil
@@ -31,6 +32,7 @@
 from absl.testing import parameterized
 import numpy as np
 
+from tensorflow.core.framework import summary_pb2
 from tensorflow.python import keras
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.framework import random_seed
@@ -40,6 +42,8 @@
 from tensorflow.python.keras.engine import sequential
 from tensorflow.python.keras.optimizer_v2 import gradient_descent
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import summary_ops_v2
 from tensorflow.python.platform import test
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.summary import summary_iterator
@@ -135,7 +139,7 @@
         adam.AdamOptimizer(0.001),
         'binary_crossentropy',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     return model
 
   @parameterized.named_parameters(('with_numpy', _get_numpy()),
@@ -238,7 +242,7 @@
         optimizer='rmsprop',
         metrics=[keras.metrics.CategoricalAccuracy(name='my_acc')],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     return model
 
   @keras_parameterized.run_with_all_model_types
@@ -523,7 +527,7 @@
         mode=mode,
         save_freq=3)
 
-  def _run_load_weights_on_restart_test_common_iterations(self):
+  def _get_dummy_resource_for_model_checkpoint_testing(self):
 
     def get_input_datasets():
       # Simple training input.
@@ -549,12 +553,19 @@
 
     temp_dir = self.get_temp_dir()
     filepath = os.path.join(temp_dir, 'checkpoint.epoch{epoch:02d}.h5')
-    initial_epochs = 3
 
     # The filepath shouldn't exist at the beginning.
     self.assertFalse(os.path.exists(filepath))
     callback = keras.callbacks.ModelCheckpoint(
         filepath=filepath, save_weights_only=True)
+
+    return model, train_ds, callback, filepath
+
+  def _run_load_weights_on_restart_test_common_iterations(self):
+
+    (model, train_ds, callback,
+     filepath) = self._get_dummy_resource_for_model_checkpoint_testing()
+    initial_epochs = 3
     model.fit(train_ds, epochs=initial_epochs, callbacks=[callback])
 
     # The files should exist after fitting with callback.
@@ -675,6 +686,23 @@
     self.assertNotAllClose(weights_before_additional_fit,
                            weights_after_additional_fit)
 
+  def test_fit_with_ModelCheckpoint_with_tf_config(self):
+    (model, train_ds, callback,
+     _) = self._get_dummy_resource_for_model_checkpoint_testing()
+
+    os.environ['TF_CONFIG'] = json.dumps({
+        'cluster': {
+            'worker': ['localhost:23333']
+        },
+        'task': {
+            'type': 'worker',
+            'index': 0
+        }
+    })
+
+    # `model.fit()` should work regardless of the presence of `TF_CONFIG`.
+    model.fit(train_ds, epochs=1, callbacks=[callback])
+
   def test_EarlyStopping(self):
     with self.cached_session():
       np.random.seed(123)
@@ -869,7 +897,7 @@
             num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
         model.compile(
             loss='categorical_crossentropy',
-            optimizer=keras.optimizers.SGD(lr=0.1))
+            optimizer=gradient_descent.SGD(lr=0.1))
         return model
 
       # TODO(psv): Make sure the callback works correctly when min_delta is
@@ -975,7 +1003,7 @@
             num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
         model.compile(
             loss='categorical_crossentropy',
-            optimizer=keras.optimizers.SGD(lr=0.1),
+            optimizer=gradient_descent.SGD(lr=0.1),
             metrics=['accuracy'])
         return model
 
@@ -1266,6 +1294,12 @@
             raise ValueError(
                 'Unexpected summary kind %r in event file %s:\n%r'
                 % (kind, path, event))
+          elif kind == 'tensor' and tag != 'keras':
+            # Check for V2 scalar summaries, which have a different PB
+            # structure.
+            if event.summary.value[
+                0].metadata.plugin_data.plugin_name == 'scalars':
+              container = result.scalars
           container.add(_ObservedSummary(logdir=dirpath, tag=tag))
   return result
 
@@ -1292,7 +1326,7 @@
         opt,
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     return model
 
   def test_TensorBoard_default_logdir(self):
@@ -1477,6 +1511,57 @@
         },
     )
 
+  def test_custom_summary(self):
+    if not testing_utils.should_run_tf_function():
+      self.skipTest('Custom summaries only supported in V2 code path.')
+
+    def scalar_v2_mock(name, data, step=None):
+      """A reimplementation of the scalar plugin to avoid circular deps."""
+      metadata = summary_pb2.SummaryMetadata()
+      # Should match value in tensorboard/plugins/scalar/metadata.py.
+      metadata.plugin_data.plugin_name = 'scalars'
+      with summary_ops_v2.summary_scope(
+          name, 'scalar_summary', values=[data, step]) as (tag, _):
+        return summary_ops_v2.write(
+            tag=tag,
+            tensor=math_ops.cast(data, 'float32'),
+            step=step,
+            metadata=metadata)
+
+    class LayerWithSummary(keras.layers.Layer):
+
+      def call(self, x):
+        scalar_v2_mock('custom_summary', math_ops.reduce_sum(x))
+        return x
+
+    model = testing_utils.get_model_from_layers([LayerWithSummary()],
+                                                input_shape=(5,),
+                                                name='model')
+
+    model.compile(
+        'sgd',
+        'mse',
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    tb_cbk = keras.callbacks.TensorBoard(self.logdir, update_freq=1)
+    x, y = np.ones((10, 5)), np.ones((10, 5))
+    model.fit(x, y, batch_size=2, validation_data=(x, y), callbacks=[tb_cbk])
+    summary_file = list_summaries(self.logdir)
+    self.assertEqual(
+        summary_file.scalars,
+        {
+            _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'),
+            _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'),
+            _ObservedSummary(logdir=self.train_dir, tag='batch_loss'),
+            _ObservedSummary(
+                logdir=self.train_dir,
+                tag='model/layer_with_summary/custom_summary'),
+            _ObservedSummary(
+                logdir=self.validation_dir,
+                tag='model/layer_with_summary/custom_summary')
+        },
+    )
+
   def _strip_layer_names(self, summaries, model_type):
     """Deduplicate summary names modulo layer prefix.
 
@@ -1526,7 +1611,7 @@
         opt,
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     return model
 
   def fitModelAndAssertKerasModelWritten(self, model):
diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD
index 2607fa7..ef84471 100644
--- a/tensorflow/python/keras/distribute/BUILD
+++ b/tensorflow/python/keras/distribute/BUILD
@@ -2,7 +2,7 @@
 #   keras/distribute package is intended to serve as the centralized place for things
 #   related to dist-strat used by Keras..
 
-load("//tensorflow/core:platform/default/distribute.bzl", "distribute_py_test")
+load("//tensorflow/core/platform:default/distribute.bzl", "distribute_py_test")
 load("//tensorflow:tensorflow.bzl", "cuda_py_test")
 
 package(
@@ -98,6 +98,21 @@
 )
 
 distribute_py_test(
+    name = "keras_premade_models_test",
+    srcs = ["keras_premade_models_test.py"],
+    full_precision = True,
+    main = "keras_premade_models_test.py",
+    shard_count = 4,
+    tags = [
+        "multi_and_single_gpu",
+    ],
+    deps = [
+        ":distribute_strategy_test_lib",
+        ":keras_correctness_test_lib",
+    ],
+)
+
+distribute_py_test(
     name = "distribute_strategy_test",
     srcs = ["distribute_strategy_test.py"],
     full_precision = True,
@@ -168,8 +183,6 @@
         "multi_and_single_gpu",
         "no_rocm",  # times out on ROCm
         "no_windows_gpu",
-        # TODO(b/134764123): Re-enable this test.
-        "notap",
         "notsan",
     ],
     deps = [
@@ -208,6 +221,7 @@
         "no_windows_gpu",
         "notsan",
     ],
+    xla_enable_strict_auto_jit = False,  # Tensorflow also fails.
     deps = [
         ":keras_correctness_test_lib",
     ],
diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py
index 8278b6b..6a5c7d3 100644
--- a/tensorflow/python/keras/distribute/distribute_strategy_test.py
+++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py
@@ -31,6 +31,7 @@
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import test
+from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.distribute import distributed_training_utils
 from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
@@ -39,6 +40,7 @@
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn
 from tensorflow.python.ops.losses import loss_reduction
+from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.training import gradient_descent
 from tensorflow.python.training import rmsprop
 
@@ -247,10 +249,10 @@
   return (combinations.combine(
       distribution=strategies_minus_tpu,
       mode=['graph', 'eager'],
-      run_distributed=[True, False]) + combinations.combine(
+      experimental_run_tf_function=[True, False]) + combinations.combine(
           distribution=tpu_strategies,
           mode=['graph', 'eager'],
-          run_distributed=[False]))
+          experimental_run_tf_function=[False]))
 
 
 def all_strategy_minus_default_and_tpu_combinations():
@@ -274,16 +276,20 @@
       strategy_minus_tpu_combinations(),
       combinations.combine(
           optimizer=[
+              strategy_combinations.adagrad_optimizer_v1_fn,
+              strategy_combinations.adam_optimizer_v1_fn,
+              strategy_combinations.gradient_descent_optimizer_v1_fn,
+              strategy_combinations.rmsprop_optimizer_v1_fn,
               strategy_combinations.adagrad_optimizer_keras_v2_fn,
               strategy_combinations.adam_optimizer_keras_v2_fn,
               strategy_combinations.gradient_descent_optimizer_keras_v2_fn,
               strategy_combinations.rmsprop_optimizer_keras_v2_fn
           ],
-          run_distributed=[True, False]))
+          experimental_run_tf_function=[True, False]))
   tpu_strategies_graph = combinations.combine(
       distribution=tpu_strategies,
       mode=['graph'],
-      run_distributed=[True],
+      experimental_run_tf_function=[True],
       optimizer=[
           strategy_combinations.adagrad_optimizer_v1_fn,
           strategy_combinations.adam_optimizer_v1_fn,
@@ -297,7 +303,7 @@
   tpu_strategies_eager = combinations.combine(
       distribution=tpu_strategies,
       mode=['eager'],
-      run_distributed=[False],
+      experimental_run_tf_function=[False],
       optimizer=[
           strategy_combinations.adagrad_optimizer_keras_v2_fn,
           strategy_combinations.adam_optimizer_keras_v2_fn,
@@ -424,7 +430,8 @@
             distribution, input_64_samples, steps=10, batch_size=13)
 
   @combinations.generate(all_strategy_combinations_plus_run_distributed())
-  def test_calling_model_with_numpy_arrays(self, distribution, run_distributed):
+  def test_calling_model_with_numpy_arrays(self, distribution,
+                                           experimental_run_tf_function):
     with self.cached_session():
       with distribution.scope():
         optimizer_fn = gradient_descent_keras.SGD
@@ -433,7 +440,10 @@
         loss = 'mse'
         metrics = ['mae']
         model.compile(
-            optimizer, loss, metrics=metrics, run_distributed=run_distributed)
+            optimizer,
+            loss,
+            metrics=metrics,
+            experimental_run_tf_function=experimental_run_tf_function)
 
         inputs = np.zeros((64, 3), dtype=np.float32)
         targets = np.zeros((64, 4), dtype=np.float32)
@@ -457,14 +467,17 @@
 
   @combinations.generate(all_strategy_combinations_plus_run_distributed())
   def test_calling_model_with_nested_numpy_arrays(self, distribution,
-                                                  run_distributed):
+                                                  experimental_run_tf_function):
     with self.cached_session():
       with distribution.scope():
         optimizer_fn = gradient_descent_keras.SGD
         optimizer = optimizer_fn(learning_rate=0.001)
         model = multi_input_output_model()
         loss = 'mse'
-        model.compile(optimizer, loss, run_distributed=run_distributed)
+        model.compile(
+            optimizer,
+            loss,
+            experimental_run_tf_function=experimental_run_tf_function)
 
       input_a_np = np.asarray(np.random.random((64, 3)), dtype=np.float32)
       input_b_np = np.asarray(np.random.random((64, 5)), dtype=np.float32)
@@ -489,13 +502,17 @@
       combinations.combine(
           distribution=strategies_minus_tpu,
           mode=['graph', 'eager'],
-          run_distributed=[True, False]))
-  def test_numpy_with_sample_weights(self, distribution, run_distributed):
+          experimental_run_tf_function=[True, False]))
+  def test_numpy_with_sample_weights(self, distribution,
+                                     experimental_run_tf_function):
     with self.cached_session(), distribution.scope():
       model = get_sample_weights_model()
       optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
       loss = 'mse'
-      model.compile(optimizer, loss, run_distributed=run_distributed)
+      model.compile(
+          optimizer,
+          loss,
+          experimental_run_tf_function=experimental_run_tf_function)
 
       inputs = np.array([[0], [1], [2], [3]], np.float32)
       targets = np.array([[2], [4], [6], [8]], np.float32)
@@ -526,14 +543,18 @@
       self.assertAllClose(result, 13.5)
 
   @combinations.generate(all_strategy_combinations_plus_run_distributed())
-  def test_flatten_predict_outputs(self, distribution, run_distributed):
+  def test_flatten_predict_outputs(self, distribution,
+                                   experimental_run_tf_function):
     with self.cached_session():
       with distribution.scope():
         model = multi_input_output_model()
         optimizer_fn = gradient_descent_keras.SGD
         optimizer = optimizer_fn(learning_rate=0.001)
         loss = 'mse'
-        model.compile(optimizer, loss, run_distributed=run_distributed)
+        model.compile(
+            optimizer,
+            loss,
+            experimental_run_tf_function=experimental_run_tf_function)
 
       # We take 6 input samples with each input having a dimension of 3 or 5.
       input_a_np = np.asarray(np.random.random((6, 3)), dtype=np.float32)
@@ -594,9 +615,11 @@
           rtol=1e-5)
 
   @combinations.generate(
-      combinations.times(tpu_strategy_combinations_graph_only(),
-                         combinations.combine(run_distributed=[True, False])))
-  def test_predict_with_partial_batch(self, distribution, run_distributed):
+      combinations.times(
+          tpu_strategy_combinations_graph_only(),
+          combinations.combine(experimental_run_tf_function=[True, False])))
+  def test_predict_with_partial_batch(self, distribution,
+                                      experimental_run_tf_function):
     with self.cached_session():
       optimizer = gradient_descent.GradientDescentOptimizer(0.001)
       loss = 'mse'
@@ -604,7 +627,9 @@
       with distribution.scope():
         model_with_ds_strategy = get_model()
         model_with_ds_strategy.compile(
-            optimizer, loss, run_distributed=run_distributed)
+            optimizer,
+            loss,
+            experimental_run_tf_function=experimental_run_tf_function)
 
       cpu_model = get_model()
       cpu_model.compile(optimizer, loss)
@@ -655,10 +680,11 @@
         model.evaluate(inputs, steps=1)
 
   @combinations.generate(
-      combinations.times(tpu_strategy_combinations_graph_only(),
-                         combinations.combine(run_distributed=[True, False])))
+      combinations.times(
+          tpu_strategy_combinations_graph_only(),
+          combinations.combine(experimental_run_tf_function=[True, False])))
   def test_predict_multi_output_model_with_partial_batch(
-      self, distribution, run_distributed):
+      self, distribution, experimental_run_tf_function):
     with self.cached_session():
       optimizer = gradient_descent.GradientDescentOptimizer(0.001)
       loss = 'mse'
@@ -666,7 +692,9 @@
       with distribution.scope():
         model_with_ds_strategy = simple_multi_inputs_multi_outputs_model()
         model_with_ds_strategy.compile(
-            optimizer, loss, run_distributed=run_distributed)
+            optimizer,
+            loss,
+            experimental_run_tf_function=experimental_run_tf_function)
 
       cpu_model = simple_multi_inputs_multi_outputs_model()
       cpu_model.compile(optimizer, loss)
@@ -693,7 +721,8 @@
                                            parameterized.TestCase):
 
   @combinations.generate(all_strategy_combinations_plus_run_distributed())
-  def test_calling_model_on_same_dataset(self, distribution, run_distributed):
+  def test_calling_model_on_same_dataset(self, distribution,
+                                         experimental_run_tf_function):
     with self.cached_session():
       with distribution.scope():
         optimizer_fn = gradient_descent_keras.SGD
@@ -702,7 +731,10 @@
         loss = 'mse'
         metrics = ['mae', keras.metrics.CategoricalAccuracy()]
         model.compile(
-            optimizer, loss, metrics=metrics, run_distributed=run_distributed)
+            optimizer,
+            loss,
+            metrics=metrics,
+            experimental_run_tf_function=experimental_run_tf_function)
 
       dataset = get_dataset(distribution)
 
@@ -724,8 +756,8 @@
       model.predict(get_predict_dataset(distribution), steps=2)
 
   @combinations.generate(all_strategy_combinations_plus_run_distributed())
-  def test_model_interleaved_eval_same_as_direct_eval(self, distribution,
-                                                      run_distributed):
+  def test_model_interleaved_eval_same_as_direct_eval(
+      self, distribution, experimental_run_tf_function):
     with self.cached_session():
       with distribution.scope():
         optimizer_fn = gradient_descent_keras.SGD
@@ -734,7 +766,7 @@
             optimizer_fn(0.001),
             loss='mse',
             metrics=['mae', keras.metrics.CategoricalAccuracy()],
-            run_distributed=run_distributed)
+            experimental_run_tf_function=experimental_run_tf_function)
 
         interleaved_model = get_model()
         interleaved_model.set_weights(user_controlled_model.get_weights())
@@ -742,7 +774,7 @@
             optimizer_fn(0.001),
             loss='mse',
             metrics=['mae', keras.metrics.CategoricalAccuracy()],
-            run_distributed=run_distributed)
+            experimental_run_tf_function=experimental_run_tf_function)
 
       dataset = get_dataset(distribution)
 
@@ -778,7 +810,7 @@
 
   @combinations.generate(all_strategy_combinations_plus_run_distributed())
   def test_fit_with_tuple_and_dict_dataset_inputs(self, distribution,
-                                                  run_distributed):
+                                                  experimental_run_tf_function):
     with self.cached_session():
       with distribution.scope():
         optimizer_fn = gradient_descent_keras.SGD
@@ -787,7 +819,10 @@
         loss = 'mse'
         metrics = ['mae', keras.metrics.CategoricalAccuracy()]
         model.compile(
-            optimizer, loss, metrics=metrics, run_distributed=run_distributed)
+            optimizer,
+            loss,
+            metrics=metrics,
+            experimental_run_tf_function=experimental_run_tf_function)
 
       input_a_np = np.random.random((10, 3)).astype('float32')
       input_b_np = np.random.random((10, 5)).astype('float32')
@@ -814,7 +849,7 @@
 
   @combinations.generate(all_strategy_combinations_plus_run_distributed())
   def test_fit_with_dictionary_in_the_dataset_b135161171(
-      self, distribution, run_distributed):
+      self, distribution, experimental_run_tf_function):
 
     def custom_loss(predict, label, weight):
       bce = keras.losses.binary_crossentropy(label, predict)
@@ -833,7 +868,9 @@
             inputs=[input_img, input_lbl, input_weight],
             outputs=[predict, my_loss])
         model.add_loss(model.get_layer('my_loss').output)
-        model.compile(optimizer='adam', run_distributed=run_distributed)
+        model.compile(
+            optimizer='adam',
+            experimental_run_tf_function=experimental_run_tf_function)
 
       def map_fn(img, lbl, weight):
         inputs = {'img': img, 'lbl': lbl, 'weight': weight}
@@ -851,7 +888,7 @@
 
   @combinations.generate(all_strategy_combinations_plus_run_distributed())
   def test_fit_eval_and_predict_methods_on_dataset_without_steps(
-      self, distribution, run_distributed):
+      self, distribution, experimental_run_tf_function):
     with self.cached_session():
       with distribution.scope():
         optimizer_fn = gradient_descent_keras.SGD
@@ -860,7 +897,10 @@
         loss = 'mse'
         metrics = ['mae', keras.metrics.CategoricalAccuracy()]
         model.compile(
-            optimizer, loss, metrics=metrics, run_distributed=run_distributed)
+            optimizer,
+            loss,
+            metrics=metrics,
+            experimental_run_tf_function=experimental_run_tf_function)
 
       inputs = np.zeros((1000, 3), dtype=np.float32)
       targets = np.zeros((1000, 4), dtype=np.float32)
@@ -884,10 +924,11 @@
           predict_with_numpy, predict_with_ds, atol=1e-4, rtol=1e-4)
 
   @combinations.generate(
-      combinations.times(strategy_minus_tpu_combinations(),
-                         combinations.combine(run_distributed=[True, False])))
+      combinations.times(
+          strategy_minus_tpu_combinations(),
+          combinations.combine(experimental_run_tf_function=[True, False])))
   def test_on_dataset_with_unknown_cardinality_without_steps(
-      self, distribution, run_distributed, mode):
+      self, distribution, experimental_run_tf_function, mode):
     with self.cached_session():
       with distribution.scope():
         optimizer_fn = gradient_descent_keras.SGD
@@ -896,7 +937,10 @@
         loss = 'mse'
         metrics = ['mae', keras.metrics.CategoricalAccuracy()]
         model.compile(
-            optimizer, loss, metrics=metrics, run_distributed=run_distributed)
+            optimizer,
+            loss,
+            metrics=metrics,
+            experimental_run_tf_function=experimental_run_tf_function)
 
       inputs = np.zeros((1000, 3), dtype=np.float32)
       targets = np.zeros((1000, 4), dtype=np.float32)
@@ -937,10 +981,11 @@
           rtol=1e-4)
 
   @combinations.generate(
-      combinations.times(tpu_strategy_combinations(),
-                         combinations.combine(run_distributed=[True, False])))
+      combinations.times(
+          tpu_strategy_combinations(),
+          combinations.combine(experimental_run_tf_function=[True, False])))
   def test_on_dataset_with_unknown_cardinality(self, distribution,
-                                               run_distributed):
+                                               experimental_run_tf_function):
     with self.cached_session():
       with distribution.scope():
         model = get_model()
@@ -950,7 +995,7 @@
             gradient_descent.GradientDescentOptimizer(0.001),
             loss,
             metrics=metrics,
-            run_distributed=run_distributed)
+            experimental_run_tf_function=experimental_run_tf_function)
 
       inputs = np.zeros((1000, 3), dtype=np.float32)
       targets = np.zeros((1000, 4), dtype=np.float32)
@@ -982,8 +1027,8 @@
         model.fit(dataset, epochs=1)
 
   @combinations.generate(all_strategy_combinations_plus_run_distributed())
-  def test_fit_eval_and_predict_methods_on_dataset(self, distribution,
-                                                   run_distributed):
+  def test_fit_eval_and_predict_methods_on_dataset(
+      self, distribution, experimental_run_tf_function):
     with self.cached_session():
       with distribution.scope():
         optimizer_fn = gradient_descent_keras.SGD
@@ -992,7 +1037,10 @@
         loss = 'mse'
         metrics = ['mae', keras.metrics.CategoricalAccuracy()]
         model.compile(
-            optimizer, loss, metrics=metrics, run_distributed=run_distributed)
+            optimizer,
+            loss,
+            metrics=metrics,
+            experimental_run_tf_function=experimental_run_tf_function)
 
       dataset = get_dataset(distribution)
 
@@ -1002,14 +1050,17 @@
 
   @combinations.generate(strategy_and_optimizer_combinations())
   def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer,
-                                               run_distributed):
+                                               experimental_run_tf_function):
     with self.cached_session():
 
       with distribution.scope():
 
         model = get_model()
         loss = 'mse'
-        model.compile(optimizer(), loss, run_distributed=run_distributed)
+        model.compile(
+            optimizer(),
+            loss,
+            experimental_run_tf_function=experimental_run_tf_function)
 
       dataset = get_dataset(distribution)
 
@@ -1024,8 +1075,9 @@
               strategy_combinations.one_device_strategy
           ],
           mode=['graph', 'eager'],
-          run_distributed=[True, False]))
-  def test_dataset_wrong_input_shape(self, distribution, run_distributed, mode):
+          experimental_run_tf_function=[True, False]))
+  def test_dataset_wrong_input_shape(self, distribution,
+                                     experimental_run_tf_function, mode):
     if mode == 'graph':
       self.skipTest(
           'TODO(b/120943676, b/120957836): Re-enable for graph once the '
@@ -1036,7 +1088,10 @@
         optimizer = optimizer_fn(learning_rate=0.001)
         model = get_model()
         loss = 'mse'
-        model.compile(optimizer, loss, run_distributed=run_distributed)
+        model.compile(
+            optimizer,
+            loss,
+            experimental_run_tf_function=experimental_run_tf_function)
 
       # Wrong input shape
       inputs = np.zeros((10, 5), dtype=np.float32)
@@ -1054,16 +1109,19 @@
               strategy_combinations.mirrored_strategy_with_gpu_and_cpu
           ],
           mode=['graph', 'eager'],
-          run_distributed=[True, False]))
-  def test_dataset_external_batch_input_validation(self, distribution,
-                                                   run_distributed):
+          experimental_run_tf_function=[True, False]))
+  def test_dataset_external_batch_input_validation(
+      self, distribution, experimental_run_tf_function):
     with self.cached_session():
       with distribution.scope():
         optimizer_fn = gradient_descent_keras.SGD
         optimizer = optimizer_fn(learning_rate=0.001)
         model = get_model()
         loss = 'mse'
-        model.compile(optimizer, loss, run_distributed=run_distributed)
+        model.compile(
+            optimizer,
+            loss,
+            experimental_run_tf_function=experimental_run_tf_function)
 
       # Batching is done outside tf.data's `batch`
       inputs = np.zeros((100, 10, 3), dtype=np.float32)
@@ -1080,8 +1138,9 @@
               strategy_combinations.mirrored_strategy_with_two_gpus
           ],
           mode=['graph', 'eager'],
-          run_distributed=[True, False]))
-  def test_learning_phase_value(self, distribution, run_distributed):
+          experimental_run_tf_function=[True, False]))
+  def test_learning_phase_value(self, distribution,
+                                experimental_run_tf_function):
     # TODO(anjalisridhar): Modify this test to use Lambdas since we can compare
     # meaningful values. Currently we don't pass the learning phase if the
     # Lambda layer uses the learning phase.
@@ -1098,7 +1157,10 @@
         loss = 'mse'
         metrics = ['acc']
         model.compile(
-            optimizer, loss, metrics=metrics, run_distributed=run_distributed)
+            optimizer,
+            loss,
+            metrics=metrics,
+            experimental_run_tf_function=experimental_run_tf_function)
 
       batch_size = 8
       if isinstance(distribution, mirrored_strategy.MirroredStrategy):
@@ -1128,13 +1190,17 @@
       self.assertArrayNear(output, ref_output, 1e-1)
 
   @combinations.generate(all_strategy_combinations_plus_run_distributed())
-  def testOptimizerWithCallbacks(self, distribution, run_distributed):
+  def testOptimizerWithCallbacks(self, distribution,
+                                 experimental_run_tf_function):
     with self.cached_session():
       with distribution.scope():
         model = get_model()
         optimizer = gradient_descent_keras.SGD(0.01)
         loss = 'mse'
-        model.compile(optimizer, loss, run_distributed=run_distributed)
+        model.compile(
+            optimizer,
+            loss,
+            experimental_run_tf_function=experimental_run_tf_function)
 
       dataset = get_dataset(distribution)
 
@@ -1191,10 +1257,11 @@
           rtol=1e-5)
 
   @combinations.generate(
-      combinations.times(tpu_strategy_combinations_graph_only(),
-                         combinations.combine(run_distributed=[True, False])))
-  def test_predict_with_dataset_with_partial_batch(self, distribution,
-                                                   run_distributed):
+      combinations.times(
+          tpu_strategy_combinations_graph_only(),
+          combinations.combine(experimental_run_tf_function=[True, False])))
+  def test_predict_with_dataset_with_partial_batch(
+      self, distribution, experimental_run_tf_function):
     with self.cached_session():
       optimizer = gradient_descent.GradientDescentOptimizer(0.001)
       loss = 'mse'
@@ -1202,7 +1269,9 @@
       with distribution.scope():
         model_with_ds_strategy = get_model()
         model_with_ds_strategy.compile(
-            optimizer, loss, run_distributed=run_distributed)
+            optimizer,
+            loss,
+            experimental_run_tf_function=experimental_run_tf_function)
 
       cpu_model = get_model()
       cpu_model.compile(optimizer, loss)
@@ -1222,10 +1291,11 @@
           rtol=1e-5)
 
   @combinations.generate(
-      combinations.times(tpu_strategy_combinations_graph_only(),
-                         combinations.combine(run_distributed=[True, False])))
+      combinations.times(
+          tpu_strategy_combinations_graph_only(),
+          combinations.combine(experimental_run_tf_function=[True, False])))
   def test_predict_multi_output_model_with_dataset_with_partial_batch(
-      self, distribution, run_distributed):
+      self, distribution, experimental_run_tf_function):
     with self.cached_session():
       optimizer = gradient_descent.GradientDescentOptimizer(0.001)
       loss = 'mse'
@@ -1233,7 +1303,9 @@
       with distribution.scope():
         model_with_ds_strategy = simple_multi_inputs_multi_outputs_model()
         model_with_ds_strategy.compile(
-            optimizer, loss, run_distributed=run_distributed)
+            optimizer,
+            loss,
+            experimental_run_tf_function=experimental_run_tf_function)
 
       cpu_model = simple_multi_inputs_multi_outputs_model()
       cpu_model.compile(optimizer, loss)
@@ -1314,13 +1386,17 @@
       combinations.combine(
           distribution=strategies_minus_tpu,
           mode=['graph', 'eager'],
-          run_distributed=[True, False]))
-  def test_dataset_with_sample_weights(self, distribution, run_distributed):
+          experimental_run_tf_function=[True, False]))
+  def test_dataset_with_sample_weights(self, distribution,
+                                       experimental_run_tf_function):
     with self.cached_session(), distribution.scope():
       model = get_sample_weights_model()
       optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
       loss = 'mse'
-      model.compile(optimizer, loss, run_distributed=run_distributed)
+      model.compile(
+          optimizer,
+          loss,
+          experimental_run_tf_function=experimental_run_tf_function)
 
       inputs = np.array([[0], [1], [2], [3]], np.float32)
       targets = np.array([[2], [4], [6], [8]], np.float32)
@@ -1373,8 +1449,8 @@
   @combinations.generate(
       combinations.times(
           strategy_combinations.all_strategy_combinations_minus_default(),
-          combinations.combine(run_distributed=[True, False])))
-  def test_regularizer_loss(self, distribution, run_distributed):
+          combinations.combine(experimental_run_tf_function=[True, False])))
+  def test_regularizer_loss(self, distribution, experimental_run_tf_function):
     batch_size = 2
     if not distributed_training_utils.global_batch_size_supported(distribution):
       batch_size //= distribution.num_replicas_in_sync
@@ -1396,7 +1472,7 @@
       model.compile(
           opt,
           loss=TestRegularizerLoss.loss_fn,
-          run_distributed=run_distributed)
+          experimental_run_tf_function=experimental_run_tf_function)
       model.fit(
           x=np.array([[1.], [1.]], dtype=np.float32),
           y=np.array([[1.], [1.]], dtype=np.float32),
@@ -1409,14 +1485,17 @@
                                               parameterized.TestCase):
 
   @combinations.generate(all_strategy_combinations_plus_run_distributed())
-  def test_distribution_strategy_on_sequential_model(self, distribution,
-                                                     run_distributed):
+  def test_distribution_strategy_on_sequential_model(
+      self, distribution, experimental_run_tf_function):
     with distribution.scope():
       optimizer_fn = gradient_descent_keras.SGD
       optimizer = optimizer_fn(learning_rate=0.001)
       model = simple_sequential_model()
       loss = 'mse'
-      model.compile(optimizer, loss, run_distributed=run_distributed)
+      model.compile(
+          optimizer,
+          loss,
+          experimental_run_tf_function=experimental_run_tf_function)
 
       inputs = np.zeros((20, 10), np.float32)
       targets = np.zeros((20, 2), np.float32)
@@ -1426,14 +1505,17 @@
     model.evaluate(inputs, targets, batch_size=10)
 
   @combinations.generate(all_strategy_combinations_plus_run_distributed())
-  def test_distribution_strategy_on_functional_model(self, distribution,
-                                                     run_distributed):
+  def test_distribution_strategy_on_functional_model(
+      self, distribution, experimental_run_tf_function):
     with distribution.scope():
       optimizer_fn = gradient_descent_keras.SGD
       optimizer = optimizer_fn(learning_rate=0.001)
       model = get_model()
       loss = 'mse'
-      model.compile(optimizer, loss, run_distributed=run_distributed)
+      model.compile(
+          optimizer,
+          loss,
+          experimental_run_tf_function=experimental_run_tf_function)
 
       inputs = np.zeros((64, 3), dtype=np.float32)
       targets = np.zeros((64, 4), dtype=np.float32)
@@ -1443,10 +1525,11 @@
     model.evaluate(inputs, targets)
 
   @combinations.generate(
-      combinations.times(all_strategy_combinations_minus_default(),
-                         combinations.combine(run_distributed=[True, False])))
+      combinations.times(
+          all_strategy_combinations_minus_default(),
+          combinations.combine(experimental_run_tf_function=[True, False])))
   def test_distribution_strategy_one_dimensional(self, distribution,
-                                                 run_distributed):
+                                                 experimental_run_tf_function):
     with distribution.scope():
       inp = keras.layers.Input(shape=(10,))
       out = keras.layers.Dense(3, activation='softmax')(inp)
@@ -1455,7 +1538,7 @@
           optimizer='rmsprop',
           loss='sparse_categorical_crossentropy',
           metrics=['sparse_categorical_accuracy'],
-          run_distributed=run_distributed)
+          experimental_run_tf_function=experimental_run_tf_function)
 
       x = np.random.random((64, 10)).astype('float32')
       y = np.random.randint(3, size=64)
@@ -1469,14 +1552,14 @@
               strategy_combinations.mirrored_strategy_with_two_gpus
           ],
           mode=['graph', 'eager'],
-          run_distributed=[True, False],
+          experimental_run_tf_function=[True, False],
           reduction=[
               loss_reduction.ReductionV2.AUTO,
               loss_reduction.ReductionV2.SUM_OVER_BATCH_SIZE,
               loss_reduction.ReductionV2.SUM
           ]))
   def test_distribution_strategy_with_loss_reduction_types(
-      self, distribution, run_distributed, reduction):
+      self, distribution, experimental_run_tf_function, reduction):
     np.random.seed(_RANDOM_SEED)
 
     def _get_model():
@@ -1502,17 +1585,24 @@
       ds_model.compile(
           'sgd',
           loss=keras.losses.MeanSquaredError(reduction=reduction),
-          run_distributed=run_distributed)
+          experimental_run_tf_function=experimental_run_tf_function)
       ds_history = ds_model.fit(
           dataset, steps_per_epoch=2, epochs=1, shuffle=False)
     self.assertArrayNear(history.history['loss'], ds_history.history['loss'],
                          1e-5)
 
   @combinations.generate(
-      combinations.times(all_strategy_combinations_minus_default(),
-                         combinations.combine(run_distributed=[True, False])))
-  def test_distribution_strategy_with_symbolic_add_loss(self, distribution,
-                                                        run_distributed):
+      combinations.times(
+          all_strategy_combinations_minus_default(),
+          combinations.combine(experimental_run_tf_function=[True, False])))
+  def test_distribution_strategy_with_symbolic_add_loss(
+      self, mode, distribution, experimental_run_tf_function):
+
+    # TODO(b/123533246): Enable the test for TPU once bug is fixed
+    if (isinstance(distribution,
+                   (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)) and
+        mode == 'graph' and not experimental_run_tf_function):
+      self.skipTest('TPU Strategy in graph mode fails with this test.')
 
     def _make_model_with_add_loss():
       inputs = keras.Input((10,))
@@ -1532,7 +1622,8 @@
 
     with distribution.scope():
       ds_model = _make_model_with_add_loss()
-      ds_model.compile('sgd', run_distributed=run_distributed)
+      ds_model.compile(
+          'sgd', experimental_run_tf_function=experimental_run_tf_function)
       ds_history = ds_model.fit(x, epochs=1)
 
     self.assertAllClose(history.history, ds_history.history)
@@ -1570,10 +1661,11 @@
     self.assertAllClose(history.history, ds_history.history)
 
   @combinations.generate(
-      combinations.times(all_strategy_minus_default_and_tpu_combinations(),
-                         combinations.combine(run_distributed=[True, False])))
+      combinations.times(
+          all_strategy_minus_default_and_tpu_combinations(),
+          combinations.combine(experimental_run_tf_function=[True, False])))
   def test_distribution_strategy_with_add_metric_in_call(
-      self, distribution, run_distributed):
+      self, distribution, experimental_run_tf_function):
 
     class Bias(keras.layers.Layer):
 
@@ -1606,7 +1698,10 @@
     with distribution.scope():
       ds_model = _make_model_with_add_metric()
       self.assertLen(ds_model.metrics, 1)
-      ds_model.compile('sgd', 'mse', run_distributed=run_distributed)
+      ds_model.compile(
+          'sgd',
+          'mse',
+          experimental_run_tf_function=experimental_run_tf_function)
       ds_history = ds_model.fit(
           x, y, validation_data=(x, y), validation_steps=2, epochs=2)
       self.assertLen(ds_model.metrics, 1)
@@ -1622,9 +1717,9 @@
               strategy_combinations.mirrored_strategy_with_two_gpus,
           ],
           mode=['eager'],
-          run_distributed=[False]))
-  def test_distribution_strategy_with_add_metric_object(self, distribution,
-                                                        run_distributed):
+          experimental_run_tf_function=[False]))
+  def test_distribution_strategy_with_add_metric_object(
+      self, distribution, experimental_run_tf_function):
 
     class Bias(keras.layers.Layer):
 
@@ -1657,7 +1752,10 @@
     with distribution.scope():
       ds_model = _make_model_with_add_metric_object()
       self.assertLen(ds_model.metrics, 1)
-      ds_model.compile('sgd', 'mse', run_distributed=run_distributed)
+      ds_model.compile(
+          'sgd',
+          'mse',
+          experimental_run_tf_function=experimental_run_tf_function)
       ds_history = ds_model.fit(
           x, y, validation_data=(x, y), validation_steps=2, epochs=2)
       self.assertLen(ds_model.metrics, 1)
@@ -1666,10 +1764,11 @@
 
   @combinations.generate(
       # TODO(phillypham): Why does validation_steps > 1 not work on TPUs?
-      combinations.times(all_strategy_minus_default_and_tpu_combinations(),
-                         combinations.combine(run_distributed=[True, False])))
+      combinations.times(
+          all_strategy_minus_default_and_tpu_combinations(),
+          combinations.combine(experimental_run_tf_function=[True, False])))
   def test_distribution_strategy_with_add_metric_outside_call(
-      self, distribution, run_distributed):
+      self, distribution, experimental_run_tf_function):
 
     def _make_model_with_add_metric():
       inputs = keras.Input((10,))
@@ -1693,7 +1792,10 @@
     with distribution.scope():
       ds_model = _make_model_with_add_metric()
       self.assertLen(ds_model.metrics, 1)
-      ds_model.compile('sgd', 'mse', run_distributed=run_distributed)
+      ds_model.compile(
+          'sgd',
+          'mse',
+          experimental_run_tf_function=experimental_run_tf_function)
       ds_history = ds_model.fit(
           x, y, validation_data=(x, y), validation_steps=2, epochs=2)
       self.assertLen(ds_model.metrics, 1)
@@ -1702,6 +1804,68 @@
 
   @combinations.generate(
       combinations.combine(
+          distribution=strategies_minus_tpu,
+          mode=['eager'],
+          experimental_run_tf_function=[True]))
+  def test_sparse_tensor_outputs(self, distribution,
+                                 experimental_run_tf_function):
+
+    class ToSparse(keras.layers.Layer):
+      """Create a sparse tensor based on a given dense tensor."""
+
+      def call(self, inputs):
+        indices = array_ops.where_v2(math_ops.not_equal(inputs, 0))
+        values = array_ops.gather_nd(inputs, indices)
+        shape = array_ops.shape(inputs, out_type='int64')
+        return sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
+
+    model = keras.Sequential([ToSparse()])
+    model._experimental_run_tf_function = experimental_run_tf_function
+
+    # Define some input data with additional padding.
+    input_data = np.array([[1, 0, 0], [2, 3, 0]])
+    output = model.predict(input_data, batch_size=2)
+
+    expected_indices = np.array([[0, 0], [1, 0], [1, 1]])
+    expected_values = np.array([1, 2, 3])
+    expected_dense_shape = np.array([2, 3])
+
+    self.assertAllEqual(output.indices, expected_indices)
+    self.assertAllEqual(output.values, expected_values)
+    self.assertAllEqual(output.dense_shape, expected_dense_shape)
+
+  @combinations.generate(
+      combinations.combine(
+          distribution=strategies_minus_tpu,
+          mode=['eager'],
+          experimental_run_tf_function=[True]))
+  def test_ragged_tensor_outputs(self, distribution,
+                                 experimental_run_tf_function):
+
+    class ToRagged(keras.layers.Layer):
+      """Create a ragged tensor based on a given dense tensor."""
+
+      def __init__(self, padding, ragged_rank=1, **kwargs):
+        super(ToRagged, self).__init__(**kwargs)
+        self._padding = padding
+        self._ragged_rank = ragged_rank
+
+      def call(self, inputs):
+        return ragged_tensor.RaggedTensor.from_tensor(
+            inputs, padding=self._padding, ragged_rank=self._ragged_rank)
+
+    model = keras.Sequential([ToRagged(padding=0)])
+    model._experimental_run_tf_function = experimental_run_tf_function
+
+    # Define some input data with additional padding.
+    input_data = np.array([[1, 0, 0], [2, 3, 0]])
+    output = model.predict(input_data, batch_size=2)
+
+    expected_values = [[1], [2, 3]]
+    self.assertAllEqual(expected_values, output)
+
+  @combinations.generate(
+      combinations.combine(
           distribution=strategies_minus_default_minus_tpu + tpu_strategies,
           mode=['eager']))
   def test_correctness_of_add_loss_with_merge_call(self, distribution):
@@ -1825,6 +1989,36 @@
   return model
 
 
+def _functional_with_layer_reuse(input_shape, num_classes, l1, l2):
+  base_model = keras.Sequential([
+      keras.layers.Conv2D(
+          32, kernel_size=5, activation='relu', input_shape=input_shape),
+      keras.layers.MaxPooling2D(pool_size=2),
+      keras.layers.Conv2D(64, kernel_size=5, activation='relu'),
+      keras.layers.MaxPooling2D(pool_size=2),
+      keras.layers.Flatten(),
+      keras.layers.Dense(1024, activation='relu'),
+      keras.layers.Dense(num_classes, name='logits'),
+  ])
+  inputs = keras.Input(input_shape, name='images')
+  logits = base_model(inputs)
+  model = keras.Model(inputs=inputs, outputs=logits)
+  # Reuse sequential layer and create new nodes.
+  zero_logits = base_model(array_ops.zeros_like(inputs))
+  one_logits = base_model(array_ops.ones_like(inputs))
+  # L2 loss.
+  l2_loss = math_ops.reduce_mean(
+      math_ops.reduce_sum(math_ops.square(logits - zero_logits), -1))
+  model.add_loss(l2_loss * l2)
+  model.add_metric(l2_loss, aggregation='mean', name='l2_loss')
+  # L1 loss.
+  l1_loss = math_ops.reduce_mean(
+      math_ops.reduce_sum(math_ops.abs(logits - one_logits), -1))
+  model.add_loss(l1_loss * l1)
+  model.add_metric(l1_loss, aggregation='mean', name='l1_loss')
+  return model
+
+
 class TestDistributionStrategyWithMultipleAddLossAndMetricCalls(
     test.TestCase, parameterized.TestCase):
   """Tests complex models with multiple add loss and metric calls."""
@@ -1836,10 +2030,16 @@
               model_fn=[
                   _functional_with_add_loss_and_metric,
                   _sequential_with_add_loss_and_metric,
+                  _functional_with_layer_reuse,
               ],
               l1=[0.01],
               l2=[0.1])))
   def test_fit_and_evaluate(self, distribution, model_fn, l1, l2):
+    # TODO(b/138445028): Enable the test for TPU once bug is fixed.
+    if (isinstance(distribution,
+                   (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1))):
+      self.skipTest('Flaky with TPUStrategy')
+
     # Make fake MNIST-like image data.
     dataset = dataset_ops.DatasetV2.from_tensor_slices(
         (np.random.uniform(size=(64, 28, 28, 1)).astype(np.float32),
diff --git a/tensorflow/python/keras/distribute/distributed_training_utils.py b/tensorflow/python/keras/distribute/distributed_training_utils.py
index 1f484ae..227fc01 100644
--- a/tensorflow/python/keras/distribute/distributed_training_utils.py
+++ b/tensorflow/python/keras/distribute/distributed_training_utils.py
@@ -32,6 +32,7 @@
 from tensorflow.python.eager import def_function
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.keras import backend as K
 from tensorflow.python.keras import callbacks
@@ -42,7 +43,10 @@
 from tensorflow.python.keras.utils.mode_keys import ModeKeys
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import sparse_ops
 from tensorflow.python.ops import variables
+from tensorflow.python.ops.ragged import ragged_concat_ops
+from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import nest
 from tensorflow.python.util import tf_contextlib
@@ -133,6 +137,38 @@
   return all_inputs, all_outputs, all_updates, all_session_args
 
 
+def unwrap_output_dict(strategy, grouped_outputs, mode):
+  """Unwrap the list of outputs contained in the PerReplica parameters."""
+  if mode == ModeKeys.PREDICT:
+    return flatten_per_replica_values(strategy, grouped_outputs)
+
+  # In the case of fit/eval, the grouped_outputs is a dict, whereas in predict,
+  # the output is as same structure as model output. They need to be treated
+  # differently
+  total_loss = strategy.reduce(reduce_util.ReduceOp.SUM,
+                               grouped_outputs['total_loss'][0], axis=None)
+  output_losses = flatten_per_replica_values(strategy,
+                                             grouped_outputs['output_losses'])
+  metrics = flatten_per_replica_values(strategy,
+                                       grouped_outputs['metrics'])
+  batch_size = strategy.reduce(reduce_util.ReduceOp.SUM,
+                               grouped_outputs['batch_size'], axis=None)
+  if (is_tpu_strategy(strategy) and
+      ops.executing_eagerly_outside_functions()):
+    # Choose 1 value per replica in the TPU case since all replicas produce the
+    # same output.
+    # We only do this in eager mode for now since this function is used in
+    # both graph and eager mode and in the graph case we currently don't use
+    # experimental_run so would need to be removed when we converge the graph
+    # code path as well.
+    output_losses = output_losses[::strategy.num_replicas_in_sync]
+    metrics = metrics[::strategy.num_replicas_in_sync]
+  return {'total_loss': [total_loss],
+          'output_losses': output_losses,
+          'metrics': metrics,
+          'batch_size': batch_size}
+
+
 def unwrap_outputs(distribution_strategy, grouped_outputs,
                    with_loss_tensor=False):
   """Unwrap the list of outputs contained in the PerReplica parameters.
@@ -304,7 +340,7 @@
 
   """
   # Convert the inputs and targets into a list of PerReplica objects.
-  per_replica_list = nest.flatten(x)
+  per_replica_list = nest.flatten(x, expand_composites=True)
   x_values_list = []
   for x in per_replica_list:
     if not tensor_util.is_tensor(x):
@@ -1009,14 +1045,15 @@
     model.set_weights(updated_weights)
 
 
-def _per_replica_aggregate_batch(batch_outs, model, mode):
+def _per_replica_aggregate_batch(strategy, batch_outs, model, mode):
   """Aggregates the per-replica batch-level outputs from a distributed step."""
-  if model._distribution_strategy is not None and mode == ModeKeys.PREDICT:
+  if strategy is not None and mode == ModeKeys.PREDICT:
     total_batch_outs = []
     for i in range(len(model.outputs)):
-      num_replicas = model._distribution_strategy.num_replicas_in_sync
+      num_replicas = strategy.num_replicas_in_sync
       nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas]
-      total_batch_outs.append(np.concatenate(nest.flatten(nested_outs)))
+      total_batch_outs.append(
+          concat_along_batch_dimension(nest.flatten(nested_outs)))
     return total_batch_outs
   return batch_outs
 
@@ -1096,17 +1133,18 @@
   return dc_context.get_current_worker_context().is_chief
 
 
-def filter_distributed_callbacks(callbacks_list):
+def filter_distributed_callbacks(callbacks_list, model):
   """Filter Callbacks based on the worker context when running multi-worker.
 
   Arguments:
     callbacks_list: A list of `Callback` instances.
+    model: Keras model instance.
 
   Returns:
     The list of `Callback` instances that should be run on this worker.
   """
 
-  if not multi_worker_util.in_multi_worker_mode():
+  if not model._in_multi_worker_mode():
     raise ValueError(
         'filter_distributed_callbacks() should only be called when Keras '
         'is in multi worker mode.')
@@ -1148,3 +1186,12 @@
       if sample_weights and None not in sample_weights:
         for m, sw in zip(distributed_models, sample_weights):
           m._update_sample_weight_modes(sample_weights=[sw])
+
+
+def concat_along_batch_dimension(outputs):
+  """Concats prediction outputs along the batch dimension."""
+  if isinstance(outputs[0], sparse_tensor.SparseTensor):
+    return sparse_ops.sparse_concat_v2(axis=0, sp_inputs=outputs)
+  if isinstance(outputs[0], ragged_tensor.RaggedTensor):
+    return ragged_concat_ops.concat(outputs, axis=0)
+  return np.concatenate(outputs)
diff --git a/tensorflow/python/keras/distribute/distributed_training_utils_test.py b/tensorflow/python/keras/distribute/distributed_training_utils_test.py
index 4adc8b5..39b4c36 100644
--- a/tensorflow/python/keras/distribute/distributed_training_utils_test.py
+++ b/tensorflow/python/keras/distribute/distributed_training_utils_test.py
@@ -22,14 +22,12 @@
 from tensorflow.python.keras.distribute import distributed_training_utils
 from tensorflow.python.keras.optimizer_v2 import adam
 from tensorflow.python.platform import test
-from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training import adam as v1_adam
 
 
 class DistributedTrainingUtilsTest(test.TestCase):
 
-  @test.mock.patch.object(logging, 'warning', autospec=True)
-  def test_validate_callbacks_predefined_callbacks(self, mock_warning):
+  def test_validate_callbacks_predefined_callbacks(self):
     supported_predefined_callbacks = [
         callbacks.TensorBoard(),
         callbacks.CSVLogger(filename='./log.csv'),
@@ -55,8 +53,6 @@
         distributed_training_utils.validate_callbacks([callback],
                                                       v1_adam.AdamOptimizer())
 
-    self.assertEqual(0, mock_warning.call_count)
-
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/keras/distribute/keras_correctness_test_base.py b/tensorflow/python/keras/distribute/keras_correctness_test_base.py
index 1a446d1..915cd8c 100644
--- a/tensorflow/python/keras/distribute/keras_correctness_test_base.py
+++ b/tensorflow/python/keras/distribute/keras_correctness_test_base.py
@@ -64,7 +64,8 @@
 def all_strategy_and_input_config_combinations():
   return (combinations.times(
       combinations.combine(
-          distribution=all_strategies, run_distributed=[True, False]),
+          distribution=all_strategies,
+          experimental_run_tf_function=[True, False]),
       eager_mode_test_configuration() + graph_mode_test_configuration()))
 
 
@@ -97,10 +98,11 @@
   return (combinations.times(
       combinations.combine(
           distribution=strategies_for_embedding_models(),
-          run_distributed=[True, False]),
+          experimental_run_tf_function=[True, False]),
       (graph_mode_test_configuration())) + combinations.times(
           combinations.combine(
-              distribution=eager_mode_strategies, run_distributed=[False]),
+              distribution=eager_mode_strategies,
+              experimental_run_tf_function=[False]),
           (eager_mode_test_configuration())))
 
 
@@ -244,13 +246,13 @@
 def fit_eval_and_predict(initial_weights,
                          input_fn,
                          model_fn,
-                         run_distributed=None,
+                         experimental_run_tf_function=None,
                          distribution=None,
                          is_stateful_model=False):
   """Generates results for fit/predict/evaluate for given model."""
   training_inputs, eval_inputs, predict_inputs = input_fn()
   model = model_fn(
-      run_distributed=run_distributed,
+      experimental_run_tf_function=experimental_run_tf_function,
       initial_weights=initial_weights,
       distribution=distribution,
       input_shapes=get_shapes(training_inputs['x']))
@@ -418,28 +420,31 @@
 
   def get_model(self,
                 distribution=None,
-                run_distributed=None,
+                experimental_run_tf_function=None,
                 input_shapes=None):
     raise NotImplementedError
 
-  def skip_unsupported_test_configuration(self, distribution, run_distributed):
-    if should_skip_tpu_with_eager(distribution) and run_distributed:
-      self.skipTest(
-          'TPUStrategy does not support eager mode with run_distributed.')
+  def skip_unsupported_test_configuration(self, distribution,
+                                          experimental_run_tf_function):
+    if should_skip_tpu_with_eager(
+        distribution) and experimental_run_tf_function:
+      self.skipTest('TPUStrategy does not support eager mode with '
+                    'experimental_run_tf_function.')
     return
 
   def run_correctness_test(self,
                            distribution,
                            use_numpy,
                            use_validation_data,
-                           run_distributed=None,
+                           experimental_run_tf_function=None,
                            with_batch_norm=False,
                            is_stateful_model=False,
                            partial_last_batch=None,
                            training_epochs=2):
     with self.cached_session():
       self.set_up_test_config(use_numpy, use_validation_data, with_batch_norm)
-      self.skip_unsupported_test_configuration(distribution, run_distributed)
+      self.skip_unsupported_test_configuration(distribution,
+                                               experimental_run_tf_function)
 
       if partial_last_batch == 'eval':
         x_train, y_train, x_eval, y_eval, x_predict = (
@@ -456,7 +461,8 @@
       # This is used to initialize the model for both the distribution and
       # non-distribution run.
       model = self.get_model(
-          run_distributed=run_distributed, input_shapes=get_shapes(x_train))
+          experimental_run_tf_function=experimental_run_tf_function,
+          input_shapes=get_shapes(x_train))
       initial_weights = model.get_weights()
 
       ds_input_fn = functools.partial(
@@ -487,14 +493,14 @@
           initial_weights,
           input_fn=ds_input_fn,
           model_fn=self.get_model,
-          run_distributed=run_distributed,
+          experimental_run_tf_function=experimental_run_tf_function,
           distribution=distribution,
           is_stateful_model=is_stateful_model)
       results_without_ds = fit_eval_and_predict(
           initial_weights,
           input_fn=nods_input_fn,
           model_fn=self.get_model,
-          run_distributed=run_distributed,
+          experimental_run_tf_function=experimental_run_tf_function,
           distribution=None,
           is_stateful_model=is_stateful_model)
 
@@ -534,14 +540,18 @@
     training_input = kwargs
     return training_input, None, None
 
-  def run_dynamic_lr_test(self, distribution, run_distributed=None):
+  def run_dynamic_lr_test(self,
+                          distribution,
+                          experimental_run_tf_function=None):
     with self.cached_session():
       self.set_up_test_config()
-      self.skip_unsupported_test_configuration(distribution, run_distributed)
+      self.skip_unsupported_test_configuration(distribution,
+                                               experimental_run_tf_function)
 
       x_train, y_train, _ = self.get_data()
       model = self.get_model(
-          run_distributed=run_distributed, input_shapes=get_shapes(x_train))
+          experimental_run_tf_function=experimental_run_tf_function,
+          input_shapes=get_shapes(x_train))
       initial_weights = model.get_weights()
       update_freq = None
 
@@ -582,13 +592,13 @@
           initial_weights,
           input_fn=ds_input_fn,
           model_fn=self.get_model,
-          run_distributed=run_distributed,
+          experimental_run_tf_function=experimental_run_tf_function,
           distribution=distribution)
       results_without_ds = fit_eval_and_predict(
           initial_weights,
           input_fn=nods_input_fn,
           model_fn=self.get_model,
-          run_distributed=run_distributed,
+          experimental_run_tf_function=experimental_run_tf_function,
           distribution=None)
       compare_results(
           results_with_ds, results_without_ds, distribution, testcase=self)
diff --git a/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py b/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py
index 9d6724e..f68a927 100644
--- a/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py
+++ b/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py
@@ -34,14 +34,14 @@
   return (combinations.combine(
       distribution=keras_correctness_test_base.all_strategies,
       mode=['graph', 'eager'],
-      run_distributed=[True, False]))
+      experimental_run_tf_function=[True, False]))
 
 
 def all_strategy_combinations_with_graph_mode():
   return (combinations.combine(
       distribution=keras_correctness_test_base.all_strategies,
       mode=['graph'],
-      run_distributed=[True, False]))
+      experimental_run_tf_function=[True, False]))
 
 
 def is_default_strategy(strategy):
@@ -53,7 +53,7 @@
     keras_correctness_test_base.TestDistributionStrategyCorrectnessBase):
 
   def get_model(self,
-                run_distributed,
+                experimental_run_tf_function,
                 initial_weights=None,
                 distribution=None,
                 input_shapes=None):
@@ -76,7 +76,7 @@
           loss=keras.losses.mean_squared_error,
           optimizer=gradient_descent_keras.SGD(0.05),
           metrics=['mse'],
-          run_distributed=run_distributed)
+          experimental_run_tf_function=experimental_run_tf_function)
       return model
 
   def get_data(self):
@@ -104,9 +104,9 @@
   @combinations.generate(
       keras_correctness_test_base.all_strategy_and_input_config_combinations())
   def test_dnn_correctness(self, distribution, use_numpy, use_validation_data,
-                           run_distributed):
+                           experimental_run_tf_function):
     self.run_correctness_test(distribution, use_numpy, use_validation_data,
-                              run_distributed)
+                              experimental_run_tf_function)
 
   @combinations.generate(
       keras_correctness_test_base.test_combinations_with_tpu_strategies())
@@ -131,14 +131,18 @@
         training_epochs=1)
 
   @combinations.generate(all_strategy_combinations_with_graph_mode())
-  def test_dnn_with_dynamic_learning_rate(self, distribution, run_distributed):
-    self.run_dynamic_lr_test(distribution, run_distributed)
+  def test_dnn_with_dynamic_learning_rate(self, distribution,
+                                          experimental_run_tf_function):
+    self.run_dynamic_lr_test(distribution, experimental_run_tf_function)
 
 
 class TestDistributionStrategyDnnMetricCorrectness(
     keras_correctness_test_base.TestDistributionStrategyCorrectnessBase):
 
-  def get_model(self, run_distributed, distribution=None, input_shapes=None):
+  def get_model(self,
+                experimental_run_tf_function,
+                distribution=None,
+                input_shapes=None):
     with distribution.scope():
       model = keras.Sequential()
       model.add(
@@ -147,16 +151,19 @@
           loss=keras.losses.mean_squared_error,
           optimizer=gradient_descent_keras.SGD(0.05),
           metrics=[keras.metrics.BinaryAccuracy()],
-          run_distributed=run_distributed)
+          experimental_run_tf_function=experimental_run_tf_function)
     return model
 
-  def run_metric_correctness_test(self, distribution, run_distributed):
+  def run_metric_correctness_test(self, distribution,
+                                  experimental_run_tf_function):
     with self.cached_session():
       self.set_up_test_config()
-      self.skip_unsupported_test_configuration(distribution, run_distributed)
+      self.skip_unsupported_test_configuration(distribution,
+                                               experimental_run_tf_function)
 
       x_train, y_train, _ = self.get_data()
-      model = self.get_model(run_distributed, distribution=distribution)
+      model = self.get_model(
+          experimental_run_tf_function, distribution=distribution)
 
       batch_size = 64
       batch_size = (
@@ -169,14 +176,18 @@
       self.assertEqual(history.history['binary_accuracy'], [1.0, 1.0])
 
   @combinations.generate(all_strategy_combinations_with_eager_and_graph_modes())
-  def test_simple_dnn_metric_correctness(self, distribution, run_distributed):
-    self.run_metric_correctness_test(distribution, run_distributed)
+  def test_simple_dnn_metric_correctness(self, distribution,
+                                         experimental_run_tf_function):
+    self.run_metric_correctness_test(distribution, experimental_run_tf_function)
 
 
 class TestDistributionStrategyDnnMetricEvalCorrectness(
     keras_correctness_test_base.TestDistributionStrategyCorrectnessBase):
 
-  def get_model(self, run_distributed, distribution=None, input_shapes=None):
+  def get_model(self,
+                experimental_run_tf_function,
+                distribution=None,
+                input_shapes=None):
     with distribution.scope():
       model = keras.Sequential()
       model.add(
@@ -189,15 +200,18 @@
           loss='mae',
           metrics=['accuracy', keras.metrics.BinaryAccuracy()],
           optimizer=gradient_descent.GradientDescentOptimizer(0.001),
-          run_distributed=run_distributed)
+          experimental_run_tf_function=experimental_run_tf_function)
     return model
 
-  def run_eval_metrics_correctness_test(self, distribution, run_distributed):
+  def run_eval_metrics_correctness_test(self, distribution,
+                                        experimental_run_tf_function):
     with self.cached_session():
       self.set_up_test_config()
-      self.skip_unsupported_test_configuration(distribution, run_distributed)
+      self.skip_unsupported_test_configuration(distribution,
+                                               experimental_run_tf_function)
 
-      model = self.get_model(run_distributed, distribution=distribution)
+      model = self.get_model(
+          experimental_run_tf_function, distribution=distribution)
 
       # verify correctness of stateful and stateless metrics.
       x = np.ones((100, 4)).astype('float32')
@@ -217,8 +231,9 @@
 
   @combinations.generate(all_strategy_combinations_with_eager_and_graph_modes())
   def test_identity_model_metric_eval_correctness(self, distribution,
-                                                  run_distributed):
-    self.run_eval_metrics_correctness_test(distribution, run_distributed)
+                                                  experimental_run_tf_function):
+    self.run_eval_metrics_correctness_test(distribution,
+                                           experimental_run_tf_function)
 
 
 class SubclassedModel(keras.Model):
@@ -249,7 +264,7 @@
     TestDistributionStrategyDnnCorrectness):
 
   def get_model(self,
-                run_distributed,
+                experimental_run_tf_function,
                 initial_weights=None,
                 distribution=None,
                 input_shapes=None):
@@ -260,23 +275,23 @@
           loss=keras.losses.mean_squared_error,
           optimizer=gradient_descent_keras.SGD(0.05),
           metrics=['mse'],
-          run_distributed=run_distributed)
+          experimental_run_tf_function=experimental_run_tf_function)
       return model
 
   @combinations.generate(
       keras_correctness_test_base.all_strategy_and_input_config_combinations())
   def test_dnn_correctness(self, distribution, use_numpy, use_validation_data,
-                           run_distributed):
+                           experimental_run_tf_function):
     if (context.executing_eagerly()) or is_default_strategy(distribution):
       self.run_correctness_test(distribution, use_numpy, use_validation_data,
-                                run_distributed)
+                                experimental_run_tf_function)
     elif K.is_tpu_strategy(distribution) and not context.executing_eagerly():
       with self.assertRaisesRegexp(
           ValueError,
           'Expected `model` argument to be a functional `Model` instance, '
           'but got a subclass model instead.'):
         self.run_correctness_test(distribution, use_numpy, use_validation_data,
-                                  run_distributed)
+                                  experimental_run_tf_function)
     else:
       with self.assertRaisesRegexp(
           ValueError,
@@ -284,27 +299,28 @@
           '`Sequential` model that is created without `input_shape`/'
           '`input_dim` set in its first layer or a subclassed model.'):
         self.run_correctness_test(distribution, use_numpy, use_validation_data,
-                                  run_distributed)
+                                  experimental_run_tf_function)
 
   @combinations.generate(all_strategy_combinations_with_graph_mode())
-  def test_dnn_with_dynamic_learning_rate(self, distribution, run_distributed):
-    if ((not run_distributed and context.executing_eagerly() and
+  def test_dnn_with_dynamic_learning_rate(self, distribution,
+                                          experimental_run_tf_function):
+    if ((not experimental_run_tf_function and context.executing_eagerly() and
          not K.is_tpu_strategy(distribution)) or
         is_default_strategy(distribution)):
-      self.run_dynamic_lr_test(distribution, run_distributed)
+      self.run_dynamic_lr_test(distribution, experimental_run_tf_function)
     elif K.is_tpu_strategy(distribution):
       with self.assertRaisesRegexp(
           ValueError,
           'Expected `model` argument to be a functional `Model` instance, '
           'but got a subclass model instead.'):
-        self.run_dynamic_lr_test(distribution, run_distributed)
+        self.run_dynamic_lr_test(distribution, experimental_run_tf_function)
     else:
       with self.assertRaisesRegexp(
           ValueError,
           'We currently do not support distribution strategy with a '
           '`Sequential` model that is created without `input_shape`/'
           '`input_dim` set in its first layer or a subclassed model.'):
-        self.run_dynamic_lr_test(distribution, run_distributed)
+        self.run_dynamic_lr_test(distribution, experimental_run_tf_function)
 
   @combinations.generate(
       keras_correctness_test_base.test_combinations_with_tpu_strategies())
diff --git a/tensorflow/python/keras/distribute/keras_embedding_model_correctness_test.py b/tensorflow/python/keras/distribute/keras_embedding_model_correctness_test.py
index 87fd174..a14293e 100644
--- a/tensorflow/python/keras/distribute/keras_embedding_model_correctness_test.py
+++ b/tensorflow/python/keras/distribute/keras_embedding_model_correctness_test.py
@@ -33,7 +33,7 @@
                 max_words=10,
                 initial_weights=None,
                 distribution=None,
-                run_distributed=None,
+                experimental_run_tf_function=None,
                 input_shapes=None):
     del input_shapes
     with keras_correctness_test_base.MaybeDistributionScope(distribution):
@@ -51,32 +51,30 @@
         model.set_weights(initial_weights)
 
       model.compile(
-          # TODO(b/130808953): Switch back the V1 optimizer once global_step is
-          # mirrored.
           optimizer=gradient_descent_keras.SGD(learning_rate=0.1),
           loss='sparse_categorical_crossentropy',
           metrics=['sparse_categorical_accuracy'],
-          run_distributed=run_distributed)
+          experimental_run_tf_function=experimental_run_tf_function)
     return model
 
   @combinations.generate(
       keras_correctness_test_base.test_combinations_for_embedding_model())
   def test_embedding_model_correctness(self, distribution, use_numpy,
-                                       use_validation_data, run_distributed):
+                                       use_validation_data,
+                                       experimental_run_tf_function):
 
     self.use_distributed_dense = False
     self.run_correctness_test(distribution, use_numpy, use_validation_data,
-                              run_distributed)
+                              experimental_run_tf_function)
 
   @combinations.generate(
       keras_correctness_test_base.test_combinations_for_embedding_model())
-  def test_embedding_time_distributed_model_correctness(self, distribution,
-                                                        use_numpy,
-                                                        use_validation_data,
-                                                        run_distributed):
+  def test_embedding_time_distributed_model_correctness(
+      self, distribution, use_numpy, use_validation_data,
+      experimental_run_tf_function):
     self.use_distributed_dense = True
     self.run_correctness_test(distribution, use_numpy, use_validation_data,
-                              run_distributed)
+                              experimental_run_tf_function)
 
 
 class DistributionStrategySiameseEmbeddingModelCorrectnessTest(
@@ -87,7 +85,7 @@
                 max_words=10,
                 initial_weights=None,
                 distribution=None,
-                run_distributed=None,
+                experimental_run_tf_function=None,
                 input_shapes=None):
     del input_shapes
     with keras_correctness_test_base.MaybeDistributionScope(distribution):
@@ -121,7 +119,7 @@
       model.compile(
           optimizer=gradient_descent_keras.SGD(learning_rate=0.1),
           loss='mse',
-          run_distributed=run_distributed,
+          experimental_run_tf_function=experimental_run_tf_function,
           metrics=['mse'])
     return model
 
@@ -159,9 +157,9 @@
       keras_correctness_test_base.test_combinations_for_embedding_model())
   def test_siamese_embedding_model_correctness(self, distribution, use_numpy,
                                                use_validation_data,
-                                               run_distributed):
+                                               experimental_run_tf_function):
     self.run_correctness_test(distribution, use_numpy, use_validation_data,
-                              run_distributed)
+                              experimental_run_tf_function)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py b/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py
index b3d5706..8f050f8 100644
--- a/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py
+++ b/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py
@@ -31,7 +31,7 @@
   def get_model(self,
                 initial_weights=None,
                 distribution=None,
-                run_distributed=None,
+                experimental_run_tf_function=None,
                 input_shapes=None):
     del input_shapes
     with keras_correctness_test_base.MaybeDistributionScope(distribution):
@@ -58,7 +58,7 @@
           optimizer=gradient_descent.SGD(learning_rate=0.1),
           loss='sparse_categorical_crossentropy',
           metrics=['sparse_categorical_accuracy'],
-          run_distributed=run_distributed)
+          experimental_run_tf_function=experimental_run_tf_function)
 
     return model
 
@@ -93,22 +93,22 @@
   @combinations.generate(
       keras_correctness_test_base.all_strategy_and_input_config_combinations())
   def test_cnn_correctness(self, distribution, use_numpy, use_validation_data,
-                           run_distributed):
+                           experimental_run_tf_function):
     self.run_correctness_test(distribution, use_numpy, use_validation_data,
-                              run_distributed)
+                              experimental_run_tf_function)
 
   @combinations.generate(
       keras_correctness_test_base.all_strategy_and_input_config_combinations())
   def test_cnn_with_batch_norm_correctness(self, distribution, use_numpy,
                                            use_validation_data,
-                                           run_distributed):
+                                           experimental_run_tf_function):
     self.skipTest('Flakily times out, b/134670856')
     self.run_correctness_test(
         distribution,
         use_numpy,
         use_validation_data,
         with_batch_norm=True,
-        run_distributed=run_distributed)
+        experimental_run_tf_function=experimental_run_tf_function)
 
   @combinations.generate(
       keras_correctness_test_base.test_combinations_with_tpu_strategies() +
diff --git a/tensorflow/python/keras/distribute/keras_lstm_model_correctness_test.py b/tensorflow/python/keras/distribute/keras_lstm_model_correctness_test.py
index 5b403c2..149fbd1 100644
--- a/tensorflow/python/keras/distribute/keras_lstm_model_correctness_test.py
+++ b/tensorflow/python/keras/distribute/keras_lstm_model_correctness_test.py
@@ -37,7 +37,7 @@
                 max_words=10,
                 initial_weights=None,
                 distribution=None,
-                run_distributed=None,
+                experimental_run_tf_function=None,
                 input_shapes=None):
     del input_shapes
 
@@ -67,15 +67,16 @@
           optimizer=optimizer_fn(learning_rate=0.1),
           loss='sparse_categorical_crossentropy',
           metrics=['sparse_categorical_accuracy'],
-          run_distributed=run_distributed)
+          experimental_run_tf_function=experimental_run_tf_function)
     return model
 
   @combinations.generate(
       keras_correctness_test_base.test_combinations_for_embedding_model())
   def test_lstm_model_correctness(self, distribution, use_numpy,
-                                  use_validation_data, run_distributed):
+                                  use_validation_data,
+                                  experimental_run_tf_function):
     self.run_correctness_test(distribution, use_numpy, use_validation_data,
-                              run_distributed)
+                              experimental_run_tf_function)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/keras/distribute/keras_optimizer_v2_test.py b/tensorflow/python/keras/distribute/keras_optimizer_v2_test.py
index 36918ca..012c10f 100644
--- a/tensorflow/python/keras/distribute/keras_optimizer_v2_test.py
+++ b/tensorflow/python/keras/distribute/keras_optimizer_v2_test.py
@@ -108,9 +108,9 @@
               strategy_combinations.central_storage_strategy_with_two_gpus,
           ],
           mode=['graph', 'eager'],
-          run_distributed=[True, False]))
+          experimental_run_tf_function=[True, False]))
   def testOptimizerWithKerasModelAndNumpyArrays(self, distribution,
-                                                run_distributed):
+                                                experimental_run_tf_function):
     self.skipTest('b/130309197')
     with self.cached_session():
       with distribution.scope():
@@ -119,7 +119,10 @@
         loss = 'mse'
         metrics = ['mae']
         model.compile(
-            optimizer, loss, metrics=metrics, run_distributed=run_distributed)
+            optimizer,
+            loss,
+            metrics=metrics,
+            experimental_run_tf_function=experimental_run_tf_function)
 
       inputs = np.zeros((64, 3), dtype=np.float32)
       targets = np.zeros((64, 4), dtype=np.float32)
diff --git a/tensorflow/python/keras/distribute/keras_premade_models_test.py b/tensorflow/python/keras/distribute/keras_premade_models_test.py
new file mode 100644
index 0000000..fa77ca2
--- /dev/null
+++ b/tensorflow/python/keras/distribute/keras_premade_models_test.py
@@ -0,0 +1,96 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for keras premade models using tf.distribute.Strategy."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
+from tensorflow.python.eager import test
+from tensorflow.python.keras.engine import sequential
+from tensorflow.python.keras.layers import core
+from tensorflow.python.keras.optimizer_v2 import adagrad
+from tensorflow.python.keras.optimizer_v2 import gradient_descent
+from tensorflow.python.keras.premade import linear
+from tensorflow.python.keras.premade import wide_deep
+
+
+def strategy_combinations_eager_data_fn():
+  return combinations.combine(
+      distribution=[
+          strategy_combinations.default_strategy,
+          strategy_combinations.one_device_strategy,
+          strategy_combinations.one_device_strategy_gpu,
+          strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+          strategy_combinations.mirrored_strategy_with_two_gpus
+      ],
+      mode=['eager'],
+      data_fn=[get_numpy, get_dataset])
+
+
+def get_numpy():
+  inputs = np.random.uniform(low=-5, high=5, size=(64, 2)).astype(np.float32)
+  output = .3 * inputs[:, 0] + .2 * inputs[:, 1]
+  return inputs, output
+
+
+def get_dataset():
+  inputs, output = get_numpy()
+  dataset = dataset_ops.Dataset.from_tensor_slices((inputs, output))
+  dataset = dataset.batch(10).repeat(10)
+  return dataset
+
+
+class KerasPremadeModelsTest(test.TestCase, parameterized.TestCase):
+
+  @combinations.generate(strategy_combinations_eager_data_fn())
+  def test_linear_model(self, distribution, data_fn):
+    with distribution.scope():
+      model = linear.LinearModel()
+      opt = gradient_descent.SGD(learning_rate=0.1)
+      model.compile(opt, 'mse', experimental_run_tf_function=True)
+      if data_fn == get_numpy:
+        inputs, output = get_numpy()
+        hist = model.fit(inputs, output, epochs=5)
+      else:
+        hist = model.fit(get_dataset(), epochs=5)
+      self.assertLess(hist.history['loss'][4], 0.2)
+
+  @combinations.generate(strategy_combinations_eager_data_fn())
+  def test_wide_deep_model(self, distribution, data_fn):
+    with distribution.scope():
+      linear_model = linear.LinearModel(units=1)
+      dnn_model = sequential.Sequential([core.Dense(units=1)])
+      wide_deep_model = wide_deep.WideDeepModel(linear_model, dnn_model)
+      linear_opt = gradient_descent.SGD(learning_rate=0.1)
+      dnn_opt = adagrad.Adagrad(learning_rate=0.2)
+      wide_deep_model.compile(
+          optimizer=[linear_opt, dnn_opt],
+          loss='mse',
+          experimental_run_tf_function=True)
+      if data_fn == get_numpy:
+        inputs, output = get_numpy()
+        hist = wide_deep_model.fit(inputs, output, epochs=5)
+      else:
+        hist = wide_deep_model.fit(get_dataset(), epochs=5)
+      self.assertLess(hist.history['loss'][4], 0.2)
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/python/keras/distribute/keras_stateful_lstm_model_correctness_test.py b/tensorflow/python/keras/distribute/keras_stateful_lstm_model_correctness_test.py
index 4802c8d..db5118f 100644
--- a/tensorflow/python/keras/distribute/keras_stateful_lstm_model_correctness_test.py
+++ b/tensorflow/python/keras/distribute/keras_stateful_lstm_model_correctness_test.py
@@ -41,7 +41,7 @@
       mode='graph',
       use_numpy=False,
       use_validation_data=False,
-      run_distributed=[True, False]))
+      experimental_run_tf_function=[True, False]))
 
 
 class DistributionStrategyStatefulLstmModelCorrectnessTest(
@@ -52,7 +52,7 @@
                 max_words=10,
                 initial_weights=None,
                 distribution=None,
-                run_distributed=None,
+                experimental_run_tf_function=None,
                 input_shapes=None):
     del input_shapes
     batch_size = keras_correctness_test_base._GLOBAL_BATCH_SIZE
@@ -86,20 +86,22 @@
   # doesn't work and enable for DistributionStrategy more generally.
   @combinations.generate(test_combinations_for_stateful_embedding_model())
   def disabled_test_stateful_lstm_model_correctness(
-      self, distribution, use_numpy, use_validation_data, run_distributed):
+      self, distribution, use_numpy, use_validation_data,
+      experimental_run_tf_function):
     self.run_correctness_test(
         distribution,
         use_numpy,
         use_validation_data,
         is_stateful_model=True,
-        run_distributed=run_distributed)
+        experimental_run_tf_function=experimental_run_tf_function)
 
   @combinations.generate(
       combinations.times(
           keras_correctness_test_base.test_combinations_with_tpu_strategies(),
-          combinations.combine(run_distributed=[True, False])))
+          combinations.combine(experimental_run_tf_function=[True, False])))
   def test_incorrectly_use_multiple_cores_for_stateful_lstm_model(
-      self, distribution, use_numpy, use_validation_data, run_distributed):
+      self, distribution, use_numpy, use_validation_data,
+      experimental_run_tf_function):
     with self.assertRaisesRegexp(
         ValueError,
         'Single core must be used for computation on stateful models. Consider '
@@ -109,7 +111,7 @@
           use_numpy,
           use_validation_data,
           is_stateful_model=True,
-          run_distributed=run_distributed)
+          experimental_run_tf_function=experimental_run_tf_function)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/keras/distribute/keras_utils_test.py b/tensorflow/python/keras/distribute/keras_utils_test.py
index 8a790cf..c476912 100644
--- a/tensorflow/python/keras/distribute/keras_utils_test.py
+++ b/tensorflow/python/keras/distribute/keras_utils_test.py
@@ -72,16 +72,17 @@
                                             parameterized.TestCase):
 
   @combinations.generate(
-      combinations.times(keras_test_lib.all_strategy_combinations(),
-                         combinations.combine(run_distributed=[True, False])))
-  def test_callbacks_in_fit(self, distribution, run_distributed):
+      combinations.times(
+          keras_test_lib.all_strategy_combinations(),
+          combinations.combine(experimental_run_tf_function=[True, False])))
+  def test_callbacks_in_fit(self, distribution, experimental_run_tf_function):
     with distribution.scope():
       model = keras_test_lib.get_model()
       model.compile(
           optimizer='sgd',
           loss='mse',
           metrics=['mae'],
-          run_distributed=run_distributed)
+          experimental_run_tf_function=experimental_run_tf_function)
 
     dataset = keras_test_lib.get_dataset(distribution)
     counter = Counter()
@@ -127,16 +128,17 @@
         })
 
   @combinations.generate(
-      combinations.times(keras_test_lib.all_strategy_combinations(),
-                         combinations.combine(run_distributed=[True, False])))
-  def test_callbacks_in_eval(self, distribution, run_distributed):
+      combinations.times(
+          keras_test_lib.all_strategy_combinations(),
+          combinations.combine(experimental_run_tf_function=[True, False])))
+  def test_callbacks_in_eval(self, distribution, experimental_run_tf_function):
     with distribution.scope():
       model = keras_test_lib.get_model()
       model.compile(
           optimizer='sgd',
           loss='mse',
           metrics=['mae'],
-          run_distributed=run_distributed)
+          experimental_run_tf_function=experimental_run_tf_function)
 
     dataset = keras_test_lib.get_dataset(distribution)
     counter = Counter()
@@ -152,16 +154,18 @@
         })
 
   @combinations.generate(
-      combinations.times(keras_test_lib.all_strategy_combinations(),
-                         combinations.combine(run_distributed=[True, False])))
-  def test_callbacks_in_predict(self, distribution, run_distributed):
+      combinations.times(
+          keras_test_lib.all_strategy_combinations(),
+          combinations.combine(experimental_run_tf_function=[True, False])))
+  def test_callbacks_in_predict(self, distribution,
+                                experimental_run_tf_function):
     with distribution.scope():
       model = keras_test_lib.get_model()
       model.compile(
           optimizer='sgd',
           loss='mse',
           metrics=['mae'],
-          run_distributed=run_distributed)
+          experimental_run_tf_function=experimental_run_tf_function)
 
     dataset = keras_test_lib.get_dataset(distribution)
     counter = Counter()
@@ -238,8 +242,9 @@
               strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
           ],
           mode=['graph', 'eager'],
-          run_distributed=[True, False]))
-  def test_unsupported_features(self, distribution, run_distributed, mode):
+          experimental_run_tf_function=[True, False]))
+  def test_unsupported_features(self, distribution,
+                                experimental_run_tf_function, mode):
     with self.cached_session():
       with distribution.scope():
         model = keras_test_lib.get_model()
@@ -247,18 +252,15 @@
         loss = 'mse'
         metrics = ['mae']
         model.compile(
-            optimizer, loss, metrics=metrics, run_distributed=run_distributed)
+            optimizer,
+            loss,
+            metrics=metrics,
+            experimental_run_tf_function=experimental_run_tf_function)
 
       dataset = keras_test_lib.get_dataset(distribution)
-
-      if run_distributed and mode == 'eager':
-        exception_error_message = (
-            '`validation_split` argument is not supported when data adapter'
-            ' is.+')
-      else:
-        exception_error_message = (
-            '`validation_split` argument is not supported when input `x`'
-            ' is a dataset or a dataset iterator.+')
+      exception_error_message = (
+          '`validation_split` argument is not supported when input `x`'
+          ' is a dataset or a dataset iterator.+')
 
       # Test with validation split
       with self.assertRaisesRegexp(ValueError, exception_error_message):
@@ -308,9 +310,9 @@
               strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
           ],
           mode=['graph', 'eager'],
-          run_distributed=[True, False]))
+          experimental_run_tf_function=[True, False]))
   def test_calling_with_unsupported_predefined_callbacks(
-      self, distribution, run_distributed):
+      self, distribution, experimental_run_tf_function):
     with self.cached_session():
       with distribution.scope():
         model = keras_test_lib.get_model()
@@ -318,7 +320,10 @@
         loss = 'mse'
         metrics = ['mae']
         model.compile(
-            optimizer, loss, metrics=metrics, run_distributed=run_distributed)
+            optimizer,
+            loss,
+            metrics=metrics,
+            experimental_run_tf_function=experimental_run_tf_function)
 
       dataset = keras_test_lib.get_dataset(distribution)
 
@@ -349,22 +354,27 @@
       combinations.combine(
           distribution=[strategy_combinations.one_device_strategy],
           mode=['eager'],
-          run_distributed=[True, False]))
+          experimental_run_tf_function=[True, False]))
   def test_distribution_strategy_with_run_eagerly(self, distribution,
-                                                  run_distributed):
+                                                  experimental_run_tf_function):
     with distribution.scope():
       x = keras.layers.Input(shape=(1,))
       y = keras.layers.Dense(1, kernel_initializer='ones')(x)
       model = keras.models.Model(x, y)
 
-      if run_distributed:
-        model.compile('sgd', run_eagerly=True, run_distributed=run_distributed)
+      if experimental_run_tf_function:
+        model.compile(
+            'sgd',
+            run_eagerly=True,
+            experimental_run_tf_function=experimental_run_tf_function)
       else:
         err_msg = ('We currently do not support enabling `run_eagerly` with '
                    'distribution strategy.')
         with self.assertRaisesRegex(ValueError, err_msg):
           model.compile(
-              'sgd', run_eagerly=True, run_distributed=run_distributed)
+              'sgd',
+              run_eagerly=True,
+              experimental_run_tf_function=experimental_run_tf_function)
 
   @combinations.generate(
       combinations.combine(
@@ -373,9 +383,9 @@
               strategy_combinations.one_device_strategy,
           ],
           mode=['graph', 'eager'],
-          run_distributed=[True, False]))
-  def test_distribution_strategy_on_subclassed_model(self, distribution,
-                                                     run_distributed):
+          experimental_run_tf_function=[True, False]))
+  def test_distribution_strategy_on_subclassed_model(
+      self, distribution, experimental_run_tf_function):
     with distribution.scope():
 
       class _SimpleMLP(keras.Model):
@@ -395,9 +405,11 @@
             'We currently do not support distribution strategy with a '
             '`Sequential` model that is created without `input_shape`/'
             '`input_dim` set in its first layer or a subclassed model.'):
-          model.compile('sgd', run_distributed=run_distributed)
+          model.compile(
+              'sgd', experimental_run_tf_function=experimental_run_tf_function)
       else:
-        model.compile('sgd', run_distributed=run_distributed)
+        model.compile(
+            'sgd', experimental_run_tf_function=experimental_run_tf_function)
 
   @combinations.generate(
       combinations.combine(
@@ -406,16 +418,17 @@
               strategy_combinations.one_device_strategy,
           ],
           mode=['graph', 'eager'],
-          run_distributed=[True, False]))
+          experimental_run_tf_function=[True, False]))
   def test_distribution_strategy_on_deferred_sequential_model(
-      self, distribution, run_distributed):
+      self, distribution, experimental_run_tf_function):
     with distribution.scope():
       model = keras.models.Sequential()
       model.add(keras.layers.Dense(16, activation='relu'))
       model.add(keras.layers.Dense(3, activation='softmax'))
 
       if context.executing_eagerly():
-        model.compile('sgd', run_distributed=run_distributed)
+        model.compile(
+            'sgd', experimental_run_tf_function=experimental_run_tf_function)
       else:
         with self.assertRaisesRegexp(
             ValueError,
@@ -423,7 +436,8 @@
             '`Sequential` model that is created without '
             '`input_shape`/`input_dim` set in its first layer or '
             'a subclassed model.'):
-          model.compile('sgd', run_distributed=run_distributed)
+          model.compile(
+              'sgd', experimental_run_tf_function=experimental_run_tf_function)
 
   @combinations.generate(
       keras_test_lib.all_strategy_combinations_minus_default())
@@ -449,10 +463,10 @@
               strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
           ],
           mode=['graph', 'eager'],
-          run_distributed=[True, False],
+          experimental_run_tf_function=[True, False],
           optimizer=strategy_combinations.gradient_descent_optimizer_keras_v2_fn
       ))
-  def test_masking(self, distribution, run_distributed, optimizer):
+  def test_masking(self, distribution, experimental_run_tf_function, optimizer):
     with self.cached_session():
       np.random.seed(1337)
       x = np.array([[[1], [1]], [[0], [0]]])
@@ -463,7 +477,9 @@
             keras.layers.TimeDistributed(
                 keras.layers.Dense(1, kernel_initializer='one')))
         model.compile(
-            loss='mse', optimizer=optimizer(), run_distributed=run_distributed)
+            loss='mse',
+            optimizer=optimizer(),
+            experimental_run_tf_function=experimental_run_tf_function)
       y = np.array([[[1], [1]], [[1], [1]]])
       dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
       dataset = dataset.repeat(100)
@@ -480,11 +496,11 @@
           keras_test_lib.all_strategy_combinations(),
           combinations.combine(
               fused=[True, False],
-              run_distributed=[True, False],
+              experimental_run_tf_function=[True, False],
               optimizer=strategy_combinations
               .gradient_descent_optimizer_keras_v2_fn)))
   def test_batchnorm_correctness(self, distribution, fused, optimizer,
-                                 run_distributed):
+                                 experimental_run_tf_function):
     with self.cached_session():
       with distribution.scope():
         model = keras.models.Sequential()
@@ -496,7 +512,9 @@
             ), momentum=0.8, fused=fused)
         model.add(norm)
         model.compile(
-            loss='mse', optimizer=optimizer(), run_distributed=run_distributed)
+            loss='mse',
+            optimizer=optimizer(),
+            experimental_run_tf_function=experimental_run_tf_function)
 
       # centered on 5.0, variance 10.0
       x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10, 20, 30))
@@ -525,21 +543,28 @@
       combinations.times(
           keras_test_lib.all_strategy_combinations_minus_default(),
           combinations.combine(
-              run_distributed=[True, False],
+              experimental_run_tf_function=[True, False],
               optimizer=strategy_combinations.rmsprop_optimizer_keras_v2_fn)))
-  def test_save_load_h5(self, distribution, optimizer, run_distributed):
+  def test_save_load_h5(self, distribution, optimizer,
+                        experimental_run_tf_function):
     with self.cached_session():
       dataset = keras_test_lib.get_dataset(distribution)
       with distribution.scope():
         model = keras_test_lib.get_model()
-        model.compile(optimizer(), 'mse', run_distributed=run_distributed)
+        model.compile(
+            optimizer(),
+            'mse',
+            experimental_run_tf_function=experimental_run_tf_function)
         model.fit(dataset, epochs=1, steps_per_epoch=1)
 
         weights_file = tempfile.mktemp('.h5')
         model.save_weights(weights_file)
 
         model_2 = keras_test_lib.get_model()
-        model_2.compile(optimizer(), 'mse', run_distributed=run_distributed)
+        model_2.compile(
+            optimizer(),
+            'mse',
+            experimental_run_tf_function=experimental_run_tf_function)
         model_2.load_weights(weights_file)
         model_2.predict(
             keras_test_lib.get_predict_dataset(distribution), steps=2)
@@ -549,9 +574,10 @@
       combinations.times(
           keras_test_lib.all_strategy_combinations_minus_default(),
           combinations.combine(
-              run_distributed=[True, False],
+              experimental_run_tf_function=[True, False],
               optimizer=strategy_combinations.rmsprop_optimizer_keras_v2_fn)))
-  def test_save_load_trackable(self, distribution, optimizer, run_distributed):
+  def test_save_load_trackable(self, distribution, optimizer,
+                               experimental_run_tf_function):
     # TODO(b/123533246): Enable the test for TPU once bug is fixed
     if (isinstance(distribution,
                    (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)) and
@@ -561,14 +587,20 @@
       dataset = keras_test_lib.get_dataset(distribution)
       with distribution.scope():
         model = keras_test_lib.get_model()
-        model.compile(optimizer(), 'mse', run_distributed=run_distributed)
+        model.compile(
+            optimizer(),
+            'mse',
+            experimental_run_tf_function=experimental_run_tf_function)
         model.fit(dataset, epochs=1, steps_per_epoch=1)
 
         weights_file = tempfile.mktemp()
         model.save_weights(weights_file)
 
         model_2 = keras_test_lib.get_model()
-        model_2.compile(optimizer(), 'mse', run_distributed=run_distributed)
+        model_2.compile(
+            optimizer(),
+            'mse',
+            experimental_run_tf_function=experimental_run_tf_function)
         model_2.load_weights(weights_file)
         model_2.predict(
             keras_test_lib.get_predict_dataset(distribution), steps=2)
@@ -580,8 +612,9 @@
   @combinations.generate(
       combinations.times(
           keras_test_lib.all_strategy_combinations_minus_default(),
-          combinations.combine(run_distributed=[True, False])))
-  def test_layer_outside_scope(self, distribution, run_distributed):
+          combinations.combine(experimental_run_tf_function=[True, False])))
+  def test_layer_outside_scope(self, distribution,
+                               experimental_run_tf_function):
     with self.cached_session():
       with self.assertRaisesRegexp(
           ValueError, 'was not created in the distribution strategy'):
@@ -593,13 +626,17 @@
           loss = 'mse'
           metrics = ['mae', keras.metrics.CategoricalAccuracy()]
           model.compile(
-              optimizer, loss, metrics=metrics, run_distributed=run_distributed)
+              optimizer,
+              loss,
+              metrics=metrics,
+              experimental_run_tf_function=experimental_run_tf_function)
 
   @combinations.generate(
       combinations.times(
           keras_test_lib.all_strategy_combinations_minus_default(),
-          combinations.combine(run_distributed=[True, False])))
-  def test_model_outside_scope(self, distribution, run_distributed):
+          combinations.combine(experimental_run_tf_function=[True, False])))
+  def test_model_outside_scope(self, distribution,
+                               experimental_run_tf_function):
     with self.cached_session():
       with self.assertRaisesRegexp(
           ValueError, 'was not created in the distribution strategy'):
@@ -611,7 +648,10 @@
           loss = 'mse'
           metrics = ['mae', keras.metrics.CategoricalAccuracy()]
           model.compile(
-              optimizer, loss, metrics=metrics, run_distributed=run_distributed)
+              optimizer,
+              loss,
+              metrics=metrics,
+              experimental_run_tf_function=experimental_run_tf_function)
 
 
 class TestDistributionStrategyWithStaticShapes(test.TestCase,
diff --git a/tensorflow/python/keras/distribute/multi_worker_optimizer_comparison_test.py b/tensorflow/python/keras/distribute/multi_worker_optimizer_comparison_test.py
index b9ce689..2112888 100644
--- a/tensorflow/python/keras/distribute/multi_worker_optimizer_comparison_test.py
+++ b/tensorflow/python/keras/distribute/multi_worker_optimizer_comparison_test.py
@@ -34,10 +34,8 @@
 from tensorflow.python.keras.engine import base_layer
 from tensorflow.python.keras.engine import sequential
 from tensorflow.python.keras.optimizer_v2 import gradient_descent
-from tensorflow.python.keras.optimizer_v2 import rmsprop
 from tensorflow.python.platform import test
 from tensorflow.python.training import gradient_descent as gradient_descent_v1
-from tensorflow.python.training import rmsprop as rmsprop_v1
 
 
 class KerasMultiWorkerOptimizerTest(test_base.IndependentWorkerTestBase,
@@ -136,17 +134,6 @@
         strategy_cls, gradient_descent.SGD,
         gradient_descent_v1.GradientDescentOptimizer)
 
-  @combinations.generate(
-      combinations.combine(
-          mode=['graph'],
-          strategy_cls=[collective_strategy.CollectiveAllReduceStrategy],
-          required_gpus=[0, 1]))
-  def test_rmsprop_optimizer_v1_v2_comparison(self, strategy_cls):
-    self.skipTest('There is an issue in collective ops (b/127700538) that '
-                  'prevent us from running this test with rmsprop optimizers.')
-    self.run_optimizer_comparison_with_simple_bias_model(
-        strategy_cls, rmsprop.RMSprop, rmsprop_v1.RMSPropOptimizer)
-
 
 if __name__ == '__main__':
   with test.mock.patch.object(sys, 'exit', os._exit):
diff --git a/tensorflow/python/keras/distribute/multi_worker_training_state.py b/tensorflow/python/keras/distribute/multi_worker_training_state.py
index 17ac85a..d4fc0fc 100644
--- a/tensorflow/python/keras/distribute/multi_worker_training_state.py
+++ b/tensorflow/python/keras/distribute/multi_worker_training_state.py
@@ -220,7 +220,8 @@
     return temp_dir, os.path.join(temp_dir, 'training_state')
 
   def _assert_in_multi_worker_mode(self):
-    if not multi_worker_util.in_multi_worker_mode():
+    # pylint: disable=protected-access
+    if not self._model._in_multi_worker_mode():
       raise ValueError('MultiWorkerTrainingState is only supposed to be used '
                        'in multi-worker training. This indicates some error '
                        'that needs to be fixed. Please submit a bug issue to '
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index b193f09..2b6f4ed 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -27,6 +27,7 @@
 import numpy as np
 from six.moves import zip  # pylint: disable=redefined-builtin
 
+from google.protobuf import json_format
 from tensorflow.core.framework import node_def_pb2
 from tensorflow.python.autograph.core import ag_ctx
 from tensorflow.python.autograph.impl import api as autograph
@@ -38,8 +39,10 @@
 from tensorflow.python.framework import auto_control_deps
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
 from tensorflow.python.framework import func_graph
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.keras import backend
@@ -63,6 +66,8 @@
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variables as tf_variables
+from tensorflow.python.ops.ragged import ragged_tensor
+from tensorflow.python.platform import tf_logging
 from tensorflow.python.training.tracking import base as trackable
 from tensorflow.python.training.tracking import data_structures
 from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
@@ -107,8 +112,9 @@
   Arguments:
     trainable: Boolean, whether the layer's variables should be trainable.
     name: String name of the layer.
-    dtype: Default dtype of the layer's weights (default of `None` means use the
-      type of the first input).
+    dtype: The dtype of the layer's computations and weights (default of
+      `None` means use `tf.keras.backend.floatx` in TensorFlow 2, or the type
+      of the first input in TensorFlow 1).
     dynamic: Set this to `True` if your layer should only be run eagerly, and
       should not be used to generate a static computation graph.
       This would be the case for a Tree-RNN or a recursive network,
@@ -118,8 +124,10 @@
 
   Read-only properties:
     name: The name of the layer (string).
-    dtype: Default dtype of the layer's weights (default of `None` means use the
-      type of the first input).
+    dtype: The dtype of the layer's computations and weights. If mixed
+      precision is used with a `tf.keras.mixed_precision.experimental.Policy`,
+      this is instead just the dtype of the layer's weights, as the computations
+      are done in a different dtype.
     updates: List of update ops of this layer.
     losses: List of losses added by this layer.
     trainable_weights: List of variables to be included in backprop.
@@ -132,6 +140,129 @@
     trainable: Whether the layer should be trained (boolean).
     input_spec: Optional (list of) `InputSpec` object(s) specifying the
       constraints on inputs that can be accepted by the layer.
+
+  ### Dtypes and casting
+  Each layer has a dtype, which is typically the dtype of the layer's
+  computations and variables. A layer's dtype can be queried via the
+  `Layer.dtype` property. The dtype is specified with the `dtype` constructor
+  argument. In TensorFlow 2, the dtype defaults to `tf.keras.backend.floatx()`
+  if no dtype is passed. `floatx()` itself defaults to "float32". Additionally,
+  layers will cast their inputs to the layer's dtype in TensorFlow 2. For
+  example:
+
+  ```
+  x = tf.ones((4, 4, 4, 4), dtype='float64')
+  layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2)
+  print(layer.dtype)  # float32
+
+  # `layer` casts it's inputs to layer.dtype, which is float32, and does
+  # computations in float32.
+  y = layer(x)
+  ```
+
+  Currently, only tensors in the first argument to the layer's `call` method are
+  casted. For example:
+
+  ```
+  class MyLayer(tf.keras.layers.Layer):
+    # Bug! `b` will not be casted.
+    def call(self, a, b):
+      return a + 1., b + 1.
+
+  a = tf.constant(1., dtype="float32")
+  b = tf.constant(1., dtype="float32")
+
+  layer = MyLayer(dtype="float64")
+  x, y = layer(a, b)
+  print(x.dtype)  # float64
+  print(y.dtype)  # float32. Not casted since `b` was not passed to first input
+  ```
+
+  It is recommended to accept tensors only in the first argument. This way,
+  all tensors are casted to the layer's dtype. `MyLayer` should therefore be
+  written as:
+
+  ```
+  class MyLayer(tf.keras.layers.Layer):
+    # Now, all tensor inputs will be casted.
+    def call(self, inputs):
+      a, b = inputs
+      return a + 1., b + 1.
+
+  a = tf.constant(1., dtype="float32")
+  b = tf.constant(1., dtype="float32")
+
+  layer = MyLayer(dtype="float64")
+  x, y = layer((a, b))
+  print(x.dtype)  # float64
+  print(y.dtype)  # float64.
+  ```
+
+  In a future minor release, tensors in other arguments may be casted as well.
+
+  Currently, other arguments are not automatically casted for
+  technical reasons, but this may change in a future minor release.
+
+  A layer subclass can prevent its inputs from being autocasted by passing
+  `autocast=False` to the layer constructor. For example:
+
+  ```
+  class MyLayer(tf.keras.layers.Layer):
+
+    def __init__(self, **kwargs):
+      kwargs['autocast']=False
+      super(MyLayer, self).__init__(**kwargs)
+
+    def call(self, inp):
+      return inp
+
+  x = tf.ones((4, 4, 4, 4), dtype='float64')
+  layer = MyLayer()
+  print(layer.dtype)  # float32.
+  y = layer(x)  # MyLayer will not cast inputs to it's dtype of float32
+  print(y.dtype)  # float64
+  ```
+
+  #### Running models in float64 in TensorFlow 2
+
+  If you want to run a Model in float64, you can set floatx to be float64 by
+  calling `tf.keras.backend.set_floatx('float64')`. This will cause all layers
+  to default to float64 instead of float32:
+
+  ```
+  tf.keras.backend.set_floatx('float64')
+  layer1 = tf.keras.layers.Dense(4)
+  layer2 = tf.keras.layers.Dense(4)
+
+  x = tf.ones((4, 4))
+  y = layer2(layer1(x))  # Both layers run in float64
+  ```
+
+  Alternatively, you can pass `dtype='float64'` to each individual layer. Note
+  that if you have any layers which contain other layers as members, you must
+  ensure each sublayer gets `dtype='float64'` passed to it's constructor as
+  well:
+
+  ```
+  layer1 = tf.keras.layers.Dense(4, dtype='float64')
+  layer2 = tf.keras.layers.Dense(4, dtype='float64')
+
+  x = tf.ones((4, 4))
+  y = layer2(layer1(x))  # Both layers run in float64
+
+  class NestedLayer(tf.keras.layers.Layer):
+    def __init__(self, **kwargs):
+      super(NestedLayer, self).__init__(**kwargs)
+      self.dense = tf.keras.layers.Dense(4, dtype=kwargs.get('dtype'))
+
+    def call(self, inp):
+      return self.dense(inp)
+
+  layer3 = NestedLayer(dtype='float64')
+  z = layer3(x)  # layer3's dense layer runs in float64, since NestedLayer
+                 # correcty passed it's dtype to it's dense layer
+
+  ```
   """
 
   # See tf.Module for the usage of this property.
@@ -159,6 +290,7 @@
         'batch_size',
         'weights',
         'activity_regularizer',
+        'autocast'
     }
     # Validate optional keyword arguments.
     generic_utils.validate_kwargs(kwargs, allowed_kwargs)
@@ -194,7 +326,12 @@
     # added using the `add_metric` API.
     self._metrics = []
 
-    self._set_dtype_and_policy(dtype)
+    self._set_dtype_policy(dtype)
+    # Boolean indicating whether the layer automatically casts its inputs to the
+    # layer's compute_dtype.
+    self._autocast = kwargs.get('autocast',
+                                base_layer_utils.v2_dtype_behavior_enabled())
+
     # Dependencies tracked via attribute assignment.
     self._maybe_create_attribute('_layers', [])
 
@@ -203,11 +340,7 @@
     self._inbound_nodes = []
     self._outbound_nodes = []
 
-    call_fn_args = self._call_fn_args
-    self._expects_training_arg = ('training' in call_fn_args or
-                                  self._call_accepts_kwargs)
-    self._expects_mask_arg = ('mask' in call_fn_args or
-                              self._call_accepts_kwargs)
+    self._init_call_fn_args()
 
     # Whether the `call` method can be used to build a TF graph without issues.
     self._dynamic = dynamic
@@ -329,8 +462,9 @@
     if dtype is None:
       dtype = self.dtype or backend.floatx()
     dtype = dtypes.as_dtype(dtype)
-    if self._dtype is None:
-      self._dtype = dtype.base_dtype.name
+    if self._dtype_policy.variable_dtype is None:
+      # The policy is "infer", so we infer the policy from the variable dtype.
+      self._dtype_policy = policy.Policy(dtype.base_dtype.name)
     initializer = initializers.get(initializer)
     regularizer = regularizers.get(regularizer)
     constraint = constraints.get(constraint)
@@ -362,7 +496,7 @@
         raise ValueError('An initializer for variable %s of type %s is required'
                          ' for layer %s' % (name, dtype.base_dtype, self.name))
 
-    if autocast and self._mixed_precision_policy.should_cast_variables:
+    if autocast and self._dtype_policy.should_cast_variables:
       # Wrap 'getter' with a version that returns an AutoCastVariable.
       old_getter = getter
       def getter(*args, **kwargs):  # pylint: disable=function-redefined
@@ -441,7 +575,7 @@
     if len(extra_args) > 1 and hasattr(self.get_config, '_is_default'):
       raise NotImplementedError('Layers with arguments in `__init__` must '
                                 'override `get_config`.')
-    # TODO(reedwm): Handle serializing self._mixed_precision_policy.
+    # TODO(reedwm): Handle serializing self._dtype_policy.
     return config
 
   @classmethod
@@ -537,12 +671,7 @@
       return s.shape
     input_shape = nest.map_structure(check_type_return_shape, input_signature)
     output_shape = self.compute_output_shape(input_shape)
-    if self._mixed_precision_policy.should_cast_variables:
-      # If using mixed precision, and weights are cast to input dtype, we should
-      # not infer the dtype from self.dtype
-      dtype = None
-    else:
-      dtype = self.dtype
+    dtype = self._compute_dtype
     if dtype is None:
       input_dtypes = [s.dtype for s in nest.flatten(input_signature)]
       # Default behavior when self.dtype is None, is to use the first input's
@@ -653,6 +782,12 @@
             training_value = backend.learning_phase()
 
       if self._expects_training_arg and training_value is not None:
+        # Force the training_value to be bool type which matches to the contract
+        # for layer/model call args.
+        if tensor_util.is_tensor(training_value):
+          training_value = math_ops.cast(training_value, dtypes.bool)
+        else:
+          training_value = bool(training_value)
         kwargs['training'] = training_value
         training_arg_passed_by_framework = True
 
@@ -674,6 +809,8 @@
       if build_graph:
         # Symbolic execution on symbolic tensors. We will attempt to build
         # the corresponding TF subgraph inside `backend.get_graph()`
+        # TODO(reedwm): We should assert input compatibility after the inputs
+        # are casted, not before.
         input_spec.assert_input_compatibility(self.input_spec, inputs,
                                               self.name)
         graph = backend.get_graph()
@@ -681,13 +818,16 @@
           # Build layer if applicable (if the `build` method has been
           # overridden).
           self._maybe_build(inputs)
+          cast_inputs = self._maybe_cast_inputs(inputs)
 
           # Wrapping `call` function in autograph to allow for dynamic control
-          # dependencies in call. We are limiting this to subclassed layers as
-          # autograph is strictly needed only for subclassed layers and models.
+          # flow and control dependencies in call. We are limiting this to
+          # subclassed layers as autograph is strictly needed only for
+          # subclassed layers and models.
           # tf_convert will respect the value of autograph setting in the
           # enclosing tf.function, if any.
-          if base_layer_utils.is_subclassed(self):
+          if (base_layer_utils.is_subclassed(self) and
+              not base_layer_utils.from_saved_model(self)):
             call_fn = autograph.tf_convert(
                 self.call, ag_ctx.control_status_ctx())
           else:
@@ -696,30 +836,25 @@
           if not self.dynamic:
             try:
               with base_layer_utils.autocast_context_manager(
-                  input_list,
-                  self._mixed_precision_policy.should_cast_variables):
+                  self._compute_dtype):
                 # Add auto_control_deps in V2 when they are not already added by
                 # a `tf.function`.
                 if (ops.executing_eagerly_outside_functions() and
                     not base_layer_utils.is_in_eager_or_tf_function()):
                   with auto_control_deps.AutomaticControlDependencies() as acd:
-                    outputs = call_fn(inputs, *args, **kwargs)
+                    outputs = call_fn(cast_inputs, *args, **kwargs)
                     # Wrap Tensors in `outputs` in `tf.identity` to avoid
                     # circular dependencies.
                     outputs = base_layer_utils.mark_as_return(outputs, acd)
                 else:
-                  outputs = call_fn(inputs, *args, **kwargs)
+                  outputs = call_fn(cast_inputs, *args, **kwargs)
 
-            except TypeError as e:
-              exception_str = str(e)
-              exception_msg = 'Tensor objects are only iterable when eager'
-              if exception_msg in exception_str:
-                raise TypeError('You are attempting to use Python control '
-                                'flow in a layer that was not declared to be '
-                                'dynamic. Pass `dynamic=True` to the class '
-                                'constructor.\nEncountered error:\n"""\n' +
-                                exception_str + '\n"""')
-              raise
+            except errors.OperatorNotAllowedInGraphError as e:
+              raise TypeError('You are attempting to use Python control '
+                              'flow in a layer that was not declared to be '
+                              'dynamic. Pass `dynamic=True` to the class '
+                              'constructor.\nEncountered error:\n"""\n' +
+                              str(e) + '\n"""')
           else:
             # We will use static shape inference to return symbolic tensors
             # matching the specifications of the layer outputs.
@@ -753,9 +888,10 @@
         # Eager execution on data tensors.
         with backend.name_scope(self._name_scope()):
           self._maybe_build(inputs)
+          cast_inputs = self._maybe_cast_inputs(inputs)
           with base_layer_utils.autocast_context_manager(
-              input_list, self._mixed_precision_policy.should_cast_variables):
-            outputs = self.call(inputs, *args, **kwargs)
+              self._compute_dtype):
+            outputs = self.call(cast_inputs, *args, **kwargs)
           self._handle_activity_regularization(inputs, outputs)
           self._set_mask_metadata(inputs, outputs, input_masks)
 
@@ -763,7 +899,7 @@
 
   @property
   def dtype(self):
-    return self._dtype
+    return self._dtype_policy.variable_dtype
 
   @property
   def name(self):
@@ -844,11 +980,10 @@
         if callable(u):
           try:
             u = u()
-          except ValueError as e:
-            if 'Trying to capture a tensor from an inner function' in str(e):
-              base_layer_utils.check_graph_consistency(
-                  method='add_update', force_raise=True)
-            raise
+          except errors.InaccessibleTensorError:
+            base_layer_utils.check_graph_consistency(
+                method='add_update', force_raise=True)
+            raise  # check_graph_consistency may not always raise.
         base_layer_utils.check_graph_consistency(u, method='add_update')
         updates.append(u)
     return updates + self._gather_children_attribute('updates')
@@ -1485,9 +1620,9 @@
           (in which case its weights aren't yet defined).
     """
     if not self.built:
-      if self.__class__.__name__ == 'Sequential':
+      if getattr(self, '_is_graph_network', False):
         with tf_utils.maybe_init_scope(self):
-          self._maybe_build()  # pylint: disable=no-value-for-parameter
+          self._maybe_build(self.inputs)
       else:
         raise ValueError('You tried to call `count_params` on ' + self.name +
                          ', but the layer isn\'t built. '
@@ -1590,23 +1725,110 @@
   # Methods & attributes below are all private and only used by the framework. #
   ##############################################################################
 
-  def _set_dtype_and_policy(self, dtype):
-    """Sets self._dtype and self._mixed_precision_policy."""
-    if dtype:
-      if isinstance(dtype, policy.Policy):
-        self._mixed_precision_policy = dtype
-        self._dtype = self._mixed_precision_policy.default_variable_dtype
-      else:
-        # If a non-policy dtype is passed, no casting should be done. So we use
-        # the "infer" policy, which does no casting.
-        self._mixed_precision_policy = policy.Policy('infer')
-        self._dtype = dtypes.as_dtype(dtype).name
+  def _set_dtype_policy(self, dtype):
+    """Sets self._dtype_policy."""
+    if isinstance(dtype, policy.Policy):
+      self._dtype_policy = dtype
+    elif dtype:
+      self._dtype_policy = policy.Policy(dtypes.as_dtype(dtype).name)
     else:
-      self._mixed_precision_policy = policy.global_policy()
-      # If the global policy has not been set, it will be an "infer" policy
-      # without a default variable dtype, and so self._dtype will be None. In
-      # that case, self._dtype will be set when the layer is built or called.
-      self._dtype = self._mixed_precision_policy.default_variable_dtype
+      self._dtype_policy = policy.global_policy()
+
+    if self._dtype_policy.should_cast_variables and backend.is_tpu_strategy(
+        ds_context.get_strategy()):
+      # TODO(b/137859335): Supoprt this. AutoCastVariables currently do not work
+      # properly when wrapping TPUMirroredVariables.
+      raise ValueError('DType Policies ending in "_with_float32_vars" are '
+                       'not yet supported with TPUStrategy. Got policy: %s' %
+                       self._dtype_policy.name)
+
+    # This has no impact on the layer behavior, and is only used for printing
+    # warnings.
+    self._dtype_defaulted_to_floatx = (not dtype and
+                                       policy.policy_defaults_to_floatx())
+
+  # TODO(reedwm): Expose this property?
+  @property
+  def _compute_dtype(self):
+    """The layer's compute dtype.
+
+    Unless mixed-precision is used, this is the same as `Layer.dtype`.
+
+    If self._autocast is True, layer's will cast floating-point inputs to this.
+
+    Returns:
+      The layer's compute dtype.
+    """
+    return self._dtype_policy.compute_dtype
+
+  def _maybe_cast_inputs(self, inputs):
+    """Maybe casts the inputs to the compute dtype.
+
+    If self._compute_dtype is floating-point, and self_autocast is True,
+    floating-point inputs are casted to self._compute_dtype.
+
+    Args:
+      inputs: Input tensor, or structure of input tensors.
+
+    Returns:
+      `inputs`, but tensors may have been casted to self._compute_dtype
+    """
+    compute_dtype = self._compute_dtype
+    if (self._autocast and compute_dtype and
+        dtypes.as_dtype(compute_dtype).is_floating):
+      def f(x):
+        cast_types = (ops.Tensor, sparse_tensor.SparseTensor,
+                      ragged_tensor.RaggedTensor)
+        if (isinstance(x, cast_types) and x.dtype.is_floating and
+            x.dtype.base_dtype.name != compute_dtype):
+          if self._dtype_defaulted_to_floatx:
+            self._warn_about_input_casting(x.dtype.base_dtype)
+          return math_ops.cast(x, compute_dtype)
+        else:
+          return x
+      return nest.map_structure(f, inputs)
+    else:
+      return inputs
+
+  def _warn_about_input_casting(self, input_dtype):
+    # self._already_warned_about_input_casting is only retrieved or set in this
+    # function.
+    already_warned = getattr(self, '_already_warned_about_input_casting', False)
+    if not already_warned:
+      tf_logging.warn(
+          "Layer {self.name} is casting an input tensor from dtype "
+          "{input_dtype} to the layer's dtype of {layer_dtype}, which is new "
+          "behavior in TensorFlow 2.  The layer has dtype {layer_dtype} "
+          "because it's dtype defaults to floatx.\n\n"
+          ""
+          "If you intended to run this layer in {layer_dtype}, you can safely "
+          "ignore this warning. If in doubt, this warning is likely only an "
+          "issue if you are porting a TensorFlow 1.X model to TensorFlow 2.\n\n"
+          ""
+          "To change all layers to have dtype {input_dtype} by default, call "
+          "`tf.keras.backend.set_floatx('{input_dtype}')`. To change just this "
+          "layer, pass dtype='{input_dtype}' to the layer constructor. If you "
+          "are the author of this layer, you can disable autocasting by "
+          "passing autocast=False to the base Layer constructor.\n".format(
+              self=self,
+              input_dtype=input_dtype.name,
+              layer_dtype=self._compute_dtype))
+      self._already_warned_about_input_casting = True
+
+  # _dtype used to be an attribute set in the constructor. We still expose it
+  # because some clients still use it.
+  # TODO(reedwm): Deprecate, then remove the _dtype property.
+  @property
+  def _dtype(self):
+    # This is equivalent to returning self.dtype . We do not return self.dtype
+    # as it would cause infinite recursion in a few subclasses, which override
+    # "dtype" to return self._dtype.
+    return self._dtype_policy.variable_dtype
+
+  @_dtype.setter
+  def _dtype(self, value):
+    value = dtypes.as_dtype(value).name
+    self._dtype_policy = policy.Policy(value)
 
   def _name_scope(self):
     return self.name
@@ -1772,19 +1994,25 @@
       return None
     return input_masks
 
-  def _call_arg_was_passed(self, arg_name, args, kwargs):
+  def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False):
     if arg_name in kwargs:
       return True
-    # Ignore `inputs` arg.
-    if arg_name in dict(zip(self._call_fn_args[1:], args)):
+    call_fn_args = self._call_fn_args
+    if not inputs_in_args:
+      # Ignore `inputs` arg.
+      call_fn_args = call_fn_args[1:]
+    if arg_name in dict(zip(call_fn_args, args)):
       return True
     return False
 
-  def _get_call_arg_value(self, arg_name, args, kwargs):
+  def _get_call_arg_value(self, arg_name, args, kwargs, inputs_in_args=False):
     if arg_name in kwargs:
       return kwargs[arg_name]
-    # Ignore `inputs` arg.
-    args_dict = dict(zip(self._call_fn_args[1:], args))
+    call_fn_args = self._call_fn_args
+    if not inputs_in_args:
+      # Ignore `inputs` arg.
+      call_fn_args = call_fn_args[1:]
+    args_dict = dict(zip(call_fn_args, args))
     return args_dict[arg_name]
 
   def _set_connectivity_metadata_(self, inputs, outputs, args, kwargs):
@@ -1792,7 +2020,7 @@
     # If the layer returns tensors from its inputs, unmodified,
     # we copy them to avoid loss of tensor metadata.
     output_ls = nest.flatten(outputs)
-    inputs_ls = nest.flatten(inputs)
+    inputs_ls = object_identity.ObjectIdentitySet(nest.flatten(inputs))
     output_ls_copy = []
     for x in output_ls:
       if x in inputs_ls:
@@ -1894,11 +2122,14 @@
       input_spec.assert_input_compatibility(
           self.input_spec, inputs, self.name)
       input_list = nest.flatten(inputs)
-      if input_list and self._dtype is None:
+      if input_list and self._dtype_policy.compute_dtype is None:
         try:
-          self._dtype = input_list[0].dtype.base_dtype.name
+          dtype = input_list[0].dtype.base_dtype.name
         except AttributeError:
           pass
+        else:
+          self._dtype_policy = policy.with_input_dtype(self._dtype_policy,
+                                                       dtype)
       input_shapes = None
       if all(hasattr(x, 'shape') for x in input_list):
         input_shapes = nest.map_structure(lambda x: x.shape, inputs)
@@ -2112,6 +2343,17 @@
   def _is_layer(self):
     return True
 
+  def _init_call_fn_args(self):
+    # Clear cached call function arguments.
+    self.__class__._call_fn_args.fget.cache.pop(self, None)
+    self.__class__._call_accepts_kwargs.fget.cache.pop(self, None)
+
+    call_fn_args = self._call_fn_args
+    self._expects_training_arg = ('training' in call_fn_args or
+                                  self._call_accepts_kwargs)
+    self._expects_mask_arg = ('mask' in call_fn_args or
+                              self._call_accepts_kwargs)
+
   @property
   @tracking.cached_per_instance
   def _call_fn_args(self):
@@ -2257,7 +2499,7 @@
   def _unique_trainable_weights(self):
     """Dedupe trainable weights while maintaining order as much as possible."""
     trainable_weights = self.trainable_weights
-    output, seen_weights = [], set()
+    output, seen_weights = [], object_identity.ObjectIdentitySet()
     for w in trainable_weights:
       if w not in seen_weights:
         output.append(w)
@@ -2285,11 +2527,11 @@
 
   Attributes:
     node_def: String, the serialized NodeDef of the Op this layer will wrap.
+    name: String, the name of the Layer.
     constants: Dict of NumPy arrays, the values of any Tensors needed for this
       Operation that do not originate from a Keras `Input` Layer. Since all
       placeholders must come from Keras `Input` Layers, these Tensors must be
       treated as constant in the Functional API.
-    name: String, the name of the Layer.
     trainable: Bool, whether this Layer is trainable. Currently Variables are
       not supported, and so this parameter has no effect.
     dtype: The default dtype of this Layer. Inherited from `Layer` and has no
@@ -2298,16 +2540,25 @@
 
   def __init__(self,
                node_def,
+               name,
                constants=None,
-               name=None,
                trainable=True,
                dtype=None):
+    # Pass autocast=False, as if inputs are cast, input types might not match
+    # Operation type.
     super(TensorFlowOpLayer, self).__init__(
-        name=_TF_OP_LAYER_NAME_PREFIX + name, trainable=trainable, dtype=dtype)
-    if not isinstance(node_def, bytes):
-      node_def = node_def.encode('utf-8')
-    self.node_def = node_def_pb2.NodeDef.FromString(node_def)
-    self.constants = constants or {}
+        name=_TF_OP_LAYER_NAME_PREFIX + name, trainable=trainable, dtype=dtype,
+        autocast=False)
+    if isinstance(node_def, dict):
+      self.node_def = json_format.ParseDict(node_def, node_def_pb2.NodeDef())
+    else:
+      if not isinstance(node_def, bytes):
+        node_def = node_def.encode('utf-8')
+      self.node_def = node_def_pb2.NodeDef.FromString(node_def)
+    # JSON serialization stringifies keys which are integer input indices.
+    self.constants = ({
+        int(index): constant for index, constant in constants.items()
+    } if constants is not None else {})
     # Layer uses original op unless it is called on new inputs.
     # This means `built` is not set in `__call__`.
     self.built = True
@@ -2365,7 +2616,9 @@
   def get_config(self):
     config = super(TensorFlowOpLayer, self).get_config()
     config.update({
-        'node_def': self.node_def.SerializeToString().decode('utf-8'),
+        # `__init__` prefixes the name. Revert to the constructor argument.
+        'name': config['name'][len(_TF_OP_LAYER_NAME_PREFIX):],
+        'node_def': json_format.MessageToDict(self.node_def),
         'constants': {
             i: backend.get_value(c) for i, c in self.constants.items()
         }
diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py
index 7bacc1e..a0ae0cc 100644
--- a/tensorflow/python/keras/engine/base_layer_test.py
+++ b/tensorflow/python/keras/engine/base_layer_test.py
@@ -18,12 +18,9 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
-import itertools as it
 import os
 import sys
 import traceback
-from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.python import keras
@@ -32,12 +29,15 @@
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import test_util
+from tensorflow.python.keras import backend
 from tensorflow.python.keras import keras_parameterized
 from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.engine import base_layer
+from tensorflow.python.keras.mixed_precision.experimental import policy
 from tensorflow.python.keras.optimizer_v2 import rmsprop
 from tensorflow.python.keras.utils import tf_utils
 from tensorflow.python.layers import core as legacy_core
@@ -47,9 +47,12 @@
 from tensorflow.python.ops import summary_ops_v2
 from tensorflow.python.ops import tensor_array_ops
 from tensorflow.python.ops import variables
+from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.platform import gfile
 from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
 from tensorflow.python.summary import summary_iterator
+from tensorflow.python.util import nest
 
 
 class DynamicLayer(base_layer.Layer):
@@ -119,7 +122,7 @@
         return inputs
 
     with context.eager_mode():
-      layer = BuildCounter()
+      layer = BuildCounter(dtype=dtypes.float64)
       output_shape = layer.compute_output_shape((None, 10))
       self.assertEqual(layer.build_counter, 1)
       self.assertEqual(output_shape.as_list(), [None, 10])
@@ -221,7 +224,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
     self.assertEqual(loss, 2 * 3)
 
@@ -313,7 +316,8 @@
     def get_learning_phase_value():
       model = keras.models.Sequential([LearningPhaseLayer(input_shape=(1,))])
       model._run_eagerly = testing_utils.should_run_eagerly()
-      model._run_distributed = testing_utils.should_run_distributed()
+      model._experimental_run_tf_function = (
+          testing_utils.should_run_tf_function())
       return np.sum(model(np.ones((1, 1))))
 
     self.assertEqual(get_learning_phase_value(), 0)
@@ -334,7 +338,7 @@
   @keras_parameterized.run_all_keras_modes
   def test_learning_phase_freezing_for_layers_in_predict(self):
     if not (testing_utils.should_run_eagerly() or
-            testing_utils.should_run_distributed()):
+            testing_utils.should_run_tf_function()):
       self.skipTest('Predict fails to override the outer learning phase in'
                     'the FuncGraph path.')
 
@@ -348,7 +352,8 @@
     def get_learning_phase_value():
       model = keras.models.Sequential([LearningPhaseLayer(input_shape=(1,))])
       model._run_eagerly = testing_utils.should_run_eagerly()
-      model._run_distributed = testing_utils.should_run_distributed()
+      model._experimental_run_tf_function = (
+          testing_utils.should_run_tf_function())
       return np.sum(model.predict(np.ones((1, 1))))
 
     self.assertEqual(get_learning_phase_value(), 0)
@@ -447,7 +452,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     x, y = np.ones((10, 10)), np.ones((10, 10))
     # Checks that variables get initialized.
     model.fit(x, y, batch_size=2, epochs=2)
@@ -494,7 +499,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     inputs = np.random.random((3, 10))
     out = model.predict(inputs)
     self.assertAllClose(model.layers[-1].get_weights()[0], kernel_value)
@@ -555,6 +560,24 @@
     # arguments, no error is thrown:
     self.assertEqual(MyLayerNew2(name='New').get_config()['name'], 'New')
 
+  @test_util.run_in_graph_and_eager_modes
+  def test_count_params(self):
+    dense = keras.layers.Dense(16)
+    dense.build((None, 4))
+    self.assertEqual(dense.count_params(), 16 * 4 + 16)
+
+    dense = keras.layers.Dense(16)
+    with self.assertRaisesRegexp(ValueError, 'call `count_params`'):
+      dense.count_params()
+
+    model = keras.Sequential(keras.layers.Dense(16))
+    with self.assertRaisesRegexp(ValueError, 'call `count_params`'):
+      model.count_params()
+
+    dense = keras.layers.Dense(16, input_dim=4)
+    model = keras.Sequential(dense)
+    self.assertEqual(model.count_params(), 16 * 4 + 16)
+
 
 class SymbolicSupportTest(test.TestCase):
 
@@ -916,7 +939,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     train_loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
     self.assertEqual(train_loss, 0.)
     test_loss = model.test_on_batch(np.ones((2, 3)), np.ones((2, 3)))
@@ -941,7 +964,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     train_loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
     self.assertEqual(train_loss, 2 * 3)
     test_loss = model.test_on_batch(np.ones((2, 3)), np.ones((2, 3)))
@@ -966,7 +989,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     _, train_metric = model.train_on_batch(np.ones((2, 3)),
                                            np.ones((2, 3)))
     self.assertEqual(train_metric, 2 * 3)
@@ -998,7 +1021,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
     self.assertEqual(keras.backend.get_value(layer.counter), 1.)
 
@@ -1032,7 +1055,7 @@
           'sgd',
           'mse',
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
       model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
       self.assertEqual(keras.backend.get_value(layer.counter), 6.)
     else:
@@ -1068,7 +1091,7 @@
           'sgd',
           'mse',
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
       loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
       self.assertEqual(loss, 2 * 3)
     else:
@@ -1082,7 +1105,7 @@
             1, kernel_regularizer=keras.regularizers.l2(1e-4), input_shape=(1,))
     ])
     model._run_eagerly = testing_utils.should_run_eagerly()
-    model._run_distributed = testing_utils.should_run_distributed()
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
 
     def assert_graph(t):
       if not context.executing_eagerly():
@@ -1125,7 +1148,7 @@
           'sgd',
           'mse',
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
       history = model.fit(np.ones((2, 3)), np.ones((2, 3)))
       self.assertEqual(history.history['sum'][-1], 2 * 3)
     else:
@@ -1154,7 +1177,7 @@
         loss='mse',
         optimizer='sgd',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     x = np.ones(shape=(10, 1))
     y = np.ones(shape=(10, 2))
@@ -1188,7 +1211,7 @@
         loss='mse',
         optimizer='sgd',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     x = np.ones(shape=(10, 3, 4))
     y = np.ones(shape=(10, 3, 2))
@@ -1201,66 +1224,237 @@
         model.fit(x, y, epochs=2, batch_size=5)
 
 
-_LAYERS_TO_TEST = [
-    (keras.layers.Dense, (1,), collections.OrderedDict(units=[1])),
-    (keras.layers.Activation, (2, 2),
-     collections.OrderedDict(activation=['relu'])),
-    (keras.layers.Dropout, (16,), collections.OrderedDict(rate=[0.25])),
-    (keras.layers.BatchNormalization, (8, 8, 3), collections.OrderedDict(
-        axis=[3], center=[True, False], scale=[True, False])),
-    (keras.layers.Conv1D, (8, 8), collections.OrderedDict(
-        filters=[1], kernel_size=[1, 3], strides=[1, 2],
-        padding=['valid', 'same'], use_bias=[True, False],
-        kernel_regularizer=[None, 'l2'])),
-    (keras.layers.Conv2D, (8, 8, 3), collections.OrderedDict(
-        filters=[1], kernel_size=[1, 3], strides=[1, 2],
-        padding=['valid', 'same'], use_bias=[True, False],
-        kernel_regularizer=[None, 'l2'])),
-    (keras.layers.LSTM, (8, 8), collections.OrderedDict(
-        units=[1],
-        activation=[None, 'relu'],
-        kernel_regularizer=[None, 'l2'],
-        dropout=[0, 0.5],
-        stateful=[True, False],
-        unroll=[True, False])),
-]
+class AddLayer(keras.layers.Layer):
+  """A layer which adds it's input to a variable.
 
-OUTPUT_TEST_CASES = []
-for layer_type, inp_shape, arg_dict in _LAYERS_TO_TEST:
-  arg_combinations = [[(k, i) for i in v] for k, v in arg_dict.items()]  # pylint: disable=g-complex-comprehension
-  for args in it.product(*arg_combinations):
-    name = '_{}_{}'.format(
-        layer_type.__name__, '_'.join('{}_{}'.format(k, v) for k, v in args))
-    OUTPUT_TEST_CASES.append(
-        (name, layer_type, inp_shape, {k: v for k, v in args}))
+  Useful for testing a layer with a variable
+  """
+
+  def build(self, _):
+    self.v = self.add_weight('v', (), initializer='ones')
+    self.built = True
+
+  def call(self, inputs):
+    return inputs + self.v
 
 
-class OutputTypeTest(keras_parameterized.TestCase):
-  """Test that layers and models produce the correct tensor types."""
+class IdentityLayer(keras.layers.Layer):
+  """A layer that returns it's input.
 
-  # In v1 graph there are only symbolic tensors.
-  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
-  @parameterized.named_parameters(*OUTPUT_TEST_CASES)
-  def test_layer_outputs(self, layer_to_test, input_shape, layer_kwargs):
-    layer = layer_to_test(**layer_kwargs)
+  Useful for testing a layer without a variable.
+  """
 
-    input_data = np.ones(shape=(2,) + input_shape, dtype=np.float32)
-    layer_result = layer(input_data)
+  def call(self, inputs):
+    return inputs
 
-    inp = keras.layers.Input(shape=input_shape, batch_size=2)
-    model = keras.models.Model(inp, layer_to_test(**layer_kwargs)(inp))
-    model_result = model(input_data)
 
-    for x in [layer_result, model_result]:
-      if not isinstance(x, ops.Tensor):
-        raise ValueError('Tensor or EagerTensor expected, got type {}'
-                         .format(type(x)))
+@test_util.run_all_in_graph_and_eager_modes
+class DTypeTest(keras_parameterized.TestCase):
 
-      if isinstance(x, ops.EagerTensor) != context.executing_eagerly():
-        expected_type = (ops.EagerTensor if context.executing_eagerly()
-                         else ops.Tensor)
-        raise ValueError('Expected type {}, got type {}'
-                         .format(expected_type, type(x)))
+  # This class only have tests relating to layer.dtype. Tests for dtype policies
+  # are in mixed_precision/experimental/keras_test.py
+
+  # TODO(reedwm): Maybe have a separate test file for input casting tests.
+
+  def _const(self, dtype):
+    return array_ops.constant(1, dtype=dtype)
+
+  @testing_utils.enable_v2_dtype_behavior
+  def test_dtype_defaults_to_floatx(self):
+    layer = AddLayer()
+    self.assertEqual(layer.dtype, 'float32')
+    layer(self._const('float64'))
+    self.assertEqual(layer.dtype, 'float32')  # dtype should not change
+
+    try:
+      backend.set_floatx('float64')
+      layer = AddLayer()
+      self.assertEqual(layer.dtype, 'float64')
+    finally:
+      backend.set_floatx('float32')
+
+  @testing_utils.enable_v2_dtype_behavior
+  def test_passing_dtype_to_constructor(self):
+    layer = IdentityLayer(dtype='float64')
+    layer(self._const('float32'))
+    self.assertEqual(layer.dtype, 'float64')
+
+    layer = IdentityLayer(dtype='int32')
+    layer(self._const('float32'))
+    self.assertEqual(layer.dtype, 'int32')
+
+    layer = IdentityLayer(dtype=dtypes.float64)
+    layer(self._const('float32'))
+    self.assertEqual(layer.dtype, 'float64')
+
+  @testing_utils.enable_v2_dtype_behavior
+  def input_cast_to_dtype(self):
+    layer = AddLayer()
+
+    # Input should be cast to layer.dtype, so output should also be layer.dtype
+    self.assertEqual(layer(self._const('float64')).dtype, 'float32')
+
+    layer = AddLayer(dtype='float64')
+    self.assertEqual(layer(self._const('float32')).dtype, 'float64')
+
+    # Test inputs are not casted if layer.dtype is not floating-point
+    layer = IdentityLayer(dtype='int32')
+    self.assertEqual(layer(self._const('float64')).dtype, 'float64')
+
+    # Test inputs are not casted if the inputs are not floating-point
+    layer = IdentityLayer(dtype='float32')
+    self.assertEqual(layer(self._const('int32')).dtype, 'int32')
+
+    # Test Numpy arrays are casted
+    layer = IdentityLayer(dtype='float64')
+    self.assertEqual(layer(np.array(1, dtype='float32')).dtype, 'float64')
+
+    # Test Python floats are casted
+    layer = IdentityLayer(dtype='float64')
+    self.assertEqual(layer(1.).dtype, 'float64')
+
+  @testing_utils.enable_v2_dtype_behavior
+  def multiple_inputs_cast_to_dtype(self):
+
+    class MultiIdentityLayer(keras.layers.Layer):
+
+      def call(self, inputs):
+        return [array_ops.identity(x) for x in inputs]
+
+    # Testing layer with default dtype of float32
+    layer = MultiIdentityLayer()
+    x, y = layer([self._const('float16'), self._const('float32')])
+    self.assertEqual(x.dtype, 'float32')
+    self.assertEqual(y.dtype, 'float32')
+
+    # Test passing dtype to the constructor
+    layer = MultiIdentityLayer(dtype='float64')
+    x, y = layer([self._const('float16'), self._const('float32')])
+    self.assertEqual(x.dtype, 'float64')
+    self.assertEqual(y.dtype, 'float64')
+
+    # Test several non-floating point types
+    layer = MultiIdentityLayer(dtype='float64')
+    x, y, z, w = layer([self._const('float16'), self._const('bool'),
+                        self._const('float64'), self._constant('complex64')])
+    self.assertEqual(x.dtype, 'float64')
+    self.assertEqual(y.dtype, 'bool')
+    self.assertEqual(z.dtype, 'float64')
+    self.assertEqual(w.dtype, 'complex64')
+
+  @testing_utils.enable_v2_dtype_behavior
+  def test_extra_args_and_kwargs_not_casted(self):
+
+    class IdentityLayerWithArgs(keras.layers.Layer):
+
+      def call(self, inputs, *args, **kwargs):
+        return nest.flatten([inputs, args, kwargs])
+
+    layer = IdentityLayerWithArgs(dtype='float64')
+    x, y, z = layer(self._const('float16'), self._const('float16'),
+                    kwarg=self._const('float16'))
+    self.assertEqual(x.dtype, 'float64')
+    self.assertEqual(y.dtype, 'float16')
+    self.assertEqual(z.dtype, 'float16')
+
+  @testing_utils.enable_v2_dtype_behavior
+  def test_layer_without_autocast(self):
+
+    class IdentityLayerWithoutAutocast(IdentityLayer):
+
+      def __init__(self, *args, **kwargs):
+        kwargs['autocast'] = False
+        super(IdentityLayerWithoutAutocast, self).__init__(*args, **kwargs)
+
+    layer = IdentityLayerWithoutAutocast(dtype='float64')
+    self.assertEqual(layer(self._const('float32')).dtype, 'float32')
+
+  @testing_utils.enable_v2_dtype_behavior
+  def test_dtype_warnings(self):
+    # Test a layer warns when it casts inputs.
+    layer = IdentityLayer()
+    with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
+      layer(self._const('float64'))
+      self.assertRegexpMatches(
+          str(mock_warn.call_args),
+          ".*from dtype float64 to the layer's dtype of float32.*"
+          "The layer has dtype float32 because.*")
+
+    # Test a layer does not warn a second time
+    with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
+      layer(self._const('float64'))
+      mock_warn.assert_not_called()
+
+    # Test a new layer can warn even if a different layer already warned
+    layer = IdentityLayer()
+    with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
+      layer(self._const('float64'))
+      self.assertRegexpMatches(
+          str(mock_warn.call_args),
+          ".*from dtype float64 to the layer's dtype of float32.*"
+          "The layer has dtype float32 because.*")
+
+    # Test a layer does not warn if a dtype is passed
+    layer = IdentityLayer(dtype='float32')
+    with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
+      layer(self._const('float64'))
+      mock_warn.assert_not_called()
+
+    # Test a layer does not warn if a Policy is set:
+    with policy.policy_scope('float32'):
+      layer = IdentityLayer()
+      with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
+        layer(self._const('float64'))
+        mock_warn.assert_not_called()
+
+  @testing_utils.enable_v2_dtype_behavior
+  def test_compute_output_signature(self):
+
+    class IdentityLayerWithOutputShape(IdentityLayer):
+
+      def compute_output_shape(self, input_shape):
+        return input_shape
+
+    layer = IdentityLayerWithOutputShape(dtype='float64')
+    output_signature = layer.compute_output_signature(
+        tensor_spec.TensorSpec(shape=(), dtype='float32'))
+    self.assertEqual(output_signature.shape, ())
+    self.assertEqual(output_signature.dtype, 'float64')
+
+  @testing_utils.enable_v2_dtype_behavior
+  def test_composite_tensors_input_casting(self):
+    sparse = sparse_tensor.SparseTensor(
+        indices=array_ops.constant([[0, 1], [2, 3]], dtype='int64'),
+        values=array_ops.constant([0., 1.], dtype='float32'),
+        dense_shape=array_ops.constant([4, 4], dtype='int64'))
+    ragged = ragged_tensor.RaggedTensor.from_row_splits(
+        values=array_ops.constant([1., 2., 3.], dtype='float32'),
+        row_splits=array_ops.constant([0, 2, 2, 3], dtype='int64'))
+
+    layer = IdentityLayer(dtype='float16')
+    for x in sparse, ragged:
+      self.assertEqual(x.dtype, 'float32')
+      y = layer(x)
+      self.assertEqual(y.dtype, 'float16')
+      self.assertEqual(type(x), type(y))
+
+  @testing_utils.enable_v2_dtype_behavior
+  def test_passing_non_tensor(self):
+    layer = IdentityLayer()
+    x = object()
+    y = layer(x)  # Layer should not cast 'x', as it's not a tensor
+    self.assertIs(x, y)
+
+  @testing_utils.disable_v2_dtype_behavior
+  def test_v1_behavior(self):
+    # Test dtype defaults to None and inferred from input
+    layer = IdentityLayer()
+    self.assertIsNone(layer.dtype)
+    layer(self._const('float64'))
+    self.assertEqual(layer.dtype, 'float64')
+
+    # Test layer does not cast to dtype
+    self.assertEqual(layer(self._const('float32')).dtype, 'float32')
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py
index ad0c7cc..a4826e5 100644
--- a/tensorflow/python/keras/engine/base_layer_utils.py
+++ b/tensorflow/python/keras/engine/base_layer_utils.py
@@ -19,6 +19,7 @@
 
 import threading
 
+from tensorflow.python import tf2
 from tensorflow.python.distribute import distribution_strategy_context
 from tensorflow.python.eager import context
 from tensorflow.python.framework import dtypes
@@ -27,7 +28,7 @@
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.keras import backend
 from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_util_v2
+from tensorflow.python.ops import control_flow_v2_func_graphs
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import init_ops_v2
 from tensorflow.python.ops import variables as tf_variables
@@ -224,7 +225,8 @@
             # configured improperly.
             constants[i] = op_input
           else:
-            constants[i] = backend.function([], op_input)([])
+            with ops.init_scope():
+              constants[i] = backend.function([], op_input)([])
       processed_ops, created_layers = _create_keras_history_helper(
           layer_inputs, processed_ops, created_layers)
       name = op.name
@@ -238,7 +240,7 @@
   return processed_ops, created_layers
 
 
-def needs_keras_history(tensors):
+def needs_keras_history(tensors, ignore_call_context=False):
   """Check if any Tensors need to be wrapped in TensorFlowOpLayers.
 
   This will never return True inside a sublayer, because sublayers
@@ -248,12 +250,18 @@
 
   Arguments:
     tensors: An arbitrary nested structure of Tensors.
+    ignore_call_context: Whether to ignore the check of if currently
+      outside of a `call` context. This is `True` when creating
+      KerasHistory inside `Node`, where we always know that Tensors
+      are being used with the Functional API.
 
   Returns:
     Bool, whether at least one Tensor needs to be wrapped.
   """
   input_tensors = nest.flatten(tensors)
-  if call_context().in_call or all(
+  if call_context().in_call and not ignore_call_context:
+    return False
+  if all(
       getattr(tensor, '_keras_history', None) is not None
       for tensor in input_tensors):
     # KerasHistory already set.
@@ -308,21 +316,25 @@
   tensors_to_check = nest.flatten(tensors)
 
   while tensors_to_check:
-    new_tensors_to_check = set()
+    new_tensors_to_check = []
     for tensor in tensors_to_check:
+      if id(tensor) in checked_tensors:
+        continue
+
+      checked_tensors.add(id(tensor))
+
       if getattr(tensor, '_keras_history_checked', None) is not None:
         continue
       if getattr(tensor, '_keras_history', None) is not None:
         return True
 
       try:
-        new_tensors_to_check.update(tensor.op.inputs)
+        new_tensors_to_check.extend(tensor.op.inputs)
       except AttributeError:
         # In case `tensor` is a Variable created in an Eager context.
         pass
 
-    checked_tensors.update(tensors_to_check)
-    tensors_to_check = list(new_tensors_to_check - checked_tensors)
+    tensors_to_check = new_tensors_to_check
 
   # Mark that these Tensors have been checked once for `_keras_history`,
   # and should not be checked again for performance reasons.
@@ -422,32 +434,21 @@
   return 'training' in full_args and full_args['training'] is not None
 
 
-def _get_var_read_dtype(input_list, should_cast):
-  """Gets the dtype that AutoCastVariables should be read in."""
-  if should_cast and input_list and input_list[0].dtype.is_floating:
-    return input_list[0].dtype.base_dtype
-  else:
-    return None
-
-
-def autocast_context_manager(input_list, should_cast):
+def autocast_context_manager(dtype):
   """Returns a context manager to autocast AutoCastVariables.
 
-  Under this context manager, if `should_cast` is True, AutoCastVariables will
-  be casted. If `should_cast` is False, AutoCastVariables will not be casted,
-  which can be used to disable autocasting if nested under another
-  call to `autocast_context_manager`.
+  Under this context manager, AutoCastVariables will be casted to `dtype` if
+  `dtype` is floating-point. Otherwise, AutoCastVariables will not be casted.
 
   Args:
-    input_list: The inputs to the layer with the AutoCastVariables.
-    should_cast: Whether AutoCastVariables should be casted.
+    dtype: The dtype to cast AutoCastVariables to, or None.
 
   Returns:
     A context manager to automatically cast AutoCastVariables.
   """
-  var_read_dtype = _get_var_read_dtype(input_list, should_cast)
-  return ops.get_default_graph()._enable_auto_casting_variables(  # pylint: disable=protected-access
-      var_read_dtype)
+  if dtype and not dtypes.as_dtype(dtype).is_floating:
+    dtype = None
+  return ops.get_default_graph()._enable_auto_casting_variables(dtype)  # pylint: disable=protected-access
 
 
 def is_subclassed(layer):
@@ -456,6 +457,11 @@
           layer.__module__.find('keras.layers') == -1)
 
 
+def from_saved_model(layer):
+  """Returns whether the layer is loaded from a SavedModel."""
+  return layer.__module__.find('keras.saving.saved_model') != -1
+
+
 def check_graph_consistency(tensor=None, method='add_loss', force_raise=False):
   """Checks that tensors passed to `add_*` method match the Keras graph.
 
@@ -472,12 +478,13 @@
   Raises:
     RuntimeError: In case of an out-of-graph tensor.
   """
-  if (force_raise or (ops.executing_eagerly_outside_functions() and
-                      hasattr(tensor, 'graph') and
-                      isinstance(tensor.graph,
-                                 (control_flow_util_v2.CondBranchFuncGraph,
-                                  control_flow_util_v2.WhileCondFuncGraph,
-                                  control_flow_util_v2.WhileBodyFuncGraph)))):
+  if (force_raise or
+      (ops.executing_eagerly_outside_functions() and
+       hasattr(tensor, 'graph') and
+       isinstance(tensor.graph,
+                  (control_flow_v2_func_graphs.CondBranchFuncGraph,
+                   control_flow_v2_func_graphs.WhileCondFuncGraph,
+                   control_flow_v2_func_graphs.WhileBodyFuncGraph)))):
     if method == 'activity_regularizer':
       bad_example = """
       class TestModel(tf.keras.Model):
@@ -613,3 +620,60 @@
   """Decorates a method to detect overrides in subclasses."""
   method._is_default = True  # pylint: disable=protected-access
   return method
+
+
+V2_DTYPE_BEHAVIOR = None
+
+
+# These two functions are not exported because we plan on removing them in the
+# future.
+def enable_v2_dtype_behavior():
+  """Enable the V2 dtype behavior for Keras layers.
+
+  By default, the V2 dtype behavior is enabled in TensorFlow 2.
+
+  When enabled, the dtype of Keras layers defaults to floatx (which is typically
+  float32) instead of None. In addition, layers will automatically cast
+  floating-point inputs to the layer's dtype.
+
+  For example, once enabled, the following block will run a Conv2D layer
+  in float32:
+
+  ```python
+  x = tf.ones((4, 4, 4, 4), dtype='float64')
+  layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2)
+  print(layer.dtype)  # Float32 when enabled. None when disabled.
+  # When enabled, will cast inputs to the layer's dtype, which is float32. When
+  # disabled, will do no casting, so the layer is done in float64.
+  y = layer(x)
+  ```
+
+  A layer author can opt-out their layer from the automatic input casting by
+  passing `autocast=False` to the base Layer's constructor. This disables the
+  autocasting part of the V2 behavior for that layer, but not the defaulting to
+  floatx part of the V2 behavior.
+
+  When a global `tf.keras.mixed_precision.experimental.Policy` is set, the
+  layer's dtype will default to the global policy instead of floatx. Layers
+  will automatically cast inputs to the policy's compute_dtype.
+  """
+  global V2_DTYPE_BEHAVIOR
+  V2_DTYPE_BEHAVIOR = True
+
+
+def disable_v2_dtype_behavior():
+  """Disables the V2 dtype behavior for Keras layers.
+
+  See `enable_v2_dtype_behavior`.
+
+  This function will be removed in the future.
+  """
+  global V2_DTYPE_BEHAVIOR
+  V2_DTYPE_BEHAVIOR = False
+
+
+def v2_dtype_behavior_enabled():
+  """Returns True if the V2 dtype behavior is enabled."""
+  if V2_DTYPE_BEHAVIOR is None:
+    return tf2.enabled()
+  return V2_DTYPE_BEHAVIOR
diff --git a/tensorflow/python/keras/engine/base_preprocessing_layer_test.py b/tensorflow/python/keras/engine/base_preprocessing_layer_test.py
index ac26e68..f500876 100644
--- a/tensorflow/python/keras/engine/base_preprocessing_layer_test.py
+++ b/tensorflow/python/keras/engine/base_preprocessing_layer_test.py
@@ -27,6 +27,7 @@
 from tensorflow.python.eager import context
 from tensorflow.python.framework import dtypes
 from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.engine import base_preprocessing_layer
 from tensorflow.python.keras.engine import base_preprocessing_layer_v1
 from tensorflow.python.ops import init_ops
@@ -158,10 +159,12 @@
     layer = get_layer()
     output = layer(input_data)
     model = keras.Model(input_data, output)
+    model._run_eagerly = testing_utils.should_run_eagerly()
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
 
     layer.set_total(15)
 
-    self.assertAllEqual([[16], [17], [18]], model.predict([1, 2, 3]))
+    self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
 
   def test_pre_build_adapt_update_numpy(self):
     """Test that preproc layers can adapt() before build() is called."""
@@ -173,8 +176,10 @@
     input_data = keras.Input(shape=(1,))
     output = layer(input_data)
     model = keras.Model(input_data, output)
+    model._run_eagerly = testing_utils.should_run_eagerly()
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
 
-    self.assertAllEqual([[16], [17], [18]], model.predict([1, 2, 3]))
+    self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
 
   def test_post_build_adapt_update_numpy(self):
     """Test that preproc layers can adapt() after build() is called."""
@@ -184,10 +189,12 @@
     layer = get_layer()
     output = layer(input_data)
     model = keras.Model(input_data, output)
+    model._run_eagerly = testing_utils.should_run_eagerly()
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
 
     layer.adapt(input_dataset)
 
-    self.assertAllEqual([[16], [17], [18]], model.predict([1, 2, 3]))
+    self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
 
   def test_pre_build_injected_update(self):
     """Test external update injection before build() is called."""
@@ -203,8 +210,10 @@
     input_data = keras.Input(shape=(1,))
     output = layer(input_data)
     model = keras.Model(input_data, output)
+    model._run_eagerly = testing_utils.should_run_eagerly()
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
 
-    self.assertAllEqual([[16], [17], [18]], model.predict([1, 2, 3]))
+    self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
 
   def test_post_build_injected_update(self):
     """Test external update injection after build() is called."""
@@ -213,12 +222,14 @@
     layer = get_layer()
     output = layer(input_data)
     model = keras.Model(input_data, output)
+    model._run_eagerly = testing_utils.should_run_eagerly()
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
 
     combiner = layer._combiner
     updates = combiner.extract(combiner.compute(input_dataset))
     layer._set_state_variables(updates)
 
-    self.assertAllEqual([[16], [17], [18]], model.predict([1, 2, 3]))
+    self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
 
   def test_pre_build_adapt_update_dataset(self):
     """Test that preproc layers can adapt() before build() is called."""
@@ -231,8 +242,10 @@
     input_data = keras.Input(shape=(1,))
     output = layer(input_data)
     model = keras.Model(input_data, output)
+    model._run_eagerly = testing_utils.should_run_eagerly()
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
 
-    self.assertAllEqual([[16], [17], [18]], model.predict([1, 2, 3]))
+    self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
 
   def test_post_build_adapt_update_dataset(self):
     """Test that preproc layers can adapt() after build() is called."""
@@ -243,10 +256,12 @@
     layer = get_layer()
     output = layer(input_data)
     model = keras.Model(input_data, output)
+    model._run_eagerly = testing_utils.should_run_eagerly()
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
 
     layer.adapt(input_dataset)
 
-    self.assertAllEqual([[16], [17], [18]], model.predict([1, 2, 3]))
+    self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
 
   def test_further_tuning(self):
     """Test that models can be tuned with multiple calls to 'adapt'."""
@@ -259,10 +274,13 @@
     input_data = keras.Input(shape=(1,))
     output = layer(input_data)
     model = keras.Model(input_data, output)
-    self.assertAllEqual([[16], [17], [18]], model.predict([1, 2, 3]))
+    model._run_eagerly = testing_utils.should_run_eagerly()
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
+
+    self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
 
     layer.adapt(np.array([1, 2]), reset_state=False)
-    self.assertAllEqual([[19], [20], [21]], model.predict([1, 2, 3]))
+    self.assertAllEqual([[19], [20], [21]], model.predict([1., 2., 3.]))
 
   def test_further_tuning_post_injection(self):
     """Test that models can be tuned with multiple calls to 'adapt'."""
@@ -274,14 +292,16 @@
     input_data = keras.Input(shape=(1,))
     output = layer(input_data)
     model = keras.Model(input_data, output)
+    model._run_eagerly = testing_utils.should_run_eagerly()
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
 
     combiner = layer._combiner
     updates = combiner.extract(combiner.compute(input_dataset))
     layer._set_state_variables(updates)
-    self.assertAllEqual([[16], [17], [18]], model.predict([1, 2, 3]))
+    self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
 
     layer.adapt(np.array([1, 2]), reset_state=False)
-    self.assertAllEqual([[19], [20], [21]], model.predict([1, 2, 3]))
+    self.assertAllEqual([[19], [20], [21]], model.predict([1., 2., 3.]))
 
   def test_weight_based_state_transfer(self):
     """Test that preproc layers can transfer state via get/set weights.."""
@@ -290,21 +310,25 @@
       input_data = keras.Input(shape=(1,))
       layer = get_layer()
       output = layer(input_data)
-      return (keras.Model(input_data, output), layer)
+      model = keras.Model(input_data, output)
+      model._run_eagerly = testing_utils.should_run_eagerly()
+      model._experimental_run_tf_function = (
+          testing_utils.should_run_tf_function())
+      return (model, layer)
 
     input_dataset = np.array([1, 2, 3, 4, 5])
     model, layer = get_model()
     layer.adapt(input_dataset)
-    self.assertAllEqual([[16], [17], [18]], model.predict([1, 2, 3]))
+    self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
 
     # Create a new model and verify it has no state carryover.
     weights = model.get_weights()
     model_2, _ = get_model()
-    self.assertAllEqual([[1], [2], [3]], model_2.predict([1, 2, 3]))
+    self.assertAllEqual([[1], [2], [3]], model_2.predict([1., 2., 3.]))
 
     # Transfer state from model to model_2 via get/set weights.
     model_2.set_weights(weights)
-    self.assertAllEqual([[16], [17], [18]], model_2.predict([1, 2, 3]))
+    self.assertAllEqual([[16], [17], [18]], model_2.predict([1., 2., 3.]))
 
   def test_weight_based_state_transfer_with_further_tuning(self):
     """Test that transferred state can be used to further tune a model.."""
@@ -313,12 +337,16 @@
       input_data = keras.Input(shape=(1,))
       layer = get_layer()
       output = layer(input_data)
-      return (keras.Model(input_data, output), layer)
+      model = keras.Model(input_data, output)
+      model._run_eagerly = testing_utils.should_run_eagerly()
+      model._experimental_run_tf_function = (
+          testing_utils.should_run_tf_function())
+      return (model, layer)
 
     input_dataset = np.array([1, 2, 3, 4, 5])
     model, layer = get_model()
     layer.adapt(input_dataset)
-    self.assertAllEqual([[16], [17], [18]], model.predict([1, 2, 3]))
+    self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
 
     # Transfer state from model to model_2 via get/set weights.
     weights = model.get_weights()
@@ -327,7 +355,7 @@
 
     # Further adapt this layer based on the transferred weights.
     layer_2.adapt(np.array([1, 2]), reset_state=False)
-    self.assertAllEqual([[19], [20], [21]], model_2.predict([1, 2, 3]))
+    self.assertAllEqual([[19], [20], [21]], model_2.predict([1., 2., 3.]))
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/keras/engine/correctness_test.py b/tensorflow/python/keras/engine/correctness_test.py
index aa005aa..3f75b2b 100644
--- a/tensorflow/python/keras/engine/correctness_test.py
+++ b/tensorflow/python/keras/engine/correctness_test.py
@@ -70,7 +70,7 @@
         keras.optimizer_v2.gradient_descent.SGD(0.1),
         'mae',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     return model
 
   def test_simple_bias_fit(self):
@@ -109,7 +109,7 @@
         keras.optimizer_v2.gradient_descent.SGD(0.1),
         'mae',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     return model
 
   @parameterized.named_parameters(('subclassed', True), ('functional', False))
diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py
index bd29560..5a1fbb6 100644
--- a/tensorflow/python/keras/engine/data_adapter.py
+++ b/tensorflow/python/keras/engine/data_adapter.py
@@ -27,8 +27,10 @@
 
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.framework import ops
+from tensorflow.python.framework.ops import composite_tensor
 from tensorflow.python.keras.engine import training_utils
 from tensorflow.python.keras.utils import data_utils
+from tensorflow.python.ops import array_ops
 from tensorflow.python.util import nest
 from tensorflow.python.util import tf_inspect
 
@@ -166,18 +168,38 @@
 
   @staticmethod
   def can_handle(x, y=None):
+    # TODO(kaftan): Check performance implications of using a flatten
+    #  here for other types of inputs.
     flat_inputs = nest.flatten(x)
     if y is not None:
       flat_inputs += nest.flatten(y)
 
-    return all(isinstance(v, (ops.Tensor, np.ndarray)) for v in flat_inputs)
+    def _is_tensor_or_composite(v):
+      if isinstance(v, (ops.Tensor, np.ndarray)):
+        return True
+      # Dataset inherits from CompositeTensor but shouldn't be handled here.
+      if (isinstance(v, composite_tensor.CompositeTensor) and
+          not isinstance(v, dataset_ops.DatasetV2)):
+        return True
+      return False
+
+    return all(_is_tensor_or_composite(v) for v in flat_inputs)
 
   def __init__(self, x, y=None, sample_weights=None, batch_size=None,
-               shuffle=False, **kwargs):
+               steps=None, shuffle=False, **kwargs):
     super(TensorLikeDataAdapter, self).__init__(x, y, **kwargs)
     x = _process_numpy_inputs(x)
     y = _process_numpy_inputs(y)
     sample_weights = _process_numpy_inputs(sample_weights)
+
+    # If sample_weights are not specified for an output use 1.0 as weights.
+    if sample_weights is not None and None in sample_weights:
+      weight = next(s for s in sample_weights if s is not None)
+      sample_weights = training_utils.list_to_tuple([
+          array_ops.ones((weight.shape[0],)) if sw is None else sw
+          for sw in sample_weights
+      ])
+
     if y is not None and sample_weights is not None:
       inputs = (x, y, sample_weights)
     elif y is not None:
@@ -187,23 +209,25 @@
     else:
       inputs = (x,)
 
-    if not batch_size:
-      raise ValueError(
-          "`batch_size` is required for `Tensor` or `NumPy` input data.")
-
     dataset = dataset_ops.DatasetV2.from_tensor_slices(inputs)
     num_samples = int(nest.flatten(x)[0].shape[0])
     if shuffle:
       dataset = dataset.shuffle(num_samples)
-    if batch_size:
-      dataset = dataset.batch(batch_size)
-      self._size = int(math.ceil(num_samples / batch_size))
-      self._batch_size = batch_size
-      self._has_partial_batch = (self._size != (num_samples // batch_size))
-    else:
-      self._size = 1
-      self._batch_size = num_samples
-      self._has_partial_batch = False
+
+    # If batch_size is not passed but steps is, calculate from the input data.
+    if steps and not batch_size:
+      batch_size = int(math.ceil(num_samples/steps))
+
+    if not batch_size:
+      raise ValueError(
+          "`batch_size` or `steps` is required for `Tensor` or `NumPy`"
+          " input data.")
+
+    dataset = dataset.batch(batch_size)
+    self._size = int(math.ceil(num_samples / batch_size))
+    self._batch_size = batch_size
+    self._has_partial_batch = (self._size != (num_samples // batch_size))
+
     self._partial_batch_size = None
     if self._has_partial_batch:
       self._partial_batch_size = (
@@ -227,6 +251,55 @@
     return self._partial_batch_size
 
 
+class ListsOfScalarsDataAdapter(DataAdapter):
+  """Adapter that handles lists of scalars and lists of lists of scalars."""
+
+  @staticmethod
+  def can_handle(x, y=None):
+    handles_x = ListsOfScalarsDataAdapter._is_list_of_scalars(x)
+    handles_y = True
+    if y is not None:
+      handles_y = ListsOfScalarsDataAdapter._is_list_of_scalars(y)
+    return handles_x and handles_y
+
+  @staticmethod
+  def _is_list_of_scalars(inp):
+    if isinstance(inp, (float, int, str)):
+      return True
+    if isinstance(inp, (list, tuple)):
+      return ListsOfScalarsDataAdapter._is_list_of_scalars(inp[0])
+    return False
+
+  def __init__(
+      self, x, y=None, sample_weights=None, batch_size=None,
+      shuffle=False, **kwargs):
+    super(ListsOfScalarsDataAdapter, self).__init__(x, y, **kwargs)
+    x = np.asarray(x)
+    if y is not None:
+      y = np.asarray(y)
+    if sample_weights is not None:
+      sample_weights = np.asarray(sample_weights)
+
+    self._internal_adapter = TensorLikeDataAdapter(
+        x, y=y, sample_weights=sample_weights,
+        batch_size=batch_size, shuffle=shuffle, **kwargs)
+
+  def get_dataset(self):
+    return self._internal_adapter.get_dataset()
+
+  def get_size(self):
+    return self._internal_adapter.get_size()
+
+  def batch_size(self):
+    return self._internal_adapter.batch_size()
+
+  def has_partial_batch(self):
+    return self._internal_adapter.has_partial_batch()
+
+  def partial_batch_size(self):
+    return self._internal_adapter.partial_batch_size()
+
+
 class DatasetAdapter(DataAdapter):
   """Adapter that handles `tf.data.Dataset`."""
 
@@ -358,6 +431,7 @@
 
 
 ALL_ADAPTER_CLS = [
+    ListsOfScalarsDataAdapter,
     TensorLikeDataAdapter, DatasetAdapter, GeneratorDataAdapter,
     KerasSequenceAdapter
 ]
diff --git a/tensorflow/python/keras/engine/data_adapter_test.py b/tensorflow/python/keras/engine/data_adapter_test.py
index 5564e6c..8f5fe16 100644
--- a/tensorflow/python/keras/engine/data_adapter_test.py
+++ b/tensorflow/python/keras/engine/data_adapter_test.py
@@ -18,6 +18,7 @@
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.python import keras
@@ -31,7 +32,7 @@
 
 
 @test_util.run_all_in_graph_and_eager_modes
-class DataAdapterTestBase(test.TestCase):
+class DataAdapterTestBase(test.TestCase, parameterized.TestCase):
 
   def setUp(self):
     super(DataAdapterTestBase, self).setUp()
@@ -83,7 +84,8 @@
     self.assertFalse(self.adapter_cls.can_handle(self.sequence_input))
 
   def test_iterator_expect_batch_size_numpy(self):
-    with self.assertRaisesRegexp(ValueError, r'`batch_size` is required'):
+    with self.assertRaisesRegexp(
+        ValueError, r'`batch_size` or `steps` is required'):
       self.adapter_cls(self.numpy_input, self.numpy_target)
 
   def test_size_numpy(self):
@@ -131,17 +133,33 @@
     self.assertEqual(adapter.get_size(), 10)
     self.assertFalse(adapter.has_partial_batch())
 
-  def test_batch_size(self):
+  @parameterized.named_parameters(
+      ('batch_size_5', 5, None, 5),
+      ('batch_size_50', 50, 4, 50),  # Sanity check: batch_size takes precedence
+      ('steps_1', None, 1, 50),
+      ('steps_4', None, 4, 13),
+      )
+  def test_batch_size(self, batch_size_in, steps, batch_size_out):
     adapter = self.adapter_cls(
-        self.tensor_input, self.tensor_target, batch_size=5)
-    self.assertEqual(adapter.batch_size(), 5)
+        self.tensor_input, self.tensor_target, batch_size=batch_size_in,
+        steps=steps)
+    self.assertEqual(adapter.batch_size(), batch_size_out)
 
-  def test_partial_batch(self):
+  @parameterized.named_parameters(
+      ('batch_size_5', 5, None, 10, 0),
+      ('batch_size_4', 4, None, 13, 2),
+      ('steps_1', None, 1, 1, 0),
+      ('steps_5', None, 5, 5, 0),
+      ('steps_4', None, 4, 4, 11),
+      )
+  def test_partial_batch(
+      self, batch_size_in, steps, size, partial_batch_size):
     adapter = self.adapter_cls(
-        self.tensor_input, self.tensor_target, batch_size=4)
-    self.assertEqual(adapter.get_size(), 13)   # 50/4
-    self.assertTrue(adapter.has_partial_batch())
-    self.assertEqual(adapter.partial_batch_size(), 2)
+        self.tensor_input, self.tensor_target, batch_size=batch_size_in,
+        steps=steps)
+    self.assertEqual(adapter.get_size(), size)   # 50/steps
+    self.assertEqual(adapter.has_partial_batch(), bool(partial_batch_size))
+    self.assertEqual(adapter.partial_batch_size(), partial_batch_size or None)
 
 
 class DatasetAdapterTest(DataAdapterTestBase):
diff --git a/tensorflow/python/keras/engine/feature_columns_integration_test.py b/tensorflow/python/keras/engine/feature_columns_integration_test.py
index a151b84..f50508c 100644
--- a/tensorflow/python/keras/engine/feature_columns_integration_test.py
+++ b/tensorflow/python/keras/engine/feature_columns_integration_test.py
@@ -49,8 +49,6 @@
 
   @keras_parameterized.run_all_keras_modes
   def test_sequential_model(self):
-    if testing_utils.should_run_distributed():
-      self.skipTest('b/137397816')
     columns = [fc.numeric_column('a')]
     model = keras.models.Sequential([
         fc.DenseFeatures(columns),
@@ -62,7 +60,7 @@
         loss='categorical_crossentropy',
         metrics=['accuracy'],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     x = {'a': np.random.random((10, 1))}
     y = np.random.randint(20, size=(10, 1))
@@ -74,8 +72,6 @@
 
   @keras_parameterized.run_all_keras_modes
   def test_sequential_model_with_ds_input(self):
-    if testing_utils.should_run_distributed():
-      self.skipTest('b/137397816')
     columns = [fc.numeric_column('a')]
     model = keras.models.Sequential([
         fc.DenseFeatures(columns),
@@ -87,7 +83,7 @@
         loss='categorical_crossentropy',
         metrics=['accuracy'],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     y = np.random.randint(20, size=(100, 1))
     y = keras.utils.to_categorical(y, num_classes=20)
@@ -141,8 +137,6 @@
 
   @keras_parameterized.run_all_keras_modes
   def test_subclassed_model_with_feature_columns(self):
-    if testing_utils.should_run_distributed():
-      self.skipTest('b/137397816')
     col_a = fc.numeric_column('a')
     col_b = fc.numeric_column('b')
 
@@ -153,7 +147,7 @@
         loss='categorical_crossentropy',
         metrics=['accuracy'],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     x = {'a': np.random.random((10, 1)), 'b': np.random.random((10, 1))}
     y = np.random.randint(20, size=(10, 1))
@@ -165,8 +159,6 @@
 
   @keras_parameterized.run_all_keras_modes
   def test_subclassed_model_with_feature_columns_with_ds_input(self):
-    if testing_utils.should_run_distributed():
-      self.skipTest('b/137397816')
     col_a = fc.numeric_column('a')
     col_b = fc.numeric_column('b')
 
@@ -177,7 +169,7 @@
         loss='categorical_crossentropy',
         metrics=['accuracy'],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     y = np.random.randint(20, size=(100, 1))
     y = keras.utils.to_categorical(y, num_classes=20)
diff --git a/tensorflow/python/keras/engine/input_layer.py b/tensorflow/python/keras/engine/input_layer.py
index 2440448..82c2e2d 100644
--- a/tensorflow/python/keras/engine/input_layer.py
+++ b/tensorflow/python/keras/engine/input_layer.py
@@ -102,6 +102,7 @@
     super(InputLayer, self).__init__(dtype=dtype, name=name)
     self.built = True
     self.sparse = sparse
+    self.ragged = ragged
     self.batch_size = batch_size
     self.supports_masking = True
 
@@ -152,6 +153,7 @@
         'batch_input_shape': self._batch_input_shape,
         'dtype': self.dtype,
         'sparse': self.sparse,
+        'ragged': self.ragged,
         'name': self.name
     }
     return config
@@ -204,7 +206,8 @@
           values of 'None' in the 'shape' argument represent ragged dimensions.
           For more information about RaggedTensors, see
           https://www.tensorflow.org/guide/ragged_tensors.
-      **kwargs: deprecated arguments support.
+      **kwargs: deprecated arguments support. Supports `batch_shape` and
+          `batch_input_shape`.
 
   Returns:
     A `tensor`.
@@ -235,15 +238,21 @@
     raise ValueError(
         'Cannot set both sparse and ragged to True in a Keras input.')
 
-  batch_shape = None
-  if 'batch_shape' in kwargs:
-    batch_shape = kwargs.pop('batch_shape')
-    if shape and batch_shape:
-      raise ValueError('Only provide the shape OR '
-                       'batch_shape argument to '
-                       'Input, not both at the same time.')
-    batch_size = batch_shape[0]
-    shape = batch_shape[1:]
+  input_layer_config = {'name': name, 'dtype': dtype, 'sparse': sparse,
+                        'ragged': ragged, 'input_tensor': tensor}
+
+  batch_input_shape = kwargs.pop('batch_input_shape',
+                                 kwargs.pop('batch_shape', None))
+  if shape and batch_input_shape:
+    raise ValueError('Only provide the `shape` OR `batch_input_shape` argument '
+                     'to Input, not both at the same time.')
+  if batch_input_shape:
+    shape = batch_input_shape[1:]
+    input_layer_config.update({'batch_input_shape': batch_input_shape})
+  else:
+    input_layer_config.update(
+        {'batch_size': batch_size, 'input_shape': shape})
+
   if kwargs:
     raise ValueError('Unrecognized keyword arguments:', kwargs.keys())
 
@@ -253,23 +262,7 @@
                      '`shape` does not include the batch '
                      'dimension.')
 
-  if batch_shape:
-    input_layer = InputLayer(
-        batch_input_shape=batch_shape,
-        name=name,
-        dtype=dtype,
-        sparse=sparse,
-        ragged=ragged,
-        input_tensor=tensor)
-  else:
-    input_layer = InputLayer(
-        input_shape=shape,
-        batch_size=batch_size,
-        name=name,
-        dtype=dtype,
-        sparse=sparse,
-        ragged=ragged,
-        input_tensor=tensor)
+  input_layer = InputLayer(**input_layer_config)
 
   # Return tensor including `_keras_history`.
   # Note that in this case train_output and test_output are the same pointer.
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index 9569bf7..ff5a479 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -26,6 +26,7 @@
 import os
 import threading
 
+import numpy as np
 from six.moves import zip  # pylint: disable=redefined-builtin
 
 from tensorflow.python import pywrap_tensorflow
@@ -41,7 +42,6 @@
 from tensorflow.python.keras.engine import base_layer_utils
 from tensorflow.python.keras.engine import node as node_module
 from tensorflow.python.keras.engine import training_utils
-from tensorflow.python.keras.mixed_precision.experimental import policy
 from tensorflow.python.keras.utils import generic_utils
 from tensorflow.python.keras.utils import layer_utils
 from tensorflow.python.keras.utils import tf_utils
@@ -54,6 +54,7 @@
 from tensorflow.python.training.tracking import tracking
 from tensorflow.python.training.tracking import util as trackable_utils
 from tensorflow.python.util import nest
+from tensorflow.python.util import object_identity
 from tensorflow.python.util import serialization
 from tensorflow.python.util import tf_inspect
 
@@ -189,7 +190,8 @@
     # self.losses
     # self.updates
 
-    generic_utils.validate_kwargs(kwargs, {'trainable', 'dtype', 'dynamic'})
+    generic_utils.validate_kwargs(kwargs, {'trainable', 'dtype', 'dynamic',
+                                           'autocast'})
 
     # Object to store all thread local layer properties.
     self._thread_local = threading.local()
@@ -227,8 +229,14 @@
       self._graph = None
     else:
       self._graph = ops.get_default_graph()  # Used in symbolic mode only.
-      # A Network does not create weights of its own, thus has no dtype.
-    self._dtype = kwargs.get('dtype', None)
+
+    # Both graph and subclassed networks have a dtype policy. The policy is
+    # currently ignored for a graph network, as graph networks disable
+    # autocasting (making the policy's compute dtype meaningless) and graph
+    # networks have no variables (making the policy's variable_dtype
+    # meaningless). For subclassed networks, the dtype policy acts as it does
+    # for any ordinary layer.
+    self._set_dtype_policy(kwargs.get('dtype', None))
 
     # All layers in order of horizontal graph traversal.
     # Entries are unique. Includes input and output layers.
@@ -241,12 +249,6 @@
     self._trackable_saver = (
         trackable_utils.saver_with_op_caching(self))
 
-    # Networks do not need to do any casting of inputs or variables, because
-    # each of its layers will handle casting through the layer's own
-    # implementation. Therefore networks use the 'infer' policy, which does no
-    # casting.
-    self._mixed_precision_policy = policy.Policy('infer')
-
   @trackable.no_automatic_dependency_tracking
   def _init_graph_network(self, inputs, outputs, name=None, **kwargs):
     generic_utils.validate_kwargs(
@@ -278,6 +280,9 @@
     # present in the signature of the `call` method of a graph network.
     self._expects_training_arg = True
     self._expects_mask_arg = True
+    # A graph network does not autocast inputs, as its layers will cast them
+    # instead.
+    self._autocast = False
 
     self._input_layers = []
     self._output_layers = []
@@ -310,12 +315,11 @@
       self._input_coordinates.append((layer, node_index, tensor_index))
 
     # Keep track of the network's nodes and layers.
-    nodes, nodes_by_depth, layers, layers_by_depth = _map_graph_network(
+    nodes, nodes_by_depth, layers, _ = _map_graph_network(
         self.inputs, self.outputs)
     self._network_nodes = nodes
     self._nodes_by_depth = nodes_by_depth
     self._layers = layers
-    self._layers_by_depth = layers_by_depth
     self._layer_call_argspecs = {}
     for layer in self._layers:
       self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
@@ -372,10 +376,9 @@
   def _init_subclassed_network(self, name=None, **kwargs):
     self._base_init(name=name, **kwargs)
     self._is_graph_network = False
-    self._expects_training_arg = ('training' in self._call_fn_args or
-                                  self._call_accepts_kwargs)
-    self._expects_mask_arg = ('mask' in self._call_fn_args or
-                              self._call_accepts_kwargs)
+    self._init_call_fn_args()
+    self._autocast = kwargs.get('autocast',
+                                base_layer_utils.v2_dtype_behavior_enabled())
     self.outputs = []
     self.inputs = []
     self.built = False
@@ -869,11 +872,7 @@
     }
     node_conversion_map = {}
     for layer in self.layers:
-      if issubclass(layer.__class__, Network) and layer._is_graph_network:
-        # Networks start with a pre-existing node linking their input to output.
-        kept_nodes = 1
-      else:
-        kept_nodes = 0
+      kept_nodes = 1 if _should_skip_first_node(layer) else 0
       for original_node_index, node in enumerate(layer._inbound_nodes):
         node_key = _make_node_key(layer.name, original_node_index)
         if node_key in self._network_nodes:
@@ -891,9 +890,9 @@
           # The node is relevant to the model:
           # add to filtered_inbound_nodes.
           if node.arguments:
+            kwargs = _serialize_tensors(node.arguments)
             try:
-              json.dumps(node.arguments)
-              kwargs = node.arguments
+              json.dumps(kwargs)
             except TypeError:
               logging.warning(
                   'Layer ' + layer.name +
@@ -977,9 +976,8 @@
     Raises:
         ValueError: In case of improperly formatted config dict.
     """
-    # Layer instances created during
-    # the graph reconstruction process
-    created_layers = {}
+    # Layer instances created during the graph reconstruction process.
+    created_layers = collections.OrderedDict()
 
     # Dictionary mapping layer instances to
     # node data that specifies a layer call.
@@ -1014,6 +1012,7 @@
           kwargs = {}
         elif len(input_data) == 4:
           kwargs = input_data[3]
+          kwargs = _deserialize_keras_tensors(kwargs, created_layers)
         else:
           raise ValueError('Improperly formatted model config.')
 
@@ -1110,14 +1109,20 @@
         layer for layer in created_layers.values() if layer not in model.layers
     ]
     if ancillary_layers:
-      model._insert_layers(ancillary_layers)
+      relevant_nodes = nest.flatten([
+          layer.inbound_nodes[1:]
+          if _should_skip_first_node(layer) else layer.inbound_nodes
+          for layer in created_layers.values()
+      ])
+      model._insert_layers(ancillary_layers, relevant_nodes)
     return model
 
   def save(self,
            filepath,
            overwrite=True,
            include_optimizer=True,
-           save_format=None):
+           save_format=None,
+           signatures=None):
     """Saves the model to Tensorflow SavedModel or a single HDF5 file.
 
     The savefile includes:
@@ -1143,6 +1148,9 @@
           to Tensorflow SavedModel or HDF5. The default is currently 'h5', but
           will switch to 'tf' in TensorFlow 2.0. The 'tf' option is currently
           disabled (use `tf.keras.experimental.export_saved_model` instead).
+      signatures: Signatures to save with the SavedModel. Applicable to the 'tf'
+        format only. Please see the `signatures` argument in
+        `tf.saved_model.save` for details.
 
     Example:
 
@@ -1157,7 +1165,8 @@
     model = load_model('my_model.h5')
     ```
     """
-    saving.save_model(self, filepath, overwrite, include_optimizer, save_format)
+    saving.save_model(self, filepath, overwrite, include_optimizer, save_format,
+                      signatures)
 
   def save_weights(self, filepath, overwrite=True, save_format=None):
     """Saves all layer weights.
@@ -1452,7 +1461,7 @@
   def _validate_graph_inputs_and_outputs(self):
     """Validates the inputs and outputs of a Graph Network."""
     # Check for redundancy in inputs.
-    if len(set(self.inputs)) != len(self.inputs):
+    if len(object_identity.ObjectIdentitySet(self.inputs)) != len(self.inputs):
       raise ValueError('The list of inputs passed to the model '
                        'is redundant. '
                        'All inputs should only appear once.'
@@ -1538,7 +1547,7 @@
     def _get_min_depth(node):
       """Gets the minimum depth at which node can be computed."""
       min_depth = 0
-      for layer, node_id, _, _ in node.iterate_inbound():
+      for layer, node_id, _, _ in node.iterate_inbound(include_arguments=True):
         inbound_node = layer._inbound_nodes[node_id]
         if inbound_node in node_to_depth:
           min_depth = min(min_depth, node_to_depth[inbound_node])
@@ -1563,26 +1572,23 @@
 
       node = unprocessed_nodes.pop(0)
       depth = _get_min_depth(node)
-      if depth is None:
+      if depth is None:  # Defer until inbound nodes are processed.
         unprocessed_nodes.append(node)
-      else:
-        node_key = _make_node_key(
-            node.outbound_layer.name,
-            node.outbound_layer._inbound_nodes.index(node))
+        continue
+      node_key = _make_node_key(node.outbound_layer.name,
+                                node.outbound_layer._inbound_nodes.index(node))
+      if node_key not in self._network_nodes:
         node_to_depth[node] = depth
         self._network_nodes.add(node_key)
         self._nodes_by_depth[depth].append(node)
 
-    # Insert layers into `_layer_by_depth` and other layer attrs.
+    # Insert layers and update other layer attrs.
+    layer_set = set(self._layers)
     for layer in layers:
-      depth = min([
-          node_to_depth[node]
-          for node in layer.inbound_nodes
-          if node in network_nodes
-      ])
-      self._layers_by_depth[depth].append(layer)
-      self._layers.append(layer)
-      self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
+      if layer not in layer_set:
+        self._layers.append(layer)
+        self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
+        layer_set.add(layer)
 
   def _assert_weights_created(self):
     """Asserts that all the weights for the network have been created.
@@ -1615,20 +1621,22 @@
     return '_tf_keras_network'
 
   def _graph_network_add_loss(self, symbolic_loss):
-    new_layers = _diff_layers(self.inputs, [symbolic_loss], self._layers)
+    new_nodes, new_layers = _map_subgraph_network(self.inputs, [symbolic_loss])
     # Losses must be keyed on inputs no matter what in order to be supported in
     # DistributionStrategy.
     add_loss_layer = base_layer.AddLoss(unconditional=False)
     add_loss_layer(symbolic_loss)
+    new_nodes.extend(add_loss_layer.inbound_nodes)
     new_layers.append(add_loss_layer)
-    self._insert_layers(new_layers)
+    self._insert_layers(new_layers, new_nodes)
 
   def _graph_network_add_metric(self, value, aggregation, name):
-    new_layers = _diff_layers(self.inputs, [value], self._layers)
+    new_nodes, new_layers = _map_subgraph_network(self.inputs, [value])
     add_metric_layer = base_layer.AddMetric(aggregation, name)
     add_metric_layer(value)
+    new_nodes.extend(add_metric_layer.inbound_nodes)
     new_layers.append(add_metric_layer)
-    self._insert_layers(new_layers)
+    self._insert_layers(new_layers, new_nodes)
 
 
 def _is_hdf5_filepath(filepath):
@@ -1712,7 +1720,8 @@
     nodes_in_progress.add(node)
 
     # Propagate to all previous tensors connected to this node.
-    for layer, node_index, tensor_index, tensor in node.iterate_inbound():
+    for layer, node_index, tensor_index, tensor in node.iterate_inbound(
+        include_arguments=True):
       build_map(tensor, finished_nodes, nodes_in_progress, layer, node_index,
                 tensor_index)
 
@@ -1775,7 +1784,7 @@
   depth_keys = list(layers_by_depth.keys())
   depth_keys.sort(reverse=True)
 
-  # Set self.layers and self._layers_by_depth.
+  # Set self.layers ordered by depth.
   layers = []
   for depth in depth_keys:
     layers_for_depth = layers_by_depth[depth]
@@ -1791,9 +1800,9 @@
   # Check that all tensors required are computable.
   # computable_tensors: all tensors in the graph
   # that can be computed from the inputs provided.
-  computable_tensors = []
+  computable_tensors = object_identity.ObjectIdentitySet()
   for x in inputs:
-    computable_tensors.append(x)
+    computable_tensors.add(x)
 
   layers_with_complete_input = []  # To provide a better error msg.
   for depth in depth_keys:
@@ -1809,7 +1818,7 @@
                              'were accessed without issue: ' +
                              str(layers_with_complete_input))
         for x in nest.flatten(node.output_tensors):
-          computable_tensors.append(x)
+          computable_tensors.add(x)
         layers_with_complete_input.append(layer.name)
 
   # Ensure name unicity, which will be crucial for serialization
@@ -1823,18 +1832,63 @@
   return network_nodes, nodes_by_depth, layers, layers_by_depth
 
 
-def _diff_layers(inputs, outputs, layers):
-  """Returns the layers in the network topology minus those in `layers`.
+def _map_subgraph_network(inputs, outputs):
+  """Returns the nodes and layers in the topology from `inputs` to `outputs`.
 
   Args:
     inputs: List of input tensors.
     outputs: List of output tensors.
-    layers: List of layers.
 
   Returns:
-    List of layers in the network topology not in `layers`.
+    A tuple of List{Node] and List[Layer].
   """
   base_layer_utils.create_keras_history(outputs)
-  # List of all layers in the topology betweeen inputs and outputs.
-  all_layers = _map_graph_network(inputs, outputs)[2]
-  return [layer for layer in all_layers if layer not in layers]
+  # Keep only nodes and layers in the topology betweeen inputs and outputs.
+  _, nodes_by_depth, layers, _ = _map_graph_network(inputs, outputs)
+  return nest.flatten([nodes for nodes in nodes_by_depth.values()]), layers
+
+
+def _should_skip_first_node(layer):
+  """Returns True if the first layer node should not be saved or loaded."""
+  # Networks start with a pre-existing node linking their input to output.
+  return issubclass(layer.__class__, Network) and layer._is_graph_network
+
+
+def _serialize_tensors(kwargs):
+  """Serializes Tensors passed to `call`."""
+
+  def _serialize_keras_tensor(t):
+    """Serializes a single Tensor passed to `call`."""
+    if hasattr(t, '_keras_history'):
+      kh = t._keras_history
+      return [kh.layer.name, kh.node_index, kh.tensor_index]
+
+    if isinstance(t, np.ndarray):
+      return t.tolist()
+
+    if isinstance(t, ops.Tensor):
+      return backend.get_value(t).tolist()
+
+    return t
+
+  return nest.map_structure(_serialize_keras_tensor, kwargs)
+
+
+def _deserialize_keras_tensors(kwargs, layer_map):
+  """Deserializes Keras Tensors passed to `call`.."""
+
+  def _deserialize_keras_tensor(t):
+    """Deserializes a single Keras Tensor passed to `call`."""
+    if isinstance(t, tf_utils.ListWrapper):
+      t = t.as_list()
+      layer_name = t[0]
+      node_index = t[1]
+      tensor_index = t[2]
+
+      layer = layer_map[layer_name]
+      node = layer._inbound_nodes[node_index]
+      return nest.flatten(node.output_tensors)[tensor_index]
+    return t
+
+  kwargs = tf_utils.convert_inner_node_data(kwargs, wrap=True)
+  return nest.map_structure(_deserialize_keras_tensor, kwargs)
diff --git a/tensorflow/python/keras/engine/network_test.py b/tensorflow/python/keras/engine/network_test.py
index 53a2df6..5044ad3 100644
--- a/tensorflow/python/keras/engine/network_test.py
+++ b/tensorflow/python/keras/engine/network_test.py
@@ -810,7 +810,7 @@
     output = a(b(a(b(x))))
     m = keras.models.Model(x, output)
     m.run_eagerly = testing_utils.should_run_eagerly()
-    m._run_distributed = testing_utils.should_run_distributed()
+    m._experimental_run_tf_function = testing_utils.should_run_tf_function()
 
     output_val = m.predict(x_val)
 
@@ -838,7 +838,7 @@
 
     m = keras.models.Model(inputs=input_layer, outputs=output)
     m.run_eagerly = testing_utils.should_run_eagerly()
-    m._run_distributed = testing_utils.should_run_distributed()
+    m._experimental_run_tf_function = testing_utils.should_run_tf_function()
 
     x_val = np.random.random((10, 16, 9, 3))
     output_val = m.predict(x_val)
@@ -868,7 +868,7 @@
         optimizer='sgd',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     loss = model.train_on_batch(x, y)
     self.assertEqual(loss, 0)  # In inference mode, output is equal to input.
 
@@ -888,8 +888,26 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed()
-    )
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    history = model.fit(
+        x=[np.ones((10, 5, 10)), np.zeros((10, 5))],
+        y=np.zeros((10, 100)),
+        batch_size=2)
+    # All data is masked, returned values are 0's.
+    self.assertEqual(history.history['loss'][0], 0.0)
+    history = model.fit(
+        x=[np.ones((10, 5, 10)), np.ones((10, 5))],
+        y=np.zeros((10, 100)),
+        batch_size=2)
+    # Data is not masked, returned values are random.
+    self.assertGreater(history.history['loss'][0], 0.0)
+
+    model = keras.Model.from_config(model.get_config())
+    model.compile(
+        'sgd',
+        'mse',
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     history = model.fit(
         x=[np.ones((10, 5, 10)), np.zeros((10, 5))],
         y=np.zeros((10, 100)),
@@ -919,7 +937,22 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    history = model.fit(
+        x=[3 * np.ones((10, 10)), 7 * np.ones((10, 10))],
+        y=10 * np.ones((10, 10)),
+        batch_size=2)
+    # Check that second input was correctly added to first.
+    self.assertEqual(history.history['loss'][0], 0.0)
+
+    # Check serialization.
+    model = keras.Model.from_config(
+        model.get_config(), custom_objects={'MyAdd': MyAdd})
+    model.compile(
+        'sgd',
+        'mse',
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     history = model.fit(
         x=[3 * np.ones((10, 10)), 7 * np.ones((10, 10))],
         y=10 * np.ones((10, 10)),
@@ -945,7 +978,21 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    history = model.fit(
+        x=[3 * np.ones((10, 10)), 7 * np.ones((10, 10))],
+        y=10 * np.ones((10, 10)),
+        batch_size=2)
+    # Check that second input was correctly added to first.
+    self.assertEqual(history.history['loss'][0], 0.0)
+
+    model = keras.Model.from_config(
+        model.get_config(), custom_objects={'MaybeAdd': MaybeAdd})
+    model.compile(
+        'sgd',
+        'mse',
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     history = model.fit(
         x=[3 * np.ones((10, 10)), 7 * np.ones((10, 10))],
         y=10 * np.ones((10, 10)),
@@ -981,7 +1028,21 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    history = model.fit(
+        x=[np.ones((10, 10)), 2 * np.ones((10, 10)), 3 * np.ones((10, 10))],
+        y=15 * np.ones((10, 10)),
+        batch_size=2)
+    # Check that all inputs were correctly added.
+    self.assertEqual(history.history['loss'][0], 0.0)
+
+    model = keras.Model.from_config(
+        model.get_config(), custom_objects={'AddAll': AddAll})
+    model.compile(
+        'sgd',
+        'mse',
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     history = model.fit(
         x=[np.ones((10, 10)), 2 * np.ones((10, 10)), 3 * np.ones((10, 10))],
         y=15 * np.ones((10, 10)),
@@ -1006,13 +1067,14 @@
     o = keras.layers.add(o)
     model = keras.Model(i, o)
     model.run_eagerly = testing_utils.should_run_eagerly()
-    model._run_distributed = testing_utils.should_run_distributed()
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
 
     i2 = keras.layers.Input(shape=(3, 2, 1))
     o2 = model(i2)
     model2 = keras.Model(i2, o2)
     model2.run_eagerly = testing_utils.should_run_eagerly()
-    model2._run_distributed = testing_utils.should_run_distributed()
+    model2._experimental_run_tf_function = testing_utils.should_run_tf_function(
+    )
 
     x = np.random.random((4, 3, 2, 1))
     out = model2.predict(x)
@@ -1031,7 +1093,7 @@
         optimizer='sgd',
         metrics=['acc'],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     json_str = model.to_json()
     keras.models.model_from_json(json_str)
@@ -1331,7 +1393,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     model_input = np.random.randint(
         low=1, high=5, size=(10, 3, 4)).astype('float32')
@@ -1516,14 +1578,14 @@
     model.compile(
         'sgd',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(x, batch_size=2, epochs=1)
 
     model2 = model.from_config(model.get_config())
     model2.compile(
         'sgd',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model2.set_weights(initial_weights)
     model2.fit(x, batch_size=2, epochs=1)
 
@@ -1548,7 +1610,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(x, y, batch_size=2, epochs=1)
 
     model2 = model.from_config(model.get_config())
@@ -1556,7 +1618,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model2.set_weights(initial_weights)
     model2.fit(x, y, batch_size=2, epochs=1)
 
@@ -1625,5 +1687,36 @@
     self.assertEqual(len(model.weights), 1)
 
 
+@test_util.run_all_in_graph_and_eager_modes
+class DTypeTest(keras_parameterized.TestCase):
+
+  @testing_utils.enable_v2_dtype_behavior
+  def test_graph_network_dtype(self):
+    inputs = keras.Input((10,))
+    outputs = keras.layers.Dense(10)(inputs)
+    network = network_lib.Network(inputs, outputs)
+    self.assertEqual(network.dtype, 'float32')
+
+  @testing_utils.enable_v2_dtype_behavior
+  def test_subclassed_network_dtype(self):
+
+    class IdentityNetwork(network_lib.Network):
+
+      def call(self, inputs):
+        return inputs
+
+    network = IdentityNetwork()
+    self.assertEqual(network.dtype, 'float32')
+    self.assertEqual(network(array_ops.constant(1, 'float64')).dtype, 'float32')
+
+    network = IdentityNetwork(dtype='float16')
+    self.assertEqual(network.dtype, 'float16')
+    self.assertEqual(network(array_ops.constant(1, 'float64')).dtype, 'float16')
+
+    network = IdentityNetwork(autocast=False)
+    self.assertEqual(network.dtype, 'float32')
+    self.assertEqual(network(array_ops.constant(1, 'float64')).dtype, 'float64')
+
+
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/keras/engine/node.py b/tensorflow/python/keras/engine/node.py
index 9a7ecb7..4e00507 100644
--- a/tensorflow/python/keras/engine/node.py
+++ b/tensorflow/python/keras/engine/node.py
@@ -20,6 +20,7 @@
 
 from tensorflow.python.framework import ops
 from tensorflow.python.keras import backend
+from tensorflow.python.keras.engine import base_layer_utils
 from tensorflow.python.util import nest
 
 
@@ -111,6 +112,15 @@
     # Optional keyword arguments to layer's `call`.
     self.arguments = arguments
 
+    # Create Keras History for any Keras Tensors in `arguments`.
+    tensor_arguments = [
+        t for t in nest.flatten(self.arguments) if isinstance(t, ops.Tensor)
+    ]
+    for tensor_argument in tensor_arguments:
+      if base_layer_utils.needs_keras_history(
+          tensor_argument, ignore_call_context=True):
+        base_layer_utils.create_keras_history(tensor_argument)
+
     # Add nodes to all layers involved.
     for layer in nest.flatten(inbound_layers):
       if layer is not None:
@@ -121,15 +131,39 @@
     # accessor here.
     outbound_layer.inbound_nodes.append(self)
 
-  def iterate_inbound(self):
+  def iterate_inbound(self, include_arguments=False):
     """Returns a list of tuples representing the inbound data.
 
+    Arguments:
+      include_arguments: Whether to also iterate over any Keras Tensors
+        passed as args, kwargs.
+
     Returns:
       List of tuples like: (inbound_layer, node_index, tensor_index, tensor).
     """
-    return zip(
-        nest.flatten(self.inbound_layers), nest.flatten(self.node_indices),
-        nest.flatten(self.tensor_indices), nest.flatten(self.input_tensors))
+    inputs_inbound = list(
+        zip(
+            nest.flatten(self.inbound_layers),
+            nest.flatten(self.node_indices),
+            nest.flatten(self.tensor_indices),
+            nest.flatten(self.input_tensors)))
+
+    if include_arguments:
+      keras_tensor_arguments = [
+          kt for kt in nest.flatten(self.arguments)
+          if hasattr(kt, '_keras_history')
+      ]
+
+      def _get_inbound(keras_tensor):
+        kh = keras_tensor._keras_history
+        return kh.layer, kh.node_index, kh.tensor_index, keras_tensor
+
+      arguments_inbound = nest.map_structure(_get_inbound,
+                                             keras_tensor_arguments)
+
+      return inputs_inbound + arguments_inbound
+    else:
+      return inputs_inbound
 
   def _get_all_node_dependencies(self):
     """Returns all of the nodes this node immediately depends on."""
diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py
index e06a895..c1593c4 100644
--- a/tensorflow/python/keras/engine/sequential_test.py
+++ b/tensorflow/python/keras/engine/sequential_test.py
@@ -78,7 +78,7 @@
         loss='mse',
         optimizer='rmsprop',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     x = np.random.random((batch_size, input_dim))
     y = np.random.random((batch_size, num_classes))
     model.fit(x, y, epochs=1)
@@ -89,7 +89,7 @@
         loss='mse',
         optimizer='rmsprop',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     y = np.random.random((batch_size, num_hidden))
     model.fit(x, y, epochs=1)
 
@@ -118,7 +118,7 @@
         optimizer='rmsprop',
         metrics=[keras.metrics.CategoricalAccuracy()],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     self.assertEqual(len(model.layers), 2)
     with self.assertRaisesRegexp(
         ValueError, 'Weights for model .* have not yet been created'):
@@ -146,7 +146,7 @@
         optimizer='rmsprop',
         metrics=[keras.metrics.CategoricalAccuracy()],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     self.assertEqual(len(model.layers), 2)
     with self.assertRaisesRegexp(
         ValueError, 'Weights for model .* have not yet been created'):
@@ -295,7 +295,7 @@
         optimizer='rmsprop',
         metrics=[keras.metrics.CategoricalAccuracy()],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     self.assertFalse(model.built)
 
     x = np.random.random((batch_size, input_dim))
@@ -344,7 +344,7 @@
         'rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     self.assertEqual(len(model.outputs), 0)
     model.train_on_batch(np.zeros((1, 2)), np.zeros((1, 5)))
     self.assertEqual(len(model.outputs), 1)
@@ -359,7 +359,7 @@
         loss='mse',
         optimizer='rmsprop',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     x = np.random.random((2, 6))
     y = np.random.random((2, 5))
     model.fit(x, y, epochs=1)
@@ -385,14 +385,12 @@
 
   @keras_parameterized.run_all_keras_modes
   def test_string_input(self):
-    if testing_utils.should_run_distributed():
-      self.skipTest('b/137397816')
     seq = keras.Sequential([
         keras.layers.InputLayer(input_shape=(1,), dtype=dtypes.string),
         keras.layers.Lambda(lambda x: x[0])
     ])
     seq.run_eagerly = testing_utils.should_run_eagerly()
-    seq._run_distributed = testing_utils.should_run_distributed()
+    seq._experimental_run_tf_function = testing_utils.should_run_tf_function()
     preds = seq.predict([['tensorflow eager']])
     self.assertEqual(preds.shape, (1,))
 
@@ -472,7 +470,7 @@
         loss='mse',
         optimizer='rmsprop',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     x = np.random.random((2, 6))
     y = np.random.random((2, 5))
@@ -486,7 +484,7 @@
         loss='mse',
         optimizer='rmsprop',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     model.build((None, 6))
 
@@ -505,7 +503,7 @@
         weighted_metrics=['mae'],
         loss='categorical_crossentropy',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     x = {'dense_input': np.random.random((10, 1))}
     y = np.random.randint(num_classes, size=(10, 1))
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index ee898f8..9c88d29 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -26,7 +26,6 @@
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.ops import iterator_ops
 from tensorflow.python.distribute import distribution_strategy_context
-from tensorflow.python.distribute import multi_worker_util
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import monitoring
@@ -65,6 +64,7 @@
 from tensorflow.python.util import nest
 from tensorflow.python.util import serialization
 from tensorflow.python.util import tf_inspect
+from tensorflow.python.util.compat import collections_abc
 from tensorflow.python.util.tf_export import keras_export
 
 try:
@@ -143,9 +143,11 @@
 
   def __init__(self, *args, **kwargs):
     super(Model, self).__init__(*args, **kwargs)
+    _keras_api_gauge.get_cell('model').set(True)
     # initializing _distribution_strategy here since it is possible to call
     # predict on a model without compiling it.
     self._distribution_strategy = None
+    self._compile_time_distribution_strategy = None
 
     # This flag is used to track if the user is using the deprecated path of
     # passing distribution strategy to compile rather than creating the model
@@ -153,7 +155,7 @@
     self._compile_distribution = False
 
     self._run_eagerly = None
-    self._run_distributed = False
+    self._experimental_run_tf_function = False
 
   def get_weights(self):
     """Retrieves the weights of the model.
@@ -161,8 +163,10 @@
     Returns:
         A flat list of Numpy arrays.
     """
-    if self._distribution_strategy:
-      with self._distribution_strategy.scope():
+    strategy = (self._distribution_strategy or
+                self._compile_time_distribution_strategy)
+    if strategy:
+      with strategy.scope():
         return super(Model, self).get_weights()
     return super(Model, self).get_weights()
 
@@ -177,7 +181,7 @@
 
   @trackable.no_automatic_dependency_tracking
   def compile(self,
-              optimizer,
+              optimizer='rmsprop',
               loss=None,
               metrics=None,
               loss_weights=None,
@@ -239,19 +243,31 @@
         ValueError: In case of invalid arguments for
             `optimizer`, `loss`, `metrics` or `sample_weight_mode`.
     """
-    _keras_api_gauge.get_cell('compile').set(True)
     self._run_eagerly = kwargs.pop('run_eagerly', None)
-    self._run_distributed = kwargs.pop('run_distributed', False)
+    self._experimental_run_tf_function = kwargs.pop(
+        'experimental_run_tf_function', True)
+
+    if isinstance(optimizer, (list, tuple)):
+      self.optimizer = [optimizers.get(opt) for opt in optimizer]
+      is_any_optimizer_v1 = any(
+          isinstance(opt, optimizers.Optimizer) for opt in self.optimizer)
+    else:
+      self.optimizer = optimizers.get(optimizer)
+      is_any_optimizer_v1 = isinstance(self.optimizer, optimizers.Optimizer)
 
     if ((sample_weight_mode is not None)
         or (target_tensors is not None)
         or (weighted_metrics is not None)
-        or not context.executing_eagerly()):
+        or is_any_optimizer_v1
+        or not ops.executing_eagerly_outside_functions()):
       # Fallback out of things that aren't supported with v2 loops
-      self._run_distributed = False
+      self._experimental_run_tf_function = False
+
+    self._compile_time_distribution_strategy = (
+        distribution_strategy_context.get_strategy())
 
     if distribute is not None:
-      if tf2.enabled() or self._run_distributed:
+      if tf2.enabled() or self._experimental_run_tf_function:
         raise ValueError(
             'Distribute argument in compile is not available in TF 2.0 please '
             'create the model under the distribution strategy scope.')
@@ -269,12 +285,11 @@
           self._distribution_strategy = (
               distribution_strategy_context.get_strategy())
 
-    if not self._run_distributed:
+    if not self._experimental_run_tf_function:
       self._validate_compile_param_for_distribution_strategy(self.run_eagerly,
                                                              sample_weight_mode,
                                                              target_tensors,
                                                              weighted_metrics)
-    self.optimizer = optimizers.get(optimizer)
     # We've disabled automatic dependency tracking for this method, but do want
     # to add a checkpoint dependency on the optimizer if it's trackable.
     if isinstance(self.optimizer, trackable.Trackable):
@@ -301,6 +316,9 @@
     self._distributed_model_cache = {}
     self._distributed_function_cache = {}
 
+    # Clear any `_eager_losses` that was added.
+    self._clear_losses()
+
     if (not context.executing_eagerly() and
         self._distribution_strategy is not None):
       # Ensures a Session is created and configured correctly for Distribution
@@ -314,6 +332,7 @@
       # time the model gets called on training data.
       return
     self._is_compiled = True
+    _keras_api_gauge.get_cell('compile').set(True)
 
     # Prepare list of loss functions, same size of model outputs.
     self.loss_functions = training_utils.prepare_loss_functions(
@@ -479,9 +498,7 @@
 
     # Experiment training loop with default DS path.
     if (context.executing_eagerly()
-        and self._run_distributed
-        # TODO(scottzhu): Finish getting sequences working with the v2 loops.
-        and not isinstance(inputs, (data_utils.Sequence))
+        and self._experimental_run_tf_function
         and not distributed_training_utils.is_tpu_strategy(
             self._distribution_strategy)):
       try:
@@ -491,12 +508,17 @@
         logging.warning('Falling back from v2 loop because of error: '
                         '%s' % data_failure_exception)
       if valid_adapter:
-        return training_v2.Loop()
+        if self._in_multi_worker_mode():
+          return training_distributed.DistributionMultiWorkerTrainingLoop(
+              training_v2.Loop())
+        else:
+          return training_v2.Loop()
 
     # Case 1: distribution strategy.
     if self._distribution_strategy:
-      if multi_worker_util.in_multi_worker_mode():
-        return training_distributed.DistributionMultiWorkerTrainingLoop()
+      if self._in_multi_worker_mode():
+        return training_distributed.DistributionMultiWorkerTrainingLoop(
+            training_distributed.DistributionSingleWorkerTrainingLoop())
       else:
         return training_distributed.DistributionSingleWorkerTrainingLoop()
 
@@ -644,9 +666,9 @@
             and 'validation_steps' is None, validation
             will run until the `validation_data` dataset is exhausted.
         validation_freq: Only relevant if validation data is provided. Integer
-            or `collections.Container` instance (e.g. list, tuple, etc.). If an
-            integer, specifies how many training epochs to run before a new
-            validation run is performed, e.g. `validation_freq=2` runs
+            or `collections_abc.Container` instance (e.g. list, tuple, etc.).
+            If an integer, specifies how many training epochs to run before a
+            new validation run is performed, e.g. `validation_freq=2` runs
             validation every 2 epochs. If a Container, specifies the epochs on
             which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
             validation at the end of the 1st, 2nd, and 10th epochs.
@@ -677,7 +699,7 @@
         ValueError: In case of mismatch between the provided input data
             and what the model expects.
     """
-    _keras_api_gauge.get_cell('train').set(True)
+    _keras_api_gauge.get_cell('fit').set(True)
     # Legacy support
     if 'nb_epoch' in kwargs:
       logging.warning(
@@ -948,18 +970,20 @@
     Raises:
       ValueError: In case of invalid user-provided arguments.
     """
-    if self._run_distributed:
+    self._assert_compile_was_called()
+    self._check_call_args('train_on_batch')
+    if self._experimental_run_tf_function:
       outputs = training_v2_utils.train_on_batch(
           self, x, y=y, sample_weight=sample_weight,
           class_weight=class_weight, reset_metrics=reset_metrics)
+      outputs = (outputs['total_loss'] + outputs['output_losses'] +
+                 outputs['metrics'])
       outputs = [
           training_v2_utils._non_none_constant_value(v) for v in outputs]  # pylint: disable=protected-access
       if len(outputs) == 1:
         outputs = outputs[0]
       return outputs
 
-    self._assert_compile_was_called()
-    self._check_call_args('train_on_batch')
     # If at this point we are in the replica context, then it is okay to execute
     # the Eager code path.  The expected way to get here is to call `fit` that
     # calls `train_on_batch` on each replica.
@@ -977,12 +1001,14 @@
     # for each replica by `self._distribution_strategy` and the same code path
     # as Eager is expected to be taken.
     if self.run_eagerly or self._distribution_strategy:
-      outputs = training_eager.train_on_batch(
+      output_dict = training_eager.train_on_batch(
           self,
           x,
           y,
           sample_weights=sample_weights,
           output_loss_metrics=self._output_loss_metrics)
+      outputs = (output_dict['total_loss'] + output_dict['output_losses']
+                 + output_dict['metrics'])
       outputs = [
           training_v2_utils._non_none_constant_value(v) for v in outputs]  # pylint: disable=protected-access
     else:
@@ -1041,18 +1067,20 @@
     Raises:
         ValueError: In case of invalid user-provided arguments.
     """
-    if self._run_distributed:
+    self._assert_compile_was_called()
+    self._check_call_args('test_on_batch')
+    if self._experimental_run_tf_function:
       outputs = training_v2_utils.test_on_batch(
           self, x, y=y, sample_weight=sample_weight,
           reset_metrics=reset_metrics)
+      outputs = (outputs['total_loss'] + outputs['output_losses'] +
+                 outputs['metrics'])
       outputs = [
           training_v2_utils._non_none_constant_value(v) for v in outputs]  # pylint: disable=protected-access
       if len(outputs) == 1:
         outputs = outputs[0]
       return outputs
 
-    self._assert_compile_was_called()
-    self._check_call_args('test_on_batch')
     if (self._distribution_strategy and
         distribution_strategy_context.in_cross_replica_context()):
       raise NotImplementedError('`test_on_batch` is not supported for models '
@@ -1064,12 +1092,14 @@
     # If `self._distribution_strategy` is True, then we are in a replica context
     # at this point.
     if self.run_eagerly or self._distribution_strategy:
-      outputs = training_eager.test_on_batch(
+      output_dict = training_eager.test_on_batch(
           self,
           x,
           y,
           sample_weights=sample_weights,
           output_loss_metrics=self._output_loss_metrics)
+      outputs = (output_dict['total_loss'] + output_dict['output_losses']
+                 + output_dict['metrics'])
       outputs = [
           training_v2_utils._non_none_constant_value(v) for v in outputs]  # pylint: disable=protected-access
     else:
@@ -1106,7 +1136,7 @@
           expectations of the model.
     """
     self._check_call_args('predict_on_batch')
-    if self._run_distributed:
+    if self._experimental_run_tf_function:
       return training_v2_utils.predict_on_batch(self, x)
 
     if (self._distribution_strategy and
@@ -1121,7 +1151,7 @@
     # at this point.
     if self.run_eagerly or self._distribution_strategy:
       inputs = training_utils.cast_if_floating_dtype(inputs)
-      if isinstance(inputs, collections.Sequence):
+      if isinstance(inputs, collections_abc.Sequence):
         # Unwrap lists with only one input, as we do when training on batch
         if len(inputs) == 1:
           inputs = inputs[0]
@@ -1199,9 +1229,9 @@
             Optional for `Sequence`: if unspecified, will use
             the `len(validation_data)` as a number of steps.
         validation_freq: Only relevant if validation data is provided. Integer
-            or `collections.Container` instance (e.g. list, tuple, etc.). If an
-            integer, specifies how many training epochs to run before a new
-            validation run is performed, e.g. `validation_freq=2` runs
+            or `collections_abc.Container` instance (e.g. list, tuple, etc.).
+            If an integer, specifies how many training epochs to run before a
+            new validation run is performed, e.g. `validation_freq=2` runs
             validation every 2 epochs. If a Container, specifies the epochs on
             which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
             validation at the end of the 1st, 2nd, and 10th epochs.
@@ -1251,7 +1281,7 @@
     if self._distribution_strategy:
       raise NotImplementedError('`fit_generator` is not supported for '
                                 'models compiled with tf.distribute.Strategy.')
-    _keras_api_gauge.get_cell('train').set(True)
+    _keras_api_gauge.get_cell('fit_generator').set(True)
     self._check_call_args('fit_generator')
     return training_generator.fit_generator(
         self,
@@ -1325,8 +1355,9 @@
     if self._distribution_strategy:
       raise NotImplementedError('`evaluate_generator` is not supported for '
                                 'models compiled with tf.distribute.Strategy.')
-    _keras_api_gauge.get_cell('evaluate').set(True)
+    _keras_api_gauge.get_cell('evaluate_generator').set(True)
     self._check_call_args('evaluate_generator')
+
     return training_generator.evaluate_generator(
         self,
         generator,
@@ -1383,8 +1414,7 @@
     if self._distribution_strategy:
       raise NotImplementedError('`predict_generator` is not supported for '
                                 'models compiled with tf.distribute.Strategy.')
-    _keras_api_gauge.get_cell('predict').set(True)
-    self._check_call_args('predict_generator')
+    _keras_api_gauge.get_cell('predict_generator').set(True)
     return training_generator.predict_generator(
         self,
         generator,
@@ -2023,6 +2053,9 @@
   def _make_train_function(self):
     has_recompiled = self._recompile_weights_loss_and_weighted_metrics()
     self._check_trainable_weights_consistency()
+    if isinstance(self.optimizer, list):
+      raise ValueError('The `optimizer` in `compile` should be a single '
+                       'optimizer.')
     # If we have re-compiled the loss/weighted metric sub-graphs then create
     # train function even if one exists already. This is because
     # `_feed_sample_weights` list has been updated on re-copmpile.
@@ -2342,135 +2375,24 @@
     if check_steps:
       training_utils.check_steps_argument(x, steps, steps_name)
 
-    # First, we build/compile the model on the fly if necessary.
-    all_inputs = []
-    is_build_called = False
-    is_compile_called = False
-    # Whether this is a subclassed model that expects dictionary inputs
-    # rather than list inputs (e.g. FeatureColumn-based models).
-    dict_inputs = False
-
+    # First, we build the model on the fly if necessary.
     if not self.inputs:
-      # We need to use `x_input` to set the model inputs.
-
-      # If input data is a dataset iterator in graph mode or if it is an eager
-      # iterator and only one batch of samples is required, we fetch the data
-      # tensors from the iterator and then standardize them.
-      if isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
-        x_input, y_input, _ = training_utils.extract_tensors_from_dataset(x)
-      else:
-        x_input = x
-        y_input = y
-
-      # We type-check that `x_input` and `y_input` are either single arrays
-      # or lists of arrays, and extract a flat list of inputs from the passed
-      # structure.
-      if isinstance(x_input, (list, tuple)):
-        if not all(isinstance(v, np.ndarray) or
-                   tensor_util.is_tensor(v) for v in x_input):
-          raise ValueError('Please provide as model inputs either a single '
-                           'array or a list of arrays. You passed: x=' + str(x))
-        all_inputs += list(x_input)
-      elif isinstance(x_input, dict):
-        dict_inputs = True
-        keys = sorted(x_input.keys())
-        all_inputs = [x_input[k] for k in keys]
-      else:
-        if (not isinstance(x_input, np.ndarray) and
-            not tensor_util.is_tensor(x_input)):
-          raise ValueError('Please provide as model inputs either a single '
-                           'array or a list of arrays. You passed: x=' + str(x))
-        all_inputs.append(x_input)
-
-      # Now that we have a flat set of inputs, we make sure that none of them
-      # are CompositeTensors or CompositeTensorValues of any type (or scipy
-      # sparse arrays, which we treat as SparseTensor values). We cannot safely
-      # infer input data from an arbitrary composite tensor, so we don't try -
-      # users should explictly add composite tensor inputs to their subclassed
-      # models.
-      for input_tensor in all_inputs:
-        if (composite_tensor_utils.is_composite_or_composite_value(input_tensor)
-           ):
-          # TODO(b/132691975): Document subclass-model CT input handling.
-          raise ValueError(
-              'All SparseTensor and RaggedTensor inputs must be explicitly '
-              'declared using a keras.Input() with sparse=True or ragged=True. '
-              'We found an undeclared input %s. For Sequential models, please '
-              'add a keras.Input() as your first Layer. For subclassed models, '
-              'please call self._add_inputs() on your input set, which you can '
-              'create using keras.Input() for each input to your model.' %
-              (input_tensor,))
-
-      # Build the model using the retrieved inputs (value or symbolic).
-      # If values or generated from a dataset, then in symbolic-mode
-      # placeholders will be created to match the value shapes.
+      all_inputs, y_input, dict_inputs = self._build_model_with_inputs(x, y)
       is_build_called = True
-      if is_dataset:
-        def create_tensor_spec(t):
-          return tensor_spec.TensorSpec(t.shape, t.dtype)
-        cast_inputs = nest.map_structure(create_tensor_spec, x_input)
-      elif training_utils.has_tensors(x_input):
-        cast_inputs = training_utils.cast_if_floating_dtype(x_input)
-      else:
-        cast_inputs = x_input
-
-      self._set_inputs(cast_inputs)
     else:
-      y_input = y
+      all_inputs = []
+      # Whether this is a subclassed model that expects dictionary inputs
+      # rather than list inputs (e.g. FeatureColumn-based models).
       dict_inputs = isinstance(self.inputs, dict)
+      is_build_called = False
+      y_input = y
 
+    # Second, we compile the model on the fly if necessary, mostly for subclass
+    # models.
+    is_compile_called = False
     if not self._is_compiled and self.optimizer:
-      # On-the-fly compilation of the model.
-      if y_input is not None:
-        # We need to use `y` to set the model targets.
-        if training_utils.has_tensors(y_input):
-          y_input = training_utils.cast_if_floating_dtype(y_input)
-        if isinstance(y_input, (list, tuple)):
-          if not all(isinstance(v, np.ndarray) or
-                     tensor_util.is_tensor(v) for v in y_input):
-            raise ValueError('Please provide as model targets either a single '
-                             'array or a list of arrays. '
-                             'You passed: y=' + str(y))
-          all_inputs += list(y_input)
-        elif isinstance(y_input, dict):
-          raise ValueError('You cannot pass a dictionary as model targets.')
-        else:
-          if (not isinstance(y_input, np.ndarray) and
-              not tensor_util.is_tensor(y_input)):
-            raise ValueError('Please provide as model targets either a single '
-                             'array or a list of arrays. '
-                             'You passed: y=' + str(y))
-          all_inputs.append(y_input)
-
-      # Typecheck that all inputs are *either* value *or* symbolic.
-      # TODO(fchollet): this check could be removed in Eager mode?
-      if any(tensor_util.is_tensor(v) for v in all_inputs):
-        if not all(tensor_util.is_tensor(v) for v in all_inputs):
-          raise ValueError('Do not pass inputs that mix Numpy arrays and '
-                           'TensorFlow tensors. '
-                           'You passed: x=' + str(x) + '; y=' + str(y))
-
-      if is_dataset or context.executing_eagerly():
-        target_tensors = None
-      else:
-        # Handle target tensors if any passed.
-        if y_input is not None:
-          if not isinstance(y_input, (list, tuple)):
-            y_input = [y_input]
-          target_tensors = [v for v in y_input if _is_symbolic_tensor(v)]
-        else:
-          target_tensors = None
+      self._compile_from_inputs(all_inputs, y_input, x, y)
       is_compile_called = True
-      self.compile(
-          optimizer=self.optimizer,
-          loss=self.loss,
-          metrics=self._compile_metrics,
-          weighted_metrics=self._compile_weighted_metrics,
-          loss_weights=self.loss_weights,
-          target_tensors=target_tensors,
-          sample_weight_mode=self.sample_weight_mode,
-          run_eagerly=self.run_eagerly,
-          run_distributed=self._run_distributed)
 
     # In graph mode, if we had just set inputs and targets as symbolic tensors
     # by invoking build and compile on the model respectively, we do not have to
@@ -2588,7 +2510,7 @@
       y = []
       sample_weights = None
 
-    if self.stateful and batch_size:
+    if self.stateful and batch_size and not is_dataset:
       # Check that for stateful networks, number of samples is a multiple
       # of the static batch size.
       if x[0].shape[0] % batch_size != 0:
@@ -2604,6 +2526,108 @@
       x = dict(zip(feed_input_names, x))
     return x, y, sample_weights
 
+  def _build_model_with_inputs(self, inputs, targets):
+    """Build the model (set model inputs/outputs), mainly for subclass model."""
+    processed_inputs = []
+    is_dict_inputs = False
+    orig_inputs = inputs
+    # We need to use `inputs` to set the model inputs.
+    # If input data is a dataset iterator in graph mode or if it is an eager
+    # iterator and only one batch of samples is required, we fetch the data
+    # tensors from the iterator and then standardize them.
+    if isinstance(inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
+      inputs, targets, _ = training_utils.extract_tensors_from_dataset(inputs)
+    # We type-check that `inputs` and `targets` are either single arrays
+    # or lists of arrays, and extract a flat list of inputs from the passed
+    # structure.
+    training_utils.validate_input_types(inputs, orig_inputs)
+
+    if isinstance(inputs, (list, tuple)):
+      processed_inputs += list(inputs)
+    elif isinstance(inputs, dict):
+      is_dict_inputs = True
+      keys = sorted(inputs.keys())
+      processed_inputs = [inputs[k] for k in keys]
+    else:
+      processed_inputs.append(inputs)
+    # Now that we have a flat set of inputs, we make sure that none of them
+    # are CompositeTensors or CompositeTensorValues of any type (or scipy
+    # sparse arrays, which we treat as SparseTensor values). We cannot safely
+    # infer input data from an arbitrary composite tensor, so we don't try -
+    # users should explicitly add composite tensor inputs to their subclassed
+    # models.
+    for input_tensor in processed_inputs:
+      if composite_tensor_utils.is_composite_or_composite_value(input_tensor):
+        # TODO(b/132691975): Document subclass-model CT input handling.
+        raise ValueError(
+            'All SparseTensor and RaggedTensor inputs must be explicitly '
+            'declared using a keras.Input() with sparse=True or ragged=True. '
+            'We found an undeclared input %s. For Sequential models, please '
+            'add a keras.Input() as your first Layer. For subclassed models, '
+            'please call self._set_inputs() on your input set, which you can '
+            'create using keras.Input() for each input to your model.' %
+            (input_tensor,))
+    # Build the model using the retrieved inputs (value or symbolic).
+    # If values are generated from a dataset, then in symbolic-mode
+    # placeholders will be created to match the value shapes.
+    if isinstance(orig_inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2,
+                                iterator_ops.Iterator)):
+      def create_tensor_spec(t):
+        return tensor_spec.TensorSpec(t.shape, t.dtype)
+
+      cast_inputs = nest.map_structure(create_tensor_spec, inputs)
+    elif training_utils.has_tensors(inputs):
+      cast_inputs = training_utils.cast_if_floating_dtype(inputs)
+    else:
+      cast_inputs = inputs
+    self._set_inputs(cast_inputs)
+    return processed_inputs, targets, is_dict_inputs
+
+  def _compile_from_inputs(self, all_inputs, target, orig_inputs, orig_target):
+    if target is not None:
+      # We need to use `y` to set the model targets.
+      if training_utils.has_tensors(target):
+        target = training_utils.cast_if_floating_dtype_and_mismatch(
+            target, self.outputs)
+      training_utils.validate_input_types(target, orig_target,
+                                          allow_dict=False, field_name='target')
+      if isinstance(target, (list, tuple)):
+        all_inputs += list(target)
+      else:
+        all_inputs.append(target)
+    # Type check that all inputs are *either* value *or* symbolic.
+    # TODO(fchollet): this check could be removed in Eager mode?
+    if any(tensor_util.is_tensor(v) for v in all_inputs):
+      if not all(tensor_util.is_tensor(v) for v in all_inputs):
+        raise ValueError('Do not pass inputs that mix Numpy arrays and '
+                         'TensorFlow tensors. '
+                         'You passed: x=' + str(orig_inputs) +
+                         '; y=' + str(orig_target))
+    is_dataset = isinstance(orig_inputs, (dataset_ops.DatasetV1,
+                                          dataset_ops.DatasetV2,
+                                          iterator_ops.Iterator))
+    if is_dataset or context.executing_eagerly():
+      target_tensors = None
+    else:
+      # Handle target tensors if any passed.
+      if target is not None:
+        if not isinstance(target, (list, tuple)):
+          target = [target]
+        target_tensors = [v for v in target if _is_symbolic_tensor(v)]
+      else:
+        target_tensors = None
+
+    self.compile(
+        optimizer=self.optimizer,
+        loss=self.loss,
+        metrics=self._compile_metrics,
+        weighted_metrics=self._compile_weighted_metrics,
+        loss_weights=self.loss_weights,
+        target_tensors=target_tensors,
+        sample_weight_mode=self.sample_weight_mode,
+        run_eagerly=self.run_eagerly,
+        experimental_run_tf_function=self._experimental_run_tf_function)
+
   # TODO(omalleyt): Consider changing to a more descriptive function name.
   def _set_inputs(self, inputs, outputs=None, training=None):
     """Set model's input and output specs based on the input data received.
@@ -2823,6 +2847,24 @@
                          'training/testing. '
                          'Use `model.compile(optimizer, loss)`.')
 
+  def _in_multi_worker_mode(self):
+    """Method to infer if this `Model` is working in multi-worker settings.
+
+    Experimental. Signature and implementation are subject to change.
+
+    Returns:
+      Whether this model indicates it's working in multi-worker settings.
+    """
+    # If the model was compiled under the scope of a `tf.distribute.Strategy',
+    # `self._distribution_strategy` would have been set and model should infer
+    # that as the used strategy (even if it's out of strategy scope already).
+    strategy = self._distribution_strategy
+
+    # Otherwise, use the strategy whose scope this is in.
+    if not strategy and distribution_strategy_context.has_strategy():
+      strategy = distribution_strategy_context.get_strategy()
+    return strategy and strategy._in_multi_worker_mode()  # pylint: disable=protected-access
+
 
 class DistributedCallbackModel(Model):
   """Model that is used for callbacks with tf.distribute.Strategy."""
diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py
index c6cc786..f83369a 100644
--- a/tensorflow/python/keras/engine/training_arrays.py
+++ b/tensorflow/python/keras/engine/training_arrays.py
@@ -35,6 +35,7 @@
 from tensorflow.python.keras.utils.generic_utils import slice_arrays
 from tensorflow.python.keras.utils.mode_keys import ModeKeys
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import nest
 
 try:
   from scipy.sparse import issparse  # pylint: disable=g-import-not-at-top
@@ -90,9 +91,10 @@
         declaring one epoch finished and starting the next epoch. Ignored with
         the default value of `None`.
       validation_steps: Number of steps to run validation for (only if doing
-        validation from data tensors). Ignored with the default value of `None`.
+        validation from data tensors). Ignored with the default value of
+        `None`.
       validation_freq: Only relevant if validation data is provided. Integer or
-        `collections.Container` instance (e.g. list, tuple, etc.). If an
+        `collections_abc.Container` instance (e.g. list, tuple, etc.). If an
         integer, specifies how many training epochs to run before a new
         validation run is performed, e.g. `validation_freq=2` runs
         validation every 2 epochs. If a Container, specifies the epochs on
@@ -100,8 +102,8 @@
         validation at the end of the 1st, 2nd, and 10th epochs.
       mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
       validation_in_fit: if true, then this method is invoked from within
-        training iteration (for validation). In the case where `val_inputs` is a
-        dataset, this flag indicates that its iterator and feed values are
+        training iteration (for validation). In the case where `val_inputs` is
+        a dataset, this flag indicates that its iterator and feed values are
         already created so should properly reuse resources.
       prepared_feed_values_from_dataset: if True, `inputs` is a list of feed
         tensors returned from `_prepare_feed_values` call on the validation
@@ -138,7 +140,7 @@
     if steps_per_epoch is None:
       reset_dataset_after_each_epoch = True
       steps_per_epoch = training_utils.infer_steps_for_dataset(
-          inputs, steps_per_epoch, epochs=epochs, steps_name=steps_name)
+          model, inputs, steps_per_epoch, epochs=epochs, steps_name=steps_name)
     input_iterator = _get_iterator(inputs, model._distribution_strategy)
 
   # Enter tf.distribute.Strategy scope.
@@ -196,6 +198,7 @@
       # that determines the number of steps required. To avoid this issue,
       # set validation_steps here if validation_steps is None.
       validation_steps = training_utils.infer_steps_for_dataset(
+          model,
           val_inputs,
           validation_steps,
           epochs=epochs,
@@ -207,7 +210,8 @@
     val_samples_or_steps = validation_steps
   else:
     # Get num samples for printing.
-    val_samples_or_steps = val_inputs and val_inputs[0].shape[0] or None
+    val_samples_or_steps = val_inputs and nest.flatten(
+        val_inputs)[0].shape[0] or None
 
   if mode == ModeKeys.TRAIN and verbose:
     _print_train_info(num_samples_or_steps, val_samples_or_steps, is_dataset)
@@ -332,7 +336,7 @@
 
         if model._distribution_strategy:
           batch_outs = distributed_training_utils._per_replica_aggregate_batch(
-              batch_outs, model, mode)
+              model._distribution_strategy, batch_outs, model, mode)
 
         # Aggregate results.
         if step == 0:
@@ -429,7 +433,7 @@
           batch_size=batch_size,
           steps_per_epoch=validation_steps,
           callbacks=callbacks,
-          verbose=0,
+          verbose=verbose,
           mode=ModeKeys.TEST,
           validation_in_fit=True,
           prepared_feed_values_from_dataset=(val_iterator is not None),
diff --git a/tensorflow/python/keras/engine/training_arrays_test.py b/tensorflow/python/keras/engine/training_arrays_test.py
index 280c369..097d5ee 100644
--- a/tensorflow/python/keras/engine/training_arrays_test.py
+++ b/tensorflow/python/keras/engine/training_arrays_test.py
@@ -60,10 +60,11 @@
     # from the fit history should be equal to the final element in the output
     # of evaluating the model on the same eval dataset.
     self.assertAlmostEqual(history.history["val_mean_absolute_error"][-1],
-                           evaluation[-1])
+                           evaluation[-1], places=5)
 
 
-class PrintTrainingInfoTest(parameterized.TestCase):
+class PrintTrainingInfoTest(keras_parameterized.TestCase,
+                            parameterized.TestCase):
 
   @test_util.run_v1_only("Only relevant in graph mode.")
   def test_print_info_with_datasets(self):
@@ -110,6 +111,79 @@
     if do_validation:
       self.assertIn(", validate on 50 samples", mock_stdout.getvalue())
 
+  @keras_parameterized.run_all_keras_modes
+  def test_dict_float64_input(self):
+
+    class MyModel(keras.Model):
+
+      def __init__(self):
+        super(MyModel, self).__init__(self)
+        self.dense1 = keras.layers.Dense(10, activation="relu")
+        self.dense2 = keras.layers.Dense(10, activation="relu")
+        self.concat = keras.layers.Concatenate()
+        self.dense3 = keras.layers.Dense(1, activation="sigmoid")
+
+      def call(self, inputs):
+        d1 = self.dense1(inputs["one"])
+        d2 = self.dense2(inputs["two"])
+        concat = self.concat([d1, d2])
+        return self.dense3(concat)
+
+    model = MyModel()
+    model.compile(
+        loss="mae",
+        optimizer="adam",
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+
+    model.fit(
+        x={
+            "one": np.random.rand(100, 10, 1),
+            "two": np.random.rand(100, 10, 1)
+        },
+        y=np.random.rand(100, 10, 1))
+
+  def test_dict_validation_input(self):
+    """Test case for GitHub issue 30122."""
+    train_input_0 = np.random.rand(1000, 1)
+    train_input_1 = np.random.rand(1000, 1)
+    train_labels = np.random.rand(1000, 1)
+    val_input_0 = np.random.rand(1000, 1)
+    val_input_1 = np.random.rand(1000, 1)
+    val_labels = np.random.rand(1000, 1)
+
+    input_0 = keras.Input(shape=(None,), name="input_0")
+    input_1 = keras.Input(shape=(None,), name="input_1")
+
+    class my_model(keras.Model):
+
+      def __init__(self):
+        super(my_model, self).__init__(self)
+        self.hidden_layer_0 = keras.layers.Dense(100, activation="relu")
+        self.hidden_layer_1 = keras.layers.Dense(100, activation="relu")
+        self.concat = keras.layers.Concatenate()
+        self.out_layer = keras.layers.Dense(1, activation="sigmoid")
+
+      def call(self, inputs=[input_0, input_1]):
+        activation_0 = self.hidden_layer_0(inputs["input_0"])
+        activation_1 = self.hidden_layer_1(inputs["input_1"])
+        concat = self.concat([activation_0, activation_1])
+        return self.out_layer(concat)
+
+    model = my_model()
+    model.compile(loss="mae", optimizer="adam")
+
+    model.fit(
+        x={
+            "input_0": train_input_0,
+            "input_1": train_input_1
+        },
+        y=train_labels,
+        validation_data=({
+            "input_0": val_input_0,
+            "input_1": val_input_1
+        }, val_labels))
+
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/keras/engine/training_dataset_test.py b/tensorflow/python/keras/engine/training_dataset_test.py
index 145465b..aba2021 100644
--- a/tensorflow/python/keras/engine/training_dataset_test.py
+++ b/tensorflow/python/keras/engine/training_dataset_test.py
@@ -52,10 +52,10 @@
   @keras_parameterized.run_with_all_model_types
   @keras_parameterized.run_all_keras_modes
   def test_calling_model_on_same_dataset(self):
-    if ((not testing_utils.should_run_eagerly())
-        and testing_utils.get_model_type() == 'subclass'
-        and context.executing_eagerly()
-        and (not testing_utils.should_run_distributed())):
+    if ((not testing_utils.should_run_eagerly()) and
+        testing_utils.get_model_type() == 'subclass' and
+        context.executing_eagerly() and
+        (not testing_utils.should_run_tf_function())):
       self.skipTest('b/120673224')
 
     model = testing_utils.get_small_mlp(1, 4, input_dim=3)
@@ -67,7 +67,7 @@
         loss,
         metrics=metrics,
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     inputs = np.zeros((10, 3), np.float32)
     targets = np.zeros((10, 4), np.float32)
@@ -93,7 +93,7 @@
         loss,
         metrics=metrics,
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     inputs = np.zeros((10, 3), np.float32)
     targets = np.zeros((10, 4), np.float32)
@@ -175,7 +175,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     input_a_np = np.random.random((10, 3)).astype(dtype=np.float32)
     input_b_np = np.random.random((10, 3)).astype(dtype=np.float32)
@@ -232,7 +232,7 @@
         loss,
         metrics=metrics,
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     inputs = np.zeros((10, 3), np.float32)
     targets = np.zeros((10, 4), np.float32)
@@ -279,7 +279,7 @@
         optimizer,
         loss='sparse_categorical_crossentropy',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     inputs = np.zeros((10, 3), dtype=np.float32)
     targets = np.random.randint(0, 4, size=10, dtype=np.int32)
@@ -304,7 +304,7 @@
         'rmsprop',
         loss='mae',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     inputs = np.zeros((40, 2), dtype=np.float32)
     inputs[10:20, :] = 2
@@ -375,7 +375,7 @@
         'rmsprop',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     inputs = np.zeros((100, 3), dtype=np.float32)
     targets = np.random.randint(0, 4, size=100, dtype=np.int32)
@@ -399,7 +399,7 @@
         'rmsprop',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     inputs = np.zeros((100, 3), dtype=np.float32)
     targets = np.random.randint(0, 4, size=100, dtype=np.int32)
@@ -439,7 +439,7 @@
         'rmsprop',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     inputs = np.zeros((100, 3), dtype=np.float32)
     targets = np.random.randint(0, 4, size=100, dtype=np.int32)
@@ -476,7 +476,7 @@
         'rmsprop',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     inputs = np.zeros((100, 3), dtype=np.float32)
     targets = np.random.randint(0, 4, size=100, dtype=np.int32)
@@ -542,7 +542,7 @@
         metrics=['accuracy', metrics_module.BinaryAccuracy()],
         optimizer='rmsprop',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     np.random.seed(123)
     x = np.random.randint(10, size=(100, 4)).astype(np.float32)
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 547a4f9..7213af9 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -149,7 +149,7 @@
           (only if doing validation from data tensors).
           Ignored with the default value of `None`.
       validation_freq: Only relevant if validation data is provided. Integer or
-          `collections.Container` instance (e.g. list, tuple, etc.). If an
+          `collections.abc.Container` instance (e.g. list, tuple, etc.). If an
           integer, specifies how many training epochs to run before a new
           validation run is performed, e.g. `validation_freq=2` runs
           validation every 2 epochs. If a Container, specifies the epochs on
@@ -646,7 +646,7 @@
 
     if dist_utils.is_tpu_strategy(model._distribution_strategy):
       steps_per_epoch = training_utils.infer_steps_for_dataset(
-          dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch')
+          model, dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch')
       if steps_per_epoch is None:
         raise ValueError('Number of steps could not be inferred from the data, '
                          'please pass the steps_per_epoch argument.')
@@ -703,7 +703,7 @@
 
     if dist_utils.is_tpu_strategy(model._distribution_strategy):
       steps = training_utils.infer_steps_for_dataset(
-          dataset, steps, steps_name='steps')
+          model, dataset, steps, steps_name='steps')
       if steps is None:
         raise ValueError('Number of steps could not be inferred from the data, '
                          'please pass the steps argument.')
@@ -740,7 +740,7 @@
         allow_partial_batch=True)
     if dist_utils.is_tpu_strategy(model._distribution_strategy):
       steps = training_utils.infer_steps_for_dataset(
-          dataset, steps, steps_name='steps')
+          model, dataset, steps, steps_name='steps')
       if steps is None:
         raise ValueError('Number of steps could not be inferred from the data, '
                          'please pass the steps argument.')
@@ -756,16 +756,16 @@
         callbacks=callbacks)
 
 
-def train_with_multi_worker(fn):
+def train_with_multi_worker(method):
   """Decorator that handles multi worker training with distribution strategy."""
 
-  def wrapper(instance, model, **kwargs):
-
+  def wrapper(model, **kwargs):
     def _worker_fn(_):
       callbacks = kwargs.pop('callbacks', None)
-      filtered_callbacks = dist_utils.filter_distributed_callbacks(callbacks)
+      filtered_callbacks = dist_utils.filter_distributed_callbacks(
+          callbacks, model)
       kwargs['callbacks'] = filtered_callbacks
-      return fn(instance, model, **kwargs)
+      return method(model, **kwargs)
 
     return dc.run_distribute_coordinator(
         _worker_fn,
@@ -775,10 +775,20 @@
   return wrapper
 
 
-class DistributionMultiWorkerTrainingLoop(DistributionSingleWorkerTrainingLoop):
+class DistributionMultiWorkerTrainingLoop(training_utils.TrainingLoop):
   """Training loop for distribution strategy with multiple worker."""
 
-  fit = train_with_multi_worker(DistributionSingleWorkerTrainingLoop.fit)
-  evaluate = train_with_multi_worker(
-      DistributionSingleWorkerTrainingLoop.evaluate)
-  # Currently predict is still using the single worker implementation.
+  def __init__(self, single_worker_loop):
+    self._single_worker_loop = single_worker_loop
+
+  def fit(self, *args, **kwargs):
+    return train_with_multi_worker(self._single_worker_loop.fit)(
+        *args, **kwargs)
+
+  def evaluate(self, *args, **kwargs):
+    return train_with_multi_worker(self._single_worker_loop.evaluate)(
+        *args, **kwargs)
+
+  def predict(self, *args, **kwargs):
+    # Currently predict is still using the single worker implementation.
+    return self._single_worker_loop.predict(*args, **kwargs)
diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py
index a1470fe..ab16efc 100644
--- a/tensorflow/python/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/engine/training_eager.py
@@ -19,13 +19,10 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
-
 import numpy as np
 
 from tensorflow.python.eager.backprop import GradientTape
 from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
 from tensorflow.python.keras import backend
 from tensorflow.python.keras.engine import training_utils
 from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
@@ -60,6 +57,14 @@
   # Invoke all(weighted and unweighted) metrics.
   metric_results = []
   if targets:
+    # Insert None values corresponding to the targets that need to be skipped
+    # on the model.
+    if len(model._targets) != len(targets):
+      new_targets = [
+          None if t is None else targets.pop(0) for t in model._targets
+      ]
+      targets = new_targets
+
     metric_results = model._handle_metrics(
         outputs,
         targets=targets,
@@ -121,6 +126,16 @@
 
   outs = model(inputs, **kwargs)
   outs = nest.flatten(outs)
+
+  if targets:
+    targets = training_utils.cast_if_floating_dtype_and_mismatch(targets, outs)
+  # TODO(sallymatson/psv): check if we should do same mismatch fix for weights
+  if sample_weights:
+    sample_weights = [
+        training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val))
+        if val is not None else None for val in sample_weights
+    ]
+
   masks = [getattr(t, '_keras_mask', None) for t in outs]
   targets = nest.flatten(targets)
 
@@ -245,10 +260,16 @@
     if training:
       trainable_weights = model._unique_trainable_weights
       if trainable_weights:
-        grads = tape.gradient(scaled_total_loss, trainable_weights)
-        if isinstance(model.optimizer, loss_scale_optimizer.LossScaleOptimizer):
-          grads = model.optimizer.get_unscaled_gradients(grads)
-        model.optimizer.apply_gradients(zip(grads, trainable_weights))
+        # TODO(tanzheny) b/132690565: Provide mechanism for user to override
+        # model.train_on_batch.
+        if hasattr(model, '_backwards'):
+          model._backwards(tape, scaled_total_loss)
+        else:
+          grads = tape.gradient(scaled_total_loss, trainable_weights)
+          if isinstance(model.optimizer,
+                        loss_scale_optimizer.LossScaleOptimizer):
+            grads = model.optimizer.get_unscaled_gradients(grads)
+          model.optimizer.apply_gradients(zip(grads, trainable_weights))
       else:
         logging.warning('The list of trainable weights is empty. Make sure that'
                         ' you are not setting model.trainable to False before '
@@ -273,26 +294,13 @@
         loss values.
 
   Returns:
-      total loss and the loss associated with each output.
+      Dict with three items:
+        'total_loss': list with a single tensor for overall loss,
+        'output_losses': list of tensors for loss corresponding to each of the
+          model output. Could be a empty list when model has only one output.
+        'metrics': list of tensors for metric specified.
   """
-  if isinstance(inputs, collections.Sequence):
-    if len(inputs) and tensor_util.is_tensor(inputs[0]):
-      inputs = training_utils.cast_if_floating_to_model_input_dtypes(inputs,
-                                                                     model)
-      if targets:
-        targets = training_utils.cast_if_floating_dtype(targets)
-    else:
-      inputs = training_utils.cast_if_floating_to_model_input_dtypes(
-          [ops.convert_to_tensor(val) for val in inputs], model)
-      if targets:
-        targets = training_utils.cast_if_floating_dtype(
-            [ops.convert_to_tensor(val) for val in targets])
-  if sample_weights:
-    sample_weights = [
-        training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val))
-        if val is not None else None for val in sample_weights
-    ]
-
+  inputs = training_utils.cast_to_model_input_dtypes(inputs, model)
   outs, total_loss, output_losses, masks = (
       _process_single_batch(
           model,
@@ -306,9 +314,9 @@
   metrics_results = _eager_metrics_fn(
       model, outs, targets, sample_weights=sample_weights, masks=masks)
   total_loss = nest.flatten(total_loss)
-  results = total_loss + output_losses + metrics_results
-
-  return results
+  return {'total_loss': total_loss,
+          'output_losses': output_losses,
+          'metrics': metrics_results}
 
 
 def test_on_batch(model,
@@ -327,25 +335,14 @@
         loss values.
 
   Returns:
-      total loss, loss and metrics associated with each output.
+      Dict with three items:
+        'total_loss': single tensor for overall loss,
+        'output_losses': list of tensors for loss corresponding to each of the
+          model output. Could be a empty list when model has only one output.
+        'metrics': list of tensors for metric specified.
   """
-  if isinstance(inputs, collections.Sequence):
-    if len(inputs) and tensor_util.is_tensor(inputs[0]):
-      inputs = training_utils.cast_if_floating_to_model_input_dtypes(inputs,
-                                                                     model)
-      if targets:
-        targets = training_utils.cast_if_floating_dtype(targets)
-    else:
-      inputs = training_utils.cast_if_floating_to_model_input_dtypes(
-          [ops.convert_to_tensor(val) for val in inputs], model)
-      if targets:
-        targets = training_utils.cast_if_floating_dtype(
-            [ops.convert_to_tensor(val) for val in targets])
-  if sample_weights:
-    sample_weights = [
-        training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val))
-        if val is not None else None for val in sample_weights
-    ]
+  inputs = training_utils.cast_to_model_input_dtypes(inputs, model)
+
   with backend.eager_learning_phase_scope(0):
     outs, total_loss, output_losses, masks = (
         _model_loss(
@@ -360,6 +357,7 @@
   metrics_results = _eager_metrics_fn(
       model, outs, targets, sample_weights=sample_weights, masks=masks)
   total_loss = nest.flatten(total_loss)
-  results = total_loss + output_losses + metrics_results
 
-  return results
+  return {'total_loss': total_loss,
+          'output_losses': output_losses,
+          'metrics': metrics_results}
diff --git a/tensorflow/python/keras/engine/training_eager_test.py b/tensorflow/python/keras/engine/training_eager_test.py
index e74c5b6..c60c1af 100644
--- a/tensorflow/python/keras/engine/training_eager_test.py
+++ b/tensorflow/python/keras/engine/training_eager_test.py
@@ -24,7 +24,6 @@
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.eager import context
 from tensorflow.python.framework import ops
-from tensorflow.python.framework import test_util
 from tensorflow.python.keras import keras_parameterized
 from tensorflow.python.keras import metrics as metrics_module
 from tensorflow.python.keras import testing_utils
@@ -35,7 +34,7 @@
 
 class TrainingTest(keras_parameterized.TestCase):
 
-  @test_util.run_in_graph_and_eager_modes()
+  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
   def test_dynamic_model_has_trainable_weights(self):
     if not context.executing_eagerly():
       # Only test Eager modes, as Graph mode is not relevant for dynamic models.
@@ -52,7 +51,10 @@
         return self.dense(inputs)
 
     model = DynamicModel()
-    model.compile('rmsprop', 'mae')
+    model.compile(
+        'rmsprop', 'mae',
+        run_eagerly=True,
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     hist = model.fit(np.zeros((1, 1)), np.zeros((1, 1)))
     self.assertEqual(hist.history['loss'][-1], 1)
     self.assertEqual(len(model.trainable_weights), 2)
@@ -88,7 +90,7 @@
         metrics=metrics,
         loss_weights=loss_weights,
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function(),
         sample_weight_mode=None)
 
     input_a = array_ops.zeros(shape=(10, 3))
@@ -159,7 +161,7 @@
         loss,
         metrics=metrics,
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     inputs = array_ops.zeros(shape=(10, 3))
     targets = array_ops.zeros(shape=(10, 4))
@@ -244,7 +246,7 @@
         loss='sparse_categorical_crossentropy',
         optimizer=rmsprop.RMSprop(learning_rate=0.001),
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     x = np.ones((100, 4))
     np.random.seed(123)
     y = np.random.randint(0, 1, size=(100, 1))
@@ -265,7 +267,7 @@
         loss='sparse_categorical_crossentropy',
         optimizer=rmsprop.RMSprop(learning_rate=0.001),
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     x = np.ones((100, 4), dtype=np.float32)
     np.random.seed(123)
     y = np.random.randint(0, 1, size=(100, 1))
diff --git a/tensorflow/python/keras/engine/training_generator.py b/tensorflow/python/keras/engine/training_generator.py
index b033c98..5194a22 100644
--- a/tensorflow/python/keras/engine/training_generator.py
+++ b/tensorflow/python/keras/engine/training_generator.py
@@ -80,7 +80,7 @@
       validation_steps: Total number of steps (batches of samples) before
         declaring validation finished.
       validation_freq: Only relevant if validation data is provided. Integer or
-        `collections.Container` instance (e.g. list, tuple, etc.). If an
+        `collections.abc.Container` instance (e.g. list, tuple, etc.). If an
         integer, specifies how many training epochs to run before a new
         validation run is performed, e.g. `validation_freq=2` runs
         validation every 2 epochs. If a Container, specifies the epochs on
@@ -133,7 +133,7 @@
     if steps_per_epoch is None:
       reset_dataset_after_each_epoch = True
       steps_per_epoch = training_utils.infer_steps_for_dataset(
-          data, steps_per_epoch, epochs=epochs, steps_name=steps_name)
+          model, data, steps_per_epoch, epochs=epochs, steps_name=steps_name)
 
   # Convert to a format that supports `next(generator)`.
   generator, steps_per_epoch = convert_to_generator_like(
@@ -318,7 +318,7 @@
           use_multiprocessing=use_multiprocessing,
           max_queue_size=max_queue_size,
           callbacks=callbacks,
-          verbose=0,
+          verbose=verbose,
           mode=ModeKeys.TEST,
           steps_name='validation_steps')
 
diff --git a/tensorflow/python/keras/engine/training_generator_test.py b/tensorflow/python/keras/engine/training_generator_test.py
index 6db967c..5362eac 100644
--- a/tensorflow/python/keras/engine/training_generator_test.py
+++ b/tensorflow/python/keras/engine/training_generator_test.py
@@ -152,7 +152,7 @@
         optimizer=rmsprop.RMSprop(1e-3),
         metrics=['mae', metrics_module.CategoricalAccuracy()],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     self._sleep_at_end = True
     model.evaluate_generator(custom_generator(),
@@ -180,7 +180,7 @@
     model = testing_utils.get_small_mlp(
         num_hidden=3, num_classes=4, input_dim=2)
     model.run_eagerly = testing_utils.should_run_eagerly()
-    model._run_distributed = testing_utils.should_run_distributed()
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
 
     self._sleep_at_end = True
     model.predict_generator(custom_generator(),
@@ -221,7 +221,7 @@
         optimizer=rmsprop.RMSprop(1e-3),
         metrics=['mae', metrics_module.CategoricalAccuracy()],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     model.fit_generator(custom_generator(mode=3),
                         steps_per_epoch=5,
@@ -259,7 +259,7 @@
         loss='mse',
         optimizer=rmsprop.RMSprop(1e-3),
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     err_msg = 'Output of generator should be a tuple of 1 or 2 or 3 elements'
     with self.assertRaisesRegex(ValueError, err_msg):
@@ -305,7 +305,7 @@
         rmsprop.RMSprop(0.001),
         'binary_crossentropy',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(
         ones_generator(),
         steps_per_epoch=2,
diff --git a/tensorflow/python/keras/engine/training_integration_test.py b/tensorflow/python/keras/engine/training_integration_test.py
new file mode 100644
index 0000000..90a1180
--- /dev/null
+++ b/tensorflow/python/keras/engine/training_integration_test.py
@@ -0,0 +1,207 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""End-to-end tests for a variety of small models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import itertools
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python import keras
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
+from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras import testing_utils
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+def _conv2d_filter(**kwargs):
+  """Convolution with non-default strides and dilation rate is not supported."""
+  return kwargs['strides'] <= 1 or kwargs['dilation_rate'] <= 1
+
+
+# Scheme: (layer_class, data_shape, fuzz_dims, constructor_args, filter_fn)
+#   layer_class:
+#     A keras Layer class to be tested.
+#   data_shape:
+#     The shape of the input data. (not including batch dim)
+#   fuzz_dims:
+#     Dimensions which can be unspecified during model construction. For
+#     instance, if data_shape is (2, 5) and fuzz_dims is (False, True), a pass
+#     with model input shape of (2, None) will also be performed.
+#   constructor_args:
+#     An OrderedDict (to ensure consistent test names) with a key and a list
+#     of values to test. Test cases will be generated for the Cartesian product
+#     of all constructor args, so adding more fields can cause the drastically
+#     increase the testing load.
+#   filter_fn:
+#     If not None, this function will be called on each set of generated
+#     constructor args, and prevents generation of contradictory combinations.
+#     A True return value indicates a valid test.
+_LAYERS_TO_TEST = [
+    (keras.layers.Dense, (1,), (False,), collections.OrderedDict([
+        ('units', [1])]), None),
+    (keras.layers.Activation, (2, 2), (True, True), collections.OrderedDict([
+        ('activation', ['relu'])]), None),
+    (keras.layers.Dropout, (16,), (False,), collections.OrderedDict([
+        ('rate', [0.25])]), None),
+    (keras.layers.BatchNormalization, (8, 8, 3), (True, True, False),
+     collections.OrderedDict([
+         ('axis', [3]),
+         ('center', [True, False]),
+         ('scale', [True, False])
+     ]), None),
+    (keras.layers.Conv1D, (8, 8), (False, False), collections.OrderedDict([
+        ('filters', [1]),
+        ('kernel_size', [1, 3]),
+        ('strides', [1, 2]),
+        ('padding', ['valid', 'same']),
+        ('use_bias', [True]),
+        ('kernel_regularizer', ['l2']),
+        ('data_format', ['channels_last'])
+    ]), None),
+    (keras.layers.Conv2D, (8, 8, 3), (True, True, False),
+     collections.OrderedDict([
+         ('filters', [1]),
+         ('kernel_size', [1, 3]),
+         ('strides', [1, 2]),
+         ('padding', ['valid', 'same']),
+         ('use_bias', [True, False]),
+         ('kernel_regularizer', ['l2']),
+         ('dilation_rate', [1, 2]),
+         ('data_format', ['channels_last'])
+     ]), _conv2d_filter),
+    (keras.layers.LSTM, (4, 4), (False, False), collections.OrderedDict([
+        ('units', [1]),
+        ('kernel_regularizer', ['l2']),
+        ('dropout', [0, 0.5]),
+        ('stateful', [True, False]),
+        ('unroll', [True, False]),
+        ('return_sequences', [True, False])
+    ]), None),
+]
+
+
+def _gather_test_cases():
+  cases = []
+  for layer_type, inp_shape, fuzz_dims, arg_dict, filter_fn in _LAYERS_TO_TEST:
+    arg_combinations = [[(k, i) for i in v] for k, v in arg_dict.items()]  # pylint: disable=g-complex-comprehension
+    for arguments in itertools.product(*arg_combinations):
+      layer_kwargs = {k: v for k, v in arguments}
+      if filter_fn is not None and not filter_fn(**layer_kwargs):
+        continue
+
+      name = '_{}_{}'.format(layer_type.__name__,
+                             '_'.join('{}_{}'.format(*i) for i in arguments))
+      cases.append((name, layer_type, inp_shape, fuzz_dims, layer_kwargs))
+  return cases
+
+
+OUTPUT_TEST_CASES = _gather_test_cases()
+
+
+class CoreLayerIntegrationTest(keras_parameterized.TestCase):
+  """Test that layers and models produce the correct tensor types."""
+
+  # In v1 graph there are only symbolic tensors.
+  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
+  @parameterized.named_parameters(*OUTPUT_TEST_CASES)
+  def test_layer_output_type(self, layer_to_test, input_shape, _, layer_kwargs):
+    layer = layer_to_test(**layer_kwargs)
+
+    input_data = np.ones(shape=(2,) + input_shape, dtype=np.float32)
+    layer_result = layer(input_data)
+
+    inp = keras.layers.Input(shape=input_shape, batch_size=2)
+    model = keras.models.Model(inp, layer_to_test(**layer_kwargs)(inp))
+    model_result = model(input_data)
+
+    for x in [layer_result, model_result]:
+      if not isinstance(x, ops.Tensor):
+        raise ValueError('Tensor or EagerTensor expected, got type {}'
+                         .format(type(x)))
+
+      if isinstance(x, ops.EagerTensor) != context.executing_eagerly():
+        expected_type = (ops.EagerTensor if context.executing_eagerly()
+                         else ops.Tensor)
+        raise ValueError('Expected type {}, got type {}'
+                         .format(expected_type, type(x)))
+
+  def _run_fit_eval_predict(self, layer_to_test, input_shape, data_shape,
+                            layer_kwargs):
+    batch_size = 2
+    run_eagerly = testing_utils.should_run_eagerly()
+    experimental_run_tf_function = testing_utils.should_run_tf_function()
+
+    def map_fn(_):
+      x = keras.backend.random_uniform(shape=data_shape)
+      y = keras.backend.random_uniform(shape=(1,))
+      return x, y
+
+    dataset = dataset_ops.DatasetV2.range(4).map(map_fn).batch(batch_size)
+
+    inp = keras.layers.Input(shape=input_shape, batch_size=batch_size)
+    layer = layer_to_test(**layer_kwargs)(inp)
+
+    # Condense the output down to a single scalar.
+    layer = keras.layers.Flatten()(layer)
+    layer = keras.layers.Lambda(
+        lambda x: math_ops.reduce_mean(x, keepdims=True))(layer)
+    layer = keras.layers.Dense(1, activation=None)(layer)
+    model = keras.models.Model(inp, layer)
+
+    model.compile(loss='mse', optimizer='sgd', run_eagerly=run_eagerly,
+                  experimental_run_tf_function=experimental_run_tf_function)
+    model.fit(dataset, verbose=2, epochs=2)
+
+    model.compile(loss='mse', optimizer='sgd', run_eagerly=run_eagerly,
+                  experimental_run_tf_function=experimental_run_tf_function)
+    model.fit(dataset.repeat(2), verbose=2, epochs=2, steps_per_epoch=2)
+
+    eval_dataset = dataset_ops.DatasetV2.range(4).map(map_fn).batch(batch_size)
+    model.evaluate(eval_dataset, verbose=2)
+
+    def pred_map_fn(_):
+      return keras.backend.random_uniform(shape=data_shape)
+
+    pred_dataset = dataset_ops.DatasetV2.range(4)
+    pred_dataset = pred_dataset.map(pred_map_fn).batch(batch_size)
+    model.predict(pred_dataset, verbose=2)
+
+  @keras_parameterized.run_all_keras_modes(always_skip_v1=False)
+  @parameterized.named_parameters(*OUTPUT_TEST_CASES)
+  def test_model_loops(self, layer_to_test, input_shape, fuzz_dims,
+                       layer_kwargs):
+    self._run_fit_eval_predict(layer_to_test, input_shape,
+                               input_shape, layer_kwargs)
+
+    if any(fuzz_dims):
+      fuzzed_shape = []
+      for dim, should_fuzz in zip(input_shape, fuzz_dims):
+        fuzzed_shape.append(None if should_fuzz else dim)
+
+      self._run_fit_eval_predict(layer_to_test, fuzzed_shape,
+                                 input_shape, layer_kwargs)
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 8672abe..3aaad89 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -70,7 +70,7 @@
         optimizer='adam',
         loss=loss,
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     self.assertEqual(model.loss, loss)
 
     loss = losses.get(loss)
@@ -120,7 +120,7 @@
         optimizer='adam',
         loss=loss,
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     self.assertEqual(model.loss_functions[0].fn, losses.mean_squared_error)
     self.assertEqual(model.loss_functions[1].fn, losses.mean_absolute_error)
     self.assertAllEqual(model._loss_weights_list, [1., 1.])
@@ -131,7 +131,7 @@
         optimizer='adam',
         loss=loss,
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     self.assertEqual(model.loss_functions[0].fn, losses.mean_absolute_error)
     self.assertEqual(model.loss_functions[1].fn, losses.mean_squared_error)
     self.assertAllEqual(model._loss_weights_list, [1., 1.])
@@ -145,7 +145,7 @@
         loss='mse',
         loss_weights=loss_weights,
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     self.assertAllEqual(model._loss_weights_list, [1., 2.])
 
   def test_compile_with_multi_output_and_loss_weights_dict(self):
@@ -183,7 +183,7 @@
           optimizer='adam',
           loss=['mse', 'mae'],
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
 
   @keras_parameterized.run_all_keras_modes
   def test_compile_with_incorrect_loss_key(self):
@@ -197,7 +197,7 @@
           optimizer='adam',
           loss={'unknown_output': 'mse'},
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
 
   @keras_parameterized.run_all_keras_modes
   def test_compile_with_incorrect_loss_weights_size(self):
@@ -210,7 +210,7 @@
           loss='mse',
           loss_weights=[1., 2.],
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
 
   @keras_parameterized.run_all_keras_modes
   def test_compile_with_incorrect_loss_weights_key(self):
@@ -225,7 +225,7 @@
           loss='mse',
           loss_weights={'unknown_output': 1.},
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
 
   @keras_parameterized.run_all_keras_modes
   def test_compile_with_incorrect_sample_weight_mode(self):
@@ -240,7 +240,7 @@
           loss='mse',
           sample_weight_mode={'unknown': 'temporal'},
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
 
 
 class TrainingTest(keras_parameterized.TestCase):
@@ -262,7 +262,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     hist = model.fit(x=np.array([0.]), y=np.array([0.]))
     self.assertAllClose(hist.history['loss'][0], 10000)
 
@@ -281,7 +281,7 @@
         'sgd',
         loss='mae',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     inputs = np.ones((40, 2), dtype=np.float32)
     targets = np.ones((40, 1), dtype=np.float32)
@@ -315,7 +315,7 @@
         'sgd',
         loss='mae',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     inputs = np.ones((40, 2), dtype=np.float32)
     targets = np.ones((40, 1), dtype=np.float32)
@@ -334,6 +334,28 @@
     self.assertAllClose(history.history['val_loss'][0], 1.0)
 
   @keras_parameterized.run_all_keras_modes
+  @keras_parameterized.run_with_all_model_types
+  def test_target_dtype_matches_output(self):
+
+    def _loss_fn(labels, preds):
+      self.assertEqual(labels.dtype, preds.dtype)
+      return labels - preds
+
+    layers = [keras.layers.Dense(10, dtype=np.float64),
+              keras.layers.Dense(10, dtype=np.float64)]
+    model = testing_utils.get_model_from_layers(layers, input_shape=(1,))
+    inputs = np.ones(10, dtype=np.float64)
+    targets = np.ones(10, dtype=np.float64)
+    model.compile(
+        'sgd',
+        loss=_loss_fn,
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    model.train_on_batch(inputs, targets)
+    model.test_on_batch(inputs, targets)
+    self.assertEqual(model.predict(inputs).dtype, np.float64)
+
+  @keras_parameterized.run_all_keras_modes
   def test_fit_and_validate_nested_training_arg(self):
 
     class NestedReturnTraining(keras.layers.Layer):
@@ -362,7 +384,7 @@
         'sgd',
         loss='mae',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     inputs = np.ones((40, 2), dtype=np.float32)
     targets = np.ones((40, 1), dtype=np.float32)
@@ -383,8 +405,6 @@
   @keras_parameterized.run_with_all_model_types(exclude_models='sequential')
   @keras_parameterized.run_all_keras_modes
   def test_fit_on_arrays(self):
-    if testing_utils.should_run_distributed():
-      self.skipTest('b/137397816')
     input_a = keras.layers.Input(shape=(3,), name='input_a')
     input_b = keras.layers.Input(shape=(3,), name='input_b')
 
@@ -404,7 +424,7 @@
         metrics=[metrics_module.CategoricalAccuracy(), 'mae'],
         loss_weights=loss_weights,
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     input_a_np = np.random.random((10, 3))
     input_b_np = np.random.random((10, 3))
@@ -430,14 +450,6 @@
         verbose=2)
     model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
 
-    # Test model with input data as a list of lists
-    model.fit(
-        [np.ndarray.tolist(input_a_np), np.ndarray.tolist(input_b_np)],
-        [output_d_np, output_e_np],
-        epochs=2,
-        batch_size=5,
-        verbose=2)
-
     # Test with validation data
     model.fit(
         [input_a_np, input_b_np], [output_d_np, output_e_np],
@@ -525,7 +537,7 @@
         loss,
         metrics=[metrics_module.CategoricalAccuracy(), 'mae'],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(
         [input_a_np, input_b_np], [output_d_np, output_e_np],
         epochs=1,
@@ -546,7 +558,7 @@
           metrics=metrics,
           loss_weights=loss_weights,
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(
         [input_a_np, input_b_np], [output_d_np, output_e_np],
         epochs=1,
@@ -586,7 +598,7 @@
         optimizer,
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     # This will work
     model.fit([input_a_np], output_d_np, epochs=1)
     # TODO(gsundeep) Test only works in eager, file ticket
@@ -598,8 +610,7 @@
     input_a_np = np.random.random((10, 3))
     input_b_np = np.random.random((10, 4))
 
-    model.fit([np.ndarray.tolist(input_a_np)],
-              [np.ndarray.tolist(input_b_np)],
+    model.fit([np.ndarray.tolist(input_a_np)], [np.ndarray.tolist(input_b_np)],
               epochs=2,
               batch_size=5,
               verbose=2)
@@ -626,7 +637,7 @@
         loss_weights=loss_weights,
         sample_weight_mode=None,
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     input_a_np = np.random.random((10, 3))
     input_b_np = np.random.random((10, 3))
@@ -712,7 +723,7 @@
           optimizer,
           'binary_crossentropy',
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
       model.fit(x, y, batch_size=2, epochs=5)
       loss[reg] = model.evaluate(x, y)
     self.assertLess(loss[None], loss['l2'])
@@ -733,7 +744,7 @@
         optimizer,
         'binary_crossentropy',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     loss = model.test_on_batch(x, y)
     self.assertAlmostEqual(0.01, loss, places=4)
 
@@ -751,7 +762,7 @@
         optimizer,
         'binary_crossentropy',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     x = np.ones((10, 10), 'float32')
     y = np.ones((10, 1), 'float32')
@@ -794,8 +805,6 @@
 
   @keras_parameterized.run_all_keras_modes
   def test_training_on_sparse_data_with_dense_placeholders(self):
-    if testing_utils.should_run_distributed():
-      self.skipTest('b/137397816')
     # TODO(kaftan) Test seems to not work, file ticket
     if testing_utils.should_run_eagerly() and context.executing_eagerly():
       self.skipTest('Skipping running model eagerly.')
@@ -821,7 +830,7 @@
         'mse',
         metrics=['mae', metrics_module.CategoricalAccuracy()],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(test_inputs, test_outputs,
               epochs=1, batch_size=2, validation_split=0.5)
     model.evaluate(test_inputs, test_outputs, batch_size=2)
@@ -843,7 +852,7 @@
         optimizer=keras.optimizers.Adam(lr=0.0001),
         metrics=['accuracy'],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
   @keras_parameterized.run_all_keras_modes
   def test_that_trainable_disables_updates(self):
@@ -862,7 +871,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     assert not model.updates
 
     x1 = model.predict(val_a)
@@ -875,7 +884,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     assert model.updates
 
     model.train_on_batch(val_a, val_out)
@@ -887,7 +896,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     assert not model.updates
 
     x1 = model.predict(val_a)
@@ -994,8 +1003,6 @@
 
   @keras_parameterized.run_all_keras_modes
   def test_mismatched_output_shape_and_target_shape(self):
-    if testing_utils.should_run_distributed():
-      self.skipTest('b/137397816')
     model = keras.Sequential([
         keras.layers.Dense(2, input_shape=(3, 4)),
         keras.layers.Dense(5),
@@ -1004,7 +1011,7 @@
         RMSPropOptimizer(learning_rate=0.001),
         loss='sparse_categorical_crossentropy',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     # Test with Numpy data
     x_train = np.random.random((10, 3, 4))
     y_train = np.random.randint(0, 5, size=(10, 3))
@@ -1048,7 +1055,7 @@
         RMSPropOptimizer(learning_rate=0.001),
         loss='binary_crossentropy',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     with test.mock.patch.object(sys, 'stdout', mock_stdout):
       model.fit(
           np.ones((10, 10), 'float32'), np.ones((10, 1), 'float32'), epochs=10)
@@ -1237,7 +1244,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     class ValCounter(keras.callbacks.Callback):
 
@@ -1266,7 +1273,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     with self.assertRaisesRegexp(
         ValueError, '`validation_steps` should not be specified if '
@@ -1297,7 +1304,7 @@
         keras.optimizer_v2.gradient_descent.SGD(0.025),
         loss=keras.losses.MeanAbsoluteError(),
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     x = np.array([[0.], [1.], [2.]])
     y = np.array([[0.5], [2.], [3.5]])
@@ -1323,7 +1330,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
     self.assertEqual(loss, 2 * 3)
 
@@ -1382,32 +1389,28 @@
   # TODO(b/131372221): Make this work with subclassed models.
   @keras_parameterized.run_with_all_model_types(exclude_models=['subclass'])
   @keras_parameterized.run_all_keras_modes
+  @testing_utils.enable_v2_dtype_behavior
   def test_model_dtype(self):
 
     class AssertTypeLayer(keras.layers.Layer):
 
-      def __init__(self, assert_type=None, **kwargs):
-        super(AssertTypeLayer, self).__init__(**kwargs)
-        self.assert_type = assert_type
-
       def call(self, inputs):
-        assert inputs.dtype.name == self.assert_type, (
+        assert inputs.dtype.name == self.dtype, (
             'Input tensor has type %s which does not match assert type %s' %
             (inputs.dtype.name, self.assert_type))
         return inputs + 1.
 
     for dtype in ('float16', 'float32', 'float64'):
-      model = testing_utils.get_model_from_layers([AssertTypeLayer(dtype)],
-                                                  input_shape=(10,),
-                                                  input_dtype=dtype)
+      model = testing_utils.get_model_from_layers(
+          [AssertTypeLayer(dtype=dtype)], input_shape=(10,))
       model.compile(
           'sgd',
           'mse',
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
 
-      x = np.ones((10, 10), dtype=dtype)
-      y = np.ones((10, 10), dtype=dtype)
+      x = np.ones((10, 10))
+      y = np.ones((10, 10))
       model.fit(x, y)
       model.test_on_batch(x, y)
       model(x)
@@ -1437,11 +1440,11 @@
         loss='mse',
         optimizer='sgd',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(x, x, epochs=1)
 
     if (testing_utils.should_run_eagerly() or
-        testing_utils.should_run_distributed()):
+        testing_utils.should_run_tf_function()):
       expected_training_arg = True
     else:
       expected_training_arg = keras.backend.symbolic_learning_phase()
@@ -1522,7 +1525,7 @@
             optimizer,
             loss=None,
             run_eagerly=testing_utils.should_run_eagerly(),
-            run_distributed=testing_utils.should_run_distributed())
+            experimental_run_tf_function=testing_utils.should_run_tf_function())
 
   @keras_parameterized.run_all_keras_modes
   def test_compile_warning_for_loss_missing_output(self):
@@ -1544,7 +1547,7 @@
                 'dense_1': metrics_module.CategoricalAccuracy(),
             },
             run_eagerly=testing_utils.should_run_eagerly(),
-            run_distributed=testing_utils.should_run_distributed())
+            experimental_run_tf_function=testing_utils.should_run_tf_function())
         msg = ('Output dense_1 missing from loss dictionary. We assume this '
                'was done on purpose. The fit and evaluate APIs will not be '
                'expecting any data to be passed to dense_1.')
@@ -1552,8 +1555,6 @@
 
   @keras_parameterized.run_all_keras_modes
   def test_invalid_steps_per_epoch_usage(self):
-    if testing_utils.should_run_distributed():
-      self.skipTest('b/137397816')
     x = keras.layers.Input(shape=(1,))
     y = keras.layers.Dense(1)(x)
 
@@ -1562,10 +1563,11 @@
         'sgd',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     err_msg = 'When passing input data as arrays, do not specify'
 
-    if testing_utils.should_run_eagerly() and not model._run_distributed:
+    if testing_utils.should_run_eagerly(
+    ) and not model._experimental_run_tf_function:
       with self.assertRaisesRegex(ValueError, err_msg):
         model.fit(x=np.zeros((100, 1)), y=np.ones((100, 1)), steps_per_epoch=4)
 
@@ -1607,7 +1609,7 @@
         weighted_metrics=['mae', metrics_module.CategoricalAccuracy()],
         optimizer=RMSPropOptimizer(learning_rate=learning_rate),
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     np.random.seed(1337)
     (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
@@ -1676,7 +1678,7 @@
         weighted_metrics=['mae', metrics_module.CategoricalAccuracy()],
         loss='categorical_crossentropy',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     np.random.seed(43)
     (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
@@ -1788,7 +1790,7 @@
           weighted_metrics=['mae', metrics_module.CategoricalAccuracy()],
           sample_weight_mode='temporal',
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
 
       model.fit(
           temporal_x_train,
@@ -1823,8 +1825,6 @@
   @keras_parameterized.run_all_keras_modes
   @keras_parameterized.run_with_all_model_types(exclude_models='sequential')
   def test_fit_with_incorrect_weights(self):
-    if testing_utils.should_run_distributed():
-      self.skipTest('b/137397816')
     input_a = keras.layers.Input(shape=(3,), name='input_a')
     input_b = keras.layers.Input(shape=(3,), name='input_b')
 
@@ -1838,7 +1838,7 @@
         optimizer='adam',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     x = np.random.random((10, 3))
     y = np.random.random((10, 2))
 
@@ -1858,8 +1858,6 @@
 
   @keras_parameterized.run_all_keras_modes
   def test_class_weight_invalid_use_case(self):
-    if testing_utils.should_run_distributed():
-      self.skipTest('b/137397816')
     num_classes = 5
     train_samples = 1000
     test_samples = 1000
@@ -1879,7 +1877,7 @@
           optimizer,
           loss='binary_crossentropy',
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
 
       (x_train, y_train), _ = testing_utils.get_test_data(
           train_samples=train_samples,
@@ -1901,7 +1899,7 @@
             loss='binary_crossentropy',
             sample_weight_mode=[],
             run_eagerly=testing_utils.should_run_eagerly(),
-            run_distributed=testing_utils.should_run_distributed())
+            experimental_run_tf_function=testing_utils.should_run_tf_function())
 
       # Build multi-output model
       x = keras.Input((3,))
@@ -1912,7 +1910,7 @@
           optimizer,
           loss='mse',
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
       x_np = np.random.random((10, 3))
       y_np = np.random.random((10, 4))
       w_np = np.random.random((10,))
@@ -1942,9 +1940,6 @@
   @keras_parameterized.run_all_keras_modes
   def test_default_sample_weight(self):
     """Verifies that fit works without having to set sample_weight."""
-    if testing_utils.should_run_distributed():
-      self.skipTest('b/137397816')
-
     num_classes = 5
     input_dim = 5
     timesteps = 3
@@ -1967,7 +1962,7 @@
           loss='mse',
           sample_weight_mode=[None],
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
       model.fit(x, y, epochs=1, batch_size=10)
 
       # sample_weight_mode is a list and mode value is `temporal`
@@ -1976,7 +1971,7 @@
           loss='mse',
           sample_weight_mode=['temporal'],
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
       model.fit(x, y, epochs=1, batch_size=10)
 
       # sample_weight_mode is a dict and mode value is None
@@ -1985,7 +1980,7 @@
           loss='mse',
           sample_weight_mode={'time_distributed': None},
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
       model.fit(x, y, epochs=1, batch_size=10)
 
       # sample_weight_mode is a dict and mode value is `temporal`
@@ -1994,7 +1989,7 @@
           loss='mse',
           sample_weight_mode={'time_distributed': 'temporal'},
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
       model.fit(x, y, epochs=1, batch_size=10)
 
       # sample_weight_mode is a not a list/dict and mode value is None
@@ -2003,7 +1998,7 @@
           loss='mse',
           sample_weight_mode=None,
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
       model.fit(x, y, epochs=1, batch_size=10)
 
       # sample_weight_mode is a not a list/dict and mode value is `temporal`
@@ -2012,7 +2007,7 @@
           loss='mse',
           sample_weight_mode='temporal',
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
       model.fit(x, y, epochs=1, batch_size=10)
 
   def test_sample_weight_tensor(self):
@@ -2092,7 +2087,7 @@
         loss='mse',
         optimizer=RMSPropOptimizer(learning_rate=0.001),
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     return model
 
   @keras_parameterized.run_with_all_model_types
@@ -2138,7 +2133,7 @@
         loss='mse',
         optimizer=RMSPropOptimizer(learning_rate=0.001),
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     y = np.random.random((5, 3))
     model.train_on_batch(x, y)
 
@@ -2157,7 +2152,7 @@
         'rmsprop',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.trainable = True
     model.train_on_batch(x, y)
     self.assertRaises(Warning)
@@ -2173,7 +2168,7 @@
           'rmsprop',
           'mse',
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
       out = model.predict(x)
       model.train_on_batch(x, y)
       out_2 = model.predict(x)
@@ -2187,7 +2182,7 @@
           'rmsprop',
           'mse',
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
       out = model.predict(x)
       model.train_on_batch(x, y)
       out_2 = model.predict(x)
@@ -2305,7 +2300,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     inputs2 = keras.Input(10)
     outputs2 = shared_layer(inputs2)
@@ -2315,7 +2310,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     x, y = np.ones((10, 10)), np.ones((10, 10))
 
@@ -2349,7 +2344,7 @@
         loss,
         metrics=['mae', metrics_module.CategoricalAccuracy()],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     inputs = keras.backend.zeros(shape=(10, 3))
     targets = keras.backend.zeros(shape=(10, 4))
@@ -2403,7 +2398,7 @@
         metrics=['mae', metrics_module.CategoricalAccuracy()],
         loss_weights=loss_weights,
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     input_a_tf = keras.backend.zeros(shape=(10, 3))
     input_b_tf = keras.backend.zeros(shape=(10, 3))
@@ -2935,7 +2930,7 @@
         loss='mae',
         metrics=metrics,
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     mse_metric = 'mse' if tf2.enabled() else 'mean_squared_error'
     reference_metric_names = [
@@ -2968,7 +2963,7 @@
         metrics=[acc_obj],
         optimizer=RMSPropOptimizer(learning_rate=0.001),
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     x_train = np.random.random((100, 4))
     y_train = np.random.random((100, 1))
@@ -3002,7 +2997,7 @@
         metrics=[keras.metrics.MeanSquaredError()],
         weighted_metrics=[keras.metrics.MeanSquaredError()],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     # list of list of metrics.
     model.compile(
@@ -3019,7 +3014,7 @@
              keras.metrics.Accuracy()]
         ],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     # dict of metrics.
     model.compile(
@@ -3042,12 +3037,10 @@
             ],
         },
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
   @keras_parameterized.run_all_keras_modes
   def test_invalid_metrics(self):
-    if testing_utils.should_run_distributed():
-      self.skipTest('b/137397816')
     num_classes = 5
     input_dim = 5
 
@@ -3062,7 +3055,7 @@
           loss='categorical_crossentropy',
           metrics=metrics_module.CategoricalAccuracy(),
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     inp = keras.layers.Input(shape=(1,))
     x = keras.layers.Dense(3, activation='relu')(inp)
@@ -3087,7 +3080,7 @@
               'output_3': 'mse',
           },
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     with self.assertRaisesRegex(
         ValueError,
@@ -3101,7 +3094,7 @@
               'output_3': 'mse',
           },
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
 
   @keras_parameterized.run_all_keras_modes
   def test_metrics_masking(self):
@@ -3119,7 +3112,7 @@
           loss='mse',
           weighted_metrics=['accuracy'],
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
 
       # verify that masking is applied.
       x = np.array([[[1], [1]], [[1], [1]], [[0], [0]]])
@@ -3156,7 +3149,7 @@
         'sgd',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     inputs = np.ones(shape=(10, 1))
     targets = np.ones(shape=(10, 1))
@@ -3178,8 +3171,6 @@
 
   @keras_parameterized.run_all_keras_modes
   def test_add_metric_in_model_call(self):
-    if testing_utils.should_run_distributed():
-      self.skipTest('b/137397816')
 
     class TestModel(keras.Model):
 
@@ -3201,7 +3192,7 @@
         loss='mse',
         optimizer=RMSPropOptimizer(0.01),
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     x = np.ones(shape=(10, 1))
     y = np.ones(shape=(10, 2))
@@ -3244,7 +3235,7 @@
         loss='mse',
         optimizer=RMSPropOptimizer(0.01),
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     x = np.ones(shape=(10, 1))
     y = np.ones(shape=(10, 2))
@@ -3303,7 +3294,7 @@
         loss='mse',
         metrics=[metrics_module.Accuracy('metric_4')],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     # Verify that the metrics added using `compile` and `add_metric` API are
     # included
@@ -3331,7 +3322,7 @@
         optimizer=RMSPropOptimizer(0.01),
         metrics=[metrics_module.Accuracy('acc')],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     x = np.ones(shape=(10, 1))
     y = np.ones(shape=(10, 2))
     model.fit(x, y, epochs=2, batch_size=5, validation_data=(x, y))
@@ -3341,8 +3332,6 @@
 
   @keras_parameterized.run_all_keras_modes
   def test_multiple_add_metric_calls(self):
-    if testing_utils.should_run_distributed():
-      self.skipTest('b/137397816')
 
     class TestModel(keras.Model):
 
@@ -3364,7 +3353,7 @@
         loss='mse',
         optimizer=RMSPropOptimizer(0.01),
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     x = np.ones(shape=(10, 1))
     y = np.ones(shape=(10, 2))
@@ -3407,7 +3396,7 @@
           loss='mse',
           optimizer=RMSPropOptimizer(0.01),
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
       model.fit(x, y, epochs=2, batch_size=5, validation_data=(x, y))
 
   @keras_parameterized.run_all_keras_modes
@@ -3430,7 +3419,7 @@
         loss='mse',
         optimizer=RMSPropOptimizer(0.01),
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     x = np.ones(shape=(10, 1))
     y = np.ones(shape=(10, 2))
@@ -3458,7 +3447,7 @@
         loss='mse',
         optimizer=RMSPropOptimizer(0.01),
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     x = np.ones(shape=(10, 1))
     y = np.ones(shape=(10, 2))
 
@@ -3496,7 +3485,7 @@
         optimizer=keras.optimizer_v2.gradient_descent.SGD(0.1),
         metrics=[metrics_module.MeanAbsoluteError(name='mae_3')],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     x = np.array([[0.], [1.], [2.]])
     y = np.array([[0.5], [2.], [3.5]])
@@ -3533,7 +3522,7 @@
         loss='mse',
         metrics=[metrics_module.Accuracy('acc')],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     self.assertEqual([m.name for m in inner_model.metrics],
                      ['acc', 'mean', 'mean1'])
@@ -3549,7 +3538,7 @@
         loss='mse',
         metrics=[metrics_module.Accuracy('acc2')],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     self.assertEqual([m.name for m in outer_model.metrics],
                      ['acc2', 'mean', 'mean1', 'mean2'])
 
@@ -3633,7 +3622,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(x, y, batch_size=2, epochs=1)
     self.assertEqual(self.evaluate(layer.counter), 5)
 
@@ -3647,7 +3636,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(x, y, batch_size=2, epochs=1)
     self.assertEqual(self.evaluate(layer.counter), 5)
     layer.trainable = False
@@ -3655,7 +3644,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(x, y, batch_size=2, epochs=1)
     self.assertEqual(self.evaluate(layer.counter), 5)
 
@@ -3669,7 +3658,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(x, y, batch_size=2, epochs=1)
     self.assertEqual(self.evaluate(layer.counter), 5)
 
@@ -3703,7 +3692,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     x, y = np.ones((10, 10)), np.ones((10, 1))
     model.fit(x, y, batch_size=2, epochs=1)
     self.assertAllEqual(self.evaluate(bn.moving_mean), np.zeros((10,)))
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index f4c2b26..b45bcbc 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -34,7 +34,6 @@
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.ops import iterator_ops
 from tensorflow.python.data.ops import readers
-from tensorflow.python.distribute import multi_worker_util
 from tensorflow.python.eager import context
 from tensorflow.python.framework import composite_tensor_utils
 from tensorflow.python.framework import dtypes
@@ -54,6 +53,7 @@
 from tensorflow.python.ops.losses import util as tf_losses_utils
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import nest
+from tensorflow.python.util.compat import collections_abc
 
 
 @six.add_metaclass(abc.ABCMeta)
@@ -474,9 +474,14 @@
   Raises:
       ValueError: in case of improperly formatted user-provided data.
   """
+  try:
+    data_len = len(data)
+  except TypeError:
+    # For instance if data is `None` or a symbolic Tensor.
+    data_len = None
+
   if not names:
-    if (data is not None and hasattr(data, '__len__') and len(data) and
-        not isinstance(data, dict)):
+    if data_len and not isinstance(data, dict):
       raise ValueError(
           'Error when checking model ' + exception_prefix + ': '
           'expected no data, but got:', data)
@@ -1068,7 +1073,7 @@
     return loss
 
   # Deserialize loss configuration, if needed.
-  if isinstance(loss, collections.Mapping):
+  if isinstance(loss, collections_abc.Mapping):
     loss = losses.get(loss)
 
   # Custom callable class.
@@ -1126,6 +1131,24 @@
         'Received: x=%s, validation_split=%f' % (x, validation_split))
 
 
+def validate_input_types(inp, orig_inp, allow_dict=True, field_name='inputs'):
+  """Helper function to validate either inputs or targets."""
+  if isinstance(inp, (list, tuple)):
+    if not all(isinstance(v, np.ndarray) or
+               tensor_util.is_tensor(v) for v in inp):
+      raise ValueError(
+          'Please provide as model inputs either a single array or a list of '
+          'arrays. You passed: {}={}'.format(field_name, str(orig_inp)))
+  elif isinstance(inp, dict):
+    if not allow_dict:
+      raise ValueError(
+          'You cannot pass a dictionary as model {}.'.format(field_name))
+  elif not isinstance(inp, np.ndarray) and not tensor_util.is_tensor(inp):
+    raise ValueError(
+        'Please provide as model inputs either a single array or a list of '
+        'arrays. You passed: {}={}'.format(field_name, orig_inp))
+
+
 def check_generator_arguments(y=None, sample_weight=None,
                               validation_split=None):
   """Validates arguments passed when using a generator."""
@@ -1191,13 +1214,41 @@
 
 
 def cast_single_tensor(x, dtype=None):
-  x = ops.convert_to_tensor(x)
+  if isinstance(x, np.ndarray):
+    x = ops.convert_to_tensor(x)
   dtype = dtype or K.floatx()
   if x.dtype.is_floating:
     return math_ops.cast(x, dtype=dtype)
   return x
 
 
+def cast_if_floating_dtype_and_mismatch(targets, outputs):
+  """Returns target data tensors using correct datatype.
+
+  Checks that each target and output pair are the same datatype. If not, casts
+  the target to the output's datatype.
+
+  Args:
+    targets: tensor or list of targets.
+    outputs: tensor or list of outputs.
+
+  Returns:
+    Targets in appropriate datatype.
+  """
+  if tensor_util.is_tensor(targets):
+    # There is one target, so output[0] should be the only output.
+    return cast_single_tensor(targets, dtype=outputs[0].dtype)
+  new_targets = []
+  for target, out in zip(targets, outputs):
+    if isinstance(target, np.ndarray):
+      target = ops.convert_to_tensor(target)
+    if target.dtype != out.dtype:
+      new_targets.append(cast_single_tensor(target, dtype=out.dtype))
+    else:
+      new_targets.append(target)
+  return new_targets
+
+
 def cast_if_floating_dtype(x):
   """Casts the given data tensors to the default floating point type.
 
@@ -1211,11 +1262,9 @@
   return nest.map_structure(cast_single_tensor, x)
 
 
-def cast_if_floating_to_model_input_dtypes(x, model):
+def cast_to_model_input_dtypes(x, model):
   """Casts the given data tensors to the dtypes of the model inputs.
 
-  Casts only if the input is already a floating point type.
-
   Args:
     x: tensor or list/tuple of tensors.
     model: The model.
@@ -1224,10 +1273,8 @@
     Converted input. Each tensor is casted to the corresponding input in
     `model.inputs`.
   """
-  # TODO(b/131372221): We should probably cast even if the input is not
-  # floating-point.
   input_dtypes = nest.map_structure(lambda t: t.dtype, model.inputs)
-  return nest.map_structure(cast_single_tensor, x, input_dtypes)
+  return nest.map_structure(math_ops.cast, x, input_dtypes)
 
 
 def prepare_sample_weight_modes(training_endpoints, sample_weight_mode):
@@ -1288,7 +1335,7 @@
       ValueError: If loss is a dict with keys not in model output names,
           or if loss is a list with len not equal to model outputs.
   """
-  if isinstance(loss, collections.Mapping):
+  if isinstance(loss, collections_abc.Mapping):
     generic_utils.check_for_unexpected_keys('loss', loss, output_names)
     loss_functions = []
     for name in output_names:
@@ -1300,7 +1347,7 @@
       loss_functions.append(get_loss_function(loss.get(name, None)))
   elif isinstance(loss, six.string_types):
     loss_functions = [get_loss_function(loss) for _ in output_names]
-  elif isinstance(loss, collections.Sequence):
+  elif isinstance(loss, collections_abc.Sequence):
     if len(loss) != len(output_names):
       raise ValueError('When passing a list as loss, it should have one entry '
                        'per model outputs. The model has {} outputs, but you '
@@ -1568,10 +1615,15 @@
   return x, y, weights
 
 
-def infer_steps_for_dataset(dataset, steps, epochs=1, steps_name='steps'):
+def infer_steps_for_dataset(model,
+                            dataset,
+                            steps,
+                            epochs=1,
+                            steps_name='steps'):
   """Infers steps_per_epoch needed to loop through a dataset.
 
   Arguments:
+      model: Keras model instance.
       dataset: Input data of type tf.data.Dataset.
       steps: Number of steps to draw from the dataset (may be None if unknown).
       epochs: Number of times to iterate over the dataset.
@@ -1581,14 +1633,15 @@
 
   Returns:
     Integer or `None`. Inferred number of steps to loop through the dataset.
-    `None` is returned if the size of the dataset is unknown and `steps` was
-    not specified.
+    `None` is returned if 1) the size of the dataset is unknown and `steps` was
+    not specified, or 2) this is multi-worker training and auto sharding is
+    enabled.
 
   Raises:
     ValueError: In case of invalid argument values.
   """
   assert isinstance(dataset, dataset_ops.DatasetV2)
-  if (multi_worker_util.in_multi_worker_mode() and
+  if (model._in_multi_worker_mode() and
       dataset.options().experimental_distribute.auto_shard):
     # If the dataset would be auto-sharded, we should not infer a local
     # steps_per_epoch due to the possible inbalanced sharding between workers.
@@ -1803,9 +1856,9 @@
       raise ValueError('`validation_freq` can not be less than 1.')
     return one_indexed_epoch % validation_freq == 0
 
-  if not isinstance(validation_freq, collections.Container):
+  if not isinstance(validation_freq, collections_abc.Container):
     raise ValueError('`validation_freq` must be an Integer or '
-                     '`collections.Container` (e.g. list, tuple, etc.)')
+                     '`collections_abc.Container` (e.g. list, tuple, etc.)')
   return one_indexed_epoch in validation_freq
 
 
diff --git a/tensorflow/python/keras/engine/training_v2.py b/tensorflow/python/keras/engine/training_v2.py
index 7e89312..b559d56 100644
--- a/tensorflow/python/keras/engine/training_v2.py
+++ b/tensorflow/python/keras/engine/training_v2.py
@@ -35,6 +35,7 @@
 from tensorflow.python.keras.engine import training_v2_utils
 from tensorflow.python.keras.utils.mode_keys import ModeKeys
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import nest
 from tensorflow.python.util import tf_contextlib
 
 
@@ -59,10 +60,10 @@
                   batch_size=None,
                   strategy=None,
                   steps_per_epoch=None,
+                  num_samples=None,
                   mode=ModeKeys.TRAIN,
                   training_context=None,
-                  total_epochs=None,
-                  partical_batch_size=None):
+                  total_epochs=None):
   """Run the execution function with the data from iterator.
 
   Given the dataset iterator and execution function, get the data from iterator
@@ -77,21 +78,18 @@
     batch_size: The size of the current batch.
     strategy: the distribution strategy instance from the model.
     steps_per_epoch: the number of steps to run for the epoch.
+    num_samples: the number of samples for the whole epoch if known. This can be
+      used to calculate the final partial batch, and scale the loss.
     mode: the mode for the current epoch.
     training_context: the context that contains callbacks and progress bar.
     total_epochs: the total number of epochs that will be run.
       Used when throw error when the iterator unexpectedly
       reaches its end.
-    partical_batch_size: the size of the final batch if it is already known. It
-      will be used to scale the loss value for the final batch.
   Returns:
     The loss and metric value from the model.
   """
   # Only use the sample to count if there is a partial batch at the end.
-  use_steps = not (partical_batch_size and batch_size and steps_per_epoch and
-                   steps_per_epoch == dataset_size)
-  num_samples = None if use_steps else batch_size * (steps_per_epoch -
-                                                     1) + partical_batch_size
+  use_steps = num_samples is None
 
   if mode == ModeKeys.PREDICT:
     aggregator = training_utils.OutputsAggregator(
@@ -112,72 +110,74 @@
   step = 0
 
   while step < target_steps:
-    # TODO(scottzhu): Maybe update the training context to take into account
-    #  whether a batch of training happens. Then it could still use a
-    #  context manager
-    batch_logs = {'batch': step, 'size': 1}
-    training_context.callbacks._call_batch_hook(
-        mode, 'begin', step, batch_logs)
-    training_context.progbar.on_batch_begin(step, batch_logs)
-    try:
-      batch_outs = execution_function(iterator)
-    except (StopIteration, errors.OutOfRangeError):
-      # TODO(kaftan): File bug about tf function and errors.OutOfRangeError?
-      # Are there any other C++ errors tf function should recapture?
-      # The only acceptable case here is that the input has a unknown
-      # length, and configured to fully consume it.
-      if (dataset_size is None
-          and steps_per_epoch is None
-          and step > 0):
-        # The input passed by the user ran out of batches.
-        # Now we know the cardinality of the input(dataset or generator).
-        steps_per_epoch = step
-        aggregator.steps = steps_per_epoch
-        progbar.params['steps'] = steps_per_epoch
-        progbar.progbar.target = steps_per_epoch
-      else:
-        callbacks.model.stop_training = True
-        logging.warning(
-            'Your input ran out of data; interrupting training. '
-            'Make sure that your dataset or generator can generate at '
-            'least `steps_per_epoch * epochs` batches (in this case, '
-            '{} batches). You may need to use the repeat() function '
-            'when building your dataset.'.format(
-                total_epochs * steps_per_epoch))
-      # In either case, break out the loop for training batch.
-      break
-
-    if not isinstance(batch_outs, list):
-      batch_outs = [batch_outs]
-    if strategy:
-      batch_outs = dist_utils._per_replica_aggregate_batch(
-          batch_outs, model, mode)
-
-    if step == 0:
-      aggregator.create(batch_outs)
-
     if use_steps:
-      aggregator.aggregate(batch_outs)
+      current_batch_size = 1
+    elif step < target_steps - 1:
+      current_batch_size = batch_size
     else:
-      aggregator.aggregate(
-          batch_outs,
-          batch_start=step * batch_size,
-          batch_end=min((step + 1) * batch_size, num_samples))
-    cbks.make_logs(model, batch_logs, batch_outs, mode)
+      current_batch_size = num_samples - step * batch_size
+    with training_context.on_batch(
+        step=step, mode=mode, size=current_batch_size) as batch_logs:
+      try:
+        batch_outs = execution_function(iterator)
+      except (StopIteration, errors.OutOfRangeError):
+        # TODO(kaftan): File bug about tf function and errors.OutOfRangeError?
+        # Are there any other C++ errors tf function should recapture?
+        # The only acceptable case here is that the input has a unknown
+        # length, and configured to fully consume it.
+        if (dataset_size is None
+            and steps_per_epoch is None
+            and step > 0):
+          # The input passed by the user ran out of batches.
+          # Now we know the cardinality of the input(dataset or generator).
+          steps_per_epoch = step
+          aggregator.steps = steps_per_epoch
+          progbar.params['steps'] = steps_per_epoch
+          progbar.progbar.target = steps_per_epoch
+        else:
+          callbacks.model.stop_training = True
+          logging.warning(
+              'Your input ran out of data; interrupting training. '
+              'Make sure that your dataset or generator can generate at '
+              'least `steps_per_epoch * epochs` batches (in this case, '
+              '{} batches). You may need to use the repeat() function '
+              'when building your dataset.'.format(
+                  total_epochs * steps_per_epoch))
+        # In either case, break out the loop for training batch.
+        # Also note the training_context that data inputs are exhausted, so all
+        # the post batch hooks can be skipped.
+        batch_logs['data_exhausted'] = True
+        break
 
-    training_context.callbacks._call_batch_hook(
-        mode, 'end', step, batch_logs)
-    training_context.progbar.on_batch_end(step, batch_logs)
+      if mode != ModeKeys.PREDICT:
+        data_batch_size = batch_outs['batch_size']
+        batch_outs = (batch_outs['total_loss'] + batch_outs['output_losses']
+                      + batch_outs['metrics'])
+        if current_batch_size != data_batch_size:
+          batch_logs['size'] = data_batch_size
+          current_batch_size = data_batch_size
+      else:
+        batch_outs = _aggregate_predict_results(strategy, batch_outs, model)
 
-    step += 1
+      if step == 0:
+        aggregator.create(batch_outs)
+
+      if use_steps:
+        aggregator.aggregate(batch_outs)
+      else:
+        aggregator.aggregate(
+            batch_outs,
+            batch_start=step * batch_size,
+            batch_end=step * batch_size + current_batch_size)
+      cbks.make_logs(model, batch_logs, batch_outs, mode)
+      step += 1
 
     if callbacks.model.stop_training:
       break
 
   # End of an epoch.
   aggregator.finalize()
-  results = aggregator.results
-  return results
+  return aggregator.results
 
 
 class Loop(training_utils.TrainingLoop):
@@ -216,6 +216,8 @@
           validation_steps=validation_steps,
           distribution_strategy=strategy)
 
+      total_samples = _get_total_number_of_samples(training_data_adapter)
+      use_sample = total_samples is not None
       do_validation = (validation_adapter is not None)
 
       if not steps_per_epoch:
@@ -232,13 +234,15 @@
       # is infinite.
       # TODO(scottzhu): This check should probably happen in the adapter
       training_utils.infer_steps_for_dataset(
-          training_dataset, steps_per_epoch, steps_name='steps_per_epoch',
+          model,
+          training_dataset,
+          steps_per_epoch,
+          steps_name='steps_per_epoch',
           epochs=0)
 
       training_dataset = strategy.experimental_distribute_dataset(
           training_dataset)
 
-      _update_sample_weight_mode(model, ModeKeys.TRAIN, training_dataset)
       training_function = training_v2_utils._get_or_make_execution_function(
           model, ModeKeys.TRAIN)
 
@@ -261,7 +265,10 @@
         # dataset is infinite.
         # TODO(scottzhu): This check should probably happen in the adapter
         training_utils.infer_steps_for_dataset(
-            validation_dataset, validation_steps, steps_name='validation_steps',
+            model,
+            validation_dataset,
+            validation_steps,
+            steps_name='validation_steps',
             epochs=0)
         validation_dataset = strategy.experimental_distribute_dataset(
             validation_dataset)
@@ -273,11 +280,13 @@
           batch_size=batch_size,
           epochs=epochs,
           steps_per_epoch=steps_per_epoch,
-          samples=None,
+          samples=total_samples,
+          count_mode='samples' if use_sample else 'steps',
           verbose=0,  # Handle ProgBarLogger separately in this loop.
           mode=ModeKeys.TRAIN)
 
-      with training_context.on_start(model, callbacks, verbose, ModeKeys.TRAIN):
+      with training_context.on_start(
+          model, callbacks, use_sample, verbose, ModeKeys.TRAIN):
         # TODO(scottzhu): Handle TPUStrategy training loop
         for epoch in range(initial_epoch, epochs):
           if training_context.callbacks.model.stop_training:
@@ -303,10 +312,10 @@
                 batch_size=training_data_adapter.batch_size(),
                 strategy=strategy,
                 steps_per_epoch=steps_per_epoch,
+                num_samples=total_samples,
                 mode=ModeKeys.TRAIN,
                 training_context=training_context,
-                total_epochs=epochs,
-                partical_batch_size=training_data_adapter.partial_batch_size())
+                total_epochs=epochs)
             cbks.make_logs(model, epoch_logs, training_result, ModeKeys.TRAIN)
 
             # Evaluation
@@ -321,9 +330,11 @@
               else:
                 eval_data_iter = iter(validation_dataset)
 
+              val_total_samples = _get_total_number_of_samples(
+                  validation_adapter)
               eval_context = TrainingContext()
               with eval_context.on_start(
-                  model, callbacks, verbose=0, mode=ModeKeys.TEST):
+                  model, callbacks, use_sample, verbose=0, mode=ModeKeys.TEST):
                 with eval_context.on_epoch(epoch, ModeKeys.TEST):
                   model.reset_metrics()
                   eval_result = run_one_epoch(
@@ -334,11 +345,10 @@
                       batch_size=validation_adapter.batch_size(),
                       strategy=strategy,
                       steps_per_epoch=validation_steps,
+                      num_samples=val_total_samples,
                       mode=ModeKeys.TEST,
                       training_context=eval_context,
-                      total_epochs=1,
-                      partical_batch_size=validation_adapter.partial_batch_size(
-                      ))
+                      total_epochs=1)
                   cbks.make_logs(model, epoch_logs, eval_result, ModeKeys.TEST,
                                  prefix='val_')
 
@@ -365,6 +375,8 @@
           sample_weights=sample_weight,
           steps=steps,
           distribution_strategy=strategy)
+      total_samples = _get_total_number_of_samples(adapter)
+      use_sample = total_samples is not None
 
       if not steps:
         steps = adapter.get_size()
@@ -377,10 +389,9 @@
       # is infinite.
       # TODO(scottzhu): This check should probably happen in the adapter
       training_utils.infer_steps_for_dataset(
-          dataset, steps, steps_name='steps', epochs=0)
+          model, dataset, steps, steps_name='steps', epochs=0)
       dataset = strategy.experimental_distribute_dataset(dataset)
 
-      _update_sample_weight_mode(model, mode, dataset)
       execution_function = training_v2_utils._get_or_make_execution_function(
           model, mode)
 
@@ -393,11 +404,13 @@
           batch_size=batch_size,
           epochs=1,
           steps_per_epoch=steps,
-          samples=None,
+          samples=use_sample,
+          count_mode='samples' if use_sample else 'steps',
           verbose=0,  # Handle ProgBarLogger separately in this loop.
           mode=mode)
 
-      with training_context.on_start(model, callbacks, verbose, mode):
+      with training_context.on_start(
+          model, callbacks, use_sample, verbose, mode):
         # TODO(scottzhu): Handle TPUStrategy training loop
         with training_context.on_epoch(0, mode) as epoch_logs:
           model.reset_metrics()
@@ -409,10 +422,10 @@
               batch_size=adapter.batch_size(),
               strategy=strategy,
               steps_per_epoch=steps,
+              num_samples=total_samples,
               mode=mode,
               training_context=training_context,
-              total_epochs=1,
-              partical_batch_size=adapter.partial_batch_size())
+              total_epochs=1)
           cbks.make_logs(model, epoch_logs, result, mode)
 
     if len(result) == 1:
@@ -435,25 +448,13 @@
 
 def _get_distribution_strategy(model):
   """Get the model's distribution strategy."""
-  if model._distribution_strategy:
-    return model._distribution_strategy
+  if model._compile_time_distribution_strategy:
+    strategy = model._compile_time_distribution_strategy
   else:
-    # Use the default strategy if no strategy was present at compile.
-    # Validate there is no actual strategy scope active at execution
-    # time.
+    # Grab the active strategy if the model was never compiled
+    # but it is now predicting.
     strategy = distribution_strategy_context.get_strategy()
-    if distribution_strategy_context.has_strategy():
-      raise ValueError(
-          'Model was compiled without any active distribution strategy, '
-          'but there is an execution-time distribution '
-          'strategy scope of (%s). '
-          'Try to make sure your code looks similar to the following.\n'
-          'with strategy.scope():\n'
-          '  model=_create_model()\n'
-          '  model.compile(...)\n'
-          '  model.fit(...)'% strategy)
-
-    return strategy
+  return strategy
 
 
 def _process_training_inputs(model, x, y, batch_size=None,
@@ -537,48 +538,47 @@
         batch_size=batch_size,
         check_steps=True,
         steps=steps)
-    # TODO(scottzhu): The generator and keras.sequence does not work with
-    # model._standardize_user_data() so far. However that method is very
-    # important which contains on-fly model build/tensor align for dict input,
-    # etc. We should still call the _standardize_user_data with the peeked data
-    # from generator or sequence, and let model compile.
-  return adapter_cls(x, y, batch_size=batch_size,
-                     sample_weights=sample_weights, shuffle=shuffle,
-                     distribution_strategy=distribution_strategy)
+  adapter = adapter_cls(x, y, batch_size=batch_size, steps=steps,
+                        sample_weights=sample_weights, shuffle=shuffle,
+                        distribution_strategy=distribution_strategy)
+  # As a fallback for the data type that does not work with
+  # _standardize_user_data, use the _prepare_model_with_inputs.
+  if adapter_cls not in _ADAPTER_FOR_STANDARDIZE_USER_DATA:
+    training_v2_utils._prepare_model_with_inputs(model, adapter.get_dataset())
+  return adapter
 
 
-def _update_sample_weight_mode(model, mode, dataset):
-  """Updates the sample_weight_mode of a given model."""
-  # TODO(kaftan): This won't actually do anything right now because
-  ## dist_utils._update_sample_weight_modes only does things when the model
-  ## is distributed by cloning. We will need to revisit if a method here
-  ## is needed at all, and if so how it should look.
-  # Add a quick return to prevent us from calling model._feed_targets that
-  # accesses certain model properties that may not be set in the `PREDICT` mode.
-  if mode == ModeKeys.PREDICT:
-    return
+def _get_total_number_of_samples(adapter):
+  if not adapter.get_size() or not adapter.batch_size():
+    return None
+  total_sample = adapter.get_size() * adapter.batch_size()
+  if adapter.has_partial_batch():
+    total_sample -= (adapter.batch_size() - adapter.partial_batch_size())
+  return total_sample
 
-  # Get some sample inputs from the data_adapter
-  iterator = iter(dataset)
-  _, _, sample_weights = training_v2_utils._prepare_feed_values(
-      model, iterator, mode)
 
-  # Call the DistributionStrategy specific function to update the
-  # sample_weight_mode on the model.
-  dist_utils._update_sample_weight_modes(model, mode, sample_weights)
-
-  # Force delete the iterator.
-  del iterator
+def _aggregate_predict_results(strategy, batch_outs, model):
+  if not isinstance(batch_outs, list):
+    batch_outs = [batch_outs]
+  total_batch_outs = []
+  for i in range(len(model.outputs)):
+    num_replicas = strategy.num_replicas_in_sync
+    nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas]
+    total_batch_outs.append(
+        dist_utils.concat_along_batch_dimension(nest.flatten(nested_outs)))
+  return total_batch_outs
 
 
 class TrainingContext(object):
   """Utility object that wrap around callbacks and progress bars."""
 
   @tf_contextlib.contextmanager
-  def on_start(self, model, callbacks=None, verbose=0, mode=ModeKeys.TRAIN):
+  def on_start(self, model, callbacks=None, use_samples=False, verbose=0,
+               mode=ModeKeys.TRAIN):
     """Provide a scope for the whole training process."""
     # TODO(omalleyt): Handle ProgBar as part of Callbacks once hooks are ready.
-    progbar = training_utils.get_progbar(model, 'steps')
+    progbar = training_utils.get_progbar(
+        model, 'samples' if use_samples else 'steps')
     progbar.params = callbacks.params
     progbar.params['verbose'] = verbose
     callbacks.model.stop_training = False
@@ -611,15 +611,16 @@
       self.progbar.on_epoch_end(epoch, epoch_logs)
 
   @tf_contextlib.contextmanager
-  def on_batch(self, step=0, mode=ModeKeys.TRAIN):
+  def on_batch(self, step=0, mode=ModeKeys.TRAIN, size=1):
     """Provide a scope for running one batch."""
-    batch_logs = {'batch': step, 'size': 1}
+    batch_logs = {'batch': step, 'size': size}
     self.callbacks._call_batch_hook(
         mode, 'begin', step, batch_logs)
     self.progbar.on_batch_begin(step, batch_logs)
     try:
       yield batch_logs
     finally:
-      self.callbacks._call_batch_hook(
-          mode, 'end', step, batch_logs)
-      self.progbar.on_batch_end(step, batch_logs)
+      if not batch_logs.pop('data_exhausted', False):
+        self.callbacks._call_batch_hook(
+            mode, 'end', step, batch_logs)
+        self.progbar.on_batch_end(step, batch_logs)
diff --git a/tensorflow/python/keras/engine/training_v2_utils.py b/tensorflow/python/keras/engine/training_v2_utils.py
index ec89849..86d3ad8 100644
--- a/tensorflow/python/keras/engine/training_v2_utils.py
+++ b/tensorflow/python/keras/engine/training_v2_utils.py
@@ -29,11 +29,14 @@
 from tensorflow.python.distribute import distribution_strategy_context
 from tensorflow.python.eager import def_function
 from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework.ops import composite_tensor
 from tensorflow.python.keras import backend
 from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils
 from tensorflow.python.keras.engine import training_eager
 from tensorflow.python.keras.engine import training_utils
 from tensorflow.python.keras.utils.mode_keys import ModeKeys
+from tensorflow.python.ops import array_ops
+from tensorflow.python.util import nest
 
 
 def _get_or_make_execution_function(model, mode):
@@ -67,8 +70,8 @@
     outputs = strategy.experimental_run_v2(
         per_replica_function, args=(model, x, y, sample_weights))
     # Out of PerReplica outputs reduce or pick values to return.
-    all_outputs = dist_utils.unwrap_outputs(
-        strategy, outputs, with_loss_tensor=(mode != ModeKeys.PREDICT))
+    all_outputs = dist_utils.unwrap_output_dict(
+        strategy, outputs, mode)
     return all_outputs
 
   if not model.run_eagerly:
@@ -77,7 +80,8 @@
 
   def execution_function(input_fn):
     # `numpy` translates Tensors to values in Eager mode.
-    return [out.numpy() for out in distributed_function(input_fn)]
+    return nest.map_structure(_non_none_constant_value,
+                              distributed_function(input_fn))
 
   return execution_function
 
@@ -125,7 +129,8 @@
   """Get elements from the iterator and verify the input shape and type."""
   next_element = next(iterator)
 
-  if tensor_util.is_tensor(next_element) or isinstance(next_element, dict):
+  if (tensor_util.is_tensor(next_element) or
+      isinstance(next_element, (dict, composite_tensor.CompositeTensor))):
     next_element = [next_element]
   if len(next_element) == 1:
     x, = next_element
@@ -164,6 +169,29 @@
   return func
 
 
+def _prepare_model_with_inputs(model, dataset):
+  """Use the data from the adapter to config the model.
+
+  Model need to be properly configured before training, eg build with inputs, or
+  compile with inputs for subclass model.
+
+  Args:
+    model: a Keras model object.
+    dataset: a eager dataset instance where the data will be extracted.
+  """
+  if not model.inputs:
+    inputs, target, _ = model._build_model_with_inputs(dataset, targets=None)
+  else:
+    inputs, target, _ = _get_input_from_iterator(iter(dataset))
+
+  if not model._is_compiled and model.optimizer:
+    model._compile_from_inputs(inputs, target, dataset, None)
+
+  if target is not None:
+    training_utils.prepare_sample_weight_modes(model._training_endpoints,
+                                               model.sample_weight_mode)
+
+
 def train_on_batch(
     model,
     x,
@@ -221,7 +249,7 @@
   x, y, sample_weights = model._standardize_user_data(
       x, y, sample_weight=sample_weight, class_weight=class_weight,
       extract_tensors_from_dataset=True)
-
+  batch_size = array_ops.shape(nest.flatten(x, expand_composites=True)[0])[0]
   # If `model._distribution_strategy` is True, then we are in a replica context
   # at this point because of the check above.  `train_on_batch` is being run
   # for each replica by `model._distribution_strategy` and the same code path
@@ -236,6 +264,7 @@
   if reset_metrics:
     model.reset_metrics()
 
+  outputs['batch_size'] = batch_size
   return outputs
 
 
@@ -287,6 +316,7 @@
   x, y, sample_weights = model._standardize_user_data(
       x, y, sample_weight=sample_weight, extract_tensors_from_dataset=True)
 
+  batch_size = array_ops.shape(nest.flatten(x, expand_composites=True)[0])[0]
   outputs = training_eager.test_on_batch(
       model,
       x,
@@ -297,6 +327,7 @@
   if reset_metrics:
     model.reset_metrics()
 
+  outputs['batch_size'] = batch_size
   return outputs
 
 
diff --git a/tensorflow/python/keras/integration_test.py b/tensorflow/python/keras/integration_test.py
index 2bb030d..8510db9 100644
--- a/tensorflow/python/keras/integration_test.py
+++ b/tensorflow/python/keras/integration_test.py
@@ -56,8 +56,6 @@
 class VectorClassificationIntegrationTest(keras_parameterized.TestCase):
 
   def test_vector_classification(self):
-    if testing_utils.should_run_distributed():
-      self.skipTest('b/137397816')
     np.random.seed(1337)
     (x_train, y_train), _ = testing_utils.get_test_data(
         train_samples=100,
@@ -76,7 +74,7 @@
         optimizer=keras.optimizer_v2.adam.Adam(0.005),
         metrics=['acc'],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     history = model.fit(x_train, y_train, epochs=10, batch_size=10,
                         validation_data=(x_train, y_train),
                         verbose=2)
@@ -113,7 +111,7 @@
         optimizer=keras.optimizer_v2.adam.Adam(0.005),
         metrics=['acc'],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     if not testing_utils.should_run_eagerly():
       self.assertEqual(len(model.get_losses_for(None)), 2)
       self.assertEqual(len(model.get_updates_for(x)), 2)
@@ -154,7 +152,7 @@
         optimizer=keras.optimizer_v2.adam.Adam(0.005),
         metrics=['acc'],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(x_train, y_train, epochs=1, batch_size=10,
               validation_data=(x_train, y_train),
               verbose=2)
@@ -177,7 +175,7 @@
         optimizer=keras.optimizer_v2.adam.Adam(0.005),
         metrics=['acc'],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     history = model.fit(x_train, y_train, epochs=10, batch_size=10,
                         validation_data=(x_train, y_train),
                         verbose=2)
@@ -195,9 +193,6 @@
 
   @keras_parameterized.run_with_all_model_types
   def test_timeseries_classification(self):
-    if testing_utils.should_run_distributed():
-      # Test timeout, seems to be a performance issue.
-      self.skipTest('b/137397816')
     np.random.seed(1337)
     (x_train, y_train), _ = testing_utils.get_test_data(
         train_samples=100,
@@ -217,7 +212,7 @@
         optimizer=keras.optimizer_v2.adam.Adam(0.005),
         metrics=['acc'],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     history = model.fit(x_train, y_train, epochs=15, batch_size=10,
                         validation_data=(x_train, y_train),
                         verbose=2)
@@ -228,9 +223,6 @@
     self.assertEqual(predictions.shape, (x_train.shape[0], 2))
 
   def test_timeseries_classification_sequential_tf_rnn(self):
-    if testing_utils.should_run_distributed():
-      # Test timeout, seems to be a performance issue.
-      self.skipTest('b/137397816')
     np.random.seed(1337)
     (x_train, y_train), _ = testing_utils.get_test_data(
         train_samples=100,
@@ -250,7 +242,7 @@
         optimizer=keras.optimizer_v2.adam.Adam(0.005),
         metrics=['acc'],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     history = model.fit(x_train, y_train, epochs=15, batch_size=10,
                         validation_data=(x_train, y_train),
                         verbose=2)
@@ -266,8 +258,6 @@
 class ImageClassificationIntegrationTest(keras_parameterized.TestCase):
 
   def test_image_classification(self):
-    if testing_utils.should_run_distributed():
-      self.skipTest('b/137397816')
     np.random.seed(1337)
     (x_train, y_train), _ = testing_utils.get_test_data(
         train_samples=100,
@@ -291,7 +281,7 @@
         optimizer=keras.optimizer_v2.adam.Adam(0.005),
         metrics=['acc'],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     history = model.fit(x_train, y_train, epochs=10, batch_size=10,
                         validation_data=(x_train, y_train),
                         verbose=2)
@@ -336,7 +326,7 @@
         optimizer=keras.optimizer_v2.adam.Adam(0.005),
         metrics=['accuracy'],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(x_train, y_train, epochs=2, batch_size=10,
               validation_data=(x_train, y_train),
               verbose=2)
diff --git a/tensorflow/python/keras/keras_parameterized.py b/tensorflow/python/keras/keras_parameterized.py
index 89fbbff..0beb4d4 100644
--- a/tensorflow/python/keras/keras_parameterized.py
+++ b/tensorflow/python/keras/keras_parameterized.py
@@ -18,7 +18,6 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
 import functools
 import itertools
 import unittest
@@ -31,6 +30,7 @@
 from tensorflow.python.keras import testing_utils
 from tensorflow.python.platform import test
 from tensorflow.python.util import nest
+from tensorflow.python.util.compat import collections_abc
 
 
 class TestCase(test.TestCase, parameterized.TestCase):
@@ -204,9 +204,10 @@
       optimizer = RMSPropOptimizer(learning_rate=0.001)
       loss = 'mse'
       metrics = ['mae']
-      model.compile(optimizer, loss, metrics=metrics,
-                    run_eagerly=testing_utils.should_run_eagerly(),
-                    run_distributed=testing_utils.should_run_distributed())
+      model.compile(
+          optimizer, loss, metrics=metrics,
+          run_eagerly=testing_utils.should_run_eagerly(),
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
 
       inputs = np.zeros((10, 3))
       targets = np.zeros((10, 4))
@@ -243,12 +244,11 @@
       a target dependency.
   """
 
-  params = [('_v2_function', 'v2_function'),
-            ('_v2_distributed', 'v2_distributed')]
+  params = [('_v2_function', 'v2_function'), ('_v2_funcgraph', 'v2_funcgraph')]
   if not always_skip_eager:
     params.append(('_v2_eager', 'v2_eager'))
   if not (always_skip_v1 or tf2.enabled()):
-    params.append(('_v1_graph', 'v1_graph'))
+    params.append(('_v1_session', 'v1_session'))
 
   def single_method_decorator(f):
     """Decorator that constructs the test cases."""
@@ -258,14 +258,14 @@
     @functools.wraps(f)
     def decorated(self, run_mode, *args, **kwargs):
       """A run of a single test case w/ specified run mode."""
-      if run_mode == 'v1_graph':
-        _v1_graph_test(f, self, config, *args, **kwargs)
-      elif run_mode == 'v2_function':
+      if run_mode == 'v1_session':
+        _v1_session_test(f, self, config, *args, **kwargs)
+      elif run_mode == 'v2_funcgraph':
         _v2_graph_functions_test(f, self, *args, **kwargs)
       elif run_mode == 'v2_eager':
         _v2_eager_test(f, self, *args, **kwargs)
-      elif run_mode == 'v2_distributed':
-        _v2_distributed_test(f, self, *args, **kwargs)
+      elif run_mode == 'v2_function':
+        _v2_function_test(f, self, *args, **kwargs)
       else:
         return ValueError('Unknown run mode %s' % run_mode)
 
@@ -274,9 +274,9 @@
   return _test_or_class_decorator(test_or_class, single_method_decorator)
 
 
-def _v1_graph_test(f, test_or_class, config, *args, **kwargs):
+def _v1_session_test(f, test_or_class, config, *args, **kwargs):
   with context.graph_mode(), testing_utils.run_eagerly_scope(False):
-    with testing_utils.run_distributed_scope(False):
+    with testing_utils.experimental_run_tf_function_scope(False):
       with test_or_class.test_session(use_gpu=True, config=config):
         f(test_or_class, *args, **kwargs)
 
@@ -284,21 +284,21 @@
 def _v2_graph_functions_test(f, test_or_class, *args, **kwargs):
   with context.eager_mode():
     with testing_utils.run_eagerly_scope(False):
-      with testing_utils.run_distributed_scope(False):
+      with testing_utils.experimental_run_tf_function_scope(False):
         f(test_or_class, *args, **kwargs)
 
 
 def _v2_eager_test(f, test_or_class, *args, **kwargs):
   with context.eager_mode():
     with testing_utils.run_eagerly_scope(True):
-      with testing_utils.run_distributed_scope(False):
+      with testing_utils.experimental_run_tf_function_scope(True):
         f(test_or_class, *args, **kwargs)
 
 
-def _v2_distributed_test(f, test_or_class, *args, **kwargs):
+def _v2_function_test(f, test_or_class, *args, **kwargs):
   with context.eager_mode():
     with testing_utils.run_eagerly_scope(False):
-      with testing_utils.run_distributed_scope(True):
+      with testing_utils.experimental_run_tf_function_scope(True):
         f(test_or_class, *args, **kwargs)
 
 
@@ -326,7 +326,7 @@
     The decorated result.
   """
   def _decorate_test_or_class(obj):
-    if isinstance(obj, collections.Iterable):
+    if isinstance(obj, collections_abc.Iterable):
       return itertools.chain.from_iterable(
           single_method_decorator(method) for method in obj)
     if isinstance(obj, type):
diff --git a/tensorflow/python/keras/keras_parameterized_test.py b/tensorflow/python/keras/keras_parameterized_test.py
index d08ef4f..0017fcb 100644
--- a/tensorflow/python/keras/keras_parameterized_test.py
+++ b/tensorflow/python/keras/keras_parameterized_test.py
@@ -210,21 +210,21 @@
       def testBody(self):
         mode = "eager" if context.executing_eagerly() else "graph"
         should_run_eagerly = testing_utils.should_run_eagerly()
-        should_run_distributed = testing_utils.should_run_distributed()
-        l.append((mode, should_run_eagerly, should_run_distributed))
+        should_run_tf_function = testing_utils.should_run_tf_function()
+        l.append((mode, should_run_eagerly, should_run_tf_function))
 
     e = ExampleTest()
     if not tf2.enabled():
-      e.testBody_v1_graph()
+      e.testBody_v1_session()
     e.testBody_v2_eager()
+    e.testBody_v2_funcgraph()
     e.testBody_v2_function()
-    e.testBody_v2_distributed()
 
     if not tf2.enabled():
       self.assertLen(l, 4)
       self.assertAllEqual(l, [
           ("graph", False, False),
-          ("eager", True, False),
+          ("eager", True, True),
           ("eager", False, False),
           ("eager", False, True),
       ])
@@ -236,7 +236,7 @@
     else:
       self.assertLen(l, 3)
       self.assertAllEqual(l, [
-          ("eager", True, False),
+          ("eager", True, True),
           ("eager", False, False),
           ("eager", False, True),
       ])
@@ -262,27 +262,27 @@
         mode = "eager" if context.executing_eagerly() else "graph"
         with_brackets = "with_brackets" if with_brackets else "without_brackets"
         should_run_eagerly = testing_utils.should_run_eagerly()
-        should_run_distributed = testing_utils.should_run_distributed()
-        l.append((with_brackets, mode, should_run_eagerly,
-                  should_run_distributed))
+        should_run_tf_function = testing_utils.should_run_tf_function()
+        l.append(
+            (with_brackets, mode, should_run_eagerly, should_run_tf_function))
 
     e = ExampleTest()
     if not tf2.enabled():
-      e.testBody_0_v1_graph()
-      e.testBody_1_v1_graph()
+      e.testBody_0_v1_session()
+      e.testBody_1_v1_session()
 
     e.testBody_0_v2_eager()
+    e.testBody_0_v2_funcgraph()
     e.testBody_0_v2_function()
-    e.testBody_0_v2_distributed()
     e.testBody_1_v2_eager()
+    e.testBody_1_v2_funcgraph()
     e.testBody_1_v2_function()
-    e.testBody_1_v2_distributed()
 
     expected_combinations = {
-        ("with_brackets", "eager", True, False),
+        ("with_brackets", "eager", True, True),
         ("with_brackets", "eager", False, False),
         ("with_brackets", "eager", False, True),
-        ("without_brackets", "eager", True, False),
+        ("without_brackets", "eager", True, True),
         ("without_brackets", "eager", False, False),
         ("without_brackets", "eager", False, True),
     }
@@ -314,25 +314,26 @@
       def testBody(self):
         mode = "eager" if context.executing_eagerly() else "graph"
         should_run_eagerly = testing_utils.should_run_eagerly()
-        should_run_distributed = testing_utils.should_run_distributed()
-        l.append((mode, should_run_eagerly, should_run_distributed))
+        should_run_tf_function = testing_utils.should_run_tf_function()
+        l.append((mode, should_run_eagerly, should_run_tf_function))
 
     e = ExampleTest()
-    if hasattr(e, "testBody_v1_graph"):
-      e.testBody_v1_graph()
+    if hasattr(e, "testBody_v1_session"):
+      e.testBody_v1_session()
     if hasattr(e, "testBody_v2_eager"):
       e.testBody_v2_eager()
+    if hasattr(e, "testBody_v2_funcgraph"):
+      e.testBody_v2_funcgraph()
     if hasattr(e, "testBody_v2_function"):
       e.testBody_v2_function()
-    if hasattr(e, "testBody_v2_distributed"):
-      e.testBody_v2_distributed()
 
     self.assertLen(l, 3)
-    self.assertEqual(set(l), {
-        ("eager", True, False),
-        ("eager", False, False),
-        ("eager", False, True),
-    })
+    self.assertEqual(
+        set(l), {
+            ("eager", True, True),
+            ("eager", False, False),
+            ("eager", False, True),
+        })
 
   def test_run_all_keras_modes_with_all_model_types(self):
     l = []
@@ -347,34 +348,34 @@
       def testBody(self):
         mode = "eager" if context.executing_eagerly() else "graph"
         should_run_eagerly = testing_utils.should_run_eagerly()
-        should_run_distributed = testing_utils.should_run_distributed()
-        l.append((mode, should_run_eagerly, should_run_distributed,
+        should_run_tf_function = testing_utils.should_run_tf_function()
+        l.append((mode, should_run_eagerly, should_run_tf_function,
                   testing_utils.get_model_type()))
 
     e = ExampleTest()
     e.testBody_v2_eager_functional()
+    e.testBody_v2_funcgraph_functional()
     e.testBody_v2_function_functional()
-    e.testBody_v2_distributed_functional()
     e.testBody_v2_eager_sequential()
+    e.testBody_v2_funcgraph_sequential()
     e.testBody_v2_function_sequential()
-    e.testBody_v2_distributed_sequential()
     e.testBody_v2_eager_subclass()
+    e.testBody_v2_funcgraph_subclass()
     e.testBody_v2_function_subclass()
-    e.testBody_v2_distributed_subclass()
 
     if not tf2.enabled():
-      e.testBody_v1_graph_functional()
-      e.testBody_v1_graph_sequential()
-      e.testBody_v1_graph_subclass()
+      e.testBody_v1_session_functional()
+      e.testBody_v1_session_sequential()
+      e.testBody_v1_session_subclass()
 
     expected_combinations = {
-        ("eager", True, False, "functional"),
+        ("eager", True, True, "functional"),
         ("eager", False, False, "functional"),
         ("eager", False, True, "functional"),
-        ("eager", True, False, "sequential"),
+        ("eager", True, True, "sequential"),
         ("eager", False, False, "sequential"),
         ("eager", False, True, "sequential"),
-        ("eager", True, False, "subclass"),
+        ("eager", True, True, "subclass"),
         ("eager", False, False, "subclass"),
         ("eager", False, True, "subclass"),
     }
@@ -408,34 +409,34 @@
       def testBody(self):
         mode = "eager" if context.executing_eagerly() else "graph"
         should_run_eagerly = testing_utils.should_run_eagerly()
-        should_run_distributed = testing_utils.should_run_distributed()
-        l.append((mode, should_run_eagerly, should_run_distributed,
+        should_run_tf_function = testing_utils.should_run_tf_function()
+        l.append((mode, should_run_eagerly, should_run_tf_function,
                   testing_utils.get_model_type()))
 
     e = ExampleTest()
     e.testBody_functional_v2_eager()
+    e.testBody_functional_v2_funcgraph()
     e.testBody_functional_v2_function()
-    e.testBody_functional_v2_distributed()
     e.testBody_sequential_v2_eager()
+    e.testBody_sequential_v2_funcgraph()
     e.testBody_sequential_v2_function()
-    e.testBody_sequential_v2_distributed()
     e.testBody_subclass_v2_eager()
+    e.testBody_subclass_v2_funcgraph()
     e.testBody_subclass_v2_function()
-    e.testBody_subclass_v2_distributed()
 
     if not tf2.enabled():
-      e.testBody_functional_v1_graph()
-      e.testBody_sequential_v1_graph()
-      e.testBody_subclass_v1_graph()
+      e.testBody_functional_v1_session()
+      e.testBody_sequential_v1_session()
+      e.testBody_subclass_v1_session()
 
     expected_combinations = {
-        ("eager", True, False, "functional"),
+        ("eager", True, True, "functional"),
         ("eager", False, False, "functional"),
         ("eager", False, True, "functional"),
-        ("eager", True, False, "sequential"),
+        ("eager", True, True, "sequential"),
         ("eager", False, False, "sequential"),
         ("eager", False, True, "sequential"),
-        ("eager", True, False, "subclass"),
+        ("eager", True, True, "subclass"),
         ("eager", False, False, "subclass"),
         ("eager", False, True, "subclass"),
     }
@@ -471,34 +472,34 @@
       def testBody(self, arg):
         mode = "eager" if context.executing_eagerly() else "graph"
         should_run_eagerly = testing_utils.should_run_eagerly()
-        should_run_distributed = testing_utils.should_run_distributed()
-        l.append((mode, should_run_eagerly, should_run_distributed,
+        should_run_tf_function = testing_utils.should_run_tf_function()
+        l.append((mode, should_run_eagerly, should_run_tf_function,
                   testing_utils.get_model_type()))
 
     e = ExampleTest()
     e.testBody_arg_v2_eager_functional()
+    e.testBody_arg_v2_funcgraph_functional()
     e.testBody_arg_v2_function_functional()
-    e.testBody_arg_v2_distributed_functional()
     e.testBody_arg_v2_eager_sequential()
+    e.testBody_arg_v2_funcgraph_sequential()
     e.testBody_arg_v2_function_sequential()
-    e.testBody_arg_v2_distributed_sequential()
     e.testBody_arg_v2_eager_subclass()
+    e.testBody_arg_v2_funcgraph_subclass()
     e.testBody_arg_v2_function_subclass()
-    e.testBody_arg_v2_distributed_subclass()
 
     if not tf2.enabled():
-      e.testBody_arg_v1_graph_functional()
-      e.testBody_arg_v1_graph_sequential()
-      e.testBody_arg_v1_graph_subclass()
+      e.testBody_arg_v1_session_functional()
+      e.testBody_arg_v1_session_sequential()
+      e.testBody_arg_v1_session_subclass()
 
     expected_combinations = {
-        ("eager", True, False, "functional"),
+        ("eager", True, True, "functional"),
         ("eager", False, False, "functional"),
         ("eager", False, True, "functional"),
-        ("eager", True, False, "sequential"),
+        ("eager", True, True, "sequential"),
         ("eager", False, False, "sequential"),
         ("eager", False, True, "sequential"),
-        ("eager", True, False, "subclass"),
+        ("eager", True, True, "subclass"),
         ("eager", False, False, "subclass"),
         ("eager", False, True, "subclass"),
     }
@@ -534,34 +535,34 @@
       def testBody(self, arg):
         mode = "eager" if context.executing_eagerly() else "graph"
         should_run_eagerly = testing_utils.should_run_eagerly()
-        should_run_distributed = testing_utils.should_run_distributed()
-        l.append((mode, should_run_eagerly, should_run_distributed,
+        should_run_tf_function = testing_utils.should_run_tf_function()
+        l.append((mode, should_run_eagerly, should_run_tf_function,
                   testing_utils.get_model_type()))
 
     e = ExampleTest()
     e.testBody_arg_v2_eager_functional()
+    e.testBody_arg_v2_funcgraph_functional()
     e.testBody_arg_v2_function_functional()
-    e.testBody_arg_v2_distributed_functional()
     e.testBody_arg_v2_eager_sequential()
+    e.testBody_arg_v2_funcgraph_sequential()
     e.testBody_arg_v2_function_sequential()
-    e.testBody_arg_v2_distributed_sequential()
     e.testBody_arg_v2_eager_subclass()
+    e.testBody_arg_v2_funcgraph_subclass()
     e.testBody_arg_v2_function_subclass()
-    e.testBody_arg_v2_distributed_subclass()
 
     if not tf2.enabled():
-      e.testBody_arg_v1_graph_functional()
-      e.testBody_arg_v1_graph_sequential()
-      e.testBody_arg_v1_graph_subclass()
+      e.testBody_arg_v1_session_functional()
+      e.testBody_arg_v1_session_sequential()
+      e.testBody_arg_v1_session_subclass()
 
     expected_combinations = {
-        ("eager", True, False, "functional"),
+        ("eager", True, True, "functional"),
         ("eager", False, False, "functional"),
         ("eager", False, True, "functional"),
-        ("eager", True, False, "sequential"),
+        ("eager", True, True, "sequential"),
         ("eager", False, False, "sequential"),
         ("eager", False, True, "sequential"),
-        ("eager", True, False, "subclass"),
+        ("eager", True, True, "subclass"),
         ("eager", False, False, "subclass"),
         ("eager", False, True, "subclass"),
     }
diff --git a/tensorflow/python/keras/layers/advanced_activations_test.py b/tensorflow/python/keras/layers/advanced_activations_test.py
index 07caba4..34e0ade 100644
--- a/tensorflow/python/keras/layers/advanced_activations_test.py
+++ b/tensorflow/python/keras/layers/advanced_activations_test.py
@@ -98,7 +98,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(np.ones((10, 10)), np.ones((10, 1)), batch_size=2)
 
 
diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py
index dcece7f..9e06c4c 100644
--- a/tensorflow/python/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/layers/convolutional.py
@@ -176,25 +176,24 @@
       self.bias = None
     self.input_spec = InputSpec(ndim=self.rank + 2,
                                 axes={channel_axis: input_dim})
-    self._convolution_op = None
-    self.built = True
-
-  def call(self, inputs):
     if self.padding == 'causal':
       op_padding = 'valid'
     else:
       op_padding = self.padding
     if not isinstance(op_padding, (list, tuple)):
       op_padding = op_padding.upper()
-    if self._convolution_op is None:
-      self._convolution_op = nn_ops.Convolution(
-          inputs.shape,
-          filter_shape=self.kernel.shape,
-          dilation_rate=self.dilation_rate,
-          strides=self.strides,
-          padding=op_padding,
-          data_format=conv_utils.convert_data_format(self.data_format,
-                                                     self.rank + 2))
+
+    self._convolution_op = nn_ops.Convolution(
+        input_shape,
+        filter_shape=self.kernel.shape,
+        dilation_rate=self.dilation_rate,
+        strides=self.strides,
+        padding=op_padding,
+        data_format=conv_utils.convert_data_format(self.data_format,
+                                                   self.rank + 2))
+    self.built = True
+
+  def call(self, inputs):
     outputs = self._convolution_op(inputs, self.kernel)
 
     if self.use_bias:
@@ -1786,6 +1785,7 @@
     if len(input_shape) < 4:
       raise ValueError('Inputs to `DepthwiseConv2D` should have rank 4. '
                        'Received input shape:', str(input_shape))
+    input_shape = tensor_shape.TensorShape(input_shape)
     if self.data_format == 'channels_first':
       channel_axis = 1
     else:
diff --git a/tensorflow/python/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/layers/convolutional_recurrent.py
index fe2994c..b1e30ac 100644
--- a/tensorflow/python/keras/layers/convolutional_recurrent.py
+++ b/tensorflow/python/keras/layers/convolutional_recurrent.py
@@ -921,7 +921,8 @@
                           recurrent_constraint=recurrent_constraint,
                           bias_constraint=bias_constraint,
                           dropout=dropout,
-                          recurrent_dropout=recurrent_dropout)
+                          recurrent_dropout=recurrent_dropout,
+                          dtype=kwargs.get('dtype'))
     super(ConvLSTM2D, self).__init__(cell,
                                      return_sequences=return_sequences,
                                      go_backwards=go_backwards,
diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py
index 1c2d8ca..df78cff 100644
--- a/tensorflow/python/keras/layers/core.py
+++ b/tensorflow/python/keras/layers/core.py
@@ -1043,11 +1043,7 @@
         output_shape = shape[:-1] + [self.units]
         outputs.set_shape(output_shape)
     else:
-      # Cast the inputs to self.dtype, which is the variable dtype. We do not
-      # cast if `should_cast_variables` is True, as in that case the variable
-      # will be automatically casted to inputs.dtype.
-      if not self._mixed_precision_policy.should_cast_variables:
-        inputs = math_ops.cast(inputs, self.dtype)
+      inputs = math_ops.cast(inputs, self._compute_dtype)
       if K.is_sparse(inputs):
         outputs = sparse_ops.sparse_tensor_dense_matmul(inputs, self.kernel)
       else:
diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py
index 992cb41..c943705 100644
--- a/tensorflow/python/keras/layers/core_test.py
+++ b/tensorflow/python/keras/layers/core_test.py
@@ -154,7 +154,7 @@
     def lambda_fn(x):
       return math_ops.matmul(x[0], x[1])
 
-    l = keras.layers.Lambda(lambda_fn)
+    l = keras.layers.Lambda(lambda_fn, dtype=dtypes.float64)
     output_shape = l.compute_output_shape([(10, 10), (10, 20)])
     self.assertAllEqual((10, 20), output_shape)
     output_signature = l.compute_output_signature([
@@ -289,7 +289,7 @@
         keras.optimizer_v2.gradient_descent.SGD(0.1),
         'mae',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     x, y = np.ones((10, 10), 'float32'), 2 * np.ones((10, 10), 'float32')
     model.fit(x, y, batch_size=2, epochs=2, validation_data=(x, y))
     self.assertLen(model.trainable_weights, 1)
diff --git a/tensorflow/python/keras/layers/cudnn_recurrent_test.py b/tensorflow/python/keras/layers/cudnn_recurrent_test.py
index 12bde97..bd266b5 100644
--- a/tensorflow/python/keras/layers/cudnn_recurrent_test.py
+++ b/tensorflow/python/keras/layers/cudnn_recurrent_test.py
@@ -87,7 +87,7 @@
     self.assertEqual(len(state), num_states)
     model = keras.models.Model(inputs, state[0])
     model.run_eagerly = testing_utils.should_run_eagerly()
-    model._run_distributed = testing_utils.should_run_distributed()
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
 
     inputs = np.random.random((num_samples, timesteps, input_size))
     state = model.predict(inputs)
@@ -146,7 +146,7 @@
         loss='categorical_crossentropy',
         optimizer=RMSprop(learning_rate=0.001),
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     inputs = np.random.random((num_samples, timesteps, input_size))
     initial_state = [
diff --git a/tensorflow/python/keras/layers/embeddings.py b/tensorflow/python/keras/layers/embeddings.py
index 85285db..fea1901 100644
--- a/tensorflow/python/keras/layers/embeddings.py
+++ b/tensorflow/python/keras/layers/embeddings.py
@@ -104,6 +104,11 @@
       else:
         kwargs['input_shape'] = (None,)
     dtype = kwargs.pop('dtype', K.floatx())
+    # We set autocast to False, as we do not want to cast floating- point inputs
+    # to self.dtype. In call(), we cast to int32, and casting to self.dtype
+    # before casting to int32 might cause the int32 values to be different due
+    # to a loss of precision.
+    kwargs['autocast'] = False
     super(Embedding, self).__init__(dtype=dtype, **kwargs)
 
     self.input_dim = input_dim
diff --git a/tensorflow/python/keras/layers/embeddings_test.py b/tensorflow/python/keras/layers/embeddings_test.py
index 8545941..f49cbe4 100644
--- a/tensorflow/python/keras/layers/embeddings_test.py
+++ b/tensorflow/python/keras/layers/embeddings_test.py
@@ -80,7 +80,7 @@
 
     layer.set_weights([np.array([[1, 1], [2, 2]])])
     model.run_eagerly = testing_utils.should_run_eagerly()
-    model._run_distributed = testing_utils.should_run_distributed()
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
     outputs = model.predict(np.array([[0, 1, 0]], dtype='int32'))
     self.assertAllClose(outputs, [[[1, 1], [2, 2], [1, 1]]])
 
diff --git a/tensorflow/python/keras/layers/gru_test.py b/tensorflow/python/keras/layers/gru_test.py
index 0383db0..cf32b80 100644
--- a/tensorflow/python/keras/layers/gru_test.py
+++ b/tensorflow/python/keras/layers/gru_test.py
@@ -43,6 +43,19 @@
                 'return_sequences': True},
         input_shape=(num_samples, timesteps, embedding_dim))
 
+  def test_float64_GRU(self):
+    num_samples = 2
+    timesteps = 3
+    embedding_dim = 4
+    units = 2
+    testing_utils.layer_test(
+        keras.layers.GRU,
+        kwargs={'units': units,
+                'return_sequences': True,
+                'dtype': 'float64'},
+        input_shape=(num_samples, timesteps, embedding_dim),
+        input_dtype='float64')
+
   def test_dynamic_behavior_GRU(self):
     num_samples = 2
     timesteps = 3
@@ -55,7 +68,7 @@
         'rmsprop',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     x = np.random.random((num_samples, timesteps, embedding_dim))
     y = np.random.random((num_samples, units))
     model.train_on_batch(x, y)
@@ -106,7 +119,7 @@
         'rmsprop',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     gru_model.fit(x_train, y_train)
     gru_model.predict(x_train)
 
@@ -122,7 +135,7 @@
         loss='categorical_crossentropy',
         optimizer='rmsprop',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(inputs, targets, epochs=1, batch_size=2, verbose=1)
 
   def test_statefulness_GRU(self):
@@ -147,7 +160,7 @@
         optimizer='sgd',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     out1 = model.predict(np.ones((num_samples, timesteps)))
     self.assertEqual(out1.shape, (num_samples, units))
 
diff --git a/tensorflow/python/keras/layers/gru_v2_test.py b/tensorflow/python/keras/layers/gru_v2_test.py
index 29c45fc..0a58879 100644
--- a/tensorflow/python/keras/layers/gru_v2_test.py
+++ b/tensorflow/python/keras/layers/gru_v2_test.py
@@ -27,6 +27,7 @@
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.core.protobuf import rewriter_config_pb2
 from tensorflow.python import keras
+from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
@@ -136,7 +137,6 @@
       l2 = layer_class.from_config(l1.get_config())
       assert l1.get_config() == l2.get_config()
 
-  # Due to b/120160788
   @test_util.run_v2_only
   def test_gru_v2_feature_parity_with_canonical_gru(self):
     input_shape = 10
@@ -259,8 +259,8 @@
       canonical_model.set_weights(weights)
       y_3 = canonical_model.predict(x_train)
 
-    self.assertAllClose(y_1, y_2)
-    self.assertAllClose(y_2, y_3)
+    self.assertAllClose(y_1, y_2, rtol=1e-5, atol=1e-5)
+    self.assertAllClose(y_2, y_3, rtol=1e-5, atol=1e-5)
 
   @parameterized.named_parameters(
       # test_name, time_major, go_backwards
@@ -342,6 +342,19 @@
                 'return_sequences': True},
         input_shape=(num_samples, timesteps, embedding_dim))
 
+  def test_float64_GRU(self):
+    num_samples = 2
+    timesteps = 3
+    embedding_dim = 4
+    units = 2
+    testing_utils.layer_test(
+        rnn.GRU,
+        kwargs={'units': units,
+                'return_sequences': True,
+                'dtype': 'float64'},
+        input_shape=(num_samples, timesteps, embedding_dim),
+        input_dtype='float64')
+
   def test_return_states_GRU(self):
     layer_class = rnn.GRU
     x = np.random.random((2, 3, 4))
@@ -422,8 +435,6 @@
     else:
       self.assertEqual(len(layer.get_losses_for(x)), 1)
 
-  # Run in V2 only due to b/120160788.
-  @test_util.run_v2_only
   def test_statefulness_GRU(self):
     num_samples = 2
     timesteps = 3
@@ -445,7 +456,7 @@
         optimizer=gradient_descent.GradientDescentOptimizer(0.01),
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     out1 = model.predict(np.ones((num_samples, timesteps)))
     self.assertEqual(out1.shape, (num_samples, units))
 
@@ -518,7 +529,7 @@
         optimizer='adam',
         loss='sparse_categorical_crossentropy',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(x, y, epochs=1, shuffle=False)
 
   @test_util.run_v2_only
@@ -546,6 +557,23 @@
       outputs_trimmed = lstm(inputs[:, :masksteps])
     self.assertAllClose(outputs_masked[:, -masksteps:], outputs_trimmed)
 
+  @test_util.run_deprecated_v1
+  def test_v1_session_behavior(self):
+    # See b/139132348 for more details.
+    x = np.random.uniform(size=(100, 4, 8))
+    y = np.random.uniform(size=(100, 1))
+    dataset = dataset_ops.Dataset.from_tensor_slices(
+        (x, y)).shuffle(100).batch(32)
+
+    inp = keras.layers.Input(shape=(4, 8))
+    layer = rnn.GRU(1)(inp)
+    layer = keras.layers.Dense(1)(layer)
+
+    model = keras.models.Model(inp, layer)
+
+    model.compile(loss='mse', optimizer='sgd')
+    model.fit(dataset)
+
 
 class GRULayerGradientTapeTest(test.TestCase):
 
@@ -593,9 +621,10 @@
         num_classes=self.output_shape)
     y_train = keras.utils.to_categorical(y_train, self.output_shape)
 
-    model.compile(optimizer='sgd',
-                  loss=['categorical_crossentropy', None],
-                  run_distributed=testing_utils.should_run_distributed())
+    model.compile(
+        optimizer='sgd',
+        loss=['categorical_crossentropy', None],
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     existing_loss = 0
     for _ in range(self.epoch):
@@ -611,6 +640,7 @@
     else:
       self.assertEqual(runtime_value[0], rnn._RUNTIME_CPU)
 
+  @test_util.run_v2_only
   def test_GRU_runtime(self):
     layer = rnn.GRU(self.rnn_state_size, return_runtime=True)
 
@@ -626,6 +656,7 @@
     model = keras.models.Model(inputs=inputs, outputs=[outputs, runtime])
     self._test_runtime_with_model(model)
 
+  @test_util.run_v2_only
   def test_GRU_runtime_with_mask(self):
     # Masking will affect which backend is selected based on whether the mask
     # is strictly right padded.
@@ -650,10 +681,11 @@
         num_classes=self.output_shape)
     y_train = keras.utils.to_categorical(y_train, self.output_shape)
 
-    model.compile(optimizer='sgd',
-                  loss=['categorical_crossentropy', None],
-                  run_eagerly=testing_utils.should_run_eagerly(),
-                  run_distributed=testing_utils.should_run_distributed())
+    model.compile(
+        optimizer='sgd',
+        loss=['categorical_crossentropy', None],
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     model.fit(x_train, y_train)
 
@@ -680,7 +712,6 @@
     _, runtime_value = model.predict(x_train)
     self.assertEqual(runtime_value[0], rnn._RUNTIME_CPU)
 
-  # Due to b/120160788.
   @test_util.run_v2_only
   def test_GRU_runtime_with_cond(self):
     # This test is to demonstrate the graph rewrite of grappler plugin under
diff --git a/tensorflow/python/keras/layers/kernelized_test.py b/tensorflow/python/keras/layers/kernelized_test.py
index f9231df..7e45edb 100644
--- a/tensorflow/python/keras/layers/kernelized_test.py
+++ b/tensorflow/python/keras/layers/kernelized_test.py
@@ -32,6 +32,7 @@
 from tensorflow.python.framework import test_util
 from tensorflow.python.keras import backend as keras_backend
 from tensorflow.python.keras import initializers
+from tensorflow.python.keras.engine import base_layer_utils
 from tensorflow.python.keras.layers import kernelized as kernel_layers
 from tensorflow.python.keras.utils import kernelized_utils
 from tensorflow.python.ops import array_ops
@@ -213,13 +214,15 @@
     if isinstance(initializer, init_ops.Initializer):
       expected_initializer = initializers.serialize(initializer)
 
+    expected_dtype = (
+        'float32' if base_layer_utils.v2_dtype_behavior_enabled() else None)
     expected_config = {
         'output_dim': output_dim,
         'kernel_initializer': expected_initializer,
         'scale': scale,
         'name': 'random_fourier_features',
         'trainable': trainable,
-        'dtype': None,
+        'dtype': expected_dtype,
     }
     self.assertLen(expected_config, len(rff_layer.get_config()))
     self.assertSameElements(
diff --git a/tensorflow/python/keras/layers/local.py b/tensorflow/python/keras/layers/local.py
index c2c93c0..d940920 100644
--- a/tensorflow/python/keras/layers/local.py
+++ b/tensorflow/python/keras/layers/local.py
@@ -87,7 +87,7 @@
           the output of the layer (its "activation")..
       kernel_constraint: Constraint function applied to the kernel matrix.
       bias_constraint: Constraint function applied to the bias vector.
-      implementation: implementation mode, either `1` or `2`.
+      implementation: implementation mode, either `1`, `2`, or `3`.
           `1` loops over input spatial locations to perform the forward pass.
           It is memory-efficient but performs a lot of (small) ops.
 
@@ -95,20 +95,30 @@
           and implements the forward pass as a single matrix-multiply. It uses
           a lot of RAM but performs few (large) ops.
 
-          Depending on the inputs, layer parameters, hardware, and
-          `tf.executing_eagerly()` one implementation can be dramatically faster
-          (e.g. 50X) than another.
+          `3` stores layer weights in a sparse tensor and implements the forward
+          pass as a single sparse matrix-multiply.
 
-          It is recommended to benchmark both in the setting of interest to pick
-          the most efficient one (in terms of speed and memory usage).
+          How to choose:
 
-          Following scenarios could benefit from setting `implementation=2`:
-              - eager execution;
-              - inference;
-              - running on CPU;
-              - large amount of RAM available;
-              - small models (few filters, small kernel);
-              - using `padding=same` (only possible with `implementation=2`).
+          `1`: large, dense models,
+          `2`: small models,
+          `3`: large, sparse models,
+
+          where "large" stands for large input/output activations
+          (i.e. many `filters`, `input_filters`, large `input_size`,
+          `output_size`), and "sparse" stands for few connections between inputs
+          and outputs, i.e. small ratio
+          `filters * input_filters * kernel_size / (input_size * strides)`,
+          where inputs to and outputs of the layer are assumed to have shapes
+          `(input_size, input_filters)`, `(output_size, filters)`
+          respectively.
+
+          It is recommended to benchmark each in the setting of interest to pick
+          the most efficient one (in terms of speed and memory usage). Correct
+          choice of implementation can lead to dramatic speed improvements (e.g.
+          50X), potentially at the expense of RAM.
+
+          Also, only `padding="valid"` is supported by `implementation=1`.
 
   Input shape:
       3D tensor with shape: `(batch_size, steps, input_dim)`
@@ -200,9 +210,31 @@
           kernel_shape=self.kernel_size,
           strides=self.strides,
           padding=self.padding,
-          data_format=self.data_format
+          data_format=self.data_format,
       )
 
+    elif self.implementation == 3:
+      self.kernel_shape = (self.output_length * self.filters,
+                           input_length * input_dim)
+
+      self.kernel_idxs = sorted(
+          conv_utils.conv_kernel_idxs(
+              input_shape=(input_length,),
+              kernel_shape=self.kernel_size,
+              strides=self.strides,
+              padding=self.padding,
+              filters_in=input_dim,
+              filters_out=self.filters,
+              data_format=self.data_format)
+      )
+
+      self.kernel = self.add_weight(
+          shape=(len(self.kernel_idxs),),
+          initializer=self.kernel_initializer,
+          name='kernel',
+          regularizer=self.kernel_regularizer,
+          constraint=self.kernel_constraint)
+
     else:
       raise ValueError('Unrecognized implementation mode: %d.'
                        % self.implementation)
@@ -247,6 +279,11 @@
       output = local_conv_matmul(inputs, self.kernel, self.kernel_mask,
                                  self.compute_output_shape(inputs.shape))
 
+    elif self.implementation == 3:
+      output = local_conv_sparse_matmul(inputs, self.kernel, self.kernel_idxs,
+                                        self.kernel_shape,
+                                        self.compute_output_shape(inputs.shape))
+
     else:
       raise ValueError('Unrecognized implementation mode: %d.'
                        % self.implementation)
@@ -355,7 +392,7 @@
           the output of the layer (its "activation").
       kernel_constraint: Constraint function applied to the kernel matrix.
       bias_constraint: Constraint function applied to the bias vector.
-      implementation: implementation mode, either `1` or `2`.
+      implementation: implementation mode, either `1`, `2`, or `3`.
           `1` loops over input spatial locations to perform the forward pass.
           It is memory-efficient but performs a lot of (small) ops.
 
@@ -363,20 +400,30 @@
           and implements the forward pass as a single matrix-multiply. It uses
           a lot of RAM but performs few (large) ops.
 
-          Depending on the inputs, layer parameters, hardware, and
-          `tf.executing_eagerly()` one implementation can be dramatically faster
-          (e.g. 50X) than another.
+          `3` stores layer weights in a sparse tensor and implements the forward
+          pass as a single sparse matrix-multiply.
 
-          It is recommended to benchmark both in the setting of interest to pick
-          the most efficient one (in terms of speed and memory usage).
+          How to choose:
 
-          Following scenarios could benefit from setting `implementation=2`:
-              - eager execution;
-              - inference;
-              - running on CPU;
-              - large amount of RAM available;
-              - small models (few filters, small kernel);
-              - using `padding=same` (only possible with `implementation=2`).
+          `1`: large, dense models,
+          `2`: small models,
+          `3`: large, sparse models,
+
+          where "large" stands for large input/output activations
+          (i.e. many `filters`, `input_filters`, large `np.prod(input_size)`,
+          `np.prod(output_size)`), and "sparse" stands for few connections
+          between inputs and outputs, i.e. small ratio
+          `filters * input_filters * np.prod(kernel_size) / (np.prod(input_size)
+          * np.prod(strides))`, where inputs to and outputs of the layer are
+          assumed to have shapes `input_size + (input_filters,)`,
+          `output_size + (filters,)` respectively.
+
+          It is recommended to benchmark each in the setting of interest to pick
+          the most efficient one (in terms of speed and memory usage). Correct
+          choice of implementation can lead to dramatic speed improvements (e.g.
+          50X), potentially at the expense of RAM.
+
+          Also, only `padding="valid"` is supported by `implementation=1`.
 
   Input shape:
       4D tensor with shape:
@@ -483,9 +530,31 @@
           kernel_shape=self.kernel_size,
           strides=self.strides,
           padding=self.padding,
-          data_format=self.data_format
+          data_format=self.data_format,
       )
 
+    elif self.implementation == 3:
+      self.kernel_shape = (self.output_row * self.output_col * self.filters,
+                           input_row * input_col * input_filter)
+
+      self.kernel_idxs = sorted(
+          conv_utils.conv_kernel_idxs(
+              input_shape=(input_row, input_col),
+              kernel_shape=self.kernel_size,
+              strides=self.strides,
+              padding=self.padding,
+              filters_in=input_filter,
+              filters_out=self.filters,
+              data_format=self.data_format)
+      )
+
+      self.kernel = self.add_weight(
+          shape=(len(self.kernel_idxs),),
+          initializer=self.kernel_initializer,
+          name='kernel',
+          regularizer=self.kernel_regularizer,
+          constraint=self.kernel_constraint)
+
     else:
       raise ValueError('Unrecognized implementation mode: %d.'
                        % self.implementation)
@@ -534,6 +603,11 @@
       output = local_conv_matmul(inputs, self.kernel, self.kernel_mask,
                                  self.compute_output_shape(inputs.shape))
 
+    elif self.implementation == 3:
+      output = local_conv_sparse_matmul(inputs, self.kernel, self.kernel_idxs,
+                                        self.kernel_shape,
+                                        self.compute_output_shape(inputs.shape))
+
     else:
       raise ValueError('Unrecognized implementation mode: %d.'
                        % self.implementation)
@@ -581,10 +655,7 @@
     return dict(list(base_config.items()) + list(config.items()))
 
 
-def get_locallyconnected_mask(input_shape,
-                              kernel_shape,
-                              strides,
-                              padding,
+def get_locallyconnected_mask(input_shape, kernel_shape, strides, padding,
                               data_format):
   """Return a mask representing connectivity of a locally-connected operation.
 
@@ -701,6 +772,44 @@
   return output
 
 
+def local_conv_sparse_matmul(inputs, kernel, kernel_idxs, kernel_shape,
+                             output_shape):
+  """Apply N-D convolution with un-shared weights using a single sparse matmul.
+
+  This method outputs `inputs . tf.SparseTensor(indices=kernel_idxs,
+  values=kernel, dense_shape=kernel_shape)`, with `.` standing for
+  matrix-multiply. It also reshapes `inputs` to 2-D and `output` to (N+2)-D.
+
+  Arguments:
+      inputs: (N+2)-D tensor with shape `(batch_size, channels_in, d_in1, ...,
+        d_inN)` or `(batch_size, d_in1, ..., d_inN, channels_in)`.
+      kernel: a 1-D tensor with shape `(len(kernel_idxs),)` containing all the
+        weights of the layer.
+      kernel_idxs:  a list of integer tuples representing indices in a sparse
+        matrix performing the un-shared convolution as a matrix-multiply.
+      kernel_shape: a tuple `(input_size, output_size)`, where `input_size =
+        channels_in * d_in1 * ... * d_inN` and `output_size = channels_out *
+        d_out1 * ... * d_outN`.
+      output_shape: a tuple of (N+2) elements representing the output shape:
+        `(batch_size, channels_out, d_out1, ..., d_outN)` or `(batch_size,
+        d_out1, ..., d_outN, channels_out)`, with the ordering of channels and
+        spatial dimensions matching that of the input.
+
+  Returns:
+      Output (N+2)-D dense tensor with shape `output_shape`.
+  """
+  inputs_flat = K.reshape(inputs, (K.shape(inputs)[0], -1))
+  output_flat = K.sparse_ops.sparse_tensor_dense_mat_mul(
+      kernel_idxs, kernel, kernel_shape, inputs_flat, adjoint_b=True)
+  output_flat_transpose = K.transpose(output_flat)
+
+  output_reshaped = K.reshape(
+      output_flat_transpose,
+      [K.shape(output_flat_transpose)[0],] + output_shape.as_list()[1:]
+  )
+  return output_reshaped
+
+
 def make_2d(tensor, split_dim):
   """Reshapes an N-dimensional tensor into a 2D tensor.
 
diff --git a/tensorflow/python/keras/layers/local_test.py b/tensorflow/python/keras/layers/local_test.py
index c03fb21..2efbd09 100644
--- a/tensorflow/python/keras/layers/local_test.py
+++ b/tensorflow/python/keras/layers/local_test.py
@@ -61,6 +61,22 @@
     'data_format': 'channels_last',
     'padding': 'same',
     'implementation': 2
+}, {
+    'data_format': 'channels_first',
+    'padding': 'valid',
+    'implementation': 3
+}, {
+    'data_format': 'channels_first',
+    'padding': 'same',
+    'implementation': 3
+}, {
+    'data_format': 'channels_last',
+    'padding': 'valid',
+    'implementation': 3
+}, {
+    'data_format': 'channels_last',
+    'padding': 'same',
+    'implementation': 3
 }]
 
 
@@ -219,7 +235,8 @@
         'bias_regularizer': 'l2',
         'activity_regularizer': 'l2',
         'implementation': implementation,
-        'padding': padding
+        'padding': padding,
+        'data_format': data_format
     }
 
     if padding == 'same' and implementation == 1:
@@ -253,8 +270,13 @@
 class LocallyConnectedImplementationModeTest(test.TestCase,
                                              parameterized.TestCase):
 
-  @parameterized.parameters(['channels_first', 'channels_last'])
-  def test_locallyconnected_implementation(self, data_format):
+  @parameterized.parameters([
+      {'width': 1, 'data_format': 'channels_first'},
+      {'width': 1, 'data_format': 'channels_last'},
+      {'width': 6, 'data_format': 'channels_first'},
+      {'width': 6, 'data_format': 'channels_last'},
+  ])
+  def test_locallyconnected_implementation(self, width, data_format):
     with self.cached_session():
       num_samples = 4
       num_classes = 3
@@ -263,58 +285,78 @@
       np.random.seed(1)
       targets = np.random.randint(0, num_classes, (num_samples,))
 
-      for width in [1, 6]:
-        for height in [7]:
-          for filters in [2]:
-            inputs = get_inputs(data_format, filters, height, num_samples,
-                                width)
+      height = 7
+      filters = 2
+      inputs = get_inputs(data_format, filters, height, num_samples, width)
 
-            for kernel_x in [(3,)]:
-              for kernel_y in [()] if width == 1 else [(2,)]:
-                for stride_x in [(1,)]:
-                  for stride_y in [()] if width == 1 else [(3,)]:
-                    for layers in [2]:
-                      kwargs = {
-                          'layers': layers,
-                          'filters': filters,
-                          'kernel_size': kernel_x + kernel_y,
-                          'strides': stride_x + stride_y,
-                          'data_format': data_format,
-                          'num_classes': num_classes
-                      }
-                      model_1 = get_model(implementation=1, **kwargs)
-                      model_2 = get_model(implementation=2, **kwargs)
+      kernel_x = (3,)
+      kernel_y = () if width == 1 else (2,)
+      stride_x = (1,)
+      stride_y = () if width == 1 else (3,)
+      layers = 2
 
-                      # Build models.
-                      model_1.train_on_batch(inputs, targets)
-                      model_2.train_on_batch(inputs, targets)
+      kwargs = {
+          'layers': layers,
+          'filters': filters,
+          'kernel_size': kernel_x + kernel_y,
+          'strides': stride_x + stride_y,
+          'data_format': data_format,
+          'num_classes': num_classes
+      }
 
-                      # Copy weights.
-                      copy_model_weights(model_2, model_1)
+      model_1 = get_model(implementation=1, **kwargs)
+      model_2 = get_model(implementation=2, **kwargs)
+      model_3 = get_model(implementation=3, **kwargs)
 
-                      # Compare outputs at initialization.
-                      out_1 = model_1.call(inputs)
-                      out_2 = model_2.call(inputs)
-                      self.assertAllCloseAccordingToType(
-                          out_1, out_2, rtol=1e-5, atol=1e-5)
+      # Build models.
+      model_1.train_on_batch(inputs, targets)
+      model_2.train_on_batch(inputs, targets)
+      model_3.train_on_batch(inputs, targets)
 
-                      # Train.
-                      model_1.fit(
-                          x=inputs,
-                          y=targets,
-                          epochs=num_epochs,
-                          batch_size=num_samples)
-                      model_2.fit(
-                          x=inputs,
-                          y=targets,
-                          epochs=num_epochs,
-                          batch_size=num_samples)
+      # Copy weights.
+      copy_model_weights(model_from=model_2, model_to=model_1)
+      copy_model_weights(model_from=model_2, model_to=model_3)
 
-                      # Compare outputs after a few training steps.
-                      out_1 = model_1.call(inputs)
-                      out_2 = model_2.call(inputs)
-                      self.assertAllCloseAccordingToType(
-                          out_1, out_2, atol=2e-4)
+      # Compare outputs at initialization.
+      out_1 = model_1.call(inputs)
+      out_2 = model_2.call(inputs)
+      out_3 = model_3.call(inputs)
+
+      self.assertAllCloseAccordingToType(
+          out_2, out_1, rtol=1e-5, atol=1e-5)
+      self.assertAllCloseAccordingToType(
+          out_2, out_3, rtol=1e-5, atol=1e-5)
+      self.assertAllCloseAccordingToType(
+          out_1, out_3, rtol=1e-5, atol=1e-5)
+
+      # Train.
+      model_1.fit(
+          x=inputs,
+          y=targets,
+          epochs=num_epochs,
+          batch_size=num_samples)
+      model_2.fit(
+          x=inputs,
+          y=targets,
+          epochs=num_epochs,
+          batch_size=num_samples)
+      model_3.fit(
+          x=inputs,
+          y=targets,
+          epochs=num_epochs,
+          batch_size=num_samples)
+
+      # Compare outputs after a few training steps.
+      out_1 = model_1.call(inputs)
+      out_2 = model_2.call(inputs)
+      out_3 = model_3.call(inputs)
+
+      self.assertAllCloseAccordingToType(
+          out_2, out_1, atol=2e-4)
+      self.assertAllCloseAccordingToType(
+          out_2, out_3, atol=2e-4)
+      self.assertAllCloseAccordingToType(
+          out_1, out_3, atol=2e-4)
 
   def test_make_2d(self):
     input_shapes = [
@@ -422,7 +464,7 @@
   return model
 
 
-def copy_lc_weights(lc_layer_2_from, lc_layer_1_to):
+def copy_lc_weights_2_to_1(lc_layer_2_from, lc_layer_1_to):
   lc_2_kernel, lc_2_bias = lc_layer_2_from.weights
   lc_2_kernel_masked = lc_2_kernel * lc_layer_2_from.kernel_mask
 
@@ -463,20 +505,49 @@
   lc_layer_1_to.set_weights([lc_2_kernel_reshaped, lc_2_bias])
 
 
-def copy_model_weights(model_2_from, model_1_to):
-  for l in range(len(model_2_from.layers)):
-    layer_2_from = model_2_from.layers[l]
-    layer_1_to = model_1_to.layers[l]
+def copy_lc_weights_2_to_3(lc_layer_2_from, lc_layer_3_to):
+  lc_2_kernel, lc_2_bias = lc_layer_2_from.weights
+  lc_2_kernel_masked = lc_2_kernel * lc_layer_2_from.kernel_mask
 
-    if isinstance(layer_2_from, (keras.layers.LocallyConnected2D,
-                                 keras.layers.LocallyConnected1D)):
-      copy_lc_weights(layer_2_from, layer_1_to)
+  lc_2_kernel_masked = keras.layers.local.make_2d(
+      lc_2_kernel_masked, split_dim=keras.backend.ndim(lc_2_kernel_masked) // 2)
+  lc_2_kernel_masked = keras.backend.transpose(lc_2_kernel_masked)
+  lc_2_kernel_mask = keras.backend.math_ops.not_equal(lc_2_kernel_masked, 0)
+  lc_2_kernel_flat = keras.backend.array_ops.boolean_mask(
+      lc_2_kernel_masked, lc_2_kernel_mask)
 
-    elif isinstance(layer_2_from, keras.layers.Dense):
-      weights_2, bias_2 = layer_2_from.weights
+  lc_2_kernel_flat = keras.backend.get_value(lc_2_kernel_flat)
+  lc_2_bias = keras.backend.get_value(lc_2_bias)
+
+  lc_layer_3_to.set_weights([lc_2_kernel_flat, lc_2_bias])
+
+
+def copy_model_weights(model_from, model_to):
+  for l in range(len(model_from.layers)):
+    layer_from = model_from.layers[l]
+    layer_to = model_to.layers[l]
+
+    if (isinstance(
+        layer_from,
+        (keras.layers.LocallyConnected2D, keras.layers.LocallyConnected1D)) and
+        isinstance(layer_to, (keras.layers.LocallyConnected2D,
+                              keras.layers.LocallyConnected1D))):
+      if layer_from.implementation == 2:
+        if layer_to.implementation == 1:
+          copy_lc_weights_2_to_1(layer_from, layer_to)
+        elif layer_to.implementation == 3:
+          copy_lc_weights_2_to_3(layer_from, layer_to)
+        else:
+          raise NotImplementedError
+
+      else:
+        raise NotImplementedError
+
+    elif isinstance(layer_from, keras.layers.Dense):
+      weights_2, bias_2 = layer_from.weights
       weights_2 = keras.backend.get_value(weights_2)
       bias_2 = keras.backend.get_value(bias_2)
-      layer_1_to.set_weights([weights_2, bias_2])
+      layer_to.set_weights([weights_2, bias_2])
 
     else:
       continue
diff --git a/tensorflow/python/keras/layers/lstm_test.py b/tensorflow/python/keras/layers/lstm_test.py
index 2859c45..c3708d9 100644
--- a/tensorflow/python/keras/layers/lstm_test.py
+++ b/tensorflow/python/keras/layers/lstm_test.py
@@ -44,6 +44,19 @@
                 'return_sequences': True},
         input_shape=(num_samples, timesteps, embedding_dim))
 
+  def test_float64_LSTM(self):
+    num_samples = 2
+    timesteps = 3
+    embedding_dim = 4
+    units = 2
+    testing_utils.layer_test(
+        keras.layers.LSTM,
+        kwargs={'units': units,
+                'return_sequences': True,
+                'dtype': 'float64'},
+        input_shape=(num_samples, timesteps, embedding_dim),
+        input_dtype='float64')
+
   def test_static_shape_inference_LSTM(self):
     # Github issue: 15165
     timesteps = 3
@@ -71,7 +84,7 @@
         'rmsprop',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     x = np.random.random((num_samples, timesteps, embedding_dim))
     y = np.random.random((num_samples, units))
@@ -132,7 +145,7 @@
         loss='categorical_crossentropy',
         optimizer='rmsprop',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(inputs, targets, epochs=1, batch_size=2, verbose=1)
 
   def test_masking_with_stacking_LSTM(self):
@@ -147,7 +160,7 @@
         loss='categorical_crossentropy',
         optimizer='rmsprop',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(inputs, targets, epochs=1, batch_size=2, verbose=1)
 
   def test_from_config_LSTM(self):
@@ -179,7 +192,7 @@
         loss='categorical_crossentropy',
         optimizer=adam.AdamOptimizer(),
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     inputs = np.random.random((num_samples, timesteps, embedding_dim))
     initial_state = [np.random.random((num_samples, units))
@@ -207,7 +220,7 @@
         loss='categorical_crossentropy',
         optimizer=adam.AdamOptimizer(),
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     inputs = np.random.random((num_samples, timesteps, embedding_dim))
     targets = np.random.random((num_samples, units))
@@ -260,7 +273,7 @@
         loss='categorical_crossentropy',
         optimizer='rmsprop',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     inputs = np.random.random((num_samples, timesteps, embedding_dim))
     initial_state = [np.random.random((num_samples, units))
@@ -324,7 +337,7 @@
         loss='categorical_crossentropy',
         optimizer=adam.AdamOptimizer(),
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     main_inputs = np.random.random((num_samples, timesteps, embedding_dim))
     initial_state = [np.random.random((num_samples, units))
@@ -374,7 +387,7 @@
         optimizer=gradient_descent.GradientDescentOptimizer(0.01),
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     out1 = model.predict(np.ones((num_samples, timesteps)))
     self.assertEqual(out1.shape, (num_samples, units))
 
diff --git a/tensorflow/python/keras/layers/lstm_v2_test.py b/tensorflow/python/keras/layers/lstm_v2_test.py
index 5ddbf2d..94e7f35 100644
--- a/tensorflow/python/keras/layers/lstm_v2_test.py
+++ b/tensorflow/python/keras/layers/lstm_v2_test.py
@@ -29,6 +29,7 @@
 from tensorflow.core.protobuf import rewriter_config_pb2
 from tensorflow.python import keras
 from tensorflow.python.client import session as session_lib
+from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -305,7 +306,6 @@
     targets = np.random.random((num_samples, units))
     model.train_on_batch([main_inputs] + initial_state, targets)
 
-  # Due to b/120160788.
   @test_util.run_v2_only
   def test_lstm_v2_feature_parity_with_canonical_lstm(self):
     input_shape = 10
@@ -565,6 +565,21 @@
         },
         input_shape=(num_samples, timesteps, embedding_dim))
 
+  def test_float64_LSTM(self):
+    num_samples = 2
+    timesteps = 3
+    embedding_dim = 4
+    units = 2
+    testing_utils.layer_test(
+        rnn.LSTM,
+        kwargs={
+            'units': units,
+            'return_sequences': True,
+            'dtype': 'float64'
+        },
+        input_shape=(num_samples, timesteps, embedding_dim),
+        input_dtype='float64')
+
   def test_regularizers_LSTM(self):
     embedding_dim = 4
     layer_class = rnn.LSTM
@@ -586,8 +601,6 @@
     else:
       self.assertEqual(len(layer.get_losses_for(x)), 1)
 
-  # Run in V2 only due to b/120160788.
-  @test_util.run_v2_only
   def test_statefulness_LSTM(self):
     num_samples = 2
     timesteps = 3
@@ -609,7 +622,7 @@
         optimizer=gradient_descent.GradientDescentOptimizer(0.01),
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     out1 = model.predict(np.ones((num_samples, timesteps)))
     self.assertEqual(out1.shape, (num_samples, units))
 
@@ -682,7 +695,7 @@
         optimizer='adam',
         loss='sparse_categorical_crossentropy',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(x, y, epochs=1, shuffle=False)
 
   def test_dropout_LSTM(self):
@@ -747,6 +760,23 @@
       outputs_trimmed = lstm(inputs[:, :masksteps])
     self.assertAllClose(outputs_masked[:, -masksteps:], outputs_trimmed)
 
+  @test_util.run_deprecated_v1
+  def test_v1_session_behavior(self):
+    # See b/139132348 for more details.
+    x = np.random.uniform(size=(100, 4, 8))
+    y = np.random.uniform(size=(100, 1))
+    dataset = dataset_ops.Dataset.from_tensor_slices(
+        (x, y)).shuffle(100).batch(32)
+
+    inp = keras.layers.Input(shape=(4, 8))
+    layer = rnn.LSTM(1)(inp)
+    layer = keras.layers.Dense(1)(layer)
+
+    model = keras.models.Model(inp, layer)
+
+    model.compile(loss='mse', optimizer='sgd')
+    model.fit(dataset)
+
 
 @keras_parameterized.run_all_keras_modes(config=_config)
 class LSTMGraphRewriteTest(keras_parameterized.TestCase):
@@ -767,10 +797,11 @@
         num_classes=self.output_shape)
     y_train = keras.utils.to_categorical(y_train, self.output_shape)
 
-    model.compile(optimizer='sgd',
-                  loss=['categorical_crossentropy', None],
-                  run_eagerly=testing_utils.should_run_eagerly(),
-                  run_distributed=testing_utils.should_run_distributed())
+    model.compile(
+        optimizer='sgd',
+        loss=['categorical_crossentropy', None],
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     existing_loss = 0
     for _ in range(self.epoch):
@@ -786,6 +817,7 @@
     else:
       self.assertEqual(runtime_value[0], rnn._RUNTIME_CPU)
 
+  @test_util.run_v2_only
   def test_LSTM_runtime(self):
     layer = rnn.LSTM(self.rnn_state_size, return_runtime=True)
 
@@ -801,6 +833,7 @@
     model = keras.models.Model(inputs=inputs, outputs=[outputs, runtime])
     self._test_runtime_with_model(model)
 
+  @test_util.run_v2_only
   def test_LSTM_runtime_with_mask(self):
     # Masking will affect which backend is selected based on whether the mask
     # is strictly right padded.
@@ -825,10 +858,11 @@
         num_classes=self.output_shape)
     y_train = keras.utils.to_categorical(y_train, self.output_shape)
 
-    model.compile(optimizer='sgd',
-                  loss=['categorical_crossentropy', None],
-                  run_eagerly=testing_utils.should_run_eagerly(),
-                  run_distributed=testing_utils.should_run_distributed())
+    model.compile(
+        optimizer='sgd',
+        loss=['categorical_crossentropy', None],
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     model.fit(x_train, y_train)
 
@@ -855,7 +889,6 @@
     _, runtime_value = model.predict(x_train)
     self.assertEqual(runtime_value[0], rnn._RUNTIME_CPU)
 
-  # Due to b/120160788.
   @test_util.run_v2_only
   def test_LSTM_runtime_with_cond(self):
     # This test is to demonstrate the graph rewrite of grappler plugin under
diff --git a/tensorflow/python/keras/layers/merge_test.py b/tensorflow/python/keras/layers/merge_test.py
index 4d268d3..78db3af 100644
--- a/tensorflow/python/keras/layers/merge_test.py
+++ b/tensorflow/python/keras/layers/merge_test.py
@@ -41,7 +41,7 @@
     self.assertListEqual(o.shape.as_list(), [None, 4, 5])
     model = keras.models.Model([i1, i2, i3], o)
     model.run_eagerly = testing_utils.should_run_eagerly()
-    model._run_distributed = testing_utils.should_run_distributed()
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
 
     x1 = np.random.random((2, 4, 5))
     x2 = np.random.random((2, 4, 5))
@@ -75,7 +75,7 @@
     self.assertListEqual(o.shape.as_list(), [None, 4, 5])
     model = keras.models.Model([i1, i2], o)
     model.run_eagerly = testing_utils.should_run_eagerly()
-    model._run_distributed = testing_utils.should_run_distributed()
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
 
     x1 = np.random.random((2, 4, 5))
     x2 = np.random.random((2, 4, 5))
@@ -109,7 +109,7 @@
     self.assertListEqual(o.shape.as_list(), [None, 4, 5])
     model = keras.models.Model([i1, i2, i3], o)
     model.run_eagerly = testing_utils.should_run_eagerly()
-    model._run_distributed = testing_utils.should_run_distributed()
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
 
     x1 = np.random.random((2, 4, 5))
     x2 = np.random.random((2, 4, 5))
@@ -125,7 +125,7 @@
     self.assertListEqual(o.shape.as_list(), [None, 4, 5])
     model = keras.models.Model([i1, i2], o)
     model.run_eagerly = testing_utils.should_run_eagerly()
-    model._run_distributed = testing_utils.should_run_distributed()
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
 
     x1 = np.random.random((2, 4, 5))
     x2 = np.random.random((2, 4, 5))
@@ -140,7 +140,7 @@
     self.assertListEqual(o.shape.as_list(), [None, 4, 5])
     model = keras.models.Model([i1, i2], o)
     model.run_eagerly = testing_utils.should_run_eagerly()
-    model._run_distributed = testing_utils.should_run_distributed()
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
 
     x1 = np.random.random((2, 4, 5))
     x2 = np.random.random((2, 4, 5))
@@ -155,7 +155,7 @@
     self.assertListEqual(o.shape.as_list(), [None, 4, 5])
     model = keras.models.Model([i1, i2], o)
     model.run_eagerly = testing_utils.should_run_eagerly()
-    model._run_distributed = testing_utils.should_run_distributed()
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
 
     x1 = np.random.random((2, 4, 5))
     x2 = np.random.random((2, 4, 5))
@@ -171,7 +171,7 @@
     self.assertListEqual(o.shape.as_list(), [None, 8, 5])
     model = keras.models.Model([i1, i2], o)
     model.run_eagerly = testing_utils.should_run_eagerly()
-    model._run_distributed = testing_utils.should_run_distributed()
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
 
     x1 = np.random.random((2, 4, 5))
     x2 = np.random.random((2, 4, 5))
@@ -203,7 +203,7 @@
     self.assertListEqual(o.shape.as_list(), [None, 1])
     model = keras.models.Model([i1, i2], o)
     model.run_eagerly = testing_utils.should_run_eagerly()
-    model._run_distributed = testing_utils.should_run_distributed()
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
     _ = keras.layers.Dot(axes=1).get_config()
 
     x1 = np.random.random((2, 4))
@@ -220,7 +220,7 @@
     self.assertListEqual(o.shape.as_list(), [None, 1])
     model = keras.models.Model([i1, i2], o)
     model.run_eagerly = testing_utils.should_run_eagerly()
-    model._run_distributed = testing_utils.should_run_distributed()
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
     out = model.predict([x1, x2])
     self.assertEqual(out.shape, (2, 1))
     self.assertAllClose(out, expected, atol=1e-4)
diff --git a/tensorflow/python/keras/layers/normalization_test.py b/tensorflow/python/keras/layers/normalization_test.py
index 8ab0bac..369d010 100644
--- a/tensorflow/python/keras/layers/normalization_test.py
+++ b/tensorflow/python/keras/layers/normalization_test.py
@@ -104,7 +104,7 @@
             loss='mse',
             optimizer=gradient_descent.GradientDescentOptimizer(0.01),
             run_eagerly=testing_utils.should_run_eagerly(),
-            run_distributed=testing_utils.should_run_distributed())
+            experimental_run_tf_function=testing_utils.should_run_tf_function())
 
         # centered on 5.0, variance 10.0
         x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4))
@@ -126,7 +126,7 @@
         loss='mse',
         optimizer=gradient_descent.GradientDescentOptimizer(0.01),
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     # centered on 5.0, variance 10.0
     x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 4, 4, 3))
@@ -175,7 +175,7 @@
         'rmsprop',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(np.random.random((100, 3)), np.random.random((100, 3)))
 
     test_data = np.random.random((10, 3))
@@ -187,7 +187,7 @@
         'rmsprop',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     train_loss = model.train_on_batch(test_data, test_targets)
     self.assertAlmostEqual(test_loss, train_loss)
 
@@ -366,7 +366,7 @@
       loss='mse',
       optimizer=gradient_descent.GradientDescentOptimizer(0.01),
       run_eagerly=testing_utils.should_run_eagerly(),
-      run_distributed=testing_utils.should_run_distributed())
+      experimental_run_tf_function=testing_utils.should_run_tf_function())
 
   # centered on 5.0, variance 10.0
   x = (np.random.normal(loc=5.0, scale=10.0, size=(1000, 2, 2, 2))
@@ -498,10 +498,11 @@
   model = keras.models.Sequential()
   norm = layer(input_shape=(2, 2, 2))
   model.add(norm)
-  model.compile(loss='mse',
-                optimizer=gradient_descent.GradientDescentOptimizer(0.01),
-                run_eagerly=testing_utils.should_run_eagerly())
-  # TODO(b/137397816): run_distributed=testing_utils.should_run_distributed()
+  model.compile(
+      loss='mse',
+      optimizer=gradient_descent.GradientDescentOptimizer(0.01),
+      run_eagerly=testing_utils.should_run_eagerly(),
+      experimental_run_tf_function=testing_utils.should_run_tf_function())
 
   # centered on 5.0, variance 10.0
   x = (np.random.normal(loc=5.0, scale=10.0, size=(1000, 2, 2, 2))
@@ -573,7 +574,7 @@
         loss='mse',
         optimizer=gradient_descent.GradientDescentOptimizer(0.01),
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     # centered on 5.0, variance 10.0
     x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 4, 4, 3))
diff --git a/tensorflow/python/keras/layers/preprocessing/normalization_test.py b/tensorflow/python/keras/layers/preprocessing/normalization_test.py
index 7167c43..abb61f7 100644
--- a/tensorflow/python/keras/layers/preprocessing/normalization_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/normalization_test.py
@@ -171,7 +171,7 @@
     output = layer(input_data)
     model = keras.Model(input_data, output)
     model._run_eagerly = testing_utils.should_run_eagerly()
-    model._run_distributed = testing_utils.should_run_distributed()
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
     output_data = model.predict(test_data)
     self.assertAllClose(expected, output_data)
 
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py
index 5b88db6..2ee98a4 100644
--- a/tensorflow/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/layers/recurrent.py
@@ -24,6 +24,7 @@
 import numpy as np
 
 from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.keras import activations
 from tensorflow.python.keras import backend as K
@@ -35,6 +36,7 @@
 from tensorflow.python.keras.utils import generic_utils
 from tensorflow.python.keras.utils import tf_utils
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_util
 from tensorflow.python.ops import state_ops
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training.tracking import base as trackable
@@ -732,6 +734,15 @@
           new_states = [new_states]
         return output, new_states
 
+    # `input_length` is passed as the `maximum_iterations` arg to tf.while_loop.
+    # We only specify that when building for XLA since that causes slowdowns
+    # on GPU in TF.
+    if (not context.executing_eagerly() and
+        control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph())):
+      input_length = timesteps
+    else:
+      input_length = None
+
     last_output, outputs, states = K.rnn(
         step,
         inputs,
@@ -740,7 +751,7 @@
         go_backwards=self.go_backwards,
         mask=mask,
         unroll=self.unroll,
-        input_length=timesteps,
+        input_length=input_length,
         time_major=self.time_major,
         zero_output_for_mask=self.zero_output_for_mask)
     if self.stateful:
@@ -1362,7 +1373,8 @@
         recurrent_constraint=recurrent_constraint,
         bias_constraint=bias_constraint,
         dropout=dropout,
-        recurrent_dropout=recurrent_dropout)
+        recurrent_dropout=recurrent_dropout,
+        dtype=kwargs.get('dtype'))
     super(SimpleRNN, self).__init__(
         cell,
         return_sequences=return_sequences,
@@ -1890,7 +1902,8 @@
         dropout=dropout,
         recurrent_dropout=recurrent_dropout,
         implementation=implementation,
-        reset_after=reset_after)
+        reset_after=reset_after,
+        dtype=kwargs.get('dtype'))
     super(GRU, self).__init__(
         cell,
         return_sequences=return_sequences,
@@ -2516,7 +2529,8 @@
         bias_constraint=bias_constraint,
         dropout=dropout,
         recurrent_dropout=recurrent_dropout,
-        implementation=implementation)
+        implementation=implementation,
+        dtype=kwargs.get('dtype'))
     super(LSTM, self).__init__(
         cell,
         return_sequences=return_sequences,
diff --git a/tensorflow/python/keras/layers/recurrent_test.py b/tensorflow/python/keras/layers/recurrent_test.py
index fc2d5b3..37311a7 100644
--- a/tensorflow/python/keras/layers/recurrent_test.py
+++ b/tensorflow/python/keras/layers/recurrent_test.py
@@ -83,7 +83,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))
 
     # Test stacking.
@@ -97,7 +97,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))
 
   def test_minimal_rnn_cell_non_layer_multiple_states(self):
@@ -128,7 +128,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))
 
     # Test stacking.
@@ -144,7 +144,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))
 
   def test_minimal_rnn_cell_layer(self):
@@ -187,7 +187,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))
 
     # Test basic case serialization.
@@ -214,7 +214,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))
 
     # Test stacked RNN serialization.
@@ -271,7 +271,7 @@
         optimizer="rmsprop",
         loss="mse",
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))
 
     # Test stacking.
@@ -285,7 +285,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))
 
   def test_rnn_with_time_major(self):
@@ -314,7 +314,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(
         np.zeros((batch, time_step, embedding_dim)),
         np.zeros((batch, time_step, units)))
@@ -335,7 +335,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(
         np.zeros((batch, time_step, embedding_dim)),
         np.zeros((batch, time_step, cell_units[-1])))
@@ -353,7 +353,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(
         np.zeros((batch, time_step, embedding_dim)),
         np.zeros((batch, time_step, units)))
@@ -368,7 +368,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(
         np.zeros((batch, time_step, embedding_dim)),
         np.zeros((batch, time_step, units)))
@@ -403,7 +403,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(
         [np.zeros((6, 5, 5)), np.zeros((6, 3))],
         np.zeros((6, 32))
@@ -444,7 +444,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(
         [np.zeros((6, 5, 5)), np.zeros((6, 3))],
         np.zeros((6, 32))
@@ -461,7 +461,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(
         [np.zeros((6, 5, 5)), np.zeros((6, 3))],
         np.zeros((6, 32))
@@ -494,7 +494,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))
 
     # Test stacking.
@@ -508,7 +508,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))
 
   def test_rnn_cell_with_constants_layer_passing_initial_state(self):
@@ -524,7 +524,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(
         [np.zeros((6, 5, 5)), np.zeros((6, 32)), np.zeros((6, 3))],
         np.zeros((6, 32))
@@ -574,7 +574,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))
 
     # Test stacking.
@@ -591,7 +591,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))
 
   def test_stacked_rnn_attributes(self):
@@ -693,7 +693,7 @@
           optimizer='rmsprop',
           loss='mse',
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
 
       # Test basic case serialization.
       x_np = np.random.random((6, 5, 5))
@@ -718,7 +718,7 @@
           optimizer='rmsprop',
           loss='mse',
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
 
       # Test stacked RNN serialization.
       x_np = np.random.random((6, 5, 5))
@@ -749,7 +749,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     x_np = np.random.random((6, 5, 5))
     y_np = np.random.random((6, 3))
     model.train_on_batch(x_np, y_np)
@@ -774,7 +774,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     x_np = np.random.random((6, 5, 5))
     y_np = np.random.random((6, 3))
     model.train_on_batch(x_np, y_np)
@@ -852,7 +852,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(x, y, epochs=1, batch_size=1)
 
     # check whether the model variables are present in the
@@ -888,7 +888,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(
         np.zeros((batch, time_step, input_a, input_b)),
         np.zeros((batch, unit_a, unit_b)))
@@ -907,7 +907,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(
         np.zeros((batch, time_step, input_a, input_b)),
         np.zeros((batch, unit_a * 4, unit_b * 4)))
@@ -933,7 +933,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch([
         np.zeros((batch, time_step, input_a, input_b)),
         np.zeros((batch, unit_a, unit_b))
@@ -972,7 +972,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(
         np.zeros((batch, time_step, input_size)),
         np.zeros((batch, input_size)))
@@ -1030,7 +1030,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(
         [np.zeros((batch, t, i1)), np.zeros((batch, t, i2, i3))],
         [np.zeros((batch, o1)), np.zeros((batch, o2, o3))])
@@ -1054,7 +1054,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(
         [np.zeros((batch, t, i1)),
          np.zeros((batch, t, i2, i3))],
@@ -1085,7 +1085,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(
         [np.zeros((batch, t, i1)),
          np.zeros((batch, t, i2, i3))],
@@ -1112,7 +1112,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(
         [np.zeros((batch, t, i1)),
          np.zeros((batch, t, i2, i3))],
@@ -1148,7 +1148,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(
         [np.zeros((batch, t, i1)),
          np.zeros((batch, t, i2, i3)),
@@ -1182,7 +1182,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(
         [np.zeros((batch, t, i1)),
          np.zeros((batch, t, i2, i3)),
@@ -1260,7 +1260,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     # last time step masked
     x_np = np.array([[[1.], [2.], [0.]]])
@@ -1287,7 +1287,7 @@
           optimizer='rmsprop',
           loss='mse',
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
 
       np_x = np.ones((6, 5, 5))
       result_1 = model.predict(np_x)
@@ -1312,7 +1312,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     np_x = np.ones((6, 1, 5))
     result = model.predict(np_x)
@@ -1368,7 +1368,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(
         np.zeros((batch, timesteps, input_dim)),
         np.zeros((batch, output_dim)))
@@ -1419,7 +1419,7 @@
         optimizer='rmsprop',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.train_on_batch(
         np.zeros((batch, timesteps, input_dim)),
         np.zeros((batch, output_dim)))
diff --git a/tensorflow/python/keras/layers/recurrent_v2.py b/tensorflow/python/keras/layers/recurrent_v2.py
index 217403a..3f1be45 100644
--- a/tensorflow/python/keras/layers/recurrent_v2.py
+++ b/tensorflow/python/keras/layers/recurrent_v2.py
@@ -306,7 +306,7 @@
     self.could_use_cudnn = (
         activation == 'tanh' and recurrent_activation == 'sigmoid' and
         recurrent_dropout == 0 and not unroll and use_bias and
-        reset_after)
+        reset_after and ops.executing_eagerly_outside_functions())
 
   def call(self, inputs, mask=None, training=None, initial_state=None):
     # GRU does not support constants. Ignore it during process.
@@ -486,9 +486,14 @@
 def cudnn_gru(inputs, init_h, kernel, recurrent_kernel, bias, mask, time_major,
               go_backwards):
   """GRU with CuDNN implementation which is only available for GPU."""
-  if not time_major:
+  if not time_major and mask is None:
     inputs = array_ops.transpose(inputs, perm=(1, 0, 2))
-  init_h = array_ops.expand_dims(init_h, axis=0)
+    seq_axis, batch_axis = (0, 1)
+  else:
+    seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
+  # For init_h, cuDNN expects one more dim of num_layers before or after batch
+  # dim for time major or batch major inputs respectively
+  init_h = array_ops.expand_dims(init_h, axis=seq_axis)
 
   weights = array_ops.split(kernel, 3, axis=1)
   weights += array_ops.split(recurrent_kernel, 3, axis=1)
@@ -520,15 +525,21 @@
       # reversed_input_to_cudnn = [3, 2, 1, 0, 0]
       # output_from_cudnn = [6, 5, 4, 0, 0]
       # expected_output = [0, 0, 6, 5 ,4]
-      inputs = array_ops.reverse_sequence_v2(inputs, sequence_length,
-                                             seq_axis=0, batch_axis=1)
+      inputs = array_ops.reverse_sequence_v2(
+          inputs, sequence_length, seq_axis=seq_axis, batch_axis=batch_axis)
     outputs, h, _, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(
-        inputs, input_h=init_h, input_c=0, params=params, is_training=True,
-        rnn_mode='gru', sequence_lengths=sequence_length)
+        inputs,
+        input_h=init_h,
+        input_c=0,
+        params=params,
+        is_training=True,
+        rnn_mode='gru',
+        sequence_lengths=sequence_length,
+        time_major=time_major)
     if go_backwards:
-      outputs = array_ops.reverse_sequence_v2(outputs, sequence_length,
-                                              seq_axis=0, batch_axis=1)
-      outputs = array_ops.reverse(outputs, axis=[0])
+      outputs = array_ops.reverse_sequence_v2(
+          outputs, sequence_length, seq_axis=seq_axis, batch_axis=batch_axis)
+      outputs = array_ops.reverse(outputs, axis=[seq_axis])
   else:
     if go_backwards:
       # Reverse axis 0 since the input is already convert to time major.
@@ -538,9 +549,9 @@
         rnn_mode='gru')
 
   last_output = outputs[-1]
-  if not time_major:
+  if not time_major and mask is None:
     outputs = array_ops.transpose(outputs, perm=[1, 0, 2])
-  h = h[0]
+  h = array_ops.squeeze(h, axis=seq_axis)
 
   # In the case of variable length input, the cudnn kernel will fill zeros for
   # the output, whereas the default keras behavior is to bring over the previous
@@ -885,7 +896,8 @@
     ]
     self.could_use_cudnn = (
         activation == 'tanh' and recurrent_activation == 'sigmoid' and
-        recurrent_dropout == 0 and not unroll and use_bias)
+        recurrent_dropout == 0 and not unroll and use_bias and
+        ops.executing_eagerly_outside_functions())
 
   def call(self, inputs, mask=None, training=None, initial_state=None):
     # LSTM does not support constants. Ignore it during process.
@@ -1126,11 +1138,15 @@
     runtime: Constant string tensor which indicate real runtime hardware. This
       value is for testing purpose and should not be used by user.
   """
-  if not time_major:
-    # Cudnn kernel prefer the input to be time major.
+  if not time_major and mask is None:
     inputs = array_ops.transpose(inputs, perm=(1, 0, 2))
-  init_h = array_ops.expand_dims(init_h, axis=0)
-  init_c = array_ops.expand_dims(init_c, axis=0)
+    seq_axis, batch_axis = (0, 1)
+  else:
+    seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
+  # For init_h and init_c, cuDNN expects one more dim of num_layers before or
+  # after batch dim for time major or batch major inputs respectively
+  init_h = array_ops.expand_dims(init_h, axis=seq_axis)
+  init_c = array_ops.expand_dims(init_c, axis=seq_axis)
 
   weights = array_ops.split(kernel, 4, axis=1)
   weights += array_ops.split(recurrent_kernel, 4, axis=1)
@@ -1152,15 +1168,21 @@
       # reversed_input_to_cudnn = [3, 2, 1, 0, 0]
       # output_from_cudnn = [6, 5, 4, 0, 0]
       # expected_output = [0, 0, 6, 5 ,4]
-      inputs = array_ops.reverse_sequence_v2(inputs, sequence_length,
-                                             seq_axis=0, batch_axis=1)
+      inputs = array_ops.reverse_sequence_v2(
+          inputs, sequence_length, seq_axis=seq_axis, batch_axis=batch_axis)
     outputs, h, c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(
-        inputs, input_h=init_h, input_c=init_c, params=params, is_training=True,
-        rnn_mode='lstm', sequence_lengths=sequence_length)
+        inputs,
+        input_h=init_h,
+        input_c=init_c,
+        params=params,
+        is_training=True,
+        rnn_mode='lstm',
+        sequence_lengths=sequence_length,
+        time_major=time_major)
     if go_backwards:
-      outputs = array_ops.reverse_sequence_v2(outputs, sequence_length,
-                                              seq_axis=0, batch_axis=1)
-      outputs = array_ops.reverse(outputs, axis=[0])
+      outputs = array_ops.reverse_sequence_v2(
+          outputs, sequence_length, seq_axis=seq_axis, batch_axis=batch_axis)
+      outputs = array_ops.reverse(outputs, axis=[seq_axis])
   else:
     # # Fill the array with shape [batch] with value of max timesteps.
     # sequence_length = array_ops.fill([array_ops.shape(inputs)[1]],
@@ -1173,10 +1195,10 @@
         rnn_mode='lstm')
 
   last_output = outputs[-1]
-  if not time_major:
+  if not time_major and mask is None:
     outputs = array_ops.transpose(outputs, perm=[1, 0, 2])
-  h = h[0]
-  c = c[0]
+  h = array_ops.squeeze(h, axis=seq_axis)
+  c = array_ops.squeeze(c, axis=seq_axis)
 
   # In the case of variable length input, the cudnn kernel will fill zeros for
   # the output, whereas the default keras behavior is to bring over the previous
diff --git a/tensorflow/python/keras/layers/recurrent_v2_test.py b/tensorflow/python/keras/layers/recurrent_v2_test.py
index 2d45e64..487ee81 100644
--- a/tensorflow/python/keras/layers/recurrent_v2_test.py
+++ b/tensorflow/python/keras/layers/recurrent_v2_test.py
@@ -61,7 +61,7 @@
           optimizer='adam',
           loss='sparse_categorical_crossentropy',
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
       model.fit(x, y, epochs=1, shuffle=False)
 
   @parameterized.parameters([rnn_v2.LSTM, rnn_v2.GRU])
diff --git a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py
index 42746de..9ab88b9 100644
--- a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py
+++ b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py
@@ -39,8 +39,8 @@
   @test_util.run_in_graph_and_eager_modes
   def testResidualWrapper(self):
     wrapper_type = rnn_cell_wrapper_v2.ResidualWrapper
-    x = ops.convert_to_tensor(np.array([[1., 1., 1.]]))
-    m = ops.convert_to_tensor(np.array([[0.1, 0.1, 0.1]]))
+    x = ops.convert_to_tensor(np.array([[1., 1., 1.]]), dtype="float32")
+    m = ops.convert_to_tensor(np.array([[0.1, 0.1, 0.1]]), dtype="float32")
     base_cell = rnn_cell_impl.GRUCell(
         3, kernel_initializer=init_ops.constant_initializer(0.5),
         bias_initializer=init_ops.constant_initializer(0.5))
@@ -62,8 +62,8 @@
   @test_util.run_in_graph_and_eager_modes
   def testResidualWrapperWithSlice(self):
     wrapper_type = rnn_cell_wrapper_v2.ResidualWrapper
-    x = ops.convert_to_tensor(np.array([[1., 1., 1., 1., 1.]]))
-    m = ops.convert_to_tensor(np.array([[0.1, 0.1, 0.1]]))
+    x = ops.convert_to_tensor(np.array([[1., 1., 1., 1., 1.]]), dtype="float32")
+    m = ops.convert_to_tensor(np.array([[0.1, 0.1, 0.1]]), dtype="float32")
     base_cell = rnn_cell_impl.GRUCell(
         3, kernel_initializer=init_ops.constant_initializer(0.5),
         bias_initializer=init_ops.constant_initializer(0.5))
diff --git a/tensorflow/python/keras/layers/serialization.py b/tensorflow/python/keras/layers/serialization.py
index 11cd12c..795c8b2 100644
--- a/tensorflow/python/keras/layers/serialization.py
+++ b/tensorflow/python/keras/layers/serialization.py
@@ -77,7 +77,7 @@
   """
   # Prevent circular dependencies.
   from tensorflow.python.keras import models  # pylint: disable=g-import-not-at-top
-  from tensorflow.python.feature_column import feature_column_v2  # pylint: disable=g-import-not-at-top
+  from tensorflow.python.feature_column import dense_features  # pylint: disable=g-import-not-at-top
 
   globs = globals()  # All layers.
   globs['Network'] = models.Network
@@ -85,7 +85,7 @@
   globs['Sequential'] = models.Sequential
 
   # Prevent circular dependencies with FeatureColumn serialization.
-  globs['DenseFeatures'] = feature_column_v2.DenseFeatures
+  globs['DenseFeatures'] = dense_features.DenseFeatures
 
   layer_class_name = config['class_name']
   if layer_class_name in _DESERIALIZATION_TABLE:
diff --git a/tensorflow/python/keras/layers/simplernn_test.py b/tensorflow/python/keras/layers/simplernn_test.py
index d8346b3..731e312 100644
--- a/tensorflow/python/keras/layers/simplernn_test.py
+++ b/tensorflow/python/keras/layers/simplernn_test.py
@@ -42,6 +42,19 @@
                 'return_sequences': True},
         input_shape=(num_samples, timesteps, embedding_dim))
 
+  def test_float64_SimpleRNN(self):
+    num_samples = 2
+    timesteps = 3
+    embedding_dim = 4
+    units = 2
+    testing_utils.layer_test(
+        keras.layers.SimpleRNN,
+        kwargs={'units': units,
+                'return_sequences': True,
+                'dtype': 'float64'},
+        input_shape=(num_samples, timesteps, embedding_dim),
+        input_dtype='float64')
+
   def test_dynamic_behavior_SimpleRNN(self):
     num_samples = 2
     timesteps = 3
@@ -159,7 +172,7 @@
         optimizer=gradient_descent.GradientDescentOptimizer(0.01),
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     out1 = model.predict(np.ones((num_samples, timesteps)))
     self.assertEqual(out1.shape, (num_samples, units))
 
diff --git a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py
index a08bf214..a43e983 100644
--- a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py
+++ b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py
@@ -29,6 +29,7 @@
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import ops
 from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras import saving
 from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.optimizer_v2 import adam
 from tensorflow.python.ops import array_ops
@@ -42,7 +43,7 @@
   inputs = keras.Input(shape=(10,))
   x = keras.layers.Dense(10)(inputs)
   outputs = gen_nn_ops.relu(x)
-  return inputs, outputs
+  return keras.Model(inputs, outputs)
 
 
 def _single_identity_op_at_end():
@@ -50,7 +51,7 @@
   x = keras.layers.Dense(10)(inputs)
   outputs = array_ops.identity(x)
   assert 'Identity' in outputs.name
-  return inputs, outputs
+  return keras.Model(inputs, outputs)
 
 
 def _multiple_ops_at_end():
@@ -58,7 +59,7 @@
   x = keras.layers.Dense(10)(inputs)
   x = gen_nn_ops.relu(x)
   outputs = gen_nn_ops.relu(x)
-  return inputs, outputs
+  return keras.Model(inputs, outputs)
 
 
 def _single_op_in_middle():
@@ -66,7 +67,7 @@
   x = keras.layers.Dense(10)(inputs)
   x = gen_nn_ops.relu(x)
   outputs = keras.layers.Dense(10)(x)
-  return inputs, outputs
+  return keras.Model(inputs, outputs)
 
 
 def _multiple_ops_in_middle():
@@ -75,21 +76,21 @@
   x = gen_nn_ops.relu(x)
   x = gen_nn_ops.relu(x)
   outputs = keras.layers.Dense(10)(x)
-  return inputs, outputs
+  return keras.Model(inputs, outputs)
 
 
 def _single_standalone_branch():
   inputs = keras.Input(shape=(10,))
   x = keras.layers.Dense(10)(inputs)
   outputs = x * 2
-  return inputs, outputs
+  return keras.Model(inputs, outputs)
 
 
 def _single_op_with_attrs():
   inputs = keras.Input(shape=(10,))
   x = math_ops.reduce_mean(inputs, axis=1, keepdims=True)
   outputs = keras.layers.Dense(10)(x)
-  return inputs, outputs
+  return keras.Model(inputs, outputs)
 
 
 def _multiple_uses():
@@ -98,20 +99,20 @@
   x1 = keras.layers.Dense(10)(x)
   x2 = keras.layers.Dense(10)(x)
   outputs = x1 + x2
-  return inputs, outputs
+  return keras.Model(inputs, outputs)
 
 
 def _op_with_tensor_list():
   inputs = keras.Input(shape=(10,))
   x = array_ops.concat([inputs, inputs], axis=1)
   outputs = keras.layers.Dense(10)(x)
-  return inputs, outputs
+  return keras.Model(inputs, outputs)
 
 
 def _add_n():
   inputs = keras.Input(shape=(10,))
   outputs = math_ops.add_n([inputs, inputs, inputs])
-  return inputs, outputs
+  return keras.Model(inputs, outputs)
 
 
 def _reuse_op():
@@ -122,7 +123,29 @@
   x2 = x * 2
   y2 = keras.layers.Dense(10)(x2)
   outputs = y + y2
-  return inputs, outputs
+  return keras.Model(inputs, outputs)
+
+
+def _float64_op():
+  inputs = keras.Input(shape=(10,))
+  x = keras.layers.Dense(10, dtype='float64')(inputs)
+  x = gen_nn_ops.relu(x)
+  assert x.dtype == 'float64', 'x has dtype: %s' % x.dtype
+  outputs = keras.layers.Dense(10)(x)
+  return keras.Model(inputs, outputs)
+
+
+class MyAdd(keras.layers.Layer):
+
+  def call(self, x, y):
+    return x + y
+
+
+def _layer_with_tensor_arg():
+  inputs = keras.Input(shape=(10,))
+  x = inputs * 2
+  outputs = MyAdd()(inputs, x)
+  return keras.Model(inputs, outputs)
 
 
 class LayerWithLayer(keras.layers.Layer):
@@ -140,7 +163,27 @@
 def _inner_layer():
   inputs = keras.Input(shape=(10,))
   outputs = LayerWithLayer()(inputs)
-  return inputs, outputs
+  return keras.Model(inputs, outputs)
+
+
+def _reuse_ancillary_layer():
+  inputs = (keras.Input(shape=(5,)), keras.Input(shape=(5,)))
+  base_model = keras.Sequential([
+      keras.layers.Dense(3, input_shape=(5,)),
+  ])
+  outputs = base_model(inputs[0])
+  model = keras.Model(inputs, outputs)
+  # The second input is only involved in ancillary layers.
+  outputs_delta = outputs - base_model(0.5 * inputs[1])
+  l2_loss = math_ops.reduce_mean(
+      math_ops.reduce_sum(math_ops.square(outputs_delta), -1))
+  model.add_loss(l2_loss)
+  model.add_metric(l2_loss, aggregation='mean', name='l2_loss')
+  l1_loss = 0.01 * math_ops.reduce_mean(
+      math_ops.reduce_sum(math_ops.abs(outputs_delta), -1))
+  model.add_loss(l1_loss)
+  model.add_metric(l1_loss, aggregation='mean', name='l1_loss')
+  return model
 
 
 @keras_parameterized.run_all_keras_modes
@@ -155,33 +198,47 @@
       ('single_standalone_branch', _single_standalone_branch),
       ('single_op_with_attrs', _single_op_with_attrs),
       ('multiple_uses', _multiple_uses),
-      ('op_with_tensor_list', _op_with_tensor_list), ('add_n', _add_n),
-      ('_reuse_op', _reuse_op), ('_inner_layer', _inner_layer))
+      ('op_with_tensor_list', _op_with_tensor_list),
+      ('add_n', _add_n),
+      ('_reuse_op', _reuse_op),
+      ('_float64_op', _float64_op),
+      ('_inner_layer', _inner_layer),
+      ('_reuse_ancillary_layer', _reuse_ancillary_layer),
+      ('_layer_with_tensor_arg', _layer_with_tensor_arg),
+  )
   def test_autolambda(self, model_fn):
-    inputs, outputs = model_fn()
-    model = keras.Model(inputs, outputs)
+    model = model_fn()
     model.compile(
         adam.Adam(0.001),
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
-    np_inputs = nest.map_structure(lambda x: np.ones((10, 10), 'float32'),
-                                   inputs)
-    np_outputs = nest.map_structure(lambda x: np.ones((10, 10), 'float32'),
-                                    outputs)
+    np_inputs = nest.map_structure(
+        lambda x: np.ones((10,) + tuple(x.shape[1:]), 'float32'), model.inputs)
+    np_outputs = nest.map_structure(
+        lambda x: np.ones((10,) + tuple(x.shape[1:]), 'float32'), model.outputs)
     model.fit(np_inputs, np_outputs, batch_size=2)
     model(np_inputs)  # Test calling the model directly on inputs.
 
     new_model = keras.Model.from_config(
-        model.get_config(), custom_objects={'LayerWithLayer': LayerWithLayer})
+        model.get_config(),
+        custom_objects={
+            'LayerWithLayer': LayerWithLayer,
+            'MyAdd': MyAdd
+        })
     new_model.compile(
         adam.Adam(0.001),
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     new_model.fit(np_inputs, np_outputs, batch_size=2)
     new_model(np_inputs)  # Test calling the new model directly on inputs.
+    # Assert that metrics are preserved and in the right order.
+    self.assertAllEqual(model.metrics_names, new_model.metrics_names)
+    # Assert that layer names don't change.
+    self.assertAllEqual([layer.name for layer in model.layers],
+                        [layer.name for layer in new_model.layers])
 
   def test_numerical_correctness_simple(self):
     x = ops.convert_to_tensor([[-1., 0., -2., 1.]])
@@ -205,7 +262,7 @@
     outputs = gen_nn_ops.relu(inputs)
     model1 = keras.Model(inputs, outputs)
     y1 = self.evaluate(model1(x))
-    model2 = model1.from_config(model1.get_config())
+    model2 = keras.Model.from_config(model1.get_config())
     y2 = self.evaluate(model2(x))
     self.assertAllClose(y1, y2)
 
@@ -272,6 +329,15 @@
     # Test something that requires Layers to be built.
     model.summary()
 
+  def test_json_serialization(self):
+    inputs = keras.Input(shape=(4,), dtype='uint8')
+    outputs = math_ops.cast(inputs, 'float32') / 4.
+    model = saving.model_from_json(keras.Model(inputs, outputs).to_json())
+    self.assertAllEqual(
+        self.evaluate(model(np.array([0, 64, 128, 192], np.uint8))),
+        [0., 16., 32., 48.])
+    model.summary()
+
 
 class InputInEagerTest(test.TestCase):
   """Tests ops on graph tensors in Eager runtime.
diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py
index 2b2fd4f..8a2ff0d 100644
--- a/tensorflow/python/keras/losses.py
+++ b/tensorflow/python/keras/losses.py
@@ -95,21 +95,22 @@
     """Invokes the `Loss` instance.
 
     Args:
-      y_true: Ground truth values.
-      y_pred: The predicted values.
-      sample_weight: Optional `Tensor` whose rank is either 0, or the same rank
-        as `y_true`, or is broadcastable to `y_true`. `sample_weight` acts as a
+      y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`
+      y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`
+      sample_weight: Optional `sample_weight` acts as a
         coefficient for the loss. If a scalar is provided, then the loss is
         simply scaled by the given value. If `sample_weight` is a tensor of size
         `[batch_size]`, then the total loss for each sample of the batch is
         rescaled by the corresponding element in the `sample_weight` vector. If
-        the shape of `sample_weight` matches the shape of `y_pred`, then the
-        loss of each measurable element of `y_pred` is scaled by the
-        corresponding value of `sample_weight`.
+        the shape of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be
+        broadcasted to this shape), then each loss element of `y_pred` is scaled
+        by the corresponding value of `sample_weight`. (Note on`dN-1`: all loss
+        functions reduce by 1 dimension, usually axis=-1.)
 
     Returns:
-      Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
-        shape as `y_true`; otherwise, it is scalar.
+      Weighted loss float `Tensor`. If `reduction` is `NONE`, this has
+        shape `[batch_size, d0, .. dN-1]`; otherwise, it is scalar. (Note `dN-1`
+        because all loss functions reduce by 1 dimension, usually axis=-1.)
 
     Raises:
       ValueError: If the shape of `sample_weight` is invalid.
@@ -163,7 +164,7 @@
           '`tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` using global batch '
           'size like:\n```\nwith strategy.scope():\n'
           '    loss_obj = tf.keras.losses.CategoricalCrossentropy('
-          'reduction=tf.keras.losses.reduction.None)\n....\n'
+          'reduction=tf.keras.losses.reduction.NONE)\n....\n'
           '    loss = tf.reduce_sum(loss_obj(labels, predictions)) * '
           '(1. / global_batch_size)\n```\nPlease see '
           'https://www.tensorflow.org/alpha/tutorials/distribute/training_loops'
@@ -419,8 +420,8 @@
   cce = tf.keras.losses.CategoricalCrossentropy()
   loss = cce(
     [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]],
-    [[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]])
-  print('Loss: ', loss.numpy())  # Loss: 0.3239
+    [[.9, .05, .05], [.05, .89, .06], [.05, .01, .94]])
+  print('Loss: ', loss.numpy())  # Loss: 0.0945
   ```
 
   Usage with the `compile` API:
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index 69ddd17..7246a15 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -30,6 +30,7 @@
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.keras import backend as K
+from tensorflow.python.keras.engine import base_layer_utils
 from tensorflow.python.keras.engine.base_layer import Layer
 from tensorflow.python.keras.losses import binary_crossentropy
 from tensorflow.python.keras.losses import categorical_crossentropy
@@ -136,7 +137,10 @@
     super(Metric, self).__init__(name=name, dtype=dtype, **kwargs)
     self.stateful = True  # All metric layers are stateful.
     self.built = True
-    self._dtype = K.floatx() if dtype is None else dtypes.as_dtype(dtype).name
+    if not base_layer_utils.v2_dtype_behavior_enabled():
+      # We only do this when the V2 behavior is not enabled, as when it is
+      # enabled, the dtype already defaults to floatx.
+      self._dtype = K.floatx() if dtype is None else dtypes.as_dtype(dtype).name
 
   def __new__(cls, *args, **kwargs):
     obj = super(Metric, cls).__new__(cls)
diff --git a/tensorflow/python/keras/metrics_correctness_test.py b/tensorflow/python/keras/metrics_correctness_test.py
index d2bb508..f372996 100644
--- a/tensorflow/python/keras/metrics_correctness_test.py
+++ b/tensorflow/python/keras/metrics_correctness_test.py
@@ -90,7 +90,7 @@
             metrics.MeanSquaredError(name='mean_squared_error_2')
         ],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     return model
 
   def setUp(self):
@@ -201,8 +201,6 @@
       self.assertAllClose(history.history[key], value, 1e-3)
 
   def test_fit_with_sample_weight(self):
-    if testing_utils.should_run_distributed():
-      self.skipTest('b/137397816')
     model = self._get_compiled_multi_io_model()
     history = model.fit([self.x, self.x], [self.y1, self.y2],
                         sample_weight={
@@ -226,8 +224,6 @@
       self.assertAllClose(history.history[key], value, 1e-3)
 
   def test_fit_with_class_weight(self):
-    if testing_utils.should_run_distributed():
-      self.skipTest('b/137397816')
     model = self._get_compiled_multi_io_model()
     history = model.fit([self.x, self.x], [self.y1, self.y2],
                         class_weight={
@@ -257,8 +253,6 @@
     self.assertAllClose(eval_result, self.expected_batch_result, 1e-3)
 
   def test_eval_with_sample_weight(self):
-    if testing_utils.should_run_distributed():
-      self.skipTest('b/137397816')
     model = self._get_compiled_multi_io_model()
     eval_result = model.evaluate([self.x, self.x], [self.y1, self.y2],
                                  batch_size=2,
@@ -435,7 +429,7 @@
             metrics.MeanSquaredError(name='mean_squared_error_2')
         ],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     return model
 
   def _custom_generator(self, sample_weight=None):
@@ -646,7 +640,7 @@
         optimizer='rmsprop',
         loss=loss,
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     return model
 
   def setUp(self):
@@ -764,8 +758,6 @@
     self.assertAllClose(result, expected_values)
 
   def test_fit_generator(self, reduction):
-    if testing_utils.should_run_distributed():
-      self.skipTest('b/137397816')
     model = self._get_compiled_multi_io_model(
         loss=losses.MeanSquaredError(reduction=reduction))
     history = model.fit_generator(
@@ -777,8 +769,6 @@
       self.assertAllClose(history.history[key], value)
 
   def test_eval_generator(self, reduction):
-    if testing_utils.should_run_distributed():
-      self.skipTest('b/137397816')
     model = self._get_compiled_multi_io_model(
         loss=losses.MeanSquaredError(reduction=reduction))
     eval_result = model.evaluate_generator(
diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py
index 02ebcda..7329b0b 100644
--- a/tensorflow/python/keras/metrics_test.py
+++ b/tensorflow/python/keras/metrics_test.py
@@ -1968,7 +1968,7 @@
       metrics=compile_metrics,
       optimizer='rmsprop',
       run_eagerly=testing_utils.should_run_eagerly(),
-      run_distributed=testing_utils.should_run_distributed())
+      experimental_run_tf_function=testing_utils.should_run_tf_function())
   return model
 
 
diff --git a/tensorflow/python/keras/mixed_precision/experimental/BUILD b/tensorflow/python/keras/mixed_precision/experimental/BUILD
index f9587b1..31dd12b 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/BUILD
+++ b/tensorflow/python/keras/mixed_precision/experimental/BUILD
@@ -60,6 +60,7 @@
         ":policy",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:platform_test",
+        "//tensorflow/python/keras",
         "//tensorflow/python/keras/mixed_precision/experimental:loss_scale_optimizer",
         "//tensorflow/python/keras/optimizer_v2",
     ],
@@ -115,6 +116,7 @@
         ":loss_scale_optimizer",
         ":test_util",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python:control_flow_v2_toggles",
         "//tensorflow/python/distribute:mirrored_strategy",
         "//tensorflow/python/distribute:one_device_strategy",
         "//tensorflow/python/keras",
@@ -144,5 +146,6 @@
         "//tensorflow/python/keras",
     ],
     shard_count = 4,
+    tags = ["no_windows"],  # b/139083295: bfloat16 tests fail on Windows
     xla_enable_strict_auto_jit = True,
 )
diff --git a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py
index 60cb1ca..a9fdcfc 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py
@@ -123,6 +123,22 @@
         self.assertEqual(x.read_value().dtype, dtypes.float16)
 
   @parameterized.named_parameters(*TESTCASES)
+  def test_dtype_is_not_string(self, distribute):
+    with get_distribute_scope(distribute):
+      x = get_var(1., dtypes.float32)
+      x = get_autocast_var(x, distribute)
+      self.assertEqual(x.dtype, dtypes.float32)
+      self.assertIsInstance(x.dtype, dtypes.DType)
+      self.assertEqual(x.true_dtype, dtypes.float32)
+      self.assertIsInstance(x.true_dtype, dtypes.DType)
+
+      with ops.get_default_graph()._enable_auto_casting_variables('float16'):
+        self.assertEqual(x.dtype, dtypes.float16)
+        self.assertIsInstance(x.dtype, dtypes.DType)
+        self.assertEqual(x.true_dtype, dtypes.float32)
+        self.assertIsInstance(x.true_dtype, dtypes.DType)
+
+  @parameterized.named_parameters(*TESTCASES)
   def test_operator_overloads(self, distribute):
     with get_distribute_scope(distribute):
       for read_dtype in (dtypes.float32, dtypes.float16):
diff --git a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py
index ca07a65..cc8fa18 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py
@@ -38,6 +38,7 @@
 from tensorflow.python.keras import regularizers
 from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.engine import base_layer
+from tensorflow.python.keras.engine import base_layer_utils
 from tensorflow.python.keras.layers import core
 from tensorflow.python.keras.layers import recurrent
 from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
@@ -173,19 +174,74 @@
 
   @parameterized.named_parameters(*TESTCASES)
   @test_util.run_in_graph_and_eager_modes
-  def test_variables_in_float32(self, strategy_fn):
+  def test_infer_with_float32_vars(self, strategy_fn):
     x = constant_op.constant([1.], dtype=dtypes.float16)
-    with strategy_fn().scope():
-      with policy.policy_scope('infer_float32_vars'):
-        layer = AddLayer(assert_type=dtypes.float16)
+    with strategy_fn().scope(), policy.policy_scope('infer_float32_vars'):
+      layer = AddLayer(assert_type=dtypes.float16)
+      self.assertEqual(layer.dtype, dtypes.float32)
+      y = layer(x)
+      self.assertEqual(layer.v.dtype, dtypes.float32)
+      self.assertEqual(y.dtype, dtypes.float16)
+      self.assertEqual(layer.dtype, dtypes.float32)
+      self.assertEqual(layer._dtype_policy._name, 'float16_with_float32_vars')
+      self.evaluate(variables.global_variables_initializer())
+      self.assertEqual(self.evaluate(y), 2.)
+
+      if base_layer_utils.v2_dtype_behavior_enabled():
+        # Layer should now cast inputs to float16
+        x = constant_op.constant([1.], dtype=dtypes.float32)
+        y = layer(x)
+        self.assertEqual(y.dtype, dtypes.float16)
+
+  @parameterized.named_parameters(*TESTCASES)
+  @test_util.run_in_graph_and_eager_modes
+  @testing_utils.enable_v2_dtype_behavior
+  def test_floating_point_policies_with_float32_vars(self, strategy_fn):
+    for dtype in 'bfloat16', 'float16', 'float64':
+      x = constant_op.constant([1.])
+      policy_name = dtype + '_with_float32_vars'
+      with strategy_fn().scope(), policy.policy_scope(policy_name):
+        layer = AddLayer(assert_type=dtype)
+        self.assertEqual(layer.dtype, dtypes.float32)
+        self.assertEqual(layer._dtype_policy._name, policy_name)
         y = layer(x)
         self.assertEqual(layer.v.dtype, dtypes.float32)
-        self.assertEqual(y.dtype, dtypes.float16)
+        self.assertEqual(y.dtype, dtype)
+        self.assertEqual(layer.dtype, dtypes.float32)
+        self.assertEqual(layer._dtype_policy._name, policy_name)
         self.evaluate(variables.global_variables_initializer())
         self.assertEqual(self.evaluate(y), 2.)
 
   @parameterized.named_parameters(*TESTCASES)
   @test_util.run_in_graph_and_eager_modes
+  @testing_utils.enable_v2_dtype_behavior
+  def test_int32_with_float32_vars(self, strategy_fn):
+
+    # The policy int32_with_float32_vars is not useful at all (nor is any other
+    # non-float policy with float32 variables), but we have it for consistency,
+    # and so we test it.
+
+    class IdentityLayerWithVar(base_layer.Layer):
+
+      def build(self, _):
+        self.v = self.add_weight('v', ())
+
+      def call(self, inputs):
+        # Variables are only casted to other floats, not ints
+        assert array_ops.identity(self.v).dtype == 'float32'
+        return array_ops.identity(inputs)
+
+    x = constant_op.constant([1])
+    with strategy_fn().scope(), policy.policy_scope('int32_with_float32_vars'):
+      layer = IdentityLayerWithVar()
+      self.assertEqual(layer.dtype, dtypes.float32)
+      self.assertEqual(layer._dtype_policy._name, 'int32_with_float32_vars')
+      y = layer(x)
+      self.assertEqual(layer.v.dtype, dtypes.float32)
+      self.assertEqual(y.dtype, dtypes.int32)
+
+  @parameterized.named_parameters(*TESTCASES)
+  @test_util.run_in_graph_and_eager_modes
   def test_layer_with_non_autocast_variable(self, strategy_fn):
     x = constant_op.constant([1.], dtype=dtypes.float16)
     with strategy_fn().scope():
@@ -212,7 +268,7 @@
 
   @parameterized.named_parameters(*TESTCASES)
   @test_util.run_in_graph_and_eager_modes
-  def test_layer_regularizer_runs_in_float32(self, strategy_fn):
+  def test_layer_regularizer_runs_in_var_dtype(self, strategy_fn):
     x = constant_op.constant([1.], dtype=dtypes.float16)
     with strategy_fn().scope():
       with policy.policy_scope('infer_float32_vars'):
@@ -256,6 +312,16 @@
         self.assertEqual(layer.v.dtype, dtypes.float16)
         self.assertEqual(layer.dtype, dtypes.float16)
 
+  @test_util.run_in_graph_and_eager_modes
+  def test_error_passing_policy_string_to_layer(self):
+    with self.assertRaisesRegexp(
+        TypeError, "Cannot convert value 'float16_with_float32_vars' to a "
+                   "TensorFlow DType"):
+      # This is not allowed, as otherwise a "float16_with_float32_vars" policy
+      # could be created without an API call that has the name "experimental" in
+      # it.
+      AddLayer(dtype='float16_with_float32_vars')
+
   @parameterized.named_parameters(*TESTCASES)
   @test_util.run_in_graph_and_eager_modes
   def test_gradient(self, strategy_fn):
@@ -304,7 +370,7 @@
     with strategy_fn().scope():
       with policy.policy_scope(save_policy):
         layer = AddLayer(assert_type=save_input_dtype)
-        layer.build(())
+        layer(x)  # Build layer
     layer.set_weights([np.array(100.)])
     self.assertEqual(self.evaluate(layer(x)), 101.)
     checkpoint = trackable_utils.Checkpoint(layer=layer)
@@ -316,7 +382,7 @@
     with strategy_fn().scope():
       with policy.policy_scope(load_policy):
         layer = AddLayer(assert_type=load_input_dtype)
-        layer.build(())
+        layer(x)  # Build layer
     layer.set_weights([np.array(200.)])
     self.assertEqual(self.evaluate(layer(x)), 201.)
     checkpoint = trackable_utils.Checkpoint(layer=layer)
@@ -366,20 +432,26 @@
           'strategy_fn': create_mirrored_strategy,
           'use_regularizer': True
       }, {
+          'testcase_name': 'infer',
+          'strategy_fn': create_mirrored_strategy,
+          'policy_name': 'infer_with_float32_vars'
+      }, {
           'testcase_name': 'norun_distributed',
           'strategy_fn': create_mirrored_strategy,
-          'run_distributed': False
+          'experimental_run_tf_function': False
       })
+  @testing_utils.enable_v2_dtype_behavior
   def test_model(self,
                  strategy_fn,
                  use_operator=False,
                  use_regularizer=False,
-                 run_distributed=True):
+                 policy_name='float16_with_float32_vars',
+                 experimental_run_tf_function=True):
     if not self._is_strategy_supported(strategy_fn, check_model_type=True):
       return
     regularizer = IdentityRegularizer() if use_regularizer else None
     with strategy_fn().scope():
-      with policy.policy_scope('infer_float32_vars'):
+      with policy.policy_scope(policy_name):
         layer_list = []
         if testing_utils.get_model_type() == 'subclass':
           # Subclassed models do not have an Input layer, so the model does not
@@ -410,7 +482,7 @@
             opt,
             loss=loss_fn,
             run_eagerly=testing_utils.should_run_eagerly(),
-            run_distributed=testing_utils.should_run_distributed())
+            experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     x = np.ones((2, 1))
     y = np.ones((2, 1))
@@ -435,9 +507,11 @@
       }, {
           'testcase_name': 'norun_distributed',
           'strategy_fn': create_mirrored_strategy,
-          'run_distributed': False,
+          'experimental_run_tf_function': False,
       })
-  def test_fixed_loss_scaling(self, strategy_fn, run_distributed=True):
+  def test_fixed_loss_scaling(self,
+                              strategy_fn,
+                              experimental_run_tf_function=True):
     # Note: We do not test mixed precision in this method, only loss scaling.
     if not self._is_strategy_supported(strategy_fn):
       return
@@ -467,7 +541,7 @@
           opt,
           loss=loss_fn,
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     self.assertEqual(backend.eval(layer.v), 1)
     x = np.ones((batch_size, 1))
@@ -491,6 +565,7 @@
           'strategy_fn': create_mirrored_strategy,
           'use_loss_scaling': True
       })
+  @testing_utils.enable_v2_dtype_behavior
   def test_advanced_model(self, strategy_fn, use_loss_scaling=False):
     # The advanced model tests mixed-precision-related features that would occur
     # in a resnet50 model. It tests a model that has:
@@ -507,8 +582,8 @@
     learning_rate = 2**-14
 
     with strategy.scope():
-      with policy.policy_scope(policy.Policy('infer_float32_vars')):
-        x = layers.Input(shape=(1,), batch_size=2, dtype=dtypes.float16)
+      with policy.policy_scope(policy.Policy('float16_with_float32_vars')):
+        x = layers.Input(shape=(1,), batch_size=2)
         layer1 = AddLayer(
             assert_type=dtypes.float16,
             regularizer=IdentityRegularizer(),
@@ -549,7 +624,7 @@
             opt,
             loss=loss_fn,
             run_eagerly=testing_utils.should_run_eagerly(),
-            run_distributed=testing_utils.should_run_distributed())
+            experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     x = np.ones((2, 1))
     y = np.ones((2, 1))
@@ -574,9 +649,11 @@
       }, {
           'testcase_name': 'norun_distributed',
           'strategy_fn': create_mirrored_strategy,
-          'run_distributed': False,
+          'experimental_run_tf_function': False,
       })
-  def test_dynamic_loss_scaling(self, strategy_fn, run_distributed=True):
+  def test_dynamic_loss_scaling(self,
+                                strategy_fn,
+                                experimental_run_tf_function=True):
     if not self._is_strategy_supported(strategy_fn):
       return
     strategy = strategy_fn()
@@ -616,7 +693,7 @@
             opt,
             loss=loss_fn,
             run_eagerly=testing_utils.should_run_eagerly(),
-            run_distributed=testing_utils.should_run_distributed())
+            experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     self.assertEqual(backend.eval(layer.v), 1)
     x = np.ones((batch_size, 1))
@@ -727,7 +804,7 @@
           optimizer=opt,
           loss='mse',
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     model.fit(np.zeros((2, 2)), np.zeros((2, 2)), batch_size=2)
     weights_file = os.path.join(self.get_temp_dir(), 'weights')
@@ -767,7 +844,7 @@
           optimizer=opt,
           loss='mse',
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
     # Run for 3 steps (6 examples with a batch size of 2)
     model.fit(np.zeros((6, 2)), np.zeros((6, 2)), batch_size=2)
     self.assertEqual(backend.get_value(loss_scale()), 2)
diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py
index 2d8618c..1b1921f 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py
@@ -31,6 +31,7 @@
 from tensorflow.python.keras.mixed_precision.experimental import test_util as mp_test_util
 from tensorflow.python.keras.optimizer_v2 import adam
 from tensorflow.python.keras.optimizer_v2 import gradient_descent
+from tensorflow.python.ops import control_flow_v2_toggles
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
@@ -59,6 +60,7 @@
 })
 
 
+@test_util.with_control_flow_v2
 class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
 
   def _run_if_in_graph_mode(self, val):
@@ -192,7 +194,12 @@
   @parameterized.named_parameters(*TESTCASES)
   @test_util.run_in_graph_and_eager_modes
   def testDynamicLossScaleWithSlots(self, strategy_fn):
-    with strategy_fn().scope() as strategy:
+    strategy_obj = strategy_fn()
+    if (isinstance(strategy_obj, mirrored_strategy.MirroredStrategy) and
+        control_flow_v2_toggles.control_flow_v2_enabled() and
+        not context.executing_eagerly()):
+      self.skipTest('b/138667997')
+    with strategy_obj.scope() as strategy:
       var = variables.Variable([1.0, 2.0])
       # An SGD optimizer with momentum has slot variables.
       opt = gradient_descent.SGD(1.0, momentum=1.)
diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy.py b/tensorflow/python/keras/mixed_precision/experimental/policy.py
index d90906f..a4f5f9f 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/policy.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/policy.py
@@ -19,121 +19,292 @@
 
 import contextlib
 
+import six
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.keras import backend
+from tensorflow.python.keras.engine import base_layer_utils
 from tensorflow.python.training.experimental import mixed_precision_global_state
 from tensorflow.python.util.tf_export import keras_export
 
 
 @keras_export('keras.mixed_precision.experimental.Policy')
 class Policy(object):
-  """A mixed precision policy for a Keras layer.
+  """A dtype policy for a Keras layer.
 
-  A mixed precision policy determines the floating-point dtype that Keras layers
-  should create variables in. For non-default policies, if the variable dtype
-  does not match the input dtype, variables will automatically be casted to the
-  input dtype to avoid type errors. Policies can be passed to the 'dtype'
+  A dtype policy determines the computation dtype and the variable dtype of a
+  Keras layer. Each layer has a policy. Policies can be passed to the 'dtype'
   argument of layer constructors, or a global policy can be set with
-  'set_policy'.
+  'tf.keras.mixed_precision.experimental.set_policy'. A layer will default to
+  the global policy if no policy is passed to it's constructor.
 
-  In the near future, policies will also determine the computation dtype of
-  layers, as well as the loss scaling algorithm.
+  For most models, each layer will have the same computation dtype and variable
+  dtype, which will typically be float32. However, when mixed precision
+  training is used, most layers will instead have a float16 computation dtype
+  and a float32 variable dtype. See [this
+  link](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html)
+  for more information on mixed precision training. When the variable dtype does
+  not match the computation dtype, variables will be automatically casted to the
+  computation dtype to avoid type errors.
 
-  Policies are intended to enable mixed precision training, which require using
-  float32 variables and [b]float16 computations for most layers. The term "mixed
-  precision" refers to the use of both float16 (or bfloat16) and float32 in a
-  model. See https://arxiv.org/abs/1710.03740 for more information on mixed
-  precision training.
+  In the near future, policies will also determine the loss scaling algorithm
+  for Keras models.
 
-  Policies are constructed by passing a string to the `name` constructor
-  argument. `name` determines the behavior of the policy. Currently, `name` can
-  be one of the following values.
+  Policies are constructed by passing a string to the constructor, e.g.
+  `tf.keras.mixed_precision.experimental.Policy('float32')`. The string
+  determines the compute and variable dtypes. Currently, it can be one of
+  in one of the following forms:
 
-    * 'infer': Infer the variable and computation dtypes from the input dtype.
-      This is the default behavior.
-    * 'infer_float32_vars': Infer the computation dtypes from the input
-      dtype, but create variables in float32. Variables will be casted to the
-      computation dtype. This is intended to enable mixed precision. Users can
-      cast tensors to float16 before passing them to a layer, which causes the
-      layer to run it's computation in float16 while keeping variables in
-      float32.
+    * Any dtype name, such as 'float32' or 'float64'. Both the variable and
+      compute dtypes will be that dtype.
+    * '<dtype>_with_float32_vars', where <dtype> is any dtype. The compute dtype
+      will be <dtype>, while the variable dtype is float32. This is intended for
+      the use of mixed precision, which uses float16 or bfloat16 for most
+      computations, and float32 for variables. This policy is only useful if
+      <dtype> is float16 or bfloat16, although <dtype> is allowed to be any
+      dtype. Note we will have a "mixed" policy in the future, which will make
+      it even easier to use mixed  precision by enabling other features such as
+      loss scaling.
 
-  To use mixed precision in a model, the 'infer_float32_vars' policy can be used
-  alongside float16 input tensors, which results in float16 computations and
-  float32 variables. For example:
+  ### How to use mixed precision in layers with Policies
+
+  To use mixed precision in a model, the 'float16_with_float32_vars' policy can
+  be used. `tf.keras.mixed_precision.experimental.set_policy` can be used to set
+  the default policy for layers if no policy is passed to them. Note loss
+  scaling must also be done, e.g. with a
+  `tf.keras.mixed_precision.experimental.LossScaleOptimizer`. For example
 
   ```python
-  tf.keras.mixed_precision.experimental.set_policy('infer_float32_vars')
+  tf.keras.mixed_precision.experimental.set_policy(
+      'float16_with_float32_vars')
   model = tf.keras.models.Sequential(
-      tf.keras.layers.Input((100,), dtype='float16'),
+      tf.keras.layers.Input((100,)),
+      # Dense layers use global policy of 'float16_with_float32_vars'
       tf.keras.layers.Dense(10),
       tf.keras.layers.Dense(10),
-      tf.keras.layers.Lambda(lambda x: tf.cast(x, 'float32')),
-      tf.keras.layers.Activation('Softmax')
+      # Softmax should be done in float32 for numeric stability. We pass
+      # dtype='float32' to use float32 instead of the global policy.
+      tf.keras.layers.Activation('Softmax', dtype='float32')
   )
+  opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(...)
+  ... # Train `model` with `opt`.
   ```
 
   Alternatively, the policy can be passed to individual layers instead of
   setting the global policy with `set_policy`:
 
   ```python
-  policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars')
+  policy = tf.keras.mixed_precision.experimental.Policy(
+      'float16_with_float32_vars')
   model = tf.keras.models.Sequential(
-      tf.keras.layers.Input((100,), dtype='float16'),
+      tf.keras.layers.Input((100,)),
       tf.keras.layers.Dense(10, dtype=policy),
       tf.keras.layers.Dense(10, dtype=policy),
-      tf.keras.layers.Lambda(lambda x: tf.cast(x, 'float32')),
-      tf.keras.layers.Activation('Softmax')
+      # Softmax should be done in float32 for numeric stability.
+      tf.keras.layers.Activation('Softmax', dtype='float32')
   )
+  opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(...)
+  ... # Train `model` with `opt`.
   ```
 
-  Note that a LossScaleOptimizer should also be used for mixed precision models
-  to avoid numerical underflow. See `LossScaleOptimizer`.
+  As the above example shows, strings can be directly passed to layer
+  constructors in the `dtype` argument instead of policies, but only if the
+  string is convertible to a dtype.
+
+  ### The deprecated "infer" policy
+
+  In addition to a dtype or "<dtype>_with_float32_vars", a policy can also be
+  "infer". This Policy is deprecated, and it is not recommended. When a layer
+  has an infer policy, it will infer the computation and variable dtype from
+  the first input the first time the layer is called.
+
+  Once the layer is called for the first time, the layer's policy will change to
+  the dtype of the first input.
+
+  Similarly to "infer", there is a deprecated "infer_with_float32_vars" policy
+  that infers the compute dtype, but not the variable dtype.
+
+  In TensorFlow 1, only the "infer" and "infer_with_float32_vars" policies are
+  available.
   """
+  # TODO(reedwm): Replace link in above docstring with a version that is more
+  # TensorFlow-specific, and that also mentions bfloat16.
 
   def __init__(self, name):
+    """Constructs the policy.
+
+    The `name` argument determines the compute and variable dtype, and has no
+    additional effect on the Policy. The compute and variable dtypes can only be
+    specified through `name`, and cannot be specified directly.
+
+    Args:
+      name: A string. Can be one of the following values:
+        * Any dtype name, such as 'float32' or 'float64'. Both the variable and
+          compute dtypes will be that dtype.
+        * <dtype>_with_float32_vars, where <dtype> is any dtype. The compute
+          dtype will be <dtype>, while the variable dtype is float32. This is
+          intended for the use of mixed precision, which uses float16 or
+          bfloat16 for most computations, and float32 for variables. This policy
+          is only useful if <dtype> is float16 or bfloat16, although <dtype> is
+          allowed to be any dtype. Note we will have a "mixed" policy in the
+          future, which will make it even easier to use mixed  precision by
+          enabling other features such as loss scaling.
+        * 'infer' or 'infer_with_float32_vars' (deprecated): Infer the
+          computation dtype from the input dtype.
+
+    """
+    if isinstance(name, dtypes.DType):
+      raise TypeError("'name' must be a string, not a DType. "
+                      "Instead, pass DType.name. Got: %s" % (name.name,))
+    elif not isinstance(name, six.string_types):
+      raise TypeError("'name' must be a string, but got: %s" % (name,))
+    if name == 'infer_float32_vars':
+      # For backwards compatibility. TODO(reedwm): Remove this.
+      name = 'infer_with_float32_vars'
+    if name == 'float32_with_float32_vars':
+      # Doesn't affect correctness, but causes "float32" instead of
+      # "float32_with_float32_vars" to be printed in __repr__.
+      name = 'float32'
     self._name = name
-    if name == 'infer':
-      self._default_variable_dtype = None
-    elif name == 'infer_float32_vars':
-      self._default_variable_dtype = 'float32'
-    else:
-      raise ValueError('"name" argument to Policy constructor must be "infer" '
-                       'or "infer_float32_vars", but got: %s' % name)
+    self._compute_dtype, self._variable_dtype = self._parse_name(name)
 
-  @property
-  def name(self):
-    """Returns the name of the policy: "infer" or "infer_float32_vars."""
-    return self._name
+  def _parse_name(self, name):
+    """Parses a Policy name into a compute and variable dtype.
 
-  @property
-  def default_variable_dtype(self):
-    """Returns the default variable dtype of this policy.
-
-    This is the dtype layers will create their variables in, unless a layer
-    explicit chooses a different dtype. Layers will cast variables to the
-    appropriate dtype to avoid type errors.
+    Args:
+      name: The name of the policy:
 
     Returns:
-      The default variable dtype of this policy, or None if the default variable
-      dtype should be derived from the inputs.
+      The (compute_dtype, variable_dtype) pair.
     """
-    return self._default_variable_dtype
+    if name.endswith('_with_float32_vars'):
+      base_name = name[:-len('_with_float32_vars')]
+      float32_vars = True
+    else:
+      base_name = name
+      float32_vars = False
+
+    if base_name == 'infer':
+      base_dtype = None
+    else:
+      try:
+        base_dtype = dtypes.as_dtype(base_name).name
+      except TypeError:
+        error = ('Cannot convert value %s to a mixed precision Policy. '
+                 'Valid policies include include those in the form "<dtype>" '
+                 'and "<dtype>_with_float32_vars", where <dtype> is the name '
+                 'of a dtype.' % (name,))
+        if float32_vars:
+          error += (' The value %s ends with _with_float32_vars, but %s cannot '
+                    'be converted to a DType' % (name, base_name))
+        raise ValueError(error)
+
+    if float32_vars:
+      return base_dtype, 'float32'
+    else:
+      return base_dtype, base_dtype
+
+  @property
+  def variable_dtype(self):
+    """The variable dtype of this policy.
+
+    This is the dtype layers will create their variables in, unless a layer
+    explicit chooses a different dtype. If this is different than
+    `Policy.compute_dtype` and both are non-None, Layers will cast variables to
+    the compute dtype to avoid type errors.
+
+    If this is None, the policy is "infer" and the `compute_dtype` is also None.
+    If `compute_dtype` is None, this is either None or float32.
+
+    Returns:
+      The variable dtype of this policy, or None if the variable dtype should be
+      inferred from the inputs.
+    """
+    return self._variable_dtype
+
+  @property
+  def compute_dtype(self):
+    """The compute dtype of this policy.
+
+    This is the dtype layers will do their computations in.
+
+    If this is None, the policy is "infer" or "infer_with_float32_vars" and
+    `variable_dtype` is either None or float32 respectively.
+
+    Note that even if the compute dtype is float16 or bfloat16, hardware devices
+    may not do individual adds, multiplies, and other fundamental operations in
+    [b]float16, but instead may do some of them in float32 for numeric
+    stability. The compute dtype is the dtype of the inputs and outputs of the
+    TensorFlow ops that the layer executes. Internally, many TensorFlow ops will
+    do certain internal calculations in float32, or some other device-internal
+    intermediate format with higher precision than [b]float16, to increase
+    numeric stability.
+
+    For example, a `tf.keras.layers.Dense` layer, when run on a GPU with a
+    float16 compute dtype, will pass float16 inputs to tf.matmul. But, tf.matmul
+    will do use float32 intermediate math. The performance benefit of float16 is
+    still apparent, due to increased memory bandwidth and the fact GPUs have
+    specialized hardware for computating matmuls on float16 while still keeping
+    intermediate computations in float32.
+
+    Returns:
+      The variable dtype of this policy, or None if the variable dtype should be
+      inferred from the inputs.
+    """
+    return self._compute_dtype
 
   @property
   def should_cast_variables(self):
-    """Returns true if variables should be casted."""
-    return self.default_variable_dtype is not None
+    """Returns True if variables should be casted.
 
-  # TODO(reedwm): Implement get_config/from_config.
+    This is true if the variable dtype is not the same as the compute dtype.
+
+    Returns:
+      True, if variables should be casted.
+    """
+    return self.variable_dtype != self.compute_dtype
+
+  @property
+  def name(self):
+    """Returns the name of this policy."""
+    return self._name
+
+  def __repr__(self):
+    return '<Policy "%s">' % self._name
 
 
-# The policy in effect when TensorFlow starts. This is constant and never
-# changes.
-_default_policy = Policy('infer')
+def with_input_dtype(policy, dtype):
+  """Copies "infer" `policy`, adding `dtype` to it.
 
-# The current global policy in effect. This starts as the default policy, but
-# can be changed with `set_policy`.
+  Policy must be "infer" or "infer_float32_vars" (i.e., has no compute dtype).
+  Returns a new policy with compute dtype `dtype`. The returned policy's
+  variable dtype is also `dtype` if `policy` is "infer", and is `float32` if
+  `policy` is "infer_with_float32_vars".
+
+  Args:
+    policy: An "infer" or "infer_float32_vars" policy
+    dtype: The dtype of an input to a layer.
+
+  Returns:
+    A new policy copied from `policy`, but with compute dtype and maybe
+    variable_dtype set to `dtype`.
+  """
+  assert not policy.compute_dtype
+  dtype = dtypes.as_dtype(dtype).name
+  if policy.variable_dtype is None:
+    return Policy(dtype)
+  else:
+    # Policies without a compute dtype are either "infer" or
+    # "infer_with_float32_vars", so the variable_dtype must be float32 here.
+    assert policy.variable_dtype == 'float32'
+    return Policy(dtype + '_with_float32_vars')
+
+
+# The current global policy in effect. If None, it means the current value of
+# floatx should be used as the policy if the V2 dtype behavior is enabled,
+# or "infer" otherwise.
 # TODO(reedwm): Make this thread local?
-_global_policy = _default_policy
+_global_policy = None
 
 
 @keras_export('keras.mixed_precision.experimental.global_policy')
@@ -141,15 +312,29 @@
   """Returns the global Policy.
 
   The global policy is the default policy used for layers, if no policy is
-  passed to the layer constructor. When TensorFlow starts, the global policy is
-  set to an "infer" policy, and can be changed with `set_policy`.
+  passed to the layer constructor. If no policy has been set with
+  `keras.mixed_precision.experimental.set_policy`, this will return a policy
+  constructed from `tf.keras.backend.floatx()` in TensorFlow 2, or an "infer"
+  policy in TensorFlow 1.
+
+  See `keras.mixed_precision.experimental.Policy` for more information.
 
   Returns:
     The global Policy.
   """
+  if _global_policy is None:
+    if base_layer_utils.v2_dtype_behavior_enabled():
+      return Policy(backend.floatx())
+    else:
+      return Policy('infer')
   return _global_policy
 
 
+def policy_defaults_to_floatx():
+  """Returns True if `global_policy()` will use the current value of floatx."""
+  return _global_policy is None and base_layer_utils.v2_dtype_behavior_enabled()
+
+
 def _check_if_mixed_precision_graph_rewrite_is_enabled():
   # TODO(reedwm): Update this comment once the Keras API is complete.
   if mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled:
@@ -170,19 +355,43 @@
 
 @keras_export('keras.mixed_precision.experimental.set_policy')
 def set_policy(policy):
-  """Sets the global Policy."""
+  """Sets the global Policy.
+
+  The global policy is the default policy used for layers, if no policy is
+  passed to the layer constructor. If no global policy is set, layers will
+  instead default to a Policy constructed from `tf.keras.backend.floatx()` in
+  TensorFlow 2. In TensorFlow 1, layers default to an "infer" policy.
+
+  See `keras.mixed_precision.experimental.Policy` for more information.
+
+  Args:
+    policy: A Policy, or a string that will be converted to a Policy..
+  """
   global _global_policy
   _check_if_mixed_precision_graph_rewrite_is_enabled()
-  if not isinstance(policy, Policy):
+  if policy is not None and not isinstance(policy, Policy):
     policy = Policy(policy)
+  if (policy and not base_layer_utils.v2_dtype_behavior_enabled() and
+      policy.compute_dtype):
+    raise ValueError(
+        'The global policy can only be set to a non-infer policy in TensorFlow '
+        '2')
   _global_policy = policy
   mixed_precision_global_state.using_default_mixed_precision_policy = (
-      _global_policy is _default_policy)
+      _global_policy is None)
 
 
 # TODO(reedwm): Make this thread local
 @contextlib.contextmanager
 def policy_scope(policy):
+  """A context manager that sets the global Policy under it.
+
+  Args:
+    policy: A Policy, or a string that will be converted to a Policy..
+
+  Yields:
+    Nothing.
+  """
   old_policy = _global_policy
   try:
     set_policy(policy)
diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy_test.py b/tensorflow/python/keras/mixed_precision/experimental/policy_test.py
index a48ecd7..15a237d 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/policy_test.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/policy_test.py
@@ -18,8 +18,11 @@
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
+from tensorflow.python.keras import testing_utils
+from tensorflow.python.keras.engine import base_layer_utils
 from tensorflow.python.keras.mixed_precision.experimental import policy as mp_policy
 from tensorflow.python.keras.optimizer_v2 import gradient_descent
 from tensorflow.python.platform import test
@@ -30,43 +33,110 @@
 class PolicyTest(test.TestCase):
   """Tests Policies."""
 
-  def test_infer(self):
+  @testing_utils.enable_v2_dtype_behavior
+  def test_dtype_attributes(self):
     policy = mp_policy.Policy('infer')
-    self.assertEqual(policy.name, 'infer')
-    self.assertEqual(policy.default_variable_dtype, None)
+    self.assertEqual(policy.compute_dtype, None)
+    self.assertEqual(policy.variable_dtype, None)
 
-  def test_infer_float32_vars(self):
     policy = mp_policy.Policy('infer_float32_vars')
-    self.assertEqual(policy.name, 'infer_float32_vars')
-    self.assertEqual(policy.default_variable_dtype, 'float32')
+    self.assertEqual(policy.compute_dtype, None)
+    self.assertEqual(policy.variable_dtype, 'float32')
 
+    for dtype in 'int32', 'bool', 'float16', 'float32':
+      policy = mp_policy.Policy(dtype)
+      self.assertEqual(policy.compute_dtype, dtype)
+      self.assertEqual(policy.variable_dtype, dtype)
+
+      policy = mp_policy.Policy(dtype + '_with_float32_vars')
+      self.assertEqual(policy.compute_dtype, dtype)
+      self.assertEqual(policy.variable_dtype, 'float32')
+
+  @testing_utils.enable_v2_dtype_behavior
+  def test_repr(self):
+    for policy in ('infer', 'infer_with_float32_vars', 'float32',
+                   'float16_with_float32_vars'):
+      self.assertEqual(repr(mp_policy.Policy(policy)),
+                       '<Policy "%s">' % policy)
+    self.assertEqual(repr(mp_policy.Policy('float32_with_float32_vars')),
+                     '<Policy "float32">')
+
+  @testing_utils.enable_v2_dtype_behavior
+  def test_policy_errors(self):
+    # Test passing invalid strings
+    expected_error = 'Cannot convert value %s to a mixed precision Policy.'
+
+    for invalid_policy in ('abc', 'abc_with_float32_vars',
+                           'float32_with_float16_vars'):
+      with self.assertRaisesRegexp(ValueError,
+                                   expected_error % invalid_policy):
+        mp_policy.Policy(invalid_policy)
+
+    # Test passing a DType
+    with self.assertRaisesRegexp(TypeError,
+                                 "'name' must be a string, not a DType. "
+                                 "Instead, pass DType.name. Got: float16"):
+      mp_policy.Policy(dtypes.float16)
+
+    # Test passing a non-DType invalid type
+    with self.assertRaisesRegexp(TypeError,
+                                 "'name' must be a string, but got: 5"):
+      mp_policy.Policy(5)
+
+  @testing_utils.enable_v2_dtype_behavior
+  def test_with_input_dtype(self):
+    policy = mp_policy.with_input_dtype(mp_policy.Policy('infer'), 'float16')
+    self.assertEqual(policy.compute_dtype, 'float16')
+    self.assertEqual(policy.variable_dtype, 'float16')
+
+    policy = mp_policy.with_input_dtype(
+        mp_policy.Policy('infer_with_float32_vars'), 'float16')
+    self.assertEqual(policy.compute_dtype, 'float16')
+    self.assertEqual(policy.variable_dtype, 'float32')
+
+    policy = mp_policy.with_input_dtype(
+        mp_policy.Policy('infer_with_float32_vars'), 'float32')
+    self.assertEqual(policy.compute_dtype, 'float32')
+    self.assertEqual(policy.variable_dtype, 'float32')
+
+  @testing_utils.enable_v2_dtype_behavior
   def test_global_policy(self):
-    self.assertEqual(mp_policy.global_policy().name, 'infer')
-    default_policy = mp_policy.global_policy()
+    if base_layer_utils.v2_dtype_behavior_enabled():
+      default_policy = 'float32'
+    else:
+      default_policy = 'infer'
+    self.assertEqual(mp_policy.global_policy().name, default_policy)
     try:
-      mp_policy.set_policy('infer_float32_vars')
-      self.assertEqual(mp_policy.global_policy().name, 'infer_float32_vars')
-      self.assertEqual(mp_policy.global_policy().default_variable_dtype,
-                       'float32')
+      mp_policy.set_policy('infer_with_float32_vars')
+      self.assertEqual(mp_policy.global_policy().name,
+                       'infer_with_float32_vars')
       with ops.Graph().as_default():  # Policies are not associated with a graph
-        self.assertEqual(mp_policy.global_policy().name, 'infer_float32_vars')
+        self.assertEqual(mp_policy.global_policy().name,
+                         'infer_with_float32_vars')
       mp_policy.set_policy('infer')
       self.assertEqual(mp_policy.global_policy().name, 'infer')
-      self.assertEqual(mp_policy.global_policy().default_variable_dtype, None)
-      policy = mp_policy.Policy('infer_float32_vars')
+      policy = mp_policy.Policy('infer_with_float32_vars')
       mp_policy.set_policy(policy)
       self.assertIs(mp_policy.global_policy(), policy)
     finally:
-      mp_policy.set_policy(default_policy)
+      mp_policy.set_policy(None)
 
+  @testing_utils.enable_v2_dtype_behavior
   def test_policy_scope(self):
-    with mp_policy.policy_scope('infer_float32_vars'):
-      self.assertEqual(mp_policy.global_policy().name, 'infer_float32_vars')
+    if base_layer_utils.v2_dtype_behavior_enabled():
+      default_policy = 'float32'
+    else:
+      default_policy = 'infer'
+    with mp_policy.policy_scope('infer_with_float32_vars'):
+      self.assertEqual(mp_policy.global_policy().name,
+                       'infer_with_float32_vars')
       with mp_policy.policy_scope('infer'):
         self.assertEqual(mp_policy.global_policy().name, 'infer')
-      self.assertEqual(mp_policy.global_policy().name, 'infer_float32_vars')
-    self.assertEqual(mp_policy.global_policy().name, 'infer')
+      self.assertEqual(mp_policy.global_policy().name,
+                       'infer_with_float32_vars')
+    self.assertEqual(mp_policy.global_policy().name, default_policy)
 
+  @testing_utils.enable_v2_dtype_behavior
   def test_error_if_graph_rewrite_enabled(self):
     try:
       mixed_precision.enable_mixed_precision_graph_rewrite(
@@ -78,6 +148,27 @@
     finally:
       mixed_precision.disable_mixed_precision_graph_rewrite()
 
+  @testing_utils.disable_v2_dtype_behavior
+  def test_v1_dtype_behavior(self):
+    # These policies are allowed with V1 dtype behavior
+    with mp_policy.policy_scope(mp_policy.Policy('infer')):
+      pass
+    with mp_policy.policy_scope(mp_policy.Policy('infer_float32_vars')):
+      pass
+
+    # These policies are not allowed with V1 dtype behavior
+    with self.assertRaisesRegexp(
+        ValueError,
+        'global policy can only be set to a non-infer policy in TensorFlow 2'):
+      with mp_policy.policy_scope(mp_policy.Policy('float32')):
+        pass
+    with self.assertRaisesRegexp(
+        ValueError,
+        'global policy can only be set to a non-infer policy in TensorFlow 2'):
+      with mp_policy.policy_scope(
+          mp_policy.Policy('float16_with_float32_vars')):
+        pass
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/keras/model_subclassing_compiled_test.py b/tensorflow/python/keras/model_subclassing_compiled_test.py
new file mode 100644
index 0000000..180e8c8
--- /dev/null
+++ b/tensorflow/python/keras/model_subclassing_compiled_test.py
@@ -0,0 +1,475 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for compiled Model subclassing."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import numpy as np
+
+from tensorflow.python import keras
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras import model_subclassing_test_util as model_util
+from tensorflow.python.keras import testing_utils
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+try:
+  import h5py  # pylint:disable=g-import-not-at-top
+except ImportError:
+  h5py = None
+
+
+@keras_parameterized.run_all_keras_modes
+class ModelSubclassCompiledTest(keras_parameterized.TestCase):
+
+  def test_single_io_workflow_with_np_arrays(self):
+    num_classes = 2
+    num_samples = 100
+    input_dim = 50
+
+    model = model_util.SimpleTestModel(
+        num_classes=num_classes, use_dp=True, use_bn=True)
+    model.compile(
+        loss='mse',
+        optimizer='rmsprop',
+        metrics=['acc', keras.metrics.CategoricalAccuracy()],
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+
+    x = np.ones((num_samples, input_dim))
+    y = np.zeros((num_samples, num_classes))
+
+    model.fit(x, y, epochs=2, batch_size=32, verbose=0)
+    _ = model.evaluate(x, y, verbose=0)
+
+  def test_multi_io_workflow_with_np_arrays(self):
+    num_classes = (2, 3)
+    num_samples = 1000
+    input_dim = 50
+
+    model = model_util.MultiIOTestModel(
+        num_classes=num_classes, use_dp=True, use_bn=True)
+    model.compile(
+        loss='mse',
+        optimizer='rmsprop',
+        metrics=['acc'],
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+
+    x1 = np.ones((num_samples, input_dim))
+    x2 = np.ones((num_samples, input_dim))
+    y1 = np.zeros((num_samples, num_classes[0]))
+    y2 = np.zeros((num_samples, num_classes[1]))
+
+    model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0)
+    _ = model.evaluate([x1, x2], [y1, y2], verbose=0)
+
+  def test_single_io_workflow_with_datasets(self):
+    num_classes = 2
+    num_samples = 10
+    input_dim = 50
+
+    with self.cached_session():
+      model = model_util.SimpleTestModel(
+          num_classes=num_classes, use_dp=True, use_bn=True)
+      model.compile(
+          loss='mse',
+          optimizer='rmsprop',
+          run_eagerly=testing_utils.should_run_eagerly(),
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
+
+      x = np.ones((num_samples, input_dim), dtype=np.float32)
+      y = np.zeros((num_samples, num_classes), dtype=np.float32)
+      dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
+      dataset = dataset.repeat(100)
+      dataset = dataset.batch(10)
+
+      model.fit(dataset, epochs=2, steps_per_epoch=10, verbose=0)
+      _ = model.evaluate(dataset, steps=10, verbose=0)
+
+  def test_attributes(self):
+    # layers, weights, trainable_weights, non_trainable_weights, inputs, outputs
+
+    num_classes = (2, 3)
+    num_samples = 100
+    input_dim = 50
+
+    model = model_util.MultiIOTestModel(num_classes=num_classes, use_bn=True)
+
+    x1 = np.ones((num_samples, input_dim))
+    x2 = np.ones((num_samples, input_dim))
+    y1 = np.zeros((num_samples, num_classes[0]))
+    y2 = np.zeros((num_samples, num_classes[1]))
+
+    self.assertEqual(model.name, 'test_model')
+    self.assertEqual(model.built, False)
+    self.assertEqual(len(model.weights), 0)
+
+    model.compile(
+        loss='mse',
+        optimizer='rmsprop',
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    model.train_on_batch([x1, x2], [y1, y2])
+
+    self.assertEqual(model.built, True)
+    self.assertEqual(len(model.layers), 4)
+    self.assertEqual(len(model.weights), 10)
+    self.assertEqual(len(model.trainable_weights), 8)
+    self.assertEqual(len(model.non_trainable_weights), 2)
+    self.assertEqual(len(model.inputs), 2)
+    self.assertEqual(len(model.outputs), 2)
+
+  def test_updates(self):
+    # test that updates get run during training
+    num_samples = 100
+    input_dim = 50
+
+    class BNNet(keras.Model):
+
+      def __init__(self):
+        super(BNNet, self).__init__()
+        self.bn = keras.layers.BatchNormalization(beta_initializer='ones',
+                                                  gamma_initializer='ones')
+
+      def call(self, inputs):
+        return self.bn(inputs)
+
+    x = np.ones((num_samples, input_dim))
+    y = np.ones((num_samples, input_dim))
+
+    model = BNNet()
+    model.compile(
+        loss='mse',
+        optimizer='rmsprop',
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    y_ref = model.predict(x)
+
+    model.train_on_batch(x, y)
+    y_new = model.predict(x)
+    self.assertGreater(np.sum(np.abs(y_ref - y_new)), 0.1)
+
+  def test_training_and_inference_behavior(self):
+    # test that dropout is applied in training and not inference
+
+    num_samples = 100
+    input_dim = 50
+
+    class DPNet(keras.Model):
+
+      def __init__(self):
+        super(DPNet, self).__init__()
+        self.dp = keras.layers.Dropout(0.5)
+        self.dense = keras.layers.Dense(1,
+                                        use_bias=False,
+                                        kernel_initializer='ones')
+
+      def call(self, inputs):
+        x = self.dp(inputs)
+        return self.dense(x)
+
+    model = DPNet()
+    x = np.ones((num_samples, input_dim))
+    y = model.predict(x)
+    self.assertEqual(np.sum(y), np.sum(x))
+    model.compile(
+        loss='mse',
+        optimizer='rmsprop',
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    loss = model.train_on_batch(x, y)
+    self.assertGreater(loss, 0.1)
+
+  def test_training_methods(self):
+    # test fit, train_on_batch
+    # on different input types: list, dict
+
+    num_classes = (2, 3)
+    num_samples = 100
+    input_dim = 50
+
+    x1 = np.ones((num_samples, input_dim))
+    x2 = np.ones((num_samples, input_dim))
+    y1 = np.zeros((num_samples, num_classes[0]))
+    y2 = np.zeros((num_samples, num_classes[1]))
+
+    model = model_util.MultiIOTestModel(num_classes=num_classes, use_bn=True)
+    model.compile(
+        loss='mse',
+        optimizer='rmsprop',
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0)
+    model.fit({'input_1': x1, 'input_2': x2},
+              {'output_1': y1, 'output_2': y2},
+              epochs=2, batch_size=32)
+    model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0,
+              validation_data=([x1, x2], [y1, y2]))
+
+    model = model_util.MultiIOTestModel(num_classes=num_classes, use_bn=True)
+    model.compile(
+        loss='mse',
+        optimizer='rmsprop',
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    model.train_on_batch([x1, x2], [y1, y2])
+    model.train_on_batch({'input_1': x1, 'input_2': x2},
+                         {'output_1': y1, 'output_2': y2})
+
+  def test_inference_methods(self):
+    # test predict, evaluate, test_on_batch, predict_on_batch
+    # on different input types: list, dict
+    num_classes = (2, 3)
+    num_samples = 100
+    input_dim = 50
+
+    x1 = np.ones((num_samples, input_dim))
+    x2 = np.ones((num_samples, input_dim))
+    y1 = np.zeros((num_samples, num_classes[0]))
+    y2 = np.zeros((num_samples, num_classes[1]))
+
+    model = model_util.MultiIOTestModel(num_classes=num_classes, use_bn=True)
+    model.compile(
+        loss='mse',
+        optimizer='rmsprop',
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    model.evaluate([x1, x2], [y1, y2])
+    model.test_on_batch([x1, x2], [y1, y2])
+
+    model = model_util.MultiIOTestModel(num_classes=num_classes, use_bn=True)
+    model.predict([x1, x2])
+
+    model = model_util.MultiIOTestModel(num_classes=num_classes, use_bn=True)
+    model.predict_on_batch([x1, x2])
+
+  def test_saving(self):
+    num_classes = (2, 3)
+    num_samples = 100
+    input_dim = 50
+
+    x1 = np.ones((num_samples, input_dim))
+    x2 = np.ones((num_samples, input_dim))
+    y1 = np.zeros((num_samples, num_classes[0]))
+    y2 = np.zeros((num_samples, num_classes[1]))
+
+    model = model_util.MultiIOTestModel(num_classes=num_classes, use_bn=True)
+    model.compile(
+        loss='mse',
+        optimizer='rmsprop',
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0)
+    y_ref_1, y_ref_2 = model.predict([x1, x2])
+
+    tf_format_name = os.path.join(self.get_temp_dir(), 'ckpt')
+    model.save_weights(tf_format_name)
+    if h5py is not None:
+      hdf5_format_name = os.path.join(self.get_temp_dir(), 'weights.h5')
+      model.save_weights(hdf5_format_name)
+
+    model = model_util.MultiIOTestModel(num_classes=num_classes, use_bn=True)
+
+    if h5py is not None:
+      with self.assertRaises(ValueError):
+        model.load_weights(hdf5_format_name)
+
+    model.load_weights(tf_format_name)
+
+    y1, y2 = model.predict([x1, x2])
+    self.assertAllClose(y_ref_1, y1, atol=1e-5)
+    self.assertAllClose(y_ref_2, y2, atol=1e-5)
+
+    if h5py is not None:
+      model.load_weights(hdf5_format_name)
+
+      y1, y2 = model.predict([x1, x2])
+      self.assertAllClose(y_ref_1, y1, atol=1e-5)
+      self.assertAllClose(y_ref_2, y2, atol=1e-5)
+
+  def test_subclass_nested_in_subclass(self):
+    num_classes = 2
+    num_samples = 100
+    input_dim = 50
+
+    model = model_util.NestedTestModel1(num_classes=num_classes)
+    model.compile(
+        loss='mse',
+        optimizer='rmsprop',
+        metrics=['acc'],
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+
+    x = np.ones((num_samples, input_dim))
+    y = np.zeros((num_samples, num_classes))
+
+    model.fit(x, y, epochs=2, batch_size=32, verbose=0)
+    _ = model.evaluate(x, y, verbose=0)
+
+    self.assertEqual(len(model.weights), 8 + len(model.test_net.weights))
+    self.assertEqual(len(model.non_trainable_weights),
+                     2 + len(model.test_net.non_trainable_weights))
+    self.assertEqual(len(model.trainable_weights),
+                     6 + len(model.test_net.trainable_weights))
+
+  def test_graph_nested_in_subclass(self):
+    num_classes = 2
+    num_samples = 100
+    input_dim = 50
+
+    model = model_util.NestedTestModel2(num_classes=num_classes)
+    model.compile(
+        loss='mse',
+        optimizer='rmsprop',
+        metrics=['acc'],
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+
+    x = np.ones((num_samples, input_dim))
+    y = np.zeros((num_samples, num_classes))
+
+    model.fit(x, y, epochs=2, batch_size=32, verbose=0)
+    _ = model.evaluate(x, y, verbose=0)
+
+    self.assertEqual(len(model.weights), 8 + len(model.test_net.weights))
+    self.assertEqual(len(model.non_trainable_weights),
+                     2 + len(model.test_net.non_trainable_weights))
+    self.assertEqual(len(model.trainable_weights),
+                     6 + len(model.test_net.trainable_weights))
+
+  def test_subclass_nested_in_graph(self):
+    num_classes = 2
+    num_samples = 100
+    input_dim = 50
+
+    model = model_util.get_nested_model_3(
+        input_dim=input_dim, num_classes=num_classes)
+    model.compile(
+        loss='mse',
+        optimizer='rmsprop',
+        metrics=['acc'],
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+
+    x = np.ones((num_samples, input_dim))
+    y = np.zeros((num_samples, num_classes))
+
+    model.fit(x, y, epochs=2, batch_size=32, verbose=0)
+    _ = model.evaluate(x, y, verbose=0)
+
+    self.assertEqual(len(model.weights), 16)
+    self.assertEqual(len(model.non_trainable_weights), 4)
+    self.assertEqual(len(model.trainable_weights), 12)
+
+  def test_subclass_nested_in_sequential(self):
+    num_classes = 2
+    num_samples = 100
+    input_dim = 50
+
+    class Inner(keras.Model):
+
+      def __init__(self):
+        super(Inner, self).__init__()
+        self.dense1 = keras.layers.Dense(32, activation='relu')
+        self.dense2 = keras.layers.Dense(num_classes, activation='relu')
+        self.bn = keras.layers.BatchNormalization()
+
+      def call(self, inputs):
+        x = self.dense1(inputs)
+        x = self.dense2(x)
+        return self.bn(x)
+
+    model = keras.Sequential([Inner()])
+    model.compile(
+        loss='mse',
+        optimizer='rmsprop',
+        metrics=['acc'],
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+
+    x = np.ones((num_samples, input_dim))
+    y = np.zeros((num_samples, num_classes))
+    model.fit(x, y, epochs=2, batch_size=32, verbose=0)
+    _ = model.evaluate(x, y, verbose=0)
+
+    self.assertEqual(len(model.weights), 8)
+    self.assertEqual(len(model.non_trainable_weights), 2)
+    self.assertEqual(len(model.trainable_weights), 6)
+
+  def test_support_for_manual_training_arg(self):
+    # In most cases, the `training` argument is left unspecified, in which
+    # case it defaults to value corresponding to the Model method being used
+    # (fit -> True, predict -> False, etc).
+    # If the user writes their model `call` method to take
+    # an explicit `training` argument, we must check that the correct value
+    # is being passed to the model for each method call.
+
+    class DPNet(keras.Model):
+
+      def __init__(self):
+        super(DPNet, self).__init__()
+        self.dp = keras.layers.Dropout(0.5)
+        self.dense = keras.layers.Dense(1,
+                                        use_bias=False,
+                                        kernel_initializer='ones')
+
+      def call(self, inputs, training=False):
+        x = self.dp(inputs, training=training)
+        return self.dense(x)
+
+    model = DPNet()
+    x = np.ones((10, 10))
+    y = model.predict(x)
+    self.assertEqual(np.sum(y), np.sum(x))
+    model.compile(
+        loss='mse',
+        optimizer='rmsprop',
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    loss = model.train_on_batch(x, y)
+    self.assertGreater(loss, 0.1)
+
+  def test_no_loss_in_compile(self):
+
+    class InternalLossModel(keras.Model):
+
+      def __init__(self):
+        super(InternalLossModel, self).__init__()
+        self.dense = keras.layers.Dense(1)
+
+      def call(self, inputs):
+        out = self.dense(inputs)
+        self.add_loss(math_ops.reduce_sum(out))
+        return out
+
+    model = InternalLossModel()
+    x = np.ones((10, 10))
+    model.predict(x)
+    model.compile(
+        optimizer='rmsprop',
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    model.fit(x)
+    model.evaluate(x)
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py
index 9cf1932..ac9e29d 100644
--- a/tensorflow/python/keras/model_subclassing_test.py
+++ b/tensorflow/python/keras/model_subclassing_test.py
@@ -23,15 +23,14 @@
 import numpy as np
 
 from tensorflow.python import keras
-from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.eager import context
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import test_util
 from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras import model_subclassing_test_util as model_util
 from tensorflow.python.keras import testing_utils
 from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import embedding_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import resource_variable_ops
@@ -44,150 +43,6 @@
   h5py = None
 
 
-# pylint: disable=not-callable
-class SimpleTestModel(keras.Model):
-
-  def __init__(self, use_bn=False, use_dp=False, num_classes=10):
-    super(SimpleTestModel, self).__init__(name='test_model')
-    self.use_bn = use_bn
-    self.use_dp = use_dp
-    self.num_classes = num_classes
-
-    self.dense1 = keras.layers.Dense(32, activation='relu')
-    self.dense2 = keras.layers.Dense(num_classes, activation='softmax')
-    if self.use_dp:
-      self.dp = keras.layers.Dropout(0.5)
-    if self.use_bn:
-      self.bn = keras.layers.BatchNormalization(axis=-1)
-
-  def call(self, x):
-    x = self.dense1(x)
-    if self.use_dp:
-      x = self.dp(x)
-    if self.use_bn:
-      x = self.bn(x)
-    return self.dense2(x)
-
-
-class SimpleConvTestModel(keras.Model):
-
-  def __init__(self, num_classes=10):
-    super(SimpleConvTestModel, self).__init__(name='test_model')
-    self.num_classes = num_classes
-
-    self.conv1 = keras.layers.Conv2D(32, (3, 3), activation='relu')
-    self.flatten = keras.layers.Flatten()
-    self.dense1 = keras.layers.Dense(num_classes, activation='softmax')
-
-  def call(self, x):
-    x = self.conv1(x)
-    x = self.flatten(x)
-    return self.dense1(x)
-
-
-class MultiIOTestModel(keras.Model):
-
-  def __init__(self, use_bn=False, use_dp=False, num_classes=(2, 3)):
-    super(MultiIOTestModel, self).__init__(name='test_model')
-    self.use_bn = use_bn
-    self.use_dp = use_dp
-    self.num_classes = num_classes
-
-    self.dense1 = keras.layers.Dense(32, activation='relu')
-    self.dense2 = keras.layers.Dense(num_classes[0], activation='softmax')
-    self.dense3 = keras.layers.Dense(num_classes[1], activation='softmax')
-    if use_dp:
-      self.dp = keras.layers.Dropout(0.5)
-    if use_bn:
-      self.bn = keras.layers.BatchNormalization()
-
-  def call(self, inputs):
-    x1, x2 = inputs
-    x1 = self.dense1(x1)
-    x2 = self.dense1(x2)
-    if self.use_dp:
-      x1 = self.dp(x1)
-    if self.use_bn:
-      x2 = self.bn(x2)
-    return [self.dense2(x1), self.dense3(x2)]
-
-
-class NestedTestModel1(keras.Model):
-  """A model subclass nested inside a model subclass.
-  """
-
-  def __init__(self, num_classes=2):
-    super(NestedTestModel1, self).__init__(name='nested_model_1')
-    self.num_classes = num_classes
-    self.dense1 = keras.layers.Dense(32, activation='relu')
-    self.dense2 = keras.layers.Dense(num_classes, activation='relu')
-    self.bn = keras.layers.BatchNormalization()
-    self.test_net = SimpleTestModel(num_classes=4,
-                                    use_bn=True,
-                                    use_dp=True)
-
-  def call(self, inputs):
-    x = self.dense1(inputs)
-    x = self.bn(x)
-    x = self.test_net(x)
-    return self.dense2(x)
-
-
-def get_functional_graph_model(input_dim, num_classes):
-  # A simple functional-API model (a.k.a. graph network)
-  inputs = keras.Input(shape=(input_dim,))
-  x = keras.layers.Dense(32, activation='relu')(inputs)
-  x = keras.layers.BatchNormalization()(x)
-  outputs = keras.layers.Dense(num_classes)(x)
-  return keras.Model(inputs, outputs)
-
-
-class NestedTestModel2(keras.Model):
-  """A model subclass with a functional-API graph network inside.
-  """
-
-  def __init__(self, num_classes=2):
-    super(NestedTestModel2, self).__init__(name='nested_model_2')
-    self.num_classes = num_classes
-    self.dense1 = keras.layers.Dense(32, activation='relu')
-    self.dense2 = keras.layers.Dense(num_classes, activation='relu')
-    self.bn = self.bn = keras.layers.BatchNormalization()
-    self.test_net = get_functional_graph_model(32, 4)
-
-  def call(self, inputs):
-    x = self.dense1(inputs)
-    x = self.bn(x)
-    x = self.test_net(x)
-    return self.dense2(x)
-
-
-def get_nested_model_3(input_dim, num_classes):
-  # A functional-API model with a subclassed model inside.
-  # NOTE: this requires the inner subclass to implement `compute_output_shape`.
-
-  inputs = keras.Input(shape=(input_dim,))
-  x = keras.layers.Dense(32, activation='relu')(inputs)
-  x = keras.layers.BatchNormalization()(x)
-
-  class Inner(keras.Model):
-
-    def __init__(self):
-      super(Inner, self).__init__()
-      self.dense1 = keras.layers.Dense(32, activation='relu')
-      self.dense2 = keras.layers.Dense(5, activation='relu')
-      self.bn = keras.layers.BatchNormalization()
-
-    def call(self, inputs):
-      x = self.dense1(inputs)
-      x = self.dense2(x)
-      return self.bn(x)
-
-  test_model = Inner()
-  x = test_model(x)
-  outputs = keras.layers.Dense(num_classes)(x)
-  return keras.Model(inputs, outputs, name='nested_model_3')
-
-
 @keras_parameterized.run_all_keras_modes
 class ModelSubclassingTest(keras_parameterized.TestCase):
 
@@ -242,7 +97,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(np.ones((10, 10)), np.ones((10, 1)), batch_size=2, epochs=2)
     self.assertLen(model.layers, 2)
     self.assertLen(model.trainable_variables, 4)
@@ -251,9 +106,8 @@
     num_classes = 2
     input_dim = 50
 
-    model = SimpleTestModel(num_classes=num_classes,
-                            use_dp=True,
-                            use_bn=True)
+    model = model_util.SimpleTestModel(
+        num_classes=num_classes, use_dp=True, use_bn=True)
 
     self.assertFalse(model.built, 'Model should not have been built')
     self.assertFalse(model.weights, ('Model should have no weights since it '
@@ -329,9 +183,8 @@
     input_dim = 50
     batch_size = None
 
-    model = SimpleTestModel(num_classes=num_classes,
-                            use_dp=True,
-                            use_bn=True)
+    model = model_util.SimpleTestModel(
+        num_classes=num_classes, use_dp=True, use_bn=True)
 
     self.assertFalse(model.built, 'Model should not have been built')
     self.assertFalse(model.weights, ('Model should have no weights since it '
@@ -347,9 +200,8 @@
     input_dim = tensor_shape.Dimension(50)
     batch_size = tensor_shape.Dimension(None)
 
-    model = SimpleTestModel(num_classes=num_classes,
-                            use_dp=True,
-                            use_bn=True)
+    model = model_util.SimpleTestModel(
+        num_classes=num_classes, use_dp=True, use_bn=True)
 
     self.assertFalse(model.built, 'Model should not have been built')
     self.assertFalse(model.weights, ('Model should have no weights since it '
@@ -366,7 +218,7 @@
     batch_size = 32
     input_shape = (32, 32, 3)
 
-    model = SimpleConvTestModel(num_classes)
+    model = model_util.SimpleConvTestModel(num_classes)
     self.assertFalse(model.built, 'Model should not have been built')
     self.assertFalse(model.weights, ('Model should have no weights since it '
                                      'has not been built.'))
@@ -384,7 +236,7 @@
     batch_size = None
     input_shape = (32, 32, 3)
 
-    model = SimpleConvTestModel(num_classes)
+    model = model_util.SimpleConvTestModel(num_classes)
     self.assertFalse(model.built, 'Model should not have been built')
     self.assertFalse(model.weights, ('Model should have no weights since it '
                                      'has not been built.'))
@@ -402,7 +254,7 @@
     batch_size = None
     input_shape = (32, 32, 3)
 
-    model = SimpleConvTestModel(num_classes)
+    model = model_util.SimpleConvTestModel(num_classes)
     self.assertFalse(model.built, 'Model should not have been built')
     self.assertFalse(model.weights, ('Model should have no weights since it '
                                      'has not been built.'))
@@ -419,7 +271,7 @@
       hdf5_format_name = os.path.join(self.get_temp_dir(), 'weights.h5')
       model.save_weights(hdf5_format_name)
 
-    model = SimpleConvTestModel(num_classes)
+    model = model_util.SimpleConvTestModel(num_classes)
     model.build(
         input_shape=tensor_shape.TensorShape((batch_size,) + input_shape))
     if h5py is not None:
@@ -432,7 +284,7 @@
     batch_size = None
     num_samples = 1000
     input_dim = 50
-    model = MultiIOTestModel()
+    model = model_util.MultiIOTestModel()
     self.assertFalse(model.built, 'Model should not have been built')
     self.assertFalse(model.weights, ('Model should have no weights since it '
                                      'has not been built.'))
@@ -457,14 +309,15 @@
         self.contents += msg + '\n'
 
     # Single-io
-    model = SimpleTestModel(num_classes=4, use_bn=True, use_dp=True)
+    model = model_util.SimpleTestModel(num_classes=4, use_bn=True, use_dp=True)
     model._set_inputs(np.ones((3, 4)))  # need to build model first
     print_fn = ToString()
     model.summary(print_fn=print_fn)
     self.assertTrue('Trainable params: 356' in print_fn.contents)
 
     # Multi-io
-    model = MultiIOTestModel(num_classes=(5, 6), use_bn=True, use_dp=True)
+    model = model_util.MultiIOTestModel(
+        num_classes=(5, 6), use_bn=True, use_dp=True)
     model._set_inputs([np.ones((3, 4)),
                        np.ones((3, 4))])  # need to build model first
     print_fn = ToString()
@@ -599,440 +452,6 @@
       self.assertEqual(1, len(model.get_updates_for(x)))
 
 
-@keras_parameterized.run_all_keras_modes
-class ModelSubclassCompiledTest(keras_parameterized.TestCase):
-
-  def test_single_io_workflow_with_np_arrays(self):
-    num_classes = 2
-    num_samples = 100
-    input_dim = 50
-
-    model = SimpleTestModel(num_classes=num_classes,
-                            use_dp=True,
-                            use_bn=True)
-    model.compile(
-        loss='mse',
-        optimizer='rmsprop',
-        metrics=['acc', keras.metrics.CategoricalAccuracy()],
-        run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
-
-    x = np.ones((num_samples, input_dim))
-    y = np.zeros((num_samples, num_classes))
-
-    model.fit(x, y, epochs=2, batch_size=32, verbose=0)
-    _ = model.evaluate(x, y, verbose=0)
-
-  def test_multi_io_workflow_with_np_arrays(self):
-    num_classes = (2, 3)
-    num_samples = 1000
-    input_dim = 50
-
-    model = MultiIOTestModel(num_classes=num_classes,
-                             use_dp=True,
-                             use_bn=True)
-    model.compile(
-        loss='mse',
-        optimizer='rmsprop',
-        metrics=['acc'],
-        run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
-
-    x1 = np.ones((num_samples, input_dim))
-    x2 = np.ones((num_samples, input_dim))
-    y1 = np.zeros((num_samples, num_classes[0]))
-    y2 = np.zeros((num_samples, num_classes[1]))
-
-    model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0)
-    _ = model.evaluate([x1, x2], [y1, y2], verbose=0)
-
-  def test_single_io_workflow_with_datasets(self):
-    num_classes = 2
-    num_samples = 10
-    input_dim = 50
-
-    with self.cached_session():
-      model = SimpleTestModel(num_classes=num_classes, use_dp=True, use_bn=True)
-      model.compile(
-          loss='mse',
-          optimizer='rmsprop',
-          run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
-
-      x = np.ones((num_samples, input_dim), dtype=np.float32)
-      y = np.zeros((num_samples, num_classes), dtype=np.float32)
-      dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
-      dataset = dataset.repeat(100)
-      dataset = dataset.batch(10)
-
-      model.fit(dataset, epochs=2, steps_per_epoch=10, verbose=0)
-      _ = model.evaluate(dataset, steps=10, verbose=0)
-
-  def test_attributes(self):
-    # layers, weights, trainable_weights, non_trainable_weights, inputs, outputs
-
-    num_classes = (2, 3)
-    num_samples = 100
-    input_dim = 50
-
-    model = MultiIOTestModel(num_classes=num_classes, use_bn=True)
-
-    x1 = np.ones((num_samples, input_dim))
-    x2 = np.ones((num_samples, input_dim))
-    y1 = np.zeros((num_samples, num_classes[0]))
-    y2 = np.zeros((num_samples, num_classes[1]))
-
-    self.assertEqual(model.name, 'test_model')
-    self.assertEqual(model.built, False)
-    self.assertEqual(len(model.weights), 0)
-
-    model.compile(
-        loss='mse',
-        optimizer='rmsprop',
-        run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
-    model.train_on_batch([x1, x2], [y1, y2])
-
-    self.assertEqual(model.built, True)
-    self.assertEqual(len(model.layers), 4)
-    self.assertEqual(len(model.weights), 10)
-    self.assertEqual(len(model.trainable_weights), 8)
-    self.assertEqual(len(model.non_trainable_weights), 2)
-    self.assertEqual(len(model.inputs), 2)
-    self.assertEqual(len(model.outputs), 2)
-
-  def test_updates(self):
-    # test that updates get run during training
-    num_samples = 100
-    input_dim = 50
-
-    class BNNet(keras.Model):
-
-      def __init__(self):
-        super(BNNet, self).__init__()
-        self.bn = keras.layers.BatchNormalization(beta_initializer='ones',
-                                                  gamma_initializer='ones')
-
-      def call(self, inputs):
-        return self.bn(inputs)
-
-    x = np.ones((num_samples, input_dim))
-    y = np.ones((num_samples, input_dim))
-
-    model = BNNet()
-    model.compile(
-        loss='mse',
-        optimizer='rmsprop',
-        run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
-    y_ref = model.predict(x)
-
-    model.train_on_batch(x, y)
-    y_new = model.predict(x)
-    self.assertGreater(np.sum(np.abs(y_ref - y_new)), 0.1)
-
-  def test_training_and_inference_behavior(self):
-    # test that dropout is applied in training and not inference
-
-    num_samples = 100
-    input_dim = 50
-
-    class DPNet(keras.Model):
-
-      def __init__(self):
-        super(DPNet, self).__init__()
-        self.dp = keras.layers.Dropout(0.5)
-        self.dense = keras.layers.Dense(1,
-                                        use_bias=False,
-                                        kernel_initializer='ones')
-
-      def call(self, inputs):
-        x = self.dp(inputs)
-        return self.dense(x)
-
-    model = DPNet()
-    x = np.ones((num_samples, input_dim))
-    y = model.predict(x)
-    self.assertEqual(np.sum(y), np.sum(x))
-    model.compile(
-        loss='mse',
-        optimizer='rmsprop',
-        run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
-    loss = model.train_on_batch(x, y)
-    self.assertGreater(loss, 0.1)
-
-  def test_training_methods(self):
-    # test fit, train_on_batch
-    # on different input types: list, dict
-
-    num_classes = (2, 3)
-    num_samples = 100
-    input_dim = 50
-
-    x1 = np.ones((num_samples, input_dim))
-    x2 = np.ones((num_samples, input_dim))
-    y1 = np.zeros((num_samples, num_classes[0]))
-    y2 = np.zeros((num_samples, num_classes[1]))
-
-    model = MultiIOTestModel(num_classes=num_classes, use_bn=True)
-    model.compile(
-        loss='mse',
-        optimizer='rmsprop',
-        run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
-    model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0)
-    model.fit({'input_1': x1, 'input_2': x2},
-              {'output_1': y1, 'output_2': y2},
-              epochs=2, batch_size=32)
-    model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0,
-              validation_data=([x1, x2], [y1, y2]))
-
-    model = MultiIOTestModel(num_classes=num_classes, use_bn=True)
-    model.compile(
-        loss='mse',
-        optimizer='rmsprop',
-        run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
-    model.train_on_batch([x1, x2], [y1, y2])
-    model.train_on_batch({'input_1': x1, 'input_2': x2},
-                         {'output_1': y1, 'output_2': y2})
-
-  def test_inference_methods(self):
-    # test predict, evaluate, test_on_batch, predict_on_batch
-    # on different input types: list, dict
-    num_classes = (2, 3)
-    num_samples = 100
-    input_dim = 50
-
-    x1 = np.ones((num_samples, input_dim))
-    x2 = np.ones((num_samples, input_dim))
-    y1 = np.zeros((num_samples, num_classes[0]))
-    y2 = np.zeros((num_samples, num_classes[1]))
-
-    model = MultiIOTestModel(num_classes=num_classes, use_bn=True)
-    model.compile(
-        loss='mse',
-        optimizer='rmsprop',
-        run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
-    model.evaluate([x1, x2], [y1, y2])
-    model.test_on_batch([x1, x2], [y1, y2])
-
-    model = MultiIOTestModel(num_classes=num_classes, use_bn=True)
-    model.predict([x1, x2])
-
-    model = MultiIOTestModel(num_classes=num_classes, use_bn=True)
-    model.predict_on_batch([x1, x2])
-
-  def test_saving(self):
-    num_classes = (2, 3)
-    num_samples = 100
-    input_dim = 50
-
-    x1 = np.ones((num_samples, input_dim))
-    x2 = np.ones((num_samples, input_dim))
-    y1 = np.zeros((num_samples, num_classes[0]))
-    y2 = np.zeros((num_samples, num_classes[1]))
-
-    model = MultiIOTestModel(num_classes=num_classes, use_bn=True)
-    model.compile(
-        loss='mse',
-        optimizer='rmsprop',
-        run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
-    model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0)
-    y_ref_1, y_ref_2 = model.predict([x1, x2])
-
-    tf_format_name = os.path.join(self.get_temp_dir(), 'ckpt')
-    model.save_weights(tf_format_name)
-    if h5py is not None:
-      hdf5_format_name = os.path.join(self.get_temp_dir(), 'weights.h5')
-      model.save_weights(hdf5_format_name)
-
-    model = MultiIOTestModel(num_classes=num_classes, use_bn=True)
-
-    if h5py is not None:
-      with self.assertRaises(ValueError):
-        model.load_weights(hdf5_format_name)
-
-    model.load_weights(tf_format_name)
-
-    y1, y2 = model.predict([x1, x2])
-    self.assertAllClose(y_ref_1, y1, atol=1e-5)
-    self.assertAllClose(y_ref_2, y2, atol=1e-5)
-
-    if h5py is not None:
-      model.load_weights(hdf5_format_name)
-
-      y1, y2 = model.predict([x1, x2])
-      self.assertAllClose(y_ref_1, y1, atol=1e-5)
-      self.assertAllClose(y_ref_2, y2, atol=1e-5)
-
-  def test_subclass_nested_in_subclass(self):
-    num_classes = 2
-    num_samples = 100
-    input_dim = 50
-
-    model = NestedTestModel1(num_classes=num_classes)
-    model.compile(
-        loss='mse',
-        optimizer='rmsprop',
-        metrics=['acc'],
-        run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
-
-    x = np.ones((num_samples, input_dim))
-    y = np.zeros((num_samples, num_classes))
-
-    model.fit(x, y, epochs=2, batch_size=32, verbose=0)
-    _ = model.evaluate(x, y, verbose=0)
-
-    self.assertEqual(len(model.weights), 8 + len(model.test_net.weights))
-    self.assertEqual(len(model.non_trainable_weights),
-                     2 + len(model.test_net.non_trainable_weights))
-    self.assertEqual(len(model.trainable_weights),
-                     6 + len(model.test_net.trainable_weights))
-
-  def test_graph_nested_in_subclass(self):
-    num_classes = 2
-    num_samples = 100
-    input_dim = 50
-
-    model = NestedTestModel2(num_classes=num_classes)
-    model.compile(
-        loss='mse',
-        optimizer='rmsprop',
-        metrics=['acc'],
-        run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
-
-    x = np.ones((num_samples, input_dim))
-    y = np.zeros((num_samples, num_classes))
-
-    model.fit(x, y, epochs=2, batch_size=32, verbose=0)
-    _ = model.evaluate(x, y, verbose=0)
-
-    self.assertEqual(len(model.weights), 8 + len(model.test_net.weights))
-    self.assertEqual(len(model.non_trainable_weights),
-                     2 + len(model.test_net.non_trainable_weights))
-    self.assertEqual(len(model.trainable_weights),
-                     6 + len(model.test_net.trainable_weights))
-
-  def test_subclass_nested_in_graph(self):
-    num_classes = 2
-    num_samples = 100
-    input_dim = 50
-
-    model = get_nested_model_3(input_dim=input_dim, num_classes=num_classes)
-    model.compile(
-        loss='mse',
-        optimizer='rmsprop',
-        metrics=['acc'],
-        run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
-
-    x = np.ones((num_samples, input_dim))
-    y = np.zeros((num_samples, num_classes))
-
-    model.fit(x, y, epochs=2, batch_size=32, verbose=0)
-    _ = model.evaluate(x, y, verbose=0)
-
-    self.assertEqual(len(model.weights), 16)
-    self.assertEqual(len(model.non_trainable_weights), 4)
-    self.assertEqual(len(model.trainable_weights), 12)
-
-  def test_subclass_nested_in_sequential(self):
-    num_classes = 2
-    num_samples = 100
-    input_dim = 50
-
-    class Inner(keras.Model):
-
-      def __init__(self):
-        super(Inner, self).__init__()
-        self.dense1 = keras.layers.Dense(32, activation='relu')
-        self.dense2 = keras.layers.Dense(num_classes, activation='relu')
-        self.bn = keras.layers.BatchNormalization()
-
-      def call(self, inputs):
-        x = self.dense1(inputs)
-        x = self.dense2(x)
-        return self.bn(x)
-
-    model = keras.Sequential([Inner()])
-    model.compile(
-        loss='mse',
-        optimizer='rmsprop',
-        metrics=['acc'],
-        run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
-
-    x = np.ones((num_samples, input_dim))
-    y = np.zeros((num_samples, num_classes))
-    model.fit(x, y, epochs=2, batch_size=32, verbose=0)
-    _ = model.evaluate(x, y, verbose=0)
-
-    self.assertEqual(len(model.weights), 8)
-    self.assertEqual(len(model.non_trainable_weights), 2)
-    self.assertEqual(len(model.trainable_weights), 6)
-
-  def test_support_for_manual_training_arg(self):
-    # In most cases, the `training` argument is left unspecified, in which
-    # case it defaults to value corresponding to the Model method being used
-    # (fit -> True, predict -> False, etc).
-    # If the user writes their model `call` method to take
-    # an explicit `training` argument, we must check that the correct value
-    # is being passed to the model for each method call.
-
-    class DPNet(keras.Model):
-
-      def __init__(self):
-        super(DPNet, self).__init__()
-        self.dp = keras.layers.Dropout(0.5)
-        self.dense = keras.layers.Dense(1,
-                                        use_bias=False,
-                                        kernel_initializer='ones')
-
-      def call(self, inputs, training=False):
-        x = self.dp(inputs, training=training)
-        return self.dense(x)
-
-    model = DPNet()
-    x = np.ones((10, 10))
-    y = model.predict(x)
-    self.assertEqual(np.sum(y), np.sum(x))
-    model.compile(
-        loss='mse',
-        optimizer='rmsprop',
-        run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
-    loss = model.train_on_batch(x, y)
-    self.assertGreater(loss, 0.1)
-
-  def test_no_loss_in_compile(self):
-
-    class InternalLossModel(keras.Model):
-
-      def __init__(self):
-        super(InternalLossModel, self).__init__()
-        self.dense = keras.layers.Dense(1)
-
-      def call(self, inputs):
-        out = self.dense(inputs)
-        self.add_loss(math_ops.reduce_sum(out))
-        return out
-
-    model = InternalLossModel()
-    x = np.ones((10, 10))
-    model.predict(x)
-    model.compile(
-        optimizer='rmsprop',
-        run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
-    model.fit(x)
-    model.evaluate(x)
-
 
 class GraphSpecificModelSubclassingTests(test.TestCase):
 
@@ -1043,9 +462,8 @@
     input_dim = 50
 
     with self.cached_session():
-      model = SimpleTestModel(num_classes=num_classes,
-                              use_dp=True,
-                              use_bn=True)
+      model = model_util.SimpleTestModel(
+          num_classes=num_classes, use_dp=True, use_bn=True)
       model.compile(loss='mse', optimizer='rmsprop')
 
       x = array_ops.ones((num_samples, input_dim))
@@ -1061,9 +479,8 @@
     input_dim = 50
 
     with self.cached_session():
-      model = MultiIOTestModel(num_classes=num_classes,
-                               use_dp=True,
-                               use_bn=True)
+      model = model_util.MultiIOTestModel(
+          num_classes=num_classes, use_dp=True, use_bn=True)
       model.compile(loss='mse', optimizer='rmsprop')
 
       x1 = array_ops.ones((num_samples, input_dim))
@@ -1149,9 +566,8 @@
     input_dim = 50
 
     with self.cached_session():
-      model = MultiIOTestModel(num_classes=num_classes,
-                               use_dp=True,
-                               use_bn=True)
+      model = model_util.MultiIOTestModel(
+          num_classes=num_classes, use_dp=True, use_bn=True)
       model.compile(loss='mse', optimizer='rmsprop')
 
       x1 = np.ones((num_samples, input_dim))
@@ -1167,46 +583,11 @@
       _ = model.evaluate([x1, x2], [y1, y2], verbose=0)
 
 
-class CustomCallModel(keras.Model):
-
-  def __init__(self):
-    super(CustomCallModel, self).__init__()
-    self.dense1 = keras.layers.Dense(1, activation='relu')
-    self.dense2 = keras.layers.Dense(1, activation='softmax')
-
-  def call(self, first, second, fiddle_with_output='no', training=True):
-    combined = self.dense1(first) + self.dense2(second)
-    if fiddle_with_output == 'yes':
-      return 10. * combined
-    else:
-      return combined
-
-
-class TrainingNoDefaultModel(keras.Model):
-
-  def __init__(self):
-    super(TrainingNoDefaultModel, self).__init__()
-    self.dense1 = keras.layers.Dense(1)
-
-  def call(self, x, training):
-    return self.dense1(x)
-
-
-class TrainingMaskingModel(keras.Model):
-
-  def __init__(self):
-    super(TrainingMaskingModel, self).__init__()
-    self.dense1 = keras.layers.Dense(1)
-
-  def call(self, x, training=False, mask=None):
-    return self.dense1(x)
-
-
 @test_util.run_all_in_graph_and_eager_modes
 class CustomCallSignatureTests(test.TestCase):
 
   def test_no_inputs_in_signature(self):
-    model = CustomCallModel()
+    model = model_util.CustomCallModel()
     first = array_ops.ones([2, 3])
     second = array_ops.ones([2, 5])
     output = model(first, second)
@@ -1221,7 +602,7 @@
   def test_training_args_call_build(self):
     input_dim = 2
 
-    model = TrainingNoDefaultModel()
+    model = model_util.TrainingNoDefaultModel()
     self.assertFalse(model.built, 'Model should not have been built')
     self.assertFalse(model.weights, ('Model should have no weights since it '
                                      'has not been built.'))
@@ -1233,7 +614,7 @@
   def test_training_and_mask_args_call_build(self):
     input_dim = 2
 
-    model = TrainingMaskingModel()
+    model = model_util.TrainingMaskingModel()
     self.assertFalse(model.built, 'Model should not have been built')
     self.assertFalse(model.weights, ('Model should have no weights since it '
                                      'has not been built.'))
@@ -1246,7 +627,7 @@
     first_input_shape = (2, 3)
     second_input_shape = (2, 5)
 
-    model = CustomCallModel()
+    model = model_util.CustomCallModel()
     self.assertFalse(model.built, 'Model should not have been built')
     self.assertFalse(model.weights, ('Model should have no weights since it '
                                      'has not been built.'))
@@ -1270,14 +651,11 @@
   @test_util.assert_no_new_tensors
   @test_util.assert_no_garbage_created
   def test_training_no_default(self):
-    if context.executing_eagerly():
-      self.skipTest('b/120997007')
-
-    model = TrainingNoDefaultModel()
-
+    if not context.executing_eagerly():
+      self.skipTest('b/138307499')
+    model = model_util.TrainingNoDefaultModel()
     arg = array_ops.ones([1, 1])
     model(arg, True)
-    self.assertEqual(len(model.inputs), 1)
 
   def test_positional_arg_in_call(self):
 
diff --git a/tensorflow/python/keras/model_subclassing_test_util.py b/tensorflow/python/keras/model_subclassing_test_util.py
new file mode 100644
index 0000000..0f07c71
--- /dev/null
+++ b/tensorflow/python/keras/model_subclassing_test_util.py
@@ -0,0 +1,200 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras models for use in Model subclassing tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python import keras
+
+
+# pylint: disable=missing-docstring,not-callable
+class SimpleTestModel(keras.Model):
+
+  def __init__(self, use_bn=False, use_dp=False, num_classes=10):
+    super(SimpleTestModel, self).__init__(name='test_model')
+    self.use_bn = use_bn
+    self.use_dp = use_dp
+    self.num_classes = num_classes
+
+    self.dense1 = keras.layers.Dense(32, activation='relu')
+    self.dense2 = keras.layers.Dense(num_classes, activation='softmax')
+    if self.use_dp:
+      self.dp = keras.layers.Dropout(0.5)
+    if self.use_bn:
+      self.bn = keras.layers.BatchNormalization(axis=-1)
+
+  def call(self, x):
+    x = self.dense1(x)
+    if self.use_dp:
+      x = self.dp(x)
+    if self.use_bn:
+      x = self.bn(x)
+    return self.dense2(x)
+
+
+class SimpleConvTestModel(keras.Model):
+
+  def __init__(self, num_classes=10):
+    super(SimpleConvTestModel, self).__init__(name='test_model')
+    self.num_classes = num_classes
+
+    self.conv1 = keras.layers.Conv2D(32, (3, 3), activation='relu')
+    self.flatten = keras.layers.Flatten()
+    self.dense1 = keras.layers.Dense(num_classes, activation='softmax')
+
+  def call(self, x):
+    x = self.conv1(x)
+    x = self.flatten(x)
+    return self.dense1(x)
+
+
+class MultiIOTestModel(keras.Model):
+
+  def __init__(self, use_bn=False, use_dp=False, num_classes=(2, 3)):
+    super(MultiIOTestModel, self).__init__(name='test_model')
+    self.use_bn = use_bn
+    self.use_dp = use_dp
+    self.num_classes = num_classes
+
+    self.dense1 = keras.layers.Dense(32, activation='relu')
+    self.dense2 = keras.layers.Dense(num_classes[0], activation='softmax')
+    self.dense3 = keras.layers.Dense(num_classes[1], activation='softmax')
+    if use_dp:
+      self.dp = keras.layers.Dropout(0.5)
+    if use_bn:
+      self.bn = keras.layers.BatchNormalization()
+
+  def call(self, inputs):
+    x1, x2 = inputs
+    x1 = self.dense1(x1)
+    x2 = self.dense1(x2)
+    if self.use_dp:
+      x1 = self.dp(x1)
+    if self.use_bn:
+      x2 = self.bn(x2)
+    return [self.dense2(x1), self.dense3(x2)]
+
+
+class NestedTestModel1(keras.Model):
+  """A model subclass nested inside a model subclass.
+  """
+
+  def __init__(self, num_classes=2):
+    super(NestedTestModel1, self).__init__(name='nested_model_1')
+    self.num_classes = num_classes
+    self.dense1 = keras.layers.Dense(32, activation='relu')
+    self.dense2 = keras.layers.Dense(num_classes, activation='relu')
+    self.bn = keras.layers.BatchNormalization()
+    self.test_net = SimpleTestModel(num_classes=4,
+                                    use_bn=True,
+                                    use_dp=True)
+
+  def call(self, inputs):
+    x = self.dense1(inputs)
+    x = self.bn(x)
+    x = self.test_net(x)
+    return self.dense2(x)
+
+
+class NestedTestModel2(keras.Model):
+  """A model subclass with a functional-API graph network inside.
+  """
+
+  def __init__(self, num_classes=2):
+    super(NestedTestModel2, self).__init__(name='nested_model_2')
+    self.num_classes = num_classes
+    self.dense1 = keras.layers.Dense(32, activation='relu')
+    self.dense2 = keras.layers.Dense(num_classes, activation='relu')
+    self.bn = self.bn = keras.layers.BatchNormalization()
+    self.test_net = self.get_functional_graph_model(32, 4)
+
+  @staticmethod
+  def get_functional_graph_model(input_dim, num_classes):
+    # A simple functional-API model (a.k.a. graph network)
+    inputs = keras.Input(shape=(input_dim,))
+    x = keras.layers.Dense(32, activation='relu')(inputs)
+    x = keras.layers.BatchNormalization()(x)
+    outputs = keras.layers.Dense(num_classes)(x)
+    return keras.Model(inputs, outputs)
+
+  def call(self, inputs):
+    x = self.dense1(inputs)
+    x = self.bn(x)
+    x = self.test_net(x)
+    return self.dense2(x)
+
+
+def get_nested_model_3(input_dim, num_classes):
+  # A functional-API model with a subclassed model inside.
+  # NOTE: this requires the inner subclass to implement `compute_output_shape`.
+
+  inputs = keras.Input(shape=(input_dim,))
+  x = keras.layers.Dense(32, activation='relu')(inputs)
+  x = keras.layers.BatchNormalization()(x)
+
+  class Inner(keras.Model):
+
+    def __init__(self):
+      super(Inner, self).__init__()
+      self.dense1 = keras.layers.Dense(32, activation='relu')
+      self.dense2 = keras.layers.Dense(5, activation='relu')
+      self.bn = keras.layers.BatchNormalization()
+
+    def call(self, inputs):
+      x = self.dense1(inputs)
+      x = self.dense2(x)
+      return self.bn(x)
+
+  test_model = Inner()
+  x = test_model(x)
+  outputs = keras.layers.Dense(num_classes)(x)
+  return keras.Model(inputs, outputs, name='nested_model_3')
+
+
+class CustomCallModel(keras.Model):
+
+  def __init__(self):
+    super(CustomCallModel, self).__init__()
+    self.dense1 = keras.layers.Dense(1, activation='relu')
+    self.dense2 = keras.layers.Dense(1, activation='softmax')
+
+  def call(self, first, second, fiddle_with_output='no', training=True):
+    combined = self.dense1(first) + self.dense2(second)
+    if fiddle_with_output == 'yes':
+      return 10. * combined
+    else:
+      return combined
+
+
+class TrainingNoDefaultModel(keras.Model):
+
+  def __init__(self):
+    super(TrainingNoDefaultModel, self).__init__()
+    self.dense1 = keras.layers.Dense(1)
+
+  def call(self, x, training):
+    return self.dense1(x)
+
+
+class TrainingMaskingModel(keras.Model):
+
+  def __init__(self):
+    super(TrainingMaskingModel, self).__init__()
+    self.dense1 = keras.layers.Dense(1)
+
+  def call(self, x, training=False, mask=None):
+    return self.dense1(x)
diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py
index 6ce20a7..fd6b083 100644
--- a/tensorflow/python/keras/models.py
+++ b/tensorflow/python/keras/models.py
@@ -65,10 +65,7 @@
   ancillary_layers = [
       layer for layer in ancillary_layers if not isinstance(layer, AddMetric)
   ] + metric_layers
-  nodes = set(
-      nest.flatten([layer._inbound_nodes for layer in ancillary_layers]))
-  relevant_nodes = list(nodes.intersection(new_nodes))
-  model._insert_layers(ancillary_layers, relevant_nodes=relevant_nodes)
+  model._insert_layers(ancillary_layers, relevant_nodes=list(new_nodes))
 
 
 def _make_new_nodes(nodes_by_depth, layer_fn, layer_map, tensor_map):
@@ -174,11 +171,7 @@
     # Create placeholders to build the model on top of.
     input_tensors = []
     for layer in model._input_layers:
-      input_tensor = Input(
-          batch_shape=layer._batch_input_shape,
-          dtype=layer.dtype,
-          sparse=layer.sparse,
-          name=layer.name)
+      input_tensor = Input(**layer.get_config())
       input_tensors.append(input_tensor)
       # Cache newly created input layer.
       newly_created_input_layer = input_tensor._keras_history.layer
diff --git a/tensorflow/python/keras/models_test.py b/tensorflow/python/keras/models_test.py
index 0cd79cf..6ee565c 100644
--- a/tensorflow/python/keras/models_test.py
+++ b/tensorflow/python/keras/models_test.py
@@ -177,7 +177,7 @@
         testing_utils.get_v2_optimizer('rmsprop'),
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     new_model.train_on_batch([val_a, val_b], val_out)
 
     # On top of new tensors
@@ -190,7 +190,7 @@
         testing_utils.get_v2_optimizer('rmsprop'),
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     new_model.train_on_batch([val_a, val_b], val_out)
 
     # On top of new, non-Keras tensors
@@ -205,7 +205,7 @@
           testing_utils.get_v2_optimizer('rmsprop'),
           'mse',
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
       new_model.train_on_batch(None, val_out)
 
   @keras_parameterized.run_all_keras_modes
@@ -232,7 +232,7 @@
         loss='mse',
         optimizer=testing_utils.get_v2_optimizer('adam'),
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     y = np.array([[[1], [1]], [[1], [1]]])
     loss = model.train_on_batch(x, y)
     self.assertEqual(float(loss), 0.)
@@ -297,7 +297,7 @@
         optimizer=opt,
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     model.fit(
         x=np.array([[1., 2., 3., 4.]]),
@@ -327,7 +327,7 @@
         testing_utils.get_v2_optimizer('rmsprop'),
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     keras.backend.set_floatx(floatx)
 
@@ -357,7 +357,7 @@
         testing_utils.get_v2_optimizer('rmsprop'),
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     new_model.train_on_batch(inp, out)
 
     # Create new tensors for inputs and targets
@@ -374,7 +374,7 @@
         testing_utils.get_v2_optimizer('rmsprop'),
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     new_model.train_on_batch(inp, out)
 
   def _assert_same_compile_params(self, model):
@@ -428,7 +428,7 @@
         'mse',
         metrics=['acc', metrics.categorical_accuracy],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     self._clone_and_build_test_helper(model, testing_utils.get_model_type())
 
@@ -440,7 +440,7 @@
         'mse',
         metrics=['acc', metrics.categorical_accuracy],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     self._clone_and_build_test_helper(model, 'sequential')
 
     inp = np.random.random((10, 4))
@@ -455,7 +455,7 @@
         'mse',
         metrics=['acc', metrics.categorical_accuracy],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     global_step = keras.backend.variable(123, dtype=dtypes.int64)
     clone_model = models.clone_and_build_model(
diff --git a/tensorflow/python/keras/optimizer_v2/adagrad.py b/tensorflow/python/keras/optimizer_v2/adagrad.py
index 7c9a17f..b053d51 100644
--- a/tensorflow/python/keras/optimizer_v2/adagrad.py
+++ b/tensorflow/python/keras/optimizer_v2/adagrad.py
@@ -20,6 +20,7 @@
 
 import numpy as np
 
+from tensorflow.python.compat import compat
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.keras import backend_config
@@ -27,7 +28,9 @@
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import state_ops
+from tensorflow.python.training import training_ops
 from tensorflow.python.util.tf_export import keras_export
 
 
@@ -151,6 +154,15 @@
                     or self._fallback_apply_state(var_device, var_dtype))
 
     acc = self.get_slot(var, 'accumulator')
+    if compat.forward_compatible(2019, 8, 20):
+      return training_ops.resource_apply_adagrad_v2(
+          var.handle,
+          acc.handle,
+          coefficients['lr_t'],
+          coefficients['epsilon'],
+          grad,
+          use_locking=self._use_locking)
+
     acc_t = state_ops.assign_add(
         acc, math_ops.square(grad), use_locking=self._use_locking)
     var_update = state_ops.assign_sub(
@@ -164,10 +176,22 @@
                     or self._fallback_apply_state(var_device, var_dtype))
 
     acc = self.get_slot(var, 'accumulator')
-    acc_t = self._resource_scatter_add(acc, indices, math_ops.square(grad))
-    acc_t_slice = array_ops.gather(acc_t, indices, axis=coefficients['zero'])
-    var_update = self._resource_scatter_add(
-        var, indices, coefficients['neg_lr_t'] * grad /
+    if compat.forward_compatible(2019, 8, 20):
+      return training_ops.resource_sparse_apply_adagrad_v2(
+          var.handle,
+          acc.handle,
+          coefficients['lr_t'],
+          coefficients['epsilon'],
+          grad,
+          indices,
+          use_locking=self._use_locking)
+    with ops.control_dependencies([
+        resource_variable_ops.resource_scatter_add(acc.handle, indices,
+                                                   math_ops.square(grad))
+    ]):
+      acc_t_slice = acc.sparse_read(indices)
+    var_update = resource_variable_ops.resource_scatter_add(
+        var.handle, indices, coefficients['neg_lr_t'] * grad /
         (math_ops.sqrt(acc_t_slice) + coefficients['epsilon']))
     return var_update
 
diff --git a/tensorflow/python/keras/optimizer_v2/adagrad_test.py b/tensorflow/python/keras/optimizer_v2/adagrad_test.py
index 3ddf985..d3a2ac8 100644
--- a/tensorflow/python/keras/optimizer_v2/adagrad_test.py
+++ b/tensorflow/python/keras/optimizer_v2/adagrad_test.py
@@ -161,6 +161,47 @@
           self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
           self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
 
+  def testBasicWithLargeEpsilon(self):
+    with self.cached_session():
+      var0_np = np.array([1.0, 2.0])
+      var1_np = np.array([3.0, 4.0])
+      grads0_np = np.array([0.1, 0.1])
+      grads1_np = np.array([0.01, 0.01])
+      var0 = resource_variable_ops.ResourceVariable(var0_np)
+      var1 = resource_variable_ops.ResourceVariable(var1_np)
+      grads0 = constant_op.constant(grads0_np)
+      grads1 = constant_op.constant(grads1_np)
+
+      learning_rate = 3.0
+
+      ada_opt = adagrad.Adagrad(learning_rate, epsilon=1.0)
+
+      accum0_np = np.array([0.1, 0.1])
+      accum1_np = np.array([0.1, 0.1])
+
+      if not context.executing_eagerly():
+        ada_update = ada_opt.apply_gradients(
+            zip([grads0, grads1], [var0, var1]))
+        self.evaluate(variables.global_variables_initializer())
+
+      # Fetch params to validate initial values
+      v0_val, v1_val = self.evaluate([var0, var1])
+      self.assertAllClose([1.0, 2.0], v0_val)
+      self.assertAllClose([3.0, 4.0], v1_val)
+
+      # Run 3 steps of adagrad
+      for _ in range(3):
+        if not context.executing_eagerly():
+          self.evaluate(ada_update)
+        else:
+          ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+        var0_np, accum0_np = adagrad_update_numpy(var0_np, accum0_np, grads0_np,
+                                                  3.0, 1.0)
+        var1_np, accum1_np = adagrad_update_numpy(var1_np, accum1_np, grads1_np,
+                                                  3.0, 1.0)
+        self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+        self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
+
   def testBasicWithLearningRateInverseTimeDecay(self):
     for dtype in [dtypes.float32, dtypes.float64]:
       with self.cached_session():
@@ -309,6 +350,41 @@
           self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
 
   @test_util.run_deprecated_v1
+  def testSparseSingleVarDim(self):
+    for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+      with self.cached_session():
+        var0_np = np.array([1.0], dtype=dtype.as_numpy_dtype)
+        grads0_np = np.array([0.1], dtype=dtype.as_numpy_dtype)
+
+        var0 = resource_variable_ops.ResourceVariable(var0_np)
+        grads0_np_indices = np.array([0], dtype=np.int32)
+        grads0 = ops.IndexedSlices(
+            constant_op.constant(grads0_np[grads0_np_indices]),
+            constant_op.constant(grads0_np_indices), constant_op.constant([3]))
+        learning_rate = 3.0
+        ada_opt = adagrad.Adagrad(learning_rate, epsilon=1.)
+        ada_update = ada_opt.apply_gradients(zip([grads0], [var0]))
+        variables.global_variables_initializer().run()
+
+        # Fetch params to validate initial values
+        self.assertAllClose([1.0], var0.eval())
+
+        accum0_np = np.array([0.1], dtype=dtype.as_numpy_dtype)
+
+        # Run 3 step of sgd
+        for _ in range(3):
+          ada_update.run()
+
+          var0_np, accum0_np = sparse_adagrad_update_numpy(
+              var0_np,
+              accum0_np,
+              grads0_np_indices,
+              grads0_np[grads0_np_indices],
+              learning_rate,
+              epsilon=1.)
+          self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+
+  @test_util.run_deprecated_v1
   def testSparseRepeatedIndices(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
       with self.cached_session():
diff --git a/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py b/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py
index 00b7c44..bd81dfa 100644
--- a/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py
+++ b/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py
@@ -452,7 +452,7 @@
     decay_steps = 1.0
     decay_rate = 0.5
     learning_rate_fn = keras.optimizers.schedules.InverseTimeDecay(
-      initial_learning_rate, global_step, decay_steps, decay_rate)
+      initial_learning_rate, decay_steps, decay_rate)
 
     model.compile(optimizer=tf.keras.optimizers.SGD(
                       learning_rate=learning_rate_fn),
@@ -549,7 +549,7 @@
     ```python
     decay_steps = 1000
     lr_decayed_fn = tf.keras.experimental.CosineDecay(
-        initial_learning_rate, global_step, decay_steps)
+        initial_learning_rate, decay_steps)
     ```
 
     You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
@@ -640,7 +640,6 @@
     lr_decayed_fn = (
       tf.keras.experimental.CosineDecayRestarts(
           initial_learning_rate,
-          global_step,
           first_decay_steps))
     ```
 
@@ -665,8 +664,6 @@
       A 1-arg callable learning rate schedule that takes the current optimizer
       step and outputs the decayed learning rate, a scalar `Tensor` of the same
       type as `initial_learning_rate`.
-    Raises:
-      ValueError: if `global_step` is not supplied.
     """
     super(CosineDecayRestarts, self).__init__()
 
@@ -779,7 +776,7 @@
     decay_steps = 1000
     lr_decayed_fn = (
       tf.keras.experimental.LinearCosineDecay(
-        initial_learning_rate, global_step, decay_steps))
+        initial_learning_rate, decay_steps))
     ```
 
     You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
@@ -899,7 +896,7 @@
     decay_steps = 1000
     lr_decayed_fn = (
       tf.keras.experimental.NoisyLinearCosineDecay(
-        initial_learning_rate, global_step, decay_steps))
+        initial_learning_rate, decay_steps))
     ```
 
     You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
index f053d85..1e3c82f 100644
--- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
@@ -138,15 +138,13 @@
     loss = <call_loss_function>
   vars = <list_of_variables>
   grads = tape.gradient(loss, vars)
+
+  # Process the gradients, for example cap them, etc.
+  # capped_grads = [MyCapper(g) for g in grads]
   processed_grads = [process_gradient(g) for g in grads]
-  grads_and_vars = zip(processed_grads, var_list)
 
-  # grads_and_vars is a list of tuples (gradient, variable).  Do whatever you
-  # need to the 'gradient' part, for example cap them, etc.
-  capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]
-
-  # Ask the optimizer to apply the capped gradients.
-  opt.apply_gradients(capped_grads_and_vars)
+  # Ask the optimizer to apply the processed gradients.
+  opt.apply_gradients(zip(processed_grads, var_list))
   ```
 
   ### Use with `tf.distribute.Strategy`.
@@ -1024,7 +1022,7 @@
                      ([v.name for _, v in grads_and_vars],))
   if vars_with_empty_grads:
     logging.warning(
-        ("Gradients does not exist for variables %s when minimizing the loss."),
+        ("Gradients do not exist for variables %s when minimizing the loss."),
         ([v.name for v in vars_with_empty_grads]))
   return filtered
 
diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
index 04816a8..a0b9702 100644
--- a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
+++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
@@ -612,12 +612,13 @@
 @keras_parameterized.run_all_keras_modes
 class OptimizersCompatibilityTest(keras_parameterized.TestCase):
 
-  # After run_distributed is turned on, optimizer v1 can no longer work in
-  # eager mode, skipping the test if so.
+  # After experimental_run_tf_function is turned on, optimizer v1 can no longer
+  # work in eager mode, skipping the test if so.
   def _testOptimizersCompatibility(self, opt_v1, opt_v2, test_weights=True):
-    if testing_utils.should_run_distributed() or context.executing_eagerly():
-      self.skipTest('v1 optimizer does not run in run_distributed mode or '
-                    'eager mode')
+    if testing_utils.should_run_tf_function() or context.executing_eagerly():
+      self.skipTest(
+          'v1 optimizer does not run in experimental_run_tf_function mode or '
+          'eager mode')
     np.random.seed(1331)
     with self.cached_session():
       train_samples = 20
@@ -638,7 +639,7 @@
           loss='categorical_crossentropy',
           metrics=[],
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
       model_v1.fit(x, y, batch_size=5, epochs=1)
 
       model_v2 = testing_utils.get_small_sequential_mlp(
@@ -649,7 +650,7 @@
           loss='categorical_crossentropy',
           metrics=[],
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
       model_v2._make_train_function()
       if test_weights:
         opt_v2.set_weights(opt_v1.get_weights())
@@ -702,9 +703,10 @@
     self._testOptimizersCompatibility(opt_v1, opt_v2, False)
 
   def testNumericEquivalenceForNesterovMomentum(self):
-    if testing_utils.should_run_distributed() or context.executing_eagerly():
-      self.skipTest('v1 optimizer does not run in run_distributed mode or '
-                    'eager mode')
+    if testing_utils.should_run_tf_function() or context.executing_eagerly():
+      self.skipTest(
+          'v1 optimizer does not run in experimental_run_tf_function mode or '
+          'eager mode')
     np.random.seed(1331)
     with self.cached_session():
       train_samples = 20
@@ -737,19 +739,19 @@
           loss='categorical_crossentropy',
           metrics=[],
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
       model_k_v2.compile(
           opt_k_v2,
           loss='categorical_crossentropy',
           metrics=[],
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
       model_tf.compile(
           opt_tf,
           loss='categorical_crossentropy',
           metrics=[],
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
 
       hist_k_v1 = model_k_v1.fit(x, y, batch_size=5, epochs=10, shuffle=False)
       hist_k_v2 = model_k_v2.fit(x, y, batch_size=5, epochs=10, shuffle=False)
@@ -762,9 +764,10 @@
       self.assertAllClose(hist_k_v1.history['loss'], hist_k_v2.history['loss'])
 
   def testNumericEquivalenceForAmsgrad(self):
-    if testing_utils.should_run_distributed() or context.executing_eagerly():
-      self.skipTest('v1 optimizer does not run in run_distributed mode or '
-                    'eager mode')
+    if testing_utils.should_run_tf_function() or context.executing_eagerly():
+      self.skipTest(
+          'v1 optimizer does not run in experimental_run_tf_function mode or '
+          'eager mode')
     np.random.seed(1331)
     with self.cached_session():
       train_samples = 20
@@ -792,13 +795,13 @@
           loss='categorical_crossentropy',
           metrics=[],
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
       model_k_v2.compile(
           opt_k_v2,
           loss='categorical_crossentropy',
           metrics=[],
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
 
       hist_k_v1 = model_k_v1.fit(x, y, batch_size=5, epochs=10, shuffle=False)
       hist_k_v2 = model_k_v2.fit(x, y, batch_size=5, epochs=10, shuffle=False)
diff --git a/tensorflow/python/keras/optimizers_test.py b/tensorflow/python/keras/optimizers_test.py
index 9eb2c05..8885cea 100644
--- a/tensorflow/python/keras/optimizers_test.py
+++ b/tensorflow/python/keras/optimizers_test.py
@@ -44,12 +44,13 @@
 @keras_parameterized.run_all_keras_modes
 class KerasOptimizersTest(keras_parameterized.TestCase):
 
-  # After run_distributed is turned on, optimizer v1 can no longer work in
-  # eager mode, skipping the test if so.
+  # After experimental_run_tf_function is turned on, optimizer v1 can no longer
+  # work in eager mode, skipping the test if so.
   def _test_optimizer(self, optimizer, target=0.75):
-    if testing_utils.should_run_distributed() or context.executing_eagerly():
-      self.skipTest('v1 optimizer does not run in run_distributed mode or '
-                    'eager mode')
+    if testing_utils.should_run_tf_function() or context.executing_eagerly():
+      self.skipTest(
+          'v1 optimizer does not run in experimental_run_tf_function mode or '
+          'eager mode')
     np.random.seed(1337)
     (x_train, y_train), _ = testing_utils.get_test_data(
         train_samples=1000, test_samples=200, input_shape=(10,), num_classes=2)
@@ -60,7 +61,7 @@
         optimizer=optimizer,
         metrics=['acc'],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     np.testing.assert_equal(
         keras.backend.get_value(model.optimizer.iterations), 0)
     history = model.fit(x_train, y_train, epochs=2, batch_size=16, verbose=0)
@@ -98,7 +99,7 @@
         optimizer=optimizer,
         metrics=['accuracy'],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     np.testing.assert_equal(
         keras.backend.get_value(model.optimizer.iterations),
         126)  # Using same optimizer from before
@@ -164,18 +165,20 @@
           keras.optimizers.SGD(lr=0.01, momentum=0.9, clipvalue=0.5))
 
   def test_tf_optimizer(self):
-    if testing_utils.should_run_distributed() or context.executing_eagerly():
-      self.skipTest('v1 optimizer does not run in run_distributed mode or '
-                    'eager mode')
+    if testing_utils.should_run_tf_function() or context.executing_eagerly():
+      self.skipTest(
+          'v1 optimizer does not run in experimental_run_tf_function mode or '
+          'eager mode')
     optimizer = keras.optimizers.TFOptimizer(AdamOptimizer(0.01))
     model = keras.models.Sequential()
     model.add(keras.layers.Dense(
         2, input_shape=(3,), kernel_constraint=keras.constraints.MaxNorm(1)))
     # This is possible
-    model.compile(loss='mean_squared_error',
-                  optimizer=optimizer,
-                  run_eagerly=testing_utils.should_run_eagerly(),
-                  run_distributed=testing_utils.should_run_distributed())
+    model.compile(
+        loss='mean_squared_error',
+        optimizer=optimizer,
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     keras.backend.track_tf_optimizer(optimizer)
     model.fit(np.random.random((5, 3)),
               np.random.random((5, 2)),
@@ -191,9 +194,10 @@
       optimizer.from_config(None)
 
   def test_optimizer_garbage_collection(self):
-    if testing_utils.should_run_distributed() or context.executing_eagerly():
-      self.skipTest('v1 optimizer does not run in run_distributed mode or '
-                    'eager mode')
+    if testing_utils.should_run_tf_function() or context.executing_eagerly():
+      self.skipTest(
+          'v1 optimizer does not run in experimental_run_tf_function mode or '
+          'eager mode')
     graph = ops.Graph()
     with graph.as_default():
       optimizer = keras.optimizers.TFOptimizer(AdamOptimizer(0.01))
@@ -207,9 +211,10 @@
     self.assertIs(optimizer_weak(), None)
 
   def test_tf_optimizer_iterations(self):
-    if testing_utils.should_run_distributed() or context.executing_eagerly():
-      self.skipTest('v1 optimizer does not run in run_distributed mode or '
-                    'eager mode')
+    if testing_utils.should_run_tf_function() or context.executing_eagerly():
+      self.skipTest(
+          'v1 optimizer does not run in experimental_run_tf_function mode or '
+          'eager mode')
     with self.cached_session():
       optimizer = keras.optimizers.TFOptimizer(AdamOptimizer(0.01))
       model = keras.models.Sequential()
@@ -219,7 +224,7 @@
           loss='mean_squared_error',
           optimizer=optimizer,
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
       keras.backend.track_tf_optimizer(optimizer)
       self.assertEqual(keras.backend.get_value(model.optimizer.iterations), 0)
 
diff --git a/tensorflow/python/keras/premade/BUILD b/tensorflow/python/keras/premade/BUILD
index 2350bb4..af8e86b 100644
--- a/tensorflow/python/keras/premade/BUILD
+++ b/tensorflow/python/keras/premade/BUILD
@@ -13,7 +13,9 @@
 py_library(
     name = "premade",
     srcs = [
+        "__init__.py",
         "linear.py",
+        "wide_deep.py",
     ],
     srcs_version = "PY2AND3",
     deps = [
@@ -38,3 +40,18 @@
         "//third_party/py/numpy",
     ],
 )
+
+py_test(
+    name = "wide_deep_test",
+    size = "medium",
+    srcs = ["wide_deep_test.py"],
+    python_version = "PY2",
+    shard_count = 2,
+    srcs_version = "PY2AND3",
+    deps = [
+        ":premade",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python/keras",
+        "//third_party/py/numpy",
+    ],
+)
diff --git a/tensorflow/python/keras/premade/__init__.py b/tensorflow/python/keras/premade/__init__.py
new file mode 100644
index 0000000..507f7a6
--- /dev/null
+++ b/tensorflow/python/keras/premade/__init__.py
@@ -0,0 +1,21 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Premade Model API."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras.premade import linear
+from tensorflow.python.keras.premade import wide_deep
diff --git a/tensorflow/python/keras/premade/linear.py b/tensorflow/python/keras/premade/linear.py
index fc9599f..a7e6e09 100644
--- a/tensorflow/python/keras/premade/linear.py
+++ b/tensorflow/python/keras/premade/linear.py
@@ -18,13 +18,16 @@
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python.keras import activations
 from tensorflow.python.keras import initializers
 from tensorflow.python.keras import regularizers
 from tensorflow.python.keras.engine import training
 from tensorflow.python.keras.layers import core
 from tensorflow.python.ops import nn
+from tensorflow.python.util.tf_export import keras_export
 
 
+@keras_export('keras.experimental.LinearModel')
 class LinearModel(training.Model):
   r"""Linear Model for regression and classification problems.
 
@@ -58,6 +61,7 @@
 
   def __init__(self,
                units=1,
+               activation=None,
                use_bias=True,
                kernel_initializer='glorot_uniform',
                bias_initializer='zeros',
@@ -68,6 +72,8 @@
 
     Args:
       units: Positive integer, output dimension without the batch size.
+      activation: Activation function to use.
+        If you don't specify anything, no activation is applied.
       use_bias: whether to calculate the bias/intercept for this model. If set
         to False, no bias/intercept will be used in calculations, e.g., the data
         is already centered.
@@ -79,6 +85,7 @@
     """
 
     self.units = units
+    self.activation = activations.get(activation)
     self.use_bias = use_bias
     self.kernel_initializer = initializers.get(kernel_initializer)
     self.bias_initializer = initializers.get(bias_initializer)
@@ -133,4 +140,6 @@
 
     if self.use_bias:
       result = nn.bias_add(result, self.bias)
+    if self.activation is not None:
+      return self.activation(result)  # pylint: disable=not-callable
     return result
diff --git a/tensorflow/python/keras/premade/linear_test.py b/tensorflow/python/keras/premade/linear_test.py
index 842e89e..49d7cc3 100644
--- a/tensorflow/python/keras/premade/linear_test.py
+++ b/tensorflow/python/keras/premade/linear_test.py
@@ -22,13 +22,16 @@
 
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
+from tensorflow.python.feature_column import dense_features_v2
+from tensorflow.python.feature_column import feature_column_v2 as fc
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import test_util
 from tensorflow.python.keras import backend
+from tensorflow.python.keras import keras_parameterized
 from tensorflow.python.keras import losses
 from tensorflow.python.keras.engine import input_layer
+from tensorflow.python.keras.engine import sequential
 from tensorflow.python.keras.engine import training
 from tensorflow.python.keras.layers import core
 from tensorflow.python.keras.optimizer_v2 import gradient_descent
@@ -37,8 +40,8 @@
 from tensorflow.python.platform import test
 
 
-@test_util.run_all_in_graph_and_eager_modes
-class LinearModelTest(test.TestCase):
+@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
+class LinearModelTest(keras_parameterized.TestCase):
 
   def test_linear_model_with_single_input(self):
     model = linear.LinearModel()
@@ -69,11 +72,6 @@
     model.compile('sgd', 'mse', [])
     model.fit([input_a_np, input_b_np], output_np, epochs=5)
 
-  def test_linear_model_with_int_input(self):
-    inp = input_layer.Input(shape=(1,), dtype=dtypes.int32)
-    with self.assertRaisesRegexp(TypeError, 'Unable to build'):
-      linear.LinearModel()(inp)
-
   def test_linear_model_with_sparse_input(self):
     indices = constant_op.constant([[0, 0], [0, 2], [1, 0], [1, 1]],
                                    dtype=dtypes.int64)
@@ -131,6 +129,33 @@
         grads_and_vars = zip(grads, model.trainable_variables)
         opt.apply_gradients(grads_and_vars)
 
+  # This test is an example for a regression on categorical inputs, i.e.,
+  # the output is 0.4, 0.6, 0.9 when input is 'alpha', 'beta', 'gamma'
+  # separately.
+  def test_linear_model_with_feature_column(self):
+    with context.eager_mode():
+      vocab_list = ['alpha', 'beta', 'gamma']
+      vocab_val = [0.4, 0.6, 0.9]
+      data = np.random.choice(vocab_list, size=256)
+      y = np.zeros_like(data, dtype=np.float32)
+      for vocab, val in zip(vocab_list, vocab_val):
+        indices = np.where(data == vocab)
+        y[indices] = val + np.random.uniform(
+            low=-0.01, high=0.01, size=indices[0].shape)
+      cat_column = fc.categorical_column_with_vocabulary_list(
+          key='symbol', vocabulary_list=vocab_list)
+      ind_column = fc.indicator_column(cat_column)
+      dense_feature_layer = dense_features_v2.DenseFeatures([ind_column])
+      linear_model = linear.LinearModel(
+          use_bias=False, kernel_initializer='zeros')
+      combined = sequential.Sequential([dense_feature_layer, linear_model])
+      opt = gradient_descent.SGD(learning_rate=0.1)
+      combined.compile(opt, 'mse', [])
+      combined.fit(x={'symbol': data}, y=y, batch_size=32, epochs=10)
+      self.assertAllClose([[0.4], [0.6], [0.9]],
+                          combined.layers[1].dense_layers[0].kernel.numpy(),
+                          atol=0.01)
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/keras/premade/wide_deep.py b/tensorflow/python/keras/premade/wide_deep.py
new file mode 100644
index 0000000..ff5dd5e
--- /dev/null
+++ b/tensorflow/python/keras/premade/wide_deep.py
@@ -0,0 +1,164 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Built-in WideNDeep model classes."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras import backend as K
+from tensorflow.python.keras.engine import training
+from tensorflow.python.util.tf_export import keras_export
+
+
+@keras_export('keras.experimental.WideDeepModel')
+class WideDeepModel(training.Model):
+  r"""Wide & Deep Model for regression and classification problems.
+
+  This model jointly train a linear and a dnn model.
+
+  Example:
+
+  ```python
+  linear_model = LinearModel()
+  dnn_model = keras.Sequential([keras.layers.Dense(units=64),
+                               keras.layers.Dense(units=1)])
+  combined_model = WideDeepModel(dnn_model, linear_model)
+  combined_model.compile(optimizer=['sgd', 'adam'], 'mse', ['mse'])
+  # define dnn_inputs and linear_inputs as separate numpy arrays or
+  # a single numpy array if dnn_inputs is same as linear_inputs.
+  combined_model.fit([dnn_inputs, linear_inputs], y, epochs)
+  # or define a single `tf.data.Dataset` that contains a single tensor or
+  # separate tensors for dnn_inputs and linear_inputs.
+  dataset = tf.data.Dataset.from_tensors(([dnn_inputs, linear_inputs], y))
+  combined_model.fit(dataset, epochs)
+  ```
+
+  Both linear and dnn model can be pre-compiled and trained separately
+  before jointly training:
+
+  Example:
+  ```python
+  linear_model = LinearModel()
+  linear_model.compile('adagrad', 'mse')
+  linear_model.fit(linear_inputs, y, epochs)
+  dnn_model = keras.Sequential([keras.layers.Dense(units=1)])
+  dnn_model.compile('rmsprop', 'mse')
+  dnn_model.fit(dnn_inputs, y, epochs)
+  combined_model = WideDeepModel(dnn_model, linear_model)
+  combined_model.compile(optimizer=['sgd', 'adam'], 'mse', ['mse'])
+  combined_model.fit([dnn_inputs, linear_inputs], y, epochs)
+  ```
+
+  """
+
+  def __init__(self, linear_model, dnn_model, activation=None, **kwargs):
+    """Create a Wide & Deep Model.
+
+    Args:
+      linear_model: a premade LinearModel, its output must match the output of
+        the dnn model.
+      dnn_model: a `tf.keras.Model`, its output must match the output of the
+        linear model.
+      activation: Activation function. Set it to None to maintain a linear
+        activation.
+      **kwargs: The keyword arguments that are passed on to BaseLayer.__init__.
+        Allowed keyword arguments include `name`.
+    """
+    super(WideDeepModel, self).__init__(**kwargs)
+    self.linear_model = linear_model
+    self.dnn_model = dnn_model
+    self.activation = activation
+
+  def call(self, inputs):
+    if not isinstance(inputs, (tuple, list)) or len(inputs) != 2:
+      linear_inputs = dnn_inputs = inputs
+    else:
+      linear_inputs, dnn_inputs = inputs
+    linear_output = self.linear_model(linear_inputs)
+    dnn_output = self.dnn_model(dnn_inputs)
+    output = .5 * (linear_output + dnn_output)
+    if self.activation:
+      return self.activation(output)
+    return output
+
+  def _get_optimizers(self):
+    if isinstance(self.optimizer, (tuple, list)):
+      return (self.optimizer[0], self.optimizer[1])
+    else:
+      return (self.optimizer, self.optimizer)
+
+  # This does not support gradient scaling and LossScaleOptimizer.
+  def _backwards(self, tape, loss):
+    linear_vars = self.linear_model._unique_trainable_weights  # pylint: disable=protected-access
+    dnn_vars = self.dnn_model._unique_trainable_weights  # pylint: disable=protected-access
+    linear_grads, dnn_grads = tape.gradient(loss, (linear_vars, dnn_vars))
+    linear_optimizer, dnn_optimizer = self._get_optimizers()
+    linear_optimizer.apply_gradients(zip(linear_grads, linear_vars))
+    dnn_optimizer.apply_gradients(zip(dnn_grads, dnn_vars))
+    return
+
+  def _make_train_function(self):
+    # TODO(tanzheny): This is a direct copy from super to make it work
+    # refactor it so that common logic can be shared.
+    has_recompiled = self._recompile_weights_loss_and_weighted_metrics()
+    self._check_trainable_weights_consistency()
+    # If we have re-compiled the loss/weighted metric sub-graphs then create
+    # train function even if one exists already. This is because
+    # `_feed_sample_weights` list has been updated on re-copmpile.
+    if getattr(self, 'train_function', None) is None or has_recompiled:
+      # Restore the compiled trainable state.
+      current_trainable_state = self._get_trainable_state()
+      self._set_trainable_state(self._compiled_trainable_state)
+
+      inputs = (
+          self._feed_inputs + self._feed_targets + self._feed_sample_weights)
+      if not isinstance(K.symbolic_learning_phase(), int):
+        inputs += [K.symbolic_learning_phase()]
+
+      linear_optimizer, dnn_optimizer = self._get_optimizers()
+      with K.get_graph().as_default():
+        with K.name_scope('training'):
+          # Training updates
+          updates = []
+          linear_updates = linear_optimizer.get_updates(
+              params=self.linear_model._unique_trainable_weights,  # pylint: disable=protected-access
+              loss=self.total_loss)
+          updates += linear_updates
+          dnn_updates = dnn_optimizer.get_updates(
+              params=self.dnn_model._unique_trainable_weights,  # pylint: disable=protected-access
+              loss=self.total_loss)
+          updates += dnn_updates
+          # Unconditional updates
+          updates += self.get_updates_for(None)
+          # Conditional updates relevant to this model
+          updates += self.get_updates_for(self.inputs)
+
+        metrics = self._get_training_eval_metrics()
+        metrics_tensors = [
+            m._call_result for m in metrics if hasattr(m, '_call_result')  # pylint: disable=protected-access
+        ]
+
+      with K.name_scope('training'):
+        # Gets loss and metrics. Updates weights at each call.
+        fn = K.function(
+            inputs, [self.total_loss] + metrics_tensors,
+            updates=updates,
+            name='train_function',
+            **self._function_kwargs)
+        setattr(self, 'train_function', fn)
+
+      # Restore the current trainable state
+      self._set_trainable_state(current_trainable_state)
diff --git a/tensorflow/python/keras/premade/wide_deep_test.py b/tensorflow/python/keras/premade/wide_deep_test.py
new file mode 100644
index 0000000..c3894cb
--- /dev/null
+++ b/tensorflow/python/keras/premade/wide_deep_test.py
@@ -0,0 +1,239 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Keras Premade WideNDeep models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.feature_column import dense_features_v2
+from tensorflow.python.feature_column import feature_column_v2 as fc
+from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras import testing_utils
+from tensorflow.python.keras.engine import input_layer
+from tensorflow.python.keras.engine import sequential
+from tensorflow.python.keras.engine import training
+from tensorflow.python.keras.layers import core
+from tensorflow.python.keras.optimizer_v2 import gradient_descent
+from tensorflow.python.keras.premade import linear
+from tensorflow.python.keras.premade import wide_deep
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
+class WideDeepModelTest(keras_parameterized.TestCase):
+
+  def test_wide_deep_model(self):
+    linear_model = linear.LinearModel(units=1)
+    dnn_model = sequential.Sequential([core.Dense(units=1, input_dim=3)])
+    wide_deep_model = wide_deep.WideDeepModel(linear_model, dnn_model)
+    linear_inp = np.random.uniform(low=-5, high=5, size=(64, 2))
+    dnn_inp = np.random.uniform(low=-5, high=5, size=(64, 3))
+    inputs = [linear_inp, dnn_inp]
+    output = .3 * linear_inp[:, 0] + .2 * dnn_inp[:, 1]
+    wide_deep_model.compile(
+        optimizer=['sgd', 'adam'],
+        loss='mse',
+        metrics=[],
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    wide_deep_model.fit(inputs, output, epochs=5)
+    self.assertTrue(wide_deep_model.built)
+
+  def test_wide_deep_model_backprop(self):
+    with self.cached_session():
+      linear_model = linear.LinearModel(units=1, kernel_initializer='zeros')
+      dnn_model = sequential.Sequential(
+          [core.Dense(units=1, kernel_initializer='zeros')])
+      wide_deep_model = wide_deep.WideDeepModel(linear_model, dnn_model)
+      linear_inp = np.array([1.])
+      dnn_inp = np.array([1.])
+      inputs = [linear_inp, dnn_inp]
+      output = linear_inp + 2 * dnn_inp
+      linear_opt = gradient_descent.SGD(learning_rate=.1)
+      dnn_opt = gradient_descent.SGD(learning_rate=.3)
+      wide_deep_model.compile(
+          optimizer=[linear_opt, dnn_opt],
+          loss='mse',
+          metrics=[],
+          run_eagerly=testing_utils.should_run_eagerly(),
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
+      self.evaluate(variables.global_variables_initializer())
+      wide_deep_model.fit(inputs, output, epochs=1)
+      self.assertAllClose(
+          [[0.3]],
+          self.evaluate(wide_deep_model.linear_model.dense_layers[0].kernel))
+      self.assertAllClose([[0.9]],
+                          self.evaluate(
+                              wide_deep_model.dnn_model.layers[0].kernel))
+
+  def test_wide_deep_model_with_single_input(self):
+    linear_model = linear.LinearModel(units=1)
+    dnn_model = sequential.Sequential([core.Dense(units=1, input_dim=3)])
+    wide_deep_model = wide_deep.WideDeepModel(linear_model, dnn_model)
+    inputs = np.random.uniform(low=-5, high=5, size=(64, 3))
+    output = .3 * inputs[:, 0]
+    wide_deep_model.compile(
+        optimizer=['sgd', 'adam'],
+        loss='mse',
+        metrics=[],
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    wide_deep_model.fit(inputs, output, epochs=5)
+
+  def test_wide_deep_model_with_single_optimizer(self):
+    linear_model = linear.LinearModel(units=1)
+    dnn_model = sequential.Sequential([core.Dense(units=1, input_dim=3)])
+    wide_deep_model = wide_deep.WideDeepModel(linear_model, dnn_model)
+    linear_inp = np.random.uniform(low=-5, high=5, size=(64, 2))
+    dnn_inp = np.random.uniform(low=-5, high=5, size=(64, 3))
+    inputs = [linear_inp, dnn_inp]
+    output = .3 * linear_inp[:, 0] + .2 * dnn_inp[:, 1]
+    wide_deep_model.compile(
+        optimizer='sgd',
+        loss='mse',
+        metrics=[],
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    wide_deep_model.fit(inputs, output, epochs=5)
+    self.assertTrue(wide_deep_model.built)
+
+  def test_wide_deep_model_as_layer(self):
+    linear_model = linear.LinearModel(units=1)
+    dnn_model = sequential.Sequential([core.Dense(units=1)])
+    linear_input = input_layer.Input(shape=(3,), name='linear')
+    dnn_input = input_layer.Input(shape=(5,), name='dnn')
+    wide_deep_model = wide_deep.WideDeepModel(linear_model, dnn_model)
+    wide_deep_output = wide_deep_model((linear_input, dnn_input))
+    input_b = input_layer.Input(shape=(1,), name='b')
+    output_b = core.Dense(units=1)(input_b)
+    model = training.Model(
+        inputs=[linear_input, dnn_input, input_b],
+        outputs=[wide_deep_output + output_b])
+    linear_input_np = np.random.uniform(low=-5, high=5, size=(64, 3))
+    dnn_input_np = np.random.uniform(low=-5, high=5, size=(64, 5))
+    input_b_np = np.random.uniform(low=-5, high=5, size=(64,))
+    output_np = linear_input_np[:, 0] + .2 * dnn_input_np[:, 1] + input_b_np
+    model.compile(
+        optimizer='sgd',
+        loss='mse',
+        metrics=[],
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    model.fit([linear_input_np, dnn_input_np, input_b_np], output_np, epochs=5)
+
+  def test_wide_deep_model_with_sub_model_trained(self):
+    linear_model = linear.LinearModel(units=1)
+    dnn_model = sequential.Sequential([core.Dense(units=1, input_dim=3)])
+    wide_deep_model = wide_deep.WideDeepModel(
+        linear.LinearModel(units=1),
+        sequential.Sequential([core.Dense(units=1, input_dim=3)]))
+    linear_inp = np.random.uniform(low=-5, high=5, size=(64, 2))
+    dnn_inp = np.random.uniform(low=-5, high=5, size=(64, 3))
+    inputs = [linear_inp, dnn_inp]
+    output = .3 * linear_inp[:, 0] + .2 * dnn_inp[:, 1]
+    linear_model.compile(
+        optimizer='sgd',
+        loss='mse',
+        metrics=[],
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    dnn_model.compile(
+        optimizer='adam',
+        loss='mse',
+        metrics=[],
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    linear_model.fit(linear_inp, output, epochs=50)
+    dnn_model.fit(dnn_inp, output, epochs=50)
+    wide_deep_model.compile(
+        optimizer=['sgd', 'adam'],
+        loss='mse',
+        metrics=[],
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    wide_deep_model.fit(inputs, output, epochs=50)
+
+  # This test is an example for cases where linear and dnn model accepts
+  # same raw input and same transformed inputs, i.e., the raw input is
+  # categorical, and both linear and dnn model accept one hot encoding.
+  def test_wide_deep_model_with_single_feature_column(self):
+    vocab_list = ['alpha', 'beta', 'gamma']
+    vocab_val = [0.4, 0.6, 0.9]
+    data = np.random.choice(vocab_list, size=256)
+    y = np.zeros_like(data, dtype=np.float32)
+    for vocab, val in zip(vocab_list, vocab_val):
+      indices = np.where(data == vocab)
+      y[indices] = val + np.random.uniform(
+          low=-0.01, high=0.01, size=indices[0].shape)
+    cat_column = fc.categorical_column_with_vocabulary_list(
+        key='symbol', vocabulary_list=vocab_list)
+    ind_column = fc.indicator_column(cat_column)
+    dense_feature_layer = dense_features_v2.DenseFeatures([ind_column])
+    linear_model = linear.LinearModel(
+        use_bias=False, kernel_initializer='zeros')
+    dnn_model = sequential.Sequential([core.Dense(units=1)])
+    wide_deep_model = wide_deep.WideDeepModel(linear_model, dnn_model)
+    combined = sequential.Sequential([dense_feature_layer, wide_deep_model])
+    opt = gradient_descent.SGD(learning_rate=0.1)
+    combined.compile(
+        opt,
+        'mse', [],
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    combined.fit(x={'symbol': data}, y=y, batch_size=32, epochs=10)
+
+  # This test is an example for cases where linear and dnn model accepts
+  # same raw input but different transformed inputs, i.e,. the raw input is
+  # categorical, and linear model accepts one hot encoding, while dnn model
+  # accepts embedding encoding.
+  def test_wide_deep_model_with_two_feature_columns(self):
+    vocab_list = ['alpha', 'beta', 'gamma']
+    vocab_val = [0.4, 0.6, 0.9]
+    data = np.random.choice(vocab_list, size=256)
+    y = np.zeros_like(data, dtype=np.float32)
+    for vocab, val in zip(vocab_list, vocab_val):
+      indices = np.where(data == vocab)
+      y[indices] = val + np.random.uniform(
+          low=-0.01, high=0.01, size=indices[0].shape)
+    cat_column = fc.categorical_column_with_vocabulary_list(
+        key='symbol', vocabulary_list=vocab_list)
+    ind_column = fc.indicator_column(cat_column)
+    emb_column = fc.embedding_column(cat_column, dimension=5)
+    linear_feature_layer = dense_features_v2.DenseFeatures([ind_column])
+    linear_model = linear.LinearModel(
+        use_bias=False, kernel_initializer='zeros')
+    combined_linear = sequential.Sequential(
+        [linear_feature_layer, linear_model])
+    dnn_model = sequential.Sequential([core.Dense(units=1)])
+    dnn_feature_layer = dense_features_v2.DenseFeatures([emb_column])
+    combined_dnn = sequential.Sequential([dnn_feature_layer, dnn_model])
+    wide_deep_model = wide_deep.WideDeepModel(combined_linear, combined_dnn)
+    opt = gradient_descent.SGD(learning_rate=0.1)
+    wide_deep_model.compile(
+        opt,
+        'mse', [],
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    wide_deep_model.fit(x={'symbol': data}, y=y, batch_size=32, epochs=10)
+    self.assertEqual(3, linear_model.inputs[0].shape[1])
+    self.assertEqual(5, dnn_model.inputs[0].shape[1])
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/python/keras/regularizers_test.py b/tensorflow/python/keras/regularizers_test.py
index 1b33e9d..cabefa2 100644
--- a/tensorflow/python/keras/regularizers_test.py
+++ b/tensorflow/python/keras/regularizers_test.py
@@ -79,7 +79,7 @@
         loss='categorical_crossentropy',
         optimizer='sgd',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     self.assertEqual(len(model.losses), 1)
     model.fit(x_train, y_train, batch_size=10, epochs=1, verbose=0)
 
@@ -97,7 +97,7 @@
         loss='categorical_crossentropy',
         optimizer='sgd',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     self.assertEqual(len(model.losses), 1 if context.executing_eagerly() else 1)
     model.fit(x_train, y_train, batch_size=10, epochs=1, verbose=0)
 
@@ -113,7 +113,7 @@
         'sgd',
         'mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(x, y, batch_size=5, epochs=1)
 
   def test_custom_regularizer_saving(self):
@@ -144,7 +144,7 @@
         loss='categorical_crossentropy',
         optimizer='sgd',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     self.assertEqual(len(model.losses), 5)
 
   @keras_parameterized.run_all_keras_modes
@@ -167,7 +167,7 @@
         loss='categorical_crossentropy',
         optimizer='sgd',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     self.assertEqual(len(model.losses), 6)
 
   @keras_parameterized.run_all_keras_modes
@@ -195,7 +195,7 @@
         loss='categorical_crossentropy',
         optimizer='sgd',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     self.assertEqual(len(model.losses), 14)
 
 
diff --git a/tensorflow/python/keras/saving/save.py b/tensorflow/python/keras/saving/save.py
index 7391f98..a7450c9 100644
--- a/tensorflow/python/keras/saving/save.py
+++ b/tensorflow/python/keras/saving/save.py
@@ -48,7 +48,8 @@
                filepath,
                overwrite=True,
                include_optimizer=True,
-               save_format=None):
+               save_format=None,
+               signatures=None):
   """Saves a model as a TensorFlow SavedModel or HDF5 file.
 
   The saved model contains:
@@ -79,6 +80,9 @@
       save_format: Either 'tf' or 'h5', indicating whether to save the model
         to Tensorflow SavedModel or HDF5. Defaults to 'tf' in TF 2.X, and 'h5'
         in TF 1.X.
+      signatures: Signatures to save with the SavedModel. Applicable to the 'tf'
+        format only. Please see the `signatures` argument in
+        `tf.saved_model.save` for details.
 
   Raises:
       ImportError: If save format is hdf5, and h5py is not available.
@@ -104,7 +108,8 @@
     hdf5_format.save_model_to_hdf5(
         model, filepath, overwrite, include_optimizer)
   else:
-    saved_model_save.save(model, filepath, overwrite, include_optimizer)
+    saved_model_save.save(model, filepath, overwrite, include_optimizer,
+                          signatures)
 
 
 @keras_export('keras.models.load_model')
diff --git a/tensorflow/python/keras/saving/save_test.py b/tensorflow/python/keras/saving/save_test.py
index 094a53d..6b171f3 100644
--- a/tensorflow/python/keras/saving/save_test.py
+++ b/tensorflow/python/keras/saving/save_test.py
@@ -24,7 +24,7 @@
 
 from tensorflow.python import keras
 from tensorflow.python.eager import context
-from tensorflow.python.feature_column import feature_column_v2
+from tensorflow.python.feature_column import feature_column_lib
 from tensorflow.python.framework import test_util
 from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.saving import model_config
@@ -83,16 +83,18 @@
 
   @test_util.run_in_graph_and_eager_modes
   def test_saving_with_dense_features(self):
-    cols = [feature_column_v2.numeric_column('a'),
-            feature_column_v2.indicator_column(
-                feature_column_v2.categorical_column_with_vocabulary_list(
-                    'b', ['one', 'two']))]
+    cols = [
+        feature_column_lib.numeric_column('a'),
+        feature_column_lib.indicator_column(
+            feature_column_lib.categorical_column_with_vocabulary_list(
+                'b', ['one', 'two']))
+    ]
     input_layers = {
         'a': keras.layers.Input(shape=(1,), name='a'),
         'b': keras.layers.Input(shape=(1,), name='b', dtype='string')
     }
 
-    fc_layer = feature_column_v2.DenseFeatures(cols)(input_layers)
+    fc_layer = feature_column_lib.DenseFeatures(cols)(input_layers)
     output = keras.layers.Dense(10)(fc_layer)
 
     model = keras.models.Model(input_layers, output)
diff --git a/tensorflow/python/keras/saving/saved_model/load.py b/tensorflow/python/keras/saving/saved_model/load.py
index 635ad77..75a740c 100644
--- a/tensorflow/python/keras/saving/saved_model/load.py
+++ b/tensorflow/python/keras/saving/saved_model/load.py
@@ -105,6 +105,15 @@
   def _finalize(self):
     # pylint: disable=protected-access
     for node in self._nodes:
+      if isinstance(node, RevivedLayer):
+        if not isinstance(node, RevivedSequential):
+          if hasattr(node.keras_api, 'call_and_return_conditional_losses'):
+            node.call = utils.use_wrapped_call(
+                node, node.keras_api.call_and_return_conditional_losses,
+                return_method=True)
+            node._init_call_fn_args()
+
+    for node in self._nodes:
       if isinstance(node, RevivedModel):
         call_fn = node.keras_api.call_and_return_conditional_losses
         if call_fn.input_signature is None:
@@ -224,7 +233,7 @@
 
   @property
   def keras_api(self):
-    return self._serialized_attributes[constants.KERAS_ATTR]
+    return self._serialized_attributes.get(constants.KERAS_ATTR, None)
 
   def get_config(self):
     if hasattr(self, '_config'):
@@ -232,12 +241,6 @@
     else:
       raise NotImplementedError
 
-  def call(self, inputs, *args, **kwargs):
-    """Calls the revived layer and add conditional losses."""
-    call_fn = utils.use_wrapped_call(
-        self, self.keras_api.call_and_return_conditional_losses)
-    return call_fn(inputs, *args, **kwargs)
-
 
 def recursively_deserialize_keras_object(config, module_objects=None):
   """Deserialize Keras object from a nested structure."""
@@ -281,7 +284,6 @@
   @classmethod
   def _init_from_metadata(cls, metadata):
     """Create revived network from metadata stored in the SavedModel proto."""
-    # TODO(kathywu): Refactor logic here so that RevivedNetwork uses the
     revived_obj = cls(name=metadata['name'])
 
     with trackable.no_automatic_dependency_tracking_scope(revived_obj):
@@ -329,6 +331,3 @@
     """Create revived Sequential model from SavedModel metadata."""
     revived_obj = super(RevivedSequential, cls)._init_from_metadata(metadata)
     return revived_obj
-
-  def call(self, *args, **kwargs):
-    return models_lib.Sequential.call(self, *args, **kwargs)
diff --git a/tensorflow/python/keras/saving/saved_model/save.py b/tensorflow/python/keras/saving/saved_model/save.py
index 1bf80e6..b495a03 100644
--- a/tensorflow/python/keras/saving/saved_model/save.py
+++ b/tensorflow/python/keras/saving/saved_model/save.py
@@ -39,6 +39,7 @@
 from tensorflow.python.training.tracking import data_structures
 from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
 from tensorflow.python.util import nest
+from tensorflow.python.util import tf_decorator
 from tensorflow.python.util import tf_inspect
 from tensorflow.python.util.lazy_loader import LazyLoader
 
@@ -57,7 +58,7 @@
 # pylint:enable=g-inconsistent-quotes
 
 
-def save(model, filepath, overwrite, include_optimizer):
+def save(model, filepath, overwrite, include_optimizer, signatures=None):
   """Saves a model as a SavedModel to the filepath.
 
   Args:
@@ -65,6 +66,9 @@
     filepath: String path to save the model.
     overwrite: whether to overwrite the existing filepath.
     include_optimizer: If True, save the model's optimizer state.
+    signatures: Signatures to save with the SavedModel. Applicable to the 'tf'
+      format only. Please see the `signatures` argument in `tf.saved_model.save`
+      for details.
 
   Raises:
     ValueError: if the model's inputs have not been defined.
@@ -82,7 +86,10 @@
     orig_optimizer = model.optimizer
     model.optimizer = None
 
-  save_lib.save(model, filepath)
+  # Trace all functions and signatures with `training=0` instead of using the
+  # default learning phase placeholder.
+  with K.learning_phase_scope(0):
+    save_lib.save(model, filepath, signatures)
 
   if not include_optimizer:
     model.optimizer = orig_optimizer
@@ -252,7 +259,8 @@
     fns['activity_regularizer_fn'] = _wrap_activity_regularizer(layer)
     fns['call_and_return_all_conditional_losses'] = (
         call_collection.add_function(
-            _append_activity_regularizer_loss(call_fn_with_losses,
+            _append_activity_regularizer_loss(layer,
+                                              call_fn_with_losses,
                                               fns['activity_regularizer_fn']),
             '{}_layer_call_and_return_all_conditional_losses'.format(layer.name)
             ))
@@ -343,7 +351,8 @@
         # Some layers have an unsettable activity regularizer.
         pass
       child_layer.call = utils.use_wrapped_call(
-          child_layer, layer_fns['call_and_return_conditional_losses'])
+          child_layer, layer_fns['call_and_return_conditional_losses'],
+          default_training_value=False)
   return original_fns
   # pylint: enable=protected-access
 
@@ -380,6 +389,23 @@
 # pylint: enable=protected-access
 
 
+def layer_uses_training_bool(layer):
+  """Returns whether this layer or any of its children uses the training arg."""
+  if layer._expects_training_arg:  # pylint: disable=protected-access
+    return True
+  visited = {layer}
+  to_visit = _list_all_layers(layer)
+  while to_visit:
+    layer = to_visit.pop()
+    if layer in visited:
+      continue
+    if layer._expects_training_arg:  # pylint: disable=protected-access
+      return True
+    visited.add(layer)
+    to_visit.extend(_list_all_layers(layer))
+  return False
+
+
 class LayerCallCollection(object):
   """Groups wrapped layer call functions.
 
@@ -391,8 +417,10 @@
   """
 
   def __init__(self, layer):
-    self._layer = layer
-    self._expects_training_arg = layer._expects_training_arg  # pylint: disable=protected-access
+    self.layer = layer
+    self._expects_training_arg = layer_uses_training_bool(layer)
+    self._training_arg_index = utils.get_training_arg_index(layer.call)
+
     self._input_signature = self._generate_input_signature(layer)
     self._functions = weakref.WeakValueDictionary()
     # Bool indicating whether this object is currently tracing the layer call
@@ -447,24 +475,15 @@
       # TODO(kathywu): Replace arguments with broader shapes defined in the
       # input signature.
       if self._expects_training_arg:
-        arg_list = tf_inspect.getfullargspec(fn.python_function).args
-        if 'training' in arg_list:
-          training_arg_index = arg_list.index('training')
-        else:
-          training_arg_index = -1
+        def trace_with_training(value, fn=fn):
+          utils.set_training_arg(value, self._training_arg_index, args, kwargs)
+          with K.learning_phase_scope(value):
+            fn.get_concrete_function(*args, **kwargs)
 
-        def set_training_arg(training, index=training_arg_index):
-          if index >= 0 and len(args) > index:
-            args[index] = training
-          else:
-            kwargs['training'] = training
-
-        set_training_arg(False)
-        fn.original_get_concrete_function(*args, **kwargs)
-        set_training_arg(True)
-        fn.original_get_concrete_function(*args, **kwargs)
+        trace_with_training(True)
+        trace_with_training(False)
       else:
-        fn.original_get_concrete_function(*args, **kwargs)
+        fn.get_concrete_function(*args, **kwargs)
     self.tracing = False
 
   @property
@@ -480,10 +499,63 @@
       return None
     return self._input_signature
 
-  def add_function(self, python_function, name):
+  def training_arg_was_passed(self, args, kwargs):
+    if not self.layer._expects_training_arg and self._expects_training_arg:  # pylint: disable=protected-access
+      return (utils.get_training_arg(self._training_arg_index, args, kwargs)
+              is not None)
+    else:
+      return self.layer._call_arg_was_passed(  # pylint: disable=protected-access
+          'training', args, kwargs, inputs_in_args=True)
+
+  def get_training_arg_value(self, args, kwargs):
+    if not self.layer._expects_training_arg and self._expects_training_arg:  # pylint: disable=protected-access
+      return utils.get_training_arg(self._training_arg_index, args, kwargs)
+    else:
+      return self.layer._get_call_arg_value(  # pylint: disable=protected-access
+          'training', args, kwargs, inputs_in_args=True)
+
+  def _maybe_wrap_with_training_arg(self, call_fn):
+    """Wraps call function with added training argument if necessary."""
+    if not self.layer._expects_training_arg and self._expects_training_arg:  # pylint: disable=protected-access
+      # Add training arg to wrapper function.
+      arg_spec = tf_inspect.getfullargspec(call_fn)
+      args = arg_spec.args + ['training']
+      defaults = list(arg_spec.defaults or [])
+      defaults.append(False)
+      new_arg_spec = tf_inspect.FullArgSpec(
+          args=args,
+          varargs=arg_spec.varargs,
+          varkw=arg_spec.varkw,
+          defaults=defaults,
+          kwonlyargs=arg_spec.kwonlyargs,
+          kwonlydefaults=arg_spec.kwonlydefaults,
+          annotations=arg_spec.annotations)
+
+      # Set new training arg index
+      self._training_arg_index = len(args) - 1
+      if tf_inspect.ismethod(call_fn):
+        self._training_arg_index -= 1
+
+      def wrap_with_training_arg(*args, **kwargs):
+        # Remove the training value, since the original call_fn does not expect
+        # a training arg. Instead, the training value will be propagated using
+        # the call context created in LayerCall.
+        args = list(args)
+        kwargs = kwargs.copy()
+        utils.remove_training_arg(self._training_arg_index, args, kwargs)
+        return call_fn(*args, **kwargs)
+
+      return tf_decorator.make_decorator(
+          target=call_fn,
+          decorator_func=wrap_with_training_arg,
+          decorator_argspec=new_arg_spec)
+
+    return call_fn
+
+  def add_function(self, call_fn, name):
     """Adds a layer call function to the collection."""
     self._functions[name] = fn = LayerCall(
-        self, python_function, name,
+        self, self._maybe_wrap_with_training_arg(call_fn), name,
         input_signature=self.fn_input_signature)
 
     if (None not in nest.flatten(self._input_signature) and
@@ -494,12 +566,36 @@
     return fn
 
 
+def layer_call_wrapper(call_collection, method):
+  """Ensures layer losses are kept the same, and runs method in call context."""
+  def wrapper(*args, **kwargs):
+    """Calls method within call context."""
+    layer = call_collection.layer
+    training = None
+    inputs = None
+    # pylint: disable=protected-access
+    if (args or kwargs) and call_collection.training_arg_was_passed(
+        args, kwargs):
+      inputs = args[0]
+      training = call_collection.get_training_arg_value(args, kwargs)
+    # pylint: enable=protected-access
+    original_losses = _reset_layer_losses(layer)
+    with base_layer_utils.call_context().enter(
+        layer, inputs=inputs, build_graph=False, training=training):
+      ret = method(*args, **kwargs)
+    _restore_layer_losses(original_losses)
+    return ret
+  return tf_decorator.make_decorator(target=method, decorator_func=wrapper)
+
+
 class LayerCall(def_function.Function):
   """Function that triggers traces of other functions in the same collection."""
 
-  def __init__(self, call_collection, *args, **kwargs):
-    super(LayerCall, self).__init__(*args, **kwargs)
+  def __init__(self, call_collection, python_function, *args, **kwargs):
     self.call_collection = call_collection
+    self.original_call = call_collection.layer.call
+    python_function = layer_call_wrapper(call_collection, python_function)
+    super(LayerCall, self).__init__(python_function, *args, **kwargs)
 
   def __call__(self, *args, **kwargs):
     if not self.call_collection.tracing:
@@ -511,9 +607,6 @@
       self.call_collection.add_trace(*args, **kwargs)
     return super(LayerCall, self).get_concrete_function(*args, **kwargs)
 
-  def original_get_concrete_function(self, *args, **kwargs):
-    return super(LayerCall, self).get_concrete_function(*args, **kwargs)
-
 
 def _wrap_call_and_conditional_losses(layer):
   """Wraps call function that returns a tuple of (outputs, losses).
@@ -530,37 +623,38 @@
   """
   # Create function that generates both outputs and losses
   layer_call = layer.call
-  if layer._expects_training_arg:  # pylint: disable=protected-access
-    def call_and_return_conditional_losses(inputs, training=False):
-      return layer_call(inputs, training=training), layer.get_losses_for(inputs)
-  else:
-    def call_and_return_conditional_losses(inputs):
-      K.set_learning_phase(0)
-      return layer_call(inputs), layer.get_losses_for(inputs)
-  return call_and_return_conditional_losses
+  def call_and_return_conditional_losses(inputs, *args, **kwargs):
+    return layer_call(inputs, *args, **kwargs), layer.get_losses_for(inputs)
+  return _create_call_fn_decorator(layer, call_and_return_conditional_losses)
 
 
 def _extract_outputs_from_fn(layer, call_and_return_conditional_losses):
   """Returns a function that returns only call function outputs."""
   if isinstance(layer, keras_load.RevivedLayer):
     return layer.keras_api.__call__  # pylint: disable=protected-access
-  if layer._expects_training_arg:  # pylint: disable=protected-access
-    def call(inputs, training=False):
-      return call_and_return_conditional_losses(inputs, training=training)[0]
-  else:
-    def call(inputs):
-      return call_and_return_conditional_losses(inputs)[0]
-  return call
+  def call(inputs, *args, **kwargs):
+    return call_and_return_conditional_losses(inputs, *args, **kwargs)[0]
+  return _create_call_fn_decorator(layer, call)
 
 
 def _append_activity_regularizer_loss(
-    call_fn_with_losses, activity_regularizer_fn):
+    layer, call_fn_with_losses, activity_regularizer_fn):
   """Appends activity regularizer loss to losses returned by the wrapped fn."""
-  def fn(*args, **kwargs):
-    outputs, losses = call_fn_with_losses(*args, **kwargs)
+  def fn(inputs, *args, **kwargs):
+    outputs, losses = call_fn_with_losses(inputs, *args, **kwargs)
     losses.append(activity_regularizer_fn(outputs))
     return outputs, losses
-  return fn
+  return _create_call_fn_decorator(layer, fn)
+
+
+def _create_call_fn_decorator(layer, wrapped_call):
+  fn, arg_spec = utils.maybe_add_training_arg(
+      layer.call, wrapped_call, layer._expects_training_arg,  # pylint: disable=protected-access
+      default_training_value=False)
+  return tf_decorator.make_decorator(
+      target=layer.call,
+      decorator_func=fn,
+      decorator_argspec=arg_spec)
 
 
 def _wrap_unconditional_loss(loss_fn, index):
diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py
index 7358f43..829d90d 100644
--- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py
+++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py
@@ -23,22 +23,31 @@
 
 import numpy as np
 
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
 from tensorflow.python import keras
 from tensorflow.python.eager import context
+from tensorflow.python.eager import def_function
 from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import test_util
 from tensorflow.python.keras import keras_parameterized
 from tensorflow.python.keras import regularizers
 from tensorflow.python.keras import testing_utils
-from tensorflow.python.keras.saving.saved_model import load as saved_model_load
+from tensorflow.python.keras.saving.saved_model import load as keras_load
+from tensorflow.python.keras.saving.saved_model import save as keras_save
 from tensorflow.python.keras.utils import tf_utils
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import parsing_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
 from tensorflow.python.saved_model import load as tf_load
 from tensorflow.python.saved_model import save as tf_save
+from tensorflow.python.util import tf_inspect
 
 
 class LayerWithLearningPhase(keras.engine.base_layer.Layer):
@@ -60,6 +69,13 @@
     return input_shape
 
 
+class LayerWithLoss(keras.layers.Layer):
+
+  def call(self, inputs):
+    self.add_loss(math_ops.reduce_sum(inputs), inputs)
+    return inputs
+
+
 @test_util.run_all_in_graph_and_eager_modes
 class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
 
@@ -86,8 +102,7 @@
     model.add_loss(callable_loss)
     saved_model_dir = self._save_model_dir()
     tf_save.save(model, saved_model_dir)
-
-    loaded = saved_model_load.load(saved_model_dir)
+    loaded = keras_load.load(saved_model_dir)
     self.evaluate(variables.variables_initializer(loaded.variables))
     self.assertAllClose(self.evaluate(model.weights),
                         self.evaluate(loaded.weights))
@@ -123,7 +138,7 @@
     saved_model_dir = self._save_model_dir()
     self.evaluate(variables.variables_initializer(layer.variables))
     tf_save.save(layer, saved_model_dir)
-    loaded = saved_model_load.load(saved_model_dir)
+    loaded = keras_load.load(saved_model_dir)
     self.evaluate(variables.variables_initializer(loaded.variables))
 
     equal_attrs = ['name', '_expects_training_arg', 'trainable']
@@ -137,13 +152,6 @@
 
   def test_maintains_losses(self):
     """Tests that the layer losses do not change before and after export."""
-
-    class LayerWithLoss(keras.layers.Layer):
-
-      def call(self, inputs):
-        self.add_loss(math_ops.reduce_sum(inputs), inputs)
-        return inputs
-
     model = keras.models.Sequential([LayerWithLoss()])
     model.compile(
         loss='mse',
@@ -172,7 +180,7 @@
     layer.build([None, None])
     saved_model_dir = self._save_model_dir()
     tf_save.save(layer, saved_model_dir)
-    loaded = saved_model_load.load(saved_model_dir)
+    loaded = keras_load.load(saved_model_dir)
     input_arr = array_ops.ones((4, 3))
 
     # Run the layer, and use the keras backend learing phase
@@ -214,7 +222,7 @@
     self.assertEqual(expected_layers, len(loaded.keras_api.layers))
     input_arr = array_ops.ones((4, 3))
     self.assertAllClose(self.evaluate(model(input_arr)),
-                        self.evaluate(loaded(input_arr)))
+                        self.evaluate(loaded(input_arr, training=False)))
 
   @keras_parameterized.run_with_all_model_types
   def test_compiled_model(self):
@@ -232,7 +240,7 @@
     # TODO(b/134519980): Issue with model.fit if the model call function uses
     # a tf.function (Graph mode only).
     with context.eager_mode():
-      loaded = saved_model_load.load(saved_model_dir)
+      loaded = keras_load.load(saved_model_dir)
       actual_predict = loaded.predict(input_arr)
       self.assertAllClose(expected_predict, actual_predict)
 
@@ -261,7 +269,7 @@
     layer = LayerWithNestedSpec()
     saved_model_dir = self._save_model_dir()
     tf_save.save(layer, saved_model_dir)
-    loaded = saved_model_load.load(saved_model_dir)
+    loaded = keras_load.load(saved_model_dir)
     self.assertEqual(3, loaded.input_spec['a'].max_ndim)
     self.assertEqual({-1: 2}, loaded.input_spec['a'].axes)
     self.assertAllEqual([None, 2, 3], loaded.input_spec['b'].shape)
@@ -274,7 +282,7 @@
     saved_model_dir = self._save_model_dir()
 
     model.save(saved_model_dir, save_format='tf')
-    loaded = saved_model_load.load(saved_model_dir)
+    loaded = keras_load.load(saved_model_dir)
     input_arr_1 = np.random.random((1, 3)).astype('float32')
     input_arr_2 = np.random.random((1, 5)).astype('float32')
 
@@ -292,7 +300,7 @@
 
     saved_model_dir = self._save_model_dir()
     model.save(saved_model_dir, save_format='tf')
-    loaded = saved_model_load.load(saved_model_dir)
+    loaded = keras_load.load(saved_model_dir)
 
     self.assertLen(loaded.layers, 2)
     self.assertLen(loaded.losses, 2)
@@ -307,5 +315,212 @@
     self.assertLen(loaded.layers, 2)
     self.assertLen(loaded.losses, 2)
 
+  def testBatchNormUpdates(self):
+    model = keras.models.Sequential(
+        keras.layers.BatchNormalization(input_shape=(1,)))
+    self.evaluate(variables.variables_initializer(model.variables))
+    saved_model_dir = self._save_model_dir()
+    model.save(saved_model_dir, save_format='tf')
+    loaded = keras_load.load(saved_model_dir)
+    self.evaluate(variables.variables_initializer(loaded.variables))
+    input_arr_1 = np.array([[11], [12], [13]]).astype('float32')
+    self.assertAllClose(self.evaluate(loaded.layers[-1].moving_mean), [0])
+    self.evaluate(loaded(input_arr_1, training=True))
+    self.assertAllClose(self.evaluate(loaded.layers[-1].moving_mean), [0.12])
+    self.evaluate(loaded(input_arr_1, training=False))
+    self.assertAllClose(self.evaluate(loaded.layers[-1].moving_mean), [0.12])
+
+  def testSaveWithSignatures(self):
+    model = keras.models.Sequential()
+    model.add(keras.layers.Dense(5, input_shape=(3,),
+                                 kernel_regularizer=regularizers.get('l2')))
+    model.add(keras.layers.Dropout(0.5))
+    model.add(keras.layers.Dense(4, kernel_regularizer=regularizers.get('l2')))
+
+    input_arr = np.random.random((2, 3)).astype(np.float32)
+    target_arr = np.random.random((2, 4)).astype(np.float32)
+
+    model.compile(
+        loss='mse',
+        optimizer='rmsprop')
+    model.train_on_batch(input_arr, target_arr)
+
+    @def_function.function(input_signature=[tensor_spec.TensorSpec((None, 3))])
+    def predict(inputs):
+      return {'predictions': model(inputs)}
+
+    feature_configs = {
+        'inputs': parsing_ops.FixedLenFeature(
+            shape=[2, 3], dtype=dtypes.float32)}
+
+    @def_function.function(
+        input_signature=[tensor_spec.TensorSpec([None], dtypes.string)])
+    def parse_and_predict(examples):
+      features = parsing_ops.parse_single_example(examples[0], feature_configs)
+      return {'predictions': model(features['inputs']),
+              'layer_1_outputs': model.layers[0](features['inputs'])}
+
+    saved_model_dir = self._save_model_dir()
+    model.save(saved_model_dir, save_format='tf', signatures={
+        'predict': predict,
+        'parse_and_predict': parse_and_predict})
+    model.save('/tmp/saved', save_format='tf', signatures={
+        'predict': predict,
+        'parse_and_predict': parse_and_predict})
+
+    loaded = keras_load.load(saved_model_dir)
+
+    self.assertAllClose(
+        model.predict(input_arr),
+        loaded.signatures['predict'](
+            ops.convert_to_tensor(input_arr))['predictions'])
+
+    feature = {
+        'inputs': feature_pb2.Feature(
+            float_list=feature_pb2.FloatList(value=input_arr.flatten()))}
+    example = example_pb2.Example(
+        features=feature_pb2.Features(feature=feature))
+    outputs = loaded.signatures['parse_and_predict'](
+        ops.convert_to_tensor([example.SerializeToString()]))
+    self.assertAllClose(model.predict(input_arr), outputs['predictions'])
+    self.assertAllClose(model.layers[0](input_arr), outputs['layer_1_outputs'])
+
+  def testTrainingDefaults(self):
+    def assert_training_default(fn, default_value):
+      arg_spec = tf_inspect.getfullargspec(fn)
+      index = len(arg_spec.args) - arg_spec.args.index('training')
+      self.assertEqual(arg_spec.defaults[-index], default_value)
+
+    class LayerWithTrainingRequiredArg(keras.engine.base_layer.Layer):
+
+      def call(self, inputs, training):
+        return tf_utils.smart_cond(
+            training, lambda: inputs * 0, lambda: array_ops.identity(inputs))
+
+    class LayerWithTrainingDefaultTrue(keras.engine.base_layer.Layer):
+
+      def call(self, inputs, training=True):
+        return tf_utils.smart_cond(
+            training, lambda: inputs * 0, lambda: array_ops.identity(inputs))
+
+    class Model(keras.models.Model):
+
+      def __init__(self):
+        super(Model, self).__init__()
+        self.layer_with_training_default_none = LayerWithLearningPhase()
+        self.layer_with_training_default_true = LayerWithTrainingDefaultTrue()
+        self.layer_with_required_training_arg = LayerWithTrainingRequiredArg()
+
+      def call(self, inputs):
+        x = self.layer_with_training_default_none(inputs)
+        x += self.layer_with_training_default_true(inputs)
+        x += self.layer_with_required_training_arg(inputs, False)
+        return x
+
+    model = Model()
+    # Build and set model inputs
+    model.predict(np.ones([1, 3]).astype('float32'))
+    saved_model_dir = self._save_model_dir()
+    model.save(saved_model_dir, save_format='tf')
+    load = tf_load.load(saved_model_dir)
+
+    assert_training_default(load.__call__, False)
+    assert_training_default(
+        load.layer_with_training_default_none.__call__, False)
+    assert_training_default(
+        load.layer_with_training_default_true.__call__, True)
+
+    # Assert that there are no defaults for layer with required training arg
+    arg_spec = tf_inspect.getfullargspec(
+        load.layer_with_required_training_arg.__call__)
+    self.assertFalse(arg_spec.defaults)  # defaults is None or empty
+
+
+class TestLayerCallTracing(test.TestCase):
+
+  def test_functions_have_same_trace(self):
+
+    class Layer(keras.engine.base_layer.Layer):
+
+      def call(self, inputs):
+        return inputs
+
+      def call2(self, inputs):
+        return inputs * 2
+
+    layer = Layer()
+    call_collection = keras_save.LayerCallCollection(layer)
+    fn = call_collection.add_function(layer.call, 'call')
+    fn2 = call_collection.add_function(layer.call2, 'call2')
+
+    fn(np.ones((2, 3)))
+    fn(np.ones((4, 5)))
+
+    self.assertLen(fn._list_all_concrete_functions_for_serialization(), 2)
+    self.assertLen(fn2._list_all_concrete_functions_for_serialization(), 2)
+
+    # Check that the shapes are correct
+    self.assertEqual(
+        {(2, 3), (4, 5)},
+        set(tuple(c.structured_input_signature[0][0].shape.as_list())
+            for c in fn2._list_all_concrete_functions_for_serialization()))
+
+  def test_training_arg_replacement(self):
+
+    def assert_num_traces(layer_cls, training_keyword):
+      layer = layer_cls()
+      call_collection = keras_save.LayerCallCollection(layer)
+      fn = call_collection.add_function(layer.call, 'call')
+
+      fn(np.ones((2, 3)), training=True)
+      self.assertLen(fn._list_all_concrete_functions_for_serialization(), 2)
+
+      fn(np.ones((2, 4)), training=False)
+      self.assertLen(fn._list_all_concrete_functions_for_serialization(), 4)
+
+      if training_keyword:
+        fn(np.ones((2, 5)), True)
+        self.assertLen(fn._list_all_concrete_functions_for_serialization(), 6)
+        fn(np.ones((2, 6)))
+        self.assertLen(fn._list_all_concrete_functions_for_serialization(), 8)
+
+    class LayerWithTrainingKeyword(keras.engine.base_layer.Layer):
+
+      def call(self, inputs, training=False):
+        return inputs * training
+
+    assert_num_traces(LayerWithTrainingKeyword, training_keyword=True)
+
+    class LayerWithKwargs(keras.engine.base_layer.Layer):
+
+      def call(self, inputs, **kwargs):
+        return inputs * kwargs['training']
+
+    assert_num_traces(LayerWithKwargs, training_keyword=False)
+
+    class LayerWithChildLayer(keras.engine.base_layer.Layer):
+
+      def __init__(self):
+        self.child = LayerWithKwargs()
+        super(LayerWithChildLayer, self).__init__()
+
+      def call(self, inputs):
+        return self.child(inputs)
+
+    assert_num_traces(LayerWithChildLayer, training_keyword=False)
+
+  @test_util.run_in_graph_and_eager_modes
+  def test_maintains_losses(self):
+    layer = LayerWithLoss()
+    layer(np.ones((2, 3)))
+    previous_losses = layer.losses[:]
+
+    call_collection = keras_save.LayerCallCollection(layer)
+    fn = call_collection.add_function(layer.call, 'call')
+    fn(np.ones((2, 3)))
+
+    self.assertAllEqual(previous_losses, layer.losses)
+
+
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/keras/saving/saved_model/utils.py b/tensorflow/python/keras/saving/saved_model/utils.py
index 960b370..6a52674 100644
--- a/tensorflow/python/keras/saving/saved_model/utils.py
+++ b/tensorflow/python/keras/saving/saved_model/utils.py
@@ -17,35 +17,173 @@
 from __future__ import division
 from __future__ import print_function
 
+import types
+
 from tensorflow.python.keras import backend as K
 from tensorflow.python.keras.utils import tf_utils
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_inspect
 
 
-def use_wrapped_call(layer, call_fn):
+def use_wrapped_call(layer, call_fn, default_training_value=None,
+                     return_method=False):
   """Creates fn that adds the losses returned by call_fn & returns the outputs.
 
   Args:
     layer: A Keras layer object
     call_fn: tf.function that takes layer inputs (and possibly a training arg),
       and returns a tuple of (outputs, list of losses).
+    default_training_value: Default value of the training kwarg. If `None`, the
+      default is `K.learning_phase()`.
+    return_method: Whether to return a method bound to the layer.
 
   Returns:
     function that calls call_fn and returns the outputs. Losses returned by
     call_fn are added to the layer losses.
   """
-  # TODO(kathywu): Support mask argument and multi-input call functions.
-  def wrapped_call(inputs, **kwargs):
+  expects_training_arg = layer._expects_training_arg   # pylint: disable=protected-access
+  if hasattr(call_fn, 'original_call'):
+    original_call = call_fn.original_call
+  else:
+    original_call = call_fn
+  fn, arg_spec = maybe_add_training_arg(
+      original_call, call_fn, expects_training_arg, default_training_value)
+
+  def return_outputs_and_add_losses(*args, **kwargs):
     """Returns the outputs from the call_fn, and adds the losses."""
-    if layer._expects_training_arg:  # pylint: disable=protected-access
-      training = kwargs.pop('training', None)
-      if training is None:
-        training = K.learning_phase()
-      outputs, losses = tf_utils.smart_cond(
-          training,
-          lambda: call_fn(inputs, training=True),
-          lambda: call_fn(inputs, training=False))
-    else:
-      outputs, losses = call_fn(inputs)
+    inputs_arg_index = 1 if return_method else 0
+    inputs = args[inputs_arg_index]
+    args = args[inputs_arg_index + 1:]
+    outputs, losses = fn(inputs, *args, **kwargs)
     layer.add_loss(losses, inputs)
     return outputs
-  return wrapped_call
+
+  decorated = tf_decorator.make_decorator(
+      target=call_fn,
+      decorator_func=return_outputs_and_add_losses,
+      decorator_argspec=arg_spec)
+
+  if return_method:
+    return types.MethodType(decorated, layer)
+  else:
+    return decorated
+
+
+def maybe_add_training_arg(
+    original_call, wrapped_call, expects_training_arg, default_training_value):
+  """Decorate call and optionally adds training argument.
+
+  If a layer expects a training argument, this function ensures that 'training'
+  is present in the layer args or kwonly args, with the default training value.
+
+  Args:
+    original_call: Original call function.
+    wrapped_call: Wrapped call function.
+    expects_training_arg: Whether to include 'training' argument.
+    default_training_value: Default value of the training kwarg to include in
+      the arg spec. If `None`, the default is `K.learning_phase()`.
+
+  Returns:
+    Tuple of (
+      function that calls `wrapped_call` and sets the training arg,
+      Argspec of returned function or `None` if the argspec is unchanged)
+  """
+  if not expects_training_arg:
+    return wrapped_call, None
+
+  def wrap_with_training_arg(*args, **kwargs):
+    """Wrap the `wrapped_call` function, and set training argument."""
+    training_arg_index = get_training_arg_index(original_call)
+    training = get_training_arg(training_arg_index, args, kwargs)
+    if training is None:
+      training = default_training_value or K.learning_phase()
+
+    args = list(args)
+    kwargs = kwargs.copy()
+
+    def replace_training_and_call(training):
+      set_training_arg(training, training_arg_index, args, kwargs)
+      return wrapped_call(*args, **kwargs)
+
+    return tf_utils.smart_cond(
+        training,
+        lambda: replace_training_and_call(True),
+        lambda: replace_training_and_call(False))
+
+  # Create arg spec for decorated function. If 'training' is not defined in the
+  # args of the original arg spec, then add it to kwonlyargs.
+  arg_spec = tf_inspect.getfullargspec(original_call)
+  defaults = list(arg_spec.defaults) if arg_spec.defaults is not None else []
+
+  kwonlyargs = arg_spec.kwonlyargs
+  kwonlydefaults = arg_spec.kwonlydefaults or {}
+  # Add training arg if it does not exist, or set the default training value.
+  if 'training' not in arg_spec.args:
+    kwonlyargs.append('training')
+    kwonlydefaults['training'] = default_training_value
+  else:
+    index = arg_spec.args.index('training')
+    training_default_index = len(arg_spec.args) - index
+    if (arg_spec.defaults and
+        len(arg_spec.defaults) >= training_default_index and
+        defaults[-training_default_index] is None):
+      defaults[-training_default_index] = default_training_value
+
+  decorator_argspec = tf_inspect.FullArgSpec(
+      args=arg_spec.args,
+      varargs=arg_spec.varargs,
+      varkw=arg_spec.varkw,
+      defaults=defaults,
+      kwonlyargs=kwonlyargs,
+      kwonlydefaults=kwonlydefaults,
+      annotations=arg_spec.annotations)
+  return wrap_with_training_arg, decorator_argspec
+
+
+def get_training_arg_index(call_fn):
+  """Returns the index of 'training' in the layer call function arguments.
+
+  Args:
+    call_fn: Call function.
+
+  Returns:
+    - n: index of 'training' in the call function arguments.
+    - -1: if 'training' is not found in the arguments, but layer.call accepts
+          variable keyword arguments
+    - None: if layer doesn't expect a training argument.
+  """
+  arg_list = tf_inspect.getfullargspec(call_fn).args
+  if tf_inspect.ismethod(call_fn):
+    arg_list = arg_list[1:]
+  if 'training' in arg_list:
+    return arg_list.index('training')
+  else:
+    return -1
+
+
+def set_training_arg(training, index, args, kwargs):
+  if index is None:
+    pass
+  elif index >= 0 and len(args) > index:
+    args[index] = training
+  else:
+    kwargs['training'] = training
+  return args, kwargs
+
+
+def get_training_arg(index, args, kwargs):
+  if index is None:
+    return None
+  elif index >= 0 and len(args) > index:
+    return args[index]
+  else:
+    return kwargs.get('training', None)
+
+
+def remove_training_arg(index, args, kwargs):
+  if index is None:
+    pass
+  elif index >= 0 and len(args) > index:
+    args.pop(index)
+  else:
+    kwargs.pop('training', None)
diff --git a/tensorflow/python/keras/saving/saved_model_experimental_test.py b/tensorflow/python/keras/saving/saved_model_experimental_test.py
index c662a92..11a3ff5 100644
--- a/tensorflow/python/keras/saving/saved_model_experimental_test.py
+++ b/tensorflow/python/keras/saving/saved_model_experimental_test.py
@@ -67,7 +67,7 @@
           metrics=[keras.metrics.categorical_accuracy],
           sample_weight_mode='temporal',
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
       x = np.random.random((1, 3))
       y = np.random.random((1, 3, 3))
       model.train_on_batch(x, y)
@@ -111,7 +111,7 @@
           optimizer=rmsprop.RMSprop(lr=0.0001),
           metrics=[keras.metrics.categorical_accuracy],
           run_eagerly=testing_utils.should_run_eagerly(),
-          run_distributed=testing_utils.should_run_distributed())
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
       x = np.random.random((1, 3))
       y = np.random.random((1, 3))
       model.train_on_batch(x, y)
@@ -169,7 +169,7 @@
         optimizer=training_module.RMSPropOptimizer(0.1),
         metrics=['acc'],
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     y = loaded_model.predict(x)
     self.assertAllClose(ref_y, y, atol=1e-05)
 
diff --git a/tensorflow/python/keras/saving/saving_utils_test.py b/tensorflow/python/keras/saving/saving_utils_test.py
index d3e9eae..92bee3d 100644
--- a/tensorflow/python/keras/saving/saving_utils_test.py
+++ b/tensorflow/python/keras/saving/saving_utils_test.py
@@ -29,7 +29,7 @@
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
-from tensorflow.python.feature_column import feature_column_v2
+from tensorflow.python.feature_column import feature_column_lib
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -89,7 +89,7 @@
         optimizer='sgd',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(x=np.random.random((8, 5)),
               y=np.random.random((8, 3)), epochs=2)
 
@@ -130,7 +130,7 @@
         optimizer='sgd',
         loss='mse',
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     model.fit(x=[np.random.random((8, input_dim)).astype(np.float32),
                  np.random.random((8, input_dim)).astype(np.float32)],
               y=[np.random.random((8, num_classes)).astype(np.float32),
@@ -147,18 +147,18 @@
 
   @test_util.run_in_graph_and_eager_modes
   def test_trace_features_layer(self):
-    columns = [feature_column_v2.numeric_column('x')]
-    model = sequential.Sequential(
-        [feature_column_v2.DenseFeatures(columns)])
+    columns = [feature_column_lib.numeric_column('x')]
+    model = sequential.Sequential([feature_column_lib.DenseFeatures(columns)])
     model_input = {'x': constant_op.constant([[1.]])}
     model.predict(model_input, steps=1)
     fn = saving_utils.trace_model_call(model)
     self.assertAllClose({'output_1': [[1.]]}, fn({'x': [[1.]]}))
 
-    columns = [feature_column_v2.numeric_column('x'),
-               feature_column_v2.numeric_column('y')]
-    model = sequential.Sequential(
-        [feature_column_v2.DenseFeatures(columns)])
+    columns = [
+        feature_column_lib.numeric_column('x'),
+        feature_column_lib.numeric_column('y')
+    ]
+    model = sequential.Sequential([feature_column_lib.DenseFeatures(columns)])
     model_input = {'x': constant_op.constant([[1.]]),
                    'y': constant_op.constant([[2.]])}
     model.predict(model_input, steps=1)
@@ -310,7 +310,7 @@
         ],
         optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01),
         run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
     extract_metrics = saving_utils.extract_model_metrics(model)
     self.assertEqual(set(model_metric_names), set(model.metrics_names))
     self.assertEqual(set(extract_metric_names), set(extract_metrics.keys()))
diff --git a/tensorflow/python/keras/temporal_sample_weights_correctness_test.py b/tensorflow/python/keras/temporal_sample_weights_correctness_test.py
index e702951..0d9f77c 100644
--- a/tensorflow/python/keras/temporal_sample_weights_correctness_test.py
+++ b/tensorflow/python/keras/temporal_sample_weights_correctness_test.py
@@ -64,7 +64,7 @@
       weighted_metrics=[metrics.MeanAbsoluteError(name='mae_2')],
       sample_weight_mode=sample_weight_mode,
       run_eagerly=testing_utils.should_run_eagerly(),
-      run_distributed=testing_utils.should_run_distributed())
+      experimental_run_tf_function=testing_utils.should_run_tf_function())
   return model
 
 
diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py
index bc6f844..fc0f847 100644
--- a/tensorflow/python/keras/testing_utils.py
+++ b/tensorflow/python/keras/testing_utils.py
@@ -28,6 +28,7 @@
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import base_layer_utils
 from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_v2
 from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_v2
 from tensorflow.python.keras.optimizer_v2 import adam as adam_v2
@@ -261,7 +262,7 @@
 _thread_local_data = threading.local()
 _thread_local_data.model_type = None
 _thread_local_data.run_eagerly = None
-_thread_local_data.run_distributed = None
+_thread_local_data.experimental_run_tf_function = None
 
 
 @tf_contextlib.contextmanager
@@ -318,7 +319,7 @@
 
 
 @tf_contextlib.contextmanager
-def run_distributed_scope(value):
+def experimental_run_tf_function_scope(value):
   """Provides a scope within which we compile models to run with distribution.
 
   The boolean gets restored to its original value upon exiting the scope.
@@ -330,23 +331,25 @@
   Yields:
     The provided value.
   """
-  previous_value = _thread_local_data.run_distributed
+  previous_value = _thread_local_data.experimental_run_tf_function
   try:
-    _thread_local_data.run_distributed = value
+    _thread_local_data.experimental_run_tf_function = value
     yield value
   finally:
     # Restore model type to initial value.
-    _thread_local_data.run_distributed = previous_value
+    _thread_local_data.experimental_run_tf_function = previous_value
 
 
-def should_run_distributed():
+def should_run_tf_function():
   """Returns whether the models we are testing should be run distributed."""
-  if _thread_local_data.run_distributed is None:
-    raise ValueError('Cannot call `should_run_distributed()` outside of a '
-                     '`run_distributed_scope()` or `run_all_keras_modes` '
-                     'decorator.')
+  if _thread_local_data.experimental_run_tf_function is None:
+    raise ValueError(
+        'Cannot call `should_run_tf_function()` outside of a '
+        '`experimental_run_tf_function_scope()` or `run_all_keras_modes` '
+        'decorator.')
 
-  return _thread_local_data.run_distributed and context.executing_eagerly()
+  return (_thread_local_data.experimental_run_tf_function and
+          context.executing_eagerly())
 
 
 def get_model_type():
@@ -438,8 +441,18 @@
 class _SubclassModel(keras.Model):
   """A Keras subclass model."""
 
-  def __init__(self, layers):
-    super(_SubclassModel, self).__init__()
+  def __init__(self, layers, *args, **kwargs):
+    """Instantiate a model.
+
+    Args:
+      layers: a list of layers to be added to the model.
+      *args: Model's args
+      **kwargs: Model's keyword args, at most one of
+        input_tensor -> the input tensor required for ragged/sparse input.
+    """
+
+    inputs = kwargs.pop('input_tensor', None)
+    super(_SubclassModel, self).__init__(*args, **kwargs)
     # Note that clone and build doesn't support lists of layers in subclassed
     # models. Adding each layer directly here.
     for i, layer in enumerate(layers):
@@ -447,6 +460,9 @@
 
     self.num_layers = len(layers)
 
+    if inputs is not None:
+      self._set_inputs(inputs)
+
   def _layer_name_for_i(self, i):
     return 'layer{}'.format(i)
 
@@ -461,8 +477,8 @@
 class _SubclassModelCustomBuild(keras.Model):
   """A Keras subclass model that uses a custom build method."""
 
-  def __init__(self, layer_generating_func):
-    super(_SubclassModelCustomBuild, self).__init__()
+  def __init__(self, layer_generating_func, *args, **kwargs):
+    super(_SubclassModelCustomBuild, self).__init__(*args, **kwargs)
     self.all_layers = None
     self._layer_generating_func = layer_generating_func
 
@@ -479,21 +495,50 @@
     return x
 
 
-def get_model_from_layers(layers, input_shape=None, input_dtype=None):
-  """Builds a model from a sequence of layers."""
+def get_model_from_layers(layers,
+                          input_shape=None,
+                          input_dtype=None,
+                          name=None,
+                          input_ragged=None,
+                          input_sparse=None):
+  """Builds a model from a sequence of layers.
+
+  Args:
+    layers: The layers used to build the network.
+    input_shape: Shape tuple of the input or 'TensorShape' instance.
+    input_dtype: Datatype of the input.
+    name: Name for the model.
+    input_ragged: Boolean, whether the input data is a ragged tensor.
+    input_sparse: Boolean, whether the input data is a sparse tensor.
+
+  Returns:
+    A Keras model.
+  """
+
   model_type = get_model_type()
   if model_type == 'subclass':
-    return _SubclassModel(layers)
+    inputs = None
+    if input_ragged or input_sparse:
+      inputs = keras.Input(
+          shape=input_shape,
+          dtype=input_dtype,
+          ragged=input_ragged,
+          sparse=input_sparse)
+    return _SubclassModel(layers, name=name, input_tensor=inputs)
 
   if model_type == 'subclass_custom_build':
     layer_generating_func = lambda: layers
-    return _SubclassModelCustomBuild(layer_generating_func)
+    return _SubclassModelCustomBuild(layer_generating_func, name=name)
 
   if model_type == 'sequential':
-    model = keras.models.Sequential()
+    model = keras.models.Sequential(name=name)
     if input_shape:
-      model.add(keras.layers.InputLayer(input_shape=input_shape,
-                                        dtype=input_dtype))
+      model.add(
+          keras.layers.InputLayer(
+              input_shape=input_shape,
+              dtype=input_dtype,
+              ragged=input_ragged,
+              sparse=input_sparse))
     for layer in layers:
       model.add(layer)
     return model
@@ -502,11 +547,15 @@
     if not input_shape:
       raise ValueError('Cannot create a functional model from layers with no '
                        'input shape.')
-    inputs = keras.Input(shape=input_shape, dtype=input_dtype)
+    inputs = keras.Input(
+        shape=input_shape,
+        dtype=input_dtype,
+        ragged=input_ragged,
+        sparse=input_sparse)
     outputs = inputs
     for layer in layers:
       outputs = layer(outputs)
-    return keras.Model(inputs, outputs)
+    return keras.Model(inputs, outputs, name=name)
 
   raise ValueError('Unknown model type {}'.format(model_type))
 
@@ -766,3 +815,26 @@
     return [n + ':0' for n in var_names]
   # In V1 graph mode variable names are made unique using a suffix.
   return [n + name_suffix + ':0' for n in var_names]
+
+
+def enable_v2_dtype_behavior(fn):
+  """Decorator for enabling the layer V2 dtype behavior on a test."""
+  return _set_v2_dtype_behavior(fn, True)
+
+
+def disable_v2_dtype_behavior(fn):
+  """Decorator for disabling the layer V2 dtype behavior on a test."""
+  return _set_v2_dtype_behavior(fn, False)
+
+
+def _set_v2_dtype_behavior(fn, enabled):
+  """Returns version of 'fn' that runs with v2 dtype behavior on or off."""
+  def wrapper(*args, **kwargs):
+    v2_dtype_behavior = base_layer_utils.V2_DTYPE_BEHAVIOR
+    base_layer_utils.V2_DTYPE_BEHAVIOR = enabled
+    try:
+      return fn(*args, **kwargs)
+    finally:
+      base_layer_utils.V2_DTYPE_BEHAVIOR = v2_dtype_behavior
+
+  return wrapper
diff --git a/tensorflow/python/keras/utils/composite_tensor_support_test.py b/tensorflow/python/keras/utils/composite_tensor_support_test.py
index 649a1f8..b5a1d51 100644
--- a/tensorflow/python/keras/utils/composite_tensor_support_test.py
+++ b/tensorflow/python/keras/utils/composite_tensor_support_test.py
@@ -26,6 +26,7 @@
 from tensorflow.python import keras
 
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
@@ -152,6 +153,20 @@
   raise ValueError("Unknown model type {}".format(model_type))
 
 
+def get_test_mode_kwargs():
+  run_eagerly = testing_utils.should_run_eagerly()
+  # Certain things weren't supported correctly in the old path, therefore
+  # with these changes, some tests now only pass in the single code path in V2.
+  if run_eagerly or context.executing_eagerly():
+    experimental_run_tf_function = True
+  else:
+    experimental_run_tf_function = testing_utils.should_run_tf_function()
+  return {
+      "run_eagerly": run_eagerly,
+      "experimental_run_tf_function": experimental_run_tf_function
+  }
+
+
 @keras_parameterized.run_with_all_model_types
 @keras_parameterized.run_all_keras_modes
 class CompositeTensorInternalTest(keras_parameterized.TestCase):
@@ -181,9 +196,6 @@
     self.assertAllEqual(expected_output, output)
 
   def test_training_internal_ragged_tensors(self):
-    if testing_utils.should_run_distributed():
-      # Training loop stall without clear reason.
-      self.skipTest("b/137397816")
     # Create a model that implements y=Mx. This is easy to learn and will
     # demonstrate appropriate gradient passing. (We have to use RaggedTensors
     # for this test, as ToSparse() doesn't support gradient propagation through
@@ -194,11 +206,7 @@
     input_data = np.random.rand(1024, 1)
     expected_data = np.concatenate((input_data * 3, input_data * .5), axis=-1)
 
-    model.compile(
-        loss="mse",
-        optimizer="adam",
-        run_eagerly=testing_utils.should_run_eagerly(),
-        run_distributed=testing_utils.should_run_distributed())
+    model.compile(loss="mse", optimizer="adam", **get_test_mode_kwargs())
     history = model.fit(input_data, expected_data, epochs=10, verbose=0)
 
     # If the model trained, the loss stored at history[0] should be different
@@ -215,6 +223,8 @@
     # converts the ragged tensor back to a dense tensor.
     layers = [ToRagged(padding=0)]
     model = testing_utils.get_model_from_layers(layers, input_shape=(None,))
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
+    model._run_eagerly = testing_utils.should_run_eagerly()
 
     # Define some input data with additional padding.
     input_data = np.array([[1, 0, 0], [2, 3, 0]])
@@ -228,6 +238,8 @@
     # converts the ragged tensor back to a dense tensor.
     layers = [ToRagged(padding=0)]
     model = testing_utils.get_model_from_layers(layers, input_shape=(None,))
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
+    model._run_eagerly = testing_utils.should_run_eagerly()
 
     # Define some input data with additional padding.
     input_data = np.array([[1, 0, 0], [2, 3, 0], [4, 0, 0], [5, 6, 0]])
@@ -241,6 +253,8 @@
     # converts the ragged tensor back to a dense tensor.
     layers = [ToSparse()]
     model = testing_utils.get_model_from_layers(layers, input_shape=(None,))
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
+    model._run_eagerly = testing_utils.should_run_eagerly()
 
     # Define some input data with additional padding.
     input_data = np.array([[1, 0, 0], [2, 3, 0]])
@@ -259,6 +273,8 @@
     # converts the ragged tensor back to a dense tensor.
     layers = [ToSparse()]
     model = testing_utils.get_model_from_layers(layers, input_shape=(None,))
+    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
+    model._run_eagerly = testing_utils.should_run_eagerly()
 
     # Define some input data with additional padding.
     input_data = np.array([[1, 0, 0], [2, 3, 0], [4, 0, 0], [5, 6, 0]])
@@ -284,26 +300,28 @@
     return "test_input_name"
 
 
-def get_steps():
-  # Determine the steps arg (if appropriate)
-  if not testing_utils.should_run_eagerly():
-    # CompositeTensors in graph mode are symbolic and so require a steps arg.
-    return 1
+def get_kwargs(use_dataset, action="predict"):
+  if use_dataset or not context.executing_eagerly():
+    if action == "fit":
+      return {"steps_per_epoch": 1}
+    return {"steps": 1}
   else:
-    return None
+    return {"batch_size": 2}
 
 
 def prepare_inputs(data, use_dict, use_dataset, action, input_name):
   input_data, expected_output = data
+  batch_size = input_data.shape[0]
   # Prepare the input data.
   if use_dict:
     input_data = {input_name: input_data}
   if use_dataset:
     if action == "predict":
-      input_data = dataset_ops.Dataset.from_tensors(input_data)
+      input_data = dataset_ops.DatasetV2.from_tensor_slices(input_data).batch(
+          batch_size)
     else:
-      input_data = dataset_ops.Dataset.from_tensors(
-          (input_data, expected_output))
+      input_data = dataset_ops.DatasetV2.from_tensor_slices(
+          (input_data, expected_output)).batch(batch_size)
       expected_output = None
   return (input_data, expected_output)
 
@@ -332,8 +350,12 @@
         shape=(1, None), sparse=True, name=input_name, dtype=dtypes.int32)
     layers = [ToDense(default_value=-1)]
     model = get_model_from_layers_with_input(layers, model_input=model_input)
-    model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
-    steps = get_steps()
+    model.compile(
+        optimizer="sgd",
+        loss="mse",
+        metrics=["accuracy"],
+        **get_test_mode_kwargs())
+    kwargs = get_kwargs(use_dataset, action)
 
     # Prepare the input data
     for data_element in data:
@@ -342,15 +364,14 @@
                                                    input_name)
       # Perform the action.
       if action == "predict":
-        result = model.predict(input_data, steps=steps)
+        result = model.predict(input_data, **kwargs)
         self.assertAllEqual(expected_output, result)
       if action == "evaluate":
-        result = model.evaluate(input_data, expected_output, steps=steps)
+        result = model.evaluate(input_data, expected_output, **kwargs)
         self.assertAllEqual(1.0, result[-1])
       if action == "fit":
         # TODO(momernick): What's the best way of validating that fit happened?
-        _ = model.fit(
-            input_data, expected_output, shuffle=False, steps_per_epoch=steps)
+        _ = model.fit(input_data, expected_output, shuffle=False, **kwargs)
 
 
 @keras_parameterized.run_with_all_model_types
@@ -385,7 +406,11 @@
     model_input = input_layer.Input(shape=(3,), sparse=True, dtype=dtypes.int64)
     layers = [ToDense(default_value=-1)]
     model = get_model_from_layers_with_input(layers, model_input=model_input)
-    model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
+    model.compile(
+        optimizer="sgd",
+        loss="mse",
+        metrics=["accuracy"],
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     input_data = scipy.sparse.coo_matrix(([1, 2, 3], ([0, 1, 1], [0, 0, 1])),
                                          shape=[2, 3])
@@ -443,7 +468,11 @@
         shape=(3,), sparse=True, name=input_name, dtype=dtypes.int64)
     layers = [ToDense(default_value=-1)]
     model = get_model_from_layers_with_input(layers, model_input=model_input)
-    model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
+    model.compile(
+        optimizer="sgd",
+        loss="mse",
+        metrics=["accuracy"],
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     input_data = {
         input_name:
@@ -484,7 +513,11 @@
         shape=(None, None), ragged=True, name=input_name, dtype=dtypes.int32)
     layers = [ToDense(default_value=-1)]
     model = get_model_from_layers_with_input(layers, model_input=model_input)
-    model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
+    model.compile(
+        optimizer="sgd",
+        loss="mse",
+        metrics=["accuracy"],
+        **get_test_mode_kwargs())
 
     # Prepare the input data
     for data_element in data:
@@ -524,7 +557,11 @@
         shape=input_shape, ragged=True, name=input_name, dtype=dtypes.int32)
     layers = [ToDense(default_value=-1)]
     model = get_model_from_layers_with_input(layers, model_input=model_input)
-    model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
+    model.compile(
+        optimizer="sgd",
+        loss="mse",
+        metrics=["accuracy"],
+        **get_test_mode_kwargs())
 
     for data_element in data:
       input_data, expected_output = prepare_inputs(
@@ -549,11 +586,12 @@
         shape=input_shape, ragged=True, name=input_name, dtype=dtypes.int32)
     layers = [ToDense(default_value=-1)]
     model = get_model_from_layers_with_input(layers, model_input=model_input)
-    model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
-
-    # The input is a symbolic tensor in non-Eager modes, so 'steps' is required
-    # for that case only.
-    steps = get_steps()
+    model.compile(
+        optimizer="sgd",
+        loss="mse",
+        metrics=["accuracy"],
+        **get_test_mode_kwargs())
+    kwargs = get_kwargs(use_dataset)
 
     for data_element in data:
       input_data, expected_output = prepare_inputs(
@@ -562,7 +600,7 @@
           use_dataset,
           action="predict",
           input_name=input_name)
-      result = model.predict(input_data, steps=steps)
+      result = model.predict(input_data, **kwargs)
       self.assertAllEqual(expected_output, result)
 
   def test_ragged_tensor_input_with_wrong_ragged_rank_fails(
@@ -577,7 +615,11 @@
         shape=input_shape, ragged=True, name=input_name, dtype=dtypes.int32)
     layers = [ToDense(default_value=-1)]
     model = get_model_from_layers_with_input(layers, model_input=model_input)
-    model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
+    model.compile(
+        optimizer="sgd",
+        loss="mse",
+        metrics=["accuracy"],
+        **get_test_mode_kwargs())
 
     # Define some input data with the wrong ragged rank
     for data_element in data:
@@ -618,15 +660,9 @@
     # Define some input data.
     input_data = sparse_tensor.SparseTensor([[0, 0, 0], [1, 0, 0], [1, 0, 1]],
                                             [1, 2, 3], [2, 1, 3])
-    if not testing_utils.should_run_eagerly():
-      # This ragged tensor is actually a standard tensor (as it has no ragged
-      # dimensions). Because of this, graph mode models will expect a steps
-      # arg to be passed (as SparseTensors in graph mode are symbolic).
-      steps = 1
-    else:
-      steps = None
+    kwargs = get_kwargs(use_dataset=False)
     with self.assertRaisesRegex(ValueError, ".*got array with shape.*"):
-      _ = model.predict(input_data, steps=steps)
+      _ = model.predict(input_data, **kwargs)
 
   def test_ragged_tensor_input_with_wrong_value_shape(self):
     # Create a model that accepts a ragged input and converts it to dense.
@@ -652,14 +688,14 @@
     # back to a dense tensor.
     layers = [ToDense(default_value=-1)]
     model = testing_utils.get_model_from_layers(layers)
-    steps = get_steps()
 
     # Define some input data.
     input_data = sparse_tensor.SparseTensor([[0, 0], [1, 0], [1, 1]], [1, 2, 3],
                                             [2, 3])
+    kwargs = get_kwargs(False)
     with self.assertRaisesRegex(
         ValueError, ".*All SparseTensor and RaggedTensor inputs .*"):
-      _ = model.predict(input_data, steps=steps)
+      _ = model.predict(input_data, **kwargs)
 
   def test_subclass_implicit_sparse_scipy_inputs_fails(self):
     # Create a model that accepts a sparse input and converts the sparse tensor
diff --git a/tensorflow/python/keras/utils/conv_utils.py b/tensorflow/python/keras/utils/conv_utils.py
index ea7427f..1d6256e 100644
--- a/tensorflow/python/keras/utils/conv_utils.py
+++ b/tensorflow/python/keras/utils/conv_utils.py
@@ -12,8 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Utilities used by convolution layers.
-"""
+"""Utilities used by convolution layers."""
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
@@ -52,8 +51,8 @@
   """Transforms a single integer or iterable of integers into an integer tuple.
 
   Arguments:
-    value: The value to validate and convert. Could an int, or any iterable
-      of ints.
+    value: The value to validate and convert. Could an int, or any iterable of
+      ints.
     n: The size of the tuple to be returned.
     name: The name of the argument being validated, e.g. "strides" or
       "kernel_size". This is only used to format error messages.
@@ -137,16 +136,20 @@
   return (output_length - 1) * stride - 2 * pad + filter_size
 
 
-def deconv_output_length(input_length, filter_size, padding,
-                         output_padding=None, stride=0, dilation=1):
+def deconv_output_length(input_length,
+                         filter_size,
+                         padding,
+                         output_padding=None,
+                         stride=0,
+                         dilation=1):
   """Determines output length of a transposed convolution given input length.
 
   Arguments:
       input_length: Integer.
       filter_size: Integer.
       padding: one of `"same"`, `"valid"`, `"full"`.
-      output_padding: Integer, amount of padding along the output dimension.
-          Can be set to `None` in which case the output length is inferred.
+      output_padding: Integer, amount of padding along the output dimension. Can
+        be set to `None` in which case the output length is inferred.
       stride: Integer.
       dilation: Integer.
 
@@ -252,10 +255,10 @@
 
 
   Args:
-    input_shape: tuple of size N: `(d_in1, ..., d_inN)`,
-                 spatial shape of the input.
-    kernel_shape: tuple of size N, spatial shape of the convolutional kernel
-                  / receptive field.
+    input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
+      input.
+    kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
+      receptive field.
     strides: tuple of size N, strides along each spatial dimension.
     padding: type of padding, string `"same"` or `"valid"`.
 
@@ -295,21 +298,106 @@
 
   output_axes_ticks = [range(dim) for dim in output_shape]
   for output_position in itertools.product(*output_axes_ticks):
-    input_axes_ticks = conv_connected_inputs(input_shape,
-                                             kernel_shape,
-                                             output_position,
-                                             strides,
-                                             padding)
+    input_axes_ticks = conv_connected_inputs(input_shape, kernel_shape,
+                                             output_position, strides, padding)
     for input_position in itertools.product(*input_axes_ticks):
       mask[input_position + output_position] = True
 
   return mask
 
 
-def conv_connected_inputs(input_shape,
-                          kernel_shape,
-                          output_position,
-                          strides,
+def conv_kernel_idxs(input_shape, kernel_shape, strides, padding, filters_in,
+                     filters_out, data_format):
+  """Yields output-input tuples of indices in a CNN layer.
+
+  The generator iterates over all `(output_idx, input_idx)` tuples, where
+    `output_idx` is an integer index in a flattened tensor representing a single
+    output image of a convolutional layer that is connected (via the layer
+    weights) to the respective single input image at `input_idx`
+
+  Example:
+    ```python
+        >>> input_shape = (2, 2)
+        >>> kernel_shape = (2, 1)
+        >>> strides = (1, 1)
+        >>> padding = "valid"
+        >>> filters_in = 1
+        >>> filters_out = 1
+        >>> data_format = "channels_last"
+        >>> list(conv_kernel_idxs(input_shape, kernel_shape, strides, padding,
+        >>>                       filters_in, filters_out, data_format))
+        [(0, 0), (0, 2), (1, 1), (1, 3)]
+    ```
+  Args:
+    input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
+      input.
+    kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
+      receptive field.
+    strides: tuple of size N, strides along each spatial dimension.
+    padding: type of padding, string `"same"` or `"valid"`.
+    filters_in: `int`, number if filters in the input to the layer.
+    filters_out: `int', number if filters in the output of the layer.
+    data_format: string, "channels_first" or "channels_last".
+
+  Yields:
+    The next tuple `(output_idx, input_idx)`, where
+    `output_idx` is an integer index in a flattened tensor representing a single
+    output image of a convolutional layer that is connected (via the layer
+    weights) to the respective single input image at `input_idx`.
+
+  Raises:
+      ValueError: if `data_format` is neither
+      `"channels_last"` nor `"channels_first"`, or if number of strides, input,
+      and kernel number of dimensions do not match.
+
+      NotImplementedError: if `padding` is neither `"same"` nor `"valid"`.
+  """
+  if padding not in ('same', 'valid'):
+    raise NotImplementedError('Padding type %s not supported. '
+                              'Only "valid" and "same" '
+                              'are implemented.' % padding)
+
+  in_dims = len(input_shape)
+  if isinstance(kernel_shape, int):
+    kernel_shape = (kernel_shape,) * in_dims
+  if isinstance(strides, int):
+    strides = (strides,) * in_dims
+
+  kernel_dims = len(kernel_shape)
+  stride_dims = len(strides)
+  if kernel_dims != in_dims or stride_dims != in_dims:
+    raise ValueError('Number of strides, input and kernel dimensions must all '
+                     'match. Received: %d, %d, %d.' %
+                     (stride_dims, in_dims, kernel_dims))
+
+  output_shape = conv_output_shape(input_shape, kernel_shape, strides, padding)
+  output_axes_ticks = [range(dim) for dim in output_shape]
+
+  if data_format == 'channels_first':
+    concat_idxs = lambda spatial_idx, filter_idx: (filter_idx,) + spatial_idx
+  elif data_format == 'channels_last':
+    concat_idxs = lambda spatial_idx, filter_idx: spatial_idx + (filter_idx,)
+  else:
+    raise ValueError('Data format %s not recignized.'
+                     '`data_format` must be "channels_first" or '
+                     '"channels_last".' % data_format)
+
+  for output_position in itertools.product(*output_axes_ticks):
+    input_axes_ticks = conv_connected_inputs(input_shape, kernel_shape,
+                                             output_position, strides, padding)
+    for input_position in itertools.product(*input_axes_ticks):
+      for f_in in range(filters_in):
+        for f_out in range(filters_out):
+          out_idx = np.ravel_multi_index(
+              multi_index=concat_idxs(output_position, f_out),
+              dims=concat_idxs(output_shape, filters_out))
+          in_idx = np.ravel_multi_index(
+              multi_index=concat_idxs(input_position, f_in),
+              dims=concat_idxs(input_shape, filters_in))
+          yield (out_idx, in_idx)
+
+
+def conv_connected_inputs(input_shape, kernel_shape, output_position, strides,
                           padding):
   """Return locations of the input connected to an output position.
 
@@ -331,12 +419,12 @@
         [xrange(1, 3), xrange(1, 2)]
     ```
   Args:
-    input_shape: tuple of size N: `(d_in1, ..., d_inN)`,
-                 spatial shape of the input.
-    kernel_shape: tuple of size N, spatial shape of the convolutional kernel
-                  / receptive field.
-    output_position: tuple of size N: `(p_out1, ..., p_outN)`,
-                     a single position in the output of the convolution.
+    input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
+      input.
+    kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
+      receptive field.
+    output_position: tuple of size N: `(p_out1, ..., p_outN)`, a single position
+      in the output of the convolution.
     strides: tuple of size N, strides along each spatial dimension.
     padding: type of padding, string `"same"` or `"valid"`.
 
@@ -371,10 +459,10 @@
   Forces dimensions where input is empty (size 0) to remain empty.
 
   Args:
-    input_shape: tuple of size N: `(d_in1, ..., d_inN)`,
-                 spatial shape of the input.
-    kernel_shape: tuple of size N, spatial shape of the convolutional kernel
-                  / receptive field.
+    input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
+      input.
+    kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
+      receptive field.
     strides: tuple of size N, strides along each spatial dimension.
     padding: type of padding, string `"same"` or `"valid"`.
 
@@ -382,11 +470,10 @@
     tuple of size N: `(d_out1, ..., d_outN)`, spatial shape of the output.
   """
   dims = range(len(kernel_shape))
-  output_shape = [conv_output_length(input_shape[d],
-                                     kernel_shape[d],
-                                     padding,
-                                     strides[d])
-                  for d in dims]
-  output_shape = tuple([0 if input_shape[d] == 0 else output_shape[d]
-                        for d in dims])
+  output_shape = [
+      conv_output_length(input_shape[d], kernel_shape[d], padding, strides[d])
+      for d in dims
+  ]
+  output_shape = tuple(
+      [0 if input_shape[d] == 0 else output_shape[d] for d in dims])
   return output_shape
diff --git a/tensorflow/python/keras/utils/io_utils_test.py b/tensorflow/python/keras/utils/io_utils_test.py
index 30e59f9..f67b4df 100644
--- a/tensorflow/python/keras/utils/io_utils_test.py
+++ b/tensorflow/python/keras/utils/io_utils_test.py
@@ -84,9 +84,11 @@
     model = keras.models.Sequential()
     model.add(keras.layers.Dense(64, input_shape=(10,), activation='relu'))
     model.add(keras.layers.Dense(1, activation='sigmoid'))
-    model.compile(loss='binary_crossentropy', optimizer='sgd',
-                  run_eagerly=testing_utils.should_run_eagerly(),
-                  run_distributed=testing_utils.should_run_distributed())
+    model.compile(
+        loss='binary_crossentropy',
+        optimizer='sgd',
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
 
     # Note: you have to use shuffle='batch' or False with HDF5Matrix
     model.fit(x_train, y_train, batch_size=32, shuffle='batch', verbose=False)
diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py
index 4bd65ea..62114ef 100644
--- a/tensorflow/python/keras/utils/layer_utils.py
+++ b/tensorflow/python/keras/utils/layer_utils.py
@@ -24,6 +24,7 @@
 from tensorflow.python.keras import backend as K
 from tensorflow.python.keras.utils.conv_utils import convert_kernel
 from tensorflow.python.util import nest
+from tensorflow.python.util import object_identity
 from tensorflow.python.util.tf_export import keras_export
 
 
@@ -75,7 +76,10 @@
   Returns:
       The total number of scalars composing the weights
   """
-  return int(sum(np.prod(p.shape.as_list()) for p in set(weights)))
+  return int(
+      sum(
+          np.prod(p.shape.as_list())
+          for p in object_identity.ObjectIdentitySet(weights)))
 
 
 def print_summary(model, line_length=None, positions=None, print_fn=None):
diff --git a/tensorflow/python/keras/utils/tf_utils.py b/tensorflow/python/keras/utils/tf_utils.py
index 3d34d99..24da4ad 100644
--- a/tensorflow/python/keras/utils/tf_utils.py
+++ b/tensorflow/python/keras/utils/tf_utils.py
@@ -29,6 +29,7 @@
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.util import nest
+from tensorflow.python.util import object_identity
 from tensorflow.python.util import tf_contextlib
 
 
@@ -107,10 +108,9 @@
     A set of tensors reachable from the inputs (includes the inputs themselves).
   """
   inputs = nest.flatten(inputs, expand_composites=True)
-  reachable = set(inputs)
-  if targets and not isinstance(targets, set):
-    targets = nest.flatten(targets)
-    targets = set(targets)
+  reachable = object_identity.ObjectIdentitySet(inputs)
+  if targets:
+    remaining_targets = object_identity.ObjectIdentitySet(nest.flatten(targets))
   queue = inputs[:]
 
   while queue:
@@ -136,10 +136,13 @@
     for y in outputs:
       if y not in reachable:
         reachable.add(y)
+        if targets:
+          remaining_targets.discard(y)
         queue.insert(0, y)
 
-    if targets and targets.issubset(reachable):
+    if targets and not remaining_targets:
       return reachable
+
   return reachable
 
 
@@ -250,10 +253,7 @@
     Structure of same type as nested, with lists wrapped/unwrapped.
   """
 
-  def _is_atomic_nested(nested):
-    """Returns `True` if `nested` is a list representing node data."""
-    if isinstance(nested, ListWrapper):
-      return True
+  def _is_serialized_node_data(nested):
     # Node data can be of form `[layer_name, node_id, tensor_id]` or
     # `[layer_name, node_id, tensor_id, kwargs]`.
     if (isinstance(nested, list) and (len(nested) in [3, 4]) and
@@ -261,12 +261,22 @@
       return True
     return False
 
+  def _is_atomic_nested(nested):
+    """Returns `True` if `nested` is a list representing node data."""
+    if isinstance(nested, ListWrapper):
+      return True
+    if _is_serialized_node_data(nested):
+      return True
+    return not nest.is_sequence(nested)
+
   def _convert_object_or_list(nested):
     """Convert b/t `ListWrapper` object and list representations."""
     if wrap:
       if isinstance(nested, ListWrapper):
         return nested
-      return ListWrapper(nested)
+      if _is_serialized_node_data(nested):
+        return ListWrapper(nested)
+      return nested
     else:
       if isinstance(nested, ListWrapper):
         return nested.as_list()
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 3ce1ee7..c2b3c85 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -238,7 +238,6 @@
         "no_rocm",  # TODO(rocm): feature not supported on ROCm platform
         "nomsan",  # TODO(b/131773093): Re-enable.
     ],
-    xla_enable_strict_auto_jit = True,
 )
 
 tf_py_test(
@@ -271,7 +270,6 @@
         "//tensorflow/python:state_ops",
         "//tensorflow/python:variables",
     ],
-    tags = ["nofwdcompat"],  # b/137641346
 )
 
 tf_py_test(
@@ -984,7 +982,6 @@
         "//tensorflow/python:errors",
         "//tensorflow/python:framework_for_generated_wrappers",
     ],
-    tags = ["nofwdcompat"],  # b/137641346
 )
 
 tf_py_test(
@@ -3123,6 +3120,7 @@
         "//tensorflow/python:array_ops",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:control_flow_ops",
+        "//tensorflow/python:control_flow_v2_toggles",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
         "//tensorflow/python:gradients",
@@ -3600,8 +3598,7 @@
         "//tensorflow/python:math_ops",
     ],
     shard_count = 20,
-    # TODO(b/134764123): Re-enable this test.
-    xla_enable_strict_auto_jit = False,
+    xla_enable_strict_auto_jit = True,
 )
 
 sycl_py_test(
@@ -3866,6 +3863,7 @@
         "//tensorflow/python:constant_op",
         "//tensorflow/python:control_flow_ops",
         "//tensorflow/python:control_flow_util",
+        "//tensorflow/python:control_flow_v2_toggles",
         "//tensorflow/python:dtypes",
         "//tensorflow/python:framework",
         "//tensorflow/python:framework_ops",
@@ -3891,6 +3889,7 @@
         "//tensorflow/python/data/experimental/ops:prefetching_ops",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:control_flow_ops",
+        "//tensorflow/python:control_flow_v2_toggles",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
         "//tensorflow/python:gradients",
diff --git a/tensorflow/python/kernel_tests/ackermann_op.cc b/tensorflow/python/kernel_tests/ackermann_op.cc
index d42ca6f..2d885b7 100644
--- a/tensorflow/python/kernel_tests/ackermann_op.cc
+++ b/tensorflow/python/kernel_tests/ackermann_op.cc
@@ -35,7 +35,7 @@
     Tensor* output_tensor = nullptr;
     OP_REQUIRES_OK(context,
                    context->allocate_output(0, TensorShape(), &output_tensor));
-    auto output = output_tensor->scalar<string>();
+    auto output = output_tensor->scalar<tstring>();
 
     output() = "A(m, 0) == A(m-1, 1)";
   }
diff --git a/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py
index bbceb82..fb44c33 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py
@@ -108,31 +108,31 @@
       self.assertAllClose(self._feature_0_quantiles, quantiles[0].eval())
       self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
 
-  def testBasicQuantileBucketsMultipleResourcesAddFlushed(self):
+  def testBasicQuantileBucketsSingleResourcesAddFlushed(self):
     with self.cached_session():
-      quantile_accumulator_handle_0 = self.create_resource("floats_0", self.eps,
-                                                           self.max_elements, 2)
-      quantile_accumulator_handle_1 = self.create_resource("floats_1", self.eps,
-                                                           self.max_elements, 2)
+      quantile_accumulator_handle = self.create_resource("floats_0", self.eps,
+                                                         self.max_elements, 2)
       resources.initialize_resources(resources.shared_resources()).run()
       summaries = boosted_trees_ops.make_quantile_summaries(
           [self._feature_0, self._feature_1], self._example_weights,
           epsilon=self.eps)
       summary_op = boosted_trees_ops.quantile_add_summaries(
-          quantile_accumulator_handle_0, summaries)
+          quantile_accumulator_handle, summaries)
       flushed_summaries = flush_quantile_summaries(
-          quantile_accumulator_handle_0, num_features=2)
+          quantile_accumulator_handle, num_features=2)
 
       # We are testing whether the flushed summaries output at the previous step
-      # will give the same expected results by inputting it to add_summaries
+      # will give the same expected results by inputing it to add_summaries
       summary_op_2 = boosted_trees_ops.quantile_add_summaries(
-          quantile_accumulator_handle_1, flushed_summaries)
+          quantile_accumulator_handle, flushed_summaries)
+
       flush_op = boosted_trees_ops.quantile_flush(
-          quantile_accumulator_handle_1, self.num_quantiles)
+          quantile_accumulator_handle, self.num_quantiles)
       buckets = boosted_trees_ops.get_bucket_boundaries(
-          quantile_accumulator_handle_1, num_features=2)
+          quantile_accumulator_handle, num_features=2)
       quantiles = boosted_trees_ops.boosted_trees_bucketize(
           [self._feature_0, self._feature_1], buckets)
+
       self.evaluate(summary_op)
       self.evaluate(summary_op_2)
       self.evaluate(flush_op)
diff --git a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
index 32e47ef..36cc52a 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
@@ -30,6 +30,7 @@
 
 _INEQUALITY_DEFAULT_LEFT = 'inequality_default_left'.encode('utf-8')
 _INEQUALITY_DEFAULT_RIGHT = 'inequality_default_right'.encode('utf-8')
+_EQUALITY_DEFAULT_LEFT = 'equality_default_left'.encode('utf-8')
 
 
 class StatsOpsTest(test_util.TensorFlowTestCase):
@@ -208,6 +209,39 @@
     self.assertAllClose([[-.076923], [-.75]], right_node_contribs)
     self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
 
+  def testCalculateBestMultiDimFeatureEqualitySplitsWithoutRegularization(self):
+    """Testing best split calculation without any regularization."""
+    node_id_range = [1, 3]  # node 1 through 2 will be processed.
+    stats_summary = np.asarray(self._get_stats_summary_for_split())
+    # reshape to [max_splits, feature_dim, num_buckets, 2]
+    stats_summary = np.moveaxis(stats_summary, 0, 1)
+
+    (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
+     right_node_contribs, split_types) = self.evaluate(
+         boosted_trees_ops.calculate_best_feature_split(
+             node_id_range,
+             stats_summary,
+             l1=0.0,
+             l2=0.0,
+             tree_complexity=0.0,
+             min_node_weight=0,
+             logits_dimension=1,
+             split_type='equality'))
+
+    self.assertAllEqual([1, 2], node_ids)
+    # 0.116495 = (-0.05)^2/0.06 + 0.36^2/0.57 - 0.31^2/0.63
+    # 0.60429 = (-0.4)^2/0.5 + 0.37^2/0.48 - 0.03^2/0.98
+    self.assertAllClose([0.116495, 0.60429], gains)
+    self.assertAllEqual([2, 2], thresholds)
+    self.assertAllEqual([1, 1], feature_dimensions)
+    # The left node contrib will be later added to the previous node value to
+    # make the left node value, and the same for right node contrib.
+    # left contrib 0.83 = 0.05/0.06, 0.8 = 0.4/0.5
+    self.assertAllClose([[0.833333], [.8]], left_node_contribs)
+    # right contrib -0.6315 = -0.36/0.57, -0.7708 = -0.37/0.48
+    self.assertAllClose([[-0.631579], [-0.770833]], right_node_contribs)
+    self.assertAllEqual([_EQUALITY_DEFAULT_LEFT] * 2, split_types)
+
   def testCalculateBestGainsWithL2(self):
     """Testing Gain calculation with L2."""
     with self.cached_session() as sess:
@@ -267,6 +301,39 @@
     self.assertAllClose([[-.043478], [-.6]], right_node_contribs)
     self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
 
+  def testCalculateMultiDimBestFeatureEqualitySplitsWithL2(self):
+    """Testing best split calculation with L2."""
+    node_id_range = [1, 3]  # node 1 through 2 will be processed.
+    stats_summary = np.asarray(self._get_stats_summary_for_split())
+    # reshape to [max_splits, feature_dim, num_buckets, 2]
+    stats_summary = np.moveaxis(stats_summary, 0, 1)
+
+    (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
+     right_node_contribs, split_types) = self.evaluate(
+         boosted_trees_ops.calculate_best_feature_split(
+             node_id_range,
+             stats_summary,
+             l1=0.0,
+             l2=0.1,
+             tree_complexity=0.0,
+             min_node_weight=0,
+             logits_dimension=1,
+             split_type='equality'))
+
+    self.assertAllEqual([1, 2], node_ids)
+    # 0.077414 = 0.05^2/0.16 + 0.36^2/0.67 - 0.31^2/0.73
+    # 0.501868 = 0.4^2/0.6 + 0.37^2/0.58 - 0.03^2/1.08
+    self.assertAllClose([0.077414, 0.501868], gains)
+    self.assertAllEqual([2, 2], thresholds)
+    self.assertAllEqual([1, 1], feature_dimensions)
+    # # The left node contrib will be later added to the previous node value to
+    # # make the left node value, and the same for right node contrib.
+    # left contrib 0.3125 = 0.05/0.16, 0.6667 = 0.4/0.6
+    self.assertAllClose([[0.3125], [0.666667]], left_node_contribs)
+    # right contrib -0.5373 = -0.36/0.67, -0.6379 = -0.37/0.58
+    self.assertAllClose([[-0.537313], [-0.637931]], right_node_contribs)
+    self.assertAllEqual([_EQUALITY_DEFAULT_LEFT] * 2, split_types)
+
   def testSparseCalculateBestSplitsWithL2(self):
     node_id_range = [1, 3]
     (summary_indices, summary_values,
@@ -357,6 +424,40 @@
     self.assertAllEqual([1, 1], feature_dimensions)
     self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
 
+  def testCalculateBestMultiDimFeatureEqualitySplitsWithL1(self):
+    """Testing best split calculation with L1."""
+    node_id_range = [1, 3]  # node 1 through 2 will be processed.
+    stats_summary = np.asarray(self._get_stats_summary_for_split())
+    # reshape to [max_splits, feature_dim, num_buckets, 2]
+    stats_summary = np.moveaxis(stats_summary, 0, 1)
+
+    l1 = 0.1
+    (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
+     right_node_contribs, split_types) = self.evaluate(
+         boosted_trees_ops.calculate_best_feature_split(
+             node_id_range,
+             stats_summary,
+             l1=l1,
+             l2=0.,
+             tree_complexity=0.0,
+             min_node_weight=0,
+             logits_dimension=1,
+             split_type='equality'))
+
+    self.assertAllEqual([1, 2], node_ids)
+    # 0.048597 = 0 + 0.26^2/0.57 - 0.21^2/0.63
+    # 0.501868 = 0.3^2/0.5 + 0.27^2/0.48 - 0
+    self.assertAllClose([0.048597, 0.331875], gains)
+    self.assertAllEqual([2, 2], thresholds)
+    self.assertAllEqual([1, 1], feature_dimensions)
+    # # The left node contrib will be later added to the previous node value to
+    # # make the left node value, and the same for right node contrib.
+    # left contrib 0 (-0.05>-0.1), 0.6 = 0.3/0.5
+    self.assertAllClose([[0], [0.6]], left_node_contribs)
+    # right contrib -0.45614 = -0.26/0.57, -0.5625 = -0.27/0.48
+    self.assertAllClose([[-0.45614], [-0.5625]], right_node_contribs)
+    self.assertAllEqual([_EQUALITY_DEFAULT_LEFT] * 2, split_types)
+
   def testSparseCalculateBestSplitsWithL1(self):
     node_id_range = [1, 3]
     (summary_indices, summary_values,
@@ -448,6 +549,41 @@
     self.assertAllEqual([1, 0], feature_dimensions)
     self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
 
+  def testCalculateBestMultiDimFeatureEqualitySplitsWithTreeComplexity(self):
+    """Testing best split calculation with tree complexity."""
+    node_id_range = [1, 3]  # node 1 through 2 will be processed.
+    stats_summary = np.asarray(self._get_stats_summary_for_split())
+    # reshape to [max_splits, feature_dim, num_buckets, 2]
+    stats_summary = np.moveaxis(stats_summary, 0, 1)
+
+    l2 = 0.1
+    tree_complexity = 3.
+    (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
+     right_node_contribs, split_types) = self.evaluate(
+         boosted_trees_ops.calculate_best_feature_split(
+             node_id_range,
+             stats_summary,
+             l1=0.,
+             l2=l2,
+             tree_complexity=tree_complexity,
+             min_node_weight=0,
+             logits_dimension=1,
+             split_type='equality'))
+
+    self.assertAllEqual([1, 2], node_ids)
+    # -2.922586 = 0.05^2/0.16 + 0.36^2/0.67 - 0.31^2/0.73 - 3
+    # -2.498132 = 0.4^2/0.6 + 0.37^2/0.58 - 0.03^2/1.08 - 3
+    self.assertAllClose([-2.922586, -2.498132], gains)
+    self.assertAllEqual([2, 2], thresholds)
+    self.assertAllEqual([1, 1], feature_dimensions)
+    # # The left node contrib will be later added to the previous node value to
+    # # make the left node value, and the same for right node contrib.
+    # left contrib 0.3125 = 0.05/0.16, 0.6667 = 0.4/0.6
+    self.assertAllClose([[0.3125], [0.666667]], left_node_contribs)
+    # right contrib -0.5373 = -0.36/0.67, -0.6379 = -0.37/0.58
+    self.assertAllClose([[-0.537313], [-0.637931]], right_node_contribs)
+    self.assertAllEqual([_EQUALITY_DEFAULT_LEFT] * 2, split_types)
+
   def testSparseCalculateBestSplitsWithTreeComplexity(self):
     """Testing best split calculation with tree complexity."""
     node_id_range = [1, 3]
@@ -723,6 +859,58 @@
          logits_dimension=1)
     self.assertAllEqual([], node_ids)
 
+  def testCalculateBestMultiDimFeatureEqualitySplitsWithNoSplitPossible(self):
+    """Testing best split calculation with min node weight and no split."""
+    node_id_range = [1, 3]  # node 1 through 2 will be processed.
+    stats_summary = np.asarray([
+        [
+            [[0., 0.], [.08, .09], [0., 0.], [0., 0.]],  # node 0; ignored
+            [[0., 0.], [.15, .36], [.06, .7], [.1, .2]],  # node 1
+            [[0., 0.], [-.33, .068], [0., 0.], [.3, .04]],  # node 2
+            [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 3; ignored
+            [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 4; ignored
+            [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 5; ignored
+            [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 6; ignored
+        ],  # feature 0
+        [
+            [[0., 0.], [0., 0.], [.08, .09], [0., 0.]],  # node 0; ignored
+            [[0., 0.], [.3, .5], [-.05, .06], [.06, .7]],  # node 1
+            [[.1, .1], [.2, -.05], [-.4, .05], [.07, .08]],  # node 2
+            [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 3; ignored
+            [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 4; ignored
+            [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 5; ignored
+            [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 6; ignored
+        ],  # feature 1
+    ])  # num_features * shape=[max_splits, num_buckets, 2]
+    # reshape to [max_splits, feature_dim, num_buckets, 2]
+    stats_summary = np.moveaxis(stats_summary, 0, 1)
+
+    (node_ids, _, _, _, _, _,
+     _) = boosted_trees_ops.calculate_best_feature_split(
+         node_id_range,
+         stats_summary,
+         l1=0.0,
+         l2=0.0,
+         tree_complexity=0.0,
+         min_node_weight=1,
+         logits_dimension=1,
+         split_type='equality')
+
+    # We can't split either of the nodes on the first feature
+    self.assertAllEqual([1], node_ids)
+
+    # Now check when we can't split on any feature
+    (node_ids, _, _, _, _, _,
+     _) = boosted_trees_ops.calculate_best_feature_split(
+         node_id_range,
+         stats_summary,
+         l1=0.0,
+         l2=0.0,
+         tree_complexity=0.0,
+         min_node_weight=10,
+         logits_dimension=1)
+    self.assertAllEqual([], node_ids)
+
   def testSparseCalculateBestSplitsWithMinNodeWeightNoSplitOnFeature(self):
     """Testing best split calculation with min node weight and no split."""
     node_id_range = [1, 3]  # node 1 through 2 will be processed.
diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py
index 53dd065..06dd8d7 100644
--- a/tensorflow/python/kernel_tests/check_ops_test.py
+++ b/tensorflow/python/kernel_tests/check_ops_test.py
@@ -40,6 +40,9 @@
 from tensorflow.python.platform import test
 
 
+# pylint: disable=g-error-prone-assert-raises
+
+
 class AssertV2Asserts(test.TestCase):
 
   def test_passes_when_it_should(self):
@@ -308,6 +311,15 @@
       out = array_ops.identity(larry)
     self.evaluate(out)
 
+  @test_util.run_in_graph_and_eager_modes
+  def test_noop_when_both_identical(self):
+    larry = constant_op.constant([])
+    check_op = check_ops.assert_equal(larry, larry)
+    if context.executing_eagerly():
+      self.assertIs(check_op, None)
+    else:
+      self.assertEqual(check_op.type, "NoOp")
+
 
 class AssertNoneEqualTest(test.TestCase):
 
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index 5a40390..abec3bf 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -56,7 +56,8 @@
     with self.session(graph=ops.get_default_graph()) as sess:
       pred = array_ops.placeholder(dtypes.bool, name="pred")
 
-      expected = control_flow_ops.cond(pred, true_fn, false_fn, name="expected")
+      expected = control_flow_ops.cond(
+          array_ops.squeeze_v2(pred), true_fn, false_fn, name="expected")
       actual = cond_v2.cond_v2(pred, true_fn, false_fn, name="actual")
 
       expected_grad = gradients_impl.gradients(expected, train_vals)
@@ -69,6 +70,13 @@
       self.assertEqual(expected_val, actual_val)
       self.assertEqual(expected_grad_val, actual_grad_val)
 
+      sess_run_args = {pred: [[True]]}
+      sess_run_args.update(feed_dict)
+      expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(
+          (expected, actual, expected_grad, actual_grad), sess_run_args)
+      self.assertEqual(expected_val, actual_val)
+      self.assertEqual(expected_grad_val, actual_grad_val)
+
       sess_run_args = {pred: False}
       sess_run_args.update(feed_dict)
       expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(
@@ -76,6 +84,13 @@
       self.assertEqual(expected_val, actual_val)
       self.assertEqual(expected_grad_val, actual_grad_val)
 
+      sess_run_args = {pred: [[False]]}
+      sess_run_args.update(feed_dict)
+      expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(
+          (expected, actual, expected_grad, actual_grad), sess_run_args)
+      self.assertEqual(expected_val, actual_val)
+      self.assertEqual(expected_grad_val, actual_grad_val)
+
   @test_util.run_deprecated_v1
   def testBasic(self):
     x = constant_op.constant(1.0, name="x")
diff --git a/tensorflow/python/kernel_tests/conditional_accumulator_test.py b/tensorflow/python/kernel_tests/conditional_accumulator_test.py
index 37afb32..148fde1 100644
--- a/tensorflow/python/kernel_tests/conditional_accumulator_test.py
+++ b/tensorflow/python/kernel_tests/conditional_accumulator_test.py
@@ -39,47 +39,12 @@
 
 class ConditionalAccumulatorTest(test.TestCase):
 
-  def testConstructor(self):
-    with ops.Graph().as_default():
-      q = data_flow_ops.ConditionalAccumulator(dtypes_lib.float32, name="Q")
-    self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
-    self.assertProtoEquals(
-        """
-      name:'Q' op:'ConditionalAccumulator'
-      attr { key: 'dtype' value { type: DT_FLOAT } }
-      attr { key: 'shape' value { shape { unknown_rank: true} } }
-      attr { key: 'container' value { s: '' } }
-      attr { key: 'shared_name' value { s: '' } }
-      attr { key: 'reduction_type' value {s: 'MEAN'} }
-      """, q.accumulator_ref.op.node_def)
-
   def testConstructorWithInvalidArg(self):
     with ops.Graph().as_default():
       with self.assertRaises(ValueError):
         data_flow_ops.ConditionalAccumulator(
             dtypes_lib.float32, name="Q", reduction_type="Invalid")
 
-  def testConstructorWithShape(self):
-    with ops.Graph().as_default():
-      q = data_flow_ops.ConditionalAccumulator(
-          dtypes_lib.float32,
-          name="Q",
-          shape=tensor_shape.TensorShape([1, 5, 2, 8]))
-    self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
-    self.assertProtoEquals(
-        """
-      name:'Q' op:'ConditionalAccumulator'
-      attr { key: 'dtype' value { type: DT_FLOAT } }
-      attr { key: 'shape' value { shape { dim {size: 1 }
-                                          dim {size: 5 }
-                                          dim {size: 2 }
-                                          dim {size: 8 }
-      } } }
-      attr { key: 'container' value { s: '' } }
-      attr { key: 'shared_name' value { s: '' } }
-      attr { key: 'reduction_type' value {s: 'MEAN'} }
-      """, q.accumulator_ref.op.node_def)
-
   @test_util.run_deprecated_v1
   def testAccumulatorSizeEmpty(self):
     with self.cached_session():
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 9bc9f30..007c3f2 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -22,6 +22,7 @@
 
 import collections
 import math
+import re
 import sys
 import time
 
@@ -43,6 +44,7 @@
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
@@ -431,8 +433,8 @@
   @test_util.run_v1_only("b/120545219")
   def testCondIndexedSlices(self):
     with self.cached_session():
-      values = constant_op.constant(10)
-      indices = constant_op.constant(0)
+      values = constant_op.constant([10])
+      indices = constant_op.constant([0])
       x = ops.IndexedSlices(values, indices)
       pred = math_ops.less(1, 2)
       fn1 = lambda: ops.IndexedSlices(math_ops.add(x.values, 1), indices)
@@ -441,14 +443,14 @@
 
       val = r.values
       ind = r.indices
-    self.assertAllEqual(11, val)
-    self.assertAllEqual(0, ind)
+    self.assertAllEqual([11], val)
+    self.assertAllEqual([0], ind)
 
   def testCondMismatchedIndexedSlices(self):
     @def_function.function
     def foo():
-      values = constant_op.constant(10)
-      indices = constant_op.constant(0)
+      values = constant_op.constant([10])
+      indices = constant_op.constant([0])
       x = ops.IndexedSlices(values, indices)
       with self.assertRaisesRegexp(
           TypeError, "Cannot reconcile tf.cond 0-th outputs"):
@@ -517,9 +519,9 @@
   @test_util.run_v1_only("b/120545219")
   def testCondIndexedSlicesDifferentTypes(self):
     with self.cached_session():
-      values = constant_op.constant(10)
-      i_32 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int32)
-      i_64 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int64)
+      values = constant_op.constant([10])
+      i_32 = ops.convert_to_tensor([0], name="one", dtype=dtypes.int32)
+      i_64 = ops.convert_to_tensor([0], name="one", dtype=dtypes.int64)
       x = ops.IndexedSlices(values, i_32)
       pred = math_ops.less(1, 2)
       fn1 = lambda: ops.IndexedSlices(math_ops.add(x.values, 1), i_32)
@@ -528,8 +530,8 @@
 
       val = r.values
       ind = r.indices
-    self.assertAllEqual(11, val)
-    self.assertAllEqual(0, ind)
+    self.assertAllEqual([11], val)
+    self.assertAllEqual([0], ind)
     self.assertTrue(ind.dtype == np.int64)
 
   @test_util.run_v1_only("b/120545219")
@@ -1006,6 +1008,79 @@
       self.assertAllEqual(1.0, self.evaluate(grad))
 
   @test_util.run_deprecated_v1
+  @test_util.enable_control_flow_v2
+  def testCondComputeGradAfterSessRunFails(self):
+    with self.cached_session():
+      x = constant_op.constant(10.0, name="x")
+      pred = math_ops.less(1, 2)
+
+      def true_fn():
+        a = x * x
+        return a * a
+
+      def false_fn():
+        return x * x
+
+      r = control_flow_ops.cond(pred, true_fn, false_fn)
+
+      self.assertAllEqual(r, 10000.)
+      grad = gradients_impl.gradients(r, [x])[0]
+      with self.assertRaisesRegexp(
+          errors_impl.InvalidArgumentError,
+          r"Connecting to invalid output 1 of source node cond which has 1 "
+          r"outputs. Try using "
+          "tf.compat.v1.experimental.output_all_intermediates\(True\)."):
+        self.evaluate(grad)
+
+  @test_util.run_deprecated_v1
+  @test_util.enable_output_all_intermediates
+  def testCondComputeGradAfterSessRun(self):
+    with self.cached_session():
+      x = constant_op.constant(10.0, name="x")
+      pred = math_ops.less(1, 2)
+
+      def true_fn():
+        a = x * x
+        return a * a
+
+      def false_fn():
+        return x * x
+
+      r = control_flow_ops.cond(pred, true_fn, false_fn)
+
+      self.assertAllEqual(r, 10000.)
+      grad = gradients_impl.gradients(r, [x])[0]
+      self.assertAllEqual(grad, 4000.)
+
+  @test_util.run_deprecated_v1
+  @test_util.enable_output_all_intermediates
+  def testNestedCondComputeGradAfterSessRun(self):
+    with self.cached_session():
+      x = constant_op.constant(10.0, name="x")
+      pred = math_ops.less(1, 2)
+
+      def true_fn():
+
+        def inner_true_fn():
+          a = x * x
+          return a * a
+
+        def inner_false_fn():
+          return x * x
+
+        return control_flow_ops.cond(
+            constant_op.constant(True), inner_true_fn, inner_false_fn)
+
+      def false_fn():
+        return x * x
+
+      r = control_flow_ops.cond(pred, true_fn, false_fn)
+
+      self.assertAllEqual(r, 10000.)
+      grad = gradients_impl.gradients(r, [x])[0]
+      self.assertAllEqual(grad, 4000.)
+
+  @test_util.run_deprecated_v1
   def testCondGrad_2(self):
     with self.cached_session():
       c = array_ops.placeholder(dtypes.int32, shape=[])
@@ -1566,8 +1641,8 @@
     if control_flow_util.ENABLE_CONTROL_FLOW_V2:
       xla_context = control_flow_ops.XLAControlFlowContext()
       xla_context.Enter()
-      with self.assertRaisesRegexp(
-          ValueError, r"Tensor.*Placeholder:0.* must be from the same graph.*"):
+      with self.assertRaisesRegexp(ValueError,
+                                   r"must be from the same graph.*"):
         loop = create_while_loop()
       xla_context.Exit()
     else:
@@ -1653,14 +1728,17 @@
         for dev in run_metadata_without_xla_context.step_stats.dev_stats:
           if "/device:CPU" in dev.device:
             node_stats = dev.node_stats
-        stack_push_op = "TensorListPushBack"
+        stack_push_count = len([
+            x for x in node_stats
+            if re.match(r".*TensorListPushBack_?\d*", x.node_name)
+        ])
       else:
         for dev in run_metadata.step_stats.dev_stats:
           if "/device:CPU" in dev.device:
             node_stats = dev.node_stats
         stack_push_op = "StackPushV2"
-      stack_push_count = len(
-          [x for x in node_stats if x.node_name.endswith(stack_push_op)])
+        stack_push_count = len(
+            [x for x in node_stats if x.node_name.endswith("StackPushV2")])
       # Pushes to the stack = product of maximum_iterations values;
       # the last two "3"s comes from size(p), when p == [0, 0, 0].
       self.assertEqual(stack_push_count, 5 * 3 * 3, str(node_stats))
@@ -2117,6 +2195,50 @@
     self.assertTrue(r.values.row_splits.shape.as_list() in ([6], [None]))
     self.assertTrue(r.values.values.shape.as_list() in ([49], [None]))
 
+  def testWhileShapeInvariantTensorSpec(self):
+    i = constant_op.constant(0)
+    x = constant_op.constant([1])
+    c = lambda i, _: i < 10
+    b = lambda i, x: (i + 1, array_ops.stack([x, x]))
+    shape_invariants = [
+        tensor_spec.TensorSpec([], dtype=dtypes.int32),
+        tensor_spec.TensorSpec(None, dtype=dtypes.int32)]
+    control_flow_ops.while_loop(c, b, [i, x], shape_invariants)
+
+  # TODO(b/131265085) Remove this decorator when bug is fixed.
+  @test_util.build_as_function_and_v1_graph
+  def testWhileShapeInvariantWrongTypeSpecType(self):
+    c = lambda i, _: i < 10
+    b = lambda i, x: (i + 1, x)
+    i = constant_op.constant(0)
+    x = sparse_tensor.SparseTensor([[0]], [1.0], [10])
+    shape_invariants = [
+        tensor_spec.TensorSpec([], dtype=dtypes.int32),
+        sparse_tensor.SparseTensorSpec([None])]
+    control_flow_ops.while_loop(c, b, [i, x], shape_invariants)
+
+    x2 = constant_op.constant([1])
+    with self.assertRaises(TypeError):
+      control_flow_ops.while_loop(c, b, [i, x2], shape_invariants)
+
+    x3 = ragged_factory_ops.constant([[1, 2], [3]])
+    with self.assertRaises(TypeError):
+      control_flow_ops.while_loop(c, b, [i, x3], shape_invariants)
+
+    i2 = constant_op.constant(0.0)
+    with self.assertRaises(TypeError):
+      control_flow_ops.while_loop(c, b, [i2, x], shape_invariants)
+
+  # TODO(b/131265085) Remove this decorator when bug is fixed.
+  @test_util.build_as_function_and_v1_graph
+  def testWhileShapeInvariantBadType(self):
+    i = constant_op.constant(0)
+    x = constant_op.constant([1])
+    c = lambda i, _: i < 10
+    b = lambda i, x: (i + 1, x)
+    with self.assertRaises((ValueError, TypeError)):
+      control_flow_ops.while_loop(c, b, [i, x], ["foo", "bar"])
+
   def _testNestedWhile_1(self, use_gpu):
     with self.cached_session(use_gpu=use_gpu):
       n = constant_op.constant(0)
@@ -2739,6 +2861,34 @@
       r = gradients_impl.gradients([r, y], x)[0]
       self.assertAllClose([2.0, 4.0], sess.run(r, feed_dict={x: [1.0, 2.0]}))
 
+  @test_util.run_deprecated_v1
+  @test_util.enable_output_all_intermediates
+  def testWhileGradAfterSessionRun(self):
+    v0 = constant_op.constant(2.)
+    r = control_flow_ops.while_loop(
+        lambda _: True, lambda v: v * v, [v0], maximum_iterations=3)
+
+    self.assertAllEqual(r, 256.)
+    grad = gradients_impl.gradients(r, v0)[0]
+    self.assertAllClose(grad, 1024.)
+
+  @test_util.run_deprecated_v1
+  @test_util.enable_output_all_intermediates
+  def testNestedWhileGradAfterSessionRun(self):
+    v0 = constant_op.constant(2.)
+
+    def body(v):
+      inner_v0 = constant_op.constant(1.)
+      return control_flow_ops.while_loop(
+          lambda _: True, lambda x: x * v, [inner_v0], maximum_iterations=2)
+
+    r = control_flow_ops.while_loop(
+        lambda _: True, body, [v0], maximum_iterations=3)
+
+    self.assertAllEqual(r, 256.)
+    grad = gradients_impl.gradients(r, v0)[0]
+    self.assertAllClose(grad, 1024.)
+
   @test_util.run_v1_only("b/120545219")
   def testWhileGrad_MultipleUses(self):
     with self.cached_session():
diff --git a/tensorflow/python/kernel_tests/conv_ops_3d_test.py b/tensorflow/python/kernel_tests/conv_ops_3d_test.py
index 608ee57..cfb6088 100644
--- a/tensorflow/python/kernel_tests/conv_ops_3d_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_3d_test.py
@@ -18,7 +18,6 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
 import math
 
 import numpy as np
@@ -32,7 +31,7 @@
 from tensorflow.python.ops import nn_ops
 import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
 from tensorflow.python.platform import test
-from tensorflow.python.framework import test_util
+from tensorflow.python.util.compat import collections_abc
 
 
 def GetTestConfigs():
@@ -82,7 +81,7 @@
       t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype)
       t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype)
 
-      if isinstance(stride, collections.Iterable):
+      if isinstance(stride, collections_abc.Iterable):
         strides = [1] + list(stride) + [1]
       else:
         strides = [1, stride, stride, stride, 1]
@@ -140,7 +139,7 @@
     with self.cached_session(use_gpu=use_gpu):
       t1 = constant_op.constant(x1, shape=tensor_in_sizes)
       t2 = constant_op.constant(x2, shape=filter_in_sizes)
-      if isinstance(stride, collections.Iterable):
+      if isinstance(stride, collections_abc.Iterable):
         strides = list(stride)
       else:
         strides = [stride, stride, stride]
@@ -317,6 +316,33 @@
         padding="SAME",
         expected=expected_output)
 
+  def _TestConv3DEmptyTensorOutputShape(self):
+    """Verifies the output shape of the Conv3D op when output tensor is empty.
+
+    Args: none
+    """
+    input_shape = [0, 112, 112, 112, 32]
+    filter_shape = [3, 3, 3, 32, 64]
+
+    output_shape = [0, 112, 112, 112, 64]
+    input_data = 1
+    filter_data = 1
+    for data_type in self._DtypesToTest(False):
+      input_tensor = constant_op.constant(
+          input_data, shape=input_shape, dtype=data_type, name="input")
+      filter_tensor = constant_op.constant(
+          filter_data, shape=filter_shape, dtype=data_type, name="filter")
+      conv = nn_ops.conv3d(
+          input_tensor,
+          filter_tensor,
+          strides=[1, 1, 1, 1, 1],
+          dilations=[1, 1, 1, 1, 1],
+          padding="SAME",
+          data_format="NDHWC",
+          name="conv")
+      values = self.evaluate(conv)
+      self.assertEqual(values.shape, tensor_shape.TensorShape(output_shape))
+
   def testKernelSmallerThanStride(self):
     expected_output = [
         0.03703704, 0.11111111, 0.25925926, 0.33333333, 0.7037037, 0.77777778,
@@ -380,7 +406,7 @@
         filter_planes, filter_rows, filter_cols, in_depth, out_depth
     ]
 
-    if isinstance(stride, collections.Iterable):
+    if isinstance(stride, collections_abc.Iterable):
       strides = [1] + list(stride) + [1]
     else:
       strides = [1, stride, stride, stride, 1]
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index b6c2295..b51d8ae 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -18,7 +18,6 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
 import os
 import time
 
@@ -47,6 +46,7 @@
 import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
 from tensorflow.python.platform import test
 from tensorflow.python.platform import tf_logging
+from tensorflow.python.util.compat import collections_abc
 
 
 def GetShrunkInceptionShapes(shrink=10):
@@ -269,7 +269,7 @@
     with test_util.device(use_gpu):
       t1 = constant_op.constant(x1, shape=tensor_in_sizes)
       t2 = constant_op.constant(x2, shape=filter_in_sizes)
-      if isinstance(stride, collections.Iterable):
+      if isinstance(stride, collections_abc.Iterable):
         strides = list(stride)
       else:
         strides = [stride, stride]
diff --git a/tensorflow/python/kernel_tests/critical_section_test.py b/tensorflow/python/kernel_tests/critical_section_test.py
index 326820f..5e515e1 100644
--- a/tensorflow/python/kernel_tests/critical_section_test.py
+++ b/tensorflow/python/kernel_tests/critical_section_test.py
@@ -30,6 +30,7 @@
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_v2_toggles
 from tensorflow.python.ops import critical_section_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.platform import test
@@ -63,10 +64,12 @@
   @parameterized.named_parameters(
       ("Inner%sOuter%s" % (inner, outer), inner, outer)
       for (inner, outer) in itertools.product(*([(False, True)] * 2)))
-  @test_util.disable_control_flow_v2("b/135070612")
   @test_util.run_in_graph_and_eager_modes
   @test_util.xla_allow_fallback("b/128495870")
   def testCriticalSectionWithControlFlow(self, outer_cond, inner_cond):
+    if (not context.executing_eagerly() and
+        control_flow_v2_toggles.control_flow_v2_enabled()):
+      self.skipTest("b/135070612")
     cs = critical_section_ops.CriticalSection(shared_name="cs")
     v = resource_variable_ops.ResourceVariable(0.0, name="v")
     num_concurrent = 100
diff --git a/tensorflow/python/kernel_tests/diag_op_test.py b/tensorflow/python/kernel_tests/diag_op_test.py
index a9391cd..ec70b5d 100644
--- a/tensorflow/python/kernel_tests/diag_op_test.py
+++ b/tensorflow/python/kernel_tests/diag_op_test.py
@@ -315,7 +315,7 @@
       self.assertAllEqual(v_diag.eval(), mat)
 
       # LINT.IfChange
-      if compat.forward_compatible(2019, 7, 31):
+      if compat.forward_compatible(2019, 8, 31):
       # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
 
         # {Sub,Super}diagonals.
@@ -343,7 +343,7 @@
       self.assertAllEqual(v_batch_diag.eval(), mat_batch)
 
       # LINT.IfChange
-      if compat.forward_compatible(2019, 7, 31):
+      if compat.forward_compatible(2019, 8, 31):
       # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
 
         # {Sub,Super}diagonals.
@@ -378,7 +378,7 @@
   @test_util.run_deprecated_v1
   def testRectangularBatch(self):
     # LINT.IfChange
-    if compat.forward_compatible(2019, 7, 31):
+    if compat.forward_compatible(2019, 8, 31):
     # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
 
       with self.cached_session(use_gpu=True):
@@ -490,7 +490,7 @@
         self.assertLess(error, 1e-4)
 
     # LINT.IfChange
-    if compat.forward_compatible(2019, 7, 31):
+    if compat.forward_compatible(2019, 8, 31):
     # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
 
       # {Sub,super}diagonals/band.
@@ -522,7 +522,7 @@
       self.assertAllEqual(mat_set_diag, self.evaluate(output))
 
       # LINT.IfChange
-      if compat.forward_compatible(2019, 7, 31):
+      if compat.forward_compatible(2019, 8, 31):
       # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
 
         # Diagonal bands.
@@ -554,7 +554,7 @@
       self.assertAllEqual(expected, self.evaluate(output))
 
       # LINT.IfChange
-      if compat.forward_compatible(2019, 7, 31):
+      if compat.forward_compatible(2019, 8, 31):
       # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
 
         # Diagonal bands.
@@ -585,7 +585,7 @@
       self.assertAllEqual(mat_set_diag_batch, self.evaluate(output))
 
       # LINT.IfChange
-      if compat.forward_compatible(2019, 7, 31):
+      if compat.forward_compatible(2019, 8, 31):
       # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
 
         # Diagonal bands.
@@ -622,7 +622,7 @@
       self.assertAllEqual(mat_set_diag_batch, self.evaluate(output))
 
       # LINT.IfChange
-      if compat.forward_compatible(2019, 7, 31):
+      if compat.forward_compatible(2019, 8, 31):
       # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
 
         # Diagonal bands.
@@ -656,7 +656,7 @@
         array_ops.matrix_set_diag([[v]], v).eval(feed_dict={v: 0.0})
 
       # LINT.IfChange
-      if compat.forward_compatible(2019, 7, 31):
+      if compat.forward_compatible(2019, 8, 31):
       # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
 
         d = array_ops.placeholder(dtype=dtypes_lib.float32)
@@ -675,7 +675,7 @@
           np.random.rand(*diag_shape), dtype=dtypes_lib.float32)
 
       # LINT.IfChange
-      if compat.forward_compatible(2019, 7, 31):
+      if compat.forward_compatible(2019, 8, 31):
       # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
         y = array_ops.matrix_set_diag(x, x_diag, k=diags)
       else:
@@ -697,7 +697,7 @@
     diag_bands = [(0, 0)]
 
     # LINT.IfChange
-    if compat.forward_compatible(2019, 7, 31):
+    if compat.forward_compatible(2019, 8, 31):
     # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
       diag_bands.append((-1, 1))
     for input_shape, diags in itertools.product(input_shapes, diag_bands):
@@ -740,7 +740,7 @@
       self.assertAllEqual(mat_diag.eval(), v)
 
       # LINT.IfChange
-      if compat.forward_compatible(2019, 7, 31):
+      if compat.forward_compatible(2019, 8, 31):
       # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
         for offset in [-2, 3]:
           mat = np.diag(v, offset)
@@ -767,7 +767,7 @@
       self.assertAllEqual(mat_diag.eval(), np.array([1.0, 4.0]))
 
       # LINT.IfChange
-      if compat.forward_compatible(2019, 7, 31):
+      if compat.forward_compatible(2019, 8, 31):
       # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
 
         # Diagonal bands.
@@ -790,7 +790,7 @@
       self.assertAllEqual(mat_batch_diag.eval(), v_batch)
 
       # LINT.IfChange
-      if compat.forward_compatible(2019, 7, 31):
+      if compat.forward_compatible(2019, 8, 31):
       # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
 
         # Diagonal bands with padding_value.
@@ -825,7 +825,7 @@
       self.assertAllEqual(mat_batch_diag.eval(), v_batch)
 
       # LINT.IfChange
-      if compat.forward_compatible(2019, 7, 31):
+      if compat.forward_compatible(2019, 8, 31):
       # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
 
         # Diagonal bands with padding_value.
diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py
index f23b7d3..031389c 100644
--- a/tensorflow/python/kernel_tests/gather_op_test.py
+++ b/tensorflow/python/kernel_tests/gather_op_test.py
@@ -343,7 +343,7 @@
     result = array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims)
     self.assertAllEqual(expected, result)
 
-    with compat.forward_compatibility_horizon(2019, 6, 11):
+    with compat.forward_compatibility_horizon(2019, 8, 11):
       result = array_ops.gather(
           params, indices, axis=axis, batch_dims=batch_dims)
 
@@ -443,7 +443,7 @@
     self.assertAllEqual(output_shape, result.shape.as_list())
     self.assertAllEqual(expected, result)
 
-    with compat.forward_compatibility_horizon(2019, 6, 11):
+    with compat.forward_compatibility_horizon(2019, 8, 11):
       result = array_ops.gather(
           params, indices, axis=axis, batch_dims=batch_dims)
 
diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py
index 1d935ee..4b9681a 100644
--- a/tensorflow/python/kernel_tests/init_ops_test.py
+++ b/tensorflow/python/kernel_tests/init_ops_test.py
@@ -537,13 +537,6 @@
         math_ops.range(
             0, 0, 1, dtype=dtypes.float64).dtype, dtypes.float64)
 
-  def testMixedDType(self):
-    # Test case for GitHub issue 29867
-    with self.cached_session(use_gpu=True):
-      tf_ans = math_ops.range(constant_op.constant(5), dtype=dtypes.float32)
-      self.assertAllEqual(
-          self.evaluate(tf_ans), np.arange(np.int32(5), dtype=np.float32))
-
 
 # TODO(vrv): move to sequence_ops_test?
 class LinSpaceTest(test.TestCase):
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
index 4c54ec6..ab0384f 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
@@ -246,7 +246,7 @@
   # Skip Cholesky since we are explicitly testing non-hermitian
   # spectra.
   @staticmethod
-  def tests_to_skip():
+  def skip_these_tests():
     return ["cholesky"]
 
   def operator_and_matrix(
@@ -533,7 +533,7 @@
     return [dtypes.complex64, dtypes.complex128]
 
   @staticmethod
-  def tests_to_skip():
+  def skip_these_tests():
     return ["cholesky"]
 
   def operator_and_matrix(
@@ -682,7 +682,7 @@
       self.assertEqual(operator.dtype, dtypes.complex64)
       matrix = operator.to_dense().eval()
       self.assertAllEqual((2, 2 * 3 * 5, 2 * 3 * 5), matrix.shape)
-      np.testing.assert_allclose(0, np.imag(matrix), atol=1e-6)
+      np.testing.assert_allclose(0, np.imag(matrix), atol=1e-5)
 
   @test_util.run_deprecated_v1
   def test_defining_spd_operator_by_taking_real_part(self):
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py
index 2321a8c..ba611a4 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py
@@ -44,7 +44,7 @@
     self._rtol[dtypes.complex64] = 1e-4
 
   @staticmethod
-  def tests_to_skip():
+  def skip_these_tests():
     # Cholesky not implemented.
     return ["cholesky"]
 
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_householder_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_householder_test.py
index b333dbf..4179d45 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_householder_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_householder_test.py
@@ -46,7 +46,7 @@
         shape_info((2, 1, 4, 4))]
 
   @staticmethod
-  def tests_to_skip():
+  def skip_these_tests():
     # This linear operator is never positive definite.
     return ["cholesky"]
 
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py
index 5c89607..c438187 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py
@@ -181,7 +181,7 @@
   """A = L + UDU^H, D !> 0, L > 0 ==> A !> 0 and we cannot use a Cholesky."""
 
   @staticmethod
-  def tests_to_skip():
+  def skip_these_tests():
     return ["cholesky"]
 
   _use_diag_update = True
@@ -224,7 +224,7 @@
   """A = L + UV^H, L > 0 ==> A is not symmetric and we cannot use a Cholesky."""
 
   @staticmethod
-  def tests_to_skip():
+  def skip_these_tests():
     return ["cholesky"]
 
   _use_diag_update = False
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py
index 02ce5b8..71d24e3 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py
@@ -34,7 +34,7 @@
   """Most tests done in the base class LinearOperatorDerivedClassTest."""
 
   @staticmethod
-  def tests_to_skip():
+  def skip_these_tests():
     # Cholesky does not make sense for triangular matrices.
     return ["cholesky"]
 
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_toeplitz_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_toeplitz_test.py
index 22ae26f..dececb8 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_toeplitz_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_toeplitz_test.py
@@ -61,7 +61,7 @@
     self._rtol[dtypes.complex128] = 1e-10
 
   @staticmethod
-  def tests_to_skip():
+  def skip_these_tests():
     # Skip solve tests, as these could have better stability
     # (currently exercises the base class).
     # TODO(srvasude): Enable these when solve is implemented.
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_zeros_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_zeros_test.py
index 49bbc69..086f5ee 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_zeros_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_zeros_test.py
@@ -37,7 +37,7 @@
   """Most tests done in the base class LinearOperatorDerivedClassTest."""
 
   @staticmethod
-  def tests_to_skip():
+  def skip_these_tests():
     return [
         "cholesky", "log_abs_det", "inverse", "solve", "solve_with_broadcast"]
 
diff --git a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
index a8eda0f..4f5fed9 100644
--- a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
@@ -188,6 +188,7 @@
     self._verifySolve(np.empty([2, 0, 0]), np.empty([2, 0, 0]), lower=False)
     self._verifySolve(
         np.empty([2, 0, 0]), np.empty([2, 0, 0]), lower=True, batch_dims=[3, 2])
+    self._verifySolve(np.empty([0, 0]), np.empty([0, 0]), lower=True)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/pooling_ops_3d_test.py b/tensorflow/python/kernel_tests/pooling_ops_3d_test.py
index 347e092..d5331dc 100644
--- a/tensorflow/python/kernel_tests/pooling_ops_3d_test.py
+++ b/tensorflow/python/kernel_tests/pooling_ops_3d_test.py
@@ -205,6 +205,26 @@
         padding="VALID",
         expected=[29.5, 32.5, 50.5, 53.5, 176.5, 179.5, 197.5, 200.5])
 
+  def _MaxPool3DEmptyTensorOutputShape(self):
+    """Verifies the output shape of the max pooling function when tensor is empty.
+
+    Args: none
+    """
+    input_sizes = [0, 112, 112, 112, 64]
+
+    input_data = 1
+    input_tensor = constant_op.constant(
+        input_data, shape=input_sizes, name="input")
+    max_pool_3d = nn_ops.max_pool3d(
+        input_tensor,
+        ksize=[2, 2, 2],
+        strides=[2, 2, 2],
+        padding="VALID",
+        data_format="NDHWC",
+        name="max_pool_3d")
+    values = self.evaluate(max_pool_3d)
+    self.assertEqual(values.shape, (0, 56, 56, 56, 64))
+
   def _ConstructAndTestGradientForConfig(self,
                                          pool_func,
                                          input_sizes,
diff --git a/tensorflow/python/kernel_tests/proto/BUILD b/tensorflow/python/kernel_tests/proto/BUILD
index 5a0f6a9..ff86609 100644
--- a/tensorflow/python/kernel_tests/proto/BUILD
+++ b/tensorflow/python/kernel_tests/proto/BUILD
@@ -2,9 +2,8 @@
 
 load("//tensorflow:tensorflow.bzl", "tf_py_test")
 load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
-load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
-load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_all_protos")
-load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
+load("//tensorflow/core/platform:default/build_config_root.bzl", "if_static")
+load("//tensorflow/core/platform:default/build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
 # Placeholder for Google-internal load statements.
 
 package(
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index 124c586..6bf60bc 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -118,6 +118,12 @@
       self.assertAllEqual(variable.numpy(), 1.0)
       self.assertAllEqual(variable.initialized_value().numpy(), 1.0)
 
+  def testInitializeVariableUsingInitializedValue(self):
+    var1 = resource_variable_ops.ResourceVariable(1.0, name="var1")
+    var2 = resource_variable_ops.ResourceVariable(var1.initialized_value(),
+                                                  name="var2")
+    self.assertAllEqual(var2.initialized_value(), 1.0)
+
   def testEagerBool(self):
     with context.eager_mode():
       v = resource_variable_ops.ResourceVariable(False, name="bool_test")
diff --git a/tensorflow/python/kernel_tests/rnn_cell_test.py b/tensorflow/python/kernel_tests/rnn_cell_test.py
index 2b2402d..d29c533 100644
--- a/tensorflow/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/python/kernel_tests/rnn_cell_test.py
@@ -35,6 +35,7 @@
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_v2_toggles
 from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
@@ -1014,11 +1015,15 @@
                 inputs[0]: input_value
             })
 
+      comparison_fn = self.assertAllEqual
+      if (test_util.is_xla_enabled() and
+          control_flow_v2_toggles.control_flow_v2_enabled()):
+        comparison_fn = self.assertAllClose
       if in_graph_mode:
-        self.assertAllEqual(outputs_static, outputs_dynamic)
+        comparison_fn(outputs_static, outputs_dynamic)
       else:
         self.assertAllEqual(array_ops.stack(outputs_static), outputs_dynamic)
-      self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic))
+      comparison_fn(np.hstack(state_static), np.hstack(state_dynamic))
 
   @test_util.run_in_graph_and_eager_modes
   def testDynamicRNNWithNestedTupleStates(self):
@@ -1101,13 +1106,17 @@
                 inputs[0]: input_value
             })
 
+      comparison_fn = self.assertAllEqual
+      if (test_util.is_xla_enabled() and
+          control_flow_v2_toggles.control_flow_v2_enabled()):
+        comparison_fn = self.assertAllClose
       if in_graph_mode:
-        self.assertAllEqual(outputs_static, outputs_dynamic)
+        comparison_fn(outputs_static, outputs_dynamic)
       else:
         self.assertAllEqual(array_ops.stack(outputs_static), outputs_dynamic)
         state_static = nest.flatten(state_static)
         state_dynamic = nest.flatten(state_dynamic)
-      self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic))
+      comparison_fn(np.hstack(state_static), np.hstack(state_dynamic))
 
   def _testDynamicEquivalentToStaticRNN(self, use_sequence_length):
     time_steps = 8
@@ -1164,10 +1173,6 @@
             cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32)
 
       if in_graph_mode:
-        # Generate gradients and run sessions to obtain outputs
-        feeds = {concat_inputs: input_values}
-        # Initialize
-        variables_lib.global_variables_initializer().run(feed_dict=feeds)
         # Generate gradients of sum of outputs w.r.t. inputs
         static_gradients = gradients_impl.gradients(
             outputs_static + [state_static], [concat_inputs])
@@ -1186,6 +1191,10 @@
             gradients_impl.gradients(y, trainable_variables)
             for y in [outputs_static[0], outputs_static[-1], state_static]
         ])
+        # Generate gradients and run sessions to obtain outputs
+        feeds = {concat_inputs: input_values}
+        # Initialize
+        variables_lib.global_variables_initializer().run(feed_dict=feeds)
         # Test forward pass
         values_static = sess.run(outputs_static, feed_dict=feeds)
         (state_value_static,) = sess.run((state_static,), feed_dict=feeds)
@@ -1229,10 +1238,6 @@
         split_outputs_dynamic = array_ops.unstack(outputs_dynamic, time_steps)
 
       if in_graph_mode:
-        feeds = {concat_inputs: input_values}
-
-        # Initialize
-        variables_lib.global_variables_initializer().run(feed_dict=feeds)
 
         # Generate gradients of sum of outputs w.r.t. inputs
         dynamic_gradients = gradients_impl.gradients(
@@ -1260,6 +1265,11 @@
             ]
         ])
 
+        feeds = {concat_inputs: input_values}
+
+        # Initialize
+        variables_lib.global_variables_initializer().run(feed_dict=feeds)
+
         # Test forward pass
         values_dynamic = sess.run(split_outputs_dynamic, feed_dict=feeds)
         (state_value_dynamic,) = sess.run((state_dynamic,), feed_dict=feeds)
diff --git a/tensorflow/python/kernel_tests/slice_op_test.py b/tensorflow/python/kernel_tests/slice_op_test.py
index 8f72452..258b39b 100644
--- a/tensorflow/python/kernel_tests/slice_op_test.py
+++ b/tensorflow/python/kernel_tests/slice_op_test.py
@@ -348,8 +348,8 @@
     # Tensor from 0 to infinity.  This test ensures that this
     # unintended behavior is prevented.
     c = constant_op.constant(5.0)
-    with self.assertRaisesWithPredicateMatch(
-        TypeError, lambda e: "Tensor objects are only iterable" in str(e)):
+    with self.assertRaisesRegex(errors_impl.OperatorNotAllowedInGraphError,
+                                "iterating over `tf.Tensor`"):
       for _ in c:
         pass
 
diff --git a/tensorflow/python/kernel_tests/sparse_xent_op_test.py b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
index 8f0842f..79a7efa 100644
--- a/tensorflow/python/kernel_tests/sparse_xent_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
@@ -24,6 +24,7 @@
 import numpy as np
 
 from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.compat import compat
 from tensorflow.python.client import session
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -36,9 +37,7 @@
 from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import sparse_ops
-from tensorflow.python.ops import variables
 import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
 from tensorflow.python.platform import app
 from tensorflow.python.platform import test
@@ -192,7 +191,7 @@
 
   @test_util.run_deprecated_v1
   def testGradient(self):
-    with self.session(use_gpu=True):
+    with self.session(use_gpu=True) as sess:
       l = constant_op.constant([3, 0, 1], name="l")
       f = constant_op.constant(
           [0.1, 0.2, 0.3, 0.4, 0.1, 0.4, 0.9, 1.6, 0.1, 0.8, 2.7, 6.4],
@@ -202,26 +201,48 @@
       x = nn_ops.sparse_softmax_cross_entropy_with_logits(
           labels=l, logits=f, name="xent")
       err = gradient_checker.compute_gradient_error(f, [3, 4], x, [3])
+
+      # Check that no extra computation performed. When only first derivative is
+      # requested, second derivative must not be computed. So when there is no
+      # second derivative, there is no `BatchMatMul` op in the graph.
+      op_names = [
+          op.op_def.name for op in sess.graph.get_operations() if op.op_def
+      ]
+      self.assertNotIn("BatchMatMul", op_names)
+      self.assertNotIn("BatchMatMulV2", op_names)
+
     print("cross entropy gradient err = ", err)
     self.assertLess(err, 5e-8)
 
   @test_util.run_deprecated_v1
   def testSecondGradient(self):
-    images_placeholder = array_ops.placeholder(dtypes.float32, shape=(3, 2))
-    labels_placeholder = array_ops.placeholder(dtypes.int32, shape=(3))
-    weights = variables.Variable(random_ops.truncated_normal([2], stddev=1.0))
-    weights_with_zeros = array_ops.stack([array_ops.zeros([2]), weights],
-                                         axis=1)
-    logits = math_ops.matmul(images_placeholder, weights_with_zeros)
-    cross_entropy = nn_ops.sparse_softmax_cross_entropy_with_logits(
-        labels=labels_placeholder, logits=logits)
-    loss = math_ops.reduce_mean(cross_entropy)
+    with self.session() as sess:
+      l = constant_op.constant([3, 0, 1], name="l")
+      f = constant_op.constant(
+          [0.3, 0.4, 0.1, 1.2, 0.1, 1.9, 0.1, 0.7, 0.8, 0.2, 1.3, 1.3],
+          shape=[3, 4],
+          dtype=dtypes.float64,
+          name="f")
+      x = nn_ops.sparse_softmax_cross_entropy_with_logits(
+          labels=l, logits=f, name="xent")
 
-    # Taking ths second gradient should fail, since it is not
-    # yet supported.
-    with self.assertRaisesRegexp(LookupError,
-                                 "explicitly disabled"):
-      _ = gradients_impl.hessians(loss, [weights])
+      gradients = gradients_impl.gradients(x, [f])[0]
+      err = gradient_checker.compute_gradient_error(f, [3, 4], gradients,
+                                                    [3, 4])
+
+      # Check that second derivative is calculated.
+      # (it is equivalent to being `BatchMatMul` op in the graph because of
+      # implementation of xentropy grad)
+      op_names = [
+          op.op_def.name for op in sess.graph.get_operations() if op.op_def
+      ]
+      if compat.forward_compatible(2019, 4, 25):
+        self.assertIn("BatchMatMulV2", op_names)
+      else:
+        self.assertIn("BatchMatMul", op_names)
+
+    print("cross entropy hessian err = ", err)
+    self.assertLess(err, 5e-8)
 
   def _testHighDim(self, features, labels):
     np_loss, np_backprop = self._npXent(np.array(features), np.array(labels))
diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
index 1cdfdf0..68bf532 100644
--- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py
+++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
@@ -1747,6 +1747,38 @@
       self.assertAllEqual(v0, -3)
       self.assertAllEqual(v1, 100)
 
+  def testInferShapeFalseValid(self):
+    ta = tensor_array_ops.TensorArray(
+        dtypes.float32, size=3, infer_shape=False, element_shape=[None, 10, 20])
+    ta = ta.write(0, array_ops.ones([50, 10, 20]))
+    ta = ta.write(1, array_ops.ones([50, 10, 20]))
+    ta = ta.write(2, array_ops.ones([1, 10, 20]))
+    ta = ta.concat()
+
+    correct = np.ones([101, 10, 20])
+
+    self.assertAllEqual(ta, correct)
+
+  def testInferShapeFalseInvalid(self):
+    ta = tensor_array_ops.TensorArray(
+        dtypes.float32, size=2, infer_shape=False, element_shape=[None, 10, 20])
+    ta = ta.write(0, array_ops.ones([50, 10, 20]))
+
+    with self.assertRaises(ValueError):
+      ta = ta.write(1, array_ops.ones([1, 20, 20]))
+
+  def testInferShapeTrue(self):
+    ta = tensor_array_ops.TensorArray(
+        dtypes.float32, size=3, infer_shape=True, element_shape=[None, 10, 20])
+    self.assertAllEqual((None, 10, 20), ta.element_shape.as_list())
+    ta = ta.write(0, array_ops.ones([50, 10, 20]))
+    self.assertAllEqual((50, 10, 20), ta.element_shape.as_list())
+    ta = ta.write(1, array_ops.ones([50, 10, 20]))
+    with self.assertRaises(ValueError):
+      ta = ta.write(
+          2, array_ops.ones([1, 10, 20])
+      )  # Inconsistent shapes: saw (1, 10, 20) but expected (50, 10, 20)
+
 
 class TensorArrayBenchmark(test.Benchmark):
 
diff --git a/tensorflow/python/kernel_tests/unique_op_test.py b/tensorflow/python/kernel_tests/unique_op_test.py
index f203263..dce5a2a 100644
--- a/tensorflow/python/kernel_tests/unique_op_test.py
+++ b/tensorflow/python/kernel_tests/unique_op_test.py
@@ -88,6 +88,28 @@
     for i in range(len(x)):
       self.assertEqual(x[i], tf_y[tf_idx[i]])
 
+  def testBool(self):
+    x = np.random.choice([True, False], size=7000)
+    with self.cached_session() as sess:
+      y, idx = array_ops.unique(x)
+      tf_y, tf_idx = self.evaluate([y, idx])
+
+    self.assertEqual(len(x), len(tf_idx))
+    self.assertEqual(len(tf_y), len(np.unique(x)))
+    for i in range(len(x)):
+      self.assertEqual(x[i], tf_y[tf_idx[i]])
+
+  def testBoolV2(self):
+    x = np.random.choice([True, False], size=7000)
+    with self.cached_session() as sess:
+      y, idx = gen_array_ops.unique_v2(x, axis=np.array([], np.int32))
+      tf_y, tf_idx = self.evaluate([y, idx])
+
+    self.assertEqual(len(x), len(tf_idx))
+    self.assertEqual(len(tf_y), len(np.unique(x)))
+    for i in range(len(x)):
+      self.assertEqual(x[i], tf_y[tf_idx[i]])
+
 
 class UniqueWithCountsTest(test.TestCase):
 
@@ -166,6 +188,33 @@
     for value, count in zip(tf_y, tf_count):
       self.assertEqual(count, np.sum(x == value))
 
+  def testBool(self):
+    x = np.random.choice([True, False], size=7000)
+    with self.cached_session() as sess:
+      y, idx, count = array_ops.unique_with_counts(x)
+      tf_y, tf_idx, tf_count = self.evaluate([y, idx, count])
+
+    self.assertEqual(len(x), len(tf_idx))
+    self.assertEqual(len(tf_y), len(np.unique(x)))
+    for i in range(len(x)):
+      self.assertEqual(x[i], tf_y[tf_idx[i]])
+    for value, count in zip(tf_y, tf_count):
+      self.assertEqual(count, np.sum(x == value))
+
+  def testBoolV2(self):
+    x = np.random.choice([True, False], size=7000)
+    with self.cached_session() as sess:
+      y, idx, count = gen_array_ops.unique_with_counts_v2(
+          x, axis=np.array([], np.int32))
+      tf_y, tf_idx, tf_count = self.evaluate([y, idx, count])
+
+    self.assertEqual(len(x), len(tf_idx))
+    self.assertEqual(len(tf_y), len(np.unique(x)))
+    for i in range(len(x)):
+      self.assertEqual(x[i], tf_y[tf_idx[i]])
+    for value, count in zip(tf_y, tf_count):
+      self.assertEqual(count, np.sum(x == value))
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py
index 7a465dc..4af49bf 100644
--- a/tensorflow/python/kernel_tests/while_v2_test.py
+++ b/tensorflow/python/kernel_tests/while_v2_test.py
@@ -22,11 +22,13 @@
 
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.compat import compat
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import control_flow_util
 from tensorflow.python.eager import def_function
+from tensorflow.python.ops import control_flow_util_v2
+from tensorflow.python.ops import control_flow_v2_toggles
+from tensorflow.python.ops import random_ops
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import meta_graph
@@ -34,6 +36,8 @@
 from tensorflow.python.framework import test_util
 from tensorflow.python.grappler import tf_optimizer
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
 from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import list_ops
 from tensorflow.python.ops import map_fn
@@ -44,6 +48,27 @@
 from tensorflow.python.platform import test
 
 
+def random_gamma(shape):  # pylint: disable=invalid-name
+  return random_ops.random_gamma(shape, 1.0)
+
+
+def random_gamma_with_alpha_beta(shape):  # pylint: disable=invalid-name
+  return random_ops.random_gamma(
+      shape, alpha=[[1.], [3.], [5.], [6.]], beta=[[3., 4.]])
+
+
+def random_poisson_v2(shape):  # pylint: disable=invalid-name
+  return random_ops.random_poisson_v2(shape, 1.0)
+
+
+def random_poisson_v2_with_lam(shape):  # pylint: disable=invalid-name
+  return random_ops.random_poisson_v2(shape, [12.2, 3.3])
+
+
+def fill(shape):  # pylint: disable=invalid-name
+  return array_ops.fill(shape, 1.0)
+
+
 class WhileV2Test(test.TestCase, parameterized.TestCase):
 
   @test_util.run_deprecated_v1
@@ -137,6 +162,27 @@
       self.assertSequenceEqual(self.evaluate(grad), [9.])
 
   @test_util.run_deprecated_v1
+  def testMultipleLoopNonscalarCond(self):
+    x = constant_op.constant([[5.]])
+    y = constant_op.constant(3.)
+
+    # x = 5.
+    # y = 3.
+    # while x < 45.:
+    #   x = x * y
+    ret = while_loop_v2(
+        lambda v, _: v < 45.,
+        lambda v, w: (v * w, w), [x, y],
+        return_same_structure=False)
+    # ret == [x*y^2, y]
+
+    # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
+    grad = gradients_impl.gradients(ret, [x])  # [2*x*y]
+    with self.cached_session():
+      self.assertSequenceEqual(self.evaluate(ret), [45., 3.])
+      self.assertSequenceEqual(self.evaluate(grad), [9.])
+
+  @test_util.run_deprecated_v1
   def testMultipleLoopVars(self):
     x = constant_op.constant(5.)
     y = constant_op.constant(3.)
@@ -194,6 +240,115 @@
       self.assertSequenceEqual(self.evaluate(grad), [32.])
       self.assertSequenceEqual(self.evaluate(grad_grad), [48.])
 
+  def testMultipleWhileLoopsWithFunc(self):
+    if compat.forward_compatible(2019, 8, 23):
+      x = constant_op.constant(2.)
+
+      @def_function.function
+      def Fn():
+        ret1 = while_loop_v2(
+            lambda v: v < 4.,
+            lambda v: v * v, [x],
+            return_same_structure=False,
+            name="while_1")  # x**2
+        ret2 = while_loop_v2(
+            lambda v: v < 16.,
+            lambda v: v * v, [x],
+            return_same_structure=False,
+            name="while_2")  # x**4
+        return ret1, ret2
+
+      concrete_fn = Fn.get_concrete_function()
+      while_1 = concrete_fn.graph.get_operation_by_name("while_1")
+      while_2 = concrete_fn.graph.get_operation_by_name("while_2")
+      self.assertEqual(while_1.type, "StatelessWhile")
+      self.assertEqual(while_2.type, "StatelessWhile")
+      self.assertEmpty(while_1.control_inputs)
+      self.assertEmpty(while_2.control_inputs)
+
+  def testMultipleWhileLoopsWithDeps(self):
+    if compat.forward_compatible(2019, 8, 23):
+      x = variables.Variable(2.)
+      c = constant_op.constant(2.)
+
+      @def_function.function
+      def Fn():
+        ret1 = while_loop_v2(
+            lambda v: v < 4.,
+            lambda v: v * x, [c],
+            return_same_structure=False,
+            name="while_1")  # 2x
+        ret2 = while_loop_v2(
+            lambda v: v < 16.,
+            lambda v: v * x * x, [c],
+            return_same_structure=False,
+            name="while_2")  # 4x
+        return ret1, ret2
+
+      concrete_fn = Fn.get_concrete_function()
+      while_1 = concrete_fn.graph.get_operation_by_name("while_1")
+      while_2 = concrete_fn.graph.get_operation_by_name("while_2")
+      self.assertEqual(while_1.type, "While")
+      self.assertEqual(while_2.type, "While")
+      self.assertEmpty(while_1.control_inputs)
+      self.assertLen(while_2.control_inputs, 1)
+      self.assertIs(while_2.control_inputs[0], while_1)
+
+  def testMultipleWhileLoopsWithVarsDeps(self):
+    if compat.forward_compatible(2019, 8, 23):
+      x1 = variables.Variable(2.)
+      x2 = variables.Variable(3.)
+      c = constant_op.constant(2.)
+
+      @def_function.function
+      def Fn():
+        ret1 = while_loop_v2(
+            lambda v: v < 4.,
+            lambda v: v * x1, [c],
+            return_same_structure=False,
+            name="while_1")  # 2x
+        ret2 = while_loop_v2(
+            lambda v: v < 16.,
+            lambda v: v * x1 * x1, [c],
+            return_same_structure=False,
+            name="while_2")  # 4x
+        ret3 = while_loop_v2(
+            lambda v: v < 4.,
+            lambda v: v * x2, [c],
+            return_same_structure=False,
+            name="while_3")  # 3x
+        ret4 = while_loop_v2(
+            lambda v: v < 16.,
+            lambda v: v * x2 * x2, [c],
+            return_same_structure=False,
+            name="while_4")  # 9x
+        ret5 = while_loop_v2(
+            lambda v: v < 16.,
+            lambda v: v * v, [c],
+            return_same_structure=False,
+            name="while_stateless")  # x**2
+        return ret1, ret2, ret3, ret4, ret5
+
+      concrete_fn = Fn.get_concrete_function()
+      while_1 = concrete_fn.graph.get_operation_by_name("while_1")
+      while_2 = concrete_fn.graph.get_operation_by_name("while_2")
+      while_3 = concrete_fn.graph.get_operation_by_name("while_3")
+      while_4 = concrete_fn.graph.get_operation_by_name("while_4")
+      while_stateless = concrete_fn.graph.get_operation_by_name(
+          "while_stateless")
+      self.assertEqual(while_1.type, "While")
+      self.assertEqual(while_2.type, "While")
+      self.assertEqual(while_3.type, "While")
+      self.assertEqual(while_4.type, "While")
+      self.assertEqual(while_stateless.type, "StatelessWhile")
+      self.assertEmpty(while_1.control_inputs)
+      self.assertLen(while_2.control_inputs, 1)
+      self.assertIs(while_2.control_inputs[0], while_1)
+      self.assertEmpty(while_3.control_inputs)
+      self.assertLen(while_4.control_inputs, 1)
+      self.assertIs(while_4.control_inputs[0], while_3)
+      self.assertEmpty(while_stateless.control_inputs)
+
   @test_util.run_deprecated_v1
   def testDoubleDerivative(self):
     x = constant_op.constant(2.)
@@ -207,6 +362,45 @@
       self.assertSequenceEqual(self.evaluate(grad), [32.])
       self.assertSequenceEqual(self.evaluate(grad_grad), [48.])
 
+  @test_util.run_v2_only
+  def testMultipleWhileLoopsEager(self):
+
+    @def_function.function
+    def Func():
+      x = constant_op.constant(2.)
+      ret1 = while_loop_v2(
+          lambda v: v < 4., lambda v: v * v, [x],
+          return_same_structure=False)  # x**2
+      ret2 = while_loop_v2(
+          lambda v: v < 16.,
+          lambda v: v * v, [ret1],
+          return_same_structure=False)  # x**4
+      grad = gradients_impl.gradients(ret2, [x])[0]  # 4x**3
+      grad_grad = gradients_impl.gradients(grad, [x])[0]  # 12x**2
+      return grad, grad_grad
+
+    grad, grad_grad = Func()
+    self.assertEqual(grad.numpy(), 32.)
+    self.assertEqual(grad_grad.numpy(), 48.)
+
+  @test_util.run_v2_only
+  def testDoubleDerivativeEager(self):
+
+    @def_function.function
+    def Func():
+      x = constant_op.constant(2.)
+      ret = while_loop_v2(
+          lambda v: v < 8., lambda v: v**2, [x],
+          return_same_structure=False)  # x**4
+      grad = gradients_impl.gradients(ret, [x])[0]  # 4x**3
+      grad_grad = gradients_impl.gradients(grad, [x])[0]  # 12x**2
+      return ret, grad, grad_grad
+
+    ret, grad, grad_grad = Func()
+    self.assertEqual(ret.numpy(), 16.)
+    self.assertEqual(grad.numpy(), 32.)
+    self.assertEqual(grad_grad.numpy(), 48.)
+
   def _testPruning(self):
     x = constant_op.constant(1)
 
@@ -235,6 +429,7 @@
         n for n in g.node if n.op == "Enter" and
         n.attr["T"].type == dtypes.variant.as_datatype_enum
     ])
+    self.assertEmpty([n for n in g.node if n.op == "TensorListPushBack"])
 
     stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
     train_op.append(stack)
@@ -248,6 +443,7 @@
         n for n in g.node if n.op == "Enter" and
         n.attr["T"].type == dtypes.variant.as_datatype_enum
     ])
+    self.assertNotEmpty([n for n in g.node if n.op == "TensorListPushBack"])
 
   @test_util.run_deprecated_v1
   def testPruningV1(self):
@@ -258,18 +454,17 @@
   def testPruningV2(self):
     self._testPruning()
 
-  @parameterized.named_parameters(
-      ("V1", control_flow_ops.while_loop, "StackPushV2"),
-      ("V2", while_loop_v2, "TensorListPushBack"),
-  )
-  @test_util.run_deprecated_v1
-  def testDoNotAccumulateInvariants(self, while_loop_fn, push_op):
+  def _testDoNotAccumulateInvariants(self):
+    push_op = ("TensorListPushBack"
+               if control_flow_v2_toggles.control_flow_v2_enabled() else
+               "StackPushV2")
+
     # Tests that loop invariants, i.e., tensors that are "captured" by the
     # while loop and not passed as loop variables are not accumulated in
     # gradient computation.
     v = constant_op.constant(5.0, name="v")
 
-    r = while_loop_fn(
+    r = control_flow_ops.while_loop(
         lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5)
 
     output = gradients_impl.gradients(r, v)[0]
@@ -283,6 +478,142 @@
     self.assertLen([n for n in g.node if n.op == push_op], 1)
 
   @test_util.run_deprecated_v1
+  def testDoNotAccumulateInvariantsV1(self):
+    self._testDoNotAccumulateInvariants()
+
+  @test_util.run_deprecated_v1
+  @test_util.enable_control_flow_v2
+  def testDoNotAccumulateInvariantsV2(self):
+    self._testDoNotAccumulateInvariants()
+
+  @test_util.enable_control_flow_v2
+  @test_util.run_deprecated_v1
+  @test_util.enable_output_all_intermediates
+  def testPruningNested(self):
+    assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
+    x = constant_op.constant(0)
+
+    tensor_list = list_ops.empty_tensor_list(
+        element_dtype=x.dtype, element_shape=x.shape)
+
+    def Cond(x, tl):
+      del tl  # Unused for Cond.
+      return x < 25
+
+    def Body(x, tl):
+
+      def InnerCond(inner_x, unused_outer_x, unused_tl):
+        return inner_x < 5
+
+      def InnerBody(inner_x, outer_x, tl):
+        return inner_x + 1, outer_x + 1, list_ops.tensor_list_push_back(tl, x)
+
+      inner_x = constant_op.constant(0)
+      return control_flow_ops.while_loop(InnerCond, InnerBody,
+                                         [inner_x, x, tl])[1:]
+
+    outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list])
+
+    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+    train_op.append(outputs[0])
+
+    g = GetOptimizedGraph()
+    # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
+    # away, causing an extra Enter node.
+    # enter_count = 4 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
+    # self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
+    # Test that the TensorList is pruned out.
+    self.assertEmpty([
+        n for n in g.node if n.op == "Enter" and
+        n.attr["T"].type == dtypes.variant.as_datatype_enum
+    ])
+    self.assertEmpty([n for n in g.node if n.op == "TensorListPushBack"])
+    self.assertEmpty([n for n in g.node if n.op == "_While"])
+
+    stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
+    train_op.append(stack)
+    g = GetOptimizedGraph()
+    # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
+    # away, causing an extra Enter node.
+    # enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
+    # self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
+    # Test that the TensorList is not pruned out.
+    self.assertNotEmpty([
+        n for n in g.node if n.op == "Enter" and
+        n.attr["T"].type == dtypes.variant.as_datatype_enum
+    ])
+    self.assertNotEmpty([n for n in g.node if n.op == "TensorListPushBack"])
+
+  @test_util.enable_control_flow_v2
+  @test_util.run_deprecated_v1
+  @test_util.enable_output_all_intermediates
+  def testPruningNested2(self):
+    assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
+    v = constant_op.constant(5.0, name="v")
+
+    p = array_ops.placeholder(dtype=dtypes.int32)
+
+    def MidBodyBuilder(iterations):
+
+      def MidBody(i, x):
+        r = control_flow_ops.while_loop(
+            lambda *_: True,
+            lambda i, x: (i + 1, math_ops.multiply(v, x, name="my_mul")),
+            (0, x),
+            maximum_iterations=iterations,
+            name="inner")
+        return (i + 1, gradients_impl.gradients(x + r[1], v)[0])
+
+      return MidBody
+
+    def OuterBody(i, x):
+      iterations = array_ops.size(p, name="iterations")
+      return (i + 1, x + control_flow_ops.while_loop(
+          lambda *_: True,
+          MidBodyBuilder(iterations), (0, x),
+          maximum_iterations=iterations,
+          name="mid")[1])
+
+    def CreateWhileLoop():
+      with ops.device("/cpu:0"):
+        r = control_flow_ops.while_loop(
+            lambda *_: True,
+            OuterBody, (0, 1.0),
+            maximum_iterations=5,
+            name="outer")
+        return array_ops.identity(r[1])
+
+    output = CreateWhileLoop()
+    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+    train_op.append(output)
+
+    g = GetOptimizedGraph()
+    self.assertLen([n for n in g.node if n.op == "TensorListPushBack"], 1)
+
+  @test_util.enable_control_flow_v2
+  @test_util.run_deprecated_v1
+  @test_util.enable_output_all_intermediates
+  def testPruningNested3(self):
+    assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
+    v = constant_op.constant(5.0, name="v")
+
+    def CreateWhileLoop():
+      r = control_flow_ops.while_loop(
+          lambda _: True,
+          lambda x: math_ops.multiply(v, x, name="my_mul"), [1.0],
+          maximum_iterations=5,
+          name="outer")
+      return array_ops.identity(r)
+
+    r = CreateWhileLoop()
+    output = gradients_impl.gradients(r, v)[0]
+    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+    train_op.append(output)
+
+    g = GetOptimizedGraph()
+    self.assertLen([n for n in g.node if n.op == "TensorListPushBack"], 1)
+
+  @test_util.run_deprecated_v1
   def testCaptureExternalTensorInCond(self):
     x = constant_op.constant(2.)
     y = constant_op.constant(1.)
@@ -291,7 +622,7 @@
         lambda v: v * 3., [x],
         return_same_structure=False)
     grad = gradients_impl.gradients(ret, [x])
-    with self.cached_session() as sess:
+    with self.cached_session():
       self.assertEqual(self.evaluate(ret), 18.)
       self.assertSequenceEqual(self.evaluate(grad), [9.])
 
@@ -302,7 +633,7 @@
     ret = while_loop_v2(
         lambda v: v < 8., lambda v: v * y, [x], return_same_structure=False)
     grad = gradients_impl.gradients(ret, [x])
-    with self.cached_session() as sess:
+    with self.cached_session():
       self.assertEqual(self.evaluate(ret), 18.)
       self.assertSequenceEqual(self.evaluate(grad), [9.])
 
@@ -350,7 +681,7 @@
         Cond, Body, [x, tensor_list], return_same_structure=False)
 
     for op in ops.get_default_graph().get_operations():
-      if op.type == "While":
+      if op.type == "While" or op.type == "StatelessWhile":
         while_op = op
 
     body_graph = while_v2._get_graph(while_op, "body")
@@ -433,7 +764,8 @@
         lambda i: i + 1, [constant_op.constant(0)],
         return_same_structure=False)
     while_op = output.op.inputs[0].op
-    self.assertEqual(while_op.type, "While")
+    if compat.forward_compatible(2019, 8, 23):
+      self.assertEqual(while_op.type, "StatelessWhile")
     return while_op
 
   def testDefaultName(self):
@@ -514,23 +846,132 @@
 
   @test_util.run_deprecated_v1
   def testForwardPassRewrite(self):
-    x = constant_op.constant(1.0, name="x")
-    output = while_v2.while_loop(lambda x: x < 10.0,
-                                 lambda x: x * 2.0,
-                                 [x])[0]
-    while_op = output.op.inputs[0].op
-    self.assertEqual(while_op.type, "While")
-    # outputs = [loop_counter, max_iters, x]
-    self.assertLen(while_op.outputs, 3)
+    if compat.forward_compatible(2019, 8, 23):
+      x = constant_op.constant(1.0, name="x")
+      output = while_v2.while_loop(lambda x: x < 10.0,
+                                   lambda x: x * 2.0,
+                                   [x])[0]
+      while_op = output.op.inputs[0].op
+      self.assertEqual(while_op.type, "StatelessWhile")
+      # outputs = [loop_counter, max_iters, x]
+      self.assertLen(while_op.outputs, 3)
 
-    gradients_impl.gradients(output, x)
-    # while_op should have been rewritten to output 2.0 intermediate.
-    # outputs = [loop_counter, max_iters, x, 2.0_accumulator, x_accumulator]
-    self.assertLen(while_op.outputs, 5)
+      gradients_impl.gradients(output, x)
+      # while_op should have been rewritten to output intermediates.
+      # outputs = [loop_counter, max_iters, x, x_accumulator]
+      self.assertLen(while_op.outputs, 4)
 
-    gradients_impl.gradients(output, x)
-    # Computing the gradient again shouldn't rewrite while_op again.
-    self.assertLen(while_op.outputs, 5)
+      gradients_impl.gradients(output, x)
+      # Computing the gradient again shouldn't rewrite while_op again.
+      self.assertLen(while_op.outputs, 4)
+
+  @parameterized.named_parameters(
+      ("RandomUniform", random_ops.random_uniform, [5, 3]),
+      ("RandomNormal", random_ops.random_normal, [5, 3]),
+      ("ParameterizedTruncatedNormal",
+       random_ops.parameterized_truncated_normal, [5, 3]),
+      ("TruncatedNormal", random_ops.truncated_normal, [5, 3]),
+      ("RandomGamma", random_gamma, [5, 3]),
+      ("RandomPoissonV2", random_poisson_v2, [5, 3]),
+      ("RandomGammaWithAlphaBeta", random_gamma_with_alpha_beta, [5, 3, 4, 2]),
+      ("RandomPoissonV2WithLam", random_poisson_v2_with_lam, [5, 3, 2]),
+  )
+  @test_util.run_deprecated_v1
+  def testRandomOpsShape(self, random_fn, expected_shape):
+    shape = constant_op.constant([3])
+
+    def Body(i, u):
+      shape_extended = array_ops.concat([[5], shape], axis=0)
+      u = random_fn(shape_extended)
+      assert u.shape.as_list() == expected_shape, str(u.shape.as_list())
+      return i + 1, u
+
+    _, _ = while_loop_v2(
+        cond=lambda i, _: i < 3,
+        body=Body,
+        loop_vars=[
+            0,
+            array_ops.zeros(expected_shape, dtype=dtypes.float32),
+        ])
+
+  @test_util.run_deprecated_v1
+  def testReshapeShape(self):
+    shape = constant_op.constant([3, 4])
+
+    def Body(i, u):
+      shape_extended = array_ops.concat([[5], shape], axis=0)
+      u = array_ops.reshape(u, [-1])
+      assert u.shape.as_list() == [60], str(u.shape.as_list())
+      u = array_ops.reshape(u, shape_extended)
+      assert u.shape.as_list() == [5, 3, 4], str(u.shape.as_list())
+      return i + 1, u
+
+    _, _ = while_loop_v2(
+        cond=lambda i, _: i < 3,
+        body=Body,
+        loop_vars=[
+            0,
+            array_ops.zeros([5, 3, 4], dtype=dtypes.float32),
+        ])
+
+  @parameterized.named_parameters(
+      ("Zeros", array_ops.zeros),
+      ("Ones", array_ops.ones),
+      ("Fill", fill),
+  )
+  @test_util.run_deprecated_v1
+  def testFillOpsShape(self, fill_fn):
+    shape = constant_op.constant([3, 4])
+
+    def Body(i, u):
+      shape_extended = array_ops.concat([[5], shape], axis=0)
+      u = fill_fn(shape_extended)
+      assert u.shape.as_list() == [5, 3, 4], str(u.shape.as_list())
+      return i + 1, u
+
+    _, _ = while_loop_v2(
+        cond=lambda i, _: i < 3,
+        body=Body,
+        loop_vars=[
+            0,
+            array_ops.zeros([5, 3, 4], dtype=dtypes.float32),
+        ])
+
+  @test_util.run_deprecated_v1
+  def testExternalColocationGrad(self):
+    external_t = constant_op.constant(2.)
+    v0 = constant_op.constant(2.)
+
+    def Body(v):
+      with ops.colocate_with(external_t):
+        return v * v
+
+    ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0]
+    grad = gradients_impl.gradients(ret, [v0])[0]
+    self.assertAllEqual(ret, 16.)
+    self.assertAllEqual(grad, 32.)
+
+  @test_util.run_deprecated_v1
+  def testDoNotAccumulateConstNodes(self):
+
+    def Body(v):
+      return v * 2.0
+
+    v0 = constant_op.constant(2.)
+    ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0]
+    # Gradients computation has the side-effect of updating the forward op
+    # which is what we want to test.
+    unused_grad = gradients_impl.gradients(ret, [v0])[0]
+    # ret is separated from the `While` op by an `Identity` so we skip over
+    # that.
+    forward_while_op = ret.op.inputs[0].op
+    body_graph = while_v2._get_graph(forward_while_op, "body")
+    push_back_nodes = [
+        o for o in body_graph.get_operations() if o.type == "TensorListPushBack"
+    ]
+    # Gradient of `Mul` requires accumulating both its inputs. But since one
+    # of those is a Const (2.0), we should have just one accumulator.
+    self.assertLen(push_back_nodes, 1)
 
 
 def ScalarShape():
diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py
index 031c8ca..93a0af21 100644
--- a/tensorflow/python/kernel_tests/xent_op_test.py
+++ b/tensorflow/python/kernel_tests/xent_op_test.py
@@ -242,6 +242,7 @@
           op.op_def.name for op in sess.graph.get_operations() if op.op_def
       ]
       self.assertNotIn("BatchMatMul", op_names)
+      self.assertNotIn("BatchMatMulV2", op_names)
 
     print("cross entropy gradient err = ", err)
     self.assertLess(err, 5e-8)
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 7137a3b..f1cc132 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -24,12 +24,14 @@
 from tensorflow.python.framework import ops
 from tensorflow.python.keras import backend
 from tensorflow.python.keras.engine import base_layer
+from tensorflow.python.keras.mixed_precision.experimental import policy
 from tensorflow.python.ops import variable_scope as vs
 from tensorflow.python.ops import variables as tf_variables
 from tensorflow.python.training.tracking import base as trackable
 from tensorflow.python.util import deprecation
 from tensorflow.python.util import function_utils
 from tensorflow.python.util import nest
+from tensorflow.python.util import object_identity
 from tensorflow.python.util import tf_contextlib
 from tensorflow.python.util.tf_export import tf_export
 
@@ -199,6 +201,15 @@
     self._trainable_weights = []
     self.built = False
 
+    if dtype is None:
+      # Indicates to infer dtype from inputs. When the V2 dtype behavior is
+      # enabled, Keras layers default their dtype to floatx instead, so we pass
+      # an "infer" policy to keep the old V1 behavior.
+      dtype = policy.Policy('infer')
+
+    if 'autocast' not in kwargs:
+      kwargs['autocast'] = False
+
     super(Layer, self).__init__(trainable=trainable, name=name, dtype=dtype,
                                 **kwargs)
 
@@ -577,7 +588,7 @@
   collection_list = nest.flatten(collection_list)
   for name in collection_list:
     collection = ops.get_collection_ref(name)
-    collection_set = set(collection)
+    collection_set = object_identity.ObjectIdentitySet(collection)
     for element in elements:
       if element not in collection_set:
         collection.append(element)
diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py
index 3dd09a0..1481ef5 100644
--- a/tensorflow/python/layers/base_test.py
+++ b/tensorflow/python/layers/base_test.py
@@ -20,6 +20,8 @@
 
 import copy
 
+import numpy as np
+
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
 from tensorflow.python.framework import constant_op
@@ -638,5 +640,69 @@
     self.assertEqual(len(layer.get_losses_for([intermediate_inputs])), 1)
     self.assertEqual(len(layer.get_losses_for([outputs])), 0)
 
+
+class IdentityLayer(base_layers.Layer):
+  """A layer returns the identity of it's input."""
+
+  def call(self, inputs):
+    return inputs
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class DTypeTest(test.TestCase):
+
+  def _const(self, dtype):
+    return array_ops.constant(1, dtype=dtype)
+
+  def test_dtype_inferred_from_input(self):
+    # Test with Tensor input
+    layer = IdentityLayer()
+    self.assertIsNone(layer.dtype)
+    layer(self._const('float64'))
+    self.assertEqual(layer.dtype, 'float64')
+
+    # Test with Numpy input
+    layer = IdentityLayer()
+    self.assertIsNone(layer.dtype)
+    layer(np.array(1., dtype='float64'))
+    self.assertEqual(layer.dtype, 'float64')
+
+    # Test with integer input
+    layer = IdentityLayer()
+    self.assertIsNone(layer.dtype)
+    layer(self._const('int32'))
+    self.assertEqual(layer.dtype, 'int32')
+
+    # Test layer dtype doesn't change when passed a new dtype
+    layer = IdentityLayer()
+    self.assertIsNone(layer.dtype)
+    layer(self._const('float64'))
+    self.assertEqual(layer.dtype, 'float64')
+    layer(self._const('float16'))
+    self.assertEqual(layer.dtype, 'float64')
+
+    # Test layer dtype inferred from first input
+    layer = IdentityLayer()
+    layer([self._const('float32'), self._const('float64')])
+    self.assertEqual(layer.dtype, 'float32')
+
+  def test_passing_dtype_to_constructor(self):
+    layer = IdentityLayer(dtype='float64')
+    layer(self._const('float32'))
+    self.assertEqual(layer.dtype, 'float64')
+
+    layer = IdentityLayer(dtype='int32')
+    layer(self._const('float32'))
+    self.assertEqual(layer.dtype, 'int32')
+
+    layer = IdentityLayer(dtype=dtypes.float64)
+    layer(self._const('float32'))
+    self.assertEqual(layer.dtype, 'float64')
+
+  def test_inputs_not_casted(self):
+    layer = IdentityLayer(dtype='float32')
+    self.assertEqual(layer(self._const('float64')).dtype, 'float64')
+
+
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/lib/core/ndarray_tensor.cc b/tensorflow/python/lib/core/ndarray_tensor.cc
index 3390afa..bce79b1 100644
--- a/tensorflow/python/lib/core/ndarray_tensor.cc
+++ b/tensorflow/python/lib/core/ndarray_tensor.cc
@@ -168,6 +168,16 @@
       if (pyarray_type == Bfloat16NumpyType()) {
         *out_tf_datatype = TF_BFLOAT16;
         break;
+      } else if (pyarray_type == NPY_ULONGLONG) {
+        // NPY_ULONGLONG is equivalent to NPY_UINT64, while their enum values
+        // might be different on certain platforms.
+        *out_tf_datatype = TF_UINT64;
+        break;
+      } else if (pyarray_type == NPY_LONGLONG) {
+        // NPY_LONGLONG is equivalent to NPY_INT64, while their enum values
+        // might be different on certain platforms.
+        *out_tf_datatype = TF_INT64;
+        break;
       }
       return errors::Internal("Unsupported numpy type: ",
                               numpy_type_name(pyarray_type));
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 981d531..112963a 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -56,6 +56,123 @@
 _BaseSlice = slice
 
 
+@tf_export("reshape", v1=["reshape", "manip.reshape"])
+def reshape(tensor, shape, name=None):  # pylint: disable=redefined-outer-name
+  r"""Reshapes a tensor.
+
+  Given `tensor`, this operation returns a tensor that has the same values
+  as `tensor` with shape `shape`.
+
+  If one component of `shape` is the special value -1, the size of that
+  dimension is computed so that the total size remains constant.  In particular,
+  a `shape` of `[-1]` flattens into 1-D.  At most one component of `shape` can
+  be -1.
+
+  If `shape` is 1-D or higher, then the operation returns a tensor with shape
+  `shape` filled with the values of `tensor`. In this case, the number of
+  elements implied by `shape` must be the same as the number of elements in
+  `tensor`.
+
+  For example:
+
+  ```
+  # tensor 't' is [1, 2, 3, 4, 5, 6, 7, 8, 9]
+  # tensor 't' has shape [9]
+  reshape(t, [3, 3]) ==> [[1, 2, 3],
+                          [4, 5, 6],
+                          [7, 8, 9]]
+
+  # tensor 't' is [[[1, 1], [2, 2]],
+  #                [[3, 3], [4, 4]]]
+  # tensor 't' has shape [2, 2, 2]
+  reshape(t, [2, 4]) ==> [[1, 1, 2, 2],
+                          [3, 3, 4, 4]]
+
+  # tensor 't' is [[[1, 1, 1],
+  #                 [2, 2, 2]],
+  #                [[3, 3, 3],
+  #                 [4, 4, 4]],
+  #                [[5, 5, 5],
+  #                 [6, 6, 6]]]
+  # tensor 't' has shape [3, 2, 3]
+  # pass '[-1]' to flatten 't'
+  reshape(t, [-1]) ==> [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6]
+
+  # -1 can also be used to infer the shape
+
+  # -1 is inferred to be 9:
+  reshape(t, [2, -1]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3],
+                           [4, 4, 4, 5, 5, 5, 6, 6, 6]]
+  # -1 is inferred to be 2:
+  reshape(t, [-1, 9]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3],
+                           [4, 4, 4, 5, 5, 5, 6, 6, 6]]
+  # -1 is inferred to be 3:
+  reshape(t, [ 2, -1, 3]) ==> [[[1, 1, 1],
+                                [2, 2, 2],
+                                [3, 3, 3]],
+                               [[4, 4, 4],
+                                [5, 5, 5],
+                                [6, 6, 6]]]
+
+  # tensor 't' is [7]
+  # shape `[]` reshapes to a scalar
+  reshape(t, []) ==> 7
+  ```
+
+  Args:
+    tensor: A `Tensor`.
+    shape: A `Tensor`. Must be one of the following types: `int32`, `int64`.
+      Defines the shape of the output tensor.
+    name: A name for the operation (optional).
+
+  Returns:
+    A `Tensor`. Has the same type as `tensor`.
+  """
+  result = gen_array_ops.reshape(tensor, shape, name)
+  tensor_util.maybe_set_static_shape(result, shape)
+  return result
+
+
+@tf_export("fill")
+def fill(dims, value, name=None):
+  r"""Creates a tensor filled with a scalar value.
+
+  This operation creates a tensor of shape `dims` and fills it with `value`.
+
+  For example:
+
+  ```
+  # Output tensor has shape [2, 3].
+  fill([2, 3], 9) ==> [[9, 9, 9]
+                       [9, 9, 9]]
+  ```
+
+  `tf.fill` differs from `tf.constant` in a few ways:
+
+  *   `tf.fill` only supports scalar contents, whereas `tf.constant` supports
+      Tensor values.
+  *   `tf.fill` creates an Op in the computation graph that constructs the
+  actual
+      Tensor value at runtime. This is in contrast to `tf.constant` which embeds
+      the entire Tensor into the graph with a `Const` node.
+  *   Because `tf.fill` evaluates at graph runtime, it supports dynamic shapes
+      based on other runtime Tensors, unlike `tf.constant`.
+
+  Args:
+    dims: A `Tensor`. Must be one of the following types: `int32`, `int64`. 1-D.
+      Represents the shape of the output tensor.
+    value: A `Tensor`. 0-D (scalar). Value to fill the returned tensor.
+      @compatibility(numpy) Equivalent to np.full @end_compatibility
+    name: A name for the operation (optional).
+
+  Returns:
+    A `Tensor`. Has the same type as `value`.
+  """
+  result = gen_array_ops.fill(dims, value, name=name)
+  tensor_util.maybe_set_static_shape(result, dims)
+  return result
+
+
 @tf_export("identity")
 @dispatch.add_dispatch_support
 def identity(input, name=None):  # pylint: disable=redefined-builtin
@@ -627,13 +744,15 @@
       # python doesn't always use None when constructing ranges
       # for example a[:] gives slice(None,sys.maxsize,None)
       # whereas a[::1] gives slice(None,None,None)
-      if s.start is not None and s.start is not sys.maxsize:
+      if s.start is not None and (isinstance(s.start, ops.Tensor) or
+                                  s.start != sys.maxsize):
         _check_index(s.start)
         begin.append(s.start)
       else:
         begin.append(0)
         begin_mask |= (1 << index)
-      if s.stop is not None and s.stop != sys.maxsize:
+      if s.stop is not None and (isinstance(s.stop, ops.Tensor) or
+                                 s.stop != sys.maxsize):
         _check_index(s.stop)
         end.append(s.stop)
       else:
@@ -1308,7 +1427,7 @@
       ops.convert_to_tensor(
           axis, name="concat_dim",
           dtype=dtypes.int32).get_shape().assert_has_rank(0)
-      return identity(values[0], name=scope)
+      return identity(values[0], name=name)
   return gen_array_ops.concat_v2(values=values, axis=axis, name=name)
 
 
@@ -1951,7 +2070,7 @@
     A Tensor. Has the same type as `diagonal`.
   """
   # LINT.IfChange
-  if compat.forward_compatible(2019, 7, 31):
+  if compat.forward_compatible(2019, 8, 31):
     # LINT.ThenChange(//tensorflow/python/kernel_tests/diag_op_test.py)
 
     # Special case to sidestep the tf.constant conversion error:
@@ -2063,7 +2182,7 @@
     A Tensor containing diagonals of `input`. Has the same type as `input`.
   """
   # LINT.IfChange
-  if compat.forward_compatible(2019, 7, 31):
+  if compat.forward_compatible(2019, 8, 31):
     # LINT.ThenChange(//tensorflow/python/kernel_tests/diag_op_test.py)
 
     # Special case to sidestep the tf.constant conversion error:
@@ -2170,7 +2289,7 @@
       and high ends of a matrix band. `k[0]` must not be larger than `k[1]`.
   """
   # LINT.IfChange
-  if compat.forward_compatible(2019, 7, 31):
+  if compat.forward_compatible(2019, 8, 31):
     # LINT.ThenChange(//tensorflow/python/kernel_tests/diag_op_test.py)
     return gen_array_ops.matrix_set_diag_v2(
         input=input, diagonal=diagonal, k=k, name=name)
@@ -2381,7 +2500,7 @@
     input,  # pylint: disable=redefined-builtin
     dtype=None,
     name=None):
-  """Creates a tensor with all elements set to zero.
+  """Creates a tensor with all elements set to one.
 
   Given a single tensor (`tensor`), this operation returns a tensor of the
   same type and shape as `tensor` with all elements set to 1. Optionally,
@@ -2402,7 +2521,7 @@
     name: A name for the operation (optional).
 
   Returns:
-    A `Tensor` with all elements set to zero.
+    A `Tensor` with all elements set to one.
   """
   return ones_like_impl(input, dtype, name, optimize=True)
 
@@ -2728,11 +2847,11 @@
   if mode == "CONSTANT":
     # TODO(rjryan): Once the forward compatibility period (3 weeks) have passed
     # remove the "Pad" fallback here.
-    if constant_values != 0:
+    if not tensor_util.is_tensor(constant_values) and constant_values == 0:
+      result = gen_array_ops.pad(tensor, paddings, name=name)
+    else:
       result = gen_array_ops.pad_v2(
           tensor, paddings, constant_values, name=name)
-    else:
-      result = gen_array_ops.pad(tensor, paddings, name=name)
   elif mode == "REFLECT":
     result = gen_array_ops.mirror_pad(
         tensor, paddings, mode="REFLECT", name=name)
@@ -3250,6 +3369,7 @@
 
 
 @tf_export("one_hot")
+@dispatch.add_dispatch_support
 def one_hot(indices,
             depth,
             on_value=None,
@@ -3293,6 +3413,11 @@
     depth x batch x features if axis == 0
   ```
 
+  If `indices` is a RaggedTensor, the 'axis' argument must be positive and refer
+  to a non-ragged axis. The output will be equivalent to applying 'one_hot' on
+  the values of the RaggedTensor, and creating a new RaggedTensor from the
+  result.
+
   If `dtype` is not provided, it will attempt to assume the data type of
   `on_value` or `off_value`, if one or both are passed in. If none of
   `on_value`, `off_value`, or `dtype` are provided, `dtype` will default to the
@@ -3330,6 +3455,13 @@
   #   [0.0, 0.0, 1.0]],  # one_hot(2)
   #  [[0.0, 1.0, 0.0],   # one_hot(1)
   #   [0.0, 0.0, 0.0]]]  # one_hot(-1)
+
+  indices = tf.ragged.constant([[0, 1], [2]])
+  depth = 3
+  tf.one_hot(indices, depth)  # output: [2 x None x 3]
+  # [[[1., 0., 0.],
+  #   [0., 1., 0.]],
+  #  [[0., 0., 1.]]]
   ```
 
   Args:
@@ -3820,7 +3952,7 @@
     A `Tensor`. Has the same type as `params`.
   """
   del validate_indices
-  if compat.forward_compatible(2019, 8, 10):
+  if compat.forward_compatible(2019, 9, 10):
     if axis is None:
       axis = batch_dims
     if axis != 0:
diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py
index 3997c40..7d533bc 100644
--- a/tensorflow/python/ops/check_ops.py
+++ b/tensorflow/python/ops/check_ops.py
@@ -496,6 +496,10 @@
     x = ops.convert_to_tensor(x, name='x')
     y = ops.convert_to_tensor(y, name='y')
 
+    # Short-circuit if x and y are the same tensor.
+    if x is y:
+      return None if context.executing_eagerly() else control_flow_ops.no_op()
+
     if context.executing_eagerly():
       eq = math_ops.equal(x, y)
       condition = math_ops.reduce_all(eq)
diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py
index a247a33..bb88019 100644
--- a/tensorflow/python/ops/clip_ops.py
+++ b/tensorflow/python/ops/clip_ops.py
@@ -18,8 +18,6 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
-
 import six
 
 from tensorflow.python.framework import constant_op
@@ -31,6 +29,7 @@
 from tensorflow.python.ops import math_ops
 from tensorflow.python.util import deprecation
 from tensorflow.python.util import dispatch
+from tensorflow.python.util.compat import collections_abc
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -49,14 +48,14 @@
   correct results.
 
   For example:
-  
+
   ```python
   A = tf.constant([[1, 20, 13], [3, 21, 13]])
   B = tf.clip_by_value(A, clip_value_min=0, clip_value_max=3) # [[1, 3, 3],[3, 3, 3]]
-  C = tf.clip_by_value(A, clip_value_min=0., clip_value_max=3.) # throws `TypeError` 
+  C = tf.clip_by_value(A, clip_value_min=0., clip_value_max=3.) # throws `TypeError`
   as input and clip_values are of different dtype
   ```
-  
+
   Args:
     t: A `Tensor` or `IndexedSlices`.
     clip_value_min: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape
@@ -71,8 +70,8 @@
   Raises:
     ValueError: If the clip tensors would trigger array broadcasting
       that would make the returned tensor larger than the input.
-    TypeError: If dtype of the input is `int32` and dtype of 
-    the `clip_value_min' or `clip_value_max` is `float32`  
+    TypeError: If dtype of the input is `int32` and dtype of
+    the `clip_value_min' or `clip_value_max` is `float32`
   """
   with ops.name_scope(name, "clip_by_value",
                       [t, clip_value_min, clip_value_max]) as name:
@@ -208,8 +207,8 @@
   Raises:
     TypeError: If `t_list` is not a sequence.
   """
-  if (not isinstance(t_list, collections.Sequence)
-      or isinstance(t_list, six.string_types)):
+  if (not isinstance(t_list, collections_abc.Sequence) or
+      isinstance(t_list, six.string_types)):
     raise TypeError("t_list should be a sequence")
   t_list = list(t_list)
   with ops.name_scope(name, "global_norm", t_list) as name:
@@ -282,8 +281,8 @@
   Raises:
     TypeError: If `t_list` is not a sequence.
   """
-  if (not isinstance(t_list, collections.Sequence)
-      or isinstance(t_list, six.string_types)):
+  if (not isinstance(t_list, collections_abc.Sequence) or
+      isinstance(t_list, six.string_types)):
     raise TypeError("t_list should be a sequence")
   t_list = list(t_list)
   if use_norm is None:
diff --git a/tensorflow/python/ops/collective_ops_benchmark.py b/tensorflow/python/ops/collective_ops_benchmark.py
new file mode 100644
index 0000000..870dec5
--- /dev/null
+++ b/tensorflow/python/ops/collective_ops_benchmark.py
@@ -0,0 +1,86 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Local CPU benchmarks for collective ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+import numpy as np
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import collective_ops
+from tensorflow.python.platform import test
+
+
+class CollectiveOpBenchmark(test.Benchmark):
+  """Benchmarks for local CPU collective op execution."""
+
+  def benchmark_collective(self):
+    """Measures the performance of local CPU collective execution."""
+    shapes = [(10,), (1000,), (1000000,)]
+    devices = [2, 4, 8]
+    collective_key_counter = 0
+
+    for group_size in devices:
+      group_key = collective_key_counter
+      instance_key = collective_key_counter
+      collective_key_counter += 1
+
+      for shape in shapes:
+        config = config_pb2.ConfigProto(device_count={"CPU": group_size})
+        with session.Session(config=config) as sess:
+          # Use a C++ callable to minimize the Python overhead in the benchmark.
+          callable_opts = config_pb2.CallableOptions()
+          reduce_ops = []
+          for device in range(group_size):
+            with ops.device("CPU:{}".format(device)):
+              t = constant_op.constant(np.multiply(range(shape[0]), 1.0))
+              r = collective_ops.all_reduce(t, group_size, group_key,
+                                            instance_key, "Add", "Div")
+              reduce_ops.append(r)
+              callable_opts.target.append(r.name)
+          op_callable = sess._make_callable_from_options(callable_opts)  # pylint: disable=protected-access
+
+          # Run five steps to warm up the session caches and do collective param
+          # resolution before taking the first measurement.
+          for _ in range(5):
+            op_callable()
+          deltas = []
+          overall_start = time.time()
+          # Run at least five repetitions and for at least five seconds.
+          while len(deltas) < 5 or time.time() - overall_start < 5.0:
+            start = time.time()
+            for _ in range(100):
+              op_callable()
+            end = time.time()
+            deltas.append(end - start)
+          del op_callable
+
+        median_wall_time = np.median(deltas) / 100.0
+        iters = len(deltas) * 100
+
+        self.report_benchmark(
+            iters=iters, wall_time=median_wall_time,
+            name="num_elements_{}_num_devices_{}".format(np.prod(shape),
+                                                         group_size))
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/python/ops/collective_ops_gpu_test.py b/tensorflow/python/ops/collective_ops_gpu_test.py
new file mode 100644
index 0000000..4e92281
--- /dev/null
+++ b/tensorflow/python/ops/collective_ops_gpu_test.py
@@ -0,0 +1,163 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Collective Operations that require GPU."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import collective_ops
+from tensorflow.python.platform import test
+
+
+class CollectiveOpGPUTest(test.TestCase):
+
+  def _configure(self, group_size):
+    """Set environment variables and return `ConfigProto` for NCCL execution."""
+    # Configure virtual GPU devices
+    virtual_devices = [config_pb2.GPUOptions.Experimental.VirtualDevices(
+        memory_limit_mb=([1 << 10] * group_size))]  # 1 GB per virtual GPU
+    gpu_options = config_pb2.GPUOptions(
+        visible_device_list='0',
+        experimental=config_pb2.GPUOptions.Experimental(
+            virtual_devices=virtual_devices))
+    # Configure NCCL
+    experimental = config_pb2.ConfigProto.Experimental(collective_nccl=True)
+    os.environ['NCCL_DEBUG'] = 'INFO'
+    os.environ['NCCL_LAUNCH_MODE'] = 'PARALLEL'
+    return config_pb2.ConfigProto(gpu_options=gpu_options,
+                                  experimental=experimental)
+
+  @test_util.run_deprecated_v1
+  def testBasicNcclAllReduce(self):
+    inputs = [[0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
+              [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3]]
+    expected = [0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2]
+    group_size = len(inputs)
+    group_key = 1
+    instance_key = 1
+    devices = ['/GPU:{}'.format(i) for i in range(group_size)]
+
+    with self.session(config=self._configure(group_size)) as sess:
+      if not test_util.is_gpu_available(cuda_only=True):
+        self.skipTest('No GPU available')
+      collectives = []
+      for i in range(group_size):
+        with ops.device(devices[i]):
+          t = constant_op.constant(inputs[i])
+          collectives.append(collective_ops.all_reduce(
+              t, group_size, group_key, instance_key, 'Add', 'Div'))
+      results = sess.run(collectives)
+    for result in results:
+      self.assertAllClose(result, expected, rtol=1e-5, atol=1e-5)
+
+  @test_util.run_deprecated_v1
+  def testBasicNcclBroadcast(self):
+    tensor_value = [0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1]
+    group_size = 2
+    group_key = 1
+    instance_key = 1
+    devices = ['/GPU:{}'.format(i) for i in range(group_size)]
+
+    with self.session(config=self._configure(group_size)) as sess:
+      if not test_util.is_gpu_available(cuda_only=True):
+        self.skipTest('No GPU available')
+      collectives = []
+      with ops.device(devices[0]):
+        t = constant_op.constant(tensor_value)
+        collectives.append(collective_ops.broadcast_send(
+            t, t.shape, t.dtype, group_size, group_key, instance_key))
+      with ops.device(devices[1]):
+        t = constant_op.constant(tensor_value)
+        collectives.append(collective_ops.broadcast_recv(
+            t.shape, t.dtype, group_size, group_key, instance_key))
+      results = sess.run(collectives)
+    for result in results:
+      self.assertAllClose(result, tensor_value, rtol=1e-5, atol=1e-5)
+
+  @test_util.run_deprecated_v1
+  def testNcclBroadcastDoubleRecv(self):
+    tensor_value = [0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1]
+    group_size = 2
+    group_key = 1
+    instance_key = 1
+    devices = ['/GPU:{}'.format(i) for i in range(group_size)]
+
+    with self.session(config=self._configure(group_size)) as sess:
+      if not test_util.is_gpu_available(cuda_only=True):
+        self.skipTest('No GPU available')
+      collectives = []
+      for device in devices:
+        with ops.device(device):
+          t = constant_op.constant(tensor_value)
+          collectives.append(collective_ops.broadcast_recv(
+              t.shape, t.dtype, group_size, group_key, instance_key))
+      with self.assertRaisesRegexp(errors.InternalError, 'found no source'):
+        sess.run(collectives)
+
+  @test_util.run_deprecated_v1
+  def testNcclBroadcastDoubleSend(self):
+    tensor_value = [0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1]
+    group_size = 2
+    group_key = 1
+    instance_key = 1
+    devices = ['/GPU:{}'.format(i) for i in range(group_size)]
+
+    with self.session(config=self._configure(group_size)) as sess:
+      if not test_util.is_gpu_available(cuda_only=True):
+        self.skipTest('No GPU available')
+      collectives = []
+      for device in devices:
+        with ops.device(device):
+          t = constant_op.constant(tensor_value)
+          collectives.append(collective_ops.broadcast_send(
+              t, t.shape, t.dtype, group_size, group_key, instance_key))
+      with self.assertRaisesRegexp(errors.InternalError, 'already has source'):
+        sess.run(collectives)
+
+  @test_util.run_deprecated_v1
+  def testBasicNcclAllGather(self):
+    inputs = [[0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
+              [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3]]
+    expected = [0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1,
+                0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3]
+    group_size = len(inputs)
+    group_key = 1
+    instance_key = 1
+    devices = ['/GPU:{}'.format(i) for i in range(group_size)]
+
+    with self.session(config=self._configure(group_size)) as sess:
+      if not test_util.is_gpu_available(cuda_only=True):
+        self.skipTest('No GPU available')
+      collectives = []
+      for i in range(group_size):
+        with ops.device(devices[i]):
+          t = constant_op.constant(inputs[i])
+          collectives.append(collective_ops.all_gather(t, group_size,
+                                                       group_key, instance_key))
+      results = sess.run(collectives)
+    for result in results:
+      self.assertAllClose(result, expected, rtol=1e-5, atol=1e-5)
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/python/ops/collective_ops_test.py b/tensorflow/python/ops/collective_ops_test.py
index c3e8bbf..caa8577 100644
--- a/tensorflow/python/ops/collective_ops_test.py
+++ b/tensorflow/python/ops/collective_ops_test.py
@@ -32,25 +32,27 @@
 
 class CollectiveOpTest(test.TestCase):
 
-  def _testCollectiveReduce(self, t0, t1, expected, set_graph_key):
+  def _testCollectiveReduce(self, inputs, expected, set_graph_key):
     group_key = 1
+    group_size = len(inputs)
     instance_key = 1
-    with self.session(
-        config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess:
-      with ops.device('/CPU:0'):
-        in0 = constant_op.constant(t0)
-        colred0 = collective_ops.all_reduce(in0, 2, group_key, instance_key,
-                                            'Add', 'Div')
-      with ops.device('/CPU:1'):
-        in1 = constant_op.constant(t1)
-        colred1 = collective_ops.all_reduce(in1, 2, group_key, instance_key,
-                                            'Add', 'Div')
+    device_type = 'CPU'
+    config = config_pb2.ConfigProto(device_count={device_type: group_size})
+    devices = ['/{}:{}'.format(device_type, i) for i in range(group_size)]
+
+    with self.session(config=config) as sess:
+      colred = []
+      for i in range(group_size):
+        with ops.device(devices[i]):
+          tensor = constant_op.constant(inputs[i])
+          colred.append(collective_ops.all_reduce(tensor, group_size, group_key,
+                                                  instance_key, 'Add', 'Div'))
       run_options = config_pb2.RunOptions()
       if set_graph_key:
         run_options.experimental.collective_graph_key = 1
-      results = sess.run([colred0, colred1], options=run_options)
-    self.assertAllClose(results[0], expected, rtol=1e-5, atol=1e-5)
-    self.assertAllClose(results[1], expected, rtol=1e-5, atol=1e-5)
+      results = sess.run(colred, options=run_options)
+    for i in range(group_size):
+      self.assertAllClose(results[i], expected, rtol=1e-5, atol=1e-5)
 
   def _testMultipleConcurrentCollectiveReduce(self, t0, t1, expected):
     group_key = 1
@@ -72,15 +74,19 @@
 
   @test_util.run_deprecated_v1
   def testCollectiveReduce(self):
-    self._testCollectiveReduce([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
-                               [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
-                               [0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2], True)
+    self._testCollectiveReduce(
+        inputs=[[0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
+                [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3]],
+        expected=[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2],
+        set_graph_key=True)
 
   @test_util.run_deprecated_v1
   def testCollectiveAutoGraphKey(self):
-    self._testCollectiveReduce([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
-                               [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
-                               [0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2], False)
+    self._testCollectiveReduce(
+        inputs=[[0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
+                [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3]],
+        expected=[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2],
+        set_graph_key=False)
 
   @test_util.run_deprecated_v1
   def testCollectiveMultipleConcurrentReduce(self):
@@ -122,7 +128,8 @@
 
   @test_util.run_deprecated_v1
   def testCollectiveReduceScalar(self):
-    self._testCollectiveReduce(0.1, 0.3, 0.2, True)
+    self._testCollectiveReduce(inputs=[0.1, 0.3], expected=0.2,
+                               set_graph_key=True)
 
   def _testCollectiveBroadcast(self, t0):
     group_key = 1
@@ -154,14 +161,14 @@
         config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess:
       with ops.device('/CPU:0'):
         in0 = constant_op.constant(t0)
-        colred0 = collective_ops.all_gather(in0, 2, group_key, instance_key)
+        c0 = collective_ops.all_gather(in0, 2, group_key, instance_key)
       with ops.device('/CPU:1'):
         in1 = constant_op.constant(t1)
-        colred1 = collective_ops.all_gather(in1, 2, group_key, instance_key)
+        c1 = collective_ops.all_gather(in1, 2, group_key, instance_key)
       run_options = config_pb2.RunOptions()
       if set_graph_key:
         run_options.experimental.collective_graph_key = 1
-      results = sess.run([colred0, colred1], options=run_options)
+      results = sess.run([c0, c1], options=run_options)
     self.assertAllClose(results[0], expected, rtol=1e-5, atol=1e-5)
     self.assertAllClose(results[1], expected, rtol=1e-5, atol=1e-5)
 
@@ -194,18 +201,38 @@
         config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess:
       with ops.device('/CPU:0'):
         in0 = constant_op.constant(t0)
-        colred0 = collective_ops.all_gather(in0, 2, group_key, instance_key)
+        c0 = collective_ops.all_gather(in0, 2, group_key, instance_key)
       with ops.device('/CPU:1'):
         in1 = constant_op.constant(t1)
         in2 = constant_op.constant(t2)
-        colred1 = collective_ops.all_gather(in1, 2, group_key, instance_key)
-        colred2 = collective_ops.all_gather(in2, 2, group_key, instance_key)
+        c1 = collective_ops.all_gather(in1, 2, group_key, instance_key)
+        c2 = collective_ops.all_gather(in2, 2, group_key, instance_key)
       run_options = config_pb2.RunOptions()
       run_options.experimental.collective_graph_key = 1
-      sess.run([colred0, colred1], options=run_options)
-      with self.assertRaisesRegexp(errors.InternalError,
-                                   'Inconsistent output shapes'):
-        sess.run([colred0, colred2], options=run_options)
+      sess.run([c0, c1], options=run_options)
+      with self.assertRaisesRegexp(errors.InvalidArgumentError,
+                                   'Shape mismatch'):
+        sess.run([c0, c2], options=run_options)
+
+  @test_util.run_deprecated_v1
+  def testCollectiveGatherShapeMismatchAcrossDevices(self):
+    group_key = 1
+    instance_key = 1
+    t0 = [1, 2, 3, 4]
+    t1 = [5, 6]
+    with self.session(
+        config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess:
+      with ops.device('/CPU:0'):
+        in0 = constant_op.constant(t0)
+        c0 = collective_ops.all_gather(in0, 2, group_key, instance_key)
+      with ops.device('/CPU:1'):
+        in1 = constant_op.constant(t1)
+        c1 = collective_ops.all_gather(in1, 2, group_key, instance_key)
+      run_options = config_pb2.RunOptions()
+      run_options.experimental.collective_graph_key = 1
+      with self.assertRaisesRegexp(errors.InvalidArgumentError,
+                                   'Shape mismatch'):
+        sess.run([c0, c1], options=run_options)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py
index 386aff3..3f45c50 100644
--- a/tensorflow/python/ops/cond_v2.py
+++ b/tensorflow/python/ops/cond_v2.py
@@ -30,6 +30,7 @@
 from tensorflow.python.framework import func_graph as func_graph_module
 from tensorflow.python.framework import function_def_to_graph
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_util
 from tensorflow.python.ops import control_flow_util_v2 as util
@@ -70,6 +71,9 @@
     # graphs. Propagate that behavior here.
     add_control_dependencies = ops.get_default_graph()._add_control_dependencies
     pred = ops.convert_to_tensor(pred)
+    if (tensor_util.is_tensor(pred) and
+        (pred.shape.dims is None or pred.shape.dims)):
+      pred = array_ops.squeeze_v2(pred)
 
     true_graph = func_graph_module.func_graph_from_py_func(
         true_name,
@@ -87,10 +91,14 @@
         op_return_value=pred)
 
     verify_captures(_COND, [true_graph, false_graph])
-    return _build_cond(pred, true_graph, false_graph,
-                       true_graph.external_captures,
-                       false_graph.external_captures,
-                       name=scope)
+    return _build_cond(
+        pred,
+        true_graph,
+        false_graph,
+        true_graph.external_captures,
+        false_graph.external_captures,
+        building_gradient=False,
+        name=scope)
 
 
 @ops.RegisterGradient("StatelessIf")
@@ -162,14 +170,25 @@
   _make_output_composite_tensors_match(_COND,
                                        [true_grad_graph, false_grad_graph])
 
-  outputs = _build_cond(if_op.inputs[0], true_grad_graph, false_grad_graph,
-                        true_grad_inputs, false_grad_inputs)
+  outputs = _build_cond(
+      if_op.inputs[0],
+      true_grad_graph,
+      false_grad_graph,
+      true_grad_inputs,
+      false_grad_inputs,
+      building_gradient=True,
+  )
 
   # The predicate has no gradient.
   return [None] + outputs
 
 
-def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs,
+def _build_cond(pred,
+                true_graph,
+                false_graph,
+                true_inputs,
+                false_inputs,
+                building_gradient,
                 name=None):
   """Creates an If op from the specified predicate, branch functions and inputs.
 
@@ -186,6 +205,7 @@
     false_graph: FuncGraph
     true_inputs: a list of Tensors to be passed to true_graph as input.
     false_inputs: a list of Tensors to be passed to false_graph as input.
+    building_gradient: Whether this is a gradient If op.
     name: the name for the If op.
 
   Returns:
@@ -199,6 +219,33 @@
   # this modifies true_graph and false_graph.
   cond_inputs = _make_inputs_match([true_graph, false_graph],
                                    [true_inputs, false_inputs])
+  # Save the original number of outputs to return to the caller.
+  num_cond_outputs = len(true_graph.outputs)
+  # We do not output intermediates of the gradient If op since this is just
+  # for backwards compatibility with existing code.
+  if not building_gradient and util.output_all_intermediates():
+    # Add all intermediate tensors as function outputs so they're available for
+    # the gradient computation. Since the outputs of the two functions must
+    # match, we wrap all the intermediates in optionals. Each intermediate
+    # output will have a value iff its corresponding branch is taken.
+
+    true_intermediates = _get_intermediates(true_graph)
+    false_intermediates = _get_intermediates(false_graph)
+
+    # Wrap intermediates in optionals.
+    wrapped_true_intermediates = _wrap_intermediates(true_graph,
+                                                     true_intermediates)
+    wrapped_false_intermediates = _wrap_intermediates(false_graph,
+                                                      false_intermediates)
+
+    # Make outputs match by adding none optionals.
+    extra_true_outputs, extra_false_outputs = _make_intermediates_match(  # pylint: disable=unbalanced-tuple-unpacking
+        [true_graph, false_graph],
+        [wrapped_true_intermediates, wrapped_false_intermediates])
+
+    true_graph.outputs.extend(extra_true_outputs)
+    false_graph.outputs.extend(extra_false_outputs)
+    _check_same_outputs(_COND, [true_graph, false_graph])
 
   # Create the If op.
   with ops.control_dependencies(
@@ -245,7 +292,7 @@
   # Prevent fetching since the variant outputs can't be fetched directly.
   if_op.graph.prevent_fetching(if_op)
   return func_graph_module.pack_sequence_as(true_graph.structured_outputs,
-                                            tensors)
+                                            tensors[:num_cond_outputs])
 
 
 def get_func_graphs(op):
@@ -275,8 +322,7 @@
           fdef, input_shapes)
     for external_t, internal_t in zip(inputs, func_graph.inputs):
       custom_gradient.copy_handle_data(external_t, internal_t)
-    func_graph.captures = collections.OrderedDict(zip(inputs,
-                                                      func_graph.inputs))
+    func_graph.reset_captures(zip(inputs, func_graph.inputs))
     # Link the op so that the gradient code can use it.
     func_graph._forward_cond = op
     return func_graph
@@ -379,11 +425,18 @@
       # `internal_captures` are not treated as intermediates and hence not added
       # to If op outputs. So we get the outer tensor corresponding to those
       # from the list of `external_captures`.
-      try:
-        t = t.graph._forward_cond.outputs[t.graph.outputs.index(t)]
-      except ValueError:
-        index = t.graph.internal_captures.index(t)
-        t = t.graph.external_captures[index]
+      for i, output in enumerate(t.graph.outputs):
+        if output is t:
+          t = t.graph._forward_cond.outputs[i]
+          break
+      else:
+        for i, output in enumerate(t.graph.internal_captures):
+          if output is t:
+            t = t.graph.external_captures[i]
+            break
+        else:
+          raise ValueError("Could not find external tensor capture {tensor} in "
+                           "captures or outputs".format(tensor=t))
 
       # Note: We rely on the capturing logic of the gradient If op graph to
       # correctly capture the tensors in `cond_graph.outer_graph`. Both cond_v2
@@ -395,12 +448,17 @@
 
 
 def _get_intermediates(func_graph):
-  """Returns all tensors in `func_graph` that aren't inputs or outputs."""
+  """Returns intermediate tensors of `func_graph` for gradient computation."""
   intermediates = []
   for op in func_graph.get_operations():
     for t in op.outputs:
       if t in func_graph.inputs: continue
       if t in func_graph.outputs: continue
+      if t.dtype is dtypes.resource:
+        continue
+      # Accumulating mutexes can cause deadlock.
+      if op.type == "MutexLock":
+        continue
       intermediates.append(t)
   return intermediates
 
@@ -465,16 +523,21 @@
     branch_graph. This is a deduped version of `sum(branch_inputs)`.
   """
   assert len(branch_graphs) == len(branch_inputs)
-  new_inputs = set()
+  added_inputs = set()
+  new_inputs = []
   for branch_in in branch_inputs:
-    new_inputs |= set(branch_in)
-  new_inputs = list(new_inputs)
+    for tensor in branch_in:
+      tensor_id = ops.tensor_id(tensor)
+      if tensor_id not in added_inputs:
+        added_inputs.add(tensor_id)
+        new_inputs.append(tensor)
 
   for branch_graph, branch_in in zip(branch_graphs, branch_inputs):
-    branch_input_to_param = dict(zip(branch_in, branch_graph.inputs))
+    input_ids = [ops.tensor_id(t) for t in branch_in]
+    branch_input_to_param = dict(zip(input_ids, branch_graph.inputs))
     input_list = []
     for in_t in new_inputs:
-      param = branch_input_to_param.get(in_t, None)
+      param = branch_input_to_param.get(ops.tensor_id(in_t))
       if param is None:
         param = _create_dummy_input(branch_graph, in_t)
       input_list.append(param)
@@ -482,8 +545,7 @@
     branch_graph.inputs = input_list
 
     # Rewrite the FuncGraphs' state to reflect the new inputs.
-    branch_graph.captures = collections.OrderedDict(
-        zip(new_inputs, branch_graph.inputs))
+    branch_graph.reset_captures(zip(new_inputs, branch_graph.inputs))
 
   return new_inputs
 
@@ -744,19 +806,20 @@
 
   def _capture_helper(self, tensor, name):
     if (tensor.graph is not self._forward_graph or
-        tensor in self._forward_graph.inputs or
-        tensor in self._forward_graph.outputs):
+        any(tensor is t for t in self._forward_graph.inputs) or
+        any(tensor is t for t in self._forward_graph.outputs)):
       return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)
 
     if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
       # XLA does not yet support optionals, so capture intermediates directly.
       # TODO(skyewm,jpienaar): can XLA support optionals?
-      if tensor not in self.captures:
+      if tensor not in self.external_captures:
         self.xla_intermediates.append(tensor)
         self.op_needs_rewrite = True
       return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)
 
-    captured_tensor = self._indirect_captures.get(tensor)
+    tensor_id = ops.tensor_id(tensor)
+    captured_tensor = self._indirect_captures.get(tensor_id)
     if captured_tensor is not None:
       return captured_tensor
 
@@ -778,7 +841,7 @@
       captured_tensor = super(_CondGradFuncGraph, self)._capture_helper(
           self._forward_graph.inputs[index], name)
     else:
-      if tensor not in self._wrapped_intermediates:
+      if tensor_id not in self._wrapped_intermediates:
         # If the gradient has already been computed for this If op, 'tensor' may
         # already be wrapped.
         for consumer in tensor.consumers():
@@ -791,15 +854,15 @@
           with self._forward_graph.as_default():
             optional = gen_dataset_ops.optional_from_value([tensor])
           self.op_needs_rewrite = True
-        self._wrapped_intermediates[tensor] = optional
+        self._wrapped_intermediates[tensor_id] = optional
 
-      optional = self._wrapped_intermediates[tensor]
+      optional = self._wrapped_intermediates[tensor_id]
       captured_optional = super(_CondGradFuncGraph,
                                 self)._capture_helper(optional, name)
       captured_tensor = gen_dataset_ops.optional_get_value(
           captured_optional, [tensor.dtype], [tensor.shape])[0]
 
-    self._indirect_captures[tensor] = captured_tensor
+    self._indirect_captures[tensor_id] = captured_tensor
     return captured_tensor
 
 
@@ -869,7 +932,7 @@
     # NOTE(bjp): if there are any active sessions, this modification to `op`
     # may make them unrunnable!
 
-    if control_flow_util.InXlaContext(ops.get_default_graph()):
+    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
       # XLA does not yet support optionals, so output intermediates directly and
       # make them match via FakeParams, which can be converted to zeros in XLA.
       # TODO(bjp,jpienaar): can XLA support optionals?
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 4f719086..7ad3d76 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -482,6 +482,12 @@
 
   elif shape is None:
     return var.shape
+  elif isinstance(shape, tensor_spec.TensorSpec):
+    if var.dtype != shape.dtype:
+      raise TypeError("TensorSpec %r is not compatible with %r" % (shape, var))
+    return shape.shape
+  elif isinstance(shape, type_spec.TypeSpec):
+    raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var))
   else:
     return shape
 
@@ -498,6 +504,8 @@
     A `TypeSpec` for `var`, consistent with the given shape.
   """
   if isinstance(shape, type_spec.TypeSpec):
+    if not shape.is_compatible_with(var):
+      raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var))
     return shape
   elif not isinstance(shape, tensor_shape.TensorShape):
     raise TypeError("Expected shape to be a TypeSpec or TensorShape, got %r"
@@ -752,15 +760,6 @@
       return self._outer_context.GetWhileContext()
     return None
 
-  def _IsInOuterContext(self, op):
-    op_ctxt = util.GetOutputContext(op)
-    outer_ctxt = self.outer_context
-    while outer_ctxt != op_ctxt:
-      if outer_ctxt is None:
-        return False
-      outer_ctxt = outer_ctxt.outer_context
-    return True
-
   def _RemoveExternalControlEdges(self, op):
     """Remove any external control dependency on this op."""
     while_ctxt = self.GetWhileContext()
@@ -2273,9 +2272,22 @@
       for x in xs:
         inp_op = x.op.inputs[0].op
         control_inputs = graph._control_dependencies_for_inputs([inp_op])
-        outer_control_inputs = [
-            op for op in control_inputs if self._IsInOuterContext(op)
-        ]
+        outer_control_inputs = []
+        for op in control_inputs:
+          # We need to keep control inputs that are in any ancestor
+          # ControlFlowContext, and within outer WhileContext.
+          keep_as_control_input = True
+          op_ctxt = util.GetOutputContext(op)
+          outer_ctxt = self.outer_context
+          outer_while_context = (None if outer_ctxt is None else
+                                 outer_ctxt.GetWhileContext())
+          while outer_ctxt != op_ctxt:
+            if outer_ctxt is None or outer_ctxt == outer_while_context:
+              keep_as_control_input = False
+              break
+            outer_ctxt = outer_ctxt.outer_context
+          if keep_as_control_input:
+            outer_control_inputs.append(op)
         x.op._set_control_flow_context(self)
         x.op._add_control_inputs(outer_control_inputs)
         graph._record_op_seen_by_control_dependencies(x.op)
@@ -2640,6 +2652,13 @@
   ```
 
   """
+  if not callable(cond):
+    raise TypeError("cond must be callable.")
+  if not callable(body):
+    raise TypeError("body must be callable.")
+  if parallel_iterations < 1:
+    raise TypeError("parallel_iterations must be a positive integer.")
+
   # Always enable control flow v2 if building a function, regardless of toggle.
   executing_eagerly = context.executing_eagerly()
   if (util.EnableControlFlowV2(ops.get_default_graph()) and
@@ -2652,18 +2671,12 @@
         parallel_iterations=parallel_iterations,
         maximum_iterations=maximum_iterations,
         name=name,
-        return_same_structure=return_same_structure)
+        return_same_structure=return_same_structure,
+        back_prop=back_prop)
 
   with ops.name_scope(name, "while", loop_vars):
     if not loop_vars:
       raise ValueError("No loop variables provided")
-    if not callable(cond):
-      raise TypeError("cond must be callable.")
-    if not callable(body):
-      raise TypeError("body must be callable.")
-    if parallel_iterations < 1:
-      raise TypeError("parallel_iterations must be a positive integer.")
-
     try_to_pack = (len(loop_vars) == 1 and not return_same_structure)
     if maximum_iterations is not None:
       maximum_iterations = ops.convert_to_tensor(
@@ -3265,7 +3278,111 @@
     return cond_v2.indexed_case(branch_index, branch_fns)
 
 
-@tf_export("case")
+@tf_export("case", v1=[])
+def case_v2(pred_fn_pairs,
+            default=None,
+            exclusive=False,
+            strict=False,
+            name="case"):
+  """Create a case operation.
+
+  See also `tf.switch_case`.
+
+  The `pred_fn_pairs` parameter is a list of pairs of size N.
+  Each pair contains a boolean scalar tensor and a python callable that
+  creates the tensors to be returned if the boolean evaluates to True.
+  `default` is a callable generating a list of tensors. All the callables
+  in `pred_fn_pairs` as well as `default` (if provided) should return the same
+  number and types of tensors.
+
+  If `exclusive==True`, all predicates are evaluated, and an exception is
+  thrown if more than one of the predicates evaluates to `True`.
+  If `exclusive==False`, execution stops at the first predicate which
+  evaluates to True, and the tensors generated by the corresponding function
+  are returned immediately. If none of the predicates evaluate to True, this
+  operation returns the tensors generated by `default`.
+
+  `tf.case` supports nested structures as implemented in
+  `tf.contrib.framework.nest`. All of the callables must return the same
+  (possibly nested) value structure of lists, tuples, and/or named tuples.
+  Singleton lists and tuples form the only exceptions to this: when returned by
+  a callable, they are implicitly unpacked to single values. This
+  behavior is disabled by passing `strict=True`.
+
+  @compatibility(v2)
+  `pred_fn_pairs` could be a dictionary in v1. However, tf.Tensor and
+  tf.Variable are no longer hashable in v2, so cannot be used as a key for a
+  dictionary.  Please use a list or a tuple instead.
+  @end_compatibility
+
+
+  **Example 1:**
+
+  Pseudocode:
+
+  ```
+  if (x < y) return 17;
+  else return 23;
+  ```
+
+  Expressions:
+
+  ```python
+  f1 = lambda: tf.constant(17)
+  f2 = lambda: tf.constant(23)
+  r = tf.case([(tf.less(x, y), f1)], default=f2)
+  ```
+
+  **Example 2:**
+
+  Pseudocode:
+
+  ```
+  if (x < y && x > z) raise OpError("Only one predicate may evaluate to True");
+  if (x < y) return 17;
+  else if (x > z) return 23;
+  else return -1;
+  ```
+
+  Expressions:
+
+  ```python
+  def f1(): return tf.constant(17)
+  def f2(): return tf.constant(23)
+  def f3(): return tf.constant(-1)
+  r = tf.case([(tf.less(x, y), f1), (tf.greater(x, z), f2)],
+           default=f3, exclusive=True)
+  ```
+
+  Args:
+    pred_fn_pairs: List of pairs of a boolean scalar tensor and a callable which
+      returns a list of tensors.
+    default: Optional callable that returns a list of tensors.
+    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
+    strict: A boolean that enables/disables 'strict' mode; see above.
+    name: A name for this operation (optional).
+
+  Returns:
+    The tensors returned by the first pair whose predicate evaluated to True, or
+    those returned by `default` if none does.
+
+  Raises:
+    TypeError: If `pred_fn_pairs` is not a list/tuple.
+    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
+    TypeError: If `fns[i]` is not callable for any i, or `default` is not
+               callable.
+  """
+  return _case_helper(
+      cond,
+      pred_fn_pairs,
+      default,
+      exclusive,
+      name,
+      allow_python_preds=False,
+      strict=strict)
+
+
+@tf_export(v1=["case"])
 def case(pred_fn_pairs,
          default=None,
          exclusive=False,
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
index 91ce63a..a32f33f 100644
--- a/tensorflow/python/ops/control_flow_ops_test.py
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -39,6 +39,7 @@
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_v2_toggles
 from tensorflow.python.ops import embedding_ops
 from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import init_ops
@@ -644,7 +645,7 @@
     self._testShape(fn_true, fn_false, shape)
     self._testReturnValues(fn_true, fn_false, b"abc", b"xyz")
 
-  @test_util.run_deprecated_v1
+  @test_util.run_v1_only("b/138741991")
   def test_variable(self):
     shape = tensor_shape.TensorShape([])
     fn_true = lambda: variables.Variable(3.0)
@@ -792,7 +793,7 @@
     fn_false = lambda: ta.read(1)
     self._testShape(fn_true, fn_false, shape)
 
-  @test_util.run_deprecated_v1
+  @test_util.run_v1_only("b/138741991")
   def test_list(self):
     shape = [tensor_shape.TensorShape([]), tensor_shape.TensorShape([]),
              tensor_shape.TensorShape([])]
@@ -1082,6 +1083,8 @@
   @test_util.disable_xla("Wants RunMetadata")
   def testParallelExecution(self):
     """Verify disjoint branches across while iterations are run in parallel."""
+    if control_flow_v2_toggles.control_flow_v2_enabled():
+      self.skipTest("b/138870290")
     if test.is_built_with_rocm():
       self.skipTest(
           "Disable subtest on ROCm due to missing Cholesky op support")
@@ -1288,7 +1291,7 @@
     # Expect a tuple since that is what the body returns.
     self.assertEqual(self.evaluate(r), (10,))
 
-  @test_util.run_deprecated_v1
+  @test_util.run_v1_only("Unsupported in cfv2")
   def testWhileLoopSameReturnShape_False(self):
     i = constant_op.constant(0)
     c = lambda i, _: math_ops.less(i, 10)
@@ -1367,6 +1370,12 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testAssertInFunction(self):
+    # TODO(fishx): Re-enable this test for GPU.
+    # NOTE(fishx): Disable this test for now because, in GPU, multiple errors
+    # will be thrown. But since the root cause error is marked as "derived"
+    # error. So it might be ignored.
+    if test_util.is_gpu_available():
+      self.skipTest("Skip GPU Test")
 
     @def_function.function
     def whiny(value):
diff --git a/tensorflow/python/ops/control_flow_util.py b/tensorflow/python/ops/control_flow_util.py
index a2e8a65..0f98418 100644
--- a/tensorflow/python/ops/control_flow_util.py
+++ b/tensorflow/python/ops/control_flow_util.py
@@ -26,9 +26,12 @@
 import os
 import traceback
 
+from tensorflow.python import tf2
 from tensorflow.python.platform import tf_logging as logging
 
-ENABLE_CONTROL_FLOW_V2 = (os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or
+ENABLE_CONTROL_FLOW_V2 = ((tf2.enabled() and
+                           os.getenv("TF_ENABLE_CONTROL_FLOW_V2") != "0") or
+                          os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or
                           os.getenv("TF_ENABLE_COND_V2", "0") != "0" or
                           os.getenv("TF_ENABLE_WHILE_V2", "0") != "0" or
                           os.getenv("TF_ENABLE_TENSOR_ARRAY_V2", "0") != "0")
diff --git a/tensorflow/python/ops/control_flow_util_v2.py b/tensorflow/python/ops/control_flow_util_v2.py
index 70ec9f3..fe953b4 100644
--- a/tensorflow/python/ops/control_flow_util_v2.py
+++ b/tensorflow/python/ops/control_flow_util_v2.py
@@ -20,36 +20,21 @@
 from __future__ import print_function
 
 from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.distribute import distribution_strategy_context
 from tensorflow.python.eager import context
 from tensorflow.python.eager import function
 from tensorflow.python.framework import ops
 from tensorflow.python.framework.func_graph import FuncGraph
+from tensorflow.python.keras.engine import base_layer_utils
 from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import control_flow_v2_func_graphs
 from tensorflow.python.util import tf_contextlib
 
+_EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = None
 
-class CondBranchFuncGraph(FuncGraph):
-  """FuncGraph for branches of tf.cond().
-
-  This is used to distinguish cond branches from other functions.
-  """
-  pass
-
-
-class WhileCondFuncGraph(FuncGraph):
-  """FuncGraph for the condition of tf.while_loop().
-
-  This is used to distinguish while conditions from other functions.
-  """
-  pass
-
-
-class WhileBodyFuncGraph(FuncGraph):
-  """FuncGraph for the body of tf.while_loop().
-
-  This is used to distinguish while bodies from other functions.
-  """
-  pass
+CondBranchFuncGraph = control_flow_v2_func_graphs.CondBranchFuncGraph
+WhileCondFuncGraph = control_flow_v2_func_graphs.WhileCondFuncGraph
+WhileBodyFuncGraph = control_flow_v2_func_graphs.WhileBodyFuncGraph
 
 
 def in_defun():
@@ -226,3 +211,45 @@
     ops.get_default_graph()._set_control_flow_context(control_flow_context)
     yield
   # pylint: enable=protected-access
+
+
+def _is_tpu_strategy(strategy):
+  return (strategy is not None and
+          strategy.__class__.__name__.startswith("TPUStrategy"))
+
+
+def _is_building_keras_layer():
+  return base_layer_utils.call_context().layer is not None
+
+
+def output_all_intermediates():
+  """Whether to output all intermediates of a functional control flow op.
+
+  The default behavior is to output intermediates only when building a Keras
+  Layer in graph mode and that too when certain other conditions are met:
+  1. We do not output intermediates if the functional control flow op
+     is being built inside a FuncGraph which is not a If/While graph. This
+     guards against outputting intermediates in eager mode since keras adds
+     tensors to a FuncGraph named "keras_graph" in that case. Also because we
+     do not output intermediates of tf.function (since this feature is only for
+     backwards compatibility) outputting intermediates of functional control
+     flow ops built inside tf.function is of no value.
+  2. We do not output intermediates when the compilation is using XLA or for a
+     TPU.
+  3. We do not output intermediates when a single threaded executor is used
+     since that does not perform inlining and pruning.
+
+  Returns:
+    A bool telling whether to output all intermediates.
+  """
+  if _EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE is not None:
+    return _EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
+  if in_defun():
+    return False
+  if (control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()) or
+      _is_tpu_strategy(distribution_strategy_context.get_strategy())):
+    return False
+  if (context.context().function_call_options.executor_type ==
+      "SINGLE_THREADED_EXECUTOR"):
+    return False
+  return _is_building_keras_layer()
diff --git a/tensorflow/python/ops/control_flow_v2_disable_test.py b/tensorflow/python/ops/control_flow_v2_disable_test.py
new file mode 100644
index 0000000..f6e3888
--- /dev/null
+++ b/tensorflow/python/ops/control_flow_v2_disable_test.py
@@ -0,0 +1,39 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests that TF2_BEHAVIOR=1 and TF_ENABLE_CONTROL_FLOW_V2=0 disables cfv2."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+os.environ["TF2_BEHAVIOR"] = "1"
+os.environ["TF_ENABLE_CONTROL_FLOW_V2"] = "0"
+
+from tensorflow.python import tf2  # pylint: disable=g-import-not-at-top
+from tensorflow.python.ops import control_flow_util
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
+
+
+class ControlFlowV2DisableTest(test.TestCase):
+
+  def testIsDisabled(self):
+    self.assertTrue(tf2.enabled())
+    self.assertFalse(control_flow_util.ENABLE_CONTROL_FLOW_V2)
+
+
+if __name__ == "__main__":
+  googletest.main()
diff --git a/tensorflow/python/ops/control_flow_v2_enable_test.py b/tensorflow/python/ops/control_flow_v2_enable_test.py
new file mode 100644
index 0000000..f29d4dc
--- /dev/null
+++ b/tensorflow/python/ops/control_flow_v2_enable_test.py
@@ -0,0 +1,38 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests that TF2_BEHAVIOR=1 enables cfv2."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+os.environ["TF2_BEHAVIOR"] = "1"
+
+from tensorflow.python import tf2  # pylint: disable=g-import-not-at-top
+from tensorflow.python.ops import control_flow_util
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
+
+
+class ControlFlowV2EnableTest(test.TestCase):
+
+  def testIsEnabled(self):
+    self.assertTrue(tf2.enabled())
+    self.assertTrue(control_flow_util.ENABLE_CONTROL_FLOW_V2)
+
+
+if __name__ == "__main__":
+  googletest.main()
diff --git a/tensorflow/python/ops/control_flow_v2_func_graphs.py b/tensorflow/python/ops/control_flow_v2_func_graphs.py
new file mode 100644
index 0000000..1a96d39
--- /dev/null
+++ b/tensorflow/python/ops/control_flow_v2_func_graphs.py
@@ -0,0 +1,45 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""FuncGraphs for V2 control flow."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework.func_graph import FuncGraph
+
+
+class CondBranchFuncGraph(FuncGraph):
+  """FuncGraph for branches of tf.cond().
+
+  This is used to distinguish cond branches from other functions.
+  """
+  pass
+
+
+class WhileCondFuncGraph(FuncGraph):
+  """FuncGraph for the condition of tf.while_loop().
+
+  This is used to distinguish while conditions from other functions.
+  """
+  pass
+
+
+class WhileBodyFuncGraph(FuncGraph):
+  """FuncGraph for the body of tf.while_loop().
+
+  This is used to distinguish while bodies from other functions.
+  """
+  pass
diff --git a/tensorflow/python/ops/control_flow_v2_toggles.py b/tensorflow/python/ops/control_flow_v2_toggles.py
index bbd264f..9bae4e3 100644
--- a/tensorflow/python/ops/control_flow_v2_toggles.py
+++ b/tensorflow/python/ops/control_flow_v2_toggles.py
@@ -21,6 +21,7 @@
 
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import control_flow_util_v2
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -64,3 +65,30 @@
   Note: v2 control flow is always enabled inside of tf.function.
   """
   return control_flow_util.EnableControlFlowV2(ops.get_default_graph())
+
+
+@tf_export(v1=["experimental.output_all_intermediates"])
+def output_all_intermediates(state):  # pylint: disable=invalid-name
+  """Whether to output all intermediates from functional control flow ops.
+
+  The "default" behavior to is to output all intermediates when using v2 control
+  flow inside Keras models in graph mode (possibly inside Estimators). This is
+  needed to support taking gradients of v2 control flow. In graph mode, Keras
+  can sometimes freeze the forward graph before the gradient computation which
+  does not work for v2 control flow since it requires updating the forward ops
+  to output the needed intermediates. We work around this by proactively
+  outputting the needed intermediates when building the forward pass itself.
+  Ideally any such extra tensors should be pruned out at runtime. However, if
+  for any reason this doesn't work for you or if you have an infernce-only model
+  you can turn this behavior off using
+  `tf.compat.v1.experimental.output_all_intermediates(False)`.
+
+  If with the default behavior you are still seeing errors of the form
+  "Connecting to invalid output X of source node Y which has Z outputs" try
+  setting `tf.compat.v1.experimental.output_all_intermediates(True)` and
+  please file an issue at https://github.com/tensorflow/tensorflow/issues.
+
+  Args:
+    state: True, False or None. None restores the default behavior.
+  """
+  control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = state  # pylint: disable=protected-access
diff --git a/tensorflow/python/ops/control_flow_v2_toggles_test.py b/tensorflow/python/ops/control_flow_v2_toggles_test.py
new file mode 100644
index 0000000..78b63af
--- /dev/null
+++ b/tensorflow/python/ops/control_flow_v2_toggles_test.py
@@ -0,0 +1,44 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for control_flow_v2_toggles.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import control_flow_util_v2
+from tensorflow.python.ops import control_flow_v2_toggles
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
+
+
+class ControlFlowV2TogglesTest(test.TestCase):
+
+  def testOutputAllIntermediates(self):
+    self.assertIsNone(
+        control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE)
+    control_flow_v2_toggles.output_all_intermediates(True)
+    self.assertTrue(
+        control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE)
+    control_flow_v2_toggles.output_all_intermediates(False)
+    self.assertFalse(
+        control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE)
+    control_flow_v2_toggles.output_all_intermediates(None)
+    self.assertIsNone(
+        control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE)
+
+
+if __name__ == '__main__':
+  googletest.main()
diff --git a/tensorflow/python/ops/critical_section_ops.py b/tensorflow/python/ops/critical_section_ops.py
index 85d828c..16419e4 100644
--- a/tensorflow/python/ops/critical_section_ops.py
+++ b/tensorflow/python/ops/critical_section_ops.py
@@ -31,6 +31,7 @@
 from tensorflow.python.ops import gen_resource_variable_ops
 from tensorflow.python.ops import tensor_array_ops
 from tensorflow.python.util import nest
+from tensorflow.python.util import object_identity
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -243,7 +244,7 @@
         # captured_resources is a list of resources that are directly
         # accessed only by ops created during fn(), not by any
         # ancestors of those ops in the graph.
-        captured_resources = set([
+        captured_resources = object_identity.ObjectIdentitySet([
             input_ for op in created_ops
             for input_ in op.inputs
             if input_.dtype == dtypes.resource
diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py
index 12b4feb..ec832c8 100644
--- a/tensorflow/python/ops/custom_gradient.py
+++ b/tensorflow/python/ops/custom_gradient.py
@@ -238,10 +238,15 @@
           "with `use_resource=False`.")
   # The variables that grad_fn needs to return gradients for are the set of
   # variables used that are *not* part of the inputs.
-  variables_in_tape = frozenset(tape.watched_variables()) - frozenset(args)
-  variables_in_subgraph = frozenset(get_dependent_variables(
-      input_ops=args, output_ops=result))
-  variables = list(variables_in_subgraph.union(variables_in_tape))
+  variables_in_tape = frozenset([
+      v.experimental_ref() for v in tape.watched_variables()
+  ]) - frozenset(v.experimental_ref() for v in args)
+  variables_in_subgraph = frozenset([
+      v.experimental_ref()
+      for v in get_dependent_variables(input_ops=args, output_ops=result)
+  ])
+  variables = list(
+      [v.deref() for v in variables_in_subgraph.union(variables_in_tape)])
 
   grad_argspec = tf_inspect.getfullargspec(grad_fn)
   variables_in_signature = ("variables" in grad_argspec.args or
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 3825332..f9e0f23 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -18,7 +18,6 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
 import hashlib
 import threading
 
@@ -41,6 +40,7 @@
 # pylint: disable=wildcard-import
 from tensorflow.python.ops.gen_data_flow_ops import *
 from tensorflow.python.util import deprecation
+from tensorflow.python.util.compat import collections_abc
 from tensorflow.python.util.tf_export import tf_export
 
 # pylint: enable=wildcard-import
@@ -64,7 +64,7 @@
   """Convert shapes to a list of tuples of int (or None)."""
   del dtypes
   if unknown_dim_allowed:
-    if (not isinstance(shapes, collections.Sequence) or not shapes or
+    if (not isinstance(shapes, collections_abc.Sequence) or not shapes or
         any(shape is None or isinstance(shape, int) for shape in shapes)):
       raise ValueError(
           "When providing partial shapes, a list of shapes must be provided.")
@@ -1518,6 +1518,40 @@
         values=return_val.values,
         dense_shape=return_val.shape)
 
+  # SparseConditionalAccumulator is not switched to resource. Use old kernels.
+  def num_accumulated(self, name=None):
+    """Number of gradients that have currently been aggregated in accumulator.
+
+    Args:
+      name: Optional name for the operation.
+
+    Returns:
+      Number of accumulated gradients currently in accumulator.
+    """
+    if name is None:
+      name = "%s_NumAccumulated" % self._name
+
+    return gen_data_flow_ops.accumulator_num_accumulated(
+        self._accumulator_ref, name=name)
+
+  def set_global_step(self, new_global_step, name=None):
+    """Sets the global time step of the accumulator.
+
+    The operation logs a warning if we attempt to set to a time step that is
+    lower than the accumulator's own time step.
+
+    Args:
+      new_global_step: Value of new time step. Can be a variable or a constant
+      name: Optional name for the operation.
+
+    Returns:
+      Operation that sets the accumulator's time step.
+    """
+    return gen_data_flow_ops.accumulator_set_global_step(
+        self._accumulator_ref,
+        math_ops.cast(ops.convert_to_tensor(new_global_step), _dtypes.int64),
+        name=name)
+
 
 class BaseStagingArea(object):
   """Base class for Staging Areas."""
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index e4c7087..9c437d3 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -113,7 +113,9 @@
   Raises:
     ValueError: If `params` is empty.
   """
-  if params is None or params in ((), []):
+  if params is None:
+    raise ValueError("params must be specified")
+  if isinstance(params, (list, tuple)) and not params:
     raise ValueError("Need at least one param")
   if isinstance(params, variables.PartitionedVariable):
     params = list(params)  # Iterate to get the underlying Variables.
diff --git a/tensorflow/python/ops/gradient_checker_test.py b/tensorflow/python/ops/gradient_checker_test.py
index a1e1b7a..92ca9c2 100644
--- a/tensorflow/python/ops/gradient_checker_test.py
+++ b/tensorflow/python/ops/gradient_checker_test.py
@@ -60,7 +60,7 @@
       # checking gradients for x1
       error = gradient_checker.compute_gradient_error(x1, size, y, size)
     tf_logging.info("x1 error = %f", error)
-    assert error < 1e-4
+    self.assertLess(error, 1e-4)
 
   @test_util.run_deprecated_v1
   def testAddSimpleGPU(self):
@@ -75,7 +75,7 @@
       # checking gradients for x1
       error = gradient_checker.compute_gradient_error(x1, size, y, size)
     tf_logging.info("x1 error = %f", error)
-    assert error < 1e-4
+    self.assertLess(error, 1e-4)
 
   @test_util.run_deprecated_v1
   def testAddCustomized(self):
@@ -94,7 +94,7 @@
       error = gradient_checker.compute_gradient_error(
           x2, size, y, size, x_init_value=x_init_value, delta=1e-2)
     tf_logging.info("x2 error = %f", error)
-    assert error < 1e-10
+    self.assertLess(error, 1e-10)
 
   @test_util.run_deprecated_v1
   def testGather(self):
@@ -112,7 +112,7 @@
       error = gradient_checker.compute_gradient_error(params, p_shape, y,
                                                       y_shape)
     tf_logging.info("gather error = %f", error)
-    assert error < 1e-4
+    self.assertLess(error, 1e-4)
 
   @test_util.run_deprecated_v1
   def testNestedGather(self):
@@ -134,7 +134,7 @@
       error = gradient_checker.compute_gradient_error(params, p_shape, y2,
                                                       y2_shape)
     tf_logging.info("nested gather error = %f", error)
-    assert error < 1e-4
+    self.assertLess(error, 1e-4)
 
   @test_util.run_deprecated_v1
   def testComplexMul(self):
diff --git a/tensorflow/python/ops/gradient_checker_v2_test.py b/tensorflow/python/ops/gradient_checker_v2_test.py
index 191b2b6..d1205c3 100644
--- a/tensorflow/python/ops/gradient_checker_v2_test.py
+++ b/tensorflow/python/ops/gradient_checker_v2_test.py
@@ -54,7 +54,7 @@
     error = gradient_checker.max_error(*gradient_checker.compute_gradient(
         lambda x1: math_ops.add(x1, x2), [x1]))
     tf_logging.info("x1 error = %f", error)
-    assert error < 1e-4
+    self.assertLess(error, 1e-4)
 
   def testAddCustomized(self):
     size = (2, 3)
@@ -66,7 +66,7 @@
         lambda x2: math_ops.add(x1, x2),
         [x2], delta=1e-2))
     tf_logging.info("x2 error = %f", error)
-    assert error < 1e-10
+    self.assertLess(error, 1e-10)
 
   def testGather(self):
     def f(params):
@@ -80,7 +80,7 @@
     error = gradient_checker.max_error(*gradient_checker.compute_gradient(
         f, [params]))
     tf_logging.info("gather error = %f", error)
-    assert error < 1e-4
+    self.assertLess(error, 1e-4)
 
   def testNestedGather(self):
     def f(params):
@@ -97,7 +97,7 @@
     error = gradient_checker.max_error(*gradient_checker.compute_gradient(
         f, [params]))
     tf_logging.info("nested gather error = %f", error)
-    assert error < 1e-4
+    self.assertLess(error, 1e-4)
 
   def testComplexMul(self):
     c = constant_op.constant(5 + 7j, dtype=dtypes.complex64)
diff --git a/tensorflow/python/ops/gradients_util.py b/tensorflow/python/ops/gradients_util.py
index 84a21d0..231d958 100644
--- a/tensorflow/python/ops/gradients_util.py
+++ b/tensorflow/python/ops/gradients_util.py
@@ -43,6 +43,8 @@
 from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import compat
+from tensorflow.python.util import object_identity
+from tensorflow.python.util.compat import collections_abc
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -255,7 +257,8 @@
   """
   # While ops have inputs added to them during the gradient computation, so we
   # skip the below check. See while_v2 for details.
-  if op.type == "While": return
+  if op.type == "While" or op.type == "StatelessWhile":
+    return
 
   if len(grads) != len(op.inputs):
     raise ValueError("Num gradients %d generated for op %s do not match num "
@@ -400,7 +403,7 @@
     return func_graph.captures
   else:
     assert isinstance(func_graph, framework_function._FuncGraph)  # pylint: disable=protected-access
-    return func_graph._captured  # pylint: disable=protected-access
+    return func_graph._captured.items()  # pylint: disable=protected-access
 
 
 def _MaybeCaptured(t):
@@ -415,8 +418,8 @@
   # pylint: disable=protected-access
   if (not isinstance(t, ops.EagerTensor) and
       _IsFunction(t.op.graph) and t.op.type == "Placeholder"):
-    for input_t, placeholder_t in _Captures(t.op.graph).items():
-      if t == placeholder_t:
+    for input_t, placeholder_t in _Captures(t.op.graph):
+      if t is placeholder_t:
         return _MaybeCaptured(input_t)
   # pylint: enable=protected-access
   return t
@@ -452,6 +455,7 @@
     A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
     is in a FuncGraph and has captured inputs.
   """
+  tensors = object_identity.ObjectIdentitySet(xs)
   if _IsFunction(op.graph):  # pylint: disable=protected-access
     inputs = []
     for t in op.inputs:
@@ -460,7 +464,7 @@
       # even if it's a function input for a captured value, whereas usually we'd
       # like to traverse through these closures as if the captured value was the
       # direct input to op.
-      if t not in xs:
+      if t not in tensors:
         t = _MaybeCaptured(t)
       inputs.append(t)
     return inputs
@@ -481,8 +485,8 @@
   """
   consumers = t.consumers()
   for func in func_graphs:
-    for input_t, placeholder in _Captures(func).items():
-      if input_t == t:
+    for input_t, placeholder in _Captures(func):
+      if input_t is t:
         consumers.extend(_Consumers(placeholder, func_graphs))
   return consumers
 
@@ -728,7 +732,7 @@
   for out_grad in out_grads:
     if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)):
       return True
-    if out_grad and isinstance(out_grad, collections.Sequence):
+    if out_grad and isinstance(out_grad, collections_abc.Sequence):
       if any(g is not None for g in out_grad):
         return True
   return False
@@ -953,11 +957,10 @@
         assert control_flow_util.IsLoopSwitch(op)
         continue
     # Grads have to be Tensors or IndexedSlices
-    if (isinstance(out_grad, collections.Sequence) and not all(
+    if (isinstance(out_grad, collections_abc.Sequence) and not all(
         isinstance(g, (ops.Tensor, ops.IndexedSlices))
         for g in out_grad
-        if g is not None
-    )):
+        if g is not None)):
       raise TypeError("gradients have to be either all Tensors "
                       "or all IndexedSlices")
     # Aggregate multiple gradients, and convert [] to None.
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index f47afb9..716c5e9 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -2015,6 +2015,16 @@
 
   Returns:
     Adjusted image(s), same shape and DType as `image`.
+  
+  Usage Example:
+    ```python
+    >> import tensorflow as tf
+    >> x = tf.random.normal(shape=(256, 256, 3))
+    >> tf.image.adjust_jpeg_quality(x, 75)
+    ```
+  Raises:
+    InvalidArgumentError: quality must be in [0,100]
+    InvalidArgumentError: image must have 1 or 3 channels
   """
   with ops.name_scope(name, 'adjust_jpeg_quality', [image]) as name:
     image = ops.convert_to_tensor(image, name='image')
@@ -3424,7 +3434,40 @@
   Returns:
     Pair of tensors (dy, dx) holding the vertical and horizontal image
     gradients (1-step finite difference).
-
+    
+  Usage Example:
+    ```python
+    BATCH_SIZE = 1
+    IMAGE_HEIGHT = 5
+    IMAGE_WIDTH = 5
+    CHANNELS = 1
+    image = tf.reshape(tf.range(IMAGE_HEIGHT * IMAGE_WIDTH * CHANNELS, 
+      delta=1, dtype=tf.float32), 
+      shape=(BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, CHANNELS))
+    dx, dy = tf.image.image_gradients(image)
+    print(image[0, :,:,0])
+    tf.Tensor(
+      [[ 0.  1.  2.  3.  4.]
+      [ 5.  6.  7.  8.  9.]
+      [10. 11. 12. 13. 14.]
+      [15. 16. 17. 18. 19.]
+      [20. 21. 22. 23. 24.]], shape=(5, 5), dtype=float32)
+    print(dx[0, :,:,0])
+    tf.Tensor(
+      [[5. 5. 5. 5. 5.]
+      [5. 5. 5. 5. 5.]
+      [5. 5. 5. 5. 5.]
+      [5. 5. 5. 5. 5.]
+      [0. 0. 0. 0. 0.]], shape=(5, 5), dtype=float32)    
+    print(dy[0, :,:,0])
+    tf.Tensor(
+      [[1. 1. 1. 1. 0.]
+      [1. 1. 1. 1. 0.]
+      [1. 1. 1. 1. 0.]
+      [1. 1. 1. 1. 0.]
+      [1. 1. 1. 1. 0.]], shape=(5, 5), dtype=float32)
+    ```
+    
   Raises:
     ValueError: If `image` is not a 4D tensor.
   """
diff --git a/tensorflow/python/ops/linalg/linear_operator.py b/tensorflow/python/ops/linalg/linear_operator.py
index f323d22..28db4ed 100644
--- a/tensorflow/python/ops/linalg/linear_operator.py
+++ b/tensorflow/python/ops/linalg/linear_operator.py
@@ -939,8 +939,6 @@
 
   def _to_dense(self):
     """Generic and often inefficient implementation.  Override often."""
-    logging.warn("Using (possibly slow) default implementation of to_dense."
-                 "  Converts by self.matmul(identity).")
     if self.batch_shape.is_fully_defined():
       batch_shape = self.batch_shape
     else:
diff --git a/tensorflow/python/ops/linalg/linear_operator_test_util.py b/tensorflow/python/ops/linalg/linear_operator_test_util.py
index 3d1e1fc..30399bd 100644
--- a/tensorflow/python/ops/linalg/linear_operator_test_util.py
+++ b/tensorflow/python/ops/linalg/linear_operator_test_util.py
@@ -178,7 +178,7 @@
     raise NotImplementedError("make_x is not defined.")
 
   @staticmethod
-  def tests_to_skip():
+  def skip_these_tests():
     """List of test names to skip."""
     # Subclasses should over-ride if they want to skip some tests.
     # To skip "test_foo", add "foo" to this list.
@@ -569,7 +569,7 @@
   ]
 
   for name, test_template_fn in test_name_dict.items():
-    if name in test_cls.tests_to_skip():
+    if name in test_cls.skip_these_tests():
       continue
 
     for dtype, use_placeholder, shape_info in itertools.product(
@@ -674,7 +674,7 @@
   """
 
   @staticmethod
-  def tests_to_skip():
+  def skip_these_tests():
     """List of test names to skip."""
     return [
         "cholesky",
diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py
index a1f1505..b2ea5c4 100644
--- a/tensorflow/python/ops/lookup_ops.py
+++ b/tensorflow/python/ops/lookup_ops.py
@@ -1025,7 +1025,7 @@
           ids = self._table.lookup(values)
           buckets = math_ops.add(buckets, self._table.size())
           is_id_non_default = math_ops.not_equal(ids, self._table.default_value)
-          ids = array_ops.where(is_id_non_default, ids, buckets)
+          ids = array_ops.where_v2(is_id_non_default, ids, buckets)
         else:
           ids = buckets
     if isinstance(keys, sparse_tensor.SparseTensor):
@@ -1199,7 +1199,7 @@
         ids = self._table.lookup(values)
         buckets = math_ops.add(buckets, self._table.size())
         is_id_non_default = math_ops.not_equal(ids, self._table.default_value)
-        ids = array_ops.where(is_id_non_default, ids, buckets)
+        ids = array_ops.where_v2(is_id_non_default, ids, buckets)
       else:
         ids = buckets
     if isinstance(keys, sparse_tensor.SparseTensor):
diff --git a/tensorflow/python/ops/losses/loss_reduction.py b/tensorflow/python/ops/losses/loss_reduction.py
index 483a325..7fdc791 100644
--- a/tensorflow/python/ops/losses/loss_reduction.py
+++ b/tensorflow/python/ops/losses/loss_reduction.py
@@ -28,10 +28,10 @@
      used with `tf.distribute.Strategy`, outside of built-in training loops such
      as `tf.keras` `compile` and `fit`, we expect reduction value to be
      `SUM` or `NONE`. Using `AUTO` in that case will raise an error.
-  * `NONE`: Un-reduced weighted losses with the same shape as input. When this
-    reduction type used with built-in Keras training loops like
-    `fit`/`evaluate`, the unreduced vector loss is passed to the optimizer but
-    the reported loss will be a scalar value.
+  * `NONE`: Weighted losses with one dimension reduced (axis=-1, or axis
+     specified by loss function). When this reduction type used with built-in
+     Keras training loops like `fit`/`evaluate`, the unreduced vector loss is
+     passed to the optimizer but the reported loss will be a scalar value.
   * `SUM`: Scalar sum of weighted losses.
   * `SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses.
      This reduction type is not supported when used with
@@ -42,7 +42,7 @@
      ```
      with strategy.scope():
        loss_obj = tf.keras.losses.CategoricalCrossentropy(
-           reduction=tf.keras.losses.Reduction.None)
+           reduction=tf.keras.losses.Reduction.NONE)
        ....
        loss = tf.reduce_sum(loss_object(labels, predictions)) *
            (1. / global_batch_size)
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 31e5895..3d6a915 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -19,6 +19,7 @@
 
 import numpy as np
 
+from tensorflow.python import pywrap_tensorflow as c_api
 from tensorflow.python.compat import compat
 from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
@@ -52,6 +53,79 @@
 ops.NotDifferentiable("EuclideanNorm")
 
 
+def SmartBroadcastGradientArgs(x, y, grad):
+  """Optimized version of `broadcast_gradient_args` that caches results.
+
+  This implementation avoids creating `broadcast_gradient_args` ops in the case
+  that the input shapes are fully defined, and provides hints to the calling
+  code that can be used to avoid creating reduction and reshaping ops.
+
+  Args:
+    x: The left input tensor to a broadcasting binary op.
+    y: The right input tensor to a broadcasting binary op.
+    grad: The incoming gradient tensor for a broadcasting binary op.
+
+  Returns:
+    A pair of tuples, containing:
+      * A 3-tuple of broadcast information for x, containing:
+        * The shape of x (as a tuple or Tensor).
+        * The reduction indices for x (as a tuple or Tensor).
+        * A boolean, which if True, indicates that x's shape differs from grad's
+          shape (and so x's gradient must be reduced and/or reshaped).
+      * A 3-tuple of broadcast information for y, containing the respective
+        details for y.
+  """
+  # NOTE: It may be productive to apply these optimizations in the eager case
+  # as well.
+  if context.executing_eagerly() or not (
+      isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor)
+      and isinstance(grad, ops.Tensor)):
+    sx = array_ops.shape(x)
+    sy = array_ops.shape(y)
+    rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
+    return (sx, rx, True), (sy, ry, True)
+
+  # pylint: disable=protected-access
+  x_shape_tuple = x._shape_tuple()
+  y_shape_tuple = y._shape_tuple()
+  grad_shape_tuple = grad._shape_tuple()
+  # pylint: enable=protected-access
+
+  if (x_shape_tuple is None or None in x_shape_tuple or
+      y_shape_tuple is None or None in y_shape_tuple):
+    sx = array_ops.shape_internal(x, optimize=False)
+    sy = array_ops.shape_internal(y, optimize=False)
+    rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
+    return (sx, rx, True), (sy, ry, True)
+
+  x_needs_reduction = x_shape_tuple != grad_shape_tuple
+  y_needs_reduction = y_shape_tuple != grad_shape_tuple
+
+  # Get the default graph rather than relying on `x.graph`, `y.graph`, or
+  # `grad.graph`, because these may be eager tensors.
+  g = ops.get_default_graph()
+
+  try:
+    rx, ry = g._bcast_grad_args_cache[(x_shape_tuple, y_shape_tuple)]  # pylint: disable=protected-access
+    return (x_shape_tuple, rx, x_needs_reduction), (
+        y_shape_tuple, ry, y_needs_reduction)
+  except KeyError:
+    rx, ry = array_ops.broadcast_gradient_args(x_shape_tuple, y_shape_tuple)
+    # TODO(mrry): If this becomes a bottleneck, add a multi-output version of
+    # `TF_TryEvaluateConstant()`.
+    rx_value = tuple(c_api.TF_TryEvaluateConstant_wrapper(
+        rx.graph._c_graph, rx._as_tf_output()))  # pylint: disable=protected-access
+    assert rx_value is not None
+    ry_value = tuple(c_api.TF_TryEvaluateConstant_wrapper(
+        ry.graph._c_graph, ry._as_tf_output()))  # pylint: disable=protected-access
+    assert ry_value is not None
+    g._bcast_grad_args_cache[(x_shape_tuple, y_shape_tuple)] = (  # pylint: disable=protected-access
+        rx_value, ry_value)
+
+    return (x_shape_tuple, rx_value, x_needs_reduction), (
+        y_shape_tuple, ry_value, y_needs_reduction)
+
+
 _empty_tuple = ()
 
 
@@ -85,6 +159,37 @@
         else:
           input_shape = array_ops.shape(op.inputs[0])
         return [array_ops.tile(grad, input_shape), None]
+      elif None not in input_0_shape and not context.executing_eagerly():
+        # The shape and reduction indices are statically known, so we use a
+        # graph-level cache to avoid recomputing `reduced_shape()` for each
+        # invocation.
+        graph = ops.get_default_graph()
+
+        # Canonicalize `axes` to be a tuple of indices. The incoming
+        # value may be a scalar or a vector, and may include negative indices.
+        axes = tuple(axes.reshape(-1))
+
+        try:
+          output_shape_kept_dims, tile_scaling = graph._reduced_shape_cache[  # pylint: disable=protected-access
+              (input_0_shape, axes)]
+        except KeyError:
+
+          # Compute and cache `output_shape_kept_dims` and `tile_scaling`.
+          def EvaluateAsTuple(t):
+            value = c_api.TF_TryEvaluateConstant_wrapper(
+                t.graph._c_graph, t._as_tf_output())  # pylint: disable=protected-access
+            assert value is not None
+            return tuple(value)
+
+          output_shape_kept_dims = EvaluateAsTuple(
+              math_ops.reduced_shape(input_0_shape, axes))
+          tile_scaling = EvaluateAsTuple(
+              _safe_shape_div(input_0_shape, output_shape_kept_dims))
+          graph._reduced_shape_cache[(input_0_shape, axes)] = (  # pylint:disable=protected-access
+              output_shape_kept_dims, tile_scaling)
+
+        grad = array_ops.reshape(grad, output_shape_kept_dims)
+        return [array_ops.tile(grad, tile_scaling), None]
 
   input_shape = array_ops.shape(op.inputs[0])
   # TODO(apassos) remove this once device placement for eager ops makes more
@@ -1000,55 +1105,96 @@
   if (isinstance(grad, ops.Tensor) and
       _ShapesFullySpecifiedAndEqual(x, y, grad)):
     return grad, grad
-  sx = array_ops.shape(x)
-  sy = array_ops.shape(y)
-  rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
+  (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
+      SmartBroadcastGradientArgs(x, y, grad))
   if skip_input_indices is not None and 0 in skip_input_indices:
     gx = None
+  elif not must_reduce_x:
+    gx = grad
   else:
     gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx)
   if skip_input_indices is not None and 1 in skip_input_indices:
     gy = None
+  elif not must_reduce_y:
+    gy = grad
   else:
     gy = array_ops.reshape(math_ops.reduce_sum(grad, ry), sy)
   return (gx, gy)
 
 
-
 @ops.RegisterGradient("Sub")
 def _SubGrad(op, grad):
   """Gradient for Sub."""
-  x = op.inputs[0]
   y = op.inputs[1]
+  skip_input_indices = None
+  try:
+    skip_input_indices = op.skip_input_indices
+    if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
+        y):
+      return grad, None
+  except AttributeError:
+    # No gradient skipping, so do the full gradient computation
+    pass
+  x = op.inputs[0]
   if (isinstance(grad, ops.Tensor) and
       _ShapesFullySpecifiedAndEqual(x, y, grad)):
     return grad, -grad
-  sx = array_ops.shape(x)
-  sy = array_ops.shape(y)
-  rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
-  return (array_ops.reshape(math_ops.reduce_sum(grad, rx), sx),
-          array_ops.reshape(-math_ops.reduce_sum(grad, ry), sy))
+  (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
+      SmartBroadcastGradientArgs(x, y, grad))
+  if skip_input_indices is not None and 0 in skip_input_indices:
+    gx = None
+  elif not must_reduce_x:
+    gx = grad
+  else:
+    gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx)
+  if skip_input_indices is not None and 1 in skip_input_indices:
+    gy = None
+  elif not must_reduce_y:
+    gy = -grad
+  else:
+    gy = array_ops.reshape(math_ops.reduce_sum(-grad, ry), sy)
+  return (gx, gy)
 
 
 @ops.RegisterGradient("Mul")
 def _MulGrad(op, grad):
   """The gradient of scalar multiplication."""
-  x = op.inputs[0]
   y = op.inputs[1]
+  skip_input_indices = None
+  try:
+    skip_input_indices = op.skip_input_indices
+    if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
+        y):
+      return gen_math_ops.mul(grad, math_ops.conj(y)), None
+  except AttributeError:
+    # No gradient skipping, so do the full gradient computation
+    pass
+  x = op.inputs[0]
   if (isinstance(grad, ops.Tensor) and
       _ShapesFullySpecifiedAndEqual(x, y, grad) and
       grad.dtype in (dtypes.int32, dtypes.float32)):
     return gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x)
   assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
-  sx = array_ops.shape(x)
-  sy = array_ops.shape(y)
-  rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
+
+  (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
+      SmartBroadcastGradientArgs(x, y, grad))
   x = math_ops.conj(x)
   y = math_ops.conj(y)
-  return (array_ops.reshape(
-      math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx),
-          array_ops.reshape(
-              math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy))
+  if skip_input_indices is not None and 0 in skip_input_indices:
+    gx = None
+  elif not must_reduce_x:
+    gx = gen_math_ops.mul(grad, y)
+  else:
+    gx = array_ops.reshape(
+        math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx)
+  if skip_input_indices is not None and 1 in skip_input_indices:
+    gy = None
+  elif not must_reduce_y:
+    gy = gen_math_ops.mul(x, grad)
+  else:
+    gy = array_ops.reshape(
+        math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy)
+  return (gx, gy)
 
 
 @ops.RegisterGradient("MulNoNan")
@@ -1181,35 +1327,61 @@
   """Returns grad * (y*x^(y-1), z*log(x))."""
   x = op.inputs[0]
   y = op.inputs[1]
-  z = op.outputs[0]
-  sx = array_ops.shape(x)
-  sy = array_ops.shape(y)
-  rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
+  use_mul_no_nan = compat.forward_compatible(2019, 9, 14)
+  skip_input_indices = None
+  try:
+    skip_input_indices = op.skip_input_indices
+    # TODO(mrry): If `y` is a constant, we can combine `tf.sub()` and the
+    # constant `1` into a single constant op.
+    if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
+        y):
+      x = math_ops.conj(x)
+      y = math_ops.conj(y)
+      if use_mul_no_nan:
+        return gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad), None
+      else:
+        return grad * y * math_ops.pow(x, y - 1), None
+
+  except AttributeError:
+    # No gradient skipping, so do the full gradient computation
+    pass
+
+  (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
+      SmartBroadcastGradientArgs(x, y, grad))
   x = math_ops.conj(x)
   y = math_ops.conj(y)
-  z = math_ops.conj(z)
 
-  if compat.forward_compatible(2019, 9, 14):
-    gx = array_ops.reshape(
-        math_ops.reduce_sum(
-            gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad), rx), sx)
+  if skip_input_indices is None or 0 not in skip_input_indices:
+    if use_mul_no_nan:
+      gx = gen_math_ops.mul_no_nan(y * math_ops.pow(x, y - 1), grad)
+    else:
+      gx = grad * y * math_ops.pow(x, y - 1)
+    if must_reduce_x:
+      gx = array_ops.reshape(math_ops.reduce_sum(gx, rx), sx)
   else:
-    gx = array_ops.reshape(
-        math_ops.reduce_sum(grad * y * math_ops.pow(x, y - 1), rx), sx)
-  # Avoid false singularity at x = 0
-  if x.dtype.is_complex:
-    # real(x) < 0 is fine for the complex case
-    mask = math_ops.not_equal(x, 0)
+    gx = None
+
+  if skip_input_indices is None or 1 not in skip_input_indices:
+    z = math_ops.conj(op.outputs[0])
+
+    # Avoid false singularity at x = 0
+    if x.dtype.is_complex:
+      # real(x) < 0 is fine for the complex case
+      mask = math_ops.not_equal(x, 0)
+    else:
+      # There's no sensible real value to return if x < 0, so return 0
+      mask = x > 0
+    safe_x = array_ops.where(mask, x, array_ops.ones_like(x))
+    log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x))
+    if use_mul_no_nan:
+      gy = gen_math_ops.mul_no_nan(z * log_x, grad)
+    else:
+      gy = grad * z * log_x
+    if must_reduce_y:
+      gy = array_ops.reshape(math_ops.reduce_sum(gy, ry), sy)
   else:
-    # There's no sensible real value to return if x < 0, so return 0
-    mask = x > 0
-  safe_x = array_ops.where(mask, x, array_ops.ones_like(x))
-  log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x))
-  if compat.forward_compatible(2019, 9, 14):
-    gy = array_ops.reshape(
-        math_ops.reduce_sum(gen_math_ops.mul_no_nan(z * log_x, grad), ry), sy)
-  else:
-    gy = array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy)
+    gy = None
+
   return gx, gy
 
 
@@ -1277,15 +1449,39 @@
   """Returns the gradient for (x-y)^2."""
   x = op.inputs[0]
   y = op.inputs[1]
-  sx = array_ops.shape(x)
-  sy = array_ops.shape(y)
-  rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
+  skip_input_indices = None
+  try:
+    skip_input_indices = op.skip_input_indices
+  except AttributeError:
+    # No gradient skipping, so do the full gradient computation
+    pass
+
   with ops.control_dependencies([grad]):
     # The parens ensure that if grad is IndexedSlices, it'll get multiplied by
     # Tensor (not a number like 2.0) which causes it to convert to Tensor.
     x_grad = math_ops.scalar_mul(2.0, grad) * (x - y)
-  return (array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx),
-          -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy))
+
+  if (isinstance(grad, ops.Tensor) and
+      _ShapesFullySpecifiedAndEqual(x, y, grad)):
+    return x_grad, -x_grad
+
+  (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
+      SmartBroadcastGradientArgs(x, y, grad))
+
+  if skip_input_indices is not None and 0 in skip_input_indices:
+    gx = None
+  elif must_reduce_x:
+    gx = array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx)
+  else:
+    gx = x_grad
+
+  if skip_input_indices is not None and 1 in skip_input_indices:
+    gy = None
+  elif must_reduce_y:
+    gy = -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy)
+  else:
+    gy = -x_grad
+  return (gx, gy)
 
 
 # Logical operations have no gradients.
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index eb9d440..34eeb54 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -1277,7 +1277,9 @@
 
 def tensor_equals(self, other):
   """Compares two tensors element-wise for equality."""
-  if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions():
+  g = getattr(self, "graph", None)
+  if (ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions() and
+      (g is None or g._building_function)):  # pylint: disable=protected-access
     return gen_math_ops.equal(self, other)
   else:
     # In legacy graph mode, tensor equality is object equality
@@ -1349,20 +1351,9 @@
     start, limit = 0, start
 
   with ops.name_scope(name, "Range", [start, limit, delta]) as name:
-    # In case dtype is not none, cast start, limit, and delta directly.
-    # Otherwise pass to convert_to_tensor. This is to handle
-    # the situation with:
-    #   tf.range(tf.constant(5), dtype=tf.float32)
-    # which is comparable with:
-    #   np.arange(np.int(5), dtype=np.float32)
-    if dtype is not None:
-      start = cast(start, dtype=dtype, name="start")
-      limit = cast(limit, dtype=dtype, name="limit")
-      delta = cast(delta, dtype=dtype, name="delta")
-    else:
-      start = ops.convert_to_tensor(start, name="start")
-      limit = ops.convert_to_tensor(limit, name="limit")
-      delta = ops.convert_to_tensor(delta, name="delta")
+    start = ops.convert_to_tensor(start, dtype=dtype, name="start")
+    limit = ops.convert_to_tensor(limit, dtype=dtype, name="limit")
+    delta = ops.convert_to_tensor(delta, dtype=dtype, name="delta")
 
     # infer dtype if not explicitly provided
     if dtype is None:
@@ -4069,3 +4060,35 @@
     for c in coeffs[1:]:
       p = c + p * x
     return p
+
+
+@tf_export("math.reciprocal_no_nan")
+def reciprocal_no_nan(x, name=None):
+  """Performs a safe reciprocal operation, element wise.
+
+  If a particular element is zero, the reciprocal for that element is
+  also set to zero.
+
+  For example:
+  ```python
+  x = tf.constant([2.0, 0.5, 0, 1], dtype=tf.float32)
+  tf.math.reciprocal_no_nan(x)  # [ 0.5, 2, 0.0, 1.0 ]
+  ```
+
+  Args:
+    x: A `Tensor` of type `float16`, `float32`, `float64` `complex64` or
+      `complex128`.
+    name: A name for the operation (optional).
+
+  Returns:
+    A `Tensor` of same shape and type as `x`.
+
+  Raises:
+    TypeError: x must be of a valid dtype.
+
+  """
+
+  with ops.name_scope(name, "reciprocal_no_nan", [x]) as scope:
+    x = ops.convert_to_tensor(x, name="x")
+    one = constant_op.constant(1, dtype=x.dtype.base_dtype, name="one")
+    return gen_math_ops.div_no_nan(one, x, name=scope)
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index 68740b6..c8fd977 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -699,5 +699,37 @@
       a = array_ops.ones([1], dtype=dtypes.int32) + 1.0
       self.evaluate(a)
 
+
+class ReciprocalNoNanTest(test_util.TensorFlowTestCase):
+
+  allowed_dtypes = [
+      dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64,
+      dtypes.complex128
+  ]
+
+  @test_util.run_in_graph_and_eager_modes
+  def testBasic(self):
+    for dtype in self.allowed_dtypes:
+      x = constant_op.constant([1.0, 2.0, 0.0, 4.0], dtype=dtype)
+
+      y = math_ops.reciprocal_no_nan(x)
+
+      target = constant_op.constant([1.0, 0.5, 0.0, 0.25], dtype=dtype)
+
+      self.assertAllEqual(y, target)
+      self.assertEqual(y.dtype.base_dtype, target.dtype.base_dtype)
+
+  @test_util.run_in_graph_and_eager_modes
+  def testInverse(self):
+    for dtype in self.allowed_dtypes:
+      x = np.random.choice([0, 1, 2, 4, 5], size=(5, 5, 5))
+      x = constant_op.constant(x, dtype=dtype)
+
+      y = math_ops.reciprocal_no_nan(math_ops.reciprocal_no_nan(x))
+
+      self.assertAllClose(y, x)
+      self.assertEqual(y.dtype.base_dtype, x.dtype.base_dtype)
+
+
 if __name__ == "__main__":
   googletest.main()
diff --git a/tensorflow/python/ops/nn_batchnorm_test.py b/tensorflow/python/ops/nn_batchnorm_test.py
index e978f1d..5f0616b 100644
--- a/tensorflow/python/ops/nn_batchnorm_test.py
+++ b/tensorflow/python/ops/nn_batchnorm_test.py
@@ -364,7 +364,7 @@
       if d in set(axes):
         count *= x.shape[d]
     if not keep_dims:
-      shift = np.squeeze(shift, axis=axis)
+      shift = np.asarray(shift)
     return count, m_ss, v_ss, shift
 
   def _opSuffStats(self, x, axes, shift, keep_dims):
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index 8e02871..7d31604 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -513,6 +513,24 @@
   return vec * mat
 
 
+def _IsZero(tensor):
+  """Check if tensor contains only zeros.
+
+  Args:
+    tensor: tensor to check
+
+  Returns:
+    True if tensor contains only zeros and False otherwise
+  """
+  if context.executing_eagerly():
+    # TODO(apassos) add an efficient way to detect eager zeros here.
+    return False
+  if tensor.op.type in ("ZerosLike", "Zeros"):
+    return True
+  const_fill_value = tensor_util.constant_value(tensor)
+  return const_fill_value is not None and (const_fill_value == 0).all()
+
+
 @ops.RegisterGradient("SoftmaxCrossEntropyWithLogits")
 def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
   """Gradient function for SoftmaxCrossEntropyWithLogits."""
@@ -524,18 +542,8 @@
   softmax_grad = op.outputs[1]
   grad = _BroadcastMul(grad_loss, softmax_grad)
 
-  def IsZero(g):
-    # Some introspection to check if the gradient is feeding zeros
-    if context.executing_eagerly():
-      # TODO(apassos) add an efficient way to detect eager zeros here.
-      return False
-    if g.op.type in ("ZerosLike", "Zeros"):
-      return True
-    const_fill_value = tensor_util.constant_value(g)
-    return const_fill_value is not None and (const_fill_value == 0).all()
-
   logits = op.inputs[0]
-  if grad_grad is not None and not IsZero(grad_grad):
+  if grad_grad is not None and not _IsZero(grad_grad):
     softmax = nn_ops.softmax(logits)
 
     grad += ((grad_grad - array_ops.squeeze(
@@ -548,22 +556,28 @@
 
 
 @ops.RegisterGradient("SparseSoftmaxCrossEntropyWithLogits")
-def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _):
+def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
   """Gradient function for SparseSoftmaxCrossEntropyWithLogits."""
-  # grad_0 is the backprop for cost, and we multiply it with the gradients
+  # grad_loss is the backprop for cost, and we multiply it with the gradients
   # (which is output[1])
+  # grad_grad is the backprop for softmax gradient.
   # There is no gradient for the labels
   #
-  # Currently there is no way to take the second derivative of this op
-  # due to the fused implementation's interaction with tf.gradients(),
-  # so we make sure we prevent silently incorrect results by raising
-  # an error if the second derivative is requested via prevent_gradient.
-  sparse_softmax_grad_without_gradient = array_ops.prevent_gradient(
-      op.outputs[1],
-      message="Currently there is no way to take the second "
-      "derivative of sparse_softmax_cross_entropy_with_logits due to the fused "
-      "implementation's interaction with tf.gradients()")
-  return _BroadcastMul(grad_0, sparse_softmax_grad_without_gradient), None
+  # Second derivative is just softmax derivative w.r.t. logits.
+  softmax_grad = op.outputs[1]
+  grad = _BroadcastMul(grad_loss, softmax_grad)
+
+  logits = op.inputs[0]
+  if grad_grad is not None and not _IsZero(grad_grad):
+    softmax = nn_ops.softmax(logits)
+
+    grad += ((grad_grad - array_ops.squeeze(
+        math_ops.matmul(
+            array_ops.expand_dims(grad_grad, 1),
+            array_ops.expand_dims(softmax, 2)),
+        axis=1)) * softmax)
+
+  return grad, None
 
 
 @ops.RegisterGradient("Conv2D")
@@ -614,15 +628,17 @@
           array_ops.shape(op.inputs[0]),
           op.inputs[1],
           grad,
-          op.get_attr("strides"),
-          op.get_attr("padding"),
+          dilations=op.get_attr("dilations"),
+          strides=op.get_attr("strides"),
+          padding=op.get_attr("padding"),
           data_format=op.get_attr("data_format")),
       nn_ops.depthwise_conv2d_native_backprop_filter(
           op.inputs[0],
           array_ops.shape(op.inputs[1]),
           grad,
-          op.get_attr("strides"),
-          op.get_attr("padding"),
+          dilations=op.get_attr("dilations"),
+          strides=op.get_attr("strides"),
+          padding=op.get_attr("padding"),
           data_format=op.get_attr("data_format"))
   ]
 
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 216c575..1435a0c 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -715,6 +715,22 @@
     return array_ops.identity(zero_fraction_float32, "fraction")
 
 
+# copybara:strip_begin
+# TODO(b/138808492): Remove code inside copybara
+# to make TPU code and CPU code consistent.
+def _enclosing_tpu_context():
+  # pylint: disable=protected-access
+  context = ops.get_default_graph()._get_control_flow_context()
+  # pylint: enable=protected-access
+  while context is not None and not isinstance(
+      context, control_flow_ops.XLAControlFlowContext):
+    context = context.outer_context
+  return context
+
+
+# copybara:strip_end
+
+
 # pylint: disable=redefined-builtin
 @tf_export(v1=["nn.depthwise_conv2d"])
 def depthwise_conv2d(input,
@@ -774,6 +790,25 @@
     if rate is None:
       rate = [1, 1]
 
+    # copybara:strip_begin
+    # TODO(b/138808492): Remove code inside copybara
+    # to make TPU code and CPU code consistent.
+    # Use depthwise_conv2d_native if executing on TPU.
+    if _enclosing_tpu_context() is not None:
+      if data_format == "NCHW":
+        dilations = [1, 1, rate[0], rate[1]]
+      else:
+        dilations = [1, rate[0], rate[1], 1]
+      return nn_ops.depthwise_conv2d_native(
+          input=input,
+          filter=filter,
+          strides=strides,
+          padding=padding,
+          data_format=data_format,
+          dilations=dilations,
+          name=name)
+    # copybara:strip_end
+
     def op(input_converted, _, padding):
       return nn_ops.depthwise_conv2d_native(
           input=input_converted,
diff --git a/tensorflow/python/ops/nn_loss_scaling_utilities_test.py b/tensorflow/python/ops/nn_loss_scaling_utilities_test.py
index 3578fb0..cf2a7d2 100644
--- a/tensorflow/python/ops/nn_loss_scaling_utilities_test.py
+++ b/tensorflow/python/ops/nn_loss_scaling_utilities_test.py
@@ -184,3 +184,7 @@
           RuntimeError, "You are calling `scale_regularization_loss` in "
           "cross replica context"):
         nn_impl.scale_regularization_loss([2, 3])
+
+
+if __name__ == "__main__":
+  test_lib.main()
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index f5e9aea..17e6bed 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -33,6 +33,10 @@
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import check_ops
+# copybara:strip_begin
+# TODO(b/138808492): Remove code inside copybara
+from tensorflow.python.ops import control_flow_ops
+# copybara:strip_end
 from tensorflow.python.ops import gen_nn_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import random_ops
@@ -42,6 +46,7 @@
 # pylint: enable=wildcard-import
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import deprecation
+from tensorflow.python.util.compat import collections_abc
 from tensorflow.python.util.deprecation import deprecated_args
 from tensorflow.python.util.deprecation import deprecated_argument_lookup
 
@@ -57,7 +62,7 @@
   """Formats a value input for gen_nn_ops."""
   if value is None:
     value = [1]
-  elif not isinstance(value, collections.Sized):
+  elif not isinstance(value, collections_abc.Sized):
     value = [value]
 
   current_n = len(value)
@@ -280,7 +285,7 @@
       tensor. Must be: `[1, stride_height, stride_width, 1]`.
     padding: A `string` from: `"SAME", "VALID"`.
       The type of padding algorithm to use.
-    data_format: A `string`, only `"NCHW"` is currently supported.
+    data_format: A `string`, only `"NHWC"` is currently supported.
     dilations: A list of `ints` that has length `>= 4`.
       The input stride for atrous morphological dilation. Must be:
       `[1, rate_height, rate_width, 1]`.
@@ -289,8 +294,8 @@
   Returns:
     A `Tensor`. Has the same type as `input`.
   """
-  if data_format != "NCHW":
-    raise ValueError("Data formats other than NCHW are not yet supported")
+  if data_format != "NHWC":
+    raise ValueError("Data formats other than NHWC are not yet supported")
 
   return gen_nn_ops.dilation2d(input=input,
                                filter=filters,
@@ -918,6 +923,22 @@
     "filter", "filters")
 
 
+# copybara:strip_begin
+# TODO(b/138808492): Remove code inside copybara
+# to make TPU code and CPU code consistent.
+def _enclosing_tpu_context():
+  # pylint: disable=protected-access
+  run_context = ops.get_default_graph()._get_control_flow_context()
+  # pylint: enable=protected-access
+  while run_context is not None and not isinstance(
+      run_context, control_flow_ops.XLAControlFlowContext):
+    run_context = run_context.outer_context
+  return run_context
+
+
+# copybara:strip_end
+
+
 def convolution_internal(
     input,  # pylint: disable=redefined-builtin
     filters,
@@ -925,40 +946,58 @@
     padding="VALID",
     data_format=None,
     dilations=None,
-    name=None):
+    name=None,
+    call_from_convolution=True):
   """Internal function which performs rank agnostic convolution."""
-  with ops.name_scope(name, "convolution", [input, filters]) as name:
-    if isinstance(input.shape, tensor_shape.TensorShape) and \
+  if isinstance(input.shape, tensor_shape.TensorShape) and \
         input.shape.rank is not None:
-      n = len(input.shape) - 2
-    elif not isinstance(input.shape, tensor_shape.TensorShape) and \
+    n = len(input.shape) - 2
+  elif not isinstance(input.shape, tensor_shape.TensorShape) and \
         input.shape is not None:
-      n = len(input.shape) - 2
-    elif isinstance(filters.shape, tensor_shape.TensorShape) and \
+    n = len(input.shape) - 2
+  elif isinstance(filters.shape, tensor_shape.TensorShape) and \
         filters.shape.rank is not None:
-      n = len(filters.shape) - 2
-    elif not isinstance(filters.shape, tensor_shape.TensorShape) and \
+    n = len(filters.shape) - 2
+  elif not isinstance(filters.shape, tensor_shape.TensorShape) and \
         filters.shape is not None:
-      n = len(filters.shape) - 2
-    else:
-      raise ValueError("rank of input or filter must be known")
+    n = len(filters.shape) - 2
+  else:
+    raise ValueError("rank of input or filter must be known")
 
-    if not 1 <= n <= 3:
-      raise ValueError(
-          "Input tensor must be of rank 3, 4 or 5 but was {}.".format(n + 2))
+  if not 1 <= n <= 3:
+    raise ValueError(
+        "Input tensor must be of rank 3, 4 or 5 but was {}.".format(n + 2))
 
-    if data_format is None:
-      channel_index = n + 1
-    else:
-      channel_index = 1 if data_format.startswith("NC") else n + 1
+  if data_format is None:
+    channel_index = n + 1
+  else:
+    channel_index = 1 if data_format.startswith("NC") else n + 1
 
-    strides = _get_sequence(strides, n, channel_index, "strides")
-    dilations = _get_sequence(dilations, n, channel_index, "dilations")
+  strides = _get_sequence(strides, n, channel_index, "strides")
+  dilations = _get_sequence(dilations, n, channel_index, "dilations")
 
+  # copybara:strip_begin
+  # TODO(b/138808492): Remove code inside copybara
+  # to make TPU code and CPU code consistent.
+  scopes = {1: "conv1d", 2: "Conv2D", 3: "Conv3D"}
+  if not call_from_convolution and _enclosing_tpu_context() is not None:
+    scope = scopes[n]
+  else:
+    scope = "convolution"
+  # copybara:strip_end
+  # copybara:insert scope = "convolution"
+
+  with ops.name_scope(name, scope, [input, filters]) as name:
     conv_ops = {1: conv1d, 2: gen_nn_ops.conv2d, 3: gen_nn_ops.conv3d}
 
-    if all(i == 1 for i in dilations):
-      # fast path if no dilation as gradient only supported on GPU for dilations
+    # copybara:strip_begin
+    # TODO(b/138808492): Remove code inside copybara
+    # to make TPU code and CPU code consistent.
+    if _enclosing_tpu_context() is not None or all(i == 1 for i in dilations):
+      # fast path for TPU or if no dilation as gradient only supported on GPU
+      # for dilations
+    # copybara:strip_end
+    # copybara:insert if all(i == 1 for i in dilations):
       op = conv_ops[n]
       return op(
           input,
@@ -1055,7 +1094,9 @@
     self.filter_shape = filter_shape
     self.data_format = data_format
     self.strides = strides
+    self.padding = padding
     self.name = name
+    self.dilation_rate = dilation_rate
     self.conv_op = _WithSpaceToBatch(
         input_shape,
         dilation_rate=dilation_rate,
@@ -1075,7 +1116,24 @@
         name=self.name)
 
   def __call__(self, inp, filter):  # pylint: disable=redefined-builtin
-    return self.conv_op(inp, filter)
+    # copybara:strip_begin
+    # TODO(b/138808492): Remove code inside copybara
+    # to make TPU code and CPU code consistent.
+    # TPU convolution supports dilations greater than 1.
+    if _enclosing_tpu_context() is not None:
+      return convolution_internal(
+          inp,
+          filter,
+          strides=self.strides,
+          padding=self.padding,
+          data_format=self.data_format,
+          dilations=self.dilation_rate,
+          name=self.name,
+          call_from_convolution=False)
+    else:
+      return self.conv_op(inp, filter)
+    # copybara:strip_end
+    # copybara:insert return self.conv_op(inp, filter)
 
 
 @tf_export(v1=["nn.pool"])
@@ -2601,10 +2659,10 @@
   """
   with ops.name_scope(name, "conv_transpose",
                       [input, filter, output_shape]) as name:
-    if isinstance(output_shape, collections.Sized):
-      n = len(output_shape) - 2
-    elif isinstance(output_shape, ops.Tensor):
+    if tensor_util.is_tensor(output_shape):
       n = output_shape.shape[0] - 2
+    elif isinstance(output_shape, collections.Sized):
+      n = len(output_shape) - 2
     else:
       raise ValueError("output_shape must be a tensor or sized collection.")
 
@@ -2743,7 +2801,7 @@
 def leaky_relu(features, alpha=0.2, name=None):
   """Compute the Leaky ReLU activation function.
 
-  Source: [Rectifier Nonlinearities Improve Neural Network Acoustic Models. 
+  Source: [Rectifier Nonlinearities Improve Neural Network Acoustic Models.
   AL Maas, AY Hannun, AY Ng - Proc. ICML, 2013](https://ai.stanford.edu/~amaas/papers/relu_hybrid_icml2013_final.pdf).
 
   Args:
@@ -3595,8 +3653,8 @@
     ksize = [1] + _get_sequence(ksize, 1, channel_index, "ksize")
     strides = [1] + _get_sequence(strides, 1, channel_index, "strides")
 
-    data_format = "NHWC" if data_format == "NWC" else "NCHW"
     expanding_dim = 1 if data_format == "NWC" else 2
+    data_format = "NHWC" if data_format == "NWC" else "NCHW"
 
     input = array_ops.expand_dims_v2(input, expanding_dim)
     result = gen_nn_ops.avg_pool(
@@ -3786,8 +3844,8 @@
     ksize = [1] + _get_sequence(ksize, 1, channel_index, "ksize")
     strides = [1] + _get_sequence(strides, 1, channel_index, "strides")
 
-    data_format = "NHWC" if data_format == "NWC" else "NCHW"
     expanding_dim = 1 if data_format == "NWC" else 2
+    data_format = "NHWC" if data_format == "NWC" else "NCHW"
 
     input = array_ops.expand_dims_v2(input, expanding_dim)
     result = gen_nn_ops.max_pool(
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 71f16de..4763ae0 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -1276,6 +1276,17 @@
 
     self.assertAllEqual(self.evaluate(y1), self.evaluate(y2))
 
+  def test1DNumpyWithGolden(self):
+    dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64
+    x = np.array([[[3], [6], [5]],
+                  [[1], [0], [1]]], dtype=dtype)
+    ksize = 2
+    strides = 1
+    y = nn_ops.avg_pool1d(x, ksize, strides, "SAME")
+    expected_y = np.array([[[4.5], [5.5], [5.0]],
+                           [[0.5], [0.5], [1.0]]], dtype=dtype)
+    self.assertAllEqual(self.evaluate(y), expected_y)
+
   def test2DTensor(self):
     x = array_ops.ones([3, 6, 6, 5])
     ksize = 2
@@ -1350,6 +1361,17 @@
 
     self.assertAllEqual(self.evaluate(y1), self.evaluate(y2))
 
+  def test1DNumpyWithGolden(self):
+    dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64
+    x = np.array([[[3], [6], [5]],
+                  [[1], [0], [1]]], dtype=dtype)
+    ksize = 2
+    strides = 1
+    y = nn_ops.max_pool1d(x, ksize, strides, "SAME")
+    expected_y = np.array([[[6], [6], [5]],
+                           [[1], [1], [1]]], dtype=dtype)
+    self.assertAllEqual(self.evaluate(y), expected_y)
+
   def test2DTensor(self):
     x = array_ops.ones([3, 6, 6, 5])
     ksize = 2
diff --git a/tensorflow/python/ops/op_selector.py b/tensorflow/python/ops/op_selector.py
index 68594a5..1ae43aa 100644
--- a/tensorflow/python/ops/op_selector.py
+++ b/tensorflow/python/ops/op_selector.py
@@ -19,6 +19,7 @@
 from __future__ import print_function
 
 from tensorflow.python.framework import ops
+from tensorflow.python.util import object_identity
 
 
 def is_differentiable(op):
@@ -275,11 +276,11 @@
   else:
     seed_ops = make_list_of_op(seed_ops, allow_graph=False)
 
-  stop_at_ts = frozenset(make_list_of_t(stop_at_ts))
-  seed_ops = frozenset(make_list_of_op(seed_ops))
+  stop_at_ts = object_identity.ObjectIdentitySet(make_list_of_t(stop_at_ts))
+  seed_ops = object_identity.ObjectIdentitySet(make_list_of_op(seed_ops))
   if within_ops:
     within_ops = make_list_of_op(within_ops, allow_graph=False)
-    within_ops = frozenset(within_ops)
+    within_ops = object_identity.ObjectIdentitySet(within_ops)
     seed_ops &= within_ops
 
   def is_within(op):
@@ -390,7 +391,7 @@
       sources and add_sources is False.
   """
   ops_to_visit = [_as_operation(init_tensor)]
-  extra_sources = set()
+  extra_sources = object_identity.ObjectIdentitySet()
   while ops_to_visit:
     op = ops_to_visit.pop()
     if op in visited_ops:
diff --git a/tensorflow/python/ops/parallel_for/array_test.py b/tensorflow/python/ops/parallel_for/array_test.py
index 1d1bb1f..022b7c4 100644
--- a/tensorflow/python/ops/parallel_for/array_test.py
+++ b/tensorflow/python/ops/parallel_for/array_test.py
@@ -314,7 +314,7 @@
 
     def loop_fn(i):
       diagonal = array_ops.gather(x, i)
-      if compat.forward_compatible(2019, 7, 31):
+      if compat.forward_compatible(2019, 8, 31):
         return array_ops.matrix_diag(diagonal, k=(0, 1), num_rows=4, num_cols=5)
       return array_ops.matrix_diag(diagonal)
 
@@ -325,7 +325,7 @@
 
     def loop_fn(i):
       input = array_ops.gather(x, i)  # pylint: disable=redefined-builtin
-      if compat.forward_compatible(2019, 7, 31):
+      if compat.forward_compatible(2019, 8, 31):
         return array_ops.matrix_diag_part(input, k=(-2, 0), padding_value=3)
       return array_ops.matrix_diag_part(input)
 
@@ -335,7 +335,7 @@
     matrices = random_ops.random_uniform([3, 4, 4])
     diags = random_ops.random_uniform([3, 4])
     num_outputs = 3
-    if compat.forward_compatible(2019, 7, 31):
+    if compat.forward_compatible(2019, 8, 31):
       bands = random_ops.random_uniform([3, 3, 4])
       num_outputs = 6
 
@@ -347,7 +347,7 @@
           array_ops.matrix_set_diag(matrices[0, ...], diag_i),
           array_ops.matrix_set_diag(matrix_i, diags[0, ...])
       ]
-      if compat.forward_compatible(2019, 7, 31):
+      if compat.forward_compatible(2019, 8, 31):
         k = (-1, 1)
         band_i = array_ops.gather(bands, i)
         results.extend([
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops.py b/tensorflow/python/ops/parallel_for/control_flow_ops.py
index 7c56956..b515246 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops.py
@@ -99,7 +99,12 @@
 
   output = [None if is_none else ta.concat()
             for ta, is_none in zip(ta_list, is_none_list)]
-  return nest.pack_sequence_as(loop_fn_dtypes, output)
+  assert len(output) in (0, len(flat_loop_fn_dtypes))
+  if not output:
+    # This may happen for the case where iters == 0.
+    return None
+  else:
+    return nest.pack_sequence_as(loop_fn_dtypes, output)
 
 
 def _flatten_first_two_dims(x):
@@ -306,42 +311,17 @@
 
 
   This method works similar to tf.map_fn but is optimized to run much faster,
-  but possibly with a much larger memory footprint. The speedups are obtained by
+  possibly with a much larger memory footprint. The speedups are obtained by
   vectorization (see https://arxiv.org/pdf/1903.04243.pdf). The idea behind
   vectorization is to semantically launch all the invocations of `fn` in
   parallel and fuse corresponding operations across all these invocations. This
   fusion is done statically at graph generation time and the generated code is
   often similar in performance to a manually fused version.
 
-
-  For example, let's look at a method that calculates the outer product of a
-  matrix.
-
-  ```python
-  def outer_product(a):
-    return tf.tensordot(a, a, 0)
-
-  # outer_product was designed to not support batching.
-  c = outer_product(tf.ones((2, 3)))
-  # The shape is consistent
-  assert c.shape == (2, 3, 2, 3)
-  ```
-
-  Now suppose we want an efficient batched version of outer_product. We can
-  simply write:
-
-  ```python
-  batch_size = 100
-  a = tf.ones((batch_size, 32, 32))
-  c = tf.vectorized_map(outer_product, a)
-  assert c.shape == (batch_size, 32, 32, 32, 32)
-   ```
-
   Because `tf.vectorized_map` fully parallelizes the batch, this method will
   generally be significantly faster than using `tf.map_fn`, especially in eager
-  mode.
-
-  This is an experimental feature and currently has a lot of limitations:
+  mode. However this is an experimental feature and currently has a lot of
+  limitations:
     - There should be no data dependency between the different semantic
       invocations of `fn`, i.e. it should be safe to map the elements of the
       inputs in any order.
@@ -352,8 +332,8 @@
       particular is not supported.
     - `fn` should return nested structure of Tensors or Operations. However
       if an Operation is returned, it should have zero outputs.
-    - The shape and dtype of `fn` outputs should not depend on the input
-      to `fn`.
+    - The shape and dtype of any intermediate or output tensors in the
+      computation of `fn` should not depend on the input to `fn`.
 
   Args:
     fn: The callable to be performed. It accepts one argument, which will have
@@ -368,6 +348,40 @@
     A tensor or (possibly nested) sequence of tensors. Each tensor packs the
     results of applying fn to tensors unpacked from elems along the first
     dimension, from first to last.
+
+  Examples:
+  ```python
+  def outer_product(a):
+    return tf.tensordot(a, a, 0)
+
+  batch_size = 100
+  a = tf.ones((batch_size, 32, 32))
+  c = tf.vectorized_map(outer_product, a)
+  assert c.shape == (batch_size, 32, 32, 32, 32)
+  ```
+
+  ```python
+  # Computing per-example gradients
+
+  batch_size = 10
+  num_features = 32
+  layer = tf.keras.layers.Dense(1)
+
+  def model_fn(arg):
+    with tf.GradientTape() as g:
+      inp, label = arg
+      inp = tf.expand_dims(inp, 0)
+      label = tf.expand_dims(label, 0)
+      prediction = layer(inp)
+      loss = tf.nn.l2_loss(label - prediction)
+    return g.gradient(loss, (layer.kernel, layer.bias))
+
+  inputs = tf.random_uniform([batch_size, num_features])
+  labels = tf.random_uniform([batch_size, 1])
+  per_example_gradients = tf.vectorized_map(model_fn, (inputs, labels))
+  assert per_example_gradients[0].shape == (batch_size, num_features, 1)
+  assert per_example_gradients[1].shape == (batch_size, 1)
+  ```
   """
   def loop_fn(i):
     gathered_elems = nest.map_structure(lambda x: array_ops.gather(x, i), elems)
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
index ac45a89..d0de84c 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
@@ -37,6 +37,7 @@
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import test_util
+from tensorflow.python.keras.layers import core as keras_core
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import bitwise_ops
 from tensorflow.python.ops import control_flow_ops
@@ -60,6 +61,7 @@
 
 
 @test_util.run_all_in_graph_and_eager_modes
+@test_util.with_control_flow_v2
 class PForTest(PForTestCase):
 
   def test_op_conversion_fallback_to_while_loop(self):
@@ -111,6 +113,37 @@
         compute, array_ops.ones((10, 5, 3)))
     self.run_and_assert_equal(result, array_ops.ones((10, 1, 3)))
 
+  def test_vectorized_map_example_1(self):
+    def outer_product(a):
+      return math_ops.tensordot(a, a, 0)
+
+    batch_size = 100
+    a = array_ops.ones((batch_size, 32, 32))
+    c = pfor_control_flow_ops.vectorized_map(outer_product, a)
+    self.assertAllEqual((batch_size, 32, 32, 32, 32), c.shape)
+
+  def test_vectorized_map_example_2(self):
+    batch_size = 10
+    num_features = 32
+    layer = keras_core.Dense(1)
+
+    def model_fn(arg):
+      with backprop.GradientTape() as g:
+        inp, label = arg
+        inp = array_ops.expand_dims(inp, 0)
+        label = array_ops.expand_dims(label, 0)
+        prediction = layer(inp)
+        loss = nn.l2_loss(label - prediction)
+      return g.gradient(loss, (layer.kernel, layer.bias))
+
+    inputs = random_ops.random_uniform([batch_size, num_features])
+    labels = random_ops.random_uniform([batch_size, 1])
+    per_example_gradients = pfor_control_flow_ops.vectorized_map(
+        model_fn, (inputs, labels))
+    self.assertAllEqual(per_example_gradients[0].shape,
+                        (batch_size, num_features, 1))
+    self.assertAllEqual(per_example_gradients[1].shape, (batch_size, 1))
+
 
 @test_util.run_all_in_graph_and_eager_modes
 class IndexedSlicesTest(PForTestCase):
diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py
index d880ddb..9bbc7c5 100644
--- a/tensorflow/python/ops/parallel_for/pfor.py
+++ b/tensorflow/python/ops/parallel_for/pfor.py
@@ -52,6 +52,7 @@
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import compat
 from tensorflow.python.util import nest
+from tensorflow.python.util import object_identity
 
 flags.DEFINE_bool(
     "op_conversion_fallback_to_while_loop", False,
@@ -799,12 +800,63 @@
     ...
 
   The above will register conversion function `_foo_converter` for handling
-  conversion of `foo_op_type`. During conversion, the registered functin will be
-  called with a single argument of type `PForInput` which will contain state
-  needed for the conversion.  This registered function should output a list of
-  WrappedTensor object with the same length as the number of outputs of op being
-  converted. If the op had zero outputs, then it should return a ops.Operation
-  object.
+  conversion of `foo_op_type`. These converters are called during vectorization
+  of a `pfor` loop body. For each operation node in this loop body,
+  the vectorization process will call the converter corresponding to the
+  operation type of the node.
+
+  During conversion, the registered function will be called with a single
+  argument `pfor_input`, of type `PForInput`, which will contain state needed
+  for the conversion.  When the converter is called for a node, all its inputs
+  should already have been converted and these converted values are stored in
+  `pfor_input.inputs`.  This registered function should output a list of
+  WrappedTensor objects with the same length as the number of outputs of the
+  node being converted. If the node had zero outputs, then it should return an
+  ops.Operation object.  These new sets of nodes should implement the
+  functionality of running that operation for the number of iterations specified
+  by `pfor_input.pfor.loop_len_vector[0]` where the inputs of the node for each
+  iteration are picked from `pfor_inputs.inputs()`.
+
+  One tricky aspect of the conversion process is keeping track of, and
+  leveraging loop invariance of computation. Each converted input is a
+  WrappedTensor which indicates whether the input was loop invariant or not. If
+  the converted value is loop invariant, its rank should match the rank of the
+  corresponding tensor in the loop body, else its rank is larger by 1. The
+  converter should look at the loop invariance of the inputs and generate new
+  nodes based on that. Note that the converter will not be called if all inputs
+  are loop invariant and the operation is not stateful. The converter should
+  determine if its own output is loop invariant and `wrap` its output
+  accordingly.
+
+  Example:
+
+  Here, the converter is trying to convert a Reshape node in the loop body. This
+  node will have two inputs: the tensor to reshape, and the new shape.  The
+  example here only handles the case where the shape is loop invariant.
+
+  @RegisterPFor("Reshape")
+  def _convert_reshape(pfor_input):
+    # We assume that input is not loop invariant. Call to `stacked_input`
+    # asserts that and returns the converted value. This value will have a rank
+    # larger by 1 compared to the rank of the input in the loop body.
+    t = pfor_input.stacked_input(0)
+
+    # We assume that shape input is loop invariant. Call to `unstacked_input`
+    # asserts that and returns the converted value.
+    shape = pfor_input.unstacked_input(1)
+
+    # We compute `new_shape` by prepending the number of iterations to the
+    # original shape.
+    new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape],
+                                 axis=0)
+
+    # The vectorized output involves reshaping the converted input `t` using
+    # `new_shape`.
+    new_output = array_ops.reshape(t, new_shape)
+
+    # The converted output is marked as not loop invariant using the call to
+    # wrap.
+    return wrap(new_output, True)
   """
 
   def __init__(self, op_type):
@@ -1071,7 +1123,7 @@
     self.all_indices = (
         math_ops.range(loop_len) if all_indices is None else all_indices)
 
-    self._conversion_map = {}
+    self._conversion_map = object_identity.ObjectIdentityDictionary()
     self._conversion_map[loop_var] = wrap(self.all_indices, True)
     self._pfor_ops = set(pfor_ops)
     self._pfor_op_ids = set([x._id for x in pfor_ops])
@@ -1316,7 +1368,7 @@
                (not is_stateful and not some_input_converted and
                 not some_control_input_converted)) and
               y.graph == ops.get_default_graph()):
-          if y == y_op:
+          if y is y_op:
             assert not isinstance(y_op, WhileOp)
             new_outputs = y_op
           else:
@@ -1359,7 +1411,7 @@
         logging.vlog(2, "converted %s %s", y_op, new_outputs)
 
         # Insert into self._conversion_map
-        if y == y_op:
+        if y is y_op:
           assert isinstance(new_outputs, ops.Operation)
           self._add_conversion(y_op, new_outputs)
         else:
@@ -1400,6 +1452,10 @@
     """
     return self._all_indices_partitioned
 
+
+# The code below defines converters for different operations. Please see comment
+# for RegisterPFor to see how converters should be defined.
+
 # nn_ops
 
 
@@ -1909,7 +1965,7 @@
     if axis_value is not None:
       axis = axis_value
   if indices_stacked and not param_stacked:
-    if indices == pfor_input.pfor.all_indices and axis == 0:
+    if indices is pfor_input.pfor.all_indices and axis == 0:
       param_shape0 = param.shape.dims[0].value
       indices_shape0 = indices.shape.dims[0].value
       if param_shape0 is not None and indices_shape0 == param_shape0:
diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD
index 2e0b688..1aade2c 100644
--- a/tensorflow/python/ops/ragged/BUILD
+++ b/tensorflow/python/ops/ragged/BUILD
@@ -62,6 +62,7 @@
         "//tensorflow/python:dtypes",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:math_ops",
+        "//tensorflow/python:sort_ops",
         "//tensorflow/python:tensor_util",
         "//tensorflow/python:util",
     ],
@@ -1052,3 +1053,32 @@
         "@absl_py//absl/testing:parameterized",
     ],
 )
+
+py_test(
+    name = "ragged_dynamic_partition_op_test",
+    srcs = ["ragged_dynamic_partition_op_test.py"],
+    python_version = "PY3",
+    srcs_version = "PY2AND3",
+    deps = [
+        ":ragged_array_ops",
+        ":ragged_factory_ops",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:constant_op",
+        "//tensorflow/python:errors",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:platform_test",
+        "@absl_py//absl/testing:parameterized",
+    ],
+)
+
+py_test(
+    name = "string_ngrams_op_test",
+    size = "small",
+    srcs = ["string_ngrams_op_test.py"],
+    python_version = "PY2",
+    srcs_version = "PY2AND3",
+    deps = [
+        ":ragged_string_ops",
+        "//tensorflow/python:client_testlib",
+    ],
+)
diff --git a/tensorflow/python/ops/ragged/ragged_array_ops.py b/tensorflow/python/ops/ragged/ragged_array_ops.py
index 7714217..18af982 100644
--- a/tensorflow/python/ops/ragged/ragged_array_ops.py
+++ b/tensorflow/python/ops/ragged/ragged_array_ops.py
@@ -22,7 +22,9 @@
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import sort_ops
 from tensorflow.python.ops.ragged import ragged_functional_ops
 from tensorflow.python.ops.ragged import ragged_math_ops
 from tensorflow.python.ops.ragged import ragged_tensor
@@ -520,3 +522,132 @@
       return array_ops.rank(input, name)
 
     return input.ragged_rank + array_ops.rank(input.flat_values)
+
+
+#===============================================================================
+# ragged.one_hot
+#===============================================================================
+def ragged_one_hot(indices,
+                   depth,
+                   on_value=None,
+                   off_value=None,
+                   axis=None,
+                   dtype=None,
+                   name=None):
+  """Applies tf.one_hot along the values of a RaggedTensor."""
+  with ops.name_scope(name, 'RaggedOneHot', [indices]):
+    indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
+        indices, name='indices')
+    if axis is not None:
+      axis = ragged_util.get_positive_axis(axis, indices.shape.ndims)
+      if axis < indices.ragged_rank:
+        raise ValueError('axis may not be less than indices.ragged_rank.')
+    return indices.with_flat_values(
+        array_ops.one_hot(indices.flat_values, depth, on_value, off_value, axis,
+                          dtype, name))
+
+
+#===============================================================================
+# ragged.stack_dynamic_partitions
+#===============================================================================
+@tf_export('ragged.stack_dynamic_partitions')
+def stack_dynamic_partitions(data, partitions, num_partitions, name=None):
+  """Stacks dynamic partitions of a Tensor or RaggedTensor.
+
+  Returns a RaggedTensor `output` with `num_partitions` rows, where the row
+  `output[i]` is formed by stacking all slices `data[j1...jN]` such that
+  `partitions[j1...jN] = i`.  Slices of `data` are stacked in row-major
+  order.
+
+  If `num_partitions` is an `int` (not a `Tensor`), then this is equivalent to
+  `tf.ragged.stack(tf.dynamic_partition(data, partitions, num_partitions))`.
+
+  ####Example:
+    ```python
+    >>> data           = ['a', 'b', 'c', 'd', 'e']
+    >>> partitions     = [  3,   0,   2,   2,   3]
+    >>> num_partitions = 5
+    >>> tf.ragged.stack_dynamic_partitions(data, partitions, num_partitions)
+    <RaggedTensor [['b'], [], ['c', 'd'], ['a', 'e'], []]>
+    ```
+
+  Args:
+    data: A `Tensor` or `RaggedTensor` containing the values to stack.
+    partitions: An `int32` or `int64` `Tensor` or `RaggedTensor` specifying the
+      partition that each slice of `data` should be added to.
+      `partitions.shape` must be a prefix of `data.shape`.  Values must be
+      greater than or equal to zero, and less than `num_partitions`.
+      `partitions` is not required to be sorted.
+    num_partitions: An `int32` or `int64` scalar specifying the number of
+      partitions to output.  This determines the number of rows in `output`.
+    name: A name prefix for the returned tensor (optional).
+
+  Returns:
+    A `RaggedTensor` containing the stacked partitions.  The returned tensor
+    has the same dtype as `data`, and its shape is
+    `[num_partitions, (D)] + data.shape[partitions.rank:]`, where `(D)` is a
+    ragged dimension whose length is the number of data slices stacked for
+    each `partition`.
+  """
+  with ops.name_scope(name, 'SegmentStack', [data, partitions, num_partitions]):
+    # Convert inputs to tensors.
+    data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
+    row_splits_dtype = (
+        data.row_splits.dtype
+        if isinstance(data, ragged_tensor.RaggedTensor) else None)
+    partitions = ragged_tensor.convert_to_tensor_or_ragged_tensor(
+        partitions, name='partitions', preferred_dtype=row_splits_dtype)
+    num_partitions = ops.convert_to_tensor(
+        num_partitions, name='num_partitions', preferred_dtype=partitions.dtype)
+    if row_splits_dtype is not None:
+      partitions = math_ops.cast(partitions, row_splits_dtype)
+    num_partitions = math_ops.cast(num_partitions, partitions.dtype)
+
+    # Sanity-checks for shapes.
+    partitions_rank = partitions.shape.ndims
+    if partitions_rank is None:
+      raise ValueError('partitions must have known rank.')
+    num_partitions.shape.assert_has_rank(0)
+    partitions.shape.assert_is_compatible_with(data.shape[:partitions_rank])
+
+    if partitions_rank == 0:
+      # If partitions is a scalar, then just create a RaggedTensor containing
+      # that single the complete `data` value in the specified row.
+      return ragged_tensor.RaggedTensor.from_value_rowids(
+          values=array_ops.stack([data]),
+          value_rowids=array_ops.stack([partitions]),
+          nrows=num_partitions,
+          validate=False)
+
+    elif partitions_rank == 1:
+      # If partitions is a vector (the typical case): we can just use data and
+      # partitions as the `values` and `value_rowids` for `from_value_rowids`,
+      # as long as we sort them first.
+      permutation = sort_ops.argsort(partitions, stable=True)
+      value_rowids = array_ops.gather(partitions, permutation)
+      values = array_ops.gather(data, permutation)
+      check = check_ops.assert_less(
+          value_rowids[-1:],
+          num_partitions,
+          message='partitions must be less than num_partitions')
+      with ops.control_dependencies([check]):
+        return ragged_tensor.RaggedTensor.from_value_rowids(
+            values, value_rowids, nrows=num_partitions, validate=False)
+
+    else:
+      # Handle higher-dimensional partitions via recursion.
+      if not isinstance(data, ragged_tensor.RaggedTensor):
+        data = ragged_tensor.RaggedTensor.from_tensor(
+            data, row_splits_dtype=partitions.dtype, ragged_rank=1)
+      if not isinstance(partitions, ragged_tensor.RaggedTensor):
+        partitions = ragged_tensor.RaggedTensor.from_tensor(
+            partitions,
+            row_splits_dtype=partitions.dtype,
+            ragged_rank=max(data.ragged_rank, partitions_rank - 1))
+      check = check_ops.assert_equal(
+          data.row_splits,
+          partitions.row_splits,
+          message='data and partitions have incompatible ragged shapes')
+      with ops.control_dependencies([check]):
+        return stack_dynamic_partitions(data.values, partitions.values,
+                                        num_partitions)
diff --git a/tensorflow/python/ops/ragged/ragged_concat_ops.py b/tensorflow/python/ops/ragged/ragged_concat_ops.py
index 30fe753..1372db0 100644
--- a/tensorflow/python/ops/ragged/ragged_concat_ops.py
+++ b/tensorflow/python/ops/ragged/ragged_concat_ops.py
@@ -27,6 +27,7 @@
 from tensorflow.python.ops.ragged import ragged_gather_ops
 from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.ops.ragged import ragged_util
+from tensorflow.python.util.tf_export import tf_export
 
 
 def concat(values, axis, name=None):
@@ -70,40 +71,41 @@
     return _ragged_stack_concat_helper(values, axis, stack_values=False)
 
 
+@tf_export('ragged.stack')
 def stack(values, axis=0, name=None):
-  """Stacks potentially ragged tensors along one dimension.
+  """Stacks a list of rank-`R` tensors into one rank-`(R+1)` `RaggedTensor`.
 
-  Given a list of tensors with the same rank `K` (`K >= axis`), returns a
-  rank-`K+1` `RaggedTensor` `result` such that `result[i0...iaxis]` is the
-  list `[rt[i0...iaxis] for rt in values]`.
-
-  Args:
-    values: A list of potentially ragged tensors.  May not be empty. All
-      `values` must have the same rank and the same dtype; but unlike
-      `tf.concat`, they can have arbitrary shapes.
-    axis: A python integer, indicating the dimension along which to stack.
-      (Note: Unlike `tf.stack`, the `axis` parameter must be statically known.)
-        Negative values are supported only if the rank of at least one
-        `values` value is statically known.
-    name: A name prefix for the returned tensor (optional).
-
-  Returns:
-    A `RaggedTensor` with rank `K+1`.
-    `result.ragged_rank=max(axis, max(rt.ragged_rank for rt in values]))`.
-
-  Raises:
-    ValueError: If `values` is empty, if `axis` is out of bounds or if
-      the input tensors have different ranks.
+  Given a list of tensors or ragged tensors with the same rank `R`
+  (`R >= axis`), returns a rank-`R+1` `RaggedTensor` `result` such that
+  `result[i0...iaxis]` is `[value[i0...iaxis] for value in values]`.
 
   #### Example:
     ```python
     >>> t1 = tf.ragged.constant([[1, 2], [3, 4, 5]])
     >>> t2 = tf.ragged.constant([[6], [7, 8, 9]])
-    >>> ragged.stack([t1, t2], axis=0)
+    >>> tf.ragged.stack([t1, t2], axis=0)
     [[[1, 2], [3, 4, 5]], [[6], [7, 9, 0]]]
-    >>> ragged.stack([t1, t2], axis=1)
+    >>> tf.ragged.stack([t1, t2], axis=1)
     [[[1, 2], [6]], [[3, 4, 5], [7, 8, 9]]]
     ```
+
+  Args:
+    values: A list of `tf.Tensor` or `tf.RaggedTensor`.  May not be empty. All
+      `values` must have the same rank and the same dtype; but unlike
+      `tf.stack`, they can have arbitrary dimension sizes.
+    axis: A python integer, indicating the dimension along which to stack.
+      (Note: Unlike `tf.stack`, the `axis` parameter must be statically known.)
+      Negative values are supported only if the rank of at least one
+      `values` value is statically known.
+    name: A name prefix for the returned tensor (optional).
+
+  Returns:
+    A `RaggedTensor` with rank `R+1`.
+    `result.ragged_rank=1+max(axis, max(rt.ragged_rank for rt in values]))`.
+
+  Raises:
+    ValueError: If `values` is empty, if `axis` is out of bounds or if
+      the input tensors have different ranks.
   """
   if not isinstance(values, (list, tuple)):
     values = [values]
diff --git a/tensorflow/python/ops/ragged/ragged_constant_value_op_test.py b/tensorflow/python/ops/ragged/ragged_constant_value_op_test.py
index fecbb2e..94df661 100644
--- a/tensorflow/python/ops/ragged/ragged_constant_value_op_test.py
+++ b/tensorflow/python/ops/ragged/ragged_constant_value_op_test.py
@@ -21,6 +21,7 @@
 from absl.testing import parameterized
 import numpy as np
 
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops.ragged import ragged_factory_ops
 from tensorflow.python.ops.ragged import ragged_tensor_value
@@ -175,6 +176,8 @@
       dict(
           pylist=[[b'a', b'b'], [b'c'], [b'd', b'e', b'f']],
           dtype=np.dtype('S1')),
+      dict(pylist=[], dtype=dtypes.float32, expected_dtype=np.float32),
+      dict(pylist=[], dtype=dtypes.int32, expected_dtype=np.int32),
   )
   def testRaggedValues(self,
                        pylist,
@@ -190,10 +193,10 @@
     # E.g., [np.array((1,2))] --> [[1,2]]
     pylist = _normalize_pylist(pylist)
     # If dtype was explicitly specified, check it.
-    if dtype is not None:
-      self.assertEqual(rt.dtype, dtype)
     if expected_dtype is not None:
       self.assertEqual(rt.dtype, expected_dtype)
+    elif dtype is not None:
+      self.assertEqual(rt.dtype, dtype)
 
     # If ragged_rank was explicitly specified, check it.
     if ragged_rank is not None:
diff --git a/tensorflow/python/ops/ragged/ragged_conversion_ops.py b/tensorflow/python/ops/ragged/ragged_conversion_ops.py
index 8e06a2d..585e914 100644
--- a/tensorflow/python/ops/ragged/ragged_conversion_ops.py
+++ b/tensorflow/python/ops/ragged/ragged_conversion_ops.py
@@ -18,12 +18,22 @@
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_ragged_conversion_ops
+from tensorflow.python.ops import math_ops
 from tensorflow.python.ops.ragged import ragged_tensor
 
 
-def from_tensor(tensor, lengths=None, padding=None, ragged_rank=1,
-                row_splits_dtype=dtypes.int64, name=None):
+def from_tensor(tensor,
+                lengths=None,
+                padding=None,
+                ragged_rank=1,
+                row_splits_dtype=dtypes.int64,
+                name=None):
   if ragged_tensor.is_ragged(tensor):
     return tensor
   else:
@@ -43,6 +53,130 @@
     return rt_input
 
 
+def _get_row_partition_type_tensor_pairs_tail(rt_value):
+  """Gets a list of the row partitions for rt_value.
+
+  If parent_indices are defined, then they are used. Otherwise, row_splits
+  are used.
+
+  This assumes that rt_input is nested inside another RaggedTensor. If it is
+  a tensor, then return an empty list.
+
+  Args:
+    rt_value: a ragged tensor value. May be a tensor.
+
+  Returns:
+    A list of (row_partition_type, row_partition_tensor) pairs.
+  """
+  if isinstance(rt_value, ragged_tensor.RaggedTensor):
+    tail = _get_row_partition_type_tensor_pairs_tail(rt_value.values)
+    if rt_value._cached_value_rowids is not None:  # pylint: disable=protected-access
+      return [("VALUE_ROWIDS", rt_value.value_rowids())] + tail
+    else:
+      return [("ROW_SPLITS", rt_value.row_splits)] + tail
+  return []
+
+
+def _get_row_partition_type_tensor_pairs(rt_input):
+  """Gets a list of the row partitions for rt_input.
+
+  If value_rowids are defined, then they are used. Otherwise, row_splits
+  are used. If the outermost level has value_rowids defind, then nrows is
+  also added.
+
+  Args:
+    rt_input: a ragged tensor.
+
+  Returns:
+    A list of (row_partition_type, row_partition_tensor) pairs.
+  """
+  tail = _get_row_partition_type_tensor_pairs_tail(rt_input.values)
+  if rt_input._cached_value_rowids is not None:  # pylint: disable=protected-access
+    return [("FIRST_DIM_SIZE", rt_input.nrows()),
+            ("VALUE_ROWIDS", rt_input.value_rowids())] + tail
+  else:
+    return [("ROW_SPLITS", rt_input.row_splits)] + tail
+
+
+def _shape_as_tensor(shape, dtype):
+  """Takes shape and coerces it to a shape as a tensor.
+
+  If the object is already a tensor, simply passes it on (result is guaranteed
+  to be int64 or int32, but not necessarily dtype).
+  If not, creates a tensor of type dtype.
+
+  Result is either a scalar equal to -1 if the shape is unknown_rank.
+  Otherwise, it is a vector, where unknown dimensions are represented with a
+  value of -1.
+
+  In C++, see TensorShapeFromTensor for parsing shapes in kernels, and
+  InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape, for
+  use in the shape inference function.
+
+  Args:
+    shape: input to coerce from TensorShape, Tensor, None, List[Optional[Int]],
+      Tuple[Optional[Int]].
+    dtype: tf.int64 or tf.int32
+
+  Returns:
+    a scalar or vector tensor of dtype tf.int32 or tf.int64.
+  """
+  if dtype != dtypes.int64 and dtype != dtypes.int32:
+    raise ValueError("Expected int64 or int32 for dtype: got {}".format(dtype))
+
+  if isinstance(shape, ops.Tensor):
+    if shape.dtype != dtypes.int64 and shape.dtype != dtypes.int32:
+      return math_ops.cast(shape, dtype)
+    return shape
+  shape = tensor_shape.as_shape(shape)
+  if not shape:
+    # Imply rank is unknown using a -1 scalar.
+    return constant_op.constant(-1, dtype=dtype)
+  shape = [(-1 if x is None else x) for x in shape.as_list()]
+  # At this point, shape is List[Int].
+  return constant_op.constant(shape, dtype=dtype)
+
+
+# TODO(martinz): add a gradient for this op.
+# TODO(martinz): this is a replacement for RaggedTensor.to_tensor. Move this
+# after there is a chance for the kernels to propagate.
+def ragged_to_dense(rt_input, default_value=None, shape=None):
+  """Create a dense tensor from a ragged tensor.
+
+  If the shape is None, then the resulting dense tensor is the same size as
+  the maximum length of the ragged tensor in each dimension.
+
+  If the shape is not None, then it must be the same number of dimensions
+  as the ragged tensor. For dimension i, if shape[i] is None, then the maximum
+  length of the ragged tensor in that dimension is the size of the output in
+  that dimension. If shape[i] is an integer, then that is the size of the output
+  in that dimension.
+
+  Args:
+    rt_input: the tensor to densify.
+    default_value: used when a value is missing.
+    shape: the shape of the resulting tensor.
+
+  Returns:
+    a dense tensor.
+  """
+
+  type_tensor_pairs = _get_row_partition_type_tensor_pairs(rt_input)
+  row_partition_types = [x[0] for x in type_tensor_pairs]
+  row_partition_tensors = [x[1] for x in type_tensor_pairs]
+  values = rt_input.flat_values
+  if default_value is None:
+    default_value = array_ops.zeros((), values.dtype)
+
+  shape_tensor = _shape_as_tensor(shape, row_partition_tensors[0].dtype)
+  return gen_ragged_conversion_ops.ragged_tensor_to_tensor(
+      shape=shape_tensor,
+      values=values,
+      default_value=default_value,
+      row_partition_types=row_partition_types,
+      row_partition_tensors=row_partition_tensors)
+
+
 def to_sparse(rt_input, name=None):
   return rt_input.to_sparse(name)
 
diff --git a/tensorflow/python/ops/ragged/ragged_dispatch.py b/tensorflow/python/ops/ragged/ragged_dispatch.py
index 0f67c8c..871c7ee 100644
--- a/tensorflow/python/ops/ragged/ragged_dispatch.py
+++ b/tensorflow/python/ops/ragged/ragged_dispatch.py
@@ -26,6 +26,7 @@
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import data_flow_ops
 from tensorflow.python.ops import gen_bitwise_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import parsing_ops
@@ -437,6 +438,15 @@
                                                 squeeze_dims)
   return ragged_squeeze_op.squeeze(input, axis, name)
 
+
+def _ragged_dynamic_partition(data, partitions, num_partitions, name=None):
+  """RaggedTensor Dispatch override for tf.dynamic_partition."""
+  if not isinstance(num_partitions, int) or num_partitions < 0:
+    raise TypeError('num_partitions must be a non-negative integer')
+  result = ragged_array_ops.stack_dynamic_partitions(data, partitions,
+                                                     num_partitions, name)
+  return [result[i] for i in range(num_partitions)]
+
 # (original_op, ragged_op, ragged_args)
 _RAGGED_DISPATCH_OPS = [
     (array_ops.batch_gather, ragged_batch_gather_ops.batch_gather,
@@ -449,6 +459,7 @@
     (array_ops.gather_nd, _ragged_gather_nd_v1, ['params', 'indices']),
     (array_ops.gather_nd_v2, ragged_gather_ops.gather_nd, ['params',
                                                            'indices']),
+    (array_ops.one_hot, ragged_array_ops.ragged_one_hot, ['indices']),
     (array_ops.rank, ragged_array_ops.rank, ['input']),
     (array_ops.size, _ragged_size_v1, ['input']),
     (array_ops.size_v2, ragged_array_ops.size, ['input']),
@@ -457,6 +468,8 @@
     (array_ops.stack, ragged_concat_ops.stack, ['[values]']),
     (array_ops.tile, ragged_array_ops.tile, ['input']),
     (array_ops.where, ragged_where_op.where, ['condition', 'x', 'y']),
+    (data_flow_ops.dynamic_partition, _ragged_dynamic_partition,
+     ['data', 'partitions']),
     (math_ops.unsorted_segment_sum, ragged_math_ops.segment_sum,
      ['data', 'segment_ids']),
     (math_ops.unsorted_segment_prod, ragged_math_ops.segment_prod,
diff --git a/tensorflow/python/ops/ragged/ragged_dispatch_test.py b/tensorflow/python/ops/ragged/ragged_dispatch_test.py
index 246a025..da95690 100644
--- a/tensorflow/python/ops/ragged/ragged_dispatch_test.py
+++ b/tensorflow/python/ops/ragged/ragged_dispatch_test.py
@@ -29,6 +29,7 @@
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import data_flow_ops
 from tensorflow.python.ops import gen_bitwise_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import parsing_ops
@@ -538,6 +539,20 @@
           },
           expected=ragged_factory_ops.constant_value([8, 9, 7])),
       dict(
+          op=array_ops.one_hot,
+          kwargs={
+              'indices':
+                  ragged_factory_ops.constant_value([[1, 2, 3], [0]],
+                                                    dtype=np.int32),
+              'depth':
+                  4,
+              'axis':
+                  1
+          },
+          expected=ragged_factory_ops.constant_value(
+              [[[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], [[1, 0, 0, 0]]],
+              ragged_rank=1)),
+      dict(
           op=array_ops.stack,
           args=([
               ragged_factory_ops.constant_value([[1, 2, 3], [4]],
@@ -677,10 +692,10 @@
           op=string_ops.reduce_join,
           kwargs={
               'inputs':
-                  ragged_factory_ops.constant_value(
-                      [[b'this', b'is', b'a', b'test', b'for', b'ragged',
-                        b'tensors'],
-                       [b'please', b'do', b'not', b'panic', b'!']]),
+                  ragged_factory_ops.constant_value([[
+                      b'this', b'is', b'a', b'test', b'for', b'ragged',
+                      b'tensors'
+                  ], [b'please', b'do', b'not', b'panic', b'!']]),
               'axis':
                   0,
               'keepdims':
@@ -728,11 +743,30 @@
               'axis': [0]
           },
           expected=ragged_factory_ops.constant_value([[1, 2, 3], [4, 5]])),
+      dict(
+          op=data_flow_ops.dynamic_partition,
+          kwargs={
+              'data': ragged_factory_ops.constant_value([[1], [2, 3, 4], [5]]),
+              'partitions': [2, 1, 1],
+              'num_partitions': 3
+          },
+          expected=[
+              ragged_factory_ops.constant_value([], ragged_rank=1),
+              ragged_factory_ops.constant_value([[2, 3, 4], [5]]),
+              ragged_factory_ops.constant_value([[1]])
+          ],
+          result_is_list=True),
   ])
-  def testRaggedDispatch(self, op, expected, args=(), kwargs=None):
+  def testRaggedDispatch(self, op, expected, args=(), result_is_list=False,
+                         kwargs=None):
     if kwargs is None: kwargs = {}
     result = op(*args, **kwargs)
-    self.assertAllEqual(result, expected)
+    if result_is_list:
+      self.assertLen(result, len(expected))
+      for (r, e) in zip(result, expected):
+        self.assertAllEqual(r, e)
+    else:
+      self.assertAllEqual(result, expected)
 
   def test_ragged_op_list(self):
     # Ops that should be listed as supported in both v1 and v2.
@@ -761,14 +795,14 @@
         'math.tan', 'math.truediv', 'math.unsorted_segment_max',
         'math.unsorted_segment_mean', 'math.unsorted_segment_min',
         'math.unsorted_segment_prod', 'math.unsorted_segment_sqrt_n',
-        'math.unsorted_segment_sum', 'ones_like', 'rank', 'realdiv',
+        'math.unsorted_segment_sum', 'one_hot', 'ones_like', 'rank', 'realdiv',
         'reduce_all', 'size', 'squeeze', 'stack', 'strings.as_string',
         'strings.join', 'strings.length', 'strings.reduce_join',
         'strings.regex_full_match', 'strings.regex_replace', 'strings.strip',
         'strings.substr', 'strings.to_hash_bucket_fast',
         'strings.to_hash_bucket_strong', 'strings.to_hash_bucket',
         'strings.to_number', 'strings.unicode_script', 'tile', 'truncatediv',
-        'truncatemod', 'zeros_like'
+        'truncatemod', 'zeros_like', 'dynamic_partition'
     ]
 
     # Ops that should be listed as supported in v1 only.
diff --git a/tensorflow/python/ops/ragged/ragged_dynamic_partition_op_test.py b/tensorflow/python/ops/ragged/ragged_dynamic_partition_op_test.py
new file mode 100644
index 0000000..790cabda
--- /dev/null
+++ b/tensorflow/python/ops/ragged/ragged_dynamic_partition_op_test.py
@@ -0,0 +1,257 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for ragged_array_ops.stack_dynamic_partitions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops.ragged import ragged_array_ops
+from tensorflow.python.ops.ragged import ragged_concat_ops
+from tensorflow.python.ops.ragged import ragged_factory_ops
+from tensorflow.python.platform import googletest
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class RaggedSegmentStackOpTest(test_util.TensorFlowTestCase,
+                               parameterized.TestCase):
+
+  @parameterized.parameters([
+      dict(  # empty inputs
+          data=[],
+          partitions=[],
+          num_partitions=0,
+          expected=[],
+          expected_ragged_rank=1),
+      dict(  # empty data, num_partitions>0
+          data=[],
+          partitions=[],
+          num_partitions=3,
+          expected=[[], [], []]),
+      dict(  # 1D data, 1D partitions (docstring example)
+          data=['a', 'b', 'c', 'd', 'e'],
+          partitions=[3, 0, 2, 2, 3],
+          num_partitions=5,
+          expected=[['b'], [], ['c', 'd'], ['a', 'e'], []]),
+      dict(  # 2D data, 1D partitions
+          data=[['a', 'b'], ['c', 'd'], ['e', 'f'], ['g', 'h']],
+          data_ragged_rank=0,
+          partitions=[2, 1, 2, 3],
+          num_partitions=4,
+          expected=[[], [['c', 'd']], [['a', 'b'], ['e', 'f']], [['g', 'h']]],
+          expected_ragged_rank=1),
+      dict(  # 2D ragged data, 1D partitions
+          data=[['a'], ['b', 'c', 'd'], [], ['e', 'f']],
+          data_ragged_rank=1,
+          partitions=[2, 1, 2, 3],
+          num_partitions=4,
+          expected=[[], [['b', 'c', 'd']], [['a'], []], [['e', 'f']]],
+          expected_ragged_rank=2),
+      dict(  # 2D data, 2D partitions
+          data=[['a', 'b'], ['c', 'd'], ['e', 'f'], ['g', 'h']],
+          data_ragged_rank=0,
+          partitions=[[3, 0], [2, 2], [4, 3], [2, 0]],
+          num_partitions=5,
+          expected=[['b', 'h'], [], ['c', 'd', 'g'], ['a', 'f'], ['e']]),
+      dict(  # 2D ragged data, 2D ragged partitions
+          data=[['a', 'b'], ['c', 'd'], ['e', 'f'], ['g', 'h']],
+          data_ragged_rank=0,
+          partitions=[[3, 0], [2, 2], [4, 3], [2, 0]],
+          num_partitions=5,
+          expected=[['b', 'h'], [], ['c', 'd', 'g'], ['a', 'f'], ['e']]),
+      dict(  # 3D data, 1d partitions
+          data=[[['a', 'b'], ['c', 'd']], [['e', 'f'], ['g', 'h']]],
+          data_ragged_rank=0,
+          partitions=[1, 0],
+          num_partitions=2,
+          expected=[[[['e', 'f'], ['g', 'h']]], [[['a', 'b'], ['c', 'd']]]],
+          expected_ragged_rank=1),
+      dict(  # 3D data (ragged_rank=1), 1d partitions
+          data=[[['a', 'b'], ['c', 'd']], [['e', 'f']]],
+          data_ragged_rank=1,
+          partitions=[2, 0],
+          num_partitions=3,
+          expected=[[[['e', 'f']]], [], [[['a', 'b'], ['c', 'd']]]],
+          expected_ragged_rank=2),
+      dict(  # 3D data (ragged_rank=2), 1d partitions
+          data=[[['a', 'b'], ['c', 'd']], [['e', 'f', 'g', 'h']]],
+          data_ragged_rank=2,
+          partitions=[2, 0],
+          num_partitions=3,
+          expected=[[[['e', 'f', 'g', 'h']]], [], [[['a', 'b'], ['c', 'd']]]],
+          expected_ragged_rank=3),
+      dict(  # 3D data, 2d partitions
+          data=[[['a', 'b'], ['c', 'd']], [['e', 'f'], ['g', 'h']]],
+          data_ragged_rank=0,
+          partitions=[[1, 0], [0, 3]],
+          segment_ids_ragged_rank=0,
+          num_partitions=4,
+          expected=[[['c', 'd'], ['e', 'f']], [['a', 'b']], [], [['g', 'h']]],
+          expected_ragged_rank=1),
+      dict(  # 3D data (ragged_rank=1), 2d partitions
+          data=[[['a', 'b'], ['c', 'd']], [['e', 'f']]],
+          data_ragged_rank=1,
+          partitions=[[1, 0], [0]],
+          segment_ids_ragged_rank=1,
+          num_partitions=2,
+          expected=[[['c', 'd'], ['e', 'f']], [['a', 'b']]],
+          expected_ragged_rank=1),
+      dict(  # 3D data (ragged_rank=2), 2d partitions
+          data=[[['a', 'b'], ['c', 'd']], [['e', 'f', 'g', 'h']]],
+          data_ragged_rank=2,
+          partitions=[[1, 0], [0]],
+          segment_ids_ragged_rank=1,
+          num_partitions=3,
+          expected=[[['c', 'd'], ['e', 'f', 'g', 'h']], [['a', 'b']], []],
+          expected_ragged_rank=2),
+      dict(  # 3D data (ragged_rank=2), 3d partitions (ragged_rank=2)
+          data=[[['a', 'b'], ['c', 'd']], [['e', 'f', 'g', 'h']]],
+          data_ragged_rank=2,
+          partitions=[[[3, 0], [1, 2]], [[1, 1, 0, 1]]],
+          segment_ids_ragged_rank=2,
+          num_partitions=4,
+          expected=[['b', 'g'], ['c', 'e', 'f', 'h'], ['d'], ['a']]),
+      dict(  # 0D data, 0D partitions
+          data='a',
+          partitions=3,
+          num_partitions=5,
+          expected=[[], [], [], ['a'], []]),
+      dict(  # 1D data, 0D partitions
+          data=['a', 'b', 'c'],
+          partitions=3,
+          num_partitions=5,
+          expected=[[], [], [], [['a', 'b', 'c']], []],
+          expected_ragged_rank=1),
+      dict(  # 2D data, 0D partitions
+          data=[['a', 'b'], ['c', 'd']],
+          data_ragged_rank=0,
+          partitions=3,
+          num_partitions=5,
+          expected=[[], [], [], [[['a', 'b'], ['c', 'd']]], []],
+          expected_ragged_rank=1),
+      dict(  # 2D data (ragged_rank=1), 0D partitions
+          data=[['a', 'b'], ['c']],
+          data_ragged_rank=1,
+          partitions=3,
+          num_partitions=5,
+          expected=[[], [], [], [[['a', 'b'], ['c']]], []],
+          expected_ragged_rank=3),
+  ])
+  def testRaggedSegmentStack(self,
+                             data,
+                             partitions,
+                             num_partitions,
+                             expected,
+                             data_ragged_rank=None,
+                             segment_ids_ragged_rank=None,
+                             expected_ragged_rank=None):
+    for seg_dtype in [dtypes.int32, dtypes.int64]:
+      data_tensor = ragged_factory_ops.constant(
+          data, row_splits_dtype=seg_dtype, ragged_rank=data_ragged_rank)
+      segment_ids_tensor = ragged_factory_ops.constant(
+          partitions,
+          dtype=seg_dtype,
+          row_splits_dtype=seg_dtype,
+          ragged_rank=segment_ids_ragged_rank)
+      expected_tensor = ragged_factory_ops.constant(
+          expected,
+          row_splits_dtype=seg_dtype,
+          ragged_rank=expected_ragged_rank)
+      result = ragged_array_ops.stack_dynamic_partitions(
+          data_tensor, segment_ids_tensor, num_partitions)
+      self.assertAllEqual(result, expected_tensor)
+
+      # Check that it's equivalent to tf.stack(dynamic_partition(...)),
+      # where applicable.
+      if (data_ragged_rank == 0 and segment_ids_ragged_rank == 0 and
+          seg_dtype == dtypes.int32):
+        equiv = ragged_concat_ops.stack(
+            data_flow_ops.dynamic_partition(data_tensor, segment_ids_tensor,
+                                            num_partitions))
+        self.assertAllEqual(result, self.evaluate(equiv).to_list())
+
+  @parameterized.parameters([
+      dict(
+          data=['a', 'b', 'c'],
+          partitions=[2, -1, 0],
+          num_partitions=10,
+          error='must be non-negative'),
+      dict(
+          data=['a', 'b', 'c'],
+          partitions=[2, 10, 0],
+          num_partitions=1,
+          error='partitions must be less than num_partitions'),
+      dict(
+          data=['a', 'b', 'c'],
+          partitions=[2, 10, 0],
+          num_partitions=10,
+          error='partitions must be less than num_partitions'),
+      dict(
+          data=[['a', 'b'], ['c']],
+          partitions=[[2], [3, 0]],
+          num_partitions=10,
+          error='data and partitions have incompatible ragged shapes'),
+  ])
+  def testRuntimeError(self, data, partitions, num_partitions, error):
+    data = ragged_factory_ops.constant(data)
+    partitions = ragged_factory_ops.constant(partitions, dtype=dtypes.int64)
+    with self.assertRaisesRegexp((ValueError, errors.InvalidArgumentError),
+                                 error):
+      self.evaluate(
+          ragged_array_ops.stack_dynamic_partitions(data, partitions,
+                                                    num_partitions))
+
+  @parameterized.parameters([
+      dict(
+          data=['a', 'b', 'c'],
+          partitions=[1, 2],
+          num_partitions=10,
+          error=r'Shapes \(2,\) and \(3,\) are incompatible'),
+      dict(
+          data=[['a', 'b'], ['c', 'd']],
+          partitions=[[1, 2, 3], [4, 5, 6]],
+          num_partitions=10,
+          error=r'Shapes \(2, 3\) and \(2, 2\) are incompatible'),
+      dict(
+          data=['a', 'b', 'c'],
+          partitions=[1, 2, 3],
+          num_partitions=[1, 2, 3],
+          error='must have rank 0'),
+  ])
+  def testStaticError(self, data, partitions, num_partitions, error):
+    with self.assertRaisesRegexp((ValueError, errors.InvalidArgumentError),
+                                 error):
+      ragged_array_ops.stack_dynamic_partitions(data, partitions,
+                                                num_partitions)
+
+  def testUnknownRankError(self):
+    if context.executing_eagerly():
+      return
+    partitions = array_ops.placeholder(dtypes.int32, None)
+    with self.assertRaisesRegexp((ValueError, errors.InvalidArgumentError),
+                                 'partitions must have known rank'):
+      ragged_array_ops.stack_dynamic_partitions(['a', 'b', 'c'], partitions, 10)
+
+
+if __name__ == '__main__':
+  googletest.main()
diff --git a/tensorflow/python/ops/ragged/ragged_factory_ops.py b/tensorflow/python/ops/ragged/ragged_factory_ops.py
index 5c654c6..7ab450e 100644
--- a/tensorflow/python/ops/ragged/ragged_factory_ops.py
+++ b/tensorflow/python/ops/ragged/ragged_factory_ops.py
@@ -133,6 +133,8 @@
     ValueError: If the scalar values in `pylist` have inconsistent nesting
       depth; or if ragged_rank or inner_shape are incompatible with `pylist`.
   """
+  if dtype is not None and isinstance(dtype, dtypes.DType):
+    dtype = dtype.as_numpy_dtype
   row_splits_dtype = dtypes.as_dtype(row_splits_dtype).as_numpy_dtype
   def _ragged_factory(values, row_splits):
     row_splits = np.array(row_splits, dtype=row_splits_dtype)
diff --git a/tensorflow/python/ops/ragged/ragged_math_ops.py b/tensorflow/python/ops/ragged/ragged_math_ops.py
index 39bd93e..22b6288 100644
--- a/tensorflow/python/ops/ragged/ragged_math_ops.py
+++ b/tensorflow/python/ops/ragged/ragged_math_ops.py
@@ -159,7 +159,7 @@
                               data,
                               segment_ids,
                               num_segments,
-                              separator='',
+                              separator=None,
                               name=None):
   """Aggregates along segments of a RaggedTensor using `unsorted_segment_op`.
 
@@ -182,7 +182,7 @@
       `int32`.  `segment_ids.shape` must be a prefix of `data.shape`.
       `segment_ids` is not required to be sorted.
     num_segments: An `int32` or `int64` scalar.
-    separator: An optional string. Defaults to "". The separator to
+    separator: An optional string. Defaults to None. The separator to
       use when joining. Only used for string types.
     name: A name prefix for the returned tensor (optional).
 
@@ -195,7 +195,7 @@
   """
   if not (ragged_tensor.is_ragged(data) or
           ragged_tensor.is_ragged(segment_ids)):
-    if data.dtype == dtypes.string:
+    if separator is not None:
       # It uses unsorted_segment_join.
       return unsorted_segment_op(data, segment_ids, num_segments, separator,
                                  name)
@@ -255,37 +255,45 @@
     # Recursively aggregate the values.
     output_values = _ragged_segment_aggregate(unsorted_segment_op, data.values,
                                               data_val_to_out_val_index,
-                                              output_splits[-1])
+                                              output_splits[-1], separator)
     return ragged_tensor.RaggedTensor.from_row_splits(
         output_values, output_splits, validate=False)
 
 
 def segment_sum(data, segment_ids, num_segments, name=None):
   # For docs, see: _RAGGED_SEGMENT_DOCSTRING
-  return _ragged_segment_aggregate(math_ops.unsorted_segment_sum, data,
-                                   segment_ids, num_segments, name or
-                                   'RaggedSegmentSum')
+  return _ragged_segment_aggregate(math_ops.unsorted_segment_sum,
+                                   data=data,
+                                   segment_ids=segment_ids,
+                                   num_segments=num_segments,
+                                   name=(name or'RaggedSegmentSum'))
 
 
 def segment_prod(data, segment_ids, num_segments, name=None):
   # For docs, see: _RAGGED_SEGMENT_DOCSTRING
-  return _ragged_segment_aggregate(math_ops.unsorted_segment_prod, data,
-                                   segment_ids, num_segments, name or
-                                   'RaggedSegmentProd')
+  return _ragged_segment_aggregate(math_ops.unsorted_segment_prod,
+                                   data=data,
+                                   segment_ids=segment_ids,
+                                   num_segments=num_segments,
+                                   name=(name or 'RaggedSegmentProd'))
 
 
 def segment_min(data, segment_ids, num_segments, name=None):
   # For docs, see: _RAGGED_SEGMENT_DOCSTRING
-  return _ragged_segment_aggregate(math_ops.unsorted_segment_min, data,
-                                   segment_ids, num_segments, name or
-                                   'RaggedSegmentMin')
+  return _ragged_segment_aggregate(math_ops.unsorted_segment_min,
+                                   data=data,
+                                   segment_ids=segment_ids,
+                                   num_segments=num_segments,
+                                   name=(name or 'RaggedSegmentMin'))
 
 
 def segment_max(data, segment_ids, num_segments, name=None):
   # For docs, see: _RAGGED_SEGMENT_DOCSTRING
-  return _ragged_segment_aggregate(math_ops.unsorted_segment_max, data,
-                                   segment_ids, num_segments, name or
-                                   'RaggedSegmentMax')
+  return _ragged_segment_aggregate(math_ops.unsorted_segment_max,
+                                   data=data,
+                                   segment_ids=segment_ids,
+                                   num_segments=num_segments,
+                                   name=(name or 'RaggedSegmentMax'))
 
 
 def segment_mean(data, segment_ids, num_segments, name=None):
@@ -421,7 +429,7 @@
                             rt_input,
                             axis,
                             keepdims,
-                            separator='',
+                            separator=None,
                             name=None):
   """Aggregates across axes of a RaggedTensor using the given `Tensor` ops.
 
@@ -447,8 +455,9 @@
       given set of axes), or a `Tensor` with a constant value.  Must be in the
       range `[0, rt_input.rank)`.
     keepdims: If true, retains reduced dimensions with length 1.
-    separator: An optional string. Defaults to ''. The separator to use when
-      joining. Used only when input type is string.
+    separator: An optional string. Defaults to None. The separator to use when
+      joining. The separator must not be set for non-string data types. (i.e.
+      if separator is not None then it uses string ops)
     name: A name prefix for the returned tensor (optional).
 
   Returns:
@@ -461,7 +470,12 @@
     ValueError: If `axis` contains a `Tensor` whose value is not constant.
   """
   if not ragged_tensor.is_ragged(rt_input):
-    return reduce_op(rt_input, axis, name=name)
+    if separator is None:
+      return reduce_op(rt_input, axis, name=name)
+    else:
+      # When separator is not None, We infer that dtype is string and
+      # reduce_join will be called.
+      return reduce_op(rt_input, axis, name=name, separator=separator)
 
   if keepdims:
     raise ValueError('keepdims=True is not supported for RaggedTensors.')
@@ -533,30 +547,46 @@
 
 def reduce_sum(input_tensor, axis=None, keepdims=None, name=None):
   """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
-  return ragged_reduce_aggregate(math_ops.reduce_sum,
-                                 math_ops.unsorted_segment_sum, input_tensor,
-                                 axis, keepdims, name or 'RaggedReduceSum')
+
+  return ragged_reduce_aggregate(
+      reduce_op=math_ops.reduce_sum,
+      unsorted_segment_op=math_ops.unsorted_segment_sum,
+      rt_input=input_tensor,
+      axis=axis, keepdims=keepdims,
+      name=(name or 'RaggedReduceSum'))
 
 
 def reduce_prod(input_tensor, axis=None, keepdims=None, name=None):
   """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
-  return ragged_reduce_aggregate(math_ops.reduce_prod,
-                                 math_ops.unsorted_segment_prod, input_tensor,
-                                 axis, keepdims, name or 'RaggedReduceProd')
+  return ragged_reduce_aggregate(
+      reduce_op=math_ops.reduce_prod,
+      unsorted_segment_op=math_ops.unsorted_segment_prod,
+      rt_input=input_tensor,
+      axis=axis,
+      keepdims=keepdims,
+      name=(name or 'RaggedReduceProd'))
 
 
 def reduce_min(input_tensor, axis=None, keepdims=None, name=None):
   """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
-  return ragged_reduce_aggregate(math_ops.reduce_min,
-                                 math_ops.unsorted_segment_min, input_tensor,
-                                 axis, keepdims, name or 'RaggedReduceMin')
+  return ragged_reduce_aggregate(
+      reduce_op=math_ops.reduce_min,
+      unsorted_segment_op=math_ops.unsorted_segment_min,
+      rt_input=input_tensor,
+      axis=axis,
+      keepdims=keepdims,
+      name=(name or 'RaggedReduceMin'))
 
 
 def reduce_max(input_tensor, axis=None, keepdims=None, name=None):
   """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
-  return ragged_reduce_aggregate(math_ops.reduce_max,
-                                 math_ops.unsorted_segment_max, input_tensor,
-                                 axis, keepdims, name or 'RaggedReduceMax')
+  return ragged_reduce_aggregate(
+      reduce_op=math_ops.reduce_max,
+      unsorted_segment_op=math_ops.unsorted_segment_max,
+      rt_input=input_tensor,
+      axis=axis,
+      keepdims=keepdims,
+      name=(name or 'RaggedReduceMax'))
 
 
 def reduce_mean(input_tensor, axis=None, keepdims=None, name=None):
diff --git a/tensorflow/python/ops/ragged/ragged_string_ops.py b/tensorflow/python/ops/ragged/ragged_string_ops.py
index ed52e9a..b93b02b 100644
--- a/tensorflow/python/ops/ragged/ragged_string_ops.py
+++ b/tensorflow/python/ops/ragged/ragged_string_ops.py
@@ -26,6 +26,7 @@
 from tensorflow.python.ops.ragged import ragged_array_ops
 from tensorflow.python.ops.ragged import ragged_math_ops
 from tensorflow.python.ops.ragged import ragged_tensor
+from tensorflow.python.util import compat as util_compat
 from tensorflow.python.util import deprecation
 from tensorflow.python.util.tf_export import tf_export
 
@@ -650,3 +651,150 @@
   return ragged_math_ops.ragged_reduce_aggregate(
       string_ops.reduce_join, string_ops.unsorted_segment_join, inputs, axis,
       keepdims, separator, name or "RaggedSegmentJoin")
+
+
+@tf_export("strings.ngrams")
+def ngrams(data,
+           ngram_width,
+           separator=" ",
+           pad_values=None,
+           padding_width=None,
+           preserve_short_sequences=False,
+           name=None):
+  """Create a tensor of n-grams based on `data`.
+
+  Creates a tensor of n-grams based on `data`. The n-grams are created by
+  joining windows of `width` adjacent strings from the inner axis of `data`
+  using `separator`.
+
+  The input data can be padded on both the start and end of the sequence, if
+  desired, using the `pad_values` argument. If set, `pad_values` should contain
+  either a tuple of strings or a single string; the 0th element of the tuple
+  will be used to pad the left side of the sequence and the 1st element of the
+  tuple will be used to pad the right side of the sequence. The `padding_width`
+  arg controls how many padding values are added to each side; it defaults to
+  `ngram_width-1`.
+
+  If this op is configured to not have padding, or if it is configured to add
+  padding with `padding_width` set to less than ngram_width-1, it is possible
+  that a sequence, or a sequence plus padding, is smaller than the ngram
+  width. In that case, no ngrams will be generated for that sequence. This can
+  be prevented by setting `preserve_short_sequences`, which will cause the op
+  to always generate at least one ngram per non-empty sequence.
+
+  Args:
+    data: A Tensor or RaggedTensor containing the source data for the ngrams.
+    ngram_width: The width(s) of the ngrams to create. If this is a list or
+      tuple, the op will return ngrams of all specified arities in list order.
+      Values must be non-Tensor integers greater than 0.
+    separator: The separator string used between ngram elements. Must be a
+      string constant, not a Tensor.
+    pad_values: A tuple of (left_pad_value, right_pad_value), a single string,
+      or None. If None, no padding will be added; if a single string, then that
+      string will be used for both left and right padding. Values must be Python
+      strings.
+    padding_width: If set, `padding_width` pad values will be added to both
+      sides of each sequence. Defaults to `ngram_width`-1. Must be greater than
+      0. (Note that 1-grams are never padded, regardless of this value.)
+    preserve_short_sequences: If true, then ensure that at least one ngram is
+      generated for each input sequence.  In particular, if an input sequence is
+      shorter than `min(ngram_width) + 2*pad_width`, then generate a single
+      ngram containing the entire sequence.  If false, then no ngrams are
+      generated for these short input sequences.
+    name: The op name.
+
+  Returns:
+    A RaggedTensor of ngrams. If `data.shape=[D1...DN, S]`, then
+    `output.shape=[D1...DN, NUM_NGRAMS]`, where
+    `NUM_NGRAMS=S-ngram_width+1+2*padding_width`.
+
+  Raises:
+    TypeError: if `pad_values` is set to an invalid type.
+    ValueError: if `pad_values`, `padding_width`, or `ngram_width` is set to an
+      invalid value.
+  """
+
+  with ops.name_scope(name, "StringNGrams", [data]):
+    if pad_values is None:
+      left_pad = ""
+      right_pad = ""
+    elif isinstance(pad_values, (list, tuple)):
+      if (not isinstance(pad_values[0], util_compat.bytes_or_text_types) or
+          not isinstance(pad_values[1], util_compat.bytes_or_text_types)):
+        raise TypeError(
+            "pad_values must be a string, tuple of strings, or None.")
+      left_pad = pad_values[0]
+      right_pad = pad_values[1]
+    else:
+      if not isinstance(pad_values, util_compat.bytes_or_text_types):
+        raise TypeError(
+            "pad_values must be a string, tuple of strings, or None.")
+      left_pad = pad_values
+      right_pad = pad_values
+
+    if padding_width is not None and padding_width < 1:
+      raise ValueError("padding_width must be greater than 0.")
+
+    if padding_width is not None and pad_values is None:
+      raise ValueError("pad_values must be provided if padding_width is set.")
+
+    data = ragged_tensor.convert_to_tensor_or_ragged_tensor(
+        data, name="data", dtype=dtypes.string)
+
+    # preserve the shape of the data if it is a tensor
+    to_tensor = False
+    if isinstance(data, ops.Tensor):
+      dense_shape = array_ops.concat([array_ops.shape(data)[:-1], [-1]], axis=0)
+      to_tensor = True
+
+    if not isinstance(data, ragged_tensor.RaggedTensor):
+      if data.shape.ndims is None:
+        raise ValueError("Rank of data must be known.")
+      elif data.shape.ndims == 0:
+        raise ValueError("Data must have rank>0")
+      elif data.shape.ndims == 1:
+        rt = ragged_tensor.RaggedTensor.from_row_starts(
+            data, [0], validate=False)
+        return ngrams(rt, ngram_width, separator, pad_values, padding_width,
+                      preserve_short_sequences, name)[0]
+      else:
+        data = ragged_tensor.RaggedTensor.from_tensor(
+            data, ragged_rank=data.shape.ndims - 1)
+
+    if data.ragged_rank > 1:
+      output = data.with_values(
+          ngrams(data.values, ngram_width, separator, pad_values, padding_width,
+                 preserve_short_sequences, name))
+      return array_ops.reshape(output.flat_values,
+                               dense_shape) if to_tensor else output
+
+    if pad_values is None:
+      padding_width = 0
+
+    if pad_values is not None and padding_width is None:
+      padding_width = -1
+
+    if not isinstance(ngram_width, (list, tuple)):
+      ngram_widths = [ngram_width]
+    else:
+      ngram_widths = ngram_width
+    for width in ngram_widths:
+      if width < 1:
+        raise ValueError("All ngram_widths must be greater than 0. Got %s" %
+                         ngram_width)
+
+    output, output_splits = gen_string_ops.string_n_grams(
+        data=data.flat_values,
+        data_splits=data.row_splits,
+        separator=separator,
+        ngram_widths=ngram_widths,
+        left_pad=left_pad,
+        right_pad=right_pad,
+        pad_width=padding_width,
+        preserve_short_sequences=preserve_short_sequences)
+
+    # if the input is Dense tensor, the output should also be a dense tensor
+    output = ragged_tensor.RaggedTensor.from_row_splits(
+        values=output, row_splits=output_splits, validate=False)
+    return array_ops.reshape(output.flat_values,
+                             dense_shape) if to_tensor else output
diff --git a/tensorflow/python/ops/ragged/ragged_string_ops_test.py b/tensorflow/python/ops/ragged/ragged_string_ops_test.py
index 52f8805..978d54c 100644
--- a/tensorflow/python/ops/ragged/ragged_string_ops_test.py
+++ b/tensorflow/python/ops/ragged/ragged_string_ops_test.py
@@ -62,6 +62,35 @@
           'truth_shape': [2],
       },
       {
+          'input_array': [[
+              b'this', b'is', b'a', b'test', b'for', b'ragged', b'tensors'
+          ], [b'please', b'do', b'not', b'panic', b'!']],
+          'axis': 1,
+          'keepdims': False,
+          'truth': [
+              b'this|is|a|test|for|ragged|tensors', b'please|do|not|panic|!'
+          ],
+          'truth_shape': [2],
+          'separator': '|',
+      },
+      {
+          'input_array': [[[b'a', b'b'], [b'b', b'c']], [[b'dd', b'ee']]],
+          'axis': -1,
+          'keepdims': False,
+          'truth': [[b'a|b', b'b|c'], [b'dd|ee']],
+          'truth_shape': [2, None],
+          'separator': '|',
+      },
+      {
+          'input_array': [[[[b'a', b'b', b'c'], [b'dd', b'ee']]],
+                          [[[b'f', b'g', b'h'], [b'ii', b'jj']]]],
+          'axis': -2,
+          'keepdims': False,
+          'truth': [[[b'a|dd', b'b|ee', b'c']], [[b'f|ii', b'g|jj', b'h']]],
+          'truth_shape': [2, None, None],
+          'separator': '|',
+      },
+      {
           'input_array': [[[b't', b'h', b'i', b's'], [b'i', b's'], [b'a'],
                            [b't', b'e', b's', b't']],
                           [[b'p', b'l', b'e', b'a', b's', b'e'],
diff --git a/tensorflow/python/ops/ragged/ragged_tensor.py b/tensorflow/python/ops/ragged/ragged_tensor.py
index d06819c..3556707 100644
--- a/tensorflow/python/ops/ragged/ragged_tensor.py
+++ b/tensorflow/python/ops/ragged/ragged_tensor.py
@@ -20,6 +20,7 @@
 
 import numpy as np
 
+from tensorflow.python import tf2
 from tensorflow.python.client import session
 from tensorflow.python.framework import composite_tensor
 from tensorflow.python.framework import constant_op
@@ -788,9 +789,12 @@
                                           name=name)
     else:
       values = ops.convert_to_tensor(values, name="values")
-      partition = ops.convert_to_tensor(
-          partition, preferred_dtype=dtypes.int64,
-          name=name)
+      if isinstance(partition, np.ndarray) and partition.dtype == np.int32:
+        partition = ops.convert_to_tensor(partition, name=name)
+      else:
+        partition = ops.convert_to_tensor(
+            partition, preferred_dtype=dtypes.int64,
+            name=name)
       if partition.dtype not in (dtypes.int32, dtypes.int64):
         raise ValueError("%s must have dtype int32 or int64" % name)
 
@@ -1985,16 +1989,17 @@
       return [value]
 
   def _from_components(self, tensor_list):
-    # Currently, Keras converts tensors to numpy and then calls from_components
-    # with those np.arrays.  So if we see np.ndarrays, convert them to tensors.
-    # TODO(b/133606651) Update Keras to do something different here.  Consider
-    # adding something like TypeSpec.from_numpy_components?
-    if isinstance(tensor_list[0], np.ndarray):
-      tensor_list = [ops.convert_to_tensor(t) for t in tensor_list]
-
     result = tensor_list[0]
-    for row_splits in reversed(tensor_list[1:]):
-      result = RaggedTensor(result, row_splits, internal=True)
+    if (all(isinstance(t, np.ndarray) for t in tensor_list) and
+        not tf2.enabled()):
+      for row_splits in reversed(tensor_list[1:]):
+        result = ragged_tensor_value.RaggedTensorValue(result, row_splits)
+    else:
+      if isinstance(tensor_list[0], np.ndarray):
+        tensor_list = [ops.convert_to_tensor(t) for t in tensor_list]
+        result = tensor_list[0]
+      for row_splits in reversed(tensor_list[1:]):
+        result = RaggedTensor(result, row_splits, internal=True)
     return result
 
   # The RaggedTensorSpec tensor_list encoding uses to/from_variant ops
diff --git a/tensorflow/python/ops/ragged/ragged_tensor_test.py b/tensorflow/python/ops/ragged/ragged_tensor_test.py
index eb8767b..0633872 100644
--- a/tensorflow/python/ops/ragged/ragged_tensor_test.py
+++ b/tensorflow/python/ops/ragged/ragged_tensor_test.py
@@ -29,12 +29,14 @@
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops.ragged import ragged_factory_ops
 from tensorflow.python.ops.ragged import ragged_math_ops
 from tensorflow.python.ops.ragged import ragged_tensor_value
 from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
+from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensorSpec
 from tensorflow.python.platform import googletest
 
 
@@ -374,6 +376,24 @@
         rt,
         [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
 
+  def testFromRowSplitsWithDifferentSplitTypes(self):
+    values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
+    splits1 = [0, 2, 2, 5, 6, 7]
+    splits2 = np.array([0, 2, 2, 5, 6, 7], np.int64)
+    splits3 = np.array([0, 2, 2, 5, 6, 7], np.int32)
+    splits4 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
+    splits5 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int32)
+    rt1 = RaggedTensor.from_row_splits(values, splits1)
+    rt2 = RaggedTensor.from_row_splits(values, splits2)
+    rt3 = RaggedTensor.from_row_splits(values, splits3)
+    rt4 = RaggedTensor.from_row_splits(values, splits4)
+    rt5 = RaggedTensor.from_row_splits(values, splits5)
+    self.assertEqual(rt1.row_splits.dtype, dtypes.int64)
+    self.assertEqual(rt2.row_splits.dtype, dtypes.int64)
+    self.assertEqual(rt3.row_splits.dtype, dtypes.int32)
+    self.assertEqual(rt4.row_splits.dtype, dtypes.int64)
+    self.assertEqual(rt5.row_splits.dtype, dtypes.int32)
+
   def testFromRowSplitsWithEmptySplits(self):
     err_msg = 'row_splits tensor may not be empty'
     with self.assertRaisesRegexp(ValueError, err_msg):
@@ -1547,5 +1567,179 @@
           output_ragged_rank=1,
           input_ragged_rank=1)
 
+
+@test_util.run_all_in_graph_and_eager_modes
+class RaggedTensorSpecTest(test_util.TensorFlowTestCase,
+                           parameterized.TestCase):
+
+  def assertAllTensorsEqual(self, list1, list2):
+    self.assertLen(list1, len(list2))
+    for (t1, t2) in zip(list1, list2):
+      self.assertAllEqual(t1, t2)
+
+  def testConstruction(self):
+    spec1 = RaggedTensorSpec(ragged_rank=1)
+    self.assertEqual(spec1._shape.rank, None)
+    self.assertEqual(spec1._dtype, dtypes.float32)
+    self.assertEqual(spec1._row_splits_dtype, dtypes.int64)
+    self.assertEqual(spec1._ragged_rank, 1)
+
+    spec2 = RaggedTensorSpec(shape=[None, None, None])
+    self.assertEqual(spec2._shape.as_list(), [None, None, None])
+    self.assertEqual(spec2._dtype, dtypes.float32)
+    self.assertEqual(spec2._row_splits_dtype, dtypes.int64)
+    self.assertEqual(spec2._ragged_rank, 2)
+
+    with self.assertRaisesRegexp(ValueError, 'Must specify ragged_rank'):
+      RaggedTensorSpec()
+    with self.assertRaisesRegexp(TypeError, 'ragged_rank must be an int'):
+      RaggedTensorSpec(ragged_rank=constant_op.constant(1))
+    with self.assertRaisesRegexp(ValueError,
+                                 'ragged_rank must be less than rank'):
+      RaggedTensorSpec(ragged_rank=2, shape=[None, None])
+
+  def testValueType(self):
+    spec1 = RaggedTensorSpec(ragged_rank=1)
+    self.assertEqual(spec1.value_type, RaggedTensor)
+    spec2 = RaggedTensorSpec(ragged_rank=0)
+    self.assertEqual(spec2.value_type, ops.Tensor)
+
+  @parameterized.parameters([
+      (RaggedTensorSpec(ragged_rank=1),
+       (tensor_shape.TensorShape(None), dtypes.float32, 1, dtypes.int64)),
+      (RaggedTensorSpec(shape=[5, None, None]),
+       (tensor_shape.TensorShape([5, None, None]), dtypes.float32,
+        2, dtypes.int64)),
+      (RaggedTensorSpec(shape=[5, None, None], dtype=dtypes.int32),
+       (tensor_shape.TensorShape([5, None, None]), dtypes.int32, 2,
+        dtypes.int64)),
+      (RaggedTensorSpec(ragged_rank=1, row_splits_dtype=dtypes.int32),
+       (tensor_shape.TensorShape(None), dtypes.float32, 1, dtypes.int32)),
+  ])  # pyformat: disable
+  def testSerialize(self, rt_spec, expected):
+    serialization = rt_spec._serialize()
+    # TensorShape has an unconventional definition of equality, so we can't use
+    # assertEqual directly here.  But repr() is deterministic and lossless for
+    # the expected values, so we can use that instead.
+    self.assertEqual(repr(serialization), repr(expected))
+
+  @parameterized.parameters([
+      (RaggedTensorSpec(ragged_rank=0, shape=[5, 3]), [
+          tensor_spec.TensorSpec([5, 3], dtypes.float32),
+      ]),
+      (RaggedTensorSpec(ragged_rank=1), [
+          tensor_spec.TensorSpec(None, dtypes.float32),
+          tensor_spec.TensorSpec([None], dtypes.int64)
+      ]),
+      (RaggedTensorSpec(ragged_rank=1, row_splits_dtype=dtypes.int32), [
+          tensor_spec.TensorSpec(None, dtypes.float32),
+          tensor_spec.TensorSpec([None], dtypes.int32),
+      ]),
+      (RaggedTensorSpec(ragged_rank=2), [
+          tensor_spec.TensorSpec(None, dtypes.float32),
+          tensor_spec.TensorSpec([None], dtypes.int64),
+          tensor_spec.TensorSpec([None], dtypes.int64),
+      ]),
+      (RaggedTensorSpec(shape=[5, None, None], dtype=dtypes.string), [
+          tensor_spec.TensorSpec([None], dtypes.string),
+          tensor_spec.TensorSpec([6], dtypes.int64),
+          tensor_spec.TensorSpec([None], dtypes.int64),
+      ]),
+  ])
+  def testComponentSpecs(self, rt_spec, expected):
+    self.assertEqual(rt_spec._component_specs, expected)
+
+  @parameterized.parameters([
+      {
+          'rt_spec': RaggedTensorSpec(ragged_rank=0),
+          'rt': [1.0, 2.0, 3.0],
+          'components': [[1.0, 2.0, 3.0]]
+      },
+      {
+          'rt_spec': RaggedTensorSpec(ragged_rank=1),
+          'rt': [[1.0, 2.0], [3.0]],
+          'components': [[1.0, 2.0, 3.0], [0, 2, 3]]
+      },
+      {
+          'rt_spec': RaggedTensorSpec(shape=[2, None, None]),
+          'rt': [[[1.0, 2.0], [3.0]], [[], [4.0]]],
+          'components': [[1.0, 2.0, 3.0, 4.0], [0, 2, 4], [0, 2, 3, 3, 4]]
+      },
+  ])
+  def testToFromComponents(self, rt_spec, rt, components):
+    rt = ragged_factory_ops.constant(rt)
+    actual_components = rt_spec._to_components(rt)
+    self.assertAllTensorsEqual(actual_components, components)
+    rt_reconstructed = rt_spec._from_components(actual_components)
+    self.assertAllEqual(rt, rt_reconstructed)
+
+  @test_util.run_v1_only('RaggedTensorValue is deprecated in v2')
+  def testFromNumpyComponents(self):
+    spec1 = RaggedTensorSpec(ragged_rank=1, dtype=dtypes.int32)
+    rt1 = spec1._from_components([np.array([1, 2, 3]), np.array([0, 2, 3])])
+    self.assertIsInstance(rt1, ragged_tensor_value.RaggedTensorValue)
+    self.assertAllEqual(rt1, [[1, 2], [3]])
+
+    spec2 = RaggedTensorSpec(ragged_rank=2, dtype=dtypes.int32)
+    rt2 = spec2._from_components([np.array([1, 2, 3]), np.array([0, 2, 3]),
+                                  np.array([0, 0, 2, 3])])
+    self.assertIsInstance(rt2, ragged_tensor_value.RaggedTensorValue)
+    self.assertAllEqual(rt2, [[[], [1, 2]], [[3]]])
+
+    spec3 = RaggedTensorSpec(ragged_rank=0, dtype=dtypes.int32)
+    rt3 = spec3._from_components([np.array([1, 2, 3])])
+    self.assertIsInstance(rt3, np.ndarray)
+    self.assertAllEqual(rt3, [1, 2, 3])
+
+  @parameterized.parameters([
+      RaggedTensorSpec(ragged_rank=0, shape=[5, 3]),
+      RaggedTensorSpec(ragged_rank=1),
+      RaggedTensorSpec(ragged_rank=1, row_splits_dtype=dtypes.int32),
+      RaggedTensorSpec(ragged_rank=2, dtype=dtypes.string),
+      RaggedTensorSpec(shape=[5, None, None]),
+  ])
+  def testFlatTensorSpecs(self, rt_spec):
+    self.assertEqual(rt_spec._flat_tensor_specs,
+                     [tensor_spec.TensorSpec(None, dtypes.variant)])
+
+  @parameterized.parameters([
+      {
+          'rt_spec': RaggedTensorSpec(ragged_rank=1),
+          'rt': [[1.0, 2.0], [3.0]]
+      },
+      {
+          'rt_spec': RaggedTensorSpec(shape=[2, None, None]),
+          'rt': [[[1.0, 2.0], [3.0]], [[], [4.0]]]
+      },
+  ])
+  def testToFromTensorList(self, rt_spec, rt):
+    rt = ragged_factory_ops.constant(rt)
+    tensor_list = rt_spec._to_tensor_list(rt)
+    rt_reconstructed = rt_spec._from_tensor_list(tensor_list)
+    self.assertAllEqual(rt, rt_reconstructed)
+
+  @parameterized.parameters([
+      (RaggedTensorSpec([2, None], dtypes.float32, 1), 32,
+       RaggedTensorSpec([32, 2, None], dtypes.float32, 2)),
+      (RaggedTensorSpec([4, None], dtypes.float32, 1), None,
+       RaggedTensorSpec([None, 4, None], dtypes.float32, 2)),
+      (RaggedTensorSpec([2], dtypes.float32,
+                        -1), 32, RaggedTensorSpec([32, 2], dtypes.float32, 0)),
+  ])
+  def testBatch(self, spec, batch_size, expected):
+    self.assertEqual(spec._batch(batch_size), expected)
+
+  @parameterized.parameters([
+      (RaggedTensorSpec([32, None, None], dtypes.float32, 2),
+       RaggedTensorSpec([None, None], dtypes.float32, 1)),
+      (RaggedTensorSpec([None, None, None], dtypes.float32, 2),
+       RaggedTensorSpec([None, None], dtypes.float32, 1)),
+      (RaggedTensorSpec([32, 2], dtypes.float32, 0),
+       RaggedTensorSpec([2], dtypes.float32, -1)),
+  ])  # pyformat: disable
+  def testUnbatch(self, spec, expected):
+    self.assertEqual(spec._unbatch(), expected)
+
+
 if __name__ == '__main__':
   googletest.main()
diff --git a/tensorflow/python/ops/ragged/ragged_to_tensor_op_test.py b/tensorflow/python/ops/ragged/ragged_to_tensor_op_test.py
index 1c589a7..0b994af 100644
--- a/tensorflow/python/ops/ragged/ragged_to_tensor_op_test.py
+++ b/tensorflow/python/ops/ragged/ragged_to_tensor_op_test.py
@@ -21,9 +21,16 @@
 from absl.testing import parameterized
 import numpy as np
 
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops.ragged import ragged_conversion_ops
 from tensorflow.python.ops.ragged import ragged_factory_ops
+from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
 from tensorflow.python.platform import googletest
 
 
@@ -132,5 +139,355 @@
       rt.to_tensor(default)
 
 
+# This covers the tests above, but with the new implementation.
+@test_util.run_all_in_graph_and_eager_modes
+class RaggedTensorToTensorOpNewTest(test_util.TensorFlowTestCase,
+                                    parameterized.TestCase):
+
+  def testDocStringExamples(self):
+    """Example from ragged_to_tensor.__doc__."""
+    rt = ragged_factory_ops.constant([[9, 8, 7], [], [6, 5], [4]])
+    dt = ragged_conversion_ops.ragged_to_dense(rt)
+    self.assertAllEqual(dt, [[9, 8, 7], [0, 0, 0], [6, 5, 0], [4, 0, 0]])
+
+  @parameterized.parameters(
+      {
+          'rt_input': [],
+          'ragged_rank': 1,
+          'expected': [],
+          'expected_shape': [0, 0],
+      },
+      {
+          'rt_input': [[1, 2, 3], [], [4], [5, 6]],
+          'expected': [[1, 2, 3], [0, 0, 0], [4, 0, 0], [5, 6, 0]]
+      },
+      {
+          'rt_input': [[1, 2, 3], [], [4], [5, 6]],
+          'default': 9,
+          'expected': [[1, 2, 3], [9, 9, 9], [4, 9, 9], [5, 6, 9]]
+      },
+      {
+          'rt_input': [[[1], [2], [3]], [], [[4]], [[5], [6]]],
+          'ragged_rank':
+              1,
+          'default': [9],
+          'expected': [[[1], [2], [3]], [[9], [9], [9]], [[4], [9], [9]],
+                       [[5], [6], [9]]]
+      },
+      {
+          'rt_input': [[[1, 2], [], [3, 4]], [], [[5]], [[6, 7], [8]]],
+          'expected': [
+              [[1, 2], [0, 0], [3, 4]],  #
+              [[0, 0], [0, 0], [0, 0]],  #
+              [[5, 0], [0, 0], [0, 0]],  #
+              [[6, 7], [8, 0], [0, 0]],  #
+          ]
+      },
+      {
+          'rt_input': [[[1, 2], [], [3, 4]], [], [[5]], [[6, 7], [8]]],
+          'default':
+              9,
+          'expected': [
+              [[1, 2], [9, 9], [3, 4]],  #
+              [[9, 9], [9, 9], [9, 9]],  #
+              [[5, 9], [9, 9], [9, 9]],  #
+              [[6, 7], [8, 9], [9, 9]],  #
+          ]
+      },
+      {
+          'rt_input': [[[1], [2], [3]]],
+          'ragged_rank': 1,
+          'default': 0,
+          'expected': [[[1], [2], [3]]],
+      },
+      {
+          'rt_input': [[[[1], [2]], [], [[3]]]],
+          'default': 9,
+          'expected': [[[[1], [2]], [[9], [9]], [[3], [9]]]],
+      },
+  )
+  def testRaggedTensorToTensor(self,
+                               rt_input,
+                               expected,
+                               ragged_rank=None,
+                               default=None,
+                               expected_shape=None):
+    rt = ragged_factory_ops.constant(rt_input, ragged_rank=ragged_rank)
+    dt = ragged_conversion_ops.ragged_to_dense(rt, default_value=default)
+
+    self.assertIsInstance(dt, ops.Tensor)
+    self.assertEqual(rt.dtype, dt.dtype)
+    self.assertTrue(dt.shape.is_compatible_with(rt.shape))
+    if expected_shape is not None:
+      expected = np.ndarray(expected_shape, buffer=np.array(expected))
+    self.assertAllEqual(dt, expected)
+
+  @parameterized.parameters(
+      {
+          'rt_input': [[1, 2, 3]],
+          'default': 'a',
+          'error': (TypeError, '.*'),
+      }, {
+          'rt_input': [[1, 2, 3]],
+          'default': 'b',
+          'error': (TypeError, '.*'),
+      })
+  def testError(self, rt_input, default, error, ragged_rank=None):
+    rt = ragged_factory_ops.constant(rt_input, ragged_rank=ragged_rank)
+    with self.assertRaisesRegexp(error[0], error[1]):
+      ragged_conversion_ops.ragged_to_dense(rt, default_value=default)
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class RaggedToTensorOpAdditionalTests(test_util.TensorFlowTestCase):
+
+  def _compare_to_reference(self,
+                            ragged_tensor,
+                            expected=None,
+                            default_value=None):
+    treatment = ragged_conversion_ops.ragged_to_dense(
+        ragged_tensor, default_value=default_value)
+    control = ragged_tensor.to_tensor(default_value=default_value)
+    self.assertAllEqual(control, treatment)
+    if expected is not None:
+      self.assertAllEqual(expected, treatment)
+
+  def test_already_dense_simple(self):
+    """This studies a tensor initialized with value_rowids and nrows."""
+    input_data = RaggedTensor.from_value_rowids(
+        values=constant_op.constant([6, 7, 8, 9, 10, 11], dtype=dtypes.int64),
+        value_rowids=constant_op.constant([0, 0, 0, 1, 1, 1],
+                                          dtype=dtypes.int64),
+        nrows=constant_op.constant(2, dtype=dtypes.int64),
+        validate=True)
+    self._compare_to_reference(input_data, [[6, 7, 8], [9, 10, 11]])
+
+  def test_already_dense_with_dense_values_and_default(self):
+    """This studies a tensor initialized with value_rowids and nrows."""
+    input_data = RaggedTensor.from_value_rowids(
+        values=constant_op.constant(
+            [[6, 7], [8, 9], [10, 11], [12, 13], [14, 15], [16, 17]],
+            dtype=dtypes.int64),
+        value_rowids=constant_op.constant([0, 0, 0, 1, 1, 1],
+                                          dtype=dtypes.int64),
+        nrows=constant_op.constant(2, dtype=dtypes.int64),
+        validate=True)
+    self._compare_to_reference(
+        input_data,
+        [[[6, 7], [8, 9], [10, 11]], [[12, 13], [14, 15], [16, 17]]],
+        default_value=constant_op.constant([31, 32], dtype=dtypes.int64))
+
+  def test_already_dense_with_dense_values(self):
+    """This studies a tensor initialized with value_rowids and nrows."""
+    input_data = RaggedTensor.from_value_rowids(
+        values=constant_op.constant(
+            [[6, 7], [8, 9], [10, 11], [12, 13], [14, 15], [16, 17]],
+            dtype=dtypes.int64),
+        value_rowids=constant_op.constant([0, 0, 0, 1, 1, 1],
+                                          dtype=dtypes.int64),
+        nrows=constant_op.constant(2, dtype=dtypes.int64),
+        validate=True)
+    self._compare_to_reference(
+        input_data,
+        [[[6, 7], [8, 9], [10, 11]], [[12, 13], [14, 15], [16, 17]]])
+
+  def test_ragged_with_dense_values_and_default(self):
+    """This studies a tensor initialized with value_rowids and nrows."""
+    input_data = RaggedTensor.from_value_rowids(
+        values=constant_op.constant(
+            [[6, 7], [8, 9], [10, 11], [12, 13], [14, 15]], dtype=dtypes.int64),
+        value_rowids=constant_op.constant([0, 0, 0, 1, 1], dtype=dtypes.int64),
+        nrows=constant_op.constant(2, dtype=dtypes.int64),
+        validate=True)
+    self._compare_to_reference(
+        input_data, [[[6, 7], [8, 9], [10, 11]], [[12, 13], [14, 15], [2, 3]]],
+        default_value=[2, 3])
+
+  def test_ragged_with_dense_values_and_small_default(self):
+    """This studies a tensor initialized with value_rowids and nrows."""
+    input_data = RaggedTensor.from_value_rowids(
+        values=constant_op.constant(
+            [[6, 7], [8, 9], [10, 11], [12, 13], [14, 15]], dtype=dtypes.int64),
+        value_rowids=constant_op.constant([0, 0, 0, 1, 1], dtype=dtypes.int64),
+        nrows=constant_op.constant(2, dtype=dtypes.int64),
+        validate=True)
+    self._compare_to_reference(
+        input_data, [[[6, 7], [8, 9], [10, 11]], [[12, 13], [14, 15], [2, 2]]],
+        default_value=2)
+
+  def test_already_dense_with_dense_values_string(self):
+    """This studies a tensor initialized with value_rowids and nrows."""
+    input_data = RaggedTensor.from_value_rowids(
+        values=constant_op.constant(
+            [[b'a', b'b'], [b'c', b'd'], [b'e', b'f'], [b'g', b'jalapeno'],
+             [b'kangaroo', b'llama'], [b'manzana', b'nectar']],
+            dtype=dtypes.string),
+        value_rowids=constant_op.constant([0, 0, 0, 1, 1, 1],
+                                          dtype=dtypes.int64),
+        nrows=constant_op.constant(2, dtype=dtypes.int64),
+        validate=True)
+    self._compare_to_reference(input_data,
+                               [[[b'a', b'b'], [b'c', b'd'], [b'e', b'f']],
+                                [[b'g', b'jalapeno'], [b'kangaroo', b'llama'],
+                                 [b'manzana', b'nectar']]])
+
+  def test_already_dense_with_string(self):
+    """This studies a tensor initialized with value_rowids and nrows."""
+    input_data = RaggedTensor.from_value_rowids(
+        values=constant_op.constant(
+            ['a', 'b', 'c', 'd', 'e', 'antidisestablishmentarianism'],
+            dtype=dtypes.string),
+        value_rowids=constant_op.constant([0, 0, 0, 1, 1, 1],
+                                          dtype=dtypes.int64),
+        nrows=constant_op.constant(2, dtype=dtypes.int64),
+        validate=True)
+    self._compare_to_reference(
+        input_data,
+        [[b'a', b'b', b'c'], [b'd', b'e', b'antidisestablishmentarianism']])
+
+  def test_already_dense(self):
+    input_data = ragged_factory_ops.constant([[0, 1, 2], [3, 4, 5]])
+    self._compare_to_reference(input_data, [[0, 1, 2], [3, 4, 5]])
+
+  def test_true_ragged(self):
+    input_data = ragged_factory_ops.constant([[0, 1, 2], [], [3]])
+    self._compare_to_reference(input_data, [[0, 1, 2], [0, 0, 0], [3, 0, 0]])
+
+  def test_true_ragged_default_3(self):
+    input_data = ragged_factory_ops.constant([[0, 1, 2], [], [3]])
+    self._compare_to_reference(
+        input_data, [[0, 1, 2], [3, 3, 3], [3, 3, 3]], default_value=3)
+
+  def test_three_dimensional_ragged(self):
+    input_data = ragged_factory_ops.constant([[[0, 1, 2], []], [], [[3]]])
+    self._compare_to_reference(
+        input_data, [[[0, 1, 2], [3, 3, 3]], [[3, 3, 3], [3, 3, 3]],
+                     [[3, 3, 3], [3, 3, 3]]],
+        default_value=3)
+
+  def test_empty_tensor(self):
+    input_data = RaggedTensor.from_value_rowids(
+        values=constant_op.constant([], dtype=dtypes.int64),
+        value_rowids=constant_op.constant([], dtype=dtypes.int64),
+        nrows=constant_op.constant(2, dtype=dtypes.int64),
+        validate=True)
+    self._compare_to_reference(input_data, [[], []], default_value=3)
+
+  def test_empty_last(self):
+    input_data = ragged_factory_ops.constant([[0, 1, 2], [], [3], []])
+    self._compare_to_reference(input_data,
+                               [[0, 1, 2], [0, 0, 0], [3, 0, 0], [0, 0, 0]])
+
+  def test_shape_limit(self):
+    input_data = ragged_factory_ops.constant([[0, 1, 2, 3], [], [4], []])
+    actual = ragged_conversion_ops.ragged_to_dense(input_data, shape=[2, 3])
+    self.assertAllEqual(actual, [[0, 1, 2], [0, 0, 0]])
+    self.assertEqual(actual.shape.as_list(), [2, 3])
+
+  def test_shape_limit_tuple(self):
+    input_data = ragged_factory_ops.constant([[0, 1, 2, 3], [], [4], []])
+    actual = ragged_conversion_ops.ragged_to_dense(input_data, shape=(2, 3))
+    self.assertAllEqual(actual, [[0, 1, 2], [0, 0, 0]])
+    self.assertEqual(actual.shape.as_list(), [2, 3])
+
+  def test_shape_limit_tensor_shape(self):
+    input_data = ragged_factory_ops.constant([[0, 1, 2, 3], [], [4], []])
+    actual = ragged_conversion_ops.ragged_to_dense(
+        input_data, shape=tensor_shape.TensorShape([2, 3]))
+    self.assertAllEqual(actual, [[0, 1, 2], [0, 0, 0]])
+    self.assertEqual(actual.shape.as_list(), [2, 3])
+
+  def test_shape_half_limit_tensor_shape(self):
+    input_data = ragged_factory_ops.constant([[0, 1, 2, 3], [], [4], []])
+    actual = ragged_conversion_ops.ragged_to_dense(
+        input_data, shape=tensor_shape.TensorShape([2, None]))
+    self.assertAllEqual(actual, [[0, 1, 2, 3], [0, 0, 0, 0]])
+
+  def test_skip_eager_shape_half_limit_tensor_shape(self):
+    # Eager would produce a shape of [2, 4]
+    input_data = ragged_factory_ops.constant([[0, 1, 2, 3], [], [4], []])
+    actual = ragged_conversion_ops.ragged_to_dense(
+        input_data, shape=tensor_shape.TensorShape([2, None]))
+    result = actual.shape.as_list()
+    # This is equal to [2, 4] in eager, or [2, None] in non-eager.
+    self.assertEqual(result[0], 2)
+
+  def test_shape_limit_shape_is_tensor_int64(self):
+    input_data = ragged_factory_ops.constant([[0, 1, 2, 3], [], [4], []])
+    actual = ragged_conversion_ops.ragged_to_dense(
+        input_data, shape=constant_op.constant([2, 3], dtype=dtypes.int64))
+    self.assertAllEqual(actual, [[0, 1, 2], [0, 0, 0]])
+    self.assertEqual(actual.shape.as_list(), [2, 3])
+
+  def test_shape_limit_shape_is_tensor_int32(self):
+    input_data = ragged_factory_ops.constant([[0, 1, 2, 3], [], [4], []])
+    actual = ragged_conversion_ops.ragged_to_dense(
+        input_data, shape=constant_op.constant([2, 3], dtype=dtypes.int32))
+    self.assertAllEqual(actual, [[0, 1, 2], [0, 0, 0]])
+    self.assertEqual(actual.shape.as_list(), [2, 3])
+
+  def test_shape_expand_first_dim(self):
+    input_data = ragged_factory_ops.constant([[0, 1, 2], [], [3]])
+    actual = ragged_conversion_ops.ragged_to_dense(input_data, shape=[4, 4])
+    self.assertAllEqual(
+        actual, [[0, 1, 2, 0], [0, 0, 0, 0], [3, 0, 0, 0], [0, 0, 0, 0]])
+    self.assertEqual(actual.shape.as_list(), [4, 4])
+
+  def test_value_transposed(self):
+    # This test tries to get a tensor in columnar format, where I am uncertain
+    # as to whether the underlying op, which copies data in the raw format,
+    # could fail.
+    my_value = array_ops.transpose(
+        constant_op.constant([[0, 1, 2, 3], [4, 5, 6, 7]]))
+    input_data = RaggedTensor.from_value_rowids(
+        values=my_value,
+        value_rowids=constant_op.constant([0, 1, 2, 3], dtype=dtypes.int64),
+        nrows=constant_op.constant(4, dtype=dtypes.int64),
+        validate=True)
+    self._compare_to_reference(input_data,
+                               [[[0, 4]], [[1, 5]], [[2, 6]], [[3, 7]]])
+
+  # This fails on the older version of to_tensor.
+  def test_broadcast_default(self):
+    # This test is commented out. The functionality here is not supported.
+    # The dense dimension here is 2 x 2
+    input_data = ragged_factory_ops.constant([[[[1, 2], [3, 4]]], []],
+                                             ragged_rank=1)
+    # This placeholder has a 2 x 1 dimension.
+    default_value = array_ops.placeholder_with_default([[5], [6]], shape=None)
+    actual = ragged_conversion_ops.ragged_to_dense(
+        input_data, default_value=default_value)
+    expected = [[[[1, 2], [3, 4]]], [[[5, 5], [6, 6]]]]
+    self.assertAllEqual(actual, expected)
+
+  # This fails on the older version of to_tensor.
+  def test_broadcast_default_no_placeholder(self):
+    # Again, this functionality is not supported. It fails more gracefully
+    # when creating the op.
+    input_data = ragged_factory_ops.constant([[[[1, 2], [3, 4]]], []],
+                                             ragged_rank=1)
+    # default_value has a 2 x 1 dimension.
+    default_value = constant_op.constant([[5], [6]], shape=None)
+    actual = ragged_conversion_ops.ragged_to_dense(
+        input_data, default_value=default_value)
+    expected = [[[[1, 2], [3, 4]]], [[[5, 5], [6, 6]]]]
+    self.assertAllEqual(actual, expected)
+
+  def test_shape_expand_second_dim(self):
+    input_data = ragged_factory_ops.constant([[0, 1, 2], [], [3], []])
+    actual = ragged_conversion_ops.ragged_to_dense(input_data, shape=[3, 4])
+    self.assertAllEqual(actual, [[0, 1, 2, 0], [0, 0, 0, 0], [3, 0, 0, 0]])
+
+  def test_empty_tensor_with_shape(self):
+    input_data = RaggedTensor.from_value_rowids(
+        values=constant_op.constant([], dtype=dtypes.int64),
+        value_rowids=constant_op.constant([], dtype=dtypes.int64),
+        nrows=constant_op.constant(2, dtype=dtypes.int64),
+        validate=True)
+    actual = ragged_conversion_ops.ragged_to_dense(
+        input_data, default_value=3, shape=[2, 3])
+    self.assertAllEqual(actual, [[3, 3, 3], [3, 3, 3]])
+
+
 if __name__ == '__main__':
   googletest.main()
diff --git a/tensorflow/python/ops/ragged/string_ngrams_op_test.py b/tensorflow/python/ops/ragged/string_ngrams_op_test.py
new file mode 100644
index 0000000..464eb3b
--- /dev/null
+++ b/tensorflow/python/ops/ragged/string_ngrams_op_test.py
@@ -0,0 +1,278 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the b"License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an b"AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the Tensorflow strings.ngrams op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops.ragged import ragged_factory_ops
+from tensorflow.python.ops.ragged import ragged_string_ops
+from tensorflow.python.platform import test
+
+
+class StringNgramsTest(test_util.TensorFlowTestCase):
+
+  def test_unpadded_ngrams(self):
+    data = [[b"aa", b"bb", b"cc", b"dd"], [b"ee", b"ff"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=3, separator=b"|")
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[b"aa|bb|cc", b"bb|cc|dd"], []]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_tuple_multi_ngrams(self):
+    data = [[b"aa", b"bb", b"cc", b"dd"], [b"ee", b"ff"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=(2, 3), separator=b"|")
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[b"aa|bb", b"bb|cc", b"cc|dd", b"aa|bb|cc", b"bb|cc|dd"],
+                       [b"ee|ff"]]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_tuple_multi_ngrams_inverted_order(self):
+    data = [[b"aa", b"bb", b"cc", b"dd"], [b"ee", b"ff"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=(3, 2), separator=b"|")
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[b"aa|bb|cc", b"bb|cc|dd", b"aa|bb", b"bb|cc", b"cc|dd"],
+                       [b"ee|ff"]]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_list_multi_ngrams(self):
+    data = [[b"aa", b"bb", b"cc", b"dd"], [b"ee", b"ff"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=[2, 3], separator=b"|")
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[b"aa|bb", b"bb|cc", b"cc|dd", b"aa|bb|cc", b"bb|cc|dd"],
+                       [b"ee|ff"]]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_multi_ngram_ordering(self):
+    data = [[b"aa", b"bb", b"cc", b"dd"], [b"ee", b"ff"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=[3, 2], separator=b"|")
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[b"aa|bb|cc", b"bb|cc|dd", b"aa|bb", b"bb|cc", b"cc|dd"],
+                       [b"ee|ff"]]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_fully_padded_ngrams(self):
+    data = [[b"a"], [b"b", b"c", b"d"], [b"e", b"f"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=3, separator=b"|", pad_values=(b"LP", b"RP"))
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [
+        [b"LP|LP|a", b"LP|a|RP", b"a|RP|RP"],  # 0
+        [b"LP|LP|b", b"LP|b|c", b"b|c|d", b"c|d|RP", b"d|RP|RP"],  # 1
+        [b"LP|LP|e", b"LP|e|f", b"e|f|RP", b"f|RP|RP"]  # 2
+    ]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_ngram_padding_size_cap(self):
+    # Validate that the padding size is never greater than ngram_size - 1.
+    data = [[b"a"], [b"b", b"c", b"d"], [b"e", b"f"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor,
+        ngram_width=3,
+        separator=b"|",
+        pad_values=(b"LP", b"RP"),
+        padding_width=10)
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [
+        [b"LP|LP|a", b"LP|a|RP", b"a|RP|RP"],  # 0
+        [b"LP|LP|b", b"LP|b|c", b"b|c|d", b"c|d|RP", b"d|RP|RP"],  # 1
+        [b"LP|LP|e", b"LP|e|f", b"e|f|RP", b"f|RP|RP"]  # 2
+    ]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_singly_padded_ngrams(self):
+    data = [[b"a"], [b"b", b"c", b"d"], [b"e", b"f"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor,
+        ngram_width=5,
+        separator=b"|",
+        pad_values=(b"LP", b"RP"),
+        padding_width=1)
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[], [b"LP|b|c|d|RP"], []]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_singly_padded_ngrams_with_preserve_short(self):
+    data = [[b"a"], [b"b", b"c", b"d"], [b"e", b"f"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor,
+        ngram_width=5,
+        separator=b"|",
+        pad_values=(b"LP", b"RP"),
+        padding_width=1,
+        preserve_short_sequences=True)
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[b"LP|a|RP"], [b"LP|b|c|d|RP"], [b"LP|e|f|RP"]]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_singly_padded_multiple_ngrams(self):
+    data = [[b"a"], [b"b", b"c", b"d"], [b"e", b"f"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor,
+        ngram_width=(1, 5),
+        separator=b"|",
+        pad_values=(b"LP", b"RP"),
+        padding_width=1)
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[b"a"], [b"b", b"c", b"d", b"LP|b|c|d|RP"], [b"e", b"f"]]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_single_padding_string(self):
+    data = [[b"a"], [b"b", b"c", b"d"], [b"e", b"f"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor,
+        ngram_width=5,
+        separator=b"|",
+        pad_values=b"[PAD]",
+        padding_width=1)
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[], [b"[PAD]|b|c|d|[PAD]"], []]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_explicit_multiply_padded_ngrams(self):
+    data = [[b"a"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor,
+        ngram_width=5,
+        separator=b"|",
+        pad_values=(b"LP", b"RP"),
+        padding_width=2)
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[b"LP|LP|a|RP|RP"]]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_ragged_inputs_with_multiple_ragged_dimensions(self):
+    data = [[[[b"aa", b"bb", b"cc", b"dd"]], [[b"ee", b"ff"]]]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=3, separator=b"|")
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[[[b"aa|bb|cc", b"bb|cc|dd"]], [[]]]]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_ragged_inputs_with_multiple_ragged_dimensions_and_preserve(self):
+    data = [[[[b"aa", b"bb", b"cc", b"dd"]], [[b"ee", b"ff"]]]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor,
+        ngram_width=3,
+        separator=b"|",
+        preserve_short_sequences=True)
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[[[b"aa|bb|cc", b"bb|cc|dd"]], [[b"ee|ff"]]]]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_ragged_inputs_with_multiple_ragged_dimensions_bigrams(self):
+    data = [[[[b"aa", b"bb", b"cc", b"dd"]], [[b"ee", b"ff"]]]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=2, separator=b"|")
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[[[b"aa|bb", b"bb|cc", b"cc|dd"]], [[b"ee|ff"]]]]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_ragged_inputs_with_multiple_ragged_dimensions_and_multiple_ngrams(
+      self):
+    data = [[[[b"aa", b"bb", b"cc", b"dd"]], [[b"ee", b"ff"]]]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=(3, 4), separator=b"|")
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[[[b"aa|bb|cc", b"bb|cc|dd", b"aa|bb|cc|dd"]], [[]]]]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_dense_input_rank_3(self):
+    data = [[[b"a", b"z"], [b"b", b""]], [[b"b", b""], [b"e", b"f"]]]
+    data_tensor = constant_op.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=3, separator=b"|", pad_values=(b"LP", b"RP"))
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[[b"LP|LP|a", b"LP|a|z", b"a|z|RP", b"z|RP|RP"],
+                        [b"LP|LP|b", b"LP|b|", b"b||RP", b"|RP|RP"]],
+                       [[b"LP|LP|b", b"LP|b|", b"b||RP", b"|RP|RP"],
+                        [b"LP|LP|e", b"LP|e|f", b"e|f|RP", b"f|RP|RP"]]]
+    self.assertIsInstance(ngram_op, ops.Tensor)
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_dense_input(self):
+    data = [[b"a", b"z"], [b"b", b""], [b"e", b"f"]]
+    data_tensor = constant_op.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=3, separator=b"|", pad_values=(b"LP", b"RP"))
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [
+        [b"LP|LP|a", b"LP|a|z", b"a|z|RP", b"z|RP|RP"],
+        [b"LP|LP|b", b"LP|b|", b"b||RP", b"|RP|RP"],
+        [b"LP|LP|e", b"LP|e|f", b"e|f|RP", b"f|RP|RP"],
+    ]
+    self.assertIsInstance(ngram_op, ops.Tensor)
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_input_list_input(self):
+    data = [[b"a", b"z"], [b"b", b""], [b"e", b"f"]]
+    ngram_op = ragged_string_ops.ngrams(
+        data, ngram_width=3, separator=b"|", pad_values=(b"LP", b"RP"))
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [
+        [b"LP|LP|a", b"LP|a|z", b"a|z|RP", b"z|RP|RP"],
+        [b"LP|LP|b", b"LP|b|", b"b||RP", b"|RP|RP"],
+        [b"LP|LP|e", b"LP|e|f", b"e|f|RP", b"f|RP|RP"],
+    ]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_vector_input(self):
+    data = [b"a", b"z"]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=3, separator=b"|", pad_values=(b"LP", b"RP"))
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [b"LP|LP|a", b"LP|a|z", b"a|z|RP", b"z|RP|RP"]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_dense_input_with_multiple_ngrams(self):
+    data = [[b"a", b"b", b"c", b"d"], [b"e", b"f", b"g", b"h"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=(1, 2, 3), separator=b"|")
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[
+        b"a", b"b", b"c", b"d", b"a|b", b"b|c", b"c|d", b"a|b|c", b"b|c|d"
+    ], [b"e", b"f", b"g", b"h", b"e|f", b"f|g", b"g|h", b"e|f|g", b"f|g|h"]]
+    self.assertAllEqual(expected_ngrams, result)
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py
index 8af6d66..b3414c5 100644
--- a/tensorflow/python/ops/random_ops.py
+++ b/tensorflow/python/ops/random_ops.py
@@ -20,6 +20,7 @@
 
 import numpy as np
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import random_seed
@@ -73,6 +74,7 @@
         shape_tensor, dtype, seed=seed1, seed2=seed2)
     mul = rnd * stddev_tensor
     value = math_ops.add(mul, mean_tensor, name=name)
+    tensor_util.maybe_set_static_shape(value, shape)
     return value
 
 
@@ -129,6 +131,7 @@
         maxvals_tensor,
         seed=seed1,
         seed2=seed2)
+    tensor_util.maybe_set_static_shape(rnd, shape)
     return rnd
 
 
@@ -172,6 +175,7 @@
         shape_tensor, dtype, seed=seed1, seed2=seed2)
     mul = rnd * stddev_tensor
     value = math_ops.add(mul, mean_tensor, name=name)
+    tensor_util.maybe_set_static_shape(value, shape)
     return value
 
 
@@ -235,11 +239,17 @@
     maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
     seed1, seed2 = random_seed.get_seed(seed)
     if dtype.is_integer:
-      return gen_random_ops.random_uniform_int(
+      result = gen_random_ops.random_uniform_int(
           shape, minval, maxval, seed=seed1, seed2=seed2, name=name)
     else:
       rnd = gen_random_ops.random_uniform(shape, dtype, seed=seed1, seed2=seed2)
-      return math_ops.add(rnd * (maxval - minval), minval, name=name)
+      result = math_ops.add(rnd * (maxval - minval), minval, name=name)
+    # TODO(b/132092188): C++ shape inference inside functional ops does not
+    # cross FuncGraph boundaries since that information is only available in
+    # python. So we manually get the static shape using
+    # `constant_value_as_shape` which *does* cross function boundaries.
+    tensor_util.maybe_set_static_shape(result, shape)
+    return result
 
 
 ops.NotDifferentiable("RandomUniform")
@@ -332,7 +342,7 @@
   ```python
   # samples has shape [1, 5], where each value is either 0 or 1 with equal
   # probability.
-  samples = tf.random.categorical(tf.math.log([[10., 10.]]), 5)
+  samples = tf.random.categorical(tf.math.log([[0.5, 0.5]]), 5)
   ```
 
   Args:
@@ -360,7 +370,7 @@
   ```python
   # samples has shape [1, 5], where each value is either 0 or 1 with equal
   # probability.
-  samples = tf.random.categorical(tf.math.log([[10., 10.]]), 5)
+  samples = tf.random.categorical(tf.math.log([[0.5, 0.5]]), 5)
   ```
 
   Args:
@@ -390,6 +400,16 @@
 ops.NotDifferentiable("Multinomial")
 
 
+def _maybe_set_static_shape_helper(tensor, shape, postfix_tensor):
+  if (not context.executing_eagerly() and
+      ops.get_default_graph().building_function and
+      not tensor.shape.is_fully_defined()):
+    shape = tensor_util.shape_tensor(shape)
+    const_shape = tensor_util.constant_value_as_shape(shape)
+    postfix_tensor = ops.convert_to_tensor(postfix_tensor)
+    tensor.set_shape(const_shape.concatenate(postfix_tensor.shape))
+
+
 @tf_export("random.gamma", v1=["random.gamma", "random_gamma"])
 @deprecation.deprecated_endpoints("random_gamma")
 def random_gamma(shape,
@@ -468,10 +488,12 @@
         beta if beta is not None else 1, name="beta", dtype=dtype)
     alpha_broadcast = alpha + array_ops.zeros_like(beta)
     seed1, seed2 = random_seed.get_seed(seed)
-    return math_ops.maximum(
+    result = math_ops.maximum(
         np.finfo(dtype.as_numpy_dtype).tiny,
         gen_random_ops.random_gamma(
             shape, alpha_broadcast, seed=seed1, seed2=seed2) / beta)
+    _maybe_set_static_shape_helper(result, shape, alpha_broadcast)
+    return result
 
 
 @tf_export(v1=["random.poisson", "random_poisson"])
@@ -553,5 +575,7 @@
   with ops.name_scope(name, "random_poisson", [lam, shape]):
     shape = ops.convert_to_tensor(shape, name="shape", dtype=dtypes.int32)
     seed1, seed2 = random_seed.get_seed(seed)
-    return gen_random_ops.random_poisson_v2(
+    result = gen_random_ops.random_poisson_v2(
         shape, lam, dtype=dtype, seed=seed1, seed2=seed2)
+    _maybe_set_static_shape_helper(result, shape, lam)
+    return result
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py
index a99ddff..c336d2b 100644
--- a/tensorflow/python/ops/script_ops.py
+++ b/tensorflow/python/ops/script_ops.py
@@ -31,6 +31,7 @@
 from tensorflow.python import pywrap_tensorflow
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
+from tensorflow.python.eager import executor
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import function
 from tensorflow.python.framework import ops
@@ -101,27 +102,30 @@
   def __call__(self, device, token, args):
     """Passes `args` to `self._func`, which is executed eagerly."""
 
-    with context.eager_mode(), backprop.GradientTape() as tape:
-      # Only watch tensors with a floating dtype.
-      for tensor in args:
-        for t in nest.flatten(tensor):
-          if t.dtype.is_floating:
-            tape.watch(t)
-      ret = self._func(*args)
-      # Use tf.identity to copy the returned tensors to device if neccesary.
-      with ops.device(device):
-        if isinstance(ret, (tuple, list)):
-          outputs = [
-              array_ops.identity(self._convert(x, dtype=dtype))
-              for (x, dtype) in zip(ret, self._out_dtypes)
-          ]
-        elif ret is None:
-          outputs = None
-        else:
-          outputs = array_ops.identity(
-              self._convert(ret, dtype=self._out_dtypes[0]))
-    tape_cache[compat.as_bytes(token)] = (tape, args, outputs)
-    return outputs
+    func_executor = executor.new_executor(context.is_async())
+    with context.executor_scope(func_executor):
+      with context.eager_mode(), backprop.GradientTape() as tape:
+        # Only watch tensors with a floating dtype.
+        for tensor in args:
+          for t in nest.flatten(tensor):
+            if t.dtype.is_floating:
+              tape.watch(t)
+        ret = self._func(*args)
+        # Use tf.identity to copy the returned tensors to device if necessary.
+        with ops.device(device):
+          if isinstance(ret, (tuple, list)):
+            outputs = [
+                array_ops.identity(self._convert(x, dtype=dtype))
+                for (x, dtype) in zip(ret, self._out_dtypes)
+            ]
+          elif ret is None:
+            outputs = None
+          else:
+            outputs = array_ops.identity(
+                self._convert(ret, dtype=self._out_dtypes[0]))
+      tape_cache[compat.as_bytes(token)] = (tape, args, outputs)
+      return outputs
+    func_executor.wait()
 
 
 class FuncRegistry(object):
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index f6b26c8..ba86ba3 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -23,7 +23,6 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
 import numbers
 
 import numpy as np
@@ -46,6 +45,7 @@
 from tensorflow.python.util import deprecation
 from tensorflow.python.util import dispatch
 from tensorflow.python.util import tf_inspect
+from tensorflow.python.util.compat import collections_abc
 from tensorflow.python.util.tf_export import get_canonical_name_for_symbol
 from tensorflow.python.util.tf_export import tf_export
 
@@ -1430,7 +1430,7 @@
 @tf_export("sparse.to_dense", v1=["sparse.to_dense", "sparse_tensor_to_dense"])
 @deprecation.deprecated_endpoints("sparse_tensor_to_dense")
 def sparse_tensor_to_dense(sp_input,
-                           default_value=0,
+                           default_value=None,
                            validate_indices=True,
                            name=None):
   """Converts a `SparseTensor` into a dense tensor.
@@ -1470,6 +1470,8 @@
     TypeError: If `sp_input` is not a `SparseTensor`.
   """
   sp_input = _convert_to_sparse_tensor(sp_input)
+  if default_value is None:
+    default_value = array_ops.zeros([], dtype=sp_input.dtype)
 
   return gen_sparse_ops.sparse_to_dense(
       sp_input.indices,
@@ -1658,10 +1660,10 @@
                       type(vocab_size))
     vocab_size = [vocab_size]
   else:
-    if not isinstance(sp_ids, collections.Iterable):
+    if not isinstance(sp_ids, collections_abc.Iterable):
       raise TypeError("sp_ids has to be a SparseTensor or list thereof. "
                       "Found %s" % type(sp_ids))
-    if not isinstance(vocab_size, collections.Iterable):
+    if not isinstance(vocab_size, collections_abc.Iterable):
       raise TypeError("vocab_size has to be a list of Tensors or Python ints. "
                       "Found %s" % type(vocab_size))
     for dim in vocab_size:
diff --git a/tensorflow/python/ops/sparse_ops_test.py b/tensorflow/python/ops/sparse_ops_test.py
index 992a330..f48c544 100644
--- a/tensorflow/python/ops/sparse_ops_test.py
+++ b/tensorflow/python/ops/sparse_ops_test.py
@@ -125,6 +125,14 @@
     epsilon = 1e-4
     self.assertLess(gradient_checker.max_error(*grads), epsilon)
 
+  def testSparseTensorToDenseString(self):
+    sp = sparse_tensor.SparseTensor(
+        indices=[[0, 0], [1, 2]], values=['a', 'b'], dense_shape=[2, 3])
+    dense = sparse_ops.sparse_tensor_to_dense(sp)
+    expected_dense = [[b'a', b'', b''], [b'', b'', b'b']]
+    result_dense = self.evaluate(dense)
+    self.assertAllEqual(expected_dense, result_dense)
+
 
 if __name__ == '__main__':
   googletest.main()
diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py
index 7ad841c..2f350c1 100644
--- a/tensorflow/python/ops/special_math_ops.py
+++ b/tensorflow/python/ops/special_math_ops.py
@@ -533,7 +533,8 @@
   new_shape = tuple(-1 if x is None else x for x in new_shape)
   cur_shape = tuple(x.value for x in tensor.get_shape().dims)
   if (len(new_shape) == len(cur_shape) and
-      all(d0 == d1 or d1 == -1 for d0, d1 in zip(cur_shape, new_shape))):
+      all(not isinstance(d1, ops.Tensor) and (d0 == d1 or d1 == -1)
+          for d0, d1 in zip(cur_shape, new_shape))):
     return tensor
   else:
     return array_ops.reshape(tensor, new_shape)
diff --git a/tensorflow/python/ops/stateless_random_ops.py b/tensorflow/python/ops/stateless_random_ops.py
index dc3f8ff..62f65ce 100644
--- a/tensorflow/python/ops/stateless_random_ops.py
+++ b/tensorflow/python/ops/stateless_random_ops.py
@@ -92,12 +92,14 @@
     minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
     maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
     if dtype.is_integer:
-      return gen_stateless_random_ops.stateless_random_uniform_int(
+      result = gen_stateless_random_ops.stateless_random_uniform_int(
           shape, seed=seed, minval=minval, maxval=maxval, name=name)
     else:
       rnd = gen_stateless_random_ops.stateless_random_uniform(
           shape, seed=seed, dtype=dtype)
-      return math_ops.add(rnd * (maxval - minval), minval, name=name)
+      result = math_ops.add(rnd * (maxval - minval), minval, name=name)
+    tensor_util.maybe_set_static_shape(result, shape)
+    return result
 
 
 @tf_export("random.stateless_normal")
@@ -134,7 +136,9 @@
     mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
     stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
     rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype)
-    return math_ops.add(rnd * stddev, mean, name=name)
+    result = math_ops.add(rnd * stddev, mean, name=name)
+    tensor_util.maybe_set_static_shape(result, shape)
+    return result
 
 
 @tf_export("random.stateless_truncated_normal")
@@ -177,7 +181,9 @@
     stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
     rnd = gen_stateless_random_ops.stateless_truncated_normal(
         shape, seed, dtype)
-    return math_ops.add(rnd * stddev, mean, name=name)
+    result = math_ops.add(rnd * stddev, mean, name=name)
+    tensor_util.maybe_set_static_shape(result, shape)
+    return result
 
 
 @tf_export(v1=["random.stateless_multinomial"])
@@ -202,7 +208,7 @@
   # samples has shape [1, 5], where each value is either 0 or 1 with equal
   # probability.
   samples = tf.random.stateless_categorical(
-      tf.math.log([[10., 10.]]), 5, seed=[7, 17])
+      tf.math.log([[0.5, 0.5]]), 5, seed=[7, 17])
   ```
 
   Args:
@@ -241,7 +247,7 @@
   # samples has shape [1, 5], where each value is either 0 or 1 with equal
   # probability.
   samples = tf.random.stateless_categorical(
-      tf.math.log([[10., 10.]]), 5, seed=[7, 17])
+      tf.math.log([[0.5, 0.5]]), 5, seed=[7, 17])
   ```
 
   Args:
diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py
index fab83c6..57fb8f5 100644
--- a/tensorflow/python/ops/tensor_array_ops.py
+++ b/tensorflow/python/ops/tensor_array_ops.py
@@ -135,7 +135,7 @@
     # of the first write. If `infer_shape` is true, all writes checks for
     # shape equality.
     self._element_shape = [tensor_shape.as_shape(element_shape)]
-    self._infer_shape = element_shape is not None or infer_shape
+    self._infer_shape = infer_shape
     with ops.name_scope(name, "TensorArray", [handle, size, flow]) as scope:
       if handle is not None:
         self._handle = handle
@@ -179,7 +179,7 @@
   def element_shape(self):
     return self._element_shape[0]
 
-  def _merge_element_shape(self, shape):
+  def _check_element_shape(self, shape):
     """Changes the element shape of the array given a shape to merge with.
 
     Args:
@@ -190,10 +190,10 @@
           element shape of the `TensorArray`.
     """
     if not shape.is_compatible_with(self.element_shape):
-      raise ValueError(
-          "Inconsistent shapes: saw %s but expected %s "
-          "(and infer_shape=True)" % (shape, self.element_shape))
-    self._element_shape[0] = self.element_shape.merge_with(shape)
+      raise ValueError("Inconsistent shapes: saw %s but expected %s " %
+                       (shape, self.element_shape))
+    if self._infer_shape:
+      self._element_shape[0] = self.element_shape.merge_with(shape)
 
   @contextlib.contextmanager
   def _maybe_colocate_with(self, value):
@@ -266,8 +266,7 @@
       value = ops.convert_to_tensor(
           value, preferred_dtype=self._dtype, name="value")
       _check_dtypes(value, self._dtype)
-      if self._infer_shape:
-        self._merge_element_shape(value.shape)
+      self._check_element_shape(value.shape)
       with self._maybe_colocate_with(value):
         flow_out = gen_data_flow_ops.tensor_array_write_v3(
             handle=self._handle,
@@ -329,8 +328,8 @@
       value = ops.convert_to_tensor(
           value, preferred_dtype=self._dtype, name="value")
       _check_dtypes(value, self._dtype)
-      if self._infer_shape and not context.executing_eagerly():
-        self._merge_element_shape(value.shape[1:])
+      if not context.executing_eagerly():
+        self._check_element_shape(value.shape[1:])
       with self._maybe_colocate_with(value):
         flow_out = gen_data_flow_ops.tensor_array_scatter_v3(
             handle=self._handle,
@@ -348,11 +347,11 @@
       value = ops.convert_to_tensor(value, dtype=self._dtype, name="value")
       with self._maybe_colocate_with(value):
         lengths_64 = math_ops.cast(lengths, dtypes.int64)
-        if self._infer_shape and not context.executing_eagerly():
+        if not context.executing_eagerly():
           clengths = tensor_util.constant_value(lengths_64)
-          if value.shape.dims is not None:
-            if clengths is not None and clengths.max() == clengths.min():
-              self._merge_element_shape(
+          if value.shape.dims is not None and clengths is not None:
+            if clengths.shape and clengths.max() == clengths.min():
+              self._check_element_shape(
                   tensor_shape.TensorShape([clengths[0]]).concatenate(
                       value.shape[1:]))
         flow_out = gen_data_flow_ops.tensor_array_split_v3(
@@ -447,7 +446,7 @@
     # of the first write. If `infer_shape` is true, all writes checks for
     # shape equality.
     self._element_shape = [tensor_shape.as_shape(element_shape)]
-    self._infer_shape = element_shape is not None or infer_shape
+    self._infer_shape = infer_shape
     with ops.name_scope(name, "TensorArrayV2", [size, flow]) as scope:
       if flow is None:
         self._flow = list_ops.tensor_list_reserve(
@@ -480,7 +479,7 @@
     # complain.
     return None
 
-  def _merge_element_shape(self, shape):
+  def _check_element_shape(self, shape):
     """Changes the element shape of the array given a shape to merge with.
 
     Args:
@@ -491,10 +490,10 @@
           element shape of the `TensorArray`.
     """
     if not shape.is_compatible_with(self.element_shape):
-      raise ValueError(
-          "Inconsistent shapes: saw %s but expected %s "
-          "(and infer_shape=True)" % (shape, self.element_shape))
-    self._element_shape[0] = self.element_shape.merge_with(shape)
+      raise ValueError("Inconsistent shapes: saw %s but expected %s " %
+                       (shape, self.element_shape))
+    if self._infer_shape:
+      self._element_shape[0] = self.element_shape.merge_with(shape)
 
   def identity(self):
     """See TensorArray."""
@@ -524,8 +523,7 @@
       value = ops.convert_to_tensor(
           value, preferred_dtype=self._dtype, name="value")
       _check_dtypes(value, self._dtype)
-      if self._infer_shape:
-        self._merge_element_shape(value.shape)
+      self._check_element_shape(value.shape)
       flow_out = list_ops.tensor_list_set_item(
           input_handle=self._flow,
           index=index,
@@ -575,8 +573,7 @@
       value = ops.convert_to_tensor(
           value, preferred_dtype=self._dtype, name="value")
       _check_dtypes(value, self._dtype)
-      if self._infer_shape and not context.executing_eagerly():
-        self._merge_element_shape(value.shape[1:])
+      self._check_element_shape(value.shape[1:])
       flow_out = list_ops.tensor_list_from_tensor(
           tensor=value, element_shape=value.shape[1:])
       return build_ta_with_new_flow(self, flow_out)
@@ -590,8 +587,7 @@
       value = ops.convert_to_tensor(
           value, preferred_dtype=self._dtype, name="value")
       _check_dtypes(value, self._dtype)
-      if self._infer_shape and not context.executing_eagerly():
-        self._merge_element_shape(value.shape[1:])
+      self._check_element_shape(value.shape[1:])
       flow_out = list_ops.tensor_list_scatter(
           tensor=value, indices=indices, element_shape=self.element_shape,
           input_handle=self._flow)
@@ -606,11 +602,11 @@
           value, preferred_dtype=self._dtype, name="value")
       _check_dtypes(value, self._dtype)
       lengths_64 = math_ops.cast(lengths, dtypes.int64)
-      if self._infer_shape and not context.executing_eagerly():
+      if not context.executing_eagerly():
         clengths = tensor_util.constant_value(lengths_64)
-        if value.shape.dims is not None:
-          if clengths is not None and clengths.max() == clengths.min():
-            self._merge_element_shape(
+        if value.shape.dims is not None and clengths is not None:
+          if clengths.shape and clengths.max() == clengths.min():
+            self._check_element_shape(
                 tensor_shape.TensorShape([clengths[0]]).concatenate(
                     value.shape[1:]))
       flow_out = list_ops.tensor_list_split(
@@ -688,7 +684,7 @@
     # we assign a dummy value to _flow in case other code assumes it to be
     # a Tensor
     self._flow = constant_op.constant(0, dtype=dtypes.int32)
-    self._infer_shape = element_shape is not None or infer_shape
+    self._infer_shape = infer_shape
     self._element_shape = tensor_shape.as_shape(element_shape)
     self._colocate_with_first_write_call = colocate_with_first_write_call
 
@@ -804,12 +800,12 @@
           "TensorArray dtype is %s but Op is trying to write dtype %s" %
           (self._dtype.name, value.dtype.name))
 
+    if not self._element_shape.is_compatible_with(value.shape):
+      raise ValueError("Incompatible shape for value (%s), expected (%s)" %
+                       (value.shape, self._element_shape))
+
     if self._infer_shape:
-      if not self._element_shape.is_compatible_with(value.shape):
-        raise ValueError("Incompatible shape for value (%s), expected (%s)" %
-                         (value.shape, self._element_shape))
-      else:
-        self._element_shape = self._element_shape.merge_with(value.shape)
+      self._element_shape = self._element_shape.merge_with(value.shape)
 
     self._tensor_array[index] = value
 
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 39ebc5f..8805a71 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -41,6 +41,7 @@
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training.tracking import base as trackable
 from tensorflow.python.util import compat
+from tensorflow.python.util import object_identity
 from tensorflow.python.util import tf_should_use
 from tensorflow.python.util.deprecation import deprecated
 from tensorflow.python.util.tf_export import tf_export
@@ -1080,7 +1081,11 @@
     setattr(cls, operator, _run_op)
 
   def __hash__(self):
-    return id(self)
+    if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions():  # pylint: disable=protected-access
+      raise TypeError("Variable is unhashable if Tensor equality is enabled. "
+                      "Instead, use tensor.experimental_ref() as the key.")
+    else:
+      return id(self)
 
   # TODO(gjn): duplicate of math_ops.tensor_equals, consider removing
   def __eq__(self, other):
@@ -1210,6 +1215,59 @@
   def _get_save_slice_info(self):
     return self._save_slice_info
 
+  def experimental_ref(self):
+    # tf.Tensor also has the same experimental_ref() API.  If you update the
+    # documenation here, please update tf.Tensor.experimental_ref() as well.
+    """Returns a hashable reference object to this Variable.
+
+    Warning: Experimental API that could be changed or removed.
+
+    The primary usecase for this API is to put variables in a set/dictionary.
+    We can't put variables in a set/dictionary as `variable.__hash__()` is no
+    longer available starting Tensorflow 2.0.
+
+    ```python
+    import tensorflow as tf
+
+    x = tf.Variable(5)
+    y = tf.Variable(10)
+    z = tf.Variable(10)
+
+    # The followings will raise an exception starting 2.0
+    # TypeError: Variable is unhashable if Variable equality is enabled.
+    variable_set = {x, y, z}
+    variable_dict = {x: 'five', y: 'ten'}
+    ```
+
+    Instead, we can use `variable.experimental_ref()`.
+
+    ```python
+    variable_set = {x.experimental_ref(),
+                    y.experimental_ref(),
+                    z.experimental_ref()}
+
+    print(x.experimental_ref() in variable_set)
+    ==> True
+
+    variable_dict = {x.experimental_ref(): 'five',
+                     y.experimental_ref(): 'ten',
+                     z.experimental_ref(): 'ten'}
+
+    print(variable_dict[y.experimental_ref()])
+    ==> ten
+    ```
+
+    Also, the reference object provides `.deref()` function that returns the
+    original Variable.
+
+    ```python
+    x = tf.Variable(5)
+    print(x.experimental_ref().deref())
+    ==> <tf.Variable 'Variable:0' shape=() dtype=int32, numpy=5>
+    ```
+    """
+    return object_identity.Reference(self)
+
   class SaveSliceInfo(object):
     """Information on how to save this Variable as a slice.
 
@@ -2704,7 +2762,7 @@
   """
   op_type = op.node_def.op
   if op_type in ("IsVariableInitialized", "VarIsInitializedOp",
-                 "ReadVariableOp"):
+                 "ReadVariableOp", "If"):
     return op
 
   # Attempt to find the initialized_value of any variable reference / handles.
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index 7527c5c..ea57451 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -23,6 +23,8 @@
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.compat import compat
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import func_graph as func_graph_module
@@ -30,6 +32,7 @@
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_spec
+from tensorflow.python.framework import tensor_util
 from tensorflow.python.framework import type_spec
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
@@ -43,6 +46,7 @@
 from tensorflow.python.ops import tensor_array_ops
 from tensorflow.python.ops import while_v2_indexed_slices_rewriter
 from tensorflow.python.util import nest
+from tensorflow.python.util import object_identity
 
 # pylint: disable=protected-access
 
@@ -60,7 +64,8 @@
                parallel_iterations=10,
                maximum_iterations=None,
                name=None,
-               return_same_structure=True):
+               return_same_structure=True,
+               back_prop=True):
   """Like tf.while_loop, except emits a single While op."""
   # Keep the original loop_vars around to know which args were TensorArrays.
   orig_loop_vars = loop_vars
@@ -117,19 +122,23 @@
     # graphs. Propagate that behavior here.
     add_control_dependencies = ops.get_default_graph()._add_control_dependencies
 
-    # Build a `cond` wrapper that can handle the extra counter loop_var.
     def wrapped_cond(loop_counter, maximum_iterations_arg, *args):
+      """Extra `cond` wrapper that can handle the extra counter loop_var."""
       # Convert the flow variables in `args` to TensorArrays. `args` should
       # already have the same structure as `orig_loop_vars` but currently there
       # is no nest.zip so we call `_pack_sequence_as` which flattens both
       # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
       # and packs it into the structure of `orig_loop_vars`.
+      pred = cond(*_pack_sequence_as(orig_loop_vars, args))
+      if (tensor_util.is_tensor(pred) and
+          (pred.shape.dims is None or pred.shape.dims)):
+        pred = array_ops.squeeze_v2(pred)
+
       if maximum_iterations is None:
-        return cond(*_pack_sequence_as(orig_loop_vars, args))
+        return pred
       else:
         return math_ops.logical_and(
-            loop_counter < maximum_iterations_arg,
-            cond(*_pack_sequence_as(orig_loop_vars, args)))
+            loop_counter < maximum_iterations_arg, pred)
 
     # NOTE(skyewm): we set collections to the outer graph's collections for
     # compatibility with TPUEstimator.
@@ -202,8 +211,10 @@
       num_cond_captures = len(cond_graph.external_captures)
       assert (cond_graph.external_captures ==
               body_graph.external_captures[:num_cond_captures])
+      cond_graph_captures = object_identity.ObjectIdentitySet(
+          cond_graph.external_captures)
       for body_capture in body_graph.external_captures[num_cond_captures:]:
-        assert body_capture not in cond_graph.captures
+        assert body_capture not in cond_graph_captures
         cond_graph.capture(body_capture)
 
     # Make sure that the shapes of the loop outputs are compatible with the
@@ -221,6 +232,32 @@
                              len_orig_loop_vars], expand_composites=True),
         nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index +
                                len_orig_loop_vars], expand_composites=True))
+
+    num_original_outputs = len(body_graph.outputs)
+    if back_prop and util.output_all_intermediates():
+      # Export all tensors in the loop body that may be needed for gradient
+      # computation. We do this by accumulating the intermediate values in
+      # TensorLists.
+      intermediate_tensors = _get_intermediates(body_graph)
+
+      for intermediate_tensor in intermediate_tensors:
+        tensor_list = list_ops.empty_tensor_list(
+            element_dtype=intermediate_tensor.dtype,
+            element_shape=intermediate_tensor.shape,
+            max_num_elements=maximum_iterations)
+        loop_vars.append(tensor_list)
+        with cond_graph.as_default():
+          # Add a placeholder to cond_graph's inputs corresponding to the
+          # tensor_list.
+          cond_graph.capture(tensor_list)
+        with body_graph.as_default():
+          # Push the intermediate tensor to the tensor list. This captures the
+          # `tensor_list` as well.
+          appended_tensor_list = list_ops.tensor_list_push_back(
+              tensor_list, intermediate_tensor)
+          # Add this modified tensor list to the list of outputs.
+          body_graph.outputs.append(appended_tensor_list)
+
     flattened_loop_vars = nest.flatten(loop_vars, expand_composites=True)
     _check_num_inputs_outputs(cond_graph, body_graph,
                               len(flattened_loop_vars))
@@ -233,13 +270,32 @@
                                    first_loop_var_index + num_flattened_outputs)
       output_shapes[orig_loop_vars_range] = nest.flatten(
           shape_invariants, expand_composites=True)[orig_loop_vars_range]
-      outputs = gen_functional_ops._while(
+
+      cond_stateful_ops = [
+          op for op in cond_graph.get_operations() if op._is_stateful
+      ]
+      body_stateful_ops = [
+          op for op in body_graph.get_operations() if op._is_stateful
+      ]
+      # TODO(yanhuasun): Remove this after Aug 23, 2019. This is required to
+      # abide by 3-week forward compat window of new TF python op generating
+      # code with stale runtime binaries.
+      if (cond_stateful_ops or body_stateful_ops or
+          not compat.forward_compatible(2019, 8, 23)):
+        op_fn = gen_functional_ops._while
+      else:
+        op_fn = gen_functional_ops.stateless_while
+
+      outputs = op_fn(
           flattened_loop_vars,
           util.create_new_tf_function(cond_graph),
           util.create_new_tf_function(body_graph),
           output_shapes=output_shapes,
           parallel_iterations=parallel_iterations,
           name=scope)
+      # This is needed so we do not compute derivative wrt these extra outputs.
+      outputs[0].op._set_attr("_num_original_outputs",
+                              attr_value_pb2.AttrValue(i=num_original_outputs))
 
     _copy_handle_data(body_graph.outputs, outputs)
     util.maybe_set_lowering_attr(outputs[0].op)
@@ -267,6 +323,7 @@
     return outputs
 
 
+@ops.RegisterGradient("StatelessWhile")
 @ops.RegisterGradient("While")
 def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
   """The gradient of a While op produced by while_loop."""
@@ -282,9 +339,19 @@
   maximum_iterations = op.inputs[1]
   parallel_iterations = op.get_attr("parallel_iterations")
 
-  grads = [_preprocess_grad(grad, body_out, while_out)
-           for grad, body_out, while_out
-           in zip(grads, body_graph.outputs, while_op.outputs)]
+  try:
+    num_original_outputs = while_op.get_attr("_num_original_outputs")
+  except:  # pylint: disable=bare-except
+    num_original_outputs = len(while_op.outputs)
+
+  num_intermediates = len(while_op.outputs) - num_original_outputs
+  grads = [
+      _preprocess_grad(grad, body_out, while_out)  # pylint: disable=g-complex-comprehension
+      for grad, body_out, while_out in zip(
+          grads[:num_original_outputs],
+          body_graph.outputs[:num_original_outputs],
+          while_op.outputs[:num_original_outputs])
+  ] + [None] * num_intermediates
 
   # We compute the gradient for the sub-graph between trainable ys and xs
   # with non-None incoming gradients. We later pad the None's to the list of
@@ -317,6 +384,11 @@
                           [t.shape for t in new_outputs])
     _copy_handle_data(new_outputs, op.outputs[orig_num_params:])
 
+  # Do not ingore grads wrt extra outputs when computing higher order
+  # derivatives.
+  while_op._set_attr("_num_original_outputs",
+                     attr_value_pb2.AttrValue(i=len(while_op.outputs)))
+
   captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph,
                                            while_op)
   loop_vars = args + captured_inputs
@@ -354,6 +426,48 @@
   return _get_structured_grad_output(outputs, grads, body_grad_graph)
 
 
+def _get_intermediates(func_graph):
+  """Returns all tensors in `func_graph` that should be accumulated."""
+  # We currently accumulate output tensors of most ops in the function and rely
+  # on the pruning pass to get rid of the unused accumulators at runtime.
+  # However, this can bloat the GraphDef and make debugging harder so we perform
+  # some optimizations.
+  #
+  # Optimization we currently perform:
+  # 1. We do not accumulate tensors which already have an accumulator
+  #    in the loop body.
+  # 2. We do not accumulate outputs of Identity nodes. When building the
+  #    FuncGraph, we add an Identity node for each output (see
+  #    `AutomaticControlDependencies.mark_as_return`). Accumulating outputs
+  #    of all these nodes bloats the GraphDef quite a bit so we remove those.
+  #    Since the gradient of an Identity node does not rely on its forward op's
+  #    input this is safe to do.
+  #
+  # Other possible optimizations:
+  # 1. Only accumulate tensors that will be required by the backward pass.
+  #    This will require running the gradient pass and hence would increase the
+  #    graph building time for the forward pass.
+  # 2. Do not accumulate Const nodes created inside the loop body.
+  # 3. Do not accumulate loop vars that are returned as-is just like captured
+  #    tensors.
+  intermediates = []
+  reverse_captures = dict((v, k) for k, v in func_graph.captures)
+
+  for op in func_graph.get_operations():
+    if op.type == "Identity":
+      continue
+    # Accumulating mutexes can cause deadlock.
+    if op.type == "MutexLock":
+      continue
+    for o in op.outputs:
+      if (o != func_graph.inputs[0] and  # Loop counter.
+          o.dtype != dtypes.resource and  # Do not accumulate resource tensors.
+          _get_accumulator(o) is None and  # Has existing accumulator.
+          o not in reverse_captures):  # Captured value, hence loop invariant.
+        intermediates.append(o)
+  return intermediates
+
+
 def _preprocess_grad(grad, body_graph_output, while_op_output):
   """Returns the initial gradient to be used for a given output tensor.
 
@@ -476,8 +590,8 @@
   # the output of `_is_loop_invariant`. Also we would never attempt to capture
   # those accumulators so `_is_loop_invariant` should never receive those new
   # tensors as args.
-  body_graph_inputs = frozenset(body_graph.inputs)
-  body_graph_outputs = frozenset(body_graph.outputs)
+  body_graph_inputs = object_identity.ObjectIdentitySet(body_graph.inputs)
+  body_graph_outputs = object_identity.ObjectIdentitySet(body_graph.outputs)
 
   args = [counter, maximum_iterations, total_iters] + list(grads)
   # Note: The returned function does not have `args` in the list of
@@ -497,9 +611,10 @@
   #    `popped_tensor_lists` by `_WhileBodyGradFuncGraph`.
   # 2. Resources, which are output as is.
   # 3. Forward graph loop invariants, which are output as is.
-  for external_capture, internal_capture in grad_func_graph.captures.items():
-    if internal_capture in grad_func_graph.popped_tensor_lists:
-      new_output = grad_func_graph.popped_tensor_lists[internal_capture]
+  for external_capture, internal_capture in grad_func_graph.captures:
+    if ops.tensor_id(internal_capture) in grad_func_graph.popped_tensor_lists:
+      new_output = grad_func_graph.popped_tensor_lists[ops.tensor_id(
+          internal_capture)]
     elif (internal_capture.dtype == dtypes.resource or _is_loop_invariant(
         external_capture, body_graph_inputs, body_graph_outputs)):
       new_output = internal_capture
@@ -582,7 +697,11 @@
     # regular non-captured inputs).
     if t.graph == body_graph:
       # Captured accumulator or loop invariant.
-      t = while_op.outputs[t.graph.outputs.index(t)]
+      for i, output in enumerate(t.graph.outputs):
+        if output is t:
+          t = while_op.outputs[i]
+          break
+
       # Note: We rely on the capturing logic of the gradient While op graph to
       # correctly capture the tensors in `body_graph.outer_graph`. Both cond_v2
       # and while_v2 handle this while building their gradient functions.
@@ -666,8 +785,9 @@
 
   def get_func_graph_output(t):
     """Returns t or Identity(t) whichever exists in graph outputs else None."""
-    if t in tensor.graph.outputs:
-      return t
+    for output in tensor.graph.outputs:
+      if output is t:
+        return t
     # tf.defun adds an Identity for each output, check whether that is the case.
     identity_op = t.consumers()[0]
     if (identity_op.type == "Identity" and
@@ -678,8 +798,14 @@
   for consumer in tensor.consumers():
     # Find the consumer that is a TensorListPushBack node whose TensorList input
     # is in the list of function inputs.
-    if (consumer.type != "TensorListPushBack" or
-        consumer.inputs[0] not in tensor.graph.inputs):
+    if consumer.type != "TensorListPushBack":
+      continue
+
+    accum_input_idx = -1
+    for accum_input_idx, inp in enumerate(tensor.graph.inputs):
+      if inp is consumer.inputs[0]:
+        break
+    else:
       continue
 
     output = get_func_graph_output(consumer.outputs[0])
@@ -688,10 +814,12 @@
       # outputs.
       continue
 
-    accum_input_idx = tensor.graph.inputs.index(consumer.inputs[0])
-    accum_output_idx = tensor.graph.outputs.index(output)
-    if accum_input_idx == accum_output_idx:
-      return output
+    for accum_output_idx, out in enumerate(tensor.graph.outputs):
+      if out is output:
+        if accum_input_idx == accum_output_idx:
+          return output
+        break
+
   return None
 
 
@@ -798,10 +926,33 @@
       # the input of the Identity node instead.
       tensor = tensor.op.inputs[0]
 
-    captured_tensor = self._indirect_captures.get(tensor)
+    captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor))
     if captured_tensor is not None:
       return captured_tensor
 
+    # Do not accumulate loop invariants.
+    if (any(tensor is t for t in self._forward_graph.inputs) and
+        any(tensor is t for t in self._forward_graph.outputs)):
+      captured_tensor = super(_WhileBodyGradFuncGraph,
+                              self)._capture_helper(tensor, name)
+      # Add to `popped_tensor_lists` so that this gets added to the list of
+      # outputs.
+      # TODO(srbs): Rename popped_tensor_lists.
+      self.popped_tensor_lists[ops.tensor_id(captured_tensor)] = captured_tensor
+      self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
+      return captured_tensor
+
+    # Do not accumulate Const nodes. Instead copy them directly in the backward
+    # graph.
+    # TODO(srbs): This just checks for `Const` nodes. Consider checking for
+    # graph compile time consts in general.
+    # TODO(srbs): Consider making this a loop input.
+    if constant_op.is_constant(tensor):
+      real_value = constant_op.constant(
+          tensor_util.constant_value(tensor), dtype=tensor.dtype)
+      self._indirect_captures[ops.tensor_id(tensor)] = real_value
+      return real_value
+
     # Resource tensors are not accumulated and handled specially.
     if tensor.dtype == dtypes.resource:
       return self._resource_capture_helper(tensor)
@@ -865,8 +1016,9 @@
     new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back(
         captured_accumulator, element_dtype=tensor.dtype)
 
-    self._indirect_captures[tensor] = captured_tensor
-    self.popped_tensor_lists[captured_accumulator] = new_tensor_list
+    self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
+    self.popped_tensor_lists[ops.tensor_id(
+        captured_accumulator)] = new_tensor_list
     return captured_tensor
 
   def _resource_capture_helper(self, tensor):
@@ -895,13 +1047,12 @@
     assert input_placeholder.dtype == dtypes.resource
     assert tensor_in_outer_graph.dtype == dtypes.resource
     # This must be a loop invariant.
-    assert input_placeholder == self._forward_graph.outputs[index], (
-        "Resource tensors must be loop invariants %s." %
-        tensor_in_outer_graph)
+    assert input_placeholder is self._forward_graph.outputs[index], (
+        "Resource tensors must be loop invariants %s." % tensor_in_outer_graph)
 
-    self._indirect_captures[tensor] = self.capture(
+    self._indirect_captures[ops.tensor_id(tensor)] = self.capture(
         tensor_in_outer_graph, whitelisted=True)
-    return self._indirect_captures[tensor]
+    return self._indirect_captures[ops.tensor_id(tensor)]
 
 
 def _check_shapes_compat(output_tensors, shape_invariants, input_tensors):
diff --git a/tensorflow/python/platform/remote_utils.py b/tensorflow/python/platform/remote_utils.py
new file mode 100644
index 0000000..9ec2e5e
--- /dev/null
+++ b/tensorflow/python/platform/remote_utils.py
@@ -0,0 +1,22 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Platform-specific helpers for connecting to remote servers."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+def get_default_communication_protocol():
+  return 'grpc'
diff --git a/tensorflow/python/profiler/BUILD b/tensorflow/python/profiler/BUILD
index 2d00714..6dbc235 100644
--- a/tensorflow/python/profiler/BUILD
+++ b/tensorflow/python/profiler/BUILD
@@ -164,3 +164,12 @@
         "@com_google_pprof//:pprof_proto_py",
     ],
 )
+
+py_library(
+    name = "traceme",
+    srcs = ["traceme.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        "//tensorflow/python:platform",
+    ],
+)
diff --git a/tensorflow/python/profiler/internal/BUILD b/tensorflow/python/profiler/internal/BUILD
index 47f2a09..c457bf0 100644
--- a/tensorflow/python/profiler/internal/BUILD
+++ b/tensorflow/python/profiler/internal/BUILD
@@ -67,6 +67,7 @@
         "//tensorflow/python:random_ops",
     ],
     tags = [
+        "no_gpu",  # b/138442728
         "no_pip",
     ],
     xla_enable_strict_auto_jit = False,  # Node names are different with autojit
diff --git a/tensorflow/python/profiler/traceme.py b/tensorflow/python/profiler/traceme.py
new file mode 100644
index 0000000..3bd9a66
--- /dev/null
+++ b/tensorflow/python/profiler/traceme.py
@@ -0,0 +1,39 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TraceMe allows the profiler to trace python events.
+
+Usage:
+    with profiler.TraceMe('name'):
+      ...
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python import pywrap_tensorflow
+
+
+class TraceMe(object):
+  """Context manager that generates a trace event in the profiler."""
+
+  def __init__(self, name):
+    self._traceme = pywrap_tensorflow.PythonTraceMe(name)
+
+  def __enter__(self):
+    self._traceme.Enter()
+
+  def __exit__(self, exc_type, exc_val, exc_tb):
+    self._traceme.Exit()
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index e9d4bdd..314502e 100755
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -41,17 +41,18 @@
 %rename("%s") TFE_ContextGetMirroringPolicy;
 %rename("%s") TFE_ContextSetThreadLocalDevicePlacementPolicy;
 %rename("%s") TFE_ContextSetThreadLocalMirroringPolicy;
-%rename("%s") TFE_ContextSetAsyncForThread;
 %rename("%s") TFE_ContextSetServerDef;
-%rename("%s") TFE_ContextAsyncWait;
-%rename("%s") TFE_ContextAsyncClearError;
+%rename("%s") TFE_NewExecutor;
+%rename("%s") TFE_DeleteExecutor;
+%rename("%s") TFE_ExecutorIsAsync;
+%rename("%s") TFE_ExecutorWaitForAllPendingNodes;
+%rename("%s") TFE_ExecutorClearError;
+%rename("%s") TFE_ContextSetExecutorForThread;
+%rename("%s") TFE_ContextGetExecutorForThread;
 %rename("%s") TFE_NewProfiler;
 %rename("%s") TFE_ProfilerIsOk;
 %rename("%s") TFE_DeleteProfiler;
 %rename("%s") TFE_ProfilerSerializeToString;
-%rename("%s") TFE_NewProfilerContext;
-%rename("%s") TFE_ProfilerContextSetEagerContext;
-%rename("%s") TFE_DeleteProfilerContext;
 %rename("%s") TFE_StartProfilerServer;
 %rename("%s") TFE_ProfilerClientStartTracing;
 %rename("%s") TFE_ProfilerClientMonitor;
@@ -165,6 +166,8 @@
 %rename("%s") TFE_CancellationManagerIsCancelled;
 %rename("%s") TFE_CancellationManagerStartCancel;
 %rename("%s") TFE_DeleteCancellationManager;
+%rename("%s") TF_ImportGraphDefOptionsSetValidateColocationConstraints;
+%rename("%s") TFE_ClearScalarCache;
 
 %{
 #include "tensorflow/python/eager/pywrap_tfe.h"
@@ -192,6 +195,16 @@
 %}
 static PyObject* TF_ListPhysicalDevices(TF_Status* status);
 
+%{
+#include "tensorflow/python/eager/pywrap_tensor_conversion.h"
+
+static PyObject* TFE_ClearScalarCache() {
+  tensorflow::TFE_TensorHandleCache::Get()->Clear();
+  Py_RETURN_NONE;
+}
+%}
+static PyObject* TFE_ClearScalarCache();
+
 %typemap(in) (const void* proto) {
   char* c_string;
   Py_ssize_t py_size;
@@ -220,6 +233,7 @@
 }
 
 %typemap(in, numinputs=0) unsigned char* is_list (unsigned char tmp) {
+  tmp = 0;
   $1 = &tmp;
 }
 
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD
index 29ce69c..1ca3804 100644
--- a/tensorflow/python/saved_model/BUILD
+++ b/tensorflow/python/saved_model/BUILD
@@ -191,6 +191,7 @@
     srcs_version = "PY2AND3",
     deps = [
         ":constants",
+        ":nested_structure_coder",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:lib",
@@ -481,6 +482,12 @@
     deps = [
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python:framework",
+        "//tensorflow/python:tensor_array_ops",
+        "//tensorflow/python/data/ops:dataset_ops",
+        "//tensorflow/python/data/ops:iterator_ops",
+        "//tensorflow/python/data/ops:optional_ops",
+        "//tensorflow/python/distribute:values",
+        "//tensorflow/python/ops/ragged",
         "@six_archive//:six",
     ],
 )
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py
index 65cffc6..29b62a6 100644
--- a/tensorflow/python/saved_model/builder_impl.py
+++ b/tensorflow/python/saved_model/builder_impl.py
@@ -155,14 +155,14 @@
   def _validate_tensor_info(self, tensor_info):
     """Validates the `TensorInfo` proto.
 
-    Checks if the `encoding` (`name` or `coo_sparse`) and `dtype` fields exist
-    and are non-empty.
+    Checks if the `encoding` (`name` or `coo_sparse` or `type_spec`) and
+    `dtype` fields exist and are non-empty.
 
     Args:
       tensor_info: `TensorInfo` protocol buffer to validate.
 
     Raises:
-      AssertionError: If the `name` or `dtype` fields of the supplied
+      AssertionError: If the `encoding` or `dtype` fields of the supplied
           `TensorInfo` proto are not populated.
     """
     if tensor_info is None:
@@ -175,7 +175,10 @@
           "All TensorInfo protos used in the SignatureDefs must have one of "
           "the 'encoding' fields (e.g., name or coo_sparse) set: %s"
           % tensor_info)
-    if tensor_info.dtype is types_pb2.DT_INVALID:
+    if tensor_info.WhichOneof("encoding") == "composite_tensor":
+      for component in tensor_info.composite_tensor.components:
+        self._validate_tensor_info(component)
+    elif tensor_info.dtype == types_pb2.DT_INVALID:
       raise AssertionError(
           "All TensorInfo protos used in the SignatureDefs must have the dtype "
           "field set: %s" % tensor_info)
diff --git a/tensorflow/python/saved_model/function_deserialization.py b/tensorflow/python/saved_model/function_deserialization.py
index 97d9989..b02f8b00 100644
--- a/tensorflow/python/saved_model/function_deserialization.py
+++ b/tensorflow/python/saved_model/function_deserialization.py
@@ -29,6 +29,7 @@
 from tensorflow.python.framework import function_def_to_graph as function_def_lib
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_spec
+from tensorflow.python.framework import type_spec
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.saved_model import nested_structure_coder
@@ -60,9 +61,11 @@
     The structured function output.
   """
   expected_structure = function.graph.structured_input_signature
-  flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
+  flatten_inputs = nest.flatten_up_to(
+      expected_structure, inputs, expand_composites=True)
+  flatten_expected = nest.flatten(expected_structure, expand_composites=True)
   tensor_inputs = []
-  for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
+  for arg, expected in zip(flatten_inputs, flatten_expected):
     if isinstance(expected, tensor_spec.TensorSpec):
       tensor_inputs.append(
           ops.convert_to_tensor(arg, dtype_hint=expected.dtype))
@@ -111,6 +114,8 @@
         return False
       if not expected.shape.is_compatible_with(arg.shape):
         return False
+    elif isinstance(expected, type_spec.TypeSpec):
+      return expected.is_compatible_with(arg)
     else:
       if arg != expected:
         return False
@@ -317,7 +322,7 @@
 
     # Also register the gradients in the current root context.
     with ops.init_scope():
-      func._register_gradient()  # pylint: disable=protected-access
+      func._register_delayed_rewrite_gradient()  # pylint: disable=protected-access
 
   return functions
 
@@ -363,8 +368,8 @@
       # function call is the default gradient for the function and not a
       # custom one.
       fname = node_def.attr["f"].func.name
-      node_def.attr["_gradient_op_type"].s = compat.as_bytes(
-          functions[fname]._gradient_name)  # pylint: disable=protected-access
+      gradient_name = functions[fname]._register_delayed_rewrite_gradient()  # pylint: disable=protected-access
+      node_def.attr["_gradient_op_type"].s = compat.as_bytes(gradient_name)
     else:
       logging.warning("Importing a function (%s) with ops with custom "
                       "gradients. Will likely fail if a gradient is "
diff --git a/tensorflow/python/saved_model/load.py b/tensorflow/python/saved_model/load.py
index bf62c6b..88f0f81 100644
--- a/tensorflow/python/saved_model/load.py
+++ b/tensorflow/python/saved_model/load.py
@@ -176,22 +176,27 @@
       if bound_inputs:
         for bound_input, internal_capture in zip(
             bound_inputs, concrete_function.inputs[-len(bound_inputs):]):
-          concrete_function.graph.captures[bound_input] = internal_capture
-          if internal_capture.dtype == dtypes.resource:
-            if resource_variable_ops.is_resource_variable(bound_input):
-              try:
-                handle = bound_input.handle
-              except ValueError:
-                # For mirrored variables we'll copy handle data for components
-                # as they get captured.
-                pass
+          if ds_values.is_distributed_variable(bound_input):
+            concrete_function.graph.capture_distributed_variable(
+                bound_input, internal_capture)
+          else:
+            concrete_function.graph._captures[ops.tensor_id(bound_input)] = (  # pylint: disable=protected-access
+                bound_input, internal_capture)
+            if internal_capture.dtype == dtypes.resource:
+              if resource_variable_ops.is_resource_variable(bound_input):
+                try:
+                  handle = bound_input.handle
+                except ValueError:
+                  # For mirrored variables we'll copy handle data for components
+                  # as they get captured.
+                  pass
+                else:
+                  custom_gradient.copy_handle_data(handle, internal_capture)
               else:
-                custom_gradient.copy_handle_data(handle, internal_capture)
-            else:
-              custom_gradient.copy_handle_data(bound_input, internal_capture)
-          # Setting "captures" first means "capture" won't create a new
-          # placeholder for this input.
-          concrete_function.graph.capture(bound_input)
+                custom_gradient.copy_handle_data(bound_input, internal_capture)
+            # Setting "captures" first means "capture" won't create a new
+            # placeholder for this input.
+            concrete_function.graph.capture(bound_input)
 
   def _get_tensor_from_node(self, node_id):
     """Resolves a node id into a tensor to be captured for a function."""
diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py
index e28ee4b..102b93e 100644
--- a/tensorflow/python/saved_model/load_test.py
+++ b/tensorflow/python/saved_model/load_test.py
@@ -32,7 +32,7 @@
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import test
 from tensorflow.python.eager import wrap_function
-from tensorflow.python.feature_column import feature_column_v2
+from tensorflow.python.feature_column import feature_column_lib
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
@@ -42,6 +42,7 @@
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import test_util
 from tensorflow.python.framework import versions
+from tensorflow.python.keras import keras_parameterized
 from tensorflow.python.keras.engine import base_layer
 from tensorflow.python.keras.engine import input_layer
 from tensorflow.python.keras.engine import sequential
@@ -58,6 +59,8 @@
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
+from tensorflow.python.ops.ragged import ragged_factory_ops
+from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.saved_model import load
 from tensorflow.python.saved_model import save
 from tensorflow.python.saved_model import tag_constants
@@ -67,34 +70,35 @@
 from tensorflow.python.util import tf_inspect
 
 
+def cycle(obj, cycles, signatures=None):
+  to_save = obj
+  # TODO(vbardiovsky): It would be nice if exported protos reached a fixed
+  # point w.r.t. saving/restoring, ideally after 2nd saving.
+  for _ in range(cycles):
+    path = tempfile.mkdtemp(prefix=test.get_temp_dir())
+    # If available, we'll run the save and restore preferring the GPU. This
+    # just makes sure we aren't throwing errors and have enough
+    # device("CPU") blocks to satisfy the placer.
+    with test_util.use_gpu():
+      save.save(to_save, path, signatures)
+      loaded = load.load(path)
+    to_save = loaded
+  return loaded
+
+
 @parameterized.named_parameters(
     dict(testcase_name="ReloadOnce", cycles=1),
     dict(testcase_name="ReloadTwice", cycles=2),
     dict(testcase_name="ReloadThrice", cycles=3))
 class LoadTest(test.TestCase, parameterized.TestCase):
 
-  def cycle(self, obj, cycles, signatures=None):
-    to_save = obj
-    # TODO(vbardiovsky): It would be nice if exported protos reached a fixed
-    # point w.r.t. saving/restoring, ideally after 2nd saving.
-    for _ in range(cycles):
-      path = tempfile.mkdtemp(prefix=self.get_temp_dir())
-      # If available, we'll run the save and restore preferring the GPU. This
-      # just makes sure we aren't throwing errors and have enough
-      # device("CPU") blocks to satisfy the placer.
-      with test_util.use_gpu():
-        save.save(to_save, path, signatures)
-        loaded = load.load(path)
-      to_save = loaded
-    return loaded
-
   def test_structure_import(self, cycles):
     root = tracking.AutoTrackable()
     root.dep_one = tracking.AutoTrackable()
     root.dep_two = tracking.AutoTrackable()
     root.dep_two.dep = tracking.AutoTrackable()
     root.dep_three = root.dep_two.dep
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
     self.assertIs(imported.dep_three, imported.dep_two.dep)
     self.assertIsNot(imported.dep_one, imported.dep_two)
 
@@ -102,7 +106,7 @@
     root = tracking.AutoTrackable()
     root.v1 = variables.Variable(1., trainable=True)
     root.v2 = variables.Variable(2., trainable=False)
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
     self.assertEqual(imported.v1.numpy(), 1.0)
     self.assertTrue(imported.v1.trainable)
     self.assertEqual(imported.v2.numpy(), 2.0)
@@ -114,13 +118,13 @@
     # is based on object name and not on variable name.
     root.v1 = variables.Variable(1., trainable=True, name="v1")
     root.v2 = variables.Variable(2., trainable=False, name="v1")
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
     self.assertEqual(imported.v1.numpy(), 1.0)
     self.assertEqual(imported.v2.numpy(), 2.0)
     self.assertEqual(imported.v1.name, root.v1.name)
     self.assertEqual(imported.v2.name, root.v2.name)
     with variable_scope.variable_scope("foo"):
-      imported = self.cycle(root, cycles)
+      imported = cycle(root, cycles)
       self.assertTrue(imported.v1.name.startswith("foo/"))
       self.assertTrue(imported.v2.name.startswith("foo/"))
 
@@ -139,7 +143,7 @@
 
     m = MakeVariable()
     m.make_variable([1, 2, 3])
-    m = self.cycle(m, cycles)
+    m = cycle(m, cycles)
     m.v.assign([1, 2, 3, 4])
     self.assertEqual([None], tensor_shape.as_shape(m.v.shape).as_list())
 
@@ -152,7 +156,7 @@
         lambda x: root.weights * x,
         input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
     for _ in range(cycles):
-      imported = self.cycle(root, 1)
+      imported = cycle(root, 1)
       self.evaluate(imported.weights.initializer)
     self.assertEqual(4., self.evaluate(imported.f(constant_op.constant(2.))))
     self.evaluate(imported.weights.assign(4.0))
@@ -165,7 +169,7 @@
     root.f = def_function.function(
         lambda x: captured_constant * x,
         input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
     self.assertEqual(4., self.evaluate(imported.f(constant_op.constant(2.))))
 
   def test_control_outputs(self, cycles):
@@ -178,7 +182,7 @@
         exported_graph.get_operation_by_name("should_be_control_output"),
         exported_graph.control_outputs)
 
-    imported = self.cycle(exported, cycles)
+    imported = cycle(exported, cycles)
     # Calling get_concrete_function wraps in a second call operation; we want to
     # inspect the original function body for the control output; digging into
     # graph.as_graph_def() and its FunctionDefLibrary is another option.
@@ -244,7 +248,7 @@
 
     root = Adder()
     root.add(constant_op.constant(1.))
-    root = self.cycle(root, cycles)
+    root = cycle(root, cycles)
     root.add(constant_op.constant(1.))
 
   def test_capture_assets(self, cycles):
@@ -253,7 +257,7 @@
     root.f = def_function.function(
         lambda: root.vocab.asset_path,
         input_signature=[])
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
     original_output = root.f().numpy()
     imported_output = imported.f().numpy()
     self.assertNotEqual(original_output, imported_output)
@@ -270,7 +274,7 @@
     original_output = root.f().numpy()
 
     if cycles > 1:
-      root = self.cycle(root, cycles - 1)
+      root = cycle(root, cycles - 1)
     path = tempfile.mkdtemp(prefix=self.get_temp_dir())
     save.save(root, path)
 
@@ -288,7 +292,7 @@
     root = tracking.AutoTrackable()
     root.asset1 = tracking.TrackableAsset(vocab)
     root.asset2 = tracking.TrackableAsset(vocab)
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
     self.assertEqual(imported.asset1.asset_path.numpy(),
                      imported.asset2.asset_path.numpy())
 
@@ -304,7 +308,7 @@
     root.f(constant_op.constant(1.))
     root.f(constant_op.constant(1))
 
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
 
     self.assertEqual(4., imported.f(constant_op.constant(2.)).numpy())
     self.assertEqual(14, imported.f(constant_op.constant(7)).numpy())
@@ -318,7 +322,7 @@
     root = tracking.AutoTrackable()
     root.f = func
 
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
     self.assertEqual(4., imported.f(constant_op.constant(2.0)).numpy())
 
   def test_explicit_save_signature(self, cycles):
@@ -329,7 +333,7 @@
     root = tracking.AutoTrackable()
     root.f = func
 
-    imported = self.cycle(
+    imported = cycle(
         root, cycles, {
             "f":
                 root.f.get_concrete_function(
@@ -347,7 +351,7 @@
 
     root = tracking.AutoTrackable()
     root.g = g
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
     imported.g(constant_op.constant([1.0]))
 
   def test_function_with_default_bool_input(self, cycles):
@@ -365,7 +369,7 @@
     self.assertEqual(7, root.f(constant_op.constant(1)).numpy())
     self.assertEqual(2, root.f(constant_op.constant(1), True).numpy())
 
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
 
     self.assertEqual(4, imported.f(constant_op.constant(2), True).numpy())
     self.assertEqual(7, imported.f(constant_op.constant(2)).numpy())
@@ -395,7 +399,7 @@
     concrete_functions = root.f._list_all_concrete_functions_for_serialization()  # pylint: disable=protected-access
     self.assertEqual(4, len(concrete_functions))
 
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
 
     self.assertAllEqual([0.0, 0.0, 0.0],
                         imported.f(constant_op.constant([1, 2, 3]),
@@ -429,7 +433,7 @@
     obj.increase()
     self.assertEqual(16.0, obj.variable.numpy())
 
-    imported = self.cycle(obj, cycles)
+    imported = cycle(obj, cycles)
 
     imported.increase(constant_op.constant(10.0))
     self.assertEqual(26.0, imported.variable.numpy())
@@ -460,7 +464,7 @@
     # matching signature will be valid on the loaded model.
     self.assertEqual(31, root.f(input1).numpy())
 
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
 
     with self.assertRaisesRegexp(ValueError,
                                  "Could not find matching function to call"):
@@ -489,7 +493,7 @@
     self.assertEqual(3, result[1].numpy())
     self.assertEqual(0.5, result[2]["x"].numpy())
 
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
 
     result = imported.f(constant_op.constant(2), constant_op.constant(5))
     self.assertEqual(7, result[0].a.numpy())
@@ -524,7 +528,7 @@
     train_input = dict(x=constant_op.constant([[1.]]),
                        y=constant_op.constant([[2.]]))
     root.train(**train_input)
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
     self.assertAllClose(root.optimizer.learning_rate.numpy(),
                         imported.optimizer.learning_rate.numpy())
     self.assertAllClose(root(constant_op.constant([[-0.5]])),
@@ -552,7 +556,7 @@
     self.assertEqual(2, root.f(constant_op.constant(1), True).numpy())
     self.assertEqual(6, root.f(constant_op.constant(1), defg=7.0).numpy())
 
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
 
     self.assertEqual(4, imported.f(constant_op.constant(2), True).numpy())
     self.assertEqual(7, imported.f(constant_op.constant(2)).numpy())
@@ -572,7 +576,7 @@
     x = constant_op.constant(10)
     self.assertEqual(7, root.f(x, learning_rate=0.5, epochs=3).numpy())
 
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
 
     with self.assertRaisesRegexp(ValueError,
                                  "Could not find matching function to call.*"):
@@ -600,7 +604,7 @@
     self.assertEqual(27, root.f(constant_op.constant(1)).numpy())
     self.assertEqual(2, root.f(constant_op.constant(1), True).numpy())
 
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
 
     self.assertEqual(4, imported.f(constant_op.constant(2), True).numpy())
     self.assertEqual(27, imported.f(constant_op.constant(2)).numpy())
@@ -620,7 +624,7 @@
         return x * self.var
 
     m = M()
-    self.cycle(m, cycles)
+    cycle(m, cycles)
     self.assertEqual(4.0, m.f(constant_op.constant(2.0)).numpy())
 
   def test_basic_backprop(self, cycles):
@@ -634,7 +638,7 @@
     root.weight = weight
     root.bias = bias
     root.g = g
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
     with backprop.GradientTape() as t:
       x = constant_op.constant([3.5])
       loss = imported.g(x)
@@ -675,7 +679,7 @@
     root.bias = bias
     root.g = h
 
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
     with backprop.GradientTape() as t:
       x = constant_op.constant([3.5])
       loss = imported.g(x)
@@ -696,7 +700,7 @@
     root.m2.__call__ = def_function.function(
         input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])(
             lambda x: x*3.0)
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
     x = constant_op.constant(1.0)
 
     self.assertTrue(callable(imported.m1))
@@ -720,7 +724,7 @@
     root.__call__.__call__ = tracking.AutoTrackable()
     root.__call__.__call__.__call__ = func
 
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
     self.assertTrue(callable(imported))
     x = constant_op.constant(1.0)
     self.assertAllEqual(imported(x).numpy(), 3.0)
@@ -734,7 +738,7 @@
         input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
 
     if cycles > 1:
-      root = self.cycle(root, cycles - 1)
+      root = cycle(root, cycles - 1)
     path = tempfile.mkdtemp(prefix=self.get_temp_dir())
     save.save(root, path)
 
@@ -764,7 +768,7 @@
         input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
 
     if cycles > 1:
-      root = self.cycle(root, cycles - 1)
+      root = cycle(root, cycles - 1)
     path = tempfile.mkdtemp(prefix=self.get_temp_dir())
     save.save(root, path)
 
@@ -794,7 +798,7 @@
     concrete_functions = root.f._list_all_concrete_functions_for_serialization()  # pylint: disable=protected-access
     self.assertEqual(1, len(concrete_functions))
 
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
 
     with self.assertRaisesRegexp(ValueError, "Python inputs incompatible"):
       # We cannot call the function with a constant of shape ().
@@ -824,7 +828,7 @@
     root = tracking.AutoTrackable()
     root.f = func
 
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
 
     concrete = imported.f.get_concrete_function(
         training=True, x=tensor_spec.TensorSpec([None], dtypes.int32))
@@ -852,7 +856,7 @@
     self.assertAllEqual([2, 4], root.f(constant_op.constant([1, 2])).numpy())
 
     # TODO(andresp): Fix exporting of loaded concrete functions as signatures.
-    imported = self.cycle(root, cycles, signatures={})
+    imported = cycle(root, cycles, signatures={})
 
     self.assertAllEqual([2, 4, 6, 8],
                         imported.f(constant_op.constant([1, 2, 3, 4])).numpy())
@@ -874,14 +878,15 @@
 
     root = Root()
     self.assertIn(root.v.handle,
-                  root.use_v.get_concrete_function().graph.captures)
+                  root.use_v.get_concrete_function().graph.external_captures)
     for _ in range(cycles):
-      root = self.cycle(root, 1, signatures=root.use_v.get_concrete_function())
-    func_captures = root.use_v.get_concrete_function().graph.captures
+      root = cycle(root, 1, signatures=root.use_v.get_concrete_function())
+    func_captures = root.use_v.get_concrete_function().graph.external_captures
     self.assertLen(func_captures, 2)
     self.assertIn(root.v.handle, func_captures)
     self.assertIn(root.v1.handle, func_captures)
-    signature_captures = root.signatures["serving_default"].graph.captures
+    signature_captures = root.signatures[
+        "serving_default"].graph.external_captures
     self.assertLen(signature_captures, 2)
     self.assertIn(root.v.handle, signature_captures)
     self.assertIn(root.v1.handle, signature_captures)
@@ -899,7 +904,7 @@
     self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy())
 
     # TODO(andresp): Fix exporting of loaded concrete functions as signatures.
-    imported = self.cycle(root, cycles, signatures={})
+    imported = cycle(root, cycles, signatures={})
 
     self.assertAllEqual([2, 4, 6],
                         imported.f(x=constant_op.constant([1, 2, 3])).numpy())
@@ -913,7 +918,7 @@
     root.f = func.get_concrete_function(constant_op.constant([1]))
     self.assertAllEqual([4], root.f(constant_op.constant([2])).numpy())
     # TODO(andresp): Fix exporting of loaded concrete functions as signatures.
-    imported = self.cycle(root, cycles, signatures={})
+    imported = cycle(root, cycles, signatures={})
     self.assertAllEqual([6],
                         imported.f(constant_op.constant([3])).numpy())
 
@@ -934,7 +939,7 @@
 
     self.assertEqual(2., _compute_gradient(root.f).numpy())
     # TODO(andresp): Fix exporting of loaded concrete functions as signatures.
-    imported = self.cycle(root, cycles, signatures={})
+    imported = cycle(root, cycles, signatures={})
     self.assertEqual(2., _compute_gradient(imported.f).numpy())
 
   def test_revived_concrete_function_kwargs(self, cycles):
@@ -949,7 +954,7 @@
     self.assertEqual(8., root.f(y=constant_op.constant(3.),
                                 x=constant_op.constant(2.)).numpy())
     # TODO(andresp): Fix exporting of loaded concrete functions as signatures.
-    imported = self.cycle(root, cycles, signatures={})
+    imported = cycle(root, cycles, signatures={})
     self.assertEqual(8., imported.f(y=constant_op.constant(3.),
                                     x=constant_op.constant(2.)).numpy())
 
@@ -965,7 +970,7 @@
         tensor_spec.TensorSpec([], dtypes.float32, name="y"))
     self.assertEqual(8., root.f(y=constant_op.constant(3.),
                                 x=constant_op.constant(2.)).numpy())
-    imported = self.cycle(root, cycles, signatures={})
+    imported = cycle(root, cycles, signatures={})
     self.assertEqual(8., imported.f(y=constant_op.constant(3.),
                                     x=constant_op.constant(2.)).numpy())
 
@@ -987,7 +992,7 @@
     root.f(vsave)
     self.assertEqual(2, vsave.numpy())
     self.assertEqual(-1, capture.numpy())
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
 
     vload = variables.Variable(1)
     imported.f(vload)
@@ -1010,7 +1015,7 @@
     one = constant_op.constant(1)
     self.assertEqual(2, root.func(one).numpy())
     self.assertEqual(2, root.concrete_func(one).numpy())
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
     self.assertEqual(2, imported.func(one).numpy())
     self.assertEqual(2, imported.concrete_func(one).numpy())
 
@@ -1022,7 +1027,7 @@
     root.funcs = dict(
         a=def_function.function(lambda: constant_op.constant(100.)))
     root.funcs["conc"] = root.funcs["a"].get_concrete_function()
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
     self.assertEqual(1., imported.variables["a"].numpy())
     self.assertEqual(2., imported.variables["b"].numpy())
     self.assertEqual(set(["a", "b"]), set(imported.variables.keys()))
@@ -1034,7 +1039,7 @@
     root.variables = [variables.Variable(1.)]
     root.variables.append(1)
     root.variables.append(variables.Variable(3.))
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
     self.assertEqual(1., imported.variables[0].numpy())
     self.assertEqual(3., imported.variables[2].numpy())
     self.assertIs(None, imported.variables[1])
@@ -1055,7 +1060,7 @@
 
     root.losses.append(_v2_loss)
     self.assertAllClose([1., 4.], [loss() for loss in root.losses])
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
     self.assertAllClose([1., 4.], [loss() for loss in imported.losses])
     imported.variables[0].assign(3.)
     imported.variables[1].assign(4.)
@@ -1068,7 +1073,7 @@
     root.g = def_function.function(lambda: const + 2.)
     self.assertAllClose(array_ops.ones([100]), root.f())
     self.assertAllClose(2. * array_ops.ones([100]), root.g())
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
     self.assertAllClose(array_ops.ones([100]), imported.f())
     self.assertAllClose(2. * array_ops.ones([100]), imported.g())
     # TODO(b/123408994): Use the public get_concrete_function.
@@ -1098,14 +1103,14 @@
         return f
 
     exported = Exported()
-    imported = self.cycle(exported, cycles)
+    imported = cycle(exported, cycles)
     self.assertEqual(0, imported.make_func().numpy())
     self.assertEqual(1, exported.make_func().numpy())
 
   def test_overwritten_signatures_error(self, cycles):
     exported = tracking.AutoTrackable()
     exported.f = def_function.function(lambda: constant_op.constant(1.))
-    imported = self.cycle(
+    imported = cycle(
         exported, cycles,
         signatures={"key": exported.f.get_concrete_function()})
     self.assertEqual(1., imported.signatures["key"]()["output_0"].numpy())
@@ -1125,13 +1130,13 @@
         return self.v * x
 
     exported = Exported()
-    imported = self.cycle(
+    imported = cycle(
         exported,
         cycles=1,
         signatures=exported.do.get_concrete_function(
             tensor_spec.TensorSpec(None, dtypes.float32)))
     for _ in range(cycles - 1):
-      imported = self.cycle(imported, cycles=1, signatures=imported.signatures)
+      imported = cycle(imported, cycles=1, signatures=imported.signatures)
     self.assertEqual(["serving_default"], list(imported.signatures.keys()))
     imported_function = imported.signatures["serving_default"]
     two = constant_op.constant(2.)
@@ -1152,12 +1157,12 @@
         return x + y
 
     exported = Exported()
-    imported = self.cycle(
+    imported = cycle(
         exported, cycles=1, signatures=exported.do.get_concrete_function(
             tensor_spec.TensorSpec(None, dtypes.float32),
             tensor_spec.TensorSpec(None, dtypes.float32)))
     for _ in range(cycles - 1):
-      imported = self.cycle(imported, cycles=1, signatures=imported.signatures)
+      imported = cycle(imported, cycles=1, signatures=imported.signatures)
     with self.assertRaises(TypeError):
       imported.signatures["serving_default"](
           constant_op.constant(1.),
@@ -1193,7 +1198,7 @@
 
   def test_table(self, cycles):
     root = self._make_model_with_tables()
-    imported = self.cycle(root, cycles, signatures={})
+    imported = cycle(root, cycles, signatures={})
     keys = constant_op.constant(["brain", "test", "foo", "surgery"])
     self.assertAllEqual([0, -1, -1, 2], imported.lookup1(keys).numpy())
     self.assertAllEqual([2, 0, 1, -1], imported.lookup2(keys).numpy())
@@ -1212,19 +1217,19 @@
     root = self._make_model_with_tables()
     # Warm up collections to ignore those that don't expand every iteration,
     # e.g. the __varscope collection.
-    self.cycle(root, 1)
+    cycle(root, 1)
     original_collections = _gather_nonempty_collections()
-    self.cycle(root, cycles)
+    cycle(root, cycles)
     self.assertEqual(original_collections, _gather_nonempty_collections())
 
   def test_table_in_graph(self, cycles):
     root = self._make_model_with_tables()
 
     if cycles > 1:
-      root = self.cycle(root, cycles - 1)
+      root = cycle(root, cycles - 1)
     path = tempfile.mkdtemp(prefix=self.get_temp_dir())
     save.save(root, path)
-    imported = self.cycle(root, 1)
+    imported = cycle(root, 1)
 
     with ops.Graph().as_default():
       imported = load.load(path)
@@ -1243,7 +1248,7 @@
 
     root = tracking.AutoTrackable()
     root.f = def_function.function(f)
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
 
     restored_fullargspec = tf_inspect.getfullargspec(imported.f)
     self.assertEqual(original_fullargspec, restored_fullargspec)
@@ -1268,7 +1273,7 @@
 
     root = tracking.AutoTrackable()
     root.f = func
-    root = self.cycle(root, cycles)
+    root = cycle(root, cycles)
     self.assertAllEqual(root.f(), [1.0, 2.0, 3.0, True])
     self.assertAllEqual(root.f(-1.0, training=False), [3.0, 2.0, -1.0, False])
 
@@ -1289,7 +1294,7 @@
 
     root = tracking.AutoTrackable()
     root.f = func
-    root = self.cycle(root, cycles)
+    root = cycle(root, cycles)
     self.assertAllEqual(2, root.f(2).numpy())
     self.assertAllEqual(4, root.f(3).numpy())
     self.assertAllEqual(3, root.f(constant_op.constant(2)).numpy())
@@ -1306,7 +1311,7 @@
     root.f = func
     self.assertAllEqual(root.f(), [1.0])
 
-    root = self.cycle(root, cycles)
+    root = cycle(root, cycles)
     self.assertAllEqual(root.f(), [1.0])
 
   def test_partial_with_non_tensor_defaults(self, cycles):
@@ -1320,7 +1325,7 @@
     root.f = func
     self.assertAllEqual(root.f(1), 6)
 
-    root = self.cycle(root, cycles)
+    root = cycle(root, cycles)
     self.assertAllEqual(root.f(1), 6)
 
   def test_partial_with_positional(self, cycles):
@@ -1333,7 +1338,7 @@
     root.f = func
     self.assertAllEqual(root.f(1), 6)
 
-    root = self.cycle(root, cycles)
+    root = cycle(root, cycles)
     self.assertAllEqual(root.f(1), 6)
 
   def test_partial_with_positional_captured_tensors(self, cycles):
@@ -1348,7 +1353,7 @@
     root.f = func
     self.assertAllEqual(root.f(1), 13)
 
-    root = self.cycle(root, cycles)
+    root = cycle(root, cycles)
     self.assertAllEqual(root.f(1), 13)
 
   def test_partial_keyword_hiding_default(self, cycles):
@@ -1366,7 +1371,7 @@
     self.assertEqual(root.f().numpy(), 9)
     self.assertEqual(root.f(training=False).numpy(), 11)
 
-    root = self.cycle(root, cycles)
+    root = cycle(root, cycles)
     self.assertEqual(root.f().numpy(), 9)
     self.assertEqual(root.f(training=False).numpy(), 11)
 
@@ -1385,7 +1390,7 @@
     root.f = func
     self.assertEqual(root.f(constant_op.constant(4)).numpy(), 44)
 
-    root = self.cycle(root, cycles)
+    root = cycle(root, cycles)
     self.assertEqual(root.f(constant_op.constant(5)).numpy(), 45)
 
   def test_partial_bind_only_first_argument(self, cycles):
@@ -1403,7 +1408,7 @@
     root.f = tf_func
     self.assertAllEqual(root.f(y=constant_op.constant(7)), 12)
 
-    root = self.cycle(root, cycles)
+    root = cycle(root, cycles)
     self.assertAllEqual(root.f(y=constant_op.constant(9)), 14)
 
   def test_partial_with_passed_fn_as_default(self, cycles):
@@ -1420,7 +1425,7 @@
     root.f = func
     self.assertEqual(root.f(constant_op.constant(3)).numpy(), 9)
 
-    root = self.cycle(root, cycles)
+    root = cycle(root, cycles)
     self.assertEqual(root.f(constant_op.constant(3)).numpy(), 9)
 
   def test_partial_with_input_signature(self, cycles):
@@ -1439,7 +1444,7 @@
     a, b, c = root.f(2.0)
     self.assertAllEqual([a.numpy(), b.numpy(), c.numpy()], (1, 2.0, 4))
 
-    root = self.cycle(root, cycles)
+    root = cycle(root, cycles)
     a, b, c = root.f(3.0)
     self.assertAllEqual([a.numpy(), b.numpy(), c.numpy()], (1, 3.0, 4))
 
@@ -1453,7 +1458,7 @@
     root = tracking.AutoTrackable()
     root.f = func
 
-    root = self.cycle(root, cycles)
+    root = cycle(root, cycles)
 
     self.assertEqual([2], root.f([2]).numpy())
 
@@ -1475,7 +1480,7 @@
     if sys.version_info.major == 3 and sys.version_info.minor < 5:
       # TODO(allenl): figure out why this doesn't work in Python3.4
       self.skipTest("Not working in Python 3.4")
-    imported = self.cycle(obj, cycles)
+    imported = cycle(obj, cycles)
     self.assertAllClose(3.,
                         imported(NamedTupleType(a=constant_op.constant(1.),
                                                 b=constant_op.constant(2.))))
@@ -1490,7 +1495,7 @@
 
     obj = tracking.AutoTrackable()
     obj.__call__ = f
-    imported = self.cycle(obj, cycles)
+    imported = cycle(obj, cycles)
 
     self.assertEqual(4.0, imported({"a": 3.0}).numpy())
 
@@ -1510,7 +1515,7 @@
     root = tracking.AutoTrackable()
     root.f = func
 
-    root = self.cycle(root, cycles)
+    root = cycle(root, cycles)
 
     imported_graph = root.f.get_concrete_function().graph
     input_x, input_y = imported_graph.inputs
@@ -1530,7 +1535,7 @@
     v1 = variables.Variable(1.)
     weak_v1 = weakref.ref(v1)
     root = util.Checkpoint(v=v1)
-    root = self.cycle(root, cycles)
+    root = cycle(root, cycles)
     del v1
     self.assertIsNone(weak_v1())
     weak_v2 = weakref.ref(root.v)
@@ -1549,7 +1554,7 @@
                      v.aggregation)
     root = tracking.AutoTrackable()
     root.v = v
-    root = self.cycle(root, cycles)
+    root = cycle(root, cycles)
     self.assertEqual(False, root.v.trainable)
     self.assertEqual(variables.VariableSynchronization.NONE,
                      root.v.synchronization)
@@ -1577,62 +1582,18 @@
     self.assertEqual(
         3 * (1 + 4 + 9 + 16),
         root(constant_op.constant(3, dtype=dtypes.int64)).numpy())
-    root = self.cycle(root, cycles)
+    root = cycle(root, cycles)
     self.assertEqual(
         3 * (1 + 4 + 9 + 16),
         root(constant_op.constant(3, dtype=dtypes.int64)).numpy())
 
-  @test_util.run_in_graph_and_eager_modes
-  def test_dense_features_layer(self, cycles):
-    columns = [feature_column_v2.numeric_column("x"),
-               feature_column_v2.numeric_column("y")]
-    layer = feature_column_v2.DenseFeatures(columns)
-    model = sequential.Sequential([layer])
-    model_input = {"x": constant_op.constant([[1.]]),
-                   "y": constant_op.constant([[2.]])}
-    self.assertAllClose([[1., 2.]], model.predict(model_input, steps=1))
-    loaded = self.cycle(model, cycles)
-    output, = loaded._default_save_signature(model_input).values()
-    self.assertAllClose([[1., 2.]], output)
-    signature_output, = loaded.signatures["serving_default"](
-        **model_input).values()
-    self.assertAllClose([[1., 2.]], signature_output)
-
-  def test_dense_features_layer_fit(self, cycles):
-    columns = [feature_column_v2.numeric_column("x")]
-    model = sequential.Sequential(
-        [feature_column_v2.DenseFeatures(columns),
-         core.Dense(1)])
-    model_input = {"x": constant_op.constant([[1.]])}
-    model.compile(optimizer="adam", loss="mse")
-    model.fit(model_input, constant_op.constant([[3.]]))
-    loaded = self.cycle(model, cycles)
-    loaded._default_save_signature(model_input)
-    loaded.signatures["serving_default"](**model_input)
-
-  def test_multi_output_layer(self, cycles):
-
-    inp = input_layer.Input(name="inp", shape=(None,), dtype=dtypes.float32)
-
-    class _MultiOutput(base_layer.Layer):
-
-      def call(self, x):
-        return x + 1., x + 2.
-
-    out = _MultiOutput(name="out")(inp)
-    model = training_lib.Model(inp, out)
-    loaded = self.cycle(model, cycles)
-    self.assertAllClose(
-        dict(out=2., out_1=3.),
-        loaded.signatures["serving_default"](constant_op.constant(1.)))
-
   def test_tuple_signature(self, cycles):
     root = util.Checkpoint()
     root.f = def_function.function(
         lambda: (array_ops.ones([]), array_ops.zeros([])),
         input_signature=())
     for _ in range(cycles):
-      root = self.cycle(root, 1, signatures=root.f)
+      root = cycle(root, 1, signatures=root.f)
     self.assertEqual(({"output_0": 1., "output_1": 0.}),
                      self.evaluate(root.signatures["serving_default"]()))
 
@@ -1646,14 +1607,14 @@
     root.model.traced_call = _use_sequential
 
     original = root.model.traced_call(array_ops.zeros([1, 1])).numpy()
-    root = self.cycle(root, cycles)
+    root = cycle(root, cycles)
     self.assertAllEqual(
         original,
         root.model.traced_call(array_ops.zeros([1, 1])).numpy())
 
   def test_version_info(self, cycles):
     root = util.Checkpoint()
-    root = self.cycle(root, cycles)
+    root = cycle(root, cycles)
     self.assertEqual(versions.__version__, root.tensorflow_version)
     self.assertEqual(versions.__git_version__, root.tensorflow_git_version)
 
@@ -1669,18 +1630,7 @@
         output = root.g(inp)
         self.assertAllClose(4., output)
       self.assertAllClose(2., tape.gradient(output, inp))
-      root = self.cycle(root, 1)
-
-  def test_functional_model_with_conv(self, cycles):
-    x = input_layer.Input(name="x", shape=(None, None, 3), dtype=dtypes.float32)
-    conved = convolutional.Conv2D(filters=3, kernel_size=3, dilation_rate=2)(x)
-    model = training_lib.Model([x], conved)
-    model_input = array_ops.ones((1, 10, 10, 3))
-    initial_output = model.predict([model_input])
-    model = self.cycle(model, cycles)
-    self.assertAllClose(
-        [initial_output],
-        list(model.signatures["serving_default"](model_input).values()))
+      root = cycle(root, 1)
 
   def test_destroy_resource(self, cycles):
 
@@ -1728,7 +1678,7 @@
             handle, dtypes.float32)
 
     root = MyModel()
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
     self.assertEqual(11, imported.increase().numpy())  # Create the resource.
 
     handle = imported.resource.resource_handle
@@ -1756,9 +1706,92 @@
 
     root = module.Module()
     root.f = outer
-    imported = self.cycle(root, cycles)
+    imported = cycle(root, cycles)
     self.assertAllClose(2., imported.f(constant_op.constant(1.)))
 
+  def test_ragged(self, cycles):
+
+    @def_function.function(input_signature=[
+        ragged_tensor.RaggedTensorSpec(shape=[None, None], dtype=dtypes.int32)
+    ])
+    def f(x):
+      return x + 1
+
+    obj = tracking.AutoTrackable()
+    obj.f = f
+
+    imported1 = cycle(obj, cycles, signatures={})
+    rt = ragged_factory_ops.constant([[1, 2], [3]])
+    self.assertAllEqual(imported1.f(rt), [[2, 3], [4]])
+
+    imported2 = cycle(obj, cycles)
+    rt = ragged_factory_ops.constant([[1, 2], [3]])
+    self.assertAllEqual(imported2.f(rt), [[2, 3], [4]])
+
+@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
+@parameterized.named_parameters(
+    dict(testcase_name="ReloadOnce", cycles=1),
+    dict(testcase_name="ReloadTwice", cycles=2),
+    dict(testcase_name="ReloadThrice", cycles=3))
+class KerasLoadTest(test.TestCase, parameterized.TestCase):
+
+  def test_dense_features_layer(self, cycles):
+    columns = [
+        feature_column_lib.numeric_column("x"),
+        feature_column_lib.numeric_column("y")
+    ]
+    layer = feature_column_lib.DenseFeatures(columns)
+    model = sequential.Sequential([layer])
+    model_input = {"x": constant_op.constant([[1.]]),
+                   "y": constant_op.constant([[2.]])}
+    self.assertAllClose([[1., 2.]], model.predict(model_input, steps=1))
+    loaded = cycle(model, cycles)
+    output, = loaded._default_save_signature(model_input).values()
+    self.assertAllClose([[1., 2.]], output)
+    signature_output, = loaded.signatures["serving_default"](
+        **model_input).values()
+    self.assertAllClose([[1., 2.]], signature_output)
+
+  def test_dense_features_layer_fit(self, cycles):
+    columns = [feature_column_lib.numeric_column("x")]
+    model = sequential.Sequential(
+        [feature_column_lib.DenseFeatures(columns),
+         core.Dense(1)])
+    model_input = {"x": constant_op.constant([[1.]])}
+    model.compile(optimizer="adam", loss="mse", run_eagerly=True,
+                  experimental_run_tf_function=True)
+    model.fit(model_input, constant_op.constant([[3.]]))
+    loaded = cycle(model, cycles)
+    loaded._default_save_signature(model_input)
+    loaded.signatures["serving_default"](**model_input)
+
+  def test_multi_output_layer(self, cycles):
+
+    inp = input_layer.Input(name="inp", shape=(None,), dtype=dtypes.float32)
+
+    class _MultiOutput(base_layer.Layer):
+
+      def call(self, x):
+        return x + 1., x + 2.
+
+    out = _MultiOutput(name="out")(inp)
+    model = training_lib.Model(inp, out)
+    loaded = cycle(model, cycles)
+    self.assertAllClose(
+        dict(out=2., out_1=3.),
+        loaded.signatures["serving_default"](constant_op.constant(1.)))
+
+  def test_functional_model_with_conv(self, cycles):
+    x = input_layer.Input(name="x", shape=(None, None, 3), dtype=dtypes.float32)
+    conved = convolutional.Conv2D(filters=3, kernel_size=3, dilation_rate=2)(x)
+    model = training_lib.Model([x], conved)
+    model_input = array_ops.ones((1, 10, 10, 3))
+    initial_output = model.predict([model_input])
+    model = cycle(model, cycles)
+    self.assertAllClose(
+        [initial_output],
+        list(model.signatures["serving_default"](model_input).values()))
+
 
 class SingleCycleTests(test.TestCase, parameterized.TestCase):
 
diff --git a/tensorflow/python/saved_model/load_v1_in_v2.py b/tensorflow/python/saved_model/load_v1_in_v2.py
index c14af7b7..ac40a21 100644
--- a/tensorflow/python/saved_model/load_v1_in_v2.py
+++ b/tensorflow/python/saved_model/load_v1_in_v2.py
@@ -174,8 +174,7 @@
     # we don't have duplicates or name collisions.
     meta_graph_def.graph_def.library.Clear()
     for function in functions.values():
-      meta_graph_def.graph_def.library.function.append(
-          function._inference_function.definition)  # pylint: disable=protected-access
+      meta_graph_def.graph_def.library.function.append(function.function_def)
     # We've renamed functions and shared names. We need the same operation on
     # the GraphDef itself for consistency.
     for node_def in meta_graph_def.graph_def.node:
diff --git a/tensorflow/python/saved_model/model_utils/mode_keys.py b/tensorflow/python/saved_model/model_utils/mode_keys.py
index 2912de7..6f7a787 100644
--- a/tensorflow/python/saved_model/model_utils/mode_keys.py
+++ b/tensorflow/python/saved_model/model_utils/mode_keys.py
@@ -19,7 +19,7 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
+from tensorflow.python.util.compat import collections_abc
 
 
 class KerasModeKeys(object):
@@ -65,7 +65,7 @@
   return mode in [KerasModeKeys.TRAIN, EstimatorModeKeys.TRAIN]
 
 
-class ModeKeyMap(collections.Mapping):
+class ModeKeyMap(collections_abc.Mapping):
   """Map using ModeKeys as keys.
 
   This class creates an immutable mapping from modes to values. For example,
diff --git a/tensorflow/python/saved_model/nested_structure_coder.py b/tensorflow/python/saved_model/nested_structure_coder.py
index ae6c737..3144bbd 100644
--- a/tensorflow/python/saved_model/nested_structure_coder.py
+++ b/tensorflow/python/saved_model/nested_structure_coder.py
@@ -47,6 +47,7 @@
 from tensorflow.python.ops import tensor_array_ops
 from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.util import compat
+from tensorflow.python.util.compat import collections_abc
 
 
 class NotEncodableError(Exception):
@@ -161,7 +162,7 @@
   if not isinstance(instance, tuple):
     return False
   return (hasattr(instance, "_fields") and
-          isinstance(instance._fields, collections.Sequence) and
+          isinstance(instance._fields, collections_abc.Sequence) and
           all(isinstance(f, six.string_types) for f in instance._fields))
 
 
@@ -483,18 +484,24 @@
     encoded_type_spec = struct_pb2.StructuredValue()
     encoded_type_spec.type_spec_value.CopyFrom(
         struct_pb2.TypeSpecProto(
-            type_spec_class=type_spec_class, type_state=encode_fn(type_state)))
+            type_spec_class=type_spec_class,
+            type_state=encode_fn(type_state),
+            type_spec_class_name=type(type_spec_value).__name__))
     return encoded_type_spec
 
   def can_decode(self, value):
-    return (
-        value.HasField("type_spec_value") and
-        value.type_spec_value.type_spec_class in self.TYPE_SPEC_CLASS_FROM_PROTO
-    )
+    return value.HasField("type_spec_value")
 
   def do_decode(self, value, decode_fn):
+    """Returns the `tf.TypeSpec` encoded by the proto `value`."""
     type_spec_proto = value.type_spec_value
     type_spec_class_enum = type_spec_proto.type_spec_class
+    if type_spec_class_enum not in self.TYPE_SPEC_CLASS_FROM_PROTO:
+      raise ValueError(
+          "The type '%s' is not supported by this version of TensorFlow. "
+          "(The object you are loading must have been created with a newer "
+          "version of TensorFlow.)" % type_spec_proto.type_spec_class_name)
+
     type_spec_class = self.TYPE_SPEC_CLASS_FROM_PROTO[type_spec_class_enum]
     # pylint: disable=protected-access
     return type_spec_class._deserialize(decode_fn(type_spec_proto.type_state))
diff --git a/tensorflow/python/saved_model/nested_structure_coder_test.py b/tensorflow/python/saved_model/nested_structure_coder_test.py
index 41d61d8..23c305d 100644
--- a/tensorflow/python/saved_model/nested_structure_coder_test.py
+++ b/tensorflow/python/saved_model/nested_structure_coder_test.py
@@ -201,27 +201,28 @@
         values {
           type_spec_value {
             type_spec_class: RAGGED_TENSOR_SPEC
-              type_state {
-                tuple_value {
-                  # spec._shape
-                  values {
-                    tensor_shape_value {
-                      dim { size: 1 }
-                      dim { size: 2 }
-                      dim { size: 3 }
-                    }
+            type_spec_class_name: 'RaggedTensorSpec'
+            type_state {
+              tuple_value {
+                # spec._shape
+                values {
+                  tensor_shape_value {
+                    dim { size: 1 }
+                    dim { size: 2 }
+                    dim { size: 3 }
                   }
-                  # spec._dtype
-                  values { tensor_dtype_value: DT_INT64 }
-                  # spec._ragged_rank
-                  values { int64_value: 2 }
-                  # spec._row_splits_dtype
-                  values { tensor_dtype_value: DT_INT32 }
                 }
+                # spec._dtype
+                values { tensor_dtype_value: DT_INT64 }
+                # spec._ragged_rank
+                values { int64_value: 2 }
+                # spec._row_splits_dtype
+                values { tensor_dtype_value: DT_INT32 }
               }
             }
           }
         }
+      }
     """
     expected = struct_pb2.StructuredValue()
     text_format.Parse(expected_pbtxt, expected)
@@ -238,22 +239,23 @@
         values {
           type_spec_value {
             type_spec_class: SPARSE_TENSOR_SPEC
-              type_state {
-                tuple_value {
-                  # spec._shape
-                  values {
-                    tensor_shape_value {
-                      dim { size: 10 }
-                      dim { size: 20 }
-                    }
+            type_spec_class_name: 'SparseTensorSpec'
+            type_state {
+              tuple_value {
+                # spec._shape
+                values {
+                  tensor_shape_value {
+                    dim { size: 10 }
+                    dim { size: 20 }
                   }
-                  # spec._dtype
-                  values { tensor_dtype_value: DT_FLOAT }
                 }
+                # spec._dtype
+                values { tensor_dtype_value: DT_FLOAT }
               }
             }
           }
         }
+      }
     """
     expected = struct_pb2.StructuredValue()
     text_format.Parse(expected_pbtxt, expected)
@@ -261,6 +263,14 @@
     decoded = self._coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
 
+  def testDecodeUnknownTensorSpec(self):
+    encoded = struct_pb2.StructuredValue()
+    encoded.type_spec_value.type_spec_class = 0
+    encoded.type_spec_value.type_spec_class_name = "FutureTensorSpec"
+    with self.assertRaisesRegexp(
+        ValueError, "The type 'FutureTensorSpec' is not supported"):
+      self._coder.decode_proto(encoded)
+
   def testEncodeDataSetSpec(self):
     structure = [dataset_ops.DatasetSpec(
         {"rt": ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32),
diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py
index f357ed0..726180b 100644
--- a/tensorflow/python/saved_model/save.py
+++ b/tensorflow/python/saved_model/save.py
@@ -25,6 +25,7 @@
 from tensorflow.core.protobuf import meta_graph_pb2
 from tensorflow.core.protobuf import saved_model_pb2
 from tensorflow.core.protobuf import saved_object_graph_pb2
+from tensorflow.python.distribute import values as ds_values
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import function as defun
@@ -240,6 +241,7 @@
         asset_initializers_by_resource={},
         asset_filename_map={},
         asset_index={})
+
     for node_id, obj in enumerate(self.nodes):
       if isinstance(obj, tracking.CapturableResource):
         # pylint: disable=protected-access
@@ -248,6 +250,20 @@
         # pylint: enable=protected-access
         resource_map[obj.resource_handle] = new_resource
         self.captured_tensor_node_ids[obj.resource_handle] = node_id
+      elif ds_values.is_distributed_variable(obj):
+        # Put both the distributed variable and component variable handles in
+        # `captured_tensor_node_ids`.
+        # Also create a new distributed variable for `object_map` with newly
+        # created component variables.
+        new_vars = []
+        for v in obj.values:
+          new_variable = resource_variable_ops.copy_to_graph_uninitialized(v)
+          object_map[v] = new_variable
+          new_vars.append(new_variable)
+          resource_map[v.handle] = new_variable.handle
+          self.captured_tensor_node_ids[v.handle] = node_id
+        object_map[obj] = obj._clone_with_new_values(new_vars)  # pylint: disable=protected-access
+        self.captured_tensor_node_ids[obj] = node_id
       elif resource_variable_ops.is_resource_variable(obj):
         new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj)
         object_map[obj] = new_variable
@@ -258,6 +274,11 @@
         self.captured_tensor_node_ids[obj.asset_path] = node_id
 
     for concrete_function in self.concrete_functions:
+      if not concrete_function.graph.saveable:
+        raise ValueError(
+            ("Unable to save function {name} for the following reason(s):\n" +
+             "\n".join(concrete_function.graph.saving_errors))
+            .format(name=concrete_function.name))
       for capture in concrete_function.captured_inputs:
         if (tensor_util.is_tensor(capture)
             and capture.dtype not in _UNCOPIABLE_DTYPES
@@ -306,7 +327,7 @@
       `resource_map`.
   """
   export_captures = []
-  for exterior, interior in original_captures.items():
+  for exterior, interior in original_captures:
     mapped_resource = resource_map.get(exterior, None)
     if mapped_resource is None:
       raise AssertionError(
@@ -393,13 +414,12 @@
   """Calls `function` in the exported graph, using mapped resource captures."""
   export_captures = _map_captures_to_created_tensors(
       function.graph.captures, resource_map)
-  mapped_inputs = args + export_captures
   # Calls the function quite directly, since we have new captured resource
   # tensors we need to feed in which weren't part of the original function
   # definition.
   # pylint: disable=protected-access
-  outputs = function._build_call_outputs(
-      function._inference_function.call(context.context(), mapped_inputs))
+  outputs = function._call_flat(args, export_captures)
+  # pylint: enable=protected-access
   return outputs
 
 
diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py
index 566c508..a5200b0 100644
--- a/tensorflow/python/saved_model/save_test.py
+++ b/tensorflow/python/saved_model/save_test.py
@@ -141,6 +141,22 @@
       save.save(root, os.path.join(self.get_temp_dir(), "saved_model"),
                 signatures=root.f)
 
+  def test_unsaveable_func_graph(self):
+    root = module.Module()
+
+    @def_function.function(input_signature=[])
+    def nested_f():
+      ops.get_default_graph().mark_as_unsaveable("ERROR MSG")
+      return 1
+
+    @def_function.function(input_signature=[])
+    def f():
+      return nested_f()
+
+    root.f = f
+    with self.assertRaisesRegexp(ValueError, "ERROR MSG"):
+      save.save(root, os.path.join(self.get_temp_dir(), "saved_model"))
+
   def test_version_information_included(self):
     root = tracking.AutoTrackable()
     save_dir = os.path.join(self.get_temp_dir(), "saved_model")
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index e36b8b3..7722cd3 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -35,6 +35,7 @@
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variables
+from tensorflow.python.ops.ragged import ragged_factory_ops
 from tensorflow.python.platform import test
 from tensorflow.python.saved_model import builder as saved_model_builder
 from tensorflow.python.saved_model import constants
@@ -43,6 +44,7 @@
 from tensorflow.python.saved_model import main_op
 from tensorflow.python.saved_model import signature_def_utils
 from tensorflow.python.saved_model import tag_constants
+from tensorflow.python.saved_model import utils
 from tensorflow.python.training import saver_test_utils
 from tensorflow.python.training import training
 from tensorflow.python.util import compat
@@ -643,6 +645,19 @@
     self._validate_outputs_tensor_info_accept(builder, tensor_with_coo)
 
   @test_util.run_deprecated_v1
+  def testSignatureDefValidationSucceedsWithRagged(self):
+    ragged_tensor = ragged_factory_ops.constant([[1, 2], [3]])
+    tensor_with_ragged = utils.build_tensor_info(ragged_tensor)
+
+    export_dir = self._get_export_dir("test_signature_def_validation_ragged_1")
+    builder = saved_model_builder._SavedModelBuilder(export_dir)
+    self._validate_inputs_tensor_info_accept(builder, tensor_with_ragged)
+
+    export_dir = self._get_export_dir("test_signature_def_validation_ragged_2")
+    builder = saved_model_builder._SavedModelBuilder(export_dir)
+    self._validate_outputs_tensor_info_accept(builder, tensor_with_ragged)
+
+  @test_util.run_deprecated_v1
   def testAssets(self):
     export_dir = self._get_export_dir("test_assets")
     builder = saved_model_builder._SavedModelBuilder(export_dir)
diff --git a/tensorflow/python/saved_model/signature_serialization.py b/tensorflow/python/saved_model/signature_serialization.py
index 0e969e1..3f3725f 100644
--- a/tensorflow/python/saved_model/signature_serialization.py
+++ b/tensorflow/python/saved_model/signature_serialization.py
@@ -18,8 +18,6 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
-
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import function as defun
 from tensorflow.python.framework import ops
@@ -29,6 +27,7 @@
 from tensorflow.python.training.tracking import base
 from tensorflow.python.util import compat
 from tensorflow.python.util import nest
+from tensorflow.python.util.compat import collections_abc
 
 
 DEFAULT_SIGNATURE_ATTR = "_default_save_signature"
@@ -87,7 +86,7 @@
   """Converts `signatures` into a dictionary of concrete functions."""
   if signatures is None:
     return {}
-  if not isinstance(signatures, collections.Mapping):
+  if not isinstance(signatures, collections_abc.Mapping):
     signatures = {
         signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures}
   concrete_signatures = {}
@@ -146,7 +145,7 @@
 
 def _normalize_outputs(outputs, function_name, signature_key):
   """Construct an output dictionary from unnormalized function outputs."""
-  if isinstance(outputs, collections.Mapping):
+  if isinstance(outputs, collections_abc.Mapping):
     for key, value in outputs.items():
       if not isinstance(value, ops.Tensor):
         raise ValueError(
@@ -158,7 +157,7 @@
     return outputs
   else:
     original_outputs = outputs
-    if not isinstance(outputs, collections.Sequence):
+    if not isinstance(outputs, collections_abc.Sequence):
       outputs = [outputs]
     if not _is_flat(outputs):
       raise ValueError(
@@ -180,7 +179,7 @@
 # saved if they contain a _SignatureMap. A ".signatures" attribute containing
 # any other type (e.g. a regular dict) will raise an exception asking the user
 # to first "del obj.signatures" if they want it overwritten.
-class _SignatureMap(collections.Mapping, base.Trackable):
+class _SignatureMap(collections_abc.Mapping, base.Trackable):
   """A collection of SavedModel signatures."""
 
   def __init__(self):
@@ -234,7 +233,7 @@
     # be more problematic in case future export changes violated these
     # assertions.
     assert isinstance(func, defun.ConcreteFunction)
-    assert isinstance(func.structured_outputs, collections.Mapping)
+    assert isinstance(func.structured_outputs, collections_abc.Mapping)
     # pylint: disable=protected-access
     if len(func._arg_keywords) == 1:
       assert 1 == func._num_positional_args
diff --git a/tensorflow/python/saved_model/utils_impl.py b/tensorflow/python/saved_model/utils_impl.py
index 2e7b208..3dd7d6c 100644
--- a/tensorflow/python/saved_model/utils_impl.py
+++ b/tensorflow/python/saved_model/utils_impl.py
@@ -22,15 +22,19 @@
 
 from tensorflow.core.framework import types_pb2
 from tensorflow.core.protobuf import meta_graph_pb2
+from tensorflow.core.protobuf import struct_pb2
 from tensorflow.python.eager import context
+from tensorflow.python.framework import composite_tensor
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.lib.io import file_io
 from tensorflow.python.saved_model import constants
+from tensorflow.python.saved_model import nested_structure_coder
 from tensorflow.python.util import compat
 from tensorflow.python.util import deprecation
+from tensorflow.python.util import nest
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -65,6 +69,10 @@
 
 def build_tensor_info_internal(tensor):
   """Utility function to build TensorInfo proto from a Tensor."""
+  if (isinstance(tensor, composite_tensor.CompositeTensor) and
+      not isinstance(tensor, sparse_tensor.SparseTensor)):
+    return _build_composite_tensor_info_internal(tensor)
+
   tensor_info = meta_graph_pb2.TensorInfo(
       dtype=dtypes.as_dtype(tensor.dtype).as_datatype_enum,
       tensor_shape=tensor.get_shape().as_proto())
@@ -77,6 +85,19 @@
   return tensor_info
 
 
+def _build_composite_tensor_info_internal(tensor):
+  """Utility function to build TensorInfo proto from a CompositeTensor."""
+  spec = tensor._type_spec  # pylint: disable=protected-access
+  tensor_info = meta_graph_pb2.TensorInfo()
+  struct_coder = nested_structure_coder.StructureCoder()
+  spec_proto = struct_coder.encode_structure(spec)
+  tensor_info.composite_tensor.type_spec.CopyFrom(spec_proto.type_spec_value)
+  for component in nest.flatten(tensor, expand_composites=True):
+    tensor_info.composite_tensor.components.add().CopyFrom(
+        build_tensor_info_internal(component))
+  return tensor_info
+
+
 def build_tensor_info_from_op(op):
   """Utility function to build TensorInfo proto from an Op.
 
@@ -120,17 +141,19 @@
     "library as tf.compat.v1.saved_model.utils.get_tensor_from_tensor_info or "
     "tf.compat.v1.saved_model.get_tensor_from_tensor_info.")
 def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None):
-  """Returns the Tensor or SparseTensor described by a TensorInfo proto.
+  """Returns the Tensor or CompositeTensor described by a TensorInfo proto.
 
   Args:
-    tensor_info: A TensorInfo proto describing a Tensor or SparseTensor.
+    tensor_info: A TensorInfo proto describing a Tensor or SparseTensor or
+      CompositeTensor.
     graph: The tf.Graph in which tensors are looked up. If None, the
         current default graph is used.
     import_scope: If not None, names in `tensor_info` are prefixed with this
         string before lookup.
 
   Returns:
-    The Tensor or SparseTensor in `graph` described by `tensor_info`.
+    The Tensor or SparseTensor or CompositeTensor in `graph` described by
+    `tensor_info`.
 
   Raises:
     KeyError: If `tensor_info` does not correspond to a tensor in `graph`.
@@ -148,6 +171,14 @@
         _get_tensor(tensor_info.coo_sparse.indices_tensor_name),
         _get_tensor(tensor_info.coo_sparse.values_tensor_name),
         _get_tensor(tensor_info.coo_sparse.dense_shape_tensor_name))
+  elif encoding == "composite_tensor":
+    struct_coder = nested_structure_coder.StructureCoder()
+    spec_proto = struct_pb2.StructuredValue(
+        type_spec_value=tensor_info.composite_tensor.type_spec)
+    spec = struct_coder.decode_proto(spec_proto)
+    components = [_get_tensor(component.name) for component in
+                  tensor_info.composite_tensor.components]
+    return spec.from_components(components)
   else:
     raise ValueError("Invalid TensorInfo.encoding: %s" % encoding)
 
diff --git a/tensorflow/python/saved_model/utils_test.py b/tensorflow/python/saved_model/utils_test.py
index 1e12de9..d176b91 100644
--- a/tensorflow/python/saved_model/utils_test.py
+++ b/tensorflow/python/saved_model/utils_test.py
@@ -19,6 +19,7 @@
 from __future__ import print_function
 
 from tensorflow.core.framework import types_pb2
+from tensorflow.core.protobuf import struct_pb2
 from tensorflow.python.eager import context
 from tensorflow.python.eager import function
 from tensorflow.python.framework import constant_op
@@ -28,7 +29,9 @@
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops.ragged import ragged_factory_ops
 from tensorflow.python.platform import test
+from tensorflow.python.saved_model import nested_structure_coder
 from tensorflow.python.saved_model import utils
 
 
@@ -82,6 +85,26 @@
     self.assertEqual(42, x_tensor_info.tensor_shape.dim[0].size)
     self.assertEqual(69, x_tensor_info.tensor_shape.dim[1].size)
 
+  @test_util.run_v1_only("b/120545219")
+  def testBuildTensorInfoRagged(self):
+    x = ragged_factory_ops.constant([[1, 2], [3]])
+    x_tensor_info = utils.build_tensor_info(x)
+    # Check components
+    self.assertEqual(x.values.name,
+                     x_tensor_info.composite_tensor.components[0].name)
+    self.assertEqual(types_pb2.DT_INT32,
+                     x_tensor_info.composite_tensor.components[0].dtype)
+    self.assertEqual(x.row_splits.name,
+                     x_tensor_info.composite_tensor.components[1].name)
+    self.assertEqual(types_pb2.DT_INT64,
+                     x_tensor_info.composite_tensor.components[1].dtype)
+    # Check type_spec.
+    struct_coder = nested_structure_coder.StructureCoder()
+    spec_proto = struct_pb2.StructuredValue(
+        type_spec_value=x_tensor_info.composite_tensor.type_spec)
+    spec = struct_coder.decode_proto(spec_proto)
+    self.assertEqual(spec, x._type_spec)
+
   def testBuildTensorInfoEager(self):
     x = constant_op.constant(1, name="x")
     with context.eager_mode(), self.assertRaisesRegexp(
diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i
index 82e30be..deb43dd 100644
--- a/tensorflow/python/tensorflow.i
+++ b/tensorflow/python/tensorflow.i
@@ -57,3 +57,5 @@
 %include "tensorflow/python/grappler/cost_analyzer.i"
 %include "tensorflow/python/grappler/graph_analyzer.i"
 %include "tensorflow/python/grappler/model_analyzer.i"
+
+%include "tensorflow/python/util/traceme.i"
diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl
index eaaef4e..741c46f 100644
--- a/tensorflow/python/tools/api/generator/api_init_files.bzl
+++ b/tensorflow/python/tools/api/generator/api_init_files.bzl
@@ -101,6 +101,7 @@
     "keras/metrics/__init__.py",
     "keras/mixed_precision/__init__.py",
     "keras/mixed_precision/experimental/__init__.py",
+    "keras/premade/__init__.py",
     "keras/models/__init__.py",
     "keras/optimizers/__init__.py",
     "keras/optimizers/schedules/__init__.py",
diff --git a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
index b60a729..94d72c2 100644
--- a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
+++ b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
@@ -131,6 +131,7 @@
     "keras/models/__init__.py",
     "keras/optimizers/__init__.py",
     "keras/optimizers/schedules/__init__.py",
+    "keras/premade/__init__.py",
     "keras/preprocessing/__init__.py",
     "keras/preprocessing/image/__init__.py",
     "keras/preprocessing/sequence/__init__.py",
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index cdef42e..abc2f95 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -551,7 +551,7 @@
   input_examples = preprocess_input_examples_arg_string(input_examples_str)
 
   for input_tensor_key, (filename, variable_name) in inputs.items():
-    data = np.load(file_io.FileIO(filename, mode='rb'))
+    data = np.load(file_io.FileIO(filename, mode='rb'), allow_pickle=True)
 
     # When a variable_name key is specified for the input file
     if variable_name:
diff --git a/tensorflow/python/tpu/BUILD b/tensorflow/python/tpu/BUILD
index f653d10..703dc7c 100644
--- a/tensorflow/python/tpu/BUILD
+++ b/tensorflow/python/tpu/BUILD
@@ -8,6 +8,7 @@
     "tf_py_test",
 )
 load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
+load("//tensorflow/core/platform:default/build_config.bzl", "tf_proto_library")
 
 # Do not add anymore paths here. You do not need to be in the visibility list
 # to use TPU symbols. They are accessible from tf.contrib.tpu in TF 1.x and
@@ -200,6 +201,7 @@
         "//tensorflow/python:variable_scope",
         "//tensorflow/python/compiler/xla",
         "//tensorflow/python/ops/losses",
+        "//tensorflow/python/tpu:tensor_tracer_proto_py",
         "//tensorflow/python/tpu/profiler",
     ],
 )
@@ -400,3 +402,12 @@
     ],
     main = "feature_column_v2_test.py",
 )
+
+tf_proto_library(
+    name = "tensor_tracer_proto",
+    srcs = ["tensor_tracer.proto"],
+    protodeps = [
+        "//tensorflow/core:protos_all_proto",
+    ],
+    visibility = ["//visibility:public"],
+)
diff --git a/tensorflow/python/tpu/feature_column_v2_test.py b/tensorflow/python/tpu/feature_column_v2_test.py
index b879753..f62ba1b 100644
--- a/tensorflow/python/tpu/feature_column_v2_test.py
+++ b/tensorflow/python/tpu/feature_column_v2_test.py
@@ -93,7 +93,7 @@
         (7., 11.)  # id 2
     )
 
-    def _initializer(shape, dtype, partition_info):
+    def _initializer(shape, dtype, partition_info=None):
       self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
       self.assertEqual(dtypes.float32, dtype)
       self.assertIsNone(partition_info)
@@ -249,7 +249,7 @@
         (7., 11.)  # id 2
     )
 
-    def _initializer(shape, dtype, partition_info):
+    def _initializer(shape, dtype, partition_info=None):
       self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
       self.assertEqual(dtypes.float32, dtype)
       self.assertIsNone(partition_info)
diff --git a/tensorflow/python/tpu/tensor_tracer.proto b/tensorflow/python/tpu/tensor_tracer.proto
new file mode 100644
index 0000000..ad5392d
--- /dev/null
+++ b/tensorflow/python/tpu/tensor_tracer.proto
@@ -0,0 +1,74 @@
+syntax = "proto3";
+
+package tensorflow;
+
+import "tensorflow/core/framework/graph.proto";
+
+// Tensor Tracer Report proto gives information about the trace including:
+// - TensorTracerConfig: version, device, num replicas, trace mode.
+// - Graphdef, e.g., list of operations, tensors
+// - TracedTensorDef:
+//    * Name of the tensor
+//    * Tracepoint name if provided.
+//    * Index of the tensor in the compact cache if traced.
+//    * Explanation for why the tensor is traced or not.
+message TensorTracerReport {
+  TensorTracerConfig config = 1;
+
+  // Tensorflow graph.
+  tensorflow.GraphDef graphdef = 2;
+
+  // A map from tensor name to its TracedTensorDef.
+  map<string, TracedTensorDef> tensordef = 3;
+
+  message TensorTracerConfig {
+    // Tensor tracer version, e.g. hostcall, outside compilation.
+    string version = 1;
+    // Traced device, CPU, TPU...
+    string device = 2;
+
+    // Trace mode, norm, summary, full-trace.
+    string trace_mode = 3;
+
+    // Number of cores, e.g. TPU cores, in the system.
+    int32 num_cores = 4;
+
+    // Number of hosts, e.g. compute nodes in the system.
+    int32 num_hosts = 5;
+
+    // Keep submode as string for backward compatibility.
+    string submode = 6;
+
+    // Keep num cores per host for backward compatibility.
+    int32 num_cores_per_host = 7;
+
+    // Id of the included cores, if a subset of cores are traced.
+    repeated int32 included_cores = 8;
+
+    // The names of the signatures corresponding to the cache indices.
+    repeated string signatures = 9;
+  }
+
+  message TracedTensorDef {
+    // Name of the tensor as appears in tf graph.
+    string name = 1;
+    // Cache index of the tensor. This may be different than topological index.
+    int32 cache_index = 2;
+    // If trace points are provided, corresponding tracepoint name of the
+    // tensor. Trace points are placed on the edges (tensors) in the tensorflow
+    // graph, and they force tensor tracer to trace the corresponding tensor.
+    // Tracepoints can be added using the programatic interface
+    // tensor_tracer.tensor_tracepoint(tensor, trace_point_name) function.
+    // This will add a trace point with the given trace_point_name for the given
+    // tensor. If a trace_point is provided for the tensor,
+    // trace_point name will be used for the rest of the analysis instead of
+    // tensor names. One can use trace_point_name's to compare two models with
+    // arbitrary tensor names by providing the same trace point name for the
+    // tensors that are comparable.
+    string trace_point_name = 3;
+    // Whether the tensor is traced or not.
+    bool is_traced = 4;
+    // Detailed explanation why the tensor is traced or not.
+    string explanation = 5;
+  }
+}
diff --git a/tensorflow/python/tpu/tensor_tracer.py b/tensorflow/python/tpu/tensor_tracer.py
index ea4ce94..0b3dc66 100644
--- a/tensorflow/python/tpu/tensor_tracer.py
+++ b/tensorflow/python/tpu/tensor_tracer.py
@@ -22,6 +22,9 @@
 import os.path
 import sys
 
+import numpy as np
+
+from tensorflow.core.framework import summary_pb2
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import graph_io
@@ -35,14 +38,18 @@
 from tensorflow.python.ops import linalg_ops
 from tensorflow.python.ops import logging_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_impl
 from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import summary_ops_v2 as summary
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.platform import gfile
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.summary import summary_iterator
 from tensorflow.python.tpu import tensor_tracer_flags
 from tensorflow.python.tpu import tensor_tracer_report
 from tensorflow.python.tpu import tpu
 from tensorflow.python.tpu.ops import tpu_ops
+from tensorflow.python.training import training_util
 
 _DEVICE_TYPE_TPU = 'tpu'
 _DEVICE_TYPE_CPU = 'cpu'
@@ -70,9 +77,58 @@
 _COMPACT_TRACE_FILE_PREFIX = 'compact_trace.'
 _COMPACT_TRACE_ENTRY_INIT_VALUE = -1.0
 _TENSOR_TRACER_STORAGE = 'tensor_tracer_storage'
-_TENSOR_VALUES_CACHE = 'tensor_values_cache'
+_TT_SNAPSHOT = 'tensor_tracer_snapshot'
 _REPLICA_ID_TAG = '#replica-id: '
 
+_TT_SUMMARY_NORM = 'tensor_tracer_norm'
+_TT_SUMMARY_MAX = 'tensor_tracer_max'
+_TT_SUMMARY_MIN = 'tensor_tracer_min'
+_TT_SUMMARY_MEAN = 'tensor_tracer_mean'
+_TT_SUMMARY_VAR = 'tensor_tracer_var'
+_TT_SUMMARY_SIZE = 'tensor_tracer_size'
+
+_TT_SUMMARY_TAG = 'tensor_tracer_summary'
+_TT_TENSORBOARD_PLUGIN_NAME = 'tensor_tracer'
+_TT_HOSTCALL_KEY = 'tensor_tracer_host_call'
+_TT_EVENT_FILE_SUFFIX = '.tensor_tracer'
+
+_TT_SUMMARY_MAX_QUEUE = 100
+
+
+def read_tensor_tracer_event_file(event_file):
+  """Reads the event file written by tensor tracer.
+
+  Args:
+    event_file: Path to the event file that contains only tensor tracer events.
+  Returns:
+    An event dictionary in the form of
+    {step_number: {tensor_name: tensor_content}}
+  Raises:
+    ValueError: If an unexpected trace is found.
+  """
+  event_dict = {}
+  for trace_event in summary_iterator.summary_iterator(event_file):
+    # First event is an event with file_version: "brain.Event:2"
+    if not trace_event.HasField('summary'):
+      continue
+    step = trace_event.step
+    if step not in event_dict:
+      event_dict[step] = {}
+
+    if len(trace_event.summary.value) != 1:
+      raise ValueError('Single step contains %d summary values,'
+                       ' expected 1.' % len(trace_event.summary.value))
+    tensor_value = trace_event.summary.value[0]
+    tensor_name = tensor_value.tag
+
+    real_shape = [d.size for d in tensor_value.tensor.tensor_shape.dim]
+    tensor_content = np.frombuffer(
+        tensor_value.tensor.tensor_content,
+        dtypes.DType(tensor_value.tensor.dtype).as_numpy_dtype()
+        ).reshape(real_shape)
+    event_dict[step][tensor_name] = tensor_content
+  return event_dict
+
 
 def tensor_tracepoint(tensor, checkpoint_name):
   """Adds a checkpoint with the given checkpoint name for the given tensor.
@@ -144,36 +200,6 @@
   return True
 
 
-def _get_tensor_values_cache(graph=None):
-  """Returns the variable that implements tensor-value caching."""
-
-  graph = graph or ops.get_default_graph()
-  collection = graph.get_collection(_TENSOR_TRACER_STORAGE)
-  if len(collection) == 1:
-    return collection[0]
-  elif not collection:
-    raise RuntimeError('%s has not been created'%_TENSOR_VALUES_CACHE)
-  else:
-    raise RuntimeError('Multiple %s created'%_TENSOR_VALUES_CACHE)
-  return None
-
-
-def _create_tensor_values_cache(graph, num_tensors):
-  """Creates a variable as the cache to store intermediate tensor values."""
-  graph = graph or ops.get_default_graph()
-  # Create in proper graph and base name_scope.
-  with graph.as_default() as g, g.name_scope(None):
-    return variable_scope.get_variable(
-        _TENSOR_VALUES_CACHE,
-        shape=[num_tensors],
-        dtype=dtypes.float32,
-        initializer=init_ops.constant_initializer(
-            _COMPACT_TRACE_ENTRY_INIT_VALUE),
-        trainable=False,
-        use_resource=True,
-        collections=[_TENSOR_TRACER_STORAGE, ops.GraphKeys.LOCAL_VARIABLES])
-
-
 class TensorTracer(object):
   """A software construct for tracing tensor values in a TF graph on TPU.
 
@@ -202,10 +228,26 @@
   def check_device_type(device_type):
     """Checks if the given device type is valid."""
 
-    if device_type not in [_DEVICE_TYPE_TPU, _DEVICE_TYPE_CPU]:
+    if device_type not in (_DEVICE_TYPE_TPU, _DEVICE_TYPE_CPU):
       raise ValueError('Invalid device_type "%s"'%device_type)
 
   @staticmethod
+  def check_trace_mode(device_type, trace_mode):
+    """Checks if the given trace mode work on the given device type.
+
+    Args:
+      device_type: Device type, TPU, GPU, CPU.
+      trace_mode: Tensor tracer trace mode.
+    Raises:
+      ValueError: If the given trace mode is not supported for the device.
+    """
+    if trace_mode in (tensor_tracer_flags.TRACE_MODE_SUMMARY,
+                      tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY):
+      if device_type != _DEVICE_TYPE_TPU:
+        raise ValueError('Device_type "%s" is not yet supported for '
+                         'trace mode "%s"' % (device_type, trace_mode))
+
+  @staticmethod
   def loop_cond_op(op):
     return op.type in ('LoopCond', 'RefLoopCond')
 
@@ -236,7 +278,7 @@
       return True
     # Reasons for not including following op types:
     #    Assign: cause incorrect result with CPU tracing.
-    if op.type in ['Assign']:
+    if op.type == 'Assign':
       return True
     return False
 
@@ -253,17 +295,17 @@
     """Return true if scalar output tensor from Op is not safe to be traced."""
 
     # Tracing the following causes cycle in the graph on TPU.
-    if op.type in ['LoopCond', 'Enter', 'Merge', 'Const',
-                   'Switch', 'Less', 'ReadVariableOp']:
+    if op.type in ('LoopCond', 'Enter', 'Merge', 'Const',
+                   'Switch', 'Less', 'ReadVariableOp'):
       return True
     # Tracing the following will cause casting-issue
     # with the norm tracing mode or other compilation issues on CPU.
-    if op.type in ['VarHandleOp', 'IteratorToStringHandle',
+    if op.type in ('VarHandleOp', 'IteratorToStringHandle',
                    'IteratorGetNext', 'OneShotIterator',
                    'IteratorV2', 'MakeIterator',
                    'BatchDatasetV2', 'MapDataset',
                    'FixedLengthRecordDataset', 'TakeDataset', 'ZipDataset',
-                   'Placeholder', 'PlaceholderWithDefault', 'StridedSlice']:
+                   'Placeholder', 'PlaceholderWithDefault', 'StridedSlice'):
       return True
     return False
 
@@ -273,7 +315,7 @@
     if self._parameters.include_less_interesting_ops:
       return False
     # Following ops are highly unlikey to cause bugs.
-    return op.type in ['Const', 'Identity', 'Cast', 'Shape']
+    return op.type in ('Const', 'Identity', 'Cast', 'Shape')
 
   @staticmethod
   def reason(op_idx, details):
@@ -290,6 +332,50 @@
     self._tt_config = tensor_tracer_report.TensorTracerConfig()
     self._parameters = tensor_tracer_flags.TTParameters()
     self._included_op_full_names = set()
+    self._host_call_fn = {}
+    self._cache_variables = {}
+
+  def _get_all_cache_variables(self):
+    return self._cache_variables
+
+  def _create_or_get_tensor_values_cache(self, cache_name, graph=None,
+                                         shape=None, dtype=dtypes.float32,
+                                         num_signatures=None):
+    """Creates a variable as the cache to store intermediate tensor values.
+
+    Args:
+      cache_name: Name to be given to the cache (an instance of tf.variable).
+      graph: Tensorflow graph.
+      shape: A list of dimensions.
+      dtype: Data type of created cache
+    Returns:
+      A ref to newly created or existing cache with the given dimensions.
+    Raises:
+      ValueError: If missing a parameter to create the cache.
+    """
+    def _escape_namescopes(variable_name):
+      # TODO(deveci): This might cause name collisions as in "foo/bar/mytensor"
+      # and "foo_bar/mytensor".
+      return variable_name.replace('/', '_').replace(':', '_')
+
+    if cache_name not in self._cache_variables:
+      if graph is None:
+        raise ValueError('Graph must be provided at cache creation.')
+      if shape is None:
+        raise ValueError('shape must be provided at cache creation.')
+      graph = graph or ops.get_default_graph()
+
+      # Create in proper graph and base name_scope.
+      with graph.as_default() as g, g.name_scope(None):
+        self._cache_variables[cache_name] = variable_scope.get_variable(
+            _TT_SNAPSHOT + '_' + _escape_namescopes(cache_name),
+            shape=shape, dtype=dtype,
+            initializer=init_ops.constant_initializer(
+                _COMPACT_TRACE_ENTRY_INIT_VALUE),
+            trainable=False,
+            use_resource=True,
+            collections=[_TENSOR_TRACER_STORAGE, ops.GraphKeys.LOCAL_VARIABLES])
+    return self._cache_variables[cache_name]
 
   def _add_replica_id_to_graph(self):
     """Adds nodes for computing the replica ID to the graph."""
@@ -368,26 +454,78 @@
         return True
     return False
 
+  def _signature_types(self):
+    """Returns a dictionary holding the order of signatures in the cache for the selected trace mode."""
+    if self._parameters.trace_mode in set([
+        tensor_tracer_flags.TRACE_MODE_NAN_INF,
+        tensor_tracer_flags.TRACE_MODE_NORM,
+        tensor_tracer_flags.TRACE_MODE_MAX_ABS]):
+      return {self._parameters.trace_mode: 0}
+    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY:
+      return {_TT_SUMMARY_NORM: 0, _TT_SUMMARY_MAX: 1, _TT_SUMMARY_MIN: 2,
+              _TT_SUMMARY_MEAN: 3, _TT_SUMMARY_VAR: 4, _TT_SUMMARY_SIZE: 5}
+    return {}
+
+  def _num_signature_dimensions(self):
+    return len(self._signature_types())
+
   def _use_tensor_values_cache(self):
     """Returns True if immediate tensors should be first saved to a cache."""
 
     if self._parameters.trace_mode not in set([
         tensor_tracer_flags.TRACE_MODE_NAN_INF,
         tensor_tracer_flags.TRACE_MODE_NORM,
-        tensor_tracer_flags.TRACE_MODE_MAX_ABS]):
+        tensor_tracer_flags.TRACE_MODE_MAX_ABS,
+        tensor_tracer_flags.TRACE_MODE_SUMMARY
+    ]):
       return False
     if (self._parameters.trace_dir and
         _trace_files_need_precreated(self._parameters.trace_dir)):
       return True
     return self._parameters.use_compact_trace
 
-  def _save_tensor_value_to_cache_op(self, graph, cache_idx, updates):
-    """Returns an Op that will save the given updates to an entry in the cache."""
+  def _use_tensor_buffer(self):
+    """Returns true if the whole tensor needs to be cached/buffered in memory."""
+    return (self._parameters.trace_mode ==
+            tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)
 
-    cache = _get_tensor_values_cache(graph)
+  def _save_tensor_value_to_cache_op(self, cache_idx, updates):
+    """Returns an op that will save the given updates to an entry in the cache.
+
+    Args:
+      cache_idx: The cache index of the tensor within the cache.
+      updates: A dictionary of the signature updates.
+    Returns:
+      Cache update operation.
+    """
+    # state_ops.scatter_update allows updates only along the first dimension.
+    # Make a compact array by concantating different signatures, and update
+    # them all together.
+    sorted_update = []
+    signature_indices = self._signature_types()
+    for _, val in sorted(updates.items(),
+                         key=lambda item: signature_indices[item[0]]):
+      sorted_update.append(val)
+    cache = self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG)
     indices = constant_op.constant([cache_idx])
+    updates = array_ops.concat(sorted_update, axis=0)
+    updates = array_ops.reshape(updates, [1, self._num_signature_dimensions()])
     return state_ops.scatter_update(cache, indices, updates).op
 
+  def _snapshot_tensor(self, tensor):
+    """Creates a new tf.Variable and a new tf.Operation that assigns the value of the tensor to this variable.
+
+    Args:
+      tensor: tensor whose values will be stored in a new tf.Variable.
+    Returns:
+      An assignment operation.
+    """
+
+    snapshot_variable = self._create_or_get_tensor_values_cache(
+        tensor.name, tensor.op.graph,
+        tensor.shape.as_list(), tensor.dtype)
+    return state_ops.assign(snapshot_variable, tensor).op
+
   def _preprocess_traced_tensor(self, tensor):
     """Computes NAN/Norm/Max on TPUs before sending to CPU.
 
@@ -415,13 +553,46 @@
       output_tensor = array_ops.reshape(output_tensor, [1])
       return output_tensor
 
-    def _show_norm(tensor):
-      tensor = math_ops.cast(tensor, dtypes.float32)
-      output_tensor = linalg_ops.norm(tensor)
+    def _compute_signature(tensor, tf_op, cast_to_f32=True):
+      if cast_to_f32:
+        tensor = math_ops.cast(tensor, dtypes.float32)
+      output_tensor = tf_op(tensor)
       # The shape has to be 1. Set it if it does not have the information.
       output_tensor = array_ops.reshape(output_tensor, [1])
       return output_tensor
 
+    def _show_size(tensor):
+      # In order to check the size of a tensor.
+      # Not all sizes are known at the compile time, also, different replicas
+      # sometimes get different sizes of tensors.
+      # Collect it here to be used in merging replica data.
+      tsize = _compute_signature(tensor, array_ops.size, cast_to_f32=False)
+      # Cast to float32, so that it can be placed into same cache with other
+      # signatures.
+      return math_ops.cast(tsize, dtypes.float32)
+
+    def _show_max(tensor, cast_to_f32=True):
+      # returns -inf for empty tensor
+      return _compute_signature(tensor, math_ops.reduce_max, cast_to_f32)
+
+    def _show_min(tensor, cast_to_f32=True):
+      # returns inf for empty tensor
+      return _compute_signature(tensor, math_ops.reduce_min, cast_to_f32)
+
+    def _show_norm(tensor, cast_to_f32=True):
+      # returns 0 for empty tensor
+      return _compute_signature(tensor, linalg_ops.norm, cast_to_f32)
+
+    def _show_mean_and_variance(tensor, cast_to_f32=True):
+      if cast_to_f32:
+        tensor = math_ops.cast(tensor, dtypes.float32)
+      # returns nan for empty tensor
+      mean, var = nn_impl.moments(array_ops.reshape(tensor, [-1]), axes=[0])
+      # The shape has to be 1. Set it if it does not have the information.
+      mean = array_ops.reshape(mean, [1])
+      var = array_ops.reshape(var, [1])
+      return mean, var
+
     def _show_max_abs(tensor):
       tensor = math_ops.cast(tensor, dtypes.float32)
       output_tensor = math_ops.reduce_max(math_ops.abs(tensor))
@@ -450,19 +621,31 @@
 
     if (self._parameters.trace_mode ==
         tensor_tracer_flags.TRACE_MODE_FULL_IF_NAN):
-      return _detect_inf_nan_producer(tensor)
+      return {self._parameters.trace_mode: _detect_inf_nan_producer(tensor)}
     if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NAN_INF:
-      return _detect_nan_inf(tensor)
+      return {self._parameters.trace_mode: _detect_nan_inf(tensor)}
     if (self._parameters.trace_mode ==
         tensor_tracer_flags.TRACE_MODE_PART_TENSOR):
-      return tensor
-    if (self._parameters.trace_mode ==
-        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR):
-      return tensor
+      return {self._parameters.trace_mode: tensor}
+    if (self._parameters.trace_mode in (
+        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR,
+        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)):
+      return {self._parameters.trace_mode: tensor}
     if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NORM:
-      return _show_norm(tensor)
+      return {self._parameters.trace_mode: _show_norm(tensor)}
     if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_MAX_ABS:
-      return _show_max_abs(tensor)
+      return {self._parameters.trace_mode: _show_max_abs(tensor)}
+    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY:
+      tensor = math_ops.cast(tensor, dtypes.float32)
+      tsize = _show_size(tensor)
+      tnorm = _show_norm(tensor, cast_to_f32=False)
+      tmax = _show_max(tensor, cast_to_f32=False)
+      tmin = _show_min(tensor, cast_to_f32=False)
+      tmean, tvar = _show_mean_and_variance(tensor, cast_to_f32=False)
+      return {_TT_SUMMARY_NORM: tnorm, _TT_SUMMARY_MAX: tmax,
+              _TT_SUMMARY_MIN: tmin, _TT_SUMMARY_MEAN: tmean,
+              _TT_SUMMARY_VAR: tvar, _TT_SUMMARY_SIZE: tsize}
+
     raise RuntimeError(
         'Tensor trace fun for %s is not yet implemented'
         % self._parameters.trace_mode)
@@ -566,18 +749,29 @@
     # TRACE_MODE_NORM, and TRACE_MODE_MAX_ABS, as related computations are
     # performed within TPUs and only their results are transferred to CPU.
     # Simply, print the full tensor for these trace modes.
-    if self._parameters.trace_mode in [
+    if self._parameters.trace_mode in (
         tensor_tracer_flags.TRACE_MODE_NAN_INF,
         tensor_tracer_flags.TRACE_MODE_NORM,
         tensor_tracer_flags.TRACE_MODE_FULL_TENSOR,
-        tensor_tracer_flags.TRACE_MODE_MAX_ABS]:
+        tensor_tracer_flags.TRACE_MODE_MAX_ABS,
+        tensor_tracer_flags.TRACE_MODE_SUMMARY
+        ):
       return _show_full_tensor
 
     raise RuntimeError('Tensor trace fun for %s is not yet implemented'
                        %self._parameters.trace_mode)
 
   def _skip_op(self, op_id, op, ops_in_exec_path, report_handler):
-    """Returns True if we should not trace Op."""
+    """Returns True if we should not trace Op.
+
+    Args:
+      op_id: Topological index of the op.
+      op: tf.Operation
+      ops_in_exec_path: Set of operations that are in the execution path.
+      report_handler: An instance of tensor_tracer_report.TTReportHandle.
+    Returns:
+      True if the op should not be traced, false otherwise.
+    """
     if TensorTracer.while_loop_op(op):
       report_handler.instrument_op(
           op, TensorTracer.reason(op_id, _REASON_WHILELOOP_OP))
@@ -614,7 +808,15 @@
     return False
 
   def _skip_tensor(self, op_id, out_tensor, report_handler):
-    """Returns True if we should not trace out_tensor."""
+    """Returns True if we should not trace out_tensor.
+
+    Args:
+      op_id: Topological index of the op producing tensor.
+      out_tensor: tf.Tensor
+      report_handler: An instance of tensor_tracer_report.TTReportHandle.
+    Returns:
+      True if the tensor should not be traced, false otherwise.
+    """
 
     # Skips a tensor if the tensor has a non-numeric type.
     #   Note: we cannot use check_ops.is_numeric_tensor(out_tensor)
@@ -644,11 +846,12 @@
     if not out_tensor.get_shape().is_fully_defined():
       # If trace mode is nan-inf, norm or max, then the tensor will be reduced
       # to a scalar before the outside compilation call.
-      if self._parameters.trace_mode in [
+      if self._parameters.trace_mode in (
           tensor_tracer_flags.TRACE_MODE_NAN_INF,
           tensor_tracer_flags.TRACE_MODE_NORM,
-          tensor_tracer_flags.TRACE_MODE_MAX_ABS
-      ]:
+          tensor_tracer_flags.TRACE_MODE_MAX_ABS,
+          tensor_tracer_flags.TRACE_MODE_SUMMARY
+          ):
         report_handler.instrument_tensor(
             out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED))
         return False
@@ -721,7 +924,20 @@
                                                ops_in_exec_path,
                                                tensor_trace_points,
                                                report_handler):
-    """Determines the tensors to trace and instruments the trace details."""
+    """Determines the tensors to trace and instruments the trace details.
+
+    Args:
+      graph_order: graph_order tuple containing graph (tf.graph), operations
+        (list of operations), op_to_idx (op id mapping), (tensors) list of
+        tensors, tensor_to_idx (tensor id mapping), contains_cycle (whether
+        there is a cycle in the graph), topological_order_or_cycle (list of ops
+        in topological order or list of ops creating a cycle).
+      ops_in_exec_path: Set of ops in the execution path.
+      tensor_trace_points: Collection of programatic tensor trace points.
+      report_handler: An instance of tensor_tracer_report.TTReportHandle.
+    Returns:
+      List of tensors to be traced.
+    """
 
     traced_tensors = []
     checkpoint_operations = set([tensor.op
@@ -743,6 +959,10 @@
     if not self._parameters.trace_dir:
       # traces will be written to stderr. No need to check trace files.
       return
+    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY:
+      # Output files are handled by tf.summary operations, no need to precreate
+      # them.
+      return
     if _trace_files_need_precreated(self._parameters.trace_dir):
       for replica_id in range(0, self._tt_config.num_replicas):
         trace_file_path = os.path.join(
@@ -759,7 +979,15 @@
           raise RuntimeError('Failed to create %s'%self._parameters.trace_dir)
 
   def _determine_trace_and_create_report(self, graph, ops_in_exec_path):
-    """Work needs to be done prior to TPU or CPU tracing."""
+    """Work needs to be done prior to TPU or CPU tracing.
+
+    Args:
+      graph: tf.graph
+      ops_in_exec_path: Set of operations in the execution path.
+    Returns:
+      An instance of tensor_tracer_report.TensorTraceOrder, containing list of
+      tensors to be traced with their topological order information.
+    """
 
     self._check_trace_files()
 
@@ -772,17 +1000,36 @@
 
     tensor_trace_order = tensor_tracer_report.TensorTraceOrder(graph_order,
                                                                traced_tensors)
-    if self._use_tensor_values_cache():
-      _create_tensor_values_cache(graph, len(traced_tensors))
-    report_handler.create_report(self._tt_config, self._parameters,
-                                 tensor_trace_order, tensor_trace_points)
+    num_signatures = self._num_signature_dimensions()
+    if num_signatures:
+      self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG,
+                                              graph,
+                                              [len(traced_tensors),
+                                               num_signatures])
+    if self._parameters.trace_mode in (
+        tensor_tracer_flags.TRACE_MODE_SUMMARY,
+        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY):
+      report_proto = report_handler.create_report_proto(self._tt_config,
+                                                        self._parameters,
+                                                        tensor_trace_order,
+                                                        tensor_trace_points,
+                                                        self._signature_types())
+      report_handler.write_report_proto(report_proto, self._parameters)
+    else:
+      report_handler.create_report(self._tt_config, self._parameters,
+                                   tensor_trace_order, tensor_trace_points)
     return tensor_trace_order
 
-  def _generate_flush_cache_op(self, graph, num_replicas, on_tpu):
+  def _create_host_call(self):
+    return self._parameters.trace_mode in (
+        tensor_tracer_flags.TRACE_MODE_SUMMARY,
+        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)
+
+
+  def _generate_flush_cache_op(self, num_replicas, on_tpu):
     """Generates an Op that will flush the cache to file.
 
     Args:
-      graph: the graph of Ops
       num_replicas: total number of replicas.
       on_tpu: if the graph is executed on TPU.
 
@@ -807,12 +1054,14 @@
             output_stream = sys.stderr
 
           new_step_line = _REPLICA_ID_TAG + replica_str
-          print_op = logging_ops.print_v2(
-              new_step_line, '\n',
-              cache, '\n',
-              summarize=-1,
-              output_stream=output_stream)
-          with ops.control_dependencies([print_op]):
+          print_ops = []
+          for i in range(self._num_signature_dimensions()):
+            print_ops.append(logging_ops.print_v2(
+                new_step_line, '\n',
+                cache[:, i], '\n',
+                summarize=-1,
+                output_stream=output_stream))
+          with ops.control_dependencies(print_ops):
             return constant_op.constant(0).op
         return _print_cache
 
@@ -829,7 +1078,7 @@
       # only known during tf runtime, and we cannot create dynamic filenames.
       return control_flow_ops.case(flush_op_cases, exclusive=True)
 
-    cache = _get_tensor_values_cache(graph)
+    cache = self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG)
     if on_tpu:
       flush_op = tpu.outside_compilation(_flush_fun,
                                          cache.value(), self._replica_id)
@@ -844,12 +1093,10 @@
       with ops.control_dependencies([assign_op]):
         return constant_op.constant(0).op
 
-  def _flush_tensor_values_cache(self, graph, tensor_fetches, op_fetches,
-                                 on_tpu):
+  def _flush_tensor_values_cache(self, tensor_fetches, op_fetches, on_tpu):
     """Flushes the intermediate tensor values in the graph to the cache.
 
     Args:
-      graph: the graph of Ops
       tensor_fetches: list of tensor results returned by the model_fn.
       op_fetches: list of ops that are returned by the model_fn, e.g., train_op.
       on_tpu: if the graph is executed on TPU.
@@ -862,8 +1109,7 @@
     with ops.control_dependencies(op_fetches +
                                   [tensor.op for tensor in tensor_fetches]):
       flush_cache_op = self._generate_flush_cache_op(
-          graph, self._tt_config.num_replicas, on_tpu)
-
+          self._tt_config.num_replicas, on_tpu)
       return control_flow_ops.tuple(tensor_fetches,
                                     control_inputs=[flush_cache_op])
 
@@ -898,6 +1144,8 @@
     for fetch in op_fetches:
       if isinstance(fetch, ops.Operation):
         fetches.append(fetch)
+      elif isinstance(fetch, ops.Tensor):
+        fetches.append(fetch.op)
       else:
         logging.warning('Ignoring the given op_fetch:%s, which is not an op.' %
                         fetch)
@@ -934,6 +1182,86 @@
       op_control_flow_context = op_control_flow_context.outer_context
     return op_control_flow_context
 
+  def _prepare_host_call_fn(self, processed_t_fetches, op_fetches):
+    """Creates a host call function that will write the cache as tb summary.
+
+    Args:
+      processed_t_fetches: List of tensor provided to session.run.
+      op_fetches: List of operations provided to session.run.
+    Raises:
+      ValueError if trace_dir is not set.
+    """
+    if self._parameters.trace_dir is None:
+      raise ValueError('Provide a trace_dir for tensor tracer in summary mode. '
+                       '--trace_dir=/model/dir')
+
+    def _write_cache(step, **kwargs):
+      """Writes the given caches as tensor summary.
+
+      Args:
+        step: Step tensor with dimension [num_cores].
+        **kwargs: The dictionary of tensors that needs to be written as
+          summaries. Key and value pairs within kwargs correspond to the tag
+          name, and tensor content that will be written using summary.write.
+          The trace_modes that use this function are:
+            - summary: In summary mode, kwargs includes a single (tag, content)
+            pair which are, _TT_SUMMARY_TAG and a tf.float32 signature_cache
+            variable. The dimension of the signature_cache is:
+              num_cores x num_traced_tensors x num_signatures.
+            - full_tensor_summary: kwargs will include all traced tensors. Tag
+            and content correspond to the name of the tensor, and its actual
+            content.
+      Returns:
+        A tf.Operation that needs to be executed for the host call dependencies.
+      """
+
+      # TODO(deveci): Parametrize max_queue, so that flushing op can be called
+      # less frequently.
+      # Setting max_queue to 100 appears to be safe even when the number of
+      # iterations are much lower, as the destructor of the writer will flushes
+      # it.
+      summary_write_ops = []
+      with summary.create_file_writer_v2(
+          self._parameters.trace_dir,
+          filename_suffix=_TT_EVENT_FILE_SUFFIX,
+          max_queue=_TT_SUMMARY_MAX_QUEUE).as_default():
+        summary_metadata = summary_pb2.SummaryMetadata(
+            plugin_data=summary_pb2.SummaryMetadata.PluginData(
+                plugin_name=_TT_TENSORBOARD_PLUGIN_NAME))
+        for key, value in kwargs.items():
+          summary_write_ops.append(summary.write(
+              _TT_SUMMARY_TAG + '/' + key, value, metadata=summary_metadata,
+              step=step[0]))
+      return control_flow_ops.group(summary_write_ops)
+
+    step = array_ops.reshape(training_util.get_or_create_global_step(), [1])
+    self._host_call_fn = {}
+
+    host_call_deps = op_fetches + [tensor.op for tensor in processed_t_fetches]
+
+    caches_to_write = {}
+    with ops.control_dependencies(host_call_deps):
+      all_caches = self._get_all_cache_variables()
+      for cache_name, cache_variable in all_caches.items():
+        # Increase the cache rank by 1, so that when host call concatenates
+        # tensors from different replicas, we can identify them with [core_id].
+        new_cache_shape = [1]
+        new_cache_shape.extend(cache_variable.shape.as_list())
+        cache = array_ops.reshape(cache_variable.value(), new_cache_shape)
+        caches_to_write[cache_name] = cache
+    # Add step to parameter dictionary.
+    caches_to_write['step'] = step
+    # Other options without adding step to parameter dictionary are
+    #  * host_call_fn = (_write_cache(step, caches_to_write)) : fails as it
+    #    considers caches_to_write as a single parameter, rather than a keyword
+    #    parameters.
+    #  * host_call_fn = (_write_cache(step, **caches_to_write)) : fails with
+    #    a syntax error.
+    self._host_call_fn[_TT_HOSTCALL_KEY] = (_write_cache, caches_to_write)
+
+  def host_call_deps_and_fn(self):
+    return self._host_call_fn
+
   def _trace_execution(self, graph,
                        tensor_fetches,
                        op_fetches=None,
@@ -974,6 +1302,8 @@
       return tensor
 
     TensorTracer.check_device_type(self._tt_config.device_type)
+    TensorTracer.check_trace_mode(self._tt_config.device_type,
+                                  self._parameters.trace_mode)
     # Check in_tensor_fetches, and op_fetches and convert them to lists.
     processed_t_fetches = self._process_tensor_fetches(tensor_fetches)
     op_fetches = self._process_op_fetches(op_fetches)
@@ -1021,16 +1351,25 @@
         # pylint: disable=protected-access
         graph._set_control_flow_context(op_control_flow_context)
         # pylint: enable=protected-access
-        processed_out_tensor = self._preprocess_traced_tensor(out_tensor)
+        processed_tensors = self._preprocess_traced_tensor(out_tensor)
 
         if on_tpu:
-          processed_out_tensor = _cast_unsupported_dtypes(processed_out_tensor)
+          for signature in processed_tensors.keys():
+            processed_tensors[signature] = _cast_unsupported_dtypes(
+                processed_tensors[signature])
 
         if self._use_tensor_values_cache():
+          # Use a small cache to store the characteristics of the tensor.
           cache_idx = tensor_trace_order.tensorname_to_cache_idx[tensor_name]
-          trace_op = self._save_tensor_value_to_cache_op(graph,
-                                                         cache_idx,
-                                                         processed_out_tensor)
+          trace_op = self._save_tensor_value_to_cache_op(cache_idx,
+                                                         processed_tensors)
+        elif self._use_tensor_buffer():
+          if len(processed_tensors) != 1:
+            raise RuntimeError('Multiple stats are only allowed in compact '
+                               'mode.')
+          processed_out_tensor = processed_tensors.values()[0]
+          # Store the whole tensor in a buffer.
+          trace_op = self._snapshot_tensor(processed_out_tensor)
         else:
 
           def tpu_wrap_trace_fn(tensor, out_tensor_name):
@@ -1049,6 +1388,14 @@
                 predicate_tensor, lambda: trace_fn(out_tensor, out_tensor_name),
                 lambda: constant_op.constant(False)).op
 
+          if len(processed_tensors) != 1:
+            raise RuntimeError('Multiple stats are only allowed in compact '
+                               'mode.')
+          # Collecting multiple statistics are only supported in the summary
+          # mode that uses compact format(self._use_tensor_values_cache = true).
+          # Non-compact mode currently allows single stat per tensor.
+          processed_out_tensor = processed_tensors.values()[0]
+
           if self._parameters.is_conditional_trace:
             trace_op = conditional_trace_fn(processed_out_tensor, out_tensor,
                                             tpu_wrap_trace_fn, tensor_name)
@@ -1081,11 +1428,13 @@
       # tracing_ops.
       processed_t_fetches = control_flow_ops.tuple(processed_t_fetches,
                                                    control_inputs=tracing_ops)
-    if self._use_tensor_values_cache():
-      processed_t_fetches = self._flush_tensor_values_cache(graph,
-                                                            processed_t_fetches,
-                                                            op_fetches,
-                                                            on_tpu=on_tpu)
+    if self._use_tensor_values_cache() or self._use_tensor_buffer():
+      if self._create_host_call() and on_tpu:
+        self._prepare_host_call_fn(processed_t_fetches, op_fetches)
+      else:
+        processed_t_fetches = self._flush_tensor_values_cache(
+            processed_t_fetches, op_fetches, on_tpu=on_tpu)
+
     # processed_t_fetches is a list at this point. Convert it to the same
     # format as given in tensor_fetches.
     return self._convert_fetches_to_input_format(tensor_fetches,
diff --git a/tensorflow/python/tpu/tensor_tracer_flags.py b/tensorflow/python/tpu/tensor_tracer_flags.py
index bf2752d..2094c11 100644
--- a/tensorflow/python/tpu/tensor_tracer_flags.py
+++ b/tensorflow/python/tpu/tensor_tracer_flags.py
@@ -31,6 +31,12 @@
 TRACE_MODE_FULL_IF_NAN = 'trace-back-if-nan'
 TRACE_MODE_NORM = 'norm'
 TRACE_MODE_MAX_ABS = 'max-abs'
+TRACE_MODE_SUMMARY = 'summary'
+# summary mode to collects a finite set of signatures for each traced tensor,
+# (such as norm, max, min, mean) and dumps it using tb summaries.
+TRACE_MODE_FULL_TENSOR_SUMMARY = 'full_tensor_summary'
+# Full tensor mode dumps the whole tensor values for the traced tensors without
+# any processing on them; using tb summaries.
 _FLAG_NAME_TRACE_STACK_SIZE = 'trace_stack_size'
 _SUBMODE_BRIEF = 'brief'
 _SUBMODE_DETAILED = 'detailed'
@@ -164,7 +170,8 @@
       trace_mode = TRACE_MODE_NORM
     valid_trace_modes = [
         TRACE_MODE_NAN_INF, TRACE_MODE_PART_TENSOR, TRACE_MODE_FULL_TENSOR,
-        TRACE_MODE_NORM, TRACE_MODE_MAX_ABS, TRACE_MODE_FULL_IF_NAN
+        TRACE_MODE_NORM, TRACE_MODE_MAX_ABS, TRACE_MODE_FULL_IF_NAN,
+        TRACE_MODE_SUMMARY, TRACE_MODE_FULL_TENSOR_SUMMARY
     ]
     if trace_mode not in valid_trace_modes:
       raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.'
diff --git a/tensorflow/python/tpu/tensor_tracer_report.py b/tensorflow/python/tpu/tensor_tracer_report.py
index 4bf76aa..29e4875 100644
--- a/tensorflow/python/tpu/tensor_tracer_report.py
+++ b/tensorflow/python/tpu/tensor_tracer_report.py
@@ -18,10 +18,12 @@
 from __future__ import division
 from __future__ import print_function
 
+import os
 import collections
 
 from tensorflow.python.platform import gfile
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.tpu import tensor_tracer_pb2
 
 _TRACER_LOG_PREFIX = ' [>>>TT>>>]'
 _MARKER_SECTION_BEGIN = '!!!!!!! section-begin:'
@@ -48,6 +50,7 @@
 _FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:'
 
 _CURRENT_VERSION = 'use-outside-compilation'
+_TT_REPORT_PROTO = 'tensor_tracer_report.proto'
 
 
 def topological_sort(g):
@@ -219,6 +222,67 @@
   def instrument_tensor(self, tensor, explanation):
     self.instrument(tensor.name, explanation)
 
+  def create_report_proto(self, tt_config, tt_parameters, tensor_trace_order,
+                          tensor_trace_points, collected_signature_types):
+    """Creates and returns a proto that stores tensor tracer configuration.
+
+    Args:
+      tt_config: TensorTracerConfig object holding information about the run
+        environment (device, # cores, # hosts), and tensor tracer version
+        information.
+      tt_parameters: TTParameters objects storing the user provided parameters
+        for tensor tracer.
+      tensor_trace_order: TensorTraceOrder object storing a topological order of
+        the graph.
+      tensor_trace_points: Progromatically added trace_points/checkpoints.
+        collected_signature_types: The signature types collected, e,g, norm,
+        max, min, mean...
+    Returns:
+      TensorTracerReport proto.
+    """
+    report = tensor_tracer_pb2.TensorTracerReport()
+    report.config.version = tt_config.version
+    report.config.device = tt_config.device_type
+    report.config.num_cores = tt_config.num_replicas
+    report.config.num_hosts = tt_config.num_hosts
+    report.config.num_cores_per_host = tt_config.num_replicas_per_host
+    for core in tt_parameters.included_cores:
+      report.config.included_cores.append(core)
+    report.config.submode = tt_parameters.submode
+    report.config.trace_mode = tt_parameters.trace_mode
+
+    for signature_name, _ in sorted(collected_signature_types.items(),
+                                    key=lambda x: x[1]):
+      report.config.signatures.append(signature_name)
+
+    tf_graph = tensor_trace_order.graph_order.graph
+    report.graphdef.CopyFrom(tf_graph.as_graph_def())
+    for tensor in tensor_trace_order.graph_order.tensors:
+      tensor_def = tensor_tracer_pb2.TensorTracerReport.TracedTensorDef()
+      tensor_def.name = tensor.name
+      if tensor.name in tensor_trace_order.tensorname_to_cache_idx:
+        tensor_def.is_traced = True
+        tensor_def.cache_index = (
+            tensor_trace_order.tensorname_to_cache_idx[tensor.name])
+      else:
+        tensor_def.is_traced = False
+
+      if tensor.name in tensor_trace_points:
+        tensor_def.trace_point_name = tensor_trace_points[tensor.name]
+      if tensor.name in self.instrument_records:
+        tensor_def.explanation = self.instrument_records[tensor.name]
+      elif tensor.op.name in self.instrument_records:
+        tensor_def.explanation = self.instrument_records[tensor.op.name]
+      report.tensordef[tensor.name].CopyFrom(tensor_def)
+    return report
+
+  def write_report_proto(self, report_proto, tt_parameters):
+    """Writes the given report proto under trace_dir."""
+    gfile.MakeDirs(tt_parameters.trace_dir)
+    report_path = os.path.join(tt_parameters.trace_dir, _TT_REPORT_PROTO)
+    with gfile.GFile(report_path, 'wb') as f:
+      f.write(report_proto.SerializeToString())
+
   def create_report(self, tt_config, tt_parameters,
                     tensor_trace_order, tensor_trace_points):
     """Creates a report file and writes the trace information."""
diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py
index c9bcf3a..7667b10 100644
--- a/tensorflow/python/tpu/tpu.py
+++ b/tensorflow/python/tpu/tpu.py
@@ -19,7 +19,6 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
 from six.moves import xrange  # pylint: disable=redefined-builtin
 
 from tensorflow.core.framework import attr_value_pb2
@@ -42,6 +41,7 @@
 from tensorflow.python.tpu.ops import tpu_ops
 from tensorflow.python.util import compat
 from tensorflow.python.util import nest
+from tensorflow.python.util.compat import collections_abc
 from tensorflow.python.util.tf_export import tf_export
 
 ops.NotDifferentiable("TPUReplicatedInput")
@@ -107,6 +107,32 @@
     return tpu_ops.configure_distributed_tpu(embedding_config=config_string)
 
 
+def initialize_system_for_tpu_embedding(embedding_config, job=None):
+  """Initializes a distributed TPU Embedding system for use with TensorFlow.
+
+  The following two are equivalent:
+  1. initialize_system() with embedding_config.
+  2. initialize_system() without embedding_config, then
+     initialize_system_for_tpu_embedding().
+  initialize_system() should not be called with embedding_config if
+  initialize_system_for_tpu_embedding() is meant to be called later.
+
+  Args:
+    embedding_config: a `TPUEmbeddingConfiguration` proto describing the desired
+      configuration of the hardware embedding lookup tables.
+    job: The job (the XXX in TensorFlow device specification /job:XXX) that
+      contains the TPU devices that will be initialized. If job=None it is
+      assumed there is only one job in the TensorFlow flock, and an error will
+      be returned if this assumption does not hold.
+
+  Returns:
+    A no-op.
+  """
+  config_string = embedding_config.SerializeToString()
+  with ops.device(_tpu_system_device_name(job)):
+    return tpu_ops.configure_tpu_embedding(config=config_string)
+
+
 @tf_export(v1=["tpu.shutdown_system"])
 def shutdown_system(job=None):
   """Shuts down a running a distributed TPU system.
@@ -415,7 +441,7 @@
       for index in xrange(len(op.inputs)):
         x = op.inputs[index]
         real_x = self.AddValue(x)
-        if real_x != x:
+        if real_x is not x:
           op._update_input(index, real_x)  # pylint: disable=protected-access
 
     if external_control_inputs:
@@ -1023,7 +1049,7 @@
   if outputs is None:
     outputs = tuple()
   # If the computation only returned one value, makes it a tuple.
-  if not isinstance(outputs, collections.Sequence):
+  if not isinstance(outputs, collections_abc.Sequence):
     outputs = (outputs,)
 
   # Append `no_op` here so that fetching any return value of this function
@@ -1626,8 +1652,10 @@
   """
   # Scan over the top level graph and all function graphs.
   for graph in [prune_graph] + [
-      f for f in prune_graph._functions.values() if isinstance(f, ops.Graph)  # pylint: disable=protected-access
+      f for f in prune_graph._functions.values()  # pylint: disable=protected-access
   ]:
+    if not isinstance(graph, ops.Graph):
+      continue
     for op in graph.get_operations():
       if op.type not in _UNCONNECTED_OPS_TO_PRUNE:
         continue
diff --git a/tensorflow/python/tpu/tpu_embedding.py b/tensorflow/python/tpu/tpu_embedding.py
index 1a781bb..9712f0b 100644
--- a/tensorflow/python/tpu/tpu_embedding.py
+++ b/tensorflow/python/tpu/tpu_embedding.py
@@ -43,17 +43,23 @@
 INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE
 
 
+# TODO(shizhiw): a more future-proof way is to have optimization_parameter such
+#  as AdagradParameters etc instead of learning_rate.
 class TableConfig(
-    collections.namedtuple(
-        'TableConfig',
-        ['vocabulary_size', 'dimension', 'initializer', 'combiner'])):
+    collections.namedtuple('TableConfig', [
+        'vocabulary_size', 'dimension', 'initializer', 'combiner',
+        'hot_id_replication', 'learning_rate', 'learning_rate_key'
+    ])):
   """Embedding table configuration."""
 
   def __new__(cls,
               vocabulary_size,
               dimension,
               initializer=None,
-              combiner='mean'):
+              combiner='mean',
+              hot_id_replication=False,
+              learning_rate=None,
+              learning_rate_key=None):
     """Embedding table configuration.
 
     Args:
@@ -69,6 +75,20 @@
         accuracy, in particular with bag-of-words columns. For more information,
         see `tf.nn.embedding_lookup_sparse`. None is only valid for dense rather
         than sparse tensors.
+      hot_id_replication: If true, enables hot id replication, which can make
+        embedding lookups faster if there are some hot rows in the table.
+      learning_rate: float, static learning rate for this table. If
+        learning_rate and learning_rate_key are both `None`, global
+        static learning rate as specified in `optimization_parameters` in
+        `TPUEmbedding` constructor will be used. `learning_rate_key` must be
+        `None` if `learning_rate` is not `None.
+      learning_rate_key: string, use dynamic learning rate of
+        `learning_rates[learning_rate_key]` for this table, where
+        `learning_rates` is the second argument of
+        `generate_send_gradients_op()`. If learning_rate and learning_rate_key
+        are both `None`, global static learning rate as specified in
+        `optimization_parameters` in `TPUEmbedding` constructor will be used.
+        `learning_rate` must be `None` if `learning_rate_key` is not `None.
 
     Returns:
       `TableConfig`.
@@ -78,6 +98,8 @@
       ValueError: if `dimension` is not positive integer.
       ValueError: if `initializer` is specified and is not callable.
       ValueError: if `combiner` is not supported.
+      ValueError: if `learning_rate` and `learning_rate_key` are both not
+        `None`.
     """
     if not isinstance(vocabulary_size, int) or vocabulary_size < 1:
       raise ValueError('Invalid vocabulary_size {}.'.format(vocabulary_size))
@@ -94,8 +116,14 @@
     if combiner not in ('mean', 'sum', 'sqrtn', None):
       raise ValueError('Invalid combiner {}'.format(combiner))
 
-    return super(TableConfig, cls).__new__(cls, vocabulary_size, dimension,
-                                           initializer, combiner)
+    if learning_rate is not None and learning_rate_key is not None:
+      raise ValueError('At most one of learning_rate and learning_rate_key '
+                       'can be None; got {} and {}'
+                       .format(learning_rate, learning_rate_key))
+
+    return super(TableConfig, cls).__new__(
+        cls, vocabulary_size, dimension, initializer, combiner,
+        hot_id_replication, learning_rate, learning_rate_key)
 
 
 class FeatureConfig(
@@ -656,6 +684,10 @@
 
   def _create_config_proto(self):
     """Create `TPUEmbeddingConfiguration`."""
+    self._learning_rate_keys = list(
+        set(c.learning_rate_key
+            for c in self._table_to_config_dict.values()
+            if c.learning_rate_key is not None))
     config_proto = elc.TPUEmbeddingConfiguration()
     for table in self._table_to_config_dict:
       table_descriptor = config_proto.table_descriptor.add()
@@ -670,18 +702,28 @@
 
       table_descriptor.num_features = self._table_to_num_features_dict[table]
 
-      table_descriptor.optimization_parameters.learning_rate.constant = (
-          self._optimization_parameters.learning_rate)
-      table_descriptor.optimization_parameters.gradient_accumulation_status = (
+      parameters = table_descriptor.optimization_parameters
+      if table_config.learning_rate:
+        parameters.learning_rate.constant = (table_config.learning_rate)
+      elif table_config.learning_rate_key:
+        parameters.learning_rate.dynamic.tag = (
+            self._learning_rate_keys.index(table_config.learning_rate_key))
+      else:
+        parameters.learning_rate.constant = (
+            self._optimization_parameters.learning_rate)
+      parameters.gradient_accumulation_status = (
           optimization_parameters_pb2.GradientAccumulationStatus.ENABLED
           if self._optimization_parameters.use_gradient_accumulation else
           optimization_parameters_pb2.GradientAccumulationStatus.DISABLED)
       if self._optimization_parameters.clip_weight_min is not None:
-        table_descriptor.optimization_parameters.clipping_limits.lower.value = (
+        parameters.clipping_limits.lower.value = (
             self._optimization_parameters.clip_weight_min)
       if self._optimization_parameters.clip_weight_max is not None:
-        table_descriptor.optimization_parameters.clipping_limits.upper.value = (
+        parameters.clipping_limits.upper.value = (
             self._optimization_parameters.clip_weight_max)
+      if table_config.hot_id_replication:
+        parameters.hot_id_replication_configuration.status = (
+            optimization_parameters_pb2.HotIdReplicationConfiguration.ENABLED)
       self._optimizer_handler.set_optimization_parameters(table_descriptor)
 
     config_proto.mode = self._mode
@@ -960,12 +1002,16 @@
 
     return activations
 
-  def generate_send_gradients_op(self, feature_to_gradient_dict):
+  def generate_send_gradients_op(self,
+                                 feature_to_gradient_dict,
+                                 learning_rates=None):
     """Send gradient to TPU embedding.
 
     Args:
       feature_to_gradient_dict: dict mapping feature names to gradient wrt
         activations.
+      learning_rates: dict mapping from learning rate key to dynamic learning
+        rate. Defaults to `None`.
 
     Returns:
       SendTPUEmbeddingGradients Op.
@@ -977,6 +1023,10 @@
       raise RuntimeError('Only in training mode gradients need to '
                          'be sent to TPU embedding; got mode {}.'
                          .format(self._mode))
+
+    if learning_rates is None:
+      learning_rates = dict()
+
     gradients = []
     for table in self._table_to_features_dict:
       features = self._table_to_features_dict[table]
@@ -991,8 +1041,13 @@
           array_ops.concat(table_gradients, axis=1),
           [-1, array_ops.shape(table_gradients[0])[-1]])
       gradients.append(interleaved_table_grads)
+
     return tpu_ops.send_tpu_embedding_gradients(
-        inputs=gradients, config=self.config_proto.SerializeToString())
+        inputs=gradients,
+        learning_rates=[
+            learning_rates[tag] for tag in self._learning_rate_keys
+        ],
+        config=self.config_proto.SerializeToString())
 
 
 def _validate_table_to_config_dict(table_to_config_dict):
diff --git a/tensorflow/python/training/adagrad_da_test.py b/tensorflow/python/training/adagrad_da_test.py
index aacfe6f..0730618 100644
--- a/tensorflow/python/training/adagrad_da_test.py
+++ b/tensorflow/python/training/adagrad_da_test.py
@@ -63,9 +63,9 @@
         update.run()
 
         v0_val, v1_val = self.evaluate([var0, var1])
-        # Let g to be gradient accumulator, gg to be gradient squared
-        # accumulator, T be the global step, lr is the learning rate, and k the
-        # initial gradient squared accumulator value.
+        # Let g be the gradient accumulator, gg be the gradient squared
+        # accumulator, T be the global step, lr be the learning rate,
+        # and k the initial gradient squared accumulator value.
         # w = \dfrac{sign(-g)*lr*|g - l1*T|_{+}}{l2*T*lr + \sqrt{k+gg})}
         # For -0.1*3.0*(0.1 - 0)/(0 + sqrt(0.1 + 0.1*0.1)) = -0.904534
         # similarly for others.
diff --git a/tensorflow/python/training/experimental/loss_scale.py b/tensorflow/python/training/experimental/loss_scale.py
index bbbd0cd..46f52f0 100644
--- a/tensorflow/python/training/experimental/loss_scale.py
+++ b/tensorflow/python/training/experimental/loss_scale.py
@@ -205,7 +205,7 @@
         number as long as no nan or inf is encountered in training.
 
     Raises:
-      ValueError: If loss_scale is less than 1.
+      ValueError: If loss_scale_value is less than 1.
     """
     super(FixedLossScale, self).__init__()
     if not isinstance(loss_scale_value, six.integer_types + (float,)):
@@ -227,6 +227,9 @@
     del grads
     return control_flow_ops.no_op(), True
 
+  def __repr__(self):
+    return 'FixedLossScale(%s)' % self._loss_scale_value
+
   def get_config(self):
     return {'loss_scale_value': self._loss_scale_value}
 
@@ -376,6 +379,17 @@
     should_apply_gradients = is_finite
     return update_op, should_apply_gradients
 
+  def __repr__(self):
+    if context.executing_eagerly():
+      return ('DynamicLossScale(current_loss_scale=%s, num_good_steps=%s, '
+              'initial_loss_scale=%s, increment_period=%s, multiplier=%s)' %
+              (self._current_loss_scale.numpy(), self._num_good_steps.numpy(),
+               self.initial_loss_scale, self.increment_period, self.multiplier))
+    else:
+      return ('DynamicLossScale(initial_loss_scale=%s, increment_period=%s, '
+              'multiplier=%s)' %
+              (self.initial_loss_scale, self.increment_period, self.multiplier))
+
   def get_config(self):
     return {
         'initial_loss_scale': self.initial_loss_scale,
diff --git a/tensorflow/python/training/experimental/loss_scale_test.py b/tensorflow/python/training/experimental/loss_scale_test.py
index c3e18a1..e4a1114 100644
--- a/tensorflow/python/training/experimental/loss_scale_test.py
+++ b/tensorflow/python/training/experimental/loss_scale_test.py
@@ -92,6 +92,11 @@
     scalar = loss_scale_module.FixedLossScale(123)
     self.assertIsInstance(scalar(), ops.Tensor)
 
+  @test_util.run_in_graph_and_eager_modes
+  def test_repr(self):
+    loss_scale = loss_scale_module.FixedLossScale(123)
+    self.assertEqual(repr(loss_scale), 'FixedLossScale(123.0)')
+
 
 def _get_example_iter(inputs):
   dataset = dataset_ops.Dataset.from_tensor_slices(inputs)
@@ -302,5 +307,22 @@
     scalar = loss_scale_module.DynamicLossScale()
     self.assertIsInstance(scalar(), ops.Tensor)
 
+  @parameterized.named_parameters(*TESTCASES)
+  @test_util.run_in_graph_and_eager_modes
+  def test_repr(self, strategy_fn):
+    with strategy_fn().scope():
+      loss_scale = loss_scale_module.DynamicLossScale(
+          initial_loss_scale=1, increment_period=2, multiplier=3)
+      if context.executing_eagerly():
+        self.assertEqual(repr(loss_scale),
+                         'DynamicLossScale(current_loss_scale=1.0, '
+                         'num_good_steps=0, initial_loss_scale=1.0, '
+                         'increment_period=2, multiplier=3.0)')
+      else:
+        self.assertEqual(repr(loss_scale),
+                         'DynamicLossScale(initial_loss_scale=1.0, '
+                         'increment_period=2, multiplier=3.0)')
+
+
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/training/experimental/mixed_precision.py b/tensorflow/python/training/experimental/mixed_precision.py
index 949c498..6fb4de7 100644
--- a/tensorflow/python/training/experimental/mixed_precision.py
+++ b/tensorflow/python/training/experimental/mixed_precision.py
@@ -27,7 +27,7 @@
 from tensorflow.python.util.tf_export import tf_export
 
 
-def _wrap_optimizer(opt, loss_scale):
+def _wrap_optimizer(opt, loss_scale, use_v1_behavior):
   """Wraps an optimizer with a LossScaleOptimizer."""
 
   if isinstance(opt, loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer):
@@ -67,12 +67,60 @@
     from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as loss_scale_optimizer_v2  # pylint: disable=g-import-not-at-top
     return loss_scale_optimizer_v2.LossScaleOptimizer(opt, loss_scale)
 
-  raise ValueError('"opt" must be an instance of a tf.train.Optimizer or a '
-                   'tf.keras.optimizers.Optimizer, but got: %s' % opt)
+  if use_v1_behavior:
+    raise ValueError('"opt" must be an instance of a tf.train.Optimizer or a '
+                     'tf.keras.optimizers.Optimizer, but got: %s' % opt)
+  else:
+    raise ValueError('"opt" must be an instance of a '
+                     'tf.keras.optimizers.Optimizer, but got: %s' % opt)
+
+
+@tf_export('train.experimental.enable_mixed_precision_graph_rewrite', v1=[])
+def enable_mixed_precision_graph_rewrite(opt, loss_scale='dynamic'):
+  """Enable mixed precision in `tf.function`s via a graph rewrite.
+
+  Mixed precision is the use of both float16 and float32 when training a model,
+  and is used to make the model run faster. This function will use mixed
+  precision to speed up the execution time of `tf.function`s when run on a GPU.
+  It does this by changing the dtype of certain operations in the function's
+  graph from float32 to float16.
+
+  This function additionally wraps an Optimizer with a LossScaleOptimizer, which
+  is required to prevent underflow in the float16 tensors during the backwards
+  pass. An optimizer must be passed to this function, which will then be wrapped
+  to use loss scaling.
+
+  When this function is used, gradients should only be computed and applied with
+  the returned optimizer through `opt.minimize()`, and not with a
+  `tf.GradientTape`. This is because the returned optimizer will apply loss
+  scaling, and `tf.GradientTape` will not. If you do use a `tf.GradientTape`,
+  your model may train to a worse quality.
+
+  Currently, mixed precision is only enabled on Volta GPUs and above. TPU
+  support is coming soon. CPUs are not supported, as CPUs do not run float16
+  operations faster than float32 operations.
+
+  WARNING: This rewrite silently affects the entire model and can have
+  unintended consequences. One example: If a NaN occurs during dynamic loss
+  scaling, the data for the batch is silently dropped while the
+  LossScaleOptimizer attempts to find the appropriate scaling value on the next
+  batch.
+
+  Args:
+    opt: An instance of a `tf.keras.optimizers.Optimizer`.
+    loss_scale: Either an int/float, the string "dynamic", or an instance of a
+      `tf.train.experimental.LossScale`. The loss scale to use. It is
+      recommended to keep this as its default value of "dynamic".
+
+  Returns:
+    A version of `opt` that will use loss scaling to prevent underflow.
+  """
+  return _enable_mixed_precision_graph_rewrite_base(opt, loss_scale,
+                                                    use_v1_behavior=False)
 
 
 @tf_export(v1=['train.experimental.enable_mixed_precision_graph_rewrite'])
-def enable_mixed_precision_graph_rewrite(opt, loss_scale='dynamic'):
+def enable_mixed_precision_graph_rewrite_v1(opt, loss_scale='dynamic'):
   """Enable mixed precision via a graph rewrite.
 
   Mixed precision is the use of both float16 and float32 when training a model,
@@ -94,11 +142,9 @@
   `tf.gradients`/`tf.GradientTape` will not. If you do directly use
   `tf.gradients` or `tf.GradientTape`, your model may train to a worse quality.
 
-  When eager execution is enabled, the mixed precision graph rewrite is only
-  enabled within `tf.function`s, as outside `tf.function`s, there is no graph.
-
-  When enabled, mixed precision is only used on Volta GPUs and above. The parts
-  of the graph on CPUs and TPUs are untouched by the graph rewrite.
+  Currently, mixed precision is only enabled on Volta GPUs and above. TPU
+  support is coming soon. CPUs are not supported, as CPUs do not run float16
+  operations faster than float32 operations.
 
   Args:
     opt: An instance of a `tf.keras.optimizers.Optimizer` or a
@@ -112,6 +158,13 @@
   """
   # TODO(reedwm): If a ConfigProto is passed to Session, either assert that
   # auto_mixed_precision is on or turn it on for the user.
+  return _enable_mixed_precision_graph_rewrite_base(opt, loss_scale,
+                                                    use_v1_behavior=True)
+
+
+def _enable_mixed_precision_graph_rewrite_base(opt, loss_scale,
+                                               use_v1_behavior):
+  """Enables mixed precision. See `enable_mixed_precision_graph_rewrite`."""
   if not mixed_precision_global_state.using_default_mixed_precision_policy:
     raise ValueError(
         'The mixed precision graph rewrite cannot be enabled, because a keras '
@@ -122,10 +175,11 @@
         '  2. tf.train.experimental.enable_mixed_precision_graph_rewrite() '
         '(You called this second)\n\n'
         'You called both functions, which is an error, because both functions '
-        'enable you to use mixed precision. The second function enables mixed '
-        'precision in the graph with a graph rewrite. However it is currently '
-        'not very customizable, and does not support eager. The first '
-        'function is for Keras layers, but is not yet fully complete.')
+        'enable you to use mixed precision. If in doubt which function to use, '
+        'use the second, as it is currently more complete and easy to use. The '
+        'second function enables mixed precision in the graph with a graph '
+        'rewrite. However it is currently not very customizable, and does not '
+        'support eager.')
 
   if mixed_precision_global_state.non_mixed_precision_session_created:
     # TODO(reedwm): Give the stacktrace of the existing Sessions. And if the
@@ -133,17 +187,41 @@
     tf_logging.warn('You already have existing Sessions that do not use mixed '
                     'precision. enable_mixed_precision_graph_rewrite() will '
                     'not affect these Sessions.')
-  opt = _wrap_optimizer(opt, loss_scale)
+  opt = _wrap_optimizer(opt, loss_scale, use_v1_behavior=use_v1_behavior)
   config.set_optimizer_experimental_options({'auto_mixed_precision': True})
   mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled = True
   return opt
 
 
-@tf_export(v1=['train.experimental.disable_mixed_precision_graph_rewrite'])
+@tf_export('train.experimental.disable_mixed_precision_graph_rewrite', v1=[])
 def disable_mixed_precision_graph_rewrite():
   """Disables the mixed precision graph rewrite.
 
   After this is called, the mixed precision graph rewrite will no longer run for
+  tf.functions, and so float32 operations will no longer be converted to
+  float16.
+
+  This does not undo the effects of loss scaling. Any optimizers wrapped with a
+  LossScaleOptimizer will continue to do loss scaling, although this loss
+  scaling will no longer be useful, as the graph rewrite no longer converts
+  tf.functions to use float16.
+
+  This function is useful for unit testing. A unit test can test using the mixed
+  precision graph rewrite, then disable it so future unit tests continue using
+  float32.
+  """
+  if not mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled:
+    tf_logging.warn('disable_mixed_precision_graph_rewrite() called when mixed '
+                    'precision is already disabled.')
+  config.set_optimizer_experimental_options({'auto_mixed_precision': False})
+  mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled = False
+
+
+@tf_export(v1=['train.experimental.disable_mixed_precision_graph_rewrite'])
+def disable_mixed_precision_graph_rewrite_v1():
+  """Disables the mixed precision graph rewrite.
+
+  After this is called, the mixed precision graph rewrite will no longer run for
   new Sessions, and so float32 operations will no longer be converted to float16
   in such Sessions. However, any existing Sessions will continue to have the
   graph rewrite enabled if they were created after
@@ -161,8 +239,6 @@
   as `enable_mixed_precision_graph_rewrite` and
   `disable_mixed_precision_graph_rewrite` have no effect on existing sessions.
   """
-  if not mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled:
-    tf_logging.warn('disable_mixed_precision_graph_rewrite() called when mixed '
-                    'precision is already disabled.')
-  config.set_optimizer_experimental_options({'auto_mixed_precision': False})
-  mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled = False
+  # We only have a separate V1 version of this function, because the V1
+  # docstring mentions sessions.
+  disable_mixed_precision_graph_rewrite()
diff --git a/tensorflow/python/training/experimental/mixed_precision_test.py b/tensorflow/python/training/experimental/mixed_precision_test.py
index 162aee5..2b03906 100644
--- a/tensorflow/python/training/experimental/mixed_precision_test.py
+++ b/tensorflow/python/training/experimental/mixed_precision_test.py
@@ -21,6 +21,7 @@
 from absl.testing import parameterized
 
 from tensorflow.core.protobuf import config_pb2
+from tensorflow.python import tf2
 from tensorflow.python.client import session
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
@@ -40,6 +41,14 @@
 from tensorflow.python.training.experimental import mixed_precision_global_state
 
 
+if tf2.enabled():
+  enable_mixed_precision_graph_rewrite = (
+      mixed_precision.enable_mixed_precision_graph_rewrite)
+else:
+  enable_mixed_precision_graph_rewrite = (
+      mixed_precision.enable_mixed_precision_graph_rewrite_v1)
+
+
 class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
 
   IGNORE_PERF_VAR = 'TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE'
@@ -64,13 +73,13 @@
   @test_util.run_in_graph_and_eager_modes
   def test_wrap_optimizer(self):
     opt = gradient_descent_v1.GradientDescentOptimizer(1.0)
-    opt = mixed_precision.enable_mixed_precision_graph_rewrite(opt, 123.)
+    opt = enable_mixed_precision_graph_rewrite(opt, 123.)
     self.assertIsInstance(
         opt, loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer)
     self.assertEqual(self.evaluate(opt._loss_scale()), 123.)
 
     opt = gradient_descent_v2.SGD(1.0)
-    opt = mixed_precision.enable_mixed_precision_graph_rewrite(opt, 123.)
+    opt = enable_mixed_precision_graph_rewrite(opt, 123.)
     self.assertIsInstance(
         opt, loss_scale_optimizer_v2.LossScaleOptimizer)
     self.assertEqual(self.evaluate(opt._loss_scale()), 123.)
@@ -78,10 +87,14 @@
   @test_util.run_in_graph_and_eager_modes
   def test_optimizer_errors(self):
     opt = 1
-    expected_regex = ('"opt" must be an instance of a tf.train.Optimizer or '
-                      'a tf.keras.optimizers.Optimizer, but got')
+    if tf2.enabled():
+      expected_regex = ('"opt" must be an instance of a '
+                        'tf.keras.optimizers.Optimizer, but got')
+    else:
+      expected_regex = ('"opt" must be an instance of a tf.train.Optimizer or '
+                        'a tf.keras.optimizers.Optimizer, but got')
     with self.assertRaisesRegexp(ValueError, expected_regex):
-      mixed_precision.enable_mixed_precision_graph_rewrite(opt)
+      enable_mixed_precision_graph_rewrite(opt)
     self.assertFalse(config.get_optimizer_experimental_options()
                      .get('auto_mixed_precision', False))
 
@@ -91,7 +104,7 @@
     with self.assertRaisesRegexp(ValueError,
                                  '"opt" must not already be an instance of a '
                                  'MixedPrecisionLossScaleOptimizer.'):
-      mixed_precision.enable_mixed_precision_graph_rewrite(opt)
+      enable_mixed_precision_graph_rewrite(opt)
     self.assertFalse(config.get_optimizer_experimental_options()
                      .get('auto_mixed_precision', False))
 
@@ -100,7 +113,7 @@
     with self.assertRaisesRegexp(ValueError,
                                  '"opt" must not already be an instance of a '
                                  'LossScaleOptimizer.'):
-      mixed_precision.enable_mixed_precision_graph_rewrite(opt)
+      enable_mixed_precision_graph_rewrite(opt)
     self.assertFalse(config.get_optimizer_experimental_options()
                      .get('auto_mixed_precision', False))
 
@@ -108,7 +121,7 @@
   @test_util.run_in_graph_and_eager_modes
   def test_grappler_pass_enabled(self):
     opt = gradient_descent_v2.SGD(1.0)
-    mixed_precision.enable_mixed_precision_graph_rewrite(opt, 123.)
+    enable_mixed_precision_graph_rewrite(opt, 123.)
 
     var = variables.Variable([[1.0]])
 
@@ -153,8 +166,7 @@
     mixed_precision_global_state.non_mixed_precision_session_created = False
 
     with session.Session():
-      mixed_precision.enable_mixed_precision_graph_rewrite(
-          gradient_descent_v2.SGD(1.0))
+      enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
       mock_warn.assert_any_call(
           'You already have existing Sessions that do not use mixed precision. '
           'enable_mixed_precision_graph_rewrite() will not affect these '
@@ -166,8 +178,7 @@
     # the warning.
     mixed_precision_global_state.non_mixed_precision_session_created = False
 
-    mixed_precision.enable_mixed_precision_graph_rewrite(
-        gradient_descent_v2.SGD(1.0))
+    enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
     with session.Session():
       # Make sure the "You already have existing Sessions" warning was not
       # issued, since the Session was only created after
@@ -181,11 +192,9 @@
     with policy.policy_scope('infer_float32_vars'):
       with self.assertRaisesRegexp(
           ValueError, 'a keras mixed precision Policy has been set'):
-        mixed_precision.enable_mixed_precision_graph_rewrite(
-            gradient_descent_v2.SGD(1.0))
+        enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
     # Test no error is thrown when the policy is current the default.
-    mixed_precision.enable_mixed_precision_graph_rewrite(
-        gradient_descent_v2.SGD(1.0))
+    enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
index 21408f3..41c8c71 100644
--- a/tensorflow/python/training/input.py
+++ b/tensorflow/python/training/input.py
@@ -24,8 +24,6 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
-
 from six.moves import xrange  # pylint: disable=redefined-builtin
 
 from tensorflow.python.eager import context
@@ -46,6 +44,7 @@
 from tensorflow.python.summary import summary
 from tensorflow.python.training import queue_runner
 from tensorflow.python.util import deprecation
+from tensorflow.python.util.compat import collections_abc
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -600,7 +599,7 @@
 
 def _restore_sparse_tensors(stored_list, sparse_info_list):
   """Restore SparseTensors after dequeue in batch, batch_join, etc."""
-  received_sequence = isinstance(stored_list, collections.Sequence)
+  received_sequence = isinstance(stored_list, collections_abc.Sequence)
   if not received_sequence:
     stored_list = (stored_list,)
   tensors = [
diff --git a/tensorflow/python/training/saving/saveable_object_util.py b/tensorflow/python/training/saving/saveable_object_util.py
index 81d3d4d..099fcf0 100644
--- a/tensorflow/python/training/saving/saveable_object_util.py
+++ b/tensorflow/python/training/saving/saveable_object_util.py
@@ -28,6 +28,7 @@
 from tensorflow.python.ops import variables
 from tensorflow.python.training.saving import saveable_object
 from tensorflow.python.training.tracking import base as trackable
+from tensorflow.python.util import object_identity
 
 
 # Op names which identify variable reads which should be saved.
@@ -335,7 +336,7 @@
     names_to_saveables = op_list_to_dict(names_to_saveables)
 
   saveables = []
-  seen_ops = set()
+  seen_ops = object_identity.ObjectIdentitySet()
   for name, op in sorted(names_to_saveables.items(),
                          # Avoid comparing ops, sort only by name.
                          key=lambda x: x[0]):
diff --git a/tensorflow/python/training/tracking/base.py b/tensorflow/python/training/tracking/base.py
index 00bb8e6..8efeb71 100644
--- a/tensorflow/python/training/tracking/base.py
+++ b/tensorflow/python/training/tracking/base.py
@@ -31,7 +31,6 @@
 from tensorflow.python.ops import gen_io_ops as io_ops
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training.saving import saveable_object
-from tensorflow.python.util import nest
 from tensorflow.python.util import tf_contextlib
 from tensorflow.python.util import tf_decorator
 
@@ -307,7 +306,7 @@
         value_tensors[serialized_tensor.name] = array_ops.identity(value)
       return value_tensors
 
-  def _gather_ops_or_named_saveables(self):
+  def gather_ops_or_named_saveables(self):
     """Looks up or creates SaveableObjects which don't have cached ops."""
     saveables = self.trackable._gather_saveables_for_checkpoint()  # pylint: disable=protected-access
     # Name saveables based on the name this object had when it was checkpointed.
@@ -391,7 +390,7 @@
       eagerly.
     """
     (restore_ops, tensor_saveables,
-     python_saveables) = self._gather_ops_or_named_saveables()
+     python_saveables) = self.gather_ops_or_named_saveables()
     restore_ops.extend(
         self._checkpoint.restore_saveables(tensor_saveables, python_saveables))
     return restore_ops
@@ -858,13 +857,21 @@
     # traversals will happen later).
     visit_queue = collections.deque([checkpoint_position])
     restore_ops = []
+    tensor_saveables = {}
+    python_saveables = []
     while visit_queue:
       current_position = visit_queue.popleft()
-      restore_ops.extend(
-          nest.flatten(current_position.trackable  # pylint: disable=protected-access
-                       ._single_restoration_from_checkpoint_position(
-                           checkpoint_position=current_position,
-                           visit_queue=visit_queue)))
+      new_restore_ops, new_tensor_saveables, new_python_saveables = (
+          current_position.trackable  # pylint: disable=protected-access
+          ._single_restoration_from_checkpoint_position(
+              checkpoint_position=current_position,
+              visit_queue=visit_queue))
+      restore_ops.extend(new_restore_ops)
+      tensor_saveables.update(new_tensor_saveables)
+      python_saveables.extend(new_python_saveables)
+    restore_ops.extend(
+        current_position.checkpoint.restore_saveables(
+            tensor_saveables, python_saveables))
     return restore_ops
 
   def _single_restoration_from_checkpoint_position(self, checkpoint_position,
@@ -876,10 +883,13 @@
     # need to actually restore the object. However, we should pass the
     # restoration on to our dependencies.
     if checkpoint.restore_uid > self._self_update_uid:
-      restore_ops = checkpoint_position.restore_ops()
+      restore_ops, tensor_saveables, python_saveables = (
+          checkpoint_position.gather_ops_or_named_saveables())
       self._self_update_uid = checkpoint.restore_uid
     else:
       restore_ops = ()
+      tensor_saveables = {}
+      python_saveables = ()
     for child in checkpoint_position.object_proto.children:
       child_position = CheckpointPosition(
           checkpoint=checkpoint, proto_id=child.node_id)
@@ -896,7 +906,7 @@
           # resolution order (shallowest paths first). The caller is responsible
           # for emptying visit_queue.
           visit_queue.append(child_position)
-    return restore_ops
+    return restore_ops, tensor_saveables, python_saveables
 
   def _gather_saveables_for_checkpoint(self):
     """Returns a dictionary of values to checkpoint with this object.
diff --git a/tensorflow/python/training/tracking/benchmarks_test.py b/tensorflow/python/training/tracking/benchmarks_test.py
index a3cec89c..7514d9f 100644
--- a/tensorflow/python/training/tracking/benchmarks_test.py
+++ b/tensorflow/python/training/tracking/benchmarks_test.py
@@ -21,11 +21,13 @@
 import os
 import time
 
+from tensorflow.python import pywrap_tensorflow
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.module import module
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_io_ops
 from tensorflow.python.platform import test
 from tensorflow.python.training.saving import saveable_object
 from tensorflow.python.training.tracking import base
@@ -112,6 +114,17 @@
 
     self._run(_create_and_call, 3)
 
+  def benchmark_raw_restore(self):
+    checkpoint_path = _save_checkpoint()
+    all_names, all_dtypes = zip(*pywrap_tensorflow.NewCheckpointReader(
+        checkpoint_path).get_variable_to_dtype_map().items())
+
+    def _call_restore_v2():
+      gen_io_ops.restore_v2(checkpoint_path, all_names, [""] * len(all_names),
+                            all_dtypes)
+
+    self._run(_call_restore_v2, 3)
+
 
 if __name__ == "__main__":
   ops.enable_eager_execution()
diff --git a/tensorflow/python/training/tracking/data_structures.py b/tensorflow/python/training/tracking/data_structures.py
index b3f5046..652e9a9 100644
--- a/tensorflow/python/training/tracking/data_structures.py
+++ b/tensorflow/python/training/tracking/data_structures.py
@@ -35,6 +35,7 @@
 from tensorflow.python.saved_model import revived_types
 from tensorflow.python.training.tracking import base
 from tensorflow.python.training.tracking import layer_utils
+from tensorflow.python.util.compat import collections_abc
 
 
 class NoDependency(object):
@@ -249,7 +250,7 @@
     return self is other
 
 
-class List(TrackableDataStructure, collections.Sequence):
+class List(TrackableDataStructure, collections_abc.Sequence):
   """An append-only sequence type which is trackable.
 
   Maintains checkpoint dependencies on its contents (which must also be
@@ -371,9 +372,11 @@
 # TODO(tomhennigan) Update to collections.UserList?
 # TODO(allenl): Try switching this to wrapt.ObjectProxy again when we drop
 # Python 3.4 support (may still be tricky).
-class ListWrapper(List, collections.MutableSequence,
-                  # Shadowed, but there for isinstance checks.
-                  list):
+class ListWrapper(
+    List,
+    collections_abc.MutableSequence,
+    # Shadowed, but there for isinstance checks.
+    list):
   """Wraps the built-in `list` to support restore-on-create for variables.
 
   Unlike `List`, this sequence type is mutable in the same ways built-in lists
@@ -579,7 +582,7 @@
     }
 
 
-class Mapping(TrackableDataStructure, collections.Mapping):
+class Mapping(TrackableDataStructure, collections_abc.Mapping):
   """An append-only trackable mapping data structure with string keys.
 
   Maintains checkpoint dependencies on its contents (which must also be
diff --git a/tensorflow/python/training/tracking/tracking.py b/tensorflow/python/training/tracking/tracking.py
index d3aaf78..8b0bc6e 100644
--- a/tensorflow/python/training/tracking/tracking.py
+++ b/tensorflow/python/training/tracking/tracking.py
@@ -383,6 +383,8 @@
     if output is None:
       cache[item] = output = f(item)
     return output
+
+  wrapped.cache = cache
   return wrapped
 
 
diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py
index 510e618..50acc08 100644
--- a/tensorflow/python/training/training.py
+++ b/tensorflow/python/training/training.py
@@ -35,6 +35,7 @@
 from tensorflow.python.training.ftrl import FtrlOptimizer
 from tensorflow.python.training.experimental.loss_scale_optimizer import MixedPrecisionLossScaleOptimizer
 from tensorflow.python.training.experimental.mixed_precision import enable_mixed_precision_graph_rewrite
+from tensorflow.python.training.experimental.mixed_precision import enable_mixed_precision_graph_rewrite_v1
 from tensorflow.python.training.momentum import MomentumOptimizer
 from tensorflow.python.training.moving_averages import ExponentialMovingAverage
 from tensorflow.python.training.optimizer import Optimizer
diff --git a/tensorflow/python/util/compat.py b/tensorflow/python/util/compat.py
index 76ba91d..54d1495 100644
--- a/tensorflow/python/util/compat.py
+++ b/tensorflow/python/util/compat.py
@@ -38,6 +38,12 @@
 
 from tensorflow.python.util.tf_export import tf_export
 
+try:
+  # This import only works on python 3.3 and above.
+  import collections.abc as collections_abc  # pylint: disable=unused-import
+except ImportError:
+  import collections as collections_abc  # pylint: disable=unused-import
+
 
 def as_bytes(bytes_or_text, encoding='utf-8'):
   """Converts `bytearray`, `bytes`, or unicode python input types to `bytes`.
diff --git a/tensorflow/python/util/module_wrapper.py b/tensorflow/python/util/module_wrapper.py
index 7ca1e17..4478fcb 100644
--- a/tensorflow/python/util/module_wrapper.py
+++ b/tensorflow/python/util/module_wrapper.py
@@ -39,9 +39,10 @@
 
 
 def _call_location():
-  # We want to get stack frame 2 frames up from current frame,
-  # i.e. above _getattr__ and _call_location calls.
-  stack = tf_stack.extract_stack_file_and_line(max_length=3)
+  # We want to get stack frame 3 frames up from current frame,
+  # i.e. above __getattr__, _tfmw_add_deprecation_warning,
+  # and _call_location calls.
+  stack = tf_stack.extract_stack_file_and_line(max_length=4)
   if not stack:  # should never happen as we're in a function
     return 'UNKNOWN'
   frame = stack[0]
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index d43720f..bd6b791 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -40,6 +40,7 @@
 
 from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
 from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.compat import collections_abc as _collections_abc
 
 
 _SHALLOW_TREE_HAS_INVALID_KEYS = (
@@ -132,7 +133,14 @@
     # ordered and plain dicts (e.g., flattening a dict but using a
     # corresponding `OrderedDict` to pack it back).
     result = dict(zip(_sorted(instance), args))
-    return type(instance)((key, result[key]) for key in instance)
+    instance_type = type(instance)
+    if instance_type == _collections.defaultdict:
+      d = _collections.defaultdict(instance.default_factory)
+      for key in instance:
+        d[key] = result[key]
+      return d
+    else:
+      return instance_type((key, result[key]) for key in instance)
   elif _is_namedtuple(instance) or _is_attrs(instance):
     return type(instance)(*args)
   elif _is_composite_tensor(instance):
@@ -170,7 +178,7 @@
   Yields:
     The iterable's (key, value) pairs, in order of sorted keys.
   """
-  if isinstance(iterable, _collections.Mapping):
+  if isinstance(iterable, _collections_abc.Mapping):
     # Iterate through dictionaries in a deterministic order by sorting the
     # keys. Notice this means that we ignore the original order of `OrderedDict`
     # instances. This is intentional, to avoid potential bugs caused by mixing
@@ -205,14 +213,14 @@
 
 @tf_export("nest.is_nested")
 def is_nested(seq):
-  """Returns true if its input is a collections.Sequence (except strings).
+  """Returns true if its input is a collections.abc.Sequence (except strings).
 
   Args:
     seq: an input sequence.
 
   Returns:
-    True if the sequence is a not a string and is a collections.Sequence or a
-    dict.
+    True if the sequence is a not a string and is a collections.abc.Sequence
+    or a dict.
   """
   return is_sequence(seq)
 
@@ -344,7 +352,7 @@
     ValueError: If any key and value do not have the same structure layout, or
     if keys are not unique.
   """
-  if not isinstance(dictionary, (dict, _collections.Mapping)):
+  if not isinstance(dictionary, (dict, _collections_abc.Mapping)):
     raise TypeError("input must be a dictionary")
   flat_dictionary = {}
   for i, v in _six.iteritems(dictionary):
@@ -714,8 +722,8 @@
             (_is_type_spec(shallow_tree) or _is_type_spec(input_tree))):
         pass  # Compatibility will be checked below.
 
-      elif not (isinstance(shallow_tree, _collections.Mapping)
-                and isinstance(input_tree, _collections.Mapping)):
+      elif not (isinstance(shallow_tree, _collections_abc.Mapping) and
+                isinstance(input_tree, _collections_abc.Mapping)):
         raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format(
             input_type=type(input_tree),
             shallow_type=type(shallow_tree)))
@@ -753,7 +761,7 @@
             _INPUT_TREE_SMALLER_THAN_SHALLOW_TREE.format(
                 input_size=len(input_tree), shallow_size=len(shallow_tree)))
 
-    if isinstance(shallow_tree, _collections.Mapping):
+    if isinstance(shallow_tree, _collections_abc.Mapping):
       absent_keys = set(shallow_tree) - set(input_tree)
       if absent_keys:
         raise ValueError(_SHALLOW_TREE_HAS_INVALID_KEYS
@@ -1315,5 +1323,5 @@
                   flatten(structure, expand_composites=expand_composites)))
 
 
-_pywrap_tensorflow.RegisterType("Mapping", _collections.Mapping)
-_pywrap_tensorflow.RegisterType("Sequence", _collections.Sequence)
+_pywrap_tensorflow.RegisterType("Mapping", _collections_abc.Mapping)
+_pywrap_tensorflow.RegisterType("Sequence", _collections_abc.Sequence)
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index 73cb178..9ed84a9 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -32,6 +32,7 @@
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 from tensorflow.python.util import nest
+from tensorflow.python.util.compat import collections_abc
 
 try:
   import attr  # pylint:disable=g-import-not-at-top
@@ -39,7 +40,7 @@
   attr = None
 
 
-class _CustomMapping(collections.Mapping):
+class _CustomMapping(collections_abc.Mapping):
 
   def __init__(self, *args, **kwargs):
     self._wrapped = dict(*args, **kwargs)
@@ -436,6 +437,16 @@
 
     self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))
 
+    structure3 = collections.defaultdict(list)
+    structure3["a"] = [1, 2, 3, 4]
+    structure3["b"] = [2, 3, 4, 5]
+
+    expected_structure3 = collections.defaultdict(list)
+    expected_structure3["a"] = [2, 3, 4, 5]
+    expected_structure3["b"] = [3, 4, 5, 6]
+    self.assertEqual(expected_structure3,
+                     nest.map_structure(lambda x: x + 1, structure3))
+
     # Empty structures
     self.assertEqual((), nest.map_structure(lambda x: x + 1, ()))
     self.assertEqual([], nest.map_structure(lambda x: x + 1, []))
diff --git a/tensorflow/python/util/object_identity.py b/tensorflow/python/util/object_identity.py
index d4eef5b..2f913dd 100644
--- a/tensorflow/python/util/object_identity.py
+++ b/tensorflow/python/util/object_identity.py
@@ -17,9 +17,10 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
 import weakref
 
+from tensorflow.python.util.compat import collections_abc
+
 
 class _ObjectIdentityWrapper(object):
   """Wraps an object, mapping __eq__ on wrapper to "is" on wrapped.
@@ -39,7 +40,10 @@
   def __eq__(self, other):
     if isinstance(other, _ObjectIdentityWrapper):
       return self._wrapped is other._wrapped  # pylint: disable=protected-access
-    return self._wrapped is other
+    return False
+
+  def __ne__(self, other):
+    return not self.__eq__(other)
 
   def __hash__(self):
     # Wrapper id() is also fine for weakrefs. In fact, we rely on
@@ -47,6 +51,9 @@
     # weakref.ref(a) in _WeakObjectIdentityWrapper.
     return id(self._wrapped)
 
+  def __repr__(self):
+    return "<{} wrapping {!r}>".format(type(self).__name__, self._wrapped)
+
 
 class _WeakObjectIdentityWrapper(_ObjectIdentityWrapper):
 
@@ -58,7 +65,41 @@
     return self._wrapped()
 
 
-class ObjectIdentityDictionary(collections.MutableMapping):
+class Reference(_ObjectIdentityWrapper):
+  """Reference that refers an object.
+
+  ```python
+  x = [1]
+  y = [1]
+
+  x_ref1 = Reference(x)
+  x_ref2 = Reference(x)
+  y_ref2 = Reference(y)
+
+  print(x_ref1 == x_ref2)
+  ==> True
+
+  print(x_ref1 == y)
+  ==> False
+  ```
+  """
+
+  # Disabling super class' unwrapped field.
+  unwrapped = property()
+
+  def deref(self):
+    """Returns the referenced object.
+
+    ```python
+    x_ref = Reference(x)
+    print(x is x_ref.deref())
+    ==> True
+    ```
+    """
+    return self._wrapped
+
+
+class ObjectIdentityDictionary(collections_abc.MutableMapping):
   """A mutable mapping data structure which compares using "is".
 
   This is necessary because we have trackable objects (_ListWrapper) which
@@ -109,12 +150,18 @@
         yield unwrapped
 
 
-class ObjectIdentitySet(collections.MutableSet):
+class ObjectIdentitySet(collections_abc.MutableSet):
   """Like the built-in set, but compares objects with "is"."""
 
   def __init__(self, *args):
     self._storage = set([self._wrap_key(obj) for obj in list(*args)])
 
+  @staticmethod
+  def _from_storage(storage):
+    result = ObjectIdentitySet()
+    result._storage = storage  # pylint: disable=protected-access
+    return result
+
   def _wrap_key(self, key):
     return _ObjectIdentityWrapper(key)
 
@@ -127,6 +174,16 @@
   def add(self, key):
     self._storage.add(self._wrap_key(key))
 
+  def update(self, items):
+    self._storage.update([self._wrap_key(item) for item in items])
+
+  def intersection(self, items):
+    return self._storage.intersection([self._wrap_key(item) for item in items])
+
+  def difference(self, items):
+    return ObjectIdentitySet._from_storage(
+        self._storage.difference([self._wrap_key(item) for item in items]))
+
   def __len__(self):
     return len(self._storage)
 
diff --git a/tensorflow/python/util/object_identity_test.py b/tensorflow/python/util/object_identity_test.py
new file mode 100644
index 0000000..5dc8be1
--- /dev/null
+++ b/tensorflow/python/util/object_identity_test.py
@@ -0,0 +1,52 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Unit tests for object_identity."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.platform import test
+from tensorflow.python.util import object_identity
+
+
+class ObjectIdentityWrapperTest(test.TestCase):
+
+  def testWrapperNotEqualToWrapped(self):
+    o = object()
+    self.assertNotEqual(o, object_identity._ObjectIdentityWrapper(o))
+    self.assertNotEqual(object_identity._ObjectIdentityWrapper(o), o)
+
+
+class ObjectIdentitySetTest(test.TestCase):
+
+  def testDifference(self):
+
+    class Element(object):
+      pass
+
+    a = Element()
+    b = Element()
+    c = Element()
+    set1 = object_identity.ObjectIdentitySet([a, b])
+    set2 = object_identity.ObjectIdentitySet([b, c])
+    diff_set = set1.difference(set2)
+    self.assertIn(a, diff_set)
+    self.assertNotIn(b, diff_set)
+    self.assertNotIn(c, diff_set)
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/python/util/protobuf/compare.py b/tensorflow/python/util/protobuf/compare.py
index 3a3af4b..6331b42 100644
--- a/tensorflow/python/util/protobuf/compare.py
+++ b/tensorflow/python/util/protobuf/compare.py
@@ -62,7 +62,6 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
 import difflib
 
 import six
@@ -72,6 +71,8 @@
 from google.protobuf import message
 from google.protobuf import text_format
 
+from ..compat import collections_abc
+
 
 def assertProtoEqual(self, a, b, check_initialized=True,  # pylint: disable=invalid-name
                      normalize_numbers=False, msg=None):
@@ -186,7 +187,7 @@
 
 
 def _IsMap(value):
-  return isinstance(value, collections.Mapping)
+  return isinstance(value, collections_abc.Mapping)
 
 
 def _IsRepeatedContainer(value):
diff --git a/tensorflow/python/util/serialization.py b/tensorflow/python/util/serialization.py
index 2164ba4..d9335da 100644
--- a/tensorflow/python/util/serialization.py
+++ b/tensorflow/python/util/serialization.py
@@ -18,11 +18,10 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
-
 import numpy as np
 
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.util.compat import collections_abc
 
 
 def get_json_type(obj):
@@ -63,7 +62,7 @@
   if isinstance(obj, tensor_shape.TensorShape):
     return obj.as_list()
 
-  if isinstance(obj, collections.Mapping):
+  if isinstance(obj, collections_abc.Mapping):
     return dict(obj)
 
   raise TypeError('Not JSON Serializable:', obj)
diff --git a/tensorflow/python/util/tf_stack.py b/tensorflow/python/util/tf_stack.py
index a6ba59e..fb994cb 100644
--- a/tensorflow/python/util/tf_stack.py
+++ b/tensorflow/python/util/tf_stack.py
@@ -24,6 +24,21 @@
 import sys
 import threading
 
+import six
+
+# Generally such lookups should be done using `threading.local()`. See
+# https://blogs.gnome.org/jamesh/2008/06/11/tls-python/ for a detailed
+# explanation of why. However the transform stacks are expected to be empty
+# when a thread is joined, so reusing the key does not introduce a correctness
+# issue. Moreover, get_ident is faster than storing and retrieving a unique
+# key in a thread local store.
+if six.PY3:
+  _get_thread_key = threading.get_ident
+else:
+  import thread  # pylint: disable=g-import-not-at-top
+  _get_thread_key = thread.get_ident
+
+
 # Names for indices into TF traceback tuples.
 TB_FILENAME = 0
 TB_LINENO = 1
@@ -31,48 +46,62 @@
 TB_CODEDICT = 3  # Dictionary of Python interpreter state.
 
 
-stacks = threading.local()
+_source_mapper_stacks = collections.defaultdict(list)
+_source_filter_stacks = collections.defaultdict(list)
 
 
-def _source_mappers():
-  if not hasattr(stacks, 'source_mapper'):
-    stacks.source_mapper = []
-  return stacks.source_mapper
+class StackTraceTransform(object):
+  """Base class for stack trace transformation functions."""
 
-
-def _source_filters():
-  if not hasattr(stacks, 'source_filter'):
-    stacks.source_filter = []
-  return stacks.source_filter
-
-
-class StackTraceMapper(object):
-  """Allows remapping traceback information to different source code."""
+  _stack_dict = None  # Subclasses should override
+  _thread_key = None
 
   def __enter__(self):
-    _source_mappers().append(self)
+    self.reset()
+
+    # Any given instance is assumed to be used by a single thread, which reduces
+    # expensive thread local lookups.
+    if self._thread_key is None:
+      self._thread_key = _get_thread_key()
+    else:
+      assert self._thread_key == _get_thread_key(), 'Shared across threads?'
+
+    stack = self._stack_dict[self._thread_key]
+    if stack:
+      self.parent = stack[-1]
+    else:
+      self.parent = None
+    stack.append(self)
     return self
 
   def __exit__(self, unused_type, unused_value, unused_traceback):
-    assert _source_mappers()[-1] is self, 'Concurrent access?'
-    _source_mappers().pop()
+    top = self._stack_dict[self._thread_key].pop()
+    assert top is self, 'Concurrent access?'
 
-  def map(self, filename, lineno, name):
+  def reset(self):
+    pass
+
+
+class StackTraceMapper(StackTraceTransform):
+  """Allows remapping traceback information to different source code."""
+  _stack_dict = _source_mapper_stacks
+
+  def reset(self):
+    self._effective_source_map = None
+
+  def get_effective_source_map(self):
+    """Returns a map (filename, lineno) -> (filename, lineno, function_name)."""
     raise NotImplementedError('subclasses need to override this')
 
 
-class StackTraceFilter(object):
+class StackTraceFilter(StackTraceTransform):
   """Allows filtering traceback information by removing superfluous frames."""
+  _stack_dict = _source_filter_stacks
 
-  def __enter__(self):
-    _source_filters().append(self)
-    return self
+  def reset(self):
+    self._filtered_filenames = None
 
-  def __exit__(self, unused_type, unused_value, unused_traceback):
-    assert _source_filters()[-1] is self, 'Concurrent access?'
-    _source_filters().pop()
-
-  def filter(self, filename, lineno, name):
+  def get_filtered_filenames(self):
     raise NotImplementedError('subclasses need to override this')
 
 
@@ -97,9 +126,16 @@
       del f
       del outer_f
 
-  def should_remove(self, filename, lineno, name):
-    del lineno, name
-    return filename == self._filename
+  def get_filtered_filenames(self):
+    if self._filtered_filenames is None:
+      self._filtered_filenames = frozenset((self._filename,))
+      if self.parent is not None:
+        self._filtered_filenames |= self.parent.get_filtered_filenames()
+    return self._filtered_filenames
+
+
+EMPTY_FROZEN_MAP = {}
+EMPTY_FROZEN_SET = frozenset()
 
 
 def extract_stack(limit=None):
@@ -127,6 +163,21 @@
     f = sys.exc_info()[2].tb_frame.f_back
   ret = []
   length = 0
+
+  thread_key = _get_thread_key()
+  source_mappers = _source_mapper_stacks[thread_key]
+  # TODO(mdan): Use sentinels instead.
+  if source_mappers:
+    source_map = source_mappers[-1].get_effective_source_map()
+  else:
+    source_map = EMPTY_FROZEN_MAP
+
+  source_filters = _source_filter_stacks[thread_key]
+  if source_filters:
+    filtered_filenames = source_filters[-1].get_filtered_filenames()
+  else:
+    filtered_filenames = EMPTY_FROZEN_SET
+
   while f is not None and (limit is None or length < limit):
     lineno = f.f_lineno
     co = f.f_code
@@ -135,22 +186,17 @@
     frame_globals = f.f_globals
     func_start_lineno = co.co_firstlineno
 
-    for mapper in _source_mappers():
-      # TODO(mdan): Show some indication that the frame was translated.
-      filename, lineno, name = mapper.map(filename, lineno, name)
+    # TODO(mdan): Show some indication that the frame was translated.
+    filename, lineno, name = source_map.get(
+        (filename, lineno), (filename, lineno, name))
 
-    keep = True
-    if ret:  # Never filter the innermost frame.
-      keep = not any(
-          f.should_remove(filename, lineno, name) for f in _source_filters())
-    if keep:
+    # Note: we never filter the innermost frame.
+    if not (ret and filename in filtered_filenames):
       ret.append((filename, lineno, name, frame_globals, func_start_lineno))
       length += 1
 
     f = f.f_back
 
-  # TODO(mdan): Also add a truncation mechanism.
-
   ret.reverse()
   return ret
 
diff --git a/tensorflow/python/util/traceme.i b/tensorflow/python/util/traceme.i
new file mode 100644
index 0000000..1fd0657
--- /dev/null
+++ b/tensorflow/python/util/traceme.i
@@ -0,0 +1,35 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+%include "tensorflow/python/lib/core/strings.i"
+%include "tensorflow/python/platform/base.i"
+
+%{
+#include "tensorflow/core/profiler/internal/python_traceme.h"
+%}
+
+%ignoreall
+
+%unignore tensorflow;
+%unignore tensorflow::profiler;
+%unignore tensorflow::profiler::PythonTraceMe;
+%unignore tensorflow::profiler::PythonTraceMe::PythonTraceMe;
+%unignore tensorflow::profiler::PythonTraceMe::Enter;
+%unignore tensorflow::profiler::PythonTraceMe::Exit;
+%unignore tensorflow::profiler::PythonTraceMe::~PythonTraceMe;
+
+%include "tensorflow/core/profiler/internal/python_traceme.h"
+
+%unignoreall
\ No newline at end of file
diff --git a/tensorflow/stream_executor/BUILD b/tensorflow/stream_executor/BUILD
index 654e84b2..817a43c 100644
--- a/tensorflow/stream_executor/BUILD
+++ b/tensorflow/stream_executor/BUILD
@@ -5,8 +5,8 @@
 # do not link against restricted binary blobs.
 
 load("//tensorflow:tensorflow.bzl", "tf_cc_test")
-load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
-load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
+load("//tensorflow/core/platform:default/build_config.bzl", "tf_proto_library")
+load("//tensorflow/core/platform:default/build_config_root.bzl", "if_static")
 load("//tensorflow/stream_executor:build_defs.bzl", "stream_executor_friends")
 
 package(
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index f25ed70..faf4a13 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -45,6 +45,7 @@
 
 #include "tensorflow/stream_executor/host_or_device_scalar.h"
 #include "tensorflow/stream_executor/lib/array_slice.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
 #include "tensorflow/stream_executor/platform/port.h"
 
 namespace Eigen {
@@ -1382,6 +1383,8 @@
                           const DeviceMemory<std::complex<double>> &a, int lda,
                           DeviceMemory<std::complex<double>> *b, int ldb) = 0;
 
+  virtual port::Status GetVersion(string *version) = 0;
+
  protected:
   BlasSupport() {}
 
@@ -2192,7 +2195,8 @@
                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
                   uint64 n, std::complex<double> alpha,                        \
                   const DeviceMemory<std::complex<double>> &a, int lda,        \
-                  DeviceMemory<std::complex<double>> *b, int ldb) override;
+                  DeviceMemory<std::complex<double>> *b, int ldb) override;    \
+  port::Status GetVersion(string *version) override;
 
 }  // namespace blas
 }  // namespace stream_executor
diff --git a/tensorflow/stream_executor/cuda/BUILD b/tensorflow/stream_executor/cuda/BUILD
index eec6195..27b1364 100644
--- a/tensorflow/stream_executor/cuda/BUILD
+++ b/tensorflow/stream_executor/cuda/BUILD
@@ -11,12 +11,11 @@
 )
 load("//tensorflow:tensorflow.bzl", "tf_copts")
 load(
-    "//tensorflow/core:platform/default/cuda_build_defs.bzl",
+    "//tensorflow/core/platform:default/cuda_build_defs.bzl",
     "if_cuda_is_configured",
 )
 load(
-    "//tensorflow/core:platform/default/build_config_root.bzl",
-    "if_static",
+    "//tensorflow/core/platform:default/build_config_root.bzl",
     "tf_cuda_tests_tags",
 )
 
@@ -139,8 +138,8 @@
         "//tensorflow/stream_executor/platform:dso_loader",
     ] + tf_additional_cuda_driver_deps()) + select({
         # include dynamic loading implementation only when if_cuda_is_configured and build dynamically
-        "//tensorflow:using_cuda_nvcc_with_dynamic_build": ["cudart_stub"],
-        "//tensorflow:using_cuda_clang_with_dynamic_build": ["cudart_stub"],
+        "//tensorflow:build_oss_using_cuda_nvcc": ["cudart_stub"],
+        "//tensorflow:build_oss_using_cuda_clang": ["cudart_stub"],
         "//conditions:default": ["//tensorflow/core:cuda"],
     }) + [
         "@com_google_absl//absl/base:core_headers",
@@ -154,20 +153,20 @@
     name = "cudart_stub",
     srcs = select({
         # include dynamic loading implementation only when if_cuda_is_configured and build dynamically
-        "//tensorflow:using_cuda_nvcc_with_dynamic_build": ["cudart_stub.cc"],
-        "//tensorflow:using_cuda_clang_with_dynamic_build": ["cudart_stub.cc"],
+        "//tensorflow:build_oss_using_cuda_nvcc": ["cudart_stub.cc"],
+        "//tensorflow:build_oss_using_cuda_clang": ["cudart_stub.cc"],
         "//conditions:default": [],
     }),
     textual_hdrs = glob(["cuda_runtime_*.inc"]),
     visibility = ["//visibility:public"],
     deps = select({
-        "//tensorflow:using_cuda_nvcc_with_dynamic_build": [
+        "//tensorflow:build_oss_using_cuda_nvcc": [
             ":cuda_stub",
             "@local_config_cuda//cuda:cuda_headers",
             "//tensorflow/stream_executor/lib",
             "//tensorflow/stream_executor/platform:dso_loader",
         ],
-        "//tensorflow:using_cuda_clang_with_dynamic_build": [
+        "//tensorflow:build_oss_using_cuda_clang": [
             ":cuda_stub",
             "@local_config_cuda//cuda:cuda_headers",
             "//tensorflow/stream_executor/lib",
@@ -232,11 +231,11 @@
 
 alias(
     name = "cublas_lib",
-    actual = if_static(
-        "@local_config_cuda//cuda:cublas",
-        ":cublas_stub",
-    ),
-    visibility = ["//visibility:private"],
+    actual = select({
+        "//tensorflow:oss": ":cublas_stub",
+        "//conditions:default": "@local_config_cuda//cuda:cublas",
+    }),
+    visibility = ["//visibility:public"],
 )
 
 cc_library(
@@ -288,11 +287,11 @@
 
 alias(
     name = "cufft_lib",
-    actual = if_static(
-        "@local_config_cuda//cuda:cufft",
-        ":cufft_stub",
-    ),
-    visibility = ["//visibility:private"],
+    actual = select({
+        "//tensorflow:oss": ":cufft_stub",
+        "//conditions:default": "@local_config_cuda//cuda:cufft",
+    }),
+    visibility = ["//visibility:public"],
 )
 
 cc_library(
@@ -333,11 +332,11 @@
 
 alias(
     name = "cudnn_lib",
-    actual = if_static(
-        "@local_config_cuda//cuda:cudnn",
-        ":cudnn_stub",
-    ),
-    visibility = ["//visibility:private"],
+    actual = select({
+        "//tensorflow:oss": ":cudnn_stub",
+        "//conditions:default": "@local_config_cuda//cuda:cudnn",
+    }),
+    visibility = ["//visibility:public"],
 )
 
 cc_library(
@@ -386,11 +385,11 @@
 
 alias(
     name = "curand_lib",
-    actual = if_static(
-        "@local_config_cuda//cuda:curand",
-        ":curand_stub",
-    ),
-    visibility = ["//visibility:private"],
+    actual = select({
+        "//tensorflow:oss": ":curand_stub",
+        "//conditions:default": "@local_config_cuda//cuda:curand",
+    }),
+    visibility = ["//visibility:public"],
 )
 
 cc_library(
@@ -443,6 +442,15 @@
     ]),
 )
 
+alias(
+    name = "cusolver_lib",
+    actual = select({
+        "//tensorflow:oss": ":cusolver_stub",
+        "//conditions:default": "@local_config_cuda//cuda:cusolver",
+    }),
+    visibility = ["//visibility:public"],
+)
+
 cc_library(
     name = "cusparse_stub",
     srcs = if_cuda_is_configured(["cusparse_stub.cc"]),
@@ -454,6 +462,15 @@
     ]),
 )
 
+alias(
+    name = "cusparse_lib",
+    actual = select({
+        "//tensorflow:oss": ":cusparse_stub",
+        "//conditions:default": "@local_config_cuda//cuda:cusparse",
+    }),
+    visibility = ["//visibility:public"],
+)
+
 cc_library(
     name = "cuda_kernel",
     srcs = if_cuda_is_configured(["cuda_kernel.cc"]),
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index 421b9b4..aceec62 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -2179,11 +2179,11 @@
   // whether a scratch allocator was passed.
   if (scratch_allocator != nullptr) {
     SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> a_bytes,
-                        scratch_allocator->AllocateBytes(stream, size));
+                        scratch_allocator->AllocateBytes(size));
     SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> b_bytes,
-                        scratch_allocator->AllocateBytes(stream, size));
+                        scratch_allocator->AllocateBytes(size));
     SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> c_bytes,
-                        scratch_allocator->AllocateBytes(stream, size));
+                        scratch_allocator->AllocateBytes(size));
     a = DeviceMemory<CUDA_T *>(a_bytes);
     b = DeviceMemory<CUDA_T *>(b_bytes);
     c = DeviceMemory<CUDA_T *>(c_bytes);
@@ -2794,6 +2794,18 @@
                         GpuComplex(GpuMemoryMutable(b)), ldb);
 }
 
+port::Status CUDABlas::GetVersion(string *version) {
+  absl::MutexLock lock(&mu_);
+
+  int v;
+  auto status = cublasGetVersion(blas_, &v);
+  if (status != CUBLAS_STATUS_SUCCESS) {
+    return port::InternalError(ToString(status));
+  }
+  *version = std::to_string(v);
+  return port::Status::OK();
+}
+
 }  // namespace gpu
 
 void initialize_cublas() {
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 4e900b4..ab1dbd4 100755
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -952,8 +952,8 @@
       size_t state_sizes_in_bytes = 0;
       RETURN_IF_CUDNN_ERROR(
           cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes));
-      SE_ASSIGN_OR_RETURN(state_memory, state_allocator->AllocateBytes(
-                                            nullptr, state_sizes_in_bytes));
+      SE_ASSIGN_OR_RETURN(state_memory,
+                          state_allocator->AllocateBytes(state_sizes_in_bytes));
     }
     RETURN_IF_CUDNN_ERROR(cudnnSetDropoutDescriptor(
         handle.get(), cudnn.handle(), dropout, state_memory.opaque(),
@@ -1043,7 +1043,7 @@
       cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode,
       cudnnDataType_t data_type, cudnnDataType_t compute_type,
       const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
-      ScratchAllocator* state_allocator) {
+      ScratchAllocator* state_allocator, bool use_padded_io) {
     SE_ASSIGN_OR_RETURN(
         CudnnDropoutDescriptor dropout_desc,
         CudnnDropoutDescriptor::Create(cudnn, dropout, seed, state_allocator));
@@ -1079,8 +1079,10 @@
     // But in the future if these APIs are used to process full length arrays,
     // we need to distinguish when to set it.
 #if CUDNN_VERSION >= 7201
-    RETURN_IF_CUDNN_ERROR(
-        cudnnSetRNNPaddingMode(rnn_desc.get(), CUDNN_RNN_PADDED_IO_ENABLED));
+    if (use_padded_io) {
+      RETURN_IF_CUDNN_ERROR(
+          cudnnSetRNNPaddingMode(rnn_desc.get(), CUDNN_RNN_PADDED_IO_ENABLED));
+    }
 #endif
 
     port::StatusOr<PersistentRnnPlan> rnn_plan_wrapper;
@@ -1603,7 +1605,7 @@
   if (workspace_size_in_bytes == 0) {
     return DeviceMemory<uint8>();
   }
-  return workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes);
+  return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
 }
 
 #if CUDNN_VERSION >= 7402
@@ -1628,7 +1630,7 @@
   if (workspace_size_in_bytes == 0) {
     return DeviceMemory<uint8>();
   }
-  return workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes);
+  return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
 }
 
 port::StatusOr<DeviceMemory<uint8>> CreateBatchNormBackwardWorkspace(
@@ -1652,7 +1654,7 @@
   if (workspace_size_in_bytes == 0) {
     return DeviceMemory<uint8>();
   }
-  return workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes);
+  return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
 }
 #endif
 
@@ -1701,9 +1703,8 @@
         /*sizeInBytes=*/&reserve_space_size_in_bytes));
 
     if (reserve_space_size_in_bytes > 0) {
-      SE_ASSIGN_OR_RETURN(reserve_space,
-                          reserve_space_allocator->AllocateBytes(
-                              stream, reserve_space_size_in_bytes));
+      SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
+                                             reserve_space_size_in_bytes));
     }
   }
 
@@ -1974,7 +1975,8 @@
     int batch_size, dnn::RnnInputMode input_mode,
     dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
     dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
-    float dropout, uint64 seed, ScratchAllocator* state_allocator) {
+    float dropout, uint64 seed, ScratchAllocator* state_allocator,
+    bool use_padded_io) {
   // Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's
   // not enqueueing anything into a stream, we pass in the null stream.
   auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr);
@@ -1985,7 +1987,7 @@
           ToCudnnRnnInputMode(input_mode),
           ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode),
           ToCudnnDataType(data_type), GetRnnComputeType(data_type),
-          algorithm_config, dropout, seed, state_allocator));
+          algorithm_config, dropout, seed, state_allocator, use_padded_io));
   return std::unique_ptr<dnn::RnnDescriptor>(
       new CudnnRnnDescriptor(std::move(rnn_desc)));
 }
@@ -2401,7 +2403,7 @@
                         "No scratch allocator provided");
   }
 
-  return scratch_allocator->AllocateBytes(stream, size_in_bytes);
+  return scratch_allocator->AllocateBytes(size_in_bytes);
 }
 
 port::StatusOr<DeviceMemory<uint8>>
@@ -2446,7 +2448,7 @@
                         "No scratch allocator provided");
   }
 
-  return scratch_allocator->AllocateBytes(stream, size_in_bytes);
+  return scratch_allocator->AllocateBytes(size_in_bytes);
 }
 
 port::StatusOr<DeviceMemory<uint8>>
@@ -2491,7 +2493,7 @@
                         "No scratch allocator provided");
   }
 
-  return scratch_allocator->AllocateBytes(stream, size_in_bytes);
+  return scratch_allocator->AllocateBytes(size_in_bytes);
 }
 
 static bool TensorOpMathAvailable(int cc_major) {
@@ -2512,7 +2514,7 @@
     bool specify_workspace_limit = scratch_allocator != nullptr;
     auto memory_limit_bytes =
         specify_workspace_limit
-            ? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll)
+            ? std::max(scratch_allocator->GetMemoryLimitInBytes(), 0ll)
             : 0ll;
     SE_ASSIGN_OR_RETURN(cudnnConvolutionFwdAlgo_t algo,
                         GetCudnnConvolutionForwardAlgo(
@@ -2540,8 +2542,9 @@
   if (!algo_desc.has_value()) {
     return port::Status(
         port::error::INVALID_ARGUMENT,
-        "The primary convolution algorithm failed memory allocation, "
-        "while a secondary algorithm is not provided.");
+        absl::StrCat("The primary convolution algorithm failed, ",
+                     "while a secondary algorithm is not provided. ",
+                     "Returned status: ", scratch_or.status().ToString()));
   }
 
   SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionForwardWorkspace(
@@ -2564,7 +2567,7 @@
     bool specify_workspace_limit = scratch_allocator != nullptr;
     auto memory_limit_bytes =
         specify_workspace_limit
-            ? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll)
+            ? std::max(scratch_allocator->GetMemoryLimitInBytes(), 0ll)
             : 0ll;
     SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdDataAlgo_t algo,
                         GetCudnnConvolutionBackwardDataAlgo(
@@ -2616,7 +2619,7 @@
     bool specify_workspace_limit = scratch_allocator != nullptr;
     auto memory_limit_bytes =
         specify_workspace_limit
-            ? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll)
+            ? std::max(scratch_allocator->GetMemoryLimitInBytes(), 0ll)
             : 0ll;
     SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdFilterAlgo_t algo,
                         GetCudnnConvolutionBackwardFilterAlgo(
@@ -3469,9 +3472,8 @@
               /*activationDesc=*/activation_desc.handle(),
               /*xDesc=*/x_descriptor.handle(),
               /*sizeInBytes=*/&reserve_space_size_in_bytes));
-      SE_ASSIGN_OR_RETURN(reserve_space,
-                          reserve_space_allocator->AllocateBytes(
-                              stream, reserve_space_size_in_bytes));
+      SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
+                                             reserve_space_size_in_bytes));
     }
   }
 #endif
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index e3742c0..482e861 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -51,7 +51,8 @@
       int batch_size, dnn::RnnInputMode input_mode,
       dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
       dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
-      float dropout, uint64 seed, ScratchAllocator* state_allocator) override;
+      float dropout, uint64 seed, ScratchAllocator* state_allocator,
+      bool use_padded_io) override;
 
   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc
index d323b74..f7a69fc 100644
--- a/tensorflow/stream_executor/cuda/cuda_driver.cc
+++ b/tensorflow/stream_executor/cuda/cuda_driver.cc
@@ -539,7 +539,7 @@
   return port::Status::OK();
 }
 
-/* static */ bool GpuDriver::LaunchKernel(
+/* static */ port::Status GpuDriver::LaunchKernel(
     GpuContext* context, CUfunction function, unsigned int grid_dim_x,
     unsigned int grid_dim_y, unsigned int grid_dim_z, unsigned int block_dim_x,
     unsigned int block_dim_y, unsigned int block_dim_z,
@@ -554,12 +554,12 @@
                                 block_dim_x, block_dim_y, block_dim_z,
                                 shared_mem_bytes, stream, kernel_params, extra);
   if (res != CUDA_SUCCESS) {
-    LOG(ERROR) << "failed to launch CUDA kernel: " << function
-               << "; result: " << ToString(res);
-    return false;
+    return port::InternalError(absl::StrCat(
+        "Failed to launch CUDA kernel: ", reinterpret_cast<uint64>(function),
+        "; result: ", ToString(res)));
   }
   VLOG(2) << "successfully launched kernel";
-  return true;
+  return port::Status::OK();
 }
 
 /* static */ port::Status GpuDriver::LoadCubin(GpuContext* context,
@@ -575,11 +575,11 @@
   return port::Status::OK();
 }
 
-/* static */ bool GpuDriver::LoadPtx(GpuContext* context,
-                                     const char* ptx_contents,
-                                     CUmodule* module) {
+/* static */ port::Status GpuDriver::LoadPtx(GpuContext* context,
+                                             const char* ptx_contents,
+                                             CUmodule* module) {
   absl::Notification notification;
-  bool ret = true;
+  port::Status ret = port::Status::OK();
   GetDriverExecutor()->Schedule([context, ptx_contents, module, &ret,
                                  &notification]() {
     ScopedActivateContext activation(context);
@@ -629,7 +629,8 @@
                                               : 0] = '\0';
       LOG(ERROR) << "error log buffer (" << error_log_buffer_bytes
                  << " bytes): " << error_log_buffer.data();
-      ret = false;
+      ret = port::InternalError(
+          absl::StrCat("Failed to load PTX text as a module: ", ToString(res)));
       notification.Notify();
     }
 
diff --git a/tensorflow/stream_executor/cuda/cuda_fft.cc b/tensorflow/stream_executor/cuda/cuda_fft.cc
index 3bf2f5b..79047d9 100644
--- a/tensorflow/stream_executor/cuda/cuda_fft.cc
+++ b/tensorflow/stream_executor/cuda/cuda_fft.cc
@@ -244,8 +244,7 @@
 port::Status CUDAFftPlan::UpdateScratchAllocator(
     Stream *stream, ScratchAllocator *scratch_allocator) {
   if (scratch_size_bytes_ != 0) {
-    auto allocated =
-        scratch_allocator->AllocateBytes(stream, scratch_size_bytes_);
+    auto allocated = scratch_allocator->AllocateBytes(scratch_size_bytes_);
     if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
       LOG(ERROR) << "failed to allocate work area.";
       return allocated.status();
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
index a9289e3..38d3dc9 100644
--- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
@@ -217,16 +217,13 @@
   return exe_path;
 }
 
-bool GpuExecutor::LoadModuleFromCuBin(const char* cubin, CUmodule* module) {
+port::Status GpuExecutor::LoadModuleFromCuBin(const char* cubin,
+                                              CUmodule* module) {
   uint64_t module_refcount;
   std::tie(*module, module_refcount) = gpu_binary_to_module_[cubin];
 
   if (*module == nullptr) {
-    auto load_status = GpuDriver::LoadCubin(context_, cubin, module);
-    if (!load_status.ok()) {
-      LOG(ERROR) << "failed to load CUBIN: " << load_status;
-      return false;
-    }
+    TF_RETURN_IF_ERROR(GpuDriver::LoadCubin(context_, cubin, module));
     module_refcount = 1;
     VLOG(3) << "Loaded CUBIN " << static_cast<const void *>(cubin)
             << " as module " << *module;
@@ -236,17 +233,15 @@
             << " is already loaded as module " << *module;
   }
   gpu_binary_to_module_[cubin] = {*module, module_refcount};
-  return true;
+  return port::Status::OK();
 }
 
-bool GpuExecutor::LoadModuleFromPtx(const char* ptx, CUmodule* module) {
+port::Status GpuExecutor::LoadModuleFromPtx(const char* ptx, CUmodule* module) {
   uint64_t module_refcount;
   std::tie(*module, module_refcount) = gpu_binary_to_module_[ptx];
 
   if (*module == nullptr) {
-    if (!GpuDriver::LoadPtx(context_, ptx, module)) {
-      return false;
-    }
+    TF_RETURN_IF_ERROR(GpuDriver::LoadPtx(context_, ptx, module));
     VLOG(3) << "Loaded PTX " << static_cast<const void *>(ptx) << " as module "
             << *module;
     module_refcount = 1;
@@ -256,7 +251,7 @@
             << " is already loaded as module " << module;
   }
   gpu_binary_to_module_[ptx] = {*module, module_refcount};
-  return true;
+  return port::Status::OK();
 }
 
 bool GpuExecutor::LoadModuleFromHsaco(const char* hsaco, CUmodule* module) {
@@ -264,8 +259,8 @@
   return false;
 }
 
-bool GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec,
-                            KernelBase* kernel) {
+port::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec,
+                                    KernelBase* kernel) {
   GpuKernel* cuda_kernel = AsGpuKernel(kernel);
   CUmodule module;
   const string *kernelname;
@@ -276,15 +271,13 @@
     absl::MutexLock lock{&in_memory_modules_mu_};
     kernelname = &spec.cuda_cubin_in_memory().kernelname();
     const char *cubin = spec.cuda_cubin_in_memory().bytes();
-    if (!LoadModuleFromCuBin(cubin, &module)) {
-      return false;
-    }
+    TF_RETURN_IF_ERROR(LoadModuleFromCuBin(cubin, &module));
     kernel_to_gpu_binary_[kernel] = cubin;
   } else if (spec.has_cuda_ptx_in_memory()) {
     kernelname = &spec.cuda_ptx_in_memory().kernelname();
 
     if (cc_major_ == 0 && cc_minor_ == 0) {
-      return false;
+      return port::InternalError("Compute capability not set");
     }
 
     const char *ptx = spec.cuda_ptx_in_memory().text(cc_major_, cc_minor_);
@@ -292,23 +285,19 @@
       ptx = spec.cuda_ptx_in_memory().default_text();
     }
     if (ptx == nullptr) {
-      LOG(FATAL) << "loader spec has no ptx for kernel " << *kernelname;
-      return false;
+      LOG(FATAL) << "Loader spec has no ptx for kernel " << *kernelname;
     }
 
     absl::MutexLock lock{&in_memory_modules_mu_};
-    if (!LoadModuleFromPtx(ptx, &module)) {
-      return false;
-    }
+    TF_RETURN_IF_ERROR(LoadModuleFromPtx(ptx, &module));
     kernel_to_gpu_binary_[kernel] = ptx;
   } else {
-    LOG(WARNING) << "no method of loading CUDA kernel provided";
-    return false;
+    return port::InternalError("No method of loading CUDA kernel provided");
   }
   VLOG(2) << "getting function " << *kernelname << " from module " << module;
   if (!GpuDriver::GetModuleFunction(context_, module, kernelname->c_str(),
                                     cuda_kernel->gpu_function_ptr())) {
-    return false;
+    return port::InternalError("Could not find the corresponding function");
   }
 
   // We have to trust the kernel loader spec arity because there doesn't appear
@@ -321,7 +310,7 @@
   }
   kernel->set_metadata(kernel_metadata);
   kernel->set_name(*kernelname);
-  return true;
+  return port::Status::OK();
 }
 
 bool GpuExecutor::UnloadGpuBinary(const void* gpu_binary) {
@@ -357,40 +346,36 @@
   kernel_to_gpu_binary_.erase(gpu_binary_it);
 }
 
-bool GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec,
-                             ModuleHandle* module_handle) {
+port::Status GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec,
+                                     ModuleHandle* module_handle) {
   // In GpuExecutor we store the pointer to the GPU binary (PTX or CUBIN) as
   // ModuleHandle::id().
   CUmodule cu_module;
   if (spec.has_cuda_cubin_in_memory()) {
     absl::MutexLock lock{&in_memory_modules_mu_};
-    if (!LoadModuleFromCuBin(
-            reinterpret_cast<const char *>(spec.cuda_cubin_in_memory().data()),
-            &cu_module)) {
-      return false;
-    }
+    TF_RETURN_IF_ERROR(LoadModuleFromCuBin(
+        reinterpret_cast<const char*>(spec.cuda_cubin_in_memory().data()),
+        &cu_module));
     *module_handle = ModuleHandle(const_cast<void *>(
         static_cast<const void *>(spec.cuda_cubin_in_memory().data())));
-    return true;
+    return port::Status::OK();
   } else if (spec.has_cuda_ptx_in_memory()) {
     if (cc_major_ == 0 && cc_minor_ == 0) {
-      return false;
+      return port::InternalError("Compute capability not set");
     }
 
     if (!spec.cuda_ptx_in_memory()) {
-      return false;
+      return port::InternalError("PTX not found in spec");
     }
 
     absl::MutexLock lock{&in_memory_modules_mu_};
-    if (!LoadModuleFromPtx(spec.cuda_ptx_in_memory(), &cu_module)) {
-      return false;
-    }
+    TF_RETURN_IF_ERROR(
+        LoadModuleFromPtx(spec.cuda_ptx_in_memory(), &cu_module));
     *module_handle = ModuleHandle(const_cast<void *>(
         static_cast<const void *>(spec.cuda_ptx_in_memory())));
-    return true;
+    return port::Status::OK();
   }
-  LOG(WARNING) << "no method of loading CUDA module provided";
-  return false;
+  return port::InternalError("No method of loading CUDA module provided");
 }
 
 bool GpuExecutor::UnloadModule(ModuleHandle module_handle) {
@@ -417,9 +402,10 @@
   return true;
 }
 
-bool GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims,
-                         const BlockDim& block_dims, const KernelBase& kernel,
-                         const KernelArgsArrayBase& args) {
+port::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims,
+                                 const BlockDim& block_dims,
+                                 const KernelBase& kernel,
+                                 const KernelArgsArrayBase& args) {
   CHECK_EQ(kernel.Arity(), args.number_of_arguments());
   CUstream custream = AsGpuStreamValue(stream);
   const GpuKernel* cuda_kernel = AsGpuKernel(&kernel);
@@ -445,19 +431,10 @@
 
   void **kernel_params = const_cast<void **>(args.argument_addresses().data());
 
-  if (!GpuDriver::LaunchKernel(context_, cufunc, block_dims.x, block_dims.y,
-                               block_dims.z, thread_dims.x, thread_dims.y,
-                               thread_dims.z, args.number_of_shared_bytes(),
-                               custream, kernel_params,
-                               nullptr /* = extra */)) {
-    LOG(ERROR) << "failed to launch CUDA kernel " << kernel.name() << " with "
-               << args.number_of_arguments()
-               << " args; thread dim: " << thread_dims.ToString()
-               << "; block dim: " << block_dims.ToString();
-    return false;
-  }
-
-  return true;
+  return GpuDriver::LaunchKernel(
+      context_, cufunc, block_dims.x, block_dims.y, block_dims.z, thread_dims.x,
+      thread_dims.y, thread_dims.z, args.number_of_shared_bytes(), custream,
+      kernel_params, nullptr /* = extra */);
 }
 
 // This is a non-essential operation; if there's a failure, proceed without
@@ -907,7 +884,7 @@
     }
   }
 
-  LOG(INFO) << "Falied to find symbol in any modules: " << symbol_name;
+  LOG(INFO) << "Failed to find symbol in any modules: " << symbol_name;
   return false;
 }
 
diff --git a/tensorflow/stream_executor/cuda/cudnn_7_6.inc b/tensorflow/stream_executor/cuda/cudnn_7_6.inc
new file mode 100644
index 0000000..030f3ed
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cudnn_7_6.inc
@@ -0,0 +1,3162 @@
+// Auto-generated, do not edit.
+
+extern "C" {
+
+size_t CUDNNWINAPI
+cudnnGetVersion(void) {
+  using FuncPtr = size_t (CUDNNWINAPI *)();
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetVersion");
+  if (!func_ptr) return 0;
+  return func_ptr();
+}
+
+size_t CUDNNWINAPI
+cudnnGetCudartVersion(void) {
+  using FuncPtr = size_t (CUDNNWINAPI *)();
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCudartVersion");
+  if (!func_ptr) return 0;
+  return func_ptr();
+}
+
+const char *CUDNNWINAPI
+cudnnGetErrorString(cudnnStatus_t status) {
+  using FuncPtr = const char * (CUDNNWINAPI *)(cudnnStatus_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetErrorString");
+  if (!func_ptr) return "cudnnGetErrorString symbol not found.";
+  return func_ptr(status);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnQueryRuntimeError(cudnnHandle_t handle, cudnnStatus_t *rstatus, cudnnErrQueryMode_t mode, cudnnRuntimeTag_t *tag) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnStatus_t *, cudnnErrQueryMode_t, cudnnRuntimeTag_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnQueryRuntimeError");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rstatus, mode, tag);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetProperty(libraryPropertyType type, int *value) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(libraryPropertyType, int *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetProperty");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(type, value);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCreate(cudnnHandle_t *handle) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreate");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDestroy(cudnnHandle_t handle) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroy");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetStream(cudnnHandle_t handle, cudaStream_t streamId) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetStream");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, streamId);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetStream(cudnnHandle_t handle, cudaStream_t *streamId) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetStream");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, streamId);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCreateTensorDescriptor(cudnnTensorDescriptor_t *tensorDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateTensorDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(tensorDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetTensor4dDescriptor(cudnnTensorDescriptor_t tensorDesc,
+                           cudnnTensorFormat_t format,
+                           cudnnDataType_t dataType, /* image data type */
+                           int n,                    /* number of inputs (batch size) */
+                           int c,                    /* number of input feature maps */
+                           int h,                    /* height of input section */
+                           int w) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, cudnnDataType_t, int, int, int, int);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensor4dDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(tensorDesc, format, dataType, n, c, h, w);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetTensor4dDescriptorEx(cudnnTensorDescriptor_t tensorDesc,
+                             cudnnDataType_t dataType, /* image data type */
+                             int n,                    /* number of inputs (batch size) */
+                             int c,                    /* number of input feature maps */
+                             int h,                    /* height of input section */
+                             int w,                    /* width of input section */
+                             int nStride,
+                             int cStride,
+                             int hStride,
+                             int wStride) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnDataType_t, int, int, int, int, int, int, int, int);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensor4dDescriptorEx");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, wStride);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetTensor4dDescriptor(const cudnnTensorDescriptor_t tensorDesc,
+                           cudnnDataType_t *dataType, /* image data type */
+                           int *n,                    /* number of inputs (batch size) */
+                           int *c,                    /* number of input feature maps  */
+                           int *h,                    /* height of input section */
+                           int *w,                    /* width of input section */
+                           int *nStride,
+                           int *cStride,
+                           int *hStride,
+                           int *wStride) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnTensorDescriptor_t, cudnnDataType_t *, int *, int *, int *, int *, int *, int *, int *, int *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetTensor4dDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, wStride);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetTensorNdDescriptor(cudnnTensorDescriptor_t tensorDesc,
+                           cudnnDataType_t dataType,
+                           int nbDims,
+                           const int dimA[],
+                           const int strideA[]) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnDataType_t, int, const int [], const int []);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensorNdDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(tensorDesc, dataType, nbDims, dimA, strideA);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetTensorNdDescriptorEx(cudnnTensorDescriptor_t tensorDesc,
+                             cudnnTensorFormat_t format,
+                             cudnnDataType_t dataType,
+                             int nbDims,
+                             const int dimA[]) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, cudnnDataType_t, int, const int []);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensorNdDescriptorEx");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(tensorDesc, format, dataType, nbDims, dimA);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetTensorNdDescriptor(const cudnnTensorDescriptor_t tensorDesc,
+                           int nbDimsRequested,
+                           cudnnDataType_t *dataType,
+                           int *nbDims,
+                           int dimA[],
+                           int strideA[]) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnTensorDescriptor_t, int, cudnnDataType_t *, int *, int [], int []);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetTensorNdDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(tensorDesc, nbDimsRequested, dataType, nbDims, dimA, strideA);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetTensorSizeInBytes(const cudnnTensorDescriptor_t tensorDesc, size_t *size) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnTensorDescriptor_t, size_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetTensorSizeInBytes");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(tensorDesc, size);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDestroyTensorDescriptor(cudnnTensorDescriptor_t tensorDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyTensorDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(tensorDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnInitTransformDest(const cudnnTensorTransformDescriptor_t transformDesc,
+                       const cudnnTensorDescriptor_t srcDesc,
+                       cudnnTensorDescriptor_t destDesc,
+                       size_t *destSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnTensorTransformDescriptor_t, const cudnnTensorDescriptor_t, cudnnTensorDescriptor_t, size_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnInitTransformDest");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(transformDesc, srcDesc, destDesc, destSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCreateTensorTransformDescriptor(cudnnTensorTransformDescriptor_t *transformDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorTransformDescriptor_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateTensorTransformDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(transformDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetTensorTransformDescriptor(cudnnTensorTransformDescriptor_t transformDesc,
+                                  const uint32_t nbDims,
+                                  const cudnnTensorFormat_t destFormat,
+                                  const int32_t padBeforeA[],
+                                  const int32_t padAfterA[],
+                                  const uint32_t foldA[],
+                                  const cudnnFoldingDirection_t direction) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorTransformDescriptor_t, const uint32_t, const cudnnTensorFormat_t, const int32_t [], const int32_t [], const uint32_t [], const cudnnFoldingDirection_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensorTransformDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(transformDesc, nbDims, destFormat, padBeforeA, padAfterA, foldA, direction);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetTensorTransformDescriptor(cudnnTensorTransformDescriptor_t transformDesc,
+                                  uint32_t nbDimsRequested,
+                                  cudnnTensorFormat_t *destFormat,
+                                  int32_t padBeforeA[],
+                                  int32_t padAfterA[],
+                                  uint32_t foldA[],
+                                  cudnnFoldingDirection_t *direction) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorTransformDescriptor_t, uint32_t, cudnnTensorFormat_t *, int32_t [], int32_t [], uint32_t [], cudnnFoldingDirection_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetTensorTransformDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(transformDesc, nbDimsRequested, destFormat, padBeforeA, padAfterA, foldA, direction);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDestroyTensorTransformDescriptor(cudnnTensorTransformDescriptor_t transformDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorTransformDescriptor_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyTensorTransformDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(transformDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnTransformTensor(cudnnHandle_t handle,
+                     const void *alpha,
+                     const cudnnTensorDescriptor_t xDesc,
+                     const void *x,
+                     const void *beta,
+                     const cudnnTensorDescriptor_t yDesc,
+                     void *y) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnTransformTensor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, alpha, xDesc, x, beta, yDesc, y);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnTransformTensorEx(cudnnHandle_t handle,
+                       const cudnnTensorTransformDescriptor_t transDesc,
+                       const void *alpha,
+                       const cudnnTensorDescriptor_t srcDesc,
+                       const void *srcData,
+                       const void *beta,
+                       const cudnnTensorDescriptor_t destDesc,
+                       void *destData) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorTransformDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnTransformTensorEx");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, transDesc, alpha, srcDesc, srcData, beta, destDesc, destData);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetFoldedConvBackwardDataDescriptors(const cudnnHandle_t handle,
+                                          const cudnnFilterDescriptor_t filterDesc,
+                                          const cudnnTensorDescriptor_t diffDesc,
+                                          const cudnnConvolutionDescriptor_t convDesc,
+                                          const cudnnTensorDescriptor_t gradDesc,
+                                          const cudnnTensorFormat_t transformFormat,
+                                          cudnnFilterDescriptor_t foldedFilterDesc,
+                                          cudnnTensorDescriptor_t paddedDiffDesc,
+                                          cudnnConvolutionDescriptor_t foldedConvDesc,
+                                          cudnnTensorDescriptor_t foldedGradDesc,
+                                          cudnnTensorTransformDescriptor_t filterFoldTransDesc,
+                                          cudnnTensorTransformDescriptor_t diffPadTransDesc,
+                                          cudnnTensorTransformDescriptor_t gradFoldTransDesc,
+                                          cudnnTensorTransformDescriptor_t gradUnfoldTransDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorFormat_t, cudnnFilterDescriptor_t, cudnnTensorDescriptor_t, cudnnConvolutionDescriptor_t, cudnnTensorDescriptor_t, cudnnTensorTransformDescriptor_t, cudnnTensorTransformDescriptor_t, cudnnTensorTransformDescriptor_t, cudnnTensorTransformDescriptor_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetFoldedConvBackwardDataDescriptors");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, filterDesc, diffDesc, convDesc, gradDesc, transformFormat, foldedFilterDesc, paddedDiffDesc, foldedConvDesc, foldedGradDesc, filterFoldTransDesc, diffPadTransDesc, gradFoldTransDesc, gradUnfoldTransDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnAddTensor(cudnnHandle_t handle,
+               const void *alpha,
+               const cudnnTensorDescriptor_t aDesc,
+               const void *A,
+               const void *beta,
+               const cudnnTensorDescriptor_t cDesc,
+               void *C) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnAddTensor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, alpha, aDesc, A, beta, cDesc, C);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCreateOpTensorDescriptor(cudnnOpTensorDescriptor_t *opTensorDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnOpTensorDescriptor_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateOpTensorDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(opTensorDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetOpTensorDescriptor(cudnnOpTensorDescriptor_t opTensorDesc,
+                           cudnnOpTensorOp_t opTensorOp,
+                           cudnnDataType_t opTensorCompType,
+                           cudnnNanPropagation_t opTensorNanOpt) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t, cudnnDataType_t, cudnnNanPropagation_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetOpTensorDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(opTensorDesc, opTensorOp, opTensorCompType, opTensorNanOpt);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetOpTensorDescriptor(const cudnnOpTensorDescriptor_t opTensorDesc,
+                           cudnnOpTensorOp_t *opTensorOp,
+                           cudnnDataType_t *opTensorCompType,
+                           cudnnNanPropagation_t *opTensorNanOpt) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t *, cudnnDataType_t *, cudnnNanPropagation_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetOpTensorDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(opTensorDesc, opTensorOp, opTensorCompType, opTensorNanOpt);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDestroyOpTensorDescriptor(cudnnOpTensorDescriptor_t opTensorDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnOpTensorDescriptor_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyOpTensorDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(opTensorDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnOpTensor(cudnnHandle_t handle,
+              const cudnnOpTensorDescriptor_t opTensorDesc,
+              const void *alpha1,
+              const cudnnTensorDescriptor_t aDesc,
+              const void *A,
+              const void *alpha2,
+              const cudnnTensorDescriptor_t bDesc,
+              const void *B,
+              const void *beta,
+              const cudnnTensorDescriptor_t cDesc,
+              void *C) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnOpTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnOpTensor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, opTensorDesc, alpha1, aDesc, A, alpha2, bDesc, B, beta, cDesc, C);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCreateReduceTensorDescriptor(cudnnReduceTensorDescriptor_t *reduceTensorDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateReduceTensorDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(reduceTensorDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetReduceTensorDescriptor(cudnnReduceTensorDescriptor_t reduceTensorDesc,
+                               cudnnReduceTensorOp_t reduceTensorOp,
+                               cudnnDataType_t reduceTensorCompType,
+                               cudnnNanPropagation_t reduceTensorNanOpt,
+                               cudnnReduceTensorIndices_t reduceTensorIndices,
+                               cudnnIndicesType_t reduceTensorIndicesType) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t, cudnnDataType_t, cudnnNanPropagation_t, cudnnReduceTensorIndices_t, cudnnIndicesType_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetReduceTensorDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, reduceTensorNanOpt, reduceTensorIndices, reduceTensorIndicesType);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetReduceTensorDescriptor(const cudnnReduceTensorDescriptor_t reduceTensorDesc,
+                               cudnnReduceTensorOp_t *reduceTensorOp,
+                               cudnnDataType_t *reduceTensorCompType,
+                               cudnnNanPropagation_t *reduceTensorNanOpt,
+                               cudnnReduceTensorIndices_t *reduceTensorIndices,
+                               cudnnIndicesType_t *reduceTensorIndicesType) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t *, cudnnDataType_t *, cudnnNanPropagation_t *, cudnnReduceTensorIndices_t *, cudnnIndicesType_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetReduceTensorDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, reduceTensorNanOpt, reduceTensorIndices, reduceTensorIndicesType);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDestroyReduceTensorDescriptor(cudnnReduceTensorDescriptor_t reduceTensorDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyReduceTensorDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(reduceTensorDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetReductionIndicesSize(cudnnHandle_t handle,
+                             const cudnnReduceTensorDescriptor_t reduceTensorDesc,
+                             const cudnnTensorDescriptor_t aDesc,
+                             const cudnnTensorDescriptor_t cDesc,
+                             size_t *sizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnReduceTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetReductionIndicesSize");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, reduceTensorDesc, aDesc, cDesc, sizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetReductionWorkspaceSize(cudnnHandle_t handle,
+                               const cudnnReduceTensorDescriptor_t reduceTensorDesc,
+                               const cudnnTensorDescriptor_t aDesc,
+                               const cudnnTensorDescriptor_t cDesc,
+                               size_t *sizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnReduceTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetReductionWorkspaceSize");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, reduceTensorDesc, aDesc, cDesc, sizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnReduceTensor(cudnnHandle_t handle,
+                  const cudnnReduceTensorDescriptor_t reduceTensorDesc,
+                  void *indices,
+                  size_t indicesSizeInBytes,
+                  void *workspace,
+                  size_t workspaceSizeInBytes,
+                  const void *alpha,
+                  const cudnnTensorDescriptor_t aDesc,
+                  const void *A,
+                  const void *beta,
+                  const cudnnTensorDescriptor_t cDesc,
+                  void *C) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnReduceTensorDescriptor_t, void *, size_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnReduceTensor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, reduceTensorDesc, indices, indicesSizeInBytes, workspace, workspaceSizeInBytes, alpha, aDesc, A, beta, cDesc, C);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetTensor(cudnnHandle_t handle, const cudnnTensorDescriptor_t yDesc, void *y, const void *valuePtr) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, yDesc, y, valuePtr);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnScaleTensor(cudnnHandle_t handle, const cudnnTensorDescriptor_t yDesc, void *y, const void *alpha) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnScaleTensor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, yDesc, y, alpha);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCreateFilterDescriptor(cudnnFilterDescriptor_t *filterDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateFilterDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(filterDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetFilter4dDescriptor(cudnnFilterDescriptor_t filterDesc,
+                           cudnnDataType_t dataType, /* image data type */
+                           cudnnTensorFormat_t format,
+                           int k,  /* number of output feature maps */
+                           int c,  /* number of input feature maps */
+                           int h,  /* height of each input filter */
+                           int w) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, cudnnTensorFormat_t, int, int, int, int);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetFilter4dDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(filterDesc, dataType, format, k, c, h, w);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetFilter4dDescriptor(const cudnnFilterDescriptor_t filterDesc,
+                           cudnnDataType_t *dataType, /* image data type */
+                           cudnnTensorFormat_t *format,
+                           int *k,  /* number of output feature maps */
+                           int *c,  /* number of input feature maps */
+                           int *h,  /* height of each input filter */
+                           int *w) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnFilterDescriptor_t, cudnnDataType_t *, cudnnTensorFormat_t *, int *, int *, int *, int *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetFilter4dDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(filterDesc, dataType, format, k, c, h, w);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetFilterNdDescriptor(cudnnFilterDescriptor_t filterDesc,
+                           cudnnDataType_t dataType, /* image data type */
+                           cudnnTensorFormat_t format,
+                           int nbDims,
+                           const int filterDimA[]) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, cudnnTensorFormat_t, int, const int []);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetFilterNdDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(filterDesc, dataType, format, nbDims, filterDimA);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetFilterNdDescriptor(const cudnnFilterDescriptor_t filterDesc,
+                           int nbDimsRequested,
+                           cudnnDataType_t *dataType, /* image data type */
+                           cudnnTensorFormat_t *format,
+                           int *nbDims,
+                           int filterDimA[]) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnFilterDescriptor_t, int, cudnnDataType_t *, cudnnTensorFormat_t *, int *, int []);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetFilterNdDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(filterDesc, nbDimsRequested, dataType, format, nbDims, filterDimA);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetFilterSizeInBytes(const cudnnFilterDescriptor_t filterDesc, size_t *size) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnFilterDescriptor_t, size_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetFilterSizeInBytes");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(filterDesc, size);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnTransformFilter(cudnnHandle_t handle,
+                     const cudnnTensorTransformDescriptor_t transDesc,
+                     const void *alpha,
+                     const cudnnFilterDescriptor_t srcDesc,
+                     const void *srcData,
+                     const void *beta,
+                     const cudnnFilterDescriptor_t destDesc,
+                     void *destData) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorTransformDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const void *, const cudnnFilterDescriptor_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnTransformFilter");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, transDesc, alpha, srcDesc, srcData, beta, destDesc, destData);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDestroyFilterDescriptor(cudnnFilterDescriptor_t filterDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyFilterDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(filterDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnReorderFilterAndBias(cudnnHandle_t handle,
+                          const cudnnFilterDescriptor_t filterDesc,
+                          cudnnReorderType_t reorderType,
+                          const void *filterData,
+                          void *reorderedFilterData,
+                          int reorderBias,
+                          const void *biasData,
+                          void *reorderedBiasData) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, cudnnReorderType_t, const void *, void *, int, const void *, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnReorderFilterAndBias");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, filterDesc, reorderType, filterData, reorderedFilterData, reorderBias, biasData, reorderedBiasData);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCreateConvolutionDescriptor(cudnnConvolutionDescriptor_t *convDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateConvolutionDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(convDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetConvolutionMathType(cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t mathType) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, cudnnMathType_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolutionMathType");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(convDesc, mathType);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetConvolutionMathType(cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t *mathType) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, cudnnMathType_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionMathType");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(convDesc, mathType);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetConvolutionGroupCount(cudnnConvolutionDescriptor_t convDesc, int groupCount) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolutionGroupCount");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(convDesc, groupCount);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetConvolutionGroupCount(cudnnConvolutionDescriptor_t convDesc, int *groupCount) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionGroupCount");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(convDesc, groupCount);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetConvolutionReorderType(cudnnConvolutionDescriptor_t convDesc, cudnnReorderType_t reorderType) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, cudnnReorderType_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolutionReorderType");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(convDesc, reorderType);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetConvolutionReorderType(cudnnConvolutionDescriptor_t convDesc, cudnnReorderType_t *reorderType) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, cudnnReorderType_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionReorderType");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(convDesc, reorderType);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetConvolution2dDescriptor(cudnnConvolutionDescriptor_t convDesc,
+                                int pad_h,      /* zero-padding height */
+                                int pad_w,      /* zero-padding width */
+                                int u,          /* vertical filter stride */
+                                int v,          /* horizontal filter stride */
+                                int dilation_h, /* filter dilation in the vertical dimension */
+                                int dilation_w, /* filter dilation in the horizontal dimension */
+                                cudnnConvolutionMode_t mode,
+                                cudnnDataType_t computeType) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int, int, int, int, int, int, cudnnConvolutionMode_t, cudnnDataType_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolution2dDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, computeType);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetConvolution2dDescriptor(const cudnnConvolutionDescriptor_t convDesc,
+                                int *pad_h,      /* zero-padding height */
+                                int *pad_w,      /* zero-padding width */
+                                int *u,          /* vertical filter stride */
+                                int *v,          /* horizontal filter stride */
+                                int *dilation_h, /* filter dilation in the vertical dimension */
+                                int *dilation_w, /* filter dilation in the horizontal dimension */
+                                cudnnConvolutionMode_t *mode,
+                                cudnnDataType_t *computeType) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, int *, int *, int *, int *, int *, int *, cudnnConvolutionMode_t *, cudnnDataType_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolution2dDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, computeType);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetConvolution2dForwardOutputDim(const cudnnConvolutionDescriptor_t convDesc,
+                                      const cudnnTensorDescriptor_t inputTensorDesc,
+                                      const cudnnFilterDescriptor_t filterDesc,
+                                      int *n,
+                                      int *c,
+                                      int *h,
+                                      int *w) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, int *, int *, int *, int *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolution2dForwardOutputDim");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(convDesc, inputTensorDesc, filterDesc, n, c, h, w);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetConvolutionNdDescriptor(cudnnConvolutionDescriptor_t convDesc,
+                                int arrayLength, /* nbDims-2 size */
+                                const int padA[],
+                                const int filterStrideA[],
+                                const int dilationA[],
+                                cudnnConvolutionMode_t mode,
+                                cudnnDataType_t computeType) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int, const int [], const int [], const int [], cudnnConvolutionMode_t, cudnnDataType_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolutionNdDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(convDesc, arrayLength, padA, filterStrideA, dilationA, mode, computeType);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetConvolutionNdDescriptor(const cudnnConvolutionDescriptor_t convDesc,
+                                int arrayLengthRequested,
+                                int *arrayLength,
+                                int padA[],
+                                int strideA[],
+                                int dilationA[],
+                                cudnnConvolutionMode_t *mode,
+                                cudnnDataType_t *computeType) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, int, int *, int [], int [], int [], cudnnConvolutionMode_t *, cudnnDataType_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionNdDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(convDesc, arrayLengthRequested, arrayLength, padA, strideA, dilationA, mode, computeType);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetConvolutionNdForwardOutputDim(const cudnnConvolutionDescriptor_t convDesc,
+                                      const cudnnTensorDescriptor_t inputTensorDesc,
+                                      const cudnnFilterDescriptor_t filterDesc,
+                                      int nbDims,
+                                      int tensorOuputDimA[]) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, int, int []);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionNdForwardOutputDim");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(convDesc, inputTensorDesc, filterDesc, nbDims, tensorOuputDimA);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDestroyConvolutionDescriptor(cudnnConvolutionDescriptor_t convDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyConvolutionDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(convDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetConvolutionForwardAlgorithmMaxCount(cudnnHandle_t handle, int *count) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, int *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithmMaxCount");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, count);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnFindConvolutionForwardAlgorithm(cudnnHandle_t handle,
+                                     const cudnnTensorDescriptor_t xDesc,
+                                     const cudnnFilterDescriptor_t wDesc,
+                                     const cudnnConvolutionDescriptor_t convDesc,
+                                     const cudnnTensorDescriptor_t yDesc,
+                                     const int requestedAlgoCount,
+                                     int *returnedAlgoCount,
+                                     cudnnConvolutionFwdAlgoPerf_t *perfResults) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionFwdAlgoPerf_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithm");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, requestedAlgoCount, returnedAlgoCount, perfResults);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnFindConvolutionForwardAlgorithmEx(cudnnHandle_t handle,
+                                       const cudnnTensorDescriptor_t xDesc,
+                                       const void *x,
+                                       const cudnnFilterDescriptor_t wDesc,
+                                       const void *w,
+                                       const cudnnConvolutionDescriptor_t convDesc,
+                                       const cudnnTensorDescriptor_t yDesc,
+                                       void *y,
+                                       const int requestedAlgoCount,
+                                       int *returnedAlgoCount,
+                                       cudnnConvolutionFwdAlgoPerf_t *perfResults,
+                                       void *workSpace,
+                                       size_t workSpaceSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, const int, int *, cudnnConvolutionFwdAlgoPerf_t *, void *, size_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithmEx");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, xDesc, x, wDesc, w, convDesc, yDesc, y, requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, workSpaceSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetConvolutionForwardAlgorithm(cudnnHandle_t handle,
+                                    const cudnnTensorDescriptor_t xDesc,
+                                    const cudnnFilterDescriptor_t wDesc,
+                                    const cudnnConvolutionDescriptor_t convDesc,
+                                    const cudnnTensorDescriptor_t yDesc,
+                                    cudnnConvolutionFwdPreference_t preference,
+                                    size_t memoryLimitInBytes,
+                                    cudnnConvolutionFwdAlgo_t *algo) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionFwdPreference_t, size_t, cudnnConvolutionFwdAlgo_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, preference, memoryLimitInBytes, algo);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetConvolutionForwardAlgorithm_v7(cudnnHandle_t handle,
+                                       const cudnnTensorDescriptor_t srcDesc,
+                                       const cudnnFilterDescriptor_t filterDesc,
+                                       const cudnnConvolutionDescriptor_t convDesc,
+                                       const cudnnTensorDescriptor_t destDesc,
+                                       const int requestedAlgoCount,
+                                       int *returnedAlgoCount,
+                                       cudnnConvolutionFwdAlgoPerf_t *perfResults) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionFwdAlgoPerf_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm_v7");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, srcDesc, filterDesc, convDesc, destDesc, requestedAlgoCount, returnedAlgoCount, perfResults);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetConvolutionForwardWorkspaceSize(cudnnHandle_t handle,
+                                        const cudnnTensorDescriptor_t xDesc,
+                                        const cudnnFilterDescriptor_t wDesc,
+                                        const cudnnConvolutionDescriptor_t convDesc,
+                                        const cudnnTensorDescriptor_t yDesc,
+                                        cudnnConvolutionFwdAlgo_t algo,
+                                        size_t *sizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionFwdAlgo_t, size_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardWorkspaceSize");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, algo, sizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnConvolutionForward(cudnnHandle_t handle,
+                        const void *alpha,
+                        const cudnnTensorDescriptor_t xDesc,
+                        const void *x,
+                        const cudnnFilterDescriptor_t wDesc,
+                        const void *w,
+                        const cudnnConvolutionDescriptor_t convDesc,
+                        cudnnConvolutionFwdAlgo_t algo,
+                        void *workSpace,
+                        size_t workSpaceSizeInBytes,
+                        const void *beta,
+                        const cudnnTensorDescriptor_t yDesc,
+                        void *y) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionForward");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, alpha, xDesc, x, wDesc, w, convDesc, algo, workSpace, workSpaceSizeInBytes, beta, yDesc, y);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnConvolutionBiasActivationForward(cudnnHandle_t handle,
+                                      const void *alpha1,
+                                      const cudnnTensorDescriptor_t xDesc,
+                                      const void *x,
+                                      const cudnnFilterDescriptor_t wDesc,
+                                      const void *w,
+                                      const cudnnConvolutionDescriptor_t convDesc,
+                                      cudnnConvolutionFwdAlgo_t algo,
+                                      void *workSpace,
+                                      size_t workSpaceSizeInBytes,
+                                      const void *alpha2,
+                                      const cudnnTensorDescriptor_t zDesc,
+                                      const void *z,
+                                      const cudnnTensorDescriptor_t biasDesc,
+                                      const void *bias,
+                                      const cudnnActivationDescriptor_t activationDesc,
+                                      const cudnnTensorDescriptor_t yDesc,
+                                      void *y) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnActivationDescriptor_t, const cudnnTensorDescriptor_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBiasActivationForward");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, alpha1, xDesc, x, wDesc, w, convDesc, algo, workSpace, workSpaceSizeInBytes, alpha2, zDesc, z, biasDesc, bias, activationDesc, yDesc, y);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnConvolutionBackwardBias(cudnnHandle_t handle,
+                             const void *alpha,
+                             const cudnnTensorDescriptor_t dyDesc,
+                             const void *dy,
+                             const void *beta,
+                             const cudnnTensorDescriptor_t dbDesc,
+                             void *db) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBackwardBias");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, alpha, dyDesc, dy, beta, dbDesc, db);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnnHandle_t handle, int *count) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, int *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithmMaxCount");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, count);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnFindConvolutionBackwardFilterAlgorithm(cudnnHandle_t handle,
+                                            const cudnnTensorDescriptor_t xDesc,
+                                            const cudnnTensorDescriptor_t dyDesc,
+                                            const cudnnConvolutionDescriptor_t convDesc,
+                                            const cudnnFilterDescriptor_t dwDesc,
+                                            const int requestedAlgoCount,
+                                            int *returnedAlgoCount,
+                                            cudnnConvolutionBwdFilterAlgoPerf_t *perfResults) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithm");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, requestedAlgoCount, returnedAlgoCount, perfResults);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnFindConvolutionBackwardFilterAlgorithmEx(cudnnHandle_t handle,
+                                              const cudnnTensorDescriptor_t xDesc,
+                                              const void *x,
+                                              const cudnnTensorDescriptor_t dyDesc,
+                                              const void *y,
+                                              const cudnnConvolutionDescriptor_t convDesc,
+                                              const cudnnFilterDescriptor_t dwDesc,
+                                              void *dw,
+                                              const int requestedAlgoCount,
+                                              int *returnedAlgoCount,
+                                              cudnnConvolutionBwdFilterAlgoPerf_t *perfResults,
+                                              void *workSpace,
+                                              size_t workSpaceSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, void *, const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *, void *, size_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithmEx");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, xDesc, x, dyDesc, y, convDesc, dwDesc, dw, requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, workSpaceSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetConvolutionBackwardFilterAlgorithm(cudnnHandle_t handle,
+                                           const cudnnTensorDescriptor_t xDesc,
+                                           const cudnnTensorDescriptor_t dyDesc,
+                                           const cudnnConvolutionDescriptor_t convDesc,
+                                           const cudnnFilterDescriptor_t dwDesc,
+                                           cudnnConvolutionBwdFilterPreference_t preference,
+                                           size_t memoryLimitInBytes,
+                                           cudnnConvolutionBwdFilterAlgo_t *algo) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterPreference_t, size_t, cudnnConvolutionBwdFilterAlgo_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, preference, memoryLimitInBytes, algo);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetConvolutionBackwardFilterAlgorithm_v7(cudnnHandle_t handle,
+                                              const cudnnTensorDescriptor_t srcDesc,
+                                              const cudnnTensorDescriptor_t diffDesc,
+                                              const cudnnConvolutionDescriptor_t convDesc,
+                                              const cudnnFilterDescriptor_t gradDesc,
+                                              const int requestedAlgoCount,
+                                              int *returnedAlgoCount,
+                                              cudnnConvolutionBwdFilterAlgoPerf_t *perfResults) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm_v7");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, srcDesc, diffDesc, convDesc, gradDesc, requestedAlgoCount, returnedAlgoCount, perfResults);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnHandle_t handle,
+                                               const cudnnTensorDescriptor_t xDesc,
+                                               const cudnnTensorDescriptor_t dyDesc,
+                                               const cudnnConvolutionDescriptor_t convDesc,
+                                               const cudnnFilterDescriptor_t gradDesc,
+                                               cudnnConvolutionBwdFilterAlgo_t algo,
+                                               size_t *sizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, size_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterWorkspaceSize");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, xDesc, dyDesc, convDesc, gradDesc, algo, sizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnConvolutionBackwardFilter(cudnnHandle_t handle,
+                               const void *alpha,
+                               const cudnnTensorDescriptor_t xDesc,
+                               const void *x,
+                               const cudnnTensorDescriptor_t dyDesc,
+                               const void *dy,
+                               const cudnnConvolutionDescriptor_t convDesc,
+                               cudnnConvolutionBwdFilterAlgo_t algo,
+                               void *workSpace,
+                               size_t workSpaceSizeInBytes,
+                               const void *beta,
+                               const cudnnFilterDescriptor_t dwDesc,
+                               void *dw) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, void *, size_t, const void *, const cudnnFilterDescriptor_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBackwardFilter");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, alpha, xDesc, x, dyDesc, dy, convDesc, algo, workSpace, workSpaceSizeInBytes, beta, dwDesc, dw);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetConvolutionBackwardDataAlgorithmMaxCount(cudnnHandle_t handle, int *count) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, int *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithmMaxCount");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, count);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnFindConvolutionBackwardDataAlgorithm(cudnnHandle_t handle,
+                                          const cudnnFilterDescriptor_t wDesc,
+                                          const cudnnTensorDescriptor_t dyDesc,
+                                          const cudnnConvolutionDescriptor_t convDesc,
+                                          const cudnnTensorDescriptor_t dxDesc,
+                                          const int requestedAlgoCount,
+                                          int *returnedAlgoCount,
+                                          cudnnConvolutionBwdDataAlgoPerf_t *perfResults) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithm");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, requestedAlgoCount, returnedAlgoCount, perfResults);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnFindConvolutionBackwardDataAlgorithmEx(cudnnHandle_t handle,
+                                            const cudnnFilterDescriptor_t wDesc,
+                                            const void *w,
+                                            const cudnnTensorDescriptor_t dyDesc,
+                                            const void *dy,
+                                            const cudnnConvolutionDescriptor_t convDesc,
+                                            const cudnnTensorDescriptor_t dxDesc,
+                                            void *dx,
+                                            const int requestedAlgoCount,
+                                            int *returnedAlgoCount,
+                                            cudnnConvolutionBwdDataAlgoPerf_t *perfResults,
+                                            void *workSpace,
+                                            size_t workSpaceSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *, void *, size_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithmEx");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, wDesc, w, dyDesc, dy, convDesc, dxDesc, dx, requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, workSpaceSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetConvolutionBackwardDataAlgorithm(cudnnHandle_t handle,
+                                         const cudnnFilterDescriptor_t wDesc,
+                                         const cudnnTensorDescriptor_t dyDesc,
+                                         const cudnnConvolutionDescriptor_t convDesc,
+                                         const cudnnTensorDescriptor_t dxDesc,
+                                         cudnnConvolutionBwdDataPreference_t preference,
+                                         size_t memoryLimitInBytes,
+                                         cudnnConvolutionBwdDataAlgo_t *algo) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataPreference_t, size_t, cudnnConvolutionBwdDataAlgo_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, preference, memoryLimitInBytes, algo);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetConvolutionBackwardDataAlgorithm_v7(cudnnHandle_t handle,
+                                            const cudnnFilterDescriptor_t filterDesc,
+                                            const cudnnTensorDescriptor_t diffDesc,
+                                            const cudnnConvolutionDescriptor_t convDesc,
+                                            const cudnnTensorDescriptor_t gradDesc,
+                                            const int requestedAlgoCount,
+                                            int *returnedAlgoCount,
+                                            cudnnConvolutionBwdDataAlgoPerf_t *perfResults) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm_v7");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, filterDesc, diffDesc, convDesc, gradDesc, requestedAlgoCount, returnedAlgoCount, perfResults);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetConvolutionBackwardDataWorkspaceSize(cudnnHandle_t handle,
+                                             const cudnnFilterDescriptor_t wDesc,
+                                             const cudnnTensorDescriptor_t dyDesc,
+                                             const cudnnConvolutionDescriptor_t convDesc,
+                                             const cudnnTensorDescriptor_t dxDesc,
+                                             cudnnConvolutionBwdDataAlgo_t algo,
+                                             size_t *sizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataAlgo_t, size_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataWorkspaceSize");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, algo, sizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnConvolutionBackwardData(cudnnHandle_t handle,
+                             const void *alpha,
+                             const cudnnFilterDescriptor_t wDesc,
+                             const void *w,
+                             const cudnnTensorDescriptor_t dyDesc,
+                             const void *dy,
+                             const cudnnConvolutionDescriptor_t convDesc,
+                             cudnnConvolutionBwdDataAlgo_t algo,
+                             void *workSpace,
+                             size_t workSpaceSizeInBytes,
+                             const void *beta,
+                             const cudnnTensorDescriptor_t dxDesc,
+                             void *dx) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdDataAlgo_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBackwardData");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, alpha, wDesc, w, dyDesc, dy, convDesc, algo, workSpace, workSpaceSizeInBytes, beta, dxDesc, dx);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnIm2Col(cudnnHandle_t handle,
+            const cudnnTensorDescriptor_t xDesc,
+            const void *x,
+            const cudnnFilterDescriptor_t wDesc,
+            const cudnnConvolutionDescriptor_t convDesc,
+            void *colBuffer) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnIm2Col");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, xDesc, x, wDesc, convDesc, colBuffer);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSoftmaxForward(cudnnHandle_t handle,
+                    cudnnSoftmaxAlgorithm_t algo,
+                    cudnnSoftmaxMode_t mode,
+                    const void *alpha,
+                    const cudnnTensorDescriptor_t xDesc,
+                    const void *x,
+                    const void *beta,
+                    const cudnnTensorDescriptor_t yDesc,
+                    void *y) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSoftmaxForward");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, algo, mode, alpha, xDesc, x, beta, yDesc, y);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSoftmaxBackward(cudnnHandle_t handle,
+                     cudnnSoftmaxAlgorithm_t algo,
+                     cudnnSoftmaxMode_t mode,
+                     const void *alpha,
+                     const cudnnTensorDescriptor_t yDesc,
+                     const void *y,
+                     const cudnnTensorDescriptor_t dyDesc,
+                     const void *dy,
+                     const void *beta,
+                     const cudnnTensorDescriptor_t dxDesc,
+                     void *dx) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSoftmaxBackward");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, algo, mode, alpha, yDesc, y, dyDesc, dy, beta, dxDesc, dx);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCreatePoolingDescriptor(cudnnPoolingDescriptor_t *poolingDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreatePoolingDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(poolingDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetPooling2dDescriptor(cudnnPoolingDescriptor_t poolingDesc,
+                            cudnnPoolingMode_t mode,
+                            cudnnNanPropagation_t maxpoolingNanOpt,
+                            int windowHeight,
+                            int windowWidth,
+                            int verticalPadding,
+                            int horizontalPadding,
+                            int verticalStride,
+                            int horizontalStride) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t, cudnnPoolingMode_t, cudnnNanPropagation_t, int, int, int, int, int, int);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetPooling2dDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, windowWidth, verticalPadding, horizontalPadding, verticalStride, horizontalStride);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetPooling2dDescriptor(const cudnnPoolingDescriptor_t poolingDesc,
+                            cudnnPoolingMode_t *mode,
+                            cudnnNanPropagation_t *maxpoolingNanOpt,
+                            int *windowHeight,
+                            int *windowWidth,
+                            int *verticalPadding,
+                            int *horizontalPadding,
+                            int *verticalStride,
+                            int *horizontalStride) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, cudnnPoolingMode_t *, cudnnNanPropagation_t *, int *, int *, int *, int *, int *, int *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPooling2dDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, windowWidth, verticalPadding, horizontalPadding, verticalStride, horizontalStride);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetPoolingNdDescriptor(cudnnPoolingDescriptor_t poolingDesc,
+                            const cudnnPoolingMode_t mode,
+                            const cudnnNanPropagation_t maxpoolingNanOpt,
+                            int nbDims,
+                            const int windowDimA[],
+                            const int paddingA[],
+                            const int strideA[]) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t, const cudnnPoolingMode_t, const cudnnNanPropagation_t, int, const int [], const int [], const int []);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetPoolingNdDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(poolingDesc, mode, maxpoolingNanOpt, nbDims, windowDimA, paddingA, strideA);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetPoolingNdDescriptor(const cudnnPoolingDescriptor_t poolingDesc,
+                            int nbDimsRequested,
+                            cudnnPoolingMode_t *mode,
+                            cudnnNanPropagation_t *maxpoolingNanOpt,
+                            int *nbDims,
+                            int windowDimA[],
+                            int paddingA[],
+                            int strideA[]) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, int, cudnnPoolingMode_t *, cudnnNanPropagation_t *, int *, int [], int [], int []);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPoolingNdDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(poolingDesc, nbDimsRequested, mode, maxpoolingNanOpt, nbDims, windowDimA, paddingA, strideA);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetPoolingNdForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc,
+                                  const cudnnTensorDescriptor_t inputTensorDesc,
+                                  int nbDims,
+                                  int outputTensorDimA[]) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, const cudnnTensorDescriptor_t, int, int []);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPoolingNdForwardOutputDim");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(poolingDesc, inputTensorDesc, nbDims, outputTensorDimA);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetPooling2dForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc,
+                                  const cudnnTensorDescriptor_t inputTensorDesc,
+                                  int *n,
+                                  int *c,
+                                  int *h,
+                                  int *w) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, const cudnnTensorDescriptor_t, int *, int *, int *, int *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPooling2dForwardOutputDim");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(poolingDesc, inputTensorDesc, n, c, h, w);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDestroyPoolingDescriptor(cudnnPoolingDescriptor_t poolingDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyPoolingDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(poolingDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnPoolingForward(cudnnHandle_t handle,
+                    const cudnnPoolingDescriptor_t poolingDesc,
+                    const void *alpha,
+                    const cudnnTensorDescriptor_t xDesc,
+                    const void *x,
+                    const void *beta,
+                    const cudnnTensorDescriptor_t yDesc,
+                    void *y) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnPoolingForward");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, poolingDesc, alpha, xDesc, x, beta, yDesc, y);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnPoolingBackward(cudnnHandle_t handle,
+                     const cudnnPoolingDescriptor_t poolingDesc,
+                     const void *alpha,
+                     const cudnnTensorDescriptor_t yDesc,
+                     const void *y,
+                     const cudnnTensorDescriptor_t dyDesc,
+                     const void *dy,
+                     const cudnnTensorDescriptor_t xDesc,
+                     const void *x,
+                     const void *beta,
+                     const cudnnTensorDescriptor_t dxDesc,
+                     void *dx) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnPoolingBackward");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, poolingDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, beta, dxDesc, dx);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCreateActivationDescriptor(cudnnActivationDescriptor_t *activationDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnActivationDescriptor_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateActivationDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(activationDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetActivationDescriptor(cudnnActivationDescriptor_t activationDesc,
+                             cudnnActivationMode_t mode,
+                             cudnnNanPropagation_t reluNanOpt,
+                             double coef) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnActivationDescriptor_t, cudnnActivationMode_t, cudnnNanPropagation_t, double);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetActivationDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(activationDesc, mode, reluNanOpt, coef);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetActivationDescriptor(const cudnnActivationDescriptor_t activationDesc,
+                             cudnnActivationMode_t *mode,
+                             cudnnNanPropagation_t *reluNanOpt,
+                             double *coef) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnActivationDescriptor_t, cudnnActivationMode_t *, cudnnNanPropagation_t *, double *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetActivationDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(activationDesc, mode, reluNanOpt, coef);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDestroyActivationDescriptor(cudnnActivationDescriptor_t activationDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnActivationDescriptor_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyActivationDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(activationDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnActivationForward(cudnnHandle_t handle,
+                       cudnnActivationDescriptor_t activationDesc,
+                       const void *alpha,
+                       const cudnnTensorDescriptor_t xDesc,
+                       const void *x,
+                       const void *beta,
+                       const cudnnTensorDescriptor_t yDesc,
+                       void *y) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnActivationDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnActivationForward");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, activationDesc, alpha, xDesc, x, beta, yDesc, y);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnActivationBackward(cudnnHandle_t handle,
+                        cudnnActivationDescriptor_t activationDesc,
+                        const void *alpha,
+                        const cudnnTensorDescriptor_t yDesc,
+                        const void *y,
+                        const cudnnTensorDescriptor_t dyDesc,
+                        const void *dy,
+                        const cudnnTensorDescriptor_t xDesc,
+                        const void *x,
+                        const void *beta,
+                        const cudnnTensorDescriptor_t dxDesc,
+                        void *dx) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnActivationDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnActivationBackward");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, activationDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, beta, dxDesc, dx);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCreateLRNDescriptor(cudnnLRNDescriptor_t *normDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateLRNDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(normDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetLRNDescriptor(cudnnLRNDescriptor_t normDesc, unsigned lrnN, double lrnAlpha, double lrnBeta, double lrnK) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t, unsigned int, double, double, double);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetLRNDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(normDesc, lrnN, lrnAlpha, lrnBeta, lrnK);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetLRNDescriptor(cudnnLRNDescriptor_t normDesc, unsigned *lrnN, double *lrnAlpha, double *lrnBeta, double *lrnK) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t, unsigned int *, double *, double *, double *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetLRNDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(normDesc, lrnN, lrnAlpha, lrnBeta, lrnK);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDestroyLRNDescriptor(cudnnLRNDescriptor_t lrnDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyLRNDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(lrnDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnLRNCrossChannelForward(cudnnHandle_t handle,
+                            cudnnLRNDescriptor_t normDesc,
+                            cudnnLRNMode_t lrnMode,
+                            const void *alpha,
+                            const cudnnTensorDescriptor_t xDesc,
+                            const void *x,
+                            const void *beta,
+                            const cudnnTensorDescriptor_t yDesc,
+                            void *y) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnLRNCrossChannelForward");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, normDesc, lrnMode, alpha, xDesc, x, beta, yDesc, y);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnLRNCrossChannelBackward(cudnnHandle_t handle,
+                             cudnnLRNDescriptor_t normDesc,
+                             cudnnLRNMode_t lrnMode,
+                             const void *alpha,
+                             const cudnnTensorDescriptor_t yDesc,
+                             const void *y,
+                             const cudnnTensorDescriptor_t dyDesc,
+                             const void *dy,
+                             const cudnnTensorDescriptor_t xDesc,
+                             const void *x,
+                             const void *beta,
+                             const cudnnTensorDescriptor_t dxDesc,
+                             void *dx) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnLRNCrossChannelBackward");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, normDesc, lrnMode, alpha, yDesc, y, dyDesc, dy, xDesc, x, beta, dxDesc, dx);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDivisiveNormalizationForward(cudnnHandle_t handle,
+                                  cudnnLRNDescriptor_t normDesc,
+                                  cudnnDivNormMode_t mode,
+                                  const void *alpha,
+                                  const cudnnTensorDescriptor_t xDesc, /* same desc for means, temp, temp2 */
+                                  const void *x,
+                                  const void *means, /* if NULL, means are assumed to be zero */
+                                  void *temp,
+                                  void *temp2,
+                                  const void *beta,
+                                  const cudnnTensorDescriptor_t yDesc,
+                                  void *y) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, void *, void *, const void *, const cudnnTensorDescriptor_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationForward");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, temp, temp2, beta, yDesc, y);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDivisiveNormalizationBackward(cudnnHandle_t handle,
+                                   cudnnLRNDescriptor_t normDesc,
+                                   cudnnDivNormMode_t mode,
+                                   const void *alpha,
+                                   const cudnnTensorDescriptor_t xDesc, /* same desc for x, means, dy, temp, temp2 */
+                                   const void *x,
+                                   const void *means, /* if NULL, means are assumed to be zero */
+                                   const void *dy,
+                                   void *temp,
+                                   void *temp2,
+                                   const void *beta,
+                                   const cudnnTensorDescriptor_t dXdMeansDesc, /* same desc for dx, dMeans */
+                                   void *dx,                                   /* output x differential */
+                                   void *dMeans) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, void *, void *, const void *, const cudnnTensorDescriptor_t, void *, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationBackward");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, dy, temp, temp2, beta, dXdMeansDesc, dx, dMeans);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDeriveBNTensorDescriptor(cudnnTensorDescriptor_t derivedBnDesc,
+                              const cudnnTensorDescriptor_t xDesc,
+                              cudnnBatchNormMode_t mode) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, cudnnBatchNormMode_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDeriveBNTensorDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(derivedBnDesc, xDesc, mode);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(cudnnHandle_t handle,
+                                                         cudnnBatchNormMode_t mode,
+                                                         cudnnBatchNormOps_t bnOps,
+                                                         const cudnnTensorDescriptor_t xDesc,
+                                                         const cudnnTensorDescriptor_t zDesc,
+                                                         const cudnnTensorDescriptor_t yDesc,
+                                                         const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
+                                                         const cudnnActivationDescriptor_t activationDesc,
+                                                         size_t *sizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnActivationDescriptor_t, size_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, mode, bnOps, xDesc, zDesc, yDesc, bnScaleBiasMeanVarDesc, activationDesc, sizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetBatchNormalizationBackwardExWorkspaceSize(cudnnHandle_t handle,
+                                                  cudnnBatchNormMode_t mode,
+                                                  cudnnBatchNormOps_t bnOps,
+                                                  const cudnnTensorDescriptor_t xDesc,
+                                                  const cudnnTensorDescriptor_t yDesc,
+                                                  const cudnnTensorDescriptor_t dyDesc,
+                                                  const cudnnTensorDescriptor_t dzDesc,
+                                                  const cudnnTensorDescriptor_t dxDesc,
+                                                  const cudnnTensorDescriptor_t dBnScaleBiasDesc,
+                                                  const cudnnActivationDescriptor_t activationDesc,
+                                                  size_t *sizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnActivationDescriptor_t, size_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetBatchNormalizationBackwardExWorkspaceSize");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, mode, bnOps, xDesc, yDesc, dyDesc, dzDesc, dxDesc, dBnScaleBiasDesc, activationDesc, sizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetBatchNormalizationTrainingExReserveSpaceSize(cudnnHandle_t handle,
+                                                     cudnnBatchNormMode_t mode,
+                                                     cudnnBatchNormOps_t bnOps,
+                                                     const cudnnActivationDescriptor_t activationDesc,
+                                                     const cudnnTensorDescriptor_t xDesc,
+                                                     size_t *sizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, const cudnnActivationDescriptor_t, const cudnnTensorDescriptor_t, size_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetBatchNormalizationTrainingExReserveSpaceSize");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, mode, bnOps, activationDesc, xDesc, sizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnBatchNormalizationForwardTraining(
+    cudnnHandle_t handle,
+    cudnnBatchNormMode_t mode,
+
+    const void *alpha, /* alpha[0] = result blend factor */
+    const void *beta,  /* beta[0] = dest layer blend factor */
+
+    const cudnnTensorDescriptor_t xDesc,
+    const void *x, /* NxCxHxW */
+    const cudnnTensorDescriptor_t yDesc,
+    void *y, /* NxCxHxW */
+
+    /* Shared desc for the next 6 tensors in the argument list.
+       Data type to be set as follows:
+       type = (typeOf(x) == double) ? double : float
+       Dimensions for this descriptor depend on normalization mode
+       - Spatial Normalization : tensors are expected to have dims 1xCx1x1
+        (normalization is performed across NxHxW)
+       - Per-Activation Normalization : tensors are expected to have dims of 1xCxHxW
+        (normalization is performed across N) */
+    const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
+
+    /* 'Gamma' and 'Beta' respectively in Ioffe and Szegedy's paper's notation */
+    const void *bnScale,
+    const void *bnBias,
+
+    /* MUST use factor=1 in the very first call of a complete training cycle.
+       Use a factor=1/(1+n) at N-th call to the function to get
+       Cumulative Moving Average (CMA) behavior
+       CMA[n] = (x[1]+...+x[n])/n
+       Since CMA[n+1] = (n*CMA[n]+x[n+1])/(n+1) =
+       ((n+1)*CMA[n]-CMA[n])/(n+1) + x[n+1]/(n+1) =
+       CMA[n]*(1-1/(n+1)) + x[n+1]*1/(n+1) */
+    double exponentialAverageFactor,
+
+    /* Used in Training phase only.
+       runningMean = newMean*factor + runningMean*(1-factor) */
+    void *resultRunningMean,
+    /* Output in training mode, input in inference. Is the moving average
+       of  variance[x] (factor is applied in the same way as for runningMean) */
+    void *resultRunningVariance,
+
+    /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and backward functions. */
+    double epsilon,
+
+    /* Optionally save intermediate results from the forward pass here
+       - can be reused to speed up backward pass. NULL if unused */
+    void *resultSaveMean,
+    void *resultSaveInvVariance) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, const void *, double, void *, void *, double, void *, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardTraining");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, mode, alpha, beta, xDesc, x, yDesc, y, bnScaleBiasMeanVarDesc, bnScale, bnBias, exponentialAverageFactor, resultRunningMean, resultRunningVariance, epsilon, resultSaveMean, resultSaveInvVariance);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnBatchNormalizationForwardTrainingEx(
+    cudnnHandle_t handle,
+    cudnnBatchNormMode_t mode,
+    cudnnBatchNormOps_t bnOps,
+
+    const void *alpha, /* alpha[0] = result blend factor */
+    const void *beta,  /* beta[0] = dest layer blend factor */
+
+    const cudnnTensorDescriptor_t xDesc,
+    const void *xData,
+    const cudnnTensorDescriptor_t zDesc,
+    const void *zData,
+    const cudnnTensorDescriptor_t yDesc,
+    void *yData,
+
+    const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
+    const void *bnScale,
+    const void *bnBias,
+
+    double exponentialAverageFactor,
+    void *resultRunningMean,
+    void *resultRunningVariance,
+
+    /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and backward functions. */
+    double epsilon,
+
+    /* Optionally save intermediate results from the forward pass here
+       - can be reused to speed up backward pass. NULL if unused */
+    void *resultSaveMean,
+    void *resultSaveInvVariance,
+
+    cudnnActivationDescriptor_t activationDesc,
+    void *workspace,
+    size_t workSpaceSizeInBytes,
+    void *reserveSpace,
+    size_t reserveSpaceSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, const void *, double, void *, void *, double, void *, void *, cudnnActivationDescriptor_t, void *, size_t, void *, size_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardTrainingEx");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, mode, bnOps, alpha, beta, xDesc, xData, zDesc, zData, yDesc, yData, bnScaleBiasMeanVarDesc, bnScale, bnBias, exponentialAverageFactor, resultRunningMean, resultRunningVariance, epsilon, resultSaveMean, resultSaveInvVariance, activationDesc, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnBatchNormalizationForwardInference(cudnnHandle_t handle,
+                                        cudnnBatchNormMode_t mode,
+                                        const void *alpha, /* alpha[0] = result blend factor */
+                                        const void *beta,  /* beta[0] = dest layer blend factor */
+                                        const cudnnTensorDescriptor_t xDesc,
+                                        const void *x, /* NxCxHxW */
+                                        const cudnnTensorDescriptor_t yDesc,
+                                        void *y, /* NxCxHxW */
+                                        const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
+                                        const void *bnScale,
+                                        const void *bnBias,
+                                        const void *estimatedMean,
+                                        const void *estimatedVariance,
+                                        double epsilon) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, const void *, double);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardInference");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, mode, alpha, beta, xDesc, x, yDesc, y, bnScaleBiasMeanVarDesc, bnScale, bnBias, estimatedMean, estimatedVariance, epsilon);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnBatchNormalizationBackward(cudnnHandle_t handle,
+                                cudnnBatchNormMode_t mode,
+                                const void *alphaDataDiff,
+                                const void *betaDataDiff,
+                                const void *alphaParamDiff,
+                                const void *betaParamDiff,
+                                const cudnnTensorDescriptor_t xDesc, /* same desc for x, dx, dy */
+                                const void *x,
+                                const cudnnTensorDescriptor_t dyDesc,
+                                const void *dy,
+                                const cudnnTensorDescriptor_t dxDesc,
+                                void *dx,
+                                /* Shared tensor desc for the 4 tensors below */
+                                const cudnnTensorDescriptor_t dBnScaleBiasDesc,
+                                const void *bnScale, /* bnBias doesn't affect backpropagation */
+                                /* scale and bias diff are not backpropagated below this layer */
+                                void *dBnScaleResult,
+                                void *dBnBiasResult,
+                                /* Same epsilon as forward pass */
+                                double epsilon,
+
+                                /* Optionally cached intermediate results from
+                                   forward pass */
+                                const void *savedMean,
+                                const void *savedInvVariance) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, void *, void *, double, const void *, const void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationBackward");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, mode, alphaDataDiff, betaDataDiff, alphaParamDiff, betaParamDiff, xDesc, x, dyDesc, dy, dxDesc, dx, dBnScaleBiasDesc, bnScale, dBnScaleResult, dBnBiasResult, epsilon, savedMean, savedInvVariance);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnBatchNormalizationBackwardEx(cudnnHandle_t handle,
+                                  cudnnBatchNormMode_t mode,
+                                  cudnnBatchNormOps_t bnOps,
+
+                                  const void *alphaDataDiff,
+                                  const void *betaDataDiff,
+                                  const void *alphaParamDiff,
+                                  const void *betaParamDiff,
+                                  const cudnnTensorDescriptor_t xDesc,
+                                  const void *xData,
+                                  const cudnnTensorDescriptor_t yDesc,
+                                  const void *yData,
+                                  const cudnnTensorDescriptor_t dyDesc,
+                                  const void *dyData,
+                                  const cudnnTensorDescriptor_t dzDesc,
+                                  void *dzData,
+                                  const cudnnTensorDescriptor_t dxDesc,
+                                  void *dxData,
+
+                                  /* Shared tensor desc for the 4 tensors below */
+                                  const cudnnTensorDescriptor_t dBnScaleBiasDesc,
+                                  const void *bnScaleData,
+                                  const void *bnBiasData, /* needed if there is activation */
+                                  void *dBnScaleData,
+                                  void *dBnBiasData,
+                                  double epsilon, /* Same epsilon as forward pass */
+
+                                  /* Optionally cached intermediate results from
+                                     forward pass */
+                                  const void *savedMean,
+                                  const void *savedInvVariance,
+                                  cudnnActivationDescriptor_t activationDesc,
+                                  void *workSpace,
+                                  size_t workSpaceSizeInBytes,
+                                  void *reserveSpace,
+                                  size_t reserveSpaceSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, const void *, const void *, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, const void *, void *, void *, double, const void *, const void *, cudnnActivationDescriptor_t, void *, size_t, void *, size_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationBackwardEx");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, mode, bnOps, alphaDataDiff, betaDataDiff, alphaParamDiff, betaParamDiff, xDesc, xData, yDesc, yData, dyDesc, dyData, dzDesc, dzData, dxDesc, dxData, dBnScaleBiasDesc, bnScaleData, bnBiasData, dBnScaleData, dBnBiasData, epsilon, savedMean, savedInvVariance, activationDesc, workSpace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCreateSpatialTransformerDescriptor(cudnnSpatialTransformerDescriptor_t *stDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateSpatialTransformerDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(stDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetSpatialTransformerNdDescriptor(cudnnSpatialTransformerDescriptor_t stDesc,
+                                       cudnnSamplerType_t samplerType,
+                                       cudnnDataType_t dataType,
+                                       const int nbDims,
+                                       const int dimA[]) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t, cudnnSamplerType_t, cudnnDataType_t, const int, const int []);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetSpatialTransformerNdDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(stDesc, samplerType, dataType, nbDims, dimA);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDestroySpatialTransformerDescriptor(cudnnSpatialTransformerDescriptor_t stDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroySpatialTransformerDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(stDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSpatialTfGridGeneratorForward(cudnnHandle_t handle,
+                                   const cudnnSpatialTransformerDescriptor_t stDesc,
+                                   const void *theta,
+                                   void *grid) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorForward");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, stDesc, theta, grid);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSpatialTfGridGeneratorBackward(cudnnHandle_t handle,
+                                    const cudnnSpatialTransformerDescriptor_t stDesc,
+                                    const void *dgrid,
+                                    void *dtheta) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorBackward");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, stDesc, dgrid, dtheta);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSpatialTfSamplerForward(cudnnHandle_t handle,
+                             cudnnSpatialTransformerDescriptor_t stDesc,
+                             const void *alpha,
+                             const cudnnTensorDescriptor_t xDesc,
+                             const void *x,
+                             const void *grid,
+                             const void *beta,
+                             cudnnTensorDescriptor_t yDesc,
+                             void *y) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, cudnnTensorDescriptor_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfSamplerForward");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, stDesc, alpha, xDesc, x, grid, beta, yDesc, y);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSpatialTfSamplerBackward(cudnnHandle_t handle,
+                              cudnnSpatialTransformerDescriptor_t stDesc,
+                              const void *alpha,
+                              const cudnnTensorDescriptor_t xDesc,
+                              const void *x,
+                              const void *beta,
+                              const cudnnTensorDescriptor_t dxDesc,
+                              void *dx,
+                              const void *alphaDgrid,
+                              const cudnnTensorDescriptor_t dyDesc,
+                              const void *dy,
+                              const void *grid,
+                              const void *betaDgrid,
+                              void *dgrid) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfSamplerBackward");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, stDesc, alpha, xDesc, x, beta, dxDesc, dx, alphaDgrid, dyDesc, dy, grid, betaDgrid, dgrid);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCreateDropoutDescriptor(cudnnDropoutDescriptor_t *dropoutDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateDropoutDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(dropoutDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDestroyDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyDropoutDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(dropoutDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDropoutGetStatesSize(cudnnHandle_t handle, size_t *sizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, size_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutGetStatesSize");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, sizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDropoutGetReserveSpaceSize(cudnnTensorDescriptor_t xdesc, size_t *sizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, size_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutGetReserveSpaceSize");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(xdesc, sizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc,
+                          cudnnHandle_t handle,
+                          float dropout,
+                          void *states,
+                          size_t stateSizeInBytes,
+                          unsigned long long seed) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, float, void *, size_t, unsigned long long);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetDropoutDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(dropoutDesc, handle, dropout, states, stateSizeInBytes, seed);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnRestoreDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc,
+                              cudnnHandle_t handle,
+                              float dropout,
+                              void *states,
+                              size_t stateSizeInBytes,
+                              unsigned long long seed) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, float, void *, size_t, unsigned long long);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRestoreDropoutDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(dropoutDesc, handle, dropout, states, stateSizeInBytes, seed);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc,
+                          cudnnHandle_t handle,
+                          float *dropout,
+                          void **states,
+                          unsigned long long *seed) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, float *, void **, unsigned long long *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetDropoutDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(dropoutDesc, handle, dropout, states, seed);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDropoutForward(cudnnHandle_t handle,
+                    const cudnnDropoutDescriptor_t dropoutDesc,
+                    const cudnnTensorDescriptor_t xdesc,
+                    const void *x,
+                    const cudnnTensorDescriptor_t ydesc,
+                    void *y,
+                    void *reserveSpace,
+                    size_t reserveSpaceSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnDropoutDescriptor_t, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, void *, size_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutForward");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, dropoutDesc, xdesc, x, ydesc, y, reserveSpace, reserveSpaceSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDropoutBackward(cudnnHandle_t handle,
+                     const cudnnDropoutDescriptor_t dropoutDesc,
+                     const cudnnTensorDescriptor_t dydesc,
+                     const void *dy,
+                     const cudnnTensorDescriptor_t dxdesc,
+                     void *dx,
+                     void *reserveSpace,
+                     size_t reserveSpaceSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnDropoutDescriptor_t, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, void *, size_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutBackward");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, dropoutDesc, dydesc, dy, dxdesc, dx, reserveSpace, reserveSpaceSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t *rnnDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateRNNDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(rnnDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyRNNDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(rnnDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetRNNDescriptor(cudnnHandle_t handle,
+                      cudnnRNNDescriptor_t rnnDesc,
+                      const int hiddenSize,
+                      const int numLayers,
+                      cudnnDropoutDescriptor_t dropoutDesc,
+                      cudnnRNNInputMode_t inputMode,
+                      cudnnDirectionMode_t direction,
+                      cudnnRNNMode_t mode,
+                      cudnnRNNAlgo_t algo,
+                      cudnnDataType_t mathPrec) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int, cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, cudnnRNNAlgo_t, cudnnDataType_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, algo, mathPrec);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetRNNDescriptor(cudnnHandle_t handle,
+                      cudnnRNNDescriptor_t rnnDesc,
+                      int *hiddenSize,
+                      int *numLayers,
+                      cudnnDropoutDescriptor_t *dropoutDesc,
+                      cudnnRNNInputMode_t *inputMode,
+                      cudnnDirectionMode_t *direction,
+                      cudnnRNNMode_t *mode,
+                      cudnnRNNAlgo_t *algo,
+                      cudnnDataType_t *mathPrec) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, int *, int *, cudnnDropoutDescriptor_t *, cudnnRNNInputMode_t *, cudnnDirectionMode_t *, cudnnRNNMode_t *, cudnnRNNAlgo_t *, cudnnDataType_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, algo, mathPrec);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetRNNMatrixMathType(cudnnRNNDescriptor_t rnnDesc, cudnnMathType_t mType) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnMathType_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNMatrixMathType");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(rnnDesc, mType);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetRNNMatrixMathType(cudnnRNNDescriptor_t rnnDesc, cudnnMathType_t *mType) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnMathType_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNMatrixMathType");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(rnnDesc, mType);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetRNNBiasMode(cudnnRNNDescriptor_t rnnDesc, cudnnRNNBiasMode_t biasMode) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnRNNBiasMode_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNBiasMode");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(rnnDesc, biasMode);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetRNNBiasMode(cudnnRNNDescriptor_t rnnDesc, cudnnRNNBiasMode_t *biasMode) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnRNNBiasMode_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNBiasMode");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(rnnDesc, biasMode);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnRNNSetClip(cudnnHandle_t handle,
+                cudnnRNNDescriptor_t rnnDesc,
+                cudnnRNNClipMode_t clipMode,
+                cudnnNanPropagation_t clipNanOpt,
+                double lclip,
+                double rclip) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, cudnnRNNClipMode_t, cudnnNanPropagation_t, double, double);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNSetClip");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, clipMode, clipNanOpt, lclip, rclip);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnRNNGetClip(cudnnHandle_t handle,
+                cudnnRNNDescriptor_t rnnDesc,
+                cudnnRNNClipMode_t *clipMode,
+                cudnnNanPropagation_t *clipNanOpt,
+                double *lclip,
+                double *rclip) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, cudnnRNNClipMode_t *, cudnnNanPropagation_t *, double *, double *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNGetClip");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, clipMode, clipNanOpt, lclip, rclip);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetRNNProjectionLayers(cudnnHandle_t handle,
+                            cudnnRNNDescriptor_t rnnDesc,
+                            const int recProjSize,
+                            const int outProjSize) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNProjectionLayers");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, recProjSize, outProjSize);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetRNNProjectionLayers(cudnnHandle_t handle,
+                            const cudnnRNNDescriptor_t rnnDesc,
+                            int *recProjSize,
+                            int *outProjSize) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *, int *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNProjectionLayers");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, recProjSize, outProjSize);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCreatePersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc,
+                             const int minibatch,
+                             const cudnnDataType_t dataType,
+                             cudnnPersistentRNNPlan_t *plan) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, const int, const cudnnDataType_t, cudnnPersistentRNNPlan_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreatePersistentRNNPlan");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(rnnDesc, minibatch, dataType, plan);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDestroyPersistentRNNPlan(cudnnPersistentRNNPlan_t plan) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPersistentRNNPlan_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyPersistentRNNPlan");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(plan);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetPersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc, cudnnPersistentRNNPlan_t plan) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnPersistentRNNPlan_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetPersistentRNNPlan");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(rnnDesc, plan);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetRNNWorkspaceSize(cudnnHandle_t handle,
+                         const cudnnRNNDescriptor_t rnnDesc,
+                         const int seqLength,
+                         const cudnnTensorDescriptor_t *xDesc,
+                         size_t *sizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, size_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNWorkspaceSize");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, seqLength, xDesc, sizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetRNNTrainingReserveSize(cudnnHandle_t handle,
+                               const cudnnRNNDescriptor_t rnnDesc,
+                               const int seqLength,
+                               const cudnnTensorDescriptor_t *xDesc,
+                               size_t *sizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, size_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNTrainingReserveSize");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, seqLength, xDesc, sizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetRNNParamsSize(cudnnHandle_t handle,
+                      const cudnnRNNDescriptor_t rnnDesc,
+                      const cudnnTensorDescriptor_t xDesc,
+                      size_t *sizeInBytes,
+                      cudnnDataType_t dataType) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnTensorDescriptor_t, size_t *, cudnnDataType_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNParamsSize");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, xDesc, sizeInBytes, dataType);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetRNNLinLayerMatrixParams(cudnnHandle_t handle,
+                                const cudnnRNNDescriptor_t rnnDesc,
+                                const int pseudoLayer,
+                                const cudnnTensorDescriptor_t xDesc,
+                                const cudnnFilterDescriptor_t wDesc,
+                                const void *w,
+                                const int linLayerID,
+                                cudnnFilterDescriptor_t linLayerMatDesc,
+                                void **linLayerMat) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const void *, const int, cudnnFilterDescriptor_t, void **);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNLinLayerMatrixParams");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, pseudoLayer, xDesc, wDesc, w, linLayerID, linLayerMatDesc, linLayerMat);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetRNNLinLayerBiasParams(cudnnHandle_t handle,
+                              const cudnnRNNDescriptor_t rnnDesc,
+                              const int pseudoLayer,
+                              const cudnnTensorDescriptor_t xDesc,
+                              const cudnnFilterDescriptor_t wDesc,
+                              const void *w,
+                              const int linLayerID,
+                              cudnnFilterDescriptor_t linLayerBiasDesc,
+                              void **linLayerBias) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const void *, const int, cudnnFilterDescriptor_t, void **);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNLinLayerBiasParams");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, pseudoLayer, xDesc, wDesc, w, linLayerID, linLayerBiasDesc, linLayerBias);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnRNNForwardInference(cudnnHandle_t handle,
+                         const cudnnRNNDescriptor_t rnnDesc,
+                         const int seqLength,
+                         const cudnnTensorDescriptor_t *xDesc,
+                         const void *x,
+                         const cudnnTensorDescriptor_t hxDesc,
+                         const void *hx,
+                         const cudnnTensorDescriptor_t cxDesc,
+                         const void *cx,
+                         const cudnnFilterDescriptor_t wDesc,
+                         const void *w,
+                         const cudnnTensorDescriptor_t *yDesc,
+                         void *y,
+                         const cudnnTensorDescriptor_t hyDesc,
+                         void *hy,
+                         const cudnnTensorDescriptor_t cyDesc,
+                         void *cy,
+                         void *workspace,
+                         size_t workSpaceSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, void *, size_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNForwardInference");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, workSpaceSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnRNNForwardTraining(cudnnHandle_t handle,
+                        const cudnnRNNDescriptor_t rnnDesc,
+                        const int seqLength,
+                        const cudnnTensorDescriptor_t *xDesc,
+                        const void *x,
+                        const cudnnTensorDescriptor_t hxDesc,
+                        const void *hx,
+                        const cudnnTensorDescriptor_t cxDesc,
+                        const void *cx,
+                        const cudnnFilterDescriptor_t wDesc,
+                        const void *w,
+                        const cudnnTensorDescriptor_t *yDesc,
+                        void *y,
+                        const cudnnTensorDescriptor_t hyDesc,
+                        void *hy,
+                        const cudnnTensorDescriptor_t cyDesc,
+                        void *cy,
+                        void *workspace,
+                        size_t workSpaceSizeInBytes,
+                        void *reserveSpace,
+                        size_t reserveSpaceSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, size_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNForwardTraining");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnRNNBackwardData(cudnnHandle_t handle,
+                     const cudnnRNNDescriptor_t rnnDesc,
+                     const int seqLength,
+                     const cudnnTensorDescriptor_t *yDesc,
+                     const void *y,
+                     const cudnnTensorDescriptor_t *dyDesc,
+                     const void *dy,
+                     const cudnnTensorDescriptor_t dhyDesc,
+                     const void *dhy,
+                     const cudnnTensorDescriptor_t dcyDesc,
+                     const void *dcy,
+                     const cudnnFilterDescriptor_t wDesc,
+                     const void *w,
+                     const cudnnTensorDescriptor_t hxDesc,
+                     const void *hx,
+                     const cudnnTensorDescriptor_t cxDesc,
+                     const void *cx,
+                     const cudnnTensorDescriptor_t *dxDesc,
+                     void *dx,
+                     const cudnnTensorDescriptor_t dhxDesc,
+                     void *dhx,
+                     const cudnnTensorDescriptor_t dcxDesc,
+                     void *dcx,
+                     void *workspace,
+                     size_t workSpaceSizeInBytes,
+                     void *reserveSpace,
+                     size_t reserveSpaceSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, size_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNBackwardData");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc, dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, dx, dhxDesc, dhx, dcxDesc, dcx, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnRNNBackwardWeights(cudnnHandle_t handle,
+                        const cudnnRNNDescriptor_t rnnDesc,
+                        const int seqLength,
+                        const cudnnTensorDescriptor_t *xDesc,
+                        const void *x,
+                        const cudnnTensorDescriptor_t hxDesc,
+                        const void *hx,
+                        const cudnnTensorDescriptor_t *yDesc,
+                        const void *y,
+                        const void *workspace,
+                        size_t workSpaceSizeInBytes,
+                        const cudnnFilterDescriptor_t dwDesc,
+                        void *dw,
+                        const void *reserveSpace,
+                        size_t reserveSpaceSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t *, const void *, const void *, size_t, const cudnnFilterDescriptor_t, void *, const void *, size_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNBackwardWeights");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc, y, workspace, workSpaceSizeInBytes, dwDesc, dw, reserveSpace, reserveSpaceSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetRNNPaddingMode(cudnnRNNDescriptor_t rnnDesc, cudnnRNNPaddingMode_t paddingMode) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnRNNPaddingMode_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNPaddingMode");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(rnnDesc, paddingMode);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetRNNPaddingMode(cudnnRNNDescriptor_t rnnDesc, cudnnRNNPaddingMode_t *paddingMode) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnRNNPaddingMode_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNPaddingMode");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(rnnDesc, paddingMode);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCreateRNNDataDescriptor(cudnnRNNDataDescriptor_t *rnnDataDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDataDescriptor_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateRNNDataDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(rnnDataDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDestroyRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDataDescriptor_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyRNNDataDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(rnnDataDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc,
+                          cudnnDataType_t dataType,
+                          cudnnRNNDataLayout_t layout,
+                          int maxSeqLength,
+                          int batchSize,
+                          int vectorSize,
+                          const int seqLengthArray[], /* length of each sequence in the batch */
+                          void *paddingFill) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDataDescriptor_t, cudnnDataType_t, cudnnRNNDataLayout_t, int, int, int, const int [], void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNDataDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(rnnDataDesc, dataType, layout, maxSeqLength, batchSize, vectorSize, seqLengthArray, paddingFill);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc,
+                          cudnnDataType_t *dataType,
+                          cudnnRNNDataLayout_t *layout,
+                          int *maxSeqLength,
+                          int *batchSize,
+                          int *vectorSize,
+                          int arrayLengthRequested,
+                          int seqLengthArray[],
+                          void *paddingFill) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDataDescriptor_t, cudnnDataType_t *, cudnnRNNDataLayout_t *, int *, int *, int *, int, int [], void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNDataDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(rnnDataDesc, dataType, layout, maxSeqLength, batchSize, vectorSize, arrayLengthRequested, seqLengthArray, paddingFill);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnRNNForwardTrainingEx(cudnnHandle_t handle,
+                          const cudnnRNNDescriptor_t rnnDesc,
+                          const cudnnRNNDataDescriptor_t xDesc,
+                          const void *x,
+                          const cudnnTensorDescriptor_t hxDesc,
+                          const void *hx,
+                          const cudnnTensorDescriptor_t cxDesc,
+                          const void *cx,
+                          const cudnnFilterDescriptor_t wDesc,
+                          const void *w,
+                          const cudnnRNNDataDescriptor_t yDesc,
+                          void *y,
+                          const cudnnTensorDescriptor_t hyDesc,
+                          void *hy,
+                          const cudnnTensorDescriptor_t cyDesc,
+                          void *cy,
+                          const cudnnRNNDataDescriptor_t kDesc, /* reserved, should pass NULL */
+                          const void *keys,                     /* reserved, should pass NULL */
+                          const cudnnRNNDataDescriptor_t cDesc, /* reserved, should pass NULL */
+                          void *cAttn,                          /* reserved, should pass NULL */
+                          const cudnnRNNDataDescriptor_t iDesc, /* reserved, should pass NULL */
+                          void *iAttn,                          /* reserved, should pass NULL */
+                          const cudnnRNNDataDescriptor_t qDesc, /* reserved, should pass NULL */
+                          void *queries,                        /* reserved, should pass NULL */
+                          void *workSpace,
+                          size_t workSpaceSizeInBytes,
+                          void *reserveSpace,
+                          size_t reserveSpaceSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnRNNDataDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, void *, const cudnnRNNDataDescriptor_t, void *, const cudnnRNNDataDescriptor_t, void *, void *, size_t, void *, size_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNForwardTrainingEx");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, kDesc, keys, cDesc, cAttn, iDesc, iAttn, qDesc, queries, workSpace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnRNNForwardInferenceEx(cudnnHandle_t handle,
+                           const cudnnRNNDescriptor_t rnnDesc,
+                           const cudnnRNNDataDescriptor_t xDesc,
+                           const void *x,
+                           const cudnnTensorDescriptor_t hxDesc,
+                           const void *hx,
+                           const cudnnTensorDescriptor_t cxDesc,
+                           const void *cx,
+                           const cudnnFilterDescriptor_t wDesc,
+                           const void *w,
+                           const cudnnRNNDataDescriptor_t yDesc,
+                           void *y,
+                           const cudnnTensorDescriptor_t hyDesc,
+                           void *hy,
+                           const cudnnTensorDescriptor_t cyDesc,
+                           void *cy,
+                           const cudnnRNNDataDescriptor_t kDesc, /* reserved, should pass NULL */
+                           const void *keys,                     /* reserved, should pass NULL */
+                           const cudnnRNNDataDescriptor_t cDesc, /* reserved, should pass NULL */
+                           void *cAttn,                          /* reserved, should pass NULL */
+                           const cudnnRNNDataDescriptor_t iDesc, /* reserved, should pass NULL */
+                           void *iAttn,                          /* reserved, should pass NULL */
+                           const cudnnRNNDataDescriptor_t qDesc, /* reserved, should pass NULL */
+                           void *queries,                        /* reserved, should pass NULL */
+                           void *workSpace,
+                           size_t workSpaceSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnRNNDataDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, void *, const cudnnRNNDataDescriptor_t, void *, const cudnnRNNDataDescriptor_t, void *, void *, size_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNForwardInferenceEx");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, kDesc, keys, cDesc, cAttn, iDesc, iAttn, qDesc, queries, workSpace, workSpaceSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnRNNBackwardDataEx(cudnnHandle_t handle,
+                       const cudnnRNNDescriptor_t rnnDesc,
+                       const cudnnRNNDataDescriptor_t yDesc,
+                       const void *y,
+                       const cudnnRNNDataDescriptor_t dyDesc,
+                       const void *dy,
+                       const cudnnRNNDataDescriptor_t dcDesc, /* reserved, should pass NULL */
+                       const void *dcAttn,                    /* reserved, should pass NULL */
+                       const cudnnTensorDescriptor_t dhyDesc,
+                       const void *dhy,
+                       const cudnnTensorDescriptor_t dcyDesc,
+                       const void *dcy,
+                       const cudnnFilterDescriptor_t wDesc,
+                       const void *w,
+                       const cudnnTensorDescriptor_t hxDesc,
+                       const void *hx,
+                       const cudnnTensorDescriptor_t cxDesc,
+                       const void *cx,
+                       const cudnnRNNDataDescriptor_t dxDesc,
+                       void *dx,
+                       const cudnnTensorDescriptor_t dhxDesc,
+                       void *dhx,
+                       const cudnnTensorDescriptor_t dcxDesc,
+                       void *dcx,
+                       const cudnnRNNDataDescriptor_t dkDesc, /* reserved, should pass NULL */
+                       void *dkeys,                           /* reserved, should pass NULL */
+                       void *workSpace,
+                       size_t workSpaceSizeInBytes,
+                       void *reserveSpace,
+                       size_t reserveSpaceSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnRNNDataDescriptor_t, void *, void *, size_t, void *, size_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNBackwardDataEx");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, yDesc, y, dyDesc, dy, dcDesc, dcAttn, dhyDesc, dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, dx, dhxDesc, dhx, dcxDesc, dcx, dkDesc, dkeys, workSpace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnRNNBackwardWeightsEx(cudnnHandle_t handle,
+                          const cudnnRNNDescriptor_t rnnDesc,
+                          const cudnnRNNDataDescriptor_t xDesc,
+                          const void *x,
+                          const cudnnTensorDescriptor_t hxDesc,
+                          const void *hx,
+                          const cudnnRNNDataDescriptor_t yDesc,
+                          const void *y,
+                          void *workSpace,
+                          size_t workSpaceSizeInBytes,
+                          const cudnnFilterDescriptor_t dwDesc,
+                          void *dw,
+                          void *reserveSpace,
+                          size_t reserveSpaceSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, const void *, void *, size_t, const cudnnFilterDescriptor_t, void *, void *, size_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNBackwardWeightsEx");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, xDesc, x, hxDesc, hx, yDesc, y, workSpace, workSpaceSizeInBytes, dwDesc, dw, reserveSpace, reserveSpaceSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetRNNAlgorithmDescriptor(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, cudnnAlgorithmDescriptor_t algoDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, cudnnAlgorithmDescriptor_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNAlgorithmDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, algoDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetRNNForwardInferenceAlgorithmMaxCount(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNForwardInferenceAlgorithmMaxCount");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, count);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnFindRNNForwardInferenceAlgorithmEx(cudnnHandle_t handle,
+                                        const cudnnRNNDescriptor_t rnnDesc,
+                                        const int seqLength,
+                                        const cudnnTensorDescriptor_t *xDesc,
+                                        const void *x,
+                                        const cudnnTensorDescriptor_t hxDesc,
+                                        const void *hx,
+                                        const cudnnTensorDescriptor_t cxDesc,
+                                        const void *cx,
+                                        const cudnnFilterDescriptor_t wDesc,
+                                        const void *w,
+                                        const cudnnTensorDescriptor_t *yDesc,
+                                        void *y,
+                                        const cudnnTensorDescriptor_t hyDesc,
+                                        void *hy,
+                                        const cudnnTensorDescriptor_t cyDesc,
+                                        void *cy,
+                                        const float findIntensity,
+                                        const int requestedAlgoCount,
+                                        int *returnedAlgoCount,
+                                        cudnnAlgorithmPerformance_t *perfResults,
+                                        void *workspace,
+                                        size_t workSpaceSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const float, const int, int *, cudnnAlgorithmPerformance_t *, void *, size_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindRNNForwardInferenceAlgorithmEx");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, findIntensity, requestedAlgoCount, returnedAlgoCount, perfResults, workspace, workSpaceSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetRNNForwardTrainingAlgorithmMaxCount(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNForwardTrainingAlgorithmMaxCount");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, count);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnFindRNNForwardTrainingAlgorithmEx(cudnnHandle_t handle,
+                                       const cudnnRNNDescriptor_t rnnDesc,
+                                       const int seqLength,
+                                       const cudnnTensorDescriptor_t *xDesc,
+                                       const void *x,
+                                       const cudnnTensorDescriptor_t hxDesc,
+                                       const void *hx,
+                                       const cudnnTensorDescriptor_t cxDesc,
+                                       const void *cx,
+                                       const cudnnFilterDescriptor_t wDesc,
+                                       const void *w,
+                                       const cudnnTensorDescriptor_t *yDesc,
+                                       void *y,
+                                       const cudnnTensorDescriptor_t hyDesc,
+                                       void *hy,
+                                       const cudnnTensorDescriptor_t cyDesc,
+                                       void *cy,
+                                       const float findIntensity,
+                                       const int requestedAlgoCount,
+                                       int *returnedAlgoCount,
+                                       cudnnAlgorithmPerformance_t *perfResults,
+                                       void *workspace,
+                                       size_t workSpaceSizeInBytes,
+                                       void *reserveSpace,
+                                       size_t reserveSpaceSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const float, const int, int *, cudnnAlgorithmPerformance_t *, void *, size_t, void *, size_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindRNNForwardTrainingAlgorithmEx");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, findIntensity, requestedAlgoCount, returnedAlgoCount, perfResults, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetRNNBackwardDataAlgorithmMaxCount(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNBackwardDataAlgorithmMaxCount");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, count);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnFindRNNBackwardDataAlgorithmEx(cudnnHandle_t handle,
+                                    const cudnnRNNDescriptor_t rnnDesc,
+                                    const int seqLength,
+                                    const cudnnTensorDescriptor_t *yDesc,
+                                    const void *y,
+                                    const cudnnTensorDescriptor_t *dyDesc,
+                                    const void *dy,
+                                    const cudnnTensorDescriptor_t dhyDesc,
+                                    const void *dhy,
+                                    const cudnnTensorDescriptor_t dcyDesc,
+                                    const void *dcy,
+                                    const cudnnFilterDescriptor_t wDesc,
+                                    const void *w,
+                                    const cudnnTensorDescriptor_t hxDesc,
+                                    const void *hx,
+                                    const cudnnTensorDescriptor_t cxDesc,
+                                    const void *cx,
+                                    const cudnnTensorDescriptor_t *dxDesc,
+                                    void *dx,
+                                    const cudnnTensorDescriptor_t dhxDesc,
+                                    void *dhx,
+                                    const cudnnTensorDescriptor_t dcxDesc,
+                                    void *dcx,
+                                    const float findIntensity,
+                                    const int requestedAlgoCount,
+                                    int *returnedAlgoCount,
+                                    cudnnAlgorithmPerformance_t *perfResults,
+                                    void *workspace,
+                                    size_t workSpaceSizeInBytes,
+                                    void *reserveSpace,
+                                    size_t reserveSpaceSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const float, const int, int *, cudnnAlgorithmPerformance_t *, void *, size_t, void *, size_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindRNNBackwardDataAlgorithmEx");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc, dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, dx, dhxDesc, dhx, dcxDesc, dcx, findIntensity, requestedAlgoCount, returnedAlgoCount, perfResults, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetRNNBackwardWeightsAlgorithmMaxCount(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNBackwardWeightsAlgorithmMaxCount");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, count);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnFindRNNBackwardWeightsAlgorithmEx(cudnnHandle_t handle,
+                                       const cudnnRNNDescriptor_t rnnDesc,
+                                       const int seqLength,
+                                       const cudnnTensorDescriptor_t *xDesc,
+                                       const void *x,
+                                       const cudnnTensorDescriptor_t hxDesc,
+                                       const void *hx,
+                                       const cudnnTensorDescriptor_t *yDesc,
+                                       const void *y,
+                                       const float findIntensity,
+                                       const int requestedAlgoCount,
+                                       int *returnedAlgoCount,
+                                       cudnnAlgorithmPerformance_t *perfResults,
+                                       const void *workspace,
+                                       size_t workSpaceSizeInBytes,
+                                       const cudnnFilterDescriptor_t dwDesc,
+                                       void *dw,
+                                       const void *reserveSpace,
+                                       size_t reserveSpaceSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t *, const void *, const float, const int, int *, cudnnAlgorithmPerformance_t *, const void *, size_t, const cudnnFilterDescriptor_t, void *, const void *, size_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindRNNBackwardWeightsAlgorithmEx");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc, y, findIntensity, requestedAlgoCount, returnedAlgoCount, perfResults, workspace, workSpaceSizeInBytes, dwDesc, dw, reserveSpace, reserveSpaceSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCreateSeqDataDescriptor(cudnnSeqDataDescriptor_t *seqDataDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSeqDataDescriptor_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateSeqDataDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(seqDataDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDestroySeqDataDescriptor(cudnnSeqDataDescriptor_t seqDataDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSeqDataDescriptor_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroySeqDataDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(seqDataDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetSeqDataDescriptor(cudnnSeqDataDescriptor_t seqDataDesc,
+                          cudnnDataType_t dataType,
+                          int nbDims,
+                          const int dimA[],
+                          const cudnnSeqDataAxis_t axes[],
+                          size_t seqLengthArraySize,
+                          const int seqLengthArray[],
+                          void *paddingFill) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSeqDataDescriptor_t, cudnnDataType_t, int, const int [], const cudnnSeqDataAxis_t [], size_t, const int [], void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetSeqDataDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(seqDataDesc, dataType, nbDims, dimA, axes, seqLengthArraySize, seqLengthArray, paddingFill);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetSeqDataDescriptor(const cudnnSeqDataDescriptor_t seqDataDesc,
+                          cudnnDataType_t *dataType,
+                          int *nbDims,
+                          int nbDimsRequested,
+                          int dimA[],
+                          cudnnSeqDataAxis_t axes[],
+                          size_t *seqLengthArraySize,
+                          size_t seqLengthSizeRequested,
+                          int seqLengthArray[],
+                          void *paddingFill) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnSeqDataDescriptor_t, cudnnDataType_t *, int *, int, int [], cudnnSeqDataAxis_t [], size_t *, size_t, int [], void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetSeqDataDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(seqDataDesc, dataType, nbDims, nbDimsRequested, dimA, axes, seqLengthArraySize, seqLengthSizeRequested, seqLengthArray, paddingFill);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCreateAttnDescriptor(cudnnAttnDescriptor_t *attnDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAttnDescriptor_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateAttnDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(attnDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDestroyAttnDescriptor(cudnnAttnDescriptor_t attnDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAttnDescriptor_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyAttnDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(attnDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetAttnDescriptor(cudnnAttnDescriptor_t attnDesc,
+                       cudnnAttnQueryMap_t queryMap,
+                       int nHeads,
+                       double smScaler,
+                       cudnnDataType_t dataType,
+                       cudnnDataType_t computePrec,
+                       cudnnMathType_t mathType,
+                       cudnnDropoutDescriptor_t attnDropoutDesc,
+                       cudnnDropoutDescriptor_t postDropoutDesc,
+                       int qSize,
+                       int kSize,
+                       int vSize,
+                       int qProjSize,
+                       int kProjSize,
+                       int vProjSize,
+                       int oProjSize,
+                       int qoMaxSeqLength,
+                       int kvMaxSeqLength,
+                       int maxBatchSize,
+                       int maxBeamSize) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAttnDescriptor_t, cudnnAttnQueryMap_t, int, double, cudnnDataType_t, cudnnDataType_t, cudnnMathType_t, cudnnDropoutDescriptor_t, cudnnDropoutDescriptor_t, int, int, int, int, int, int, int, int, int, int, int);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetAttnDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(attnDesc, queryMap, nHeads, smScaler, dataType, computePrec, mathType, attnDropoutDesc, postDropoutDesc, qSize, kSize, vSize, qProjSize, kProjSize, vProjSize, oProjSize, qoMaxSeqLength, kvMaxSeqLength, maxBatchSize, maxBeamSize);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetAttnDescriptor(cudnnAttnDescriptor_t attnDesc,
+                       cudnnAttnQueryMap_t *queryMap,
+                       int *nHeads,
+                       double *smScaler,
+                       cudnnDataType_t *dataType,
+                       cudnnDataType_t *computePrec,
+                       cudnnMathType_t *mathType,
+                       cudnnDropoutDescriptor_t *attnDropoutDesc,
+                       cudnnDropoutDescriptor_t *postDropoutDesc,
+                       int *qSize,
+                       int *kSize,
+                       int *vSize,
+                       int *qProjSize,
+                       int *kProjSize,
+                       int *vProjSize,
+                       int *oProjSize,
+                       int *qoMaxSeqLength,
+                       int *kvMaxSeqLength,
+                       int *maxBatchSize,
+                       int *maxBeamSize) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAttnDescriptor_t, cudnnAttnQueryMap_t *, int *, double *, cudnnDataType_t *, cudnnDataType_t *, cudnnMathType_t *, cudnnDropoutDescriptor_t *, cudnnDropoutDescriptor_t *, int *, int *, int *, int *, int *, int *, int *, int *, int *, int *, int *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetAttnDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(attnDesc, queryMap, nHeads, smScaler, dataType, computePrec, mathType, attnDropoutDesc, postDropoutDesc, qSize, kSize, vSize, qProjSize, kProjSize, vProjSize, oProjSize, qoMaxSeqLength, kvMaxSeqLength, maxBatchSize, maxBeamSize);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetMultiHeadAttnBuffers(cudnnHandle_t handle,
+                             const cudnnAttnDescriptor_t attnDesc,
+                             size_t *weightSizeInBytes,
+                             size_t *workSpaceSizeInBytes,
+                             size_t *reserveSpaceSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnAttnDescriptor_t, size_t *, size_t *, size_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetMultiHeadAttnBuffers");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, attnDesc, weightSizeInBytes, workSpaceSizeInBytes, reserveSpaceSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetMultiHeadAttnWeights(cudnnHandle_t handle,
+                             const cudnnAttnDescriptor_t attnDesc,
+                             cudnnMultiHeadAttnWeightKind_t wKind,
+                             size_t weightSizeInBytes,
+                             const void *w,
+                             cudnnTensorDescriptor_t wDesc,
+                             void **wAddr) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnAttnDescriptor_t, cudnnMultiHeadAttnWeightKind_t, size_t, const void *, cudnnTensorDescriptor_t, void **);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetMultiHeadAttnWeights");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, attnDesc, wKind, weightSizeInBytes, w, wDesc, wAddr);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnMultiHeadAttnForward(cudnnHandle_t handle,
+                          const cudnnAttnDescriptor_t attnDesc,
+                          int currIdx,
+                          const int *loWinIdx,
+                          const int *hiWinIdx,
+                          const int *seqLengthArrayQRO,
+                          const int *seqLengthArrayKV,
+                          const cudnnSeqDataDescriptor_t qDesc,
+                          const void *queries,
+                          const void *residuals,
+                          const cudnnSeqDataDescriptor_t kDesc,
+                          const void *keys,
+                          const cudnnSeqDataDescriptor_t vDesc,
+                          const void *values,
+                          const cudnnSeqDataDescriptor_t oDesc,
+                          void *out,
+                          size_t weightSizeInBytes,
+                          const void *w,
+                          size_t workSpaceSizeInBytes,
+                          void *workSpace,
+                          size_t reserveSpaceSizeInBytes,
+                          void *reserveSpace) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnAttnDescriptor_t, int, const int *, const int *, const int *, const int *, const cudnnSeqDataDescriptor_t, const void *, const void *, const cudnnSeqDataDescriptor_t, const void *, const cudnnSeqDataDescriptor_t, const void *, const cudnnSeqDataDescriptor_t, void *, size_t, const void *, size_t, void *, size_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnMultiHeadAttnForward");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, attnDesc, currIdx, loWinIdx, hiWinIdx, seqLengthArrayQRO, seqLengthArrayKV, qDesc, queries, residuals, kDesc, keys, vDesc, values, oDesc, out, weightSizeInBytes, w, workSpaceSizeInBytes, workSpace, reserveSpaceSizeInBytes, reserveSpace);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnMultiHeadAttnBackwardData(cudnnHandle_t handle,
+                               const cudnnAttnDescriptor_t attnDesc,
+                               const int *loWinIdx,
+                               const int *hiWinIdx,
+                               const int *seqLengthArrayDQDO,
+                               const int *seqLengthArrayDKDV,
+                               const cudnnSeqDataDescriptor_t doDesc,
+                               const void *dout,
+                               const cudnnSeqDataDescriptor_t dqDesc,
+                               void *dqueries,
+                               const void *queries,
+                               const cudnnSeqDataDescriptor_t dkDesc,
+                               void *dkeys,
+                               const void *keys,
+                               const cudnnSeqDataDescriptor_t dvDesc,
+                               void *dvalues,
+                               const void *values,
+                               size_t weightSizeInBytes,
+                               const void *w,
+                               size_t workSpaceSizeInBytes,
+                               void *workSpace,
+                               size_t reserveSpaceSizeInBytes,
+                               void *reserveSpace) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnAttnDescriptor_t, const int *, const int *, const int *, const int *, const cudnnSeqDataDescriptor_t, const void *, const cudnnSeqDataDescriptor_t, void *, const void *, const cudnnSeqDataDescriptor_t, void *, const void *, const cudnnSeqDataDescriptor_t, void *, const void *, size_t, const void *, size_t, void *, size_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnMultiHeadAttnBackwardData");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, attnDesc, loWinIdx, hiWinIdx, seqLengthArrayDQDO, seqLengthArrayDKDV, doDesc, dout, dqDesc, dqueries, queries, dkDesc, dkeys, keys, dvDesc, dvalues, values, weightSizeInBytes, w, workSpaceSizeInBytes, workSpace, reserveSpaceSizeInBytes, reserveSpace);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnMultiHeadAttnBackwardWeights(cudnnHandle_t handle,
+                                  const cudnnAttnDescriptor_t attnDesc,
+                                  cudnnWgradMode_t addGrad,
+                                  const cudnnSeqDataDescriptor_t qDesc,
+                                  const void *queries,
+                                  const cudnnSeqDataDescriptor_t kDesc,
+                                  const void *keys,
+                                  const cudnnSeqDataDescriptor_t vDesc,
+                                  const void *values,
+                                  const cudnnSeqDataDescriptor_t doDesc,
+                                  const void *dout,
+                                  size_t weightSizeInBytes,
+                                  const void *w,
+                                  void *dw,
+                                  size_t workSpaceSizeInBytes,
+                                  void *workSpace,
+                                  size_t reserveSpaceSizeInBytes,
+                                  void *reserveSpace) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnAttnDescriptor_t, cudnnWgradMode_t, const cudnnSeqDataDescriptor_t, const void *, const cudnnSeqDataDescriptor_t, const void *, const cudnnSeqDataDescriptor_t, const void *, const cudnnSeqDataDescriptor_t, const void *, size_t, const void *, void *, size_t, void *, size_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnMultiHeadAttnBackwardWeights");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, attnDesc, addGrad, qDesc, queries, kDesc, keys, vDesc, values, doDesc, dout, weightSizeInBytes, w, dw, workSpaceSizeInBytes, workSpace, reserveSpaceSizeInBytes, reserveSpace);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCreateCTCLossDescriptor(cudnnCTCLossDescriptor_t *ctcLossDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateCTCLossDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(ctcLossDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t compType) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetCTCLossDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(ctcLossDesc, compType);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetCTCLossDescriptorEx(cudnnCTCLossDescriptor_t ctcLossDesc,
+                            cudnnDataType_t compType,
+                            cudnnLossNormalizationMode_t normMode,
+                            cudnnNanPropagation_t gradMode) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t, cudnnLossNormalizationMode_t, cudnnNanPropagation_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetCTCLossDescriptorEx");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(ctcLossDesc, compType, normMode, gradMode);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t *compType) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCTCLossDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(ctcLossDesc, compType);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetCTCLossDescriptorEx(cudnnCTCLossDescriptor_t ctcLossDesc,
+                            cudnnDataType_t *compType,
+                            cudnnLossNormalizationMode_t *normMode,
+                            cudnnNanPropagation_t *gradMode) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t *, cudnnLossNormalizationMode_t *, cudnnNanPropagation_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCTCLossDescriptorEx");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(ctcLossDesc, compType, normMode, gradMode);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDestroyCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyCTCLossDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(ctcLossDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCTCLoss(
+    cudnnHandle_t handle,
+    const cudnnTensorDescriptor_t
+        probsDesc,     /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the timing steps, N is the
+                          mini batch size, A is the alphabet size)  */
+    const void *probs, /* probabilities after softmax, in GPU memory */
+    const int *labels, /* labels, in CPU memory */
+    const int *labelLengths,                     /* the length of each label, in CPU memory */
+    const int *inputLengths,                     /* the lengths of timing steps in each batch, in CPU memory */
+    void *costs,                                 /* the returned costs of CTC, in GPU memory */
+    const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the dimensions are T,N,A */
+    const void *gradients,   /* the returned CTC gradients, in GPU memory, to compute costs only, set it to NULL */
+    cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
+    cudnnCTCLossDescriptor_t ctcLossDesc,
+    void *workspace,              /* pointer to the workspace, in GPU memory */
+    size_t workSpaceSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const int *, const int *, const int *, void *, const cudnnTensorDescriptor_t, const void *, cudnnCTCLossAlgo_t, cudnnCTCLossDescriptor_t, void *, size_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCTCLoss");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, probsDesc, probs, labels, labelLengths, inputLengths, costs, gradientsDesc, gradients, algo, ctcLossDesc, workspace, workSpaceSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetCTCLossWorkspaceSize(
+    cudnnHandle_t handle,
+    const cudnnTensorDescriptor_t probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the
+                                                timing steps, N is the mini batch size, A is the alphabet size) */
+    const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the
+                                                    dimensions are T,N,A. To compute costs
+                                                    only, set it to NULL */
+    const int *labels,                           /* labels, in CPU memory */
+    const int *labelLengths,                     /* the length of each label, in CPU memory */
+    const int *inputLengths,                     /* the lengths of timing steps in each batch, in CPU memory */
+    cudnnCTCLossAlgo_t algo,                     /* algorithm selected, supported now 0 and 1 */
+    cudnnCTCLossDescriptor_t ctcLossDesc,
+    size_t *sizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const int *, const int *, const int *, cudnnCTCLossAlgo_t, cudnnCTCLossDescriptor_t, size_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCTCLossWorkspaceSize");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, probsDesc, gradientsDesc, labels, labelLengths, inputLengths, algo, ctcLossDesc, sizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCreateAlgorithmDescriptor(cudnnAlgorithmDescriptor_t *algoDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateAlgorithmDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(algoDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetAlgorithmDescriptor(cudnnAlgorithmDescriptor_t algoDesc, cudnnAlgorithm_t algorithm) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t, cudnnAlgorithm_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetAlgorithmDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(algoDesc, algorithm);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetAlgorithmDescriptor(const cudnnAlgorithmDescriptor_t algoDesc, cudnnAlgorithm_t *algorithm) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnAlgorithmDescriptor_t, cudnnAlgorithm_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetAlgorithmDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(algoDesc, algorithm);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCopyAlgorithmDescriptor(const cudnnAlgorithmDescriptor_t src, cudnnAlgorithmDescriptor_t dest) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnAlgorithmDescriptor_t, cudnnAlgorithmDescriptor_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCopyAlgorithmDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(src, dest);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDestroyAlgorithmDescriptor(cudnnAlgorithmDescriptor_t algoDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyAlgorithmDescriptor");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(algoDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCreateAlgorithmPerformance(cudnnAlgorithmPerformance_t *algoPerf, int numberToCreate) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmPerformance_t *, int);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateAlgorithmPerformance");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(algoPerf, numberToCreate);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetAlgorithmPerformance(cudnnAlgorithmPerformance_t algoPerf,
+                             cudnnAlgorithmDescriptor_t algoDesc,
+                             cudnnStatus_t status,
+                             float time,
+                             size_t memory) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmPerformance_t, cudnnAlgorithmDescriptor_t, cudnnStatus_t, float, size_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetAlgorithmPerformance");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(algoPerf, algoDesc, status, time, memory);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetAlgorithmPerformance(const cudnnAlgorithmPerformance_t algoPerf,
+                             cudnnAlgorithmDescriptor_t *algoDesc,
+                             cudnnStatus_t *status,
+                             float *time,
+                             size_t *memory) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnAlgorithmPerformance_t, cudnnAlgorithmDescriptor_t *, cudnnStatus_t *, float *, size_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetAlgorithmPerformance");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(algoPerf, algoDesc, status, time, memory);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDestroyAlgorithmPerformance(cudnnAlgorithmPerformance_t *algoPerf, int numberToDestroy) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmPerformance_t *, int);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyAlgorithmPerformance");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(algoPerf, numberToDestroy);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetAlgorithmSpaceSize(cudnnHandle_t handle, cudnnAlgorithmDescriptor_t algoDesc, size_t *algoSpaceSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnAlgorithmDescriptor_t, size_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetAlgorithmSpaceSize");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, algoDesc, algoSpaceSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSaveAlgorithm(cudnnHandle_t handle,
+                   cudnnAlgorithmDescriptor_t algoDesc,
+                   void *algoSpace,
+                   size_t algoSpaceSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnAlgorithmDescriptor_t, void *, size_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSaveAlgorithm");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, algoDesc, algoSpace, algoSpaceSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnRestoreAlgorithm(cudnnHandle_t handle,
+                      void *algoSpace,
+                      size_t algoSpaceSizeInBytes,
+                      cudnnAlgorithmDescriptor_t algoDesc) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, void *, size_t, cudnnAlgorithmDescriptor_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRestoreAlgorithm");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, algoSpace, algoSpaceSizeInBytes, algoDesc);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetCallback(unsigned mask, void *udata, cudnnCallback_t fptr) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(unsigned int, void *, cudnnCallback_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetCallback");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(mask, udata, fptr);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetCallback(unsigned *mask, void **udata, cudnnCallback_t *fptr) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(unsigned int *, void **, cudnnCallback_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCallback");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(mask, udata, fptr);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCreateFusedOpsConstParamPack(cudnnFusedOpsConstParamPack_t *constPack, cudnnFusedOps_t ops) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFusedOpsConstParamPack_t *, cudnnFusedOps_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateFusedOpsConstParamPack");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(constPack, ops);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDestroyFusedOpsConstParamPack(cudnnFusedOpsConstParamPack_t constPack) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFusedOpsConstParamPack_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyFusedOpsConstParamPack");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(constPack);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetFusedOpsConstParamPackAttribute(cudnnFusedOpsConstParamPack_t constPack,
+                                        cudnnFusedOpsConstParamLabel_t paramLabel,
+                                        const void *param) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFusedOpsConstParamPack_t, cudnnFusedOpsConstParamLabel_t, const void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetFusedOpsConstParamPackAttribute");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(constPack, paramLabel, param);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetFusedOpsConstParamPackAttribute(const cudnnFusedOpsConstParamPack_t constPack,
+                                        cudnnFusedOpsConstParamLabel_t paramLabel,
+                                        void *param,
+                                        int *isNULL) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnFusedOpsConstParamPack_t, cudnnFusedOpsConstParamLabel_t, void *, int *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetFusedOpsConstParamPackAttribute");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(constPack, paramLabel, param, isNULL);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCreateFusedOpsVariantParamPack(cudnnFusedOpsVariantParamPack_t *varPack, cudnnFusedOps_t ops) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFusedOpsVariantParamPack_t *, cudnnFusedOps_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateFusedOpsVariantParamPack");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(varPack, ops);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDestroyFusedOpsVariantParamPack(cudnnFusedOpsVariantParamPack_t varPack) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFusedOpsVariantParamPack_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyFusedOpsVariantParamPack");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(varPack);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetFusedOpsVariantParamPackAttribute(cudnnFusedOpsVariantParamPack_t varPack,
+                                          cudnnFusedOpsVariantParamLabel_t paramLabel,
+                                          void *ptr) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFusedOpsVariantParamPack_t, cudnnFusedOpsVariantParamLabel_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetFusedOpsVariantParamPackAttribute");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(varPack, paramLabel, ptr);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnGetFusedOpsVariantParamPackAttribute(const cudnnFusedOpsVariantParamPack_t varPack,
+                                          cudnnFusedOpsVariantParamLabel_t paramLabel,
+                                          void *ptr) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnFusedOpsVariantParamPack_t, cudnnFusedOpsVariantParamLabel_t, void *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetFusedOpsVariantParamPackAttribute");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(varPack, paramLabel, ptr);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnCreateFusedOpsPlan(cudnnFusedOpsPlan_t *plan, cudnnFusedOps_t ops) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFusedOpsPlan_t *, cudnnFusedOps_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateFusedOpsPlan");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(plan, ops);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnDestroyFusedOpsPlan(cudnnFusedOpsPlan_t plan) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFusedOpsPlan_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyFusedOpsPlan");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(plan);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnMakeFusedOpsPlan(cudnnHandle_t handle,
+                      cudnnFusedOpsPlan_t plan,
+                      const cudnnFusedOpsConstParamPack_t constPack,
+                      size_t *workspaceSizeInBytes) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnFusedOpsPlan_t, const cudnnFusedOpsConstParamPack_t, size_t *);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnMakeFusedOpsPlan");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, plan, constPack, workspaceSizeInBytes);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnFusedOpsExecute(cudnnHandle_t handle, const cudnnFusedOpsPlan_t plan, cudnnFusedOpsVariantParamPack_t varPack) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFusedOpsPlan_t, cudnnFusedOpsVariantParamPack_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFusedOpsExecute");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, plan, varPack);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetRNNDescriptor_v6(cudnnHandle_t handle,
+                         cudnnRNNDescriptor_t rnnDesc,
+                         const int hiddenSize,
+                         const int numLayers,
+                         cudnnDropoutDescriptor_t dropoutDesc,
+                         cudnnRNNInputMode_t inputMode,
+                         cudnnDirectionMode_t direction,
+                         cudnnRNNMode_t mode,
+                         cudnnRNNAlgo_t algo,
+                         cudnnDataType_t mathPrec) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int, cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, cudnnRNNAlgo_t, cudnnDataType_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNDescriptor_v6");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, algo, mathPrec);
+}
+
+cudnnStatus_t CUDNNWINAPI
+cudnnSetRNNDescriptor_v5(cudnnRNNDescriptor_t rnnDesc,
+                         int hiddenSize,
+                         int numLayers,
+                         cudnnDropoutDescriptor_t dropoutDesc,
+                         cudnnRNNInputMode_t inputMode,
+                         cudnnDirectionMode_t direction,
+                         cudnnRNNMode_t mode,
+                         cudnnDataType_t mathPrec) {
+  using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, int, int, cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, cudnnDataType_t);
+  static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNDescriptor_v5");
+  if (!func_ptr) return GetSymbolNotFoundError();
+  return func_ptr(rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, mathPrec);
+}
+
+}  // extern "C"
diff --git a/tensorflow/stream_executor/cuda/cudnn_stub.cc b/tensorflow/stream_executor/cuda/cudnn_stub.cc
index 3b567c1..5a05437 100644
--- a/tensorflow/stream_executor/cuda/cudnn_stub.cc
+++ b/tensorflow/stream_executor/cuda/cudnn_stub.cc
@@ -57,6 +57,8 @@
 #include "tensorflow/stream_executor/cuda/cudnn_7_1.inc"
 #elif CUDNN_MINOR < 4
 #include "tensorflow/stream_executor/cuda/cudnn_7_3.inc"
-#else
+#elif CUDNN_MINOR < 6
 #include "tensorflow/stream_executor/cuda/cudnn_7_4.inc"
+#else
+#include "tensorflow/stream_executor/cuda/cudnn_7_6.inc"
 #endif
diff --git a/tensorflow/stream_executor/cuda/redzone_allocator.cc b/tensorflow/stream_executor/cuda/redzone_allocator.cc
index 76ff86c..afd4f57 100644
--- a/tensorflow/stream_executor/cuda/redzone_allocator.cc
+++ b/tensorflow/stream_executor/cuda/redzone_allocator.cc
@@ -45,26 +45,28 @@
 using RedzoneCheckStatus = RedzoneAllocator::RedzoneCheckStatus;
 
 RedzoneAllocator::RedzoneAllocator(
-    int device_ordinal, DeviceMemoryAllocator* memory_allocator,
-    cuda::PtxCompilationOptions ptx_compilation_opts, uint64 redzone_size,
-    uint8 redzone_pattern)
-    : device_ordinal_(device_ordinal),
+    Stream* stream, DeviceMemoryAllocator* memory_allocator,
+    cuda::PtxCompilationOptions ptx_compilation_opts, int64 memory_limit,
+    int64 redzone_size, uint8 redzone_pattern)
+    : device_ordinal_(stream->parent()->device_ordinal()),
+      stream_(stream),
+      memory_limit_(memory_limit),
       redzone_size_(RoundUpToNearest(
           redzone_size,
-          static_cast<uint64>(tensorflow::Allocator::kAllocatorAlignment))),
+          static_cast<int64>(tensorflow::Allocator::kAllocatorAlignment))),
       redzone_pattern_(redzone_pattern),
       memory_allocator_(memory_allocator),
       ptx_compilation_opts_(ptx_compilation_opts) {}
 
 port::StatusOr<DeviceMemory<uint8>> RedzoneAllocator::AllocateBytes(
-    Stream* stream, int64 byte_size) {
+    int64 byte_size) {
   CHECK_GE(byte_size, 0) << "byte_size must be positive.";
-  if (byte_size > GetMemoryLimitInBytes(stream)) {
+  if (byte_size > GetMemoryLimitInBytes()) {
     return port::Status(
         port::error::RESOURCE_EXHAUSTED,
         absl::StrFormat(
             "Allocating %d bytes exceeds the memory limit of %d bytes.",
-            byte_size, GetMemoryLimitInBytes(stream)));
+            byte_size, GetMemoryLimitInBytes()));
   }
 
   int64 rhs_slop = RoundUpToNearest(byte_size, kRhsRedzoneAlign) - byte_size;
@@ -78,10 +80,10 @@
   static_assert(sizeof(uint8) == 1, "Unexpected size");
   DeviceMemory<uint8> allocated_buffer_memory(*allocated_buffer);
 
-  DeviceMemory<uint8> lhs_redzone = stream->parent()->GetSubBuffer(
+  DeviceMemory<uint8> lhs_redzone = stream_->parent()->GetSubBuffer(
       &allocated_buffer_memory, 0, redzone_size_);
 
-  DeviceMemory<uint8> data_chunk = stream->parent()->GetSubBuffer(
+  DeviceMemory<uint8> data_chunk = stream_->parent()->GetSubBuffer(
       &allocated_buffer_memory, redzone_size_, byte_size);
 
   // Split up the RHS redzone into two pieces:
@@ -89,10 +91,10 @@
   //  - redzone_size_ bytes.
   // We do this because Stream::ThenMemset32 requires the buffer address and
   // size to be aligned to 4 bytes.
-  DeviceMemory<uint8> rhs_redzone_slop = stream->parent()->GetSubBuffer(
+  DeviceMemory<uint8> rhs_redzone_slop = stream_->parent()->GetSubBuffer(
       &allocated_buffer_memory, redzone_size_ + byte_size, rhs_slop);
 
-  DeviceMemory<uint8> rhs_redzone_nonslop = stream->parent()->GetSubBuffer(
+  DeviceMemory<uint8> rhs_redzone_nonslop = stream_->parent()->GetSubBuffer(
       &allocated_buffer_memory, redzone_size_ + byte_size + rhs_slop,
       redzone_size_);
 
@@ -100,11 +102,11 @@
                          redzone_pattern_};
   uint32 pattern32;
   std::memcpy(&pattern32, pattern_arr, sizeof(pattern32));
-  stream->ThenMemset32(&lhs_redzone, pattern32, redzone_size_);
+  stream_->ThenMemset32(&lhs_redzone, pattern32, redzone_size_);
   if (rhs_slop != 0) {
-    stream->ThenMemcpy(&rhs_redzone_slop, &pattern32, rhs_slop);
+    stream_->ThenMemcpy(&rhs_redzone_slop, &pattern32, rhs_slop);
   }
-  stream->ThenMemset32(&rhs_redzone_nonslop, pattern32, redzone_size_);
+  stream_->ThenMemset32(&rhs_redzone_nonslop, pattern32, redzone_size_);
 
   allocated_buffers_.emplace_back(std::move(allocated_buffer), byte_size);
   return data_chunk;
@@ -295,9 +297,8 @@
   return RedzoneCheckStatus::OK();
 }
 
-port::StatusOr<RedzoneCheckStatus> RedzoneAllocator::CheckRedzones(
-    Stream* stream) const {
-  StreamExecutor* executor = stream->parent();
+port::StatusOr<RedzoneCheckStatus> RedzoneAllocator::CheckRedzones() const {
+  StreamExecutor* executor = stream_->parent();
 
   absl::Span<const uint8> compiled_ptx = {};
   port::StatusOr<absl::Span<const uint8>> compiled_ptx_or =
@@ -316,7 +317,7 @@
 
   ScopedDeviceMemory<uint64> out_param =
       executor->AllocateOwnedScalar<uint64>();
-  stream->ThenMemZero(out_param.ptr(), sizeof(uint64));
+  stream_->ThenMemZero(out_param.ptr(), sizeof(uint64));
 
   TF_ASSIGN_OR_RETURN(
       std::unique_ptr<ComparisonKernelT> comparison_kernel,
@@ -327,7 +328,7 @@
   for (const auto& buf_and_size : allocated_buffers_) {
     TF_ASSIGN_OR_RETURN(
         RedzoneCheckStatus redzone_status,
-        CheckRedzonesForBuffer(stream, *buf_and_size.first, out_param.cref(),
+        CheckRedzonesForBuffer(stream_, *buf_and_size.first, out_param.cref(),
                                *comparison_kernel, buf_and_size.second,
                                redzone_size_, redzone_pattern_));
     if (!redzone_status.ok()) {
diff --git a/tensorflow/stream_executor/cuda/redzone_allocator.h b/tensorflow/stream_executor/cuda/redzone_allocator.h
index 42ddd99..d09a5c0 100644
--- a/tensorflow/stream_executor/cuda/redzone_allocator.h
+++ b/tensorflow/stream_executor/cuda/redzone_allocator.h
@@ -39,21 +39,24 @@
 // memory for cudnn convolutions.
 class RedzoneAllocator : public ScratchAllocator {
  public:
-  RedzoneAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator,
+  static const int64 kDefaultMemoryLimit = 1LL << 32;  // 4GB
+  static const int64 kDefaultRedzoneSize =
+      1LL << 23;  // 8MiB per side, 16MiB total.
+  static const uint8 kDefaultRedzonePattern = -1;
+  RedzoneAllocator(Stream* stream, DeviceMemoryAllocator* memory_allocator,
                    cuda::PtxCompilationOptions ptx_compilation_opts,
-                   uint64 redzone_size = 1 << 23,  // 8MiB per side, 16MiB total
-                   uint8 redzone_pattern = -1);
+                   int64 memory_limit = kDefaultMemoryLimit,
+                   int64 redzone_size = kDefaultRedzoneSize,
+                   uint8 redzone_pattern = kDefaultRedzonePattern);
 
   // Redzones don't count towards the memory limit.
-  int64 GetMemoryLimitInBytes(Stream* stream) override {
-    return 1LL << 32;  // 4GB.  TODO(jlebar): Tune this?
-  }
+  int64 GetMemoryLimitInBytes() override { return memory_limit_; }
+
   int64 TotalAllocatedBytesExcludingRedzones() const {
     return allocated_bytes_excluding_redzones_;
   }
 
-  port::StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream,
-                                                    int64 byte_size) override;
+  port::StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override;
 
   // Non-empty redzone check status implies that there was a write into a
   // redzone, with a string communicating the location of the write.
@@ -92,12 +95,16 @@
   //  - RedzoneCheckStatus with a non-empty error message iff a write into a
   //    redzone has been detected.
   //  - A stream error, if loading or launching the kernel has failed.
-  port::StatusOr<RedzoneCheckStatus> CheckRedzones(Stream* stream) const;
+  port::StatusOr<RedzoneCheckStatus> CheckRedzones() const;
 
  private:
   const int device_ordinal_;
+  Stream* stream_;
 
-  // Redzone size on *one side* of allocation.
+  // Memory limit of the allocator in bytes.
+  const int64 memory_limit_;
+
+  // Redzone size on *one side* of allocation in bytes.
   //
   // Must be a multiple of kXlaAllocatedBufferAlignBytes, otherwise the buffers
   // returned to users will be misaligned.
diff --git a/tensorflow/stream_executor/cuda/redzone_allocator_test.cc b/tensorflow/stream_executor/cuda/redzone_allocator_test.cc
index 23fee51..97aa2c9 100644
--- a/tensorflow/stream_executor/cuda/redzone_allocator_test.cc
+++ b/tensorflow/stream_executor/cuda/redzone_allocator_test.cc
@@ -55,15 +55,17 @@
   StreamExecutor* stream_exec = platform->ExecutorForDevice(0).ValueOrDie();
   cuda::PtxCompilationOptions opts;
   StreamExecutorMemoryAllocator se_allocator(platform, {stream_exec});
-  RedzoneAllocator allocator(/*device_ordinal=*/0, &se_allocator, opts,
-                             kRedzoneSize, kRedzonePattern);
 
   Stream stream(stream_exec);
   stream.Init();
+  RedzoneAllocator allocator(
+      &stream, &se_allocator, opts,
+      /*memory_limit=*/RedzoneAllocator::kDefaultMemoryLimit,
+      /*redzone_size=*/kRedzoneSize,
+      /*redzone_pattern=*/kRedzonePattern);
   TF_ASSERT_OK_AND_ASSIGN(DeviceMemory<uint8> buf,
-                          allocator.AllocateBytes(&stream,
-                                                  /*byte_size=*/kAllocSize));
-  EXPECT_REDZONE_OK(allocator.CheckRedzones(&stream));
+                          allocator.AllocateBytes(/*byte_size=*/kAllocSize));
+  EXPECT_REDZONE_OK(allocator.CheckRedzones());
 
   char* buf_addr = reinterpret_cast<char*>(buf.opaque());
   DeviceMemoryBase lhs_redzone(buf_addr - kRedzoneSize, kRedzoneSize);
@@ -100,15 +102,13 @@
     DeviceMemoryBase redzone_at_offset(
         reinterpret_cast<char*>(redzone.opaque()) + offset, 1);
     char old_redzone_value = 0;
-    {
-      EXPECT_REDZONE_OK(allocator.CheckRedzones(&stream));
-    }
+    { EXPECT_REDZONE_OK(allocator.CheckRedzones()); }
     stream.ThenMemcpy(&old_redzone_value, redzone_at_offset, 1)
         .ThenMemZero(&redzone_at_offset, 1);
-    EXPECT_REDZONE_VIOLATION(allocator.CheckRedzones(&stream));
+    EXPECT_REDZONE_VIOLATION(allocator.CheckRedzones());
 
     // Checking reinitializes the redzone.
-    EXPECT_REDZONE_OK(allocator.CheckRedzones(&stream));
+    EXPECT_REDZONE_OK(allocator.CheckRedzones());
   };
 
   modify_redzone(lhs_redzone, /*offset=*/0, "lhs");
@@ -130,12 +130,15 @@
   StreamExecutor* stream_exec = platform->ExecutorForDevice(0).ValueOrDie();
   cuda::PtxCompilationOptions opts;
   StreamExecutorMemoryAllocator se_allocator(platform, {stream_exec});
-  RedzoneAllocator allocator(/*device_ordinal=*/0, &se_allocator, opts,
-                             kRedzoneSize, /*redzone_pattern=*/-1);
   Stream stream(stream_exec);
   stream.Init();
-  (void)allocator.AllocateBytes(&stream, /*byte_size=*/1);
-  EXPECT_REDZONE_OK(allocator.CheckRedzones(&stream));
+  RedzoneAllocator allocator(
+      &stream, &se_allocator, opts,
+      /*memory_limit=*/RedzoneAllocator::kDefaultMemoryLimit,
+      /*redzone_size=*/kRedzoneSize,
+      /*redzone_pattern=*/-1);
+  (void)allocator.AllocateBytes(/*byte_size=*/1);
+  EXPECT_REDZONE_OK(allocator.CheckRedzones());
 }
 
 }  // namespace
diff --git a/tensorflow/stream_executor/device_memory_allocator.h b/tensorflow/stream_executor/device_memory_allocator.h
index c9213cf..9d30969 100644
--- a/tensorflow/stream_executor/device_memory_allocator.h
+++ b/tensorflow/stream_executor/device_memory_allocator.h
@@ -147,9 +147,10 @@
 // Type alias for compatibility with the previous managed memory implementation.
 using OwningDeviceMemory = ScopedDeviceMemory<uint8>;
 
-// Interface for device memory allocators used within the XLA service. An
-// allocator is responsible for allocating memory on all devices of a particular
-// platform.
+// Memory allocator interface for the device.
+//
+// Intended usage is through Allocate() functions which return an owning smart
+// pointer.
 class DeviceMemoryAllocator {
  public:
   // Parameter platform indicates which platform the allocator allocates memory
@@ -186,7 +187,9 @@
     return Allocate(device_ordinal, size, retry_on_failure);
   }
 
-  // Must be a nop for null pointers.
+  // Must be a nop for null pointers. Should not be used.
+  //
+  // TODO(cheshire): Add deprecation notice.
   virtual port::Status Deallocate(int device_ordinal, DeviceMemoryBase mem) = 0;
 
   // Return the platform that the allocator allocates memory on.
@@ -194,7 +197,7 @@
 
   // Can we call Deallocate() as soon as a computation has been scheduled on
   // a stream, or do we have to wait for the computation to complete first?
-  virtual bool AllowsAsynchronousDeallocation() const = 0;
+  virtual bool AllowsAsynchronousDeallocation() const { return false; }
 
  protected:
   const Platform* platform_;
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 7837c8e..77045ef 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -2095,6 +2095,7 @@
   //  state_allocator: an memory allocator that will be used to store the state
   //    for dropout layer. The user has to maintain the memory until the model
   //    is no longer in use.
+  //  use_padded_io: a bool to specify whether the input is using padded IO.
   virtual port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
   createRnnDescriptor(int num_layers, int hidden_size, int input_size,
                       int cell_size, int batch_size,
@@ -2103,7 +2104,7 @@
                       dnn::RnnMode rnn_mode, dnn::DataType data_type,
                       const dnn::AlgorithmConfig& algorithm_config,
                       float dropout, uint64 seed,
-                      ScratchAllocator* state_allocator) {
+                      ScratchAllocator* state_allocator, bool use_padded_io) {
     return port::Status(port::error::UNIMPLEMENTED,
                         "createRnnDescriptor is unimplemented");
   }
diff --git a/tensorflow/stream_executor/gpu/BUILD b/tensorflow/stream_executor/gpu/BUILD
index 1981490..cd598b4 100644
--- a/tensorflow/stream_executor/gpu/BUILD
+++ b/tensorflow/stream_executor/gpu/BUILD
@@ -6,7 +6,7 @@
     "if_gpu_is_configured",
 )
 load(
-    "//tensorflow/core:platform/default/cuda_build_defs.bzl",
+    "//tensorflow/core/platform:default/cuda_build_defs.bzl",
     "if_cuda_is_configured",
 )
 load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
diff --git a/tensorflow/stream_executor/gpu/gpu_driver.h b/tensorflow/stream_executor/gpu/gpu_driver.h
index 57a87fb..5de443e 100644
--- a/tensorflow/stream_executor/gpu/gpu_driver.h
+++ b/tensorflow/stream_executor/gpu/gpu_driver.h
@@ -197,19 +197,18 @@
   // TODO(leary) describe the structure of kernel_params and extra in a readable
   // way.
   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1gb8f3dc3031b40da29d5f9a7139e52e15
-  static bool LaunchKernel(GpuContext* context, GpuFunctionHandle function,
-                           unsigned int grid_dim_x, unsigned int grid_dim_y,
-                           unsigned int grid_dim_z, unsigned int block_dim_x,
-                           unsigned int block_dim_y, unsigned int block_dim_z,
-                           unsigned int shared_mem_bytes,
-                           GpuStreamHandle stream, void** kernel_params,
-                           void** extra);
+  static port::Status LaunchKernel(
+      GpuContext* context, GpuFunctionHandle function, unsigned int grid_dim_x,
+      unsigned int grid_dim_y, unsigned int grid_dim_z,
+      unsigned int block_dim_x, unsigned int block_dim_y,
+      unsigned int block_dim_z, unsigned int shared_mem_bytes,
+      GpuStreamHandle stream, void** kernel_params, void** extra);
 
   // Loads ptx_contents with the CUDA driver's PTX JIT and stores the resulting
   // handle in "module". Any error logs that are produced are logged internally.
   // (supported on CUDA only)
-  static bool LoadPtx(GpuContext* context, const char* ptx_contents,
-                      GpuModuleHandle* module);
+  static port::Status LoadPtx(GpuContext* context, const char* ptx_contents,
+                              GpuModuleHandle* module);
 
   // Loads cubin_bytes with the CUDA driver's blob loading interface and stores
   // the resulting handle in "module".
diff --git a/tensorflow/stream_executor/gpu/gpu_executor.h b/tensorflow/stream_executor/gpu/gpu_executor.h
index 2149f13..c61bd73 100644
--- a/tensorflow/stream_executor/gpu/gpu_executor.h
+++ b/tensorflow/stream_executor/gpu/gpu_executor.h
@@ -61,17 +61,17 @@
 
   port::Status Init(int device_ordinal, DeviceOptions device_options) override;
 
-  bool GetKernel(const MultiKernelLoaderSpec& spec,
-                 KernelBase* kernel) override;
+  port::Status GetKernel(const MultiKernelLoaderSpec& spec,
+                         KernelBase* kernel) override;
   // (supported on CUDA only)
   void UnloadKernel(const KernelBase* kernel) override;
-  bool LoadModule(const MultiModuleLoaderSpec& spec,
-                  ModuleHandle* module_handle) override;
+  port::Status LoadModule(const MultiModuleLoaderSpec& spec,
+                          ModuleHandle* module_handle) override;
   bool UnloadModule(ModuleHandle module_handle) override;
 
-  bool Launch(Stream* stream, const ThreadDim& thread_dims,
-              const BlockDim& block_dims, const KernelBase& k,
-              const KernelArgsArrayBase& args) override;
+  port::Status Launch(Stream* stream, const ThreadDim& thread_dims,
+                      const BlockDim& block_dims, const KernelBase& k,
+                      const KernelArgsArrayBase& args) override;
 
   // (supported on CUDA only)
   int CalculateOccupancy(const DeviceDescription& device_description,
@@ -271,12 +271,12 @@
                          const BlockDim& block_dims);
 
   // (supported on CUDA only)
-  bool LoadModuleFromCuBin(const char* cubin, GpuModuleHandle* module)
+  port::Status LoadModuleFromCuBin(const char* cubin, GpuModuleHandle* module)
       EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_);
 
   // Loads the PTX text `ptx` as a CUDA module.  `ptx` must be null terminated.
   // (supported on CUDA only)
-  bool LoadModuleFromPtx(const char* ptx, GpuModuleHandle* module)
+  port::Status LoadModuleFromPtx(const char* ptx, GpuModuleHandle* module)
       EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_);
 
   // (supported on ROCm only)
diff --git a/tensorflow/stream_executor/host/host_gpu_executor.cc b/tensorflow/stream_executor/host/host_gpu_executor.cc
index d24eec6..75f5431 100644
--- a/tensorflow/stream_executor/host/host_gpu_executor.cc
+++ b/tensorflow/stream_executor/host/host_gpu_executor.cc
@@ -254,6 +254,9 @@
       tensorflow::profile_utils::CpuUtils::GetCycleCounterFrequency());
   builder.set_clock_rate_ghz(cycle_counter_frequency / 1e9);
 
+  builder.set_name("Host");
+  builder.set_platform_version("Default Version");
+
   return builder.Build();
 }
 
diff --git a/tensorflow/stream_executor/host/host_gpu_executor.h b/tensorflow/stream_executor/host/host_gpu_executor.h
index dfe43e1..d0cc004 100644
--- a/tensorflow/stream_executor/host/host_gpu_executor.h
+++ b/tensorflow/stream_executor/host/host_gpu_executor.h
@@ -50,14 +50,14 @@
     return port::Status::OK();
   }
 
-  bool GetKernel(const MultiKernelLoaderSpec &spec,
-                 KernelBase *kernel) override {
-    return false;
+  port::Status GetKernel(const MultiKernelLoaderSpec &spec,
+                         KernelBase *kernel) override {
+    return port::UnimplementedError("Not Implemented");
   }
-  bool Launch(Stream *stream, const ThreadDim &thread_dims,
-              const BlockDim &block_dims, const KernelBase &kernel,
-              const KernelArgsArrayBase &args) override {
-    return false;
+  port::Status Launch(Stream *stream, const ThreadDim &thread_dims,
+                      const BlockDim &block_dims, const KernelBase &kernel,
+                      const KernelArgsArrayBase &args) override {
+    return port::UnimplementedError("Not Implemented");
   }
 
   void *Allocate(uint64 size) override;
diff --git a/tensorflow/stream_executor/kernel.h b/tensorflow/stream_executor/kernel.h
index 9384db6..1e4f375 100644
--- a/tensorflow/stream_executor/kernel.h
+++ b/tensorflow/stream_executor/kernel.h
@@ -525,16 +525,19 @@
   // structure.
   void PackParams(KernelArgsArray<kNumberOfParameters> *args,
                   Params &... params) const {
-    PackOneParam(args, params...);
+    PackOneParamFromList(args, params...);
   }
 
   template <typename T, typename... RestOfParams>
-  void PackOneParam(KernelArgsArray<kNumberOfParameters> *args, const T &arg,
-                    const RestOfParams &... rest) const {
+  void PackOneParamFromList(KernelArgsArray<kNumberOfParameters> *args,
+                            const T &arg, const RestOfParams &... rest) const {
     PackOneParam(args, arg);
-    PackOneParam(args, rest...);
+    PackOneParamFromList(args, rest...);
   }
 
+  // Base case for variadic template expansion - nothing to do!
+  void PackOneParamFromList(KernelArgsArray<kNumberOfParameters> *args) const {}
+
   // Packs one (non-DeviceMemoryBase) parameter into the arg and sizes array.
   // The enable_if<> is for excluding DeviceMemoryBase args, which have a
   // separate implementation below.
@@ -581,9 +584,6 @@
     args->add_shared_bytes(arg.size());
   }
 
-  // Base case for variadic template expansion - nothing to do!
-  void PackOneParam(KernelArgsArray<kNumberOfParameters> *args) const {}
-
   SE_DISALLOW_COPY_AND_ASSIGN(TypedKernel);
 };
 
diff --git a/tensorflow/stream_executor/platform/BUILD b/tensorflow/stream_executor/platform/BUILD
index f5071d1..f9540db 100644
--- a/tensorflow/stream_executor/platform/BUILD
+++ b/tensorflow/stream_executor/platform/BUILD
@@ -1,5 +1,4 @@
 load("//tensorflow/stream_executor:build_defs.bzl", "stream_executor_friends")
-load("//tensorflow/core:platform/default/build_config.bzl", "tf_platform_hdrs")
 
 package(
     default_visibility = [":friends"],
diff --git a/tensorflow/stream_executor/platform/default/BUILD b/tensorflow/stream_executor/platform/default/BUILD
index e039b5e..51170e4 100644
--- a/tensorflow/stream_executor/platform/default/BUILD
+++ b/tensorflow/stream_executor/platform/default/BUILD
@@ -2,7 +2,7 @@
 
 package(default_visibility = ["//tensorflow/stream_executor:__subpackages__"])
 
-load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
+load("//tensorflow/core/platform:default/build_config_root.bzl", "if_static")
 load("//tensorflow:tensorflow.bzl", "tf_copts")
 
 cc_library(
diff --git a/tensorflow/stream_executor/rocm/BUILD b/tensorflow/stream_executor/rocm/BUILD
index 71139c0..008de9e 100644
--- a/tensorflow/stream_executor/rocm/BUILD
+++ b/tensorflow/stream_executor/rocm/BUILD
@@ -8,7 +8,7 @@
 )
 load("//tensorflow:tensorflow.bzl", "tf_copts")
 load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
-load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
+load("//tensorflow/core/platform:default/build_config_root.bzl", "if_static")
 
 package(
     default_visibility = [":friends"],
diff --git a/tensorflow/stream_executor/rocm/rocm_blas.cc b/tensorflow/stream_executor/rocm/rocm_blas.cc
index 52c6546..a5a588b 100644
--- a/tensorflow/stream_executor/rocm/rocm_blas.cc
+++ b/tensorflow/stream_executor/rocm/rocm_blas.cc
@@ -1849,7 +1849,7 @@
   if (scratch_allocator != nullptr) {
     SE_ASSIGN_OR_RETURN(
         DeviceMemory<uint8> batch_matrix_bytes,
-        scratch_allocator->AllocateBytes(stream, matrix_batch_byte_size));
+        scratch_allocator->AllocateBytes(matrix_batch_byte_size));
     *device_memory = DeviceMemory<MAPPED_T>(batch_matrix_bytes);
   } else {
     assert(temp_memory != nullptr);
@@ -2407,6 +2407,11 @@
              << "for the \"complex<double>\" dataype";
   return false;
 }
+
+port::Status ROCMBlas::GetVersion(string *version) {
+  return port::UnimplementedError("");
+}
+
 }  // namespace gpu
 
 void initialize_rocblas() {
diff --git a/tensorflow/stream_executor/rocm/rocm_dnn.cc b/tensorflow/stream_executor/rocm/rocm_dnn.cc
index efe49dd..52a98b6 100644
--- a/tensorflow/stream_executor/rocm/rocm_dnn.cc
+++ b/tensorflow/stream_executor/rocm/rocm_dnn.cc
@@ -1985,7 +1985,7 @@
   // Allocate the workspace.
   if (workspace_size_in_bytes > 0) {
     auto allocated =
-        workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes);
+        workspace_allocator->AllocateBytes(workspace_size_in_bytes);
     if (!allocated.ok() || (*workspace = allocated.ValueOrDie()) == nullptr) {
       LOG(ERROR) << "Failed to allocate RNN workspace";
 
@@ -2062,8 +2062,8 @@
     }
 
     if (reserve_space_size_in_bytes > 0) {
-      auto allocated = reserve_space_allocator->AllocateBytes(
-          stream, reserve_space_size_in_bytes);
+      auto allocated =
+          reserve_space_allocator->AllocateBytes(reserve_space_size_in_bytes);
       if (!allocated.ok() ||
           (reserve_space = allocated.ValueOrDie()) == nullptr) {
         LOG(ERROR) << "Fail to allocate RNN reserve space";
@@ -2280,7 +2280,8 @@
     int batch_size, dnn::RnnInputMode input_mode,
     dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
     dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
-    float dropout, uint64 seed, ScratchAllocator* state_allocator) {
+    float dropout, uint64 seed, ScratchAllocator* state_allocator,
+    bool use_padded_io) {
   // ROCM TODO: cell_size is ignored for now
   // ROCM TODO: batch_size is ignored for now
 
@@ -2575,8 +2576,7 @@
 
 void* MIOpenAllocatorCallback(void* ctx, size_t size_in_bytes) {
   auto* mac = static_cast<MIOpenAllocatorContext*>(ctx);
-  auto allocated =
-      mac->scratch_allocator_->AllocateBytes(mac->stream_, size_in_bytes);
+  auto allocated = mac->scratch_allocator_->AllocateBytes(size_in_bytes);
 
   DeviceMemory<uint8> scratch;
   if (allocated.ok()) {
@@ -2659,7 +2659,7 @@
     }
 
     if (status == miopenStatusSuccess && size_in_bytes != 0) {
-      auto allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes);
+      auto allocated = scratch_allocator->AllocateBytes(size_in_bytes);
       if (allocated.ok()) {
         scratch_memory_temp = allocated.ValueOrDie();
       }
@@ -2744,8 +2744,7 @@
           absl::StrCat("An allocator must be specified when scratch memory is "
                        "needed"));
     }
-    auto allocated =
-        scratch_allocator->AllocateBytes(stream, scratch_memory_size);
+    auto allocated = scratch_allocator->AllocateBytes(scratch_memory_size);
     if (!allocated.ok()) {
       return port::InternalError(absl::StrCat(
           "Failed to allocate scratch memory of size: ", scratch_memory_size));
@@ -3600,7 +3599,7 @@
   if (workspace_size_in_bytes > 0) {
     assert(workspace_allocator);
     auto allocated =
-        workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes);
+        workspace_allocator->AllocateBytes(workspace_size_in_bytes);
     if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) {
       LOG(ERROR) << "Failed to allocate backward pooling workspace";
       return false;
@@ -3624,7 +3623,7 @@
 
   if (dest2_size > 0) {
     assert(workspace_allocator);
-    auto allocated = workspace_allocator->AllocateBytes(stream, dest2_size);
+    auto allocated = workspace_allocator->AllocateBytes(dest2_size);
     if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) {
       LOG(ERROR) << "Failed to allocate backward pooling workspace";
       return false;
@@ -3696,7 +3695,7 @@
   if (workspace_size_in_bytes > 0) {
     assert(workspace_allocator);
     auto allocated =
-        workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes);
+        workspace_allocator->AllocateBytes(workspace_size_in_bytes);
     if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) {
       LOG(ERROR) << "Failed to allocate backward pooling workspace";
       return false;
@@ -3720,7 +3719,7 @@
 
   if (dest2_size > 0) {
     assert(workspace_allocator);
-    auto allocated = workspace_allocator->AllocateBytes(stream, dest2_size);
+    auto allocated = workspace_allocator->AllocateBytes(dest2_size);
     if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) {
       LOG(ERROR) << "Failed to allocate backward pooling workspace";
       return false;
@@ -3831,7 +3830,7 @@
   if (workspace_size_in_bytes > 0) {
     assert(workspace_allocator);
     auto allocated =
-        workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes);
+        workspace_allocator->AllocateBytes(workspace_size_in_bytes);
     if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) {
       LOG(ERROR) << "Failed to allocate backward pooling workspace";
       return false;
@@ -3856,7 +3855,7 @@
 
   if (dest2_size > 0) {
     assert(workspace_allocator);
-    auto allocated = workspace_allocator->AllocateBytes(stream, dest2_size);
+    auto allocated = workspace_allocator->AllocateBytes(dest2_size);
     if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) {
       LOG(ERROR)
           << "Failed to allocate tensor to chain forward and backward LRN";
diff --git a/tensorflow/stream_executor/rocm/rocm_dnn.h b/tensorflow/stream_executor/rocm/rocm_dnn.h
index 5bc0914..b955480 100644
--- a/tensorflow/stream_executor/rocm/rocm_dnn.h
+++ b/tensorflow/stream_executor/rocm/rocm_dnn.h
@@ -50,7 +50,8 @@
       int batch_size, dnn::RnnInputMode input_mode,
       dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
       dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
-      float dropout, uint64 seed, ScratchAllocator* state_allocator) override;
+      float dropout, uint64 seed, ScratchAllocator* state_allocator,
+      bool use_padded_io) override;
 
   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
   createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
diff --git a/tensorflow/stream_executor/rocm/rocm_driver.cc b/tensorflow/stream_executor/rocm/rocm_driver.cc
index 7cd35ff..1aae9f2 100644
--- a/tensorflow/stream_executor/rocm/rocm_driver.cc
+++ b/tensorflow/stream_executor/rocm/rocm_driver.cc
@@ -419,7 +419,7 @@
   return port::Status::OK();
 }
 
-/* static */ bool GpuDriver::LaunchKernel(
+/* static */ port::Status GpuDriver::LaunchKernel(
     GpuContext* context, hipFunction_t function, unsigned int grid_dim_x,
     unsigned int grid_dim_y, unsigned int grid_dim_z, unsigned int block_dim_x,
     unsigned int block_dim_y, unsigned int block_dim_z,
@@ -434,19 +434,18 @@
       function, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x, block_dim_y,
       block_dim_z, shared_mem_bytes, stream, kernel_params, extra);
   if (res != hipSuccess) {
-    LOG(ERROR) << "failed to launch ROCM kernel: " << function
-               << "; result: " << ToString(res);
-    return false;
+    return port::InternalError(
+        absl::StrCat("Failed to launch ROCM kernel: ", ToString(res)));
   }
   VLOG(2) << "successfully launched kernel";
-  return true;
+  return port::Status::OK();
 }
 
-/* static */ bool GpuDriver::LoadPtx(GpuContext* context,
-                                     const char* ptx_contents,
-                                     hipModule_t* module) {
+/* static */ port::Status GpuDriver::LoadPtx(GpuContext* context,
+                                             const char* ptx_contents,
+                                             hipModule_t* module) {
   LOG(ERROR) << "Feature not supported on ROCm platform (LoadPtx)";
-  return false;
+  return port::InternalError("Not Implemented");
 }
 
 /* static */ port::Status GpuDriver::LoadCubin(GpuContext* context,
diff --git a/tensorflow/stream_executor/rocm/rocm_fft.cc b/tensorflow/stream_executor/rocm/rocm_fft.cc
index d2c542f..82dce9e 100644
--- a/tensorflow/stream_executor/rocm/rocm_fft.cc
+++ b/tensorflow/stream_executor/rocm/rocm_fft.cc
@@ -272,8 +272,7 @@
       // TODO(yangzihao): refactor this code and the one with the same function
       // in the batch mode.
       if (size_in_bytes != 0) {
-        auto allocated =
-            scratch_allocator->AllocateBytes(stream, size_in_bytes);
+        auto allocated = scratch_allocator->AllocateBytes(size_in_bytes);
         if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
           LOG(ERROR) << "failed to allocate work area.";
           return allocated.status();
@@ -328,8 +327,7 @@
                             "Failed to make rocFFT bacthed plan."};
       }
       if (size_in_bytes != 0) {
-        auto allocated =
-            scratch_allocator->AllocateBytes(stream, size_in_bytes);
+        auto allocated = scratch_allocator->AllocateBytes(size_in_bytes);
         if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
           LOG(ERROR) << "failed to allocate work area.";
           return allocated.status();
diff --git a/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc b/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc
index e37d6d2..d1ee42e 100644
--- a/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc
+++ b/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc
@@ -230,8 +230,8 @@
   return exe_path;
 }
 
-bool GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec,
-                            KernelBase* kernel) {
+port::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec,
+                                    KernelBase* kernel) {
   GpuKernel* rocm_kernel = AsGpuKernel(kernel);
   hipModule_t module = nullptr;
   const string* kernelname;
@@ -243,8 +243,8 @@
   }
 
   if (on_disk_spec != nullptr) {
-    LOG(WARNING) << "loading ROCM kernel from disk is not supported";
-    return false;
+    return port::InternalError(
+        "Loading ROCM kernel from disk is not supported");
   } else if (spec.has_cuda_cubin_in_memory()) {
     kernelname = &spec.cuda_cubin_in_memory().kernelname();
 
@@ -254,20 +254,18 @@
 
     if (module == nullptr) {
       if (!GpuDriver::LoadHsaco(context_, hsaco, &module)) {
-        LOG(ERROR) << "failed to load HSACO\n";
-        return false;
+        return port::InternalError("Failed to load HSACO");
       }
     }
     kernel_to_gpu_binary_[kernel] = hsaco;
   } else {
-    LOG(WARNING) << "no method of loading ROCM kernel provided";
-    return false;
+    return port::InternalError("No method of loading ROCM kernel provided");
   }
 
   VLOG(2) << "getting function " << *kernelname << " from module " << module;
   if (!GpuDriver::GetModuleFunction(context_, module, kernelname->c_str(),
                                     rocm_kernel->gpu_function_ptr())) {
-    return false;
+    return port::InternalError("Failed getting module function");
   }
 
   // We have to trust the kernel loader spec arity because there doesn't appear
@@ -280,7 +278,7 @@
   }
   kernel->set_metadata(kernel_metadata);
   kernel->set_name(*kernelname);
-  return true;
+  return port::Status::OK();
 }
 
 bool GpuExecutor::GetKernelMetadata(GpuKernel* rocm_kernel,
@@ -295,9 +293,10 @@
   return true;
 }
 
-bool GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims,
-                         const BlockDim& block_dims, const KernelBase& kernel,
-                         const KernelArgsArrayBase& args) {
+port::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims,
+                                 const BlockDim& block_dims,
+                                 const KernelBase& kernel,
+                                 const KernelArgsArrayBase& args) {
   CHECK_EQ(kernel.Arity(), args.number_of_arguments());
   GpuStreamHandle hipstream = AsGpuStreamValue(stream);
   const GpuKernel* rocm_kernel = AsGpuKernel(&kernel);
@@ -339,18 +338,10 @@
   void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, kernargs.data(),
                     HIP_LAUNCH_PARAM_BUFFER_SIZE, &size, HIP_LAUNCH_PARAM_END};
 
-  if (!GpuDriver::LaunchKernel(
-          GetGpuContext(stream), hipfunc, block_dims.x, block_dims.y,
-          block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z,
-          args.number_of_shared_bytes(), hipstream, nullptr, (void**)&config)) {
-    LOG(ERROR) << "failed to launch ROCM kernel with args: "
-               << args.number_of_arguments()
-               << "; thread dim: " << thread_dims.ToString()
-               << "; block dim: " << block_dims.ToString();
-    return false;
-  }
-
-  return true;
+  return GpuDriver::LaunchKernel(
+      GetGpuContext(stream), hipfunc, block_dims.x, block_dims.y, block_dims.z,
+      thread_dims.x, thread_dims.y, thread_dims.z,
+      args.number_of_shared_bytes(), hipstream, nullptr, (void**)&config);
 }
 
 int GpuExecutor::CalculateOccupancy(const DeviceDescription& device_description,
@@ -372,8 +363,8 @@
   return 0;
 }
 
-bool GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec,
-                             ModuleHandle* module_handle) {
+port::Status GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec,
+                                     ModuleHandle* module_handle) {
   // In GpuExecutor we store the pointer to the  HSACO binary  as
   // ModuleHandle::id().
   hipModule_t hip_module = nullptr;
@@ -383,25 +374,24 @@
     if (!LoadModuleFromHsaco(
             reinterpret_cast<const char*>(spec.cuda_cubin_in_memory().data()),
             &hip_module)) {
-      return false;
+      return port::InternalError("Failed loading module from HSACO");
     }
     *module_handle = ModuleHandle(const_cast<void*>(
         static_cast<const void*>(spec.cuda_cubin_in_memory().data())));
-    return true;
+    return port::Status::OK();
   } else {
-    LOG(ERROR) << "No HSACO binary found \n";
-    return false;
+    return port::InternalError("No HASCO binary found");
   }
 }
 
-bool GpuExecutor::LoadModuleFromCuBin(const char* cubin, hipModule_t* module) {
+port::Status GpuExecutor::LoadModuleFromCuBin(const char* cubin,
+                                              hipModule_t* module) {
   LOG(FATAL) << "Feature not supported on ROCM platform (LoadModuleFromCuBin)";
-  return false;
 }
 
-bool GpuExecutor::LoadModuleFromPtx(const char* ptx, hipModule_t* module) {
+port::Status GpuExecutor::LoadModuleFromPtx(const char* ptx,
+                                            hipModule_t* module) {
   LOG(FATAL) << "Feature not supported on ROCM platform (LoadModuleFromPtx)";
-  return false;
 }
 
 bool GpuExecutor::LoadModuleFromHsaco(const char* hsaco, hipModule_t* module) {
diff --git a/tensorflow/stream_executor/scratch_allocator.cc b/tensorflow/stream_executor/scratch_allocator.cc
index 8fc4c4c..520ee8a 100644
--- a/tensorflow/stream_executor/scratch_allocator.cc
+++ b/tensorflow/stream_executor/scratch_allocator.cc
@@ -22,18 +22,17 @@
 
 ScratchAllocator::~ScratchAllocator() {}
 
-OneTimeScratchAllocator::OneTimeScratchAllocator() {}
+OneTimeScratchAllocator::OneTimeScratchAllocator(Stream* stream)
+    : stream_(stream) {}
 OneTimeScratchAllocator::~OneTimeScratchAllocator() {}
 
-int64 OneTimeScratchAllocator::GetMemoryLimitInBytes(Stream* stream) {
-  return -1;
-}
+int64 OneTimeScratchAllocator::GetMemoryLimitInBytes() { return -1; }
 
 port::StatusOr<DeviceMemory<uint8>> OneTimeScratchAllocator::AllocateBytes(
-    Stream* stream, int64 byte_size) {
+    int64 byte_size) {
   CHECK(temporary_ == nullptr);
   SE_ASSIGN_OR_RETURN(temporary_,
-                      stream->AllocateTemporaryArray<uint8>(byte_size));
+                      stream_->AllocateTemporaryArray<uint8>(byte_size));
   return temporary_->device_memory();
 }
 
diff --git a/tensorflow/stream_executor/scratch_allocator.h b/tensorflow/stream_executor/scratch_allocator.h
index 2aed2c4..29b4e5a 100644
--- a/tensorflow/stream_executor/scratch_allocator.h
+++ b/tensorflow/stream_executor/scratch_allocator.h
@@ -27,16 +27,12 @@
 
 class Stream;
 
-// Interface that allows stream operations (e.g.
-// Stream::ThenConvolveWithScratch) to optionally request scratch space be
-// allocated in order to speed up the operation being enqueued.
+// Interface for "scratch" allocator for device memory, which deallocates all
+// buffers it has allocated at destruction. Returned memory pointers are not
+// owning.
 //
-// Note that the caller is responsible for deallocating the scratch space at a
-// known-safe point, when all scratch-memory-consuming kernels are known for
-// sure to have finished; e.g. at stream synchronization time. This is different
-// from a traditional C++ object allocator, where the client is responsible for
-// releasing. (Conceptually, scratch memory is a form of "temporary" device
-// memory allocation.)
+// Used by stream operations (e.g. Stream::ThenConvolveWithScratch) to optonally
+// request scratch space to speed up the operation.
 class ScratchAllocator {
  public:
   virtual ~ScratchAllocator();
@@ -45,14 +41,14 @@
   // bytes. This information may be used to help select an algorithm.
   //
   // Returns values < 0 to indicate that there is no recommended limit.
-  virtual int64 GetMemoryLimitInBytes(Stream* stream) = 0;
+  virtual int64 GetMemoryLimitInBytes() = 0;
 
   // Returns an allocation on byte_size bytes for use in an operation on stream.
   //
   // This is a temporary allocation, and the caller is responsible for
   // deallocating at some known-safe point. See the class comment above.
   virtual port::StatusOr<DeviceMemory<uint8>> AllocateBytes(
-      Stream* stream, int64 byte_size) = 0;
+      int64 byte_size) = 0;
 };
 
 // Allocates a single temporary memory allocation -- this memory is deallocated
@@ -64,14 +60,14 @@
 // thread will request the scratch allocation).
 class OneTimeScratchAllocator : public ScratchAllocator {
  public:
-  OneTimeScratchAllocator();
+  explicit OneTimeScratchAllocator(Stream* stream);
   ~OneTimeScratchAllocator() override;
-  int64 GetMemoryLimitInBytes(Stream* stream) override;
-  port::StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream,
-                                                    int64 byte_size) override;
+  int64 GetMemoryLimitInBytes() override;
+  port::StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override;
 
  private:
   std::unique_ptr<TemporaryDeviceMemory<uint8>> temporary_;
+  Stream* stream_;
 
   SE_DISALLOW_COPY_AND_ASSIGN(OneTimeScratchAllocator);
 };
diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h
index 4619fe1..ca60a09 100644
--- a/tensorflow/stream_executor/stream_executor_internal.h
+++ b/tensorflow/stream_executor/stream_executor_internal.h
@@ -179,20 +179,21 @@
   virtual port::Status Init(int device_ordinal,
                             DeviceOptions device_options) = 0;
 
-  virtual bool GetKernel(const MultiKernelLoaderSpec &spec,
-                         KernelBase *kernel) {
-    return false;
-  }
-  virtual bool LoadModule(const MultiModuleLoaderSpec &spec,
-                          ModuleHandle *module_handle) {
-    return false;
+  virtual port::Status GetKernel(const MultiKernelLoaderSpec &spec,
+                                 KernelBase *kernel) {
+    return port::UnimplementedError("Not Implemented");
   }
   virtual bool UnloadModule(ModuleHandle module_handle) { return false; }
-  virtual bool Launch(Stream *stream, const ThreadDim &thread_dims,
-                      const BlockDim &block_dims, const KernelBase &k,
-                      const KernelArgsArrayBase &args) {
-    return false;
+  virtual port::Status LoadModule(const MultiModuleLoaderSpec &spec,
+                                  ModuleHandle *module_handle) {
+    return port::UnimplementedError("Not Implemented");
   }
+  virtual port::Status Launch(Stream *stream, const ThreadDim &thread_dims,
+                              const BlockDim &block_dims, const KernelBase &k,
+                              const KernelArgsArrayBase &args) {
+    return port::UnimplementedError("Not Implemented");
+  }
+
   // Releases any state associated with the kernel.
   virtual void UnloadKernel(const KernelBase *kernel) {}
   virtual void *Allocate(uint64 size) = 0;
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 839f1cd..f8b6655 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -188,8 +188,8 @@
 
 port::Status StreamExecutor::Init() { return Init(DeviceOptions::Default()); }
 
-bool StreamExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
-                               KernelBase *kernel) {
+port::Status StreamExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
+                                       KernelBase *kernel) {
   return implementation_->GetKernel(spec, kernel);
 }
 
@@ -197,8 +197,8 @@
   implementation_->UnloadKernel(kernel);
 }
 
-bool StreamExecutor::LoadModule(const MultiModuleLoaderSpec &spec,
-                                ModuleHandle *module_handle) {
+port::Status StreamExecutor::LoadModule(const MultiModuleLoaderSpec &spec,
+                                        ModuleHandle *module_handle) {
   return implementation_->LoadModule(spec, module_handle);
 }
 
@@ -340,7 +340,8 @@
     int batch_size, dnn::RnnInputMode input_mode,
     dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
     dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config,
-    float dropout, uint64 seed, ScratchAllocator *state_allocator) {
+    float dropout, uint64 seed, ScratchAllocator *state_allocator,
+    bool use_padded_io) {
   dnn::DnnSupport *dnn_support = AsDnn();
   if (!dnn_support) {
     return port::Status(port::error::UNKNOWN,
@@ -349,7 +350,7 @@
   return dnn_support->createRnnDescriptor(
       num_layers, hidden_size, input_size, cell_size, batch_size, input_mode,
       direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed,
-      state_allocator);
+      state_allocator, use_padded_io);
 }
 
 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
@@ -433,10 +434,11 @@
   return rng_.get();
 }
 
-bool StreamExecutor::Launch(Stream *stream, const ThreadDim &thread_dims,
-                            const BlockDim &block_dims,
-                            const KernelBase &kernel,
-                            const KernelArgsArrayBase &args) {
+port::Status StreamExecutor::Launch(Stream *stream,
+                                    const ThreadDim &thread_dims,
+                                    const BlockDim &block_dims,
+                                    const KernelBase &kernel,
+                                    const KernelArgsArrayBase &args) {
   SubmitTrace(&TraceListener::LaunchSubmit, stream, thread_dims, block_dims,
               kernel, args);
 
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index d2f2f59..efa4034 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -100,8 +100,8 @@
   //    instantiation should not be loaded into more than once.
   //
   // If an error occurs, or there is no kernel available for the StreamExecutor
-  // platform, false is returned.
-  bool GetKernel(const MultiKernelLoaderSpec &spec, KernelBase *kernel);
+  // platform, error status is returned.
+  port::Status GetKernel(const MultiKernelLoaderSpec &spec, KernelBase *kernel);
 
   // Releases any state associated with the previously loaded kernel.
   void UnloadKernel(const KernelBase *kernel);
@@ -109,9 +109,10 @@
   // Loads a module for the platform this StreamExecutor is acting upon.
   //
   // `spec` describes the module to be loaded.  On success writes the handle for
-  // the loaded module to `module_handle` and returns true.  Else returns false.
-  bool LoadModule(const MultiModuleLoaderSpec &spec,
-                  ModuleHandle *module_handle);
+  // the loaded module to `module_handle` and returns Status::OK.
+  // Otherwise, returns the error which has occurred.
+  port::Status LoadModule(const MultiModuleLoaderSpec &spec,
+                          ModuleHandle *module_handle);
 
   // Unloads the module with handle `module_handle`.
   bool UnloadModule(ModuleHandle module_handle);
@@ -185,9 +186,6 @@
   //
   // Resets the internal contents of mem to be null-representative, but this
   // null-out effect should not be relied upon in client code.
-  //
-  // TODO(jlebar): Change this to accept a DeviceMemoryBase by value, see
-  // discussion in cl/195744342.
   void Deallocate(DeviceMemoryBase *mem);
 
   // Retrieves a mapping of active opaque device memory pointer to a string
@@ -398,7 +396,8 @@
       int batch_size, dnn::RnnInputMode input_mode,
       dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
       dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config,
-      float dropout, uint64 seed, ScratchAllocator *state_allocator);
+      float dropout, uint64 seed, ScratchAllocator *state_allocator,
+      bool use_padded_io);
 
   // Create a RNN sequence descriptor that specifies either the input or output
   // sequence. The caller retains the ownership of the returned descriptor.
@@ -451,9 +450,9 @@
   //
   // This is called by Stream::Launch() to delegate to the platform's launch
   // implementation in StreamExecutorInterface::Launch().
-  bool Launch(Stream *stream, const ThreadDim &thread_dims,
-              const BlockDim &block_dims, const KernelBase &kernel,
-              const KernelArgsArrayBase &args);
+  port::Status Launch(Stream *stream, const ThreadDim &thread_dims,
+                      const BlockDim &block_dims, const KernelBase &kernel,
+                      const KernelArgsArrayBase &args);
 
   // Gets-or-creates (creates with memoization) a FftSupport datatype that can
   // be used to execute FFT routines on the current platform.
@@ -473,6 +472,19 @@
   // underlying platform.
   dnn::DnnSupport *AsDnn();
 
+  // Gets-or-creates (creates with memoization) a BlasSupport datatype that can
+  // be used to execute BLAS routines on the current platform. This is typically
+  // not user-facing, as users will use the Stream::ThenBlas* family of routines
+  // to entrain BLAS operations. See blas.h for additional details.
+  //
+  // Ownership is not transferred to the caller -- ownership is retained by this
+  // object for memoization. This BLAS interface is also only expected to be
+  // used by a Stream for entraining calls to BLAS functionality.
+  //
+  // Returns null if there was an error initializing the BLAS support for the
+  // underlying platform.
+  blas::BlasSupport *AsBlas();
+
   // Turns StreamExecutor operation tracing on or off.
   void EnableTracing(bool enable);
 
@@ -493,9 +505,6 @@
 
   // Return an allocator which delegates to this stream executor for memory
   // allocation.
-  //
-  // Creates the allocator object on the first access, as the device ordinal
-  // of this stream_executor is not set in constructor.
   StreamExecutorMemoryAllocator *GetAllocator() { return &allocator_; }
 
  private:
@@ -510,18 +519,10 @@
   template <typename... Args>
   friend struct ThenBlasImpl;
 
-  // Gets-or-creates (creates with memoization) a BlasSupport datatype that can
-  // be used to execute BLAS routines on the current platform. This is typically
-  // not user-facing, as users will use the Stream::ThenBlas* family of routines
-  // to entrain BLAS operations. See blas.h for additional details.
-  //
-  // Ownership is not transferred to the caller -- ownership is retained by this
-  // object for memoization. This BLAS interface is also only expected to be
-  // used by a Stream for entraining calls to BLAS functionality.
-  //
-  // Returns null if there was an error initializing the BLAS support for the
-  // underlying platform.
-  blas::BlasSupport *AsBlas();
+  // Synchronously allocates size bytes on the underlying platform and returns
+  // an opaque void* representing that allocation. In the case of failure,
+  // nullptr is returned.
+  void *Allocate(uint64 size);
 
   // Gets-or-creates (creates with memoization) an RngSupport datatype that can
   // be used for random-number-generation routines on the current platform.
@@ -540,11 +541,6 @@
   // Without blocking the device, retrieve the current stream status.
   port::Status GetStatus(Stream *stream);
 
-  // Synchronously allocates size bytes on the underlying platform and returns
-  // an opaque void* representing that allocation. In the case of failure,
-  // nullptr is returned.
-  void *Allocate(uint64 size);
-
   // Finds and retrieves device memory for the symbol on the underlying
   // platform.
   bool GetSymbol(const string &symbol_name, ModuleHandle module_handle,
@@ -785,10 +781,7 @@
         reinterpret_cast<const char *>(cubin_data.data()), kernel_name);
   }
 
-  if (!GetKernel(loader_spec, kernel_base.get())) {
-    return port::InternalError("Unable to load kernel");
-  }
-
+  TF_RETURN_IF_ERROR(GetKernel(loader_spec, kernel_base.get()));
   return std::move(kernel_base);
 }
 
@@ -882,7 +875,8 @@
     kernel.PackParams(&kernel_args, args...);
     DCHECK(parent_ != nullptr);
     bool ok =
-        parent_->Launch(this, thread_dims, block_dims, kernel, kernel_args);
+        parent_->Launch(this, thread_dims, block_dims, kernel, kernel_args)
+            .ok();
     if (!ok) {
       SetError();
       LOG(WARNING) << "parent failed to launch kernel: " << &kernel;
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index eaa73eb..9638244 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -3,7 +3,7 @@
 # Return the options to use for a C++ library or binary build.
 # Uses the ":optmode" config_setting to pick the options.
 load(
-    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "//tensorflow/core/platform:default/build_config_root.bzl",
     "if_dynamic_kernels",
     "if_static",
     "tf_additional_grpc_deps_py",
@@ -18,7 +18,7 @@
     "if_tensorrt",
 )
 load(
-    "//tensorflow/core:platform/default/cuda_build_defs.bzl",
+    "//tensorflow/core/platform:default/cuda_build_defs.bzl",
     "if_cuda_is_configured",
 )
 load(
@@ -80,7 +80,7 @@
 # i.e. "common_runtime/direct_session_test.cc" becomes
 #      "common_runtime_direct_session_test"
 def src_to_test_name(src):
-    return src.replace("/", "_").split(".")[0]
+    return src.replace("/", "_").replace(":", "_").split(".")[0]
 
 def full_path(relative_paths):
     return [native.package_name() + "/" + relative for relative in relative_paths]
@@ -421,9 +421,21 @@
 
 def tf_binary_pybind_deps():
     return select({
-        clean_dep("//tensorflow:macos"): [clean_dep("//tensorflow/python:lib_pywrap_tensorflow_internal.dylib")],
-        clean_dep("//tensorflow:windows"): [clean_dep("//tensorflow/python:_pywrap_tensorflow_internal.dll")],
-        "//conditions:default": [clean_dep("//tensorflow/python:lib_pywrap_tensorflow_internal.so")],
+        clean_dep("//tensorflow:macos"): [
+            clean_dep(
+                "//tensorflow/python:_pywrap_tensorflow_internal_macos",
+            ),
+        ],
+        clean_dep("//tensorflow:windows"): [
+            clean_dep(
+                "//tensorflow/python:_pywrap_tensorflow_internal_windows",
+            ),
+        ],
+        "//conditions:default": [
+            clean_dep(
+                "//tensorflow/python:_pywrap_tensorflow_internal_linux",
+            ),
+        ],
     })
 
 # Helper function for the per-OS tensorflow libraries and their version symlinks
@@ -1277,10 +1289,10 @@
 def _cuda_copts(opts = []):
     """Gets the appropriate set of copts for (maybe) CUDA compilation.
 
-      If we're doing CUDA compilation, returns copts for our particular CUDA
-      compiler.  If we're not doing CUDA compilation, returns an empty list.
+        If we're doing CUDA compilation, returns copts for our particular CUDA
+        compiler.  If we're not doing CUDA compilation, returns an empty list.
 
-      """
+        """
     return cuda_default_copts() + select({
         "//conditions:default": [],
         "@local_config_cuda//cuda:using_nvcc": ([
@@ -1333,21 +1345,21 @@
 def tf_gpu_library(deps = None, cuda_deps = None, copts = tf_copts(), **kwargs):
     """Generate a cc_library with a conditional set of CUDA dependencies.
 
-    When the library is built with --config=cuda:
+      When the library is built with --config=cuda:
 
-    - Both deps and cuda_deps are used as dependencies.
-    - The cuda runtime is added as a dependency (if necessary).
-    - The library additionally passes -DGOOGLE_CUDA=1 to the list of copts.
-    - In addition, when the library is also built with TensorRT enabled, it
-        additionally passes -DGOOGLE_TENSORRT=1 to the list of copts.
+      - Both deps and cuda_deps are used as dependencies.
+      - The cuda runtime is added as a dependency (if necessary).
+      - The library additionally passes -DGOOGLE_CUDA=1 to the list of copts.
+      - In addition, when the library is also built with TensorRT enabled, it
+          additionally passes -DGOOGLE_TENSORRT=1 to the list of copts.
 
-    Args:
-    - cuda_deps: BUILD dependencies which will be linked if and only if:
-        '--config=cuda' is passed to the bazel command line.
-    - deps: dependencies which will always be linked.
-    - copts: copts always passed to the cc_library.
-    - kwargs: Any other argument to cc_library.
-    """
+      Args:
+      - cuda_deps: BUILD dependencies which will be linked if and only if:
+          '--config=cuda' is passed to the bazel command line.
+      - deps: dependencies which will always be linked.
+      - copts: copts always passed to the cc_library.
+      - kwargs: Any other argument to cc_library.
+      """
     if not deps:
         deps = []
     if not cuda_deps:
@@ -1393,25 +1405,25 @@
         **kwargs):
     """A rule to build a TensorFlow OpKernel.
 
-    May either specify srcs/hdrs or prefix.  Similar to tf_gpu_library,
-    but with alwayslink=1 by default.  If prefix is specified:
-      * prefix*.cc (except *.cu.cc) is added to srcs
-      * prefix*.h (except *.cu.h) is added to hdrs
-      * prefix*.cu.cc and prefix*.h (including *.cu.h) are added to gpu_srcs.
-    With the exception that test files are excluded.
-    For example, with prefix = "cast_op",
-      * srcs = ["cast_op.cc"]
-      * hdrs = ["cast_op.h"]
-      * gpu_srcs = ["cast_op_gpu.cu.cc", "cast_op.h"]
-      * "cast_op_test.cc" is excluded
-    With prefix = "cwise_op"
-      * srcs = ["cwise_op_abs.cc", ..., "cwise_op_tanh.cc"],
-      * hdrs = ["cwise_ops.h", "cwise_ops_common.h"],
-      * gpu_srcs = ["cwise_op_gpu_abs.cu.cc", ..., "cwise_op_gpu_tanh.cu.cc",
-                    "cwise_ops.h", "cwise_ops_common.h",
-                    "cwise_ops_gpu_common.cu.h"]
-      * "cwise_ops_test.cc" is excluded
-    """
+      May either specify srcs/hdrs or prefix.  Similar to tf_gpu_library,
+      but with alwayslink=1 by default.  If prefix is specified:
+        * prefix*.cc (except *.cu.cc) is added to srcs
+        * prefix*.h (except *.cu.h) is added to hdrs
+        * prefix*.cu.cc and prefix*.h (including *.cu.h) are added to gpu_srcs.
+      With the exception that test files are excluded.
+      For example, with prefix = "cast_op",
+        * srcs = ["cast_op.cc"]
+        * hdrs = ["cast_op.h"]
+        * gpu_srcs = ["cast_op_gpu.cu.cc", "cast_op.h"]
+        * "cast_op_test.cc" is excluded
+      With prefix = "cwise_op"
+        * srcs = ["cwise_op_abs.cc", ..., "cwise_op_tanh.cc"],
+        * hdrs = ["cwise_ops.h", "cwise_ops_common.h"],
+        * gpu_srcs = ["cwise_op_gpu_abs.cu.cc", ..., "cwise_op_gpu_tanh.cu.cc",
+                      "cwise_ops.h", "cwise_ops_common.h",
+                      "cwise_ops_gpu_common.cu.h"]
+        * "cwise_ops_test.cc" is excluded
+      """
     if not srcs:
         srcs = []
     if not hdrs:
@@ -1542,13 +1554,13 @@
 def _get_transitive_headers(hdrs, deps):
     """Obtain the header files for a target and its transitive dependencies.
 
-    Args:
-      hdrs: a list of header files
-      deps: a list of targets that are direct dependencies
+      Args:
+        hdrs: a list of header files
+        deps: a list of targets that are direct dependencies
 
-    Returns:
-      a collection of the transitive headers
-    """
+      Returns:
+        a collection of the transitive headers
+      """
     return depset(
         hdrs,
         transitive = [dep[CcInfo].compilation_context.headers for dep in deps],
@@ -1628,14 +1640,14 @@
 def _get_repository_roots(ctx, files):
     """Returns abnormal root directories under which files reside.
 
-    When running a ctx.action, source files within the main repository are all
-    relative to the current directory; however, files that are generated or exist
-    in remote repositories will have their root directory be a subdirectory,
-    e.g. bazel-out/local-fastbuild/genfiles/external/jpeg_archive. This function
-    returns the set of these devious directories, ranked and sorted by popularity
-    in order to hopefully minimize the number of I/O system calls within the
-    compiler, because includes have quadratic complexity.
-    """
+      When running a ctx.action, source files within the main repository are all
+      relative to the current directory; however, files that are generated or exist
+      in remote repositories will have their root directory be a subdirectory,
+      e.g. bazel-out/local-fastbuild/genfiles/external/jpeg_archive. This function
+      returns the set of these devious directories, ranked and sorted by popularity
+      in order to hopefully minimize the number of I/O system calls within the
+      compiler, because includes have quadratic complexity.
+      """
     result = {}
     for f in files.to_list():
         root = f.root.path
@@ -1763,7 +1775,7 @@
 
 def tf_custom_op_library(name, srcs = [], gpu_srcs = [], deps = [], linkopts = [], copts = [], **kwargs):
     """Helper to build a dynamic library (.so) from the sources containing implementations of custom ops and kernels.
-    """
+      """
     cuda_deps = [
         clean_dep("//tensorflow/core:stream_executor_headers_lib"),
         "@local_config_cuda//cuda:cuda_headers",
@@ -2108,11 +2120,12 @@
         shard_count = shard_count,
         srcs_version = "PY2AND3",
         tags = tags,
-        visibility = [clean_dep("//tensorflow:internal")] + additional_visibility,
-        deps = [
+        visibility = [clean_dep("//tensorflow:internal")] +
+                     additional_visibility,
+        deps = depset([
             clean_dep("//tensorflow/python:extra_py_tests_deps"),
             clean_dep("//tensorflow/python:gradient_checker"),
-        ] + additional_deps + xla_test_true_list,
+        ] + additional_deps + xla_test_true_list),
         **kwargs
     )
 
@@ -2378,7 +2391,8 @@
 def tensorflow_opensource_extra_deps():
     return []
 
-def tf_pybind_extension(
+# buildozer: disable=function-docstring-args
+def pybind_extension(
         name,
         srcs,
         module_name,
@@ -2396,7 +2410,7 @@
         compatible_with = None,
         restricted_to = None,
         deprecation = None):
-    """Builds a Python extension module."""
+    """Builds a generic Python extension module."""
     _ignore = [module_name]
     p = name.rfind("/")
     if p == -1:
@@ -2431,8 +2445,8 @@
     )
     native.cc_binary(
         name = so_file,
-        srcs = srcs + hdrs + tf_binary_additional_srcs() + tf_binary_pybind_deps(),
-        data = data + tf_binary_pybind_deps(),
+        srcs = srcs + hdrs,
+        data = data,
         copts = copts,
         nocopts = nocopts,
         linkopts = linkopts + _rpath_linkopts(name) + select({
@@ -2484,29 +2498,54 @@
         compatible_with = compatible_with,
     )
 
+# buildozer: enable=function-docstring-args
+
+def tf_python_pybind_extension(
+        name,
+        srcs,
+        module_name,
+        hdrs = [],
+        features = [],
+        copts = None,
+        deps = []):
+    """A wrapper macro for pybind_extension that is used in tensorflow/python/BUILD.
+
+    It is used for targets under //third_party/tensorflow/python that link
+    against libtensorflow_framework.so and pywrap_tensorflow_internal.so.
+    """
+    pybind_extension(
+        name,
+        srcs + tf_binary_additional_srcs(),
+        module_name,
+        hdrs = hdrs,
+        features = features,
+        copts = copts,
+        deps = deps + tf_binary_pybind_deps(),
+    )
+
 def if_cuda_or_rocm(if_true, if_false = []):
     """Shorthand for select()'ing whether to build for either CUDA or ROCm.
 
-    Returns a select statement which evaluates to
-       if_true if we're building with either CUDA or ROCm enabled.
-       if_false, otherwise.
+      Returns a select statement which evaluates to
+         if_true if we're building with either CUDA or ROCm enabled.
+         if_false, otherwise.
 
-    Sometimes a target has additional CUDa or ROCm specific dependencies.
-    The `if_cuda` / `if_rocm` functions are used to specify these additional
-    dependencies. For eg, see the `//tensorflow/core/kernels:bias_op` target
+      Sometimes a target has additional CUDa or ROCm specific dependencies.
+      The `if_cuda` / `if_rocm` functions are used to specify these additional
+      dependencies. For eg, see the `//tensorflow/core/kernels:bias_op` target
 
-    If the same additional dependency is needed for both CUDA and ROCm
-    (for eg. `reduction_ops` dependency for the `bias_op` target above),
-    then specifying that dependency in both  both `if_cuda` and `if_rocm` will
-    result in both those functions returning a select statement, which contains
-    the same dependency, which then leads to a duplicate dependency bazel error.
+      If the same additional dependency is needed for both CUDA and ROCm
+      (for eg. `reduction_ops` dependency for the `bias_op` target above),
+      then specifying that dependency in both  both `if_cuda` and `if_rocm` will
+      result in both those functions returning a select statement, which contains
+      the same dependency, which then leads to a duplicate dependency bazel error.
 
-    In order to work around this error, any additional dependency that is common
-    to both the CUDA and ROCm platforms, should be specified using this function.
-    Doing so will eliminate the cause of the bazel error (i.e. the  same
-    dependency showing up in two different select statements)
+      In order to work around this error, any additional dependency that is common
+      to both the CUDA and ROCm platforms, should be specified using this function.
+      Doing so will eliminate the cause of the bazel error (i.e. the  same
+      dependency showing up in two different select statements)
 
-    """
+      """
     return select({
         "@local_config_cuda//cuda:using_nvcc": if_true,
         "@local_config_cuda//cuda:using_clang": if_true,
@@ -2523,8 +2562,9 @@
         "//tensorflow:with_mlir_support": if_true,
     })
 
+# TODO(b/138724071): Remove when build is stable.
 def if_mlir_tflite(if_true, if_false = []):
-    return if_mlir(if_true, if_false)
+    return if_true  # Internally we always build with MLIR.
 
 def tfcompile_extra_flags():
     return ""
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor-spec.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor-spec.pbtxt
index 983d34b..80d9853 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor-spec.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor-spec.pbtxt
@@ -5,6 +5,14 @@
   is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
   is_instance: "<type \'object\'>"
   member {
+    name: "dtype"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "shape"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "value_type"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-tensor-info.-composite-tensor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-tensor-info.-composite-tensor.pbtxt
new file mode 100644
index 0000000..5fe1b98
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-tensor-info.-composite-tensor.pbtxt
@@ -0,0 +1,20 @@
+path: "tensorflow.TensorInfo.CompositeTensor"
+tf_proto {
+  descriptor {
+    name: "CompositeTensor"
+    field {
+      name: "type_spec"
+      number: 1
+      label: LABEL_OPTIONAL
+      type: TYPE_MESSAGE
+      type_name: ".tensorflow.TypeSpecProto"
+    }
+    field {
+      name: "components"
+      number: 2
+      label: LABEL_REPEATED
+      type: TYPE_MESSAGE
+      type_name: ".tensorflow.TensorInfo"
+    }
+  }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-tensor-info.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-tensor-info.pbtxt
index 63566c8..48773ea 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-tensor-info.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-tensor-info.pbtxt
@@ -18,6 +18,14 @@
       oneof_index: 0
     }
     field {
+      name: "composite_tensor"
+      number: 5
+      label: LABEL_OPTIONAL
+      type: TYPE_MESSAGE
+      type_name: ".tensorflow.TensorInfo.CompositeTensor"
+      oneof_index: 0
+    }
+    field {
       name: "dtype"
       number: 2
       label: LABEL_OPTIONAL
@@ -52,6 +60,23 @@
         type: TYPE_STRING
       }
     }
+    nested_type {
+      name: "CompositeTensor"
+      field {
+        name: "type_spec"
+        number: 1
+        label: LABEL_OPTIONAL
+        type: TYPE_MESSAGE
+        type_name: ".tensorflow.TypeSpecProto"
+      }
+      field {
+        name: "components"
+        number: 2
+        label: LABEL_REPEATED
+        type: TYPE_MESSAGE
+        type_name: ".tensorflow.TensorInfo"
+      }
+    }
     oneof_decl {
       name: "encoding"
     }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-tensor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-tensor.pbtxt
index 4506fcc..9f35e14 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-tensor.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-tensor.pbtxt
@@ -48,6 +48,10 @@
     argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
   }
   member_method {
+    name: "experimental_ref"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
     name: "get_shape"
     argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-variable.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-variable.pbtxt
index fb7af9a..df68721 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-variable.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-variable.pbtxt
@@ -85,6 +85,10 @@
     argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "experimental_ref"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
     name: "from_proto"
     argspec: "args=[\'variable_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.autograph.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.autograph.experimental.pbtxt
index 44afc34..1454a2d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.autograph.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.autograph.experimental.pbtxt
@@ -6,7 +6,7 @@
   }
   member_method {
     name: "do_not_convert"
-    argspec: "args=[\'func\', \'run_as\', \'return_dtypes\'], varargs=None, keywords=None, defaults=[\'None\', \'RunMode.GRAPH\', \'None\'], "
+    argspec: "args=[\'func\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "set_loop_options"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.config.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.config.pbtxt
index 0c29d7a..cc188a1 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.config.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.config.pbtxt
@@ -14,7 +14,7 @@
   }
   member_method {
     name: "experimental_connect_to_cluster"
-    argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'grpc\'], "
+    argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\'], "
   }
   member_method {
     name: "experimental_connect_to_host"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.experimental.pbtxt
index 0c3f04e..5826676 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.experimental.pbtxt
@@ -4,4 +4,8 @@
     name: "function_executor_type"
     argspec: "args=[\'executor_type\'], varargs=None, keywords=None, defaults=None"
   }
+  member_method {
+    name: "output_all_intermediates"
+    argspec: "args=[\'state\'], varargs=None, keywords=None, defaults=None"
+  }
 }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
index a13e20b..5d01d06 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
@@ -169,7 +169,7 @@
   }
   member_method {
     name: "compile"
-    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "compute_mask"
@@ -277,7 +277,7 @@
   }
   member_method {
     name: "save"
-    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\'], "
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
   }
   member_method {
     name: "save_weights"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
index 9ddbdf2..48252ca 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
@@ -174,7 +174,7 @@
   }
   member_method {
     name: "compile"
-    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "compute_mask"
@@ -294,7 +294,7 @@
   }
   member_method {
     name: "save"
-    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\'], "
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
   }
   member_method {
     name: "save_weights"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt
new file mode 100644
index 0000000..f43bf39
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt
@@ -0,0 +1,315 @@
+path: "tensorflow.keras.experimental.LinearModel"
+tf_class {
+  is_instance: "<class \'tensorflow.python.keras.premade.linear.LinearModel\'>"
+  is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
+  is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
+  is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
+  is_instance: "<class \'tensorflow.python.module.module.Module\'>"
+  is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
+  is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "activity_regularizer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "dtype"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "dynamic"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "inbound_nodes"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input_mask"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input_shape"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input_spec"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "layers"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "losses"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "metrics"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "metrics_names"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "name"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "name_scope"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "outbound_nodes"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output_mask"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output_shape"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "run_eagerly"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "sample_weights"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "state_updates"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "stateful"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "submodules"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "trainable"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "trainable_variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "trainable_weights"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "updates"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "weights"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_metric"
+    argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compile"
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_output_shape"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_output_signature"
+    argspec: "args=[\'self\', \'input_signature\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "evaluate"
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
+  }
+  member_method {
+    name: "evaluate_generator"
+    argspec: "args=[\'self\', \'generator\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'10\', \'1\', \'False\', \'0\'], "
+  }
+  member_method {
+    name: "fit"
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], "
+  }
+  member_method {
+    name: "fit_generator"
+    argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'validation_freq\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'1\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_layer"
+    argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "load_weights"
+    argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], "
+  }
+  member_method {
+    name: "predict"
+    argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
+  }
+  member_method {
+    name: "predict_generator"
+    argspec: "args=[\'self\', \'generator\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'10\', \'1\', \'False\', \'0\'], "
+  }
+  member_method {
+    name: "predict_on_batch"
+    argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "reset_metrics"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "reset_states"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "save"
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "save_weights"
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "summary"
+    argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "test_on_batch"
+    argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'reset_metrics\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
+  }
+  member_method {
+    name: "to_json"
+    argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "to_yaml"
+    argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "train_on_batch"
+    argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\', \'reset_metrics\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
+  }
+  member_method {
+    name: "with_name_scope"
+    argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt
new file mode 100644
index 0000000..f47a393
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt
@@ -0,0 +1,315 @@
+path: "tensorflow.keras.experimental.WideDeepModel"
+tf_class {
+  is_instance: "<class \'tensorflow.python.keras.premade.wide_deep.WideDeepModel\'>"
+  is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
+  is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
+  is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
+  is_instance: "<class \'tensorflow.python.module.module.Module\'>"
+  is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
+  is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "activity_regularizer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "dtype"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "dynamic"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "inbound_nodes"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input_mask"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input_shape"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input_spec"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "layers"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "losses"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "metrics"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "metrics_names"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "name"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "name_scope"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "outbound_nodes"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output_mask"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output_shape"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "run_eagerly"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "sample_weights"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "state_updates"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "stateful"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "submodules"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "trainable"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "trainable_variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "trainable_weights"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "updates"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "weights"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'linear_model\', \'dnn_model\', \'activation\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_metric"
+    argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compile"
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_output_shape"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_output_signature"
+    argspec: "args=[\'self\', \'input_signature\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "evaluate"
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
+  }
+  member_method {
+    name: "evaluate_generator"
+    argspec: "args=[\'self\', \'generator\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'10\', \'1\', \'False\', \'0\'], "
+  }
+  member_method {
+    name: "fit"
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], "
+  }
+  member_method {
+    name: "fit_generator"
+    argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'validation_freq\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'1\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_layer"
+    argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "load_weights"
+    argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], "
+  }
+  member_method {
+    name: "predict"
+    argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
+  }
+  member_method {
+    name: "predict_generator"
+    argspec: "args=[\'self\', \'generator\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'10\', \'1\', \'False\', \'0\'], "
+  }
+  member_method {
+    name: "predict_on_batch"
+    argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "reset_metrics"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "reset_states"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "save"
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "save_weights"
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "summary"
+    argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "test_on_batch"
+    argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'reset_metrics\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
+  }
+  member_method {
+    name: "to_json"
+    argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "to_yaml"
+    argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "train_on_batch"
+    argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\', \'reset_metrics\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
+  }
+  member_method {
+    name: "with_name_scope"
+    argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.pbtxt
index bfd169a..4a83b58 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'type\'>"
   }
   member {
+    name: "LinearModel"
+    mtype: "<type \'type\'>"
+  }
+  member {
     name: "NoisyLinearCosineDecay"
     mtype: "<type \'type\'>"
   }
@@ -24,6 +28,10 @@
     name: "SequenceFeatures"
     mtype: "<type \'type\'>"
   }
+  member {
+    name: "WideDeepModel"
+    mtype: "<type \'type\'>"
+  }
   member_method {
     name: "export_saved_model"
     argspec: "args=[\'model\', \'saved_model_path\', \'custom_objects\', \'as_text\', \'input_signature\', \'serving_only\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt
index 0e176cb..718b0f7 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt
@@ -1,6 +1,6 @@
 path: "tensorflow.keras.layers.DenseFeatures"
 tf_class {
-  is_instance: "<class \'tensorflow.python.feature_column.feature_column_v2.DenseFeatures\'>"
+  is_instance: "<class \'tensorflow.python.feature_column.dense_features.DenseFeatures\'>"
   is_instance: "<class \'tensorflow.python.feature_column.feature_column_v2._BaseFeaturesLayer\'>"
   is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
   is_instance: "<class \'tensorflow.python.module.module.Module\'>"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt
index a2af655..600f11b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt
@@ -3,7 +3,7 @@
   is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.policy.Policy\'>"
   is_instance: "<type \'object\'>"
   member {
-    name: "default_variable_dtype"
+    name: "compute_dtype"
     mtype: "<type \'property\'>"
   }
   member {
@@ -14,6 +14,10 @@
     name: "should_cast_variables"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "variable_dtype"
+    mtype: "<type \'property\'>"
+  }
   member_method {
     name: "__init__"
     argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
index 3840d3d..3a9b05d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
@@ -169,7 +169,7 @@
   }
   member_method {
     name: "compile"
-    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "compute_mask"
@@ -277,7 +277,7 @@
   }
   member_method {
     name: "save"
-    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\'], "
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
   }
   member_method {
     name: "save_weights"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
index 3d9f85c..2343dd8 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
@@ -174,7 +174,7 @@
   }
   member_method {
     name: "compile"
-    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "compute_mask"
@@ -294,7 +294,7 @@
   }
   member_method {
     name: "save"
-    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\'], "
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
   }
   member_method {
     name: "save_weights"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.pbtxt
index 9a3f95c..311142b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.pbtxt
@@ -30,6 +30,6 @@
   }
   member_method {
     name: "save_model"
-    argspec: "args=[\'model\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\'], "
+    argspec: "args=[\'model\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
   }
 }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt
index 5e3376d..bf7812a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt
@@ -313,6 +313,10 @@
     argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "reciprocal_no_nan"
+    argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "reduce_all"
     argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index 32f85a0..47f82fb 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -1145,6 +1145,10 @@
     argspec: "args=[], varargs=None, keywords=None, defaults=None"
   }
   member_method {
+    name: "disable_tensor_equality"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
     name: "disable_v2_behavior"
     argspec: "args=[], varargs=None, keywords=None, defaults=None"
   }
@@ -1193,6 +1197,10 @@
     argspec: "args=[], varargs=None, keywords=None, defaults=None"
   }
   member_method {
+    name: "enable_tensor_equality"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
     name: "enable_v2_behavior"
     argspec: "args=[], varargs=None, keywords=None, defaults=None"
   }
@@ -1322,7 +1330,7 @@
   }
   member_method {
     name: "function"
-    argspec: "args=[\'func\', \'input_signature\', \'autograph\', \'experimental_autograph_options\', \'experimental_relax_shapes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'False\'], "
+    argspec: "args=[\'func\', \'input_signature\', \'autograph\', \'experimental_autograph_options\', \'experimental_relax_shapes\', \'experimental_compile\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'False\', \'None\'], "
   }
   member_method {
     name: "gather"
@@ -2214,7 +2222,7 @@
   }
   member_method {
     name: "sparse_tensor_to_dense"
-    argspec: "args=[\'sp_input\', \'default_value\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'True\', \'None\'], "
+    argspec: "args=[\'sp_input\', \'default_value\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
   }
   member_method {
     name: "sparse_to_dense"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.ragged.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.ragged.pbtxt
index 6b07759..c37b511 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.ragged.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.ragged.pbtxt
@@ -36,4 +36,12 @@
     name: "segment_ids_to_row_splits"
     argspec: "args=[\'segment_ids\', \'num_segments\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
   }
+  member_method {
+    name: "stack"
+    argspec: "args=[\'values\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
+  }
+  member_method {
+    name: "stack_dynamic_partitions"
+    argspec: "args=[\'data\', \'partitions\', \'num_partitions\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
 }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index 473323b..cff4910 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -97,10 +97,18 @@
     argspec: "args=[\'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "AnonymousMemoryCache"
+    argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "AnonymousMultiDeviceIterator"
     argspec: "args=[\'devices\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "AnonymousRandomSeedGenerator"
+    argspec: "args=[\'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "Any"
     argspec: "args=[\'input\', \'axis\', \'keep_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
   }
@@ -121,6 +129,10 @@
     argspec: "args=[\'var\', \'gradient_accumulator\', \'gradient_squared_accumulator\', \'grad\', \'lr\', \'l1\', \'l2\', \'global_step\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
   }
   member_method {
+    name: "ApplyAdagradV2"
+    argspec: "args=[\'var\', \'accum\', \'lr\', \'epsilon\', \'grad\', \'use_locking\', \'update_slots\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
+  }
+  member_method {
     name: "ApplyAdam"
     argspec: "args=[\'var\', \'m\', \'v\', \'beta1_power\', \'beta2_power\', \'lr\', \'beta1\', \'beta2\', \'epsilon\', \'grad\', \'use_locking\', \'use_nesterov\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
   }
@@ -601,6 +613,10 @@
     argspec: "args=[\'input_dataset\', \'filename\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "CacheDatasetV2"
+    argspec: "args=[\'input_dataset\', \'filename\', \'cache\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "Case"
     argspec: "args=[\'branch_index\', \'input\', \'Tout\', \'branches\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'None\'], "
   }
@@ -705,6 +721,10 @@
     argspec: "args=[\'embedding_config\', \'tpu_embedding_config\', \'is_global_init\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'False\', \'None\'], "
   }
   member_method {
+    name: "ConfigureTPUEmbedding"
+    argspec: "args=[\'config\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "Conj"
     argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
@@ -949,10 +969,18 @@
     argspec: "args=[\'handle\', \'deleter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "DeleteMemoryCache"
+    argspec: "args=[\'handle\', \'deleter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "DeleteMultiDeviceIterator"
     argspec: "args=[\'multi_device_iterator\', \'iterators\', \'deleter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "DeleteRandomSeedGenerator"
+    argspec: "args=[\'handle\', \'deleter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "DeleteSessionTensor"
     argspec: "args=[\'handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
@@ -1258,7 +1286,7 @@
   }
   member_method {
     name: "ExperimentalRebatchDataset"
-    argspec: "args=[\'input_dataset\', \'num_workers\', \'output_types\', \'output_shapes\', \'use_fallback\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+    argspec: "args=[\'input_dataset\', \'num_replicas\', \'output_types\', \'output_shapes\', \'use_fallback\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
   }
   member_method {
     name: "ExperimentalScanDataset"
@@ -2546,7 +2574,7 @@
   }
   member_method {
     name: "PrefetchDataset"
-    argspec: "args=[\'input_dataset\', \'buffer_size\', \'output_types\', \'output_shapes\', \'slack_period\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
+    argspec: "args=[\'input_dataset\', \'buffer_size\', \'output_types\', \'output_shapes\', \'slack_period\', \'legacy_autotune\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'True\', \'None\'], "
   }
   member_method {
     name: "Prelinearize"
@@ -2845,6 +2873,10 @@
     argspec: "args=[\'rt_nested_splits\', \'rt_dense_values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "RaggedTensorToTensor"
+    argspec: "args=[\'shape\', \'values\', \'default_value\', \'row_partition_tensors\', \'row_partition_types\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "RaggedTensorToVariant"
     argspec: "args=[\'rt_nested_splits\', \'rt_dense_values\', \'batched_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
@@ -2982,7 +3014,7 @@
   }
   member_method {
     name: "RebatchDataset"
-    argspec: "args=[\'input_dataset\', \'num_workers\', \'output_types\', \'output_shapes\', \'use_fallback\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+    argspec: "args=[\'input_dataset\', \'num_replicas\', \'output_types\', \'output_shapes\', \'use_fallback\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
   }
   member_method {
     name: "Reciprocal"
@@ -3149,6 +3181,10 @@
     argspec: "args=[\'var\', \'gradient_accumulator\', \'gradient_squared_accumulator\', \'grad\', \'lr\', \'l1\', \'l2\', \'global_step\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
   }
   member_method {
+    name: "ResourceApplyAdagradV2"
+    argspec: "args=[\'var\', \'accum\', \'lr\', \'epsilon\', \'grad\', \'use_locking\', \'update_slots\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
+  }
+  member_method {
     name: "ResourceApplyAdam"
     argspec: "args=[\'var\', \'m\', \'v\', \'beta1_power\', \'beta2_power\', \'lr\', \'beta1\', \'beta2\', \'epsilon\', \'grad\', \'use_locking\', \'use_nesterov\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
   }
@@ -3269,6 +3305,10 @@
     argspec: "args=[\'var\', \'gradient_accumulator\', \'gradient_squared_accumulator\', \'grad\', \'indices\', \'lr\', \'l1\', \'l2\', \'global_step\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
   }
   member_method {
+    name: "ResourceSparseApplyAdagradV2"
+    argspec: "args=[\'var\', \'accum\', \'lr\', \'epsilon\', \'grad\', \'indices\', \'use_locking\', \'update_slots\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
+  }
+  member_method {
     name: "ResourceSparseApplyCenteredRMSProp"
     argspec: "args=[\'var\', \'mg\', \'ms\', \'mom\', \'lr\', \'rho\', \'momentum\', \'epsilon\', \'grad\', \'indices\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
   }
@@ -3629,6 +3669,10 @@
     argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'output_types\', \'output_shapes\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
   }
   member_method {
+    name: "ShuffleDatasetV2"
+    argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed_generator\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "ShutdownDistributedTPU"
     argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
@@ -3745,6 +3789,10 @@
     argspec: "args=[\'var\', \'gradient_accumulator\', \'gradient_squared_accumulator\', \'grad\', \'indices\', \'lr\', \'l1\', \'l2\', \'global_step\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
   }
   member_method {
+    name: "SparseApplyAdagradV2"
+    argspec: "args=[\'var\', \'accum\', \'lr\', \'epsilon\', \'grad\', \'indices\', \'use_locking\', \'update_slots\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
+  }
+  member_method {
     name: "SparseApplyCenteredRMSProp"
     argspec: "args=[\'var\', \'mg\', \'ms\', \'mom\', \'lr\', \'rho\', \'momentum\', \'epsilon\', \'grad\', \'indices\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
   }
@@ -4050,7 +4098,7 @@
   }
   member_method {
     name: "StatelessWhile"
-    argspec: "args=[\'input\', \'cond\', \'body\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'input\', \'cond\', \'body\', \'output_shapes\', \'parallel_iterations\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'10\', \'None\'], "
   }
   member_method {
     name: "StaticRegexFullMatch"
@@ -4109,6 +4157,10 @@
     argspec: "args=[\'input\', \'encoding\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
   }
   member_method {
+    name: "StringNGrams"
+    argspec: "args=[\'data\', \'data_splits\', \'separator\', \'ngram_widths\', \'left_pad\', \'right_pad\', \'pad_width\', \'preserve_short_sequences\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "StringSplit"
     argspec: "args=[\'input\', \'delimiter\', \'skip_empty\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt
index 1fc79d5..27c64f2 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt
@@ -126,7 +126,7 @@
   }
   member_method {
     name: "to_dense"
-    argspec: "args=[\'sp_input\', \'default_value\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'True\', \'None\'], "
+    argspec: "args=[\'sp_input\', \'default_value\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
   }
   member_method {
     name: "to_indicator"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
index 1a73ab6..b500833 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
@@ -25,6 +25,10 @@
     argspec: "args=[\'input\', \'encoding\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
   }
   member_method {
+    name: "ngrams"
+    argspec: "args=[\'data\', \'ngram_width\', \'separator\', \'pad_values\', \'padding_width\', \'preserve_short_sequences\', \'name\'], varargs=None, keywords=None, defaults=[\' \', \'None\', \'None\', \'False\', \'None\'], "
+  }
+  member_method {
     name: "reduce_join"
     argspec: "args=[\'inputs\', \'axis\', \'keep_dims\', \'separator\', \'name\', \'reduction_indices\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'\', \'None\', \'None\', \'None\'], "
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor-spec.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor-spec.pbtxt
index 983d34b..80d9853 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor-spec.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor-spec.pbtxt
@@ -5,6 +5,14 @@
   is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
   is_instance: "<type \'object\'>"
   member {
+    name: "dtype"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "shape"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "value_type"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-tensor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-tensor.pbtxt
index 4506fcc..9f35e14 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-tensor.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-tensor.pbtxt
@@ -48,6 +48,10 @@
     argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
   }
   member_method {
+    name: "experimental_ref"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
     name: "get_shape"
     argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-variable.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-variable.pbtxt
index 3585197..f53a5ec 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-variable.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-variable.pbtxt
@@ -84,6 +84,10 @@
     argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "experimental_ref"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
     name: "from_proto"
     argspec: "args=[\'variable_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.autograph.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.autograph.experimental.pbtxt
index 44afc34..1454a2d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.autograph.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.autograph.experimental.pbtxt
@@ -6,7 +6,7 @@
   }
   member_method {
     name: "do_not_convert"
-    argspec: "args=[\'func\', \'run_as\', \'return_dtypes\'], varargs=None, keywords=None, defaults=[\'None\', \'RunMode.GRAPH\', \'None\'], "
+    argspec: "args=[\'func\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "set_loop_options"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.config.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.config.pbtxt
index 0c29d7a..cc188a1 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.config.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.config.pbtxt
@@ -14,7 +14,7 @@
   }
   member_method {
     name: "experimental_connect_to_cluster"
-    argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'grpc\'], "
+    argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\'], "
   }
   member_method {
     name: "experimental_connect_to_host"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
index a13e20b..5d01d06 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
@@ -169,7 +169,7 @@
   }
   member_method {
     name: "compile"
-    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "compute_mask"
@@ -277,7 +277,7 @@
   }
   member_method {
     name: "save"
-    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\'], "
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
   }
   member_method {
     name: "save_weights"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
index 9ddbdf2..48252ca 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
@@ -174,7 +174,7 @@
   }
   member_method {
     name: "compile"
-    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "compute_mask"
@@ -294,7 +294,7 @@
   }
   member_method {
     name: "save"
-    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\'], "
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
   }
   member_method {
     name: "save_weights"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt
new file mode 100644
index 0000000..f43bf39
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt
@@ -0,0 +1,315 @@
+path: "tensorflow.keras.experimental.LinearModel"
+tf_class {
+  is_instance: "<class \'tensorflow.python.keras.premade.linear.LinearModel\'>"
+  is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
+  is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
+  is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
+  is_instance: "<class \'tensorflow.python.module.module.Module\'>"
+  is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
+  is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "activity_regularizer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "dtype"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "dynamic"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "inbound_nodes"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input_mask"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input_shape"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input_spec"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "layers"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "losses"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "metrics"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "metrics_names"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "name"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "name_scope"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "outbound_nodes"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output_mask"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output_shape"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "run_eagerly"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "sample_weights"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "state_updates"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "stateful"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "submodules"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "trainable"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "trainable_variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "trainable_weights"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "updates"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "weights"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_metric"
+    argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compile"
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_output_shape"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_output_signature"
+    argspec: "args=[\'self\', \'input_signature\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "evaluate"
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
+  }
+  member_method {
+    name: "evaluate_generator"
+    argspec: "args=[\'self\', \'generator\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'10\', \'1\', \'False\', \'0\'], "
+  }
+  member_method {
+    name: "fit"
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], "
+  }
+  member_method {
+    name: "fit_generator"
+    argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'validation_freq\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'1\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_layer"
+    argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "load_weights"
+    argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], "
+  }
+  member_method {
+    name: "predict"
+    argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
+  }
+  member_method {
+    name: "predict_generator"
+    argspec: "args=[\'self\', \'generator\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'10\', \'1\', \'False\', \'0\'], "
+  }
+  member_method {
+    name: "predict_on_batch"
+    argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "reset_metrics"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "reset_states"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "save"
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "save_weights"
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "summary"
+    argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "test_on_batch"
+    argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'reset_metrics\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
+  }
+  member_method {
+    name: "to_json"
+    argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "to_yaml"
+    argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "train_on_batch"
+    argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\', \'reset_metrics\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
+  }
+  member_method {
+    name: "with_name_scope"
+    argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt
new file mode 100644
index 0000000..f47a393
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt
@@ -0,0 +1,315 @@
+path: "tensorflow.keras.experimental.WideDeepModel"
+tf_class {
+  is_instance: "<class \'tensorflow.python.keras.premade.wide_deep.WideDeepModel\'>"
+  is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
+  is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
+  is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
+  is_instance: "<class \'tensorflow.python.module.module.Module\'>"
+  is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
+  is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "activity_regularizer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "dtype"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "dynamic"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "inbound_nodes"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input_mask"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input_shape"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input_spec"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "layers"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "losses"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "metrics"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "metrics_names"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "name"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "name_scope"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "outbound_nodes"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output_mask"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output_shape"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "run_eagerly"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "sample_weights"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "state_updates"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "stateful"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "submodules"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "trainable"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "trainable_variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "trainable_weights"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "updates"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "weights"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'linear_model\', \'dnn_model\', \'activation\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_metric"
+    argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compile"
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_output_shape"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_output_signature"
+    argspec: "args=[\'self\', \'input_signature\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "evaluate"
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
+  }
+  member_method {
+    name: "evaluate_generator"
+    argspec: "args=[\'self\', \'generator\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'10\', \'1\', \'False\', \'0\'], "
+  }
+  member_method {
+    name: "fit"
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'validation_freq\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'1\', \'10\', \'1\', \'False\'], "
+  }
+  member_method {
+    name: "fit_generator"
+    argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'validation_freq\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'1\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_layer"
+    argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "load_weights"
+    argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], "
+  }
+  member_method {
+    name: "predict"
+    argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
+  }
+  member_method {
+    name: "predict_generator"
+    argspec: "args=[\'self\', \'generator\', \'steps\', \'callbacks\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'10\', \'1\', \'False\', \'0\'], "
+  }
+  member_method {
+    name: "predict_on_batch"
+    argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "reset_metrics"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "reset_states"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "save"
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "save_weights"
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "summary"
+    argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "test_on_batch"
+    argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'reset_metrics\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
+  }
+  member_method {
+    name: "to_json"
+    argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "to_yaml"
+    argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "train_on_batch"
+    argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\', \'reset_metrics\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
+  }
+  member_method {
+    name: "with_name_scope"
+    argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.pbtxt
index bfd169a..4a83b58 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'type\'>"
   }
   member {
+    name: "LinearModel"
+    mtype: "<type \'type\'>"
+  }
+  member {
     name: "NoisyLinearCosineDecay"
     mtype: "<type \'type\'>"
   }
@@ -24,6 +28,10 @@
     name: "SequenceFeatures"
     mtype: "<type \'type\'>"
   }
+  member {
+    name: "WideDeepModel"
+    mtype: "<type \'type\'>"
+  }
   member_method {
     name: "export_saved_model"
     argspec: "args=[\'model\', \'saved_model_path\', \'custom_objects\', \'as_text\', \'input_signature\', \'serving_only\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense-features.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense-features.pbtxt
index 0e176cb..631012b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense-features.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense-features.pbtxt
@@ -1,6 +1,7 @@
 path: "tensorflow.keras.layers.DenseFeatures"
 tf_class {
-  is_instance: "<class \'tensorflow.python.feature_column.feature_column_v2.DenseFeatures\'>"
+  is_instance: "<class \'tensorflow.python.feature_column.dense_features_v2.DenseFeatures\'>"
+  is_instance: "<class \'tensorflow.python.feature_column.dense_features.DenseFeatures\'>"
   is_instance: "<class \'tensorflow.python.feature_column.feature_column_v2._BaseFeaturesLayer\'>"
   is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
   is_instance: "<class \'tensorflow.python.module.module.Module\'>"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt
index a2af655..600f11b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt
@@ -3,7 +3,7 @@
   is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.policy.Policy\'>"
   is_instance: "<type \'object\'>"
   member {
-    name: "default_variable_dtype"
+    name: "compute_dtype"
     mtype: "<type \'property\'>"
   }
   member {
@@ -14,6 +14,10 @@
     name: "should_cast_variables"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "variable_dtype"
+    mtype: "<type \'property\'>"
+  }
   member_method {
     name: "__init__"
     argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
index 3840d3d..3a9b05d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
@@ -169,7 +169,7 @@
   }
   member_method {
     name: "compile"
-    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "compute_mask"
@@ -277,7 +277,7 @@
   }
   member_method {
     name: "save"
-    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\'], "
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
   }
   member_method {
     name: "save_weights"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
index 3d9f85c..2343dd8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
@@ -174,7 +174,7 @@
   }
   member_method {
     name: "compile"
-    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "compute_mask"
@@ -294,7 +294,7 @@
   }
   member_method {
     name: "save"
-    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\'], "
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
   }
   member_method {
     name: "save_weights"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt
index 9a3f95c..311142b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt
@@ -30,6 +30,6 @@
   }
   member_method {
     name: "save_model"
-    argspec: "args=[\'model\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\'], "
+    argspec: "args=[\'model\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
   }
 }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt
index f0f6373..82688f5 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt
@@ -313,6 +313,10 @@
     argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "reciprocal_no_nan"
+    argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "reduce_all"
     argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
index 33c4610..63c70f8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
@@ -646,7 +646,7 @@
   }
   member_method {
     name: "function"
-    argspec: "args=[\'func\', \'input_signature\', \'autograph\', \'experimental_autograph_options\', \'experimental_relax_shapes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'False\'], "
+    argspec: "args=[\'func\', \'input_signature\', \'autograph\', \'experimental_autograph_options\', \'experimental_relax_shapes\', \'experimental_compile\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'False\', \'None\'], "
   }
   member_method {
     name: "gather"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.ragged.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.ragged.pbtxt
index d3f70f1..75144f1 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.ragged.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.ragged.pbtxt
@@ -24,4 +24,12 @@
     name: "segment_ids_to_row_splits"
     argspec: "args=[\'segment_ids\', \'num_segments\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
   }
+  member_method {
+    name: "stack"
+    argspec: "args=[\'values\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
+  }
+  member_method {
+    name: "stack_dynamic_partitions"
+    argspec: "args=[\'data\', \'partitions\', \'num_partitions\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
 }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index 473323b..cff4910 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -97,10 +97,18 @@
     argspec: "args=[\'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "AnonymousMemoryCache"
+    argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "AnonymousMultiDeviceIterator"
     argspec: "args=[\'devices\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "AnonymousRandomSeedGenerator"
+    argspec: "args=[\'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "Any"
     argspec: "args=[\'input\', \'axis\', \'keep_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
   }
@@ -121,6 +129,10 @@
     argspec: "args=[\'var\', \'gradient_accumulator\', \'gradient_squared_accumulator\', \'grad\', \'lr\', \'l1\', \'l2\', \'global_step\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
   }
   member_method {
+    name: "ApplyAdagradV2"
+    argspec: "args=[\'var\', \'accum\', \'lr\', \'epsilon\', \'grad\', \'use_locking\', \'update_slots\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
+  }
+  member_method {
     name: "ApplyAdam"
     argspec: "args=[\'var\', \'m\', \'v\', \'beta1_power\', \'beta2_power\', \'lr\', \'beta1\', \'beta2\', \'epsilon\', \'grad\', \'use_locking\', \'use_nesterov\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
   }
@@ -601,6 +613,10 @@
     argspec: "args=[\'input_dataset\', \'filename\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "CacheDatasetV2"
+    argspec: "args=[\'input_dataset\', \'filename\', \'cache\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "Case"
     argspec: "args=[\'branch_index\', \'input\', \'Tout\', \'branches\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'None\'], "
   }
@@ -705,6 +721,10 @@
     argspec: "args=[\'embedding_config\', \'tpu_embedding_config\', \'is_global_init\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'False\', \'None\'], "
   }
   member_method {
+    name: "ConfigureTPUEmbedding"
+    argspec: "args=[\'config\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "Conj"
     argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
@@ -949,10 +969,18 @@
     argspec: "args=[\'handle\', \'deleter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "DeleteMemoryCache"
+    argspec: "args=[\'handle\', \'deleter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "DeleteMultiDeviceIterator"
     argspec: "args=[\'multi_device_iterator\', \'iterators\', \'deleter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "DeleteRandomSeedGenerator"
+    argspec: "args=[\'handle\', \'deleter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "DeleteSessionTensor"
     argspec: "args=[\'handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
@@ -1258,7 +1286,7 @@
   }
   member_method {
     name: "ExperimentalRebatchDataset"
-    argspec: "args=[\'input_dataset\', \'num_workers\', \'output_types\', \'output_shapes\', \'use_fallback\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+    argspec: "args=[\'input_dataset\', \'num_replicas\', \'output_types\', \'output_shapes\', \'use_fallback\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
   }
   member_method {
     name: "ExperimentalScanDataset"
@@ -2546,7 +2574,7 @@
   }
   member_method {
     name: "PrefetchDataset"
-    argspec: "args=[\'input_dataset\', \'buffer_size\', \'output_types\', \'output_shapes\', \'slack_period\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
+    argspec: "args=[\'input_dataset\', \'buffer_size\', \'output_types\', \'output_shapes\', \'slack_period\', \'legacy_autotune\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'True\', \'None\'], "
   }
   member_method {
     name: "Prelinearize"
@@ -2845,6 +2873,10 @@
     argspec: "args=[\'rt_nested_splits\', \'rt_dense_values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "RaggedTensorToTensor"
+    argspec: "args=[\'shape\', \'values\', \'default_value\', \'row_partition_tensors\', \'row_partition_types\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "RaggedTensorToVariant"
     argspec: "args=[\'rt_nested_splits\', \'rt_dense_values\', \'batched_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
@@ -2982,7 +3014,7 @@
   }
   member_method {
     name: "RebatchDataset"
-    argspec: "args=[\'input_dataset\', \'num_workers\', \'output_types\', \'output_shapes\', \'use_fallback\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+    argspec: "args=[\'input_dataset\', \'num_replicas\', \'output_types\', \'output_shapes\', \'use_fallback\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
   }
   member_method {
     name: "Reciprocal"
@@ -3149,6 +3181,10 @@
     argspec: "args=[\'var\', \'gradient_accumulator\', \'gradient_squared_accumulator\', \'grad\', \'lr\', \'l1\', \'l2\', \'global_step\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
   }
   member_method {
+    name: "ResourceApplyAdagradV2"
+    argspec: "args=[\'var\', \'accum\', \'lr\', \'epsilon\', \'grad\', \'use_locking\', \'update_slots\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
+  }
+  member_method {
     name: "ResourceApplyAdam"
     argspec: "args=[\'var\', \'m\', \'v\', \'beta1_power\', \'beta2_power\', \'lr\', \'beta1\', \'beta2\', \'epsilon\', \'grad\', \'use_locking\', \'use_nesterov\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
   }
@@ -3269,6 +3305,10 @@
     argspec: "args=[\'var\', \'gradient_accumulator\', \'gradient_squared_accumulator\', \'grad\', \'indices\', \'lr\', \'l1\', \'l2\', \'global_step\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
   }
   member_method {
+    name: "ResourceSparseApplyAdagradV2"
+    argspec: "args=[\'var\', \'accum\', \'lr\', \'epsilon\', \'grad\', \'indices\', \'use_locking\', \'update_slots\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
+  }
+  member_method {
     name: "ResourceSparseApplyCenteredRMSProp"
     argspec: "args=[\'var\', \'mg\', \'ms\', \'mom\', \'lr\', \'rho\', \'momentum\', \'epsilon\', \'grad\', \'indices\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
   }
@@ -3629,6 +3669,10 @@
     argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'output_types\', \'output_shapes\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
   }
   member_method {
+    name: "ShuffleDatasetV2"
+    argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed_generator\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "ShutdownDistributedTPU"
     argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
@@ -3745,6 +3789,10 @@
     argspec: "args=[\'var\', \'gradient_accumulator\', \'gradient_squared_accumulator\', \'grad\', \'indices\', \'lr\', \'l1\', \'l2\', \'global_step\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
   }
   member_method {
+    name: "SparseApplyAdagradV2"
+    argspec: "args=[\'var\', \'accum\', \'lr\', \'epsilon\', \'grad\', \'indices\', \'use_locking\', \'update_slots\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
+  }
+  member_method {
     name: "SparseApplyCenteredRMSProp"
     argspec: "args=[\'var\', \'mg\', \'ms\', \'mom\', \'lr\', \'rho\', \'momentum\', \'epsilon\', \'grad\', \'indices\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
   }
@@ -4050,7 +4098,7 @@
   }
   member_method {
     name: "StatelessWhile"
-    argspec: "args=[\'input\', \'cond\', \'body\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'input\', \'cond\', \'body\', \'output_shapes\', \'parallel_iterations\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'10\', \'None\'], "
   }
   member_method {
     name: "StaticRegexFullMatch"
@@ -4109,6 +4157,10 @@
     argspec: "args=[\'input\', \'encoding\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
   }
   member_method {
+    name: "StringNGrams"
+    argspec: "args=[\'data\', \'data_splits\', \'separator\', \'ngram_widths\', \'left_pad\', \'right_pad\', \'pad_width\', \'preserve_short_sequences\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "StringSplit"
     argspec: "args=[\'input\', \'delimiter\', \'skip_empty\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt
index 96e05c6..da31499 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt
@@ -102,7 +102,7 @@
   }
   member_method {
     name: "to_dense"
-    argspec: "args=[\'sp_input\', \'default_value\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'True\', \'None\'], "
+    argspec: "args=[\'sp_input\', \'default_value\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
   }
   member_method {
     name: "to_indicator"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
index 6f0cd87..8fc27cc 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
@@ -25,6 +25,10 @@
     argspec: "args=[\'input\', \'encoding\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
   }
   member_method {
+    name: "ngrams"
+    argspec: "args=[\'data\', \'ngram_width\', \'separator\', \'pad_values\', \'padding_width\', \'preserve_short_sequences\', \'name\'], varargs=None, keywords=None, defaults=[\' \', \'None\', \'None\', \'False\', \'None\'], "
+  }
+  member_method {
     name: "reduce_join"
     argspec: "args=[\'inputs\', \'axis\', \'keepdims\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'\', \'None\'], "
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.pbtxt
index 381cc5a..f532332 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.pbtxt
@@ -16,4 +16,12 @@
     name: "PythonState"
     mtype: "<type \'type\'>"
   }
+  member_method {
+    name: "disable_mixed_precision_graph_rewrite"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "enable_mixed_precision_graph_rewrite"
+    argspec: "args=[\'opt\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'dynamic\'], "
+  }
 }
diff --git a/tensorflow/tools/api/lib/BUILD b/tensorflow/tools/api/lib/BUILD
index 4c2c9b8..75cb933 100644
--- a/tensorflow/tools/api/lib/BUILD
+++ b/tensorflow/tools/api/lib/BUILD
@@ -1,7 +1,7 @@
 # Helper libraries for TensorFlow API compatibility test.
 
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_proto_library",
 )
 
diff --git a/tensorflow/tools/api/tests/BUILD b/tensorflow/tools/api/tests/BUILD
index 8f9f508..716654b 100644
--- a/tensorflow/tools/api/tests/BUILD
+++ b/tensorflow/tools/api/tests/BUILD
@@ -48,6 +48,7 @@
     srcs = ["deprecation_test.py"],
     python_version = "PY2",
     srcs_version = "PY2AND3",
+    tags = ["v1only"],
     deps = [
         "//tensorflow:tensorflow_py",
         "//tensorflow/python:client_testlib",
diff --git a/tensorflow/tools/api/tests/deprecation_test.py b/tensorflow/tools/api/tests/deprecation_test.py
index 3a5cf0d..962b557 100644
--- a/tensorflow/tools/api/tests/deprecation_test.py
+++ b/tensorflow/tools/api/tests/deprecation_test.py
@@ -39,7 +39,7 @@
     tf.tables_initializer()
     self.assertEqual(1, mock_warning.call_count)
     self.assertRegexpMatches(mock_warning.call_args[0][1],
-                             "module_wrapper.py:")
+                             "deprecation_test.py:")
     self.assertRegexpMatches(
         mock_warning.call_args[0][2], r"tables_initializer")
     self.assertRegexpMatches(
@@ -60,7 +60,7 @@
     tf.ragged.RaggedTensorValue(value, row_splits)
     self.assertEqual(1, mock_warning.call_count)
     self.assertRegexpMatches(mock_warning.call_args[0][1],
-                             "module_wrapper.py:")
+                             "deprecation_test.py:")
     self.assertRegexpMatches(
         mock_warning.call_args[0][2], r"ragged.RaggedTensorValue")
     self.assertRegexpMatches(
@@ -83,7 +83,7 @@
     tf.sparse_mask(array, mask_indices)
     self.assertEqual(1, mock_warning.call_count)
     self.assertRegexpMatches(mock_warning.call_args[0][1],
-                             "module_wrapper.py:")
+                             "deprecation_test.py:")
     self.assertRegexpMatches(
         mock_warning.call_args[0][2], r"sparse_mask")
     self.assertRegexpMatches(
@@ -101,7 +101,7 @@
     tf.VarLenFeature(tf.dtypes.int32)
     self.assertEqual(1, mock_warning.call_count)
     self.assertRegexpMatches(mock_warning.call_args[0][1],
-                             "module_wrapper.py:")
+                             "deprecation_test.py:")
     self.assertRegexpMatches(
         mock_warning.call_args[0][2], r"VarLenFeature")
     self.assertRegexpMatches(
@@ -119,7 +119,7 @@
     tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY  # pylint: disable=pointless-statement
     self.assertEqual(1, mock_warning.call_count)
     self.assertRegexpMatches(mock_warning.call_args[0][1],
-                             "module_wrapper.py:")
+                             "deprecation_test.py:")
     self.assertRegexpMatches(
         mock_warning.call_args[0][2],
         r"saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY")
diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc
index e5187ab..7ebba43 100644
--- a/tensorflow/tools/benchmark/benchmark_model.cc
+++ b/tensorflow/tools/benchmark/benchmark_model.cc
@@ -101,7 +101,7 @@
         if (!input.initialization_values.empty()) {
           LOG(FATAL) << "Initialization values are not supported for strings";
         }
-        auto type_tensor = input_tensor.flat<string>();
+        auto type_tensor = input_tensor.flat<tstring>();
         type_tensor = type_tensor.constant("");
         break;
       }
diff --git a/tensorflow/tools/ci_build/Dockerfile.custom_op_ubuntu_16 b/tensorflow/tools/ci_build/Dockerfile.custom_op_ubuntu_16
new file mode 100644
index 0000000..72348d5
--- /dev/null
+++ b/tensorflow/tools/ci_build/Dockerfile.custom_op_ubuntu_16
@@ -0,0 +1,70 @@
+# Dockerfile for Ubuntu 16.04 manylinux2010 custom ops with CPU.
+
+FROM ubuntu:16.04 as devtoolset
+
+LABEL maintainer="Amit Patankar <amitpatankar@google.com>"
+
+ENV DEBIAN_FRONTEND=noninteractive
+RUN apt-get update && apt-get install -y \
+      bzip2 \
+      cpio \
+      file \
+      flex \
+      g++ \
+      make \
+      patch \
+      rpm2cpio \
+      unar \
+      wget \
+      tar \
+      xz-utils \
+      && \
+    rm -rf /var/lib/apt/lists/*
+
+ADD devtoolset/fixlinks.sh fixlinks.sh
+ADD devtoolset/build_devtoolset.sh build_devtoolset.sh
+ADD devtoolset/rpm-patch.sh rpm-patch.sh
+
+# Set up a sysroot for glibc 2.12 / libstdc++ 4.4 / devtoolset-7 in /dt7.
+RUN /build_devtoolset.sh devtoolset-7 /dt7
+# Set up a sysroot for glibc 2.12 / libstdc++ 4.4 / devtoolset-8 in /dt8.
+RUN /build_devtoolset.sh devtoolset-8 /dt8
+
+# TODO(klimek): Split up into two different docker images.
+FROM ubuntu:16.04
+
+LABEL maintainer="Amit Patankar <amitpatankar@google.com>"
+
+COPY --from=devtoolset /dt7 /dt7
+COPY --from=devtoolset /dt8 /dt8
+
+# Copy and run the install scripts.
+COPY install/*.sh /install/
+ARG DEBIAN_FRONTEND=noninteractive
+RUN /install/install_bootstrap_deb_packages.sh
+RUN /install/install_deb_packages.sh
+RUN /install/install_clang.sh
+RUN /install/install_bazel.sh
+
+# Install golang.
+RUN /install/install_golang.sh
+env GOROOT=/usr/local/go
+env PATH=$GOROOT/bin:$PATH
+
+# Install python 3.6.
+RUN add-apt-repository ppa:jonathonf/python-3.6 && \
+    apt-get update && apt-get install -y \
+    python3.6 python3.6-dev python3-pip python3.6-venv && \
+    rm -rf /var/lib/apt/lists/* && \
+    python3.6 -m pip install pip --upgrade && \
+    update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.6 0
+
+RUN /install/install_pip_packages.sh
+
+# TODO(klimek): Figure out a better way to get the right include paths
+# forwarded when we install new packages.
+RUN ln -s "/usr/include/x86_64-linux-gnu/python2.7" "/dt7/usr/include/x86_64-linux-gnu/python2.7"
+RUN ln -s "/usr/include/x86_64-linux-gnu/python2.7" "/dt8/usr/include/x86_64-linux-gnu/python2.7"
+
+RUN ln -s "/usr/include/x86_64-linux-gnu/python3.6m" "/dt7/usr/include/x86_64-linux-gnu/python3.6m"
+RUN ln -s "/usr/include/x86_64-linux-gnu/python3.6m" "/dt8/usr/include/x86_64-linux-gnu/python3.6m"
diff --git a/tensorflow/tools/ci_build/Dockerfile.custom_op_ubuntu_16_gpu b/tensorflow/tools/ci_build/Dockerfile.custom_op_ubuntu_16_gpu
new file mode 100644
index 0000000..c57d28a
--- /dev/null
+++ b/tensorflow/tools/ci_build/Dockerfile.custom_op_ubuntu_16_gpu
@@ -0,0 +1,67 @@
+# Dockerfile for Ubuntu 16.04 manylinux2010 custom ops with GPU.
+
+FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04 as devtoolset
+
+LABEL maintainer="Amit Patankar <amitpatankar@google.com>"
+
+ENV DEBIAN_FRONTEND=noninteractive
+RUN apt-get update && apt-get install -y \
+      cpio \
+      file \
+      flex \
+      g++ \
+      make \
+      rpm2cpio \
+      unar \
+      wget \
+      && \
+    rm -rf /var/lib/apt/lists/*
+
+ADD devtoolset/fixlinks.sh fixlinks.sh
+ADD devtoolset/build_devtoolset.sh build_devtoolset.sh
+ADD devtoolset/rpm-patch.sh rpm-patch.sh
+
+# Set up a sysroot for glibc 2.12 / libstdc++ 4.4 / devtoolset-7 in /dt7.
+RUN /build_devtoolset.sh devtoolset-7 /dt7
+# Set up a sysroot for glibc 2.12 / libstdc++ 4.4 / devtoolset-8 in /dt8.
+RUN /build_devtoolset.sh devtoolset-8 /dt8
+
+# TODO(klimek): Split up into two different docker images.
+FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04
+
+LABEL maintainer="Amit Patankar <amitpatankar@google.com>"
+
+COPY --from=devtoolset /dt7 /dt7
+COPY --from=devtoolset /dt8 /dt8
+
+# Install TensorRT.
+RUN apt-get update && apt-get install -y \
+    libnvinfer-dev=5.1.5-1+cuda10.0 \
+    libnvinfer5=5.1.5-1+cuda10.0 \
+      && \
+    rm -rf /var/lib/apt/lists/*
+
+# Copy and run the install scripts.
+COPY install/*.sh /install/
+ARG DEBIAN_FRONTEND=noninteractive
+RUN /install/install_bootstrap_deb_packages.sh
+RUN /install/install_deb_packages.sh
+RUN /install/install_clang.sh
+RUN /install/install_bazel.sh
+
+ENV TF_NEED_CUDA=1
+
+# Install python 3.6.
+RUN add-apt-repository ppa:jonathonf/python-3.6 && \
+    apt-get update && apt-get install -y \
+    python3.6 python3.6-dev python3-pip python3.6-venv && \
+    rm -rf /var/lib/apt/lists/* && \
+    python3.6 -m pip install pip --upgrade && \
+    update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.6 0
+
+RUN /install/install_pip_packages.sh
+
+# TODO(klimek): Figure out a better way to get the right include paths
+# forwarded when we install new packages.
+RUN ln -s "/usr/include/x86_64-linux-gnu/python3.6m" "/dt7/usr/include/x86_64-linux-gnu/python3.6m"
+RUN ln -s "/usr/include/x86_64-linux-gnu/python3.6m" "/dt8/usr/include/x86_64-linux-gnu/python3.6m"
diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.ubuntu16.04-manylinux2010 b/tensorflow/tools/ci_build/Dockerfile.rbe.ubuntu16.04-manylinux2010
new file mode 100644
index 0000000..93ad40d
--- /dev/null
+++ b/tensorflow/tools/ci_build/Dockerfile.rbe.ubuntu16.04-manylinux2010
@@ -0,0 +1,73 @@
+# Dockerfile to build a manylinux 2010 compliant cross-compiler.
+#
+# Builds a devtoolset gcc/libstdc++ that targets manylinux 2010 compatible
+# glibc (2.12) and system libstdc++ (4.4).
+#
+# To push a new version, run:
+# $ docker build -f Dockerfile.rbe.ubuntu16.04-manylinux2010 \
+#  --tag "gcr.io/tensorflow-testing/nosla-ubuntu16.04-manylinux2010" .
+# $ docker push gcr.io/tensorflow-testing/nosla-ubuntu16.04-manylinux2010
+
+FROM ubuntu:16.04 as devtoolset
+
+ENV DEBIAN_FRONTEND=noninteractive
+RUN apt-get update && apt-get install -y \
+      bzip2 \
+      cpio \
+      file \
+      flex \
+      g++ \
+      make \
+      patch \
+      rpm2cpio \
+      unar \
+      wget \
+      tar \
+      xz-utils \
+      && \
+    rm -rf /var/lib/apt/lists/*
+
+ADD devtoolset/fixlinks.sh fixlinks.sh
+ADD devtoolset/build_devtoolset.sh build_devtoolset.sh
+ADD devtoolset/rpm-patch.sh rpm-patch.sh
+
+# Set up a sysroot for glibc 2.12 / libstdc++ 4.4 / devtoolset-7 in /dt7.
+RUN /build_devtoolset.sh devtoolset-7 /dt7
+# Set up a sysroot for glibc 2.12 / libstdc++ 4.4 / devtoolset-8 in /dt8.
+RUN /build_devtoolset.sh devtoolset-8 /dt8
+
+# TODO(klimek): Split up into two different docker images.
+FROM ubuntu:16.04
+COPY --from=devtoolset /dt7 /dt7
+COPY --from=devtoolset /dt8 /dt8
+
+# Copy and run the install scripts.
+COPY install/*.sh /install/
+ARG DEBIAN_FRONTEND=noninteractive
+RUN /install/install_bootstrap_deb_packages.sh
+RUN /install/install_deb_packages.sh
+RUN /install/install_clang.sh
+RUN /install/install_bazel.sh
+
+# Install golang.
+RUN /install/install_golang.sh
+env GOROOT=/usr/local/go
+env PATH=$GOROOT/bin:$PATH
+
+# Install python 3.6.
+RUN add-apt-repository ppa:jonathonf/python-3.6 && \
+    apt-get update && apt-get install -y \
+    python3.6 python3.6-dev python3-pip python3.6-venv && \
+    rm -rf /var/lib/apt/lists/* && \
+    python3.6 -m pip install pip --upgrade && \
+    update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.6 0
+
+RUN /install/install_pip_packages.sh
+
+# TODO(klimek): Figure out a better way to get the right include paths
+# forwarded when we install new packages.
+RUN ln -s "/usr/include/x86_64-linux-gnu/python2.7" "/dt7/usr/include/x86_64-linux-gnu/python2.7"
+RUN ln -s "/usr/include/x86_64-linux-gnu/python2.7" "/dt8/usr/include/x86_64-linux-gnu/python2.7"
+
+RUN ln -s "/usr/include/x86_64-linux-gnu/python3.6m" "/dt7/usr/include/x86_64-linux-gnu/python3.6m"
+RUN ln -s "/usr/include/x86_64-linux-gnu/python3.6m" "/dt8/usr/include/x86_64-linux-gnu/python3.6m"
diff --git a/tensorflow/tools/ci_build/builds/docker_cpu_pip.sh b/tensorflow/tools/ci_build/builds/docker_cpu_pip.sh
index 26dcc6a..c87ec29 100755
--- a/tensorflow/tools/ci_build/builds/docker_cpu_pip.sh
+++ b/tensorflow/tools/ci_build/builds/docker_cpu_pip.sh
@@ -37,7 +37,8 @@
       --test_timeout 300,450,1200,3600 \
       --test_output=errors \
       -- //${PIP_TEST_ROOT}/tensorflow/python/... \
-      -//${PIP_TEST_ROOT}/tensorflow/python/keras:training_eager_test \
-      -//${PIP_TEST_ROOT}/tensorflow/python/keras:base_layer_test \
       -//${PIP_TEST_ROOT}/tensorflow/python:virtual_gpu_test \
-      -//${PIP_TEST_ROOT}/tensorflow/python:virtual_gpu_test_gpu
+      -//${PIP_TEST_ROOT}/tensorflow/python:virtual_gpu_test_gpu \
+      -//${PIP_TEST_ROOT}/tensorflow/python:collective_ops_gpu_test \
+      -//${PIP_TEST_ROOT}/tensorflow/python:collective_ops_gpu_test_gpu
+
diff --git a/tensorflow/tools/ci_build/builds/pip_new.sh b/tensorflow/tools/ci_build/builds/pip_new.sh
index 1fb02ca..72f1b58 100755
--- a/tensorflow/tools/ci_build/builds/pip_new.sh
+++ b/tensorflow/tools/ci_build/builds/pip_new.sh
@@ -421,7 +421,7 @@
   echo "PYTHON_BIN_PATH to be used to install the .whl: ${PYTHON_BIN_PATH}"
   echo "PIP_BIN_PATH to be used to install the .whl: ${PIP_BIN_PATH}"
 
-  # Upgrade pip so it supports tags such as cp27mu, manylinux1 etc.
+  # Upgrade pip so it supports tags such as cp27mu, manylinux2010 etc.
   echo "Upgrade pip in virtualenv"
 
   # NOTE: pip install --upgrade pip leads to a documented TLS issue for
@@ -452,6 +452,12 @@
   #   ImportError: cannot import name py31compat
   ${PIP_BIN_PATH} install --upgrade setuptools==39.1.0 || \
     die "Error: setuptools install, upgrade FAILED"
+
+  # Install the future package in the virtualenv. Installing it in user system
+  # packages does not appear to port it over when creating a virtualenv.
+  #   ImportError: No module named builtins
+  ${PIP_BIN_PATH} install --upgrade "future>=0.17.1" || \
+    die "Error: future install, upgrade FAILED"
 }
 
 run_test_with_bazel() {
@@ -613,7 +619,7 @@
 
 WHL_DIR=$(dirname "${WHL_PATH}")
 WHL_BASE_NAME=$(basename "${WHL_PATH}")
-AUDITED_WHL_NAME="${WHL_DIR}"/$(echo "${WHL_BASE_NAME//linux/manylinux1}")
+AUDITED_WHL_NAME="${WHL_DIR}"/$(echo "${WHL_BASE_NAME//linux/manylinux2010}")
 
 # Print the size of the wheel file.
 echo "Size of the PIP wheel file built: $(ls -l ${WHL_PATH} | awk '{print $5}')"
@@ -626,25 +632,25 @@
     # Copy and rename for gpu manylinux as we do not want auditwheel to package in libcudart.so
     WHL_PATH=${AUDITED_WHL_NAME}
     cp "${WHL_DIR}"/"${WHL_BASE_NAME}" "${WHL_PATH}"
-    echo "Copied manylinux1 wheel file at ${WHL_PATH}"
+    echo "Copied manylinux2010 wheel file at ${WHL_PATH}"
   else
     if [[ ${OS_TYPE} == "ubuntu" ]]; then
       # Avoid Python3.6 abnormality by installing auditwheel here.
       set +e
       pip3 show auditwheel || "pip${PY_MAJOR_MINOR_VER}" show auditwheel
-      pip3 install auditwheel==1.5.0 || "pip${PY_MAJOR_MINOR_VER}" install auditwheel==1.5.0
-      sudo pip3 install auditwheel==1.5.0 || \
-        sudo "pip${PY_MAJOR_MINOR_VER}" install auditwheel==1.5.0
+      pip3 install auditwheel==2.0.0 || "pip${PY_MAJOR_MINOR_VER}" install auditwheel==2.0.0
+      sudo pip3 install auditwheel==2.0.0 || \
+        sudo "pip${PY_MAJOR_MINOR_VER}" install auditwheel==2.0.0
       set -e
       auditwheel --version
 
-      # Repair the wheels for cpu manylinux1
+      # Repair the wheels for cpu manylinux2010
       echo "auditwheel repairing ${WHL_PATH}"
-      auditwheel repair -w "${WHL_DIR}" "${WHL_PATH}"
+      auditwheel repair --plat manylinux2010_x86_64 -w "${WHL_DIR}" "${WHL_PATH}"
 
       if [[ -f ${AUDITED_WHL_NAME} ]]; then
         WHL_PATH=${AUDITED_WHL_NAME}
-        echo "Repaired manylinux1 wheel file at: ${WHL_PATH}"
+        echo "Repaired manylinux2010 wheel file at: ${WHL_PATH}"
       else
         die "WARNING: Cannot find repaired wheel."
       fi
diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh
index b78281d..cb27a59 100755
--- a/tensorflow/tools/ci_build/ci_sanity.sh
+++ b/tensorflow/tools/ci_build/ci_sanity.sh
@@ -219,7 +219,7 @@
 
   echo ""
   if [[ ${N_ERRORS} != 0 ]]; then
-    echo "FAIL: Found ${N_ERRORS} non-whitelited pylint errors:"
+    echo "FAIL: Found ${N_ERRORS} non-whitelisted pylint errors:"
     cat "${NONWL_ERRORS_FILE}"
     return 1
   else
@@ -363,12 +363,12 @@
 
   # Blacklist
   echo ${MISSING_LICENSES_FILE}
-  grep -e "@bazel_tools//third_party/" -e "@com_google_absl//absl" -e "@org_tensorflow//" -e "@com_github_googlecloudplatform_google_cloud_cpp//google" -v ${MISSING_LICENSES_FILE} > temp.txt
+  grep -e "@bazel_tools//third_party/" -e "@bazel_tools//tools" -e "@local" -e "@com_google_absl//absl" -e "@org_tensorflow//" -e "@com_github_googlecloudplatform_google_cloud_cpp//google" -v ${MISSING_LICENSES_FILE} > temp.txt
   mv temp.txt ${MISSING_LICENSES_FILE}
 
   # Whitelist
   echo ${EXTRA_LICENSE_FILE}
-  grep -e "@bazel_tools//src" -e "@bazel_tools//tools/" -e "@com_google_absl//" -e "//external" -e "@local" -e "@com_github_googlecloudplatform_google_cloud_cpp//" -e "@embedded_jdk//" -v ${EXTRA_LICENSES_FILE} > temp.txt
+  grep -e "//third_party/mkl_dnn" -e "@bazel_tools//src" -e "@bazel_tools//tools/" -e "@org_tensorflow//tensorflow" -e "@com_google_absl//" -e "//external" -e "@local" -e "@com_github_googlecloudplatform_google_cloud_cpp//" -e "@embedded_jdk//" -v ${EXTRA_LICENSES_FILE} > temp.txt
   mv temp.txt ${EXTRA_LICENSES_FILE}
 
 
@@ -551,10 +551,11 @@
 _check_no_deps() {
   TARGET="$1"
   DISALLOWED_DEP="$2"
+  EXTRA_FLAG="$3"
 
   TMP_FILE="$(mktemp)_tmp.log"
   echo "Checking ${TARGET} does not depend on ${DISALLOWED_DEP} ..."
-  bazel cquery "somepath(${TARGET}, ${DISALLOWED_DEP})" --keep_going> "${TMP_FILE}" 2>&1
+  bazel cquery ${EXTRA_FLAG} "somepath(${TARGET}, ${DISALLOWED_DEP})" --keep_going> "${TMP_FILE}" 2>&1
   if cat "${TMP_FILE}" | grep "Empty query results"; then
       echo "Success."
   else
@@ -568,7 +569,8 @@
   rm "${TMP_FILE}"
 }
 
-do_pip_no_cuda_deps_check() {
+_do_pip_no_cuda_deps_check() {
+  EXTRA_FLAG="$1"
   DISALLOWED_CUDA_DEPS=("@local_config_cuda//cuda:cudart"
         "@local_config_cuda//cuda:cublas"
         "@local_config_cuda//cuda:cuda_driver"
@@ -578,7 +580,7 @@
         "@local_config_cuda//cuda:cusparse")
   for cuda_dep in "${DISALLOWED_CUDA_DEPS[@]}"
   do
-   _check_no_deps "//tensorflow/tools/pip_package:build_pip_package" "${cuda_dep}"
+   _check_no_deps "//tensorflow/tools/pip_package:build_pip_package" "${cuda_dep}" "${EXTRA_FLAG}"
    RESULT=$?
 
    if [[ ${RESULT} != "0" ]]; then
@@ -587,6 +589,14 @@
   done
 }
 
+do_pip_no_cuda_deps_check_ubuntu() {
+  _do_pip_no_cuda_deps_check "--define using_cuda=true --define using_cuda_nvcc=true"
+}
+
+do_pip_no_cuda_deps_check_windows() {
+  _do_pip_no_cuda_deps_check "--define using_cuda=true --define using_cuda_nvcc=true --define framework_shared_object=false"
+}
+
 do_configure_test() {
   for WITH_CUDA in 1 0
   do
@@ -602,8 +612,8 @@
 }
 
 # Supply all sanity step commands and descriptions
-SANITY_STEPS=("do_configure_test" "do_pylint PYTHON2" "do_pylint PYTHON3" "do_check_futures_test" "do_buildifier" "do_bazel_nobuild" "do_pip_package_licenses_check" "do_lib_package_licenses_check" "do_java_package_licenses_check" "do_pip_smoke_test" "do_check_load_py_test" "do_code_link_check" "do_check_file_name_test" "do_pip_no_cuda_deps_check")
-SANITY_STEPS_DESC=("Run ./configure" "Python 2 pylint" "Python 3 pylint" "Check that python files have certain __future__ imports" "buildifier check" "bazel nobuild" "pip: license check for external dependencies" "C library: license check for external dependencies" "Java Native Library: license check for external dependencies" "Pip Smoke Test: Checking py_test dependencies exist in pip package" "Check load py_test: Check that BUILD files with py_test target properly load py_test" "Code Link Check: Check there are no broken links" "Check file names for cases" "Check gpu pip package does not depend on cuda shared libraries.")
+SANITY_STEPS=("do_configure_test" "do_pylint PYTHON2" "do_pylint PYTHON3" "do_check_futures_test" "do_buildifier" "do_bazel_nobuild" "do_pip_package_licenses_check" "do_lib_package_licenses_check" "do_java_package_licenses_check" "do_pip_smoke_test" "do_check_load_py_test" "do_code_link_check" "do_check_file_name_test" "do_pip_no_cuda_deps_check_ubuntu" "do_pip_no_cuda_deps_check_windows")
+SANITY_STEPS_DESC=("Run ./configure" "Python 2 pylint" "Python 3 pylint" "Check that python files have certain __future__ imports" "buildifier check" "bazel nobuild" "pip: license check for external dependencies" "C library: license check for external dependencies" "Java Native Library: license check for external dependencies" "Pip Smoke Test: Checking py_test dependencies exist in pip package" "Check load py_test: Check that BUILD files with py_test target properly load py_test" "Code Link Check: Check there are no broken links" "Check file names for cases" "Check Ubuntu gpu pip package does not depend on cuda shared libraries" "Check Windows gpu pip package does not depend on cuda shared libraries")
 
 INCREMENTAL_FLAG=""
 DEFAULT_BAZEL_CONFIGS=""
diff --git a/tensorflow/tools/ci_build/install/install_bazel.sh b/tensorflow/tools/ci_build/install/install_bazel.sh
index 2157785..0e4ce18 100755
--- a/tensorflow/tools/ci_build/install/install_bazel.sh
+++ b/tensorflow/tools/ci_build/install/install_bazel.sh
@@ -15,7 +15,7 @@
 # ==============================================================================
 
 # Select bazel version.
-BAZEL_VERSION="0.24.1"
+BAZEL_VERSION="0.26.1"
 
 set +e
 local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}')
diff --git a/tensorflow/tools/ci_build/install/install_bazel_from_source.sh b/tensorflow/tools/ci_build/install/install_bazel_from_source.sh
index 75de245..6d221a7 100755
--- a/tensorflow/tools/ci_build/install/install_bazel_from_source.sh
+++ b/tensorflow/tools/ci_build/install/install_bazel_from_source.sh
@@ -18,7 +18,7 @@
 # It will compile bazel from source and install it in /usr/local/bin
 
 # Select bazel version.
-BAZEL_VERSION="0.24.1"
+BAZEL_VERSION="0.26.1"
 
 set +e
 local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}')
diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh
index 9e641d4..c8fc266 100755
--- a/tensorflow/tools/ci_build/install/install_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh
@@ -16,12 +16,9 @@
 
 set -e
 
-# We don't apt-get install so that we can install a newer version of pip.
-# Only needed for Ubuntu 14.04 and 16.04; not needed for 18.04 and Debian 8,9?
-# Run easy_install after easy_install3, so that the default pip points to pip2,
-# to match the default python version of 2.7.
-easy_install3 -U pip==18.1
-easy_install -U pip==18.1
+# Get the latest version of pip so it recognize manylinux2010
+easy_install3 -U pip
+easy_install -U pip
 
 # Install pip packages from whl files to avoid the time-consuming process of
 # building from source.
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh b/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh
index 8f8f031..bb12098 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh
@@ -31,8 +31,19 @@
 yes "" | $PYTHON_BIN_PATH configure.py
 if [[ "$MODE" == "eigen" ]]; then
     CONFIG=""
+    OMPTHREADS=""
 else
     CONFIG="--config=mkl"
+# Setting OMP_THREADS for low performing benchmarks.
+#   Default value(=core count) degrades perfrmance of some banchmark cases. 
+#   Optimal thread count is case specific. 
+#   An argument can be passed to script, the value of which is used if given.
+#   Otherwise OMP_NUM_THREADS is set to 10
+    if [[ -z $1 ]]; then
+        OMPTHREADS="--action_env=OMP_NUM_THREADS=10"
+    else 
+        OMPTHREADS="--action_env=OMP_NUM_THREADS=$1"
+    fi
 fi
 
 # Run bazel test command. Double test timeouts to avoid flakes.
@@ -41,5 +52,5 @@
 # caused by executing multiple tests concurrently.
 bazel test --test_tag_filters=-no_oss,-no_oss_py2,-oss_serial,-gpu,-benchmark-test --test_lang_filters=cc,py -k \
     --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only \
-    ${CONFIG} --test_env=KMP_BLOCKTIME=0 --config=opt --test_output=errors -- \
+    ${CONFIG} --test_env=KMP_BLOCKTIME=0 ${OMPTHREADS} --config=opt --test_output=errors -- \
     //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... -//tensorflow/lite/...
diff --git a/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh b/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh
index 9961f44..1398b79 100755
--- a/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh
+++ b/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh
@@ -87,12 +87,13 @@
   echo "Building for the Pi One/Zero, with no NEON support"
   WHEEL_ARCH=linux_armv6l
 else
-  PI_COPTS='--copt=-march=armv7-a --copt=-mfpu=neon-vfpv4
+  PI_COPTS="--copt=-march=armv7-a --copt=-mfpu=neon-vfpv4
   --copt=-std=gnu11 --copt=-DS_IREAD=S_IRUSR --copt=-DS_IWRITE=S_IWUSR
   --copt=-O3 --copt=-fno-tree-pre
   --copt=-U__GCC_HAVE_SYNC_COMPARE_AND_SWAP_1
   --copt=-U__GCC_HAVE_SYNC_COMPARE_AND_SWAP_2
-  --copt=-U__GCC_HAVE_SYNC_COMPARE_AND_SWAP_8'
+  --copt=-U__GCC_HAVE_SYNC_COMPARE_AND_SWAP_8
+  --define=raspberry_pi_with_neon=true"
   WHEEL_ARCH=linux_armv7l
   echo "Building for the Pi Two/Three, with NEON acceleration"
 fi
diff --git a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
index e1db8ca..3ddcafb 100644
--- a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
+++ b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
@@ -169,7 +169,6 @@
 # https://github.com/bazelbuild/bazel/issues/6622
 bazel test --announce_rc --config=opt -k --test_output=errors \
   ${EXTRA_TEST_FLAGS} \
-  --experimental_windows_native_test_wrapper \
   --define=no_tensorflow_py_deps=true --test_lang_filters=py \
   --test_tag_filters=-no_pip,-no_windows,-no_oss,-gpu \
   --build_tag_filters=-no_pip,-no_windows,-no_oss,-gpu --build_tests_only \
diff --git a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
index 0277479..bdd70eb 100644
--- a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
+++ b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
@@ -170,11 +170,10 @@
 bazel test --announce_rc --config=opt -k --test_output=errors \
   --test_env=TF_GPU_COUNT \
   ${EXTRA_TEST_FLAGS} \
-  --experimental_windows_native_test_wrapper \
   --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \
   --define=no_tensorflow_py_deps=true --test_lang_filters=py \
-  --test_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu,-no_oss \
-  --build_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu,-no_oss --build_tests_only \
+  --test_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu,-no_oss,gpu \
+  --build_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu,-no_oss,gpu --build_tests_only \
   --test_size_filters=small,medium \
   --local_test_jobs=$TF_GPU_COUNT --test_timeout="300,450,1200,3600" \
   --flaky_test_attempts=3 \
diff --git a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh
index 814e19c..fd31c35b 100755
--- a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh
+++ b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh
@@ -60,7 +60,12 @@
 mkdir -p ${DIR}/lib
 cp bazel-bin/tensorflow/tensorflow.dll ${DIR}/lib/tensorflow.dll
 cp bazel-genfiles/tensorflow/tensorflow.lib ${DIR}/lib/tensorflow.lib
-cp tensorflow/c/c_api.h ${DIR}/include/tensorflow/c
+cp tensorflow/c/c_api.h \
+  tensorflow/c/tf_attrtype.h \
+  tensorflow/c/tf_datatype.h \
+  tensorflow/c/tf_status.h \
+  tensorflow/c/tf_tensor.h \
+  ${DIR}/include/tensorflow/c
 cp tensorflow/c/eager/c_api.h ${DIR}/include/tensorflow/c/eager
 cp bazel-genfiles/tensorflow/tools/lib_package/include/tensorflow/c/LICENSE ${DIR}/include/tensorflow/c
 cd ${DIR}
@@ -69,5 +74,9 @@
   lib/tensorflow.lib \
   include/tensorflow/c/eager/c_api.h \
   include/tensorflow/c/c_api.h \
+  include/tensorflow/c/tf_attrtype.h \
+  include/tensorflow/c/tf_datatype.h \
+  include/tensorflow/c/tf_status.h \
+  include/tensorflow/c/tf_tensor.h \
   include/tensorflow/c/LICENSE
 rm -rf lib include
diff --git a/tensorflow/tools/ci_build/windows/libtensorflow_gpu.sh b/tensorflow/tools/ci_build/windows/libtensorflow_gpu.sh
index 29736b2..df5c3e6 100644
--- a/tensorflow/tools/ci_build/windows/libtensorflow_gpu.sh
+++ b/tensorflow/tools/ci_build/windows/libtensorflow_gpu.sh
@@ -60,7 +60,12 @@
 mkdir -p ${DIR}/lib
 cp bazel-bin/tensorflow/tensorflow.dll ${DIR}/lib/tensorflow.dll
 cp bazel-genfiles/tensorflow/tensorflow.lib ${DIR}/lib/tensorflow.lib
-cp tensorflow/c/c_api.h ${DIR}/include/tensorflow/c
+cp tensorflow/c/c_api.h \
+  tensorflow/c/tf_attrtype.h \
+  tensorflow/c/tf_datatype.h \
+  tensorflow/c/tf_status.h \
+  tensorflow/c/tf_tensor.h \
+  ${DIR}/include/tensorflow/c
 cp tensorflow/c/eager/c_api.h ${DIR}/include/tensorflow/c/eager
 cp bazel-genfiles/tensorflow/tools/lib_package/include/tensorflow/c/LICENSE ${DIR}/include/tensorflow/c
 cd ${DIR}
@@ -69,5 +74,9 @@
   lib/tensorflow.lib \
   include/tensorflow/c/eager/c_api.h \
   include/tensorflow/c/c_api.h \
+  include/tensorflow/c/tf_attrtype.h \
+  include/tensorflow/c/tf_datatype.h \
+  include/tensorflow/c/tf_status.h \
+  include/tensorflow/c/tf_tensor.h \
   include/tensorflow/c/LICENSE
 rm -rf lib include
diff --git a/tensorflow/tools/compatibility/BUILD b/tensorflow/tools/compatibility/BUILD
index 36efc6bf..5a50d77 100644
--- a/tensorflow/tools/compatibility/BUILD
+++ b/tensorflow/tools/compatibility/BUILD
@@ -153,6 +153,7 @@
     srcs = ["tf_upgrade_v2_test.py"],
     python_version = "PY2",
     srcs_version = "PY2AND3",
+    tags = ["v1only"],
     deps = [
         ":tf_upgrade_v2_lib",
         "//tensorflow:tensorflow_py",
@@ -225,7 +226,8 @@
     cmd = ("$(location :tf_upgrade_v2)" +
            " --infile $(location testdata/test_file_v1_12.py)" +
            " --outfile $(location test_file_v2_0.py)" +
-           " --reportfile $(location report_v2.txt)"),
+           " --reportfile $(location report_v2.txt) && " +
+           "sed -i'.original' 's/_TEST_VERSION = 1/_TEST_VERSION = 2/g' $(location test_file_v2_0.py)"),
     tools = [":tf_upgrade_v2"],
 )
 
@@ -235,6 +237,7 @@
     srcs = ["testdata/test_file_v1_12.py"],
     python_version = "PY2",
     srcs_version = "PY2AND3",
+    tags = ["v1only"],
     deps = [
         "//tensorflow:tensorflow_py",
     ],
diff --git a/tensorflow/tools/compatibility/ast_edits.py b/tensorflow/tools/compatibility/ast_edits.py
index e80bdc4..70ed82d 100644
--- a/tensorflow/tools/compatibility/ast_edits.py
+++ b/tensorflow/tools/compatibility/ast_edits.py
@@ -1032,10 +1032,25 @@
       output_directory = os.path.dirname(output_path)
       if not os.path.isdir(output_directory):
         os.makedirs(output_directory)
+
+      if os.path.islink(input_path):
+        link_target = os.readlink(input_path)
+        link_target_output = os.path.join(
+            output_root_directory, os.path.relpath(link_target, root_directory))
+        if (link_target, link_target_output) in files_to_process:
+          # Create a link to the new location of the target file
+          os.symlink(link_target_output, output_path)
+        else:
+          report += "Copying symlink %s without modifying its target %s" % (
+              input_path, link_target)
+          os.symlink(link_target, output_path)
+        continue
+
       file_count += 1
       _, l_report, l_errors = self.process_file(input_path, output_path)
       tree_errors[input_path] = l_errors
       report += l_report
+
     for input_path, output_path in files_to_copy:
       output_directory = os.path.dirname(output_path)
       if not os.path.isdir(output_directory):
@@ -1059,6 +1074,9 @@
     report += ("=" * 80) + "\n"
 
     for path in files_to_process:
+      if os.path.islink(path):
+        report += "Skipping symlink %s.\n" % path
+        continue
       file_count += 1
       _, l_report, l_errors = self.process_file(path, path)
       tree_errors[path] = l_errors
diff --git a/tensorflow/tools/compatibility/ast_edits_test.py b/tensorflow/tools/compatibility/ast_edits_test.py
index 0bc87d1..d6a366d 100644
--- a/tensorflow/tools/compatibility/ast_edits_test.py
+++ b/tensorflow/tools/compatibility/ast_edits_test.py
@@ -45,6 +45,7 @@
 from __future__ import print_function
 
 import ast
+import os
 import six
 
 from tensorflow.python.framework import test_util
@@ -605,6 +606,89 @@
     _, new_text = self._upgrade(RenameImports(), text)
     self.assertEqual(expected_text, new_text)
 
+  def testUpgradeInplaceWithSymlink(self):
+    upgrade_dir = os.path.join(self.get_temp_dir(), "foo")
+    os.mkdir(upgrade_dir)
+    file_a = os.path.join(upgrade_dir, "a.py")
+    file_b = os.path.join(upgrade_dir, "b.py")
+
+    with open(file_a, "a") as f:
+      f.write("import foo as f")
+    os.symlink(file_a, file_b)
+
+    upgrader = ast_edits.ASTCodeUpgrader(RenameImports())
+    upgrader.process_tree_inplace(upgrade_dir)
+
+    self.assertTrue(os.path.islink(file_b))
+    self.assertEqual(file_a, os.readlink(file_b))
+    with open(file_a, "r") as f:
+      self.assertEqual("import bar as f", f.read())
+
+  def testUpgradeInPlaceWithSymlinkInDifferentDir(self):
+    upgrade_dir = os.path.join(self.get_temp_dir(), "foo")
+    other_dir = os.path.join(self.get_temp_dir(), "bar")
+    os.mkdir(upgrade_dir)
+    os.mkdir(other_dir)
+    file_c = os.path.join(other_dir, "c.py")
+    file_d = os.path.join(upgrade_dir, "d.py")
+
+    with open(file_c, "a") as f:
+      f.write("import foo as f")
+    os.symlink(file_c, file_d)
+
+    upgrader = ast_edits.ASTCodeUpgrader(RenameImports())
+    upgrader.process_tree_inplace(upgrade_dir)
+
+    self.assertTrue(os.path.islink(file_d))
+    self.assertEqual(file_c, os.readlink(file_d))
+    # File pointed to by symlink is in a different directory.
+    # Therefore, it should not be upgraded.
+    with open(file_c, "r") as f:
+      self.assertEqual("import foo as f", f.read())
+
+  def testUpgradeCopyWithSymlink(self):
+    upgrade_dir = os.path.join(self.get_temp_dir(), "foo")
+    output_dir = os.path.join(self.get_temp_dir(), "bar")
+    os.mkdir(upgrade_dir)
+    file_a = os.path.join(upgrade_dir, "a.py")
+    file_b = os.path.join(upgrade_dir, "b.py")
+
+    with open(file_a, "a") as f:
+      f.write("import foo as f")
+    os.symlink(file_a, file_b)
+
+    upgrader = ast_edits.ASTCodeUpgrader(RenameImports())
+    upgrader.process_tree(upgrade_dir, output_dir, copy_other_files=True)
+
+    new_file_a = os.path.join(output_dir, "a.py")
+    new_file_b = os.path.join(output_dir, "b.py")
+    self.assertTrue(os.path.islink(new_file_b))
+    self.assertEqual(new_file_a, os.readlink(new_file_b))
+    with open(new_file_a, "r") as f:
+      self.assertEqual("import bar as f", f.read())
+
+  def testUpgradeCopyWithSymlinkInDifferentDir(self):
+    upgrade_dir = os.path.join(self.get_temp_dir(), "foo")
+    other_dir = os.path.join(self.get_temp_dir(), "bar")
+    output_dir = os.path.join(self.get_temp_dir(), "baz")
+    os.mkdir(upgrade_dir)
+    os.mkdir(other_dir)
+    file_a = os.path.join(other_dir, "a.py")
+    file_b = os.path.join(upgrade_dir, "b.py")
+
+    with open(file_a, "a") as f:
+      f.write("import foo as f")
+    os.symlink(file_a, file_b)
+
+    upgrader = ast_edits.ASTCodeUpgrader(RenameImports())
+    upgrader.process_tree(upgrade_dir, output_dir, copy_other_files=True)
+
+    new_file_b = os.path.join(output_dir, "b.py")
+    self.assertTrue(os.path.islink(new_file_b))
+    self.assertEqual(file_a, os.readlink(new_file_b))
+    with open(file_a, "r") as f:
+      self.assertEqual("import foo as f", f.read())
+
 
 if __name__ == "__main__":
   test_lib.main()
diff --git a/tensorflow/tools/compatibility/renames_v2.py b/tensorflow/tools/compatibility/renames_v2.py
index 58f5dff..e5d64a7 100644
--- a/tensorflow/tools/compatibility/renames_v2.py
+++ b/tensorflow/tools/compatibility/renames_v2.py
@@ -277,6 +277,8 @@
         'tf.compat.v1.disable_eager_execution',
     'tf.disable_resource_variables':
         'tf.compat.v1.disable_resource_variables',
+    'tf.disable_tensor_equality':
+        'tf.compat.v1.disable_tensor_equality',
     'tf.disable_v2_behavior':
         'tf.compat.v1.disable_v2_behavior',
     'tf.disable_v2_tensorshape':
@@ -331,6 +333,8 @@
         'tf.compat.v1.enable_eager_execution',
     'tf.enable_resource_variables':
         'tf.compat.v1.enable_resource_variables',
+    'tf.enable_tensor_equality':
+        'tf.compat.v1.enable_tensor_equality',
     'tf.enable_v2_behavior':
         'tf.compat.v1.enable_v2_behavior',
     'tf.enable_v2_tensorshape':
@@ -363,6 +367,8 @@
         'tf.compat.v1.estimator.tpu.TPUEstimatorSpec',
     'tf.estimator.tpu.experimental.EmbeddingSpec':
         'tf.compat.v1.estimator.tpu.experimental.EmbeddingSpec',
+    'tf.experimental.output_all_intermediates':
+        'tf.compat.v1.experimental.output_all_intermediates',
     'tf.expm1':
         'tf.math.expm1',
     'tf.fake_quant_with_min_max_args':
@@ -1459,10 +1465,6 @@
         'tf.compat.v1.train.do_quantize_training_on_graphdef',
     'tf.train.experimental.MixedPrecisionLossScaleOptimizer':
         'tf.compat.v1.train.experimental.MixedPrecisionLossScaleOptimizer',
-    'tf.train.experimental.disable_mixed_precision_graph_rewrite':
-        'tf.compat.v1.train.experimental.disable_mixed_precision_graph_rewrite',
-    'tf.train.experimental.enable_mixed_precision_graph_rewrite':
-        'tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite',
     'tf.train.exponential_decay':
         'tf.compat.v1.train.exponential_decay',
     'tf.train.export_meta_graph':
diff --git a/tensorflow/tools/compatibility/testdata/test_file_v1_12.py b/tensorflow/tools/compatibility/testdata/test_file_v1_12.py
index 42f8cb7..ca33adb 100644
--- a/tensorflow/tools/compatibility/testdata/test_file_v1_12.py
+++ b/tensorflow/tools/compatibility/testdata/test_file_v1_12.py
@@ -21,10 +21,16 @@
 from tensorflow.python.framework import test_util
 from tensorflow.python.platform import test as test_lib
 
+_TEST_VERSION = 1
+
 
 class TestUpgrade(test_util.TensorFlowTestCase):
   """Test various APIs that have been changed in 2.0."""
 
+  @classmethod
+  def setUpClass(cls):
+    cls._tf_api_version = 1 if hasattr(tf, 'contrib') else 2
+
   def setUp(self):
     tf.compat.v1.enable_v2_behavior()
 
@@ -74,6 +80,14 @@
     self.assertAllClose(out, 0.40318608)
 
   def testLinearClassifier(self):
+    if _TEST_VERSION == 2 and self._tf_api_version == 1:
+      # Skip if we converted this file to v2 but running with tf v1.
+      # In this case, conversion script adds reference to
+      # tf.keras.losses.Reduction which is not available in v1.
+      self.skipTest(
+          'After converting to 2.0, this test does not work with '
+          'TensorFlow 1.x.')
+      return
     feature_column = tf.feature_column.numeric_column(
         'feature', shape=(1,))
 
diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py
index adc8aa4..221353d 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_v2.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py
@@ -1526,6 +1526,10 @@
             "'merge_repeated' argument and behaves as if merge_repeated=False. "
             "This call site specifies something other than "
             "merge_repeated=False, so it was converted to compat.v1."),
+        "tf.nn.dilation2d": functools.partial(
+            _add_argument_transformer,
+            arg_name="data_format",
+            arg_value_ast=ast.Str("NHWC")),
         "tf.nn.erosion2d": functools.partial(
             _add_argument_transformer,
             arg_name="data_format",
@@ -2024,7 +2028,7 @@
 
   Default value for tf.estimator.*Classifier and tf.estimator.*Regressor
   loss_reduction argument changed to SUM_OVER_BATCH_SIZE. So, we update
-  existing calls to use the old default value `tf.losses.Reduction.SUM`.
+  existing calls to use the old default value `tf.keras.losses.Reduction.SUM`.
 
   Note: to apply this transformation, symbol must be added
   to reordered_function_names above.
@@ -2032,9 +2036,7 @@
   for keyword_arg in node.keywords:
     if keyword_arg.arg == "loss_reduction":
       return node
-  # TODO(annarev): this should be updated to tf.keras.losses.Reduction.SUM
-  # once b/125525822 is fixed.
-  default_value = "tf.compat.v1.losses.Reduction.SUM"
+  default_value = "tf.keras.losses.Reduction.SUM"
   # Parse with pasta instead of ast to avoid emitting a spurious trailing \n.
   ast_value = pasta.parse(default_value)
   node.keywords.append(ast.keyword(arg="loss_reduction", value=ast_value))
diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py
index 68fe923..4464a2a 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py
@@ -684,7 +684,7 @@
     for c in classes:
       ns = "tf.estimator." + c
       text = ns + "()"
-      expected_text = ns + "(loss_reduction=tf.compat.v1.losses.Reduction.SUM)"
+      expected_text = ns + "(loss_reduction=tf.keras.losses.Reduction.SUM)"
       _, report, errors, new_text = self._upgrade(text)
       self.assertEqual(expected_text, new_text)
 
@@ -703,7 +703,7 @@
     text = "tf.estimator.BaselineClassifier(model_dir=model_dir)"
     expected_text = ("tf.estimator.BaselineClassifier(" +
                      "model_dir=model_dir, "
-                     "loss_reduction=tf.compat.v1.losses.Reduction.SUM)")
+                     "loss_reduction=tf.keras.losses.Reduction.SUM)")
     _, report, errors, new_text = self._upgrade(text)
     self.assertEqual(expected_text, new_text)
 
@@ -728,7 +728,7 @@
       suffix = "(input_layer_partitioner=TEST)"
       text = ns + suffix
       suffix = ("(input_layer_partitioner=TEST, "
-                "loss_reduction=tf.compat.v1.losses.Reduction.SUM)")
+                "loss_reduction=tf.keras.losses.Reduction.SUM)")
       expected_text = "tf.compat.v1.estimator." + c + suffix
       _, unused_report, unused_errors, new_text = self._upgrade(text)
       self.assertEqual(new_text, expected_text)
@@ -764,7 +764,7 @@
       suffix = "(optimizer=TEST)"
       text = ns + suffix
       suffix = ("(optimizer=TEST, "
-                "loss_reduction=tf.compat.v1.losses.Reduction.SUM)")
+                "loss_reduction=tf.keras.losses.Reduction.SUM)")
       expected_text = "tf.compat.v1.estimator." + c + suffix
       _, unused_report, unused_errors, new_text = self._upgrade(text)
       self.assertEqual(new_text, expected_text)
@@ -779,7 +779,7 @@
       suffix = "(dnn_optimizer=TEST, linear_optimizer=Test)"
       text = ns + suffix
       suffix = ("(dnn_optimizer=TEST, linear_optimizer=Test, "
-                "loss_reduction=tf.compat.v1.losses.Reduction.SUM)")
+                "loss_reduction=tf.keras.losses.Reduction.SUM)")
       expected_text = "tf.compat.v1.estimator." + c + suffix
       _, unused_report, unused_errors, new_text = self._upgrade(text)
       self.assertEqual(new_text, expected_text)
@@ -815,7 +815,7 @@
       suffix = "(input_layer_partitioner=TEST, optimizer=TEST)"
       text = ns + suffix
       suffix = ("(input_layer_partitioner=TEST, optimizer=TEST, "
-                "loss_reduction=tf.compat.v1.losses.Reduction.SUM)")
+                "loss_reduction=tf.keras.losses.Reduction.SUM)")
       expected_text = "tf.compat.v1.estimator." + c + suffix
       _, unused_report, unused_errors, new_text = self._upgrade(text)
       self.assertEqual(new_text, expected_text)
@@ -833,7 +833,7 @@
       text = ns + suffix
       suffix = ("(input_layer_partitioner=TEST, dnn_optimizer=TEST, "
                 "linear_optimizer=TEST, "
-                "loss_reduction=tf.compat.v1.losses.Reduction.SUM)")
+                "loss_reduction=tf.keras.losses.Reduction.SUM)")
       expected_text = "tf.compat.v1.estimator." + c + suffix
       _, unused_report, unused_errors, new_text = self._upgrade(text)
       self.assertEqual(new_text, expected_text)
@@ -2069,6 +2069,12 @@
     _, _, _, new_text = self._upgrade(text)
     self.assertEqual(new_text, expected_text)
 
+  def testNnDilation2d(self):
+    text = "tf.nn.dilation2d(v, k, s, r, p)"
+    expected_text = "tf.nn.dilation2d(v, k, s, r, p, data_format='NHWC')"
+    _, _, _, new_text = self._upgrade(text)
+    self.assertEqual(new_text, expected_text)
+
   def testPywrapTensorflowWarning(self):
     text = "tf.pywrap_tensorflow.foo()"
     expected = "tf.pywrap_tensorflow.foo()"
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu-jupyter.Dockerfile
index db17744..4728a4c 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu-jupyter.Dockerfile
@@ -98,7 +98,7 @@
     enum34
 
 # Install bazel
-ARG BAZEL_VERSION=0.24.1
+ARG BAZEL_VERSION=0.26.1
 RUN mkdir /bazel && \
     wget -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \
     wget -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu.Dockerfile
index 5c3ca61..f4396ca 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu.Dockerfile
@@ -98,7 +98,7 @@
     enum34
 
 # Install bazel
-ARG BAZEL_VERSION=0.24.1
+ARG BAZEL_VERSION=0.26.1
 RUN mkdir /bazel && \
     wget -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \
     wget -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile
index 02d8f89..80bffe6 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile
@@ -32,7 +32,7 @@
 ARG CUDNN_MAJOR_VERSION=7
 ARG LIB_DIR_PREFIX=x86_64
 
-# Needed for string substitution 
+# Needed for string substitution
 SHELL ["/bin/bash", "-c"]
 RUN apt-get update && apt-get install -y --no-install-recommends \
         build-essential \
@@ -84,6 +84,12 @@
 ARG CHECKOUT_TF_SRC=0
 RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true
 
+# Link the libcuda stub to the location where tensorflow is searching for it and reconfigure
+# dynamic linker run-time bindings
+RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1 \
+    && echo "/usr/local/cuda/lib64/stubs" > /etc/ld.so.conf.d/z-cuda-stubs.conf \
+    && ldconfig
+
 ARG USE_PYTHON_3_NOT_2
 ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
 ARG PYTHON=python${_PY_SUFFIX}
@@ -130,7 +136,7 @@
     enum34
 
 # Install bazel
-ARG BAZEL_VERSION=0.24.1
+ARG BAZEL_VERSION=0.26.1
 RUN mkdir /bazel && \
     wget -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \
     wget -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile
index 6d00ef3..f45c632 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile
@@ -32,7 +32,7 @@
 ARG CUDNN_MAJOR_VERSION=7
 ARG LIB_DIR_PREFIX=x86_64
 
-# Needed for string substitution 
+# Needed for string substitution
 SHELL ["/bin/bash", "-c"]
 RUN apt-get update && apt-get install -y --no-install-recommends \
         build-essential \
@@ -84,6 +84,12 @@
 ARG CHECKOUT_TF_SRC=0
 RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true
 
+# Link the libcuda stub to the location where tensorflow is searching for it and reconfigure
+# dynamic linker run-time bindings
+RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1 \
+    && echo "/usr/local/cuda/lib64/stubs" > /etc/ld.so.conf.d/z-cuda-stubs.conf \
+    && ldconfig
+
 ARG USE_PYTHON_3_NOT_2
 ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
 ARG PYTHON=python${_PY_SUFFIX}
@@ -130,7 +136,7 @@
     enum34
 
 # Install bazel
-ARG BAZEL_VERSION=0.24.1
+ARG BAZEL_VERSION=0.26.1
 RUN mkdir /bazel && \
     wget -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \
     wget -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile
index fde7c9e..1a18e64 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile
@@ -30,7 +30,7 @@
 ARG CUDA
 ARG CUDNN=7.4.1.5-1
 
-# Needed for string substitution 
+# Needed for string substitution
 SHELL ["/bin/bash", "-c"]
 # Pick up some TF dependencies
 RUN apt-get update && apt-get install -y --no-install-recommends \
@@ -60,6 +60,12 @@
 # For CUDA profiling, TensorFlow requires CUPTI.
 ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
 
+# Link the libcuda stub to the location where tensorflow is searching for it and reconfigure
+# dynamic linker run-time bindings
+RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1 \
+    && echo "/usr/local/cuda/lib64/stubs" > /etc/ld.so.conf.d/z-cuda-stubs.conf \
+    && ldconfig
+
 ARG USE_PYTHON_3_NOT_2
 ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
 ARG PYTHON=python${_PY_SUFFIX}
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile
index a6ff1a5..07c775c 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile
@@ -30,7 +30,7 @@
 ARG CUDA
 ARG CUDNN=7.4.1.5-1
 
-# Needed for string substitution 
+# Needed for string substitution
 SHELL ["/bin/bash", "-c"]
 # Pick up some TF dependencies
 RUN apt-get update && apt-get install -y --no-install-recommends \
@@ -60,6 +60,12 @@
 # For CUDA profiling, TensorFlow requires CUPTI.
 ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
 
+# Link the libcuda stub to the location where tensorflow is searching for it and reconfigure
+# dynamic linker run-time bindings
+RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1 \
+    && echo "/usr/local/cuda/lib64/stubs" > /etc/ld.so.conf.d/z-cuda-stubs.conf \
+    && ldconfig
+
 ARG USE_PYTHON_3_NOT_2
 ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
 ARG PYTHON=python${_PY_SUFFIX}
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le-jupyter.Dockerfile
index a05c718..59768aa 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le-jupyter.Dockerfile
@@ -32,7 +32,7 @@
 ARG CUDNN_MAJOR_VERSION=7
 ARG LIB_DIR_PREFIX=x86_64
 
-# Needed for string substitution 
+# Needed for string substitution
 SHELL ["/bin/bash", "-c"]
 RUN apt-get update && apt-get install -y --no-install-recommends \
         build-essential \
@@ -84,6 +84,12 @@
 ARG CHECKOUT_TF_SRC=0
 RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true
 
+# Link the libcuda stub to the location where tensorflow is searching for it and reconfigure
+# dynamic linker run-time bindings
+RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1 \
+    && echo "/usr/local/cuda/lib64/stubs" > /etc/ld.so.conf.d/z-cuda-stubs.conf \
+    && ldconfig
+
 ARG USE_PYTHON_3_NOT_2
 ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
 ARG PYTHON=python${_PY_SUFFIX}
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le.Dockerfile
index 44d91ad..d4a4c92 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le.Dockerfile
@@ -32,7 +32,7 @@
 ARG CUDNN_MAJOR_VERSION=7
 ARG LIB_DIR_PREFIX=x86_64
 
-# Needed for string substitution 
+# Needed for string substitution
 SHELL ["/bin/bash", "-c"]
 RUN apt-get update && apt-get install -y --no-install-recommends \
         build-essential \
@@ -84,6 +84,12 @@
 ARG CHECKOUT_TF_SRC=0
 RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true
 
+# Link the libcuda stub to the location where tensorflow is searching for it and reconfigure
+# dynamic linker run-time bindings
+RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1 \
+    && echo "/usr/local/cuda/lib64/stubs" > /etc/ld.so.conf.d/z-cuda-stubs.conf \
+    && ldconfig
+
 ARG USE_PYTHON_3_NOT_2
 ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
 ARG PYTHON=python${_PY_SUFFIX}
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le-jupyter.Dockerfile
index b2f1ce1..b265a60 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le-jupyter.Dockerfile
@@ -30,7 +30,7 @@
 ARG CUDA
 ARG CUDNN=7.4.1.5-1
 
-# Needed for string substitution 
+# Needed for string substitution
 SHELL ["/bin/bash", "-c"]
 # Pick up some TF dependencies
 RUN apt-get update && apt-get install -y --no-install-recommends \
@@ -60,6 +60,12 @@
 # For CUDA profiling, TensorFlow requires CUPTI.
 ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
 
+# Link the libcuda stub to the location where tensorflow is searching for it and reconfigure
+# dynamic linker run-time bindings
+RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1 \
+    && echo "/usr/local/cuda/lib64/stubs" > /etc/ld.so.conf.d/z-cuda-stubs.conf \
+    && ldconfig
+
 ARG USE_PYTHON_3_NOT_2
 ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
 ARG PYTHON=python${_PY_SUFFIX}
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le.Dockerfile
index 3422ead..971d765 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le.Dockerfile
@@ -30,7 +30,7 @@
 ARG CUDA
 ARG CUDNN=7.4.1.5-1
 
-# Needed for string substitution 
+# Needed for string substitution
 SHELL ["/bin/bash", "-c"]
 # Pick up some TF dependencies
 RUN apt-get update && apt-get install -y --no-install-recommends \
@@ -60,6 +60,12 @@
 # For CUDA profiling, TensorFlow requires CUPTI.
 ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
 
+# Link the libcuda stub to the location where tensorflow is searching for it and reconfigure
+# dynamic linker run-time bindings
+RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1 \
+    && echo "/usr/local/cuda/lib64/stubs" > /etc/ld.so.conf.d/z-cuda-stubs.conf \
+    && ldconfig
+
 ARG USE_PYTHON_3_NOT_2
 ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
 ARG PYTHON=python${_PY_SUFFIX}
diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu/bazel.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu/bazel.partial.Dockerfile
index 4f76a1d..7ece3a4 100644
--- a/tensorflow/tools/dockerfiles/partials/ubuntu/bazel.partial.Dockerfile
+++ b/tensorflow/tools/dockerfiles/partials/ubuntu/bazel.partial.Dockerfile
@@ -25,7 +25,7 @@
     enum34
 
 # Install bazel
-ARG BAZEL_VERSION=0.24.1
+ARG BAZEL_VERSION=0.26.1
 RUN mkdir /bazel && \
     wget -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \
     wget -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \
diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu/devel-nvidia.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu/devel-nvidia.partial.Dockerfile
index fc0976b..2ba3a68 100644
--- a/tensorflow/tools/dockerfiles/partials/ubuntu/devel-nvidia.partial.Dockerfile
+++ b/tensorflow/tools/dockerfiles/partials/ubuntu/devel-nvidia.partial.Dockerfile
@@ -9,7 +9,7 @@
 ARG CUDNN_MAJOR_VERSION=7
 ARG LIB_DIR_PREFIX=x86_64
 
-# Needed for string substitution 
+# Needed for string substitution
 SHELL ["/bin/bash", "-c"]
 RUN apt-get update && apt-get install -y --no-install-recommends \
         build-essential \
@@ -60,3 +60,9 @@
 # Check out TensorFlow source code if --build-arg CHECKOUT_TF_SRC=1
 ARG CHECKOUT_TF_SRC=0
 RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true
+
+# Link the libcuda stub to the location where tensorflow is searching for it and reconfigure
+# dynamic linker run-time bindings
+RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1 \
+    && echo "/usr/local/cuda/lib64/stubs" > /etc/ld.so.conf.d/z-cuda-stubs.conf \
+    && ldconfig
diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu/nvidia.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu/nvidia.partial.Dockerfile
index b09c645..bb9253a 100644
--- a/tensorflow/tools/dockerfiles/partials/ubuntu/nvidia.partial.Dockerfile
+++ b/tensorflow/tools/dockerfiles/partials/ubuntu/nvidia.partial.Dockerfile
@@ -7,7 +7,7 @@
 ARG CUDA
 ARG CUDNN=7.4.1.5-1
 
-# Needed for string substitution 
+# Needed for string substitution
 SHELL ["/bin/bash", "-c"]
 # Pick up some TF dependencies
 RUN apt-get update && apt-get install -y --no-install-recommends \
@@ -36,3 +36,9 @@
 
 # For CUDA profiling, TensorFlow requires CUPTI.
 ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
+
+# Link the libcuda stub to the location where tensorflow is searching for it and reconfigure
+# dynamic linker run-time bindings
+RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1 \
+    && echo "/usr/local/cuda/lib64/stubs" > /etc/ld.so.conf.d/z-cuda-stubs.conf \
+    && ldconfig
diff --git a/tensorflow/tools/dockerfiles/tests/build-gpu.sh b/tensorflow/tools/dockerfiles/tests/build-gpu.sh
index 76b25d5..0e107e3 100755
--- a/tensorflow/tools/dockerfiles/tests/build-gpu.sh
+++ b/tensorflow/tools/dockerfiles/tests/build-gpu.sh
@@ -22,8 +22,6 @@
 
 ln -s $(which ${PYTHON}) /usr/local/bin/python 
 
-ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1
-
 LD_LIBRARY_PATH=/usr/local/cuda/lib64/stubs:${LD_LIBRARY_PATH} \
 tensorflow/tools/ci_build/builds/configured GPU \
 bazel build -c opt --copt=-mavx --config=cuda \
diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD
index c98c4ee..adafe2a 100644
--- a/tensorflow/tools/graph_transforms/BUILD
+++ b/tensorflow/tools/graph_transforms/BUILD
@@ -336,4 +336,5 @@
         "//tensorflow/python:variables",
     ],
     main = "python/transform_graph_test.py",
+    tags = ["v1only"],
 )
diff --git a/tensorflow/tools/graph_transforms/README.md b/tensorflow/tools/graph_transforms/README.md
index a90916c..34d6305 100644
--- a/tensorflow/tools/graph_transforms/README.md
+++ b/tensorflow/tools/graph_transforms/README.md
@@ -111,7 +111,7 @@
 tool can inspect the model and provide guesses about likely input and output nodes,
 as well as other information that's useful for debugging. Here's an example of
 how to use it on the [Inception V3
-graph](http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz):
+graph](https://storage.googleapis.com/download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz):
 
 ```bash
 bazel build tensorflow/tools/graph_transforms:summarize_graph
@@ -124,7 +124,7 @@
 transformation pipelines, aimed at users who want to quickly accomplish one of
 these tasks. A lot of them will use the Inception V3 model for their examples,
 which can be downloaded from
-[http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz](http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz).
+[https://storage.googleapis.com/download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz](https://storage.googleapis.com/download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz).
 
 ### Optimizing for Deployment
 
diff --git a/tensorflow/tools/graph_transforms/sparsify_gather.cc b/tensorflow/tools/graph_transforms/sparsify_gather.cc
index 49e5cca..cc4078d 100644
--- a/tensorflow/tools/graph_transforms/sparsify_gather.cc
+++ b/tensorflow/tools/graph_transforms/sparsify_gather.cc
@@ -126,7 +126,7 @@
     if (node.name() == tensor_names_node) {
       Tensor tensor_names_tensor;
       TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &tensor_names_tensor));
-      const auto& tensor_names_value = tensor_names_tensor.flat<string>();
+      const auto& tensor_names_value = tensor_names_tensor.flat<tstring>();
       for (int i = 0; i < tensor_names_value.size(); i++) {
         if (tensor_names_value(i) == GetMonolithicTensorKey(target_name)) {
           offset = i;
@@ -144,7 +144,7 @@
       Tensor shape_and_slices_tensor;
       TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &shape_and_slices_tensor));
       const auto& shape_and_slices_value =
-          shape_and_slices_tensor.flat<string>();
+          shape_and_slices_tensor.flat<tstring>();
       *shape_slice_string = shape_and_slices_value(offset);
       return Status::OK();
     }
diff --git a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc
index b8d6ba0..dfe8fb0 100644
--- a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc
+++ b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc
@@ -116,7 +116,7 @@
       NodeDef* tensor_shapes_slices_node = CreateNode(
           "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def);
       Tensor shapes_slices_val(DT_STRING, TensorShape({1}));
-      shapes_slices_val.flat<string>()(0) = "4 1 0,4:0,1";
+      shapes_slices_val.flat<tstring>()(0) = "4 1 0,4:0,1";
       SetNodeTensorAttr<string>("value", shapes_slices_val,
                                 tensor_shapes_slices_node);
 
@@ -327,8 +327,8 @@
       NodeDef* tensor_shapes_slices_node = CreateNode(
           "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def);
       Tensor shapes_slices_val(DT_STRING, TensorShape({2}));
-      shapes_slices_val.flat<string>()(0) = "4 1 0,4:0,1";
-      shapes_slices_val.flat<string>()(1) = "4 1 0,4:0,1";
+      shapes_slices_val.flat<tstring>()(0) = "4 1 0,4:0,1";
+      shapes_slices_val.flat<tstring>()(1) = "4 1 0,4:0,1";
       SetNodeTensorAttr<string>("value", shapes_slices_val,
                                 tensor_shapes_slices_node);
 
diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD
index b8ebdc8..75e01e3 100644
--- a/tensorflow/tools/lib_package/BUILD
+++ b/tensorflow/tools/lib_package/BUILD
@@ -8,7 +8,7 @@
 load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
 load("@local_config_syslibs//:build_defs.bzl", "if_not_system_lib")
 load("//tensorflow:tensorflow.bzl", "VERSION", "VERSION_MAJOR", "if_macos")
-load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps")
+load("//tensorflow/core/platform:default/build_config_root.bzl", "tf_additional_license_deps")
 load("//third_party/mkl:build_defs.bzl", "if_mkl")
 
 genrule(
diff --git a/tensorflow/tools/optimization/optimization_pass_runner.cc b/tensorflow/tools/optimization/optimization_pass_runner.cc
index 162d39d..8cd9e32 100644
--- a/tensorflow/tools/optimization/optimization_pass_runner.cc
+++ b/tensorflow/tools/optimization/optimization_pass_runner.cc
@@ -111,8 +111,8 @@
   GraphConstructorOptions graph_opts;
   graph_opts.expect_device_spec = true;
   graph_opts.allow_internal_ops = true;
-  TF_RETURN_IF_ERROR(
-      ConvertGraphDefToGraph(graph_opts, input, options.graph->get()));
+  TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(graph_opts, std::move(input),
+                                            options.graph->get()));
 
   // Add all devices that were previously configured with AddDevice.
   DeviceSet device_set;
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index e9a017a..08ce526 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -13,7 +13,7 @@
 load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
 load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
 load("@local_config_syslibs//:build_defs.bzl", "if_not_system_lib")
-load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps")
+load("//tensorflow/core/platform:default/build_config_root.bzl", "tf_additional_license_deps")
 load(
     "//third_party/ngraph:build_defs.bzl",
     "if_ngraph",
@@ -85,6 +85,7 @@
     "//tensorflow/python/debug:debug_pip",
     "//tensorflow/python/distribute:combinations",
     "//tensorflow/python/eager:eager_pip",
+    "//tensorflow/python/keras:model_subclassing_test_util",
     "//tensorflow/python/keras:preprocessing_test_utils",
     "//tensorflow/python/keras/distribute:distribute_strategy_test_lib",
     "//tensorflow/python/keras/distribute:multi_worker_testing_utils",
diff --git a/tensorflow/tools/proto_text/BUILD b/tensorflow/tools/proto_text/BUILD
index 4e5db0c..893a7f3 100644
--- a/tensorflow/tools/proto_text/BUILD
+++ b/tensorflow/tools/proto_text/BUILD
@@ -17,7 +17,7 @@
 
 # For platform specific build config
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
+    "//tensorflow/core/platform:default/build_config.bzl",
     "tf_proto_library_cc",
 )
 
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index f888e2d..485fa71 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -7,7 +7,6 @@
 load("//third_party/mkl:build_defs.bzl", "mkl_repository")
 load("//third_party/git:git_configure.bzl", "git_configure")
 load("//third_party/py:python_configure.bzl", "python_configure")
-load("//third_party/mlir:mlir_configure.bzl", "mlir_configure")
 load("//third_party/sycl:sycl_configure.bzl", "sycl_configure")
 load("//third_party/systemlibs:syslibs_configure.bzl", "syslibs_configure")
 load("//third_party/toolchains/remote:configure.bzl", "remote_execution_configure")
@@ -74,7 +73,10 @@
     syslibs_configure(name = "local_config_syslibs")
     python_configure(name = "local_config_python")
     rocm_configure(name = "local_config_rocm")
-    mlir_configure(name = "local_config_mlir")
+    native.local_repository(
+        name = "local_config_mlir",
+        path = "third_party/mlir",
+    )
     remote_execution_configure(name = "local_config_remote_execution")
 
     initialize_third_party()
@@ -155,11 +157,11 @@
     tf_http_archive(
         name = "com_google_absl",
         build_file = clean_dep("//third_party:com_google_absl.BUILD"),
-        sha256 = "eee7452846aae8040037234accf9a1cfbeca1d93bb4238b70f0d43d26645a602",
-        strip_prefix = "abseil-cpp-f3840bc5e33ce4932e35986cf3718450c6f02af2",
+        sha256 = "acd93f6baaedc4414ebd08b33bebca7c7a46888916101d8c0b8083573526d070",
+        strip_prefix = "abseil-cpp-43ef2148c0936ebf7cb4be6b19927a9d9d145b8f",
         urls = [
-            "https://storage.googleapis.com/mirror.tensorflow.org/github.com/abseil/abseil-cpp/archive/f3840bc5e33ce4932e35986cf3718450c6f02af2.tar.gz",
-            "https://github.com/abseil/abseil-cpp/archive/f3840bc5e33ce4932e35986cf3718450c6f02af2.tar.gz",
+            "http://mirror.tensorflow.org/github.com/abseil/abseil-cpp/archive/43ef2148c0936ebf7cb4be6b19927a9d9d145b8f.tar.gz",
+            "https://github.com/abseil/abseil-cpp/archive/43ef2148c0936ebf7cb4be6b19927a9d9d145b8f.tar.gz",
         ],
     )
 
@@ -167,11 +169,11 @@
         name = "eigen_archive",
         build_file = clean_dep("//third_party:eigen.BUILD"),
         patch_file = clean_dep("//third_party/eigen3:gpu_packet_math.patch"),
-        sha256 = "f3d69ac773ecaf3602cb940040390d4e71a501bb145ca9e01ce5464cf6d4eb68",
-        strip_prefix = "eigen-eigen-049af2f56331",
+        sha256 = "7e7a57e33c59280a17a66e521396cd8b1a55d0676c9f807078522fda52114b5c",
+        strip_prefix = "eigen-eigen-8071cda5714d",
         urls = [
-            "https://storage.googleapis.com/mirror.tensorflow.org/bitbucket.org/eigen/eigen/get/049af2f56331.tar.gz",
-            "https://bitbucket.org/eigen/eigen/get/049af2f56331.tar.gz",
+            "https://storage.googleapis.com/mirror.tensorflow.org/bitbucket.org/eigen/eigen/get/8071cda5714d.tar.gz",
+            "https://bitbucket.org/eigen/eigen/get/8071cda5714d.tar.gz",
         ],
     )
 
@@ -387,6 +389,7 @@
         ],
         sha256 = "8ad8c4783bf61ded74527bffb48ed9b54166685e4230386a9ed9b1279e2df5b1",
         build_file = clean_dep("//third_party:enum34.BUILD"),
+        system_build_file = clean_dep("//third_party/systemlibs:enum34.BUILD"),
         strip_prefix = "enum34-1.1.6/enum",
     )
 
@@ -495,12 +498,12 @@
     tf_http_archive(
         name = "curl",
         build_file = clean_dep("//third_party:curl.BUILD"),
-        sha256 = "e9c37986337743f37fd14fe8737f246e97aec94b39d1b71e8a5973f72a9fc4f5",
-        strip_prefix = "curl-7.60.0",
+        sha256 = "821aeb78421375f70e55381c9ad2474bf279fc454b791b7e95fc83562951c690",
+        strip_prefix = "curl-7.65.1",
         system_build_file = clean_dep("//third_party/systemlibs:curl.BUILD"),
         urls = [
-            "https://storage.googleapis.com/mirror.tensorflow.org/curl.haxx.se/download/curl-7.60.0.tar.gz",
-            "https://curl.haxx.se/download/curl-7.60.0.tar.gz",
+            "http://mirror.tensorflow.org/curl.haxx.se/download/curl-7.65.1.tar.gz",
+            "https://curl.haxx.se/download/curl-7.65.1.tar.gz",
         ],
     )
 
@@ -543,11 +546,11 @@
     tf_http_archive(
         name = "llvm",
         build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
-        sha256 = "88012afcd6d8238430d39967b62e5599bc31d9c4cdc6d20281bedf1020b7000b",
-        strip_prefix = "llvm-b7d166cebcf619a3691eed3f994384aab3d80fa6",
+        sha256 = "4aab057172b4b5f6d50abfd6175707d8ca31944c42fbfd08d914ec1503f4b32e",
+        strip_prefix = "llvm-bd17a8c045af512595fab6e255b285496128177c",
         urls = [
-            "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/b7d166cebcf619a3691eed3f994384aab3d80fa6.tar.gz",
-            "https://github.com/llvm-mirror/llvm/archive/b7d166cebcf619a3691eed3f994384aab3d80fa6.tar.gz",
+            "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/bd17a8c045af512595fab6e255b285496128177c.tar.gz",
+            "https://github.com/llvm-mirror/llvm/archive/bd17a8c045af512595fab6e255b285496128177c.tar.gz",
         ],
     )
 
@@ -787,8 +790,8 @@
         build_file = clean_dep("//third_party:tflite_mobilenet_float.BUILD"),
         sha256 = "2fadeabb9968ec6833bee903900dda6e61b3947200535874ce2fe42a8493abc0",
         urls = [
-            "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz",
-            "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz",
+            "https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz",
+            "https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz",
         ],
     )
 
@@ -797,8 +800,8 @@
         build_file = clean_dep("//third_party:tflite_mobilenet_quant.BUILD"),
         sha256 = "d32432d28673a936b2d6281ab0600c71cf7226dfe4cdcef3012555f691744166",
         urls = [
-            "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz",
-            "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz",
+            "https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz",
+            "https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz",
         ],
     )
 
@@ -829,7 +832,7 @@
         strip_prefix = "ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18",
         urls = [
             "https://storage.googleapis.com/mirror.tensorflow.org/storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18.tar.gz",
-            "http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18.tar.gz",
+            "https://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18.tar.gz",
         ],
     )
 
@@ -921,11 +924,11 @@
     tf_http_archive(
         name = "pybind11",
         urls = [
-            "https://mirror.bazel.build/github.com/pybind/pybind11/archive/v2.2.4.tar.gz",
-            "https://github.com/pybind/pybind11/archive/v2.2.4.tar.gz",
+            "https://mirror.bazel.build/github.com/pybind/pybind11/archive/v2.3.0.tar.gz",
+            "https://github.com/pybind/pybind11/archive/v2.3.0.tar.gz",
         ],
-        sha256 = "b69e83658513215b8d1443544d0549b7d231b9f201f6fc787a2b2218b408181e",
-        strip_prefix = "pybind11-2.2.4",
+        sha256 = "0f34838f2c8024a6765168227ba587b3687729ebf03dc912f88ff75c7aa9cfe8",
+        strip_prefix = "pybind11-2.3.0",
         build_file = clean_dep("//third_party:pybind11.BUILD"),
     )
 
diff --git a/third_party/curl.BUILD b/third_party/curl.BUILD
index a3aa3ce..6688080 100644
--- a/third_party/curl.BUILD
+++ b/third_party/curl.BUILD
@@ -154,8 +154,7 @@
         "lib/parsedate.c",
         "lib/parsedate.h",
         "lib/pingpong.h",
-        "lib/pipeline.c",
-        "lib/pipeline.h",
+        "lib/pingpong.c",
         "lib/pop3.h",
         "lib/progress.c",
         "lib/progress.h",
@@ -217,9 +216,7 @@
         "lib/vauth/vauth.c",
         "lib/vauth/vauth.h",
         "lib/version.c",
-        "lib/vtls/axtls.h",
         "lib/vtls/cyassl.h",
-        "lib/vtls/darwinssl.h",
         "lib/vtls/gskit.h",
         "lib/vtls/gtls.h",
         "lib/vtls/mbedtls.h",
@@ -235,12 +232,25 @@
         "lib/wildcard.c",
         "lib/wildcard.h",
         "lib/x509asn1.h",
+        "lib/psl.h",
+        "lib/psl.c",
+        "lib/vtls/sectransp.h",
+        "lib/vtls/mesalink.h",
+        "lib/vtls/mesalink.c",
+        "lib/curl_get_line.h",
+        "lib/curl_get_line.c",
+        "lib/urlapi-int.h",
+        "lib/urlapi.c",
+        "lib/altsvc.h",
+        "lib/altsvc.c",
+        "lib/doh.h",
+        "lib/doh.c",
     ] + select({
         "@org_tensorflow//tensorflow:macos": [
-            "lib/vtls/darwinssl.c",
+            "lib/vtls/sectransp.c",
         ],
         "@org_tensorflow//tensorflow:ios": [
-            "lib/vtls/darwinssl.c",
+            "lib/vtls/sectransp.c",
         ],
         "@org_tensorflow//tensorflow:windows": CURL_WIN_SRCS,
         "//conditions:default": [
@@ -256,6 +266,7 @@
         "include/curl/stdcheaders.h",
         "include/curl/system.h",
         "include/curl/typecheck-gcc.h",
+        "include/curl/urlapi.h",
     ],
     copts = select({
         "@org_tensorflow//tensorflow:windows": CURL_WIN_COPTS,
@@ -465,7 +476,7 @@
         "#  define HAVE_SYS_FILIO_H 1",
         "#  define HAVE_SYS_SOCKIO_H 1",
         "#  define OS \"x86_64-apple-darwin15.5.0\"",
-        "#  define USE_DARWINSSL 1",
+        "#  define USE_SECTRANSP 1",
         "#else",
         "#  define CURL_CA_BUNDLE \"/etc/ssl/certs/ca-certificates.crt\"",
         "#  define GETSERVBYPORT_R_ARGS 6",
diff --git a/third_party/icu/BUILD.bazel b/third_party/icu/BUILD.bazel
index 36d6b90..6949656 100644
--- a/third_party/icu/BUILD.bazel
+++ b/third_party/icu/BUILD.bazel
@@ -44,7 +44,7 @@
     ]),
     copts = [
         "-DU_COMMON_IMPLEMENTATION",
-        "-DU_HAVE_STD_ATOMICS",
+        "-DU_HAVE_STD_ATOMICS",  # TODO(gunan): Remove when TF is on ICU 64+.
     ] + select({
         ":android": [
             "-fdata-sections",
diff --git a/third_party/llvm/llvm.autogenerated.BUILD b/third_party/llvm/llvm.autogenerated.BUILD
index 4003262..3270532 100644
--- a/third_party/llvm/llvm.autogenerated.BUILD
+++ b/third_party/llvm/llvm.autogenerated.BUILD
@@ -2778,6 +2778,7 @@
     deps = [
         ":config",
         ":debug_info_code_view",
+        ":mc",
         ":object",
         ":support",
     ],
diff --git a/third_party/llvm/llvm.bzl b/third_party/llvm/llvm.bzl
index efb62a4..8b0fdec 100644
--- a/third_party/llvm/llvm.bzl
+++ b/third_party/llvm/llvm.bzl
@@ -354,7 +354,7 @@
         "UNICODE",
         "_UNICODE",
     ],
-    "//conditions:default": ["_DEBUG"],
+    "//conditions:default": [],
 }) + [
     "LLVM_ENABLE_STATS",
     "__STDC_LIMIT_MACROS",
diff --git a/third_party/mlir/.clang-format b/third_party/mlir/.clang-format
new file mode 100644
index 0000000..392e201
--- /dev/null
+++ b/third_party/mlir/.clang-format
@@ -0,0 +1,2 @@
+BasedOnStyle: LLVM
+AlwaysBreakTemplateDeclarations: Yes
\ No newline at end of file
diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD
index 9e6a07a..4f6f229 100644
--- a/third_party/mlir/BUILD
+++ b/third_party/mlir/BUILD
@@ -13,11 +13,16 @@
     packages = ["//..."],
 )
 
-# Please do not depend on this from any other packages.
+# Before adding a project here, please read go/mlir-sla
+# In particular the OWNERS file of the dependent project should be updated.
 package_group(
     name = "friends",
     includes = ["@org_tensorflow//tensorflow/compiler/mlir:subpackages"],
-    packages = ["//..."],
+    packages = [
+        "//...",
+        "//learning/glassbox/evaluation/compiler/...",
+        "//tensorflow/compiler/xla/service/gpu/mlir/...",
+    ],
 )
 
 exports_files([
@@ -42,6 +47,7 @@
         "lib/IR/Diagnostics.cpp",
         "lib/IR/Dialect.cpp",
         "lib/IR/Function.cpp",
+        "lib/IR/FunctionSupport.cpp",
         "lib/IR/IntegerSet.cpp",
         "lib/IR/IntegerSetDetail.h",
         "lib/IR/Location.cpp",
@@ -55,6 +61,7 @@
         "lib/IR/StandardTypes.cpp",
         "lib/IR/SymbolTable.cpp",
         "lib/IR/TypeDetail.h",
+        "lib/IR/TypeUtilities.cpp",
         "lib/IR/Types.cpp",
         "lib/IR/Value.cpp",
     ],
@@ -72,6 +79,7 @@
         "include/mlir/IR/DialectHooks.h",
         "include/mlir/IR/DialectSymbolRegistry.def",
         "include/mlir/IR/Function.h",
+        "include/mlir/IR/FunctionSupport.h",
         "include/mlir/IR/Identifier.h",
         "include/mlir/IR/IntegerSet.h",
         "include/mlir/IR/Location.h",
@@ -89,6 +97,7 @@
         "include/mlir/IR/StorageUniquerSupport.h",
         "include/mlir/IR/SymbolTable.h",
         "include/mlir/IR/TypeSupport.h",
+        "include/mlir/IR/TypeUtilities.h",
         "include/mlir/IR/Types.h",
         "include/mlir/IR/UseDefLists.h",
         "include/mlir/IR/Value.h",
@@ -168,6 +177,7 @@
     name = "AffineOpsTdFiles",
     srcs = [
         "include/mlir/AffineOps/AffineOps.td",
+        "include/mlir/AffineOps/AffineOpsBase.td",
         ":OpBaseTdFiles",
     ],
 )
@@ -376,6 +386,7 @@
     deps = [
         ":IR",
         ":Support",
+        ":VectorOpsIncGen",
         "@llvm//:support",
     ],
 )
@@ -405,6 +416,7 @@
         "include/mlir/Support/MathExtras.h",
         "include/mlir/Support/STLExtras.h",
         "include/mlir/Support/StorageUniquer.h",
+        "include/mlir/Support/StringExtras.h",
     ],
     includes = ["include"],
     deps = [
@@ -413,21 +425,6 @@
 )
 
 cc_library(
-    name = "TypeUtilities",
-    srcs = [
-        "lib/Support/TypeUtilities.cpp",
-    ],
-    hdrs = [
-        "include/mlir/Support/TypeUtilities.h",
-    ],
-    includes = ["include"],
-    deps = [
-        ":IR",
-        "@llvm//:support",
-    ],
-)
-
-cc_library(
     name = "Parser",
     srcs = [
         "lib/Parser/Lexer.cpp",
@@ -474,7 +471,7 @@
 filegroup(
     name = "GPUOpsTdFiles",
     srcs = [
-        "include/mlir/GPU/GPUOps.td",
+        "include/mlir/Dialect/GPU/GPUOps.td",
         ":OpBaseTdFiles",
     ],
 )
@@ -484,15 +481,15 @@
     tbl_outs = [
         (
             "-gen-op-decls",
-            "include/mlir/GPU/GPUOps.h.inc",
+            "include/mlir/Dialect/GPU/GPUOps.h.inc",
         ),
         (
             "-gen-op-defs",
-            "include/mlir/GPU/GPUOps.cpp.inc",
+            "include/mlir/Dialect/GPU/GPUOps.cpp.inc",
         ),
     ],
     tblgen = ":mlir-tblgen",
-    td_file = "include/mlir/GPU/GPUOps.td",
+    td_file = "include/mlir/Dialect/GPU/GPUOps.td",
     td_srcs = [
         ":GPUOpsTdFiles",
     ],
@@ -500,8 +497,8 @@
 
 cc_library(
     name = "GPUDialect",
-    srcs = ["lib/GPU/IR/GPUDialect.cpp"],
-    hdrs = ["include/mlir/GPU/GPUDialect.h"],
+    srcs = ["lib/Dialect/GPU/IR/GPUDialect.cpp"],
+    hdrs = ["include/mlir/Dialect/GPU/GPUDialect.h"],
     includes = ["include"],
     deps = [
         ":GPUOpsIncGen",
@@ -513,7 +510,7 @@
 
 cc_library(
     name = "GPUDialectRegistration",
-    srcs = ["lib/GPU/IR/DialectRegistration.cpp"],
+    srcs = ["lib/Dialect/GPU/IR/DialectRegistration.cpp"],
     includes = ["include"],
     deps = [
         ":GPUDialect",
@@ -523,8 +520,8 @@
 
 cc_library(
     name = "GPUTransforms",
-    srcs = ["lib/GPU/Transforms/KernelOutlining.cpp"],
-    hdrs = ["include/mlir/GPU/Passes.h"],
+    srcs = ["lib/Dialect/GPU/Transforms/KernelOutlining.cpp"],
+    hdrs = ["include/mlir/Dialect/GPU/Passes.h"],
     includes = ["include"],
     deps = [
         ":GPUDialect",
@@ -588,6 +585,23 @@
     alwayslink = 1,
 )
 
+cc_library(
+    name = "GPUToSPIRVTransforms",
+    srcs = [
+        "lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp",
+    ],
+    includes = ["include"],
+    deps = [
+        ":GPUDialect",
+        ":IR",
+        ":Pass",
+        ":SPIRVConversions",
+        ":SPIRVDialect",
+        ":StandardOps",
+    ],
+    alwayslink = 1,
+)
+
 gentbl(
     name = "LLVMOpsIncGen",
     tbl_outs = [
@@ -739,15 +753,15 @@
 )
 
 gentbl(
-    name = "StdOpsToSPIRVConversionIncGen",
+    name = "StandardToSPIRVGen",
     tbl_outs = [
         (
             "-gen-rewriters",
-            "lib/Conversion/StandardToSPIRV/StdOpsToSPIRVConversion.cpp.inc",
+            "lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp.inc",
         ),
     ],
     tblgen = ":mlir-tblgen",
-    td_file = "lib/Conversion/StandardToSPIRV/StdOpsToSPIRVConversion.td",
+    td_file = "lib/Conversion/StandardToSPIRV/StandardToSPIRV.td",
     td_srcs = [
         ":SPIRVOpsTdFiles",
         ":StdOpsTdFiles",
@@ -770,10 +784,10 @@
 )
 
 gentbl(
-    name = "SPIRVSerializationIncGen",
+    name = "SPIRVSerializationGen",
     tbl_outs = [
         (
-            "-gen-spirv-serial",
+            "-gen-spirv-serialization",
             "include/mlir/Dialect/SPIRV/SPIRVSerialization.inc",
         ),
     ],
@@ -816,10 +830,12 @@
 cc_library(
     name = "SPIRVConversions",
     srcs = [
-        "lib/Conversion/StandardToSPIRV/StdOpsToSPIRVConversion.cpp",
-        "lib/Conversion/StandardToSPIRV/StdOpsToSPIRVConversion.cpp.inc",
+        "lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp",
+        "lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp",
+        "lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp.inc",
     ],
     hdrs = [
+        "include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h",
         "include/mlir/Dialect/SPIRV/Passes.h",
     ],
     includes = [
@@ -831,8 +847,9 @@
         ":Pass",
         ":SPIRVDialect",
         ":StandardOps",
-        ":StdOpsToSPIRVConversionIncGen",
+        ":StandardToSPIRVGen",
         ":Support",
+        ":Transforms",
         "@llvm//:support",
     ],
     alwayslink = 1,
@@ -843,17 +860,18 @@
     srcs = [
         "include/mlir/Dialect/SPIRV/SPIRVSerialization.inc",
         "lib/Dialect/SPIRV/Serialization/Deserializer.cpp",
-        "lib/Dialect/SPIRV/Serialization/SPIRVBinaryUtils.h",
+        "lib/Dialect/SPIRV/Serialization/SPIRVBinaryUtils.cpp",
         "lib/Dialect/SPIRV/Serialization/Serializer.cpp",
     ],
     hdrs = [
+        "include/mlir/Dialect/SPIRV/SPIRVBinaryUtils.h",
         "include/mlir/Dialect/SPIRV/Serialization.h",
     ],
     includes = ["include"],
     deps = [
         ":IR",
         ":SPIRVDialect",
-        ":SPIRVSerializationIncGen",
+        ":SPIRVSerializationGen",
         ":Support",
         "@llvm//:support",
     ],
@@ -937,14 +955,13 @@
 cc_library(
     name = "Transforms",
     srcs = [
+        "lib/Transforms/AffineDataCopyGeneration.cpp",
         "lib/Transforms/CSE.cpp",
         "lib/Transforms/Canonicalizer.cpp",
         "lib/Transforms/DialectConversion.cpp",
-        "lib/Transforms/DmaGeneration.cpp",
         "lib/Transforms/LoopCoalescing.cpp",
         "lib/Transforms/LoopFusion.cpp",
         "lib/Transforms/LoopInvariantCodeMotion.cpp",
-        "lib/Transforms/LoopParametricTiling.cpp",
         "lib/Transforms/LoopTiling.cpp",
         "lib/Transforms/LoopUnroll.cpp",
         "lib/Transforms/LoopUnrollAndJam.cpp",
@@ -1103,6 +1120,7 @@
     deps = [
         ":AffineOps",
         ":IR",
+        ":LoopOps",
         ":Pass",
         ":StandardOps",
         ":Support",
@@ -1244,6 +1262,7 @@
     deps = [
         ":Analysis",
         ":GPUToNVVMTransforms",
+        ":GPUToSPIRVTransforms",
         ":GPUTransforms",
         ":IR",
         ":LLVMDialect",
@@ -1352,9 +1371,9 @@
         ":StandardDialectRegistration",
         ":Transforms",
         ":VectorDialectRegistration",
+        "//test:TestDialect",
+        "//test:TestTransforms",
         "@llvm//:support",
-        "@local_config_mlir//test:TestDialect",
-        "@local_config_mlir//test:TestTransforms",
     ],
 )
 
@@ -1376,6 +1395,7 @@
         ":Support",
         ":Transforms",
         "@llvm//:core",
+        "@llvm//:orc_jit",
         "@llvm//:support",
     ],
     alwayslink = 1,
@@ -1679,6 +1699,7 @@
     srcs = [
         "include/mlir/Linalg/IR/LinalgBase.td",
         "include/mlir/Linalg/IR/LinalgLibraryOps.td",
+        ":AffineOpsTdFiles",
         ":OpBaseTdFiles",
     ],
 )
@@ -1724,6 +1745,7 @@
         "include/mlir/Linalg/Utils/Utils.h",
     ],
     deps = [
+        ":AffineOps",
         ":CFGTransforms",
         ":EDSC",
         ":IR",
@@ -1811,6 +1833,33 @@
     alwayslink = 1,
 )
 
+filegroup(
+    name = "VectorOpsTdFiles",
+    srcs = [
+        "include/mlir/VectorOps/VectorOps.td",
+        ":OpBaseTdFiles",
+    ],
+)
+
+gentbl(
+    name = "VectorOpsIncGen",
+    tbl_outs = [
+        (
+            "-gen-op-decls",
+            "include/mlir/VectorOps/VectorOps.h.inc",
+        ),
+        (
+            "-gen-op-defs",
+            "include/mlir/VectorOps/VectorOps.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/VectorOps/VectorOps.td",
+    td_srcs = [
+        ":VectorOpsTdFiles",
+    ],
+)
+
 # To reference all tablegen files here when checking for updates to them.
 filegroup(
     name = "TdFiles",
diff --git a/third_party/mlir/CMakeLists.txt b/third_party/mlir/CMakeLists.txt
new file mode 100644
index 0000000..86266a2
--- /dev/null
+++ b/third_party/mlir/CMakeLists.txt
@@ -0,0 +1,63 @@
+# MLIR project.
+set(MLIR_MAIN_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include ) # --src-root
+set(MLIR_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/include ) # --includedir
+set(MLIR_TABLEGEN_EXE mlir-tblgen)
+
+set(MLIR_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
+set(MLIR_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
+
+# TODO: Temporary, remove when no longer needed.
+set(CMAKE_CXX_STANDARD 11)
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+set(CMAKE_CXX_EXTENSIONS OFF)
+
+function(mlir_tablegen ofn)
+  tablegen(MLIR ${ARGV} "-I${MLIR_MAIN_SRC_DIR}" "-I${MLIR_INCLUDE_DIR}")
+  set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn}
+      PARENT_SCOPE)
+endfunction()
+
+# TODO: This is to handle the current static registration, but should be
+# factored out a bit.
+function(whole_archive_link target)
+  if("${CMAKE_SYSTEM_NAME}" STREQUAL "Darwin")
+    set(link_flags "-L${CMAKE_BINARY_DIR}/lib ")
+    FOREACH(LIB ${ARGN})
+      string(CONCAT link_flags ${link_flags} "-Wl,-force_load ${CMAKE_BINARY_DIR}/lib/lib${LIB}.a ")
+    ENDFOREACH(LIB)
+  elseif(MSVC)
+    FOREACH(LIB ${ARGN})
+      string(CONCAT link_flags ${link_flags} "/WHOLEARCHIVE:${LIB} ")
+    ENDFOREACH(LIB)
+  else()
+    set(link_flags "-L${CMAKE_BINARY_DIR}/lib -Wl,--whole-archive,")
+    FOREACH(LIB ${ARGN})
+      string(CONCAT link_flags ${link_flags} "-l${LIB},")
+    ENDFOREACH(LIB)
+    string(CONCAT link_flags ${link_flags} "--no-whole-archive")
+  endif()
+  set_target_properties(${target} PROPERTIES LINK_FLAGS ${link_flags})
+endfunction(whole_archive_link)
+
+# Build the CUDA conversions and run according tests if the NVPTX backend
+# is available
+if ("NVPTX" IN_LIST LLVM_TARGETS_TO_BUILD)
+  set(MLIR_CUDA_CONVERSIONS_ENABLED 1)
+else()
+  set(MLIR_CUDA_CONVERSIONS_ENABLED 0)
+endif()
+
+set(MLIR_CUDA_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir CUDA runner")
+
+include_directories( "include")
+include_directories( ${MLIR_INCLUDE_DIR})
+
+add_subdirectory(include/mlir)
+add_subdirectory(lib)
+add_subdirectory(tools)
+add_subdirectory(unittests)
+add_subdirectory(test)
+
+if( LLVM_INCLUDE_EXAMPLES )
+  add_subdirectory(examples)
+endif()
diff --git a/third_party/mlir/CONTRIBUTING.md b/third_party/mlir/CONTRIBUTING.md
new file mode 100644
index 0000000..e21e4b8
--- /dev/null
+++ b/third_party/mlir/CONTRIBUTING.md
@@ -0,0 +1,49 @@
+# How to Contribute
+
+Everyone is welcome to contribute to MLIR. There are several ways of getting involved and contributing including reporting bugs, improving documentation, writing models or tutorials. 
+
+Please read our [Code of Conduct](https://github.com/tensorflow/tensorflow/blob/master/CODE_OF_CONDUCT.md) before participating.
+
+## Community Guidelines
+
+This project follows [Google's Open Source Community
+Guidelines](https://opensource.google.com/conduct/).
+
+## How to become a contributor and submit your own code
+
+### Contributor License Agreements
+
+We'd love to accept your patches! Before we can take them, please fill out either the individual or corporate Contributor License Agreement (CLA).
+
+* If you are an individual writing original source code and you're sure you own the intellectual property, then you'll need to sign an [individual CLA](https://code.google.com/legal/individual-cla-v1.0.html).
+  * If you work for a company that wants to allow you to contribute your work, then you'll need to sign a [corporate CLA](https://code.google.com/legal/corporate-cla-v1.0.html).
+
+Follow either of the two links above to access the appropriate CLA and instructions for how to sign and return it. Once we receive it, we'll be able to accept your pull requests.
+
+***NOTE***: Only original source code from you and other people that have signed the CLA can be accepted into the main repository.
+
+### Contributing code
+
+If you have improvements to MLIR, send us your pull requests! For those
+just getting started, GitHub has a [howto](https://help.github.com/articles/using-pull-requests/).
+
+MLIR team members will be assigned to review your pull requests. Once the pull requests are approved and pass continuous integration checks, a team member will merge your pull request submitted to our internal repository. After the change has been submitted internally, your pull request will be merged automatically on GitHub.
+
+If you want to contribute, start working through the MLIR codebase, navigate to [Github "issues" tab](https://github.com/tensorflow/mlir/issues) and start looking through interesting issues. If you decide to start on an issue, leave a comment so that other people know that you're working on it. If you want to help out, but not alone, use the issue comment thread to coordinate.
+
+### Contribution guidelines and standards
+
+*   Read the [developer guide](g3doc/DeveloperGuide.md).
+*   Ensure that you use the correct license. Examples are provided below.
+*   Include tests when you contribute new features, as they help to a)
+    prove that your code works correctly, and b) guard against future breaking
+    changes to lower the maintenance cost.
+*   Bug fixes also generally require tests, because the presence of bugs
+    usually indicates insufficient test coverage.
+
+#### License
+
+Include a license at the top of new files.
+
+* [C/C++ license example](https://github.com/tensorflow/mlir/blob/master/examples/toy/Ch1/toyc.cpp)
+* [Python license example](https://github.com/tensorflow/mlir/blob/master/bindings/python/test/test_py2and3.py)
diff --git a/third_party/mlir/LICENSE.TXT b/third_party/mlir/LICENSE.TXT
new file mode 100644
index 0000000..a4b160b
--- /dev/null
+++ b/third_party/mlir/LICENSE.TXT
@@ -0,0 +1,205 @@
+Copyright 2019 The MLIR Authors.
+
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "[]"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright [yyyy] [name of copyright owner]
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+     http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
+
+
diff --git a/third_party/mlir/README.md b/third_party/mlir/README.md
new file mode 100644
index 0000000..104d9ad
--- /dev/null
+++ b/third_party/mlir/README.md
@@ -0,0 +1,132 @@
+# Multi-Level Intermediate Representation Overview
+
+The MLIR project aims to define a common intermediate representation (IR) that
+will unify the infrastructure required to execute high performance machine
+learning models in TensorFlow and similar ML frameworks. This project will
+include the application of HPC techniques, along with integration of search
+algorithms like reinforcement learning. This project aims to reduce the cost to
+bring up new hardware, and improve usability for existing TensorFlow users.
+
+Note that this repository contains the core of the MLIR framework. The
+TensorFlow compilers we are building on top of MLIR will be part of the
+main TensorFlow repository soon.
+
+# How to Contribute
+
+Thank you for your interest in contributing to MLIR! If you want to contribute
+to MLIR, be sure to review the [contribution guidelines](CONTRIBUTING.md).
+
+## More resources
+
+For more information on MLIR, please see:
+
+*   [The MLIR draft specification](g3doc/LangRef.md), which describes the IR
+    itself.
+*   [The MLIR rationale document](g3doc/Rationale.md), covering motivation
+    behind some decisions.
+*   Previous external [talks](#mlir-talks).
+
+Join the [MLIR mailing list](https://groups.google.com/a/tensorflow.org/forum/#!forum/mlir)
+to hear about announcements and discussions.
+Please be mindful of the [TensorFlow Code of Conduct](https://github.com/tensorflow/tensorflow/blob/master/CODE_OF_CONDUCT.md),
+which pledges to foster an open and welcoming environment.
+
+## What is MLIR for?
+
+MLIR is intended to be a hybrid IR which can support multiple different
+requirements in a unified infrastructure. For example, this includes:
+
+*   The ability to represent all TensorFlow graphs, including dynamic shapes,
+    the user-extensible op ecosystem, TensorFlow variables, etc.
+*   Optimizations and transformations typically done on a TensorFlow graph, e.g.
+    in Grappler.
+*   Quantization and other graph transformations done on a TensorFlow graph or
+    the TF Lite representation.
+*   Representation of kernels for ML operations in a form suitable for
+    optimization.
+*   Ability to host high-performance-computing-style loop optimizations across
+    kernels (fusion, loop interchange, tiling, etc) and to transform memory
+    layouts of data.
+*   Code generation "lowering" transformations such as DMA insertion, explicit
+    cache management, memory tiling, and vectorization for 1D and 2D register
+    architectures.
+*   Ability to represent target-specific operations, e.g. the MXU on TPUs.
+
+MLIR is a common IR that also supports hardware specific operations. Thus,
+any investment into the infrastructure surrounding MLIR (e.g. the compiler
+passes that work on it) should yield good returns; many targets can use that
+infrastructure and will benefit from it.
+
+MLIR is a powerful representation, but it also has non-goals. We do not try to
+support low level machine code generation algorithms (like register allocation
+and instruction scheduling). They are a better fit for lower level optimizers
+(such as LLVM). Also, we do not intend MLIR to be a source language that
+end-users would themselves write kernels in (analogous to CUDA C++). While we
+would love to see a kernel language happen someday, that will be an independent
+project that compiles down to MLIR.
+
+## Compiler infrastructure
+
+We benefited from experience gained from building other IRs (HLO, LLVM and SIL)
+when building MLIR. We will directly adopt existing best practices, e.g. writing
+and maintaining an IR spec, building an IR verifier, providing the ability to
+dump and parse MLIR files to text, writing extensive unit tests with the
+[FileCheck](https://llvm.org/docs/CommandGuide/FileCheck.html) tool, and
+building the infrastructure as a set of modular libraries that can be combined
+in new ways. We plan to use the infrastructure developed by the XLA team for
+performance analysis and benchmarking.
+
+Other lessons have been incorporated and integrated into the design in subtle
+ways. For example, LLVM has non-obvious design mistakes that prevent a
+multithreaded compiler from working on multiple functions in an LLVM module at
+the same time. MLIR solves these problems by having per-function constant pools
+and by making references explicit with `function_ref`.
+
+# Getting started with MLIR
+
+The following instructions for compiling and testing MLIR assume that you have
+`git`, [`ninja`](https://ninja-build.org/), and a working C++ toolchain. In the
+future, we aim to align on the same level of platform support as
+[LLVM](https://llvm.org/docs/GettingStarted.html#requirements). For now, MLIR
+has been tested on Linux and macOS, with recent versions of clang and with
+gcc 7.
+
+```sh
+git clone https://github.com/llvm/llvm-project.git
+git clone https://github.com/tensorflow/mlir llvm-project/llvm/projects/mlir
+mkdir llvm-project/build
+cd llvm-project/build
+cmake -G Ninja ../llvm -DLLVM_BUILD_EXAMPLES=ON -DLLVM_ENABLE_CXX1Y=Y -DLLVM_TARGETS_TO_BUILD="host"
+cmake --build . --target check-mlir
+```
+
+To compile and test on Windows using Visual Studio 2017:
+
+```bat
+REM In shell with Visual Studio environment set up, e.g., with command such as
+REM   <visual-studio-install>\Auxiliary\Build\vcvarsall.bat" x64
+REM invoked.
+git clone https://github.com/llvm/llvm-project.git
+git clone https://github.com/tensorflow/mlir llvm-project\llvm\projects\mlir
+mkdir llvm-project\build
+cd llvm-project\build
+cmake ..\llvm -G "Visual Studio 15 2017 Win64" -DLLVM_BUILD_EXAMPLES=ON -DLLVM_ENABLE_CXX1Y=Y -DLLVM_TARGETS_TO_BUILD="host" -DCMAKE_BUILD_TYPE=Release -Thost=x64
+cmake --build . --target check-mlir
+```
+
+As a starter, you may try [the tutorial](g3doc/Tutorials/Toy/Ch-1.md) on
+building a compiler for a Toy language.
+
+# MLIR talks
+
+* "[MLIR Primer: A Compiler Infrastructure for the End of Moore’s Law](https://ai.google/research/pubs/pub48035.pdf)"
+  * Chris Lattner & Jacques Pienaar, Google at
+    [Compilers for Machine Learning](https://www.c4ml.org/) workshop at
+    [CGO 2019](http://cgo.org/cgo2019/)
+* "[MLIR: Multi-Level Intermediate Representation for Compiler
+    Infrastructure](https://llvm.org/devmtg/2019-04/talks.html#Keynote_1)"
+  * Tatiana Shpeisman & Chris Lattner, Google at
+    [EuroLLVM 2019](https://llvm.org/devmtg/2019-04)
+* "[Tutorial: Building a Compiler with MLIR](https://llvm.org/devmtg/2019-04/talks.html#Tutorial_1)"
+  * Mehdi Amini, Jacques Pienaar, Nicolas Vasilache, Google at
+    [EuroLLVM 2019](https://llvm.org/devmtg/2019-04)
diff --git a/third_party/mlir/WORKSPACE b/third_party/mlir/WORKSPACE
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/third_party/mlir/WORKSPACE
diff --git a/third_party/mlir/bindings/python/BUILD b/third_party/mlir/bindings/python/BUILD
index a539d0b..00e896c 100644
--- a/third_party/mlir/bindings/python/BUILD
+++ b/third_party/mlir/bindings/python/BUILD
@@ -27,8 +27,9 @@
     features = ["-use_header_modules"],
     module_name = "pybind",
     deps = [
-        "@llvm//:ir",
-        "@llvm//:support",
+        "//third_party/llvm/llvm:ir",
+        "//third_party/llvm/llvm:support",
+        "//third_party/pybind11",
         "@local_config_mlir//:EDSC",
         "@local_config_mlir//:ExecutionEngine",
         "@local_config_mlir//:IR",
@@ -37,6 +38,5 @@
         "@local_config_mlir//:StandardDialectRegistration",
         "@local_config_mlir//:TargetLLVMIR",
         "@local_config_mlir//:Transforms",
-        "@pybind11",
     ],
 )
diff --git a/third_party/mlir/bindings/python/pybind.cpp b/third_party/mlir/bindings/python/pybind.cpp
new file mode 100644
index 0000000..5efd08d
--- /dev/null
+++ b/third_party/mlir/bindings/python/pybind.cpp
@@ -0,0 +1,932 @@
+//===- pybind.cpp - MLIR Python bindings ----------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cstddef>
+#include <unordered_map>
+
+#include "mlir-c/Core.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Helpers.h"
+#include "mlir/EDSC/Intrinsics.h"
+#include "mlir/ExecutionEngine/ExecutionEngine.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Target/LLVMIR.h"
+#include "mlir/Transforms/Passes.h"
+#include "pybind11/pybind11.h"
+#include "pybind11/pytypes.h"
+#include "pybind11/stl.h"
+
+static bool inited = [] {
+  llvm::InitializeNativeTarget();
+  llvm::InitializeNativeTargetAsmPrinter();
+  return true;
+}();
+
+namespace mlir {
+namespace edsc {
+namespace python {
+
+namespace py = pybind11;
+
+struct PythonAttribute;
+struct PythonAttributedType;
+struct PythonBindable;
+struct PythonExpr;
+struct PythonFunctionContext;
+struct PythonStmt;
+struct PythonBlock;
+
+struct PythonType {
+  PythonType() : type{nullptr} {}
+  PythonType(mlir_type_t t) : type{t} {}
+
+  operator mlir_type_t() const { return type; }
+
+  PythonAttributedType attachAttributeDict(
+      const std::unordered_map<std::string, PythonAttribute> &attrs) const;
+
+  std::string str() {
+    mlir::Type f = mlir::Type::getFromOpaquePointer(type);
+    std::string res;
+    llvm::raw_string_ostream os(res);
+    f.print(os);
+    return res;
+  }
+
+  mlir_type_t type;
+};
+
+struct PythonValueHandle {
+  PythonValueHandle(PythonType type)
+      : value(mlir::Type::getFromOpaquePointer(type.type)) {}
+  PythonValueHandle(const PythonValueHandle &other) = default;
+  PythonValueHandle(const mlir::edsc::ValueHandle &other) : value(other) {}
+  operator ValueHandle() const { return value; }
+  operator ValueHandle &() { return value; }
+
+  std::string str() const {
+    return std::to_string(reinterpret_cast<intptr_t>(value.getValue()));
+  }
+
+  PythonValueHandle call(const std::vector<PythonValueHandle> &args) {
+    assert(value.hasType() && value.getType().isa<FunctionType>() &&
+           "can only call function-typed values");
+
+    std::vector<Value *> argValues;
+    argValues.reserve(args.size());
+    for (auto arg : args)
+      argValues.push_back(arg.value.getValue());
+    return ValueHandle::create<CallIndirectOp>(value, argValues);
+  }
+
+  mlir::edsc::ValueHandle value;
+};
+
+struct PythonFunction {
+  PythonFunction() : function{nullptr} {}
+  PythonFunction(mlir_func_t f) : function{f} {}
+  PythonFunction(mlir::FuncOp f)
+      : function(const_cast<void *>(f.getAsOpaquePointer())) {}
+  operator mlir_func_t() { return function; }
+  std::string str() {
+    mlir::FuncOp f = mlir::FuncOp::getFromOpaquePointer(function);
+    std::string res;
+    llvm::raw_string_ostream os(res);
+    f.print(os);
+    return res;
+  }
+
+  // If the function does not yet have an entry block, i.e. if it is a function
+  // declaration, add the entry block, transforming the declaration into a
+  // definition.  Return true if the block was added, false otherwise.
+  bool define() {
+    auto f = mlir::FuncOp::getFromOpaquePointer(function);
+    if (!f.getBlocks().empty())
+      return false;
+
+    f.addEntryBlock();
+    return true;
+  }
+
+  PythonValueHandle arg(unsigned index) {
+    auto f = mlir::FuncOp::getFromOpaquePointer(function);
+    assert(index < f.getNumArguments() && "argument index out of bounds");
+    return PythonValueHandle(ValueHandle(f.getArgument(index)));
+  }
+
+  mlir_func_t function;
+};
+
+/// Trivial C++ wrappers make use of the EDSC C API.
+struct PythonMLIRModule {
+  PythonMLIRModule()
+      : mlirContext(),
+        module(mlir::ModuleOp::create(mlir::UnknownLoc::get(&mlirContext))),
+        moduleManager(*module) {}
+
+  PythonType makeScalarType(const std::string &mlirElemType,
+                            unsigned bitwidth) {
+    return ::makeScalarType(mlir_context_t{&mlirContext}, mlirElemType.c_str(),
+                            bitwidth);
+  }
+  PythonType makeMemRefType(PythonType elemType, std::vector<int64_t> sizes) {
+    return ::makeMemRefType(mlir_context_t{&mlirContext}, elemType,
+                            int64_list_t{sizes.data(), sizes.size()});
+  }
+  PythonType makeIndexType() {
+    return ::makeIndexType(mlir_context_t{&mlirContext});
+  }
+
+  // Declare a function with the given name, input types and their attributes,
+  // output types, and function attributes, but do not define it.
+  PythonFunction declareFunction(const std::string &name,
+                                 const py::list &inputs,
+                                 const std::vector<PythonType> &outputTypes,
+                                 const py::kwargs &funcAttributes);
+
+  // Declare a function with the given name, input types and their attributes,
+  // output types, and function attributes.
+  PythonFunction makeFunction(const std::string &name, const py::list &inputs,
+                              const std::vector<PythonType> &outputTypes,
+                              const py::kwargs &funcAttributes) {
+    auto declaration =
+        declareFunction(name, inputs, outputTypes, funcAttributes);
+    declaration.define();
+    return declaration;
+  }
+
+  // Create a custom op given its name and arguments.
+  PythonExpr op(const std::string &name, PythonType type,
+                const py::list &arguments, const py::list &successors,
+                py::kwargs attributes);
+
+  // Create an integer attribute.
+  PythonAttribute integerAttr(PythonType type, int64_t value);
+
+  // Create a boolean attribute.
+  PythonAttribute boolAttr(bool value);
+
+  void compile() {
+    PassManager manager;
+    manager.addPass(mlir::createCanonicalizerPass());
+    manager.addPass(mlir::createCSEPass());
+    manager.addPass(mlir::createLowerAffinePass());
+    manager.addPass(mlir::createConvertToLLVMIRPass());
+    if (failed(manager.run(*module))) {
+      llvm::errs() << "conversion to the LLVM IR dialect failed\n";
+      return;
+    }
+
+    auto created = mlir::ExecutionEngine::create(*module);
+    llvm::handleAllErrors(created.takeError(),
+                          [](const llvm::ErrorInfoBase &b) {
+                            b.log(llvm::errs());
+                            assert(false);
+                          });
+    engine = std::move(*created);
+  }
+
+  std::string getIR() {
+    std::string res;
+    llvm::raw_string_ostream os(res);
+    module->print(os);
+    return res;
+  }
+
+  uint64_t getEngineAddress() {
+    assert(engine && "module must be compiled into engine first");
+    return reinterpret_cast<uint64_t>(reinterpret_cast<void *>(engine.get()));
+  }
+
+  PythonFunction getNamedFunction(const std::string &name) {
+    return moduleManager.lookupSymbol<FuncOp>(name);
+  }
+
+  PythonFunctionContext
+  makeFunctionContext(const std::string &name, const py::list &inputs,
+                      const std::vector<PythonType> &outputs,
+                      const py::kwargs &attributes);
+
+private:
+  mlir::MLIRContext mlirContext;
+  // One single module in a python-exposed MLIRContext for now.
+  mlir::OwningModuleRef module;
+  mlir::ModuleManager moduleManager;
+  std::unique_ptr<mlir::ExecutionEngine> engine;
+};
+
+struct PythonFunctionContext {
+  PythonFunctionContext(PythonFunction f) : function(f) {}
+  PythonFunctionContext(PythonMLIRModule &module, const std::string &name,
+                        const py::list &inputs,
+                        const std::vector<PythonType> &outputs,
+                        const py::kwargs &attributes) {
+    auto function = module.declareFunction(name, inputs, outputs, attributes);
+    function.define();
+  }
+
+  PythonFunction enter() {
+    assert(function.function && "function is not set up");
+    auto mlirFunc = mlir::FuncOp::getFromOpaquePointer(function.function);
+    contextBuilder.emplace(mlirFunc.getBody());
+    context = new mlir::edsc::ScopedContext(*contextBuilder, mlirFunc.getLoc());
+    return function;
+  }
+
+  void exit(py::object, py::object, py::object) {
+    delete context;
+    context = nullptr;
+    contextBuilder.reset();
+  }
+
+  PythonFunction function;
+  mlir::edsc::ScopedContext *context;
+  llvm::Optional<OpBuilder> contextBuilder;
+};
+
+PythonFunctionContext PythonMLIRModule::makeFunctionContext(
+    const std::string &name, const py::list &inputs,
+    const std::vector<PythonType> &outputs, const py::kwargs &attributes) {
+  auto func = declareFunction(name, inputs, outputs, attributes);
+  func.define();
+  return PythonFunctionContext(func);
+}
+
+struct PythonBlockHandle {
+  PythonBlockHandle() : value(nullptr) {}
+  PythonBlockHandle(const PythonBlockHandle &other) = default;
+  PythonBlockHandle(const mlir::edsc::BlockHandle &other) : value(other) {}
+  operator mlir::edsc::BlockHandle() const { return value; }
+
+  PythonValueHandle arg(int index) { return arguments[index]; }
+
+  std::string str() {
+    std::string s;
+    llvm::raw_string_ostream os(s);
+    value.getBlock()->print(os);
+    return os.str();
+  }
+
+  mlir::edsc::BlockHandle value;
+  std::vector<mlir::edsc::ValueHandle> arguments;
+};
+
+struct PythonLoopContext {
+  PythonLoopContext(PythonValueHandle lb, PythonValueHandle ub, int64_t step)
+      : lb(lb), ub(ub), step(step) {}
+  PythonLoopContext(const PythonLoopContext &) = delete;
+  PythonLoopContext(PythonLoopContext &&) = default;
+  PythonLoopContext &operator=(const PythonLoopContext &) = delete;
+  PythonLoopContext &operator=(PythonLoopContext &&) = default;
+  ~PythonLoopContext() { assert(!builder && "did not exit from the context"); }
+
+  PythonValueHandle enter() {
+    ValueHandle iv(lb.value.getType());
+    builder = new LoopBuilder(&iv, lb.value, ub.value, step);
+    return iv;
+  }
+
+  void exit(py::object, py::object, py::object) {
+    (*builder)({}); // exit from the builder's scope.
+    delete builder;
+    builder = nullptr;
+  }
+
+  PythonValueHandle lb, ub;
+  int64_t step;
+  LoopBuilder *builder = nullptr;
+};
+
+struct PythonLoopNestContext {
+  PythonLoopNestContext(const std::vector<PythonValueHandle> &lbs,
+                        const std::vector<PythonValueHandle> &ubs,
+                        const std::vector<int64_t> steps)
+      : lbs(lbs), ubs(ubs), steps(steps) {
+    assert(lbs.size() == ubs.size() && lbs.size() == steps.size() &&
+           "expected the same number of lower, upper bounds, and steps");
+  }
+  PythonLoopNestContext(const PythonLoopNestContext &) = delete;
+  PythonLoopNestContext(PythonLoopNestContext &&) = default;
+  PythonLoopNestContext &operator=(const PythonLoopNestContext &) = delete;
+  PythonLoopNestContext &operator=(PythonLoopNestContext &&) = default;
+  ~PythonLoopNestContext() {
+    assert(!builder && "did not exit from the context");
+  }
+
+  std::vector<PythonValueHandle> enter() {
+    if (steps.empty())
+      return {};
+
+    auto type = mlir_type_t(lbs.front().value.getType().getAsOpaquePointer());
+    std::vector<PythonValueHandle> handles(steps.size(),
+                                           PythonValueHandle(type));
+    std::vector<ValueHandle *> handlePtrs;
+    handlePtrs.reserve(steps.size());
+    for (auto &h : handles)
+      handlePtrs.push_back(&h.value);
+    builder = new LoopNestBuilder(
+        handlePtrs, std::vector<ValueHandle>(lbs.begin(), lbs.end()),
+        std::vector<ValueHandle>(ubs.begin(), ubs.end()), steps);
+    return handles;
+  }
+
+  void exit(py::object, py::object, py::object) {
+    (*builder)({}); // exit from the builder's scope.
+    delete builder;
+    builder = nullptr;
+  }
+
+  std::vector<PythonValueHandle> lbs;
+  std::vector<PythonValueHandle> ubs;
+  std::vector<int64_t> steps;
+  LoopNestBuilder *builder = nullptr;
+};
+
+struct PythonBlockAppender {
+  PythonBlockAppender(const PythonBlockHandle &handle) : handle(handle) {}
+  PythonBlockHandle handle;
+};
+
+struct PythonBlockContext {
+public:
+  PythonBlockContext() {
+    createBlockBuilder();
+    clearBuilder();
+  }
+  PythonBlockContext(const std::vector<PythonType> &argTypes) {
+    handle.arguments.reserve(argTypes.size());
+    for (const auto &t : argTypes) {
+      auto type =
+          Type::getFromOpaquePointer(reinterpret_cast<const void *>(t.type));
+      handle.arguments.emplace_back(type);
+    }
+    createBlockBuilder();
+    clearBuilder();
+  }
+  PythonBlockContext(const PythonBlockAppender &a) : handle(a.handle) {}
+  PythonBlockContext(const PythonBlockContext &) = delete;
+  PythonBlockContext(PythonBlockContext &&) = default;
+  PythonBlockContext &operator=(const PythonBlockContext &) = delete;
+  PythonBlockContext &operator=(PythonBlockContext &&) = default;
+  ~PythonBlockContext() {
+    assert(!builder && "did not exit from the block context");
+  }
+
+  // EDSC maintain an implicit stack of builders (mostly for keeping track of
+  // insretion points); every operation gets inserted using the top-of-the-stack
+  // builder.  Creating a new EDSC Builder automatically puts it on the stack,
+  // effectively entering the block for it.
+  void createBlockBuilder() {
+    if (handle.value.getBlock()) {
+      builder = new BlockBuilder(handle.value, mlir::edsc::Append());
+    } else {
+      std::vector<ValueHandle *> args;
+      args.reserve(handle.arguments.size());
+      for (auto &a : handle.arguments)
+        args.push_back(&a);
+      builder = new BlockBuilder(&handle.value, args);
+    }
+  }
+
+  PythonBlockHandle enter() {
+    createBlockBuilder();
+    return handle;
+  }
+
+  void exit(py::object, py::object, py::object) { clearBuilder(); }
+
+  PythonBlockHandle getHandle() { return handle; }
+
+  // EDSC maintain an implicit stack of builders (mostly for keeping track of
+  // insretion points); every operation gets inserted using the top-of-the-stack
+  // builder.  Calling operator() on a builder pops the builder from the stack,
+  // effectively resetting the insertion point to its position before we entered
+  // the block.
+  void clearBuilder() {
+    (*builder)({}); // exit from the builder's scope.
+    delete builder;
+    builder = nullptr;
+  }
+
+  PythonBlockHandle handle;
+  BlockBuilder *builder = nullptr;
+};
+
+struct PythonAttribute {
+  PythonAttribute() : attr(nullptr) {}
+  PythonAttribute(const mlir_attr_t &a) : attr(a) {}
+  PythonAttribute(const PythonAttribute &other) = default;
+  operator mlir_attr_t() { return attr; }
+
+  std::string str() const {
+    if (!attr)
+      return "##null attr##";
+
+    std::string res;
+    llvm::raw_string_ostream os(res);
+    Attribute::getFromOpaquePointer(reinterpret_cast<const void *>(attr))
+        .print(os);
+    return res;
+  }
+
+  mlir_attr_t attr;
+};
+
+struct PythonAttributedType {
+  PythonAttributedType() : type(nullptr) {}
+  PythonAttributedType(mlir_type_t t) : type(t) {}
+  PythonAttributedType(
+      PythonType t,
+      const std::unordered_map<std::string, PythonAttribute> &attributes =
+          std::unordered_map<std::string, PythonAttribute>())
+      : type(t), attrs(attributes) {}
+
+  operator mlir_type_t() const { return type.type; }
+  operator PythonType() const { return type; }
+
+  // Return a vector of named attribute descriptors.  The vector owns the
+  // mlir_named_attr_t objects it contains, but not the names and attributes
+  // those objects point to (names and opaque pointers to attributes are owned
+  // by `this`).
+  std::vector<mlir_named_attr_t> getNamedAttrs() const {
+    std::vector<mlir_named_attr_t> result;
+    result.reserve(attrs.size());
+    for (const auto &namedAttr : attrs)
+      result.push_back({namedAttr.first.c_str(), namedAttr.second.attr});
+    return result;
+  }
+
+  std::string str() {
+    mlir::Type t = mlir::Type::getFromOpaquePointer(type);
+    std::string res;
+    llvm::raw_string_ostream os(res);
+    t.print(os);
+    if (attrs.empty())
+      return os.str();
+
+    os << '{';
+    bool first = true;
+    for (const auto &namedAttr : attrs) {
+      if (first)
+        first = false;
+      else
+        os << ", ";
+      os << namedAttr.first << ": " << namedAttr.second.str();
+    }
+    os << '}';
+
+    return os.str();
+  }
+
+private:
+  PythonType type;
+  std::unordered_map<std::string, PythonAttribute> attrs;
+};
+
+struct PythonIndexedValue {
+  explicit PythonIndexedValue(PythonType type)
+      : indexed(Type::getFromOpaquePointer(type.type)) {}
+  explicit PythonIndexedValue(const IndexedValue &other) : indexed(other) {}
+  PythonIndexedValue(PythonValueHandle handle) : indexed(handle.value) {}
+  PythonIndexedValue(const PythonIndexedValue &other) = default;
+
+  // Create a new indexed value with the same base as this one but with indices
+  // provided as arguments.
+  PythonIndexedValue index(const std::vector<PythonValueHandle> &indices) {
+    std::vector<ValueHandle> handles(indices.begin(), indices.end());
+    return PythonIndexedValue(IndexedValue(indexed(handles)));
+  }
+
+  void store(const std::vector<PythonValueHandle> &indices,
+             PythonValueHandle value) {
+    // Uses the overloaded `opreator=` to emit a store.
+    index(indices).indexed = value.value;
+  }
+
+  PythonValueHandle load(const std::vector<PythonValueHandle> &indices) {
+    // Uses the overloaded cast to `ValueHandle` to emit a load.
+    return static_cast<ValueHandle>(index(indices).indexed);
+  }
+
+  IndexedValue indexed;
+};
+
+template <typename ListTy, typename PythonTy, typename Ty>
+ListTy makeCList(SmallVectorImpl<Ty> &owning, const py::list &list) {
+  for (auto &inp : list) {
+    owning.push_back(Ty{inp.cast<PythonTy>()});
+  }
+  return ListTy{owning.data(), owning.size()};
+}
+
+static mlir_type_list_t makeCTypes(llvm::SmallVectorImpl<mlir_type_t> &owning,
+                                   const py::list &types) {
+  return makeCList<mlir_type_list_t, PythonType>(owning, types);
+}
+
+PythonFunction
+PythonMLIRModule::declareFunction(const std::string &name,
+                                  const py::list &inputs,
+                                  const std::vector<PythonType> &outputTypes,
+                                  const py::kwargs &funcAttributes) {
+
+  std::vector<PythonAttributedType> attributedInputs;
+  attributedInputs.reserve(inputs.size());
+  for (const auto &in : inputs) {
+    std::string className = in.get_type().str();
+    if (className.find(".Type'") != std::string::npos)
+      attributedInputs.emplace_back(in.cast<PythonType>());
+    else
+      attributedInputs.push_back(in.cast<PythonAttributedType>());
+  }
+
+  // Create the function type.
+  std::vector<mlir_type_t> ins(attributedInputs.begin(),
+                               attributedInputs.end());
+  std::vector<mlir_type_t> outs(outputTypes.begin(), outputTypes.end());
+  auto funcType = ::makeFunctionType(
+      mlir_context_t{&mlirContext}, mlir_type_list_t{ins.data(), ins.size()},
+      mlir_type_list_t{outs.data(), outs.size()});
+
+  // Build the list of function attributes.
+  std::vector<mlir::NamedAttribute> attrs;
+  attrs.reserve(funcAttributes.size());
+  for (const auto &named : funcAttributes)
+    attrs.emplace_back(
+        Identifier::get(std::string(named.first.str()), &mlirContext),
+        mlir::Attribute::getFromOpaquePointer(reinterpret_cast<const void *>(
+            named.second.cast<PythonAttribute>().attr)));
+
+  // Build the list of lists of function argument attributes.
+  std::vector<mlir::NamedAttributeList> inputAttrs;
+  inputAttrs.reserve(attributedInputs.size());
+  for (const auto &in : attributedInputs) {
+    std::vector<mlir::NamedAttribute> inAttrs;
+    for (const auto &named : in.getNamedAttrs())
+      inAttrs.emplace_back(Identifier::get(named.name, &mlirContext),
+                           mlir::Attribute::getFromOpaquePointer(
+                               reinterpret_cast<const void *>(named.value)));
+    inputAttrs.emplace_back(inAttrs);
+  }
+
+  // Create the function itself.
+  auto func = mlir::FuncOp::create(
+      UnknownLoc::get(&mlirContext), name,
+      mlir::Type::getFromOpaquePointer(funcType).cast<FunctionType>(), attrs,
+      inputAttrs);
+  moduleManager.insert(func);
+  return func;
+}
+
+PythonAttributedType PythonType::attachAttributeDict(
+    const std::unordered_map<std::string, PythonAttribute> &attrs) const {
+  return PythonAttributedType(*this, attrs);
+}
+
+PythonAttribute PythonMLIRModule::integerAttr(PythonType type, int64_t value) {
+  return PythonAttribute(::makeIntegerAttr(type, value));
+}
+
+PythonAttribute PythonMLIRModule::boolAttr(bool value) {
+  return PythonAttribute(::makeBoolAttr(&mlirContext, value));
+}
+
+PYBIND11_MODULE(pybind, m) {
+  m.doc() =
+      "Python bindings for MLIR Embedded Domain-Specific Components (EDSCs)";
+  m.def("version", []() { return "EDSC Python extensions v1.0"; });
+
+  py::class_<PythonLoopContext>(
+      m, "LoopContext", "A context for building the body of a 'for' loop")
+      .def(py::init<PythonValueHandle, PythonValueHandle, int64_t>())
+      .def("__enter__", &PythonLoopContext::enter)
+      .def("__exit__", &PythonLoopContext::exit);
+
+  py::class_<PythonLoopNestContext>(m, "LoopNestContext",
+                                    "A context for building the body of a the "
+                                    "innermost loop in a nest of 'for' loops")
+      .def(py::init<const std::vector<PythonValueHandle> &,
+                    const std::vector<PythonValueHandle> &,
+                    const std::vector<int64_t> &>())
+      .def("__enter__", &PythonLoopNestContext::enter)
+      .def("__exit__", &PythonLoopNestContext::exit);
+
+  m.def("constant_index", [](int64_t val) -> PythonValueHandle {
+    return ValueHandle(index_t(val));
+  });
+  m.def("constant_int", [](int64_t val, int width) -> PythonValueHandle {
+    return ValueHandle::create<ConstantIntOp>(val, width);
+  });
+  m.def("constant_float", [](double val, PythonType type) -> PythonValueHandle {
+    FloatType floatType =
+        Type::getFromOpaquePointer(type.type).cast<FloatType>();
+    assert(floatType);
+    auto value = APFloat(val);
+    bool lostPrecision;
+    value.convert(floatType.getFloatSemantics(), APFloat::rmNearestTiesToEven,
+                  &lostPrecision);
+    return ValueHandle::create<ConstantFloatOp>(value, floatType);
+  });
+  m.def("constant_function", [](PythonFunction func) -> PythonValueHandle {
+    auto function = FuncOp::getFromOpaquePointer(func.function);
+    auto attr = SymbolRefAttr::get(function.getName(), function.getContext());
+    return ValueHandle::create<ConstantOp>(function.getType(), attr);
+  });
+  m.def("appendTo", [](const PythonBlockHandle &handle) {
+    return PythonBlockAppender(handle);
+  });
+  m.def(
+      "ret",
+      [](const std::vector<PythonValueHandle> &args) {
+        std::vector<ValueHandle> values(args.begin(), args.end());
+        (intrinsics::ret(ArrayRef<ValueHandle>{values})); // vexing parse
+        return PythonValueHandle(nullptr);
+      },
+      py::arg("args") = std::vector<PythonValueHandle>());
+  m.def(
+      "br",
+      [](const PythonBlockHandle &dest,
+         const std::vector<PythonValueHandle> &args) {
+        std::vector<ValueHandle> values(args.begin(), args.end());
+        intrinsics::br(dest, values);
+        return PythonValueHandle(nullptr);
+      },
+      py::arg("dest"), py::arg("args") = std::vector<PythonValueHandle>());
+  m.def(
+      "cond_br",
+      [](PythonValueHandle condition, const PythonBlockHandle &trueDest,
+         const std::vector<PythonValueHandle> &trueArgs,
+         const PythonBlockHandle &falseDest,
+         const std::vector<PythonValueHandle> &falseArgs) -> PythonValueHandle {
+        std::vector<ValueHandle> trueArguments(trueArgs.begin(),
+                                               trueArgs.end());
+        std::vector<ValueHandle> falseArguments(falseArgs.begin(),
+                                                falseArgs.end());
+        intrinsics::cond_br(condition, trueDest, trueArguments, falseDest,
+                            falseArguments);
+        return PythonValueHandle(nullptr);
+      });
+  m.def("select",
+        [](PythonValueHandle condition, PythonValueHandle trueValue,
+           PythonValueHandle falseValue) -> PythonValueHandle {
+          return ValueHandle::create<SelectOp>(condition.value, trueValue.value,
+                                               falseValue.value);
+        });
+  m.def("op",
+        [](const std::string &name,
+           const std::vector<PythonValueHandle> &operands,
+           const std::vector<PythonType> &resultTypes,
+           const py::kwargs &attributes) -> PythonValueHandle {
+          std::vector<ValueHandle> operandHandles(operands.begin(),
+                                                  operands.end());
+          std::vector<Type> types;
+          types.reserve(resultTypes.size());
+          for (auto t : resultTypes)
+            types.push_back(Type::getFromOpaquePointer(t.type));
+
+          std::vector<NamedAttribute> attrs;
+          attrs.reserve(attributes.size());
+          for (const auto &a : attributes) {
+            std::string name = a.first.str();
+            auto pyAttr = a.second.cast<PythonAttribute>();
+            auto cppAttr = Attribute::getFromOpaquePointer(pyAttr.attr);
+            auto identifier =
+                Identifier::get(name, ScopedContext::getContext());
+            attrs.emplace_back(identifier, cppAttr);
+          }
+
+          return ValueHandle::create(name, operandHandles, types, attrs);
+        });
+
+  py::class_<PythonFunction>(m, "Function", "Wrapping class for mlir::FuncOp.")
+      .def(py::init<PythonFunction>())
+      .def("__str__", &PythonFunction::str)
+      .def("define", &PythonFunction::define,
+           "Adds a body to the function if it does not already have one.  "
+           "Returns true if the body was added")
+      .def("arg", &PythonFunction::arg,
+           "Get the ValueHandle to the indexed argument of the function");
+
+  py::class_<PythonAttribute>(m, "Attribute",
+                              "Wrapping class for mlir::Attribute")
+      .def(py::init<PythonAttribute>())
+      .def("__str__", &PythonAttribute::str);
+
+  py::class_<PythonType>(m, "Type", "Wrapping class for mlir::Type.")
+      .def(py::init<PythonType>())
+      .def("__call__", &PythonType::attachAttributeDict,
+           "Attach the attributes to these type, making it suitable for "
+           "constructing functions with argument attributes")
+      .def("__str__", &PythonType::str);
+
+  py::class_<PythonAttributedType>(
+      m, "AttributedType",
+      "A class containing a wrapped mlir::Type and a wrapped "
+      "mlir::NamedAttributeList that are used together, e.g. in function "
+      "argument declaration")
+      .def(py::init<PythonAttributedType>())
+      .def("__str__", &PythonAttributedType::str);
+
+  py::class_<PythonMLIRModule>(
+      m, "MLIRModule",
+      "An MLIRModule is the abstraction that owns the allocations to support "
+      "compilation of a single mlir::ModuleOp into an ExecutionEngine backed "
+      "by "
+      "the LLVM ORC JIT. A typical flow consists in creating an MLIRModule, "
+      "adding functions, compiling the module to obtain an ExecutionEngine on "
+      "which named functions may be called. For now the only means to retrieve "
+      "the ExecutionEngine is by calling `get_engine_address`. This mode of "
+      "execution is limited to passing the pointer to C++ where the function "
+      "is called. Extending the API to allow calling JIT compiled functions "
+      "directly require integration with a tensor library (e.g. numpy). This "
+      "is left as the prerogative of libraries and frameworks for now.")
+      .def(py::init<>())
+      .def("boolAttr", &PythonMLIRModule::boolAttr,
+           "Creates an mlir::BoolAttr with the given value")
+      .def(
+          "integerAttr", &PythonMLIRModule::integerAttr,
+          "Creates an mlir::IntegerAttr of the given type with the given value "
+          "in the context associated with this MLIR module.")
+      .def("declare_function", &PythonMLIRModule::declareFunction,
+           "Declares a new mlir::FuncOp in the current mlir::ModuleOp.  The "
+           "function arguments can have attributes.  The function has no "
+           "definition and can be linked to an external library.")
+      .def("make_function", &PythonMLIRModule::makeFunction,
+           "Defines a new mlir::FuncOp in the current mlir::ModuleOp.")
+      .def("function_context", &PythonMLIRModule::makeFunctionContext,
+           "Defines a new mlir::FuncOp in the mlir::ModuleOp and creates the "
+           "function context for building the body of the function.")
+      .def("get_function", &PythonMLIRModule::getNamedFunction,
+           "Looks up the function with the given name in the module.")
+      .def(
+          "make_scalar_type",
+          [](PythonMLIRModule &instance, const std::string &type,
+             unsigned bitwidth) {
+            return instance.makeScalarType(type, bitwidth);
+          },
+          py::arg("type"), py::arg("bitwidth") = 0,
+          "Returns a scalar mlir::Type using the following convention:\n"
+          "  - makeScalarType(c, \"bf16\") return an "
+          "`mlir::FloatType::getBF16`\n"
+          "  - makeScalarType(c, \"f16\") return an `mlir::FloatType::getF16`\n"
+          "  - makeScalarType(c, \"f32\") return an `mlir::FloatType::getF32`\n"
+          "  - makeScalarType(c, \"f64\") return an `mlir::FloatType::getF64`\n"
+          "  - makeScalarType(c, \"index\") return an `mlir::IndexType::get`\n"
+          "  - makeScalarType(c, \"i\", bitwidth) return an "
+          "`mlir::IntegerType::get(bitwidth)`\n\n"
+          " No other combinations are currently supported.")
+      .def("make_memref_type", &PythonMLIRModule::makeMemRefType,
+           "Returns an mlir::MemRefType of an elemental scalar. -1 is used to "
+           "denote symbolic dimensions in the resulting memref shape.")
+      .def("make_index_type", &PythonMLIRModule::makeIndexType,
+           "Returns an mlir::IndexType")
+      .def("compile", &PythonMLIRModule::compile,
+           "Compiles the mlir::ModuleOp to LLVMIR a creates new opaque "
+           "ExecutionEngine backed by the ORC JIT.")
+      .def("get_ir", &PythonMLIRModule::getIR,
+           "Returns a dump of the MLIR representation of the module. This is "
+           "used for serde to support out-of-process execution as well as "
+           "debugging purposes.")
+      .def("get_engine_address", &PythonMLIRModule::getEngineAddress,
+           "Returns the address of the compiled ExecutionEngine. This is used "
+           "for in-process execution.")
+      .def("__str__", &PythonMLIRModule::getIR,
+           "Get the string representation of the module");
+
+  py::class_<PythonFunctionContext>(
+      m, "FunctionContext", "A wrapper around mlir::edsc::ScopedContext")
+      .def(py::init<PythonFunction>())
+      .def("__enter__", &PythonFunctionContext::enter)
+      .def("__exit__", &PythonFunctionContext::exit);
+
+  {
+    using namespace mlir::edsc::op;
+    py::class_<PythonValueHandle>(m, "ValueHandle",
+                                  "A wrapper around mlir::edsc::ValueHandle")
+        .def(py::init<PythonType>())
+        .def(py::init<PythonValueHandle>())
+        .def("__add__",
+             [](PythonValueHandle lhs, PythonValueHandle rhs)
+                 -> PythonValueHandle { return lhs.value + rhs.value; })
+        .def("__sub__",
+             [](PythonValueHandle lhs, PythonValueHandle rhs)
+                 -> PythonValueHandle { return lhs.value - rhs.value; })
+        .def("__mul__",
+             [](PythonValueHandle lhs, PythonValueHandle rhs)
+                 -> PythonValueHandle { return lhs.value * rhs.value; })
+        .def("__div__",
+             [](PythonValueHandle lhs, PythonValueHandle rhs)
+                 -> PythonValueHandle { return lhs.value / rhs.value; })
+        .def("__truediv__",
+             [](PythonValueHandle lhs, PythonValueHandle rhs)
+                 -> PythonValueHandle { return lhs.value / rhs.value; })
+        .def("__floordiv__",
+             [](PythonValueHandle lhs, PythonValueHandle rhs)
+                 -> PythonValueHandle { return floorDiv(lhs, rhs); })
+        .def("__mod__",
+             [](PythonValueHandle lhs, PythonValueHandle rhs)
+                 -> PythonValueHandle { return lhs.value % rhs.value; })
+        .def("__lt__",
+             [](PythonValueHandle lhs,
+                PythonValueHandle rhs) -> PythonValueHandle {
+               return ValueHandle::create<CmpIOp>(CmpIPredicate::SLT, lhs.value,
+                                                  rhs.value);
+             })
+        .def("__le__",
+             [](PythonValueHandle lhs,
+                PythonValueHandle rhs) -> PythonValueHandle {
+               return ValueHandle::create<CmpIOp>(CmpIPredicate::SLE, lhs.value,
+                                                  rhs.value);
+             })
+        .def("__gt__",
+             [](PythonValueHandle lhs,
+                PythonValueHandle rhs) -> PythonValueHandle {
+               return ValueHandle::create<CmpIOp>(CmpIPredicate::SGT, lhs.value,
+                                                  rhs.value);
+             })
+        .def("__ge__",
+             [](PythonValueHandle lhs,
+                PythonValueHandle rhs) -> PythonValueHandle {
+               return ValueHandle::create<CmpIOp>(CmpIPredicate::SGE, lhs.value,
+                                                  rhs.value);
+             })
+        .def("__eq__",
+             [](PythonValueHandle lhs,
+                PythonValueHandle rhs) -> PythonValueHandle {
+               return ValueHandle::create<CmpIOp>(CmpIPredicate::EQ, lhs.value,
+                                                  rhs.value);
+             })
+        .def("__ne__",
+             [](PythonValueHandle lhs,
+                PythonValueHandle rhs) -> PythonValueHandle {
+               return ValueHandle::create<CmpIOp>(CmpIPredicate::NE, lhs.value,
+                                                  rhs.value);
+             })
+        .def("__invert__",
+             [](PythonValueHandle handle) -> PythonValueHandle {
+               return !handle.value;
+             })
+        .def("__and__",
+             [](PythonValueHandle lhs, PythonValueHandle rhs)
+                 -> PythonValueHandle { return lhs.value && rhs.value; })
+        .def("__or__",
+             [](PythonValueHandle lhs, PythonValueHandle rhs)
+                 -> PythonValueHandle { return lhs.value || rhs.value; })
+        .def("__call__", &PythonValueHandle::call);
+  }
+
+  py::class_<PythonBlockAppender>(
+      m, "BlockAppender",
+      "A dummy class signaling BlockContext to append IR to the given block "
+      "instead of creating a new block")
+      .def(py::init<const PythonBlockHandle &>());
+  py::class_<PythonBlockHandle>(m, "BlockHandle",
+                                "A wrapper around mlir::edsc::BlockHandle")
+      .def(py::init<PythonBlockHandle>())
+      .def("arg", &PythonBlockHandle::arg);
+
+  py::class_<PythonBlockContext>(m, "BlockContext",
+                                 "A wrapper around mlir::edsc::BlockBuilder")
+      .def(py::init<>())
+      .def(py::init<const std::vector<PythonType> &>())
+      .def(py::init<const PythonBlockAppender &>())
+      .def("__enter__", &PythonBlockContext::enter)
+      .def("__exit__", &PythonBlockContext::exit)
+      .def("handle", &PythonBlockContext::getHandle);
+
+  py::class_<PythonIndexedValue>(m, "IndexedValue",
+                                 "A wrapper around mlir::edsc::IndexedValue")
+      .def(py::init<PythonValueHandle>())
+      .def("load", &PythonIndexedValue::load)
+      .def("store", &PythonIndexedValue::store);
+}
+
+} // namespace python
+} // namespace edsc
+} // namespace mlir
diff --git a/third_party/mlir/bindings/python/test/BUILD b/third_party/mlir/bindings/python/test/BUILD
new file mode 100644
index 0000000..36fe5cb
--- /dev/null
+++ b/third_party/mlir/bindings/python/test/BUILD
@@ -0,0 +1,36 @@
+# Description:
+#   BUILD file for the Python wrappers for EDSCs
+
+licenses(["notice"])  # Apache 2.0
+
+# Export the BUILD file so automated tooling can check licenses
+exports_files(["BUILD"])
+
+load("//third_party/llvm/build_defs:lit.bzl", "glob_lit_tests")
+
+glob_lit_tests(
+    data = [":test_utilities"],
+    driver = "@local_config_mlir//:run_lit.sh",
+    test_file_exts = ["py"],
+)
+
+# Bundle together all of the test utilities that are used by tests.
+filegroup(
+    name = "test_utilities",
+    testonly = True,
+    data = [
+        ":test_edsc",
+        "//third_party/llvm/llvm:FileCheck",
+    ],
+)
+
+py_binary(
+    name = "test_edsc",
+    srcs = ["test_py2and3.py"],
+    main = "test_py2and3.py",
+    python_version = "PY2",
+    deps = [
+        "//testing/pybase",
+        "@local_config_mlir//bindings/python:_pybind",
+    ],
+)
diff --git a/third_party/mlir/bindings/python/test/test_py2and3.py b/third_party/mlir/bindings/python/test/test_py2and3.py
new file mode 100644
index 0000000..c658c94
--- /dev/null
+++ b/third_party/mlir/bindings/python/test/test_py2and3.py
@@ -0,0 +1,486 @@
+# Copyright 2019 The MLIR Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# RUN: %p/test_edsc %s | FileCheck %s
+"""Python2 and 3 test for the MLIR EDSC Python bindings"""
+
+import google_mlir.bindings.python.pybind as E
+import inspect
+
+# Prints `str` prefixed by the current test function name so we can use it in
+# Filecheck label directives.
+# This is achieved by inspecting the stack and getting the parent name.
+def printWithCurrentFunctionName(str):
+  print(inspect.stack()[1][3])
+  print(str)
+
+class EdscTest:
+
+  def setUp(self):
+    self.module = E.MLIRModule()
+    self.boolType = self.module.make_scalar_type("i", 1)
+    self.i32Type = self.module.make_scalar_type("i", 32)
+    self.f32Type = self.module.make_scalar_type("f32")
+    self.indexType = self.module.make_index_type()
+
+  def testBlockArguments(self):
+    self.setUp()
+    with self.module.function_context("foo", [], []) as fun:
+      E.constant_index(42)
+      with E.BlockContext([self.f32Type, self.f32Type]) as b:
+        b.arg(0) + b.arg(1)
+      printWithCurrentFunctionName(str(fun))
+    # CHECK-LABEL: testBlockArguments
+    #       CHECK: %{{.*}} = constant 42 : index
+    #       CHECK: ^bb{{.*}}(%{{.*}}: f32, %{{.*}}: f32):
+    #       CHECK:   %{{.*}} = addf %{{.*}}, %{{.*}} : f32
+
+  def testBlockContext(self):
+    self.setUp()
+    with self.module.function_context("foo", [], []) as fun:
+      cst = E.constant_index(42)
+      with E.BlockContext():
+        cst + cst
+      printWithCurrentFunctionName(str(fun))
+    # CHECK-LABEL: testBlockContext
+    #       CHECK: %{{.*}} = constant 42 : index
+    #       CHECK: ^bb
+    #       CHECK: %{{.*}} = "affine.apply"() {map = () -> (84)} : () -> index
+
+  def testBlockContextAppend(self):
+    self.setUp()
+    with self.module.function_context("foo", [], []) as fun:
+      E.constant_index(41)
+      with E.BlockContext() as b:
+        blk = b  # save block handle for later
+        E.constant_index(0)
+      E.constant_index(42)
+      with E.BlockContext(E.appendTo(blk)):
+        E.constant_index(1)
+      printWithCurrentFunctionName(str(fun))
+    # CHECK-LABEL: testBlockContextAppend
+    #       CHECK: %{{.*}} = constant 41 : index
+    #       CHECK: %{{.*}} = constant 42 : index
+    #       CHECK: ^bb
+    #       CHECK: %{{.*}} = constant 0 : index
+    #       CHECK: %{{.*}} = constant 1 : index
+
+  def testBlockContextStandalone(self):
+    self.setUp()
+    with self.module.function_context("foo", [], []) as fun:
+      blk1 = E.BlockContext()
+      blk2 = E.BlockContext()
+      with blk1:
+        E.constant_index(0)
+      with blk2:
+        E.constant_index(56)
+        E.constant_index(57)
+      E.constant_index(41)
+      with blk1:
+        E.constant_index(1)
+      E.constant_index(42)
+      printWithCurrentFunctionName(str(fun))
+    # CHECK-LABEL: testBlockContextStandalone
+    #       CHECK: %{{.*}} = constant 41 : index
+    #       CHECK: %{{.*}} = constant 42 : index
+    #       CHECK: ^bb
+    #       CHECK: %{{.*}} = constant 0 : index
+    #       CHECK: %{{.*}} = constant 1 : index
+    #       CHECK: ^bb
+    #       CHECK: %{{.*}} = constant 56 : index
+    #       CHECK: %{{.*}} = constant 57 : index
+
+  def testBooleanOps(self):
+    self.setUp()
+    with self.module.function_context(
+        "booleans", [self.boolType for _ in range(4)], []) as fun:
+      i, j, k, l = (fun.arg(x) for x in range(4))
+      stmt1 = (i < j) & (j >= k)
+      stmt2 = ~(stmt1 | (k == l))
+      printWithCurrentFunctionName(str(fun))
+    # CHECK-LABEL: testBooleanOps
+    #       CHECK: %{{.*}} = cmpi "slt", %{{.*}}, %{{.*}} : i1
+    #       CHECK: %{{.*}} = cmpi "sge", %{{.*}}, %{{.*}} : i1
+    #       CHECK: %{{.*}} = muli %{{.*}}, %{{.*}} : i1
+    #       CHECK: %{{.*}} = cmpi "eq", %{{.*}}, %{{.*}} : i1
+    #       CHECK: %{{.*}} = constant 1 : i1
+    #       CHECK: %{{.*}} = subi %{{.*}}, %{{.*}} : i1
+    #       CHECK: %{{.*}} = constant 1 : i1
+    #       CHECK: %{{.*}} = subi %{{.*}}, %{{.*}} : i1
+    #       CHECK: %{{.*}} = muli %{{.*}}, %{{.*}} : i1
+    #       CHECK: %{{.*}} = constant 1 : i1
+    #       CHECK: %{{.*}} = subi %{{.*}}, %{{.*}} : i1
+    #       CHECK: %{{.*}} = constant 1 : i1
+    #       CHECK: %{{.*}} = subi %{{.*}}, %{{.*}} : i1
+
+  def testBr(self):
+    self.setUp()
+    with self.module.function_context("foo", [], []) as fun:
+      with E.BlockContext() as b:
+        blk = b
+        E.ret()
+      E.br(blk)
+      printWithCurrentFunctionName(str(fun))
+    # CHECK-LABEL: testBr
+    #       CHECK:   br ^bb
+    #       CHECK: ^bb
+    #       CHECK:   return
+
+  def testBrArgs(self):
+    self.setUp()
+    with self.module.function_context("foo", [], []) as fun:
+      # Create an infinite loop.
+      with E.BlockContext([self.indexType, self.indexType]) as b:
+        E.br(b, [b.arg(1), b.arg(0)])
+      E.br(b, [E.constant_index(0), E.constant_index(1)])
+      printWithCurrentFunctionName(str(fun))
+    # CHECK-LABEL: testBrArgs
+    #       CHECK:   %{{.*}} = constant 0 : index
+    #       CHECK:   %{{.*}} = constant 1 : index
+    #       CHECK:   br ^bb{{.*}}(%{{.*}}, %{{.*}} : index, index)
+    #       CHECK: ^bb{{.*}}(%{{.*}}: index, %{{.*}}: index):
+    #       CHECK:   br ^bb{{.*}}(%{{.*}}, %{{.*}} : index, index)
+
+  def testBrDeclaration(self):
+    self.setUp()
+    with self.module.function_context("foo", [], []) as fun:
+      blk = E.BlockContext()
+      E.br(blk.handle())
+      with blk:
+        E.ret()
+      printWithCurrentFunctionName(str(fun))
+    # CHECK-LABEL: testBrDeclaration
+    #       CHECK:   br ^bb
+    #       CHECK: ^bb
+    #       CHECK:   return
+
+  def testCallOp(self):
+    self.setUp()
+    callee = self.module.declare_function("sqrtf", [self.f32Type],
+                                          [self.f32Type])
+    with self.module.function_context("call", [self.f32Type], []) as fun:
+      funCst = E.constant_function(callee)
+      funCst([fun.arg(0)]) + E.constant_float(42., self.f32Type)
+      printWithCurrentFunctionName(str(self.module))
+    # CHECK-LABEL: testCallOp
+    #       CHECK: func @sqrtf(f32) -> f32
+    #       CHECK:   %{{.*}} = constant @sqrtf : (f32) -> f32
+    #       CHECK:   %{{.*}} = call_indirect %{{.*}}(%{{.*}}) : (f32) -> f32
+
+  def testCondBr(self):
+    self.setUp()
+    with self.module.function_context("foo", [self.boolType], []) as fun:
+      with E.BlockContext() as blk1:
+        E.ret([])
+      with E.BlockContext([self.indexType]) as blk2:
+        E.ret([])
+      cst = E.constant_index(0)
+      E.cond_br(fun.arg(0), blk1, [], blk2, [cst])
+      printWithCurrentFunctionName(str(fun))
+    # CHECK-LABEL: testCondBr
+    #       CHECK:   cond_br %{{.*}}, ^bb{{.*}}, ^bb{{.*}}(%{{.*}} : index)
+
+  def testConstants(self):
+    self.setUp()
+    with self.module.function_context("constants", [], []) as fun:
+      E.constant_float(1.23, self.module.make_scalar_type("bf16"))
+      E.constant_float(1.23, self.module.make_scalar_type("f16"))
+      E.constant_float(1.23, self.module.make_scalar_type("f32"))
+      E.constant_float(1.23, self.module.make_scalar_type("f64"))
+      E.constant_int(1, 1)
+      E.constant_int(123, 8)
+      E.constant_int(123, 16)
+      E.constant_int(123, 32)
+      E.constant_int(123, 64)
+      E.constant_index(123)
+      E.constant_function(fun)
+      printWithCurrentFunctionName(str(fun))
+    # CHECK-LABEL: testConstants
+    #       CHECK:  constant 1.230000e+00 : bf16
+    #       CHECK:  constant 1.230470e+00 : f16
+    #       CHECK:  constant 1.230000e+00 : f32
+    #       CHECK:  constant 1.230000e+00 : f64
+    #       CHECK:  constant 1 : i1
+    #       CHECK:  constant 123 : i8
+    #       CHECK:  constant 123 : i16
+    #       CHECK:  constant 123 : i32
+    #       CHECK:  constant 123 : index
+    #       CHECK:  constant @constants : () -> ()
+
+  def testCustom(self):
+    self.setUp()
+    with self.module.function_context("custom", [self.indexType, self.f32Type],
+                                      []) as fun:
+      E.op("foo", [fun.arg(0)], [self.f32Type]) + fun.arg(1)
+      printWithCurrentFunctionName(str(fun))
+    # CHECK-LABEL: testCustom
+    #       CHECK: %{{.*}} = "foo"(%{{.*}}) : (index) -> f32
+    #       CHECK:  %{{.*}} = addf %{{.*}}, %{{.*}} : f32
+
+  # Create 'addi' using the generic Op interface.  We need an operation known
+  # to the execution engine so that the engine can compile it.
+  def testCustomOpCompilation(self):
+    self.setUp()
+    with self.module.function_context("adder", [self.i32Type], []) as f:
+      c1 = E.op(
+          "std.constant", [], [self.i32Type],
+          value=self.module.integerAttr(self.i32Type, 42))
+      E.op("std.addi", [c1, f.arg(0)], [self.i32Type])
+      E.ret([])
+    self.module.compile()
+    printWithCurrentFunctionName(str(self.module.get_engine_address() == 0))
+    # CHECK-LABEL: testCustomOpCompilation
+    #       CHECK: False
+
+  def testDivisions(self):
+    self.setUp()
+    with self.module.function_context(
+        "division", [self.indexType, self.i32Type, self.i32Type], []) as fun:
+      # indices only support floor division
+      fun.arg(0) // E.constant_index(42)
+      # regular values only support regular division
+      fun.arg(1) / fun.arg(2)
+      printWithCurrentFunctionName(str(self.module))
+    # CHECK-LABEL: testDivisions
+    #       CHECK:  floordiv 42
+    #       CHECK:  divis %{{.*}}, %{{.*}} : i32
+
+  def testFunctionArgs(self):
+    self.setUp()
+    with self.module.function_context("foo", [self.f32Type, self.f32Type],
+                                      [self.indexType]) as fun:
+      pass
+      printWithCurrentFunctionName(str(fun))
+    # CHECK-LABEL: testFunctionArgs
+    #       CHECK: func @foo(%{{.*}}: f32, %{{.*}}: f32) -> index
+
+  def testFunctionContext(self):
+    self.setUp()
+    with self.module.function_context("foo", [], []):
+      pass
+      printWithCurrentFunctionName(self.module.get_function("foo"))
+    # CHECK-LABEL: testFunctionContext
+    #       CHECK: func @foo() {
+
+  def testFunctionDeclaration(self):
+    self.setUp()
+    boolAttr = self.module.boolAttr(True)
+    t = self.module.make_memref_type(self.f32Type, [10])
+    t_llvm_noalias = t({"llvm.noalias": boolAttr})
+    t_readonly = t({"readonly": boolAttr})
+    f = self.module.declare_function("foo", [t, t_llvm_noalias, t_readonly], [])
+    printWithCurrentFunctionName(str(self.module))
+    # CHECK-LABEL: testFunctionDeclaration
+    #       CHECK: func @foo(memref<10xf32>, memref<10xf32> {llvm.noalias = true}, memref<10xf32> {readonly = true})
+
+  def testFunctionMultiple(self):
+    self.setUp()
+    with self.module.function_context("foo", [], []):
+      pass
+    with self.module.function_context("foo", [], []):
+      E.constant_index(0)
+    printWithCurrentFunctionName(str(self.module))
+    # CHECK-LABEL: testFunctionMultiple
+    #       CHECK: func @foo()
+    #       CHECK: func @foo_0()
+    #       CHECK: %{{.*}} = constant 0 : index
+
+  def testIndexedValue(self):
+    self.setUp()
+    memrefType = self.module.make_memref_type(self.f32Type, [10, 42])
+    with self.module.function_context("indexed", [memrefType],
+                                      [memrefType]) as fun:
+      A = E.IndexedValue(fun.arg(0))
+      cst = E.constant_float(1., self.f32Type)
+      with E.LoopNestContext(
+          [E.constant_index(0), E.constant_index(0)],
+          [E.constant_index(10), E.constant_index(42)], [1, 1]) as (i, j):
+        A.store([i, j], A.load([i, j]) + cst)
+      E.ret([fun.arg(0)])
+      printWithCurrentFunctionName(str(fun))
+    # CHECK-LABEL: testIndexedValue
+    #       CHECK: "affine.for"()
+    #       CHECK: "affine.for"()
+    #       CHECK: "affine.load"
+    #  CHECK-SAME: memref<10x42xf32>
+    #       CHECK:  %{{.*}} = addf %{{.*}}, %{{.*}} : f32
+    #       CHECK:  "affine.store"
+    #  CHECK-SAME:  memref<10x42xf32>
+    #       CHECK: {lower_bound = () -> (0), step = 1 : index, upper_bound = () -> (42)}
+    #       CHECK: {lower_bound = () -> (0), step = 1 : index, upper_bound = () -> (10)}
+
+  def testLoopContext(self):
+    self.setUp()
+    with self.module.function_context("foo", [], []) as fun:
+      lhs = E.constant_index(0)
+      rhs = E.constant_index(42)
+      with E.LoopContext(lhs, rhs, 1) as i:
+        lhs + rhs + i
+        with E.LoopContext(rhs, rhs + rhs, 2) as j:
+          x = i + j
+      printWithCurrentFunctionName(str(fun))
+    # CHECK-LABEL: testLoopContext
+    #       CHECK: "affine.for"() (
+    #       CHECK:   ^bb{{.*}}(%{{.*}}: index):
+    #       CHECK: "affine.for"(%{{.*}}, %{{.*}}) (
+    #       CHECK: ^bb{{.*}}(%{{.*}}: index):
+    #       CHECK: "affine.apply"(%{{.*}}, %{{.*}}) {map = (d0, d1) -> (d0 + d1)} : (index, index) -> index
+    #       CHECK: {lower_bound = (d0) -> (d0), step = 2 : index, upper_bound = (d0) -> (d0)} : (index, index) -> ()
+    #       CHECK: {lower_bound = () -> (0), step = 1 : index, upper_bound = () -> (42)}
+
+  def testLoopNestContext(self):
+    self.setUp()
+    with self.module.function_context("foo", [], []) as fun:
+      lbs = [E.constant_index(i) for i in range(4)]
+      ubs = [E.constant_index(10 * i + 5) for i in range(4)]
+      with E.LoopNestContext(lbs, ubs, [1, 3, 5, 7]) as (i, j, k, l):
+        i + j + k + l
+    printWithCurrentFunctionName(str(fun))
+    # CHECK-LABEL: testLoopNestContext
+    #       CHECK: "affine.for"() (
+    #       CHECK: ^bb{{.*}}(%{{.*}}: index):
+    #       CHECK: "affine.for"() (
+    #       CHECK: ^bb{{.*}}(%{{.*}}: index):
+    #       CHECK: "affine.for"() (
+    #       CHECK: ^bb{{.*}}(%{{.*}}: index):
+    #       CHECK: "affine.for"() (
+    #       CHECK: ^bb{{.*}}(%{{.*}}: index):
+    #       CHECK: %{{.*}} = "affine.apply"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {map = (d0, d1, d2, d3) -> (d0 + d1 + d2 + d3)} : (index, index, index, index) -> index
+
+  def testMLIRBooleanCompilation(self):
+    self.setUp()
+    m = self.module.make_memref_type(self.boolType, [10])  # i1 tensor
+    with self.module.function_context("mkbooltensor", [m, m], []) as f:
+      input = E.IndexedValue(f.arg(0))
+      output = E.IndexedValue(f.arg(1))
+      zero = E.constant_index(0)
+      ten = E.constant_index(10)
+      with E.LoopNestContext([zero] * 3, [ten] * 3, [1] * 3) as (i, j, k):
+        b1 = (i < j) & (j < k)
+        b2 = ~b1
+        b3 = b2 | (k < j)
+        output.store([i], input.load([i]) & b3)
+      E.ret([])
+    self.module.compile()
+    printWithCurrentFunctionName(str(self.module.get_engine_address() == 0))
+    # CHECK-LABEL: testMLIRBooleanCompilation
+    #       CHECK: False
+
+  def testMLIRFunctionCreation(self):
+    self.setUp()
+    module = E.MLIRModule()
+    t = module.make_scalar_type("f32")
+    m = module.make_memref_type(t, [3, 4, -1, 5])
+    printWithCurrentFunctionName(str(t))
+    print(str(m))
+    print(str(module.make_function("copy", [m, m], [])))
+    print(str(module.make_function("sqrtf", [t], [t])))
+    # CHECK-LABEL: testMLIRFunctionCreation
+    #       CHECK:  f32
+    #       CHECK:  memref<3x4x?x5xf32>
+    #       CHECK: func @copy(%{{.*}}: memref<3x4x?x5xf32>, %{{.*}}: memref<3x4x?x5xf32>) {
+    #       CHECK:  func @sqrtf(%{{.*}}: f32) -> f32
+
+  def testMLIRScalarTypes(self):
+    self.setUp()
+    module = E.MLIRModule()
+    printWithCurrentFunctionName(str(module.make_scalar_type("bf16")))
+    print(str(module.make_scalar_type("f16")))
+    print(str(module.make_scalar_type("f32")))
+    print(str(module.make_scalar_type("f64")))
+    print(str(module.make_scalar_type("i", 1)))
+    print(str(module.make_scalar_type("i", 8)))
+    print(str(module.make_scalar_type("i", 32)))
+    print(str(module.make_scalar_type("i", 123)))
+    print(str(module.make_scalar_type("index")))
+    # CHECK-LABEL: testMLIRScalarTypes
+    #       CHECK:  bf16
+    #       CHECK:  f16
+    #       CHECK:  f32
+    #       CHECK:  f64
+    #       CHECK:  i1
+    #       CHECK:  i8
+    #       CHECK:  i32
+    #       CHECK:  i123
+    #       CHECK:  index
+
+  def testMatrixMultiply(self):
+    self.setUp()
+    memrefType = self.module.make_memref_type(self.f32Type, [32, 32])
+    with self.module.function_context(
+        "matmul", [memrefType, memrefType, memrefType], []) as fun:
+      A = E.IndexedValue(fun.arg(0))
+      B = E.IndexedValue(fun.arg(1))
+      C = E.IndexedValue(fun.arg(2))
+      c0 = E.constant_index(0)
+      c32 = E.constant_index(32)
+      with E.LoopNestContext([c0, c0, c0], [c32, c32, c32], [1, 1, 1]) as (i, j,
+                                                                           k):
+        C.store([i, j], A.load([i, k]) * B.load([k, j]))
+      E.ret([])
+      printWithCurrentFunctionName(str(fun))
+    # CHECK-LABEL: testMatrixMultiply
+    #       CHECK: "affine.for"()
+    #       CHECK: "affine.for"()
+    #       CHECK: "affine.for"()
+    #   CHECK-DAG:  %{{.*}} = "affine.load"
+    #   CHECK-DAG:  %{{.*}} = "affine.load"
+    #       CHECK:  %{{.*}} = mulf %{{.*}}, %{{.*}} : f32
+    #       CHECK:  "affine.store"
+    #  CHECK-SAME:  memref<32x32xf32>
+    #       CHECK: {lower_bound = () -> (0), step = 1 : index, upper_bound = () -> (32)} : () -> ()
+    #       CHECK: {lower_bound = () -> (0), step = 1 : index, upper_bound = () -> (32)} : () -> ()
+    #       CHECK: {lower_bound = () -> (0), step = 1 : index, upper_bound = () -> (32)} : () -> ()
+
+  def testRet(self):
+    self.setUp()
+    with self.module.function_context("foo", [],
+                                      [self.indexType, self.indexType]) as fun:
+      c42 = E.constant_index(42)
+      c0 = E.constant_index(0)
+      E.ret([c42, c0])
+      printWithCurrentFunctionName(str(fun))
+    # CHECK-LABEL: testRet
+    #       CHECK:    %{{.*}} = constant 42 : index
+    #       CHECK:    %{{.*}} = constant 0 : index
+    #       CHECK:    return %{{.*}}, %{{.*}} : index, index
+
+  def testSelectOp(self):
+    self.setUp()
+    with self.module.function_context("foo", [self.boolType],
+                                      [self.i32Type]) as fun:
+      a = E.constant_int(42, 32)
+      b = E.constant_int(0, 32)
+      E.ret([E.select(fun.arg(0), a, b)])
+      printWithCurrentFunctionName(str(fun))
+    # CHECK-LABEL: testSelectOp
+    #       CHECK:  %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : i32
+
+
+# Until python 3.6 this cannot be used because the order in the dict is not the
+# order of method declaration.
+def runTests():
+  def isTest(attr):
+    return inspect.ismethod(attr) and "EdscTest.setUp " not in str(attr)
+
+  edscTest = EdscTest()
+  tests = sorted(filter(isTest,
+                        (getattr(edscTest, attr) for attr in dir(edscTest))),
+                 key = lambda x : str(x))
+  for test in tests:
+    test()
+
+if __name__ == '__main__':
+  runTests()
diff --git a/third_party/mlir/include/mlir-c/Core.h b/third_party/mlir/include/mlir-c/Core.h
new file mode 100644
index 0000000..918ccdf
--- /dev/null
+++ b/third_party/mlir/include/mlir-c/Core.h
@@ -0,0 +1,119 @@
+/*===-- mlir-c/Core.h - Core Library C Interface ------------------*- C -*-===*\
+|*                                                                            *|
+|* Copyright 2019 The MLIR Authors.                                           *|
+|*                                                                            *|
+|* Licensed under the Apache License, Version 2.0 (the "License");            *|
+|* you may not use this file except in compliance with the License.           *|
+|* You may obtain a copy of the License at                                    *|
+|*                                                                            *|
+|*   http://www.apache.org/licenses/LICENSE-2.0                               *|
+|*                                                                            *|
+|* Unless required by applicable law or agreed to in writing, software        *|
+|* distributed under the License is distributed on an "AS IS" BASIS,          *|
+|* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.   *|
+|* See the License for the specific language governing permissions and        *|
+|* limitations under the License.                                             *|
+|*                                                                            *|
+|*===----------------------------------------------------------------------===*|
+|*                                                                            *|
+|* This header declares the C interface to MLIR.                              *|
+|*                                                                            *|
+\*===----------------------------------------------------------------------===*/
+#ifndef MLIR_C_CORE_H
+#define MLIR_C_CORE_H
+
+#ifdef __cplusplus
+#include <cstdint>
+extern "C" {
+#else
+#include <stdint.h>
+#endif
+
+/// Opaque MLIR types.
+/// Opaque C type for mlir::MLIRContext*.
+typedef void *mlir_context_t;
+/// Opaque C type for mlir::Type.
+typedef const void *mlir_type_t;
+/// Opaque C type for mlir::FuncOp.
+typedef void *mlir_func_t;
+/// Opaque C type for mlir::Attribute.
+typedef const void *mlir_attr_t;
+
+/// Simple C lists for non-owning mlir Opaque C types.
+/// Recommended usage is construction from the `data()` and `size()` of a scoped
+/// owning SmallVectorImpl<...> and passing to one of the C functions declared
+/// later in this file.
+/// Once the function returns and the proper EDSC has been constructed,
+/// resources are freed by exiting the scope.
+typedef struct {
+  int64_t *values;
+  uint64_t n;
+} int64_list_t;
+
+typedef struct {
+  mlir_type_t *types;
+  uint64_t n;
+} mlir_type_list_t;
+
+typedef struct {
+  const char *name;
+  mlir_attr_t value;
+} mlir_named_attr_t;
+
+typedef struct {
+  mlir_named_attr_t *list;
+  uint64_t n;
+} mlir_named_attr_list_t;
+
+/// Minimal C API for exposing EDSCs to Swift, Python and other languages.
+
+/// Returns a simple scalar mlir::Type using the following convention:
+///   - makeScalarType(c, "bf16") return an `mlir::FloatType::getBF16`
+///   - makeScalarType(c, "f16") return an `mlir::FloatType::getF16`
+///   - makeScalarType(c, "f32") return an `mlir::FloatType::getF32`
+///   - makeScalarType(c, "f64") return an `mlir::FloatType::getF64`
+///   - makeScalarType(c, "index") return an `mlir::IndexType::get`
+///   - makeScalarType(c, "i", bitwidth) return an
+///     `mlir::IntegerType::get(bitwidth)`
+///
+/// No other combinations are currently supported.
+mlir_type_t makeScalarType(mlir_context_t context, const char *name,
+                           unsigned bitwidth);
+
+/// Returns an `mlir::MemRefType` of the element type `elemType` and shape
+/// `sizes`.
+mlir_type_t makeMemRefType(mlir_context_t context, mlir_type_t elemType,
+                           int64_list_t sizes);
+
+/// Returns an `mlir::FunctionType` of the element type `elemType` and shape
+/// `sizes`.
+mlir_type_t makeFunctionType(mlir_context_t context, mlir_type_list_t inputs,
+                             mlir_type_list_t outputs);
+
+/// Returns an `mlir::IndexType`.
+mlir_type_t makeIndexType(mlir_context_t context);
+
+/// Returns an `mlir::IntegerAttr` of the specified type that contains the given
+/// value.
+mlir_attr_t makeIntegerAttr(mlir_type_t type, int64_t value);
+
+/// Returns an `mlir::BoolAttr` with the given value.
+mlir_attr_t makeBoolAttr(mlir_context_t context, bool value);
+
+/// Returns the arity of `function`.
+unsigned getFunctionArity(mlir_func_t function);
+
+/// Returns the rank of the `function` argument at position `pos`.
+/// If the argument is of MemRefType, this returns the rank of the MemRef.
+/// Otherwise returns `0`.
+/// TODO(ntv): support more than MemRefType and scalar Type.
+unsigned getRankOfFunctionArgument(mlir_func_t function, unsigned pos);
+
+/// Returns an opaque mlir::Type of the `function` argument at position `pos`.
+mlir_type_t getTypeOfFunctionArgument(mlir_func_t function, unsigned pos);
+
+#ifdef __cplusplus
+} // end extern "C"
+#endif
+
+#endif // MLIR_C_CORE_H
diff --git a/third_party/mlir/include/mlir/AffineOps/AffineOps.h b/third_party/mlir/include/mlir/AffineOps/AffineOps.h
new file mode 100644
index 0000000..59f7fc7
--- /dev/null
+++ b/third_party/mlir/include/mlir/AffineOps/AffineOps.h
@@ -0,0 +1,598 @@
+//===- AffineOps.h - MLIR Affine Operations -------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines convenience types for working with Affine operations
+// in the MLIR operation set.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_AFFINEOPS_AFFINEOPS_H
+#define MLIR_AFFINEOPS_AFFINEOPS_H
+
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+class AffineBound;
+class AffineValueMap;
+class AffineTerminatorOp;
+class FlatAffineConstraints;
+class OpBuilder;
+
+/// A utility function to check if a value is defined at the top level of a
+/// function. A value defined at the top level is always a valid symbol.
+bool isTopLevelSymbol(Value *value);
+
+class AffineOpsDialect : public Dialect {
+public:
+  AffineOpsDialect(MLIRContext *context);
+  static StringRef getDialectNamespace() { return "affine"; }
+};
+
+/// The "affine.apply" operation applies an affine map to a list of operands,
+/// yielding a single result. The operand list must be the same size as the
+/// number of arguments to the affine mapping.  All operands and the result are
+/// of type 'Index'. This operation requires a single affine map attribute named
+/// "map".  For example:
+///
+///   %y = "affine.apply" (%x) { map: (d0) -> (d0 + 1) } :
+///          (index) -> (index)
+///
+/// equivalently:
+///
+///   #map42 = (d0)->(d0+1)
+///   %y = affine.apply #map42(%x)
+///
+class AffineApplyOp : public Op<AffineApplyOp, OpTrait::VariadicOperands,
+                                OpTrait::OneResult, OpTrait::HasNoSideEffect> {
+public:
+  using Op::Op;
+
+  /// Builds an affine apply op with the specified map and operands.
+  static void build(Builder *builder, OperationState *result, AffineMap map,
+                    ArrayRef<Value *> operands);
+
+  /// Returns the affine map to be applied by this operation.
+  AffineMap getAffineMap() {
+    return getAttrOfType<AffineMapAttr>("map").getValue();
+  }
+
+  /// Returns true if the result of this operation can be used as dimension id.
+  bool isValidDim();
+
+  /// Returns true if the result of this operation is a symbol.
+  bool isValidSymbol();
+
+  static StringRef getOperationName() { return "affine.apply"; }
+
+  // Hooks to customize behavior of this op.
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+  LogicalResult verify();
+  OpFoldResult fold(ArrayRef<Attribute> operands);
+
+  static void getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                          MLIRContext *context);
+};
+
+/// AffineDmaStartOp starts a non-blocking DMA operation that transfers data
+/// from a source memref to a destination memref. The source and destination
+/// memref need not be of the same dimensionality, but need to have the same
+/// elemental type. The operands include the source and destination memref's
+/// each followed by its indices, size of the data transfer in terms of the
+/// number of elements (of the elemental type of the memref), a tag memref with
+/// its indices, and optionally at the end, a stride and a
+/// number_of_elements_per_stride arguments. The tag location is used by an
+/// AffineDmaWaitOp to check for completion. The indices of the source memref,
+/// destination memref, and the tag memref have the same restrictions as any
+/// affine.load/store. In particular, index for each memref dimension must be an
+/// affine expression of loop induction variables and symbols.
+/// The optional stride arguments should be of 'index' type, and specify a
+/// stride for the slower memory space (memory space with a lower memory space
+/// id), tranferring chunks of number_of_elements_per_stride every stride until
+/// %num_elements are transferred. Either both or no stride arguments should be
+/// specified. The value of 'num_elements' must be a multiple of
+/// 'number_of_elements_per_stride'.
+//
+// For example, a DmaStartOp operation that transfers 256 elements of a memref
+// '%src' in memory space 0 at indices [%i + 3, %j] to memref '%dst' in memory
+// space 1 at indices [%k + 7, %l], would be specified as follows:
+//
+//   %num_elements = constant 256
+//   %idx = constant 0 : index
+//   %tag = alloc() : memref<1xi32, 4>
+//   affine.dma_start %src[%i + 3, %j], %dst[%k + 7, %l], %tag[%idx],
+//     %num_elements :
+//       memref<40x128xf32, 0>, memref<2x1024xf32, 1>, memref<1xi32, 2>
+//
+//   If %stride and %num_elt_per_stride are specified, the DMA is expected to
+//   transfer %num_elt_per_stride elements every %stride elements apart from
+//   memory space 0 until %num_elements are transferred.
+//
+//   affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%idx], %num_elements,
+//     %stride, %num_elt_per_stride : ...
+//
+// TODO(mlir-team): add additional operands to allow source and destination
+// striding, and multiple stride levels (possibly using AffineMaps to specify
+// multiple levels of striding).
+// TODO(andydavis) Consider replacing src/dst memref indices with view memrefs.
+class AffineDmaStartOp : public Op<AffineDmaStartOp, OpTrait::VariadicOperands,
+                                   OpTrait::ZeroResult> {
+public:
+  using Op::Op;
+
+  static void build(Builder *builder, OperationState *result, Value *srcMemRef,
+                    AffineMap srcMap, ArrayRef<Value *> srcIndices,
+                    Value *destMemRef, AffineMap dstMap,
+                    ArrayRef<Value *> destIndices, Value *tagMemRef,
+                    AffineMap tagMap, ArrayRef<Value *> tagIndices,
+                    Value *numElements, Value *stride = nullptr,
+                    Value *elementsPerStride = nullptr);
+
+  /// Returns the operand index of the src memref.
+  unsigned getSrcMemRefOperandIndex() { return 0; }
+
+  /// Returns the source MemRefType for this DMA operation.
+  Value *getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); }
+  MemRefType getSrcMemRefType() {
+    return getSrcMemRef()->getType().cast<MemRefType>();
+  }
+
+  /// Returns the rank (number of indices) of the source MemRefType.
+  unsigned getSrcMemRefRank() { return getSrcMemRefType().getRank(); }
+
+  /// Returns the affine map used to access the src memref.
+  AffineMap getSrcMap() { return getSrcMapAttr().getValue(); }
+  AffineMapAttr getSrcMapAttr() {
+    return getAttr(getSrcMapAttrName()).cast<AffineMapAttr>();
+  }
+
+  /// Returns the source memref affine map indices for this DMA operation.
+  operand_range getSrcIndices() {
+    return {operand_begin() + getSrcMemRefOperandIndex() + 1,
+            operand_begin() + getSrcMemRefOperandIndex() + 1 +
+                getSrcMap().getNumInputs()};
+  }
+
+  /// Returns the memory space of the src memref.
+  unsigned getSrcMemorySpace() {
+    return getSrcMemRef()->getType().cast<MemRefType>().getMemorySpace();
+  }
+
+  /// Returns the operand index of the dst memref.
+  unsigned getDstMemRefOperandIndex() {
+    return getSrcMemRefOperandIndex() + 1 + getSrcMap().getNumInputs();
+  }
+
+  /// Returns the destination MemRefType for this DMA operations.
+  Value *getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); }
+  MemRefType getDstMemRefType() {
+    return getDstMemRef()->getType().cast<MemRefType>();
+  }
+
+  /// Returns the rank (number of indices) of the destination MemRefType.
+  unsigned getDstMemRefRank() {
+    return getDstMemRef()->getType().cast<MemRefType>().getRank();
+  }
+
+  /// Returns the memory space of the src memref.
+  unsigned getDstMemorySpace() {
+    return getDstMemRef()->getType().cast<MemRefType>().getMemorySpace();
+  }
+
+  /// Returns the affine map used to access the dst memref.
+  AffineMap getDstMap() { return getDstMapAttr().getValue(); }
+  AffineMapAttr getDstMapAttr() {
+    return getAttr(getDstMapAttrName()).cast<AffineMapAttr>();
+  }
+
+  /// Returns the destination memref indices for this DMA operation.
+  operand_range getDstIndices() {
+    return {operand_begin() + getDstMemRefOperandIndex() + 1,
+            operand_begin() + getDstMemRefOperandIndex() + 1 +
+                getDstMap().getNumInputs()};
+  }
+
+  /// Returns the operand index of the tag memref.
+  unsigned getTagMemRefOperandIndex() {
+    return getDstMemRefOperandIndex() + 1 + getDstMap().getNumInputs();
+  }
+
+  /// Returns the Tag MemRef for this DMA operation.
+  Value *getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); }
+  MemRefType getTagMemRefType() {
+    return getTagMemRef()->getType().cast<MemRefType>();
+  }
+
+  /// Returns the rank (number of indices) of the tag MemRefType.
+  unsigned getTagMemRefRank() {
+    return getTagMemRef()->getType().cast<MemRefType>().getRank();
+  }
+
+  /// Returns the affine map used to access the tag memref.
+  AffineMap getTagMap() { return getTagMapAttr().getValue(); }
+  AffineMapAttr getTagMapAttr() {
+    return getAttr(getTagMapAttrName()).cast<AffineMapAttr>();
+  }
+
+  /// Returns the tag memref indices for this DMA operation.
+  operand_range getTagIndices() {
+    return {operand_begin() + getTagMemRefOperandIndex() + 1,
+            operand_begin() + getTagMemRefOperandIndex() + 1 +
+                getTagMap().getNumInputs()};
+  }
+
+  /// Returns the number of elements being transferred by this DMA operation.
+  Value *getNumElements() {
+    return getOperand(getTagMemRefOperandIndex() + 1 +
+                      getTagMap().getNumInputs());
+  }
+
+  /// Returns the AffineMapAttr associated with 'memref'.
+  NamedAttribute getAffineMapAttrForMemRef(Value *memref) {
+    if (memref == getSrcMemRef())
+      return {Identifier::get(getSrcMapAttrName(), getContext()),
+              getSrcMapAttr()};
+    else if (memref == getDstMemRef())
+      return {Identifier::get(getDstMapAttrName(), getContext()),
+              getDstMapAttr()};
+    assert(memref == getTagMemRef() &&
+           "DmaStartOp expected source, destination or tag memref");
+    return {Identifier::get(getTagMapAttrName(), getContext()),
+            getTagMapAttr()};
+  }
+
+  /// Returns true if this is a DMA from a faster memory space to a slower one.
+  bool isDestMemorySpaceFaster() {
+    return (getSrcMemorySpace() < getDstMemorySpace());
+  }
+
+  /// Returns true if this is a DMA from a slower memory space to a faster one.
+  bool isSrcMemorySpaceFaster() {
+    // Assumes that a lower number is for a slower memory space.
+    return (getDstMemorySpace() < getSrcMemorySpace());
+  }
+
+  /// Given a DMA start operation, returns the operand position of either the
+  /// source or destination memref depending on the one that is at the higher
+  /// level of the memory hierarchy. Asserts failure if neither is true.
+  unsigned getFasterMemPos() {
+    assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
+    return isSrcMemorySpaceFaster() ? 0 : getDstMemRefOperandIndex();
+  }
+
+  static StringRef getSrcMapAttrName() { return "src_map"; }
+  static StringRef getDstMapAttrName() { return "dst_map"; }
+  static StringRef getTagMapAttrName() { return "tag_map"; }
+
+  static StringRef getOperationName() { return "affine.dma_start"; }
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+  LogicalResult verify();
+  static void getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                          MLIRContext *context);
+
+  /// Returns true if this DMA operation is strided, returns false otherwise.
+  bool isStrided() {
+    return getNumOperands() !=
+           getTagMemRefOperandIndex() + 1 + getTagMap().getNumInputs() + 1;
+  }
+
+  /// Returns the stride value for this DMA operation.
+  Value *getStride() {
+    if (!isStrided())
+      return nullptr;
+    return getOperand(getNumOperands() - 1 - 1);
+  }
+
+  /// Returns the number of elements to transfer per stride for this DMA op.
+  Value *getNumElementsPerStride() {
+    if (!isStrided())
+      return nullptr;
+    return getOperand(getNumOperands() - 1);
+  }
+};
+
+/// AffineDmaWaitOp blocks until the completion of a DMA operation associated
+/// with the tag element '%tag[%index]'. %tag is a memref, and %index has to be
+/// an index with the same restrictions as any load/store index. In particular,
+/// index for each memref dimension must be an affine expression of loop
+/// induction variables and symbols. %num_elements is the number of elements
+/// associated with the DMA operation. For example:
+//
+//   affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %num_elements :
+//     memref<2048xf32, 0>, memref<256xf32, 1>, memref<1xi32, 2>
+//   ...
+//   ...
+//   affine.dma_wait %tag[%index], %num_elements : memref<1xi32, 2>
+//
+class AffineDmaWaitOp : public Op<AffineDmaWaitOp, OpTrait::VariadicOperands,
+                                  OpTrait::ZeroResult> {
+public:
+  using Op::Op;
+
+  static void build(Builder *builder, OperationState *result, Value *tagMemRef,
+                    AffineMap tagMap, ArrayRef<Value *> tagIndices,
+                    Value *numElements);
+
+  static StringRef getOperationName() { return "affine.dma_wait"; }
+
+  // Returns the Tag MemRef associated with the DMA operation being waited on.
+  Value *getTagMemRef() { return getOperand(0); }
+  MemRefType getTagMemRefType() {
+    return getTagMemRef()->getType().cast<MemRefType>();
+  }
+
+  /// Returns the affine map used to access the tag memref.
+  AffineMap getTagMap() { return getTagMapAttr().getValue(); }
+  AffineMapAttr getTagMapAttr() {
+    return getAttr(getTagMapAttrName()).cast<AffineMapAttr>();
+  }
+
+  // Returns the tag memref index for this DMA operation.
+  operand_range getTagIndices() {
+    return {operand_begin() + 1,
+            operand_begin() + 1 + getTagMap().getNumInputs()};
+  }
+
+  // Returns the rank (number of indices) of the tag memref.
+  unsigned getTagMemRefRank() {
+    return getTagMemRef()->getType().cast<MemRefType>().getRank();
+  }
+
+  /// Returns the AffineMapAttr associated with 'memref'.
+  NamedAttribute getAffineMapAttrForMemRef(Value *memref) {
+    assert(memref == getTagMemRef());
+    return {Identifier::get(getTagMapAttrName(), getContext()),
+            getTagMapAttr()};
+  }
+
+  /// Returns the number of elements transferred in the associated DMA op.
+  Value *getNumElements() { return getOperand(1 + getTagMap().getNumInputs()); }
+
+  static StringRef getTagMapAttrName() { return "tag_map"; }
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+  LogicalResult verify();
+  static void getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                          MLIRContext *context);
+};
+
+/// The "affine.load" op reads an element from a memref, where the index
+/// for each memref dimension is an affine expression of loop induction
+/// variables and symbols. The output of 'affine.load' is a new value with the
+/// same type as the elements of the memref. An affine expression of loop IVs
+/// and symbols must be specified for each dimension of the memref. The keyword
+/// 'symbol' can be used to indicate SSA identifiers which are symbolic.
+//
+//  Example 1:
+//
+//    %1 = affine.load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>
+//
+//  Example 2: Uses 'symbol' keyword for symbols '%n' and '%m'.
+//
+//    %1 = affine.load %0[%i0 + symbol(%n), %i1 + symbol(%m)]
+//      : memref<100x100xf32>
+//
+class AffineLoadOp : public Op<AffineLoadOp, OpTrait::OneResult,
+                               OpTrait::AtLeastNOperands<1>::Impl> {
+public:
+  using Op::Op;
+
+  /// Builds an affine load op with the specified map and operands.
+  static void build(Builder *builder, OperationState *result, AffineMap map,
+                    ArrayRef<Value *> operands);
+  /// Builds an affine load op an identify map and operands.
+  static void build(Builder *builder, OperationState *result, Value *memref,
+                    ArrayRef<Value *> indices = {});
+
+  /// Returns the operand index of the memref.
+  unsigned getMemRefOperandIndex() { return 0; }
+
+  /// Get memref operand.
+  Value *getMemRef() { return getOperand(getMemRefOperandIndex()); }
+  void setMemRef(Value *value) { setOperand(getMemRefOperandIndex(), value); }
+  MemRefType getMemRefType() {
+    return getMemRef()->getType().cast<MemRefType>();
+  }
+
+  /// Get affine map operands.
+  operand_range getIndices() { return llvm::drop_begin(getOperands(), 1); }
+
+  /// Returns the affine map used to index the memref for this operation.
+  AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
+  AffineMapAttr getAffineMapAttr() {
+    return getAttr(getMapAttrName()).cast<AffineMapAttr>();
+  }
+
+  /// Returns the AffineMapAttr associated with 'memref'.
+  NamedAttribute getAffineMapAttrForMemRef(Value *memref) {
+    assert(memref == getMemRef());
+    return {Identifier::get(getMapAttrName(), getContext()),
+            getAffineMapAttr()};
+  }
+
+  static StringRef getMapAttrName() { return "map"; }
+  static StringRef getOperationName() { return "affine.load"; }
+
+  // Hooks to customize behavior of this op.
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+  LogicalResult verify();
+  static void getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                          MLIRContext *context);
+};
+
+/// The "affine.store" op writes an element to a memref, where the index
+/// for each memref dimension is an affine expression of loop induction
+/// variables and symbols. The 'affine.store' op stores a new value which is the
+/// same type as the elements of the memref. An affine expression of loop IVs
+/// and symbols must be specified for each dimension of the memref. The keyword
+/// 'symbol' can be used to indicate SSA identifiers which are symbolic.
+//
+//  Example 1:
+//
+//    affine.store %v0, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>
+//
+//  Example 2: Uses 'symbol' keyword for symbols '%n' and '%m'.
+//
+//    affine.store %v0, %0[%i0 + symbol(%n), %i1 + symbol(%m)]
+//      : memref<100x100xf32>
+//
+class AffineStoreOp : public Op<AffineStoreOp, OpTrait::ZeroResult,
+                                OpTrait::AtLeastNOperands<1>::Impl> {
+public:
+  using Op::Op;
+
+  /// Builds an affine store operation with the specified map and operands.
+  static void build(Builder *builder, OperationState *result,
+                    Value *valueToStore, AffineMap map,
+                    ArrayRef<Value *> operands);
+  /// Builds an affine store operation with an identity map and operands.
+  static void build(Builder *builder, OperationState *result,
+                    Value *valueToStore, Value *memref,
+                    ArrayRef<Value *> operands);
+
+  /// Get value to be stored by store operation.
+  Value *getValueToStore() { return getOperand(0); }
+
+  /// Returns the operand index of the memref.
+  unsigned getMemRefOperandIndex() { return 1; }
+
+  /// Get memref operand.
+  Value *getMemRef() { return getOperand(getMemRefOperandIndex()); }
+  void setMemRef(Value *value) { setOperand(getMemRefOperandIndex(), value); }
+
+  MemRefType getMemRefType() {
+    return getMemRef()->getType().cast<MemRefType>();
+  }
+
+  /// Get affine map operands.
+  operand_range getIndices() { return llvm::drop_begin(getOperands(), 2); }
+
+  /// Returns the affine map used to index the memref for this operation.
+  AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
+  AffineMapAttr getAffineMapAttr() {
+    return getAttr(getMapAttrName()).cast<AffineMapAttr>();
+  }
+
+  /// Returns the AffineMapAttr associated with 'memref'.
+  NamedAttribute getAffineMapAttrForMemRef(Value *memref) {
+    assert(memref == getMemRef());
+    return {Identifier::get(getMapAttrName(), getContext()),
+            getAffineMapAttr()};
+  }
+
+  static StringRef getMapAttrName() { return "map"; }
+  static StringRef getOperationName() { return "affine.store"; }
+
+  // Hooks to customize behavior of this op.
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+  LogicalResult verify();
+  static void getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                          MLIRContext *context);
+};
+
+/// Returns true if the given Value can be used as a dimension id.
+bool isValidDim(Value *value);
+
+/// Returns true if the given Value can be used as a symbol.
+bool isValidSymbol(Value *value);
+
+/// Modifies both `map` and `operands` in-place so as to:
+/// 1. drop duplicate operands
+/// 2. drop unused dims and symbols from map
+void canonicalizeMapAndOperands(AffineMap *map,
+                                llvm::SmallVectorImpl<Value *> *operands);
+
+/// Returns a composed AffineApplyOp by composing `map` and `operands` with
+/// other AffineApplyOps supplying those operands. The operands of the resulting
+/// AffineApplyOp do not change the length of  AffineApplyOp chains.
+AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
+                                      llvm::ArrayRef<Value *> operands);
+
+/// Given an affine map `map` and its input `operands`, this method composes
+/// into `map`, maps of AffineApplyOps whose results are the values in
+/// `operands`, iteratively until no more of `operands` are the result of an
+/// AffineApplyOp. When this function returns, `map` becomes the composed affine
+/// map, and each Value in `operands` is guaranteed to be either a loop IV or a
+/// terminal symbol, i.e., a symbol defined at the top level or a block/function
+/// argument.
+void fullyComposeAffineMapAndOperands(AffineMap *map,
+                                      llvm::SmallVectorImpl<Value *> *operands);
+
+#define GET_OP_CLASSES
+#include "mlir/AffineOps/AffineOps.h.inc"
+
+/// Returns if the provided value is the induction variable of a AffineForOp.
+bool isForInductionVar(Value *val);
+
+/// Returns the loop parent of an induction variable. If the provided value is
+/// not an induction variable, then return nullptr.
+AffineForOp getForInductionVarOwner(Value *val);
+
+/// Extracts the induction variables from a list of AffineForOps and places them
+/// in the output argument `ivs`.
+void extractForInductionVars(ArrayRef<AffineForOp> forInsts,
+                             SmallVectorImpl<Value *> *ivs);
+
+/// AffineBound represents a lower or upper bound in the for operation.
+/// This class does not own the underlying operands. Instead, it refers
+/// to the operands stored in the AffineForOp. Its life span should not exceed
+/// that of the for operation it refers to.
+class AffineBound {
+public:
+  AffineForOp getAffineForOp() { return op; }
+  AffineMap getMap() { return map; }
+
+  /// Returns an AffineValueMap representing this bound.
+  AffineValueMap getAsAffineValueMap();
+
+  unsigned getNumOperands() { return opEnd - opStart; }
+  Value *getOperand(unsigned idx) {
+    return op.getOperation()->getOperand(opStart + idx);
+  }
+
+  using operand_iterator = AffineForOp::operand_iterator;
+  using operand_range = AffineForOp::operand_range;
+
+  operand_iterator operand_begin() { return op.operand_begin() + opStart; }
+  operand_iterator operand_end() { return op.operand_begin() + opEnd; }
+  operand_range getOperands() { return {operand_begin(), operand_end()}; }
+
+private:
+  // 'affine.for' operation that contains this bound.
+  AffineForOp op;
+  // Start and end positions of this affine bound operands in the list of
+  // the containing 'affine.for' operation operands.
+  unsigned opStart, opEnd;
+  // Affine map for this bound.
+  AffineMap map;
+
+  AffineBound(AffineForOp op, unsigned opStart, unsigned opEnd, AffineMap map)
+      : op(op), opStart(opStart), opEnd(opEnd), map(map) {}
+
+  friend class AffineForOp;
+};
+
+} // end namespace mlir
+
+#endif
diff --git a/third_party/mlir/include/mlir/AffineOps/AffineOps.td b/third_party/mlir/include/mlir/AffineOps/AffineOps.td
new file mode 100644
index 0000000..c517ed02
--- /dev/null
+++ b/third_party/mlir/include/mlir/AffineOps/AffineOps.td
@@ -0,0 +1,259 @@
+//===- AffineOps.td - Affine operation definitions ---------*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Defines MLIR affine operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef AFFINE_OPS
+#else
+#define AFFINE_OPS
+
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+include "mlir/AffineOps/AffineOpsBase.td"
+
+def Affine_Dialect : Dialect {
+  let name = "affine";
+  let cppNamespace = "";
+}
+
+// Base class for Affine dialect ops.
+class Affine_Op<string mnemonic, list<OpTrait> traits = []> :
+    Op<Affine_Dialect, mnemonic, traits> {
+  // For every affine op, there needs to be a:
+  //   * void print(OpAsmPrinter *p, ${C++ class of Op} op)
+  //   * LogicalResult verify(${C++ class of Op} op)
+  //   * ParseResult parse${C++ class of Op}(OpAsmParser *parser,
+  //                                         OperationState *result)
+  // functions.
+  let printer = [{ return ::print(p, *this); }];
+  let verifier = [{ return ::verify(*this); }];
+  let parser = [{ return ::parse$cppClass(parser, result); }];
+}
+
+// Require regions to have affine terminator.
+def ImplicitAffineTerminator
+    : SingleBlockImplicitTerminator<"AffineTerminatorOp">;
+
+def AffineForOp : Affine_Op<"for", [ImplicitAffineTerminator]> {
+  let summary = "for operation";
+  let description = [{
+    The "affine.for" operation represents an affine loop nest, defining an SSA
+    value for its induction variable. It has one region capturing the loop body.
+    The induction variable is represented as a argument of this region. This SSA
+    value always has type index, which is the size of the machine word. The
+    stride, represented by step, is a positive constant integer which defaults
+    to "1" if not present. The lower and upper bounds specify a half-open range:
+    the range includes the lower bound but does not include the upper bound.
+
+    The body region must contain exactly one block that terminates with
+    "affine.terminator".  Calling AffineForOp::build will create such region
+    and insert the terminator, so will the parsing even in cases if it is absent
+    from the custom format.
+
+    The lower and upper bounds of a for operation are represented as an
+    application of an affine mapping to a list of SSA values passed to the map.
+    The same restrictions hold for these SSA values as for all bindings of SSA
+    values to dimensions and symbols. The affine mappings for the bounds may
+    return multiple results, in which case the max/min keywords are required
+    (for the lower/upper bound respectively), and the bound is the
+    maximum/minimum of the returned values.
+
+    Example:
+
+      affine.for %i = 1 to 10 {
+        ...
+      }
+
+  }];
+  let arguments = (ins Variadic<AnyType>);
+  let regions = (region SizedRegion<1>:$region);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<"Builder *builder, OperationState *result, "
+              "int64_t lowerBound, int64_t upperBound, int64_t step = 1">,
+    OpBuilder<"Builder *builder, OperationState *result, "
+              "ArrayRef<Value *> lbOperands, AffineMap lbMap, "
+              "ArrayRef<Value *> ubOperands, AffineMap ubMap, "
+              "int64_t step = 1">
+  ];
+
+  let extraClassDeclaration = [{
+    static StringRef getStepAttrName() { return "step"; }
+    static StringRef getLowerBoundAttrName() { return "lower_bound"; }
+    static StringRef getUpperBoundAttrName() { return "upper_bound"; }
+
+    Block *getBody() { return &region().front(); }
+    Value *getInductionVar() { return getBody()->getArgument(0); }
+    OpBuilder getBodyBuilder() {
+      return OpBuilder(getBody(), std::prev(getBody()->end()));
+    }
+
+    // TODO: provide iterators for the lower and upper bound operands
+    // if the current access via getLowerBound(), getUpperBound() is too slow.
+
+    /// Returns operands for the lower bound map.
+    operand_range getLowerBoundOperands();
+
+    /// Returns operands for the upper bound map.
+    operand_range getUpperBoundOperands();
+
+    /// Returns information about the lower bound as a single object.
+    AffineBound getLowerBound();
+
+    /// Returns information about the upper bound as a single object.
+    AffineBound getUpperBound();
+
+    /// Returns loop step.
+    int64_t getStep() {
+      return getAttr(getStepAttrName()).cast<IntegerAttr>().getInt();
+    }
+
+    /// Returns affine map for the lower bound.
+    AffineMap getLowerBoundMap() { return getLowerBoundMapAttr().getValue(); }
+    AffineMapAttr getLowerBoundMapAttr() {
+      return getAttr(getLowerBoundAttrName()).cast<AffineMapAttr>();
+    }
+    /// Returns affine map for the upper bound. The upper bound is exclusive.
+    AffineMap getUpperBoundMap() { return getUpperBoundMapAttr().getValue(); }
+    AffineMapAttr getUpperBoundMapAttr() {
+      return getAttr(getUpperBoundAttrName()).cast<AffineMapAttr>();
+    }
+
+    /// Set lower bound. The new bound must have the same number of operands as
+    /// the current bound map. Otherwise, 'replaceForLowerBound' should be used.
+    void setLowerBound(ArrayRef<Value *> operands, AffineMap map);
+    /// Set upper bound. The new bound must not have more operands than the
+    /// current bound map. Otherwise, 'replaceForUpperBound' should be used.
+    void setUpperBound(ArrayRef<Value *> operands, AffineMap map);
+
+    /// Set the lower bound map without changing operands.
+    void setLowerBoundMap(AffineMap map);
+
+    /// Set the upper bound map without changing operands.
+    void setUpperBoundMap(AffineMap map);
+
+    /// Set loop step.
+    void setStep(int64_t step) {
+      assert(step > 0 && "step has to be a positive integer constant");
+      auto *context = getLowerBoundMap().getContext();
+      setAttr(Identifier::get(getStepAttrName(), context),
+              IntegerAttr::get(IndexType::get(context), step));
+    }
+
+    /// Returns true if the lower bound is constant.
+    bool hasConstantLowerBound();
+    /// Returns true if the upper bound is constant.
+    bool hasConstantUpperBound();
+    /// Returns true if both bounds are constant.
+    bool hasConstantBounds() {
+      return hasConstantLowerBound() && hasConstantUpperBound();
+    }
+    /// Returns the value of the constant lower bound.
+    /// Fails assertion if the bound is non-constant.
+    int64_t getConstantLowerBound();
+    /// Returns the value of the constant upper bound. The upper bound is
+    /// exclusive. Fails assertion if the bound is non-constant.
+    int64_t getConstantUpperBound();
+    /// Sets the lower bound to the given constant value.
+    void setConstantLowerBound(int64_t value);
+    /// Sets the upper bound to the given constant value.
+    void setConstantUpperBound(int64_t value);
+
+    /// Returns true if both the lower and upper bound have the same operand 
+    /// lists (same operands in the same order).
+    bool matchingBoundOperandList();
+  }];
+
+  let hasCanonicalizer = 1;
+}
+
+def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> {
+  let summary = "if-then-else operation";
+  let description = [{
+    The "if" operation represents an if-then-else construct for conditionally
+    executing two regions of code. The operands to an if operation are an
+    IntegerSet condition and a set of symbol/dimension operands to the
+    condition set. The operation produces no results. For example:
+
+       affine.if #set(%i)  {
+         ...
+       } else {
+         ...
+       }
+
+    The 'else' blocks to the if operation are optional, and may be omitted. For
+    example:
+
+       affine.if #set(%i)  {
+         ...
+       }
+  }];
+  let arguments = (ins Variadic<AnyType>);
+  let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<"Builder *builder, OperationState *result, "
+              "Value *cond, bool withElseRegion">
+  ];
+
+  let extraClassDeclaration = [{
+    static StringRef getConditionAttrName() { return "condition"; }
+
+    IntegerSet getIntegerSet();
+    void setIntegerSet(IntegerSet newSet);
+
+    OpBuilder getThenBodyBuilder() {
+      assert(!thenRegion().empty() && "Unexpected empty 'then' region.");
+      Block &body = thenRegion().front();
+      return OpBuilder(&body, std::prev(body.end()));
+    }
+    OpBuilder getElseBodyBuilder() {
+      assert(!elseRegion().empty() && "Unexpected empty 'else' region.");
+      Block &body = elseRegion().front();
+      return OpBuilder(&body, std::prev(body.end()));
+    }
+  }];
+}
+
+def AffineTerminatorOp :
+    Affine_Op<"terminator", [Terminator]> {
+  let summary = "affine terminator operation";
+  let description = [{
+    Affine terminator is a special terminator operation for blocks inside affine
+    loops and branches. It unconditionally transmits the control flow to the
+    successor of the operation enclosing the region.
+
+    This operation does _not_ have a custom syntax. However, affine control
+    operations omit the terminator in their custom syntax for brevity.
+  }];
+
+  // No custom parsing/printing form.
+  let parser = ?;
+  let printer = ?;
+
+  // Fully specified by traits.
+  let verifier = ?;
+}
+
+#endif // AFFINE_OPS
diff --git a/third_party/mlir/include/mlir/AffineOps/AffineOpsBase.td b/third_party/mlir/include/mlir/AffineOps/AffineOpsBase.td
new file mode 100644
index 0000000..2ac1d37
--- /dev/null
+++ b/third_party/mlir/include/mlir/AffineOps/AffineOpsBase.td
@@ -0,0 +1,44 @@
+//===- AffineOpsBase.td - Affine operation definitions -----*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Defines base support for MLIR affine operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef AFFINE_OPS_BASE
+#else
+#define AFFINE_OPS_BASE
+
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+// Attributes containing affine maps.
+def AffineMapAttr : Attr<
+    CPred<"$_self.isa<AffineMapAttr>()">, "AffineMap attribute"> {
+  let storageType = [{ AffineMapAttr }];
+  let returnType = [{ AffineMap }];
+  let constBuilderCall = "$_builder.getAffineMapAttr($0)";
+}
+
+def AffineMapArrayAttr : TypedArrayAttrBase<AffineMapAttr,
+                                      "AffineMap array attribute"> {
+  let constBuilderCall = "$_builder.getAffineMapArrayAttr($0)";
+}
+
+#endif // AFFINE_OPS_BASE
diff --git a/third_party/mlir/include/mlir/AffineOps/CMakeLists.txt b/third_party/mlir/include/mlir/AffineOps/CMakeLists.txt
new file mode 100644
index 0000000..6c5a58c
--- /dev/null
+++ b/third_party/mlir/include/mlir/AffineOps/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS AffineOps.td)
+mlir_tablegen(AffineOps.h.inc -gen-op-decls)
+mlir_tablegen(AffineOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRAffineOpsIncGen)
diff --git a/third_party/mlir/include/mlir/Analysis/AffineAnalysis.h b/third_party/mlir/include/mlir/Analysis/AffineAnalysis.h
new file mode 100644
index 0000000..bb25a65
--- /dev/null
+++ b/third_party/mlir/include/mlir/Analysis/AffineAnalysis.h
@@ -0,0 +1,134 @@
+//===- AffineAnalysis.h - analyses for affine structures --------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This header file defines prototypes for methods that perform analysis
+// involving affine structures (AffineExprStorage, AffineMap, IntegerSet, etc.)
+// and other IR structures that in turn use these.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_AFFINE_ANALYSIS_H
+#define MLIR_ANALYSIS_AFFINE_ANALYSIS_H
+
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+
+class AffineApplyOp;
+class AffineForOp;
+class AffineValueMap;
+class FlatAffineConstraints;
+class Operation;
+class Value;
+
+/// Returns in `affineApplyOps`, the sequence of those AffineApplyOp
+/// Operations that are reachable via a search starting from `operands` and
+/// ending at those operands that are not the result of an AffineApplyOp.
+void getReachableAffineApplyOps(
+    llvm::ArrayRef<Value *> operands,
+    llvm::SmallVectorImpl<Operation *> &affineApplyOps);
+
+/// Builds a system of constraints with dimensional identifiers corresponding to
+/// the loop IVs of the forOps appearing in that order. Bounds of the loop are
+/// used to add appropriate inequalities. Any symbols founds in the bound
+/// operands are added as symbols in the system. Returns failure for the yet
+/// unimplemented cases.
+//  TODO(bondhugula): handle non-unit strides.
+LogicalResult getIndexSet(llvm::MutableArrayRef<AffineForOp> forOps,
+                          FlatAffineConstraints *domain);
+
+/// Encapsulates a memref load or store access information.
+struct MemRefAccess {
+  Value *memref;
+  Operation *opInst;
+  llvm::SmallVector<Value *, 4> indices;
+
+  /// Constructs a MemRefAccess from a load or store operation.
+  // TODO(b/119949820): add accessors to standard op's load, store, DMA op's to
+  // return MemRefAccess, i.e., loadOp->getAccess(), dmaOp->getRead/WriteAccess.
+  explicit MemRefAccess(Operation *opInst);
+
+  // Returns the rank of the memref associated with this access.
+  unsigned getRank() const;
+  // Returns true if this access is of a store op.
+  bool isStore() const;
+
+  /// Populates 'accessMap' with composition of AffineApplyOps reachable from
+  // 'indices'.
+  void getAccessMap(AffineValueMap *accessMap) const;
+};
+
+// DependenceComponent contains state about the direction of a dependence as an
+// interval [lb, ub] for an AffineForOp.
+// Distance vectors components are represented by the interval [lb, ub] with
+// lb == ub.
+// Direction vectors components are represented by the interval [lb, ub] with
+// lb < ub. Note that ub/lb == None means unbounded.
+struct DependenceComponent {
+  // The AffineForOp Operation associated with this dependence component.
+  Operation *op;
+  // The lower bound of the dependence distance.
+  llvm::Optional<int64_t> lb;
+  // The upper bound of the dependence distance (inclusive).
+  llvm::Optional<int64_t> ub;
+  DependenceComponent() : lb(llvm::None), ub(llvm::None) {}
+};
+
+/// Checks whether two accesses to the same memref access the same element.
+/// Each access is specified using the MemRefAccess structure, which contains
+/// the operation, indices and memref associated with the access. Returns
+/// 'NoDependence' if it can be determined conclusively that the accesses do not
+/// access the same memref element. If 'allowRAR' is true, will consider
+/// read-after-read dependences (typically used by applications trying to
+/// optimize input reuse).
+// TODO(andydavis) Wrap 'dependenceConstraints' and 'dependenceComponents' into
+// a single struct.
+// TODO(andydavis) Make 'dependenceConstraints' optional arg.
+struct DependenceResult {
+  enum ResultEnum {
+    HasDependence, // A dependence exists between 'srcAccess' and 'dstAccess'.
+    NoDependence,  // No dependence exists between 'srcAccess' and 'dstAccess'.
+    Failure,       // Dependence check failed due to unsupported cases.
+  } value;
+  DependenceResult(ResultEnum v) : value(v) {}
+};
+
+DependenceResult checkMemrefAccessDependence(
+    const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
+    unsigned loopDepth, FlatAffineConstraints *dependenceConstraints,
+    llvm::SmallVector<DependenceComponent, 2> *dependenceComponents,
+    bool allowRAR = false);
+
+/// Utility function that returns true if the provided DependenceResult
+/// corresponds to a dependence result.
+inline bool hasDependence(DependenceResult result) {
+  return result.value == DependenceResult::HasDependence;
+}
+
+/// Returns in 'depCompsVec', dependence components for dependences between all
+/// load and store ops in loop nest rooted at 'forOp', at loop depths in range
+/// [1, maxLoopDepth].
+void getDependenceComponents(
+    AffineForOp forOp, unsigned maxLoopDepth,
+    std::vector<llvm::SmallVector<DependenceComponent, 2>> *depCompsVec);
+
+} // end namespace mlir
+
+#endif // MLIR_ANALYSIS_AFFINE_ANALYSIS_H
diff --git a/third_party/mlir/include/mlir/Analysis/AffineStructures.h b/third_party/mlir/include/mlir/Analysis/AffineStructures.h
new file mode 100644
index 0000000..968ffb1
--- /dev/null
+++ b/third_party/mlir/include/mlir/Analysis/AffineStructures.h
@@ -0,0 +1,813 @@
+//===- AffineStructures.h - MLIR Affine Structures Class --------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Structures for affine/polyhedral analysis of ML functions.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_AFFINE_STRUCTURES_H
+#define MLIR_ANALYSIS_AFFINE_STRUCTURES_H
+
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+
+class AffineApplyOp;
+class AffineBound;
+class AffineCondition;
+class AffineMap;
+class AffineForOp;
+class IntegerSet;
+class MLIRContext;
+class Value;
+class HyperRectangularSet;
+class MemRefType;
+
+/// A mutable affine map. Its affine expressions are however unique.
+struct MutableAffineMap {
+public:
+  MutableAffineMap() {}
+  MutableAffineMap(AffineMap map);
+
+  ArrayRef<AffineExpr> getResults() const { return results; }
+  AffineExpr getResult(unsigned idx) const { return results[idx]; }
+  void setResult(unsigned idx, AffineExpr result) { results[idx] = result; }
+  unsigned getNumResults() const { return results.size(); }
+  unsigned getNumDims() const { return numDims; }
+  void setNumDims(unsigned d) { numDims = d; }
+  unsigned getNumSymbols() const { return numSymbols; }
+  void setNumSymbols(unsigned d) { numSymbols = d; }
+  MLIRContext *getContext() const { return context; }
+
+  /// Returns true if the idx'th result expression is a multiple of factor.
+  bool isMultipleOf(unsigned idx, int64_t factor) const;
+
+  /// Resets this MutableAffineMap with 'map'.
+  void reset(AffineMap map);
+
+  /// Simplify the (result) expressions in this map using analysis (used by
+  //-simplify-affine-expr pass).
+  void simplify();
+  /// Get the AffineMap corresponding to this MutableAffineMap. Note that an
+  /// AffineMap will be uniqued and stored in context, while a mutable one
+  /// isn't.
+  AffineMap getAffineMap() const;
+
+private:
+  // Same meaning as AffineMap's fields.
+  SmallVector<AffineExpr, 8> results;
+  unsigned numDims;
+  unsigned numSymbols;
+  /// A pointer to the IR's context to store all newly created
+  /// AffineExprStorage's.
+  MLIRContext *context;
+};
+
+/// A mutable integer set. Its affine expressions are however unique.
+struct MutableIntegerSet {
+public:
+  MutableIntegerSet(IntegerSet set, MLIRContext *context);
+
+  /// Create a universal set (no constraints).
+  MutableIntegerSet(unsigned numDims, unsigned numSymbols,
+                    MLIRContext *context);
+
+  unsigned getNumDims() const { return numDims; }
+  unsigned getNumSymbols() const { return numSymbols; }
+  unsigned getNumConstraints() const { return constraints.size(); }
+
+  void clear() {
+    constraints.clear();
+    eqFlags.clear();
+  }
+
+private:
+  unsigned numDims;
+  unsigned numSymbols;
+
+  SmallVector<AffineExpr, 8> constraints;
+  SmallVector<bool, 8> eqFlags;
+};
+
+/// An AffineValueMap is an affine map plus its ML value operands and
+/// results for analysis purposes. The structure is still a tree form that is
+/// same as that of an affine map or an AffineApplyOp. However, its operands,
+/// results, and its map can themselves change  as a result of
+/// substitutions, simplifications, and other analysis.
+// An affine value map can readily be constructed from an AffineApplyOp, or an
+// AffineBound of a AffineForOp. It can be further transformed, substituted
+// into, or simplified. Unlike AffineMap's, AffineValueMap's are created and
+// destroyed during analysis. Only the AffineMap expressions that are pointed by
+// them are unique'd. An affine value map, and the operations on it, maintain
+// the invariant that operands are always positionally aligned with the
+// AffineDimExpr and AffineSymbolExpr in the underlying AffineMap.
+// TODO(bondhugula): Some of these classes could go into separate files.
+class AffineValueMap {
+public:
+  // Creates an empty AffineValueMap (users should call 'reset' to reset map
+  // and operands).
+  AffineValueMap() {}
+  AffineValueMap(AffineMap map);
+  AffineValueMap(AffineMap map, ArrayRef<Value *> operands,
+                 ArrayRef<Value *> results = llvm::None);
+
+  explicit AffineValueMap(AffineApplyOp applyOp);
+  explicit AffineValueMap(AffineBound bound);
+
+  ~AffineValueMap();
+
+  // Resets this AffineValueMap with 'map', 'operands', and 'results'.
+  void reset(AffineMap map, ArrayRef<Value *> operands,
+             ArrayRef<Value *> results = llvm::None);
+
+  /// Return true if the idx^th result can be proved to be a multiple of
+  /// 'factor', false otherwise.
+  inline bool isMultipleOf(unsigned idx, int64_t factor) const;
+
+  /// Return true if the idx^th result depends on 'value', false otherwise.
+  bool isFunctionOf(unsigned idx, Value *value) const;
+
+  /// Return true if the result at 'idx' is a constant, false
+  /// otherwise.
+  bool isConstant(unsigned idx) const;
+
+  /// Return true if this is an identity map.
+  bool isIdentity() const;
+
+  inline unsigned getNumOperands() const { return operands.size(); }
+  inline unsigned getNumDims() const { return map.getNumDims(); }
+  inline unsigned getNumSymbols() const { return map.getNumSymbols(); }
+  inline unsigned getNumResults() const { return map.getNumResults(); }
+
+  Value *getOperand(unsigned i) const;
+  ArrayRef<Value *> getOperands() const;
+  AffineMap getAffineMap() const;
+
+private:
+  // A mutable affine map.
+  MutableAffineMap map;
+
+  // TODO: make these trailing objects?
+  /// The SSA operands binding to the dim's and symbols of 'map'.
+  SmallVector<Value *, 4> operands;
+  /// The SSA results binding to the results of 'map'.
+  SmallVector<Value *, 4> results;
+};
+
+/// An IntegerValueSet is an integer set plus its operands.
+// Both, the integer set being pointed to and the operands can change during
+// analysis, simplification, and transformation.
+class IntegerValueSet {
+  /// Constructs an integer value set from an affine value map.
+  // This will lead to a single equality in 'set'.
+  explicit IntegerValueSet(const AffineValueMap &avm);
+
+  /// Returns true if this integer set is determined to be empty. Emptiness is
+  /// checked by by eliminating identifiers successively (through either
+  /// Gaussian or Fourier-Motzkin) while using the GCD test and a trivial
+  /// invalid constraint check. Returns 'true' if the constaint system is found
+  /// to be empty; false otherwise. This method is exact for rational spaces but
+  /// not integer spaces - thus, if it returns true, the set is provably integer
+  /// empty as well, but if it returns false, it doesn't necessarily mean an
+  /// integer point exists in it. This method also returns false where an
+  /// explosion of constraints is detected - due to the super-exponential
+  /// worse-case complexity of Fourier-Motzkin elimination (rare for realistic
+  /// problem cases but possible for artificial adversarial or improperly
+  // constructed ones), this method returns false conservatively.
+  bool isEmpty() const;
+
+  bool getNumDims() const { return set.getNumDims(); }
+  bool getNumSymbols() const { return set.getNumSymbols(); }
+
+private:
+  // The set pointed to may itself change unlike in IR structures like
+  // 'AffineCondition'.
+  MutableIntegerSet set;
+  /// The SSA operands binding to the dim's and symbols of 'set'.
+  SmallVector<Value *, 4> operands;
+};
+
+/// A flat list of affine equalities and inequalities in the form.
+/// Inequality: c_0*x_0 + c_1*x_1 + .... + c_{n-1}*x_{n-1} == 0
+/// Equality: c_0*x_0 + c_1*x_1 + .... + c_{n-1}*x_{n-1} >= 0
+///
+/// FlatAffineConstraints stores coefficients in a contiguous buffer (one buffer
+/// for equalities and one for inequalities). The size of each buffer is
+/// numReservedCols * number of inequalities (or equalities). The reserved size
+/// is numReservedCols * numReservedInequalities (or numReservedEqualities). A
+/// coefficient (r, c) lives at the location numReservedCols * r + c in the
+/// buffer. The extra space between getNumCols() and numReservedCols exists to
+/// prevent frequent movement of data when adding columns, especially at the
+/// end.
+///
+/// The identifiers x_0, x_1, ... appear in the order: dimensional identifiers,
+/// symbolic identifiers, and local identifiers.  The local identifiers
+/// correspond to local/internal variables created when converting from
+/// AffineExpr's containing mod's and div's; they are thus needed to increase
+/// representational power. Each local identifier is always (by construction) a
+/// floordiv of a pure add/mul affine function of dimensional, symbolic, and
+/// other local identifiers, in a non-mutually recursive way. Hence, every local
+/// identifier can ultimately always be recovered as an affine function of
+/// dimensional and symbolic identifiers (involving floordiv's); note however
+/// that some floordiv combinations are converted to mod's by AffineExpr
+/// construction.
+///
+class FlatAffineConstraints {
+public:
+  enum IdKind { Dimension, Symbol, Local };
+
+  /// Constructs a constraint system reserving memory for the specified number
+  /// of constraints and identifiers..
+  FlatAffineConstraints(unsigned numReservedInequalities,
+                        unsigned numReservedEqualities,
+                        unsigned numReservedCols, unsigned numDims = 0,
+                        unsigned numSymbols = 0, unsigned numLocals = 0,
+                        ArrayRef<Optional<Value *>> idArgs = {})
+      : numReservedCols(numReservedCols), numDims(numDims),
+        numSymbols(numSymbols) {
+    assert(numReservedCols >= numDims + numSymbols + 1);
+    assert(idArgs.empty() || idArgs.size() == numDims + numSymbols + numLocals);
+    equalities.reserve(numReservedCols * numReservedEqualities);
+    inequalities.reserve(numReservedCols * numReservedInequalities);
+    numIds = numDims + numSymbols + numLocals;
+    ids.reserve(numReservedCols);
+    if (idArgs.empty())
+      ids.resize(numIds, None);
+    else
+      ids.append(idArgs.begin(), idArgs.end());
+  }
+
+  /// Constructs a constraint system with the specified number of
+  /// dimensions and symbols.
+  FlatAffineConstraints(unsigned numDims = 0, unsigned numSymbols = 0,
+                        unsigned numLocals = 0,
+                        ArrayRef<Optional<Value *>> idArgs = {})
+      : numReservedCols(numDims + numSymbols + numLocals + 1), numDims(numDims),
+        numSymbols(numSymbols) {
+    assert(numReservedCols >= numDims + numSymbols + 1);
+    assert(idArgs.empty() || idArgs.size() == numDims + numSymbols + numLocals);
+    numIds = numDims + numSymbols + numLocals;
+    ids.reserve(numIds);
+    if (idArgs.empty())
+      ids.resize(numIds, None);
+    else
+      ids.append(idArgs.begin(), idArgs.end());
+  }
+
+  explicit FlatAffineConstraints(const HyperRectangularSet &set);
+
+  /// Create a flat affine constraint system from an AffineValueMap or a list of
+  /// these. The constructed system will only include equalities.
+  // TODO(bondhugula)
+  explicit FlatAffineConstraints(const AffineValueMap &avm);
+  explicit FlatAffineConstraints(ArrayRef<const AffineValueMap *> avmRef);
+
+  /// Creates an affine constraint system from an IntegerSet.
+  explicit FlatAffineConstraints(IntegerSet set);
+
+  /// Create an affine constraint system from an IntegerValueSet.
+  // TODO(bondhugula)
+  explicit FlatAffineConstraints(const IntegerValueSet &set);
+
+  FlatAffineConstraints(const FlatAffineConstraints &other);
+
+  FlatAffineConstraints(ArrayRef<const AffineValueMap *> avmRef,
+                        IntegerSet set);
+
+  FlatAffineConstraints(const MutableAffineMap &map);
+
+  ~FlatAffineConstraints() {}
+
+  // Clears any existing data and reserves memory for the specified constraints.
+  void reset(unsigned numReservedInequalities, unsigned numReservedEqualities,
+             unsigned numReservedCols, unsigned numDims, unsigned numSymbols,
+             unsigned numLocals = 0, ArrayRef<Value *> idArgs = {});
+
+  void reset(unsigned numDims = 0, unsigned numSymbols = 0,
+             unsigned numLocals = 0, ArrayRef<Value *> idArgs = {});
+
+  /// Appends constraints from 'other' into this. This is equivalent to an
+  /// intersection with no simplification of any sort attempted.
+  void append(const FlatAffineConstraints &other);
+
+  // Checks for emptiness by performing variable elimination on all identifiers,
+  // running the GCD test on each equality constraint, and checking for invalid
+  // constraints.
+  // Returns true if the GCD test fails for any equality, or if any invalid
+  // constraints are discovered on any row. Returns false otherwise.
+  bool isEmpty() const;
+
+  // Runs the GCD test on all equality constraints. Returns 'true' if this test
+  // fails on any equality. Returns 'false' otherwise.
+  // This test can be used to disprove the existence of a solution. If it
+  // returns true, no integer solution to the equality constraints can exist.
+  bool isEmptyByGCDTest() const;
+
+  // Clones this object.
+  std::unique_ptr<FlatAffineConstraints> clone() const;
+
+  /// Returns the value at the specified equality row and column.
+  inline int64_t atEq(unsigned i, unsigned j) const {
+    return equalities[i * numReservedCols + j];
+  }
+  inline int64_t &atEq(unsigned i, unsigned j) {
+    return equalities[i * numReservedCols + j];
+  }
+
+  inline int64_t atIneq(unsigned i, unsigned j) const {
+    return inequalities[i * numReservedCols + j];
+  }
+
+  inline int64_t &atIneq(unsigned i, unsigned j) {
+    return inequalities[i * numReservedCols + j];
+  }
+
+  /// Returns the number of columns in the constraint system.
+  inline unsigned getNumCols() const { return numIds + 1; }
+
+  inline unsigned getNumEqualities() const {
+    assert(equalities.size() % numReservedCols == 0 &&
+           "inconsistent equality buffer size");
+    return equalities.size() / numReservedCols;
+  }
+
+  inline unsigned getNumInequalities() const {
+    assert(inequalities.size() % numReservedCols == 0 &&
+           "inconsistent inequality buffer size");
+    return inequalities.size() / numReservedCols;
+  }
+
+  inline unsigned getNumReservedEqualities() const {
+    return equalities.capacity() / numReservedCols;
+  }
+
+  inline unsigned getNumReservedInequalities() const {
+    return inequalities.capacity() / numReservedCols;
+  }
+
+  inline ArrayRef<int64_t> getEquality(unsigned idx) const {
+    return ArrayRef<int64_t>(&equalities[idx * numReservedCols], getNumCols());
+  }
+
+  inline ArrayRef<int64_t> getInequality(unsigned idx) const {
+    return ArrayRef<int64_t>(&inequalities[idx * numReservedCols],
+                             getNumCols());
+  }
+
+  AffineExpr toAffineExpr(unsigned idx, MLIRContext *context);
+
+  /// Adds constraints (lower and upper bounds) for the specified 'affine.for'
+  /// operation's Value using IR information stored in its bound maps. The
+  /// right identifier is first looked up using forOp's Value. Asserts if the
+  /// Value corresponding to the 'affine.for' operation isn't found in the
+  /// constraint system. Returns failure for the yet unimplemented/unsupported
+  /// cases.  Any new identifiers that are found in the bound operands of the
+  /// 'affine.for' operation are added as trailing identifiers (either
+  /// dimensional or symbolic depending on whether the operand is a valid
+  /// symbol).
+  //  TODO(bondhugula): add support for non-unit strides.
+  LogicalResult addAffineForOpDomain(AffineForOp forOp);
+
+  /// Adds a lower or an upper bound for the identifier at the specified
+  /// position with constraints being drawn from the specified bound map and
+  /// operands. If `eq` is true, add a single equality equal to the bound map's
+  /// first result expr.
+  LogicalResult addLowerOrUpperBound(unsigned pos, AffineMap boundMap,
+                                     ArrayRef<Value *> operands, bool eq,
+                                     bool lower = true);
+
+  /// Computes the lower and upper bounds of the first 'num' dimensional
+  /// identifiers (starting at 'offset') as an affine map of the remaining
+  /// identifiers (dimensional and symbolic). This method is able to detect
+  /// identifiers as floordiv's and mod's of affine expressions of other
+  /// identifiers with respect to (positive) constants. Sets bound map to a
+  /// null AffineMap if such a bound can't be found (or yet unimplemented).
+  void getSliceBounds(unsigned offset, unsigned num, MLIRContext *context,
+                      SmallVectorImpl<AffineMap> *lbMaps,
+                      SmallVectorImpl<AffineMap> *ubMaps);
+
+  /// Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper
+  /// bounds in 'ubMaps' to each identifier in the constraint system which has
+  /// a value in 'values'. Note that both lower/upper bounds share the same
+  /// operand list 'operands'.
+  /// This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size'.
+  /// Note that both lower/upper bounds use operands from 'operands'.
+  LogicalResult addSliceBounds(ArrayRef<Value *> values,
+                               ArrayRef<AffineMap> lbMaps,
+                               ArrayRef<AffineMap> ubMaps,
+                               ArrayRef<Value *> operands);
+
+  // Adds an inequality (>= 0) from the coefficients specified in inEq.
+  void addInequality(ArrayRef<int64_t> inEq);
+  // Adds an equality from the coefficients specified in eq.
+  void addEquality(ArrayRef<int64_t> eq);
+
+  /// Adds a constant lower bound constraint for the specified identifier.
+  void addConstantLowerBound(unsigned pos, int64_t lb);
+  /// Adds a constant upper bound constraint for the specified identifier.
+  void addConstantUpperBound(unsigned pos, int64_t ub);
+
+  /// Adds a new local identifier as the floordiv of an affine function of other
+  /// identifiers, the coefficients of which are provided in 'dividend' and with
+  /// respect to a positive constant 'divisor'. Two constraints are added to the
+  /// system to capture equivalence with the floordiv:
+  /// q = dividend floordiv c    <=>   c*q <= dividend <= c*q + c - 1.
+  void addLocalFloorDiv(ArrayRef<int64_t> dividend, int64_t divisor);
+
+  /// Adds a constant lower bound constraint for the specified expression.
+  void addConstantLowerBound(ArrayRef<int64_t> expr, int64_t lb);
+  /// Adds a constant upper bound constraint for the specified expression.
+  void addConstantUpperBound(ArrayRef<int64_t> expr, int64_t ub);
+
+  /// Sets the identifier at the specified position to a constant.
+  void setIdToConstant(unsigned pos, int64_t val);
+
+  /// Sets the identifier corresponding to the specified Value id to a
+  /// constant. Asserts if the 'id' is not found.
+  void setIdToConstant(Value &id, int64_t val);
+
+  /// Looks up the position of the identifier with the specified Value. Returns
+  /// true if found (false otherwise). `pos' is set to the (column) position of
+  /// the identifier.
+  bool findId(Value &id, unsigned *pos) const;
+
+  /// Returns true if an identifier with the specified Value exists, false
+  /// otherwise.
+  bool containsId(Value &id) const;
+
+  // Add identifiers of the specified kind - specified positions are relative to
+  // the kind of identifier. The coefficient column corresponding to the added
+  // identifier is initialized to zero. 'id' is the Value corresponding to the
+  // identifier that can optionally be provided.
+  void addDimId(unsigned pos, Value *id = nullptr);
+  void addSymbolId(unsigned pos, Value *id = nullptr);
+  void addLocalId(unsigned pos);
+  void addId(IdKind kind, unsigned pos, Value *id = nullptr);
+
+  /// Add the specified values as a dim or symbol id depending on its nature, if
+  /// it already doesn't exist in the system. `id' has to be either a terminal
+  /// symbol or a loop IV, i.e., it cannot be the result affine.apply of any
+  /// symbols or loop IVs. The identifier is added to the end of the existing
+  /// dims or symbols. Additional information on the identifier is extracted
+  /// from the IR and added to the constraint system.
+  void addInductionVarOrTerminalSymbol(Value *id);
+
+  /// Composes the affine value map with this FlatAffineConstrains, adding the
+  /// results of the map as dimensions at the front [0, vMap->getNumResults())
+  /// and with the dimensions set to the equalities specified by the value map.
+  /// Returns failure if the composition fails (when vMap is a semi-affine map).
+  /// The vMap's operand Value's are used to look up the right positions in
+  /// the FlatAffineConstraints with which to associate. The dimensional and
+  /// symbolic operands of vMap should match 1:1 (in the same order) with those
+  /// of this constraint system, but the latter could have additional trailing
+  /// operands.
+  LogicalResult composeMap(AffineValueMap *vMap);
+
+  /// Projects out (aka eliminates) 'num' identifiers starting at position
+  /// 'pos'. The resulting constraint system is the shadow along the dimensions
+  /// that still exist. This method may not always be integer exact.
+  // TODO(bondhugula): deal with integer exactness when necessary - can return a
+  // value to mark exactness for example.
+  void projectOut(unsigned pos, unsigned num);
+  inline void projectOut(unsigned pos) { return projectOut(pos, 1); }
+
+  /// Projects out the identifier that is associate with Value *.
+  void projectOut(Value *id);
+
+  void removeId(IdKind idKind, unsigned pos);
+  void removeId(unsigned pos);
+
+  void removeDim(unsigned pos);
+
+  void removeEquality(unsigned pos);
+  void removeInequality(unsigned pos);
+
+  /// Changes the partition between dimensions and symbols. Depending on the new
+  /// symbol count, either a chunk of trailing dimensional identifiers becomes
+  /// symbols, or some of the leading symbols become dimensions.
+  void setDimSymbolSeparation(unsigned newSymbolCount);
+
+  /// Changes all symbol identifiers which are loop IVs to dim identifiers.
+  void convertLoopIVSymbolsToDims();
+
+  /// Sets the specified identifier to a constant and removes it.
+  void setAndEliminate(unsigned pos, int64_t constVal);
+
+  /// Tries to fold the specified identifier to a constant using a trivial
+  /// equality detection; if successful, the constant is substituted for the
+  /// identifier everywhere in the constraint system and then removed from the
+  /// system.
+  LogicalResult constantFoldId(unsigned pos);
+
+  /// This method calls constantFoldId for the specified range of identifiers,
+  /// 'num' identifiers starting at position 'pos'.
+  void constantFoldIdRange(unsigned pos, unsigned num);
+
+  /// Returns true if all the identifiers in the specified range [start, limit)
+  /// can only take a single value each if the remaining identifiers are treated
+  /// as symbols/parameters, i.e., for given values of the latter, there only
+  /// exists a unique value for each of the dimensions in the specified range.
+  bool isRangeOneToOne(unsigned start, unsigned limit) const;
+
+  /// Updates the constraints to be the smallest bounding (enclosing) box that
+  /// contains the points of 'this' set and that of 'other', with the symbols
+  /// being treated specially. For each of the dimensions, the min of the lower
+  /// bounds (symbolic) and the max of the upper bounds (symbolic) is computed
+  /// to determine such a bounding box. `other' is expected to have the same
+  /// dimensional identifiers as this constraint system (in the same order).
+  ///
+  /// Eg: if 'this' is {0 <= d0 <= 127}, 'other' is {16 <= d0 <= 192}, the
+  ///      output is {0 <= d0 <= 192}.
+  /// 2) 'this' = {s0 + 5 <= d0 <= s0 + 20}, 'other' is {s0 + 1 <= d0 <= s0 +
+  ///     9}, output = {s0 + 1 <= d0 <= s0 + 20}.
+  /// 3) 'this' = {0 <= d0 <= 5, 1 <= d1 <= 9}, 'other' = {2 <= d0 <= 6, 5 <= d1
+  ///     <= 15}, output = {0 <= d0 <= 6, 1 <= d1 <= 15}.
+  LogicalResult unionBoundingBox(const FlatAffineConstraints &other);
+
+  /// Returns 'true' if this constraint system and 'other' are in the same
+  /// space, i.e., if they are associated with the same set of identifiers,
+  /// appearing in the same order. Returns 'false' otherwise.
+  bool areIdsAlignedWithOther(const FlatAffineConstraints &other);
+
+  /// Merge and align the identifiers of 'this' and 'other' starting at
+  /// 'offset', so that both constraint systems get the union of the contained
+  /// identifiers that is dimension-wise and symbol-wise unique; both
+  /// constraint systems are updated so that they have the union of all
+  /// identifiers, with this's original identifiers appearing first followed by
+  /// any of other's identifiers that didn't appear in 'this'. Local
+  /// identifiers of each system are by design separate/local and are placed
+  /// one after other (this's followed by other's).
+  //  Eg: Input: 'this'  has ((%i %j) [%M %N])
+  //             'other' has (%k, %j) [%P, %N, %M])
+  //      Output: both 'this', 'other' have (%i, %j, %k) [%M, %N, %P]
+  //
+  void mergeAndAlignIdsWithOther(unsigned offset, FlatAffineConstraints *other);
+
+  unsigned getNumConstraints() const {
+    return getNumInequalities() + getNumEqualities();
+  }
+  inline unsigned getNumIds() const { return numIds; }
+  inline unsigned getNumDimIds() const { return numDims; }
+  inline unsigned getNumSymbolIds() const { return numSymbols; }
+  inline unsigned getNumDimAndSymbolIds() const { return numDims + numSymbols; }
+  inline unsigned getNumLocalIds() const {
+    return numIds - numDims - numSymbols;
+  }
+
+  inline ArrayRef<Optional<Value *>> getIds() const {
+    return {ids.data(), ids.size()};
+  }
+  inline MutableArrayRef<Optional<Value *>> getIds() {
+    return {ids.data(), ids.size()};
+  }
+
+  /// Returns the optional Value corresponding to the pos^th identifier.
+  inline Optional<Value *> getId(unsigned pos) const { return ids[pos]; }
+  inline Optional<Value *> &getId(unsigned pos) { return ids[pos]; }
+
+  /// Returns the Value associated with the pos^th identifier. Asserts if
+  /// no Value identifier was associated.
+  inline Value *getIdValue(unsigned pos) const {
+    assert(ids[pos].hasValue() && "identifier's Value not set");
+    return ids[pos].getValue();
+  }
+
+  /// Returns the Values associated with identifiers in range [start, end).
+  /// Asserts if no Value was associated with one of these identifiers.
+  void getIdValues(unsigned start, unsigned end,
+                   SmallVectorImpl<Value *> *values) const {
+    assert((start < numIds || start == end) && "invalid start position");
+    assert(end <= numIds && "invalid end position");
+    values->clear();
+    values->reserve(end - start);
+    for (unsigned i = start; i < end; i++) {
+      values->push_back(getIdValue(i));
+    }
+  }
+  inline void getAllIdValues(SmallVectorImpl<Value *> *values) const {
+    getIdValues(0, numIds, values);
+  }
+
+  /// Sets Value associated with the pos^th identifier.
+  inline void setIdValue(unsigned pos, Value *val) {
+    assert(pos < numIds && "invalid id position");
+    ids[pos] = val;
+  }
+  /// Sets Values associated with identifiers in the range [start, end).
+  void setIdValues(unsigned start, unsigned end, ArrayRef<Value *> values) {
+    assert((start < numIds || end == start) && "invalid start position");
+    assert(end <= numIds && "invalid end position");
+    assert(values.size() == end - start);
+    for (unsigned i = start; i < end; ++i)
+      ids[i] = values[i - start];
+  }
+
+  /// Clears this list of constraints and copies other into it.
+  void clearAndCopyFrom(const FlatAffineConstraints &other);
+
+  /// Returns the smallest known constant bound for the extent of the specified
+  /// identifier (pos^th), i.e., the smallest known constant that is greater
+  /// than or equal to 'exclusive upper bound' - 'lower bound' of the
+  /// identifier. Returns None if it's not a constant. This method employs
+  /// trivial (low complexity / cost) checks and detection. Symbolic identifiers
+  /// are treated specially, i.e., it looks for constant differences between
+  /// affine expressions involving only the symbolic identifiers. See comments
+  /// at function definition for examples. 'lb' and 'lbDivisor', if provided,
+  /// are used to express the lower bound associated with the constant
+  /// difference: 'lb' has the coefficients and lbDivisor, the divisor. For eg.,
+  /// if the lower bound is [(s0 + s2 - 1) floordiv 32] for a system with three
+  /// symbolic identifiers, *lb = [1, 0, 1], lbDivisor = 32.
+  Optional<int64_t>
+  getConstantBoundOnDimSize(unsigned pos,
+                            SmallVectorImpl<int64_t> *lb = nullptr,
+                            int64_t *lbFloorDivisor = nullptr,
+                            SmallVectorImpl<int64_t> *ub = nullptr) const;
+
+  /// Returns the constant lower bound for the pos^th identifier if there is
+  /// one; None otherwise.
+  Optional<int64_t> getConstantLowerBound(unsigned pos) const;
+
+  /// Returns the constant upper bound for the pos^th identifier if there is
+  /// one; None otherwise.
+  Optional<int64_t> getConstantUpperBound(unsigned pos) const;
+
+  /// Gets the lower and upper bound of the pos^th identifier treating
+  /// [0, offset) U [offset + num, symbStartPos) as dimensions and
+  /// [symStartPos, getNumDimAndSymbolIds) as symbols. The returned
+  /// multi-dimensional maps in the pair represent the max and min of
+  /// potentially multiple affine expressions. The upper bound is exclusive.
+  /// 'localExprs' holds pre-computed AffineExpr's for all local identifiers in
+  /// the system.
+  std::pair<AffineMap, AffineMap>
+  getLowerAndUpperBound(unsigned pos, unsigned offset, unsigned num,
+                        unsigned symStartPos, ArrayRef<AffineExpr> localExprs,
+                        MLIRContext *context);
+
+  /// Returns true if the set can be trivially detected as being
+  /// hyper-rectangular on the specified contiguous set of identifiers.
+  bool isHyperRectangular(unsigned pos, unsigned num) const;
+
+  /// Removes duplicate constraints, trivially true constraints, and constraints
+  /// that can be detected as redundant as a result of differing only in their
+  /// constant term part. A constraint of the form <non-negative constant> >= 0
+  /// is considered trivially true. This method is a linear time method on the
+  /// constraints, does a single scan, and updates in place.
+  void removeTrivialRedundancy();
+
+  /// A more expensive check to detect redundant inequalities thatn
+  /// removeTrivialRedundancy.
+  void removeRedundantInequalities();
+
+  // Removes all equalities and inequalities.
+  void clearConstraints();
+
+  void print(raw_ostream &os) const;
+  void dump() const;
+
+private:
+  /// Returns false if the fields corresponding to various identifier counts, or
+  /// equality/inequality buffer sizes aren't consistent; true otherwise. This
+  /// is meant to be used within an assert internally.
+  bool hasConsistentState() const;
+
+  /// Checks all rows of equality/inequality constraints for trivial
+  /// contradictions (for example: 1 == 0, 0 >= 1), which may have surfaced
+  /// after elimination. Returns 'true' if an invalid constraint is found;
+  /// 'false'otherwise.
+  bool hasInvalidConstraint() const;
+
+  /// Returns the constant lower bound bound if isLower is true, and the upper
+  /// bound if isLower is false.
+  template <bool isLower>
+  Optional<int64_t> computeConstantLowerOrUpperBound(unsigned pos);
+
+  // Eliminates a single identifier at 'position' from equality and inequality
+  // constraints. Returns 'success' if the identifier was eliminated, and
+  // 'failure' otherwise.
+  inline LogicalResult gaussianEliminateId(unsigned position) {
+    return success(gaussianEliminateIds(position, position + 1) == 1);
+  }
+
+  // Eliminates identifiers from equality and inequality constraints
+  // in column range [posStart, posLimit).
+  // Returns the number of variables eliminated.
+  unsigned gaussianEliminateIds(unsigned posStart, unsigned posLimit);
+
+  /// Eliminates identifier at the specified position using Fourier-Motzkin
+  /// variable elimination, but uses Gaussian elimination if there is an
+  /// equality involving that identifier. If the result of the elimination is
+  /// integer exact, *isResultIntegerExact is set to true. If 'darkShadow' is
+  /// set to true, a potential under approximation (subset) of the rational
+  /// shadow / exact integer shadow is computed.
+  // See implementation comments for more details.
+  void FourierMotzkinEliminate(unsigned pos, bool darkShadow = false,
+                               bool *isResultIntegerExact = nullptr);
+
+  /// Tightens inequalities given that we are dealing with integer spaces. This
+  /// is similar to the GCD test but applied to inequalities. The constant term
+  /// can be reduced to the preceding multiple of the GCD of the coefficients,
+  /// i.e.,
+  ///  64*i - 100 >= 0  =>  64*i - 128 >= 0 (since 'i' is an integer). This is a
+  /// fast method (linear in the number of coefficients).
+  void GCDTightenInequalities();
+
+  /// Normalized each constraints by the GCD of its coefficients.
+  void normalizeConstraintsByGCD();
+
+  /// Removes identifiers in column range [idStart, idLimit), and copies any
+  /// remaining valid data into place, updates member variables, and resizes
+  /// arrays as needed.
+  void removeIdRange(unsigned idStart, unsigned idLimit);
+
+  /// Coefficients of affine equalities (in == 0 form).
+  SmallVector<int64_t, 64> equalities;
+
+  /// Coefficients of affine inequalities (in >= 0 form).
+  SmallVector<int64_t, 64> inequalities;
+
+  /// Number of columns reserved. Actual ones in used are returned by
+  /// getNumCols().
+  unsigned numReservedCols;
+
+  /// Total number of identifiers.
+  unsigned numIds;
+
+  /// Number of identifiers corresponding to real dimensions.
+  unsigned numDims;
+
+  /// Number of identifiers corresponding to symbols (unknown but constant for
+  /// analysis).
+  unsigned numSymbols;
+
+  /// Values corresponding to the (column) identifiers of this constraint
+  /// system appearing in the order the identifiers correspond to columns.
+  /// Temporary ones or those that aren't associated to any Value are set to
+  /// None.
+  SmallVector<Optional<Value *>, 8> ids;
+
+  /// A parameter that controls detection of an unrealistic number of
+  /// constraints. If the number of constraints is this many times the number of
+  /// variables, we consider such a system out of line with the intended use
+  /// case of FlatAffineConstraints.
+  // The rationale for 32 is that in the typical simplest of cases, an
+  // identifier is expected to have one lower bound and one upper bound
+  // constraint. With a level of tiling or a connection to another identifier
+  // through a div or mod, an extra pair of bounds gets added. As a limit, we
+  // don't expect an identifier to have more than 32 lower/upper/equality
+  // constraints. This is conservatively set low and can be raised if needed.
+  constexpr static unsigned kExplosionFactor = 32;
+};
+
+/// Simplify an affine expression by flattening and some amount of
+/// simple analysis. This has complexity linear in the number of nodes in
+/// 'expr'. Returns the simplified expression, which is the same as the input
+///  expression if it can't be simplified.
+AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims,
+                              unsigned numSymbols);
+
+/// Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' could not be
+/// flattened (i.e., semi-affine is not yet handled). 'cst' contains constraints
+/// that connect newly introduced local identifiers to existing dimensional and
+/// symbolic identifiers. See documentation for AffineExprFlattener on how
+/// mod's and div's are flattened.
+LogicalResult
+getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
+                       llvm::SmallVectorImpl<int64_t> *flattenedExpr,
+                       FlatAffineConstraints *cst = nullptr);
+
+/// Flattens the result expressions of the map to their corresponding flattened
+/// forms and set in 'flattenedExprs'. Returns failure if any expression in the
+/// map could not be flattened (i.e., semi-affine is not yet handled). 'cst'
+/// contains constraints that connect newly introduced local identifiers to
+/// existing dimensional and / symbolic identifiers. See documentation for
+/// AffineExprFlattener on how mod's and div's are flattened. For all affine
+/// expressions that share the same operands (like those of an affine map), this
+/// method should be used instead of repeatedly calling getFlattenedAffineExpr
+/// since local variables added to deal with div's and mod's will be reused
+/// across expressions.
+LogicalResult getFlattenedAffineExprs(
+    AffineMap map, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
+    FlatAffineConstraints *cst = nullptr);
+LogicalResult getFlattenedAffineExprs(
+    IntegerSet set, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
+    FlatAffineConstraints *cst = nullptr);
+
+} // end namespace mlir.
+
+#endif // MLIR_ANALYSIS_AFFINE_STRUCTURES_H
diff --git a/third_party/mlir/include/mlir/Analysis/Dominance.h b/third_party/mlir/include/mlir/Analysis/Dominance.h
new file mode 100644
index 0000000..d3e5b61
--- /dev/null
+++ b/third_party/mlir/include/mlir/Analysis/Dominance.h
@@ -0,0 +1,144 @@
+//===- Dominance.h - Dominator analysis for CFGs ----------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_ANALYSIS_DOMINANCE_H
+#define MLIR_ANALYSIS_DOMINANCE_H
+
+#include "mlir/IR/RegionGraphTraits.h"
+#include "llvm/Support/GenericDomTree.h"
+
+extern template class llvm::DominatorTreeBase<mlir::Block, false>;
+extern template class llvm::DominatorTreeBase<mlir::Block, true>;
+
+namespace mlir {
+using DominanceInfoNode = llvm::DomTreeNodeBase<Block>;
+class Operation;
+
+namespace detail {
+template <bool IsPostDom> class DominanceInfoBase {
+  using base = llvm::DominatorTreeBase<Block, IsPostDom>;
+
+public:
+  DominanceInfoBase(Operation *op) { recalculate(op); }
+  DominanceInfoBase(DominanceInfoBase &&) = default;
+  DominanceInfoBase &operator=(DominanceInfoBase &&) = default;
+
+  DominanceInfoBase(const DominanceInfoBase &) = delete;
+  DominanceInfoBase &operator=(const DominanceInfoBase &) = delete;
+
+  /// Recalculate the dominance info.
+  void recalculate(Operation *op);
+
+  /// Get the root dominance node of the given region.
+  DominanceInfoNode *getRootNode(Region *region) {
+    assert(dominanceInfos.count(region) != 0);
+    return dominanceInfos[region]->getRootNode();
+  }
+
+protected:
+  using super = DominanceInfoBase<IsPostDom>;
+
+  /// Return true if the specified block A properly dominates block B.
+  bool properlyDominates(Block *a, Block *b);
+
+  /// A mapping of regions to their base dominator tree.
+  llvm::DenseMap<Region *, std::unique_ptr<base>> dominanceInfos;
+};
+} // end namespace detail
+
+/// A class for computing basic dominance information.
+class DominanceInfo : public detail::DominanceInfoBase</*IsPostDom=*/false> {
+public:
+  using super::super;
+
+  /// Return true if operation A properly dominates operation B.
+  bool properlyDominates(Operation *a, Operation *b);
+
+  /// Return true if operation A dominates operation B.
+  bool dominates(Operation *a, Operation *b) {
+    return a == b || properlyDominates(a, b);
+  }
+
+  /// Return true if value A properly dominates operation B.
+  bool properlyDominates(Value *a, Operation *b);
+
+  /// Return true if operation A dominates operation B.
+  bool dominates(Value *a, Operation *b) {
+    return (Operation *)a->getDefiningOp() == b || properlyDominates(a, b);
+  }
+
+  /// Return true if the specified block A dominates block B.
+  bool dominates(Block *a, Block *b) {
+    return a == b || properlyDominates(a, b);
+  }
+
+  /// Return true if the specified block A properly dominates block B.
+  bool properlyDominates(Block *a, Block *b) {
+    return super::properlyDominates(a, b);
+  }
+};
+
+/// A class for computing basic postdominance information.
+class PostDominanceInfo : public detail::DominanceInfoBase</*IsPostDom=*/true> {
+public:
+  using super::super;
+
+  /// Return true if operation A properly postdominates operation B.
+  bool properlyPostDominates(Operation *a, Operation *b);
+
+  /// Return true if operation A postdominates operation B.
+  bool postDominates(Operation *a, Operation *b) {
+    return a == b || properlyPostDominates(a, b);
+  }
+
+  /// Return true if the specified block A properly postdominates block B.
+  bool properlyPostDominates(Block *a, Block *b) {
+    return super::properlyDominates(a, b);
+  }
+
+  /// Return true if the specified block A postdominates block B.
+  bool postDominates(Block *a, Block *b) {
+    return a == b || properlyPostDominates(a, b);
+  }
+};
+
+} //  end namespace mlir
+
+namespace llvm {
+
+/// DominatorTree GraphTraits specialization so the DominatorTree can be
+/// iterated by generic graph iterators.
+template <> struct GraphTraits<mlir::DominanceInfoNode *> {
+  using ChildIteratorType = mlir::DominanceInfoNode::iterator;
+  using NodeRef = mlir::DominanceInfoNode *;
+
+  static NodeRef getEntryNode(NodeRef N) { return N; }
+  static inline ChildIteratorType child_begin(NodeRef N) { return N->begin(); }
+  static inline ChildIteratorType child_end(NodeRef N) { return N->end(); }
+};
+
+template <> struct GraphTraits<const mlir::DominanceInfoNode *> {
+  using ChildIteratorType = mlir::DominanceInfoNode::const_iterator;
+  using NodeRef = const mlir::DominanceInfoNode *;
+
+  static NodeRef getEntryNode(NodeRef N) { return N; }
+  static inline ChildIteratorType child_begin(NodeRef N) { return N->begin(); }
+  static inline ChildIteratorType child_end(NodeRef N) { return N->end(); }
+};
+
+} // end namespace llvm
+#endif
diff --git a/third_party/mlir/include/mlir/Analysis/LoopAnalysis.h b/third_party/mlir/include/mlir/Analysis/LoopAnalysis.h
new file mode 100644
index 0000000..7763a2b
--- /dev/null
+++ b/third_party/mlir/include/mlir/Analysis/LoopAnalysis.h
@@ -0,0 +1,111 @@
+//===- LoopAnalysis.h - loop analysis methods -------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This header file defines prototypes for methods to analyze loops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_LOOP_ANALYSIS_H
+#define MLIR_ANALYSIS_LOOP_ANALYSIS_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/Optional.h"
+
+namespace mlir {
+
+class AffineExpr;
+class AffineForOp;
+class AffineMap;
+class Operation;
+class MemRefType;
+class Value;
+
+/// Returns the trip count of the loop as an affine map with its corresponding
+/// operands if the latter is expressible as an affine expression, and nullptr
+/// otherwise. This method always succeeds as long as the lower bound is not a
+/// multi-result map. The trip count expression is simplified before returning.
+/// This method only utilizes map composition to construct lower and upper
+/// bounds before computing the trip count expressions
+// TODO(mlir-team): this should be moved into 'Transforms/' and be replaced by a
+// pure analysis method relying on FlatAffineConstraints
+void buildTripCountMapAndOperands(AffineForOp forOp, AffineMap *map,
+                                  SmallVectorImpl<Value *> *operands);
+
+/// Returns the trip count of the loop if it's a constant, None otherwise. This
+/// uses affine expression analysis and is able to determine constant trip count
+/// in non-trivial cases.
+llvm::Optional<uint64_t> getConstantTripCount(AffineForOp forOp);
+
+/// Returns the greatest known integral divisor of the trip count. Affine
+/// expression analysis is used (indirectly through getTripCount), and
+/// this method is thus able to determine non-trivial divisors.
+uint64_t getLargestDivisorOfTripCount(AffineForOp forOp);
+
+/// Given an induction variable `iv` of type AffineForOp and an `index` of type
+/// IndexType, returns `true` if `index` is independent of `iv` and false
+/// otherwise.
+/// The determination supports composition with at most one AffineApplyOp.
+/// The at most one AffineApplyOp comes from the fact that composition of
+/// AffineApplyOp need to be canonicalized by construction to avoid writing code
+/// that composes arbitrary numbers of AffineApplyOps everywhere. To achieve
+/// this, at the very least, the compose-affine-apply pass must have been run.
+///
+/// Prerequisites:
+///   1. `iv` and `index` of the proper type;
+///   2. at most one reachable AffineApplyOp from index;
+///
+/// Returns false in cases with more than one AffineApplyOp, this is
+/// conservative.
+bool isAccessInvariant(Value *iv, Value *index);
+
+/// Given an induction variable `iv` of type AffineForOp and `indices` of type
+/// IndexType, returns the set of `indices` that are independent of `iv`.
+///
+/// Prerequisites (inherited from `isAccessInvariant` above):
+///   1. `iv` and `indices` of the proper type;
+///   2. at most one affine.apply is reachable from each index in `indices`;
+///
+/// Emits a note if it encounters a chain of affine.apply and conservatively
+///  those cases.
+llvm::DenseSet<Value *, llvm::DenseMapInfo<Value *>>
+getInvariantAccesses(Value *iv, llvm::ArrayRef<Value *> indices);
+
+using VectorizableLoopFun = std::function<bool(AffineForOp)>;
+
+/// Checks whether the loop is structurally vectorizable; i.e.:
+///   1. no conditionals are nested under the loop;
+///   2. all nested load/stores are to scalar MemRefs.
+/// TODO(ntv): relax the no-conditionals restriction
+bool isVectorizableLoopBody(AffineForOp loop);
+
+/// Checks whether the loop is structurally vectorizable and that all the LoadOp
+/// and StoreOp matched have access indexing functions that are are either:
+///   1. invariant along the loop induction variable created by 'loop';
+///   2. varying along at most one memory dimension. If such a unique dimension
+///      is found, it is written into `memRefDim`.
+bool isVectorizableLoopBody(AffineForOp loop, int *memRefDim);
+
+/// Checks where SSA dominance would be violated if a for op's body
+/// operations are shifted by the specified shifts. This method checks if a
+/// 'def' and all its uses have the same shift factor.
+// TODO(mlir-team): extend this to check for memory-based dependence
+// violation when we have the support.
+bool isInstwiseShiftValid(AffineForOp forOp, llvm::ArrayRef<uint64_t> shifts);
+} // end namespace mlir
+
+#endif // MLIR_ANALYSIS_LOOP_ANALYSIS_H
diff --git a/third_party/mlir/include/mlir/Analysis/NestedMatcher.h b/third_party/mlir/include/mlir/Analysis/NestedMatcher.h
new file mode 100644
index 0000000..b07b73a
--- /dev/null
+++ b/third_party/mlir/include/mlir/Analysis/NestedMatcher.h
@@ -0,0 +1,193 @@
+//===- NestedMacher.h - Nested matcher for MLFunction -----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_
+#define MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_
+
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/Support/Allocator.h"
+
+namespace mlir {
+
+struct NestedPattern;
+class Operation;
+
+/// An NestedPattern captures nested patterns in the IR.
+/// It is used in conjunction with a scoped NestedPatternContext which is an
+/// llvm::BumpPtrAllocator that handles memory allocations efficiently and
+/// avoids ownership issues.
+///
+/// In order to use NestedPatterns, first create a scoped context.
+/// When the context goes out of scope, everything is freed.
+/// This design simplifies the API by avoiding references to the context and
+/// makes it clear that references to matchers must not escape.
+///
+/// Example:
+///   {
+///      NestedPatternContext context;
+///      auto gemmLike = Doall(Doall(Red(LoadStores())));
+///      auto matches = gemmLike.match(f);
+///      // do work on matches
+///   }  // everything is freed
+///
+///
+/// Nested abstraction for matching results.
+/// Provides access to the nested Operation* captured by a Matcher.
+///
+/// A NestedMatch contains an Operation* and the children NestedMatch and is
+/// thus cheap to copy. NestedMatch is stored in a scoped bumper allocator whose
+/// lifetime is managed by an RAII NestedPatternContext.
+struct NestedMatch {
+  static NestedMatch build(Operation *operation,
+                           ArrayRef<NestedMatch> nestedMatches);
+  NestedMatch(const NestedMatch &) = default;
+  NestedMatch &operator=(const NestedMatch &) = default;
+
+  explicit operator bool() { return matchedOperation != nullptr; }
+
+  Operation *getMatchedOperation() { return matchedOperation; }
+  ArrayRef<NestedMatch> getMatchedChildren() { return matchedChildren; }
+
+private:
+  friend struct NestedPattern;
+  friend struct NestedPatternContext;
+
+  /// Underlying global bump allocator managed by a NestedPatternContext.
+  static llvm::BumpPtrAllocator *&allocator();
+
+  NestedMatch() = default;
+
+  /// Payload, holds a NestedMatch and all its children along this branch.
+  Operation *matchedOperation;
+  ArrayRef<NestedMatch> matchedChildren;
+};
+
+/// A NestedPattern is a nested operation walker that:
+///   1. recursively matches a substructure in the tree;
+///   2. uses a filter function to refine matches with extra semantic
+///      constraints (passed via a lambda of type FilterFunctionType);
+///   3. TODO(ntv) optionally applies actions (lambda).
+///
+/// Nested patterns are meant to capture imperfectly nested loops while matching
+/// properties over the whole loop nest. For instance, in vectorization we are
+/// interested in capturing all the imperfectly nested loops of a certain type
+/// and such that all the load and stores have certain access patterns along the
+/// loops' induction variables). Such NestedMatches are first captured using the
+/// `match` function and are later processed to analyze properties and apply
+/// transformations in a non-greedy way.
+///
+/// The NestedMatches captured in the IR can grow large, especially after
+/// aggressive unrolling. As experience has shown, it is generally better to use
+/// a plain walk over operations to match flat patterns but the current
+/// implementation is competitive nonetheless.
+using FilterFunctionType = std::function<bool(Operation &)>;
+inline bool defaultFilterFunction(Operation &) { return true; }
+struct NestedPattern {
+  NestedPattern(ArrayRef<NestedPattern> nested,
+                FilterFunctionType filter = defaultFilterFunction);
+  NestedPattern(const NestedPattern &) = default;
+  NestedPattern &operator=(const NestedPattern &) = default;
+
+  /// Returns all the top-level matches in `func`.
+  void match(FuncOp func, SmallVectorImpl<NestedMatch> *matches) {
+    func.walk([&](Operation *op) { matchOne(op, matches); });
+  }
+
+  /// Returns all the top-level matches in `op`.
+  void match(Operation *op, SmallVectorImpl<NestedMatch> *matches) {
+    op->walk([&](Operation *child) { matchOne(child, matches); });
+  }
+
+  /// Returns the depth of the pattern.
+  unsigned getDepth() const;
+
+private:
+  friend struct NestedPatternContext;
+  friend struct NestedMatch;
+  friend struct State;
+
+  /// Underlying global bump allocator managed by a NestedPatternContext.
+  static llvm::BumpPtrAllocator *&allocator();
+
+  /// Matches this pattern against a single `op` and fills matches with the
+  /// result.
+  void matchOne(Operation *op, SmallVectorImpl<NestedMatch> *matches);
+
+  /// Nested patterns to be matched.
+  ArrayRef<NestedPattern> nestedPatterns;
+
+  /// Extra filter function to apply to prune patterns as the IR is walked.
+  FilterFunctionType filter;
+
+  /// skip is an implementation detail needed so that we can implement match
+  /// without switching on the type of the Operation. The idea is that a
+  /// NestedPattern first checks if it matches locally and then recursively
+  /// applies its nested matchers to its elem->nested. Since we want to rely on
+  /// the existing operation walking functionality rather than duplicate
+  /// it, we allow an off-by-one traversal to account for the fact that we
+  /// write:
+  ///
+  ///  void match(Operation *elem) {
+  ///    for (auto &c : getNestedPatterns()) {
+  ///      NestedPattern childPattern(...);
+  ///                                  ^~~~ Needs off-by-one skip.
+  ///
+  Operation *skip;
+};
+
+/// RAII structure to transparently manage the bump allocator for
+/// NestedPattern and NestedMatch classes. This avoids passing a context to
+/// all the API functions.
+struct NestedPatternContext {
+  NestedPatternContext() {
+    assert(NestedMatch::allocator() == nullptr &&
+           "Only a single NestedPatternContext is supported");
+    assert(NestedPattern::allocator() == nullptr &&
+           "Only a single NestedPatternContext is supported");
+    NestedMatch::allocator() = &allocator;
+    NestedPattern::allocator() = &allocator;
+  }
+  ~NestedPatternContext() {
+    NestedMatch::allocator() = nullptr;
+    NestedPattern::allocator() = nullptr;
+  }
+  llvm::BumpPtrAllocator allocator;
+};
+
+namespace matcher {
+// Syntactic sugar NestedPattern builder functions.
+NestedPattern Op(FilterFunctionType filter = defaultFilterFunction);
+NestedPattern If(NestedPattern child);
+NestedPattern If(FilterFunctionType filter, NestedPattern child);
+NestedPattern If(ArrayRef<NestedPattern> nested = {});
+NestedPattern If(FilterFunctionType filter,
+                 ArrayRef<NestedPattern> nested = {});
+NestedPattern For(NestedPattern child);
+NestedPattern For(FilterFunctionType filter, NestedPattern child);
+NestedPattern For(ArrayRef<NestedPattern> nested = {});
+NestedPattern For(FilterFunctionType filter,
+                  ArrayRef<NestedPattern> nested = {});
+
+bool isParallelLoop(Operation &op);
+bool isReductionLoop(Operation &op);
+bool isLoadOrStore(Operation &op);
+
+} // end namespace matcher
+} // end namespace mlir
+
+#endif // MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_
diff --git a/third_party/mlir/include/mlir/Analysis/Passes.h b/third_party/mlir/include/mlir/Analysis/Passes.h
new file mode 100644
index 0000000..9eafcd3
--- /dev/null
+++ b/third_party/mlir/include/mlir/Analysis/Passes.h
@@ -0,0 +1,43 @@
+//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This header file defines prototypes that expose pass constructors in the
+// analysis library.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_PASSES_H
+#define MLIR_ANALYSIS_PASSES_H
+
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+
+class FunctionPassBase;
+
+/// Creates a pass to check memref accesses in a Function.
+FunctionPassBase *createMemRefBoundCheckPass();
+
+/// Creates a pass to check memref access dependences in a Function.
+FunctionPassBase *createTestMemRefDependenceCheckPass();
+
+/// Creates a pass to test parallelism detection; emits note for parallel loops.
+FunctionPassBase *createParallelismDetectionTestPass();
+
+} // end namespace mlir
+
+#endif // MLIR_ANALYSIS_PASSES_H
diff --git a/third_party/mlir/include/mlir/Analysis/SliceAnalysis.h b/third_party/mlir/include/mlir/Analysis/SliceAnalysis.h
new file mode 100644
index 0000000..ad6b653
--- /dev/null
+++ b/third_party/mlir/include/mlir/Analysis/SliceAnalysis.h
@@ -0,0 +1,215 @@
+//===- SliceAnalysis.h - Analysis for Transitive UseDef chains --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_ANALYSIS_SLICEANALYSIS_H_
+#define MLIR_ANALYSIS_SLICEANALYSIS_H_
+
+#include <functional>
+#include <vector>
+
+#include "mlir/Support/LLVM.h"
+
+#include "llvm/ADT/SetVector.h"
+
+namespace mlir {
+
+class Operation;
+
+/// Type of the condition to limit the propagation of transitive use-defs.
+/// This can be used in particular to limit the propagation to a given Scope or
+/// to avoid passing through certain types of operation in a configurable
+/// manner.
+using TransitiveFilter = std::function<bool(Operation *)>;
+
+/// Fills `forwardSlice` with the computed forward slice (i.e. all
+/// the transitive uses of op), **without** including that operation.
+///
+/// This additionally takes a TransitiveFilter which acts as a frontier:
+/// when looking at uses transitively, a operation that does not pass the
+/// filter is never propagated through. This allows in particular to carve out
+/// the scope within a ForInst or the scope within an IfInst.
+///
+/// The implementation traverses the use chains in postorder traversal for
+/// efficiency reasons: if a operation is already in `forwardSlice`, no
+/// need to traverse its uses again. Since use-def chains form a DAG, this
+/// terminates.
+///
+/// Upon return to the root call, `forwardSlice` is filled with a
+/// postorder list of uses (i.e. a reverse topological order). To get a proper
+/// topological order, we just just reverse the order in `forwardSlice` before
+/// returning.
+///
+/// Example starting from node 0
+/// ============================
+///
+///               0
+///    ___________|___________
+///    1       2      3      4
+///    |_______|      |______|
+///    |   |             |
+///    |   5             6
+///    |___|_____________|
+///      |               |
+///      7               8
+///      |_______________|
+///              |
+///              9
+///
+/// Assuming all local orders match the numbering order:
+/// 1. after getting back to the root getForwardSlice, `forwardSlice` may
+///    contain:
+///      {9, 7, 8, 5, 1, 2, 6, 3, 4}
+/// 2. reversing the result of 1. gives:
+///      {4, 3, 6, 2, 1, 5, 8, 7, 9}
+///
+void getForwardSlice(
+    Operation *op, llvm::SetVector<Operation *> *forwardSlice,
+    TransitiveFilter filter = /* pass-through*/
+    [](Operation *) { return true; });
+
+/// Fills `backwardSlice` with the computed backward slice (i.e.
+/// all the transitive defs of op), **without** including that operation.
+///
+/// This additionally takes a TransitiveFilter which acts as a frontier:
+/// when looking at defs transitively, a operation that does not pass the
+/// filter is never propagated through. This allows in particular to carve out
+/// the scope within a ForInst or the scope within an IfInst.
+///
+/// The implementation traverses the def chains in postorder traversal for
+/// efficiency reasons: if a operation is already in `backwardSlice`, no
+/// need to traverse its definitions again. Since useuse-def chains form a DAG,
+/// this terminates.
+///
+/// Upon return to the root call, `backwardSlice` is filled with a
+/// postorder list of defs. This happens to be a topological order, from the
+/// point of view of the use-def chains.
+///
+/// Example starting from node 8
+/// ============================
+///
+///    1       2      3      4
+///    |_______|      |______|
+///    |   |             |
+///    |   5             6
+///    |___|_____________|
+///      |               |
+///      7               8
+///      |_______________|
+///              |
+///              9
+///
+/// Assuming all local orders match the numbering order:
+///    {1, 2, 5, 3, 4, 6}
+///
+void getBackwardSlice(
+    Operation *op, llvm::SetVector<Operation *> *backwardSlice,
+    TransitiveFilter filter = /* pass-through*/
+    [](Operation *) { return true; });
+
+/// Iteratively computes backward slices and forward slices until
+/// a fixed point is reached. Returns an `llvm::SetVector<Operation *>` which
+/// **includes** the original operation.
+///
+/// This allows building a slice (i.e. multi-root DAG where everything
+/// that is reachable from an Value in forward and backward direction is
+/// contained in the slice).
+/// This is the abstraction we need to materialize all the operations for
+/// supervectorization without worrying about orderings and Value
+/// replacements.
+///
+/// Example starting from any node
+/// ==============================
+///
+///    1       2      3      4
+///    |_______|      |______|
+///    |   |             |   |
+///    |   5             6___|
+///    |___|_____________|   |
+///      |               |   |
+///      7               8   |
+///      |_______________|   |
+///              |           |
+///              9          10
+///
+/// Return the whole DAG in some topological order.
+///
+/// The implementation works by just filling up a worklist with iterative
+/// alternate calls to `getBackwardSlice` and `getForwardSlice`.
+///
+/// The following section describes some additional implementation
+/// considerations for a potentially more efficient implementation but they are
+/// just an intuition without proof, we still use a worklist for now.
+///
+/// Additional implementation considerations
+/// ========================================
+/// Consider the defs-op-uses hourglass.
+///    ____
+///    \  /  defs (in some topological order)
+///     \/
+///     op
+///     /\
+///    /  \  uses (in some topological order)
+///   /____\
+///
+/// We want to iteratively apply `getSlice` to construct the whole
+/// list of Operation that are reachable by (use|def)+ from op.
+/// We want the resulting slice in topological order.
+/// Ideally we would like the ordering to be maintained in-place to avoid
+/// copying Operation at each step. Keeping this ordering by construction
+/// seems very unclear, so we list invariants in the hope of seeing whether
+/// useful properties pop up.
+///
+/// In the following:
+///   we use |= for set inclusion;
+///   we use << for set topological ordering (i.e. each pair is ordered).
+///
+/// Assumption:
+/// ===========
+/// We wish to maintain the following property by a recursive argument:
+///   """
+///      defs << {op} <<uses are in topological order.
+///   """
+/// The property clearly holds for 0 and 1-sized uses and defs;
+///
+/// Invariants:
+///   2. defs and uses are in topological order internally, by construction;
+///   3. for any {x} |= defs, defs(x) |= defs;    because all go through op
+///   4. for any {x} |= uses,    defs |= defs(x); because all go through op
+///   5. for any {x} |= defs,    uses |= uses(x); because all go through op
+///   6. for any {x} |= uses, uses(x) |= uses;    because all go through op
+///
+/// Intuitively, we should be able to recurse like:
+///   preorder(defs) - op - postorder(uses)
+/// and keep things ordered but this is still hand-wavy and not worth the
+/// trouble for now: punt to a simple worklist-based solution.
+///
+llvm::SetVector<Operation *> getSlice(
+    Operation *op,
+    TransitiveFilter backwardFilter = /* pass-through*/
+    [](Operation *) { return true; },
+    TransitiveFilter forwardFilter = /* pass-through*/
+    [](Operation *) { return true; });
+
+/// Multi-root DAG topological sort.
+/// Performs a topological sort of the Operation in the `toSort` SetVector.
+/// Returns a topologically sorted SetVector.
+llvm::SetVector<Operation *>
+topologicalSort(const llvm::SetVector<Operation *> &toSort);
+
+} // end namespace mlir
+
+#endif // MLIR_ANALYSIS_SLICEANALYSIS_H_
diff --git a/third_party/mlir/include/mlir/Analysis/Utils.h b/third_party/mlir/include/mlir/Analysis/Utils.h
new file mode 100644
index 0000000..b012cc1
--- /dev/null
+++ b/third_party/mlir/include/mlir/Analysis/Utils.h
@@ -0,0 +1,304 @@
+//===- Utils.h - General analysis utilities ---------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This header file defines prototypes for various transformation utilities for
+// memref's and non-loop IR structures. These are not passes by themselves but
+// are used either by passes, optimization sequences, or in turn by other
+// transformation utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_UTILS_H
+#define MLIR_ANALYSIS_UTILS_H
+
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Location.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallVector.h"
+#include <memory>
+
+namespace mlir {
+
+class AffineForOp;
+class Block;
+class FlatAffineConstraints;
+class Location;
+struct MemRefAccess;
+class Operation;
+class Value;
+
+/// Populates 'loops' with IVs of the loops surrounding 'op' ordered from
+/// the outermost 'affine.for' operation to the innermost one.
+//  TODO(bondhugula): handle 'affine.if' ops.
+void getLoopIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops);
+
+/// Returns the nesting depth of this operation, i.e., the number of loops
+/// surrounding this operation.
+unsigned getNestingDepth(Operation &op);
+
+/// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
+/// at 'forOp'.
+void getSequentialLoops(AffineForOp forOp,
+                        llvm::SmallDenseSet<Value *, 8> *sequentialLoops);
+
+/// ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their
+/// associated operands for a set of loops within a loop nest (typically the
+/// set of loops surrounding a store operation). Loop bound AffineMaps which
+/// are non-null represent slices of that loop's iteration space.
+struct ComputationSliceState {
+  // List of sliced loop IVs (ordered from outermost to innermost).
+  // EX: 'ivs[i]' has lower bound 'lbs[i]' and upper bound 'ubs[i]'.
+  SmallVector<Value *, 4> ivs;
+  // List of lower bound AffineMaps.
+  SmallVector<AffineMap, 4> lbs;
+  // List of upper bound AffineMaps.
+  SmallVector<AffineMap, 4> ubs;
+  // List of lower bound operands (lbOperands[i] are used by 'lbs[i]').
+  std::vector<SmallVector<Value *, 4>> lbOperands;
+  // List of upper bound operands (ubOperands[i] are used by 'ubs[i]').
+  std::vector<SmallVector<Value *, 4>> ubOperands;
+  // Slice loop nest insertion point in target loop nest.
+  Block::iterator insertPoint;
+  // Adds to 'cst' with constraints which represent the slice bounds on 'ivs'
+  // in 'this'. Specifically, the values in 'ivs' are added to 'cst' as dim
+  // identifiers and the values in 'lb/ubOperands' are added as symbols.
+  // Constraints are added for all loop IV bounds (dim or symbol), and
+  // constraints are added for slice bounds in 'lbs'/'ubs'.
+  // Returns failure if we cannot add loop bounds because of unsupported cases.
+  LogicalResult getAsConstraints(FlatAffineConstraints *cst);
+
+  // Clears all bounds and operands in slice state.
+  void clearBounds();
+};
+
+/// Computes the computation slice loop bounds for one loop nest as affine maps
+/// of the other loop nest's IVs and symbols, using 'dependenceConstraints'
+/// computed between 'depSourceAccess' and 'depSinkAccess'.
+/// If 'isBackwardSlice' is true, a backwards slice is computed in which the
+/// slice bounds of loop nest surrounding 'depSourceAccess' are computed in
+/// terms of loop IVs and symbols of the loop nest surrounding 'depSinkAccess'
+/// at 'loopDepth'.
+/// If 'isBackwardSlice' is false, a forward slice is computed in which the
+/// slice bounds of loop nest surrounding 'depSinkAccess' are computed in terms
+/// of loop IVs and symbols of the loop nest surrounding 'depSourceAccess' at
+/// 'loopDepth'.
+/// The slice loop bounds and associated operands are returned in 'sliceState'.
+//
+//  Backward slice example:
+//
+//    affine.for %i0 = 0 to 10 {
+//      affine.store %cst, %0[%i0] : memref<100xf32>  // 'depSourceAccess'
+//    }
+//    affine.for %i1 = 0 to 10 {
+//      %v = affine.load %0[%i1] : memref<100xf32>    // 'depSinkAccess'
+//    }
+//
+//    // Backward computation slice of loop nest '%i0'.
+//    affine.for %i0 = (d0) -> (d0)(%i1) to (d0) -> (d0 + 1)(%i1) {
+//      affine.store %cst, %0[%i0] : memref<100xf32>  // 'depSourceAccess'
+//    }
+//
+//  Forward slice example:
+//
+//    affine.for %i0 = 0 to 10 {
+//      affine.store %cst, %0[%i0] : memref<100xf32>  // 'depSourceAccess'
+//    }
+//    affine.for %i1 = 0 to 10 {
+//      %v = affine.load %0[%i1] : memref<100xf32>    // 'depSinkAccess'
+//    }
+//
+//    // Forward computation slice of loop nest '%i1'.
+//    affine.for %i1 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 1)(%i0) {
+//      %v = affine.load %0[%i1] : memref<100xf32>    // 'depSinkAccess'
+//    }
+//
+void getComputationSliceState(Operation *depSourceOp, Operation *depSinkOp,
+                              FlatAffineConstraints *dependenceConstraints,
+                              unsigned loopDepth, bool isBackwardSlice,
+                              ComputationSliceState *sliceState);
+
+/// Computes in 'sliceUnion' the union of all slice bounds computed at
+/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB'.
+/// The parameter 'numCommonLoops' is the number of loops common to the
+/// operations in 'opsA' and 'opsB'.
+/// If 'isBackwardSlice' is true, computes slice bounds for loop nest
+/// surrounding ops in 'opsA', as a function of IVs and symbols of loop nest
+/// surrounding ops in 'opsB' at 'loopDepth'.
+/// If 'isBackwardSlice' is false, computes slice bounds for loop nest
+/// surrounding ops in 'opsB', as a function of IVs and symbols of loop nest
+/// surrounding ops in 'opsA' at 'loopDepth'.
+/// Returns 'success' if union was computed, 'failure' otherwise.
+// TODO(andydavis) Change this API to take 'forOpA'/'forOpB'.
+LogicalResult computeSliceUnion(ArrayRef<Operation *> opsA,
+                                ArrayRef<Operation *> opsB, unsigned loopDepth,
+                                unsigned numCommonLoops, bool isBackwardSlice,
+                                ComputationSliceState *sliceUnion);
+
+/// Creates a clone of the computation contained in the loop nest surrounding
+/// 'srcOpInst', slices the iteration space of src loop based on slice bounds
+/// in 'sliceState', and inserts the computation slice at the beginning of the
+/// operation block of the loop at 'dstLoopDepth' in the loop nest surrounding
+/// 'dstOpInst'. Returns the top-level loop of the computation slice on
+/// success, returns nullptr otherwise.
+// Loop depth is a crucial optimization choice that determines where to
+// materialize the results of the backward slice - presenting a trade-off b/w
+// storage and redundant computation in several cases.
+// TODO(andydavis) Support computation slices with common surrounding loops.
+AffineForOp insertBackwardComputationSlice(Operation *srcOpInst,
+                                           Operation *dstOpInst,
+                                           unsigned dstLoopDepth,
+                                           ComputationSliceState *sliceState);
+
+/// A region of a memref's data space; this is typically constructed by
+/// analyzing load/store op's on this memref and the index space of loops
+/// surrounding such op's.
+// For example, the memref region for a load operation at loop depth = 1:
+//
+//    affine.for %i = 0 to 32 {
+//      affine.for %ii = %i to (d0) -> (d0 + 8) (%i) {
+//        affine.load %A[%ii]
+//      }
+//    }
+//
+// Region:  {memref = %A, write = false, {%i <= m0 <= %i + 7} }
+// The last field is a 2-d FlatAffineConstraints symbolic in %i.
+//
+struct MemRefRegion {
+  explicit MemRefRegion(Location loc) : loc(loc) {}
+
+  /// Computes the memory region accessed by this memref with the region
+  /// represented as constraints symbolic/parameteric in 'loopDepth' loops
+  /// surrounding opInst. The computed region's 'cst' field has exactly as many
+  /// dimensional identifiers as the rank of the memref, and *potentially*
+  /// additional symbolic identifiers which could include any of the loop IVs
+  /// surrounding opInst up until 'loopDepth' and another additional Function
+  /// symbols involved with the access (for eg., those appear in affine.apply's,
+  /// loop bounds, etc.). If 'sliceState' is non-null, operands from
+  /// 'sliceState' are added as symbols, and the following constraints are added
+  /// to the system:
+  /// *) Inequality constraints which represent loop bounds for 'sliceState'
+  ///    operands which are loop IVS (these represent the destination loop IVs
+  ///    of the slice, and are added as symbols to MemRefRegion's constraint
+  ///    system).
+  /// *) Inequality constraints for the slice bounds in 'sliceState', which
+  ///    represent the bounds on the loop IVs in this constraint system w.r.t
+  ///    to slice operands (which correspond to symbols).
+  /// If 'addMemRefDimBounds' is true, constant upper/lower bounds
+  /// [0, memref.getDimSize(i)) are added for each MemRef dimension 'i'.
+  ///
+  ///  For example, the memref region for this operation at loopDepth = 1 will
+  ///  be:
+  ///
+  ///    affine.for %i = 0 to 32 {
+  ///      affine.for %ii = %i to (d0) -> (d0 + 8) (%i) {
+  ///        load %A[%ii]
+  ///      }
+  ///    }
+  ///
+  ///   {memref = %A, write = false, {%i <= m0 <= %i + 7} }
+  /// The last field is a 2-d FlatAffineConstraints symbolic in %i.
+  ///
+  LogicalResult compute(Operation *op, unsigned loopDepth,
+                        ComputationSliceState *sliceState = nullptr,
+                        bool addMemRefDimBounds = true);
+
+  FlatAffineConstraints *getConstraints() { return &cst; }
+  const FlatAffineConstraints *getConstraints() const { return &cst; }
+  bool isWrite() const { return write; }
+  void setWrite(bool flag) { write = flag; }
+
+  /// Returns a constant upper bound on the number of elements in this region if
+  /// bounded by a known constant (always possible for static shapes), None
+  /// otherwise. Note that the symbols of the region are treated specially,
+  /// i.e., the returned bounding constant holds for *any given* value of the
+  /// symbol identifiers. The 'shape' vector is set to the corresponding
+  /// dimension-wise bounds major to minor. We use int64_t instead of uint64_t
+  /// since index types can be at most int64_t.
+  Optional<int64_t> getConstantBoundingSizeAndShape(
+      SmallVectorImpl<int64_t> *shape = nullptr,
+      std::vector<SmallVector<int64_t, 4>> *lbs = nullptr,
+      SmallVectorImpl<int64_t> *lbDivisors = nullptr) const;
+
+  /// A wrapper around FlatAffineConstraints::getConstantBoundOnDimSize(). 'pos'
+  /// corresponds to the position of the memref shape's dimension (major to
+  /// minor) which matches 1:1 with the dimensional identifier positions in
+  //'cst'.
+  Optional<int64_t>
+  getConstantBoundOnDimSize(unsigned pos,
+                            SmallVectorImpl<int64_t> *lb = nullptr,
+                            int64_t *lbFloorDivisor = nullptr) const {
+    assert(pos < getRank() && "invalid position");
+    return cst.getConstantBoundOnDimSize(pos, lb);
+  }
+
+  /// Returns the size of this MemRefRegion in bytes.
+  Optional<int64_t> getRegionSize();
+
+  // Wrapper around FlatAffineConstraints::unionBoundingBox.
+  LogicalResult unionBoundingBox(const MemRefRegion &other);
+
+  /// Returns the rank of the memref that this region corresponds to.
+  unsigned getRank() const;
+
+  /// Memref that this region corresponds to.
+  Value *memref;
+
+  /// Read or write.
+  bool write;
+
+  /// If there is more than one load/store op associated with the region, the
+  /// location information would correspond to one of those op's.
+  Location loc;
+
+  /// Region (data space) of the memref accessed. This set will thus have at
+  /// least as many dimensional identifiers as the shape dimensionality of the
+  /// memref, and these are the leading dimensions of the set appearing in that
+  /// order (major to minor / outermost to innermost). There may be additional
+  /// identifiers since getMemRefRegion() is called with a specific loop depth,
+  /// and thus the region is symbolic in the outer surrounding loops at that
+  /// depth.
+  // TODO(bondhugula): Replace this to exploit HyperRectangularSet.
+  FlatAffineConstraints cst;
+};
+
+/// Returns the size of memref data in bytes if it's statically shaped, None
+/// otherwise.
+Optional<uint64_t> getMemRefSizeInBytes(MemRefType memRefType);
+
+/// Checks a load or store op for an out of bound access; returns failure if the
+/// access is out of bounds along any of the dimensions, success otherwise.
+/// Emits a diagnostic error (with location information) if emitError is true.
+template <typename LoadOrStoreOpPointer>
+LogicalResult boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp,
+                                      bool emitError = true);
+
+/// Returns the number of surrounding loops common to both A and B.
+unsigned getNumCommonSurroundingLoops(Operation &A, Operation &B);
+
+/// Gets the memory footprint of all data touched in the specified memory space
+/// in bytes; if the memory space is unspecified, considers all memory spaces.
+Optional<int64_t> getMemoryFootprintBytes(AffineForOp forOp,
+                                          int memorySpace = -1);
+
+/// Returns true if `forOp' is a parallel loop.
+bool isLoopParallel(AffineForOp forOp);
+
+} // end namespace mlir
+
+#endif // MLIR_ANALYSIS_UTILS_H
diff --git a/third_party/mlir/include/mlir/Analysis/VectorAnalysis.h b/third_party/mlir/include/mlir/Analysis/VectorAnalysis.h
new file mode 100644
index 0000000..8b9992d
--- /dev/null
+++ b/third_party/mlir/include/mlir/Analysis/VectorAnalysis.h
@@ -0,0 +1,143 @@
+//===- VectorAnalysis.h - Analysis for Vectorization -------*- C++ -*-=======//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_ANALYSIS_VECTORANALYSIS_H_
+#define MLIR_ANALYSIS_VECTORANALYSIS_H_
+
+#include "mlir/Support/LLVM.h"
+
+#include "llvm/ADT/DenseMap.h"
+
+namespace mlir {
+
+class AffineApplyOp;
+class AffineForOp;
+class AffineMap;
+class Location;
+class MemRefType;
+class OpBuilder;
+class Operation;
+class Value;
+class VectorType;
+
+/// Computes and returns the multi-dimensional ratio of `superShape` to
+/// `subShape`. This is calculated by performing a traversal from minor to major
+/// dimensions (i.e. in reverse shape order). If integral division is not
+/// possible, returns None.
+/// The ArrayRefs are assumed (and enforced) to only contain > 1 values.
+/// This constraint comes from the fact that they are meant to be used with
+/// VectorTypes, for which the property holds by construction.
+///
+/// Examples:
+///   - shapeRatio({3, 4, 5, 8}, {2, 5, 2}) returns {3, 2, 1, 4}
+///   - shapeRatio({3, 4, 4, 8}, {2, 5, 2}) returns None
+///   - shapeRatio({1, 2, 10, 32}, {2, 5, 2}) returns {1, 1, 2, 16}
+llvm::Optional<llvm::SmallVector<unsigned, 4>>
+shapeRatio(ArrayRef<int64_t> superShape, ArrayRef<int64_t> subShape);
+
+/// Computes and returns the multi-dimensional ratio of the shapes of
+/// `superVector` to `subVector`. If integral division is not possible, returns
+/// None.
+/// Assumes and enforces that the VectorTypes have the same elemental type.
+llvm::Optional<llvm::SmallVector<unsigned, 4>>
+shapeRatio(VectorType superVectorType, VectorType subVectorType);
+
+/// Constructs a permutation map of invariant memref indices to vector
+/// dimension.
+///
+/// If no index is found to be invariant, 0 is added to the permutation_map and
+/// corresponds to a vector broadcast along that dimension.
+///
+/// The implementation uses the knowledge of the mapping of loops to
+/// vector dimension. `loopToVectorDim` carries this information as a map with:
+///   - keys representing "vectorized enclosing loops";
+///   - values representing the corresponding vector dimension.
+/// Note that loopToVectorDim is a whole function map from which only enclosing
+/// loop information is extracted.
+///
+/// Prerequisites: `opInst` is a vectorizable load or store operation (i.e. at
+/// most one invariant index along each AffineForOp of `loopToVectorDim`).
+///
+/// Example 1:
+/// The following MLIR snippet:
+///
+/// ```mlir
+///    affine.for %i3 = 0 to %0 {
+///      affine.for %i4 = 0 to %1 {
+///        affine.for %i5 = 0 to %2 {
+///          %a5 = load %arg0[%i4, %i5, %i3] : memref<?x?x?xf32>
+///    }}}
+/// ```
+///
+/// may vectorize with {permutation_map: (d0, d1, d2) -> (d2, d1)} into:
+///
+/// ```mlir
+///    affine.for %i3 = 0 to %0 step 32 {
+///      affine.for %i4 = 0 to %1 {
+///        affine.for %i5 = 0 to %2 step 256 {
+///          %4 = vector.transfer_read %arg0, %i4, %i5, %i3
+///               {permutation_map: (d0, d1, d2) -> (d2, d1)} :
+///               (memref<?x?x?xf32>, index, index) -> vector<32x256xf32>
+///    }}}
+/// ```
+///
+/// Meaning that vector.transfer_read will be responsible for reading the slice:
+/// `%arg0[%i4, %i5:%15+256, %i3:%i3+32]` into vector<32x256xf32>.
+///
+/// Example 2:
+/// The following MLIR snippet:
+///
+/// ```mlir
+///    %cst0 = constant 0 : index
+///    affine.for %i0 = 0 to %0 {
+///      %a0 = load %arg0[%cst0, %cst0] : memref<?x?xf32>
+///    }
+/// ```
+///
+/// may vectorize with {permutation_map: (d0) -> (0)} into:
+///
+/// ```mlir
+///    affine.for %i0 = 0 to %0 step 128 {
+///      %3 = vector.transfer_read %arg0, %c0_0, %c0_0
+///           {permutation_map: (d0, d1) -> (0)} :
+///           (memref<?x?xf32>, index, index) -> vector<128xf32>
+///    }
+/// ````
+///
+/// Meaning that vector.transfer_read will be responsible of reading the slice
+/// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast.
+///
+AffineMap makePermutationMap(
+    Operation *op, ArrayRef<Value *> indices,
+    const llvm::DenseMap<Operation *, unsigned> &loopToVectorDim);
+
+namespace matcher {
+
+/// Matches vector.transfer_read, vector.transfer_write and ops that return a
+/// vector type that is a multiple of the sub-vector type. This allows passing
+/// over other smaller vector types in the function and avoids interfering with
+/// operations on those.
+/// This is a first approximation, it can easily be extended in the future.
+/// TODO(ntv): this could all be much simpler if we added a bit that a vector
+/// type to mark that a vector is a strict super-vector but it still does not
+/// warrant adding even 1 extra bit in the IR for now.
+bool operatesOnSuperVectorsOf(Operation &op, VectorType subVectorType);
+
+} // end namespace matcher
+} // end namespace mlir
+
+#endif // MLIR_ANALYSIS_VECTORANALYSIS_H_
diff --git a/third_party/mlir/include/mlir/Analysis/Verifier.h b/third_party/mlir/include/mlir/Analysis/Verifier.h
new file mode 100644
index 0000000..daaff57
--- /dev/null
+++ b/third_party/mlir/include/mlir/Analysis/Verifier.h
@@ -0,0 +1,31 @@
+//===- Verifier.h - Verifier analysis for MLIR structures -------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_ANALYSIS_VERIFIER_H
+#define MLIR_ANALYSIS_VERIFIER_H
+
+namespace mlir {
+struct LogicalResult;
+class Operation;
+
+/// Perform (potentially expensive) checks of invariants, used to detect
+/// compiler bugs, on this operation and any nested operations. On error, this
+/// reports the error through the MLIRContext and returns failure.
+LogicalResult verify(Operation *op);
+} //  end namespace mlir
+
+#endif
diff --git a/third_party/mlir/include/mlir/CMakeLists.txt b/third_party/mlir/include/mlir/CMakeLists.txt
new file mode 100644
index 0000000..202b40b
--- /dev/null
+++ b/third_party/mlir/include/mlir/CMakeLists.txt
@@ -0,0 +1,7 @@
+add_subdirectory(AffineOps)
+add_subdirectory(Dialect)
+add_subdirectory(EDSC)
+add_subdirectory(Linalg)
+add_subdirectory(LLVMIR)
+add_subdirectory(StandardOps)
+add_subdirectory(VectorOps)
diff --git a/third_party/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h b/third_party/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h
new file mode 100644
index 0000000..78e4356
--- /dev/null
+++ b/third_party/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h
@@ -0,0 +1,45 @@
+//===- ConvertControlFlowToCFG.h - Pass entrypoint --------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_CONVERSION_CONTROLFLOWTOCFG_CONVERTCONTROLFLOWTOCFG_H_
+#define MLIR_CONVERSION_CONTROLFLOWTOCFG_CONVERTCONTROLFLOWTOCFG_H_
+
+#include <memory>
+#include <vector>
+
+namespace mlir {
+class FuncOp;
+class FunctionPassBase;
+struct LogicalResult;
+class MLIRContext;
+class RewritePattern;
+
+// Owning list of rewriting patterns.
+class OwningRewritePatternList;
+
+/// Collect a set of patterns to lower from loop.for, loop.if, and
+/// loop.terminator to CFG operations within the Standard dialect, in particular
+/// convert structured control flow into CFG branch-based control flow.
+void populateLoopToStdConversionPatterns(OwningRewritePatternList &patterns,
+                                         MLIRContext *ctx);
+
+/// Creates a pass to convert loop.for, loop.if and loop.terminator ops to CFG.
+FunctionPassBase *createConvertToCFGPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_CONTROLFLOWTOCFG_CONVERTCONTROLFLOWTOCFG_H_
diff --git a/third_party/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h b/third_party/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h
new file mode 100644
index 0000000..b19fb53
--- /dev/null
+++ b/third_party/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h
@@ -0,0 +1,58 @@
+//===- GPUToCUDAPass.h - MLIR CUDA runtime support --------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
+#define MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <vector>
+
+namespace mlir {
+
+class ModulePassBase;
+class FuncOp;
+
+using OwnedCubin = std::unique_ptr<std::vector<char>>;
+using CubinGenerator = std::function<OwnedCubin(const std::string &, FuncOp &)>;
+
+/// Creates a pass to convert kernel functions into CUBIN blobs.
+///
+/// This transformation takes the body of each function that is annotated with
+/// the 'nvvm.kernel' attribute, copies it to a new LLVM module, compiles the
+/// module with help of the nvptx backend to PTX and then invokes the provided
+/// cubinGenerator to produce a binary blob (the cubin). Such blob is then
+/// attached as a string attribute named 'nvvm.cubin' to the kernel function.
+/// After the transformation, the body of the kernel function is removed (i.e.,
+/// it is turned into a declaration).
+ModulePassBase *
+createConvertGPUKernelToCubinPass(CubinGenerator cubinGenerator);
+
+/// Creates a pass to convert a gpu.launch_func operation into a sequence of
+/// CUDA calls.
+///
+/// This pass does not generate code to call CUDA directly but instead uses a
+/// small wrapper library that exports a stable and conveniently typed ABI
+/// ontop of CUDA.
+ModulePassBase *createConvertGpuLaunchFuncToCudaCallsPass();
+
+/// Creates a pass to augment a module with getter functions for all contained
+/// cubins as encoded via the 'nvvm.cubin' attribute.
+ModulePassBase *createGenerateCubinAccessorPass();
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
diff --git a/third_party/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/third_party/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
new file mode 100644
index 0000000..b53549f
--- /dev/null
+++ b/third_party/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
@@ -0,0 +1,28 @@
+//===- GPUToNVMMPass.h - Convert GPU kernel to NVVM dialect -----*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_
+#define MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_
+
+namespace mlir {
+struct FunctionPassBase;
+
+/// Creates a pass that lowers GPU dialect operations to NVVM counterparts.
+FunctionPassBase *createLowerGpuOpsToNVVMOpsPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_
diff --git a/third_party/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h b/third_party/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h
new file mode 100644
index 0000000..973b995
--- /dev/null
+++ b/third_party/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h
@@ -0,0 +1,57 @@
+//===- LoopsToGPU.h - Convert loop nests to GPU kernels ---------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPU_H_
+#define MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPU_H_
+
+namespace mlir {
+class AffineForOp;
+struct LogicalResult;
+
+namespace loop {
+class ForOp;
+} // end namespace loop
+
+/// Convert a perfect affine loop nest with the outermost loop identified by
+/// `forOp` into a gpu::Launch operation.  Map `numBlockDims` outer loops to
+/// GPU blocks and `numThreadDims` to GPU threads.  The bounds of the loops that
+/// are mapped should be independent of the induction variables of the other
+/// mapped loops.
+///
+/// No check on the size of the block or grid, or on the validity of
+/// parallelization is performed, it is under the responsibility of the caller
+/// to strip-mine the loops and to perform the dependence analysis before
+/// calling the conversion.
+LogicalResult convertAffineLoopNestToGPULaunch(AffineForOp forOp,
+                                               unsigned numBlockDims,
+                                               unsigned numThreadDims);
+
+/// Convert a perfect linalg loop nest with the outermost loop identified by
+/// `forOp` into a gpu::Launch operation.  Map `numBlockDims` outer loops to
+/// GPU blocks and `numThreadDims` to GPU threads.  The bounds of the loops that
+/// are mapped should be independent of the induction variables of the other
+/// mapped loops.
+///
+/// No check on the size of the block or grid, or on the validity of
+/// parallelization is performed, it is under the responsibility of the caller
+/// to strip-mine the loops and to perform the dependence analysis before
+/// calling the conversion.
+LogicalResult convertLoopNestToGPULaunch(loop::ForOp forOp,
+                                         unsigned numBlockDims,
+                                         unsigned numThreadDims);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPU_H_
diff --git a/third_party/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h b/third_party/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h
new file mode 100644
index 0000000..52f0dd4
--- /dev/null
+++ b/third_party/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h
@@ -0,0 +1,35 @@
+//===- LoopsToGPUPass.h - Pass converting loops to GPU kernels --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_
+#define MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_
+
+namespace mlir {
+class FunctionPassBase;
+
+/// Create a pass that converts loop nests into GPU kernels.  It considers
+/// top-level affine.for and linalg.for operations as roots of loop nests and
+/// converts them to the gpu.launch operations if possible.
+///
+/// No check on the size of the block or grid, or on the validity of
+/// parallelization is performed, it is under the responsibility of the caller
+/// to strip-mine the loops and to perform the dependence analysis before
+/// calling the conversion.
+FunctionPassBase *createSimpleLoopsToGPUPass(unsigned numBlockDims,
+                                             unsigned numThreadDims);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_
diff --git a/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
new file mode 100644
index 0000000..d5c4c11
--- /dev/null
+++ b/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -0,0 +1,129 @@
+//===- ConvertStandardToLLVM.h - Convert to the LLVM dialect ----*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Provides a dialect conversion targeting the LLVM IR dialect.  By default, it
+// converts Standard ops and types and provides hooks for dialect-specific
+// extensions to the conversion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
+#define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace llvm {
+class IntegerType;
+class LLVMContext;
+class Module;
+class Type;
+} // namespace llvm
+
+namespace mlir {
+namespace LLVM {
+class LLVMDialect;
+class LLVMType;
+} // namespace LLVM
+
+/// Conversion from types in the Standard dialect to the LLVM IR dialect.
+class LLVMTypeConverter : public TypeConverter {
+public:
+  using TypeConverter::convertType;
+
+  LLVMTypeConverter(MLIRContext *ctx);
+
+  /// Convert types to LLVM IR.  This calls `convertAdditionalType` to convert
+  /// non-standard or non-builtin types.
+  Type convertType(Type t) override;
+
+  /// Convert a non-empty list of types to be returned from a function into a
+  /// supported LLVM IR type.  In particular, if more than one values is
+  /// returned, create an LLVM IR structure type with elements that correspond
+  /// to each of the MLIR types converted with `convertType`.
+  Type packFunctionResults(ArrayRef<Type> types);
+
+  /// Returns the LLVM context.
+  llvm::LLVMContext &getLLVMContext();
+
+  /// Returns the LLVM dialect.
+  LLVM::LLVMDialect *getDialect() { return llvmDialect; }
+
+protected:
+  /// LLVM IR module used to parse/create types.
+  llvm::Module *module;
+  LLVM::LLVMDialect *llvmDialect;
+
+private:
+  Type convertStandardType(Type type);
+
+  // Convert a function type.  The arguments and results are converted one by
+  // one.  Additionally, if the function returns more than one value, pack the
+  // results into an LLVM IR structure type so that the converted function type
+  // returns at most one result.
+  Type convertFunctionType(FunctionType type);
+
+  // Convert the index type.  Uses llvmModule data layout to create an integer
+  // of the pointer bitwidth.
+  Type convertIndexType(IndexType type);
+
+  // Convert an integer type `i*` to `!llvm<"i*">`.
+  Type convertIntegerType(IntegerType type);
+
+  // Convert a floating point type: `f16` to `!llvm.half`, `f32` to
+  // `!llvm.float` and `f64` to `!llvm.double`.  `bf16` is not supported
+  // by LLVM.
+  Type convertFloatType(FloatType type);
+
+  // Convert a memref type into an LLVM type that captures the relevant data.
+  // For statically-shaped memrefs, the resulting type is a pointer to the
+  // (converted) memref element type. For dynamically-shaped memrefs, the
+  // resulting type is an LLVM structure type that contains:
+  //   1. a pointer to the (converted) memref element type
+  //   2. as many index types as memref has dynamic dimensions.
+  Type convertMemRefType(MemRefType type);
+
+  // Convert a 1D vector type into an LLVM vector type.
+  Type convertVectorType(VectorType type);
+
+  // Get the LLVM representation of the index type based on the bitwidth of the
+  // pointer as defined by the data layout of the module.
+  LLVM::LLVMType getIndexType();
+
+  // Wrap the given LLVM IR type into an LLVM IR dialect type.
+  Type wrap(llvm::Type *llvmType);
+
+  // Extract an LLVM IR dialect type.
+  LLVM::LLVMType unwrap(Type type);
+};
+
+/// Base class for operation conversions targeting the LLVM IR dialect. Provides
+/// conversion patterns with an access to the containing LLVMLowering for the
+/// purpose of type conversions.
+class LLVMOpLowering : public ConversionPattern {
+public:
+  LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
+                 LLVMTypeConverter &lowering);
+
+protected:
+  // Back-reference to the lowering class, used to call type and function
+  // conversions accounting for potential extensions.
+  LLVMTypeConverter &lowering;
+};
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
diff --git a/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
new file mode 100644
index 0000000..941e382
--- /dev/null
+++ b/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
@@ -0,0 +1,92 @@
+//===- ConvertStandardToLLVMPass.h - Pass entrypoint ------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_
+#define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_
+
+#include "llvm/ADT/STLExtras.h"
+#include <memory>
+#include <vector>
+
+namespace llvm {
+class Module;
+} // namespace llvm
+
+namespace mlir {
+class DialectConversion;
+class FuncOp;
+class LLVMTypeConverter;
+struct LogicalResult;
+class MLIRContext;
+class ModuleOp;
+class ModulePassBase;
+class RewritePattern;
+class Type;
+
+// Owning list of rewriting patterns.
+class OwningRewritePatternList;
+
+/// Type for a callback constructing the owning list of patterns for the
+/// conversion to the LLVMIR dialect.  The callback is expected to append
+/// patterns to the owning list provided as the second argument.
+using LLVMPatternListFiller =
+    std::function<void(LLVMTypeConverter &, OwningRewritePatternList &)>;
+
+/// Type for a callback constructing the type converter for the conversion to
+/// the LLVMIR dialect.  The callback is expected to return an instance of the
+/// converter.
+using LLVMTypeConverterMaker =
+    std::function<std::unique_ptr<LLVMTypeConverter>(MLIRContext *)>;
+
+/// Collect a set of patterns to convert from the Standard dialect to LLVM.
+void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
+                                         OwningRewritePatternList &patterns);
+
+/// Creates a pass to convert the Standard dialect into the LLVMIR dialect.
+ModulePassBase *createConvertToLLVMIRPass();
+
+/// Creates a pass to convert operations to the LLVMIR dialect.  The conversion
+/// is defined by a list of patterns and a type converter that will be obtained
+/// during the pass using the provided callbacks.
+ModulePassBase *
+createConvertToLLVMIRPass(LLVMPatternListFiller patternListFiller,
+                          LLVMTypeConverterMaker typeConverterMaker);
+
+/// Creates a pass to convert operations to the LLVMIR dialect.  The conversion
+/// is defined by a list of patterns obtained during the pass using the provided
+/// callback and an optional type conversion class, an instance is created
+/// during the pass.
+template <typename TypeConverter = LLVMTypeConverter>
+ModulePassBase *
+createConvertToLLVMIRPass(LLVMPatternListFiller patternListFiller) {
+  return createConvertToLLVMIRPass(patternListFiller, [](MLIRContext *context) {
+    return llvm::make_unique<TypeConverter>(context);
+  });
+}
+
+namespace LLVM {
+/// Make argument-taking successors of each block distinct.  PHI nodes in LLVM
+/// IR use the predecessor ID to identify which value to take.  They do not
+/// support different values coming from the same predecessor.  If a block has
+/// another block as a successor more than once with different values, insert
+/// a new dummy block for LLVM PHI nodes to tell the sources apart.
+void ensureDistinctSuccessors(ModuleOp m);
+} // namespace LLVM
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_
diff --git a/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h b/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h
new file mode 100644
index 0000000..21c2842
--- /dev/null
+++ b/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h
@@ -0,0 +1,103 @@
+//===- ConvertStandardToSPIRV.h - Convert to SPIR-V dialect -----*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Provides type converters and patterns to convert from standard types/ops to
+// SPIR-V types and operations. Also provides utilities and base classes to use
+// while targeting SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H
+#define MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+
+namespace spirv {
+class SPIRVDialect;
+}
+
+/// Type conversion from Standard Types to SPIR-V Types.
+class SPIRVTypeConverter : public TypeConverter {
+public:
+  explicit SPIRVTypeConverter(MLIRContext *context);
+
+  /// Converts types to SPIR-V supported types.
+  Type convertType(Type t) override;
+
+protected:
+  spirv::SPIRVDialect *spirvDialect;
+};
+
+/// Converts a function type according to the requirements of a SPIR-V entry
+/// function. The arguments need to be converted to spv.Variables of spv.ptr
+/// types so that they could be bound by the runtime.
+class SPIRVEntryFnTypeConverter final : public SPIRVTypeConverter {
+public:
+  using SPIRVTypeConverter::SPIRVTypeConverter;
+
+  /// Method to convert argument of a function. The `type` is converted to
+  /// spv.ptr<type, Uniform>.
+  // TODO(ravishankarm) : Support other storage classes.
+  LogicalResult convertSignatureArg(unsigned inputNo, Type type,
+                                    SignatureConversion &result) override;
+};
+
+/// Base class to define a conversion pattern to translate Ops into SPIR-V.
+template <typename OpTy> class SPIRVOpLowering : public ConversionPattern {
+public:
+  SPIRVOpLowering(MLIRContext *context, SPIRVTypeConverter &typeConverter,
+                  SPIRVEntryFnTypeConverter &entryFnConverter)
+      : ConversionPattern(OpTy::getOperationName(), 1, context),
+        typeConverter(typeConverter), entryFnConverter(entryFnConverter) {}
+
+protected:
+  // Type lowering class.
+  SPIRVTypeConverter &typeConverter;
+
+  // Entry function signature converter.
+  SPIRVEntryFnTypeConverter &entryFnConverter;
+};
+
+/// Base Class for legalize a FuncOp within a spv.module. This class can be
+/// extended to implement a ConversionPattern to lower a FuncOp. It provides
+/// hooks to legalize a FuncOp as a simple function, or as an entry function.
+class SPIRVFnLowering : public SPIRVOpLowering<FuncOp> {
+public:
+  using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
+
+protected:
+  /// Method to legalize the function as a non-entry function.
+  LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
+                              ConversionPatternRewriter &rewriter,
+                              FuncOp &newFuncOp) const;
+
+  /// Method to legalize the function as an entry function.
+  LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
+                                     ConversionPatternRewriter &rewriter,
+                                     FuncOp &newFuncOp) const;
+};
+
+/// Appends to a pattern list additional patterns for translating StandardOps to
+/// SPIR-V ops.
+void populateStandardToSPIRVPatterns(MLIRContext *context,
+                                     OwningRewritePatternList &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H
diff --git a/third_party/mlir/include/mlir/Dialect/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/CMakeLists.txt
new file mode 100644
index 0000000..5ae314a
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -0,0 +1,5 @@
+add_subdirectory(FxpMathOps)
+add_subdirectory(GPU)
+add_subdirectory(LoopOps)
+add_subdirectory(QuantOps)
+add_subdirectory(SPIRV)
diff --git a/third_party/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt
new file mode 100644
index 0000000..eaf72d2
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS FxpMathOps.td)
+mlir_tablegen(FxpMathOps.h.inc -gen-op-decls)
+mlir_tablegen(FxpMathOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRFxpMathOpsIncGen)
diff --git a/third_party/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h b/third_party/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h
new file mode 100644
index 0000000..88a4234
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h
@@ -0,0 +1,40 @@
+//===- FxpMathOps.h - Fixed point ops ---------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_DIALECT_FXPMATHOPS_FXPMATHOPS_H_
+#define MLIR_DIALECT_FXPMATHOPS_FXPMATHOPS_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+namespace fxpmath {
+
+/// Defines the 'FxpMathOps' dialect.
+class FxpMathOpsDialect : public Dialect {
+public:
+  FxpMathOpsDialect(MLIRContext *context);
+};
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/FxpMathOps/FxpMathOps.h.inc"
+
+} // namespace fxpmath
+} // namespace mlir
+
+#endif // MLIR_DIALECT_FXPMATHOPS_FXPMATHOPS_H_
diff --git a/third_party/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td b/third_party/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td
new file mode 100644
index 0000000..46b4293
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td
@@ -0,0 +1,290 @@
+//===- FxpMathOps.td - Fixed point ops  --------------------*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is the operation definition file for fixed point ops (and real
+// equivalents).
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef DIALECT_FXPMATHOPS_FXPMATH_OPS_
+#else
+#define DIALECT_FXPMATHOPS_FXPMATH_OPS_
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+include "mlir/Dialect/QuantOps/QuantPredicates.td"
+
+def fxpmath_Dialect : Dialect {
+  let name = "fxpmath";
+}
+
+//===----------------------------------------------------------------------===//
+// Attributes
+//===----------------------------------------------------------------------===//
+
+// Real value for an (inclusive) min/max clamp limit.
+def fxpmath_ClampValueAttr : OptionalAttr<F64Attr>;
+
+// Element-wise activation function to apply.
+// Note that RELU activations are not here: they are expressed as clamps.
+def fxpmath_EwUnaryFnAttr :
+    StringBasedAttr<CPred<"true">, "element-wise unary function"> {
+  let returnType = [{ StringRef }];
+  let defaultValue = "IDENTITY";
+}
+
+class fxpmath_ConstEwUnaryFn<string val> : ConstantAttr<fxpmath_EwUnaryFnAttr, val>;
+def fxpmath_EwUnaryFn_Abs     : fxpmath_ConstEwUnaryFn<"ABS">;
+def fxpmath_EwUnaryFn_Exp     : fxpmath_ConstEwUnaryFn<"EXP">;
+def fxpmath_EwUnaryFn_Identity: fxpmath_ConstEwUnaryFn<"IDENTITY">;
+def fxpmath_EwUnaryFn_Log     : fxpmath_ConstEwUnaryFn<"LOG">;
+def fxpmath_EwUnaryFn_Neg     : fxpmath_ConstEwUnaryFn<"NEG">;
+def fxpmath_EwUnaryFn_Rsqrt   : fxpmath_ConstEwUnaryFn<"RSQRT">;
+def fxpmath_EwUnaryFn_Sigmoid : fxpmath_ConstEwUnaryFn<"SIGMOID">;
+def fxpmath_EwUnaryFn_Sign    : fxpmath_ConstEwUnaryFn<"SIGN">;
+def fxpmath_EwUnaryFn_Sin     : fxpmath_ConstEwUnaryFn<"SIN">;
+def fxpmath_EwUnaryFn_Sqrt    : fxpmath_ConstEwUnaryFn<"SQRT">;
+def fxpmath_EwUnaryFn_Square  : fxpmath_ConstEwUnaryFn<"SQUARE">;
+def fxpmath_EwUnaryFn_Tanh    : fxpmath_ConstEwUnaryFn<"TANH">;
+
+//===----------------------------------------------------------------------===//
+// Comparison functions (compares relative to zero on a subtraction result).
+//===----------------------------------------------------------------------===//
+
+def fxpmath_CompareZ    : StrEnumAttrCase<"CMPZ">;
+def fxpmath_CompareNZ   : StrEnumAttrCase<"CMPNZ">;
+def fxpmath_CompareLZ   : StrEnumAttrCase<"CMPLZ">;
+def fxpmath_CompareLZE  : StrEnumAttrCase<"CMPLZE">;
+def fxpmath_CompareGZ   : StrEnumAttrCase<"CMPGZ">;
+def fxpmath_CompareGZE  : StrEnumAttrCase<"CMPGZE">;
+
+def fxpmath_CompareFnAttr : StrEnumAttr<"ComparisonFn",
+    "Type of subtraction-result comparison to perform.",
+    [
+      fxpmath_CompareZ,
+      fxpmath_CompareNZ,
+      fxpmath_CompareLZ,
+      fxpmath_CompareLZE,
+      fxpmath_CompareGZ,
+      fxpmath_CompareGZE
+    ]>;
+
+//===----------------------------------------------------------------------===//
+// Base classes
+//===----------------------------------------------------------------------===//
+
+class fxpmath_Op<string mnemonic, list<OpTrait> traits> :
+    Op<fxpmath_Dialect, mnemonic, traits>;
+
+//===----------------------------------------------------------------------===//
+// Fixed-point (fxp) arithmetic ops used by kernels.
+// Some of these are temporary pending inclusion into a more core dialect.
+//===----------------------------------------------------------------------===//
+
+def fxpmath_ClampISOp : fxpmath_Op<"clampis", [NoSideEffect, SameOperandsAndResultType]> {
+  let summary =
+      "Clamps a signed-integer like argument to a min/max range.";
+  let description = [{
+    Element-wise equivalent to:
+      r = std::min(clamp_max, std::max(e, clamp_min))
+  }];
+  let arguments = (ins IntegerLike:$operand,
+                       APIntAttr:$clamp_min,
+                       APIntAttr:$clamp_max);
+  let results = (outs IntegerLike);
+}
+
+def fxpmath_ConvertISOp :
+    fxpmath_Op<"convertis",
+               [NoSideEffect, SameOperandsAndResultShape]> {
+  let summary =
+      "Does an element-wise conversion from a signed integer to signed integer";
+  let description = [{
+    Similar to an element-wise static_cast in C++, from a one signed integer
+    element type to another.
+  }];
+  let arguments = (ins IntegerLike:$operand);
+  let results = (outs IntegerLike);
+}
+
+def fxpmath_ConvertISToFOp :
+    fxpmath_Op<"convertistof",
+               [NoSideEffect, SameOperandsAndResultShape]> {
+  let summary =
+      "Does an element-wise conversion from a signed integer to a float";
+  let description = [{
+    Similar to an element-wise static_cast in C++, from a signed integer
+    element type to a floating point element type, rounding to the nearest
+    floating point value.
+  }];
+  let arguments = (ins IntegerLike:$operand);
+  let results = (outs FloatLike);
+}
+
+
+def fxpmath_VecScalarSaturatingRoundingDoublingHighMulISOp :
+    fxpmath_Op<"vs_saturating_rounding_doubling_high_mulis",
+               [NoSideEffect, SameOperandsAndResultType]> {
+  let summary = "Implements equivalent functionality to ARMv7 NEON VQRDMULH";
+  let description = [{
+    Equivalent to the ARMv7 NEON VQRDMULH instruction.
+    See gemmlowp::SaturatingRoundingDoublingHighMul for a reference
+    implementation.
+  }];
+  let arguments = (ins IntegerLike:$a, APIntAttr:$b);
+  let results = (outs IntegerLike);
+}
+
+def fxpmath_RoundingDivideByPotISOp :
+    fxpmath_Op<"rounding_divide_by_potis", [NoSideEffect, SameOperandsAndResultType]> {
+  let summary = [{
+    Computes a rounding arithmetic right shift.
+  }];
+  let description = [{
+    Computes integer division by a power-of-two, correctly rounded-to-nearest.
+    Also known as a rounding arithmetic right shift. See
+    gemmlowp::RoundingDivideByPOT for a reference implementation.
+  }];
+  let arguments = (ins IntegerLike:$operand, APIntAttr:$exponent);
+  let results = (outs IntegerLike:$res);
+  let verifier = [{
+    auto verifyExponent = exponent().getSExtValue();
+    if (verifyExponent < 0 || verifyExponent > 31) {
+      return emitOpError("exponent must be in range [0..31]");
+    }
+    return success();
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// Real math ops.
+//
+// Math ops on real numbers which may have a representation in quantized
+// arithmetic. It is expected that eligible ops are lowered from a source
+// dialect to this set of ops prior to the process of converting a compuation
+// to a quantized form. It is a non-goal of these ops to preserve enough
+// information to convert back to the higher level, source dialect.
+//
+// These ops support either real/floating point or QuantizedTypes as operands
+// and results. Since not all transformations are supported (globally or
+// sometimes for specific targets), a computation may end up with
+// untransformable RealMathOps, in which case they need to be lowered as is
+// (using floating point math).
+//
+// This op set takes advantage of the fact that it is typically trivial to
+// combine a math function with a compatible bias addition and real-valued
+// clamp (which can be done at a higher accumulation bit depth).
+//
+// In addition, all element-wise unary functions are collapsed into a single
+// fxpmath_RealUnaryEwOp and selected via an enum-like attribute. Especially at
+// low bit depths, this makes matching simpler and allows the construction of
+// generic LUT-based implementations. It also allows specific lowering rules
+// to consolidate runs of chained unary ops and fuse them to preceding math
+// ops, potentially allowing them to operate directly on higher precision
+// intermediates without resorting to lots of custom kernels for common
+// formulas that can suffer from insufficient precision at low bit depths.
+//
+// Comparison operators are modeled as element-wise unary functions (i.e.
+// CMPZ, CMPNZ, CMPLZ, CMPGZ) intended to follow a sub and output a 1bit
+// quantized value. It is expected that lowering rules can fuse them with
+// the preceding sub.
+//===----------------------------------------------------------------------===//
+
+class fxpmath_RealMathOp<string mnemonic, list<OpTrait> traits = [], dag args> :
+    fxpmath_Op<mnemonic, traits>,
+    Arguments<!con(args, (ins
+        fxpmath_ClampValueAttr:$clamp_min, fxpmath_ClampValueAttr:$clamp_max))>;
+
+//===----------------------------------------------------------------------===//
+// Element wise binary real math ops.
+//===----------------------------------------------------------------------===//
+
+class fxpmath_RealBinaryOp<string mnemonic, list<OpTrait> traits = []> :
+    fxpmath_RealMathOp<mnemonic, traits,
+                     (ins quant_RealValueType:$lhs,
+                      quant_RealValueType:$rhs)>,
+    Results<(outs quant_RealValueType:$res)>;
+
+class fxpmath_RealBinaryBiasOp<string mnemonic, list<OpTrait> traits = []> :
+    fxpmath_RealMathOp<mnemonic, traits,
+                     (ins quant_RealValueType:$lhs, quant_RealValueType:$rhs,
+                          quant_RealValueType:$bias)>,
+    Results<(outs quant_RealValueType:$res)>;
+
+def fxpmath_RealAddEwOp :
+    fxpmath_RealBinaryOp<"real_add_ew", [NoSideEffect]>;
+
+def fxpmath_RealSubEwOp :
+    fxpmath_RealBinaryOp<"real_sub_ew", [NoSideEffect]>;
+
+def fxpmath_RealMulEwOp :
+    fxpmath_RealBinaryOp<"real_mul_ew", [NoSideEffect]>;
+
+def fxpmath_RealDivEwOp :
+    fxpmath_RealBinaryOp<"real_div_ew", [NoSideEffect]>;
+
+//===----------------------------------------------------------------------===//
+// Element wise unary real math op.
+//===----------------------------------------------------------------------===//
+
+def fxpmath_RealUnaryEwOp :
+    fxpmath_RealMathOp<"real_unary_ew", [NoSideEffect],
+        (ins quant_RealValueType:$operand, fxpmath_EwUnaryFnAttr:$fn)>,
+    Results<(outs quant_RealValueType:$res)>;
+
+def fxpmath_RealCompareZeroEwOp : fxpmath_Op<"compare", [NoSideEffect]>,
+    Arguments<(ins quant_RealValueType:$operand, fxpmath_CompareFnAttr:$fn)>,
+    Results<(outs I1Tensor:$res)> {
+  let description = [{
+    Compares a real value to zero, returning an I1 (boolean) tensor with the
+    result of applying the comparison function.
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// Dot op with fused bias addition.
+//===----------------------------------------------------------------------===//
+
+def fxpmath_RealMatMulOp :
+    fxpmath_RealBinaryOp<"real_matmul", [NoSideEffect]> {
+  let summary = "Matmul";
+  let description = [{
+    A matrix multiply of [m, k] and [k, n] -> [m, n] where the bias vector is
+    of shape [n]. Also accepts rank 3 or more input tensors, in which case
+    the leading dimensions are batch dims.
+
+    Many real systems have specific library calls optimized for this precise
+    operation, which is why it is handled explicitly versus purely as a
+    generalized tensor contraction.
+  }];
+}
+
+def fxpmath_RealMatMulBiasOp :
+    fxpmath_RealBinaryBiasOp<"real_matmul_bias", [NoSideEffect]> {
+  let summary = "Matmul with bias";
+  let description = [{
+    A specialization of a RealMatMulOp that also accepts an [n] dimension
+    bias vector.
+
+    In addition, there is often special support for a fused bias and clamp,
+    which is why they are included.
+  }];
+}
+
+#endif  // DIALECT_FXPMATHOPS_FXPMATH_OPS_
diff --git a/third_party/mlir/include/mlir/Dialect/FxpMathOps/Passes.h b/third_party/mlir/include/mlir/Dialect/FxpMathOps/Passes.h
new file mode 100644
index 0000000..74c634a
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/FxpMathOps/Passes.h
@@ -0,0 +1,43 @@
+//===- Passes.h - Fixed point math passes -----------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines all of the passes owned by the FxpMathOps dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_FXPMATHOPS_PASSES_H
+#define MLIR_DIALECT_FXPMATHOPS_PASSES_H
+
+namespace mlir {
+class FunctionPassBase;
+
+namespace fxpmath {
+
+/// Creates a pass that lowers uniform-quantized real math ops to integer
+/// arithmetic. This will leave unrecognized real math ops as-is and is
+/// typically followed by a pass that lowers any unrecognized ops to a pure
+/// floating point form.
+FunctionPassBase *createLowerUniformRealMathPass();
+
+/// Creates a pass that lowers uniform-quantized qcast/dcast ops to equivalent
+/// operations that perform quantize/dequantize.
+FunctionPassBase *createLowerUniformCastsPass();
+
+} // namespace fxpmath
+} // namespace mlir
+
+#endif // MLIR_DIALECT_FXPMATHOPS_PASSES_H
diff --git a/third_party/mlir/include/mlir/Dialect/GPU/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/GPU/CMakeLists.txt
new file mode 100644
index 0000000..5ba59a1
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/GPU/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS GPUOps.td)
+mlir_tablegen(GPUOps.h.inc -gen-op-decls)
+mlir_tablegen(GPUOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRGPUOpsIncGen)
diff --git a/third_party/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/third_party/mlir/include/mlir/Dialect/GPU/GPUDialect.h
new file mode 100644
index 0000000..d034212
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/GPU/GPUDialect.h
@@ -0,0 +1,174 @@
+//===- GPUDialect.h - MLIR Dialect for GPU Kernels --------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines the GPU kernel-related operations and puts them in the
+// corresponding dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_GPU_GPUDIALECT_H
+#define MLIR_DIALECT_GPU_GPUDIALECT_H
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+class FuncOp;
+
+namespace gpu {
+
+/// The dialect containing GPU kernel launching operations and related
+/// facilities.
+class GPUDialect : public Dialect {
+public:
+  /// Create the dialect in the given `context`.
+  GPUDialect(MLIRContext *context);
+
+  /// Get the canonical string name of the dialect.
+  static StringRef getDialectName();
+
+  /// Get the name of the attribute used to annotate outlined kernel functions.
+  static StringRef getKernelFuncAttrName() { return "gpu.kernel"; }
+
+  /// Returns whether the given function is a kernel function, i.e., has the
+  /// 'gpu.kernel' attribute.
+  static bool isKernel(FuncOp function);
+};
+
+/// Utility class for the GPU dialect to represent triples of `Value`s
+/// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation.
+struct KernelDim3 {
+  Value *x;
+  Value *y;
+  Value *z;
+};
+
+/// GPU kernel launch operation.  Takes a 3D grid of thread blocks as leading
+/// operands, followed by kernel data operands.  Has one region representing
+/// the kernel to be executed.  This region is not allowed to use values defined
+/// outside it.
+class LaunchOp : public Op<LaunchOp, OpTrait::AtLeastNOperands<6>::Impl,
+                           OpTrait::ZeroResult, OpTrait::IsIsolatedFromAbove> {
+public:
+  using Op::Op;
+
+  static void build(Builder *builder, OperationState *result, Value *gridSizeX,
+                    Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX,
+                    Value *blockSizeY, Value *blockSizeZ,
+                    ArrayRef<Value *> operands);
+
+  /// Get the kernel region.
+  Region &getBody();
+
+  /// Get the SSA values corresponding to kernel block identifiers.
+  KernelDim3 getBlockIds();
+  /// Get the SSA values corresponding to kernel thread identifiers.
+  KernelDim3 getThreadIds();
+  /// Get the SSA values corresponding to kernel grid size.
+  KernelDim3 getGridSize();
+  /// Get the SSA values corresponding to kernel block size.
+  KernelDim3 getBlockSize();
+  /// Get the operand values passed as kernel arguments.
+  operand_range getKernelOperandValues();
+  /// Get the operand types passed as kernel arguments.
+  operand_type_range getKernelOperandTypes();
+
+  /// Get the SSA values passed as operands to specify the grid size.
+  KernelDim3 getGridSizeOperandValues();
+  /// Get the SSA values passed as operands to specify the block size.
+  KernelDim3 getBlockSizeOperandValues();
+
+  /// Get the SSA values of the kernel arguments.
+  llvm::iterator_range<Block::args_iterator> getKernelArguments();
+
+  LogicalResult verify();
+
+  /// Custom syntax support.
+  void print(OpAsmPrinter *p);
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+
+  static StringRef getOperationName() { return "gpu.launch"; }
+
+  /// Erase the `index`-th kernel argument.  Both the entry block argument and
+  /// the operand will be dropped.  The block argument must not have any uses.
+  void eraseKernelArgument(unsigned index);
+
+  /// Append canonicalization patterns to `results`.
+  static void getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                          MLIRContext *context);
+
+private:
+  static StringRef getBlocksKeyword() { return "blocks"; }
+  static StringRef getThreadsKeyword() { return "threads"; }
+  static StringRef getArgsKeyword() { return "args"; }
+
+  /// The number of launch configuration operands, placed at the leading
+  /// positions of the operand list.
+  static constexpr unsigned kNumConfigOperands = 6;
+
+  /// The number of region attributes containing the launch configuration,
+  /// placed in the leading positions of the argument list.
+  static constexpr unsigned kNumConfigRegionAttributes = 12;
+};
+
+/// Operation to launch a kernel given as outlined function.
+class LaunchFuncOp : public Op<LaunchFuncOp, OpTrait::AtLeastNOperands<6>::Impl,
+                               OpTrait::ZeroResult> {
+public:
+  using Op::Op;
+
+  static void build(Builder *builder, OperationState *result, FuncOp kernelFunc,
+                    Value *gridSizeX, Value *gridSizeY, Value *gridSizeZ,
+                    Value *blockSizeX, Value *blockSizeY, Value *blockSizeZ,
+                    ArrayRef<Value *> kernelOperands);
+
+  static void build(Builder *builder, OperationState *result, FuncOp kernelFunc,
+                    KernelDim3 gridSize, KernelDim3 blockSize,
+                    ArrayRef<Value *> kernelOperands);
+
+  /// The kernel function specified by the operation's `kernel` attribute.
+  StringRef kernel();
+  /// The number of operands passed to the kernel function.
+  unsigned getNumKernelOperands();
+  /// The i-th operand passed to the kernel function.
+  Value *getKernelOperand(unsigned i);
+
+  /// Get the SSA values passed as operands to specify the grid size.
+  KernelDim3 getGridSizeOperandValues();
+  /// Get the SSA values passed as operands to specify the block size.
+  KernelDim3 getBlockSizeOperandValues();
+
+  LogicalResult verify();
+
+  static StringRef getOperationName() { return "gpu.launch_func"; }
+
+  /// The number of launch configuration operands, placed at the leading
+  /// positions of the operand list.
+  static constexpr unsigned kNumConfigOperands = 6;
+
+private:
+  /// The name of the function attribute specifying the kernel to launch.
+  static StringRef getKernelAttrName() { return "kernel"; }
+};
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/GPU/GPUOps.h.inc"
+
+} // end namespace gpu
+} // end namespace mlir
+
+#endif // MLIR_DIALECT_GPU_GPUDIALECT_H
diff --git a/third_party/mlir/include/mlir/Dialect/GPU/GPUOps.td b/third_party/mlir/include/mlir/Dialect/GPU/GPUOps.td
new file mode 100644
index 0000000..b38a597
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -0,0 +1,60 @@
+//===-- GPUOps.td - GPU dialect operation definitions ------*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Defines some operations of the GPU dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef GPU_OPS
+#else
+#define GPU_OPS
+
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+def GPU_Dialect : Dialect {
+  let name = "gpu";
+}
+
+class GPU_Op<string mnemonic, list<OpTrait> traits = []> :
+    Op<GPU_Dialect, mnemonic, traits>;
+
+class GPU_IndexOp<string mnemonic, list<OpTrait> traits = []> :
+    GPU_Op<mnemonic, !listconcat(traits, [NoSideEffect])>,
+    Arguments<(ins StrAttr:$dimension)>, Results<(outs Index)>;
+
+def gpu_BlockDim : GPU_IndexOp<"block_dim">;
+def gpu_BlockId : GPU_IndexOp<"block_id">;
+def gpu_GridDim : GPU_IndexOp<"grid_dim">;
+def gpu_ThreadId : GPU_IndexOp<"thread_id">;
+
+def gpu_Return : GPU_Op<"return", [Terminator]>, Arguments<(ins)>,
+    Results<(outs)> {
+  let summary = "Terminator for GPU launch regions.";
+  let description = [{
+    A terminator operation for regions that appear in the body of `gpu.launch`
+    operation.  These regions are not expected to return any value so the
+    terminator takes no operands.
+  }];
+
+  let parser = [{ return success(); }];
+  let printer = [{ *p << getOperationName(); }];
+}
+
+#endif // GPU_OPS
diff --git a/third_party/mlir/include/mlir/Dialect/GPU/Passes.h b/third_party/mlir/include/mlir/Dialect/GPU/Passes.h
new file mode 100644
index 0000000..f9b569d
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/GPU/Passes.h
@@ -0,0 +1,33 @@
+//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This header file defines prototypes that expose pass constructors.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_GPU_PASSES_H_
+#define MLIR_DIALECT_GPU_PASSES_H_
+
+namespace mlir {
+
+class ModulePassBase;
+
+ModulePassBase *createGpuKernelOutliningPass();
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_GPU_PASSES_H_
diff --git a/third_party/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt
new file mode 100644
index 0000000..2d69958
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS LoopOps.td)
+mlir_tablegen(LoopOps.h.inc -gen-op-decls)
+mlir_tablegen(LoopOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRLoopOpsIncGen)
diff --git a/third_party/mlir/include/mlir/Dialect/LoopOps/LoopOps.h b/third_party/mlir/include/mlir/Dialect/LoopOps/LoopOps.h
new file mode 100644
index 0000000..90cc0b7
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/LoopOps/LoopOps.h
@@ -0,0 +1,56 @@
+//===- Ops.h - Loop MLIR Operations -----------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines convenience types for working with loop operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_LOOPOPS_OPS_H_
+#define MLIR_LOOPOPS_OPS_H_
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace loop {
+
+class TerminatorOp;
+
+class LoopOpsDialect : public Dialect {
+public:
+  LoopOpsDialect(MLIRContext *context);
+  static StringRef getDialectNamespace() { return "loop"; }
+};
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/LoopOps/LoopOps.h.inc"
+
+// Insert `loop.terminator` at the end of the only region's only block if it
+// does not have a terminator already.  If a new `loop.terminator` is inserted,
+// the location is specified by `loc`. If the region is empty, insert a new
+// block first.
+void ensureLoopTerminator(Region &region, Builder &builder, Location loc);
+
+/// Returns the loop parent of an induction variable. If the provided value is
+/// not an induction variable, then return nullptr.
+ForOp getForInductionVarOwner(Value *val);
+
+} // end namespace loop
+} // end namespace mlir
+#endif // MLIR_LOOPOPS_OPS_H_
diff --git a/third_party/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/third_party/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
new file mode 100644
index 0000000..8b1b591
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
@@ -0,0 +1,158 @@
+//===- Ops.td - Loop operation definitions ---------------*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Defines MLIR loop operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef LOOP_OPS
+#else
+#define LOOP_OPS
+
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+def Loop_Dialect : Dialect {
+  let name = "loop";
+  let cppNamespace = "";
+}
+
+// Base class for Loop dialect ops.
+class Loop_Op<string mnemonic, list<OpTrait> traits = []> :
+    Op<Loop_Dialect, mnemonic, traits> {
+  // For every standard op, there needs to be a:
+  //   * void print(OpAsmPrinter *p, ${C++ class of Op} op)
+  //   * LogicalResult verify(${C++ class of Op} op)
+  //   * ParseResult parse${C++ class of Op}(OpAsmParser *parser,
+  //                                         OperationState *result)
+  // functions.
+  let printer = [{ return ::print(p, *this); }];
+  let verifier = [{ return ::verify(*this); }];
+  let parser = [{ return ::parse$cppClass(parser, result); }];
+}
+
+def ForOp : Loop_Op<"for",
+      [SingleBlockImplicitTerminator<"TerminatorOp">]> {
+  let summary = "for operation";
+  let description = [{
+    The "loop.for" operation represents a loop nest taking 3 SSA value as
+    operands that represent the lower bound, upper bound and step respectively.
+    The operation defines an SSA value for its induction variable. It has one
+    region capturing the loop body. The induction variable is represented as an
+    argument of this region. This SSA value always has type index, which is the
+    size of the machine word. The step is a value of type index, required to be
+    positive.
+    The lower and upper bounds specify a half-open range: the range includes the
+    lower bound but does not include the upper bound.
+
+    The body region must contain exactly one block that terminates with
+    "loop.terminator".  Calling ForOp::build will create such region and insert
+    the terminator, so will the parsing even in cases when it is absent from the
+    custom format. For example:
+
+       loop.for %iv = %lb to %ub step %step {
+         ... // body
+       }
+  }];
+  let arguments = (ins Index:$lowerBound, Index:$upperBound, Index:$step);
+  let regions = (region SizedRegion<1>:$region);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<"Builder *builder, OperationState *result, "
+              "Value *lowerBound, Value *upperBound, Value *step">
+  ];
+
+  let extraClassDeclaration = [{
+    Block *getBody() { return &region().front(); }
+    Value *getInductionVar() { return getBody()->getArgument(0); }
+    OpBuilder getBodyBuilder() {
+      return OpBuilder(getBody(), std::prev(getBody()->end()));
+    }
+    void setLowerBound(Value *bound) { getOperation()->setOperand(0, bound); }
+    void setUpperBound(Value *bound) { getOperation()->setOperand(1, bound); }
+    void setStep(Value *step) { getOperation()->setOperand(2, step); }
+  }];
+}
+
+def IfOp : Loop_Op<"if",
+      [SingleBlockImplicitTerminator<"TerminatorOp">]> {
+  let summary = "if-then-else operation";
+  let description = [{
+    The "loop.if" operation represents an if-then-else construct for
+    conditionally executing two regions of code. The operand to an if operation
+    is a boolean value. The operation produces no results. For example:
+
+       loop.if %b  {
+         ...
+       } else {
+         ...
+       }
+
+    The 'else' block is optional, and may be omitted. For
+    example:
+
+       loop.if %b  {
+         ...
+       }
+  }];
+  let arguments = (ins I1:$condition);
+  let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<"Builder *builder, OperationState *result, "
+              "Value *cond, bool withElseRegion">
+  ];
+
+  let extraClassDeclaration = [{
+    OpBuilder getThenBodyBuilder() {
+      assert(!thenRegion().empty() && "Unexpected empty 'then' region.");
+      Block &body = thenRegion().front();
+      return OpBuilder(&body, std::prev(body.end()));
+    }
+    OpBuilder getElseBodyBuilder() {
+      assert(!elseRegion().empty() && "Unexpected empty 'else' region.");
+      Block &body = elseRegion().front();
+      return OpBuilder(&body, std::prev(body.end()));
+    }
+  }];
+}
+
+def TerminatorOp :
+    Loop_Op<"terminator", [NativeOpTrait<"IsTerminator">]> {
+  let summary = "cf terminator operation";
+  let description = [{
+    "loop.terminator" is a special terminator operation for blocks inside
+    loops. It terminates the region. This operation does _not_ have a custom
+    syntax. However, `std` control operations omit the terminator in their
+    custom syntax for brevity.
+
+       loop.terminator
+  }];
+
+  // No custom parsing/printing form.
+  let parser = ?;
+  let printer = ?;
+
+  // Fully specified by traits.
+  let verifier = ?;
+}
+
+#endif // LOOP_OPS
diff --git a/third_party/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt
new file mode 100644
index 0000000..3e3b946
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS QuantOps.td)
+mlir_tablegen(QuantOps.h.inc -gen-op-decls)
+mlir_tablegen(QuantOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRQuantOpsIncGen)
diff --git a/third_party/mlir/include/mlir/Dialect/QuantOps/FakeQuantSupport.h b/third_party/mlir/include/mlir/Dialect/QuantOps/FakeQuantSupport.h
new file mode 100644
index 0000000..560b632
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/QuantOps/FakeQuantSupport.h
@@ -0,0 +1,68 @@
+//===- FakeQuantSupport.h - Support utilities for FakeQuant ops -*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines support utilities for interoperating with FakeQuant* based
+// QAT (Quantized Aware Training) computations, as implemented by TFLite. Note
+// that FakeQuant* operators mix multiple concerns specific to how TFLite
+// originally implemented quantization. As such, utilities here enforce
+// opinions taken by that codebase (vs providing any amount of genericity).
+//
+// Specifically, it combines the following concerns, each of which would be
+// independent variables in a more generic setup:
+//   - numBits and isSigned imply storage data type (uint8, int8, int16)
+//   - numBits < 8 is promoted to uint8 or int8
+//   - "narrow_range" narrows the lower bound of the storage type's range by
+//     1
+//   - the specified min/max values are "nudged" so that the result has a zero
+//     that can be exactly expressed
+//   - min=max=0 implies scale=0 and zero_point=0
+//
+// With the above assumptions applied, every conforming specified FakeQuant op
+// can be represented by a UniformQuantizedType. This scheme is not expected to
+// be generalized further in the future and should be considered to be a
+// legacy set of rules.
+//
+// As canonically used in TensorFlow graphs, the presence of a FakeQuant node
+// is a hint that the specific math represented here has been simulated at
+// training time. As such, it is usually not advised to arbitrarily change
+// quantization parameters derived from FakeQuant.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_QUANTOPS_FAKEQUANTSUPPORT_H_
+#define MLIR_DIALECT_QUANTOPS_FAKEQUANTSUPPORT_H_
+
+#include "mlir/Dialect/QuantOps/QuantTypes.h"
+
+namespace mlir {
+namespace quant {
+
+/// Converts per-layer FakeQuant attributes to the corresponding type.
+/// In the event that the parameters cannot be converted, returns a nullptr
+/// convertible Type and issues an appropriate error.
+/// Note that there are multiple variants of a per-layer FakeQuant op, so
+/// this function takes the attributes discretely vs taking a reference to the
+/// originating op.
+UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits,
+                                          double rmin, double rmax,
+                                          bool narrowRange, Type expressedType,
+                                          bool isSigned = false);
+
+} // namespace quant
+} // namespace mlir
+
+#endif // MLIR_DIALECT_QUANTOPS_FAKEQUANTSUPPORT_H_
diff --git a/third_party/mlir/include/mlir/Dialect/QuantOps/Passes.h b/third_party/mlir/include/mlir/Dialect/QuantOps/Passes.h
new file mode 100644
index 0000000..6b647a8
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/QuantOps/Passes.h
@@ -0,0 +1,47 @@
+//===- Passes.h - Quantization Passes ------ --------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines all of the passes owned by the quantization dialect. As
+// things mature, it is expected that passes specific to certain frontend or
+// backend dialects will move to those dialects directly. For now, they are
+// incubated here.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_QUANTOPS_PASSES_H
+#define MLIR_DIALECT_QUANTOPS_PASSES_H
+
+namespace mlir {
+class FunctionPassBase;
+
+namespace quant {
+
+/// Creates a pass that converts quantization simulation operations (i.e.
+/// FakeQuant and those like it) to casts into/out of supported QuantizedTypes.
+FunctionPassBase *createConvertSimulatedQuantPass();
+
+/// Creates a pass that converts constants followed by a qbarrier to a
+/// constant whose value is quantized. This is typically one of the last
+/// passes done when lowering to express actual quantized arithmetic in a
+/// low level representation. Because it modifies the constant, it is
+/// destructive and cannot be undone.
+FunctionPassBase *createConvertConstPass();
+
+} // namespace quant
+} // namespace mlir
+
+#endif // MLIR_DIALECT_QUANTOPS_PASSES_H
diff --git a/third_party/mlir/include/mlir/Dialect/QuantOps/QuantOps.h b/third_party/mlir/include/mlir/Dialect/QuantOps/QuantOps.h
new file mode 100644
index 0000000..8753cd2
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/QuantOps/QuantOps.h
@@ -0,0 +1,50 @@
+//===- QuantOps.h - Quantization Ops and Types ------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_DIALECT_QUANTOPS_QUANTOPS_H_
+#define MLIR_DIALECT_QUANTOPS_QUANTOPS_H_
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir {
+namespace quant {
+
+/// Defines the 'Quantization' dialect
+class QuantizationDialect : public Dialect {
+public:
+  QuantizationDialect(MLIRContext *context);
+
+  /// Parse a type registered to this dialect.
+  Type parseType(StringRef spec, Location loc) const override;
+
+  /// Print a type registered to this dialect.
+  void printType(Type type, raw_ostream &os) const override;
+};
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/QuantOps/QuantOps.h.inc"
+
+} // namespace quant
+} // namespace mlir
+
+#endif // MLIR_DIALECT_QUANTOPS_QUANTOPS_H_
diff --git a/third_party/mlir/include/mlir/Dialect/QuantOps/QuantOps.td b/third_party/mlir/include/mlir/Dialect/QuantOps/QuantOps.td
new file mode 100644
index 0000000..394d3a1
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/QuantOps/QuantOps.td
@@ -0,0 +1,227 @@
+//===- QuantOps.td - Quantization operation definition -----*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is the operation definition file for Quantization.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef DIALECT_QUANTOPS_QUANT_OPS_
+#else
+#define DIALECT_QUANTOPS_QUANT_OPS_
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+include "mlir/Dialect/QuantOps/QuantPredicates.td"
+#endif // OP_BASE
+
+def quant_Dialect : Dialect {
+  let name = "quant";
+}
+
+//===----------------------------------------------------------------------===//
+// Base classes
+//===----------------------------------------------------------------------===//
+
+class quant_Op<string mnemonic, list<OpTrait> traits> :
+    Op<quant_Dialect, mnemonic, traits>;
+
+//===----------------------------------------------------------------------===//
+// Quantization casts
+//===----------------------------------------------------------------------===//
+// A QuantizeCast (qcast) represents a potential type shift from a quantizable
+// type to a quantized type.
+//
+// At runtime, a qcast will apply the transformation expressed by its
+// operand and result type. For flexibility during transformation, it is also
+// possible to have a qcast that performs no transformation (both its
+// operand and result type are quantizable).
+//
+// A qcast will typically originate from either:
+//   a) An expressed or implied constraint in the source dialect which signals
+//      that a certain level of quantization is possible or required.
+//   b) An inference made by a quantization algorithm indicating that a
+//      quantized representation may be acceptable.
+//
+// Especially early in transformation, it is common to have pairs of
+// qcast/dcast at points where a transition to a quantized type is
+// required. In addition, it is also common to have an identity qcast
+// (where the operand and result type are not quantized) at all points where
+// it is legal to use a quantized representation (but is not known to be
+// acceptable).
+def quant_QuantizeCastOp : quant_Op<"qcast", [NoSideEffect]> {
+  let arguments = (ins quant_RealValueType:$arg);
+  let results = (outs quant_RealValueType);
+}
+
+// A DequantizeCast op (dcast) represents the inverse of a qcast,
+// converting back from a quantized to quantizable (expressed) type.
+//
+// Like qcasts, a dcast is allowed to have both its operand and result
+// as non quantized types. This facilitates transformations and marks edges
+// where the computation must be carried out in the expressed type.
+//
+// Especially early in transformation, it is common to have dcasts on
+// all operands to ops that must operate with the expressed type (typically
+// math ops prior to lowering to target-specific, quantized kernels).
+def quant_DequantizeCastOp : quant_Op<"dcast", [NoSideEffect]> {
+  let arguments = (ins quant_RealValueType:$arg);
+  let results = (outs quant_RealValueType);
+}
+
+// A StorageCast (scast) represents a cast from or to a type based on the
+// storage type and a type based on a corresponding quantized type.
+//
+// This op exists to ensure type coherency for between parts of the computation
+// which are operating directly on an underlying storage type and those which
+// operate on quantized values.
+//
+// Examples from storage to quantized type:
+//   i8 -> !quant<"uniform[i8:f32]{1.0}">
+//   tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+//   vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
+def quant_StorageCastOp : quant_Op<"scast", [NoSideEffect]> {
+  let arguments = (ins quant_RealOrStorageValueType:$arg);
+  let results = (outs quant_RealOrStorageValueType);
+  let hasCanonicalizer = 0b1;
+}
+
+//===----------------------------------------------------------------------===//
+// Training integration and instrumentation ops
+//===----------------------------------------------------------------------===//
+
+def quant_ConstFakeQuant : quant_Op<"const_fake_quant",
+                                    [SameOperandsAndResultType, NoSideEffect]> {
+  let summary =
+      "Simulates the effect of uniform quantization with const range.";
+
+  let description = [{
+    Given a const min, max, num_bits and narrow_range attribute, applies the
+    same uniform quantization simulation as is done by the TensorFlow
+    fake_quant_with_min_max_args op. See the fakeQuantAttrsToType() utility
+    method and the quant-convert-simulated-quantization pass for futher details.
+  }];
+
+  let arguments = (ins
+    F32Tensor:$inputs,
+    F32Attr:$min,
+    F32Attr:$max,
+    // The bitwidth of the quantization; between 2 and 16, inclusive.
+    I64Attr:$num_bits,
+    // Quantization range starts from 0 or 1; starts from 1 if true.
+    DefaultValuedAttr<BoolAttr, "false">:$narrow_range,
+    // The sign of the quantization.
+    DefaultValuedAttr<BoolAttr, "false">:$is_signed
+  );
+
+  let results = (outs
+    F32Tensor:$outputs
+  );
+}
+
+def quant_StatisticsRefOp : quant_Op<"stats_ref", [SameOperandsAndResultType]> {
+  let summary =
+      "Indicates that statistics are resolved by reference.";
+
+  let description = [{
+    This op acts as an identity that, when encountered at runtime, should result
+    in statistics being collected about about the value of its operand/result.
+    Such statistics will be stored with the provided key, allowing this node
+    to later be converted to a 'stats' op if statistics with that key have been
+    encountered.
+  }];
+
+  let arguments = (ins
+    quant_RealValueType:$arg,
+    StrAttr:$statsKey
+  );
+  let results = (outs quant_RealValueType);
+}
+
+def quant_StatisticsOp : quant_Op<"stats", [SameOperandsAndResultType]> {
+  let summary =
+      "Identity op which associates statistics with the value.";
+
+  let description = [{
+    Associates statistics about the runtime ranges of values observed for
+    evaluations of this node.
+
+    Statistics about the entire type are reported in the 'layerStats' attribute
+    and those for each axis, in the (optional) `axisStats` attribute. The
+    interpretation of each is determined by the last dimension of its shape.
+    Currently, only dim=2 is supported, which is interpreted as [min, max].
+
+    `layerStats` must be a rank 1 tensor: [2]
+    `axisStats` must be a rank 2 tensor: [N, 2], where N=the rank of `arg`.
+  }];
+
+  let arguments = (ins
+    quant_RealValueType:$arg,
+    ElementsAttr:$layerStats,
+    OptionalAttr<ElementsAttr>:$axisStats);
+  let results = (outs quant_RealValueType);
+
+  let verifier = [{
+    auto tensorArg = arg()->getType().dyn_cast<TensorType>();
+    auto argRank = tensorArg ? tensorArg.getRank() : 0;
+    // Verify layerStats attribute.
+    {
+      auto layerStatsType = layerStats().getType();
+      if (!layerStatsType.getElementType().isa<FloatType>()) {
+        return emitOpError(
+            "layerStats must have a floating point element type");
+      }
+      if (layerStatsType.getRank() != 1 || layerStatsType.getDimSize(0) != 2) {
+        return emitOpError("layerStats must have shape [2]");
+      }
+    }
+    // Verify axisStats (optional) attribute.
+    if (axisStats()) {
+      auto axisStatsType = axisStats()->getType();
+      if (!axisStatsType.getElementType().isa<FloatType>()) {
+        return emitOpError("axisStats must have a floating point element type");
+      }
+      if (axisStatsType.getRank() != 2 ||
+          axisStatsType.getDimSize(1) != 2 ||
+          axisStatsType.getDimSize(0) != argRank) {
+        return emitOpError("axisStats must have shape [N,2] "
+                           "where N = the argument rank");
+      }
+    }
+    return success();
+  }];
+}
+
+def quant_CoupledRefOp : quant_Op<"coupled_ref", [SameOperandsAndResultType]> {
+  let summary =
+      "Indicates that one point of the computation is coupled to another.";
+
+  let description = [{
+    Ordinarily, relationships between ops for the purposes of determining
+    compatible quantized types is explicit based on the use-def chain. However,
+    in some situations, a use may be separated from its def by arbitrary
+    external connections. In such a case, during analysis, all coupled_ref
+    nodes in a module which share a coupledKey will be considered to be
+    directly connected as via an identity op for the purpose of type inference.
+  }];
+
+  let arguments = (ins
+    quant_RealValueType:$arg,
+    StrAttr:$coupledKey);
+  let results = (outs quant_RealValueType);
+}
+
+#endif // DIALECT_QUANTOPS_QUANT_OPS_
diff --git a/third_party/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td b/third_party/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td
new file mode 100644
index 0000000..4940b01
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td
@@ -0,0 +1,73 @@
+//===- QuantPredicates.td - Predicates for dialect types ---*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Predicates for types in the Quantization dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef DIALECT_QUANTOPS_QUANT_PREDICATES_
+#else
+#define DIALECT_QUANTOPS_QUANT_PREDICATES_
+
+//===----------------------------------------------------------------------===//
+// Quantization type definitions
+//===----------------------------------------------------------------------===//
+
+class quant_TypedPrimitiveOrContainer<Type etype> :
+    Type<Or<[etype.predicate,
+                TensorOf<[etype]>.predicate,
+                VectorOf<[etype]>.predicate]>,
+         "primitive/tensor/vector of " # etype.description>;
+
+// An implementation of QuantizedType.
+def quant_QuantizedType :
+    Type<CPred<"$_self.isa<mlir::quant::QuantizedType>()">, "QuantizedType">;
+
+// A primitive type that can represent a real value. This is either a
+// floating point value or a quantized type.
+def quant_RealPrimitiveType :
+    Type<Or<[AnyFloat.predicate, quant_QuantizedType.predicate]>,
+    "real valued primitive (float or quantized type)">;
+
+// A primitive type that can represent a storage value. This is either an
+// integer or quantized type.
+def quant_StoragePrimitiveType :
+    Type<Or<[AnyInteger.predicate, quant_QuantizedType.predicate]>,
+    "quantized storage primitive (integer or quantized type)">;
+
+// A primitive or container of RealPrimitiveType.
+def quant_RealValueType :
+    quant_TypedPrimitiveOrContainer<quant_RealPrimitiveType>;
+
+// A primitive or container of StoragePrimitiveType.
+def quant_StorageValueType :
+    quant_TypedPrimitiveOrContainer<quant_StoragePrimitiveType>;
+
+// Either a real valued or storage primitive or container type.
+def quant_RealOrStorageValueType :
+    Type<Or<[quant_RealValueType.predicate,
+                quant_StorageValueType.predicate]>>;
+
+// An implementation of UniformQuantizedType.
+def quant_UniformQuantizedType :
+    Type<CPred<"$_self.isa<UniformQuantizedType>()">, "UniformQuantizedType">;
+
+// Predicate for detecting a container or primitive of UniformQuantizedType.
+def quant_UniformQuantizedValueType :
+    quant_TypedPrimitiveOrContainer<quant_UniformQuantizedType>;
+
+#endif // DIALECT_QUANTOPS_QUANT_PREDICATES_
diff --git a/third_party/mlir/include/mlir/Dialect/QuantOps/QuantTypes.h b/third_party/mlir/include/mlir/Dialect/QuantOps/QuantTypes.h
new file mode 100644
index 0000000..803ee4e
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/QuantOps/QuantTypes.h
@@ -0,0 +1,411 @@
+//===- QuantTypes.h - Quantization Ops and Types ----------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_DIALECT_QUANTOPS_QUANT_TYPES_H_
+#define MLIR_DIALECT_QUANTOPS_QUANT_TYPES_H_
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir {
+namespace quant {
+
+class QuantizedIntegerType;
+
+namespace detail {
+
+struct QuantizedTypeStorage;
+struct AnyQuantizedTypeStorage;
+struct UniformQuantizedTypeStorage;
+struct UniformQuantizedPerAxisTypeStorage;
+
+} // namespace detail
+
+namespace QuantizationTypes {
+enum Kind {
+  Any = Type::FIRST_QUANTIZATION_TYPE,
+  UniformQuantized,
+  UniformQuantizedPerAxis,
+  LAST_USED_QUANTIZATION_TYPE = UniformQuantizedPerAxis,
+};
+} // namespace QuantizationTypes
+
+/// Enumeration of bit-mapped flags related to quantized types.
+namespace QuantizationFlags {
+enum FlagValue {
+  // Indicates that the storage type should be interpreted as a signed
+  // integer. The default is to interpret it as an unsigned value.
+  Signed = 1,
+};
+} // namespace QuantizationFlags
+
+/// Base class for all quantized types known to this dialect.
+/// All quantized types have:
+///   - storageType: The (narrower) numeric type that is being used to
+///     approximate some expressed type.
+///   - expressedType: The type that is being approximated.
+///
+/// The base class provides generic support for manipulating the types based
+/// on these fields.
+class QuantizedType : public Type {
+public:
+  using ImplType = detail::QuantizedTypeStorage;
+  using Type::Type;
+
+  /// The maximum number of bits supported for storage types.
+  static constexpr unsigned MaxStorageBits = 32;
+
+  static LogicalResult
+  verifyConstructionInvariants(llvm::Optional<Location> loc,
+                               MLIRContext *context, unsigned flags,
+                               Type storageType, Type expressedType,
+                               int64_t storageTypeMin, int64_t storageTypeMax);
+
+  /// Support method to enable LLVM-style type casting.
+  static bool classof(Type type) {
+    return type.getKind() >= Type::FIRST_QUANTIZATION_TYPE &&
+           type.getKind() <= QuantizationTypes::LAST_USED_QUANTIZATION_TYPE;
+  }
+
+  /// Gets the minimum possible stored by a storageType. storageTypeMin must
+  /// be greater than or equal to this value.
+  static int64_t getDefaultMininumForInteger(bool isSigned,
+                                             unsigned integralWidth) {
+    if (isSigned) {
+      return llvm::minIntN(integralWidth);
+    }
+    return 0;
+  }
+
+  /// Gets the maximum possible stored by a storageType. storageTypeMax must
+  /// be less than or equal to this value.
+  static int64_t getDefaultMaxinumForInteger(bool isSigned,
+                                             unsigned integralWidth) {
+    if (isSigned) {
+      return llvm::maxIntN(integralWidth);
+    }
+    return llvm::maxUIntN(integralWidth);
+  }
+
+  /// Gets the original expressed type that this quantized type approximates.
+  /// Note that this presumes that the quantized type was always derived from
+  /// a floating point type, which in the broadest definition, is not true (i.e.
+  /// it could be some form of integral, fixed type or affine type in its own
+  /// right); however, at the high level, no examples of such usage are
+  /// presently known and the restriction serves some useful purposes (such as
+  /// always being able to reverse a transformation or measure error). In most
+  /// cases, this will be f32.
+  Type getExpressedType() const;
+
+  /// Gets the flags associated with this type. Typically a more specific
+  /// accessor is appropriate.
+  unsigned getFlags() const;
+
+  // Convenience helpers.
+  /// Whether the storage type should be interpreted as a signed quantity
+  /// (true) or an unsigned value (false).
+  bool isSigned() const {
+    return (getFlags() & QuantizationFlags::Signed) ==
+           QuantizationFlags::Signed;
+  }
+
+  /// Gets the underlying type used for to store values. Note that this may
+  /// be signed or unsigned. Use the isSigned() accessor to differentiate.
+  Type getStorageType() const;
+
+  /// The minimum value that storageType can take.
+  int64_t getStorageTypeMin() const;
+
+  /// The maximum value that storageType can take.
+  int64_t getStorageTypeMax() const;
+
+  /// Gets the integral bit width that the underlying storage type can exactly
+  /// represent. For integral storage types, this will just be their width.
+  unsigned getStorageTypeIntegralWidth() const;
+
+  /// Returns whether the candidateExpressedType is a match for this
+  /// QuantizedType. This will be true if the candidate type is either a
+  /// primitive type or a container type whose element type equals this
+  /// QuantizedType's expressed type.
+  /// Examples of compatible candidateExpressedType:
+  ///   !quant.uniform<i8:f32, 1.0> =~ f32
+  ///   !quant.uniform<i8:f32, 1.0> =~ tensor<4xf32>
+  bool isCompatibleExpressedType(Type candidateExpressedType);
+
+  /// Returns the element type as a QuantizedType or nullptr if it is not
+  /// a quantized type. If the type is primitive, returns that. If it is a
+  /// container (vector/tensor), return the element type.
+  /// Examples:
+  ///   !quant.uniform<i8:f32, 1.0> -> !quant.uniform<i8:f32, 1.0>
+  ///   tensor<4x!quant.uniform<i8:f32, 1.0> -> quant.uniform<i8:f32, 1.0>
+  static QuantizedType getQuantizedElementType(Type primitiveOrContainerType);
+
+  /// Casts from a type based on the storageType to a corresponding type based
+  /// on this type (returns nullptr if the cast is not valid).
+  /// Examples:
+  ///   i8 -> !quant.uniform<i8:f32, 1.0>
+  ///   tensor<4xi8> -> tensor<4x!quant.uniform<i8:f32, 1.0}>>
+  ///   vector<4xi8> -> vector<4x!quant.uniform<i8:f32, 1.0>>
+  Type castFromStorageType(Type candidateType);
+
+  /// Casts from a type based on a QuantizedType to a corresponding type based
+  /// on the storageType (returns nullptr if the cast is not valid).
+  /// This is the inverse of castFromStorageType().
+  static Type castToStorageType(Type quantizedType);
+
+  /// Casts from a type based on the expressedType to a corresponding type based
+  /// on this type (returns nullptr if the cast is not valid).
+  /// Examples:
+  ///   f32 -> !quant.uniform<i8:f32, 1.0>
+  ///   tensor<4xf32> -> tensor<4x!quant.uniform<i8:f32, 1.0>>
+  ///   vector<4xf32> -> vector<4x!quant.uniform<i8:f32, 1.0>>
+  Type castFromExpressedType(Type candidateType);
+
+  /// Casts from a type based on QuantizedType to a corresponding type based
+  /// on the expressedType (returns nullptr if the cast is not valid).
+  /// This is the inverse of castFromExpressedType.
+  static Type castToExpressedType(Type quantizedType);
+
+  /// Casts from a type based on the expressedType to the equivalent type
+  /// based on storageType by way of this QuantizedType. Equivalent to:
+  ///   QuantizedType::castToStorageType(castFromExpressedType(candidateType))
+  /// (but with validity checks).
+  /// Example (for this = !quant.uniform<i8:f32, 1.0>):
+  ///   tensor<4xf32> -> tensor<4xi8>
+  Type castExpressedToStorageType(Type candidateType);
+
+private:
+  /// Hide the following methods inherited from `Type`. It is almost certainly
+  /// a bug to call them from a `QuantizedType` object. Users should call
+  /// `getStorageType` or `getExpressedType` to get the underlying types
+  /// they want to inspect.
+  using Type::isBF16;
+  using Type::isF16;
+  using Type::isF32;
+  using Type::isF64;
+  using Type::isIndex;
+  using Type::isInteger;
+};
+
+/// A quantized type that maps storage to/from expressed types in an
+/// unspecified way.
+///
+/// Typical syntax:
+///   quant.any<i8:f32>
+///   quant.any<i8>
+///   quant.any<i8<-16,15>>
+///
+/// Note that for the any type, the expressed type is optional.
+class AnyQuantizedType
+    : public Type::TypeBase<AnyQuantizedType, QuantizedType,
+                            detail::AnyQuantizedTypeStorage> {
+public:
+  using Base::Base;
+
+  /// Support method to enable LLVM-style type casting.
+  static bool kindof(unsigned kind) { return kind == QuantizationTypes::Any; }
+
+  /// Gets an instance of the type with all parameters specified but not
+  /// checked.
+  static AnyQuantizedType get(unsigned flags, Type storageType,
+                              Type expressedType, int64_t storageTypeMin,
+                              int64_t storageTypeMax);
+
+  /// Gets an instance of the type with all specified parameters checked.
+  /// Returns a nullptr convertible type on failure.
+  static AnyQuantizedType getChecked(unsigned flags, Type storageType,
+                                     Type expressedType, int64_t storageTypeMin,
+                                     int64_t storageTypeMax, Location location);
+
+  /// Verifies construction invariants and issues errors/warnings.
+  static LogicalResult
+  verifyConstructionInvariants(llvm::Optional<Location> loc,
+                               MLIRContext *context, unsigned flags,
+                               Type storageType, Type expressedType,
+                               int64_t storageTypeMin, int64_t storageTypeMax);
+};
+
+/// Represents a family of uniform, quantized types.
+///
+/// Each instance of this type expresses a mapping between real values (most
+/// often expressed in floating point f32) and quantized values (either fixed
+/// point or affine).
+///
+/// The relationship is:
+///     real_value = scale * (quantized_value - zero_point)
+///
+/// It is used as part of high level graph transformations that have the goal
+/// of re-expressing parts of a computation in terms of this common form for
+/// more efficient execution at runtime. In addition, it is designed to be
+/// expressive enough to facilitate lowering to precise types and operations
+/// in target hardware.
+///
+/// As a high-level type, focused on intermediate passes, this type holds
+/// opinions consistent with high-level usage. If lowering math kernels below
+/// the high level arithmetic ops (i.e. to LLVM IR or hardware specific
+/// instruction sets), it is expected that the information expressed here
+/// will be used to drive low level codegen and target specific type selection,
+/// but this type will likely be erased in the process.
+///
+/// Syntax synopsis:
+///   Per-layer, all parameters expressed:
+///     !quant<uniform[StorageType:ExpressedType]{Scale:ZeroPoint}>
+///   Per-layer, optional parameters omitted:
+///     !quant<uniform[StorageType]{Scale}>
+///
+///   StorageType: 'i'|'u' NumBits
+///   ExpressedType: 'f16', 'f32', 'bf16', 'f64'
+///   Scale: A legal double value
+///   ZeroPoint: An integer value
+class UniformQuantizedType
+    : public Type::TypeBase<UniformQuantizedType, QuantizedType,
+                            detail::UniformQuantizedTypeStorage> {
+public:
+  using Base::Base;
+
+  /// Gets an instance of the type with all parameters specified but not
+  /// checked.
+  static UniformQuantizedType get(unsigned flags, Type storageType,
+                                  Type expressedType, double scale,
+                                  int64_t zeroPoint, int64_t storageTypeMin,
+                                  int64_t storageTypeMax);
+
+  /// Gets an instance of the type with all specified parameters checked.
+  /// Returns a nullptr convertible type on failure.
+  static UniformQuantizedType
+  getChecked(unsigned flags, Type storageType, Type expressedType, double scale,
+             int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax,
+             Location location);
+
+  /// Verifies construction invariants and issues errors/warnings.
+  static LogicalResult verifyConstructionInvariants(
+      llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
+      Type storageType, Type expressedType, double scale, int64_t zeroPoint,
+      int64_t storageTypeMin, int64_t storageTypeMax);
+
+  /// Support method to enable LLVM-style type casting.
+  static bool kindof(unsigned kind) {
+    return kind == QuantizationTypes::UniformQuantized;
+  }
+
+  /// Gets the scale term. The scale designates the difference between the real
+  /// values corresponding to consecutive quantized values differing by 1.
+  double getScale() const;
+
+  /// Gets the storage value corresponding to the real value 0 in the affine
+  /// equation.
+  int64_t getZeroPoint() const;
+
+  // Fixed point values are real numbers divided by a scale.
+  // Currently, only signed storage types are treated as fixed point.
+  // A fixed point value can be obtained from an affine value by subtracting
+  // the zeroPoint.
+  // In the future, this may be explicit versus implied by type and zeroPoint.
+  bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; }
+};
+
+/// Represents per-axis (also known as per-channel quantization).
+///
+/// Syntax synopsis:
+///   Per-axis, all parameters expressed:
+///     !quant<uniform[StorageType:ExpressedType:QuantizedDim]{QuantParams}>
+///   Per-axis, optional parameters omitted:
+///     !quant<uniform[StorageType]{Scale}>
+///
+///   StorageType: 'i'|'u' NumBits
+///   ExpressedType: 'f16', 'f32', 'bf16', 'f64'
+///   QuantizedDim: An integer value
+///   QuantParams: (Scale ':' ZeroPoint)+
+///   Scale: A legal double value
+///   ZeroPoint: An integer value
+class UniformQuantizedPerAxisType
+    : public Type::TypeBase<UniformQuantizedPerAxisType, QuantizedType,
+                            detail::UniformQuantizedPerAxisTypeStorage> {
+public:
+  using Base::Base;
+
+  /// Gets an instance of the type with all parameters specified but not
+  /// checked.
+  static UniformQuantizedPerAxisType
+  get(unsigned flags, Type storageType, Type expressedType,
+      ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
+      int32_t quantizedDimension, int64_t storageTypeMin,
+      int64_t storageTypeMax);
+
+  /// Gets an instance of the type with all specified parameters checked.
+  /// Returns a nullptr convertible type on failure.
+  static UniformQuantizedPerAxisType
+  getChecked(unsigned flags, Type storageType, Type expressedType,
+             ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
+             int32_t quantizedDimension, int64_t storageTypeMin,
+             int64_t storageTypeMax, Location location);
+
+  /// Verifies construction invariants and issues errors/warnings.
+  static LogicalResult verifyConstructionInvariants(
+      llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
+      Type storageType, Type expressedType, ArrayRef<double> scales,
+      ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
+      int64_t storageTypeMin, int64_t storageTypeMax);
+
+  /// Support method to enable LLVM-style type casting.
+  static bool kindof(unsigned kind) {
+    return kind == QuantizationTypes::UniformQuantizedPerAxis;
+  }
+
+  /// Gets the quantization scales. The scales designate the difference between
+  /// the real values corresponding to consecutive quantized values differing
+  /// by 1. The ith scale corresponds to the ith slice in the
+  /// quantized_dimension.
+  ArrayRef<double> getScales() const;
+
+  /// Gets the storage values corresponding to the real value 0 in the affine
+  /// equation. The ith zero point corresponds to the ith slice in the
+  /// quantized_dimension.
+  ArrayRef<int64_t> getZeroPoints() const;
+
+  /// Specifies the dimension of the Tensor's shape that the scales and
+  /// zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1]
+  /// with quantization params:
+  ///   scales=[1.0, 2.0, 3.0], zeroPoints=[1, 2, 3], quantizedDimension=1
+  /// will be quantized across the second dimension of t.
+  ///   t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1
+  ///   t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2
+  ///   t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3
+  int32_t getQuantizedDimension() const;
+
+  /// Fixed point values are real numbers divided by a scale.
+  /// Currently, only signed storage types are treated as fixed point.
+  /// A fixed point value can be obtained from an affine value by subtracting
+  /// the zeroPoint.
+  /// In the future, this may be explicit versus implied by type and zeroPoint.
+  bool isFixedPoint() const {
+    if (!isSigned())
+      return false;
+    return llvm::all_of(getZeroPoints(),
+                        [](int64_t zeroPoint) { return zeroPoint != 0; });
+  }
+};
+
+} // namespace quant
+} // namespace mlir
+
+#endif // MLIR_DIALECT_QUANTOPS_QUANT_TYPES_H_
diff --git a/third_party/mlir/include/mlir/Dialect/QuantOps/QuantizeUtils.h b/third_party/mlir/include/mlir/Dialect/QuantOps/QuantizeUtils.h
new file mode 100644
index 0000000..de87ca1
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/QuantOps/QuantizeUtils.h
@@ -0,0 +1,70 @@
+//===- QuantizeUtils.h - Support utilities for quantization -----*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_DIALECT_QUANTOPS_QUANTIZEUTILS_H_
+#define MLIR_DIALECT_QUANTOPS_QUANTIZEUTILS_H_
+
+namespace mlir {
+class Attribute;
+class Type;
+
+namespace quant {
+class QuantizedType;
+class UniformQuantizedType;
+class UniformQuantizedValueConverter;
+
+/// Converts an attribute from a type based on
+/// quantizedElementType.getExpressedType() to one based on
+/// quantizedElementType.getStorageType(), where quantizedElementType is as from
+/// QuantizedType::getQuantizedElementType().
+/// Returns nullptr if the conversion is not supported. On success, stores the
+/// converted type in outConvertedType.
+///
+/// Examples:
+/// 1. realValue is a primitive value attribute:
+/// (realValue: FloatAttr, quantizedElementType: UniformQuantizedType[i8:f32])
+///   -> (IntegerAttr, outConvertedType: i8)
+/// 2. realValue is an elements attribute:
+/// (realValue: DenseElementsAttr[tensor<2x2xf32>],
+///  quantizedElementType: UniformQuantizedType[i8:f32])
+///   -> (DenseElementsAttr[tensor<2x2xi8>], outConvertedType: tensor<2x2xi8>)
+Attribute quantizeAttr(Attribute realValue, QuantizedType quantizedElementType,
+                       Type &outConvertedType);
+
+/// Converts an attribute from a type based on
+/// quantizedElementType.getExpressedType() to one based on
+/// quantizedElementType.getStorageType(), where quantizedElementType is as from
+/// QuantizedType::getQuantizedElementType() and casted to an
+/// UniformQuantizedType. Returns nullptr if the conversion is not supported. On
+/// success, stores the converted type in outConvertedType.
+///
+/// Examples:
+/// 1. realValue is a primitive value attribute:
+/// (realValue: FloatAttr, quantizedElementType: UniformQuantizedType[i8:f32])
+///   -> (IntegerAttr, outConvertedType: i8)
+/// 2. realValue is an elements attribute:
+/// (realValue: DenseElementsAttr[tensor<2x2xf32>],
+///  quantizedElementType: UniformQuantizedType[i8:f32])
+///   -> (DenseElementsAttr[tensor<2x2xi8>], outConvertedType: tensor<2x2xi8>)
+Attribute quantizeAttrUniform(Attribute realValue,
+                              UniformQuantizedType quantizedElementType,
+                              const UniformQuantizedValueConverter &converter,
+                              Type &outConvertedType);
+} // namespace quant
+} // namespace mlir
+
+#endif // MLIR_DIALECT_QUANTOPS_QUANTIZEUTILS_H_
diff --git a/third_party/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h b/third_party/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h
new file mode 100644
index 0000000..5d11c76
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h
@@ -0,0 +1,119 @@
+//===- UniformSupport.h - Support utilities for uniform quant ---*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_DIALECT_QUANTOPS_UNIFORMSUPPORT_H_
+#define MLIR_DIALECT_QUANTOPS_UNIFORMSUPPORT_H_
+
+#include "mlir/Dialect/QuantOps/QuantTypes.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/APSInt.h"
+
+namespace mlir {
+namespace quant {
+
+/// Performs type conversion from an arbitrary input type to a type
+/// that is expressed by a UniformQuantizedType.
+///
+/// This handles cases where the inputType is a supported primitive type
+/// (i.e. f32, bf16, etc) or a vector/tensor type based on a supported
+/// elemental type.
+///
+/// Since conversion often involves introspecting some attributes of the
+/// input type in order to determine how to represent it, this is a two step
+/// process.
+struct ExpressedToUniformQuantizedConverter {
+  /// Creates a converter for the given input type.
+  static const ExpressedToUniformQuantizedConverter
+  forInputType(Type inputType);
+
+  /// Converts the inputType to be based on the given elemental type,
+  /// returning the new type (or nullptr and emit an error on failure).
+  Type convert(UniformQuantizedType elementalType) const;
+
+  /// Whether the conversion is legal.
+  explicit operator bool() const { return (bool)expressedType; }
+
+  /// The input type that is being converted from.
+  /// This may be an elemental or composite type.
+  const Type inputType;
+
+  /// Supported, elemental expressed type (i.e. f32).
+  /// Will be nullptr if conversion is not supported.
+  const Type expressedType;
+};
+
+/// Reference implementation of converting between real numbers and values
+/// represented by a UniformQuantizedType.
+/// Note that this is not expected to be speedy and may be superceded eventually
+/// by a more optimal implementation.
+/// Also, the interface assumes that quantization is done per-layer and will
+/// need to be wider for various per-channel schemes. As such, this is a
+/// placeholder.
+class UniformQuantizedValueConverter {
+public:
+  UniformQuantizedValueConverter(UniformQuantizedType uniformType)
+      : scale(uniformType.getScale()),
+        zeroPoint(static_cast<double>(uniformType.getZeroPoint())),
+        clampMin(static_cast<double>(uniformType.getStorageTypeMin())),
+        clampMax(static_cast<double>(uniformType.getStorageTypeMax())),
+        storageBitWidth(uniformType.getStorageTypeIntegralWidth()),
+        isSigned(uniformType.isSigned()) {
+    assert(uniformType.getExpressedType().isa<FloatType>());
+    assert(uniformType.getStorageType().isa<IntegerType>());
+  }
+
+  virtual APInt quantizeFloatToInt(APFloat expressedValue) const {
+    bool lossy;
+    expressedValue.convert(scale.getSemantics(), APFloat::rmNearestTiesToEven,
+                           &lossy);
+    // fixedpoint = clamp(clampMin, clampMax, (
+    //   roundHalfToEven(expressed / scale) + zeroPoint))
+    APFloat scaled = (expressedValue / scale);
+    scaled.roundToIntegral(APFloat::rmNearestTiesToEven);
+    scaled.add(zeroPoint, APFloat::rmNearestTiesToEven);
+    APFloat fixedpoint = llvm::minimum(scaled, clampMax);
+    fixedpoint = llvm::maximum(fixedpoint, clampMin);
+
+    llvm::APSInt result(storageBitWidth, !isSigned);
+    fixedpoint.convertToInteger(result, APFloat::rmNearestTiesToEven, &lossy);
+
+    return std::move(result);
+  }
+
+  int64_t quantizeFloatToInt64(APFloat expressedValue) const {
+    APInt qValue = quantizeFloatToInt(expressedValue);
+    return isSigned ? qValue.getSExtValue() : qValue.getZExtValue();
+  }
+
+  virtual ~UniformQuantizedValueConverter() {}
+
+private:
+  const APFloat scale;
+  const APFloat zeroPoint;
+  const APFloat clampMin;
+  const APFloat clampMax;
+  const uint32_t storageBitWidth;
+  const bool isSigned;
+};
+
+} // namespace quant
+} // namespace mlir
+
+#endif // MLIR_DIALECT_QUANTOPS_UNIFORMSUPPORT_H_
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt
new file mode 100644
index 0000000..af4520d
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt
@@ -0,0 +1,17 @@
+set(LLVM_TARGET_DEFINITIONS SPIRVOps.td)
+mlir_tablegen(SPIRVOps.h.inc -gen-op-decls)
+mlir_tablegen(SPIRVOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRSPIRVOpsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
+mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls)
+mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRSPIRVEnumsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS SPIRVOps.td)
+mlir_tablegen(SPIRVSerialization.inc -gen-spirv-serialization)
+add_public_tablegen_target(MLIRSPIRVSerializationGen)
+
+set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
+mlir_tablegen(SPIRVOpUtils.inc -gen-spirv-op-utils)
+add_public_tablegen_target(MLIRSPIRVOpUtilsGen)
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/Passes.h b/third_party/mlir/include/mlir/Dialect/SPIRV/Passes.h
new file mode 100644
index 0000000..e896da7
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/Passes.h
@@ -0,0 +1,35 @@
+//===- Passes.h - SPIR-V pass entry points ----------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This header file defines prototypes that expose pass constructors.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_PASSES_H_
+#define MLIR_DIALECT_SPIRV_PASSES_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace spirv {
+
+ModulePassBase *createConvertStandardToSPIRVPass();
+
+} // namespace spirv
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SPIRV_PASSES_H_
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
new file mode 100644
index 0000000..1a722f8
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -0,0 +1,741 @@
+//===- SPIRVBase.td - MLIR SPIR-V Op Definitions Base file -*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is the base file for SPIR-V operation definition specification.
+// This file defines the SPIR-V dialect, common SPIR-V types, and utilities
+// for facilitating defining SPIR-V ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef SPIRV_BASE
+#else
+#define SPIRV_BASE
+
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+//===----------------------------------------------------------------------===//
+// SPIR-V dialect definitions
+//===----------------------------------------------------------------------===//
+
+def SPV_Dialect : Dialect {
+  let name = "spv";
+
+  let description = [{
+    The SPIR-V dialect in MLIR.
+
+    SPIR-V is the Khronos Group's binary intermediate language for representing
+    graphical-shader stages and compute kernels for multiple Khronos APIs,
+    including OpenCL, OpenGL, and Vulkan.
+    See https://www.khronos.org/registry/spir-v for more details.
+
+    This dialect aims to be a simple proxy for the SPIR-V binary format to
+    enable straightforward and lightweight conversion from/to the binary
+    format. Ops in this dialect should stay at the same semantic level and
+    try to be a mechanical mapping to the corresponding SPIR-V instructions;
+    but they may deviate representationally to allow using MLIR mechanisms.
+    As a convention, if such deviation happens, the op name follows "snake_case"
+    style; otherwise, the op name just follows the SPIR-V mnemonic (by removing
+    the leading `Op` prefix) to use "CamelCase" style.
+  }];
+
+  let cppNamespace = "spirv";
+}
+
+//===----------------------------------------------------------------------===//
+// SPIR-V opcode specification
+//===----------------------------------------------------------------------===//
+
+class SPV_OpCode<string name, int val> {
+  // Name used as reference to retrieve the opcode
+  string opname = name;
+
+  // Opcode associated with the name
+  int opcode = val;
+}
+
+// Begin opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
+
+def SPV_OC_OpNop                   : I32EnumAttrCase<"OpNop", 0>;
+def SPV_OC_OpName                  : I32EnumAttrCase<"OpName", 5>;
+def SPV_OC_OpMemoryModel           : I32EnumAttrCase<"OpMemoryModel", 14>;
+def SPV_OC_OpEntryPoint            : I32EnumAttrCase<"OpEntryPoint", 15>;
+def SPV_OC_OpExecutionMode         : I32EnumAttrCase<"OpExecutionMode", 16>;
+def SPV_OC_OpTypeVoid              : I32EnumAttrCase<"OpTypeVoid", 19>;
+def SPV_OC_OpTypeBool              : I32EnumAttrCase<"OpTypeBool", 20>;
+def SPV_OC_OpTypeInt               : I32EnumAttrCase<"OpTypeInt", 21>;
+def SPV_OC_OpTypeFloat             : I32EnumAttrCase<"OpTypeFloat", 22>;
+def SPV_OC_OpTypeVector            : I32EnumAttrCase<"OpTypeVector", 23>;
+def SPV_OC_OpTypeArray             : I32EnumAttrCase<"OpTypeArray", 28>;
+def SPV_OC_OpTypePointer           : I32EnumAttrCase<"OpTypePointer", 32>;
+def SPV_OC_OpTypeFunction          : I32EnumAttrCase<"OpTypeFunction", 33>;
+def SPV_OC_OpConstantTrue          : I32EnumAttrCase<"OpConstantTrue", 41>;
+def SPV_OC_OpConstantFalse         : I32EnumAttrCase<"OpConstantFalse", 42>;
+def SPV_OC_OpConstant              : I32EnumAttrCase<"OpConstant", 43>;
+def SPV_OC_OpConstantComposite     : I32EnumAttrCase<"OpConstantComposite", 44>;
+def SPV_OC_OpConstantNull          : I32EnumAttrCase<"OpConstantNull", 46>;
+def SPV_OC_OpSpecConstantTrue      : I32EnumAttrCase<"OpSpecConstantTrue", 48>;
+def SPV_OC_OpSpecConstantFalse     : I32EnumAttrCase<"OpSpecConstantFalse", 49>;
+def SPV_OC_OpSpecConstant          : I32EnumAttrCase<"OpSpecConstant", 50>;
+def SPV_OC_OpSpecConstantComposite : I32EnumAttrCase<"OpSpecConstantComposite", 51>;
+def SPV_OC_OpFunction              : I32EnumAttrCase<"OpFunction", 54>;
+def SPV_OC_OpFunctionParameter     : I32EnumAttrCase<"OpFunctionParameter", 55>;
+def SPV_OC_OpFunctionEnd           : I32EnumAttrCase<"OpFunctionEnd", 56>;
+def SPV_OC_OpVariable              : I32EnumAttrCase<"OpVariable", 59>;
+def SPV_OC_OpLoad                  : I32EnumAttrCase<"OpLoad", 61>;
+def SPV_OC_OpStore                 : I32EnumAttrCase<"OpStore", 62>;
+def SPV_OC_OpAccessChain           : I32EnumAttrCase<"OpAccessChain", 65>;
+def SPV_OC_OpDecorate              : I32EnumAttrCase<"OpDecorate", 71>;
+def SPV_OC_OpCompositeExtract      : I32EnumAttrCase<"OpCompositeExtract", 81>;
+def SPV_OC_OpIAdd                  : I32EnumAttrCase<"OpIAdd", 128>;
+def SPV_OC_OpFAdd                  : I32EnumAttrCase<"OpFAdd", 129>;
+def SPV_OC_OpISub                  : I32EnumAttrCase<"OpISub", 130>;
+def SPV_OC_OpFSub                  : I32EnumAttrCase<"OpFSub", 131>;
+def SPV_OC_OpIMul                  : I32EnumAttrCase<"OpIMul", 132>;
+def SPV_OC_OpFMul                  : I32EnumAttrCase<"OpFMul", 133>;
+def SPV_OC_OpUDiv                  : I32EnumAttrCase<"OpUDiv", 134>;
+def SPV_OC_OpSDiv                  : I32EnumAttrCase<"OpSDiv", 135>;
+def SPV_OC_OpFDiv                  : I32EnumAttrCase<"OpFDiv", 136>;
+def SPV_OC_OpUMod                  : I32EnumAttrCase<"OpUMod", 137>;
+def SPV_OC_OpSRem                  : I32EnumAttrCase<"OpSRem", 138>;
+def SPV_OC_OpSMod                  : I32EnumAttrCase<"OpSMod", 139>;
+def SPV_OC_OpFRem                  : I32EnumAttrCase<"OpFRem", 140>;
+def SPV_OC_OpFMod                  : I32EnumAttrCase<"OpFMod", 141>;
+def SPV_OC_OpIEqual                : I32EnumAttrCase<"OpIEqual", 170>;
+def SPV_OC_OpINotEqual             : I32EnumAttrCase<"OpINotEqual", 171>;
+def SPV_OC_OpUGreaterThan          : I32EnumAttrCase<"OpUGreaterThan", 172>;
+def SPV_OC_OpSGreaterThan          : I32EnumAttrCase<"OpSGreaterThan", 173>;
+def SPV_OC_OpUGreaterThanEqual     : I32EnumAttrCase<"OpUGreaterThanEqual", 174>;
+def SPV_OC_OpSGreaterThanEqual     : I32EnumAttrCase<"OpSGreaterThanEqual", 175>;
+def SPV_OC_OpULessThan             : I32EnumAttrCase<"OpULessThan", 176>;
+def SPV_OC_OpSLessThan             : I32EnumAttrCase<"OpSLessThan", 177>;
+def SPV_OC_OpULessThanEqual        : I32EnumAttrCase<"OpULessThanEqual", 178>;
+def SPV_OC_OpSLessThanEqual        : I32EnumAttrCase<"OpSLessThanEqual", 179>;
+def SPV_OC_OpReturn                : I32EnumAttrCase<"OpReturn", 253>;
+
+def SPV_OpcodeAttr :
+    I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
+      SPV_OC_OpNop, SPV_OC_OpName, SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint,
+      SPV_OC_OpExecutionMode, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt,
+      SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeArray,
+      SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue,
+      SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite,
+      SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse,
+      SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction,
+      SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpVariable,
+      SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
+      SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub,
+      SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv,
+      SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem,
+      SPV_OC_OpFMod, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
+      SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
+      SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
+      SPV_OC_OpSLessThanEqual, SPV_OC_OpReturn
+      ]> {
+    let returnType = "::mlir::spirv::Opcode";
+    let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
+    let cppNamespace = "::mlir::spirv";
+}
+
+// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
+
+
+//===----------------------------------------------------------------------===//
+// SPIR-V type definitions
+//===----------------------------------------------------------------------===//
+
+def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
+def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">;
+def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">;
+def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">;
+
+// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
+// for the definition of the following types and type categories.
+
+def SPV_Void : TypeAlias<NoneType, "void type">;
+def SPV_Bool : IntOfWidths<[1]>;
+def SPV_Integer : IntOfWidths<[8, 16, 32, 64]>;
+def SPV_Float : FloatOfWidths<[16, 32, 64]>;
+def SPV_Vector : VectorOf<[SPV_Bool, SPV_Integer, SPV_Float]>;
+// Component type check is done in the type parser for the following SPIR-V
+// dialect-specific types so we use "Any" here.
+def SPV_AnyPtr : Type<SPV_IsPtrType, "any SPIR-V pointer type">;
+def SPV_AnyArray : Type<SPV_IsArrayType, "any SPIR-V array type">;
+def SPV_AnyRTArray : Type<SPV_IsRTArrayType, "any SPIR-V runtime array type">;
+def SPV_AnyStruct : Type<SPV_IsStructType, "any SPIR-V struct type">;
+
+def SPV_Numerical : AnyTypeOf<[SPV_Integer, SPV_Float]>;
+def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>;
+def SPV_Aggregrate : AnyTypeOf<[SPV_AnyArray, SPV_AnyStruct]>;
+def SPV_Composite: AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyStruct]>;
+def SPV_Type : AnyTypeOf<[
+    SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector,
+    SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct
+  ]>;
+
+class SPV_ScalarOrVectorOf<Type type> :
+    Type<Or<[type.predicate, VectorOf<[type]>.predicate]>,
+         "scalar/vector of " # type.description>;
+
+// TODO(antiagainst): Use a more appropriate way to model optional operands
+class SPV_Optional<Type type> : Variadic<type>;
+
+def SPV_IsEntryPointType :
+    CPred<"$_self.isa<::mlir::spirv::EntryPointType>()">;
+def SPV_EntryPoint : Type<SPV_IsEntryPointType, "SPIR-V entry point type">;
+
+//===----------------------------------------------------------------------===//
+// SPIR-V enum definitions
+//===----------------------------------------------------------------------===//
+
+// Begin enum section. Generated from SPIR-V spec; DO NOT MODIFY!
+
+def SPV_AM_Logical                    : I32EnumAttrCase<"Logical", 0>;
+def SPV_AM_Physical32                 : I32EnumAttrCase<"Physical32", 1>;
+def SPV_AM_Physical64                 : I32EnumAttrCase<"Physical64", 2>;
+def SPV_AM_PhysicalStorageBuffer64EXT : I32EnumAttrCase<"PhysicalStorageBuffer64EXT", 5348>;
+
+def SPV_AddressingModelAttr :
+    I32EnumAttr<"AddressingModel", "valid SPIR-V AddressingModel", [
+      SPV_AM_Logical, SPV_AM_Physical32, SPV_AM_Physical64,
+      SPV_AM_PhysicalStorageBuffer64EXT
+    ]> {
+  let returnType = "::mlir::spirv::AddressingModel";
+  let convertFromStorage = "static_cast<::mlir::spirv::AddressingModel>($_self.getInt())";
+  let cppNamespace = "::mlir::spirv";
+}
+
+def SPV_D_RelaxedPrecision            : I32EnumAttrCase<"RelaxedPrecision", 0>;
+def SPV_D_SpecId                      : I32EnumAttrCase<"SpecId", 1>;
+def SPV_D_Block                       : I32EnumAttrCase<"Block", 2>;
+def SPV_D_BufferBlock                 : I32EnumAttrCase<"BufferBlock", 3>;
+def SPV_D_RowMajor                    : I32EnumAttrCase<"RowMajor", 4>;
+def SPV_D_ColMajor                    : I32EnumAttrCase<"ColMajor", 5>;
+def SPV_D_ArrayStride                 : I32EnumAttrCase<"ArrayStride", 6>;
+def SPV_D_MatrixStride                : I32EnumAttrCase<"MatrixStride", 7>;
+def SPV_D_GLSLShared                  : I32EnumAttrCase<"GLSLShared", 8>;
+def SPV_D_GLSLPacked                  : I32EnumAttrCase<"GLSLPacked", 9>;
+def SPV_D_CPacked                     : I32EnumAttrCase<"CPacked", 10>;
+def SPV_D_BuiltIn                     : I32EnumAttrCase<"BuiltIn", 11>;
+def SPV_D_NoPerspective               : I32EnumAttrCase<"NoPerspective", 13>;
+def SPV_D_Flat                        : I32EnumAttrCase<"Flat", 14>;
+def SPV_D_Patch                       : I32EnumAttrCase<"Patch", 15>;
+def SPV_D_Centroid                    : I32EnumAttrCase<"Centroid", 16>;
+def SPV_D_Sample                      : I32EnumAttrCase<"Sample", 17>;
+def SPV_D_Invariant                   : I32EnumAttrCase<"Invariant", 18>;
+def SPV_D_Restrict                    : I32EnumAttrCase<"Restrict", 19>;
+def SPV_D_Aliased                     : I32EnumAttrCase<"Aliased", 20>;
+def SPV_D_Volatile                    : I32EnumAttrCase<"Volatile", 21>;
+def SPV_D_Constant                    : I32EnumAttrCase<"Constant", 22>;
+def SPV_D_Coherent                    : I32EnumAttrCase<"Coherent", 23>;
+def SPV_D_NonWritable                 : I32EnumAttrCase<"NonWritable", 24>;
+def SPV_D_NonReadable                 : I32EnumAttrCase<"NonReadable", 25>;
+def SPV_D_Uniform                     : I32EnumAttrCase<"Uniform", 26>;
+def SPV_D_UniformId                   : I32EnumAttrCase<"UniformId", 27>;
+def SPV_D_SaturatedConversion         : I32EnumAttrCase<"SaturatedConversion", 28>;
+def SPV_D_Stream                      : I32EnumAttrCase<"Stream", 29>;
+def SPV_D_Location                    : I32EnumAttrCase<"Location", 30>;
+def SPV_D_Component                   : I32EnumAttrCase<"Component", 31>;
+def SPV_D_Index                       : I32EnumAttrCase<"Index", 32>;
+def SPV_D_Binding                     : I32EnumAttrCase<"Binding", 33>;
+def SPV_D_DescriptorSet               : I32EnumAttrCase<"DescriptorSet", 34>;
+def SPV_D_Offset                      : I32EnumAttrCase<"Offset", 35>;
+def SPV_D_XfbBuffer                   : I32EnumAttrCase<"XfbBuffer", 36>;
+def SPV_D_XfbStride                   : I32EnumAttrCase<"XfbStride", 37>;
+def SPV_D_FuncParamAttr               : I32EnumAttrCase<"FuncParamAttr", 38>;
+def SPV_D_FPRoundingMode              : I32EnumAttrCase<"FPRoundingMode", 39>;
+def SPV_D_FPFastMathMode              : I32EnumAttrCase<"FPFastMathMode", 40>;
+def SPV_D_LinkageAttributes           : I32EnumAttrCase<"LinkageAttributes", 41>;
+def SPV_D_NoContraction               : I32EnumAttrCase<"NoContraction", 42>;
+def SPV_D_InputAttachmentIndex        : I32EnumAttrCase<"InputAttachmentIndex", 43>;
+def SPV_D_Alignment                   : I32EnumAttrCase<"Alignment", 44>;
+def SPV_D_MaxByteOffset               : I32EnumAttrCase<"MaxByteOffset", 45>;
+def SPV_D_AlignmentId                 : I32EnumAttrCase<"AlignmentId", 46>;
+def SPV_D_MaxByteOffsetId             : I32EnumAttrCase<"MaxByteOffsetId", 47>;
+def SPV_D_NoSignedWrap                : I32EnumAttrCase<"NoSignedWrap", 4469>;
+def SPV_D_NoUnsignedWrap              : I32EnumAttrCase<"NoUnsignedWrap", 4470>;
+def SPV_D_ExplicitInterpAMD           : I32EnumAttrCase<"ExplicitInterpAMD", 4999>;
+def SPV_D_OverrideCoverageNV          : I32EnumAttrCase<"OverrideCoverageNV", 5248>;
+def SPV_D_PassthroughNV               : I32EnumAttrCase<"PassthroughNV", 5250>;
+def SPV_D_ViewportRelativeNV          : I32EnumAttrCase<"ViewportRelativeNV", 5252>;
+def SPV_D_SecondaryViewportRelativeNV : I32EnumAttrCase<"SecondaryViewportRelativeNV", 5256>;
+def SPV_D_PerPrimitiveNV              : I32EnumAttrCase<"PerPrimitiveNV", 5271>;
+def SPV_D_PerViewNV                   : I32EnumAttrCase<"PerViewNV", 5272>;
+def SPV_D_PerTaskNV                   : I32EnumAttrCase<"PerTaskNV", 5273>;
+def SPV_D_PerVertexNV                 : I32EnumAttrCase<"PerVertexNV", 5285>;
+def SPV_D_NonUniformEXT               : I32EnumAttrCase<"NonUniformEXT", 5300>;
+def SPV_D_RestrictPointerEXT          : I32EnumAttrCase<"RestrictPointerEXT", 5355>;
+def SPV_D_AliasedPointerEXT           : I32EnumAttrCase<"AliasedPointerEXT", 5356>;
+def SPV_D_CounterBuffer               : I32EnumAttrCase<"CounterBuffer", 5634>;
+def SPV_D_UserSemantic                : I32EnumAttrCase<"UserSemantic", 5635>;
+def SPV_D_UserTypeGOOGLE              : I32EnumAttrCase<"UserTypeGOOGLE", 5636>;
+
+def SPV_DecorationAttr :
+    I32EnumAttr<"Decoration", "valid SPIR-V Decoration", [
+      SPV_D_RelaxedPrecision, SPV_D_SpecId, SPV_D_Block, SPV_D_BufferBlock,
+      SPV_D_RowMajor, SPV_D_ColMajor, SPV_D_ArrayStride, SPV_D_MatrixStride,
+      SPV_D_GLSLShared, SPV_D_GLSLPacked, SPV_D_CPacked, SPV_D_BuiltIn,
+      SPV_D_NoPerspective, SPV_D_Flat, SPV_D_Patch, SPV_D_Centroid, SPV_D_Sample,
+      SPV_D_Invariant, SPV_D_Restrict, SPV_D_Aliased, SPV_D_Volatile, SPV_D_Constant,
+      SPV_D_Coherent, SPV_D_NonWritable, SPV_D_NonReadable, SPV_D_Uniform,
+      SPV_D_UniformId, SPV_D_SaturatedConversion, SPV_D_Stream, SPV_D_Location,
+      SPV_D_Component, SPV_D_Index, SPV_D_Binding, SPV_D_DescriptorSet, SPV_D_Offset,
+      SPV_D_XfbBuffer, SPV_D_XfbStride, SPV_D_FuncParamAttr, SPV_D_FPRoundingMode,
+      SPV_D_FPFastMathMode, SPV_D_LinkageAttributes, SPV_D_NoContraction,
+      SPV_D_InputAttachmentIndex, SPV_D_Alignment, SPV_D_MaxByteOffset,
+      SPV_D_AlignmentId, SPV_D_MaxByteOffsetId, SPV_D_NoSignedWrap,
+      SPV_D_NoUnsignedWrap, SPV_D_ExplicitInterpAMD, SPV_D_OverrideCoverageNV,
+      SPV_D_PassthroughNV, SPV_D_ViewportRelativeNV,
+      SPV_D_SecondaryViewportRelativeNV, SPV_D_PerPrimitiveNV, SPV_D_PerViewNV,
+      SPV_D_PerTaskNV, SPV_D_PerVertexNV, SPV_D_NonUniformEXT,
+      SPV_D_RestrictPointerEXT, SPV_D_AliasedPointerEXT, SPV_D_CounterBuffer,
+      SPV_D_UserSemantic, SPV_D_UserTypeGOOGLE
+    ]> {
+  let returnType = "::mlir::spirv::Decoration";
+  let convertFromStorage = "static_cast<::mlir::spirv::Decoration>($_self.getInt())";
+  let cppNamespace = "::mlir::spirv";
+}
+
+def SPV_D_1D          : I32EnumAttrCase<"1D", 0>;
+def SPV_D_2D          : I32EnumAttrCase<"2D", 1>;
+def SPV_D_3D          : I32EnumAttrCase<"3D", 2>;
+def SPV_D_Cube        : I32EnumAttrCase<"Cube", 3>;
+def SPV_D_Rect        : I32EnumAttrCase<"Rect", 4>;
+def SPV_D_Buffer      : I32EnumAttrCase<"Buffer", 5>;
+def SPV_D_SubpassData : I32EnumAttrCase<"SubpassData", 6>;
+
+def SPV_DimAttr :
+    I32EnumAttr<"Dim", "valid SPIR-V Dim", [
+      SPV_D_1D, SPV_D_2D, SPV_D_3D, SPV_D_Cube, SPV_D_Rect, SPV_D_Buffer,
+      SPV_D_SubpassData
+    ]> {
+  let returnType = "::mlir::spirv::Dim";
+  let convertFromStorage = "static_cast<::mlir::spirv::Dim>($_self.getInt())";
+  let cppNamespace = "::mlir::spirv";
+}
+
+def SPV_EM_Invocations                      : I32EnumAttrCase<"Invocations", 0>;
+def SPV_EM_SpacingEqual                     : I32EnumAttrCase<"SpacingEqual", 1>;
+def SPV_EM_SpacingFractionalEven            : I32EnumAttrCase<"SpacingFractionalEven", 2>;
+def SPV_EM_SpacingFractionalOdd             : I32EnumAttrCase<"SpacingFractionalOdd", 3>;
+def SPV_EM_VertexOrderCw                    : I32EnumAttrCase<"VertexOrderCw", 4>;
+def SPV_EM_VertexOrderCcw                   : I32EnumAttrCase<"VertexOrderCcw", 5>;
+def SPV_EM_PixelCenterInteger               : I32EnumAttrCase<"PixelCenterInteger", 6>;
+def SPV_EM_OriginUpperLeft                  : I32EnumAttrCase<"OriginUpperLeft", 7>;
+def SPV_EM_OriginLowerLeft                  : I32EnumAttrCase<"OriginLowerLeft", 8>;
+def SPV_EM_EarlyFragmentTests               : I32EnumAttrCase<"EarlyFragmentTests", 9>;
+def SPV_EM_PointMode                        : I32EnumAttrCase<"PointMode", 10>;
+def SPV_EM_Xfb                              : I32EnumAttrCase<"Xfb", 11>;
+def SPV_EM_DepthReplacing                   : I32EnumAttrCase<"DepthReplacing", 12>;
+def SPV_EM_DepthGreater                     : I32EnumAttrCase<"DepthGreater", 14>;
+def SPV_EM_DepthLess                        : I32EnumAttrCase<"DepthLess", 15>;
+def SPV_EM_DepthUnchanged                   : I32EnumAttrCase<"DepthUnchanged", 16>;
+def SPV_EM_LocalSize                        : I32EnumAttrCase<"LocalSize", 17>;
+def SPV_EM_LocalSizeHint                    : I32EnumAttrCase<"LocalSizeHint", 18>;
+def SPV_EM_InputPoints                      : I32EnumAttrCase<"InputPoints", 19>;
+def SPV_EM_InputLines                       : I32EnumAttrCase<"InputLines", 20>;
+def SPV_EM_InputLinesAdjacency              : I32EnumAttrCase<"InputLinesAdjacency", 21>;
+def SPV_EM_Triangles                        : I32EnumAttrCase<"Triangles", 22>;
+def SPV_EM_InputTrianglesAdjacency          : I32EnumAttrCase<"InputTrianglesAdjacency", 23>;
+def SPV_EM_Quads                            : I32EnumAttrCase<"Quads", 24>;
+def SPV_EM_Isolines                         : I32EnumAttrCase<"Isolines", 25>;
+def SPV_EM_OutputVertices                   : I32EnumAttrCase<"OutputVertices", 26>;
+def SPV_EM_OutputPoints                     : I32EnumAttrCase<"OutputPoints", 27>;
+def SPV_EM_OutputLineStrip                  : I32EnumAttrCase<"OutputLineStrip", 28>;
+def SPV_EM_OutputTriangleStrip              : I32EnumAttrCase<"OutputTriangleStrip", 29>;
+def SPV_EM_VecTypeHint                      : I32EnumAttrCase<"VecTypeHint", 30>;
+def SPV_EM_ContractionOff                   : I32EnumAttrCase<"ContractionOff", 31>;
+def SPV_EM_Initializer                      : I32EnumAttrCase<"Initializer", 33>;
+def SPV_EM_Finalizer                        : I32EnumAttrCase<"Finalizer", 34>;
+def SPV_EM_SubgroupSize                     : I32EnumAttrCase<"SubgroupSize", 35>;
+def SPV_EM_SubgroupsPerWorkgroup            : I32EnumAttrCase<"SubgroupsPerWorkgroup", 36>;
+def SPV_EM_SubgroupsPerWorkgroupId          : I32EnumAttrCase<"SubgroupsPerWorkgroupId", 37>;
+def SPV_EM_LocalSizeId                      : I32EnumAttrCase<"LocalSizeId", 38>;
+def SPV_EM_LocalSizeHintId                  : I32EnumAttrCase<"LocalSizeHintId", 39>;
+def SPV_EM_PostDepthCoverage                : I32EnumAttrCase<"PostDepthCoverage", 4446>;
+def SPV_EM_DenormPreserve                   : I32EnumAttrCase<"DenormPreserve", 4459>;
+def SPV_EM_DenormFlushToZero                : I32EnumAttrCase<"DenormFlushToZero", 4460>;
+def SPV_EM_SignedZeroInfNanPreserve         : I32EnumAttrCase<"SignedZeroInfNanPreserve", 4461>;
+def SPV_EM_RoundingModeRTE                  : I32EnumAttrCase<"RoundingModeRTE", 4462>;
+def SPV_EM_RoundingModeRTZ                  : I32EnumAttrCase<"RoundingModeRTZ", 4463>;
+def SPV_EM_StencilRefReplacingEXT           : I32EnumAttrCase<"StencilRefReplacingEXT", 5027>;
+def SPV_EM_OutputLinesNV                    : I32EnumAttrCase<"OutputLinesNV", 5269>;
+def SPV_EM_OutputPrimitivesNV               : I32EnumAttrCase<"OutputPrimitivesNV", 5270>;
+def SPV_EM_DerivativeGroupQuadsNV           : I32EnumAttrCase<"DerivativeGroupQuadsNV", 5289>;
+def SPV_EM_DerivativeGroupLinearNV          : I32EnumAttrCase<"DerivativeGroupLinearNV", 5290>;
+def SPV_EM_OutputTrianglesNV                : I32EnumAttrCase<"OutputTrianglesNV", 5298>;
+def SPV_EM_PixelInterlockOrderedEXT         : I32EnumAttrCase<"PixelInterlockOrderedEXT", 5366>;
+def SPV_EM_PixelInterlockUnorderedEXT       : I32EnumAttrCase<"PixelInterlockUnorderedEXT", 5367>;
+def SPV_EM_SampleInterlockOrderedEXT        : I32EnumAttrCase<"SampleInterlockOrderedEXT", 5368>;
+def SPV_EM_SampleInterlockUnorderedEXT      : I32EnumAttrCase<"SampleInterlockUnorderedEXT", 5369>;
+def SPV_EM_ShadingRateInterlockOrderedEXT   : I32EnumAttrCase<"ShadingRateInterlockOrderedEXT", 5370>;
+def SPV_EM_ShadingRateInterlockUnorderedEXT : I32EnumAttrCase<"ShadingRateInterlockUnorderedEXT", 5371>;
+
+def SPV_ExecutionModeAttr :
+    I32EnumAttr<"ExecutionMode", "valid SPIR-V ExecutionMode", [
+      SPV_EM_Invocations, SPV_EM_SpacingEqual, SPV_EM_SpacingFractionalEven,
+      SPV_EM_SpacingFractionalOdd, SPV_EM_VertexOrderCw, SPV_EM_VertexOrderCcw,
+      SPV_EM_PixelCenterInteger, SPV_EM_OriginUpperLeft, SPV_EM_OriginLowerLeft,
+      SPV_EM_EarlyFragmentTests, SPV_EM_PointMode, SPV_EM_Xfb, SPV_EM_DepthReplacing,
+      SPV_EM_DepthGreater, SPV_EM_DepthLess, SPV_EM_DepthUnchanged, SPV_EM_LocalSize,
+      SPV_EM_LocalSizeHint, SPV_EM_InputPoints, SPV_EM_InputLines,
+      SPV_EM_InputLinesAdjacency, SPV_EM_Triangles, SPV_EM_InputTrianglesAdjacency,
+      SPV_EM_Quads, SPV_EM_Isolines, SPV_EM_OutputVertices, SPV_EM_OutputPoints,
+      SPV_EM_OutputLineStrip, SPV_EM_OutputTriangleStrip, SPV_EM_VecTypeHint,
+      SPV_EM_ContractionOff, SPV_EM_Initializer, SPV_EM_Finalizer,
+      SPV_EM_SubgroupSize, SPV_EM_SubgroupsPerWorkgroup,
+      SPV_EM_SubgroupsPerWorkgroupId, SPV_EM_LocalSizeId, SPV_EM_LocalSizeHintId,
+      SPV_EM_PostDepthCoverage, SPV_EM_DenormPreserve, SPV_EM_DenormFlushToZero,
+      SPV_EM_SignedZeroInfNanPreserve, SPV_EM_RoundingModeRTE,
+      SPV_EM_RoundingModeRTZ, SPV_EM_StencilRefReplacingEXT, SPV_EM_OutputLinesNV,
+      SPV_EM_OutputPrimitivesNV, SPV_EM_DerivativeGroupQuadsNV,
+      SPV_EM_DerivativeGroupLinearNV, SPV_EM_OutputTrianglesNV,
+      SPV_EM_PixelInterlockOrderedEXT, SPV_EM_PixelInterlockUnorderedEXT,
+      SPV_EM_SampleInterlockOrderedEXT, SPV_EM_SampleInterlockUnorderedEXT,
+      SPV_EM_ShadingRateInterlockOrderedEXT, SPV_EM_ShadingRateInterlockUnorderedEXT
+    ]> {
+  let returnType = "::mlir::spirv::ExecutionMode";
+  let convertFromStorage = "static_cast<::mlir::spirv::ExecutionMode>($_self.getInt())";
+  let cppNamespace = "::mlir::spirv";
+}
+
+def SPV_EM_Vertex                 : I32EnumAttrCase<"Vertex", 0>;
+def SPV_EM_TessellationControl    : I32EnumAttrCase<"TessellationControl", 1>;
+def SPV_EM_TessellationEvaluation : I32EnumAttrCase<"TessellationEvaluation", 2>;
+def SPV_EM_Geometry               : I32EnumAttrCase<"Geometry", 3>;
+def SPV_EM_Fragment               : I32EnumAttrCase<"Fragment", 4>;
+def SPV_EM_GLCompute              : I32EnumAttrCase<"GLCompute", 5>;
+def SPV_EM_Kernel                 : I32EnumAttrCase<"Kernel", 6>;
+def SPV_EM_TaskNV                 : I32EnumAttrCase<"TaskNV", 5267>;
+def SPV_EM_MeshNV                 : I32EnumAttrCase<"MeshNV", 5268>;
+def SPV_EM_RayGenerationNV        : I32EnumAttrCase<"RayGenerationNV", 5313>;
+def SPV_EM_IntersectionNV         : I32EnumAttrCase<"IntersectionNV", 5314>;
+def SPV_EM_AnyHitNV               : I32EnumAttrCase<"AnyHitNV", 5315>;
+def SPV_EM_ClosestHitNV           : I32EnumAttrCase<"ClosestHitNV", 5316>;
+def SPV_EM_MissNV                 : I32EnumAttrCase<"MissNV", 5317>;
+def SPV_EM_CallableNV             : I32EnumAttrCase<"CallableNV", 5318>;
+
+def SPV_ExecutionModelAttr :
+    I32EnumAttr<"ExecutionModel", "valid SPIR-V ExecutionModel", [
+      SPV_EM_Vertex, SPV_EM_TessellationControl, SPV_EM_TessellationEvaluation,
+      SPV_EM_Geometry, SPV_EM_Fragment, SPV_EM_GLCompute, SPV_EM_Kernel,
+      SPV_EM_TaskNV, SPV_EM_MeshNV, SPV_EM_RayGenerationNV, SPV_EM_IntersectionNV,
+      SPV_EM_AnyHitNV, SPV_EM_ClosestHitNV, SPV_EM_MissNV, SPV_EM_CallableNV
+    ]> {
+  let returnType = "::mlir::spirv::ExecutionModel";
+  let convertFromStorage = "static_cast<::mlir::spirv::ExecutionModel>($_self.getInt())";
+  let cppNamespace = "::mlir::spirv";
+}
+
+def SPV_FC_None       : I32EnumAttrCase<"None", 0x0000>;
+def SPV_FC_Inline     : I32EnumAttrCase<"Inline", 0x0001>;
+def SPV_FC_DontInline : I32EnumAttrCase<"DontInline", 0x0002>;
+def SPV_FC_Pure       : I32EnumAttrCase<"Pure", 0x0004>;
+def SPV_FC_Const      : I32EnumAttrCase<"Const", 0x0008>;
+
+def SPV_FunctionControlAttr :
+    I32EnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", [
+      SPV_FC_None, SPV_FC_Inline, SPV_FC_DontInline, SPV_FC_Pure, SPV_FC_Const
+    ]> {
+  let returnType = "::mlir::spirv::FunctionControl";
+  let convertFromStorage = "static_cast<::mlir::spirv::FunctionControl>($_self.getInt())";
+  let cppNamespace = "::mlir::spirv";
+}
+
+def SPV_IF_Unknown      : I32EnumAttrCase<"Unknown", 0>;
+def SPV_IF_Rgba32f      : I32EnumAttrCase<"Rgba32f", 1>;
+def SPV_IF_Rgba16f      : I32EnumAttrCase<"Rgba16f", 2>;
+def SPV_IF_R32f         : I32EnumAttrCase<"R32f", 3>;
+def SPV_IF_Rgba8        : I32EnumAttrCase<"Rgba8", 4>;
+def SPV_IF_Rgba8Snorm   : I32EnumAttrCase<"Rgba8Snorm", 5>;
+def SPV_IF_Rg32f        : I32EnumAttrCase<"Rg32f", 6>;
+def SPV_IF_Rg16f        : I32EnumAttrCase<"Rg16f", 7>;
+def SPV_IF_R11fG11fB10f : I32EnumAttrCase<"R11fG11fB10f", 8>;
+def SPV_IF_R16f         : I32EnumAttrCase<"R16f", 9>;
+def SPV_IF_Rgba16       : I32EnumAttrCase<"Rgba16", 10>;
+def SPV_IF_Rgb10A2      : I32EnumAttrCase<"Rgb10A2", 11>;
+def SPV_IF_Rg16         : I32EnumAttrCase<"Rg16", 12>;
+def SPV_IF_Rg8          : I32EnumAttrCase<"Rg8", 13>;
+def SPV_IF_R16          : I32EnumAttrCase<"R16", 14>;
+def SPV_IF_R8           : I32EnumAttrCase<"R8", 15>;
+def SPV_IF_Rgba16Snorm  : I32EnumAttrCase<"Rgba16Snorm", 16>;
+def SPV_IF_Rg16Snorm    : I32EnumAttrCase<"Rg16Snorm", 17>;
+def SPV_IF_Rg8Snorm     : I32EnumAttrCase<"Rg8Snorm", 18>;
+def SPV_IF_R16Snorm     : I32EnumAttrCase<"R16Snorm", 19>;
+def SPV_IF_R8Snorm      : I32EnumAttrCase<"R8Snorm", 20>;
+def SPV_IF_Rgba32i      : I32EnumAttrCase<"Rgba32i", 21>;
+def SPV_IF_Rgba16i      : I32EnumAttrCase<"Rgba16i", 22>;
+def SPV_IF_Rgba8i       : I32EnumAttrCase<"Rgba8i", 23>;
+def SPV_IF_R32i         : I32EnumAttrCase<"R32i", 24>;
+def SPV_IF_Rg32i        : I32EnumAttrCase<"Rg32i", 25>;
+def SPV_IF_Rg16i        : I32EnumAttrCase<"Rg16i", 26>;
+def SPV_IF_Rg8i         : I32EnumAttrCase<"Rg8i", 27>;
+def SPV_IF_R16i         : I32EnumAttrCase<"R16i", 28>;
+def SPV_IF_R8i          : I32EnumAttrCase<"R8i", 29>;
+def SPV_IF_Rgba32ui     : I32EnumAttrCase<"Rgba32ui", 30>;
+def SPV_IF_Rgba16ui     : I32EnumAttrCase<"Rgba16ui", 31>;
+def SPV_IF_Rgba8ui      : I32EnumAttrCase<"Rgba8ui", 32>;
+def SPV_IF_R32ui        : I32EnumAttrCase<"R32ui", 33>;
+def SPV_IF_Rgb10a2ui    : I32EnumAttrCase<"Rgb10a2ui", 34>;
+def SPV_IF_Rg32ui       : I32EnumAttrCase<"Rg32ui", 35>;
+def SPV_IF_Rg16ui       : I32EnumAttrCase<"Rg16ui", 36>;
+def SPV_IF_Rg8ui        : I32EnumAttrCase<"Rg8ui", 37>;
+def SPV_IF_R16ui        : I32EnumAttrCase<"R16ui", 38>;
+def SPV_IF_R8ui         : I32EnumAttrCase<"R8ui", 39>;
+
+def SPV_ImageFormatAttr :
+    I32EnumAttr<"ImageFormat", "valid SPIR-V ImageFormat", [
+      SPV_IF_Unknown, SPV_IF_Rgba32f, SPV_IF_Rgba16f, SPV_IF_R32f, SPV_IF_Rgba8,
+      SPV_IF_Rgba8Snorm, SPV_IF_Rg32f, SPV_IF_Rg16f, SPV_IF_R11fG11fB10f,
+      SPV_IF_R16f, SPV_IF_Rgba16, SPV_IF_Rgb10A2, SPV_IF_Rg16, SPV_IF_Rg8,
+      SPV_IF_R16, SPV_IF_R8, SPV_IF_Rgba16Snorm, SPV_IF_Rg16Snorm, SPV_IF_Rg8Snorm,
+      SPV_IF_R16Snorm, SPV_IF_R8Snorm, SPV_IF_Rgba32i, SPV_IF_Rgba16i, SPV_IF_Rgba8i,
+      SPV_IF_R32i, SPV_IF_Rg32i, SPV_IF_Rg16i, SPV_IF_Rg8i, SPV_IF_R16i, SPV_IF_R8i,
+      SPV_IF_Rgba32ui, SPV_IF_Rgba16ui, SPV_IF_Rgba8ui, SPV_IF_R32ui,
+      SPV_IF_Rgb10a2ui, SPV_IF_Rg32ui, SPV_IF_Rg16ui, SPV_IF_Rg8ui, SPV_IF_R16ui,
+      SPV_IF_R8ui
+    ]> {
+  let returnType = "::mlir::spirv::ImageFormat";
+  let convertFromStorage = "static_cast<::mlir::spirv::ImageFormat>($_self.getInt())";
+  let cppNamespace = "::mlir::spirv";
+}
+
+def SPV_LT_Export : I32EnumAttrCase<"Export", 0>;
+def SPV_LT_Import : I32EnumAttrCase<"Import", 1>;
+
+def SPV_LinkageTypeAttr :
+    I32EnumAttr<"LinkageType", "valid SPIR-V LinkageType", [
+      SPV_LT_Export, SPV_LT_Import
+    ]> {
+  let returnType = "::mlir::spirv::LinkageType";
+  let convertFromStorage = "static_cast<::mlir::spirv::LinkageType>($_self.getInt())";
+  let cppNamespace = "::mlir::spirv";
+}
+
+def SPV_MA_None                    : I32EnumAttrCase<"None", 0x0000>;
+def SPV_MA_Volatile                : I32EnumAttrCase<"Volatile", 0x0001>;
+def SPV_MA_Aligned                 : I32EnumAttrCase<"Aligned", 0x0002>;
+def SPV_MA_Nontemporal             : I32EnumAttrCase<"Nontemporal", 0x0004>;
+def SPV_MA_MakePointerAvailableKHR : I32EnumAttrCase<"MakePointerAvailableKHR", 0x0008>;
+def SPV_MA_MakePointerVisibleKHR   : I32EnumAttrCase<"MakePointerVisibleKHR", 0x0010>;
+def SPV_MA_NonPrivatePointerKHR    : I32EnumAttrCase<"NonPrivatePointerKHR", 0x0020>;
+
+def SPV_MemoryAccessAttr :
+    I32EnumAttr<"MemoryAccess", "valid SPIR-V MemoryAccess", [
+      SPV_MA_None, SPV_MA_Volatile, SPV_MA_Aligned, SPV_MA_Nontemporal,
+      SPV_MA_MakePointerAvailableKHR, SPV_MA_MakePointerVisibleKHR,
+      SPV_MA_NonPrivatePointerKHR
+    ]> {
+  let returnType = "::mlir::spirv::MemoryAccess";
+  let convertFromStorage = "static_cast<::mlir::spirv::MemoryAccess>($_self.getInt())";
+  let cppNamespace = "::mlir::spirv";
+}
+
+def SPV_MM_Simple    : I32EnumAttrCase<"Simple", 0>;
+def SPV_MM_GLSL450   : I32EnumAttrCase<"GLSL450", 1>;
+def SPV_MM_OpenCL    : I32EnumAttrCase<"OpenCL", 2>;
+def SPV_MM_VulkanKHR : I32EnumAttrCase<"VulkanKHR", 3>;
+
+def SPV_MemoryModelAttr :
+    I32EnumAttr<"MemoryModel", "valid SPIR-V MemoryModel", [
+      SPV_MM_Simple, SPV_MM_GLSL450, SPV_MM_OpenCL, SPV_MM_VulkanKHR
+    ]> {
+  let returnType = "::mlir::spirv::MemoryModel";
+  let convertFromStorage = "static_cast<::mlir::spirv::MemoryModel>($_self.getInt())";
+  let cppNamespace = "::mlir::spirv";
+}
+
+def SPV_SC_UniformConstant          : I32EnumAttrCase<"UniformConstant", 0>;
+def SPV_SC_Input                    : I32EnumAttrCase<"Input", 1>;
+def SPV_SC_Uniform                  : I32EnumAttrCase<"Uniform", 2>;
+def SPV_SC_Output                   : I32EnumAttrCase<"Output", 3>;
+def SPV_SC_Workgroup                : I32EnumAttrCase<"Workgroup", 4>;
+def SPV_SC_CrossWorkgroup           : I32EnumAttrCase<"CrossWorkgroup", 5>;
+def SPV_SC_Private                  : I32EnumAttrCase<"Private", 6>;
+def SPV_SC_Function                 : I32EnumAttrCase<"Function", 7>;
+def SPV_SC_Generic                  : I32EnumAttrCase<"Generic", 8>;
+def SPV_SC_PushConstant             : I32EnumAttrCase<"PushConstant", 9>;
+def SPV_SC_AtomicCounter            : I32EnumAttrCase<"AtomicCounter", 10>;
+def SPV_SC_Image                    : I32EnumAttrCase<"Image", 11>;
+def SPV_SC_StorageBuffer            : I32EnumAttrCase<"StorageBuffer", 12>;
+def SPV_SC_CallableDataNV           : I32EnumAttrCase<"CallableDataNV", 5328>;
+def SPV_SC_IncomingCallableDataNV   : I32EnumAttrCase<"IncomingCallableDataNV", 5329>;
+def SPV_SC_RayPayloadNV             : I32EnumAttrCase<"RayPayloadNV", 5338>;
+def SPV_SC_HitAttributeNV           : I32EnumAttrCase<"HitAttributeNV", 5339>;
+def SPV_SC_IncomingRayPayloadNV     : I32EnumAttrCase<"IncomingRayPayloadNV", 5342>;
+def SPV_SC_ShaderRecordBufferNV     : I32EnumAttrCase<"ShaderRecordBufferNV", 5343>;
+def SPV_SC_PhysicalStorageBufferEXT : I32EnumAttrCase<"PhysicalStorageBufferEXT", 5349>;
+
+def SPV_StorageClassAttr :
+    I32EnumAttr<"StorageClass", "valid SPIR-V StorageClass", [
+      SPV_SC_UniformConstant, SPV_SC_Input, SPV_SC_Uniform, SPV_SC_Output,
+      SPV_SC_Workgroup, SPV_SC_CrossWorkgroup, SPV_SC_Private, SPV_SC_Function,
+      SPV_SC_Generic, SPV_SC_PushConstant, SPV_SC_AtomicCounter, SPV_SC_Image,
+      SPV_SC_StorageBuffer, SPV_SC_CallableDataNV, SPV_SC_IncomingCallableDataNV,
+      SPV_SC_RayPayloadNV, SPV_SC_HitAttributeNV, SPV_SC_IncomingRayPayloadNV,
+      SPV_SC_ShaderRecordBufferNV, SPV_SC_PhysicalStorageBufferEXT
+    ]> {
+  let returnType = "::mlir::spirv::StorageClass";
+  let convertFromStorage = "static_cast<::mlir::spirv::StorageClass>($_self.getInt())";
+  let cppNamespace = "::mlir::spirv";
+}
+
+// End enum section. Generated from SPIR-V spec; DO NOT MODIFY!
+
+// Enums added manually that are not part of SPIRV spec
+
+def SPV_IDI_NoDepth      : I32EnumAttrCase<"NoDepth", 0>;
+def SPV_IDI_IsDepth      : I32EnumAttrCase<"IsDepth", 1>;
+def SPV_IDI_DepthUnknown : I32EnumAttrCase<"DepthUnknown", 2>;
+
+def SPV_DepthAttr :
+    I32EnumAttr<"ImageDepthInfo", "valid SPIR-V Image Depth specification",
+      [SPV_IDI_NoDepth, SPV_IDI_IsDepth, SPV_IDI_DepthUnknown]> {
+  let cppNamespace = "::mlir::spirv";
+}
+
+def SPV_IAI_NonArrayed : I32EnumAttrCase<"NonArrayed", 0>;
+def SPV_IAI_Arrayed    : I32EnumAttrCase<"Arrayed", 1>;
+
+def SPV_ArrayedAttr :
+    I32EnumAttr<"ImageArrayedInfo", "valid SPIR-V Image Arrayed specification",
+      [SPV_IAI_NonArrayed, SPV_IAI_Arrayed]> {
+  let cppNamespace = "::mlir::spirv";
+}
+
+def SPV_ISI_SingleSampled : I32EnumAttrCase<"SingleSampled", 0>;
+def SPV_ISI_MultiSampled  : I32EnumAttrCase<"MultiSampled", 1>;
+
+def SPV_SamplingAttr:
+    I32EnumAttr<"ImageSamplingInfo", "valid SPIR-V Image Sampling specification",
+      [SPV_ISI_SingleSampled, SPV_ISI_MultiSampled]> {
+  let cppNamespace = "::mlir::spirv";
+}
+
+def SPV_ISUI_SamplerUnknown : I32EnumAttrCase<"SamplerUnknown", 0>;
+def SPV_ISUI_NeedSampler    : I32EnumAttrCase<"NeedSampler", 1>;
+def SPV_ISUI_NoSampler      : I32EnumAttrCase<"NoSampler", 2>;
+
+def SPV_SamplerUseAttr:
+    I32EnumAttr<"ImageSamplerUseInfo", "valid SPIR-V Sampler Use specification",
+      [SPV_ISUI_SamplerUnknown, SPV_ISUI_NeedSampler, SPV_ISUI_NoSampler]> {
+  let cppNamespace = "::mlir::spirv";
+}
+
+//===----------------------------------------------------------------------===//
+// SPIR-V OpTrait definitions
+//===----------------------------------------------------------------------===//
+
+// Check that an op can only be used with SPIR-V ModuleOp
+def IsModuleOnlyPred :
+  CPred<"llvm::isa_and_nonnull<spirv::ModuleOp>($_op.getParentOp())">;
+
+def ModuleOnly :
+  PredOpTrait<"op can only be used in a 'spv.module' block", IsModuleOnlyPred>;
+
+//===----------------------------------------------------------------------===//
+// SPIR-V op definitions
+//===----------------------------------------------------------------------===//
+
+// Base class for all SPIR-V ops.
+class SPV_Op<string mnemonic, list<OpTrait> traits = []> :
+    Op<SPV_Dialect, mnemonic, traits> {
+
+  // For each SPIR-V op, the following static functions need to be defined
+  // in SPVOps.cpp:
+  //
+  // * static ParseResult parse<op-c++-class-name>(OpAsmParser *parser,
+  //                                               OperationState *result)
+  // * static void print(OpAsmPrinter *p, <op-c++-class-name> op)
+  // * static LogicalResult verify(<op-c++-class-name> op)
+  let parser = [{ return ::parse$cppClass(parser, result); }];
+  let printer = [{ return ::print(*this, p); }];
+  let verifier = [{ return ::verify(*this); }];
+
+  // Specifies whether this op has a direct corresponding SPIR-V binary
+  // instruction opcode. The (de)serializer use this field to determine whether
+  // to auto-generate an entry in the (de)serialization dispatch table for this
+  // op. If set, this field also futher enables `autogenSerialization` (see
+  // below for details).
+  bit hasOpcode = 1;
+
+  // Name of the corresponding SPIR-V op. Only valid to use when hasOpcode is 1.
+  string spirvOpName = "Op" # mnemonic;
+
+  // Controls whether to auto-generate this op's (de)serialization method.
+  // If set, it results in generation of the following methods:
+  //
+  // ```c++
+  // template<typename OpTy> Serializer::processOp(OpTy op);
+  // template<typename OpTy> Deserializer::processOp(ArrayRef<uint32_t>);
+  // ```
+  //
+  // If this field is not set, then manual implementation of a specialization of
+  // these methods is required.
+  //
+  // Note:
+  //
+  // 1) If hasOpcode is set but autogenSerialization is not set, the
+  //    (de)serializer dispatch method still calls the above method for
+  //    (de)serializing this op.
+  //
+  // 2) If hasOpcode is not set, then this field is not interpreted; this op's
+  //    (de)serialization method will not be auto-generated regardless. Neither
+  //    does the handling in the (de)serialization dispatch table. Both
+  //    (de)serializing this op and its dispatch should be handled manually.
+  bit autogenSerialization = 1;
+}
+
+class SPV_BinaryOp<string mnemonic, Type resultType, Type operandsType,
+                   list<OpTrait> traits = []> :
+      SPV_Op<mnemonic, traits> {
+  let arguments = (ins
+    SPV_ScalarOrVectorOf<operandsType>:$operand1,
+    SPV_ScalarOrVectorOf<operandsType>:$operand2
+  );
+  let results = (outs
+    SPV_ScalarOrVectorOf<resultType>:$result
+  );
+  let parser = [{ return impl::parseBinaryOp(parser, result); }];
+  let printer = [{ return impl::printBinaryOp(getOperation(), p); }];
+  // No additional verification needed in addition to the ODS-generated ones.
+  let verifier = [{ return success(); }];
+}
+
+class SPV_ArithmeticOp<string mnemonic, Type type,
+                       list<OpTrait> traits = []> :
+      // Operands type same as result type.
+      SPV_BinaryOp<mnemonic, type, type,
+                   !listconcat(traits,
+                               [NoSideEffect, SameOperandsAndResultType])> {
+}
+
+class SPV_LogicalOp<string mnemonic, Type operandsType,
+                    list<OpTrait> traits = []> :
+      // Result type is SPV_Bool.
+      SPV_BinaryOp<mnemonic, SPV_Bool, operandsType,
+                   !listconcat(traits,
+                               [NoSideEffect, SameTypeOperands,
+                                SameOperandsAndResultShape])> {
+  let parser = [{ return ::parseBinaryLogicalOp(parser, result); }];
+  let printer = [{ return ::printBinaryLogicalOp(getOperation(), p); }];
+}
+
+
+#endif // SPIRV_BASE
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBinaryUtils.h b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBinaryUtils.h
new file mode 100644
index 0000000..f255446
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBinaryUtils.h
@@ -0,0 +1,52 @@
+//===- SPIRVBinaryUtils.cpp - SPIR-V Binary Module Utils --------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file declares common utilities for SPIR-V binary module.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_SPIRV_BINARY_UTILS_H_
+#define MLIR_DIALECT_SPIRV_SPIRV_BINARY_UTILS_H_
+
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Support/LogicalResult.h"
+
+#include <cstdint>
+
+namespace mlir {
+namespace spirv {
+
+/// SPIR-V binary header word count
+constexpr unsigned kHeaderWordCount = 5;
+
+/// SPIR-V magic number
+constexpr uint32_t kMagicNumber = 0x07230203;
+
+/// The serializer tool ID registered to the Khronos Group
+constexpr uint32_t kGeneratorNumber = 22;
+
+/// Auto-generated getOpcode<*Op>() specializations
+#define GET_SPIRV_SERIALIZATION_UTILS
+#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc"
+
+/// Appends a SPRI-V module header to `header` with the given `idBound`.
+void appendModuleHeader(SmallVectorImpl<uint32_t> &header, uint32_t idBound);
+
+} // end namespace spirv
+} // end namespace mlir
+
+#endif // MLIR_DIALECT_SPIRV_SPIRV_BINARY_UTILS_H_
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h
new file mode 100644
index 0000000..494adc1
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h
@@ -0,0 +1,49 @@
+//===- SPIRVDialect.h - MLIR SPIR-V dialect ---------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file declares the SPIR-V dialect in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_SPIRVDIALECT_H_
+#define MLIR_DIALECT_SPIRV_SPIRVDIALECT_H_
+
+#include "mlir/IR/Dialect.h"
+
+namespace mlir {
+namespace spirv {
+
+class SPIRVDialect : public Dialect {
+public:
+  explicit SPIRVDialect(MLIRContext *context);
+
+  static StringRef getDialectNamespace() { return "spv"; }
+
+  /// Parses a type registered to this dialect.
+  Type parseType(llvm::StringRef spec, Location loc) const override;
+
+  /// Prints a type registered to this dialect.
+  void printType(Type type, llvm::raw_ostream &os) const override;
+
+  /// Checks if a type is valid in SPIR-V dialect.
+  bool isValidSPIRVType(Type t) const;
+};
+
+} // end namespace spirv
+} // end namespace mlir
+
+#endif // MLIR_DIALECT_SPIRV_SPIRVDIALECT_H_
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
new file mode 100644
index 0000000..104a479
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
@@ -0,0 +1,48 @@
+//===- SPIRVOps.h - MLIR SPIR-V operations ----------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file declares the operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_SPIRVOPS_H_
+#define MLIR_DIALECT_SPIRV_SPIRVOPS_H_
+
+#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/IR/Function.h"
+
+namespace mlir {
+namespace spirv {
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/SPIRV/SPIRVOps.h.inc"
+
+/// Following methods are auto-generated.
+///
+/// Get the name used in the Op to refer to an enum value of the given
+/// `EnumClass`.
+/// template <typename EnumClass> StringRef attributeName();
+///
+/// Get the function that can be used to symbolize an enum value.
+/// template <typename EnumClass>
+/// llvm::Optional<EnumClass> (*)(StringRef) symbolizeEnum();
+#include "mlir/Dialect/SPIRV/SPIRVOpUtils.inc"
+
+} // end namespace spirv
+} // end namespace mlir
+
+#endif // MLIR_DIALECT_SPIRV_SPIRVOPS_H_
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
new file mode 100644
index 0000000..b833da5
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
@@ -0,0 +1,1224 @@
+//===-- SPIRVOps.td - MLIR SPIR-V Op Definitions Spec ------*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is the main operation definition specification file for SPIR-V
+// operations.
+//
+//===----------------------------------------------------------------------===//
+
+// Note that for each op in this file, we use a tool to automatically generate
+// certain sections in its definition: basic structure, summary, description.
+// So modifications to these sections will not be respected. Modifications to
+// op traits, arguments, results, and sections after the results are retained.
+// Besides, ops in this file must be separated via the '// -----' marker.
+
+#ifdef SPIRV_OPS
+#else
+#define SPIRV_OPS
+
+#ifdef SPIRV_BASE
+#else
+include "mlir/Dialect/SPIRV/SPIRVBase.td"
+#endif // SPIRV_BASE
+
+#ifdef SPIRV_STRUCTURE_OPS
+#else
+// Pull in ops for defining the SPIR-V module structure
+include "mlir/Dialect/SPIRV/SPIRVStructureOps.td"
+#endif // SPIRV_STRUCTURE_OPS
+
+// -----
+
+def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> {
+  let summary = [{
+    Create a pointer into a composite object that can be used with OpLoad
+    and OpStore.
+  }];
+
+  let description = [{
+    Result Type must be an OpTypePointer. Its Type operand must be the type
+    reached by walking the Base’s type hierarchy down to the last provided
+    index in Indexes, and its Storage Class operand must be the same as the
+    Storage Class of Base.
+
+    Base must be a pointer, pointing to the base of a composite object.
+
+    Indexes walk the type hierarchy to the desired depth, potentially down
+    to scalar granularity. The first index in Indexes will select the top-
+    level member/element/component/element of the base composite. All
+    composite constituents use zero-based numbering, as described by their
+    OpType… instruction. The second index will apply similarly to that
+    result, and so on. Once any non-composite type is reached, there must be
+    no remaining (unused) indexes.
+
+     Each index in Indexes
+
+    - must be a scalar integer type,
+
+    - is treated as a signed count, and
+
+    - must be an OpConstant when indexing into a structure.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    access-chain-op ::= ssa-id `=` `spv.AccessChain` ssa-use
+                        `[` ssa-use (',' ssa-use)* `]`
+                        `:` pointer-type
+    ```
+
+    For example:
+
+    ```
+    %0 = "spv.constant"() { value = 1: i32} : () -> i32
+    %1 = spv.Variable : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+    %2 = spv.AccessChain %1[%0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+    %3 = spv.Load "Function" %2 ["Volatile"] : !spv.array<4xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPV_AnyPtr:$base_ptr,
+    Variadic<SPV_Integer>:$indices
+  );
+
+  let results = (outs
+    SPV_AnyPtr:$component_ptr
+  );
+}
+
+// -----
+
+def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> {
+  let summary = "Extract a part of a composite object.";
+
+  let description = [{
+    Result Type must be the type of object selected by the last provided
+    index.  The instruction result is the extracted object.
+
+    Composite is the composite to extract from.
+
+    Indexes walk the type hierarchy, potentially down to component
+    granularity, to select the part to extract. All indexes must be in
+    bounds.  All composite constituents use zero-based numbering, as
+    described by their OpType… instruction.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    composite-extract-op ::= ssa-id `=` `spv.CompositeExtract` ssa-use
+                             `[` integer-literal (',' integer-literal)* `]`
+                             `:` composite-type
+    ```
+
+    For example:
+
+    ```
+    %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+    %1 = spv.Load "Function" %0 ["Volatile"] : !spv.array<4x!spv.array<4xf32>>
+    %2 = spv.CompositeExtract %1[1 : i32] : !spv.array<4x!spv.array<4xf32>>
+    ```
+
+  }];
+
+  let arguments = (ins
+    SPV_Composite:$composite,
+    I32ArrayAttr:$indices
+  );
+
+  let results = (outs
+    SPV_Type:$component
+  );
+}
+
+// -----
+
+def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> {
+  let summary = [{
+    Declare an entry point, its execution model, and its interface.
+  }];
+
+  let description = [{
+    Execution Model is the execution model for the entry point and its
+    static call tree. See Execution Model.
+
+    Entry Point must be the Result <id> of an OpFunction instruction.
+
+    Name is a name string for the entry point. A module cannot have two
+    OpEntryPoint instructions with the same Execution Model and the same
+    Name string.
+
+    Interface is a list of <id> of global OpVariable instructions. These
+    declare the set of global variables from a module that form the
+    interface of this entry point. The set of Interface <id> must be equal
+    to or a superset of the global OpVariable Result <id> referenced by the
+    entry point’s static call tree, within the interface’s storage classes.
+    Before version 1.4, the interface’s storage classes are limited to the
+    Input and Output storage classes. Starting with version 1.4, the
+    interface’s storage classes are all storage classes used in declaring
+    all global variables referenced by the entry point’s call tree.
+
+    Interface <id> are forward references. Before version 1.4, duplication
+    of these <id> is tolerated. Starting with version 1.4, an <id> must not
+    appear more than once.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    execution-model ::= "Vertex" | "TesellationControl" |
+                        <and other SPIR-V execution models...>
+
+    entry-point-op ::= ssa-id ` = spv.EntryPoint ` execution-model fn-name
+                       (ssa-use ( `, ` ssa-use)* ` : `
+                        pointer-type ( `, ` pointer-type)* )?
+    ```
+
+    For example:
+
+    ```
+    spv.EntryPoint "GLCompute" @foo
+    spv.EntryPoint "Kernel" @foo, %1, %2 : !spv.ptr<f32, Input>, !spv.ptr<f32, Output>
+
+    ```
+  }];
+
+  let arguments = (ins
+    SPV_ExecutionModelAttr:$execution_model,
+    SymbolRefAttr:$fn,
+    Variadic<SPV_AnyPtr>:$interface
+  );
+
+  let results = (outs);
+  let autogenSerialization = 0;
+}
+
+// -----
+
+def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [ModuleOnly]> {
+  let summary = "Declare an execution mode for an entry point.";
+
+  let description = [{
+    Entry Point must be the Entry Point <id> operand of an OpEntryPoint
+    instruction.
+
+    Mode is the execution mode. See Execution Mode.
+
+    This instruction is only valid when the Mode operand is an execution
+    mode that takes no Extra Operands, or takes Extra Operands that are not
+    <id> operands.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    execution-mode ::= "Invocations" | "SpacingEqual" |
+                       <and other SPIR-V execution modes...>
+
+    execution-mode-op ::= `spv.ExecutionMode ` ssa-use execution-mode
+                          (integer-literal (`, ` integer-literal)* )?
+    ```
+
+    For example:
+
+    ```
+    spv.ExecutionMode @foo "ContractionOff"
+    spv.ExecutionMode @bar "LocalSizeHint", 3, 4, 5
+    ```
+  }];
+
+  let arguments = (ins
+    SymbolRefAttr:$fn,
+    SPV_ExecutionModeAttr:$execution_mode,
+    OptionalAttr<I32ArrayAttr>:$values
+  );
+
+  let results = (outs);
+
+  let verifier = [{ return success(); }];
+
+  let autogenSerialization = 0;
+}
+
+// -----
+
+def SPV_FAddOp : SPV_ArithmeticOp<"FAdd", SPV_Float, [Commutative]> {
+  let summary = "Floating-point addition of Operand 1 and Operand 2.";
+
+  let description = [{
+    Result Type must be a scalar or vector of floating-point type.
+
+     The types of Operand 1 and Operand 2 both must be the same as Result
+    Type.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    float-scalar-vector-type ::= float-type |
+                                 `vector<` integer-literal `x` float-type `>`
+    fadd-op ::= ssa-id `=` `spv.FAdd` ssa-use, ssa-use
+                          `:` float-scalar-vector-type
+    ```
+    For example:
+
+    ```
+    %4 = spv.FAdd %0, %1 : f32
+    %5 = spv.FAdd %2, %3 : vector<4xf32>
+    ```
+  }];
+}
+
+// -----
+
+def SPV_FDivOp : SPV_ArithmeticOp<"FDiv", SPV_Float> {
+  let summary = "Floating-point division of Operand 1 divided by Operand 2.";
+
+  let description = [{
+    Result Type must be a scalar or vector of floating-point type.
+
+     The types of Operand 1 and Operand 2 both must be the same as Result
+    Type.
+
+     Results are computed per component.  The resulting value is undefined
+    if Operand 2 is 0.
+    ### Custom assembly form
+    ``` {.ebnf}
+    float-scalar-vector-type ::= float-type |
+                                 `vector<` integer-literal `x` float-type `>`
+    fdiv-op ::= ssa-id `=` `spv.FDiv` ssa-use, ssa-use
+                          `:` float-scalar-vector-type
+    ```
+
+    For example:
+
+    ```
+    %4 = spv.FDiv %0, %1 : f32
+    %5 = spv.FDiv %2, %3 : vector<4xf32>
+    ```
+  }];
+}
+
+// -----
+
+def SPV_FModOp : SPV_ArithmeticOp<"FMod", SPV_Float> {
+  let summary = [{
+    The floating-point remainder whose sign matches the sign of Operand 2.
+  }];
+
+  let description = [{
+    Result Type must be a scalar or vector of floating-point type.
+
+     The types of Operand 1 and Operand 2 both must be the same as Result
+    Type.
+
+     Results are computed per component.  The resulting value is undefined
+    if Operand 2 is 0.  Otherwise, the result is the remainder r of Operand
+    1 divided by Operand 2 where if r ≠ 0, the sign of r is the same as the
+    sign of Operand 2.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    float-scalar-vector-type ::= float-type |
+                                 `vector<` integer-literal `x` float-type `>`
+    fmod-op ::= ssa-id `=` `spv.FMod` ssa-use, ssa-use
+                          `:` float-scalar-vector-type
+    ```
+    For example:
+
+    ```
+    %4 = spv.FMod %0, %1 : f32
+    %5 = spv.FMod %2, %3 : vector<4xf32>
+    ```
+  }];
+}
+
+// -----
+
+def SPV_FMulOp : SPV_ArithmeticOp<"FMul", SPV_Float, [Commutative]> {
+  let summary = "Floating-point multiplication of Operand 1 and Operand 2.";
+
+  let description = [{
+    Result Type must be a scalar or vector of floating-point type.
+
+     The types of Operand 1 and Operand 2 both must be the same as Result
+    Type.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    float-scalar-vector-type ::= float-type |
+                                 `vector<` integer-literal `x` float-type `>`
+    fmul-op ::= `spv.FMul` ssa-use, ssa-use
+                          `:` float-scalar-vector-type
+    ```
+
+    For example:
+
+    ```
+    %4 = spv.FMul %0, %1 : f32
+    %5 = spv.FMul %2, %3 : vector<4xf32>
+    ```
+  }];
+}
+
+// -----
+
+def SPV_FRemOp : SPV_ArithmeticOp<"FRem", SPV_Float> {
+  let summary = [{
+    The floating-point remainder whose sign matches the sign of Operand 1.
+  }];
+
+  let description = [{
+    Result Type must be a scalar or vector of floating-point type.
+
+     The types of Operand 1 and Operand 2 both must be the same as Result
+    Type.
+
+     Results are computed per component.  The resulting value is undefined
+    if Operand 2 is 0.  Otherwise, the result is the remainder r of Operand
+    1 divided by Operand 2 where if r ≠ 0, the sign of r is the same as the
+    sign of Operand 1.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    float-scalar-vector-type ::= float-type |
+                                 `vector<` integer-literal `x` float-type `>`
+    frem-op ::= ssa-id `=` `spv.FRemOp` ssa-use, ssa-use
+                          `:` float-scalar-vector-type
+    ```
+
+    For example:
+
+    ```
+    %4 = spv.FRemOp %0, %1 : f32
+    %5 = spv.FRemOp %2, %3 : vector<4xf32>
+    ```
+  }];
+}
+
+// -----
+
+def SPV_FSubOp : SPV_ArithmeticOp<"FSub", SPV_Float> {
+  let summary = "Floating-point subtraction of Operand 2 from Operand 1.";
+
+  let description = [{
+    Result Type must be a scalar or vector of floating-point type.
+
+     The types of Operand 1 and Operand 2 both must be the same as Result
+    Type.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    float-scalar-vector-type ::= float-type |
+                                 `vector<` integer-literal `x` float-type `>`
+    fsub-op ::= ssa-id `=` `spv.FRemOp` ssa-use, ssa-use
+                          `:` float-scalar-vector-type
+    ```
+
+    For example:
+
+    ```
+    %4 = spv.FRemOp %0, %1 : f32
+    %5 = spv.FRemOp %2, %3 : vector<4xf32>
+    ```
+  }];
+}
+
+// -----
+
+def SPV_IAddOp : SPV_ArithmeticOp<"IAdd", SPV_Integer, [Commutative]> {
+  let summary = "Integer addition of Operand 1 and Operand 2.";
+
+  let description = [{
+    Result Type must be a scalar or vector of integer type.
+
+     The type of Operand 1 and Operand 2  must be a scalar or vector of
+    integer type.  They must have the same number of components as Result
+    Type. They must have the same component width as Result Type.
+
+    The resulting value will equal the low-order N bits of the correct
+    result R, where N is the component width and R is computed with enough
+    precision to avoid overflow and underflow.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    integer-scalar-vector-type ::= integer-type |
+                                 `vector<` integer-literal `x` integer-type `>`
+    iadd-op ::= ssa-id `=` `spv.IAdd` ssa-use, ssa-use
+                          `:` integer-scalar-vector-type
+    ```
+
+    For example:
+
+    ```
+    %4 = spv.IAdd %0, %1 : i32
+    %5 = spv.IAdd %2, %3 : vector<4xi32>
+
+    ```
+  }];
+}
+
+// -----
+
+def SPV_IEqualOp : SPV_LogicalOp<"IEqual", SPV_Integer, [Commutative]> {
+  let summary = "Integer comparison for equality.";
+
+  let description = [{
+    Result Type must be a scalar or vector of Boolean type.
+
+     The type of Operand 1 and Operand 2  must be a scalar or vector of
+    integer type.  They must have the same component width, and they must
+    have the same number of components as Result Type.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    integer-scalar-vector-type ::= integer-type |
+                                 `vector<` integer-literal `x` integer-type `>`
+    iequal-op ::= ssa-id `=` `spv.IEqual` ssa-use, ssa-use
+                             `:` integer-scalar-vector-type
+    ```
+    For example:
+
+    ```
+    %4 = spv.IEqual %0, %1 : i32
+    %5 = spv.IEqual %2, %3 : vector<4xi32>
+
+    ```
+  }];
+}
+
+// -----
+
+def SPV_INotEqualOp : SPV_LogicalOp<"INotEqual", SPV_Integer, [Commutative]> {
+  let summary = "Integer comparison for inequality.";
+
+  let description = [{
+    Result Type must be a scalar or vector of Boolean type.
+
+     The type of Operand 1 and Operand 2  must be a scalar or vector of
+    integer type.  They must have the same component width, and they must
+    have the same number of components as Result Type.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    integer-scalar-vector-type ::= integer-type |
+                                 `vector<` integer-literal `x` integer-type `>`
+    inot-equal-op ::= ssa-id `=` `spv.INotEqual` ssa-use, ssa-use
+                                 `:` integer-scalar-vector-type
+    ```
+    For example:
+
+    ```
+    %4 = spv.INotEqual %0, %1 : i32
+    %5 = spv.INotEqual %2, %3 : vector<4xi32>
+
+    ```
+  }];
+}
+
+// -----
+
+def SPV_IMulOp : SPV_ArithmeticOp<"IMul", SPV_Integer, [Commutative]> {
+  let summary = "Integer multiplication of Operand 1 and Operand 2.";
+
+  let description = [{
+    Result Type must be a scalar or vector of integer type.
+
+     The type of Operand 1 and Operand 2  must be a scalar or vector of
+    integer type.  They must have the same number of components as Result
+    Type. They must have the same component width as Result Type.
+
+    The resulting value will equal the low-order N bits of the correct
+    result R, where N is the component width and R is computed with enough
+    precision to avoid overflow and underflow.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    integer-scalar-vector-type ::= integer-type |
+                                 `vector<` integer-literal `x` integer-type `>`
+    imul-op ::= ssa-id `=` `spv.IMul` ssa-use, ssa-use
+                          `:` integer-scalar-vector-type
+    ```
+
+    For example:
+
+    ```
+    %4 = spv.IMul %0, %1 : i32
+    %5 = spv.IMul %2, %3 : vector<4xi32>
+
+    ```
+  }];
+}
+
+// -----
+
+def SPV_ISubOp : SPV_ArithmeticOp<"ISub", SPV_Integer> {
+  let summary = "Integer subtraction of Operand 2 from Operand 1.";
+
+  let description = [{
+    Result Type must be a scalar or vector of integer type.
+
+     The type of Operand 1 and Operand 2  must be a scalar or vector of
+    integer type.  They must have the same number of components as Result
+    Type. They must have the same component width as Result Type.
+
+    The resulting value will equal the low-order N bits of the correct
+    result R, where N is the component width and R is computed with enough
+    precision to avoid overflow and underflow.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    integer-scalar-vector-type ::= integer-type |
+                                 `vector<` integer-literal `x` integer-type `>`
+    isub-op ::= `spv.ISub` ssa-use, ssa-use
+                          `:` integer-scalar-vector-type
+    ```
+
+    For example:
+
+    ```
+    %4 = spv.ISub %0, %1 : i32
+    %5 = spv.ISub %2, %3 : vector<4xi32>
+
+    ```
+  }];
+}
+
+// -----
+
+def SPV_LoadOp : SPV_Op<"Load", []> {
+  let summary = "Load through a pointer.";
+
+  let description = [{
+    Result Type is the type of the loaded object. It must be a type with
+    fixed size; i.e., it cannot be, nor include, any OpTypeRuntimeArray
+    types.
+
+    Pointer is the pointer to load through.  Its type must be an
+    OpTypePointer whose Type operand is the same as Result Type.
+
+    If present, any Memory Operands must begin with a memory operand
+    literal. If not present, it is the same as specifying the memory operand
+    None.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    memory-access ::= `"None"` | `"Volatile"` | `"Aligned", ` integer-literal
+                    | `"NonTemporal"`
+
+    load-op ::= ssa-id ` = spv.Load ` storage-class ssa-use
+                (`[` memory-access `]`)? ` : ` spirv-element-type
+    ```
+
+    For example:
+
+    ```
+    %0 = spv.Variable : !spv.ptr<f32, Function>
+    %1 = spv.Load "Function" %0 : f32
+    %2 = spv.Load "Function" %0 ["Volatile"] : f32
+    %3 = spv.Load "Function" %0 ["Aligned", 4] : f32
+    ```
+  }];
+
+  let arguments = (ins
+    SPV_AnyPtr:$ptr,
+    OptionalAttr<SPV_MemoryAccessAttr>:$memory_access,
+    OptionalAttr<I32Attr>:$alignment
+  );
+
+  let results = (outs
+    SPV_Type:$value
+  );
+}
+
+// -----
+
+def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> {
+  let summary = "Return with no value from a function with void return type.";
+
+  let description = [{
+    This instruction must be the last instruction in a block.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    return-op ::= `spv.Return`
+    ```
+  }];
+
+  let arguments = (ins);
+
+  let results = (outs);
+
+  let parser = [{ return parseNoIOOp(parser, result); }];
+  let printer = [{ printNoIOOp(getOperation(), p); }];
+
+  let verifier = [{ return verifyReturn(*this); }];
+}
+
+// -----
+
+def SPV_SDivOp : SPV_ArithmeticOp<"SDiv", SPV_Integer> {
+  let summary = "Signed-integer division of Operand 1 divided by Operand 2.";
+
+  let description = [{
+    Result Type must be a scalar or vector of integer type.
+
+     The type of Operand 1 and Operand 2  must be a scalar or vector of
+    integer type.  They must have the same number of components as Result
+    Type. They must have the same component width as Result Type.
+
+     Results are computed per component.  The resulting value is undefined
+    if Operand 2 is 0.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    integer-scalar-vector-type ::= integer-type |
+                                 `vector<` integer-literal `x` integer-type `>`
+    sdiv-op ::= ssa-id `=` `spv.SDiv` ssa-use, ssa-use
+                           `:` integer-scalar-vector-type
+    ```
+
+    For example:
+
+    ```
+    %4 = spv.SDiv %0, %1 : i32
+    %5 = spv.SDiv %2, %3 : vector<4xi32>
+
+    ```
+  }];
+}
+
+// -----
+
+def SPV_SGreaterThanOp : SPV_LogicalOp<"SGreaterThan", SPV_Integer, []> {
+  let summary = [{
+    Signed-integer comparison if Operand 1 is greater than  Operand 2.
+  }];
+
+  let description = [{
+    Result Type must be a scalar or vector of Boolean type.
+
+     The type of Operand 1 and Operand 2  must be a scalar or vector of
+    integer type.  They must have the same component width, and they must
+    have the same number of components as Result Type.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    integer-scalar-vector-type ::= integer-type |
+                                 `vector<` integer-literal `x` integer-type `>`
+    sgreater-than-op ::= ssa-id `=` `spv.SGreaterThan` ssa-use, ssa-use
+                                    `:` integer-scalar-vector-type
+    ```
+    For example:
+
+    ```
+    %4 = spv.SGreaterThan %0, %1 : i32
+    %5 = spv.SGreaterThan %2, %3 : vector<4xi32>
+
+    ```
+  }];
+}
+
+// -----
+
+def SPV_SGreaterThanEqualOp : SPV_LogicalOp<"SGreaterThanEqual", SPV_Integer, []> {
+  let summary = [{
+    Signed-integer comparison if Operand 1 is greater than or equal to
+    Operand 2.
+  }];
+
+  let description = [{
+    Result Type must be a scalar or vector of Boolean type.
+
+     The type of Operand 1 and Operand 2  must be a scalar or vector of
+    integer type.  They must have the same component width, and they must
+    have the same number of components as Result Type.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    integer-scalar-vector-type ::= integer-type |
+                                 `vector<` integer-literal `x` integer-type `>`
+    sgreater-than-equal-op ::= ssa-id `=` `spv.SGreaterThanEqual` ssa-use, ssa-use
+                                          `:` integer-scalar-vector-type
+    ```
+    For example:
+
+    ```
+    %4 = spv.SGreaterThanEqual %0, %1 : i32
+    %5 = spv.SGreaterThanEqual %2, %3 : vector<4xi32>
+
+    ```
+  }];
+}
+
+// -----
+
+def SPV_SLessThanOp : SPV_LogicalOp<"SLessThan", SPV_Integer, []> {
+  let summary = [{
+    Signed-integer comparison if Operand 1 is less than Operand 2.
+  }];
+
+  let description = [{
+    Result Type must be a scalar or vector of Boolean type.
+
+     The type of Operand 1 and Operand 2  must be a scalar or vector of
+    integer type.  They must have the same component width, and they must
+    have the same number of components as Result Type.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    integer-scalar-vector-type ::= integer-type |
+                                 `vector<` integer-literal `x` integer-type `>`
+    sless-than-op ::= ssa-id `=` `spv.SLessThan` ssa-use, ssa-use
+                                 `:` integer-scalar-vector-type
+    ```
+    For example:
+
+    ```
+    %4 = spv.SLessThan %0, %1 : i32
+    %5 = spv.SLessThan %2, %3 : vector<4xi32>
+
+    ```
+  }];
+}
+
+// -----
+
+def SPV_SLessThanEqualOp : SPV_LogicalOp<"SLessThanEqual", SPV_Integer, []> {
+  let summary = [{
+    Signed-integer comparison if Operand 1 is less than or equal to Operand
+    2.
+  }];
+
+  let description = [{
+    Result Type must be a scalar or vector of Boolean type.
+
+     The type of Operand 1 and Operand 2  must be a scalar or vector of
+    integer type.  They must have the same component width, and they must
+    have the same number of components as Result Type.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    integer-scalar-vector-type ::= integer-type |
+                                 `vector<` integer-literal `x` integer-type `>`
+    sless-than-equal-op ::= ssa-id `=` `spv.SLessThanEqual` ssa-use, ssa-use
+                                       `:` integer-scalar-vector-type
+    ```
+    For example:
+
+    ```
+    %4 = spv.SLessThanEqual %0, %1 : i32
+    %5 = spv.SLessThanEqual %2, %3 : vector<4xi32>
+
+    ```
+  }];
+}
+
+// -----
+
+def SPV_SModOp : SPV_ArithmeticOp<"SMod", SPV_Integer> {
+  let summary = [{
+    Signed remainder operation for the remainder whose sign matches the sign
+    of Operand 2.
+  }];
+
+  let description = [{
+    Result Type must be a scalar or vector of integer type.
+
+     The type of Operand 1 and Operand 2  must be a scalar or vector of
+    integer type.  They must have the same number of components as Result
+    Type. They must have the same component width as Result Type.
+
+     Results are computed per component.  The resulting value is undefined
+    if Operand 2 is 0.  Otherwise, the result is the remainder r of Operand
+    1 divided by Operand 2 where if r ≠ 0, the sign of r is the same as the
+    sign of Operand 2.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    integer-scalar-vector-type ::= integer-type |
+                                 `vector<` integer-literal `x` integer-type `>`
+    smod-op ::= ssa-id `=` `spv.SMod` ssa-use, ssa-use
+                           `:` integer-scalar-vector-type
+    ```
+    For example:
+
+    ```
+    %4 = spv.SMod %0, %1 : i32
+    %5 = spv.SMod %2, %3 : vector<4xi32>
+
+    ```
+  }];
+}
+
+// -----
+
+def SPV_SRemOp : SPV_ArithmeticOp<"SRem", SPV_Integer> {
+  let summary = [{
+    Signed remainder operation for the remainder whose sign matches the sign
+    of Operand 1.
+  }];
+
+  let description = [{
+    Result Type must be a scalar or vector of integer type.
+
+     The type of Operand 1 and Operand 2  must be a scalar or vector of
+    integer type.  They must have the same number of components as Result
+    Type. They must have the same component width as Result Type.
+
+     Results are computed per component.  The resulting value is undefined
+    if Operand 2 is 0.  Otherwise, the result is the remainder r of Operand
+    1 divided by Operand 2 where if r ≠ 0, the sign of r is the same as the
+    sign of Operand 1.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    integer-scalar-vector-type ::= integer-type |
+                                 `vector<` integer-literal `x` integer-type `>`
+    srem-op ::= ssa-id `=` `spv.SRem` ssa-use, ssa-use
+                           `:` integer-scalar-vector-type
+    ```
+    For example:
+
+    ```
+    %4 = spv.SRem %0, %1 : i32
+    %5 = spv.SRem %2, %3 : vector<4xi32>
+
+    ```
+  }];
+}
+
+// -----
+
+def SPV_StoreOp : SPV_Op<"Store", []> {
+  let summary = "Store through a pointer.";
+
+  let description = [{
+    Pointer is the pointer to store through.  Its type must be an
+    OpTypePointer whose Type operand is the same as the type of Object.
+
+    Object is the object to store.
+
+    If present, any Memory Operands must begin with a memory operand
+    literal. If not present, it is the same as specifying the memory operand
+    None.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    store-op ::= `spv.Store ` storage-class ssa-use `, ` ssa-use `, `
+                  (`[` memory-access `]`)? `:` spirv-element-type
+    ```
+
+    For example:
+
+    ```
+    %0 = spv.Variable : !spv.ptr<f32, Function>
+    %1 = spv.FMul ... : f32
+    spv.Store "Function" %0, %1 : f32
+    spv.Store "Function" %0, %1 ["Volatile"] : f32
+    spv.Store "Function" %0, %1 ["Aligned", 4] : f32
+  }];
+
+  let arguments = (ins
+    SPV_AnyPtr:$ptr,
+    SPV_Type:$value,
+    OptionalAttr<SPV_MemoryAccessAttr>:$memory_access,
+    OptionalAttr<I32Attr>:$alignment
+  );
+
+  let results = (outs);
+}
+
+// -----
+
+def SPV_UDivOp : SPV_ArithmeticOp<"UDiv", SPV_Integer> {
+  let summary = "Unsigned-integer division of Operand 1 divided by Operand 2.";
+
+  let description = [{
+    Result Type must be a scalar or vector of integer type, whose Signedness
+    operand is 0.
+
+     The types of Operand 1 and Operand 2 both must be the same as Result
+    Type.
+
+     Results are computed per component.  The resulting value is undefined
+    if Operand 2 is 0.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    integer-scalar-vector-type ::= integer-type |
+                                 `vector<` integer-literal `x` integer-type `>`
+    udiv-op ::= ssa-id `=` `spv.UDiv` ssa-use, ssa-use
+                           `:` integer-scalar-vector-type
+    ```
+    For example:
+
+    ```
+    %4 = spv.UDiv %0, %1 : i32
+    %5 = spv.UDiv %2, %3 : vector<4xi32>
+
+    ```
+  }];
+}
+
+// -----
+
+def SPV_UGreaterThanOp : SPV_LogicalOp<"UGreaterThan", SPV_Integer, []> {
+  let summary = [{
+    Unsigned-integer comparison if Operand 1 is greater than  Operand 2.
+  }];
+
+  let description = [{
+    Result Type must be a scalar or vector of Boolean type.
+
+     The type of Operand 1 and Operand 2  must be a scalar or vector of
+    integer type.  They must have the same component width, and they must
+    have the same number of components as Result Type.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    integer-scalar-vector-type ::= integer-type |
+                                 `vector<` integer-literal `x` integer-type `>`
+    ugreater-than-op ::= ssa-id `=` `spv.UGreaterThan` ssa-use, ssa-use
+                                    `:` integer-scalar-vector-type
+    ```
+    For example:
+
+    ```
+    %4 = spv.UGreaterhan %0, %1 : i32
+    %5 = spv.UGreaterThan %2, %3 : vector<4xi32>
+
+    ```
+  }];
+}
+
+// -----
+
+def SPV_UGreaterThanEqualOp
+    : SPV_LogicalOp<"UGreaterThanEqual", SPV_Integer, []> {
+  let summary = [{
+    Unsigned-integer comparison if Operand 1 is greater than or equal to
+    Operand 2.
+  }];
+
+  let description = [{
+    Result Type must be a scalar or vector of Boolean type.
+
+     The type of Operand 1 and Operand 2  must be a scalar or vector of
+    integer type.  They must have the same component width, and they must
+    have the same number of components as Result Type.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    integer-scalar-vector-type ::= integer-type |
+                                 `vector<` integer-literal `x` integer-type `>`
+    ugreater-than-equal-op ::= ssa-id `=` `spv.UGreaterThanEqual` ssa-use, ssa-use
+                                          `:` integer-scalar-vector-type
+    ```
+    For example:
+
+    ```
+    %4 = spv.UGreaterThanEqual %0, %1 : i32
+    %5 = spv.UGreaterThanEqual %2, %3 : vector<4xi32>
+
+    ```
+  }];
+}
+
+// -----
+
+def SPV_ULessThanOp : SPV_LogicalOp<"ULessThan", SPV_Integer, []> {
+  let summary = [{
+    Unsigned-integer comparison if Operand 1 is less than Operand 2.
+  }];
+
+  let description = [{
+    Result Type must be a scalar or vector of Boolean type.
+
+     The type of Operand 1 and Operand 2  must be a scalar or vector of
+    integer type.  They must have the same component width, and they must
+    have the same number of components as Result Type.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    integer-scalar-vector-type ::= integer-type |
+                                 `vector<` integer-literal `x` integer-type `>`
+    uless-than-op ::= ssa-id `=` `spv.ULessThan` ssa-use, ssa-use
+                                 `:` integer-scalar-vector-type
+    ```
+    For example:
+
+    ```
+    %4 = spv.ULessThan %0, %1 : i32
+    %5 = spv.ULessThan %2, %3 : vector<4xi32>
+
+    ```
+  }];
+}
+
+// -----
+
+def SPV_ULessThanEqualOp : SPV_LogicalOp<"ULessThanEqual", SPV_Integer, []> {
+  let summary = [{
+    Unsigned-integer comparison if Operand 1 is less than or equal to
+    Operand 2.
+  }];
+
+  let description = [{
+    Result Type must be a scalar or vector of Boolean type.
+
+     The type of Operand 1 and Operand 2  must be a scalar or vector of
+    integer type.  They must have the same component width, and they must
+    have the same number of components as Result Type.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    integer-scalar-vector-type ::= integer-type |
+                                 `vector<` integer-literal `x` integer-type `>`
+    uless-than-equal-op ::= ssa-id `=` `spv.ULessThanEqual` ssa-use, ssa-use
+                                       `:` integer-scalar-vector-type
+    ```
+    For example:
+
+    ```
+    %4 = spv.ULessThanEqual %0, %1 : i32
+    %5 = spv.ULessThanEqual %2, %3 : vector<4xi32>
+
+    ```
+  }];
+}
+
+// -----
+
+def SPV_UModOp : SPV_ArithmeticOp<"UMod", SPV_Integer> {
+  let summary = "Unsigned modulo operation of Operand 1 modulo Operand 2.";
+
+  let description = [{
+    Result Type must be a scalar or vector of integer type, whose Signedness
+    operand is 0.
+
+     The types of Operand 1 and Operand 2 both must be the same as Result
+    Type.
+
+     Results are computed per component.  The resulting value is undefined
+    if Operand 2 is 0.
+
+    ### Custom assembly form
+    ``` {.ebnf}
+    integer-scalar-vector-type ::= integer-type |
+                                 `vector<` integer-literal `x` integer-type `>`
+    umod-op ::= ssa-id `=` `spv.UMod` ssa-use, ssa-use
+                           `:` integer-scalar-vector-type
+    ```
+    For example:
+
+    ```
+    %4 = spv.UMod %0, %1 : i32
+    %5 = spv.UMod %2, %3 : vector<4xi32>
+
+    ```
+  }];
+}
+
+// -----
+
+def SPV_VariableOp : SPV_Op<"Variable", []> {
+  let summary = [{
+    Allocate an object in memory, resulting in a pointer to it, which can be
+    used with OpLoad and OpStore.
+  }];
+
+  let description = [{
+    Result Type must be an OpTypePointer. Its Type operand is the type of
+    object in memory.
+
+    Storage Class is the Storage Class of the memory holding the object. It
+    cannot be Generic. It must be the same as the Storage Class operand of
+    the Result Type.
+
+    Initializer is optional.  If Initializer is present, it will be the
+    initial value of the variable’s memory content. Initializer must be an
+    <id> from a constant instruction or a global (module scope) OpVariable
+    instruction. Initializer must have the same type as the type pointed to
+    by Result Type.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    variable-op ::= ssa-id `=` `spv.Variable` (`init(` ssa-use `)`)?
+                    (`bind(` integer-literal, integer-literal `)`)?
+                    attribute-dict? `:` spirv-pointer-type
+    ```
+
+    where `init` specifies initializer and `bind` specifies the descriptor set
+    and binding number.
+
+    For example:
+
+    ```
+    %0 = spv.constant ...
+
+    %1 = spv.Variable : !spv.ptr<f32, Function>
+    %2 = spv.Variable init(%0): !spv.ptr<f32, Private>
+    %3 = spv.Variable init(%0) bind(1, 2): !spv.ptr<f32, Uniform>
+    ```
+  }];
+
+  let arguments = (ins
+    SPV_StorageClassAttr:$storage_class,
+    SPV_Optional<AnyType>:$initializer
+  );
+
+  let results = (outs
+    SPV_AnyPtr:$pointer
+  );
+}
+
+// -----
+
+#endif // SPIRV_OPS
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
new file mode 100644
index 0000000..b44d8ef
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
@@ -0,0 +1,184 @@
+//===-- SPIRVOps.td - MLIR SPIR-V Op Definitions Spec ------*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains ops for defining the SPIR-V structure: module, function,
+// and module-level operations. The representational form of these ops deviate
+// from the SPIR-V binary format in order to utilize MLIR mechanisms.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef SPIRV_STRUCTURE_OPS
+#else
+#define SPIRV_STRUCTURE_OPS
+
+#ifdef SPIRV_BASE
+#else
+include "mlir/SPIRV/SPIRVBase.td"
+#endif // SPIRV_BASE
+
+def SPV_ModuleOp : SPV_Op<"module",
+                          [SingleBlockImplicitTerminator<"ModuleEndOp">,
+                           NativeOpTrait<"SymbolTable">]> {
+  let summary = "The top-level op that defines a SPIR-V module";
+
+  let description = [{
+    This op defines a SPIR-V module using a MLIR region. The region contains
+    one block. Module-level operations, including functions definitions,
+    are all placed in this block.
+
+    Using an op with a region to define a SPIR-V module enables "embedding"
+    SPIR-V modules in other dialects in a clean manner: this op guarantees
+    the validaty and serializability of a SPIR-V module and thus serves as
+    a clear-cut boundary.
+
+    This op takes no operands and generates no results. This op should not
+    implicitly capture values from the enclosing environment.
+
+    This op has only one region, which only contains one block. The block
+    must be terminated via the `spv._module_end` op.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    addressing-model ::= `"Logical"` | `"Physical32"` | `"Physical64"`
+    memory-model ::= `"Simple"` | `"GLSL450"` | `"OpenCL"` | `"VulkanKHR"`
+    spv-module-op ::= `spv.module` addressing-model memory-model
+                      region
+                      (`attributes` attribute-dict)?
+    ```
+
+    For example:
+
+    ```
+    spv.module "Logical" "VulkanKHR" { }
+
+    spv.module "Logical" "VulkanKHR" {
+      func @do_nothing() -> () {
+        spv.Return
+      }
+    } attributes {
+      capability = ["Shader"],
+      extension = ["SPV_KHR_16bit_storage"]
+    }
+    ```
+  }];
+
+  let arguments = (ins
+    SPV_AddressingModelAttr:$addressing_model,
+    SPV_MemoryModelAttr:$memory_model,
+    OptionalAttr<StrArrayAttr>:$capabilities,
+    OptionalAttr<StrArrayAttr>:$extensions,
+    OptionalAttr<StrArrayAttr>:$extended_instruction_sets
+  );
+
+  let results = (outs);
+
+  let regions = (region SizedRegion<1>:$body);
+
+  let builders = [OpBuilder<"Builder *, OperationState *state">,
+                  OpBuilder<[{Builder *, OperationState *state,
+                              IntegerAttr addressing_model,
+                              IntegerAttr memory_model,
+                              /*optional*/ArrayAttr capabilities = nullptr,
+                              /*optional*/ArrayAttr extensions = nullptr,
+                              /*optional*/ArrayAttr extended_instruction_sets = nullptr}]>];
+
+  // We need to ensure the block inside the region is properly terminated;
+  // the auto-generated builders do not guarantee that.
+  let skipDefaultBuilders = 1;
+
+  let hasOpcode = 0;
+
+  let extraClassDeclaration = [{
+    Block& getBlock() {
+      return this->getOperation()->getRegion(0).front();
+    }
+  }];
+}
+
+def SPV_ModuleEndOp : SPV_Op<"_module_end", [Terminator, ModuleOnly]> {
+  let summary = "The pseudo op that ends a SPIR-V module";
+
+  let description = [{
+    This op terminates the only block inside a `spv.module`'s only region.
+    This op does not have a corresponding SPIR-V instruction and thus will
+    not be serialized into the binary format; it is used solely to satisfy
+    the structual requirement that an block must be ended with a terminator.
+  }];
+
+  let arguments = (ins);
+
+  let results = (outs);
+
+  let parser = [{ return parseNoIOOp(parser, result); }];
+  let printer = [{ printNoIOOp(getOperation(), p); }];
+
+  let verifier = [{ return success(); }];
+
+  let hasOpcode = 0;
+}
+
+def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> {
+  let summary = "The op that declares a SPIR-V constant";
+
+  let description = [{
+    This op declares a SPIR-V constant. SPIR-V has multiple constant
+    instructions covering different constant types:
+
+    * `OpConstantTrue` and `OpConstantFalse` for boolean constants
+    * `OpConstant` for scalar constants
+    * `OpConstantComposite` for composite constants
+    * `OpConstantNull` for null constants
+    * ...
+
+    Having such a plethora of constant instructions renders IR transformations
+    more tedious. Therefore, we use a single `spv.constant` op to represent
+    them all. Note that conversion between those SPIR-V constant instructions
+    and this op is purely mechanical; so it can be scoped to the binary
+    (de)serialzation process.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    spv-constant-op ::= ssa-id `=` `spv.constant` (`spec`)? attribute-value
+                        (`:` spirv-type)?
+    ```
+
+    For example:
+
+    ```
+    %0 = spv.constant spec true
+    %1 = spv.constant dense<[2, 3]> : vector<2xf32>
+    %2 = spv.constant [dense<3.0> : vector<2xf32>] : !spv.array<1xvector<2xf32>>
+    ```
+
+    TODO(antiagainst): support constant structs
+  }];
+
+  let arguments = (ins
+    AnyAttr:$value,
+    UnitAttr:$is_spec_const
+  );
+
+  let results = (outs
+    SPV_Type:$constant
+  );
+
+  let hasOpcode = 0;
+}
+
+#endif // SPIRV_STRUCTURE_OPS
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
new file mode 100644
index 0000000..264fed3
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
@@ -0,0 +1,185 @@
+//===- SPIRVTypes.h - MLIR SPIR-V Types -------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file declares the types in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_SPIRVTYPES_H_
+#define MLIR_DIALECT_SPIRV_SPIRVTYPES_H_
+
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/TypeSupport.h"
+#include "mlir/IR/Types.h"
+
+// Pull in all enum type definitions and utility function declarations
+#include "mlir/Dialect/SPIRV/SPIRVEnums.h.inc"
+
+#include <tuple>
+
+namespace mlir {
+namespace spirv {
+
+namespace detail {
+struct ArrayTypeStorage;
+struct ImageTypeStorage;
+struct PointerTypeStorage;
+struct RuntimeArrayTypeStorage;
+struct StructTypeStorage;
+} // namespace detail
+
+namespace TypeKind {
+enum Kind {
+  Array = Type::FIRST_SPIRV_TYPE,
+  Image,
+  Pointer,
+  RuntimeArray,
+  Struct,
+};
+}
+
+// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
+class CompositeType : public Type {
+public:
+  using Type::Type;
+
+  static bool classof(Type type) {
+    return (type.getKind() == TypeKind::Array ||
+            type.getKind() == TypeKind::Struct ||
+            type.getKind() == StandardTypes::Vector);
+  }
+
+  unsigned getNumElements() const;
+
+  Type getElementType(unsigned) const;
+};
+
+// SPIR-V array type
+class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
+                                        detail::ArrayTypeStorage> {
+public:
+  using Base::Base;
+
+  static bool kindof(unsigned kind) { return kind == TypeKind::Array; }
+
+  static ArrayType get(Type elementType, unsigned elementCount);
+
+  unsigned getNumElements() const;
+
+  Type getElementType() const;
+};
+
+// SPIR-V image type
+class ImageType
+    : public Type::TypeBase<ImageType, Type, detail::ImageTypeStorage> {
+public:
+  using Base::Base;
+
+  static bool kindof(unsigned kind) { return kind == TypeKind::Image; }
+
+  static ImageType
+  get(Type elementType, Dim dim,
+      ImageDepthInfo depth = ImageDepthInfo::DepthUnknown,
+      ImageArrayedInfo arrayed = ImageArrayedInfo::NonArrayed,
+      ImageSamplingInfo samplingInfo = ImageSamplingInfo::SingleSampled,
+      ImageSamplerUseInfo samplerUse = ImageSamplerUseInfo::SamplerUnknown,
+      ImageFormat format = ImageFormat::Unknown) {
+    return ImageType::get(
+        std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
+                   ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>(
+            elementType, dim, depth, arrayed, samplingInfo, samplerUse,
+            format));
+  }
+
+  static ImageType
+      get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
+                     ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>);
+
+  Type getElementType() const;
+  Dim getDim() const;
+  ImageDepthInfo getDepthInfo() const;
+  ImageArrayedInfo getArrayedInfo() const;
+  ImageSamplingInfo getSamplingInfo() const;
+  ImageSamplerUseInfo getSamplerUseInfo() const;
+  ImageFormat getImageFormat() const;
+  // TODO(ravishankarm): Add support for Access qualifier
+};
+
+// SPIR-V pointer type
+class PointerType
+    : public Type::TypeBase<PointerType, Type, detail::PointerTypeStorage> {
+public:
+  using Base::Base;
+
+  static bool kindof(unsigned kind) { return kind == TypeKind::Pointer; }
+
+  static PointerType get(Type pointeeType, StorageClass storageClass);
+
+  Type getPointeeType() const;
+
+  StorageClass getStorageClass() const;
+};
+
+// SPIR-V run-time array type
+class RuntimeArrayType
+    : public Type::TypeBase<RuntimeArrayType, Type,
+                            detail::RuntimeArrayTypeStorage> {
+public:
+  using Base::Base;
+
+  static bool kindof(unsigned kind) { return kind == TypeKind::RuntimeArray; }
+
+  static RuntimeArrayType get(Type elementType);
+
+  Type getElementType() const;
+};
+
+// SPIR-V struct type
+class StructType : public Type::TypeBase<StructType, CompositeType,
+                                         detail::StructTypeStorage> {
+public:
+  using Base::Base;
+
+  // Layout information used for members in a struct in SPIR-V
+  //
+  // TODO(ravishankarm) : For now this only supports the offset type, so uses
+  // uint64_t value to represent the offset, with
+  // std::numeric_limit<uint64_t>::max indicating no offset. Change this to
+  // something that can hold all the information needed for different member
+  // types
+  using LayoutInfo = uint64_t;
+
+  static bool kindof(unsigned kind) { return kind == TypeKind::Struct; }
+
+  static StructType get(ArrayRef<Type> memberTypes);
+
+  static StructType get(ArrayRef<Type> memberTypes,
+                        ArrayRef<LayoutInfo> layoutInfo);
+
+  unsigned getNumElements() const;
+
+  Type getElementType(unsigned) const;
+
+  bool hasLayout() const;
+
+  uint64_t getOffset(unsigned) const;
+};
+
+} // end namespace spirv
+} // end namespace mlir
+
+#endif // MLIR_DIALECT_SPIRV_SPIRVTYPES_H_
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/Serialization.h b/third_party/mlir/include/mlir/Dialect/SPIRV/Serialization.h
new file mode 100644
index 0000000..bfc9062
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/Serialization.h
@@ -0,0 +1,49 @@
+//===- Serialization.h - MLIR SPIR-V (De)serialization ----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file declares the entry points for serialize and deserialze SPIR-V
+// binary modules.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_SERIALIZATION_H_
+#define MLIR_DIALECT_SPIRV_SERIALIZATION_H_
+
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+struct LogicalResult;
+class MLIRContext;
+
+namespace spirv {
+class ModuleOp;
+
+/// Serializes the given SPIR-V `module` and writes to `binary`. On failure,
+/// reports errors to the error handler registered with the MLIR context for
+/// `module`.
+LogicalResult serialize(ModuleOp module, SmallVectorImpl<uint32_t> &binary);
+
+/// Deserializes the given SPIR-V `binary` module and creates a MLIR ModuleOp
+/// in the given `context`. Returns the ModuleOp on success; otherwise, reports
+/// errors to the error handler registered with `context` and returns
+/// llvm::None.
+Optional<ModuleOp> deserialize(ArrayRef<uint32_t> binary, MLIRContext *context);
+
+} // end namespace spirv
+} // end namespace mlir
+
+#endif // MLIR_DIALECT_SPIRV_SERIALIZATION_H_
diff --git a/third_party/mlir/include/mlir/Dialect/Traits.h b/third_party/mlir/include/mlir/Dialect/Traits.h
new file mode 100644
index 0000000..8bb5e4b
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/Traits.h
@@ -0,0 +1,89 @@
+//===- Traits.h - Common op traits shared by dialects -----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file declares common op traits that are not core to MLIR but can be
+// shared by multiple dialects.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRAITS
+#define MLIR_DIALECT_TRAITS
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace OpTrait {
+
+// These functions are out-of-line implementations of the methods in the
+// corresponding trait classes.  This avoids them being template
+// instantiated/duplicated.
+namespace impl {
+LogicalResult verifyCompatibleOperandBroadcast(Operation *op);
+} // namespace impl
+
+namespace util {
+/// Returns true and sets `resultShape` to the broadcasted shape from the two
+/// given shapes if they are broadcast compatible. Returns false and clears
+/// `resultShape` otherwise.
+///
+/// The rules for determing the result shape are:
+///
+/// Zip together the dimensions in the two given shapes by prepending the shape
+/// with less dimensions with 1s. For each dimension pair, deduces the result
+/// dimension according to the following order:
+/// - If there are unknown dimensions, follows the TensorFlow behavior:
+///   - If either dimension is greater than 1, we assume that the program is
+///     correct, and the other dimension will be broadcast to match it.
+///   - If either dimension is 1, the other dimension is the result.
+///   - Otherwise, the result dimension is unknown dimension.
+/// - If one of the dimension is 1, the other dimension is the result.
+/// - If two dimensions are the same, that's the result.
+/// - Otherwise, incompatible shape.
+bool getBroadcastedShape(ArrayRef<int64_t> shape1, ArrayRef<int64_t> shape2,
+                         SmallVectorImpl<int64_t> &resultShape);
+
+/// Returns the result broadcast composition type from the two given types by
+/// following NumPy broadcast semantics. Returned type may have dynamic shape if
+/// either of the input types has dynamic shape. Returns null type if the two
+/// given types are not broadcast-compatible.
+Type getBroadcastedType(Type type1, Type type2);
+} // namespace util
+
+/// This class provides the API for ops that are known to have broadcast-
+/// compatible operand and result types. Specifically,  starting from the
+/// most varying dimension, each dimension pair of the two operands' types
+/// should either be the same or one of them is one. Also, the result type
+/// should have the corresponding dimension equal to the larger one, if known.
+/// Shapes are checked partially if ranks or dimensions are not known. For
+/// example, an op with tensor<? x 2 x f32> and tensor <2 x f32> as operand
+/// types and tensor<3 x 2 x f32> as the result type is broadcast-compatible.
+///
+/// Ths trait assumes the op has two operands and one result, and it asserts
+/// if the pre-condition is not satisfied.
+template <typename ConcreteType>
+class BroadcastableTwoOperandsOneResult
+    : public TraitBase<ConcreteType, BroadcastableTwoOperandsOneResult> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifyCompatibleOperandBroadcast(op);
+  }
+};
+
+} // end namespace OpTrait
+} // end namespace mlir
+
+#endif // MLIR_DIALECT_TRAITS
diff --git a/third_party/mlir/include/mlir/EDSC/Builders.h b/third_party/mlir/include/mlir/EDSC/Builders.h
new file mode 100644
index 0000000..c1df3cf
--- /dev/null
+++ b/third_party/mlir/include/mlir/EDSC/Builders.h
@@ -0,0 +1,500 @@
+//===- Builders.h - MLIR Declarative Builder Classes ------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Provides intuitive composable interfaces for building structured MLIR
+// snippets in a declarative fashion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_EDSC_BUILDERS_H_
+#define MLIR_EDSC_BUILDERS_H_
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Transforms/FoldUtils.h"
+#include "mlir/VectorOps/VectorOps.h"
+
+namespace mlir {
+
+namespace edsc {
+
+struct index_t {
+  explicit index_t(int64_t v) : v(v) {}
+  explicit operator int64_t() { return v; }
+  int64_t v;
+};
+
+class BlockHandle;
+class CapturableHandle;
+class NestedBuilder;
+class ValueHandle;
+
+/// Helper class to transparently handle builder insertion points by RAII.
+/// As its name indicates, a ScopedContext is means to be used locally in a
+/// scoped fashion. This abstracts away all the boilerplate related to
+/// checking proper usage of captures, NestedBuilders as well as handling the
+/// setting and restoring of insertion points.
+class ScopedContext {
+public:
+  ScopedContext(OpBuilder &builder, Location location);
+
+  /// Sets the insertion point of the builder to 'newInsertPt' for the duration
+  /// of the scope. The existing insertion point of the builder is restored on
+  /// destruction.
+  ScopedContext(OpBuilder &builder, OpBuilder::InsertPoint newInsertPt,
+                Location location);
+  ~ScopedContext();
+
+  static MLIRContext *getContext();
+  static OpBuilder &getBuilder();
+  static Location getLocation();
+
+private:
+  /// Only NestedBuilder (which is used to create an operation with a body)
+  /// may access private members in order to implement scoping.
+  friend class NestedBuilder;
+
+  ScopedContext() = delete;
+  ScopedContext(const ScopedContext &) = delete;
+  ScopedContext &operator=(const ScopedContext &) = delete;
+
+  static ScopedContext *&getCurrentScopedContext();
+
+  /// Top level OpBuilder.
+  OpBuilder &builder;
+  /// The previous insertion point of the builder.
+  llvm::Optional<OpBuilder::InsertPoint> prevBuilderInsertPoint;
+  /// Current location.
+  Location location;
+  /// Parent context we return into.
+  ScopedContext *enclosingScopedContext;
+  /// Defensively keeps track of the current NestedBuilder to ensure proper
+  /// scoping usage.
+  NestedBuilder *nestedBuilder;
+
+  // TODO: Implement scoping of ValueHandles. To do this we need a proper data
+  // structure to hold ValueHandle objects. We can emulate one but there should
+  // already be something available in LLVM for this purpose.
+};
+
+/// A NestedBuilder is a scoping abstraction to create an idiomatic syntax
+/// embedded in C++ that serves the purpose of building nested MLIR.
+/// Nesting and compositionality is obtained by using the strict ordering that
+/// exists between object construction and method invocation on said object (in
+/// our case, the call to `operator()`).
+/// This ordering allows implementing an abstraction that decouples definition
+/// from declaration (in a PL sense) on placeholders of type ValueHandle and
+/// BlockHandle.
+class NestedBuilder {
+protected:
+  NestedBuilder() = default;
+  NestedBuilder(const NestedBuilder &) = delete;
+  NestedBuilder(NestedBuilder &&other) : bodyScope(other.bodyScope) {
+    other.bodyScope = nullptr;
+  }
+
+  NestedBuilder &operator=(const NestedBuilder &) = delete;
+  NestedBuilder &operator=(NestedBuilder &&other) {
+    std::swap(bodyScope, other.bodyScope);
+    return *this;
+  }
+
+  /// Enter an mlir::Block and setup a ScopedContext to insert operations at
+  /// the end of it. Since we cannot use c++ language-level scoping to implement
+  /// scoping itself, we use enter/exit pairs of operations.
+  /// As a consequence we must allocate a new OpBuilder + ScopedContext and
+  /// let the escape.
+  /// Step back "prev" times from the end of the block to set up the insertion
+  /// point, which is useful for non-empty blocks.
+  void enter(mlir::Block *block, int prev = 0) {
+    bodyScope = new ScopedContext(
+        ScopedContext::getBuilder(),
+        OpBuilder::InsertPoint(block, std::prev(block->end(), prev)),
+        ScopedContext::getLocation());
+    bodyScope->nestedBuilder = this;
+  }
+
+  /// Exit the current mlir::Block by explicitly deleting the dynamically
+  /// allocated OpBuilder and ScopedContext.
+  void exit() {
+    // Reclaim now to exit the scope.
+    bodyScope->nestedBuilder = nullptr;
+    delete bodyScope;
+    bodyScope = nullptr;
+  }
+
+  /// Custom destructor does nothing because we already destroyed bodyScope
+  /// manually in `exit`. Insert an assertion to defensively guard against
+  /// improper usage of scoping.
+  ~NestedBuilder() {
+    assert(!bodyScope &&
+           "Illegal use of NestedBuilder; must have called exit()");
+  }
+
+private:
+  ScopedContext *bodyScope = nullptr;
+};
+
+/// A LoopBuilder is a generic NestedBuilder for loop-like MLIR operations.
+/// More specifically it is meant to be used as a temporary object for
+/// representing any nested MLIR construct that is "related to" an mlir::Value*
+/// (for now an induction variable).
+/// This is extensible and will evolve in the future as MLIR evolves, hence
+/// the name LoopBuilder (as opposed to say ForBuilder or AffineForBuilder).
+class LoopBuilder : public NestedBuilder {
+public:
+  /// Constructs a new AffineForOp and captures the associated induction
+  /// variable. A ValueHandle pointer is passed as the first argument and is the
+  /// *only* way to capture the loop induction variable.
+  LoopBuilder(ValueHandle *iv, ArrayRef<ValueHandle> lbHandles,
+              ArrayRef<ValueHandle> ubHandles, int64_t step);
+  LoopBuilder(const LoopBuilder &) = delete;
+  LoopBuilder(LoopBuilder &&) = default;
+
+  LoopBuilder &operator=(const LoopBuilder &) = delete;
+  LoopBuilder &operator=(LoopBuilder &&) = default;
+
+  /// The only purpose of this operator is to serve as a sequence point so that
+  /// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
+  /// scoped within a LoopBuilder.
+  ValueHandle operator()(llvm::function_ref<void(void)> fun = nullptr);
+};
+
+/// Explicit nested LoopBuilder. Offers a compressed multi-loop builder to avoid
+/// explicitly writing all the loops in a nest. This simple functionality is
+/// also useful to write rank-agnostic custom ops.
+///
+/// Usage:
+///
+/// ```c++
+///    LoopNestBuilder({&i, &j, &k}, {lb, lb, lb}, {ub, ub, ub}, {1, 1, 1})(
+///      [&](){
+///        ...
+///      });
+/// ```
+///
+/// ```c++
+///    LoopNestBuilder({&i}, {lb}, {ub}, {1})([&](){
+///      LoopNestBuilder({&j}, {lb}, {ub}, {1})([&](){
+///        LoopNestBuilder({&k}, {lb}, {ub}, {1})([&](){
+///          ...
+///        }),
+///      }),
+///    });
+/// ```
+class LoopNestBuilder {
+public:
+  LoopNestBuilder(ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
+                  ArrayRef<ValueHandle> ubs, ArrayRef<int64_t> steps);
+
+  ValueHandle operator()(llvm::function_ref<void(void)> fun = nullptr);
+
+private:
+  SmallVector<LoopBuilder, 4> loops;
+};
+
+// This class exists solely to handle the C++ vexing parse case when
+// trying to enter a Block that has already been constructed.
+class Append {};
+
+/// A BlockBuilder is a NestedBuilder for mlir::Block*.
+/// This exists by opposition to LoopBuilder which is not related to an
+/// mlir::Block* but to a mlir::Value*.
+/// It is meant to be used as a temporary object for representing any nested
+/// MLIR construct that is "related to" an mlir::Block*.
+class BlockBuilder : public NestedBuilder {
+public:
+  /// Enters the mlir::Block* previously captured by `bh` and sets the insertion
+  /// point to its end.
+  BlockBuilder(BlockHandle bh, Append);
+
+  /// Constructs a new mlir::Block with argument types derived from `args`.
+  /// Captures the new block in `bh` and its arguments into `args`.
+  /// Enters the new mlir::Block* and sets the insertion point to its end.
+  ///
+  /// Prerequisites:
+  ///   The ValueHandle `args` are typed delayed ValueHandles; i.e. they are
+  ///   not yet bound to mlir::Value*.
+  BlockBuilder(BlockHandle *bh, ArrayRef<ValueHandle *> args);
+
+  /// The only purpose of this operator is to serve as a sequence point so that
+  /// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
+  /// scoped within a BlockBuilder.
+  void operator()(llvm::function_ref<void(void)> fun = nullptr);
+
+private:
+  BlockBuilder(BlockBuilder &) = delete;
+  BlockBuilder &operator=(BlockBuilder &other) = delete;
+};
+
+/// Base class for ValueHandle, OperationHandle and BlockHandle.
+/// Not meant to be used outside of these classes.
+class CapturableHandle {
+protected:
+  CapturableHandle() = default;
+};
+
+/// ValueHandle implements a (potentially "delayed") typed Value abstraction.
+/// ValueHandle should be captured by pointer but otherwise passed by Value
+/// everywhere.
+/// A ValueHandle can have 3 states:
+///   1. null state (empty type and empty value), in which case it does not hold
+///      a value and must never hold a Value (now or in the future). This is
+///      used for MLIR operations with zero returns as well as the result of
+///      calling a NestedBuilder::operator(). In both cases the objective is to
+///      have an object that can be inserted in an ArrayRef<ValueHandle> to
+///      implement nesting;
+///   2. delayed state (empty value), in which case it represents an eagerly
+///      typed "delayed" value that can be hold a Value in the future;
+///   3. constructed state,in which case it holds a Value.
+///
+/// A ValueHandle is meant to capture a single Value* and should be used for
+/// operations that have a single result. For convenience of use, we also
+/// include AffineForOp in this category although it does not return a value.
+/// In the case of AffineForOp, the captured Value* is the loop induction
+/// variable.
+class ValueHandle : public CapturableHandle {
+public:
+  /// A ValueHandle in a null state can never be captured;
+  static ValueHandle null() { return ValueHandle(); }
+
+  /// A ValueHandle that is constructed from a Type represents a typed "delayed"
+  /// Value. A delayed Value can only capture Values of the specified type.
+  /// Such a delayed value represents the declaration (in the PL sense) of a
+  /// placeholder for an mlir::Value* that will be constructed and captured at
+  /// some later point in the program.
+  explicit ValueHandle(Type t) : t(t), v(nullptr) {}
+
+  /// A ValueHandle that is constructed from an mlir::Value* is an "eager"
+  /// Value. An eager Value represents both the declaration and the definition
+  /// (in the PL sense) of a placeholder for an mlir::Value* that has already
+  /// been constructed in the past and that is captured "now" in the program.
+  explicit ValueHandle(Value *v) : t(v->getType()), v(v) {}
+
+  /// Builds a ConstantIndexOp of value `cst`. The constant is created at the
+  /// current insertion point.
+  /// This implicit constructor is provided to each build an eager Value for a
+  /// constant at the current insertion point in the IR. An implicit constructor
+  /// allows idiomatic expressions mixing ValueHandle and literals.
+  ValueHandle(index_t cst);
+
+  /// ValueHandle is a value type, use the default copy constructor.
+  ValueHandle(const ValueHandle &other) = default;
+
+  /// ValueHandle is a value type, the assignment operator typechecks before
+  /// assigning.
+  ValueHandle &operator=(const ValueHandle &other);
+
+  /// Provide a swap operator.
+  void swap(ValueHandle &other) {
+    if (this == &other)
+      return;
+    std::swap(t, other.t);
+    std::swap(v, other.v);
+  }
+
+  /// Implicit conversion useful for automatic conversion to Container<Value*>.
+  operator Value *() const { return getValue(); }
+
+  /// Generic mlir::Op create. This is the key to being extensible to the whole
+  /// of MLIR without duplicating the type system or the op definitions.
+  template <typename Op, typename... Args>
+  static ValueHandle create(Args... args);
+
+  /// Generic mlir::Op create. This is the key to being extensible to the whole
+  /// of MLIR without duplicating the type system or the op definitions.
+  template <typename Op, typename... Args>
+  static ValueHandle create(OperationFolder &folder, Args... args);
+
+  /// Special case to build composed AffineApply operations.
+  // TODO: createOrFold when available and move inside of the `create` method.
+  static ValueHandle createComposedAffineApply(AffineMap map,
+                                               ArrayRef<Value *> operands);
+
+  /// Generic create for a named operation producing a single value.
+  static ValueHandle create(StringRef name, ArrayRef<ValueHandle> operands,
+                            ArrayRef<Type> resultTypes,
+                            ArrayRef<NamedAttribute> attributes = {});
+
+  bool hasValue() const { return v != nullptr; }
+  Value *getValue() const {
+    assert(hasValue() && "Unexpected null value;");
+    return v;
+  }
+  bool hasType() const { return t != Type(); }
+  Type getType() const { return t; }
+
+  Operation *getOperation() const {
+    if (!v)
+      return nullptr;
+    return v->getDefiningOp();
+  }
+
+protected:
+  ValueHandle() : t(), v(nullptr) {}
+
+  Type t;
+  Value *v;
+};
+
+/// An OperationHandle can be used in lieu of ValueHandle to capture the
+/// operation in cases when one does not care about, or cannot extract, a
+/// unique Value* from the operation.
+/// This can be used for capturing zero result operations as well as
+/// multi-result operations that are not supported by ValueHandle.
+/// We do not distinguish further between zero and multi-result operations at
+/// this time.
+struct OperationHandle : public CapturableHandle {
+  OperationHandle() : op(nullptr) {}
+  OperationHandle(Operation *op) : op(op) {}
+
+  OperationHandle(const OperationHandle &) = default;
+  OperationHandle &operator=(const OperationHandle &) = default;
+
+  /// Generic mlir::Op create. This is the key to being extensible to the whole
+  /// of MLIR without duplicating the type system or the op definitions.
+  template <typename Op, typename... Args>
+  static OperationHandle create(Args... args);
+  template <typename Op, typename... Args> static Op createOp(Args... args);
+
+  /// Generic create for a named operation.
+  static OperationHandle create(StringRef name, ArrayRef<ValueHandle> operands,
+                                ArrayRef<Type> resultTypes,
+                                ArrayRef<NamedAttribute> attributes = {});
+
+  operator Operation *() { return op; }
+  Operation *getOperation() const { return op; }
+
+private:
+  Operation *op;
+};
+
+/// Simple wrapper to build a generic operation without successor blocks.
+template <typename HandleType> struct CustomOperation {
+  CustomOperation(StringRef name) : name(name) {
+    static_assert(std::is_same<HandleType, ValueHandle>() ||
+                      std::is_same<HandleType, OperationHandle>(),
+                  "Only CustomOperation<ValueHandle> or "
+                  "CustomOperation<OperationHandle> can be constructed.");
+  }
+  HandleType operator()(ArrayRef<ValueHandle> operands = {},
+                        ArrayRef<Type> resultTypes = {},
+                        ArrayRef<NamedAttribute> attributes = {}) {
+    return HandleType::create(name, operands, resultTypes, attributes);
+  }
+  std::string name;
+};
+
+/// A BlockHandle represents a (potentially "delayed") Block abstraction.
+/// This extra abstraction is necessary because an mlir::Block is not an
+/// mlir::Value.
+/// A BlockHandle should be captured by pointer but otherwise passed by Value
+/// everywhere.
+class BlockHandle : public CapturableHandle {
+public:
+  /// A BlockHandle constructed without an mlir::Block* represents a "delayed"
+  /// Block. A delayed Block represents the declaration (in the PL sense) of a
+  /// placeholder for an mlir::Block* that will be constructed and captured at
+  /// some later point in the program.
+  BlockHandle() : block(nullptr) {}
+
+  /// A BlockHandle constructed with an mlir::Block* represents an "eager"
+  /// Block. An eager Block represents both the declaration and the definition
+  /// (in the PL sense) of a placeholder for an mlir::Block* that has already
+  /// been constructed in the past and that is captured "now" in the program.
+  BlockHandle(mlir::Block *block) : block(block) {}
+
+  /// BlockHandle is a value type, use the default copy constructor and
+  /// assignment operator.
+  BlockHandle(const BlockHandle &) = default;
+  BlockHandle &operator=(const BlockHandle &) = default;
+
+  /// Delegates block creation to MLIR and wrap the resulting mlir::Block.
+  static BlockHandle create(ArrayRef<Type> argTypes);
+
+  operator bool() { return block != nullptr; }
+  operator mlir::Block *() { return block; }
+  mlir::Block *getBlock() { return block; }
+
+private:
+  mlir::Block *block;
+};
+
+template <typename Op, typename... Args>
+OperationHandle OperationHandle::create(Args... args) {
+  return OperationHandle(ScopedContext::getBuilder()
+                             .create<Op>(ScopedContext::getLocation(), args...)
+                             .getOperation());
+}
+
+template <typename Op, typename... Args>
+Op OperationHandle::createOp(Args... args) {
+  return cast<Op>(
+      OperationHandle(ScopedContext::getBuilder()
+                          .create<Op>(ScopedContext::getLocation(), args...)
+                          .getOperation())
+          .getOperation());
+}
+
+template <typename Op, typename... Args>
+ValueHandle ValueHandle::create(Args... args) {
+  Operation *op = ScopedContext::getBuilder()
+                      .create<Op>(ScopedContext::getLocation(), args...)
+                      .getOperation();
+  if (op->getNumResults() == 1) {
+    return ValueHandle(op->getResult(0));
+  } else if (op->getNumResults() == 0) {
+    if (auto f = dyn_cast<AffineForOp>(op)) {
+      return ValueHandle(f.getInductionVar());
+    }
+  }
+  llvm_unreachable("unsupported operation, use an OperationHandle instead");
+}
+
+template <typename Op, typename... Args>
+ValueHandle ValueHandle::create(OperationFolder &folder, Args... args) {
+  return ValueHandle(folder.create<Op>(ScopedContext::getBuilder(),
+                                       ScopedContext::getLocation(), args...));
+}
+
+namespace op {
+
+ValueHandle operator+(ValueHandle lhs, ValueHandle rhs);
+ValueHandle operator-(ValueHandle lhs, ValueHandle rhs);
+ValueHandle operator*(ValueHandle lhs, ValueHandle rhs);
+ValueHandle operator/(ValueHandle lhs, ValueHandle rhs);
+ValueHandle operator%(ValueHandle lhs, ValueHandle rhs);
+ValueHandle floorDiv(ValueHandle lhs, ValueHandle rhs);
+ValueHandle ceilDiv(ValueHandle lhs, ValueHandle rhs);
+
+ValueHandle operator!(ValueHandle value);
+ValueHandle operator&&(ValueHandle lhs, ValueHandle rhs);
+ValueHandle operator||(ValueHandle lhs, ValueHandle rhs);
+ValueHandle operator^(ValueHandle lhs, ValueHandle rhs);
+ValueHandle operator==(ValueHandle lhs, ValueHandle rhs);
+ValueHandle operator!=(ValueHandle lhs, ValueHandle rhs);
+ValueHandle operator<(ValueHandle lhs, ValueHandle rhs);
+ValueHandle operator<=(ValueHandle lhs, ValueHandle rhs);
+ValueHandle operator>(ValueHandle lhs, ValueHandle rhs);
+ValueHandle operator>=(ValueHandle lhs, ValueHandle rhs);
+
+} // namespace op
+} // namespace edsc
+} // namespace mlir
+
+#endif // MLIR_EDSC_BUILDERS_H_
diff --git a/third_party/mlir/include/mlir/EDSC/CMakeLists.txt b/third_party/mlir/include/mlir/EDSC/CMakeLists.txt
new file mode 100644
index 0000000..0b6f249
--- /dev/null
+++ b/third_party/mlir/include/mlir/EDSC/CMakeLists.txt
@@ -0,0 +1,3 @@
+set(LLVM_TARGET_DEFINITIONS "${MLIR_SOURCE_DIR}/test/mlir-tblgen/reference-impl.td")
+mlir_tablegen("reference-impl.inc" -gen-reference-implementations)
+add_public_tablegen_target(MLIRReferenceImplementationTestGen)
diff --git a/third_party/mlir/include/mlir/EDSC/Helpers.h b/third_party/mlir/include/mlir/EDSC/Helpers.h
new file mode 100644
index 0000000..69b7290
--- /dev/null
+++ b/third_party/mlir/include/mlir/EDSC/Helpers.h
@@ -0,0 +1,267 @@
+//===- Helpers.h - MLIR Declarative Helper Functionality --------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Provides helper classes and syntactic sugar for declarative builders.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_EDSC_HELPERS_H_
+#define MLIR_EDSC_HELPERS_H_
+
+#include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Intrinsics.h"
+
+namespace mlir {
+namespace edsc {
+
+// A TemplatedIndexedValue brings an index notation over the template Load and
+// Store parameters.
+template <typename Load, typename Store> class TemplatedIndexedValue;
+
+// By default, edsc::IndexedValue provides an index notation around the affine
+// load and stores. edsc::StdIndexedValue provides the standard load/store
+// counterpart.
+using IndexedValue =
+    TemplatedIndexedValue<intrinsics::affine_load, intrinsics::affine_store>;
+using StdIndexedValue =
+    TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>;
+
+// Base class for MemRefView and VectorView.
+class View {
+public:
+  unsigned rank() const { return lbs.size(); }
+  ValueHandle lb(unsigned idx) { return lbs[idx]; }
+  ValueHandle ub(unsigned idx) { return ubs[idx]; }
+  int64_t step(unsigned idx) { return steps[idx]; }
+  std::tuple<ValueHandle, ValueHandle, int64_t> range(unsigned idx) {
+    return std::make_tuple(lbs[idx], ubs[idx], steps[idx]);
+  }
+  void swapRanges(unsigned i, unsigned j) {
+    if (i == j)
+      return;
+    lbs[i].swap(lbs[j]);
+    ubs[i].swap(ubs[j]);
+    std::swap(steps[i], steps[j]);
+  }
+
+  ArrayRef<ValueHandle> getLbs() { return lbs; }
+  ArrayRef<ValueHandle> getUbs() { return ubs; }
+  ArrayRef<int64_t> getSteps() { return steps; }
+
+protected:
+  SmallVector<ValueHandle, 8> lbs;
+  SmallVector<ValueHandle, 8> ubs;
+  SmallVector<int64_t, 8> steps;
+};
+
+/// A MemRefView represents the information required to step through a
+/// MemRef. It has placeholders for non-contiguous tensors that fit within the
+/// Fortran subarray model.
+/// At the moment it can only capture a MemRef with an identity layout map.
+// TODO(ntv): Support MemRefs with layoutMaps.
+class MemRefView : public View {
+public:
+  explicit MemRefView(Value *v);
+  MemRefView(const MemRefView &) = default;
+  MemRefView &operator=(const MemRefView &) = default;
+
+  unsigned fastestVarying() const { return rank() - 1; }
+
+private:
+  friend IndexedValue;
+  ValueHandle base;
+};
+
+/// A VectorView represents the information required to step through a
+/// Vector accessing each scalar element at a time. It is the counterpart of
+/// a MemRefView but for vectors. This exists purely for boilerplate avoidance.
+class VectorView : public View {
+public:
+  explicit VectorView(Value *v);
+  VectorView(const VectorView &) = default;
+  VectorView &operator=(const VectorView &) = default;
+
+private:
+  friend IndexedValue;
+  ValueHandle base;
+};
+
+/// A TemplatedIndexedValue brings an index notation over the template Load and
+/// Store parameters. This helper class is an abstraction purely for sugaring
+/// purposes and allows writing compact expressions such as:
+///
+/// ```mlir
+///    // `IndexedValue` provided by default in the mlir::edsc namespace.
+///    using IndexedValue =
+///      TemplatedIndexedValue<intrinsics::load, intrinsics::store>;
+///    IndexedValue A(...), B(...), C(...);
+///    For(ivs, zeros, shapeA, ones, {
+///      C(ivs) = A(ivs) + B(ivs)
+///    });
+/// ```
+///
+/// Assigning to an IndexedValue emits an actual `Store` operation, while
+/// converting an IndexedValue to a ValueHandle emits an actual `Load`
+/// operation.
+template <typename Load, typename Store> class TemplatedIndexedValue {
+public:
+  explicit TemplatedIndexedValue(Type t) : base(t) {}
+  explicit TemplatedIndexedValue(Value *v)
+      : TemplatedIndexedValue(ValueHandle(v)) {}
+  explicit TemplatedIndexedValue(ValueHandle v) : base(v) {}
+
+  TemplatedIndexedValue(const TemplatedIndexedValue &rhs) = default;
+
+  TemplatedIndexedValue operator()() { return *this; }
+  /// Returns a new `TemplatedIndexedValue`.
+  TemplatedIndexedValue operator()(ValueHandle index) {
+    TemplatedIndexedValue res(base);
+    res.indices.push_back(index);
+    return res;
+  }
+  template <typename... Args>
+  TemplatedIndexedValue operator()(ValueHandle index, Args... indices) {
+    return TemplatedIndexedValue(base, index).append(indices...);
+  }
+  TemplatedIndexedValue operator()(llvm::ArrayRef<ValueHandle> indices) {
+    return TemplatedIndexedValue(base, indices);
+  }
+  TemplatedIndexedValue operator()(llvm::ArrayRef<IndexHandle> indices) {
+    return TemplatedIndexedValue(
+        base, llvm::ArrayRef<ValueHandle>(indices.begin(), indices.end()));
+  }
+
+  /// Emits a `store`.
+  // NOLINTNEXTLINE: unconventional-assign-operator
+  OperationHandle operator=(const TemplatedIndexedValue &rhs) {
+    ValueHandle rrhs(rhs);
+    return Store(rrhs, getBase(), {indices.begin(), indices.end()});
+  }
+  // NOLINTNEXTLINE: unconventional-assign-operator
+  OperationHandle operator=(ValueHandle rhs) {
+    return Store(rhs, getBase(), {indices.begin(), indices.end()});
+  }
+
+  /// Emits a `load` when converting to a ValueHandle.
+  operator ValueHandle() const {
+    return Load(getBase(), {indices.begin(), indices.end()});
+  }
+
+  /// Emits a `load` when converting to a Value*.
+  Value *operator*(void)const {
+    return Load(getBase(), {indices.begin(), indices.end()}).getValue();
+  }
+
+  ValueHandle getBase() const { return base; }
+
+  /// Operator overloadings.
+  ValueHandle operator+(ValueHandle e);
+  ValueHandle operator-(ValueHandle e);
+  ValueHandle operator*(ValueHandle e);
+  ValueHandle operator/(ValueHandle e);
+  OperationHandle operator+=(ValueHandle e);
+  OperationHandle operator-=(ValueHandle e);
+  OperationHandle operator*=(ValueHandle e);
+  OperationHandle operator/=(ValueHandle e);
+  ValueHandle operator+(TemplatedIndexedValue e) {
+    return *this + static_cast<ValueHandle>(e);
+  }
+  ValueHandle operator-(TemplatedIndexedValue e) {
+    return *this - static_cast<ValueHandle>(e);
+  }
+  ValueHandle operator*(TemplatedIndexedValue e) {
+    return *this * static_cast<ValueHandle>(e);
+  }
+  ValueHandle operator/(TemplatedIndexedValue e) {
+    return *this / static_cast<ValueHandle>(e);
+  }
+  OperationHandle operator+=(TemplatedIndexedValue e) {
+    return this->operator+=(static_cast<ValueHandle>(e));
+  }
+  OperationHandle operator-=(TemplatedIndexedValue e) {
+    return this->operator-=(static_cast<ValueHandle>(e));
+  }
+  OperationHandle operator*=(TemplatedIndexedValue e) {
+    return this->operator*=(static_cast<ValueHandle>(e));
+  }
+  OperationHandle operator/=(TemplatedIndexedValue e) {
+    return this->operator/=(static_cast<ValueHandle>(e));
+  }
+
+private:
+  TemplatedIndexedValue(ValueHandle base, ArrayRef<ValueHandle> indices)
+      : base(base), indices(indices.begin(), indices.end()) {}
+
+  TemplatedIndexedValue &append() { return *this; }
+
+  template <typename T, typename... Args>
+  TemplatedIndexedValue &append(T index, Args... indices) {
+    this->indices.push_back(static_cast<ValueHandle>(index));
+    append(indices...);
+    return *this;
+  }
+  ValueHandle base;
+  llvm::SmallVector<ValueHandle, 8> indices;
+};
+
+/// Operator overloadings.
+template <typename Load, typename Store>
+ValueHandle TemplatedIndexedValue<Load, Store>::operator+(ValueHandle e) {
+  using op::operator+;
+  return static_cast<ValueHandle>(*this) + e;
+}
+template <typename Load, typename Store>
+ValueHandle TemplatedIndexedValue<Load, Store>::operator-(ValueHandle e) {
+  using op::operator-;
+  return static_cast<ValueHandle>(*this) - e;
+}
+template <typename Load, typename Store>
+ValueHandle TemplatedIndexedValue<Load, Store>::operator*(ValueHandle e) {
+  using op::operator*;
+  return static_cast<ValueHandle>(*this) * e;
+}
+template <typename Load, typename Store>
+ValueHandle TemplatedIndexedValue<Load, Store>::operator/(ValueHandle e) {
+  using op::operator/;
+  return static_cast<ValueHandle>(*this) / e;
+}
+
+template <typename Load, typename Store>
+OperationHandle TemplatedIndexedValue<Load, Store>::operator+=(ValueHandle e) {
+  using op::operator+;
+  return Store(*this + e, getBase(), {indices.begin(), indices.end()});
+}
+template <typename Load, typename Store>
+OperationHandle TemplatedIndexedValue<Load, Store>::operator-=(ValueHandle e) {
+  using op::operator-;
+  return Store(*this - e, getBase(), {indices.begin(), indices.end()});
+}
+template <typename Load, typename Store>
+OperationHandle TemplatedIndexedValue<Load, Store>::operator*=(ValueHandle e) {
+  using op::operator*;
+  return Store(*this * e, getBase(), {indices.begin(), indices.end()});
+}
+template <typename Load, typename Store>
+OperationHandle TemplatedIndexedValue<Load, Store>::operator/=(ValueHandle e) {
+  using op::operator/;
+  return Store(*this / e, getBase(), {indices.begin(), indices.end()});
+}
+
+} // namespace edsc
+} // namespace mlir
+
+#endif // MLIR_EDSC_HELPERS_H_
diff --git a/third_party/mlir/include/mlir/EDSC/Intrinsics.h b/third_party/mlir/include/mlir/EDSC/Intrinsics.h
new file mode 100644
index 0000000..98e9cea
--- /dev/null
+++ b/third_party/mlir/include/mlir/EDSC/Intrinsics.h
@@ -0,0 +1,278 @@
+//===- Intrinsics.h - MLIR Operations for Declarative Builders ---*- C++-*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Provides intuitive composable intrinsics for building snippets of MLIR
+// declaratively
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_EDSC_INTRINSICS_H_
+#define MLIR_EDSC_INTRINSICS_H_
+
+#include "mlir/EDSC/Builders.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+
+class MemRefType;
+class Type;
+
+namespace edsc {
+
+/// An IndexHandle is a simple wrapper around a ValueHandle.
+/// IndexHandles are ubiquitous enough to justify a new type to allow simple
+/// declarations without boilerplate such as:
+///
+/// ```c++
+///    IndexHandle i, j, k;
+/// ```
+struct IndexHandle : public ValueHandle {
+  explicit IndexHandle()
+      : ValueHandle(ScopedContext::getBuilder().getIndexType()) {}
+  explicit IndexHandle(index_t v) : ValueHandle(v) {}
+  explicit IndexHandle(Value *v) : ValueHandle(v) {
+    assert(v->getType() == ScopedContext::getBuilder().getIndexType() &&
+           "Expected index type");
+  }
+  explicit IndexHandle(ValueHandle v) : ValueHandle(v) {
+    assert(v.getType() == ScopedContext::getBuilder().getIndexType() &&
+           "Expected index type");
+  }
+  IndexHandle &operator=(const ValueHandle &v) {
+    assert(v.getType() == ScopedContext::getBuilder().getIndexType() &&
+           "Expected index type");
+    /// Creating a new IndexHandle(v) and then std::swap rightly complains the
+    /// binding has already occurred and that we should use another name.
+    this->t = v.getType();
+    this->v = v.getValue();
+    return *this;
+  }
+};
+
+inline SmallVector<IndexHandle, 8> makeIndexHandles(unsigned rank) {
+  return SmallVector<IndexHandle, 8>(rank);
+}
+
+inline SmallVector<ValueHandle *, 8>
+makeIndexHandlePointers(MutableArrayRef<IndexHandle> ivs) {
+  SmallVector<ValueHandle *, 8> pivs;
+  pivs.reserve(ivs.size());
+  for (auto &iv : ivs) {
+    pivs.push_back(&iv);
+  }
+  return pivs;
+}
+
+/// Returns a vector of the underlying Value* from `ivs`.
+inline SmallVector<Value *, 8> extractValues(ArrayRef<IndexHandle> ivs) {
+  SmallVector<Value *, 8> vals;
+  vals.reserve(ivs.size());
+  for (auto &iv : ivs) {
+    vals.push_back(iv.getValue());
+  }
+  return vals;
+}
+
+/// Provides a set of first class intrinsics.
+/// In the future, most of intrinsics related to Operation that don't contain
+/// other operations should be Tablegen'd.
+namespace intrinsics {
+namespace detail {
+/// Helper structure to be used with ValueBuilder / OperationBuilder.
+/// It serves the purpose of removing boilerplate specialization for the sole
+/// purpose of implicitly converting ArrayRef<ValueHandle> -> ArrayRef<Value*>.
+class ValueHandleArray {
+public:
+  ValueHandleArray(ArrayRef<ValueHandle> vals) {
+    values.append(vals.begin(), vals.end());
+  }
+  ValueHandleArray(ArrayRef<IndexHandle> vals) {
+    values.append(vals.begin(), vals.end());
+  }
+  ValueHandleArray(ArrayRef<index_t> vals) {
+    llvm::SmallVector<IndexHandle, 8> tmp(vals.begin(), vals.end());
+    values.append(tmp.begin(), tmp.end());
+  }
+  operator ArrayRef<Value *>() { return values; }
+
+private:
+  ValueHandleArray() = default;
+  llvm::SmallVector<Value *, 8> values;
+};
+
+template <typename T> inline T unpack(T value) { return value; }
+
+inline detail::ValueHandleArray unpack(ArrayRef<ValueHandle> values) {
+  return detail::ValueHandleArray(values);
+}
+
+} // namespace detail
+
+/// Helper variadic abstraction to allow extending to any MLIR op without
+/// boilerplate or Tablegen.
+/// Arguably a builder is not a ValueHandle but in practice it is only used as
+/// an alias to a notional ValueHandle<Op>.
+/// Implementing it as a subclass allows it to compose all the way to Value*.
+/// Without subclassing, implicit conversion to Value* would fail when composing
+/// in patterns such as: `select(a, b, select(c, d, e))`.
+template <typename Op> struct ValueBuilder : public ValueHandle {
+  // Builder-based
+  template <typename... Args>
+  ValueBuilder(Args... args)
+      : ValueHandle(ValueHandle::create<Op>(detail::unpack(args)...)) {}
+  ValueBuilder(ArrayRef<ValueHandle> vs)
+      : ValueBuilder(ValueBuilder::create<Op>(detail::unpack(vs))) {}
+  template <typename... Args>
+  ValueBuilder(ArrayRef<ValueHandle> vs, Args... args)
+      : ValueHandle(ValueHandle::create<Op>(detail::unpack(vs),
+                                            detail::unpack(args)...)) {}
+  template <typename T, typename... Args>
+  ValueBuilder(T t, ArrayRef<ValueHandle> vs, Args... args)
+      : ValueHandle(ValueHandle::create<Op>(
+            detail::unpack(t), detail::unpack(vs), detail::unpack(args)...)) {}
+  template <typename T1, typename T2, typename... Args>
+  ValueBuilder(T1 t1, T2 t2, ArrayRef<ValueHandle> vs, Args... args)
+      : ValueHandle(ValueHandle::create<Op>(
+            detail::unpack(t1), detail::unpack(t2), detail::unpack(vs),
+            detail::unpack(args)...)) {}
+
+  /// Folder-based
+  template <typename... Args>
+  ValueBuilder(OperationFolder &folder, Args... args)
+      : ValueHandle(ValueHandle::create<Op>(folder, detail::unpack(args)...)) {}
+  ValueBuilder(OperationFolder &folder, ArrayRef<ValueHandle> vs)
+      : ValueBuilder(ValueBuilder::create<Op>(folder, detail::unpack(vs))) {}
+  template <typename... Args>
+  ValueBuilder(OperationFolder &folder, ArrayRef<ValueHandle> vs, Args... args)
+      : ValueHandle(ValueHandle::create<Op>(folder, detail::unpack(vs),
+                                            detail::unpack(args)...)) {}
+  template <typename T, typename... Args>
+  ValueBuilder(OperationFolder &folder, T t, ArrayRef<ValueHandle> vs,
+               Args... args)
+      : ValueHandle(ValueHandle::create<Op>(folder, detail::unpack(t),
+                                            detail::unpack(vs),
+                                            detail::unpack(args)...)) {}
+  template <typename T1, typename T2, typename... Args>
+  ValueBuilder(OperationFolder &folder, T1 t1, T2 t2, ArrayRef<ValueHandle> vs,
+               Args... args)
+      : ValueHandle(ValueHandle::create<Op>(
+            folder, detail::unpack(t1), detail::unpack(t2), detail::unpack(vs),
+            detail::unpack(args)...)) {}
+
+  ValueBuilder() : ValueHandle(ValueHandle::create<Op>()) {}
+};
+
+template <typename Op> struct OperationBuilder : public OperationHandle {
+  template <typename... Args>
+  OperationBuilder(Args... args)
+      : OperationHandle(OperationHandle::create<Op>(detail::unpack(args)...)) {}
+  OperationBuilder(ArrayRef<ValueHandle> vs)
+      : OperationHandle(OperationHandle::create<Op>(detail::unpack(vs))) {}
+  template <typename... Args>
+  OperationBuilder(ArrayRef<ValueHandle> vs, Args... args)
+      : OperationHandle(OperationHandle::create<Op>(detail::unpack(vs),
+                                                    detail::unpack(args)...)) {}
+  template <typename T, typename... Args>
+  OperationBuilder(T t, ArrayRef<ValueHandle> vs, Args... args)
+      : OperationHandle(OperationHandle::create<Op>(
+            detail::unpack(t), detail::unpack(vs), detail::unpack(args)...)) {}
+  template <typename T1, typename T2, typename... Args>
+  OperationBuilder(T1 t1, T2 t2, ArrayRef<ValueHandle> vs, Args... args)
+      : OperationHandle(OperationHandle::create<Op>(
+            detail::unpack(t1), detail::unpack(t2), detail::unpack(vs),
+            detail::unpack(args)...)) {}
+  OperationBuilder() : OperationHandle(OperationHandle::create<Op>()) {}
+};
+
+using alloc = ValueBuilder<AllocOp>;
+using affine_apply = ValueBuilder<AffineApplyOp>;
+using affine_load = ValueBuilder<AffineLoadOp>;
+using affine_store = OperationBuilder<AffineStoreOp>;
+using call = OperationBuilder<mlir::CallOp>;
+using constant_float = ValueBuilder<ConstantFloatOp>;
+using constant_index = ValueBuilder<ConstantIndexOp>;
+using constant_int = ValueBuilder<ConstantIntOp>;
+using dealloc = OperationBuilder<DeallocOp>;
+using dim = ValueBuilder<DimOp>;
+using muli = ValueBuilder<MulIOp>;
+using ret = OperationBuilder<ReturnOp>;
+using select = ValueBuilder<SelectOp>;
+using std_load = ValueBuilder<LoadOp>;
+using std_store = OperationBuilder<StoreOp>;
+using subi = ValueBuilder<SubIOp>;
+using vector_type_cast = ValueBuilder<vector::VectorTypeCastOp>;
+
+/// Branches into the mlir::Block* captured by BlockHandle `b` with `operands`.
+///
+/// Prerequisites:
+///   All Handles have already captured previously constructed IR objects.
+OperationHandle br(BlockHandle bh, ArrayRef<ValueHandle> operands);
+
+/// Creates a new mlir::Block* and branches to it from the current block.
+/// Argument types are specified by `operands`.
+/// Captures the new block in `bh` and the actual `operands` in `captures`. To
+/// insert the new mlir::Block*, a local ScopedContext is constructed and
+/// released to the current block. The branch operation is then added to the
+/// new block.
+///
+/// Prerequisites:
+///   `b` has not yet captured an mlir::Block*.
+///   No `captures` have captured any mlir::Value*.
+///   All `operands` have already captured an mlir::Value*
+///   captures.size() == operands.size()
+///   captures and operands are pairwise of the same type.
+OperationHandle br(BlockHandle *bh, ArrayRef<ValueHandle *> captures,
+                   ArrayRef<ValueHandle> operands);
+
+/// Branches into the mlir::Block* captured by BlockHandle `trueBranch` with
+/// `trueOperands` if `cond` evaluates to `true` (resp. `falseBranch` and
+/// `falseOperand` if `cond` evaluates to `false`).
+///
+/// Prerequisites:
+///   All Handles have captured previouly constructed IR objects.
+OperationHandle cond_br(ValueHandle cond, BlockHandle trueBranch,
+                        ArrayRef<ValueHandle> trueOperands,
+                        BlockHandle falseBranch,
+                        ArrayRef<ValueHandle> falseOperands);
+
+/// Eagerly creates new mlir::Block* with argument types specified by
+/// `trueOperands`/`falseOperands`.
+/// Captures the new blocks in `trueBranch`/`falseBranch` and the arguments in
+/// `trueCaptures/falseCaptures`.
+/// To insert the new mlir::Block*, a local ScopedContext is constructed and
+/// released. The branch operation is then added in the original location and
+/// targeting the eagerly constructed blocks.
+///
+/// Prerequisites:
+///   `trueBranch`/`falseBranch` has not yet captured an mlir::Block*.
+///   No `trueCaptures`/`falseCaptures` have captured any mlir::Value*.
+///   All `trueOperands`/`trueOperands` have already captured an mlir::Value*
+///   `trueCaptures`.size() == `trueOperands`.size()
+///   `falseCaptures`.size() == `falseOperands`.size()
+///   `trueCaptures` and `trueOperands` are pairwise of the same type
+///   `falseCaptures` and `falseOperands` are pairwise of the same type.
+OperationHandle cond_br(ValueHandle cond, BlockHandle *trueBranch,
+                        ArrayRef<ValueHandle *> trueCaptures,
+                        ArrayRef<ValueHandle> trueOperands,
+                        BlockHandle *falseBranch,
+                        ArrayRef<ValueHandle *> falseCaptures,
+                        ArrayRef<ValueHandle> falseOperands);
+} // namespace intrinsics
+} // namespace edsc
+} // namespace mlir
+
+#endif // MLIR_EDSC_INTRINSICS_H_
diff --git a/third_party/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h b/third_party/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
new file mode 100644
index 0000000..69f6c2e
--- /dev/null
+++ b/third_party/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
@@ -0,0 +1,111 @@
+//===- ExecutionEngine.h - MLIR Execution engine and utils -----*- C++ -*--===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file provides a JIT-backed execution engine for MLIR modules.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_EXECUTIONENGINE_EXECUTIONENGINE_H_
+#define MLIR_EXECUTIONENGINE_EXECUTIONENGINE_H_
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/Support/Error.h"
+
+#include <functional>
+#include <memory>
+
+namespace llvm {
+template <typename T> class Expected;
+class Module;
+} // namespace llvm
+
+namespace mlir {
+
+class ModuleOp;
+
+namespace impl {
+class OrcJIT;
+} // end namespace impl
+
+/// JIT-backed execution engine for MLIR modules.  Assumes the module can be
+/// converted to LLVM IR.  For each function, creates a wrapper function with
+/// the fixed interface
+///
+///     void _mlir_funcName(void **)
+///
+/// where the only argument is interpreted as a list of pointers to the actual
+/// arguments of the function, followed by a pointer to the result.  This allows
+/// the engine to provide the caller with a generic function pointer that can
+/// be used to invoke the JIT-compiled function.
+class ExecutionEngine {
+public:
+  ~ExecutionEngine();
+
+  /// Creates an execution engine for the given module.  If `transformer` is
+  /// provided, it will be called on the LLVM module during JIT-compilation and
+  /// can be used, e.g., for reporting or optimization.
+  /// If `sharedLibPaths` are provided, the underlying JIT-compilation will open
+  /// and link the shared libraries for symbol resolution.
+  static llvm::Expected<std::unique_ptr<ExecutionEngine>>
+  create(ModuleOp m,
+         std::function<llvm::Error(llvm::Module *)> transformer = {},
+         ArrayRef<StringRef> sharedLibPaths = {});
+
+  /// Looks up a packed-argument function with the given name and returns a
+  /// pointer to it.  Propagates errors in case of failure.
+  llvm::Expected<void (*)(void **)> lookup(StringRef name) const;
+
+  /// Invokes the function with the given name passing it the list of arguments.
+  /// The arguments are accepted by lvalue-reference since the packed function
+  /// interface expects a list of non-null pointers.
+  template <typename... Args>
+  llvm::Error invoke(StringRef name, Args &... args);
+
+  /// Invokes the function with the given name passing it the list of arguments
+  /// as a list of opaque pointers. This is the arity-agnostic equivalent of
+  /// the templated `invoke`.
+  llvm::Error invoke(StringRef name, MutableArrayRef<void *> args);
+
+  /// Set the target triple on the module. This is implicitly done when creating
+  /// the engine.
+  static bool setupTargetTriple(llvm::Module *llvmModule);
+
+private:
+  // Ordering of llvmContext and jit is important for destruction purposes: the
+  // jit must be destroyed before the context.
+  llvm::LLVMContext llvmContext;
+  // Private implementation of the JIT (PIMPL)
+  std::unique_ptr<impl::OrcJIT> jit;
+};
+
+template <typename... Args>
+llvm::Error ExecutionEngine::invoke(StringRef name, Args &... args) {
+  auto expectedFPtr = lookup(name);
+  if (!expectedFPtr)
+    return expectedFPtr.takeError();
+  auto fptr = *expectedFPtr;
+
+  llvm::SmallVector<void *, 8> packedArgs{static_cast<void *>(&args)...};
+  (*fptr)(packedArgs.data());
+
+  return llvm::Error::success();
+}
+
+} // end namespace mlir
+
+#endif // MLIR_EXECUTIONENGINE_EXECUTIONENGINE_H_
diff --git a/third_party/mlir/include/mlir/ExecutionEngine/MemRefUtils.h b/third_party/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
new file mode 100644
index 0000000..6946864
--- /dev/null
+++ b/third_party/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
@@ -0,0 +1,54 @@
+//===- MemRefUtils.h - MLIR runtime utilities for memrefs -------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is a set of utilities to working with objects of memref type in an JIT
+// context using the MLIR execution engine.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_EXECUTIONENGINE_MEMREFUTILS_H_
+#define MLIR_EXECUTIONENGINE_MEMREFUTILS_H_
+
+#include "mlir/Support/LLVM.h"
+
+namespace llvm {
+template <typename T> class Expected;
+}
+
+namespace mlir {
+class FuncOp;
+
+/// Simple memref descriptor class compatible with the ABI of functions emitted
+/// by MLIR to LLVM IR conversion for statically-shaped memrefs of float type.
+struct StaticFloatMemRef {
+  float *data;
+};
+
+/// Given an MLIR function that takes only statically-shaped memrefs with
+/// element type f32, allocate the memref descriptor and the data storage for
+/// each of the arguments, initialize the storage with `initialValue`, and
+/// return a list of type-erased descriptor pointers.
+llvm::Expected<SmallVector<void *, 8>>
+allocateMemRefArguments(FuncOp func, float initialValue = 0.0);
+
+/// Free a list of type-erased descriptors to statically-shaped memrefs with
+/// element type f32.
+void freeMemRefArguments(ArrayRef<void *> args);
+
+} // namespace mlir
+
+#endif // MLIR_EXECUTIONENGINE_MEMREFUTILS_H_
diff --git a/third_party/mlir/include/mlir/ExecutionEngine/OptUtils.h b/third_party/mlir/include/mlir/ExecutionEngine/OptUtils.h
new file mode 100644
index 0000000..8c0249d
--- /dev/null
+++ b/third_party/mlir/include/mlir/ExecutionEngine/OptUtils.h
@@ -0,0 +1,66 @@
+//===- OptUtils.h - MLIR Execution Engine opt pass utilities ----*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file declares the utility functions to trigger LLVM optimizations from
+// MLIR Execution Engine.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_EXECUTIONENGINE_OPTUTILS_H_
+#define MLIR_EXECUTIONENGINE_OPTUTILS_H_
+
+#include "llvm/Pass.h"
+
+#include <functional>
+#include <string>
+
+namespace llvm {
+class Module;
+class Error;
+class TargetMachine;
+} // namespace llvm
+
+namespace mlir {
+
+/// Initialize LLVM passes that can be when running MLIR code using
+/// ExecutionEngine.
+void initializeLLVMPasses();
+
+/// Create a module transformer function for MLIR ExecutionEngine that runs
+/// LLVM IR passes corresponding to the given speed and size optimization
+/// levels (e.g. -O2 or -Os). If not null, `targetMachine` is used to
+/// initialize passes that provide target-specific information to the LLVM
+/// optimizer. `targetMachine` must outlive the returned std::function.
+std::function<llvm::Error(llvm::Module *)>
+makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel,
+                          llvm::TargetMachine *targetMachine);
+
+/// Create a module transformer function for MLIR ExecutionEngine that runs
+/// LLVM IR passes explicitly specified, plus an optional optimization level,
+/// Any optimization passes, if present, will be inserted before the pass at
+/// position optPassesInsertPos. If not null, `targetMachine` is used to
+/// initialize passes that provide target-specific information to the LLVM
+/// optimizer. `targetMachine` must outlive the returned std::function.
+std::function<llvm::Error(llvm::Module *)>
+makeLLVMPassesTransformer(llvm::ArrayRef<const llvm::PassInfo *> llvmPasses,
+                          llvm::Optional<unsigned> mbOptLevel,
+                          llvm::TargetMachine *targetMachine,
+                          unsigned optPassesInsertPos = 0);
+
+} // end namespace mlir
+
+#endif // LIR_EXECUTIONENGINE_OPTUTILS_H_
diff --git a/third_party/mlir/include/mlir/IR/AffineExpr.h b/third_party/mlir/include/mlir/IR/AffineExpr.h
new file mode 100644
index 0000000..58b4fbc
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/AffineExpr.h
@@ -0,0 +1,311 @@
+//===- AffineExpr.h - MLIR Affine Expr Class --------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// An affine expression is an affine combination of dimension identifiers and
+// symbols, including ceildiv/floordiv/mod by a constant integer.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_AFFINE_EXPR_H
+#define MLIR_IR_AFFINE_EXPR_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseMapInfo.h"
+#include "llvm/Support/Casting.h"
+#include <type_traits>
+
+namespace mlir {
+
+class MLIRContext;
+class AffineMap;
+class IntegerSet;
+
+namespace detail {
+
+struct AffineExprStorage;
+struct AffineBinaryOpExprStorage;
+struct AffineDimExprStorage;
+struct AffineSymbolExprStorage;
+struct AffineConstantExprStorage;
+
+} // namespace detail
+
+enum class AffineExprKind {
+  Add,
+  /// RHS of mul is always a constant or a symbolic expression.
+  Mul,
+  /// RHS of mod is always a constant or a symbolic expression with a positive
+  /// value.
+  Mod,
+  /// RHS of floordiv is always a constant or a symbolic expression.
+  FloorDiv,
+  /// RHS of ceildiv is always a constant or a symbolic expression.
+  CeilDiv,
+
+  /// This is a marker for the last affine binary op. The range of binary
+  /// op's is expected to be this element and earlier.
+  LAST_AFFINE_BINARY_OP = CeilDiv,
+
+  /// Constant integer.
+  Constant,
+  /// Dimensional identifier.
+  DimId,
+  /// Symbolic identifier.
+  SymbolId,
+};
+
+/// Base type for affine expression.
+/// AffineExpr's are immutable value types with intuitive operators to
+/// operate on chainable, lightweight compositions.
+/// An AffineExpr is an interface to the underlying storage type pointer.
+class AffineExpr {
+public:
+  using ImplType = detail::AffineExprStorage;
+
+  AffineExpr() : expr(nullptr) {}
+  /* implicit */ AffineExpr(const ImplType *expr)
+      : expr(const_cast<ImplType *>(expr)) {}
+
+  AffineExpr(const AffineExpr &other) : expr(other.expr) {}
+  AffineExpr &operator=(AffineExpr other) {
+    expr = other.expr;
+    return *this;
+  }
+
+  bool operator==(AffineExpr other) const { return expr == other.expr; }
+  bool operator!=(AffineExpr other) const { return !(*this == other); }
+  explicit operator bool() const { return expr; }
+
+  bool operator!() const { return expr == nullptr; }
+
+  template <typename U> bool isa() const;
+  template <typename U> U dyn_cast() const;
+  template <typename U> U cast() const;
+
+  MLIRContext *getContext() const;
+
+  /// Return the classification for this type.
+  AffineExprKind getKind() const;
+
+  void print(raw_ostream &os) const;
+  void dump() const;
+
+  /// Returns true if this expression is made out of only symbols and
+  /// constants, i.e., it does not involve dimensional identifiers.
+  bool isSymbolicOrConstant() const;
+
+  /// Returns true if this is a pure affine expression, i.e., multiplication,
+  /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
+  bool isPureAffine() const;
+
+  /// Returns the greatest known integral divisor of this affine expression.
+  uint64_t getLargestKnownDivisor() const;
+
+  /// Return true if the affine expression is a multiple of 'factor'.
+  bool isMultipleOf(int64_t factor) const;
+
+  /// Return true if the affine expression involves AffineDimExpr `position`.
+  bool isFunctionOfDim(unsigned position) const;
+
+  /// Walk all of the AffineExpr's in this expression in postorder.
+  void walk(std::function<void(AffineExpr)> callback) const;
+
+  /// This method substitutes any uses of dimensions and symbols (e.g.
+  /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
+  AffineExpr replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
+                                   ArrayRef<AffineExpr> symReplacements) const;
+
+  AffineExpr operator+(int64_t v) const;
+  AffineExpr operator+(AffineExpr other) const;
+  AffineExpr operator-() const;
+  AffineExpr operator-(int64_t v) const;
+  AffineExpr operator-(AffineExpr other) const;
+  AffineExpr operator*(int64_t v) const;
+  AffineExpr operator*(AffineExpr other) const;
+  AffineExpr floorDiv(uint64_t v) const;
+  AffineExpr floorDiv(AffineExpr other) const;
+  AffineExpr ceilDiv(uint64_t v) const;
+  AffineExpr ceilDiv(AffineExpr other) const;
+  AffineExpr operator%(uint64_t v) const;
+  AffineExpr operator%(AffineExpr other) const;
+
+  /// Compose with an AffineMap.
+  /// Returns the composition of this AffineExpr with `map`.
+  ///
+  /// Prerequisites:
+  /// `this` and `map` are composable, i.e. that the number of AffineDimExpr of
+  /// `this` is smaller than the number of results of `map`. If a result of a
+  /// map does not have a corresponding AffineDimExpr, that result simply does
+  /// not appear in the produced AffineExpr.
+  ///
+  /// Example:
+  ///   expr: `d0 + d2`
+  ///   map:  `(d0, d1, d2)[s0, s1] -> (d0 + s1, d1 + s0, d0 + d1 + d2)`
+  ///   returned expr: `d0 * 2 + d1 + d2 + s1`
+  AffineExpr compose(AffineMap map) const;
+
+  friend ::llvm::hash_code hash_value(AffineExpr arg);
+
+protected:
+  ImplType *expr;
+};
+
+/// Affine binary operation expression. An affine binary operation could be an
+/// add, mul, floordiv, ceildiv, or a modulo operation. (Subtraction is
+/// represented through a multiply by -1 and add.) These expressions are always
+/// constructed in a simplified form. For eg., the LHS and RHS operands can't
+/// both be constants. There are additional canonicalizing rules depending on
+/// the op type: see checks in the constructor.
+class AffineBinaryOpExpr : public AffineExpr {
+public:
+  using ImplType = detail::AffineBinaryOpExprStorage;
+  /* implicit */ AffineBinaryOpExpr(AffineExpr::ImplType *ptr);
+  AffineExpr getLHS() const;
+  AffineExpr getRHS() const;
+};
+
+/// A dimensional identifier appearing in an affine expression.
+class AffineDimExpr : public AffineExpr {
+public:
+  using ImplType = detail::AffineDimExprStorage;
+  /* implicit */ AffineDimExpr(AffineExpr::ImplType *ptr);
+  unsigned getPosition() const;
+};
+
+/// A symbolic identifier appearing in an affine expression.
+class AffineSymbolExpr : public AffineExpr {
+public:
+  using ImplType = detail::AffineDimExprStorage;
+  /* implicit */ AffineSymbolExpr(AffineExpr::ImplType *ptr);
+  unsigned getPosition() const;
+};
+
+/// An integer constant appearing in affine expression.
+class AffineConstantExpr : public AffineExpr {
+public:
+  using ImplType = detail::AffineConstantExprStorage;
+  /* implicit */ AffineConstantExpr(AffineExpr::ImplType *ptr);
+  int64_t getValue() const;
+};
+
+/// Make AffineExpr hashable.
+inline ::llvm::hash_code hash_value(AffineExpr arg) {
+  return ::llvm::hash_value(arg.expr);
+}
+
+inline AffineExpr operator+(int64_t val, AffineExpr expr) { return expr + val; }
+inline AffineExpr operator*(int64_t val, AffineExpr expr) { return expr * val; }
+inline AffineExpr operator-(int64_t val, AffineExpr expr) {
+  return expr * (-1) + val;
+}
+
+/// These free functions allow clients of the API to not use classes in detail.
+AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context);
+AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context);
+AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context);
+AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
+                                 AffineExpr rhs);
+
+/// Constructs an affine expression from a flat ArrayRef. If there are local
+/// identifiers (neither dimensional nor symbolic) that appear in the sum of
+/// products expression, 'localExprs' is expected to have the AffineExpr
+/// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the
+/// format [dims, symbols, locals, constant term].
+AffineExpr toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
+                        unsigned numSymbols, ArrayRef<AffineExpr> localExprs,
+                        MLIRContext *context);
+
+raw_ostream &operator<<(raw_ostream &os, AffineExpr &expr);
+
+template <typename U> bool AffineExpr::isa() const {
+  if (std::is_same<U, AffineBinaryOpExpr>::value) {
+    return getKind() <= AffineExprKind::LAST_AFFINE_BINARY_OP;
+  }
+  if (std::is_same<U, AffineDimExpr>::value) {
+    return getKind() == AffineExprKind::DimId;
+  }
+  if (std::is_same<U, AffineSymbolExpr>::value) {
+    return getKind() == AffineExprKind::SymbolId;
+  }
+  if (std::is_same<U, AffineConstantExpr>::value) {
+    return getKind() == AffineExprKind::Constant;
+  }
+}
+template <typename U> U AffineExpr::dyn_cast() const {
+  if (isa<U>()) {
+    return U(expr);
+  }
+  return U(nullptr);
+}
+template <typename U> U AffineExpr::cast() const {
+  assert(isa<U>());
+  return U(expr);
+}
+
+/// Simplify an affine expression by flattening and some amount of
+/// simple analysis. This has complexity linear in the number of nodes in
+/// 'expr'. Returns the simplified expression, which is the same as the input
+///  expression if it can't be simplified.
+AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims,
+                              unsigned numSymbols);
+
+/// Flattens 'expr' into 'flattenedExpr'. Returns true on success or false
+/// if 'expr' could not be flattened (i.e., semi-affine is not yet handled).
+/// 'cst' contains constraints that connect newly introduced local identifiers
+/// to existing dimensional and / symbolic identifiers. See documentation for
+/// AffineExprFlattener on how mod's and div's are flattened.
+bool getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
+                            unsigned numSymbols,
+                            llvm::SmallVectorImpl<int64_t> *flattenedExpr);
+
+/// Flattens the result expressions of the map to their corresponding flattened
+/// forms and set in 'flattenedExprs'. Returns true on success or false
+/// if any expression in the map could not be flattened (i.e., semi-affine is
+/// not yet handled).  For all affine expressions that share the same operands
+/// (like those of an affine map), this method should be used instead of
+/// repeatedly calling getFlattenedAffineExpr since local variables added to
+/// deal with div's and mod's will be reused across expressions.
+bool getFlattenedAffineExprs(
+    AffineMap map, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs);
+bool getFlattenedAffineExprs(
+    IntegerSet set, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs);
+
+} // namespace mlir
+
+namespace llvm {
+
+// AffineExpr hash just like pointers
+template <> struct DenseMapInfo<mlir::AffineExpr> {
+  static mlir::AffineExpr getEmptyKey() {
+    auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+    return mlir::AffineExpr(static_cast<mlir::AffineExpr::ImplType *>(pointer));
+  }
+  static mlir::AffineExpr getTombstoneKey() {
+    auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+    return mlir::AffineExpr(static_cast<mlir::AffineExpr::ImplType *>(pointer));
+  }
+  static unsigned getHashValue(mlir::AffineExpr val) {
+    return mlir::hash_value(val);
+  }
+  static bool isEqual(mlir::AffineExpr LHS, mlir::AffineExpr RHS) {
+    return LHS == RHS;
+  }
+};
+
+} // namespace llvm
+
+#endif // MLIR_IR_AFFINE_EXPR_H
diff --git a/third_party/mlir/include/mlir/IR/AffineExprVisitor.h b/third_party/mlir/include/mlir/IR/AffineExprVisitor.h
new file mode 100644
index 0000000..7b14381
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/AffineExprVisitor.h
@@ -0,0 +1,334 @@
+//===- AffineExprVisitor.h - MLIR AffineExpr Visitor Class ------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines the AffineExpr visitor class.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_AFFINE_EXPR_VISITOR_H
+#define MLIR_IR_AFFINE_EXPR_VISITOR_H
+
+#include "mlir/IR/AffineExpr.h"
+
+namespace mlir {
+
+/// Base class for AffineExpr visitors/walkers.
+///
+/// AffineExpr visitors are used when you want to perform different actions
+/// for different kinds of AffineExprs without having to use lots of casts
+/// and a big switch instruction.
+///
+/// To define your own visitor, inherit from this class, specifying your
+/// new type for the 'SubClass' template parameter, and "override" visitXXX
+/// functions in your class. This class is defined in terms of statically
+/// resolved overloading, not virtual functions.
+///
+/// For example, here is a visitor that counts the number of for AffineDimExprs
+/// in an AffineExpr.
+///
+///  /// Declare the class.  Note that we derive from AffineExprVisitor
+///  /// instantiated with our new subclasses_ type.
+///
+///  struct DimExprCounter : public AffineExprVisitor<DimExprCounter> {
+///    unsigned numDimExprs;
+///    DimExprCounter() : numDimExprs(0) {}
+///    void visitDimExpr(AffineDimExpr expr) { ++numDimExprs; }
+///  };
+///
+///  And this class would be used like this:
+///    DimExprCounter dec;
+///    dec.visit(affineExpr);
+///    numDimExprs = dec.numDimExprs;
+///
+/// AffineExprVisitor provides visit methods for the following binary affine
+/// op expressions:
+/// AffineBinaryAddOpExpr, AffineBinaryMulOpExpr,
+/// AffineBinaryModOpExpr, AffineBinaryFloorDivOpExpr,
+/// AffineBinaryCeilDivOpExpr. Note that default implementations of these
+/// methods will call the general AffineBinaryOpExpr method.
+///
+/// In addition, visit methods are provided for the following affine
+//  expressions: AffineConstantExpr, AffineDimExpr, and
+//  AffineSymbolExpr.
+///
+/// Note that if you don't implement visitXXX for some affine expression type,
+/// the visitXXX method for Instruction superclass will be invoked.
+///
+/// Note that this class is specifically designed as a template to avoid
+/// virtual function call overhead. Defining and using a AffineExprVisitor is
+/// just as efficient as having your own switch instruction over the instruction
+/// opcode.
+
+template <typename SubClass, typename RetTy = void> class AffineExprVisitor {
+  //===--------------------------------------------------------------------===//
+  // Interface code - This is the public interface of the AffineExprVisitor
+  // that you use to visit affine expressions...
+public:
+  // Function to walk an AffineExpr (in post order).
+  RetTy walkPostOrder(AffineExpr expr) {
+    static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
+                  "Must instantiate with a derived type of AffineExprVisitor");
+    switch (expr.getKind()) {
+    case AffineExprKind::Add: {
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      walkOperandsPostOrder(binOpExpr);
+      return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
+    }
+    case AffineExprKind::Mul: {
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      walkOperandsPostOrder(binOpExpr);
+      return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
+    }
+    case AffineExprKind::Mod: {
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      walkOperandsPostOrder(binOpExpr);
+      return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
+    }
+    case AffineExprKind::FloorDiv: {
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      walkOperandsPostOrder(binOpExpr);
+      return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
+    }
+    case AffineExprKind::CeilDiv: {
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      walkOperandsPostOrder(binOpExpr);
+      return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
+    }
+    case AffineExprKind::Constant:
+      return static_cast<SubClass *>(this)->visitConstantExpr(
+          expr.cast<AffineConstantExpr>());
+    case AffineExprKind::DimId:
+      return static_cast<SubClass *>(this)->visitDimExpr(
+          expr.cast<AffineDimExpr>());
+    case AffineExprKind::SymbolId:
+      return static_cast<SubClass *>(this)->visitSymbolExpr(
+          expr.cast<AffineSymbolExpr>());
+    }
+  }
+
+  // Function to visit an AffineExpr.
+  RetTy visit(AffineExpr expr) {
+    static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
+                  "Must instantiate with a derived type of AffineExprVisitor");
+    switch (expr.getKind()) {
+    case AffineExprKind::Add: {
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
+    }
+    case AffineExprKind::Mul: {
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
+    }
+    case AffineExprKind::Mod: {
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
+    }
+    case AffineExprKind::FloorDiv: {
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
+    }
+    case AffineExprKind::CeilDiv: {
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
+    }
+    case AffineExprKind::Constant:
+      return static_cast<SubClass *>(this)->visitConstantExpr(
+          expr.cast<AffineConstantExpr>());
+    case AffineExprKind::DimId:
+      return static_cast<SubClass *>(this)->visitDimExpr(
+          expr.cast<AffineDimExpr>());
+    case AffineExprKind::SymbolId:
+      return static_cast<SubClass *>(this)->visitSymbolExpr(
+          expr.cast<AffineSymbolExpr>());
+    }
+    llvm_unreachable("Unknown AffineExpr");
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Visitation functions... these functions provide default fallbacks in case
+  // the user does not specify what to do for a particular instruction type.
+  // The default behavior is to generalize the instruction type to its subtype
+  // and try visiting the subtype.  All of this should be inlined perfectly,
+  // because there are no virtual functions to get in the way.
+  //
+
+  // Default visit methods. Note that the default op-specific binary op visit
+  // methods call the general visitAffineBinaryOpExpr visit method.
+  void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {}
+  void visitAddExpr(AffineBinaryOpExpr expr) {
+    static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+  }
+  void visitMulExpr(AffineBinaryOpExpr expr) {
+    static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+  }
+  void visitModExpr(AffineBinaryOpExpr expr) {
+    static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+  }
+  void visitFloorDivExpr(AffineBinaryOpExpr expr) {
+    static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+  }
+  void visitCeilDivExpr(AffineBinaryOpExpr expr) {
+    static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+  }
+  void visitConstantExpr(AffineConstantExpr expr) {}
+  void visitDimExpr(AffineDimExpr expr) {}
+  void visitSymbolExpr(AffineSymbolExpr expr) {}
+
+private:
+  // Walk the operands - each operand is itself walked in post order.
+  void walkOperandsPostOrder(AffineBinaryOpExpr expr) {
+    walkPostOrder(expr.getLHS());
+    walkPostOrder(expr.getRHS());
+  }
+};
+
+// This class is used to flatten a pure affine expression (AffineExpr,
+// which is in a tree form) into a sum of products (w.r.t constants) when
+// possible, and in that process simplifying the expression. For a modulo,
+// floordiv, or a ceildiv expression, an additional identifier, called a local
+// identifier, is introduced to rewrite the expression as a sum of product
+// affine expression. Each local identifier is always and by construction a
+// floordiv of a pure add/mul affine function of dimensional, symbolic, and
+// other local identifiers, in a non-mutually recursive way. Hence, every local
+// identifier can ultimately always be recovered as an affine function of
+// dimensional and symbolic identifiers (involving floordiv's); note however
+// that by AffineExpr construction, some floordiv combinations are converted to
+// mod's. The result of the flattening is a flattened expression and a set of
+// constraints involving just the local variables.
+//
+// d2 + (d0 + d1) floordiv 4  is flattened to d2 + q where 'q' is the local
+// variable introduced, with localVarCst containing 4*q <= d0 + d1 <= 4*q + 3.
+//
+// The simplification performed includes the accumulation of contributions for
+// each dimensional and symbolic identifier together, the simplification of
+// floordiv/ceildiv/mod expressions and other simplifications that in turn
+// happen as a result. A simplification that this flattening naturally performs
+// is of simplifying the numerator and denominator of floordiv/ceildiv, and
+// folding a modulo expression to a zero, if possible. Three examples are below:
+//
+// (d0 + 3 * d1) + d0) - 2 * d1) - d0    simplified to     d0 + d1
+// (d0 - d0 mod 4 + 4) mod 4             simplified to     0
+// (3*d0 + 2*d1 + d0) floordiv 2 + d1    simplified to     2*d0 + 2*d1
+//
+// The way the flattening works for the second example is as follows: d0 % 4 is
+// replaced by d0 - 4*q with q being introduced: the expression then simplifies
+// to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to
+// zero. Note that an affine expression may not always be expressible purely as
+// a sum of products involving just the original dimensional and symbolic
+// identifiers due to the presence of modulo/floordiv/ceildiv expressions that
+// may not be eliminated after simplification; in such cases, the final
+// expression can be reconstructed by replacing the local identifiers with their
+// corresponding explicit form stored in 'localExprs' (note that each of the
+// explicit forms itself would have been simplified).
+//
+// The expression walk method here performs a linear time post order walk that
+// performs the above simplifications through visit methods, with partial
+// results being stored in 'operandExprStack'. When a parent expr is visited,
+// the flattened expressions corresponding to its two operands would already be
+// on the stack - the parent expression looks at the two flattened expressions
+// and combines the two. It pops off the operand expressions and pushes the
+// combined result (although this is done in-place on its LHS operand expr).
+// When the walk is completed, the flattened form of the top-level expression
+// would be left on the stack.
+//
+// A flattener can be repeatedly used for multiple affine expressions that bind
+// to the same operands, for example, for all result expressions of an
+// AffineMap or AffineValueMap. In such cases, using it for multiple expressions
+// is more efficient than creating a new flattener for each expression since
+// common idenical div and mod expressions appearing across different
+// expressions are mapped to the same local identifier (same column position in
+// 'localVarCst').
+class SimpleAffineExprFlattener
+    : public AffineExprVisitor<SimpleAffineExprFlattener> {
+public:
+  // Flattend expression layout: [dims, symbols, locals, constant]
+  // Stack that holds the LHS and RHS operands while visiting a binary op expr.
+  // In future, consider adding a prepass to determine how big the SmallVector's
+  // will be, and linearize this to std::vector<int64_t> to prevent
+  // SmallVector moves on re-allocation.
+  std::vector<SmallVector<int64_t, 8>> operandExprStack;
+
+  unsigned numDims;
+  unsigned numSymbols;
+
+  // Number of newly introduced identifiers to flatten mod/floordiv/ceildiv's.
+  unsigned numLocals;
+
+  // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for
+  // which new identifiers were introduced; if the latter do not get canceled
+  // out, these expressions can be readily used to reconstruct the AffineExpr
+  // (tree) form. Note that these expressions themselves would have been
+  // simplified (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4
+  // will be simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1)
+  // ceildiv 2 would be the local expression stored for q.
+  SmallVector<AffineExpr, 4> localExprs;
+
+  SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols);
+
+  virtual ~SimpleAffineExprFlattener() = default;
+
+  // Visitor method overrides.
+  void visitMulExpr(AffineBinaryOpExpr expr);
+  void visitAddExpr(AffineBinaryOpExpr expr);
+  void visitDimExpr(AffineDimExpr expr);
+  void visitSymbolExpr(AffineSymbolExpr expr);
+  void visitConstantExpr(AffineConstantExpr expr);
+  void visitCeilDivExpr(AffineBinaryOpExpr expr);
+  void visitFloorDivExpr(AffineBinaryOpExpr expr);
+
+  //
+  // t = expr mod c   <=>  t = expr - c*q and c*q <= expr <= c*q + c - 1
+  //
+  // A mod expression "expr mod c" is thus flattened by introducing a new local
+  // variable q (= expr floordiv c), such that expr mod c is replaced with
+  // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
+  void visitModExpr(AffineBinaryOpExpr expr);
+
+protected:
+  // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
+  // The local identifier added is always a floordiv of a pure add/mul affine
+  // function of other identifiers, coefficients of which are specified in
+  // dividend and with respect to a positive constant divisor. localExpr is the
+  // simplified tree expression (AffineExpr) corresponding to the quantifier.
+  virtual void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
+                                  AffineExpr localExpr);
+
+private:
+  // t = expr floordiv c   <=> t = q, c * q <= expr <= c * q + c - 1
+  // A floordiv is thus flattened by introducing a new local variable q, and
+  // replacing that expression with 'q' while adding the constraints
+  // c * q <= expr <= c * q + c - 1 to localVarCst (done by
+  // FlatAffineConstraints::addLocalFloorDiv).
+  //
+  // A ceildiv is similarly flattened:
+  // t = expr ceildiv c   <=> t =  (expr + c - 1) floordiv c
+  void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);
+
+  int findLocalId(AffineExpr localExpr);
+
+  inline unsigned getNumCols() const {
+    return numDims + numSymbols + numLocals + 1;
+  }
+  inline unsigned getConstantIndex() const { return getNumCols() - 1; }
+  inline unsigned getLocalVarStartIndex() const { return numDims + numSymbols; }
+  inline unsigned getSymbolStartIndex() const { return numDims; }
+  inline unsigned getDimStartIndex() const { return 0; }
+};
+
+} // end namespace mlir
+
+#endif // MLIR_IR_AFFINE_EXPR_VISITOR_H
diff --git a/third_party/mlir/include/mlir/IR/AffineMap.h b/third_party/mlir/include/mlir/IR/AffineMap.h
new file mode 100644
index 0000000..711cfd8
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/AffineMap.h
@@ -0,0 +1,248 @@
+//===- AffineMap.h - MLIR Affine Map Class ----------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Affine maps are mathematical functions which map a list of dimension
+// identifiers and symbols, to multidimensional affine expressions.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_AFFINE_MAP_H
+#define MLIR_IR_AFFINE_MAP_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMapInfo.h"
+
+namespace mlir {
+
+namespace detail {
+struct AffineMapStorage;
+} // end namespace detail
+
+class AffineExpr;
+class Attribute;
+struct LogicalResult;
+class MLIRContext;
+
+/// A multi-dimensional affine map
+/// Affine map's are immutable like Type's, and they are uniqued.
+/// Eg: (d0, d1) -> (d0/128, d0 mod 128, d1)
+/// The names used (d0, d1) don't matter - it's the mathematical function that
+/// is unique to this affine map.
+class AffineMap {
+public:
+  using ImplType = detail::AffineMapStorage;
+
+  AffineMap() : map(nullptr) {}
+  explicit AffineMap(ImplType *map) : map(map) {}
+  AffineMap(const AffineMap &other) : map(other.map) {}
+  AffineMap &operator=(const AffineMap &other) = default;
+
+  /// Returns a zero result affine map with no dimensions or symbols: () -> ().
+  static AffineMap get(MLIRContext *context);
+
+  static AffineMap get(unsigned dimCount, unsigned symbolCount,
+                       ArrayRef<AffineExpr> results);
+
+  /// Returns a single constant result affine map.
+  static AffineMap getConstantMap(int64_t val, MLIRContext *context);
+
+  /// Returns an AffineMap with 'numDims' identity result dim exprs.
+  static AffineMap getMultiDimIdentityMap(unsigned numDims,
+                                          MLIRContext *context);
+
+  MLIRContext *getContext() const;
+
+  explicit operator bool() { return map != nullptr; }
+  bool operator==(AffineMap other) const { return other.map == map; }
+  bool operator!=(AffineMap other) const { return !(other.map == map); }
+
+  /// Returns true if this affine map is an identity affine map.
+  /// An identity affine map corresponds to an identity affine function on the
+  /// dimensional identifiers.
+  bool isIdentity() const;
+
+  /// Returns true if this affine map is a single result constant function.
+  bool isSingleConstant() const;
+
+  /// Returns the constant result of this map. This methods asserts that the map
+  /// has a single constant result.
+  int64_t getSingleConstantResult() const;
+
+  // Prints affine map to 'os'.
+  void print(raw_ostream &os) const;
+  void dump() const;
+
+  unsigned getNumDims() const;
+  unsigned getNumSymbols() const;
+  unsigned getNumResults() const;
+  unsigned getNumInputs() const;
+
+  ArrayRef<AffineExpr> getResults() const;
+  AffineExpr getResult(unsigned idx) const;
+
+  /// Walk all of the AffineExpr's in this mapping. Each node in an expression
+  /// tree is visited in postorder.
+  void walkExprs(std::function<void(AffineExpr)> callback) const;
+
+  /// This method substitutes any uses of dimensions and symbols (e.g.
+  /// dim#0 with dimReplacements[0]) in subexpressions and returns the modified
+  /// expression mapping.  Because this can be used to eliminate dims and
+  /// symbols, the client needs to specify the number of dims and symbols in
+  /// the result.  The returned map always has the same number of results.
+  AffineMap replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
+                                  ArrayRef<AffineExpr> symReplacements,
+                                  unsigned numResultDims,
+                                  unsigned numResultSyms);
+
+  /// Folds the results of the application of an affine map on the provided
+  /// operands to a constant if possible.
+  LogicalResult constantFold(ArrayRef<Attribute> operandConstants,
+                             SmallVectorImpl<Attribute> &results) const;
+
+  /// Returns the AffineMap resulting from composing `this` with `map`.
+  /// The resulting AffineMap has as many AffineDimExpr as `map` and as many
+  /// AffineSymbolExpr as the concatenation of `this` and `map` (in which case
+  /// the symbols of `this` map come first).
+  ///
+  /// Prerequisites:
+  /// The maps are composable, i.e. that the number of AffineDimExpr of `this`
+  /// matches the number of results of `map`.
+  ///
+  /// Example:
+  ///   map1: `(d0, d1)[s0, s1] -> (d0 + 1 + s1, d1 - 1 - s0)`
+  ///   map2: `(d0)[s0] -> (d0 + s0, d0 - s0))`
+  ///   map1.compose(map2):
+  ///     `(d0)[s0, s1, s2] -> (d0 + s1 + s2 + 1, d0 - s0 - s2 - 1)`
+  AffineMap compose(AffineMap map);
+
+  /// Returns true if the AffineMap represents a subset (i.e. a projection) of a
+  /// symbol-less permutation map.
+  bool isProjectedPermutation();
+
+  /// Returns true if the AffineMap represents a symbol-less permutation map.
+  bool isPermutation();
+
+  /// Returns the map consisting of the `resultPos` subset.
+  AffineMap getSubMap(ArrayRef<unsigned> resultPos);
+
+  friend ::llvm::hash_code hash_value(AffineMap arg);
+
+private:
+  ImplType *map;
+
+  static AffineMap getImpl(unsigned dimCount, unsigned symbolCount,
+                           ArrayRef<AffineExpr> results, MLIRContext *context);
+};
+
+// Make AffineExpr hashable.
+inline ::llvm::hash_code hash_value(AffineMap arg) {
+  return ::llvm::hash_value(arg.map);
+}
+
+/// Simplify an affine map by simplifying its underlying AffineExpr results.
+AffineMap simplifyAffineMap(AffineMap map);
+
+/// Returns a map of codomain to domain dimensions such that the first codomain
+/// dimension for a particular domain dimension is selected.
+/// Returns an empty map if the input map is empty or if `map` is not invertible
+/// (i.e. `map` does not contain a subset that is a permutation of full domain
+/// rank).
+///
+/// Prerequisites:
+///   1. `map` has no symbols.
+///
+/// Example 1:
+///
+/// ```{.mlir}
+///    (d0, d1, d2) -> (d1, d1, d0, d2, d1, d2, d1, d0)
+///                      0       2   3
+/// ```
+///
+/// returns:
+///
+/// ```{.mlir}
+///    (d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d0, d3)
+/// ```
+///
+/// Example 2:
+///
+/// ```{.mlir}
+///    (d0, d1, d2) -> (d1, d0 + d1, d0, d2, d1, d2, d1, d0)
+///                      0            2   3
+/// ```
+///
+/// returns:
+///
+/// ```{.mlir}
+///    (d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d0, d3)
+/// ```
+AffineMap inversePermutation(AffineMap map);
+
+/// Concatenates a list of `maps` into a single AffineMap, stepping over
+/// potentially empty maps. Assumes each of the underlying map has 0 symbols.
+/// The resulting map has a number of dims equal to the max of `maps`' dims and
+/// the concatenated results as its results.
+/// Returns an empty map if all input `maps` are empty.
+///
+/// Example:
+/// When applied to the following list of 3 affine maps,
+///
+/// ```{.mlir}
+///    {
+///      (i, j, k) -> (i, k),
+///      (i, j, k) -> (k, j),
+///      (i, j, k) -> (i, j)
+///    }
+/// ```
+///
+/// Returns the map:
+///
+/// ```{.mlir}
+///     (i, j, k) -> (i, k, k, j, i, j)
+/// ```
+AffineMap concatAffineMaps(llvm::ArrayRef<AffineMap> maps);
+
+inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
+  map.print(os);
+  return os;
+}
+} // end namespace mlir
+
+namespace llvm {
+
+// AffineExpr hash just like pointers
+template <> struct DenseMapInfo<mlir::AffineMap> {
+  static mlir::AffineMap getEmptyKey() {
+    auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+    return mlir::AffineMap(static_cast<mlir::AffineMap::ImplType *>(pointer));
+  }
+  static mlir::AffineMap getTombstoneKey() {
+    auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+    return mlir::AffineMap(static_cast<mlir::AffineMap::ImplType *>(pointer));
+  }
+  static unsigned getHashValue(mlir::AffineMap val) {
+    return mlir::hash_value(val);
+  }
+  static bool isEqual(mlir::AffineMap LHS, mlir::AffineMap RHS) {
+    return LHS == RHS;
+  }
+};
+
+} // namespace llvm
+
+#endif // MLIR_IR_AFFINE_MAP_H
diff --git a/third_party/mlir/include/mlir/IR/AttributeSupport.h b/third_party/mlir/include/mlir/IR/AttributeSupport.h
new file mode 100644
index 0000000..78b3a27
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/AttributeSupport.h
@@ -0,0 +1,116 @@
+//===- AttributeSupport.h ---------------------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines support types for registering dialect extended attributes.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_ATTRIBUTESUPPORT_H
+#define MLIR_IR_ATTRIBUTESUPPORT_H
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/StorageUniquerSupport.h"
+#include "llvm/ADT/PointerIntPair.h"
+
+namespace mlir {
+class MLIRContext;
+class Type;
+
+//===----------------------------------------------------------------------===//
+// AttributeStorage
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+class AttributeUniquer;
+} // end namespace detail
+
+/// Base storage class appearing in an attribute. Derived storage classes should
+/// only be constructed within the context of the AttributeUniquer.
+class AttributeStorage : public StorageUniquer::BaseStorage {
+  friend detail::AttributeUniquer;
+  friend StorageUniquer;
+
+public:
+  /// Get the type of this attribute.
+  Type getType() const;
+
+  /// Get the dialect of this attribute.
+  Dialect &getDialect() const {
+    assert(dialect && "Malformed attribute storage object.");
+    return const_cast<Dialect &>(*dialect);
+  }
+
+protected:
+  /// Construct a new attribute storage instance with the given type.
+  /// Note: All attributes require a valid type. If no type is provided here,
+  ///       the type of the attribute will automatically default to NoneType
+  ///       upon initialization in the uniquer.
+  AttributeStorage(Type type);
+  AttributeStorage();
+
+  /// Set the type of this attribute.
+  void setType(Type type);
+
+  // Set the dialect for this storage instance. This is used by the
+  // AttributeUniquer when initializing a newly constructed storage object.
+  void initializeDialect(Dialect &newDialect) { dialect = &newDialect; }
+
+private:
+  /// The dialect for this attribute.
+  Dialect *dialect;
+
+  /// The opaque type of the attribute value.
+  const void *type;
+};
+
+/// Default storage type for attributes that require no additional
+/// initialization or storage.
+using DefaultAttributeStorage = AttributeStorage;
+
+//===----------------------------------------------------------------------===//
+// AttributeStorageAllocator
+//===----------------------------------------------------------------------===//
+
+// This is a utility allocator used to allocate memory for instances of derived
+// Attributes.
+using AttributeStorageAllocator = StorageUniquer::StorageAllocator;
+
+//===----------------------------------------------------------------------===//
+// AttributeUniquer
+//===----------------------------------------------------------------------===//
+namespace detail {
+// A utility class to get, or create, unique instances of attributes within an
+// MLIRContext. This class manages all creation and uniquing of attributes.
+class AttributeUniquer {
+public:
+  /// Get an uniqued instance of attribute T.
+  template <typename T, typename... Args>
+  static T get(MLIRContext *ctx, unsigned kind, Args &&... args) {
+    return ctx->getAttributeUniquer().get<typename T::ImplType>(
+        getInitFn(ctx, T::getClassID()), kind, std::forward<Args>(args)...);
+  }
+
+private:
+  /// Returns a functor used to initialize new attribute storage instances.
+  static std::function<void(AttributeStorage *)>
+  getInitFn(MLIRContext *ctx, const ClassID *const attrID);
+};
+} // namespace detail
+
+} // end namespace mlir
+
+#endif
diff --git a/third_party/mlir/include/mlir/IR/Attributes.h b/third_party/mlir/include/mlir/IR/Attributes.h
new file mode 100644
index 0000000..323473f
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/Attributes.h
@@ -0,0 +1,954 @@
+//===- Attributes.h - MLIR Attribute Classes --------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_IR_ATTRIBUTES_H
+#define MLIR_IR_ATTRIBUTES_H
+
+#include "mlir/IR/AttributeSupport.h"
+#include "llvm/ADT/APFloat.h"
+
+namespace mlir {
+class AffineMap;
+class Dialect;
+class FunctionType;
+class Identifier;
+class IntegerSet;
+class Location;
+class MLIRContext;
+class ShapedType;
+class Type;
+
+namespace detail {
+
+struct AffineMapAttributeStorage;
+struct ArrayAttributeStorage;
+struct BoolAttributeStorage;
+struct DictionaryAttributeStorage;
+struct IntegerAttributeStorage;
+struct IntegerSetAttributeStorage;
+struct FloatAttributeStorage;
+struct OpaqueAttributeStorage;
+struct StringAttributeStorage;
+struct TypeAttributeStorage;
+
+/// Elements Attributes.
+struct DenseElementsAttributeStorage;
+struct OpaqueElementsAttributeStorage;
+struct SparseElementsAttributeStorage;
+} // namespace detail
+
+/// Attributes are known-constant values of operations and functions.
+///
+/// Instances of the Attribute class are references to immutable, uniqued,
+/// and immortal values owned by MLIRContext. As such, an Attribute is a thin
+/// wrapper around an underlying storage pointer. Attributes are usually passed
+/// by value.
+class Attribute {
+public:
+  /// Integer identifier for all the concrete attribute kinds.
+  enum Kind {
+  // Reserve attribute kinds for dialect specific extensions.
+#define DEFINE_SYM_KIND_RANGE(Dialect)                                         \
+  FIRST_##Dialect##_ATTR, LAST_##Dialect##_ATTR = FIRST_##Dialect##_ATTR + 0xff,
+#include "DialectSymbolRegistry.def"
+  };
+
+  /// Utility class for implementing attributes.
+  template <typename ConcreteType, typename BaseType = Attribute,
+            typename StorageType = AttributeStorage>
+  using AttrBase = detail::StorageUserBase<ConcreteType, BaseType, StorageType,
+                                           detail::AttributeUniquer>;
+
+  using ImplType = AttributeStorage;
+  using ValueType = void;
+
+  Attribute() : impl(nullptr) {}
+  /* implicit */ Attribute(const ImplType *impl)
+      : impl(const_cast<ImplType *>(impl)) {}
+
+  Attribute(const Attribute &other) : impl(other.impl) {}
+  Attribute &operator=(Attribute other) {
+    impl = other.impl;
+    return *this;
+  }
+
+  bool operator==(Attribute other) const { return impl == other.impl; }
+  bool operator!=(Attribute other) const { return !(*this == other); }
+  explicit operator bool() const { return impl; }
+
+  bool operator!() const { return impl == nullptr; }
+
+  template <typename U> bool isa() const;
+  template <typename U> U dyn_cast() const;
+  template <typename U> U dyn_cast_or_null() const;
+  template <typename U> U cast() const;
+
+  // Support dyn_cast'ing Attribute to itself.
+  static bool classof(Attribute) { return true; }
+
+  /// Return the classification for this attribute.
+  unsigned getKind() const { return impl->getKind(); }
+
+  /// Return the type of this attribute.
+  Type getType() const;
+
+  /// Return the context this attribute belongs to.
+  MLIRContext *getContext() const;
+
+  /// Get the dialect this attribute is registered to.
+  Dialect &getDialect() const;
+
+  /// Print the attribute.
+  void print(raw_ostream &os) const;
+  void dump() const;
+
+  /// Get an opaque pointer to the attribute.
+  const void *getAsOpaquePointer() const { return impl; }
+  /// Construct an attribute from the opaque pointer representation.
+  static Attribute getFromOpaquePointer(const void *ptr) {
+    return Attribute(reinterpret_cast<const ImplType *>(ptr));
+  }
+
+  friend ::llvm::hash_code hash_value(Attribute arg);
+
+protected:
+  ImplType *impl;
+};
+
+inline raw_ostream &operator<<(raw_ostream &os, Attribute attr) {
+  attr.print(os);
+  return os;
+}
+
+namespace StandardAttributes {
+enum Kind {
+  AffineMap = Attribute::FIRST_STANDARD_ATTR,
+  Array,
+  Bool,
+  Dictionary,
+  Float,
+  Integer,
+  IntegerSet,
+  Opaque,
+  String,
+  SymbolRef,
+  Type,
+  Unit,
+
+  /// Elements Attributes.
+  DenseElements,
+  OpaqueElements,
+  SparseElements,
+  FIRST_ELEMENTS_ATTR = DenseElements,
+  LAST_ELEMENTS_ATTR = SparseElements,
+
+  /// Locations.
+  CallSiteLocation,
+  FileLineColLocation,
+  FusedLocation,
+  NameLocation,
+  UnknownLocation,
+
+  // Represents a location as a 'void*' pointer to a front-end's opaque
+  // location information, which must live longer than the MLIR objects that
+  // refer to it.  OpaqueLocation's are never serialized.
+  //
+  // TODO: OpaqueLocation,
+
+  // Represents a value inlined through a function call.
+  // TODO: InlinedLocation,
+
+  FIRST_LOCATION_ATTR = CallSiteLocation,
+  LAST_LOCATION_ATTR = UnknownLocation,
+};
+} // namespace StandardAttributes
+
+class AffineMapAttr
+    : public Attribute::AttrBase<AffineMapAttr, Attribute,
+                                 detail::AffineMapAttributeStorage> {
+public:
+  using Base::Base;
+  using ValueType = AffineMap;
+
+  static AffineMapAttr get(AffineMap value);
+
+  AffineMap getValue() const;
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool kindof(unsigned kind) {
+    return kind == StandardAttributes::AffineMap;
+  }
+};
+
+/// Array attributes are lists of other attributes.  They are not necessarily
+/// type homogenous given that attributes don't, in general, carry types.
+class ArrayAttr : public Attribute::AttrBase<ArrayAttr, Attribute,
+                                             detail::ArrayAttributeStorage> {
+public:
+  using Base::Base;
+  using ValueType = ArrayRef<Attribute>;
+
+  static ArrayAttr get(ArrayRef<Attribute> value, MLIRContext *context);
+
+  ArrayRef<Attribute> getValue() const;
+
+  /// Support range iteration.
+  using iterator = llvm::ArrayRef<Attribute>::iterator;
+  iterator begin() const { return getValue().begin(); }
+  iterator end() const { return getValue().end(); }
+  size_t size() const { return getValue().size(); }
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool kindof(unsigned kind) {
+    return kind == StandardAttributes::Array;
+  }
+};
+
+class BoolAttr : public Attribute::AttrBase<BoolAttr, Attribute,
+                                            detail::BoolAttributeStorage> {
+public:
+  using Base::Base;
+  using ValueType = bool;
+
+  static BoolAttr get(bool value, MLIRContext *context);
+
+  bool getValue() const;
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool kindof(unsigned kind) { return kind == StandardAttributes::Bool; }
+};
+
+/// NamedAttribute is used for dictionary attributes, it holds an identifier for
+/// the name and a value for the attribute. The attribute pointer should always
+/// be non-null.
+using NamedAttribute = std::pair<Identifier, Attribute>;
+
+/// Dictionary attribute is an attribute that represents a sorted collection of
+/// named attribute values. The elements are sorted by name, and each name must
+/// be unique within the collection.
+class DictionaryAttr
+    : public Attribute::AttrBase<DictionaryAttr, Attribute,
+                                 detail::DictionaryAttributeStorage> {
+public:
+  using Base::Base;
+  using ValueType = ArrayRef<NamedAttribute>;
+
+  static DictionaryAttr get(ArrayRef<NamedAttribute> value,
+                            MLIRContext *context);
+
+  ArrayRef<NamedAttribute> getValue() const;
+
+  /// Return the specified attribute if present, null otherwise.
+  Attribute get(StringRef name) const;
+  Attribute get(Identifier name) const;
+
+  /// Support range iteration.
+  using iterator = llvm::ArrayRef<NamedAttribute>::iterator;
+  iterator begin() const;
+  iterator end() const;
+  bool empty() const { return size() == 0; }
+  size_t size() const;
+
+  /// Methods for supporting type inquiry through isa, cast, and dyn_cast.
+  static bool kindof(unsigned kind) {
+    return kind == StandardAttributes::Dictionary;
+  }
+};
+
+class FloatAttr : public Attribute::AttrBase<FloatAttr, Attribute,
+                                             detail::FloatAttributeStorage> {
+public:
+  using Base::Base;
+  using ValueType = APFloat;
+
+  /// Return a float attribute for the specified value in the specified type.
+  /// These methods should only be used for simple constant values, e.g 1.0/2.0,
+  /// that are known-valid both as host double and the 'type' format.
+  static FloatAttr get(Type type, double value);
+  static FloatAttr getChecked(Type type, double value, Location loc);
+
+  /// Return a float attribute for the specified value in the specified type.
+  static FloatAttr get(Type type, const APFloat &value);
+  static FloatAttr getChecked(Type type, const APFloat &value, Location loc);
+
+  APFloat getValue() const;
+
+  /// This function is used to convert the value to a double, even if it loses
+  /// precision.
+  double getValueAsDouble() const;
+  static double getValueAsDouble(APFloat val);
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool kindof(unsigned kind) {
+    return kind == StandardAttributes::Float;
+  }
+
+  /// Verify the construction invariants for a double value.
+  static LogicalResult
+  verifyConstructionInvariants(llvm::Optional<Location> loc, MLIRContext *ctx,
+                               Type type, double value);
+  static LogicalResult
+  verifyConstructionInvariants(llvm::Optional<Location> loc, MLIRContext *ctx,
+                               Type type, const APFloat &value);
+};
+
+class IntegerAttr
+    : public Attribute::AttrBase<IntegerAttr, Attribute,
+                                 detail::IntegerAttributeStorage> {
+public:
+  using Base::Base;
+  using ValueType = APInt;
+
+  static IntegerAttr get(Type type, int64_t value);
+  static IntegerAttr get(Type type, const APInt &value);
+
+  APInt getValue() const;
+  // TODO(jpienaar): Change callers to use getValue instead.
+  int64_t getInt() const;
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool kindof(unsigned kind) {
+    return kind == StandardAttributes::Integer;
+  }
+};
+
+class IntegerSetAttr
+    : public Attribute::AttrBase<IntegerSetAttr, Attribute,
+                                 detail::IntegerSetAttributeStorage> {
+public:
+  using Base::Base;
+  using ValueType = IntegerSet;
+
+  static IntegerSetAttr get(IntegerSet value);
+
+  IntegerSet getValue() const;
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool kindof(unsigned kind) {
+    return kind == StandardAttributes::IntegerSet;
+  }
+};
+
+/// Opaque attributes represent attributes of non-registered dialects. These are
+/// attribute represented in their raw string form, and can only usefully be
+/// tested for attribute equality.
+class OpaqueAttr : public Attribute::AttrBase<OpaqueAttr, Attribute,
+                                              detail::OpaqueAttributeStorage> {
+public:
+  using Base::Base;
+
+  /// Get or create a new OpaqueAttr with the provided dialect and string data.
+  static OpaqueAttr get(Identifier dialect, StringRef attrData, Type type,
+                        MLIRContext *context);
+
+  /// Get or create a new OpaqueAttr with the provided dialect and string data.
+  /// If the given identifier is not a valid namespace for a dialect, then a
+  /// null attribute is returned.
+  static OpaqueAttr getChecked(Identifier dialect, StringRef attrData,
+                               Type type, Location location);
+
+  /// Returns the dialect namespace of the opaque attribute.
+  Identifier getDialectNamespace() const;
+
+  /// Returns the raw attribute data of the opaque attribute.
+  StringRef getAttrData() const;
+
+  /// Verify the construction of an opaque attribute.
+  static LogicalResult
+  verifyConstructionInvariants(llvm::Optional<Location> loc,
+                               MLIRContext *context, Identifier dialect,
+                               StringRef attrData, Type type);
+
+  static bool kindof(unsigned kind) {
+    return kind == StandardAttributes::Opaque;
+  }
+};
+
+class StringAttr : public Attribute::AttrBase<StringAttr, Attribute,
+                                              detail::StringAttributeStorage> {
+public:
+  using Base::Base;
+  using ValueType = StringRef;
+
+  /// Get an instance of a StringAttr with the given string.
+  static StringAttr get(StringRef bytes, MLIRContext *context);
+
+  /// Get an instance of a StringAttr with the given string and Type.
+  static StringAttr get(StringRef bytes, Type type);
+
+  StringRef getValue() const;
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool kindof(unsigned kind) {
+    return kind == StandardAttributes::String;
+  }
+};
+
+/// A symbol reference attribute represents a symbolic reference to another
+/// operation.
+class SymbolRefAttr
+    : public Attribute::AttrBase<SymbolRefAttr, Attribute,
+                                 detail::StringAttributeStorage> {
+public:
+  using Base::Base;
+  using ValueType = StringRef;
+
+  static SymbolRefAttr get(StringRef value, MLIRContext *ctx);
+
+  /// Returns the name of the held symbol reference.
+  StringRef getValue() const;
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool kindof(unsigned kind) {
+    return kind == StandardAttributes::SymbolRef;
+  }
+};
+
+class TypeAttr : public Attribute::AttrBase<TypeAttr, Attribute,
+                                            detail::TypeAttributeStorage> {
+public:
+  using Base::Base;
+  using ValueType = Type;
+
+  static TypeAttr get(Type value);
+
+  Type getValue() const;
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool kindof(unsigned kind) { return kind == StandardAttributes::Type; }
+};
+
+/// Unit attributes are attributes that hold no specific value and are given
+/// meaning by their existence.
+class UnitAttr : public Attribute::AttrBase<UnitAttr> {
+public:
+  using Base::Base;
+
+  static UnitAttr get(MLIRContext *context);
+
+  static bool kindof(unsigned kind) { return kind == StandardAttributes::Unit; }
+};
+
+//===----------------------------------------------------------------------===//
+// Elements Attributes
+//===----------------------------------------------------------------------===//
+
+/// A base attribute that represents a reference to a static shaped tensor or
+/// vector constant.
+class ElementsAttr : public Attribute {
+public:
+  using Attribute::Attribute;
+
+  /// Return the type of this ElementsAttr, guaranteed to be a vector or tensor
+  /// with static shape.
+  ShapedType getType() const;
+
+  /// Return the value at the given index. If index does not refer to a valid
+  /// element, then a null attribute is returned.
+  Attribute getValue(ArrayRef<uint64_t> index) const;
+
+  /// Generates a new ElementsAttr by mapping each int value to a new
+  /// underlying APInt. The new values can represent either a integer or float.
+  /// This ElementsAttr should contain integers.
+  ElementsAttr
+  mapValues(Type newElementType,
+            llvm::function_ref<APInt(const APInt &)> mapping) const;
+
+  /// Generates a new ElementsAttr by mapping each float value to a new
+  /// underlying APInt. The new values can represent either a integer or float.
+  /// This ElementsAttr should contain floats.
+  ElementsAttr
+  mapValues(Type newElementType,
+            llvm::function_ref<APInt(const APFloat &)> mapping) const;
+
+  /// Method for support type inquiry through isa, cast and dyn_cast.
+  static bool classof(Attribute attr) {
+    return attr.getKind() >= StandardAttributes::FIRST_ELEMENTS_ATTR &&
+           attr.getKind() <= StandardAttributes::LAST_ELEMENTS_ATTR;
+  }
+};
+
+/// An attribute that represents a reference to a dense vector or tensor object.
+///
+class DenseElementsAttr
+    : public Attribute::AttrBase<DenseElementsAttr, ElementsAttr,
+                                 detail::DenseElementsAttributeStorage> {
+public:
+  using Base::Base;
+
+  /// Method for support type inquiry through isa, cast and dyn_cast.
+  static bool classof(Attribute attr) {
+    return attr.getKind() == StandardAttributes::DenseElements;
+  }
+
+  /// Constructs a dense elements attribute from an array of element values.
+  /// Each element attribute value is expected to be an element of 'type'.
+  /// 'type' must be a vector or tensor with static shape.
+  static DenseElementsAttr get(ShapedType type, ArrayRef<Attribute> values);
+
+  /// Constructs a dense integer elements attribute from an array of integer
+  /// or floating-point values. Each value is expected to be the same bitwidth
+  /// of the element type of 'type'. 'type' must be a vector or tensor with
+  /// static shape.
+  template <typename T, typename = typename std::enable_if<
+                            std::numeric_limits<T>::is_integer ||
+                            llvm::is_one_of<T, float, double>::value>::type>
+  static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values) {
+    const char *data = reinterpret_cast<const char *>(values.data());
+    return getRawIntOrFloat(
+        type, ArrayRef<char>(data, values.size() * sizeof(T)), sizeof(T),
+        /*isInt=*/std::numeric_limits<T>::is_integer);
+  }
+
+  /// Constructs a dense integer elements attribute from a single element.
+  template <typename T, typename = typename std::enable_if<
+                            std::numeric_limits<T>::is_integer ||
+                            llvm::is_one_of<T, float, double>::value>::type>
+  static DenseElementsAttr get(const ShapedType &type, T value) {
+    return get(type, llvm::makeArrayRef(value));
+  }
+
+  /// Overload of the above 'get' method that is specialized for boolean values.
+  static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values);
+
+  /// Constructs a dense integer elements attribute from an array of APInt
+  /// values. Each APInt value is expected to have the same bitwidth as the
+  /// element type of 'type'. 'type' must be a vector or tensor with static
+  /// shape.
+  static DenseElementsAttr get(ShapedType type, ArrayRef<APInt> values);
+
+  /// Constructs a dense float elements attribute from an array of APFloat
+  /// values. Each APFloat value is expected to have the same bitwidth as the
+  /// element type of 'type'. 'type' must be a vector or tensor with static
+  /// shape.
+  static DenseElementsAttr get(ShapedType type, ArrayRef<APFloat> values);
+
+  /// Construct a dense elements attribute for an initializer_list of values.
+  /// Each value is expected to be the same bitwidth of the element type of
+  /// 'type'. 'type' must be a vector or tensor with static shape.
+  template <typename T>
+  static DenseElementsAttr get(const ShapedType &type,
+                               const std::initializer_list<T> &list) {
+    return get(type, ArrayRef<T>(list));
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Iterators
+  //===--------------------------------------------------------------------===//
+
+  /// A utility iterator that allows walking over the internal Attribute values
+  /// of a DenseElementsAttr.
+  class AttributeElementIterator
+      : public indexed_accessor_iterator<AttributeElementIterator, const void *,
+                                         Attribute, Attribute, Attribute> {
+  public:
+    /// Accesses the Attribute value at this iterator position.
+    Attribute operator*() const;
+
+  private:
+    friend DenseElementsAttr;
+
+    /// Constructs a new iterator.
+    AttributeElementIterator(DenseElementsAttr attr, size_t index);
+  };
+
+  /// A utility iterator that allows walking over the internal raw APInt values.
+  class IntElementIterator
+      : public indexed_accessor_iterator<IntElementIterator, const char *,
+                                         APInt, APInt, APInt> {
+  public:
+    /// Accesses the raw APInt value at this iterator position.
+    APInt operator*() const;
+
+  private:
+    friend DenseElementsAttr;
+
+    /// Constructs a new iterator.
+    IntElementIterator(DenseElementsAttr attr, size_t index);
+
+    /// The bitwidth of the element type.
+    size_t bitWidth;
+  };
+
+  /// Iterator for walking over APFloat values.
+  class FloatElementIterator final
+      : public llvm::mapped_iterator<IntElementIterator,
+                                     std::function<APFloat(const APInt &)>> {
+    friend DenseElementsAttr;
+
+    /// Initializes the float element iterator to the specified iterator.
+    FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it);
+
+  public:
+    using reference = APFloat;
+  };
+
+  //===--------------------------------------------------------------------===//
+  // Value Querying
+  //===--------------------------------------------------------------------===//
+
+  /// Returns the number of raw elements held by this attribute.
+  size_t rawSize() const;
+
+  /// Returns if this attribute corresponds to a splat, i.e. if all element
+  /// values are the same.
+  bool isSplat() const;
+
+  /// If this attribute corresponds to a splat, then get the splat value.
+  /// Otherwise, return null.
+  Attribute getSplatValue() const;
+
+  /// Return the value at the given index. If index does not refer to a valid
+  /// element, then a null attribute is returned.
+  Attribute getValue(ArrayRef<uint64_t> index) const;
+
+  /// Return the held element values as an array of integer or floating-point
+  /// values.
+  template <typename T, typename = typename std::enable_if<
+                            (!std::is_same<T, bool>::value &&
+                             std::numeric_limits<T>::is_integer) ||
+                            llvm::is_one_of<T, float, double>::value>::type>
+  ArrayRef<T> getValues() const {
+    assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer));
+    auto rawData = getRawData();
+    return ArrayRef<T>(reinterpret_cast<const T *>(rawData.data()),
+                       rawData.size() / sizeof(T));
+  }
+
+  /// Return the held element values as a range of Attributes.
+  llvm::iterator_range<AttributeElementIterator> getAttributeValues() const;
+  template <typename T, typename = typename std::enable_if<
+                            std::is_same<T, Attribute>::value>::type>
+  llvm::iterator_range<AttributeElementIterator> getValues() const {
+    return getAttributeValues();
+  }
+  AttributeElementIterator attr_value_begin() const;
+  AttributeElementIterator attr_value_end() const;
+
+  /// Return the held element values as a range of APInts. The element type of
+  /// this attribute must be of integer type.
+  llvm::iterator_range<IntElementIterator> getIntValues() const;
+  template <typename T, typename = typename std::enable_if<
+                            std::is_same<T, APInt>::value>::type>
+  llvm::iterator_range<IntElementIterator> getValues() const {
+    return getIntValues();
+  }
+  IntElementIterator int_value_begin() const;
+  IntElementIterator int_value_end() const;
+
+  /// Return the held element values as a range of APFloat. The element type of
+  /// this attribute must be of float type.
+  llvm::iterator_range<FloatElementIterator> getFloatValues() const;
+  template <typename T, typename = typename std::enable_if<
+                            std::is_same<T, APFloat>::value>::type>
+  llvm::iterator_range<FloatElementIterator> getValues() const {
+    return getFloatValues();
+  }
+  FloatElementIterator float_value_begin() const;
+  FloatElementIterator float_value_end() const;
+
+  //===--------------------------------------------------------------------===//
+  // Mutation Utilities
+  //===--------------------------------------------------------------------===//
+
+  /// Return a new DenseElementsAttr that has the same data as the current
+  /// attribute, but has been reshaped to 'newType'. The new type must have the
+  /// same total number of elements as well as element type.
+  DenseElementsAttr reshape(ShapedType newType);
+
+  /// Generates a new DenseElementsAttr by mapping each int value to a new
+  /// underlying APInt. The new values can represent either a integer or float.
+  /// This underlying type must be an DenseIntElementsAttr.
+  DenseElementsAttr
+  mapValues(Type newElementType,
+            llvm::function_ref<APInt(const APInt &)> mapping) const;
+
+  /// Generates a new DenseElementsAttr by mapping each float value to a new
+  /// underlying APInt. the new values can represent either a integer or float.
+  /// This underlying type must be an DenseFPElementsAttr.
+  DenseElementsAttr
+  mapValues(Type newElementType,
+            llvm::function_ref<APInt(const APFloat &)> mapping) const;
+
+protected:
+  /// Return the raw storage data held by this attribute.
+  ArrayRef<char> getRawData() const;
+
+  /// Get iterators to the raw APInt values for each element in this attribute.
+  IntElementIterator raw_int_begin() const {
+    return IntElementIterator(*this, 0);
+  }
+  IntElementIterator raw_int_end() const {
+    return IntElementIterator(*this, rawSize());
+  }
+
+  /// Constructs a dense elements attribute from an array of raw APInt values.
+  /// Each APInt value is expected to have the same bitwidth as the element type
+  /// of 'type'. 'type' must be a vector or tensor with static shape.
+  static DenseElementsAttr getRaw(ShapedType type, ArrayRef<APInt> values);
+
+  /// Get or create a new dense elements attribute instance with the given raw
+  /// data buffer. 'type' must be a vector or tensor with static shape.
+  static DenseElementsAttr getRaw(ShapedType type, ArrayRef<char> data,
+                                  bool isSplat);
+
+  /// Overload of the raw 'get' method that asserts that the given type is of
+  /// integer or floating-point type. This method is used to verify type
+  /// invariants that the templatized 'get' method cannot.
+  static DenseElementsAttr getRawIntOrFloat(ShapedType type,
+                                            ArrayRef<char> data,
+                                            int64_t dataEltSize, bool isInt);
+
+  /// Check the information for a c++ data type, check if this type is valid for
+  /// the current attribute. This method is used to verify specific type
+  /// invariants that the templatized 'getValues' method cannot.
+  bool isValidIntOrFloat(int64_t dataEltSize, bool isInt) const;
+};
+
+/// An attribute that represents a reference to a dense float vector or tensor
+/// object. Each element is stored as a double.
+class DenseFPElementsAttr : public DenseElementsAttr {
+public:
+  using iterator = DenseElementsAttr::FloatElementIterator;
+
+  using DenseElementsAttr::DenseElementsAttr;
+
+  /// Generates a new DenseElementsAttr by mapping each value attribute, and
+  /// constructing the DenseElementsAttr given the new element type.
+  DenseElementsAttr
+  mapValues(Type newElementType,
+            llvm::function_ref<APInt(const APFloat &)> mapping) const;
+
+  /// Iterator access to the float element values.
+  iterator begin() const { return float_value_begin(); }
+  iterator end() const { return float_value_end(); }
+
+  /// Method for supporting type inquiry through isa, cast and dyn_cast.
+  static bool classof(Attribute attr);
+};
+
+/// An attribute that represents a reference to a dense integer vector or tensor
+/// object.
+class DenseIntElementsAttr : public DenseElementsAttr {
+public:
+  /// DenseIntElementsAttr iterates on APInt, so we can use the raw element
+  /// iterator directly.
+  using iterator = DenseElementsAttr::IntElementIterator;
+
+  using DenseElementsAttr::DenseElementsAttr;
+
+  /// Generates a new DenseElementsAttr by mapping each value attribute, and
+  /// constructing the DenseElementsAttr given the new element type.
+  DenseElementsAttr
+  mapValues(Type newElementType,
+            llvm::function_ref<APInt(const APInt &)> mapping) const;
+
+  /// Iterator access to the integer element values.
+  iterator begin() const { return raw_int_begin(); }
+  iterator end() const { return raw_int_end(); }
+
+  /// Method for supporting type inquiry through isa, cast and dyn_cast.
+  static bool classof(Attribute attr);
+};
+
+/// An opaque attribute that represents a reference to a vector or tensor
+/// constant with opaque content. This respresentation is for tensor constants
+/// which the compiler may not need to interpret. This attribute is always
+/// associated with a particular dialect, which provides a method to convert
+/// tensor representation to a non-opaque format.
+class OpaqueElementsAttr
+    : public Attribute::AttrBase<OpaqueElementsAttr, ElementsAttr,
+                                 detail::OpaqueElementsAttributeStorage> {
+public:
+  using Base::Base;
+  using ValueType = StringRef;
+
+  static OpaqueElementsAttr get(Dialect *dialect, ShapedType type,
+                                StringRef bytes);
+
+  StringRef getValue() const;
+
+  /// Return the value at the given index. If index does not refer to a valid
+  /// element, then a null attribute is returned.
+  Attribute getValue(ArrayRef<uint64_t> index) const;
+
+  /// Decodes the attribute value using dialect-specific decoding hook.
+  /// Returns false if decoding is successful. If not, returns true and leaves
+  /// 'result' argument unspecified.
+  bool decode(ElementsAttr &result);
+
+  /// Returns dialect associated with this opaque constant.
+  Dialect *getDialect() const;
+
+  /// Method for support type inquiry through isa, cast and dyn_cast.
+  static bool kindof(unsigned kind) {
+    return kind == StandardAttributes::OpaqueElements;
+  }
+};
+
+/// An attribute that represents a reference to a sparse vector or tensor
+/// object.
+///
+/// This class uses COO (coordinate list) encoding to represent the sparse
+/// elements in an element attribute. Specifically, the sparse vector/tensor
+/// stores the indices and values as two separate dense elements attributes of
+/// tensor type (even if the sparse attribute is of vector type, in order to
+/// support empty lists). The dense elements attribute indices is a 2-D tensor
+/// of 64-bit integer elements with shape [N, ndims], which specifies the
+/// indices of the elements in the sparse tensor that contains nonzero values.
+/// The dense elements attribute values is a 1-D tensor with shape [N], and it
+/// supplies the corresponding values for the indices.
+///
+/// For example,
+/// `sparse<tensor<3x4xi32>, [[0, 0], [1, 2]], [1, 5]>` represents tensor
+/// [[1, 0, 0, 0],
+///  [0, 0, 5, 0],
+///  [0, 0, 0, 0]].
+class SparseElementsAttr
+    : public Attribute::AttrBase<SparseElementsAttr, ElementsAttr,
+                                 detail::SparseElementsAttributeStorage> {
+public:
+  using Base::Base;
+
+  /// 'type' must be a vector or tensor with static shape.
+  static SparseElementsAttr get(ShapedType type, DenseElementsAttr indices,
+                                DenseElementsAttr values);
+
+  DenseIntElementsAttr getIndices() const;
+
+  DenseElementsAttr getValues() const;
+
+  /// Return the value of the element at the given index.
+  Attribute getValue(ArrayRef<uint64_t> index) const;
+
+  /// Method for support type inquiry through isa, cast and dyn_cast.
+  static bool kindof(unsigned kind) {
+    return kind == StandardAttributes::SparseElements;
+  }
+};
+
+/// An attribute that represents a reference to a splat vector or tensor
+/// constant, meaning all of the elements have the same value.
+class SplatElementsAttr : public DenseElementsAttr {
+public:
+  using DenseElementsAttr::DenseElementsAttr;
+
+  /// Method for support type inquiry through isa, cast and dyn_cast.
+  static bool classof(Attribute attr) {
+    auto denseAttr = attr.dyn_cast<DenseElementsAttr>();
+    return denseAttr && denseAttr.isSplat();
+  }
+};
+
+template <typename U> bool Attribute::isa() const {
+  assert(impl && "isa<> used on a null attribute.");
+  return U::classof(*this);
+}
+template <typename U> U Attribute::dyn_cast() const {
+  return isa<U>() ? U(impl) : U(nullptr);
+}
+template <typename U> U Attribute::dyn_cast_or_null() const {
+  return (impl && isa<U>()) ? U(impl) : U(nullptr);
+}
+template <typename U> U Attribute::cast() const {
+  assert(isa<U>());
+  return U(impl);
+}
+
+// Make Attribute hashable.
+inline ::llvm::hash_code hash_value(Attribute arg) {
+  return ::llvm::hash_value(arg.impl);
+}
+
+/// A NamedAttributeList is used to manage a list of named attributes. This
+/// provides simple interfaces for adding/removing/finding attributes from
+/// within a DictionaryAttr.
+///
+/// We assume there will be relatively few attributes on a given operation
+/// (maybe a dozen or so, but not hundreds or thousands) so we use linear
+/// searches for everything.
+class NamedAttributeList {
+public:
+  NamedAttributeList(DictionaryAttr attrs = nullptr)
+      : attrs((attrs && !attrs.empty()) ? attrs : nullptr) {}
+  NamedAttributeList(ArrayRef<NamedAttribute> attributes);
+
+  /// Return the underlying dictionary attribute. This may be null, if this list
+  /// has no attributes.
+  DictionaryAttr getDictionary() const { return attrs; }
+
+  /// Return all of the attributes on this operation.
+  ArrayRef<NamedAttribute> getAttrs() const;
+
+  /// Replace the held attributes with ones provided in 'newAttrs'.
+  void setAttrs(ArrayRef<NamedAttribute> attributes);
+
+  /// Return the specified attribute if present, null otherwise.
+  Attribute get(StringRef name) const;
+  Attribute get(Identifier name) const;
+
+  /// If the an attribute exists with the specified name, change it to the new
+  /// value.  Otherwise, add a new attribute with the specified name/value.
+  void set(Identifier name, Attribute value);
+
+  enum class RemoveResult { Removed, NotFound };
+
+  /// Remove the attribute with the specified name if it exists.  The return
+  /// value indicates whether the attribute was present or not.
+  RemoveResult remove(Identifier name);
+
+private:
+  DictionaryAttr attrs;
+};
+
+} // end namespace mlir.
+
+namespace llvm {
+
+// Attribute hash just like pointers.
+template <> struct DenseMapInfo<mlir::Attribute> {
+  static mlir::Attribute getEmptyKey() {
+    auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+    return mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer));
+  }
+  static mlir::Attribute getTombstoneKey() {
+    auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+    return mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer));
+  }
+  static unsigned getHashValue(mlir::Attribute val) {
+    return mlir::hash_value(val);
+  }
+  static bool isEqual(mlir::Attribute LHS, mlir::Attribute RHS) {
+    return LHS == RHS;
+  }
+};
+
+/// Allow LLVM to steal the low bits of Attributes.
+template <> struct PointerLikeTypeTraits<mlir::Attribute> {
+public:
+  static inline void *getAsVoidPointer(mlir::Attribute attr) {
+    return const_cast<void *>(attr.getAsOpaquePointer());
+  }
+  static inline mlir::Attribute getFromVoidPointer(void *ptr) {
+    return mlir::Attribute::getFromOpaquePointer(ptr);
+  }
+  enum { NumLowBitsAvailable = 3 };
+};
+
+} // namespace llvm
+
+#endif
diff --git a/third_party/mlir/include/mlir/IR/Block.h b/third_party/mlir/include/mlir/IR/Block.h
new file mode 100644
index 0000000..84144b8
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/Block.h
@@ -0,0 +1,456 @@
+//===- Block.h - MLIR Block Class -------------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines the Block class.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_BLOCK_H
+#define MLIR_IR_BLOCK_H
+
+#include "mlir/IR/Value.h"
+#include "llvm/ADT/PointerUnion.h"
+#include "llvm/ADT/ilist.h"
+#include "llvm/ADT/ilist_node.h"
+
+//===----------------------------------------------------------------------===//
+// ilist_traits for Operation
+//===----------------------------------------------------------------------===//
+
+namespace llvm {
+namespace ilist_detail {
+// Explicitly define the node access for the operation list so that we can
+// break the dependence on the Operation class in this header. This allows for
+// operations to have trailing Regions without a circular include
+// dependence.
+template <>
+struct SpecificNodeAccess<
+    typename compute_node_options<::mlir::Operation>::type> : NodeAccess {
+protected:
+  using OptionsT = typename compute_node_options<mlir::Operation>::type;
+  using pointer = typename OptionsT::pointer;
+  using const_pointer = typename OptionsT::const_pointer;
+  using node_type = ilist_node_impl<OptionsT>;
+
+  static node_type *getNodePtr(pointer N);
+  static const node_type *getNodePtr(const_pointer N);
+
+  static pointer getValuePtr(node_type *N);
+  static const_pointer getValuePtr(const node_type *N);
+};
+} // end namespace ilist_detail
+
+template <> struct ilist_traits<::mlir::Operation> {
+  using Operation = ::mlir::Operation;
+  using op_iterator = simple_ilist<Operation>::iterator;
+
+  static void deleteNode(Operation *op);
+  void addNodeToList(Operation *op);
+  void removeNodeFromList(Operation *op);
+  void transferNodesFromList(ilist_traits<Operation> &otherList,
+                             op_iterator first, op_iterator last);
+
+private:
+  mlir::Block *getContainingBlock();
+};
+} // end namespace llvm
+
+namespace mlir {
+using BlockOperand = IROperandImpl<Block>;
+
+class PredecessorIterator;
+class SuccessorIterator;
+
+/// `Block` represents an ordered list of `Operation`s.
+class Block : public IRObjectWithUseList,
+              public llvm::ilist_node_with_parent<Block, Region> {
+public:
+  explicit Block() {}
+  ~Block();
+
+  void clear() {
+    // Drop all references from within this block.
+    dropAllReferences();
+
+    // Clear operations in the reverse order so that uses are destroyed
+    // before their defs.
+    while (!empty())
+      operations.pop_back();
+  }
+
+  /// Blocks are maintained in a Region.
+  Region *getParent();
+
+  /// Returns the closest surrounding operation that contains this block.
+  Operation *getParentOp();
+
+  /// Return if this block is the entry block in the parent region.
+  bool isEntryBlock();
+
+  /// Insert this block (which must not already be in a function) right before
+  /// the specified block.
+  void insertBefore(Block *block);
+
+  /// Unlink this Block from its parent region and delete it.
+  void erase();
+
+  //===--------------------------------------------------------------------===//
+  // Block argument management
+  //===--------------------------------------------------------------------===//
+
+  // This is the list of arguments to the block.
+  using BlockArgListType = ArrayRef<BlockArgument *>;
+
+  BlockArgListType getArguments() { return arguments; }
+
+  using args_iterator = BlockArgListType::iterator;
+  using reverse_args_iterator = BlockArgListType::reverse_iterator;
+  args_iterator args_begin() { return getArguments().begin(); }
+  args_iterator args_end() { return getArguments().end(); }
+  reverse_args_iterator args_rbegin() { return getArguments().rbegin(); }
+  reverse_args_iterator args_rend() { return getArguments().rend(); }
+
+  bool args_empty() { return arguments.empty(); }
+
+  /// Add one value to the argument list.
+  BlockArgument *addArgument(Type type);
+
+  /// Add one argument to the argument list for each type specified in the list.
+  llvm::iterator_range<args_iterator> addArguments(ArrayRef<Type> types);
+
+  /// Erase the argument at 'index' and remove it from the argument list. If
+  /// 'updatePredTerms' is set to true, this argument is also removed from the
+  /// terminators of each predecessor to this block.
+  void eraseArgument(unsigned index, bool updatePredTerms = true);
+
+  unsigned getNumArguments() { return arguments.size(); }
+  BlockArgument *getArgument(unsigned i) { return arguments[i]; }
+
+  //===--------------------------------------------------------------------===//
+  // Operation list management
+  //===--------------------------------------------------------------------===//
+
+  /// This is the list of operations in the block.
+  using InstListType = llvm::iplist<Operation>;
+  InstListType &getOperations() { return operations; }
+
+  // Iteration over the operations in the block.
+  using iterator = InstListType::iterator;
+  using reverse_iterator = InstListType::reverse_iterator;
+
+  iterator begin() { return operations.begin(); }
+  iterator end() { return operations.end(); }
+  reverse_iterator rbegin() { return operations.rbegin(); }
+  reverse_iterator rend() { return operations.rend(); }
+
+  bool empty() { return operations.empty(); }
+  void push_back(Operation *op) { operations.push_back(op); }
+  void push_front(Operation *op) { operations.push_front(op); }
+
+  Operation &back() { return operations.back(); }
+  Operation &front() { return operations.front(); }
+
+  /// Returns 'op' if 'op' lies in this block, or otherwise finds the
+  /// ancestor operation of 'op' that lies in this block. Returns nullptr if
+  /// the latter fails.
+  /// TODO: This is very specific functionality that should live somewhere else,
+  /// probably in Dominance.cpp.
+  Operation *findAncestorInstInBlock(Operation &op);
+
+  /// This drops all operand uses from operations within this block, which is
+  /// an essential step in breaking cyclic dependences between references when
+  /// they are to be deleted.
+  void dropAllReferences();
+
+  /// This drops all uses of values defined in this block or in the blocks of
+  /// nested regions wherever the uses are located.
+  void dropAllDefinedValueUses();
+
+  /// Returns true if the ordering of the child operations is valid, false
+  /// otherwise.
+  bool isInstOrderValid();
+
+  /// Invalidates the current ordering of operations.
+  void invalidateInstOrder();
+
+  /// Verifies the current ordering of child operations matches the
+  /// validInstOrder flag. Returns false if the order is valid, true otherwise.
+  bool verifyInstOrder();
+
+  /// Recomputes the ordering of child operations within the block.
+  void recomputeInstOrder();
+
+private:
+  /// A utility iterator that filters out operations that are not 'OpT'.
+  template <typename OpT>
+  class op_filter_iterator
+      : public llvm::filter_iterator<Block::iterator, bool (*)(Operation &)> {
+    static bool filter(Operation &op) { return llvm::isa<OpT>(op); }
+
+  public:
+    op_filter_iterator(Block::iterator it, Block::iterator end)
+        : llvm::filter_iterator<Block::iterator, bool (*)(Operation &)>(
+              it, end, &filter) {}
+
+    /// Allow implict conversion to the underlying block iterator.
+    operator Block::iterator() const { return this->wrapped(); }
+  };
+
+public:
+  /// This class provides iteration over the held instructions of a block for a
+  /// specific operation type.
+  template <typename OpT>
+  class op_iterator : public llvm::mapped_iterator<op_filter_iterator<OpT>,
+                                                   OpT (*)(Operation &)> {
+    static OpT unwrap(Operation &op) { return llvm::cast<OpT>(op); }
+
+  public:
+    using reference = OpT;
+
+    /// Initializes the iterator to the specified filter iterator.
+    op_iterator(op_filter_iterator<OpT> it)
+        : llvm::mapped_iterator<op_filter_iterator<OpT>, OpT (*)(Operation &)>(
+              it, &unwrap) {}
+
+    /// Allow implict conversion to the underlying block iterator.
+    operator Block::iterator() const { return this->wrapped(); }
+  };
+
+  /// Return an iterator range over the operations within this block that are of
+  /// 'OpT'.
+  template <typename OpT> llvm::iterator_range<op_iterator<OpT>> getOps() {
+    auto endIt = end();
+    return {op_filter_iterator<OpT>(begin(), endIt),
+            op_filter_iterator<OpT>(endIt, endIt)};
+  }
+  template <typename OpT> op_iterator<OpT> op_begin() {
+    return op_filter_iterator<OpT>(begin(), end());
+  }
+  template <typename OpT> op_iterator<OpT> op_end() {
+    return op_filter_iterator<OpT>(end(), end());
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Terminator management
+  //===--------------------------------------------------------------------===//
+
+  /// Get the terminator operation of this block. This function asserts that
+  /// the block has a valid terminator operation.
+  Operation *getTerminator();
+
+  //===--------------------------------------------------------------------===//
+  // Predecessors and successors.
+  //===--------------------------------------------------------------------===//
+
+  // Predecessor iteration.
+  using pred_iterator = PredecessorIterator;
+  pred_iterator pred_begin();
+  pred_iterator pred_end();
+  llvm::iterator_range<pred_iterator> getPredecessors();
+
+  /// Return true if this block has no predecessors.
+  bool hasNoPredecessors();
+
+  /// If this block has exactly one predecessor, return it.  Otherwise, return
+  /// null.
+  ///
+  /// Note that if a block has duplicate predecessors from a single block (e.g.
+  /// if you have a conditional branch with the same block as the true/false
+  /// destinations) is not considered to be a single predecessor.
+  Block *getSinglePredecessor();
+
+  // Indexed successor access.
+  unsigned getNumSuccessors();
+  Block *getSuccessor(unsigned i);
+
+  // Successor iteration.
+  using succ_iterator = SuccessorIterator;
+  succ_iterator succ_begin();
+  succ_iterator succ_end();
+  llvm::iterator_range<succ_iterator> getSuccessors();
+
+  //===--------------------------------------------------------------------===//
+  // Operation Walkers
+  //===--------------------------------------------------------------------===//
+
+  /// Walk the operations in this block in postorder, calling the callback for
+  /// each operation.
+  void walk(llvm::function_ref<void(Operation *)> callback);
+
+  /// Specialization of walk to only visit operations of 'OpTy'.
+  template <typename OpTy> void walk(llvm::function_ref<void(OpTy)> callback) {
+    walk([&](Operation *opInst) {
+      if (auto op = dyn_cast<OpTy>(opInst))
+        callback(op);
+    });
+  }
+
+  /// Walk the operations in the specified [begin, end) range of this block in
+  /// postorder, calling the callback for each operation.
+  void walk(Block::iterator begin, Block::iterator end,
+            llvm::function_ref<void(Operation *)> callback);
+
+  //===--------------------------------------------------------------------===//
+  // Other
+  //===--------------------------------------------------------------------===//
+
+  /// Split the block into two blocks before the specified operation or
+  /// iterator.
+  ///
+  /// Note that all operations BEFORE the specified iterator stay as part of
+  /// the original basic block, and the rest of the operations in the original
+  /// block are moved to the new block, including the old terminator.  The
+  /// original block is left without a terminator.
+  ///
+  /// The newly formed Block is returned, and the specified iterator is
+  /// invalidated.
+  Block *splitBlock(iterator splitBefore);
+  Block *splitBlock(Operation *splitBeforeInst) {
+    return splitBlock(iterator(splitBeforeInst));
+  }
+
+  /// Returns pointer to member of operation list.
+  static InstListType Block::*getSublistAccess(Operation *) {
+    return &Block::operations;
+  }
+
+  void print(raw_ostream &os);
+  void dump();
+
+  /// Print out the name of the block without printing its body.
+  /// NOTE: The printType argument is ignored.  We keep it for compatibility
+  /// with LLVM dominator machinery that expects it to exist.
+  void printAsOperand(raw_ostream &os, bool printType = true);
+
+private:
+  /// Pair of the parent object that owns this block and a bit that signifies if
+  /// the operations within this block have a valid ordering.
+  llvm::PointerIntPair<Region *, /*IntBits=*/1, bool> parentValidInstOrderPair;
+
+  /// This is the list of operations in the block.
+  InstListType operations;
+
+  /// This is the list of arguments to the block.
+  std::vector<BlockArgument *> arguments;
+
+  Block(Block &) = delete;
+  void operator=(Block &) = delete;
+
+  friend struct llvm::ilist_traits<Block>;
+};
+
+} // end namespace mlir
+
+//===----------------------------------------------------------------------===//
+// ilist_traits for Block
+//===----------------------------------------------------------------------===//
+
+namespace llvm {
+
+template <>
+struct ilist_traits<::mlir::Block> : public ilist_alloc_traits<::mlir::Block> {
+  using Block = ::mlir::Block;
+  using block_iterator = simple_ilist<::mlir::Block>::iterator;
+
+  void addNodeToList(Block *block);
+  void removeNodeFromList(Block *block);
+  void transferNodesFromList(ilist_traits<Block> &otherList,
+                             block_iterator first, block_iterator last);
+
+private:
+  mlir::Region *getParentRegion();
+};
+} // end namespace llvm
+
+namespace mlir {
+//===----------------------------------------------------------------------===//
+// Predecessors
+//===----------------------------------------------------------------------===//
+
+/// Implement a predecessor iterator for blocks. This works by walking the use
+/// lists of the blocks. The entries on this list are the BlockOperands that
+/// are embedded into terminator operations. From the operand, we can get the
+/// terminator that contains it, and its parent block is the predecessor.
+class PredecessorIterator final
+    : public llvm::mapped_iterator<ValueUseIterator<BlockOperand>,
+                                   Block *(*)(BlockOperand &)> {
+  static Block *unwrap(BlockOperand &value);
+
+public:
+  using reference = Block *;
+
+  /// Initializes the operand type iterator to the specified operand iterator.
+  PredecessorIterator(ValueUseIterator<BlockOperand> it)
+      : llvm::mapped_iterator<ValueUseIterator<BlockOperand>,
+                              Block *(*)(BlockOperand &)>(it, &unwrap) {}
+  explicit PredecessorIterator(BlockOperand *operand)
+      : PredecessorIterator(ValueUseIterator<BlockOperand>(operand)) {}
+
+  /// Get the successor number in the predecessor terminator.
+  unsigned getSuccessorIndex() const;
+};
+
+inline auto Block::pred_begin() -> pred_iterator {
+  return pred_iterator((BlockOperand *)getFirstUse());
+}
+
+inline auto Block::pred_end() -> pred_iterator {
+  return pred_iterator(nullptr);
+}
+
+inline auto Block::getPredecessors() -> llvm::iterator_range<pred_iterator> {
+  return {pred_begin(), pred_end()};
+}
+
+//===----------------------------------------------------------------------===//
+// Successors
+//===----------------------------------------------------------------------===//
+
+/// This template implements the successor iterators for Block.
+class SuccessorIterator final
+    : public indexed_accessor_iterator<SuccessorIterator, Block *, Block *,
+                                       Block *, Block *> {
+public:
+  /// Initializes the result iterator to the specified index.
+  SuccessorIterator(Block *object, unsigned index)
+      : indexed_accessor_iterator<SuccessorIterator, Block *, Block *, Block *,
+                                  Block *>(object, index) {}
+
+  SuccessorIterator(const SuccessorIterator &other)
+      : SuccessorIterator(other.object, other.index) {}
+
+  Block *operator*() const { return this->object->getSuccessor(this->index); }
+
+  /// Get the successor number in the terminator.
+  unsigned getSuccessorIndex() const { return this->index; }
+};
+
+inline auto Block::succ_begin() -> succ_iterator {
+  return succ_iterator(this, 0);
+}
+
+inline auto Block::succ_end() -> succ_iterator {
+  return succ_iterator(this, getNumSuccessors());
+}
+
+inline auto Block::getSuccessors() -> llvm::iterator_range<succ_iterator> {
+  return {succ_begin(), succ_end()};
+}
+
+} // end namespace mlir
+
+#endif // MLIR_IR_BLOCK_H
diff --git a/third_party/mlir/include/mlir/IR/BlockAndValueMapping.h b/third_party/mlir/include/mlir/IR/BlockAndValueMapping.h
new file mode 100644
index 0000000..bd69aa2
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/BlockAndValueMapping.h
@@ -0,0 +1,93 @@
+//===- BlockAndValueMapping.h -----------------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines a utility class for maintaining a mapping for multiple
+// value types.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_BLOCKANDVALUEMAPPING_H
+#define MLIR_IR_BLOCKANDVALUEMAPPING_H
+
+#include "mlir/IR/Block.h"
+
+namespace mlir {
+// This is a utility class for mapping one set of values to another. New
+// mappings can be inserted via 'map'. Existing mappings can be
+// found via the 'lookup*' functions. There are two variants that differ only in
+// return value when an existing is not found for the provided key.
+// 'lookupOrNull' returns nullptr where as 'lookupOrDefault' will return the
+// lookup key.
+class BlockAndValueMapping {
+public:
+  /// Inserts a new mapping for 'from' to 'to'. If there is an existing mapping,
+  /// it is overwritten.
+  void map(Block *from, Block *to) { valueMap[from] = to; }
+  void map(Value *from, Value *to) { valueMap[from] = to; }
+
+  /// Erases a mapping for 'from'.
+  void erase(IRObjectWithUseList *from) { valueMap.erase(from); }
+
+  /// Checks to see if a mapping for 'from' exists.
+  bool contains(IRObjectWithUseList *from) const {
+    return valueMap.count(from);
+  }
+
+  /// Lookup a mapped value within the map. If a mapping for the provided value
+  /// does not exist then return nullptr.
+  Block *lookupOrNull(Block *from) const {
+    return lookupOrValue(from, (Block *)nullptr);
+  }
+  Value *lookupOrNull(Value *from) const {
+    return lookupOrValue(from, (Value *)nullptr);
+  }
+
+  /// Lookup a mapped value within the map. If a mapping for the provided value
+  /// does not exist then return the provided value.
+  Block *lookupOrDefault(Block *from) const {
+    return lookupOrValue(from, from);
+  }
+  Value *lookupOrDefault(Value *from) const {
+    return lookupOrValue(from, from);
+  }
+
+  /// Lookup a mapped value within the map. This asserts the provided value
+  /// exists within the map.
+  template <typename T> T *lookup(T *from) const {
+    auto *result = lookupOrNull(from);
+    assert(result && "expected 'from' to be contained within the map");
+    return result;
+  }
+
+  /// Clears all mappings held by the mapper.
+  void clear() { valueMap.clear(); }
+
+private:
+  /// Utility lookupOrValue that looks up an existing key or returns the
+  /// provided value. This function assumes that if a mapping does exist, then
+  /// it is of 'T' type.
+  template <typename T> T *lookupOrValue(T *from, T *value) const {
+    auto it = valueMap.find(from);
+    return it != valueMap.end() ? static_cast<T *>(it->second) : value;
+  }
+
+  llvm::DenseMap<IRObjectWithUseList *, IRObjectWithUseList *> valueMap;
+};
+
+} // end namespace mlir
+
+#endif // MLIR_IR_BLOCKANDVALUEMAPPING_H
diff --git a/third_party/mlir/include/mlir/IR/Builders.h b/third_party/mlir/include/mlir/IR/Builders.h
new file mode 100644
index 0000000..3e4815a
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/Builders.h
@@ -0,0 +1,388 @@
+//===- Builders.h - Helpers for constructing MLIR Classes -------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_IR_BUILDERS_H
+#define MLIR_IR_BUILDERS_H
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+
+class AffineExpr;
+class BlockAndValueMapping;
+class ModuleOp;
+class UnknownLoc;
+class FileLineColLoc;
+class Type;
+class PrimitiveType;
+class IntegerType;
+class FunctionType;
+class MemRefType;
+class VectorType;
+class RankedTensorType;
+class UnrankedTensorType;
+class TupleType;
+class NoneType;
+class BoolAttr;
+class IntegerAttr;
+class FloatAttr;
+class StringAttr;
+class TypeAttr;
+class ArrayAttr;
+class SymbolRefAttr;
+class ElementsAttr;
+class DenseElementsAttr;
+class DenseIntElementsAttr;
+class AffineMapAttr;
+class AffineMap;
+class UnitAttr;
+
+/// This class is a general helper class for creating context-global objects
+/// like types, attributes, and affine expressions.
+class Builder {
+public:
+  explicit Builder(MLIRContext *context) : context(context) {}
+  explicit Builder(ModuleOp module);
+
+  MLIRContext *getContext() const { return context; }
+
+  Identifier getIdentifier(StringRef str);
+
+  // Locations.
+  Location getUnknownLoc();
+  Location getFileLineColLoc(Identifier filename, unsigned line,
+                             unsigned column);
+  Location getFusedLoc(ArrayRef<Location> locs,
+                       Attribute metadata = Attribute());
+
+  // Types.
+  FloatType getBF16Type();
+  FloatType getF16Type();
+  FloatType getF32Type();
+  FloatType getF64Type();
+
+  IndexType getIndexType();
+
+  IntegerType getI1Type();
+  IntegerType getIntegerType(unsigned width);
+  FunctionType getFunctionType(ArrayRef<Type> inputs, ArrayRef<Type> results);
+  MemRefType getMemRefType(ArrayRef<int64_t> shape, Type elementType,
+                           ArrayRef<AffineMap> affineMapComposition = {},
+                           unsigned memorySpace = 0);
+  VectorType getVectorType(ArrayRef<int64_t> shape, Type elementType);
+  RankedTensorType getTensorType(ArrayRef<int64_t> shape, Type elementType);
+  UnrankedTensorType getTensorType(Type elementType);
+  TupleType getTupleType(ArrayRef<Type> elementTypes);
+  NoneType getNoneType();
+
+  /// Get or construct an instance of the type 'ty' with provided arguments.
+  template <typename Ty, typename... Args> Ty getType(Args... args) {
+    return Ty::get(context, args...);
+  }
+
+  // Attributes.
+  NamedAttribute getNamedAttr(StringRef name, Attribute val);
+
+  UnitAttr getUnitAttr();
+  BoolAttr getBoolAttr(bool value);
+  DictionaryAttr getDictionaryAttr(ArrayRef<NamedAttribute> value);
+  IntegerAttr getIntegerAttr(Type type, int64_t value);
+  IntegerAttr getIntegerAttr(Type type, const APInt &value);
+  FloatAttr getFloatAttr(Type type, double value);
+  FloatAttr getFloatAttr(Type type, const APFloat &value);
+  StringAttr getStringAttr(StringRef bytes);
+  StringAttr getStringAttr(StringRef bytes, Type type);
+  ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
+  AffineMapAttr getAffineMapAttr(AffineMap map);
+  IntegerSetAttr getIntegerSetAttr(IntegerSet set);
+  TypeAttr getTypeAttr(Type type);
+  SymbolRefAttr getSymbolRefAttr(Operation *value);
+  SymbolRefAttr getSymbolRefAttr(StringRef value);
+  ElementsAttr getDenseElementsAttr(ShapedType type,
+                                    ArrayRef<Attribute> values);
+  ElementsAttr getDenseIntElementsAttr(ShapedType type,
+                                       ArrayRef<int64_t> values);
+  ElementsAttr getSparseElementsAttr(ShapedType type,
+                                     DenseIntElementsAttr indices,
+                                     DenseElementsAttr values);
+  ElementsAttr getOpaqueElementsAttr(Dialect *dialect, ShapedType type,
+                                     StringRef bytes);
+  // Returns a 0-valued attribute of the given `type`. This function only
+  // supports boolean, integer, and 16-/32-/64-bit float types, and vector or
+  // ranked tensor of them. Returns null attribute otherwise.
+  Attribute getZeroAttr(Type type);
+
+  // Convenience methods for fixed types.
+  FloatAttr getF16FloatAttr(float value);
+  FloatAttr getF32FloatAttr(float value);
+  FloatAttr getF64FloatAttr(double value);
+
+  IntegerAttr getI32IntegerAttr(int32_t value);
+  IntegerAttr getI64IntegerAttr(int64_t value);
+
+  ArrayAttr getAffineMapArrayAttr(ArrayRef<AffineMap> values);
+  ArrayAttr getI32ArrayAttr(ArrayRef<int32_t> values);
+  ArrayAttr getI64ArrayAttr(ArrayRef<int64_t> values);
+  ArrayAttr getF32ArrayAttr(ArrayRef<float> values);
+  ArrayAttr getF64ArrayAttr(ArrayRef<double> values);
+  ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values);
+
+  // Affine expressions and affine maps.
+  AffineExpr getAffineDimExpr(unsigned position);
+  AffineExpr getAffineSymbolExpr(unsigned position);
+  AffineExpr getAffineConstantExpr(int64_t constant);
+
+  AffineMap getAffineMap(unsigned dimCount, unsigned symbolCount,
+                         ArrayRef<AffineExpr> results);
+
+  // Special cases of affine maps and integer sets
+  /// Returns a zero result affine map with no dimensions or symbols: () -> ().
+  AffineMap getEmptyAffineMap();
+  /// Returns a single constant result affine map with 0 dimensions and 0
+  /// symbols.  One constant result: () -> (val).
+  AffineMap getConstantAffineMap(int64_t val);
+  // One dimension id identity map: (i) -> (i).
+  AffineMap getDimIdentityMap();
+  // Multi-dimensional identity map: (d0, d1, d2) -> (d0, d1, d2).
+  AffineMap getMultiDimIdentityMap(unsigned rank);
+  // One symbol identity map: ()[s] -> (s).
+  AffineMap getSymbolIdentityMap();
+
+  /// Returns a map that shifts its (single) input dimension by 'shift'.
+  /// (d0) -> (d0 + shift)
+  AffineMap getSingleDimShiftAffineMap(int64_t shift);
+
+  /// Returns an affine map that is a translation (shift) of all result
+  /// expressions in 'map' by 'shift'.
+  /// Eg: input: (d0, d1)[s0] -> (d0, d1 + s0), shift = 2
+  ///   returns:    (d0, d1)[s0] -> (d0 + 2, d1 + s0 + 2)
+  AffineMap getShiftedAffineMap(AffineMap map, int64_t shift);
+
+  // Integer set.
+  IntegerSet getIntegerSet(unsigned dimCount, unsigned symbolCount,
+                           ArrayRef<AffineExpr> constraints,
+                           ArrayRef<bool> isEq);
+  // TODO: Helpers for affine map/exprs, etc.
+protected:
+  MLIRContext *context;
+};
+
+/// This class helps build Operations. Operations that are created are
+/// automatically inserted at an insertion point. The builder is copyable.
+class OpBuilder : public Builder {
+public:
+  /// Create a builder with the given context.
+  explicit OpBuilder(MLIRContext *ctx) : Builder(ctx) {}
+
+  /// Create a builder and set the insertion point to the start of the region.
+  explicit OpBuilder(Region *region) : Builder(region->getContext()) {
+    if (!region->empty())
+      setInsertionPoint(&region->front(), region->front().begin());
+  }
+  explicit OpBuilder(Region &region) : OpBuilder(&region) {}
+
+  virtual ~OpBuilder();
+
+  /// Create a builder and set insertion point to the given operation, which
+  /// will cause subsequent insertions to go right before it.
+  explicit OpBuilder(Operation *op) : Builder(op->getContext()) {
+    setInsertionPoint(op);
+  }
+
+  explicit OpBuilder(Block *block) : OpBuilder(block, block->end()) {}
+
+  OpBuilder(Block *block, Block::iterator insertPoint)
+      : OpBuilder(block->getParent()) {
+    setInsertionPoint(block, insertPoint);
+  }
+
+  /// This class represents a saved insertion point.
+  class InsertPoint {
+  public:
+    /// Creates a new insertion point which doesn't point to anything.
+    InsertPoint() = default;
+
+    /// Creates a new insertion point at the given location.
+    InsertPoint(Block *insertBlock, Block::iterator insertPt)
+        : block(insertBlock), point(insertPt) {}
+
+    /// Returns true if this insert point is set.
+    bool isSet() const { return (block != nullptr); }
+
+    Block *getBlock() const { return block; }
+    Block::iterator getPoint() const { return point; }
+
+  private:
+    Block *block = nullptr;
+    Block::iterator point;
+  };
+
+  /// Reset the insertion point to no location.  Creating an operation without a
+  /// set insertion point is an error, but this can still be useful when the
+  /// current insertion point a builder refers to is being removed.
+  void clearInsertionPoint() {
+    this->block = nullptr;
+    insertPoint = Block::iterator();
+  }
+
+  /// Return a saved insertion point.
+  InsertPoint saveInsertionPoint() const {
+    return InsertPoint(getInsertionBlock(), getInsertionPoint());
+  }
+
+  /// Restore the insert point to a previously saved point.
+  void restoreInsertionPoint(InsertPoint ip) {
+    if (ip.isSet())
+      setInsertionPoint(ip.getBlock(), ip.getPoint());
+    else
+      clearInsertionPoint();
+  }
+
+  /// Set the insertion point to the specified location.
+  void setInsertionPoint(Block *block, Block::iterator insertPoint) {
+    // TODO: check that insertPoint is in this rather than some other block.
+    this->block = block;
+    this->insertPoint = insertPoint;
+  }
+
+  /// Sets the insertion point to the specified operation, which will cause
+  /// subsequent insertions to go right before it.
+  void setInsertionPoint(Operation *op) {
+    setInsertionPoint(op->getBlock(), Block::iterator(op));
+  }
+
+  /// Sets the insertion point to the start of the specified block.
+  void setInsertionPointToStart(Block *block) {
+    setInsertionPoint(block, block->begin());
+  }
+
+  /// Sets the insertion point to the end of the specified block.
+  void setInsertionPointToEnd(Block *block) {
+    setInsertionPoint(block, block->end());
+  }
+
+  /// Return the block the current insertion point belongs to.  Note that the
+  /// the insertion point is not necessarily the end of the block.
+  Block *getInsertionBlock() const { return block; }
+
+  /// Returns the current insertion point of the builder.
+  Block::iterator getInsertionPoint() const { return insertPoint; }
+
+  /// Add new block and set the insertion point to the end of it. The block is
+  /// inserted at the provided insertion point of 'parent'.
+  Block *createBlock(Region *parent, Region::iterator insertPt = {});
+
+  /// Add new block and set the insertion point to the end of it. The block is
+  /// placed before 'insertBefore'.
+  Block *createBlock(Block *insertBefore);
+
+  /// Returns the current block of the builder.
+  Block *getBlock() const { return block; }
+
+  /// Creates an operation given the fields represented as an OperationState.
+  virtual Operation *createOperation(const OperationState &state);
+
+  /// Create an operation of specific op type at the current insertion point.
+  template <typename OpTy, typename... Args>
+  OpTy create(Location location, Args&&... args) {
+    OperationState state(location, OpTy::getOperationName());
+    OpTy::build(this, &state, std::forward<Args>(args)...);
+    auto *op = createOperation(state);
+    auto result = dyn_cast<OpTy>(op);
+    assert(result && "Builder didn't return the right type");
+    return result;
+  }
+
+  /// Create an operation of specific op type at the current insertion point,
+  /// and immediately try to fold it. This functions populates 'results' with
+  /// the results after folding the operation.
+  template <typename OpTy, typename... Args>
+  void createOrFold(SmallVectorImpl<Value *> &results, Location location,
+                    Args &&... args) {
+    auto op = create<OpTy>(location, std::forward<Args>(args)...);
+    tryFold(op.getOperation(), results);
+  }
+
+  /// Overload to create or fold a single result operation.
+  template <typename OpTy, typename... Args>
+  typename std::enable_if<OpTy::template hasTrait<OpTrait::OneResult>(),
+                          Value *>::type
+  createOrFold(Location location, Args &&... args) {
+    SmallVector<Value *, 1> results;
+    createOrFold<OpTy>(results, location, std::forward<Args>(args)...);
+    return results.front();
+  }
+
+  /// Overload to create or fold a zero result operation.
+  template <typename OpTy, typename... Args>
+  typename std::enable_if<OpTy::template hasTrait<OpTrait::ZeroResult>(),
+                          OpTy>::type
+  createOrFold(Location location, Args &&... args) {
+    auto op = create<OpTy>(location, std::forward<Args>(args)...);
+    SmallVector<Value *, 0> unused;
+    tryFold(op.getOperation(), unused);
+
+    // Folding cannot remove a zero-result operation, so for convenience we
+    // continue to return it.
+    return op;
+  }
+
+  /// Creates a deep copy of the specified operation, remapping any operands
+  /// that use values outside of the operation using the map that is provided
+  /// ( leaving them alone if no entry is present).  Replaces references to
+  /// cloned sub-operations to the corresponding operation that is copied,
+  /// and adds those mappings to the map.
+  Operation *clone(Operation &op, BlockAndValueMapping &mapper) {
+    Operation *cloneOp = op.clone(mapper);
+    insert(cloneOp);
+    return cloneOp;
+  }
+  Operation *clone(Operation &op) {
+    Operation *cloneOp = op.clone();
+    insert(cloneOp);
+    return cloneOp;
+  }
+
+  /// Creates a deep copy of this operation but keep the operation regions
+  /// empty. Operands are remapped using `mapper` (if present), and `mapper` is
+  /// updated to contain the results.
+  Operation *cloneWithoutRegions(Operation &op, BlockAndValueMapping &mapper) {
+    Operation *cloneOp = op.cloneWithoutRegions(mapper);
+    insert(cloneOp);
+    return cloneOp;
+  }
+  Operation *cloneWithoutRegions(Operation &op) {
+    Operation *cloneOp = op.cloneWithoutRegions();
+    insert(cloneOp);
+    return cloneOp;
+  }
+
+private:
+  /// Attempts to fold the given operation and places new results within
+  /// 'results'.
+  void tryFold(Operation *op, SmallVectorImpl<Value *> &results);
+
+  /// Insert the given operation at the current insertion point.
+  void insert(Operation *op);
+
+  Block *block = nullptr;
+  Block::iterator insertPoint;
+};
+
+} // namespace mlir
+
+#endif
diff --git a/third_party/mlir/include/mlir/IR/Diagnostics.h b/third_party/mlir/include/mlir/IR/Diagnostics.h
new file mode 100644
index 0000000..b9621b6
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/Diagnostics.h
@@ -0,0 +1,604 @@
+//===- Diagnostics.h - MLIR Diagnostics -------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines utilities for emitting diagnostics.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_DIAGNOSTICS_H
+#define MLIR_IR_DIAGNOSTICS_H
+
+#include "mlir/IR/Location.h"
+#include "mlir/Support/STLExtras.h"
+#include <functional>
+
+namespace llvm {
+class MemoryBuffer;
+class SMLoc;
+class SourceMgr;
+} // end namespace llvm
+
+namespace mlir {
+class DiagnosticEngine;
+class Identifier;
+struct LogicalResult;
+class MLIRContext;
+class Operation;
+class OperationName;
+class Type;
+
+namespace detail {
+struct DiagnosticEngineImpl;
+} // end namespace detail
+
+/// Defines the different supported severity of a diagnostic.
+enum class DiagnosticSeverity {
+  Note,
+  Warning,
+  Error,
+  Remark,
+};
+
+//===----------------------------------------------------------------------===//
+// DiagnosticArgument
+//===----------------------------------------------------------------------===//
+
+/// A variant type that holds a single argument for a diagnostic.
+class DiagnosticArgument {
+public:
+  /// Enum that represents the different kinds of diagnostic arguments
+  /// supported.
+  enum class DiagnosticArgumentKind {
+    Attribute,
+    Double,
+    Integer,
+    Operation,
+    String,
+    Type,
+    Unsigned,
+  };
+
+  /// Outputs this argument to a stream.
+  void print(raw_ostream &os) const;
+
+  /// Returns the kind of this argument.
+  DiagnosticArgumentKind getKind() const { return kind; }
+
+  /// Returns this argument as an Attribute.
+  Attribute getAsAttribute() const;
+
+  /// Returns this argument as a double.
+  double getAsDouble() const {
+    assert(getKind() == DiagnosticArgumentKind::Double);
+    return doubleVal;
+  }
+
+  /// Returns this argument as a signed integer.
+  int64_t getAsInteger() const {
+    assert(getKind() == DiagnosticArgumentKind::Integer);
+    return static_cast<int64_t>(opaqueVal);
+  }
+
+  /// Returns this argument as an operation.
+  Operation &getAsOperation() const {
+    assert(getKind() == DiagnosticArgumentKind::Operation);
+    return *reinterpret_cast<Operation *>(opaqueVal);
+  }
+
+  /// Returns this argument as a string.
+  StringRef getAsString() const {
+    assert(getKind() == DiagnosticArgumentKind::String);
+    return stringVal;
+  }
+
+  /// Returns this argument as a Type.
+  Type getAsType() const;
+
+  /// Returns this argument as an unsigned integer.
+  uint64_t getAsUnsigned() const {
+    assert(getKind() == DiagnosticArgumentKind::Unsigned);
+    return static_cast<uint64_t>(opaqueVal);
+  }
+
+private:
+  friend class Diagnostic;
+
+  // Construct from an Attribute.
+  explicit DiagnosticArgument(Attribute attr);
+
+  // Construct from a floating point number.
+  explicit DiagnosticArgument(double val)
+      : kind(DiagnosticArgumentKind::Double), doubleVal(val) {}
+  explicit DiagnosticArgument(float val) : DiagnosticArgument(double(val)) {}
+
+  // Construct from a signed integer.
+  template <typename T>
+  explicit DiagnosticArgument(
+      T val, typename std::enable_if<std::is_signed<T>::value &&
+                                     std::numeric_limits<T>::is_integer &&
+                                     sizeof(T) <= sizeof(int64_t)>::type * = 0)
+      : kind(DiagnosticArgumentKind::Integer), opaqueVal(int64_t(val)) {}
+
+  // Construct from an unsigned integer.
+  template <typename T>
+  explicit DiagnosticArgument(
+      T val, typename std::enable_if<std::is_unsigned<T>::value &&
+                                     std::numeric_limits<T>::is_integer &&
+                                     sizeof(T) <= sizeof(uint64_t)>::type * = 0)
+      : kind(DiagnosticArgumentKind::Unsigned), opaqueVal(uint64_t(val)) {}
+
+  // Construct from an operation reference.
+  explicit DiagnosticArgument(Operation &val) : DiagnosticArgument(&val) {}
+  explicit DiagnosticArgument(Operation *val)
+      : kind(DiagnosticArgumentKind::Operation),
+        opaqueVal(reinterpret_cast<intptr_t>(val)) {
+    assert(val && "expected valid operation");
+  }
+
+  // Construct from a string reference.
+  explicit DiagnosticArgument(StringRef val)
+      : kind(DiagnosticArgumentKind::String), stringVal(val) {}
+
+  // Construct from a Type.
+  explicit DiagnosticArgument(Type val);
+
+  /// The kind of this argument.
+  DiagnosticArgumentKind kind;
+
+  /// The value of this argument.
+  union {
+    double doubleVal;
+    intptr_t opaqueVal;
+    StringRef stringVal;
+  };
+};
+
+inline raw_ostream &operator<<(raw_ostream &os, const DiagnosticArgument &arg) {
+  arg.print(os);
+  return os;
+}
+
+//===----------------------------------------------------------------------===//
+// Diagnostic
+//===----------------------------------------------------------------------===//
+
+/// This class contains all of the information necessary to report a diagnostic
+/// to the DiagnosticEngine. It should generally not be constructed directly,
+/// and instead used transitively via InFlightDiagnostic.
+class Diagnostic {
+  using NoteVector = std::vector<std::unique_ptr<Diagnostic>>;
+
+  /// This class implements a wrapper iterator around NoteVector::iterator to
+  /// implicitly dereference the unique_ptr.
+  template <typename IteratorTy, typename NotePtrTy = decltype(*IteratorTy()),
+            typename ResultTy = decltype(**IteratorTy())>
+  class NoteIteratorImpl
+      : public llvm::mapped_iterator<IteratorTy, ResultTy (*)(NotePtrTy)> {
+    static ResultTy &unwrap(NotePtrTy note) { return *note; }
+
+  public:
+    NoteIteratorImpl(IteratorTy it)
+        : llvm::mapped_iterator<IteratorTy, ResultTy (*)(NotePtrTy)>(it,
+                                                                     &unwrap) {}
+  };
+
+public:
+  Diagnostic(Location loc, DiagnosticSeverity severity)
+      : loc(loc), severity(severity) {}
+  Diagnostic(Diagnostic &&) = default;
+  Diagnostic &operator=(Diagnostic &&) = default;
+
+  /// Returns the severity of this diagnostic.
+  DiagnosticSeverity getSeverity() const { return severity; }
+
+  /// Returns the source location for this diagnostic.
+  Location getLocation() const { return loc; }
+
+  /// Returns the current list of diagnostic arguments.
+  MutableArrayRef<DiagnosticArgument> getArguments() { return arguments; }
+  ArrayRef<DiagnosticArgument> getArguments() const { return arguments; }
+
+  /// Stream operator for inserting new diagnostic arguments.
+  template <typename Arg>
+  typename std::enable_if<!std::is_convertible<Arg, StringRef>::value,
+                          Diagnostic &>::type
+  operator<<(Arg &&val) {
+    arguments.push_back(DiagnosticArgument(std::forward<Arg>(val)));
+    return *this;
+  }
+
+  /// Stream in a string literal.
+  Diagnostic &operator<<(const char *val) {
+    arguments.push_back(DiagnosticArgument(val));
+    return *this;
+  }
+
+  /// Stream in a Twine argument.
+  Diagnostic &operator<<(char val);
+  Diagnostic &operator<<(const Twine &val);
+  Diagnostic &operator<<(Twine &&val);
+
+  /// Stream in an Identifier.
+  Diagnostic &operator<<(Identifier val);
+
+  /// Stream in an OperationName.
+  Diagnostic &operator<<(OperationName val);
+
+  /// Stream in a range.
+  template <typename T> Diagnostic &operator<<(llvm::iterator_range<T> range) {
+    return appendRange(range);
+  }
+  template <typename T> Diagnostic &operator<<(llvm::ArrayRef<T> range) {
+    return appendRange(range);
+  }
+
+  /// Append a range to the diagnostic. The default delimiter between elements
+  /// is ','.
+  template <typename T, template <typename> class Container>
+  Diagnostic &appendRange(const Container<T> &c, const char *delim = ", ") {
+    interleave(
+        c, [&](const detail::ValueOfRange<Container<T>> &a) { *this << a; },
+        [&]() { *this << delim; });
+    return *this;
+  }
+
+  /// Append arguments to the diagnostic.
+  template <typename Arg1, typename Arg2, typename... Args>
+  Diagnostic &append(Arg1 &&arg1, Arg2 &&arg2, Args &&... args) {
+    append(std::forward<Arg1>(arg1));
+    return append(std::forward<Arg2>(arg2), std::forward<Args>(args)...);
+  }
+  /// Append one argument to the diagnostic.
+  template <typename Arg> Diagnostic &append(Arg &&arg) {
+    *this << std::forward<Arg>(arg);
+    return *this;
+  }
+
+  /// Outputs this diagnostic to a stream.
+  void print(raw_ostream &os) const;
+
+  /// Converts the diagnostic to a string.
+  std::string str() const;
+
+  /// Attaches a note to this diagnostic. A new location may be optionally
+  /// provided, if not, then the location defaults to the one specified for this
+  /// diagnostic. Notes may not be attached to other notes.
+  Diagnostic &attachNote(llvm::Optional<Location> noteLoc = llvm::None);
+
+  using note_iterator = NoteIteratorImpl<NoteVector::iterator>;
+  using const_note_iterator = NoteIteratorImpl<NoteVector::const_iterator>;
+
+  /// Returns the notes held by this diagnostic.
+  llvm::iterator_range<note_iterator> getNotes() {
+    return {notes.begin(), notes.end()};
+  }
+  llvm::iterator_range<const_note_iterator> getNotes() const {
+    return {notes.begin(), notes.end()};
+  }
+
+  /// Allow a diagnostic to be converted to 'failure'.
+  operator LogicalResult() const;
+
+private:
+  Diagnostic(const Diagnostic &rhs) = delete;
+  Diagnostic &operator=(const Diagnostic &rhs) = delete;
+
+  /// The source location.
+  Location loc;
+
+  /// The severity of this diagnostic.
+  DiagnosticSeverity severity;
+
+  /// The current list of arguments.
+  SmallVector<DiagnosticArgument, 4> arguments;
+
+  /// A list of string values used as arguments. This is used to guarantee the
+  /// liveness of non-constant strings used in diagnostics.
+  std::vector<std::unique_ptr<char[]>> strings;
+
+  /// A list of attached notes.
+  NoteVector notes;
+};
+
+inline raw_ostream &operator<<(raw_ostream &os, const Diagnostic &diag) {
+  diag.print(os);
+  return os;
+}
+
+//===----------------------------------------------------------------------===//
+// InFlightDiagnostic
+//===----------------------------------------------------------------------===//
+
+/// This class represents a diagnostic that is inflight and set to be reported.
+/// This allows for last minute modifications of the diagnostic before it is
+/// emitted by a DiagnosticEngine.
+class InFlightDiagnostic {
+public:
+  InFlightDiagnostic() = default;
+  InFlightDiagnostic(InFlightDiagnostic &&rhs)
+      : owner(rhs.owner), impl(std::move(rhs.impl)) {
+    // Reset the rhs diagnostic.
+    rhs.impl.reset();
+    rhs.abandon();
+  }
+  ~InFlightDiagnostic() {
+    if (isInFlight())
+      report();
+  }
+
+  /// Stream operator for new diagnostic arguments.
+  template <typename Arg> InFlightDiagnostic &operator<<(Arg &&arg) & {
+    return append(std::forward<Arg>(arg));
+  }
+  template <typename Arg> InFlightDiagnostic &&operator<<(Arg &&arg) && {
+    return std::move(append(std::forward<Arg>(arg)));
+  }
+
+  /// Append arguments to the diagnostic.
+  template <typename... Args> InFlightDiagnostic &append(Args &&... args) & {
+    assert(isActive() && "diagnostic not active");
+    if (isInFlight())
+      impl->append(std::forward<Args>(args)...);
+    return *this;
+  }
+  template <typename... Args> InFlightDiagnostic &&append(Args &&... args) && {
+    return std::move(append(std::forward<Args>(args)...));
+  }
+
+  /// Attaches a note to this diagnostic.
+  Diagnostic &attachNote(llvm::Optional<Location> noteLoc = llvm::None) {
+    assert(isActive() && "diagnostic not active");
+    return impl->attachNote(noteLoc);
+  }
+
+  /// Reports the diagnostic to the engine.
+  void report();
+
+  /// Abandons this diagnostic so that it will no longer be reported.
+  void abandon();
+
+  /// Allow an inflight diagnostic to be converted to 'failure', otherwise
+  /// 'success' if this is an empty diagnostic.
+  operator LogicalResult() const;
+
+private:
+  InFlightDiagnostic &operator=(const InFlightDiagnostic &) = delete;
+  InFlightDiagnostic &operator=(InFlightDiagnostic &&) = delete;
+  InFlightDiagnostic(DiagnosticEngine *owner, Diagnostic &&rhs)
+      : owner(owner), impl(std::move(rhs)) {}
+
+  /// Returns if the diagnostic is still active, i.e. it has a live diagnostic.
+  bool isActive() const { return impl.hasValue(); }
+
+  /// Returns if the diagnostic is still in flight to be reported.
+  bool isInFlight() const { return owner; }
+
+  // Allow access to the constructor.
+  friend DiagnosticEngine;
+
+  /// The engine that this diagnostic is to report to.
+  DiagnosticEngine *owner;
+
+  /// The raw diagnostic that is inflight to be reported.
+  llvm::Optional<Diagnostic> impl;
+};
+
+//===----------------------------------------------------------------------===//
+// DiagnosticEngine
+//===----------------------------------------------------------------------===//
+
+/// This class is the main interface for diagnostics. The DiagnosticEngine
+/// manages the registration of diagnostic handlers as well as the core API for
+/// diagnostic emission. This class should not be constructed directly, but
+/// instead interfaced with via an MLIRContext instance.
+class DiagnosticEngine {
+public:
+  ~DiagnosticEngine();
+
+  // Diagnostic handler registration and use.  MLIR supports the ability for the
+  // IR to carry arbitrary metadata about operation location information.  If a
+  // problem is detected by the compiler, it can invoke the emitError /
+  // emitWarning / emitRemark method on an Operation and have it get reported
+  // through this interface.
+  //
+  // Tools using MLIR are encouraged to register error handlers and define a
+  // schema for their location information.  If they don't, then warnings and
+  // notes will be dropped and errors will be emitted to errs.
+
+  using HandlerTy = std::function<void(Diagnostic)>;
+
+  /// Set the diagnostic handler for this engine. Note that this replaces any
+  /// existing handler.
+  void setHandler(const HandlerTy &handler);
+
+  /// Return the current diagnostic handler, or null if none is present.
+  HandlerTy getHandler();
+
+  /// Create a new inflight diagnostic with the given location and severity.
+  InFlightDiagnostic emit(Location loc, DiagnosticSeverity severity) {
+    assert(severity != DiagnosticSeverity::Note &&
+           "notes should not be emitted directly");
+    return InFlightDiagnostic(this, Diagnostic(loc, severity));
+  }
+
+  /// Emit a diagnostic using the registered issue handler if present, or with
+  /// the default behavior if not.
+  void emit(Diagnostic diag);
+
+private:
+  friend class MLIRContextImpl;
+  DiagnosticEngine();
+
+  /// The internal implementation of the DiagnosticEngine.
+  std::unique_ptr<detail::DiagnosticEngineImpl> impl;
+};
+
+//===----------------------------------------------------------------------===//
+// ScopedDiagnosticHandler
+//===----------------------------------------------------------------------===//
+
+/// This diagnostic handler is a simple RAII class that saves and restores the
+/// current diagnostic handler registered to a given context. This class can
+/// be either be used directly, or in conjunction with a derived diagnostic
+/// handler.
+class ScopedDiagnosticHandler {
+public:
+  ScopedDiagnosticHandler(MLIRContext *ctx);
+  ScopedDiagnosticHandler(MLIRContext *ctx,
+                          const DiagnosticEngine::HandlerTy &handler);
+  ~ScopedDiagnosticHandler();
+
+  /// Propagate a diagnostic to the existing diagnostic handler.
+  void propagateDiagnostic(Diagnostic diag) {
+    if (existingHandler)
+      existingHandler(std::move(diag));
+  }
+
+private:
+  /// The existing diagnostic handler registered with the context at the time of
+  /// construction.
+  DiagnosticEngine::HandlerTy existingHandler;
+
+  /// The context to register the handler back to.
+  MLIRContext *ctx;
+};
+
+/// Utility method to emit an error message using this location.
+InFlightDiagnostic emitError(Location loc);
+InFlightDiagnostic emitError(Location loc, const Twine &message);
+
+/// Utility method to emit a warning message using this location.
+InFlightDiagnostic emitWarning(Location loc);
+InFlightDiagnostic emitWarning(Location loc, const Twine &message);
+
+/// Utility method to emit a remark message using this location.
+InFlightDiagnostic emitRemark(Location loc);
+InFlightDiagnostic emitRemark(Location loc, const Twine &message);
+
+//===----------------------------------------------------------------------===//
+// SourceMgrDiagnosticHandler
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+struct SourceMgrDiagnosticHandlerImpl;
+} // end namespace detail
+
+/// This class is a utility diagnostic handler for use with llvm::SourceMgr.
+class SourceMgrDiagnosticHandler : public ScopedDiagnosticHandler {
+public:
+  SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, MLIRContext *ctx,
+                             llvm::raw_ostream &os);
+  SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, MLIRContext *ctx);
+  ~SourceMgrDiagnosticHandler();
+
+  /// Emit the given diagnostic information with the held source manager.
+  void emitDiagnostic(Location loc, Twine message, DiagnosticSeverity kind);
+
+protected:
+  /// Emit the given diagnostic with the held source manager.
+  void emitDiagnostic(Diagnostic &diag);
+
+  /// Get a memory buffer for the given file, or nullptr if no file is
+  /// available.
+  const llvm::MemoryBuffer *getBufferForFile(StringRef filename);
+
+  /// The source manager that we are wrapping.
+  llvm::SourceMgr &mgr;
+
+  /// The output stream to use when printing diagnostics.
+  llvm::raw_ostream &os;
+
+private:
+  /// Convert a location into the given memory buffer into an SMLoc.
+  llvm::SMLoc convertLocToSMLoc(FileLineColLoc loc);
+
+  /// The maximum depth that a call stack will be printed.
+  /// TODO(riverriddle) This should be a tunable flag.
+  unsigned callStackLimit = 10;
+
+  std::unique_ptr<detail::SourceMgrDiagnosticHandlerImpl> impl;
+};
+
+//===----------------------------------------------------------------------===//
+// SourceMgrDiagnosticVerifierHandler
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+struct SourceMgrDiagnosticVerifierHandlerImpl;
+} // end namespace detail
+
+/// This class is a utility diagnostic handler for use with llvm::SourceMgr that
+/// verifies that emitted diagnostics match 'expected-*' lines on the
+/// corresponding line of the source file.
+class SourceMgrDiagnosticVerifierHandler : public SourceMgrDiagnosticHandler {
+public:
+  SourceMgrDiagnosticVerifierHandler(llvm::SourceMgr &srcMgr, MLIRContext *ctx,
+                                     llvm::raw_ostream &out);
+  SourceMgrDiagnosticVerifierHandler(llvm::SourceMgr &srcMgr, MLIRContext *ctx);
+  ~SourceMgrDiagnosticVerifierHandler();
+
+  /// Returns the status of the handler and verifies that all expected
+  /// diagnostics were emitted. This return success if all diagnostics were
+  /// verified correctly, failure otherwise.
+  LogicalResult verify();
+
+private:
+  /// Process a single diagnostic.
+  void process(Diagnostic &diag);
+
+  /// Process a FileLineColLoc diagnostic.
+  void process(FileLineColLoc loc, StringRef msg, DiagnosticSeverity kind);
+
+  std::unique_ptr<detail::SourceMgrDiagnosticVerifierHandlerImpl> impl;
+};
+
+//===----------------------------------------------------------------------===//
+// ParallelDiagnosticHandler
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+struct ParallelDiagnosticHandlerImpl;
+} // end namespace detail
+
+/// This class is a utility diagnostic handler for use when multi-threading some
+/// part of the compiler where diagnostics may be emitted. This handler ensures
+/// a deterministic ordering to the emitted diagnostics that mirrors that of a
+/// single-threaded compilation.
+class ParallelDiagnosticHandler {
+public:
+  ParallelDiagnosticHandler(MLIRContext *ctx);
+  ~ParallelDiagnosticHandler();
+
+  /// Set the order id for the current thread. This is required to be set by
+  /// each thread that will be emitting diagnostics to this handler. The orderID
+  /// corresponds to the order in which diagnostics would be emitted when
+  /// executing synchronously. For example, if we were processing a list
+  /// of operations [a, b, c] on a single-thread. Diagnostics emitted while
+  /// processing operation 'a' would be emitted before those for 'b' or 'c'.
+  /// This corresponds 1-1 with the 'orderID'. The thread that is processing 'a'
+  /// should set the orderID to '0'; the thread processing 'b' should set it to
+  /// '1'; and so on and so forth. This provides a way for the handler to
+  /// deterministically order the diagnostics that it receives given the thread
+  /// that it is receiving on.
+  void setOrderIDForThread(size_t orderID);
+
+private:
+  std::unique_ptr<detail::ParallelDiagnosticHandlerImpl> impl;
+};
+} // namespace mlir
+
+#endif
diff --git a/third_party/mlir/include/mlir/IR/Dialect.h b/third_party/mlir/include/mlir/IR/Dialect.h
new file mode 100644
index 0000000..eef7711
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/Dialect.h
@@ -0,0 +1,299 @@
+//===- Dialect.h - IR Dialect Description -----------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines the 'dialect' abstraction.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_DIALECT_H
+#define MLIR_IR_DIALECT_H
+
+#include "mlir/IR/OperationSupport.h"
+
+namespace mlir {
+class OpBuilder;
+class Type;
+
+using DialectConstantDecodeHook =
+    std::function<bool(const OpaqueElementsAttr, ElementsAttr &)>;
+using DialectConstantFoldHook = std::function<LogicalResult(
+    Operation *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>;
+using DialectExtractElementHook =
+    std::function<Attribute(const OpaqueElementsAttr, ArrayRef<uint64_t>)>;
+
+/// Dialects are groups of MLIR operations and behavior associated with the
+/// entire group.  For example, hooks into other systems for constant folding,
+/// default named types for asm printing, etc.
+///
+/// Instances of the dialect object are global across all MLIRContext's that may
+/// be active in the process.
+///
+class Dialect {
+public:
+  virtual ~Dialect();
+
+  /// Utility function that returns if the given string is a valid dialect
+  /// namespace.
+  static bool isValidNamespace(StringRef str);
+
+  MLIRContext *getContext() const { return context; }
+
+  StringRef getNamespace() const { return name; }
+
+  /// Returns true if this dialect allows for unregistered operations, i.e.
+  /// operations prefixed with the dialect namespace but not registered with
+  /// addOperation.
+  bool allowsUnknownOperations() const { return unknownOpsAllowed; }
+
+  /// Return true if this dialect allows for unregistered types, i.e., types
+  /// prefixed with the dialect namespace but not registered with addType.
+  /// These are represented with OpaqueType.
+  bool allowsUnknownTypes() const { return unknownTypesAllowed; }
+
+  //===--------------------------------------------------------------------===//
+  // Constant Hooks
+  //===--------------------------------------------------------------------===//
+
+  /// Registered fallback constant fold hook for the dialect. Like the constant
+  /// fold hook of each operation, it attempts to constant fold the operation
+  /// with the specified constant operand values - the elements in "operands"
+  /// will correspond directly to the operands of the operation, but may be null
+  /// if non-constant.  If constant folding is successful, this fills in the
+  /// `results` vector.  If not, this returns failure and `results` is
+  /// unspecified.
+  DialectConstantFoldHook constantFoldHook =
+      [](Operation *op, ArrayRef<Attribute> operands,
+         SmallVectorImpl<Attribute> &results) { return failure(); };
+
+  /// Registered hook to decode opaque constants associated with this
+  /// dialect. The hook function attempts to decode an opaque constant tensor
+  /// into a tensor with non-opaque content. If decoding is successful, this
+  /// method returns false and sets 'output' attribute. If not, it returns true
+  /// and leaves 'output' unspecified. The default hook fails to decode.
+  DialectConstantDecodeHook decodeHook =
+      [](const OpaqueElementsAttr input, ElementsAttr &output) { return true; };
+
+  /// Registered hook to extract an element from an opaque constant associated
+  /// with this dialect. If element has been successfully extracted, this
+  /// method returns that element. If not, it returns an empty attribute.
+  /// The default hook fails to extract an element.
+  DialectExtractElementHook extractElementHook =
+      [](const OpaqueElementsAttr input, ArrayRef<uint64_t> index) {
+        return Attribute();
+      };
+
+  /// Registered hook to materialize a single constant operation from a given
+  /// attribute value with the desired resultant type. This method should use
+  /// the provided builder to create the operation without changing the
+  /// insertion position. The generated operation is expected to be constant
+  /// like, i.e. single result, zero operands, non side-effecting, etc. On
+  /// success, this hook should return the value generated to represent the
+  /// constant value. Otherwise, it should return null on failure.
+  virtual Operation *materializeConstant(OpBuilder &builder, Attribute value,
+                                         Type type, Location loc) {
+    return nullptr;
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Parsing Hooks
+  //===--------------------------------------------------------------------===//
+
+  /// Parse an attribute registered to this dialect. If 'type' is nonnull, it
+  /// refers to the expected type of the attribute.
+  virtual Attribute parseAttribute(StringRef attrData, Type type,
+                                   Location loc) const;
+
+  /// Print an attribute registered to this dialect. Note: The type of the
+  /// attribute need not be printed by this method as it is always printed by
+  /// the caller.
+  virtual void printAttribute(Attribute, raw_ostream &) const {
+    llvm_unreachable("dialect has no registered attribute printing hook");
+  }
+
+  /// Parse a type registered to this dialect.
+  virtual Type parseType(StringRef tyData, Location loc) const;
+
+  /// Print a type registered to this dialect.
+  virtual void printType(Type, raw_ostream &) const {
+    llvm_unreachable("dialect has no registered type printing hook");
+  }
+
+  /// Registered hooks for getting identifier aliases for symbols. The
+  /// identifier is used in place of the symbol when printing textual IR.
+  ///
+  /// Hook for defining Attribute kind aliases. This will generate an alias for
+  /// all attributes of the given kind in the form : <alias>[0-9]+. These
+  /// aliases must not contain `.`.
+  virtual void getAttributeKindAliases(
+      SmallVectorImpl<std::pair<unsigned, StringRef>> &aliases) {}
+  /// Hook for defining Attribute aliases. These aliases must not contain `.` or
+  /// end with a numeric digit([0-9]+).
+  virtual void getAttributeAliases(
+      SmallVectorImpl<std::pair<Attribute, StringRef>> &aliases) {}
+  /// Hook for defining Type aliases.
+  virtual void
+  getTypeAliases(SmallVectorImpl<std::pair<Type, StringRef>> &aliases) {}
+
+  //===--------------------------------------------------------------------===//
+  // Verification Hooks
+  //===--------------------------------------------------------------------===//
+
+  /// Verify an attribute from this dialect on the argument at 'argIndex' for
+  /// the region at 'regionIndex' on the given operation. Returns failure if
+  /// the verification failed, success otherwise. This hook may optionally be
+  /// invoked from any operation containing a region.
+  virtual LogicalResult verifyRegionArgAttribute(Operation *,
+                                                 unsigned regionIndex,
+                                                 unsigned argIndex,
+                                                 NamedAttribute);
+
+  /// Verify an attribute from this dialect on the given operation. Returns
+  /// failure if the verification failed, success otherwise.
+  virtual LogicalResult verifyOperationAttribute(Operation *, NamedAttribute) {
+    return success();
+  }
+
+protected:
+  /// The constructor takes a unique namespace for this dialect as well as the
+  /// context to bind to.
+  /// Note: The namespace must not contain '.' characters.
+  /// Note: All operations belonging to this dialect must have names starting
+  ///       with the namespace followed by '.'.
+  /// Example:
+  ///       - "tf" for the TensorFlow ops like "tf.add".
+  Dialect(StringRef name, MLIRContext *context);
+
+  /// This method is used by derived classes to add their operations to the set.
+  ///
+  template <typename... Args> void addOperations() {
+    VariadicOperationAdder<Args...>::addToSet(*this);
+  }
+
+  // It would be nice to define this as variadic functions instead of a nested
+  // variadic type, but we can't do that: function template partial
+  // specialization is not allowed, and we can't define an overload set because
+  // we don't have any arguments of the types we are pushing around.
+  template <typename First, typename... Rest> class VariadicOperationAdder {
+  public:
+    static void addToSet(Dialect &dialect) {
+      dialect.addOperation(AbstractOperation::get<First>(dialect));
+      VariadicOperationAdder<Rest...>::addToSet(dialect);
+    }
+  };
+
+  template <typename First> class VariadicOperationAdder<First> {
+  public:
+    static void addToSet(Dialect &dialect) {
+      dialect.addOperation(AbstractOperation::get<First>(dialect));
+    }
+  };
+
+  void addOperation(AbstractOperation opInfo);
+
+  /// This method is used by derived classes to add their types to the set.
+  template <typename... Args> void addTypes() {
+    VariadicSymbolAdder<Args...>::addToSet(*this);
+  }
+
+  /// This method is used by derived classes to add their attributes to the set.
+  template <typename... Args> void addAttributes() {
+    VariadicSymbolAdder<Args...>::addToSet(*this);
+  }
+
+  // It would be nice to define this as variadic functions instead of a nested
+  // variadic type, but we can't do that: function template partial
+  // specialization is not allowed, and we can't define an overload set
+  // because we don't have any arguments of the types we are pushing around.
+  template <typename First, typename... Rest> struct VariadicSymbolAdder {
+    static void addToSet(Dialect &dialect) {
+      VariadicSymbolAdder<First>::addToSet(dialect);
+      VariadicSymbolAdder<Rest...>::addToSet(dialect);
+    }
+  };
+
+  template <typename First> struct VariadicSymbolAdder<First> {
+    static void addToSet(Dialect &dialect) {
+      dialect.addSymbol(First::getClassID());
+    }
+  };
+
+  /// Enable support for unregistered operations.
+  void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; }
+
+  /// Enable support for unregistered types.
+  void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; }
+
+private:
+  // Register a symbol(e.g. type) with its given unique class identifier.
+  void addSymbol(const ClassID *const classID);
+
+  Dialect(const Dialect &) = delete;
+  void operator=(Dialect &) = delete;
+
+  /// Register this dialect object with the specified context.  The context
+  /// takes ownership of the heap allocated dialect.
+  void registerDialect(MLIRContext *context);
+
+  /// The namespace of this dialect.
+  StringRef name;
+
+  /// This is the context that owns this Dialect object.
+  MLIRContext *context;
+
+  /// Flag that specifies whether this dialect supports unregistered operations,
+  /// i.e. operations prefixed with the dialect namespace but not registered
+  /// with addOperation.
+  bool unknownOpsAllowed = false;
+
+  /// Flag that specifies whether this dialect allows unregistered types, i.e.
+  /// types prefixed with the dialect namespace but not registered with addType.
+  /// These types are represented with OpaqueType.
+  bool unknownTypesAllowed = false;
+};
+
+using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
+
+/// Registers a specific dialect creation function with the system, typically
+/// used through the DialectRegistration template.
+void registerDialectAllocator(const DialectAllocatorFunction &function);
+
+/// Registers all dialects with the specified MLIRContext.
+void registerAllDialects(MLIRContext *context);
+
+/// Utility to register a dialect. Client can register their dialect with the
+/// global registry by calling registerDialect<MyDialect>();
+template <typename ConcreteDialect> void registerDialect() {
+  registerDialectAllocator([](MLIRContext *ctx) {
+    // Just allocate the dialect, the context takes ownership of it.
+    new ConcreteDialect(ctx);
+  });
+}
+
+/// DialectRegistration provides a global initialiser that registers a Dialect
+/// allocation routine.
+///
+/// Usage:
+///
+///   // At namespace scope.
+///   static DialectRegistration<MyDialect> Unused;
+template <typename ConcreteDialect> struct DialectRegistration {
+  DialectRegistration() { registerDialect<ConcreteDialect>(); }
+};
+
+} // namespace mlir
+
+#endif
diff --git a/third_party/mlir/include/mlir/IR/DialectHooks.h b/third_party/mlir/include/mlir/IR/DialectHooks.h
new file mode 100644
index 0000000..f368988
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/DialectHooks.h
@@ -0,0 +1,82 @@
+//===- DialectHooks.h - MLIR DialectHooks mechanism -------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines abstraction and registration mechanism for dialect hooks.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_DIALECT_HOOKS_H
+#define MLIR_IR_DIALECT_HOOKS_H
+
+#include "mlir/IR/Dialect.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+using DialectHooksSetter = std::function<void(MLIRContext *)>;
+
+/// Dialect hooks allow external components to register their functions to
+/// be called for specific tasks specialized per dialect, such as decoding
+/// of opaque constants. To register concrete dialect hooks, one should
+/// define a DialectHooks subclass and use it as a template
+/// argument to DialectHooksRegistration. For example,
+///     class MyHooks : public DialectHooks {...};
+///     static DialectHooksRegistration<MyHooks, MyDialect> hooksReg;
+/// The subclass should override DialectHook methods for supported hooks.
+class DialectHooks {
+public:
+  // Returns hook to constant fold an operation.
+  DialectConstantFoldHook getConstantFoldHook() { return nullptr; }
+  // Returns hook to decode opaque constant tensor.
+  DialectConstantDecodeHook getDecodeHook() { return nullptr; }
+  // Returns hook to extract an element of an opaque constant tensor.
+  DialectExtractElementHook getExtractElementHook() { return nullptr; }
+};
+
+/// Registers a function that will set hooks in the registered dialects
+/// based on information coming from DialectHooksRegistration.
+void registerDialectHooksSetter(const DialectHooksSetter &function);
+
+/// DialectHooksRegistration provides a global initialiser that registers
+/// a dialect hooks setter routine.
+/// Usage:
+///
+///   // At namespace scope.
+///   static DialectHooksRegistration<MyHooks, MyDialect> unused;
+template <typename ConcreteHooks> struct DialectHooksRegistration {
+  DialectHooksRegistration(StringRef dialectName) {
+    registerDialectHooksSetter([dialectName](MLIRContext *ctx) {
+      Dialect *dialect = ctx->getRegisteredDialect(dialectName);
+      if (!dialect) {
+        llvm::errs() << "error: cannot register hooks for unknown dialect '"
+                     << dialectName << "'\n";
+        abort();
+      }
+      // Set hooks.
+      ConcreteHooks hooks;
+      if (auto h = hooks.getConstantFoldHook())
+        dialect->constantFoldHook = h;
+      if (auto h = hooks.getDecodeHook())
+        dialect->decodeHook = h;
+      if (auto h = hooks.getExtractElementHook())
+        dialect->extractElementHook = h;
+    });
+  }
+};
+
+} // namespace mlir
+
+#endif
diff --git a/third_party/mlir/include/mlir/IR/DialectSymbolRegistry.def b/third_party/mlir/include/mlir/IR/DialectSymbolRegistry.def
new file mode 100644
index 0000000..bf9fc1d
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/DialectSymbolRegistry.def
@@ -0,0 +1,48 @@
+//===- DialectSymbolRegistry.def - MLIR Dialect Symbol Registry -*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file enumerates the different dialects that define custom classes
+// within the attribute or type system.
+//
+//===----------------------------------------------------------------------===//
+
+DEFINE_SYM_KIND_RANGE(STANDARD)
+DEFINE_SYM_KIND_RANGE(TENSORFLOW_CONTROL)
+DEFINE_SYM_KIND_RANGE(TENSORFLOW_EXECUTOR)
+DEFINE_SYM_KIND_RANGE(TENSORFLOW)
+DEFINE_SYM_KIND_RANGE(LLVM)
+DEFINE_SYM_KIND_RANGE(QUANTIZATION)
+DEFINE_SYM_KIND_RANGE(IREE) // IREE stands for IR Execution Engine
+DEFINE_SYM_KIND_RANGE(LINALG) // Linear Algebra Dialect
+DEFINE_SYM_KIND_RANGE(FIR) // Flang Fortran IR Dialect
+DEFINE_SYM_KIND_RANGE(TOY) // Toy language (tutorial) Dialect
+DEFINE_SYM_KIND_RANGE(SPIRV) // SPIR-V dialect
+
+// The following ranges are reserved for experimenting with MLIR dialects in a
+// private context without having to register them here.
+DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_0)
+DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_1)
+DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_2)
+DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_3)
+DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_4)
+DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_5)
+DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_6)
+DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_7)
+DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_8)
+DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_9)
+
+#undef DEFINE_SYM_KIND_RANGE
diff --git a/third_party/mlir/include/mlir/IR/Function.h b/third_party/mlir/include/mlir/IR/Function.h
new file mode 100644
index 0000000..73da52f
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/Function.h
@@ -0,0 +1,160 @@
+//===- Function.h - MLIR Function Class -------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Functions are the basic unit of composition in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_FUNCTION_H
+#define MLIR_IR_FUNCTION_H
+
+#include "mlir/IR/Block.h"
+#include "mlir/IR/FunctionSupport.h"
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+//===--------------------------------------------------------------------===//
+// Function Operation.
+//===--------------------------------------------------------------------===//
+
+/// FuncOp represents a function, or an operation containing one region that
+/// forms a CFG(Control Flow Graph). The region of a function is not allowed to
+/// implicitly capture global values, and all external references must use
+/// Function arguments or attributes that establish a symbolic connection(e.g.
+/// symbols referenced by name via a string attribute).
+class FuncOp : public Op<FuncOp, OpTrait::ZeroOperands, OpTrait::ZeroResult,
+                         OpTrait::IsIsolatedFromAbove, OpTrait::FunctionLike> {
+public:
+  using Op::Op;
+  using Op::print;
+
+  static StringRef getOperationName() { return "func"; }
+
+  static FuncOp create(Location location, StringRef name, FunctionType type,
+                       ArrayRef<NamedAttribute> attrs = {});
+  static FuncOp create(Location location, StringRef name, FunctionType type,
+                       llvm::iterator_range<dialect_attr_iterator> attrs);
+  static FuncOp create(Location location, StringRef name, FunctionType type,
+                       ArrayRef<NamedAttribute> attrs,
+                       ArrayRef<NamedAttributeList> argAttrs);
+
+  static void build(Builder *builder, OperationState *result, StringRef name,
+                    FunctionType type, ArrayRef<NamedAttribute> attrs);
+  static void build(Builder *builder, OperationState *result, StringRef name,
+                    FunctionType type, ArrayRef<NamedAttribute> attrs,
+                    ArrayRef<NamedAttributeList> argAttrs);
+
+  /// Operation hooks.
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+  LogicalResult verify();
+
+  /// Returns the type of this function.
+  FunctionType getType() {
+    return getAttrOfType<TypeAttr>(getTypeAttrName())
+        .getValue()
+        .cast<FunctionType>();
+  }
+
+  /// Change the type of this function in place. This is an extremely dangerous
+  /// operation and it is up to the caller to ensure that this is legal for this
+  /// function, and to restore invariants:
+  ///  - the entry block args must be updated to match the function params.
+  ///  - the arguments attributes may need an update: if the new type has less
+  ///    parameters we drop the extra attributes, if there are more parameters
+  ///    they won't have any attributes.
+  void setType(FunctionType newType) {
+    setAttr(getTypeAttrName(), TypeAttr::get(newType));
+  }
+
+  /// Create a deep copy of this function and all of its blocks, remapping
+  /// any operands that use values outside of the function using the map that is
+  /// provided (leaving them alone if no entry is present). If the mapper
+  /// contains entries for function arguments, these arguments are not included
+  /// in the new function. Replaces references to cloned sub-values with the
+  /// corresponding value that is copied, and adds those mappings to the mapper.
+  FuncOp clone(BlockAndValueMapping &mapper);
+  FuncOp clone();
+
+  /// Clone the internal blocks and attributes from this function into dest. Any
+  /// cloned blocks are appended to the back of dest. This function asserts that
+  /// the attributes of the current function and dest are compatible.
+  void cloneInto(FuncOp dest, BlockAndValueMapping &mapper);
+
+  //===--------------------------------------------------------------------===//
+  // Body Handling
+  //===--------------------------------------------------------------------===//
+
+  /// Add an entry block to an empty function, and set up the block arguments
+  /// to match the signature of the function. The newly inserted entry block is
+  /// returned.
+  Block *addEntryBlock();
+
+private:
+  // This trait needs access to `getNumFuncArguments` and `verifyType` hooks
+  // defined below.
+  friend class OpTrait::FunctionLike<FuncOp>;
+
+  /// Returns the number of arguments. This is a hook for OpTrait::FunctionLike.
+  unsigned getNumFuncArguments() { return getType().getInputs().size(); }
+
+  /// Hook for OpTrait::FunctionLike, called after verifying that the 'type'
+  /// attribute is present and checks if it holds a function type.  Ensures
+  /// getType and getNumFuncArguments can be called safely.
+  LogicalResult verifyType() {
+    auto type = getTypeAttr().getValue();
+    if (!type.isa<FunctionType>())
+      return emitOpError("requires '" + getTypeAttrName() +
+                         "' attribute of function type");
+    return success();
+  }
+};
+} // end namespace mlir
+
+namespace llvm {
+
+// Functions hash just like pointers.
+template <> struct DenseMapInfo<mlir::FuncOp> {
+  static mlir::FuncOp getEmptyKey() {
+    auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+    return mlir::FuncOp::getFromOpaquePointer(pointer);
+  }
+  static mlir::FuncOp getTombstoneKey() {
+    auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+    return mlir::FuncOp::getFromOpaquePointer(pointer);
+  }
+  static unsigned getHashValue(mlir::FuncOp val) {
+    return hash_value(val.getAsOpaquePointer());
+  }
+  static bool isEqual(mlir::FuncOp LHS, mlir::FuncOp RHS) { return LHS == RHS; }
+};
+
+/// Allow stealing the low bits of FuncOp.
+template <> struct PointerLikeTypeTraits<mlir::FuncOp> {
+public:
+  static inline void *getAsVoidPointer(mlir::FuncOp I) {
+    return const_cast<void *>(I.getAsOpaquePointer());
+  }
+  static inline mlir::FuncOp getFromVoidPointer(void *P) {
+    return mlir::FuncOp::getFromOpaquePointer(P);
+  }
+  enum { NumLowBitsAvailable = 3 };
+};
+
+} // namespace llvm
+
+#endif // MLIR_IR_FUNCTION_H
diff --git a/third_party/mlir/include/mlir/IR/FunctionSupport.h b/third_party/mlir/include/mlir/IR/FunctionSupport.h
new file mode 100644
index 0000000..75a0a67
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/FunctionSupport.h
@@ -0,0 +1,407 @@
+//===- FunctionSupport.h - Utility types for function-like ops --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines support types for Operations that represent function-like
+// constructs to use.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_FUNCTIONSUPPORT_H
+#define MLIR_IR_FUNCTIONSUPPORT_H
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/SymbolTable.h"
+#include "llvm/ADT/SmallString.h"
+
+namespace mlir {
+
+namespace impl {
+/// Return the name of the attribute used for function types.
+inline StringRef getTypeAttrName() { return "type"; }
+
+/// Return the name of the attribute used for function arguments.
+inline StringRef getArgAttrName(unsigned arg, SmallVectorImpl<char> &out) {
+  out.clear();
+  return ("arg" + Twine(arg)).toStringRef(out);
+}
+
+/// Returns the dictionary attribute corresponding to the argument at 'index'.
+/// If there are no argument attributes at 'index', a null attribute is
+/// returned.
+inline DictionaryAttr getArgAttrDict(Operation *op, unsigned index) {
+  SmallString<8> nameOut;
+  return op->getAttrOfType<DictionaryAttr>(getArgAttrName(index, nameOut));
+}
+
+/// Return all of the attributes for the argument at 'index'.
+inline ArrayRef<NamedAttribute> getArgAttrs(Operation *op, unsigned index) {
+  auto argDict = getArgAttrDict(op, index);
+  return argDict ? argDict.getValue() : llvm::None;
+}
+
+/// A named class for passing around the variadic flag.
+class VariadicFlag {
+public:
+  explicit VariadicFlag(bool variadic) : variadic(variadic) {}
+  bool isVariadic() const { return variadic; }
+
+private:
+  /// Underlying storage.
+  bool variadic;
+};
+
+/// Callback type for `parseFunctionLikeOp`, the callback should produce the
+/// type that will be associated with a function-like operation from lists of
+/// function arguments and results, VariadicFlag indicates whether the function
+/// should have variadic arguments; in case of error, it may populate the last
+/// argument with a message.
+using FuncTypeBuilder = llvm::function_ref<Type(
+    Builder &, ArrayRef<Type>, ArrayRef<Type>, VariadicFlag, std::string &)>;
+
+/// Parser implementation for function-like operations.  Uses
+/// `funcTypeBuilder` to construct the custom function type given lists of
+/// input and output types.  If `allowVariadic` is set, the parser will accept
+/// trailing ellipsis in the function signature and indicate to the builder
+/// whether the function is variadic.  If the builder returns a null type,
+/// `result` will not contain the `type` attribute.  The caller can then add a
+/// type, report the error or delegate the reporting to the op's verifier.
+ParseResult parseFunctionLikeOp(OpAsmParser *parser, OperationState *result,
+                                bool allowVariadic,
+                                FuncTypeBuilder funcTypeBuilder);
+
+/// Printer implementation for function-like operations.  Accepts lists of
+/// argument and result types to use while printing.
+void printFunctionLikeOp(OpAsmPrinter *p, Operation *op,
+                         ArrayRef<Type> argTypes, bool isVariadic,
+                         ArrayRef<Type> results);
+
+} // namespace impl
+
+namespace OpTrait {
+
+/// This trait provides APIs for Ops that behave like functions.  In particular:
+/// - Ops can be used with SymbolTable in the parent Op and have names;
+/// - Ops have a single region with multiple blocks that corresponds to the body
+///   of the function;
+/// - the absence of a region corresonds to an external function;
+/// - arguments of the first block of the region are treated as function
+///   arguments;
+/// - they can have argument attributes that are stored in a dictionary
+///   attribute on the Op itself.
+/// This trait does *NOT* provide type support for the functions, meaning that
+/// concrete Ops must handle the type of the declared or defined function.
+/// `getTypeAttrName()` is a convenience function that returns the name of the
+/// attribute that can be used to store the function type, but the trait makes
+/// no assumption based on it.
+///
+/// - Concrete ops *must* define a member function `getNumFuncArguments()` that
+/// returns the number of function arguments based exclusively on type (so that
+/// it can be called on function declarations).
+/// - To verify that the type respects op-specific invariants, concrete ops may
+/// redefine the `verifyType()` hook that will be called after verifying the
+/// presence of the `type` attribute and before any call to
+/// `getNumFuncArguments` from the verifier.
+template <typename ConcreteType>
+class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
+public:
+  /// Verify that all of the argument attributes are dialect attributes.
+  static LogicalResult verifyTrait(Operation *op);
+
+  //===--------------------------------------------------------------------===//
+  // Name Handling.
+  //===--------------------------------------------------------------------===//
+
+  /// Returns the name of this function.
+  StringRef getName() {
+    return this->getOperation()
+        ->template getAttrOfType<StringAttr>(
+            mlir::SymbolTable::getSymbolAttrName())
+        .getValue();
+  }
+
+  /// Set the name of this function.
+  void setName(StringRef name) {
+    this->getOperation()->setAttr(
+        mlir::SymbolTable::getSymbolAttrName(),
+        StringAttr::get(name, this->getOperation()->getContext()));
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Body Handling
+  //===--------------------------------------------------------------------===//
+
+  /// Returns true if this function is external, i.e. it has no body.
+  bool isExternal() { return empty(); }
+
+  Region &getBody() { return this->getOperation()->getRegion(0); }
+
+  /// Delete all blocks from this function.
+  void eraseBody() {
+    getBody().dropAllReferences();
+    getBody().getBlocks().clear();
+  }
+
+  /// This is the list of blocks in the function.
+  using RegionType = Region::RegionType;
+  RegionType &getBlocks() { return getBody().getBlocks(); }
+
+  // Iteration over the block in the function.
+  using iterator = RegionType::iterator;
+  using reverse_iterator = RegionType::reverse_iterator;
+
+  iterator begin() { return getBody().begin(); }
+  iterator end() { return getBody().end(); }
+  reverse_iterator rbegin() { return getBody().rbegin(); }
+  reverse_iterator rend() { return getBody().rend(); }
+
+  bool empty() { return getBody().empty(); }
+  void push_back(Block *block) { getBody().push_back(block); }
+  void push_front(Block *block) { getBody().push_front(block); }
+
+  Block &back() { return getBody().back(); }
+  Block &front() { return getBody().front(); }
+
+  //===--------------------------------------------------------------------===//
+  // Type Attribute Handling
+  //===--------------------------------------------------------------------===//
+
+  /// Return the name of the attribute used for function types.
+  static StringRef getTypeAttrName() { return ::mlir::impl::getTypeAttrName(); }
+
+  TypeAttr getTypeAttr() {
+    return this->getOperation()->template getAttrOfType<TypeAttr>(
+        getTypeAttrName());
+  }
+
+  bool isTypeAttrValid() {
+    auto typeAttr = getTypeAttr();
+    if (!typeAttr)
+      return false;
+    return typeAttr.getValue() != Type{};
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Argument Handling
+  //===--------------------------------------------------------------------===//
+
+  unsigned getNumArguments() {
+    return static_cast<ConcreteType *>(this)->getNumFuncArguments();
+  }
+
+  /// Gets argument.
+  BlockArgument *getArgument(unsigned idx) {
+    return getBlocks().front().getArgument(idx);
+  }
+
+  // Supports non-const operand iteration.
+  using args_iterator = Block::args_iterator;
+  args_iterator args_begin() { return front().args_begin(); }
+  args_iterator args_end() { return front().args_end(); }
+  llvm::iterator_range<args_iterator> getArguments() {
+    return {args_begin(), args_end()};
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Argument Attributes
+  //===--------------------------------------------------------------------===//
+
+  /// FunctionLike operations allow for attaching attributes to each of the
+  /// respective function arguments. These argument attributes are stored as
+  /// DictionaryAttrs in the main operation attribute dictionary. The name of
+  /// these entries is `arg` followed by the index of the argument. These
+  /// argument attribute dictionaries are optional, and will generally only
+  /// exist if they are non-empty.
+
+  /// Return all of the attributes for the argument at 'index'.
+  ArrayRef<NamedAttribute> getArgAttrs(unsigned index) {
+    return ::mlir::impl::getArgAttrs(this->getOperation(), index);
+  }
+
+  /// Return all argument attributes of this function.
+  void getAllArgAttrs(SmallVectorImpl<NamedAttributeList> &result) {
+    for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
+      result.emplace_back(getArgAttrDict(i));
+  }
+
+  /// Return the specified attribute, if present, for the argument at 'index',
+  /// null otherwise.
+  Attribute getArgAttr(unsigned index, Identifier name) {
+    auto argDict = getArgAttrDict(index);
+    return argDict ? argDict.get(name) : nullptr;
+  }
+  Attribute getArgAttr(unsigned index, StringRef name) {
+    auto argDict = getArgAttrDict(index);
+    return argDict ? argDict.get(name) : nullptr;
+  }
+
+  template <typename AttrClass>
+  AttrClass getArgAttrOfType(unsigned index, Identifier name) {
+    return getArgAttr(index, name).template dyn_cast_or_null<AttrClass>();
+  }
+  template <typename AttrClass>
+  AttrClass getArgAttrOfType(unsigned index, StringRef name) {
+    return getArgAttr(index, name).template dyn_cast_or_null<AttrClass>();
+  }
+
+  /// Set the attributes held by the argument at 'index'.
+  void setArgAttrs(unsigned index, ArrayRef<NamedAttribute> attributes);
+  void setArgAttrs(unsigned index, NamedAttributeList attributes);
+  void setAllArgAttrs(ArrayRef<NamedAttributeList> attributes) {
+    assert(attributes.size() == getNumArguments());
+    for (unsigned i = 0, e = attributes.size(); i != e; ++i)
+      setArgAttrs(i, attributes[i]);
+  }
+
+  /// If the an attribute exists with the specified name, change it to the new
+  /// value. Otherwise, add a new attribute with the specified name/value.
+  void setArgAttr(unsigned index, Identifier name, Attribute value);
+  void setArgAttr(unsigned index, StringRef name, Attribute value) {
+    setArgAttr(index, Identifier::get(name, this->getOperation()->getContext()),
+               value);
+  }
+
+  /// Remove the attribute 'name' from the argument at 'index'.
+  NamedAttributeList::RemoveResult removeArgAttr(unsigned index,
+                                                 Identifier name);
+
+protected:
+  /// Returns the attribute entry name for the set of argument attributes at
+  /// index 'arg'.
+  static StringRef getArgAttrName(unsigned arg, SmallVectorImpl<char> &out) {
+    return ::mlir::impl::getArgAttrName(arg, out);
+  }
+
+  /// Returns the dictionary attribute corresponding to the argument at 'index'.
+  /// If there are no argument attributes at 'index', a null attribute is
+  /// returned.
+  DictionaryAttr getArgAttrDict(unsigned index) {
+    assert(index < getNumArguments() && "invalid argument number");
+    return ::mlir::impl::getArgAttrDict(this->getOperation(), index);
+  }
+
+  /// Hook for concrete classes to verify that the type attribute respects
+  /// op-specific invariants.  Default implementation always succeeds.
+  LogicalResult verifyType() { return success(); }
+};
+
+template <typename ConcreteType>
+LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
+  MLIRContext *ctx = op->getContext();
+  auto funcOp = cast<ConcreteType>(op);
+
+  if (!funcOp.isTypeAttrValid())
+    return funcOp.emitOpError("requires a type attribute '")
+           << getTypeAttrName() << '\'';
+
+  if (failed(funcOp.verifyType()))
+    return failure();
+
+  for (unsigned i = 0, e = funcOp.getNumArguments(); i != e; ++i) {
+    // Verify that all of the argument attributes are dialect attributes, i.e.
+    // that they contain a dialect prefix in their name.  Call the dialect, if
+    // registered, to verify the attributes themselves.
+    for (auto attr : funcOp.getArgAttrs(i)) {
+      if (!attr.first.strref().contains('.'))
+        return funcOp.emitOpError("arguments may only have dialect attributes");
+      auto dialectNamePair = attr.first.strref().split('.');
+      if (auto *dialect = ctx->getRegisteredDialect(dialectNamePair.first)) {
+        if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0,
+                                                     /*argIndex=*/i, attr)))
+          return failure();
+      }
+    }
+  }
+
+  // Check that the op has exactly one region for the body.
+  if (op->getNumRegions() != 1)
+    return funcOp.emitOpError("expects one region");
+
+  // Check that if the entry block exists, it has the same number of arguments
+  // as the function-like operation.
+  if (funcOp.isExternal())
+    return success();
+
+  unsigned numArguments = funcOp.getNumArguments();
+  if (funcOp.front().getNumArguments() != numArguments)
+    return funcOp.emitOpError("entry block must have ")
+           << numArguments << " arguments to match function signature";
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Function Argument Attribute.
+//===----------------------------------------------------------------------===//
+
+/// Set the attributes held by the argument at 'index'.
+template <typename ConcreteType>
+void FunctionLike<ConcreteType>::setArgAttrs(
+    unsigned index, ArrayRef<NamedAttribute> attributes) {
+  assert(index < getNumArguments() && "invalid argument number");
+  SmallString<8> nameOut;
+  getArgAttrName(index, nameOut);
+  Operation *op = this->getOperation();
+
+  if (attributes.empty())
+    return (void)static_cast<ConcreteType *>(this)->removeAttr(nameOut);
+  op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext()));
+}
+
+template <typename ConcreteType>
+void FunctionLike<ConcreteType>::setArgAttrs(unsigned index,
+                                             NamedAttributeList attributes) {
+  assert(index < getNumArguments() && "invalid argument number");
+  SmallString<8> nameOut;
+  if (auto newAttr = attributes.getDictionary())
+    return this->getOperation()->setAttr(getArgAttrName(index, nameOut),
+                                         newAttr);
+  static_cast<ConcreteType *>(this)->removeAttr(getArgAttrName(index, nameOut));
+}
+
+/// If the an attribute exists with the specified name, change it to the new
+/// value. Otherwise, add a new attribute with the specified name/value.
+template <typename ConcreteType>
+void FunctionLike<ConcreteType>::setArgAttr(unsigned index, Identifier name,
+                                            Attribute value) {
+  auto curAttr = getArgAttrDict(index);
+  NamedAttributeList attrList(curAttr);
+  attrList.set(name, value);
+
+  // If the attribute changed, then set the new arg attribute list.
+  if (curAttr != attrList.getDictionary())
+    setArgAttrs(index, attrList);
+}
+
+/// Remove the attribute 'name' from the argument at 'index'.
+template <typename ConcreteType>
+NamedAttributeList::RemoveResult
+FunctionLike<ConcreteType>::removeArgAttr(unsigned index, Identifier name) {
+  // Build an attribute list and remove the attribute at 'name'.
+  NamedAttributeList attrList(getArgAttrDict(index));
+  auto result = attrList.remove(name);
+
+  // If the attribute was removed, then update the argument dictionary.
+  if (result == NamedAttributeList::RemoveResult::Removed)
+    setArgAttrs(index, attrList);
+  return result;
+}
+
+} // end namespace OpTrait
+
+} // end namespace mlir
+
+#endif // MLIR_IR_FUNCTIONSUPPORT_H
diff --git a/third_party/mlir/include/mlir/IR/Identifier.h b/third_party/mlir/include/mlir/IR/Identifier.h
new file mode 100644
index 0000000..bc84c20
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/Identifier.h
@@ -0,0 +1,143 @@
+//===- Identifier.h - MLIR Identifier Class ---------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_IR_IDENTIFIER_H
+#define MLIR_IR_IDENTIFIER_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseMapInfo.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace mlir {
+class MLIRContext;
+
+/// This class represents a uniqued string owned by an MLIRContext.  Strings
+/// represented by this type cannot contain nul characters, and may not have a
+/// zero length.
+///
+/// This is a POD type with pointer size, so it should be passed around by
+/// value.  The underlying data is owned by MLIRContext and is thus immortal for
+/// almost all clients.
+class Identifier {
+public:
+  /// Return an identifier for the specified string.
+  static Identifier get(StringRef str, MLIRContext *context);
+  Identifier(const Identifier &) = default;
+  Identifier &operator=(const Identifier &other) = default;
+
+  /// Return a StringRef for the string.
+  StringRef strref() const { return StringRef(pointer, size()); }
+
+  /// Identifiers implicitly convert to StringRefs.
+  operator StringRef() const { return strref(); }
+
+  /// Return an std::string.
+  std::string str() const { return strref().str(); }
+
+  /// Return a null terminated C string.
+  const char *c_str() const { return pointer; }
+
+  /// Return a pointer to the start of the string data.
+  const char *data() const { return pointer; }
+
+  /// Return the number of bytes in this string.
+  unsigned size() const { return ::strlen(pointer); }
+
+  /// Return true if this identifier is the specified string.
+  bool is(StringRef string) const { return strref().equals(string); }
+
+  const char *begin() const { return pointer; }
+  const char *end() const { return pointer + size(); }
+
+  void print(raw_ostream &os) const;
+  void dump() const;
+
+  const void *getAsOpaquePointer() const {
+    return static_cast<const void *>(pointer);
+  }
+  static Identifier getFromOpaquePointer(const void *pointer) {
+    return Identifier((const char *)pointer);
+  }
+
+private:
+  /// These are the bytes of the string, which is a nul terminated string.
+  const char *pointer;
+  explicit Identifier(const char *pointer) : pointer(pointer) {}
+};
+
+inline raw_ostream &operator<<(raw_ostream &os, Identifier identifier) {
+  identifier.print(os);
+  return os;
+}
+
+inline bool operator==(Identifier lhs, Identifier rhs) {
+  return lhs.data() == rhs.data();
+}
+
+inline bool operator!=(Identifier lhs, Identifier rhs) {
+  return lhs.data() != rhs.data();
+}
+
+inline bool operator==(Identifier lhs, StringRef rhs) { return lhs.is(rhs); }
+inline bool operator!=(Identifier lhs, StringRef rhs) { return !lhs.is(rhs); }
+inline bool operator==(StringRef lhs, Identifier rhs) { return rhs.is(lhs); }
+inline bool operator!=(StringRef lhs, Identifier rhs) { return !rhs.is(lhs); }
+
+// Make identifiers hashable.
+inline llvm::hash_code hash_value(Identifier arg) {
+  return llvm::hash_value(arg.strref());
+}
+
+} // end namespace mlir
+
+namespace llvm {
+// Identifiers hash just like pointers, there is no need to hash the bytes.
+template <>
+struct DenseMapInfo<mlir::Identifier> {
+  static mlir::Identifier getEmptyKey() {
+    auto pointer = llvm::DenseMapInfo<const void *>::getEmptyKey();
+    return mlir::Identifier::getFromOpaquePointer(pointer);
+  }
+  static mlir::Identifier getTombstoneKey() {
+    auto pointer = llvm::DenseMapInfo<const void *>::getTombstoneKey();
+    return mlir::Identifier::getFromOpaquePointer(pointer);
+  }
+  static unsigned getHashValue(mlir::Identifier Val) {
+    return DenseMapInfo<const void *>::getHashValue(Val.data());
+  }
+  static bool isEqual(mlir::Identifier LHS, mlir::Identifier RHS) {
+    return LHS == RHS;
+  }
+};
+
+/// The pointer inside of an identifier comes from a StringMap, so its alignment
+/// is always at least 4 and probably 8 (on 64-bit machines).  Allow LLVM to
+/// steal the low bits.
+template <>
+struct PointerLikeTypeTraits<mlir::Identifier> {
+public:
+  static inline void *getAsVoidPointer(mlir::Identifier I) {
+    return const_cast<void *>(I.getAsOpaquePointer());
+  }
+  static inline mlir::Identifier getFromVoidPointer(void *P) {
+    return mlir::Identifier::getFromOpaquePointer(P);
+  }
+  enum { NumLowBitsAvailable = 2 };
+};
+
+} // end namespace llvm
+#endif
diff --git a/third_party/mlir/include/mlir/IR/IntegerSet.h b/third_party/mlir/include/mlir/IR/IntegerSet.h
new file mode 100644
index 0000000..b7662f0
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/IntegerSet.h
@@ -0,0 +1,137 @@
+//===- IntegerSet.h - MLIR Integer Set Class --------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Integer sets are sets of points from the integer lattice constrained by
+// affine equality/inequality constraints. This class is meant to represent
+// integer sets in the IR - for 'affine.if' operations and as attributes of
+// other operations. It is typically expected to contain only a handful of
+// affine constraints, and is immutable like an affine map. Integer sets are not
+// unique'd - although affine expressions that make up its equalities and
+// inequalites are themselves unique.
+
+// This class is not meant for affine analysis and operations like set
+// operations, emptiness checks, or other math operations for analysis and
+// transformation. For the latter, use FlatAffineConstraints.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_INTEGER_SET_H
+#define MLIR_IR_INTEGER_SET_H
+
+#include "mlir/IR/AffineExpr.h"
+#include "llvm/ADT/ArrayRef.h"
+
+namespace mlir {
+
+namespace detail {
+struct IntegerSetStorage;
+}
+
+class MLIRContext;
+
+/// An integer set representing a conjunction of one or more affine equalities
+/// and inequalities. An integer set in the IR is immutable like the affine map,
+/// but integer sets are not unique'd. The affine expressions that make up the
+/// equalities and inequalities of an integer set are themselves unique and are
+/// allocated by the bump pointer allocator.
+class IntegerSet {
+public:
+  using ImplType = detail::IntegerSetStorage;
+
+  IntegerSet() : set(nullptr) {}
+  explicit IntegerSet(ImplType *set) : set(set) {}
+  IntegerSet(const IntegerSet &other) : set(other.set) {}
+  IntegerSet &operator=(const IntegerSet &other) = default;
+
+  static IntegerSet get(unsigned dimCount, unsigned symbolCount,
+                        ArrayRef<AffineExpr> constraints,
+                        ArrayRef<bool> eqFlags);
+
+  // Returns the canonical empty IntegerSet (i.e. a set with no integer points).
+  static IntegerSet getEmptySet(unsigned numDims, unsigned numSymbols,
+                                MLIRContext *context) {
+    auto one = getAffineConstantExpr(1, context);
+    /* 1 == 0 */
+    return get(numDims, numSymbols, one, true);
+  }
+
+  /// Returns true if this is the canonical integer set.
+  bool isEmptyIntegerSet() const;
+
+  explicit operator bool() { return set; }
+  bool operator==(IntegerSet other) const { return set == other.set; }
+
+  unsigned getNumDims() const;
+  unsigned getNumSymbols() const;
+  unsigned getNumOperands() const;
+  unsigned getNumConstraints() const;
+  unsigned getNumEqualities() const;
+  unsigned getNumInequalities() const;
+
+  ArrayRef<AffineExpr> getConstraints() const;
+
+  AffineExpr getConstraint(unsigned idx) const;
+
+  /// Returns the equality bits, which specify whether each of the constraints
+  /// is an equality or inequality.
+  ArrayRef<bool> getEqFlags() const;
+
+  /// Returns true if the idx^th constraint is an equality, false if it is an
+  /// inequality.
+  bool isEq(unsigned idx) const;
+
+  MLIRContext *getContext() const;
+
+  void print(raw_ostream &os) const;
+  void dump() const;
+
+  friend ::llvm::hash_code hash_value(IntegerSet arg);
+
+private:
+  ImplType *set;
+  /// Sets with constraints fewer than kUniquingThreshold are uniqued.
+  constexpr static unsigned kUniquingThreshold = 4;
+};
+
+// Make AffineExpr hashable.
+inline ::llvm::hash_code hash_value(IntegerSet arg) {
+  return ::llvm::hash_value(arg.set);
+}
+
+} // end namespace mlir
+namespace llvm {
+
+// IntegerSet hash just like pointers
+template <> struct DenseMapInfo<mlir::IntegerSet> {
+  static mlir::IntegerSet getEmptyKey() {
+    auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+    return mlir::IntegerSet(static_cast<mlir::IntegerSet::ImplType *>(pointer));
+  }
+  static mlir::IntegerSet getTombstoneKey() {
+    auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+    return mlir::IntegerSet(static_cast<mlir::IntegerSet::ImplType *>(pointer));
+  }
+  static unsigned getHashValue(mlir::IntegerSet val) {
+    return mlir::hash_value(val);
+  }
+  static bool isEqual(mlir::IntegerSet LHS, mlir::IntegerSet RHS) {
+    return LHS == RHS;
+  }
+};
+
+} // namespace llvm
+#endif // MLIR_IR_INTEGER_SET_H
diff --git a/third_party/mlir/include/mlir/IR/Location.h b/third_party/mlir/include/mlir/IR/Location.h
new file mode 100644
index 0000000..32fe0f4
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/Location.h
@@ -0,0 +1,270 @@
+//===- Location.h - MLIR Location Classes -----------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// These classes provide the ability to relate MLIR objects back to source
+// location position information.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_LOCATION_H
+#define MLIR_IR_LOCATION_H
+
+#include "mlir/IR/Attributes.h"
+
+namespace mlir {
+
+class Attribute;
+class MLIRContext;
+class Identifier;
+
+namespace detail {
+
+struct LocationStorage;
+struct UnknownLocationStorage;
+struct FileLineColLocationStorage;
+struct NameLocationStorage;
+struct CallSiteLocationStorage;
+struct FusedLocationStorage;
+
+} // namespace detail
+
+/// Location objects represent source locations information in MLIR.
+/// LocationAttr acts as the anchor for all Location based attributes.
+class LocationAttr : public Attribute {
+public:
+  using Attribute::Attribute;
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool classof(Attribute attr) {
+    return attr.getKind() >= StandardAttributes::FIRST_LOCATION_ATTR &&
+           attr.getKind() <= StandardAttributes::LAST_LOCATION_ATTR;
+  }
+};
+
+/// This class defines the main interface for locations in MLIR and acts as a
+/// non-nullable wrapper around a LocationAttr.
+class Location {
+public:
+  Location(LocationAttr loc) : impl(loc) {
+    assert(loc && "location should never be null.");
+  }
+
+  /// Access the impl location attribute.
+  operator LocationAttr() const { return impl; }
+  LocationAttr *operator->() const { return const_cast<LocationAttr *>(&impl); }
+
+  /// Type casting utilities on the underlying location.
+  template <typename U> bool isa() const { return impl.isa<U>(); }
+  template <typename U> U dyn_cast() const { return impl.dyn_cast<U>(); }
+  template <typename U> U cast() const { return impl.cast<U>(); }
+
+  /// Comparison operators.
+  bool operator==(Location rhs) const { return impl == rhs.impl; }
+  bool operator!=(Location rhs) const { return !(*this == rhs); }
+
+  /// Print the location.
+  void print(raw_ostream &os) const { impl.print(os); }
+  void dump() const { impl.dump(); }
+
+  friend ::llvm::hash_code hash_value(Location arg);
+
+  /// Methods for supporting PointerLikeTypeTraits.
+  const void *getAsOpaquePointer() const { return impl.getAsOpaquePointer(); }
+  static Location getFromOpaquePointer(const void *pointer) {
+    return LocationAttr(reinterpret_cast<const AttributeStorage *>(pointer));
+  }
+
+protected:
+  /// The internal backing location attribute.
+  LocationAttr impl;
+};
+
+inline raw_ostream &operator<<(raw_ostream &os, const Location &loc) {
+  loc.print(os);
+  return os;
+}
+
+/// Represents a location as call site. "callee" is the concrete location
+/// (Unknown/NameLocation/FileLineColLoc) and "caller" points to the caller's
+/// location (another CallLocation or a concrete location). Multiple
+/// CallSiteLocs can be chained to form a call stack.
+class CallSiteLoc
+    : public Attribute::AttrBase<CallSiteLoc, LocationAttr,
+                                 detail::CallSiteLocationStorage> {
+public:
+  using Base::Base;
+
+  /// Return a uniqued call location object.
+  static Location get(Location callee, Location caller, MLIRContext *context);
+
+  /// Return a call site location which represents a name reference in one line
+  /// or a stack of frames. The input frames are ordered from innermost to
+  /// outermost.
+  static Location get(Location name, ArrayRef<Location> frames,
+                      MLIRContext *context);
+
+  /// The concrete location information this object presents.
+  Location getCallee() const;
+
+  /// The caller's location.
+  Location getCaller() const;
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool kindof(unsigned kind) {
+    return kind == StandardAttributes::CallSiteLocation;
+  }
+};
+
+/// Represents a location derived from a file/line/column location.  The column
+/// and line may be zero to represent unknown column and/or unknown line/column
+/// information.
+class FileLineColLoc
+    : public Attribute::AttrBase<FileLineColLoc, LocationAttr,
+                                 detail::FileLineColLocationStorage> {
+public:
+  using Base::Base;
+
+  /// Return a uniqued FileLineCol location object.
+  static Location get(Identifier filename, unsigned line, unsigned column,
+                      MLIRContext *context);
+  static Location get(StringRef filename, unsigned line, unsigned column,
+                      MLIRContext *context);
+
+  StringRef getFilename() const;
+
+  unsigned getLine() const;
+  unsigned getColumn() const;
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool kindof(unsigned kind) {
+    return kind == StandardAttributes::FileLineColLocation;
+  }
+};
+
+/// Represents a value composed of multiple source constructs, with an optional
+/// metadata attribute.
+class FusedLoc : public Attribute::AttrBase<FusedLoc, LocationAttr,
+                                            detail::FusedLocationStorage> {
+public:
+  using Base::Base;
+
+  /// Return a uniqued Fused Location object. The first location in the list
+  /// will get precedence during diagnostic emission, with the rest being
+  /// displayed as supplementary "fused from here" style notes.
+  static Location get(ArrayRef<Location> locs, Attribute metadata,
+                      MLIRContext *context);
+  static Location get(ArrayRef<Location> locs, MLIRContext *context) {
+    return get(locs, Attribute(), context);
+  }
+
+  ArrayRef<Location> getLocations() const;
+
+  /// Returns the optional metadata attached to this fused location. Given that
+  /// it is optional, the return value may be a null node.
+  Attribute getMetadata() const;
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool kindof(unsigned kind) {
+    return kind == StandardAttributes::FusedLocation;
+  }
+};
+
+/// Represents an identity name attached to a child location.
+class NameLoc : public Attribute::AttrBase<NameLoc, LocationAttr,
+                                           detail::NameLocationStorage> {
+public:
+  using Base::Base;
+
+  /// Return a uniqued name location object. The child location must not be
+  /// another NameLoc.
+  static Location get(Identifier name, Location child, MLIRContext *context);
+
+  /// Return a uniqued name location object with an unknown child.
+  static Location get(Identifier name, MLIRContext *context);
+
+  /// Return the name identifier.
+  Identifier getName() const;
+
+  /// Return the child location.
+  Location getChildLoc() const;
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool kindof(unsigned kind) {
+    return kind == StandardAttributes::NameLocation;
+  }
+};
+
+/// Represents an unknown location.  This is always a singleton for a given
+/// MLIRContext.
+class UnknownLoc : public Attribute::AttrBase<UnknownLoc, LocationAttr> {
+public:
+  using Base::Base;
+
+  /// Get an instance of the UnknownLoc.
+  static Location get(MLIRContext *context);
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool kindof(unsigned kind) {
+    return kind == StandardAttributes::UnknownLocation;
+  }
+};
+
+// Make Location hashable.
+inline ::llvm::hash_code hash_value(Location arg) {
+  return hash_value(arg.impl);
+}
+
+} // end namespace mlir
+
+namespace llvm {
+
+// Type hash just like pointers.
+template <> struct DenseMapInfo<mlir::Location> {
+  static mlir::Location getEmptyKey() {
+    auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+    return mlir::Location::getFromOpaquePointer(pointer);
+  }
+  static mlir::Location getTombstoneKey() {
+    auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+    return mlir::Location::getFromOpaquePointer(pointer);
+  }
+  static unsigned getHashValue(mlir::Location val) {
+    return mlir::hash_value(val);
+  }
+  static bool isEqual(mlir::Location LHS, mlir::Location RHS) {
+    return LHS == RHS;
+  }
+};
+
+/// We align LocationStorage by 8, so allow LLVM to steal the low bits.
+template <> struct PointerLikeTypeTraits<mlir::Location> {
+public:
+  static inline void *getAsVoidPointer(mlir::Location I) {
+    return const_cast<void *>(I.getAsOpaquePointer());
+  }
+  static inline mlir::Location getFromVoidPointer(void *P) {
+    return mlir::Location::getFromOpaquePointer(P);
+  }
+  enum {
+    NumLowBitsAvailable =
+        PointerLikeTypeTraits<mlir::Attribute>::NumLowBitsAvailable
+  };
+};
+
+} // namespace llvm
+
+#endif
diff --git a/third_party/mlir/include/mlir/IR/MLIRContext.h b/third_party/mlir/include/mlir/IR/MLIRContext.h
new file mode 100644
index 0000000..a93cb8b
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/MLIRContext.h
@@ -0,0 +1,92 @@
+//===- MLIRContext.h - MLIR Global Context Class ----------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_IR_MLIRCONTEXT_H
+#define MLIR_IR_MLIRCONTEXT_H
+
+#include "mlir/Support/LLVM.h"
+#include <functional>
+#include <memory>
+#include <vector>
+
+namespace mlir {
+class AbstractOperation;
+class DiagnosticEngine;
+class Dialect;
+class InFlightDiagnostic;
+class Location;
+class MLIRContextImpl;
+class StorageUniquer;
+
+/// MLIRContext is the top-level object for a collection of MLIR modules.  It
+/// holds immortal uniqued objects like types, and the tables used to unique
+/// them.
+///
+/// MLIRContext gets a redundant "MLIR" prefix because otherwise it ends up with
+/// a very generic name ("Context") and because it is uncommon for clients to
+/// interact with it.
+///
+class MLIRContext {
+public:
+  explicit MLIRContext();
+  ~MLIRContext();
+
+  /// Return information about all registered IR dialects.
+  std::vector<Dialect *> getRegisteredDialects();
+
+  /// Get a registered IR dialect with the given namespace. If an exact match is
+  /// not found, then return nullptr.
+  Dialect *getRegisteredDialect(StringRef name);
+
+  /// Get a registered IR dialect for the given derived dialect type. The
+  /// derived type must provide a static 'getDialectNamespace' method.
+  template <typename T> T *getRegisteredDialect() {
+    return static_cast<T *>(getRegisteredDialect(T::getDialectNamespace()));
+  }
+
+  /// Return information about all registered operations.  This isn't very
+  /// efficient: typically you should ask the operations about their properties
+  /// directly.
+  std::vector<AbstractOperation *> getRegisteredOperations();
+
+  // This is effectively private given that only MLIRContext.cpp can see the
+  // MLIRContextImpl type.
+  MLIRContextImpl &getImpl() { return *impl; }
+
+  /// Returns the diagnostic engine for this context.
+  DiagnosticEngine &getDiagEngine();
+
+  /// Returns the storage uniquer used for creating affine constructs.
+  StorageUniquer &getAffineUniquer();
+
+  /// Returns the storage uniquer used for constructing type storage instances.
+  /// This should not be used directly.
+  StorageUniquer &getTypeUniquer();
+
+  /// Returns the storage uniquer used for constructing attribute storage
+  /// instances. This should not be used directly.
+  StorageUniquer &getAttributeUniquer();
+
+private:
+  const std::unique_ptr<MLIRContextImpl> impl;
+
+  MLIRContext(const MLIRContext &) = delete;
+  void operator=(const MLIRContext &) = delete;
+};
+} // end namespace mlir
+
+#endif // MLIR_IR_MLIRCONTEXT_H
diff --git a/third_party/mlir/include/mlir/IR/Matchers.h b/third_party/mlir/include/mlir/IR/Matchers.h
new file mode 100644
index 0000000..4ea1ce2
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/Matchers.h
@@ -0,0 +1,177 @@
+//===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file provides a simple and efficient mechanism for performing general
+// tree-based pattern matching over MLIR. This mechanism is inspired by LLVM's
+// include/llvm/IR/PatternMatch.h.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_MATCHERS_H
+#define MLIR_MATCHERS_H
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+#include <type_traits>
+
+namespace mlir {
+
+namespace detail {
+
+/// The matcher that matches a certain kind of Attribute and binds the value
+/// inside the Attribute.
+template <
+    typename AttrClass,
+    // Require AttrClass to be a derived class from Atribute and get its
+    // value type
+    typename ValueType =
+        typename std::enable_if<std::is_base_of<Attribute, AttrClass>::value,
+                                AttrClass>::type::ValueType,
+    // Require the ValueType is not void
+    typename = typename std::enable_if<!std::is_void<ValueType>::value>::type>
+struct attr_value_binder {
+  ValueType *bind_value;
+
+  /// Creates a matcher instance that binds the value to bv if match succeeds.
+  attr_value_binder(ValueType *bv) : bind_value(bv) {}
+
+  bool match(const Attribute &attr) {
+    if (auto intAttr = attr.dyn_cast<AttrClass>()) {
+      *bind_value = intAttr.getValue();
+      return true;
+    }
+    return false;
+  }
+};
+
+/// The matcher that matches a constant foldable operation that has no side
+/// effect, no operands and produces a single result.
+template <typename AttrT> struct constant_op_binder {
+  AttrT *bind_value;
+
+  /// Creates a matcher instance that binds the constant attribute value to
+  /// bind_value if match succeeds.
+  constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {}
+
+  bool match(Operation *op) {
+    if (op->getNumOperands() > 0 || op->getNumResults() != 1)
+      return false;
+    if (!op->hasNoSideEffect())
+      return false;
+
+    SmallVector<OpFoldResult, 1> foldedOp;
+    if (succeeded(op->fold(/*operands=*/llvm::None, foldedOp))) {
+      if (auto attr = foldedOp.front().dyn_cast<Attribute>()) {
+        if ((*bind_value = attr.dyn_cast<AttrT>()))
+          return true;
+      }
+    }
+    return false;
+  }
+};
+
+/// The matcher that matches a constant scalar / vector splat / tensor splat
+/// integer operation and binds the constant integer value.
+struct constant_int_op_binder {
+  IntegerAttr::ValueType *bind_value;
+
+  /// Creates a matcher instance that binds the value to bv if match succeeds.
+  constant_int_op_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {}
+
+  bool match(Operation *op) {
+    Attribute attr;
+    if (!constant_op_binder<Attribute>(&attr).match(op))
+      return false;
+    auto type = op->getResult(0)->getType();
+
+    if (type.isa<IntegerType>()) {
+      return attr_value_binder<IntegerAttr>(bind_value).match(attr);
+    }
+    if (type.isa<VectorType>() || type.isa<RankedTensorType>()) {
+      if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
+        return attr_value_binder<IntegerAttr>(bind_value)
+            .match(splatAttr.getSplatValue());
+      }
+    }
+    return false;
+  }
+};
+
+// The matcher that matches a given target constant scalar / vector splat /
+// tensor splat integer value.
+template <int64_t TargetValue> struct constant_int_value_matcher {
+  bool match(Operation *op) {
+    APInt value;
+
+    return constant_int_op_binder(&value).match(op) && TargetValue == value;
+  }
+};
+
+/// The matcher that matches a certain kind of op.
+template <typename OpClass> struct op_matcher {
+  bool match(Operation *op) { return isa<OpClass>(op); }
+};
+
+} // end namespace detail
+
+/// Entry point for matching a pattern over a Value.
+template <typename Pattern>
+inline bool matchPattern(Value *value, const Pattern &pattern) {
+  // TODO: handle other cases
+  if (auto *op = value->getDefiningOp())
+    return const_cast<Pattern &>(pattern).match(op);
+  return false;
+}
+
+/// Entry point for matching a pattern over an Operation.
+template <typename Pattern>
+inline bool matchPattern(Operation *op, const Pattern &pattern) {
+  return const_cast<Pattern &>(pattern).match(op);
+}
+
+/// Matches a constant holding a scalar/vector/tensor integer (splat) and
+/// writes the integer value to bind_value.
+inline detail::constant_int_op_binder
+m_ConstantInt(IntegerAttr::ValueType *bind_value) {
+  return detail::constant_int_op_binder(bind_value);
+}
+
+/// Matches a value from a constant foldable operation and writes the value to
+/// bind_value.
+template <typename AttrT>
+inline detail::constant_op_binder<AttrT> m_Constant(AttrT *bind_value) {
+  return detail::constant_op_binder<AttrT>(bind_value);
+}
+
+/// Matches a constant scalar / vector splat / tensor splat integer one.
+inline detail::constant_int_value_matcher<1> m_One() {
+  return detail::constant_int_value_matcher<1>();
+}
+
+/// Matches the given OpClass.
+template <typename OpClass> inline detail::op_matcher<OpClass> m_Op() {
+  return detail::op_matcher<OpClass>();
+}
+
+/// Matches a constant scalar / vector splat / tensor splat integer zero.
+inline detail::constant_int_value_matcher<0> m_Zero() {
+  return detail::constant_int_value_matcher<0>();
+}
+
+} // end namespace mlir
+
+#endif // MLIR_MATCHERS_H
diff --git a/third_party/mlir/include/mlir/IR/Module.h b/third_party/mlir/include/mlir/IR/Module.h
new file mode 100644
index 0000000..147337f
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/Module.h
@@ -0,0 +1,216 @@
+//===- Module.h - MLIR Module Class -----------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Module is the top-level container for code in an MLIR program.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_MODULE_H
+#define MLIR_IR_MODULE_H
+
+#include "mlir/IR/SymbolTable.h"
+
+namespace mlir {
+class ModuleTerminatorOp;
+
+//===----------------------------------------------------------------------===//
+// Module Operation.
+//===----------------------------------------------------------------------===//
+
+/// ModuleOp represents a module, or an operation containing one region with a
+/// single block containing opaque operations. The region of a module is not
+/// allowed to implicitly capture global values, and all external references
+/// must use symbolic references via attributes(e.g. via a string name).
+class ModuleOp
+    : public Op<
+          ModuleOp, OpTrait::ZeroOperands, OpTrait::ZeroResult,
+          OpTrait::IsIsolatedFromAbove, OpTrait::SymbolTable,
+          OpTrait::SingleBlockImplicitTerminator<ModuleTerminatorOp>::Impl> {
+public:
+  using Op::Op;
+  using Op::print;
+
+  static StringRef getOperationName() { return "module"; }
+
+  static void build(Builder *builder, OperationState *result);
+
+  /// Construct a module from the given location.
+  static ModuleOp create(Location loc);
+
+  /// Operation hooks.
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+  LogicalResult verify();
+
+  /// Return body of this module.
+  Region &getBodyRegion();
+  Block *getBody();
+
+  /// Print the this module in the custom top-level form.
+  void print(raw_ostream &os);
+  void dump();
+
+  //===--------------------------------------------------------------------===//
+  // Body Management.
+  //===--------------------------------------------------------------------===//
+
+  /// Iteration over the operations in the module.
+  using iterator = Block::iterator;
+
+  iterator begin() { return getBody()->begin(); }
+  iterator end() { return getBody()->end(); }
+  Operation &front() { return *begin(); }
+
+  /// This returns a range of operations of the given type 'T' held within the
+  /// module.
+  template <typename T> llvm::iterator_range<Block::op_iterator<T>> getOps() {
+    return getBody()->getOps<T>();
+  }
+
+  /// Insert the operation into the back of the body, before the terminator.
+  void push_back(Operation *op) {
+    insert(Block::iterator(getBody()->getTerminator()), op);
+  }
+
+  /// Insert the operation at the given insertion point. Note: The operation is
+  /// never inserted after the terminator, even if the insertion point is end().
+  void insert(Operation *insertPt, Operation *op) {
+    insert(Block::iterator(insertPt), op);
+  }
+  void insert(Block::iterator insertPt, Operation *op) {
+    auto *body = getBody();
+    if (insertPt == body->end())
+      insertPt = Block::iterator(body->getTerminator());
+    body->getOperations().insert(insertPt, op);
+  }
+};
+
+/// The ModuleTerminatorOp is a special terminator operation for the body of a
+/// ModuleOp, it has no semantic meaning beyond keeping the body of a ModuleOp
+/// well-formed.
+///
+/// This operation does _not_ have a custom syntax. However, ModuleOp will omit
+/// the terminator in their custom syntax for brevity.
+class ModuleTerminatorOp
+    : public Op<ModuleTerminatorOp, OpTrait::ZeroOperands, OpTrait::ZeroResult,
+                OpTrait::HasParent<ModuleOp>::Impl, OpTrait::IsTerminator> {
+public:
+  using Op::Op;
+  static StringRef getOperationName() { return "module_terminator"; }
+  static void build(Builder *, OperationState *) {}
+};
+
+//===----------------------------------------------------------------------===//
+// Module Manager.
+//===----------------------------------------------------------------------===//
+
+/// A class used to manage the symbols held by a module. This class handles
+/// ensures that symbols inserted into a module have a unique name, and provides
+/// efficent named lookup to held symbols.
+class ModuleManager {
+public:
+  ModuleManager(ModuleOp module) : module(module), symbolTable(module) {}
+
+  /// Look up a symbol with the specified name, returning null if no such
+  /// name exists. Names must never include the @ on them.
+  template <typename T, typename NameTy> T lookupSymbol(NameTy &&name) const {
+    return symbolTable.lookup<T>(name);
+  }
+
+  /// Insert a new symbol into the module, auto-renaming it as necessary.
+  void insert(Operation *op) {
+    symbolTable.insert(op);
+    module.push_back(op);
+  }
+  void insert(Block::iterator insertPt, Operation *op) {
+    symbolTable.insert(op);
+    module.insert(insertPt, op);
+  }
+
+  /// Remove the given symbol from the module symbol table and then erase it.
+  void erase(Operation *op) {
+    symbolTable.erase(op);
+    op->erase();
+  }
+
+  /// Return the internally held module.
+  ModuleOp getModule() const { return module; }
+
+  /// Return the context of the internal module.
+  MLIRContext *getContext() { return module.getContext(); }
+
+private:
+  ModuleOp module;
+  SymbolTable symbolTable;
+};
+
+/// This class acts as an owning reference to a module, and will automatically
+/// destroy the held module if valid.
+class OwningModuleRef {
+public:
+  OwningModuleRef(std::nullptr_t = nullptr) {}
+  OwningModuleRef(ModuleOp module) : module(module) {}
+  OwningModuleRef(OwningModuleRef &&other) : module(other.release()) {}
+  ~OwningModuleRef() {
+    if (module)
+      module.erase();
+  }
+
+  // Assign from another module reference.
+  OwningModuleRef &operator=(OwningModuleRef &&other) {
+    if (module)
+      module.erase();
+    module = other.release();
+    return *this;
+  }
+
+  /// Allow accessing the internal module.
+  ModuleOp get() const { return module; }
+  ModuleOp operator*() const { return module; }
+  ModuleOp *operator->() { return &module; }
+  explicit operator bool() const { return module; }
+
+  /// Release the referenced module.
+  ModuleOp release() {
+    ModuleOp released;
+    std::swap(released, module);
+    return released;
+  }
+
+private:
+  ModuleOp module;
+};
+
+} // end namespace mlir
+
+namespace llvm {
+
+/// Allow stealing the low bits of ModuleOp.
+template <> struct PointerLikeTypeTraits<mlir::ModuleOp> {
+public:
+  static inline void *getAsVoidPointer(mlir::ModuleOp I) {
+    return const_cast<void *>(I.getAsOpaquePointer());
+  }
+  static inline mlir::ModuleOp getFromVoidPointer(void *P) {
+    return mlir::ModuleOp::getFromOpaquePointer(P);
+  }
+  enum { NumLowBitsAvailable = 3 };
+};
+
+} // end namespace llvm
+
+#endif // MLIR_IR_MODULE_H
diff --git a/third_party/mlir/include/mlir/IR/OpBase.td b/third_party/mlir/include/mlir/IR/OpBase.td
new file mode 100644
index 0000000..3cf3efc
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/OpBase.td
@@ -0,0 +1,1437 @@
+//===-- OpBase.td - Base op definition file ----------------*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is the base operation definition file.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef OP_BASE
+#else
+#define OP_BASE
+
+//===----------------------------------------------------------------------===//
+// Common utilities for defining TableGen mechanisms
+//===----------------------------------------------------------------------===//
+
+// Concatenates a list of strings with a separator (default ", ")
+class StrJoin<list<string> strings, string sep = ", "> {
+  string result =
+      !if(!empty(strings), "",
+          !foldl(!head(strings), !tail(strings), prev, cur, prev # sep # cur));
+}
+
+// Concatenates a list of integers into a string with a separator (default ", ")
+class StrJoinInt<list<int> integers, string sep = ", "> :
+    StrJoin<!foreach(i, integers, !cast<string>(i)), sep>;
+
+//===----------------------------------------------------------------------===//
+// Predicate definitions
+//===----------------------------------------------------------------------===//
+
+// Base class for logical predicates.
+//
+// Predicates are used to compose constraints (see next section for details).
+// There are two categories of predicates:
+//
+// 1. CPred: the primitive leaf predicate.
+// 2. Compound predicate: a predicate composed from child predicates using
+//    predicate combiners ("conjunction", "disjunction", "negation" or
+//    "substitution").
+class Pred;
+
+// A logical predicate wrapping any C expression.
+//
+// This is the basis for composing more complex predicates. It is the "atom"
+// predicate from the perspective of TableGen and the "interface" between
+// TableGen and C++. What is inside is already C++ code, which will be treated
+// as opaque strings with special placeholders to be substituted.
+//
+// ## Special placeholders
+//
+// Special placeholders can be used to refer to entities in the context where
+// this predicate is used. They serve as "hooks" to the enclosing environment.
+// The following special placeholders are supported in constraints for an op:
+//
+// * `$_builder` will be replaced by a mlir::Builder instance.
+// * `$_op` will be replaced by the current operation.
+// * `$_self` will be replaced with the entity this predicate is attached to.
+//   E.g., `BoolAttr` is an attribute constraint that wraps a
+//   `CPred<"$_self.isa<BoolAttr>()">` (see the following sections for details).
+//   Then for `F32:$attr`,`$_self` will be replaced by `$attr`.
+//   For type constraints, it's a little bit special since we want the
+//   constraints on each type definition reads naturally and we want to attach
+//   type constraints directly to an operand/result, $_self will be replaced
+//   by the operand/result's type. E.g., for `F32` in `F32:$operand`, its
+//   `$_self` will be expanded as `getOperand(...)->getType()`.
+class CPred<code pred> : Pred {
+  code predExpr = "(" # pred # ")";
+}
+
+// Kinds of predicate combiners.  These must closesly match the predicates
+// implemented by the C++ backend (tblgen::PredCombinerKind).
+class PredCombinerKind;
+def PredCombinerAnd : PredCombinerKind;
+def PredCombinerOr : PredCombinerKind;
+def PredCombinerNot : PredCombinerKind;
+def PredCombinerSubstLeaves : PredCombinerKind;
+def PredCombinerConcat : PredCombinerKind;
+
+// A predicate that combines other predicates as defined by PredCombinerKind.
+// Instantiated below.
+class CombinedPred<PredCombinerKind k, list<Pred> c> : Pred {
+  PredCombinerKind kind = k;
+  list<Pred> children = c;
+}
+
+// Predicate combiners
+
+// A predicate that holds if all of its children hold.  Always holds for zero
+// children.
+class And<list<Pred> children> : CombinedPred<PredCombinerAnd, children>;
+
+// A predicate that holds if any of its children hold.  Never holds for zero
+// children.
+class Or<list<Pred> children> : CombinedPred<PredCombinerOr, children>;
+
+// A predicate that holds if its child does not.
+class Neg<Pred child> : CombinedPred<PredCombinerNot, [child]>;
+
+// A predicate that substitutes "pat" with "repl" in predicate calls of the
+// leaves of the predicate tree (i.e., not CombinedPred).
+//
+// This is plain string substitution without regular expressions or captures.
+// New predicates with more complex logical can be introduced should the need
+// arise.
+class SubstLeaves<string pat, string repl, Pred child>
+    : CombinedPred<PredCombinerSubstLeaves, [child]> {
+  string pattern = pat;
+  string replacement = repl;
+}
+
+// A predicate that prepends `pre` and appends `suf` to the final predicate
+// string composed from `child`. This is plain string concatenation and there
+// will be no substitution happening for `pre` and `suf`.
+class Concat<string pre, Pred child, string suf> :
+    CombinedPred<PredCombinerConcat, [child]> {
+  string prefix = pre;
+  string suffix = suf;
+}
+
+//===----------------------------------------------------------------------===//
+// Constraint definitions
+//===----------------------------------------------------------------------===//
+
+// TODO(b/130064155): Merge Constraints into Pred.
+
+// Base class for named constraints.
+//
+// An op's operands/attributes/results can have various requirements, e.g.,
+// having certain types, having values inside a certain range, and so on.
+// Besides, for a graph rewrite rule, the source pattern used to match against
+// the existing graph has conditions, like the op's operand must be of a more
+// constrained subtype, the attribute must have a certain value, and so on.
+//
+// These requirements and conditions are modeled using this class. Records of
+// this class are used to generate verification code in op verifier, and
+// matching code in pattern matcher.
+//
+// Constraints are predicates with descriptive names, to facilitate inspection,
+// provide nice error messages, etc.
+class Constraint<Pred pred, string desc = ""> {
+  // The predicates that this constraint requires.
+  Pred predicate = pred;
+  // User-readable description used in error reporting messages. If empty, a
+  // generic message will be used.
+  string description = desc;
+}
+
+// Subclasses used to differentiate different constraint kinds. These are used
+// as markers for the TableGen backend to handle different constraint kinds
+// differently if needed. Constraints not deriving from the following subclasses
+// are considered as uncategorized constraints.
+
+// Subclass for constraints on a type.
+class TypeConstraint<Pred predicate, string description = ""> :
+    Constraint<predicate, description>;
+
+// Subclass for constraints on an attribute.
+class AttrConstraint<Pred predicate, string description = ""> :
+    Constraint<predicate, description>;
+
+// Subclass for constraints on a region.
+class RegionConstraint<Pred predicate, string description = ""> :
+    Constraint<predicate, description>;
+
+// How to use these constraint categories:
+//
+// * Use TypeConstraint to specify
+//   * Constraints on an op's operand/result definition
+//   * Further constraints to match an op's operand/result in source pattern
+//
+// * Use Attr (a subclass for AttrConstraint) for
+//   * Constraints on an op's attribute definition
+// * Use AttrConstraint to specify
+//   * Further constraints to match an op's attribute in source pattern
+//
+// * Use uncategorized constraint to specify
+//   * Multi-entity constraints in rewrite rules
+
+//===----------------------------------------------------------------------===//
+// Common predicates
+//===----------------------------------------------------------------------===//
+
+// Whether a type is a VectorType.
+def IsVectorTypePred : CPred<"$_self.isa<VectorType>()">;
+
+// Whether a type is a TensorType.
+def IsTensorTypePred : CPred<"$_self.isa<TensorType>()">;
+
+// Whether a type is a MemRefType.
+def IsMemRefTypePred : CPred<"$_self.isa<MemRefType>()">;
+
+// Whether a type is a ShapedType.
+def IsShapedTypePred : CPred<"$_self.isa<ShapedType>()">;
+
+// For a ShapedType, verify that it has a static shape.
+def HasStaticShapePred : CPred<"$_self.cast<ShapedType>().hasStaticShape()">;
+
+// Whether a type is a TupleType.
+def IsTupleTypePred : CPred<"$_self.isa<TupleType>()">;
+
+//===----------------------------------------------------------------------===//
+// Dialect definitions
+//===----------------------------------------------------------------------===//
+
+class Dialect {
+  // The name of the dialect.
+  string name = ?;
+
+  // Short summary of the dialect.
+  string summary = ?;
+
+  // The description of the dialect.
+  string description = ?;
+
+  // The C++ namespace that ops of this dialect should be placed into.
+  //
+  // By default, uses the name of the dialect as the only namespace. To avoid
+  // placing in any namespace, use "". To specify nested namespaces, use "::"
+  // as the delimiter, e.g., given "A::B", ops will be placed in
+  // `namespace A { namespace B { <ops> } }`.
+  //
+  // Note that this works in conjunction with dialect C++ code. Depending on how
+  // the generated files are included into the dialect, you may want to specify
+  // a full namespace path or a partial one.
+  string cppNamespace = name;
+}
+
+//===----------------------------------------------------------------------===//
+// Type definitions
+//===----------------------------------------------------------------------===//
+
+// A type, carries type constraints.
+class Type<Pred condition, string descr = ""> :
+    TypeConstraint<condition, descr>;
+
+// Allows providing an alternative name and description to an existing type def.
+class TypeAlias<Type t, string description = t.description> :
+    Type<t.predicate, description>;
+
+// A variadic type constraint. It expands to zero or more of the base type. This
+// class is used for supporting variadic operands/results. An op can declare no
+// more than one variadic operand/result, and that operand/result must be the
+// last one in the operand/result list.
+class Variadic<Type type> : TypeConstraint<type.predicate, type.description> {
+  Type baseType = type;
+}
+
+// A type that can be constructed using MLIR::Builder.
+// Note that this does not "inherit" from Type because it would require
+// duplicating Type subclasses for buildable and non-buildable cases to avoid
+// diamond "inheritance".
+// TODO(zinenko): we may extend this to a more general 'Buildable' trait,
+// making some Types and some Attrs buildable.
+class BuildableType<code builder> {
+  // The builder call to invoke (if specified) to construct the BuildableType.
+  // Format: this will be affixed to the builder.
+  code builderCall = builder;
+}
+
+// Any type at all.
+def AnyType : Type<CPred<"true">, "any type">;
+
+// None type
+def NoneType : Type<CPred<"$_self.isa<NoneType>()">, "none type">;
+
+// Any type from the given list
+class AnyTypeOf<list<Type> allowedTypes, string description = ""> : Type<
+    // Satisfy any of the allowed type's condition
+    Or<!foreach(allowedtype, allowedTypes, allowedtype.predicate)>,
+    !if(!eq(description, ""),
+        StrJoin<!foreach(t, allowedTypes, t.description), " or ">.result,
+        description)>;
+
+// Integer types.
+// Any integer type irrespective of its width.
+def AnyInteger : Type<CPred<"$_self.isa<IntegerType>()">, "integer">;
+
+// Index type.
+def Index : Type<CPred<"$_self.isa<IndexType>()">, "index">;
+
+// Integer type of a specific width.
+class I<int width>
+    : Type<CPred<"$_self.isInteger(" # width # ")">,
+                  width # "-bit integer">,
+      BuildableType<"getIntegerType(" # width # ")"> {
+  int bitwidth = width;
+}
+
+class IntOfWidths<list<int> widths> :
+    AnyTypeOf<!foreach(w, widths, I<w>),
+              StrJoinInt<widths, "/">.result # "-bit integer">;
+
+def I1  : I<1>;
+def I8  : I<8>;
+def I16 : I<16>;
+def I32 : I<32>;
+def I64 : I<64>;
+
+// Floating point types.
+
+// Any float type irrespective of its width.
+def AnyFloat : Type<CPred<"$_self.isa<FloatType>()">, "floating-point">;
+
+// Float type of a specific width.
+class F<int width>
+    : Type<CPred<"$_self.isF" # width # "()">,
+                width # "-bit float">,
+      BuildableType<"getF" # width # "Type()"> {
+  int bitwidth = width;
+}
+
+class FloatOfWidths<list<int> widths> :
+    AnyTypeOf<!foreach(w, widths, F<w>),
+              StrJoinInt<widths, "/">.result # "-bit float">;
+
+def F16 : F<16>;
+def F32 : F<32>;
+def F64 : F<64>;
+
+def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
+           BuildableType<"getBF16Type()">;
+
+class OpaqueType<string dialect, string name, string description>
+  : Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">,
+         description>;
+
+// Function Type
+
+// Any function type.
+def FunctionType : Type<CPred<"$_self.isa<FunctionType>()">, "function type">;
+
+// A container type is a type that has another type embedded within it.
+class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
+                    string descr> :
+    // First, check the container predicate.  Then, substitute the extracted
+    // element into the element type checker.
+    Type<And<[containerPred,
+                SubstLeaves<"$_self", !cast<string>(elementTypeCall),
+                etype.predicate>]>,
+         descr # " of " # etype.description # " values"> {
+  // The type of elements in the container.
+  Type elementType = etype;
+
+  // Call to retrieve.
+  code getElementTypeCall = elementTypeCall;
+}
+
+class ShapedContainerType<list<Type> allowedTypes, Pred containerPred, string descr> :
+    ContainerType<AnyTypeOf<allowedTypes>, containerPred,
+                  "$_self.cast<ShapedType>().getElementType()", descr>;
+
+// Vector types.
+
+class VectorOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsVectorTypePred, "vector">;
+
+def AnyVector : VectorOf<[AnyType]>;
+
+// Tensor types.
+
+// Any tensor type whose element type is from the given `allowedTypes` list
+class TensorOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsTensorTypePred, "tensor">;
+
+def AnyTensor : TensorOf<[AnyType]>;
+
+// TODO(b/130064155) Have an easy way to add another constraint to a type.
+class StaticShapeTensorOf<list<Type> allowedTypes>
+    : Type<And<[TensorOf<allowedTypes>.predicate, HasStaticShapePred]>,
+           "statically shaped " # TensorOf<allowedTypes>.description>;
+
+def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
+
+def I1Tensor   : TensorOf<[I1]>;
+def I8Tensor   : TensorOf<[I8]>;
+def I16Tensor  : TensorOf<[I16]>;
+def I32Tensor  : TensorOf<[I32]>;
+def I64Tensor  : TensorOf<[I64]>;
+
+def BF16Tensor : TensorOf<[BF16]>;
+def F16Tensor  : TensorOf<[F16]>;
+def F32Tensor  : TensorOf<[F32]>;
+def F64Tensor  : TensorOf<[F64]>;
+
+// Memref type.
+
+// Memrefs are blocks of data with fixed type and rank.
+class MemRefOf<list<Type> allowedTypes> :
+    ShapedContainerType<allowedTypes, IsMemRefTypePred, "memref">;
+
+def AnyMemRef : MemRefOf<[AnyType]>;
+
+// Memref declarations handle any memref, independent of rank, size, (static or
+// dynamic), layout, or memory space.
+def I1MemRef  : MemRefOf<[I1]>;
+def I8MemRef  : MemRefOf<[I8]>;
+def I16MemRef : MemRefOf<[I16]>;
+def I32MemRef : MemRefOf<[I32]>;
+def I64MemRef : MemRefOf<[I64]>;
+
+def BF16MemRef : MemRefOf<[BF16]>;
+def F16MemRef  : MemRefOf<[F16]>;
+def F32MemRef  : MemRefOf<[F32]>;
+def F64MemRef  : MemRefOf<[F64]>;
+
+// This represents a generic tuple without any constraints on element type.
+def AnyTuple : Type<IsTupleTypePred, "tuple">;
+
+// A container type that has other types embedded in it, but (unlike
+// ContainerType) can hold elements with a mix of types. Requires a call that
+// produces a list of all elements' types.
+class MixedContainerType<Type etype, Pred containerPred, code elementTypesCall,
+                         string descr> :
+    Type<
+        And<[
+            containerPred,
+            Concat<
+                "llvm::all_of(" # elementTypesCall # ", [](Type t) { return ",
+                SubstLeaves<"$_self", "t", etype.predicate>,
+                "; })"
+            >
+        ]>,
+        descr # " with any combination of " # etype.description # " values"> {
+  // The type of elements in the container.
+  Type elementType = etype;
+
+  // Call to retrieve.
+  code getElementTypesCall = elementTypesCall;
+}
+
+// A Tuple that holds a mix of elements of the allowed types.
+class TupleOf<list<Type> allowedTypes>
+    : MixedContainerType<AnyTypeOf<allowedTypes>, IsTupleTypePred,
+                         "$_self.cast<TupleType>().getTypes()", "tuple">;
+
+// A Tuple with arbitrary nesting, where all elements are a mix of the allowed
+// types.
+class NestedTupleOf<list<Type> allowedTypes> :
+    MixedContainerType<AnyTypeOf<allowedTypes>, IsTupleTypePred,
+                       "getFlattenedTypes($_self.cast<TupleType>())",
+                       "nested tuple">;
+
+//===----------------------------------------------------------------------===//
+// Common type constraints
+//===----------------------------------------------------------------------===//
+
+// Type constraint for bool-like types: bools, vectors of bools, tensors of
+// bools.
+def BoolLike : TypeConstraint<Or<[I1.predicate, VectorOf<[I1]>.predicate,
+                                  TensorOf<[I1]>.predicate]>,
+    "bool-like">;
+
+// Type constraint for integer-like types: integers, indices, vectors of
+// integers, tensors of integers.
+def IntegerLike : TypeConstraint<Or<[AnyInteger.predicate, Index.predicate,
+        VectorOf<[AnyInteger]>.predicate, TensorOf<[AnyInteger]>.predicate]>,
+    "integer-like">;
+
+// Type constraint for float-like types: floats, vectors or tensors thereof.
+def FloatLike : TypeConstraint<Or<[AnyFloat.predicate,
+        VectorOf<[AnyFloat]>.predicate, TensorOf<[AnyFloat]>.predicate]>,
+    "floating-point-like">;
+
+
+//===----------------------------------------------------------------------===//
+// Attribute definitions
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Base attribute definition
+
+// Base class for all attributes.
+class Attr<Pred condition, string descr = ""> :
+    AttrConstraint<condition, descr> {
+  code storageType = ?; // The backing mlir::Attribute type
+  code returnType = ?;  // The underlying C++ value type
+
+  // The call expression to convert from the storage type to the return
+  // type. For example, an enum can be stored as an int but returned as an
+  // enum class.
+  //
+  // Format: $_self will be expanded to the attribute.
+  //
+  // For example, `$_self.getValue().getSExtValue()` for `IntegerAttr val` will
+  // expand to `getAttrOfType<IntegerAttr>("val").getValue().getSExtValue()`.
+  code convertFromStorage = "$_self.getValue()";
+
+  // The call expression to build an attribute from a constant value.
+  //
+  // Format: $0 will be expanded to the constant value of the attribute.
+  //
+  // For example, `$_builder.getStringAttr("$0")` for `StringAttr:"foo"` will
+  // expand to `builder.getStringAttr("foo")`.
+  string constBuilderCall = ?;
+
+  // Default value for attribute.
+  // Requires a constBuilderCall defined.
+  string defaultValue = ?;
+
+  // Whether the attribute is optional. Typically requires a custom
+  // convertFromStorage method to handle the case where the attribute is
+  // not present.
+  bit isOptional = 0;
+
+  // What is the base-level Attr instantiation that this Attr is built upon.
+  // Unset means this is a base-level Attr.
+  //
+  // This field is used by attribute wrapper classes (DefaultValuedAttr,
+  // OptionalAttr, etc.) to retrive the base-level attribute definition.
+  // This can be used for getting its name; otherwise, we will see
+  // "anonymous_<number>" as the attribute def name because of template
+  // instantiation.
+  // TOOD(b/132458159): deduplicate the fields in attribute wrapper classes.
+  Attr baseAttr = ?;
+}
+
+//===----------------------------------------------------------------------===//
+// Attribute modifier definition
+
+// Decorates an attribute to have an (unvalidated) default value if not present.
+class DefaultValuedAttr<Attr attr, string val> :
+    Attr<attr.predicate, attr.description> {
+  // Construct this attribute with the input attribute and change only
+  // the default value.
+  // Note: this has to be kept up to date with Attr above.
+  let storageType = attr.storageType;
+  let returnType = attr.returnType;
+  let convertFromStorage = attr.convertFromStorage;
+  let constBuilderCall = attr.constBuilderCall;
+  let defaultValue = val;
+
+  let baseAttr = attr;
+}
+
+// Decorates an attribute as optional. The return type of the generated
+// attribute accessor method will be Optional<>.
+class OptionalAttr<Attr attr> : Attr<attr.predicate, attr.description> {
+  // Rewrite the attribute to be optional.
+  // Note: this has to be kept up to date with Attr above.
+  let storageType = attr.storageType;
+  let returnType = "Optional<" # attr.returnType #">";
+  let convertFromStorage = "$_self ? " # returnType # "(" #
+                           attr.convertFromStorage # ") : (llvm::None)";
+  let isOptional = 1;
+
+  let baseAttr = attr;
+}
+
+//===----------------------------------------------------------------------===//
+// Primitive attribute kinds
+
+// A generic attribute that must be constructed around a specific type
+// `attrValType`. Backed by MLIR attribute kind `attrKind`.
+class TypedAttrBase<BuildableType attrValType, string attrKind,
+                    Pred condition, string descr> :
+    Attr<condition, descr> {
+  let constBuilderCall = "$_builder.get" # attrKind # "($_builder." #
+                         attrValType.builderCall # ", $0)";
+  let storageType = attrKind;
+}
+
+// Any attribute.
+def AnyAttr : Attr<CPred<"true">, "any attribute"> {
+  let storageType = "Attribute";
+  let returnType = "Attribute";
+  let convertFromStorage = "$_self";
+  let constBuilderCall = "$0";
+}
+
+def BoolAttr : Attr<CPred<"$_self.isa<BoolAttr>()">, "bool attribute"> {
+  let storageType = [{ BoolAttr }];
+  let returnType = [{ bool }];
+  let constBuilderCall = "$_builder.getBoolAttr($0)";
+}
+
+// Base class for integer attributes of fixed width.
+class IntegerAttrBase<I attrValType, string descr> :
+    TypedAttrBase<
+      attrValType, "IntegerAttr",
+      And<[CPred<"$_self.isa<IntegerAttr>()">,
+           CPred<"$_self.cast<IntegerAttr>().getType()."
+                 "isInteger(" # attrValType.bitwidth # ")">]>,
+      descr> {
+  let returnType = [{ APInt }];
+}
+
+def APIntAttr : Attr<CPred<"$_self.isa<IntegerAttr>()">,
+                     "arbitrary integer attribute"> {
+  let storageType = [{ IntegerAttr }];
+  let returnType = [{ APInt }];
+}
+
+def I32Attr : IntegerAttrBase<I32, "32-bit integer attribute">;
+def I64Attr : IntegerAttrBase<I64, "64-bit integer attribute">;
+
+class NonNegativeIntAttrBase<I attrValType, string descr> :
+    TypedAttrBase<
+      attrValType, "IntegerAttr",
+      And<[IntegerAttrBase<attrValType, "">.predicate,
+           CPred<"!$_self.cast<IntegerAttr>().getValue().isNegative()">]>,
+      descr> {
+  let returnType = [{ APInt }];
+}
+
+def NonNegativeI32Attr : NonNegativeIntAttrBase<
+    I32, "non-negative 32-bit integer attribute">;
+def NonNegativeI64Attr : NonNegativeIntAttrBase<
+    I64, "non-negative 64-bit integer attribute">;
+
+// Base class for float attributes of fixed width.
+class FloatAttrBase<F attrValType, string descr> :
+    TypedAttrBase<attrValType, "FloatAttr",
+              And<[CPred<"$_self.isa<FloatAttr>()">,
+                     CPred<"$_self.cast<FloatAttr>().getType().isF" #
+                           attrValType.bitwidth # "()">]>,
+              descr> {
+  let returnType = [{ APFloat }];
+}
+
+def F32Attr : FloatAttrBase<F32, "32-bit float attribute">;
+def F64Attr : FloatAttrBase<F64, "64-bit float attribute">;
+
+// An attribute backed by a string type.
+class StringBasedAttr<Pred condition, string descr> : Attr<condition, descr> {
+  let constBuilderCall = "$_builder.getStringAttr(\"$0\")";
+  let storageType = [{ StringAttr }];
+  let returnType = [{ StringRef }];
+}
+
+def StrAttr : StringBasedAttr<CPred<"$_self.isa<StringAttr>()">,
+                              "string attribute">;
+
+// Base class for attributes containing types. Example:
+//   def IntTypeAttr : TypeAttrBase<"IntegerType", "integer type attribute">
+// defines a type attribute containing an integer type.
+class TypeAttrBase<string retType, string description> :
+    Attr<And<[
+      CPred<"$_self.isa<TypeAttr>()">,
+      CPred<"$_self.cast<TypeAttr>().getValue().isa<" # retType # ">()">]>,
+    description> {
+  let storageType = [{ TypeAttr }];
+  let returnType = retType;
+  let convertFromStorage = "$_self.getValue().cast<" # retType # ">()";
+}
+
+def TypeAttr : TypeAttrBase<"Type", "any type attribute">;
+
+// The mere presence of unit attributes has a meaning.  Therefore, unit
+// attributes are always treated as optional and accessors to them return
+// "true" if the attribute is present and "false" otherwise.
+def UnitAttr : Attr<CPred<"$_self.isa<UnitAttr>()">, "unit attribute"> {
+  let storageType = [{ UnitAttr }];
+  let constBuilderCall = "$_builder.getUnitAttr()";
+  let convertFromStorage = "$_self != nullptr";
+  let returnType = "bool";
+  let isOptional = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Enum attribute kinds
+
+// Additional information for an enum attribute case.
+class EnumAttrCaseInfo<string sym, int val> {
+  // The C++ enumerant symbol
+  string symbol = sym;
+
+  // The C++ enumerant value
+  // If less than zero, there will be no explicit discriminator values assigned
+  // to enumerators in the generated enum class.
+  int value = val;
+}
+
+// An enum attribute case stored with StringAttr.
+class StrEnumAttrCase<string sym, int val = -1> :
+    EnumAttrCaseInfo<sym, val>,
+    StringBasedAttr<
+      CPred<"$_self.cast<StringAttr>().getValue() == \"" # sym # "\"">,
+      "case " # sym>;
+
+// An enum attribute case stored with IntegerAttr.
+class IntEnumAttrCaseBase<I intType, string sym, int val> :
+    EnumAttrCaseInfo<sym, val>,
+    IntegerAttrBase<intType, "case " # sym> {
+  let predicate =
+    CPred<"$_self.cast<IntegerAttr>().getInt() == " # val>;
+}
+
+class I32EnumAttrCase<string sym, int val> : IntEnumAttrCaseBase<I32, sym, val>;
+class I64EnumAttrCase<string sym, int val> : IntEnumAttrCaseBase<I64, sym, val>;
+
+// Additional information for an enum attribute.
+class EnumAttrInfo<string name, list<EnumAttrCaseInfo> cases> {
+  // The C++ enum class name
+  string className = name;
+
+  // List of all accepted cases
+  list<EnumAttrCaseInfo> enumerants = cases;
+
+  // The following fields are only used by the EnumsGen backend to generate
+  // an enum class definition and conversion utility functions.
+
+  // The underlying type for the C++ enum class. An empty string mean the
+  // underlying type is not explicitly specified.
+  string underlyingType = "";
+
+  // The C++ namespaces that the enum class definition and utility functions
+  // should be placed into.
+  //
+  // Normally you want to place the full namespace path here. If it is nested,
+  // use "::" as the delimiter, e.g., given "A::B", generated code will be
+  // placed in `namespace A { namespace B { ... } }`. To avoid placing in any
+  // namespace, use "".
+  // TODO(b/134741431): use dialect to provide the namespace.
+  string cppNamespace = "";
+
+  // The name of the utility function that converts a value of the underlying
+  // type to the corresponding symbol. It will have the following signature:
+  //
+  // ```c++
+  // llvm::Optional<<qualified-enum-class-name>> <fn-name>(<underlying-type>);
+  // ```
+  string underlyingToSymbolFnName = "symbolize" # name;
+
+  // The name of the utility function that converts a string to the
+  // corresponding symbol. It will have the following signature:
+  //
+  // ```c++
+  // llvm::Optional<<qualified-enum-class-name>> <fn-name>(llvm::StringRef);
+  // ```
+  string stringToSymbolFnName = "symbolize" # name;
+
+  // The name of the utility function that converts a symbol to the
+  // corresponding string. It will have the following signature:
+  //
+  // ```c++
+  // llvm::StringRef <fn-name>(<qualified-enum-class-name>);
+  // ```
+  string symbolToStringFnName = "stringify" # name;
+
+  // The name of the utility function that returns the max enum value used
+  // within the enum class. It will have the following signature:
+  //
+  // ```c++
+  // static constexpr unsigned <fn-name>();
+  // ```
+  string maxEnumValFnName = "getMaxEnumValFor" # name;
+}
+
+// An enum attribute backed by StringAttr.
+//
+// Op attributes of this kind are stored as StringAttr. Extra verification will
+// be generated on the string though: only the symbols of the allowed cases are
+// permitted as the string value.
+class StrEnumAttr<string name, string description,
+                  list<StrEnumAttrCase> cases> :
+    EnumAttrInfo<name, cases>,
+    StringBasedAttr<
+      And<[StrAttr.predicate, Or<!foreach(case, cases, case.predicate)>]>,
+      !if(!empty(description), "allowed string cases: " #
+          StrJoin<!foreach(case, cases, "'" # case.symbol # "'")>.result,
+          description)>;
+
+// An enum attribute backed by IntegerAttr.
+//
+// Op attributes of this kind are stored as IntegerAttr. Extra verification will
+// be generated on the integer though: only the values of the allowed cases are
+// permitted as the integer value.
+class IntEnumAttr<I intType, string name, string description,
+                  list<IntEnumAttrCaseBase> cases> :
+    EnumAttrInfo<name, cases>,
+    IntegerAttrBase<intType,
+      !if(!empty(description), "allowed " # intType.description # " cases: " #
+          StrJoinInt<!foreach(case, cases, case.value)>.result, description)> {
+  let predicate = And<[
+    IntegerAttrBase<intType, "">.predicate,
+    Or<!foreach(case, cases, case.predicate)>]>;
+}
+
+class I32EnumAttr<string name, string description,
+                  list<I32EnumAttrCase> cases> :
+    IntEnumAttr<I32, name, description, cases> {
+  let underlyingType = "uint32_t";
+}
+class I64EnumAttr<string name, string description,
+                  list<I64EnumAttrCase> cases> :
+    IntEnumAttr<I64, name, description, cases> {
+  let underlyingType = "uint64_t";
+}
+
+//===----------------------------------------------------------------------===//
+// Composite attribute kinds
+
+class ElementsAttrBase<Pred condition, string description> :
+    Attr<condition, description> {
+  let storageType = [{ ElementsAttr }];
+  let returnType = [{ ElementsAttr }];
+  let convertFromStorage = "$_self";
+}
+
+def ElementsAttr: ElementsAttrBase<CPred<"$_self.isa<ElementsAttr>()">,
+                                   "constant vector/tensor attribute">;
+
+// Base class for array attributes.
+class ArrayAttrBase<Pred condition, string description> :
+    Attr<condition, description> {
+  let storageType = [{ ArrayAttr }];
+  let returnType = [{ ArrayAttr }];
+  let convertFromStorage = "$_self";
+}
+
+def ArrayAttr : ArrayAttrBase<CPred<"$_self.isa<ArrayAttr>()">,
+                              "array attribute">;
+
+// Base class for array attributes whose elements are of the same kind.
+// `element` specifies the element attribute kind stored in this array.
+class TypedArrayAttrBase<Attr element, string description>: ArrayAttrBase<
+    And<[
+      // Guranatee this is an ArrayAttr first
+      CPred<"$_self.isa<ArrayAttr>()">,
+      // Guarantee all elements satisfy the constraints from `element`
+      Concat<"llvm::all_of($_self.cast<ArrayAttr>(), "
+                          "[](Attribute attr) { return ",
+                             SubstLeaves<"$_self", "attr", element.predicate>,
+                          "; })">]>,
+    description> {
+  let constBuilderCall = "$_builder.getArrayAttr($0)";
+}
+
+def I32ArrayAttr : TypedArrayAttrBase<I32Attr,
+                                      "32-bit integer array attribute"> {
+  let constBuilderCall = "$_builder.getI32ArrayAttr($0)";
+}
+def I64ArrayAttr : TypedArrayAttrBase<I64Attr,
+                                      "64-bit integer array attribute"> {
+  let constBuilderCall = "$_builder.getI64ArrayAttr($0)";
+}
+def F32ArrayAttr : TypedArrayAttrBase<F32Attr, "32-bit float array attribute"> {
+  let constBuilderCall = "$_builder.getF32ArrayAttr($0)";
+}
+def F64ArrayAttr : TypedArrayAttrBase<F64Attr, "64-bit float array attribute"> {
+  let constBuilderCall = "$_builder.getF64ArrayAttr($0)";
+}
+def StrArrayAttr : TypedArrayAttrBase<StrAttr, "string array attribute"> {
+  let constBuilderCall = "$_builder.getStrArrayAttr($0)";
+}
+def TypeArrayAttr : TypedArrayAttrBase<TypeAttr, "type array attribute"> {
+  let constBuilderCall = ?;
+}
+
+// Attributes containing symbol references.
+def SymbolRefAttr : Attr<CPred<"$_self.isa<SymbolRefAttr>()">,
+                        "symbol reference attribute"> {
+  let storageType = [{ SymbolRefAttr }];
+  let returnType = [{ StringRef }];
+  let constBuilderCall = "$_builder.getSymbolRefAttr($0)";
+}
+
+//===----------------------------------------------------------------------===//
+// Derive attribute kinds
+
+// DerivedAttr are attributes whose value is computed from properties
+// of the operation. They do not require additional storage and are
+// materialized as needed.
+class DerivedAttr<code ret, code b> : Attr<CPred<"true">, "derived attribute"> {
+  let returnType = ret;
+  code body = b;
+}
+
+// Derived attribute that returns a mlir::Type.
+class DerivedTypeAttr<code body> : DerivedAttr<"Type", body>;
+
+//===----------------------------------------------------------------------===//
+// Constant attribute kinds
+
+// Represents a constant attribute of specific Attr type. A constant
+// attribute can be specified only of attributes that have a constant
+// builder call defined. The constant value is specified as a string.
+//
+// If used as a constraint, it generates a matcher on a constant attribute by
+// using the constant value builder of the attribute and the value.
+class ConstantAttr<Attr attribute, string val> : AttrConstraint<
+    CPred<"$_self == " # !subst("$0", val, attribute.constBuilderCall)>,
+    "constant attribute " # val> {
+  Attr attr = attribute;
+  string value = val;
+}
+
+class ConstF32Attr<string val> : ConstantAttr<F32Attr, val>;
+def ConstBoolAttrFalse : ConstantAttr<BoolAttr, "false">;
+def ConstBoolAttrTrue : ConstantAttr<BoolAttr, "true">;
+def ConstUnitAttr : ConstantAttr<UnitAttr, "unit">;
+
+//===----------------------------------------------------------------------===//
+// Common attribute constraints
+//===----------------------------------------------------------------------===//
+
+// A general mechanism to further confine the given `attr` with all the
+// `constraints`. This allows to compose complex constraints out of a series
+// of more primitive ones.
+class Confined<Attr attr, list<AttrConstraint> constraints> : Attr<
+    And<!listconcat([attr.predicate],
+                      !foreach(pred, constraints, pred.predicate))>,
+    !foldl(/*init*/attr.description, /*list*/constraints,
+           prev, cur, prev # " " # cur.description)> {
+  let storageType = attr.storageType;
+  let returnType = attr.returnType;
+  let convertFromStorage = attr.convertFromStorage;
+  let constBuilderCall = attr.constBuilderCall;
+  let defaultValue = attr.defaultValue;
+  let isOptional = attr.isOptional;
+
+  let baseAttr = attr;
+}
+
+// An AttrConstraint that holds if all attr constraints specified in
+// 'constraints' hold.
+class AllAttrConstraintsOf<list<AttrConstraint> constraints> : AttrConstraint<
+    And<!listconcat([!head(constraints).predicate],
+                      !foreach(pred, !tail(constraints), pred.predicate))>,
+    !foldl(/*init*/!head(constraints).description, /*list*/!tail(constraints),
+           prev, cur, prev # " and " # cur.description)> {
+}
+
+class IntMinValue<int n> : AttrConstraint<
+    CPred<"$_self.cast<IntegerAttr>().getInt() >= " # n>,
+    "whose minimal value is " # n>;
+
+class ArrayMinCount<int n> : AttrConstraint<
+    CPred<"$_self.cast<ArrayAttr>().size() >= " # n>,
+    "with at least " # n # " elements">;
+
+class IntArrayNthElemEq<int index, int value> : AttrConstraint<
+    And<[
+      CPred<"$_self.cast<ArrayAttr>().size() > " # index>,
+      CPred<"$_self.cast<ArrayAttr>().getValue()[" # index # "]"
+        ".cast<IntegerAttr>().getInt() == " # value>
+       ]>,
+    "whose " # index # "-th element must be " # value>;
+
+class IntArrayNthElemMinValue<int index, int min> : AttrConstraint<
+    And<[
+      CPred<"$_self.cast<ArrayAttr>().size() > " # index>,
+      CPred<"$_self.cast<ArrayAttr>().getValue()[" # index # "]"
+        ".cast<IntegerAttr>().getInt() >= " # min>
+        ]>,
+    "whose " # index # "-th element must be at least " # min>;
+
+def IsNullAttr : AttrConstraint<
+    CPred<"!$_self">, "empty attribute (for optional attributes)">;
+
+//===----------------------------------------------------------------------===//
+// Region definitions
+//===----------------------------------------------------------------------===//
+
+class Region<Pred condition, string descr = ""> :
+    RegionConstraint<condition, descr>;
+
+// Any region.
+def AnyRegion : Region<CPred<"true">, "any region">;
+
+// A region with the given number of blocks.
+class SizedRegion<int numBlocks> : Region<
+  CPred<"$_self.getBlocks().size() == " # numBlocks>,
+  "region with " # numBlocks # " blocks">;
+
+//===----------------------------------------------------------------------===//
+// OpTrait definitions
+//===----------------------------------------------------------------------===//
+
+// OpTrait represents a trait regarding an op.
+class OpTrait;
+
+// NativeOpTrait corresponds to the MLIR C++ OpTrait mechanism. The
+// purpose to wrap around C++ symbol string with this class is to make
+// traits specified for ops in TableGen less alien and more integrated.
+class NativeOpTrait<string prop> : OpTrait {
+  string trait = prop;
+}
+
+// ParamNativeOpTrait corresponds to the template-parameterized traits in the
+// C++ implementation.  MLIR uses nested class templates to implement such
+// traits leading to constructs of the form "TraitName<Parameters>::Impl". Use
+// the value in `prop` as the trait name and the value in `params` as
+// parameters to construct the native trait class name.
+class ParamNativeOpTrait<string prop, string params>
+    : NativeOpTrait<prop # "<" # params # ">::Impl"> {
+}
+
+// GenInternalOpTrait is an op trait that does not have direct C++ mapping but
+// affects op definition generator internals, like how op builders and
+// operand/attribute/result getters are generated.
+class GenInternalOpTrait<string prop> : OpTrait {
+  string trait = prop;
+}
+
+// PredOpTrait is an op trait implemented by way of a predicate on the op.
+class PredOpTrait<string descr, Pred pred> : OpTrait {
+  string description = descr;
+  Pred predicate = pred;
+}
+
+// Op supports operand broadcast behavior.
+def Broadcastable    : NativeOpTrait<"BroadcastableTwoOperandsOneResult">;
+// X op Y == Y op X
+def Commutative      : NativeOpTrait<"IsCommutative">;
+// Op results are float or vectors/tensors thereof.
+def ResultsAreFloatLike : NativeOpTrait<"ResultsAreFloatLike">;
+// Op has no side effect.
+def NoSideEffect     : NativeOpTrait<"HasNoSideEffect">;
+// Op has the same operand type.
+def SameTypeOperands  : NativeOpTrait<"SameTypeOperands">;
+// Op has same operand and result shape.
+def SameOperandsAndResultShape   : NativeOpTrait<"SameOperandsAndResultShape">;
+// Op has the same operand and result type.
+def SameOperandsAndResultType    : NativeOpTrait<"SameOperandsAndResultType">;
+// Op has the same operand and result element type.
+def SameOperandsAndResultElementType :
+  NativeOpTrait<"SameOperandsAndResultElementType">;
+// Op is a terminator.
+def Terminator       : NativeOpTrait<"IsTerminator">;
+
+// Op's regions have a single block with the specified terminator.
+class SingleBlockImplicitTerminator<string op>
+    : ParamNativeOpTrait<"SingleBlockImplicitTerminator", op>;
+
+// Op's parent operation is the provided one.
+class HasParent<string op>
+    : ParamNativeOpTrait<"HasParent", op>;
+
+// Op result type is derived from the first attribute. If the attribute is an
+// subclass of `TypeAttrBase`, its value is used, otherwise, the type of the
+// attribute content is used.
+def FirstAttrDerivedResultType :
+  GenInternalOpTrait<"FirstAttrDerivedResultType">;
+
+// TODO(antiagainst): Turn the following into normal traits and generate
+// verification for them.
+
+// All variadic operands of the op have the same number of values.
+// A variadic operand contains an array of values whose array size is only
+// known at runtime. This trait requires all variadic operands of an op
+// to have the same array size.
+def SameVariadicOperandSize : GenInternalOpTrait<"SameVariadicOperandSize">;
+// All variadic results of the op have the same number of values.
+// A variadic result contains an array of values whose array size is only
+// known at runtime. This trait requires all variadic results of an op
+// to have the same array size.
+def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">;
+
+//===----------------------------------------------------------------------===//
+// Op definitions
+//===----------------------------------------------------------------------===//
+
+// Marker used to identify the argument list for an op.
+def ins;
+
+// Marker used to identify the result list for an op.
+def outs;
+
+// Marker used to identify the region list for an op.
+def region;
+
+// Class for defining a custom builder.
+//
+// TableGen generates several generic builders for each op by default (see
+// comment in the `Op` class). If the default generated ones cannot cover
+// some use case, custom builders can be defined using instances of this class.
+//
+// The signature of the builder is always
+//
+// ```c++
+// static void build(Builder *builder, OperationState *state,
+//                   <other-parameters>...) {
+//   <body>...
+// }
+// ```
+//
+// To define a custom builder, the parameter list (*including* the `Builder
+// *builder, OperationState *state` part) and body should be passed in
+// as separate template arguments to this class. This is because we generate
+// op declaration and definition into separate files. If an empty string is
+// passed in for `body`, then *only* the builder declaration will be
+// generated; this provides a way to define complicated builders entirely
+// in C++.
+class OpBuilder<string p, code b = ""> {
+  string params = p;
+  code body = b;
+}
+
+// Base class for all ops.
+class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
+  // The dialect of the op.
+  Dialect opDialect = dialect;
+
+  // The mnemonic of the op.
+  string opName = mnemonic;
+
+  // One-line human-readable description of what the op does.
+  string summary = "";
+
+  // Additional, longer human-readable description of what the op does.
+  string description = "";
+
+  // Dag containting the arguments of the op. Default to 0 arguments.
+  dag arguments = (ins);
+
+  // The list of results of the op. Default to 0 results.
+  dag results = (outs);
+
+  // The list of regions of the op. Default to 0 regions.
+  dag regions = (region);
+
+  // Attribute getters can be added to the op by adding an Attr member
+  // with the name and type of the attribute. E.g., adding int attribute
+  // with name "value" and type "i32":
+  //   I32Attr value;
+
+  // Define the hooks used for building, parsing, printing, verification.
+
+  // Custom builder.
+  // In addition to the custom builder provided here, and unless
+  // skipDefaultBuilders is set, two default builders are generated, with the
+  // following signatures:
+  //
+  // ```c++
+  // static void build(Builder *, OperationState *tblgen_state,
+  //                   Type <result0-name>, Type <result1-name>, ...,
+  //                   Value <arg0-name>, Value <arg1-name>, ...,
+  //                   Attribute <attr0-name>, Attribute <attr1-name>, ...);
+  // ```
+  // * where the attributes follow the same declaration order as in the op.
+  //
+  // ```c++
+  // static void build(Builder *, OperationState *tblgen_state,
+  //                   ArrayRef<Type> resultTypes,
+  //                   ArrayRef<Value> operands,
+  //                   ArrayRef<NamedAttribute> attributes);
+  // ```
+  list<OpBuilder> builders = ?;
+
+  // Avoid generating default build functions.  Custom builders must be
+  // provided.
+  bit skipDefaultBuilders = 0;
+
+  // Custom parser.
+  code parser = ?;
+
+  // Custom printer.
+  code printer = ?;
+
+  // Custom verifier.
+  code verifier = ?;
+
+  // Whether this op has associated canonicalization patterns.
+  // TODO(b/120163349): figure out a better way to write canonicalization
+  // patterns in TableGen rules directly instead of using this marker
+  // and C++ implementations.
+  bit hasCanonicalizer = 0;
+
+  // Whether this op has a folder.
+  bit hasFolder = 0;
+
+  // Op traits.
+  list<OpTrait> traits = props;
+
+  // Additional code that will be added to the public part of the generated
+  // C++ code of the op declaration.
+  code extraClassDeclaration = ?;
+}
+
+// The arguments of an op.
+class Arguments<dag args> {
+  dag arguments = args;
+}
+
+// The results of an op.
+class Results<dag rets> {
+  dag results = rets;
+}
+
+//===----------------------------------------------------------------------===//
+// Common value constraints
+//===----------------------------------------------------------------------===//
+
+def HasNoUseOf: Constraint<
+    CPred<"$_self->use_begin() == $_self->use_end()">, "has no use">;
+
+//===----------------------------------------------------------------------===//
+// Common op type constraints
+//===----------------------------------------------------------------------===//
+
+// These traits are for verifying properties of an op that require knowledge of
+// multiple arguments or results. For verifying properties of a single argument
+// or result, prefer operand type constraints.
+
+// These traits often require including "mlir/IR/TypeUtilities.h".
+
+// TODO(b/135033717): Improve the autogenerated error messages.
+
+// Type Constraint operand `idx`'s Element type is `type`.
+class TCopVTEtIs<int idx, Type type> : And<[
+   CPred<"$_op.getNumOperands() > " # idx>,
+   SubstLeaves<"$_self", "$_op.getOperand(" # idx # ")->getType()",
+     IsShapedTypePred>,
+   SubstLeaves<"$_self", "getElementTypeOrSelf($_op.getOperand(" # idx # "))",
+     type.predicate>]>;
+
+// Predicate to verify that a named argument or result's element type matches a
+// given type.
+class TypeIsPred<string name, Type type> :
+   SubstLeaves<"$_self", "$" # name # ".getType()", type.predicate>;
+class TypeIs<string name, Type type> : PredOpTrait<
+  "'" # name # "' is " # type.description, TypeIsPred<name, type>>;
+
+// Predicate to verify that a named argument or result's element type matches a
+// given type.
+class ElementTypeIsPred<string name, Type type> : And<[
+   SubstLeaves<"$_self", "$" # name # ".getType()", IsShapedTypePred>,
+   SubstLeaves<"$_self", "getElementTypeOrSelf($" # name # ")",
+     type.predicate>]>;
+class ElementTypeIs<string name, Type type> : PredOpTrait<
+  "'" # name # "' is " # type.description, ElementTypeIsPred<name, type>>;
+
+// TODO(b/135032064): Only works for non-variadic.
+class AllMatchPred<list<string> names, string operator> :
+    CPred<"llvm::is_splat(llvm::makeArrayRef({" #
+          StrJoin<!foreach(n, names,
+                           !subst("$_self", "$" # n, operator))>.result
+          # "}))">;
+
+class AllMatchTrait<list<string> names, string operator, string description> :
+    PredOpTrait<
+        "all of {" # StrJoin<names>.result # "} have same " # description,
+        AllMatchPred<names, operator>>;
+
+class AllElementTypesMatch<list<string> names> :
+    AllMatchTrait<names, "getElementTypeOrSelf($_self)", "element type">;
+
+class AllTypesMatch<list<string> names> :
+    AllMatchTrait<names, "$_self.getType()", "type">;
+
+// Predicate to verify that the i'th operand and the j'th operand have the same
+// elemental type.
+// Type Constraint operand `i`'s Element type is Same As operand `j`'s Element
+// type.
+class TCopVTEtIsSameAs<int i, int j> : And<[
+    CPred<"$_op.getNumOperands() > std::max(" # i # "u," # j # "u)">,
+    SubstLeaves<"$_self", "$_op.getOperand(" # i # ")->getType()",
+      IsShapedTypePred>,
+    SubstLeaves<"$_self", "$_op.getOperand(" # j # ")->getType()",
+      IsShapedTypePred>,
+    CPred<"mlir::getElementTypeOrSelf($_op.getOperand(" # i # ")) == "
+          "mlir::getElementTypeOrSelf($_op.getOperand(" # j # "))">]>;
+
+// Predicate to verify that the i'th result and the j'th operand exist and has
+// shaped types.
+class TCOpResIsShapedTypePred<int i, int j> : And<[
+    CPred<"$_op.getNumResults() > " # i>,
+    CPred<"$_op.getNumOperands() > " # j>,
+    SubstLeaves<"$_self", "$_op.getResult(" # i # ")->getType()",
+      IsShapedTypePred>,
+    SubstLeaves<"$_self", "$_op.getOperand(" # j # ")->getType()",
+      IsShapedTypePred>]>;
+
+// Basic Predicate to verify that the i'th result and the j'th operand have the
+// same elemental type.
+class TCresVTEtIsSameAsOpBase<int i, int j> :
+    CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")) == "
+          "getElementTypeOrSelf($_op.getOperand(" # j # "))">;
+
+// Predicate to verify that the i'th result and the j'th operand have the same
+// elemental type.
+// Type Constraint result`i`'s Element type is Same As Operand `j`'s Element
+// type.
+class TCresVTEtIsSameAsOp<int i, int j> : And<[
+    TCOpResIsShapedTypePred<i, j>,
+    TCresVTEtIsSameAsOpBase<i, j>]>;
+
+// Predicate to verify that the opId'th operand can be broadcasted to the type
+// of the resId'th result.
+class TCOpIsBroadcastableToRes<int opId, int resId> : And<[
+    TCOpResIsShapedTypePred<opId, resId>,
+    CPred<"OpTrait::util::getBroadcastedType("
+              "$_op.getOperand(" # opId # ")->getType(), "
+              "$_op.getResult(" # resId # ")->getType())">]>;
+
+// Predicate to verify that all the operands at the given `indices`
+// have the same element type.
+// Type Constraint operands' Element type are all Same At the given `indices`.
+// We query the operands' types into a list and check they are all the same.
+// Precondition:
+// 1) all operands involved are of shaped type and
+// 2) the indices are not out of range.
+class TCopVTEtAreSameAt<list<int> indices> : CPred<
+  "llvm::is_splat(mlir::functional::map("
+    "[this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); }, "
+    "llvm::ArrayRef<unsigned>({" # StrJoinInt<indices>.result # "})))">;
+
+//===----------------------------------------------------------------------===//
+// Pattern definitions
+//===----------------------------------------------------------------------===//
+
+// Marker used to identify the delta value added to the default benefit value.
+def addBenefit;
+
+// Base class for op+ -> op+ rewrite rules. These allow declaratively
+// specifying rewrite rules.
+//
+// A rewrite rule contains two components: a source pattern and one or more
+// result patterns. Each pattern is specified as a (recursive) DAG node (tree)
+// in the form of `(node arg0, arg1, ...)`.
+//
+// The `node` are normally MLIR ops, but it can also be one of the directives
+// listed later in this section.
+//
+// ## Symbol binding
+//
+// In the source pattern, `argN` can be used to specify matchers (e.g., using
+// type/attribute type constraints, etc.) and bound to a name for later use.
+// We can also bound names to op instances to reference them later in
+// multi-entity constraints.
+//
+// In the result pattern, `argN` can be used to refer to a previously bound
+// name, with potential transformations (e.g., using tAttr, etc.). `argN` can
+// itself be nested DAG node. We can also bound names to ops to reference
+// them later in other result patterns.
+//
+// For example,
+//
+// ```
+// def : Pattern<(OneResultOp1:$op1 $arg0, $arg1),
+//               [(OneResultOp2:$op2 $arg0, $arg1),
+//                (OneResultOp3 $op2 (OneResultOp4))],
+//               [(HasStaticShapePred $op1)]>;
+// ```
+//
+// `$argN` is bound to the `OneResultOp1`'s N-th argument and used later to
+// build `OneResultOp2`. `$op1` is bound to `OneResultOp1` and used to
+// check whether the result's shape is static. `$op2` is bound to
+// `OneResultOp2` and used to build `OneResultOp3`.
+//
+// ## Multi-result op
+//
+// To create multi-result ops in result pattern, you can use a syntax similar
+// to uni-result op, and it will act as a value pack for all results:
+//
+// ```
+// def : Pattern<(ThreeResultOp ...),
+//               [(TwoResultOp ...), (OneResultOp ...)]>;
+// ```
+//
+// Then `TwoResultOp` will replace the first two values of `ThreeResultOp`.
+//
+// You can also use `$<name>__N` to explicitly access the N-th reusult.
+// ```
+// def : Pattern<(FiveResultOp ...),
+//               [(TwoResultOp1:$res1__1 ...), (replaceWithValue $res1__0),
+//                (TwoResultOp2:$res2 ...), (replaceWithValue $res2__1)]>;
+// ```
+//
+// Then the values generated by `FiveResultOp` will be replaced by
+//
+// * `FiveResultOp`#0: `TwoResultOp1`#1
+// * `FiveResultOp`#1: `TwoResultOp1`#0
+// * `FiveResultOp`#2: `TwoResultOp2`#0
+// * `FiveResultOp`#3: `TwoResultOp2`#1
+// * `FiveResultOp`#4: `TwoResultOp2`#1
+class Pattern<dag source, list<dag> results, list<dag> preds = [],
+  dag benefitAdded = (addBenefit 0)> {
+  dag sourcePattern = source;
+  // Result patterns. Each result pattern is expected to replace one result
+  // of the root op in the source pattern. In the case of more result patterns
+  // than needed to replace the source op, only the last N results generated
+  // by the last N result pattern is used to replace a N-result source op.
+  // So that the beginning result patterns can be used to generate additional
+  // ops to aid building the results used for replacement.
+  list<dag> resultPatterns = results;
+  // Multi-entity constraints. Each constraint here involves multiple entities
+  // matched in source pattern and places further constraints on them as a
+  // whole.
+  list<dag> constraints = preds;
+  // The delta value added to the default benefit value. The default value is
+  // the number of ops in the source pattern. The rule with the highest final
+  // benefit value will be applied first if there are multiple rules matches.
+  // This delta value can be either positive or negative.
+  dag benefitDelta = benefitAdded;
+}
+
+// Form of a pattern which produces a single result.
+class Pat<dag pattern, dag result, list<dag> preds = [],
+  dag benefitAdded = (addBenefit 0)> :
+  Pattern<pattern, [result], preds, benefitAdded>;
+
+// Native code call wrapper. This allows invoking an arbitrary C++ expression
+// to create an op operand/attribute or replace an op result.
+//
+// ## Placeholders
+//
+// If used as a DAG leaf, i.e., `(... NativeCodeCall<"...">:$arg, ...)`,
+// the wrapped expression can take special placeholders listed below:
+//
+// * `$_builder` will be replaced by the current `mlir::PatternRewriter`.
+// * `$_self` will be replaced with the entity this transformer is attached to.
+//   E.g., with the definition `def transform : tAttr<$_self...>`, `$_self` in
+//   `transform:$attr` will be replaced by  the value for `$att`.
+//
+// If used as a DAG node, i.e., `(NativeCodeCall<"..."> <arg0>, ..., <argN>)`,
+// then positional placeholders are also supported; placeholder `$N` in the
+// wrapped C++ expression will be replaced by `<argN>`.
+
+class NativeCodeCall<string expr> {
+  string expression = expr;
+}
+
+//===----------------------------------------------------------------------===//
+// Common directives
+//===----------------------------------------------------------------------===//
+
+// Directive used in result pattern to indicate that no new op are generated,
+// so to replace the matched DAG with an existing SSA value.
+def replaceWithValue;
+
+#endif // OP_BASE
diff --git a/third_party/mlir/include/mlir/IR/OpDefinition.h b/third_party/mlir/include/mlir/IR/OpDefinition.h
new file mode 100644
index 0000000..ed68936
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/OpDefinition.h
@@ -0,0 +1,1026 @@
+//===- OpDefinition.h - Classes for defining concrete Op types --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements helper classes for implementing the "Op" types.  This
+// includes the Op type, which is the base class for Op class definitions,
+// as well as number of traits in the OpTrait namespace that provide a
+// declarative way to specify properties of Ops.
+//
+// The purpose of these types are to allow light-weight implementation of
+// concrete ops (like DimOp) with very little boilerplate.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_OPDEFINITION_H
+#define MLIR_IR_OPDEFINITION_H
+
+#include "mlir/IR/Operation.h"
+#include <type_traits>
+
+namespace mlir {
+class Builder;
+
+namespace OpTrait {
+template <typename ConcreteType> class OneResult;
+}
+
+/// This class represents success/failure for operation parsing. It is
+/// essentially a simple wrapper class around LogicalResult that allows for
+/// explicit conversion to bool. This allows for the parser to chain together
+/// parse rules without the clutter of "failed/succeeded".
+class ParseResult : public LogicalResult {
+public:
+  ParseResult(LogicalResult result = success()) : LogicalResult(result) {}
+
+  // Allow diagnostics emitted during parsing to be converted to failure.
+  ParseResult(const InFlightDiagnostic &) : LogicalResult(failure()) {}
+  ParseResult(const Diagnostic &) : LogicalResult(failure()) {}
+
+  /// Failure is true in a boolean context.
+  explicit operator bool() const { return failed(*this); }
+};
+
+// These functions are out-of-line utilities, which avoids them being template
+// instantiated/duplicated.
+namespace impl {
+/// Insert an operation, generated by `buildTerminatorOp`, at the end of the
+/// region's only block if it does not have a terminator already. If the region
+/// is empty, insert a new block first. `buildTerminatorOp` should return the
+/// terminator operation to insert.
+void ensureRegionTerminator(
+    Region &region, Location loc,
+    llvm::function_ref<Operation *()> buildTerminatorOp);
+/// Templated version that fills the generates the provided operation type.
+template <typename OpTy>
+void ensureRegionTerminator(Region &region, Builder &builder, Location loc) {
+  ensureRegionTerminator(region, loc, [&] {
+    OperationState state(loc, OpTy::getOperationName());
+    OpTy::build(&builder, &state);
+    return Operation::create(state);
+  });
+}
+} // namespace impl
+
+/// This is the concrete base class that holds the operation pointer and has
+/// non-generic methods that only depend on State (to avoid having them
+/// instantiated on template types that don't affect them.
+///
+/// This also has the fallback implementations of customization hooks for when
+/// they aren't customized.
+class OpState {
+public:
+  /// Ops are pointer-like, so we allow implicit conversion to bool.
+  operator bool() { return getOperation() != nullptr; }
+
+  /// This implicitly converts to Operation*.
+  operator Operation *() const { return state; }
+
+  /// Return the operation that this refers to.
+  Operation *getOperation() { return state; }
+
+  /// Returns the closest surrounding operation that contains this operation
+  /// or nullptr if this is a top-level operation.
+  Operation *getParentOp() { return getOperation()->getParentOp(); }
+
+  /// Return the closest surrounding parent operation that is of type 'OpTy'.
+  template <typename OpTy> OpTy getParentOfType() {
+    return getOperation()->getParentOfType<OpTy>();
+  }
+
+  /// Return the context this operation belongs to.
+  MLIRContext *getContext() { return getOperation()->getContext(); }
+
+  /// Print the operation to the given stream.
+  void print(raw_ostream &os) { state->print(os); }
+
+  /// Dump this operation.
+  void dump() { state->dump(); }
+
+  /// The source location the operation was defined or derived from.
+  Location getLoc() { return state->getLoc(); }
+  void setLoc(Location loc) { state->setLoc(loc); }
+
+  /// Return all of the attributes on this operation.
+  ArrayRef<NamedAttribute> getAttrs() { return state->getAttrs(); }
+
+  /// A utility iterator that filters out non-dialect attributes.
+  using dialect_attr_iterator = Operation::dialect_attr_iterator;
+  using dialect_attr_range = Operation::dialect_attr_range;
+
+  /// Return a range corresponding to the dialect attributes for this operation.
+  dialect_attr_range getDialectAttrs() { return state->getDialectAttrs(); }
+  dialect_attr_iterator dialect_attr_begin() {
+    return state->dialect_attr_begin();
+  }
+  dialect_attr_iterator dialect_attr_end() { return state->dialect_attr_end(); }
+
+  /// Return an attribute with the specified name.
+  Attribute getAttr(StringRef name) { return state->getAttr(name); }
+
+  /// If the operation has an attribute of the specified type, return it.
+  template <typename AttrClass> AttrClass getAttrOfType(StringRef name) {
+    return getAttr(name).dyn_cast_or_null<AttrClass>();
+  }
+
+  /// If the an attribute exists with the specified name, change it to the new
+  /// value.  Otherwise, add a new attribute with the specified name/value.
+  void setAttr(Identifier name, Attribute value) {
+    state->setAttr(name, value);
+  }
+  void setAttr(StringRef name, Attribute value) {
+    setAttr(Identifier::get(name, getContext()), value);
+  }
+
+  /// Set the attributes held by this operation.
+  void setAttrs(ArrayRef<NamedAttribute> attributes) {
+    state->setAttrs(attributes);
+  }
+  void setAttrs(NamedAttributeList newAttrs) { state->setAttrs(newAttrs); }
+
+  /// Set the dialect attributes for this operation, and preserve all dependent.
+  template <typename DialectAttrs> void setDialectAttrs(DialectAttrs &&attrs) {
+    state->setDialectAttrs(std::move(attrs));
+  }
+
+  /// Remove the attribute with the specified name if it exists.  The return
+  /// value indicates whether the attribute was present or not.
+  NamedAttributeList::RemoveResult removeAttr(Identifier name) {
+    return state->removeAttr(name);
+  }
+  NamedAttributeList::RemoveResult removeAttr(StringRef name) {
+    return state->removeAttr(Identifier::get(name, getContext()));
+  }
+
+  /// Return true if there are no users of any results of this operation.
+  bool use_empty() { return state->use_empty(); }
+
+  /// Remove this operation from its parent block and delete it.
+  void erase() { state->erase(); }
+
+  /// Emit an error with the op name prefixed, like "'dim' op " which is
+  /// convenient for verifiers.
+  InFlightDiagnostic emitOpError(const Twine &message = {});
+
+  /// Emit an error about fatal conditions with this operation, reporting up to
+  /// any diagnostic handlers that may be listening.
+  InFlightDiagnostic emitError(const Twine &message = {});
+
+  /// Emit a warning about this operation, reporting up to any diagnostic
+  /// handlers that may be listening.
+  InFlightDiagnostic emitWarning(const Twine &message = {});
+
+  /// Emit a remark about this operation, reporting up to any diagnostic
+  /// handlers that may be listening.
+  InFlightDiagnostic emitRemark(const Twine &message = {});
+
+  /// Walk the operation in postorder, calling the callback for each nested
+  /// operation(including this one).
+  void walk(llvm::function_ref<void(Operation *)> callback) {
+    state->walk(callback);
+  }
+
+  /// Specialization of walk to only visit operations of 'OpTy'.
+  template <typename OpTy> void walk(llvm::function_ref<void(OpTy)> callback) {
+    walk([&](Operation *opInst) {
+      if (auto op = dyn_cast<OpTy>(opInst))
+        callback(op);
+    });
+  }
+
+  // These are default implementations of customization hooks.
+public:
+  /// This hook returns any canonicalization pattern rewrites that the operation
+  /// supports, for use by the canonicalization pass.
+  static void getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                          MLIRContext *context) {}
+
+protected:
+  /// If the concrete type didn't implement a custom verifier hook, just fall
+  /// back to this one which accepts everything.
+  LogicalResult verify() { return success(); }
+
+  /// Unless overridden, the custom assembly form of an op is always rejected.
+  /// Op implementations should implement this to return failure.
+  /// On success, they should fill in result with the fields to use.
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+
+  // The fallback for the printer is to print it the generic assembly form.
+  void print(OpAsmPrinter *p);
+
+  /// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
+  /// so we can cast it away here.
+  explicit OpState(Operation *state) : state(state) {}
+
+private:
+  Operation *state;
+};
+
+// Allow comparing operators.
+inline bool operator==(OpState lhs, OpState rhs) {
+  return lhs.getOperation() == rhs.getOperation();
+}
+inline bool operator!=(OpState lhs, OpState rhs) {
+  return lhs.getOperation() != rhs.getOperation();
+}
+
+/// This class represents a single result from folding an operation.
+class OpFoldResult : public llvm::PointerUnion<Attribute, Value *> {
+  using llvm::PointerUnion<Attribute, Value *>::PointerUnion;
+};
+
+/// This template defines the foldHook as used by AbstractOperation.
+///
+/// The default implementation uses a general fold method that can be defined on
+/// custom ops which can return multiple results.
+template <typename ConcreteType, bool isSingleResult, typename = void>
+class FoldingHook {
+public:
+  /// This is an implementation detail of the constant folder hook for
+  /// AbstractOperation.
+  static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
+                                SmallVectorImpl<OpFoldResult> &results) {
+    return cast<ConcreteType>(op).fold(operands, results);
+  }
+
+  /// This hook implements a generalized folder for this operation.  Operations
+  /// can implement this to provide simplifications rules that are applied by
+  /// the Builder::createOrFold API and the canonicalization pass.
+  ///
+  /// This is an intentionally limited interface - implementations of this hook
+  /// can only perform the following changes to the operation:
+  ///
+  ///  1. They can leave the operation alone and without changing the IR, and
+  ///     return failure.
+  ///  2. They can mutate the operation in place, without changing anything else
+  ///     in the IR.  In this case, return success.
+  ///  3. They can return a list of existing values that can be used instead of
+  ///     the operation.  In this case, fill in the results list and return
+  ///     success.  The caller will remove the operation and use those results
+  ///     instead.
+  ///
+  /// This allows expression of some simple in-place canonicalizations (e.g.
+  /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
+  /// generalized constant folding.
+  ///
+  /// If not overridden, this fallback implementation always fails to fold.
+  ///
+  LogicalResult fold(ArrayRef<Attribute> operands,
+                     SmallVectorImpl<OpFoldResult> &results) {
+    return failure();
+  }
+};
+
+/// This template specialization defines the foldHook as used by
+/// AbstractOperation for single-result operations.  This gives the hook a nicer
+/// signature that is easier to implement.
+template <typename ConcreteType, bool isSingleResult>
+class FoldingHook<ConcreteType, isSingleResult,
+                  typename std::enable_if<isSingleResult>::type> {
+public:
+  /// If the operation returns a single value, then the Op can be implicitly
+  /// converted to an Value*.  This yields the value of the only result.
+  operator Value *() {
+    return static_cast<ConcreteType *>(this)->getOperation()->getResult(0);
+  }
+
+  /// This is an implementation detail of the constant folder hook for
+  /// AbstractOperation.
+  static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
+                                SmallVectorImpl<OpFoldResult> &results) {
+    auto result = cast<ConcreteType>(op).fold(operands);
+    if (!result)
+      return failure();
+
+    // Check if the operation was folded in place. In this case, the operation
+    // returns itself.
+    if (result.template dyn_cast<Value *>() != op->getResult(0))
+      results.push_back(result);
+    return success();
+  }
+
+  /// This hook implements a generalized folder for this operation.  Operations
+  /// can implement this to provide simplifications rules that are applied by
+  /// the Builder::createOrFold API and the canonicalization pass.
+  ///
+  /// This is an intentionally limited interface - implementations of this hook
+  /// can only perform the following changes to the operation:
+  ///
+  ///  1. They can leave the operation alone and without changing the IR, and
+  ///     return nullptr.
+  ///  2. They can mutate the operation in place, without changing anything else
+  ///     in the IR.  In this case, return the operation itself.
+  ///  3. They can return an existing SSA value that can be used instead of
+  ///     the operation.  In this case, return that value.  The caller will
+  ///     remove the operation and use that result instead.
+  ///
+  /// This allows expression of some simple in-place canonicalizations (e.g.
+  /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
+  /// generalized constant folding.
+  ///
+  /// If not overridden, this fallback implementation always fails to fold.
+  ///
+  OpFoldResult fold(ArrayRef<Attribute> operands) { return {}; }
+};
+
+//===----------------------------------------------------------------------===//
+// Operation Trait Types
+//===----------------------------------------------------------------------===//
+
+namespace OpTrait {
+
+// These functions are out-of-line implementations of the methods in the
+// corresponding trait classes.  This avoids them being template
+// instantiated/duplicated.
+namespace impl {
+LogicalResult verifyZeroOperands(Operation *op);
+LogicalResult verifyOneOperand(Operation *op);
+LogicalResult verifyNOperands(Operation *op, unsigned numOperands);
+LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
+LogicalResult verifyOperandsAreFloatLike(Operation *op);
+LogicalResult verifyOperandsAreIntegerLike(Operation *op);
+LogicalResult verifySameTypeOperands(Operation *op);
+LogicalResult verifyZeroResult(Operation *op);
+LogicalResult verifyOneResult(Operation *op);
+LogicalResult verifyNResults(Operation *op, unsigned numOperands);
+LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands);
+LogicalResult verifySameOperandsAndResultShape(Operation *op);
+LogicalResult verifySameOperandsAndResultElementType(Operation *op);
+LogicalResult verifySameOperandsAndResultType(Operation *op);
+LogicalResult verifyResultsAreBoolLike(Operation *op);
+LogicalResult verifyResultsAreFloatLike(Operation *op);
+LogicalResult verifyResultsAreIntegerLike(Operation *op);
+LogicalResult verifyIsTerminator(Operation *op);
+} // namespace impl
+
+/// Helper class for implementing traits.  Clients are not expected to interact
+/// with this directly, so its members are all protected.
+template <typename ConcreteType, template <typename> class TraitType>
+class TraitBase {
+protected:
+  /// Return the ultimate Operation being worked on.
+  Operation *getOperation() {
+    // We have to cast up to the trait type, then to the concrete type, then to
+    // the BaseState class in explicit hops because the concrete type will
+    // multiply derive from the (content free) TraitBase class, and we need to
+    // be able to disambiguate the path for the C++ compiler.
+    auto *trait = static_cast<TraitType<ConcreteType> *>(this);
+    auto *concrete = static_cast<ConcreteType *>(trait);
+    auto *base = static_cast<OpState *>(concrete);
+    return base->getOperation();
+  }
+
+  /// Provide default implementations of trait hooks.  This allows traits to
+  /// provide exactly the overrides they care about.
+  static LogicalResult verifyTrait(Operation *op) { return success(); }
+  static AbstractOperation::OperationProperties getTraitProperties() {
+    return 0;
+  }
+};
+
+namespace detail {
+/// Utility trait base that provides accessors for derived traits that have
+/// multiple operands.
+template <typename ConcreteType, template <typename> class TraitType>
+struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
+  using operand_iterator = Operation::operand_iterator;
+  using operand_range = Operation::operand_range;
+  using operand_type_iterator = Operation::operand_type_iterator;
+  using operand_type_range = Operation::operand_type_range;
+
+  /// Return the number of operands.
+  unsigned getNumOperands() { return this->getOperation()->getNumOperands(); }
+
+  /// Return the operand at index 'i'.
+  Value *getOperand(unsigned i) { return this->getOperation()->getOperand(i); }
+
+  /// Set the operand at index 'i' to 'value'.
+  void setOperand(unsigned i, Value *value) {
+    this->getOperation()->setOperand(i, value);
+  }
+
+  /// Operand iterator access.
+  operand_iterator operand_begin() {
+    return this->getOperation()->operand_begin();
+  }
+  operand_iterator operand_end() { return this->getOperation()->operand_end(); }
+  operand_range getOperands() { return this->getOperation()->getOperands(); }
+
+  /// Operand type access.
+  operand_type_iterator operand_type_begin() {
+    return this->getOperation()->operand_type_begin();
+  }
+  operand_type_iterator operand_type_end() {
+    return this->getOperation()->operand_type_end();
+  }
+  operand_type_range getOperandTypes() {
+    return this->getOperation()->getOperandTypes();
+  }
+};
+} // end namespace detail
+
+/// This class provides the API for ops that are known to have no
+/// SSA operand.
+template <typename ConcreteType>
+class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifyZeroOperands(op);
+  }
+
+private:
+  // Disable these.
+  void getOperand() {}
+  void setOperand() {}
+};
+
+/// This class provides the API for ops that are known to have exactly one
+/// SSA operand.
+template <typename ConcreteType>
+class OneOperand : public TraitBase<ConcreteType, OneOperand> {
+public:
+  Value *getOperand() { return this->getOperation()->getOperand(0); }
+
+  void setOperand(Value *value) { this->getOperation()->setOperand(0, value); }
+
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifyOneOperand(op);
+  }
+};
+
+/// This class provides the API for ops that are known to have a specified
+/// number of operands.  This is used as a trait like this:
+///
+///   class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> {
+///
+template <unsigned N> class NOperands {
+public:
+  static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2");
+
+  template <typename ConcreteType>
+  class Impl
+      : public detail::MultiOperandTraitBase<ConcreteType, NOperands<N>::Impl> {
+  public:
+    static LogicalResult verifyTrait(Operation *op) {
+      return impl::verifyNOperands(op, N);
+    }
+  };
+};
+
+/// This class provides the API for ops that are known to have a at least a
+/// specified number of operands.  This is used as a trait like this:
+///
+///   class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> {
+///
+template <unsigned N> class AtLeastNOperands {
+public:
+  template <typename ConcreteType>
+  class Impl : public detail::MultiOperandTraitBase<ConcreteType,
+                                                    AtLeastNOperands<N>::Impl> {
+  public:
+    static LogicalResult verifyTrait(Operation *op) {
+      return impl::verifyAtLeastNOperands(op, N);
+    }
+  };
+};
+
+/// This class provides the API for ops which have an unknown number of
+/// SSA operands.
+template <typename ConcreteType>
+class VariadicOperands
+    : public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {};
+
+/// This class provides return value APIs for ops that are known to have
+/// zero results.
+template <typename ConcreteType>
+class ZeroResult : public TraitBase<ConcreteType, ZeroResult> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifyZeroResult(op);
+  }
+};
+
+namespace detail {
+/// Utility trait base that provides accessors for derived traits that have
+/// multiple results.
+template <typename ConcreteType, template <typename> class TraitType>
+struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
+  using result_iterator = Operation::result_iterator;
+  using result_range = Operation::result_range;
+  using result_type_iterator = Operation::result_type_iterator;
+  using result_type_range = Operation::result_type_range;
+
+  /// Return the number of results.
+  unsigned getNumResults() { return this->getOperation()->getNumResults(); }
+
+  /// Return the result at index 'i'.
+  Value *getResult(unsigned i) { return this->getOperation()->getResult(i); }
+
+  /// Replace all uses of results of this operation with the provided 'values'.
+  /// 'values' may correspond to an existing operation, or a range of 'Value'.
+  template <typename ValuesT> void replaceAllUsesWith(ValuesT &&values) {
+    this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values));
+  }
+
+  /// Return the type of the `i`-th result.
+  Type getType(unsigned i) { return getResult(i)->getType(); }
+
+  /// Result iterator access.
+  result_iterator result_begin() {
+    return this->getOperation()->result_begin();
+  }
+  result_iterator result_end() { return this->getOperation()->result_end(); }
+  result_range getResults() { return this->getOperation()->getResults(); }
+
+  /// Result type access.
+  result_type_iterator result_type_begin() {
+    return this->getOperation()->result_type_begin();
+  }
+  result_type_iterator result_type_end() {
+    return this->getOperation()->result_type_end();
+  }
+  result_type_range getResultTypes() {
+    return this->getOperation()->getResultTypes();
+  }
+};
+} // end namespace detail
+
+/// This class provides return value APIs for ops that are known to have a
+/// single result.
+template <typename ConcreteType>
+class OneResult : public TraitBase<ConcreteType, OneResult> {
+public:
+  Value *getResult() { return this->getOperation()->getResult(0); }
+  Type getType() { return getResult()->getType(); }
+
+  /// Replace all uses of 'this' value with the new value, updating anything in
+  /// the IR that uses 'this' to use the other value instead.  When this returns
+  /// there are zero uses of 'this'.
+  void replaceAllUsesWith(Value *newValue) {
+    getResult()->replaceAllUsesWith(newValue);
+  }
+
+  /// Replace all uses of 'this' value with the result of 'op'.
+  void replaceAllUsesWith(Operation *op) {
+    this->getOperation()->replaceAllUsesWith(op);
+  }
+
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifyOneResult(op);
+  }
+};
+
+/// This class provides the API for ops that are known to have a specified
+/// number of results.  This is used as a trait like this:
+///
+///   class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> {
+///
+template <unsigned N> class NResults {
+public:
+  static_assert(N > 1, "use ZeroResult/OneResult for N < 2");
+
+  template <typename ConcreteType>
+  class Impl
+      : public detail::MultiResultTraitBase<ConcreteType, NResults<N>::Impl> {
+  public:
+    static LogicalResult verifyTrait(Operation *op) {
+      return impl::verifyNResults(op, N);
+    }
+  };
+};
+
+/// This class provides the API for ops that are known to have at least a
+/// specified number of results.  This is used as a trait like this:
+///
+///   class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> {
+///
+template <unsigned N> class AtLeastNResults {
+public:
+  template <typename ConcreteType>
+  class Impl : public detail::MultiResultTraitBase<ConcreteType,
+                                                   AtLeastNResults<N>::Impl> {
+  public:
+    static LogicalResult verifyTrait(Operation *op) {
+      return impl::verifyAtLeastNResults(op, N);
+    }
+  };
+};
+
+/// This class provides the API for ops which have an unknown number of
+/// results.
+template <typename ConcreteType>
+class VariadicResults
+    : public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {};
+
+/// This class provides verification for ops that are known to have the same
+/// operand and result shape: both are scalars, vectors/tensors of the same
+/// shape.
+template <typename ConcreteType>
+class SameOperandsAndResultShape
+    : public TraitBase<ConcreteType, SameOperandsAndResultShape> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifySameOperandsAndResultShape(op);
+  }
+};
+
+/// This class provides verification for ops that are known to have the same
+/// operand and result element type.
+///
+template <typename ConcreteType>
+class SameOperandsAndResultElementType
+    : public TraitBase<ConcreteType, SameOperandsAndResultElementType> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifySameOperandsAndResultElementType(op);
+  }
+};
+
+/// This class provides verification for ops that are known to have the same
+/// operand and result type.
+///
+/// Note: this trait subsumes the SameOperandsAndResultShape and
+/// SameOperandsAndResultElementType traits.
+template <typename ConcreteType>
+class SameOperandsAndResultType
+    : public TraitBase<ConcreteType, SameOperandsAndResultType> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifySameOperandsAndResultType(op);
+  }
+};
+
+/// This class verifies that any results of the specified op have a boolean
+/// type, a vector thereof, or a tensor thereof.
+template <typename ConcreteType>
+class ResultsAreBoolLike : public TraitBase<ConcreteType, ResultsAreBoolLike> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifyResultsAreBoolLike(op);
+  }
+};
+
+/// This class verifies that any results of the specified op have a floating
+/// point type, a vector thereof, or a tensor thereof.
+template <typename ConcreteType>
+class ResultsAreFloatLike
+    : public TraitBase<ConcreteType, ResultsAreFloatLike> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifyResultsAreFloatLike(op);
+  }
+};
+
+/// This class verifies that any results of the specified op have an integer or
+/// index type, a vector thereof, or a tensor thereof.
+template <typename ConcreteType>
+class ResultsAreIntegerLike
+    : public TraitBase<ConcreteType, ResultsAreIntegerLike> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifyResultsAreIntegerLike(op);
+  }
+};
+
+/// This class adds property that the operation is commutative.
+template <typename ConcreteType>
+class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {
+public:
+  static AbstractOperation::OperationProperties getTraitProperties() {
+    return static_cast<AbstractOperation::OperationProperties>(
+        OperationProperty::Commutative);
+  }
+};
+
+/// This class adds property that the operation has no side effects.
+template <typename ConcreteType>
+class HasNoSideEffect : public TraitBase<ConcreteType, HasNoSideEffect> {
+public:
+  static AbstractOperation::OperationProperties getTraitProperties() {
+    return static_cast<AbstractOperation::OperationProperties>(
+        OperationProperty::NoSideEffect);
+  }
+};
+
+/// This class verifies that all operands of the specified op have a float type,
+/// a vector thereof, or a tensor thereof.
+template <typename ConcreteType>
+class OperandsAreFloatLike
+    : public TraitBase<ConcreteType, OperandsAreFloatLike> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifyOperandsAreFloatLike(op);
+  }
+};
+
+/// This class verifies that all operands of the specified op have an integer or
+/// index type, a vector thereof, or a tensor thereof.
+template <typename ConcreteType>
+class OperandsAreIntegerLike
+    : public TraitBase<ConcreteType, OperandsAreIntegerLike> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifyOperandsAreIntegerLike(op);
+  }
+};
+
+/// This class verifies that all operands of the specified op have the same
+/// type.
+template <typename ConcreteType>
+class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifySameTypeOperands(op);
+  }
+};
+
+/// This class provides the API for ops that are known to be terminators.
+template <typename ConcreteType>
+class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
+public:
+  static AbstractOperation::OperationProperties getTraitProperties() {
+    return static_cast<AbstractOperation::OperationProperties>(
+        OperationProperty::Terminator);
+  }
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifyIsTerminator(op);
+  }
+
+  unsigned getNumSuccessors() {
+    return this->getOperation()->getNumSuccessors();
+  }
+  unsigned getNumSuccessorOperands(unsigned index) {
+    return this->getOperation()->getNumSuccessorOperands(index);
+  }
+
+  Block *getSuccessor(unsigned index) {
+    return this->getOperation()->getSuccessor(index);
+  }
+
+  void setSuccessor(Block *block, unsigned index) {
+    return this->getOperation()->setSuccessor(block, index);
+  }
+
+  void addSuccessorOperand(unsigned index, Value *value) {
+    return this->getOperation()->addSuccessorOperand(index, value);
+  }
+  void addSuccessorOperands(unsigned index, ArrayRef<Value *> values) {
+    return this->getOperation()->addSuccessorOperand(index, values);
+  }
+};
+
+/// This class provides the API for ops that are known to be isolated from
+/// above.
+template <typename ConcreteType>
+class IsIsolatedFromAbove
+    : public TraitBase<ConcreteType, IsIsolatedFromAbove> {
+public:
+  static AbstractOperation::OperationProperties getTraitProperties() {
+    return static_cast<AbstractOperation::OperationProperties>(
+        OperationProperty::IsolatedFromAbove);
+  }
+  static LogicalResult verifyTrait(Operation *op) {
+    for (auto &region : op->getRegions())
+      if (!region.isIsolatedFromAbove(op->getLoc()))
+        return failure();
+    return success();
+  }
+};
+
+/// This class provides APIs and verifiers for ops with regions having a single
+/// block that must terminate with `TerminatorOpType`.
+template <typename TerminatorOpType> struct SingleBlockImplicitTerminator {
+  template <typename ConcreteType>
+  class Impl : public TraitBase<ConcreteType, Impl> {
+  public:
+    static LogicalResult verifyTrait(Operation *op) {
+      for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
+        Region &region = op->getRegion(i);
+
+        // Empty regions are fine.
+        if (region.empty())
+          continue;
+
+        // Non-empty regions must contain a single basic block.
+        if (std::next(region.begin()) != region.end())
+          return op->emitOpError("expects region #")
+                 << i << " to have 0 or 1 blocks";
+
+        Block &block = region.front();
+        if (block.empty())
+          return op->emitOpError() << "expects a non-empty block";
+        Operation &terminator = block.back();
+        if (isa<TerminatorOpType>(terminator))
+          continue;
+
+        return op->emitOpError("expects regions to end with '" +
+                               TerminatorOpType::getOperationName() +
+                               "', found '" +
+                               terminator.getName().getStringRef() + "'")
+                   .attachNote()
+               << "in custom textual format, the absence of terminator implies "
+                  "'"
+               << TerminatorOpType::getOperationName() << '\'';
+      }
+
+      return success();
+    }
+
+    /// Ensure that the given region has the terminator required by this trait.
+    static void ensureTerminator(Region &region, Builder &builder,
+                                 Location loc) {
+      ::mlir::impl::template ensureRegionTerminator<TerminatorOpType>(
+          region, builder, loc);
+    }
+  };
+};
+
+/// This class provides a verifier for ops that are expecting a specific parent.
+template <typename ParentOpType> struct HasParent {
+  template <typename ConcreteType>
+  class Impl : public TraitBase<ConcreteType, Impl> {
+  public:
+    static LogicalResult verifyTrait(Operation *op) {
+      if (isa<ParentOpType>(op->getParentOp()))
+        return success();
+      return op->emitOpError() << "expects parent op '"
+                               << ParentOpType::getOperationName() << "'";
+    }
+  };
+};
+
+} // end namespace OpTrait
+
+//===----------------------------------------------------------------------===//
+// Operation Definition classes
+//===----------------------------------------------------------------------===//
+
+/// This provides public APIs that all operations should have.  The template
+/// argument 'ConcreteType' should be the concrete type by CRTP and the others
+/// are base classes by the policy pattern.
+template <typename ConcreteType, template <typename T> class... Traits>
+class Op : public OpState,
+           public Traits<ConcreteType>...,
+           public FoldingHook<ConcreteType,
+                              llvm::is_one_of<OpTrait::OneResult<ConcreteType>,
+                                              Traits<ConcreteType>...>::value> {
+public:
+  /// Return if this operation contains the provided trait.
+  template <template <typename T> class Trait>
+  static constexpr bool hasTrait() {
+    return llvm::is_one_of<Trait<ConcreteType>, Traits<ConcreteType>...>::value;
+  }
+
+  /// Return the operation that this refers to.
+  Operation *getOperation() { return OpState::getOperation(); }
+
+  /// Return the dialect that this refers to.
+  Dialect *getDialect() { return getOperation()->getDialect(); }
+
+  /// Return the parent Region of this operation.
+  Region *getParentRegion() { return getOperation()->getParentRegion(); }
+
+  /// Return true if this "op class" can match against the specified operation.
+  /// This hook can be overridden with a more specific implementation in
+  /// the subclass of Base.
+  ///
+  static bool classof(Operation *op) {
+    return op->getName().getStringRef() == ConcreteType::getOperationName();
+  }
+
+  /// This is the hook used by the AsmParser to parse the custom form of this
+  /// op from an .mlir file.  Op implementations should provide a parse method,
+  /// which returns failure.  On success, they should return fill in result with
+  /// the fields to use.
+  static ParseResult parseAssembly(OpAsmParser *parser,
+                                   OperationState *result) {
+    return ConcreteType::parse(parser, result);
+  }
+
+  /// This is the hook used by the AsmPrinter to emit this to the .mlir file.
+  /// Op implementations should provide a print method.
+  static void printAssembly(Operation *op, OpAsmPrinter *p) {
+    auto opPointer = dyn_cast<ConcreteType>(op);
+    assert(opPointer &&
+           "op's name does not match name of concrete type instantiated with");
+    opPointer.print(p);
+  }
+
+  /// This is the hook that checks whether or not this operation is well
+  /// formed according to the invariants of its opcode.  It delegates to the
+  /// Traits for their policy implementations, and allows the user to specify
+  /// their own verify() method.
+  ///
+  /// On success this returns false; on failure it emits an error to the
+  /// diagnostic subsystem and returns true.
+  static LogicalResult verifyInvariants(Operation *op) {
+    return failure(
+        failed(BaseVerifier<Traits<ConcreteType>...>::verifyTrait(op)) ||
+        failed(cast<ConcreteType>(op).verify()));
+  }
+
+  // Returns the properties of an operation by combining the properties of the
+  // traits of the op.
+  static AbstractOperation::OperationProperties getOperationProperties() {
+    return BaseProperties<Traits<ConcreteType>...>::getTraitProperties();
+  }
+
+  /// Expose the type we are instantiated on to template machinery that may want
+  /// to introspect traits on this operation.
+  using ConcreteOpType = ConcreteType;
+
+  /// This is a public constructor.  Any op can be initialized to null.
+  explicit Op() : OpState(nullptr) {}
+  Op(std::nullptr_t) : OpState(nullptr) {}
+
+  /// This is a public constructor to enable access via the llvm::cast family of
+  /// methods. This should not be used directly.
+  explicit Op(Operation *state) : OpState(state) {}
+
+  /// Methods for supporting PointerLikeTypeTraits.
+  const void *getAsOpaquePointer() const {
+    return static_cast<const void *>((Operation *)*this);
+  }
+  static ConcreteOpType getFromOpaquePointer(const void *pointer) {
+    return ConcreteOpType(
+        reinterpret_cast<Operation *>(const_cast<void *>(pointer)));
+  }
+
+private:
+  template <typename... Types> struct BaseVerifier;
+
+  template <typename First, typename... Rest>
+  struct BaseVerifier<First, Rest...> {
+    static LogicalResult verifyTrait(Operation *op) {
+      return failure(failed(First::verifyTrait(op)) ||
+                     failed(BaseVerifier<Rest...>::verifyTrait(op)));
+    }
+  };
+
+  template <typename...> struct BaseVerifier {
+    static LogicalResult verifyTrait(Operation *op) { return success(); }
+  };
+
+  template <typename... Types> struct BaseProperties;
+
+  template <typename First, typename... Rest>
+  struct BaseProperties<First, Rest...> {
+    static AbstractOperation::OperationProperties getTraitProperties() {
+      return First::getTraitProperties() |
+             BaseProperties<Rest...>::getTraitProperties();
+    }
+  };
+
+  template <typename...> struct BaseProperties {
+    static AbstractOperation::OperationProperties getTraitProperties() {
+      return 0;
+    }
+  };
+
+  /// Returns true if this operation contains the trait for the given classID.
+  static bool hasTrait(ClassID *traitID) {
+    return llvm::is_contained(llvm::makeArrayRef({ClassID::getID<Traits>()...}),
+                              traitID);
+  }
+
+  /// Allow access to 'hasTrait'.
+  friend AbstractOperation;
+};
+
+// These functions are out-of-line implementations of the methods in BinaryOp,
+// which avoids them being template instantiated/duplicated.
+namespace impl {
+void buildBinaryOp(Builder *builder, OperationState *result, Value *lhs,
+                   Value *rhs);
+ParseResult parseBinaryOp(OpAsmParser *parser, OperationState *result);
+// Prints the given binary `op` in custom assembly form if both the two operands
+// and the result have the same time. Otherwise, prints the generic assembly
+// form.
+void printBinaryOp(Operation *op, OpAsmPrinter *p);
+} // namespace impl
+
+// These functions are out-of-line implementations of the methods in CastOp,
+// which avoids them being template instantiated/duplicated.
+namespace impl {
+void buildCastOp(Builder *builder, OperationState *result, Value *source,
+                 Type destType);
+ParseResult parseCastOp(OpAsmParser *parser, OperationState *result);
+void printCastOp(Operation *op, OpAsmPrinter *p);
+Value *foldCastOp(Operation *op);
+} // namespace impl
+} // end namespace mlir
+
+#endif
diff --git a/third_party/mlir/include/mlir/IR/OpImplementation.h b/third_party/mlir/include/mlir/IR/OpImplementation.h
new file mode 100644
index 0000000..49a5314
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/OpImplementation.h
@@ -0,0 +1,527 @@
+//===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This classes used by the implementation details of Op types.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_OPIMPLEMENTATION_H
+#define MLIR_IR_OPIMPLEMENTATION_H
+
+#include "mlir/IR/OpDefinition.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/SMLoc.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+
+class Builder;
+
+//===----------------------------------------------------------------------===//
+// OpAsmPrinter
+//===----------------------------------------------------------------------===//
+
+/// This is a pure-virtual base class that exposes the asmprinter hooks
+/// necessary to implement a custom print() method.
+class OpAsmPrinter {
+public:
+  OpAsmPrinter() {}
+  virtual ~OpAsmPrinter();
+  virtual raw_ostream &getStream() const = 0;
+
+  /// Print implementations for various things an operation contains.
+  virtual void printOperand(Value *value) = 0;
+
+  /// Print a comma separated list of operands.
+  template <typename ContainerType>
+  void printOperands(const ContainerType &container) {
+    printOperands(container.begin(), container.end());
+  }
+
+  /// Print a comma separated list of operands.
+  template <typename IteratorType>
+  void printOperands(IteratorType it, IteratorType end) {
+    if (it == end)
+      return;
+    printOperand(*it);
+    for (++it; it != end; ++it) {
+      getStream() << ", ";
+      printOperand(*it);
+    }
+  }
+  virtual void printType(Type type) = 0;
+  virtual void printAttribute(Attribute attr) = 0;
+
+  /// Print a successor, and use list, of a terminator operation given the
+  /// terminator and the successor index.
+  virtual void printSuccessorAndUseList(Operation *term, unsigned index) = 0;
+
+  /// If the specified operation has attributes, print out an attribute
+  /// dictionary with their values.  elidedAttrs allows the client to ignore
+  /// specific well known attributes, commonly used if the attribute value is
+  /// printed some other way (like as a fixed operand).
+  virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
+                                     ArrayRef<StringRef> elidedAttrs = {}) = 0;
+
+  /// Print the entire operation with the default generic assembly form.
+  virtual void printGenericOp(Operation *op) = 0;
+
+  /// Prints a region.
+  virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true,
+                           bool printBlockTerminators = true) = 0;
+
+  /// Prints an affine map of SSA ids, where SSA id names are used in place
+  /// of dims/symbols.
+  /// Operand values must come from single-result sources, and be valid
+  /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
+  virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
+                                      ArrayRef<Value *> operands) = 0;
+
+  /// Print an optional arrow followed by a type list.
+  void printOptionalArrowTypeList(ArrayRef<Type> types) {
+    if (types.empty())
+      return;
+    auto &os = getStream() << " -> ";
+    bool wrapped = types.size() != 1 || types[0].isa<FunctionType>();
+    if (wrapped)
+      os << '(';
+    interleaveComma(types, *this);
+    if (wrapped)
+      os << ')';
+  }
+
+  /// Print the complete type of an operation in functional form.
+  void printFunctionalType(Operation *op) {
+    auto &os = getStream();
+    os << "(";
+    interleaveComma(op->getNonSuccessorOperands(), os,
+                    [&](Value *operand) { printType(operand->getType()); });
+    os << ") -> ";
+    if (op->getNumResults() == 1 &&
+        !op->getResult(0)->getType().isa<FunctionType>()) {
+      printType(op->getResult(0)->getType());
+    } else {
+      os << '(';
+      interleaveComma(op->getResultTypes(), os);
+      os << ')';
+    }
+  }
+
+private:
+  OpAsmPrinter(const OpAsmPrinter &) = delete;
+  void operator=(const OpAsmPrinter &) = delete;
+};
+
+// Make the implementations convenient to use.
+inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value &value) {
+  p.printOperand(&value);
+  return p;
+}
+
+inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Type type) {
+  p.printType(type);
+  return p;
+}
+
+inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Attribute attr) {
+  p.printAttribute(attr);
+  return p;
+}
+
+// Support printing anything that isn't convertible to one of the above types,
+// even if it isn't exactly one of them.  For example, we want to print
+// FunctionType with the Type version above, not have it match this.
+template <typename T, typename std::enable_if<
+                          !std::is_convertible<T &, Value &>::value &&
+                              !std::is_convertible<T &, Type &>::value &&
+                              !std::is_convertible<T &, Attribute &>::value,
+                          T>::type * = nullptr>
+inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &other) {
+  p.getStream() << other;
+  return p;
+}
+
+//===----------------------------------------------------------------------===//
+// OpAsmParser
+//===----------------------------------------------------------------------===//
+
+/// The OpAsmParser has methods for interacting with the asm parser: parsing
+/// things from it, emitting errors etc.  It has an intentionally high-level API
+/// that is designed to reduce/constrain syntax innovation in individual
+/// operations.
+///
+/// For example, consider an op like this:
+///
+///    %x = load %p[%1, %2] : memref<...>
+///
+/// The "%x = load" tokens are already parsed and therefore invisible to the
+/// custom op parser.  This can be supported by calling `parseOperandList` to
+/// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to
+/// parse the indices, then calling `parseColonTypeList` to parse the result
+/// type.
+///
+class OpAsmParser {
+public:
+  virtual ~OpAsmParser();
+
+  /// Emit a diagnostic at the specified location and return failure.
+  virtual InFlightDiagnostic emitError(llvm::SMLoc loc,
+                                       const Twine &message = {}) = 0;
+
+  /// Return a builder which provides useful access to MLIRContext, global
+  /// objects like types and attributes.
+  virtual Builder &getBuilder() const = 0;
+
+  /// Get the location of the next token and store it into the argument.  This
+  /// always succeeds.
+  virtual llvm::SMLoc getCurrentLocation() = 0;
+  ParseResult getCurrentLocation(llvm::SMLoc *loc) {
+    *loc = getCurrentLocation();
+    return success();
+  }
+
+  /// Return the location of the original name token.
+  virtual llvm::SMLoc getNameLoc() const = 0;
+
+  // These methods emit an error and return failure or success. This allows
+  // these to be chained together into a linear sequence of || expressions in
+  // many cases.
+
+  //===--------------------------------------------------------------------===//
+  // Token Parsing
+  //===--------------------------------------------------------------------===//
+
+  /// Parse a '->' token.
+  virtual ParseResult parseArrow() = 0;
+
+  /// Parse a '->' token if present
+  virtual ParseResult parseOptionalArrow() = 0;
+
+  /// Parse a `:` token.
+  virtual ParseResult parseColon() = 0;
+
+  /// Parse a `:` token if present.
+  virtual ParseResult parseOptionalColon() = 0;
+
+  /// Parse a `,` token.
+  virtual ParseResult parseComma() = 0;
+
+  /// Parse a `,` token if present.
+  virtual ParseResult parseOptionalComma() = 0;
+
+  /// Parse a `=` token.
+  virtual ParseResult parseEqual() = 0;
+
+  /// Parse a keyword.
+  ParseResult parseKeyword(const char *keyword, const Twine &msg = "") {
+    if (parseOptionalKeyword(keyword))
+      return emitError(getNameLoc(), "expected '") << keyword << "'" << msg;
+    return success();
+  }
+
+  /// Parse a keyword if present.
+  virtual ParseResult parseOptionalKeyword(const char *keyword) = 0;
+
+  /// Parse a `(` token.
+  virtual ParseResult parseLParen() = 0;
+
+  /// Parse a `(` token if present.
+  virtual ParseResult parseOptionalLParen() = 0;
+
+  /// Parse a `)` token.
+  virtual ParseResult parseRParen() = 0;
+
+  /// Parse a `)` token if present.
+  virtual ParseResult parseOptionalRParen() = 0;
+
+  /// Parse a `[` token.
+  virtual ParseResult parseLSquare() = 0;
+
+  /// Parse a `[` token if present.
+  virtual ParseResult parseOptionalLSquare() = 0;
+
+  /// Parse a `]` token.
+  virtual ParseResult parseRSquare() = 0;
+
+  /// Parse a `]` token if present.
+  virtual ParseResult parseOptionalRSquare() = 0;
+
+  /// Parse a `...` token if present;
+  virtual ParseResult parseOptionalEllipsis() = 0;
+
+  //===--------------------------------------------------------------------===//
+  // Attribute Parsing
+  //===--------------------------------------------------------------------===//
+
+  /// Parse an arbitrary attribute and return it in result.  This also adds the
+  /// attribute to the specified attribute list with the specified name.
+  ParseResult parseAttribute(Attribute &result, StringRef attrName,
+                             SmallVectorImpl<NamedAttribute> &attrs) {
+    return parseAttribute(result, Type(), attrName, attrs);
+  }
+
+  /// Parse an arbitrary attribute of a given type and return it in result. This
+  /// also adds the attribute to the specified attribute list with the specified
+  /// name.
+  virtual ParseResult
+  parseAttribute(Attribute &result, Type type, StringRef attrName,
+                 SmallVectorImpl<NamedAttribute> &attrs) = 0;
+
+  /// Parse an attribute of a specific kind and type.
+  template <typename AttrType>
+  ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
+                             SmallVectorImpl<NamedAttribute> &attrs) {
+    llvm::SMLoc loc = getCurrentLocation();
+
+    // Parse any kind of attribute.
+    Attribute attr;
+    if (parseAttribute(attr, type, attrName, attrs))
+      return failure();
+
+    // Check for the right kind of attribute.
+    result = attr.dyn_cast<AttrType>();
+    if (!result)
+      return emitError(loc, "invalid kind of constant specified");
+
+    return success();
+  }
+
+  /// Parse a named dictionary into 'result' if it is present.
+  virtual ParseResult
+  parseOptionalAttributeDict(SmallVectorImpl<NamedAttribute> &result) = 0;
+
+  //===--------------------------------------------------------------------===//
+  // Identifier Parsing
+  //===--------------------------------------------------------------------===//
+
+  virtual ParseResult
+  parseSymbolName(StringAttr &result, StringRef attrName,
+                  SmallVectorImpl<NamedAttribute> &attrs) = 0;
+
+  //===--------------------------------------------------------------------===//
+  // Operand Parsing
+  //===--------------------------------------------------------------------===//
+
+  /// This is the representation of an operand reference.
+  struct OperandType {
+    llvm::SMLoc location; // Location of the token.
+    StringRef name;       // Value name, e.g. %42 or %abc
+    unsigned number;      // Number, e.g. 12 for an operand like %xyz#12
+  };
+
+  /// Parse a single operand.
+  virtual ParseResult parseOperand(OperandType &result) = 0;
+
+  /// These are the supported delimiters around operand lists and region
+  /// argument lists, used by parseOperandList and parseRegionArgumentList.
+  enum class Delimiter {
+    /// Zero or more operands with no delimiters.
+    None,
+    /// Parens surrounding zero or more operands.
+    Paren,
+    /// Square brackets surrounding zero or more operands.
+    Square,
+    /// Parens supporting zero or more operands, or nothing.
+    OptionalParen,
+    /// Square brackets supporting zero or more ops, or nothing.
+    OptionalSquare,
+  };
+
+  /// Parse zero or more SSA comma-separated operand references with a specified
+  /// surrounding delimiter, and an optional required operand count.
+  virtual ParseResult
+  parseOperandList(SmallVectorImpl<OperandType> &result,
+                   int requiredOperandCount = -1,
+                   Delimiter delimiter = Delimiter::None) = 0;
+  ParseResult parseOperandList(SmallVectorImpl<OperandType> &result,
+                               Delimiter delimiter) {
+    return parseOperandList(result, /*requiredOperandCount=*/-1, delimiter);
+  }
+
+  /// Parse zero or more trailing SSA comma-separated trailing operand
+  /// references with a specified surrounding delimiter, and an optional
+  /// required operand count. A leading comma is expected before the operands.
+  virtual ParseResult
+  parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
+                           int requiredOperandCount = -1,
+                           Delimiter delimiter = Delimiter::None) = 0;
+  ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
+                                       Delimiter delimiter) {
+    return parseTrailingOperandList(result, /*requiredOperandCount=*/-1,
+                                    delimiter);
+  }
+
+  /// Resolve an operand to an SSA value, emitting an error on failure.
+  virtual ParseResult resolveOperand(const OperandType &operand, Type type,
+                                     SmallVectorImpl<Value *> &result) = 0;
+
+  /// Resolve a list of operands to SSA values, emitting an error on failure, or
+  /// appending the results to the list on success. This method should be used
+  /// when all operands have the same type.
+  ParseResult resolveOperands(ArrayRef<OperandType> operands, Type type,
+                              SmallVectorImpl<Value *> &result) {
+    for (auto elt : operands)
+      if (resolveOperand(elt, type, result))
+        return failure();
+    return success();
+  }
+
+  /// Resolve a list of operands and a list of operand types to SSA values,
+  /// emitting an error and returning failure, or appending the results
+  /// to the list on success.
+  ParseResult resolveOperands(ArrayRef<OperandType> operands,
+                              ArrayRef<Type> types, llvm::SMLoc loc,
+                              SmallVectorImpl<Value *> &result) {
+    if (operands.size() != types.size())
+      return emitError(loc)
+             << operands.size() << " operands present, but expected "
+             << types.size();
+
+    for (unsigned i = 0, e = operands.size(); i != e; ++i)
+      if (resolveOperand(operands[i], types[i], result))
+        return failure();
+    return success();
+  }
+
+  /// Parses an affine map attribute where dims and symbols are SSA operands.
+  /// Operand values must come from single-result sources, and be valid
+  /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
+  virtual ParseResult
+  parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, Attribute &map,
+                         StringRef attrName,
+                         SmallVectorImpl<NamedAttribute> &attrs) = 0;
+
+  //===--------------------------------------------------------------------===//
+  // Region Parsing
+  //===--------------------------------------------------------------------===//
+
+  /// Parses a region. Any parsed blocks are appended to "region" and must be
+  /// moved to the op regions after the op is created. The first block of the
+  /// region takes "arguments" of types "argTypes".
+  virtual ParseResult parseRegion(Region &region,
+                                  ArrayRef<OperandType> arguments,
+                                  ArrayRef<Type> argTypes) = 0;
+
+  /// Parses a region if present.
+  virtual ParseResult parseOptionalRegion(Region &region,
+                                          ArrayRef<OperandType> arguments,
+                                          ArrayRef<Type> argTypes) = 0;
+
+  /// Parse a region argument.  Region arguments define new values; so this also
+  /// checks if values with the same name have not been defined yet.
+  virtual ParseResult parseRegionArgument(OperandType &argument) = 0;
+
+  /// Parse zero or more region arguments with a specified surrounding
+  /// delimiter, and an optional required argument count. Region arguments
+  /// define new values; so this also checks if values with the same names have
+  /// not been defined yet.
+  virtual ParseResult
+  parseRegionArgumentList(SmallVectorImpl<OperandType> &result,
+                          int requiredOperandCount = -1,
+                          Delimiter delimiter = Delimiter::None) = 0;
+  virtual ParseResult
+  parseRegionArgumentList(SmallVectorImpl<OperandType> &result,
+                          Delimiter delimiter) {
+    return parseRegionArgumentList(result, /*requiredOperandCount=*/-1,
+                                   delimiter);
+  }
+
+  /// Parse a region argument if present.
+  virtual ParseResult parseOptionalRegionArgument(OperandType &argument) = 0;
+
+  //===--------------------------------------------------------------------===//
+  // Successor Parsing
+  //===--------------------------------------------------------------------===//
+
+  /// Parse a single operation successor and its operand list.
+  virtual ParseResult
+  parseSuccessorAndUseList(Block *&dest,
+                           SmallVectorImpl<Value *> &operands) = 0;
+
+  //===--------------------------------------------------------------------===//
+  // Type Parsing
+  //===--------------------------------------------------------------------===//
+
+  /// Parse a type.
+  virtual ParseResult parseType(Type &result) = 0;
+
+  /// Parse an optional arrow followed by a type list.
+  virtual ParseResult
+  parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0;
+
+  /// Parse a colon followed by a type.
+  virtual ParseResult parseColonType(Type &result) = 0;
+
+  /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType.
+  template <typename TypeType> ParseResult parseColonType(TypeType &result) {
+    llvm::SMLoc loc = getCurrentLocation();
+
+    // Parse any kind of type.
+    Type type;
+    if (parseColonType(type))
+      return failure();
+
+    // Check for the right kind of attribute.
+    result = type.dyn_cast<TypeType>();
+    if (!result)
+      return emitError(loc, "invalid kind of type specified");
+
+    return success();
+  }
+
+  /// Parse a colon followed by a type list, which must have at least one type.
+  virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
+
+  /// Parse an optional colon followed by a type list, which if present must
+  /// have at least one type.
+  virtual ParseResult
+  parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
+
+  /// Parse a keyword followed by a type.
+  ParseResult parseKeywordType(const char *keyword, Type &result) {
+    return failure(parseKeyword(keyword) || parseType(result));
+  }
+
+  /// Add the specified type to the end of the specified type list and return
+  /// success.  This is a helper designed to allow parse methods to be simple
+  /// and chain through || operators.
+  ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) {
+    result.push_back(type);
+    return success();
+  }
+
+  /// Add the specified types to the end of the specified type list and return
+  /// success.  This is a helper designed to allow parse methods to be simple
+  /// and chain through || operators.
+  ParseResult addTypesToList(ArrayRef<Type> types,
+                             SmallVectorImpl<Type> &result) {
+    result.append(types.begin(), types.end());
+    return success();
+  }
+
+private:
+  /// Parse either an operand list or a region argument list depending on
+  /// whether isOperandList is true.
+  ParseResult parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result,
+                                          bool isOperandList,
+                                          int requiredOperandCount,
+                                          Delimiter delimiter);
+};
+
+} // end namespace mlir
+
+#endif
diff --git a/third_party/mlir/include/mlir/IR/Operation.h b/third_party/mlir/include/mlir/IR/Operation.h
new file mode 100644
index 0000000..db10a1a
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/Operation.h
@@ -0,0 +1,728 @@
+//===- Operation.h - MLIR Operation Class -----------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines the Operation class.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_OPERATION_H
+#define MLIR_IR_OPERATION_H
+
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/Region.h"
+#include "llvm/ADT/Twine.h"
+
+namespace mlir {
+class BlockAndValueMapping;
+class Location;
+class MLIRContext;
+class OperandIterator;
+class OperandTypeIterator;
+struct OperationState;
+class ResultIterator;
+class ResultTypeIterator;
+
+/// Terminator operations can have Block operands to represent successors.
+using BlockOperand = IROperandImpl<Block>;
+
+/// Operation is a basic unit of execution within a function. Operations can
+/// be nested within other operations effectively forming a tree. Child
+/// operations are organized into operation blocks represented by a 'Block'
+/// class.
+class Operation final
+    : public llvm::ilist_node_with_parent<Operation, Block>,
+      private llvm::TrailingObjects<Operation, OpResult, BlockOperand, unsigned,
+                                    Region, detail::OperandStorage> {
+public:
+  /// Create a new Operation with the specific fields.
+  static Operation *create(Location location, OperationName name,
+                           ArrayRef<Value *> operands,
+                           ArrayRef<Type> resultTypes,
+                           ArrayRef<NamedAttribute> attributes,
+                           ArrayRef<Block *> successors, unsigned numRegions,
+                           bool resizableOperandList, MLIRContext *context);
+
+  /// Overload of create that takes an existing NamedAttributeList to avoid
+  /// unnecessarily uniquing a list of attributes.
+  static Operation *create(Location location, OperationName name,
+                           ArrayRef<Value *> operands,
+                           ArrayRef<Type> resultTypes,
+                           const NamedAttributeList &attributes,
+                           ArrayRef<Block *> successors, unsigned numRegions,
+                           bool resizableOperandList, MLIRContext *context);
+
+  /// Create a new Operation from the fields stored in `state`.
+  static Operation *create(const OperationState &state);
+
+  /// The name of an operation is the key identifier for it.
+  OperationName getName() { return name; }
+
+  /// If this operation has a registered operation description, return it.
+  /// Otherwise return null.
+  const AbstractOperation *getAbstractOperation() {
+    return getName().getAbstractOperation();
+  }
+
+  /// Returns true if this operation has a registered operation description,
+  /// otherwise false.
+  bool isRegistered() { return getAbstractOperation(); }
+
+  /// Remove this operation from its parent block and delete it.
+  void erase();
+
+  /// Create a deep copy of this operation, remapping any operands that use
+  /// values outside of the operation using the map that is provided (leaving
+  /// them alone if no entry is present).  Replaces references to cloned
+  /// sub-operations to the corresponding operation that is copied, and adds
+  /// those mappings to the map.
+  Operation *clone(BlockAndValueMapping &mapper);
+  Operation *clone();
+
+  /// Create a deep copy of this operation but keep the operation regions empty.
+  /// Operands are remapped using `mapper` (if present), and `mapper` is updated
+  /// to contain the results.
+  Operation *cloneWithoutRegions(BlockAndValueMapping &mapper);
+  Operation *cloneWithoutRegions();
+
+  /// Returns the operation block that contains this operation.
+  Block *getBlock() { return block; }
+
+  /// Return the context this operation is associated with.
+  MLIRContext *getContext();
+
+  /// Return the dialact this operation is associated with, or nullptr if the
+  /// associated dialect is not registered.
+  Dialect *getDialect();
+
+  /// The source location the operation was defined or derived from.
+  Location getLoc() { return location; }
+
+  /// Set the source location the operation was defined or derived from.
+  void setLoc(Location loc) { location = loc; }
+
+  /// Returns the region to which the instruction belongs. Returns nullptr if
+  /// the instruction is unlinked.
+  Region *getParentRegion();
+
+  /// Returns the closest surrounding operation that contains this operation
+  /// or nullptr if this is a top-level operation.
+  Operation *getParentOp();
+
+  /// Return the closest surrounding parent operation that is of type 'OpTy'.
+  template <typename OpTy> OpTy getParentOfType() {
+    auto *op = this;
+    while ((op = op->getParentOp()))
+      if (auto parentOp = llvm::dyn_cast<OpTy>(op))
+        return parentOp;
+    return OpTy();
+  }
+
+  /// Replace any uses of 'from' with 'to' within this operation.
+  void replaceUsesOfWith(Value *from, Value *to);
+
+  /// Replace all uses of results of this operation with the provided 'values'.
+  template <typename ValuesT,
+            typename = decltype(std::declval<ValuesT>().begin())>
+  void replaceAllUsesWith(ValuesT &&values) {
+    assert(std::distance(values.begin(), values.end()) == getNumResults() &&
+           "expected 'values' to correspond 1-1 with the number of results");
+
+    auto valueIt = values.begin();
+    for (unsigned i = 0, e = getNumResults(); i != e; ++i)
+      getResult(i)->replaceAllUsesWith(*(valueIt++));
+  }
+
+  /// Replace all uses of results of this operation with results of 'op'.
+  void replaceAllUsesWith(Operation *op) {
+    assert(getNumResults() == op->getNumResults());
+    for (unsigned i = 0, e = getNumResults(); i != e; ++i)
+      getResult(i)->replaceAllUsesWith(op->getResult(i));
+  }
+
+  /// Destroys this operation and its subclass data.
+  void destroy();
+
+  /// This drops all operand uses from this operation, which is an essential
+  /// step in breaking cyclic dependences between references when they are to
+  /// be deleted.
+  void dropAllReferences();
+
+  /// Drop uses of all values defined by this operation or its nested regions.
+  void dropAllDefinedValueUses();
+
+  /// Unlink this operation from its current block and insert it right before
+  /// `existingInst` which may be in the same or another block in the same
+  /// function.
+  void moveBefore(Operation *existingInst);
+
+  /// Unlink this operation from its current block and insert it right before
+  /// `iterator` in the specified block.
+  void moveBefore(Block *block, llvm::iplist<Operation>::iterator iterator);
+
+  /// Given an operation 'other' that is within the same parent block, return
+  /// whether the current operation is before 'other' in the operation list
+  /// of the parent block.
+  /// Note: This function has an average complexity of O(1), but worst case may
+  /// take O(N) where N is the number of operations within the parent block.
+  bool isBeforeInBlock(Operation *other);
+
+  void print(raw_ostream &os);
+  void dump();
+
+  //===--------------------------------------------------------------------===//
+  // Operands
+  //===--------------------------------------------------------------------===//
+
+  /// Returns if the operation has a resizable operation list, i.e. operands can
+  /// be added.
+  bool hasResizableOperandsList() { return getOperandStorage().isResizable(); }
+
+  /// Replace the current operands of this operation with the ones provided in
+  /// 'operands'. If the operands list is not resizable, the size of 'operands'
+  /// must be less than or equal to the current number of operands.
+  void setOperands(ArrayRef<Value *> operands) {
+    getOperandStorage().setOperands(this, operands);
+  }
+
+  unsigned getNumOperands() { return getOperandStorage().size(); }
+
+  Value *getOperand(unsigned idx) { return getOpOperand(idx).get(); }
+  void setOperand(unsigned idx, Value *value) {
+    return getOpOperand(idx).set(value);
+  }
+
+  // Support operand iteration.
+  using operand_iterator = OperandIterator;
+  using operand_range = llvm::iterator_range<operand_iterator>;
+
+  operand_iterator operand_begin();
+  operand_iterator operand_end();
+
+  /// Returns an iterator on the underlying Value's (Value *).
+  operand_range getOperands();
+
+  /// Erase the operand at position `idx`.
+  void eraseOperand(unsigned idx) { getOperandStorage().eraseOperand(idx); }
+
+  MutableArrayRef<OpOperand> getOpOperands() {
+    return getOperandStorage().getOperands();
+  }
+
+  OpOperand &getOpOperand(unsigned idx) { return getOpOperands()[idx]; }
+
+  // Support operand type iteration.
+  using operand_type_iterator = OperandTypeIterator;
+  using operand_type_range = llvm::iterator_range<operand_type_iterator>;
+  operand_type_iterator operand_type_begin();
+  operand_type_iterator operand_type_end();
+  operand_type_range getOperandTypes();
+
+  //===--------------------------------------------------------------------===//
+  // Results
+  //===--------------------------------------------------------------------===//
+
+  /// Return true if there are no users of any results of this operation.
+  bool use_empty();
+
+  unsigned getNumResults() { return numResults; }
+
+  Value *getResult(unsigned idx) { return &getOpResult(idx); }
+
+  // Support result iteration.
+  using result_iterator = ResultIterator;
+  using result_range = llvm::iterator_range<result_iterator>;
+
+  result_iterator result_begin();
+  result_iterator result_end();
+
+  result_range getResults();
+
+  MutableArrayRef<OpResult> getOpResults() {
+    return {getTrailingObjects<OpResult>(), numResults};
+  }
+
+  OpResult &getOpResult(unsigned idx) { return getOpResults()[idx]; }
+
+  // Support result type iteration.
+  using result_type_iterator = ResultTypeIterator;
+  using result_type_range = llvm::iterator_range<result_type_iterator>;
+  result_type_iterator result_type_begin();
+  result_type_iterator result_type_end();
+  result_type_range getResultTypes();
+
+  //===--------------------------------------------------------------------===//
+  // Attributes
+  //===--------------------------------------------------------------------===//
+
+  // Operations may optionally carry a list of attributes that associate
+  // constants to names.  Attributes may be dynamically added and removed over
+  // the lifetime of an operation.
+
+  /// Return all of the attributes on this operation.
+  ArrayRef<NamedAttribute> getAttrs() { return attrs.getAttrs(); }
+
+  /// Return the internal attribute list on this operation.
+  NamedAttributeList &getAttrList() { return attrs; }
+
+  /// Set the attribute list on this operation.
+  /// Using a NamedAttributeList is more efficient as it does not require new
+  /// uniquing in the MLIRContext.
+  void setAttrs(NamedAttributeList newAttrs) { attrs = newAttrs; }
+
+  /// Return the specified attribute if present, null otherwise.
+  Attribute getAttr(Identifier name) { return attrs.get(name); }
+  Attribute getAttr(StringRef name) { return attrs.get(name); }
+
+  template <typename AttrClass> AttrClass getAttrOfType(Identifier name) {
+    return getAttr(name).dyn_cast_or_null<AttrClass>();
+  }
+
+  template <typename AttrClass> AttrClass getAttrOfType(StringRef name) {
+    return getAttr(name).dyn_cast_or_null<AttrClass>();
+  }
+
+  /// If the an attribute exists with the specified name, change it to the new
+  /// value.  Otherwise, add a new attribute with the specified name/value.
+  void setAttr(Identifier name, Attribute value) { attrs.set(name, value); }
+  void setAttr(StringRef name, Attribute value) {
+    setAttr(Identifier::get(name, getContext()), value);
+  }
+
+  /// Remove the attribute with the specified name if it exists.  The return
+  /// value indicates whether the attribute was present or not.
+  NamedAttributeList::RemoveResult removeAttr(Identifier name) {
+    return attrs.remove(name);
+  }
+
+  /// A utility iterator that filters out non-dialect attributes.
+  class dialect_attr_iterator
+      : public llvm::filter_iterator<ArrayRef<NamedAttribute>::iterator,
+                                     bool (*)(NamedAttribute)> {
+    static bool filter(NamedAttribute attr) {
+      // Dialect attributes are prefixed by the dialect name, like operations.
+      return attr.first.strref().count('.');
+    }
+
+    explicit dialect_attr_iterator(ArrayRef<NamedAttribute>::iterator it,
+                                   ArrayRef<NamedAttribute>::iterator end)
+        : llvm::filter_iterator<ArrayRef<NamedAttribute>::iterator,
+                                bool (*)(NamedAttribute)>(it, end, &filter) {}
+
+    // Allow access to the constructor.
+    friend Operation;
+  };
+  using dialect_attr_range = llvm::iterator_range<dialect_attr_iterator>;
+
+  /// Return a range corresponding to the dialect attributes for this operation.
+  dialect_attr_range getDialectAttrs() {
+    auto attrs = getAttrs();
+    return {dialect_attr_iterator(attrs.begin(), attrs.end()),
+            dialect_attr_iterator(attrs.end(), attrs.end())};
+  }
+  dialect_attr_iterator dialect_attr_begin() {
+    auto attrs = getAttrs();
+    return dialect_attr_iterator(attrs.begin(), attrs.end());
+  }
+  dialect_attr_iterator dialect_attr_end() {
+    auto attrs = getAttrs();
+    return dialect_attr_iterator(attrs.end(), attrs.end());
+  }
+
+  /// Set the dialect attributes for this operation, and preserve all dependent.
+  template <typename DialectAttrT>
+  void setDialectAttrs(DialectAttrT &&dialectAttrs) {
+    SmallVector<NamedAttribute, 16> attrs;
+    attrs.assign(std::begin(dialectAttrs), std::end(dialectAttrs));
+    for (auto attr : getAttrs())
+      if (!attr.first.strref().count('.'))
+        attrs.push_back(attr);
+    setAttrs(llvm::makeArrayRef(attrs));
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Blocks
+  //===--------------------------------------------------------------------===//
+
+  /// Returns the number of regions held by this operation.
+  unsigned getNumRegions() { return numRegions; }
+
+  /// Returns the regions held by this operation.
+  MutableArrayRef<Region> getRegions() {
+    auto *regions = getTrailingObjects<Region>();
+    return {regions, numRegions};
+  }
+
+  /// Returns the region held by this operation at position 'index'.
+  Region &getRegion(unsigned index) {
+    assert(index < numRegions && "invalid region index");
+    return getRegions()[index];
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Terminators
+  //===--------------------------------------------------------------------===//
+
+  MutableArrayRef<BlockOperand> getBlockOperands() {
+    return {getTrailingObjects<BlockOperand>(), numSuccs};
+  }
+
+  /// Return the operands of this operation that are *not* successor arguments.
+  operand_range getNonSuccessorOperands();
+
+  operand_range getSuccessorOperands(unsigned index);
+
+  Value *getSuccessorOperand(unsigned succIndex, unsigned opIndex) {
+    assert(!isKnownNonTerminator() && "only terminators may have successors");
+    assert(opIndex < getNumSuccessorOperands(succIndex));
+    return getOperand(getSuccessorOperandIndex(succIndex) + opIndex);
+  }
+
+  bool hasSuccessors() { return numSuccs != 0; }
+  unsigned getNumSuccessors() { return numSuccs; }
+  unsigned getNumSuccessorOperands(unsigned index) {
+    assert(!isKnownNonTerminator() && "only terminators may have successors");
+    assert(index < getNumSuccessors());
+    return getTrailingObjects<unsigned>()[index];
+  }
+
+  Block *getSuccessor(unsigned index) {
+    assert(index < getNumSuccessors());
+    return getBlockOperands()[index].get();
+  }
+  void setSuccessor(Block *block, unsigned index);
+
+  /// Erase a specific operand from the operand list of the successor at
+  /// 'index'.
+  void eraseSuccessorOperand(unsigned succIndex, unsigned opIndex) {
+    assert(succIndex < getNumSuccessors());
+    assert(opIndex < getNumSuccessorOperands(succIndex));
+    getOperandStorage().eraseOperand(getSuccessorOperandIndex(succIndex) +
+                                     opIndex);
+    --getTrailingObjects<unsigned>()[succIndex];
+  }
+
+  /// Get the index of the first operand of the successor at the provided
+  /// index.
+  unsigned getSuccessorOperandIndex(unsigned index);
+
+  //===--------------------------------------------------------------------===//
+  // Accessors for various properties of operations
+  //===--------------------------------------------------------------------===//
+
+  /// Returns whether the operation is commutative.
+  bool isCommutative() {
+    if (auto *absOp = getAbstractOperation())
+      return absOp->hasProperty(OperationProperty::Commutative);
+    return false;
+  }
+
+  /// Returns whether the operation has side-effects.
+  bool hasNoSideEffect() {
+    if (auto *absOp = getAbstractOperation())
+      return absOp->hasProperty(OperationProperty::NoSideEffect);
+    return false;
+  }
+
+  /// Represents the status of whether an operation is a terminator. We
+  /// represent an 'unknown' status because we want to support unregistered
+  /// terminators.
+  enum class TerminatorStatus { Terminator, NonTerminator, Unknown };
+
+  /// Returns the status of whether this operation is a terminator or not.
+  TerminatorStatus getTerminatorStatus() {
+    if (auto *absOp = getAbstractOperation()) {
+      return absOp->hasProperty(OperationProperty::Terminator)
+                 ? TerminatorStatus::Terminator
+                 : TerminatorStatus::NonTerminator;
+    }
+    return TerminatorStatus::Unknown;
+  }
+
+  /// Returns if the operation is known to be a terminator.
+  bool isKnownTerminator() {
+    return getTerminatorStatus() == TerminatorStatus::Terminator;
+  }
+
+  /// Returns if the operation is known to *not* be a terminator.
+  bool isKnownNonTerminator() {
+    return getTerminatorStatus() == TerminatorStatus::NonTerminator;
+  }
+
+  /// Returns if the operation is known to be completely isolated from enclosing
+  /// regions, i.e. no internal regions reference values defined above this
+  /// operation.
+  bool isKnownIsolatedFromAbove() {
+    if (auto *absOp = getAbstractOperation())
+      return absOp->hasProperty(OperationProperty::IsolatedFromAbove);
+    return false;
+  }
+
+  /// Attempt to fold this operation with the specified constant operand values
+  /// - the elements in "operands" will correspond directly to the operands of
+  /// the operation, but may be null if non-constant. If folding is successful,
+  /// this fills in the `results` vector. If not, `results` is unspecified.
+  LogicalResult fold(ArrayRef<Attribute> operands,
+                     SmallVectorImpl<OpFoldResult> &results);
+
+  /// Returns if the operation was registered with a particular trait, e.g.
+  /// hasTrait<OperandsAreIntegerLike>().
+  template <template <typename T> class Trait> bool hasTrait() {
+    auto *absOp = getAbstractOperation();
+    return absOp ? absOp->hasTrait<Trait>() : false;
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Operation Walkers
+  //===--------------------------------------------------------------------===//
+
+  /// Walk this operation in postorder, calling the callback for each operation
+  /// including this one.
+  void walk(llvm::function_ref<void(Operation *)> callback);
+
+  /// Specialization of walk to only visit operations of 'T'.
+  template <typename T> void walk(llvm::function_ref<void(T)> callback) {
+    walk([&](Operation *op) {
+      if (auto derivedOp = dyn_cast<T>(op))
+        callback(derivedOp);
+    });
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Other
+  //===--------------------------------------------------------------------===//
+
+  /// Emit an error with the op name prefixed, like "'dim' op " which is
+  /// convenient for verifiers.
+  InFlightDiagnostic emitOpError(const Twine &message = {});
+
+  /// Emit an error about fatal conditions with this operation, reporting up to
+  /// any diagnostic handlers that may be listening.
+  InFlightDiagnostic emitError(const Twine &message = {});
+
+  /// Emit a warning about this operation, reporting up to any diagnostic
+  /// handlers that may be listening.
+  InFlightDiagnostic emitWarning(const Twine &message = {});
+
+  /// Emit a remark about this operation, reporting up to any diagnostic
+  /// handlers that may be listening.
+  InFlightDiagnostic emitRemark(const Twine &message = {});
+
+private:
+  Operation(Location location, OperationName name, unsigned numResults,
+            unsigned numSuccessors, unsigned numRegions,
+            const NamedAttributeList &attributes, MLIRContext *context);
+
+  // Operations are deleted through the destroy() member because they are
+  // allocated with malloc.
+  ~Operation();
+
+  /// Returns the operand storage object.
+  detail::OperandStorage &getOperandStorage() {
+    return *getTrailingObjects<detail::OperandStorage>();
+  }
+
+  /// Provide a 'getParent' method for ilist_node_with_parent methods.
+  /// We mark it as const function because ilist_node_with_parent specifically
+  /// requires a 'getParent() const' method. Once ilist_node removes this
+  /// constraint, we should drop the const to fit the rest of the MLIR const
+  /// model.
+  Block *getParent() const { return block; }
+
+  /// The operation block that containts this operation.
+  Block *block = nullptr;
+
+  /// This holds information about the source location the operation was defined
+  /// or derived from.
+  Location location;
+
+  /// Relative order of this operation in its parent block. Used for
+  /// O(1) local dominance checks between operations.
+  mutable unsigned orderIndex = 0;
+
+  const unsigned numResults, numSuccs, numRegions;
+
+  /// This holds the name of the operation.
+  OperationName name;
+
+  /// This holds general named attributes for the operation.
+  NamedAttributeList attrs;
+
+  // allow ilist_traits access to 'block' field.
+  friend struct llvm::ilist_traits<Operation>;
+
+  // allow block to access the 'orderIndex' field.
+  friend class Block;
+
+  // allow ilist_node_with_parent to access the 'getParent' method.
+  friend class llvm::ilist_node_with_parent<Operation, Block>;
+
+  // This stuff is used by the TrailingObjects template.
+  friend llvm::TrailingObjects<Operation, OpResult, BlockOperand, unsigned,
+                               Region, detail::OperandStorage>;
+  size_t numTrailingObjects(OverloadToken<OpResult>) const {
+    return numResults;
+  }
+  size_t numTrailingObjects(OverloadToken<BlockOperand>) const {
+    return numSuccs;
+  }
+  size_t numTrailingObjects(OverloadToken<Region>) const { return numRegions; }
+  size_t numTrailingObjects(OverloadToken<unsigned>) const { return numSuccs; }
+};
+
+inline raw_ostream &operator<<(raw_ostream &os, Operation &op) {
+  op.print(os);
+  return os;
+}
+
+/// This class implements the const/non-const operand iterators for the
+/// Operation class in terms of getOperand(idx).
+class OperandIterator final
+    : public indexed_accessor_iterator<OperandIterator, Operation *, Value *,
+                                       Value *, Value *> {
+public:
+  /// Initializes the operand iterator to the specified operand index.
+  OperandIterator(Operation *object, unsigned index)
+      : indexed_accessor_iterator<OperandIterator, Operation *, Value *,
+                                  Value *, Value *>(object, index) {}
+
+  Value *operator*() const { return this->object->getOperand(this->index); }
+};
+
+/// This class implements the operand type iterators for the Operation
+/// class in terms of operand_iterator->getType().
+class OperandTypeIterator final
+    : public llvm::mapped_iterator<OperandIterator, Type (*)(Value *)> {
+  static Type unwrap(Value *value) { return value->getType(); }
+
+public:
+  using reference = Type;
+
+  /// Initializes the operand type iterator to the specified operand iterator.
+  OperandTypeIterator(OperandIterator it)
+      : llvm::mapped_iterator<OperandIterator, Type (*)(Value *)>(it, &unwrap) {
+  }
+};
+
+// Implement the inline operand iterator methods.
+inline auto Operation::operand_begin() -> operand_iterator {
+  return operand_iterator(this, 0);
+}
+
+inline auto Operation::operand_end() -> operand_iterator {
+  return operand_iterator(this, getNumOperands());
+}
+
+inline auto Operation::getOperands() -> operand_range {
+  return {operand_begin(), operand_end()};
+}
+
+inline auto Operation::operand_type_begin() -> operand_type_iterator {
+  return operand_type_iterator(operand_begin());
+}
+
+inline auto Operation::operand_type_end() -> operand_type_iterator {
+  return operand_type_iterator(operand_end());
+}
+
+inline auto Operation::getOperandTypes() -> operand_type_range {
+  return {operand_type_begin(), operand_type_end()};
+}
+
+/// This class implements the result iterators for the Operation class
+/// in terms of getResult(idx).
+class ResultIterator final
+    : public indexed_accessor_iterator<ResultIterator, Operation *, Value *,
+                                       Value *, Value *> {
+public:
+  /// Initializes the result iterator to the specified index.
+  ResultIterator(Operation *object, unsigned index)
+      : indexed_accessor_iterator<ResultIterator, Operation *, Value *, Value *,
+                                  Value *>(object, index) {}
+
+  Value *operator*() const { return this->object->getResult(this->index); }
+};
+
+/// This class implements the result type iterators for the Operation
+/// class in terms of result_iterator->getType().
+class ResultTypeIterator final
+    : public llvm::mapped_iterator<ResultIterator, Type (*)(Value *)> {
+  static Type unwrap(Value *value) { return value->getType(); }
+
+public:
+  using reference = Type;
+
+  /// Initializes the result type iterator to the specified result iterator.
+  ResultTypeIterator(ResultIterator it)
+      : llvm::mapped_iterator<ResultIterator, Type (*)(Value *)>(it, &unwrap) {}
+};
+
+// Implement the inline result iterator methods.
+inline auto Operation::result_begin() -> result_iterator {
+  return result_iterator(this, 0);
+}
+
+inline auto Operation::result_end() -> result_iterator {
+  return result_iterator(this, getNumResults());
+}
+
+inline auto Operation::getResults() -> llvm::iterator_range<result_iterator> {
+  return {result_begin(), result_end()};
+}
+
+inline auto Operation::result_type_begin() -> result_type_iterator {
+  return result_type_iterator(result_begin());
+}
+
+inline auto Operation::result_type_end() -> result_type_iterator {
+  return result_type_iterator(result_end());
+}
+
+inline auto Operation::getResultTypes() -> result_type_range {
+  return {result_type_begin(), result_type_end()};
+}
+
+} // end namespace mlir
+
+namespace llvm {
+/// Provide isa functionality for operation casts.
+template <typename T> struct isa_impl<T, ::mlir::Operation> {
+  static inline bool doit(const ::mlir::Operation &op) {
+    return T::classof(const_cast<::mlir::Operation *>(&op));
+  }
+};
+
+/// Provide specializations for operation casts as the resulting T is value
+/// typed.
+template <typename T> struct cast_retty_impl<T, ::mlir::Operation *> {
+  using ret_type = T;
+};
+template <typename T> struct cast_retty_impl<T, ::mlir::Operation> {
+  using ret_type = T;
+};
+template <class T>
+struct cast_convert_val<T, ::mlir::Operation, ::mlir::Operation> {
+  static T doit(::mlir::Operation &val) { return T(&val); }
+};
+template <class T>
+struct cast_convert_val<T, ::mlir::Operation *, ::mlir::Operation *> {
+  static T doit(::mlir::Operation *val) { return T(val); }
+};
+} // end namespace llvm
+
+#endif // MLIR_IR_OPERATION_H
diff --git a/third_party/mlir/include/mlir/IR/OperationSupport.h b/third_party/mlir/include/mlir/IR/OperationSupport.h
new file mode 100644
index 0000000..204da29
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/OperationSupport.h
@@ -0,0 +1,481 @@
+//===- OperationSupport.h ---------------------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines a number of support types that Operation and related
+// classes build on top of.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_OPERATION_SUPPORT_H
+#define MLIR_IR_OPERATION_SUPPORT_H
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Identifier.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/PointerUnion.h"
+#include "llvm/Support/TrailingObjects.h"
+#include <memory>
+
+namespace mlir {
+class Block;
+class Dialect;
+class Operation;
+struct OperationState;
+class OpAsmParser;
+class OpAsmParserResult;
+class OpAsmPrinter;
+class OpFoldResult;
+class ParseResult;
+class Pattern;
+class Region;
+class RewritePattern;
+class Type;
+class Value;
+
+/// This is an adaptor from a list of values to named operands of OpTy.  In a
+/// generic operation context, e.g., in dialect conversions, an ordered array of
+/// `Value`s is treated as operands of `OpTy`.  This adaptor takes a reference
+/// to the array and provides accessors with the same names as `OpTy` for
+/// operands.  This makes possible to create function templates that operate on
+/// either OpTy or OperandAdaptor<OpTy> seamlessly.
+template <typename OpTy> using OperandAdaptor = typename OpTy::OperandAdaptor;
+
+class OwningRewritePatternList;
+
+enum class OperationProperty {
+  /// This bit is set for an operation if it is a commutative operation: that
+  /// is a binary operator (two inputs) where "a op b" and "b op a" produce the
+  /// same results.
+  Commutative = 0x1,
+
+  /// This bit is set for operations that have no side effects: that means that
+  /// they do not read or write memory, or access any hidden state.
+  NoSideEffect = 0x2,
+
+  /// This bit is set for an operation if it is a terminator: that means
+  /// an operation at the end of a block.
+  Terminator = 0x4,
+
+  /// This bit is set for operations that are completely isolated from above.
+  /// This is used for operations whose regions are explicit capture only, i.e.
+  /// they are never allowed to implicitly reference values defined above the
+  /// parent operation.
+  IsolatedFromAbove = 0x8,
+};
+
+/// This is a "type erased" representation of a registered operation.  This
+/// should only be used by things like the AsmPrinter and other things that need
+/// to be parameterized by generic operation hooks.  Most user code should use
+/// the concrete operation types.
+class AbstractOperation {
+public:
+  using OperationProperties = uint32_t;
+
+  /// This is the name of the operation.
+  const StringRef name;
+
+  /// This is the dialect that this operation belongs to.
+  Dialect &dialect;
+
+  /// Return true if this "op class" can match against the specified operation.
+  bool (&classof)(Operation *op);
+
+  /// Use the specified object to parse this ops custom assembly format.
+  ParseResult (&parseAssembly)(OpAsmParser *parser, OperationState *result);
+
+  /// This hook implements the AsmPrinter for this operation.
+  void (&printAssembly)(Operation *op, OpAsmPrinter *p);
+
+  /// This hook implements the verifier for this operation.  It should emits an
+  /// error message and returns failure if a problem is detected, or returns
+  /// success if everything is ok.
+  LogicalResult (&verifyInvariants)(Operation *op);
+
+  /// This hook implements a generalized folder for this operation.  Operations
+  /// can implement this to provide simplifications rules that are applied by
+  /// the Builder::createOrFold API and the canonicalization pass.
+  ///
+  /// This is an intentionally limited interface - implementations of this hook
+  /// can only perform the following changes to the operation:
+  ///
+  ///  1. They can leave the operation alone and without changing the IR, and
+  ///     return failure.
+  ///  2. They can mutate the operation in place, without changing anything else
+  ///     in the IR.  In this case, return success.
+  ///  3. They can return a list of existing values that can be used instead of
+  ///     the operation.  In this case, fill in the results list and return
+  ///     success.  The caller will remove the operation and use those results
+  ///     instead.
+  ///
+  /// This allows expression of some simple in-place canonicalizations (e.g.
+  /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
+  /// generalized constant folding.
+  LogicalResult (&foldHook)(Operation *op, ArrayRef<Attribute> operands,
+                            SmallVectorImpl<OpFoldResult> &results);
+
+  /// This hook returns any canonicalization pattern rewrites that the operation
+  /// supports, for use by the canonicalization pass.
+  void (&getCanonicalizationPatterns)(OwningRewritePatternList &results,
+                                      MLIRContext *context);
+
+  /// Returns whether the operation has a particular property.
+  bool hasProperty(OperationProperty property) const {
+    return opProperties & static_cast<OperationProperties>(property);
+  }
+
+  /// Returns if the operation has a particular trait.
+  template <template <typename T> class Trait> bool hasTrait() const {
+    return hasRawTrait(ClassID::getID<Trait>());
+  }
+
+  /// Look up the specified operation in the specified MLIRContext and return a
+  /// pointer to it if present.  Otherwise, return a null pointer.
+  static const AbstractOperation *lookup(StringRef opName,
+                                         MLIRContext *context);
+
+  /// This constructor is used by Dialect objects when they register the list of
+  /// operations they contain.
+  template <typename T> static AbstractOperation get(Dialect &dialect) {
+    return AbstractOperation(
+        T::getOperationName(), dialect, T::getOperationProperties(), T::classof,
+        T::parseAssembly, T::printAssembly, T::verifyInvariants, T::foldHook,
+        T::getCanonicalizationPatterns, T::hasTrait);
+  }
+
+private:
+  AbstractOperation(
+      StringRef name, Dialect &dialect, OperationProperties opProperties,
+      bool (&classof)(Operation *op),
+      ParseResult (&parseAssembly)(OpAsmParser *parser, OperationState *result),
+      void (&printAssembly)(Operation *op, OpAsmPrinter *p),
+      LogicalResult (&verifyInvariants)(Operation *op),
+      LogicalResult (&foldHook)(Operation *op, ArrayRef<Attribute> operands,
+                                SmallVectorImpl<OpFoldResult> &results),
+      void (&getCanonicalizationPatterns)(OwningRewritePatternList &results,
+                                          MLIRContext *context),
+      bool (&hasTrait)(ClassID *traitID))
+      : name(name), dialect(dialect), classof(classof),
+        parseAssembly(parseAssembly), printAssembly(printAssembly),
+        verifyInvariants(verifyInvariants), foldHook(foldHook),
+        getCanonicalizationPatterns(getCanonicalizationPatterns),
+        opProperties(opProperties), hasRawTrait(hasTrait) {}
+
+  /// The properties of the operation.
+  const OperationProperties opProperties;
+
+  /// This hook returns if the operation contains the trait corresponding
+  /// to the given ClassID.
+  bool (&hasRawTrait)(ClassID *traitID);
+};
+
+class OperationName {
+public:
+  using RepresentationUnion =
+      llvm::PointerUnion<Identifier, const AbstractOperation *>;
+
+  OperationName(AbstractOperation *op) : representation(op) {}
+  OperationName(StringRef name, MLIRContext *context);
+
+  /// Return the name of the dialect this operation is registered to.
+  StringRef getDialect() const;
+
+  /// Return the name of this operation.  This always succeeds.
+  StringRef getStringRef() const;
+
+  /// If this operation has a registered operation description, return it.
+  /// Otherwise return null.
+  const AbstractOperation *getAbstractOperation() const;
+
+  void print(raw_ostream &os) const;
+  void dump() const;
+
+  void *getAsOpaquePointer() const {
+    return static_cast<void *>(representation.getOpaqueValue());
+  }
+  static OperationName getFromOpaquePointer(void *pointer);
+
+private:
+  RepresentationUnion representation;
+  OperationName(RepresentationUnion representation)
+      : representation(representation) {}
+};
+
+inline raw_ostream &operator<<(raw_ostream &os, OperationName identifier) {
+  identifier.print(os);
+  return os;
+}
+
+inline bool operator==(OperationName lhs, OperationName rhs) {
+  return lhs.getAsOpaquePointer() == rhs.getAsOpaquePointer();
+}
+
+inline bool operator!=(OperationName lhs, OperationName rhs) {
+  return lhs.getAsOpaquePointer() != rhs.getAsOpaquePointer();
+}
+
+// Make operation names hashable.
+inline llvm::hash_code hash_value(OperationName arg) {
+  return llvm::hash_value(arg.getAsOpaquePointer());
+}
+
+/// This represents an operation in an abstracted form, suitable for use with
+/// the builder APIs.  This object is a large and heavy weight object meant to
+/// be used as a temporary object on the stack.  It is generally unwise to put
+/// this in a collection.
+struct OperationState {
+  MLIRContext *const context;
+  Location location;
+  OperationName name;
+  SmallVector<Value *, 4> operands;
+  /// Types of the results of this operation.
+  SmallVector<Type, 4> types;
+  SmallVector<NamedAttribute, 4> attributes;
+  /// Successors of this operation and their respective operands.
+  SmallVector<Block *, 1> successors;
+  /// Regions that the op will hold.
+  SmallVector<std::unique_ptr<Region>, 1> regions;
+  /// If the operation has a resizable operand list.
+  bool resizableOperandList = false;
+
+public:
+  OperationState(Location location, StringRef name);
+
+  OperationState(Location location, OperationName name);
+
+  OperationState(Location location, StringRef name, ArrayRef<Value *> operands,
+                 ArrayRef<Type> types, ArrayRef<NamedAttribute> attributes,
+                 ArrayRef<Block *> successors = {},
+                 MutableArrayRef<std::unique_ptr<Region>> regions = {},
+                 bool resizableOperandList = false);
+
+  void addOperands(ArrayRef<Value *> newOperands) {
+    assert(successors.empty() &&
+           "Non successor operands should be added first.");
+    operands.append(newOperands.begin(), newOperands.end());
+  }
+
+  void addTypes(ArrayRef<Type> newTypes) {
+    types.append(newTypes.begin(), newTypes.end());
+  }
+
+  /// Add an attribute with the specified name.
+  void addAttribute(StringRef name, Attribute attr) {
+    addAttribute(Identifier::get(name, getContext()), attr);
+  }
+
+  /// Add an attribute with the specified name.
+  void addAttribute(Identifier name, Attribute attr) {
+    attributes.push_back({name, attr});
+  }
+
+  /// Add an array of named attributes.
+  void addAttributes(ArrayRef<NamedAttribute> newAttributes) {
+    attributes.append(newAttributes.begin(), newAttributes.end());
+  }
+
+  void addSuccessor(Block *successor, ArrayRef<Value *> succOperands) {
+    successors.push_back(successor);
+    // Insert a sentinal operand to mark a barrier between successor operands.
+    operands.push_back(nullptr);
+    operands.append(succOperands.begin(), succOperands.end());
+  }
+
+  /// Create a region that should be attached to the operation.  These regions
+  /// can be filled in immediately without waiting for Operation to be
+  /// created.  When it is, the region bodies will be transferred.
+  Region *addRegion();
+
+  /// Take a region that should be attached to the Operation.  The body of the
+  /// region will be transferred when the Operation is constructed.  If the
+  /// region is null, a new empty region will be attached to the Operation.
+  void addRegion(std::unique_ptr<Region> &&region);
+
+  /// Sets the operand list of the operation as resizable.
+  void setOperandListToResizable(bool isResizable = true) {
+    resizableOperandList = isResizable;
+  }
+
+  /// Get the context held by this operation state.
+  MLIRContext *getContext() { return location->getContext(); }
+};
+
+namespace detail {
+/// A utility class holding the information necessary to dynamically resize
+/// operands.
+struct ResizableStorage {
+  ResizableStorage(OpOperand *opBegin, unsigned numOperands)
+      : firstOpAndIsDynamic(opBegin, false), capacity(numOperands) {}
+
+  ~ResizableStorage() { cleanupStorage(); }
+
+  /// Cleanup any allocated storage.
+  void cleanupStorage() {
+    // If the storage is dynamic, then we need to free the storage.
+    if (isStorageDynamic())
+      free(firstOpAndIsDynamic.getPointer());
+  }
+
+  /// Sets the storage pointer to a new dynamically allocated block.
+  void setDynamicStorage(OpOperand *opBegin) {
+    /// Cleanup the old storage if necessary.
+    cleanupStorage();
+    firstOpAndIsDynamic.setPointerAndInt(opBegin, true);
+  }
+
+  /// Returns the current storage pointer.
+  OpOperand *getPointer() { return firstOpAndIsDynamic.getPointer(); }
+
+  /// Returns if the current storage of operands is in the trailing objects is
+  /// in a dynamically allocated memory block.
+  bool isStorageDynamic() const { return firstOpAndIsDynamic.getInt(); }
+
+  /// A pointer to the first operand element. This is either to the trailing
+  /// objects storage, or a dynamically allocated block of memory.
+  llvm::PointerIntPair<OpOperand *, 1, bool> firstOpAndIsDynamic;
+
+  // The maximum number of operands that can be currently held by the storage.
+  unsigned capacity;
+};
+
+/// This class handles the management of operation operands. Operands are
+/// stored similarly to the elements of a SmallVector except for two key
+/// differences. The first is the inline storage, which is a trailing objects
+/// array. The second is that being able to dynamically resize the operand list
+/// is optional.
+class OperandStorage final
+    : private llvm::TrailingObjects<OperandStorage, ResizableStorage,
+                                    OpOperand> {
+public:
+  OperandStorage(unsigned numOperands, bool resizable)
+      : numOperands(numOperands), resizable(resizable) {
+    // Initialize the resizable storage.
+    if (resizable) {
+      new (&getResizableStorage())
+          ResizableStorage(getTrailingObjects<OpOperand>(), numOperands);
+    }
+  }
+
+  ~OperandStorage() {
+    // Manually destruct the operands.
+    for (auto &operand : getOperands())
+      operand.~OpOperand();
+
+    // If the storage is resizable then destruct the utility.
+    if (resizable)
+      getResizableStorage().~ResizableStorage();
+  }
+
+  /// Replace the operands contained in the storage with the ones provided in
+  /// 'operands'.
+  void setOperands(Operation *owner, ArrayRef<Value *> operands);
+
+  /// Erase an operand held by the storage.
+  void eraseOperand(unsigned index);
+
+  /// Get the operation operands held by the storage.
+  MutableArrayRef<OpOperand> getOperands() {
+    return {getRawOperands(), size()};
+  }
+
+  /// Return the number of operands held in the storage.
+  unsigned size() const { return numOperands; }
+
+  /// Returns the additional size necessary for allocating this object.
+  static size_t additionalAllocSize(unsigned numOperands, bool resizable) {
+    return additionalSizeToAlloc<ResizableStorage, OpOperand>(resizable ? 1 : 0,
+                                                              numOperands);
+  }
+
+  /// Returns if this storage is resizable.
+  bool isResizable() const { return resizable; }
+
+private:
+  /// Clear the storage and destroy the current operands held by the storage.
+  void clear() { numOperands = 0; }
+
+  /// Returns the current pointer for the raw operands array.
+  OpOperand *getRawOperands() {
+    return resizable ? getResizableStorage().getPointer()
+                     : getTrailingObjects<OpOperand>();
+  }
+
+  /// Returns the resizable operand utility class.
+  ResizableStorage &getResizableStorage() {
+    assert(resizable);
+    return *getTrailingObjects<ResizableStorage>();
+  }
+
+  /// Grow the internal resizable operand storage.
+  void grow(ResizableStorage &resizeUtil, size_t minSize);
+
+  /// The current number of operands, and the current max operand capacity.
+  unsigned numOperands : 31;
+
+  /// Whether this storage is resizable or not.
+  bool resizable : 1;
+
+  // This stuff is used by the TrailingObjects template.
+  friend llvm::TrailingObjects<OperandStorage, ResizableStorage, OpOperand>;
+  size_t numTrailingObjects(OverloadToken<ResizableStorage>) const {
+    return resizable ? 1 : 0;
+  }
+};
+} // end namespace detail
+} // end namespace mlir
+
+namespace llvm {
+// Identifiers hash just like pointers, there is no need to hash the bytes.
+template <> struct DenseMapInfo<mlir::OperationName> {
+  static mlir::OperationName getEmptyKey() {
+    auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+    return mlir::OperationName::getFromOpaquePointer(pointer);
+  }
+  static mlir::OperationName getTombstoneKey() {
+    auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+    return mlir::OperationName::getFromOpaquePointer(pointer);
+  }
+  static unsigned getHashValue(mlir::OperationName Val) {
+    return DenseMapInfo<void *>::getHashValue(Val.getAsOpaquePointer());
+  }
+  static bool isEqual(mlir::OperationName LHS, mlir::OperationName RHS) {
+    return LHS == RHS;
+  }
+};
+
+/// The pointer inside of an identifier comes from a StringMap, so its alignment
+/// is always at least 4 and probably 8 (on 64-bit machines).  Allow LLVM to
+/// steal the low bits.
+template <> struct PointerLikeTypeTraits<mlir::OperationName> {
+public:
+  static inline void *getAsVoidPointer(mlir::OperationName I) {
+    return const_cast<void *>(I.getAsOpaquePointer());
+  }
+  static inline mlir::OperationName getFromVoidPointer(void *P) {
+    return mlir::OperationName::getFromOpaquePointer(P);
+  }
+  enum {
+    NumLowBitsAvailable = PointerLikeTypeTraits<
+        mlir::OperationName::RepresentationUnion>::NumLowBitsAvailable
+  };
+};
+
+} // end namespace llvm
+
+#endif
diff --git a/third_party/mlir/include/mlir/IR/PatternMatch.h b/third_party/mlir/include/mlir/IR/PatternMatch.h
new file mode 100644
index 0000000..4f1e50b
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/PatternMatch.h
@@ -0,0 +1,464 @@
+//===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_PATTERNMATCHER_H
+#define MLIR_PATTERNMATCHER_H
+
+#include "mlir/IR/Builders.h"
+
+namespace mlir {
+
+class PatternRewriter;
+
+//===----------------------------------------------------------------------===//
+// PatternBenefit class
+//===----------------------------------------------------------------------===//
+
+/// This class represents the benefit of a pattern match in a unitless scheme
+/// that ranges from 0 (very little benefit) to 65K.  The most common unit to
+/// use here is the "number of operations matched" by the pattern.
+///
+/// This also has a sentinel representation that can be used for patterns that
+/// fail to match.
+///
+class PatternBenefit {
+  enum { ImpossibleToMatchSentinel = 65535 };
+
+public:
+  /*implicit*/ PatternBenefit(unsigned benefit);
+  PatternBenefit(const PatternBenefit &) = default;
+  PatternBenefit &operator=(const PatternBenefit &) = default;
+
+  static PatternBenefit impossibleToMatch() { return PatternBenefit(); }
+  bool isImpossibleToMatch() const { return *this == impossibleToMatch(); }
+
+  /// If the corresponding pattern can match, return its benefit.  If the
+  // corresponding pattern isImpossibleToMatch() then this aborts.
+  unsigned short getBenefit() const;
+
+  bool operator==(const PatternBenefit &rhs) const {
+    return representation == rhs.representation;
+  }
+  bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); }
+  bool operator<(const PatternBenefit &rhs) const {
+    return representation < rhs.representation;
+  }
+
+private:
+  PatternBenefit() : representation(ImpossibleToMatchSentinel) {}
+  unsigned short representation;
+};
+
+/// Pattern state is used by patterns that want to maintain state between their
+/// match and rewrite phases.  Patterns can define a pattern-specific subclass
+/// of this.
+class PatternState {
+public:
+  virtual ~PatternState() {}
+
+protected:
+  // Must be subclassed.
+  PatternState() {}
+};
+
+/// This is the type returned by a pattern match.  A match failure returns a
+/// None value.  A match success returns a Some value with any state the pattern
+/// may need to maintain (but may also be null).
+using PatternMatchResult = Optional<std::unique_ptr<PatternState>>;
+
+//===----------------------------------------------------------------------===//
+// Pattern class
+//===----------------------------------------------------------------------===//
+
+/// Instances of Pattern can be matched against SSA IR.  These matches get used
+/// in ways dependent on their subclasses and the driver doing the matching.
+/// For example, RewritePatterns implement a rewrite from one matched pattern
+/// to a replacement DAG tile.
+class Pattern {
+public:
+  /// Return the benefit (the inverse of "cost") of matching this pattern.  The
+  /// benefit of a Pattern is always static - rewrites that may have dynamic
+  /// benefit can be instantiated multiple times (different Pattern instances)
+  /// for each benefit that they may return, and be guarded by different match
+  /// condition predicates.
+  PatternBenefit getBenefit() const { return benefit; }
+
+  /// Return the root node that this pattern matches.  Patterns that can
+  /// match multiple root types are instantiated once per root.
+  OperationName getRootKind() const { return rootKind; }
+
+  //===--------------------------------------------------------------------===//
+  // Implementation hooks for patterns to implement.
+  //===--------------------------------------------------------------------===//
+
+  /// Attempt to match against code rooted at the specified operation,
+  /// which is the same operation code as getRootKind().  On failure, this
+  /// returns a None value.  On success it returns a (possibly null)
+  /// pattern-specific state wrapped in an Optional.
+  virtual PatternMatchResult match(Operation *op) const = 0;
+
+  virtual ~Pattern() {}
+
+  //===--------------------------------------------------------------------===//
+  // Helper methods to simplify pattern implementations
+  //===--------------------------------------------------------------------===//
+
+  /// This method indicates that no match was found.
+  static PatternMatchResult matchFailure() { return None; }
+
+  /// This method indicates that a match was found and has the specified cost.
+  PatternMatchResult
+  matchSuccess(std::unique_ptr<PatternState> state = {}) const {
+    return PatternMatchResult(std::move(state));
+  }
+
+protected:
+  /// Patterns must specify the root operation name they match against, and can
+  /// also specify the benefit of the pattern matching.
+  Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context);
+
+private:
+  const OperationName rootKind;
+  const PatternBenefit benefit;
+
+  virtual void anchor();
+};
+
+/// RewritePattern is the common base class for all DAG to DAG replacements.
+/// There are two possible usages of this class:
+///   * Multi-step RewritePattern with "match" and "rewrite"
+///     - By overloading the "match" and "rewrite" functions, the user can
+///       separate the concerns of matching and rewriting.
+///   * Single-step RewritePattern with "matchAndRewrite"
+///     - By overloading the "matchAndRewrite" function, the user can perform
+///       the rewrite in the same call as the match. This removes the need for
+///       any PatternState.
+///
+class RewritePattern : public Pattern {
+public:
+  /// Rewrite the IR rooted at the specified operation with the result of
+  /// this pattern, generating any new operations with the specified
+  /// rewriter.  If an unexpected error is encountered (an internal
+  /// compiler error), it is emitted through the normal MLIR diagnostic
+  /// hooks and the IR is left in a valid state.
+  virtual void rewrite(Operation *op, std::unique_ptr<PatternState> state,
+                       PatternRewriter &rewriter) const;
+
+  /// Rewrite the IR rooted at the specified operation with the result of
+  /// this pattern, generating any new operations with the specified
+  /// builder.  If an unexpected error is encountered (an internal
+  /// compiler error), it is emitted through the normal MLIR diagnostic
+  /// hooks and the IR is left in a valid state.
+  virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
+
+  /// Attempt to match against code rooted at the specified operation,
+  /// which is the same operation code as getRootKind().  On failure, this
+  /// returns a None value.  On success, it returns a (possibly null)
+  /// pattern-specific state wrapped in an Optional.  This state is passed back
+  /// into the rewrite function if this match is selected.
+  PatternMatchResult match(Operation *op) const override;
+
+  /// Attempt to match against code rooted at the specified operation,
+  /// which is the same operation code as getRootKind(). If successful, this
+  /// function will automatically perform the rewrite.
+  virtual PatternMatchResult matchAndRewrite(Operation *op,
+                                             PatternRewriter &rewriter) const {
+    if (auto matchResult = match(op)) {
+      rewrite(op, std::move(*matchResult), rewriter);
+      return matchSuccess();
+    }
+    return matchFailure();
+  }
+
+  /// Return a list of operations that may be generated when rewriting an
+  /// operation instance with this pattern.
+  ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
+
+protected:
+  /// Patterns must specify the root operation name they match against, and can
+  /// also specify the benefit of the pattern matching.
+  RewritePattern(StringRef rootName, PatternBenefit benefit,
+                 MLIRContext *context)
+      : Pattern(rootName, benefit, context) {}
+  /// Patterns must specify the root operation name they match against, and can
+  /// also specify the benefit of the pattern matching. They can also specify
+  /// the names of operations that may be generated during a successful rewrite.
+  RewritePattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
+                 PatternBenefit benefit, MLIRContext *context);
+
+  /// A list of the potential operations that may be generated when rewriting
+  /// an op with this pattern.
+  llvm::SmallVector<OperationName, 2> generatedOps;
+};
+
+/// OpRewritePattern is a wrapper around RewritePattern that allows for
+/// matching and rewriting against an instance of a derived operation class as
+/// opposed to a raw Operation.
+template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
+  /// Patterns must specify the root operation name they match against, and can
+  /// also specify the benefit of the pattern matching.
+  OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
+      : RewritePattern(SourceOp::getOperationName(), benefit, context) {}
+
+  /// Wrappers around the RewritePattern methods that pass the derived op type.
+  void rewrite(Operation *op, std::unique_ptr<PatternState> state,
+               PatternRewriter &rewriter) const final {
+    rewrite(llvm::cast<SourceOp>(op), std::move(state), rewriter);
+  }
+  void rewrite(Operation *op, PatternRewriter &rewriter) const final {
+    rewrite(llvm::cast<SourceOp>(op), rewriter);
+  }
+  PatternMatchResult match(Operation *op) const final {
+    return match(llvm::cast<SourceOp>(op));
+  }
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const final {
+    return matchAndRewrite(llvm::cast<SourceOp>(op), rewriter);
+  }
+
+  /// Rewrite and Match methods that operate on the SourceOp type. These must be
+  /// overridden by the derived pattern class.
+  virtual void rewrite(SourceOp op, std::unique_ptr<PatternState> state,
+                       PatternRewriter &rewriter) const {
+    rewrite(op, rewriter);
+  }
+  virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
+    llvm_unreachable("must override matchAndRewrite or a rewrite method");
+  }
+  virtual PatternMatchResult match(SourceOp op) const {
+    llvm_unreachable("must override match or matchAndRewrite");
+  }
+  virtual PatternMatchResult matchAndRewrite(SourceOp op,
+                                             PatternRewriter &rewriter) const {
+    if (auto matchResult = match(op)) {
+      rewrite(op, std::move(*matchResult), rewriter);
+      return matchSuccess();
+    }
+    return matchFailure();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// PatternRewriter class
+//===----------------------------------------------------------------------===//
+
+/// This class coordinates the application of a pattern to the current function,
+/// providing a way to create operations and keep track of what gets deleted.
+///
+/// These class serves two purposes:
+///  1) it is the interface that patterns interact with to make mutations to the
+///     IR they are being applied to.
+///  2) It is a base class that clients of the PatternMatcher use when they want
+///     to apply patterns and observe their effects (e.g. to keep worklists or
+///     other data structures up to date).
+///
+class PatternRewriter : public OpBuilder {
+public:
+  /// Create operation of specific op type at the current insertion point
+  /// without verifying to see if it is valid.
+  template <typename OpTy, typename... Args>
+  OpTy create(Location location, Args... args) {
+    OperationState state(location, OpTy::getOperationName());
+    OpTy::build(this, &state, args...);
+    auto *op = createOperation(state);
+    auto result = dyn_cast<OpTy>(op);
+    assert(result && "Builder didn't return the right type");
+    return result;
+  }
+
+  /// Creates an operation of specific op type at the current insertion point.
+  /// If the result is an invalid op (the verifier hook fails), emit an error
+  /// and return null.
+  template <typename OpTy, typename... Args>
+  OpTy createChecked(Location location, Args... args) {
+    OperationState state(location, OpTy::getOperationName());
+    OpTy::build(this, &state, args...);
+    auto *op = createOperation(state);
+
+    // If the Operation we produce is valid, return it.
+    if (!OpTy::verifyInvariants(op)) {
+      auto result = dyn_cast<OpTy>(op);
+      assert(result && "Builder didn't return the right type");
+      return result;
+    }
+
+    // Otherwise, the error message got emitted.  Just remove the operation
+    // we made.
+    op->erase();
+    return OpTy();
+  }
+
+  /// This is implemented to create the specified operations and serves as a
+  /// notification hook for rewriters that want to know about new operations.
+  virtual Operation *createOperation(const OperationState &state) = 0;
+
+  /// Move the blocks that belong to "region" before the given position in
+  /// another region "parent".  The two regions must be different.  The caller
+  /// is responsible for creating or updating the operation transferring flow
+  // of control to the region and pass it the correct block arguments.
+  virtual void inlineRegionBefore(Region &region, Region &parent,
+                                  Region::iterator before);
+  void inlineRegionBefore(Region &region, Block *before);
+
+  /// This method performs the final replacement for a pattern, where the
+  /// results of the operation are updated to use the specified list of SSA
+  /// values.  In addition to replacing and removing the specified operation,
+  /// clients can specify a list of other nodes that this replacement may make
+  /// (perhaps transitively) dead.  If any of those values are dead, this will
+  /// remove them as well.
+  virtual void replaceOp(Operation *op, ArrayRef<Value *> newValues,
+                         ArrayRef<Value *> valuesToRemoveIfDead);
+  void replaceOp(Operation *op, ArrayRef<Value *> newValues) {
+    replaceOp(op, newValues, llvm::None);
+  }
+
+  /// Replaces the result op with a new op that is created without verification.
+  /// The result values of the two ops must be the same types.
+  template <typename OpTy, typename... Args>
+  void replaceOpWithNewOp(Operation *op, Args &&... args) {
+    auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
+    replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(), {});
+  }
+
+  /// Replaces the result op with a new op that is created without verification.
+  /// The result values of the two ops must be the same types.  This allows
+  /// specifying a list of ops that may be removed if dead.
+  template <typename OpTy, typename... Args>
+  void replaceOpWithNewOp(ArrayRef<Value *> valuesToRemoveIfDead, Operation *op,
+                          Args &&... args) {
+    auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
+    replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(),
+                                    valuesToRemoveIfDead);
+  }
+
+  /// Split the operations starting at "before" (inclusive) out of the given
+  /// block into a new block, and return it.
+  virtual Block *splitBlock(Block *block, Block::iterator before) {
+    return block->splitBlock(before);
+  }
+
+  /// This method is used as the final notification hook for patterns that end
+  /// up modifying the pattern root in place, by changing its operands.  This is
+  /// a minor efficiency win (it avoids creating a new operation and removing
+  /// the old one) but also often allows simpler code in the client.
+  ///
+  /// The valuesToRemoveIfDead list is an optional list of values that the
+  /// rewriter should remove if they are dead at this point.
+  ///
+  void updatedRootInPlace(Operation *op,
+                          ArrayRef<Value *> valuesToRemoveIfDead = {});
+
+protected:
+  explicit PatternRewriter(MLIRContext *ctx) : OpBuilder(ctx) {}
+  virtual ~PatternRewriter();
+
+  // These are the callback methods that subclasses can choose to implement if
+  // they would like to be notified about certain types of mutations.
+
+  /// Notify the pattern rewriter that the specified operation has been mutated
+  /// in place.  This is called after the mutation is done.
+  virtual void notifyRootUpdated(Operation *op) {}
+
+  /// Notify the pattern rewriter that the specified operation is about to be
+  /// replaced with another set of operations.  This is called before the uses
+  /// of the operation have been changed.
+  virtual void notifyRootReplaced(Operation *op) {}
+
+  /// This is called on an operation that a pattern match is removing, right
+  /// before the operation is deleted.  At this point, the operation has zero
+  /// uses.
+  virtual void notifyOperationRemoved(Operation *op) {}
+
+private:
+  /// op and newOp are known to have the same number of results, replace the
+  /// uses of op with uses of newOp
+  void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp,
+                                       ArrayRef<Value *> valuesToRemoveIfDead);
+};
+
+//===----------------------------------------------------------------------===//
+// Pattern-driven rewriters
+//===----------------------------------------------------------------------===//
+
+class OwningRewritePatternList {
+  using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
+
+public:
+  PatternListT::iterator begin() { return patterns.begin(); }
+  PatternListT::iterator end() { return patterns.end(); }
+  PatternListT::const_iterator begin() const { return patterns.begin(); }
+  PatternListT::const_iterator end() const { return patterns.end(); }
+  void clear() { patterns.clear(); }
+
+  //===--------------------------------------------------------------------===//
+  // Pattern Insertion
+  //===--------------------------------------------------------------------===//
+
+  void insert(RewritePattern *pattern) { patterns.emplace_back(pattern); }
+
+  /// Add an instance of each of the pattern types 'Ts' to the pattern list with
+  /// the given arguments.
+  // Note: ConstructorArg is necessary here to separate the two variadic lists.
+  template <typename... Ts, typename ConstructorArg,
+            typename... ConstructorArgs>
+  void insert(ConstructorArg &&arg, ConstructorArgs &&... args) {
+    // The following expands a call to emplace_back for each of the pattern
+    // types 'Ts'. This magic is necessary due to a limitation in the places
+    // that a parameter pack can be expanded in c++11.
+    // FIXME: In c++17 this can be simplified by using 'fold expressions'.
+    using dummy = int[];
+    (void)dummy{
+        0, (patterns.emplace_back(llvm::make_unique<Ts>(arg, args...)), 0)...};
+  }
+
+private:
+  PatternListT patterns;
+};
+
+/// This class manages optimization and execution of a group of rewrite
+/// patterns, providing an API for finding and applying, the best match against
+/// a given node.
+///
+class RewritePatternMatcher {
+public:
+  /// Create a RewritePatternMatcher with the specified set of patterns.
+  explicit RewritePatternMatcher(OwningRewritePatternList &patterns);
+
+  /// Try to match the given operation to a pattern and rewrite it. Return
+  /// true if any pattern matches.
+  bool matchAndRewrite(Operation *op, PatternRewriter &rewriter);
+
+private:
+  RewritePatternMatcher(const RewritePatternMatcher &) = delete;
+  void operator=(const RewritePatternMatcher &) = delete;
+
+  /// The group of patterns that are matched for optimization through this
+  /// matcher.
+  std::vector<RewritePattern *> patterns;
+};
+
+/// Rewrite the regions of the specified operation, which must be isolated from
+/// above, by repeatedly applying the highest benefit patterns in a greedy
+/// work-list driven manner. Return true if no more patterns can be matched in
+/// the result operation regions.
+/// Note: This does not apply patterns to the top-level operation itself.
+///
+bool applyPatternsGreedily(Operation *op, OwningRewritePatternList &patterns);
+
+} // end namespace mlir
+
+#endif // MLIR_PATTERN_MATCH_H
diff --git a/third_party/mlir/include/mlir/IR/Region.h b/third_party/mlir/include/mlir/IR/Region.h
new file mode 100644
index 0000000..c6f97c3
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/Region.h
@@ -0,0 +1,147 @@
+//===- Region.h - MLIR Region Class -----------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines the Region class.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_REGION_H
+#define MLIR_IR_REGION_H
+
+#include "mlir/IR/Block.h"
+
+namespace mlir {
+class BlockAndValueMapping;
+
+/// This class contains a list of basic blocks and a link to the parent
+/// operation it is attached to.
+class Region {
+public:
+  Region() = default;
+  explicit Region(Operation *container);
+  ~Region();
+
+  /// Return the context this region is inserted in.  The region must have a
+  /// valid parent container.
+  MLIRContext *getContext();
+
+  /// Return a location for this region. This is the location attached to the
+  /// parent container. The region must have a valid parent container.
+  Location getLoc();
+
+  using RegionType = llvm::iplist<Block>;
+  RegionType &getBlocks() { return blocks; }
+
+  // Iteration over the block in the function.
+  using iterator = RegionType::iterator;
+  using reverse_iterator = RegionType::reverse_iterator;
+
+  iterator begin() { return blocks.begin(); }
+  iterator end() { return blocks.end(); }
+  reverse_iterator rbegin() { return blocks.rbegin(); }
+  reverse_iterator rend() { return blocks.rend(); }
+
+  bool empty() { return blocks.empty(); }
+  void push_back(Block *block) { blocks.push_back(block); }
+  void push_front(Block *block) { blocks.push_front(block); }
+
+  Block &back() { return blocks.back(); }
+  Block &front() { return blocks.front(); }
+
+  /// getSublistAccess() - Returns pointer to member of region.
+  static RegionType Region::*getSublistAccess(Block *) {
+    return &Region::blocks;
+  }
+
+  /// Return the region containing this region or nullptr if the region is
+  /// attached to a top-level operation.
+  Region *getParentRegion();
+
+  /// Return the parent operation this region is attached to.
+  Operation *getParentOp();
+
+  /// Find the first parent operation of the given type, or nullptr if there is
+  /// no ancestor operation.
+  template <typename ParentT> ParentT getParentOfType() {
+    auto *region = this;
+    do {
+      if (auto parent = dyn_cast_or_null<ParentT>(region->container))
+        return parent;
+    } while ((region = region->getParentRegion()));
+    return ParentT();
+  }
+
+  /// Return the number of this region in the parent operation.
+  unsigned getRegionNumber();
+
+  /// Return true if this region is a proper ancestor of the `other` region.
+  bool isProperAncestor(Region *other);
+
+  /// Return true if this region is ancestor of the `other` region.  A region
+  /// is considered as its own ancestor, use `isProperAncestor` to avoid this.
+  bool isAncestor(Region *other) {
+    return this == other || isProperAncestor(other);
+  }
+
+  /// Clone the internal blocks from this region into dest. Any
+  /// cloned blocks are appended to the back of dest. If the mapper
+  /// contains entries for block arguments, these arguments are not included
+  /// in the respective cloned block.
+  void cloneInto(Region *dest, BlockAndValueMapping &mapper);
+  /// Clone this region into 'dest' before the given position in 'dest'.
+  void cloneInto(Region *dest, Region::iterator destPos,
+                 BlockAndValueMapping &mapper);
+
+  /// Takes body of another region (that region will have no body after this
+  /// operation completes).  The current body of this region is cleared.
+  void takeBody(Region &other) {
+    blocks.clear();
+    blocks.splice(blocks.end(), other.getBlocks());
+  }
+
+  /// Check that this does not use any value defined outside it.
+  /// Emit errors if `noteLoc` is provided; this location is used to point
+  /// to the operation containing the region, the actual error is reported at
+  /// the operation with an offending use.
+  bool isIsolatedFromAbove(llvm::Optional<Location> noteLoc = llvm::None);
+
+  /// Drop all operand uses from operations within this region, which is
+  /// an essential step in breaking cyclic dependences between references when
+  /// they are to be deleted.
+  void dropAllReferences();
+
+  /// Walk the operations in this block in postorder, calling the callback for
+  /// each operation.
+  void walk(llvm::function_ref<void(Operation *)> callback);
+
+  /// Displays the CFG in a window. This is for use from the debugger and
+  /// depends on Graphviz to generate the graph.
+  /// This function is defined in ViewRegionGraph and only works with that
+  /// target linked.
+  void viewGraph(const llvm::Twine &regionName);
+  void viewGraph();
+
+private:
+  RegionType blocks;
+
+  /// This is the object we are part of.
+  Operation *container;
+};
+
+} // end namespace mlir
+
+#endif // MLIR_IR_REGION_H
diff --git a/third_party/mlir/include/mlir/IR/RegionGraphTraits.h b/third_party/mlir/include/mlir/IR/RegionGraphTraits.h
new file mode 100644
index 0000000..f45dcc4
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/RegionGraphTraits.h
@@ -0,0 +1,94 @@
+//===- RegionGraphTraits.h - llvm::GraphTraits for CFGs ---------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements specializations of llvm::GraphTraits for various MLIR
+// CFG data types.  This allows the generic LLVM graph algorithms to be applied
+// to CFGs.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_REGIONGRAPHTRAITS_H
+#define MLIR_IR_REGIONGRAPHTRAITS_H
+
+#include "mlir/IR/Region.h"
+#include "llvm/ADT/GraphTraits.h"
+
+namespace llvm {
+template <> struct GraphTraits<mlir::Block *> {
+  using ChildIteratorType = mlir::Block::succ_iterator;
+  using Node = mlir::Block;
+  using NodeRef = Node *;
+
+  static NodeRef getEntryNode(NodeRef bb) { return bb; }
+
+  static ChildIteratorType child_begin(NodeRef node) {
+    return node->succ_begin();
+  }
+  static ChildIteratorType child_end(NodeRef node) { return node->succ_end(); }
+};
+
+template <> struct GraphTraits<Inverse<mlir::Block *>> {
+  using ChildIteratorType = mlir::Block::pred_iterator;
+  using Node = mlir::Block;
+  using NodeRef = Node *;
+  static NodeRef getEntryNode(Inverse<NodeRef> inverseGraph) {
+    return inverseGraph.Graph;
+  }
+  static inline ChildIteratorType child_begin(NodeRef node) {
+    return node->pred_begin();
+  }
+  static inline ChildIteratorType child_end(NodeRef node) {
+    return node->pred_end();
+  }
+};
+
+template <>
+struct GraphTraits<mlir::Region *> : public GraphTraits<mlir::Block *> {
+  using GraphType = mlir::Region *;
+  using NodeRef = mlir::Block *;
+
+  static NodeRef getEntryNode(GraphType fn) { return &fn->front(); }
+
+  using nodes_iterator = pointer_iterator<mlir::Region::iterator>;
+  static nodes_iterator nodes_begin(GraphType fn) {
+    return nodes_iterator(fn->begin());
+  }
+  static nodes_iterator nodes_end(GraphType fn) {
+    return nodes_iterator(fn->end());
+  }
+};
+
+template <>
+struct GraphTraits<Inverse<mlir::Region *>>
+    : public GraphTraits<Inverse<mlir::Block *>> {
+  using GraphType = Inverse<mlir::Region *>;
+  using NodeRef = NodeRef;
+
+  static NodeRef getEntryNode(GraphType fn) { return &fn.Graph->front(); }
+
+  using nodes_iterator = pointer_iterator<mlir::Region::iterator>;
+  static nodes_iterator nodes_begin(GraphType fn) {
+    return nodes_iterator(fn.Graph->begin());
+  }
+  static nodes_iterator nodes_end(GraphType fn) {
+    return nodes_iterator(fn.Graph->end());
+  }
+};
+
+} // namespace llvm
+
+#endif
diff --git a/third_party/mlir/include/mlir/IR/StandardTypes.h b/third_party/mlir/include/mlir/IR/StandardTypes.h
new file mode 100644
index 0000000..4666e58
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/StandardTypes.h
@@ -0,0 +1,496 @@
+//===- StandardTypes.h - MLIR Standard Type Classes -------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_IR_STANDARDTYPES_H
+#define MLIR_IR_STANDARDTYPES_H
+
+#include "mlir/IR/Types.h"
+
+namespace llvm {
+struct fltSemantics;
+} // namespace llvm
+
+namespace mlir {
+class AffineMap;
+class FloatType;
+class IndexType;
+class IntegerType;
+class Location;
+class MLIRContext;
+
+namespace detail {
+
+struct IntegerTypeStorage;
+struct ShapedTypeStorage;
+struct VectorTypeStorage;
+struct RankedTensorTypeStorage;
+struct UnrankedTensorTypeStorage;
+struct MemRefTypeStorage;
+struct ComplexTypeStorage;
+struct TupleTypeStorage;
+
+} // namespace detail
+
+namespace StandardTypes {
+enum Kind {
+  // Floating point.
+  BF16 = Type::Kind::FIRST_STANDARD_TYPE,
+  F16,
+  F32,
+  F64,
+  FIRST_FLOATING_POINT_TYPE = BF16,
+  LAST_FLOATING_POINT_TYPE = F64,
+
+  // Target pointer sized integer, used (e.g.) in affine mappings.
+  Index,
+
+  // Derived types.
+  Integer,
+  Vector,
+  RankedTensor,
+  UnrankedTensor,
+  MemRef,
+  Complex,
+  Tuple,
+  None,
+};
+
+} // namespace StandardTypes
+
+inline bool Type::isBF16() { return getKind() == StandardTypes::BF16; }
+inline bool Type::isF16() { return getKind() == StandardTypes::F16; }
+inline bool Type::isF32() { return getKind() == StandardTypes::F32; }
+inline bool Type::isF64() { return getKind() == StandardTypes::F64; }
+
+inline bool Type::isIndex() { return getKind() == StandardTypes::Index; }
+
+/// Index is a special integer-like type with unknown platform-dependent bit
+/// width.
+class IndexType : public Type::TypeBase<IndexType, Type> {
+public:
+  using Base::Base;
+
+  /// Get an instance of the IndexType.
+  static IndexType get(MLIRContext *context);
+
+  /// Support method to enable LLVM-style type casting.
+  static bool kindof(unsigned kind) { return kind == StandardTypes::Index; }
+};
+
+/// Integer types can have arbitrary bitwidth up to a large fixed limit.
+class IntegerType
+    : public Type::TypeBase<IntegerType, Type, detail::IntegerTypeStorage> {
+public:
+  using Base::Base;
+
+  /// Get or create a new IntegerType of the given width within the context.
+  /// Assume the width is within the allowed range and assert on failures.
+  /// Use getChecked to handle failures gracefully.
+  static IntegerType get(unsigned width, MLIRContext *context);
+
+  /// Get or create a new IntegerType of the given width within the context,
+  /// defined at the given, potentially unknown, location.  If the width is
+  /// outside the allowed range, emit errors and return a null type.
+  static IntegerType getChecked(unsigned width, MLIRContext *context,
+                                Location location);
+
+  /// Verify the construction of an integer type.
+  static LogicalResult
+  verifyConstructionInvariants(llvm::Optional<Location> loc,
+                               MLIRContext *context, unsigned width);
+
+  /// Return the bitwidth of this integer type.
+  unsigned getWidth() const;
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool kindof(unsigned kind) { return kind == StandardTypes::Integer; }
+
+  /// Integer representation maximal bitwidth.
+  static constexpr unsigned kMaxWidth = 4096;
+};
+
+/// Return true if this is an integer type with the specified width.
+inline bool Type::isInteger(unsigned width) {
+  if (auto intTy = dyn_cast<IntegerType>())
+    return intTy.getWidth() == width;
+  return false;
+}
+
+inline bool Type::isIntOrIndex() {
+  return isa<IndexType>() || isa<IntegerType>();
+}
+
+inline bool Type::isIntOrIndexOrFloat() {
+  return isa<IndexType>() || isa<IntegerType>() || isa<FloatType>();
+}
+
+inline bool Type::isIntOrFloat() {
+  return isa<IntegerType>() || isa<FloatType>();
+}
+
+class FloatType : public Type::TypeBase<FloatType, Type> {
+public:
+  using Base::Base;
+
+  static FloatType get(StandardTypes::Kind kind, MLIRContext *context);
+
+  // Convenience factories.
+  static FloatType getBF16(MLIRContext *ctx) {
+    return get(StandardTypes::BF16, ctx);
+  }
+  static FloatType getF16(MLIRContext *ctx) {
+    return get(StandardTypes::F16, ctx);
+  }
+  static FloatType getF32(MLIRContext *ctx) {
+    return get(StandardTypes::F32, ctx);
+  }
+  static FloatType getF64(MLIRContext *ctx) {
+    return get(StandardTypes::F64, ctx);
+  }
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool kindof(unsigned kind) {
+    return kind >= StandardTypes::FIRST_FLOATING_POINT_TYPE &&
+           kind <= StandardTypes::LAST_FLOATING_POINT_TYPE;
+  }
+
+  /// Return the bitwidth of this float type.
+  unsigned getWidth();
+
+  /// Return the floating semantics of this float type.
+  const llvm::fltSemantics &getFloatSemantics();
+};
+
+/// This is a common base class between Vector, UnrankedTensor, RankedTensor,
+/// and MemRef types because they share behavior and semantics around shape,
+/// rank, and fixed element type. Any type with these semantics should inherit
+/// from ShapedType.
+class ShapedType : public Type {
+public:
+  using ImplType = detail::ShapedTypeStorage;
+  using Type::Type;
+
+  /// Return the element type.
+  Type getElementType() const;
+
+  /// If an element type is an integer or a float, return its width. Otherwise,
+  /// abort.
+  unsigned getElementTypeBitWidth() const;
+
+  /// If it has static shape, return the number of elements. Otherwise, abort.
+  int64_t getNumElements() const;
+
+  /// If this is a ranked type, return the rank. Otherwise, abort.
+  int64_t getRank() const;
+
+  /// Whether or not this is a ranked type. Memrefs, vectors and ranked tensors
+  /// have a rank, while unranked tensors do not.
+  bool hasRank() const;
+
+  /// If this is a ranked type, return the shape. Otherwise, abort.
+  ArrayRef<int64_t> getShape() const;
+
+  /// If this is unranked type or any dimension has unknown size (<0), it
+  /// doesn't have static shape. If all dimensions have known size (>= 0), it
+  /// has static shape.
+  bool hasStaticShape() const;
+
+  /// If this is a ranked type, return the number of dimensions with dynamic
+  /// size. Otherwise, abort.
+  int64_t getNumDynamicDims() const;
+
+  /// If this is ranked type, return the size of the specified dimension.
+  /// Otherwise, abort.
+  int64_t getDimSize(int64_t i) const;
+
+  /// Get the total amount of bits occupied by a value of this type.  This does
+  /// not take into account any memory layout or widening constraints, e.g. a
+  /// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice
+  /// it will likely be stored as in a 4xi64 vector register.  Fail an assertion
+  /// if the size cannot be computed statically, i.e. if the type has a dynamic
+  /// shape or if its elemental type does not have a known bit width.
+  int64_t getSizeInBits() const;
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool classof(Type type) {
+    return type.getKind() == StandardTypes::Vector ||
+           type.getKind() == StandardTypes::RankedTensor ||
+           type.getKind() == StandardTypes::UnrankedTensor ||
+           type.getKind() == StandardTypes::MemRef;
+  }
+
+  /// Whether the given dimension size indicates a dynamic dimension.
+  static constexpr bool isDynamic(int64_t dSize) { return dSize < 0; }
+};
+
+/// Vector types represent multi-dimensional SIMD vectors, and have a fixed
+/// known constant shape with one or more dimension.
+class VectorType
+    : public Type::TypeBase<VectorType, ShapedType, detail::VectorTypeStorage> {
+public:
+  using Base::Base;
+
+  /// Get or create a new VectorType of the provided shape and element type.
+  /// Assumes the arguments define a well-formed VectorType.
+  static VectorType get(ArrayRef<int64_t> shape, Type elementType);
+
+  /// Get or create a new VectorType of the provided shape and element type
+  /// declared at the given, potentially unknown, location.  If the VectorType
+  /// defined by the arguments would be ill-formed, emit errors and return
+  /// nullptr-wrapping type.
+  static VectorType getChecked(ArrayRef<int64_t> shape, Type elementType,
+                               Location location);
+
+  /// Verify the construction of a vector type.
+  static LogicalResult
+  verifyConstructionInvariants(llvm::Optional<Location> loc,
+                               MLIRContext *context, ArrayRef<int64_t> shape,
+                               Type elementType);
+
+  /// Returns true of the given type can be used as an element of a vector type.
+  /// In particular, vectors can consist of integer or float primitives.
+  static bool isValidElementType(Type t) { return t.isIntOrFloat(); }
+
+  ArrayRef<int64_t> getShape() const;
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool kindof(unsigned kind) { return kind == StandardTypes::Vector; }
+};
+
+/// Tensor types represent multi-dimensional arrays, and have two variants:
+/// RankedTensorType and UnrankedTensorType.
+class TensorType : public ShapedType {
+public:
+  using ShapedType::ShapedType;
+
+  /// Return true if the specified element type is ok in a tensor.
+  static bool isValidElementType(Type type) {
+    // Note: Non standard/builtin types are allowed to exist within tensor
+    // types. Dialects are expected to verify that tensor types have a valid
+    // element type within that dialect.
+    return type.isIntOrFloat() || type.isa<VectorType>() ||
+           type.isa<OpaqueType>() ||
+           (type.getKind() > Type::Kind::LAST_STANDARD_TYPE);
+  }
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool classof(Type type) {
+    return type.getKind() == StandardTypes::RankedTensor ||
+           type.getKind() == StandardTypes::UnrankedTensor;
+  }
+};
+
+/// Ranked tensor types represent multi-dimensional arrays that have a shape
+/// with a fixed number of dimensions. Each shape element can be a positive
+/// integer or unknown (represented -1).
+class RankedTensorType
+    : public Type::TypeBase<RankedTensorType, TensorType,
+                            detail::RankedTensorTypeStorage> {
+public:
+  using Base::Base;
+
+  /// Get or create a new RankedTensorType of the provided shape and element
+  /// type. Assumes the arguments define a well-formed type.
+  static RankedTensorType get(ArrayRef<int64_t> shape, Type elementType);
+
+  /// Get or create a new RankedTensorType of the provided shape and element
+  /// type declared at the given, potentially unknown, location.  If the
+  /// RankedTensorType defined by the arguments would be ill-formed, emit errors
+  /// and return a nullptr-wrapping type.
+  static RankedTensorType getChecked(ArrayRef<int64_t> shape, Type elementType,
+                                     Location location);
+
+  /// Verify the construction of a ranked tensor type.
+  static LogicalResult
+  verifyConstructionInvariants(llvm::Optional<Location> loc,
+                               MLIRContext *context, ArrayRef<int64_t> shape,
+                               Type elementType);
+
+  ArrayRef<int64_t> getShape() const;
+
+  static bool kindof(unsigned kind) {
+    return kind == StandardTypes::RankedTensor;
+  }
+};
+
+/// Unranked tensor types represent multi-dimensional arrays that have an
+/// unknown shape.
+class UnrankedTensorType
+    : public Type::TypeBase<UnrankedTensorType, TensorType,
+                            detail::UnrankedTensorTypeStorage> {
+public:
+  using Base::Base;
+
+  /// Get or create a new UnrankedTensorType of the provided shape and element
+  /// type. Assumes the arguments define a well-formed type.
+  static UnrankedTensorType get(Type elementType);
+
+  /// Get or create a new UnrankedTensorType of the provided shape and element
+  /// type declared at the given, potentially unknown, location.  If the
+  /// UnrankedTensorType defined by the arguments would be ill-formed, emit
+  /// errors and return a nullptr-wrapping type.
+  static UnrankedTensorType getChecked(Type elementType, Location location);
+
+  /// Verify the construction of a unranked tensor type.
+  static LogicalResult
+  verifyConstructionInvariants(llvm::Optional<Location> loc,
+                               MLIRContext *context, Type elementType);
+
+  ArrayRef<int64_t> getShape() const { return llvm::None; }
+
+  static bool kindof(unsigned kind) {
+    return kind == StandardTypes::UnrankedTensor;
+  }
+};
+
+/// MemRef types represent a region of memory that have a shape with a fixed
+/// number of dimensions. Each shape element can be a non-negative integer or
+/// unknown (represented by any negative integer). MemRef types also have an
+/// affine map composition, represented as an array AffineMap pointers.
+class MemRefType
+    : public Type::TypeBase<MemRefType, ShapedType, detail::MemRefTypeStorage> {
+public:
+  using Base::Base;
+
+  /// Get or create a new MemRefType based on shape, element type, affine
+  /// map composition, and memory space.  Assumes the arguments define a
+  /// well-formed MemRef type.  Use getChecked to gracefully handle MemRefType
+  /// construction failures.
+  static MemRefType get(ArrayRef<int64_t> shape, Type elementType,
+                        ArrayRef<AffineMap> affineMapComposition = {},
+                        unsigned memorySpace = 0);
+
+  /// Get or create a new MemRefType based on shape, element type, affine
+  /// map composition, and memory space declared at the given location.
+  /// If the location is unknown, the last argument should be an instance of
+  /// UnknownLoc.  If the MemRefType defined by the arguments would be
+  /// ill-formed, emits errors (to the handler registered with the context or to
+  /// the error stream) and returns nullptr.
+  static MemRefType getChecked(ArrayRef<int64_t> shape, Type elementType,
+                               ArrayRef<AffineMap> affineMapComposition,
+                               unsigned memorySpace, Location location);
+
+  ArrayRef<int64_t> getShape() const;
+
+  /// Returns an array of affine map pointers representing the memref affine
+  /// map composition.
+  ArrayRef<AffineMap> getAffineMaps() const;
+
+  /// Returns the memory space in which data referred to by this memref resides.
+  unsigned getMemorySpace() const;
+
+  static bool kindof(unsigned kind) { return kind == StandardTypes::MemRef; }
+
+private:
+  /// Get or create a new MemRefType defined by the arguments.  If the resulting
+  /// type would be ill-formed, return nullptr.  If the location is provided,
+  /// emit detailed error messages.
+  static MemRefType getImpl(ArrayRef<int64_t> shape, Type elementType,
+                            ArrayRef<AffineMap> affineMapComposition,
+                            unsigned memorySpace, Optional<Location> location);
+  using Base::getImpl;
+};
+
+/// The 'complex' type represents a complex number with a parameterized element
+/// type, which is composed of a real and imaginary value of that element type.
+///
+/// The element must be a floating point or integer scalar type.
+///
+class ComplexType
+    : public Type::TypeBase<ComplexType, Type, detail::ComplexTypeStorage> {
+public:
+  using Base::Base;
+
+  /// Get or create a ComplexType with the provided element type.
+  static ComplexType get(Type elementType);
+
+  /// Get or create a ComplexType with the provided element type.  This emits
+  /// and error at the specified location and returns null if the element type
+  /// isn't supported.
+  static ComplexType getChecked(Type elementType, Location location);
+
+  /// Verify the construction of an integer type.
+  static LogicalResult
+  verifyConstructionInvariants(llvm::Optional<Location> loc,
+                               MLIRContext *context, Type elementType);
+
+  Type getElementType();
+
+  static bool kindof(unsigned kind) { return kind == StandardTypes::Complex; }
+
+private:
+  static ComplexType getCheckedImpl(Type elementType,
+                                    Optional<Location> location);
+};
+
+/// Tuple types represent a collection of other types. Note: This type merely
+/// provides a common mechanism for representing tuples in MLIR. It is up to
+/// dialect authors to provides operations for manipulating them, e.g.
+/// extract_tuple_element. When possible, users should prefer multi-result
+/// operations in the place of tuples.
+class TupleType
+    : public Type::TypeBase<TupleType, Type, detail::TupleTypeStorage> {
+public:
+  using Base::Base;
+
+  /// Get or create a new TupleType with the provided element types. Assumes the
+  /// arguments define a well-formed type.
+  static TupleType get(ArrayRef<Type> elementTypes, MLIRContext *context);
+
+  /// Get or create an empty tuple type.
+  static TupleType get(MLIRContext *context) { return get({}, context); }
+
+  /// Return the elements types for this tuple.
+  ArrayRef<Type> getTypes() const;
+
+  /// Accumulate the types contained in this tuple and tuples nested within it.
+  /// Note that this only flattens nested tuples, not any other container type,
+  /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
+  /// (i32, tensor<i32>, f32, i64)
+  void getFlattenedTypes(SmallVectorImpl<Type> &types);
+
+  /// Return the number of held types.
+  size_t size() const;
+
+  /// Iterate over the held elements.
+  using iterator = ArrayRef<Type>::iterator;
+  iterator begin() const { return getTypes().begin(); }
+  iterator end() const { return getTypes().end(); }
+
+  /// Return the element type at index 'index'.
+  Type getType(size_t index) const {
+    assert(index < size() && "invalid index for tuple type");
+    return getTypes()[index];
+  }
+
+  static bool kindof(unsigned kind) { return kind == StandardTypes::Tuple; }
+};
+
+/// NoneType is a unit type, i.e. a type with exactly one possible value, where
+/// its value does not have a defined dynamic representation.
+class NoneType : public Type::TypeBase<NoneType, Type> {
+public:
+  using Base::Base;
+
+  /// Get an instance of the NoneType.
+  static NoneType get(MLIRContext *context);
+
+  static bool kindof(unsigned kind) { return kind == StandardTypes::None; }
+};
+} // end namespace mlir
+
+#endif // MLIR_IR_STANDARDTYPES_H
diff --git a/third_party/mlir/include/mlir/IR/StorageUniquerSupport.h b/third_party/mlir/include/mlir/IR/StorageUniquerSupport.h
new file mode 100644
index 0000000..1a73073
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -0,0 +1,94 @@
+//===- StorageUniquerSupport.h - MLIR Storage Uniquer Utilities -*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines utility classes for interfacing with StorageUniquer.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_STORAGEUNIQUERSUPPORT_H
+#define MLIR_IR_STORAGEUNIQUERSUPPORT_H
+
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/STLExtras.h"
+#include "mlir/Support/StorageUniquer.h"
+
+namespace mlir {
+class Location;
+class MLIRContext;
+
+namespace detail {
+/// Utility class for implementing users of storage classes uniqued by a
+/// StorageUniquer. Clients are not expected to interact with this class
+/// directly.
+template <typename ConcreteT, typename BaseT, typename StorageT,
+          typename UniquerT>
+class StorageUserBase : public BaseT {
+public:
+  using BaseT::BaseT;
+
+  /// Utility declarations for the concrete attribute class.
+  using Base = StorageUserBase<ConcreteT, BaseT, StorageT, UniquerT>;
+  using ImplType = StorageT;
+
+  /// Return a unique identifier for the concrete type.
+  static ClassID *getClassID() { return ClassID::getID<ConcreteT>(); }
+
+  /// Provide a default implementation of 'classof' that invokes a 'kindof'
+  /// method on the concrete type.
+  template <typename T> static bool classof(T val) {
+    static_assert(std::is_convertible<ConcreteT, T>::value,
+                  "casting from a non-convertible type");
+    return ConcreteT::kindof(val.getKind());
+  }
+
+protected:
+  /// Get or create a new ConcreteT instance within the ctx. This
+  /// function is guaranteed to return a non null object and will assert if
+  /// the arguments provided are invalid.
+  template <typename... Args>
+  static ConcreteT get(MLIRContext *ctx, unsigned kind, Args... args) {
+    // Ensure that the invariants are correct for construction.
+    assert(succeeded(
+        ConcreteT::verifyConstructionInvariants(llvm::None, ctx, args...)));
+    return UniquerT::template get<ConcreteT>(ctx, kind, args...);
+  }
+
+  /// Get or create a new ConcreteT instance within the ctx, defined at
+  /// the given, potentially unknown, location. If the arguments provided are
+  /// invalid then emit errors and return a null object.
+  template <typename... Args>
+  static ConcreteT getChecked(const Location &loc, MLIRContext *ctx,
+                              unsigned kind, Args... args) {
+    // If the construction invariants fail then we return a null attribute.
+    if (failed(ConcreteT::verifyConstructionInvariants(loc, ctx, args...)))
+      return ConcreteT();
+    return UniquerT::template get<ConcreteT>(ctx, kind, args...);
+  }
+
+  /// Default implementation that just returns success.
+  template <typename... Args>
+  static LogicalResult verifyConstructionInvariants(Args... args) {
+    return success();
+  }
+
+  /// Utility for easy access to the storage instance.
+  ImplType *getImpl() const { return static_cast<ImplType *>(this->impl); }
+};
+} // namespace detail
+} // namespace mlir
+
+#endif
diff --git a/third_party/mlir/include/mlir/IR/SymbolTable.h b/third_party/mlir/include/mlir/IR/SymbolTable.h
new file mode 100644
index 0000000..8826809
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/SymbolTable.h
@@ -0,0 +1,109 @@
+//===- SymbolTable.h - MLIR Symbol Table Class ------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_IR_SYMBOLTABLE_H
+#define MLIR_IR_SYMBOLTABLE_H
+
+#include "mlir/IR/OpDefinition.h"
+#include "llvm/ADT/StringMap.h"
+
+namespace mlir {
+class Identifier;
+class MLIRContext;
+class Operation;
+
+/// This class allows for representing and managing the symbol table used by
+/// operations with the 'SymbolTable' trait.
+class SymbolTable {
+public:
+  /// Build a symbol table with the symbols within the given operation.
+  SymbolTable(Operation *op);
+
+  /// Look up a symbol with the specified name, returning null if no such
+  /// name exists. Names never include the @ on them.
+  Operation *lookup(StringRef name) const;
+  template <typename T> T lookup(StringRef name) const {
+    return dyn_cast_or_null<T>(lookup(name));
+  }
+
+  /// Erase the given symbol from the table.
+  void erase(Operation *symbol);
+
+  /// Insert a new symbol into the table, and rename it as necessary to avoid
+  /// collisions.
+  void insert(Operation *symbol);
+
+  /// Returns the context held by this symbol table.
+  MLIRContext *getContext() const { return context; }
+
+  /// Return the name of the attribute used for symbol names.
+  static StringRef getSymbolAttrName() { return "sym_name"; }
+
+private:
+  MLIRContext *context;
+
+  /// This is a mapping from a name to the symbol with that name.
+  llvm::StringMap<Operation *> symbolTable;
+
+  /// This is used when name conflicts are detected.
+  unsigned uniquingCounter = 0;
+};
+
+//===----------------------------------------------------------------------===//
+// SymbolTable Trait Types
+//===----------------------------------------------------------------------===//
+
+namespace OpTrait {
+namespace impl {
+LogicalResult verifySymbolTable(Operation *op);
+} // namespace impl
+
+/// A trait used to provide symbol table functionalities to a region operation.
+/// This operation must hold exactly 1 region. Once attached, all operations
+/// that are directly within the region, i.e not including those within child
+/// regions, that contain a 'SymbolTable::getSymbolAttrName()' StringAttr will
+/// be verified to ensure that the names are uniqued.
+template <typename ConcreteType>
+class SymbolTable : public TraitBase<ConcreteType, SymbolTable> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifySymbolTable(op);
+  }
+
+  /// Look up a symbol with the specified name, returning null if no such
+  /// name exists. Symbol names never include the @ on them. Note: This
+  /// performs a linear scan of held symbols.
+  Operation *lookupSymbol(StringRef name) {
+    // Look for a symbol with the given name.
+    for (auto &block : this->getOperation()->getRegion(0)) {
+      for (auto &op : block) {
+        auto nameAttr = op.template getAttrOfType<StringAttr>(
+            mlir::SymbolTable::getSymbolAttrName());
+        if (nameAttr && nameAttr.getValue() == name)
+          return &op;
+      }
+    }
+    return nullptr;
+  }
+  template <typename T> T lookupSymbol(StringRef name) {
+    return dyn_cast_or_null<T>(lookupSymbol(name));
+  }
+};
+} // end namespace OpTrait
+} // end namespace mlir
+
+#endif // MLIR_IR_SYMBOLTABLE_H
diff --git a/third_party/mlir/include/mlir/IR/TypeSupport.h b/third_party/mlir/include/mlir/IR/TypeSupport.h
new file mode 100644
index 0000000..86620da
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/TypeSupport.h
@@ -0,0 +1,121 @@
+//===- TypeSupport.h --------------------------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines support types for registering dialect extended types.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_TYPE_SUPPORT_H
+#define MLIR_IR_TYPE_SUPPORT_H
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/StorageUniquerSupport.h"
+
+namespace mlir {
+struct ClassID;
+class Dialect;
+class MLIRContext;
+
+//===----------------------------------------------------------------------===//
+// TypeStorage
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+class TypeUniquer;
+} // end namespace detail
+
+/// Base storage class appearing in a Type.
+class TypeStorage : public StorageUniquer::BaseStorage {
+  friend detail::TypeUniquer;
+  friend StorageUniquer;
+
+protected:
+  /// This constructor is used by derived classes as part of the TypeUniquer.
+  /// When using this constructor, the initializeTypeInfo function must be
+  /// invoked afterwards for the storage to be valid.
+  TypeStorage(unsigned subclassData = 0)
+      : dialect(nullptr), subclassData(subclassData) {}
+
+public:
+  /// Get the dialect that this type is registered to.
+  Dialect &getDialect() {
+    assert(dialect && "Malformed type storage object.");
+    return *dialect;
+  }
+  /// Get the subclass data.
+  unsigned getSubclassData() const { return subclassData; }
+
+  /// Set the subclass data.
+  void setSubclassData(unsigned val) { subclassData = val; }
+
+private:
+  // Set the dialect for this storage instance. This is used by the TypeUniquer
+  // when initializing a newly constructed type storage object.
+  void initializeDialect(Dialect &newDialect) { dialect = &newDialect; }
+
+  /// The dialect for this type.
+  Dialect *dialect;
+
+  /// Space for subclasses to store data.
+  unsigned subclassData;
+};
+
+/// Default storage type for types that require no additional initialization or
+/// storage.
+using DefaultTypeStorage = TypeStorage;
+
+//===----------------------------------------------------------------------===//
+// TypeStorageAllocator
+//===----------------------------------------------------------------------===//
+
+// This is a utility allocator used to allocate memory for instances of derived
+// Types.
+using TypeStorageAllocator = StorageUniquer::StorageAllocator;
+
+//===----------------------------------------------------------------------===//
+// TypeUniquer
+//===----------------------------------------------------------------------===//
+namespace detail {
+// A utility class to get, or create, unique instances of types within an
+// MLIRContext. This class manages all creation and uniquing of types.
+class TypeUniquer {
+public:
+  /// Get an uniqued instance of a type T.
+  template <typename T, typename... Args>
+  static T get(MLIRContext *ctx, unsigned kind, Args &&... args) {
+    return ctx->getTypeUniquer().get<typename T::ImplType>(
+        [&](TypeStorage *storage) {
+          storage->initializeDialect(lookupDialectForType<T>(ctx));
+        },
+        kind, std::forward<Args>(args)...);
+  }
+
+private:
+  /// Get the dialect that the type 'T' was registered with.
+  template <typename T> static Dialect &lookupDialectForType(MLIRContext *ctx) {
+    return lookupDialectForType(ctx, T::getClassID());
+  }
+
+  /// Get the dialect that registered the type with the provided typeid.
+  static Dialect &lookupDialectForType(MLIRContext *ctx,
+                                       const ClassID *const typeID);
+};
+} // namespace detail
+
+} // end namespace mlir
+
+#endif
diff --git a/third_party/mlir/include/mlir/IR/TypeUtilities.h b/third_party/mlir/include/mlir/IR/TypeUtilities.h
new file mode 100644
index 0000000..ce0169f
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/TypeUtilities.h
@@ -0,0 +1,94 @@
+//===- TypeUtilities.h - Helper function for type queries -------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines generic type utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_TYPEUTILITIES_H
+#define MLIR_SUPPORT_TYPEUTILITIES_H
+
+#include "mlir/IR/Operation.h"
+#include "llvm/ADT/STLExtras.h"
+
+namespace mlir {
+
+class Attribute;
+class TupleType;
+class Type;
+class Value;
+
+//===----------------------------------------------------------------------===//
+// Utility Functions
+//===----------------------------------------------------------------------===//
+
+/// Return the element type or return the type itself.
+Type getElementTypeOrSelf(Type type);
+
+/// Return the element type or return the type itself.
+Type getElementTypeOrSelf(Attribute attr);
+Type getElementTypeOrSelf(Value *val);
+Type getElementTypeOrSelf(Value &val);
+
+/// Get the types within a nested Tuple. A helper for the class method that
+/// handles storage concerns, which is tricky to do in tablegen.
+SmallVector<Type, 10> getFlattenedTypes(TupleType t);
+
+/// Return true if the specified type is an opaque type with the specified
+/// dialect and typeData.
+bool isOpaqueTypeWithName(Type type, StringRef dialect, StringRef typeData);
+
+//===----------------------------------------------------------------------===//
+// Utility Iterators
+//===----------------------------------------------------------------------===//
+
+// An iterator for the element types of an op's operands of shaped types.
+class OperandElementTypeIterator final
+    : public llvm::mapped_iterator<OperandIterator, Type (*)(Value *)> {
+public:
+  using reference = Type;
+
+  /// Initializes the result element type iterator to the specified operand
+  /// iterator.
+  explicit OperandElementTypeIterator(OperandIterator it);
+
+private:
+  static Type unwrap(Value *value);
+};
+
+using OperandElementTypeRange =
+    llvm::iterator_range<OperandElementTypeIterator>;
+
+// An iterator for the tensor element types of an op's results of shaped types.
+class ResultElementTypeIterator final
+    : public llvm::mapped_iterator<ResultIterator, Type (*)(Value *)> {
+public:
+  using reference = Type;
+
+  /// Initializes the result element type iterator to the specified result
+  /// iterator.
+  explicit ResultElementTypeIterator(ResultIterator it);
+
+private:
+  static Type unwrap(Value *value);
+};
+
+using ResultElementTypeRange = llvm::iterator_range<ResultElementTypeIterator>;
+
+} // end namespace mlir
+
+#endif // MLIR_SUPPORT_TYPEUTILITIES_H
diff --git a/third_party/mlir/include/mlir/IR/Types.h b/third_party/mlir/include/mlir/IR/Types.h
new file mode 100644
index 0000000..48c7cb3
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/Types.h
@@ -0,0 +1,313 @@
+//===- Types.h - MLIR Type Classes ------------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_IR_TYPES_H
+#define MLIR_IR_TYPES_H
+
+#include "mlir/IR/TypeSupport.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMapInfo.h"
+
+namespace mlir {
+class FloatType;
+class Identifier;
+class IndexType;
+class IntegerType;
+class MLIRContext;
+class TypeStorage;
+
+namespace detail {
+struct FunctionTypeStorage;
+struct OpaqueTypeStorage;
+} // namespace detail
+
+/// Instances of the Type class are immutable and uniqued.  They wrap a pointer
+/// to the storage object owned by MLIRContext.  Therefore, instances of Type
+/// are passed around by value.
+///
+/// Some types are "primitives" meaning they do not have any parameters, for
+/// example the Index type.  Parametric types have additional information that
+/// differentiates the types of the same kind between them, for example the
+/// Integer type has bitwidth, making i8 and i16 belong to the same kind by be
+/// different instances of the IntegerType.
+///
+/// Types are constructed and uniqued via the 'detail::TypeUniquer' class.
+///
+/// Derived type classes are expected to implement several required
+/// implementaiton hooks:
+///  * Required:
+///    - static bool kindof(unsigned kind);
+///      * Returns if the provided type kind corresponds to an instance of the
+///        current type. Used for isa/dyn_cast casting functionality.
+///
+///  * Optional:
+///    - static LogicalResult verifyConstructionInvariants(
+///                                               llvm::Optional<Location> loc,
+///                                               MLIRContext *context,
+///                                               Args... args)
+///      * This method is invoked when calling the 'TypeBase::get/getChecked'
+///        methods to ensure that the arguments passed in are valid to construct
+///        a type instance with.
+///      * This method is expected to return failure if a type cannot be
+///        constructed with 'args', success otherwise.
+///      * 'args' must correspond with the arguments passed into the
+///        'TypeBase::get' call after the type kind.
+///
+///
+/// Type storage objects inherit from TypeStorage and contain the following:
+///    - The type kind (for LLVM-style RTTI).
+///    - The dialect that defined the type.
+///    - Any parameters of the type.
+/// For non-parametric types, a convenience DefaultTypeStorage is provided.
+/// Parametric storage types must derive TypeStorage and respect the following:
+///    - Define a type alias, KeyTy, to a type that uniquely identifies the
+///      instance of the type within its kind.
+///      * The key type must be constructible from the values passed into the
+///        detail::TypeUniquer::get call after the type kind.
+///      * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
+///        storage class must define a hashing method:
+///         'static unsigned hashKey(const KeyTy &)'
+///
+///    - Provide a method, 'bool operator==(const KeyTy &) const', to
+///      compare the storage instance against an instance of the key type.
+///
+///    - Provide a construction method:
+///        'DerivedStorage *construct(TypeStorageAllocator &, const KeyTy &key)'
+///      that builds a unique instance of the derived storage. The arguments to
+///      this function are an allocator to store any uniqued data within the
+///      context and the key type for this storage.
+class Type {
+public:
+  /// Integer identifier for all the concrete type kinds.
+  /// Note: This is not an enum class as each dialect will likely define a
+  /// separate enumeration for the specific types that they define. Not being an
+  /// enum class also simplifies the handling of type kinds by not requiring
+  /// casts for each use.
+  enum Kind {
+    // Builtin types.
+    Function,
+    Opaque,
+    LAST_BUILTIN_TYPE = Opaque,
+
+  // Reserve type kinds for dialect specific type system extensions.
+#define DEFINE_SYM_KIND_RANGE(Dialect)                                         \
+  FIRST_##Dialect##_TYPE, LAST_##Dialect##_TYPE = FIRST_##Dialect##_TYPE + 0xff,
+#include "DialectSymbolRegistry.def"
+  };
+
+  /// Utility class for implementing types.
+  template <typename ConcreteType, typename BaseType,
+            typename StorageType = DefaultTypeStorage>
+  using TypeBase = detail::StorageUserBase<ConcreteType, BaseType, StorageType,
+                                           detail::TypeUniquer>;
+
+  using ImplType = TypeStorage;
+
+  Type() : impl(nullptr) {}
+  /* implicit */ Type(const ImplType *impl)
+      : impl(const_cast<ImplType *>(impl)) {}
+
+  Type(const Type &other) : impl(other.impl) {}
+  Type &operator=(Type other) {
+    impl = other.impl;
+    return *this;
+  }
+
+  bool operator==(Type other) const { return impl == other.impl; }
+  bool operator!=(Type other) const { return !(*this == other); }
+  explicit operator bool() const { return impl; }
+
+  bool operator!() const { return impl == nullptr; }
+
+  template <typename U> bool isa() const;
+  template <typename U> U dyn_cast() const;
+  template <typename U> U dyn_cast_or_null() const;
+  template <typename U> U cast() const;
+
+  // Support type casting Type to itself.
+  static bool classof(Type) { return true; }
+
+  /// Return the classification for this type.
+  unsigned getKind() const;
+
+  /// Return the LLVMContext in which this type was uniqued.
+  MLIRContext *getContext() const;
+
+  /// Get the dialect this type is registered to.
+  Dialect &getDialect() const;
+
+  // Convenience predicates.  This is only for floating point types,
+  // derived types should use isa/dyn_cast.
+  bool isIndex();
+  bool isBF16();
+  bool isF16();
+  bool isF32();
+  bool isF64();
+
+  /// Return true if this is an integer type with the specified width.
+  bool isInteger(unsigned width);
+
+  /// Return the bit width of an integer or a float type, assert failure on
+  /// other types.
+  unsigned getIntOrFloatBitWidth();
+
+  /// Return true if this is an integer or index type.
+  bool isIntOrIndex();
+  /// Return true if this is an integer, index, or float type.
+  bool isIntOrIndexOrFloat();
+  /// Return true of this is an integer or a float type.
+  bool isIntOrFloat();
+
+  /// Print the current type.
+  void print(raw_ostream &os);
+  void dump();
+
+  friend ::llvm::hash_code hash_value(Type arg);
+
+  unsigned getSubclassData() const;
+  void setSubclassData(unsigned val);
+
+  /// Methods for supporting PointerLikeTypeTraits.
+  const void *getAsOpaquePointer() const {
+    return static_cast<const void *>(impl);
+  }
+  static Type getFromOpaquePointer(const void *pointer) {
+    return Type(reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
+  }
+
+protected:
+  ImplType *impl;
+};
+
+inline raw_ostream &operator<<(raw_ostream &os, Type type) {
+  type.print(os);
+  return os;
+}
+
+/// Function types map from a list of inputs to a list of results.
+class FunctionType
+    : public Type::TypeBase<FunctionType, Type, detail::FunctionTypeStorage> {
+public:
+  using Base::Base;
+
+  static FunctionType get(ArrayRef<Type> inputs, ArrayRef<Type> results,
+                          MLIRContext *context);
+
+  // Input types.
+  unsigned getNumInputs() const { return getSubclassData(); }
+
+  Type getInput(unsigned i) const { return getInputs()[i]; }
+
+  ArrayRef<Type> getInputs() const;
+
+  // Result types.
+  unsigned getNumResults() const;
+
+  Type getResult(unsigned i) const { return getResults()[i]; }
+
+  ArrayRef<Type> getResults() const;
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool kindof(unsigned kind) { return kind == Kind::Function; }
+};
+
+/// Opaque types represent types of non-registered dialects. These are types
+/// represented in their raw string form, and can only usefully be tested for
+/// type equality.
+class OpaqueType
+    : public Type::TypeBase<OpaqueType, Type, detail::OpaqueTypeStorage> {
+public:
+  using Base::Base;
+
+  /// Get or create a new OpaqueType with the provided dialect and string data.
+  static OpaqueType get(Identifier dialect, StringRef typeData,
+                        MLIRContext *context);
+
+  /// Get or create a new OpaqueType with the provided dialect and string data.
+  /// If the given identifier is not a valid namespace for a dialect, then a
+  /// null type is returned.
+  static OpaqueType getChecked(Identifier dialect, StringRef typeData,
+                               MLIRContext *context, Location location);
+
+  /// Returns the dialect namespace of the opaque type.
+  Identifier getDialectNamespace() const;
+
+  /// Returns the raw type data of the opaque type.
+  StringRef getTypeData() const;
+
+  /// Verify the construction of an opaque type.
+  static LogicalResult
+  verifyConstructionInvariants(llvm::Optional<Location> loc,
+                               MLIRContext *context, Identifier dialect,
+                               StringRef typeData);
+
+  static bool kindof(unsigned kind) { return kind == Kind::Opaque; }
+};
+
+// Make Type hashable.
+inline ::llvm::hash_code hash_value(Type arg) {
+  return ::llvm::hash_value(arg.impl);
+}
+
+template <typename U> bool Type::isa() const {
+  assert(impl && "isa<> used on a null type.");
+  return U::classof(*this);
+}
+template <typename U> U Type::dyn_cast() const {
+  return isa<U>() ? U(impl) : U(nullptr);
+}
+template <typename U> U Type::dyn_cast_or_null() const {
+  return (impl && isa<U>()) ? U(impl) : U(nullptr);
+}
+template <typename U> U Type::cast() const {
+  assert(isa<U>());
+  return U(impl);
+}
+
+} // end namespace mlir
+
+namespace llvm {
+
+// Type hash just like pointers.
+template <> struct DenseMapInfo<mlir::Type> {
+  static mlir::Type getEmptyKey() {
+    auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+    return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
+  }
+  static mlir::Type getTombstoneKey() {
+    auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+    return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
+  }
+  static unsigned getHashValue(mlir::Type val) { return mlir::hash_value(val); }
+  static bool isEqual(mlir::Type LHS, mlir::Type RHS) { return LHS == RHS; }
+};
+
+/// We align TypeStorage by 8, so allow LLVM to steal the low bits.
+template <> struct PointerLikeTypeTraits<mlir::Type> {
+public:
+  static inline void *getAsVoidPointer(mlir::Type I) {
+    return const_cast<void *>(I.getAsOpaquePointer());
+  }
+  static inline mlir::Type getFromVoidPointer(void *P) {
+    return mlir::Type::getFromOpaquePointer(P);
+  }
+  enum { NumLowBitsAvailable = 3 };
+};
+
+} // namespace llvm
+
+#endif // MLIR_IR_TYPES_H
diff --git a/third_party/mlir/include/mlir/IR/UseDefLists.h b/third_party/mlir/include/mlir/IR/UseDefLists.h
new file mode 100644
index 0000000..fe0e9e0
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/UseDefLists.h
@@ -0,0 +1,282 @@
+//===- UseDefLists.h --------------------------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines generic use/def list machinery and manipulation utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_USEDEFLISTS_H
+#define MLIR_IR_USEDEFLISTS_H
+
+#include "mlir/IR/Location.h"
+#include "llvm/ADT/PointerIntPair.h"
+#include "llvm/ADT/iterator_range.h"
+
+namespace mlir {
+
+class IROperand;
+class Operation;
+template <typename OperandType> class ValueUseIterator;
+template <typename OperandType> class ValueUserIterator;
+
+class IRObjectWithUseList {
+public:
+  ~IRObjectWithUseList() {
+    assert(use_empty() && "Cannot destroy a value that still has uses!");
+  }
+
+  /// Returns true if this value has no uses.
+  bool use_empty() const { return firstUse == nullptr; }
+
+  /// Returns true if this value has exactly one use.
+  inline bool hasOneUse() const;
+
+  using use_iterator = ValueUseIterator<IROperand>;
+  using use_range = llvm::iterator_range<use_iterator>;
+
+  inline use_iterator use_begin() const;
+  inline use_iterator use_end() const;
+
+  /// Returns a range of all uses, which is useful for iterating over all uses.
+  inline use_range getUses() const;
+
+  using user_iterator = ValueUserIterator<IROperand>;
+  using user_range = llvm::iterator_range<user_iterator>;
+
+  inline user_iterator user_begin() const;
+  inline user_iterator user_end() const;
+
+  /// Returns a range of all users.
+  inline user_range getUsers() const;
+
+  /// Replace all uses of 'this' value with the new value, updating anything in
+  /// the IR that uses 'this' to use the other value instead.  When this returns
+  /// there are zero uses of 'this'.
+  void replaceAllUsesWith(IRObjectWithUseList *newValue);
+
+  /// Drop all uses of this object from their respective owners.
+  void dropAllUses();
+
+protected:
+  IRObjectWithUseList() {}
+
+  /// Return the first IROperand that is using this value, for use by custom
+  /// use/def iterators.
+  IROperand *getFirstUse() { return firstUse; }
+  const IROperand *getFirstUse() const { return firstUse; }
+
+private:
+  friend class IROperand;
+  IROperand *firstUse = nullptr;
+};
+
+/// A reference to a value, suitable for use as an operand of an operation.
+class IROperand {
+public:
+  IROperand(Operation *owner) : owner(owner) {}
+  IROperand(Operation *owner, IRObjectWithUseList *value)
+      : value(value), owner(owner) {
+    insertIntoCurrent();
+  }
+
+  /// Return the current value being used by this operand.
+  IRObjectWithUseList *get() const { return value; }
+
+  /// Set the current value being used by this operand.
+  void set(IRObjectWithUseList *newValue) {
+    // It isn't worth optimizing for the case of switching operands on a single
+    // value.
+    removeFromCurrent();
+    value = newValue;
+    insertIntoCurrent();
+  }
+
+  /// Return the owner of this operand.
+  Operation *getOwner() { return owner; }
+  Operation *getOwner() const { return owner; }
+
+  /// \brief Remove this use of the operand.
+  void drop() {
+    removeFromCurrent();
+    value = nullptr;
+    nextUse = nullptr;
+    back = nullptr;
+  }
+
+  ~IROperand() { removeFromCurrent(); }
+
+  /// Return the next operand on the use-list of the value we are referring to.
+  /// This should generally only be used by the internal implementation details
+  /// of the SSA machinery.
+  IROperand *getNextOperandUsingThisValue() { return nextUse; }
+
+  /// We support a move constructor so IROperand's can be in vectors, but this
+  /// shouldn't be used by general clients.
+  IROperand(IROperand &&other) : owner(other.owner) {
+    *this = std::move(other);
+  }
+  IROperand &operator=(IROperand &&other) {
+    removeFromCurrent();
+    other.removeFromCurrent();
+    value = other.value;
+    other.value = nullptr;
+    other.back = nullptr;
+    nextUse = nullptr;
+    back = nullptr;
+    insertIntoCurrent();
+    return *this;
+  }
+
+private:
+  /// The value used as this operand.  This can be null when in a
+  /// "dropAllUses" state.
+  IRObjectWithUseList *value = nullptr;
+
+  /// The next operand in the use-chain.
+  IROperand *nextUse = nullptr;
+
+  /// This points to the previous link in the use-chain.
+  IROperand **back = nullptr;
+
+  /// The operation owner of this operand.
+  Operation *const owner;
+
+  /// Operands are not copyable or assignable.
+  IROperand(const IROperand &use) = delete;
+  IROperand &operator=(const IROperand &use) = delete;
+
+  void removeFromCurrent() {
+    if (!back)
+      return;
+    *back = nextUse;
+    if (nextUse)
+      nextUse->back = back;
+  }
+
+  void insertIntoCurrent() {
+    back = &value->firstUse;
+    nextUse = value->firstUse;
+    if (nextUse)
+      nextUse->back = &nextUse;
+    value->firstUse = this;
+  }
+};
+
+/// A reference to a value, suitable for use as an operand of an operation,
+/// operation, etc.  IRValueTy is the root type to use for values this tracks,
+/// and SSAUserTy is the type that will contain operands.
+template <typename IRValueTy> class IROperandImpl : public IROperand {
+public:
+  IROperandImpl(Operation *owner) : IROperand(owner) {}
+  IROperandImpl(Operation *owner, IRValueTy *value) : IROperand(owner, value) {}
+
+  /// Return the current value being used by this operand.
+  IRValueTy *get() { return (IRValueTy *)IROperand::get(); }
+
+  /// Set the current value being used by this operand.
+  void set(IRValueTy *newValue) { IROperand::set(newValue); }
+
+  /// Return which operand this is in the operand list of the User.
+  unsigned getOperandNumber();
+};
+
+/// An iterator over all uses of a ValueBase.
+template <typename OperandType>
+class ValueUseIterator
+    : public std::iterator<std::forward_iterator_tag, OperandType> {
+public:
+  ValueUseIterator() = default;
+  explicit ValueUseIterator(OperandType *current) : current(current) {}
+  OperandType *operator->() const { return current; }
+  OperandType &operator*() const { return *current; }
+
+  Operation *getUser() const { return current->getOwner(); }
+
+  ValueUseIterator &operator++() {
+    assert(current && "incrementing past end()!");
+    current = (OperandType *)current->getNextOperandUsingThisValue();
+    return *this;
+  }
+
+  ValueUseIterator operator++(int unused) {
+    ValueUseIterator copy = *this;
+    ++*this;
+    return copy;
+  }
+
+  friend bool operator==(ValueUseIterator lhs, ValueUseIterator rhs) {
+    return lhs.current == rhs.current;
+  }
+
+  friend bool operator!=(ValueUseIterator lhs, ValueUseIterator rhs) {
+    return !(lhs == rhs);
+  }
+
+private:
+  OperandType *current;
+};
+
+inline auto IRObjectWithUseList::use_begin() const -> use_iterator {
+  return use_iterator(firstUse);
+}
+
+inline auto IRObjectWithUseList::use_end() const -> use_iterator {
+  return use_iterator(nullptr);
+}
+
+inline auto IRObjectWithUseList::getUses() const -> use_range {
+  return {use_begin(), use_end()};
+}
+
+/// Returns true if this value has exactly one use.
+inline bool IRObjectWithUseList::hasOneUse() const {
+  return firstUse && firstUse->getNextOperandUsingThisValue() == nullptr;
+}
+
+/// An iterator over all users of a ValueBase.
+template <typename OperandType>
+class ValueUserIterator final
+    : public llvm::mapped_iterator<ValueUseIterator<OperandType>,
+                                   Operation *(*)(OperandType &)> {
+  static Operation *unwrap(OperandType &value) { return value.getOwner(); }
+
+public:
+  using pointer = Operation *;
+  using reference = Operation *;
+
+  /// Initializes the result type iterator to the specified result iterator.
+  ValueUserIterator(ValueUseIterator<OperandType> it)
+      : llvm::mapped_iterator<ValueUseIterator<OperandType>,
+                              Operation *(*)(OperandType &)>(it, &unwrap) {}
+  Operation *operator->() { return **this; }
+};
+
+inline auto IRObjectWithUseList::user_begin() const -> user_iterator {
+  return user_iterator(use_begin());
+}
+
+inline auto IRObjectWithUseList::user_end() const -> user_iterator {
+  return user_iterator(use_end());
+}
+
+inline auto IRObjectWithUseList::getUsers() const -> user_range {
+  return {user_begin(), user_end()};
+}
+
+} // namespace mlir
+
+#endif
diff --git a/third_party/mlir/include/mlir/IR/Value.h b/third_party/mlir/include/mlir/IR/Value.h
new file mode 100644
index 0000000..110c74f
--- /dev/null
+++ b/third_party/mlir/include/mlir/IR/Value.h
@@ -0,0 +1,166 @@
+//===- Value.h - Base of the SSA Value hierarchy ----------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines generic Value type and manipulation utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_VALUE_H
+#define MLIR_IR_VALUE_H
+
+#include "mlir/IR/Types.h"
+#include "mlir/IR/UseDefLists.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class Block;
+class Operation;
+class Region;
+class Value;
+
+/// Operands contain a Value.
+using OpOperand = IROperandImpl<Value>;
+
+/// This is the common base class for all SSA values in the MLIR system,
+/// representing a computable value that has a type and a set of users.
+///
+class Value : public IRObjectWithUseList {
+public:
+  /// This enumerates all of the SSA value kinds in the MLIR system.
+  enum class Kind {
+    BlockArgument, // block argument
+    OpResult,      // operation result
+  };
+
+  ~Value() {}
+
+  Kind getKind() const { return typeAndKind.getInt(); }
+
+  Type getType() const { return typeAndKind.getPointer(); }
+
+  /// Utility to get the associated MLIRContext that this value is defined in.
+  MLIRContext *getContext() const { return getType().getContext(); }
+
+  /// Mutate the type of this Value to be of the specified type.
+  ///
+  /// Note that this is an extremely dangerous operation which can create
+  /// completely invalid IR very easily.  It is strongly recommended that you
+  /// recreate IR objects with the right types instead of mutating them in
+  /// place.
+  void setType(Type newType) { typeAndKind.setPointer(newType); }
+
+  /// Replace all uses of 'this' value with the new value, updating anything in
+  /// the IR that uses 'this' to use the other value instead.  When this returns
+  /// there are zero uses of 'this'.
+  void replaceAllUsesWith(Value *newValue) {
+    IRObjectWithUseList::replaceAllUsesWith(newValue);
+  }
+
+  /// If this value is the result of an operation, return the operation that
+  /// defines it.
+  Operation *getDefiningOp();
+
+  /// If this value is the result of an operation, use it as a location,
+  /// otherwise return an unknown location.
+  Location getLoc();
+
+  /// Return the Region in which this Value is defined.
+  Region *getParentRegion();
+
+  using use_iterator = ValueUseIterator<OpOperand>;
+  using use_range = llvm::iterator_range<use_iterator>;
+
+  inline use_iterator use_begin();
+  inline use_iterator use_end();
+
+  /// Returns a range of all uses, which is useful for iterating over all uses.
+  inline use_range getUses();
+
+  void print(raw_ostream &os);
+  void dump();
+
+protected:
+  Value(Kind kind, Type type) : typeAndKind(type, kind) {}
+
+private:
+  llvm::PointerIntPair<Type, 1, Kind> typeAndKind;
+};
+
+inline raw_ostream &operator<<(raw_ostream &os, Value &value) {
+  value.print(os);
+  return os;
+}
+
+// Utility functions for iterating through Value uses.
+inline auto Value::use_begin() -> use_iterator {
+  return use_iterator((OpOperand *)getFirstUse());
+}
+
+inline auto Value::use_end() -> use_iterator { return use_iterator(nullptr); }
+
+inline auto Value::getUses() -> llvm::iterator_range<use_iterator> {
+  return {use_begin(), use_end()};
+}
+
+/// Block arguments are values.
+class BlockArgument : public Value {
+public:
+  static bool classof(const Value *value) {
+    return const_cast<Value *>(value)->getKind() == Kind::BlockArgument;
+  }
+
+  Block *getOwner() { return owner; }
+
+  /// Returns the number of this argument.
+  unsigned getArgNumber();
+
+private:
+  friend class Block; // For access to private constructor.
+  BlockArgument(Type type, Block *owner)
+      : Value(Value::Kind::BlockArgument, type), owner(owner) {}
+
+  /// The owner of this operand.
+  /// TODO: can encode this more efficiently to avoid the space hit of this
+  /// through bitpacking shenanigans.
+  Block *const owner;
+};
+
+/// This is a value defined by a result of an operation.
+class OpResult : public Value {
+public:
+  OpResult(Type type, Operation *owner)
+      : Value(Value::Kind::OpResult, type), owner(owner) {}
+
+  static bool classof(const Value *value) {
+    return const_cast<Value *>(value)->getKind() == Kind::OpResult;
+  }
+
+  Operation *getOwner() { return owner; }
+
+  /// Returns the number of this result.
+  unsigned getResultNumber();
+
+private:
+  /// The owner of this operand.
+  /// TODO: can encode this more efficiently to avoid the space hit of this
+  /// through bitpacking shenanigans.
+  Operation *const owner;
+};
+
+} // namespace mlir
+
+#endif
diff --git a/third_party/mlir/include/mlir/LLVMIR/CMakeLists.txt b/third_party/mlir/include/mlir/LLVMIR/CMakeLists.txt
new file mode 100644
index 0000000..1d7d06b
--- /dev/null
+++ b/third_party/mlir/include/mlir/LLVMIR/CMakeLists.txt
@@ -0,0 +1,16 @@
+set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
+mlir_tablegen(LLVMOps.h.inc -gen-op-decls)
+mlir_tablegen(LLVMOps.cpp.inc -gen-op-defs)
+mlir_tablegen(LLVMOpsEnums.h.inc -gen-enum-decls)
+mlir_tablegen(LLVMOpsEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRLLVMOpsIncGen)
+set(LLVM_TARGET_DEFINITIONS NVVMOps.td)
+mlir_tablegen(NVVMOps.h.inc -gen-op-decls)
+mlir_tablegen(NVVMOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRNVVMOpsIncGen)
+set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
+mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions)
+add_public_tablegen_target(MLIRLLVMConversionsIncGen)
+set(LLVM_TARGET_DEFINITIONS NVVMOps.td)
+mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions)
+add_public_tablegen_target(MLIRNVVMConversionsIncGen)
diff --git a/third_party/mlir/include/mlir/LLVMIR/LLVMDialect.h b/third_party/mlir/include/mlir/LLVMIR/LLVMDialect.h
new file mode 100644
index 0000000..00f5be4d
--- /dev/null
+++ b/third_party/mlir/include/mlir/LLVMIR/LLVMDialect.h
@@ -0,0 +1,180 @@
+//===- LLVMDialect.h - MLIR LLVM IR dialect ---------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines the LLVM IR dialect in MLIR, containing LLVM operations and
+// LLVM type system.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMDIALECT_H_
+#define MLIR_TARGET_LLVMDIALECT_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeSupport.h"
+#include "mlir/IR/Types.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+
+#include "mlir/LLVMIR/LLVMOpsEnums.h.inc"
+
+namespace llvm {
+class Type;
+class LLVMContext;
+} // end namespace llvm
+
+namespace mlir {
+namespace LLVM {
+class LLVMDialect;
+
+namespace detail {
+struct LLVMTypeStorage;
+struct LLVMDialectImpl;
+} // namespace detail
+
+class LLVMType : public mlir::Type::TypeBase<LLVMType, mlir::Type,
+                                             detail::LLVMTypeStorage> {
+public:
+  enum Kind {
+    LLVM_TYPE = FIRST_LLVM_TYPE,
+  };
+
+  using Base::Base;
+
+  static bool kindof(unsigned kind) { return kind == LLVM_TYPE; }
+
+  LLVMDialect &getDialect();
+  llvm::Type *getUnderlyingType() const;
+
+  /// Array type utilities.
+  LLVMType getArrayElementType();
+  unsigned getArrayNumElements();
+
+  /// Vector type utilities.
+  LLVMType getVectorElementType();
+
+  /// Function type utilities.
+  LLVMType getFunctionParamType(unsigned argIdx);
+  unsigned getFunctionNumParams();
+  LLVMType getFunctionResultType();
+
+  /// Pointer type utilities.
+  LLVMType getPointerTo(unsigned addrSpace = 0);
+  LLVMType getPointerElementTy();
+
+  /// Struct type utilities.
+  LLVMType getStructElementType(unsigned i);
+
+  /// Utilities used to generate floating point types.
+  static LLVMType getDoubleTy(LLVMDialect *dialect);
+  static LLVMType getFloatTy(LLVMDialect *dialect);
+  static LLVMType getHalfTy(LLVMDialect *dialect);
+
+  /// Utilities used to generate integer types.
+  static LLVMType getIntNTy(LLVMDialect *dialect, unsigned numBits);
+  static LLVMType getInt1Ty(LLVMDialect *dialect) {
+    return getIntNTy(dialect, /*numBits=*/1);
+  }
+  static LLVMType getInt8Ty(LLVMDialect *dialect) {
+    return getIntNTy(dialect, /*numBits=*/8);
+  }
+  static LLVMType getInt8PtrTy(LLVMDialect *dialect) {
+    return getInt8Ty(dialect).getPointerTo();
+  }
+  static LLVMType getInt16Ty(LLVMDialect *dialect) {
+    return getIntNTy(dialect, /*numBits=*/16);
+  }
+  static LLVMType getInt32Ty(LLVMDialect *dialect) {
+    return getIntNTy(dialect, /*numBits=*/32);
+  }
+  static LLVMType getInt64Ty(LLVMDialect *dialect) {
+    return getIntNTy(dialect, /*numBits=*/64);
+  }
+
+  /// Utilities used to generate other miscellaneous types.
+  static LLVMType getArrayTy(LLVMType elementType, uint64_t numElements);
+  static LLVMType getFunctionTy(LLVMType result, ArrayRef<LLVMType> params,
+                                bool isVarArg);
+  static LLVMType getFunctionTy(LLVMType result, bool isVarArg) {
+    return getFunctionTy(result, llvm::None, isVarArg);
+  }
+  static LLVMType getStructTy(LLVMDialect *dialect, ArrayRef<LLVMType> elements,
+                              bool isPacked = false);
+  static LLVMType getStructTy(LLVMDialect *dialect, bool isPacked = false) {
+    return getStructTy(dialect, llvm::None, isPacked);
+  }
+  template <typename... Args>
+  static typename std::enable_if<llvm::are_base_of<LLVMType, Args...>::value,
+                                 LLVMType>::type
+  getStructTy(LLVMType elt1, Args... elts) {
+    SmallVector<LLVMType, 8> fields({elt1, elts...});
+    return getStructTy(&elt1.getDialect(), fields);
+  }
+  static LLVMType getVectorTy(LLVMType elementType, unsigned numElements);
+  static LLVMType getVoidTy(LLVMDialect *dialect);
+
+private:
+  friend LLVMDialect;
+
+  /// Get an LLVMType with a pre-existing llvm type.
+  static LLVMType get(MLIRContext *context, llvm::Type *llvmType);
+
+  /// Get an LLVMType with an llvm type that may cause changes to the underlying
+  /// llvm context when constructed.
+  static LLVMType getLocked(LLVMDialect *dialect,
+                            llvm::function_ref<llvm::Type *()> typeBuilder);
+};
+
+///// Ops /////
+#define GET_OP_CLASSES
+#include "mlir/LLVMIR/LLVMOps.h.inc"
+
+class LLVMDialect : public Dialect {
+public:
+  explicit LLVMDialect(MLIRContext *context);
+  ~LLVMDialect();
+  static StringRef getDialectNamespace() { return "llvm"; }
+
+  llvm::LLVMContext &getLLVMContext();
+  llvm::Module &getLLVMModule();
+
+  /// Parse a type registered to this dialect.
+  Type parseType(StringRef tyData, Location loc) const override;
+
+  /// Print a type registered to this dialect.
+  void printType(Type type, raw_ostream &os) const override;
+
+  /// Verify a region argument attribute registered to this dialect.
+  /// Returns failure if the verification failed, success otherwise.
+  LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIdx,
+                                         unsigned argIdx,
+                                         NamedAttribute argAttr) override;
+
+private:
+  friend LLVMType;
+
+  std::unique_ptr<detail::LLVMDialectImpl> impl;
+};
+
+} // end namespace LLVM
+} // end namespace mlir
+
+#endif // MLIR_TARGET_LLVMDIALECT_H_
diff --git a/third_party/mlir/include/mlir/LLVMIR/LLVMOpBase.td b/third_party/mlir/include/mlir/LLVMIR/LLVMOpBase.td
new file mode 100644
index 0000000..a68cdbf
--- /dev/null
+++ b/third_party/mlir/include/mlir/LLVMIR/LLVMOpBase.td
@@ -0,0 +1,59 @@
+//===-- LLVMOpBase.td - LLVM IR dialect shared definitions -*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains shared definitions for the LLVM IR dialect and its
+// subdialects.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef LLVMIR_OP_BASE
+#else
+#define LLVMIR_OP_BASE
+
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+def LLVM_Dialect : Dialect {
+  let name = "llvm";
+  let cppNamespace = "LLVM";
+}
+
+// LLVM IR type wrapped in MLIR.
+def LLVM_Type : Type<CPred<"$_self.isa<::mlir::LLVM::LLVMType>()">,
+                     "LLVM dialect type">;
+
+// Base class for LLVM operations. Defines the interface to the llvm::IRBuilder
+// used to translate to LLVM IR proper.
+class LLVM_OpBase<Dialect dialect, string mnemonic, list<OpTrait> traits = []> :
+    Op<dialect, mnemonic, traits> {
+  // A pattern for constructing the LLVM IR Instruction (or other Value) that
+  // corresponds to this op.  This pattern can use `builder` to refer to an
+  // `llvm::IRBuilder<>` instance, $-names of arguments and results and the
+  // following special variable names:
+  //   - $_resultType - substituted with the LLVM IR type of the result;
+  //   - $_numOperands - substituted with the number of operands (including
+  //                     the variadic ones);
+  //   - $_hasResult - substituted with a check that a variadic-result op does
+  //                   have a result (LLVM ops can have 0 or 1 result);
+  //   - $_location - mlir::Location object of the instruction.
+  // Additionally, `$$` can be used to produce the dollar character.
+  string llvmBuilder = "";
+}
+
+#endif  // LLVMIR_OP_BASE
diff --git a/third_party/mlir/include/mlir/LLVMIR/LLVMOps.td b/third_party/mlir/include/mlir/LLVMIR/LLVMOps.td
new file mode 100644
index 0000000..b626836
--- /dev/null
+++ b/third_party/mlir/include/mlir/LLVMIR/LLVMOps.td
@@ -0,0 +1,475 @@
+//===-- LLVMOps.td - LLVM IR dialect op definition file ----*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is the LLVM IR operation definition file.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef LLVMIR_OPS
+#else
+#define LLVMIR_OPS
+
+include "mlir/LLVMIR/LLVMOpBase.td"
+
+// Base class for LLVM operations.  All operations get an "llvm." prefix in
+// their name automatically.  LLVM operations have either zero or one result,
+// this class is specialized below for both cases and should not be used
+// directly.
+class LLVM_Op<string mnemonic, list<OpTrait> traits = []> :
+    LLVM_OpBase<LLVM_Dialect, mnemonic, traits> {
+}
+
+class LLVM_Builder<string builder> {
+  string llvmBuilder = builder;
+}
+
+def LLVM_OneResultOpBuilder : OpBuilder<
+  "Builder *, OperationState *result, Type resultType, "
+  "ArrayRef<Value *> operands, ArrayRef<NamedAttribute> attributes = {}",
+  [{
+    if (resultType) result->addTypes(resultType);
+    result->addOperands(operands);
+    for (auto namedAttr : attributes) {
+      result->addAttribute(namedAttr.first, namedAttr.second);
+    }
+  }]>;
+
+def LLVM_ZeroResultOpBuilder : OpBuilder<
+  "Builder *, OperationState *result, ArrayRef<Value *> operands, "
+  "ArrayRef<NamedAttribute> attributes = {}",
+  [{
+    result->addOperands(operands);
+    for (auto namedAttr : attributes) {
+      result->addAttribute(namedAttr.first, namedAttr.second);
+    }
+  }]>;
+
+class LLVM_TwoBuilders<OpBuilder b1, OpBuilder b2> {
+  list<OpBuilder> builders = [b1, b2];
+}
+
+// Base class for LLVM operations with one result.
+class LLVM_OneResultOp<string mnemonic, list<OpTrait> traits = []> :
+    LLVM_Op<mnemonic, traits>, Results<(outs LLVM_Type:$res)> {
+  let builders = [LLVM_OneResultOpBuilder];
+}
+
+// Compatibility builder that takes an instance of wrapped llvm::VoidType
+// to indicate no result.
+def LLVM_VoidResultTypeOpBuilder : OpBuilder<
+  "Builder *builder, OperationState *result, Type resultType, "
+  "ArrayRef<Value *> operands, ArrayRef<NamedAttribute> attributes = {}",
+  [{
+    auto llvmType = resultType.dyn_cast<LLVM::LLVMType>(); (void)llvmType;
+    assert(llvmType && "result must be an LLVM type");
+    assert(llvmType.getUnderlyingType() &&
+            llvmType.getUnderlyingType()->isVoidTy() &&
+            "for zero-result operands, only 'void' is accepted as result type");
+    build(builder, result, operands, attributes);
+  }]>;
+
+// Base class for LLVM operations with zero results.
+class LLVM_ZeroResultOp<string mnemonic, list<OpTrait> traits = []> :
+    LLVM_Op<mnemonic, traits>, Results<(outs)>,
+    LLVM_TwoBuilders<LLVM_VoidResultTypeOpBuilder, LLVM_ZeroResultOpBuilder>;
+
+// Base class for LLVM terminator operations.  All terminator operations have
+// zero results and an optional list of successors.
+class LLVM_TerminatorOp<string mnemonic, list<OpTrait> traits = []> :
+    LLVM_Op<mnemonic, !listconcat(traits, [Terminator])>,
+    Arguments<(ins Variadic<LLVM_Type>:$args)>, Results<(outs)> {
+  let builders = [OpBuilder<
+    "Builder *, OperationState *result, "
+    "ArrayRef<Value *> properOperands, "
+    "ArrayRef<Block *> destinations, "
+    "ArrayRef<ArrayRef<Value *>> operands = {}, "
+    "ArrayRef<NamedAttribute> attributes = {}",
+    [{
+      result->addOperands(properOperands);
+      for (auto kvp : llvm::zip(destinations, operands)) {
+        result->addSuccessor(std::get<0>(kvp), std::get<1>(kvp));
+      }
+      for (auto namedAttr : attributes) {
+        result->addAttribute(namedAttr.first, namedAttr.second);
+      }
+    }]
+  >];
+}
+
+// Class for arithmetic binary operations.
+class LLVM_ArithmeticOp<string mnemonic, string builderFunc,
+                        list<OpTrait> traits = []> :
+    LLVM_OneResultOp<mnemonic,
+           !listconcat([NoSideEffect, SameOperandsAndResultType], traits)>,
+    Arguments<(ins LLVM_Type:$lhs, LLVM_Type:$rhs)>,
+    LLVM_Builder<"$res = builder." # builderFunc # "($lhs, $rhs);"> {
+  let parser = [{ return impl::parseBinaryOp(parser, result); }];
+  let printer = [{ mlir::impl::printBinaryOp(this->getOperation(), p); }];
+}
+
+// Integer binary operations.
+def LLVM_AddOp : LLVM_ArithmeticOp<"add", "CreateAdd", [Commutative]>;
+def LLVM_SubOp : LLVM_ArithmeticOp<"sub", "CreateSub">;
+def LLVM_MulOp : LLVM_ArithmeticOp<"mul", "CreateMul", [Commutative]>;
+def LLVM_UDivOp : LLVM_ArithmeticOp<"udiv", "CreateUDiv">;
+def LLVM_SDivOp : LLVM_ArithmeticOp<"sdiv", "CreateSDiv">;
+def LLVM_URemOp : LLVM_ArithmeticOp<"urem", "CreateURem">;
+def LLVM_SRemOp : LLVM_ArithmeticOp<"srem", "CreateSRem">;
+def LLVM_AndOp : LLVM_ArithmeticOp<"and", "CreateAnd">;
+def LLVM_OrOp : LLVM_ArithmeticOp<"or", "CreateOr">;
+def LLVM_XOrOp : LLVM_ArithmeticOp<"xor", "CreateXor">;
+
+// Predicate for integer comparisons.
+def ICmpPredicateEQ  : I64EnumAttrCase<"eq", 0>;
+def ICmpPredicateNE  : I64EnumAttrCase<"ne", 1>;
+def ICmpPredicateSLT : I64EnumAttrCase<"slt", 2>;
+def ICmpPredicateSLE : I64EnumAttrCase<"sle", 3>;
+def ICmpPredicateSGT : I64EnumAttrCase<"sgt", 4>;
+def ICmpPredicateSGE : I64EnumAttrCase<"sge", 5>;
+def ICmpPredicateULT : I64EnumAttrCase<"ult", 6>;
+def ICmpPredicateULE : I64EnumAttrCase<"ule", 7>;
+def ICmpPredicateUGT : I64EnumAttrCase<"ugt", 8>;
+def ICmpPredicateUGE : I64EnumAttrCase<"uge", 9>;
+def ICmpPredicate : I64EnumAttr<
+    "ICmpPredicate",
+    "llvm.icmp comparison predicate",
+    [ICmpPredicateEQ, ICmpPredicateNE, ICmpPredicateSLT, ICmpPredicateSLE,
+     ICmpPredicateSGT, ICmpPredicateSGE, ICmpPredicateULT, ICmpPredicateULE,
+     ICmpPredicateUGT, ICmpPredicateUGE]> {
+  let cppNamespace = "mlir::LLVM";
+
+  let returnType = "ICmpPredicate";
+  let convertFromStorage =
+      "static_cast<" # returnType # ">($_self.getValue().getZExtValue())";
+}
+
+// Other integer operations.
+def LLVM_ICmpOp : LLVM_OneResultOp<"icmp", [NoSideEffect]>,
+                  Arguments<(ins ICmpPredicate:$predicate, LLVM_Type:$lhs,
+                             LLVM_Type:$rhs)> {
+  let llvmBuilder = [{
+    $res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
+  }];
+  let parser = [{ return parseCmpOp<ICmpPredicate>(parser, result); }];
+  let printer = [{ printICmpOp(p, *this); }];
+}
+
+// Predicate for float comparisons
+def FCmpPredicateFALSE  : I64EnumAttrCase<"_false", 0>;
+def FCmpPredicateOEQ    : I64EnumAttrCase<"oeq", 1>;
+def FCmpPredicateOGT    : I64EnumAttrCase<"ogt", 2>;
+def FCmpPredicateOGE    : I64EnumAttrCase<"oge", 3>;
+def FCmpPredicateOLT    : I64EnumAttrCase<"olt", 4>;
+def FCmpPredicateOLE    : I64EnumAttrCase<"ole", 5>;
+def FCmpPredicateONE    : I64EnumAttrCase<"one", 6>;
+def FCmpPredicateORD    : I64EnumAttrCase<"ord", 7>;
+def FCmpPredicateUEQ    : I64EnumAttrCase<"ueq", 8>;
+def FCmpPredicateUGT    : I64EnumAttrCase<"ugt", 9>;
+def FCmpPredicateUGE    : I64EnumAttrCase<"uge", 10>;
+def FCmpPredicateULT    : I64EnumAttrCase<"ult", 11>;
+def FCmpPredicateULE    : I64EnumAttrCase<"ule", 12>;
+def FCmpPredicateUNE    : I64EnumAttrCase<"une", 13>;
+def FCmpPredicateUNO    : I64EnumAttrCase<"uno", 14>;
+def FCmpPredicateTRUE   : I64EnumAttrCase<"_true", 15>;
+
+def FCmpPredicate : I64EnumAttr<
+    "FCmpPredicate",
+    "llvm.fcmp comparison predicate",
+    [FCmpPredicateFALSE, FCmpPredicateOEQ, FCmpPredicateOGT, FCmpPredicateOGE,
+     FCmpPredicateOLT, FCmpPredicateOLE, FCmpPredicateONE, FCmpPredicateORD,
+     FCmpPredicateUEQ, FCmpPredicateUGT, FCmpPredicateUGE, FCmpPredicateULT,
+     FCmpPredicateULE, FCmpPredicateUNE, FCmpPredicateUNO, FCmpPredicateTRUE 
+    ]> {
+  let cppNamespace = "mlir::LLVM";
+
+  let returnType = "FCmpPredicate";
+  let convertFromStorage =
+      "static_cast<" # returnType # ">($_self.getValue().getZExtValue())";
+}
+
+// Other integer operations.
+def LLVM_FCmpOp : LLVM_OneResultOp<"fcmp", [NoSideEffect]>,
+                  Arguments<(ins FCmpPredicate:$predicate, LLVM_Type:$lhs,
+                             LLVM_Type:$rhs)> {
+  let llvmBuilder = [{
+    $res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
+  }];
+  let parser = [{ return parseCmpOp<FCmpPredicate>(parser, result); }];
+  let printer = [{ printFCmpOp(p, *this); }];
+}
+
+// Floating point binary operations.
+def LLVM_FAddOp : LLVM_ArithmeticOp<"fadd", "CreateFAdd">;
+def LLVM_FSubOp : LLVM_ArithmeticOp<"fsub", "CreateFSub">;
+def LLVM_FMulOp : LLVM_ArithmeticOp<"fmul", "CreateFMul">;
+def LLVM_FDivOp : LLVM_ArithmeticOp<"fdiv", "CreateFDiv">;
+def LLVM_FRemOp : LLVM_ArithmeticOp<"frem", "CreateFRem">;
+
+// Memory-related operations.
+def LLVM_AllocaOp : LLVM_OneResultOp<"alloca">,
+                    Arguments<(ins LLVM_Type:$arraySize)> {
+  string llvmBuilder = [{
+    $res = builder.CreateAlloca($_resultType->getPointerElementType(),
+                                $arraySize);
+  }];
+  let parser = [{ return parseAllocaOp(parser, result); }];
+  let printer = [{ printAllocaOp(p, *this); }];
+}
+def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>,
+                 Arguments<(ins LLVM_Type:$base, Variadic<LLVM_Type>:$indices)>,
+                 LLVM_Builder<"$res = builder.CreateGEP($base, $indices);"> {
+  let parser = [{ return parseGEPOp(parser, result); }];
+  let printer = [{ printGEPOp(p, *this); }];
+}
+def LLVM_LoadOp : LLVM_OneResultOp<"load">, Arguments<(ins LLVM_Type:$addr)>,
+                  LLVM_Builder<"$res = builder.CreateLoad($addr);"> {
+  let parser = [{ return parseLoadOp(parser, result); }];
+  let printer = [{ printLoadOp(p, *this); }];
+}
+def LLVM_StoreOp : LLVM_ZeroResultOp<"store">,
+                   Arguments<(ins LLVM_Type:$value, LLVM_Type:$addr)>,
+                   LLVM_Builder<"builder.CreateStore($value, $addr);"> {
+  let parser = [{ return parseStoreOp(parser, result); }];
+  let printer = [{ printStoreOp(p, *this); }];
+}
+
+// Casts.
+class LLVM_CastOp<string mnemonic, string builderFunc,
+                  list<OpTrait> traits = []> :
+    LLVM_OneResultOp<mnemonic,
+           !listconcat([NoSideEffect], traits)>,
+    Arguments<(ins LLVM_Type:$arg)>,
+    LLVM_Builder<"$res = builder." # builderFunc # "($arg, $_resultType);"> {
+  let parser = [{ return mlir::impl::parseCastOp(parser, result); }];
+  let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }];
+}
+def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast">;
+def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "CreateIntToPtr">;
+def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "CreatePtrToInt">;
+def LLVM_SExtOp : LLVM_CastOp<"sext", "CreateSExt">;
+def LLVM_ZExtOp : LLVM_CastOp<"zext", "CreateZExt">;
+def LLVM_TruncOp : LLVM_CastOp<"trunc", "CreateTrunc">;
+def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "CreateSIToFP">;
+
+// Call-related operations.
+def LLVM_CallOp : LLVM_Op<"call">,
+                  Arguments<(ins OptionalAttr<SymbolRefAttr>:$callee,
+                             // TODO(b/133216756): fix test failure and
+                             // change to LLVM_Type
+                             Variadic<AnyType>)>,
+                  Results<(outs Variadic<LLVM_Type>)>,
+                  LLVM_TwoBuilders<LLVM_OneResultOpBuilder,
+                                   LLVM_ZeroResultOpBuilder> {
+  let verifier = [{
+    if (getNumResults() > 1)
+      return emitOpError("must have 0 or 1 result");
+    return success();
+  }];
+  let parser = [{ return parseCallOp(parser, result); }];
+  let printer = [{ printCallOp(p, *this); }];
+}
+def LLVM_ExtractElementOp : LLVM_OneResultOp<"extractelement", [NoSideEffect]>,
+                          Arguments<(ins LLVM_Type:$vector,
+                                     LLVM_Type:$position)> {
+  string llvmBuilder = [{
+    $res = builder.CreateExtractElement($vector, $position);
+  }];
+  let builders = [OpBuilder<
+    "Builder *b, OperationState *result, Value *vector, Value *position,"
+    "ArrayRef<NamedAttribute> attrs = {}">];
+  let parser = [{ return parseExtractElementOp(parser, result); }];
+  let printer = [{ printExtractElementOp(p, *this); }];
+}
+def LLVM_ExtractValueOp : LLVM_OneResultOp<"extractvalue", [NoSideEffect]>,
+                          Arguments<(ins LLVM_Type:$container,
+                                     ArrayAttr:$position)> {
+  string llvmBuilder = [{
+    $res = builder.CreateExtractValue($container, extractPosition($position));
+  }];
+  let parser = [{ return parseExtractValueOp(parser, result); }];
+  let printer = [{ printExtractValueOp(p, *this); }];
+}
+def LLVM_InsertElementOp : LLVM_OneResultOp<"insertelement", [NoSideEffect]>,
+                         Arguments<(ins LLVM_Type:$vector, LLVM_Type:$value,
+                                    LLVM_Type:$position)> {
+  string llvmBuilder = [{
+    $res = builder.CreateInsertElement($vector, $value, $position);
+  }];
+  let parser = [{ return parseInsertElementOp(parser, result); }];
+  let printer = [{ printInsertElementOp(p, *this); }];
+}
+def LLVM_InsertValueOp : LLVM_OneResultOp<"insertvalue", [NoSideEffect]>,
+                         Arguments<(ins LLVM_Type:$container, LLVM_Type:$value,
+                                    ArrayAttr:$position)> {
+  string llvmBuilder = [{
+    $res = builder.CreateInsertValue($container, $value,
+                                     extractPosition($position));
+  }];
+  let parser = [{ return parseInsertValueOp(parser, result); }];
+  let printer = [{ printInsertValueOp(p, *this); }];
+}
+def LLVM_ShuffleVectorOp
+    : LLVM_OneResultOp<"shufflevector", [NoSideEffect]>,
+      Arguments<(ins LLVM_Type:$v1, LLVM_Type:$v2, I32ArrayAttr:$mask)>,
+      LLVM_Builder<
+      "$res = builder.CreateShuffleVector($v1, $v2, extractPosition($mask));"> {
+  let builders = [OpBuilder<
+    "Builder *b, OperationState *result, Value *v1, Value *v2, "
+    "ArrayAttr mask, ArrayRef<NamedAttribute> attrs = {}">];
+  let verifier = [{
+    auto wrappedVectorType1 = v1()->getType().cast<LLVM::LLVMType>();
+    auto wrappedVectorType2 = v2()->getType().cast<LLVM::LLVMType>();
+    if (!wrappedVectorType2.getUnderlyingType()->isVectorTy())
+      return emitOpError("expected LLVM IR Dialect vector type for operand #2");
+    if (wrappedVectorType1.getVectorElementType() !=
+        wrappedVectorType2.getVectorElementType())
+      return emitOpError("expected matching LLVM IR Dialect element types");
+    return success();
+  }];
+  let parser = [{ return parseShuffleVectorOp(parser, result); }];
+  let printer = [{ printShuffleVectorOp(p, *this); }];
+}
+
+// Misc operations.
+def LLVM_SelectOp
+    : LLVM_OneResultOp<"select", [NoSideEffect]>,
+      Arguments<(ins LLVM_Type:$condition, LLVM_Type:$trueValue,
+                 LLVM_Type:$falseValue)>,
+      LLVM_Builder<
+          "$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> {
+  let parser = [{ return parseSelectOp(parser, result); }];
+  let printer = [{ printSelectOp(p, *this); }];
+}
+
+// Terminators.
+def LLVM_BrOp : LLVM_TerminatorOp<"br", []> {
+  let parser = [{ return parseBrOp(parser, result); }];
+  let printer = [{ printBrOp(p, *this); }];
+}
+def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", []> {
+  let verifier = [{
+    if (getNumSuccessors() != 2)
+      return emitOpError("expected exactly two successors");
+    return success();
+  }];
+  let parser = [{ return parseCondBrOp(parser, result); }];
+  let printer = [{ printCondBrOp(p, *this); }];
+}
+def LLVM_ReturnOp : LLVM_TerminatorOp<"return", []> {
+  string llvmBuilder = [{
+    if ($_numOperands != 0)
+      builder.CreateRet($args[0]);
+    else
+      builder.CreateRetVoid();
+  }];
+
+  let verifier = [{
+    if (getNumOperands() > 1)
+      return emitOpError("expects at most 1 operand");
+    return success();
+  }];
+
+  let parser = [{ return parseReturnOp(parser, result); }];
+  let printer = [{ printReturnOp(p, *this); }];
+}
+def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> {
+  string llvmBuilder = [{ builder.CreateUnreachable(); }];
+  let parser = [{ return success(); }];
+  let printer = [{ *p << getOperationName(); }];
+}
+
+// Pseudo-operations (do not appear in LLVM IR but necessary for the dialect to
+// work correctly).
+def LLVM_GlobalOp
+    : LLVM_ZeroResultOp<"global">,
+      Arguments<(ins TypeAttr:$type, UnitAttr:$constant, StrAttr:$sym_name,
+                 AnyAttr:$value)> {
+
+  let builders = [
+    OpBuilder<"Builder *builder, OperationState *result, LLVMType type, "
+              "bool isConstant, StringRef name, Attribute value, "
+              "ArrayRef<NamedAttribute> attrs = {}">
+  ];
+
+  let extraClassDeclaration = [{
+    /// Return the LLVM type of the global.
+    LLVMType getType() {
+      return type().cast<LLVMType>();
+    }
+  }];
+
+  let printer = "printGlobalOp(p, *this);";
+  let parser = "return parseGlobalOp(parser, result);";
+  let verifier = "return ::verify(*this);";
+}
+
+def LLVM_LLVMFuncOp : LLVM_ZeroResultOp<"func",
+      [NativeOpTrait<"IsIsolatedFromAbove">, NativeOpTrait<"FunctionLike">]> {
+  let summary = "LLVM dialect function, has wrapped LLVM IR function type";
+
+  let regions = (region AnyRegion:$body);
+
+  let skipDefaultBuilders = 1;
+
+  let builders = [
+    OpBuilder<"Builder *builder, OperationState *result, StringRef name, "
+              "LLVMType type, ArrayRef<NamedAttribute> attrs, "
+              "ArrayRef<NamedAttributeList> argAttrs = {}">
+  ];
+
+  let extraClassDeclaration = [{
+    LLVMType getType() {
+      return getAttrOfType<TypeAttr>(getTypeAttrName())
+          .getValue().cast<LLVMType>();
+    }
+    bool isVarArg() {
+      return getType().getUnderlyingType()->isFunctionVarArg();
+    }
+
+    // Hook for OpTrait::FunctionLike, returns the number of function arguments.
+    // Depends on the type attribute being correct as checked by verifyType.
+    unsigned getNumFuncArguments();
+
+    // Hook for OpTrait::FunctionLike, called after verifying that the 'type'
+    // attribute is present.  This can check for preconditions of the
+    // getNumArguments hook not failing.
+    LogicalResult verifyType();
+  }];
+
+  let verifier = [{ return ::verify(*this); }];
+  let printer = [{ printLLVMFuncOp(p, *this); }];
+  let parser = [{
+    return impl::parseFunctionLikeOp(parser, result, /*allowVariadic=*/true,
+                                     buildLLVMFunctionType);
+  }];
+}
+
+def LLVM_UndefOp : LLVM_OneResultOp<"undef", [NoSideEffect]>,
+                   LLVM_Builder<"$res = llvm::UndefValue::get($_resultType);"> {
+  let parser = [{ return parseUndefOp(parser, result); }];
+  let printer = [{ printUndefOp(p, *this); }];
+}
+def LLVM_ConstantOp
+    : LLVM_OneResultOp<"constant", [NoSideEffect]>,
+      Arguments<(ins AnyAttr:$value)>,
+      LLVM_Builder<"$res = getLLVMConstant($_resultType, $value, $_location);">
+{
+  let parser = [{ return parseConstantOp(parser, result); }];
+  let printer = [{ printConstantOp(p, *this); }];
+}
+
+#endif // LLVMIR_OPS
diff --git a/third_party/mlir/include/mlir/LLVMIR/NVVMDialect.h b/third_party/mlir/include/mlir/LLVMIR/NVVMDialect.h
new file mode 100644
index 0000000..206f868
--- /dev/null
+++ b/third_party/mlir/include/mlir/LLVMIR/NVVMDialect.h
@@ -0,0 +1,43 @@
+//===- NVVMDialect.h - MLIR NVVM IR dialect ---------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines the NVVM IR dialect in MLIR, containing NVVM operations and
+// NVVM specific extensions to the LLVM type system.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_LLVMIR_NVVMDIALECT_H_
+#define MLIR_LLVMIR_NVVMDIALECT_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+namespace mlir {
+namespace NVVM {
+
+///// Ops /////
+#define GET_OP_CLASSES
+#include "mlir/LLVMIR/NVVMOps.h.inc"
+
+class NVVMDialect : public Dialect {
+public:
+  explicit NVVMDialect(MLIRContext *context);
+};
+
+} // namespace NVVM
+} // namespace mlir
+
+#endif /* MLIR_LLVMIR_NVVMDIALECT_H_ */
diff --git a/third_party/mlir/include/mlir/LLVMIR/NVVMOps.td b/third_party/mlir/include/mlir/LLVMIR/NVVMOps.td
new file mode 100644
index 0000000..18be599
--- /dev/null
+++ b/third_party/mlir/include/mlir/LLVMIR/NVVMOps.td
@@ -0,0 +1,60 @@
+//===-- NVVMOps.td - NVVM IR dialect op definition file ----*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is the NVVM IR operation definition file.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef NVVMIR_OPS
+#else
+#define NVVMIR_OPS
+
+include "mlir/LLVMIR/LLVMOpBase.td"
+
+def NVVM_Dialect : Dialect {
+  let name = "nvvm";
+  let cppNamespace = "NVVM";
+}
+
+class NVVM_Op<string mnemonic, list<OpTrait> traits = []> :
+  LLVM_OpBase<NVVM_Dialect, mnemonic, traits> {
+}
+
+class NVVM_SpecialRegisterOp<string mnemonic,
+    list<OpTrait> traits = []> :
+  NVVM_Op<mnemonic, !listconcat(traits, [NoSideEffect])>,
+  Results<(outs LLVM_Type:$res)>, Arguments<(ins)> {
+  string llvmBuilder = "$res = createIntrinsicCall(builder,"
+    # "llvm::Intrinsic::nvvm_" # !subst(".","_", mnemonic) # ");";
+  let parser = [{ return parseNVVMSpecialRegisterOp(parser, result); }];
+  let printer = [{ printNVVMSpecialRegisterOp(p, this->getOperation()); }];
+}
+
+def NVVM_ThreadIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.x">;
+def NVVM_ThreadIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.y">;
+def NVVM_ThreadIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.z">;
+def NVVM_BlockDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.x">;
+def NVVM_BlockDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.y">;
+def NVVM_BlockDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.z">;
+def NVVM_BlockIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.x">;
+def NVVM_BlockIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.y">;
+def NVVM_BlockIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.z">;
+def NVVM_GridDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.x">;
+def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">;
+def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">;
+
+#endif // NVVMIR_OPS
diff --git a/third_party/mlir/include/mlir/Linalg/Analysis/DependenceAnalysis.h b/third_party/mlir/include/mlir/Linalg/Analysis/DependenceAnalysis.h
new file mode 100644
index 0000000..de5a28d
--- /dev/null
+++ b/third_party/mlir/include/mlir/Linalg/Analysis/DependenceAnalysis.h
@@ -0,0 +1,137 @@
+//===- DependenceAnalysis.h - Dependence analysis on SSA views --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
+#define MLIR_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace linalg {
+
+class LinalgOp;
+
+/// A very primitive alias analysis which just records for each view, either:
+///   1. The base buffer, or
+///   2. The block argument view
+/// that it indexes into.
+/// This does not perform inter-block or inter-procedural analysis and assumes
+/// that different block argument views do not alias.
+class Aliases {
+public:
+  /// Returns true if v1 and v2 alias.
+  bool alias(Value *v1, Value *v2) { return find(v1) == find(v2); }
+
+private:
+  /// Returns the base buffer or block argument into which the view `v` aliases.
+  /// This lazily records the new aliases discovered while walking back the
+  /// use-def chain.
+  Value *find(Value *v);
+
+  DenseMap<Value *, Value *> aliases;
+};
+
+/// Data structure for holding a dependence graph that operates on LinalgOp and
+/// views as SSA values.
+class LinalgDependenceGraph {
+public:
+  struct LinalgOpView {
+    Operation *op;
+    Value *view;
+  };
+  struct LinalgDependenceGraphElem {
+    // dependentOpView may be either:
+    //   1. src in the case of dependencesIntoGraphs.
+    //   2. dst in the case of dependencesFromDstGraphs.
+    LinalgOpView dependentOpView;
+    // View in the op that is used to index in the graph:
+    //   1. src in the case of dependencesFromDstGraphs.
+    //   2. dst in the case of dependencesIntoGraphs.
+    Value *indexingView;
+  };
+  using LinalgDependences = llvm::SmallVector<LinalgDependenceGraphElem, 8>;
+  using DependenceGraph = DenseMap<Operation *, LinalgDependences>;
+  using dependence_iterator = LinalgDependences::iterator;
+  using dependence_range = llvm::iterator_range<dependence_iterator>;
+
+  enum DependenceType { RAR = 0, RAW, WAR, WAW, NumTypes };
+
+  LinalgDependenceGraph(Aliases &aliases, ArrayRef<Operation *> ops);
+
+  /// Returns the X such that op -> X is a dependence of type dt.
+  dependence_range getDependencesFrom(Operation *src, DependenceType dt);
+  dependence_range getDependencesFrom(LinalgOp src, DependenceType dt);
+
+  /// Returns the X such that X -> op is a dependence of type dt.
+  dependence_range getDependencesInto(Operation *dst, DependenceType dt);
+  dependence_range getDependencesInto(LinalgOp dst, DependenceType dt);
+
+  /// Returns the operations that are interleaved between `srcLinalgOp` and
+  /// `dstLinalgOp` and that are involved in any RAW, WAR or WAW dependence
+  /// relation with `srcLinalgOp`, on any view.
+  /// Any such operation prevents reordering.
+  SmallVector<Operation *, 8> findCoveringDependences(LinalgOp srcLinalgOp,
+                                                      LinalgOp dstLinalgOp);
+
+  /// Returns the operations that are interleaved between `srcLinalgOp` and
+  /// `dstLinalgOp` and that are involved in a RAR or RAW with `srcLinalgOp`.
+  /// Dependences are restricted to views aliasing `view`.
+  SmallVector<Operation *, 8>
+  findCoveringReads(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value *view);
+
+  /// Returns the operations that are interleaved between `srcLinalgOp` and
+  /// `dstLinalgOp` and that are involved in a WAR or WAW with `srcLinalgOp`.
+  /// Dependences are restricted to views aliasing `view`.
+  SmallVector<Operation *, 8>
+  findCoveringWrites(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value *view);
+
+private:
+  // Keep dependences in both directions, this is not just a performance gain
+  // but it also reduces usage errors.
+  // Dependence information is stored as a map of:
+  //   (source operation -> LinalgDependenceGraphElem)
+  DependenceGraph dependencesFromGraphs[DependenceType::NumTypes];
+  // Reverse dependence information is stored as a map of:
+  //   (destination operation -> LinalgDependenceGraphElem)
+  DependenceGraph dependencesIntoGraphs[DependenceType::NumTypes];
+
+  /// Analyses the aliasing views between `src` and `dst` and inserts the proper
+  /// dependences in the graph.
+  void addDependencesBetween(LinalgOp src, LinalgOp dst);
+
+  // Adds an new dependence unit in the proper graph.
+  // Uses std::pair to keep operations and view together and avoid usage errors
+  // related to src/dst and producer/consumer terminology in the context of
+  // dependences.
+  void addDependenceElem(DependenceType dt, LinalgOpView indexingOpView,
+                         LinalgOpView dependentOpView);
+
+  /// Implementation detail for findCoveringxxx.
+  SmallVector<Operation *, 8>
+  findOperationsWithCoveringDependences(LinalgOp srcLinalgOp,
+                                        LinalgOp dstLinalgOp, Value *view,
+                                        ArrayRef<DependenceType> types);
+
+  Aliases &aliases;
+  SmallVector<Operation *, 8> linalgOps;
+  DenseMap<Operation *, unsigned> linalgOpPositions;
+};
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
diff --git a/third_party/mlir/include/mlir/Linalg/CMakeLists.txt b/third_party/mlir/include/mlir/Linalg/CMakeLists.txt
new file mode 100644
index 0000000..f33061b
--- /dev/null
+++ b/third_party/mlir/include/mlir/Linalg/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/third_party/mlir/include/mlir/Linalg/IR/CMakeLists.txt b/third_party/mlir/include/mlir/Linalg/IR/CMakeLists.txt
new file mode 100644
index 0000000..b0c7266
--- /dev/null
+++ b/third_party/mlir/include/mlir/Linalg/IR/CMakeLists.txt
@@ -0,0 +1,8 @@
+set(LLVM_TARGET_DEFINITIONS LinalgOps.td)
+mlir_tablegen(LinalgOps.h.inc -gen-op-decls)
+mlir_tablegen(LinalgOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRLinalgOpsIncGen)
+set(LLVM_TARGET_DEFINITIONS LinalgLibraryOps.td)
+mlir_tablegen(LinalgLibraryOps.h.inc -gen-op-decls)
+mlir_tablegen(LinalgLibraryOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRLinalgLibraryOpsIncGen)
diff --git a/third_party/mlir/include/mlir/Linalg/IR/LinalgBase.td b/third_party/mlir/include/mlir/Linalg/IR/LinalgBase.td
new file mode 100644
index 0000000..4ea6651
--- /dev/null
+++ b/third_party/mlir/include/mlir/Linalg/IR/LinalgBase.td
@@ -0,0 +1,47 @@
+//===- LinalgBase.td - Linalg dialect base support ---------*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is the definition file for base linear algebra support.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+#ifdef LINALG_BASE
+#else
+#define LINALG_BASE
+
+def Linalg_Dialect : Dialect {
+  let name = "linalg";
+}
+
+// Whether a type is a BufferType.
+def LinalgIsBufferTypePred : CPred<"$_self.isa<BufferType>()">;
+def Buffer : Type<LinalgIsBufferTypePred, "buffer">;
+
+// Whether a type is a RangeType.
+def LinalgIsRangeTypePred : CPred<"$_self.isa<RangeType>()">;
+def Range : Type<LinalgIsRangeTypePred, "range">;
+
+// Whether a type is a ViewType.
+def LinalgIsViewTypePred : CPred<"$_self.isa<ViewType>()">;
+def View : Type<LinalgIsViewTypePred, "view">;
+
+#endif // LINALG_BASE
\ No newline at end of file
diff --git a/third_party/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td b/third_party/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td
new file mode 100644
index 0000000..998d68b
--- /dev/null
+++ b/third_party/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td
@@ -0,0 +1,415 @@
+//===- LinalgLibraryOps.td - Linalg dialect library ops -*- tablegen ----*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is the operation definition file for linear algebra operations that
+// correspond to underlying library calls (e.g. BLAS).
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef LINALG_LIBRARY_OPS
+#else
+#define LINALG_LIBRARY_OPS
+
+include "mlir/AffineOps/AffineOpsBase.td"
+include "mlir/Linalg/IR/LinalgBase.td"
+
+class LinalgParametricNativeOpTrait<string prop, string parameters> :
+  NativeOpTrait<"linalg::" # prop # parameters>
+{}
+
+class LinalgParametricIntNativeOpTrait<string prop, list<int> parameters> :
+  LinalgParametricNativeOpTrait<
+    prop,
+    !strconcat("<",
+               !cast<string>(!head(parameters)),
+               !foldl("",
+                      !tail(parameters),
+                      sum,
+                      param,
+                      sum # "," # !cast<string>(param)),
+               ">::Impl")>
+{}
+
+// The Linalg `NInputsAndOutputs` trait provides the API for ops that are known
+// to have a specified number of inputs and outputs, all passed as operands.
+// See Linalg/LinalgTraits.h for implementation details an usage.
+class NInputsAndOutputs<int n_ins, int n_outs> :
+  LinalgParametricIntNativeOpTrait<"NInputsAndOutputs", [n_ins, n_outs]>
+{}
+
+// The linalg `NLoopTypes` trait provides the API for ops that are known to have
+// a specified number of parallel (n_par), reduction (n_red) and window (n_win)
+// loops.
+// See Linalg/LinalgTraits.h for implementation details an usage.
+class NLoopTypes<int n_par, int n_red, int n_win> :
+LinalgParametricIntNativeOpTrait<"NLoopTypes", [n_par, n_red, n_win]>
+{}
+
+// The linalg `ViewRanks` trait the API for ops that are known to have a
+// specified list of view ranks.
+// See Linalg/LinalgTraits.h for implementation details an usage.
+class ViewRanks<list<int> ranks> :
+LinalgParametricIntNativeOpTrait<"ViewRanks", ranks>
+{}
+
+def ViewTraits : NativeOpTrait<"linalg::ViewTraits">;
+
+// Base Tablegen class for Linalg ops.
+// Linalg ops that correspond to library calls operate on linalg::View as their
+// first operands. These may be optionally followed by non-view operands
+// depending on the specific Linalg op.
+class LinalgLibraryBase_Op<string mnemonic, list<OpTrait> props>
+  : Op<Linalg_Dialect, mnemonic, !listconcat(props, [ViewTraits])> {
+  let parser = [{ return parseLinalgLibraryOp(parser, result); }];
+  let printer = [{ printLinalgLibraryOp(p, *this); }];
+}
+
+class LinalgLibrary_Op<string mnemonic, list<OpTrait> props>
+  : LinalgLibraryBase_Op<mnemonic, props> {
+  code libraryCallName = [{
+    std::string getLibraryCallName() {
+      return generateLibraryCallName(getOperation());
+    }
+  }];
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Concrete Linalg ops.
+////////////////////////////////////////////////////////////////////////////////
+def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> {
+  let description = [{
+    Copies the data in the input view into the output view.
+
+    Usage:
+      linalg.copy(%arg0, %arg1) : !linalg.view<?xf32>, !linalg.view<?xf32>
+
+    One possible lowering to loop form is:
+      %0 = linalg.dim %arg0, 0 : index
+      loop.for %i0 = %c0 to %0 step %c1 {
+        %1 = linalg.load %arg0[%i0] : !linalg.view<?xf32>
+        linalg.store %1, %arg1[%i0] : !linalg.view<?xf32>
+      }
+
+    Optionally, can take `input_permutation` and `output_permutation` attributes
+    to reorder the dimensions of the input and output views.
+
+    Usage:
+      linalg.copy(%arg0, %arg1) {inputPermutation : (i, j, k) -> (i, k, j),
+                                 outputPermutation : (i, j, k) -> (k, j, i)} :
+        !linalg.view<?x?x?xf32>, !linalg.view<?x?x?xf32>
+
+    One possible lowering to loop form is:
+      %0 = linalg.dim %arg0, 0
+      %1 = linalg.dim %arg0, 1
+      %2 = linalg.dim %arg0, 2
+      loop.for %i0 = %c0 to %{{.*}} step %c1 {
+        loop.for %i1 = %c0 to %{{.*}} step %c1 {
+          loop.for %i2 = %c0 to %{{.*}} step %c1 {
+            %3 = linalg.load %arg0[%i0, %i2, %i1] : !linalg.view<?x?x?xf32>
+            linalg.store %3, %arg1[%i2, %i1, %i0] : !linalg.view<?x?x?xf32>
+
+    The views are expected to be compatible for correctness but this is not
+    enforced at the moment.
+  }];
+  let arguments = (ins
+    View,
+    View,
+    OptionalAttr<AffineMapAttr>:$inputPermutation,
+    OptionalAttr<AffineMapAttr>:$outputPermutation);
+  // TODO(ntv) this should go away once the usage of OptionalAttr triggers
+  // emission of builders with default arguments left unspecified.
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, Value *input, Value *output", [{
+    return build(
+      builder, result, input, output, AffineMapAttr(), AffineMapAttr());
+  }]>];
+  let extraClassDeclaration = libraryCallName # [{
+    unsigned getNumParallelLoops() {
+      auto *view = *(getOperands().begin());
+      return view->getType().cast<ViewType>().getRank();
+    }
+    unsigned getNumReductionLoops() { return 0; }
+    unsigned getNumWindowLoops() { return 0; }
+  }];
+  let verifier = [{ return ::verify(*this); }];
+}
+
+def FillOp : LinalgLibrary_Op<"fill", [NInputsAndOutputs<0, 1>]> {
+  let arguments = (ins View, AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>);
+  let extraClassDeclaration = libraryCallName # [{
+    unsigned getNumParallelLoops() {
+      auto *view = *(getOperands().begin());
+      return view->getType().cast<ViewType>().getRank();
+    }
+    unsigned getNumReductionLoops() { return 0; }
+    unsigned getNumWindowLoops() { return 0; }
+    Value *getValue() {
+      return *(getOperands().begin() + getNumInputsAndOutputs());
+    }
+  }];
+  let verifier = [{ return ::verify(*this); }];
+}
+
+def DotOp : LinalgLibrary_Op<"dot",
+                            [NInputsAndOutputs<2, 1>,
+                             NLoopTypes<0, 1, 0>,
+                             ViewRanks<[1, 1, 0]>]> {
+  let arguments = (ins View, View, View);
+  let extraClassDeclaration = libraryCallName;
+}
+
+def MatvecOp : LinalgLibrary_Op<"matvec",
+                                  [NInputsAndOutputs<2, 1>,
+                                   NLoopTypes<1, 1, 0>,
+                                   ViewRanks<[2, 1, 1]>]> {
+  let arguments = (ins View, View, View);
+  let extraClassDeclaration = libraryCallName;
+}
+
+def MatmulOp : LinalgLibrary_Op<"matmul",
+                                  [NInputsAndOutputs<2, 1>,
+                                   NLoopTypes<2, 1, 0>,
+                                   ViewRanks<[2, 2, 2]>]> {
+  let arguments = (ins View, View, View);
+  let extraClassDeclaration = libraryCallName;
+}
+
+def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> {
+  let description = [{
+    Generic n-D convolution as described in the TF documentation:
+    https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/nn/convolution
+
+    ```
+      output[b, x[0], ..., x[N-1], k] =
+      sum_{z[0], ..., z[N-1], q}
+          filter[z[0], ..., z[N-1], q, k] *
+          padded_input[b,
+                       x[0] * strides[0] + dilation_rate[0] * z[0],
+                       ...,
+                       x[N-1] * strides[N-1] + dilation_rate[N-1] * z[N-1],
+                       q]
+    ```
+  }];
+  // TODO(ntv) padding.
+  // Following the TF source of truth above, strides and dilations are integer
+  // attributes of the same rank as the number of window dimensions.
+  let arguments = (ins View:$filter, View:$input, View:$output,
+                   OptionalAttr<I64ArrayAttr>:$strides,
+                   OptionalAttr<I64ArrayAttr>:$dilations);
+  let extraClassDeclaration = libraryCallName # [{
+    // TODO(ntv) extend to support more than 1 dimensions and potentially
+    // grouping too.
+    unsigned getNumBatchDimensions() { return 1; }
+    unsigned getNumInputFeatureDimensions() { return 1; }
+    unsigned getNumOutputFeatureDimensions() { return 1; }
+
+    // Outer parallel loops are always the number of output dimensions; i.e.
+    // [ b, xs, q] in the TF notation above.
+    unsigned getNumParallelLoops() { return getOutputViewType(0).getRank(); }
+
+    // Window loops are a special kind of reduction that is neither tiled or
+    // parallelized across; i.e. [zs] in the TF notation above whose number
+    // match `xs` (i.e. 1 window loop per "image" dimension).
+    unsigned getNumWindowLoops() {
+      return getNumParallelLoops() - getNumBatchDimensions() -
+             getNumInputFeatureDimensions(); }
+
+    // Reduction loops are exactly the non-parallel, non-window loops (i.e. `q`)
+    // We distinguish between reduction loops and convolution window loops for
+    // now. That distinction may disappear in the future.
+    unsigned getNumReductionLoops() { return getNumInputFeatureDimensions(); }
+
+    int64_t getStride(unsigned i) {
+      assert(i < getNumWindowLoops());
+      if (!strides().hasValue()) return 1;
+      return strides()->getValue()[i]
+        .cast<IntegerAttr>().getValue().getSExtValue();
+    }
+
+    int64_t getDilation(unsigned i) {
+      assert(i < getNumWindowLoops());
+      if (!dilations().hasValue()) return 1;
+      return dilations()->getValue()[i]
+        .cast<IntegerAttr>().getValue().getSExtValue();
+    }
+  }];
+  let verifier = [{ return ::verify(*this); }];
+}
+
+def GenericOp : LinalgLibraryBase_Op<"generic", []> {
+  let description = [{
+    Generic Linalg op form where the key properties of the computation are
+    specified as attributes. In pretty form, a linalg.generic op is written as:
+
+      ```
+        linalg.generic #trait_attribute %A, %B, %C {other-attributes} :
+          !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+      ```
+
+    Where #trait_attributes is an alias of a dictionary attribute containing:
+      - doc [optional]: a documentation string
+      - fun: a SymbolRefAttr that must resolve to an existing function symbol.
+        To support inplace updates in a generic fashion, the signature of the
+        function must be:
+        ```
+          fun([input views element types], [output views element types])
+            -> ([output views element types])
+        ```
+      - indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input
+        and output view. Such AffineMapAttr specifies the mapping between the
+        loops and the indexing within each view.
+      - library_call [optional]: a StringAttr containing the name of an
+        external library function that the linalg.generic operation maps to.
+        The external library is assumed to be dynamically linked and no strong
+        compile-time guarantees are provided. In the absence of such a library
+        call, linalg.generic will always lower to loops.
+      - n_loops: a triple of I64Attr representing the number of enclosing
+        [parallel, reduction, window] loops respectively.
+      - n_views: a pair of I64Attr representing the number of input (readonly)
+        and output (readwrite) views.
+
+    Example:
+    Defining a #matmul_trait attribute in MLIR can be done as follows:
+      ```
+        func @fma(%a: f32, %b: f32, %c: f32) -> f32 {
+          %d = mulf %a, %b: f32
+          %e = addf %c, %d: f32
+          return %e: f32
+        }
+        #matmul_accesses = [
+          (m, n, k) -> (m, k),
+          (m, n, k) -> (k, n),
+          (m, n, k) -> (m, n)
+        ]
+        #matmul_trait = {
+          doc = "C(m, n) += A(m, k) * B(k, n)",
+          fun = @fma,
+          indexing_maps = #matmul_accesses,
+          library_call = "linalg_matmul",
+          n_views = [2, 1],
+          n_loop_types = [2, 1, 0]
+        }
+      ```
+
+    And can be reused in multiple places as:
+      ```
+        linalg.generic #matmul_trait %A, %B, %C [other-attributes] :
+          !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+      ```
+
+    This may lower to either:
+      ```
+        call @linalg_matmul(%A, %B, %C) :
+          (!linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>)
+          -> ()
+      ```
+
+    or IR resembling:
+    ```
+    loop.for %m = %c0 to %M step %c1 {
+      loop.for %n = %c0 to %N step %c1 {
+        loop.for %k = %c0 to %K step %c1 {
+          %a = linalg.load %A[%m, %k] : !linalg.view<?x?xf32>
+          %b = linalg.load %B[%k, %n] : !linalg.view<?x?xf32>
+          %c = linalg.load %C[%m, %n] : !linalg.view<?x?xf32>
+          %d = call @mac(%a, %b, %c) : (f32, f32, f32) -> (f32)
+          linalg.store %d, %C[%m, %n] : !linalg.view<?x?x?xf32>
+        }
+      }
+    }
+    ```
+  }];
+  let arguments = (ins Variadic<View>:$views,
+                   AffineMapArrayAttr:$indexing_maps,
+                   I64ArrayAttr:$n_loop_types,
+                   I64ArrayAttr:$n_views,
+                   OptionalAttr<StrAttr>:$doc,
+                   OptionalAttr<SymbolRefAttr>:$fun,
+                   OptionalAttr<StrAttr>:$library_call);
+  let regions = (region AnyRegion:$region);
+  let extraClassDeclaration = [{
+    SmallVector<StringRef, 8> linalgTraitAttrNames() {
+      return SmallVector<StringRef, 8>{
+        "doc", "fun", "indexing_maps", "library_call", "n_loop_types", "n_views"
+      };
+    }
+    unsigned getNumInputs() {
+      if (!getAttr("n_views") || n_views().getValue().size() != 2)
+        return 0;
+      auto val = n_views().getValue()[0].cast<IntegerAttr>().getValue();
+      assert(val.getSExtValue() >= 0);
+      return val.getZExtValue();
+    }
+    unsigned getNumOutputs() {
+      if (!getAttr("n_views") || n_views().getValue().size() != 2)
+        return 0;
+      auto val = n_views().getValue()[1].cast<IntegerAttr>().getValue();
+      assert(val.getSExtValue() >= 0);
+      return val.getZExtValue();
+    }
+    unsigned getNumParallelLoops() {
+      if (!getAttr("n_loop_types") || n_loop_types().getValue().size() != 3)
+        return 0;
+      auto val = n_loop_types().getValue()[0].cast<IntegerAttr>().getValue();
+      assert(val.getSExtValue() >= 0);
+      return val.getZExtValue();
+    }
+    unsigned getNumReductionLoops() {
+      if (!getAttr("n_loop_types") || n_loop_types().getValue().size() != 3)
+        return 0;
+      auto val = n_loop_types().getValue()[1].cast<IntegerAttr>().getValue();
+      assert(val.getSExtValue() >= 0);
+      return val.getZExtValue();
+    }
+    unsigned getNumWindowLoops() {
+      if (!getAttr("n_loop_types") || n_loop_types().getValue().size() != 3)
+        return 0;
+      auto val = n_loop_types().getValue()[2].cast<IntegerAttr>().getValue();
+      assert(val.getSExtValue() >= 0);
+      return val.getZExtValue();
+    }
+    unsigned getNumLoops() {
+      return getNumParallelLoops() + getNumReductionLoops() +
+        getNumWindowLoops();
+    }
+    FuncOp getFunction() {
+      auto moduleOp = getParentOfType<ModuleOp>();
+      return fun().hasValue() ?
+        moduleOp.lookupSymbol<FuncOp>(fun().getValue()) : FuncOp();
+    }
+    StringRef getLibraryCallName() {
+      return library_call().hasValue() ? library_call().getValue() : "";
+    }
+    AffineMap getIndexingMap(unsigned i) {
+      assert(i < getNumInputsAndOutputs());
+      return indexing_maps().getValue()[i].cast<AffineMapAttr>().getValue();
+    }
+    AffineMap getInputIndexingMap(unsigned i) {
+      assert(i < getNumInputs());
+      return indexing_maps().getValue()[i].cast<AffineMapAttr>().getValue();
+    }
+    AffineMap getOutputIndexingMap(unsigned i) {
+      assert(i < getNumOutputs());
+      return indexing_maps().getValue()[i + getNumInputs()]
+          .cast<AffineMapAttr>().getValue();
+    }
+  }];
+  let printer = [{ return ::print(p, *this); }];
+  let verifier = [{ return ::verify(*this); }];
+  let parser = [{ return ::parse$cppClass(parser, result); }];
+}
+#endif // LINALG_LIBRARY_OPS
diff --git a/third_party/mlir/include/mlir/Linalg/IR/LinalgOps.h b/third_party/mlir/include/mlir/Linalg/IR/LinalgOps.h
new file mode 100644
index 0000000..3187f4f
--- /dev/null
+++ b/third_party/mlir/include/mlir/Linalg/IR/LinalgOps.h
@@ -0,0 +1,425 @@
+//===- LinalgOps.h - Linalg Operations --------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_LINALG_LINALGOPS_H_
+#define MLIR_LINALG_LINALGOPS_H_
+
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Linalg/IR/LinalgTraits.h"
+#include "mlir/Linalg/IR/LinalgTypes.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class OperationFolder;
+
+namespace linalg {
+
+/// A linalg.LoadOp is the counterpart of load but operating on ViewType
+/// instead of MemRefType.
+///
+/// ```{.mlir}
+///    %0 = linalg.load %V[%c0] : !linalg.view<?xf32>
+/// ```
+class LoadOp
+    : public Op<LoadOp, OpTrait::VariadicOperands, OpTrait::OneResult> {
+public:
+  using Op::Op;
+
+  // Hooks to customize the behavior of this op.
+  static llvm::StringRef getOperationName() { return "linalg.load"; }
+  static void build(Builder *b, OperationState *result, Value *view,
+                    ArrayRef<Value *> indices = {});
+  LogicalResult verify();
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+
+  // Op-specific functionality.
+  unsigned getRank() { return getViewType().getRank(); }
+  ViewType getViewType() { return getView()->getType().cast<ViewType>(); }
+  Value *getView() { return getOperand(0); }
+  Operation::operand_range getIndices() {
+    return {operand_begin() + 1, operand_end()};
+  }
+};
+
+/// The "linalg.range" op creates a linalg.range from 3 values of type `index`
+/// that represent the min, max and step values of the range.
+///
+/// ```{.mlir}
+///    %3 = linalg.range %0:%1:%2 : !linalg.range
+/// ```
+class RangeOp : public Op<RangeOp, OpTrait::NOperands<3>::Impl,
+                          OpTrait::OneResult, OpTrait::HasNoSideEffect> {
+public:
+  using Op::Op;
+
+  // Hooks to customize the behavior of this op.
+  static llvm::StringRef getOperationName() { return "linalg.range"; }
+  static void build(Builder *b, OperationState *result, Value *min, Value *max,
+                    Value *step);
+  LogicalResult verify();
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+
+  // Op-specific functionality.
+  Value *min() { return getOperand(0); }
+  Value *max() { return getOperand(1); }
+  Value *step() { return getOperand(2); }
+};
+
+/// The "linalg.slice" op produces a linalg.view which is a subview of a given
+/// base view. This allows defining a subregion within the underlying buffer to
+/// operate on only a subset of the buffer.
+///
+/// A "linalg.slice" op takes a base view and a variadic number of indexings and
+/// produces a linalg.view of the same elemental type as the buffer. An indexing
+/// is either:
+///   1. a linalg.range, in which case it does not reduce the rank of the parent
+///      view.
+///   2. an index, in which case it reduces the rank of the parent view by one.
+///
+/// The parent view must be a base view (i.e. either a function argument or has
+/// been produced by a linalg.view op). In other words, chains of
+/// linalg.slice operations cannot be constructed in the IR. This defines away
+/// problems related to keeping track of which dimensions of the base view have
+/// been rank-reduced.
+///
+/// Examples:
+///   1. rank-preserving slice:
+///
+/// ```{.mlir}
+///    %4 = linalg.slice %0[%1, %2] : !linalg.view<?x?xf32>, !linalg.range,
+///    !linalg.range, !linalg.view<?x?xf32>
+/// ```
+///
+///   2. rank-reducing slice (from 2-D to 1-D):
+///
+/// ```{.mlir}
+///    %4 = linalg.slice %0[%1, %2] : !linalg.view<?x?xf32>, index,
+///    !linalg.range, !linalg.view<?xf32>
+/// ```
+///
+///   3. rank-reducing slice (from 2-D to 0-D):
+///
+/// ```{.mlir}
+///    %4 = linalg.slice %0[%1, %2] : !linalg.view<?x?xf32>, index, index,
+///    !linalg.view<f32>
+/// ```
+class ViewOp;
+class SliceOp : public Op<SliceOp, OpTrait::VariadicOperands,
+                          OpTrait::OneResult, OpTrait::HasNoSideEffect> {
+  enum { FirstIndexingOperand = 1 };
+
+public:
+  using Op::Op;
+
+  // Hooks to customize the behavior of this op.
+  static llvm::StringRef getOperationName() { return "linalg.slice"; }
+  static void build(Builder *b, OperationState *result, Value *base,
+                    llvm::ArrayRef<Value *> indexings);
+  LogicalResult verify();
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+
+  // Op-specific functionality.
+  unsigned getRank() { return getViewType().getRank(); }
+  Type getElementType() { return getViewType().getElementType(); }
+  ViewType getViewType() { return getType().cast<ViewType>(); }
+  Value *getBaseView() { return getOperand(0); }
+  ViewOp getBaseViewOp();
+  ViewType getBaseViewType();
+  unsigned getBaseViewRank() { return getBaseViewType().getRank(); }
+  // Get the underlying indexing at a given rank.
+  Value *getIndexing(unsigned rank) { return *(getIndexings().begin() + rank); }
+  // Get all the indexings in this view.
+  Operation::operand_range getIndexings() {
+    return {operand_begin() + SliceOp::FirstIndexingOperand, operand_end()};
+  }
+  // Get the subset of indexings that are of RangeType.
+  SmallVector<Value *, 8> getRanges();
+};
+
+/// A linalg.StoreOp is the counterpart of affine.store but operating on
+/// ViewType instead of MemRefType.
+///
+/// ```{.mlir}
+///    linalg.store %f, %V[%c0] : !linalg.view<?xf32>
+/// ```
+class StoreOp
+    : public Op<StoreOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
+public:
+  using Op::Op;
+
+  // Hooks to customize the behavior of this op.
+  static llvm::StringRef getOperationName() { return "linalg.store"; }
+  static void build(Builder *b, OperationState *result, Value *valueToStore,
+                    Value *view, ArrayRef<Value *> indices = {});
+  LogicalResult verify();
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+
+  // Op-specific functionality.
+  unsigned getRank() { return getViewType().getRank(); }
+  ViewType getViewType() { return getView()->getType().cast<ViewType>(); }
+  Value *getValueToStore() { return getOperand(0); }
+  Value *getView() { return getOperand(1); }
+  Operation::operand_range getIndices() {
+    return {operand_begin() + 2, operand_end()};
+  }
+};
+
+/// Returns the name mangled library call name to disambiguate between different
+/// overloads at the C level. The name mangling scheme is basic and uses MLIR
+/// type names:
+///   1. form a string which is the concatenation of the linalg op name with all
+///      the operand type names, separate by underscores;
+///   2. drop the `linalg.` prefix, and the `<`, `>`, `?` symbols from the type.
+/// Assumes `op` is a LinalgOp.
+///
+/// Examples:
+///
+/// 1. linalg.fill(%A, %f) : !linalg.view<f32>, f32
+///   name mangles into `linalg_fill_viewf32_f32_impl`
+///
+/// 2. linalg.dot(%A, %B, %C) :
+///      !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
+///   name mangles into `linalg_dot_viewxf32_viewxf32_viewf32_impl`
+///
+/// 3. linalg.matmul(...) :
+///      !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+///   name mangles into `linalg_matmul_viewxxf32_viewxxf32_viewxxf32_impl`
+std::string generateLibraryCallName(Operation *op);
+
+#define GET_OP_CLASSES
+#include "mlir/Linalg/IR/LinalgOps.h.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Linalg/IR/LinalgLibraryOps.h.inc"
+
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os, SubViewOp::Range &range);
+
+/// Returns the list of maps that map loops to operands of a Linalg op.
+/// The i-th affine map identifies loop indices to subscripts that are used when
+/// accessing the i-th operand.
+/// For instance, a matmul that can be written in index notation as:
+/// `A(i, k) * B(k, j) -> C(i, j)` will have the following, ordered, list of
+/// affine maps:
+///
+/// ```{.mlir}
+///    (
+///      (i, j, k) -> (i, k),
+///      (i, j, k) -> (k, j),
+///      (i, j, k) -> (i, j)
+///    )
+/// ```
+///
+/// Only permutation maps are currently supported.
+SmallVector<AffineMap, 4> loopToOperandRangesMaps(Operation *op);
+
+/// A LinalgOp behaves like a base class for the Linalg operations that are
+/// defined in LinalgLibraryOps.td. The implementation does not use inheritance
+/// directly. Instead, a LinalgOp directly derives from Op, hides the `classof`
+/// method and dispatches to the appropriate LinalgLibraryOp.
+/// This allows writing generic passes, like tiling, for all current and future
+/// LinalgOps without requiring templating and dispatch in multiple places.
+class LinalgOp : public Op<LinalgOp> {
+public:
+  using Op::Op;
+
+  LinalgOp(Operation *op) : Op<LinalgOp>(op) {
+    impl = ModelDispatch<
+#define GET_OP_LIST
+#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
+        >::dispatch(op);
+  }
+
+  static bool classof(Operation *op) {
+    return ModelDispatch<
+#define GET_OP_LIST
+#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
+        >::classof(op);
+  }
+
+  unsigned getNumParallelLoops() {
+    return impl->getNumParallelLoops(getOperation());
+  }
+  unsigned getNumReductionLoops() {
+    return impl->getNumReductionLoops(getOperation());
+  }
+  unsigned getNumWindowLoops() {
+    return impl->getNumWindowLoops(getOperation());
+  }
+  unsigned getNumLoops() {
+    return getNumParallelLoops() + getNumReductionLoops() + getNumWindowLoops();
+  }
+  unsigned getNumInputs() { return impl->getNumInputs(getOperation()); }
+  unsigned getNumOutputs() { return impl->getNumOutputs(getOperation()); }
+  unsigned getNumInputsAndOutputs() {
+    return impl->getNumInputsAndOutputs(getOperation());
+  }
+  Value *getInput(unsigned i) { return impl->getInput(getOperation(), i); }
+  llvm::Optional<unsigned> getIndexOfInput(Value *view) {
+    return impl->getIndexOfInput(getOperation(), view);
+  }
+  ViewType getInputViewType(unsigned i) {
+    return impl->getInputViewType(getOperation(), i);
+  }
+  Operation::operand_range getInputs() {
+    return impl->getInputs(getOperation());
+  }
+  Value *getOutput(unsigned i) { return impl->getOutput(getOperation(), i); }
+  llvm::Optional<unsigned> getIndexOfOutput(Value *view) {
+    return impl->getIndexOfOutput(getOperation(), view);
+  }
+  ViewType getOutputViewType(unsigned i) {
+    return impl->getOutputViewType(getOperation(), i);
+  }
+  Operation::operand_range getOutputs() {
+    return impl->getOutputs(getOperation());
+  }
+  Operation::operand_range getInputsAndOutputs() {
+    return impl->getInputsAndOutputs(getOperation());
+  }
+  LinalgOp create(OpBuilder &builder, Location loc, ArrayRef<Value *> operands,
+                  ArrayRef<NamedAttribute> attributes) {
+    return LinalgOp(impl->create(builder, loc, operands, attributes));
+  }
+
+private:
+  struct Concept {
+    virtual ~Concept() = default;
+    virtual unsigned getNumInputs(Operation *op) = 0;
+    virtual unsigned getNumOutputs(Operation *op) = 0;
+    virtual unsigned getNumInputsAndOutputs(Operation *op) = 0;
+    virtual unsigned getNumParallelLoops(Operation *op) = 0;
+    virtual unsigned getNumReductionLoops(Operation *op) = 0;
+    virtual unsigned getNumWindowLoops(Operation *op) = 0;
+    virtual Value *getInput(Operation *op, unsigned i) = 0;
+    virtual llvm::Optional<unsigned> getIndexOfInput(Operation *op,
+                                                     Value *view) = 0;
+    virtual ViewType getInputViewType(Operation *op, unsigned i) = 0;
+    virtual Operation::operand_range getInputs(Operation *op) = 0;
+    virtual Value *getOutput(Operation *op, unsigned i) = 0;
+    virtual llvm::Optional<unsigned> getIndexOfOutput(Operation *op,
+                                                      Value *view) = 0;
+    virtual ViewType getOutputViewType(Operation *op, unsigned i) = 0;
+    virtual Operation::operand_range getOutputs(Operation *op) = 0;
+    virtual Operation::operand_range getInputsAndOutputs(Operation *op) = 0;
+    virtual Operation *create(OpBuilder &builder, Location loc,
+                              ArrayRef<Value *> operands,
+                              ArrayRef<NamedAttribute> attributes) = 0;
+  };
+
+  /// The implementation is inspired from Sean Parent's concept-based
+  /// polymorphism. A key difference is that the set of classes erased is
+  /// statically known, which alleviates the need for using dynamic memory
+  /// allocation.
+  /// We use a zero-sized templated class `Model<ConcreteOp>` to emit the
+  /// virtual table and generate a singleton object for each instantiation of
+  /// this class.
+  /// We pay the cost of initialization once on construction (find which class
+  /// to dispatch to) and then a virtual dispatch on every call.
+  template <typename ConcreteOp> struct Model : public Concept {
+    static Model<ConcreteOp> &instance() {
+      static Model<ConcreteOp> singleton;
+      return singleton;
+    }
+    unsigned getNumInputs(Operation *op) override {
+      return cast<ConcreteOp>(op).getNumInputs();
+    }
+    unsigned getNumOutputs(Operation *op) override {
+      return cast<ConcreteOp>(op).getNumOutputs();
+    }
+    unsigned getNumInputsAndOutputs(Operation *op) override {
+      return cast<ConcreteOp>(op).getNumInputsAndOutputs();
+    }
+    unsigned getNumParallelLoops(Operation *op) override {
+      return cast<ConcreteOp>(op).getNumParallelLoops();
+    }
+    unsigned getNumReductionLoops(Operation *op) override {
+      return cast<ConcreteOp>(op).getNumReductionLoops();
+    }
+    unsigned getNumWindowLoops(Operation *op) override {
+      return cast<ConcreteOp>(op).getNumWindowLoops();
+    }
+    Value *getInput(Operation *op, unsigned i) override {
+      return cast<ConcreteOp>(op).getInput(i);
+    }
+    llvm::Optional<unsigned> getIndexOfInput(Operation *op,
+                                             Value *view) override {
+      return cast<ConcreteOp>(op).getIndexOfInput(view);
+    }
+    ViewType getInputViewType(Operation *op, unsigned i) override {
+      return cast<ConcreteOp>(op).getInputViewType(i);
+    }
+    Operation::operand_range getInputs(Operation *op) override {
+      return cast<ConcreteOp>(op).getInputs();
+    }
+    Value *getOutput(Operation *op, unsigned i) override {
+      return cast<ConcreteOp>(op).getOutput(i);
+    }
+    llvm::Optional<unsigned> getIndexOfOutput(Operation *op,
+                                              Value *view) override {
+      return cast<ConcreteOp>(op).getIndexOfOutput(view);
+    }
+    ViewType getOutputViewType(Operation *op, unsigned i) override {
+      return cast<ConcreteOp>(op).getOutputViewType(i);
+    }
+    Operation::operand_range getOutputs(Operation *op) override {
+      return cast<ConcreteOp>(op).getOutputs();
+    }
+    Operation::operand_range getInputsAndOutputs(Operation *op) override {
+      return cast<ConcreteOp>(op).getInputsAndOutputs();
+    }
+    Operation *create(OpBuilder &builder, Location loc,
+                      ArrayRef<Value *> operands,
+                      ArrayRef<NamedAttribute> attributes) override {
+      return builder.create<ConcreteOp>(loc, ArrayRef<Type>{}, operands,
+                                        attributes);
+    }
+  };
+  Concept *impl;
+
+  template <typename... Types> struct ModelDispatch;
+
+  template <typename First, typename... Rest>
+  struct ModelDispatch<First, Rest...> {
+    static bool classof(Operation *op) {
+      return isa<First>(op) || ModelDispatch<Rest...>::classof(op);
+    }
+    static Concept *dispatch(Operation *op) {
+      return isa<First>(op) ? &Model<First>::instance()
+                            : ModelDispatch<Rest...>::dispatch(op);
+    }
+  };
+
+  template <typename...> struct ModelDispatch {
+    static bool classof(Operation *op) { return false; }
+    static Concept *dispatch(Operation *op) {
+      llvm_unreachable("Invalid LinalgOp");
+    }
+  };
+};
+
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_LINALG_LINALGOPS_H_
diff --git a/third_party/mlir/include/mlir/Linalg/IR/LinalgOps.td b/third_party/mlir/include/mlir/Linalg/IR/LinalgOps.td
new file mode 100644
index 0000000..f7a07fc
--- /dev/null
+++ b/third_party/mlir/include/mlir/Linalg/IR/LinalgOps.td
@@ -0,0 +1,274 @@
+//===- LinalgOps.td - Linalg dialect ops -------------------*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is the operation definition file for linear algebra operations.
+//
+//===----------------------------------------------------------------------===//
+
+include "mlir/Linalg/IR/LinalgBase.td"
+
+#ifdef LINALG_OPS
+#else
+#define LINALG_OPS
+
+// Base class for Linalg dialect ops that do not correspond to library calls.
+class Linalg_Op<string mnemonic, list<OpTrait> traits = []> :
+    Op<Linalg_Dialect, mnemonic, traits> {
+  // For every linalg op, there needs to be a:
+  //   * void print(OpAsmPrinter *p, ${C++ class of Op} op)
+  //   * LogicalResult verify(${C++ class of Op} op)
+  //   * ParseResult parse${C++ class of Op}(OpAsmParser *parser,
+  //                                         OperationState *result)
+  // functions.
+  let printer = [{ return ::print(p, *this); }];
+  let verifier = [{ return ::verify(*this); }];
+  let parser = [{ return ::parse$cppClass(parser, result); }];
+}
+
+def BufferAllocOp :
+    Linalg_Op<"buffer_alloc">,
+    Arguments<(ins Variadic<Index>:$size)>,
+    Results<(outs Buffer)> {
+  let summary = "buffer allocation operation";
+  let description = [{
+    The "buffer_alloc" op creates a 1-D linalg.buffer of the specified type,
+    upon which a base view can be laid out to give it indexing semantics.
+    "buffer_alloc" takes a single argument, the size of the buffer to allocate
+    (in number of elements).
+
+    ```{.mlir}
+        %0 = linalg.buffer_alloc(%arg0) : !linalg.buffer<?xf32>
+    ```
+
+    The size argument may be omitted if it is statically known, in which case it
+    must be reflected in the type.
+
+    ```{.mlir}
+        %0 = linalg.buffer_alloc() : !linalg.buffer<4xf32>
+    ```
+  }];
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, BufferType bufferType", [{
+       result->types.push_back(bufferType);
+     }]
+  >];
+  let extraClassDeclaration = [{
+    BufferType getBufferType() { return getType().cast<BufferType>(); }
+    Type getElementType() { return getBufferType().getElementType(); }
+  }];
+}
+
+def BufferDeallocOp :
+    Linalg_Op<"buffer_dealloc">,
+    Arguments<(ins Buffer:$buffer)>,
+    Results<(outs)> {
+  let summary = "buffer allocation operation";
+  let description = [{
+    The "buffer_dealloc" op frees a 1-D linalg.buffer of the specified type.
+
+    ```{.mlir}
+        linalg.buffer_dealloc %0 : !linalg.buffer<f32>
+    ```
+  }];
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, BufferType bufferType", [{
+       result->types.push_back(bufferType);
+     }]
+  >];
+  let extraClassDeclaration = [{
+    BufferType getBufferType() {
+      return getOperand()->getType().cast<BufferType>();
+    }
+  }];
+  // Fully specified by traits.
+  let verifier = ?;
+}
+
+def BufferSizeOp :
+    Linalg_Op<"buffer_size", [NoSideEffect]>,
+    Arguments<(ins Buffer)>,
+    Results<(outs Index)> {
+  let summary = "buffer size operation";
+  let description = [{
+    The "linalg.buffer_size" operation takes a linalg.buffer and returns an
+    "index". For example:
+
+       %0 = linalg.buffer_size %arg0 : !linalg.buffer<f32>
+  }];
+  // Fully specified by traits.
+  let verifier = ?;
+}
+
+def DimOp : Linalg_Op<"dim", [NoSideEffect]>,
+    Arguments<(ins View:$view, APIntAttr:$index)>,
+    Results<(outs Index)> {
+  let summary = "dimension index operation";
+  let description = [{
+    The "linalg.dim" operation takes a linalg.view and returns an
+    "index". It requires a single integer attribute named "index". It
+     returns the size of the specified dimension. For example:
+
+      %1 = linalg.dim %0, 2 : view<?x?x?xf32>
+  }];
+
+  let verifier = [{
+    if (getIndex() >= getViewType().getRank())
+      return emitOpError("index is out of range");
+    return success();
+  }];
+
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, Value *view, unsigned index",
+    [{
+      result->addOperands(view);
+      result->addAttribute(
+        "index", builder->getIntegerAttr(builder->getIndexType(), index));
+      result->types.push_back(builder->getIndexType());
+    }]>];
+
+  let extraClassDeclaration = [{
+    unsigned getIndex() {
+      return getAttrOfType<IntegerAttr>("index").getValue().getZExtValue();
+    }
+    ViewType getViewType() { return getOperand()->getType().cast<ViewType>(); }
+  }];
+
+  let hasCanonicalizer = 1;
+}
+
+def SubViewOp : Linalg_Op<"subview", [NoSideEffect]>,
+    Arguments<(ins View:$view, Variadic<Index>:$ranges)>,
+    Results<(outs View)> {
+  let summary = "subview operation";
+  let description = [{
+    The "linalg.subview" operation takes a linalg.view, a list of indices and
+    returns a new linalg.view of the same type that is contained within the
+    operand view.
+    This operation is equivalent to a non-rank-reducing slice operation. The
+    main difference is the operands are all of type `index` and no intermediate
+    linalg.range operations are required. A "linalg.subview" is thus a
+    specialized linalg.slice with a higher level of abstraction.
+
+      %1 = linalg.subview %0[%1, %2, %3, %4, %5, %6] : view<?x?xf32>
+
+  }];
+  // TODO(ntv) evolve syntax towards:
+  //   linalg.subview %0[%1:%2:%3][%4:%5:%6] : view<?x?xf32>
+
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, Value *view, "
+    "ArrayRef<Value *> ranges",
+    [{
+      result->addOperands(view);
+      result->addOperands(ranges);
+      result->types.push_back(view->getType());
+    }]>];
+
+  let verifier = [{
+    auto numRanges = (getNumOperands() - 1) / 3;
+    if (getNumOperands() != 3 * numRanges + 1 ||
+        numRanges != getViewType().getRank())
+      return emitOpError("expected a view followed by 3 indices specifying ") <<
+        "a range for each dimension";
+    return success();
+  }];
+
+  let extraClassDeclaration = [{
+    Value *getView() { return getOperand(0); }
+    ViewType getViewType() { return getView()->getType().cast<ViewType>(); }
+    struct Range { Value *min; Value *max; Value *step; };
+    Range getRange(unsigned i) {
+      return Range{
+        getOperand(1 + 3*i), getOperand(1 + 3*i + 1), getOperand(1 + 3*i + 2)};
+    }
+    SmallVector<Range, 8> getRanges() {
+      SmallVector<Range, 8> res;
+      unsigned rank = getViewType().getRank();
+      res.reserve(rank);
+      for (unsigned i = 0; i < rank; ++i)
+        res.push_back(getRange(i));
+      return res;
+    }
+    // This requires `SubViewOp` to be declared, in the future it should be
+    // folded into the builders.
+    static void build(Builder *builder, OperationState *result, Value *view,
+        ArrayRef<SubViewOp::Range> ranges) {
+      result->addOperands(view);
+      for (auto r : ranges)
+        result->addOperands({r.min, r.max, r.step});
+      result->types.push_back(view->getType());
+    }
+  }];
+}
+
+def ViewOp : Linalg_Op<"view", [NoSideEffect]>,
+    Arguments<(ins Buffer:$buffer, Variadic<Range>:$ranges)>,
+    Results<(outs View)> {
+  let summary = "view operation";
+  let description = [{
+    The "linalg.view" op produces a linalg.view which is a multi-dimensional
+    range abstraction on top of an underlying linalg.buffer. This gives an
+    indexing structure to an otherwise non-indexable linalg.buffer.
+
+    A "linalg.view" takes a buffer and a variadic number of ranges and produces
+    a `view` of rank the number of ranges. The elemental type may not match the
+    buffer element type:
+
+    Examples:
+    ```
+       %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
+       %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
+       %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xvector<4xf32>>
+    ```
+  }];
+
+  let builders = [OpBuilder<
+    "Builder *b, OperationState *result, Value *buffer, "
+    "ArrayRef<Value *> ranges, Type resultType = Type(), "
+    "ArrayRef<NamedAttribute> attrs = {}">];
+
+  let verifier = [{
+    if (getViewType().getRank() != llvm::size(ranges()))
+      return emitOpError("the view rank must be the number of its ranges");
+    return success();
+  }];
+
+  let extraClassDeclaration = [{
+    enum { FirstIndexingOperand = 1 };
+    unsigned getRank() { return getViewType().getRank(); }
+    Type getElementType() { return getViewType().getElementType(); }
+    ViewType getViewType() { return getType().cast<ViewType>(); }
+    /// Get the underlying indexing at a given rank.
+    Value *getRange(unsigned rank) {
+      assert(rank < getRank() && "rank overflow");
+      return *(ranges().begin() + rank);
+    }
+  }];
+}
+
+def YieldOp : Linalg_Op<"yield", [NativeOpTrait<"IsTerminator">]>,
+    Arguments<(ins Variadic<AnyType>:$values)> {
+  let summary = "Linalg yield operation";
+  let description = [{
+    "linalg.yield" is a special terminator operation for blocks inside regions
+    in linalg ops. It returns values to the immediately enclosing linalg op.
+
+       linalg.yield %f0, %f1 : f32, f32
+  }];
+}
+
+#endif // LINALG_OPS
diff --git a/third_party/mlir/include/mlir/Linalg/IR/LinalgTraits.h b/third_party/mlir/include/mlir/Linalg/IR/LinalgTraits.h
new file mode 100644
index 0000000..34f7043
--- /dev/null
+++ b/third_party/mlir/include/mlir/Linalg/IR/LinalgTraits.h
@@ -0,0 +1,193 @@
+//===- LinalgTraits.h - Linalg Traits ---------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_LINALG_LINALGTRAITS_H_
+#define MLIR_LINALG_LINALGTRAITS_H_
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Linalg/IR/LinalgTypes.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+namespace OpTrait {
+namespace linalg {
+
+/// This class provides the API for ops that are known to have a specified
+/// number of inputs and outputs, all passed as operands. This is used as a
+/// trait like this:
+///
+///   class DotOp : public Op<DotOp, OpTrait::NInputsAndOutputs<2, 1>::Impl> {
+///
+template <unsigned NInputs, unsigned NOutputs> class NInputsAndOutputs {
+public:
+  template <typename ConcreteType>
+  class Impl
+      : public OpTrait::TraitBase<ConcreteType,
+                                  NInputsAndOutputs<NInputs, NOutputs>::Impl> {
+  public:
+    static unsigned getNumInputs() { return NInputs; }
+    static unsigned getNumOutputs() { return NOutputs; }
+    static LogicalResult verifyTrait(Operation *op) {
+      return OpTrait::impl::verifyAtLeastNOperands(op, NInputs + NOutputs);
+    }
+  };
+};
+
+/// This class provides the API for ops that are known to operate on views. This
+/// trait must be used in conjunction with an op definition or a trait that
+/// provides the methods `getNumInputs` and `getNumOutputs`. This is used as a
+/// trait like this:
+///
+///   class DotOp : public Op<DotOp, OpTrait::ViewTrait> {
+///
+template <typename ConcreteType>
+class ViewTraits : public OpTrait::TraitBase<ConcreteType, ViewTraits> {
+private:
+  /// Return the number of input views. For internal use only.
+  unsigned nInputs() {
+    return cast<ConcreteType>(this->getOperation()).getNumInputs();
+  }
+  /// Return the number of input views. For internal use only.
+  unsigned nOutputs() {
+    return cast<ConcreteType>(this->getOperation()).getNumOutputs();
+  }
+
+public:
+  /// Return the `i`-th input view.
+  Value *getInput(unsigned i) {
+    assert(i < nInputs());
+    return this->getOperation()->getOperand(i);
+  }
+  /// Return the index of `view` in the list of input views if found, llvm::None
+  /// otherwise.
+  llvm::Optional<unsigned> getIndexOfInput(Value *view) {
+    auto it = llvm::find(getInputs(), view);
+    if (it != getInputs().end())
+      return it - getInputs().begin();
+    return llvm::None;
+  }
+  /// Return the `i`-th input view type.
+  mlir::linalg::ViewType getInputViewType(unsigned i) {
+    return getInput(i)->getType().template cast<mlir::linalg::ViewType>();
+  }
+  /// Return the range over input views.
+  Operation::operand_range getInputs() {
+    auto range = this->getOperation()->getOperands();
+    return {range.begin(), range.begin() + nInputs()};
+  }
+  /// Return the `i`-th output view.
+  Value *getOutput(unsigned i) {
+    return this->getOperation()->getOperand(nInputs() + i);
+  }
+  /// Return the index of `view` in the list of output views if found,
+  /// llvm::None otherwise.
+  llvm::Optional<unsigned> getIndexOfOutput(Value *view) {
+    auto it = llvm::find(getOutputs(), view);
+    if (it != getOutputs().end())
+      return it - getOutputs().begin();
+    return llvm::None;
+  }
+  /// Return the `i`-th output view type.
+  mlir::linalg::ViewType getOutputViewType(unsigned i) {
+    return getOutput(i)->getType().template cast<mlir::linalg::ViewType>();
+  }
+  /// Return the range over output views.
+  Operation::operand_range getOutputs() {
+    auto range = this->getOperation()->getOperands();
+    return {range.begin() + nInputs(),
+            range.begin() + getNumInputsAndOutputs()};
+  }
+  /// Return the number of input and output views.
+  unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); }
+  /// Return the `i`-th view type.
+  mlir::linalg::ViewType getViewType(unsigned i) {
+    return (i < nInputs()) ? getInputViewType(i)
+                           : getOutputViewType(i - nInputs());
+  }
+  /// Return the range over input and output views.
+  Operation::operand_range getInputsAndOutputs() {
+    auto range = this->getOperation()->getOperands();
+    return {range.begin(), range.begin() + getNumInputsAndOutputs()};
+  }
+  static LogicalResult verifyTrait(Operation *op) {
+    auto nViews = cast<ConcreteType>(op).getNumInputsAndOutputs();
+    if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nViews)))
+      return failure();
+    for (unsigned i = 0, e = nViews; i < e; ++i) {
+      if (!op->getOperand(i)->getType().dyn_cast<mlir::linalg::ViewType>())
+        return op->emitOpError("operand ") << i << " must have view type ";
+    }
+    return success();
+  }
+};
+
+/// This class provides the API for ops that are known to have a specified
+/// number of parallel, reduction and window loops. This is used as a trait like
+/// this:
+///
+///   class MatmulOp : public Op<MatmulOp, OpTrait::NLoopTypes<2, 1, 0>::Impl> {
+///
+template <unsigned NParallel, unsigned NReduction, unsigned NWindow = 0>
+class NLoopTypes {
+public:
+  template <typename ConcreteType>
+  class Impl
+      : public OpTrait::TraitBase<
+            ConcreteType, NLoopTypes<NParallel, NReduction, NWindow>::Impl> {
+  public:
+    static unsigned getNumParallelLoops() { return NParallel; }
+    static unsigned getNumReductionLoops() { return NReduction; }
+    static unsigned getNumWindowLoops() { return NWindow; }
+    static unsigned getNumLoops() { return NParallel + NReduction + NWindow; }
+  };
+};
+
+/// This class provides the API for ops that are known to have a specified
+/// list of view ranks. This is used as a trait like this:
+///
+///   class MatvecOp : public Op<MatvecOp, OpTrait::ViewRanks<2, 1, 1>::Impl> {
+///
+template <unsigned... Ranks> class ViewRanks {
+public:
+  template <typename ConcreteType>
+  class Impl
+      : public OpTrait::TraitBase<ConcreteType, ViewRanks<Ranks...>::Impl> {
+  public:
+    static LogicalResult verifyTrait(Operation *op) {
+      if (op->getNumOperands() != sizeof...(Ranks))
+        return op->emitError("expected ") << sizeof...(Ranks) << " operands";
+
+      unsigned ranks[]{Ranks...};
+      for (unsigned i = 0, e = op->getNumOperands(); i < e; ++i) {
+        auto viewType =
+            op->getOperand(i)->getType().dyn_cast<mlir::linalg::ViewType>();
+        if (!viewType)
+          return op->emitOpError("operand ") << i << " must have view type ";
+        if (ranks[i] != viewType.getRank())
+          return op->emitOpError("operand ")
+                 << i << " must have rank " << ranks[i];
+      }
+      return success();
+    }
+  };
+};
+
+} // namespace linalg
+} // namespace OpTrait
+} // namespace mlir
+
+#endif // MLIR_LINALG_LINALGTRAITS_H_
diff --git a/third_party/mlir/include/mlir/Linalg/IR/LinalgTypes.h b/third_party/mlir/include/mlir/Linalg/IR/LinalgTypes.h
new file mode 100644
index 0000000..b1ce221
--- /dev/null
+++ b/third_party/mlir/include/mlir/Linalg/IR/LinalgTypes.h
@@ -0,0 +1,121 @@
+//===- LinalgTypes.h - Linalg Types ---------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_LINALG_LINALGTYPES_H_
+#define MLIR_LINALG_LINALGTYPES_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+class MLIRContext;
+
+namespace linalg {
+enum LinalgTypes {
+  Buffer = Type::FIRST_LINALG_TYPE,
+  Range,
+  View,
+  LAST_USED_LINALG_TYPE = View,
+};
+
+class LinalgDialect : public Dialect {
+public:
+  explicit LinalgDialect(MLIRContext *context);
+  static StringRef getDialectNamespace() { return "linalg"; }
+
+  /// Parse a type registered to this dialect.
+  Type parseType(llvm::StringRef spec, Location loc) const override;
+
+  /// Print a type registered to this dialect.
+  void printType(Type type, llvm::raw_ostream &os) const override;
+};
+
+/// A BufferType represents a contiguous block of memory that can be allocated
+/// and deallocated. A buffer cannot be indexed directly, a view must be
+/// laid out on a buffer to give it indexing semantics.
+struct BufferTypeStorage;
+class BufferType : public Type::TypeBase<BufferType, Type, BufferTypeStorage> {
+public:
+  // Used for generic hooks in TypeBase.
+  using Base::Base;
+  /// Construction hook.
+  static BufferType get(MLIRContext *context, Type elementType,
+                        int64_t bufferSize = -1);
+  /// Used to implement llvm-style cast.
+  static bool kindof(unsigned kind) { return kind == LinalgTypes::Buffer; }
+
+  // Type-specific functionality.
+  Type getElementType();
+  bool hasConstantSize();
+  Optional<int64_t> getBufferSize();
+};
+
+/// A RangeType represents a minimal range abstraction (min, max, step).
+/// It is constructed by calling the linalg.range op with three values index of
+/// index type:
+///
+/// ```{.mlir}
+///    func @foo(%arg0 : index, %arg1 : index, %arg2 : index) {
+///      %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
+///    }
+/// ```
+class RangeType : public Type::TypeBase<RangeType, Type> {
+public:
+  // Used for generic hooks in TypeBase.
+  using Base::Base;
+  /// Construction hook.
+  static RangeType get(MLIRContext *context) {
+    /// Custom, uniq'ed construction in the MLIRContext.
+    return Base::get(context, LinalgTypes::Range);
+  }
+  /// Used to implement llvm-style cast.
+  static bool kindof(unsigned kind) { return kind == LinalgTypes::Range; }
+};
+
+/// A ViewType represents a multi-dimensional range abstraction on top of an
+/// underlying storage type. It is parameterizable by the underlying element
+/// type and the rank of the view.
+/// A new value of ViewType is constructed from a buffer with a view op and
+/// passing it ranges:
+///
+/// ```{.mlir}
+///    %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
+///    %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
+///    %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
+/// ```
+struct ViewTypeStorage;
+class ViewType : public Type::TypeBase<ViewType, Type, ViewTypeStorage> {
+public:
+  // Used for generic hooks in TypeBase.
+  using Base::Base;
+  /// Construction hook.
+  static ViewType get(MLIRContext *context, Type elementType, unsigned rank);
+  // Used to implement llvm-style cast.
+  static bool kindof(unsigned kind) { return kind == LinalgTypes::View; }
+
+  // Type-specific functionality.
+  /// Return the underlying elemental type.
+  Type getElementType();
+  /// Return the rank of the view.
+  /// This is the number of indexings needed to reach an underlying element.
+  unsigned getRank();
+};
+
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_LINALG_LINALGTYPES_H_
diff --git a/third_party/mlir/include/mlir/Linalg/Passes.h b/third_party/mlir/include/mlir/Linalg/Passes.h
new file mode 100644
index 0000000..0294149
--- /dev/null
+++ b/third_party/mlir/include/mlir/Linalg/Passes.h
@@ -0,0 +1,44 @@
+//===- Passes.h - Linalg pass entry points ----------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This header file defines prototypes that expose pass constructors.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_LINALG_PASSES_H_
+#define MLIR_LINALG_PASSES_H_
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
+
+namespace mlir {
+class FunctionPassBase;
+class ModulePassBase;
+
+namespace linalg {
+FunctionPassBase *createLinalgFusionPass(ArrayRef<int64_t> tileSizes = {});
+
+FunctionPassBase *createLinalgTilingPass(ArrayRef<int64_t> tileSizes = {},
+                                         bool promoteViews = false);
+
+FunctionPassBase *createLowerLinalgToLoopsPass();
+
+ModulePassBase *createLowerLinalgToLLVMPass();
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_LINALG_PASSES_H_
diff --git a/third_party/mlir/include/mlir/Linalg/Utils/Intrinsics.h b/third_party/mlir/include/mlir/Linalg/Utils/Intrinsics.h
new file mode 100644
index 0000000..eabec69
--- /dev/null
+++ b/third_party/mlir/include/mlir/Linalg/Utils/Intrinsics.h
@@ -0,0 +1,51 @@
+//===- Intrinsics.h - Linalg intrinsics definitions -----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_LINALG_INTRINSICS_H_
+#define MLIR_LINALG_INTRINSICS_H_
+
+#include "mlir/EDSC/Intrinsics.h"
+
+namespace mlir {
+namespace linalg {
+class BufferAllocOp;
+class BufferDeallocOp;
+class CopyOp;
+class DimOp;
+class FillOp;
+class LoadOp;
+class RangeOp;
+class SliceOp;
+class StoreOp;
+class ViewOp;
+namespace intrinsics {
+using buffer_alloc = mlir::edsc::intrinsics::ValueBuilder<BufferAllocOp>;
+using buffer_dealloc =
+    mlir::edsc::intrinsics::OperationBuilder<BufferDeallocOp>;
+using copy = mlir::edsc::intrinsics::OperationBuilder<CopyOp>;
+using dim = mlir::edsc::intrinsics::ValueBuilder<linalg::DimOp>;
+using fill = mlir::edsc::intrinsics::OperationBuilder<FillOp>;
+using linalg_load = mlir::edsc::intrinsics::ValueBuilder<linalg::LoadOp>;
+using linalg_store = mlir::edsc::intrinsics::OperationBuilder<linalg::StoreOp>;
+using range = mlir::edsc::intrinsics::ValueBuilder<RangeOp>;
+using slice = mlir::edsc::intrinsics::ValueBuilder<SliceOp>;
+using view = mlir::edsc::intrinsics::ValueBuilder<ViewOp>;
+} // namespace intrinsics
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_LINALG_INTRINSICS_H_
diff --git a/third_party/mlir/include/mlir/Linalg/Utils/Utils.h b/third_party/mlir/include/mlir/Linalg/Utils/Utils.h
new file mode 100644
index 0000000..68d71a8
--- /dev/null
+++ b/third_party/mlir/include/mlir/Linalg/Utils/Utils.h
@@ -0,0 +1,156 @@
+//===- Utils.h - Utilities to support the Linalg dialect --------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_LINALG_UTILS_H_
+#define MLIR_LINALG_UTILS_H_
+
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/EDSC/Helpers.h"
+#include "mlir/Linalg/IR/LinalgOps.h"
+#include "mlir/Linalg/Utils/Intrinsics.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class AffineExpr;
+class AffineMap;
+class OperationFolder;
+
+namespace edsc {
+
+/// A LoopRangeBuilder is a generic NestedBuilder for loop.for operations.
+/// More specifically it is meant to be used as a temporary object for
+/// representing any nested MLIR construct that is "related to" an mlir::Value*
+/// (for now an induction variable).
+class LoopRangeBuilder : public NestedBuilder {
+public:
+  /// Constructs a new loop.for and captures the associated induction
+  /// variable. A ValueHandle pointer is passed as the first argument and is the
+  /// *only* way to capture the loop induction variable.
+  LoopRangeBuilder(ValueHandle *iv, ValueHandle range);
+  LoopRangeBuilder(ValueHandle *iv, Value *range);
+  LoopRangeBuilder(ValueHandle *iv, linalg::SubViewOp::Range range);
+
+  LoopRangeBuilder(const LoopRangeBuilder &) = delete;
+  LoopRangeBuilder(LoopRangeBuilder &&) = default;
+
+  LoopRangeBuilder &operator=(const LoopRangeBuilder &) = delete;
+  LoopRangeBuilder &operator=(LoopRangeBuilder &&) = default;
+
+  /// The only purpose of this operator is to serve as a sequence point so that
+  /// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
+  /// scoped within a LoopRangeBuilder.
+  ValueHandle operator()(std::function<void(void)> fun = nullptr);
+};
+
+/// Helper class to sugar building loop.for loop nests from ranges.
+/// This is similar to edsc::LoopNestBuilder except it works on ranges directly.
+/// In the current implementation it produces loop.for operations.
+class LoopNestRangeBuilder {
+public:
+  LoopNestRangeBuilder(llvm::ArrayRef<edsc::ValueHandle *> ivs,
+                       llvm::ArrayRef<edsc::ValueHandle> ranges);
+  LoopNestRangeBuilder(llvm::ArrayRef<edsc::ValueHandle *> ivs,
+                       llvm::ArrayRef<Value *> ranges);
+  LoopNestRangeBuilder(llvm::ArrayRef<edsc::ValueHandle *> ivs,
+                       llvm::ArrayRef<linalg::SubViewOp::Range> ranges);
+  edsc::ValueHandle operator()(std::function<void(void)> fun = nullptr);
+
+private:
+  llvm::SmallVector<LoopRangeBuilder, 4> loops;
+};
+
+} // namespace edsc
+
+namespace linalg {
+
+/// Returns the linearized list of all view dimensions in a linalgOp. Applying
+/// the inverse, concatenated loopToOperandRangeMaps to this list allows the
+/// derivation of loop ranges for any linalgOp.
+template <typename ConcreteOp>
+SmallVector<Value *, 8> getViewSizes(ConcreteOp linalgOp) {
+  SmallVector<Value *, 8> res;
+  for (auto v : linalgOp.getInputsAndOutputs()) {
+    ViewType t = v->getType().template cast<ViewType>();
+    for (unsigned i = 0; i < t.getRank(); ++i)
+      res.push_back(intrinsics::dim(v, i));
+  }
+  return res;
+}
+
+/// Returns the values obtained by applying `map` to the list of values.
+/// Performs simplifications and foldings where possible.
+SmallVector<Value *, 4> applyMapToValues(OpBuilder &b, Location loc,
+                                         AffineMap map,
+                                         ArrayRef<Value *> values,
+                                         OperationFolder &state);
+
+struct TiledLinalgOp {
+  LinalgOp op;
+  SmallVector<loop::ForOp, 8> loops;
+};
+
+/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
+/// Inserts scoped local buffers and copies tiled views into/from those buffers
+/// when the corresponding entry in `viewsToPromote` is true.
+/// Returns a struct containing the tiled loops and the cloned op if successful,
+/// llvm::None otherwise.
+// TODO(ntv) implement a heuristic for view promotion.
+llvm::Optional<TiledLinalgOp> tileLinalgOp(LinalgOp op,
+                                           ArrayRef<Value *> tileSizes,
+                                           OperationFolder &folder,
+                                           ArrayRef<bool> viewsToPromote = {});
+
+/// Performs standalone tiling of a single LinalgOp by constant `tileSizes`.
+/// Inserts scoped local buffers and copies tiled views into/from those buffers
+/// when the corresponding entry in `viewsToPromote` is true.
+/// Returns a struct containing the tiled loops and the cloned op if successful,
+/// llvm::None otherwise.
+// TODO(ntv) implement a heuristic for view promotion.
+llvm::Optional<TiledLinalgOp> tileLinalgOp(LinalgOp op,
+                                           ArrayRef<int64_t> tileSizes,
+                                           OperationFolder &folder,
+                                           ArrayRef<bool> viewsToPromote = {});
+
+struct PromotionInfo {
+  Value *buffer;
+  Value *fullLocalView;
+  Value *partialLocalView;
+};
+
+/// Promotes the `views` into a new buffer allocated at the insertion point `b`.
+/// For now, promotion occurs in 3 steps:
+///   1. Create a new buffer for a full tile (i.e. not clipped at the boundary).
+///   2. Take a full view on the buffer and `linalg.fill` it with zeros (use
+///      float zero for now).
+///   3. Take a partial slice of the full view in step 2. and copy into it.
+///
+/// Returns a list of PromotionInfo which hold the promoted buffer and the
+/// full and partial views indexing into the buffer.
+llvm::SmallVector<PromotionInfo, 8> promoteLinalgViews(OpBuilder &b,
+                                                       Location loc,
+                                                       ArrayRef<Value *> views,
+                                                       OperationFolder &folder);
+
+/// Returns all the operands of `linalgOp` that are not views.
+/// Asserts that these operands are value types to allow transformations like
+/// tiling to just use the values when cloning `linalgOp`.
+llvm::SmallVector<Value *, 4> getAssumedNonViewOperands(LinalgOp linalgOp);
+
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_LINALG_UTILS_H_
diff --git a/third_party/mlir/include/mlir/Parser.h b/third_party/mlir/include/mlir/Parser.h
new file mode 100644
index 0000000..71babe7
--- /dev/null
+++ b/third_party/mlir/include/mlir/Parser.h
@@ -0,0 +1,70 @@
+//===- Parser.h - MLIR Parser Library Interface -----------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file is contains the interface to the MLIR parser library.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_PARSER_H
+#define MLIR_PARSER_H
+
+namespace llvm {
+class SourceMgr;
+class SMDiagnostic;
+class StringRef;
+} // end namespace llvm
+
+namespace mlir {
+class Location;
+class ModuleOp;
+class MLIRContext;
+class Type;
+
+/// This parses the file specified by the indicated SourceMgr and returns an
+/// MLIR module if it was valid.  If not, the error message is emitted through
+/// the error handler registered in the context, and a null pointer is returned.
+ModuleOp parseSourceFile(const llvm::SourceMgr &sourceMgr,
+                         MLIRContext *context);
+
+/// This parses the file specified by the indicated filename and returns an
+/// MLIR module if it was valid.  If not, the error message is emitted through
+/// the error handler registered in the context, and a null pointer is returned.
+ModuleOp parseSourceFile(llvm::StringRef filename, MLIRContext *context);
+
+/// This parses the file specified by the indicated filename using the provided
+/// SourceMgr and returns an MLIR module if it was valid.  If not, the error
+/// message is emitted through the error handler registered in the context, and
+/// a null pointer is returned.
+ModuleOp parseSourceFile(llvm::StringRef filename, llvm::SourceMgr &sourceMgr,
+                         MLIRContext *context);
+
+/// This parses the module string to a MLIR module if it was valid.  If not, the
+/// error message is emitted through the error handler registered in the
+/// context, and a null pointer is returned.
+ModuleOp parseSourceString(llvm::StringRef moduleStr, MLIRContext *context);
+
+/// This parses a single MLIR type to an MLIR context if it was valid.  If not,
+/// an error message is emitted through a new SourceMgrDiagnosticHandler
+/// constructed from a new SourceMgr with a single a MemoryBuffer wrapping
+/// `typeStr`. If the passed `typeStr` has additional tokens that were not part
+/// of the type, an error is emitted.
+// TODO(ntv) Improve diagnostic reporting.
+Type parseType(llvm::StringRef typeStr, MLIRContext *context);
+
+} // end namespace mlir
+
+#endif // MLIR_PARSER_H
diff --git a/third_party/mlir/include/mlir/Pass/AnalysisManager.h b/third_party/mlir/include/mlir/Pass/AnalysisManager.h
new file mode 100644
index 0000000..1f44515
--- /dev/null
+++ b/third_party/mlir/include/mlir/Pass/AnalysisManager.h
@@ -0,0 +1,293 @@
+//===- AnalysisManager.h - Analysis Management Infrastructure ---*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_PASS_ANALYSISMANAGER_H
+#define MLIR_PASS_ANALYSISMANAGER_H
+
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Pass/PassInstrumentation.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/TypeName.h"
+
+namespace mlir {
+/// A special type used by analyses to provide an address that identifies a
+/// particular analysis set or a concrete analysis type.
+using AnalysisID = ClassID;
+
+//===----------------------------------------------------------------------===//
+// Analysis Preservation and Concept Modeling
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+/// A utility class to represent the analyses that are known to be preserved.
+class PreservedAnalyses {
+public:
+  /// Mark all analyses as preserved.
+  void preserveAll() { preservedIDs.insert(&allAnalysesID); }
+
+  /// Returns true if all analyses were marked preserved.
+  bool isAll() const { return preservedIDs.count(&allAnalysesID); }
+
+  /// Returns true if no analyses were marked preserved.
+  bool isNone() const { return preservedIDs.empty(); }
+
+  /// Preserve the given analyses.
+  template <typename AnalysisT> void preserve() {
+    preserve(AnalysisID::getID<AnalysisT>());
+  }
+  template <typename AnalysisT, typename AnalysisT2, typename... OtherAnalysesT>
+  void preserve() {
+    preserve<AnalysisT>();
+    preserve<AnalysisT2, OtherAnalysesT...>();
+  }
+  void preserve(const AnalysisID *id) { preservedIDs.insert(id); }
+
+  /// Returns if the given analysis has been marked as preserved. Note that this
+  /// simply checks for the presence of a given analysis ID and should not be
+  /// used as a general preservation checker.
+  template <typename AnalysisT> bool isPreserved() const {
+    return isPreserved(AnalysisID::getID<AnalysisT>());
+  }
+  bool isPreserved(const AnalysisID *id) const {
+    return preservedIDs.count(id);
+  }
+
+private:
+  /// An identifier used to represent all potential analyses.
+  constexpr static AnalysisID allAnalysesID = {};
+
+  /// The set of analyses that are known to be preserved.
+  SmallPtrSet<const void *, 2> preservedIDs;
+};
+
+/// The abstract polymorphic base class representing an analysis.
+struct AnalysisConcept {
+  virtual ~AnalysisConcept() = default;
+};
+
+/// A derived analysis model used to hold a specific analysis object.
+template <typename AnalysisT> struct AnalysisModel : public AnalysisConcept {
+  template <typename... Args>
+  explicit AnalysisModel(Args &&... args)
+      : analysis(std::forward<Args>(args)...) {}
+
+  AnalysisT analysis;
+};
+
+/// This class represents a cache of analyses for a single IR unit. All
+/// computation, caching, and invalidation of analyses takes place here.
+template <typename IRUnitT> class AnalysisMap {
+  /// A mapping between an analysis id and an existing analysis instance.
+  using ConceptMap =
+      llvm::DenseMap<const AnalysisID *, std::unique_ptr<AnalysisConcept>>;
+
+  /// Utility to return the name of the given analysis class.
+  template <typename AnalysisT> static llvm::StringRef getAnalysisName() {
+    StringRef name = llvm::getTypeName<AnalysisT>();
+    if (!name.consume_front("mlir::"))
+      name.consume_front("(anonymous namespace)::");
+    return name;
+  }
+
+public:
+  explicit AnalysisMap(IRUnitT ir) : ir(ir) {}
+
+  /// Get an analysis for the current IR unit, computing it if necessary.
+  template <typename AnalysisT> AnalysisT &getAnalysis(PassInstrumentor *pi) {
+    auto *id = AnalysisID::getID<AnalysisT>();
+
+    typename ConceptMap::iterator it;
+    bool wasInserted;
+    std::tie(it, wasInserted) = analyses.try_emplace(id);
+
+    // If we don't have a cached analysis for this function, compute it directly
+    // and add it to the cache.
+    if (wasInserted) {
+      if (pi)
+        pi->runBeforeAnalysis(getAnalysisName<AnalysisT>(), id, ir);
+
+      it->second = llvm::make_unique<AnalysisModel<AnalysisT>>(ir);
+
+      if (pi)
+        pi->runAfterAnalysis(getAnalysisName<AnalysisT>(), id, ir);
+    }
+    return static_cast<AnalysisModel<AnalysisT> &>(*it->second).analysis;
+  }
+
+  /// Get a cached analysis instance if one exists, otherwise return null.
+  template <typename AnalysisT>
+  llvm::Optional<std::reference_wrapper<AnalysisT>> getCachedAnalysis() const {
+    auto res = analyses.find(AnalysisID::getID<AnalysisT>());
+    if (res == analyses.end())
+      return llvm::None;
+    return {static_cast<AnalysisModel<AnalysisT> &>(*res->second).analysis};
+  }
+
+  /// Returns the IR unit that this analysis map represents.
+  IRUnitT getIRUnit() { return ir; }
+  const IRUnitT getIRUnit() const { return ir; }
+
+  /// Clear any held analyses.
+  void clear() { analyses.clear(); }
+
+  /// Invalidate any cached analyses based upon the given set of preserved
+  /// analyses.
+  void invalidate(const detail::PreservedAnalyses &pa) {
+    // Remove any analyses not marked as preserved.
+    for (auto it = analyses.begin(), e = analyses.end(); it != e;) {
+      auto curIt = it++;
+      if (!pa.isPreserved(curIt->first))
+        analyses.erase(curIt);
+    }
+  }
+
+private:
+  IRUnitT ir;
+  ConceptMap analyses;
+};
+
+} // namespace detail
+
+//===----------------------------------------------------------------------===//
+// Analysis Management
+//===----------------------------------------------------------------------===//
+class ModuleAnalysisManager;
+
+/// An analysis manager for a specific function instance. This class can only be
+/// constructed from a ModuleAnalysisManager instance.
+class FunctionAnalysisManager {
+public:
+  // Query for a cached analysis on the parent Module. The analysis may not
+  // exist and if it does it may be stale.
+  template <typename AnalysisT>
+  llvm::Optional<std::reference_wrapper<AnalysisT>>
+  getCachedModuleAnalysis() const;
+
+  // Query for the given analysis for the current function.
+  template <typename AnalysisT> AnalysisT &getAnalysis() {
+    return impl->getAnalysis<AnalysisT>(getPassInstrumentor());
+  }
+
+  // Query for a cached entry of the given analysis on the current function.
+  template <typename AnalysisT>
+  llvm::Optional<std::reference_wrapper<AnalysisT>> getCachedAnalysis() const {
+    return impl->getCachedAnalysis<AnalysisT>();
+  }
+
+  /// Invalidate any non preserved analyses,
+  void invalidate(const detail::PreservedAnalyses &pa) {
+    // If all analyses were preserved, then there is nothing to do here.
+    if (pa.isAll())
+      return;
+    impl->invalidate(pa);
+  }
+
+  /// Clear any held analyses.
+  void clear() { impl->clear(); }
+
+  /// Returns a pass instrumentation object for the current function. This value
+  /// may be null.
+  PassInstrumentor *getPassInstrumentor() const;
+
+private:
+  FunctionAnalysisManager(const ModuleAnalysisManager *parent,
+                          detail::AnalysisMap<FuncOp> *impl)
+      : parent(parent), impl(impl) {}
+
+  /// A reference to the parent analysis manager.
+  const ModuleAnalysisManager *parent;
+
+  /// A reference to the impl analysis map within the owning analysis manager.
+  detail::AnalysisMap<FuncOp> *impl;
+
+  /// Allow access to the constructor.
+  friend class ModuleAnalysisManager;
+};
+
+/// An analysis manager for a specific module instance.
+class ModuleAnalysisManager {
+public:
+  ModuleAnalysisManager(ModuleOp module, PassInstrumentor *passInstrumentor)
+      : moduleAnalyses(module), passInstrumentor(passInstrumentor) {}
+  ModuleAnalysisManager(const ModuleAnalysisManager &) = delete;
+  ModuleAnalysisManager &operator=(const ModuleAnalysisManager &) = delete;
+
+  /// Query for the analysis of a function. The analysis is computed if it does
+  /// not exist.
+  template <typename AnalysisT>
+  AnalysisT &getFunctionAnalysis(FuncOp function) {
+    return slice(function).getAnalysis<AnalysisT>();
+  }
+
+  /// Query for a cached analysis of a child function, or return null.
+  template <typename AnalysisT>
+  llvm::Optional<std::reference_wrapper<AnalysisT>>
+  getCachedFunctionAnalysis(FuncOp function) const {
+    auto it = functionAnalyses.find(function);
+    if (it == functionAnalyses.end())
+      return llvm::None;
+    return it->second->getCachedAnalysis<AnalysisT>();
+  }
+
+  /// Query for the analysis for the module. The analysis is computed if it does
+  /// not exist.
+  template <typename AnalysisT> AnalysisT &getAnalysis() {
+    return moduleAnalyses.getAnalysis<AnalysisT>(getPassInstrumentor());
+  }
+
+  /// Query for a cached analysis for the module, or return null.
+  template <typename AnalysisT>
+  llvm::Optional<std::reference_wrapper<AnalysisT>> getCachedAnalysis() const {
+    return moduleAnalyses.getCachedAnalysis<AnalysisT>();
+  }
+
+  /// Create an analysis slice for the given child function.
+  FunctionAnalysisManager slice(FuncOp function);
+
+  /// Invalidate any non preserved analyses.
+  void invalidate(const detail::PreservedAnalyses &pa);
+
+  /// Returns a pass instrumentation object for the current module. This value
+  /// may be null.
+  PassInstrumentor *getPassInstrumentor() const { return passInstrumentor; }
+
+private:
+  /// The cached analyses for functions within the current module.
+  llvm::DenseMap<FuncOp, std::unique_ptr<detail::AnalysisMap<FuncOp>>>
+      functionAnalyses;
+
+  /// The analyses for the owning module.
+  detail::AnalysisMap<ModuleOp> moduleAnalyses;
+
+  /// An optional instrumentation object.
+  PassInstrumentor *passInstrumentor;
+};
+
+// Query for a cached analysis on the parent Module. The analysis may not exist
+// and if it does it may be stale.
+template <typename AnalysisT>
+llvm::Optional<std::reference_wrapper<AnalysisT>>
+FunctionAnalysisManager::getCachedModuleAnalysis() const {
+  return parent->getCachedAnalysis<AnalysisT>();
+}
+
+} // end namespace mlir
+
+#endif // MLIR_PASS_ANALYSISMANAGER_H
diff --git a/third_party/mlir/include/mlir/Pass/Pass.h b/third_party/mlir/include/mlir/Pass/Pass.h
new file mode 100644
index 0000000..b1531a3
--- /dev/null
+++ b/third_party/mlir/include/mlir/Pass/Pass.h
@@ -0,0 +1,289 @@
+//===- Pass.h - Base classes for compiler passes ----------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_PASS_PASS_H
+#define MLIR_PASS_PASS_H
+
+#include "mlir/Pass/AnalysisManager.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/PointerIntPair.h"
+
+namespace mlir {
+/// The abstract base pass class. This class contains information describing the
+/// derived pass object, e.g its kind and abstract PassInfo.
+class Pass {
+public:
+  enum class Kind { FunctionPass, ModulePass };
+
+  virtual ~Pass() = default;
+
+  /// Returns the unique identifier that corresponds to this pass.
+  const PassID *getPassID() const { return passIDAndKind.getPointer(); }
+
+  /// Returns the pass info for the specified pass class or null if unknown.
+  static const PassInfo *lookupPassInfo(const PassID *passID);
+  template <typename PassT> static const PassInfo *lookupPassInfo() {
+    return lookupPassInfo(PassID::getID<PassT>());
+  }
+
+  /// Returns the pass info for this pass.
+  const PassInfo *lookupPassInfo() const { return lookupPassInfo(getPassID()); }
+
+  /// Return the kind of this pass.
+  Kind getKind() const { return passIDAndKind.getInt(); }
+
+  /// Returns the derived pass name.
+  virtual StringRef getName() = 0;
+
+protected:
+  Pass(const PassID *passID, Kind kind) : passIDAndKind(passID, kind) {}
+
+private:
+  /// Out of line virtual method to ensure vtables and metadata are emitted to a
+  /// single .o file.
+  virtual void anchor();
+
+  /// Represents a unique identifier for the pass and its kind.
+  llvm::PointerIntPair<const PassID *, 1, Kind> passIDAndKind;
+};
+
+namespace detail {
+class FunctionPassExecutor;
+class ModulePassExecutor;
+
+/// The state for a single execution of a pass. This provides a unified
+/// interface for accessing and initializing necessary state for pass execution.
+template <typename IRUnitT, typename AnalysisManagerT>
+struct PassExecutionState {
+  PassExecutionState(IRUnitT ir, AnalysisManagerT &analysisManager)
+      : irAndPassFailed(ir, false), analysisManager(analysisManager) {}
+
+  /// The current IR unit being transformed and a bool for if the pass signaled
+  /// a failure.
+  llvm::PointerIntPair<IRUnitT, 1, bool> irAndPassFailed;
+
+  /// The analysis manager for the IR unit.
+  AnalysisManagerT &analysisManager;
+
+  /// The set of preserved analyses for the current execution.
+  detail::PreservedAnalyses preservedAnalyses;
+};
+} // namespace detail
+
+/// Pass to transform a specific function within a module. Derived passes should
+/// not inherit from this class directly, and instead should use the CRTP
+/// FunctionPass class.
+class FunctionPassBase : public Pass {
+  using PassStateT =
+      detail::PassExecutionState<FuncOp, FunctionAnalysisManager>;
+
+public:
+  static bool classof(const Pass *pass) {
+    return pass->getKind() == Kind::FunctionPass;
+  }
+
+protected:
+  explicit FunctionPassBase(const PassID *id) : Pass(id, Kind::FunctionPass) {}
+
+  /// The polymorphic API that runs the pass over the currently held function.
+  virtual void runOnFunction() = 0;
+
+  /// A clone method to create a copy of this pass.
+  virtual FunctionPassBase *clone() const = 0;
+
+  /// Return the current function being transformed.
+  FuncOp getFunction() { return getPassState().irAndPassFailed.getPointer(); }
+
+  /// Return the MLIR context for the current function being transformed.
+  MLIRContext &getContext() { return *getFunction().getContext(); }
+
+  /// Returns the current pass state.
+  PassStateT &getPassState() {
+    assert(passState && "pass state was never initialized");
+    return *passState;
+  }
+
+  /// Returns the current analysis manager.
+  FunctionAnalysisManager &getAnalysisManager() {
+    return getPassState().analysisManager;
+  }
+
+private:
+  /// Forwarding function to execute this pass.
+  LLVM_NODISCARD
+  LogicalResult run(FuncOp fn, FunctionAnalysisManager &fam);
+
+  /// The current execution state for the pass.
+  llvm::Optional<PassStateT> passState;
+
+  /// Allow access to 'run'.
+  friend detail::FunctionPassExecutor;
+};
+
+/// Pass to transform a module. Derived passes should not inherit from this
+/// class directly, and instead should use the CRTP ModulePass class.
+class ModulePassBase : public Pass {
+  using PassStateT =
+      detail::PassExecutionState<ModuleOp, ModuleAnalysisManager>;
+
+public:
+  static bool classof(const Pass *pass) {
+    return pass->getKind() == Kind::ModulePass;
+  }
+
+protected:
+  explicit ModulePassBase(const PassID *id) : Pass(id, Kind::ModulePass) {}
+
+  /// The polymorphic API that runs the pass over the currently held module.
+  virtual void runOnModule() = 0;
+
+  /// Return the current module being transformed.
+  ModuleOp getModule() { return getPassState().irAndPassFailed.getPointer(); }
+
+  /// Return the MLIR context for the current module being transformed.
+  MLIRContext &getContext() { return *getModule().getContext(); }
+
+  /// Returns the current pass state.
+  PassStateT &getPassState() {
+    assert(passState && "pass state was never initialized");
+    return *passState;
+  }
+
+  /// Returns the current analysis manager.
+  ModuleAnalysisManager &getAnalysisManager() {
+    return getPassState().analysisManager;
+  }
+
+private:
+  /// Forwarding function to execute this pass.
+  LLVM_NODISCARD
+  LogicalResult run(ModuleOp module, ModuleAnalysisManager &mam);
+
+  /// The current execution state for the pass.
+  llvm::Optional<PassStateT> passState;
+
+  /// Allow access to 'run'.
+  friend detail::ModulePassExecutor;
+};
+
+//===----------------------------------------------------------------------===//
+// Pass Model Definitions
+//===----------------------------------------------------------------------===//
+namespace detail {
+/// The opaque CRTP model of a pass. This class provides utilities for derived
+/// pass execution and handles all of the necessary polymorphic API.
+template <typename IRUnitT, typename PassT, typename BasePassT>
+class PassModel : public BasePassT {
+public:
+  /// Support isa/dyn_cast functionality for the derived pass class.
+  static bool classof(const Pass *pass) {
+    return pass->getPassID() == PassID::getID<PassT>();
+  }
+
+protected:
+  PassModel() : BasePassT(PassID::getID<PassT>()) {}
+
+  /// Signal that some invariant was broken when running. The IR is allowed to
+  /// be in an invalid state.
+  void signalPassFailure() {
+    this->getPassState().irAndPassFailed.setInt(true);
+  }
+
+  /// Query an analysis for the current ir unit.
+  template <typename AnalysisT> AnalysisT &getAnalysis() {
+    return this->getAnalysisManager().template getAnalysis<AnalysisT>();
+  }
+
+  /// Query a cached instance of an analysis for the current ir unit if one
+  /// exists.
+  template <typename AnalysisT>
+  llvm::Optional<std::reference_wrapper<AnalysisT>> getCachedAnalysis() {
+    return this->getAnalysisManager().template getCachedAnalysis<AnalysisT>();
+  }
+
+  /// Mark all analyses as preserved.
+  void markAllAnalysesPreserved() {
+    this->getPassState().preservedAnalyses.preserveAll();
+  }
+
+  /// Mark the provided analyses as preserved.
+  template <typename... AnalysesT> void markAnalysesPreserved() {
+    this->getPassState().preservedAnalyses.template preserve<AnalysesT...>();
+  }
+  void markAnalysesPreserved(const AnalysisID *id) {
+    this->getPassState().preservedAnalyses.preserve(id);
+  }
+
+  /// Returns the derived pass name.
+  StringRef getName() override {
+    StringRef name = llvm::getTypeName<PassT>();
+    if (!name.consume_front("mlir::"))
+      name.consume_front("(anonymous namespace)::");
+    return name;
+  }
+};
+} // end namespace detail
+
+/// A model for providing function pass specific utilities.
+///
+/// Function passes must not:
+///   - read or modify any other functions within the parent module, as
+///     other threads may be manipulating them concurrently.
+///   - modify any state within the parent module, this includes adding
+///     additional functions.
+///
+/// Derived function passes are expected to provide the following:
+///   - A 'void runOnFunction()' method.
+template <typename T>
+struct FunctionPass : public detail::PassModel<FuncOp, T, FunctionPassBase> {
+  /// Returns the analysis for the parent module if it exists.
+  template <typename AnalysisT>
+  llvm::Optional<std::reference_wrapper<AnalysisT>> getCachedModuleAnalysis() {
+    return this->getAnalysisManager()
+        .template getCachedModuleAnalysis<AnalysisT>();
+  }
+
+  /// A clone method to create a copy of this pass.
+  FunctionPassBase *clone() const override {
+    return new T(*static_cast<const T *>(this));
+  }
+};
+
+/// A model for providing module pass specific utilities.
+///
+/// Derived module passes are expected to provide the following:
+///   - A 'void runOnModule()' method.
+template <typename T>
+struct ModulePass : public detail::PassModel<ModuleOp, T, ModulePassBase> {
+  /// Returns the analysis for a child function.
+  template <typename AnalysisT> AnalysisT &getFunctionAnalysis(FuncOp f) {
+    return this->getAnalysisManager().template getFunctionAnalysis<AnalysisT>(
+        f);
+  }
+
+  /// Returns an existing analysis for a child function if it exists.
+  template <typename AnalysisT>
+  llvm::Optional<std::reference_wrapper<AnalysisT>>
+  getCachedFunctionAnalysis(FuncOp f) {
+    return this->getAnalysisManager()
+        .template getCachedFunctionAnalysis<AnalysisT>(f);
+  }
+};
+} // end namespace mlir
+
+#endif // MLIR_PASS_PASS_H
diff --git a/third_party/mlir/include/mlir/Pass/PassInstrumentation.h b/third_party/mlir/include/mlir/Pass/PassInstrumentation.h
new file mode 100644
index 0000000..4035832
--- /dev/null
+++ b/third_party/mlir/include/mlir/Pass/PassInstrumentation.h
@@ -0,0 +1,133 @@
+//===- PassInstrumentation.h ------------------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_PASS_PASSINSTRUMENTATION_H_
+#define MLIR_PASS_PASSINSTRUMENTATION_H_
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/STLExtras.h"
+#include "llvm/ADT/Any.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace mlir {
+using AnalysisID = ClassID;
+class Pass;
+
+namespace detail {
+struct PassInstrumentorImpl;
+} // end namespace detail
+
+/// PassInstrumentation provdes several entry points into the pass manager
+/// infrastructure. Instrumentations should be added directly to a PassManager
+/// before running a pipeline.
+class PassInstrumentation {
+public:
+  virtual ~PassInstrumentation() = 0;
+
+  /// A callback to run before a pass is executed. This function takes a pointer
+  /// to the pass to be executed, as well as an llvm::Any holding a pointer to
+  /// the IR unit being transformed on.
+  virtual void runBeforePass(Pass *pass, const llvm::Any &ir) {}
+
+  /// A callback to run after a pass is successfully executed. This function
+  /// takes a pointer to the pass to be executed, as well as an llvm::Any
+  /// holding a pointer to the IR unit being transformed on.
+  virtual void runAfterPass(Pass *pass, const llvm::Any &ir) {}
+
+  /// A callback to run when a pass execution fails. This function takes a
+  /// pointer to the pass that was being executed, as well as an llvm::Any
+  /// holding a pointer to the IR unit that was being transformed. Note
+  /// that the ir unit may be in an invalid state.
+  virtual void runAfterPassFailed(Pass *pass, const llvm::Any &ir) {}
+
+  /// A callback to run before an analysis is computed. This function takes the
+  /// name of the analysis to be computed, its AnalysisID, as well as an
+  /// llvm::Any holding a pointer to the IR unit being analyzed on.
+  virtual void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id,
+                                 const llvm::Any &ir) {}
+
+  /// A callback to run before an analysis is computed. This function takes the
+  /// name of the analysis that was computed, its AnalysisID, as well as an
+  /// llvm::Any holding a pointer to the IR unit that was analyzed.
+  virtual void runAfterAnalysis(llvm::StringRef name, AnalysisID *id,
+                                const llvm::Any &ir) {}
+};
+
+/// This class holds a collection of PassInstrumentation objects, and invokes
+/// their respective call backs.
+class PassInstrumentor {
+public:
+  PassInstrumentor();
+  PassInstrumentor(PassInstrumentor &&) = delete;
+  PassInstrumentor(const PassInstrumentor &) = delete;
+  ~PassInstrumentor();
+
+  /// See PassInstrumentation::runBeforePass for details.
+  template <typename IRUnitT> void runBeforePass(Pass *pass, IRUnitT ir) {
+    runBeforePass(pass, llvm::Any(ir));
+  }
+
+  /// See PassInstrumentation::runAfterPass for details.
+  template <typename IRUnitT> void runAfterPass(Pass *pass, IRUnitT ir) {
+    runAfterPass(pass, llvm::Any(ir));
+  }
+
+  /// See PassInstrumentation::runAfterPassFailed for details.
+  template <typename IRUnitT> void runAfterPassFailed(Pass *pass, IRUnitT ir) {
+    runAfterPassFailed(pass, llvm::Any(ir));
+  }
+
+  /// See PassInstrumentation::runBeforeAnalysis for details.
+  template <typename IRUnitT>
+  void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT ir) {
+    runBeforeAnalysis(name, id, llvm::Any(ir));
+  }
+
+  /// See PassInstrumentation::runAfterAnalysis for details.
+  template <typename IRUnitT>
+  void runAfterAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT ir) {
+    runAfterAnalysis(name, id, llvm::Any(ir));
+  }
+
+  /// Add the given instrumentation to the collection. This takes ownership over
+  /// the given pointer.
+  void addInstrumentation(PassInstrumentation *pi);
+
+private:
+  /// See PassInstrumentation::runBeforePass for details.
+  void runBeforePass(Pass *pass, const llvm::Any &ir);
+
+  /// See PassInstrumentation::runAfterPass for details.
+  void runAfterPass(Pass *pass, const llvm::Any &ir);
+
+  /// See PassInstrumentation::runAfterPassFailed for details.
+  void runAfterPassFailed(Pass *pass, const llvm::Any &ir);
+
+  /// See PassInstrumentation::runBeforeAnalysis for details.
+  void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id,
+                         const llvm::Any &ir);
+
+  /// See PassInstrumentation::runAfterAnalysis for details.
+  void runAfterAnalysis(llvm::StringRef name, AnalysisID *id,
+                        const llvm::Any &ir);
+
+  std::unique_ptr<detail::PassInstrumentorImpl> impl;
+};
+
+} // end namespace mlir
+
+#endif // MLIR_PASS_PASSINSTRUMENTATION_H_
diff --git a/third_party/mlir/include/mlir/Pass/PassManager.h b/third_party/mlir/include/mlir/Pass/PassManager.h
new file mode 100644
index 0000000..68dfeb0
--- /dev/null
+++ b/third_party/mlir/include/mlir/Pass/PassManager.h
@@ -0,0 +1,142 @@
+//===- PassManager.h - Pass Management Interface ----------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_PASS_PASSMANAGER_H
+#define MLIR_PASS_PASSMANAGER_H
+
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace llvm {
+class Any;
+} // end namespace llvm
+
+namespace mlir {
+class FunctionPassBase;
+class ModuleOp;
+class ModulePassBase;
+class Pass;
+class PassInstrumentation;
+class PassInstrumentor;
+
+namespace detail {
+class PassExecutor;
+class ModulePassExecutor;
+} // end namespace detail
+
+/// An enum describing the different display modes for the pass timing
+/// information within the pass manager.
+enum class PassTimingDisplayMode {
+  // In this mode the results are displayed in a list sorted by total time,
+  // with each pass/analysis instance aggregated into one unique result.
+  List,
+
+  // In this mode the results are displayed in a nested pipeline view that
+  // mirrors the internal pass pipeline that is being executed in the pass
+  // manager.
+  Pipeline,
+};
+
+/// The main pass manager and pipeline builder.
+class PassManager {
+public:
+  // If verifyPasses is true, the verifier is run after each pass.
+  PassManager(bool verifyPasses = true);
+  ~PassManager();
+
+  /// Run the passes within this manager on the provided module.
+  LLVM_NODISCARD
+  LogicalResult run(ModuleOp module);
+
+  /// Disable support for multi-threading within the pass manager.
+  void disableMultithreading(bool disable = true);
+
+  //===--------------------------------------------------------------------===//
+  // Pipeline Building
+  //===--------------------------------------------------------------------===//
+
+  /// Add an opaque pass pointer to the current manager. This takes ownership
+  /// over the provided pass pointer.
+  void addPass(Pass *pass);
+
+  /// Add a module pass to the current manager. This takes ownership over the
+  /// provided pass pointer.
+  void addPass(ModulePassBase *pass);
+
+  /// Add a function pass to the current manager. This takes ownership over the
+  /// provided pass pointer. This will automatically create a function pass
+  /// executor if necessary.
+  void addPass(FunctionPassBase *pass);
+
+  //===--------------------------------------------------------------------===//
+  // Instrumentations
+  //===--------------------------------------------------------------------===//
+
+  /// Add the provided instrumentation to the pass manager. This takes ownership
+  /// over the given pointer.
+  void addInstrumentation(PassInstrumentation *pi);
+
+  /// Add an instrumentation to print the IR before and after pass execution.
+  /// * 'shouldPrintBeforePass' and 'shouldPrintAfterPass' correspond to filter
+  ///   functions that take a 'Pass *'. These function should return true if the
+  ///   IR should be printed or not.
+  /// * 'printModuleScope' signals if the module IR should be printed, even for
+  ///   non module passes.
+  /// * 'out' corresponds to the stream to output the printed IR to.
+  void enableIRPrinting(std::function<bool(Pass *)> shouldPrintBeforePass,
+                        std::function<bool(Pass *)> shouldPrintAfterPass,
+                        bool printModuleScope, raw_ostream &out);
+
+  /// Add an instrumentation to time the execution of passes and the computation
+  /// of analyses.
+  /// Note: Timing should be enabled after all other instrumentations to avoid
+  /// any potential "ghost" timing from other instrumentations being
+  /// unintentionally included in the timing results.
+  void enableTiming(
+      PassTimingDisplayMode displayMode = PassTimingDisplayMode::Pipeline);
+
+private:
+  /// A stack of nested pass executors on sub-module IR units, e.g. function.
+  llvm::SmallVector<detail::PassExecutor *, 1> nestedExecutorStack;
+
+  /// The top level module pass executor.
+  std::unique_ptr<detail::ModulePassExecutor> mpe;
+
+  /// Flag that specifies if the IR should be verified after each pass has run.
+  bool verifyPasses : 1;
+
+  /// Flag that specifies if pass timing is enabled.
+  bool passTiming : 1;
+
+  /// Flag that specifies if multi-threading is disabled.
+  bool disableThreads : 1;
+
+  /// A manager for pass instrumentations.
+  std::unique_ptr<PassInstrumentor> instrumentor;
+};
+
+/// Register a set of useful command-line options that can be used to configure
+/// a pass manager. The values of these options can be applied via the
+/// 'applyPassManagerCLOptions' method below.
+void registerPassManagerCLOptions();
+
+/// Apply any values provided to the pass manager options that were registered
+/// with 'registerPassManagerOptions'.
+void applyPassManagerCLOptions(PassManager &pm);
+} // end namespace mlir
+
+#endif // MLIR_PASS_PASSMANAGER_H
diff --git a/third_party/mlir/include/mlir/Pass/PassRegistry.h b/third_party/mlir/include/mlir/Pass/PassRegistry.h
new file mode 100644
index 0000000..ea0fbbe
--- /dev/null
+++ b/third_party/mlir/include/mlir/Pass/PassRegistry.h
@@ -0,0 +1,165 @@
+//===- PassRegistry.h - Pass Registration Utilities -------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains utilities for registering information about compiler
+// passes.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_PASS_PASSREGISTRY_H_
+#define MLIR_PASS_PASSREGISTRY_H_
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/STLExtras.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Compiler.h"
+#include <functional>
+
+namespace mlir {
+class Pass;
+class PassManager;
+
+/// A registry function that adds passes to the given pass manager.
+using PassRegistryFunction = std::function<void(PassManager &)>;
+
+using PassAllocatorFunction = std::function<Pass *()>;
+
+/// A special type used by transformation passes to provide an address that can
+/// act as a unique identifier during pass registration.
+using PassID = ClassID;
+
+/// Structure to group information about a passes and pass pipelines (argument
+/// to invoke via mlir-opt, description, pass pipeline builder).
+class PassRegistryEntry {
+public:
+  /// Adds this pass registry entry to the given pass manager.
+  void addToPipeline(PassManager &pm) const {
+    assert(builder &&
+           "Cannot call addToPipeline on PassRegistryEntry without builder");
+    builder(pm);
+  }
+
+  /// Returns the command line option that may be passed to 'mlir-opt' that will
+  /// cause this pass to run or null if there is no such argument.
+  StringRef getPassArgument() const { return arg; }
+
+  /// Returns a description for the pass, this never returns null.
+  StringRef getPassDescription() const { return description; }
+
+protected:
+  PassRegistryEntry(StringRef arg, StringRef description,
+                    PassRegistryFunction builder)
+      : arg(arg), description(description), builder(builder) {}
+
+private:
+  // The argument with which to invoke the pass via mlir-opt.
+  StringRef arg;
+
+  // Description of the pass.
+  StringRef description;
+
+  // Function to register this entry to a pass manager pipeline.
+  PassRegistryFunction builder;
+};
+
+/// A structure to represent the information of a registered pass pipeline.
+class PassPipelineInfo : public PassRegistryEntry {
+public:
+  PassPipelineInfo(StringRef arg, StringRef description,
+                   PassRegistryFunction builder)
+      : PassRegistryEntry(arg, description, builder) {}
+};
+
+/// A structure to represent the information for a derived pass class.
+class PassInfo : public PassRegistryEntry {
+public:
+  /// PassInfo constructor should not be invoked directly, instead use
+  /// PassRegistration or registerPass.
+  PassInfo(StringRef arg, StringRef description, const PassID *passID,
+           PassAllocatorFunction allocator);
+};
+
+/// Register a specific dialect pipeline registry function with the system,
+/// typically used through the PassPipelineRegistration template.
+void registerPassPipeline(StringRef arg, StringRef description,
+                          const PassRegistryFunction &function);
+
+/// Register a specific dialect pass allocator function with the system,
+/// typically used through the PassRegistration template.
+void registerPass(StringRef arg, StringRef description, const PassID *passID,
+                  const PassAllocatorFunction &function);
+
+/// PassRegistration provides a global initializer that registers a Pass
+/// allocation routine for a concrete pass instance.  The third argument is
+/// optional and provides a callback to construct a pass that does not have
+/// a default constructor.
+///
+/// Usage:
+///
+///   // At namespace scope.
+///   static PassRegistration<MyPass> Unused("unused", "Unused pass");
+template <typename ConcretePass> struct PassRegistration {
+  PassRegistration(StringRef arg, StringRef description,
+                   const PassAllocatorFunction &constructor) {
+    registerPass(arg, description, PassID::getID<ConcretePass>(), constructor);
+  }
+
+  PassRegistration(StringRef arg, StringRef description) {
+    PassAllocatorFunction constructor = [] { return new ConcretePass(); };
+    registerPass(arg, description, PassID::getID<ConcretePass>(), constructor);
+  }
+};
+
+/// PassPipelineRegistration provides a global initializer that registers a Pass
+/// pipeline builder routine.
+///
+/// Usage:
+///
+///   // At namespace scope.
+///   void pipelineBuilder(PassManager &pm) {
+///      pm.addPass(new MyPass());
+///      pm.addPass(new MyOtherPass());
+///   }
+///
+///   static PassPipelineRegistration Unused("unused", "Unused pass",
+///                                          pipelineBuilder);
+struct PassPipelineRegistration {
+  PassPipelineRegistration(StringRef arg, StringRef description,
+                           PassRegistryFunction builder) {
+    registerPassPipeline(arg, description, builder);
+  }
+
+  /// Constructor that accepts a pass allocator function instead of the standard
+  /// registry function. This is useful for registering specializations of
+  /// existing passes.
+  PassPipelineRegistration(StringRef arg, StringRef description,
+                           PassAllocatorFunction allocator);
+};
+
+/// Adds command line option for each registered pass.
+struct PassNameParser : public llvm::cl::parser<const PassRegistryEntry *> {
+  PassNameParser(llvm::cl::Option &opt);
+
+  void initialize();
+
+  void printOptionInfo(const llvm::cl::Option &O,
+                       size_t GlobalWidth) const override;
+};
+} // end namespace mlir
+
+#endif // MLIR_PASS_PASSREGISTRY_H_
diff --git a/third_party/mlir/include/mlir/Quantizer/Configurations/FxpMathConfig.h b/third_party/mlir/include/mlir/Quantizer/Configurations/FxpMathConfig.h
new file mode 100644
index 0000000..467512f
--- /dev/null
+++ b/third_party/mlir/include/mlir/Quantizer/Configurations/FxpMathConfig.h
@@ -0,0 +1,50 @@
+//===- FxpMathConfig.h - Reference fixed point config -----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines a TargetConfiguration for reference fixed-point math
+// quantization scheme based on the FxpMathOps (plus a small category of
+// extension ops that can be added from other dialects).
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_QUANTIZER_CONFIGURATIONS_FXPMATHCONFIG_H
+#define MLIR_QUANTIZER_CONFIGURATIONS_FXPMATHCONFIG_H
+
+#include "mlir/Quantizer/Support/Configuration.h"
+#include "mlir/Quantizer/Support/Metadata.h"
+
+namespace mlir {
+namespace quantizer {
+
+/// Target configuration for a reference affine/fixed-point quantization
+/// scheme defined in terms of the FxpMathOps dialect. This can be extended
+/// with select ops from other dialects by way of the following public
+/// methods:
+///   - addValueIdentityOp
+class FxpMathTargetConfig : public TargetConfiguration {
+public:
+  /// Creates an FxpMathTargetConfig instance which can be further customized.
+  static std::unique_ptr<FxpMathTargetConfig> create(SolverContext &context);
+
+protected:
+  FxpMathTargetConfig(SolverContext &context) : TargetConfiguration(context) {}
+};
+
+} // namespace quantizer
+} // namespace mlir
+
+#endif // MLIR_QUANTIZER_CONFIGURATIONS_FXPMATHCONFIG_H
diff --git a/third_party/mlir/include/mlir/Quantizer/Support/Configuration.h b/third_party/mlir/include/mlir/Quantizer/Support/Configuration.h
new file mode 100644
index 0000000..a260824
--- /dev/null
+++ b/third_party/mlir/include/mlir/Quantizer/Support/Configuration.h
@@ -0,0 +1,155 @@
+//===- Configuration.h - Configuration object base classes ------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// The quantizer is relatively agnostic to source and target dialects, with
+// the specific represented by configuration policy objects derived from
+// classes in this file.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_QUANTIZER_SUPPORT_CONFIGURATION_H
+#define MLIR_QUANTIZER_SUPPORT_CONFIGURATION_H
+
+#include <functional>
+
+#include "mlir/Dialect/QuantOps/QuantTypes.h"
+#include "mlir/IR/Identifier.h"
+#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
+#include "mlir/Quantizer/Support/Metadata.h"
+#include "mlir/Quantizer/Support/Rules.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/StringSet.h"
+
+namespace mlir {
+class Operation;
+
+namespace quantizer {
+
+class CAGSlice;
+
+/// Defines quantization configuration for the target.
+/// The settings here depend on a variety of details about the deployment
+/// environment, although, where we have control over such things, we do
+/// try to standardize as possible.
+///
+/// Non-const methods are used to setup the configuration. It is expected that
+/// const instances/references are used post-build.
+class TargetConfiguration {
+public:
+  static constexpr size_t MaxSchemeIndex = 31;
+  using OpHandlerFn = std::function<void(Operation *op, CAGSlice &cag)>;
+
+  TargetConfiguration(SolverContext &context);
+  virtual ~TargetConfiguration() = default;
+
+  /// Adds a candidate type, returning its ordinal.
+  unsigned addCandidateType(quant::AnyQuantizedType quantizedType,
+                            CandidateQuantizedType::Scheme scheme) {
+    unsigned ordinal = candidateTypes.size();
+    assert(allCandidateTypesMask.size() == ordinal);
+    CandidateQuantizedType ct{ordinal, quantizedType, scheme};
+    candidateTypes.push_back(ct);
+    allCandidateTypesMask.push_back(true);
+    return ordinal;
+  }
+
+  /// Gets a prototype scheme by index.
+  const CandidateQuantizedType &getCandidateType(unsigned index) const {
+    assert(index < candidateTypes.size());
+    return candidateTypes[index];
+  }
+
+  llvm::ArrayRef<CandidateQuantizedType> getCandidateTypes() const {
+    return candidateTypes;
+  }
+
+  /// Gets a mask of all enabled candidate types by ordinal.
+  llvm::SmallBitVector getAllCandidateTypesMask() const {
+    return allCandidateTypesMask;
+  }
+
+  /// Gets a mask with every candidate type except those in the given mask.
+  llvm::SmallBitVector getCandidateTypeDisabledExceptMask(
+      llvm::ArrayRef<unsigned> exceptOrdinals) const {
+    llvm::SmallBitVector disabled(allCandidateTypesMask);
+    for (unsigned ordinal : exceptOrdinals) {
+      disabled.reset(ordinal);
+    }
+    return disabled;
+  }
+
+  /// Adds an op handler.
+  template <typename OpTy>
+  void addOpHandler(OpHandlerFn fn) {
+    addOpHandlerByName(OpTy::getOperationName(), fn);
+  }
+
+  /// Adds an operation which requires statistics at its result nodes for
+  /// best quantization performance. Note that the opName StringRef is
+  /// expected to come from getOperationName() and be static.
+  template <typename OpTy>
+  void addRequireStatsOp() {
+    addRequireStatsOpByName(OpTy::getOperationName());
+  }
+
+  /// Returns whether opName is a RequireStatsOp.
+  bool isRequireStatsOp(Operation *op) const;
+
+  /// Adds an op which does not mutate its values but may mutate its shape
+  /// or combine its operands in an arbitrary way.
+  /// Such ops are expected to have the same types for operands and results
+  /// and must be capable of operating on storage types.
+  template <typename OpTy>
+  void addValueIdentityOp() {
+    addValueIdentityOpByName(OpTy::getOperationName());
+  }
+
+  /// Handles the operation if a handler is defined for it.
+  void handleOp(Operation *op, CAGSlice &cag) const;
+
+  /// Finalizes the CAG after all anchors have been added.
+  virtual void finalizeAnchors(CAGSlice &cag) const {}
+
+  /// Whether an operand or result type is subject to analysis by this config.
+  virtual bool isHandledType(Type t) const = 0;
+
+protected:
+  virtual void addValueIdentityOpByName(StringRef opName) = 0;
+  void addOpHandlerByName(StringRef name, OpHandlerFn fn);
+
+private:
+  void addRequireStatsOpByName(StringRef opName);
+
+  /// Vector of all candidate type constraints, indexed by ordinal.
+  std::vector<CandidateQuantizedType> candidateTypes;
+
+  // A SmallBoolVector with bits set for all known candidate types.
+  llvm::SmallBitVector allCandidateTypesMask;
+
+  /// Map of all op handlers.
+  llvm::StringMap<OpHandlerFn> opHandlers;
+
+  /// Names of operations which should have their results annotated with
+  /// statistics.
+  llvm::StringSet<> requireStatsOpNames;
+};
+
+} // namespace quantizer
+} // namespace mlir
+
+#endif // MLIR_QUANTIZER_SUPPORT_CONFIGURATION_H
diff --git a/third_party/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h b/third_party/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h
new file mode 100644
index 0000000..8f2a0e5
--- /dev/null
+++ b/third_party/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h
@@ -0,0 +1,374 @@
+//===- ConstraintAnalysisGraph.h - Graphs type for constraints --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file provides graph-based data structures for representing anchors
+// and constraints between them.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPH_H
+#define MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPH_H
+
+#include <utility>
+#include <vector>
+
+#include "mlir/IR/Function.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Quantizer/Support/Metadata.h"
+#include "llvm/ADT/DenseMap.h"
+
+namespace mlir {
+namespace quantizer {
+
+class CAGNode;
+class CAGSlice;
+class TargetConfiguration;
+
+/// A node in the Constraint Analysis Graph.
+/// Nodes are either anchors (representing results and operands) or constraints.
+/// Anchor nodes are connected to other anchor nodes via constraints.
+/// Nodes exist within graph slices, which are typically analyses attached to
+/// the function or module. Slices can contain other slices, which mirrors
+/// the nesting of analyses.
+///
+/// Nodes have directed relationships which propagate successor-ward when dirty.
+/// Relationships can be bi-directional, in which case, the constraint's
+/// propagation mechanism must ensure convergence.
+class CAGNode {
+public:
+  enum class Kind {
+    /// Anchors.
+    Anchor,
+    OperandAnchor,
+    ResultAnchor,
+    LastAnchor = ResultAnchor,
+
+    /// Constraints.
+    Constraint,
+    SolveUniformConstraint,
+    UniformPropagateExplicitScale,
+    LastConstraint = UniformPropagateExplicitScale,
+  };
+
+  // Vector and iterator over nodes.
+  using node_vector = llvm::SmallVector<CAGNode *, 1>;
+  using iterator = node_vector::iterator;
+  using const_iterator = node_vector::const_iterator;
+
+  virtual ~CAGNode() = default;
+
+  Kind getKind() const { return kind; }
+
+  /// Unique id of the node within the slice.
+  int getNodeId() const { return nodeId; }
+
+  /// Whether the node is dirty, requiring one or more calls to propagate().
+  bool isDirty() const { return dirty; }
+  void markDirty() { dirty = true; }
+  void clearDirty() { dirty = false; }
+
+  /// Iterator over this node's children (outgoing) nodes.
+  const_iterator begin() const { return outgoing.begin(); }
+  const_iterator end() const { return outgoing.end(); }
+  iterator begin() { return outgoing.begin(); }
+  iterator end() { return outgoing.end(); }
+
+  /// Iterator over this parents (incoming) nodes.
+  const_iterator incoming_begin() const { return incoming.begin(); }
+  const_iterator incoming_end() const { return incoming.end(); }
+  iterator incoming_begin() { return incoming.begin(); }
+  iterator incoming_end() { return incoming.end(); }
+
+  virtual void propagate(SolverContext &solverContext,
+                         const TargetConfiguration &config) {}
+
+  /// Prints the node label, suitable for one-line display.
+  virtual void printLabel(llvm::raw_ostream &os) const;
+
+  template <typename T>
+  void findChildrenOfKind(llvm::SmallVectorImpl<T *> &found) {
+    for (CAGNode *child : *this) {
+      T *ofKind = llvm::dyn_cast<T>(child);
+      if (ofKind) {
+        found.push_back(ofKind);
+      }
+    }
+  }
+
+  /// Replaces this node by rerouting any parent nodes to have otherNode
+  /// as a child.
+  void replaceIncoming(CAGNode *otherNode);
+
+  /// Adds an outgoing connection to this node (and corresponding back
+  /// incoming connection).
+  void addOutgoing(CAGNode *toNode);
+
+  /// Whether this node is an orphan (has no incoming or outgoing connections).
+  bool isOrphan() const { return incoming.empty() && outgoing.empty(); }
+
+protected:
+  CAGNode(Kind kind) : kind(kind) {}
+
+private:
+  Kind kind;
+  int nodeId = -1;
+  node_vector outgoing;
+  node_vector incoming;
+  bool dirty = false;
+
+  friend class CAGSlice;
+};
+
+/// Anchor nodes represent points in the source IR where we may choose to
+/// introduce a type transition. These include operands, results, arguments
+/// returns, etc.
+class CAGAnchorNode : public CAGNode {
+public:
+  enum class TypeTransformRule {
+    /// The owning op directly supports all transformed types. In practice,
+    /// this means that the op supports QuantizedType for this anchor.
+    Direct,
+
+    /// The type of this anchor should be set to the QuantizedType storage
+    /// type. This will only be valid if constraints are such that all
+    /// inputs/outputs converge to the same storage type (i.e. coupled).
+    DirectStorage,
+
+    /// The anchor must only be typed based on the expressed type. This is
+    /// used for ops that do not natively support quantization, and suitable
+    /// casts will be inserted.
+    ExpressedOnly,
+  };
+
+  /// Metadata for solving uniform quantization params.
+  CAGUniformMetadata &getUniformMetadata() { return uniformMetadata; }
+  const CAGUniformMetadata &getUniformMetadata() const {
+    return uniformMetadata;
+  }
+
+  virtual Operation *getOp() const = 0;
+  virtual Value *getValue() const = 0;
+
+  static bool classof(const CAGNode *n) {
+    return n->getKind() >= Kind::Anchor && n->getKind() <= Kind::LastAnchor;
+  }
+
+  void propagate(SolverContext &solverContext,
+                 const TargetConfiguration &config) override;
+
+  void printLabel(llvm::raw_ostream &os) const override;
+
+  /// Given the anchor metadata and resolved solutions, chooses the most
+  /// salient and returns an appropriate type to represent it.
+  Type getTransformedType();
+
+  TypeTransformRule getTypeTransformRule() const { return typeTransformRule; }
+
+  void setTypeTransformRule(TypeTransformRule r) { typeTransformRule = r; }
+
+  /// Gets the Type that was defined for this anchor at the time of
+  /// construction.
+  Type getOriginalType() const { return originalType; }
+
+protected:
+  CAGAnchorNode(Kind kind, Type originalType)
+      : CAGNode(kind), originalType(originalType) {}
+
+private:
+  CAGUniformMetadata uniformMetadata;
+  Type originalType;
+  TypeTransformRule typeTransformRule = TypeTransformRule::Direct;
+};
+
+/// An anchor tied to a specific operand.
+/// Since operand anchors can be rewritten so that the operand refers to
+/// a new result, they are maintained by reference (to the op and index).
+class CAGOperandAnchor : public CAGAnchorNode {
+public:
+  CAGOperandAnchor(Operation *op, unsigned operandIdx);
+
+  Operation *getOp() const final { return op; }
+  unsigned getOperandIdx() const { return operandIdx; }
+
+  static bool classof(const CAGNode *n) {
+    return n->getKind() == Kind::Anchor || n->getKind() == Kind::OperandAnchor;
+  }
+
+  Value *getValue() const final { return op->getOperand(operandIdx); }
+
+  void printLabel(llvm::raw_ostream &os) const override;
+
+private:
+  Operation *op;
+  unsigned operandIdx;
+};
+
+/// An anchor tied to a specific result.
+/// Since a result is already anchored to its defining op, result anchors refer
+/// directly to the underlying Value*.
+class CAGResultAnchor : public CAGAnchorNode {
+public:
+  CAGResultAnchor(Operation *op, unsigned resultIdx);
+
+  static bool classof(const CAGNode *n) {
+    return n->getKind() == Kind::Anchor || n->getKind() == Kind::ResultAnchor;
+  }
+
+  Operation *getOp() const final { return resultValue->getDefiningOp(); }
+  Value *getValue() const final { return resultValue; }
+
+  void printLabel(llvm::raw_ostream &os) const override;
+
+private:
+  Value *resultValue;
+};
+
+/// Base class for constraint nodes.
+class CAGConstraintNode : public CAGNode {
+public:
+  CAGConstraintNode(Kind kind) : CAGNode(kind) {}
+
+  static bool classof(const CAGNode *n) {
+    return n->getKind() >= Kind::Constraint &&
+           n->getKind() <= Kind::LastConstraint;
+  }
+};
+
+/// A slice of a CAG (which may be the whole graph).
+class CAGSlice {
+public:
+  CAGSlice(SolverContext &context);
+  ~CAGSlice();
+
+  using node_vector = std::vector<CAGNode *>;
+  using iterator = node_vector::iterator;
+  using const_iterator = node_vector::const_iterator;
+
+  iterator begin() { return allNodes.begin(); }
+  iterator end() { return allNodes.end(); }
+  const_iterator begin() const { return allNodes.begin(); }
+  const_iterator end() const { return allNodes.end(); }
+
+  /// Gets an operand anchor node.
+  CAGOperandAnchor *getOperandAnchor(Operation *op, unsigned operandIdx);
+
+  /// Gets a result anchor node.
+  CAGResultAnchor *getResultAnchor(Operation *op, unsigned resultIdx);
+
+  /// Adds a relation constraint with incoming 'from' anchors and outgoing 'to'
+  /// anchors.
+  template <typename T, typename... Args>
+  T *addUniqueConstraint(llvm::ArrayRef<CAGAnchorNode *> anchors,
+                         Args... args) {
+    static_assert(std::is_convertible<T *, CAGConstraintNode *>(),
+                  "T must be a CAGConstraingNode");
+    T *constraintNode = addNode(llvm::make_unique<T>(args...));
+    for (auto *anchor : anchors)
+      anchor->addOutgoing(constraintNode);
+    return constraintNode;
+  }
+
+  /// Adds a unidirectional constraint from a node to an array of target nodes.
+  template <typename T, typename... Args>
+  T *addUnidirectionalConstraint(CAGAnchorNode *fromAnchor,
+                                 llvm::ArrayRef<CAGAnchorNode *> toAnchors,
+                                 Args... args) {
+    static_assert(std::is_convertible<T *, CAGConstraintNode *>(),
+                  "T must be a CAGConstraingNode");
+    T *constraintNode = addNode(llvm::make_unique<T>(args...));
+    fromAnchor->addOutgoing(constraintNode);
+    for (auto *toAnchor : toAnchors) {
+      constraintNode->addOutgoing(toAnchor);
+    }
+    return constraintNode;
+  }
+
+  template <typename T>
+  T *addClusteredConstraint(llvm::ArrayRef<CAGAnchorNode *> anchors) {
+    static_assert(std::is_convertible<T *, CAGConstraintNode *>(),
+                  "T must be a CAGConstraingNode");
+    llvm::SmallVector<T *, 8> cluster;
+    for (auto *anchor : anchors) {
+      anchor->findChildrenOfKind<T>(cluster);
+    }
+
+    T *constraintNode;
+    if (cluster.empty()) {
+      // Create new.
+      constraintNode = addNode(llvm::make_unique<T>());
+    } else {
+      // Merge existing.
+      constraintNode = cluster[0];
+      for (size_t i = 1, e = cluster.size(); i < e; ++i) {
+        cluster[i]->replaceIncoming(constraintNode);
+      }
+    }
+    for (auto *anchor : anchors) {
+      anchor->addOutgoing(constraintNode);
+    }
+    return constraintNode;
+  }
+
+  /// Enumerates all implied connections in the slice.
+  /// An implied connection is any two nodes that physically refer to the
+  /// same value in the IR, such as result->operand.
+  /// Typically this will be modeled with some kind of strong or weak
+  /// identity constraint such that types propagate.
+  /// This is usually called when the slice has been fully constructed in
+  /// order to add final constraints.
+  /// It is legal for the callback to modify the graph by adding constraints.
+  void enumerateImpliedConnections(
+      std::function<void(CAGAnchorNode *from, CAGAnchorNode *to)> callback);
+
+  /// Performs one round of propagation, returning the number of nodes
+  /// propagates. If returns > 0, then additional propagate() rounds are
+  /// required.
+  unsigned propagate(const TargetConfiguration &config);
+
+private:
+  /// Adds a node to the graph.
+  /// The node should be a subclass of TransformNode.
+  /// Returns the raw pointer to the node.
+  template <typename T>
+  T *addNode(std::unique_ptr<T> node) {
+    node->nodeId = allNodes.size();
+    T *unownedNode = node.release();
+    allNodes.push_back(unownedNode);
+    return unownedNode;
+  }
+
+  SolverContext &context;
+  std::vector<CAGNode *> allNodes;
+  llvm::DenseMap<std::pair<Operation *, unsigned>, CAGOperandAnchor *>
+      operandAnchors;
+  llvm::DenseMap<std::pair<Operation *, unsigned>, CAGResultAnchor *>
+      resultAnchors;
+};
+
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+                                     const CAGNode &node) {
+  node.printLabel(os);
+  return os;
+}
+
+} // namespace quantizer
+} // namespace mlir
+
+#endif // MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPH_H
diff --git a/third_party/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h b/third_party/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h
new file mode 100644
index 0000000..7e2b61d
--- /dev/null
+++ b/third_party/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h
@@ -0,0 +1,58 @@
+//===- ConstraintAnalysisGraphTraits.h - Traits for CAGs --------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Provides graph traits for constraint analysis graphs.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPHTRAITS_H
+#define MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPHTRAITS_H
+
+#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
+#include "llvm/ADT/GraphTraits.h"
+
+namespace llvm {
+
+template <>
+struct GraphTraits<const mlir::quantizer::CAGNode *> {
+  using NodeRef = const mlir::quantizer::CAGNode *;
+
+  static NodeRef getEntryNode(NodeRef node) { return node; }
+
+  // Successors.
+  using ChildIteratorType = mlir::quantizer::CAGNode::const_iterator;
+  static ChildIteratorType child_begin(NodeRef node) { return node->begin(); }
+  static ChildIteratorType child_end(NodeRef node) { return node->end(); }
+};
+
+template <>
+struct GraphTraits<const mlir::quantizer::CAGSlice *>
+    : public llvm::GraphTraits<const mlir::quantizer::CAGNode *> {
+  using nodes_iterator = mlir::quantizer::CAGSlice::const_iterator;
+  static mlir::quantizer::CAGSlice::const_iterator
+  nodes_begin(const mlir::quantizer::CAGSlice *G) {
+    return G->begin();
+  }
+  static mlir::quantizer::CAGSlice::const_iterator
+  nodes_end(const mlir::quantizer::CAGSlice *G) {
+    return G->end();
+  }
+};
+
+} // end namespace llvm
+
+#endif // MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPHTRAITS_H
diff --git a/third_party/mlir/include/mlir/Quantizer/Support/Metadata.h b/third_party/mlir/include/mlir/Quantizer/Support/Metadata.h
new file mode 100644
index 0000000..a2ed681
--- /dev/null
+++ b/third_party/mlir/include/mlir/Quantizer/Support/Metadata.h
@@ -0,0 +1,110 @@
+//===- Metadata.h - Top level types and metadata ----------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains top level types needed to construct constraint graphs,
+// including context/allocator support and concrete metadata structs for
+// different quantization schemes (which must be attached to anchor nodes).
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_QUANTIZER_SUPPORT_METADATA_H
+#define MLIR_QUANTIZER_SUPPORT_METADATA_H
+
+#include <limits>
+
+#include "mlir/Dialect/QuantOps/QuantTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Quantizer/Support/Rules.h"
+#include "llvm/ADT/SmallBitVector.h"
+
+namespace mlir {
+namespace quantizer {
+
+class SolverContext {
+public:
+  SolverContext(MLIRContext &mlirContext) : mlirContext(mlirContext) {}
+
+  MLIRContext &getMlirContext() { return mlirContext; }
+
+  llvm::BumpPtrAllocator &getAllocator() { return allocator; }
+
+  // Optional path to write a debug DOT file for the CAG.
+  StringRef getDebugCAGDotPath() const { return debugCAGDotPath; }
+  void setDebugCAGDotPath(StringRef p) { debugCAGDotPath = p; }
+
+private:
+  MLIRContext &mlirContext;
+  llvm::BumpPtrAllocator allocator;
+  std::string debugCAGDotPath;
+};
+
+/// Candidate for a quantized type conversion.
+struct CandidateQuantizedType {
+  // Note that scheme encodes more than just the target type: it also encodes
+  // additional constraints.
+  enum class Scheme {
+    // Uses aggregate range information for all nodes in the cluster to
+    // solve for uniform scale and zero point.
+    UniformPerLayer,
+    // Uses aggregate per-axis range information for all nodes in the cluster
+    // to solve for per-axis uniform scale and zero point.
+    UniformPerAxisFixedPoint,
+    // Uses the |explicitScaleZeroPoint| to set the scale (and zero point = 0)
+    // for the uniform type. This typically overrides all other constraints
+    // and is used for wide accumulator types (i.e. i32 bias vectors).
+    UniformExplicitFixedPointScale,
+  };
+  unsigned ordinal;
+  quant::AnyQuantizedType quantizedType;
+  Scheme scheme;
+};
+
+struct CAGUniformMetadata {
+  /// Default salience for facts that are derived from data either statically
+  /// discovered in the computation or observed from an outside source.
+  static constexpr int SalienceDefault = 0;
+
+  /// Highest salience level for facts derived from overrides provided
+  /// explicitly.
+  static constexpr int SalienceForced = 100;
+
+  /// Salience for facts derived from constraints in how the math is
+  /// expressed which must be satisfied.
+  static constexpr int SalienceRequired = 200;
+
+  /// The range that the scheme must represent in order to accomadate the
+  /// underlying data.
+  ExpandingMinMaxFact requiredRange;
+
+  /// Bool vector of scheme ordinals that are disabled.
+  llvm::SmallBitVector disabledCandidateTypes;
+
+  /// If set, then a solution has converged for the given per-layer scheme.
+  quant::QuantizedType selectedType;
+
+  /// Optional scale and zero point to be used by types which solve via the
+  /// UniformExplicitFixedPointScale scheme.
+  DiscreteScaleZeroPointFact explicitScaleZeroPoint;
+
+  /// Prints a summary of the metadata suitable for display in a graph label.
+  void printSummary(llvm::raw_ostream &os) const;
+};
+
+} // end namespace quantizer
+} // end namespace mlir
+
+#endif // MLIR_QUANTIZER_SUPPORT_METADATA_H
diff --git a/third_party/mlir/include/mlir/Quantizer/Support/Rules.h b/third_party/mlir/include/mlir/Quantizer/Support/Rules.h
new file mode 100644
index 0000000..9d1e53d
--- /dev/null
+++ b/third_party/mlir/include/mlir/Quantizer/Support/Rules.h
@@ -0,0 +1,209 @@
+//===- Rules.h - Helpers for declaring facts and rules ----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines helper classes and functions for managing state (facts),
+// merging and tracking modification for various data types important for
+// quantization.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_QUANTIZER_SUPPORT_RULES_H
+#define MLIR_QUANTIZER_SUPPORT_RULES_H
+
+#include "llvm/ADT/Optional.h"
+
+#include <algorithm>
+#include <limits>
+#include <utility>
+
+namespace mlir {
+namespace quantizer {
+
+/// Typed indicator of whether a mutator produces a modification.
+struct ModificationResult {
+  enum ModificationEnum { Retained, Modified } value;
+  ModificationResult(ModificationEnum v) : value(v) {}
+
+  ModificationResult operator|(ModificationResult other) {
+    if (value == Modified || other.value == Modified) {
+      return ModificationResult(Modified);
+    } else {
+      return ModificationResult(Retained);
+    }
+  }
+
+  ModificationResult operator|=(ModificationResult other) {
+    value =
+        (value == Modified || other.value == Modified) ? Modified : Retained;
+    return *this;
+  }
+};
+
+inline ModificationResult modify(bool isModified = true) {
+  return ModificationResult{isModified ? ModificationResult::Modified
+                                       : ModificationResult::Retained};
+}
+
+inline bool modified(ModificationResult m) {
+  return m.value == ModificationResult::Modified;
+}
+
+/// A fact that can converge through forward propagation alone without the
+/// need to track ownership or individual assertions. In practice, this works
+/// for static assertions that are either minimized or maximized and do not
+/// vary dynamically.
+///
+/// It is expected that ValueTy is appropriate to pass by value and has an
+/// operator==. The BinaryReducer type should have two static methods:
+///   using ValueTy : Type of the value.
+///   ValueTy initialValue() : Returns the initial value of the fact.
+///   ValueTy reduce(ValueTy lhs, ValueTy rhs) : Reduces two values.
+template <typename BinaryReducer>
+class BasePropagatedFact {
+public:
+  using ValueTy = typename BinaryReducer::ValueTy;
+  using ThisTy = BasePropagatedFact<BinaryReducer>;
+  BasePropagatedFact()
+      : value(BinaryReducer::initialValue()),
+        salience(std::numeric_limits<int>::min()) {}
+
+  int getSalience() const { return salience; }
+  bool hasValue() const { return salience != std::numeric_limits<int>::min(); }
+  ValueTy getValue() const { return value; }
+  ModificationResult assertValue(int assertSalience, ValueTy assertValue) {
+    if (assertSalience > salience) {
+      // New salience band.
+      value = assertValue;
+      salience = assertSalience;
+      return modify(true);
+    } else if (assertSalience < salience) {
+      // Lower salience - ignore.
+      return modify(false);
+    }
+    // Merge within same salience band.
+    ValueTy updatedValue = BinaryReducer::reduce(value, assertValue);
+    auto mod = modify(value != updatedValue);
+    value = updatedValue;
+    return mod;
+  }
+  ModificationResult mergeFrom(const ThisTy &other) {
+    if (other.hasValue()) {
+      return assertValue(other.getSalience(), other.getValue());
+    }
+    return modify(false);
+  }
+
+private:
+  ValueTy value;
+  int salience;
+};
+
+/// A binary reducer that expands a min/max range represented by a pair
+/// of doubles such that it represents the largest of all inputs.
+/// The initial value is (Inf, -Inf).
+struct ExpandingMinMaxReducer {
+  using ValueTy = std::pair<double, double>;
+  static ValueTy initialValue() {
+    return std::make_pair(std::numeric_limits<double>::infinity(),
+                          -std::numeric_limits<double>::infinity());
+  }
+  static ValueTy reduce(ValueTy lhs, ValueTy rhs) {
+    return std::make_pair(std::min(lhs.first, rhs.first),
+                          std::max(lhs.second, rhs.second));
+  }
+};
+using ExpandingMinMaxFact = BasePropagatedFact<ExpandingMinMaxReducer>;
+
+/// A binary reducer that minimizing a numeric type.
+template <typename T>
+struct MinimizingNumericReducer {
+  using ValueTy = T;
+  static ValueTy initialValue() {
+    if (std::numeric_limits<T>::has_infinity()) {
+      return std::numeric_limits<T>::infinity();
+    } else {
+      return std::numeric_limits<T>::max();
+    }
+  }
+  static ValueTy reduce(ValueTy lhs, ValueTy rhs) { return std::min(lhs, rhs); }
+};
+using MinimizingDoubleFact =
+    BasePropagatedFact<MinimizingNumericReducer<double>>;
+using MinimizingIntFact = BasePropagatedFact<MinimizingNumericReducer<int>>;
+
+/// A binary reducer that maximizes a numeric type.
+template <typename T>
+struct MaximizingNumericReducer {
+  using ValueTy = T;
+  static ValueTy initialValue() {
+    if (std::numeric_limits<T>::has_infinity()) {
+      return -std::numeric_limits<T>::infinity();
+    } else {
+      return std::numeric_limits<T>::min();
+    }
+  }
+  static ValueTy reduce(ValueTy lhs, ValueTy rhs) { return std::max(lhs, rhs); }
+};
+using MaximizingDoubleFact =
+    BasePropagatedFact<MaximizingNumericReducer<double>>;
+using MaximizingIntFact = BasePropagatedFact<MaximizingNumericReducer<int>>;
+
+/// A fact and reducer for tracking agreement of discrete values. The value
+/// type consists of a |T| value and a flag indicating whether there is a
+/// conflict (in which case, the preserved value is arbitrary).
+template <typename T>
+struct DiscreteReducer {
+  struct ValueTy {
+    ValueTy() : conflict(false) {}
+    ValueTy(T value) : value(value), conflict(false) {}
+    ValueTy(T value, bool conflict) : value(value), conflict(conflict) {}
+    llvm::Optional<T> value;
+    bool conflict;
+    bool operator==(const ValueTy &other) const {
+      if (conflict != other.conflict)
+        return false;
+      if (value && other.value) {
+        return *value == *other.value;
+      } else {
+        return !value && !other.value;
+      }
+    }
+    bool operator!=(const ValueTy &other) const { return !(*this == other); }
+  };
+  static ValueTy initialValue() { return ValueTy(); }
+  static ValueTy reduce(ValueTy lhs, ValueTy rhs) {
+    if (!lhs.value && !rhs.value)
+      return lhs;
+    else if (!lhs.value)
+      return rhs;
+    else if (!rhs.value)
+      return lhs;
+    else
+      return ValueTy(*lhs.value, *lhs.value != *rhs.value);
+  }
+};
+
+template <typename T>
+using DiscreteFact = BasePropagatedFact<DiscreteReducer<T>>;
+
+/// Discrete scale/zeroPoint fact.
+using DiscreteScaleZeroPointFact = DiscreteFact<std::pair<double, int64_t>>;
+
+} // end namespace quantizer
+} // end namespace mlir
+
+#endif // MLIR_QUANTIZER_SUPPORT_RULES_H
diff --git a/third_party/mlir/include/mlir/Quantizer/Support/Statistics.h b/third_party/mlir/include/mlir/Quantizer/Support/Statistics.h
new file mode 100644
index 0000000..c6f059e
--- /dev/null
+++ b/third_party/mlir/include/mlir/Quantizer/Support/Statistics.h
@@ -0,0 +1,94 @@
+//===- Statistics.h - Collects statistics over tensors ----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines adapters for extracting various (per layer and per axis)
+// statistics over tensors.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_QUANTIZER_SUPPORT_STATISTICS_H
+#define MLIR_QUANTIZER_SUPPORT_STATISTICS_H
+
+#include "mlir/IR/Attributes.h"
+
+namespace mlir {
+namespace quantizer {
+
+/// Statistics about a tensor axis (or the whole tensor).
+struct TensorAxisStatistics {
+  int64_t sampleSize = 0;
+  double minValue = 0;
+  double maxValue = 0;
+  double mean = 0;
+  double variance = 0;
+
+  TensorAxisStatistics() {}
+  TensorAxisStatistics(int64_t sampleSize, double minValue, double maxValue,
+                       double mean, double variance)
+      : sampleSize(sampleSize), minValue(minValue), maxValue(maxValue),
+        mean(mean), variance(variance) {}
+  void clear() { *this = TensorAxisStatistics(); }
+};
+
+/// Base class for querying statistics about a tensor.
+class AbstractTensorStatistics {
+public:
+  virtual ~AbstractTensorStatistics() = default;
+
+  /// Gets statistics across the whole tensor.
+  /// Returns true if statistics are valid and were populated.
+  virtual bool get(TensorAxisStatistics &stats) const { return false; }
+
+  /// Whether this instance supports querying per axis statistics. If true,
+  /// then getForAxis(...) can be used.
+  virtual bool supportsPerAxis() const { return false; }
+
+  /// Count of axises supported in a per-axis query.
+  virtual unsigned getAxisCount() const { return 0; }
+
+  /// Gets statistics for a specific axis (0..getAxisCount() - 1).
+  /// Returns true if statistics are valid and were populated.
+  virtual bool getForAxis(unsigned axis, TensorAxisStatistics &stats) const {
+    return false;
+  }
+};
+
+/// Wraps an MLIR Attribte and returns statistics about it.
+/// It is expected that the attribute be one of:
+///   FloatAttr (scalar)
+///   DenseFPElementsAttr
+///   OpaqueElementsAttr (with Float based type)
+///   SparseElementAttr  (with Float based type)
+class AttributeTensorStatistics : public AbstractTensorStatistics {
+public:
+  AttributeTensorStatistics(Attribute attr) : attr(attr) {}
+
+  bool get(TensorAxisStatistics &stats) const override;
+
+  // TODO: Implement per-axis.
+
+private:
+  Attribute attr;
+};
+
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+                              const TensorAxisStatistics &stats);
+
+} // end namespace quantizer
+} // end namespace mlir
+
+#endif // MLIR_QUANTIZER_SUPPORT_STATISTICS_H
diff --git a/third_party/mlir/include/mlir/Quantizer/Support/TypeUtils.h b/third_party/mlir/include/mlir/Quantizer/Support/TypeUtils.h
new file mode 100644
index 0000000..074f8b9
--- /dev/null
+++ b/third_party/mlir/include/mlir/Quantizer/Support/TypeUtils.h
@@ -0,0 +1,40 @@
+//===- TypeUtils.h - Helper function for manipulating types -----*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines various helper functions for manipulating types. The
+// process of quantizing typically involves a number of type manipulations
+// that are not very common elsewhere, and it is best to name them and define
+// them here versus inline in the rest of the tool.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef THIRD_PARTY_MLIR_EDGE_FXPSOLVER_SUPPORT_TYPEUTILS_H_
+#define THIRD_PARTY_MLIR_EDGE_FXPSOLVER_SUPPORT_TYPEUTILS_H_
+
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace quantizer {
+
+/// Given an arbitrary container or primitive type, returns the element type,
+/// where the element type is just the type for non-containers.
+Type getElementOrPrimitiveType(Type t);
+
+} // namespace quantizer
+} // namespace mlir
+
+#endif // THIRD_PARTY_MLIR_EDGE_FXPSOLVER_SUPPORT_TYPEUTILS_H_
diff --git a/third_party/mlir/include/mlir/Quantizer/Support/UniformConstraints.h b/third_party/mlir/include/mlir/Quantizer/Support/UniformConstraints.h
new file mode 100644
index 0000000..90b5fe1
--- /dev/null
+++ b/third_party/mlir/include/mlir/Quantizer/Support/UniformConstraints.h
@@ -0,0 +1,69 @@
+//===- UniformConstraints.h - Constraints for uniform quant -----*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines a builder that lets you attach constraints necessary to
+// perform a variety of uniform quantization conversions to CAG anchors.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_QUANTIZER_SUPPORT_UNIFORMCONSTRAINTS_H
+#define MLIR_QUANTIZER_SUPPORT_UNIFORMCONSTRAINTS_H
+
+#include "mlir/Quantizer/Support/Statistics.h"
+
+namespace mlir {
+namespace quantizer {
+
+class CAGAnchorNode;
+class CAGSlice;
+
+/// Factory methods for adding CAG constraints of various kinds suitable
+/// for solving for uniform quantization.
+class UniformConstraintsBuilder {
+public:
+  UniformConstraintsBuilder(CAGSlice &slice) : slice(slice) {}
+
+  /// Adds a coupling constraint between two nodes, effectively treating
+  /// them as a hard identity relationship.
+  void coupleAnchors(CAGAnchorNode *a, CAGAnchorNode *b);
+
+  /// Applies statistics constraints to the given anchor, such that the solver
+  /// ensures that the statistics are representable by chosen types.
+  void applyStats(CAGAnchorNode *a, TensorAxisStatistics stats);
+
+  /// Applies a constraint to a node which allows solutions that do not extend
+  /// beyond given min/max bounds (this is a hint that the tensor will not
+  /// take values outside of these bounds). If either minValue or maxValue is
+  /// NAN, then that side is considered open.
+  void clamp(CAGAnchorNode *a, APFloat minValue, APFloat maxValue);
+
+  /// Propagates an explicit scale from an anchor that may have a uniform
+  /// |selectedType| to the |explicitScaleZeroPoint| field of the to node.
+  /// This is typically used with a to node that has a candidate quantized
+  /// type of |UniformExplicitFixedPointScale|, indicating that it can be
+  /// an arbitrary (signed) type that is expected to share the same scale
+  /// as the originating node.
+  void propagateExplicitScale(CAGAnchorNode *from, CAGAnchorNode *to);
+
+private:
+  CAGSlice &slice;
+};
+
+} // namespace quantizer
+} // namespace mlir
+
+#endif // MLIR_QUANTIZER_SUPPORT_UNIFORMCONSTRAINTS_H
diff --git a/third_party/mlir/include/mlir/Quantizer/Support/UniformSolvers.h b/third_party/mlir/include/mlir/Quantizer/Support/UniformSolvers.h
new file mode 100644
index 0000000..0759758
--- /dev/null
+++ b/third_party/mlir/include/mlir/Quantizer/Support/UniformSolvers.h
@@ -0,0 +1,95 @@
+//===- UniformSolvers.h - Uniform type solver algorithms --------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines algorithms for solving uniform type parameters for various
+// conditions (i.e. fixed-point, affine, scale matching, etc).
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_QUANTIZER_SUPPORT_UNIFORMSOLVERS_H
+#define MLIR_QUANTIZER_SUPPORT_UNIFORMSOLVERS_H
+
+#include <cstdint>
+#include <limits>
+
+namespace llvm {
+class raw_ostream;
+} // end namespace llvm
+
+namespace mlir {
+namespace quantizer {
+
+struct UniformStorageParams {
+  static UniformStorageParams getQuint8() { return {255, 0}; }
+  static UniformStorageParams getQuint8SymmetricRight() { return {254, 1}; }
+  static UniformStorageParams getQuint16() { return {32767, 0}; }
+
+  uint64_t numLevels;
+  int64_t minValue;
+};
+
+/// Solves for the uniform quantization scheme paramers delta and z given
+/// bounding min/max.
+class UniformParamsFromMinMaxSolver {
+public:
+  UniformParamsFromMinMaxSolver(const UniformStorageParams &storageParams,
+                                double boundingMin, double boundingMax)
+      : storageParams(storageParams), boundingMin(boundingMin),
+        boundingMax(boundingMax) {}
+
+  /// Performs the computation, returning whether satisfied.
+  bool compute();
+
+  // Params.
+  double getBoundingMin() const { return boundingMin; }
+  double getBoundingMax() const { return boundingMax; }
+  bool isSatisfied() const { return satisfied; }
+  double getAdjMin() const { return adjMin; }
+  double getAdjMax() const { return adjMax; }
+  double getScale() const { return delta; }
+  int64_t getZp() const { return zp; }
+  int getStepCount() const { return stepCount; }
+
+  // Quantize and dequantize.
+  int64_t quantize(double x) const;
+  double dequantize(int64_t xq) const;
+
+private:
+  const UniformStorageParams storageParams;
+  const double boundingMin;
+  const double boundingMax;
+
+  // Results
+  int stepCount = 0;
+  double adjMin = std::numeric_limits<double>::quiet_NaN();
+  double adjMax = std::numeric_limits<double>::quiet_NaN();
+  double delta = std::numeric_limits<double>::quiet_NaN();
+  int64_t zp = 0;
+
+  bool satisfied = false;
+};
+
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+                              const UniformStorageParams &p);
+
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+                              const UniformParamsFromMinMaxSolver &s);
+
+} // end namespace quantizer
+} // end namespace mlir
+
+#endif // MLIR_QUANTIZER_SUPPORT_UNIFORMSOLVERS_H
diff --git a/third_party/mlir/include/mlir/Quantizer/Transforms/Passes.h b/third_party/mlir/include/mlir/Quantizer/Transforms/Passes.h
new file mode 100644
index 0000000..0d7b4cb
--- /dev/null
+++ b/third_party/mlir/include/mlir/Quantizer/Transforms/Passes.h
@@ -0,0 +1,51 @@
+//===- Passes.h - Quantizer passes  -----------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines entry points to create passes to perform various kinds
+// of quantization related transforms.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_QUANTIZER_TRANSFORMS_PASSES_H
+#define MLIR_QUANTIZER_TRANSFORMS_PASSES_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace quantizer {
+
+class SolverContext;
+class TargetConfiguration;
+
+/// Creates a pass that infers quantized types based on metadata discovered
+/// in the computation.
+ModulePassBase *
+createInferQuantizedTypesPass(SolverContext &solverContext,
+                              const TargetConfiguration &config);
+
+/// Creates a pass which removes any instrumentation and hint ops which have
+/// no effect on final runtime.
+FunctionPassBase *createRemoveInstrumentationPass();
+
+/// Adds default (dummy) statistics to ops that can benefit from runtime stats.
+/// Meant for testing.
+FunctionPassBase *createAddDefaultStatsPass();
+
+} // namespace quantizer
+} // namespace mlir
+
+#endif // MLIR_QUANTIZER_TRANSFORMS_PASSES_H
diff --git a/third_party/mlir/include/mlir/SDBM/SDBM.h b/third_party/mlir/include/mlir/SDBM/SDBM.h
new file mode 100644
index 0000000..b1c2723
--- /dev/null
+++ b/third_party/mlir/include/mlir/SDBM/SDBM.h
@@ -0,0 +1,206 @@
+//===- SDBM.h - MLIR SDBM declaration ---------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// A striped difference-bound matrix (SDBM) is a set in Z^N (or R^N) defined
+// as {(x_1, ... x_n) | f(x_1, ... x_n) >= 0} where f is an SDBM expression.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef INCLUDE_MLIR_IR_SDBM_H
+#define INCLUDE_MLIR_IR_SDBM_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseMap.h"
+
+namespace mlir {
+
+class MLIRContext;
+class SDBMDialect;
+class SDBMExpr;
+class SDBMPositiveExpr;
+
+/// A utility class for SDBM to represent an integer with potentially infinite
+/// positive value. This uses the largest value of int64_t to represent infinity
+/// and redefines the arithmetic operators so that the infinity "saturates":
+///   inf + x = inf,
+///   inf - x = inf.
+/// If a sum of two finite values reaches the largest value of int64_t, the
+/// behavior of IntInfty is undefined (in practice, it asserts), similarly to
+/// regular signed integer overflow.
+class IntInfty {
+public:
+  constexpr static int64_t infty = std::numeric_limits<int64_t>::max();
+
+  /*implicit*/ IntInfty(int64_t v) : value(v) {}
+
+  IntInfty &operator=(int64_t v) {
+    value = v;
+    return *this;
+  }
+
+  static IntInfty infinity() { return IntInfty(infty); }
+
+  int64_t getValue() const { return value; }
+  explicit operator int64_t() const { return value; }
+
+  bool isFinite() { return value != infty; }
+
+private:
+  int64_t value;
+};
+
+inline IntInfty operator+(IntInfty lhs, IntInfty rhs) {
+  if (!lhs.isFinite() || !rhs.isFinite())
+    return IntInfty::infty;
+
+  // Check for overflows, treating the sum of two values adding up to INT_MAX as
+  // overflow.  Convert values to unsigned to get an extra bit and avoid the
+  // undefined behavior of signed integer overflows.
+  assert((lhs.getValue() <= 0 || rhs.getValue() <= 0 ||
+          static_cast<uint64_t>(lhs.getValue()) +
+                  static_cast<uint64_t>(rhs.getValue()) <
+              static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) &&
+         "IntInfty overflow");
+  // Check for underflows by converting values to unsigned to avoid undefined
+  // behavior of signed integers perform the addition (bitwise result is same
+  // because numbers are required to be two's complement in C++) and check if
+  // the sign bit remains negative.
+  assert((lhs.getValue() >= 0 || rhs.getValue() >= 0 ||
+          ((static_cast<uint64_t>(lhs.getValue()) +
+            static_cast<uint64_t>(rhs.getValue())) >>
+           63) == 1) &&
+         "IntInfty underflow");
+
+  return lhs.getValue() + rhs.getValue();
+}
+
+inline bool operator<(IntInfty lhs, IntInfty rhs) {
+  return lhs.getValue() < rhs.getValue();
+}
+
+inline bool operator<=(IntInfty lhs, IntInfty rhs) {
+  return lhs.getValue() <= rhs.getValue();
+}
+
+inline bool operator==(IntInfty lhs, IntInfty rhs) {
+  return lhs.getValue() == rhs.getValue();
+}
+
+inline bool operator!=(IntInfty lhs, IntInfty rhs) { return !(lhs == rhs); }
+
+/// Striped difference-bound matrix is a representation of an integer set bound
+/// by a system of SDBMExprs interpreted as inequalities "expr <= 0".
+class SDBM {
+public:
+  /// Obtain an SDBM from a list of SDBM expressions treated as inequalities and
+  /// equalities with zero.
+  static SDBM get(ArrayRef<SDBMExpr> inequalities,
+                  ArrayRef<SDBMExpr> equalities);
+
+  void getSDBMExpressions(SDBMDialect *dialect,
+                          SmallVectorImpl<SDBMExpr> &inequalities,
+                          SmallVectorImpl<SDBMExpr> &equalities);
+
+  void print(llvm::raw_ostream &os);
+  void dump();
+
+  IntInfty operator()(int i, int j) { return at(i, j); }
+
+private:
+  /// Get the given element of the difference bounds matrix.  First index
+  /// corresponds to the negative term of the difference, second index
+  /// corresponds to the positive term of the difference.
+  IntInfty &at(int i, int j) { return matrix[i * getNumVariables() + j]; }
+
+  /// Populate `inequalities` and `equalities` based on the values at(row,col)
+  /// and at(col,row) of the DBM.  Depending on the values being finite and
+  /// being subsumed by stripe expressions, this may or may not add elements to
+  /// the lists of equalities and inequalities.
+  void convertDBMElement(unsigned row, unsigned col, SDBMPositiveExpr rowExpr,
+                         SDBMPositiveExpr colExpr,
+                         SmallVectorImpl<SDBMExpr> &inequalities,
+                         SmallVectorImpl<SDBMExpr> &equalities);
+
+  /// Populate `inequalities` based on the value at(pos,pos) of the DBM. Only
+  /// adds new inequalities if the inequality is not trivially true.
+  void convertDBMDiagonalElement(unsigned pos, SDBMPositiveExpr expr,
+                                 SmallVectorImpl<SDBMExpr> &inequalities);
+
+  /// Get the total number of elements in the matrix.
+  unsigned getNumVariables() const {
+    return 1 + numDims + numSymbols + numTemporaries;
+  }
+
+  /// Get the position in the matrix that corresponds to the given dimension.
+  unsigned getDimPosition(unsigned position) const { return 1 + position; }
+
+  /// Get the position in the matrix that corresponds to the given symbol.
+  unsigned getSymbolPosition(unsigned position) const {
+    return 1 + numDims + position;
+  }
+
+  /// Get the position in the matrix that corresponds to the given temporary.
+  unsigned getTemporaryPosition(unsigned position) const {
+    return 1 + numDims + numSymbols + position;
+  }
+
+  /// Number of dimensions in the system,
+  unsigned numDims;
+  /// Number of symbols in the system.
+  unsigned numSymbols;
+  /// Number of temporary variables in the system.
+  unsigned numTemporaries;
+
+  /// Difference bounds matrix, stored as a linearized row-major vector.
+  /// Each value in this matrix corresponds to an inequality
+  ///
+  ///   v@col - v@row <= at(row, col)
+  ///
+  /// where v@col and v@row are the variables that correspond to the linearized
+  /// position in the matrix.  The positions correspond to
+  ///
+  ///   - constant 0 (producing constraints v@col <= X and -v@row <= Y);
+  ///   - SDBM expression dimensions (d0, d1, ...);
+  ///   - SDBM expression symbols (s0, s1, ...);
+  ///   - temporary variables (t0, t1, ...).
+  ///
+  /// Temporary variables are introduced to represent expressions that are not
+  /// trivially a difference between two variables.  For example, if one side of
+  /// a difference expression is itself a stripe expression, it will be replaced
+  /// with a temporary variable assigned equal to this expression.
+  ///
+  /// Infinite entries in the matrix correspond correspond to an absence of a
+  /// constraint:
+  ///
+  ///   v@col - v@row <= infinity
+  ///
+  /// is trivially true.  Negated values at symmetric positions in the matrix
+  /// allow one to couple two inequalities into a single equality.
+  std::vector<IntInfty> matrix;
+
+  /// The mapping between the indices of variables in the DBM and the stripe
+  /// expressions they are equal to.  These expressions are stored as they
+  /// appeared when constructing an SDBM from a SDBMExprs, in particular no
+  /// temporaries can appear in these expressions.  This removes the need to
+  /// iteratively substitute definitions of the temporaries in the reverse
+  /// conversion.
+  llvm::DenseMap<unsigned, SDBMExpr> stripeToPoint;
+};
+
+} // namespace mlir
+
+#endif // INCLUDE_MLIR_IR_SDBM_H
diff --git a/third_party/mlir/include/mlir/SDBM/SDBMDialect.h b/third_party/mlir/include/mlir/SDBM/SDBMDialect.h
new file mode 100644
index 0000000..12086dc
--- /dev/null
+++ b/third_party/mlir/include/mlir/SDBM/SDBMDialect.h
@@ -0,0 +1,41 @@
+//===- SDBMDialect.h - Dialect for striped DBMs -----------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_SDBM_SDBMDIALECT_H
+#define MLIR_SDBM_SDBMDIALECT_H
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/Support/StorageUniquer.h"
+
+namespace mlir {
+class MLIRContext;
+
+class SDBMDialect : public Dialect {
+public:
+  SDBMDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) {}
+
+  static StringRef getDialectNamespace() { return "sdbm"; }
+
+  /// Get the uniquer for SDBM expressions. This should not be used directly.
+  StorageUniquer &getUniquer() { return uniquer; }
+
+private:
+  StorageUniquer uniquer;
+};
+} // namespace mlir
+
+#endif // MLIR_SDBM_SDBMDIALECT_H
diff --git a/third_party/mlir/include/mlir/SDBM/SDBMExpr.h b/third_party/mlir/include/mlir/SDBM/SDBMExpr.h
new file mode 100644
index 0000000..afbeda1
--- /dev/null
+++ b/third_party/mlir/include/mlir/SDBM/SDBMExpr.h
@@ -0,0 +1,530 @@
+//===- SDBMExpr.h - MLIR SDBM Expression ------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// A striped difference-bound matrix (SDBM) expression is a constant expression,
+// an identifier, a binary expression with constant RHS and +, stripe operators
+// or a difference expression between two identifiers.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_SDBMEXPR_H
+#define MLIR_IR_SDBMEXPR_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseMapInfo.h"
+
+namespace mlir {
+
+class AffineExpr;
+class MLIRContext;
+
+enum class SDBMExprKind { Add, Stripe, Diff, Constant, DimId, SymbolId, Neg };
+
+namespace detail {
+struct SDBMExprStorage;
+struct SDBMBinaryExprStorage;
+struct SDBMDiffExprStorage;
+struct SDBMPositiveExprStorage;
+struct SDBMConstantExprStorage;
+struct SDBMNegExprStorage;
+} // namespace detail
+
+class SDBMConstantExpr;
+class SDBMDialect;
+class SDBMDimExpr;
+class SDBMSymbolExpr;
+
+/// Striped Difference-Bounded Matrix (SDBM) expression is a base left-hand side
+/// expression for the SDBM framework.  SDBM expressions are a subset of affine
+/// expressions supporting low-complexity algorithms for the operations used in
+/// loop transformations.  In particular, are supported:
+///   - constant expressions;
+///   - single variables (dimensions and symbols) with +1 or -1 coefficient;
+///   - stripe expressions: "x # C", where "x" is a single variable or another
+///     stripe expression, "#" is the stripe operator, and "C" is a constant
+///     expression; "#" is defined as x - x mod C.
+///   - sum expressions between single variable/stripe expressions and constant
+///     expressions;
+///   - difference expressions between single variable/stripe expressions.
+/// `SDBMExpr` class hierarchy provides a type-safe interface to constructing
+/// and operating on SDBM expressions.  For example, it requires the LHS of a
+/// sum expression to be a single variable or a stripe expression.  These
+/// restrictions are intended to force the caller to perform the necessary
+/// simplifications to stay within the SDBM domain, because SDBM expressions do
+/// not combine in more cases than they do.  This choice may be reconsidered in
+/// the future.
+///
+/// `SDBMExpr` and derived classes are thin wrappers around a pointer owned by
+/// an MLIRContext, and should be used by-value.  They are uniqued in the
+/// MLIRContext and immortal.
+class SDBMExpr {
+public:
+  using ImplType = detail::SDBMExprStorage;
+  SDBMExpr() : impl(nullptr) {}
+  /* implicit */ SDBMExpr(ImplType *expr) : impl(expr) {}
+
+  /// SDBM expressions are thin wrappers around a unique'ed immutable pointer,
+  /// which makes them trivially assignable and trivially copyable.
+  SDBMExpr(const SDBMExpr &) = default;
+  SDBMExpr &operator=(const SDBMExpr &) = default;
+
+  /// SDBM expressions can be compared straight-forwardly.
+  bool operator==(const SDBMExpr &other) const { return impl == other.impl; }
+  bool operator!=(const SDBMExpr &other) const { return !(*this == other); }
+
+  /// SDBM expressions are convertible to `bool`: null expressions are converted
+  /// to false, non-null expressions are converted to true.
+  explicit operator bool() const { return impl != nullptr; }
+  bool operator!() const { return !static_cast<bool>(*this); }
+
+  /// Negate the given SDBM expression.
+  SDBMExpr operator-();
+
+  /// Prints the SDBM expression.
+  void print(raw_ostream &os) const;
+  void dump() const;
+
+  /// LLVM-style casts.
+  template <typename U> bool isa() const { return U::isClassFor(*this); }
+  template <typename U> U dyn_cast() const {
+    if (!isa<U>())
+      return {};
+    return U(const_cast<SDBMExpr *>(this)->impl);
+  }
+  template <typename U> U cast() const {
+    assert(isa<U>() && "cast to incorrect subtype");
+    return U(const_cast<SDBMExpr *>(this)->impl);
+  }
+
+  /// Support for LLVM hashing.
+  ::llvm::hash_code hash_value() const { return ::llvm::hash_value(impl); }
+
+  /// Returns the kind of the SDBM expression.
+  SDBMExprKind getKind() const;
+
+  /// Returns the MLIR context in which this expression lives.
+  MLIRContext *getContext() const;
+
+  /// Returns the SDBM dialect instance.
+  SDBMDialect *getDialect() const;
+
+  /// Convert the SDBM expression into an Affine expression.  This always
+  /// succeeds because SDBM are a subset of affine.
+  AffineExpr getAsAffineExpr() const;
+
+  /// Try constructing an SDBM expression from the given affine expression.
+  /// This may fail if the affine expression is not representable as SDBM, in
+  /// which case llvm::None is returned.  The conversion procedure recognizes
+  /// (nested) multiplicative ((x floordiv B) * B) and additive (x - x mod B)
+  /// patterns for the stripe expression.
+  static Optional<SDBMExpr> tryConvertAffineExpr(AffineExpr affine);
+
+protected:
+  ImplType *impl;
+};
+
+/// SDBM constant expression, wraps a 64-bit integer.
+class SDBMConstantExpr : public SDBMExpr {
+public:
+  using ImplType = detail::SDBMConstantExprStorage;
+
+  using SDBMExpr::SDBMExpr;
+
+  /// Obtain or create a constant expression unique'ed in the given dialect
+  /// (which belongs to a context).
+  static SDBMConstantExpr get(SDBMDialect *dialect, int64_t value);
+
+  static bool isClassFor(const SDBMExpr &expr) {
+    return expr.getKind() == SDBMExprKind::Constant;
+  }
+
+  int64_t getValue() const;
+};
+
+/// SDBM varying expression can be one of:
+///   - input variable expression;
+///   - stripe expression;
+///   - negation (product with -1) of either of the above.
+///   - sum of a varying and a constant expression
+///   - difference between varying expressions
+class SDBMVaryingExpr : public SDBMExpr {
+public:
+  using ImplType = detail::SDBMExprStorage;
+  using SDBMExpr::SDBMExpr;
+
+  static bool isClassFor(const SDBMExpr &expr) {
+    return expr.getKind() == SDBMExprKind::DimId ||
+           expr.getKind() == SDBMExprKind::SymbolId ||
+           expr.getKind() == SDBMExprKind::Neg ||
+           expr.getKind() == SDBMExprKind::Stripe ||
+           expr.getKind() == SDBMExprKind::Add ||
+           expr.getKind() == SDBMExprKind::Diff;
+  }
+};
+
+/// SDBM positive variable expression can be one of:
+///  - single variable expression;
+///  - stripe expression.
+class SDBMPositiveExpr : public SDBMVaryingExpr {
+public:
+  using SDBMVaryingExpr::SDBMVaryingExpr;
+
+  static bool isClassFor(const SDBMExpr &expr) {
+    return expr.getKind() == SDBMExprKind::DimId ||
+           expr.getKind() == SDBMExprKind::SymbolId ||
+           expr.getKind() == SDBMExprKind::Stripe;
+  }
+};
+
+/// SDBM sum expression.  LHS is a varying expression and RHS is always a
+/// constant expression.
+class SDBMSumExpr : public SDBMVaryingExpr {
+public:
+  using ImplType = detail::SDBMBinaryExprStorage;
+  using SDBMVaryingExpr::SDBMVaryingExpr;
+
+  /// Obtain or create a sum expression unique'ed in the given context.
+  static SDBMSumExpr get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs);
+
+  static bool isClassFor(const SDBMExpr &expr) {
+    SDBMExprKind kind = expr.getKind();
+    return kind == SDBMExprKind::Add;
+  }
+
+  SDBMVaryingExpr getLHS() const;
+  SDBMConstantExpr getRHS() const;
+};
+
+/// SDBM difference expression.  Both LHS and RHS are positive variable
+/// expressions.
+class SDBMDiffExpr : public SDBMVaryingExpr {
+public:
+  using ImplType = detail::SDBMDiffExprStorage;
+  using SDBMVaryingExpr::SDBMVaryingExpr;
+
+  /// Obtain or create a difference expression unique'ed in the given context.
+  static SDBMDiffExpr get(SDBMPositiveExpr lhs, SDBMPositiveExpr rhs);
+
+  static bool isClassFor(const SDBMExpr &expr) {
+    return expr.getKind() == SDBMExprKind::Diff;
+  }
+
+  SDBMPositiveExpr getLHS() const;
+  SDBMPositiveExpr getRHS() const;
+};
+
+/// SDBM stripe expression "x # C" where "x" is a positive variable expression,
+/// "C" is a constant expression and "#" is the stripe operator defined as:
+///   x # C = x - x mod C.
+class SDBMStripeExpr : public SDBMPositiveExpr {
+public:
+  using ImplType = detail::SDBMBinaryExprStorage;
+  using SDBMPositiveExpr::SDBMPositiveExpr;
+
+  static bool isClassFor(const SDBMExpr &expr) {
+    return expr.getKind() == SDBMExprKind::Stripe;
+  }
+
+  static SDBMStripeExpr get(SDBMPositiveExpr var,
+                            SDBMConstantExpr stripeFactor);
+
+  SDBMPositiveExpr getVar() const;
+  SDBMConstantExpr getStripeFactor() const;
+};
+
+/// SDBM "input" variable expression can be either a dimension identifier or
+/// a symbol identifier.  When used to define SDBM functions, dimensions are
+/// interpreted as function arguments while symbols are treated as unknown but
+/// constant values, hence the name.
+class SDBMInputExpr : public SDBMPositiveExpr {
+public:
+  using ImplType = detail::SDBMPositiveExprStorage;
+  using SDBMPositiveExpr::SDBMPositiveExpr;
+
+  static bool isClassFor(const SDBMExpr &expr) {
+    return expr.getKind() == SDBMExprKind::DimId ||
+           expr.getKind() == SDBMExprKind::SymbolId;
+  }
+
+  unsigned getPosition() const;
+};
+
+/// SDBM dimension expression.  Dimensions correspond to function arguments
+/// when defining functions using SDBM expressions.
+class SDBMDimExpr : public SDBMInputExpr {
+public:
+  using ImplType = detail::SDBMPositiveExprStorage;
+  using SDBMInputExpr::SDBMInputExpr;
+
+  /// Obtain or create a dimension expression unique'ed in the given dialect
+  /// (which belongs to a context).
+  static SDBMDimExpr get(SDBMDialect *dialect, unsigned position);
+
+  static bool isClassFor(const SDBMExpr &expr) {
+    return expr.getKind() == SDBMExprKind::DimId;
+  }
+};
+
+/// SDBM symbol expression.  Symbols correspond to symbolic constants when
+/// defining functions using SDBM expressions.
+class SDBMSymbolExpr : public SDBMInputExpr {
+public:
+  using ImplType = detail::SDBMPositiveExprStorage;
+  using SDBMInputExpr::SDBMInputExpr;
+
+  /// Obtain or create a symbol expression unique'ed in the given dialect (which
+  /// belongs to a context).
+  static SDBMSymbolExpr get(SDBMDialect *dialect, unsigned position);
+
+  static bool isClassFor(const SDBMExpr &expr) {
+    return expr.getKind() == SDBMExprKind::SymbolId;
+  }
+};
+
+/// Negation of an SDBM variable expression.  Equivalent to multiplying the
+/// expression with -1 (SDBM does not support other coefficients that 1 and -1).
+class SDBMNegExpr : public SDBMVaryingExpr {
+public:
+  using ImplType = detail::SDBMNegExprStorage;
+  using SDBMVaryingExpr::SDBMVaryingExpr;
+
+  /// Obtain or create a negation expression unique'ed in the given context.
+  static SDBMNegExpr get(SDBMPositiveExpr var);
+
+  static bool isClassFor(const SDBMExpr &expr) {
+    return expr.getKind() == SDBMExprKind::Neg;
+  }
+
+  SDBMPositiveExpr getVar() const;
+};
+
+/// A visitor class for SDBM expressions.  Calls the kind-specific function
+/// depending on the kind of expression it visits.
+template <typename Derived, typename Result = void> class SDBMVisitor {
+public:
+  /// Visit the given SDBM expression, dispatching to kind-specific functions.
+  Result visit(SDBMExpr expr) {
+    auto *derived = static_cast<Derived *>(this);
+    switch (expr.getKind()) {
+    case SDBMExprKind::Add:
+    case SDBMExprKind::Diff:
+    case SDBMExprKind::DimId:
+    case SDBMExprKind::SymbolId:
+    case SDBMExprKind::Neg:
+    case SDBMExprKind::Stripe:
+      return derived->visitVarying(expr.cast<SDBMVaryingExpr>());
+    case SDBMExprKind::Constant:
+      return derived->visitConstant(expr.cast<SDBMConstantExpr>());
+    }
+
+    llvm_unreachable("unsupported SDBM expression kind");
+  }
+
+  /// Traverse the SDBM expression tree calling `visit` on each node
+  /// in depth-first preorder.
+  void walkPreorder(SDBMExpr expr) { return walk</*isPreorder=*/true>(expr); }
+
+  /// Traverse the SDBM expression tree calling `visit` on each node in
+  /// depth-first postorder.
+  void walkPostorder(SDBMExpr expr) { return walk</*isPreorder=*/false>(expr); }
+
+protected:
+  /// Default visitors do nothing.
+  void visitSum(SDBMSumExpr) {}
+  void visitDiff(SDBMDiffExpr) {}
+  void visitStripe(SDBMStripeExpr) {}
+  void visitDim(SDBMDimExpr) {}
+  void visitSymbol(SDBMSymbolExpr) {}
+  void visitNeg(SDBMNegExpr) {}
+  void visitConstant(SDBMConstantExpr) {}
+
+  /// Default implementation of visitPositive dispatches to the special
+  /// functions for stripes and other variables.  Concrete visitors can override
+  /// it.
+  Result visitPositive(SDBMPositiveExpr expr) {
+    auto *derived = static_cast<Derived *>(this);
+    if (expr.getKind() == SDBMExprKind::Stripe)
+      return derived->visitStripe(expr.cast<SDBMStripeExpr>());
+    else
+      return derived->visitInput(expr.cast<SDBMInputExpr>());
+  }
+
+  /// Default implementation of visitInput dispatches to the special
+  /// functions for dimensions or symbols.  Concrete visitors can override it to
+  /// visit all variables instead.
+  Result visitInput(SDBMInputExpr expr) {
+    auto *derived = static_cast<Derived *>(this);
+    if (expr.getKind() == SDBMExprKind::DimId)
+      return derived->visitDim(expr.cast<SDBMDimExpr>());
+    else
+      return derived->visitSymbol(expr.cast<SDBMSymbolExpr>());
+  }
+
+  /// Default implementation of visitVarying dispatches to the special
+  /// functions for variables and negations thereof.  Concerete visitors can
+  /// override it to visit all variables and negations instead.
+  Result visitVarying(SDBMVaryingExpr expr) {
+    auto *derived = static_cast<Derived *>(this);
+    if (auto var = expr.dyn_cast<SDBMPositiveExpr>())
+      return derived->visitPositive(var);
+    else if (auto neg = expr.dyn_cast<SDBMNegExpr>())
+      return derived->visitNeg(neg);
+    else if (auto sum = expr.dyn_cast<SDBMSumExpr>())
+      return derived->visitSum(sum);
+    else if (auto diff = expr.dyn_cast<SDBMDiffExpr>())
+      return derived->visitDiff(diff);
+
+    llvm_unreachable("unhandled subtype of varying SDBM expression");
+  }
+
+  template <bool isPreorder> void walk(SDBMExpr expr) {
+    if (isPreorder)
+      visit(expr);
+    if (auto sumExpr = expr.dyn_cast<SDBMSumExpr>()) {
+      walk<isPreorder>(sumExpr.getLHS());
+      walk<isPreorder>(sumExpr.getRHS());
+    } else if (auto diffExpr = expr.dyn_cast<SDBMDiffExpr>()) {
+      walk<isPreorder>(diffExpr.getLHS());
+      walk<isPreorder>(diffExpr.getRHS());
+    } else if (auto stripeExpr = expr.dyn_cast<SDBMStripeExpr>()) {
+      walk<isPreorder>(stripeExpr.getVar());
+      walk<isPreorder>(stripeExpr.getStripeFactor());
+    } else if (auto negExpr = expr.dyn_cast<SDBMNegExpr>()) {
+      walk<isPreorder>(negExpr.getVar());
+    }
+    if (!isPreorder)
+      visit(expr);
+  }
+};
+
+/// Overloaded arithmetic operators for SDBM expressions asserting that their
+/// arguments have the proper SDBM expression subtype.  Perform canonicalization
+/// and constant folding on these expressions.
+namespace ops_assertions {
+
+/// Add two SDBM expressions.  At least one of the expressions must be a
+/// constant or a negation, but both expressions cannot be negations
+/// simultaneously.
+SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs);
+inline SDBMExpr operator+(SDBMExpr lhs, int64_t rhs) {
+  return lhs + SDBMConstantExpr::get(lhs.getDialect(), rhs);
+}
+inline SDBMExpr operator+(int64_t lhs, SDBMExpr rhs) {
+  return SDBMConstantExpr::get(rhs.getDialect(), lhs) + rhs;
+}
+
+/// Subtract an SDBM expression from another SDBM expression.  Both expressions
+/// must not be difference expressions.
+SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs);
+inline SDBMExpr operator-(SDBMExpr lhs, int64_t rhs) {
+  return lhs - SDBMConstantExpr::get(lhs.getDialect(), rhs);
+}
+inline SDBMExpr operator-(int64_t lhs, SDBMExpr rhs) {
+  return SDBMConstantExpr::get(rhs.getDialect(), lhs) - rhs;
+}
+
+/// Construct a stripe expression from a positive expression and a positive
+/// constant stripe factor.
+SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor);
+inline SDBMExpr stripe(SDBMExpr expr, int64_t factor) {
+  return stripe(expr, SDBMConstantExpr::get(expr.getDialect(), factor));
+}
+} // namespace ops_assertions
+
+} // end namespace mlir
+
+namespace llvm {
+// SDBMExpr hash just like pointers.
+template <> struct DenseMapInfo<mlir::SDBMExpr> {
+  static mlir::SDBMExpr getEmptyKey() {
+    auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+    return mlir::SDBMExpr(static_cast<mlir::SDBMExpr::ImplType *>(pointer));
+  }
+  static mlir::SDBMExpr getTombstoneKey() {
+    auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+    return mlir::SDBMExpr(static_cast<mlir::SDBMExpr::ImplType *>(pointer));
+  }
+  static unsigned getHashValue(mlir::SDBMExpr expr) {
+    return expr.hash_value();
+  }
+  static bool isEqual(mlir::SDBMExpr lhs, mlir::SDBMExpr rhs) {
+    return lhs == rhs;
+  }
+};
+
+// SDBMVaryingExpr hash just like pointers.
+template <> struct DenseMapInfo<mlir::SDBMVaryingExpr> {
+  static mlir::SDBMVaryingExpr getEmptyKey() {
+    auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+    return mlir::SDBMVaryingExpr(
+        static_cast<mlir::SDBMExpr::ImplType *>(pointer));
+  }
+  static mlir::SDBMVaryingExpr getTombstoneKey() {
+    auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+    return mlir::SDBMVaryingExpr(
+        static_cast<mlir::SDBMExpr::ImplType *>(pointer));
+  }
+  static unsigned getHashValue(mlir::SDBMVaryingExpr expr) {
+    return expr.hash_value();
+  }
+  static bool isEqual(mlir::SDBMVaryingExpr lhs, mlir::SDBMVaryingExpr rhs) {
+    return lhs == rhs;
+  }
+};
+
+// SDBMPositiveExpr hash just like pointers.
+template <> struct DenseMapInfo<mlir::SDBMPositiveExpr> {
+  static mlir::SDBMPositiveExpr getEmptyKey() {
+    auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+    return mlir::SDBMPositiveExpr(
+        static_cast<mlir::SDBMExpr::ImplType *>(pointer));
+  }
+  static mlir::SDBMPositiveExpr getTombstoneKey() {
+    auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+    return mlir::SDBMPositiveExpr(
+        static_cast<mlir::SDBMExpr::ImplType *>(pointer));
+  }
+  static unsigned getHashValue(mlir::SDBMPositiveExpr expr) {
+    return expr.hash_value();
+  }
+  static bool isEqual(mlir::SDBMPositiveExpr lhs, mlir::SDBMPositiveExpr rhs) {
+    return lhs == rhs;
+  }
+};
+
+// SDBMConstantExpr hash just like pointers.
+template <> struct DenseMapInfo<mlir::SDBMConstantExpr> {
+  static mlir::SDBMConstantExpr getEmptyKey() {
+    auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+    return mlir::SDBMConstantExpr(
+        static_cast<mlir::SDBMExpr::ImplType *>(pointer));
+  }
+  static mlir::SDBMConstantExpr getTombstoneKey() {
+    auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+    return mlir::SDBMConstantExpr(
+        static_cast<mlir::SDBMExpr::ImplType *>(pointer));
+  }
+  static unsigned getHashValue(mlir::SDBMConstantExpr expr) {
+    return expr.hash_value();
+  }
+  static bool isEqual(mlir::SDBMConstantExpr lhs, mlir::SDBMConstantExpr rhs) {
+    return lhs == rhs;
+  }
+};
+} // namespace llvm
+
+#endif // MLIR_IR_SDBMEXPR_H
diff --git a/third_party/mlir/include/mlir/StandardOps/CMakeLists.txt b/third_party/mlir/include/mlir/StandardOps/CMakeLists.txt
new file mode 100644
index 0000000..670676f
--- /dev/null
+++ b/third_party/mlir/include/mlir/StandardOps/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS Ops.td)
+mlir_tablegen(Ops.h.inc -gen-op-decls)
+mlir_tablegen(Ops.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRStandardOpsIncGen)
diff --git a/third_party/mlir/include/mlir/StandardOps/Ops.h b/third_party/mlir/include/mlir/StandardOps/Ops.h
new file mode 100644
index 0000000..fbd6462
--- /dev/null
+++ b/third_party/mlir/include/mlir/StandardOps/Ops.h
@@ -0,0 +1,363 @@
+//===- Ops.h - Standard MLIR Operations -------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines convenience types for working with standard operations
+// in the MLIR operation set.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_STANDARDOPS_OPS_H
+#define MLIR_STANDARDOPS_OPS_H
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+class AffineMap;
+class Builder;
+class FuncOp;
+class OpBuilder;
+
+class StandardOpsDialect : public Dialect {
+public:
+  StandardOpsDialect(MLIRContext *context);
+  static StringRef getDialectNamespace() { return "std"; }
+};
+
+/// The predicate indicates the type of the comparison to perform:
+/// (in)equality; (un)signed less/greater than (or equal to).
+enum class CmpIPredicate {
+  FirstValidValue,
+  // (In)equality comparisons.
+  EQ = FirstValidValue,
+  NE,
+  // Signed comparisons.
+  SLT,
+  SLE,
+  SGT,
+  SGE,
+  // Unsigned comparisons.
+  ULT,
+  ULE,
+  UGT,
+  UGE,
+  // Number of predicates.
+  NumPredicates
+};
+
+/// The predicate indicates the type of the comparison to perform:
+/// (un)orderedness, (in)equality and less/greater than (or equal to) as
+/// well as predicates that are always true or false.
+enum class CmpFPredicate {
+  FirstValidValue,
+  // Always false
+  AlwaysFalse = FirstValidValue,
+  // Ordered comparisons
+  OEQ,
+  OGT,
+  OGE,
+  OLT,
+  OLE,
+  ONE,
+  // Both ordered
+  ORD,
+  // Unordered comparisons
+  UEQ,
+  UGT,
+  UGE,
+  ULT,
+  ULE,
+  UNE,
+  // Any unordered
+  UNO,
+  // Always true
+  AlwaysTrue,
+  // Number of predicates.
+  NumPredicates
+};
+
+#define GET_OP_CLASSES
+#include "mlir/StandardOps/Ops.h.inc"
+
+/// This is a refinement of the "constant" op for the case where it is
+/// returning a float value of FloatType.
+///
+///   %1 = "std.constant"(){value: 42.0} : bf16
+///
+class ConstantFloatOp : public ConstantOp {
+public:
+  using ConstantOp::ConstantOp;
+
+  /// Builds a constant float op producing a float of the specified type.
+  static void build(Builder *builder, OperationState *result,
+                    const APFloat &value, FloatType type);
+
+  APFloat getValue() { return getAttrOfType<FloatAttr>("value").getValue(); }
+
+  static bool classof(Operation *op);
+};
+
+/// This is a refinement of the "constant" op for the case where it is
+/// returning an integer value of IntegerType.
+///
+///   %1 = "std.constant"(){value: 42} : i32
+///
+class ConstantIntOp : public ConstantOp {
+public:
+  using ConstantOp::ConstantOp;
+  /// Build a constant int op producing an integer of the specified width.
+  static void build(Builder *builder, OperationState *result, int64_t value,
+                    unsigned width);
+
+  /// Build a constant int op producing an integer with the specified type,
+  /// which must be an integer type.
+  static void build(Builder *builder, OperationState *result, int64_t value,
+                    Type type);
+
+  int64_t getValue() { return getAttrOfType<IntegerAttr>("value").getInt(); }
+
+  static bool classof(Operation *op);
+};
+
+/// This is a refinement of the "constant" op for the case where it is
+/// returning an integer value of Index type.
+///
+///   %1 = "std.constant"(){value: 99} : () -> index
+///
+class ConstantIndexOp : public ConstantOp {
+public:
+  using ConstantOp::ConstantOp;
+
+  /// Build a constant int op producing an index.
+  static void build(Builder *builder, OperationState *result, int64_t value);
+
+  int64_t getValue() { return getAttrOfType<IntegerAttr>("value").getInt(); }
+
+  static bool classof(Operation *op);
+};
+
+// DmaStartOp starts a non-blocking DMA operation that transfers data from a
+// source memref to a destination memref. The source and destination memref need
+// not be of the same dimensionality, but need to have the same elemental type.
+// The operands include the source and destination memref's each followed by its
+// indices, size of the data transfer in terms of the number of elements (of the
+// elemental type of the memref), a tag memref with its indices, and optionally
+// at the end, a stride and a number_of_elements_per_stride arguments. The tag
+// location is used by a DmaWaitOp to check for completion. The indices of the
+// source memref, destination memref, and the tag memref have the same
+// restrictions as any load/store. The optional stride arguments should be of
+// 'index' type, and specify a stride for the slower memory space (memory space
+// with a lower memory space id), tranferring chunks of
+// number_of_elements_per_stride every stride until %num_elements are
+// transferred. Either both or no stride arguments should be specified.
+//
+// For example, a DmaStartOp operation that transfers 256 elements of a memref
+// '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space
+// 1 at indices [%k, %l], would be specified as follows:
+//
+//   %num_elements = constant 256
+//   %idx = constant 0 : index
+//   %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4>
+//   dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] :
+//     memref<40 x 128 x f32>, (d0) -> (d0), 0>,
+//     memref<2 x 1024 x f32>, (d0) -> (d0), 1>,
+//     memref<1 x i32>, (d0) -> (d0), 2>
+//
+//   If %stride and %num_elt_per_stride are specified, the DMA is expected to
+//   transfer %num_elt_per_stride elements every %stride elements apart from
+//   memory space 0 until %num_elements are transferred.
+//
+//   dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride,
+//             %num_elt_per_stride :
+//
+// TODO(mlir-team): add additional operands to allow source and destination
+// striding, and multiple stride levels.
+// TODO(andydavis) Consider replacing src/dst memref indices with view memrefs.
+class DmaStartOp
+    : public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
+public:
+  using Op::Op;
+
+  static void build(Builder *builder, OperationState *result, Value *srcMemRef,
+                    ArrayRef<Value *> srcIndices, Value *destMemRef,
+                    ArrayRef<Value *> destIndices, Value *numElements,
+                    Value *tagMemRef, ArrayRef<Value *> tagIndices,
+                    Value *stride = nullptr,
+                    Value *elementsPerStride = nullptr);
+
+  // Returns the source MemRefType for this DMA operation.
+  Value *getSrcMemRef() { return getOperand(0); }
+  // Returns the rank (number of indices) of the source MemRefType.
+  unsigned getSrcMemRefRank() {
+    return getSrcMemRef()->getType().cast<MemRefType>().getRank();
+  }
+  // Returns the source memerf indices for this DMA operation.
+  operand_range getSrcIndices() {
+    return {getOperation()->operand_begin() + 1,
+            getOperation()->operand_begin() + 1 + getSrcMemRefRank()};
+  }
+
+  // Returns the destination MemRefType for this DMA operations.
+  Value *getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
+  // Returns the rank (number of indices) of the destination MemRefType.
+  unsigned getDstMemRefRank() {
+    return getDstMemRef()->getType().cast<MemRefType>().getRank();
+  }
+  unsigned getSrcMemorySpace() {
+    return getSrcMemRef()->getType().cast<MemRefType>().getMemorySpace();
+  }
+  unsigned getDstMemorySpace() {
+    return getDstMemRef()->getType().cast<MemRefType>().getMemorySpace();
+  }
+
+  // Returns the destination memref indices for this DMA operation.
+  operand_range getDstIndices() {
+    return {getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1,
+            getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1 +
+                getDstMemRefRank()};
+  }
+
+  // Returns the number of elements being transferred by this DMA operation.
+  Value *getNumElements() {
+    return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank());
+  }
+
+  // Returns the Tag MemRef for this DMA operation.
+  Value *getTagMemRef() {
+    return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
+  }
+  // Returns the rank (number of indices) of the tag MemRefType.
+  unsigned getTagMemRefRank() {
+    return getTagMemRef()->getType().cast<MemRefType>().getRank();
+  }
+
+  // Returns the tag memref index for this DMA operation.
+  operand_range getTagIndices() {
+    unsigned tagIndexStartPos =
+        1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1;
+    return {getOperation()->operand_begin() + tagIndexStartPos,
+            getOperation()->operand_begin() + tagIndexStartPos +
+                getTagMemRefRank()};
+  }
+
+  /// Returns true if this is a DMA from a faster memory space to a slower one.
+  bool isDestMemorySpaceFaster() {
+    return (getSrcMemorySpace() < getDstMemorySpace());
+  }
+
+  /// Returns true if this is a DMA from a slower memory space to a faster one.
+  bool isSrcMemorySpaceFaster() {
+    // Assumes that a lower number is for a slower memory space.
+    return (getDstMemorySpace() < getSrcMemorySpace());
+  }
+
+  /// Given a DMA start operation, returns the operand position of either the
+  /// source or destination memref depending on the one that is at the higher
+  /// level of the memory hierarchy. Asserts failure if neither is true.
+  unsigned getFasterMemPos() {
+    assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
+    return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1;
+  }
+
+  static StringRef getOperationName() { return "std.dma_start"; }
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+  LogicalResult verify();
+
+  static void getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                          MLIRContext *context);
+
+  bool isStrided() {
+    return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() +
+                                   1 + 1 + getTagMemRefRank();
+  }
+
+  Value *getStride() {
+    if (!isStrided())
+      return nullptr;
+    return getOperand(getNumOperands() - 1 - 1);
+  }
+
+  Value *getNumElementsPerStride() {
+    if (!isStrided())
+      return nullptr;
+    return getOperand(getNumOperands() - 1);
+  }
+};
+
+// DmaWaitOp blocks until the completion of a DMA operation associated with the
+// tag element '%tag[%index]'. %tag is a memref, and %index has to be an index
+// with the same restrictions as any load/store index. %num_elements is the
+// number of elements associated with the DMA operation. For example:
+//
+//   dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] :
+//     memref<2048 x f32>, (d0) -> (d0), 0>,
+//     memref<256 x f32>, (d0) -> (d0), 1>
+//     memref<1 x i32>, (d0) -> (d0), 2>
+//   ...
+//   ...
+//   dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2>
+//
+class DmaWaitOp
+    : public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
+public:
+  using Op::Op;
+
+  static void build(Builder *builder, OperationState *result, Value *tagMemRef,
+                    ArrayRef<Value *> tagIndices, Value *numElements);
+
+  static StringRef getOperationName() { return "std.dma_wait"; }
+
+  // Returns the Tag MemRef associated with the DMA operation being waited on.
+  Value *getTagMemRef() { return getOperand(0); }
+
+  // Returns the tag memref index for this DMA operation.
+  operand_range getTagIndices() {
+    return {getOperation()->operand_begin() + 1,
+            getOperation()->operand_begin() + 1 + getTagMemRefRank()};
+  }
+
+  // Returns the rank (number of indices) of the tag memref.
+  unsigned getTagMemRefRank() {
+    return getTagMemRef()->getType().cast<MemRefType>().getRank();
+  }
+
+  // Returns the number of elements transferred in the associated DMA operation.
+  Value *getNumElements() { return getOperand(1 + getTagMemRefRank()); }
+
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+  static void getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                          MLIRContext *context);
+};
+
+/// Prints dimension and symbol list.
+void printDimAndSymbolList(Operation::operand_iterator begin,
+                           Operation::operand_iterator end, unsigned numDims,
+                           OpAsmPrinter *p);
+
+/// Parses dimension and symbol list and returns true if parsing failed.
+ParseResult parseDimAndSymbolList(OpAsmParser *parser,
+                                  SmallVector<Value *, 4> &operands,
+                                  unsigned &numDims);
+
+} // end namespace mlir
+
+#endif // MLIR_STANDARDOPS_OPS_H
diff --git a/third_party/mlir/include/mlir/StandardOps/Ops.td b/third_party/mlir/include/mlir/StandardOps/Ops.td
new file mode 100644
index 0000000..b6bf2cf
--- /dev/null
+++ b/third_party/mlir/include/mlir/StandardOps/Ops.td
@@ -0,0 +1,905 @@
+//===- Ops.td - Standard operation definitions -------------*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Defines some MLIR standard operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef STANDARD_OPS
+#else
+#define STANDARD_OPS
+
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+def Std_Dialect : Dialect {
+  let name = "std";
+  let cppNamespace = "";
+}
+
+// Base class for Standard dialect ops.
+class Std_Op<string mnemonic, list<OpTrait> traits = []> :
+    Op<Std_Dialect, mnemonic, traits> {
+  // For every standard op, there needs to be a:
+  //   * void print(OpAsmPrinter *p, ${C++ class of Op} op)
+  //   * LogicalResult verify(${C++ class of Op} op)
+  //   * ParseResult parse${C++ class of Op}(OpAsmParser *parser,
+  //                                         OperationState *result)
+  // functions.
+  let printer = [{ return ::print(p, *this); }];
+  let verifier = [{ return ::verify(*this); }];
+  let parser = [{ return ::parse$cppClass(parser, result); }];
+}
+
+// Base class for standard cast operations. Requires single operand and result,
+// but does not constrain them to specific types.
+class CastOp<string mnemonic, list<OpTrait> traits = []> :
+    Std_Op<mnemonic, !listconcat(traits, [NoSideEffect])> {
+
+  let results = (outs AnyType);
+
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, Value *source, Type destType", [{
+       impl::buildCastOp(builder, result, source, destType);
+  }]>];
+
+  let parser = [{
+    return impl::parseCastOp(parser, result);
+  }];
+  let printer = [{
+    return printStandardCastOp(this->getOperation(), p);
+  }];
+  let verifier = [{ return ::verifyCastOp(*this); }];
+
+  let hasFolder = 1;
+}
+
+// Base class for standard arithmetic operations.  Requires operands and
+// results to be of the same type, but does not constrain them to specific
+// types.  Individual classes will have `lhs` and `rhs` accessor to operands.
+class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
+    Op<Std_Dialect, mnemonic,
+       !listconcat(traits, [NoSideEffect, SameOperandsAndResultType])> {
+
+  let results = (outs AnyType);
+
+  let parser = [{
+    return impl::parseBinaryOp(parser, result);
+  }];
+
+  let printer = [{
+    return printStandardBinaryOp(this->getOperation(), p);
+  }];
+}
+
+// Base class for standard arithmetic operations on integers, vectors and
+// tensors thereof.  This operation takes two operands and returns one result,
+// each of these is required to be of the same type.  This type may be an
+// integer scalar type, a vector whose element type is an integer type, or an
+// integer tensor.  The custom assembly form of the operaton is as follows
+//
+//     <op>i %0, %1 : i32
+class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
+    ArithmeticOp<mnemonic, traits>,
+    Arguments<(ins IntegerLike:$lhs, IntegerLike:$rhs)>;
+
+// Base class for standard arithmetic binary operations on floats, vectors and
+// tensors thereof.  This operation has two operands and returns one result,
+// each of these is required to be of the same type.  This type may be a
+// floating point scalar type, a vector whose element type is a floating point
+// type, or a floating point tensor.  The custom assembly form of the operation
+// is as follows
+//
+//     <op>f %0, %1 : f32
+class FloatArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
+    ArithmeticOp<mnemonic, traits>,
+    Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>;
+
+def AddFOp : FloatArithmeticOp<"addf"> {
+  let summary = "floating point addition operation";
+  let hasFolder = 1;
+}
+
+def AddIOp : IntArithmeticOp<"addi", [Commutative]> {
+  let summary = "integer addition operation";
+  let hasFolder = 1;
+}
+
+def AllocOp : Std_Op<"alloc"> {
+  let summary = "memory allocation operation";
+  let description = [{
+    The "alloc" operation allocates a region of memory, as specified by its
+    memref type. For example:
+
+      %0 = alloc() : memref<8x64xf32, (d0, d1) -> (d0, d1), 1>
+
+    The optional list of dimension operands are bound to the dynamic dimensions
+    specified in its memref type. In the example below, the ssa value '%d' is
+    bound to the second dimension of the memref (which is dynamic).
+
+      %0 = alloc(%d) : memref<8x?xf32, (d0, d1) -> (d0, d1), 1>
+
+    The optional list of symbol operands are bound to the symbols of the
+    memrefs affine map. In the example below, the ssa value '%s' is bound to
+    the symbol 's0' in the affine map specified in the allocs memref type.
+
+      %0 = alloc()[%s] : memref<8x64xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1>
+
+    This operation returns a single ssa value of memref type, which can be used
+    by subsequent load and store operations.
+  }];
+
+  let arguments = (ins Variadic<Index>:$value);
+  let results = (outs AnyMemRef);
+
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, MemRefType memrefType", [{
+       result->types.push_back(memrefType);
+     }]
+  >];
+
+  let extraClassDeclaration = [{
+    MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
+  }];
+
+  let hasCanonicalizer = 1;
+}
+
+def AndOp : IntArithmeticOp<"and", [Commutative]> {
+  let summary = "integer binary and";
+  let hasFolder = 1;
+}
+
+def BranchOp : Std_Op<"br", [Terminator]> {
+  let summary = "branch operation";
+  let description = [{
+    The "br" operation represents a branch operation in a function.
+    The operation takes variable number of operands and produces no results.
+    The operand number and types for each successor must match the arguments of
+    the block successor. For example:
+
+      ^bb2:
+        %2 = call @someFn()
+        br ^bb3(%2 : tensor<*xf32>)
+      ^bb3(%3: tensor<*xf32>):
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$operands);
+
+  let builders = [OpBuilder<
+    "Builder *, OperationState *result, Block *dest,"
+    "ArrayRef<Value *> operands = {}", [{
+      result->addSuccessor(dest, operands);
+  }]>];
+
+  // BranchOp is fully verified by traits.
+  let verifier = ?;
+
+  let extraClassDeclaration = [{
+    Block *getDest();
+    void setDest(Block *block);
+
+    /// Erase the operand at 'index' from the operand list.
+    void eraseOperand(unsigned index);
+  }];
+}
+
+def CallOp : Std_Op<"call"> {
+  let summary = "call operation";
+  let description = [{
+    The "call" operation represents a direct call to a function.  The operands
+    and result types of the call must match the specified function type.  The
+    callee is encoded as a function attribute named "callee".
+
+      %2 = call @my_add(%0, %1) : (f32, f32) -> f32
+  }];
+
+  let arguments = (ins SymbolRefAttr:$callee, Variadic<AnyType>:$operands);
+  let results = (outs Variadic<AnyType>);
+
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, FuncOp callee,"
+    "ArrayRef<Value *> operands = {}", [{
+      result->addOperands(operands);
+      result->addAttribute("callee", builder->getSymbolRefAttr(callee));
+      result->addTypes(callee.getType().getResults());
+  }]>, OpBuilder<
+    "Builder *builder, OperationState *result, StringRef callee,"
+    "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
+      result->addOperands(operands);
+      result->addAttribute("callee", builder->getSymbolRefAttr(callee));
+      result->addTypes(results);
+  }]>];
+
+  let extraClassDeclaration = [{
+    StringRef getCallee() { return callee(); }
+    FunctionType getCalleeType();
+
+    /// Get the argument operands to the called function.
+    operand_range getArgOperands() {
+      return {arg_operand_begin(), arg_operand_end()};
+    }
+
+    operand_iterator arg_operand_begin() { return operand_begin(); }
+    operand_iterator arg_operand_end() { return operand_end(); }
+  }];
+}
+
+def CallIndirectOp : Std_Op<"call_indirect"> {
+  let summary = "indirect call operation";
+  let description = [{
+    The "call_indirect" operation represents an indirect call to a value of
+    function type.  Functions are first class types in MLIR, and may be passed
+    as arguments and merged together with block arguments.  The operands
+    and result types of the call must match the specified function type.
+
+      %3 = call_indirect %2(%0, %1) : (f32, f32) -> f32
+  }];
+
+  let arguments = (ins FunctionType:$callee, Variadic<AnyType>:$operands);
+  let results = (outs Variadic<AnyType>);
+
+  let builders = [OpBuilder<
+    "Builder *, OperationState *result, Value *callee,"
+    "ArrayRef<Value *> operands = {}", [{
+      result->operands.push_back(callee);
+      result->addOperands(operands);
+      result->addTypes(callee->getType().cast<FunctionType>().getResults());
+  }]>];
+
+  let extraClassDeclaration = [{
+    Value *getCallee() { return getOperand(0); }
+
+    /// Get the argument operands to the called function.
+    operand_range getArgOperands() {
+      return {arg_operand_begin(), arg_operand_end()};
+    }
+
+    operand_iterator arg_operand_begin() { return ++operand_begin(); }
+    operand_iterator arg_operand_end() { return operand_end(); }
+  }];
+
+  let hasCanonicalizer = 1;
+}
+
+def CmpIOp : Std_Op<"cmpi", [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]> {
+  let summary = "integer comparison operation";
+  let description = [{
+    The "cmpi" operation compares its two operands according to the integer
+    comparison rules and the predicate specified by the respective attribute.
+    The predicate defines the type of comparison: (in)equality, (un)signed
+    less/greater than (or equal to).  The operands must have the same type, and
+    this type must be an integer type, a vector or a tensor thereof.  The result
+    is an i1, or a vector/tensor thereof having the same shape as the inputs.
+    Since integers are signless, the predicate also explicitly indicates
+    whether to interpret the operands as signed or unsigned integers for
+    less/greater than comparisons.  For the sake of readability by humans,
+    custom assembly form for the operation uses a string-typed attribute for
+    the predicate.  The value of this attribute corresponds to lower-cased name
+    of the predicate constant, e.g., "slt" means "signed less than".  The string
+    representation of the attribute is merely a syntactic sugar and is converted
+    to an integer attribute by the parser.
+
+      %r1 = cmpi "eq" %0, %1 : i32
+      %r2 = cmpi "slt" %0, %1 : tensor<42x42xi64>
+      %r3 = "std.cmpi"(%0, %1){predicate: 0} : (i8, i8) -> i1
+  }];
+
+  let arguments = (ins IntegerLike:$lhs, IntegerLike:$rhs);
+  let results = (outs BoolLike);
+
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, CmpIPredicate predicate,"
+    "Value *lhs, Value *rhs", [{
+      ::buildCmpIOp(builder, result, predicate, lhs, rhs);
+  }]>];
+
+  let extraClassDeclaration = [{
+    static StringRef getPredicateAttrName() { return "predicate"; }
+    static CmpIPredicate getPredicateByName(StringRef name);
+
+    CmpIPredicate getPredicate() {
+      return (CmpIPredicate)getAttrOfType<IntegerAttr>(getPredicateAttrName())
+          .getInt();
+    }
+  }];
+
+  let hasFolder = 1;
+}
+
+def CmpFOp : Std_Op<"cmpf", [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]> {
+  let summary = "floating-point comparison operation";
+  let description = [{
+    The "cmpf" operation compares its two operands according to the float
+    comparison rules and the predicate specified by the respective attribute.
+    The predicate defines the type of comparison: (un)orderedness, (in)equality
+    and signed less/greater than (or equal to) as well as predicates that are
+    always true or false.  The operands must have the same type, and this type
+    must be a float type, or a vector or tensor thereof.  The result is an i1,
+    or a vector/tensor thereof having the same shape as the inputs. Unlike cmpi,
+    the operands are always treated as signed. The u prefix indicates
+    *unordered* comparison, not unsigned comparison, so "une" means unordered or
+    not equal. For the sake of readability by humans, custom assembly form for
+    the operation uses a string-typed attribute for the predicate.  The value of
+    this attribute corresponds to lower-cased name of the predicate constant,
+    e.g., "one" means "ordered not equal".  The string representation of the
+    attribute is merely a syntactic sugar and is converted to an integer
+    attribute by the parser.
+
+      %r1 = cmpf "oeq" %0, %1 : f32
+      %r2 = cmpf "ult" %0, %1 : tensor<42x42xf64>
+      %r3 = "std.cmpf"(%0, %1) {predicate: 0} : (f8, f8) -> i1
+  }];
+
+  let arguments = (ins FloatLike:$lhs, FloatLike:$rhs);
+  let results = (outs BoolLike);
+
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, CmpFPredicate predicate,"
+    "Value *lhs, Value *rhs", [{
+      ::buildCmpFOp(builder, result, predicate, lhs, rhs);
+  }]>];
+
+  let extraClassDeclaration = [{
+    static StringRef getPredicateAttrName() { return "predicate"; }
+    static CmpFPredicate getPredicateByName(StringRef name);
+
+    CmpFPredicate getPredicate() {
+      return (CmpFPredicate)getAttrOfType<IntegerAttr>(getPredicateAttrName())
+          .getInt();
+    }
+  }];
+
+  let hasFolder = 1;
+}
+
+def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
+  let summary = "conditional branch operation";
+  let description = [{
+    The "cond_br" operation represents a conditional branch operation in a
+    function. The operation takes variable number of operands and produces
+    no results. The operand number and types for each successor must match the
+    arguments of the block successor. For example:
+
+      ^bb0:
+         %0 = extract_element %arg0[] : tensor<i1>
+         cond_br %0, ^bb1, ^bb2
+      ^bb1:
+         ...
+      ^bb2:
+         ...
+  }];
+
+  let arguments = (ins I1:$condition, Variadic<AnyType>:$branchOperands);
+
+  let builders = [OpBuilder<
+    "Builder *, OperationState *result, Value *condition,"
+    "Block *trueDest, ArrayRef<Value *> trueOperands,"
+    "Block *falseDest, ArrayRef<Value *> falseOperands", [{
+      result->addOperands(condition);
+      result->addSuccessor(trueDest, trueOperands);
+      result->addSuccessor(falseDest, falseOperands);
+  }]>];
+
+  // CondBranchOp is fully verified by traits.
+  let verifier = ?;
+
+  let extraClassDeclaration = [{
+    // These are the indices into the dests list.
+    enum { trueIndex = 0, falseIndex = 1 };
+
+    // The condition operand is the first operand in the list.
+    Value *getCondition() { return getOperand(0); }
+
+    /// Return the destination if the condition is true.
+    Block *getTrueDest() {
+      return getOperation()->getSuccessor(trueIndex);
+    }
+
+    /// Return the destination if the condition is false.
+    Block *getFalseDest() {
+      return getOperation()->getSuccessor(falseIndex);
+    }
+
+    // Accessors for operands to the 'true' destination.
+    Value *getTrueOperand(unsigned idx) {
+      assert(idx < getNumTrueOperands());
+      return getOperand(getTrueDestOperandIndex() + idx);
+    }
+
+    void setTrueOperand(unsigned idx, Value *value) {
+      assert(idx < getNumTrueOperands());
+      setOperand(getTrueDestOperandIndex() + idx, value);
+    }
+
+    operand_iterator true_operand_begin() {
+      return operand_begin() + getTrueDestOperandIndex();
+    }
+    operand_iterator true_operand_end() {
+      return true_operand_begin() + getNumTrueOperands();
+    }
+    operand_range getTrueOperands() {
+      return {true_operand_begin(), true_operand_end()};
+    }
+
+    unsigned getNumTrueOperands()  {
+      return getOperation()->getNumSuccessorOperands(trueIndex);
+    }
+
+    /// Erase the operand at 'index' from the true operand list.
+    void eraseTrueOperand(unsigned index)  {
+      getOperation()->eraseSuccessorOperand(trueIndex, index);
+    }
+
+    // Accessors for operands to the 'false' destination.
+    Value *getFalseOperand(unsigned idx) {
+      assert(idx < getNumFalseOperands());
+      return getOperand(getFalseDestOperandIndex() + idx);
+    }
+    void setFalseOperand(unsigned idx, Value *value) {
+      assert(idx < getNumFalseOperands());
+      setOperand(getFalseDestOperandIndex() + idx, value);
+    }
+
+    operand_iterator false_operand_begin() { return true_operand_end(); }
+    operand_iterator false_operand_end() {
+      return false_operand_begin() + getNumFalseOperands();
+    }
+    operand_range getFalseOperands() {
+      return {false_operand_begin(), false_operand_end()};
+    }
+
+    unsigned getNumFalseOperands() {
+      return getOperation()->getNumSuccessorOperands(falseIndex);
+    }
+
+    /// Erase the operand at 'index' from the false operand list.
+    void eraseFalseOperand(unsigned index) {
+      getOperation()->eraseSuccessorOperand(falseIndex, index);
+    }
+
+  private:
+    /// Get the index of the first true destination operand.
+    unsigned getTrueDestOperandIndex() { return 1; }
+
+    /// Get the index of the first false destination operand.
+    unsigned getFalseDestOperandIndex() {
+      return getTrueDestOperandIndex() + getNumTrueOperands();
+    }
+  }];
+
+  let hasCanonicalizer = 1;
+}
+
+def ConstantOp : Std_Op<"constant", [NoSideEffect]> {
+  let summary = "constant";
+
+  let arguments = (ins AnyAttr:$value);
+  let results = (outs AnyType);
+
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, Attribute value",
+    [{ build(builder, result, value.getType(), value); }]>];
+
+  let extraClassDeclaration = [{
+    Attribute getValue() { return getAttr("value"); }
+
+    /// Returns true if a constant operation can be built with the given value
+    /// and result type.
+    static bool isBuildableWith(Attribute value, Type type);
+  }];
+
+  let hasFolder = 1;
+}
+
+def DeallocOp : Std_Op<"dealloc"> {
+  let summary = "memory deallocation operation";
+  let description = [{
+    The "dealloc" operation frees the region of memory referenced by a memref
+    which was originally created by the "alloc" operation.
+    The "dealloc" operation should not be called on memrefs which alias an
+    alloc'd memref (i.e. memrefs returned by the "view" and "reshape"
+    operations).
+
+      %0 = alloc() : memref<8x64xf32, (d0, d1) -> (d0, d1), 1>
+      dealloc %0 : memref<8x64xf32, (d0, d1) -> (d0, d1), 1>
+  }];
+
+  let arguments = (ins AnyMemRef:$memref);
+
+  let hasCanonicalizer = 1;
+}
+
+def DimOp : Std_Op<"dim", [NoSideEffect]> {
+  let summary = "dimension index operation";
+  let description = [{
+    The "dim" operation takes a memref or tensor operand and returns an "index".
+    It requires a single integer attribute named "index". It returns the size
+    of the specified dimension. For example:
+
+      %1 = dim %0, 2 : tensor<?x?x?xf32>
+  }];
+
+  let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor],
+                                 "any tensor or memref type">:$memrefOrTensor,
+                       APIntAttr:$index);
+  let results = (outs Index);
+
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, Value *memrefOrTensor,"
+    "unsigned index", [{
+      auto indexType = builder->getIndexType();
+      auto indexAttr = builder->getIntegerAttr(indexType, index);
+      build(builder, result, indexType, memrefOrTensor, indexAttr);
+    }]>];
+
+  let extraClassDeclaration = [{
+    unsigned getIndex() {
+      return getAttrOfType<IntegerAttr>("index").getValue().getZExtValue();
+    }
+  }];
+
+  let hasFolder = 1;
+}
+
+def DivFOp : FloatArithmeticOp<"divf"> {
+  let summary = "floating point division operation";
+}
+
+def DivISOp : IntArithmeticOp<"divis"> {
+  let summary = "signed integer division operation";
+  let hasFolder = 1;
+}
+
+def DivIUOp : IntArithmeticOp<"diviu"> {
+  let summary = "unsigned integer division operation";
+  let hasFolder = 1;
+}
+
+def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
+  let summary = "element extract operation";
+  let description = [{
+    The "extract_element" op reads a tensor or vector and returns one element
+    from it specified by an index list. The output of extract is a new value
+    with the same type as the elements of the tensor or vector. The arity of
+    indices matches the rank of the accessed value (i.e., if a tensor is of rank
+    3, then 3 indices are required for the extract).  The indices should all be
+    of affine_int type. For example:
+
+      %0 = extract_element %0[%1, %2] : vector<4x4xi32>
+  }];
+
+  let arguments = (ins AnyTypeOf<[AnyVector, AnyTensor]>:$aggregate,
+                       Variadic<Index>:$indices);
+  let results = (outs AnyType);
+
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, Value *aggregate,"
+    "ArrayRef<Value *> indices = {}", [{
+      auto resType = aggregate->getType().cast<ShapedType>()
+                                         .getElementType();
+      build(builder, result, resType, aggregate, indices);
+    }]>];
+
+  let extraClassDeclaration = [{
+    Value *getAggregate() { return getOperand(0); }
+
+    operand_range getIndices() {
+      return {getOperation()->operand_begin() + 1,
+              getOperation()->operand_end()};
+    }
+  }];
+
+  let hasFolder = 1;
+}
+
+def IndexCastOp : CastOp<"index_cast">, Arguments<(ins AnyType:$in)> {
+  let summary = "cast between index and integer types";
+  let description = [{
+    Casts between integer scalars and 'index' scalars.  Index is an integer of
+    platform-specific bit width.  If casting to a wider integer, the value is
+    sign-extended.  If casting to a narrower integer, the value is truncated.
+  }];
+
+  let extraClassDeclaration = [{
+    /// Return true if `a` and `b` are valid operand and result pairs for
+    /// the operation.
+    static bool areCastCompatible(Type a, Type b);
+  }];
+
+  let hasFolder = 0;
+}
+
+def SIToFPOp : CastOp<"sitofp">, Arguments<(ins AnyType:$in)> {
+  let summary = "cast from integer type to floating-point";
+  let description = [{
+    Cast from a value interpreted as signed integer to the corresponding
+    floating-point value. If the value cannot be exactly represented, it is
+    rounded using the default rounding mode. Only scalars are currently
+    supported.
+  }];
+
+  let extraClassDeclaration = [{
+    /// Return true if `a` and `b` are valid operand and result pairs for
+    /// the operation.
+    static bool areCastCompatible(Type a, Type b);
+  }];
+
+  let hasFolder = 0;
+}
+
+def LoadOp : Std_Op<"load"> {
+  let summary = "load operation";
+  let description = [{
+    The "load" op reads an element from a memref specified by an index list. The
+    output of load is a new value with the same type as the elements of the
+    memref. The arity of indices is the rank of the memref (i.e., if the memref
+    loaded from is of rank 3, then 3 indices are required for the load following
+    the memref identifier). For example:
+
+      %3 = load %0[%1, %1] : memref<4x4xi32>
+  }];
+
+  let arguments = (ins AnyMemRef:$memref, Variadic<Index>:$indices);
+  let results = (outs AnyType);
+
+  let builders = [OpBuilder<
+    "Builder *, OperationState *result, Value *memref,"
+    "ArrayRef<Value *> indices = {}", [{
+      auto memrefType = memref->getType().cast<MemRefType>();
+      result->addOperands(memref);
+      result->addOperands(indices);
+      result->types.push_back(memrefType.getElementType());
+  }]>];
+
+  let extraClassDeclaration = [{
+    Value *getMemRef() { return getOperand(0); }
+    void setMemRef(Value *value) { setOperand(0, value); }
+    MemRefType getMemRefType() {
+      return getMemRef()->getType().cast<MemRefType>();
+    }
+
+    operand_range getIndices() {
+      return {getOperation()->operand_begin() + 1, getOperation()->operand_end()};
+    }
+  }];
+
+  let hasCanonicalizer = 1;
+}
+
+def MemRefCastOp : CastOp<"memref_cast"> {
+  let summary = "memref cast operation";
+  let description = [{
+    The "memref_cast" operation converts a memref from one type to an equivalent
+    type with a compatible shape. The source and destination types are
+    when both are memref types with the same element type, affine mappings,
+    address space, and rank but where the individual dimensions may add or
+    remove constant dimensions from the memref type.
+
+    If the cast converts any dimensions from an unknown to a known size, then it
+    acts as an assertion that fails at runtime of the dynamic dimensions
+    disagree with resultant destination size.
+
+    Assert that the input dynamic shape matches the destination static shape.
+       %2 = memref_cast %1 : memref<?x?xf32> to memref<4x4xf32>
+    Erase static shape information, replacing it with dynamic information.
+       %3 = memref_cast %1 : memref<4xf32> to memref<?xf32>
+  }];
+
+  let arguments = (ins AnyMemRef:$source);
+  let results = (outs AnyMemRef);
+
+  let extraClassDeclaration = [{
+    /// Return true if `a` and `b` are valid operand and result pairs for
+    /// the operation.
+    static bool areCastCompatible(Type a, Type b);
+
+    /// The result of a memref_cast is always a memref.
+    MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
+  }];
+}
+
+def MulFOp : FloatArithmeticOp<"mulf"> {
+  let summary = "foating point multiplication operation";
+  let hasFolder = 1;
+}
+
+def MulIOp : IntArithmeticOp<"muli", [Commutative]> {
+  let summary = "integer multiplication operation";
+  let hasFolder = 1;
+}
+
+def OrOp : IntArithmeticOp<"or", [Commutative]> {
+  let summary = "integer binary or";
+  let hasFolder = 1;
+}
+
+def RankOp : Std_Op<"rank", [NoSideEffect]> {
+  let summary = "rank operation";
+  let description = [{
+    The "rank" operation takes a tensor operand and returns its rank.
+
+      %1 = rank %0 : index
+  }];
+
+  let arguments = (ins AnyTensor);
+  let results = (outs Index);
+  let verifier = ?;
+
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, Value *tensor", [{
+      auto indexType = builder->getIndexType();
+      build(builder, result, indexType, tensor);
+    }]>];
+
+  let hasFolder = 1;
+}
+
+def RemFOp : FloatArithmeticOp<"remf"> {
+  let summary = "floating point division remainder operation";
+}
+
+def RemISOp : IntArithmeticOp<"remis"> {
+  let summary = "signed integer division remainder operation";
+  let hasFolder = 1;
+}
+
+def RemIUOp : IntArithmeticOp<"remiu"> {
+  let summary = "unsigned integer division remainder operation";
+  let hasFolder = 1;
+}
+
+def ReturnOp : Std_Op<"return", [Terminator, HasParent<"FuncOp">]> {
+  let summary = "return operation";
+  let description = [{
+    The "return" operation represents a return operation within a function.
+    The operation takes variable number of operands and produces no results.
+    The operand number and types must match the signature of the function
+    that contains the operation. For example:
+
+      func @foo() : (i32, f8) {
+      ...
+      return %0, %1 : i32, f8
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$operands);
+
+  let builders = [OpBuilder<
+    "Builder *b, OperationState *result", [{ build(b, result, llvm::None); }]
+  >];
+}
+
+def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape]> {
+  let summary = "select operation";
+  let description = [{
+    The "select" operation chooses one value based on a binary condition
+    supplied as its first operand. If the value of the first operand is 1, the
+    second operand is chosen, otherwise the third operand is chosen. The second
+    and the third operand must have the same type. The operation applies
+    elementwise to vectors and tensors.  The shape of all arguments must be
+    identical. For example, the maximum operation is obtained by combining
+    "select" with "cmpi" as follows.
+
+      %2 = cmpi "gt" %0, %1 : i32         // %2 is i1
+      %3 = select %2, %0, %1 : i32
+  }];
+
+  let arguments = (ins BoolLike:$condition, AnyType:$true_value,
+                       AnyType:$false_value);
+  let results = (outs AnyType);
+
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, Value *condition,"
+    "Value *trueValue, Value *falseValue", [{
+      result->addOperands({condition, trueValue, falseValue});
+      result->addTypes(trueValue->getType());
+  }]>];
+
+  let extraClassDeclaration = [{
+      Value *getCondition() { return condition(); }
+      Value *getTrueValue() { return true_value(); }
+      Value *getFalseValue() { return false_value(); }
+  }];
+
+  let hasFolder = 1;
+}
+def ShlISOp : IntArithmeticOp<"shlis"> {
+  let summary = "signed integer shift left";
+}
+
+def SubFOp : FloatArithmeticOp<"subf"> {
+  let summary = "floating point subtraction operation";
+  let hasFolder = 1;
+}
+
+def SubIOp : IntArithmeticOp<"subi"> {
+  let summary = "integer subtraction operation";
+  let hasFolder = 1;
+}
+
+def StoreOp : Std_Op<"store"> {
+  let summary = "store operation";
+  let description = [{
+    The "store" op writes an element to a memref specified by an index list.
+    The arity of indices is the rank of the memref (i.e. if the memref being
+    stored to is of rank 3, then 3 indices are required for the store following
+    the memref identifier). The store operation does not produce a result.
+
+    In the following example, the ssa value '%v' is stored in memref '%A' at
+    indices [%i, %j]:
+      store %v, %A[%i, %j] : memref<4x128xf32, (d0, d1) -> (d0, d1), 0>
+  }];
+
+  let arguments = (ins AnyType:$value, AnyMemRef:$memref, Variadic<Index>:$indices);
+
+  let builders = [OpBuilder<
+    "Builder *, OperationState *result, Value *valueToStore, Value *memref", [{
+      result->addOperands(valueToStore);
+      result->addOperands(memref);
+  }]>];
+
+  let extraClassDeclaration = [{
+      Value *getValueToStore() { return getOperand(0); }
+
+      Value *getMemRef() { return getOperand(1); }
+      void setMemRef(Value *value) { setOperand(1, value); }
+      MemRefType getMemRefType() {
+        return getMemRef()->getType().cast<MemRefType>();
+      }
+
+      operand_range getIndices() {
+        return {getOperation()->operand_begin() + 2, getOperation()->operand_end()};
+      }
+  }];
+
+  let hasCanonicalizer = 1;
+}
+
+def TensorCastOp : CastOp<"tensor_cast"> {
+  let summary = "tensor cast operation";
+  let description = [{
+    The "tensor_cast" operation converts a tensor from one type to an equivalent
+    type without changing any data elements.  The source and destination types
+    must both be tensor types with the same element type.  If both are ranked
+    then the rank should be the same and static dimensions should match.  The
+    operation is invalid if converting to a mismatching constant dimension.
+
+    Convert from unknown rank to rank 2 with unknown dimension sizes.
+       %2 = tensor_cast %1 : tensor<??f32> to tensor<?x?xf32>
+  }];
+
+  let arguments = (ins AnyTensor);
+  let results = (outs AnyTensor);
+
+  let extraClassDeclaration = [{
+    /// Return true if `a` and `b` are valid operand and result pairs for
+    /// the operation.
+    static bool areCastCompatible(Type a, Type b);
+
+    /// The result of a tensor_cast is always a tensor.
+    TensorType getType() { return getResult()->getType().cast<TensorType>(); }
+  }];
+}
+
+def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
+  let summary = "integer binary xor";
+  let hasFolder = 1;
+}
+
+#endif // STANDARD_OPS
diff --git a/third_party/mlir/include/mlir/Support/DebugStringHelper.h b/third_party/mlir/include/mlir/Support/DebugStringHelper.h
new file mode 100644
index 0000000..230ed23
--- /dev/null
+++ b/third_party/mlir/include/mlir/Support/DebugStringHelper.h
@@ -0,0 +1,51 @@
+//===- DebugStringHelper.h - helpers to generate debug strings --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Convenience functions to make it easier to get a string representation for
+// ops that have a print method. For use in debugging output and errors
+// returned.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DEBUGSTRINGHELPER_H_
+#define MLIR_DEBUGSTRINGHELPER_H_
+
+#include <string>
+
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/raw_os_ostream.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+
+// Simple helper function that returns a string as printed from a op.
+template <typename T> static std::string debugString(T &op) {
+  std::string instr_str;
+  llvm::raw_string_ostream os(instr_str);
+  op.print(os);
+  return os.str();
+}
+
+} // namespace mlir
+
+inline std::ostream &operator<<(std::ostream &out, const llvm::Twine &twine) {
+  llvm::raw_os_ostream rout(out);
+  rout << twine;
+  return out;
+}
+
+#endif // MLIR_DEBUGSTRINGHELPER_H_
diff --git a/third_party/mlir/include/mlir/Support/FileUtilities.h b/third_party/mlir/include/mlir/Support/FileUtilities.h
new file mode 100644
index 0000000..5ce9722
--- /dev/null
+++ b/third_party/mlir/include/mlir/Support/FileUtilities.h
@@ -0,0 +1,50 @@
+//===- FileUtilities.h - utilities for working with files -------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Common utilities for working with files.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_FILEUTILITIES_H_
+#define MLIR_SUPPORT_FILEUTILITIES_H_
+
+#include <memory>
+#include <string>
+
+namespace llvm {
+class MemoryBuffer;
+class ToolOutputFile;
+class StringRef;
+} // namespace llvm
+
+namespace mlir {
+
+/// Open the file specified by its name for reading. Write the error message to
+/// `errorMessage` if errors occur and `errorMessage` is not nullptr.
+std::unique_ptr<llvm::MemoryBuffer>
+openInputFile(llvm::StringRef inputFilename,
+              std::string *errorMessage = nullptr);
+
+/// Open the file specified by its name for writing. Write the error message to
+/// `errorMessage` if errors occur and `errorMessage` is not nullptr.
+std::unique_ptr<llvm::ToolOutputFile>
+openOutputFile(llvm::StringRef outputFilename,
+               std::string *errorMessage = nullptr);
+
+} // namespace mlir
+
+#endif // MLIR_SUPPORT_FILEUTILITIES_H_
diff --git a/third_party/mlir/include/mlir/Support/Functional.h b/third_party/mlir/include/mlir/Support/Functional.h
new file mode 100644
index 0000000..edc5e1d
--- /dev/null
+++ b/third_party/mlir/include/mlir/Support/Functional.h
@@ -0,0 +1,122 @@
+//===- Functional.h - Helpers for functional-style Combinators --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_SUPPORT_FUNCTIONAL_H_
+#define MLIR_SUPPORT_FUNCTIONAL_H_
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Casting.h"
+
+/// This file provides some simple template functional-style sugar to operate
+/// on **value** types. Make sure when using that the stored type is cheap to
+/// copy!
+///
+/// TODO(ntv): add some static_assert but we need proper traits for this.
+
+namespace mlir {
+namespace functional {
+
+/// Map with iterators.
+template <typename Fn, typename IterType>
+auto map(Fn fun, IterType begin, IterType end)
+    -> llvm::SmallVector<typename std::result_of<Fn(decltype(*begin))>::type,
+                         8> {
+  using R = typename std::result_of<Fn(decltype(*begin))>::type;
+  llvm::SmallVector<R, 8> res;
+  // auto i works with both pointer types and value types with an operator*.
+  // auto *i only works for pointer types.
+  for (auto i = begin; i != end; ++i) {
+    res.push_back(fun(*i));
+  }
+  return res;
+}
+
+/// Map with templated container.
+template <typename Fn, typename ContainerType>
+auto map(Fn fun, ContainerType input)
+    -> decltype(map(fun, std::begin(input), std::end(input))) {
+  return map(fun, std::begin(input), std::end(input));
+}
+
+/// Zip map with 2 templated container, iterates to the min of the sizes of
+/// the 2 containers.
+/// TODO(ntv): make variadic when needed.
+template <typename Fn, typename ContainerType1, typename ContainerType2>
+auto zipMap(Fn fun, ContainerType1 input1, ContainerType2 input2)
+    -> llvm::SmallVector<
+        typename std::result_of<Fn(decltype(*input1.begin()),
+                                   decltype(*input2.begin()))>::type,
+        8> {
+  using R = typename std::result_of<Fn(decltype(*input1.begin()),
+                                       decltype(*input2.begin()))>::type;
+  llvm::SmallVector<R, 8> res;
+  auto zipIter = llvm::zip(input1, input2);
+  for (auto it : zipIter) {
+    res.push_back(fun(std::get<0>(it), std::get<1>(it)));
+  }
+  return res;
+}
+
+/// Apply with iterators.
+template <typename Fn, typename IterType>
+void apply(Fn fun, IterType begin, IterType end) {
+  // auto i works with both pointer types and value types with an operator*.
+  // auto *i only works for pointer types.
+  for (auto i = begin; i != end; ++i) {
+    fun(*i);
+  }
+}
+
+/// Apply with templated container.
+template <typename Fn, typename ContainerType>
+void apply(Fn fun, ContainerType input) {
+  return apply(fun, std::begin(input), std::end(input));
+}
+
+/// Zip apply with 2 templated container, iterates to the min of the sizes of
+/// the 2 containers.
+/// TODO(ntv): make variadic when needed.
+template <typename Fn, typename ContainerType1, typename ContainerType2>
+void zipApply(Fn fun, ContainerType1 input1, ContainerType2 input2) {
+  auto zipIter = llvm::zip(input1, input2);
+  for (auto it : zipIter) {
+    fun(std::get<0>(it), std::get<1>(it));
+  }
+}
+
+/// Unwraps a pointer type to another type (possibly the same).
+/// Used in particular to allow easier compositions of
+///   Operation::operand_range types.
+template <typename T, typename ToType = T>
+inline std::function<ToType *(T *)> makePtrDynCaster() {
+  return [](T *val) { return llvm::dyn_cast<ToType>(val); };
+}
+
+/// Simple ScopeGuard.
+struct ScopeGuard {
+  explicit ScopeGuard(std::function<void(void)> destruct)
+      : destruct(destruct) {}
+  ~ScopeGuard() { destruct(); }
+
+private:
+  std::function<void(void)> destruct;
+};
+
+} // namespace functional
+} // namespace mlir
+
+#endif // MLIR_SUPPORT_FUNCTIONAL_H_
diff --git a/third_party/mlir/include/mlir/Support/JitRunner.h b/third_party/mlir/include/mlir/Support/JitRunner.h
new file mode 100644
index 0000000..14b66a8
--- /dev/null
+++ b/third_party/mlir/include/mlir/Support/JitRunner.h
@@ -0,0 +1,47 @@
+//===- JitRunner.h - MLIR CPU Execution Driver Library ----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is a library that provides a shared implementation for command line
+// utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
+// IR before JIT-compiling and executing the latter.
+//
+// The translation can be customized by providing an MLIR to MLIR
+// transformation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_JITRUNNER_H_
+#define MLIR_SUPPORT_JITRUNNER_H_
+
+#include "llvm/ADT/STLExtras.h"
+
+namespace mlir {
+
+class ModuleOp;
+struct LogicalResult;
+
+// Entry point for all CPU runners. Expects the common argc/argv arguments for
+// standard C++ main functions and an mlirTransformer.
+// The latter is applied after parsing the input into MLIR IR and before passing
+// the MLIR module to the ExecutionEngine.
+int JitRunnerMain(
+    int argc, char **argv,
+    llvm::function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer);
+
+} // namespace mlir
+
+#endif // MLIR_SUPPORT_JITRUNNER_H_
diff --git a/third_party/mlir/include/mlir/Support/LLVM.h b/third_party/mlir/include/mlir/Support/LLVM.h
new file mode 100644
index 0000000..f0dd121
--- /dev/null
+++ b/third_party/mlir/include/mlir/Support/LLVM.h
@@ -0,0 +1,103 @@
+//===- LLVM.h - Import and forward declare core LLVM types ------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file forward declares and imports various common LLVM datatypes that
+// MLIR wants to use unqualified.
+//
+// Note that most of these are forward declared and then imported into the MLIR
+// namespace with using decls, rather than being #included.  This is because we
+// want clients to explicitly #include the files they need.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_LLVM_H
+#define MLIR_SUPPORT_LLVM_H
+
+// We include these two headers because they cannot be practically forward
+// declared, and are effectively language features.
+#include "llvm/ADT/None.h"
+#include "llvm/Support/Casting.h"
+
+// Forward declarations.
+namespace llvm {
+// Containers.
+class StringRef;
+class StringLiteral;
+class Twine;
+template <typename T> class SmallPtrSetImpl;
+template <typename T, unsigned N> class SmallPtrSet;
+template <typename T> class SmallVectorImpl;
+template <typename T, unsigned N> class SmallVector;
+template <unsigned N> class SmallString;
+template <typename T> class ArrayRef;
+template <typename T> class MutableArrayRef;
+template <typename T> class TinyPtrVector;
+template <typename T> class Optional;
+template <typename... PT> class PointerUnion;
+namespace detail {
+template <typename KeyT, typename ValueT> struct DenseMapPair;
+}
+template <typename T> struct DenseMapInfo;
+template <typename ValueT, typename ValueInfoT> class DenseSet;
+template <typename KeyT, typename ValueT, typename KeyInfoT, typename BucketT>
+class DenseMap;
+
+// Other common classes.
+class raw_ostream;
+class APInt;
+class APFloat;
+} // end namespace llvm
+
+namespace mlir {
+// Casting operators.
+using llvm::cast;
+using llvm::cast_or_null;
+using llvm::dyn_cast;
+using llvm::dyn_cast_or_null;
+using llvm::isa;
+using llvm::isa_and_nonnull;
+
+// Containers.
+using llvm::ArrayRef;
+using llvm::DenseMapInfo;
+template <typename KeyT, typename ValueT,
+          typename KeyInfoT = DenseMapInfo<KeyT>,
+          typename BucketT = llvm::detail::DenseMapPair<KeyT, ValueT>>
+using DenseMap = llvm::DenseMap<KeyT, ValueT, KeyInfoT, BucketT>;
+template <typename ValueT, typename ValueInfoT = DenseMapInfo<ValueT>>
+using DenseSet = llvm::DenseSet<ValueT, ValueInfoT>;
+using llvm::MutableArrayRef;
+using llvm::None;
+using llvm::Optional;
+using llvm::PointerUnion;
+using llvm::SmallPtrSet;
+using llvm::SmallPtrSetImpl;
+using llvm::SmallString;
+using llvm::SmallVector;
+using llvm::SmallVectorImpl;
+using llvm::StringLiteral;
+using llvm::StringRef;
+using llvm::TinyPtrVector;
+using llvm::Twine;
+
+// Other common classes.
+using llvm::APFloat;
+using llvm::APInt;
+using llvm::raw_ostream;
+} // namespace mlir
+
+#endif // MLIR_SUPPORT_LLVM_H
diff --git a/third_party/mlir/include/mlir/Support/LogicalResult.h b/third_party/mlir/include/mlir/Support/LogicalResult.h
new file mode 100644
index 0000000..a9fc77c
--- /dev/null
+++ b/third_party/mlir/include/mlir/Support/LogicalResult.h
@@ -0,0 +1,60 @@
+//===- LogicalResult.h - Utilities for handling success/failure -*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_SUPPORT_LOGICAL_RESULT_H
+#define MLIR_SUPPORT_LOGICAL_RESULT_H
+
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+
+// Values that can be used to signal success/failure. This should be used in
+// conjunction with the utility functions below.
+struct LogicalResult {
+  enum ResultEnum { Success, Failure } value;
+  LogicalResult(ResultEnum v) : value(v) {}
+};
+
+/// Utility function to generate a LogicalResult. If isSuccess is true a
+/// `success` result is generated, otherwise a 'failure' result is generated.
+inline LogicalResult success(bool isSuccess = true) {
+  return LogicalResult{isSuccess ? LogicalResult::Success
+                                 : LogicalResult::Failure};
+}
+
+/// Utility function to generate a LogicalResult. If isFailure is true a
+/// `failure` result is generated, otherwise a 'success' result is generated.
+inline LogicalResult failure(bool isFailure = true) {
+  return LogicalResult{isFailure ? LogicalResult::Failure
+                                 : LogicalResult::Success};
+}
+
+/// Utility function that returns true if the provided LogicalResult corresponds
+/// to a success value.
+inline bool succeeded(LogicalResult result) {
+  return result.value == LogicalResult::Success;
+}
+
+/// Utility function that returns true if the provided LogicalResult corresponds
+/// to a failure value.
+inline bool failed(LogicalResult result) {
+  return result.value == LogicalResult::Failure;
+}
+
+} // namespace mlir
+
+#endif // MLIR_SUPPORT_LOGICAL_RESULT_H
diff --git a/third_party/mlir/include/mlir/Support/MathExtras.h b/third_party/mlir/include/mlir/Support/MathExtras.h
new file mode 100644
index 0000000..767677f
--- /dev/null
+++ b/third_party/mlir/include/mlir/Support/MathExtras.h
@@ -0,0 +1,65 @@
+//===- MathExtras.h - Math functions relevant to MLIR -----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains math functions relevant to MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_MATHEXTRAS_H_
+#define MLIR_SUPPORT_MATHEXTRAS_H_
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/APInt.h"
+
+namespace mlir {
+
+/// Returns the result of MLIR's ceildiv operation on constants. The RHS is
+/// expected to be positive.
+inline int64_t ceilDiv(int64_t lhs, int64_t rhs) {
+  assert(rhs >= 1);
+  // C/C++'s integer division rounds towards 0.
+  return lhs % rhs > 0 ? lhs / rhs + 1 : lhs / rhs;
+}
+
+/// Returns the result of MLIR's floordiv operation on constants. The RHS is
+/// expected to be positive.
+inline int64_t floorDiv(int64_t lhs, int64_t rhs) {
+  assert(rhs >= 1);
+  // C/C++'s integer division rounds towards 0.
+  return lhs % rhs < 0 ? lhs / rhs - 1 : lhs / rhs;
+}
+
+/// Returns MLIR's mod operation on constants. MLIR's mod operation yields the
+/// remainder of the Euclidean division of 'lhs' by 'rhs', and is therefore not
+/// C's % operator.  The RHS is always expected to be positive, and the result
+/// is always non-negative.
+inline int64_t mod(int64_t lhs, int64_t rhs) {
+  assert(rhs >= 1);
+  return lhs % rhs < 0 ? lhs % rhs + rhs : lhs % rhs;
+}
+
+/// Returns the least common multiple of 'a' and 'b'.
+inline int64_t lcm(int64_t a, int64_t b) {
+  uint64_t x = std::abs(a);
+  uint64_t y = std::abs(b);
+  int64_t lcm = (x * y) / llvm::GreatestCommonDivisor64(x, y);
+  assert((lcm >= a && lcm >= b) && "LCM overflow");
+  return lcm;
+}
+} // end namespace mlir
+
+#endif // MLIR_SUPPORT_MATHEXTRAS_H_
diff --git a/third_party/mlir/include/mlir/Support/MlirOptMain.h b/third_party/mlir/include/mlir/Support/MlirOptMain.h
new file mode 100644
index 0000000..00a1e48
--- /dev/null
+++ b/third_party/mlir/include/mlir/Support/MlirOptMain.h
@@ -0,0 +1,38 @@
+//===- MlirOptMain.h - MLIR Optimizer Driver main ---------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Main entry function for mlir-opt for when built as standalone binary.
+//
+//===----------------------------------------------------------------------===//
+
+#include <memory>
+#include <vector>
+
+namespace llvm {
+class raw_ostream;
+class MemoryBuffer;
+} // end namespace llvm
+namespace mlir {
+struct LogicalResult;
+class PassRegistryEntry;
+
+LogicalResult
+MlirOptMain(llvm::raw_ostream &os, std::unique_ptr<llvm::MemoryBuffer> buffer,
+            const std::vector<const PassRegistryEntry *> &passList,
+            bool splitInputFile, bool verifyDiagnostics, bool verifyPasses);
+
+} // end namespace mlir
diff --git a/third_party/mlir/include/mlir/Support/STLExtras.h b/third_party/mlir/include/mlir/Support/STLExtras.h
new file mode 100644
index 0000000..3448b08
--- /dev/null
+++ b/third_party/mlir/include/mlir/Support/STLExtras.h
@@ -0,0 +1,239 @@
+//===- STLExtras.h - STL-like extensions that are used by MLIR --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains stuff that should be arguably sunk down to the LLVM
+// Support/STLExtras.h file over time.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_STLEXTRAS_H
+#define MLIR_SUPPORT_STLEXTRAS_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/iterator.h"
+#include <tuple>
+
+namespace mlir {
+
+namespace detail {
+template <typename RangeT>
+using ValueOfRange = typename std::remove_reference<decltype(
+    *std::begin(std::declval<RangeT &>()))>::type;
+} // end namespace detail
+
+/// An STL-style algorithm similar to std::for_each that applies a second
+/// functor between every pair of elements.
+///
+/// This provides the control flow logic to, for example, print a
+/// comma-separated list:
+/// \code
+///   interleave(names.begin(), names.end(),
+///              [&](StringRef name) { os << name; },
+///              [&] { os << ", "; });
+/// \endcode
+template <typename ForwardIterator, typename UnaryFunctor,
+          typename NullaryFunctor>
+inline void interleave(ForwardIterator begin, ForwardIterator end,
+                       UnaryFunctor each_fn, NullaryFunctor between_fn) {
+  if (begin == end)
+    return;
+  each_fn(*begin);
+  ++begin;
+  for (; begin != end; ++begin) {
+    between_fn();
+    each_fn(*begin);
+  }
+}
+
+template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
+inline void interleave(const Container &c, UnaryFunctor each_fn,
+                       NullaryFunctor between_fn) {
+  interleave(c.begin(), c.end(), each_fn, between_fn);
+}
+
+template <typename Container, typename UnaryFunctor, typename raw_ostream,
+          typename T = detail::ValueOfRange<Container>>
+inline void interleaveComma(const Container &c, raw_ostream &os,
+                            UnaryFunctor each_fn) {
+  interleave(c.begin(), c.end(), each_fn, [&] { os << ", "; });
+}
+template <typename Container, typename raw_ostream,
+          typename T = detail::ValueOfRange<Container>>
+inline void interleaveComma(const Container &c, raw_ostream &os) {
+  interleaveComma(c, os, [&](const T &a) { os << a; });
+}
+
+/// A special type used to provide an address for a given class that can act as
+/// a unique identifier during pass registration.
+/// Note: We specify an explicit alignment here to allow use with PointerIntPair
+/// and other utilities/data structures that require a known pointer alignment.
+struct alignas(8) ClassID {
+  template <typename T> static ClassID *getID() {
+    static ClassID id;
+    return &id;
+  }
+  template <template <typename T> class Trait> static ClassID *getID() {
+    static ClassID id;
+    return &id;
+  }
+};
+
+/// Utilities for detecting if a given trait holds for some set of arguments
+/// 'Args'. For example, the given trait could be used to detect if a given type
+/// has a copy assignment operator:
+///   template<class T>
+///   using has_copy_assign_t = decltype(std::declval<T&>()
+///                                                 = std::declval<const T&>());
+///   bool fooHasCopyAssign = is_detected<has_copy_assign_t, FooClass>::value;
+namespace detail {
+template <typename...> using void_t = void;
+template <class, template <class...> class Op, class... Args> struct detector {
+  using value_t = std::false_type;
+};
+template <template <class...> class Op, class... Args>
+struct detector<void_t<Op<Args...>>, Op, Args...> {
+  using value_t = std::true_type;
+};
+} // end namespace detail
+
+template <template <class...> class Op, class... Args>
+using is_detected = typename detail::detector<void, Op, Args...>::value_t;
+
+/// Check if a Callable type can be invoked with the given set of arg types.
+namespace detail {
+template <typename Callable, typename... Args>
+using is_invocable =
+    decltype(std::declval<Callable &>()(std::declval<Args>()...));
+} // namespace detail
+
+template <typename Callable, typename... Args>
+using is_invocable = is_detected<detail::is_invocable, Callable, Args...>;
+
+//===----------------------------------------------------------------------===//
+//     Extra additions to <iterator>
+//===----------------------------------------------------------------------===//
+
+/// A utility class used to implement an iterator that contains some object and
+/// an index. The iterator moves the index but keeps the object constant.
+template <typename DerivedT, typename ObjectType, typename T,
+          typename PointerT = T *, typename ReferenceT = T &>
+class indexed_accessor_iterator
+    : public llvm::iterator_facade_base<DerivedT,
+                                        std::random_access_iterator_tag, T,
+                                        std::ptrdiff_t, PointerT, ReferenceT> {
+public:
+  ptrdiff_t operator-(const indexed_accessor_iterator &rhs) const {
+    assert(object == rhs.object && "incompatible iterators");
+    return index - rhs.index;
+  }
+  bool operator==(const indexed_accessor_iterator &rhs) const {
+    return object == rhs.object && index == rhs.index;
+  }
+  bool operator<(const indexed_accessor_iterator &rhs) const {
+    assert(object == rhs.object && "incompatible iterators");
+    return index < rhs.index;
+  }
+
+  DerivedT &operator+=(ptrdiff_t offset) {
+    this->index += offset;
+    return static_cast<DerivedT &>(*this);
+  }
+  DerivedT &operator-=(ptrdiff_t offset) {
+    this->index -= offset;
+    return static_cast<DerivedT &>(*this);
+  }
+
+protected:
+  indexed_accessor_iterator(ObjectType object, ptrdiff_t index)
+      : object(object), index(index) {}
+  ObjectType object;
+  ptrdiff_t index;
+};
+
+} // end namespace mlir
+
+// Allow tuples to be usable as DenseMap keys.
+// TODO: Move this to upstream LLVM.
+
+/// Simplistic combination of 32-bit hash values into 32-bit hash values.
+/// This function is taken from llvm/ADT/DenseMapInfo.h.
+static inline unsigned llvm_combineHashValue(unsigned a, unsigned b) {
+  uint64_t key = (uint64_t)a << 32 | (uint64_t)b;
+  key += ~(key << 32);
+  key ^= (key >> 22);
+  key += ~(key << 13);
+  key ^= (key >> 8);
+  key += (key << 3);
+  key ^= (key >> 15);
+  key += ~(key << 27);
+  key ^= (key >> 31);
+  return (unsigned)key;
+}
+
+namespace llvm {
+template <typename... Ts> struct DenseMapInfo<std::tuple<Ts...>> {
+  using Tuple = std::tuple<Ts...>;
+
+  static inline Tuple getEmptyKey() {
+    return Tuple(DenseMapInfo<Ts>::getEmptyKey()...);
+  }
+
+  static inline Tuple getTombstoneKey() {
+    return Tuple(DenseMapInfo<Ts>::getTombstoneKey()...);
+  }
+
+  template <unsigned I>
+  static unsigned getHashValueImpl(const Tuple &values, std::false_type) {
+    using EltType = typename std::tuple_element<I, Tuple>::type;
+    std::integral_constant<bool, I + 1 == sizeof...(Ts)> atEnd;
+    return llvm_combineHashValue(
+        DenseMapInfo<EltType>::getHashValue(std::get<I>(values)),
+        getHashValueImpl<I + 1>(values, atEnd));
+  }
+
+  template <unsigned I>
+  static unsigned getHashValueImpl(const Tuple &values, std::true_type) {
+    return 0;
+  }
+
+  static unsigned getHashValue(const std::tuple<Ts...> &values) {
+    std::integral_constant<bool, 0 == sizeof...(Ts)> atEnd;
+    return getHashValueImpl<0>(values, atEnd);
+  }
+
+  template <unsigned I>
+  static bool isEqualImpl(const Tuple &lhs, const Tuple &rhs, std::false_type) {
+    using EltType = typename std::tuple_element<I, Tuple>::type;
+    std::integral_constant<bool, I + 1 == sizeof...(Ts)> atEnd;
+    return DenseMapInfo<EltType>::isEqual(std::get<I>(lhs), std::get<I>(rhs)) &&
+           isEqualImpl<I + 1>(lhs, rhs, atEnd);
+  }
+
+  template <unsigned I>
+  static bool isEqualImpl(const Tuple &lhs, const Tuple &rhs, std::true_type) {
+    return true;
+  }
+
+  static bool isEqual(const Tuple &lhs, const Tuple &rhs) {
+    std::integral_constant<bool, 0 == sizeof...(Ts)> atEnd;
+    return isEqualImpl<0>(lhs, rhs, atEnd);
+  }
+};
+
+} // end namespace llvm
+
+#endif // MLIR_SUPPORT_STLEXTRAS_H
diff --git a/third_party/mlir/include/mlir/Support/StorageUniquer.h b/third_party/mlir/include/mlir/Support/StorageUniquer.h
new file mode 100644
index 0000000..1873df1
--- /dev/null
+++ b/third_party/mlir/include/mlir/Support/StorageUniquer.h
@@ -0,0 +1,270 @@
+//===- StorageUniquer.h - Common Storage Class Uniquer ----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_SUPPORT_STORAGEUNIQUER_H
+#define MLIR_SUPPORT_STORAGEUNIQUER_H
+
+#include "mlir/Support/STLExtras.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
+
+namespace mlir {
+namespace detail {
+struct StorageUniquerImpl;
+
+/// Trait to check if ImplTy provides a 'getKey' method with types 'Args'.
+template <typename ImplTy, typename... Args>
+using has_impltype_getkey_t = decltype(ImplTy::getKey(std::declval<Args>()...));
+
+/// Trait to check if ImplTy provides a 'hashKey' method for 'T'.
+template <typename ImplTy, typename T>
+using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>()));
+} // namespace detail
+
+/// A utility class to get, or create instances of storage classes. These
+/// storage classes must respect the following constraints:
+///    - Derive from StorageUniquer::BaseStorage.
+///    - Provide an unsigned 'kind' value to be used as part of the unique'ing
+///      process.
+///
+/// For non-parametric storage classes, i.e. those that are solely uniqued by
+/// their kind, nothing else is needed. Instances of these classes can be
+/// created by calling `get` without trailing arguments.
+///
+/// Otherwise, the parametric storage classes may be created with `get`,
+/// and must respect the following:
+///    - Define a type alias, KeyTy, to a type that uniquely identifies the
+///      instance of the storage class within its kind.
+///      * The key type must be constructible from the values passed into the
+///        getComplex call after the kind.
+///      * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
+///        storage class must define a hashing method:
+///         'static unsigned hashKey(const KeyTy &)'
+///
+///    - Provide a method, 'bool operator==(const KeyTy &) const', to
+///      compare the storage instance against an instance of the key type.
+///
+///    - Provide a static construction method:
+///        'DerivedStorage *construct(StorageAllocator &, const KeyTy &key)'
+///      that builds a unique instance of the derived storage. The arguments to
+///      this function are an allocator to store any uniqued data and the key
+///      type for this storage.
+///
+///    - Provide a cleanup method:
+///        'void cleanup()'
+///      that is called when erasing a storage instance. This should cleanup any
+///      fields of the storage as necessary and not attempt to free the memory
+///      of the storage itself.
+class StorageUniquer {
+public:
+  StorageUniquer();
+  ~StorageUniquer();
+
+  /// This class acts as the base storage that all storage classes must derived
+  /// from.
+  class BaseStorage {
+  public:
+    /// Get the kind classification of this storage.
+    unsigned getKind() const { return kind; }
+
+  protected:
+    BaseStorage() : kind(0) {}
+
+  private:
+    /// Allow access to the kind field.
+    friend detail::StorageUniquerImpl;
+
+    /// Classification of the subclass, used for type checking.
+    unsigned kind;
+  };
+
+  /// This is a utility allocator used to allocate memory for instances of
+  /// derived types.
+  class StorageAllocator {
+  public:
+    /// Copy the specified array of elements into memory managed by our bump
+    /// pointer allocator.  This assumes the elements are all PODs.
+    template <typename T> ArrayRef<T> copyInto(ArrayRef<T> elements) {
+      if (elements.empty())
+        return llvm::None;
+      auto result = allocator.Allocate<T>(elements.size());
+      std::uninitialized_copy(elements.begin(), elements.end(), result);
+      return ArrayRef<T>(result, elements.size());
+    }
+
+    /// Copy the provided string into memory managed by our bump pointer
+    /// allocator.
+    StringRef copyInto(StringRef str) {
+      auto result = copyInto(ArrayRef<char>(str.data(), str.size()));
+      return StringRef(result.data(), str.size());
+    }
+
+    /// Allocate an instance of the provided type.
+    template <typename T> T *allocate() { return allocator.Allocate<T>(); }
+
+    /// Allocate 'size' bytes of 'alignment' aligned memory.
+    void *allocate(size_t size, size_t alignment) {
+      return allocator.Allocate(size, alignment);
+    }
+
+  private:
+    /// The raw allocator for type storage objects.
+    llvm::BumpPtrAllocator allocator;
+  };
+
+  /// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter
+  /// that can be used to initialize a newly inserted storage instance. This
+  /// function is used for derived types that have complex storage or uniquing
+  /// constraints.
+  template <typename Storage, typename Arg, typename... Args>
+  Storage *get(std::function<void(Storage *)> initFn, unsigned kind, Arg &&arg,
+               Args &&... args) {
+    // Construct a value of the derived key type.
+    auto derivedKey =
+        getKey<Storage>(std::forward<Arg>(arg), std::forward<Args>(args)...);
+
+    // Create a hash of the kind and the derived key.
+    unsigned hashValue = getHash<Storage>(kind, derivedKey);
+
+    // Generate an equality function for the derived storage.
+    std::function<bool(const BaseStorage *)> isEqual =
+        [&derivedKey](const BaseStorage *existing) {
+          return static_cast<const Storage &>(*existing) == derivedKey;
+        };
+
+    // Generate a constructor function for the derived storage.
+    std::function<BaseStorage *(StorageAllocator &)> ctorFn =
+        [&](StorageAllocator &allocator) {
+          auto *storage = Storage::construct(allocator, derivedKey);
+          if (initFn)
+            initFn(storage);
+          return storage;
+        };
+
+    // Get an instance for the derived storage.
+    return static_cast<Storage *>(getImpl(kind, hashValue, isEqual, ctorFn));
+  }
+
+  /// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter
+  /// that can be used to initialize a newly inserted storage instance. This
+  /// function is used for derived types that use no additional storage or
+  /// uniquing outside of the kind.
+  template <typename Storage>
+  Storage *get(std::function<void(Storage *)> initFn, unsigned kind) {
+    auto ctorFn = [&](StorageAllocator &allocator) {
+      auto *storage = new (allocator.allocate<Storage>()) Storage();
+      if (initFn)
+        initFn(storage);
+      return storage;
+    };
+    return static_cast<Storage *>(getImpl(kind, ctorFn));
+  }
+
+  /// Erases a uniqued instance of 'Storage'. This function is used for derived
+  /// types that have complex storage or uniquing constraints.
+  template <typename Storage, typename Arg, typename... Args>
+  void erase(unsigned kind, Arg &&arg, Args &&... args) {
+    // Construct a value of the derived key type.
+    auto derivedKey =
+        getKey<Storage>(std::forward<Arg>(arg), std::forward<Args>(args)...);
+
+    // Create a hash of the kind and the derived key.
+    unsigned hashValue = getHash<Storage>(kind, derivedKey);
+
+    // Generate an equality function for the derived storage.
+    std::function<bool(const BaseStorage *)> isEqual =
+        [&derivedKey](const BaseStorage *existing) {
+          return static_cast<const Storage &>(*existing) == derivedKey;
+        };
+
+    // Attempt to erase the storage instance.
+    eraseImpl(kind, hashValue, isEqual, [](BaseStorage *storage) {
+      static_cast<Storage *>(storage)->cleanup();
+    });
+  }
+
+private:
+  /// Implementation for getting/creating an instance of a derived type with
+  /// complex storage.
+  BaseStorage *getImpl(unsigned kind, unsigned hashValue,
+                       llvm::function_ref<bool(const BaseStorage *)> isEqual,
+                       std::function<BaseStorage *(StorageAllocator &)> ctorFn);
+
+  /// Implementation for getting/creating an instance of a derived type with
+  /// default storage.
+  BaseStorage *getImpl(unsigned kind,
+                       std::function<BaseStorage *(StorageAllocator &)> ctorFn);
+
+  /// Implementation for erasing an instance of a derived type with complex
+  /// storage.
+  void eraseImpl(unsigned kind, unsigned hashValue,
+                 llvm::function_ref<bool(const BaseStorage *)> isEqual,
+                 std::function<void(BaseStorage *)> cleanupFn);
+
+  /// The internal implementation class.
+  std::unique_ptr<detail::StorageUniquerImpl> impl;
+
+  //===--------------------------------------------------------------------===//
+  // Key Construction
+  //===--------------------------------------------------------------------===//
+
+  /// Used to construct an instance of 'ImplTy::KeyTy' if there is an
+  /// 'ImplTy::getKey' function for the provided arguments.
+  template <typename ImplTy, typename... Args>
+  static typename std::enable_if<
+      is_detected<detail::has_impltype_getkey_t, ImplTy, Args...>::value,
+      typename ImplTy::KeyTy>::type
+  getKey(Args &&... args) {
+    return ImplTy::getKey(args...);
+  }
+  /// If there is no 'ImplTy::getKey' method, then we try to directly construct
+  /// the 'ImplTy::KeyTy' with the provided arguments.
+  template <typename ImplTy, typename... Args>
+  static typename std::enable_if<
+      !is_detected<detail::has_impltype_getkey_t, ImplTy, Args...>::value,
+      typename ImplTy::KeyTy>::type
+  getKey(Args &&... args) {
+    return typename ImplTy::KeyTy(args...);
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Key and Kind Hashing
+  //===--------------------------------------------------------------------===//
+
+  /// Used to generate a hash for the 'ImplTy::KeyTy' and kind of a storage
+  /// instance if there is an 'ImplTy::hashKey' overload for 'DerivedKey'.
+  template <typename ImplTy, typename DerivedKey>
+  static typename std::enable_if<
+      is_detected<detail::has_impltype_hash_t, ImplTy, DerivedKey>::value,
+      ::llvm::hash_code>::type
+  getHash(unsigned kind, const DerivedKey &derivedKey) {
+    return llvm::hash_combine(kind, ImplTy::hashKey(derivedKey));
+  }
+  /// If there is no 'ImplTy::hashKey' default to using the
+  /// 'llvm::DenseMapInfo' definition for 'DerivedKey' for generating a hash.
+  template <typename ImplTy, typename DerivedKey>
+  static typename std::enable_if<
+      !is_detected<detail::has_impltype_hash_t, ImplTy, DerivedKey>::value,
+      ::llvm::hash_code>::type
+  getHash(unsigned kind, const DerivedKey &derivedKey) {
+    return llvm::hash_combine(
+        kind, llvm::DenseMapInfo<DerivedKey>::getHashValue(derivedKey));
+  }
+};
+} // end namespace mlir
+
+#endif
diff --git a/third_party/mlir/include/mlir/Support/StringExtras.h b/third_party/mlir/include/mlir/Support/StringExtras.h
new file mode 100644
index 0000000..9948d15
--- /dev/null
+++ b/third_party/mlir/include/mlir/Support/StringExtras.h
@@ -0,0 +1,83 @@
+//===- StringExtras.h - String utilities used by MLIR -----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains string utility functions used within MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_STRINGEXTRAS_H
+#define MLIR_SUPPORT_STRINGEXTRAS_H
+
+#include "llvm/ADT/StringExtras.h"
+
+#include <cctype>
+
+namespace mlir {
+/// Converts a string to snake-case from camel-case by replacing all uppercase
+/// letters with '_' followed by the letter in lowercase, except if the
+/// uppercase letter is the first character of the string.
+inline std::string convertToSnakeCase(llvm::StringRef input) {
+  std::string snakeCase;
+  snakeCase.reserve(input.size());
+  for (auto c : input) {
+    if (std::isupper(c)) {
+      if (!snakeCase.empty() && snakeCase.back() != '_') {
+        snakeCase.push_back('_');
+      }
+      snakeCase.push_back(llvm::toLower(c));
+    } else {
+      snakeCase.push_back(c);
+    }
+  }
+  return snakeCase;
+}
+
+/// Converts a string from camel-case to snake_case by replacing all occurences
+/// of '_' followed by a lowercase letter with the letter in
+/// uppercase. Optionally allow capitalization of the first letter (if it is a
+/// lowercase letter)
+inline std::string convertToCamelCase(llvm::StringRef input,
+                                      bool capitalizeFirst = false) {
+  if (input.empty()) {
+    return "";
+  }
+  std::string output;
+  output.reserve(input.size());
+  size_t pos = 0;
+  if (capitalizeFirst && std::islower(input[pos])) {
+    output.push_back(llvm::toUpper(input[pos]));
+    pos++;
+  }
+  while (pos < input.size()) {
+    auto cur = input[pos];
+    if (cur == '_') {
+      if (pos && (pos + 1 < input.size())) {
+        if (std::islower(input[pos + 1])) {
+          output.push_back(llvm::toUpper(input[pos + 1]));
+          pos += 2;
+          continue;
+        }
+      }
+    }
+    output.push_back(cur);
+    pos++;
+  }
+  return output;
+}
+} // namespace mlir
+
+#endif // MLIR_SUPPORT_STRINGEXTRAS_H
diff --git a/third_party/mlir/include/mlir/Support/TranslateClParser.h b/third_party/mlir/include/mlir/Support/TranslateClParser.h
new file mode 100644
index 0000000..d81dd83
--- /dev/null
+++ b/third_party/mlir/include/mlir/Support/TranslateClParser.h
@@ -0,0 +1,50 @@
+//===- TranslateClParser.h - Translations command line parser ---*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains custom command line parser for translations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_TRANSLATE_CL_PARSER_H_
+#define MLIR_SUPPORT_TRANSLATE_CL_PARSER_H_
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/CommandLine.h"
+#include <functional>
+
+namespace mlir {
+
+struct LogicalResult;
+class MLIRContext;
+
+/// Common interface for source-to-source translation functions.
+using TranslateFunction = std::function<LogicalResult(
+    StringRef inputFilename, StringRef outputFilename, MLIRContext *)>;
+
+/// Custom parser for TranslateFunction.
+/// Wraps TranslateToMLIRFunctions and TranslateFromMLIRFunctions into
+/// TranslateFunctions before registering them as options.
+struct TranslationParser : public llvm::cl::parser<const TranslateFunction *> {
+  TranslationParser(llvm::cl::Option &opt);
+
+  void printOptionInfo(const llvm::cl::Option &O,
+                       size_t GlobalWidth) const override;
+};
+
+} // namespace mlir
+
+#endif // MLIR_SUPPORT_TRANSLATE_CL_PARSER_H_
diff --git a/third_party/mlir/include/mlir/TableGen/Argument.h b/third_party/mlir/include/mlir/TableGen/Argument.h
new file mode 100644
index 0000000..8390939
--- /dev/null
+++ b/third_party/mlir/include/mlir/TableGen/Argument.h
@@ -0,0 +1,68 @@
+//===- Argument.h - Argument definitions ------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This header file contains definitions for TableGen operation's arguments.
+// Operation arguments fall into two categories:
+//
+// 1. Operands: SSA values operated on by the operation
+// 2. Attributes: compile-time known properties that have influence over
+//    the operation's behavior
+//
+// These two categories are modelled with the unified argument concept in
+// TableGen because we need similar pattern matching mechanisms for them.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_ARGUMENT_H_
+#define MLIR_TABLEGEN_ARGUMENT_H_
+
+#include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/Type.h"
+#include "llvm/ADT/PointerUnion.h"
+#include <string>
+
+namespace llvm {
+class StringRef;
+} // end namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+// A struct wrapping an op attribute and its name together
+struct NamedAttribute {
+  llvm::StringRef name;
+  Attribute attr;
+};
+
+// A struct wrapping an op operand/result's constraint and its name together
+struct NamedTypeConstraint {
+  // Returns true if this operand/result has constraint to be satisfied.
+  bool hasPredicate() const;
+  // Returns true if this operand/result is variadic.
+  bool isVariadic() const;
+
+  llvm::StringRef name;
+  TypeConstraint constraint;
+};
+
+// Operation argument: either attribute or operand
+using Argument = llvm::PointerUnion<NamedAttribute *, NamedTypeConstraint *>;
+
+} // end namespace tblgen
+} // end namespace mlir
+
+#endif // MLIR_TABLEGEN_ARGUMENT_H_
diff --git a/third_party/mlir/include/mlir/TableGen/Attribute.h b/third_party/mlir/include/mlir/TableGen/Attribute.h
new file mode 100644
index 0000000..2f137a2
--- /dev/null
+++ b/third_party/mlir/include/mlir/TableGen/Attribute.h
@@ -0,0 +1,186 @@
+//===- Attribute.h - Attribute wrapper class --------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Attribute wrapper to simplify using TableGen Record defining a MLIR
+// Attribute.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_ATTRIBUTE_H_
+#define MLIR_TABLEGEN_ATTRIBUTE_H_
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/TableGen/Constraint.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace llvm {
+class DefInit;
+class Record;
+} // end namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+// Wrapper class with helper methods for accessing attribute constraints defined
+// in TableGen.
+class AttrConstraint : public Constraint {
+public:
+  explicit AttrConstraint(const llvm::Record *record);
+
+  static bool classof(const Constraint *c) { return c->getKind() == CK_Attr; }
+};
+
+// Wrapper class providing helper methods for accessing MLIR Attribute defined
+// in TableGen. This class should closely reflect what is defined as class
+// `Attr` in TableGen.
+class Attribute : public AttrConstraint {
+public:
+  explicit Attribute(const llvm::Record *record);
+  explicit Attribute(const llvm::DefInit *init);
+
+  // Returns the storage type if set. Returns the default storage type
+  // ("Attribute") otherwise.
+  StringRef getStorageType() const;
+
+  // Returns the return type for this attribute.
+  StringRef getReturnType() const;
+
+  // Returns the template getter method call which reads this attribute's
+  // storage and returns the value as of the desired return type.
+  // The call will contain a `{0}` which will be expanded to this attribute.
+  StringRef getConvertFromStorageCall() const;
+
+  // Returns true if this attribute can be built from a constant value.
+  bool isConstBuildable() const;
+
+  // Returns the template that can be used to produce an instance of the
+  // attribute.
+  // Syntax: {0} should be replaced with a builder, {1} should be replaced with
+  // the constant value.
+  StringRef getConstBuilderTemplate() const;
+
+  // Returns the base-level attribute that this attribute constraint is
+  // built upon.
+  Attribute getBaseAttr() const;
+
+  // Returns whether this attribute has a default value's initializer.
+  bool hasDefaultValueInitializer() const;
+  // Returns the default value's initializer for this attribute.
+  StringRef getDefaultValueInitializer() const;
+
+  // Returns whether this attribute is optional.
+  bool isOptional() const;
+
+  // Returns true if this attribute is a derived attribute (i.e., a subclass
+  // of `DerivedAttr`).
+  bool isDerivedAttr() const;
+
+  // Returns true if this attribute is a type attribute (i.e., a subclass
+  // of `TypeAttrBase`).
+  bool isTypeAttr() const;
+
+  // Returns true if this attribute is an enum attribute (i.e., a subclass of
+  // `EnumAttrInfo`)
+  bool isEnumAttr() const;
+
+  // Returns this attribute's TableGen def name. If this is an `OptionalAttr`
+  // or `DefaultValuedAttr` without explicit name, returns the base attribute's
+  // name.
+  StringRef getAttrDefName() const;
+
+  // Returns the code body for derived attribute. Aborts if this is not a
+  // derived attribute.
+  StringRef getDerivedCodeBody() const;
+};
+
+// Wrapper class providing helper methods for accessing MLIR constant attribute
+// defined in TableGen. This class should closely reflect what is defined as
+// class `ConstantAttr` in TableGen.
+class ConstantAttr {
+public:
+  explicit ConstantAttr(const llvm::DefInit *init);
+
+  // Returns the attribute kind.
+  Attribute getAttribute() const;
+
+  // Returns the constant value.
+  StringRef getConstantValue() const;
+
+private:
+  // The TableGen definition of this constant attribute.
+  const llvm::Record *def;
+};
+
+// Wrapper class providing helper methods for accessing enum attribute cases
+// defined in TableGen. This is used for enum attribute case backed by both
+// StringAttr and IntegerAttr.
+class EnumAttrCase : public Attribute {
+public:
+  explicit EnumAttrCase(const llvm::DefInit *init);
+
+  // Returns true if this EnumAttrCase is backed by a StringAttr.
+  bool isStrCase() const;
+
+  // Returns the symbol of this enum attribute case.
+  StringRef getSymbol() const;
+
+  // Returns the value of this enum attribute case.
+  int64_t getValue() const;
+};
+
+// Wrapper class providing helper methods for accessing enum attributes defined
+// in TableGen.This is used for enum attribute case backed by both StringAttr
+// and IntegerAttr.
+class EnumAttr : public Attribute {
+public:
+  explicit EnumAttr(const llvm::Record *record);
+  explicit EnumAttr(const llvm::Record &record);
+  explicit EnumAttr(const llvm::DefInit *init);
+
+  // Returns the enum class name.
+  StringRef getEnumClassName() const;
+
+  // Returns the C++ namespaces this enum class should be placed in.
+  StringRef getCppNamespace() const;
+
+  // Returns the underlying type.
+  StringRef getUnderlyingType() const;
+
+  // Returns the name of the utility function that converts a value of the
+  // underlying type to the corresponding symbol.
+  StringRef getUnderlyingToSymbolFnName() const;
+
+  // Returns the name of the utility function that converts a string to the
+  // corresponding symbol.
+  StringRef getStringToSymbolFnName() const;
+
+  // Returns the name of the utility function that converts a symbol to the
+  // corresponding string.
+  StringRef getSymbolToStringFnName() const;
+
+  // Returns the name of the utilit function that returns the max enum value
+  // used within the enum class.
+  StringRef getMaxEnumValFnName() const;
+
+  // Returns all allowed cases for this enum attribute.
+  std::vector<EnumAttrCase> getAllCases() const;
+};
+
+} // end namespace tblgen
+} // end namespace mlir
+
+#endif // MLIR_TABLEGEN_ATTRIBUTE_H_
diff --git a/third_party/mlir/include/mlir/TableGen/Constraint.h b/third_party/mlir/include/mlir/TableGen/Constraint.h
new file mode 100644
index 0000000..17b60da
--- /dev/null
+++ b/third_party/mlir/include/mlir/TableGen/Constraint.h
@@ -0,0 +1,90 @@
+//===- Constraint.h - Constraint class --------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Constraint wrapper to simplify using TableGen Record for constraints.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_CONSTRAINT_H_
+#define MLIR_TABLEGEN_CONSTRAINT_H_
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/TableGen/Predicate.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace llvm {
+class Record;
+} // end namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+// Wrapper class with helper methods for accessing Constraint defined in
+// TableGen.
+class Constraint {
+public:
+  Constraint(const llvm::Record *record);
+
+  bool operator==(const Constraint &that) { return def == that.def; }
+  bool operator!=(const Constraint &that) { return def != that.def; }
+
+  // Returns the predicate for this constraint.
+  Pred getPredicate() const;
+
+  // Returns the condition template that can be used to check if a type or
+  // attribute satisfies this constraint.  The template may contain "{0}" that
+  // must be substituted with an expression returning an mlir::Type or
+  // mlir::Attribute.
+  std::string getConditionTemplate() const;
+
+  // Returns the user-readable description of this constraint. If the
+  // description is not provided, returns the TableGen def name.
+  StringRef getDescription() const;
+
+  // Constraint kind
+  enum Kind { CK_Attr, CK_Region, CK_Type, CK_Uncategorized };
+
+  Kind getKind() const { return kind; }
+
+protected:
+  Constraint(Kind kind, const llvm::Record *record);
+
+  // The TableGen definition of this constraint.
+  const llvm::Record *def;
+
+private:
+  // What kind of constraint this is.
+  Kind kind;
+};
+
+// An constraint and the concrete entities to place the constraint on.
+struct AppliedConstraint {
+  AppliedConstraint(Constraint &&constraint, StringRef self,
+                    std::vector<std::string> &&entities);
+
+  Constraint constraint;
+  // The symbol to replace `$_self` special placeholder in the constraint.
+  std::string self;
+  // The symbols to replace `$N` positional placeholders in the constraint.
+  std::vector<std::string> entities;
+};
+
+} // end namespace tblgen
+} // end namespace mlir
+
+#endif // MLIR_TABLEGEN_CONSTRAINT_H_
diff --git a/third_party/mlir/include/mlir/TableGen/Dialect.h b/third_party/mlir/include/mlir/TableGen/Dialect.h
new file mode 100644
index 0000000..0005ad1
--- /dev/null
+++ b/third_party/mlir/include/mlir/TableGen/Dialect.h
@@ -0,0 +1,50 @@
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Dialect wrapper to simplify using TableGen Record defining a MLIR dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_DIALECT_H_
+#define MLIR_TABLEGEN_DIALECT_H_
+
+#include "mlir/Support/LLVM.h"
+
+namespace llvm {
+class Record;
+} // end namespace llvm
+
+namespace mlir {
+namespace tblgen {
+// Wrapper class that contains a MLIR dialect's information defined in TableGen
+// and provides helper methods for accessing them.
+class Dialect {
+public:
+  explicit Dialect(const llvm::Record *def) : def(*def) {}
+
+  // Returns the name of this dialect.
+  StringRef getName() const;
+
+  // Returns the C++ namespaces that ops of this dialect should be placed into.
+  StringRef getCppNamespace() const;
+
+private:
+  const llvm::Record &def;
+};
+} // end namespace tblgen
+} // end namespace mlir
+
+#endif // MLIR_TABLEGEN_DIALECT_H_
diff --git a/third_party/mlir/include/mlir/TableGen/Format.h b/third_party/mlir/include/mlir/TableGen/Format.h
new file mode 100644
index 0000000..75ace15
--- /dev/null
+++ b/third_party/mlir/include/mlir/TableGen/Format.h
@@ -0,0 +1,248 @@
+//===- Format.h - Utilities for String Format -------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file declares utilities for formatting strings. They are specially
+// tailored to the needs of TableGen'ing op definitions and rewrite rules,
+// so they are not expected to be used as widely applicable utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_FORMAT_H_
+#define MLIR_TABLEGEN_FORMAT_H_
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/FormatVariadic.h"
+
+namespace mlir {
+namespace tblgen {
+
+/// Format context containing substitutions for special placeholders.
+///
+/// This context divides special placeholders into two categories: builtin ones
+/// and custom ones.
+///
+/// Builtin placeholders are baked into `FmtContext` and each one of them has a
+/// dedicated setter. They can be used in all dialects. Their names follow the
+/// convention of `$_<name>`. The rationale of the leading underscore is to
+/// avoid confusion and name collision: op arguments/attributes/results are
+/// named as $<name>, and we can potentially support referencing those entities
+/// directly in the format template in the future.
+//
+/// Custom ones are registered by dialect-specific TablGen backends and use the
+/// same unified setter.
+class FmtContext {
+public:
+  // Placeholder kinds
+  enum class PHKind : char {
+    None,
+    Custom,  // For custom placeholders
+    Builder, // For the $_builder placeholder
+    Op,      // For the $_op placeholder
+    Self,    // For the $_self placeholder
+  };
+
+  FmtContext() = default;
+
+  // Setter for custom placeholders
+  FmtContext &addSubst(StringRef placeholder, Twine subst);
+
+  // Setters for builtin placeholders
+  FmtContext &withBuilder(Twine subst);
+  FmtContext &withOp(Twine subst);
+  FmtContext &withSelf(Twine subst);
+
+  Optional<StringRef> getSubstFor(PHKind placeholder) const;
+  Optional<StringRef> getSubstFor(StringRef placeholder) const;
+
+  static PHKind getPlaceHolderKind(StringRef str);
+
+private:
+  struct PHKindInfo : DenseMapInfo<PHKind> {
+    using CharInfo = DenseMapInfo<char>;
+
+    static inline PHKind getEmptyKey() {
+      return static_cast<PHKind>(CharInfo::getEmptyKey());
+    }
+    static inline PHKind getTombstoneKey() {
+      return static_cast<PHKind>(CharInfo::getTombstoneKey());
+    }
+    static unsigned getHashValue(const PHKind &val) {
+      return CharInfo::getHashValue(static_cast<char>(val));
+    }
+
+    static bool isEqual(const PHKind &lhs, const PHKind &rhs) {
+      return lhs == rhs;
+    }
+  };
+
+  llvm::SmallDenseMap<PHKind, std::string, 4, PHKindInfo> builtinSubstMap;
+  llvm::StringMap<std::string> customSubstMap;
+};
+
+/// Struct representing a replacement segment for the formatted string. It can
+/// be a segment of the formatting template (for `Literal`) or a replacement
+/// parameter (for `PositionalPH` and `SpecialPH`).
+struct FmtReplacement {
+  enum class Type { Empty, Literal, PositionalPH, SpecialPH };
+
+  FmtReplacement() = default;
+  explicit FmtReplacement(StringRef literal)
+      : type(Type::Literal), spec(literal) {}
+  FmtReplacement(StringRef spec, size_t index)
+      : type(Type::PositionalPH), spec(spec), index(index) {}
+  FmtReplacement(StringRef spec, FmtContext::PHKind placeholder)
+      : type(Type::SpecialPH), spec(spec), placeholder(placeholder) {}
+
+  Type type = Type::Empty;
+  StringRef spec;
+  size_t index = 0;
+  FmtContext::PHKind placeholder = FmtContext::PHKind::None;
+};
+
+class FmtObjectBase {
+private:
+  static std::pair<FmtReplacement, StringRef> splitFmtSegment(StringRef fmt);
+  static std::vector<FmtReplacement> parseFormatString(StringRef fmt);
+
+protected:
+  // The parameters are stored in a std::tuple, which does not provide runtime
+  // indexing capabilities.  In order to enable runtime indexing, we use this
+  // structure to put the parameters into a std::vector.  Since the parameters
+  // are not all the same type, we use some type-erasure by wrapping the
+  // parameters in a template class that derives from a non-template superclass.
+  // Essentially, we are converting a std::tuple<Derived<Ts...>> to a
+  // std::vector<Base*>.
+  struct CreateAdapters {
+    template <typename... Ts>
+    std::vector<llvm::detail::format_adapter *> operator()(Ts &... items) {
+      return std::vector<llvm::detail::format_adapter *>{&items...};
+    }
+  };
+
+  StringRef fmt;
+  const FmtContext *context;
+  std::vector<llvm::detail::format_adapter *> adapters;
+  std::vector<FmtReplacement> replacements;
+
+public:
+  FmtObjectBase(StringRef fmt, const FmtContext *ctx, size_t numParams)
+      : fmt(fmt), context(ctx), replacements(parseFormatString(fmt)) {}
+
+  FmtObjectBase(const FmtObjectBase &that) = delete;
+
+  FmtObjectBase(FmtObjectBase &&that)
+      : fmt(std::move(that.fmt)), context(that.context),
+        adapters(), // adapters are initialized by FmtObject
+        replacements(std::move(that.replacements)) {}
+
+  void format(llvm::raw_ostream &s) const;
+
+  std::string str() const {
+    std::string result;
+    llvm::raw_string_ostream s(result);
+    format(s);
+    return s.str();
+  }
+
+  template <unsigned N> SmallString<N> sstr() const {
+    SmallString<N> result;
+    llvm::raw_svector_ostream s(result);
+    format(s);
+    return result;
+  }
+
+  template <unsigned N> operator SmallString<N>() const { return sstr<N>(); }
+
+  operator std::string() const { return str(); }
+};
+
+template <typename Tuple> class FmtObject : public FmtObjectBase {
+  // Storage for the parameter adapters.  Since the base class erases the type
+  // of the parameters, we have to own the storage for the parameters here, and
+  // have the base class store type-erased pointers into this tuple.
+  Tuple parameters;
+
+public:
+  FmtObject(StringRef fmt, const FmtContext *ctx, Tuple &&params)
+      : FmtObjectBase(fmt, ctx, std::tuple_size<Tuple>::value),
+        parameters(std::move(params)) {
+    adapters.reserve(std::tuple_size<Tuple>::value);
+    adapters = llvm::apply_tuple(CreateAdapters(), parameters);
+  }
+
+  FmtObject(FmtObject const &that) = delete;
+
+  FmtObject(FmtObject &&that)
+      : FmtObjectBase(std::move(that)), parameters(std::move(that.parameters)) {
+    adapters.reserve(that.adapters.size());
+    adapters = llvm::apply_tuple(CreateAdapters(), parameters);
+  }
+};
+
+/// Formats text by substituting placeholders in format string with replacement
+/// parameters.
+///
+/// There are two categories of placeholders accepted, both led by a '$' sign:
+///
+/// 1. Positional placeholder: $[0-9]+
+/// 2. Special placeholder:    $[a-zA-Z_][a-zA-Z0-9_]*
+///
+/// Replacement parameters for positional placeholders are supplied as the
+/// `vals` parameter pack with 1:1 mapping. That is, $0 will be replaced by the
+/// first parameter in `vals`, $1 by the second one, and so on. Note that you
+/// can use the positional placeholders in any order and repeat any times, for
+/// example, "$2 $1 $1 $0" is accepted.
+///
+/// Replacement parameters for special placeholders are supplied using the `ctx`
+/// format context.
+///
+/// The `fmt` is recorded as a `StringRef` inside the returned `FmtObject`.
+/// The caller needs to make sure the underlying data is available when the
+/// `FmtObject` is used.
+///
+/// `ctx` accepts a nullptr if there is no special placeholder is used.
+///
+/// If no substitution is provided for a placeholder or any error happens during
+/// format string parsing or replacement, the placeholder will be outputted
+/// as-is with an additional marker '<no-subst-found>', to aid debugging.
+///
+/// To print a '$' literally, escape it with '$$'.
+///
+/// This utility function is inspired by LLVM formatv(), with modifications
+/// specially tailored for TableGen C++ generation usage:
+///
+/// 1. This utility use '$' instead of '{' and '}' for denoting the placeholder
+///    because '{' and '}' are frequently used in C++ code.
+/// 2. This utility does not support format layout because it is rarely needed
+///    in C++ code generation.
+template <typename... Ts>
+inline auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&... vals)
+    -> FmtObject<decltype(std::make_tuple(
+        llvm::detail::build_format_adapter(std::forward<Ts>(vals))...))> {
+  using ParamTuple = decltype(std::make_tuple(
+      llvm::detail::build_format_adapter(std::forward<Ts>(vals))...));
+  return FmtObject<ParamTuple>(
+      fmt, ctx,
+      std::make_tuple(
+          llvm::detail::build_format_adapter(std::forward<Ts>(vals))...));
+}
+
+} // end namespace tblgen
+} // end namespace mlir
+
+#endif // MLIR_TABLEGEN_FORMAT_H_
diff --git a/third_party/mlir/include/mlir/TableGen/GenInfo.h b/third_party/mlir/include/mlir/TableGen/GenInfo.h
new file mode 100644
index 0000000..0b0bd19
--- /dev/null
+++ b/third_party/mlir/include/mlir/TableGen/GenInfo.h
@@ -0,0 +1,81 @@
+//===- GenInfo.h - Generator info -------------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_TABLEGEN_GENINFO_H_
+#define MLIR_TABLEGEN_GENINFO_H_
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/StringRef.h"
+#include <functional>
+
+namespace llvm {
+class RecordKeeper;
+} // end namespace llvm
+
+namespace mlir {
+
+/// Generator function to invoke.
+using GenFunction = std::function<bool(const llvm::RecordKeeper &recordKeeper,
+                                       raw_ostream &os)>;
+
+/// Structure to group information about a generator (argument to invoke via
+/// mlir-tblgen, description, and generator function).
+class GenInfo {
+public:
+  /// GenInfo constructor should not be invoked directly, instead use
+  /// GenRegistration or registerGen.
+  GenInfo(StringRef arg, StringRef description, GenFunction generator)
+      : arg(arg), description(description), generator(generator) {}
+
+  /// Invokes the generator and returns whether the generator failed.
+  bool invoke(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) const {
+    assert(generator && "Cannot call generator with null generator");
+    return generator(recordKeeper, os);
+  }
+
+  /// Returns the command line option that may be passed to 'mlir-tblgen' to
+  /// invoke this generator.
+  StringRef getGenArgument() const { return arg; }
+
+  /// Returns a description for the generator.
+  StringRef getGenDescription() const { return description; }
+
+private:
+  // The argument with which to invoke the generator via mlir-tblgen.
+  StringRef arg;
+
+  // Description of the generator.
+  StringRef description;
+
+  // Generator function.
+  GenFunction generator;
+};
+
+/// GenRegistration provides a global initializer that registers a generator
+/// function.
+///
+/// Usage:
+///
+///   // At namespace scope.
+///   static GenRegistration Print("print", "Print records", [](...){...});
+struct GenRegistration {
+  GenRegistration(StringRef arg, StringRef description, GenFunction function);
+};
+
+} // end namespace mlir
+
+#endif // MLIR_TABLEGEN_GENINFO_H_
diff --git a/third_party/mlir/include/mlir/TableGen/GenNameParser.h b/third_party/mlir/include/mlir/TableGen/GenNameParser.h
new file mode 100644
index 0000000..7b1e8a3
--- /dev/null
+++ b/third_party/mlir/include/mlir/TableGen/GenNameParser.h
@@ -0,0 +1,40 @@
+//===- GenNameParser.h - Command line parser for generators -----*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// The GenNameParser class adds all passes linked in to the system that are
+// creatable to the tool.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_GENNAMEPARSER_H_
+#define MLIR_TABLEGEN_GENNAMEPARSER_H_
+
+#include "llvm/Support/CommandLine.h"
+
+namespace mlir {
+class GenInfo;
+
+/// Adds command line option for each registered generator.
+struct GenNameParser : public llvm::cl::parser<const GenInfo *> {
+  GenNameParser(llvm::cl::Option &opt);
+
+  void printOptionInfo(const llvm::cl::Option &O,
+                       size_t GlobalWidth) const override;
+};
+} // end namespace mlir
+
+#endif // MLIR_TABLEGEN_GENNAMEPARSER_H_
diff --git a/third_party/mlir/include/mlir/TableGen/OpTrait.h b/third_party/mlir/include/mlir/TableGen/OpTrait.h
new file mode 100644
index 0000000..8a3463d
--- /dev/null
+++ b/third_party/mlir/include/mlir/TableGen/OpTrait.h
@@ -0,0 +1,98 @@
+//===- OpTrait.h - OpTrait wrapper class ------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// OpTrait wrapper to simplify using TableGen Record defining an MLIR OpTrait.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_OPTRAIT_H_
+#define MLIR_TABLEGEN_OPTRAIT_H_
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace llvm {
+class Init;
+class Record;
+} // end namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+// Wrapper class with helper methods for accessing OpTrait constraints defined
+// in TableGen.
+class OpTrait {
+public:
+  // Discriminator for kinds of op traits.
+  enum class Kind {
+    // OpTrait corresponding to C++ class.
+    Native,
+    // OpTrait corresponding to predicate on operation.
+    Pred,
+    // OpTrait controlling op definition generator internals.
+    Internal
+  };
+
+  explicit OpTrait(Kind kind, const llvm::Record *def);
+
+  // Returns an OpTrait corresponding to the init provided.
+  static OpTrait create(const llvm::Init *init);
+
+  Kind getKind() const { return kind; }
+
+protected:
+  // The TableGen definition of this trait.
+  const llvm::Record *def;
+  Kind kind;
+};
+
+// OpTrait corresponding to a native C++ OpTrait.
+class NativeOpTrait : public OpTrait {
+public:
+  // Returns the trait corresponding to a C++ trait class.
+  StringRef getTrait() const;
+
+  static bool classof(const OpTrait *t) { return t->getKind() == Kind::Native; }
+};
+
+// OpTrait corresponding to a predicate on the operation.
+class PredOpTrait : public OpTrait {
+public:
+  // Returns the template for constructing the predicate.
+  std::string getPredTemplate() const;
+
+  // Returns the description of what the predicate is verifying.
+  StringRef getDescription() const;
+
+  static bool classof(const OpTrait *t) { return t->getKind() == Kind::Pred; }
+};
+
+// OpTrait controlling op definition generator internals.
+class InternalOpTrait : public OpTrait {
+public:
+  // Returns the trait controlling op definition generator internals.
+  StringRef getTrait() const;
+
+  static bool classof(const OpTrait *t) {
+    return t->getKind() == Kind::Internal;
+  }
+};
+
+} // end namespace tblgen
+} // end namespace mlir
+
+#endif // MLIR_TABLEGEN_OPTRAIT_H_
diff --git a/third_party/mlir/include/mlir/TableGen/Operator.h b/third_party/mlir/include/mlir/TableGen/Operator.h
new file mode 100644
index 0000000..d9b60d2
--- /dev/null
+++ b/third_party/mlir/include/mlir/TableGen/Operator.h
@@ -0,0 +1,206 @@
+//===- Operator.h - Operator class ------------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Operator wrapper to simplify using TableGen Record defining a MLIR Op.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_OPERATOR_H_
+#define MLIR_TABLEGEN_OPERATOR_H_
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/TableGen/Argument.h"
+#include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/Dialect.h"
+#include "mlir/TableGen/OpTrait.h"
+#include "mlir/TableGen/Region.h"
+#include "mlir/TableGen/Type.h"
+#include "llvm/ADT/PointerUnion.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/SMLoc.h"
+
+namespace llvm {
+class CodeInit;
+class DefInit;
+class Record;
+class StringInit;
+} // end namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+// Wrapper class that contains a MLIR op's information (e.g., operands,
+// atributes) defined in TableGen and provides helper methods for
+// accessing them.
+class Operator {
+public:
+  explicit Operator(const llvm::Record &def);
+  explicit Operator(const llvm::Record *def) : Operator(*def) {}
+
+  // Returns this op's dialect name.
+  StringRef getDialectName() const;
+
+  // Returns the operation name. The name will follow the "<dialect>.<op-name>"
+  // format if its dialect name is not empty.
+  std::string getOperationName() const;
+
+  // Returns this op's C++ class name.
+  StringRef getCppClassName() const;
+
+  // Returns this op's C++ class name prefixed with namespaces.
+  std::string getQualCppClassName() const;
+
+  using value_iterator = NamedTypeConstraint *;
+  using value_range = llvm::iterator_range<value_iterator>;
+
+  // Returns true if this op has variadic operands or results.
+  bool isVariadic() const;
+
+  // Returns true if default builders should not be generated.
+  bool skipDefaultBuilders() const;
+
+  // Op result iterators.
+  value_iterator result_begin();
+  value_iterator result_end();
+  value_range getResults();
+
+  // Returns the number of results this op produces.
+  int getNumResults() const;
+
+  // Returns the op result at the given `index`.
+  NamedTypeConstraint &getResult(int index) { return results[index]; }
+  const NamedTypeConstraint &getResult(int index) const {
+    return results[index];
+  }
+
+  // Returns the `index`-th result's type constraint.
+  TypeConstraint getResultTypeConstraint(int index) const;
+  // Returns the `index`-th result's name.
+  StringRef getResultName(int index) const;
+
+  // Returns the number of variadic results in this operation.
+  unsigned getNumVariadicResults() const;
+
+  // Op attribute interators.
+  using attribute_iterator = const NamedAttribute *;
+  attribute_iterator attribute_begin() const;
+  attribute_iterator attribute_end() const;
+  llvm::iterator_range<attribute_iterator> getAttributes() const;
+
+  int getNumAttributes() const { return attributes.size(); }
+
+  // Op attribute accessors.
+  NamedAttribute &getAttribute(int index) { return attributes[index]; }
+
+  // Op operand iterators.
+  value_iterator operand_begin();
+  value_iterator operand_end();
+  value_range getOperands();
+
+  int getNumOperands() const { return operands.size(); }
+  NamedTypeConstraint &getOperand(int index) { return operands[index]; }
+  const NamedTypeConstraint &getOperand(int index) const {
+    return operands[index];
+  }
+
+  // Returns the number of variadic operands in this operation.
+  unsigned getNumVariadicOperands() const;
+
+  // Returns the total number of arguments.
+  int getNumArgs() const { return arguments.size(); }
+
+  // Op argument (attribute or operand) accessors.
+  Argument getArg(int index) const;
+  StringRef getArgName(int index) const;
+
+  // Returns true if this op has the given MLIR C++ `trait`.
+  // TODO: We should add a C++ wrapper class for TableGen OpTrait instead of
+  // requiring the raw MLIR trait here.
+  bool hasTrait(llvm::StringRef trait) const;
+
+  // Returns the number of regions.
+  unsigned getNumRegions() const;
+  // Returns the `index`-th region.
+  const NamedRegion &getRegion(unsigned index) const;
+
+  // Trait.
+  using const_trait_iterator = const OpTrait *;
+  const_trait_iterator trait_begin() const;
+  const_trait_iterator trait_end() const;
+  llvm::iterator_range<const_trait_iterator> getTraits() const;
+
+  ArrayRef<llvm::SMLoc> getLoc() const;
+
+  // Query functions for the documentation of the operator.
+  bool hasDescription() const;
+  StringRef getDescription() const;
+  bool hasSummary() const;
+  StringRef getSummary() const;
+
+  // Returns this op's extra class declaration code.
+  StringRef getExtraClassDeclaration() const;
+
+  // Returns the Tablegen definition this operator was constructed from.
+  // TODO(antiagainst,zinenko): do not expose the TableGen record, this is a
+  // temporary solution to OpEmitter requiring a Record because Operator does
+  // not provide enough methods.
+  const llvm::Record &getDef() const;
+
+private:
+  // Populates the vectors containing operands, attributes, results and traits.
+  void populateOpStructure();
+
+  // The dialect of this op.
+  Dialect dialect;
+
+  // The unqualified C++ class name of the op.
+  StringRef cppClassName;
+
+  // The operands of the op.
+  SmallVector<NamedTypeConstraint, 4> operands;
+
+  // The attributes of the op.  Contains native attributes (corresponding to the
+  // actual stored attributed of the operation) followed by derived attributes
+  // (corresponding to dynamic properties of the operation that are computed
+  // upon request).
+  SmallVector<NamedAttribute, 4> attributes;
+
+  // The arguments of the op (operands and native attributes).
+  SmallVector<Argument, 4> arguments;
+
+  // The results of the op.
+  SmallVector<NamedTypeConstraint, 4> results;
+
+  // The traits of the op.
+  SmallVector<OpTrait, 4> traits;
+
+  // The regions of this op.
+  SmallVector<NamedRegion, 1> regions;
+
+  // The number of native attributes stored in the leading positions of
+  // `attributes`.
+  int numNativeAttributes;
+
+  // The TableGen definition of this op.
+  const llvm::Record &def;
+};
+
+} // end namespace tblgen
+} // end namespace mlir
+
+#endif // MLIR_TABLEGEN_OPERATOR_H_
diff --git a/third_party/mlir/include/mlir/TableGen/Pattern.h b/third_party/mlir/include/mlir/TableGen/Pattern.h
new file mode 100644
index 0000000..efe6494
--- /dev/null
+++ b/third_party/mlir/include/mlir/TableGen/Pattern.h
@@ -0,0 +1,392 @@
+//===- Pattern.h - Pattern wrapper class ------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Pattern wrapper class to simplify using TableGen Record defining a MLIR
+// Pattern.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_PATTERN_H_
+#define MLIR_TABLEGEN_PATTERN_H_
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/TableGen/Argument.h"
+#include "mlir/TableGen/Operator.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/StringSet.h"
+
+namespace llvm {
+class DagInit;
+class Init;
+class Record;
+} // end namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+// Mapping from TableGen Record to Operator wrapper object.
+//
+// We allocate each wrapper object in heap to make sure the pointer to it is
+// valid throughout the lifetime of this map. This is important because this map
+// is shared among multiple patterns to avoid creating the wrapper object for
+// the same op again and again. But this map will continuously grow.
+using RecordOperatorMap =
+    llvm::DenseMap<const llvm::Record *, std::unique_ptr<Operator>>;
+
+class Pattern;
+
+// Wrapper class providing helper methods for accessing TableGen DAG leaves
+// used inside Patterns. This class is lightweight and designed to be used like
+// values.
+//
+// A TableGen DAG construct is of the syntax
+//   `(operator, arg0, arg1, ...)`.
+//
+// This class provides getters to retrieve `arg*` as tblgen:: wrapper objects
+// for handy helper methods. It only works on `arg*`s that are not nested DAG
+// constructs.
+class DagLeaf {
+public:
+  explicit DagLeaf(const llvm::Init *def) : def(def) {}
+
+  // Returns true if this DAG leaf is not specified in the pattern. That is, it
+  // places no further constraints/transforms and just carries over the original
+  // value.
+  bool isUnspecified() const;
+
+  // Returns true if this DAG leaf is matching an operand. That is, it specifies
+  // a type constraint.
+  bool isOperandMatcher() const;
+
+  // Returns true if this DAG leaf is matching an attribute. That is, it
+  // specifies an attribute constraint.
+  bool isAttrMatcher() const;
+
+  // Returns true if this DAG leaf is wrapping native code call.
+  bool isNativeCodeCall() const;
+
+  // Returns true if this DAG leaf is specifying a constant attribute.
+  bool isConstantAttr() const;
+
+  // Returns true if this DAG leaf is specifying an enum attribute case.
+  bool isEnumAttrCase() const;
+
+  // Returns this DAG leaf as a constraint. Asserts if fails.
+  Constraint getAsConstraint() const;
+
+  // Returns this DAG leaf as an constant attribute. Asserts if fails.
+  ConstantAttr getAsConstantAttr() const;
+
+  // Returns this DAG leaf as an enum attribute case.
+  // Precondition: isEnumAttrCase()
+  EnumAttrCase getAsEnumAttrCase() const;
+
+  // Returns the matching condition template inside this DAG leaf. Assumes the
+  // leaf is an operand/attribute matcher and asserts otherwise.
+  std::string getConditionTemplate() const;
+
+  // Returns the native code call template inside this DAG leaf.
+  // Precondition: isNativeCodeCall()
+  StringRef getNativeCodeTemplate() const;
+
+private:
+  // Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and
+  // also a subclass of the given `superclass`.
+  bool isSubClassOf(StringRef superclass) const;
+
+  const llvm::Init *def;
+};
+
+// Wrapper class providing helper methods for accessing TableGen DAG constructs
+// used inside Patterns. This class is lightweight and designed to be used like
+// values.
+//
+// A TableGen DAG construct is of the syntax
+//   `(operator, arg0, arg1, ...)`.
+//
+// When used inside Patterns, `operator` corresponds to some dialect op, or
+// a known list of verbs that defines special transformation actions. This
+// `arg*` can be a nested DAG construct. This class provides getters to
+// retrieve `operator` and `arg*` as tblgen:: wrapper objects for handy helper
+// methods.
+//
+// A null DagNode contains a nullptr and converts to false implicitly.
+class DagNode {
+public:
+  explicit DagNode(const llvm::DagInit *node) : node(node) {}
+
+  // Implicit bool converter that returns true if this DagNode is not a null
+  // DagNode.
+  operator bool() const { return node != nullptr; }
+
+  // Returns the symbol bound to this DAG node.
+  StringRef getSymbol() const;
+
+  // Returns the operator wrapper object corresponding to the dialect op matched
+  // by this DAG. The operator wrapper will be queried from the given `mapper`
+  // and created in it if not existing.
+  Operator &getDialectOp(RecordOperatorMap *mapper) const;
+
+  // Returns the number of operations recursively involved in the DAG tree
+  // rooted from this node.
+  int getNumOps() const;
+
+  // Returns the number of immediate arguments to this DAG node.
+  int getNumArgs() const;
+
+  // Returns true if the `index`-th argument is a nested DAG construct.
+  bool isNestedDagArg(unsigned index) const;
+
+  // Gets the `index`-th argument as a nested DAG construct if possible. Returns
+  // null DagNode otherwise.
+  DagNode getArgAsNestedDag(unsigned index) const;
+
+  // Gets the `index`-th argument as a DAG leaf.
+  DagLeaf getArgAsLeaf(unsigned index) const;
+
+  // Returns the specified name of the `index`-th argument.
+  StringRef getArgName(unsigned index) const;
+
+  // Returns true if this DAG construct means to replace with an existing SSA
+  // value.
+  bool isReplaceWithValue() const;
+
+  // Returns true if this DAG node is wrapping native code call.
+  bool isNativeCodeCall() const;
+
+  // Returns true if this DAG node is an operation.
+  bool isOperation() const;
+
+  // Returns the native code call template inside this DAG node.
+  // Precondition: isNativeCodeCall()
+  StringRef getNativeCodeTemplate() const;
+
+private:
+  const llvm::DagInit *node; // nullptr means null DagNode
+};
+
+// A class for maintaining information for symbols bound in patterns and
+// provides methods for resolving them according to specific use cases.
+//
+// Symbols can be bound to
+//
+// * Op arguments and op results in the source pattern and
+// * Op results in result patterns.
+//
+// Symbols can be referenced in result patterns and additional constraints to
+// the pattern.
+//
+// For example, in
+//
+// ```
+// def : Pattern<
+//     (SrcOp:$results1 $arg0, %arg1),
+//     [(ResOp1:$results2), (ResOp2 $results2 (ResOp3 $arg0, $arg1))]>;
+// ```
+//
+// `$argN` is bound to the `SrcOp`'s N-th argument. `$results1` is bound to
+// `SrcOp`. `$results2` is bound to `ResOp1`. $result2 is referenced to build
+// `ResOp2`. `$arg0` and `$arg1` are referenced to build `ResOp3`.
+//
+// If a symbol binds to a multi-result op and it does not have the `__N`
+// suffix, the symbol is expanded to represent all results generated by the
+// multi-result op. If the symbol has a `__N` suffix, then it will expand to
+// only the N-th *static* result as declared in ODS, and that can still
+// corresponds to multiple *dynamic* values if the N-th *static* result is
+// variadic.
+//
+// This class keeps track of such symbols and resolves them into their bound
+// values in a suitable way.
+class SymbolInfoMap {
+public:
+  explicit SymbolInfoMap(ArrayRef<llvm::SMLoc> loc) : loc(loc) {}
+
+  // Class for information regarding a symbol.
+  class SymbolInfo {
+  public:
+    // Returns a string for defining a variable named as `name` to store the
+    // value bound by this symbol.
+    std::string getVarDecl(StringRef name) const;
+
+  private:
+    // Allow SymbolInfoMap to access private methods.
+    friend class SymbolInfoMap;
+
+    // What kind of entity this symbol represents:
+    // * Attr: op attribute
+    // * Operand: op operand
+    // * Result: op result
+    // * Value: a value not attached to an op (e.g., from NativeCodeCall)
+    enum class Kind : uint8_t { Attr, Operand, Result, Value };
+
+    // Creates a SymbolInfo instance. `index` is only used for `Attr` and
+    // `Operand` so should be negative for `Result` and `Value` kind.
+    SymbolInfo(const Operator *op, Kind kind, Optional<int> index);
+
+    // Static methods for creating SymbolInfo.
+    static SymbolInfo getAttr(const Operator *op, int index) {
+      return SymbolInfo(op, Kind::Attr, index);
+    }
+    static SymbolInfo getOperand(const Operator *op, int index) {
+      return SymbolInfo(op, Kind::Operand, index);
+    }
+    static SymbolInfo getResult(const Operator *op) {
+      return SymbolInfo(op, Kind::Result, llvm::None);
+    }
+    static SymbolInfo getValue() {
+      return SymbolInfo(nullptr, Kind::Value, llvm::None);
+    }
+
+    // Returns the number of static values this symbol corresponds to.
+    // A static value is an operand/result declared in ODS. Normally a symbol
+    // only represents one static value, but symbols bound to op results can
+    // represent more than one if the op is a multi-result op.
+    int getStaticValueCount() const;
+
+    // Returns a string containing the C++ expression for referencing this
+    // symbol as a value (if this symbol represents one static value) or a value
+    // range (if this symbol represents multiple static values). `name` is the
+    // name of the C++ variable that this symbol bounds to. `index` should only
+    // be used for indexing results.
+    std::string getValueAndRangeUse(StringRef name, int index) const;
+
+    const Operator *op; // The op where the bound entity belongs
+    Kind kind;          // The kind of the bound entity
+    // The argument index (for `Attr` and `Operand` only)
+    Optional<int> argIndex;
+  };
+
+  using BaseT = llvm::StringMap<SymbolInfo>;
+
+  // Iterators for accessing all symbols.
+  using iterator = BaseT::iterator;
+  iterator begin() { return symbolInfoMap.begin(); }
+  iterator end() { return symbolInfoMap.end(); }
+
+  // Const iterators for accessing all symbols.
+  using const_iterator = BaseT::const_iterator;
+  const_iterator begin() const { return symbolInfoMap.begin(); }
+  const_iterator end() const { return symbolInfoMap.end(); }
+
+  // Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
+  // Returns false if `symbol` is already bound.
+  bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex);
+
+  // Binds the given `symbol` to the results the given `op`. Returns false if
+  // `symbol` is already bound.
+  bool bindOpResult(StringRef symbol, const Operator &op);
+
+  // Registers the given `symbol` as bound to a value. Returns false if `symbol`
+  // is already bound.
+  bool bindValue(StringRef symbol);
+
+  // Returns true if the given `symbol` is bound.
+  bool contains(StringRef symbol) const;
+
+  // Returns an interator to the information of the given symbol named as `key`.
+  const_iterator find(StringRef key) const;
+
+  // Returns the number of static values of the given `symbol` corresponds to.
+  // A static value is a operand/result declared in ODS. Normally a symbol only
+  // represents one static value, but symbols bound to op results can represent
+  // more than one if the op is a multi-result op.
+  int getStaticValueCount(StringRef symbol) const;
+
+  // Returns a string containing the C++ expression for referencing this
+  // symbol as a value (if this symbol represents one static value) or a value
+  // range (if this symbol represents multiple static values).
+  std::string getValueAndRangeUse(StringRef symbol) const;
+
+  // Splits the given `symbol` into a value pack name and an index. Returns the
+  // value pack name and writes the index to `index` on sucess. Returns `symbol`
+  // itself if it does not contain an index.
+  //
+  // We can use `name__N` to access the `N`-th value in the value pack bound to
+  // `name`. `name` is typically the results of an multi-result op.
+  static StringRef getValuePackName(StringRef symbol, int *index = nullptr);
+
+private:
+  llvm::StringMap<SymbolInfo> symbolInfoMap;
+
+  // Pattern instantiation location. This is intended to be used as parameter
+  // to PrintFatalError() to report errors.
+  ArrayRef<llvm::SMLoc> loc;
+};
+
+// Wrapper class providing helper methods for accessing MLIR Pattern defined
+// in TableGen. This class should closely reflect what is defined as class
+// `Pattern` in TableGen. This class contains maps so it is not intended to be
+// used as values.
+class Pattern {
+public:
+  explicit Pattern(const llvm::Record *def, RecordOperatorMap *mapper);
+
+  // Returns the source pattern to match.
+  DagNode getSourcePattern() const;
+
+  // Returns the number of result patterns generated by applying this rewrite
+  // rule.
+  int getNumResultPatterns() const;
+
+  // Returns the DAG tree root node of the `index`-th result pattern.
+  DagNode getResultPattern(unsigned index) const;
+
+  // Collects all symbols bound in the source pattern into `infoMap`.
+  void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap);
+
+  // Collects all symbols bound in result patterns into `infoMap`.
+  void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap);
+
+  // Returns the op that the root node of the source pattern matches.
+  const Operator &getSourceRootOp();
+
+  // Returns the operator wrapper object corresponding to the given `node`'s DAG
+  // operator.
+  Operator &getDialectOp(DagNode node);
+
+  // Returns the constraints.
+  std::vector<AppliedConstraint> getConstraints() const;
+
+  // Returns the benefit score of the pattern.
+  int getBenefit() const;
+
+  using IdentifierLine = std::pair<StringRef, unsigned>;
+
+  // Returns the file location of the pattern (buffer identifier + line number
+  // pair).
+  std::vector<IdentifierLine> getLocation() const;
+
+private:
+  // Recursively collects all bound symbols inside the DAG tree rooted
+  // at `tree` and updates the given `infoMap`.
+  void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
+                           bool isSrcPattern);
+
+  // The TableGen definition of this pattern.
+  const llvm::Record &def;
+
+  // All operators.
+  // TODO(antiagainst): we need a proper context manager, like MLIRContext,
+  // for managing the lifetime of shared entities.
+  RecordOperatorMap *recordOpMap;
+};
+
+} // end namespace tblgen
+} // end namespace mlir
+
+#endif // MLIR_TABLEGEN_PATTERN_H_
diff --git a/third_party/mlir/include/mlir/TableGen/Predicate.h b/third_party/mlir/include/mlir/TableGen/Predicate.h
new file mode 100644
index 0000000..49f7ebc
--- /dev/null
+++ b/third_party/mlir/include/mlir/TableGen/Predicate.h
@@ -0,0 +1,128 @@
+//===- Predicate.h - Predicate class ----------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Wrapper around predicates defined in TableGen.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_PREDICATE_H_
+#define MLIR_TABLEGEN_PREDICATE_H_
+
+#include "mlir/Support/LLVM.h"
+
+#include <string>
+#include <vector>
+
+namespace llvm {
+class Init;
+class ListInit;
+class Record;
+class SMLoc;
+} // end namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+// A logical predicate.  This class must closely follow the definition of
+// TableGen class 'Pred'.
+class Pred {
+public:
+  // Constructs the null Predicate (e.g., always true).
+  explicit Pred() : def(nullptr) {}
+  // Construct a Predicate from a record.
+  explicit Pred(const llvm::Record *record);
+  // Construct a Predicate from an initializer.
+  explicit Pred(const llvm::Init *init);
+
+  // Check if the predicate is defined.  Callers may use this to interpret the
+  // missing predicate as either true (e.g. in filters) or false (e.g. in
+  // precondition verification).
+  bool isNull() const { return def == nullptr; }
+
+  // Get the predicate condition.  This may dispatch to getConditionImpl() of
+  // the underlying predicate type.
+  std::string getCondition() const;
+
+  // Whether the predicate is a combination of other predicates, i.e. an
+  // record of type CombinedPred.
+  bool isCombined() const;
+
+  // Records are pointer-comparable.
+  bool operator==(const Pred &other) const { return def == other.def; }
+
+  // Get the location of the predicate.
+  ArrayRef<llvm::SMLoc> getLoc() const;
+
+protected:
+  // The TableGen definition of this predicate.
+  const llvm::Record *def;
+};
+
+// A logical predicate wrapping a C expression.  This class must closely follow
+// the definition of TableGen class 'CPred'.
+class CPred : public Pred {
+public:
+  // Construct a CPred from a record.
+  explicit CPred(const llvm::Record *record);
+  // Construct a CPred an initializer.
+  explicit CPred(const llvm::Init *init);
+
+  // Get the predicate condition.
+  std::string getConditionImpl() const;
+};
+
+// A logical predicate that is a combination of other predicates.  This class
+// must closely follow the definition of TableGen class 'CombinedPred'.
+class CombinedPred : public Pred {
+public:
+  // Construct a CombinedPred from a record.
+  explicit CombinedPred(const llvm::Record *record);
+  // Construct a CombinedPred from an initializer.
+  explicit CombinedPred(const llvm::Init *init);
+
+  // Get the predicate condition.
+  std::string getConditionImpl() const;
+
+  // Get the definition of the combiner used in this predicate.
+  const llvm::Record *getCombinerDef() const;
+
+  // Get the predicates that are combined by this predicate.
+  const std::vector<llvm::Record *> getChildren() const;
+};
+
+// A combined predicate that requires all child predicates of 'CPred' type to
+// have their expression rewritten with a simple string substitution rule.
+class SubstLeavesPred : public CombinedPred {
+public:
+  // Get the replacement pattern.
+  StringRef getPattern() const;
+  // Get the string used to replace the pattern.
+  StringRef getReplacement() const;
+};
+
+// A combined predicate that prepends a prefix and appends a suffix to the
+// predicate string composed from a child predicate.
+class ConcatPred : public CombinedPred {
+public:
+  StringRef getPrefix() const;
+  StringRef getSuffix() const;
+};
+
+} // end namespace tblgen
+} // end namespace mlir
+
+#endif // MLIR_TABLEGEN_PREDICATE_H_
diff --git a/third_party/mlir/include/mlir/TableGen/Region.h b/third_party/mlir/include/mlir/TableGen/Region.h
new file mode 100644
index 0000000..21dffe6
--- /dev/null
+++ b/third_party/mlir/include/mlir/TableGen/Region.h
@@ -0,0 +1,45 @@
+//===- TGRegion.h - TableGen region definitions -----------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_TABLEGEN_REGION_H_
+#define MLIR_TABLEGEN_REGION_H_
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/TableGen/Constraint.h"
+
+namespace mlir {
+namespace tblgen {
+
+// Wrapper class providing helper methods for accessing Region defined in
+// TableGen.
+class Region : public Constraint {
+public:
+  using Constraint::Constraint;
+
+  static bool classof(const Constraint *c) { return c->getKind() == CK_Region; }
+};
+
+// A struct bundling a region's constraint and its name.
+struct NamedRegion {
+  StringRef name;
+  Region constraint;
+};
+
+} // end namespace tblgen
+} // end namespace mlir
+
+#endif // MLIR_TABLEGEN_REGION_H_
diff --git a/third_party/mlir/include/mlir/TableGen/Type.h b/third_party/mlir/include/mlir/TableGen/Type.h
new file mode 100644
index 0000000..c7f92e4
--- /dev/null
+++ b/third_party/mlir/include/mlir/TableGen/Type.h
@@ -0,0 +1,52 @@
+//===- Type.h - Type class --------------------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Type wrapper to simplify using TableGen Record defining a MLIR Type.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_TYPE_H_
+#define MLIR_TABLEGEN_TYPE_H_
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/TableGen/Constraint.h"
+
+namespace llvm {
+class DefInit;
+class Record;
+} // end namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+// Wrapper class with helper methods for accessing Type constraints defined in
+// TableGen.
+class TypeConstraint : public Constraint {
+public:
+  explicit TypeConstraint(const llvm::Record *record);
+  explicit TypeConstraint(const llvm::DefInit *init);
+
+  static bool classof(const Constraint *c) { return c->getKind() == CK_Type; }
+
+  // Returns true if this is a variadic type constraint.
+  bool isVariadic() const;
+};
+
+} // end namespace tblgen
+} // end namespace mlir
+
+#endif // MLIR_TABLEGEN_TYPE_H_
diff --git a/third_party/mlir/include/mlir/Target/LLVMIR.h b/third_party/mlir/include/mlir/Target/LLVMIR.h
new file mode 100644
index 0000000..4176490
--- /dev/null
+++ b/third_party/mlir/include/mlir/Target/LLVMIR.h
@@ -0,0 +1,45 @@
+//===- LLVMIR.h - MLIR to LLVM IR conversion --------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file declares the entry point for the MLIR to LLVM IR conversion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_H
+#define MLIR_TARGET_LLVMIR_H
+
+#include <memory>
+
+// Forward-declare LLVM classses.
+namespace llvm {
+class LLVMContext;
+class Module;
+} // namespace llvm
+
+namespace mlir {
+
+class ModuleOp;
+
+/// Convert the given MLIR module into LLVM IR.  The LLVM context is extracted
+/// from the registered LLVM IR dialect.  In case of error, report it
+/// to the error handler registered with the MLIR context, if any (obtained from
+/// the MLIR module), and return `nullptr`.
+std::unique_ptr<llvm::Module> translateModuleToLLVMIR(ModuleOp m);
+
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVMIR_H
diff --git a/third_party/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/third_party/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
new file mode 100644
index 0000000..04651b8
--- /dev/null
+++ b/third_party/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -0,0 +1,102 @@
+//===- ModuleTranslation.h - MLIR to LLVM conversion ------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the translation between an MLIR LLVM dialect module and
+// the corresponding LLVMIR module. It only handles core LLVM IR operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
+#define MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
+
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Value.h"
+
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Value.h"
+
+namespace mlir {
+class Attribute;
+class FuncOp;
+class Location;
+class ModuleOp;
+class Operation;
+
+namespace LLVM {
+
+// Implementation class for module translation.  Holds a reference to the module
+// being translated, and the mappings between the original and the translated
+// functions, basic blocks and values.  It is practically easier to hold these
+// mappings in one class since the conversion of control flow operations
+// needs to look up block and function mappings.
+class ModuleTranslation {
+public:
+  template <typename T = ModuleTranslation>
+  static std::unique_ptr<llvm::Module> translateModule(ModuleOp m) {
+    auto llvmModule = prepareLLVMModule(m);
+
+    T translator(m);
+    translator.llvmModule = std::move(llvmModule);
+    translator.convertGlobals();
+    if (failed(translator.convertFunctions()))
+      return nullptr;
+
+    return std::move(translator.llvmModule);
+  }
+
+protected:
+  // Translate the given MLIR module expressed in MLIR LLVM IR dialect into an
+  // LLVM IR module.  The MLIR LLVM IR dialect holds a pointer to an
+  // LLVMContext, the LLVM IR module will be created in that context.
+  explicit ModuleTranslation(ModuleOp module) : mlirModule(module) {}
+  virtual ~ModuleTranslation() {}
+
+  virtual LogicalResult convertOperation(Operation &op,
+                                         llvm::IRBuilder<> &builder);
+  static std::unique_ptr<llvm::Module> prepareLLVMModule(ModuleOp m);
+
+private:
+  LogicalResult convertFunctions();
+  void convertGlobals();
+  LogicalResult convertOneFunction(FuncOp func);
+  void connectPHINodes(FuncOp func);
+  LogicalResult convertBlock(Block &bb, bool ignoreArguments);
+
+  template <typename Range>
+  SmallVector<llvm::Value *, 8> lookupValues(Range &&values);
+
+  llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
+                                  Location loc);
+
+  // Original and translated module.
+  ModuleOp mlirModule;
+  std::unique_ptr<llvm::Module> llvmModule;
+
+protected:
+  // Mappings between original and translated values, used for lookups.
+  llvm::StringMap<llvm::Function *> functionMapping;
+  llvm::DenseMap<Value *, llvm::Value *> valueMapping;
+  llvm::DenseMap<Block *, llvm::BasicBlock *> blockMapping;
+};
+
+} // namespace LLVM
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
diff --git a/third_party/mlir/include/mlir/Target/NVVMIR.h b/third_party/mlir/include/mlir/Target/NVVMIR.h
new file mode 100644
index 0000000..d3e24db
--- /dev/null
+++ b/third_party/mlir/include/mlir/Target/NVVMIR.h
@@ -0,0 +1,44 @@
+//===- NVVMIR.h - MLIR to LLVM + NVVM IR conversion -------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file declares the entry point for the MLIR to LLVM + NVVM IR conversion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_NVVMIR_H
+#define MLIR_TARGET_NVVMIR_H
+
+#include <memory>
+
+// Forward-declare LLVM classses.
+namespace llvm {
+class Module;
+} // namespace llvm
+
+namespace mlir {
+class ModuleOp;
+
+/// Convert the given MLIR module into NVVM IR. This conversion requires the
+/// registration of the LLVM IR dialect and will extract the LLVM context
+/// from the registered LLVM IR dialect.  In case of error, report it
+/// to the error handler registered with the MLIR context, if any (obtained from
+/// the MLIR module), and return `nullptr`.
+std::unique_ptr<llvm::Module> translateModuleToNVVMIR(ModuleOp m);
+
+} // namespace mlir
+
+#endif // MLIR_TARGET_NVVMIR_H
diff --git a/third_party/mlir/include/mlir/Transforms/DialectConversion.h b/third_party/mlir/include/mlir/Transforms/DialectConversion.h
new file mode 100644
index 0000000..3a62c5f
--- /dev/null
+++ b/third_party/mlir/include/mlir/Transforms/DialectConversion.h
@@ -0,0 +1,510 @@
+//===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file declares a generic pass for converting between MLIR dialects.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
+#define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/MapVector.h"
+
+namespace mlir {
+
+// Forward declarations.
+class Block;
+class ConversionPatternRewriter;
+class FuncOp;
+class MLIRContext;
+class Operation;
+class Type;
+class Value;
+
+//===----------------------------------------------------------------------===//
+// Type Conversion
+//===----------------------------------------------------------------------===//
+
+/// Base class for type conversion interface. Specific converters must
+/// derive this class and implement the pure virtual functions.
+class TypeConverter {
+public:
+  virtual ~TypeConverter() = default;
+
+  /// This class provides all of the information necessary to convert a type
+  /// signature.
+  class SignatureConversion {
+  public:
+    SignatureConversion(unsigned numOrigInputs)
+        : remappedInputs(numOrigInputs) {}
+
+    /// This struct represents a range of new types that remap an existing
+    /// signature input.
+    struct InputMapping {
+      size_t inputNo, size;
+    };
+
+    /// Return the argument types for the new signature.
+    ArrayRef<Type> getConvertedTypes() const { return argTypes; }
+
+    /// Get the input mapping for the given argument.
+    llvm::Optional<InputMapping> getInputMapping(unsigned input) const {
+      return remappedInputs[input];
+    }
+
+    //===------------------------------------------------------------------===//
+    // Conversion Hooks
+    //===------------------------------------------------------------------===//
+
+    /// Remap an input of the original signature with a new set of types. The
+    /// new types are appended to the new signature conversion.
+    void addInputs(unsigned origInputNo, ArrayRef<Type> types);
+
+    /// Append new input types to the signature conversion, this should only be
+    /// used if the new types are not intended to remap an existing input.
+    void addInputs(ArrayRef<Type> types);
+
+    /// Remap an input of the original signature with a range of types in the
+    /// new signature.
+    void remapInput(unsigned origInputNo, unsigned newInputNo,
+                    unsigned newInputCount = 1);
+
+  private:
+    /// The remapping information for each of the original arguments.
+    SmallVector<llvm::Optional<InputMapping>, 4> remappedInputs;
+
+    /// The set of new argument types.
+    SmallVector<Type, 4> argTypes;
+  };
+
+  /// This hooks allows for converting a type. This function should return
+  /// failure if no valid conversion exists, success otherwise. If the new set
+  /// of types is empty, the type is removed and any usages of the existing
+  /// value are expected to be removed during conversion.
+  virtual LogicalResult convertType(Type t, SmallVectorImpl<Type> &results);
+
+  /// This hook simplifies defining 1-1 type conversions. This function returns
+  /// the type convert to on success, and a null type on failure.
+  virtual Type convertType(Type t) { return t; }
+
+  /// Convert the given set of types, filling 'results' as necessary. This
+  /// returns failure if the conversion of any of the types fails, success
+  /// otherwise.
+  LogicalResult convertTypes(ArrayRef<Type> types,
+                             SmallVectorImpl<Type> &results);
+
+  /// Return true if the given type is legal for this type converter, i.e. the
+  /// type converts to itself.
+  bool isLegal(Type type);
+
+  /// Return true if the inputs and outputs of the given function type are
+  /// legal.
+  bool isSignatureLegal(FunctionType funcType);
+
+  /// This hook allows for converting a specific argument of a signature. It
+  /// takes as inputs the original argument input number, type.
+  /// On success, this function should populate 'result' with any new mappings.
+  virtual LogicalResult convertSignatureArg(unsigned inputNo, Type type,
+                                            SignatureConversion &result);
+
+  /// This function converts the type signature of the given block, by invoking
+  /// 'convertSignatureArg' for each argument. This function should return a
+  /// valid conversion for the signature on success, None otherwise.
+  llvm::Optional<SignatureConversion> convertBlockSignature(Block *block);
+
+  /// This hook allows for materializing a conversion from a set of types into
+  /// one result type by generating a cast operation of some kind. The generated
+  /// operation should produce one result, of 'resultType', with the provided
+  /// 'inputs' as operands. This hook must be overridden when a type conversion
+  /// results in more than one type, or if a type conversion may persist after
+  /// the conversion has finished.
+  virtual Operation *materializeConversion(PatternRewriter &rewriter,
+                                           Type resultType,
+                                           ArrayRef<Value *> inputs,
+                                           Location loc) {
+    llvm_unreachable("expected 'materializeConversion' to be overridden");
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+/// Base class for the conversion patterns that require type changes. Specific
+/// conversions must derive this class and implement least one `rewrite` method.
+/// NOTE: These conversion patterns can only be used with the 'apply*' methods
+/// below.
+class ConversionPattern : public RewritePattern {
+public:
+  /// Construct an ConversionPattern.  `rootName` must correspond to the
+  /// canonical name of the first operation matched by the pattern.
+  ConversionPattern(StringRef rootName, PatternBenefit benefit,
+                    MLIRContext *ctx)
+      : RewritePattern(rootName, benefit, ctx) {}
+
+  /// Hook for derived classes to implement rewriting. `op` is the (first)
+  /// operation matched by the pattern, `operands` is a list of rewritten values
+  /// that are passed to this operation, `rewriter` can be used to emit the new
+  /// operations. This function must be reimplemented if the
+  /// ConversionPattern ever needs to replace an operation that does not
+  /// have successors. This function should not fail. If some specific cases of
+  /// the operation are not supported, these cases should not be matched.
+  virtual void rewrite(Operation *op, ArrayRef<Value *> operands,
+                       ConversionPatternRewriter &rewriter) const {
+    llvm_unreachable("unimplemented rewrite");
+  }
+
+  /// Hook for derived classes to implement rewriting. `op` is the (first)
+  /// operation matched by the pattern, `properOperands` is a list of rewritten
+  /// values that are passed to the operation itself, `destinations` is a list
+  /// of (potentially rewritten) successor blocks, `operands` is a list of lists
+  /// of rewritten values passed to each of the successors, co-indexed with
+  /// `destinations`, `rewriter` can be used to emit the new operations. It must
+  /// be reimplemented if the ConversionPattern ever needs to replace a
+  /// terminator operation that has successors. This function should not fail
+  /// the pass. If some specific cases of the operation are not supported,
+  /// these cases should not be matched.
+  virtual void rewrite(Operation *op, ArrayRef<Value *> properOperands,
+                       ArrayRef<Block *> destinations,
+                       ArrayRef<ArrayRef<Value *>> operands,
+                       ConversionPatternRewriter &rewriter) const {
+    llvm_unreachable("unimplemented rewrite for terminators");
+  }
+
+  /// Hook for derived classes to implement combined matching and rewriting.
+  virtual PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> properOperands,
+                  ArrayRef<Block *> destinations,
+                  ArrayRef<ArrayRef<Value *>> operands,
+                  ConversionPatternRewriter &rewriter) const {
+    if (!match(op))
+      return matchFailure();
+    rewrite(op, properOperands, destinations, operands, rewriter);
+    return matchSuccess();
+  }
+
+  /// Hook for derived classes to implement combined matching and rewriting.
+  virtual PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const {
+    if (!match(op))
+      return matchFailure();
+    rewrite(op, operands, rewriter);
+    return matchSuccess();
+  }
+
+  /// Attempt to match and rewrite the IR root at the specified operation.
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const final;
+
+private:
+  using RewritePattern::rewrite;
+};
+
+/// Add a pattern to the given pattern list to convert the signature of a FuncOp
+/// with the given type converter.
+void populateFuncOpTypeConversionPattern(OwningRewritePatternList &patterns,
+                                         MLIRContext *ctx,
+                                         TypeConverter &converter);
+
+//===----------------------------------------------------------------------===//
+// Conversion PatternRewriter
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+struct ConversionPatternRewriterImpl;
+} // end namespace detail
+
+/// This class implements a pattern rewriter for use with ConversionPatterns. It
+/// extends the base PatternRewriter and provides special conversion specific
+/// hooks.
+class ConversionPatternRewriter final : public PatternRewriter {
+public:
+  ConversionPatternRewriter(MLIRContext *ctx, TypeConverter *converter);
+  ~ConversionPatternRewriter() override;
+
+  /// Apply a signature conversion to the entry block of the given region.
+  void applySignatureConversion(Region *region,
+                                TypeConverter::SignatureConversion &conversion);
+
+  /// Clone the given operation without cloning its regions.
+  Operation *cloneWithoutRegions(Operation *op);
+  template <typename OpT> OpT cloneWithoutRegions(OpT op) {
+    return cast<OpT>(cloneWithoutRegions(op.getOperation()));
+  }
+
+  //===--------------------------------------------------------------------===//
+  // PatternRewriter Hooks
+  //===--------------------------------------------------------------------===//
+
+  /// PatternRewriter hook for replacing the results of an operation.
+  void replaceOp(Operation *op, ArrayRef<Value *> newValues,
+                 ArrayRef<Value *> valuesToRemoveIfDead) override;
+  using PatternRewriter::replaceOp;
+
+  /// PatternRewriter hook for splitting a block into two parts.
+  Block *splitBlock(Block *block, Block::iterator before) override;
+
+  /// PatternRewriter hook for moving blocks out of a region.
+  void inlineRegionBefore(Region &region, Region &parent,
+                          Region::iterator before) override;
+  using PatternRewriter::inlineRegionBefore;
+
+  /// PatternRewriter hook for creating a new operation.
+  Operation *createOperation(const OperationState &state) override;
+
+  /// PatternRewriter hook for updating the root operation in-place.
+  void notifyRootUpdated(Operation *op) override;
+
+  /// Return a reference to the internal implementation.
+  detail::ConversionPatternRewriterImpl &getImpl();
+
+private:
+  std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
+};
+
+//===----------------------------------------------------------------------===//
+// ConversionTarget
+//===----------------------------------------------------------------------===//
+
+/// This class describes a specific conversion target.
+class ConversionTarget {
+public:
+  /// This enumeration corresponds to the specific action to take when
+  /// considering an operation legal for this conversion target.
+  enum class LegalizationAction {
+    /// The target supports this operation.
+    Legal,
+
+    /// This operation has dynamic legalization constraints that must be checked
+    /// by the target.
+    Dynamic,
+
+    /// The target explicitly does not support this operation.
+    Illegal,
+  };
+
+  /// The type used to store operation legality information.
+  using LegalityMapTy = llvm::MapVector<OperationName, LegalizationAction>;
+
+  /// The signature of the callback used to determine if an operation is
+  /// dynamically legal on the target.
+  using DynamicLegalityCallbackFn = std::function<bool(Operation *)>;
+
+  ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
+  virtual ~ConversionTarget() = default;
+
+  //===--------------------------------------------------------------------===//
+  // Legality Registration
+  //===--------------------------------------------------------------------===//
+
+  /// Register a legality action for the given operation.
+  void setOpAction(OperationName op, LegalizationAction action);
+  template <typename OpT> void setOpAction(LegalizationAction action) {
+    setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
+  }
+
+  /// Register the given operations as legal.
+  template <typename OpT> void addLegalOp() {
+    setOpAction<OpT>(LegalizationAction::Legal);
+  }
+  template <typename OpT, typename OpT2, typename... OpTs> void addLegalOp() {
+    addLegalOp<OpT>();
+    addLegalOp<OpT2, OpTs...>();
+  }
+
+  /// Register the given operation as dynamically legal, i.e. requiring custom
+  /// handling by the target via 'isDynamicallyLegal'.
+  template <typename OpT> void addDynamicallyLegalOp() {
+    setOpAction<OpT>(LegalizationAction::Dynamic);
+  }
+  template <typename OpT, typename OpT2, typename... OpTs>
+  void addDynamicallyLegalOp() {
+    addDynamicallyLegalOp<OpT>();
+    addDynamicallyLegalOp<OpT2, OpTs...>();
+  }
+
+  /// Register the given operation as dynamically legal and set the dynamic
+  /// legalization callback to the one provided.
+  template <typename OpT>
+  void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) {
+    OperationName opName(OpT::getOperationName(), &ctx);
+    setOpAction(opName, LegalizationAction::Dynamic);
+    setLegalityCallback(opName, callback);
+  }
+  template <typename OpT, typename OpT2, typename... OpTs>
+  void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) {
+    addDynamicallyLegalOp<OpT>(callback);
+    addDynamicallyLegalOp<OpT2, OpTs...>(callback);
+  }
+  template <typename OpT, class Callable>
+  typename std::enable_if<!is_invocable<Callable, Operation *>::value>::type
+  addDynamicallyLegalOp(Callable &&callback) {
+    addDynamicallyLegalOp<OpT>(
+        [=](Operation *op) { return callback(cast<OpT>(op)); });
+  }
+
+  /// Register the given operation as illegal, i.e. this operation is known to
+  /// not be supported by this target.
+  template <typename OpT> void addIllegalOp() {
+    setOpAction<OpT>(LegalizationAction::Illegal);
+  }
+  template <typename OpT, typename OpT2, typename... OpTs> void addIllegalOp() {
+    addIllegalOp<OpT>();
+    addIllegalOp<OpT2, OpTs...>();
+  }
+
+  /// Register a legality action for the given dialects.
+  void setDialectAction(ArrayRef<StringRef> dialectNames,
+                        LegalizationAction action);
+
+  /// Register the operations of the given dialects as legal.
+  template <typename... Names>
+  void addLegalDialect(StringRef name, Names... names) {
+    SmallVector<StringRef, 2> dialectNames({name, names...});
+    setDialectAction(dialectNames, LegalizationAction::Legal);
+  }
+  template <typename... Args> void addLegalDialect() {
+    SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
+    setDialectAction(dialectNames, LegalizationAction::Legal);
+  }
+
+  /// Register the operations of the given dialects as dynamically legal, i.e.
+  /// requiring custom handling by the target via 'isDynamicallyLegal'.
+  template <typename... Names>
+  void addDynamicallyLegalDialect(StringRef name, Names... names) {
+    SmallVector<StringRef, 2> dialectNames({name, names...});
+    setDialectAction(dialectNames, LegalizationAction::Dynamic);
+  }
+  template <typename... Args>
+  void addDynamicallyLegalDialect(
+      llvm::Optional<DynamicLegalityCallbackFn> callback = llvm::None) {
+    SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
+    setDialectAction(dialectNames, LegalizationAction::Dynamic);
+    if (callback)
+      setLegalityCallback(dialectNames, *callback);
+  }
+
+  /// Register the operations of the given dialects as illegal, i.e.
+  /// operations of this dialect are not supported by the target.
+  template <typename... Names>
+  void addIllegalDialect(StringRef name, Names... names) {
+    SmallVector<StringRef, 2> dialectNames({name, names...});
+    setDialectAction(dialectNames, LegalizationAction::Illegal);
+  }
+  template <typename... Args> void addIllegalDialect() {
+    SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
+    setDialectAction(dialectNames, LegalizationAction::Illegal);
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Legality Querying
+  //===--------------------------------------------------------------------===//
+
+  /// Get the legality action for the given operation.
+  llvm::Optional<LegalizationAction> getOpAction(OperationName op) const;
+
+  /// Return true if the given operation instance is legal on this target.
+  bool isLegal(Operation *op) const;
+
+protected:
+  /// Runs a custom legalization query for the given operation. This should
+  /// return true if the given operation is legal, otherwise false.
+  virtual bool isDynamicallyLegal(Operation *op) const {
+    llvm_unreachable(
+        "targets with custom legalization must override 'isDynamicallyLegal'");
+  }
+
+private:
+  /// Set the dynamic legality callback for the given operation.
+  void setLegalityCallback(OperationName name,
+                           const DynamicLegalityCallbackFn &callback);
+
+  /// Set the dynamic legality callback for the given dialects.
+  void setLegalityCallback(ArrayRef<StringRef> dialects,
+                           const DynamicLegalityCallbackFn &callback);
+
+  /// A deterministic mapping of operation name to the specific legality action
+  /// to take.
+  LegalityMapTy legalOperations;
+
+  /// A set of dynamic legality callbacks for given operation names.
+  DenseMap<OperationName, DynamicLegalityCallbackFn> opLegalityFns;
+
+  /// A deterministic mapping of dialect name to the specific legality action to
+  /// take.
+  llvm::StringMap<LegalizationAction> legalDialects;
+
+  /// A set of dynamic legality callbacks for given dialect names.
+  llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
+
+  /// The current context this target applies to.
+  MLIRContext &ctx;
+};
+
+//===----------------------------------------------------------------------===//
+// Op Conversion Entry Points
+//===----------------------------------------------------------------------===//
+
+/// Apply a partial conversion on the given operations, and all nested
+/// operations. This method converts as many operations to the target as
+/// possible, ignoring operations that failed to legalize. This method only
+/// returns failure if there are unreachable blocks in any of the regions nested
+/// within 'ops'. If 'converter' is provided, the signatures of blocks and
+/// regions are also converted.
+LLVM_NODISCARD LogicalResult applyPartialConversion(
+    ArrayRef<Operation *> ops, ConversionTarget &target,
+    OwningRewritePatternList &patterns, TypeConverter *converter = nullptr);
+LLVM_NODISCARD LogicalResult applyPartialConversion(
+    Operation *op, ConversionTarget &target, OwningRewritePatternList &patterns,
+    TypeConverter *converter = nullptr);
+
+/// Apply a complete conversion on the given operations, and all nested
+/// operations. This method returns failure if the conversion of any operation
+/// fails, or if there are unreachable blocks in any of the regions nested
+/// within 'ops'. If 'converter' is provided, the signatures of blocks and
+/// regions are also converted.
+LLVM_NODISCARD LogicalResult applyFullConversion(
+    ArrayRef<Operation *> ops, ConversionTarget &target,
+    OwningRewritePatternList &patterns, TypeConverter *converter = nullptr);
+LLVM_NODISCARD LogicalResult applyFullConversion(
+    Operation *op, ConversionTarget &target, OwningRewritePatternList &patterns,
+    TypeConverter *converter = nullptr);
+
+/// Apply an analysis conversion on the given operations, and all nested
+/// operations. This method analyzes which operations would be successfully
+/// converted to the target if a conversion was applied. All operations that
+/// were found to be legalizable to the given 'target' are placed within the
+/// provided 'convertedOps' set; note that no actual rewrites are applied to the
+/// operations on success and only pre-existing operations are added to the set.
+/// This method only returns failure if there are unreachable blocks in any of
+/// the regions nested within 'ops', or if a type conversion failed. If
+/// 'converter' is provided, the signatures of blocks and regions are also
+/// considered for conversion.
+LLVM_NODISCARD LogicalResult applyAnalysisConversion(
+    ArrayRef<Operation *> ops, ConversionTarget &target,
+    OwningRewritePatternList &patterns, DenseSet<Operation *> &convertedOps,
+    TypeConverter *converter = nullptr);
+LLVM_NODISCARD LogicalResult applyAnalysisConversion(
+    Operation *op, ConversionTarget &target, OwningRewritePatternList &patterns,
+    DenseSet<Operation *> &convertedOps, TypeConverter *converter = nullptr);
+} // end namespace mlir
+
+#endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
diff --git a/third_party/mlir/include/mlir/Transforms/FoldUtils.h b/third_party/mlir/include/mlir/Transforms/FoldUtils.h
new file mode 100644
index 0000000..87a3e13
--- /dev/null
+++ b/third_party/mlir/include/mlir/Transforms/FoldUtils.h
@@ -0,0 +1,123 @@
+//===- FoldUtils.h - Operation Fold Utilities -------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This header file declares various operation folding utilities. These
+// utilities are intended to be used by passes to unify and simply their logic.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_FOLDUTILS_H
+#define MLIR_TRANSFORMS_FOLDUTILS_H
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Dialect.h"
+
+namespace mlir {
+class Operation;
+class Value;
+
+/// A utility class for folding operations, and unifying duplicated constants
+/// generated along the way.
+class OperationFolder {
+public:
+  /// Tries to perform folding on the given `op`, including unifying
+  /// deduplicated constants. If successful, replaces `op`'s uses with
+  /// folded results, and returns success. `preReplaceAction` is invoked on `op`
+  /// before it is replaced. 'processGeneratedConstants' is invoked for any new
+  /// operations generated when folding. If the op was completely folded it is
+  /// erased.
+  LogicalResult tryToFold(
+      Operation *op,
+      llvm::function_ref<void(Operation *)> processGeneratedConstants = nullptr,
+      llvm::function_ref<void(Operation *)> preReplaceAction = nullptr);
+
+  /// Notifies that the given constant `op` should be remove from this
+  /// OperationFolder's internal bookkeeping.
+  ///
+  /// Note: this method must be called if a constant op is to be deleted
+  /// externally to this OperationFolder. `op` must be a constant op.
+  void notifyRemoval(Operation *op);
+
+  /// Create an operation of specific op type with the given builder,
+  /// and immediately try to fold it. This function populates 'results' with
+  /// the results after folding the operation.
+  template <typename OpTy, typename... Args>
+  void create(OpBuilder &builder, SmallVectorImpl<Value *> &results,
+              Location location, Args &&... args) {
+    Operation *op = builder.create<OpTy>(location, std::forward<Args>(args)...);
+    if (failed(tryToFold(op, results)))
+      results.assign(op->result_begin(), op->result_end());
+    else if (op->getNumResults() != 0)
+      op->erase();
+  }
+
+  /// Overload to create or fold a single result operation.
+  template <typename OpTy, typename... Args>
+  typename std::enable_if<OpTy::template hasTrait<OpTrait::OneResult>(),
+                          Value *>::type
+  create(OpBuilder &builder, Location location, Args &&... args) {
+    SmallVector<Value *, 1> results;
+    create<OpTy>(builder, results, location, std::forward<Args>(args)...);
+    return results.front();
+  }
+
+  /// Overload to create or fold a zero result operation.
+  template <typename OpTy, typename... Args>
+  typename std::enable_if<OpTy::template hasTrait<OpTrait::ZeroResult>(),
+                          OpTy>::type
+  create(OpBuilder &builder, Location location, Args &&... args) {
+    auto op = builder.create<OpTy>(location, std::forward<Args>(args)...);
+    SmallVector<Value *, 0> unused;
+    (void)tryToFold(op.getOperation(), unused);
+
+    // Folding cannot remove a zero-result operation, so for convenience we
+    // continue to return it.
+    return op;
+  }
+
+private:
+  /// This map keeps track of uniqued constants by dialect, attribute, and type.
+  /// A constant operation materializes an attribute with a type. Dialects may
+  /// generate different constants with the same input attribute and type, so we
+  /// also need to track per-dialect.
+  using ConstantMap =
+      DenseMap<std::tuple<Dialect *, Attribute, Type>, Operation *>;
+
+  /// Tries to perform folding on the given `op`. If successful, populates
+  /// `results` with the results of the folding.
+  LogicalResult tryToFold(Operation *op, SmallVectorImpl<Value *> &results,
+                          llvm::function_ref<void(Operation *)>
+                              processGeneratedConstants = nullptr);
+
+  /// Try to get or create a new constant entry. On success this returns the
+  /// constant operation, nullptr otherwise.
+  Operation *tryGetOrCreateConstant(ConstantMap &uniquedConstants,
+                                    Dialect *dialect, OpBuilder &builder,
+                                    Attribute value, Type type, Location loc);
+
+  /// A mapping between an insertion region and the constants that have been
+  /// created within it.
+  DenseMap<Region *, ConstantMap> foldScopes;
+
+  /// This map tracks all of the dialects that an operation is referenced by;
+  /// given that many dialects may generate the same constant.
+  DenseMap<Operation *, SmallVector<Dialect *, 2>> referencedDialects;
+};
+
+} // end namespace mlir
+
+#endif // MLIR_TRANSFORMS_FOLDUTILS_H
diff --git a/third_party/mlir/include/mlir/Transforms/LoopFusionUtils.h b/third_party/mlir/include/mlir/Transforms/LoopFusionUtils.h
new file mode 100644
index 0000000..b6d1ea4
--- /dev/null
+++ b/third_party/mlir/include/mlir/Transforms/LoopFusionUtils.h
@@ -0,0 +1,100 @@
+//===- LoopFusionUtils.h - Loop fusion utilities ----------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This header file defines prototypes for various loop fusion utility
+// methods: these are not passes by themselves but are used either by passes,
+// optimization sequences, or in turn by other transformation utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H
+#define MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+class AffineForOp;
+struct ComputationSliceState;
+class Operation;
+
+// TODO(andydavis) Extend this module to include utility functions for querying
+// fusion cost/storage reduction, and for performing the loop fusion
+// transformation.
+
+struct FusionResult {
+  enum ResultEnum {
+    Success,
+    FailPrecondition,     // Failed precondition for fusion. (e.g. same block).
+    FailBlockDependence,  // Fusion would violate another dependence in block.
+    FailFusionDependence, // Fusion would reverse dependences between loops.
+    FailComputationSlice, // Unable to compute src loop computation slice.
+  } value;
+  FusionResult(ResultEnum v) : value(v) {}
+};
+
+/// Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the
+/// loop nest rooted at 'dstForOp' at 'dstLoopDepth'. Returns FusionResult
+/// 'Success' if fusion of the src/dst loop nests is feasible (i.e. they are
+/// in the same block and dependences would not be violated). Otherwise
+/// returns a FusionResult explaining why fusion is not feasible.
+/// NOTE: This function is not feature complete and should only be used in
+/// testing.
+/// TODO(andydavis) Update comments when this function is fully implemented.
+FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
+                          unsigned dstLoopDepth,
+                          ComputationSliceState *srcSlice);
+
+/// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
+/// and operation count) for a loop nest up until (and including) the innermost
+/// loop body.
+struct LoopNestStats {
+  /// Map from AffineForOp to immediate child AffineForOps in its loop body.
+  llvm::DenseMap<Operation *, llvm::SmallVector<AffineForOp, 2>> loopMap;
+  /// Map from AffineForOp to count of operations in its loop body.
+  llvm::DenseMap<Operation *, uint64_t> opCountMap;
+  /// Map from AffineForOp to its constant trip count.
+  llvm::DenseMap<Operation *, uint64_t> tripCountMap;
+};
+
+/// Collect loop nest statistics (eg. loop trip count and operation count)
+/// in 'stats' for loop nest rooted at 'forOp'. Returns true on success,
+/// returns false otherwise.
+// TODO(andydavis) Consider moving this to LoopUtils.
+bool getLoopNestStats(AffineForOp forOp, LoopNestStats *stats);
+
+/// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
+/// Currently, the total cost is computed by counting the total operation
+/// instance count (i.e. total number of operations in the loop body * loop
+/// trip count) for the entire loop nest.
+// TODO(andydavis) Improve this cost model.
+int64_t getComputeCost(AffineForOp forOp, LoopNestStats &stats);
+
+/// Computes and returns in 'computeCost', the total compute cost of fusing the
+/// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently,
+/// the total cost is computed by counting the total operation instance count
+/// (i.e. total number of operations in the loop body * loop trip count) for
+/// the entire loop nest.
+/// Returns true on success, failure otherwise (e.g. non-constant trip counts).
+// TODO(andydavis) Improve this cost model.
+bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
+                          AffineForOp dstForOp, LoopNestStats &dstStats,
+                          ComputationSliceState *slice, int64_t *computeCost);
+
+} // end namespace mlir
+
+#endif // MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H
diff --git a/third_party/mlir/include/mlir/Transforms/LoopUtils.h b/third_party/mlir/include/mlir/Transforms/LoopUtils.h
new file mode 100644
index 0000000..3bc76ba
--- /dev/null
+++ b/third_party/mlir/include/mlir/Transforms/LoopUtils.h
@@ -0,0 +1,208 @@
+//===- LoopUtils.h - Loop transformation utilities --------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This header file defines prototypes for various loop transformation utility
+// methods: these are not passes by themselves but are used either by passes,
+// optimization sequences, or in turn by other transformation utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_LOOP_UTILS_H
+#define MLIR_TRANSFORMS_LOOP_UTILS_H
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+class AffineMap;
+class AffineForOp;
+class FuncOp;
+class OpBuilder;
+class Value;
+
+namespace loop {
+class ForOp;
+} // end namespace loop
+
+/// Unrolls this for operation completely if the trip count is known to be
+/// constant. Returns failure otherwise.
+LogicalResult loopUnrollFull(AffineForOp forOp);
+
+/// Unrolls this for operation by the specified unroll factor. Returns failure
+/// if the loop cannot be unrolled either due to restrictions or due to invalid
+/// unroll factors.
+LogicalResult loopUnrollByFactor(AffineForOp forOp, uint64_t unrollFactor);
+
+/// Unrolls this loop by the specified unroll factor or its trip count,
+/// whichever is lower.
+LogicalResult loopUnrollUpToFactor(AffineForOp forOp, uint64_t unrollFactor);
+
+/// Get perfectly nested sequence of loops starting at root of loop nest
+/// (the first op being another AffineFor, and the second op - a terminator).
+/// A loop is perfectly nested iff: the first op in the loop's body is another
+/// AffineForOp, and the second op is a terminator).
+void getPerfectlyNestedLoops(SmallVectorImpl<AffineForOp> &nestedLoops,
+                             AffineForOp root);
+void getPerfectlyNestedLoops(SmallVectorImpl<loop::ForOp> &nestedLoops,
+                             loop::ForOp root);
+
+/// Unrolls and jams this loop by the specified factor. Returns success if the
+/// loop is successfully unroll-jammed.
+LogicalResult loopUnrollJamByFactor(AffineForOp forOp,
+                                    uint64_t unrollJamFactor);
+
+/// Unrolls and jams this loop by the specified factor or by the trip count (if
+/// constant), whichever is lower.
+LogicalResult loopUnrollJamUpToFactor(AffineForOp forOp,
+                                      uint64_t unrollJamFactor);
+
+/// Promotes the loop body of a AffineForOp to its containing block if the
+/// AffineForOp was known to have a single iteration.
+LogicalResult promoteIfSingleIteration(AffineForOp forOp);
+
+/// Promotes all single iteration AffineForOp's in the Function, i.e., moves
+/// their body into the containing Block.
+void promoteSingleIterationLoops(FuncOp f);
+
+/// Computes the cleanup loop lower bound of the loop being unrolled with
+/// the specified unroll factor; this bound will also be upper bound of the main
+/// part of the unrolled loop. Computes the bound as an AffineMap with its
+/// operands or a null map when the trip count can't be expressed as an affine
+/// expression.
+void getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor,
+                              AffineMap *map,
+                              SmallVectorImpl<Value *> *operands,
+                              OpBuilder &builder);
+
+/// Skew the operations in the body of a 'affine.for' operation with the
+/// specified operation-wise shifts. The shifts are with respect to the
+/// original execution order, and are multiplied by the loop 'step' before being
+/// applied.
+LLVM_NODISCARD
+LogicalResult instBodySkew(AffineForOp forOp, ArrayRef<uint64_t> shifts,
+                           bool unrollPrologueEpilogue = false);
+
+/// Tiles the specified band of perfectly nested loops creating tile-space loops
+/// and intra-tile loops. A band is a contiguous set of loops.
+LLVM_NODISCARD
+LogicalResult tileCodeGen(MutableArrayRef<AffineForOp> band,
+                          ArrayRef<unsigned> tileSizes);
+
+/// Performs loop interchange on 'forOpA' and 'forOpB'. Requires that 'forOpA'
+/// and 'forOpB' are part of a perfectly nested sequence of loops.
+void interchangeLoops(AffineForOp forOpA, AffineForOp forOpB);
+
+/// Checks if the loop interchange permutation 'loopPermMap', of the perfectly
+/// nested sequence of loops in 'loops', would violate dependences (loop 'i' in
+/// 'loops' is mapped to location 'j = 'loopPermMap[i]' in the interchange).
+bool isValidLoopInterchangePermutation(ArrayRef<AffineForOp> loops,
+                                       ArrayRef<unsigned> loopPermMap);
+
+/// Performs a sequence of loop interchanges on perfectly nested 'loops', as
+/// specified by permutation 'loopPermMap' (loop 'i' in 'loops' is mapped to
+/// location 'j = 'loopPermMap[i]' after the loop interchange).
+unsigned interchangeLoops(ArrayRef<AffineForOp> loops,
+                          ArrayRef<unsigned> loopPermMap);
+
+// Sinks all sequential loops to the innermost levels (while preserving
+// relative order among them) and moves all parallel loops to the
+// outermost (while again preserving relative order among them).
+// Returns AffineForOp of the root of the new loop nest after loop interchanges.
+AffineForOp sinkSequentialLoops(AffineForOp forOp);
+
+/// Sinks 'forOp' by 'loopDepth' levels by performing a series of loop
+/// interchanges. Requires that 'forOp' is part of a perfect nest with
+/// 'loopDepth' AffineForOps consecutively nested under it.
+void sinkLoop(AffineForOp forOp, unsigned loopDepth);
+
+/// Performs tiling fo imperfectly nested loops (with interchange) by
+/// strip-mining the `forOps` by `sizes` and sinking them, in their order of
+/// occurrence in `forOps`, under each of the `targets`.
+/// Returns the new AffineForOps, one per each of (`forOps`, `targets`) pair,
+/// nested immediately under each of `targets`.
+using Loops = SmallVector<loop::ForOp, 8>;
+using TileLoops = std::pair<Loops, Loops>;
+SmallVector<SmallVector<AffineForOp, 8>, 8> tile(ArrayRef<AffineForOp> forOps,
+                                                 ArrayRef<uint64_t> sizes,
+                                                 ArrayRef<AffineForOp> targets);
+SmallVector<Loops, 8> tile(ArrayRef<loop::ForOp> forOps,
+                           ArrayRef<Value *> sizes,
+                           ArrayRef<loop::ForOp> targets);
+
+/// Performs tiling (with interchange) by strip-mining the `forOps` by `sizes`
+/// and sinking them, in their order of occurrence in `forOps`, under `target`.
+/// Returns the new AffineForOps, one per `forOps`, nested immediately under
+/// `target`.
+SmallVector<AffineForOp, 8> tile(ArrayRef<AffineForOp> forOps,
+                                 ArrayRef<uint64_t> sizes, AffineForOp target);
+Loops tile(ArrayRef<loop::ForOp> forOps, ArrayRef<Value *> sizes,
+           loop::ForOp target);
+
+/// Tile a nest of loop::ForOp loops rooted at `rootForOp` with the given
+/// (parametric) sizes. Sizes are expected to be strictly positive values at
+/// runtime.  If more sizes than loops are provided, discard the trailing values
+/// in sizes.  Assumes the loop nest is permutable.
+/// Returns the newly created intra-tile loops.
+Loops tilePerfectlyNested(loop::ForOp rootForOp, ArrayRef<Value *> sizes);
+
+/// Tile a nest of standard for loops rooted at `rootForOp` by finding such
+/// parametric tile sizes that the outer loops have a fixed number of iterations
+/// as defined in `sizes`.
+TileLoops extractFixedOuterLoops(loop::ForOp rootFOrOp,
+                                 ArrayRef<int64_t> sizes);
+
+/// Replace a perfect nest of "for" loops with a single linearized loop. Assumes
+/// `loops` contains a list of perfectly nested loops with bounds and steps
+/// independent of any loop induction variable involved in the nest.
+void coalesceLoops(MutableArrayRef<loop::ForOp> loops);
+
+/// Maps `forOp` for execution on a parallel grid of virtual `processorIds` of
+/// size given by `numProcessors`. This is achieved by embedding the SSA values
+/// corresponding to `processorIds` and `numProcessors` into the bounds and step
+/// of the `forOp`. No check is performed on the legality of the rewrite, it is
+/// the caller's responsibility to ensure legality.
+///
+/// Requires that `processorIds` and `numProcessors` have the same size and that
+/// for each idx, `processorIds`[idx] takes, at runtime, all values between 0
+/// and `numProcessors`[idx] - 1. This corresponds to traditional use cases for:
+///   1. GPU (threadIdx, get_local_id(), ...)
+///   2. MPI (MPI_Comm_rank)
+///   3. OpenMP (omp_get_thread_num)
+///
+/// Example:
+/// Assuming a 2-d grid with processorIds = [blockIdx.x, threadIdx.x] and
+/// numProcessors = [gridDim.x, blockDim.x], the loop:
+///
+/// ```
+///    loop.for %i = %lb to %ub step %step {
+///      ...
+///    }
+/// ```
+///
+/// is rewritten into a version resembling the following pseudo-IR:
+///
+/// ```
+///    loop.for %i = %lb + threadIdx.x + blockIdx.x * blockDim.x to %ub
+///       step %gridDim.x * blockDim.x {
+///      ...
+///    }
+/// ```
+void mapLoopToProcessorIds(loop::ForOp forOp, ArrayRef<Value *> processorId,
+                           ArrayRef<Value *> numProcessors);
+} // end namespace mlir
+
+#endif // MLIR_TRANSFORMS_LOOP_UTILS_H
diff --git a/third_party/mlir/include/mlir/Transforms/LowerAffine.h b/third_party/mlir/include/mlir/Transforms/LowerAffine.h
new file mode 100644
index 0000000..5fae476
--- /dev/null
+++ b/third_party/mlir/include/mlir/Transforms/LowerAffine.h
@@ -0,0 +1,58 @@
+//===- LowerAffine.h - Convert Affine to Standard dialect -------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_TRANSFORMS_LOWERAFFINE_H
+#define MLIR_TRANSFORMS_LOWERAFFINE_H
+
+#include "mlir/Support/LLVM.h"
+#include <vector>
+
+namespace mlir {
+class AffineExpr;
+class AffineForOp;
+class Location;
+struct LogicalResult;
+class MLIRContext;
+class OpBuilder;
+class RewritePattern;
+class Value;
+
+// Owning list of rewriting patterns.
+class OwningRewritePatternList;
+
+/// Emit code that computes the given affine expression using standard
+/// arithmetic operations applied to the provided dimension and symbol values.
+Value *expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr,
+                        ArrayRef<Value *> dimValues,
+                        ArrayRef<Value *> symbolValues);
+
+/// Collect a set of patterns to convert from the Affine dialect to the Standard
+/// dialect, in particular convert structured affine control flow into CFG
+/// branch-based control flow.
+void populateAffineToStdConversionPatterns(OwningRewritePatternList &patterns,
+                                           MLIRContext *ctx);
+
+/// Emit code that computes the lower bound of the given affine loop using
+/// standard arithmetic operations.
+Value *lowerAffineLowerBound(AffineForOp op, OpBuilder &builder);
+
+/// Emit code that computes the upper bound of the given affine loop using
+/// standard arithmetic operations.
+Value *lowerAffineUpperBound(AffineForOp op, OpBuilder &builder);
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_LOWERAFFINE_H
diff --git a/third_party/mlir/include/mlir/Transforms/Passes.h b/third_party/mlir/include/mlir/Transforms/Passes.h
new file mode 100644
index 0000000..ee36517
--- /dev/null
+++ b/third_party/mlir/include/mlir/Transforms/Passes.h
@@ -0,0 +1,135 @@
+//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This header file defines prototypes that expose pass constructors in the loop
+// transformation library.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_PASSES_H
+#define MLIR_TRANSFORMS_PASSES_H
+
+#include "mlir/Support/LLVM.h"
+#include <functional>
+#include <limits>
+
+namespace mlir {
+
+class AffineForOp;
+class FunctionPassBase;
+class ModulePassBase;
+
+/// Creates a constant folding pass. Note that this pass solely provides simple
+/// top-down constant folding functionality; it is intended to be used for
+/// testing purpose. Use Canonicalizer pass, which exploits more simplification
+/// opportunties exposed by constant folding, for the general cases.
+FunctionPassBase *createTestConstantFoldPass();
+
+/// Creates an instance of the Canonicalizer pass.
+FunctionPassBase *createCanonicalizerPass();
+
+/// Creates a pass to perform common sub expression elimination.
+FunctionPassBase *createCSEPass();
+
+/// Creates a pass to vectorize loops, operations and data types using a
+/// target-independent, n-D super-vector abstraction.
+FunctionPassBase *
+createVectorizePass(llvm::ArrayRef<int64_t> virtualVectorSize);
+
+/// Creates a pass to allow independent testing of vectorizer functionality with
+/// FileCheck.
+FunctionPassBase *createVectorizerTestPass();
+
+/// Creates a pass to lower super-vectors to target-dependent HW vectors.
+FunctionPassBase *
+createMaterializeVectorsPass(llvm::ArrayRef<int64_t> vectorSize);
+
+/// Creates a loop unrolling pass with the provided parameters.
+/// 'getUnrollFactor' is a function callback for clients to supply a function
+/// that computes an unroll factor - the callback takes precedence over unroll
+/// factors supplied through other means. If -1 is passed as the unrollFactor
+/// and no callback is provided, anything passed from the command-line (if at
+/// all) or the default unroll factor is used (LoopUnroll:kDefaultUnrollFactor).
+FunctionPassBase *createLoopUnrollPass(
+    int unrollFactor = -1, int unrollFull = -1,
+    const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr);
+
+/// Creates a loop unroll jam pass to unroll jam by the specified factor. A
+/// factor of -1 lets the pass use the default factor or the one on the command
+/// line if provided.
+FunctionPassBase *createLoopUnrollAndJamPass(int unrollJamFactor = -1);
+
+/// Creates an simplification pass for affine structures.
+FunctionPassBase *createSimplifyAffineStructuresPass();
+
+/// Creates a loop fusion pass which fuses loops. Buffers of size less than or
+/// equal to `localBufSizeThreshold` are promoted to memory space
+/// `fastMemorySpace'.
+FunctionPassBase *createLoopFusionPass(unsigned fastMemorySpace = 0,
+                                       uint64_t localBufSizeThreshold = 0,
+                                       bool maximalFusion = false);
+
+/// Creates a loop invariant code motion pass that hoists loop invariant
+/// instructions out of the loop.
+FunctionPassBase *createLoopInvariantCodeMotionPass();
+
+/// Creates a pass to pipeline explicit movement of data across levels of the
+/// memory hierarchy.
+FunctionPassBase *createPipelineDataTransferPass();
+
+/// Lowers affine control flow operations (ForStmt, IfStmt and AffineApplyOp)
+/// to equivalent lower-level constructs (flow of basic blocks and arithmetic
+/// primitives).
+FunctionPassBase *createLowerAffinePass();
+
+/// Creates a pass to perform tiling on loop nests.
+FunctionPassBase *createLoopTilingPass(uint64_t cacheSizeBytes);
+
+/// Creates a pass that performs parametric tiling so that the outermost loops
+/// have the given fixed number of iterations.  Assumes outermost loop nests
+/// are permutable.
+FunctionPassBase *
+createSimpleParametricTilingPass(ArrayRef<int64_t> outerLoopSizes);
+
+/// Creates a pass that transforms perfectly nested loops with independent
+/// bounds into a single loop.
+FunctionPassBase *createLoopCoalescingPass();
+
+/// Performs packing (or explicit copying) of accessed memref regions into
+/// buffers in the specified faster memory space through either pointwise copies
+/// or DMA operations.
+FunctionPassBase *createAffineDataCopyGenerationPass(
+    unsigned slowMemorySpace, unsigned fastMemorySpace,
+    unsigned tagMemorySpace = 0, int minDmaTransferSize = 1024,
+    uint64_t fastMemCapacityBytes = std::numeric_limits<uint64_t>::max());
+
+/// Creates a pass to lower VectorTransferReadOp and VectorTransferWriteOp.
+FunctionPassBase *createLowerVectorTransfersPass();
+
+/// Creates a pass to perform optimizations relying on memref dataflow such as
+/// store to load forwarding, elimination of dead stores, and dead allocs.
+FunctionPassBase *createMemRefDataFlowOptPass();
+
+/// Creates a pass to strip debug information from a function.
+FunctionPassBase *createStripDebugInfoPass();
+
+/// Creates a pass which tests loop fusion utilities.
+FunctionPassBase *createTestLoopFusionPass();
+
+} // end namespace mlir
+
+#endif // MLIR_TRANSFORMS_PASSES_H
diff --git a/third_party/mlir/include/mlir/Transforms/RegionUtils.h b/third_party/mlir/include/mlir/Transforms/RegionUtils.h
new file mode 100644
index 0000000..a00ddc6
--- /dev/null
+++ b/third_party/mlir/include/mlir/Transforms/RegionUtils.h
@@ -0,0 +1,50 @@
+//===- RegionUtils.h - Region-related transformation utilities --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_TRANSFORMS_REGIONUTILS_H_
+#define MLIR_TRANSFORMS_REGIONUTILS_H_
+
+#include "mlir/IR/Region.h"
+#include "mlir/IR/Value.h"
+
+#include "llvm/ADT/SetVector.h"
+
+namespace mlir {
+
+/// Check if all values in the provided range are defined above the `limit`
+/// region.  That is, if they are defined in a region that is a proper ancestor
+/// of `limit`.
+template <typename Range>
+bool areValuesDefinedAbove(Range values, Region &limit) {
+  for (Value *v : values)
+    if (!v->getParentRegion()->isProperAncestor(&limit))
+      return false;
+  return true;
+}
+
+/// Replace all uses of `orig` within the given region with `replacement`.
+void replaceAllUsesInRegionWith(Value *orig, Value *replacement,
+                                Region &region);
+
+/// Fill `values` with a list of values defined at the ancestors of the `limit`
+/// region and used within `region` or its descendants.
+void getUsedValuesDefinedAbove(Region &region, Region &limit,
+                               llvm::SetVector<Value *> &values);
+
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_REGIONUTILS_H_
diff --git a/third_party/mlir/include/mlir/Transforms/Utils.h b/third_party/mlir/include/mlir/Transforms/Utils.h
new file mode 100644
index 0000000..ff48a90
--- /dev/null
+++ b/third_party/mlir/include/mlir/Transforms/Utils.h
@@ -0,0 +1,122 @@
+//===- Utils.h - General transformation utilities ---------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This header file defines prototypes for various transformation utilities for
+// memref's and non-loop IR structures. These are not passes by themselves but
+// are used either by passes, optimization sequences, or in turn by other
+// transformation utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_UTILS_H
+#define MLIR_TRANSFORMS_UTILS_H
+
+#include "mlir/IR/AffineMap.h"
+#include "mlir/StandardOps/Ops.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
+
+namespace mlir {
+
+class AffineApplyOp;
+class AffineForOp;
+class Location;
+class OpBuilder;
+
+/// Replaces all "deferencing" uses of oldMemRef with newMemRef while optionally
+/// remapping the old memref's indices using the supplied affine map,
+/// 'indexRemap'. The new memref could be of a different shape or rank.
+/// 'extraIndices' provides additional access indices to be added to the start.
+///
+/// 'indexRemap' remaps indices of the old memref access to a new set of indices
+/// that are used to index the memref. Additional input operands to indexRemap
+/// can be optionally provided, and they are added at the start of its input
+/// list. 'indexRemap' is expected to have only dimensional inputs, and the
+/// number of its inputs equal to extraOperands.size() plus rank of the memref.
+/// 'extraOperands' is an optional argument that corresponds to additional
+/// operands (inputs) for indexRemap at the beginning of its input list.
+///
+/// 'domInstFilter', if non-null, restricts the replacement to only those
+/// operations that are dominated by the former; similarly, `postDomInstFilter`
+/// restricts replacement to only those operations that are postdominated by it.
+///
+/// Returns true on success and false if the replacement is not possible,
+/// whenever a memref is used as an operand in a non-deferencing context, except
+/// for dealloc's on the memref which are left untouched. See comments at
+/// function definition for an example.
+//
+//  Ex: to replace load %A[%i, %j] with load %Abuf[%t mod 2, %ii - %i, %j]:
+//  The SSA value corresponding to '%t mod 2' should be in 'extraIndices', and
+//  index remap will perform (%i, %j) -> (%ii - %i, %j), i.e., indexRemap = (d0,
+//  d1, d2) -> (d0 - d1, d2), and %ii will be the extra operand. Without any
+//  extra operands, note that 'indexRemap' would just be applied to existing
+//  indices (%i, %j).
+//  TODO(bondhugula): allow extraIndices to be added at any position.
+bool replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
+                              ArrayRef<Value *> extraIndices = {},
+                              AffineMap indexRemap = AffineMap(),
+                              ArrayRef<Value *> extraOperands = {},
+                              Operation *domInstFilter = nullptr,
+                              Operation *postDomInstFilter = nullptr);
+
+/// Creates and inserts into 'builder' a new AffineApplyOp, with the number of
+/// its results equal to the number of operands, as a composition
+/// of all other AffineApplyOps reachable from input parameter 'operands'. If
+/// different operands were drawing results from multiple affine apply ops,
+/// these will also be collected into a single (multi-result) affine apply op.
+/// The final results of the composed AffineApplyOp are returned in output
+/// parameter 'results'. Returns the affine apply op created.
+Operation *createComposedAffineApplyOp(OpBuilder &builder, Location loc,
+                                       ArrayRef<Value *> operands,
+                                       ArrayRef<Operation *> affineApplyOps,
+                                       SmallVectorImpl<Value *> *results);
+
+/// Given an operation, inserts one or more single result affine apply
+/// operations, results of which are exclusively used by this operation.
+/// The operands of these newly created affine apply ops are
+/// guaranteed to be loop iterators or terminal symbols of a function.
+///
+/// Before
+///
+/// affine.for %i = 0 to #map(%N)
+///   %idx = affine.apply (d0) -> (d0 mod 2) (%i)
+///   send %A[%idx], ...
+///   %v = "compute"(%idx, ...)
+///
+/// After
+///
+/// affine.for %i = 0 to #map(%N)
+///   %idx = affine.apply (d0) -> (d0 mod 2) (%i)
+///   send %A[%idx], ...
+///   %idx_ = affine.apply (d0) -> (d0 mod 2) (%i)
+///   %v = "compute"(%idx_, ...)
+
+/// This allows the application of different transformations on send and
+/// compute (for eg. different shifts/delays)
+///
+/// Fills `sliceOps` with the list of affine.apply operations.
+/// In the following cases, `sliceOps` remains empty:
+///   1. If none of opInst's operands were the result of an affine.apply
+///      (i.e., there was no affine computation slice to create).
+///   2. If all the affine.apply op's supplying operands to this opInst did not
+///      have any uses other than those in this opInst.
+void createAffineComputationSlice(Operation *opInst,
+                                  SmallVectorImpl<AffineApplyOp> *sliceOps);
+
+} // end namespace mlir
+
+#endif // MLIR_TRANSFORMS_UTILS_H
diff --git a/third_party/mlir/include/mlir/Transforms/ViewRegionGraph.h b/third_party/mlir/include/mlir/Transforms/ViewRegionGraph.h
new file mode 100644
index 0000000..61da9f1
--- /dev/null
+++ b/third_party/mlir/include/mlir/Transforms/ViewRegionGraph.h
@@ -0,0 +1,49 @@
+//===- ViewRegionGraph.h - View/write graphviz graphs -----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Defines interface to produce Graphviz outputs of MLIR Regions.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_VIEWFUNCTIONGRAPH_H_
+#define MLIR_TRANSFORMS_VIEWFUNCTIONGRAPH_H_
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/GraphWriter.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+class FunctionPassBase;
+class Region;
+
+/// Displays the CFG in a window. This is for use from the debugger and
+/// depends on Graphviz to generate the graph.
+void viewGraph(Region &region, const Twine &name, bool shortNames = false,
+               const Twine &title = "",
+               llvm::GraphProgram::Name program = llvm::GraphProgram::DOT);
+
+llvm::raw_ostream &writeGraph(llvm::raw_ostream &os, Region &region,
+                              bool shortNames = false, const Twine &title = "");
+
+/// Creates a pass to print CFG graphs.
+FunctionPassBase *createPrintCFGGraphPass(llvm::raw_ostream &os = llvm::errs(),
+                                          bool shortNames = false,
+                                          const llvm::Twine &title = "");
+
+} // end namespace mlir
+
+#endif // MLIR_TRANSFORMS_VIEWFUNCTIONGRAPH_H_
diff --git a/third_party/mlir/include/mlir/Translation.h b/third_party/mlir/include/mlir/Translation.h
new file mode 100644
index 0000000..b0cb930
--- /dev/null
+++ b/third_party/mlir/include/mlir/Translation.h
@@ -0,0 +1,71 @@
+//===- Translation.h - Translation registry ---------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Registry for user-provided translations.
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_TRANSLATION_H
+#define MLIR_TRANSLATION_H
+
+#include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace mlir {
+struct LogicalResult;
+class MLIRContext;
+class ModuleOp;
+class OwningModuleRef;
+
+/// Interface of the function that translates a file to MLIR.  The
+/// implementation should create a new MLIR ModuleOp in the given context and
+/// return a pointer to it, or a nullptr in case of any error.
+using TranslateToMLIRFunction =
+    std::function<OwningModuleRef(llvm::StringRef, MLIRContext *)>;
+/// Interface of the function that translates MLIR to a different format and
+/// outputs the result to a file. It is allowed to modify the module.
+using TranslateFromMLIRFunction =
+    std::function<LogicalResult(ModuleOp, llvm::StringRef)>;
+
+/// Use Translate[To|From]MLIRRegistration as a global initialiser that
+/// registers a function and associates it with name. This requires that a
+/// translation has not been registered to a given name.
+///
+/// Usage:
+///
+///   // At namespace scope.
+///   static TranslateToMLIRRegistration Unused(&MySubCommand, [] { ... });
+///
+/// \{
+struct TranslateToMLIRRegistration {
+  TranslateToMLIRRegistration(llvm::StringRef name,
+                              const TranslateToMLIRFunction &function);
+};
+
+struct TranslateFromMLIRRegistration {
+  TranslateFromMLIRRegistration(llvm::StringRef name,
+                                const TranslateFromMLIRFunction &function);
+};
+/// \}
+
+/// Get a read-only reference to the translator registry.
+const llvm::StringMap<TranslateToMLIRFunction> &getTranslationToMLIRRegistry();
+const llvm::StringMap<TranslateFromMLIRFunction> &
+getTranslationFromMLIRRegistry();
+
+} // namespace mlir
+
+#endif // MLIR_TRANSLATION_H
diff --git a/third_party/mlir/include/mlir/VectorOps/CMakeLists.txt b/third_party/mlir/include/mlir/VectorOps/CMakeLists.txt
new file mode 100644
index 0000000..6cc7e44
--- /dev/null
+++ b/third_party/mlir/include/mlir/VectorOps/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS VectorOps.td)
+mlir_tablegen(VectorOps.h.inc -gen-op-decls)
+mlir_tablegen(VectorOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRVectorOpsIncGen)
diff --git a/third_party/mlir/include/mlir/VectorOps/VectorOps.h b/third_party/mlir/include/mlir/VectorOps/VectorOps.h
new file mode 100644
index 0000000..47cd8a1
--- /dev/null
+++ b/third_party/mlir/include/mlir/VectorOps/VectorOps.h
@@ -0,0 +1,212 @@
+//===- VectorOps.h - MLIR Super Vectorizer Operations -----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines convenience types for working with super-vectorization
+// operations, in particular super-vector loads and stores.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_VECTOROPS_VECTOROPS_H
+#define MLIR_VECTOROPS_VECTOROPS_H
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+namespace vector {
+
+/// Dialect for super-vectorization Ops.
+class VectorOpsDialect : public Dialect {
+public:
+  VectorOpsDialect(MLIRContext *context);
+  static StringRef getDialectNamespace() { return "vector"; }
+};
+
+/// VectorTransferReadOp performs a blocking read from a scalar memref
+/// location into a super-vector of the same elemental type. This operation is
+/// called 'read' by opposition to 'load' because the super-vector granularity
+/// is generally not representable with a single hardware register. As a
+/// consequence, memory transfers will generally be required when lowering
+/// VectorTransferReadOp. A VectorTransferReadOp is thus a mid-level abstraction
+/// that supports super-vectorization with non-effecting padding for full-tile
+/// only code.
+//
+/// A vector transfer read has semantics similar to a vector load, with
+/// additional support for:
+///   1. an optional value of the elemental type of the MemRef. This value
+///      supports non-effecting padding and is inserted in places where the
+///      vector read exceeds the MemRef bounds. If the value is not specified,
+///      the access is statically guaranteed to be within bounds;
+///   2. an attribute of type AffineMap to specify a slice of the original
+///      MemRef access and its transposition into the super-vector shape.
+///      The permutation_map is an AffineMap that must represent a permutation
+///      from the MemRef dim space projected onto the vector dim space.
+///      This permutation_map has as many output dimensions as the vector rank.
+///      However, it is not necessarily full rank on the target space to signify
+///      that broadcast operations will be needed along certain vector
+///      dimensions.
+///      In the limit, one may load a 0-D slice of a memref (i.e. a single
+///      value) into a vector, which corresponds to broadcasting that value in
+///      the whole vector (i.e. a non-constant splat).
+///
+/// Example with full rank permutation_map:
+/// ```mlir
+///   %A = alloc(%size1, %size2, %size3, %size4) : memref<?x?x?x?xf32>
+///   ...
+///   %val = `ssa-value` : f32
+///   // let %i, %j, %k, %l be ssa-values of type index
+///   %v0 = vector.transfer_read %src[%i, %j, %k, %l]
+///          {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
+///         memref<?x?x?x?xf32>, vector<16x32x64xf32>
+///   %v1 = vector.transfer_read %src[%i, %j, %k, %l], (%val)
+///          {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
+///         memref<?x?x?x?xf32>, vector<16x32x64xf32>
+/// ```
+///
+/// Example with partial rank permutation_map:
+/// ```mlir
+///   %c0 = constant 0 : index
+///   %A = alloc(%size1, %size2, %size3, %size4) : memref<?x?x?x?xf32>
+///   ...
+///   // let %i, %j be ssa-values of type index
+///   %v0 = vector.transfer_read %src[%i, %c0, %c0, %c0]
+///          {permutation_map: (d0, d1, d2, d3) -> (0, d1, 0)} :
+///         memref<?x?x?x?xf32>, vector<16x32x64xf32>
+class VectorTransferReadOp
+    : public Op<VectorTransferReadOp, OpTrait::VariadicOperands,
+                OpTrait::OneResult> {
+  enum Offsets : unsigned { MemRefOffset = 0, FirstIndexOffset = 1 };
+
+public:
+  using Op::Op;
+
+  static StringRef getOperationName() { return "vector.transfer_read"; }
+  static StringRef getPermutationMapAttrName() { return "permutation_map"; }
+  static void build(Builder *builder, OperationState *result,
+                    VectorType vectorType, Value *srcMemRef,
+                    ArrayRef<Value *> srcIndices, AffineMap permutationMap,
+                    Optional<Value *> paddingValue = None);
+  VectorType getResultType() {
+    return getResult()->getType().cast<VectorType>();
+  }
+  Value *getVector() { return getResult(); }
+  Value *getMemRef() { return getOperand(Offsets::MemRefOffset); }
+  VectorType getVectorType() { return getResultType(); }
+  MemRefType getMemRefType() {
+    return getMemRef()->getType().cast<MemRefType>();
+  }
+  operand_range getIndices();
+  Optional<Value *> getPaddingValue();
+  AffineMap getPermutationMap();
+
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+  LogicalResult verify();
+};
+
+/// VectorTransferWriteOp performs a blocking write from a super-vector to
+/// a scalar memref of the same elemental type. This operation is
+/// called 'write' by opposition to 'store' because the super-vector granularity
+/// is generally not representable with a single hardware register. As a
+/// consequence, memory transfers will generally be required when lowering
+/// VectorTransferWriteOp. A VectorTransferWriteOp is thus a mid-level
+/// abstraction that supports super-vectorization with non-effecting padding for
+/// full-tile only code.
+///
+/// A vector transfer write has semantics similar to a vector store, with
+/// additional support for handling out-of-bounds situations. It is the
+/// responsibility of vector.transfer_write's implementation to ensure the
+/// memory writes are valid. Different implementations may be pertinent
+/// depending on the hardware support including:
+/// 1. predication;
+/// 2. explicit control-flow;
+/// 3. Read-Modify-Write;
+/// 4. writing out of bounds of the memref when the allocation allows it.
+///
+/// Example:
+/// ```mlir
+///   %A = alloc(%size1, %size2, %size3, %size4) : memref<?x?x?x?xf32>.
+///   %val = `ssa-value` : vector<16x32x64xf32>
+///   // let %i, %j, %k, %l be ssa-values of type index
+///   vector.transfer_write %val, %src[%i, %j, %k, %l]
+///     {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
+///   vector<16x32x64xf32>, memref<?x?x?x?xf32>
+/// ```
+class VectorTransferWriteOp
+    : public Op<VectorTransferWriteOp, OpTrait::VariadicOperands,
+                OpTrait::ZeroResult> {
+  enum Offsets : unsigned {
+    VectorOffset = 0,
+    MemRefOffset = 1,
+    FirstIndexOffset = 2
+  };
+
+public:
+  using Op::Op;
+
+  static StringRef getOperationName() { return "vector.transfer_write"; }
+  static StringRef getPermutationMapAttrName() { return "permutation_map"; }
+  static void build(Builder *builder, OperationState *result, Value *srcVector,
+                    Value *dstMemRef, ArrayRef<Value *> dstIndices,
+                    AffineMap permutationMap);
+  Value *getVector() { return getOperand(Offsets::VectorOffset); }
+  VectorType getVectorType() {
+    return getVector()->getType().cast<VectorType>();
+  }
+  Value *getMemRef() { return getOperand(Offsets::MemRefOffset); }
+  MemRefType getMemRefType() {
+    return getMemRef()->getType().cast<MemRefType>();
+  }
+  operand_range getIndices();
+  AffineMap getPermutationMap();
+
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+  LogicalResult verify();
+};
+
+/// VectorTypeCastOp performs a conversion from a memref with scalar element to
+/// memref with vector element, copying the shape of the memref to the vector.
+///
+/// Example:
+///
+/// ```mlir
+///  %A  = alloc() : memref<5x4x3xf32>
+///  %VA = vector.type_cast %A : memref<5x4x3xf32>, memref<1xvector<5x4x3xf32>>
+/// ```
+class VectorTypeCastOp
+    : public Op<VectorTypeCastOp, OpTrait::OneOperand, OpTrait::OneResult> {
+public:
+  using Op::Op;
+
+  static StringRef getOperationName() { return "vector.type_cast"; }
+  static void build(Builder *builder, OperationState *result, Value *srcVector,
+                    Type dstType);
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+  LogicalResult verify();
+};
+
+#define GET_OP_CLASSES
+#include "mlir/VectorOps/VectorOps.h.inc"
+
+} // end namespace vector
+} // end namespace mlir
+
+#endif // MLIR_VECTOROPS_VECTOROPS_H
diff --git a/third_party/mlir/include/mlir/VectorOps/VectorOps.td b/third_party/mlir/include/mlir/VectorOps/VectorOps.td
new file mode 100644
index 0000000..962e53b
--- /dev/null
+++ b/third_party/mlir/include/mlir/VectorOps/VectorOps.td
@@ -0,0 +1,99 @@
+//===- VectorOps.td - Vector op definitions ---------------*- tablegen -*-====//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Defines MLIR vector operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef VECTOR_OPS
+#else
+#define VECTOR_OPS
+
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+def Vector_Dialect : Dialect {
+  let name = "vector";
+  let cppNamespace = "vector";
+}
+
+// Base class for Vector dialect ops.
+class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
+    Op<Vector_Dialect, mnemonic, traits> {
+  // For every vector op, there needs to be a:
+  //   * void print(OpAsmPrinter *p, ${C++ class of Op} op)
+  //   * LogicalResult verify(${C++ class of Op} op)
+  //   * ParseResult parse${C++ class of Op}(OpAsmParser *parser,
+  //                                         OperationState *result)
+  // functions.
+  let printer = [{ return ::print(p, *this); }];
+  let verifier = [{ return ::verify(*this); }];
+  let parser = [{ return ::parse$cppClass(parser, result); }];
+}
+
+def ExtractElementOp :
+  Vector_Op<"extractelement", [NoSideEffect,
+     PredOpTrait<"operand and result have same element type",
+                 TCresVTEtIsSameAsOpBase<0, 0>>]>,
+    Arguments<(ins AnyVector:$vector, I32ArrayAttr:$position)>,
+    Results<(outs AnyType)> {
+  let summary = "extractelement operation";
+  let description = [{
+    Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at
+    the proper position. Degenerates to an element type in the 0-D case.
+
+    Examples:
+    ```
+      %1 = vector.extractelement %0[3]: vector<4x8x16xf32>
+      %2 = vector.extractelement %0[3, 3, 3]: vector<4x8x16xf32>
+    ```
+  }];
+  let extraClassDeclaration = [{
+    VectorType getVectorType() {
+      return vector()->getType().cast<VectorType>();
+    }
+  }];
+}
+def OuterProductOp :
+  Vector_Op<"outerproduct", [NoSideEffect, SameOperandsAndResultElementType]>,
+    Arguments<(ins AnyVector:$lhs, AnyVector:$rhs)>,
+    Results<(outs AnyVector)> {
+  let summary = "outerproduct operation";
+  let description = [{
+    Takes 2 1-D vectors and returns the 2-D vector containing the outer product.
+
+    Example:
+    ```
+      %2 = vector.extractelement %0, %1: vector<4xf32>, vector<8xf32>
+      return %2: vector<4x8xf32>
+    ```
+  }];
+  let extraClassDeclaration = [{
+    VectorType getOperandVectorTypeLHS() {
+      return lhs()->getType().cast<VectorType>();
+    }
+    VectorType getOperandVectorTypeRHS() {
+      return rhs()->getType().cast<VectorType>();
+    }
+    VectorType getVectorType() {
+      return getResult()->getType().cast<VectorType>();
+    }
+  }];
+}
+#endif // VECTOR_OPS
diff --git a/third_party/mlir/lib/AffineOps/AffineOps.cpp b/third_party/mlir/lib/AffineOps/AffineOps.cpp
new file mode 100644
index 0000000..51a6ec2
--- /dev/null
+++ b/third_party/mlir/lib/AffineOps/AffineOps.cpp
@@ -0,0 +1,1760 @@
+//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/StandardOps/Ops.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/Support/Debug.h"
+using namespace mlir;
+using llvm::dbgs;
+
+#define DEBUG_TYPE "affine-analysis"
+
+//===----------------------------------------------------------------------===//
+// AffineOpsDialect
+//===----------------------------------------------------------------------===//
+
+AffineOpsDialect::AffineOpsDialect(MLIRContext *context)
+    : Dialect(getDialectNamespace(), context) {
+  addOperations<AffineApplyOp, AffineDmaStartOp, AffineDmaWaitOp, AffineLoadOp,
+                AffineStoreOp,
+#define GET_OP_LIST
+#include "mlir/AffineOps/AffineOps.cpp.inc"
+                >();
+}
+
+/// A utility function to check if a given region is attached to a function.
+static bool isFunctionRegion(Region *region) {
+  return llvm::isa<FuncOp>(region->getParentOp());
+}
+
+/// A utility function to check if a value is defined at the top level of a
+/// function. A value defined at the top level is always a valid symbol.
+bool mlir::isTopLevelSymbol(Value *value) {
+  if (auto *arg = dyn_cast<BlockArgument>(value))
+    return isFunctionRegion(arg->getOwner()->getParent());
+  return isFunctionRegion(value->getDefiningOp()->getParentRegion());
+}
+
+// Value can be used as a dimension id if it is valid as a symbol, or
+// it is an induction variable, or it is a result of affine apply operation
+// with dimension id arguments.
+bool mlir::isValidDim(Value *value) {
+  // The value must be an index type.
+  if (!value->getType().isIndex())
+    return false;
+
+  if (auto *op = value->getDefiningOp()) {
+    // Top level operation or constant operation is ok.
+    if (isFunctionRegion(op->getParentRegion()) || isa<ConstantOp>(op))
+      return true;
+    // Affine apply operation is ok if all of its operands are ok.
+    if (auto applyOp = dyn_cast<AffineApplyOp>(op))
+      return applyOp.isValidDim();
+    // The dim op is okay if its operand memref/tensor is defined at the top
+    // level.
+    if (auto dimOp = dyn_cast<DimOp>(op))
+      return isTopLevelSymbol(dimOp.getOperand());
+    return false;
+  }
+  // This value is a block argument (which also includes 'affine.for' loop IVs).
+  return true;
+}
+
+// Value can be used as a symbol if it is a constant, or it is defined at
+// the top level, or it is a result of affine apply operation with symbol
+// arguments.
+bool mlir::isValidSymbol(Value *value) {
+  // The value must be an index type.
+  if (!value->getType().isIndex())
+    return false;
+
+  if (auto *op = value->getDefiningOp()) {
+    // Top level operation or constant operation is ok.
+    if (isFunctionRegion(op->getParentRegion()) || isa<ConstantOp>(op))
+      return true;
+    // Affine apply operation is ok if all of its operands are ok.
+    if (auto applyOp = dyn_cast<AffineApplyOp>(op))
+      return applyOp.isValidSymbol();
+    // The dim op is okay if its operand memref/tensor is defined at the top
+    // level.
+    if (auto dimOp = dyn_cast<DimOp>(op))
+      return isTopLevelSymbol(dimOp.getOperand());
+    return false;
+  }
+  // Otherwise, check that the value is a top level symbol.
+  return isTopLevelSymbol(value);
+}
+
+// Returns true if 'value' is a valid index to an affine operation (e.g.
+// affine.load, affine.store, affine.dma_start, affine.dma_wait).
+// Returns false otherwise.
+static bool isValidAffineIndexOperand(Value *value) {
+  return isValidDim(value) || isValidSymbol(value);
+}
+
+/// Utility function to verify that a set of operands are valid dimension and
+/// symbol identifiers. The operands should be layed out such that the dimension
+/// operands are before the symbol operands. This function returns failure if
+/// there was an invalid operand. An operation is provided to emit any necessary
+/// errors.
+template <typename OpTy>
+static LogicalResult
+verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands,
+                              unsigned numDims) {
+  unsigned opIt = 0;
+  for (auto *operand : operands) {
+    if (opIt++ < numDims) {
+      if (!isValidDim(operand))
+        return op.emitOpError("operand cannot be used as a dimension id");
+    } else if (!isValidSymbol(operand)) {
+      return op.emitOpError("operand cannot be used as a symbol");
+    }
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// AffineApplyOp
+//===----------------------------------------------------------------------===//
+
+void AffineApplyOp::build(Builder *builder, OperationState *result,
+                          AffineMap map, ArrayRef<Value *> operands) {
+  result->addOperands(operands);
+  result->types.append(map.getNumResults(), builder->getIndexType());
+  result->addAttribute("map", builder->getAffineMapAttr(map));
+}
+
+ParseResult AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
+  auto &builder = parser->getBuilder();
+  auto affineIntTy = builder.getIndexType();
+
+  AffineMapAttr mapAttr;
+  unsigned numDims;
+  if (parser->parseAttribute(mapAttr, "map", result->attributes) ||
+      parseDimAndSymbolList(parser, result->operands, numDims) ||
+      parser->parseOptionalAttributeDict(result->attributes))
+    return failure();
+  auto map = mapAttr.getValue();
+
+  if (map.getNumDims() != numDims ||
+      numDims + map.getNumSymbols() != result->operands.size()) {
+    return parser->emitError(parser->getNameLoc(),
+                             "dimension or symbol index mismatch");
+  }
+
+  result->types.append(map.getNumResults(), affineIntTy);
+  return success();
+}
+
+void AffineApplyOp::print(OpAsmPrinter *p) {
+  *p << "affine.apply " << getAttr("map");
+  printDimAndSymbolList(operand_begin(), operand_end(),
+                        getAffineMap().getNumDims(), p);
+  p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"map"});
+}
+
+LogicalResult AffineApplyOp::verify() {
+  // Check that affine map attribute was specified.
+  auto affineMapAttr = getAttrOfType<AffineMapAttr>("map");
+  if (!affineMapAttr)
+    return emitOpError("requires an affine map");
+
+  // Check input and output dimensions match.
+  auto map = affineMapAttr.getValue();
+
+  // Verify that operand count matches affine map dimension and symbol count.
+  if (getNumOperands() != map.getNumDims() + map.getNumSymbols())
+    return emitOpError(
+        "operand count and affine map dimension and symbol count must match");
+
+  // Verify that all operands are of `index` type.
+  for (Type t : getOperandTypes()) {
+    if (!t.isIndex())
+      return emitOpError("operands must be of type 'index'");
+  }
+
+  if (!getResult()->getType().isIndex())
+    return emitOpError("result must be of type 'index'");
+
+  // Verify that the operands are valid dimension and symbol identifiers.
+  if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
+                                           map.getNumDims())))
+    return failure();
+
+  // Verify that the map only produces one result.
+  if (map.getNumResults() != 1)
+    return emitOpError("mapping must produce one value");
+
+  return success();
+}
+
+// The result of the affine apply operation can be used as a dimension id if it
+// is a CFG value or if it is an Value, and all the operands are valid
+// dimension ids.
+bool AffineApplyOp::isValidDim() {
+  return llvm::all_of(getOperands(),
+                      [](Value *op) { return mlir::isValidDim(op); });
+}
+
+// The result of the affine apply operation can be used as a symbol if it is
+// a CFG value or if it is an Value, and all the operands are symbols.
+bool AffineApplyOp::isValidSymbol() {
+  return llvm::all_of(getOperands(),
+                      [](Value *op) { return mlir::isValidSymbol(op); });
+}
+
+OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
+  auto map = getAffineMap();
+
+  // Fold dims and symbols to existing values.
+  auto expr = map.getResult(0);
+  if (auto dim = expr.dyn_cast<AffineDimExpr>())
+    return getOperand(dim.getPosition());
+  if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
+    return getOperand(map.getNumDims() + sym.getPosition());
+
+  // Otherwise, default to folding the map.
+  SmallVector<Attribute, 1> result;
+  if (failed(map.constantFold(operands, result)))
+    return {};
+  return result[0];
+}
+
+namespace {
+/// An `AffineApplyNormalizer` is a helper class that is not visible to the user
+/// and supports renumbering operands of AffineApplyOp. This acts as a
+/// reindexing map of Value* to positional dims or symbols and allows
+/// simplifications such as:
+///
+/// ```mlir
+///    %1 = affine.apply (d0, d1) -> (d0 - d1) (%0, %0)
+/// ```
+///
+/// into:
+///
+/// ```mlir
+///    %1 = affine.apply () -> (0)
+/// ```
+struct AffineApplyNormalizer {
+  AffineApplyNormalizer(AffineMap map, ArrayRef<Value *> operands);
+
+  /// Returns the AffineMap resulting from normalization.
+  AffineMap getAffineMap() { return affineMap; }
+
+  SmallVector<Value *, 8> getOperands() {
+    SmallVector<Value *, 8> res(reorderedDims);
+    res.append(concatenatedSymbols.begin(), concatenatedSymbols.end());
+    return res;
+  }
+
+private:
+  /// Helper function to insert `v` into the coordinate system of the current
+  /// AffineApplyNormalizer. Returns the AffineDimExpr with the corresponding
+  /// renumbered position.
+  AffineDimExpr renumberOneDim(Value *v);
+
+  /// Given an `other` normalizer, this rewrites `other.affineMap` in the
+  /// coordinate system of the current AffineApplyNormalizer.
+  /// Returns the rewritten AffineMap and updates the dims and symbols of
+  /// `this`.
+  AffineMap renumber(const AffineApplyNormalizer &other);
+
+  /// Maps of Value* to position in `affineMap`.
+  DenseMap<Value *, unsigned> dimValueToPosition;
+
+  /// Ordered dims and symbols matching positional dims and symbols in
+  /// `affineMap`.
+  SmallVector<Value *, 8> reorderedDims;
+  SmallVector<Value *, 8> concatenatedSymbols;
+
+  AffineMap affineMap;
+
+  /// Used with RAII to control the depth at which AffineApply are composed
+  /// recursively. Only accepts depth 1 for now to allow a behavior where a
+  /// newly composed AffineApplyOp does not increase the length of the chain of
+  /// AffineApplyOps. Full composition is implemented iteratively on top of
+  /// this behavior.
+  static unsigned &affineApplyDepth() {
+    static thread_local unsigned depth = 0;
+    return depth;
+  }
+  static constexpr unsigned kMaxAffineApplyDepth = 1;
+
+  AffineApplyNormalizer() { affineApplyDepth()++; }
+
+public:
+  ~AffineApplyNormalizer() { affineApplyDepth()--; }
+};
+} // end anonymous namespace.
+
+AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value *v) {
+  DenseMap<Value *, unsigned>::iterator iterPos;
+  bool inserted = false;
+  std::tie(iterPos, inserted) =
+      dimValueToPosition.insert(std::make_pair(v, dimValueToPosition.size()));
+  if (inserted) {
+    reorderedDims.push_back(v);
+  }
+  return getAffineDimExpr(iterPos->second, v->getContext())
+      .cast<AffineDimExpr>();
+}
+
+AffineMap AffineApplyNormalizer::renumber(const AffineApplyNormalizer &other) {
+  SmallVector<AffineExpr, 8> dimRemapping;
+  for (auto *v : other.reorderedDims) {
+    auto kvp = other.dimValueToPosition.find(v);
+    if (dimRemapping.size() <= kvp->second)
+      dimRemapping.resize(kvp->second + 1);
+    dimRemapping[kvp->second] = renumberOneDim(kvp->first);
+  }
+  unsigned numSymbols = concatenatedSymbols.size();
+  unsigned numOtherSymbols = other.concatenatedSymbols.size();
+  SmallVector<AffineExpr, 8> symRemapping(numOtherSymbols);
+  for (unsigned idx = 0; idx < numOtherSymbols; ++idx) {
+    symRemapping[idx] =
+        getAffineSymbolExpr(idx + numSymbols, other.affineMap.getContext());
+  }
+  concatenatedSymbols.insert(concatenatedSymbols.end(),
+                             other.concatenatedSymbols.begin(),
+                             other.concatenatedSymbols.end());
+  auto map = other.affineMap;
+  return map.replaceDimsAndSymbols(dimRemapping, symRemapping,
+                                   dimRemapping.size(), symRemapping.size());
+}
+
+// Gather the positions of the operands that are produced by an AffineApplyOp.
+static llvm::SetVector<unsigned>
+indicesFromAffineApplyOp(ArrayRef<Value *> operands) {
+  llvm::SetVector<unsigned> res;
+  for (auto en : llvm::enumerate(operands))
+    if (isa_and_nonnull<AffineApplyOp>(en.value()->getDefiningOp()))
+      res.insert(en.index());
+  return res;
+}
+
+// Support the special case of a symbol coming from an AffineApplyOp that needs
+// to be composed into the current AffineApplyOp.
+// This case is handled by rewriting all such symbols into dims for the purpose
+// of allowing mathematical AffineMap composition.
+// Returns an AffineMap where symbols that come from an AffineApplyOp have been
+// rewritten as dims and are ordered after the original dims.
+// TODO(andydavis,ntv): This promotion makes AffineMap lose track of which
+// symbols are represented as dims. This loss is static but can still be
+// recovered dynamically (with `isValidSymbol`). Still this is annoying for the
+// semi-affine map case. A dynamic canonicalization of all dims that are valid
+// symbols (a.k.a `canonicalizePromotedSymbols`) into symbols helps and even
+// results in better simplifications and foldings. But we should evaluate
+// whether this behavior is what we really want after using more.
+static AffineMap promoteComposedSymbolsAsDims(AffineMap map,
+                                              ArrayRef<Value *> symbols) {
+  if (symbols.empty()) {
+    return map;
+  }
+
+  // Sanity check on symbols.
+  for (auto *sym : symbols) {
+    assert(isValidSymbol(sym) && "Expected only valid symbols");
+    (void)sym;
+  }
+
+  // Extract the symbol positions that come from an AffineApplyOp and
+  // needs to be rewritten as dims.
+  auto symPositions = indicesFromAffineApplyOp(symbols);
+  if (symPositions.empty()) {
+    return map;
+  }
+
+  // Create the new map by replacing each symbol at pos by the next new dim.
+  unsigned numDims = map.getNumDims();
+  unsigned numSymbols = map.getNumSymbols();
+  unsigned numNewDims = 0;
+  unsigned numNewSymbols = 0;
+  SmallVector<AffineExpr, 8> symReplacements(numSymbols);
+  for (unsigned i = 0; i < numSymbols; ++i) {
+    symReplacements[i] =
+        symPositions.count(i) > 0
+            ? getAffineDimExpr(numDims + numNewDims++, map.getContext())
+            : getAffineSymbolExpr(numNewSymbols++, map.getContext());
+  }
+  assert(numSymbols >= numNewDims);
+  AffineMap newMap = map.replaceDimsAndSymbols(
+      {}, symReplacements, numDims + numNewDims, numNewSymbols);
+
+  return newMap;
+}
+
+/// The AffineNormalizer composes AffineApplyOp recursively. Its purpose is to
+/// keep a correspondence between the mathematical `map` and the `operands` of
+/// a given AffineApplyOp. This correspondence is maintained by iterating over
+/// the operands and forming an `auxiliaryMap` that can be composed
+/// mathematically with `map`. To keep this correspondence in cases where
+/// symbols are produced by affine.apply operations, we perform a local rewrite
+/// of symbols as dims.
+///
+/// Rationale for locally rewriting symbols as dims:
+/// ================================================
+/// The mathematical composition of AffineMap must always concatenate symbols
+/// because it does not have enough information to do otherwise. For example,
+/// composing `(d0)[s0] -> (d0 + s0)` with itself must produce
+/// `(d0)[s0, s1] -> (d0 + s0 + s1)`.
+///
+/// The result is only equivalent to `(d0)[s0] -> (d0 + 2 * s0)` when
+/// applied to the same mlir::Value* for both s0 and s1.
+/// As a consequence mathematical composition of AffineMap always concatenates
+/// symbols.
+///
+/// When AffineMaps are used in AffineApplyOp however, they may specify
+/// composition via symbols, which is ambiguous mathematically. This corner case
+/// is handled by locally rewriting such symbols that come from AffineApplyOp
+/// into dims and composing through dims.
+/// TODO(andydavis, ntv): Composition via symbols comes at a significant code
+/// complexity. Alternatively we should investigate whether we want to
+/// explicitly disallow symbols coming from affine.apply and instead force the
+/// user to compose symbols beforehand. The annoyances may be small (i.e. 1 or 2
+/// extra API calls for such uses, which haven't popped up until now) and the
+/// benefit potentially big: simpler and more maintainable code for a
+/// non-trivial, recursive, procedure.
+AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
+                                             ArrayRef<Value *> operands)
+    : AffineApplyNormalizer() {
+  static_assert(kMaxAffineApplyDepth > 0, "kMaxAffineApplyDepth must be > 0");
+  assert(map.getNumInputs() == operands.size() &&
+         "number of operands does not match the number of map inputs");
+
+  LLVM_DEBUG(map.print(dbgs() << "\nInput map: "));
+
+  // Promote symbols that come from an AffineApplyOp to dims by rewriting the
+  // map to always refer to:
+  //   (dims, symbols coming from AffineApplyOp, other symbols).
+  // The order of operands can remain unchanged.
+  // This is a simplification that relies on 2 ordering properties:
+  //   1. rewritten symbols always appear after the original dims in the map;
+  //   2. operands are traversed in order and either dispatched to:
+  //      a. auxiliaryExprs (dims and symbols rewritten as dims);
+  //      b. concatenatedSymbols (all other symbols)
+  // This allows operand order to remain unchanged.
+  unsigned numDimsBeforeRewrite = map.getNumDims();
+  map = promoteComposedSymbolsAsDims(map,
+                                     operands.take_back(map.getNumSymbols()));
+
+  LLVM_DEBUG(map.print(dbgs() << "\nRewritten map: "));
+
+  SmallVector<AffineExpr, 8> auxiliaryExprs;
+  bool furtherCompose = (affineApplyDepth() <= kMaxAffineApplyDepth);
+  // We fully spell out the 2 cases below. In this particular instance a little
+  // code duplication greatly improves readability.
+  // Note that the first branch would disappear if we only supported full
+  // composition (i.e. infinite kMaxAffineApplyDepth).
+  if (!furtherCompose) {
+    // 1. Only dispatch dims or symbols.
+    for (auto en : llvm::enumerate(operands)) {
+      auto *t = en.value();
+      assert(t->getType().isIndex());
+      bool isDim = (en.index() < map.getNumDims());
+      if (isDim) {
+        // a. The mathematical composition of AffineMap composes dims.
+        auxiliaryExprs.push_back(renumberOneDim(t));
+      } else {
+        // b. The mathematical composition of AffineMap concatenates symbols.
+        //    We do the same for symbol operands.
+        concatenatedSymbols.push_back(t);
+      }
+    }
+  } else {
+    assert(numDimsBeforeRewrite <= operands.size());
+    // 2. Compose AffineApplyOps and dispatch dims or symbols.
+    for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+      auto *t = operands[i];
+      auto affineApply = dyn_cast_or_null<AffineApplyOp>(t->getDefiningOp());
+      if (affineApply) {
+        // a. Compose affine.apply operations.
+        LLVM_DEBUG(affineApply.getOperation()->print(
+            dbgs() << "\nCompose AffineApplyOp recursively: "));
+        AffineMap affineApplyMap = affineApply.getAffineMap();
+        SmallVector<Value *, 8> affineApplyOperands(
+            affineApply.getOperands().begin(), affineApply.getOperands().end());
+        AffineApplyNormalizer normalizer(affineApplyMap, affineApplyOperands);
+
+        LLVM_DEBUG(normalizer.affineMap.print(
+            dbgs() << "\nRenumber into current normalizer: "));
+
+        auto renumberedMap = renumber(normalizer);
+
+        LLVM_DEBUG(
+            renumberedMap.print(dbgs() << "\nRecursive composition yields: "));
+
+        auxiliaryExprs.push_back(renumberedMap.getResult(0));
+      } else {
+        if (i < numDimsBeforeRewrite) {
+          // b. The mathematical composition of AffineMap composes dims.
+          auxiliaryExprs.push_back(renumberOneDim(t));
+        } else {
+          // c. The mathematical composition of AffineMap concatenates symbols.
+          //    We do the same for symbol operands.
+          concatenatedSymbols.push_back(t);
+        }
+      }
+    }
+  }
+
+  // Early exit if `map` is already composed.
+  if (auxiliaryExprs.empty()) {
+    affineMap = map;
+    return;
+  }
+
+  assert(concatenatedSymbols.size() >= map.getNumSymbols() &&
+         "Unexpected number of concatenated symbols");
+  auto numDims = dimValueToPosition.size();
+  auto numSymbols = concatenatedSymbols.size() - map.getNumSymbols();
+  auto auxiliaryMap = AffineMap::get(numDims, numSymbols, auxiliaryExprs);
+
+  LLVM_DEBUG(map.print(dbgs() << "\nCompose map: "));
+  LLVM_DEBUG(auxiliaryMap.print(dbgs() << "\nWith map: "));
+  LLVM_DEBUG(map.compose(auxiliaryMap).print(dbgs() << "\nResult: "));
+
+  // TODO(andydavis,ntv): Disabling simplification results in major speed gains.
+  // Another option is to cache the results as it is expected a lot of redundant
+  // work is performed in practice.
+  affineMap = simplifyAffineMap(map.compose(auxiliaryMap));
+
+  LLVM_DEBUG(affineMap.print(dbgs() << "\nSimplified result: "));
+  LLVM_DEBUG(dbgs() << "\n");
+}
+
+/// Implements `map` and `operands` composition and simplification to support
+/// `makeComposedAffineApply`. This can be called to achieve the same effects
+/// on `map` and `operands` without creating an AffineApplyOp that needs to be
+/// immediately deleted.
+static void composeAffineMapAndOperands(AffineMap *map,
+                                        SmallVectorImpl<Value *> *operands) {
+  AffineApplyNormalizer normalizer(*map, *operands);
+  auto normalizedMap = normalizer.getAffineMap();
+  auto normalizedOperands = normalizer.getOperands();
+  canonicalizeMapAndOperands(&normalizedMap, &normalizedOperands);
+  *map = normalizedMap;
+  *operands = normalizedOperands;
+  assert(*map);
+}
+
+void mlir::fullyComposeAffineMapAndOperands(
+    AffineMap *map, SmallVectorImpl<Value *> *operands) {
+  while (llvm::any_of(*operands, [](Value *v) {
+    return isa_and_nonnull<AffineApplyOp>(v->getDefiningOp());
+  })) {
+    composeAffineMapAndOperands(map, operands);
+  }
+}
+
+AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
+                                            AffineMap map,
+                                            ArrayRef<Value *> operands) {
+  AffineMap normalizedMap = map;
+  SmallVector<Value *, 8> normalizedOperands(operands.begin(), operands.end());
+  composeAffineMapAndOperands(&normalizedMap, &normalizedOperands);
+  assert(normalizedMap);
+  return b.create<AffineApplyOp>(loc, normalizedMap, normalizedOperands);
+}
+
+// A symbol may appear as a dim in affine.apply operations. This function
+// canonicalizes dims that are valid symbols into actual symbols.
+static void
+canonicalizePromotedSymbols(AffineMap *map,
+                            llvm::SmallVectorImpl<Value *> *operands) {
+  if (!map || operands->empty())
+    return;
+
+  assert(map->getNumInputs() == operands->size() &&
+         "map inputs must match number of operands");
+
+  auto *context = map->getContext();
+  SmallVector<Value *, 8> resultOperands;
+  resultOperands.reserve(operands->size());
+  SmallVector<Value *, 8> remappedSymbols;
+  remappedSymbols.reserve(operands->size());
+  unsigned nextDim = 0;
+  unsigned nextSym = 0;
+  unsigned oldNumSyms = map->getNumSymbols();
+  SmallVector<AffineExpr, 8> dimRemapping(map->getNumDims());
+  for (unsigned i = 0, e = map->getNumInputs(); i != e; ++i) {
+    if (i < map->getNumDims()) {
+      if (isValidSymbol((*operands)[i])) {
+        // This is a valid symbols that appears as a dim, canonicalize it.
+        dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
+        remappedSymbols.push_back((*operands)[i]);
+      } else {
+        dimRemapping[i] = getAffineDimExpr(nextDim++, context);
+        resultOperands.push_back((*operands)[i]);
+      }
+    } else {
+      resultOperands.push_back((*operands)[i]);
+    }
+  }
+
+  resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
+  *operands = resultOperands;
+  *map = map->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
+                                    oldNumSyms + nextSym);
+
+  assert(map->getNumInputs() == operands->size() &&
+         "map inputs must match number of operands");
+}
+
+void mlir::canonicalizeMapAndOperands(
+    AffineMap *map, llvm::SmallVectorImpl<Value *> *operands) {
+  if (!map || operands->empty())
+    return;
+
+  assert(map->getNumInputs() == operands->size() &&
+         "map inputs must match number of operands");
+
+  canonicalizePromotedSymbols(map, operands);
+
+  // Check to see what dims are used.
+  llvm::SmallBitVector usedDims(map->getNumDims());
+  llvm::SmallBitVector usedSyms(map->getNumSymbols());
+  map->walkExprs([&](AffineExpr expr) {
+    if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
+      usedDims[dimExpr.getPosition()] = true;
+    else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
+      usedSyms[symExpr.getPosition()] = true;
+  });
+
+  auto *context = map->getContext();
+
+  SmallVector<Value *, 8> resultOperands;
+  resultOperands.reserve(operands->size());
+
+  llvm::SmallDenseMap<Value *, AffineExpr, 8> seenDims;
+  SmallVector<AffineExpr, 8> dimRemapping(map->getNumDims());
+  unsigned nextDim = 0;
+  for (unsigned i = 0, e = map->getNumDims(); i != e; ++i) {
+    if (usedDims[i]) {
+      auto it = seenDims.find((*operands)[i]);
+      if (it == seenDims.end()) {
+        dimRemapping[i] = getAffineDimExpr(nextDim++, context);
+        resultOperands.push_back((*operands)[i]);
+        seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
+      } else {
+        dimRemapping[i] = it->second;
+      }
+    }
+  }
+  llvm::SmallDenseMap<Value *, AffineExpr, 8> seenSymbols;
+  SmallVector<AffineExpr, 8> symRemapping(map->getNumSymbols());
+  unsigned nextSym = 0;
+  for (unsigned i = 0, e = map->getNumSymbols(); i != e; ++i) {
+    if (usedSyms[i]) {
+      auto it = seenSymbols.find((*operands)[i + map->getNumDims()]);
+      if (it == seenSymbols.end()) {
+        symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
+        resultOperands.push_back((*operands)[i + map->getNumDims()]);
+        seenSymbols.insert(std::make_pair((*operands)[i + map->getNumDims()],
+                                          symRemapping[i]));
+      } else {
+        symRemapping[i] = it->second;
+      }
+    }
+  }
+  *map =
+      map->replaceDimsAndSymbols(dimRemapping, symRemapping, nextDim, nextSym);
+  *operands = resultOperands;
+}
+
+namespace {
+/// Simplify AffineApply operations.
+///
+struct SimplifyAffineApply : public OpRewritePattern<AffineApplyOp> {
+  using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(AffineApplyOp apply,
+                                     PatternRewriter &rewriter) const override {
+    auto map = apply.getAffineMap();
+
+    AffineMap oldMap = map;
+    SmallVector<Value *, 8> resultOperands(apply.getOperands());
+    composeAffineMapAndOperands(&map, &resultOperands);
+    if (map == oldMap)
+      return matchFailure();
+
+    rewriter.replaceOpWithNewOp<AffineApplyOp>(apply, map, resultOperands);
+    return matchSuccess();
+  }
+};
+} // end anonymous namespace.
+
+void AffineApplyOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<SimplifyAffineApply>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// Common canonicalization pattern support logic
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This is a common class used for patterns of the form
+/// "someop(memrefcast) -> someop".  It folds the source of any memref_cast
+/// into the root operation directly.
+struct MemRefCastFolder : public RewritePattern {
+  /// The rootOpName is the name of the root operation to match against.
+  MemRefCastFolder(StringRef rootOpName, MLIRContext *context)
+      : RewritePattern(rootOpName, 1, context) {}
+
+  PatternMatchResult match(Operation *op) const override {
+    for (auto *operand : op->getOperands())
+      if (matchPattern(operand, m_Op<MemRefCastOp>()))
+        return matchSuccess();
+
+    return matchFailure();
+  }
+
+  void rewrite(Operation *op, PatternRewriter &rewriter) const override {
+    for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
+      if (auto *memref = op->getOperand(i)->getDefiningOp())
+        if (auto cast = dyn_cast<MemRefCastOp>(memref))
+          op->setOperand(i, cast.getOperand());
+    rewriter.updatedRootInPlace(op);
+  }
+};
+
+} // end anonymous namespace.
+
+//===----------------------------------------------------------------------===//
+// AffineDmaStartOp
+//===----------------------------------------------------------------------===//
+
+// TODO(b/133776335) Check that map operands are loop IVs or symbols.
+void AffineDmaStartOp::build(Builder *builder, OperationState *result,
+                             Value *srcMemRef, AffineMap srcMap,
+                             ArrayRef<Value *> srcIndices, Value *destMemRef,
+                             AffineMap dstMap, ArrayRef<Value *> destIndices,
+                             Value *tagMemRef, AffineMap tagMap,
+                             ArrayRef<Value *> tagIndices, Value *numElements,
+                             Value *stride, Value *elementsPerStride) {
+  result->addOperands(srcMemRef);
+  result->addAttribute(getSrcMapAttrName(), builder->getAffineMapAttr(srcMap));
+  result->addOperands(srcIndices);
+  result->addOperands(destMemRef);
+  result->addAttribute(getDstMapAttrName(), builder->getAffineMapAttr(dstMap));
+  result->addOperands(destIndices);
+  result->addOperands(tagMemRef);
+  result->addAttribute(getTagMapAttrName(), builder->getAffineMapAttr(tagMap));
+  result->addOperands(tagIndices);
+  result->addOperands(numElements);
+  if (stride) {
+    result->addOperands({stride, elementsPerStride});
+  }
+}
+
+void AffineDmaStartOp::print(OpAsmPrinter *p) {
+  *p << "affine.dma_start " << *getSrcMemRef() << '[';
+  SmallVector<Value *, 8> operands(getSrcIndices());
+  p->printAffineMapOfSSAIds(getSrcMapAttr(), operands);
+  *p << "], " << *getDstMemRef() << '[';
+  operands.assign(getDstIndices().begin(), getDstIndices().end());
+  p->printAffineMapOfSSAIds(getDstMapAttr(), operands);
+  *p << "], " << *getTagMemRef() << '[';
+  operands.assign(getTagIndices().begin(), getTagIndices().end());
+  p->printAffineMapOfSSAIds(getTagMapAttr(), operands);
+  *p << "], " << *getNumElements();
+  if (isStrided()) {
+    *p << ", " << *getStride();
+    *p << ", " << *getNumElementsPerStride();
+  }
+  *p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
+     << getTagMemRefType();
+}
+
+// Parse AffineDmaStartOp.
+// Ex:
+//   affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
+//     %stride, %num_elt_per_stride
+//       : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
+//
+ParseResult AffineDmaStartOp::parse(OpAsmParser *parser,
+                                    OperationState *result) {
+  OpAsmParser::OperandType srcMemRefInfo;
+  AffineMapAttr srcMapAttr;
+  SmallVector<OpAsmParser::OperandType, 4> srcMapOperands;
+  OpAsmParser::OperandType dstMemRefInfo;
+  AffineMapAttr dstMapAttr;
+  SmallVector<OpAsmParser::OperandType, 4> dstMapOperands;
+  OpAsmParser::OperandType tagMemRefInfo;
+  AffineMapAttr tagMapAttr;
+  SmallVector<OpAsmParser::OperandType, 4> tagMapOperands;
+  OpAsmParser::OperandType numElementsInfo;
+  SmallVector<OpAsmParser::OperandType, 2> strideInfo;
+
+  SmallVector<Type, 3> types;
+  auto indexType = parser->getBuilder().getIndexType();
+
+  // Parse and resolve the following list of operands:
+  // *) dst memref followed by its affine maps operands (in square brackets).
+  // *) src memref followed by its affine map operands (in square brackets).
+  // *) tag memref followed by its affine map operands (in square brackets).
+  // *) number of elements transferred by DMA operation.
+  if (parser->parseOperand(srcMemRefInfo) ||
+      parser->parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
+                                     getSrcMapAttrName(), result->attributes) ||
+      parser->parseComma() || parser->parseOperand(dstMemRefInfo) ||
+      parser->parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
+                                     getDstMapAttrName(), result->attributes) ||
+      parser->parseComma() || parser->parseOperand(tagMemRefInfo) ||
+      parser->parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
+                                     getTagMapAttrName(), result->attributes) ||
+      parser->parseComma() || parser->parseOperand(numElementsInfo))
+    return failure();
+
+  // Parse optional stride and elements per stride.
+  if (parser->parseTrailingOperandList(strideInfo)) {
+    return failure();
+  }
+  if (!strideInfo.empty() && strideInfo.size() != 2) {
+    return parser->emitError(parser->getNameLoc(),
+                             "expected two stride related operands");
+  }
+  bool isStrided = strideInfo.size() == 2;
+
+  if (parser->parseColonTypeList(types))
+    return failure();
+
+  if (types.size() != 3)
+    return parser->emitError(parser->getNameLoc(), "expected three types");
+
+  if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) ||
+      parser->resolveOperands(srcMapOperands, indexType, result->operands) ||
+      parser->resolveOperand(dstMemRefInfo, types[1], result->operands) ||
+      parser->resolveOperands(dstMapOperands, indexType, result->operands) ||
+      parser->resolveOperand(tagMemRefInfo, types[2], result->operands) ||
+      parser->resolveOperands(tagMapOperands, indexType, result->operands) ||
+      parser->resolveOperand(numElementsInfo, indexType, result->operands))
+    return failure();
+
+  if (isStrided) {
+    if (parser->resolveOperands(strideInfo, indexType, result->operands))
+      return failure();
+  }
+
+  // Check that src/dst/tag operand counts match their map.numInputs.
+  if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
+      dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
+      tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
+    return parser->emitError(parser->getNameLoc(),
+                             "memref operand count not equal to map.numInputs");
+  return success();
+}
+
+LogicalResult AffineDmaStartOp::verify() {
+  if (!getOperand(getSrcMemRefOperandIndex())->getType().isa<MemRefType>())
+    return emitOpError("expected DMA source to be of memref type");
+  if (!getOperand(getDstMemRefOperandIndex())->getType().isa<MemRefType>())
+    return emitOpError("expected DMA destination to be of memref type");
+  if (!getOperand(getTagMemRefOperandIndex())->getType().isa<MemRefType>())
+    return emitOpError("expected DMA tag to be of memref type");
+
+  // DMAs from different memory spaces supported.
+  if (getSrcMemorySpace() == getDstMemorySpace()) {
+    return emitOpError("DMA should be between different memory spaces");
+  }
+  unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
+                              getDstMap().getNumInputs() +
+                              getTagMap().getNumInputs();
+  if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
+      getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
+    return emitOpError("incorrect number of operands");
+  }
+
+  for (auto *idx : getSrcIndices()) {
+    if (!idx->getType().isIndex())
+      return emitOpError("src index to dma_start must have 'index' type");
+    if (!isValidAffineIndexOperand(idx))
+      return emitOpError("src index must be a dimension or symbol identifier");
+  }
+  for (auto *idx : getDstIndices()) {
+    if (!idx->getType().isIndex())
+      return emitOpError("dst index to dma_start must have 'index' type");
+    if (!isValidAffineIndexOperand(idx))
+      return emitOpError("dst index must be a dimension or symbol identifier");
+  }
+  for (auto *idx : getTagIndices()) {
+    if (!idx->getType().isIndex())
+      return emitOpError("tag index to dma_start must have 'index' type");
+    if (!isValidAffineIndexOperand(idx))
+      return emitOpError("tag index must be a dimension or symbol identifier");
+  }
+  return success();
+}
+
+void AffineDmaStartOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  /// dma_start(memrefcast) -> dma_start
+  results.insert<MemRefCastFolder>(getOperationName(), context);
+}
+
+//===----------------------------------------------------------------------===//
+// AffineDmaWaitOp
+//===----------------------------------------------------------------------===//
+
+// TODO(b/133776335) Check that map operands are loop IVs or symbols.
+void AffineDmaWaitOp::build(Builder *builder, OperationState *result,
+                            Value *tagMemRef, AffineMap tagMap,
+                            ArrayRef<Value *> tagIndices, Value *numElements) {
+  result->addOperands(tagMemRef);
+  result->addAttribute(getTagMapAttrName(), builder->getAffineMapAttr(tagMap));
+  result->addOperands(tagIndices);
+  result->addOperands(numElements);
+}
+
+void AffineDmaWaitOp::print(OpAsmPrinter *p) {
+  *p << "affine.dma_wait " << *getTagMemRef() << '[';
+  SmallVector<Value *, 2> operands(getTagIndices());
+  p->printAffineMapOfSSAIds(getTagMapAttr(), operands);
+  *p << "], ";
+  p->printOperand(getNumElements());
+  *p << " : " << getTagMemRef()->getType();
+}
+
+// Parse AffineDmaWaitOp.
+// Eg:
+//   affine.dma_wait %tag[%index], %num_elements
+//     : memref<1 x i32, (d0) -> (d0), 4>
+//
+ParseResult AffineDmaWaitOp::parse(OpAsmParser *parser,
+                                   OperationState *result) {
+  OpAsmParser::OperandType tagMemRefInfo;
+  AffineMapAttr tagMapAttr;
+  SmallVector<OpAsmParser::OperandType, 2> tagMapOperands;
+  Type type;
+  auto indexType = parser->getBuilder().getIndexType();
+  OpAsmParser::OperandType numElementsInfo;
+
+  // Parse tag memref, its map operands, and dma size.
+  if (parser->parseOperand(tagMemRefInfo) ||
+      parser->parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
+                                     getTagMapAttrName(), result->attributes) ||
+      parser->parseComma() || parser->parseOperand(numElementsInfo) ||
+      parser->parseColonType(type) ||
+      parser->resolveOperand(tagMemRefInfo, type, result->operands) ||
+      parser->resolveOperands(tagMapOperands, indexType, result->operands) ||
+      parser->resolveOperand(numElementsInfo, indexType, result->operands))
+    return failure();
+
+  if (!type.isa<MemRefType>())
+    return parser->emitError(parser->getNameLoc(),
+                             "expected tag to be of memref type");
+
+  if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
+    return parser->emitError(parser->getNameLoc(),
+                             "tag memref operand count != to map.numInputs");
+  return success();
+}
+
+LogicalResult AffineDmaWaitOp::verify() {
+  if (!getOperand(0)->getType().isa<MemRefType>())
+    return emitOpError("expected DMA tag to be of memref type");
+  for (auto *idx : getTagIndices()) {
+    if (!idx->getType().isIndex())
+      return emitOpError("index to dma_wait must have 'index' type");
+    if (!isValidAffineIndexOperand(idx))
+      return emitOpError("index must be a dimension or symbol identifier");
+  }
+  return success();
+}
+
+void AffineDmaWaitOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  /// dma_wait(memrefcast) -> dma_wait
+  results.insert<MemRefCastFolder>(getOperationName(), context);
+}
+
+//===----------------------------------------------------------------------===//
+// AffineForOp
+//===----------------------------------------------------------------------===//
+
+void AffineForOp::build(Builder *builder, OperationState *result,
+                        ArrayRef<Value *> lbOperands, AffineMap lbMap,
+                        ArrayRef<Value *> ubOperands, AffineMap ubMap,
+                        int64_t step) {
+  assert(((!lbMap && lbOperands.empty()) ||
+          lbOperands.size() == lbMap.getNumInputs()) &&
+         "lower bound operand count does not match the affine map");
+  assert(((!ubMap && ubOperands.empty()) ||
+          ubOperands.size() == ubMap.getNumInputs()) &&
+         "upper bound operand count does not match the affine map");
+  assert(step > 0 && "step has to be a positive integer constant");
+
+  // Add an attribute for the step.
+  result->addAttribute(getStepAttrName(),
+                       builder->getIntegerAttr(builder->getIndexType(), step));
+
+  // Add the lower bound.
+  result->addAttribute(getLowerBoundAttrName(),
+                       builder->getAffineMapAttr(lbMap));
+  result->addOperands(lbOperands);
+
+  // Add the upper bound.
+  result->addAttribute(getUpperBoundAttrName(),
+                       builder->getAffineMapAttr(ubMap));
+  result->addOperands(ubOperands);
+
+  // Create a region and a block for the body.  The argument of the region is
+  // the loop induction variable.
+  Region *bodyRegion = result->addRegion();
+  Block *body = new Block();
+  body->addArgument(IndexType::get(builder->getContext()));
+  bodyRegion->push_back(body);
+  ensureTerminator(*bodyRegion, *builder, result->location);
+
+  // Set the operands list as resizable so that we can freely modify the bounds.
+  result->setOperandListToResizable();
+}
+
+void AffineForOp::build(Builder *builder, OperationState *result, int64_t lb,
+                        int64_t ub, int64_t step) {
+  auto lbMap = AffineMap::getConstantMap(lb, builder->getContext());
+  auto ubMap = AffineMap::getConstantMap(ub, builder->getContext());
+  return build(builder, result, {}, lbMap, {}, ubMap, step);
+}
+
+static LogicalResult verify(AffineForOp op) {
+  // Check that the body defines as single block argument for the induction
+  // variable.
+  auto *body = op.getBody();
+  if (body->getNumArguments() != 1 ||
+      !body->getArgument(0)->getType().isIndex())
+    return op.emitOpError(
+        "expected body to have a single index argument for the "
+        "induction variable");
+
+  // Verify that there are enough operands for the bounds.
+  AffineMap lowerBoundMap = op.getLowerBoundMap(),
+            upperBoundMap = op.getUpperBoundMap();
+  if (op.getNumOperands() !=
+      (lowerBoundMap.getNumInputs() + upperBoundMap.getNumInputs()))
+    return op.emitOpError(
+        "operand count must match with affine map dimension and symbol count");
+
+  // Verify that the bound operands are valid dimension/symbols.
+  /// Lower bound.
+  if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(),
+                                           op.getLowerBoundMap().getNumDims())))
+    return failure();
+  /// Upper bound.
+  if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(),
+                                           op.getUpperBoundMap().getNumDims())))
+    return failure();
+  return success();
+}
+
+/// Parse a for operation loop bounds.
+static ParseResult parseBound(bool isLower, OperationState *result,
+                              OpAsmParser *p) {
+  // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
+  // the map has multiple results.
+  bool failedToParsedMinMax =
+      failed(p->parseOptionalKeyword(isLower ? "max" : "min"));
+
+  auto &builder = p->getBuilder();
+  auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName()
+                               : AffineForOp::getUpperBoundAttrName();
+
+  // Parse ssa-id as identity map.
+  SmallVector<OpAsmParser::OperandType, 1> boundOpInfos;
+  if (p->parseOperandList(boundOpInfos))
+    return failure();
+
+  if (!boundOpInfos.empty()) {
+    // Check that only one operand was parsed.
+    if (boundOpInfos.size() > 1)
+      return p->emitError(p->getNameLoc(),
+                          "expected only one loop bound operand");
+
+    // TODO: improve error message when SSA value is not an affine integer.
+    // Currently it is 'use of value ... expects different type than prior uses'
+    if (p->resolveOperand(boundOpInfos.front(), builder.getIndexType(),
+                          result->operands))
+      return failure();
+
+    // Create an identity map using symbol id. This representation is optimized
+    // for storage. Analysis passes may expand it into a multi-dimensional map
+    // if desired.
+    AffineMap map = builder.getSymbolIdentityMap();
+    result->addAttribute(boundAttrName, builder.getAffineMapAttr(map));
+    return success();
+  }
+
+  // Get the attribute location.
+  llvm::SMLoc attrLoc = p->getCurrentLocation();
+
+  Attribute boundAttr;
+  if (p->parseAttribute(boundAttr, builder.getIndexType(), boundAttrName,
+                        result->attributes))
+    return failure();
+
+  // Parse full form - affine map followed by dim and symbol list.
+  if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
+    unsigned currentNumOperands = result->operands.size();
+    unsigned numDims;
+    if (parseDimAndSymbolList(p, result->operands, numDims))
+      return failure();
+
+    auto map = affineMapAttr.getValue();
+    if (map.getNumDims() != numDims)
+      return p->emitError(
+          p->getNameLoc(),
+          "dim operand count and integer set dim count must match");
+
+    unsigned numDimAndSymbolOperands =
+        result->operands.size() - currentNumOperands;
+    if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
+      return p->emitError(
+          p->getNameLoc(),
+          "symbol operand count and integer set symbol count must match");
+
+    // If the map has multiple results, make sure that we parsed the min/max
+    // prefix.
+    if (map.getNumResults() > 1 && failedToParsedMinMax) {
+      if (isLower) {
+        return p->emitError(attrLoc, "lower loop bound affine map with "
+                                     "multiple results requires 'max' prefix");
+      }
+      return p->emitError(attrLoc, "upper loop bound affine map with multiple "
+                                   "results requires 'min' prefix");
+    }
+    return success();
+  }
+
+  // Parse custom assembly form.
+  if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) {
+    result->attributes.pop_back();
+    result->addAttribute(
+        boundAttrName, builder.getAffineMapAttr(
+                           builder.getConstantAffineMap(integerAttr.getInt())));
+    return success();
+  }
+
+  return p->emitError(
+      p->getNameLoc(),
+      "expected valid affine map representation for loop bounds");
+}
+
+ParseResult parseAffineForOp(OpAsmParser *parser, OperationState *result) {
+  auto &builder = parser->getBuilder();
+  OpAsmParser::OperandType inductionVariable;
+  // Parse the induction variable followed by '='.
+  if (parser->parseRegionArgument(inductionVariable) || parser->parseEqual())
+    return failure();
+
+  // Parse loop bounds.
+  if (parseBound(/*isLower=*/true, result, parser) ||
+      parser->parseKeyword("to", " between bounds") ||
+      parseBound(/*isLower=*/false, result, parser))
+    return failure();
+
+  // Parse the optional loop step, we default to 1 if one is not present.
+  if (parser->parseOptionalKeyword("step")) {
+    result->addAttribute(
+        AffineForOp::getStepAttrName(),
+        builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
+  } else {
+    llvm::SMLoc stepLoc = parser->getCurrentLocation();
+    IntegerAttr stepAttr;
+    if (parser->parseAttribute(stepAttr, builder.getIndexType(),
+                               AffineForOp::getStepAttrName().data(),
+                               result->attributes))
+      return failure();
+
+    if (stepAttr.getValue().getSExtValue() < 0)
+      return parser->emitError(
+          stepLoc,
+          "expected step to be representable as a positive signed integer");
+  }
+
+  // Parse the body region.
+  Region *body = result->addRegion();
+  if (parser->parseRegion(*body, inductionVariable, builder.getIndexType()))
+    return failure();
+
+  AffineForOp::ensureTerminator(*body, builder, result->location);
+
+  // Parse the optional attribute list.
+  if (parser->parseOptionalAttributeDict(result->attributes))
+    return failure();
+
+  // Set the operands list as resizable so that we can freely modify the bounds.
+  result->setOperandListToResizable();
+  return success();
+}
+
+static void printBound(AffineMapAttr boundMap,
+                       Operation::operand_range boundOperands,
+                       const char *prefix, OpAsmPrinter *p) {
+  AffineMap map = boundMap.getValue();
+
+  // Check if this bound should be printed using custom assembly form.
+  // The decision to restrict printing custom assembly form to trivial cases
+  // comes from the will to roundtrip MLIR binary -> text -> binary in a
+  // lossless way.
+  // Therefore, custom assembly form parsing and printing is only supported for
+  // zero-operand constant maps and single symbol operand identity maps.
+  if (map.getNumResults() == 1) {
+    AffineExpr expr = map.getResult(0);
+
+    // Print constant bound.
+    if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
+      if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
+        *p << constExpr.getValue();
+        return;
+      }
+    }
+
+    // Print bound that consists of a single SSA symbol if the map is over a
+    // single symbol.
+    if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
+      if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
+        p->printOperand(*boundOperands.begin());
+        return;
+      }
+    }
+  } else {
+    // Map has multiple results. Print 'min' or 'max' prefix.
+    *p << prefix << ' ';
+  }
+
+  // Print the map and its operands.
+  *p << boundMap;
+  printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
+                        map.getNumDims(), p);
+}
+
+void print(OpAsmPrinter *p, AffineForOp op) {
+  *p << "affine.for ";
+  p->printOperand(op.getBody()->getArgument(0));
+  *p << " = ";
+  printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p);
+  *p << " to ";
+  printBound(op.getUpperBoundMapAttr(), op.getUpperBoundOperands(), "min", p);
+
+  if (op.getStep() != 1)
+    *p << " step " << op.getStep();
+  p->printRegion(op.region(),
+                 /*printEntryBlockArgs=*/false,
+                 /*printBlockTerminators=*/false);
+  p->printOptionalAttrDict(op.getAttrs(),
+                           /*elidedAttrs=*/{op.getLowerBoundAttrName(),
+                                            op.getUpperBoundAttrName(),
+                                            op.getStepAttrName()});
+}
+
+namespace {
+/// This is a pattern to fold constant loop bounds.
+struct AffineForLoopBoundFolder : public OpRewritePattern<AffineForOp> {
+  using OpRewritePattern<AffineForOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(AffineForOp forOp,
+                                     PatternRewriter &rewriter) const override {
+    auto foldLowerOrUpperBound = [&forOp](bool lower) {
+      // Check to see if each of the operands is the result of a constant.  If
+      // so, get the value.  If not, ignore it.
+      SmallVector<Attribute, 8> operandConstants;
+      auto boundOperands =
+          lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
+      for (auto *operand : boundOperands) {
+        Attribute operandCst;
+        matchPattern(operand, m_Constant(&operandCst));
+        operandConstants.push_back(operandCst);
+      }
+
+      AffineMap boundMap =
+          lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
+      assert(boundMap.getNumResults() >= 1 &&
+             "bound maps should have at least one result");
+      SmallVector<Attribute, 4> foldedResults;
+      if (failed(boundMap.constantFold(operandConstants, foldedResults)))
+        return failure();
+
+      // Compute the max or min as applicable over the results.
+      assert(!foldedResults.empty() &&
+             "bounds should have at least one result");
+      auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
+      for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
+        auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
+        maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
+                         : llvm::APIntOps::smin(maxOrMin, foldedResult);
+      }
+      lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
+            : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
+      return success();
+    };
+
+    // Try to fold the lower bound.
+    bool folded = false;
+    if (!forOp.hasConstantLowerBound())
+      folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
+
+    // Try to fold the upper bound.
+    if (!forOp.hasConstantUpperBound())
+      folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
+
+    // If any of the bounds were folded we return success.
+    if (!folded)
+      return matchFailure();
+    rewriter.updatedRootInPlace(forOp);
+    return matchSuccess();
+  }
+};
+} // end anonymous namespace
+
+void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                              MLIRContext *context) {
+  results.insert<AffineForLoopBoundFolder>(context);
+}
+
+AffineBound AffineForOp::getLowerBound() {
+  auto lbMap = getLowerBoundMap();
+  return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap);
+}
+
+AffineBound AffineForOp::getUpperBound() {
+  auto lbMap = getLowerBoundMap();
+  auto ubMap = getUpperBoundMap();
+  return AffineBound(AffineForOp(*this), lbMap.getNumInputs(), getNumOperands(),
+                     ubMap);
+}
+
+void AffineForOp::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) {
+  assert(lbOperands.size() == map.getNumInputs());
+  assert(map.getNumResults() >= 1 && "bound map has at least one result");
+
+  SmallVector<Value *, 4> newOperands(lbOperands.begin(), lbOperands.end());
+
+  auto ubOperands = getUpperBoundOperands();
+  newOperands.append(ubOperands.begin(), ubOperands.end());
+  getOperation()->setOperands(newOperands);
+
+  setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
+}
+
+void AffineForOp::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) {
+  assert(ubOperands.size() == map.getNumInputs());
+  assert(map.getNumResults() >= 1 && "bound map has at least one result");
+
+  SmallVector<Value *, 4> newOperands(getLowerBoundOperands());
+  newOperands.append(ubOperands.begin(), ubOperands.end());
+  getOperation()->setOperands(newOperands);
+
+  setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
+}
+
+void AffineForOp::setLowerBoundMap(AffineMap map) {
+  auto lbMap = getLowerBoundMap();
+  assert(lbMap.getNumDims() == map.getNumDims() &&
+         lbMap.getNumSymbols() == map.getNumSymbols());
+  assert(map.getNumResults() >= 1 && "bound map has at least one result");
+  (void)lbMap;
+  setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
+}
+
+void AffineForOp::setUpperBoundMap(AffineMap map) {
+  auto ubMap = getUpperBoundMap();
+  assert(ubMap.getNumDims() == map.getNumDims() &&
+         ubMap.getNumSymbols() == map.getNumSymbols());
+  assert(map.getNumResults() >= 1 && "bound map has at least one result");
+  (void)ubMap;
+  setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
+}
+
+bool AffineForOp::hasConstantLowerBound() {
+  return getLowerBoundMap().isSingleConstant();
+}
+
+bool AffineForOp::hasConstantUpperBound() {
+  return getUpperBoundMap().isSingleConstant();
+}
+
+int64_t AffineForOp::getConstantLowerBound() {
+  return getLowerBoundMap().getSingleConstantResult();
+}
+
+int64_t AffineForOp::getConstantUpperBound() {
+  return getUpperBoundMap().getSingleConstantResult();
+}
+
+void AffineForOp::setConstantLowerBound(int64_t value) {
+  setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
+}
+
+void AffineForOp::setConstantUpperBound(int64_t value) {
+  setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
+}
+
+AffineForOp::operand_range AffineForOp::getLowerBoundOperands() {
+  return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
+}
+
+AffineForOp::operand_range AffineForOp::getUpperBoundOperands() {
+  return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
+}
+
+bool AffineForOp::matchingBoundOperandList() {
+  auto lbMap = getLowerBoundMap();
+  auto ubMap = getUpperBoundMap();
+  if (lbMap.getNumDims() != ubMap.getNumDims() ||
+      lbMap.getNumSymbols() != ubMap.getNumSymbols())
+    return false;
+
+  unsigned numOperands = lbMap.getNumInputs();
+  for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
+    // Compare Value *'s.
+    if (getOperand(i) != getOperand(numOperands + i))
+      return false;
+  }
+  return true;
+}
+
+/// Returns if the provided value is the induction variable of a AffineForOp.
+bool mlir::isForInductionVar(Value *val) {
+  return getForInductionVarOwner(val) != AffineForOp();
+}
+
+/// Returns the loop parent of an induction variable. If the provided value is
+/// not an induction variable, then return nullptr.
+AffineForOp mlir::getForInductionVarOwner(Value *val) {
+  auto *ivArg = dyn_cast<BlockArgument>(val);
+  if (!ivArg || !ivArg->getOwner())
+    return AffineForOp();
+  auto *containingInst = ivArg->getOwner()->getParent()->getParentOp();
+  return dyn_cast<AffineForOp>(containingInst);
+}
+
+/// Extracts the induction variables from a list of AffineForOps and returns
+/// them.
+void mlir::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
+                                   SmallVectorImpl<Value *> *ivs) {
+  ivs->reserve(forInsts.size());
+  for (auto forInst : forInsts)
+    ivs->push_back(forInst.getInductionVar());
+}
+
+//===----------------------------------------------------------------------===//
+// AffineIfOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(AffineIfOp op) {
+  // Verify that we have a condition attribute.
+  auto conditionAttr =
+      op.getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
+  if (!conditionAttr)
+    return op.emitOpError(
+        "requires an integer set attribute named 'condition'");
+
+  // Verify that there are enough operands for the condition.
+  IntegerSet condition = conditionAttr.getValue();
+  if (op.getNumOperands() != condition.getNumOperands())
+    return op.emitOpError(
+        "operand count and condition integer set dimension and "
+        "symbol count must match");
+
+  // Verify that the operands are valid dimension/symbols.
+  if (failed(verifyDimAndSymbolIdentifiers(
+          op, op.getOperation()->getNonSuccessorOperands(),
+          condition.getNumDims())))
+    return failure();
+
+  // Verify that the entry of each child region does not have arguments.
+  for (auto &region : op.getOperation()->getRegions()) {
+    for (auto &b : region)
+      if (b.getNumArguments() != 0)
+        return op.emitOpError(
+            "requires that child entry blocks have no arguments");
+  }
+  return success();
+}
+
+ParseResult parseAffineIfOp(OpAsmParser *parser, OperationState *result) {
+  // Parse the condition attribute set.
+  IntegerSetAttr conditionAttr;
+  unsigned numDims;
+  if (parser->parseAttribute(conditionAttr, AffineIfOp::getConditionAttrName(),
+                             result->attributes) ||
+      parseDimAndSymbolList(parser, result->operands, numDims))
+    return failure();
+
+  // Verify the condition operands.
+  auto set = conditionAttr.getValue();
+  if (set.getNumDims() != numDims)
+    return parser->emitError(
+        parser->getNameLoc(),
+        "dim operand count and integer set dim count must match");
+  if (numDims + set.getNumSymbols() != result->operands.size())
+    return parser->emitError(
+        parser->getNameLoc(),
+        "symbol operand count and integer set symbol count must match");
+
+  // Create the regions for 'then' and 'else'.  The latter must be created even
+  // if it remains empty for the validity of the operation.
+  result->regions.reserve(2);
+  Region *thenRegion = result->addRegion();
+  Region *elseRegion = result->addRegion();
+
+  // Parse the 'then' region.
+  if (parser->parseRegion(*thenRegion, {}, {}))
+    return failure();
+  AffineIfOp::ensureTerminator(*thenRegion, parser->getBuilder(),
+                               result->location);
+
+  // If we find an 'else' keyword then parse the 'else' region.
+  if (!parser->parseOptionalKeyword("else")) {
+    if (parser->parseRegion(*elseRegion, {}, {}))
+      return failure();
+    AffineIfOp::ensureTerminator(*elseRegion, parser->getBuilder(),
+                                 result->location);
+  }
+
+  // Parse the optional attribute list.
+  if (parser->parseOptionalAttributeDict(result->attributes))
+    return failure();
+
+  return success();
+}
+
+void print(OpAsmPrinter *p, AffineIfOp op) {
+  auto conditionAttr =
+      op.getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
+  *p << "affine.if " << conditionAttr;
+  printDimAndSymbolList(op.operand_begin(), op.operand_end(),
+                        conditionAttr.getValue().getNumDims(), p);
+  p->printRegion(op.thenRegion(),
+                 /*printEntryBlockArgs=*/false,
+                 /*printBlockTerminators=*/false);
+
+  // Print the 'else' regions if it has any blocks.
+  auto &elseRegion = op.elseRegion();
+  if (!elseRegion.empty()) {
+    *p << " else";
+    p->printRegion(elseRegion,
+                   /*printEntryBlockArgs=*/false,
+                   /*printBlockTerminators=*/false);
+  }
+
+  // Print the attribute list.
+  p->printOptionalAttrDict(op.getAttrs(),
+                           /*elidedAttrs=*/op.getConditionAttrName());
+}
+
+IntegerSet AffineIfOp::getIntegerSet() {
+  return getAttrOfType<IntegerSetAttr>(getConditionAttrName()).getValue();
+}
+void AffineIfOp::setIntegerSet(IntegerSet newSet) {
+  setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet));
+}
+
+//===----------------------------------------------------------------------===//
+// AffineLoadOp
+//===----------------------------------------------------------------------===//
+
+void AffineLoadOp::build(Builder *builder, OperationState *result,
+                         AffineMap map, ArrayRef<Value *> operands) {
+  result->addOperands(operands);
+  if (map)
+    result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
+  auto memrefType = operands[0]->getType().cast<MemRefType>();
+  result->types.push_back(memrefType.getElementType());
+}
+
+void AffineLoadOp::build(Builder *builder, OperationState *result,
+                         Value *memref, ArrayRef<Value *> indices) {
+  result->addOperands(memref);
+  result->addOperands(indices);
+  auto memrefType = memref->getType().cast<MemRefType>();
+  auto rank = memrefType.getRank();
+  // Create identity map for memrefs with at least one dimension or () -> ()
+  // for zero-dimensional memrefs.
+  auto map = rank ? builder->getMultiDimIdentityMap(rank)
+                  : builder->getEmptyAffineMap();
+  result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
+  result->types.push_back(memrefType.getElementType());
+}
+
+ParseResult AffineLoadOp::parse(OpAsmParser *parser, OperationState *result) {
+  auto &builder = parser->getBuilder();
+  auto affineIntTy = builder.getIndexType();
+
+  MemRefType type;
+  OpAsmParser::OperandType memrefInfo;
+  AffineMapAttr mapAttr;
+  SmallVector<OpAsmParser::OperandType, 1> mapOperands;
+  return failure(
+      parser->parseOperand(memrefInfo) ||
+      parser->parseAffineMapOfSSAIds(mapOperands, mapAttr, getMapAttrName(),
+                                     result->attributes) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(type) ||
+      parser->resolveOperand(memrefInfo, type, result->operands) ||
+      parser->resolveOperands(mapOperands, affineIntTy, result->operands) ||
+      parser->addTypeToList(type.getElementType(), result->types));
+}
+
+void AffineLoadOp::print(OpAsmPrinter *p) {
+  *p << "affine.load " << *getMemRef() << '[';
+  AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
+  if (mapAttr) {
+    SmallVector<Value *, 2> operands(getIndices());
+    p->printAffineMapOfSSAIds(mapAttr, operands);
+  }
+  *p << ']';
+  p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()});
+  *p << " : " << getMemRefType();
+}
+
+LogicalResult AffineLoadOp::verify() {
+  if (getType() != getMemRefType().getElementType())
+    return emitOpError("result type must match element type of memref");
+
+  auto mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
+  if (mapAttr) {
+    AffineMap map = getAttrOfType<AffineMapAttr>(getMapAttrName()).getValue();
+    if (map.getNumResults() != getMemRefType().getRank())
+      return emitOpError("affine.load affine map num results must equal"
+                         " memref rank");
+    if (map.getNumInputs() != getNumOperands() - 1)
+      return emitOpError("expects as many subscripts as affine map inputs");
+  } else {
+    if (getMemRefType().getRank() != getNumOperands() - 1)
+      return emitOpError(
+          "expects the number of subscripts to be equal to memref rank");
+  }
+
+  for (auto *idx : getIndices()) {
+    if (!idx->getType().isIndex())
+      return emitOpError("index to load must have 'index' type");
+    if (!isValidAffineIndexOperand(idx))
+      return emitOpError("index must be a dimension or symbol identifier");
+  }
+  return success();
+}
+
+void AffineLoadOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  /// load(memrefcast) -> load
+  results.insert<MemRefCastFolder>(getOperationName(), context);
+}
+
+//===----------------------------------------------------------------------===//
+// AffineStoreOp
+//===----------------------------------------------------------------------===//
+
+void AffineStoreOp::build(Builder *builder, OperationState *result,
+                          Value *valueToStore, AffineMap map,
+                          ArrayRef<Value *> operands) {
+  result->addOperands(valueToStore);
+  result->addOperands(operands);
+  if (map)
+    result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
+}
+
+void AffineStoreOp::build(Builder *builder, OperationState *result,
+                          Value *valueToStore, Value *memref,
+                          ArrayRef<Value *> operands) {
+  result->addOperands(valueToStore);
+  result->addOperands(memref);
+  result->addOperands(operands);
+  auto memrefType = memref->getType().cast<MemRefType>();
+  auto map = builder->getMultiDimIdentityMap(memrefType.getRank());
+  result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
+}
+
+ParseResult AffineStoreOp::parse(OpAsmParser *parser, OperationState *result) {
+  auto affineIntTy = parser->getBuilder().getIndexType();
+
+  MemRefType type;
+  OpAsmParser::OperandType storeValueInfo;
+  OpAsmParser::OperandType memrefInfo;
+  AffineMapAttr mapAttr;
+  SmallVector<OpAsmParser::OperandType, 1> mapOperands;
+  return failure(
+      parser->parseOperand(storeValueInfo) || parser->parseComma() ||
+      parser->parseOperand(memrefInfo) ||
+      parser->parseAffineMapOfSSAIds(mapOperands, mapAttr, getMapAttrName(),
+                                     result->attributes) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(type) ||
+      parser->resolveOperand(storeValueInfo, type.getElementType(),
+                             result->operands) ||
+      parser->resolveOperand(memrefInfo, type, result->operands) ||
+      parser->resolveOperands(mapOperands, affineIntTy, result->operands));
+}
+
+void AffineStoreOp::print(OpAsmPrinter *p) {
+  *p << "affine.store " << *getValueToStore();
+  *p << ", " << *getMemRef() << '[';
+  AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
+  if (mapAttr) {
+    SmallVector<Value *, 2> operands(getIndices());
+    p->printAffineMapOfSSAIds(mapAttr, operands);
+  }
+  *p << ']';
+  p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()});
+  *p << " : " << getMemRefType();
+}
+
+LogicalResult AffineStoreOp::verify() {
+  // First operand must have same type as memref element type.
+  if (getValueToStore()->getType() != getMemRefType().getElementType())
+    return emitOpError("first operand must have same type memref element type");
+
+  auto mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
+  if (mapAttr) {
+    AffineMap map = mapAttr.getValue();
+    if (map.getNumResults() != getMemRefType().getRank())
+      return emitOpError("affine.store affine map num results must equal"
+                         " memref rank");
+    if (map.getNumInputs() != getNumOperands() - 2)
+      return emitOpError("expects as many subscripts as affine map inputs");
+  } else {
+    if (getMemRefType().getRank() != getNumOperands() - 2)
+      return emitOpError(
+          "expects the number of subscripts to be equal to memref rank");
+  }
+
+  for (auto *idx : getIndices()) {
+    if (!idx->getType().isIndex())
+      return emitOpError("index to store must have 'index' type");
+    if (!isValidAffineIndexOperand(idx))
+      return emitOpError("index must be a dimension or symbol identifier");
+  }
+  return success();
+}
+
+void AffineStoreOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  /// load(memrefcast) -> load
+  results.insert<MemRefCastFolder>(getOperationName(), context);
+}
+
+#define GET_OP_CLASSES
+#include "mlir/AffineOps/AffineOps.cpp.inc"
diff --git a/third_party/mlir/lib/AffineOps/CMakeLists.txt b/third_party/mlir/lib/AffineOps/CMakeLists.txt
new file mode 100644
index 0000000..a8cf24e
--- /dev/null
+++ b/third_party/mlir/lib/AffineOps/CMakeLists.txt
@@ -0,0 +1,10 @@
+add_llvm_library(MLIRAffineOps
+  AffineOps.cpp
+  DialectRegistration.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/AffineOps
+  )
+add_dependencies(MLIRAffineOps MLIRAffineOpsIncGen MLIRIR MLIRStandardOps)
+target_link_libraries(MLIRAffineOps MLIRIR MLIRStandardOps)
+
diff --git a/third_party/mlir/lib/AffineOps/DialectRegistration.cpp b/third_party/mlir/lib/AffineOps/DialectRegistration.cpp
new file mode 100644
index 0000000..0afb32c
--- /dev/null
+++ b/third_party/mlir/lib/AffineOps/DialectRegistration.cpp
@@ -0,0 +1,22 @@
+//===- DialectRegistration.cpp - Register Affine Op dialect ---------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/AffineOps/AffineOps.h"
+using namespace mlir;
+
+// Static initialization for Affine op dialect registration.
+static DialectRegistration<AffineOpsDialect> StandardOps;
diff --git a/third_party/mlir/lib/Analysis/AffineAnalysis.cpp b/third_party/mlir/lib/Analysis/AffineAnalysis.cpp
new file mode 100644
index 0000000..28c4eae
--- /dev/null
+++ b/third_party/mlir/lib/Analysis/AffineAnalysis.cpp
@@ -0,0 +1,896 @@
+//===- AffineAnalysis.cpp - Affine structures analysis routines -----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements miscellaneous analysis routines for affine structures
+// (expressions, maps, sets), and other utilities relying on such analysis.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/IR/AffineExprVisitor.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Support/MathExtras.h"
+#include "mlir/Support/STLExtras.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+#define DEBUG_TYPE "affine-analysis"
+
+using namespace mlir;
+
+using llvm::dbgs;
+
+/// Returns the sequence of AffineApplyOp Operations operation in
+/// 'affineApplyOps', which are reachable via a search starting from 'operands',
+/// and ending at operands which are not defined by AffineApplyOps.
+// TODO(andydavis) Add a method to AffineApplyOp which forward substitutes
+// the AffineApplyOp into any user AffineApplyOps.
+void mlir::getReachableAffineApplyOps(
+    ArrayRef<Value *> operands, SmallVectorImpl<Operation *> &affineApplyOps) {
+  struct State {
+    // The ssa value for this node in the DFS traversal.
+    Value *value;
+    // The operand index of 'value' to explore next during DFS traversal.
+    unsigned operandIndex;
+  };
+  SmallVector<State, 4> worklist;
+  for (auto *operand : operands) {
+    worklist.push_back({operand, 0});
+  }
+
+  while (!worklist.empty()) {
+    State &state = worklist.back();
+    auto *opInst = state.value->getDefiningOp();
+    // Note: getDefiningOp will return nullptr if the operand is not an
+    // Operation (i.e. block argument), which is a terminator for the search.
+    if (!isa_and_nonnull<AffineApplyOp>(opInst)) {
+      worklist.pop_back();
+      continue;
+    }
+
+    if (state.operandIndex == 0) {
+      // Pre-Visit: Add 'opInst' to reachable sequence.
+      affineApplyOps.push_back(opInst);
+    }
+    if (state.operandIndex < opInst->getNumOperands()) {
+      // Visit: Add next 'affineApplyOp' operand to worklist.
+      // Get next operand to visit at 'operandIndex'.
+      auto *nextOperand = opInst->getOperand(state.operandIndex);
+      // Increment 'operandIndex' in 'state'.
+      ++state.operandIndex;
+      // Add 'nextOperand' to worklist.
+      worklist.push_back({nextOperand, 0});
+    } else {
+      // Post-visit: done visiting operands AffineApplyOp, pop off stack.
+      worklist.pop_back();
+    }
+  }
+}
+
+// Builds a system of constraints with dimensional identifiers corresponding to
+// the loop IVs of the forOps appearing in that order. Any symbols founds in
+// the bound operands are added as symbols in the system. Returns failure for
+// the yet unimplemented cases.
+// TODO(andydavis,bondhugula) Handle non-unit steps through local variables or
+// stride information in FlatAffineConstraints. (For eg., by using iv - lb %
+// step = 0 and/or by introducing a method in FlatAffineConstraints
+// setExprStride(ArrayRef<int64_t> expr, int64_t stride)
+LogicalResult mlir::getIndexSet(MutableArrayRef<AffineForOp> forOps,
+                                FlatAffineConstraints *domain) {
+  SmallVector<Value *, 4> indices;
+  extractForInductionVars(forOps, &indices);
+  // Reset while associated Values in 'indices' to the domain.
+  domain->reset(forOps.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
+  for (auto forOp : forOps) {
+    // Add constraints from forOp's bounds.
+    if (failed(domain->addAffineForOpDomain(forOp)))
+      return failure();
+  }
+  return success();
+}
+
+// Computes the iteration domain for 'opInst' and populates 'indexSet', which
+// encapsulates the constraints involving loops surrounding 'opInst' and
+// potentially involving any Function symbols. The dimensional identifiers in
+// 'indexSet' correspond to the loops surounding 'op' from outermost to
+// innermost.
+// TODO(andydavis) Add support to handle IfInsts surrounding 'op'.
+static LogicalResult getInstIndexSet(Operation *op,
+                                     FlatAffineConstraints *indexSet) {
+  // TODO(andydavis) Extend this to gather enclosing IfInsts and consider
+  // factoring it out into a utility function.
+  SmallVector<AffineForOp, 4> loops;
+  getLoopIVs(*op, &loops);
+  return getIndexSet(loops, indexSet);
+}
+
+// ValuePositionMap manages the mapping from Values which represent dimension
+// and symbol identifiers from 'src' and 'dst' access functions to positions
+// in new space where some Values are kept separate (using addSrc/DstValue)
+// and some Values are merged (addSymbolValue).
+// Position lookups return the absolute position in the new space which
+// has the following format:
+//
+//   [src-dim-identifiers] [dst-dim-identifiers] [symbol-identifers]
+//
+// Note: access function non-IV dimension identifiers (that have 'dimension'
+// positions in the access function position space) are assigned as symbols
+// in the output position space. Convienience access functions which lookup
+// an Value in multiple maps are provided (i.e. getSrcDimOrSymPos) to handle
+// the common case of resolving positions for all access function operands.
+//
+// TODO(andydavis) Generalize this: could take a template parameter for
+// the number of maps (3 in the current case), and lookups could take indices
+// of maps to check. So getSrcDimOrSymPos would be "getPos(value, {0, 2})".
+class ValuePositionMap {
+public:
+  void addSrcValue(Value *value) {
+    if (addValueAt(value, &srcDimPosMap, numSrcDims))
+      ++numSrcDims;
+  }
+  void addDstValue(Value *value) {
+    if (addValueAt(value, &dstDimPosMap, numDstDims))
+      ++numDstDims;
+  }
+  void addSymbolValue(Value *value) {
+    if (addValueAt(value, &symbolPosMap, numSymbols))
+      ++numSymbols;
+  }
+  unsigned getSrcDimOrSymPos(Value *value) const {
+    return getDimOrSymPos(value, srcDimPosMap, 0);
+  }
+  unsigned getDstDimOrSymPos(Value *value) const {
+    return getDimOrSymPos(value, dstDimPosMap, numSrcDims);
+  }
+  unsigned getSymPos(Value *value) const {
+    auto it = symbolPosMap.find(value);
+    assert(it != symbolPosMap.end());
+    return numSrcDims + numDstDims + it->second;
+  }
+
+  unsigned getNumSrcDims() const { return numSrcDims; }
+  unsigned getNumDstDims() const { return numDstDims; }
+  unsigned getNumDims() const { return numSrcDims + numDstDims; }
+  unsigned getNumSymbols() const { return numSymbols; }
+
+private:
+  bool addValueAt(Value *value, DenseMap<Value *, unsigned> *posMap,
+                  unsigned position) {
+    auto it = posMap->find(value);
+    if (it == posMap->end()) {
+      (*posMap)[value] = position;
+      return true;
+    }
+    return false;
+  }
+  unsigned getDimOrSymPos(Value *value,
+                          const DenseMap<Value *, unsigned> &dimPosMap,
+                          unsigned dimPosOffset) const {
+    auto it = dimPosMap.find(value);
+    if (it != dimPosMap.end()) {
+      return dimPosOffset + it->second;
+    }
+    it = symbolPosMap.find(value);
+    assert(it != symbolPosMap.end());
+    return numSrcDims + numDstDims + it->second;
+  }
+
+  unsigned numSrcDims = 0;
+  unsigned numDstDims = 0;
+  unsigned numSymbols = 0;
+  DenseMap<Value *, unsigned> srcDimPosMap;
+  DenseMap<Value *, unsigned> dstDimPosMap;
+  DenseMap<Value *, unsigned> symbolPosMap;
+};
+
+// Builds a map from Value to identifier position in a new merged identifier
+// list, which is the result of merging dim/symbol lists from src/dst
+// iteration domains, the format of which is as follows:
+//
+//   [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers, const_term]
+//
+// This method populates 'valuePosMap' with mappings from operand Values in
+// 'srcAccessMap'/'dstAccessMap' (as well as those in 'srcDomain'/'dstDomain')
+// to the position of these values in the merged list.
+static void buildDimAndSymbolPositionMaps(
+    const FlatAffineConstraints &srcDomain,
+    const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap,
+    const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap,
+    FlatAffineConstraints *dependenceConstraints) {
+  auto updateValuePosMap = [&](ArrayRef<Value *> values, bool isSrc) {
+    for (unsigned i = 0, e = values.size(); i < e; ++i) {
+      auto *value = values[i];
+      if (!isForInductionVar(values[i])) {
+        assert(isValidSymbol(values[i]) &&
+               "access operand has to be either a loop IV or a symbol");
+        valuePosMap->addSymbolValue(value);
+      } else if (isSrc) {
+        valuePosMap->addSrcValue(value);
+      } else {
+        valuePosMap->addDstValue(value);
+      }
+    }
+  };
+
+  SmallVector<Value *, 4> srcValues, destValues;
+  srcDomain.getIdValues(0, srcDomain.getNumDimAndSymbolIds(), &srcValues);
+  dstDomain.getIdValues(0, dstDomain.getNumDimAndSymbolIds(), &destValues);
+  // Update value position map with identifiers from src iteration domain.
+  updateValuePosMap(srcValues, /*isSrc=*/true);
+  // Update value position map with identifiers from dst iteration domain.
+  updateValuePosMap(destValues, /*isSrc=*/false);
+  // Update value position map with identifiers from src access function.
+  updateValuePosMap(srcAccessMap.getOperands(), /*isSrc=*/true);
+  // Update value position map with identifiers from dst access function.
+  updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false);
+}
+
+// Sets up dependence constraints columns appropriately, in the format:
+// [src-dim-ids, dst-dim-ids, symbol-ids, local-ids, const_term]
+void initDependenceConstraints(const FlatAffineConstraints &srcDomain,
+                               const FlatAffineConstraints &dstDomain,
+                               const AffineValueMap &srcAccessMap,
+                               const AffineValueMap &dstAccessMap,
+                               const ValuePositionMap &valuePosMap,
+                               FlatAffineConstraints *dependenceConstraints) {
+  // Calculate number of equalities/inequalities and columns required to
+  // initialize FlatAffineConstraints for 'dependenceDomain'.
+  unsigned numIneq =
+      srcDomain.getNumInequalities() + dstDomain.getNumInequalities();
+  AffineMap srcMap = srcAccessMap.getAffineMap();
+  assert(srcMap.getNumResults() == dstAccessMap.getAffineMap().getNumResults());
+  unsigned numEq = srcMap.getNumResults();
+  unsigned numDims = srcDomain.getNumDimIds() + dstDomain.getNumDimIds();
+  unsigned numSymbols = valuePosMap.getNumSymbols();
+  unsigned numLocals = srcDomain.getNumLocalIds() + dstDomain.getNumLocalIds();
+  unsigned numIds = numDims + numSymbols + numLocals;
+  unsigned numCols = numIds + 1;
+
+  // Set flat affine constraints sizes and reserving space for constraints.
+  dependenceConstraints->reset(numIneq, numEq, numCols, numDims, numSymbols,
+                               numLocals);
+
+  // Set values corresponding to dependence constraint identifiers.
+  SmallVector<Value *, 4> srcLoopIVs, dstLoopIVs;
+  srcDomain.getIdValues(0, srcDomain.getNumDimIds(), &srcLoopIVs);
+  dstDomain.getIdValues(0, dstDomain.getNumDimIds(), &dstLoopIVs);
+
+  dependenceConstraints->setIdValues(0, srcLoopIVs.size(), srcLoopIVs);
+  dependenceConstraints->setIdValues(
+      srcLoopIVs.size(), srcLoopIVs.size() + dstLoopIVs.size(), dstLoopIVs);
+
+  // Set values for the symbolic identifier dimensions.
+  auto setSymbolIds = [&](ArrayRef<Value *> values) {
+    for (auto *value : values) {
+      if (!isForInductionVar(value)) {
+        assert(isValidSymbol(value) && "expected symbol");
+        dependenceConstraints->setIdValue(valuePosMap.getSymPos(value), value);
+      }
+    }
+  };
+
+  setSymbolIds(srcAccessMap.getOperands());
+  setSymbolIds(dstAccessMap.getOperands());
+
+  SmallVector<Value *, 8> srcSymbolValues, dstSymbolValues;
+  srcDomain.getIdValues(srcDomain.getNumDimIds(),
+                        srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues);
+  dstDomain.getIdValues(dstDomain.getNumDimIds(),
+                        dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues);
+  setSymbolIds(srcSymbolValues);
+  setSymbolIds(dstSymbolValues);
+
+  for (unsigned i = 0, e = dependenceConstraints->getNumDimAndSymbolIds();
+       i < e; i++)
+    assert(dependenceConstraints->getIds()[i].hasValue());
+}
+
+// Adds iteration domain constraints from 'srcDomain' and 'dstDomain' into
+// 'dependenceDomain'.
+// Uses 'valuePosMap' to determine the position in 'dependenceDomain' to which a
+// srcDomain/dstDomain Value maps.
+static void addDomainConstraints(const FlatAffineConstraints &srcDomain,
+                                 const FlatAffineConstraints &dstDomain,
+                                 const ValuePositionMap &valuePosMap,
+                                 FlatAffineConstraints *dependenceDomain) {
+  unsigned depNumDimsAndSymbolIds = dependenceDomain->getNumDimAndSymbolIds();
+
+  SmallVector<int64_t, 4> cst(dependenceDomain->getNumCols());
+
+  auto addDomain = [&](bool isSrc, bool isEq, unsigned localOffset) {
+    const FlatAffineConstraints &domain = isSrc ? srcDomain : dstDomain;
+    unsigned numCsts =
+        isEq ? domain.getNumEqualities() : domain.getNumInequalities();
+    unsigned numDimAndSymbolIds = domain.getNumDimAndSymbolIds();
+    auto at = [&](unsigned i, unsigned j) -> int64_t {
+      return isEq ? domain.atEq(i, j) : domain.atIneq(i, j);
+    };
+    auto map = [&](unsigned i) -> int64_t {
+      return isSrc ? valuePosMap.getSrcDimOrSymPos(domain.getIdValue(i))
+                   : valuePosMap.getDstDimOrSymPos(domain.getIdValue(i));
+    };
+
+    for (unsigned i = 0; i < numCsts; ++i) {
+      // Zero fill.
+      std::fill(cst.begin(), cst.end(), 0);
+      // Set coefficients for identifiers corresponding to domain.
+      for (unsigned j = 0; j < numDimAndSymbolIds; ++j)
+        cst[map(j)] = at(i, j);
+      // Local terms.
+      for (unsigned j = 0, e = domain.getNumLocalIds(); j < e; j++)
+        cst[depNumDimsAndSymbolIds + localOffset + j] =
+            at(i, numDimAndSymbolIds + j);
+      // Set constant term.
+      cst[cst.size() - 1] = at(i, domain.getNumCols() - 1);
+      // Add constraint.
+      if (isEq)
+        dependenceDomain->addEquality(cst);
+      else
+        dependenceDomain->addInequality(cst);
+    }
+  };
+
+  // Add equalities from src domain.
+  addDomain(/*isSrc=*/true, /*isEq=*/true, /*localOffset=*/0);
+  // Add inequalities from src domain.
+  addDomain(/*isSrc=*/true, /*isEq=*/false, /*localOffset=*/0);
+  // Add equalities from dst domain.
+  addDomain(/*isSrc=*/false, /*isEq=*/true,
+            /*localOffset=*/srcDomain.getNumLocalIds());
+  // Add inequalities from dst domain.
+  addDomain(/*isSrc=*/false, /*isEq=*/false,
+            /*localOffset=*/srcDomain.getNumLocalIds());
+}
+
+// Adds equality constraints that equate src and dst access functions
+// represented by 'srcAccessMap' and 'dstAccessMap' for each result.
+// Requires that 'srcAccessMap' and 'dstAccessMap' have the same results count.
+// For example, given the following two accesses functions to a 2D memref:
+//
+//   Source access function:
+//     (a0 * d0 + a1 * s0 + a2, b0 * d0 + b1 * s0 + b2)
+//
+//   Destination acceses function:
+//     (c0 * d0 + c1 * s0 + c2, f0 * d0 + f1 * s0 + f2)
+//
+// This method constructs the following equality constraints in
+// 'dependenceDomain', by equating the access functions for each result
+// (i.e. each memref dim). Notice that 'd0' for the destination access function
+// is mapped into 'd0' in the equality constraint:
+//
+//   d0      d1      s0         c
+//   --      --      --         --
+//   a0     -c0      (a1 - c1)  (a1 - c2) = 0
+//   b0     -f0      (b1 - f1)  (b1 - f2) = 0
+//
+// Returns failure if any AffineExpr cannot be flattened (due to it being
+// semi-affine). Returns success otherwise.
+static LogicalResult
+addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
+                           const AffineValueMap &dstAccessMap,
+                           const ValuePositionMap &valuePosMap,
+                           FlatAffineConstraints *dependenceDomain) {
+  AffineMap srcMap = srcAccessMap.getAffineMap();
+  AffineMap dstMap = dstAccessMap.getAffineMap();
+  assert(srcMap.getNumResults() == dstMap.getNumResults());
+  unsigned numResults = srcMap.getNumResults();
+
+  unsigned srcNumIds = srcMap.getNumDims() + srcMap.getNumSymbols();
+  ArrayRef<Value *> srcOperands = srcAccessMap.getOperands();
+
+  unsigned dstNumIds = dstMap.getNumDims() + dstMap.getNumSymbols();
+  ArrayRef<Value *> dstOperands = dstAccessMap.getOperands();
+
+  std::vector<SmallVector<int64_t, 8>> srcFlatExprs;
+  std::vector<SmallVector<int64_t, 8>> destFlatExprs;
+  FlatAffineConstraints srcLocalVarCst, destLocalVarCst;
+  // Get flattened expressions for the source destination maps.
+  if (failed(getFlattenedAffineExprs(srcMap, &srcFlatExprs, &srcLocalVarCst)) ||
+      failed(getFlattenedAffineExprs(dstMap, &destFlatExprs, &destLocalVarCst)))
+    return failure();
+
+  unsigned domNumLocalIds = dependenceDomain->getNumLocalIds();
+  unsigned srcNumLocalIds = srcLocalVarCst.getNumLocalIds();
+  unsigned dstNumLocalIds = destLocalVarCst.getNumLocalIds();
+  unsigned numLocalIdsToAdd = srcNumLocalIds + dstNumLocalIds;
+  for (unsigned i = 0; i < numLocalIdsToAdd; i++) {
+    dependenceDomain->addLocalId(dependenceDomain->getNumLocalIds());
+  }
+
+  unsigned numDims = dependenceDomain->getNumDimIds();
+  unsigned numSymbols = dependenceDomain->getNumSymbolIds();
+  unsigned numSrcLocalIds = srcLocalVarCst.getNumLocalIds();
+  unsigned newLocalIdOffset = numDims + numSymbols + domNumLocalIds;
+
+  // Equality to add.
+  SmallVector<int64_t, 8> eq(dependenceDomain->getNumCols());
+  for (unsigned i = 0; i < numResults; ++i) {
+    // Zero fill.
+    std::fill(eq.begin(), eq.end(), 0);
+
+    // Flattened AffineExpr for src result 'i'.
+    const auto &srcFlatExpr = srcFlatExprs[i];
+    // Set identifier coefficients from src access function.
+    for (unsigned j = 0, e = srcOperands.size(); j < e; ++j)
+      eq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] = srcFlatExpr[j];
+    // Local terms.
+    for (unsigned j = 0, e = srcNumLocalIds; j < e; j++)
+      eq[newLocalIdOffset + j] = srcFlatExpr[srcNumIds + j];
+    // Set constant term.
+    eq[eq.size() - 1] = srcFlatExpr[srcFlatExpr.size() - 1];
+
+    // Flattened AffineExpr for dest result 'i'.
+    const auto &destFlatExpr = destFlatExprs[i];
+    // Set identifier coefficients from dst access function.
+    for (unsigned j = 0, e = dstOperands.size(); j < e; ++j)
+      eq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] -= destFlatExpr[j];
+    // Local terms.
+    for (unsigned j = 0, e = dstNumLocalIds; j < e; j++)
+      eq[newLocalIdOffset + numSrcLocalIds + j] = -destFlatExpr[dstNumIds + j];
+    // Set constant term.
+    eq[eq.size() - 1] -= destFlatExpr[destFlatExpr.size() - 1];
+
+    // Add equality constraint.
+    dependenceDomain->addEquality(eq);
+  }
+
+  // Add equality constraints for any operands that are defined by constant ops.
+  auto addEqForConstOperands = [&](ArrayRef<Value *> operands) {
+    for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+      if (isForInductionVar(operands[i]))
+        continue;
+      auto *symbol = operands[i];
+      assert(isValidSymbol(symbol));
+      // Check if the symbol is a constant.
+      if (auto cOp = dyn_cast_or_null<ConstantIndexOp>(symbol->getDefiningOp()))
+        dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol),
+                                          cOp.getValue());
+    }
+  };
+
+  // Add equality constraints for any src symbols defined by constant ops.
+  addEqForConstOperands(srcOperands);
+  // Add equality constraints for any dst symbols defined by constant ops.
+  addEqForConstOperands(dstOperands);
+
+  // By construction (see flattener), local var constraints will not have any
+  // equalities.
+  assert(srcLocalVarCst.getNumEqualities() == 0 &&
+         destLocalVarCst.getNumEqualities() == 0);
+  // Add inequalities from srcLocalVarCst and destLocalVarCst into the
+  // dependence domain.
+  SmallVector<int64_t, 8> ineq(dependenceDomain->getNumCols());
+  for (unsigned r = 0, e = srcLocalVarCst.getNumInequalities(); r < e; r++) {
+    std::fill(ineq.begin(), ineq.end(), 0);
+
+    // Set identifier coefficients from src local var constraints.
+    for (unsigned j = 0, e = srcOperands.size(); j < e; ++j)
+      ineq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] =
+          srcLocalVarCst.atIneq(r, j);
+    // Local terms.
+    for (unsigned j = 0, e = srcNumLocalIds; j < e; j++)
+      ineq[newLocalIdOffset + j] = srcLocalVarCst.atIneq(r, srcNumIds + j);
+    // Set constant term.
+    ineq[ineq.size() - 1] =
+        srcLocalVarCst.atIneq(r, srcLocalVarCst.getNumCols() - 1);
+    dependenceDomain->addInequality(ineq);
+  }
+
+  for (unsigned r = 0, e = destLocalVarCst.getNumInequalities(); r < e; r++) {
+    std::fill(ineq.begin(), ineq.end(), 0);
+    // Set identifier coefficients from dest local var constraints.
+    for (unsigned j = 0, e = dstOperands.size(); j < e; ++j)
+      ineq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] =
+          destLocalVarCst.atIneq(r, j);
+    // Local terms.
+    for (unsigned j = 0, e = dstNumLocalIds; j < e; j++)
+      ineq[newLocalIdOffset + numSrcLocalIds + j] =
+          destLocalVarCst.atIneq(r, dstNumIds + j);
+    // Set constant term.
+    ineq[ineq.size() - 1] =
+        destLocalVarCst.atIneq(r, destLocalVarCst.getNumCols() - 1);
+
+    dependenceDomain->addInequality(ineq);
+  }
+  return success();
+}
+
+// Returns the number of outer loop common to 'src/dstDomain'.
+// Loops common to 'src/dst' domains are added to 'commonLoops' if non-null.
+static unsigned
+getNumCommonLoops(const FlatAffineConstraints &srcDomain,
+                  const FlatAffineConstraints &dstDomain,
+                  SmallVectorImpl<AffineForOp> *commonLoops = nullptr) {
+  // Find the number of common loops shared by src and dst accesses.
+  unsigned minNumLoops =
+      std::min(srcDomain.getNumDimIds(), dstDomain.getNumDimIds());
+  unsigned numCommonLoops = 0;
+  for (unsigned i = 0; i < minNumLoops; ++i) {
+    if (!isForInductionVar(srcDomain.getIdValue(i)) ||
+        !isForInductionVar(dstDomain.getIdValue(i)) ||
+        srcDomain.getIdValue(i) != dstDomain.getIdValue(i))
+      break;
+    if (commonLoops != nullptr)
+      commonLoops->push_back(getForInductionVarOwner(srcDomain.getIdValue(i)));
+    ++numCommonLoops;
+  }
+  if (commonLoops != nullptr)
+    assert(commonLoops->size() == numCommonLoops);
+  return numCommonLoops;
+}
+
+// Returns Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
+static Block *getCommonBlock(const MemRefAccess &srcAccess,
+                             const MemRefAccess &dstAccess,
+                             const FlatAffineConstraints &srcDomain,
+                             unsigned numCommonLoops) {
+  if (numCommonLoops == 0) {
+    auto *block = srcAccess.opInst->getBlock();
+    while (!llvm::isa<FuncOp>(block->getParentOp())) {
+      block = block->getParentOp()->getBlock();
+    }
+    return block;
+  }
+  auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1);
+  auto forOp = getForInductionVarOwner(commonForValue);
+  assert(forOp && "commonForValue was not an induction variable");
+  return forOp.getBody();
+}
+
+// Returns true if the ancestor operation of 'srcAccess' appears before the
+// ancestor operation of 'dstAccess' in the common ancestral block. Returns
+// false otherwise.
+// Note that because 'srcAccess' or 'dstAccess' may be nested in conditionals,
+// the function is named 'srcAppearsBeforeDstInCommonBlock'. Note that
+// 'numCommonLoops' is the number of contiguous surrounding outer loops.
+static bool srcAppearsBeforeDstInAncestralBlock(
+    const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
+    const FlatAffineConstraints &srcDomain, unsigned numCommonLoops) {
+  // Get Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
+  auto *commonBlock =
+      getCommonBlock(srcAccess, dstAccess, srcDomain, numCommonLoops);
+  // Check the dominance relationship between the respective ancestors of the
+  // src and dst in the Block of the innermost among the common loops.
+  auto *srcInst = commonBlock->findAncestorInstInBlock(*srcAccess.opInst);
+  assert(srcInst != nullptr);
+  auto *dstInst = commonBlock->findAncestorInstInBlock(*dstAccess.opInst);
+  assert(dstInst != nullptr);
+
+  // Determine whether dstInst comes after srcInst.
+  return srcInst->isBeforeInBlock(dstInst);
+}
+
+// Adds ordering constraints to 'dependenceDomain' based on number of loops
+// common to 'src/dstDomain' and requested 'loopDepth'.
+// Note that 'loopDepth' cannot exceed the number of common loops plus one.
+// EX: Given a loop nest of depth 2 with IVs 'i' and 'j':
+// *) If 'loopDepth == 1' then one constraint is added: i' >= i + 1
+// *) If 'loopDepth == 2' then two constraints are added: i == i' and j' > j + 1
+// *) If 'loopDepth == 3' then two constraints are added: i == i' and j == j'
+static void addOrderingConstraints(const FlatAffineConstraints &srcDomain,
+                                   const FlatAffineConstraints &dstDomain,
+                                   unsigned loopDepth,
+                                   FlatAffineConstraints *dependenceDomain) {
+  unsigned numCols = dependenceDomain->getNumCols();
+  SmallVector<int64_t, 4> eq(numCols);
+  unsigned numSrcDims = srcDomain.getNumDimIds();
+  unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain);
+  unsigned numCommonLoopConstraints = std::min(numCommonLoops, loopDepth);
+  for (unsigned i = 0; i < numCommonLoopConstraints; ++i) {
+    std::fill(eq.begin(), eq.end(), 0);
+    eq[i] = -1;
+    eq[i + numSrcDims] = 1;
+    if (i == loopDepth - 1) {
+      eq[numCols - 1] = -1;
+      dependenceDomain->addInequality(eq);
+    } else {
+      dependenceDomain->addEquality(eq);
+    }
+  }
+}
+
+// Computes distance and direction vectors in 'dependences', by adding
+// variables to 'dependenceDomain' which represent the difference of the IVs,
+// eliminating all other variables, and reading off distance vectors from
+// equality constraints (if possible), and direction vectors from inequalities.
+static void computeDirectionVector(
+    const FlatAffineConstraints &srcDomain,
+    const FlatAffineConstraints &dstDomain, unsigned loopDepth,
+    FlatAffineConstraints *dependenceDomain,
+    llvm::SmallVector<DependenceComponent, 2> *dependenceComponents) {
+  // Find the number of common loops shared by src and dst accesses.
+  SmallVector<AffineForOp, 4> commonLoops;
+  unsigned numCommonLoops =
+      getNumCommonLoops(srcDomain, dstDomain, &commonLoops);
+  if (numCommonLoops == 0)
+    return;
+  // Compute direction vectors for requested loop depth.
+  unsigned numIdsToEliminate = dependenceDomain->getNumIds();
+  // Add new variables to 'dependenceDomain' to represent the direction
+  // constraints for each shared loop.
+  for (unsigned j = 0; j < numCommonLoops; ++j) {
+    dependenceDomain->addDimId(j);
+  }
+
+  // Add equality contraints for each common loop, setting newly introduced
+  // variable at column 'j' to the 'dst' IV minus the 'src IV.
+  SmallVector<int64_t, 4> eq;
+  eq.resize(dependenceDomain->getNumCols());
+  unsigned numSrcDims = srcDomain.getNumDimIds();
+  // Constraint variables format:
+  // [num-common-loops][num-src-dim-ids][num-dst-dim-ids][num-symbols][constant]
+  for (unsigned j = 0; j < numCommonLoops; ++j) {
+    std::fill(eq.begin(), eq.end(), 0);
+    eq[j] = 1;
+    eq[j + numCommonLoops] = 1;
+    eq[j + numCommonLoops + numSrcDims] = -1;
+    dependenceDomain->addEquality(eq);
+  }
+
+  // Eliminate all variables other than the direction variables just added.
+  dependenceDomain->projectOut(numCommonLoops, numIdsToEliminate);
+
+  // Scan each common loop variable column and set direction vectors based
+  // on eliminated constraint system.
+  dependenceComponents->resize(numCommonLoops);
+  for (unsigned j = 0; j < numCommonLoops; ++j) {
+    (*dependenceComponents)[j].op = commonLoops[j].getOperation();
+    auto lbConst = dependenceDomain->getConstantLowerBound(j);
+    (*dependenceComponents)[j].lb =
+        lbConst.getValueOr(std::numeric_limits<int64_t>::min());
+    auto ubConst = dependenceDomain->getConstantUpperBound(j);
+    (*dependenceComponents)[j].ub =
+        ubConst.getValueOr(std::numeric_limits<int64_t>::max());
+  }
+}
+
+// Populates 'accessMap' with composition of AffineApplyOps reachable from
+// indices of MemRefAccess.
+void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
+  // Get affine map from AffineLoad/Store.
+  AffineMap map;
+  if (auto loadOp = dyn_cast<AffineLoadOp>(opInst))
+    map = loadOp.getAffineMap();
+  else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst))
+    map = storeOp.getAffineMap();
+  SmallVector<Value *, 8> operands(indices.begin(), indices.end());
+  fullyComposeAffineMapAndOperands(&map, &operands);
+  map = simplifyAffineMap(map);
+  canonicalizeMapAndOperands(&map, &operands);
+  accessMap->reset(map, operands);
+}
+
+// Builds a flat affine constraint system to check if there exists a dependence
+// between memref accesses 'srcAccess' and 'dstAccess'.
+// Returns 'NoDependence' if the accesses can be definitively shown not to
+// access the same element.
+// Returns 'HasDependence' if the accesses do access the same element.
+// Returns 'Failure' if an error or unsupported case was encountered.
+// If a dependence exists, returns in 'dependenceComponents' a direction
+// vector for the dependence, with a component for each loop IV in loops
+// common to both accesses (see Dependence in AffineAnalysis.h for details).
+//
+// The memref access dependence check is comprised of the following steps:
+// *) Compute access functions for each access. Access functions are computed
+//    using AffineValueMaps initialized with the indices from an access, then
+//    composed with AffineApplyOps reachable from operands of that access,
+//    until operands of the AffineValueMap are loop IVs or symbols.
+// *) Build iteration domain constraints for each access. Iteration domain
+//    constraints are pairs of inequality contraints representing the
+//    upper/lower loop bounds for each AffineForOp in the loop nest associated
+//    with each access.
+// *) Build dimension and symbol position maps for each access, which map
+//    Values from access functions and iteration domains to their position
+//    in the merged constraint system built by this method.
+//
+// This method builds a constraint system with the following column format:
+//
+//  [src-dim-identifiers, dst-dim-identifiers, symbols, constant]
+//
+// For example, given the following MLIR code with with "source" and
+// "destination" accesses to the same memref labled, and symbols %M, %N, %K:
+//
+//   affine.for %i0 = 0 to 100 {
+//     affine.for %i1 = 0 to 50 {
+//       %a0 = affine.apply
+//         (d0, d1) -> (d0 * 2 - d1 * 4 + s1, d1 * 3 - s0) (%i0, %i1)[%M, %N]
+//       // Source memref access.
+//       store %v0, %m[%a0#0, %a0#1] : memref<4x4xf32>
+//     }
+//   }
+//
+//   affine.for %i2 = 0 to 100 {
+//     affine.for %i3 = 0 to 50 {
+//       %a1 = affine.apply
+//         (d0, d1) -> (d0 * 7 + d1 * 9 - s1, d1 * 11 + s0) (%i2, %i3)[%K, %M]
+//       // Destination memref access.
+//       %v1 = load %m[%a1#0, %a1#1] : memref<4x4xf32>
+//     }
+//   }
+//
+// The access functions would be the following:
+//
+//   src: (%i0 * 2 - %i1 * 4 + %N, %i1 * 3 - %M)
+//   dst: (%i2 * 7 + %i3 * 9 - %M, %i3 * 11 - %K)
+//
+// The iteration domains for the src/dst accesses would be the following:
+//
+//   src: 0 <= %i0 <= 100, 0 <= %i1 <= 50
+//   dst: 0 <= %i2 <= 100, 0 <= %i3 <= 50
+//
+// The symbols by both accesses would be assigned to a canonical position order
+// which will be used in the dependence constraint system:
+//
+//   symbol name: %M  %N  %K
+//   symbol  pos:  0   1   2
+//
+// Equality constraints are built by equating each result of src/destination
+// access functions. For this example, the following two equality constraints
+// will be added to the dependence constraint system:
+//
+//   [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const]
+//      2         -4        -7        -9       1      1     0     0    = 0
+//      0          3         0        -11     -1      0     1     0    = 0
+//
+// Inequality constraints from the iteration domain will be meged into
+// the dependence constraint system
+//
+//   [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const]
+//       1         0         0         0        0     0     0     0    >= 0
+//      -1         0         0         0        0     0     0     100  >= 0
+//       0         1         0         0        0     0     0     0    >= 0
+//       0        -1         0         0        0     0     0     50   >= 0
+//       0         0         1         0        0     0     0     0    >= 0
+//       0         0        -1         0        0     0     0     100  >= 0
+//       0         0         0         1        0     0     0     0    >= 0
+//       0         0         0        -1        0     0     0     50   >= 0
+//
+//
+// TODO(andydavis) Support AffineExprs mod/floordiv/ceildiv.
+DependenceResult mlir::checkMemrefAccessDependence(
+    const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
+    unsigned loopDepth, FlatAffineConstraints *dependenceConstraints,
+    llvm::SmallVector<DependenceComponent, 2> *dependenceComponents,
+    bool allowRAR) {
+  LLVM_DEBUG(llvm::dbgs() << "Checking for dependence at depth: "
+                          << Twine(loopDepth) << " between:\n";);
+  LLVM_DEBUG(srcAccess.opInst->dump(););
+  LLVM_DEBUG(dstAccess.opInst->dump(););
+
+  // Return 'NoDependence' if these accesses do not access the same memref.
+  if (srcAccess.memref != dstAccess.memref)
+    return DependenceResult::NoDependence;
+
+  // Return 'NoDependence' if one of these accesses is not an AffineStoreOp.
+  if (!allowRAR && !isa<AffineStoreOp>(srcAccess.opInst) &&
+      !isa<AffineStoreOp>(dstAccess.opInst))
+    return DependenceResult::NoDependence;
+
+  // Get composed access function for 'srcAccess'.
+  AffineValueMap srcAccessMap;
+  srcAccess.getAccessMap(&srcAccessMap);
+
+  // Get composed access function for 'dstAccess'.
+  AffineValueMap dstAccessMap;
+  dstAccess.getAccessMap(&dstAccessMap);
+
+  // Get iteration domain for the 'srcAccess' operation.
+  FlatAffineConstraints srcDomain;
+  if (failed(getInstIndexSet(srcAccess.opInst, &srcDomain)))
+    return DependenceResult::Failure;
+
+  // Get iteration domain for 'dstAccess' operation.
+  FlatAffineConstraints dstDomain;
+  if (failed(getInstIndexSet(dstAccess.opInst, &dstDomain)))
+    return DependenceResult::Failure;
+
+  // Return 'NoDependence' if loopDepth > numCommonLoops and if the ancestor
+  // operation of 'srcAccess' does not properly dominate the ancestor
+  // operation of 'dstAccess' in the same common operation block.
+  // Note: this check is skipped if 'allowRAR' is true, because because RAR
+  // deps can exist irrespective of lexicographic ordering b/w src and dst.
+  unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain);
+  assert(loopDepth <= numCommonLoops + 1);
+  if (!allowRAR && loopDepth > numCommonLoops &&
+      !srcAppearsBeforeDstInAncestralBlock(srcAccess, dstAccess, srcDomain,
+                                           numCommonLoops)) {
+    return DependenceResult::NoDependence;
+  }
+  // Build dim and symbol position maps for each access from access operand
+  // Value to position in merged contstraint system.
+  ValuePositionMap valuePosMap;
+  buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap,
+                                dstAccessMap, &valuePosMap,
+                                dependenceConstraints);
+
+  initDependenceConstraints(srcDomain, dstDomain, srcAccessMap, dstAccessMap,
+                            valuePosMap, dependenceConstraints);
+
+  assert(valuePosMap.getNumDims() ==
+         srcDomain.getNumDimIds() + dstDomain.getNumDimIds());
+
+  // Create memref access constraint by equating src/dst access functions.
+  // Note that this check is conservative, and will fail in the future when
+  // local variables for mod/div exprs are supported.
+  if (failed(addMemRefAccessConstraints(srcAccessMap, dstAccessMap, valuePosMap,
+                                        dependenceConstraints)))
+    return DependenceResult::Failure;
+
+  // Add 'src' happens before 'dst' ordering constraints.
+  addOrderingConstraints(srcDomain, dstDomain, loopDepth,
+                         dependenceConstraints);
+  // Add src and dst domain constraints.
+  addDomainConstraints(srcDomain, dstDomain, valuePosMap,
+                       dependenceConstraints);
+
+  // Return 'NoDependence' if the solution space is empty: no dependence.
+  if (dependenceConstraints->isEmpty()) {
+    return DependenceResult::NoDependence;
+  }
+
+  // Compute dependence direction vector and return true.
+  if (dependenceComponents != nullptr) {
+    computeDirectionVector(srcDomain, dstDomain, loopDepth,
+                           dependenceConstraints, dependenceComponents);
+  }
+
+  LLVM_DEBUG(llvm::dbgs() << "Dependence polyhedron:\n");
+  LLVM_DEBUG(dependenceConstraints->dump());
+  return DependenceResult::HasDependence;
+}
+
+/// Gathers dependence components for dependences between all ops in loop nest
+/// rooted at 'forOp' at loop depths in range [1, maxLoopDepth].
+void mlir::getDependenceComponents(
+    AffineForOp forOp, unsigned maxLoopDepth,
+    std::vector<llvm::SmallVector<DependenceComponent, 2>> *depCompsVec) {
+  // Collect all load and store ops in loop nest rooted at 'forOp'.
+  SmallVector<Operation *, 8> loadAndStoreOpInsts;
+  forOp.getOperation()->walk([&](Operation *opInst) {
+    if (isa<AffineLoadOp>(opInst) || isa<AffineStoreOp>(opInst))
+      loadAndStoreOpInsts.push_back(opInst);
+  });
+
+  unsigned numOps = loadAndStoreOpInsts.size();
+  for (unsigned d = 1; d <= maxLoopDepth; ++d) {
+    for (unsigned i = 0; i < numOps; ++i) {
+      auto *srcOpInst = loadAndStoreOpInsts[i];
+      MemRefAccess srcAccess(srcOpInst);
+      for (unsigned j = 0; j < numOps; ++j) {
+        auto *dstOpInst = loadAndStoreOpInsts[j];
+        MemRefAccess dstAccess(dstOpInst);
+
+        FlatAffineConstraints dependenceConstraints;
+        llvm::SmallVector<DependenceComponent, 2> depComps;
+        // TODO(andydavis,bondhugula) Explore whether it would be profitable
+        // to pre-compute and store deps instead of repeatedly checking.
+        DependenceResult result = checkMemrefAccessDependence(
+            srcAccess, dstAccess, d, &dependenceConstraints, &depComps);
+        if (hasDependence(result))
+          depCompsVec->push_back(depComps);
+      }
+    }
+  }
+}
diff --git a/third_party/mlir/lib/Analysis/AffineStructures.cpp b/third_party/mlir/lib/Analysis/AffineStructures.cpp
new file mode 100644
index 0000000..46e4535
--- /dev/null
+++ b/third_party/mlir/lib/Analysis/AffineStructures.cpp
@@ -0,0 +1,2806 @@
+//===- AffineStructures.cpp - MLIR Affine Structures Class-----------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Structures for affine/polyhedral analysis of MLIR functions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/IR/AffineExprVisitor.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Support/MathExtras.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+#define DEBUG_TYPE "affine-structures"
+
+using namespace mlir;
+using llvm::SmallDenseMap;
+using llvm::SmallDenseSet;
+using llvm::SmallPtrSet;
+
+namespace {
+
+// See comments for SimpleAffineExprFlattener.
+// An AffineExprFlattener extends a SimpleAffineExprFlattener by recording
+// constraint information associated with mod's, floordiv's, and ceildiv's
+// in FlatAffineConstraints 'localVarCst'.
+struct AffineExprFlattener : public SimpleAffineExprFlattener {
+public:
+  // Constraints connecting newly introduced local variables (for mod's and
+  // div's) to existing (dimensional and symbolic) ones. These are always
+  // inequalities.
+  FlatAffineConstraints localVarCst;
+
+  AffineExprFlattener(unsigned nDims, unsigned nSymbols, MLIRContext *ctx)
+      : SimpleAffineExprFlattener(nDims, nSymbols) {
+    localVarCst.reset(nDims, nSymbols, /*numLocals=*/0);
+  }
+
+private:
+  // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
+  // The local identifier added is always a floordiv of a pure add/mul affine
+  // function of other identifiers, coefficients of which are specified in
+  // `dividend' and with respect to the positive constant `divisor'. localExpr
+  // is the simplified tree expression (AffineExpr) corresponding to the
+  // quantifier.
+  void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
+                          AffineExpr localExpr) override {
+    SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr);
+    // Update localVarCst.
+    localVarCst.addLocalFloorDiv(dividend, divisor);
+  }
+};
+
+} // end anonymous namespace
+
+// Flattens the expressions in map. Returns failure if 'expr' was unable to be
+// flattened (i.e., semi-affine expressions not handled yet).
+static LogicalResult getFlattenedAffineExprs(
+    ArrayRef<AffineExpr> exprs, unsigned numDims, unsigned numSymbols,
+    std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
+    FlatAffineConstraints *localVarCst) {
+  if (exprs.empty()) {
+    localVarCst->reset(numDims, numSymbols);
+    return success();
+  }
+
+  AffineExprFlattener flattener(numDims, numSymbols, exprs[0].getContext());
+  // Use the same flattener to simplify each expression successively. This way
+  // local identifiers / expressions are shared.
+  for (auto expr : exprs) {
+    if (!expr.isPureAffine())
+      return failure();
+
+    flattener.walkPostOrder(expr);
+  }
+
+  assert(flattener.operandExprStack.size() == exprs.size());
+  flattenedExprs->clear();
+  flattenedExprs->assign(flattener.operandExprStack.begin(),
+                         flattener.operandExprStack.end());
+
+  if (localVarCst) {
+    localVarCst->clearAndCopyFrom(flattener.localVarCst);
+  }
+
+  return success();
+}
+
+// Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to
+// be flattened (semi-affine expressions not handled yet).
+LogicalResult
+mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
+                             unsigned numSymbols,
+                             llvm::SmallVectorImpl<int64_t> *flattenedExpr,
+                             FlatAffineConstraints *localVarCst) {
+  std::vector<SmallVector<int64_t, 8>> flattenedExprs;
+  LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols,
+                                                &flattenedExprs, localVarCst);
+  *flattenedExpr = flattenedExprs[0];
+  return ret;
+}
+
+/// Flattens the expressions in map. Returns failure if 'expr' was unable to be
+/// flattened (i.e., semi-affine expressions not handled yet).
+LogicalResult mlir::getFlattenedAffineExprs(
+    AffineMap map, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
+    FlatAffineConstraints *localVarCst) {
+  if (map.getNumResults() == 0) {
+    localVarCst->reset(map.getNumDims(), map.getNumSymbols());
+    return success();
+  }
+  return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(),
+                                   map.getNumSymbols(), flattenedExprs,
+                                   localVarCst);
+}
+
+LogicalResult mlir::getFlattenedAffineExprs(
+    IntegerSet set, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
+    FlatAffineConstraints *localVarCst) {
+  if (set.getNumConstraints() == 0) {
+    localVarCst->reset(set.getNumDims(), set.getNumSymbols());
+    return success();
+  }
+  return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(),
+                                   set.getNumSymbols(), flattenedExprs,
+                                   localVarCst);
+}
+
+//===----------------------------------------------------------------------===//
+// MutableAffineMap.
+//===----------------------------------------------------------------------===//
+
+MutableAffineMap::MutableAffineMap(AffineMap map)
+    : numDims(map.getNumDims()), numSymbols(map.getNumSymbols()),
+      // A map always has at least 1 result by construction
+      context(map.getResult(0).getContext()) {
+  for (auto result : map.getResults())
+    results.push_back(result);
+}
+
+void MutableAffineMap::reset(AffineMap map) {
+  results.clear();
+  numDims = map.getNumDims();
+  numSymbols = map.getNumSymbols();
+  // A map always has at least 1 result by construction
+  context = map.getResult(0).getContext();
+  for (auto result : map.getResults())
+    results.push_back(result);
+}
+
+bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
+  if (results[idx].isMultipleOf(factor))
+    return true;
+
+  // TODO(bondhugula): use simplifyAffineExpr and FlatAffineConstraints to
+  // complete this (for a more powerful analysis).
+  return false;
+}
+
+// Simplifies the result affine expressions of this map. The expressions have to
+// be pure for the simplification implemented.
+void MutableAffineMap::simplify() {
+  // Simplify each of the results if possible.
+  // TODO(ntv): functional-style map
+  for (unsigned i = 0, e = getNumResults(); i < e; i++) {
+    results[i] = simplifyAffineExpr(getResult(i), numDims, numSymbols);
+  }
+}
+
+AffineMap MutableAffineMap::getAffineMap() const {
+  return AffineMap::get(numDims, numSymbols, results);
+}
+
+MutableIntegerSet::MutableIntegerSet(IntegerSet set, MLIRContext *context)
+    : numDims(set.getNumDims()), numSymbols(set.getNumSymbols()) {
+  // TODO(bondhugula)
+}
+
+// Universal set.
+MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols,
+                                     MLIRContext *context)
+    : numDims(numDims), numSymbols(numSymbols) {}
+
+//===----------------------------------------------------------------------===//
+// AffineValueMap.
+//===----------------------------------------------------------------------===//
+
+AffineValueMap::AffineValueMap(AffineMap map, ArrayRef<Value *> operands,
+                               ArrayRef<Value *> results)
+    : map(map), operands(operands.begin(), operands.end()),
+      results(results.begin(), results.end()) {}
+
+AffineValueMap::AffineValueMap(AffineApplyOp applyOp)
+    : map(applyOp.getAffineMap()),
+      operands(applyOp.operand_begin(), applyOp.operand_end()) {
+  results.push_back(applyOp.getResult());
+}
+
+AffineValueMap::AffineValueMap(AffineBound bound)
+    : map(bound.getMap()),
+      operands(bound.operand_begin(), bound.operand_end()) {}
+
+void AffineValueMap::reset(AffineMap map, ArrayRef<Value *> operands,
+                           ArrayRef<Value *> results) {
+  this->map.reset(map);
+  this->operands.assign(operands.begin(), operands.end());
+  this->results.assign(results.begin(), results.end());
+}
+
+// Returns true and sets 'indexOfMatch' if 'valueToMatch' is found in
+// 'valuesToSearch' beginning at 'indexStart'. Returns false otherwise.
+static bool findIndex(Value *valueToMatch, ArrayRef<Value *> valuesToSearch,
+                      unsigned indexStart, unsigned *indexOfMatch) {
+  unsigned size = valuesToSearch.size();
+  for (unsigned i = indexStart; i < size; ++i) {
+    if (valueToMatch == valuesToSearch[i]) {
+      *indexOfMatch = i;
+      return true;
+    }
+  }
+  return false;
+}
+
+inline bool AffineValueMap::isMultipleOf(unsigned idx, int64_t factor) const {
+  return map.isMultipleOf(idx, factor);
+}
+
+/// This method uses the invariant that operands are always positionally aligned
+/// with the AffineDimExpr in the underlying AffineMap.
+bool AffineValueMap::isFunctionOf(unsigned idx, Value *value) const {
+  unsigned index;
+  if (!findIndex(value, operands, /*indexStart=*/0, &index)) {
+    return false;
+  }
+  auto expr = const_cast<AffineValueMap *>(this)->getAffineMap().getResult(idx);
+  // TODO(ntv): this is better implemented on a flattened representation.
+  // At least for now it is conservative.
+  return expr.isFunctionOfDim(index);
+}
+
+Value *AffineValueMap::getOperand(unsigned i) const {
+  return static_cast<Value *>(operands[i]);
+}
+
+ArrayRef<Value *> AffineValueMap::getOperands() const {
+  return ArrayRef<Value *>(operands);
+}
+
+AffineMap AffineValueMap::getAffineMap() const { return map.getAffineMap(); }
+
+AffineValueMap::~AffineValueMap() {}
+
+//===----------------------------------------------------------------------===//
+// FlatAffineConstraints.
+//===----------------------------------------------------------------------===//
+
+// Copy constructor.
+FlatAffineConstraints::FlatAffineConstraints(
+    const FlatAffineConstraints &other) {
+  numReservedCols = other.numReservedCols;
+  numDims = other.getNumDimIds();
+  numSymbols = other.getNumSymbolIds();
+  numIds = other.getNumIds();
+
+  auto otherIds = other.getIds();
+  ids.reserve(numReservedCols);
+  ids.append(otherIds.begin(), otherIds.end());
+
+  unsigned numReservedEqualities = other.getNumReservedEqualities();
+  unsigned numReservedInequalities = other.getNumReservedInequalities();
+
+  equalities.reserve(numReservedEqualities * numReservedCols);
+  inequalities.reserve(numReservedInequalities * numReservedCols);
+
+  for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
+    addInequality(other.getInequality(r));
+  }
+  for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
+    addEquality(other.getEquality(r));
+  }
+}
+
+// Clones this object.
+std::unique_ptr<FlatAffineConstraints> FlatAffineConstraints::clone() const {
+  return llvm::make_unique<FlatAffineConstraints>(*this);
+}
+
+// Construct from an IntegerSet.
+FlatAffineConstraints::FlatAffineConstraints(IntegerSet set)
+    : numReservedCols(set.getNumOperands() + 1),
+      numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()),
+      numSymbols(set.getNumSymbols()) {
+  equalities.reserve(set.getNumEqualities() * numReservedCols);
+  inequalities.reserve(set.getNumInequalities() * numReservedCols);
+  ids.resize(numIds, None);
+
+  // Flatten expressions and add them to the constraint system.
+  std::vector<SmallVector<int64_t, 8>> flatExprs;
+  FlatAffineConstraints localVarCst;
+  if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) {
+    assert(false && "flattening unimplemented for semi-affine integer sets");
+    return;
+  }
+  assert(flatExprs.size() == set.getNumConstraints());
+  for (unsigned l = 0, e = localVarCst.getNumLocalIds(); l < e; l++) {
+    addLocalId(getNumLocalIds());
+  }
+
+  for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
+    const auto &flatExpr = flatExprs[i];
+    assert(flatExpr.size() == getNumCols());
+    if (set.getEqFlags()[i]) {
+      addEquality(flatExpr);
+    } else {
+      addInequality(flatExpr);
+    }
+  }
+  // Add the other constraints involving local id's from flattening.
+  append(localVarCst);
+}
+
+void FlatAffineConstraints::reset(unsigned numReservedInequalities,
+                                  unsigned numReservedEqualities,
+                                  unsigned newNumReservedCols,
+                                  unsigned newNumDims, unsigned newNumSymbols,
+                                  unsigned newNumLocals,
+                                  ArrayRef<Value *> idArgs) {
+  assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
+         "minimum 1 column");
+  numReservedCols = newNumReservedCols;
+  numDims = newNumDims;
+  numSymbols = newNumSymbols;
+  numIds = numDims + numSymbols + newNumLocals;
+  assert(idArgs.empty() || idArgs.size() == numIds);
+
+  clearConstraints();
+  if (numReservedEqualities >= 1)
+    equalities.reserve(newNumReservedCols * numReservedEqualities);
+  if (numReservedInequalities >= 1)
+    inequalities.reserve(newNumReservedCols * numReservedInequalities);
+  if (idArgs.empty()) {
+    ids.resize(numIds, None);
+  } else {
+    ids.assign(idArgs.begin(), idArgs.end());
+  }
+}
+
+void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols,
+                                  unsigned newNumLocals,
+                                  ArrayRef<Value *> idArgs) {
+  reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims,
+        newNumSymbols, newNumLocals, idArgs);
+}
+
+void FlatAffineConstraints::append(const FlatAffineConstraints &other) {
+  assert(other.getNumCols() == getNumCols());
+  assert(other.getNumDimIds() == getNumDimIds());
+  assert(other.getNumSymbolIds() == getNumSymbolIds());
+
+  inequalities.reserve(inequalities.size() +
+                       other.getNumInequalities() * numReservedCols);
+  equalities.reserve(equalities.size() +
+                     other.getNumEqualities() * numReservedCols);
+
+  for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
+    addInequality(other.getInequality(r));
+  }
+  for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
+    addEquality(other.getEquality(r));
+  }
+}
+
+void FlatAffineConstraints::addLocalId(unsigned pos) {
+  addId(IdKind::Local, pos);
+}
+
+void FlatAffineConstraints::addDimId(unsigned pos, Value *id) {
+  addId(IdKind::Dimension, pos, id);
+}
+
+void FlatAffineConstraints::addSymbolId(unsigned pos, Value *id) {
+  addId(IdKind::Symbol, pos, id);
+}
+
+/// Adds a dimensional identifier. The added column is initialized to
+/// zero.
+void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value *id) {
+  if (kind == IdKind::Dimension) {
+    assert(pos <= getNumDimIds());
+  } else if (kind == IdKind::Symbol) {
+    assert(pos <= getNumSymbolIds());
+  } else {
+    assert(pos <= getNumLocalIds());
+  }
+
+  unsigned oldNumReservedCols = numReservedCols;
+
+  // Check if a resize is necessary.
+  if (getNumCols() + 1 > numReservedCols) {
+    equalities.resize(getNumEqualities() * (getNumCols() + 1));
+    inequalities.resize(getNumInequalities() * (getNumCols() + 1));
+    numReservedCols++;
+  }
+
+  int absolutePos;
+
+  if (kind == IdKind::Dimension) {
+    absolutePos = pos;
+    numDims++;
+  } else if (kind == IdKind::Symbol) {
+    absolutePos = pos + getNumDimIds();
+    numSymbols++;
+  } else {
+    absolutePos = pos + getNumDimIds() + getNumSymbolIds();
+  }
+  numIds++;
+
+  // Note that getNumCols() now will already return the new size, which will be
+  // at least one.
+  int numInequalities = static_cast<int>(getNumInequalities());
+  int numEqualities = static_cast<int>(getNumEqualities());
+  int numCols = static_cast<int>(getNumCols());
+  for (int r = numInequalities - 1; r >= 0; r--) {
+    for (int c = numCols - 2; c >= 0; c--) {
+      if (c < absolutePos)
+        atIneq(r, c) = inequalities[r * oldNumReservedCols + c];
+      else
+        atIneq(r, c + 1) = inequalities[r * oldNumReservedCols + c];
+    }
+    atIneq(r, absolutePos) = 0;
+  }
+
+  for (int r = numEqualities - 1; r >= 0; r--) {
+    for (int c = numCols - 2; c >= 0; c--) {
+      // All values in column absolutePositions < absolutePos have the same
+      // coordinates in the 2-d view of the coefficient buffer.
+      if (c < absolutePos)
+        atEq(r, c) = equalities[r * oldNumReservedCols + c];
+      else
+        // Those at absolutePosition >= absolutePos, get a shifted
+        // absolutePosition.
+        atEq(r, c + 1) = equalities[r * oldNumReservedCols + c];
+    }
+    // Initialize added dimension to zero.
+    atEq(r, absolutePos) = 0;
+  }
+
+  // If an 'id' is provided, insert it; otherwise use None.
+  if (id) {
+    ids.insert(ids.begin() + absolutePos, id);
+  } else {
+    ids.insert(ids.begin() + absolutePos, None);
+  }
+  assert(ids.size() == getNumIds());
+}
+
+/// Checks if two constraint systems are in the same space, i.e., if they are
+/// associated with the same set of identifiers, appearing in the same order.
+static bool areIdsAligned(const FlatAffineConstraints &A,
+                          const FlatAffineConstraints &B) {
+  return A.getNumDimIds() == B.getNumDimIds() &&
+         A.getNumSymbolIds() == B.getNumSymbolIds() &&
+         A.getNumIds() == B.getNumIds() && A.getIds().equals(B.getIds());
+}
+
+/// Calls areIdsAligned to check if two constraint systems have the same set
+/// of identifiers in the same order.
+bool FlatAffineConstraints::areIdsAlignedWithOther(
+    const FlatAffineConstraints &other) {
+  return areIdsAligned(*this, other);
+}
+
+/// Checks if the SSA values associated with `cst''s identifiers are unique.
+static bool LLVM_ATTRIBUTE_UNUSED
+areIdsUnique(const FlatAffineConstraints &cst) {
+  SmallPtrSet<Value *, 8> uniqueIds;
+  for (auto id : cst.getIds()) {
+    if (id.hasValue() && !uniqueIds.insert(id.getValue()).second)
+      return false;
+  }
+  return true;
+}
+
+// Swap the posA^th identifier with the posB^th identifier.
+static void swapId(FlatAffineConstraints *A, unsigned posA, unsigned posB) {
+  assert(posA < A->getNumIds() && "invalid position A");
+  assert(posB < A->getNumIds() && "invalid position B");
+
+  if (posA == posB)
+    return;
+
+  for (unsigned r = 0, e = A->getNumInequalities(); r < e; r++) {
+    std::swap(A->atIneq(r, posA), A->atIneq(r, posB));
+  }
+  for (unsigned r = 0, e = A->getNumEqualities(); r < e; r++) {
+    std::swap(A->atEq(r, posA), A->atEq(r, posB));
+  }
+  std::swap(A->getId(posA), A->getId(posB));
+}
+
+/// Merge and align the identifiers of A and B starting at 'offset', so that
+/// both constraint systems get the union of the contained identifiers that is
+/// dimension-wise and symbol-wise unique; both constraint systems are updated
+/// so that they have the union of all identifiers, with A's original
+/// identifiers appearing first followed by any of B's identifiers that didn't
+/// appear in A. Local identifiers of each system are by design separate/local
+/// and are placed one after other (A's followed by B's).
+//  Eg: Input: A has ((%i %j) [%M %N]) and B has (%k, %j) [%P, %N, %M])
+//      Output: both A, B have (%i, %j, %k) [%M, %N, %P]
+//
+static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A,
+                             FlatAffineConstraints *B) {
+  assert(offset <= A->getNumDimIds() && offset <= B->getNumDimIds());
+  // A merge/align isn't meaningful if a cst's ids aren't distinct.
+  assert(areIdsUnique(*A) && "A's id values aren't unique");
+  assert(areIdsUnique(*B) && "B's id values aren't unique");
+
+  assert(std::all_of(A->getIds().begin() + offset,
+                     A->getIds().begin() + A->getNumDimAndSymbolIds(),
+                     [](Optional<Value *> id) { return id.hasValue(); }));
+
+  assert(std::all_of(B->getIds().begin() + offset,
+                     B->getIds().begin() + B->getNumDimAndSymbolIds(),
+                     [](Optional<Value *> id) { return id.hasValue(); }));
+
+  // Place local id's of A after local id's of B.
+  for (unsigned l = 0, e = A->getNumLocalIds(); l < e; l++) {
+    B->addLocalId(0);
+  }
+  for (unsigned t = 0, e = B->getNumLocalIds() - A->getNumLocalIds(); t < e;
+       t++) {
+    A->addLocalId(A->getNumLocalIds());
+  }
+
+  SmallVector<Value *, 4> aDimValues, aSymValues;
+  A->getIdValues(offset, A->getNumDimIds(), &aDimValues);
+  A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &aSymValues);
+  {
+    // Merge dims from A into B.
+    unsigned d = offset;
+    for (auto *aDimValue : aDimValues) {
+      unsigned loc;
+      if (B->findId(*aDimValue, &loc)) {
+        assert(loc >= offset && "A's dim appears in B's aligned range");
+        assert(loc < B->getNumDimIds() &&
+               "A's dim appears in B's non-dim position");
+        swapId(B, d, loc);
+      } else {
+        B->addDimId(d);
+        B->setIdValue(d, aDimValue);
+      }
+      d++;
+    }
+
+    // Dimensions that are in B, but not in A, are added at the end.
+    for (unsigned t = A->getNumDimIds(), e = B->getNumDimIds(); t < e; t++) {
+      A->addDimId(A->getNumDimIds());
+      A->setIdValue(A->getNumDimIds() - 1, B->getIdValue(t));
+    }
+  }
+  {
+    // Merge symbols: merge A's symbols into B first.
+    unsigned s = B->getNumDimIds();
+    for (auto *aSymValue : aSymValues) {
+      unsigned loc;
+      if (B->findId(*aSymValue, &loc)) {
+        assert(loc >= B->getNumDimIds() && loc < B->getNumDimAndSymbolIds() &&
+               "A's symbol appears in B's non-symbol position");
+        swapId(B, s, loc);
+      } else {
+        B->addSymbolId(s - B->getNumDimIds());
+        B->setIdValue(s, aSymValue);
+      }
+      s++;
+    }
+    // Symbols that are in B, but not in A, are added at the end.
+    for (unsigned t = A->getNumDimAndSymbolIds(),
+                  e = B->getNumDimAndSymbolIds();
+         t < e; t++) {
+      A->addSymbolId(A->getNumSymbolIds());
+      A->setIdValue(A->getNumDimAndSymbolIds() - 1, B->getIdValue(t));
+    }
+  }
+  assert(areIdsAligned(*A, *B) && "IDs expected to be aligned");
+}
+
+// Call 'mergeAndAlignIds' to align constraint systems of 'this' and 'other'.
+void FlatAffineConstraints::mergeAndAlignIdsWithOther(
+    unsigned offset, FlatAffineConstraints *other) {
+  mergeAndAlignIds(offset, this, other);
+}
+
+// This routine may add additional local variables if the flattened expression
+// corresponding to the map has such variables due to mod's, ceildiv's, and
+// floordiv's in it.
+LogicalResult FlatAffineConstraints::composeMap(AffineValueMap *vMap) {
+  std::vector<SmallVector<int64_t, 8>> flatExprs;
+  FlatAffineConstraints localCst;
+  if (failed(getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs,
+                                     &localCst))) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "composition unimplemented for semi-affine maps\n");
+    return failure();
+  }
+  assert(flatExprs.size() == vMap->getNumResults());
+
+  // Add localCst information.
+  if (localCst.getNumLocalIds() > 0) {
+    SmallVector<Value *, 8> values(vMap->getOperands().begin(),
+                                   vMap->getOperands().end());
+    localCst.setIdValues(0, localCst.getNumDimAndSymbolIds(), values);
+    // Align localCst and this.
+    mergeAndAlignIds(/*offset=*/0, &localCst, this);
+    // Finally, append localCst to this constraint set.
+    append(localCst);
+  }
+
+  // Add dimensions corresponding to the map's results.
+  for (unsigned t = 0, e = vMap->getNumResults(); t < e; t++) {
+    // TODO: Consider using a batched version to add a range of IDs.
+    addDimId(0);
+  }
+
+  // We add one equality for each result connecting the result dim of the map to
+  // the other identifiers.
+  // For eg: if the expression is 16*i0 + i1, and this is the r^th
+  // iteration/result of the value map, we are adding the equality:
+  //  d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we
+  //  add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
+  for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
+    const auto &flatExpr = flatExprs[r];
+    assert(flatExpr.size() >= vMap->getNumOperands() + 1);
+
+    // eqToAdd is the equality corresponding to the flattened affine expression.
+    SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
+    // Set the coefficient for this result to one.
+    eqToAdd[r] = 1;
+
+    // Dims and symbols.
+    for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) {
+      unsigned loc;
+      bool ret = findId(*vMap->getOperand(i), &loc);
+      assert(ret && "value map's id can't be found");
+      (void)ret;
+      // Negate 'eq[r]' since the newly added dimension will be set to this one.
+      eqToAdd[loc] = -flatExpr[i];
+    }
+    // Local vars common to eq and localCst are at the beginning.
+    unsigned j = getNumDimIds() + getNumSymbolIds();
+    unsigned end = flatExpr.size() - 1;
+    for (unsigned i = vMap->getNumOperands(); i < end; i++, j++) {
+      eqToAdd[j] = -flatExpr[i];
+    }
+
+    // Constant term.
+    eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
+
+    // Add the equality connecting the result of the map to this constraint set.
+    addEquality(eqToAdd);
+  }
+
+  return success();
+}
+
+// Turn a dimension into a symbol.
+static void turnDimIntoSymbol(FlatAffineConstraints *cst, Value &id) {
+  unsigned pos;
+  if (cst->findId(id, &pos) && pos < cst->getNumDimIds()) {
+    swapId(cst, pos, cst->getNumDimIds() - 1);
+    cst->setDimSymbolSeparation(cst->getNumSymbolIds() + 1);
+  }
+}
+
+// Turn a symbol into a dimension.
+static void turnSymbolIntoDim(FlatAffineConstraints *cst, Value &id) {
+  unsigned pos;
+  if (cst->findId(id, &pos) && pos >= cst->getNumDimIds() &&
+      pos < cst->getNumDimAndSymbolIds()) {
+    swapId(cst, pos, cst->getNumDimIds());
+    cst->setDimSymbolSeparation(cst->getNumSymbolIds() - 1);
+  }
+}
+
+// Changes all symbol identifiers which are loop IVs to dim identifiers.
+void FlatAffineConstraints::convertLoopIVSymbolsToDims() {
+  // Gather all symbols which are loop IVs.
+  SmallVector<Value *, 4> loopIVs;
+  for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) {
+    if (ids[i].hasValue() && getForInductionVarOwner(ids[i].getValue()))
+      loopIVs.push_back(ids[i].getValue());
+  }
+  // Turn each symbol in 'loopIVs' into a dim identifier.
+  for (auto *iv : loopIVs) {
+    turnSymbolIntoDim(this, *iv);
+  }
+}
+
+void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value *id) {
+  if (containsId(*id))
+    return;
+
+  // Caller is expected to fully compose map/operands if necessary.
+  assert((isTopLevelSymbol(id) || isForInductionVar(id)) &&
+         "non-terminal symbol / loop IV expected");
+  // Outer loop IVs could be used in forOp's bounds.
+  if (auto loop = getForInductionVarOwner(id)) {
+    addDimId(getNumDimIds(), id);
+    if (failed(this->addAffineForOpDomain(loop)))
+      LLVM_DEBUG(
+          loop.emitWarning("failed to add domain info to constraint system"));
+    return;
+  }
+  // Add top level symbol.
+  addSymbolId(getNumSymbolIds(), id);
+  // Check if the symbol is a constant.
+  if (auto constOp = dyn_cast_or_null<ConstantIndexOp>(id->getDefiningOp()))
+    setIdToConstant(*id, constOp.getValue());
+}
+
+LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) {
+  unsigned pos;
+  // Pre-condition for this method.
+  if (!findId(*forOp.getInductionVar(), &pos)) {
+    assert(false && "Value not found");
+    return failure();
+  }
+
+  int64_t step = forOp.getStep();
+  if (step != 1) {
+    if (!forOp.hasConstantLowerBound())
+      forOp.emitWarning("domain conservatively approximated");
+    else {
+      // Add constraints for the stride.
+      // (iv - lb) % step = 0 can be written as:
+      // (iv - lb) - step * q = 0 where q = (iv - lb) / step.
+      // Add local variable 'q' and add the above equality.
+      // The first constraint is q = (iv - lb) floordiv step
+      SmallVector<int64_t, 8> dividend(getNumCols(), 0);
+      int64_t lb = forOp.getConstantLowerBound();
+      dividend[pos] = 1;
+      dividend.back() -= lb;
+      addLocalFloorDiv(dividend, step);
+      // Second constraint: (iv - lb) - step * q = 0.
+      SmallVector<int64_t, 8> eq(getNumCols(), 0);
+      eq[pos] = 1;
+      eq.back() -= lb;
+      // For the local var just added above.
+      eq[getNumCols() - 2] = -step;
+      addEquality(eq);
+    }
+  }
+
+  if (forOp.hasConstantLowerBound()) {
+    addConstantLowerBound(pos, forOp.getConstantLowerBound());
+  } else {
+    // Non-constant lower bound case.
+    SmallVector<Value *, 4> lbOperands(forOp.getLowerBoundOperands().begin(),
+                                       forOp.getLowerBoundOperands().end());
+    if (failed(addLowerOrUpperBound(pos, forOp.getLowerBoundMap(), lbOperands,
+                                    /*eq=*/false, /*lower=*/true)))
+      return failure();
+  }
+
+  if (forOp.hasConstantUpperBound()) {
+    addConstantUpperBound(pos, forOp.getConstantUpperBound() - 1);
+    return success();
+  }
+  // Non-constant upper bound case.
+  SmallVector<Value *, 4> ubOperands(forOp.getUpperBoundOperands().begin(),
+                                     forOp.getUpperBoundOperands().end());
+  return addLowerOrUpperBound(pos, forOp.getUpperBoundMap(), ubOperands,
+                              /*eq=*/false, /*lower=*/false);
+}
+
+// Searches for a constraint with a non-zero coefficient at 'colIdx' in
+// equality (isEq=true) or inequality (isEq=false) constraints.
+// Returns true and sets row found in search in 'rowIdx'.
+// Returns false otherwise.
+static bool
+findConstraintWithNonZeroAt(const FlatAffineConstraints &constraints,
+                            unsigned colIdx, bool isEq, unsigned *rowIdx) {
+  auto at = [&](unsigned rowIdx) -> int64_t {
+    return isEq ? constraints.atEq(rowIdx, colIdx)
+                : constraints.atIneq(rowIdx, colIdx);
+  };
+  unsigned e =
+      isEq ? constraints.getNumEqualities() : constraints.getNumInequalities();
+  for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) {
+    if (at(*rowIdx) != 0) {
+      return true;
+    }
+  }
+  return false;
+}
+
+// Normalizes the coefficient values across all columns in 'rowIDx' by their
+// GCD in equality or inequality contraints as specified by 'isEq'.
+template <bool isEq>
+static void normalizeConstraintByGCD(FlatAffineConstraints *constraints,
+                                     unsigned rowIdx) {
+  auto at = [&](unsigned colIdx) -> int64_t {
+    return isEq ? constraints->atEq(rowIdx, colIdx)
+                : constraints->atIneq(rowIdx, colIdx);
+  };
+  uint64_t gcd = std::abs(at(0));
+  for (unsigned j = 1, e = constraints->getNumCols(); j < e; ++j) {
+    gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(at(j)));
+  }
+  if (gcd > 0 && gcd != 1) {
+    for (unsigned j = 0, e = constraints->getNumCols(); j < e; ++j) {
+      int64_t v = at(j) / static_cast<int64_t>(gcd);
+      isEq ? constraints->atEq(rowIdx, j) = v
+           : constraints->atIneq(rowIdx, j) = v;
+    }
+  }
+}
+
+void FlatAffineConstraints::normalizeConstraintsByGCD() {
+  for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
+    normalizeConstraintByGCD</*isEq=*/true>(this, i);
+  }
+  for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
+    normalizeConstraintByGCD</*isEq=*/false>(this, i);
+  }
+}
+
+bool FlatAffineConstraints::hasConsistentState() const {
+  if (inequalities.size() != getNumInequalities() * numReservedCols)
+    return false;
+  if (equalities.size() != getNumEqualities() * numReservedCols)
+    return false;
+  if (ids.size() != getNumIds())
+    return false;
+
+  // Catches errors where numDims, numSymbols, numIds aren't consistent.
+  if (numDims > numIds || numSymbols > numIds || numDims + numSymbols > numIds)
+    return false;
+
+  return true;
+}
+
+/// Checks all rows of equality/inequality constraints for trivial
+/// contradictions (for example: 1 == 0, 0 >= 1), which may have surfaced
+/// after elimination. Returns 'true' if an invalid constraint is found;
+/// 'false' otherwise.
+bool FlatAffineConstraints::hasInvalidConstraint() const {
+  assert(hasConsistentState());
+  auto check = [&](bool isEq) -> bool {
+    unsigned numCols = getNumCols();
+    unsigned numRows = isEq ? getNumEqualities() : getNumInequalities();
+    for (unsigned i = 0, e = numRows; i < e; ++i) {
+      unsigned j;
+      for (j = 0; j < numCols - 1; ++j) {
+        int64_t v = isEq ? atEq(i, j) : atIneq(i, j);
+        // Skip rows with non-zero variable coefficients.
+        if (v != 0)
+          break;
+      }
+      if (j < numCols - 1) {
+        continue;
+      }
+      // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'.
+      // Example invalid constraints include: '1 == 0' or '-1 >= 0'
+      int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1);
+      if ((isEq && v != 0) || (!isEq && v < 0)) {
+        return true;
+      }
+    }
+    return false;
+  };
+  if (check(/*isEq=*/true))
+    return true;
+  return check(/*isEq=*/false);
+}
+
+// Eliminate identifier from constraint at 'rowIdx' based on coefficient at
+// pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be
+// updated as they have already been eliminated.
+static void eliminateFromConstraint(FlatAffineConstraints *constraints,
+                                    unsigned rowIdx, unsigned pivotRow,
+                                    unsigned pivotCol, unsigned elimColStart,
+                                    bool isEq) {
+  // Skip if equality 'rowIdx' if same as 'pivotRow'.
+  if (isEq && rowIdx == pivotRow)
+    return;
+  auto at = [&](unsigned i, unsigned j) -> int64_t {
+    return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j);
+  };
+  int64_t leadCoeff = at(rowIdx, pivotCol);
+  // Skip if leading coefficient at 'rowIdx' is already zero.
+  if (leadCoeff == 0)
+    return;
+  int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol);
+  int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1;
+  int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff);
+  int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff));
+  int64_t rowMultiplier = lcm / std::abs(leadCoeff);
+
+  unsigned numCols = constraints->getNumCols();
+  for (unsigned j = 0; j < numCols; ++j) {
+    // Skip updating column 'j' if it was just eliminated.
+    if (j >= elimColStart && j < pivotCol)
+      continue;
+    int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) +
+                rowMultiplier * at(rowIdx, j);
+    isEq ? constraints->atEq(rowIdx, j) = v
+         : constraints->atIneq(rowIdx, j) = v;
+  }
+}
+
+// Remove coefficients in column range [colStart, colLimit) in place.
+// This removes in data in the specified column range, and copies any
+// remaining valid data into place.
+static void shiftColumnsToLeft(FlatAffineConstraints *constraints,
+                               unsigned colStart, unsigned colLimit,
+                               bool isEq) {
+  assert(colLimit <= constraints->getNumIds());
+  if (colLimit <= colStart)
+    return;
+
+  unsigned numCols = constraints->getNumCols();
+  unsigned numRows = isEq ? constraints->getNumEqualities()
+                          : constraints->getNumInequalities();
+  unsigned numToEliminate = colLimit - colStart;
+  for (unsigned r = 0, e = numRows; r < e; ++r) {
+    for (unsigned c = colLimit; c < numCols; ++c) {
+      if (isEq) {
+        constraints->atEq(r, c - numToEliminate) = constraints->atEq(r, c);
+      } else {
+        constraints->atIneq(r, c - numToEliminate) = constraints->atIneq(r, c);
+      }
+    }
+  }
+}
+
+// Removes identifiers in column range [idStart, idLimit), and copies any
+// remaining valid data into place, and updates member variables.
+void FlatAffineConstraints::removeIdRange(unsigned idStart, unsigned idLimit) {
+  assert(idLimit < getNumCols() && "invalid id limit");
+
+  if (idStart >= idLimit)
+    return;
+
+  // We are going to be removing one or more identifiers from the range.
+  assert(idStart < numIds && "invalid idStart position");
+
+  // TODO(andydavis) Make 'removeIdRange' a lambda called from here.
+  // Remove eliminated identifiers from equalities.
+  shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/true);
+
+  // Remove eliminated identifiers from inequalities.
+  shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/false);
+
+  // Update members numDims, numSymbols and numIds.
+  unsigned numDimsEliminated = 0;
+  unsigned numLocalsEliminated = 0;
+  unsigned numColsEliminated = idLimit - idStart;
+  if (idStart < numDims) {
+    numDimsEliminated = std::min(numDims, idLimit) - idStart;
+  }
+  // Check how many local id's were removed. Note that our identifier order is
+  // [dims, symbols, locals]. Local id start at position numDims + numSymbols.
+  if (idLimit > numDims + numSymbols) {
+    numLocalsEliminated = std::min(
+        idLimit - std::max(idStart, numDims + numSymbols), getNumLocalIds());
+  }
+  unsigned numSymbolsEliminated =
+      numColsEliminated - numDimsEliminated - numLocalsEliminated;
+
+  numDims -= numDimsEliminated;
+  numSymbols -= numSymbolsEliminated;
+  numIds = numIds - numColsEliminated;
+
+  ids.erase(ids.begin() + idStart, ids.begin() + idLimit);
+
+  // No resize necessary. numReservedCols remains the same.
+}
+
+/// Returns the position of the identifier that has the minimum <number of lower
+/// bounds> times <number of upper bounds> from the specified range of
+/// identifiers [start, end). It is often best to eliminate in the increasing
+/// order of these counts when doing Fourier-Motzkin elimination since FM adds
+/// that many new constraints.
+static unsigned getBestIdToEliminate(const FlatAffineConstraints &cst,
+                                     unsigned start, unsigned end) {
+  assert(start < cst.getNumIds() && end < cst.getNumIds() + 1);
+
+  auto getProductOfNumLowerUpperBounds = [&](unsigned pos) {
+    unsigned numLb = 0;
+    unsigned numUb = 0;
+    for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
+      if (cst.atIneq(r, pos) > 0) {
+        ++numLb;
+      } else if (cst.atIneq(r, pos) < 0) {
+        ++numUb;
+      }
+    }
+    return numLb * numUb;
+  };
+
+  unsigned minLoc = start;
+  unsigned min = getProductOfNumLowerUpperBounds(start);
+  for (unsigned c = start + 1; c < end; c++) {
+    unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c);
+    if (numLbUbProduct < min) {
+      min = numLbUbProduct;
+      minLoc = c;
+    }
+  }
+  return minLoc;
+}
+
+// Checks for emptiness of the set by eliminating identifiers successively and
+// using the GCD test (on all equality constraints) and checking for trivially
+// invalid constraints. Returns 'true' if the constraint system is found to be
+// empty; false otherwise.
+bool FlatAffineConstraints::isEmpty() const {
+  if (isEmptyByGCDTest() || hasInvalidConstraint())
+    return true;
+
+  // First, eliminate as many identifiers as possible using Gaussian
+  // elimination.
+  FlatAffineConstraints tmpCst(*this);
+  unsigned currentPos = 0;
+  while (currentPos < tmpCst.getNumIds()) {
+    tmpCst.gaussianEliminateIds(currentPos, tmpCst.getNumIds());
+    ++currentPos;
+    // We check emptiness through trivial checks after eliminating each ID to
+    // detect emptiness early. Since the checks isEmptyByGCDTest() and
+    // hasInvalidConstraint() are linear time and single sweep on the constraint
+    // buffer, this appears reasonable - but can optimize in the future.
+    if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest())
+      return true;
+  }
+
+  // Eliminate the remaining using FM.
+  for (unsigned i = 0, e = tmpCst.getNumIds(); i < e; i++) {
+    tmpCst.FourierMotzkinEliminate(
+        getBestIdToEliminate(tmpCst, 0, tmpCst.getNumIds()));
+    // Check for a constraint explosion. This rarely happens in practice, but
+    // this check exists as a safeguard against improperly constructed
+    // constraint systems or artifically created arbitrarily complex systems
+    // that aren't the intended use case for FlatAffineConstraints. This is
+    // needed since FM has a worst case exponential complexity in theory.
+    if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumIds()) {
+      LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected\n");
+      return false;
+    }
+
+    // FM wouldn't have modified the equalities in any way. So no need to again
+    // run GCD test. Check for trivial invalid constraints.
+    if (tmpCst.hasInvalidConstraint())
+      return true;
+  }
+  return false;
+}
+
+// Runs the GCD test on all equality constraints. Returns 'true' if this test
+// fails on any equality. Returns 'false' otherwise.
+// This test can be used to disprove the existence of a solution. If it returns
+// true, no integer solution to the equality constraints can exist.
+//
+// GCD test definition:
+//
+// The equality constraint:
+//
+//  c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0
+//
+// has an integer solution iff:
+//
+//  GCD of c_1, c_2, ..., c_n divides c_0.
+//
+bool FlatAffineConstraints::isEmptyByGCDTest() const {
+  assert(hasConsistentState());
+  unsigned numCols = getNumCols();
+  for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
+    uint64_t gcd = std::abs(atEq(i, 0));
+    for (unsigned j = 1; j < numCols - 1; ++j) {
+      gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j)));
+    }
+    int64_t v = std::abs(atEq(i, numCols - 1));
+    if (gcd > 0 && (v % gcd != 0)) {
+      return true;
+    }
+  }
+  return false;
+}
+
+/// Tightens inequalities given that we are dealing with integer spaces. This is
+/// analogous to the GCD test but applied to inequalities. The constant term can
+/// be reduced to the preceding multiple of the GCD of the coefficients, i.e.,
+///  64*i - 100 >= 0  =>  64*i - 128 >= 0 (since 'i' is an integer). This is a
+/// fast method - linear in the number of coefficients.
+// Example on how this affects practical cases: consider the scenario:
+// 64*i >= 100, j = 64*i; without a tightening, elimination of i would yield
+// j >= 100 instead of the tighter (exact) j >= 128.
+void FlatAffineConstraints::GCDTightenInequalities() {
+  unsigned numCols = getNumCols();
+  for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
+    uint64_t gcd = std::abs(atIneq(i, 0));
+    for (unsigned j = 1; j < numCols - 1; ++j) {
+      gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atIneq(i, j)));
+    }
+    if (gcd > 0 && gcd != 1) {
+      int64_t gcdI = static_cast<int64_t>(gcd);
+      // Tighten the constant term and normalize the constraint by the GCD.
+      atIneq(i, numCols - 1) = mlir::floorDiv(atIneq(i, numCols - 1), gcdI);
+      for (unsigned j = 0, e = numCols - 1; j < e; ++j)
+        atIneq(i, j) /= gcdI;
+    }
+  }
+}
+
+// Eliminates all identifer variables in column range [posStart, posLimit).
+// Returns the number of variables eliminated.
+unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart,
+                                                     unsigned posLimit) {
+  // Return if identifier positions to eliminate are out of range.
+  assert(posLimit <= numIds);
+  assert(hasConsistentState());
+
+  if (posStart >= posLimit)
+    return 0;
+
+  GCDTightenInequalities();
+
+  unsigned pivotCol = 0;
+  for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) {
+    // Find a row which has a non-zero coefficient in column 'j'.
+    unsigned pivotRow;
+    if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/true,
+                                     &pivotRow)) {
+      // No pivot row in equalities with non-zero at 'pivotCol'.
+      if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/false,
+                                       &pivotRow)) {
+        // If inequalities are also non-zero in 'pivotCol', it can be
+        // eliminated.
+        continue;
+      }
+      break;
+    }
+
+    // Eliminate identifier at 'pivotCol' from each equality row.
+    for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
+      eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
+                              /*isEq=*/true);
+      normalizeConstraintByGCD</*isEq=*/true>(this, i);
+    }
+
+    // Eliminate identifier at 'pivotCol' from each inequality row.
+    for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
+      eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
+                              /*isEq=*/false);
+      normalizeConstraintByGCD</*isEq=*/false>(this, i);
+    }
+    removeEquality(pivotRow);
+    GCDTightenInequalities();
+  }
+  // Update position limit based on number eliminated.
+  posLimit = pivotCol;
+  // Remove eliminated columns from all constraints.
+  removeIdRange(posStart, posLimit);
+  return posLimit - posStart;
+}
+
+// Detect the identifier at 'pos' (say id_r) as modulo of another identifier
+// (say id_n) w.r.t a constant. When this happens, another identifier (say id_q)
+// could be detected as the floordiv of n. For eg:
+// id_n - 4*id_q - id_r = 0, 0 <= id_r <= 3    <=>
+//                          id_r = id_n mod 4, id_q = id_n floordiv 4.
+// lbConst and ubConst are the constant lower and upper bounds for 'pos' -
+// pre-detected at the caller.
+static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos,
+                        int64_t lbConst, int64_t ubConst,
+                        SmallVectorImpl<AffineExpr> *memo) {
+  assert(pos < cst.getNumIds() && "invalid position");
+
+  // Check if 0 <= id_r <= divisor - 1 and if id_r is equal to
+  // id_n - divisor * id_q. If these are true, then id_n becomes the dividend
+  // and id_q the quotient when dividing id_n by the divisor.
+
+  if (lbConst != 0 || ubConst < 1)
+    return false;
+
+  int64_t divisor = ubConst + 1;
+
+  // Now check for: id_r =  id_n - divisor * id_q. As an example, we
+  // are looking r = d - 4q, i.e., either r - d + 4q = 0 or -r + d - 4q = 0.
+  unsigned seenQuotient = 0, seenDividend = 0;
+  int quotientPos = -1, dividendPos = -1;
+  for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
+    // id_n should have coeff 1 or -1.
+    if (std::abs(cst.atEq(r, pos)) != 1)
+      continue;
+    // constant term should be 0.
+    if (cst.atEq(r, cst.getNumCols() - 1) != 0)
+      continue;
+    unsigned c, f;
+    int quotientSign = 1, dividendSign = 1;
+    for (c = 0, f = cst.getNumDimAndSymbolIds(); c < f; c++) {
+      if (c == pos)
+        continue;
+      // The coefficient of the quotient should be +/-divisor.
+      // TODO(bondhugula): could be extended to detect an affine function for
+      // the quotient (i.e., the coeff could be a non-zero multiple of divisor).
+      int64_t v = cst.atEq(r, c) * cst.atEq(r, pos);
+      if (v == divisor || v == -divisor) {
+        seenQuotient++;
+        quotientPos = c;
+        quotientSign = v > 0 ? 1 : -1;
+      }
+      // The coefficient of the dividend should be +/-1.
+      // TODO(bondhugula): could be extended to detect an affine function of
+      // the other identifiers as the dividend.
+      else if (v == -1 || v == 1) {
+        seenDividend++;
+        dividendPos = c;
+        dividendSign = v < 0 ? 1 : -1;
+      } else if (cst.atEq(r, c) != 0) {
+        // Cannot be inferred as a mod since the constraint has a coefficient
+        // for an identifier that's neither a unit nor the divisor (see TODOs
+        // above).
+        break;
+      }
+    }
+    if (c < f)
+      // Cannot be inferred as a mod since the constraint has a coefficient for
+      // an identifier that's neither a unit nor the divisor (see TODOs above).
+      continue;
+
+    // We are looking for exactly one identifier as the dividend.
+    if (seenDividend == 1 && seenQuotient >= 1) {
+      if (!(*memo)[dividendPos])
+        return false;
+      // Successfully detected a mod.
+      (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign;
+      auto ub = cst.getConstantUpperBound(dividendPos);
+      if (ub.hasValue() && ub.getValue() < divisor)
+        // The mod can be optimized away.
+        (*memo)[pos] = (*memo)[dividendPos] * dividendSign;
+      else
+        (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign;
+
+      if (seenQuotient == 1 && !(*memo)[quotientPos])
+        // Successfully detected a floordiv as well.
+        (*memo)[quotientPos] =
+            (*memo)[dividendPos].floorDiv(divisor) * quotientSign;
+      return true;
+    }
+  }
+  return false;
+}
+
+// Gather lower and upper bounds for the pos^th identifier.
+static void getLowerAndUpperBoundIndices(const FlatAffineConstraints &cst,
+                                         unsigned pos,
+                                         SmallVectorImpl<unsigned> *lbIndices,
+                                         SmallVectorImpl<unsigned> *ubIndices) {
+  assert(pos < cst.getNumIds() && "invalid position");
+
+  // Gather all lower bounds and upper bounds of the variable. Since the
+  // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
+  // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
+  for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
+    if (cst.atIneq(r, pos) >= 1) {
+      // Lower bound.
+      lbIndices->push_back(r);
+    } else if (cst.atIneq(r, pos) <= -1) {
+      // Upper bound.
+      ubIndices->push_back(r);
+    }
+  }
+}
+
+// Check if the pos^th identifier can be expressed as a floordiv of an affine
+// function of other identifiers (where the divisor is a positive constant).
+// For eg: 4q <= i + j <= 4q + 3   <=>   q = (i + j) floordiv 4.
+bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos,
+                      SmallVectorImpl<AffineExpr> *memo, MLIRContext *context) {
+  assert(pos < cst.getNumIds() && "invalid position");
+
+  SmallVector<unsigned, 4> lbIndices, ubIndices;
+  getLowerAndUpperBoundIndices(cst, pos, &lbIndices, &ubIndices);
+
+  // Check if any lower bound, upper bound pair is of the form:
+  // divisor * id >=  expr - (divisor - 1)    <-- Lower bound for 'id'
+  // divisor * id <=  expr                    <-- Upper bound for 'id'
+  // Then, 'id' is equivalent to 'expr floordiv divisor'.  (where divisor > 1).
+  //
+  // For example, if -32*k + 16*i + j >= 0
+  //                  32*k - 16*i - j + 31 >= 0   <=>
+  //             k = ( 16*i + j ) floordiv 32
+  unsigned seenDividends = 0;
+  for (auto ubPos : ubIndices) {
+    for (auto lbPos : lbIndices) {
+      // Check if lower bound's constant term is 'divisor - 1'. The 'divisor'
+      // here is cst.atIneq(lbPos, pos) and we already know that it's positive
+      // (since cst.Ineq(lbPos, ...) is a lower bound expression for 'pos'.
+      if (cst.atIneq(lbPos, cst.getNumCols() - 1) != cst.atIneq(lbPos, pos) - 1)
+        continue;
+      // Check if upper bound's constant term is 0.
+      if (cst.atIneq(ubPos, cst.getNumCols() - 1) != 0)
+        continue;
+      // For the remaining part, check if the lower bound expr's coeff's are
+      // negations of corresponding upper bound ones'.
+      unsigned c, f;
+      for (c = 0, f = cst.getNumCols() - 1; c < f; c++) {
+        if (cst.atIneq(lbPos, c) != -cst.atIneq(ubPos, c))
+          break;
+        if (c != pos && cst.atIneq(lbPos, c) != 0)
+          seenDividends++;
+      }
+      // Lb coeff's aren't negative of ub coeff's (for the non constant term
+      // part).
+      if (c < f)
+        continue;
+      if (seenDividends >= 1) {
+        // The divisor is the constant term of the lower bound expression.
+        // We already know that cst.atIneq(lbPos, pos) > 0.
+        int64_t divisor = cst.atIneq(lbPos, pos);
+        // Construct the dividend expression.
+        auto dividendExpr = getAffineConstantExpr(0, context);
+        unsigned c, f;
+        for (c = 0, f = cst.getNumCols() - 1; c < f; c++) {
+          if (c == pos)
+            continue;
+          int64_t ubVal = cst.atIneq(ubPos, c);
+          if (ubVal == 0)
+            continue;
+          if (!(*memo)[c])
+            break;
+          dividendExpr = dividendExpr + ubVal * (*memo)[c];
+        }
+        // Expression can't be constructed as it depends on a yet unknown
+        // identifier.
+        // TODO(mlir-team): Visit/compute the identifiers in an order so that
+        // this doesn't happen. More complex but much more efficient.
+        if (c < f)
+          continue;
+        // Successfully detected the floordiv.
+        (*memo)[pos] = dividendExpr.floorDiv(divisor);
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
+// Fills an inequality row with the value 'val'.
+static inline void fillInequality(FlatAffineConstraints *cst, unsigned r,
+                                  int64_t val) {
+  for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
+    cst->atIneq(r, c) = val;
+  }
+}
+
+// Negates an inequality.
+static inline void negateInequality(FlatAffineConstraints *cst, unsigned r) {
+  for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
+    cst->atIneq(r, c) = -cst->atIneq(r, c);
+  }
+}
+
+// A more complex check to eliminate redundant inequalities. Uses FourierMotzkin
+// to check if a constraint is redundant.
+void FlatAffineConstraints::removeRedundantInequalities() {
+  SmallVector<bool, 32> redun(getNumInequalities(), false);
+  // To check if an inequality is redundant, we replace the inequality by its
+  // complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting
+  // system is empty. If it is, the inequality is redundant.
+  FlatAffineConstraints tmpCst(*this);
+  for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
+    // Change the inequality to its complement.
+    negateInequality(&tmpCst, r);
+    tmpCst.atIneq(r, tmpCst.getNumCols() - 1)--;
+    if (tmpCst.isEmpty()) {
+      redun[r] = true;
+      // Zero fill the redundant inequality.
+      fillInequality(this, r, /*val=*/0);
+      fillInequality(&tmpCst, r, /*val=*/0);
+    } else {
+      // Reverse the change (to avoid recreating tmpCst each time).
+      tmpCst.atIneq(r, tmpCst.getNumCols() - 1)++;
+      negateInequality(&tmpCst, r);
+    }
+  }
+
+  // Scan to get rid of all rows marked redundant, in-place.
+  auto copyRow = [&](unsigned src, unsigned dest) {
+    if (src == dest)
+      return;
+    for (unsigned c = 0, e = getNumCols(); c < e; c++) {
+      atIneq(dest, c) = atIneq(src, c);
+    }
+  };
+  unsigned pos = 0;
+  for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
+    if (!redun[r])
+      copyRow(r, pos++);
+  }
+  inequalities.resize(numReservedCols * pos);
+}
+
+std::pair<AffineMap, AffineMap> FlatAffineConstraints::getLowerAndUpperBound(
+    unsigned pos, unsigned offset, unsigned num, unsigned symStartPos,
+    ArrayRef<AffineExpr> localExprs, MLIRContext *context) {
+  assert(pos + offset < getNumDimIds() && "invalid dim start pos");
+  assert(symStartPos >= (pos + offset) && "invalid sym start pos");
+  assert(getNumLocalIds() == localExprs.size() &&
+         "incorrect local exprs count");
+
+  SmallVector<unsigned, 4> lbIndices, ubIndices;
+  getLowerAndUpperBoundIndices(*this, pos + offset, &lbIndices, &ubIndices);
+
+  /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos).
+  auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) {
+    b.clear();
+    for (unsigned i = 0, e = a.size(); i < e; ++i) {
+      if (i < offset || i >= offset + num)
+        b.push_back(a[i]);
+    }
+  };
+
+  SmallVector<int64_t, 8> lb, ub;
+  SmallVector<AffineExpr, 4> exprs;
+  unsigned dimCount = symStartPos - num;
+  unsigned symCount = getNumDimAndSymbolIds() - symStartPos;
+  exprs.reserve(lbIndices.size());
+  // Lower bound expressions.
+  for (auto idx : lbIndices) {
+    auto ineq = getInequality(idx);
+    // Extract the lower bound (in terms of other coeff's + const), i.e., if
+    // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j
+    // - 1.
+    addCoeffs(ineq, lb);
+    std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>());
+    auto expr = mlir::toAffineExpr(lb, dimCount, symCount, localExprs, context);
+    exprs.push_back(expr);
+  }
+  auto lbMap =
+      exprs.empty() ? AffineMap() : AffineMap::get(dimCount, symCount, exprs);
+
+  exprs.clear();
+  exprs.reserve(ubIndices.size());
+  // Upper bound expressions.
+  for (auto idx : ubIndices) {
+    auto ineq = getInequality(idx);
+    // Extract the upper bound (in terms of other coeff's + const).
+    addCoeffs(ineq, ub);
+    auto expr = mlir::toAffineExpr(ub, dimCount, symCount, localExprs, context);
+    // Upper bound is exclusive.
+    exprs.push_back(expr + 1);
+  }
+  auto ubMap =
+      exprs.empty() ? AffineMap() : AffineMap::get(dimCount, symCount, exprs);
+
+  return {lbMap, ubMap};
+}
+
+/// Computes the lower and upper bounds of the first 'num' dimensional
+/// identifiers (starting at 'offset') as affine maps of the remaining
+/// identifiers (dimensional and symbolic identifiers). Local identifiers are
+/// themselves explicitly computed as affine functions of other identifiers in
+/// this process if needed.
+void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num,
+                                           MLIRContext *context,
+                                           SmallVectorImpl<AffineMap> *lbMaps,
+                                           SmallVectorImpl<AffineMap> *ubMaps) {
+  assert(num < getNumDimIds() && "invalid range");
+
+  // Basic simplification.
+  normalizeConstraintsByGCD();
+
+  LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num
+                          << " identifiers\n");
+  LLVM_DEBUG(dump());
+
+  // Record computed/detected identifiers.
+  SmallVector<AffineExpr, 8> memo(getNumIds());
+  // Initialize dimensional and symbolic identifiers.
+  for (unsigned i = 0, e = getNumDimIds(); i < e; i++) {
+    if (i < offset)
+      memo[i] = getAffineDimExpr(i, context);
+    else if (i >= offset + num)
+      memo[i] = getAffineDimExpr(i - num, context);
+  }
+  for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++)
+    memo[i] = getAffineSymbolExpr(i - getNumDimIds(), context);
+
+  bool changed;
+  do {
+    changed = false;
+    // Identify yet unknown identifiers as constants or mod's / floordiv's of
+    // other identifiers if possible.
+    for (unsigned pos = 0; pos < getNumIds(); pos++) {
+      if (memo[pos])
+        continue;
+
+      auto lbConst = getConstantLowerBound(pos);
+      auto ubConst = getConstantUpperBound(pos);
+      if (lbConst.hasValue() && ubConst.hasValue()) {
+        // Detect equality to a constant.
+        if (lbConst.getValue() == ubConst.getValue()) {
+          memo[pos] = getAffineConstantExpr(lbConst.getValue(), context);
+          changed = true;
+          continue;
+        }
+
+        // Detect an identifier as modulo of another identifier w.r.t a
+        // constant.
+        if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(),
+                        &memo)) {
+          changed = true;
+          continue;
+        }
+      }
+
+      // Detect an identifier as floordiv of another identifier w.r.t a
+      // constant.
+      if (detectAsFloorDiv(*this, pos, &memo, context)) {
+        changed = true;
+        continue;
+      }
+
+      // Detect an identifier as an expression of other identifiers.
+      unsigned idx;
+      if (!findConstraintWithNonZeroAt(*this, pos, /*isEq=*/true, &idx)) {
+        continue;
+      }
+
+      // Build AffineExpr solving for identifier 'pos' in terms of all others.
+      auto expr = getAffineConstantExpr(0, context);
+      unsigned j, e;
+      for (j = 0, e = getNumIds(); j < e; ++j) {
+        if (j == pos)
+          continue;
+        int64_t c = atEq(idx, j);
+        if (c == 0)
+          continue;
+        // If any of the involved IDs hasn't been found yet, we can't proceed.
+        if (!memo[j])
+          break;
+        expr = expr + memo[j] * c;
+      }
+      if (j < e)
+        // Can't construct expression as it depends on a yet uncomputed
+        // identifier.
+        continue;
+
+      // Add constant term to AffineExpr.
+      expr = expr + atEq(idx, getNumIds());
+      int64_t vPos = atEq(idx, pos);
+      assert(vPos != 0 && "expected non-zero here");
+      if (vPos > 0)
+        expr = (-expr).floorDiv(vPos);
+      else
+        // vPos < 0.
+        expr = expr.floorDiv(-vPos);
+      // Successfully constructed expression.
+      memo[pos] = expr;
+      changed = true;
+    }
+    // This loop is guaranteed to reach a fixed point - since once an
+    // identifier's explicit form is computed (in memo[pos]), it's not updated
+    // again.
+  } while (changed);
+
+  // Set the lower and upper bound maps for all the identifiers that were
+  // computed as affine expressions of the rest as the "detected expr" and
+  // "detected expr + 1" respectively; set the undetected ones to null.
+  Optional<FlatAffineConstraints> tmpClone;
+  for (unsigned pos = 0; pos < num; pos++) {
+    unsigned numMapDims = getNumDimIds() - num;
+    unsigned numMapSymbols = getNumSymbolIds();
+    AffineExpr expr = memo[pos + offset];
+    if (expr)
+      expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols);
+
+    AffineMap &lbMap = (*lbMaps)[pos];
+    AffineMap &ubMap = (*ubMaps)[pos];
+
+    if (expr) {
+      lbMap = AffineMap::get(numMapDims, numMapSymbols, expr);
+      ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + 1);
+    } else {
+      // TODO(bondhugula): Whenever there are local identifiers in the
+      // dependence constraints, we'll conservatively over-approximate, since we
+      // don't always explicitly compute them above (in the while loop).
+      if (getNumLocalIds() == 0) {
+        // Work on a copy so that we don't update this constraint system.
+        if (!tmpClone) {
+          tmpClone.emplace(FlatAffineConstraints(*this));
+          // Removing redudnant inequalities is necessary so that we don't get
+          // redundant loop bounds.
+          tmpClone->removeRedundantInequalities();
+        }
+        std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound(
+            pos, offset, num, getNumDimIds(), {}, context);
+      }
+
+      // If the above fails, we'll just use the constant lower bound and the
+      // constant upper bound (if they exist) as the slice bounds.
+      // TODO(b/126426796): being conservative for the moment in cases that
+      // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is
+      // fixed (b/126426796).
+      if (!lbMap || lbMap.getNumResults() > 1) {
+        LLVM_DEBUG(llvm::dbgs()
+                   << "WARNING: Potentially over-approximating slice lb\n");
+        auto lbConst = getConstantLowerBound(pos + offset);
+        if (lbConst.hasValue()) {
+          lbMap = AffineMap::get(
+              numMapDims, numMapSymbols,
+              getAffineConstantExpr(lbConst.getValue(), context));
+        }
+      }
+      if (!ubMap || ubMap.getNumResults() > 1) {
+        LLVM_DEBUG(llvm::dbgs()
+                   << "WARNING: Potentially over-approximating slice ub\n");
+        auto ubConst = getConstantUpperBound(pos + offset);
+        if (ubConst.hasValue()) {
+          (ubMap) = AffineMap::get(
+              numMapDims, numMapSymbols,
+              getAffineConstantExpr(ubConst.getValue() + 1, context));
+        }
+      }
+    }
+    LLVM_DEBUG(llvm::dbgs()
+               << "lb map for pos = " << Twine(pos + offset) << ", expr: ");
+    LLVM_DEBUG(lbMap.dump(););
+    LLVM_DEBUG(llvm::dbgs()
+               << "ub map for pos = " << Twine(pos + offset) << ", expr: ");
+    LLVM_DEBUG(ubMap.dump(););
+  }
+}
+
+LogicalResult
+FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap,
+                                            ArrayRef<Value *> boundOperands,
+                                            bool eq, bool lower) {
+  assert(pos < getNumDimAndSymbolIds() && "invalid position");
+  // Equality follows the logic of lower bound except that we add an equality
+  // instead of an inequality.
+  assert((!eq || boundMap.getNumResults() == 1) && "single result expected");
+  if (eq)
+    lower = true;
+
+  // Fully commpose map and operands; canonicalize and simplify so that we
+  // transitively get to terminal symbols or loop IVs.
+  auto map = boundMap;
+  SmallVector<Value *, 4> operands(boundOperands.begin(), boundOperands.end());
+  fullyComposeAffineMapAndOperands(&map, &operands);
+  map = simplifyAffineMap(map);
+  canonicalizeMapAndOperands(&map, &operands);
+  for (auto *operand : operands)
+    addInductionVarOrTerminalSymbol(operand);
+
+  FlatAffineConstraints localVarCst;
+  std::vector<SmallVector<int64_t, 8>> flatExprs;
+  if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst))) {
+    LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n");
+    return failure();
+  }
+
+  // Merge and align with localVarCst.
+  if (localVarCst.getNumLocalIds() > 0) {
+    // Set values for localVarCst.
+    localVarCst.setIdValues(0, localVarCst.getNumDimAndSymbolIds(), operands);
+    for (auto *operand : operands) {
+      unsigned pos;
+      if (findId(*operand, &pos)) {
+        if (pos >= getNumDimIds() && pos < getNumDimAndSymbolIds()) {
+          // If the local var cst has this as a dim, turn it into its symbol.
+          turnDimIntoSymbol(&localVarCst, *operand);
+        } else if (pos < getNumDimIds()) {
+          // Or vice versa.
+          turnSymbolIntoDim(&localVarCst, *operand);
+        }
+      }
+    }
+    mergeAndAlignIds(/*offset=*/0, this, &localVarCst);
+    append(localVarCst);
+  }
+
+  // Record positions of the operands in the constraint system. Need to do
+  // this here since the constraint system changes after a bound is added.
+  SmallVector<unsigned, 8> positions;
+  unsigned numOperands = operands.size();
+  for (auto *operand : operands) {
+    unsigned pos;
+    if (!findId(*operand, &pos))
+      assert(0 && "expected to be found");
+    positions.push_back(pos);
+  }
+
+  for (const auto &flatExpr : flatExprs) {
+    SmallVector<int64_t, 4> ineq(getNumCols(), 0);
+    ineq[pos] = lower ? 1 : -1;
+    // Dims and symbols.
+    for (unsigned j = 0, e = map.getNumInputs(); j < e; j++) {
+      ineq[positions[j]] = lower ? -flatExpr[j] : flatExpr[j];
+    }
+    // Copy over the local id coefficients.
+    unsigned numLocalIds = flatExpr.size() - 1 - numOperands;
+    for (unsigned jj = 0, j = getNumIds() - numLocalIds; jj < numLocalIds;
+         jj++, j++) {
+      ineq[j] =
+          lower ? -flatExpr[numOperands + jj] : flatExpr[numOperands + jj];
+    }
+    // Constant term.
+    ineq[getNumCols() - 1] =
+        lower ? -flatExpr[flatExpr.size() - 1]
+              // Upper bound in flattenedExpr is an exclusive one.
+              : flatExpr[flatExpr.size() - 1] - 1;
+    eq ? addEquality(ineq) : addInequality(ineq);
+  }
+  return success();
+}
+
+// Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper
+// bounds in 'ubMaps' to each value in `values' that appears in the constraint
+// system. Note that both lower/upper bounds share the same operand list
+// 'operands'.
+// This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size', and
+// skips any null AffineMaps in 'lbMaps' or 'ubMaps'.
+// Note that both lower/upper bounds use operands from 'operands'.
+// Returns failure for unimplemented cases such as semi-affine expressions or
+// expressions with mod/floordiv.
+LogicalResult FlatAffineConstraints::addSliceBounds(
+    ArrayRef<Value *> values, ArrayRef<AffineMap> lbMaps,
+    ArrayRef<AffineMap> ubMaps, ArrayRef<Value *> operands) {
+  assert(values.size() == lbMaps.size());
+  assert(lbMaps.size() == ubMaps.size());
+
+  for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
+    unsigned pos;
+    if (!findId(*values[i], &pos))
+      continue;
+
+    AffineMap lbMap = lbMaps[i];
+    AffineMap ubMap = ubMaps[i];
+    assert(!lbMap || lbMap.getNumInputs() == operands.size());
+    assert(!ubMap || ubMap.getNumInputs() == operands.size());
+
+    // Check if this slice is just an equality along this dimension.
+    if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
+        ubMap.getNumResults() == 1 &&
+        lbMap.getResult(0) + 1 == ubMap.getResult(0)) {
+      if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/true,
+                                      /*lower=*/true)))
+        return failure();
+      continue;
+    }
+
+    if (lbMap && failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false,
+                                             /*lower=*/true)))
+      return failure();
+
+    if (ubMap && failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false,
+                                             /*lower=*/false)))
+      return failure();
+  }
+  return success();
+}
+
+void FlatAffineConstraints::addEquality(ArrayRef<int64_t> eq) {
+  assert(eq.size() == getNumCols());
+  unsigned offset = equalities.size();
+  equalities.resize(equalities.size() + numReservedCols);
+  std::copy(eq.begin(), eq.end(), equalities.begin() + offset);
+}
+
+void FlatAffineConstraints::addInequality(ArrayRef<int64_t> inEq) {
+  assert(inEq.size() == getNumCols());
+  unsigned offset = inequalities.size();
+  inequalities.resize(inequalities.size() + numReservedCols);
+  std::copy(inEq.begin(), inEq.end(), inequalities.begin() + offset);
+}
+
+void FlatAffineConstraints::addConstantLowerBound(unsigned pos, int64_t lb) {
+  assert(pos < getNumCols());
+  unsigned offset = inequalities.size();
+  inequalities.resize(inequalities.size() + numReservedCols);
+  std::fill(inequalities.begin() + offset,
+            inequalities.begin() + offset + getNumCols(), 0);
+  inequalities[offset + pos] = 1;
+  inequalities[offset + getNumCols() - 1] = -lb;
+}
+
+void FlatAffineConstraints::addConstantUpperBound(unsigned pos, int64_t ub) {
+  assert(pos < getNumCols());
+  unsigned offset = inequalities.size();
+  inequalities.resize(inequalities.size() + numReservedCols);
+  std::fill(inequalities.begin() + offset,
+            inequalities.begin() + offset + getNumCols(), 0);
+  inequalities[offset + pos] = -1;
+  inequalities[offset + getNumCols() - 1] = ub;
+}
+
+void FlatAffineConstraints::addConstantLowerBound(ArrayRef<int64_t> expr,
+                                                  int64_t lb) {
+  assert(expr.size() == getNumCols());
+  unsigned offset = inequalities.size();
+  inequalities.resize(inequalities.size() + numReservedCols);
+  std::fill(inequalities.begin() + offset,
+            inequalities.begin() + offset + getNumCols(), 0);
+  std::copy(expr.begin(), expr.end(), inequalities.begin() + offset);
+  inequalities[offset + getNumCols() - 1] += -lb;
+}
+
+void FlatAffineConstraints::addConstantUpperBound(ArrayRef<int64_t> expr,
+                                                  int64_t ub) {
+  assert(expr.size() == getNumCols());
+  unsigned offset = inequalities.size();
+  inequalities.resize(inequalities.size() + numReservedCols);
+  std::fill(inequalities.begin() + offset,
+            inequalities.begin() + offset + getNumCols(), 0);
+  for (unsigned i = 0, e = getNumCols(); i < e; i++) {
+    inequalities[offset + i] = -expr[i];
+  }
+  inequalities[offset + getNumCols() - 1] += ub;
+}
+
+/// Adds a new local identifier as the floordiv of an affine function of other
+/// identifiers, the coefficients of which are provided in 'dividend' and with
+/// respect to a positive constant 'divisor'. Two constraints are added to the
+/// system to capture equivalence with the floordiv.
+///      q = expr floordiv c    <=>   c*q <= expr <= c*q + c - 1.
+void FlatAffineConstraints::addLocalFloorDiv(ArrayRef<int64_t> dividend,
+                                             int64_t divisor) {
+  assert(dividend.size() == getNumCols() && "incorrect dividend size");
+  assert(divisor > 0 && "positive divisor expected");
+
+  addLocalId(getNumLocalIds());
+
+  // Add two constraints for this new identifier 'q'.
+  SmallVector<int64_t, 8> bound(dividend.size() + 1);
+
+  // dividend - q * divisor >= 0
+  std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1,
+            bound.begin());
+  bound.back() = dividend.back();
+  bound[getNumIds() - 1] = -divisor;
+  addInequality(bound);
+
+  // -dividend +qdivisor * q + divisor - 1 >= 0
+  std::transform(bound.begin(), bound.end(), bound.begin(),
+                 std::negate<int64_t>());
+  bound[bound.size() - 1] += divisor - 1;
+  addInequality(bound);
+}
+
+bool FlatAffineConstraints::findId(Value &id, unsigned *pos) const {
+  unsigned i = 0;
+  for (const auto &mayBeId : ids) {
+    if (mayBeId.hasValue() && mayBeId.getValue() == &id) {
+      *pos = i;
+      return true;
+    }
+    i++;
+  }
+  return false;
+}
+
+bool FlatAffineConstraints::containsId(Value &id) const {
+  return llvm::any_of(ids, [&](const Optional<Value *> &mayBeId) {
+    return mayBeId.hasValue() && mayBeId.getValue() == &id;
+  });
+}
+
+void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
+  assert(newSymbolCount <= numDims + numSymbols &&
+         "invalid separation position");
+  numDims = numDims + numSymbols - newSymbolCount;
+  numSymbols = newSymbolCount;
+}
+
+/// Sets the specified identifer to a constant value.
+void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) {
+  unsigned offset = equalities.size();
+  equalities.resize(equalities.size() + numReservedCols);
+  std::fill(equalities.begin() + offset,
+            equalities.begin() + offset + getNumCols(), 0);
+  equalities[offset + pos] = 1;
+  equalities[offset + getNumCols() - 1] = -val;
+}
+
+/// Sets the specified identifer to a constant value; asserts if the id is not
+/// found.
+void FlatAffineConstraints::setIdToConstant(Value &id, int64_t val) {
+  unsigned pos;
+  if (!findId(id, &pos))
+    // This is a pre-condition for this method.
+    assert(0 && "id not found");
+  setIdToConstant(pos, val);
+}
+
+void FlatAffineConstraints::removeEquality(unsigned pos) {
+  unsigned numEqualities = getNumEqualities();
+  assert(pos < numEqualities);
+  unsigned outputIndex = pos * numReservedCols;
+  unsigned inputIndex = (pos + 1) * numReservedCols;
+  unsigned numElemsToCopy = (numEqualities - pos - 1) * numReservedCols;
+  std::copy(equalities.begin() + inputIndex,
+            equalities.begin() + inputIndex + numElemsToCopy,
+            equalities.begin() + outputIndex);
+  equalities.resize(equalities.size() - numReservedCols);
+}
+
+/// Finds an equality that equates the specified identifier to a constant.
+/// Returns the position of the equality row. If 'symbolic' is set to true,
+/// symbols are also treated like a constant, i.e., an affine function of the
+/// symbols is also treated like a constant.
+static int findEqualityToConstant(const FlatAffineConstraints &cst,
+                                  unsigned pos, bool symbolic = false) {
+  assert(pos < cst.getNumIds() && "invalid position");
+  for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
+    int64_t v = cst.atEq(r, pos);
+    if (v * v != 1)
+      continue;
+    unsigned c;
+    unsigned f = symbolic ? cst.getNumDimIds() : cst.getNumIds();
+    // This checks for zeros in all positions other than 'pos' in [0, f)
+    for (c = 0; c < f; c++) {
+      if (c == pos)
+        continue;
+      if (cst.atEq(r, c) != 0) {
+        // Dependent on another identifier.
+        break;
+      }
+    }
+    if (c == f)
+      // Equality is free of other identifiers.
+      return r;
+  }
+  return -1;
+}
+
+void FlatAffineConstraints::setAndEliminate(unsigned pos, int64_t constVal) {
+  assert(pos < getNumIds() && "invalid position");
+  for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
+    atIneq(r, getNumCols() - 1) += atIneq(r, pos) * constVal;
+  }
+  for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
+    atEq(r, getNumCols() - 1) += atEq(r, pos) * constVal;
+  }
+  removeId(pos);
+}
+
+LogicalResult FlatAffineConstraints::constantFoldId(unsigned pos) {
+  assert(pos < getNumIds() && "invalid position");
+  int rowIdx;
+  if ((rowIdx = findEqualityToConstant(*this, pos)) == -1)
+    return failure();
+
+  // atEq(rowIdx, pos) is either -1 or 1.
+  assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1);
+  int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos);
+  setAndEliminate(pos, constVal);
+  return success();
+}
+
+void FlatAffineConstraints::constantFoldIdRange(unsigned pos, unsigned num) {
+  for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) {
+    if (failed(constantFoldId(t)))
+      t++;
+  }
+}
+
+/// Returns the extent (upper bound - lower bound) of the specified
+/// identifier if it is found to be a constant; returns None if it's not a
+/// constant. This methods treats symbolic identifiers specially, i.e.,
+/// it looks for constant differences between affine expressions involving
+/// only the symbolic identifiers. See comments at function definition for
+/// example. 'lb', if provided, is set to the lower bound associated with the
+/// constant difference. Note that 'lb' is purely symbolic and thus will contain
+/// the coefficients of the symbolic identifiers and the constant coefficient.
+//  Egs: 0 <= i <= 15, return 16.
+//       s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol)
+//       s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16.
+//       s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb =
+//       ceil(s0 - 7 / 8) = floor(s0 / 8)).
+Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
+    unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *lbFloorDivisor,
+    SmallVectorImpl<int64_t> *ub) const {
+  assert(pos < getNumDimIds() && "Invalid identifier position");
+  assert(getNumLocalIds() == 0);
+
+  // TODO(bondhugula): eliminate all remaining dimensional identifiers (other
+  // than the one at 'pos' to make this more powerful. Not needed for
+  // hyper-rectangular spaces.
+
+  // Find an equality for 'pos'^th identifier that equates it to some function
+  // of the symbolic identifiers (+ constant).
+  int eqRow = findEqualityToConstant(*this, pos, /*symbolic=*/true);
+  if (eqRow != -1) {
+    // This identifier can only take a single value.
+    if (lb) {
+      // Set lb to the symbolic value.
+      lb->resize(getNumSymbolIds() + 1);
+      if (ub)
+        ub->resize(getNumSymbolIds() + 1);
+      for (unsigned c = 0, f = getNumSymbolIds() + 1; c < f; c++) {
+        int64_t v = atEq(eqRow, pos);
+        // atEq(eqRow, pos) is either -1 or 1.
+        assert(v * v == 1);
+        (*lb)[c] = v < 0 ? atEq(eqRow, getNumDimIds() + c) / -v
+                         : -atEq(eqRow, getNumDimIds() + c) / v;
+        // Since this is an equality, ub = lb.
+        if (ub)
+          (*ub)[c] = (*lb)[c];
+      }
+      assert(lbFloorDivisor &&
+             "both lb and divisor or none should be provided");
+      *lbFloorDivisor = 1;
+    }
+    return 1;
+  }
+
+  // Check if the identifier appears at all in any of the inequalities.
+  unsigned r, e;
+  for (r = 0, e = getNumInequalities(); r < e; r++) {
+    if (atIneq(r, pos) != 0)
+      break;
+  }
+  if (r == e)
+    // If it doesn't, there isn't a bound on it.
+    return None;
+
+  // Positions of constraints that are lower/upper bounds on the variable.
+  SmallVector<unsigned, 4> lbIndices, ubIndices;
+
+  // Gather all symbolic lower bounds and upper bounds of the variable. Since
+  // the canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a
+  // lower bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
+  for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
+    unsigned c, f;
+    for (c = 0, f = getNumDimIds(); c < f; c++) {
+      if (c != pos && atIneq(r, c) != 0)
+        break;
+    }
+    if (c < getNumDimIds())
+      // Not a pure symbolic bound.
+      continue;
+    if (atIneq(r, pos) >= 1)
+      // Lower bound.
+      lbIndices.push_back(r);
+    else if (atIneq(r, pos) <= -1)
+      // Upper bound.
+      ubIndices.push_back(r);
+  }
+
+  // TODO(bondhugula): eliminate other dimensional identifiers to make this more
+  // powerful. Not needed for hyper-rectangular iteration spaces.
+
+  Optional<int64_t> minDiff = None;
+  unsigned minLbPosition, minUbPosition;
+  for (auto ubPos : ubIndices) {
+    for (auto lbPos : lbIndices) {
+      // Look for a lower bound and an upper bound that only differ by a
+      // constant, i.e., pairs of the form  0 <= c_pos - f(c_i's) <= diffConst.
+      // For example, if ii is the pos^th variable, we are looking for
+      // constraints like ii >= i, ii <= ii + 50, 50 being the difference. The
+      // minimum among all such constant differences is kept since that's the
+      // constant bounding the extent of the pos^th variable.
+      unsigned j, e;
+      for (j = 0, e = getNumCols() - 1; j < e; j++)
+        if (atIneq(ubPos, j) != -atIneq(lbPos, j)) {
+          break;
+        }
+      if (j < getNumCols() - 1)
+        continue;
+      int64_t diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) +
+                                 atIneq(lbPos, getNumCols() - 1) + 1,
+                             atIneq(lbPos, pos));
+      if (minDiff == None || diff < minDiff) {
+        minDiff = diff;
+        minLbPosition = lbPos;
+        minUbPosition = ubPos;
+      }
+    }
+  }
+  if (lb && minDiff.hasValue()) {
+    // Set lb to the symbolic lower bound.
+    lb->resize(getNumSymbolIds() + 1);
+    if (ub)
+      ub->resize(getNumSymbolIds() + 1);
+    // The lower bound is the ceildiv of the lb constraint over the coefficient
+    // of the variable at 'pos'. We express the ceildiv equivalently as a floor
+    // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N +
+    // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32).
+    *lbFloorDivisor = atIneq(minLbPosition, pos);
+    assert(*lbFloorDivisor == -atIneq(minUbPosition, pos));
+    for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) {
+      (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c);
+    }
+    if (ub) {
+      for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++)
+        (*ub)[c] = atIneq(minUbPosition, getNumDimIds() + c);
+    }
+    // The lower bound leads to a ceildiv while the upper bound is a floordiv
+    // whenever the cofficient at pos != 1. ceildiv (val / d) = floordiv (val +
+    // d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to
+    // the constant term for the lower bound.
+    (*lb)[getNumSymbolIds()] += atIneq(minLbPosition, pos) - 1;
+  }
+  return minDiff;
+}
+
+template <bool isLower>
+Optional<int64_t>
+FlatAffineConstraints::computeConstantLowerOrUpperBound(unsigned pos) {
+  assert(pos < getNumIds() && "invalid position");
+  // Project to 'pos'.
+  projectOut(0, pos);
+  projectOut(1, getNumIds() - 1);
+  // Check if there's an equality equating the '0'^th identifier to a constant.
+  int eqRowIdx = findEqualityToConstant(*this, 0, /*symbolic=*/false);
+  if (eqRowIdx != -1)
+    // atEq(rowIdx, 0) is either -1 or 1.
+    return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, 0);
+
+  // Check if the identifier appears at all in any of the inequalities.
+  unsigned r, e;
+  for (r = 0, e = getNumInequalities(); r < e; r++) {
+    if (atIneq(r, 0) != 0)
+      break;
+  }
+  if (r == e)
+    // If it doesn't, there isn't a bound on it.
+    return None;
+
+  Optional<int64_t> minOrMaxConst = None;
+
+  // Take the max across all const lower bounds (or min across all constant
+  // upper bounds).
+  for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
+    if (isLower) {
+      if (atIneq(r, 0) <= 0)
+        // Not a lower bound.
+        continue;
+    } else if (atIneq(r, 0) >= 0) {
+      // Not an upper bound.
+      continue;
+    }
+    unsigned c, f;
+    for (c = 0, f = getNumCols() - 1; c < f; c++)
+      if (c != 0 && atIneq(r, c) != 0)
+        break;
+    if (c < getNumCols() - 1)
+      // Not a constant bound.
+      continue;
+
+    int64_t boundConst =
+        isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0))
+                : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0));
+    if (isLower) {
+      if (minOrMaxConst == None || boundConst > minOrMaxConst)
+        minOrMaxConst = boundConst;
+    } else {
+      if (minOrMaxConst == None || boundConst < minOrMaxConst)
+        minOrMaxConst = boundConst;
+    }
+  }
+  return minOrMaxConst;
+}
+
+Optional<int64_t>
+FlatAffineConstraints::getConstantLowerBound(unsigned pos) const {
+  FlatAffineConstraints tmpCst(*this);
+  return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/true>(pos);
+}
+
+Optional<int64_t>
+FlatAffineConstraints::getConstantUpperBound(unsigned pos) const {
+  FlatAffineConstraints tmpCst(*this);
+  return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
+}
+
+// A simple (naive and conservative) check for hyper-rectangularlity.
+bool FlatAffineConstraints::isHyperRectangular(unsigned pos,
+                                               unsigned num) const {
+  assert(pos < getNumCols() - 1);
+  // Check for two non-zero coefficients in the range [pos, pos + sum).
+  for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
+    unsigned sum = 0;
+    for (unsigned c = pos; c < pos + num; c++) {
+      if (atIneq(r, c) != 0)
+        sum++;
+    }
+    if (sum > 1)
+      return false;
+  }
+  for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
+    unsigned sum = 0;
+    for (unsigned c = pos; c < pos + num; c++) {
+      if (atEq(r, c) != 0)
+        sum++;
+    }
+    if (sum > 1)
+      return false;
+  }
+  return true;
+}
+
+void FlatAffineConstraints::print(raw_ostream &os) const {
+  assert(hasConsistentState());
+  os << "\nConstraints (" << getNumDimIds() << " dims, " << getNumSymbolIds()
+     << " symbols, " << getNumLocalIds() << " locals), (" << getNumConstraints()
+     << " constraints)\n";
+  os << "(";
+  for (unsigned i = 0, e = getNumIds(); i < e; i++) {
+    if (ids[i] == None)
+      os << "None ";
+    else
+      os << "Value ";
+  }
+  os << " const)\n";
+  for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
+    for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
+      os << atEq(i, j) << " ";
+    }
+    os << "= 0\n";
+  }
+  for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
+    for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
+      os << atIneq(i, j) << " ";
+    }
+    os << ">= 0\n";
+  }
+  os << '\n';
+}
+
+void FlatAffineConstraints::dump() const { print(llvm::errs()); }
+
+/// Removes duplicate constraints, trivially true constraints, and constraints
+/// that can be detected as redundant as a result of differing only in their
+/// constant term part. A constraint of the form <non-negative constant> >= 0 is
+/// considered trivially true.
+//  Uses a DenseSet to hash and detect duplicates followed by a linear scan to
+//  remove duplicates in place.
+void FlatAffineConstraints::removeTrivialRedundancy() {
+  SmallDenseSet<ArrayRef<int64_t>, 8> rowSet;
+
+  // A map used to detect redundancy stemming from constraints that only differ
+  // in their constant term. The value stored is <row position, const term>
+  // for a given row.
+  SmallDenseMap<ArrayRef<int64_t>, std::pair<unsigned, int64_t>>
+      rowsWithoutConstTerm;
+
+  // Check if constraint is of the form <non-negative-constant> >= 0.
+  auto isTriviallyValid = [&](unsigned r) -> bool {
+    for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) {
+      if (atIneq(r, c) != 0)
+        return false;
+    }
+    return atIneq(r, getNumCols() - 1) >= 0;
+  };
+
+  // Detect and mark redundant constraints.
+  SmallVector<bool, 256> redunIneq(getNumInequalities(), false);
+  for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
+    int64_t *rowStart = inequalities.data() + numReservedCols * r;
+    auto row = ArrayRef<int64_t>(rowStart, getNumCols());
+    if (isTriviallyValid(r) || !rowSet.insert(row).second) {
+      redunIneq[r] = true;
+      continue;
+    }
+
+    // Among constraints that only differ in the constant term part, mark
+    // everything other than the one with the smallest constant term redundant.
+    // (eg: among i - 16j - 5 >= 0, i - 16j - 1 >=0, i - 16j - 7 >= 0, the
+    // former two are redundant).
+    int64_t constTerm = atIneq(r, getNumCols() - 1);
+    auto rowWithoutConstTerm = ArrayRef<int64_t>(rowStart, getNumCols() - 1);
+    const auto &ret =
+        rowsWithoutConstTerm.insert({rowWithoutConstTerm, {r, constTerm}});
+    if (!ret.second) {
+      // Check if the other constraint has a higher constant term.
+      auto &val = ret.first->second;
+      if (val.second > constTerm) {
+        // The stored row is redundant. Mark it so, and update with this one.
+        redunIneq[val.first] = true;
+        val = {r, constTerm};
+      } else {
+        // The one stored makes this one redundant.
+        redunIneq[r] = true;
+      }
+    }
+  }
+
+  auto copyRow = [&](unsigned src, unsigned dest) {
+    if (src == dest)
+      return;
+    for (unsigned c = 0, e = getNumCols(); c < e; c++) {
+      atIneq(dest, c) = atIneq(src, c);
+    }
+  };
+
+  // Scan to get rid of all rows marked redundant, in-place.
+  unsigned pos = 0;
+  for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
+    if (!redunIneq[r])
+      copyRow(r, pos++);
+  }
+  inequalities.resize(numReservedCols * pos);
+
+  // TODO(bondhugula): consider doing this for equalities as well, but probably
+  // not worth the savings.
+}
+
+void FlatAffineConstraints::clearAndCopyFrom(
+    const FlatAffineConstraints &other) {
+  FlatAffineConstraints copy(other);
+  std::swap(*this, copy);
+  assert(copy.getNumIds() == copy.getIds().size());
+}
+
+void FlatAffineConstraints::removeId(unsigned pos) {
+  removeIdRange(pos, pos + 1);
+}
+
+static std::pair<unsigned, unsigned>
+getNewNumDimsSymbols(unsigned pos, const FlatAffineConstraints &cst) {
+  unsigned numDims = cst.getNumDimIds();
+  unsigned numSymbols = cst.getNumSymbolIds();
+  unsigned newNumDims, newNumSymbols;
+  if (pos < numDims) {
+    newNumDims = numDims - 1;
+    newNumSymbols = numSymbols;
+  } else if (pos < numDims + numSymbols) {
+    assert(numSymbols >= 1);
+    newNumDims = numDims;
+    newNumSymbols = numSymbols - 1;
+  } else {
+    newNumDims = numDims;
+    newNumSymbols = numSymbols;
+  }
+  return {newNumDims, newNumSymbols};
+}
+
+#undef DEBUG_TYPE
+#define DEBUG_TYPE "fm"
+
+/// Eliminates identifier at the specified position using Fourier-Motzkin
+/// variable elimination. This technique is exact for rational spaces but
+/// conservative (in "rare" cases) for integer spaces. The operation corresponds
+/// to a projection operation yielding the (convex) set of integer points
+/// contained in the rational shadow of the set. An emptiness test that relies
+/// on this method will guarantee emptiness, i.e., it disproves the existence of
+/// a solution if it says it's empty.
+/// If a non-null isResultIntegerExact is passed, it is set to true if the
+/// result is also integer exact. If it's set to false, the obtained solution
+/// *may* not be exact, i.e., it may contain integer points that do not have an
+/// integer pre-image in the original set.
+///
+/// Eg:
+/// j >= 0, j <= i + 1
+/// i >= 0, i <= N + 1
+/// Eliminating i yields,
+///   j >= 0, 0 <= N + 1, j - 1 <= N + 1
+///
+/// If darkShadow = true, this method computes the dark shadow on elimination;
+/// the dark shadow is a convex integer subset of the exact integer shadow. A
+/// non-empty dark shadow proves the existence of an integer solution. The
+/// elimination in such a case could however be an under-approximation, and thus
+/// should not be used for scanning sets or used by itself for dependence
+/// checking.
+///
+/// Eg: 2-d set, * represents grid points, 'o' represents a point in the set.
+///            ^
+///            |
+///            | * * * * o o
+///         i  | * * o o o o
+///            | o * * * * *
+///            --------------->
+///                 j ->
+///
+/// Eliminating i from this system (projecting on the j dimension):
+/// rational shadow / integer light shadow:  1 <= j <= 6
+/// dark shadow:                             3 <= j <= 6
+/// exact integer shadow:                    j = 1 \union  3 <= j <= 6
+/// holes/splinters:                         j = 2
+///
+/// darkShadow = false, isResultIntegerExact = nullptr are default values.
+// TODO(bondhugula): a slight modification to yield dark shadow version of FM
+// (tightened), which can prove the existence of a solution if there is one.
+void FlatAffineConstraints::FourierMotzkinEliminate(
+    unsigned pos, bool darkShadow, bool *isResultIntegerExact) {
+  LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n");
+  LLVM_DEBUG(dump());
+  assert(pos < getNumIds() && "invalid position");
+  assert(hasConsistentState());
+
+  // Check if this identifier can be eliminated through a substitution.
+  for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
+    if (atEq(r, pos) != 0) {
+      // Use Gaussian elimination here (since we have an equality).
+      LogicalResult ret = gaussianEliminateId(pos);
+      (void)ret;
+      assert(succeeded(ret) && "Gaussian elimination guaranteed to succeed");
+      LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n");
+      LLVM_DEBUG(dump());
+      return;
+    }
+  }
+
+  // A fast linear time tightening.
+  GCDTightenInequalities();
+
+  // Check if the identifier appears at all in any of the inequalities.
+  unsigned r, e;
+  for (r = 0, e = getNumInequalities(); r < e; r++) {
+    if (atIneq(r, pos) != 0)
+      break;
+  }
+  if (r == getNumInequalities()) {
+    // If it doesn't appear, just remove the column and return.
+    // TODO(andydavis,bondhugula): refactor removeColumns to use it from here.
+    removeId(pos);
+    LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
+    LLVM_DEBUG(dump());
+    return;
+  }
+
+  // Positions of constraints that are lower bounds on the variable.
+  SmallVector<unsigned, 4> lbIndices;
+  // Positions of constraints that are lower bounds on the variable.
+  SmallVector<unsigned, 4> ubIndices;
+  // Positions of constraints that do not involve the variable.
+  std::vector<unsigned> nbIndices;
+  nbIndices.reserve(getNumInequalities());
+
+  // Gather all lower bounds and upper bounds of the variable. Since the
+  // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
+  // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
+  for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
+    if (atIneq(r, pos) == 0) {
+      // Id does not appear in bound.
+      nbIndices.push_back(r);
+    } else if (atIneq(r, pos) >= 1) {
+      // Lower bound.
+      lbIndices.push_back(r);
+    } else {
+      // Upper bound.
+      ubIndices.push_back(r);
+    }
+  }
+
+  // Set the number of dimensions, symbols in the resulting system.
+  const auto &dimsSymbols = getNewNumDimsSymbols(pos, *this);
+  unsigned newNumDims = dimsSymbols.first;
+  unsigned newNumSymbols = dimsSymbols.second;
+
+  SmallVector<Optional<Value *>, 8> newIds;
+  newIds.reserve(numIds - 1);
+  newIds.append(ids.begin(), ids.begin() + pos);
+  newIds.append(ids.begin() + pos + 1, ids.end());
+
+  /// Create the new system which has one identifier less.
+  FlatAffineConstraints newFac(
+      lbIndices.size() * ubIndices.size() + nbIndices.size(),
+      getNumEqualities(), getNumCols() - 1, newNumDims, newNumSymbols,
+      /*numLocals=*/getNumIds() - 1 - newNumDims - newNumSymbols, newIds);
+
+  assert(newFac.getIds().size() == newFac.getNumIds());
+
+  // This will be used to check if the elimination was integer exact.
+  unsigned lcmProducts = 1;
+
+  // Let x be the variable we are eliminating.
+  // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note
+  // that c_l, c_u >= 1) we have:
+  // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u
+  // We thus generate a constraint:
+  // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub.
+  // Note if c_l = c_u = 1, all integer points captured by the resulting
+  // constraint correspond to integer points in the original system (i.e., they
+  // have integer pre-images). Hence, if the lcm's are all 1, the elimination is
+  // integer exact.
+  for (auto ubPos : ubIndices) {
+    for (auto lbPos : lbIndices) {
+      SmallVector<int64_t, 4> ineq;
+      ineq.reserve(newFac.getNumCols());
+      int64_t lbCoeff = atIneq(lbPos, pos);
+      // Note that in the comments above, ubCoeff is the negation of the
+      // coefficient in the canonical form as the view taken here is that of the
+      // term being moved to the other size of '>='.
+      int64_t ubCoeff = -atIneq(ubPos, pos);
+      // TODO(bondhugula): refactor this loop to avoid all branches inside.
+      for (unsigned l = 0, e = getNumCols(); l < e; l++) {
+        if (l == pos)
+          continue;
+        assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified");
+        int64_t lcm = mlir::lcm(lbCoeff, ubCoeff);
+        ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) +
+                       atIneq(lbPos, l) * (lcm / lbCoeff));
+        lcmProducts *= lcm;
+      }
+      if (darkShadow) {
+        // The dark shadow is a convex subset of the exact integer shadow. If
+        // there is a point here, it proves the existence of a solution.
+        ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1;
+      }
+      // TODO: we need to have a way to add inequalities in-place in
+      // FlatAffineConstraints instead of creating and copying over.
+      newFac.addInequality(ineq);
+    }
+  }
+
+  LLVM_DEBUG(llvm::dbgs() << "FM isResultIntegerExact: " << (lcmProducts == 1)
+                          << "\n");
+  if (lcmProducts == 1 && isResultIntegerExact)
+    *isResultIntegerExact = 1;
+
+  // Copy over the constraints not involving this variable.
+  for (auto nbPos : nbIndices) {
+    SmallVector<int64_t, 4> ineq;
+    ineq.reserve(getNumCols() - 1);
+    for (unsigned l = 0, e = getNumCols(); l < e; l++) {
+      if (l == pos)
+        continue;
+      ineq.push_back(atIneq(nbPos, l));
+    }
+    newFac.addInequality(ineq);
+  }
+
+  assert(newFac.getNumConstraints() ==
+         lbIndices.size() * ubIndices.size() + nbIndices.size());
+
+  // Copy over the equalities.
+  for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
+    SmallVector<int64_t, 4> eq;
+    eq.reserve(newFac.getNumCols());
+    for (unsigned l = 0, e = getNumCols(); l < e; l++) {
+      if (l == pos)
+        continue;
+      eq.push_back(atEq(r, l));
+    }
+    newFac.addEquality(eq);
+  }
+
+  // GCD tightening and normalization allows detection of more trivially
+  // redundant constraints.
+  newFac.GCDTightenInequalities();
+  newFac.normalizeConstraintsByGCD();
+  newFac.removeTrivialRedundancy();
+  clearAndCopyFrom(newFac);
+  LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
+  LLVM_DEBUG(dump());
+}
+
+#undef DEBUG_TYPE
+#define DEBUG_TYPE "affine-structures"
+
+void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) {
+  if (num == 0)
+    return;
+
+  // 'pos' can be at most getNumCols() - 2 if num > 0.
+  assert((getNumCols() < 2 || pos <= getNumCols() - 2) && "invalid position");
+  assert(pos + num < getNumCols() && "invalid range");
+
+  // Eliminate as many identifiers as possible using Gaussian elimination.
+  unsigned currentPos = pos;
+  unsigned numToEliminate = num;
+  unsigned numGaussianEliminated = 0;
+
+  while (currentPos < getNumIds()) {
+    unsigned curNumEliminated =
+        gaussianEliminateIds(currentPos, currentPos + numToEliminate);
+    ++currentPos;
+    numToEliminate -= curNumEliminated + 1;
+    numGaussianEliminated += curNumEliminated;
+  }
+
+  // Eliminate the remaining using Fourier-Motzkin.
+  for (unsigned i = 0; i < num - numGaussianEliminated; i++) {
+    unsigned numToEliminate = num - numGaussianEliminated - i;
+    FourierMotzkinEliminate(
+        getBestIdToEliminate(*this, pos, pos + numToEliminate));
+  }
+
+  // Fast/trivial simplifications.
+  GCDTightenInequalities();
+  // Normalize constraints after tightening since the latter impacts this, but
+  // not the other way round.
+  normalizeConstraintsByGCD();
+}
+
+void FlatAffineConstraints::projectOut(Value *id) {
+  unsigned pos;
+  bool ret = findId(*id, &pos);
+  assert(ret);
+  (void)ret;
+  FourierMotzkinEliminate(pos);
+}
+
+bool FlatAffineConstraints::isRangeOneToOne(unsigned start,
+                                            unsigned limit) const {
+  assert(start <= getNumIds() - 1 && "invalid start position");
+  assert(limit > start && limit <= getNumIds() && "invalid limit");
+
+  FlatAffineConstraints tmpCst(*this);
+
+  if (start != 0) {
+    // Move [start, limit) to the left.
+    for (unsigned r = 0, e = getNumInequalities(); r < e; ++r) {
+      for (unsigned c = 0, f = getNumCols(); c < f; ++c) {
+        if (c >= start && c < limit)
+          tmpCst.atIneq(r, c - start) = atIneq(r, c);
+        else if (c < start)
+          tmpCst.atIneq(r, c + limit - start) = atIneq(r, c);
+        else
+          tmpCst.atIneq(r, c) = atIneq(r, c);
+      }
+    }
+    for (unsigned r = 0, e = getNumEqualities(); r < e; ++r) {
+      for (unsigned c = 0, f = getNumCols(); c < f; ++c) {
+        if (c >= start && c < limit)
+          tmpCst.atEq(r, c - start) = atEq(r, c);
+        else if (c < start)
+          tmpCst.atEq(r, c + limit - start) = atEq(r, c);
+        else
+          tmpCst.atEq(r, c) = atEq(r, c);
+      }
+    }
+  }
+
+  // Mark everything to the right as symbols so that we can check the extents in
+  // a symbolic way below.
+  tmpCst.setDimSymbolSeparation(getNumIds() - (limit - start));
+
+  // Check if the extents of all the specified dimensions are just one (when
+  // treating the rest as symbols).
+  for (unsigned pos = 0, e = tmpCst.getNumDimIds(); pos < e; ++pos) {
+    auto extent = tmpCst.getConstantBoundOnDimSize(pos);
+    if (!extent.hasValue() || extent.getValue() != 1)
+      return false;
+  }
+  return true;
+}
+
+void FlatAffineConstraints::clearConstraints() {
+  equalities.clear();
+  inequalities.clear();
+}
+
+namespace {
+
+enum BoundCmpResult { Greater, Less, Equal, Unknown };
+
+/// Compares two affine bounds whose coefficients are provided in 'first' and
+/// 'second'. The last coefficient is the constant term.
+static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
+  assert(a.size() == b.size());
+
+  // For the bounds to be comparable, their corresponding identifier
+  // coefficients should be equal; the constant terms are then compared to
+  // determine less/greater/equal.
+
+  if (!std::equal(a.begin(), a.end() - 1, b.begin()))
+    return Unknown;
+
+  if (a.back() == b.back())
+    return Equal;
+
+  return a.back() < b.back() ? Less : Greater;
+}
+} // namespace
+
+// Computes the bounding box with respect to 'other' by finding the min of the
+// lower bounds and the max of the upper bounds along each of the dimensions.
+LogicalResult
+FlatAffineConstraints::unionBoundingBox(const FlatAffineConstraints &otherCst) {
+  assert(otherCst.getNumDimIds() == numDims && "dims mismatch");
+  assert(otherCst.getIds()
+             .slice(0, getNumDimIds())
+             .equals(getIds().slice(0, getNumDimIds())) &&
+         "dim values mismatch");
+  assert(otherCst.getNumLocalIds() == 0 && "local ids not supported here");
+  assert(getNumLocalIds() == 0 && "local ids not supported yet here");
+
+  Optional<FlatAffineConstraints> otherCopy;
+  if (!areIdsAligned(*this, otherCst)) {
+    otherCopy.emplace(FlatAffineConstraints(otherCst));
+    mergeAndAlignIds(/*offset=*/numDims, this, &otherCopy.getValue());
+  }
+
+  const auto &other = otherCopy ? *otherCopy : otherCst;
+
+  std::vector<SmallVector<int64_t, 8>> boundingLbs;
+  std::vector<SmallVector<int64_t, 8>> boundingUbs;
+  boundingLbs.reserve(2 * getNumDimIds());
+  boundingUbs.reserve(2 * getNumDimIds());
+
+  // To hold lower and upper bounds for each dimension.
+  SmallVector<int64_t, 4> lb, otherLb, ub, otherUb;
+  // To compute min of lower bounds and max of upper bounds for each dimension.
+  SmallVector<int64_t, 4> minLb(getNumSymbolIds() + 1);
+  SmallVector<int64_t, 4> maxUb(getNumSymbolIds() + 1);
+  // To compute final new lower and upper bounds for the union.
+  SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols());
+
+  int64_t lbFloorDivisor, otherLbFloorDivisor;
+  for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
+    auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub);
+    if (!extent.hasValue())
+      // TODO(bondhugula): symbolic extents when necessary.
+      // TODO(bondhugula): handle union if a dimension is unbounded.
+      return failure();
+
+    auto otherExtent = other.getConstantBoundOnDimSize(
+        d, &otherLb, &otherLbFloorDivisor, &otherUb);
+    if (!otherExtent.hasValue() || lbFloorDivisor != otherLbFloorDivisor)
+      // TODO(bondhugula): symbolic extents when necessary.
+      return failure();
+
+    assert(lbFloorDivisor > 0 && "divisor always expected to be positive");
+
+    auto res = compareBounds(lb, otherLb);
+    // Identify min.
+    if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) {
+      minLb = lb;
+      // Since the divisor is for a floordiv, we need to convert to ceildiv,
+      // i.e., i >= expr floordiv div <=> i >= (expr - div + 1) ceildiv div <=>
+      // div * i >= expr - div + 1.
+      minLb.back() -= lbFloorDivisor - 1;
+    } else if (res == BoundCmpResult::Greater) {
+      minLb = otherLb;
+      minLb.back() -= otherLbFloorDivisor - 1;
+    } else {
+      // Uncomparable - check for constant lower/upper bounds.
+      auto constLb = getConstantLowerBound(d);
+      auto constOtherLb = other.getConstantLowerBound(d);
+      if (!constLb.hasValue() || !constOtherLb.hasValue())
+        return failure();
+      std::fill(minLb.begin(), minLb.end(), 0);
+      minLb.back() = std::min(constLb.getValue(), constOtherLb.getValue());
+    }
+
+    // Do the same for ub's but max of upper bounds. Identify max.
+    auto uRes = compareBounds(ub, otherUb);
+    if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) {
+      maxUb = ub;
+    } else if (uRes == BoundCmpResult::Less) {
+      maxUb = otherUb;
+    } else {
+      // Uncomparable - check for constant lower/upper bounds.
+      auto constUb = getConstantUpperBound(d);
+      auto constOtherUb = other.getConstantUpperBound(d);
+      if (!constUb.hasValue() || !constOtherUb.hasValue())
+        return failure();
+      std::fill(maxUb.begin(), maxUb.end(), 0);
+      maxUb.back() = std::max(constUb.getValue(), constOtherUb.getValue());
+    }
+
+    std::fill(newLb.begin(), newLb.end(), 0);
+    std::fill(newUb.begin(), newUb.end(), 0);
+
+    // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor,
+    // and so it's the divisor for newLb and newUb as well.
+    newLb[d] = lbFloorDivisor;
+    newUb[d] = -lbFloorDivisor;
+    // Copy over the symbolic part + constant term.
+    std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimIds());
+    std::transform(newLb.begin() + getNumDimIds(), newLb.end(),
+                   newLb.begin() + getNumDimIds(), std::negate<int64_t>());
+    std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimIds());
+
+    boundingLbs.push_back(newLb);
+    boundingUbs.push_back(newUb);
+  }
+
+  // Clear all constraints and add the lower/upper bounds for the bounding box.
+  clearConstraints();
+  for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
+    addInequality(boundingLbs[d]);
+    addInequality(boundingUbs[d]);
+  }
+  // TODO(mlir-team): copy over pure symbolic constraints from this and 'other'
+  // over to the union (since the above are just the union along dimensions); we
+  // shouldn't be discarding any other constraints on the symbols.
+
+  return success();
+}
diff --git a/third_party/mlir/lib/Analysis/CMakeLists.txt b/third_party/mlir/lib/Analysis/CMakeLists.txt
new file mode 100644
index 0000000..e2b1d12
--- /dev/null
+++ b/third_party/mlir/lib/Analysis/CMakeLists.txt
@@ -0,0 +1,20 @@
+add_llvm_library(MLIRAnalysis STATIC
+  AffineAnalysis.cpp
+  AffineStructures.cpp
+  Dominance.cpp
+  LoopAnalysis.cpp
+  MemRefBoundCheck.cpp
+  NestedMatcher.cpp
+  OpStats.cpp
+  SliceAnalysis.cpp
+  TestMemRefDependenceCheck.cpp
+  TestParallelismDetection.cpp
+  Utils.cpp
+  VectorAnalysis.cpp
+  Verifier.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Analysis
+  )
+add_dependencies(MLIRAnalysis MLIRAffineOps MLIRLoopOps)
+target_link_libraries(MLIRAnalysis MLIRAffineOps MLIRLoopOps)
diff --git a/third_party/mlir/lib/Analysis/Dominance.cpp b/third_party/mlir/lib/Analysis/Dominance.cpp
new file mode 100644
index 0000000..e384a56
--- /dev/null
+++ b/third_party/mlir/lib/Analysis/Dominance.cpp
@@ -0,0 +1,164 @@
+//===- Dominance.cpp - Dominator analysis for CFGs ------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Implementation of dominance related classes and instantiations of extern
+// templates.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Dominance.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/Support/GenericDomTreeConstruction.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+template class llvm::DominatorTreeBase<Block, /*IsPostDom=*/false>;
+template class llvm::DominatorTreeBase<Block, /*IsPostDom=*/true>;
+template class llvm::DomTreeNodeBase<Block>;
+
+//===----------------------------------------------------------------------===//
+// DominanceInfoBase
+//===----------------------------------------------------------------------===//
+
+template <bool IsPostDom>
+void DominanceInfoBase<IsPostDom>::recalculate(Operation *op) {
+  dominanceInfos.clear();
+
+  /// Build the dominance for each of the operation regions.
+  op->walk([&](Operation *op) {
+    for (auto &region : op->getRegions()) {
+      // Don't compute dominance if the region is empty.
+      if (region.empty())
+        continue;
+      auto opDominance = llvm::make_unique<base>();
+      opDominance->recalculate(region);
+      dominanceInfos.try_emplace(&region, std::move(opDominance));
+    }
+  });
+}
+
+/// Return true if the specified block A properly dominates block B.
+template <bool IsPostDom>
+bool DominanceInfoBase<IsPostDom>::properlyDominates(Block *a, Block *b) {
+  // A block dominates itself but does not properly dominate itself.
+  if (a == b)
+    return false;
+
+  // If either a or b are null, then conservatively return false.
+  if (!a || !b)
+    return false;
+
+  // If both blocks are not in the same region, 'a' properly dominates 'b' if
+  // 'b' is defined in an operation region that (recursively) ends up being
+  // dominated by 'a'. Walk up the list of containers enclosing B.
+  auto *regionA = a->getParent(), *regionB = b->getParent();
+  if (regionA != regionB) {
+    Operation *bAncestor;
+    do {
+      bAncestor = regionB->getParentOp();
+      // If 'bAncestor' is the top level region, then 'a' is a block that post
+      // dominates 'b'.
+      if (!bAncestor || !bAncestor->getBlock())
+        return IsPostDom;
+
+      regionB = bAncestor->getBlock()->getParent();
+    } while (regionA != regionB);
+
+    // Check to see if the ancestor of 'b' is the same block as 'a'.
+    b = bAncestor->getBlock();
+    if (a == b)
+      return true;
+  }
+
+  // Otherwise, use the standard dominance functionality.
+
+  // If we don't have a dominance information for this region, assume that b is
+  // dominated by anything.
+  auto baseInfoIt = dominanceInfos.find(regionA);
+  if (baseInfoIt == dominanceInfos.end())
+    return true;
+  return baseInfoIt->second->properlyDominates(a, b);
+}
+
+template class mlir::detail::DominanceInfoBase</*IsPostDom=*/true>;
+template class mlir::detail::DominanceInfoBase</*IsPostDom=*/false>;
+
+//===----------------------------------------------------------------------===//
+// DominanceInfo
+//===----------------------------------------------------------------------===//
+
+/// Return true if operation A properly dominates operation B.
+bool DominanceInfo::properlyDominates(Operation *a, Operation *b) {
+  auto *aBlock = a->getBlock(), *bBlock = b->getBlock();
+
+  // If a or b are not within a block, then a does not dominate b.
+  if (!aBlock || !bBlock)
+    return false;
+
+  // If the blocks are the same, then check if b is before a in the block.
+  if (aBlock == bBlock)
+    return a->isBeforeInBlock(b);
+
+  // Traverse up b's hierarchy to check if b's block is contained in a's.
+  if (auto *bAncestor = aBlock->findAncestorInstInBlock(*b)) {
+    // Since we already know that aBlock != bBlock, here bAncestor != b.
+    // a and bAncestor are in the same block; check if 'a' dominates
+    // bAncestor.
+    return dominates(a, bAncestor);
+  }
+
+  // If the blocks are different, check if a's block dominates b's.
+  return properlyDominates(aBlock, bBlock);
+}
+
+/// Return true if value A properly dominates operation B.
+bool DominanceInfo::properlyDominates(Value *a, Operation *b) {
+  if (auto *aInst = a->getDefiningOp())
+    return properlyDominates(aInst, b);
+
+  // block arguments properly dominate all operations in their own block, so
+  // we use a dominates check here, not a properlyDominates check.
+  return dominates(cast<BlockArgument>(a)->getOwner(), b->getBlock());
+}
+
+//===----------------------------------------------------------------------===//
+// PostDominanceInfo
+//===----------------------------------------------------------------------===//
+
+/// Returns true if statement 'a' properly postdominates statement b.
+bool PostDominanceInfo::properlyPostDominates(Operation *a, Operation *b) {
+  auto *aBlock = a->getBlock(), *bBlock = b->getBlock();
+
+  // If a or b are not within a block, then a does not post dominate b.
+  if (!aBlock || !bBlock)
+    return false;
+
+  // If the blocks are the same, check if b is before a in the block.
+  if (aBlock == bBlock)
+    return b->isBeforeInBlock(a);
+
+  // Traverse up b's hierarchy to check if b's block is contained in a's.
+  if (auto *bAncestor = a->getBlock()->findAncestorInstInBlock(*b))
+    // Since we already know that aBlock != bBlock, here bAncestor != b.
+    // a and bAncestor are in the same block; check if 'a' postdominates
+    // bAncestor.
+    return postDominates(a, bAncestor);
+
+  // If the blocks are different, check if a's block post dominates b's.
+  return properlyDominates(aBlock, bBlock);
+}
diff --git a/third_party/mlir/lib/Analysis/LoopAnalysis.cpp b/third_party/mlir/lib/Analysis/LoopAnalysis.cpp
new file mode 100644
index 0000000..743907b
--- /dev/null
+++ b/third_party/mlir/lib/Analysis/LoopAnalysis.cpp
@@ -0,0 +1,402 @@
+//===- LoopAnalysis.cpp - Misc loop analysis routines //-------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements miscellaneous loop analysis routines.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/LoopAnalysis.h"
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/NestedMatcher.h"
+#include "mlir/Analysis/VectorAnalysis.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Support/Functional.h"
+#include "mlir/Support/MathExtras.h"
+#include "mlir/VectorOps/VectorOps.h"
+
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SmallString.h"
+#include <type_traits>
+
+using namespace mlir;
+
+/// Returns the trip count of the loop as an affine expression if the latter is
+/// expressible as an affine expression, and nullptr otherwise. The trip count
+/// expression is simplified before returning. This method only utilizes map
+/// composition to construct lower and upper bounds before computing the trip
+/// count expressions.
+// TODO(mlir-team): this should be moved into 'Transforms/' and be replaced by a
+// pure analysis method relying on FlatAffineConstraints; the latter will also
+// be more powerful (since both inequalities and equalities will be considered).
+void mlir::buildTripCountMapAndOperands(
+    AffineForOp forOp, AffineMap *map,
+    SmallVectorImpl<Value *> *tripCountOperands) {
+  int64_t loopSpan;
+
+  int64_t step = forOp.getStep();
+  OpBuilder b(forOp.getOperation());
+
+  if (forOp.hasConstantBounds()) {
+    int64_t lb = forOp.getConstantLowerBound();
+    int64_t ub = forOp.getConstantUpperBound();
+    loopSpan = ub - lb;
+    if (loopSpan < 0)
+      loopSpan = 0;
+    *map = b.getConstantAffineMap(ceilDiv(loopSpan, step));
+    tripCountOperands->clear();
+    return;
+  }
+  auto lbMap = forOp.getLowerBoundMap();
+  auto ubMap = forOp.getUpperBoundMap();
+  if (lbMap.getNumResults() != 1) {
+    *map = AffineMap();
+    return;
+  }
+  SmallVector<Value *, 4> lbOperands(forOp.getLowerBoundOperands());
+  SmallVector<Value *, 4> ubOperands(forOp.getUpperBoundOperands());
+  auto lb = b.create<AffineApplyOp>(forOp.getLoc(), lbMap, lbOperands);
+  SmallVector<Value *, 4> ubs;
+  ubs.reserve(ubMap.getNumResults());
+  for (auto ubExpr : ubMap.getResults())
+    ubs.push_back(b.create<AffineApplyOp>(
+        forOp.getLoc(),
+        b.getAffineMap(ubMap.getNumDims(), ubMap.getNumSymbols(), {ubExpr}),
+        ubOperands));
+
+  tripCountOperands->clear();
+  tripCountOperands->reserve(1 + ubs.size());
+  tripCountOperands->push_back(lb);
+  tripCountOperands->append(ubs.begin(), ubs.end());
+
+  SmallVector<AffineExpr, 4> tripCountExprs(ubs.size());
+  for (unsigned i = 0, e = ubs.size(); i < e; i++)
+    tripCountExprs[i] =
+        (b.getAffineDimExpr(1 + i) - b.getAffineDimExpr(0)).ceilDiv(step);
+  *map = b.getAffineMap(1 + ubs.size(), 0, tripCountExprs);
+
+  fullyComposeAffineMapAndOperands(map, tripCountOperands);
+  *map = simplifyAffineMap(*map);
+  canonicalizeMapAndOperands(map, tripCountOperands);
+  // Remove any affine.apply's that became dead as a result of composition,
+  // simplification, and canonicalization above.
+  for (auto *v : ubs)
+    if (v->use_empty())
+      v->getDefiningOp()->erase();
+  if (lb.use_empty())
+    lb.erase();
+}
+
+/// Returns the trip count of the loop if it's a constant, None otherwise. This
+/// method uses affine expression analysis (in turn using getTripCount) and is
+/// able to determine constant trip count in non-trivial cases.
+// FIXME(mlir-team): this is really relying on buildTripCountMapAndOperands;
+// being an analysis utility, it shouldn't. Replace with a version that just
+// works with analysis structures (FlatAffineConstraints) and thus doesn't
+// update the IR.
+llvm::Optional<uint64_t> mlir::getConstantTripCount(AffineForOp forOp) {
+  SmallVector<Value *, 4> operands;
+  AffineMap map;
+  buildTripCountMapAndOperands(forOp, &map, &operands);
+
+  if (!map)
+    return None;
+
+  // Take the min if all trip counts are constant.
+  Optional<uint64_t> tripCount;
+  for (auto resultExpr : map.getResults()) {
+    if (auto constExpr = resultExpr.dyn_cast<AffineConstantExpr>()) {
+      if (tripCount.hasValue())
+        tripCount = std::min(tripCount.getValue(),
+                             static_cast<uint64_t>(constExpr.getValue()));
+      else
+        tripCount = constExpr.getValue();
+    } else
+      return None;
+  }
+  return tripCount;
+}
+
+/// Returns the greatest known integral divisor of the trip count. Affine
+/// expression analysis is used (indirectly through getTripCount), and
+/// this method is thus able to determine non-trivial divisors.
+uint64_t mlir::getLargestDivisorOfTripCount(AffineForOp forOp) {
+  SmallVector<Value *, 4> operands;
+  AffineMap map;
+  buildTripCountMapAndOperands(forOp, &map, &operands);
+
+  if (!map)
+    return 1;
+
+  // The largest divisor of the trip count is the GCD of the individual largest
+  // divisors.
+  assert(map.getNumResults() >= 1 && "expected one or more results");
+  Optional<uint64_t> gcd;
+  for (auto resultExpr : map.getResults()) {
+    uint64_t thisGcd;
+    if (auto constExpr = resultExpr.dyn_cast<AffineConstantExpr>()) {
+      uint64_t tripCount = constExpr.getValue();
+      // 0 iteration loops (greatest divisor is 2^64 - 1).
+      if (tripCount == 0)
+        thisGcd = std::numeric_limits<uint64_t>::max();
+      else
+        // The greatest divisor is the trip count.
+        thisGcd = tripCount;
+    } else {
+      // Trip count is not a known constant; return its largest known divisor.
+      thisGcd = resultExpr.getLargestKnownDivisor();
+    }
+    if (gcd.hasValue())
+      gcd = llvm::GreatestCommonDivisor64(gcd.getValue(), thisGcd);
+    else
+      gcd = thisGcd;
+  }
+  assert(gcd.hasValue() && "value expected per above logic");
+  return gcd.getValue();
+}
+
+bool mlir::isAccessInvariant(Value *iv, Value *index) {
+  assert(isForInductionVar(iv) && "iv must be a AffineForOp");
+  assert(index->getType().isa<IndexType>() && "index must be of IndexType");
+  SmallVector<Operation *, 4> affineApplyOps;
+  getReachableAffineApplyOps({index}, affineApplyOps);
+
+  if (affineApplyOps.empty()) {
+    // Pointer equality test because of Value pointer semantics.
+    return index != iv;
+  }
+
+  if (affineApplyOps.size() > 1) {
+    affineApplyOps[0]->emitRemark(
+        "CompositionAffineMapsPass must have been run: there should be at most "
+        "one AffineApplyOp, returning false conservatively.");
+    return false;
+  }
+
+  auto composeOp = cast<AffineApplyOp>(affineApplyOps[0]);
+  // We need yet another level of indirection because the `dim` index of the
+  // access may not correspond to the `dim` index of composeOp.
+  return !(AffineValueMap(composeOp).isFunctionOf(0, iv));
+}
+
+llvm::DenseSet<Value *>
+mlir::getInvariantAccesses(Value *iv, llvm::ArrayRef<Value *> indices) {
+  llvm::DenseSet<Value *> res;
+  for (unsigned idx = 0, n = indices.size(); idx < n; ++idx) {
+    auto *val = indices[idx];
+    if (isAccessInvariant(iv, val)) {
+      res.insert(val);
+    }
+  }
+  return res;
+}
+
+/// Given:
+///   1. an induction variable `iv` of type AffineForOp;
+///   2. a `memoryOp` of type const LoadOp& or const StoreOp&;
+/// determines whether `memoryOp` has a contiguous access along `iv`. Contiguous
+/// is defined as either invariant or varying only along a unique MemRef dim.
+/// Upon success, the unique MemRef dim is written in `memRefDim` (or -1 to
+/// convey the memRef access is invariant along `iv`).
+///
+/// Prerequisites:
+///   1. `memRefDim` ~= nullptr;
+///   2. `iv` of the proper type;
+///   3. the MemRef accessed by `memoryOp` has no layout map or at most an
+///      identity layout map.
+///
+/// Currently only supports no layoutMap or identity layoutMap in the MemRef.
+/// Returns false if the MemRef has a non-identity layoutMap or more than 1
+/// layoutMap. This is conservative.
+///
+// TODO(ntv): check strides.
+template <typename LoadOrStoreOp>
+static bool isContiguousAccess(Value *iv, LoadOrStoreOp memoryOp,
+                               int *memRefDim) {
+  static_assert(std::is_same<LoadOrStoreOp, AffineLoadOp>::value ||
+                    std::is_same<LoadOrStoreOp, AffineStoreOp>::value,
+                "Must be called on either const LoadOp & or const StoreOp &");
+  assert(memRefDim && "memRefDim == nullptr");
+  auto memRefType = memoryOp.getMemRefType();
+
+  auto layoutMap = memRefType.getAffineMaps();
+  // TODO(ntv): remove dependence on Builder once we support non-identity
+  // layout map.
+  Builder b(memoryOp.getContext());
+  if (layoutMap.size() >= 2 ||
+      (layoutMap.size() == 1 &&
+       !(layoutMap[0] ==
+         b.getMultiDimIdentityMap(layoutMap[0].getNumDims())))) {
+    return memoryOp.emitError("NYI: non-trivial layoutMap"), false;
+  }
+
+  int uniqueVaryingIndexAlongIv = -1;
+  auto accessMap = memoryOp.getAffineMap();
+  SmallVector<Value *, 4> mapOperands(memoryOp.getIndices());
+  unsigned numDims = accessMap.getNumDims();
+  for (unsigned i = 0, e = memRefType.getRank(); i < e; ++i) {
+    // Gather map operands used result expr 'i' in 'exprOperands'.
+    SmallVector<Value *, 4> exprOperands;
+    auto resultExpr = accessMap.getResult(i);
+    resultExpr.walk([&](AffineExpr expr) {
+      if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
+        exprOperands.push_back(mapOperands[dimExpr.getPosition()]);
+      else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
+        exprOperands.push_back(mapOperands[numDims + symExpr.getPosition()]);
+    });
+    // Check access invariance of each operand in 'exprOperands'.
+    for (auto *exprOperand : exprOperands) {
+      if (!isAccessInvariant(iv, exprOperand)) {
+        if (uniqueVaryingIndexAlongIv != -1) {
+          // 2+ varying indices -> do not vectorize along iv.
+          return false;
+        }
+        uniqueVaryingIndexAlongIv = i;
+      }
+    }
+  }
+
+  if (uniqueVaryingIndexAlongIv == -1)
+    *memRefDim = -1;
+  else
+    *memRefDim = memRefType.getRank() - (uniqueVaryingIndexAlongIv + 1);
+  return true;
+}
+
+template <typename LoadOrStoreOpPointer>
+static bool isVectorElement(LoadOrStoreOpPointer memoryOp) {
+  auto memRefType = memoryOp.getMemRefType();
+  return memRefType.getElementType().template isa<VectorType>();
+}
+
+static bool isVectorTransferReadOrWrite(Operation &op) {
+  return isa<vector::VectorTransferReadOp>(op) ||
+         isa<vector::VectorTransferWriteOp>(op);
+}
+
+using VectorizableOpFun = std::function<bool(AffineForOp, Operation &)>;
+
+static bool
+isVectorizableLoopBodyWithOpCond(AffineForOp loop,
+                                 VectorizableOpFun isVectorizableOp) {
+  auto *forOp = loop.getOperation();
+
+  // No vectorization across conditionals for now.
+  auto conditionals = matcher::If();
+  SmallVector<NestedMatch, 8> conditionalsMatched;
+  conditionals.match(forOp, &conditionalsMatched);
+  if (!conditionalsMatched.empty()) {
+    return false;
+  }
+
+  // No vectorization across unknown regions.
+  auto regions = matcher::Op([](Operation &op) -> bool {
+    return op.getNumRegions() != 0 &&
+           !(isa<AffineIfOp>(op) || isa<AffineForOp>(op));
+  });
+  SmallVector<NestedMatch, 8> regionsMatched;
+  regions.match(forOp, &regionsMatched);
+  if (!regionsMatched.empty()) {
+    return false;
+  }
+
+  auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite);
+  SmallVector<NestedMatch, 8> vectorTransfersMatched;
+  vectorTransfers.match(forOp, &vectorTransfersMatched);
+  if (!vectorTransfersMatched.empty()) {
+    return false;
+  }
+
+  auto loadAndStores = matcher::Op(matcher::isLoadOrStore);
+  SmallVector<NestedMatch, 8> loadAndStoresMatched;
+  loadAndStores.match(forOp, &loadAndStoresMatched);
+  for (auto ls : loadAndStoresMatched) {
+    auto *op = ls.getMatchedOperation();
+    auto load = dyn_cast<AffineLoadOp>(op);
+    auto store = dyn_cast<AffineStoreOp>(op);
+    // Only scalar types are considered vectorizable, all load/store must be
+    // vectorizable for a loop to qualify as vectorizable.
+    // TODO(ntv): ponder whether we want to be more general here.
+    bool vector = load ? isVectorElement(load) : isVectorElement(store);
+    if (vector) {
+      return false;
+    }
+    if (isVectorizableOp && !isVectorizableOp(loop, *op)) {
+      return false;
+    }
+  }
+  return true;
+}
+
+bool mlir::isVectorizableLoopBody(AffineForOp loop, int *memRefDim) {
+  VectorizableOpFun fun([memRefDim](AffineForOp loop, Operation &op) {
+    auto load = dyn_cast<AffineLoadOp>(op);
+    auto store = dyn_cast<AffineStoreOp>(op);
+    return load ? isContiguousAccess(loop.getInductionVar(), load, memRefDim)
+                : isContiguousAccess(loop.getInductionVar(), store, memRefDim);
+  });
+  return isVectorizableLoopBodyWithOpCond(loop, fun);
+}
+
+bool mlir::isVectorizableLoopBody(AffineForOp loop) {
+  return isVectorizableLoopBodyWithOpCond(loop, nullptr);
+}
+
+/// Checks whether SSA dominance would be violated if a for op's body
+/// operations are shifted by the specified shifts. This method checks if a
+/// 'def' and all its uses have the same shift factor.
+// TODO(mlir-team): extend this to check for memory-based dependence violation
+// when we have the support.
+bool mlir::isInstwiseShiftValid(AffineForOp forOp, ArrayRef<uint64_t> shifts) {
+  auto *forBody = forOp.getBody();
+  assert(shifts.size() == forBody->getOperations().size());
+
+  // Work backwards over the body of the block so that the shift of a use's
+  // ancestor operation in the block gets recorded before it's looked up.
+  DenseMap<Operation *, uint64_t> forBodyShift;
+  for (auto it : llvm::enumerate(llvm::reverse(forBody->getOperations()))) {
+    auto &op = it.value();
+
+    // Get the index of the current operation, note that we are iterating in
+    // reverse so we need to fix it up.
+    size_t index = shifts.size() - it.index() - 1;
+
+    // Remember the shift of this operation.
+    uint64_t shift = shifts[index];
+    forBodyShift.try_emplace(&op, shift);
+
+    // Validate the results of this operation if it were to be shifted.
+    for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
+      Value *result = op.getResult(i);
+      for (auto *user : result->getUsers()) {
+        // If an ancestor operation doesn't lie in the block of forOp,
+        // there is no shift to check.
+        if (auto *ancInst = forBody->findAncestorInstInBlock(*user)) {
+          assert(forBodyShift.count(ancInst) > 0 && "ancestor expected in map");
+          if (shift != forBodyShift[ancInst])
+            return false;
+        }
+      }
+    }
+  }
+  return true;
+}
diff --git a/third_party/mlir/lib/Analysis/MemRefBoundCheck.cpp b/third_party/mlir/lib/Analysis/MemRefBoundCheck.cpp
new file mode 100644
index 0000000..b043d47
--- /dev/null
+++ b/third_party/mlir/lib/Analysis/MemRefBoundCheck.cpp
@@ -0,0 +1,63 @@
+//===- MemRefBoundCheck.cpp - MLIR Affine Structures Class ----------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to check memref accessses for out of bound
+// accesses.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/Passes.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "memref-bound-check"
+
+using namespace mlir;
+
+namespace {
+
+/// Checks for out of bound memef access subscripts..
+struct MemRefBoundCheck : public FunctionPass<MemRefBoundCheck> {
+  void runOnFunction() override;
+};
+
+} // end anonymous namespace
+
+FunctionPassBase *mlir::createMemRefBoundCheckPass() {
+  return new MemRefBoundCheck();
+}
+
+void MemRefBoundCheck::runOnFunction() {
+  getFunction().walk([](Operation *opInst) {
+    if (auto loadOp = dyn_cast<AffineLoadOp>(opInst)) {
+      boundCheckLoadOrStoreOp(loadOp);
+    } else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst)) {
+      boundCheckLoadOrStoreOp(storeOp);
+    }
+    // TODO(bondhugula): do this for DMA ops as well.
+  });
+}
+
+static PassRegistration<MemRefBoundCheck>
+    memRefBoundCheck("memref-bound-check",
+                     "Check memref access bounds in a Function");
diff --git a/third_party/mlir/lib/Analysis/NestedMatcher.cpp b/third_party/mlir/lib/Analysis/NestedMatcher.cpp
new file mode 100644
index 0000000..18be6cf
--- /dev/null
+++ b/third_party/mlir/lib/Analysis/NestedMatcher.cpp
@@ -0,0 +1,161 @@
+//===- NestedMatcher.cpp - NestedMatcher Impl  ----------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Analysis/NestedMatcher.h"
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/StandardOps/Ops.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+llvm::BumpPtrAllocator *&NestedMatch::allocator() {
+  thread_local llvm::BumpPtrAllocator *allocator = nullptr;
+  return allocator;
+}
+
+NestedMatch NestedMatch::build(Operation *operation,
+                               ArrayRef<NestedMatch> nestedMatches) {
+  auto *result = allocator()->Allocate<NestedMatch>();
+  auto *children = allocator()->Allocate<NestedMatch>(nestedMatches.size());
+  std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children);
+  new (result) NestedMatch();
+  result->matchedOperation = operation;
+  result->matchedChildren =
+      ArrayRef<NestedMatch>(children, nestedMatches.size());
+  return *result;
+}
+
+llvm::BumpPtrAllocator *&NestedPattern::allocator() {
+  thread_local llvm::BumpPtrAllocator *allocator = nullptr;
+  return allocator;
+}
+
+NestedPattern::NestedPattern(ArrayRef<NestedPattern> nested,
+                             FilterFunctionType filter)
+    : nestedPatterns(), filter(filter), skip(nullptr) {
+  if (!nested.empty()) {
+    auto *newNested = allocator()->Allocate<NestedPattern>(nested.size());
+    std::uninitialized_copy(nested.begin(), nested.end(), newNested);
+    nestedPatterns = ArrayRef<NestedPattern>(newNested, nested.size());
+  }
+}
+
+unsigned NestedPattern::getDepth() const {
+  if (nestedPatterns.empty()) {
+    return 1;
+  }
+  unsigned depth = 0;
+  for (auto &c : nestedPatterns) {
+    depth = std::max(depth, c.getDepth());
+  }
+  return depth + 1;
+}
+
+/// Matches a single operation in the following way:
+///   1. checks the kind of operation against the matcher, if different then
+///      there is no match;
+///   2. calls the customizable filter function to refine the single operation
+///      match with extra semantic constraints;
+///   3. if all is good, recursivey matches the nested patterns;
+///   4. if all nested match then the single operation matches too and is
+///      appended to the list of matches;
+///   5. TODO(ntv) Optionally applies actions (lambda), in which case we will
+///      want to traverse in post-order DFS to avoid invalidating iterators.
+void NestedPattern::matchOne(Operation *op,
+                             SmallVectorImpl<NestedMatch> *matches) {
+  if (skip == op) {
+    return;
+  }
+  // Local custom filter function
+  if (!filter(*op)) {
+    return;
+  }
+
+  if (nestedPatterns.empty()) {
+    SmallVector<NestedMatch, 8> nestedMatches;
+    matches->push_back(NestedMatch::build(op, nestedMatches));
+    return;
+  }
+  // Take a copy of each nested pattern so we can match it.
+  for (auto nestedPattern : nestedPatterns) {
+    SmallVector<NestedMatch, 8> nestedMatches;
+    // Skip elem in the walk immediately following. Without this we would
+    // essentially need to reimplement walk here.
+    nestedPattern.skip = op;
+    nestedPattern.match(op, &nestedMatches);
+    // If we could not match even one of the specified nestedPattern, early exit
+    // as this whole branch is not a match.
+    if (nestedMatches.empty()) {
+      return;
+    }
+    matches->push_back(NestedMatch::build(op, nestedMatches));
+  }
+}
+
+static bool isAffineForOp(Operation &op) { return isa<AffineForOp>(op); }
+
+static bool isAffineIfOp(Operation &op) { return isa<AffineIfOp>(op); }
+
+namespace mlir {
+namespace matcher {
+
+NestedPattern Op(FilterFunctionType filter) {
+  return NestedPattern({}, filter);
+}
+
+NestedPattern If(NestedPattern child) {
+  return NestedPattern(child, isAffineIfOp);
+}
+NestedPattern If(FilterFunctionType filter, NestedPattern child) {
+  return NestedPattern(child, [filter](Operation &op) {
+    return isAffineIfOp(op) && filter(op);
+  });
+}
+NestedPattern If(ArrayRef<NestedPattern> nested) {
+  return NestedPattern(nested, isAffineIfOp);
+}
+NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
+  return NestedPattern(nested, [filter](Operation &op) {
+    return isAffineIfOp(op) && filter(op);
+  });
+}
+
+NestedPattern For(NestedPattern child) {
+  return NestedPattern(child, isAffineForOp);
+}
+NestedPattern For(FilterFunctionType filter, NestedPattern child) {
+  return NestedPattern(
+      child, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
+}
+NestedPattern For(ArrayRef<NestedPattern> nested) {
+  return NestedPattern(nested, isAffineForOp);
+}
+NestedPattern For(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
+  return NestedPattern(
+      nested, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
+}
+
+bool isLoadOrStore(Operation &op) {
+  return isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op);
+}
+
+} // end namespace matcher
+} // end namespace mlir
diff --git a/third_party/mlir/lib/Analysis/OpStats.cpp b/third_party/mlir/lib/Analysis/OpStats.cpp
new file mode 100644
index 0000000..f01ec56
--- /dev/null
+++ b/third_party/mlir/lib/Analysis/OpStats.cpp
@@ -0,0 +1,93 @@
+//===- OpStats.cpp - Prints stats of operations in module -----------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Format.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+namespace {
+struct PrintOpStatsPass : public ModulePass<PrintOpStatsPass> {
+  explicit PrintOpStatsPass(llvm::raw_ostream &os = llvm::errs()) : os(os) {}
+
+  // Prints the resultant operation statistics post iterating over the module.
+  void runOnModule() override;
+
+  // Print summary of op stats.
+  void printSummary();
+
+private:
+  llvm::StringMap<int64_t> opCount;
+  llvm::raw_ostream &os;
+};
+} // namespace
+
+void PrintOpStatsPass::runOnModule() {
+  opCount.clear();
+
+  // Compute the operation statistics for each function in the module.
+  for (auto &op : getModule())
+    op.walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; });
+  printSummary();
+}
+
+void PrintOpStatsPass::printSummary() {
+  os << "Operations encountered:\n";
+  os << "-----------------------\n";
+  SmallVector<StringRef, 64> sorted(opCount.keys());
+  llvm::sort(sorted);
+
+  // Split an operation name from its dialect prefix.
+  auto splitOperationName = [](StringRef opName) {
+    auto splitName = opName.split('.');
+    return splitName.second.empty() ? std::make_pair("", splitName.first)
+                                    : splitName;
+  };
+
+  // Compute the largest dialect and operation name.
+  StringRef dialectName, opName;
+  size_t maxLenOpName = 0, maxLenDialect = 0;
+  for (const auto &key : sorted) {
+    std::tie(dialectName, opName) = splitOperationName(key);
+    maxLenDialect = std::max(maxLenDialect, dialectName.size());
+    maxLenOpName = std::max(maxLenOpName, opName.size());
+  }
+
+  for (const auto &key : sorted) {
+    std::tie(dialectName, opName) = splitOperationName(key);
+
+    // Left-align the names (aligning on the dialect) and right-align the count
+    // below. The alignment is for readability and does not affect CSV/FileCheck
+    // parsing.
+    if (dialectName.empty())
+      os.indent(maxLenDialect + 3);
+    else
+      os << llvm::right_justify(dialectName, maxLenDialect + 2) << '.';
+
+    // Left justify the operation name.
+    os << llvm::left_justify(opName, maxLenOpName) << " , " << opCount[key]
+       << '\n';
+  }
+}
+
+static PassRegistration<PrintOpStatsPass>
+    pass("print-op-stats", "Print statistics of operations");
diff --git a/third_party/mlir/lib/Analysis/SliceAnalysis.cpp b/third_party/mlir/lib/Analysis/SliceAnalysis.cpp
new file mode 100644
index 0000000..c240d77
--- /dev/null
+++ b/third_party/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -0,0 +1,223 @@
+//===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements Analysis functions specific to slicing in Function.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Analysis/VectorAnalysis.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/Functional.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+
+///
+/// Implements Analysis functions specific to slicing in Function.
+///
+
+using namespace mlir;
+
+using llvm::SetVector;
+
+static void getForwardSliceImpl(Operation *op,
+                                SetVector<Operation *> *forwardSlice,
+                                TransitiveFilter filter) {
+  if (!op) {
+    return;
+  }
+
+  // Evaluate whether we should keep this use.
+  // This is useful in particular to implement scoping; i.e. return the
+  // transitive forwardSlice in the current scope.
+  if (!filter(op)) {
+    return;
+  }
+
+  if (auto forOp = dyn_cast<AffineForOp>(op)) {
+    for (auto *ownerInst : forOp.getInductionVar()->getUsers())
+      if (forwardSlice->count(ownerInst) == 0)
+        getForwardSliceImpl(ownerInst, forwardSlice, filter);
+  } else if (auto forOp = dyn_cast<loop::ForOp>(op)) {
+    for (auto *ownerInst : forOp.getInductionVar()->getUsers())
+      if (forwardSlice->count(ownerInst) == 0)
+        getForwardSliceImpl(ownerInst, forwardSlice, filter);
+  } else {
+    assert(op->getNumRegions() == 0 && "unexpected generic op with regions");
+    assert(op->getNumResults() <= 1 && "unexpected multiple results");
+    if (op->getNumResults() > 0) {
+      for (auto *ownerInst : op->getResult(0)->getUsers())
+        if (forwardSlice->count(ownerInst) == 0)
+          getForwardSliceImpl(ownerInst, forwardSlice, filter);
+    }
+  }
+
+  forwardSlice->insert(op);
+}
+
+void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
+                           TransitiveFilter filter) {
+  getForwardSliceImpl(op, forwardSlice, filter);
+  // Don't insert the top level operation, we just queried on it and don't
+  // want it in the results.
+  forwardSlice->remove(op);
+
+  // Reverse to get back the actual topological order.
+  // std::reverse does not work out of the box on SetVector and I want an
+  // in-place swap based thing (the real std::reverse, not the LLVM adapter).
+  std::vector<Operation *> v(forwardSlice->takeVector());
+  forwardSlice->insert(v.rbegin(), v.rend());
+}
+
+static void getBackwardSliceImpl(Operation *op,
+                                 SetVector<Operation *> *backwardSlice,
+                                 TransitiveFilter filter) {
+  if (!op)
+    return;
+
+  assert((op->getNumRegions() == 0 || isa<AffineForOp>(op) ||
+          isa<loop::ForOp>(op)) &&
+         "unexpected generic op with regions");
+
+  // Evaluate whether we should keep this def.
+  // This is useful in particular to implement scoping; i.e. return the
+  // transitive forwardSlice in the current scope.
+  if (!filter(op)) {
+    return;
+  }
+
+  for (auto en : llvm::enumerate(op->getOperands())) {
+    auto *operand = en.value();
+    if (auto *blockArg = dyn_cast<BlockArgument>(operand)) {
+      if (auto affIv = getForInductionVarOwner(operand)) {
+        auto *affOp = affIv.getOperation();
+        if (backwardSlice->count(affOp) == 0)
+          getBackwardSliceImpl(affOp, backwardSlice, filter);
+      } else if (auto loopIv = loop::getForInductionVarOwner(operand)) {
+        auto *loopOp = loopIv.getOperation();
+        if (backwardSlice->count(loopOp) == 0)
+          getBackwardSliceImpl(loopOp, backwardSlice, filter);
+      } else if (blockArg->getOwner() !=
+                 &op->getParentOfType<FuncOp>().getBody().front()) {
+        op->emitError("Unsupported CF for operand ") << en.index();
+        llvm_unreachable("Unsupported control flow");
+      }
+      continue;
+    }
+    auto *op = operand->getDefiningOp();
+    if (backwardSlice->count(op) == 0) {
+      getBackwardSliceImpl(op, backwardSlice, filter);
+    }
+  }
+
+  backwardSlice->insert(op);
+}
+
+void mlir::getBackwardSlice(Operation *op,
+                            SetVector<Operation *> *backwardSlice,
+                            TransitiveFilter filter) {
+  getBackwardSliceImpl(op, backwardSlice, filter);
+
+  // Don't insert the top level operation, we just queried on it and don't
+  // want it in the results.
+  backwardSlice->remove(op);
+}
+
+SetVector<Operation *> mlir::getSlice(Operation *op,
+                                      TransitiveFilter backwardFilter,
+                                      TransitiveFilter forwardFilter) {
+  SetVector<Operation *> slice;
+  slice.insert(op);
+
+  unsigned currentIndex = 0;
+  SetVector<Operation *> backwardSlice;
+  SetVector<Operation *> forwardSlice;
+  while (currentIndex != slice.size()) {
+    auto *currentInst = (slice)[currentIndex];
+    // Compute and insert the backwardSlice starting from currentInst.
+    backwardSlice.clear();
+    getBackwardSlice(currentInst, &backwardSlice, backwardFilter);
+    slice.insert(backwardSlice.begin(), backwardSlice.end());
+
+    // Compute and insert the forwardSlice starting from currentInst.
+    forwardSlice.clear();
+    getForwardSlice(currentInst, &forwardSlice, forwardFilter);
+    slice.insert(forwardSlice.begin(), forwardSlice.end());
+    ++currentIndex;
+  }
+  return topologicalSort(slice);
+}
+
+namespace {
+/// DFS post-order implementation that maintains a global count to work across
+/// multiple invocations, to help implement topological sort on multi-root DAGs.
+/// We traverse all operations but only record the ones that appear in
+/// `toSort` for the final result.
+struct DFSState {
+  DFSState(const SetVector<Operation *> &set)
+      : toSort(set), topologicalCounts(), seen() {}
+  const SetVector<Operation *> &toSort;
+  SmallVector<Operation *, 16> topologicalCounts;
+  DenseSet<Operation *> seen;
+};
+} // namespace
+
+static void DFSPostorder(Operation *current, DFSState *state) {
+  assert(current->getNumResults() <= 1 && "NYI: multi-result");
+  if (current->getNumResults() > 0) {
+    for (auto &u : current->getResult(0)->getUses()) {
+      auto *op = u.getOwner();
+      DFSPostorder(op, state);
+    }
+  }
+  bool inserted;
+  using IterTy = decltype(state->seen.begin());
+  IterTy iter;
+  std::tie(iter, inserted) = state->seen.insert(current);
+  if (inserted) {
+    if (state->toSort.count(current) > 0) {
+      state->topologicalCounts.push_back(current);
+    }
+  }
+}
+
+SetVector<Operation *>
+mlir::topologicalSort(const SetVector<Operation *> &toSort) {
+  if (toSort.empty()) {
+    return toSort;
+  }
+
+  // Run from each root with global count and `seen` set.
+  DFSState state(toSort);
+  for (auto *s : toSort) {
+    assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
+    DFSPostorder(s, &state);
+  }
+
+  // Reorder and return.
+  SetVector<Operation *> res;
+  for (auto it = state.topologicalCounts.rbegin(),
+            eit = state.topologicalCounts.rend();
+       it != eit; ++it) {
+    res.insert(*it);
+  }
+  return res;
+}
diff --git a/third_party/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp b/third_party/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp
new file mode 100644
index 0000000..1802b73
--- /dev/null
+++ b/third_party/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp
@@ -0,0 +1,129 @@
+//===- TestMemRefDependenceCheck.cpp - Test dep analysis ------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to run pair-wise memref access dependence checks.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/Passes.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "test-memref-dependence-check"
+
+using namespace mlir;
+
+namespace {
+
+// TODO(andydavis) Add common surrounding loop depth-wise dependence checks.
+/// Checks dependences between all pairs of memref accesses in a Function.
+struct TestMemRefDependenceCheck
+    : public FunctionPass<TestMemRefDependenceCheck> {
+  SmallVector<Operation *, 4> loadsAndStores;
+  void runOnFunction() override;
+};
+
+} // end anonymous namespace
+
+FunctionPassBase *mlir::createTestMemRefDependenceCheckPass() {
+  return new TestMemRefDependenceCheck();
+}
+
+// Returns a result string which represents the direction vector (if there was
+// a dependence), returns the string "false" otherwise.
+static std::string
+getDirectionVectorStr(bool ret, unsigned numCommonLoops, unsigned loopNestDepth,
+                      ArrayRef<DependenceComponent> dependenceComponents) {
+  if (!ret)
+    return "false";
+  if (dependenceComponents.empty() || loopNestDepth > numCommonLoops)
+    return "true";
+  std::string result;
+  for (unsigned i = 0, e = dependenceComponents.size(); i < e; ++i) {
+    std::string lbStr = "-inf";
+    if (dependenceComponents[i].lb.hasValue() &&
+        dependenceComponents[i].lb.getValue() !=
+            std::numeric_limits<int64_t>::min())
+      lbStr = std::to_string(dependenceComponents[i].lb.getValue());
+
+    std::string ubStr = "+inf";
+    if (dependenceComponents[i].ub.hasValue() &&
+        dependenceComponents[i].ub.getValue() !=
+            std::numeric_limits<int64_t>::max())
+      ubStr = std::to_string(dependenceComponents[i].ub.getValue());
+
+    result += "[" + lbStr + ", " + ubStr + "]";
+  }
+  return result;
+}
+
+// For each access in 'loadsAndStores', runs a depence check between this
+// "source" access and all subsequent "destination" accesses in
+// 'loadsAndStores'. Emits the result of the dependence check as a note with
+// the source access.
+static void checkDependences(ArrayRef<Operation *> loadsAndStores) {
+  for (unsigned i = 0, e = loadsAndStores.size(); i < e; ++i) {
+    auto *srcOpInst = loadsAndStores[i];
+    MemRefAccess srcAccess(srcOpInst);
+    for (unsigned j = 0; j < e; ++j) {
+      auto *dstOpInst = loadsAndStores[j];
+      MemRefAccess dstAccess(dstOpInst);
+
+      unsigned numCommonLoops =
+          getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
+      for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
+        FlatAffineConstraints dependenceConstraints;
+        llvm::SmallVector<DependenceComponent, 2> dependenceComponents;
+        DependenceResult result = checkMemrefAccessDependence(
+            srcAccess, dstAccess, d, &dependenceConstraints,
+            &dependenceComponents);
+        assert(result.value != DependenceResult::Failure);
+        bool ret = hasDependence(result);
+        // TODO(andydavis) Print dependence type (i.e. RAW, etc) and print
+        // distance vectors as: ([2, 3], [0, 10]). Also, shorten distance
+        // vectors from ([1, 1], [3, 3]) to (1, 3).
+        srcOpInst->emitRemark("dependence from ")
+            << i << " to " << j << " at depth " << d << " = "
+            << getDirectionVectorStr(ret, numCommonLoops, d,
+                                     dependenceComponents);
+      }
+    }
+  }
+}
+
+// Walks the Function 'f' adding load and store ops to 'loadsAndStores'.
+// Runs pair-wise dependence checks.
+void TestMemRefDependenceCheck::runOnFunction() {
+  // Collect the loads and stores within the function.
+  loadsAndStores.clear();
+  getFunction().walk([&](Operation *op) {
+    if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op))
+      loadsAndStores.push_back(op);
+  });
+
+  checkDependences(loadsAndStores);
+}
+
+static PassRegistration<TestMemRefDependenceCheck>
+    pass("test-memref-dependence-check",
+         "Checks dependences between all pairs of memref accesses.");
diff --git a/third_party/mlir/lib/Analysis/TestParallelismDetection.cpp b/third_party/mlir/lib/Analysis/TestParallelismDetection.cpp
new file mode 100644
index 0000000..246cfbe
--- /dev/null
+++ b/third_party/mlir/lib/Analysis/TestParallelismDetection.cpp
@@ -0,0 +1,57 @@
+//===- ParallelismDetection.cpp - Parallelism Detection pass ------------*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to detect parallel affine 'affine.for' ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Analysis/Passes.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestParallelismDetection
+    : public FunctionPass<TestParallelismDetection> {
+  void runOnFunction() override;
+};
+
+} // end anonymous namespace
+
+FunctionPassBase *mlir::createParallelismDetectionTestPass() {
+  return new TestParallelismDetection();
+}
+
+// Walks the function and emits a note for all 'affine.for' ops detected as
+// parallel.
+void TestParallelismDetection::runOnFunction() {
+  FuncOp f = getFunction();
+  OpBuilder b(f.getBody());
+  f.walk<AffineForOp>([&](AffineForOp forOp) {
+    if (isLoopParallel(forOp))
+      forOp.emitRemark("parallel loop");
+    else
+      forOp.emitRemark("sequential loop");
+  });
+}
+
+static PassRegistration<TestParallelismDetection>
+    pass("test-detect-parallel", "Test parallelism detection ");
diff --git a/third_party/mlir/lib/Analysis/Utils.cpp b/third_party/mlir/lib/Analysis/Utils.cpp
new file mode 100644
index 0000000..fc36cc5
--- /dev/null
+++ b/third_party/mlir/lib/Analysis/Utils.cpp
@@ -0,0 +1,1002 @@
+//===- Utils.cpp ---- Misc utilities for analysis -------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements miscellaneous analysis routines for non-loop IR
+// structures.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Utils.h"
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/StandardOps/Ops.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+#define DEBUG_TYPE "analysis-utils"
+
+using namespace mlir;
+
+using llvm::SmallDenseMap;
+
+/// Populates 'loops' with IVs of the loops surrounding 'op' ordered from
+/// the outermost 'affine.for' operation to the innermost one.
+void mlir::getLoopIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops) {
+  auto *currOp = op.getParentOp();
+  AffineForOp currAffineForOp;
+  // Traverse up the hierarchy collecing all 'affine.for' operation while
+  // skipping over 'affine.if' operations.
+  while (currOp && ((currAffineForOp = dyn_cast<AffineForOp>(currOp)) ||
+                    isa<AffineIfOp>(currOp))) {
+    if (currAffineForOp)
+      loops->push_back(currAffineForOp);
+    currOp = currOp->getParentOp();
+  }
+  std::reverse(loops->begin(), loops->end());
+}
+
+// Populates 'cst' with FlatAffineConstraints which represent slice bounds.
+LogicalResult
+ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {
+  assert(!lbOperands.empty());
+  // Adds src 'ivs' as dimension identifiers in 'cst'.
+  unsigned numDims = ivs.size();
+  // Adds operands (dst ivs and symbols) as symbols in 'cst'.
+  unsigned numSymbols = lbOperands[0].size();
+
+  SmallVector<Value *, 4> values(ivs);
+  // Append 'ivs' then 'operands' to 'values'.
+  values.append(lbOperands[0].begin(), lbOperands[0].end());
+  cst->reset(numDims, numSymbols, 0, values);
+
+  // Add loop bound constraints for values which are loop IVs and equality
+  // constraints for symbols which are constants.
+  for (const auto &value : values) {
+    assert(cst->containsId(*value) && "value expected to be present");
+    if (isValidSymbol(value)) {
+      // Check if the symbol is a constant.
+
+      if (auto cOp = dyn_cast_or_null<ConstantIndexOp>(value->getDefiningOp()))
+        cst->setIdToConstant(*value, cOp.getValue());
+    } else if (auto loop = getForInductionVarOwner(value)) {
+      if (failed(cst->addAffineForOpDomain(loop)))
+        return failure();
+    }
+  }
+
+  // Add slices bounds on 'ivs' using maps 'lbs'/'ubs' with 'lbOperands[0]'
+  LogicalResult ret = cst->addSliceBounds(ivs, lbs, ubs, lbOperands[0]);
+  assert(succeeded(ret) &&
+         "should not fail as we never have semi-affine slice maps");
+  (void)ret;
+  return success();
+}
+
+// Clears state bounds and operand state.
+void ComputationSliceState::clearBounds() {
+  lbs.clear();
+  ubs.clear();
+  lbOperands.clear();
+  ubOperands.clear();
+}
+
+unsigned MemRefRegion::getRank() const {
+  return memref->getType().cast<MemRefType>().getRank();
+}
+
+Optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
+    SmallVectorImpl<int64_t> *shape, std::vector<SmallVector<int64_t, 4>> *lbs,
+    SmallVectorImpl<int64_t> *lbDivisors) const {
+  auto memRefType = memref->getType().cast<MemRefType>();
+  unsigned rank = memRefType.getRank();
+  if (shape)
+    shape->reserve(rank);
+
+  assert(rank == cst.getNumDimIds() && "inconsistent memref region");
+
+  // Find a constant upper bound on the extent of this memref region along each
+  // dimension.
+  int64_t numElements = 1;
+  int64_t diffConstant;
+  int64_t lbDivisor;
+  for (unsigned d = 0; d < rank; d++) {
+    SmallVector<int64_t, 4> lb;
+    Optional<int64_t> diff = cst.getConstantBoundOnDimSize(d, &lb, &lbDivisor);
+    if (diff.hasValue()) {
+      diffConstant = diff.getValue();
+      assert(lbDivisor > 0);
+    } else {
+      // If no constant bound is found, then it can always be bound by the
+      // memref's dim size if the latter has a constant size along this dim.
+      auto dimSize = memRefType.getDimSize(d);
+      if (dimSize == -1)
+        return None;
+      diffConstant = dimSize;
+      // Lower bound becomes 0.
+      lb.resize(cst.getNumSymbolIds() + 1, 0);
+      lbDivisor = 1;
+    }
+    numElements *= diffConstant;
+    if (lbs) {
+      lbs->push_back(lb);
+      assert(lbDivisors && "both lbs and lbDivisor or none");
+      lbDivisors->push_back(lbDivisor);
+    }
+    if (shape) {
+      shape->push_back(diffConstant);
+    }
+  }
+  return numElements;
+}
+
+LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) {
+  assert(memref == other.memref);
+  return cst.unionBoundingBox(*other.getConstraints());
+}
+
+/// Computes the memory region accessed by this memref with the region
+/// represented as constraints symbolic/parameteric in 'loopDepth' loops
+/// surrounding opInst and any additional Function symbols.
+//  For example, the memref region for this load operation at loopDepth = 1 will
+//  be as below:
+//
+//    affine.for %i = 0 to 32 {
+//      affine.for %ii = %i to (d0) -> (d0 + 8) (%i) {
+//        load %A[%ii]
+//      }
+//    }
+//
+// region:  {memref = %A, write = false, {%i <= m0 <= %i + 7} }
+// The last field is a 2-d FlatAffineConstraints symbolic in %i.
+//
+// TODO(bondhugula): extend this to any other memref dereferencing ops
+// (dma_start, dma_wait).
+LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
+                                    ComputationSliceState *sliceState,
+                                    bool addMemRefDimBounds) {
+  assert((isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op)) &&
+         "affine load/store op expected");
+
+  MemRefAccess access(op);
+  memref = access.memref;
+  write = access.isStore();
+
+  unsigned rank = access.getRank();
+
+  LLVM_DEBUG(llvm::dbgs() << "MemRefRegion::compute: " << *op
+                          << "depth: " << loopDepth << "\n";);
+
+  if (rank == 0) {
+    SmallVector<AffineForOp, 4> ivs;
+    getLoopIVs(*op, &ivs);
+    SmallVector<Value *, 8> regionSymbols;
+    extractForInductionVars(ivs, &regionSymbols);
+    // A rank 0 memref has a 0-d region.
+    cst.reset(rank, loopDepth, 0, regionSymbols);
+    return success();
+  }
+
+  // Build the constraints for this region.
+  AffineValueMap accessValueMap;
+  access.getAccessMap(&accessValueMap);
+  AffineMap accessMap = accessValueMap.getAffineMap();
+
+  unsigned numDims = accessMap.getNumDims();
+  unsigned numSymbols = accessMap.getNumSymbols();
+  unsigned numOperands = accessValueMap.getNumOperands();
+  // Merge operands with slice operands.
+  SmallVector<Value *, 4> operands;
+  operands.resize(numOperands);
+  for (unsigned i = 0; i < numOperands; ++i)
+    operands[i] = accessValueMap.getOperand(i);
+
+  if (sliceState != nullptr) {
+    operands.reserve(operands.size() + sliceState->lbOperands[0].size());
+    // Append slice operands to 'operands' as symbols.
+    for (auto extraOperand : sliceState->lbOperands[0]) {
+      if (!llvm::is_contained(operands, extraOperand)) {
+        operands.push_back(extraOperand);
+        numSymbols++;
+      }
+    }
+  }
+  // We'll first associate the dims and symbols of the access map to the dims
+  // and symbols resp. of cst. This will change below once cst is
+  // fully constructed out.
+  cst.reset(numDims, numSymbols, 0, operands);
+
+  // Add equality constraints.
+  // Add inequalties for loop lower/upper bounds.
+  for (unsigned i = 0; i < numDims + numSymbols; ++i) {
+    auto *operand = operands[i];
+    if (auto loop = getForInductionVarOwner(operand)) {
+      // Note that cst can now have more dimensions than accessMap if the
+      // bounds expressions involve outer loops or other symbols.
+      // TODO(bondhugula): rewrite this to use getInstIndexSet; this way
+      // conditionals will be handled when the latter supports it.
+      if (failed(cst.addAffineForOpDomain(loop)))
+        return failure();
+    } else {
+      // Has to be a valid symbol.
+      auto *symbol = operand;
+      assert(isValidSymbol(symbol));
+      // Check if the symbol is a constant.
+      if (auto *op = symbol->getDefiningOp()) {
+        if (auto constOp = dyn_cast<ConstantIndexOp>(op)) {
+          cst.setIdToConstant(*symbol, constOp.getValue());
+        }
+      }
+    }
+  }
+
+  // Add lower/upper bounds on loop IVs using bounds from 'sliceState'.
+  if (sliceState != nullptr) {
+    // Add dim and symbol slice operands.
+    for (auto operand : sliceState->lbOperands[0]) {
+      cst.addInductionVarOrTerminalSymbol(operand);
+    }
+    // Add upper/lower bounds from 'sliceState' to 'cst'.
+    LogicalResult ret =
+        cst.addSliceBounds(sliceState->ivs, sliceState->lbs, sliceState->ubs,
+                           sliceState->lbOperands[0]);
+    assert(succeeded(ret) &&
+           "should not fail as we never have semi-affine slice maps");
+    (void)ret;
+  }
+
+  // Add access function equalities to connect loop IVs to data dimensions.
+  if (failed(cst.composeMap(&accessValueMap))) {
+    op->emitError("getMemRefRegion: compose affine map failed");
+    LLVM_DEBUG(accessValueMap.getAffineMap().dump());
+    return failure();
+  }
+
+  // Set all identifiers appearing after the first 'rank' identifiers as
+  // symbolic identifiers - so that the ones corresponding to the memref
+  // dimensions are the dimensional identifiers for the memref region.
+  cst.setDimSymbolSeparation(cst.getNumDimAndSymbolIds() - rank);
+
+  // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which
+  // this memref region is symbolic.
+  SmallVector<AffineForOp, 4> enclosingIVs;
+  getLoopIVs(*op, &enclosingIVs);
+  assert(loopDepth <= enclosingIVs.size() && "invalid loop depth");
+  enclosingIVs.resize(loopDepth);
+  SmallVector<Value *, 4> ids;
+  cst.getIdValues(cst.getNumDimIds(), cst.getNumDimAndSymbolIds(), &ids);
+  for (auto *id : ids) {
+    AffineForOp iv;
+    if ((iv = getForInductionVarOwner(id)) &&
+        llvm::is_contained(enclosingIVs, iv) == false) {
+      cst.projectOut(id);
+    }
+  }
+
+  // Project out any local variables (these would have been added for any
+  // mod/divs).
+  cst.projectOut(cst.getNumDimAndSymbolIds(), cst.getNumLocalIds());
+
+  // Constant fold any symbolic identifiers.
+  cst.constantFoldIdRange(/*pos=*/cst.getNumDimIds(),
+                          /*num=*/cst.getNumSymbolIds());
+
+  assert(cst.getNumDimIds() == rank && "unexpected MemRefRegion format");
+
+  // Add upper/lower bounds for each memref dimension with static size
+  // to guard against potential over-approximation from projection.
+  // TODO(andydavis) Support dynamic memref dimensions.
+  if (addMemRefDimBounds) {
+    auto memRefType = memref->getType().cast<MemRefType>();
+    for (unsigned r = 0; r < rank; r++) {
+      cst.addConstantLowerBound(r, 0);
+      int64_t dimSize = memRefType.getDimSize(r);
+      if (ShapedType::isDynamic(dimSize))
+        continue;
+      cst.addConstantUpperBound(r, dimSize - 1);
+    }
+  }
+
+  LLVM_DEBUG(llvm::dbgs() << "Memory region:\n");
+  LLVM_DEBUG(cst.dump());
+  return success();
+}
+
+//  TODO(mlir-team): improve/complete this when we have target data.
+static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
+  auto elementType = memRefType.getElementType();
+
+  unsigned sizeInBits;
+  if (elementType.isIntOrFloat()) {
+    sizeInBits = elementType.getIntOrFloatBitWidth();
+  } else {
+    auto vectorType = elementType.cast<VectorType>();
+    sizeInBits =
+        vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
+  }
+  return llvm::divideCeil(sizeInBits, 8);
+}
+
+// Returns the size of the region.
+Optional<int64_t> MemRefRegion::getRegionSize() {
+  auto memRefType = memref->getType().cast<MemRefType>();
+
+  auto layoutMaps = memRefType.getAffineMaps();
+  if (layoutMaps.size() > 1 ||
+      (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) {
+    LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
+    return false;
+  }
+
+  // Indices to use for the DmaStart op.
+  // Indices for the original memref being DMAed from/to.
+  SmallVector<Value *, 4> memIndices;
+  // Indices for the faster buffer being DMAed into/from.
+  SmallVector<Value *, 4> bufIndices;
+
+  // Compute the extents of the buffer.
+  Optional<int64_t> numElements = getConstantBoundingSizeAndShape();
+  if (!numElements.hasValue()) {
+    LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n");
+    return None;
+  }
+  return getMemRefEltSizeInBytes(memRefType) * numElements.getValue();
+}
+
+/// Returns the size of memref data in bytes if it's statically shaped, None
+/// otherwise.  If the element of the memref has vector type, takes into account
+/// size of the vector as well.
+//  TODO(mlir-team): improve/complete this when we have target data.
+Optional<uint64_t> mlir::getMemRefSizeInBytes(MemRefType memRefType) {
+  if (!memRefType.hasStaticShape())
+    return None;
+  auto elementType = memRefType.getElementType();
+  if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
+    return None;
+
+  uint64_t sizeInBytes = getMemRefEltSizeInBytes(memRefType);
+  for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) {
+    sizeInBytes = sizeInBytes * memRefType.getDimSize(i);
+  }
+  return sizeInBytes;
+}
+
+template <typename LoadOrStoreOpPointer>
+LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp,
+                                            bool emitError) {
+  static_assert(std::is_same<LoadOrStoreOpPointer, AffineLoadOp>::value ||
+                    std::is_same<LoadOrStoreOpPointer, AffineStoreOp>::value,
+                "argument should be either a AffineLoadOp or a AffineStoreOp");
+
+  Operation *opInst = loadOrStoreOp.getOperation();
+  MemRefRegion region(opInst->getLoc());
+  if (failed(region.compute(opInst, /*loopDepth=*/0, /*sliceState=*/nullptr,
+                            /*addMemRefDimBounds=*/false)))
+    return success();
+
+  LLVM_DEBUG(llvm::dbgs() << "Memory region");
+  LLVM_DEBUG(region.getConstraints()->dump());
+
+  bool outOfBounds = false;
+  unsigned rank = loadOrStoreOp.getMemRefType().getRank();
+
+  // For each dimension, check for out of bounds.
+  for (unsigned r = 0; r < rank; r++) {
+    FlatAffineConstraints ucst(*region.getConstraints());
+
+    // Intersect memory region with constraint capturing out of bounds (both out
+    // of upper and out of lower), and check if the constraint system is
+    // feasible. If it is, there is at least one point out of bounds.
+    SmallVector<int64_t, 4> ineq(rank + 1, 0);
+    int64_t dimSize = loadOrStoreOp.getMemRefType().getDimSize(r);
+    // TODO(bondhugula): handle dynamic dim sizes.
+    if (dimSize == -1)
+      continue;
+
+    // Check for overflow: d_i >= memref dim size.
+    ucst.addConstantLowerBound(r, dimSize);
+    outOfBounds = !ucst.isEmpty();
+    if (outOfBounds && emitError) {
+      loadOrStoreOp.emitOpError()
+          << "memref out of upper bound access along dimension #" << (r + 1);
+    }
+
+    // Check for a negative index.
+    FlatAffineConstraints lcst(*region.getConstraints());
+    std::fill(ineq.begin(), ineq.end(), 0);
+    // d_i <= -1;
+    lcst.addConstantUpperBound(r, -1);
+    outOfBounds = !lcst.isEmpty();
+    if (outOfBounds && emitError) {
+      loadOrStoreOp.emitOpError()
+          << "memref out of lower bound access along dimension #" << (r + 1);
+    }
+  }
+  return failure(outOfBounds);
+}
+
+// Explicitly instantiate the template so that the compiler knows we need them!
+template LogicalResult mlir::boundCheckLoadOrStoreOp(AffineLoadOp loadOp,
+                                                     bool emitError);
+template LogicalResult mlir::boundCheckLoadOrStoreOp(AffineStoreOp storeOp,
+                                                     bool emitError);
+
+// Returns in 'positions' the Block positions of 'op' in each ancestor
+// Block from the Block containing operation, stopping at 'limitBlock'.
+static void findInstPosition(Operation *op, Block *limitBlock,
+                             SmallVectorImpl<unsigned> *positions) {
+  Block *block = op->getBlock();
+  while (block != limitBlock) {
+    // FIXME: This algorithm is unnecessarily O(n) and should be improved to not
+    // rely on linear scans.
+    int instPosInBlock = std::distance(block->begin(), op->getIterator());
+    positions->push_back(instPosInBlock);
+    op = block->getParentOp();
+    block = op->getBlock();
+  }
+  std::reverse(positions->begin(), positions->end());
+}
+
+// Returns the Operation in a possibly nested set of Blocks, where the
+// position of the operation is represented by 'positions', which has a
+// Block position for each level of nesting.
+static Operation *getInstAtPosition(ArrayRef<unsigned> positions,
+                                    unsigned level, Block *block) {
+  unsigned i = 0;
+  for (auto &op : *block) {
+    if (i != positions[level]) {
+      ++i;
+      continue;
+    }
+    if (level == positions.size() - 1)
+      return &op;
+    if (auto childAffineForOp = dyn_cast<AffineForOp>(op))
+      return getInstAtPosition(positions, level + 1,
+                               childAffineForOp.getBody());
+
+    for (auto &region : op.getRegions()) {
+      for (auto &b : region)
+        if (auto *ret = getInstAtPosition(positions, level + 1, &b))
+          return ret;
+    }
+    return nullptr;
+  }
+  return nullptr;
+}
+
+// Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'.
+LogicalResult addMissingLoopIVBounds(SmallPtrSet<Value *, 8> &ivs,
+                                     FlatAffineConstraints *cst) {
+  for (unsigned i = 0, e = cst->getNumDimIds(); i < e; ++i) {
+    auto *value = cst->getIdValue(i);
+    if (ivs.count(value) == 0) {
+      assert(isForInductionVar(value));
+      auto loop = getForInductionVarOwner(value);
+      if (failed(cst->addAffineForOpDomain(loop)))
+        return failure();
+    }
+  }
+  return success();
+}
+
+// Returns the innermost common loop depth for the set of operations in 'ops'.
+// TODO(andydavis) Move this to LoopUtils.
+static unsigned
+getInnermostCommonLoopDepth(ArrayRef<Operation *> ops,
+                            SmallVectorImpl<AffineForOp> &surroundingLoops) {
+  unsigned numOps = ops.size();
+  assert(numOps > 0);
+
+  std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
+  unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
+  for (unsigned i = 0; i < numOps; ++i) {
+    getLoopIVs(*ops[i], &loops[i]);
+    loopDepthLimit =
+        std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
+  }
+
+  unsigned loopDepth = 0;
+  for (unsigned d = 0; d < loopDepthLimit; ++d) {
+    unsigned i;
+    for (i = 1; i < numOps; ++i) {
+      if (loops[i - 1][d] != loops[i][d])
+        return loopDepth;
+    }
+    surroundingLoops.push_back(loops[i - 1][d]);
+    ++loopDepth;
+  }
+  return loopDepth;
+}
+
+/// Computes in 'sliceUnion' the union of all slice bounds computed at
+/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB'.
+/// Returns 'Success' if union was computed, 'failure' otherwise.
+LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
+                                      ArrayRef<Operation *> opsB,
+                                      unsigned loopDepth,
+                                      unsigned numCommonLoops,
+                                      bool isBackwardSlice,
+                                      ComputationSliceState *sliceUnion) {
+  // Compute the union of slice bounds between all pairs in 'opsA' and
+  // 'opsB' in 'sliceUnionCst'.
+  FlatAffineConstraints sliceUnionCst;
+  assert(sliceUnionCst.getNumDimAndSymbolIds() == 0);
+  std::vector<std::pair<Operation *, Operation *>> dependentOpPairs;
+  for (unsigned i = 0, numOpsA = opsA.size(); i < numOpsA; ++i) {
+    MemRefAccess srcAccess(opsA[i]);
+    for (unsigned j = 0, numOpsB = opsB.size(); j < numOpsB; ++j) {
+      MemRefAccess dstAccess(opsB[j]);
+      if (srcAccess.memref != dstAccess.memref)
+        continue;
+      // Check if 'loopDepth' exceeds nesting depth of src/dst ops.
+      if ((!isBackwardSlice && loopDepth > getNestingDepth(*opsA[i])) ||
+          (isBackwardSlice && loopDepth > getNestingDepth(*opsB[j]))) {
+        LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n.");
+        return failure();
+      }
+
+      bool readReadAccesses = isa<AffineLoadOp>(srcAccess.opInst) &&
+                              isa<AffineLoadOp>(dstAccess.opInst);
+      FlatAffineConstraints dependenceConstraints;
+      // Check dependence between 'srcAccess' and 'dstAccess'.
+      DependenceResult result = checkMemrefAccessDependence(
+          srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1,
+          &dependenceConstraints, /*dependenceComponents=*/nullptr,
+          /*allowRAR=*/readReadAccesses);
+      if (result.value == DependenceResult::Failure) {
+        LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n.");
+        return failure();
+      }
+      if (result.value == DependenceResult::NoDependence)
+        continue;
+      dependentOpPairs.push_back({opsA[i], opsB[j]});
+
+      // Compute slice bounds for 'srcAccess' and 'dstAccess'.
+      ComputationSliceState tmpSliceState;
+      mlir::getComputationSliceState(opsA[i], opsB[j], &dependenceConstraints,
+                                     loopDepth, isBackwardSlice,
+                                     &tmpSliceState);
+
+      if (sliceUnionCst.getNumDimAndSymbolIds() == 0) {
+        // Initialize 'sliceUnionCst' with the bounds computed in previous step.
+        if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) {
+          LLVM_DEBUG(llvm::dbgs()
+                     << "Unable to compute slice bound constraints\n.");
+          return failure();
+        }
+        assert(sliceUnionCst.getNumDimAndSymbolIds() > 0);
+        continue;
+      }
+
+      // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
+      FlatAffineConstraints tmpSliceCst;
+      if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
+        LLVM_DEBUG(llvm::dbgs()
+                   << "Unable to compute slice bound constraints\n.");
+        return failure();
+      }
+
+      // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed.
+      if (!sliceUnionCst.areIdsAlignedWithOther(tmpSliceCst)) {
+
+        // Pre-constraint id alignment: record loop IVs used in each constraint
+        // system.
+        SmallPtrSet<Value *, 8> sliceUnionIVs;
+        for (unsigned k = 0, l = sliceUnionCst.getNumDimIds(); k < l; ++k)
+          sliceUnionIVs.insert(sliceUnionCst.getIdValue(k));
+        SmallPtrSet<Value *, 8> tmpSliceIVs;
+        for (unsigned k = 0, l = tmpSliceCst.getNumDimIds(); k < l; ++k)
+          tmpSliceIVs.insert(tmpSliceCst.getIdValue(k));
+
+        sliceUnionCst.mergeAndAlignIdsWithOther(/*offset=*/0, &tmpSliceCst);
+
+        // Post-constraint id alignment: add loop IV bounds missing after
+        // id alignment to constraint systems. This can occur if one constraint
+        // system uses an loop IV that is not used by the other. The call
+        // to unionBoundingBox below expects constraints for each Loop IV, even
+        // if they are the unsliced full loop bounds added here.
+        if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst)))
+          return failure();
+        if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst)))
+          return failure();
+      }
+      // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
+      if (failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
+        LLVM_DEBUG(llvm::dbgs()
+                   << "Unable to compute union bounding box of slice bounds."
+                      "\n.");
+        return failure();
+      }
+    }
+  }
+
+  // Empty union.
+  if (sliceUnionCst.getNumDimAndSymbolIds() == 0)
+    return failure();
+
+  // Gather loops surrounding ops from loop nest where slice will be inserted.
+  SmallVector<Operation *, 4> ops;
+  for (auto &dep : dependentOpPairs) {
+    ops.push_back(isBackwardSlice ? dep.second : dep.first);
+  }
+  SmallVector<AffineForOp, 4> surroundingLoops;
+  unsigned innermostCommonLoopDepth =
+      getInnermostCommonLoopDepth(ops, surroundingLoops);
+  if (loopDepth > innermostCommonLoopDepth) {
+    LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n.");
+    return failure();
+  }
+
+  // Store 'numSliceLoopIVs' before converting dst loop IVs to dims.
+  unsigned numSliceLoopIVs = sliceUnionCst.getNumDimIds();
+
+  // Convert any dst loop IVs which are symbol identifiers to dim identifiers.
+  sliceUnionCst.convertLoopIVSymbolsToDims();
+  sliceUnion->clearBounds();
+  sliceUnion->lbs.resize(numSliceLoopIVs, AffineMap());
+  sliceUnion->ubs.resize(numSliceLoopIVs, AffineMap());
+
+  // Get slice bounds from slice union constraints 'sliceUnionCst'.
+  sliceUnionCst.getSliceBounds(/*offset=*/0, numSliceLoopIVs,
+                               opsA[0]->getContext(), &sliceUnion->lbs,
+                               &sliceUnion->ubs);
+
+  // Add slice bound operands of union.
+  SmallVector<Value *, 4> sliceBoundOperands;
+  sliceUnionCst.getIdValues(numSliceLoopIVs,
+                            sliceUnionCst.getNumDimAndSymbolIds(),
+                            &sliceBoundOperands);
+
+  // Copy src loop IVs from 'sliceUnionCst' to 'sliceUnion'.
+  sliceUnion->ivs.clear();
+  sliceUnionCst.getIdValues(0, numSliceLoopIVs, &sliceUnion->ivs);
+
+  // Set loop nest insertion point to block start at 'loopDepth'.
+  sliceUnion->insertPoint =
+      isBackwardSlice
+          ? surroundingLoops[loopDepth - 1].getBody()->begin()
+          : std::prev(surroundingLoops[loopDepth - 1].getBody()->end());
+
+  // Give each bound its own copy of 'sliceBoundOperands' for subsequent
+  // canonicalization.
+  sliceUnion->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
+  sliceUnion->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
+  return success();
+}
+
+const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier";
+// Computes slice bounds by projecting out any loop IVs from
+// 'dependenceConstraints' at depth greater than 'loopDepth', and computes slice
+// bounds in 'sliceState' which represent the one loop nest's IVs in terms of
+// the other loop nest's IVs, symbols and constants (using 'isBackwardsSlice').
+void mlir::getComputationSliceState(
+    Operation *depSourceOp, Operation *depSinkOp,
+    FlatAffineConstraints *dependenceConstraints, unsigned loopDepth,
+    bool isBackwardSlice, ComputationSliceState *sliceState) {
+  // Get loop nest surrounding src operation.
+  SmallVector<AffineForOp, 4> srcLoopIVs;
+  getLoopIVs(*depSourceOp, &srcLoopIVs);
+  unsigned numSrcLoopIVs = srcLoopIVs.size();
+
+  // Get loop nest surrounding dst operation.
+  SmallVector<AffineForOp, 4> dstLoopIVs;
+  getLoopIVs(*depSinkOp, &dstLoopIVs);
+  unsigned numDstLoopIVs = dstLoopIVs.size();
+
+  assert((!isBackwardSlice && loopDepth <= numSrcLoopIVs) ||
+         (isBackwardSlice && loopDepth <= numDstLoopIVs));
+
+  // Project out dimensions other than those up to 'loopDepth'.
+  unsigned pos = isBackwardSlice ? numSrcLoopIVs + loopDepth : loopDepth;
+  unsigned num =
+      isBackwardSlice ? numDstLoopIVs - loopDepth : numSrcLoopIVs - loopDepth;
+  dependenceConstraints->projectOut(pos, num);
+
+  // Add slice loop IV values to 'sliceState'.
+  unsigned offset = isBackwardSlice ? 0 : loopDepth;
+  unsigned numSliceLoopIVs = isBackwardSlice ? numSrcLoopIVs : numDstLoopIVs;
+  dependenceConstraints->getIdValues(offset, offset + numSliceLoopIVs,
+                                     &sliceState->ivs);
+
+  // Set up lower/upper bound affine maps for the slice.
+  sliceState->lbs.resize(numSliceLoopIVs, AffineMap());
+  sliceState->ubs.resize(numSliceLoopIVs, AffineMap());
+
+  // Get bounds for slice IVs in terms of other IVs, symbols, and constants.
+  dependenceConstraints->getSliceBounds(offset, numSliceLoopIVs,
+                                        depSourceOp->getContext(),
+                                        &sliceState->lbs, &sliceState->ubs);
+
+  // Set up bound operands for the slice's lower and upper bounds.
+  SmallVector<Value *, 4> sliceBoundOperands;
+  unsigned numDimsAndSymbols = dependenceConstraints->getNumDimAndSymbolIds();
+  for (unsigned i = 0; i < numDimsAndSymbols; ++i) {
+    if (i < offset || i >= offset + numSliceLoopIVs) {
+      sliceBoundOperands.push_back(dependenceConstraints->getIdValue(i));
+    }
+  }
+
+  // Give each bound its own copy of 'sliceBoundOperands' for subsequent
+  // canonicalization.
+  sliceState->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
+  sliceState->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
+
+  // Set destination loop nest insertion point to block start at 'dstLoopDepth'.
+  sliceState->insertPoint =
+      isBackwardSlice ? dstLoopIVs[loopDepth - 1].getBody()->begin()
+                      : std::prev(srcLoopIVs[loopDepth - 1].getBody()->end());
+
+  llvm::SmallDenseSet<Value *, 8> sequentialLoops;
+  if (isa<AffineLoadOp>(depSourceOp) && isa<AffineLoadOp>(depSinkOp)) {
+    // For read-read access pairs, clear any slice bounds on sequential loops.
+    // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'.
+    getSequentialLoops(isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0],
+                       &sequentialLoops);
+  }
+  // Clear all sliced loop bounds beginning at the first sequential loop, or
+  // first loop with a slice fusion barrier attribute..
+  // TODO(andydavis, bondhugula) Use MemRef read/write regions instead of
+  // using 'kSliceFusionBarrierAttrName'.
+  auto getSliceLoop = [&](unsigned i) {
+    return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i];
+  };
+  for (unsigned i = 0; i < numSliceLoopIVs; ++i) {
+    Value *iv = getSliceLoop(i).getInductionVar();
+    if (sequentialLoops.count(iv) == 0 &&
+        getSliceLoop(i).getAttr(kSliceFusionBarrierAttrName) == nullptr)
+      continue;
+    for (unsigned j = i; j < numSliceLoopIVs; ++j) {
+      sliceState->lbs[j] = AffineMap();
+      sliceState->ubs[j] = AffineMap();
+    }
+    break;
+  }
+}
+
+/// Creates a computation slice of the loop nest surrounding 'srcOpInst',
+/// updates the slice loop bounds with any non-null bound maps specified in
+/// 'sliceState', and inserts this slice into the loop nest surrounding
+/// 'dstOpInst' at loop depth 'dstLoopDepth'.
+// TODO(andydavis,bondhugula): extend the slicing utility to compute slices that
+// aren't necessarily a one-to-one relation b/w the source and destination. The
+// relation between the source and destination could be many-to-many in general.
+// TODO(andydavis,bondhugula): the slice computation is incorrect in the cases
+// where the dependence from the source to the destination does not cover the
+// entire destination index set. Subtract out the dependent destination
+// iterations from destination index set and check for emptiness --- this is one
+// solution.
+AffineForOp
+mlir::insertBackwardComputationSlice(Operation *srcOpInst, Operation *dstOpInst,
+                                     unsigned dstLoopDepth,
+                                     ComputationSliceState *sliceState) {
+  // Get loop nest surrounding src operation.
+  SmallVector<AffineForOp, 4> srcLoopIVs;
+  getLoopIVs(*srcOpInst, &srcLoopIVs);
+  unsigned numSrcLoopIVs = srcLoopIVs.size();
+
+  // Get loop nest surrounding dst operation.
+  SmallVector<AffineForOp, 4> dstLoopIVs;
+  getLoopIVs(*dstOpInst, &dstLoopIVs);
+  unsigned dstLoopIVsSize = dstLoopIVs.size();
+  if (dstLoopDepth > dstLoopIVsSize) {
+    dstOpInst->emitError("invalid destination loop depth");
+    return AffineForOp();
+  }
+
+  // Find the op block positions of 'srcOpInst' within 'srcLoopIVs'.
+  SmallVector<unsigned, 4> positions;
+  // TODO(andydavis): This code is incorrect since srcLoopIVs can be 0-d.
+  findInstPosition(srcOpInst, srcLoopIVs[0].getOperation()->getBlock(),
+                   &positions);
+
+  // Clone src loop nest and insert it a the beginning of the operation block
+  // of the loop at 'dstLoopDepth' in 'dstLoopIVs'.
+  auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1];
+  OpBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin());
+  auto sliceLoopNest =
+      cast<AffineForOp>(b.clone(*srcLoopIVs[0].getOperation()));
+
+  Operation *sliceInst =
+      getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody());
+  // Get loop nest surrounding 'sliceInst'.
+  SmallVector<AffineForOp, 4> sliceSurroundingLoops;
+  getLoopIVs(*sliceInst, &sliceSurroundingLoops);
+
+  // Sanity check.
+  unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size();
+  (void)sliceSurroundingLoopsSize;
+  assert(dstLoopDepth + numSrcLoopIVs >= sliceSurroundingLoopsSize);
+  unsigned sliceLoopLimit = dstLoopDepth + numSrcLoopIVs;
+  (void)sliceLoopLimit;
+  assert(sliceLoopLimit >= sliceSurroundingLoopsSize);
+
+  // Update loop bounds for loops in 'sliceLoopNest'.
+  for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
+    auto forOp = sliceSurroundingLoops[dstLoopDepth + i];
+    if (AffineMap lbMap = sliceState->lbs[i])
+      forOp.setLowerBound(sliceState->lbOperands[i], lbMap);
+    if (AffineMap ubMap = sliceState->ubs[i])
+      forOp.setUpperBound(sliceState->ubOperands[i], ubMap);
+  }
+  return sliceLoopNest;
+}
+
+// Constructs  MemRefAccess populating it with the memref, its indices and
+// opinst from 'loadOrStoreOpInst'.
+MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
+  if (auto loadOp = dyn_cast<AffineLoadOp>(loadOrStoreOpInst)) {
+    memref = loadOp.getMemRef();
+    opInst = loadOrStoreOpInst;
+    auto loadMemrefType = loadOp.getMemRefType();
+    indices.reserve(loadMemrefType.getRank());
+    for (auto *index : loadOp.getIndices()) {
+      indices.push_back(index);
+    }
+  } else {
+    assert(isa<AffineStoreOp>(loadOrStoreOpInst) && "load/store op expected");
+    auto storeOp = dyn_cast<AffineStoreOp>(loadOrStoreOpInst);
+    opInst = loadOrStoreOpInst;
+    memref = storeOp.getMemRef();
+    auto storeMemrefType = storeOp.getMemRefType();
+    indices.reserve(storeMemrefType.getRank());
+    for (auto *index : storeOp.getIndices()) {
+      indices.push_back(index);
+    }
+  }
+}
+
+unsigned MemRefAccess::getRank() const {
+  return memref->getType().cast<MemRefType>().getRank();
+}
+
+bool MemRefAccess::isStore() const { return isa<AffineStoreOp>(opInst); }
+
+/// Returns the nesting depth of this statement, i.e., the number of loops
+/// surrounding this statement.
+unsigned mlir::getNestingDepth(Operation &op) {
+  Operation *currOp = &op;
+  unsigned depth = 0;
+  while ((currOp = currOp->getParentOp())) {
+    if (isa<AffineForOp>(currOp))
+      depth++;
+  }
+  return depth;
+}
+
+/// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
+/// where each lists loops from outer-most to inner-most in loop nest.
+unsigned mlir::getNumCommonSurroundingLoops(Operation &A, Operation &B) {
+  SmallVector<AffineForOp, 4> loopsA, loopsB;
+  getLoopIVs(A, &loopsA);
+  getLoopIVs(B, &loopsB);
+
+  unsigned minNumLoops = std::min(loopsA.size(), loopsB.size());
+  unsigned numCommonLoops = 0;
+  for (unsigned i = 0; i < minNumLoops; ++i) {
+    if (loopsA[i].getOperation() != loopsB[i].getOperation())
+      break;
+    ++numCommonLoops;
+  }
+  return numCommonLoops;
+}
+
+static Optional<int64_t> getMemoryFootprintBytes(Block &block,
+                                                 Block::iterator start,
+                                                 Block::iterator end,
+                                                 int memorySpace) {
+  SmallDenseMap<Value *, std::unique_ptr<MemRefRegion>, 4> regions;
+
+  // Walk this 'affine.for' operation to gather all memory regions.
+  bool error = false;
+  block.walk(start, end, [&](Operation *opInst) {
+    if (!isa<AffineLoadOp>(opInst) && !isa<AffineStoreOp>(opInst)) {
+      // Neither load nor a store op.
+      return;
+    }
+
+    // Compute the memref region symbolic in any IVs enclosing this block.
+    auto region = llvm::make_unique<MemRefRegion>(opInst->getLoc());
+    if (failed(
+            region->compute(opInst,
+                            /*loopDepth=*/getNestingDepth(*block.begin())))) {
+      opInst->emitError("Error obtaining memory region\n");
+      error = true;
+      return;
+    }
+    auto it = regions.find(region->memref);
+    if (it == regions.end()) {
+      regions[region->memref] = std::move(region);
+    } else if (failed(it->second->unionBoundingBox(*region))) {
+      opInst->emitWarning(
+          "getMemoryFootprintBytes: unable to perform a union on a memory "
+          "region");
+      error = true;
+      return;
+    }
+  });
+
+  if (error)
+    return None;
+
+  int64_t totalSizeInBytes = 0;
+  for (const auto &region : regions) {
+    Optional<int64_t> size = region.second->getRegionSize();
+    if (!size.hasValue())
+      return None;
+    totalSizeInBytes += size.getValue();
+  }
+  return totalSizeInBytes;
+}
+
+Optional<int64_t> mlir::getMemoryFootprintBytes(AffineForOp forOp,
+                                                int memorySpace) {
+  auto *forInst = forOp.getOperation();
+  return ::getMemoryFootprintBytes(
+      *forInst->getBlock(), Block::iterator(forInst),
+      std::next(Block::iterator(forInst)), memorySpace);
+}
+
+/// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
+/// at 'forOp'.
+void mlir::getSequentialLoops(
+    AffineForOp forOp, llvm::SmallDenseSet<Value *, 8> *sequentialLoops) {
+  forOp.getOperation()->walk([&](Operation *op) {
+    if (auto innerFor = dyn_cast<AffineForOp>(op))
+      if (!isLoopParallel(innerFor))
+        sequentialLoops->insert(innerFor.getInductionVar());
+  });
+}
+
+/// Returns true if 'forOp' is parallel.
+bool mlir::isLoopParallel(AffineForOp forOp) {
+  // Collect all load and store ops in loop nest rooted at 'forOp'.
+  SmallVector<Operation *, 8> loadAndStoreOpInsts;
+  bool hasSideEffectingOps = false;
+  forOp.getOperation()->walk([&](Operation *opInst) {
+    if (isa<AffineLoadOp>(opInst) || isa<AffineStoreOp>(opInst))
+      return loadAndStoreOpInsts.push_back(opInst);
+    if (!isa<AffineForOp>(opInst) && !isa<AffineTerminatorOp>(opInst) &&
+        !isa<AffineIfOp>(opInst) && !opInst->hasNoSideEffect()) {
+      hasSideEffectingOps = true;
+    }
+  });
+  // Stop early if the loop has unknown ops with side effects.
+  if (hasSideEffectingOps)
+    return false;
+
+  // Dep check depth would be number of enclosing loops + 1.
+  unsigned depth = getNestingDepth(*forOp.getOperation()) + 1;
+
+  // Check dependences between all pairs of ops in 'loadAndStoreOpInsts'.
+  for (auto *srcOpInst : loadAndStoreOpInsts) {
+    MemRefAccess srcAccess(srcOpInst);
+    for (auto *dstOpInst : loadAndStoreOpInsts) {
+      MemRefAccess dstAccess(dstOpInst);
+      FlatAffineConstraints dependenceConstraints;
+      DependenceResult result = checkMemrefAccessDependence(
+          srcAccess, dstAccess, depth, &dependenceConstraints,
+          /*dependenceComponents=*/nullptr);
+      if (result.value != DependenceResult::NoDependence)
+        return false;
+    }
+  }
+  return true;
+}
diff --git a/third_party/mlir/lib/Analysis/VectorAnalysis.cpp b/third_party/mlir/lib/Analysis/VectorAnalysis.cpp
new file mode 100644
index 0000000..2306156
--- /dev/null
+++ b/third_party/mlir/lib/Analysis/VectorAnalysis.cpp
@@ -0,0 +1,241 @@
+//===- VectorAnalysis.cpp - Analysis for Vectorization --------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Analysis/VectorAnalysis.h"
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Support/Functional.h"
+#include "mlir/Support/STLExtras.h"
+#include "mlir/VectorOps/VectorOps.h"
+
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SetVector.h"
+
+///
+/// Implements Analysis functions specific to vectors which support
+/// the vectorization and vectorization materialization passes.
+///
+
+using namespace mlir;
+
+using llvm::SetVector;
+
+Optional<SmallVector<unsigned, 4>>
+mlir::shapeRatio(ArrayRef<int64_t> superShape, ArrayRef<int64_t> subShape) {
+  if (superShape.size() < subShape.size()) {
+    return Optional<SmallVector<unsigned, 4>>();
+  }
+
+  // Starting from the end, compute the integer divisors.
+  // Set the boolean `divides` if integral division is not possible.
+  std::vector<unsigned> result;
+  result.reserve(superShape.size());
+  bool divides = true;
+  auto divide = [&divides, &result](int superSize, int subSize) {
+    assert(superSize > 0 && "superSize must be > 0");
+    assert(subSize > 0 && "subSize must be > 0");
+    divides &= (superSize % subSize == 0);
+    result.push_back(superSize / subSize);
+  };
+  functional::zipApply(
+      divide, SmallVector<int64_t, 8>{superShape.rbegin(), superShape.rend()},
+      SmallVector<int64_t, 8>{subShape.rbegin(), subShape.rend()});
+
+  // If integral division does not occur, return and let the caller decide.
+  if (!divides) {
+    return None;
+  }
+
+  // At this point we computed the ratio (in reverse) for the common
+  // size. Fill with the remaining entries from the super-vector shape (still in
+  // reverse).
+  int commonSize = subShape.size();
+  std::copy(superShape.rbegin() + commonSize, superShape.rend(),
+            std::back_inserter(result));
+
+  assert(result.size() == superShape.size() &&
+         "super to sub shape ratio is not of the same size as the super rank");
+
+  // Reverse again to get it back in the proper order and return.
+  return SmallVector<unsigned, 4>{result.rbegin(), result.rend()};
+}
+
+Optional<SmallVector<unsigned, 4>> mlir::shapeRatio(VectorType superVectorType,
+                                                    VectorType subVectorType) {
+  assert(superVectorType.getElementType() == subVectorType.getElementType() &&
+         "vector types must be of the same elemental type");
+  return shapeRatio(superVectorType.getShape(), subVectorType.getShape());
+}
+
+/// Constructs a permutation map from memref indices to vector dimension.
+///
+/// The implementation uses the knowledge of the mapping of enclosing loop to
+/// vector dimension. `enclosingLoopToVectorDim` carries this information as a
+/// map with:
+///   - keys representing "vectorized enclosing loops";
+///   - values representing the corresponding vector dimension.
+/// The algorithm traverses "vectorized enclosing loops" and extracts the
+/// at-most-one MemRef index that is invariant along said loop. This index is
+/// guaranteed to be at most one by construction: otherwise the MemRef is not
+/// vectorizable.
+/// If this invariant index is found, it is added to the permutation_map at the
+/// proper vector dimension.
+/// If no index is found to be invariant, 0 is added to the permutation_map and
+/// corresponds to a vector broadcast along that dimension.
+///
+/// Returns an empty AffineMap if `enclosingLoopToVectorDim` is empty,
+/// signalling that no permutation map can be constructed given
+/// `enclosingLoopToVectorDim`.
+///
+/// Examples can be found in the documentation of `makePermutationMap`, in the
+/// header file.
+static AffineMap makePermutationMap(
+    ArrayRef<Value *> indices,
+    const DenseMap<Operation *, unsigned> &enclosingLoopToVectorDim) {
+  if (enclosingLoopToVectorDim.empty())
+    return AffineMap();
+  MLIRContext *context =
+      enclosingLoopToVectorDim.begin()->getFirst()->getContext();
+  using functional::makePtrDynCaster;
+  using functional::map;
+  SmallVector<AffineExpr, 4> perm(enclosingLoopToVectorDim.size(),
+                                  getAffineConstantExpr(0, context));
+
+  for (auto kvp : enclosingLoopToVectorDim) {
+    assert(kvp.second < perm.size());
+    auto invariants = getInvariantAccesses(
+        cast<AffineForOp>(kvp.first).getInductionVar(), indices);
+    unsigned numIndices = indices.size();
+    unsigned countInvariantIndices = 0;
+    for (unsigned dim = 0; dim < numIndices; ++dim) {
+      if (!invariants.count(indices[dim])) {
+        assert(perm[kvp.second] == getAffineConstantExpr(0, context) &&
+               "permutationMap already has an entry along dim");
+        perm[kvp.second] = getAffineDimExpr(dim, context);
+      } else {
+        ++countInvariantIndices;
+      }
+    }
+    assert((countInvariantIndices == numIndices ||
+            countInvariantIndices == numIndices - 1) &&
+           "Vectorization prerequisite violated: at most 1 index may be "
+           "invariant wrt a vectorized loop");
+  }
+  return AffineMap::get(indices.size(), 0, perm);
+}
+
+/// Implementation detail that walks up the parents and records the ones with
+/// the specified type.
+/// TODO(ntv): could also be implemented as a collect parents followed by a
+/// filter and made available outside this file.
+template <typename T>
+static SetVector<Operation *> getParentsOfType(Operation *op) {
+  SetVector<Operation *> res;
+  auto *current = op;
+  while (auto *parent = current->getParentOp()) {
+    if (auto typedParent = dyn_cast<T>(parent)) {
+      assert(res.count(parent) == 0 && "Already inserted");
+      res.insert(parent);
+    }
+    current = parent;
+  }
+  return res;
+}
+
+/// Returns the enclosing AffineForOp, from closest to farthest.
+static SetVector<Operation *> getEnclosingforOps(Operation *op) {
+  return getParentsOfType<AffineForOp>(op);
+}
+
+AffineMap mlir::makePermutationMap(
+    Operation *op, ArrayRef<Value *> indices,
+    const DenseMap<Operation *, unsigned> &loopToVectorDim) {
+  DenseMap<Operation *, unsigned> enclosingLoopToVectorDim;
+  auto enclosingLoops = getEnclosingforOps(op);
+  for (auto *forInst : enclosingLoops) {
+    auto it = loopToVectorDim.find(forInst);
+    if (it != loopToVectorDim.end()) {
+      enclosingLoopToVectorDim.insert(*it);
+    }
+  }
+  return ::makePermutationMap(indices, enclosingLoopToVectorDim);
+}
+
+bool mlir::matcher::operatesOnSuperVectorsOf(Operation &op,
+                                             VectorType subVectorType) {
+  // First, extract the vector type and ditinguish between:
+  //   a. ops that *must* lower a super-vector (i.e. vector.transfer_read,
+  //      vector.transfer_write); and
+  //   b. ops that *may* lower a super-vector (all other ops).
+  // The ops that *may* lower a super-vector only do so if the super-vector to
+  // sub-vector ratio exists. The ops that *must* lower a super-vector are
+  // explicitly checked for this property.
+  /// TODO(ntv): there should be a single function for all ops to do this so we
+  /// do not have to special case. Maybe a trait, or just a method, unclear atm.
+  bool mustDivide = false;
+  (void)mustDivide;
+  VectorType superVectorType;
+  if (auto read = dyn_cast<vector::VectorTransferReadOp>(op)) {
+    superVectorType = read.getResultType();
+    mustDivide = true;
+  } else if (auto write = dyn_cast<vector::VectorTransferWriteOp>(op)) {
+    superVectorType = write.getVectorType();
+    mustDivide = true;
+  } else if (op.getNumResults() == 0) {
+    if (!isa<ReturnOp>(op)) {
+      op.emitError("NYI: assuming only return operations can have 0 "
+                   " results at this point");
+    }
+    return false;
+  } else if (op.getNumResults() == 1) {
+    if (auto v = op.getResult(0)->getType().dyn_cast<VectorType>()) {
+      superVectorType = v;
+    } else {
+      // Not a vector type.
+      return false;
+    }
+  } else {
+    // Not a vector.transfer and has more than 1 result, fail hard for now to
+    // wake us up when something changes.
+    op.emitError("NYI: operation has more than 1 result");
+    return false;
+  }
+
+  // Get the ratio.
+  auto ratio = shapeRatio(superVectorType, subVectorType);
+
+  // Sanity check.
+  assert((ratio.hasValue() || !mustDivide) &&
+         "vector.transfer operation in which super-vector size is not an"
+         " integer multiple of sub-vector size");
+
+  // This catches cases that are not strictly necessary to have multiplicity but
+  // still aren't divisible by the sub-vector shape.
+  // This could be useful information if we wanted to reshape at the level of
+  // the vector type (but we would have to look at the compute and distinguish
+  // between parallel, reduction and possibly other cases.
+  if (!ratio.hasValue()) {
+    return false;
+  }
+
+  return true;
+}
diff --git a/third_party/mlir/lib/Analysis/Verifier.cpp b/third_party/mlir/lib/Analysis/Verifier.cpp
new file mode 100644
index 0000000..d250996
--- /dev/null
+++ b/third_party/mlir/lib/Analysis/Verifier.cpp
@@ -0,0 +1,273 @@
+//===- Verifier.cpp - MLIR Verifier Implementation ------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the verify() methods on the various IR types, performing
+// (potentially expensive) checks on the holistic structure of the code.  This
+// can be used for detecting bugs in compiler transformations and hand written
+// .mlir files.
+//
+// The checks in this file are only for things that can occur as part of IR
+// transformations: e.g. violation of dominance information, malformed operation
+// attributes, etc.  MLIR supports transformations moving IR through locally
+// invalid states (e.g. unlinking an operation from a block before re-inserting
+// it in a new place), but each transformation must complete with the IR in a
+// valid form.
+//
+// This should not check for things that are always wrong by construction (e.g.
+// attributes or other immutable structures that are incorrect), because those
+// are not mutable and can be checked at time of construction.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Verifier.h"
+#include "mlir/Analysis/Dominance.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/PrettyStackTrace.h"
+#include "llvm/Support/Regex.h"
+using namespace mlir;
+
+namespace {
+/// This class encapsulates all the state used to verify an operation region.
+class OperationVerifier {
+public:
+  explicit OperationVerifier(MLIRContext *ctx)
+      : ctx(ctx), identifierRegex("^[a-zA-Z_][a-zA-Z_0-9\\.\\$]*$") {}
+
+  /// Verify the given operation.
+  LogicalResult verify(Operation &op);
+
+  /// Returns the registered dialect for a dialect-specific attribute.
+  Dialect *getDialectForAttribute(const NamedAttribute &attr) {
+    assert(attr.first.strref().contains('.') && "expected dialect attribute");
+    auto dialectNamePair = attr.first.strref().split('.');
+    return ctx->getRegisteredDialect(dialectNamePair.first);
+  }
+
+  /// Returns if the given string is valid to use as an identifier name.
+  bool isValidName(StringRef name) { return identifierRegex.match(name); }
+
+private:
+  /// Verify the given potentially nested region or block.
+  LogicalResult verifyRegion(Region &region);
+  LogicalResult verifyBlock(Block &block);
+  LogicalResult verifyOperation(Operation &op);
+
+  /// Verify the dominance within the given IR unit.
+  LogicalResult verifyDominance(Region &region);
+  LogicalResult verifyDominance(Operation &op);
+
+  /// Emit an error for the given block.
+  InFlightDiagnostic emitError(Block &bb, const Twine &message) {
+    // Take the location information for the first operation in the block.
+    if (!bb.empty())
+      return bb.front().emitError(message);
+
+    // Worst case, fall back to using the parent's location.
+    return mlir::emitError(bb.getParent()->getLoc(), message);
+  }
+
+  /// The current context for the verifier.
+  MLIRContext *ctx;
+
+  /// Dominance information for this operation, when checking dominance.
+  DominanceInfo *domInfo = nullptr;
+
+  /// Regex checker for attribute names.
+  llvm::Regex identifierRegex;
+
+  /// Mapping between dialect namespace and if that dialect supports
+  /// unregistered operations.
+  llvm::StringMap<bool> dialectAllowsUnknownOps;
+};
+} // end anonymous namespace
+
+/// Verify the given operation.
+LogicalResult OperationVerifier::verify(Operation &op) {
+  // Verify the operation first.
+  if (failed(verifyOperation(op)))
+    return failure();
+
+  // Since everything looks structurally ok to this point, we do a dominance
+  // check for any nested regions. We do this as a second pass since malformed
+  // CFG's can cause dominator analysis constructure to crash and we want the
+  // verifier to be resilient to malformed code.
+  DominanceInfo theDomInfo(&op);
+  domInfo = &theDomInfo;
+  for (auto &region : op.getRegions())
+    if (failed(verifyDominance(region)))
+      return failure();
+
+  domInfo = nullptr;
+  return success();
+}
+
+LogicalResult OperationVerifier::verifyRegion(Region &region) {
+  if (region.empty())
+    return success();
+
+  // Verify the first block has no predecessors.
+  auto *firstBB = &region.front();
+  if (!firstBB->hasNoPredecessors())
+    return mlir::emitError(region.getLoc(),
+                           "entry block of region may not have predecessors");
+
+  // Verify each of the blocks within the region.
+  for (auto &block : region)
+    if (failed(verifyBlock(block)))
+      return failure();
+  return success();
+}
+
+LogicalResult OperationVerifier::verifyBlock(Block &block) {
+  for (auto *arg : block.getArguments())
+    if (arg->getOwner() != &block)
+      return emitError(block, "block argument not owned by block");
+
+  // Verify that this block has a terminator.
+  if (block.empty())
+    return emitError(block, "block with no terminator");
+
+  // Verify the non-terminator operations separately so that we can verify
+  // they has no successors.
+  for (auto &op : llvm::make_range(block.begin(), std::prev(block.end()))) {
+    if (op.getNumSuccessors() != 0)
+      return op.emitError(
+          "operation with block successors must terminate its parent block");
+
+    if (failed(verifyOperation(op)))
+      return failure();
+  }
+
+  // Verify the terminator.
+  if (failed(verifyOperation(block.back())))
+    return failure();
+  if (block.back().isKnownNonTerminator())
+    return emitError(block, "block with no terminator");
+
+  // Verify that this block is not branching to a block of a different
+  // region.
+  for (Block *successor : block.getSuccessors())
+    if (successor->getParent() != block.getParent())
+      return block.back().emitOpError(
+          "branching to block of a different region");
+
+  return success();
+}
+
+LogicalResult OperationVerifier::verifyOperation(Operation &op) {
+  // Check that operands are non-nil and structurally ok.
+  for (auto *operand : op.getOperands())
+    if (!operand)
+      return op.emitError("null operand found");
+
+  /// Verify that all of the attributes are okay.
+  for (auto attr : op.getAttrs()) {
+    if (!identifierRegex.match(attr.first))
+      return op.emitError("invalid attribute name '") << attr.first << "'";
+
+    // Check for any optional dialect specific attributes.
+    if (!attr.first.strref().contains('.'))
+      continue;
+    if (auto *dialect = getDialectForAttribute(attr))
+      if (failed(dialect->verifyOperationAttribute(&op, attr)))
+        return failure();
+  }
+
+  // If we can get operation info for this, check the custom hook.
+  auto *opInfo = op.getAbstractOperation();
+  if (opInfo && failed(opInfo->verifyInvariants(&op)))
+    return failure();
+
+  // Verify that all child regions are ok.
+  for (auto &region : op.getRegions())
+    if (failed(verifyRegion(region)))
+      return failure();
+
+  // If this is a registered operation, there is nothing left to do.
+  if (opInfo)
+    return success();
+
+  // Otherwise, verify that the parent dialect allows un-registered operations.
+  auto dialectPrefix = op.getName().getDialect();
+
+  // Check for an existing answer for the operation dialect.
+  auto it = dialectAllowsUnknownOps.find(dialectPrefix);
+  if (it == dialectAllowsUnknownOps.end()) {
+    // If the operation dialect is registered, query it directly.
+    if (auto *dialect = ctx->getRegisteredDialect(dialectPrefix))
+      it = dialectAllowsUnknownOps
+               .try_emplace(dialectPrefix, dialect->allowsUnknownOperations())
+               .first;
+    // Otherwise, conservatively allow unknown operations.
+    else
+      it = dialectAllowsUnknownOps.try_emplace(dialectPrefix, true).first;
+  }
+
+  if (!it->second) {
+    return op.emitError("unregistered operation '")
+           << op.getName() << "' found in dialect ('" << dialectPrefix
+           << "') that does not allow unknown operations";
+  }
+
+  return success();
+}
+
+LogicalResult OperationVerifier::verifyDominance(Region &region) {
+  // Verify the dominance of each of the held operations.
+  for (auto &block : region)
+    for (auto &op : block)
+      if (failed(verifyDominance(op)))
+        return failure();
+  return success();
+}
+
+LogicalResult OperationVerifier::verifyDominance(Operation &op) {
+  // Check that operands properly dominate this use.
+  for (unsigned operandNo = 0, e = op.getNumOperands(); operandNo != e;
+       ++operandNo) {
+    auto *operand = op.getOperand(operandNo);
+    if (domInfo->properlyDominates(operand, &op))
+      continue;
+
+    auto diag = op.emitError("operand #")
+                << operandNo << " does not dominate this use";
+    if (auto *useOp = operand->getDefiningOp())
+      diag.attachNote(useOp->getLoc()) << "operand defined here";
+    return failure();
+  }
+
+  // Verify the dominance of each of the nested blocks within this operation.
+  for (auto &region : op.getRegions())
+    if (failed(verifyDominance(region)))
+      return failure();
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Entrypoint
+//===----------------------------------------------------------------------===//
+
+/// Perform (potentially expensive) checks of invariants, used to detect
+/// compiler bugs.  On error, this reports the error through the MLIRContext and
+/// returns failure.
+LogicalResult mlir::verify(Operation *op) {
+  return OperationVerifier(op->getContext()).verify(*op);
+}
diff --git a/third_party/mlir/lib/CMakeLists.txt b/third_party/mlir/lib/CMakeLists.txt
new file mode 100644
index 0000000..fece5cb
--- /dev/null
+++ b/third_party/mlir/lib/CMakeLists.txt
@@ -0,0 +1,20 @@
+add_subdirectory(AffineOps)
+add_subdirectory(Analysis)
+add_subdirectory(Conversion)
+add_subdirectory(Dialect)
+add_subdirectory(EDSC)
+add_subdirectory(ExecutionEngine)
+add_subdirectory(IR)
+add_subdirectory(LLVMIR)
+add_subdirectory(Linalg)
+add_subdirectory(Parser)
+add_subdirectory(Pass)
+add_subdirectory(Quantizer)
+add_subdirectory(SDBM)
+add_subdirectory(StandardOps)
+add_subdirectory(Support)
+add_subdirectory(TableGen)
+add_subdirectory(Target)
+add_subdirectory(Transforms)
+add_subdirectory(Translation)
+add_subdirectory(VectorOps)
diff --git a/third_party/mlir/lib/Conversion/CMakeLists.txt b/third_party/mlir/lib/Conversion/CMakeLists.txt
new file mode 100644
index 0000000..1ddd103
--- /dev/null
+++ b/third_party/mlir/lib/Conversion/CMakeLists.txt
@@ -0,0 +1,7 @@
+add_subdirectory(LoopsToGPU)
+add_subdirectory(ControlFlowToCFG)
+add_subdirectory(GPUToCUDA)
+add_subdirectory(GPUToNVVM)
+add_subdirectory(GPUToSPIRV)
+add_subdirectory(StandardToLLVM)
+add_subdirectory(StandardToSPIRV)
diff --git a/third_party/mlir/lib/Conversion/ControlFlowToCFG/CMakeLists.txt b/third_party/mlir/lib/Conversion/ControlFlowToCFG/CMakeLists.txt
new file mode 100644
index 0000000..d8793c2
--- /dev/null
+++ b/third_party/mlir/lib/Conversion/ControlFlowToCFG/CMakeLists.txt
@@ -0,0 +1,22 @@
+add_llvm_library(MLIRControlFlowToCFG
+  ConvertControlFlowToCFG.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ControlFlowToCFG
+)
+add_dependencies(
+  MLIRControlFlowToCFG
+
+  MLIRLoopOps
+  MLIRTransforms
+  LLVMCore
+  LLVMSupport
+)
+target_link_libraries(
+  MLIRControlFlowToCFG
+
+  MLIRLoopOps
+  MLIRTransforms
+  LLVMCore
+  LLVMSupport
+)
diff --git a/third_party/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp b/third_party/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp
new file mode 100644
index 0000000..9535dc7
--- /dev/null
+++ b/third_party/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp
@@ -0,0 +1,278 @@
+//===- ConvertControlFlowToCFG.cpp - ControlFlow to CFG conversion --------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to convert loop.for, loop.if and loop.terminator
+// ops into standard CFG ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Support/Functional.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/Utils.h"
+
+using namespace mlir;
+using namespace mlir::loop;
+
+namespace {
+
+struct ControlFlowToCFGPass : public FunctionPass<ControlFlowToCFGPass> {
+  void runOnFunction() override;
+};
+
+// Create a CFG subgraph for the loop around its body blocks (if the body
+// contained other loops, they have been already lowered to a flow of blocks).
+// Maintain the invariants that a CFG subgraph created for any loop has a single
+// entry and a single exit, and that the entry/exit blocks are respectively
+// first/last blocks in the parent region.  The original loop operation is
+// replaced by the initialization operations that set up the initial value of
+// the loop induction variable (%iv) and computes the loop bounds that are loop-
+// invariant for affine loops.  The operations following the original loop.for
+// are split out into a separate continuation (exit) block. A condition block is
+// created before the continuation block. It checks the exit condition of the
+// loop and branches either to the continuation block, or to the first block of
+// the body. Induction variable modification is appended to the last block of
+// the body (which is the exit block from the body subgraph thanks to the
+// invariant we maintain) along with a branch that loops back to the condition
+// block.
+//
+//      +---------------------------------+
+//      |   <code before the ForOp>       |
+//      |   <compute initial %iv value>   |
+//      |   br cond(%iv)                  |
+//      +---------------------------------+
+//             |
+//  -------|   |
+//  |      v   v
+//  |   +--------------------------------+
+//  |   | cond(%iv):                     |
+//  |   |   <compare %iv to upper bound> |
+//  |   |   cond_br %r, body, end        |
+//  |   +--------------------------------+
+//  |          |               |
+//  |          |               -------------|
+//  |          v                            |
+//  |   +--------------------------------+  |
+//  |   | body-first:                    |  |
+//  |   |   <body contents>              |  |
+//  |   +--------------------------------+  |
+//  |                   |                   |
+//  |                  ...                  |
+//  |                   |                   |
+//  |   +--------------------------------+  |
+//  |   | body-last:                     |  |
+//  |   |   <body contents>              |  |
+//  |   |   %new_iv =<add step to %iv>   |  |
+//  |   |   br cond(%new_iv)             |  |
+//  |   +--------------------------------+  |
+//  |          |                            |
+//  |-----------        |--------------------
+//                      v
+//      +--------------------------------+
+//      | end:                           |
+//      |   <code after the ForOp> |
+//      +--------------------------------+
+//
+struct ForLowering : public OpRewritePattern<ForOp> {
+  using OpRewritePattern<ForOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(ForOp forOp,
+                                     PatternRewriter &rewriter) const override;
+};
+
+// Create a CFG subgraph for the loop.if operation (including its "then" and
+// optional "else" operation blocks).  We maintain the invariants that the
+// subgraph has a single entry and a single exit point, and that the entry/exit
+// blocks are respectively the first/last block of the enclosing region. The
+// operations following the loop.if are split into a continuation (subgraph
+// exit) block. The condition is lowered to a chain of blocks that implement the
+// short-circuit scheme.  Condition blocks are created by splitting out an empty
+// block from the block that contains the loop.if operation.  They
+// conditionally branch to either the first block of the "then" region, or to
+// the first block of the "else" region.  If the latter is absent, they branch
+// to the continuation block instead.  The last blocks of "then" and "else"
+// regions (which are known to be exit blocks thanks to the invariant we
+// maintain).
+//
+//      +--------------------------------+
+//      | <code before the IfOp>         |
+//      | cond_br %cond, %then, %else    |
+//      +--------------------------------+
+//             |              |
+//             |              --------------|
+//             v                            |
+//      +--------------------------------+  |
+//      | then:                          |  |
+//      |   <then contents>              |  |
+//      |   br continue                  |  |
+//      +--------------------------------+  |
+//             |                            |
+//   |----------               |-------------
+//   |                         V
+//   |  +--------------------------------+
+//   |  | else:                          |
+//   |  |   <else contents>              |
+//   |  |   br continue                  |
+//   |  +--------------------------------+
+//   |         |
+//   ------|   |
+//         v   v
+//      +--------------------------------+
+//      | continue:                      |
+//      |   <code after the IfOp>  |
+//      +--------------------------------+
+//
+struct IfLowering : public OpRewritePattern<IfOp> {
+  using OpRewritePattern<IfOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(IfOp ifOp,
+                                     PatternRewriter &rewriter) const override;
+};
+
+struct TerminatorLowering : public OpRewritePattern<TerminatorOp> {
+  using OpRewritePattern<TerminatorOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(TerminatorOp op,
+                                     PatternRewriter &rewriter) const override {
+    rewriter.replaceOp(op, {});
+    return matchSuccess();
+  }
+};
+} // namespace
+
+PatternMatchResult
+ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
+  Location loc = forOp.getLoc();
+
+  // Start by splitting the block containing the 'loop.for' into two parts.
+  // The part before will get the init code, the part after will be the end
+  // point.
+  auto *initBlock = rewriter.getInsertionBlock();
+  auto initPosition = rewriter.getInsertionPoint();
+  auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
+
+  // Use the first block of the loop body as the condition block since it is
+  // the block that has the induction variable as its argument.  Split out
+  // all operations from the first block into a new block.  Move all body
+  // blocks from the loop body region to the region containing the loop.
+  auto *conditionBlock = &forOp.region().front();
+  auto *firstBodyBlock =
+      rewriter.splitBlock(conditionBlock, conditionBlock->begin());
+  auto *lastBodyBlock = &forOp.region().back();
+  rewriter.inlineRegionBefore(forOp.region(), endBlock);
+  auto *iv = conditionBlock->getArgument(0);
+
+  // Append the induction variable stepping logic to the last body block and
+  // branch back to the condition block.  Construct an expression f :
+  // (x -> x+step) and apply this expression to the induction variable.
+  rewriter.setInsertionPointToEnd(lastBodyBlock);
+  auto *step = forOp.step();
+  auto *stepped = rewriter.create<AddIOp>(loc, iv, step).getResult();
+  if (!stepped)
+    return matchFailure();
+  rewriter.create<BranchOp>(loc, conditionBlock, stepped);
+
+  // Compute loop bounds before branching to the condition.
+  rewriter.setInsertionPointToEnd(initBlock);
+  Value *lowerBound = forOp.lowerBound();
+  Value *upperBound = forOp.upperBound();
+  if (!lowerBound || !upperBound)
+    return matchFailure();
+  rewriter.create<BranchOp>(loc, conditionBlock, lowerBound);
+
+  // With the body block done, we can fill in the condition block.
+  rewriter.setInsertionPointToEnd(conditionBlock);
+  auto comparison =
+      rewriter.create<CmpIOp>(loc, CmpIPredicate::SLT, iv, upperBound);
+
+  rewriter.create<CondBranchOp>(loc, comparison, firstBodyBlock,
+                                ArrayRef<Value *>(), endBlock,
+                                ArrayRef<Value *>());
+  // Ok, we're done!
+  rewriter.replaceOp(forOp, {});
+  return matchSuccess();
+}
+
+PatternMatchResult
+IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
+  auto loc = ifOp.getLoc();
+
+  // Start by splitting the block containing the 'loop.if' into two parts.
+  // The part before will contain the condition, the part after will be the
+  // continuation point.
+  auto *condBlock = rewriter.getInsertionBlock();
+  auto opPosition = rewriter.getInsertionPoint();
+  auto *continueBlock = rewriter.splitBlock(condBlock, opPosition);
+
+  // Move blocks from the "then" region to the region containing 'loop.if',
+  // place it before the continuation block, and branch to it.
+  auto &thenRegion = ifOp.thenRegion();
+  auto *thenBlock = &thenRegion.front();
+  rewriter.setInsertionPointToEnd(&thenRegion.back());
+  rewriter.create<BranchOp>(loc, continueBlock);
+  rewriter.inlineRegionBefore(thenRegion, continueBlock);
+
+  // Move blocks from the "else" region (if present) to the region containing
+  // 'loop.if', place it before the continuation block and branch to it.  It
+  // will be placed after the "then" regions.
+  auto *elseBlock = continueBlock;
+  auto &elseRegion = ifOp.elseRegion();
+  if (!elseRegion.empty()) {
+    elseBlock = &elseRegion.front();
+    rewriter.setInsertionPointToEnd(&elseRegion.back());
+    rewriter.create<BranchOp>(loc, continueBlock);
+    rewriter.inlineRegionBefore(elseRegion, continueBlock);
+  }
+
+  rewriter.setInsertionPointToEnd(condBlock);
+  rewriter.create<CondBranchOp>(loc, ifOp.condition(), thenBlock,
+                                /*trueArgs=*/ArrayRef<Value *>(), elseBlock,
+                                /*falseArgs=*/ArrayRef<Value *>());
+
+  // Ok, we're done!
+  rewriter.replaceOp(ifOp, {});
+  return matchSuccess();
+}
+
+void mlir::populateLoopToStdConversionPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *ctx) {
+  patterns.insert<ForLowering, IfLowering, TerminatorLowering>(ctx);
+}
+
+void ControlFlowToCFGPass::runOnFunction() {
+  OwningRewritePatternList patterns;
+  populateLoopToStdConversionPatterns(patterns, &getContext());
+  ConversionTarget target(getContext());
+  target.addLegalDialect<StandardOpsDialect>();
+  if (failed(applyPartialConversion(getFunction(), target, patterns)))
+    signalPassFailure();
+}
+
+FunctionPassBase *mlir::createConvertToCFGPass() {
+  return new ControlFlowToCFGPass();
+}
+
+static PassRegistration<ControlFlowToCFGPass>
+    pass("lower-to-cfg", "Convert control flow operations to ");
diff --git a/third_party/mlir/lib/Conversion/GPUToCUDA/CMakeLists.txt b/third_party/mlir/lib/Conversion/GPUToCUDA/CMakeLists.txt
new file mode 100644
index 0000000..fbaf36c
--- /dev/null
+++ b/third_party/mlir/lib/Conversion/GPUToCUDA/CMakeLists.txt
@@ -0,0 +1,17 @@
+if(MLIR_CUDA_CONVERSIONS_ENABLED)
+  llvm_map_components_to_libnames(nvptx "NVPTX")
+
+  add_llvm_library(MLIRGPUtoCUDATransforms
+    ConvertKernelFuncToCubin.cpp
+    ConvertLaunchFuncToCudaCalls.cpp
+    GenerateCubinAccessors.cpp
+  )
+  target_link_libraries(MLIRGPUtoCUDATransforms
+    MLIRGPU
+    MLIRLLVMIR
+    MLIRNVVMIR
+    MLIRPass
+    MLIRTargetNVVMIR
+    ${nvptx}
+  )
+endif()
diff --git a/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp b/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp
new file mode 100644
index 0000000..7663775
--- /dev/null
+++ b/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp
@@ -0,0 +1,173 @@
+//===- ConvertKernelFuncToCubin.cpp - MLIR GPU lowering passes ------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to convert gpu kernel functions into a
+// corresponding binary blob that can be executed on a CUDA GPU. Currently
+// only translates the function itself but no dependencies.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h"
+
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Target/NVVMIR.h"
+
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/TargetRegistry.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Target/TargetMachine.h"
+
+using namespace mlir;
+
+namespace {
+// TODO(herhut): Move to shared location.
+static constexpr const char *kCubinAnnotation = "nvvm.cubin";
+
+/// A pass converting tagged kernel functions to cubin blobs.
+class GpuKernelToCubinPass : public ModulePass<GpuKernelToCubinPass> {
+public:
+  GpuKernelToCubinPass(
+      CubinGenerator cubinGenerator = compilePtxToCubinForTesting)
+      : cubinGenerator(cubinGenerator) {}
+
+  // Run the dialect converter on the module.
+  void runOnModule() override {
+    // Make sure the NVPTX target is initialized.
+    LLVMInitializeNVPTXTarget();
+    LLVMInitializeNVPTXTargetInfo();
+    LLVMInitializeNVPTXTargetMC();
+    LLVMInitializeNVPTXAsmPrinter();
+
+    for (auto function : getModule().getOps<FuncOp>()) {
+      if (!gpu::GPUDialect::isKernel(function) || function.isExternal()) {
+        continue;
+      }
+      if (failed(translateGpuKernelToCubinAnnotation(function)))
+        signalPassFailure();
+    }
+  }
+
+private:
+  static OwnedCubin compilePtxToCubinForTesting(const std::string &ptx,
+                                                FuncOp &function);
+
+  std::string translateModuleToPtx(llvm::Module &module,
+                                   llvm::TargetMachine &target_machine);
+  OwnedCubin convertModuleToCubin(llvm::Module &llvmModule, FuncOp &function);
+  LogicalResult translateGpuKernelToCubinAnnotation(FuncOp &function);
+
+  CubinGenerator cubinGenerator;
+};
+
+} // anonymous namespace
+
+std::string GpuKernelToCubinPass::translateModuleToPtx(
+    llvm::Module &module, llvm::TargetMachine &target_machine) {
+  std::string ptx;
+  {
+    llvm::raw_string_ostream stream(ptx);
+    llvm::buffer_ostream pstream(stream);
+    llvm::legacy::PassManager codegen_passes;
+    target_machine.addPassesToEmitFile(codegen_passes, pstream, nullptr,
+                                       llvm::TargetMachine::CGFT_AssemblyFile);
+    codegen_passes.run(module);
+  }
+
+  return ptx;
+}
+
+OwnedCubin
+GpuKernelToCubinPass::compilePtxToCubinForTesting(const std::string &ptx,
+                                                  FuncOp &function) {
+  const char data[] = "CUBIN";
+  return llvm::make_unique<std::vector<char>>(data, data + sizeof(data) - 1);
+}
+
+OwnedCubin GpuKernelToCubinPass::convertModuleToCubin(llvm::Module &llvmModule,
+                                                      FuncOp &function) {
+  std::unique_ptr<llvm::TargetMachine> targetMachine;
+  {
+    std::string error;
+    // TODO(herhut): Make triple configurable.
+    constexpr const char *cudaTriple = "nvptx64-nvidia-cuda";
+    llvm::Triple triple(cudaTriple);
+    const llvm::Target *target =
+        llvm::TargetRegistry::lookupTarget("", triple, error);
+    if (target == nullptr) {
+      function.emitError("Cannot initialize target triple");
+      return {};
+    }
+    targetMachine.reset(
+        target->createTargetMachine(triple.str(), "sm_35", "+ptx60", {}, {}));
+  }
+
+  // Set the data layout of the llvm module to match what the ptx target needs.
+  llvmModule.setDataLayout(targetMachine->createDataLayout());
+
+  auto ptx = translateModuleToPtx(llvmModule, *targetMachine);
+
+  return cubinGenerator(ptx, function);
+}
+
+LogicalResult
+GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation(FuncOp &function) {
+  Builder builder(function.getContext());
+
+  OwningModuleRef module = ModuleOp::create(function.getLoc());
+
+  // TODO(herhut): Also handle called functions.
+  module->push_back(function.clone());
+
+  auto llvmModule = translateModuleToNVVMIR(*module);
+  auto cubin = convertModuleToCubin(*llvmModule, function);
+
+  if (!cubin) {
+    return function.emitError("Translation to CUDA binary failed.");
+  }
+
+  function.setAttr(kCubinAnnotation,
+                   builder.getStringAttr({cubin->data(), cubin->size()}));
+
+  // Remove the body of the kernel function now that it has been translated.
+  // The main reason to do this is so that the resulting module no longer
+  // contains the NVVM instructions (typically contained in the kernel bodies)
+  // and hence can be compiled into host code by a separate pass.
+  function.eraseBody();
+
+  return success();
+}
+
+ModulePassBase *
+mlir::createConvertGPUKernelToCubinPass(CubinGenerator cubinGenerator) {
+  return new GpuKernelToCubinPass(cubinGenerator);
+}
+
+static PassRegistration<GpuKernelToCubinPass>
+    pass("test-kernel-to-cubin",
+         "Convert all kernel functions to CUDA cubin blobs");
diff --git a/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
new file mode 100644
index 0000000..bf75778
--- /dev/null
+++ b/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
@@ -0,0 +1,391 @@
+//===- ConvertLaunchFuncToCudaCalls.cpp - MLIR CUDA lowering passes -------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to convert gpu.launch_func op into a sequence of
+// CUDA runtime calls. As the CUDA runtime does not have a stable published ABI,
+// this pass uses a slim runtime layer that builds on top of the public API from
+// the CUDA headers.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h"
+
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/Pass/Pass.h"
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Error.h"
+
+using namespace mlir;
+
+// To avoid name mangling, these are defined in the mini-runtime file.
+static constexpr const char *cuModuleLoadName = "mcuModuleLoad";
+static constexpr const char *cuModuleGetFunctionName = "mcuModuleGetFunction";
+static constexpr const char *cuLaunchKernelName = "mcuLaunchKernel";
+static constexpr const char *cuGetStreamHelperName = "mcuGetStreamHelper";
+static constexpr const char *cuStreamSynchronizeName = "mcuStreamSynchronize";
+
+static constexpr const char *kCubinGetterAnnotation = "nvvm.cubingetter";
+
+namespace {
+
+/// A pass to convert gpu.launch_func operations into a sequence of CUDA
+/// runtime calls.
+///
+/// In essence, a gpu.launch_func operations gets compiled into the following
+/// sequence of runtime calls:
+///
+/// * mcuModuleLoad        -- loads the module given the cubin data
+/// * mcuModuleGetFunction -- gets a handle to the actual kernel function
+/// * mcuGetStreamHelper   -- initializes a new CUDA stream
+/// * mcuLaunchKernelName  -- launches the kernel on a stream
+/// * mcuStreamSynchronize -- waits for operations on the stream to finish
+///
+/// Intermediate data structures are allocated on the stack.
+class GpuLaunchFuncToCudaCallsPass
+    : public ModulePass<GpuLaunchFuncToCudaCallsPass> {
+private:
+  LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
+
+  llvm::LLVMContext &getLLVMContext() {
+    return getLLVMDialect()->getLLVMContext();
+  }
+
+  void initializeCachedTypes() {
+    const llvm::Module &module = llvmDialect->getLLVMModule();
+    llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
+    llvmPointerPointerType = llvmPointerType.getPointerTo();
+    llvmInt8Type = LLVM::LLVMType::getInt8Ty(llvmDialect);
+    llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
+    llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
+    llvmIntPtrType = LLVM::LLVMType::getIntNTy(
+        llvmDialect, module.getDataLayout().getPointerSizeInBits());
+  }
+
+  LLVM::LLVMType getPointerType() { return llvmPointerType; }
+
+  LLVM::LLVMType getPointerPointerType() { return llvmPointerPointerType; }
+
+  LLVM::LLVMType getInt8Type() { return llvmInt8Type; }
+
+  LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
+
+  LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
+
+  LLVM::LLVMType getIntPtrType() {
+    const llvm::Module &module = getLLVMDialect()->getLLVMModule();
+    return LLVM::LLVMType::getIntNTy(
+        getLLVMDialect(), module.getDataLayout().getPointerSizeInBits());
+  }
+
+  LLVM::LLVMType getCUResultType() {
+    // This is declared as an enum in CUDA but helpers use i32.
+    return getInt32Type();
+  }
+
+  // Allocate a void pointer on the stack.
+  Value *allocatePointer(OpBuilder &builder, Location loc) {
+    auto one = builder.create<LLVM::ConstantOp>(loc, getInt32Type(),
+                                                builder.getI32IntegerAttr(1));
+    return builder.create<LLVM::AllocaOp>(loc, getPointerPointerType(), one);
+  }
+
+  void declareCudaFunctions(Location loc);
+  Value *setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder);
+  Value *generateKernelNameConstant(FuncOp kernelFunction, Location &loc,
+                                    OpBuilder &builder);
+  void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp);
+
+public:
+  // Run the dialect converter on the module.
+  void runOnModule() override {
+    // Cache the LLVMDialect for the current module.
+    llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
+    // Cache the used LLVM types.
+    initializeCachedTypes();
+
+    for (auto func : getModule().getOps<FuncOp>()) {
+      func.walk<mlir::gpu::LaunchFuncOp>(
+          [this](mlir::gpu::LaunchFuncOp op) { translateGpuLaunchCalls(op); });
+    }
+  }
+
+private:
+  LLVM::LLVMDialect *llvmDialect;
+  LLVM::LLVMType llvmPointerType;
+  LLVM::LLVMType llvmPointerPointerType;
+  LLVM::LLVMType llvmInt8Type;
+  LLVM::LLVMType llvmInt32Type;
+  LLVM::LLVMType llvmInt64Type;
+  LLVM::LLVMType llvmIntPtrType;
+};
+
+} // anonymous namespace
+
+// Adds declarations for the needed helper functions from the CUDA wrapper.
+// The types in comments give the actual types expected/returned but the API
+// uses void pointers. This is fine as they have the same linkage in C.
+void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
+  ModuleOp module = getModule();
+  Builder builder(module);
+  if (!module.lookupSymbol<FuncOp>(cuModuleLoadName)) {
+    module.push_back(
+        FuncOp::create(loc, cuModuleLoadName,
+                       builder.getFunctionType(
+                           {
+                               getPointerPointerType(), /* CUmodule *module */
+                               getPointerType()         /* void *cubin */
+                           },
+                           getCUResultType())));
+  }
+  if (!module.lookupSymbol<FuncOp>(cuModuleGetFunctionName)) {
+    // The helper uses void* instead of CUDA's opaque CUmodule and
+    // CUfunction.
+    module.push_back(
+        FuncOp::create(loc, cuModuleGetFunctionName,
+                       builder.getFunctionType(
+                           {
+                               getPointerPointerType(), /* void **function */
+                               getPointerType(),        /* void *module */
+                               getPointerType()         /* char *name */
+                           },
+                           getCUResultType())));
+  }
+  if (!module.lookupSymbol<FuncOp>(cuLaunchKernelName)) {
+    // Other than the CUDA api, the wrappers use uintptr_t to match the
+    // LLVM type if MLIR's index type, which the GPU dialect uses.
+    // Furthermore, they use void* instead of CUDA's opaque CUfunction and
+    // CUstream.
+    module.push_back(FuncOp::create(
+        loc, cuLaunchKernelName,
+        builder.getFunctionType(
+            {
+                getPointerType(),        /* void* f */
+                getIntPtrType(),         /* intptr_t gridXDim */
+                getIntPtrType(),         /* intptr_t gridyDim */
+                getIntPtrType(),         /* intptr_t gridZDim */
+                getIntPtrType(),         /* intptr_t blockXDim */
+                getIntPtrType(),         /* intptr_t blockYDim */
+                getIntPtrType(),         /* intptr_t blockZDim */
+                getInt32Type(),          /* unsigned int sharedMemBytes */
+                getPointerType(),        /* void *hstream */
+                getPointerPointerType(), /* void **kernelParams */
+                getPointerPointerType()  /* void **extra */
+            },
+            getCUResultType())));
+  }
+  if (!module.lookupSymbol<FuncOp>(cuGetStreamHelperName)) {
+    // Helper function to get the current CUDA stream. Uses void* instead of
+    // CUDAs opaque CUstream.
+    module.push_back(FuncOp::create(
+        loc, cuGetStreamHelperName,
+        builder.getFunctionType({}, getPointerType() /* void *stream */)));
+  }
+  if (!module.lookupSymbol<FuncOp>(cuStreamSynchronizeName)) {
+    module.push_back(
+        FuncOp::create(loc, cuStreamSynchronizeName,
+                       builder.getFunctionType(
+                           {
+                               getPointerType() /* CUstream stream */
+                           },
+                           getCUResultType())));
+  }
+}
+
+// Generates a parameters array to be used with a CUDA kernel launch call. The
+// arguments are extracted from the launchOp.
+// The generated code is essentially as follows:
+//
+// %array = alloca(numparams * sizeof(void *))
+// for (i : [0, NumKernelOperands))
+//   %array[i] = cast<void*>(KernelOperand[i])
+// return %array
+Value *
+GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp,
+                                               OpBuilder &builder) {
+  Location loc = launchOp.getLoc();
+  auto one = builder.create<LLVM::ConstantOp>(loc, getInt32Type(),
+                                              builder.getI32IntegerAttr(1));
+  auto arraySize = builder.create<LLVM::ConstantOp>(
+      loc, getInt32Type(),
+      builder.getI32IntegerAttr(launchOp.getNumKernelOperands()));
+  auto array =
+      builder.create<LLVM::AllocaOp>(loc, getPointerPointerType(), arraySize);
+  for (int idx = 0, e = launchOp.getNumKernelOperands(); idx < e; ++idx) {
+    auto operand = launchOp.getKernelOperand(idx);
+    auto llvmType = operand->getType().cast<LLVM::LLVMType>();
+    auto memLocation =
+        builder.create<LLVM::AllocaOp>(loc, llvmType.getPointerTo(), one);
+    builder.create<LLVM::StoreOp>(loc, operand, memLocation);
+    auto casted =
+        builder.create<LLVM::BitcastOp>(loc, getPointerType(), memLocation);
+    auto index = builder.create<LLVM::ConstantOp>(
+        loc, getInt32Type(), builder.getI32IntegerAttr(idx));
+    auto gep = builder.create<LLVM::GEPOp>(loc, getPointerPointerType(), array,
+                                           ArrayRef<Value *>{index});
+    builder.create<LLVM::StoreOp>(loc, casted, gep);
+  }
+  return array;
+}
+
+// Generates LLVM IR that produces a value representing the name of the
+// given kernel function. The generated IR consists essentially of the
+// following:
+//
+// %0 = alloca(strlen(name) + 1)
+// %0[0] = constant name[0]
+// ...
+// %0[n] = constant name[n]
+// %0[n+1] = 0
+Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant(
+    FuncOp kernelFunction, Location &loc, OpBuilder &builder) {
+  // TODO(herhut): Make this a constant once this is supported.
+  auto kernelNameSize = builder.create<LLVM::ConstantOp>(
+      loc, getInt32Type(),
+      builder.getI32IntegerAttr(kernelFunction.getName().size() + 1));
+  auto kernelName =
+      builder.create<LLVM::AllocaOp>(loc, getPointerType(), kernelNameSize);
+  for (auto byte : llvm::enumerate(kernelFunction.getName())) {
+    auto index = builder.create<LLVM::ConstantOp>(
+        loc, getInt32Type(), builder.getI32IntegerAttr(byte.index()));
+    auto gep = builder.create<LLVM::GEPOp>(loc, getPointerType(), kernelName,
+                                           ArrayRef<Value *>{index});
+    auto value = builder.create<LLVM::ConstantOp>(
+        loc, getInt8Type(),
+        builder.getIntegerAttr(builder.getIntegerType(8), byte.value()));
+    builder.create<LLVM::StoreOp>(loc, value, gep);
+  }
+  // Add trailing zero to terminate string.
+  auto index = builder.create<LLVM::ConstantOp>(
+      loc, getInt32Type(),
+      builder.getI32IntegerAttr(kernelFunction.getName().size()));
+  auto gep = builder.create<LLVM::GEPOp>(loc, getPointerType(), kernelName,
+                                         ArrayRef<Value *>{index});
+  auto value = builder.create<LLVM::ConstantOp>(
+      loc, getInt8Type(), builder.getIntegerAttr(builder.getIntegerType(8), 0));
+  builder.create<LLVM::StoreOp>(loc, value, gep);
+  return kernelName;
+}
+
+// Emits LLVM IR to launch a kernel function. Expects the module that contains
+// the compiled kernel function as a cubin in the 'nvvm.cubin' attribute of the
+// kernel function in the IR.
+// While MLIR has no global constants, also expects a cubin getter function in
+// an 'nvvm.cubingetter' attribute. Such function is expected to return a
+// pointer to the cubin blob when invoked.
+// With these given, the generated code in essence is
+//
+// %0 = call %cubingetter
+// %1 = alloca sizeof(void*)
+// call %mcuModuleLoad(%2, %1)
+// %2 = alloca sizeof(void*)
+// %3 = load %1
+// %4 = <see generateKernelNameConstant>
+// call %mcuModuleGetFunction(%2, %3, %4)
+// %5 = call %mcuGetStreamHelper()
+// %6 = load %2
+// %7 = <see setupParamsArray>
+// call %mcuLaunchKernel(%6, <launchOp operands 0..5>, 0, %5, %7, nullptr)
+// call %mcuStreamSynchronize(%5)
+void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
+    mlir::gpu::LaunchFuncOp launchOp) {
+  OpBuilder builder(launchOp);
+  Location loc = launchOp.getLoc();
+  declareCudaFunctions(loc);
+
+  auto zero = builder.create<LLVM::ConstantOp>(loc, getInt32Type(),
+                                               builder.getI32IntegerAttr(0));
+  // Emit a call to the cubin getter to retrieve a pointer to the data that
+  // represents the cubin at runtime.
+  // TODO(herhut): This should rather be a static global once supported.
+  auto kernelFunction = getModule().lookupSymbol<FuncOp>(launchOp.kernel());
+  auto cubinGetter =
+      kernelFunction.getAttrOfType<SymbolRefAttr>(kCubinGetterAnnotation);
+  if (!cubinGetter) {
+    kernelFunction.emitError("Missing ")
+        << kCubinGetterAnnotation << " attribute.";
+    return signalPassFailure();
+  }
+  auto data = builder.create<LLVM::CallOp>(
+      loc, ArrayRef<Type>{getPointerType()}, cubinGetter, ArrayRef<Value *>{});
+  // Emit the load module call to load the module data. Error checking is done
+  // in the called helper function.
+  auto cuModule = allocatePointer(builder, loc);
+  FuncOp cuModuleLoad = getModule().lookupSymbol<FuncOp>(cuModuleLoadName);
+  builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
+                               builder.getSymbolRefAttr(cuModuleLoad),
+                               ArrayRef<Value *>{cuModule, data.getResult(0)});
+  // Get the function from the module. The name corresponds to the name of
+  // the kernel function.
+  auto cuOwningModuleRef =
+      builder.create<LLVM::LoadOp>(loc, getPointerType(), cuModule);
+  auto kernelName = generateKernelNameConstant(kernelFunction, loc, builder);
+  auto cuFunction = allocatePointer(builder, loc);
+  FuncOp cuModuleGetFunction =
+      getModule().lookupSymbol<FuncOp>(cuModuleGetFunctionName);
+  builder.create<LLVM::CallOp>(
+      loc, ArrayRef<Type>{getCUResultType()},
+      builder.getSymbolRefAttr(cuModuleGetFunction),
+      ArrayRef<Value *>{cuFunction, cuOwningModuleRef, kernelName});
+  // Grab the global stream needed for execution.
+  FuncOp cuGetStreamHelper =
+      getModule().lookupSymbol<FuncOp>(cuGetStreamHelperName);
+  auto cuStream = builder.create<LLVM::CallOp>(
+      loc, ArrayRef<Type>{getPointerType()},
+      builder.getSymbolRefAttr(cuGetStreamHelper), ArrayRef<Value *>{});
+  // Invoke the function with required arguments.
+  auto cuLaunchKernel = getModule().lookupSymbol<FuncOp>(cuLaunchKernelName);
+  auto cuFunctionRef =
+      builder.create<LLVM::LoadOp>(loc, getPointerType(), cuFunction);
+  auto paramsArray = setupParamsArray(launchOp, builder);
+  auto nullpointer =
+      builder.create<LLVM::IntToPtrOp>(loc, getPointerPointerType(), zero);
+  builder.create<LLVM::CallOp>(
+      loc, ArrayRef<Type>{getCUResultType()},
+      builder.getSymbolRefAttr(cuLaunchKernel),
+      ArrayRef<Value *>{cuFunctionRef, launchOp.getOperand(0),
+                        launchOp.getOperand(1), launchOp.getOperand(2),
+                        launchOp.getOperand(3), launchOp.getOperand(4),
+                        launchOp.getOperand(5), zero, /* sharedMemBytes */
+                        cuStream.getResult(0),        /* stream */
+                        paramsArray,                  /* kernel params */
+                        nullpointer /* extra */});
+  // Sync on the stream to make it synchronous.
+  auto cuStreamSync = getModule().lookupSymbol<FuncOp>(cuStreamSynchronizeName);
+  builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
+                               builder.getSymbolRefAttr(cuStreamSync),
+                               ArrayRef<Value *>(cuStream.getResult(0)));
+  launchOp.erase();
+}
+
+mlir::ModulePassBase *mlir::createConvertGpuLaunchFuncToCudaCallsPass() {
+  return new GpuLaunchFuncToCudaCallsPass();
+}
+
+static PassRegistration<GpuLaunchFuncToCudaCallsPass>
+    pass("launch-func-to-cuda",
+         "Convert all launch_func ops to CUDA runtime calls");
diff --git a/third_party/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp b/third_party/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp
new file mode 100644
index 0000000..813a3be
--- /dev/null
+++ b/third_party/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp
@@ -0,0 +1,152 @@
+//===- GenerateCubinAccessors.cpp - MLIR GPU lowering passes --------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to generate LLVMIR functions that return the
+// data stored in nvvm.cubin char* blob.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Identifier.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+
+#include "llvm/ADT/STLExtras.h"
+
+namespace mlir {
+namespace {
+
+// TODO(herhut): Move to shared location.
+constexpr const char *kCubinAnnotation = "nvvm.cubin";
+constexpr const char *kCubinGetterAnnotation = "nvvm.cubingetter";
+constexpr const char *kCubinGetterSuffix = "_cubin";
+constexpr const char *kMallocHelperName = "malloc";
+
+/// A pass generating getter functions for all cubin blobs annotated on
+/// functions via the nvvm.cubin attribute.
+///
+/// The functions allocate memory using the system malloc call with signature
+/// void *malloc(size_t size). This function has to be provided by the actual
+/// runner that executes the generated code.
+///
+/// This is a stop-gap measure until MLIR supports global constants.
+class GpuGenerateCubinAccessorsPass
+    : public ModulePass<GpuGenerateCubinAccessorsPass> {
+private:
+  LLVM::LLVMType getIndexType() {
+    unsigned bits =
+        llvmDialect->getLLVMModule().getDataLayout().getPointerSizeInBits();
+    return LLVM::LLVMType::getIntNTy(llvmDialect, bits);
+  }
+
+  FuncOp getMallocHelper(Location loc, Builder &builder) {
+    FuncOp result = getModule().lookupSymbol<FuncOp>(kMallocHelperName);
+    if (!result) {
+      result = FuncOp::create(
+          loc, kMallocHelperName,
+          builder.getFunctionType(ArrayRef<Type>{getIndexType()},
+                                  LLVM::LLVMType::getInt8PtrTy(llvmDialect)));
+      getModule().push_back(result);
+    }
+    return result;
+  }
+
+  // Generates a function that returns a char array at runtime that contains the
+  // data from blob. As there are currently no global constants, this uses a
+  // sequence of store operations.
+  // TODO(herhut): Use global constants instead.
+  FuncOp generateCubinAccessor(Builder &builder, FuncOp &orig,
+                               StringAttr blob) {
+    Location loc = orig.getLoc();
+    SmallString<128> nameBuffer(orig.getName());
+    nameBuffer.append(kCubinGetterSuffix);
+    // Generate a function that returns void*.
+    FuncOp result = FuncOp::create(
+        loc, mlir::Identifier::get(nameBuffer, &getContext()),
+        builder.getFunctionType(ArrayRef<Type>{},
+                                LLVM::LLVMType::getInt8PtrTy(llvmDialect)));
+    // Insert a body block that just returns the constant.
+    OpBuilder ob(result.getBody());
+    ob.createBlock(&result.getBody());
+    auto sizeConstant = ob.create<LLVM::ConstantOp>(
+        loc, getIndexType(),
+        builder.getIntegerAttr(builder.getIndexType(), blob.getValue().size()));
+    auto memory =
+        ob.create<LLVM::CallOp>(
+              loc, ArrayRef<Type>{LLVM::LLVMType::getInt8PtrTy(llvmDialect)},
+              builder.getSymbolRefAttr(getMallocHelper(loc, builder)),
+              ArrayRef<Value *>{sizeConstant})
+            .getResult(0);
+    for (auto byte : llvm::enumerate(blob.getValue().bytes())) {
+      auto index = ob.create<LLVM::ConstantOp>(
+          loc, LLVM::LLVMType::getInt32Ty(llvmDialect),
+          builder.getI32IntegerAttr(byte.index()));
+      auto gep =
+          ob.create<LLVM::GEPOp>(loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect),
+                                 memory, ArrayRef<Value *>{index});
+      auto value = ob.create<LLVM::ConstantOp>(
+          loc, LLVM::LLVMType::getInt8Ty(llvmDialect),
+          builder.getIntegerAttr(builder.getIntegerType(8), byte.value()));
+      ob.create<LLVM::StoreOp>(loc, value, gep);
+    }
+    ob.create<LLVM::ReturnOp>(loc, ArrayRef<Value *>{memory});
+    // Store the name of the getter on the function for easier lookup.
+    orig.setAttr(kCubinGetterAnnotation, builder.getSymbolRefAttr(result));
+    return result;
+  }
+
+public:
+  // Run the dialect converter on the module.
+  void runOnModule() override {
+    llvmDialect =
+        getModule().getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
+    auto module = getModule();
+    Builder builder(&getContext());
+
+    auto functions = module.getOps<FuncOp>();
+    for (auto it = functions.begin(); it != functions.end();) {
+      // Move iterator to after the current function so that potential insertion
+      // of the accessor is after the kernel with cubin iself.
+      FuncOp orig = *it++;
+      StringAttr cubinBlob = orig.getAttrOfType<StringAttr>(kCubinAnnotation);
+      if (!cubinBlob)
+        continue;
+      module.insert(it, generateCubinAccessor(builder, orig, cubinBlob));
+    }
+  }
+
+private:
+  LLVM::LLVMDialect *llvmDialect;
+};
+
+} // anonymous namespace
+
+ModulePassBase *createGenerateCubinAccessorPass() {
+  return new GpuGenerateCubinAccessorsPass();
+}
+
+static PassRegistration<GpuGenerateCubinAccessorsPass>
+    pass("generate-cubin-accessors",
+         "Generate LLVMIR functions that give access to cubin data");
+
+} // namespace mlir
diff --git a/third_party/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt b/third_party/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt
new file mode 100644
index 0000000..492f3a1
--- /dev/null
+++ b/third_party/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt
@@ -0,0 +1,10 @@
+add_llvm_library(MLIRGPUtoNVVMTransforms
+  LowerGpuOpsToNVVMOps.cpp
+  )
+target_link_libraries(MLIRGPUtoNVVMTransforms
+  LLVMSupport
+  MLIRGPU
+  MLIRLLVMIR
+  MLIRNVVMIR
+  MLIRPass
+  )
diff --git a/third_party/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/third_party/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
new file mode 100644
index 0000000..e4a6f96
--- /dev/null
+++ b/third_party/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -0,0 +1,139 @@
+//===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to generate NVVMIR operations for higher-level
+// GPU operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/LLVMIR/NVVMDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+
+#include "llvm/ADT/StringSwitch.h"
+
+namespace mlir {
+namespace {
+
+// A pass that replaces all occurences of GPU operations with their
+// corresponding NVVM equivalent.
+//
+// This pass does not handle launching of kernels. Instead, it is meant to be
+// used on the body region of a launch or the body region of a kernel
+// function.
+class LowerGpuOpsToNVVMOpsPass : public FunctionPass<LowerGpuOpsToNVVMOpsPass> {
+private:
+  enum dimension { X = 0, Y = 1, Z = 2, invalid };
+
+  template <typename T> dimension dimensionToIndex(T op) {
+    return llvm::StringSwitch<dimension>(op.dimension())
+        .Case("x", X)
+        .Case("y", Y)
+        .Case("z", Z)
+        .Default(invalid);
+  }
+
+  // Helper that replaces Op with XOp, YOp, or ZOp dependeing on the dimension
+  // that Op operates on.  Op is assumed to return an `std.index` value and
+  // XOp, YOp and ZOp are assumed to return an `llvm.i32` value.  Depending on
+  // `indexBitwidth`, sign-extend or truncate the resulting value to match the
+  // bitwidth expected by the consumers of the value.
+  template <typename XOp, typename YOp, typename ZOp, class Op>
+  void replaceWithIntrinsic(Op operation, LLVM::LLVMDialect *dialect,
+                            unsigned indexBitwidth) {
+    assert(operation.getType().isIndex() &&
+           "expected an operation returning index");
+    OpBuilder builder(operation);
+    auto loc = operation.getLoc();
+    Value *newOp;
+    switch (dimensionToIndex(operation)) {
+    case X:
+      newOp = builder.create<XOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
+      break;
+    case Y:
+      newOp = builder.create<YOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
+      break;
+    case Z:
+      newOp = builder.create<ZOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
+      break;
+    default:
+      operation.emitError("Illegal dimension: " + operation.dimension());
+      signalPassFailure();
+      return;
+    }
+
+    if (indexBitwidth > 32) {
+      newOp = builder.create<LLVM::SExtOp>(
+          loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
+    } else if (indexBitwidth < 32) {
+      newOp = builder.create<LLVM::TruncOp>(
+          loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
+    }
+    operation.replaceAllUsesWith(newOp);
+    operation.erase();
+  }
+
+public:
+  void runOnFunction() {
+    LLVM::LLVMDialect *llvmDialect =
+        getContext().getRegisteredDialect<LLVM::LLVMDialect>();
+    unsigned indexBitwidth =
+        llvmDialect->getLLVMModule().getDataLayout().getPointerSizeInBits();
+    getFunction().walk([&](Operation *opInst) {
+      if (auto threadId = dyn_cast<gpu::ThreadId>(opInst)) {
+        replaceWithIntrinsic<NVVM::ThreadIdXOp, NVVM::ThreadIdYOp,
+                             NVVM::ThreadIdZOp>(threadId, llvmDialect,
+                                                indexBitwidth);
+        return;
+      }
+      if (auto blockDim = dyn_cast<gpu::BlockDim>(opInst)) {
+        replaceWithIntrinsic<NVVM::BlockDimXOp, NVVM::BlockDimYOp,
+                             NVVM::BlockDimZOp>(blockDim, llvmDialect,
+                                                indexBitwidth);
+        return;
+      }
+      if (auto blockId = dyn_cast<gpu::BlockId>(opInst)) {
+        replaceWithIntrinsic<NVVM::BlockIdXOp, NVVM::BlockIdYOp,
+                             NVVM::BlockIdZOp>(blockId, llvmDialect,
+                                               indexBitwidth);
+        return;
+      }
+      if (auto gridDim = dyn_cast<gpu::GridDim>(opInst)) {
+        replaceWithIntrinsic<NVVM::GridDimXOp, NVVM::GridDimYOp,
+                             NVVM::GridDimZOp>(gridDim, llvmDialect,
+                                               indexBitwidth);
+        return;
+      }
+    });
+  }
+};
+
+} // anonymous namespace
+
+FunctionPassBase *createLowerGpuOpsToNVVMOpsPass() {
+  return new LowerGpuOpsToNVVMOpsPass();
+}
+
+static PassRegistration<LowerGpuOpsToNVVMOpsPass>
+    pass("lower-gpu-ops-to-nvvm-ops",
+         "Generate NVVM operations for gpu operations");
+
+} // namespace mlir
diff --git a/third_party/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt b/third_party/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
new file mode 100644
index 0000000..8426420
--- /dev/null
+++ b/third_party/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_llvm_library(MLIRGPUtoSPIRVTransforms
+  GPUToSPIRV.cpp
+  )
+
+target_link_libraries(MLIRGPUtoSPIRVTransforms
+  MLIRGPU
+  MLIRIR
+  MLIRPass
+  MLIRSPIRV
+  MLIRStandardOps
+  MLIRSPIRVConversion
+  MLIRTransforms
+  )
diff --git a/third_party/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/third_party/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
new file mode 100644
index 0000000..c36aee5
--- /dev/null
+++ b/third_party/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -0,0 +1,124 @@
+//===- GPUToSPIRV.cp - MLIR SPIR-V lowering passes ------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to convert a kernel function in the GPU Dialect
+// into a spv.module operation
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Pattern to convert a kernel function in GPU dialect (a FuncOp with the
+/// attribute gpu.kernel) within a spv.module.
+class KernelFnConversion final : public SPIRVFnLowering {
+public:
+  using SPIRVFnLowering::SPIRVFnLowering;
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+PatternMatchResult
+KernelFnConversion::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                    ConversionPatternRewriter &rewriter) const {
+  auto funcOp = cast<FuncOp>(op);
+  FuncOp newFuncOp;
+  if (!gpu::GPUDialect::isKernel(funcOp)) {
+    return succeeded(lowerFunction(funcOp, operands, rewriter, newFuncOp))
+               ? matchSuccess()
+               : matchFailure();
+  }
+
+  if (failed(lowerAsEntryFunction(funcOp, operands, rewriter, newFuncOp))) {
+    return matchFailure();
+  }
+  newFuncOp.getOperation()->removeAttr(Identifier::get(
+      gpu::GPUDialect::getKernelFuncAttrName(), op->getContext()));
+  return matchSuccess();
+}
+
+namespace {
+/// Pass to lower GPU Dialect to SPIR-V. The pass only converts those functions
+/// that have the "gpu.kernel" attribute, i.e. those functions that are
+/// referenced in gpu::LaunchKernelOp operations. For each such function
+///
+/// 1) Create a spirv::ModuleOp, and clone the function into spirv::ModuleOp
+/// (the original function is still needed by the gpu::LaunchKernelOp, so cannot
+/// replace it).
+///
+/// 2) Lower the body of the spirv::ModuleOp.
+class GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> {
+  void runOnModule() override;
+};
+} // namespace
+
+void GPUToSPIRVPass::runOnModule() {
+  auto context = &getContext();
+  auto module = getModule();
+
+  SmallVector<Operation *, 4> spirvModules;
+  for (auto funcOp : module.getOps<FuncOp>()) {
+    if (gpu::GPUDialect::isKernel(funcOp)) {
+      OpBuilder builder(module.getBodyRegion());
+      // Create a new spirv::ModuleOp for this function, and clone the
+      // function into it.
+      // TODO : Generalize this to account for different extensions,
+      // capabilities, extended_instruction_sets, other addressing models
+      // and memory models.
+      auto spvModule = builder.create<spirv::ModuleOp>(
+          funcOp.getLoc(),
+          builder.getI32IntegerAttr(
+              static_cast<int32_t>(spirv::AddressingModel::Logical)),
+          builder.getI32IntegerAttr(
+              static_cast<int32_t>(spirv::MemoryModel::VulkanKHR)));
+      OpBuilder moduleBuilder(spvModule.getOperation()->getRegion(0));
+      moduleBuilder.clone(*funcOp.getOperation());
+      spirvModules.push_back(spvModule);
+    }
+  }
+
+  /// Dialect conversion to lower the functions with the spirv::ModuleOps.
+  SPIRVTypeConverter typeConverter(context);
+  SPIRVEntryFnTypeConverter entryFnConverter(context);
+  OwningRewritePatternList patterns;
+  patterns.insert<KernelFnConversion>(context, typeConverter, entryFnConverter);
+  populateStandardToSPIRVPatterns(context, patterns);
+
+  ConversionTarget target(*context);
+  target.addLegalDialect<spirv::SPIRVDialect>();
+  target.addDynamicallyLegalOp<FuncOp>(
+      [&](FuncOp Op) { return typeConverter.isSignatureLegal(Op.getType()); });
+
+  if (failed(applyFullConversion(spirvModules, target, patterns,
+                                 &typeConverter))) {
+    return signalPassFailure();
+  }
+}
+
+ModulePassBase *createGPUToSPIRVPass() { return new GPUToSPIRVPass(); }
+
+static PassRegistration<GPUToSPIRVPass>
+    pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect");
diff --git a/third_party/mlir/lib/Conversion/LoopsToGPU/CMakeLists.txt b/third_party/mlir/lib/Conversion/LoopsToGPU/CMakeLists.txt
new file mode 100644
index 0000000..2dacc80
--- /dev/null
+++ b/third_party/mlir/lib/Conversion/LoopsToGPU/CMakeLists.txt
@@ -0,0 +1,21 @@
+set(LIBS
+  MLIRAffineOps
+  MLIRGPU
+  MLIRIR
+  MLIRLinalg
+  MLIRPass
+  MLIRStandardOps
+  MLIRSupport
+  MLIRTransforms
+  LLVMSupport
+)
+
+add_llvm_library(MLIRLoopsToGPU
+  LoopsToGPU.cpp
+  LoopsToGPUPass.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/LoopsToGPU
+)
+add_dependencies(MLIRLoopsToGPU ${LIBS})
+target_link_libraries(MLIRLoopsToGPU ${LIBS})
diff --git a/third_party/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp b/third_party/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
new file mode 100644
index 0000000..6ca4cb3
--- /dev/null
+++ b/third_party/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
@@ -0,0 +1,337 @@
+//===- LoopsToGPU.cpp - Convert an affine loop nest to a GPU kernel -------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This implements a straightforward conversion of an loop nest into a GPU
+// kernel.  The caller is expected to guarantee that the conversion is correct
+// or to further transform the kernel to ensure correctness.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LoopsToGPU/LoopsToGPU.h"
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Transforms/LowerAffine.h"
+#include "mlir/Transforms/RegionUtils.h"
+
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "loops-to-gpu"
+
+using namespace mlir;
+using namespace mlir::loop;
+
+// Extract an indexed value from KernelDim3.
+static Value *getDim3Value(const gpu::KernelDim3 &dim3, unsigned pos) {
+  switch (pos) {
+  case 0:
+    return dim3.x;
+  case 1:
+    return dim3.y;
+  case 2:
+    return dim3.z;
+  default:
+    llvm_unreachable("dim3 position out of bounds");
+  }
+  return nullptr;
+}
+
+// Get the lower bound-related operands of a loop operation.
+static Operation::operand_range getLowerBoundOperands(AffineForOp forOp) {
+  return forOp.getLowerBoundOperands();
+}
+static SmallVector<Value *, 1> getLowerBoundOperands(ForOp forOp) {
+  SmallVector<Value *, 1> bounds(1, forOp.lowerBound());
+  return bounds;
+}
+
+// Get the upper bound-related operands of a loop operation.
+static Operation::operand_range getUpperBoundOperands(AffineForOp forOp) {
+  return forOp.getUpperBoundOperands();
+}
+static SmallVector<Value *, 1> getUpperBoundOperands(ForOp forOp) {
+  SmallVector<Value *, 1> bounds(1, forOp.upperBound());
+  return bounds;
+}
+
+// Get a Value that corresponds to the loop step.  If the step is an attribute,
+// materialize a corresponding constant using builder.
+static Value *getOrCreateStep(AffineForOp forOp, OpBuilder &builder) {
+  return builder.create<ConstantIndexOp>(forOp.getLoc(), forOp.getStep());
+}
+static Value *getOrCreateStep(ForOp forOp, OpBuilder &) { return forOp.step(); }
+
+// Get a Value for the loop lower bound.  If the value requires computation,
+// materialize the instructions using builder.
+static Value *getOrEmitLowerBound(AffineForOp forOp, OpBuilder &builder) {
+  return lowerAffineLowerBound(forOp, builder);
+}
+static Value *getOrEmitLowerBound(ForOp forOp, OpBuilder &) {
+  return forOp.lowerBound();
+}
+
+// Get a Value for the loop upper bound.  If the value requires computation,
+// materialize the instructions using builder.
+static Value *getOrEmitUpperBound(AffineForOp forOp, OpBuilder &builder) {
+  return lowerAffineUpperBound(forOp, builder);
+}
+static Value *getOrEmitUpperBound(ForOp forOp, OpBuilder &) {
+  return forOp.upperBound();
+}
+
+// Check the structure of the loop nest:
+//   - there are enough loops to map to numBlockDims + numThreadDims;
+//   - the loops are perfectly nested;
+//   - the loop bounds can be computed above the outermost loop.
+// This roughly corresponds to the "matcher" part of the pattern-based
+// rewriting infrastructure.
+template <typename OpTy>
+LogicalResult checkLoopNestMappable(OpTy forOp, unsigned numBlockDims,
+                                    unsigned numThreadDims) {
+  if (numBlockDims < 1 || numThreadDims < 1) {
+    LLVM_DEBUG(llvm::dbgs() << "nothing to map");
+    return success();
+  }
+
+  OpBuilder builder(forOp.getOperation());
+  if (numBlockDims > 3) {
+    return emitError(builder.getUnknownLoc(),
+                     "cannot map to more than 3 block dimensions");
+  }
+  if (numThreadDims > 3) {
+    return emitError(builder.getUnknownLoc(),
+                     "cannot map to more than 3 thread dimensions");
+  }
+
+  OpTy currentLoop = forOp;
+  Region &limit = forOp.region();
+  for (unsigned i = 0, e = numBlockDims + numThreadDims; i < e; ++i) {
+    Operation *nested = &currentLoop.getBody()->front();
+    if (!areValuesDefinedAbove(getLowerBoundOperands(currentLoop), limit) ||
+        !areValuesDefinedAbove(getUpperBoundOperands(currentLoop), limit))
+      return currentLoop.emitError(
+          "loops with bounds depending on other mapped loops "
+          "are not supported");
+
+    // The innermost loop can have an arbitrary body, skip the perfect nesting
+    // check for it.
+    if (i == e - 1)
+      break;
+
+    auto begin = currentLoop.getBody()->begin(),
+         end = currentLoop.getBody()->end();
+    if (currentLoop.getBody()->empty() || std::next(begin, 2) != end)
+      return currentLoop.emitError(
+          "expected perfectly nested loops in the body");
+
+    if (!(currentLoop = dyn_cast<OpTy>(nested)))
+      return nested->emitError("expected a nested loop");
+  }
+
+  return success();
+}
+
+namespace {
+// Helper structure that holds common state of the loop to GPU kernel
+// conversion.
+struct LoopToGpuConverter {
+  template <typename OpTy>
+  Optional<OpTy> collectBounds(OpTy forOp, unsigned numLoops);
+
+  template <typename OpTy>
+  void createLaunch(OpTy rootForOp, OpTy innermostForOp, unsigned numBlockDims,
+                    unsigned numThreadDims);
+
+  // Ranges of the loops mapped to blocks or threads.
+  SmallVector<Value *, 6> dims;
+  // Lower bounds of the loops mapped to blocks or threads.
+  SmallVector<Value *, 6> lbs;
+  // Induction variables of the loops mapped to blocks or threads.
+  SmallVector<Value *, 6> ivs;
+  // Steps of the loops mapped to blocks or threads.
+  SmallVector<Value *, 6> steps;
+};
+} // namespace
+
+// Return true if the value is obviously a constant "one".
+static bool isConstantOne(Value *value) {
+  if (auto def = dyn_cast_or_null<ConstantIndexOp>(value->getDefiningOp()))
+    return def.getValue() == 1;
+  return false;
+}
+
+// Collect ranges, bounds, steps and induction variables in preparation for
+// mapping a loop nest of depth "numLoops" rooted at "forOp" to a GPU kernel.
+// This may fail if the IR for computing loop bounds cannot be constructed, for
+// example if an affine loop uses semi-affine maps. Return the last loop to be
+// mapped on success, llvm::None on failure.
+template <typename OpTy>
+Optional<OpTy> LoopToGpuConverter::collectBounds(OpTy forOp,
+                                                 unsigned numLoops) {
+  OpBuilder builder(forOp.getOperation());
+  dims.reserve(numLoops);
+  lbs.reserve(numLoops);
+  ivs.reserve(numLoops);
+  steps.reserve(numLoops);
+  OpTy currentLoop = forOp;
+  for (unsigned i = 0; i < numLoops; ++i) {
+    Value *lowerBound = getOrEmitLowerBound(currentLoop, builder);
+    Value *upperBound = getOrEmitUpperBound(currentLoop, builder);
+    if (!lowerBound || !upperBound) {
+      return llvm::None;
+    }
+
+    Value *range =
+        builder.create<SubIOp>(currentLoop.getLoc(), upperBound, lowerBound);
+    Value *step = getOrCreateStep(currentLoop, builder);
+    if (!isConstantOne(step))
+      range = builder.create<DivISOp>(currentLoop.getLoc(), range, step);
+    dims.push_back(range);
+
+    lbs.push_back(lowerBound);
+    ivs.push_back(currentLoop.getInductionVar());
+    steps.push_back(step);
+
+    if (i != numLoops - 1)
+      currentLoop = cast<OpTy>(&currentLoop.getBody()->front());
+  }
+  return currentLoop;
+}
+
+// Replace the rooted at "rootForOp" with a GPU launch operation.  This expects
+// "innermostForOp" to point to the last loop to be transformed to the kernel,
+// and to have (numBlockDims + numThreadDims) perfectly nested loops between
+// "rootForOp" and "innermostForOp".
+template <typename OpTy>
+void LoopToGpuConverter::createLaunch(OpTy rootForOp, OpTy innermostForOp,
+                                      unsigned numBlockDims,
+                                      unsigned numThreadDims) {
+  OpBuilder builder(rootForOp.getOperation());
+  // Prepare the grid and block sizes for the launch operation.  If there is
+  // no loop mapped to a specific dimension, use constant "1" as its size.
+  Value *constOne = (numBlockDims < 3 || numThreadDims < 3)
+                        ? builder.create<ConstantIndexOp>(rootForOp.getLoc(), 1)
+                        : nullptr;
+  Value *gridSizeX = dims[0];
+  Value *gridSizeY = numBlockDims > 1 ? dims[1] : constOne;
+  Value *gridSizeZ = numBlockDims > 2 ? dims[2] : constOne;
+  Value *blockSizeX = dims[numBlockDims];
+  Value *blockSizeY = numThreadDims > 1 ? dims[numBlockDims + 1] : constOne;
+  Value *blockSizeZ = numThreadDims > 2 ? dims[numBlockDims + 2] : constOne;
+
+  // Create a launch op and move the body region of the innermost loop to the
+  // launch op.  Pass the values defined outside the outermost loop and used
+  // inside the innermost loop and loop lower bounds as kernel data arguments.
+  // Still assuming perfect nesting so there are no values other than induction
+  // variables that are defined in one loop and used in deeper loops.
+  llvm::SetVector<Value *> valuesToForwardSet;
+  getUsedValuesDefinedAbove(innermostForOp.region(), rootForOp.region(),
+                            valuesToForwardSet);
+  auto valuesToForward = valuesToForwardSet.takeVector();
+  auto originallyForwardedValues = valuesToForward.size();
+  valuesToForward.insert(valuesToForward.end(), lbs.begin(), lbs.end());
+  valuesToForward.insert(valuesToForward.end(), steps.begin(), steps.end());
+  auto launchOp = builder.create<gpu::LaunchOp>(
+      rootForOp.getLoc(), gridSizeX, gridSizeY, gridSizeZ, blockSizeX,
+      blockSizeY, blockSizeZ, valuesToForward);
+  valuesToForward.resize(originallyForwardedValues);
+
+  // Replace the loop terminator (loops contain only a single block) with the
+  // gpu return and move the operations from the loop body block to the gpu
+  // launch body block.  Do not move the entire block because of the difference
+  // in block arguments.
+  Operation &terminator = innermostForOp.getBody()->back();
+  Location terminatorLoc = terminator.getLoc();
+  terminator.erase();
+  builder.setInsertionPointToEnd(innermostForOp.getBody());
+  builder.create<gpu::Return>(terminatorLoc);
+  launchOp.getBody().front().getOperations().splice(
+      launchOp.getBody().front().begin(),
+      innermostForOp.getBody()->getOperations());
+
+  // Remap the loop iterators to use block/thread identifiers instead.  Loops
+  // may iterate from LB with step S whereas GPU thread/block ids always iterate
+  // from 0 to N with step 1.  Therefore, loop induction variables are replaced
+  // with (gpu-thread/block-id * S) + LB.
+  builder.setInsertionPointToStart(&launchOp.getBody().front());
+  auto lbArgumentIt = std::next(launchOp.getKernelArguments().begin(),
+                                originallyForwardedValues);
+  auto stepArgumentIt = std::next(lbArgumentIt, lbs.size());
+  for (auto en : llvm::enumerate(ivs)) {
+    Value *id =
+        en.index() < numBlockDims
+            ? getDim3Value(launchOp.getBlockIds(), en.index())
+            : getDim3Value(launchOp.getThreadIds(), en.index() - numBlockDims);
+    Value *step = steps[en.index()];
+    if (!isConstantOne(step))
+      id = builder.create<MulIOp>(rootForOp.getLoc(), step, id);
+
+    Value *ivReplacement =
+        builder.create<AddIOp>(rootForOp.getLoc(), *lbArgumentIt, id);
+    en.value()->replaceAllUsesWith(ivReplacement);
+    replaceAllUsesInRegionWith(steps[en.index()], *stepArgumentIt,
+                               launchOp.getBody());
+    std::advance(lbArgumentIt, 1);
+    std::advance(stepArgumentIt, 1);
+  }
+
+  // Remap the values defined outside the body to use kernel arguments instead.
+  // The list of kernel arguments also contains the lower bounds for loops at
+  // trailing positions, make sure we don't touch those.
+  for (const auto &pair :
+       llvm::zip_first(valuesToForward, launchOp.getKernelArguments())) {
+    Value *from = std::get<0>(pair);
+    Value *to = std::get<1>(pair);
+    replaceAllUsesInRegionWith(from, to, launchOp.getBody());
+  }
+
+  // We are done and can erase the original outermost loop.
+  rootForOp.erase();
+}
+
+// Generic loop to GPU kernel conversion function.
+template <typename OpTy>
+static LogicalResult convertLoopNestToGPULaunch(OpTy forOp,
+                                                unsigned numBlockDims,
+                                                unsigned numThreadDims) {
+  if (failed(checkLoopNestMappable(forOp, numBlockDims, numThreadDims)))
+    return failure();
+
+  LoopToGpuConverter converter;
+  auto maybeInnerLoop =
+      converter.collectBounds(forOp, numBlockDims + numThreadDims);
+  if (!maybeInnerLoop)
+    return failure();
+  converter.createLaunch(forOp, *maybeInnerLoop, numBlockDims, numThreadDims);
+
+  return success();
+}
+
+LogicalResult mlir::convertAffineLoopNestToGPULaunch(AffineForOp forOp,
+                                                     unsigned numBlockDims,
+                                                     unsigned numThreadDims) {
+  return ::convertLoopNestToGPULaunch(forOp, numBlockDims, numThreadDims);
+}
+
+LogicalResult mlir::convertLoopNestToGPULaunch(ForOp forOp,
+                                               unsigned numBlockDims,
+                                               unsigned numThreadDims) {
+  return ::convertLoopNestToGPULaunch(forOp, numBlockDims, numThreadDims);
+}
diff --git a/third_party/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp b/third_party/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp
new file mode 100644
index 0000000..7c785b5
--- /dev/null
+++ b/third_party/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp
@@ -0,0 +1,78 @@
+//===- LoopsToGPUPass.cpp - Convert a loop nest to a GPU kernel -----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h"
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Conversion/LoopsToGPU/LoopsToGPU.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/Pass/Pass.h"
+
+#include "llvm/Support/CommandLine.h"
+
+#define PASS_NAME "convert-loops-to-gpu"
+
+using namespace mlir;
+using namespace mlir::loop;
+
+static llvm::cl::OptionCategory clOptionsCategory(PASS_NAME " options");
+static llvm::cl::opt<unsigned>
+    clNumBlockDims("gpu-block-dims",
+                   llvm::cl::desc("Number of GPU block dimensions for mapping"),
+                   llvm::cl::cat(clOptionsCategory), llvm::cl::init(1u));
+static llvm::cl::opt<unsigned> clNumThreadDims(
+    "gpu-thread-dims",
+    llvm::cl::desc("Number of GPU thread dimensions for mapping"),
+    llvm::cl::cat(clOptionsCategory), llvm::cl::init(1u));
+
+namespace {
+// A pass that traverses top-level loops in the function and converts them to
+// GPU launch operations.  Nested launches are not allowed, so this does not
+// walk the function recursively to avoid considering nested loops.
+struct ForLoopMapper : public FunctionPass<ForLoopMapper> {
+  ForLoopMapper(unsigned numBlockDims, unsigned numThreadDims)
+      : numBlockDims(numBlockDims), numThreadDims(numThreadDims) {}
+
+  void runOnFunction() override {
+    for (Block &block : getFunction())
+      for (Operation &op : llvm::make_early_inc_range(block)) {
+        if (auto forOp = dyn_cast<AffineForOp>(&op)) {
+          if (failed(convertAffineLoopNestToGPULaunch(forOp, numBlockDims,
+                                                      numThreadDims)))
+            signalPassFailure();
+        } else if (auto forOp = dyn_cast<ForOp>(&op)) {
+          if (failed(convertLoopNestToGPULaunch(forOp, numBlockDims,
+                                                numThreadDims)))
+            signalPassFailure();
+        }
+      }
+  }
+
+  unsigned numBlockDims;
+  unsigned numThreadDims;
+};
+} // namespace
+
+FunctionPassBase *mlir::createSimpleLoopsToGPUPass(unsigned numBlockDims,
+                                                   unsigned numThreadDims) {
+  return new ForLoopMapper(numBlockDims, numThreadDims);
+}
+
+static PassRegistration<ForLoopMapper>
+    registration(PASS_NAME, "Convert top-level loops to GPU kernels", [] {
+      return new ForLoopMapper(clNumBlockDims.getValue(),
+                               clNumThreadDims.getValue());
+    });
diff --git a/third_party/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt b/third_party/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt
new file mode 100644
index 0000000..3f3a334
--- /dev/null
+++ b/third_party/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt
@@ -0,0 +1,24 @@
+add_llvm_library(MLIRStandardToLLVM
+  ConvertStandardToLLVM.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/StandardToLLVM
+)
+add_dependencies(
+  MLIRStandardToLLVM
+
+  MLIRControlFlowToCFG
+  MLIRLLVMIR
+  MLIRTransforms
+  LLVMCore
+  LLVMSupport
+)
+target_link_libraries(
+  MLIRStandardToLLVM
+
+  MLIRControlFlowToCFG
+  MLIRLLVMIR
+  MLIRTransforms
+  LLVMCore
+  LLVMSupport
+)
diff --git a/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
new file mode 100644
index 0000000..5bb2811
--- /dev/null
+++ b/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -0,0 +1,1144 @@
+//===- ConvertStandardToLLVM.cpp - Standard to LLVM dialect conversion-----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to convert MLIR standard and builtin dialects
+// into the LLVM IR dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Support/Functional.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/Utils.h"
+
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Type.h"
+
+using namespace mlir;
+
+LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
+    : llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()) {
+  assert(llvmDialect && "LLVM IR dialect is not registered");
+  module = &llvmDialect->getLLVMModule();
+}
+
+// Get the LLVM context.
+llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() {
+  return module->getContext();
+}
+
+// Extract an LLVM IR type from the LLVM IR dialect type.
+LLVM::LLVMType LLVMTypeConverter::unwrap(Type type) {
+  if (!type)
+    return nullptr;
+  auto *mlirContext = type.getContext();
+  auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>();
+  if (!wrappedLLVMType)
+    emitError(UnknownLoc::get(mlirContext),
+              "conversion resulted in a non-LLVM type");
+  return wrappedLLVMType;
+}
+
+LLVM::LLVMType LLVMTypeConverter::getIndexType() {
+  return LLVM::LLVMType::getIntNTy(
+      llvmDialect, module->getDataLayout().getPointerSizeInBits());
+}
+
+Type LLVMTypeConverter::convertIndexType(IndexType type) {
+  return getIndexType();
+}
+
+Type LLVMTypeConverter::convertIntegerType(IntegerType type) {
+  return LLVM::LLVMType::getIntNTy(llvmDialect, type.getWidth());
+}
+
+Type LLVMTypeConverter::convertFloatType(FloatType type) {
+  switch (type.getKind()) {
+  case mlir::StandardTypes::F32:
+    return LLVM::LLVMType::getFloatTy(llvmDialect);
+  case mlir::StandardTypes::F64:
+    return LLVM::LLVMType::getDoubleTy(llvmDialect);
+  case mlir::StandardTypes::F16:
+    return LLVM::LLVMType::getHalfTy(llvmDialect);
+  case mlir::StandardTypes::BF16: {
+    auto *mlirContext = llvmDialect->getContext();
+    return emitError(UnknownLoc::get(mlirContext), "unsupported type: BF16"),
+           Type();
+  }
+  default:
+    llvm_unreachable("non-float type in convertFloatType");
+  }
+}
+
+// Function types are converted to LLVM Function types by recursively converting
+// argument and result types.  If MLIR Function has zero results, the LLVM
+// Function has one VoidType result.  If MLIR Function has more than one result,
+// they are into an LLVM StructType in their order of appearance.
+Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
+  // Convert argument types one by one and check for errors.
+  SmallVector<LLVM::LLVMType, 8> argTypes;
+  for (auto t : type.getInputs()) {
+    auto converted = convertType(t);
+    if (!converted)
+      return {};
+    argTypes.push_back(unwrap(converted));
+  }
+
+  // If function does not return anything, create the void result type,
+  // if it returns on element, convert it, otherwise pack the result types into
+  // a struct.
+  LLVM::LLVMType resultType =
+      type.getNumResults() == 0
+          ? LLVM::LLVMType::getVoidTy(llvmDialect)
+          : unwrap(packFunctionResults(type.getResults()));
+  if (!resultType)
+    return {};
+  return LLVM::LLVMType::getFunctionTy(resultType, argTypes, /*isVarArg=*/false)
+      .getPointerTo();
+}
+
+// Convert a MemRef to an LLVM type. If the memref is statically-shaped, then
+// we return a pointer to the converted element type. Otherwise we return an
+// LLVM stucture type, where the first element of the structure type is a
+// pointer to the elemental type of the MemRef and the following N elements are
+// values of the Index type, one for each of N dynamic dimensions of the MemRef.
+Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
+  LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
+  if (!elementType)
+    return {};
+  auto ptrType = elementType.getPointerTo();
+
+  // Extra value for the memory space.
+  unsigned numDynamicSizes = type.getNumDynamicDims();
+  // If memref is statically-shaped we return the underlying pointer type.
+  if (numDynamicSizes == 0)
+    return ptrType;
+
+  SmallVector<LLVM::LLVMType, 8> types(numDynamicSizes + 1, getIndexType());
+  types.front() = ptrType;
+
+  return LLVM::LLVMType::getStructTy(llvmDialect, types);
+}
+
+// Convert a 1D vector type to an LLVM vector type.
+Type LLVMTypeConverter::convertVectorType(VectorType type) {
+  if (type.getRank() != 1) {
+    auto *mlirContext = llvmDialect->getContext();
+    emitError(UnknownLoc::get(mlirContext), "only 1D vectors are supported");
+    return {};
+  }
+
+  LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
+  return elementType
+             ? LLVM::LLVMType::getVectorTy(elementType, type.getShape().front())
+             : Type();
+}
+
+// Dispatch based on the actual type.  Return null type on error.
+Type LLVMTypeConverter::convertStandardType(Type type) {
+  if (auto funcType = type.dyn_cast<FunctionType>())
+    return convertFunctionType(funcType);
+  if (auto intType = type.dyn_cast<IntegerType>())
+    return convertIntegerType(intType);
+  if (auto floatType = type.dyn_cast<FloatType>())
+    return convertFloatType(floatType);
+  if (auto indexType = type.dyn_cast<IndexType>())
+    return convertIndexType(indexType);
+  if (auto memRefType = type.dyn_cast<MemRefType>())
+    return convertMemRefType(memRefType);
+  if (auto vectorType = type.dyn_cast<VectorType>())
+    return convertVectorType(vectorType);
+  if (auto llvmType = type.dyn_cast<LLVM::LLVMType>())
+    return llvmType;
+
+  return {};
+}
+
+// Convert the element type of the memref `t` to to an LLVM type using
+// `lowering`, get a pointer LLVM type pointing to the converted `t`, wrap it
+// into the MLIR LLVM dialect type and return.
+static Type getMemRefElementPtrType(MemRefType t, LLVMTypeConverter &lowering) {
+  auto elementType = t.getElementType();
+  auto converted = lowering.convertType(elementType);
+  if (!converted)
+    return {};
+  return converted.cast<LLVM::LLVMType>().getPointerTo();
+}
+
+LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
+                               LLVMTypeConverter &lowering_)
+    : ConversionPattern(rootOpName, /*benefit=*/1, context),
+      lowering(lowering_) {}
+
+namespace {
+// Base class for Standard to LLVM IR op conversions.  Matches the Op type
+// provided as template argument.  Carries a reference to the LLVM dialect in
+// case it is necessary for rewriters.
+template <typename SourceOp>
+class LLVMLegalizationPattern : public LLVMOpLowering {
+public:
+  // Construct a conversion pattern.
+  explicit LLVMLegalizationPattern(LLVM::LLVMDialect &dialect_,
+                                   LLVMTypeConverter &lowering_)
+      : LLVMOpLowering(SourceOp::getOperationName(), dialect_.getContext(),
+                       lowering_),
+        dialect(dialect_) {}
+
+  // Get the LLVM IR dialect.
+  LLVM::LLVMDialect &getDialect() const { return dialect; }
+  // Get the LLVM context.
+  llvm::LLVMContext &getContext() const { return dialect.getLLVMContext(); }
+  // Get the LLVM module in which the types are constructed.
+  llvm::Module &getModule() const { return dialect.getLLVMModule(); }
+
+  // Get the MLIR type wrapping the LLVM integer type whose bit width is defined
+  // by the pointer size used in the LLVM module.
+  LLVM::LLVMType getIndexType() const {
+    return LLVM::LLVMType::getIntNTy(
+        &dialect, getModule().getDataLayout().getPointerSizeInBits());
+  }
+
+  // Get the MLIR type wrapping the LLVM i8* type.
+  LLVM::LLVMType getVoidPtrType() const {
+    return LLVM::LLVMType::getInt8PtrTy(&dialect);
+  }
+
+  // Create an LLVM IR pseudo-operation defining the given index constant.
+  Value *createIndexConstant(ConversionPatternRewriter &builder, Location loc,
+                             uint64_t value) const {
+    auto attr = builder.getIntegerAttr(builder.getIndexType(), value);
+    return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr);
+  }
+
+  // Get the array attribute named "position" containing the given list of
+  // integers as integer attribute elements.
+  static ArrayAttr getIntegerArrayAttr(ConversionPatternRewriter &builder,
+                                       ArrayRef<int64_t> values) {
+    SmallVector<Attribute, 4> attrs;
+    attrs.reserve(values.size());
+    for (int64_t pos : values)
+      attrs.push_back(builder.getIntegerAttr(builder.getIndexType(), pos));
+    return builder.getArrayAttr(attrs);
+  }
+
+  // Extract raw data pointer value from a value representing a memref.
+  static Value *extractMemRefElementPtr(ConversionPatternRewriter &builder,
+                                        Location loc,
+                                        Value *convertedMemRefValue,
+                                        Type elementTypePtr,
+                                        bool hasStaticShape) {
+    Value *buffer;
+    if (hasStaticShape)
+      return convertedMemRefValue;
+    else
+      return builder.create<LLVM::ExtractValueOp>(
+          loc, elementTypePtr, convertedMemRefValue,
+          getIntegerArrayAttr(builder, 0));
+    return buffer;
+  }
+
+protected:
+  LLVM::LLVMDialect &dialect;
+};
+
+struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
+  using LLVMLegalizationPattern<FuncOp>::LLVMLegalizationPattern;
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto funcOp = cast<FuncOp>(op);
+    FunctionType type = funcOp.getType();
+
+    // Convert the original function arguments.
+    TypeConverter::SignatureConversion result(type.getNumInputs());
+    for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
+      if (failed(lowering.convertSignatureArg(i, type.getInput(i), result)))
+        return matchFailure();
+
+    // Pack the result types into a struct.
+    Type packedResult;
+    if (type.getNumResults() != 0) {
+      if (!(packedResult = lowering.packFunctionResults(type.getResults())))
+        return matchFailure();
+    }
+
+    // Create a new function with an updated signature.
+    auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
+    rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+                                newFuncOp.end());
+    newFuncOp.setType(FunctionType::get(
+        result.getConvertedTypes(),
+        packedResult ? ArrayRef<Type>(packedResult) : llvm::None,
+        funcOp.getContext()));
+
+    // Tell the rewriter to convert the region signature.
+    rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
+    rewriter.replaceOp(op, llvm::None);
+    return matchSuccess();
+  }
+};
+
+// Basic lowering implementation for one-to-one rewriting from Standard Ops to
+// LLVM Dialect Ops.
+template <typename SourceOp, typename TargetOp>
+struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
+  using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
+  using Super = OneToOneLLVMOpLowering<SourceOp, TargetOp>;
+
+  // Convert the type of the result to an LLVM type, pass operands as is,
+  // preserve attributes.
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    unsigned numResults = op->getNumResults();
+
+    Type packedType;
+    if (numResults != 0) {
+      packedType = this->lowering.packFunctionResults(
+          llvm::to_vector<4>(op->getResultTypes()));
+      assert(packedType && "type conversion failed, such operation should not "
+                           "have been matched");
+    }
+
+    auto newOp = rewriter.create<TargetOp>(op->getLoc(), packedType, operands,
+                                           op->getAttrs());
+
+    // If the operation produced 0 or 1 result, return them immediately.
+    if (numResults == 0)
+      return rewriter.replaceOp(op, llvm::None), this->matchSuccess();
+    if (numResults == 1)
+      return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)),
+             this->matchSuccess();
+
+    // Otherwise, it had been converted to an operation producing a structure.
+    // Extract individual results from the structure and return them as list.
+    SmallVector<Value *, 4> results;
+    results.reserve(numResults);
+    for (unsigned i = 0; i < numResults; ++i) {
+      auto type = this->lowering.convertType(op->getResult(i)->getType());
+      results.push_back(rewriter.create<LLVM::ExtractValueOp>(
+          op->getLoc(), type, newOp.getOperation()->getResult(0),
+          this->getIntegerArrayAttr(rewriter, i)));
+    }
+    rewriter.replaceOp(op, results);
+    return this->matchSuccess();
+  }
+};
+
+// Specific lowerings.
+// FIXME: this should be tablegen'ed.
+struct AddIOpLowering : public OneToOneLLVMOpLowering<AddIOp, LLVM::AddOp> {
+  using Super::Super;
+};
+struct SubIOpLowering : public OneToOneLLVMOpLowering<SubIOp, LLVM::SubOp> {
+  using Super::Super;
+};
+struct MulIOpLowering : public OneToOneLLVMOpLowering<MulIOp, LLVM::MulOp> {
+  using Super::Super;
+};
+struct DivISOpLowering : public OneToOneLLVMOpLowering<DivISOp, LLVM::SDivOp> {
+  using Super::Super;
+};
+struct DivIUOpLowering : public OneToOneLLVMOpLowering<DivIUOp, LLVM::UDivOp> {
+  using Super::Super;
+};
+struct RemISOpLowering : public OneToOneLLVMOpLowering<RemISOp, LLVM::SRemOp> {
+  using Super::Super;
+};
+struct RemIUOpLowering : public OneToOneLLVMOpLowering<RemIUOp, LLVM::URemOp> {
+  using Super::Super;
+};
+struct AndOpLowering : public OneToOneLLVMOpLowering<AndOp, LLVM::AndOp> {
+  using Super::Super;
+};
+struct OrOpLowering : public OneToOneLLVMOpLowering<OrOp, LLVM::OrOp> {
+  using Super::Super;
+};
+struct XOrOpLowering : public OneToOneLLVMOpLowering<XOrOp, LLVM::XOrOp> {
+  using Super::Super;
+};
+struct AddFOpLowering : public OneToOneLLVMOpLowering<AddFOp, LLVM::FAddOp> {
+  using Super::Super;
+};
+struct SubFOpLowering : public OneToOneLLVMOpLowering<SubFOp, LLVM::FSubOp> {
+  using Super::Super;
+};
+struct MulFOpLowering : public OneToOneLLVMOpLowering<MulFOp, LLVM::FMulOp> {
+  using Super::Super;
+};
+struct DivFOpLowering : public OneToOneLLVMOpLowering<DivFOp, LLVM::FDivOp> {
+  using Super::Super;
+};
+struct RemFOpLowering : public OneToOneLLVMOpLowering<RemFOp, LLVM::FRemOp> {
+  using Super::Super;
+};
+struct SelectOpLowering
+    : public OneToOneLLVMOpLowering<SelectOp, LLVM::SelectOp> {
+  using Super::Super;
+};
+struct CallOpLowering : public OneToOneLLVMOpLowering<CallOp, LLVM::CallOp> {
+  using Super::Super;
+};
+struct CallIndirectOpLowering
+    : public OneToOneLLVMOpLowering<CallIndirectOp, LLVM::CallOp> {
+  using Super::Super;
+};
+struct ConstLLVMOpLowering
+    : public OneToOneLLVMOpLowering<ConstantOp, LLVM::ConstantOp> {
+  using Super::Super;
+};
+
+// Check if the MemRefType `type` is supported by the lowering. We currently do
+// not support memrefs with affine maps and non-default memory spaces.
+static bool isSupportedMemRefType(MemRefType type) {
+  if (!type.getAffineMaps().empty())
+    return false;
+  if (type.getMemorySpace() != 0)
+    return false;
+  return true;
+}
+
+// An `alloc` is converted into a definition of a memref descriptor value and
+// a call to `malloc` to allocate the underlying data buffer.  The memref
+// descriptor is of the LLVM structure type where the first element is a pointer
+// to the (typed) data buffer, and the remaining elements serve to store
+// dynamic sizes of the memref using LLVM-converted `index` type.
+struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
+  using LLVMLegalizationPattern<AllocOp>::LLVMLegalizationPattern;
+
+  PatternMatchResult match(Operation *op) const override {
+    MemRefType type = cast<AllocOp>(op).getType();
+    return isSupportedMemRefType(type) ? matchSuccess() : matchFailure();
+  }
+
+  void rewrite(Operation *op, ArrayRef<Value *> operands,
+               ConversionPatternRewriter &rewriter) const override {
+    auto allocOp = cast<AllocOp>(op);
+    MemRefType type = allocOp.getType();
+
+    // Get actual sizes of the memref as values: static sizes are constant
+    // values and dynamic sizes are passed to 'alloc' as operands.  In case of
+    // zero-dimensional memref, assume a scalar (size 1).
+    SmallVector<Value *, 4> sizes;
+    auto numOperands = allocOp.getNumOperands();
+    sizes.reserve(numOperands);
+    unsigned i = 0;
+    for (int64_t s : type.getShape())
+      sizes.push_back(s == -1 ? operands[i++]
+                              : createIndexConstant(rewriter, op->getLoc(), s));
+    if (sizes.empty())
+      sizes.push_back(createIndexConstant(rewriter, op->getLoc(), 1));
+
+    // Compute the total number of memref elements.
+    Value *cumulativeSize = sizes.front();
+    for (unsigned i = 1, e = sizes.size(); i < e; ++i)
+      cumulativeSize = rewriter.create<LLVM::MulOp>(
+          op->getLoc(), getIndexType(),
+          ArrayRef<Value *>{cumulativeSize, sizes[i]});
+
+    // Compute the total amount of bytes to allocate.
+    auto elementType = type.getElementType();
+    assert((elementType.isIntOrFloat() || elementType.isa<VectorType>()) &&
+           "invalid memref element type");
+    uint64_t elementSize = 0;
+    if (auto vectorType = elementType.dyn_cast<VectorType>())
+      elementSize = vectorType.getNumElements() *
+                    llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
+    else
+      elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
+    cumulativeSize = rewriter.create<LLVM::MulOp>(
+        op->getLoc(), getIndexType(),
+        ArrayRef<Value *>{
+            cumulativeSize,
+            createIndexConstant(rewriter, op->getLoc(), elementSize)});
+
+    // Insert the `malloc` declaration if it is not already present.
+    auto module = op->getParentOfType<ModuleOp>();
+    FuncOp mallocFunc = module.lookupSymbol<FuncOp>("malloc");
+    if (!mallocFunc) {
+      auto mallocType =
+          rewriter.getFunctionType(getIndexType(), getVoidPtrType());
+      mallocFunc =
+          FuncOp::create(rewriter.getUnknownLoc(), "malloc", mallocType);
+      module.push_back(mallocFunc);
+    }
+
+    // Allocate the underlying buffer and store a pointer to it in the MemRef
+    // descriptor.
+    Value *allocated =
+        rewriter
+            .create<LLVM::CallOp>(op->getLoc(), getVoidPtrType(),
+                                  rewriter.getSymbolRefAttr(mallocFunc),
+                                  cumulativeSize)
+            .getResult(0);
+    auto structElementType = lowering.convertType(elementType);
+    auto elementPtrType =
+        structElementType.cast<LLVM::LLVMType>().getPointerTo();
+    allocated = rewriter.create<LLVM::BitcastOp>(op->getLoc(), elementPtrType,
+                                                 ArrayRef<Value *>(allocated));
+
+    // Deal with static memrefs
+    if (numOperands == 0)
+      return rewriter.replaceOp(op, allocated);
+
+    // Create the MemRef descriptor.
+    auto structType = lowering.convertType(type);
+    Value *memRefDescriptor = rewriter.create<LLVM::UndefOp>(
+        op->getLoc(), structType, ArrayRef<Value *>{});
+
+    memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
+        op->getLoc(), structType, memRefDescriptor, allocated,
+        getIntegerArrayAttr(rewriter, 0));
+
+    // Store dynamically allocated sizes in the descriptor.  Dynamic sizes are
+    // passed in as operands.
+    for (auto indexedSize : llvm::enumerate(operands)) {
+      memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
+          op->getLoc(), structType, memRefDescriptor, indexedSize.value(),
+          getIntegerArrayAttr(rewriter, 1 + indexedSize.index()));
+    }
+
+    // Return the final value of the descriptor.
+    rewriter.replaceOp(op, memRefDescriptor);
+  }
+};
+
+// A `dealloc` is converted into a call to `free` on the underlying data buffer.
+// The memref descriptor being an SSA value, there is no need to clean it up
+// in any way.
+struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
+  using LLVMLegalizationPattern<DeallocOp>::LLVMLegalizationPattern;
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    assert(operands.size() == 1 && "dealloc takes one operand");
+    OperandAdaptor<DeallocOp> transformed(operands);
+
+    // Insert the `free` declaration if it is not already present.
+    FuncOp freeFunc =
+        op->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>("free");
+    if (!freeFunc) {
+      auto freeType = rewriter.getFunctionType(getVoidPtrType(), {});
+      freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType);
+      op->getParentOfType<ModuleOp>().push_back(freeFunc);
+    }
+
+    auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
+    auto hasStaticShape = type.getUnderlyingType()->isPointerTy();
+    Type elementPtrType = hasStaticShape ? type : type.getStructElementType(0);
+    Value *bufferPtr =
+        extractMemRefElementPtr(rewriter, op->getLoc(), transformed.memref(),
+                                elementPtrType, hasStaticShape);
+    Value *casted = rewriter.create<LLVM::BitcastOp>(
+        op->getLoc(), getVoidPtrType(), bufferPtr);
+    rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+        op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
+    return matchSuccess();
+  }
+};
+
+struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
+  using LLVMLegalizationPattern<MemRefCastOp>::LLVMLegalizationPattern;
+
+  PatternMatchResult match(Operation *op) const override {
+    auto memRefCastOp = cast<MemRefCastOp>(op);
+    MemRefType sourceType =
+        memRefCastOp.getOperand()->getType().cast<MemRefType>();
+    MemRefType targetType = memRefCastOp.getType();
+    return (isSupportedMemRefType(targetType) &&
+            isSupportedMemRefType(sourceType))
+               ? matchSuccess()
+               : matchFailure();
+  }
+
+  void rewrite(Operation *op, ArrayRef<Value *> operands,
+               ConversionPatternRewriter &rewriter) const override {
+    auto memRefCastOp = cast<MemRefCastOp>(op);
+    OperandAdaptor<MemRefCastOp> transformed(operands);
+    auto targetType = memRefCastOp.getType();
+    auto sourceType = memRefCastOp.getOperand()->getType().cast<MemRefType>();
+
+    // Copy the data buffer pointer.
+    auto elementTypePtr = getMemRefElementPtrType(targetType, lowering);
+    Value *buffer =
+        extractMemRefElementPtr(rewriter, op->getLoc(), transformed.source(),
+                                elementTypePtr, sourceType.hasStaticShape());
+    // Account for static memrefs as target types
+    if (targetType.hasStaticShape())
+      return rewriter.replaceOp(op, buffer);
+
+    // Create the new MemRef descriptor.
+    auto structType = lowering.convertType(targetType);
+    Value *newDescriptor = rewriter.create<LLVM::UndefOp>(
+        op->getLoc(), structType, ArrayRef<Value *>{});
+    // Otherwise target type is dynamic memref, so create a proper descriptor.
+    newDescriptor = rewriter.create<LLVM::InsertValueOp>(
+        op->getLoc(), structType, newDescriptor, buffer,
+        getIntegerArrayAttr(rewriter, 0));
+
+    // Fill in the dynamic sizes of the new descriptor.  If the size was
+    // dynamic, copy it from the old descriptor.  If the size was static, insert
+    // the constant.  Note that the positions of dynamic sizes in the
+    // descriptors start from 1 (the buffer pointer is at position zero).
+    int64_t sourceDynamicDimIdx = 1;
+    int64_t targetDynamicDimIdx = 1;
+    for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
+      // Ignore new static sizes (they will be known from the type).  If the
+      // size was dynamic, update the index of dynamic types.
+      if (targetType.getShape()[i] != -1) {
+        if (sourceType.getShape()[i] == -1)
+          ++sourceDynamicDimIdx;
+        continue;
+      }
+
+      auto sourceSize = sourceType.getShape()[i];
+      Value *size =
+          sourceSize == -1
+              ? rewriter.create<LLVM::ExtractValueOp>(
+                    op->getLoc(), getIndexType(),
+                    transformed.source(), // NB: dynamic memref
+                    getIntegerArrayAttr(rewriter, sourceDynamicDimIdx++))
+              : createIndexConstant(rewriter, op->getLoc(), sourceSize);
+      newDescriptor = rewriter.create<LLVM::InsertValueOp>(
+          op->getLoc(), structType, newDescriptor, size,
+          getIntegerArrayAttr(rewriter, targetDynamicDimIdx++));
+    }
+    assert(sourceDynamicDimIdx - 1 == sourceType.getNumDynamicDims() &&
+           "source dynamic dimensions were not processed");
+    assert(targetDynamicDimIdx - 1 == targetType.getNumDynamicDims() &&
+           "target dynamic dimensions were not set up");
+
+    rewriter.replaceOp(op, newDescriptor);
+  }
+};
+
+// A `dim` is converted to a constant for static sizes and to an access to the
+// size stored in the memref descriptor for dynamic sizes.
+struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
+  using LLVMLegalizationPattern<DimOp>::LLVMLegalizationPattern;
+
+  PatternMatchResult match(Operation *op) const override {
+    auto dimOp = cast<DimOp>(op);
+    MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>();
+    return isSupportedMemRefType(type) ? matchSuccess() : matchFailure();
+  }
+
+  void rewrite(Operation *op, ArrayRef<Value *> operands,
+               ConversionPatternRewriter &rewriter) const override {
+    auto dimOp = cast<DimOp>(op);
+    OperandAdaptor<DimOp> transformed(operands);
+    MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>();
+
+    auto shape = type.getShape();
+    uint64_t index = dimOp.getIndex();
+    // Extract dynamic size from the memref descriptor and define static size
+    // as a constant.
+    if (shape[index] == -1) {
+      // Find the position of the dynamic dimension in the list of dynamic sizes
+      // by counting the number of preceding dynamic dimensions.  Start from 1
+      // because the buffer pointer is at position zero.
+      int64_t position = 1;
+      for (uint64_t i = 0; i < index; ++i) {
+        if (shape[i] == -1)
+          ++position;
+      }
+      rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
+          op, getIndexType(), transformed.memrefOrTensor(),
+          getIntegerArrayAttr(rewriter, position));
+    } else {
+      rewriter.replaceOp(
+          op, createIndexConstant(rewriter, op->getLoc(), shape[index]));
+    }
+  }
+};
+
+// Common base for load and store operations on MemRefs.  Restricts the match
+// to supported MemRef types.  Provides functionality to emit code accessing a
+// specific element of the underlying data buffer.
+template <typename Derived>
+struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
+  using LLVMLegalizationPattern<Derived>::LLVMLegalizationPattern;
+  using Base = LoadStoreOpLowering<Derived>;
+
+  PatternMatchResult match(Operation *op) const override {
+    MemRefType type = cast<Derived>(op).getMemRefType();
+    return isSupportedMemRefType(type) ? this->matchSuccess()
+                                       : this->matchFailure();
+  }
+
+  // Given subscript indices and array sizes in row-major order,
+  //   i_n, i_{n-1}, ..., i_1
+  //   s_n, s_{n-1}, ..., s_1
+  // obtain a value that corresponds to the linearized subscript
+  //   \sum_k i_k * \prod_{j=1}^{k-1} s_j
+  // by accumulating the running linearized value.
+  // Note that `indices` and `allocSizes` are passed in the same order as they
+  // appear in load/store operations and memref type declarations.
+  Value *linearizeSubscripts(ConversionPatternRewriter &builder, Location loc,
+                             ArrayRef<Value *> indices,
+                             ArrayRef<Value *> allocSizes) const {
+    assert(indices.size() == allocSizes.size() &&
+           "mismatching number of indices and allocation sizes");
+    assert(!indices.empty() && "cannot linearize a 0-dimensional access");
+
+    Value *linearized = indices.front();
+    for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) {
+      linearized = builder.create<LLVM::MulOp>(
+          loc, this->getIndexType(),
+          ArrayRef<Value *>{linearized, allocSizes[i]});
+      linearized = builder.create<LLVM::AddOp>(
+          loc, this->getIndexType(), ArrayRef<Value *>{linearized, indices[i]});
+    }
+    return linearized;
+  }
+
+  // Given the MemRef type, a descriptor and a list of indices, extract the data
+  // buffer pointer from the descriptor, convert multi-dimensional subscripts
+  // into a linearized index (using dynamic size data from the descriptor if
+  // necessary) and get the pointer to the buffer element identified by the
+  // indices.
+  Value *getElementPtr(Location loc, Type elementTypePtr,
+                       ArrayRef<int64_t> shape, Value *memRefDescriptor,
+                       ArrayRef<Value *> indices,
+                       ConversionPatternRewriter &rewriter) const {
+    // Get the list of MemRef sizes.  Static sizes are defined as constants.
+    // Dynamic sizes are extracted from the MemRef descriptor, where they start
+    // from the position 1 (the buffer is at position 0).
+    SmallVector<Value *, 4> sizes;
+    unsigned dynamicSizeIdx = 1;
+    for (int64_t s : shape) {
+      if (s == -1) {
+        Value *size = rewriter.create<LLVM::ExtractValueOp>(
+            loc, this->getIndexType(), memRefDescriptor,
+            this->getIntegerArrayAttr(rewriter, dynamicSizeIdx++));
+        sizes.push_back(size);
+      } else {
+        sizes.push_back(this->createIndexConstant(rewriter, loc, s));
+      }
+    }
+
+    // The second and subsequent operands are access subscripts.  Obtain the
+    // linearized address in the buffer.
+    Value *subscript = linearizeSubscripts(rewriter, loc, indices, sizes);
+
+    Value *dataPtr = rewriter.create<LLVM::ExtractValueOp>(
+        loc, elementTypePtr, memRefDescriptor,
+        this->getIntegerArrayAttr(rewriter, 0));
+    return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr,
+                                        ArrayRef<Value *>{dataPtr, subscript},
+                                        ArrayRef<NamedAttribute>{});
+  }
+  // This is a getElementPtr variant, where the value is a direct raw pointer.
+  // If a shape is empty, we are dealing with a zero-dimensional memref. Return
+  // the pointer unmodified in this case.  Otherwise, linearize subscripts to
+  // obtain the offset with respect to the base pointer.  Use this offset to
+  // compute and return the element pointer.
+  Value *getRawElementPtr(Location loc, Type elementTypePtr,
+                          ArrayRef<int64_t> shape, Value *rawDataPtr,
+                          ArrayRef<Value *> indices,
+                          ConversionPatternRewriter &rewriter) const {
+    if (shape.empty())
+      return rawDataPtr;
+
+    SmallVector<Value *, 4> sizes;
+    for (int64_t s : shape) {
+      sizes.push_back(this->createIndexConstant(rewriter, loc, s));
+    }
+
+    Value *subscript = linearizeSubscripts(rewriter, loc, indices, sizes);
+    return rewriter.create<LLVM::GEPOp>(
+        loc, elementTypePtr, ArrayRef<Value *>{rawDataPtr, subscript},
+        ArrayRef<NamedAttribute>{});
+  }
+
+  Value *getDataPtr(Location loc, MemRefType type, Value *dataPtr,
+                    ArrayRef<Value *> indices,
+                    ConversionPatternRewriter &rewriter,
+                    llvm::Module &module) const {
+    auto ptrType = getMemRefElementPtrType(type, this->lowering);
+    auto shape = type.getShape();
+    if (type.hasStaticShape()) {
+      // NB: If memref was statically-shaped, dataPtr is pointer to raw data.
+      return getRawElementPtr(loc, ptrType, shape, dataPtr, indices, rewriter);
+    }
+    return getElementPtr(loc, ptrType, shape, dataPtr, indices, rewriter);
+  }
+};
+
+// Load operation is lowered to obtaining a pointer to the indexed element
+// and loading it.
+struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
+  using Base::Base;
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loadOp = cast<LoadOp>(op);
+    OperandAdaptor<LoadOp> transformed(operands);
+    auto type = loadOp.getMemRefType();
+
+    Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
+                                transformed.indices(), rewriter, getModule());
+    auto elementType = lowering.convertType(type.getElementType());
+
+    rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, elementType,
+                                              ArrayRef<Value *>{dataPtr});
+    return matchSuccess();
+  }
+};
+
+// Store opreation is lowered to obtaining a pointer to the indexed element,
+// and storing the given value to it.
+struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
+  using Base::Base;
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto type = cast<StoreOp>(op).getMemRefType();
+    OperandAdaptor<StoreOp> transformed(operands);
+
+    Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
+                                transformed.indices(), rewriter, getModule());
+    rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
+                                               dataPtr);
+    return matchSuccess();
+  }
+};
+
+// The lowering of index_cast becomes an integer conversion since index becomes
+// an integer.  If the bit width of the source and target integer types is the
+// same, just erase the cast.  If the target type is wider, sign-extend the
+// value, otherwise truncate it.
+struct IndexCastOpLowering : public LLVMLegalizationPattern<IndexCastOp> {
+  using LLVMLegalizationPattern<IndexCastOp>::LLVMLegalizationPattern;
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    IndexCastOpOperandAdaptor transformed(operands);
+    auto indexCastOp = cast<IndexCastOp>(op);
+
+    auto targetType =
+        this->lowering.convertType(indexCastOp.getResult()->getType())
+            .cast<LLVM::LLVMType>();
+    auto sourceType = transformed.in()->getType().cast<LLVM::LLVMType>();
+    unsigned targetBits = targetType.getUnderlyingType()->getIntegerBitWidth();
+    unsigned sourceBits = sourceType.getUnderlyingType()->getIntegerBitWidth();
+
+    if (targetBits == sourceBits)
+      rewriter.replaceOp(op, transformed.in());
+    else if (targetBits < sourceBits)
+      rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
+                                                 transformed.in());
+    else
+      rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType,
+                                                transformed.in());
+    return matchSuccess();
+  }
+};
+
+// Convert std.cmp predicate into the LLVM dialect CmpPredicate.  The two
+// enums share the numerical values so just cast.
+template <typename LLVMPredType, typename StdPredType>
+static LLVMPredType convertCmpPredicate(StdPredType pred) {
+  return static_cast<LLVMPredType>(pred);
+}
+
+struct CmpIOpLowering : public LLVMLegalizationPattern<CmpIOp> {
+  using LLVMLegalizationPattern<CmpIOp>::LLVMLegalizationPattern;
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto cmpiOp = cast<CmpIOp>(op);
+    CmpIOpOperandAdaptor transformed(operands);
+
+    rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
+        op, lowering.convertType(cmpiOp.getResult()->getType()),
+        rewriter.getI64IntegerAttr(static_cast<int64_t>(
+            convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))),
+        transformed.lhs(), transformed.rhs());
+
+    return matchSuccess();
+  }
+};
+
+struct CmpFOpLowering : public LLVMLegalizationPattern<CmpFOp> {
+  using LLVMLegalizationPattern<CmpFOp>::LLVMLegalizationPattern;
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto cmpfOp = cast<CmpFOp>(op);
+    CmpFOpOperandAdaptor transformed(operands);
+
+    rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
+        op, lowering.convertType(cmpfOp.getResult()->getType()),
+        rewriter.getI64IntegerAttr(static_cast<int64_t>(
+            convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))),
+        transformed.lhs(), transformed.rhs());
+
+    return matchSuccess();
+  }
+};
+
+struct SIToFPLowering
+    : public OneToOneLLVMOpLowering<SIToFPOp, LLVM::SIToFPOp> {
+  using Super::Super;
+};
+
+// Base class for LLVM IR lowering terminator operations with successors.
+template <typename SourceOp, typename TargetOp>
+struct OneToOneLLVMTerminatorLowering
+    : public LLVMLegalizationPattern<SourceOp> {
+  using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
+  using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> properOperands,
+                  ArrayRef<Block *> destinations,
+                  ArrayRef<ArrayRef<Value *>> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<TargetOp>(op, properOperands, destinations,
+                                          operands, op->getAttrs());
+    return this->matchSuccess();
+  }
+};
+
+// Special lowering pattern for `ReturnOps`.  Unlike all other operations,
+// `ReturnOp` interacts with the function signature and must have as many
+// operands as the function has return values.  Because in LLVM IR, functions
+// can only return 0 or 1 value, we pack multiple values into a structure type.
+// Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
+// necessary before returning it
+struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
+  using LLVMLegalizationPattern<ReturnOp>::LLVMLegalizationPattern;
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    unsigned numArguments = op->getNumOperands();
+
+    // If ReturnOp has 0 or 1 operand, create it and return immediately.
+    if (numArguments == 0) {
+      rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
+          op, llvm::ArrayRef<Value *>(), llvm::ArrayRef<Block *>(),
+          llvm::ArrayRef<llvm::ArrayRef<Value *>>(), op->getAttrs());
+      return matchSuccess();
+    }
+    if (numArguments == 1) {
+      rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
+          op, llvm::ArrayRef<Value *>(operands.front()),
+          llvm::ArrayRef<Block *>(), llvm::ArrayRef<llvm::ArrayRef<Value *>>(),
+          op->getAttrs());
+      return matchSuccess();
+    }
+
+    // Otherwise, we need to pack the arguments into an LLVM struct type before
+    // returning.
+    auto packedType =
+        lowering.packFunctionResults(llvm::to_vector<4>(op->getOperandTypes()));
+
+    Value *packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType);
+    for (unsigned i = 0; i < numArguments; ++i) {
+      packed = rewriter.create<LLVM::InsertValueOp>(
+          op->getLoc(), packedType, packed, operands[i],
+          getIntegerArrayAttr(rewriter, i));
+    }
+    rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
+        op, llvm::makeArrayRef(packed), llvm::ArrayRef<Block *>(),
+        llvm::ArrayRef<llvm::ArrayRef<Value *>>(), op->getAttrs());
+    return matchSuccess();
+  }
+};
+
+// FIXME: this should be tablegen'ed as well.
+struct BranchOpLowering
+    : public OneToOneLLVMTerminatorLowering<BranchOp, LLVM::BrOp> {
+  using Super::Super;
+};
+struct CondBranchOpLowering
+    : public OneToOneLLVMTerminatorLowering<CondBranchOp, LLVM::CondBrOp> {
+  using Super::Super;
+};
+
+} // namespace
+
+static void ensureDistinctSuccessors(Block &bb) {
+  auto *terminator = bb.getTerminator();
+
+  // Find repeated successors with arguments.
+  llvm::SmallDenseMap<Block *, llvm::SmallVector<int, 4>> successorPositions;
+  for (int i = 0, e = terminator->getNumSuccessors(); i < e; ++i) {
+    Block *successor = terminator->getSuccessor(i);
+    // Blocks with no arguments are safe even if they appear multiple times
+    // because they don't need PHI nodes.
+    if (successor->getNumArguments() == 0)
+      continue;
+    successorPositions[successor].push_back(i);
+  }
+
+  // If a successor appears for the second or more time in the terminator,
+  // create a new dummy block that unconditionally branches to the original
+  // destination, and retarget the terminator to branch to this new block.
+  // There is no need to pass arguments to the dummy block because it will be
+  // dominated by the original block and can therefore use any values defined in
+  // the original block.
+  for (const auto &successor : successorPositions) {
+    const auto &positions = successor.second;
+    // Start from the second occurrence of a block in the successor list.
+    for (auto position = std::next(positions.begin()), end = positions.end();
+         position != end; ++position) {
+      auto *dummyBlock = new Block();
+      bb.getParent()->push_back(dummyBlock);
+      auto builder = OpBuilder(dummyBlock);
+      SmallVector<Value *, 8> operands(
+          terminator->getSuccessorOperands(*position));
+      builder.create<BranchOp>(terminator->getLoc(), successor.first, operands);
+      terminator->setSuccessor(dummyBlock, *position);
+      for (int i = 0, e = terminator->getNumSuccessorOperands(*position); i < e;
+           ++i)
+        terminator->eraseSuccessorOperand(*position, i);
+    }
+  }
+}
+
+void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) {
+  for (auto f : m.getOps<FuncOp>()) {
+    for (auto &bb : f.getBlocks()) {
+      ::ensureDistinctSuccessors(bb);
+    }
+  }
+}
+
+/// Collect a set of patterns to convert from the Standard dialect to LLVM.
+void mlir::populateStdToLLVMConversionPatterns(
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+  // FIXME: this should be tablegen'ed
+  patterns.insert<
+      AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering,
+      BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering,
+      CmpFOpLowering, CondBranchOpLowering, ConstLLVMOpLowering,
+      DeallocOpLowering, DimOpLowering, DivISOpLowering, DivIUOpLowering,
+      DivFOpLowering, FuncOpConversion, IndexCastOpLowering, LoadOpLowering,
+      MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering,
+      RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering,
+      SelectOpLowering, SIToFPLowering, StoreOpLowering, SubFOpLowering,
+      SubIOpLowering, XOrOpLowering>(*converter.getDialect(), converter);
+}
+
+// Convert types using the stored LLVM IR module.
+Type LLVMTypeConverter::convertType(Type t) { return convertStandardType(t); }
+
+// Create an LLVM IR structure type if there is more than one result.
+Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) {
+  assert(!types.empty() && "expected non-empty list of type");
+
+  if (types.size() == 1)
+    return convertType(types.front());
+
+  SmallVector<LLVM::LLVMType, 8> resultTypes;
+  resultTypes.reserve(types.size());
+  for (auto t : types) {
+    auto converted = convertType(t).dyn_cast<LLVM::LLVMType>();
+    if (!converted)
+      return {};
+    resultTypes.push_back(converted);
+  }
+
+  return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes);
+}
+
+/// Create an instance of LLVMTypeConverter in the given context.
+static std::unique_ptr<LLVMTypeConverter>
+makeStandardToLLVMTypeConverter(MLIRContext *context) {
+  return llvm::make_unique<LLVMTypeConverter>(context);
+}
+
+namespace {
+/// A pass converting MLIR operations into the LLVM IR dialect.
+struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
+  // By default, the patterns are those converting Standard operations to the
+  // LLVMIR dialect.
+  explicit LLVMLoweringPass(
+      LLVMPatternListFiller patternListFiller =
+          populateStdToLLVMConversionPatterns,
+      LLVMTypeConverterMaker converterBuilder = makeStandardToLLVMTypeConverter)
+      : patternListFiller(patternListFiller),
+        typeConverterMaker(converterBuilder) {}
+
+  // Run the dialect converter on the module.
+  void runOnModule() override {
+    if (!typeConverterMaker || !patternListFiller)
+      return signalPassFailure();
+
+    ModuleOp m = getModule();
+    LLVM::ensureDistinctSuccessors(m);
+    std::unique_ptr<LLVMTypeConverter> typeConverter =
+        typeConverterMaker(&getContext());
+    if (!typeConverter)
+      return signalPassFailure();
+
+    OwningRewritePatternList patterns;
+    populateLoopToStdConversionPatterns(patterns, m.getContext());
+    patternListFiller(*typeConverter, patterns);
+
+    ConversionTarget target(getContext());
+    target.addLegalDialect<LLVM::LLVMDialect>();
+    target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+      return typeConverter->isSignatureLegal(op.getType());
+    });
+    if (failed(applyPartialConversion(m, target, patterns, &*typeConverter)))
+      signalPassFailure();
+  }
+
+  // Callback for creating a list of patterns.  It is called every time in
+  // runOnModule since applyPartialConversion consumes the list.
+  LLVMPatternListFiller patternListFiller;
+
+  // Callback for creating an instance of type converter.  The converter
+  // constructor needs an MLIRContext, which is not available until runOnModule.
+  LLVMTypeConverterMaker typeConverterMaker;
+};
+} // end namespace
+
+ModulePassBase *mlir::createConvertToLLVMIRPass() {
+  return new LLVMLoweringPass;
+}
+
+ModulePassBase *
+mlir::createConvertToLLVMIRPass(LLVMPatternListFiller patternListFiller,
+                                LLVMTypeConverterMaker typeConverterMaker) {
+  return new LLVMLoweringPass(patternListFiller, typeConverterMaker);
+}
+
+static PassRegistration<LLVMLoweringPass>
+    pass("lower-to-llvm", "Convert all functions to the LLVM IR dialect");
diff --git a/third_party/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt b/third_party/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt
new file mode 100644
index 0000000..be53112
--- /dev/null
+++ b/third_party/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt
@@ -0,0 +1,25 @@
+set(LLVM_TARGET_DEFINITIONS StandardToSPIRV.td)
+mlir_tablegen(StandardToSPIRV.cpp.inc -gen-rewriters)
+add_public_tablegen_target(MLIRStandardToSPIRVIncGen)
+
+add_llvm_library(MLIRSPIRVConversion
+  ConvertStandardToSPIRV.cpp
+  ConvertStandardToSPIRVPass.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
+  )
+
+add_dependencies(MLIRSPIRVConversion
+  MLIRStandardToSPIRVIncGen)
+
+target_link_libraries(MLIRSPIRVConversion
+  MLIRIR
+  MLIRPass
+  MLIRSPIRV
+  MLIRSupport
+  MLIRTransformUtils
+  MLIRSPIRV
+  MLIRStandardOps
+  )
diff --git a/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
new file mode 100644
index 0000000..067f2ae
--- /dev/null
+++ b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -0,0 +1,206 @@
+//===- ConvertStandardToSPIRV.cpp - Standard to SPIR-V dialect conversion--===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to convert MLIR standard and builtin dialects
+// into the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/StandardOps/Ops.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Type Conversion
+//===----------------------------------------------------------------------===//
+
+SPIRVTypeConverter::SPIRVTypeConverter(MLIRContext *context)
+    : spirvDialect(context->getRegisteredDialect<spirv::SPIRVDialect>()) {}
+
+Type SPIRVTypeConverter::convertType(Type t) {
+  // Check if the type is SPIR-V supported. If so return the type.
+  if (spirvDialect->isValidSPIRVType(t)) {
+    return t;
+  }
+
+  if (auto memRefType = t.dyn_cast<MemRefType>()) {
+    if (memRefType.hasStaticShape()) {
+      // Convert MemrefType to spv.array if size is known.
+      // TODO(ravishankarm) : For now hard-coding this to be StorageBuffer. Need
+      // to support other Storage Classes.
+      return spirv::PointerType::get(
+          spirv::ArrayType::get(memRefType.getElementType(),
+                                memRefType.getNumElements()),
+          spirv::StorageClass::StorageBuffer);
+    }
+  }
+  return Type();
+}
+
+//===----------------------------------------------------------------------===//
+// Entry Function signature Conversion
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+SPIRVEntryFnTypeConverter::convertSignatureArg(unsigned inputNo, Type type,
+                                               SignatureConversion &result) {
+  // Try to convert the given input type.
+  auto convertedType = convertType(type);
+  // TODO(ravishankarm) : Vulkan spec requires these to be a
+  // spirv::StructType. This is not a SPIR-V requirement, so just making this a
+  // pointer type for now.
+  if (!convertedType)
+    return failure();
+  // For arguments to entry functions, convert the type into a pointer type if
+  // it is already not one.
+  if (!convertedType.isa<spirv::PointerType>()) {
+    // TODO(ravishankarm) : For now hard-coding this to be StorageBuffer. Need
+    // to support other Storage classes.
+    convertedType = spirv::PointerType::get(convertedType,
+                                            spirv::StorageClass::StorageBuffer);
+  }
+
+  // Add the new inputs.
+  result.addInputs(inputNo, convertedType);
+  return success();
+}
+
+template <typename Converter>
+static LogicalResult
+lowerFunctionImpl(FuncOp funcOp, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter, Converter &typeConverter,
+                  TypeConverter::SignatureConversion &signatureConverter,
+                  FuncOp &newFuncOp) {
+  auto fnType = funcOp.getType();
+
+  if (fnType.getNumResults()) {
+    return funcOp.emitError("SPIR-V dialect only supports functions with no "
+                            "return values right now");
+  }
+
+  for (auto &argType : enumerate(fnType.getInputs())) {
+    // Get the type of the argument
+    if (failed(typeConverter.convertSignatureArg(
+            argType.index(), argType.value(), signatureConverter))) {
+      return funcOp.emitError("unable to convert argument type ")
+             << argType.value() << " to SPIR-V type";
+    }
+  }
+
+  // Create a new function with an updated signature.
+  newFuncOp = rewriter.cloneWithoutRegions(funcOp);
+  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+                              newFuncOp.end());
+  newFuncOp.setType(FunctionType::get(signatureConverter.getConvertedTypes(),
+                                      llvm::None, funcOp.getContext()));
+
+  // Tell the rewriter to convert the region signature.
+  rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
+  rewriter.replaceOp(funcOp.getOperation(), llvm::None);
+  return success();
+}
+
+LogicalResult
+SPIRVFnLowering::lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
+                               ConversionPatternRewriter &rewriter,
+                               FuncOp &newFuncOp) const {
+  auto fnType = funcOp.getType();
+  TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
+  return lowerFunctionImpl(funcOp, operands, rewriter, typeConverter,
+                           signatureConverter, newFuncOp);
+}
+
+LogicalResult
+SPIRVFnLowering::lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
+                                      ConversionPatternRewriter &rewriter,
+                                      FuncOp &newFuncOp) const {
+  auto fnType = funcOp.getType();
+  TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
+  if (failed(lowerFunctionImpl(funcOp, operands, rewriter, entryFnConverter,
+                               signatureConverter, newFuncOp))) {
+    return failure();
+  }
+  // Create spv.Variable ops for each of the arguments. These need to be bound
+  // by the runtime. For now use descriptor_set 0, and arg number as the binding
+  // number.
+  auto module = funcOp.getParentOfType<spirv::ModuleOp>();
+  if (!module) {
+    return funcOp.emitError("expected op to be within a spv.module");
+  }
+  OpBuilder builder(module.getOperation()->getRegion(0));
+  SmallVector<Value *, 4> interface;
+  for (auto &convertedArgType :
+       llvm::enumerate(signatureConverter.getConvertedTypes())) {
+    auto variableOp = builder.create<spirv::VariableOp>(
+        funcOp.getLoc(), convertedArgType.value(),
+        builder.getI32IntegerAttr(
+            static_cast<int32_t>(spirv::StorageClass::StorageBuffer)),
+        llvm::None);
+    variableOp.setAttr("descriptor_set", builder.getI32IntegerAttr(0));
+    variableOp.setAttr("binding",
+                       builder.getI32IntegerAttr(convertedArgType.index()));
+    interface.push_back(variableOp.getResult());
+  }
+  // Create an entry point instruction for this function.
+  // TODO(ravishankarm) : Add execution mode for the entry function
+  builder.setInsertionPoint(&(module.getBlock().back()));
+  builder.create<spirv::EntryPointOp>(
+      funcOp.getLoc(),
+      builder.getI32IntegerAttr(
+          static_cast<int32_t>(spirv::ExecutionModel::GLCompute)),
+      builder.getSymbolRefAttr(newFuncOp.getName()), interface);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Operation conversion
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Convert return -> spv.Return.
+class ReturnToSPIRVConversion : public ConversionPattern {
+public:
+  ReturnToSPIRVConversion(MLIRContext *context)
+      : ConversionPattern(ReturnOp::getOperationName(), 1, context) {}
+  virtual PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (op->getNumOperands()) {
+      return matchFailure();
+    }
+    rewriter.replaceOpWithNewOp<spirv::ReturnOp>(op);
+    return matchSuccess();
+  }
+};
+
+} // namespace
+
+namespace {
+/// Import the Standard Ops to SPIR-V Patterns.
+#include "StandardToSPIRV.cpp.inc"
+} // namespace
+
+namespace mlir {
+void populateStandardToSPIRVPatterns(MLIRContext *context,
+                                     OwningRewritePatternList &patterns) {
+  populateWithGenerated(context, &patterns);
+  // Add the return op conversion.
+  patterns.insert<ReturnToSPIRVConversion>(context);
+}
+} // namespace mlir
diff --git a/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
new file mode 100644
index 0000000..ad2c4b5
--- /dev/null
+++ b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
@@ -0,0 +1,56 @@
+//===- ConvertStandardToSPIRVPass.cpp - Convert Std Ops to SPIR-V Ops -----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to convert MLIR standard ops into the SPIR-V
+// ops. It does not legalize FuncOps.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
+#include "mlir/Dialect/SPIRV/Passes.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+
+using namespace mlir;
+
+namespace {
+/// A pass converting MLIR Standard operations into the SPIR-V dialect.
+class ConvertStandardToSPIRVPass
+    : public ModulePass<ConvertStandardToSPIRVPass> {
+  void runOnModule() override;
+};
+} // namespace
+
+void ConvertStandardToSPIRVPass::runOnModule() {
+  OwningRewritePatternList patterns;
+  auto module = getModule();
+
+  populateStandardToSPIRVPatterns(module.getContext(), patterns);
+  ConversionTarget target(*(module.getContext()));
+  target.addLegalDialect<spirv::SPIRVDialect>();
+  target.addLegalOp<FuncOp>();
+
+  if (failed(applyPartialConversion(module, target, patterns))) {
+    return signalPassFailure();
+  }
+}
+
+ModulePassBase *mlir::spirv::createConvertStandardToSPIRVPass() {
+  return new ConvertStandardToSPIRVPass();
+}
+
+static PassRegistration<ConvertStandardToSPIRVPass>
+    pass("convert-std-to-spirv", "Convert Standard Ops to SPIR-V dialect");
diff --git a/third_party/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td b/third_party/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td
new file mode 100644
index 0000000..9198e85
--- /dev/null
+++ b/third_party/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td
@@ -0,0 +1,48 @@
+//==- StandardToSPIRV.td - Standard Ops to SPIR-V Patterns ---*- tablegen -*==//
+
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines Patterns to lower standard ops to SPIR-V.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef MLIR_CONVERSION_STANDARDTOSPIRV_TD
+#else
+#define MLIR_CONVERSION_STANDARDTOSPIRV_TD
+
+#ifdef STANDARD_OPS
+#else
+include "mlir/StandardOps/Ops.td"
+#endif // STANDARD_OPS
+
+#ifdef SPIRV_OPS
+#else
+include "mlir/Dialect/SPIRV/SPIRVOps.td"
+#endif // SPIRV_OPS
+
+def IsScalar : TypeConstraint<CPred<"!($_self.isa<ShapedType>())">, "scalar">;
+
+class IsVectorLengthPred<int vecLength> :
+      CPred<"($_self.cast<VectorType>().getShape().size() == 1 && " #
+            "$_self.cast<VectorType>().getShape()[0] == " # vecLength # ")">;
+
+class IsVectorOfLength<int vecLength>:
+    TypeConstraint<And<[IsVectorTypePred, IsVectorLengthPred<vecLength>]>,
+                   vecLength # "-element vector">;
+
+multiclass BinaryOpPattern<Op src, SPV_Op tgt> {
+  def : Pat<(src IsScalar:$l, IsScalar:$r), (tgt $l, $r)>;
+  foreach vecLength = [2, 3, 4] in {
+    def : Pat<(src IsVectorOfLength<vecLength>:$l,
+                   IsVectorOfLength<vecLength>:$r),
+              (tgt $l, $r)>;
+  }
+}
+
+defm : BinaryOpPattern<MulFOp, SPV_FMulOp>;
+
+#endif // MLIR_CONVERSION_STANDARDTOSPIRV_TD
diff --git a/third_party/mlir/lib/Dialect/CMakeLists.txt b/third_party/mlir/lib/Dialect/CMakeLists.txt
new file mode 100644
index 0000000..8898c43
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_subdirectory(FxpMathOps)
+add_subdirectory(GPU)
+add_subdirectory(LoopOps)
+add_subdirectory(QuantOps)
+add_subdirectory(SPIRV)
+
+add_llvm_library(MLIRDialect
+  Traits.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect
+  )
+target_link_libraries(MLIRDialect MLIRIR)
diff --git a/third_party/mlir/lib/Dialect/FxpMathOps/CMakeLists.txt b/third_party/mlir/lib/Dialect/FxpMathOps/CMakeLists.txt
new file mode 100644
index 0000000..9eddc55
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/FxpMathOps/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_llvm_library(MLIRFxpMathOps
+  IR/FxpMathOps.cpp
+  IR/DialectRegistration.cpp
+  Transforms/LowerUniformRealMath.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/FxpMathOps
+  )
+add_dependencies(MLIRFxpMathOps
+                 MLIRFxpMathOpsIncGen
+                 MLIRQuantOps
+                 MLIRIR
+                 MLIRPass
+                 MLIRSupport
+                 MLIRStandardOps)
diff --git a/third_party/mlir/lib/Dialect/FxpMathOps/IR/DialectRegistration.cpp b/third_party/mlir/lib/Dialect/FxpMathOps/IR/DialectRegistration.cpp
new file mode 100644
index 0000000..aa6782e
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/FxpMathOps/IR/DialectRegistration.cpp
@@ -0,0 +1,24 @@
+//===- DialectRegistration.cpp - Register FxpMathOps dialect --------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
+
+using namespace mlir;
+using namespace mlir::fxpmath;
+
+// Static initialization for the fxpmath ops dialect registration.
+static mlir::DialectRegistration<FxpMathOpsDialect> FxpMathOps;
diff --git a/third_party/mlir/lib/Dialect/FxpMathOps/IR/FxpMathOps.cpp b/third_party/mlir/lib/Dialect/FxpMathOps/IR/FxpMathOps.cpp
new file mode 100644
index 0000000..18c07b0
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/FxpMathOps/IR/FxpMathOps.cpp
@@ -0,0 +1,38 @@
+//===- FxpMathOps.cpp - Op implementation for FxpMathOps ------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
+#include "mlir/Dialect/QuantOps/QuantTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/StandardTypes.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/MathExtras.h"
+
+using namespace mlir;
+using namespace mlir::fxpmath;
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/FxpMathOps/FxpMathOps.cpp.inc"
+
+FxpMathOpsDialect::FxpMathOpsDialect(MLIRContext *context)
+    : Dialect(/*name=*/"fxpmath", context) {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/FxpMathOps/FxpMathOps.cpp.inc"
+      >();
+}
diff --git a/third_party/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/third_party/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
new file mode 100644
index 0000000..e6c351b
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
@@ -0,0 +1,401 @@
+//===- LowerUniformRealMath.cpp  ------------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "UniformKernelUtils.h"
+
+#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
+#include "mlir/Dialect/FxpMathOps/Passes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+
+using namespace mlir;
+using namespace mlir::fxpmath;
+using namespace mlir::fxpmath::detail;
+using namespace mlir::quant;
+
+namespace {
+
+struct LowerUniformRealMathPass
+    : public FunctionPass<LowerUniformRealMathPass> {
+  void runOnFunction() override;
+};
+
+struct LowerUniformCastsPass : public FunctionPass<LowerUniformCastsPass> {
+  void runOnFunction() override;
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Dequantize
+//===----------------------------------------------------------------------===//
+
+static Value *emitUniformPerLayerDequantize(Location loc, Value *input,
+                                            UniformQuantizedType elementType,
+                                            PatternRewriter &rewriter) {
+  // Pre-conditions.
+  if (!elementType.isSigned()) {
+    // TODO: Support unsigned storage type.
+    emitWarning(loc, "unimplemented: dequantize signed uniform");
+    return nullptr;
+  }
+
+  Type storageType = elementType.castToStorageType(input->getType());
+  Type realType = elementType.castToExpressedType(input->getType());
+  Type intermediateType =
+      castElementType(storageType, IntegerType::get(32, rewriter.getContext()));
+  assert(storageType && "cannot cast to storage type");
+  assert(realType && "cannot cast to expressed type");
+
+  // Cast to storage type.
+  input = rewriter.create<StorageCastOp>(loc, storageType, input);
+
+  // Promote to intermediate type.
+  input = rewriter.create<ConvertISOp>(loc, intermediateType, input);
+
+  // Apply zero-point offset.
+  if (elementType.getZeroPoint() != 0) {
+    Value *negZeroPointConst = rewriter.create<ConstantOp>(
+        loc, broadcastScalarConstIntValue(intermediateType,
+                                          -elementType.getZeroPoint()));
+    input = rewriter.create<AddIOp>(loc, input, negZeroPointConst);
+  }
+
+  // Convert to float.
+  input = rewriter.create<ConvertISToFOp>(loc, realType, input);
+
+  // Mul by scale.
+  Value *scaleConst = rewriter.create<ConstantOp>(
+      loc, broadcastScalarConstFloatValue(realType,
+                                          APFloat(elementType.getScale())));
+  return rewriter.create<MulFOp>(loc, input, scaleConst);
+}
+
+static Value *
+emitUniformPerAxisDequantize(Location loc, Value *input,
+                             UniformQuantizedPerAxisType elementType,
+                             PatternRewriter &rewriter) {
+  // TODO: Support per-axis dequantize.
+  rewriter.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Warning)
+      << "unimplemented: per-axis uniform dequantization";
+  return nullptr;
+}
+
+static Value *emitDequantize(Location loc, Value *input,
+                             PatternRewriter &rewriter) {
+  Type inputType = input->getType();
+  QuantizedType qElementType =
+      QuantizedType::getQuantizedElementType(inputType);
+  if (auto uperLayerElementType =
+          qElementType.dyn_cast_or_null<UniformQuantizedType>()) {
+    return emitUniformPerLayerDequantize(loc, input, uperLayerElementType,
+                                         rewriter);
+  } else if (auto uperAxisElementType =
+                 qElementType.dyn_cast_or_null<UniformQuantizedPerAxisType>()) {
+    return emitUniformPerAxisDequantize(loc, input, uperAxisElementType,
+                                        rewriter);
+  } else {
+    return nullptr;
+  }
+}
+
+namespace {
+
+struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> {
+  using OpRewritePattern<DequantizeCastOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(DequantizeCastOp op,
+                                     PatternRewriter &rewriter) const {
+    Type inputType = op.arg()->getType();
+    Type outputType = op.getResult()->getType();
+
+    QuantizedType inputElementType =
+        QuantizedType::getQuantizedElementType(inputType);
+    Type expressedOutputType = inputElementType.castToExpressedType(inputType);
+    if (expressedOutputType != outputType) {
+      // Not a valid uniform cast.
+      return matchFailure();
+    }
+
+    Value *dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter);
+    if (!dequantizedValue) {
+      return matchFailure();
+    }
+
+    rewriter.replaceOp(op, dequantizedValue);
+    return matchSuccess();
+  }
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Elementwise add
+//===----------------------------------------------------------------------===//
+
+static LogicalResult
+tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info,
+                                      PatternRewriter &rewriter) {
+  if (!info.resultType.isSigned() || info.lhsType != info.resultType ||
+      info.rhsType != info.resultType) {
+    return failure();
+  }
+
+  // Choose a byte aligned intermediate width big enough to perform the
+  // calculation without overflow.
+  // TODO: This should probably be made just big enough to avoid overflow and
+  // leave the downstream tooling to decide how to align that to machine
+  // word sizes.
+  unsigned intermediateWidth =
+      info.resultType.getStorageTypeIntegralWidth() <= 8 ? 16 : 32;
+  IntegerType intermediateElementType =
+      IntegerType::get(intermediateWidth, rewriter.getContext());
+  Type intermediateType =
+      castElementType(info.resultStorageType, intermediateElementType);
+
+  // Cast operands to storage type.
+  Value *lhsValue = rewriter
+                        .create<StorageCastOp>(info.op->getLoc(),
+                                               info.lhsStorageType, info.lhs)
+                        .getResult();
+  Value *rhsValue = rewriter
+                        .create<StorageCastOp>(info.op->getLoc(),
+                                               info.rhsStorageType, info.rhs)
+                        .getResult();
+
+  // Cast to the intermediate sized type.
+  lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
+                                          lhsValue);
+  rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
+                                          rhsValue);
+
+  // Add.
+  Value *resultValue =
+      rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, rhsValue);
+
+  // Zero point offset adjustment.
+  // result = (lhs - zp) + (rhs - zp) + zp
+  // zpOffset = -zp
+  int zpOffset = -1 * info.resultType.getZeroPoint();
+  if (zpOffset != 0) {
+    Value *zpOffsetConst = rewriter.create<ConstantOp>(
+        info.op->getLoc(),
+        broadcastScalarConstIntValue(intermediateType, zpOffset));
+    resultValue =
+        rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
+  }
+
+  // Clamp.
+  auto clampMinMax = info.getClampMinMax(intermediateElementType);
+  resultValue = rewriter.create<ClampISOp>(
+      info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
+
+  // Convert back to original type.
+  resultValue = rewriter.create<ConvertISOp>(
+      info.op->getLoc(), info.resultStorageType, resultValue);
+
+  // Cast back for new result.
+  rewriter.replaceOpWithNewOp<StorageCastOp>(
+      info.op, info.getQuantizedResultType(), resultValue);
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Elementwise mul
+//===----------------------------------------------------------------------===//
+
+static LogicalResult
+tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info,
+                            PatternRewriter &rewriter) {
+  if (!info.resultType.isSigned()) {
+    return failure();
+  }
+
+  double outputMultiplierReal = info.lhsType.getScale() *
+                                info.rhsType.getScale() /
+                                info.resultType.getScale();
+  if (outputMultiplierReal > 1.0) {
+    info.op->emitWarning("unimplemented: cannot multiply with multipler > 1.0");
+    return failure();
+  }
+
+  // TODO: Choose an appropriate intermediate width for muls > 8 bits to
+  // avoid overflow.
+  unsigned intermediateWidth = 32;
+  IntegerType intermediateElementType =
+      IntegerType::get(intermediateWidth, rewriter.getContext());
+  Type intermediateType =
+      castElementType(info.resultStorageType, intermediateElementType);
+
+  // Cast operands to storage type.
+  Value *lhsValue = rewriter
+                        .create<StorageCastOp>(info.op->getLoc(),
+                                               info.lhsStorageType, info.lhs)
+                        .getResult();
+  Value *rhsValue = rewriter
+                        .create<StorageCastOp>(info.op->getLoc(),
+                                               info.rhsStorageType, info.rhs)
+                        .getResult();
+
+  // Cast to the intermediate sized type.
+  lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
+                                          lhsValue);
+  rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
+                                          rhsValue);
+
+  // Apply argument zeroPoints.
+  if (info.lhsType.getZeroPoint() != 0) {
+    Value *zpOffsetConst = rewriter.create<ConstantOp>(
+        info.op->getLoc(), broadcastScalarConstIntValue(
+                               intermediateType, -info.lhsType.getZeroPoint()));
+    lhsValue =
+        rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, zpOffsetConst);
+  }
+
+  if (info.rhsType.getZeroPoint() != 0) {
+    Value *zpOffsetConst = rewriter.create<ConstantOp>(
+        info.op->getLoc(), broadcastScalarConstIntValue(
+                               intermediateType, -info.rhsType.getZeroPoint()));
+    rhsValue =
+        rewriter.create<AddIOp>(info.op->getLoc(), rhsValue, zpOffsetConst);
+  }
+
+  // Mul.
+  Value *resultValue =
+      rewriter.create<MulIOp>(info.op->getLoc(), lhsValue, rhsValue);
+
+  // Scale output.
+  QuantizedMultiplierSmallerThanOneExp outputMultiplier(outputMultiplierReal);
+  resultValue = rewriter.create<VecScalarSaturatingRoundingDoublingHighMulISOp>(
+      info.op->getLoc(), resultValue,
+      IntegerAttr::get(intermediateElementType, outputMultiplier.multiplier));
+  resultValue = rewriter.create<RoundingDivideByPotISOp>(
+      info.op->getLoc(), resultValue,
+      IntegerAttr::get(intermediateElementType, -outputMultiplier.exponent));
+
+  // Zero point offset adjustment.
+  if (info.resultType.getZeroPoint() != 0) {
+    Value *zpOffsetConst = rewriter.create<ConstantOp>(
+        info.op->getLoc(),
+        broadcastScalarConstIntValue(intermediateType,
+                                     info.resultType.getZeroPoint()));
+    resultValue =
+        rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
+  }
+
+  // Clamp.
+  auto clampMinMax = info.getClampMinMax(intermediateElementType);
+  resultValue = rewriter.create<ClampISOp>(
+      info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
+
+  // Convert back to original type.
+  resultValue = rewriter.create<ConvertISOp>(
+      info.op->getLoc(), info.resultStorageType, resultValue);
+
+  // Cast back for new result.
+  rewriter.replaceOpWithNewOp<StorageCastOp>(
+      info.op, info.getQuantizedResultType(), resultValue);
+
+  return success();
+}
+
+namespace {
+
+struct UniformRealAddEwPattern : public OpRewritePattern<RealAddEwOp> {
+  using OpRewritePattern<RealAddEwOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(RealAddEwOp op,
+                                     PatternRewriter &rewriter) const {
+    const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
+                                   op.clamp_max());
+    if (!info.isValid()) {
+      return matchFailure();
+    }
+
+    // Try all of the permutations we support.
+    if (succeeded(tryRewriteAffineAddEwIsomorphicSigned(info, rewriter))) {
+      return matchSuccess();
+    }
+
+    return matchFailure();
+  }
+};
+
+struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> {
+  using OpRewritePattern<RealMulEwOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(RealMulEwOp op,
+                                     PatternRewriter &rewriter) const {
+    const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
+                                   op.clamp_max());
+    if (!info.isValid()) {
+      return matchFailure();
+    }
+
+    // Try all of the permutations we support.
+    if (succeeded(tryRewriteAffineMulEwSigned(info, rewriter))) {
+      return matchSuccess();
+    }
+
+    return matchFailure();
+  }
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// LowerUniformRealMath pass
+//===----------------------------------------------------------------------===//
+
+void LowerUniformRealMathPass::runOnFunction() {
+  auto fn = getFunction();
+  OwningRewritePatternList patterns;
+  auto *context = &getContext();
+  patterns.insert<UniformRealAddEwPattern, UniformRealMulEwPattern>(context);
+  applyPatternsGreedily(fn, patterns);
+}
+
+FunctionPassBase *mlir::fxpmath::createLowerUniformRealMathPass() {
+  return new LowerUniformRealMathPass();
+}
+
+static PassRegistration<LowerUniformRealMathPass> lowerUniformRealMathPass(
+    "fxpmath-lower-uniform-real-math",
+    "Lowers uniform-quantized real math ops to integer arithmetic.");
+
+//===----------------------------------------------------------------------===//
+// LowerUniformCasts pass
+//===----------------------------------------------------------------------===//
+
+void LowerUniformCastsPass::runOnFunction() {
+  auto fn = getFunction();
+  OwningRewritePatternList patterns;
+  auto *context = &getContext();
+  patterns.insert<UniformDequantizePattern>(context);
+  applyPatternsGreedily(fn, patterns);
+}
+
+FunctionPassBase *mlir::fxpmath::createLowerUniformCastsPass() {
+  return new LowerUniformCastsPass();
+}
+
+static PassRegistration<LowerUniformCastsPass>
+    lowerUniformCastsPass("fxpmath-lower-uniform-casts",
+                          "Lowers uniform-quantized casts.");
diff --git a/third_party/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h b/third_party/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h
new file mode 100644
index 0000000..f0eeba0
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h
@@ -0,0 +1,236 @@
+//===- UniformKernelUtils.h - Utilities for lowering uniform math - C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
+#define MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
+
+#include "mlir/Dialect/QuantOps/QuantOps.h"
+#include "mlir/Dialect/QuantOps/QuantTypes.h"
+#include "mlir/Dialect/QuantOps/UniformSupport.h"
+#include "mlir/IR/Operation.h"
+
+#include <cmath>
+
+namespace mlir {
+namespace fxpmath {
+namespace detail {
+
+inline quant::UniformQuantizedType getUniformElementType(Type t) {
+  return quant::QuantizedType::getQuantizedElementType(t)
+      .dyn_cast_or_null<quant::UniformQuantizedType>();
+}
+
+inline bool hasStorageBitWidth(quant::QuantizedType t,
+                               llvm::ArrayRef<unsigned> checkWidths) {
+  unsigned w = t.getStorageType().getIntOrFloatBitWidth();
+  for (unsigned checkWidth : checkWidths) {
+    if (w == checkWidth)
+      return true;
+  }
+  return false;
+}
+
+/// Computes the log2(x), rounded to an integral value. Returns whether 'x' can
+/// be considered an exact integral value.
+template <typename F> bool integralLog2(F x, int &log2Result) {
+  const F xLog2 = std::log(x) * (1.0 / std::log(2.0));
+  const F xLog2Rounded = std::round(xLog2);
+  const F xLog2Frac = xLog2 - xLog2Rounded;
+  log2Result = static_cast<int>(xLog2Rounded);
+  // Allow small comparison slop below the level that would make a difference
+  // for 2^16 levels.
+  return std::abs(xLog2Frac) < 1e-6;
+}
+
+/// Helper class for operating on binary operations where all operands
+/// and the result are a UniformQuantizedType.
+struct UniformBinaryOpInfo {
+  UniformBinaryOpInfo(Operation *op, Value *lhs, Value *rhs,
+                      Optional<APFloat> clampMin, Optional<APFloat> clampMax)
+      : op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax),
+        lhsType(getUniformElementType(lhs->getType())),
+        rhsType(getUniformElementType(rhs->getType())),
+        resultType(getUniformElementType(*op->result_type_begin())),
+        lhsStorageType(quant::QuantizedType::castToStorageType(lhs->getType())),
+        rhsStorageType(quant::QuantizedType::castToStorageType(rhs->getType())),
+        resultStorageType(
+            quant::QuantizedType::castToStorageType(*op->result_type_begin())) {
+  }
+
+  /// Returns whether this info is valid (all types defined, etc).
+  bool isValid() const {
+    return lhsType && rhsType && resultType && lhsStorageType &&
+           rhsStorageType && resultStorageType;
+  }
+
+  /// Gets the final quantized result type of the result.
+  Type getQuantizedResultType() const { return *op->result_type_begin(); }
+
+  /// Returns whether the storage type of all operands is identical.
+  bool isSameStorageType() const {
+    return lhsType.getStorageType() == rhsType.getStorageType() &&
+           lhsType.getStorageType() == resultType.getStorageType();
+  }
+
+  /// Returns whether all operands and result are considered fixedpoint power
+  /// of two, setting the lhs, rhs, and result log2 scale references.
+  bool isFixedPointPOT(int &lhsLog2Scale, int &rhsLog2Scale,
+                       int &resultLog2Scale) const {
+    if (!lhsType.isFixedPoint() || !rhsType.isFixedPoint() ||
+        !resultType.isFixedPoint()) {
+      return false;
+    }
+
+    if (!integralLog2(lhsType.getScale(), lhsLog2Scale) ||
+        !integralLog2(rhsType.getScale(), rhsLog2Scale) ||
+        !integralLog2(resultType.getScale(), resultLog2Scale)) {
+      return false;
+    }
+
+    return true;
+  }
+
+  /// Gets the result integer clamp range given the result quantized type
+  // and any explicit clamp provided as attributes.
+  std::pair<IntegerAttr, IntegerAttr> getClampMinMax(IntegerType ty) const {
+    int64_t typeMin = resultType.getStorageTypeMin();
+    int64_t typeMax = resultType.getStorageTypeMax();
+
+    if (clampMin || clampMax) {
+      quant::UniformQuantizedValueConverter conv(resultType);
+      if (clampMin) {
+        typeMin = std::max(typeMin, conv.quantizeFloatToInt64(*clampMin));
+      }
+      if (clampMax) {
+        typeMax = std::min(typeMax, conv.quantizeFloatToInt64(*clampMax));
+      }
+    }
+
+    // The quantized, integral ops expect clamps as 32bit ints.
+    return {
+        IntegerAttr::get(ty, typeMin),
+        IntegerAttr::get(ty, typeMax),
+    };
+  }
+
+  Operation *op;
+  Value *lhs;
+  Value *rhs;
+  Optional<APFloat> clampMin;
+  Optional<APFloat> clampMax;
+
+  // Element UniformQuantizedType for operands/result.
+  quant::UniformQuantizedType lhsType;
+  quant::UniformQuantizedType rhsType;
+  quant::UniformQuantizedType resultType;
+
+  // Full storage-based types.
+  Type lhsStorageType;
+  Type rhsStorageType;
+  Type resultStorageType;
+};
+
+/// Derives a quantized multiplier and shift from a real valued multiplier
+/// less than 1.
+struct QuantizedMultiplierSmallerThanOneExp {
+  QuantizedMultiplierSmallerThanOneExp(double realMultiplier) {
+    assert(realMultiplier < 1.0);
+    assert(realMultiplier > 0.0);
+
+    const double q = std::frexp(realMultiplier, &exponent);
+    auto qFixed = static_cast<int64_t>(std::round(q * (1ll << 31)));
+    assert(qFixed <= (1ll << 31));
+    if (qFixed == (1ll << 31)) {
+      qFixed /= 2;
+      ++exponent;
+    }
+    assert(qFixed <= std::numeric_limits<int32_t>::max());
+    multiplier = static_cast<int32_t>(qFixed);
+  }
+
+  int32_t multiplier;
+  int exponent;
+};
+
+/// Casts an integer or floating point based shaped type to a new element type.
+inline Type castElementType(Type t, Type newElementType) {
+  if (auto st = t.dyn_cast<ShapedType>()) {
+    switch (st.getKind()) {
+    case StandardTypes::Kind::Vector:
+      return VectorType::get(st.getShape(), newElementType);
+    case StandardTypes::Kind::RankedTensor:
+      return RankedTensorType::get(st.getShape(), newElementType);
+    case StandardTypes::Kind::UnrankedTensor:
+      return UnrankedTensorType::get(newElementType);
+    case StandardTypes::Kind::MemRef:
+      return MemRefType::get(st.getShape(), newElementType,
+                             st.cast<MemRefType>().getAffineMaps());
+    }
+  }
+  assert(t.isIntOrFloat());
+  return newElementType;
+}
+
+/// Creates an IntegerAttr with a type that matches the shape of 't' (which can
+/// be a scalar primitive or a shaped type).
+inline Attribute broadcastScalarConstIntValue(Type t, int64_t value) {
+  if (auto st = t.dyn_cast<ShapedType>()) {
+    assert(st.getElementType().isa<IntegerType>());
+    return DenseElementsAttr::get(st,
+                                  IntegerAttr::get(st.getElementType(), value));
+  }
+
+  auto integerType = t.cast<IntegerType>();
+  assert(t.isa<IntegerType>() && "integer broadcast must be of integer type");
+  return IntegerAttr::get(integerType, value);
+}
+
+/// Given an APFloat, converts it to the float semantics that matches the
+/// given FloatType, silently ignoring inexact conversions.
+inline APFloat convertFloatToType(FloatType ft, APFloat value) {
+  bool losesInfo;
+  auto status = value.convert(ft.getFloatSemantics(),
+                              APFloat::rmNearestTiesToEven, &losesInfo);
+  (void)status; // unused in opt mode
+  assert((status & (APFloat::opDivByZero | APFloat::opInvalidOp)) == 0 &&
+         "could not convert to float const");
+  return value;
+}
+
+/// Creates a FloatAttr with a type that matches the shape of 't' (which can be
+/// a scalar primitive or a shaped type).
+inline Attribute broadcastScalarConstFloatValue(Type t, APFloat value) {
+  if (auto st = t.dyn_cast<ShapedType>()) {
+    FloatType floatElementType = st.getElementType().dyn_cast<FloatType>();
+    assert(floatElementType &&
+           "float broadcast element type must be float like");
+    APFloat apValue = convertFloatToType(floatElementType, value);
+    return DenseElementsAttr::get(st,
+                                  FloatAttr::get(st.getElementType(), apValue));
+  } else {
+    auto floatType = t.dyn_cast<FloatType>();
+    assert(floatType && "float broadcast must be of float type");
+    APFloat apValue = convertFloatToType(floatType, value);
+    return FloatAttr::get(floatType, apValue);
+  }
+}
+
+} // namespace detail
+} // namespace fxpmath
+} // namespace mlir
+
+#endif // MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
diff --git a/third_party/mlir/lib/Dialect/GPU/CMakeLists.txt b/third_party/mlir/lib/Dialect/GPU/CMakeLists.txt
new file mode 100644
index 0000000..09da5cc
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -0,0 +1,10 @@
+add_llvm_library(MLIRGPU
+  IR/GPUDialect.cpp
+  IR/DialectRegistration.cpp
+  Transforms/KernelOutlining.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
+)
+add_dependencies(MLIRGPU MLIRGPUOpsIncGen MLIRIR LLVMSupport)
+target_link_libraries(MLIRGPU MLIRIR MLIRStandardOps LLVMSupport)
diff --git a/third_party/mlir/lib/Dialect/GPU/IR/DialectRegistration.cpp b/third_party/mlir/lib/Dialect/GPU/IR/DialectRegistration.cpp
new file mode 100644
index 0000000..af50d02
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/GPU/IR/DialectRegistration.cpp
@@ -0,0 +1,21 @@
+//===- DialectRegistration.cpp - MLIR GPU dialect registration ------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/GPU/GPUDialect.h"
+
+// Static initialization for GPU dialect registration.
+static mlir::DialectRegistration<mlir::gpu::GPUDialect> kernelDialect;
diff --git a/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
new file mode 100644
index 0000000..2fbaa49
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -0,0 +1,454 @@
+//===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the GPU kernel-related dialect and its operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/StandardOps/Ops.h"
+
+using namespace mlir;
+using namespace mlir::gpu;
+
+StringRef GPUDialect::getDialectName() { return "gpu"; }
+
+bool GPUDialect::isKernel(FuncOp function) {
+  UnitAttr isKernelAttr =
+      function.getAttrOfType<UnitAttr>(getKernelFuncAttrName());
+  return static_cast<bool>(isKernelAttr);
+}
+
+GPUDialect::GPUDialect(MLIRContext *context)
+    : Dialect(getDialectName(), context) {
+  addOperations<LaunchOp, LaunchFuncOp,
+#define GET_OP_LIST
+#include "mlir/Dialect/GPU/GPUOps.cpp.inc"
+                >();
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/GPU/GPUOps.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// LaunchOp
+//===----------------------------------------------------------------------===//
+
+static SmallVector<Type, 4> getValueTypes(ArrayRef<Value *> values) {
+  SmallVector<Type, 4> types;
+  types.reserve(values.size());
+  for (Value *v : values)
+    types.push_back(v->getType());
+  return types;
+}
+
+void LaunchOp::build(Builder *builder, OperationState *result, Value *gridSizeX,
+                     Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX,
+                     Value *blockSizeY, Value *blockSizeZ,
+                     ArrayRef<Value *> operands) {
+  // Add grid and block sizes as op operands, followed by the data operands.
+  result->addOperands(
+      {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
+  result->addOperands(operands);
+
+  // Create a kernel body region with kNumConfigRegionAttributes + N arguments,
+  // where the first kNumConfigRegionAttributes arguments have `index` type and
+  // the rest have the same types as the data operands.
+  Region *kernelRegion = result->addRegion();
+  Block *body = new Block();
+  body->addArguments(
+      std::vector<Type>(kNumConfigRegionAttributes, builder->getIndexType()));
+  body->addArguments(getValueTypes(operands));
+  kernelRegion->push_back(body);
+}
+
+Region &LaunchOp::getBody() { return getOperation()->getRegion(0); }
+
+KernelDim3 LaunchOp::getBlockIds() {
+  assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty.");
+  auto args = getBody().getBlocks().front().getArguments();
+  return KernelDim3{args[0], args[1], args[2]};
+}
+
+KernelDim3 LaunchOp::getThreadIds() {
+  assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty.");
+  auto args = getBody().getBlocks().front().getArguments();
+  return KernelDim3{args[3], args[4], args[5]};
+}
+
+KernelDim3 LaunchOp::getGridSize() {
+  assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty.");
+  auto args = getBody().getBlocks().front().getArguments();
+  return KernelDim3{args[6], args[7], args[8]};
+}
+
+KernelDim3 LaunchOp::getBlockSize() {
+  assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty.");
+  auto args = getBody().getBlocks().front().getArguments();
+  return KernelDim3{args[9], args[10], args[11]};
+}
+
+LaunchOp::operand_range LaunchOp::getKernelOperandValues() {
+  return llvm::drop_begin(getOperands(), kNumConfigOperands);
+}
+
+LaunchOp::operand_type_range LaunchOp::getKernelOperandTypes() {
+  return llvm::drop_begin(getOperandTypes(), kNumConfigOperands);
+}
+
+KernelDim3 LaunchOp::getGridSizeOperandValues() {
+  return KernelDim3{getOperand(0), getOperand(1), getOperand(2)};
+}
+
+KernelDim3 LaunchOp::getBlockSizeOperandValues() {
+  return KernelDim3{getOperand(3), getOperand(4), getOperand(5)};
+}
+
+llvm::iterator_range<Block::args_iterator> LaunchOp::getKernelArguments() {
+  auto args = getBody().getBlocks().front().getArguments();
+  return llvm::drop_begin(args, LaunchOp::kNumConfigRegionAttributes);
+}
+
+LogicalResult LaunchOp::verify() {
+  // Kernel launch takes kNumConfigOperands leading operands for grid/block
+  // sizes and transforms them into kNumConfigRegionAttributes region arguments
+  // for block/thread identifiers and grid/block sizes.
+  if (!getBody().empty()) {
+    Block &entryBlock = getBody().front();
+    if (entryBlock.getNumArguments() != kNumConfigOperands + getNumOperands())
+      return emitError("unexpected number of region arguments");
+  }
+
+  // Block terminators without successors are expected to exit the kernel region
+  // and must be `gpu.launch`.
+  for (Block &block : getBody()) {
+    if (block.empty())
+      continue;
+    if (block.back().getNumSuccessors() != 0)
+      continue;
+    if (!isa<gpu::Return>(&block.back())) {
+      return block.back()
+                 .emitError("expected 'gpu.terminator' or a terminator with "
+                            "successors")
+                 .attachNote(getLoc())
+             << "in '" << getOperationName() << "' body region";
+    }
+  }
+
+  return success();
+}
+
+// Pretty-print the kernel grid/block size assignment as
+//   (%iter-x, %iter-y, %iter-z) in
+//   (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
+// where %size-* and %iter-* will correspond to the body region arguments.
+static void printSizeAssignment(OpAsmPrinter *p, KernelDim3 size,
+                                ArrayRef<Value *> operands, KernelDim3 ids) {
+  *p << '(' << *ids.x << ", " << *ids.y << ", " << *ids.z << ") in (";
+  *p << *size.x << " = " << *operands[0] << ", ";
+  *p << *size.y << " = " << *operands[1] << ", ";
+  *p << *size.z << " = " << *operands[2] << ')';
+}
+
+void LaunchOp::print(OpAsmPrinter *p) {
+  SmallVector<Value *, 12> operandContainer(operand_begin(), operand_end());
+  ArrayRef<Value *> operands(operandContainer);
+
+  // Print the launch configuration.
+  *p << getOperationName() << ' ' << getBlocksKeyword();
+  printSizeAssignment(p, getGridSize(), operands.take_front(3), getBlockIds());
+  *p << ' ' << getThreadsKeyword();
+  printSizeAssignment(p, getBlockSize(), operands.slice(3, 3), getThreadIds());
+
+  // From now on, the first kNumConfigOperands operands corresponding to grid
+  // and block sizes are irrelevant, so we can drop them.
+  operands = operands.drop_front(kNumConfigOperands);
+
+  // Print the data argument remapping.
+  if (!getBody().empty() && !operands.empty()) {
+    *p << ' ' << getArgsKeyword() << '(';
+    for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+      if (i != 0)
+        *p << ", ";
+      *p << *getBody().front().getArgument(kNumConfigRegionAttributes + i)
+         << " = " << *operands[i];
+    }
+    *p << ") ";
+  }
+
+  // Print the types of data arguments.
+  if (!operands.empty()) {
+    *p << ": ";
+    for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+      if (i != 0)
+        *p << ", ";
+      *p << operands[i]->getType();
+    }
+  }
+
+  p->printRegion(getBody(), /*printEntryBlockArgs=*/false);
+  p->printOptionalAttrDict(getAttrs());
+}
+
+// Parse the size assignment blocks for blocks and threads.  These have the form
+//   (%region_arg, %region_arg, %region_arg) in
+//   (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
+// where %region_arg are percent-identifiers for the region arguments to be
+// introduced futher (SSA defs), and %operand are percent-identifiers for the
+// SSA value uses.
+static ParseResult
+parseSizeAssignment(OpAsmParser *parser,
+                    MutableArrayRef<OpAsmParser::OperandType> sizes,
+                    MutableArrayRef<OpAsmParser::OperandType> regionSizes,
+                    MutableArrayRef<OpAsmParser::OperandType> indices) {
+  assert(indices.size() == 3 && "space for three indices expected");
+  SmallVector<OpAsmParser::OperandType, 3> args;
+  if (parser->parseRegionArgumentList(args, /*requiredOperandCount=*/3,
+                                      OpAsmParser::Delimiter::Paren) ||
+      parser->parseKeyword("in") || parser->parseLParen())
+    return failure();
+  std::move(args.begin(), args.end(), indices.begin());
+
+  for (int i = 0; i < 3; ++i) {
+    if (i != 0 && parser->parseComma())
+      return failure();
+    if (parser->parseRegionArgument(regionSizes[i]) || parser->parseEqual() ||
+        parser->parseOperand(sizes[i]))
+      return failure();
+  }
+
+  return parser->parseRParen();
+}
+
+// Parses a Launch operation.
+// operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
+//                           `threads` `(` ssa-id-list `)` `in` ssa-reassignment
+//                             (`args` ssa-reassignment `:` type-list)?
+//                             region attr-dict?
+// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
+ParseResult LaunchOp::parse(OpAsmParser *parser, OperationState *result) {
+  // Sizes of the grid and block.
+  SmallVector<OpAsmParser::OperandType, kNumConfigOperands> sizes(
+      kNumConfigOperands);
+  MutableArrayRef<OpAsmParser::OperandType> sizesRef(sizes);
+
+  // Actual (data) operands passed to the kernel.
+  SmallVector<OpAsmParser::OperandType, 4> dataOperands;
+
+  // Region arguments to be created.
+  SmallVector<OpAsmParser::OperandType, 16> regionArgs(
+      kNumConfigRegionAttributes);
+  MutableArrayRef<OpAsmParser::OperandType> regionArgsRef(regionArgs);
+
+  // Parse the size assignment segments: the first segment assigns grid siezs
+  // and defines values for block identifiers; the second segment assigns block
+  // sies and defines values for thread identifiers.  In the region argument
+  // list, identifiers preceed sizes, and block-related values preceed
+  // thread-related values.
+  if (parser->parseKeyword(getBlocksKeyword().data()) ||
+      parseSizeAssignment(parser, sizesRef.take_front(3),
+                          regionArgsRef.slice(6, 3),
+                          regionArgsRef.slice(0, 3)) ||
+      parser->parseKeyword(getThreadsKeyword().data()) ||
+      parseSizeAssignment(parser, sizesRef.drop_front(3),
+                          regionArgsRef.slice(9, 3),
+                          regionArgsRef.slice(3, 3)) ||
+      parser->resolveOperands(sizes, parser->getBuilder().getIndexType(),
+                              result->operands))
+    return failure();
+
+  // If kernel argument renaming segment is present, parse it.  When present,
+  // the segment should have at least one element.  If this segment is present,
+  // so is the trailing type list.  Parse it as well and use the parsed types
+  // to resolve the operands passed to the kernel arguments.
+  SmallVector<Type, 4> dataTypes;
+  if (!parser->parseOptionalKeyword(getArgsKeyword().data())) {
+    llvm::SMLoc argsLoc = parser->getCurrentLocation();
+
+    regionArgs.push_back({});
+    dataOperands.push_back({});
+    if (parser->parseLParen() ||
+        parser->parseRegionArgument(regionArgs.back()) ||
+        parser->parseEqual() || parser->parseOperand(dataOperands.back()))
+      return failure();
+
+    while (!parser->parseOptionalComma()) {
+      regionArgs.push_back({});
+      dataOperands.push_back({});
+      if (parser->parseRegionArgument(regionArgs.back()) ||
+          parser->parseEqual() || parser->parseOperand(dataOperands.back()))
+        return failure();
+    }
+
+    if (parser->parseRParen() || parser->parseColonTypeList(dataTypes) ||
+        parser->resolveOperands(dataOperands, dataTypes, argsLoc,
+                                result->operands))
+      return failure();
+  }
+
+  // Introduce the body region and parse it.  The region has
+  // kNumConfigRegionAttributes leading arguments that correspond to
+  // block/thread identifiers and grid/block sizes, all of the `index` type.
+  // Follow the actual kernel arguments.
+  Type index = parser->getBuilder().getIndexType();
+  dataTypes.insert(dataTypes.begin(), kNumConfigRegionAttributes, index);
+  Region *body = result->addRegion();
+  return failure(parser->parseRegion(*body, regionArgs, dataTypes) ||
+                 parser->parseOptionalAttributeDict(result->attributes));
+}
+
+void LaunchOp::eraseKernelArgument(unsigned index) {
+  Block &entryBlock = getBody().front();
+  assert(index < entryBlock.getNumArguments() - kNumConfigRegionAttributes &&
+         "kernel argument index overflow");
+  entryBlock.eraseArgument(kNumConfigRegionAttributes + index);
+  getOperation()->eraseOperand(kNumConfigOperands + index);
+}
+
+namespace {
+// Clone any known constants passed as operands to the kernel into its body.
+class PropagateConstantBounds : public OpRewritePattern<LaunchOp> {
+  using OpRewritePattern<LaunchOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(LaunchOp launchOp,
+                                     PatternRewriter &rewriter) const override {
+    auto oringInsertionPoint = rewriter.saveInsertionPoint();
+    rewriter.setInsertionPointToStart(&launchOp.getBody().front());
+
+    // Traverse operands passed to kernel and check if some of them are known
+    // constants.  If so, clone the constant operation inside the kernel region
+    // and use it instead of passing the value from the parent region.  Perform
+    // the traversal in the inverse order to simplify index arithmetics when
+    // dropping arguments.
+    SmallVector<Value *, 8> operands(launchOp.getKernelOperandValues().begin(),
+                                     launchOp.getKernelOperandValues().end());
+    SmallVector<Value *, 8> kernelArgs(launchOp.getKernelArguments().begin(),
+                                       launchOp.getKernelArguments().end());
+    bool found = false;
+    for (unsigned i = operands.size(); i > 0; --i) {
+      unsigned index = i - 1;
+      Value *operand = operands[index];
+      if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp())) {
+        continue;
+      }
+
+      found = true;
+      Value *internalConstant =
+          rewriter.clone(*operand->getDefiningOp())->getResult(0);
+      Value *kernelArg = kernelArgs[index];
+      kernelArg->replaceAllUsesWith(internalConstant);
+      launchOp.eraseKernelArgument(index);
+    }
+    rewriter.restoreInsertionPoint(oringInsertionPoint);
+
+    if (!found)
+      return matchFailure();
+
+    rewriter.updatedRootInPlace(launchOp);
+    return matchSuccess();
+  }
+};
+} // end namespace
+
+void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                           MLIRContext *context) {
+  results.insert<PropagateConstantBounds>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// LaunchFuncOp
+//===----------------------------------------------------------------------===//
+
+void LaunchFuncOp::build(Builder *builder, OperationState *result,
+                         FuncOp kernelFunc, Value *gridSizeX, Value *gridSizeY,
+                         Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY,
+                         Value *blockSizeZ, ArrayRef<Value *> kernelOperands) {
+  // Add grid and block sizes as op operands, followed by the data operands.
+  result->addOperands(
+      {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
+  result->addOperands(kernelOperands);
+  result->addAttribute(getKernelAttrName(),
+                       builder->getSymbolRefAttr(kernelFunc));
+}
+
+void LaunchFuncOp::build(Builder *builder, OperationState *result,
+                         FuncOp kernelFunc, KernelDim3 gridSize,
+                         KernelDim3 blockSize,
+                         ArrayRef<Value *> kernelOperands) {
+  build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z,
+        blockSize.x, blockSize.y, blockSize.z, kernelOperands);
+}
+
+StringRef LaunchFuncOp::kernel() {
+  return getAttrOfType<SymbolRefAttr>(getKernelAttrName()).getValue();
+}
+
+unsigned LaunchFuncOp::getNumKernelOperands() {
+  return getNumOperands() - kNumConfigOperands;
+}
+
+Value *LaunchFuncOp::getKernelOperand(unsigned i) {
+  return getOperation()->getOperand(i + kNumConfigOperands);
+}
+
+KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
+  return KernelDim3{getOperand(0), getOperand(1), getOperand(2)};
+}
+
+KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
+  return KernelDim3{getOperand(3), getOperand(4), getOperand(5)};
+}
+
+LogicalResult LaunchFuncOp::verify() {
+  auto kernelAttr = this->getAttr(getKernelAttrName());
+  if (!kernelAttr) {
+    return emitOpError("attribute 'kernel' must be specified");
+  } else if (!kernelAttr.isa<SymbolRefAttr>()) {
+    return emitOpError("attribute 'kernel' must be a function");
+  }
+
+  auto module = getParentOfType<ModuleOp>();
+  FuncOp kernelFunc = module.lookupSymbol<FuncOp>(kernel());
+  if (!kernelFunc)
+    return emitError() << "kernel function '" << kernelAttr << "' is undefined";
+
+  if (!kernelFunc.getAttrOfType<mlir::UnitAttr>(
+          GPUDialect::getKernelFuncAttrName())) {
+    return emitError("kernel function is missing the '")
+           << GPUDialect::getKernelFuncAttrName() << "' attribute";
+  }
+  unsigned numKernelFuncArgs = kernelFunc.getNumArguments();
+  if (getNumKernelOperands() != numKernelFuncArgs) {
+    return emitOpError("got ")
+           << getNumKernelOperands() << " kernel operands but expected "
+           << numKernelFuncArgs;
+  }
+  auto functionType = kernelFunc.getType();
+  for (unsigned i = 0; i < numKernelFuncArgs; ++i) {
+    if (getKernelOperand(i)->getType() != functionType.getInput(i)) {
+      return emitOpError("type of function argument ")
+             << i << " does not match";
+    }
+  }
+  return success();
+}
diff --git a/third_party/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/third_party/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
new file mode 100644
index 0000000..01decce
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -0,0 +1,118 @@
+//===- KernelOutlining.cpp - Implementation of GPU kernel outling ---------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the GPU dialect kernel outlining pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/GPU/Passes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+
+using namespace mlir;
+
+template <typename OpTy>
+static void createForAllDimensions(OpBuilder &builder, Location loc,
+                                   SmallVectorImpl<Value *> &values) {
+  for (StringRef dim : {"x", "y", "z"}) {
+    Value *v = builder.create<OpTy>(loc, builder.getIndexType(),
+                                    builder.getStringAttr(dim));
+    values.push_back(v);
+  }
+}
+
+// Add operations generating block/thread ids and gird/block dimensions at the
+// beginning of `kernelFunc` and replace uses of the respective function args.
+static void injectGpuIndexOperations(Location loc, FuncOp kernelFunc) {
+  OpBuilder OpBuilder(kernelFunc.getBody());
+  SmallVector<Value *, 12> indexOps;
+  createForAllDimensions<gpu::BlockId>(OpBuilder, loc, indexOps);
+  createForAllDimensions<gpu::ThreadId>(OpBuilder, loc, indexOps);
+  createForAllDimensions<gpu::GridDim>(OpBuilder, loc, indexOps);
+  createForAllDimensions<gpu::BlockDim>(OpBuilder, loc, indexOps);
+  // Replace the leading 12 function args with the respective thread/block index
+  // operations. Iterate backwards since args are erased and indices change.
+  for (int i = 11; i >= 0; --i) {
+    auto &firstBlock = kernelFunc.front();
+    firstBlock.getArgument(i)->replaceAllUsesWith(indexOps[i]);
+    firstBlock.eraseArgument(i);
+  }
+}
+
+// Outline the `gpu.launch` operation body into a kernel function. Replace
+// `gpu.return` operations by `std.return` in the generated functions.
+static FuncOp outlineKernelFunc(gpu::LaunchOp launchOp) {
+  Location loc = launchOp.getLoc();
+  SmallVector<Type, 4> kernelOperandTypes(launchOp.getKernelOperandTypes());
+  FunctionType type =
+      FunctionType::get(kernelOperandTypes, {}, launchOp.getContext());
+  std::string kernelFuncName =
+      Twine(launchOp.getParentOfType<FuncOp>().getName(), "_kernel").str();
+  FuncOp outlinedFunc = FuncOp::create(loc, kernelFuncName, type);
+  outlinedFunc.getBody().takeBody(launchOp.getBody());
+  Builder builder(launchOp.getContext());
+  outlinedFunc.setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
+                       builder.getUnitAttr());
+  injectGpuIndexOperations(loc, outlinedFunc);
+  outlinedFunc.walk<mlir::gpu::Return>([](mlir::gpu::Return op) {
+    OpBuilder replacer(op);
+    replacer.create<ReturnOp>(op.getLoc());
+    op.erase();
+  });
+  return outlinedFunc;
+}
+
+// Replace `gpu.launch` operations with an `gpu.launch_func` operation launching
+// `kernelFunc`.
+static void convertToLaunchFuncOp(gpu::LaunchOp &launchOp, FuncOp kernelFunc) {
+  OpBuilder builder(launchOp);
+  SmallVector<Value *, 4> kernelOperandValues(
+      launchOp.getKernelOperandValues());
+  builder.create<gpu::LaunchFuncOp>(
+      launchOp.getLoc(), kernelFunc, launchOp.getGridSizeOperandValues(),
+      launchOp.getBlockSizeOperandValues(), kernelOperandValues);
+  launchOp.erase();
+}
+
+namespace {
+
+class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
+public:
+  void runOnModule() override {
+    ModuleManager moduleManager(getModule());
+    for (auto func : getModule().getOps<FuncOp>()) {
+      func.walk<mlir::gpu::LaunchOp>([&](mlir::gpu::LaunchOp op) {
+        FuncOp outlinedFunc = outlineKernelFunc(op);
+        moduleManager.insert(outlinedFunc);
+        convertToLaunchFuncOp(op, outlinedFunc);
+      });
+    }
+  }
+};
+
+} // namespace
+
+ModulePassBase *mlir::createGpuKernelOutliningPass() {
+  return new GpuKernelOutliningPass();
+}
+
+static PassRegistration<GpuKernelOutliningPass>
+    pass("gpu-kernel-outlining",
+         "Outline gpu.launch bodies to kernel functions.");
diff --git a/third_party/mlir/lib/Dialect/LoopOps/CMakeLists.txt b/third_party/mlir/lib/Dialect/LoopOps/CMakeLists.txt
new file mode 100644
index 0000000..ce4a666
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/LoopOps/CMakeLists.txt
@@ -0,0 +1,9 @@
+file(GLOB globbed *.c *.cpp)
+add_llvm_library(MLIRLoopOps
+  ${globbed}
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/LoopOps
+  )
+add_dependencies(MLIRLoopOps MLIRLoopOpsIncGen MLIRStandardOps LLVMSupport)
+target_link_libraries(MLIRLoopOps LLVMSupport)
diff --git a/third_party/mlir/lib/Dialect/LoopOps/DialectRegistration.cpp b/third_party/mlir/lib/Dialect/LoopOps/DialectRegistration.cpp
new file mode 100644
index 0000000..5724402
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/LoopOps/DialectRegistration.cpp
@@ -0,0 +1,22 @@
+//===- DialectRegistration.cpp - Register loop dialect --------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+using namespace mlir;
+
+// Static initialization for loop dialect registration.
+static DialectRegistration<loop::LoopOpsDialect> LoopOps;
diff --git a/third_party/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/third_party/mlir/lib/Dialect/LoopOps/LoopOps.cpp
new file mode 100644
index 0000000..13dc35e
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/LoopOps/LoopOps.cpp
@@ -0,0 +1,208 @@
+//===- Ops.cpp - Loop MLIR Operations -------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Value.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Support/MathExtras.h"
+#include "mlir/Support/STLExtras.h"
+
+using namespace mlir;
+using namespace mlir::loop;
+
+//===----------------------------------------------------------------------===//
+// LoopOpsDialect
+//===----------------------------------------------------------------------===//
+
+LoopOpsDialect::LoopOpsDialect(MLIRContext *context)
+    : Dialect(getDialectNamespace(), context) {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/LoopOps/LoopOps.cpp.inc"
+      >();
+}
+
+//===----------------------------------------------------------------------===//
+// ForOp
+//===----------------------------------------------------------------------===//
+
+void ForOp::build(Builder *builder, OperationState *result, Value *lb,
+                  Value *ub, Value *step) {
+  result->addOperands({lb, ub, step});
+  Region *bodyRegion = result->addRegion();
+  ForOp::ensureTerminator(*bodyRegion, *builder, result->location);
+  bodyRegion->front().addArgument(builder->getIndexType());
+}
+
+LogicalResult verify(ForOp op) {
+  if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step()->getDefiningOp()))
+    if (cst.getValue() <= 0)
+      return op.emitOpError("constant step operand must be nonnegative");
+
+  // Check that the body defines as single block argument for the induction
+  // variable.
+  auto *body = op.getBody();
+  if (body->getNumArguments() != 1 ||
+      !body->getArgument(0)->getType().isIndex())
+    return op.emitOpError("expected body to have a single index argument for "
+                          "the induction variable");
+  return success();
+}
+
+static void print(OpAsmPrinter *p, ForOp op) {
+  *p << op.getOperationName() << " " << *op.getInductionVar() << " = "
+     << *op.lowerBound() << " to " << *op.upperBound() << " step "
+     << *op.step();
+  p->printRegion(op.region(),
+                 /*printEntryBlockArgs=*/false,
+                 /*printBlockTerminators=*/false);
+  p->printOptionalAttrDict(op.getAttrs());
+}
+
+static ParseResult parseForOp(OpAsmParser *parser, OperationState *result) {
+  auto &builder = parser->getBuilder();
+  OpAsmParser::OperandType inductionVariable, lb, ub, step;
+  // Parse the induction variable followed by '='.
+  if (parser->parseRegionArgument(inductionVariable) || parser->parseEqual())
+    return failure();
+
+  // Parse loop bounds.
+  Type indexType = builder.getIndexType();
+  if (parser->parseOperand(lb) ||
+      parser->resolveOperand(lb, indexType, result->operands) ||
+      parser->parseKeyword("to") || parser->parseOperand(ub) ||
+      parser->resolveOperand(ub, indexType, result->operands) ||
+      parser->parseKeyword("step") || parser->parseOperand(step) ||
+      parser->resolveOperand(step, indexType, result->operands))
+    return failure();
+
+  // Parse the body region.
+  Region *body = result->addRegion();
+  if (parser->parseRegion(*body, inductionVariable, indexType))
+    return failure();
+
+  ForOp::ensureTerminator(*body, builder, result->location);
+
+  // Parse the optional attribute list.
+  if (parser->parseOptionalAttributeDict(result->attributes))
+    return failure();
+
+  return success();
+}
+
+ForOp mlir::loop::getForInductionVarOwner(Value *val) {
+  auto *ivArg = dyn_cast<BlockArgument>(val);
+  if (!ivArg)
+    return ForOp();
+  assert(ivArg->getOwner() && "unlinked block argument");
+  auto *containingInst = ivArg->getOwner()->getParentOp();
+  return dyn_cast_or_null<ForOp>(containingInst);
+}
+
+//===----------------------------------------------------------------------===//
+// IfOp
+//===----------------------------------------------------------------------===//
+
+void IfOp::build(Builder *builder, OperationState *result, Value *cond,
+                 bool withElseRegion) {
+  result->addOperands(cond);
+  Region *thenRegion = result->addRegion();
+  Region *elseRegion = result->addRegion();
+  IfOp::ensureTerminator(*thenRegion, *builder, result->location);
+  if (withElseRegion)
+    IfOp::ensureTerminator(*elseRegion, *builder, result->location);
+}
+
+static LogicalResult verify(IfOp op) {
+  // Verify that the entry of each child region does not have arguments.
+  for (auto &region : op.getOperation()->getRegions()) {
+    if (region.empty())
+      continue;
+
+    for (auto &b : region)
+      if (b.getNumArguments() != 0)
+        return op.emitOpError(
+            "requires that child entry blocks have no arguments");
+  }
+  return success();
+}
+
+static ParseResult parseIfOp(OpAsmParser *parser, OperationState *result) {
+  // Create the regions for 'then'.
+  result->regions.reserve(2);
+  Region *thenRegion = result->addRegion();
+  Region *elseRegion = result->addRegion();
+
+  auto &builder = parser->getBuilder();
+  OpAsmParser::OperandType cond;
+  Type i1Type = builder.getIntegerType(1);
+  if (parser->parseOperand(cond) ||
+      parser->resolveOperand(cond, i1Type, result->operands))
+    return failure();
+
+  // Parse the 'then' region.
+  if (parser->parseRegion(*thenRegion, {}, {}))
+    return failure();
+  IfOp::ensureTerminator(*thenRegion, parser->getBuilder(), result->location);
+
+  // If we find an 'else' keyword then parse the 'else' region.
+  if (!parser->parseOptionalKeyword("else")) {
+    if (parser->parseRegion(*elseRegion, {}, {}))
+      return failure();
+    IfOp::ensureTerminator(*elseRegion, parser->getBuilder(), result->location);
+  }
+
+  // Parse the optional attribute list.
+  if (parser->parseOptionalAttributeDict(result->attributes))
+    return failure();
+
+  return success();
+}
+
+static void print(OpAsmPrinter *p, IfOp op) {
+  *p << IfOp::getOperationName() << " " << *op.condition();
+  p->printRegion(op.thenRegion(),
+                 /*printEntryBlockArgs=*/false,
+                 /*printBlockTerminators=*/false);
+
+  // Print the 'else' regions if it exists and has a block.
+  auto &elseRegion = op.elseRegion();
+  if (!elseRegion.empty()) {
+    *p << " else";
+    p->printRegion(elseRegion,
+                   /*printEntryBlockArgs=*/false,
+                   /*printBlockTerminators=*/false);
+  }
+
+  p->printOptionalAttrDict(op.getAttrs());
+}
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/LoopOps/LoopOps.cpp.inc"
diff --git a/third_party/mlir/lib/Dialect/QuantOps/CMakeLists.txt b/third_party/mlir/lib/Dialect/QuantOps/CMakeLists.txt
new file mode 100644
index 0000000..74b3f3c
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/QuantOps/CMakeLists.txt
@@ -0,0 +1,21 @@
+add_llvm_library(MLIRQuantOps
+  IR/DialectRegistration.cpp
+  IR/QuantOps.cpp
+  IR/QuantTypes.cpp
+  IR/TypeDetail.h
+  IR/TypeParser.cpp
+  Transforms/ConvertConst.cpp
+  Transforms/ConvertSimQuant.cpp
+  Utils/QuantizeUtils.cpp
+  Utils/UniformSupport.cpp
+  Utils/FakeQuantSupport.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/QuantOps
+  )
+add_dependencies(MLIRQuantOps
+                 MLIRIR
+                 MLIRPass
+                 MLIRQuantOpsIncGen
+                 MLIRSupport
+                 MLIRStandardOps)
diff --git a/third_party/mlir/lib/Dialect/QuantOps/IR/DialectRegistration.cpp b/third_party/mlir/lib/Dialect/QuantOps/IR/DialectRegistration.cpp
new file mode 100644
index 0000000..b071248
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/QuantOps/IR/DialectRegistration.cpp
@@ -0,0 +1,24 @@
+//===- DialectRegistration.cpp - Register Quantization dialect ------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/QuantOps/QuantOps.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+
+// Static initialization for Quantization dialect registration.
+static mlir::DialectRegistration<QuantizationDialect> QuantizationOps;
diff --git a/third_party/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp b/third_party/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
new file mode 100644
index 0000000..3bd49d4
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
@@ -0,0 +1,74 @@
+//===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/QuantOps/QuantOps.h"
+#include "TypeDetail.h"
+
+#include "mlir/Dialect/QuantOps/QuantTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/MathExtras.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+using namespace mlir::quant::detail;
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/QuantOps/QuantOps.cpp.inc"
+
+namespace {
+
+/// Matches x -> [scast -> scast] -> y, replacing the second scast with the
+/// value of x if the casts invert each other.
+class RemoveRedundantStorageCastsRewrite
+    : public OpRewritePattern<StorageCastOp> {
+public:
+  using OpRewritePattern<StorageCastOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(StorageCastOp op,
+                                     PatternRewriter &rewriter) const override {
+    if (!matchPattern(op.arg(), m_Op<StorageCastOp>()))
+      return matchFailure();
+    auto srcScastOp = cast<StorageCastOp>(op.arg()->getDefiningOp());
+    if (srcScastOp.arg()->getType() != op.getType())
+      return matchFailure();
+
+    rewriter.replaceOp(op, srcScastOp.arg());
+    return matchSuccess();
+  }
+};
+
+} // end anonymous namespace
+
+void StorageCastOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context) {
+  patterns.insert<RemoveRedundantStorageCastsRewrite>(context);
+}
+
+QuantizationDialect::QuantizationDialect(MLIRContext *context)
+    : Dialect(/*name=*/"quant", context) {
+  addTypes<AnyQuantizedType, UniformQuantizedType,
+           UniformQuantizedPerAxisType>();
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/QuantOps/QuantOps.cpp.inc"
+      >();
+}
diff --git a/third_party/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp b/third_party/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp
new file mode 100644
index 0000000..6cc8ab0
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp
@@ -0,0 +1,412 @@
+//===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/QuantOps/QuantTypes.h"
+#include "TypeDetail.h"
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/StandardTypes.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/MathExtras.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+using namespace mlir::quant::detail;
+
+unsigned QuantizedType::getFlags() const {
+  return static_cast<ImplType *>(impl)->flags;
+}
+
+LogicalResult QuantizedType::verifyConstructionInvariants(
+    llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
+    Type storageType, Type expressedType, int64_t storageTypeMin,
+    int64_t storageTypeMax) {
+  // Verify that the storage type is integral.
+  // This restriction may be lifted at some point in favor of using bf16
+  // or f16 as exact representations on hardware where that is advantageous.
+  auto intStorageType = storageType.dyn_cast<IntegerType>();
+  if (!intStorageType) {
+    if (loc) {
+      emitError(*loc, "storage type must be integral");
+    }
+    return failure();
+  }
+  unsigned integralWidth = intStorageType.getWidth();
+
+  // Verify storage width.
+  if (integralWidth == 0 || integralWidth > MaxStorageBits) {
+    if (loc) {
+      emitError(*loc, "illegal storage type size: ") << integralWidth;
+    }
+    return failure();
+  }
+
+  // Verify storageTypeMin and storageTypeMax.
+  bool isSigned =
+      (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
+  int64_t defaultIntegerMin =
+      getDefaultMininumForInteger(isSigned, integralWidth);
+  int64_t defaultIntegerMax =
+      getDefaultMaxinumForInteger(isSigned, integralWidth);
+  if (storageTypeMax - storageTypeMin <= 0 ||
+      storageTypeMin < defaultIntegerMin ||
+      storageTypeMax > defaultIntegerMax) {
+    if (loc) {
+      emitError(*loc, "illegal storage min and storage max: (")
+          << storageTypeMin << ":" << storageTypeMax << ")";
+    }
+    return failure();
+  }
+  return success();
+}
+
+Type QuantizedType::getStorageType() const {
+  return static_cast<ImplType *>(impl)->storageType;
+}
+
+int64_t QuantizedType::getStorageTypeMin() const {
+  return static_cast<ImplType *>(impl)->storageTypeMin;
+}
+
+int64_t QuantizedType::getStorageTypeMax() const {
+  return static_cast<ImplType *>(impl)->storageTypeMax;
+}
+
+unsigned QuantizedType::getStorageTypeIntegralWidth() const {
+  // NOTE: If ever supporting non-integral storage types, some other scheme
+  // for determining the width will be needed.
+  return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
+}
+
+Type QuantizedType::getExpressedType() const {
+  return static_cast<ImplType *>(impl)->expressedType;
+}
+
+bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
+  if (candidateExpressedType.isa<ShapedType>()) {
+    return candidateExpressedType.cast<ShapedType>().getElementType() ==
+           getExpressedType();
+  }
+  return candidateExpressedType == getExpressedType();
+}
+
+QuantizedType
+QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) {
+  if (primitiveOrContainerType.isa<ShapedType>()) {
+    Type elementType =
+        primitiveOrContainerType.cast<ShapedType>().getElementType();
+    return elementType.dyn_cast<QuantizedType>();
+  }
+  return primitiveOrContainerType.dyn_cast<QuantizedType>();
+}
+
+Type QuantizedType::castFromStorageType(Type candidateType) {
+  if (candidateType == getStorageType()) {
+    // i.e. i32 -> quant<"uniform[i8:f32]{1.0}">
+    return *this;
+  } else if (candidateType.isa<RankedTensorType>()) {
+    // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+    return RankedTensorType::get(
+        candidateType.cast<RankedTensorType>().getShape(), getStorageType());
+  } else if (candidateType.isa<UnrankedTensorType>()) {
+    // i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">>
+    return UnrankedTensorType::get(getStorageType());
+  } else if (candidateType.isa<VectorType>()) {
+    // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+    return VectorType::get(candidateType.cast<VectorType>().getShape(),
+                           getStorageType());
+  }
+
+  return nullptr;
+}
+
+Type QuantizedType::castToStorageType(Type quantizedType) {
+  if (quantizedType.isa<QuantizedType>()) {
+    // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
+    return quantizedType.cast<QuantizedType>().getStorageType();
+  } else if (quantizedType.isa<ShapedType>()) {
+    // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+    ShapedType sType = quantizedType.cast<ShapedType>();
+    if (!sType.getElementType().isa<QuantizedType>()) {
+      return nullptr;
+    }
+    Type storageType =
+        sType.getElementType().cast<QuantizedType>().getStorageType();
+    if (quantizedType.isa<RankedTensorType>()) {
+      return RankedTensorType::get(sType.getShape(), storageType);
+    } else if (quantizedType.isa<UnrankedTensorType>()) {
+      return UnrankedTensorType::get(storageType);
+    } else if (quantizedType.isa<VectorType>()) {
+      return VectorType::get(sType.getShape(), storageType);
+    }
+  }
+
+  return nullptr;
+}
+
+Type QuantizedType::castFromExpressedType(Type candidateType) {
+  if (candidateType == getExpressedType()) {
+    // i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
+    return *this;
+  } else if (candidateType.isa<ShapedType>()) {
+    ShapedType candidateShapedType = candidateType.cast<ShapedType>();
+    if (candidateShapedType.getElementType() != getExpressedType()) {
+      return nullptr;
+    }
+
+    if (candidateType.isa<RankedTensorType>()) {
+      // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+      return RankedTensorType::get(candidateShapedType.getShape(), *this);
+    } else if (candidateType.isa<UnrankedTensorType>()) {
+      // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
+      return UnrankedTensorType::get(*this);
+    } else if (candidateType.isa<VectorType>()) {
+      // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+      return VectorType::get(candidateShapedType.getShape(), *this);
+    }
+  }
+
+  return nullptr;
+}
+
+Type QuantizedType::castToExpressedType(Type quantizedType) {
+  if (quantizedType.isa<QuantizedType>()) {
+    // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
+    return quantizedType.cast<QuantizedType>().getExpressedType();
+  } else if (quantizedType.isa<ShapedType>()) {
+    // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+    ShapedType sType = quantizedType.cast<ShapedType>();
+    if (!sType.getElementType().isa<QuantizedType>()) {
+      return nullptr;
+    }
+    Type expressedType =
+        sType.getElementType().cast<QuantizedType>().getExpressedType();
+    if (quantizedType.isa<RankedTensorType>()) {
+      return RankedTensorType::get(sType.getShape(), expressedType);
+    } else if (quantizedType.isa<UnrankedTensorType>()) {
+      return UnrankedTensorType::get(expressedType);
+    } else if (quantizedType.isa<VectorType>()) {
+      return VectorType::get(sType.getShape(), expressedType);
+    }
+  }
+
+  return nullptr;
+}
+
+Type QuantizedType::castExpressedToStorageType(Type candidateType) {
+  Type expressedQuantizedType = castFromExpressedType(candidateType);
+  if (!expressedQuantizedType) {
+    return nullptr;
+  }
+  return QuantizedType::castToStorageType(expressedQuantizedType);
+}
+
+AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
+                                       Type expressedType,
+                                       int64_t storageTypeMin,
+                                       int64_t storageTypeMax) {
+  return Base::get(storageType.getContext(), QuantizationTypes::Any, flags,
+                   storageType, expressedType, storageTypeMin, storageTypeMax);
+}
+
+AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType,
+                                              Type expressedType,
+                                              int64_t storageTypeMin,
+                                              int64_t storageTypeMax,
+                                              Location location) {
+  return Base::getChecked(location, storageType.getContext(),
+                          QuantizationTypes::Any, flags, storageType,
+                          expressedType, storageTypeMin, storageTypeMax);
+}
+
+LogicalResult AnyQuantizedType::verifyConstructionInvariants(
+    llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
+    Type storageType, Type expressedType, int64_t storageTypeMin,
+    int64_t storageTypeMax) {
+  if (failed(QuantizedType::verifyConstructionInvariants(
+          loc, context, flags, storageType, expressedType, storageTypeMin,
+          storageTypeMax))) {
+    return failure();
+  }
+
+  // Verify that the expressed type is floating point.
+  // If this restriction is ever eliminated, the parser/printer must be
+  // extended.
+  if (expressedType && !expressedType.isa<FloatType>()) {
+    if (loc) {
+      emitError(*loc, "expressed type must be floating point");
+    }
+    return failure();
+  }
+
+  return success();
+}
+
+UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType,
+                                               Type expressedType, double scale,
+                                               int64_t zeroPoint,
+                                               int64_t storageTypeMin,
+                                               int64_t storageTypeMax) {
+  return Base::get(storageType.getContext(),
+                   QuantizationTypes::UniformQuantized, flags, storageType,
+                   expressedType, scale, zeroPoint, storageTypeMin,
+                   storageTypeMax);
+}
+
+UniformQuantizedType
+UniformQuantizedType::getChecked(unsigned flags, Type storageType,
+                                 Type expressedType, double scale,
+                                 int64_t zeroPoint, int64_t storageTypeMin,
+                                 int64_t storageTypeMax, Location location) {
+  return Base::getChecked(location, storageType.getContext(),
+                          QuantizationTypes::UniformQuantized, flags,
+                          storageType, expressedType, scale, zeroPoint,
+                          storageTypeMin, storageTypeMax);
+}
+
+LogicalResult UniformQuantizedType::verifyConstructionInvariants(
+    llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
+    Type storageType, Type expressedType, double scale, int64_t zeroPoint,
+    int64_t storageTypeMin, int64_t storageTypeMax) {
+  if (failed(QuantizedType::verifyConstructionInvariants(
+          loc, context, flags, storageType, expressedType, storageTypeMin,
+          storageTypeMax))) {
+    return failure();
+  }
+
+  // Uniform quantization requires fully expressed parameters, including
+  // expressed type.
+  if (!expressedType) {
+    if (loc) {
+      emitError(*loc, "uniform quantization requires expressed type");
+    }
+    return failure();
+  }
+
+  // Verify that the expressed type is floating point.
+  // If this restriction is ever eliminated, the parser/printer must be
+  // extended.
+  if (!expressedType.isa<FloatType>()) {
+    if (loc) {
+      emitError(*loc, "expressed type must be floating point");
+    }
+    return failure();
+  }
+
+  // Verify scale.
+  if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale)) {
+    if (loc) {
+      emitError(*loc) << "illegal scale: " << scale;
+    }
+    return failure();
+  }
+
+  return success();
+}
+
+double UniformQuantizedType::getScale() const { return getImpl()->scale; }
+
+int64_t UniformQuantizedType::getZeroPoint() const {
+  return getImpl()->zeroPoint;
+}
+
+UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get(
+    unsigned flags, Type storageType, Type expressedType,
+    ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
+    int32_t quantizedDimension, int64_t storageTypeMin,
+    int64_t storageTypeMax) {
+  return Base::get(storageType.getContext(),
+                   QuantizationTypes::UniformQuantizedPerAxis, flags,
+                   storageType, expressedType, scales, zeroPoints,
+                   quantizedDimension, storageTypeMin, storageTypeMax);
+}
+
+UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
+    unsigned flags, Type storageType, Type expressedType,
+    ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
+    int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax,
+    Location location) {
+  return Base::getChecked(location, storageType.getContext(),
+                          QuantizationTypes::UniformQuantizedPerAxis, flags,
+                          storageType, expressedType, scales, zeroPoints,
+                          quantizedDimension, storageTypeMin, storageTypeMax);
+}
+
+LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
+    llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
+    Type storageType, Type expressedType, ArrayRef<double> scales,
+    ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
+    int64_t storageTypeMin, int64_t storageTypeMax) {
+  if (failed(QuantizedType::verifyConstructionInvariants(
+          loc, context, flags, storageType, expressedType, storageTypeMin,
+          storageTypeMax))) {
+    return failure();
+  }
+
+  // Uniform quantization requires fully expressed parameters, including
+  // expressed type.
+  if (!expressedType) {
+    if (loc) {
+      emitError(*loc, "uniform quantization requires expressed type");
+    }
+    return failure();
+  }
+
+  // Verify that the expressed type is floating point.
+  // If this restriction is ever eliminated, the parser/printer must be
+  // extended.
+  if (!expressedType.isa<FloatType>()) {
+    if (loc) {
+      emitError(*loc, "expressed type must be floating point");
+    }
+    return failure();
+  }
+
+  // Ensure that the number of scales and zeroPoints match.
+  if (scales.size() != zeroPoints.size()) {
+    if (loc) {
+      emitError(*loc, "illegal number of scales and zeroPoints: ")
+          << scales.size() << ", " << zeroPoints.size();
+    }
+    return failure();
+  }
+
+  // Verify scale.
+  for (double scale : scales) {
+    if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale)) {
+      if (loc) {
+        emitError(*loc) << "illegal scale: " << scale;
+      }
+      return failure();
+    }
+  }
+
+  return success();
+}
+
+ArrayRef<double> UniformQuantizedPerAxisType::getScales() const {
+  return getImpl()->getScales();
+}
+
+ArrayRef<int64_t> UniformQuantizedPerAxisType::getZeroPoints() const {
+  return getImpl()->getZeroPoints();
+}
+
+int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
+  return getImpl()->quantizedDimension;
+}
diff --git a/third_party/mlir/lib/Dialect/QuantOps/IR/TypeDetail.h b/third_party/mlir/lib/Dialect/QuantOps/IR/TypeDetail.h
new file mode 100644
index 0000000..4949b12
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/QuantOps/IR/TypeDetail.h
@@ -0,0 +1,269 @@
+//===- TypeDetail.h - QuantOps Type detail ----------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef TYPE_DETAIL_H_
+#define TYPE_DETAIL_H_
+
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/TypeSupport.h"
+#include "mlir/IR/Types.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/bit.h"
+
+namespace mlir {
+namespace quant {
+namespace detail {
+
+struct QuantizedTypeStorage : public mlir::TypeStorage {
+  QuantizedTypeStorage(unsigned flags, Type storageType, Type expressedType,
+                       int64_t storageTypeMin, int64_t storageTypeMax)
+      : flags(flags), storageType(storageType), expressedType(expressedType),
+        storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {}
+
+  /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue.
+  unsigned flags;
+
+  // Integral type for the storage point representation.
+  Type storageType;
+
+  // Floating point type that the quantized type approximates.
+  Type expressedType;
+
+  // The minimum value storageType can take.
+  int64_t storageTypeMin;
+
+  // The maximum value storageType can take.
+  int64_t storageTypeMax;
+};
+
+struct AnyQuantizedTypeStorage : public QuantizedTypeStorage {
+  struct KeyTy {
+    KeyTy(unsigned flags, Type storageType, Type expressedType,
+          int64_t storageTypeMin, int64_t storageTypeMax)
+        : flags(flags), storageType(storageType), expressedType(expressedType),
+          storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {}
+    unsigned flags;
+    Type storageType;
+    Type expressedType;
+    int64_t storageTypeMin;
+    int64_t storageTypeMax;
+
+    // Check for equality of two structures that share KeyTy data members
+    // (by name).
+    template <typename T, typename U>
+    static bool genericIsEqual(const T &lhs, const U &rhs) {
+      return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType &&
+             lhs.expressedType == rhs.expressedType &&
+             lhs.storageTypeMin == rhs.storageTypeMin &&
+             lhs.storageTypeMax == rhs.storageTypeMax;
+    }
+
+    bool operator==(const KeyTy &other) const {
+      return genericIsEqual(*this, other);
+    }
+
+    unsigned getHashValue() const {
+      return llvm::hash_combine(flags, storageType, expressedType,
+                                storageTypeMin, storageTypeMax);
+    }
+  };
+
+  AnyQuantizedTypeStorage(const KeyTy &key)
+      : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType,
+                             key.storageTypeMin, key.storageTypeMax) {}
+
+  bool operator==(const KeyTy &key) const {
+    return KeyTy::genericIsEqual(*this, key);
+  }
+
+  /// Construction.
+  static AnyQuantizedTypeStorage *construct(TypeStorageAllocator &allocator,
+                                            const KeyTy &key) {
+    return new (allocator.allocate<AnyQuantizedTypeStorage>())
+        AnyQuantizedTypeStorage(key);
+  }
+
+  static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
+};
+
+struct UniformQuantizedTypeStorage : public QuantizedTypeStorage {
+  struct KeyTy {
+    KeyTy(unsigned flags, Type storageType, Type expressedType, double scale,
+          int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
+        : flags(flags), storageType(storageType), expressedType(expressedType),
+          scale(scale), zeroPoint(zeroPoint), storageTypeMin(storageTypeMin),
+          storageTypeMax(storageTypeMax) {}
+    /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue.
+    unsigned flags;
+
+    // Integral type for the storage point representation.
+    Type storageType;
+
+    // Floating point type that the quantized type approximates.
+    Type expressedType;
+
+    double scale;
+    int64_t zeroPoint;
+    int64_t storageTypeMin;
+    int64_t storageTypeMax;
+
+    // Check for equality of two structures that share KeyTy data members
+    // (by name).
+    template <typename T, typename U>
+    static bool genericIsEqual(const T &lhs, const U &rhs) {
+      return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType &&
+             lhs.expressedType == rhs.expressedType && lhs.scale == rhs.scale &&
+             lhs.zeroPoint == rhs.zeroPoint &&
+             lhs.storageTypeMin == rhs.storageTypeMin &&
+             lhs.storageTypeMax == rhs.storageTypeMax;
+    }
+
+    bool operator==(const KeyTy &other) const {
+      return genericIsEqual(*this, other);
+    }
+
+    unsigned getHashValue() const {
+      int64_t scaleBits = llvm::bit_cast<int64_t>(scale);
+      return llvm::hash_combine(flags, storageType, expressedType, scaleBits,
+                                zeroPoint, storageTypeMin, storageTypeMax);
+    }
+  };
+
+  UniformQuantizedTypeStorage(const KeyTy &key)
+      : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType,
+                             key.storageTypeMin, key.storageTypeMax),
+        scale(key.scale), zeroPoint(key.zeroPoint) {}
+
+  bool operator==(const KeyTy &key) const {
+    return KeyTy::genericIsEqual(*this, key);
+  }
+
+  /// Construction.
+  static UniformQuantizedTypeStorage *construct(TypeStorageAllocator &allocator,
+                                                const KeyTy &key) {
+    return new (allocator.allocate<UniformQuantizedTypeStorage>())
+        UniformQuantizedTypeStorage(key);
+  }
+
+  static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
+
+  double scale;
+  int64_t zeroPoint;
+};
+
+struct UniformQuantizedPerAxisTypeStorage : public QuantizedTypeStorage {
+  struct KeyTy {
+    KeyTy(unsigned flags, Type storageType, Type expressedType,
+          ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
+          int32_t quantizedDimension, int64_t storageTypeMin,
+          int64_t storageTypeMax)
+        : flags(flags), storageType(storageType), expressedType(expressedType),
+          scales(scales), zeroPoints(zeroPoints),
+          quantizedDimension(quantizedDimension),
+          storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {}
+    /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue.
+    unsigned flags;
+
+    // Integral type for the storage point representation.
+    Type storageType;
+
+    // Floating point type that the quantized type approximates.
+    Type expressedType;
+
+    ArrayRef<double> scales;
+    ArrayRef<int64_t> zeroPoints;
+    int32_t quantizedDimension;
+    int64_t storageTypeMin;
+    int64_t storageTypeMax;
+
+    ArrayRef<double> getScales() const { return scales; }
+
+    ArrayRef<int64_t> getZeroPoints() const { return zeroPoints; }
+
+    // Check for equality of two structures that share KeyTy data members
+    // (by name).
+    template <typename T, typename U>
+    static bool genericIsEqual(const T &lhs, const U &rhs) {
+      return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType &&
+             lhs.expressedType == rhs.expressedType &&
+             lhs.getScales() == rhs.getScales() &&
+             lhs.getZeroPoints() == rhs.getZeroPoints() &&
+             lhs.quantizedDimension == rhs.quantizedDimension &&
+             lhs.storageTypeMin == rhs.storageTypeMin &&
+             lhs.storageTypeMax == rhs.storageTypeMax;
+    }
+
+    bool operator==(const KeyTy &other) const {
+      return genericIsEqual(*this, other);
+    }
+
+    unsigned getHashValue() const {
+      int64_t *scalesCast = llvm::bit_cast<int64_t *>(scales.data());
+      ArrayRef<int64_t> scalesBits(scalesCast, scales.size());
+      return llvm::hash_combine(
+          flags, storageType, expressedType,
+          llvm::hash_combine_range(scalesBits.begin(), scalesBits.end()),
+          llvm::hash_combine_range(zeroPoints.begin(), zeroPoints.end()),
+          storageTypeMin, storageTypeMax);
+    }
+  };
+
+  // We pass scales and zeroPoints in directly rather than relying on KeyTy
+  // because we have to create new reallocated versions in `constrcut` below.
+  UniformQuantizedPerAxisTypeStorage(const KeyTy &key, ArrayRef<double> scales,
+                                     ArrayRef<int64_t> zeroPoints)
+      : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType,
+                             key.storageTypeMin, key.storageTypeMax),
+        scaleElements(scales.data()), zeroPointElements(zeroPoints.data()),
+        quantParamsSize(scales.size()),
+        quantizedDimension(key.quantizedDimension) {}
+
+  bool operator==(const KeyTy &key) const {
+    return KeyTy::genericIsEqual(*this, key);
+  }
+
+  /// Construction.
+  static UniformQuantizedPerAxisTypeStorage *
+  construct(TypeStorageAllocator &allocator, const KeyTy &key) {
+    ArrayRef<double> scales = allocator.copyInto(key.scales);
+    ArrayRef<int64_t> zeroPoints = allocator.copyInto(key.zeroPoints);
+    return new (allocator.allocate<UniformQuantizedPerAxisTypeStorage>())
+        UniformQuantizedPerAxisTypeStorage(key, scales, zeroPoints);
+  }
+
+  static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
+
+  ArrayRef<double> getScales() const {
+    return ArrayRef<double>(scaleElements, quantParamsSize);
+  }
+
+  ArrayRef<int64_t> getZeroPoints() const {
+    return ArrayRef<int64_t>(zeroPointElements, quantParamsSize);
+  }
+
+  const double *scaleElements;
+  const int64_t *zeroPointElements;
+  unsigned quantParamsSize;
+  int32_t quantizedDimension;
+};
+
+} // namespace detail
+} // namespace quant
+} // namespace mlir
+
+#endif // TYPE_DETAIL_H_
diff --git a/third_party/mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp b/third_party/mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp
new file mode 100644
index 0000000..b3fbad8
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp
@@ -0,0 +1,744 @@
+//===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/QuantOps/QuantOps.h"
+#include "mlir/Dialect/QuantOps/QuantTypes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/Support/Format.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace quant {
+
+/// Print a floating point value in a way that the parser will be able to
+/// round-trip losslessly.
+static void printStabilizedFloat(const APFloat &apValue, raw_ostream &os) {
+  // We would like to output the FP constant value in exponential notation,
+  // but we cannot do this if doing so will lose precision.  Check here to
+  // make sure that we only output it in exponential format if we can parse
+  // the value back and get the same value.
+  bool isInf = apValue.isInfinity();
+  bool isNaN = apValue.isNaN();
+  if (!isInf && !isNaN) {
+    SmallString<128> strValue;
+    apValue.toString(strValue, 6, 0, false);
+
+    // Check to make sure that the stringized number is not some string like
+    // "Inf" or NaN, that atof will accept, but the lexer will not.  Check
+    // that the string matches the "[-+]?[0-9]" regex.
+    assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
+            ((strValue[0] == '-' || strValue[0] == '+') &&
+             (strValue[1] >= '0' && strValue[1] <= '9'))) &&
+           "[-+]?[0-9] regex does not match!");
+    // Reparse stringized version!
+    if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
+      os << strValue;
+      return;
+    }
+  }
+
+  SmallVector<char, 16> str;
+  apValue.toString(str);
+  os << str;
+}
+
+namespace {
+
+enum class TokenKind {
+  error,
+  eof,
+  l_brace,
+  r_brace,
+  l_angle,
+  r_angle,
+  colon,
+  comma,
+  alpha_ident,
+  integer_literal,
+  float_literal,
+};
+
+struct Token {
+  TokenKind kind;
+  StringRef spelling;
+};
+
+class Lexer {
+public:
+  Lexer(StringRef source) : curBuffer(source), curPtr(curBuffer.begin()) {}
+
+  Token lexToken();
+
+private:
+  Token formToken(TokenKind kind, const char *tokStart) {
+    return Token{kind, StringRef(tokStart, curPtr - tokStart)};
+  }
+
+  Token emitError(const char *loc, const Twine &message) {
+    return formToken(TokenKind::error, loc);
+  }
+
+  bool isEnd() const { return curPtr == curBuffer.end(); }
+
+  // Lexer implementation methods
+  Token lexalpha_ident(const char *tokStart);
+  Token lexNumber(const char *tokStart);
+
+  StringRef curBuffer;
+  const char *curPtr;
+};
+
+} // namespace
+
+Token Lexer::lexToken() {
+  // Ignore whitespace.
+  while (!isEnd()) {
+    switch (*curPtr) {
+    case ' ':
+    case '\t':
+    case '\n':
+    case '\r':
+      ++curPtr;
+      continue;
+    default:
+      break;
+    }
+    break;
+  }
+
+  if (isEnd()) {
+    return Token{TokenKind::eof, ""};
+  }
+
+  const char *tokStart = curPtr;
+  switch (*curPtr++) {
+  default:
+    if (isalpha(*tokStart)) {
+      return lexalpha_ident(tokStart);
+    }
+    if (isdigit(*tokStart)) {
+      return lexNumber(tokStart);
+    }
+
+    return emitError(tokStart, "unexpected character");
+
+  case '<':
+    return formToken(TokenKind::l_angle, tokStart);
+  case '>':
+    return formToken(TokenKind::r_angle, tokStart);
+  case '{':
+    return formToken(TokenKind::l_brace, tokStart);
+  case '}':
+    return formToken(TokenKind::r_brace, tokStart);
+  case ':':
+    return formToken(TokenKind::colon, tokStart);
+  case ',':
+    return formToken(TokenKind::comma, tokStart);
+  case '-':
+    return lexNumber(tokStart);
+  case '+':
+    return lexNumber(tokStart);
+  }
+}
+
+/// Lex a bare alpha identifier. Since this DSL often contains identifiers with
+/// trailing numeric components, this only matches alphas. It is up to the
+/// parser to handle identifiers that can be mixed alphanum.
+///
+///   alpha-ident ::= (letter)(letter)*
+Token Lexer::lexalpha_ident(const char *tokStart) {
+  while (!isEnd() && isalpha(*curPtr)) {
+    ++curPtr;
+  }
+  return formToken(TokenKind::alpha_ident, tokStart);
+}
+
+/// Lex a number.
+///
+///   integer-literal ::= [-+]?digit+
+///   float-literal ::= [-+]?[0-9]+[.][0-9]*([eE][-+]?[0-9]+)?
+Token Lexer::lexNumber(const char *tokStart) {
+  // Leading '+', '-' or digit has already been consumed.
+  while (!isEnd() && isdigit(*curPtr)) {
+    ++curPtr;
+  }
+  // If not a decimal point, treat as integer.
+  if (isEnd() || *curPtr != '.') {
+    return formToken(TokenKind::integer_literal, tokStart);
+  }
+  ++curPtr;
+
+  // Skip over [0-9]*([eE][-+]?[0-9]+)?
+  // Leading digits.
+  while (!isEnd() && isdigit(*curPtr)) {
+    ++curPtr;
+  }
+
+  // [eE][-+]?[0-9]+
+  if (!isEnd() && (*curPtr == 'e' || *curPtr == 'E')) {
+    auto remaining = curBuffer.end() - curPtr;
+    if (remaining > 2 && isdigit(curPtr[1])) {
+      // Lookahead 2 for digit.
+      curPtr += 2;
+      while (!isEnd() && isdigit(*curPtr)) {
+        ++curPtr;
+      }
+    } else if (remaining > 3 && (curPtr[1] == '-' || curPtr[1] == '+') &&
+               isdigit(curPtr[2])) {
+      // Lookahead 3 for [+-] digit.
+      curPtr += 3;
+      while (!isEnd() && isdigit(*curPtr)) {
+        ++curPtr;
+      }
+    }
+  }
+  return formToken(TokenKind::float_literal, tokStart);
+} // end namespace
+
+// --- TypeParser ---
+namespace {
+
+class TypeParser {
+public:
+  TypeParser(StringRef source, MLIRContext *context, Location location)
+      : context(context), location(location), lexer(source),
+        curToken(lexer.lexToken()) {}
+
+  /// Attempts to parse the source as a type, returning the unknown
+  /// type on error.
+  Type parseType();
+
+private:
+  /// Unconditionally consumes the current token.
+  void consumeToken() {
+    assert(curToken.kind != TokenKind::eof &&
+           "should not advance past EOF or errors");
+    curToken = lexer.lexToken();
+  }
+
+  /// Unconditionally consumes the current token, asserting that it is of the
+  /// specified kind.
+  void consumeToken(TokenKind kind) {
+    assert(curToken.kind == kind && "consumed an unexpected token");
+    consumeToken();
+  }
+
+  /// Conditionally consumes a token if of the specified kind.
+  /// Returns true if consumed.
+  bool consumeIf(TokenKind kind) {
+    if (curToken.kind == kind) {
+      consumeToken();
+      return true;
+    }
+    return false;
+  }
+
+  /// Emits an error at the current location with a message.
+  void emitError(const Twine &message) {
+    // TODO: All errors show up at the beginning of the extended type location.
+    // Figure out how to make this location relative to where the error occurred
+    // in this instance.
+    mlir::emitError(location, message);
+  }
+
+  // Parsers.
+  Type parseAnyType();
+  Type parseUniformType();
+  IntegerType parseStorageType(bool &isSigned);
+  bool parseStorageRange(IntegerType storageType, bool isSigned,
+                         int64_t &storageTypeMin, int64_t &storageTypeMax);
+  FloatType parseExpressedType();
+  bool parseQuantParams(double &scale, int64_t &zeroPoint);
+
+  MLIRContext *context;
+  Location location;
+  Lexer lexer;
+
+  // The next token that has not yet been consumed.
+  Token curToken;
+};
+
+} // namespace
+
+Type TypeParser::parseType() {
+  // All types start with an identifier that we switch on.
+  if (curToken.kind == TokenKind::alpha_ident) {
+    StringRef typeNameSpelling = curToken.spelling;
+    consumeToken();
+
+    Type result;
+    if (typeNameSpelling == "uniform") {
+      result = parseUniformType();
+      if (!result) {
+        return nullptr;
+      }
+    } else if (typeNameSpelling == "any") {
+      result = parseAnyType();
+      if (!result) {
+        return nullptr;
+      }
+    } else {
+      return (emitError("unknown quantized type " + typeNameSpelling), nullptr);
+    }
+
+    // Make sure the entire input was consumed.
+    if (curToken.kind != TokenKind::eof) {
+      return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+    }
+
+    return result;
+  } else {
+    return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+  }
+}
+
+/// Parses a UniformQuantizedType.
+///
+///   uniform_per_layer ::= `any<` storage-spec (expressed-type-spec)?`>`
+///   storage-spec ::= storage-type (`<` storage-range `>`)?
+///   storage-range ::= integer-literal `:` integer-literal
+///   storage-type ::= (`i` | `u`) integer-literal
+///   expressed-type-spec ::= `:` `f` integer-literal
+Type TypeParser::parseAnyType() {
+  IntegerType storageType;
+  FloatType expressedType;
+  unsigned typeFlags = 0;
+  int64_t storageTypeMin;
+  int64_t storageTypeMax;
+
+  // Type specification.
+  if (!consumeIf(TokenKind::l_angle)) {
+    return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+  }
+
+  // Storage type.
+  bool isSigned = false;
+  storageType = parseStorageType(isSigned);
+  if (!storageType) {
+    return nullptr;
+  }
+  if (isSigned) {
+    typeFlags |= QuantizationFlags::Signed;
+  }
+
+  // Storage type range.
+  if (parseStorageRange(storageType, isSigned, storageTypeMin,
+                        storageTypeMax)) {
+    return nullptr;
+  }
+
+  // Optional expressed type.
+  if (consumeIf(TokenKind::colon)) {
+    expressedType = parseExpressedType();
+    if (!expressedType) {
+      return nullptr;
+    }
+  }
+
+  if (!consumeIf(TokenKind::r_angle)) {
+    return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+  }
+
+  return AnyQuantizedType::getChecked(typeFlags, storageType, expressedType,
+                                      storageTypeMin, storageTypeMax, location);
+}
+
+/// Parses a UniformQuantizedType.
+///
+///   uniform_type ::= uniform_per_layer
+///                  | uniform_per_axis
+///   uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec
+///                          `,` scale-zero `>`
+///   uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec
+///                        axis-spec `,` scale-zero-list `>`
+///   storage-spec ::= storage-type (`<` storage-range `>`)?
+///   storage-range ::= integer-literal `:` integer-literal
+///   storage-type ::= (`i` | `u`) integer-literal
+///   expressed-type-spec ::= `:` `f` integer-literal
+///   axis-spec ::= `:` integer-literal
+///   scale-zero ::= float-literal `:` integer-literal
+///   scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}`
+Type TypeParser::parseUniformType() {
+  IntegerType storageType;
+  FloatType expressedType;
+  unsigned typeFlags = 0;
+  int64_t storageTypeMin;
+  int64_t storageTypeMax;
+  bool isPerAxis = false;
+  int32_t quantizedDimension;
+  SmallVector<double, 1> scales;
+  SmallVector<int64_t, 1> zeroPoints;
+
+  // Type specification.
+  if (!consumeIf(TokenKind::l_angle)) {
+    return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+  }
+
+  // Storage type.
+  bool isSigned = false;
+  storageType = parseStorageType(isSigned);
+  if (!storageType) {
+    return nullptr;
+  }
+  if (isSigned) {
+    typeFlags |= QuantizationFlags::Signed;
+  }
+
+  // Storage type range.
+  if (parseStorageRange(storageType, isSigned, storageTypeMin,
+                        storageTypeMax)) {
+    return nullptr;
+  }
+
+  // Expressed type.
+  if (!consumeIf(TokenKind::colon)) {
+    return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+  }
+  expressedType = parseExpressedType();
+  if (!expressedType) {
+    return nullptr;
+  }
+
+  // Optionally parse quantized dimension for per-axis quantization.
+  if (consumeIf(TokenKind::colon)) {
+    if (curToken.kind != TokenKind::integer_literal) {
+      return (emitError("expected quantized dimension"), nullptr);
+    }
+    if (curToken.spelling.getAsInteger(10, quantizedDimension)) {
+      return (emitError("illegal quantized dimension: " + curToken.spelling),
+              nullptr);
+    }
+    consumeToken(TokenKind::integer_literal);
+    isPerAxis = true;
+  }
+
+  // Comma leading into range_spec.
+  if (!consumeIf(TokenKind::comma)) {
+    return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+  }
+
+  // Parameter specification.
+  // For per-axis, ranges are in a {} delimitted list.
+  if (isPerAxis) {
+    if (!consumeIf(TokenKind::l_brace)) {
+      return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+    }
+  }
+
+  // Parse scales/zeroPoints.
+  do {
+    scales.resize(scales.size() + 1);
+    zeroPoints.resize(zeroPoints.size() + 1);
+    if (parseQuantParams(scales.back(), zeroPoints.back())) {
+      return nullptr;
+    }
+  } while (isPerAxis && consumeIf(TokenKind::comma));
+
+  if (isPerAxis) {
+    if (!consumeIf(TokenKind::r_brace)) {
+      return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+    }
+  }
+
+  if (!consumeIf(TokenKind::r_angle)) {
+    return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+  }
+
+  if (!isPerAxis && scales.size() > 1) {
+    return (emitError("multiple scales/zeroPoints provided, but "
+                      "quantizedDimension wasn't specified"),
+            nullptr);
+  }
+
+  if (isPerAxis) {
+    ArrayRef<double> scalesRef(scales.begin(), scales.end());
+    ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end());
+    return UniformQuantizedPerAxisType::getChecked(
+        typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
+        quantizedDimension, storageTypeMin, storageTypeMax, location);
+  }
+
+  return UniformQuantizedType::getChecked(
+      typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
+      storageTypeMin, storageTypeMax, location);
+}
+
+IntegerType TypeParser::parseStorageType(bool &isSigned) {
+  // Parse storage type (alpha_ident, integer_literal).
+  StringRef storageTypePrefix = curToken.spelling;
+  unsigned storageTypeWidth;
+  if (curToken.kind != TokenKind::alpha_ident) {
+    return (emitError("expected storage type prefix"), nullptr);
+  }
+  consumeToken();
+  if (curToken.kind != TokenKind::integer_literal) {
+    return (emitError("expected storage type width"), nullptr);
+  }
+  if (curToken.spelling.getAsInteger(10, storageTypeWidth) ||
+      storageTypeWidth == 0 ||
+      storageTypeWidth > QuantizedType::MaxStorageBits) {
+    return (emitError("illegal storage type size: " + Twine(curToken.spelling)),
+            nullptr);
+  }
+  consumeToken();
+
+  if (storageTypePrefix == "i") {
+    isSigned = true;
+    return IntegerType::get(storageTypeWidth, context);
+  } else if (storageTypePrefix == "u") {
+    isSigned = false;
+    return IntegerType::get(storageTypeWidth, context);
+  } else {
+    return (
+        emitError("illegal storage type prefix: " + Twine(storageTypePrefix)),
+        nullptr);
+  }
+}
+
+bool TypeParser::parseStorageRange(IntegerType storageType, bool isSigned,
+                                   int64_t &storageTypeMin,
+                                   int64_t &storageTypeMax) {
+
+  int64_t defaultIntegerMin = QuantizedType::getDefaultMininumForInteger(
+      isSigned, storageType.getWidth());
+  int64_t defaultIntegerMax = QuantizedType::getDefaultMaxinumForInteger(
+      isSigned, storageType.getWidth());
+  if (consumeIf(TokenKind::l_angle)) {
+    // Explicit storage min and storage max.
+    if (curToken.kind != TokenKind::integer_literal) {
+      return (emitError("expected storage type minimum"), true);
+    }
+    if (curToken.spelling.getAsInteger(10, storageTypeMin) ||
+        storageTypeMin < defaultIntegerMin) {
+      return (emitError("illegal storage type minimum: " + curToken.spelling),
+              true);
+    }
+    consumeToken(TokenKind::integer_literal);
+
+    if (!consumeIf(TokenKind::colon)) {
+      return (emitError("unrecognized token: " + curToken.spelling), true);
+    }
+
+    if (curToken.kind != TokenKind::integer_literal) {
+      return (emitError("expected storage type maximum"), true);
+    }
+    if (curToken.spelling.getAsInteger(10, storageTypeMax) ||
+        storageTypeMax > defaultIntegerMax) {
+      return (emitError("illegal storage type maximum: " + curToken.spelling),
+              true);
+    }
+    consumeToken(TokenKind::integer_literal);
+
+    if (!consumeIf(TokenKind::r_angle)) {
+      return (emitError("unrecognized token: " + curToken.spelling), true);
+    }
+  } else {
+    storageTypeMin = defaultIntegerMin;
+    storageTypeMax = defaultIntegerMax;
+  }
+
+  return false;
+}
+
+FloatType TypeParser::parseExpressedType() {
+  // Expect an alpha_ident followed by integer literal that we concat back
+  // together.
+  StringRef prefix = curToken.spelling;
+  if (!consumeIf(TokenKind::alpha_ident)) {
+    return (emitError("expected expressed type"), nullptr);
+  }
+  StringRef suffix = curToken.spelling;
+  if (!consumeIf(TokenKind::integer_literal)) {
+    return (emitError("expected expressed type"), nullptr);
+  }
+
+  SmallVector<char, 4> holder;
+  StringRef typeName = (Twine(prefix) + Twine(suffix)).toStringRef(holder);
+  if (typeName == "f32")
+    return FloatType::getF32(context);
+  if (typeName == "f16")
+    return FloatType::getF16(context);
+  if (typeName == "bf16")
+    return FloatType::getBF16(context);
+  if (typeName == "f64")
+    return FloatType::getF64(context);
+
+  return (emitError("unrecognized expressed type: " + typeName), nullptr);
+}
+
+bool TypeParser::parseQuantParams(double &scale, int64_t &zeroPoint) {
+  // scale[:zeroPoint]?
+  // scale.
+  StringRef scaleSpelling = curToken.spelling;
+  if (!consumeIf(TokenKind::float_literal) ||
+      scaleSpelling.getAsDouble(scale)) {
+    return (
+        emitError("expected valid uniform scale. got: " + Twine(scaleSpelling)),
+        true);
+  }
+
+  // zero point.
+  zeroPoint = 0;
+  if (!consumeIf(TokenKind::colon)) {
+    // Default zero point.
+    return false;
+  }
+  StringRef zeroPointSpelling = curToken.spelling;
+  if (!consumeIf(TokenKind::integer_literal) ||
+      zeroPointSpelling.getAsInteger(10, zeroPoint)) {
+    return (emitError("expected integer uniform zero point. got: " +
+                      Twine(zeroPointSpelling)),
+            true);
+  }
+
+  return false;
+}
+
+/// Parse a type registered to this dialect.
+Type QuantizationDialect::parseType(StringRef spec, Location loc) const {
+  TypeParser parser(spec, getContext(), loc);
+  Type parsedType = parser.parseType();
+  if (parsedType == nullptr) {
+    // Error.
+    // TODO(laurenzo): Do something?
+    return parsedType;
+  }
+
+  return parsedType;
+}
+
+static void printStorageType(QuantizedType type, raw_ostream &out) {
+  // storage type
+  unsigned storageWidth = type.getStorageTypeIntegralWidth();
+  bool isSigned = type.isSigned();
+  if (isSigned) {
+    out << "i" << storageWidth;
+  } else {
+    out << "u" << storageWidth;
+  }
+
+  // storageTypeMin and storageTypeMax if not default.
+  int64_t defaultIntegerMin =
+      QuantizedType::getDefaultMininumForInteger(isSigned, storageWidth);
+  int64_t defaultIntegerMax =
+      QuantizedType::getDefaultMaxinumForInteger(isSigned, storageWidth);
+  if (defaultIntegerMin != type.getStorageTypeMin() ||
+      defaultIntegerMax != type.getStorageTypeMax()) {
+    out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
+        << ">";
+  }
+}
+
+static void printExpressedType(QuantizedType type, raw_ostream &out) {
+  // repr type
+  Type expressedType = type.getExpressedType();
+  if (expressedType.isF32()) {
+    out << "f32";
+  } else if (expressedType.isF64()) {
+    out << "f64";
+  } else if (expressedType.isF16()) {
+    out << "f16";
+  } else if (expressedType.isBF16()) {
+    out << "bf16";
+  } else {
+    out << "unknown";
+  }
+}
+
+static void printQuantParams(double scale, int64_t zeroPoint,
+                             raw_ostream &out) {
+  printStabilizedFloat(APFloat(scale), out);
+  if (zeroPoint != 0) {
+    out << ":" << zeroPoint;
+  }
+}
+
+/// Helper that prints a UniformQuantizedType.
+static void printAnyQuantizedType(AnyQuantizedType type, raw_ostream &out) {
+  out << "any<";
+  printStorageType(type, out);
+  if (type.getExpressedType()) {
+    out << ":";
+    printExpressedType(type, out);
+  }
+  out << ">";
+}
+
+/// Helper that prints a UniformQuantizedType.
+static void printUniformQuantizedType(UniformQuantizedType type,
+                                      raw_ostream &out) {
+  out << "uniform<";
+  printStorageType(type, out);
+  out << ":";
+  printExpressedType(type, out);
+  out << ", ";
+
+  // scheme specific parameters
+  printQuantParams(type.getScale(), type.getZeroPoint(), out);
+  out << ">";
+}
+
+/// Helper that prints a UniformQuantizedPerAxisType.
+static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,
+                                             raw_ostream &out) {
+  out << "uniform<";
+  printStorageType(type, out);
+  out << ":";
+  printExpressedType(type, out);
+  out << ":";
+  out << type.getQuantizedDimension();
+  out << ", ";
+
+  // scheme specific parameters
+  ArrayRef<double> scales = type.getScales();
+  ArrayRef<int64_t> zeroPoints = type.getZeroPoints();
+  out << "{";
+  for (unsigned i = 0; i < scales.size(); ++i) {
+    printQuantParams(scales[i], zeroPoints[i], out);
+    if (i != scales.size() - 1) {
+      out << ",";
+    }
+  }
+  out << "}>";
+}
+
+/// Print a type registered to this dialect.
+void QuantizationDialect::printType(Type type, raw_ostream &os) const {
+  switch (type.getKind()) {
+  default:
+    llvm_unreachable("Unhandled quantized type");
+  case QuantizationTypes::Any:
+    printAnyQuantizedType(type.cast<AnyQuantizedType>(), os);
+    break;
+  case QuantizationTypes::UniformQuantized:
+    printUniformQuantizedType(type.cast<UniformQuantizedType>(), os);
+    break;
+  case QuantizationTypes::UniformQuantizedPerAxis:
+    printUniformQuantizedPerAxisType(type.cast<UniformQuantizedPerAxisType>(),
+                                     os);
+    break;
+  }
+}
+
+} // namespace quant
+} // namespace mlir
diff --git a/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
new file mode 100644
index 0000000..120d0cf
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
@@ -0,0 +1,121 @@
+//===- ConvertConst.cpp - Quantizes constant ops --------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/QuantOps/Passes.h"
+#include "mlir/Dialect/QuantOps/QuantOps.h"
+#include "mlir/Dialect/QuantOps/QuantizeUtils.h"
+#include "mlir/Dialect/QuantOps/UniformSupport.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+
+namespace {
+
+class ConvertConstPass : public FunctionPass<ConvertConstPass> {
+public:
+  void runOnFunction() override;
+};
+
+struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> {
+  using OpRewritePattern<QuantizeCastOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(QuantizeCastOp qbarrier,
+                                     PatternRewriter &rewriter) const override;
+};
+
+} // end anonymous namespace
+
+/// Matches a [constant] -> [qbarrier] where the qbarrier results type is
+/// quantized and the operand type is quantizable.
+
+PatternMatchResult
+QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
+                                       PatternRewriter &rewriter) const {
+  Attribute value;
+
+  // Is the operand a constant?
+  if (!matchPattern(qbarrier.arg(), m_Constant(&value))) {
+    return matchFailure();
+  }
+
+  // Does the qbarrier convert to a quantized type. This will not be true
+  // if a quantized type has not yet been chosen or if the cast to an equivalent
+  // storage type is not supported.
+  Type qbarrierResultType = qbarrier.getResult()->getType();
+  QuantizedType quantizedElementType =
+      QuantizedType::getQuantizedElementType(qbarrierResultType);
+  if (!quantizedElementType) {
+    return matchFailure();
+  }
+  if (!QuantizedType::castToStorageType(qbarrierResultType)) {
+    return matchFailure();
+  }
+
+  // Is the operand type compatible with the expressed type of the quantized
+  // type? This will not be true if the qbarrier is superfluous (converts
+  // from and to a quantized type).
+  if (!quantizedElementType.isCompatibleExpressedType(
+          qbarrier.arg()->getType())) {
+    return matchFailure();
+  }
+
+  // Is the constant value a type expressed in a way that we support?
+  if (!value.isa<FloatAttr>() && !value.isa<DenseElementsAttr>() &&
+      !value.isa<SparseElementsAttr>()) {
+    return matchFailure();
+  }
+
+  Type newConstValueType;
+  auto newConstValue =
+      quantizeAttr(value, quantizedElementType, newConstValueType);
+  if (!newConstValue) {
+    return matchFailure();
+  }
+
+  // When creating the new const op, use a fused location that combines the
+  // original const and the qbarrier that led to the quantization.
+  auto fusedLoc = FusedLoc::get(
+      {qbarrier.arg()->getDefiningOp()->getLoc(), qbarrier.getLoc()},
+      rewriter.getContext());
+  auto newConstOp =
+      rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue);
+  rewriter.replaceOpWithNewOp<StorageCastOp>({qbarrier.arg()}, qbarrier,
+                                             qbarrier.getType(), newConstOp);
+  return matchSuccess();
+}
+
+void ConvertConstPass::runOnFunction() {
+  OwningRewritePatternList patterns;
+  auto func = getFunction();
+  auto *context = &getContext();
+  patterns.insert<QuantizedConstRewrite>(context);
+  applyPatternsGreedily(func, patterns);
+}
+
+FunctionPassBase *mlir::quant::createConvertConstPass() {
+  return new ConvertConstPass();
+}
+
+static PassRegistration<ConvertConstPass>
+    pass("quant-convert-const",
+         "Converts constants followed by qbarrier to actual quantized values");
diff --git a/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp b/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
new file mode 100644
index 0000000..dfdce89
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
@@ -0,0 +1,113 @@
+//===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/QuantOps/FakeQuantSupport.h"
+#include "mlir/Dialect/QuantOps/Passes.h"
+#include "mlir/Dialect/QuantOps/QuantOps.h"
+#include "mlir/Dialect/QuantOps/UniformSupport.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+
+namespace {
+
+class ConvertSimulatedQuantPass
+    : public FunctionPass<ConvertSimulatedQuantPass> {
+public:
+  void runOnFunction() override;
+};
+
+} // end anonymous namespace
+
+/// Rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
+class ConstFakeQuantRewrite : public RewritePattern {
+public:
+  bool *hadFailure;
+
+  ConstFakeQuantRewrite(MLIRContext *context, bool *hadFailure)
+      : RewritePattern(ConstFakeQuant::getOperationName(), 1, context),
+        hadFailure(hadFailure) {}
+
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const override {
+    // TODO: If this pattern comes up more frequently, consider adding core
+    // support for failable rewrites.
+    if (failableRewrite(op, rewriter)) {
+      *hadFailure = true;
+      return matchFailure();
+    }
+
+    return matchSuccess();
+  }
+
+  bool failableRewrite(Operation *op, PatternRewriter &rewriter) const {
+    auto fqOp = cast<ConstFakeQuant>(op);
+
+    auto converter =
+        ExpressedToUniformQuantizedConverter::forInputType(fqOp.getType());
+    if (!converter) {
+      return (op->emitError("unsupported quantized type conversion"), true);
+    }
+
+    UniformQuantizedType uniformElementType = fakeQuantAttrsToType(
+        fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
+        fqOp.min().convertToFloat(), fqOp.max().convertToFloat(),
+        fqOp.narrow_range(), converter.expressedType, fqOp.is_signed());
+
+    if (!uniformElementType) {
+      // Note that the fakeQuantAttrsToType will have emitted the error.
+      return true;
+    }
+
+    Type quantizedType = converter.convert(uniformElementType);
+    assert(quantizedType &&
+           "Converter accepted a type that it did not convert");
+
+    // TODO: Map to a qbarrier with an attribute like [Forced] to signal that
+    // this is a forced/hard-coded constraint.
+    auto qbarrier = rewriter.create<QuantizeCastOp>(op->getLoc(), quantizedType,
+                                                    fqOp.inputs());
+    rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType,
+                                                  qbarrier.getResult());
+
+    return false;
+  }
+};
+
+void ConvertSimulatedQuantPass::runOnFunction() {
+  bool hadFailure = false;
+  OwningRewritePatternList patterns;
+  auto func = getFunction();
+  auto *context = &getContext();
+  patterns.insert<ConstFakeQuantRewrite>(context, &hadFailure);
+  applyPatternsGreedily(func, patterns);
+  if (hadFailure)
+    signalPassFailure();
+}
+
+FunctionPassBase *mlir::quant::createConvertSimulatedQuantPass() {
+  return new ConvertSimulatedQuantPass();
+}
+
+static PassRegistration<ConvertSimulatedQuantPass>
+    pass("quant-convert-simulated-quantization",
+         "Converts training-time simulated quantization ops to corresponding "
+         "quantize/dequantize casts.");
diff --git a/third_party/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp b/third_party/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp
new file mode 100644
index 0000000..2667da9
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp
@@ -0,0 +1,120 @@
+//===- FakeQuantSupport.cpp - Support utilities for FakeQuant ops ---------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/QuantOps/FakeQuantSupport.h"
+#include "mlir/Dialect/QuantOps/QuantTypes.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+
+UniformQuantizedType
+mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin,
+                                  double rmax, bool narrowRange,
+                                  Type expressedType, bool isSigned) {
+  MLIRContext *ctx = expressedType.getContext();
+  Type storageType;
+  unsigned flags;
+  int64_t qmin;
+  int64_t qmax;
+
+  // Hard-coded type mapping from TFLite.
+  if (numBits <= 8) {
+    storageType = IntegerType::get(8, ctx);
+    if (isSigned) {
+      flags = QuantizationFlags::Signed;
+      qmin = -128;
+      qmax = 127;
+    } else {
+      flags = 0;
+      qmin = 0;
+      qmax = 255;
+    }
+  } else if (numBits <= 16) {
+    storageType = IntegerType::get(16, ctx);
+    if (isSigned) {
+      flags = QuantizationFlags::Signed;
+      qmin = -32768;
+      qmax = 32767;
+    } else {
+      flags = 0;
+      qmin = 0;
+      qmax = 65535;
+    }
+  } else {
+    emitError(loc, "unsupported FakeQuant number of bits: ") << numBits;
+    return nullptr;
+  }
+
+  // Handle narrowRange.
+  if (narrowRange) {
+    qmin += 1;
+  }
+
+  // Range must straddle zero.
+  if (rmin > 0.0 || rmax < 0.0) {
+    return (emitError(loc, "FakeQuant range must straddle zero: [")
+                << rmin << "," << rmax << "]",
+            nullptr);
+  }
+
+  // Special case where min/max is a point. Must be 0.
+  if (rmin == rmax) {
+    return UniformQuantizedType::getChecked(flags, storageType, expressedType,
+                                            0.0, 0, qmin, qmax, loc);
+  }
+
+  // Determine the scale.
+  const double qminDouble = qmin;
+  const double qmaxDouble = qmax;
+  const double scale = (rmax - rmin) / (qmaxDouble - qminDouble);
+
+  // Zero point computation.
+  // In float, solve the affine equation for any known pair
+  // (real value, corresponding quantized value), of which, two such pairs
+  // are known: (rmin, qmin), (rmax, qmax).
+  // The arithmetic error on the zero point computed from either pair will be
+  // roughly machine_epsilon * (sum of absolute values of terms).
+  // Use the variant that adds the smaller error.
+  const double zeroPointFromMin = qminDouble - rmin / scale;
+  const double zeroPointFromMinError =
+      std::abs(qminDouble) + std::abs(rmin / scale);
+  const double zeroPointFromMax = qmaxDouble - rmax / scale;
+  const double zeroPointFromMaxError =
+      std::abs(qmaxDouble) + std::abs(rmax / scale);
+
+  const double zeroPointDouble = (zeroPointFromMinError < zeroPointFromMaxError)
+                                     ? zeroPointFromMin
+                                     : zeroPointFromMax;
+
+  // Now nudge the zero point to be an integer.
+  int64_t nudgedZeroPoint = 0;
+  if (zeroPointDouble < qminDouble) {
+    nudgedZeroPoint = qmin;
+  } else if (zeroPointDouble > qmaxDouble) {
+    nudgedZeroPoint = qmax;
+  } else {
+    nudgedZeroPoint = round(zeroPointDouble);
+  }
+
+  // By construction, the nudged zero point should always be in range.
+  assert(nudgedZeroPoint >= qmin);
+  assert(nudgedZeroPoint <= qmax);
+
+  return UniformQuantizedType::getChecked(flags, storageType, expressedType,
+                                          scale, nudgedZeroPoint, qmin, qmax,
+                                          loc);
+}
diff --git a/third_party/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp b/third_party/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp
new file mode 100644
index 0000000..7cfedf9
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp
@@ -0,0 +1,146 @@
+//===- QuantizeUtils.cpp - Support utilities for quantization -------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/QuantOps/QuantizeUtils.h"
+#include "mlir/Dialect/QuantOps/UniformSupport.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+namespace quant {
+/// Converts a possible primitive, real expressed value attribute to a
+/// corresponding storage attribute (typically FloatAttr -> IntegerAttr).
+/// quantizedElementType is the QuantizedType that describes the expressed
+/// origValue.
+/// Returns a converter Attribute or nullptr if conversion is not possible.
+static Attribute convertPrimitiveValueAttr(
+    Attribute origRealValue, QuantizedType quantizedElementType,
+    const UniformQuantizedValueConverter &converter, Type &outConvertedType) {
+  if (origRealValue.isa<FloatAttr>()) {
+    FloatAttr floatAttr = origRealValue.cast<FloatAttr>();
+    outConvertedType = quantizedElementType.getStorageType();
+    return IntegerAttr::get(quantizedElementType.getStorageType(),
+                            converter.quantizeFloatToInt(floatAttr.getValue()));
+  }
+
+  return nullptr;
+}
+
+/// Converts a real expressed DenseFPElementsAttr to a corresponding
+/// DenseElementsAttr (typically DenseIntElementsAttr) containing quantized
+/// storage values assuming the given quantizedElementType and converter.
+static DenseElementsAttr
+convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,
+                           QuantizedType quantizedElementType,
+                           const UniformQuantizedValueConverter &converter) {
+  // Convert to corresponding quantized value attributes.
+  SmallVector<APInt, 8> quantValues;
+  quantValues.reserve(realFPElementsAttr.rawSize());
+  for (APFloat realVal : realFPElementsAttr) {
+    quantValues.push_back(converter.quantizeFloatToInt(realVal));
+  }
+
+  // Cast from an expressed-type-based type to storage-type-based type,
+  // preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>).
+  ShapedType newDenseType =
+      quantizedElementType
+          .castExpressedToStorageType(realFPElementsAttr.getType())
+          .dyn_cast_or_null<ShapedType>();
+  if (!newDenseType) {
+    return nullptr;
+  }
+  return DenseIntElementsAttr::get(newDenseType, quantValues);
+}
+
+/// Converts a real expressed SplatElementsAttr to a corresponding
+/// SplatElementsAttr containing quantized storage values assuming the given
+/// quantizedElementType and converter.
+static SparseElementsAttr
+convertSparseElementsAttr(SparseElementsAttr realSparseAttr,
+                          QuantizedType quantizedElementType,
+                          const UniformQuantizedValueConverter &converter) {
+  DenseElementsAttr realDenseAttr = realSparseAttr.getValues();
+  if (!realDenseAttr.isa<DenseFPElementsAttr>()) {
+    return nullptr;
+  }
+  DenseElementsAttr quantDenseAttr =
+      convertDenseFPElementsAttr(realDenseAttr.cast<DenseFPElementsAttr>(),
+                                 quantizedElementType, converter);
+  if (!quantDenseAttr) {
+    return nullptr;
+  }
+
+  // Cast from an expressed-type-based type to storage-type-based type,
+  // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>).
+  ShapedType newSparseType =
+      quantizedElementType.castExpressedToStorageType(realSparseAttr.getType())
+          .dyn_cast_or_null<ShapedType>();
+  if (!newSparseType) {
+    return nullptr;
+  }
+  return SparseElementsAttr::get(newSparseType, realSparseAttr.getIndices(),
+                                 quantDenseAttr);
+}
+
+/// Converts a real expressed Attribute to a corresponding Attribute containing
+/// quantized storage values assuming the given uniform quantizedElementType and
+/// converter.
+Attribute quantizeAttrUniform(Attribute realValue,
+                              UniformQuantizedType quantizedElementType,
+                              const UniformQuantizedValueConverter &converter,
+                              Type &outConvertedType) {
+  // Fork to handle different variants of constants supported.
+  if (realValue.isa<DenseFPElementsAttr>()) {
+    // Dense tensor or vector constant.
+    auto converted = convertDenseFPElementsAttr(
+        realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter);
+    outConvertedType = converted.getType();
+    return converted;
+  } else if (realValue.isa<SparseElementsAttr>()) {
+    // Sparse tensor or vector constant.
+    auto converted = convertSparseElementsAttr(
+        realValue.cast<SparseElementsAttr>(), quantizedElementType, converter);
+    outConvertedType = converted.getType();
+    return converted;
+  } else {
+    // Nothing else matched: try to convert a primitive.
+    return convertPrimitiveValueAttr(realValue, quantizedElementType, converter,
+                                     outConvertedType);
+  }
+}
+
+/// Convert an attribute from a type based on
+/// quantizedElementType.getExpressedType() to one based on
+/// quantizedElementType.getStorageType().
+/// Returns nullptr if the conversion is not supported.
+/// On success, stores the converted type in outConvertedType.
+Attribute quantizeAttr(Attribute realValue, QuantizedType quantizedElementType,
+                       Type &outConvertedType) {
+  // Hard-coded to just support UniformQuantizedType. This will need to
+  // be generalized when there is more than one.
+  auto uniformQuantizedType =
+      quantizedElementType.dyn_cast<UniformQuantizedType>();
+  if (!uniformQuantizedType) {
+    return nullptr;
+  }
+  UniformQuantizedValueConverter converter(uniformQuantizedType);
+  return quantizeAttrUniform(realValue, uniformQuantizedType, converter,
+                             outConvertedType);
+}
+
+} // namespace quant
+} // namespace mlir
diff --git a/third_party/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp b/third_party/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp
new file mode 100644
index 0000000..db8a584
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp
@@ -0,0 +1,73 @@
+//===- UniformSupport.cpp - Support utilities for uniform quant -----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/QuantOps/UniformSupport.h"
+#include "mlir/IR/StandardTypes.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+
+static bool isQuantizablePrimitiveType(Type inputType) {
+  return inputType.isa<FloatType>();
+}
+
+const ExpressedToUniformQuantizedConverter
+ExpressedToUniformQuantizedConverter::forInputType(Type inputType) {
+  switch (inputType.getKind()) {
+  default:
+    if (isQuantizablePrimitiveType(inputType)) {
+      // Supported primitive type (which just is the expressed type).
+      return ExpressedToUniformQuantizedConverter{inputType, inputType};
+    }
+    // Unsupported.
+    return ExpressedToUniformQuantizedConverter{inputType, nullptr};
+  case StandardTypes::RankedTensor:
+  case StandardTypes::UnrankedTensor:
+  case StandardTypes::Vector: {
+    Type elementType = inputType.cast<ShapedType>().getElementType();
+    if (!isQuantizablePrimitiveType(elementType)) {
+      // Unsupported.
+      return ExpressedToUniformQuantizedConverter{inputType, nullptr};
+    }
+    return ExpressedToUniformQuantizedConverter{
+        inputType, inputType.cast<ShapedType>().getElementType()};
+  }
+  }
+}
+
+Type ExpressedToUniformQuantizedConverter::convert(
+    UniformQuantizedType elementalType) const {
+  assert(expressedType && "convert() on unsupported conversion");
+
+  switch (inputType.getKind()) {
+  default:
+    if (isQuantizablePrimitiveType(elementalType)) {
+      // For primitives, just use the new elemental type.
+      return elementalType;
+    }
+    // Unsupported.
+    return nullptr;
+  case StandardTypes::RankedTensor:
+    return RankedTensorType::get(inputType.cast<RankedTensorType>().getShape(),
+                                 elementalType);
+  case StandardTypes::UnrankedTensor:
+    return UnrankedTensorType::get(elementalType);
+  case StandardTypes::Vector:
+    return VectorType::get(inputType.cast<VectorType>().getShape(),
+                           elementalType);
+  }
+}
diff --git a/third_party/mlir/lib/Dialect/SPIRV/CMakeLists.txt b/third_party/mlir/lib/Dialect/SPIRV/CMakeLists.txt
new file mode 100644
index 0000000..2803b90
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/SPIRV/CMakeLists.txt
@@ -0,0 +1,21 @@
+add_llvm_library(MLIRSPIRV
+  DialectRegistration.cpp
+  SPIRVDialect.cpp
+  SPIRVOps.cpp
+  SPIRVTypes.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/SPIRV
+  )
+
+add_dependencies(MLIRSPIRV
+  MLIRSPIRVOpsIncGen
+  MLIRSPIRVEnumsIncGen
+  MLIRSPIRVOpUtilsGen)
+
+target_link_libraries(MLIRSPIRV
+  MLIRIR
+  MLIRParser
+  MLIRSupport)
+
+add_subdirectory(Serialization)
diff --git a/third_party/mlir/lib/Dialect/SPIRV/DialectRegistration.cpp b/third_party/mlir/lib/Dialect/SPIRV/DialectRegistration.cpp
new file mode 100644
index 0000000..63e9e81
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/SPIRV/DialectRegistration.cpp
@@ -0,0 +1,21 @@
+//===- DialectRegistration.cpp - MLIR SPIR-V dialect registration ---------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+
+// Static initialization for SPIR-V dialect registration.
+static mlir::DialectRegistration<mlir::spirv::SPIRVDialect> spirvDialect;
diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
new file mode 100644
index 0000000..622bb22
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -0,0 +1,588 @@
+//===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the SPIR-V dialect in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Parser.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace spirv {
+#include "mlir/Dialect/SPIRV/SPIRVOpUtils.inc"
+} // namespace spirv
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::spirv;
+
+//===----------------------------------------------------------------------===//
+// SPIR-V Dialect
+//===----------------------------------------------------------------------===//
+
+SPIRVDialect::SPIRVDialect(MLIRContext *context)
+    : Dialect(getDialectNamespace(), context) {
+  addTypes<ArrayType, ImageType, PointerType, RuntimeArrayType, StructType>();
+
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/SPIRV/SPIRVOps.cpp.inc"
+      >();
+
+  // Allow unknown operations because SPIR-V is extensible.
+  allowUnknownOperations();
+}
+
+//===----------------------------------------------------------------------===//
+// Type Parsing
+//===----------------------------------------------------------------------===//
+
+// Parses "<number> x" from the beginning of `spec`.
+static bool parseNumberX(StringRef &spec, int64_t &number) {
+  spec = spec.ltrim();
+  if (spec.empty() || !llvm::isDigit(spec.front()))
+    return false;
+
+  number = 0;
+  do {
+    number = number * 10 + spec.front() - '0';
+    spec = spec.drop_front();
+  } while (!spec.empty() && llvm::isDigit(spec.front()));
+
+  spec = spec.ltrim();
+  if (!spec.consume_front("x"))
+    return false;
+
+  return true;
+}
+
+static bool isValidSPIRVScalarType(Type type) {
+  if (type.isa<FloatType>()) {
+    return !type.isBF16();
+  }
+  if (auto intType = type.dyn_cast<IntegerType>()) {
+    return llvm::is_contained(llvm::ArrayRef<unsigned>({1, 8, 16, 32, 64}),
+                              intType.getWidth());
+  }
+  return false;
+}
+
+bool SPIRVDialect::isValidSPIRVType(Type type) const {
+  // Allow SPIR-V dialect types
+  if (&type.getDialect() == this) {
+    return true;
+  }
+  if (isValidSPIRVScalarType(type)) {
+    return true;
+  }
+  if (auto vectorType = type.dyn_cast<VectorType>()) {
+    return (isValidSPIRVScalarType(vectorType.getElementType()) &&
+            vectorType.getNumElements() >= 2 &&
+            vectorType.getNumElements() <= 4);
+  }
+  return false;
+}
+
+static Type parseAndVerifyType(SPIRVDialect const &dialect, StringRef spec,
+                               Location loc) {
+  spec = spec.trim();
+  auto *context = dialect.getContext();
+  auto type = mlir::parseType(spec.trim(), context);
+  if (!type) {
+    emitError(loc, "cannot parse type: ") << spec;
+    return Type();
+  }
+
+  // Allow SPIR-V dialect types
+  if (&type.getDialect() == &dialect)
+    return type;
+
+  // Check other allowed types
+  if (auto t = type.dyn_cast<FloatType>()) {
+    if (type.isBF16()) {
+      emitError(loc, "cannot use 'bf16' to compose SPIR-V types");
+      return Type();
+    }
+  } else if (auto t = type.dyn_cast<IntegerType>()) {
+    if (!llvm::is_contained(llvm::ArrayRef<unsigned>({8, 16, 32, 64}),
+                            t.getWidth())) {
+      emitError(loc, "only 8/16/32/64-bit integer type allowed but found ")
+          << type;
+      return Type();
+    }
+  } else if (auto t = type.dyn_cast<VectorType>()) {
+    if (t.getRank() != 1) {
+      emitError(loc, "only 1-D vector allowed but found ") << t;
+      return Type();
+    }
+    if (t.getNumElements() > 4) {
+      emitError(loc,
+                "vector length has to be less than or equal to 4 but found ")
+          << t.getNumElements();
+      return Type();
+    }
+  } else {
+    emitError(loc, "cannot use ") << type << " to compose SPIR-V types";
+    return Type();
+  }
+
+  return type;
+}
+
+// element-type ::= integer-type
+//                | floating-point-type
+//                | vector-type
+//                | spirv-type
+//
+// array-type ::= `!spv.array<` integer-literal `x` element-type `>`
+static Type parseArrayType(SPIRVDialect const &dialect, StringRef spec,
+                           Location loc) {
+  if (!spec.consume_front("array<") || !spec.consume_back(">")) {
+    emitError(loc, "spv.array delimiter <...> mismatch");
+    return Type();
+  }
+
+  int64_t count = 0;
+  spec = spec.trim();
+  if (!parseNumberX(spec, count)) {
+    emitError(loc, "expected array element count followed by 'x' but found '")
+        << spec << "'";
+    return Type();
+  }
+
+  if (spec.trim().empty()) {
+    emitError(loc, "expected element type");
+    return Type();
+  }
+
+  Type elementType = parseAndVerifyType(dialect, spec, loc);
+  if (!elementType)
+    return Type();
+
+  return ArrayType::get(elementType, count);
+}
+
+// TODO(ravishankarm) : Reorder methods to be utilities first and parse*Type
+// methods in alphabetical order
+//
+// storage-class ::= `UniformConstant`
+//                 | `Uniform`
+//                 | `Workgroup`
+//                 | <and other storage classes...>
+//
+// pointer-type ::= `!spv.ptr<` element-type `,` storage-class `>`
+static Type parsePointerType(SPIRVDialect const &dialect, StringRef spec,
+                             Location loc) {
+  if (!spec.consume_front("ptr<") || !spec.consume_back(">")) {
+    emitError(loc, "spv.ptr delimiter <...> mismatch");
+    return Type();
+  }
+
+  // Split into pointee type and storage class
+  StringRef scSpec, ptSpec;
+  std::tie(ptSpec, scSpec) = spec.rsplit(',');
+  if (scSpec.empty()) {
+    emitError(loc,
+              "expected comma to separate pointee type and storage class in '")
+        << spec << "'";
+    return Type();
+  }
+
+  scSpec = scSpec.trim();
+  auto storageClass = symbolizeStorageClass(scSpec);
+  if (!storageClass) {
+    emitError(loc, "unknown storage class: ") << scSpec;
+    return Type();
+  }
+
+  if (ptSpec.trim().empty()) {
+    emitError(loc, "expected pointee type");
+    return Type();
+  }
+
+  auto pointeeType = parseAndVerifyType(dialect, ptSpec, loc);
+  if (!pointeeType)
+    return Type();
+
+  return PointerType::get(pointeeType, *storageClass);
+}
+
+// runtime-array-type ::= `!spv.rtarray<` element-type `>`
+static Type parseRuntimeArrayType(SPIRVDialect const &dialect, StringRef spec,
+                                  Location loc) {
+  if (!spec.consume_front("rtarray<") || !spec.consume_back(">")) {
+    emitError(loc, "spv.rtarray delimiter <...> mismatch");
+    return Type();
+  }
+
+  if (spec.trim().empty()) {
+    emitError(loc, "expected element type");
+    return Type();
+  }
+
+  Type elementType = parseAndVerifyType(dialect, spec, loc);
+  if (!elementType)
+    return Type();
+
+  return RuntimeArrayType::get(elementType);
+}
+
+// Specialize this function to parse each of the parameters that define an
+// ImageType. By default it assumes this is an enum type.
+template <typename ValTy>
+static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, Location loc,
+                                      StringRef spec) {
+  auto val = spirv::symbolizeEnum<ValTy>()(spec);
+  if (!val) {
+    emitError(loc, "unknown attribute: '") << spec << "'";
+  }
+  return val;
+}
+
+template <>
+Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, Location loc,
+                                    StringRef spec) {
+  // TODO(ravishankarm): Further verify that the element type can be sampled
+  auto ty = parseAndVerifyType(dialect, spec, loc);
+  if (!ty) {
+    return llvm::None;
+  }
+  return ty;
+}
+
+template <>
+Optional<spirv::StructType::LayoutInfo>
+parseAndVerify(SPIRVDialect const &dialect, Location loc, StringRef spec) {
+  uint64_t offsetVal = std::numeric_limits<uint64_t>::max();
+  if (!spec.consume_front("[")) {
+    emitError(loc, "expected '[' while parsing layout specification in '")
+        << spec << "'";
+    return llvm::None;
+  }
+  if (spec.consumeInteger(10, offsetVal)) {
+    emitError(
+        loc,
+        "expected unsigned integer to specify offset of member in struct: '")
+        << spec << "'";
+    return llvm::None;
+  }
+  spec = spec.trim();
+  if (!spec.consume_front("]")) {
+    emitError(loc, "missing ']' in decorations spec: '") << spec << "'";
+    return llvm::None;
+  }
+  if (spec != "") {
+    emitError(loc, "unexpected extra tokens in layout information: '")
+        << spec << "'";
+    return llvm::None;
+  }
+  return spirv::StructType::LayoutInfo{offsetVal};
+}
+
+// Functor object to parse a comma separated list of specs. The function
+// parseAndVerify does the actual parsing and verification of individual
+// elements. This is a functor since parsing the last element of the list
+// (termination condition) needs partial specialization.
+template <typename ParseType, typename... Args> struct parseCommaSeparatedList {
+  Optional<std::tuple<ParseType, Args...>>
+  operator()(SPIRVDialect const &dialect, Location loc, StringRef spec) const {
+    auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
+    StringRef parseSpec, restSpec;
+    std::tie(parseSpec, restSpec) = spec.split(',');
+
+    parseSpec = parseSpec.trim();
+    if (numArgs != 0 && restSpec.empty()) {
+      emitError(loc, "expected more parameters for image type '")
+          << parseSpec << "'";
+      return llvm::None;
+    }
+
+    auto parseVal = parseAndVerify<ParseType>(dialect, loc, parseSpec);
+    if (!parseVal) {
+      return llvm::None;
+    }
+
+    auto remainingValues =
+        parseCommaSeparatedList<Args...>{}(dialect, loc, restSpec);
+    if (!remainingValues) {
+      return llvm::None;
+    }
+    return std::tuple_cat(std::tuple<ParseType>(parseVal.getValue()),
+                          remainingValues.getValue());
+  }
+};
+
+// Partial specialization of the function to parse a comma separated list of
+// specs to parse the last element of the list.
+template <typename ParseType> struct parseCommaSeparatedList<ParseType> {
+  Optional<std::tuple<ParseType>>
+  operator()(SPIRVDialect const &dialect, Location loc, StringRef spec) const {
+    spec = spec.trim();
+    auto value = parseAndVerify<ParseType>(dialect, loc, spec);
+    if (!value) {
+      return llvm::None;
+    }
+    return std::tuple<ParseType>(value.getValue());
+  }
+};
+
+// dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
+//
+// depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
+//
+// arrayed-info ::= `NonArrayed` | `Arrayed`
+//
+// sampling-info ::= `SingleSampled` | `MultiSampled`
+//
+// sampler-use-info ::= `SamplerUnknown` | `NeedSampler` |  `NoSampler`
+//
+// format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
+//
+// image-type ::= `!spv.image<` element-type `,` dim `,` depth-info `,`
+//                              arrayed-info `,` sampling-info `,`
+//                              sampler-use-info `,` format `>`
+static Type parseImageType(SPIRVDialect const &dialect, StringRef spec,
+                           Location loc) {
+  if (!spec.consume_front("image<") || !spec.consume_back(">")) {
+    emitError(loc, "spv.image delimiter <...> mismatch");
+    return Type();
+  }
+
+  auto value =
+      parseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
+                              ImageSamplingInfo, ImageSamplerUseInfo,
+                              ImageFormat>{}(dialect, loc, spec);
+  if (!value) {
+    return Type();
+  }
+
+  return ImageType::get(value.getValue());
+}
+
+// Method to parse one member of a struct (including Layout information)
+static ParseResult
+parseStructElement(SPIRVDialect const &dialect, StringRef spec, Location loc,
+                   SmallVectorImpl<Type> &memberTypes,
+                   SmallVectorImpl<StructType::LayoutInfo> &layoutInfo) {
+  // Check for a '[' <layoutInfo> ']'
+  auto lastLSquare = spec.rfind('[');
+  auto typeSpec = spec.substr(0, lastLSquare);
+  auto layoutSpec = (lastLSquare == StringRef::npos ? StringRef("")
+                                                    : spec.substr(lastLSquare));
+  auto type = parseAndVerify<Type>(dialect, loc, typeSpec);
+  if (!type) {
+    return failure();
+  }
+  memberTypes.push_back(type.getValue());
+  if (layoutSpec.empty()) {
+    return success();
+  }
+  if (layoutInfo.size() != memberTypes.size() - 1) {
+    emitError(loc, "layout specification must be given for all members");
+    return failure();
+  }
+  auto layout =
+      parseAndVerify<StructType::LayoutInfo>(dialect, loc, layoutSpec);
+  if (!layout) {
+    return failure();
+  }
+  layoutInfo.push_back(layout.getValue());
+  return success();
+}
+
+// Helper method to record the position of the corresponding '>' for every '<'
+// encountered when parsing the string left to right. The relative position of
+// '>' w.r.t to the '<' is recorded.
+static bool
+computeMatchingRAngles(Location loc, StringRef const &spec,
+                       SmallVectorImpl<size_t> &matchingRAngleOffset) {
+  SmallVector<size_t, 4> openBrackets;
+  for (size_t i = 0, e = spec.size(); i != e; ++i) {
+    if (spec[i] == '<') {
+      openBrackets.push_back(i);
+    } else if (spec[i] == '>') {
+      if (openBrackets.empty()) {
+        emitError(loc, "unbalanced '<' in '") << spec << "'";
+        return false;
+      }
+      matchingRAngleOffset.push_back(i - openBrackets.pop_back_val());
+    }
+  }
+  return true;
+}
+
+static ParseResult
+parseStructHelper(SPIRVDialect const &dialect, StringRef spec, Location loc,
+                  ArrayRef<size_t> matchingRAngleOffset,
+                  SmallVectorImpl<Type> &memberTypes,
+                  SmallVectorImpl<StructType::LayoutInfo> &layoutInfo) {
+  // Check if the occurrence of ',' or '<' is before. If former, split using
+  // ','. If latter, split using matching '>' to get the entire type
+  // description
+  auto firstComma = spec.find(',');
+  auto firstLAngle = spec.find('<');
+  if (firstLAngle == StringRef::npos && firstComma == StringRef::npos) {
+    return parseStructElement(dialect, spec, loc, memberTypes, layoutInfo);
+  }
+  if (firstLAngle == StringRef::npos || firstComma < firstLAngle) {
+    // Parse the type before the ','
+    if (parseStructElement(dialect, spec.substr(0, firstComma), loc,
+                           memberTypes, layoutInfo)) {
+      return failure();
+    }
+    return parseStructHelper(dialect, spec.substr(firstComma + 1).ltrim(), loc,
+                             matchingRAngleOffset, memberTypes, layoutInfo);
+  }
+  auto matchingRAngle = matchingRAngleOffset.front() + firstLAngle;
+  // Find the next ',' or '>'
+  auto endLoc = std::min(spec.find(',', matchingRAngle + 1), spec.size());
+  if (parseStructElement(dialect, spec.substr(0, endLoc), loc, memberTypes,
+                         layoutInfo)) {
+    return failure();
+  }
+  auto rest = spec.substr(endLoc + 1).ltrim();
+  if (rest.empty()) {
+    return success();
+  }
+  if (rest.front() == ',') {
+    return parseStructHelper(
+        dialect, rest.drop_front().trim(), loc,
+        ArrayRef<size_t>(std::next(matchingRAngleOffset.begin()),
+                         matchingRAngleOffset.end()),
+        memberTypes, layoutInfo);
+  }
+  emitError(loc, "unexpected string : '") << rest << "'";
+  return failure();
+}
+
+// struct-type ::= `!spv.struct<` spirv-type (` [` integer-literal `]`)?
+//                 (`, ` spirv-type ( ` [` integer-literal `] ` )? )*
+static Type parseStructType(SPIRVDialect const &dialect, StringRef spec,
+                            Location loc) {
+  if (!spec.consume_front("struct<") || !spec.consume_back(">")) {
+    emitError(loc, "spv.struct delimiter <...> mismatch");
+    return Type();
+  }
+
+  if (spec.trim().empty()) {
+    emitError(loc, "expected SPIR-V type");
+    return Type();
+  }
+
+  SmallVector<Type, 4> memberTypes;
+  SmallVector<StructType::LayoutInfo, 4> layoutInfo;
+  SmallVector<size_t, 4> matchingRAngleOffset;
+  if (!computeMatchingRAngles(loc, spec, matchingRAngleOffset) ||
+      parseStructHelper(dialect, spec, loc, matchingRAngleOffset, memberTypes,
+                        layoutInfo)) {
+    return Type();
+  }
+  if (layoutInfo.empty()) {
+    return StructType::get(memberTypes);
+  }
+  if (memberTypes.size() != layoutInfo.size()) {
+    emitError(loc, "layout specification must be given for all members");
+    return Type();
+  }
+  return StructType::get(memberTypes, layoutInfo);
+}
+
+// spirv-type ::= array-type
+//              | element-type
+//              | image-type
+//              | pointer-type
+//              | runtime-array-type
+//              | struct-type
+Type SPIRVDialect::parseType(StringRef spec, Location loc) const {
+  if (spec.startswith("array"))
+    return parseArrayType(*this, spec, loc);
+  if (spec.startswith("image"))
+    return parseImageType(*this, spec, loc);
+  if (spec.startswith("ptr"))
+    return parsePointerType(*this, spec, loc);
+  if (spec.startswith("rtarray"))
+    return parseRuntimeArrayType(*this, spec, loc);
+  if (spec.startswith("struct"))
+    return parseStructType(*this, spec, loc);
+
+  emitError(loc, "unknown SPIR-V type: ") << spec;
+  return Type();
+}
+
+//===----------------------------------------------------------------------===//
+// Type Printing
+//===----------------------------------------------------------------------===//
+
+static void print(ArrayType type, llvm::raw_ostream &os) {
+  os << "array<" << type.getNumElements() << " x " << type.getElementType()
+     << ">";
+}
+
+static void print(RuntimeArrayType type, llvm::raw_ostream &os) {
+  os << "rtarray<" << type.getElementType() << ">";
+}
+
+static void print(PointerType type, llvm::raw_ostream &os) {
+  os << "ptr<" << type.getPointeeType() << ", "
+     << stringifyStorageClass(type.getStorageClass()) << ">";
+}
+
+static void print(ImageType type, llvm::raw_ostream &os) {
+  os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
+     << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
+     << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
+     << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
+     << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
+     << stringifyImageFormat(type.getImageFormat()) << ">";
+}
+
+static void print(StructType type, llvm::raw_ostream &os) {
+  os << "struct<";
+  auto printMember = [&](unsigned i) {
+    os << type.getElementType(i);
+    if (type.hasLayout()) {
+      os << " [" << type.getOffset(i) << "]";
+    }
+  };
+  mlir::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
+                        printMember);
+  os << ">";
+}
+
+void SPIRVDialect::printType(Type type, llvm::raw_ostream &os) const {
+  switch (type.getKind()) {
+  case TypeKind::Array:
+    print(type.cast<ArrayType>(), os);
+    return;
+  case TypeKind::Pointer:
+    print(type.cast<PointerType>(), os);
+    return;
+  case TypeKind::RuntimeArray:
+    print(type.cast<RuntimeArrayType>(), os);
+    return;
+  case TypeKind::Image:
+    print(type.cast<ImageType>(), os);
+    return;
+  case TypeKind::Struct:
+    print(type.cast<StructType>(), os);
+    return;
+  default:
+    llvm_unreachable("unhandled SPIR-V type");
+  }
+}
diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
new file mode 100644
index 0000000..cdd1013
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -0,0 +1,1020 @@
+//===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines the operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+
+#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Support/StringExtras.h"
+
+using namespace mlir;
+
+// TODO(antiagainst): generate these strings using ODS.
+static constexpr const char kAlignmentAttrName[] = "alignment";
+static constexpr const char kIndicesAttrName[] = "indices";
+static constexpr const char kIsSpecConstName[] = "is_spec_const";
+static constexpr const char kValueAttrName[] = "value";
+static constexpr const char kValuesAttrName[] = "values";
+static constexpr const char kFnNameAttrName[] = "fn";
+
+//===----------------------------------------------------------------------===//
+// Common utility functions
+//===----------------------------------------------------------------------===//
+
+template <typename Dst, typename Src>
+inline Dst bitwiseCast(Src source) noexcept {
+  Dst dest;
+  static_assert(sizeof(source) == sizeof(dest),
+                "bitwiseCast requires same source and destination bitwidth");
+  std::memcpy(&dest, &source, sizeof(dest));
+  return dest;
+}
+
+static LogicalResult extractValueFromConstOp(Operation *op,
+                                             int32_t &indexValue) {
+  auto constOp = llvm::dyn_cast<spirv::ConstantOp>(op);
+  if (!constOp) {
+    return failure();
+  }
+  auto valueAttr = constOp.value();
+  auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
+  if (!integerValueAttr) {
+    return failure();
+  }
+  indexValue = integerValueAttr.getInt();
+  return success();
+}
+
+static ParseResult parseBinaryLogicalOp(OpAsmParser *parser,
+                                        OperationState *result) {
+  SmallVector<OpAsmParser::OperandType, 2> ops;
+  Type type;
+  if (parser->parseOperandList(ops, 2) || parser->parseColonType(type) ||
+      parser->resolveOperands(ops, type, result->operands)) {
+    return failure();
+  }
+  // Result must be a scalar or vector of boolean type.
+  Type resultType = parser->getBuilder().getIntegerType(1);
+  if (auto opsType = type.dyn_cast<VectorType>()) {
+    resultType = VectorType::get(opsType.getNumElements(), resultType);
+  }
+  result->addTypes(resultType);
+  return success();
+}
+
+template <typename EnumClass>
+static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser) {
+  Attribute attrVal;
+  SmallVector<NamedAttribute, 1> attr;
+  auto loc = parser->getCurrentLocation();
+  if (parser->parseAttribute(attrVal, parser->getBuilder().getNoneType(),
+                             spirv::attributeName<EnumClass>(), attr)) {
+    return failure();
+  }
+  if (!attrVal.isa<StringAttr>()) {
+    return parser->emitError(loc, "expected ")
+           << spirv::attributeName<EnumClass>()
+           << " attribute specified as string";
+  }
+  auto attrOptional =
+      spirv::symbolizeEnum<EnumClass>()(attrVal.cast<StringAttr>().getValue());
+  if (!attrOptional) {
+    return parser->emitError(loc, "invalid ")
+           << spirv::attributeName<EnumClass>()
+           << " attribute specification: " << attrVal;
+  }
+  value = attrOptional.getValue();
+  return success();
+}
+
+template <typename EnumClass>
+static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser,
+                                      OperationState *state) {
+  if (parseEnumAttribute(value, parser)) {
+    return failure();
+  }
+  state->addAttribute(
+      spirv::attributeName<EnumClass>(),
+      parser->getBuilder().getI32IntegerAttr(bitwiseCast<int32_t>(value)));
+  return success();
+}
+
+static ParseResult parseMemoryAccessAttributes(OpAsmParser *parser,
+                                               OperationState *state) {
+  // Parse an optional list of attributes staring with '['
+  if (parser->parseOptionalLSquare()) {
+    // Nothing to do
+    return success();
+  }
+
+  spirv::MemoryAccess memoryAccessAttr;
+  if (parseEnumAttribute(memoryAccessAttr, parser, state)) {
+    return failure();
+  }
+
+  if (memoryAccessAttr == spirv::MemoryAccess::Aligned) {
+    // Parse integer attribute for alignment.
+    Attribute alignmentAttr;
+    Type i32Type = parser->getBuilder().getIntegerType(32);
+    if (parser->parseComma() ||
+        parser->parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName,
+                               state->attributes)) {
+      return failure();
+    }
+  }
+  return parser->parseRSquare();
+}
+
+// Parses an op that has no inputs and no outputs.
+static ParseResult parseNoIOOp(OpAsmParser *parser, OperationState *state) {
+  if (parser->parseOptionalAttributeDict(state->attributes))
+    return failure();
+  return success();
+}
+
+static void printBinaryLogicalOp(Operation *logicalOp, OpAsmPrinter *printer) {
+  *printer << logicalOp->getName() << ' ' << *logicalOp->getOperand(0) << ", "
+           << *logicalOp->getOperand(1);
+  *printer << " : " << logicalOp->getOperand(0)->getType();
+}
+
+template <typename LoadStoreOpTy>
+static void
+printMemoryAccessAttribute(LoadStoreOpTy loadStoreOp, OpAsmPrinter *printer,
+                           SmallVectorImpl<StringRef> &elidedAttrs) {
+  // Print optional memory access attribute.
+  if (auto memAccess = loadStoreOp.memory_access()) {
+    elidedAttrs.push_back(spirv::attributeName<spirv::MemoryAccess>());
+    *printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
+
+    // Print integer alignment attribute.
+    if (auto alignment = loadStoreOp.alignment()) {
+      elidedAttrs.push_back(kAlignmentAttrName);
+      *printer << ", " << alignment;
+    }
+    *printer << "]";
+  }
+  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
+}
+
+template <typename LoadStoreOpTy>
+static LogicalResult verifyMemoryAccessAttribute(LoadStoreOpTy loadStoreOp) {
+  // ODS checks for attributes values. Just need to verify that if the
+  // memory-access attribute is Aligned, then the alignment attribute must be
+  // present.
+  auto *op = loadStoreOp.getOperation();
+  auto memAccessAttr = op->getAttr(spirv::attributeName<spirv::MemoryAccess>());
+  if (!memAccessAttr) {
+    // Alignment attribute shouldn't be present if memory access attribute is
+    // not present.
+    if (op->getAttr(kAlignmentAttrName)) {
+      return loadStoreOp.emitOpError(
+          "invalid alignment specification without aligned memory access "
+          "specification");
+    }
+    return success();
+  }
+
+  auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
+  auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
+
+  if (!memAccess) {
+    return loadStoreOp.emitOpError("invalid memory access specifier: ")
+           << memAccessVal;
+  }
+
+  if (*memAccess == spirv::MemoryAccess::Aligned) {
+    if (!op->getAttr(kAlignmentAttrName)) {
+      return loadStoreOp.emitOpError("missing alignment value");
+    }
+  } else {
+    if (op->getAttr(kAlignmentAttrName)) {
+      return loadStoreOp.emitOpError(
+          "invalid alignment specification with non-aligned memory access "
+          "specification");
+    }
+  }
+  return success();
+}
+
+template <typename LoadStoreOpTy>
+static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value *ptr,
+                                                   Value *val) {
+  // ODS already checks ptr is spirv::PointerType. Just check that the pointee
+  // type of the pointer and the type of the value are the same
+  //
+  // TODO(ravishankarm): Check that the value type satisfies restrictions of
+  // SPIR-V OpLoad/OpStore operations
+  if (val->getType() !=
+      ptr->getType().cast<spirv::PointerType>().getPointeeType()) {
+    return op.emitOpError("mismatch in result type and pointer type");
+  }
+  return success();
+}
+
+// Prints an op that has no inputs and no outputs.
+static void printNoIOOp(Operation *op, OpAsmPrinter *printer) {
+  *printer << op->getName();
+  printer->printOptionalAttrDict(op->getAttrs());
+}
+
+//===----------------------------------------------------------------------===//
+// spv.AccessChainOp
+//===----------------------------------------------------------------------===//
+
+static Type getElementPtrType(Type type, ArrayRef<Value *> indices,
+                              Location baseLoc) {
+  if (!indices.size()) {
+    emitError(baseLoc, "'spv.AccessChain' op expected at least "
+                       "one index ");
+    return nullptr;
+  }
+
+  auto ptrType = type.dyn_cast<spirv::PointerType>();
+  if (!ptrType) {
+    emitError(baseLoc, "'spv.AccessChain' op expected a pointer "
+                       "to composite type, but provided ")
+        << type;
+    return nullptr;
+  }
+
+  auto resultType = ptrType.getPointeeType();
+  auto resultStorageClass = ptrType.getStorageClass();
+  int32_t index = 0;
+
+  for (auto indexSSA : indices) {
+    auto cType = resultType.dyn_cast<spirv::CompositeType>();
+    if (!cType) {
+      emitError(baseLoc,
+                "'spv.AccessChain' op cannot extract from non-composite type ")
+          << resultType << " with index " << index;
+      return nullptr;
+    }
+    index = 0;
+    if (resultType.isa<spirv::StructType>()) {
+      Operation *op = indexSSA->getDefiningOp();
+      if (!op) {
+        emitError(baseLoc, "'spv.AccessChain' op index must be an "
+                           "integer spv.constant to access "
+                           "element of spv.struct");
+        return nullptr;
+      }
+
+      // TODO(denis0x0D): this should be relaxed to allow
+      // integer literals of other bitwidths.
+      if (failed(extractValueFromConstOp(op, index))) {
+        emitError(baseLoc,
+                  "'spv.AccessChain' index must be an integer spv.constant to "
+                  "access element of spv.struct, but provided ")
+            << op->getName();
+        return nullptr;
+      }
+      if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
+        emitError(baseLoc, "'spv.AccessChain' op index ")
+            << index << " out of bounds for " << resultType;
+        return nullptr;
+      }
+    }
+    resultType = cType.getElementType(index);
+  }
+  return spirv::PointerType::get(resultType, resultStorageClass);
+}
+
+static ParseResult parseAccessChainOp(OpAsmParser *parser,
+                                      OperationState *state) {
+  OpAsmParser::OperandType ptrInfo;
+  SmallVector<OpAsmParser::OperandType, 4> indicesInfo;
+  Type type;
+  // TODO(denis0x0D): regarding to the spec an index must be any integer type,
+  // figure out how to use resolveOperand with a range of types and do not
+  // fail on first attempt.
+  Type indicesType = parser->getBuilder().getIntegerType(32);
+
+  if (parser->parseOperand(ptrInfo) ||
+      parser->parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
+      parser->parseColonType(type) ||
+      parser->resolveOperand(ptrInfo, type, state->operands) ||
+      parser->resolveOperands(indicesInfo, indicesType, state->operands)) {
+    return failure();
+  }
+
+  Location baseLoc = state->operands.front()->getLoc();
+  auto resultType = getElementPtrType(
+      type, llvm::makeArrayRef(state->operands).drop_front(), baseLoc);
+  if (!resultType) {
+    return failure();
+  }
+
+  state->addTypes(resultType);
+  return success();
+}
+
+static void print(spirv::AccessChainOp op, OpAsmPrinter *printer) {
+  *printer << spirv::AccessChainOp::getOperationName() << ' ' << *op.base_ptr()
+           << '[';
+  printer->printOperands(op.indices());
+  *printer << "] : " << op.base_ptr()->getType();
+}
+
+static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
+  SmallVector<Value *, 4> indices(accessChainOp.indices().begin(),
+                                  accessChainOp.indices().end());
+  auto resultType = getElementPtrType(accessChainOp.base_ptr()->getType(),
+                                      indices, accessChainOp.getLoc());
+  if (!resultType) {
+    return failure();
+  }
+
+  auto providedResultType =
+      accessChainOp.getType().dyn_cast<spirv::PointerType>();
+  if (!providedResultType) {
+    return accessChainOp.emitOpError(
+               "result type must be a pointer, but provided")
+           << providedResultType;
+  }
+
+  if (resultType != providedResultType) {
+    return accessChainOp.emitOpError("invalid result type: expected ")
+           << resultType << ", but provided " << providedResultType;
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spv.CompositeExtractOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCompositeExtractOp(OpAsmParser *parser,
+                                           OperationState *state) {
+  OpAsmParser::OperandType compositeInfo;
+  Attribute indicesAttr;
+  Type compositeType;
+  llvm::SMLoc attrLocation;
+  int32_t index;
+
+  if (parser->parseOperand(compositeInfo) ||
+      parser->getCurrentLocation(&attrLocation) ||
+      parser->parseAttribute(indicesAttr, kIndicesAttrName,
+                             state->attributes) ||
+      parser->parseColonType(compositeType) ||
+      parser->resolveOperand(compositeInfo, compositeType, state->operands)) {
+    return failure();
+  }
+
+  auto indicesArrayAttr = indicesAttr.dyn_cast<ArrayAttr>();
+  if (!indicesArrayAttr) {
+    return parser->emitError(
+        attrLocation,
+        "expected an 32-bit integer array attribute for 'indices'");
+  }
+
+  if (!indicesArrayAttr.size()) {
+    return parser->emitError(
+        attrLocation, "expected at least one index for spv.CompositeExtract");
+  }
+
+  Type resultType = compositeType;
+  for (auto indexAttr : indicesArrayAttr) {
+    if (auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>()) {
+      index = indexIntAttr.getInt();
+    } else {
+      return parser->emitError(
+                 attrLocation,
+                 "expexted an 32-bit integer for index, but found '")
+             << indexAttr << "'";
+    }
+
+    if (auto cType = resultType.dyn_cast<spirv::CompositeType>()) {
+      if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
+        return parser->emitError(attrLocation, "index ")
+               << index << " out of bounds for " << resultType;
+      }
+      resultType = cType.getElementType(index);
+    } else {
+      return parser->emitError(attrLocation,
+                               "cannot extract from non-composite type ")
+             << resultType << " with index " << index;
+    }
+  }
+
+  state->addTypes(resultType);
+  return success();
+}
+
+static void print(spirv::CompositeExtractOp compositeExtractOp,
+                  OpAsmPrinter *printer) {
+  *printer << spirv::CompositeExtractOp::getOperationName() << ' '
+           << *compositeExtractOp.composite() << compositeExtractOp.indices()
+           << " : " << compositeExtractOp.composite()->getType();
+}
+
+static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
+  auto resultType = compExOp.composite()->getType();
+  auto indicesArrayAttr = compExOp.indices().dyn_cast<ArrayAttr>();
+
+  if (!indicesArrayAttr.size()) {
+    return compExOp.emitOpError(
+        "expexted at least one index for spv.CompositeExtractOp");
+  }
+
+  int32_t index;
+  for (auto indexAttr : indicesArrayAttr) {
+    index = indexAttr.dyn_cast<IntegerAttr>().getInt();
+    if (auto cType = resultType.dyn_cast<spirv::CompositeType>()) {
+      if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
+        return compExOp.emitOpError("index ")
+               << index << " out of bounds for " << resultType;
+      }
+      resultType = cType.getElementType(index);
+    } else {
+      return compExOp.emitError("cannot extract from non-composite type ")
+             << resultType << " with index " << index;
+    }
+  }
+
+  if (resultType != compExOp.getType()) {
+    return compExOp.emitOpError("invalid result type: expected ")
+           << resultType << " but provided " << compExOp.getType();
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spv.constant
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseConstantOp(OpAsmParser *parser, OperationState *state) {
+  if (succeeded(parser->parseOptionalKeyword("spec")))
+    state->addAttribute(kIsSpecConstName, parser->getBuilder().getUnitAttr());
+
+  Attribute value;
+  if (parser->parseAttribute(value, kValueAttrName, state->attributes))
+    return failure();
+
+  Type type;
+  if (value.getType().isa<NoneType>()) {
+    if (parser->parseColonType(type))
+      return failure();
+  } else {
+    type = value.getType();
+  }
+
+  return parser->addTypeToList(type, state->types);
+}
+
+static void print(spirv::ConstantOp constOp, OpAsmPrinter *printer) {
+  *printer << spirv::ConstantOp::getOperationName()
+           << (constOp.is_spec_const() ? " spec " : " ") << constOp.value();
+  if (constOp.getType().isa<spirv::ArrayType>()) {
+    *printer << " : " << constOp.getType();
+  }
+}
+
+static LogicalResult verify(spirv::ConstantOp constOp) {
+  auto opType = constOp.getType();
+  auto value = constOp.value();
+  auto valueType = value.getType();
+
+  // ODS already generates checks to make sure the result type is valid. We just
+  // need to additionally check that the value's attribute type is consistent
+  // with the result type.
+  switch (value.getKind()) {
+  case StandardAttributes::Bool:
+  case StandardAttributes::Integer:
+  case StandardAttributes::Float:
+  case StandardAttributes::DenseElements:
+  case StandardAttributes::SparseElements: {
+    if (valueType != opType)
+      return constOp.emitOpError("result type (")
+             << opType << ") does not match value type (" << valueType << ")";
+    return success();
+  } break;
+  case StandardAttributes::Array: {
+    auto arrayType = opType.dyn_cast<spirv::ArrayType>();
+    if (!arrayType)
+      return constOp.emitOpError(
+          "must have spv.array result type for array value");
+    auto elemType = arrayType.getElementType();
+    for (auto element : value.cast<ArrayAttr>().getValue()) {
+      if (element.getType() != elemType)
+        return constOp.emitOpError(
+            "has array element that are not of result array element type");
+    }
+  } break;
+  default:
+    return constOp.emitOpError("cannot have value of type ") << valueType;
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spv.EntryPoint
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseEntryPointOp(OpAsmParser *parser,
+                                     OperationState *state) {
+  spirv::ExecutionModel execModel;
+  SmallVector<OpAsmParser::OperandType, 0> identifiers;
+  SmallVector<Type, 0> idTypes;
+
+  Attribute fn;
+  auto loc = parser->getCurrentLocation();
+
+  if (parseEnumAttribute(execModel, parser, state) ||
+      parser->parseAttribute(fn, kFnNameAttrName, state->attributes) ||
+      parser->parseTrailingOperandList(identifiers) ||
+      parser->parseOptionalColonTypeList(idTypes) ||
+      parser->resolveOperands(identifiers, idTypes, loc, state->operands)) {
+    return failure();
+  }
+  if (!fn.isa<SymbolRefAttr>()) {
+    return parser->emitError(loc, "expected symbol reference attribute");
+  }
+  return success();
+}
+
+static void print(spirv::EntryPointOp entryPointOp, OpAsmPrinter *printer) {
+  *printer << spirv::EntryPointOp::getOperationName() << " \""
+           << stringifyExecutionModel(entryPointOp.execution_model()) << "\" @"
+           << entryPointOp.fn();
+  if (!entryPointOp.getNumOperands()) {
+    return;
+  }
+  *printer << ", ";
+  mlir::interleaveComma(entryPointOp.getOperands(), printer->getStream(),
+                        [&](Value *a) { printer->printOperand(a); });
+  *printer << " : ";
+  mlir::interleaveComma(entryPointOp.getOperands(), printer->getStream(),
+                        [&](const Value *a) { *printer << a->getType(); });
+}
+
+static LogicalResult verify(spirv::EntryPointOp entryPointOp) {
+  // Verify that all the interface ops are created from VariableOp
+  for (auto interface : entryPointOp.interface()) {
+    if (!llvm::isa_and_nonnull<spirv::VariableOp>(interface->getDefiningOp())) {
+      return entryPointOp.emitOpError("interface operands to entry point must "
+                                      "be generated from a variable op");
+    }
+    // TODO:  Before version 1.4 the variables can only have storage_class of
+    // Input or Output. That needs to be verified.
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spv.ExecutionMode
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseExecutionModeOp(OpAsmParser *parser,
+                                        OperationState *state) {
+  spirv::ExecutionMode execMode;
+  Attribute fn;
+  if (parser->parseAttribute(fn, kFnNameAttrName, state->attributes) ||
+      parseEnumAttribute(execMode, parser, state)) {
+    return failure();
+  }
+
+  SmallVector<int32_t, 4> values;
+  Type i32Type = parser->getBuilder().getIntegerType(32);
+  while (!parser->parseOptionalComma()) {
+    SmallVector<NamedAttribute, 1> attr;
+    Attribute value;
+    if (parser->parseAttribute(value, i32Type, "value", attr)) {
+      return failure();
+    }
+    values.push_back(value.cast<IntegerAttr>().getInt());
+  }
+  state->addAttribute(kValuesAttrName,
+                      parser->getBuilder().getI32ArrayAttr(values));
+  return success();
+}
+
+static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter *printer) {
+  *printer << spirv::ExecutionModeOp::getOperationName() << " @"
+           << execModeOp.fn() << " \""
+           << stringifyExecutionMode(execModeOp.execution_mode()) << "\"";
+  auto values = execModeOp.values();
+  if (!values) {
+    return;
+  }
+  *printer << ", ";
+  mlir::interleaveComma(
+      values.getValue().cast<ArrayAttr>(), printer->getStream(),
+      [&](Attribute a) { *printer << a.cast<IntegerAttr>().getInt(); });
+}
+
+//===----------------------------------------------------------------------===//
+// spv.LoadOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *state) {
+  // Parse the storage class specification
+  spirv::StorageClass storageClass;
+  OpAsmParser::OperandType ptrInfo;
+  Type elementType;
+  if (parseEnumAttribute(storageClass, parser) ||
+      parser->parseOperand(ptrInfo) ||
+      parseMemoryAccessAttributes(parser, state) ||
+      parser->parseOptionalAttributeDict(state->attributes) ||
+      parser->parseColon() || parser->parseType(elementType)) {
+    return failure();
+  }
+
+  auto ptrType = spirv::PointerType::get(elementType, storageClass);
+  if (parser->resolveOperand(ptrInfo, ptrType, state->operands)) {
+    return failure();
+  }
+
+  state->addTypes(elementType);
+  return success();
+}
+
+static void print(spirv::LoadOp loadOp, OpAsmPrinter *printer) {
+  auto *op = loadOp.getOperation();
+  SmallVector<StringRef, 4> elidedAttrs;
+  StringRef sc = stringifyStorageClass(
+      loadOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass());
+  *printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" ";
+  // Print the pointer operand.
+  printer->printOperand(loadOp.ptr());
+
+  printMemoryAccessAttribute(loadOp, printer, elidedAttrs);
+
+  printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
+  *printer << " : " << loadOp.getType();
+}
+
+static LogicalResult verify(spirv::LoadOp loadOp) {
+  // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
+  // type with fixed size; i.e., it cannot be, nor include, any
+  // OpTypeRuntimeArray types."
+  if (failed(verifyLoadStorePtrAndValTypes(loadOp, loadOp.ptr(),
+                                           loadOp.value()))) {
+    return failure();
+  }
+  return verifyMemoryAccessAttribute(loadOp);
+}
+
+//===----------------------------------------------------------------------===//
+// spv.module
+//===----------------------------------------------------------------------===//
+
+void spirv::ModuleOp::build(Builder *builder, OperationState *state) {
+  ensureTerminator(*state->addRegion(), *builder, state->location);
+}
+
+void spirv::ModuleOp::build(Builder *builder, OperationState *state,
+                            IntegerAttr addressing_model,
+                            IntegerAttr memory_model, ArrayAttr capabilities,
+                            ArrayAttr extensions,
+                            ArrayAttr extended_instruction_sets) {
+  state->addAttribute("addressing_model", addressing_model);
+  state->addAttribute("memory_model", memory_model);
+  if (capabilities)
+    state->addAttribute("capabilities", capabilities);
+  if (extensions)
+    state->addAttribute("extensions", extensions);
+  if (extended_instruction_sets)
+    state->addAttribute("extended_instruction_sets", extended_instruction_sets);
+  ensureTerminator(*state->addRegion(), *builder, state->location);
+}
+
+static ParseResult parseModuleOp(OpAsmParser *parser, OperationState *state) {
+  Region *body = state->addRegion();
+
+  // Parse attributes
+  spirv::AddressingModel addrModel;
+  spirv::MemoryModel memoryModel;
+  if (parseEnumAttribute(addrModel, parser, state) ||
+      parseEnumAttribute(memoryModel, parser, state)) {
+    return failure();
+  }
+
+  if (parser->parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
+    return failure();
+
+  if (succeeded(parser->parseOptionalKeyword("attributes"))) {
+    if (parser->parseOptionalAttributeDict(state->attributes))
+      return failure();
+  }
+
+  spirv::ModuleOp::ensureTerminator(*body, parser->getBuilder(),
+                                    state->location);
+  return success();
+}
+
+static void print(spirv::ModuleOp moduleOp, OpAsmPrinter *printer) {
+  auto *op = moduleOp.getOperation();
+
+  // Only print out addressing model and memory model in a nicer way if both
+  // presents. Otherwise, print them in the general form. This helps debugging
+  // ill-formed ModuleOp.
+  SmallVector<StringRef, 2> elidedAttrs;
+  auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
+  auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
+  if (op->getAttr(addressingModelAttrName) &&
+      op->getAttr(memoryModelAttrName)) {
+    *printer << spirv::ModuleOp::getOperationName() << " \""
+             << spirv::stringifyAddressingModel(moduleOp.addressing_model())
+             << "\" \"" << spirv::stringifyMemoryModel(moduleOp.memory_model())
+             << '"';
+    elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName});
+  }
+
+  printer->printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
+                       /*printBlockTerminators=*/false);
+
+  bool printAttrDict =
+      elidedAttrs.size() != 2 ||
+      llvm::any_of(op->getAttrs(), [&addressingModelAttrName,
+                                    &memoryModelAttrName](NamedAttribute attr) {
+        return attr.first != addressingModelAttrName &&
+               attr.first != memoryModelAttrName;
+      });
+
+  if (printAttrDict) {
+    *printer << " attributes";
+    printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
+  }
+}
+
+static LogicalResult verify(spirv::ModuleOp moduleOp) {
+  auto &op = *moduleOp.getOperation();
+  auto *dialect = op.getDialect();
+  auto &body = op.getRegion(0).front();
+  llvm::DenseMap<std::pair<FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
+      entryPoints;
+  SymbolTable table(moduleOp);
+
+  for (auto &op : body) {
+    if (op.getDialect() == dialect) {
+      // For EntryPoint op, check that the function and execution model is not
+      // duplicated in EntryPointOps
+      if (auto entryPointOp = llvm::dyn_cast<spirv::EntryPointOp>(op)) {
+        auto funcOp = table.lookup<FuncOp>(entryPointOp.fn());
+        if (!funcOp) {
+          return entryPointOp.emitError("function '")
+                 << entryPointOp.fn() << "' not found in 'spv.module'";
+        }
+        auto key = std::pair<FuncOp, spirv::ExecutionModel>(
+            funcOp, entryPointOp.execution_model());
+        auto entryPtIt = entryPoints.find(key);
+        if (entryPtIt != entryPoints.end()) {
+          return entryPointOp.emitError("duplicate of a previous EntryPointOp");
+        }
+        entryPoints[key] = entryPointOp;
+      }
+      continue;
+    }
+
+    auto funcOp = llvm::dyn_cast<FuncOp>(op);
+    if (!funcOp)
+      return op.emitError("'spv.module' can only contain func and spv.* ops");
+
+    if (funcOp.isExternal())
+      return op.emitError("'spv.module' cannot contain external functions");
+
+    for (auto &block : funcOp)
+      for (auto &op : block) {
+        if (op.getDialect() == dialect)
+          continue;
+
+        if (llvm::isa<FuncOp>(op))
+          return op.emitError("'spv.module' cannot contain nested functions");
+
+        return op.emitError(
+            "functions in 'spv.module' can only contain spv.* ops");
+      }
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spv.Return
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyReturn(spirv::ReturnOp returnOp) {
+  auto funcOp = llvm::dyn_cast<FuncOp>(returnOp.getOperation()->getParentOp());
+  if (!funcOp)
+    return returnOp.emitOpError("must appear in a 'func' op");
+
+  auto numOutputs = funcOp.getType().getNumResults();
+  if (numOutputs != 0)
+    return returnOp.emitOpError("cannot be used in functions returning value")
+           << (numOutputs > 1 ? "s" : "");
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spv.StoreOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *state) {
+  // Parse the storage class specification
+  spirv::StorageClass storageClass;
+  SmallVector<OpAsmParser::OperandType, 2> operandInfo;
+  auto loc = parser->getCurrentLocation();
+  Type elementType;
+  if (parseEnumAttribute(storageClass, parser) ||
+      parser->parseOperandList(operandInfo, 2) ||
+      parseMemoryAccessAttributes(parser, state) || parser->parseColon() ||
+      parser->parseType(elementType)) {
+    return failure();
+  }
+
+  auto ptrType = spirv::PointerType::get(elementType, storageClass);
+  if (parser->resolveOperands(operandInfo, {ptrType, elementType}, loc,
+                              state->operands)) {
+    return failure();
+  }
+  return success();
+}
+
+static void print(spirv::StoreOp storeOp, OpAsmPrinter *printer) {
+  auto *op = storeOp.getOperation();
+  SmallVector<StringRef, 4> elidedAttrs;
+  StringRef sc = stringifyStorageClass(
+      storeOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass());
+  *printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" ";
+  // Print the pointer operand
+  printer->printOperand(storeOp.ptr());
+  *printer << ", ";
+  // Print the value operand
+  printer->printOperand(storeOp.value());
+
+  printMemoryAccessAttribute(storeOp, printer, elidedAttrs);
+
+  *printer << " : " << storeOp.value()->getType();
+
+  printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
+}
+
+static LogicalResult verify(spirv::StoreOp storeOp) {
+  // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
+  // OpTypePointer whose Type operand is the same as the type of Object."
+  if (failed(verifyLoadStorePtrAndValTypes(storeOp, storeOp.ptr(),
+                                           storeOp.value()))) {
+    return failure();
+  }
+  return verifyMemoryAccessAttribute(storeOp);
+}
+
+//===----------------------------------------------------------------------===//
+// spv.Variable
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseVariableOp(OpAsmParser *parser, OperationState *state) {
+  // Parse optional initializer
+  Optional<OpAsmParser::OperandType> initInfo;
+  if (succeeded(parser->parseOptionalKeyword("init"))) {
+    initInfo = OpAsmParser::OperandType();
+    if (parser->parseLParen() || parser->parseOperand(*initInfo) ||
+        parser->parseRParen())
+      return failure();
+  }
+
+  // Parse optional descriptor binding
+  Attribute set, binding;
+  auto descriptorSetName =
+      convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet));
+  auto bindingName =
+      convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
+  if (succeeded(parser->parseOptionalKeyword("bind"))) {
+    Type i32Type = parser->getBuilder().getIntegerType(32);
+    if (parser->parseLParen() ||
+        parser->parseAttribute(set, i32Type, descriptorSetName,
+                               state->attributes) ||
+        parser->parseComma() ||
+        parser->parseAttribute(binding, i32Type, bindingName,
+                               state->attributes) ||
+        parser->parseRParen())
+      return failure();
+  }
+
+  // Parse other attributes
+  if (parser->parseOptionalAttributeDict(state->attributes))
+    return failure();
+
+  // Parse result pointer type
+  Type type;
+  if (parser->parseColon())
+    return failure();
+  auto loc = parser->getCurrentLocation();
+  if (parser->parseType(type))
+    return failure();
+
+  auto ptrType = type.dyn_cast<spirv::PointerType>();
+  if (!ptrType)
+    return parser->emitError(loc, "expected spv.ptr type");
+  state->addTypes(ptrType);
+
+  // Resolve the initializer operand
+  SmallVector<Value *, 1> init;
+  if (initInfo) {
+    if (parser->resolveOperand(*initInfo, ptrType.getPointeeType(), init))
+      return failure();
+    state->addOperands(init);
+  }
+
+  auto attr = parser->getBuilder().getI32IntegerAttr(
+      bitwiseCast<int32_t>(ptrType.getStorageClass()));
+  state->addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
+
+  return success();
+}
+
+static void print(spirv::VariableOp varOp, OpAsmPrinter *printer) {
+  auto *op = varOp.getOperation();
+  SmallVector<StringRef, 4> elidedAttrs{
+      spirv::attributeName<spirv::StorageClass>()};
+  *printer << spirv::VariableOp::getOperationName();
+
+  // Print optional initializer
+  if (op->getNumOperands() > 0) {
+    *printer << " init(";
+    printer->printOperands(varOp.initializer());
+    *printer << ")";
+  }
+
+  // Print optional descriptor binding
+  auto descriptorSetName =
+      convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet));
+  auto bindingName =
+      convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
+  auto descriptorSet = varOp.getAttrOfType<IntegerAttr>(descriptorSetName);
+  auto binding = varOp.getAttrOfType<IntegerAttr>(bindingName);
+  if (descriptorSet && binding) {
+    elidedAttrs.push_back(descriptorSetName);
+    elidedAttrs.push_back(bindingName);
+    *printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
+             << ")";
+  }
+
+  printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
+  *printer << " : " << varOp.getType();
+}
+
+static LogicalResult verify(spirv::VariableOp varOp) {
+  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
+  // object. It cannot be Generic. It must be the same as the Storage Class
+  // operand of the Result Type."
+  if (varOp.storage_class() == spirv::StorageClass::Generic)
+    return varOp.emitOpError("storage class cannot be 'Generic'");
+
+  auto pointerType = varOp.pointer()->getType().cast<spirv::PointerType>();
+  if (varOp.storage_class() != pointerType.getStorageClass())
+    return varOp.emitOpError(
+        "storage class must match result pointer's storage class");
+
+  if (varOp.getNumOperands() != 0) {
+    // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
+    // a global (module scope) OpVariable instruction".
+    bool valid = false;
+    if (auto *initOp = varOp.getOperand(0)->getDefiningOp()) {
+      if (llvm::isa<spirv::ConstantOp>(initOp)) {
+        valid = true;
+      } else if (llvm::isa<spirv::VariableOp>(initOp)) {
+        valid = llvm::isa_and_nonnull<spirv::ModuleOp>(initOp->getParentOp());
+      }
+    }
+    if (!valid)
+      return varOp.emitOpError("initializer must be the result of a "
+                               "spv.Constant or module-level spv.Variable op");
+  }
+
+  return success();
+}
+
+namespace mlir {
+namespace spirv {
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/SPIRV/SPIRVOps.cpp.inc"
+
+} // namespace spirv
+} // namespace mlir
diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
new file mode 100644
index 0000000..345d13d
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -0,0 +1,428 @@
+//===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines the types in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/IR/StandardTypes.h"
+#include "llvm/ADT/StringSwitch.h"
+
+using namespace mlir;
+using namespace mlir::spirv;
+
+// Pull in all enum utility function definitions
+#include "mlir/Dialect/SPIRV/SPIRVEnums.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// ArrayType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::ArrayTypeStorage : public TypeStorage {
+  using KeyTy = std::pair<Type, unsigned>;
+
+  static ArrayTypeStorage *construct(TypeStorageAllocator &allocator,
+                                     const KeyTy &key) {
+    return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
+  }
+
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(elementType, getSubclassData());
+  }
+
+  ArrayTypeStorage(const KeyTy &key)
+      : TypeStorage(key.second), elementType(key.first) {}
+
+  Type elementType;
+};
+
+ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
+  return Base::get(elementType.getContext(), TypeKind::Array, elementType,
+                   elementCount);
+}
+
+unsigned ArrayType::getNumElements() const {
+  return getImpl()->getSubclassData();
+}
+
+Type ArrayType::getElementType() const { return getImpl()->elementType; }
+
+//===----------------------------------------------------------------------===//
+// CompositeType
+//===----------------------------------------------------------------------===//
+
+Type CompositeType::getElementType(unsigned index) const {
+  switch (getKind()) {
+  case spirv::TypeKind::Array:
+    return cast<ArrayType>().getElementType();
+  case spirv::TypeKind::Struct:
+    return cast<StructType>().getElementType(index);
+  case StandardTypes::Vector:
+    return cast<VectorType>().getElementType();
+  default:
+    llvm_unreachable("invalid composite type");
+  }
+}
+
+unsigned CompositeType::getNumElements() const {
+  switch (getKind()) {
+  case spirv::TypeKind::Array:
+    return cast<ArrayType>().getNumElements();
+  case spirv::TypeKind::Struct:
+    return cast<StructType>().getNumElements();
+  case StandardTypes::Vector:
+    return cast<VectorType>().getNumElements();
+  default:
+    llvm_unreachable("invalid composite type");
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// ImageType
+//===----------------------------------------------------------------------===//
+
+template <typename T> static constexpr unsigned getNumBits() { return 0; }
+template <> constexpr unsigned getNumBits<Dim>() {
+  static_assert((1 << 3) > getMaxEnumValForDim(),
+                "Not enough bits to encode Dim value");
+  return 3;
+}
+template <> constexpr unsigned getNumBits<ImageDepthInfo>() {
+  static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
+                "Not enough bits to encode ImageDepthInfo value");
+  return 2;
+}
+template <> constexpr unsigned getNumBits<ImageArrayedInfo>() {
+  static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
+                "Not enough bits to encode ImageArrayedInfo value");
+  return 1;
+}
+template <> constexpr unsigned getNumBits<ImageSamplingInfo>() {
+  static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
+                "Not enough bits to encode ImageSamplingInfo value");
+  return 1;
+}
+template <> constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
+  static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
+                "Not enough bits to encode ImageSamplerUseInfo value");
+  return 2;
+}
+template <> constexpr unsigned getNumBits<ImageFormat>() {
+  static_assert((1 << 6) > getMaxEnumValForImageFormat(),
+                "Not enough bits to encode ImageFormat value");
+  return 6;
+}
+
+struct spirv::detail::ImageTypeStorage : public TypeStorage {
+private:
+  /// Define a bit-field struct to pack the enum values
+  union EnumPack {
+    struct {
+      unsigned dimEncoding : getNumBits<Dim>();
+      unsigned depthInfoEncoding : getNumBits<ImageDepthInfo>();
+      unsigned arrayedInfoEncoding : getNumBits<ImageArrayedInfo>();
+      unsigned samplingInfoEncoding : getNumBits<ImageSamplingInfo>();
+      unsigned samplerUseInfoEncoding : getNumBits<ImageSamplerUseInfo>();
+      unsigned formatEncoding : getNumBits<ImageFormat>();
+    } data;
+    unsigned storage;
+  };
+
+public:
+  using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
+                           ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
+
+  static ImageTypeStorage *construct(TypeStorageAllocator &allocator,
+                                     const KeyTy &key) {
+    return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
+  }
+
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(elementType, getDim(), getDepthInfo(), getArrayedInfo(),
+                        getSamplingInfo(), getSamplerUseInfo(),
+                        getImageFormat());
+  }
+
+  Dim getDim() const {
+    EnumPack v;
+    v.storage = getSubclassData();
+    return static_cast<Dim>(v.data.dimEncoding);
+  }
+  void setDim(Dim dim) {
+    EnumPack v;
+    v.storage = getSubclassData();
+    v.data.dimEncoding = static_cast<unsigned>(dim);
+    setSubclassData(v.storage);
+  }
+
+  ImageDepthInfo getDepthInfo() const {
+    EnumPack v;
+    v.storage = getSubclassData();
+    return static_cast<ImageDepthInfo>(v.data.depthInfoEncoding);
+  }
+  void setDepthInfo(ImageDepthInfo depthInfo) {
+    EnumPack v;
+    v.storage = getSubclassData();
+    v.data.depthInfoEncoding = static_cast<unsigned>(depthInfo);
+    setSubclassData(v.storage);
+  }
+
+  ImageArrayedInfo getArrayedInfo() const {
+    EnumPack v;
+    v.storage = getSubclassData();
+    return static_cast<ImageArrayedInfo>(v.data.arrayedInfoEncoding);
+  }
+  void setArrayedInfo(ImageArrayedInfo arrayedInfo) {
+    EnumPack v;
+    v.storage = getSubclassData();
+    v.data.arrayedInfoEncoding = static_cast<unsigned>(arrayedInfo);
+    setSubclassData(v.storage);
+  }
+
+  ImageSamplingInfo getSamplingInfo() const {
+    EnumPack v;
+    v.storage = getSubclassData();
+    return static_cast<ImageSamplingInfo>(v.data.samplingInfoEncoding);
+  }
+  void setSamplingInfo(ImageSamplingInfo samplingInfo) {
+    EnumPack v;
+    v.storage = getSubclassData();
+    v.data.samplingInfoEncoding = static_cast<unsigned>(samplingInfo);
+    setSubclassData(v.storage);
+  }
+
+  ImageSamplerUseInfo getSamplerUseInfo() const {
+    EnumPack v;
+    v.storage = getSubclassData();
+    return static_cast<ImageSamplerUseInfo>(v.data.samplerUseInfoEncoding);
+  }
+  void setSamplerUseInfo(ImageSamplerUseInfo samplerUseInfo) {
+    EnumPack v;
+    v.storage = getSubclassData();
+    v.data.samplerUseInfoEncoding = static_cast<unsigned>(samplerUseInfo);
+    setSubclassData(v.storage);
+  }
+
+  ImageFormat getImageFormat() const {
+    EnumPack v;
+    v.storage = getSubclassData();
+    return static_cast<ImageFormat>(v.data.formatEncoding);
+  }
+  void setImageFormat(ImageFormat format) {
+    EnumPack v;
+    v.storage = getSubclassData();
+    v.data.formatEncoding = static_cast<unsigned>(format);
+    setSubclassData(v.storage);
+  }
+
+  ImageTypeStorage(const KeyTy &key) : elementType(std::get<0>(key)) {
+    static_assert(sizeof(EnumPack) <= sizeof(getSubclassData()),
+                  "EnumPack size greater than subClassData type size");
+    setDim(std::get<1>(key));
+    setDepthInfo(std::get<2>(key));
+    setArrayedInfo(std::get<3>(key));
+    setSamplingInfo(std::get<4>(key));
+    setSamplerUseInfo(std::get<5>(key));
+    setImageFormat(std::get<6>(key));
+  }
+
+  Type elementType;
+};
+
+ImageType
+ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
+                          ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
+                   value) {
+  return Base::get(std::get<0>(value).getContext(), TypeKind::Image, value);
+}
+
+Type ImageType::getElementType() const { return getImpl()->elementType; }
+
+Dim ImageType::getDim() const { return getImpl()->getDim(); }
+
+ImageDepthInfo ImageType::getDepthInfo() const {
+  return getImpl()->getDepthInfo();
+}
+
+ImageArrayedInfo ImageType::getArrayedInfo() const {
+  return getImpl()->getArrayedInfo();
+}
+
+ImageSamplingInfo ImageType::getSamplingInfo() const {
+  return getImpl()->getSamplingInfo();
+}
+
+ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
+  return getImpl()->getSamplerUseInfo();
+}
+
+ImageFormat ImageType::getImageFormat() const {
+  return getImpl()->getImageFormat();
+}
+
+//===----------------------------------------------------------------------===//
+// PointerType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::PointerTypeStorage : public TypeStorage {
+  // (Type, StorageClass) as the key: Type stored in this struct, and
+  // StorageClass stored as TypeStorage's subclass data.
+  using KeyTy = std::pair<Type, StorageClass>;
+
+  static PointerTypeStorage *construct(TypeStorageAllocator &allocator,
+                                       const KeyTy &key) {
+    return new (allocator.allocate<PointerTypeStorage>())
+        PointerTypeStorage(key);
+  }
+
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(pointeeType, getStorageClass());
+  }
+
+  PointerTypeStorage(const KeyTy &key)
+      : TypeStorage(static_cast<unsigned>(key.second)), pointeeType(key.first) {
+  }
+
+  StorageClass getStorageClass() const {
+    return static_cast<StorageClass>(getSubclassData());
+  }
+
+  Type pointeeType;
+};
+
+PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
+  return Base::get(pointeeType.getContext(), TypeKind::Pointer, pointeeType,
+                   storageClass);
+}
+
+Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
+
+StorageClass PointerType::getStorageClass() const {
+  return getImpl()->getStorageClass();
+}
+
+//===----------------------------------------------------------------------===//
+// RuntimeArrayType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage {
+  using KeyTy = Type;
+
+  static RuntimeArrayTypeStorage *construct(TypeStorageAllocator &allocator,
+                                            const KeyTy &key) {
+    return new (allocator.allocate<RuntimeArrayTypeStorage>())
+        RuntimeArrayTypeStorage(key);
+  }
+
+  bool operator==(const KeyTy &key) const { return elementType == key; }
+
+  RuntimeArrayTypeStorage(const KeyTy &key) : elementType(key) {}
+
+  Type elementType;
+};
+
+RuntimeArrayType RuntimeArrayType::get(Type elementType) {
+  return Base::get(elementType.getContext(), TypeKind::RuntimeArray,
+                   elementType);
+}
+
+Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
+
+//===----------------------------------------------------------------------===//
+// StructType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::StructTypeStorage : public TypeStorage {
+  StructTypeStorage(unsigned numMembers, Type const *memberTypes,
+                    StructType::LayoutInfo const *layoutInfo)
+      : TypeStorage(numMembers), memberTypes(memberTypes),
+        layoutInfo(layoutInfo) {}
+
+  using KeyTy = std::pair<ArrayRef<Type>, ArrayRef<StructType::LayoutInfo>>;
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(getMemberTypes(), getLayoutInfo());
+  }
+
+  static StructTypeStorage *construct(TypeStorageAllocator &allocator,
+                                      const KeyTy &key) {
+    ArrayRef<Type> keyTypes = key.first;
+
+    // Copy the member type and layout information into the bump pointer
+    auto typesList = allocator.copyInto(keyTypes).data();
+
+    const StructType::LayoutInfo *layoutInfoList = nullptr;
+    if (!key.second.empty()) {
+      ArrayRef<StructType::LayoutInfo> keyLayoutInfo = key.second;
+      assert(keyLayoutInfo.size() == keyTypes.size() &&
+             "size of layout information must be same as the size of number of "
+             "elements");
+      layoutInfoList = allocator.copyInto(keyLayoutInfo).data();
+    }
+
+    return new (allocator.allocate<StructTypeStorage>())
+        StructTypeStorage(keyTypes.size(), typesList, layoutInfoList);
+  }
+
+  ArrayRef<Type> getMemberTypes() const {
+    return ArrayRef<Type>(memberTypes, getSubclassData());
+  }
+
+  ArrayRef<StructType::LayoutInfo> getLayoutInfo() const {
+    if (layoutInfo) {
+      return ArrayRef<StructType::LayoutInfo>(layoutInfo, getSubclassData());
+    }
+    return ArrayRef<StructType::LayoutInfo>(nullptr, size_t(0));
+  }
+
+  Type const *memberTypes;
+  StructType::LayoutInfo const *layoutInfo;
+};
+
+StructType StructType::get(ArrayRef<Type> memberTypes) {
+  assert(!memberTypes.empty() && "Struct needs at least one member type");
+  ArrayRef<StructType::LayoutInfo> noLayout(nullptr, size_t(0));
+  return Base::get(memberTypes[0].getContext(), TypeKind::Struct, memberTypes,
+                   noLayout);
+}
+
+StructType StructType::get(ArrayRef<Type> memberTypes,
+                           ArrayRef<StructType::LayoutInfo> layoutInfo) {
+  assert(!memberTypes.empty() && "Struct needs at least one member type");
+  return Base::get(memberTypes.vec().front().getContext(), TypeKind::Struct,
+                   memberTypes, layoutInfo);
+}
+
+unsigned StructType::getNumElements() const {
+  return getImpl()->getSubclassData();
+}
+
+Type StructType::getElementType(unsigned index) const {
+  assert(
+      getNumElements() > index &&
+      "element index is more than number of members of the SPIR-V StructType");
+  return getImpl()->memberTypes[index];
+}
+
+bool StructType::hasLayout() const { return getImpl()->layoutInfo; }
+
+uint64_t StructType::getOffset(unsigned index) const {
+  assert(
+      getNumElements() > index &&
+      "element index is more than number of members of the SPIR-V StructType");
+  return getImpl()->layoutInfo[index];
+}
diff --git a/third_party/mlir/lib/Dialect/SPIRV/Serialization/CMakeLists.txt b/third_party/mlir/lib/Dialect/SPIRV/Serialization/CMakeLists.txt
new file mode 100644
index 0000000..e652bf3
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/SPIRV/Serialization/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_llvm_library(MLIRSPIRVSerialization
+  ConvertFromBinary.cpp
+  ConvertToBinary.cpp
+  Deserializer.cpp
+  Serializer.cpp
+  SPIRVBinaryUtils.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/SPIRV
+  )
+
+add_dependencies(MLIRSPIRVSerialization
+  MLIRSPIRVSerializationGen)
+
+target_link_libraries(MLIRSPIRVSerialization
+  MLIRIR
+  MLIRSPIRV
+  MLIRSupport)
diff --git a/third_party/mlir/lib/Dialect/SPIRV/Serialization/ConvertFromBinary.cpp b/third_party/mlir/lib/Dialect/SPIRV/Serialization/ConvertFromBinary.cpp
new file mode 100644
index 0000000..38e8d93
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/SPIRV/Serialization/ConvertFromBinary.cpp
@@ -0,0 +1,95 @@
+//===- ConvertFromBinary.cpp - MLIR SPIR-V binary to module conversion ----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a translation from SPIR-V binary module to MLIR SPIR-V
+// ModuleOp.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/Serialization.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Translation.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/MemoryBuffer.h"
+
+using namespace mlir;
+
+// Adds a one-block function named as `spirv_module` to `module` and returns the
+// block. The created block will be terminated by `std.return`.
+Block *createOneBlockFunction(Builder builder, ModuleOp module) {
+  auto fnType = builder.getFunctionType(/*inputs=*/{}, /*results=*/{});
+  auto fn = FuncOp::create(builder.getUnknownLoc(), "spirv_module", fnType);
+  module.push_back(fn);
+
+  auto *block = fn.addEntryBlock();
+  OpBuilder(block).create<ReturnOp>(builder.getUnknownLoc());
+
+  return block;
+}
+
+// Deserializes the SPIR-V binary module stored in the file named as
+// `inputFilename` and returns a module containing the SPIR-V module.
+OwningModuleRef deserializeModule(llvm::StringRef inputFilename,
+                                  MLIRContext *context) {
+  Builder builder(context);
+
+  std::string errorMessage;
+  auto file = openInputFile(inputFilename, &errorMessage);
+  if (!file) {
+    emitError(UnknownLoc::get(context), errorMessage);
+    return {};
+  }
+
+  // Make sure the input stream can be treated as a stream of SPIR-V words
+  auto start = file->getBufferStart();
+  auto size = file->getBufferSize();
+  if (size % sizeof(uint32_t) != 0) {
+    emitError(UnknownLoc::get(context))
+        << "SPIR-V binary module must contain integral number of 32-bit words";
+    return {};
+  }
+
+  auto binary = llvm::makeArrayRef(reinterpret_cast<const uint32_t *>(start),
+                                   size / sizeof(uint32_t));
+
+  auto spirvModule = spirv::deserialize(binary, context);
+  if (!spirvModule)
+    return {};
+
+  // TODO(antiagainst): due to the restriction of the current translation
+  // infrastructure, we must return a MLIR module here. So we are wrapping the
+  // converted SPIR-V ModuleOp inside a MLIR module. This should be changed to
+  // return the SPIR-V ModuleOp directly after module and function are migrated
+  // to be general ops.
+  OwningModuleRef module(ModuleOp::create(
+      FileLineColLoc::get(inputFilename, /*line=*/0, /*column=*/0, context)));
+  Block *block = createOneBlockFunction(builder, module.get());
+  block->push_front(spirvModule->getOperation());
+
+  return module;
+}
+
+static TranslateToMLIRRegistration
+    registration("deserialize-spirv",
+                 [](StringRef inputFilename, MLIRContext *context) {
+                   return deserializeModule(inputFilename, context);
+                 });
diff --git a/third_party/mlir/lib/Dialect/SPIRV/Serialization/ConvertToBinary.cpp b/third_party/mlir/lib/Dialect/SPIRV/Serialization/ConvertToBinary.cpp
new file mode 100644
index 0000000..5e8c663
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/SPIRV/Serialization/ConvertToBinary.cpp
@@ -0,0 +1,79 @@
+//===- ConvertToBinary.cpp - MLIR SPIR-V module to binary conversion ------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a translation from MLIR SPIR-V ModuleOp to SPIR-V
+// binary module.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/Serialization.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Translation.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+using namespace mlir;
+
+LogicalResult serializeModule(ModuleOp module, StringRef outputFilename) {
+  if (!module)
+    return failure();
+
+  SmallVector<uint32_t, 0> binary;
+  bool done = false;
+  auto result = failure();
+
+  // TODO(antiagainst): we are checking there is only one SPIR-V ModuleOp in
+  // this module and serialize it. This is due to the restriction of the current
+  // translation infrastructure; we must take in a MLIR module here. So we are
+  // wrapping the SPIR-V ModuleOp inside a MLIR module. This should be changed
+  // to take in the SPIR-V ModuleOp directly after module and function are
+  // migrated to be general ops.
+  for (auto fn : module.getOps<FuncOp>()) {
+    fn.walk<spirv::ModuleOp>([&](spirv::ModuleOp spirvModule) {
+      if (done) {
+        spirvModule.emitError("found more than one 'spv.module' op");
+        return;
+      }
+
+      done = true;
+      result = spirv::serialize(spirvModule, binary);
+    });
+  }
+
+  if (failed(result))
+    return failure();
+
+  auto file = openOutputFile(outputFilename);
+  if (!file)
+    return failure();
+
+  file->os().write(reinterpret_cast<char *>(binary.data()),
+                   binary.size() * sizeof(uint32_t));
+  file->keep();
+
+  return mlir::success();
+}
+
+static TranslateFromMLIRRegistration
+    registration("serialize-spirv",
+                 [](ModuleOp module, StringRef outputFilename) {
+                   return serializeModule(module, outputFilename);
+                 });
diff --git a/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
new file mode 100644
index 0000000..1fd9758
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -0,0 +1,997 @@
+//===- Deserializer.cpp - MLIR SPIR-V Deserialization ---------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines the SPIR-V binary to MLIR SPIR-V module deseralization.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/Serialization.h"
+
+#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/StringExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/bit.h"
+
+using namespace mlir;
+
+// Decodes a string literal in `words` starting at `wordIndex`. Update the
+// latter to point to the position in words after the string literal.
+static inline StringRef decodeStringLiteral(ArrayRef<uint32_t> words,
+                                            unsigned &wordIndex) {
+  StringRef str(reinterpret_cast<const char *>(words.data() + wordIndex));
+  wordIndex += str.size() / 4 + 1;
+  return str;
+}
+
+namespace {
+/// A SPIR-V module serializer.
+///
+/// A SPIR-V binary module is a single linear stream of instructions; each
+/// instruction is composed of 32-bit words. The first word of an instruction
+/// records the total number of words of that instruction using the 16
+/// higher-order bits. So this deserializer uses that to get instruction
+/// boundary and parse instructions and build a SPIR-V ModuleOp gradually.
+///
+// TODO(antiagainst): clean up created ops on errors
+class Deserializer {
+public:
+  /// Creates a deserializer for the given SPIR-V `binary` module.
+  /// The SPIR-V ModuleOp will be created into `context.
+  explicit Deserializer(ArrayRef<uint32_t> binary, MLIRContext *context);
+
+  /// Deserializes the remembered SPIR-V binary module.
+  LogicalResult deserialize();
+
+  /// Collects the final SPIR-V ModuleOp.
+  Optional<spirv::ModuleOp> collect();
+
+private:
+  //===--------------------------------------------------------------------===//
+  // Module structure
+  //===--------------------------------------------------------------------===//
+
+  /// Initializes the `module` ModuleOp in this deserializer instance.
+  spirv::ModuleOp createModuleOp();
+
+  /// Processes SPIR-V module header in `binary`.
+  LogicalResult processHeader();
+
+  /// Processes the SPIR-V OpMemoryModel with `operands` and updates `module`.
+  LogicalResult processMemoryModel(ArrayRef<uint32_t> operands);
+
+  /// Process SPIR-V OpName with `operands`
+  LogicalResult processName(ArrayRef<uint32_t> operands);
+
+  /// Method to process an OpDecorate instruction.
+  LogicalResult processDecoration(ArrayRef<uint32_t> words);
+
+  /// Processes the SPIR-V function at the current `offset` into `binary`.
+  /// The operands to the OpFunction instruction is passed in as ``operands`.
+  /// This method processes each instruction inside the function and dispatches
+  /// them to their handler method accordingly.
+  LogicalResult processFunction(ArrayRef<uint32_t> operands);
+
+  /// Get the FuncOp associated with a result <id> of OpFunction.
+  FuncOp getFunction(uint32_t id) { return funcMap.lookup(id); }
+
+  //===--------------------------------------------------------------------===//
+  // Type
+  //===--------------------------------------------------------------------===//
+
+  /// Gets type for a given result <id>.
+  Type getType(uint32_t id) { return typeMap.lookup(id); }
+
+  /// Returns true if the given `type` is for SPIR-V void type.
+  bool isVoidType(Type type) const { return type.isa<NoneType>(); }
+
+  /// Processes a SPIR-V type instruction with given `opcode` and `operands` and
+  /// registers the type into `module`.
+  LogicalResult processType(spirv::Opcode opcode, ArrayRef<uint32_t> operands);
+
+  LogicalResult processArrayType(ArrayRef<uint32_t> operands);
+
+  LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
+
+  //===--------------------------------------------------------------------===//
+  // Constant
+  //===--------------------------------------------------------------------===//
+
+  /// Processes a SPIR-V Op{|Spec}Constant instruction with the given
+  /// `operands`. `isSpec` indicates whether this is a specialization constant.
+  LogicalResult processConstant(ArrayRef<uint32_t> operands, bool isSpec);
+
+  /// Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the
+  /// given `operands`. `isSpec` indicates whether this is a specialization
+  /// constant.
+  LogicalResult processConstantBool(bool isTrue, ArrayRef<uint32_t> operands,
+                                    bool isSpec);
+
+  /// Processes a SPIR-V Op{|Spec}ConstantComposite instruction with the given
+  /// `operands`. `isSpec` indicates whether this is a specialization constant.
+  LogicalResult processConstantComposite(ArrayRef<uint32_t> operands,
+                                         bool isSpec);
+
+  /// Processes a SPIR-V OpConstantNull instruction with the given `operands`.
+  LogicalResult processConstantNull(ArrayRef<uint32_t> operands);
+
+  //===--------------------------------------------------------------------===//
+  // Instruction
+  //===--------------------------------------------------------------------===//
+
+  /// Get the Value associated with a result <id>.
+  Value *getValue(uint32_t id) { return valueMap.lookup(id); }
+
+  /// Slices the first instruction out of `binary` and returns its opcode and
+  /// operands via `opcode` and `operands` respectively. Returns failure if
+  /// there is no more remaining instructions (`expectedOpcode` will be used to
+  /// compose the error message) or the next instruction is malformed.
+  LogicalResult
+  sliceInstruction(spirv::Opcode &opcode, ArrayRef<uint32_t> &operands,
+                   Optional<spirv::Opcode> expectedOpcode = llvm::None);
+
+  /// Processes a SPIR-V instruction with the given `opcode` and `operands`.
+  /// This method is the main entrance for handling SPIR-V instruction; it
+  /// checks the instruction opcode and dispatches to the corresponding handler.
+  /// Processing of Some instructions (like OpEntryPoint and OpExecutionMode)
+  /// might need to be defered, since they contain forward references to <id>s
+  /// in the deserialized binary, but module in SPIR-V dialect expects these to
+  /// be ssa-uses.
+  LogicalResult processInstruction(spirv::Opcode opcode,
+                                   ArrayRef<uint32_t> operands,
+                                   bool deferInstructions = true);
+
+  /// Method to dispatch to the specialized deserialization function for an
+  /// operation in SPIR-V dialect that is a mirror of an instruction in the
+  /// SPIR-V spec. This is auto-generated from ODS. Dispatch is handled for
+  /// all operations in SPIR-V dialect that have hasOpcode == 1.
+  LogicalResult dispatchToAutogenDeserialization(spirv::Opcode opcode,
+                                                 ArrayRef<uint32_t> words);
+
+  /// Method to deserialize an operation in the SPIR-V dialect that is a mirror
+  /// of an instruction in the SPIR-V spec. This is auto generated if hasOpcode
+  /// == 1 and autogenSerialization == 1 in ODS.
+  template <typename OpTy> LogicalResult processOp(ArrayRef<uint32_t> words) {
+    return emitError(unknownLoc, "unsupported deserialization for ")
+           << OpTy::getOperationName() << " op";
+  }
+
+private:
+  /// The SPIR-V binary module.
+  ArrayRef<uint32_t> binary;
+
+  /// The current word offset into the binary module.
+  unsigned curOffset = 0;
+
+  /// MLIRContext to create SPIR-V ModuleOp into.
+  MLIRContext *context;
+
+  // TODO(antiagainst): create Location subclass for binary blob
+  Location unknownLoc;
+
+  /// The SPIR-V ModuleOp.
+  Optional<spirv::ModuleOp> module;
+
+  OpBuilder opBuilder;
+
+  // Result <id> to type mapping.
+  DenseMap<uint32_t, Type> typeMap;
+
+  // Result <id> to function mapping.
+  DenseMap<uint32_t, FuncOp> funcMap;
+
+  // Result <id> to value mapping.
+  DenseMap<uint32_t, Value *> valueMap;
+
+  // Result <id> to name mapping.
+  DenseMap<uint32_t, StringRef> nameMap;
+
+  // Result <id> to decorations mapping.
+  DenseMap<uint32_t, NamedAttributeList> decorations;
+
+  // List of instructions that are processed in a defered fashion (after an
+  // initial processing of the entire binary). Some operations like
+  // OpEntryPoint, and OpExecutionMode use forward references to function
+  // <id>s. In SPIR-V dialect the corresponding operations (spv.EntryPoint and
+  // spv.ExecutionMode) need these references resolved. So these instructions
+  // are deserialized and stored for processing once the entire binary is
+  // processed.
+  SmallVector<std::pair<spirv::Opcode, ArrayRef<uint32_t>>, 4>
+      deferedInstructions;
+};
+} // namespace
+
+Deserializer::Deserializer(ArrayRef<uint32_t> binary, MLIRContext *context)
+    : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
+      module(createModuleOp()),
+      opBuilder(module->getOperation()->getRegion(0)) {}
+
+LogicalResult Deserializer::deserialize() {
+  if (failed(processHeader()))
+    return failure();
+
+  spirv::Opcode opcode = spirv::Opcode::OpNop;
+  ArrayRef<uint32_t> operands;
+  auto binarySize = binary.size();
+  while (curOffset < binarySize) {
+    // Slice the next instruction out and populate `opcode` and `operands`.
+    // Interally this also updates `curOffset`.
+    if (failed(sliceInstruction(opcode, operands)))
+      return failure();
+
+    if (failed(processInstruction(opcode, operands)))
+      return failure();
+  }
+
+  assert(curOffset == binarySize &&
+         "deserializer should never index beyond the binary end");
+
+  for (auto &defered : deferedInstructions) {
+    if (failed(processInstruction(defered.first, defered.second, false))) {
+      return failure();
+    }
+  }
+
+  return success();
+}
+
+Optional<spirv::ModuleOp> Deserializer::collect() { return module; }
+
+//===----------------------------------------------------------------------===//
+// Module structure
+//===----------------------------------------------------------------------===//
+
+spirv::ModuleOp Deserializer::createModuleOp() {
+  Builder builder(context);
+  OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
+  // TODO(antiagainst): use target environment to select the version
+  state.addAttribute("major_version", builder.getI32IntegerAttr(1));
+  state.addAttribute("minor_version", builder.getI32IntegerAttr(0));
+  spirv::ModuleOp::build(&builder, &state);
+  return cast<spirv::ModuleOp>(Operation::create(state));
+}
+
+LogicalResult Deserializer::processHeader() {
+  if (binary.size() < spirv::kHeaderWordCount)
+    return emitError(unknownLoc,
+                     "SPIR-V binary module must have a 5-word header");
+
+  if (binary[0] != spirv::kMagicNumber)
+    return emitError(unknownLoc, "incorrect magic number");
+
+  // TODO(antiagainst): generator number, bound, schema
+  curOffset = spirv::kHeaderWordCount;
+  return success();
+}
+
+LogicalResult Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
+  if (operands.size() != 2)
+    return emitError(unknownLoc, "OpMemoryModel must have two operands");
+
+  module->setAttr(
+      "addressing_model",
+      opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.front())));
+  module->setAttr(
+      "memory_model",
+      opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.back())));
+
+  return success();
+}
+
+LogicalResult Deserializer::processDecoration(ArrayRef<uint32_t> words) {
+  // TODO : This function should also be auto-generated. For now, since only a
+  // few decorations are processed/handled in a meaningful manner, going with a
+  // manual implementation.
+  if (words.size() < 2) {
+    return emitError(
+        unknownLoc, "OpDecorate must have at least result <id> and Decoration");
+  }
+  auto decorationName =
+      stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
+  if (decorationName.empty()) {
+    return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
+  }
+  auto attrName = convertToSnakeCase(decorationName);
+  switch (static_cast<spirv::Decoration>(words[1])) {
+  case spirv::Decoration::DescriptorSet:
+  case spirv::Decoration::Binding:
+    if (words.size() != 3) {
+      return emitError(unknownLoc, "OpDecorate with ")
+             << decorationName << " needs a single integer literal";
+    }
+    decorations[words[0]].set(
+        opBuilder.getIdentifier(attrName),
+        opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
+    break;
+  default:
+    return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
+  }
+  return success();
+}
+
+LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
+  // Get the result type
+  if (operands.size() != 4) {
+    return emitError(unknownLoc, "OpFunction must have 4 parameters");
+  }
+  Type resultType = getType(operands[0]);
+  if (!resultType) {
+    return emitError(unknownLoc, "undefined result type from <id> ")
+           << operands[0];
+  }
+  if (funcMap.count(operands[1])) {
+    return emitError(unknownLoc, "duplicate function definition/declaration");
+  }
+  auto functionControl = spirv::symbolizeFunctionControl(operands[2]);
+  if (!functionControl) {
+    return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
+  }
+  if (functionControl.getValue() != spirv::FunctionControl::None) {
+    /// TODO : Handle different function controls
+    return emitError(unknownLoc, "unhandled Function Control: '")
+           << spirv::stringifyFunctionControl(functionControl.getValue())
+           << "'";
+  }
+  Type fnType = getType(operands[3]);
+  if (!fnType || !fnType.isa<FunctionType>()) {
+    return emitError(unknownLoc, "unknown function type from <id> ")
+           << operands[3];
+  }
+  auto functionType = fnType.cast<FunctionType>();
+  if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
+      (functionType.getNumResults() == 1 &&
+       functionType.getResult(0) != resultType)) {
+    return emitError(unknownLoc, "mismatch in function type ")
+           << functionType << " and return type " << resultType << " specified";
+  }
+
+  std::string fnName = nameMap.lookup(operands[1]).str();
+  if (fnName.empty()) {
+    fnName = "spirv_fn_" + std::to_string(operands[2]);
+  }
+  auto funcOp = opBuilder.create<FuncOp>(unknownLoc, fnName, functionType,
+                                         ArrayRef<NamedAttribute>());
+  funcMap[operands[1]] = funcOp;
+  funcOp.addEntryBlock();
+
+  // Parse the op argument instructions
+  if (functionType.getNumInputs()) {
+    for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
+      auto argType = functionType.getInput(i);
+      spirv::Opcode opcode = spirv::Opcode::OpNop;
+      ArrayRef<uint32_t> operands;
+      if (failed(sliceInstruction(opcode, operands,
+                                  spirv::Opcode::OpFunctionParameter))) {
+        return failure();
+      }
+      if (opcode != spirv::Opcode::OpFunctionParameter) {
+        return emitError(
+                   unknownLoc,
+                   "missing OpFunctionParameter instruction for argument ")
+               << i;
+      }
+      if (operands.size() != 2) {
+        return emitError(
+            unknownLoc,
+            "expected result type and result <id> for OpFunctionParameter");
+      }
+      auto argDefinedType = getType(operands[0]);
+      if (!argDefinedType || argDefinedType != argType) {
+        return emitError(unknownLoc,
+                         "mismatch in argument type between function type "
+                         "definition ")
+               << functionType << " and argument type definition "
+               << argDefinedType << " at argument " << i;
+      }
+      if (getValue(operands[1])) {
+        return emitError(unknownLoc, "duplicate definition of result <id> '")
+               << operands[1];
+      }
+      auto argValue = funcOp.getArgument(i);
+      valueMap[operands[1]] = argValue;
+    }
+  }
+
+  // Create a new builder for building the body
+  OpBuilder funcBody(funcOp.getBody());
+  std::swap(funcBody, opBuilder);
+
+  spirv::Opcode opcode = spirv::Opcode::OpNop;
+  ArrayRef<uint32_t> instOperands;
+  while (succeeded(sliceInstruction(opcode, instOperands,
+                                    spirv::Opcode::OpFunctionEnd)) &&
+         opcode != spirv::Opcode::OpFunctionEnd) {
+    if (failed(processInstruction(opcode, instOperands))) {
+      return failure();
+    }
+  }
+  if (opcode != spirv::Opcode::OpFunctionEnd) {
+    return failure();
+  }
+  if (!instOperands.empty()) {
+    return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
+  }
+  std::swap(funcBody, opBuilder);
+  return success();
+}
+
+LogicalResult Deserializer::processName(ArrayRef<uint32_t> operands) {
+  if (operands.size() < 2) {
+    return emitError(unknownLoc, "OpName needs at least 2 operands");
+  }
+  if (!nameMap.lookup(operands[0]).empty()) {
+    return emitError(unknownLoc, "duplicate name found for result <id> ")
+           << operands[0];
+  }
+  unsigned wordIndex = 1;
+  StringRef name = decodeStringLiteral(operands, wordIndex);
+  if (wordIndex != operands.size()) {
+    return emitError(unknownLoc,
+                     "unexpected trailing words in OpName instruction");
+  }
+  nameMap[operands[0]] = name;
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Type
+//===----------------------------------------------------------------------===//
+
+LogicalResult Deserializer::processType(spirv::Opcode opcode,
+                                        ArrayRef<uint32_t> operands) {
+  if (operands.empty()) {
+    return emitError(unknownLoc, "type instruction with opcode ")
+           << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
+  }
+
+  /// TODO: Types might be forward declared in some instructions and need to be
+  /// handled appropriately.
+  if (typeMap.count(operands[0])) {
+    return emitError(unknownLoc, "duplicate definition for result <id> ")
+           << operands[0];
+  }
+
+  switch (opcode) {
+  case spirv::Opcode::OpTypeVoid:
+    if (operands.size() != 1) {
+      return emitError(unknownLoc, "OpTypeVoid must have no parameters");
+    }
+    typeMap[operands[0]] = opBuilder.getNoneType();
+    break;
+  case spirv::Opcode::OpTypeBool:
+    if (operands.size() != 1) {
+      return emitError(unknownLoc, "OpTypeBool must have no parameters");
+    }
+    typeMap[operands[0]] = opBuilder.getI1Type();
+    break;
+  case spirv::Opcode::OpTypeInt:
+    if (operands.size() != 3) {
+      return emitError(
+          unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
+    }
+    if (operands[2] == 0) {
+      return emitError(unknownLoc, "unhandled unsigned OpTypeInt");
+    }
+    typeMap[operands[0]] = opBuilder.getIntegerType(operands[1]);
+    break;
+  case spirv::Opcode::OpTypeFloat: {
+    if (operands.size() != 2) {
+      return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
+    }
+    Type floatTy;
+    switch (operands[1]) {
+    case 16:
+      floatTy = opBuilder.getF16Type();
+      break;
+    case 32:
+      floatTy = opBuilder.getF32Type();
+      break;
+    case 64:
+      floatTy = opBuilder.getF64Type();
+      break;
+    default:
+      return emitError(unknownLoc, "unsupported OpTypeFloat bitwdith: ")
+             << operands[1];
+    }
+    typeMap[operands[0]] = floatTy;
+  } break;
+  case spirv::Opcode::OpTypeVector: {
+    if (operands.size() != 3) {
+      return emitError(
+          unknownLoc,
+          "OpTypeVector must have element type and count parameters");
+    }
+    Type elementTy = getType(operands[1]);
+    if (!elementTy) {
+      return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
+             << operands[1];
+    }
+    typeMap[operands[0]] = opBuilder.getVectorType({operands[2]}, elementTy);
+  } break;
+  case spirv::Opcode::OpTypePointer: {
+    if (operands.size() != 3) {
+      return emitError(unknownLoc, "OpTypePointer must have two parameters");
+    }
+    auto pointeeType = getType(operands[2]);
+    if (!pointeeType) {
+      return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
+             << operands[2];
+    }
+    auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
+    typeMap[operands[0]] = spirv::PointerType::get(pointeeType, storageClass);
+  } break;
+  case spirv::Opcode::OpTypeArray:
+    return processArrayType(operands);
+  case spirv::Opcode::OpTypeFunction:
+    return processFunctionType(operands);
+  default:
+    return emitError(unknownLoc, "unhandled type instruction");
+  }
+  return success();
+}
+
+LogicalResult Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
+  if (operands.size() != 3) {
+    return emitError(unknownLoc,
+                     "OpTypeArray must have element type and count parameters");
+  }
+
+  Type elementTy = getType(operands[1]);
+  if (!elementTy) {
+    return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
+           << operands[1];
+  }
+
+  unsigned count = 0;
+  auto *countValue = getValue(operands[2]);
+  if (!countValue) {
+    return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
+           << operands[2];
+  }
+
+  auto *defOp = countValue->getDefiningOp();
+  if (auto constOp = dyn_cast<spirv::ConstantOp>(defOp)) {
+    if (auto intVal = constOp.value().dyn_cast<IntegerAttr>()) {
+      count = intVal.getInt();
+    } else {
+      return emitError(unknownLoc, "OpTypeArray count must come from a "
+                                   "scalar integer constant instruction");
+    }
+  } else {
+    return emitError(unknownLoc,
+                     "unsupported OpTypeArray count generated from ")
+           << defOp->getName();
+  }
+
+  typeMap[operands[0]] = spirv::ArrayType::get(elementTy, count);
+  return success();
+}
+
+LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
+  assert(!operands.empty() && "No operands for processing function type");
+  if (operands.size() == 1) {
+    return emitError(unknownLoc, "missing return type for OpTypeFunction");
+  }
+  auto returnType = getType(operands[1]);
+  if (!returnType) {
+    return emitError(unknownLoc, "unknown return type in OpTypeFunction");
+  }
+  SmallVector<Type, 1> argTypes;
+  for (size_t i = 2, e = operands.size(); i < e; ++i) {
+    auto ty = getType(operands[i]);
+    if (!ty) {
+      return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
+    }
+    argTypes.push_back(ty);
+  }
+  ArrayRef<Type> returnTypes;
+  if (!isVoidType(returnType)) {
+    returnTypes = llvm::makeArrayRef(returnType);
+  }
+  typeMap[operands[0]] = FunctionType::get(argTypes, returnTypes, context);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Constant
+//===----------------------------------------------------------------------===//
+
+LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands,
+                                            bool isSpec) {
+  StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
+
+  if (operands.size() < 2) {
+    return emitError(unknownLoc)
+           << opname << " must have type <id> and result <id>";
+  }
+  if (operands.size() < 3) {
+    return emitError(unknownLoc)
+           << opname << " must have at least 1 more parameter";
+  }
+
+  Type resultType = getType(operands[0]);
+  if (!resultType) {
+    return emitError(unknownLoc, "undefined result type from <id> ")
+           << operands[0];
+  }
+
+  auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
+    if (bitwidth == 64) {
+      if (operands.size() == 4) {
+        return success();
+      }
+      return emitError(unknownLoc)
+             << opname << " should have 2 parameters for 64-bit values";
+    }
+    if (bitwidth <= 32) {
+      if (operands.size() == 3) {
+        return success();
+      }
+
+      return emitError(unknownLoc)
+             << opname
+             << " should have 1 parameter for values with no more than 32 bits";
+    }
+    return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
+           << bitwidth;
+  };
+
+  spirv::ConstantOp op;
+  UnitAttr isSpecConst = isSpec ? opBuilder.getUnitAttr() : UnitAttr();
+  if (auto intType = resultType.dyn_cast<IntegerType>()) {
+    auto bitwidth = intType.getWidth();
+    if (failed(checkOperandSizeForBitwidth(bitwidth))) {
+      return failure();
+    }
+
+    APInt value;
+    if (bitwidth == 64) {
+      // 64-bit integers are represented with two SPIR-V words. According to
+      // SPIR-V spec: "When the type’s bit width is larger than one word, the
+      // literal’s low-order words appear first."
+      struct DoubleWord {
+        uint32_t word1;
+        uint32_t word2;
+      } words = {operands[2], operands[3]};
+      value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
+    } else if (bitwidth <= 32) {
+      value = APInt(bitwidth, operands[2], /*isSigned=*/true);
+    }
+
+    auto attr = opBuilder.getIntegerAttr(intType, value);
+    op = opBuilder.create<spirv::ConstantOp>(unknownLoc, intType, attr,
+                                             isSpecConst);
+  } else if (auto floatType = resultType.dyn_cast<FloatType>()) {
+    auto bitwidth = floatType.getWidth();
+    if (failed(checkOperandSizeForBitwidth(bitwidth))) {
+      return failure();
+    }
+
+    APFloat value(0.f);
+    if (floatType.isF64()) {
+      // Double values are represented with two SPIR-V words. According to
+      // SPIR-V spec: "When the type’s bit width is larger than one word, the
+      // literal’s low-order words appear first."
+      struct DoubleWord {
+        uint32_t word1;
+        uint32_t word2;
+      } words = {operands[2], operands[3]};
+      value = APFloat(llvm::bit_cast<double>(words));
+    } else if (floatType.isF32()) {
+      value = APFloat(llvm::bit_cast<float>(operands[2]));
+    } else if (floatType.isF16()) {
+      APInt data(16, operands[2]);
+      value = APFloat(APFloat::IEEEhalf(), data);
+    }
+
+    auto attr = opBuilder.getFloatAttr(floatType, value);
+    op = opBuilder.create<spirv::ConstantOp>(unknownLoc, floatType, attr,
+                                             isSpecConst);
+  } else {
+    return emitError(unknownLoc, "OpConstant can only generate values of "
+                                 "scalar integer or floating-point type");
+  }
+
+  valueMap[operands[1]] = op.getResult();
+  return success();
+}
+
+LogicalResult Deserializer::processConstantBool(bool isTrue,
+                                                ArrayRef<uint32_t> operands,
+                                                bool isSpec) {
+  if (operands.size() != 2) {
+    return emitError(unknownLoc, "Op")
+           << (isSpec ? "Spec" : "") << "Constant"
+           << (isTrue ? "True" : "False")
+           << " must have type <id> and result <id>";
+  }
+
+  auto attr = opBuilder.getBoolAttr(isTrue);
+  UnitAttr isSpecConst = isSpec ? opBuilder.getUnitAttr() : UnitAttr();
+  auto op = opBuilder.create<spirv::ConstantOp>(
+      unknownLoc, opBuilder.getI1Type(), attr, isSpecConst);
+
+  valueMap[operands[1]] = op.getResult();
+  return success();
+}
+
+LogicalResult
+Deserializer::processConstantComposite(ArrayRef<uint32_t> operands,
+                                       bool isSpec) {
+  if (operands.size() < 2) {
+    return emitError(unknownLoc,
+                     "OpConstantComposite must have type <id> and result <id>");
+  }
+  if (operands.size() < 3) {
+    return emitError(unknownLoc,
+                     "OpConstantComposite must have at least 1 parameter");
+  }
+
+  Type resultType = getType(operands[0]);
+  if (!resultType) {
+    return emitError(unknownLoc, "undefined result type from <id> ")
+           << operands[0];
+  }
+
+  SmallVector<Attribute, 4> elements;
+  elements.reserve(operands.size() - 2);
+  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
+    Value *value = getValue(operands[i]);
+    if (!value) {
+      return emitError(unknownLoc,
+                       "OpConstantComposite references undefined <id> ")
+             << operands[i];
+    }
+    auto *defOp = value->getDefiningOp();
+    if (auto elementOp = dyn_cast<spirv::ConstantOp>(defOp)) {
+      elements.push_back(elementOp.value());
+    } else {
+      return emitError(
+                 unknownLoc,
+                 "unsupported OpConstantComposite component generated from ")
+             << defOp->getName();
+    }
+  }
+
+  spirv::ConstantOp op;
+  UnitAttr isSpecConst = isSpec ? opBuilder.getUnitAttr() : UnitAttr();
+  if (auto vectorType = resultType.dyn_cast<VectorType>()) {
+    auto attr = opBuilder.getDenseElementsAttr(vectorType, elements);
+    op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr,
+                                             isSpecConst);
+  } else if (auto arrayType = resultType.dyn_cast<spirv::ArrayType>()) {
+    auto attr = opBuilder.getArrayAttr(elements);
+    op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr,
+                                             isSpecConst);
+  } else {
+    return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
+           << resultType;
+  }
+
+  valueMap[operands[1]] = op.getResult();
+  return success();
+}
+
+LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
+  if (operands.size() != 2) {
+    return emitError(unknownLoc,
+                     "OpConstantNull must have type <id> and result <id>");
+  }
+
+  Type resultType = getType(operands[0]);
+  if (!resultType) {
+    return emitError(unknownLoc, "undefined result type from <id> ")
+           << operands[0];
+  }
+
+  spirv::ConstantOp op;
+  if (resultType.isa<IntegerType>() || resultType.isa<FloatType>() ||
+      resultType.isa<VectorType>()) {
+    auto attr = opBuilder.getZeroAttr(resultType);
+    UnitAttr isSpecConst;
+    op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr,
+                                             isSpecConst);
+  } else {
+    return emitError(unknownLoc, "unsupported OpConstantNull type: ")
+           << resultType;
+  }
+
+  valueMap[operands[1]] = op.getResult();
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Instruction
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+Deserializer::sliceInstruction(spirv::Opcode &opcode,
+                               ArrayRef<uint32_t> &operands,
+                               Optional<spirv::Opcode> expectedOpcode) {
+  auto binarySize = binary.size();
+  if (curOffset >= binarySize) {
+    return emitError(unknownLoc, "expected ")
+           << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
+                              : "more")
+           << " instruction";
+  }
+
+  // For each instruction, get its word count from the first word to slice it
+  // from the stream properly, and then dispatch to the instruction handler.
+
+  uint32_t wordCount = binary[curOffset] >> 16;
+
+  if (wordCount == 0)
+    return emitError(unknownLoc, "word count cannot be zero");
+
+  uint32_t nextOffset = curOffset + wordCount;
+  if (nextOffset > binarySize)
+    return emitError(unknownLoc, "insufficient words for the last instruction");
+
+  opcode = static_cast<spirv::Opcode>(binary[curOffset] & 0xffff);
+  operands = binary.slice(curOffset + 1, wordCount - 1);
+  curOffset = nextOffset;
+  return success();
+}
+
+LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
+                                               ArrayRef<uint32_t> operands,
+                                               bool deferInstructions) {
+  // First dispatch all the instructions whose opcode does not correspond to
+  // those that have a direct mirror in the SPIR-V dialect
+  switch (opcode) {
+  case spirv::Opcode::OpMemoryModel:
+    return processMemoryModel(operands);
+  case spirv::Opcode::OpEntryPoint:
+  case spirv::Opcode::OpExecutionMode:
+    if (deferInstructions) {
+      deferedInstructions.emplace_back(opcode, operands);
+      return success();
+    }
+    break;
+  case spirv::Opcode::OpName:
+    return processName(operands);
+  case spirv::Opcode::OpTypeVoid:
+  case spirv::Opcode::OpTypeBool:
+  case spirv::Opcode::OpTypeInt:
+  case spirv::Opcode::OpTypeFloat:
+  case spirv::Opcode::OpTypeVector:
+  case spirv::Opcode::OpTypeArray:
+  case spirv::Opcode::OpTypeFunction:
+  case spirv::Opcode::OpTypePointer:
+    return processType(opcode, operands);
+  case spirv::Opcode::OpConstant:
+    return processConstant(operands, /*isSpec=*/false);
+  case spirv::Opcode::OpSpecConstant:
+    return processConstant(operands, /*isSpec=*/true);
+  case spirv::Opcode::OpConstantComposite:
+    return processConstantComposite(operands, /*isSpec=*/false);
+  case spirv::Opcode::OpSpecConstantComposite:
+    return processConstantComposite(operands, /*isSpec=*/true);
+  case spirv::Opcode::OpConstantTrue:
+    return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
+  case spirv::Opcode::OpSpecConstantTrue:
+    return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
+  case spirv::Opcode::OpConstantFalse:
+    return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
+  case spirv::Opcode::OpSpecConstantFalse:
+    return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
+  case spirv::Opcode::OpConstantNull:
+    return processConstantNull(operands);
+  case spirv::Opcode::OpDecorate:
+    return processDecoration(operands);
+  case spirv::Opcode::OpFunction:
+    return processFunction(operands);
+  default:
+    break;
+  }
+  return dispatchToAutogenDeserialization(opcode, operands);
+}
+
+namespace {
+
+template <>
+LogicalResult
+Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
+  unsigned wordIndex = 0;
+  if (wordIndex >= words.size()) {
+    return emitError(unknownLoc,
+                     "missing Execution Model specification in OpEntryPoint");
+  }
+  auto exec_model = opBuilder.getI32IntegerAttr(words[wordIndex++]);
+  if (wordIndex >= words.size()) {
+    return emitError(unknownLoc, "missing <id> in OpEntryPoint");
+  }
+  // Get the function <id>
+  auto fnID = words[wordIndex++];
+  // Get the function name
+  auto fnName = decodeStringLiteral(words, wordIndex);
+  // Verify that the function <id> matches the fnName
+  auto parsedFunc = getFunction(fnID);
+  if (!parsedFunc) {
+    return emitError(unknownLoc, "no function matching <id> ") << fnID;
+  }
+  if (parsedFunc.getName() != fnName) {
+    return emitError(unknownLoc, "function name mismatch between OpEntryPoint "
+                                 "and OpFunction with <id> ")
+           << fnID << ": " << fnName << " vs. " << parsedFunc.getName();
+  }
+  SmallVector<Value *, 4> interface;
+  while (wordIndex < words.size()) {
+    auto arg = getValue(words[wordIndex]);
+    if (!arg) {
+      return emitError(unknownLoc, "undefined result <id> ")
+             << words[wordIndex] << " while decoding OpEntryPoint";
+    }
+    interface.push_back(arg);
+    wordIndex++;
+  }
+  opBuilder.create<spirv::EntryPointOp>(
+      unknownLoc, exec_model, opBuilder.getSymbolRefAttr(fnName), interface);
+  return success();
+}
+
+template <>
+LogicalResult
+Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
+  unsigned wordIndex = 0;
+  if (wordIndex >= words.size()) {
+    return emitError(unknownLoc,
+                     "missing function result <id> in OpExecutionMode");
+  }
+  // Get the function <id> to get the name of the function
+  auto fnID = words[wordIndex++];
+  auto fn = getFunction(fnID);
+  if (!fn) {
+    return emitError(unknownLoc, "no function matching <id> ") << fnID;
+  }
+  // Get the Execution mode
+  if (wordIndex >= words.size()) {
+    return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
+  }
+  auto execMode = opBuilder.getI32IntegerAttr(words[wordIndex++]);
+
+  // Get the values
+  SmallVector<Attribute, 4> attrListElems;
+  while (wordIndex < words.size()) {
+    attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
+  }
+  auto values = opBuilder.getArrayAttr(attrListElems);
+  opBuilder.create<spirv::ExecutionModeOp>(
+      unknownLoc, opBuilder.getSymbolRefAttr(fn.getName()), execMode, values);
+  return success();
+}
+
+// Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
+// various Deserializer::processOp<...>() specializations.
+#define GET_DESERIALIZATION_FNS
+#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc"
+} // namespace
+
+Optional<spirv::ModuleOp> spirv::deserialize(ArrayRef<uint32_t> binary,
+                                             MLIRContext *context) {
+  Deserializer deserializer(binary, context);
+
+  if (failed(deserializer.deserialize()))
+    return llvm::None;
+
+  return deserializer.collect();
+}
diff --git a/third_party/mlir/lib/Dialect/SPIRV/Serialization/SPIRVBinaryUtils.cpp b/third_party/mlir/lib/Dialect/SPIRV/Serialization/SPIRVBinaryUtils.cpp
new file mode 100644
index 0000000..1e432b3
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/SPIRV/Serialization/SPIRVBinaryUtils.cpp
@@ -0,0 +1,53 @@
+//===- SPIRVBinaryUtils.cpp - MLIR SPIR-V Binary Module Utilities ---------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines common utilities for SPIR-V binary module.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
+
+using namespace mlir;
+
+void spirv::appendModuleHeader(SmallVectorImpl<uint32_t> &header,
+                               uint32_t idBound) {
+  // The major and minor version number for the generated SPIR-V binary.
+  // TODO(antiagainst): use target environment to select the version
+  constexpr uint8_t kMajorVersion = 1;
+  constexpr uint8_t kMinorVersion = 0;
+
+  // See "2.3. Physical Layout of a SPIR-V Module and Instruction" in the SPIR-V
+  // spec for the definition of the binary module header.
+  //
+  // The first five words of a SPIR-V module must be:
+  // +-------------------------------------------------------------------------+
+  // | Magic number                                                            |
+  // +-------------------------------------------------------------------------+
+  // | Version number (bytes: 0 | major number | minor number | 0)             |
+  // +-------------------------------------------------------------------------+
+  // | Generator magic number                                                  |
+  // +-------------------------------------------------------------------------+
+  // | Bound (all result <id>s in the module guaranteed to be less than it)    |
+  // +-------------------------------------------------------------------------+
+  // | 0 (reserved for instruction schema)                                     |
+  // +-------------------------------------------------------------------------+
+  header.push_back(spirv::kMagicNumber);
+  header.push_back((kMajorVersion << 16) | (kMinorVersion << 8));
+  header.push_back(kGeneratorNumber);
+  header.push_back(idBound); // <id> bound
+  header.push_back(0);       // Schema (reserved word)
+}
diff --git a/third_party/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/third_party/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
new file mode 100644
index 0000000..188b08d
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -0,0 +1,975 @@
+//===- Serializer.cpp - MLIR SPIR-V Serialization -------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines the MLIR SPIR-V module to SPIR-V binary seralization.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/Serialization.h"
+
+#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/StringExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/bit.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+/// Returns the word-count-prefixed opcode for an SPIR-V instruction.
+static inline uint32_t getPrefixedOpcode(uint32_t wordCount,
+                                         spirv::Opcode opcode) {
+  assert(((wordCount >> 16) == 0) && "word count out of range!");
+  return (wordCount << 16) | static_cast<uint32_t>(opcode);
+}
+
+/// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
+/// the given `binary` vector.
+static LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary,
+                                           spirv::Opcode op,
+                                           ArrayRef<uint32_t> operands) {
+  uint32_t wordCount = 1 + operands.size();
+  binary.push_back(getPrefixedOpcode(wordCount, op));
+  if (!operands.empty()) {
+    binary.append(operands.begin(), operands.end());
+  }
+  return success();
+}
+
+/// Encodes an SPIR-V `literal` string into the given `binary` vector.
+static LogicalResult encodeStringLiteralInto(SmallVectorImpl<uint32_t> &binary,
+                                             StringRef literal) {
+  // We need to encode the literal and the null termination.
+  auto encodingSize = literal.size() / 4 + 1;
+  auto bufferStartSize = binary.size();
+  binary.resize(bufferStartSize + encodingSize, 0);
+  std::memcpy(binary.data() + bufferStartSize, literal.data(), literal.size());
+  return success();
+}
+
+namespace {
+
+/// A SPIR-V module serializer.
+///
+/// A SPIR-V binary module is a single linear stream of instructions; each
+/// instruction is composed of 32-bit words with the layout:
+///
+///   | <word-count>|<opcode> |  <operand>   |  <operand>   | ... |
+///   | <------ word -------> | <-- word --> | <-- word --> | ... |
+///
+/// For the first word, the 16 high-order bits are the word count of the
+/// instruction, the 16 low-order bits are the opcode enumerant. The
+/// instructions then belong to different sections, which must be laid out in
+/// the particular order as specified in "2.4 Logical Layout of a Module" of
+/// the SPIR-V spec.
+class Serializer {
+public:
+  /// Creates a serializer for the given SPIR-V `module`.
+  explicit Serializer(spirv::ModuleOp module);
+
+  /// Serializes the remembered SPIR-V module.
+  LogicalResult serialize();
+
+  /// Collects the final SPIR-V `binary`.
+  void collect(SmallVectorImpl<uint32_t> &binary);
+
+private:
+  // Note that there are two main categories of methods in this class:
+  // * process*() methods are meant to fully serialize a SPIR-V module entity
+  //   (header, type, op, etc.). They update internal vectors containing
+  //   different binary sections. They are not meant to be called except the
+  //   top-level serialization loop.
+  // * prepare*() methods are meant to be helpers that prepare for serializing
+  //   certain entity. They may or may not update internal vectors containing
+  //   different binary sections. They are meant to be called among themselves
+  //   or by other process*() methods for subtasks.
+
+  //===--------------------------------------------------------------------===//
+  // <id>
+  //===--------------------------------------------------------------------===//
+
+  // Note that it is illegal to use id <0> in SPIR-V binary module. Various
+  // methods in this class, if using SPIR-V word (uint32_t) as interface,
+  // check or return id <0> to indicate error in processing.
+
+  /// Consumes the next unused <id>. This method will never return 0.
+  uint32_t getNextID() { return nextID++; }
+
+  //===--------------------------------------------------------------------===//
+  // Module structure
+  //===--------------------------------------------------------------------===//
+
+  LogicalResult processMemoryModel();
+
+  LogicalResult processConstantOp(spirv::ConstantOp op);
+
+  uint32_t findFunctionID(StringRef fnName) const {
+    return funcIDMap.lookup(fnName);
+  }
+
+  /// Processes a SPIR-V function op.
+  LogicalResult processFuncOp(FuncOp op);
+
+  /// Process attributes that translate to decorations on the result <id>
+  LogicalResult processDecoration(Location loc, uint32_t resultID,
+                                  NamedAttribute attr);
+
+  //===--------------------------------------------------------------------===//
+  // Types
+  //===--------------------------------------------------------------------===//
+
+  uint32_t findTypeID(Type type) const { return typeIDMap.lookup(type); }
+
+  Type getVoidType() { return mlirBuilder.getNoneType(); }
+
+  bool isVoidType(Type type) const { return type.isa<NoneType>(); }
+
+  /// Main dispatch method for serializing a type. The result <id> of the
+  /// serialized type will be returned as `typeID`.
+  LogicalResult processType(Location loc, Type type, uint32_t &typeID);
+
+  /// Method for preparing basic SPIR-V type serialization. Returns the type's
+  /// opcode and operands for the instruction via `typeEnum` and `operands`.
+  LogicalResult prepareBasicType(Location loc, Type type,
+                                 spirv::Opcode &typeEnum,
+                                 SmallVectorImpl<uint32_t> &operands);
+
+  LogicalResult prepareFunctionType(Location loc, FunctionType type,
+                                    spirv::Opcode &typeEnum,
+                                    SmallVectorImpl<uint32_t> &operands);
+
+  //===--------------------------------------------------------------------===//
+  // Constant
+  //===--------------------------------------------------------------------===//
+
+  uint32_t findConstantID(Attribute value) const {
+    return constIDMap.lookup(value);
+  }
+
+  /// Main dispatch method for processing a constant with the given `constType`
+  /// and `valueAttr`. `constType` is needed here because we can interpret the
+  /// `valueAttr` as a different type than the type of `valueAttr` itself; for
+  /// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType
+  /// constants. If `isSpec` is true, then the constant will be serialized as
+  /// a specialization constant.
+  uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr,
+                           bool isSpec);
+
+  /// Prepares bool ElementsAttr serialization. This method updates `opcode`
+  /// with a proper OpConstant* instruction and pushes literal values for the
+  /// constant to `operands`.
+  LogicalResult prepareBoolVectorConstant(Location loc,
+                                          DenseIntElementsAttr elementsAttr,
+                                          bool isSpec, spirv::Opcode &opcode,
+                                          SmallVectorImpl<uint32_t> &operands);
+
+  /// Prepares int ElementsAttr serialization. This method updates `opcode` with
+  /// a proper OpConstant* instruction and pushes literal values for the
+  /// constant to `operands`.
+  LogicalResult prepareIntVectorConstant(Location loc,
+                                         DenseIntElementsAttr elementsAttr,
+                                         bool isSpec, spirv::Opcode &opcode,
+                                         SmallVectorImpl<uint32_t> &operands);
+
+  /// Prepares float ElementsAttr serialization. This method updates `opcode`
+  /// with a proper OpConstant* instruction and pushes literal values for the
+  /// constant to `operands`.
+  LogicalResult prepareFloatVectorConstant(Location loc,
+                                           DenseFPElementsAttr elementsAttr,
+                                           bool isSpec, spirv::Opcode &opcode,
+                                           SmallVectorImpl<uint32_t> &operands);
+
+  uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr, bool isSpec);
+
+  uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr, bool isSpec);
+
+  uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr, bool isSpec);
+
+  //===--------------------------------------------------------------------===//
+  // Operations
+  //===--------------------------------------------------------------------===//
+
+  uint32_t findValueID(Value *val) const { return valueIDMap.lookup(val); }
+
+  /// Main dispatch method for serializing an operation.
+  LogicalResult processOperation(Operation *op);
+
+  /// Method to dispatch to the serialization function for an operation in
+  /// SPIR-V dialect that is a mirror of an instruction in the SPIR-V spec.
+  /// This is auto-generated from ODS. Dispatch is handled for all operations
+  /// in SPIR-V dialect that have hasOpcode == 1.
+  LogicalResult dispatchToAutogenSerialization(Operation *op);
+
+  /// Method to serialize an operation in the SPIR-V dialect that is a mirror of
+  /// an instruction in the SPIR-V spec. This is auto generated if hasOpcode ==
+  /// 1 and autogenSerialization == 1 in ODS.
+  template <typename OpTy> LogicalResult processOp(OpTy op) {
+    return op.emitError("unsupported op serialization");
+  }
+
+private:
+  /// The SPIR-V module to be serialized.
+  spirv::ModuleOp module;
+
+  /// An MLIR builder for getting MLIR constructs.
+  mlir::Builder mlirBuilder;
+
+  /// The next available result <id>.
+  uint32_t nextID = 1;
+
+  // The following are for different SPIR-V instruction sections. They follow
+  // the logical layout of a SPIR-V module.
+
+  SmallVector<uint32_t, 4> capabilities;
+  SmallVector<uint32_t, 0> extensions;
+  SmallVector<uint32_t, 0> extendedSets;
+  SmallVector<uint32_t, 3> memoryModel;
+  SmallVector<uint32_t, 0> entryPoints;
+  SmallVector<uint32_t, 4> executionModes;
+  // TODO(antiagainst): debug instructions
+  SmallVector<uint32_t, 0> names;
+  SmallVector<uint32_t, 0> decorations;
+  SmallVector<uint32_t, 0> typesGlobalValues;
+  SmallVector<uint32_t, 0> functions;
+
+  /// Map from type used in SPIR-V module to their <id>s
+  DenseMap<Type, uint32_t> typeIDMap;
+
+  /// Map from constant values to their <id>s
+  DenseMap<Attribute, uint32_t> constIDMap;
+
+  /// Map from FuncOps name to <id>s.
+  llvm::StringMap<uint32_t> funcIDMap;
+
+  /// Map from results of normal operations to their <id>s
+  DenseMap<Value *, uint32_t> valueIDMap;
+};
+} // namespace
+
+Serializer::Serializer(spirv::ModuleOp module)
+    : module(module), mlirBuilder(module.getContext()) {}
+
+LogicalResult Serializer::serialize() {
+  if (failed(module.verify()))
+    return failure();
+
+  // TODO(antiagainst): handle the other sections
+  processMemoryModel();
+
+  // Iterate over the module body to serialze it. Assumptions are that there is
+  // only one basic block in the moduleOp
+  for (auto &op : module.getBlock()) {
+    if (failed(processOperation(&op))) {
+      return failure();
+    }
+  }
+  return success();
+}
+
+void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
+  auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
+                    extensions.size() + extendedSets.size() +
+                    memoryModel.size() + entryPoints.size() +
+                    executionModes.size() + decorations.size() +
+                    typesGlobalValues.size() + functions.size();
+
+  binary.clear();
+  binary.reserve(moduleSize);
+
+  spirv::appendModuleHeader(binary, nextID);
+  binary.append(capabilities.begin(), capabilities.end());
+  binary.append(extensions.begin(), extensions.end());
+  binary.append(extendedSets.begin(), extendedSets.end());
+  binary.append(memoryModel.begin(), memoryModel.end());
+  binary.append(entryPoints.begin(), entryPoints.end());
+  binary.append(executionModes.begin(), executionModes.end());
+  binary.append(names.begin(), names.end());
+  binary.append(decorations.begin(), decorations.end());
+  binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
+  binary.append(functions.begin(), functions.end());
+}
+//===----------------------------------------------------------------------===//
+// Module structure
+//===----------------------------------------------------------------------===//
+
+LogicalResult Serializer::processMemoryModel() {
+  uint32_t mm = module.getAttrOfType<IntegerAttr>("memory_model").getInt();
+  uint32_t am = module.getAttrOfType<IntegerAttr>("addressing_model").getInt();
+
+  return encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel,
+                               {am, mm});
+}
+
+LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
+  if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value(),
+                                      op.is_spec_const())) {
+    valueIDMap[op.getResult()] = resultID;
+    return success();
+  }
+  return failure();
+}
+
+LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
+                                            NamedAttribute attr) {
+  auto attrName = attr.first.strref();
+  auto decorationName = mlir::convertToCamelCase(attrName, true);
+  auto decoration = spirv::symbolizeDecoration(decorationName);
+  if (!decoration) {
+    return emitError(
+               loc, "non-argument attributes expected to have snake-case-ified "
+                    "decoration name, unhandled attribute with name : ")
+           << attrName;
+  }
+  SmallVector<uint32_t, 1> args;
+  args.push_back(resultID);
+  args.push_back(static_cast<uint32_t>(decoration.getValue()));
+  switch (decoration.getValue()) {
+  case spirv::Decoration::DescriptorSet:
+  case spirv::Decoration::Binding:
+    if (auto intAttr = attr.second.dyn_cast<IntegerAttr>()) {
+      args.push_back(intAttr.getValue().getZExtValue());
+      break;
+    }
+    return emitError(loc, "expected integer attribute for ") << attrName;
+  default:
+    return emitError(loc, "unhandled decoration ") << decorationName;
+  }
+  return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args);
+}
+
+LogicalResult Serializer::processFuncOp(FuncOp op) {
+  uint32_t fnTypeID = 0;
+  // Generate type of the function.
+  processType(op.getLoc(), op.getType(), fnTypeID);
+
+  // Add the function definition.
+  SmallVector<uint32_t, 4> operands;
+  uint32_t resTypeID = 0;
+  auto resultTypes = op.getType().getResults();
+  if (resultTypes.size() > 1) {
+    return emitError(op.getLoc(),
+                     "cannot serialize function with multiple return types");
+  }
+  if (failed(processType(op.getLoc(),
+                         (resultTypes.empty() ? getVoidType() : resultTypes[0]),
+                         resTypeID))) {
+    return failure();
+  }
+  operands.push_back(resTypeID);
+  auto funcID = getNextID();
+  funcIDMap[op.getName()] = funcID;
+  operands.push_back(funcID);
+  // TODO : Support other function control options.
+  operands.push_back(static_cast<uint32_t>(spirv::FunctionControl::None));
+  operands.push_back(fnTypeID);
+  encodeInstructionInto(functions, spirv::Opcode::OpFunction, operands);
+
+  // Add function name.
+  SmallVector<uint32_t, 4> nameOperands;
+  nameOperands.push_back(funcID);
+  encodeStringLiteralInto(nameOperands, op.getName());
+  encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
+
+  // Declare the parameters.
+  for (auto arg : op.getArguments()) {
+    uint32_t argTypeID = 0;
+    if (failed(processType(op.getLoc(), arg->getType(), argTypeID))) {
+      return failure();
+    }
+    auto argValueID = getNextID();
+    valueIDMap[arg] = argValueID;
+    encodeInstructionInto(functions, spirv::Opcode::OpFunctionParameter,
+                          {argTypeID, argValueID});
+  }
+
+  // Process the body.
+  if (op.isExternal()) {
+    return emitError(op.getLoc(), "external function is unhandled");
+  }
+
+  for (auto &b : op) {
+    for (auto &op : b) {
+      if (failed(processOperation(&op))) {
+        return failure();
+      }
+    }
+  }
+
+  // Insert Function End.
+  return encodeInstructionInto(functions, spirv::Opcode::OpFunctionEnd, {});
+}
+
+//===----------------------------------------------------------------------===//
+// Type
+//===----------------------------------------------------------------------===//
+
+LogicalResult Serializer::processType(Location loc, Type type,
+                                      uint32_t &typeID) {
+  typeID = findTypeID(type);
+  if (typeID) {
+    return success();
+  }
+  typeID = getNextID();
+  SmallVector<uint32_t, 4> operands;
+  operands.push_back(typeID);
+  auto typeEnum = spirv::Opcode::OpTypeVoid;
+  if ((type.isa<FunctionType>() &&
+       succeeded(prepareFunctionType(loc, type.cast<FunctionType>(), typeEnum,
+                                     operands))) ||
+      succeeded(prepareBasicType(loc, type, typeEnum, operands))) {
+    typeIDMap[type] = typeID;
+    return encodeInstructionInto(typesGlobalValues, typeEnum, operands);
+  }
+  return failure();
+}
+
+LogicalResult
+Serializer::prepareBasicType(Location loc, Type type, spirv::Opcode &typeEnum,
+                             SmallVectorImpl<uint32_t> &operands) {
+  if (isVoidType(type)) {
+    typeEnum = spirv::Opcode::OpTypeVoid;
+    return success();
+  }
+
+  if (auto intType = type.dyn_cast<IntegerType>()) {
+    if (intType.getWidth() == 1) {
+      typeEnum = spirv::Opcode::OpTypeBool;
+      return success();
+    }
+
+    typeEnum = spirv::Opcode::OpTypeInt;
+    operands.push_back(intType.getWidth());
+    // TODO(antiagainst): support unsigned integers
+    operands.push_back(1);
+    return success();
+  }
+
+  if (auto floatType = type.dyn_cast<FloatType>()) {
+    typeEnum = spirv::Opcode::OpTypeFloat;
+    operands.push_back(floatType.getWidth());
+    return success();
+  }
+
+  if (auto vectorType = type.dyn_cast<VectorType>()) {
+    uint32_t elementTypeID = 0;
+    if (failed(processType(loc, vectorType.getElementType(), elementTypeID))) {
+      return failure();
+    }
+    typeEnum = spirv::Opcode::OpTypeVector;
+    operands.push_back(elementTypeID);
+    operands.push_back(vectorType.getNumElements());
+    return success();
+  }
+
+  if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) {
+    typeEnum = spirv::Opcode::OpTypeArray;
+    uint32_t elementTypeID = 0;
+    if (failed(processType(loc, arrayType.getElementType(), elementTypeID))) {
+      return failure();
+    }
+    operands.push_back(elementTypeID);
+    if (auto elementCountID = prepareConstantInt(
+            loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()),
+            /*isSpec=*/false)) {
+      operands.push_back(elementCountID);
+      return success();
+    }
+    return failure();
+  }
+
+  if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
+    uint32_t pointeeTypeID = 0;
+    if (failed(processType(loc, ptrType.getPointeeType(), pointeeTypeID))) {
+      return failure();
+    }
+    typeEnum = spirv::Opcode::OpTypePointer;
+    operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
+    operands.push_back(pointeeTypeID);
+    return success();
+  }
+
+  // TODO(ravishankarm) : Handle other types.
+  return emitError(loc, "unhandled type in serialization: ") << type;
+}
+
+LogicalResult
+Serializer::prepareFunctionType(Location loc, FunctionType type,
+                                spirv::Opcode &typeEnum,
+                                SmallVectorImpl<uint32_t> &operands) {
+  typeEnum = spirv::Opcode::OpTypeFunction;
+  assert(type.getNumResults() <= 1 &&
+         "Serialization supports only a single return value");
+  uint32_t resultID = 0;
+  if (failed(processType(
+          loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
+          resultID))) {
+    return failure();
+  }
+  operands.push_back(resultID);
+  for (auto &res : type.getInputs()) {
+    uint32_t argTypeID = 0;
+    if (failed(processType(loc, res, argTypeID))) {
+      return failure();
+    }
+    operands.push_back(argTypeID);
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Constant
+//===----------------------------------------------------------------------===//
+
+uint32_t Serializer::prepareConstant(Location loc, Type constType,
+                                     Attribute valueAttr, bool isSpec) {
+  if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) {
+    return prepareConstantFp(loc, floatAttr, isSpec);
+  }
+  if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
+    return prepareConstantInt(loc, intAttr, isSpec);
+  }
+  if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) {
+    return prepareConstantBool(loc, boolAttr, isSpec);
+  }
+
+  // This is a composite literal. We need to handle each component separately
+  // and then emit an OpConstantComposite for the whole.
+
+  if (auto id = findConstantID(valueAttr)) {
+    return id;
+  }
+
+  uint32_t typeID = 0;
+  if (failed(processType(loc, constType, typeID))) {
+    return 0;
+  }
+  auto resultID = getNextID();
+
+  spirv::Opcode opcode = spirv::Opcode::OpNop;
+  SmallVector<uint32_t, 4> operands;
+  operands.push_back(typeID);
+  operands.push_back(resultID);
+
+  if (auto vectorAttr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
+    if (vectorAttr.getType().getElementType().isInteger(1)) {
+      if (failed(prepareBoolVectorConstant(loc, vectorAttr, isSpec, opcode,
+                                           operands)))
+        return 0;
+    } else if (failed(prepareIntVectorConstant(loc, vectorAttr, isSpec, opcode,
+                                               operands)))
+      return 0;
+  } else if (auto vectorAttr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
+    if (failed(prepareFloatVectorConstant(loc, vectorAttr, isSpec, opcode,
+                                          operands)))
+      return 0;
+  } else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) {
+    opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
+                    : spirv::Opcode::OpConstantComposite;
+    operands.reserve(arrayAttr.size() + 2);
+
+    auto elementType = constType.cast<spirv::ArrayType>().getElementType();
+    for (Attribute elementAttr : arrayAttr)
+      if (auto elementID =
+              prepareConstant(loc, elementType, elementAttr, isSpec)) {
+        operands.push_back(elementID);
+      } else {
+        return 0;
+      }
+  } else {
+    emitError(loc, "cannot serialize attribute: ") << valueAttr;
+    return 0;
+  }
+
+  encodeInstructionInto(typesGlobalValues, opcode, operands);
+  constIDMap[valueAttr] = resultID;
+  return resultID;
+}
+
+LogicalResult Serializer::prepareBoolVectorConstant(
+    Location loc, DenseIntElementsAttr elementsAttr, bool isSpec,
+    spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) {
+  auto type = elementsAttr.getType();
+  assert(type.hasRank() && type.getRank() == 1 &&
+         "spv.constant should have verified only vector literal uses "
+         "ElementsAttr");
+  assert(type.getElementType().isInteger(1) && "must be bool ElementsAttr");
+  auto count = type.getNumElements();
+
+  // Operands for constructing the SPIR-V OpConstant* instruction
+  operands.reserve(count + 2);
+
+  // For splat cases, we don't need to loop over all elements, especially when
+  // the splat value is zero.
+  if (Attribute splatAttr = elementsAttr.getSplatValue()) {
+    // We can use OpConstantNull if this bool ElementsAttr is splatting false.
+    if (!isSpec && !splatAttr.cast<BoolAttr>().getValue()) {
+      opcode = spirv::Opcode::OpConstantNull;
+      return success();
+    }
+
+    if (auto id =
+            prepareConstantBool(loc, splatAttr.cast<BoolAttr>(), isSpec)) {
+      opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
+                      : spirv::Opcode::OpConstantComposite;
+      operands.append(count, id);
+      return success();
+    }
+
+    return failure();
+  }
+
+  // Otherwise, we need to process each element and compose them with
+  // OpConstantComposite.
+  opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
+                  : spirv::Opcode::OpConstantComposite;
+  for (APInt intValue : elementsAttr) {
+    // We are constructing an BoolAttr for each APInt here. But given that
+    // we only use ElementsAttr for vectors with no more than 4 elements, it
+    // should be fine here.
+    auto boolAttr = mlirBuilder.getBoolAttr(intValue.isOneValue());
+    if (auto elementID = prepareConstantBool(loc, boolAttr, isSpec)) {
+      operands.push_back(elementID);
+    } else {
+      return failure();
+    }
+  }
+  return success();
+}
+
+LogicalResult Serializer::prepareIntVectorConstant(
+    Location loc, DenseIntElementsAttr elementsAttr, bool isSpec,
+    spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) {
+  auto type = elementsAttr.getType();
+  assert(type.hasRank() && type.getRank() == 1 &&
+         "spv.constant should have verified only vector literal uses "
+         "ElementsAttr");
+  auto elementType = type.getElementType();
+  assert(!elementType.isInteger(1) && "must be non-bool ElementsAttr");
+  auto count = type.getNumElements();
+
+  // Operands for constructing the SPIR-V OpConstant* instruction
+  operands.reserve(count + 2);
+
+  // For splat cases, we don't need to loop over all elements, especially when
+  // the splat value is zero.
+  if (Attribute splatAttr = elementsAttr.getSplatValue()) {
+    // We can use OpConstantNull if this int ElementsAttr is splatting 0.
+    if (!isSpec && splatAttr.cast<IntegerAttr>().getValue().isNullValue()) {
+      opcode = spirv::Opcode::OpConstantNull;
+      return success();
+    }
+
+    if (auto id =
+            prepareConstantInt(loc, splatAttr.cast<IntegerAttr>(), isSpec)) {
+      opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
+                      : spirv::Opcode::OpConstantComposite;
+      operands.append(count, id);
+      return success();
+    }
+    return failure();
+  }
+
+  // Otherwise, we need to process each element and compose them with
+  // OpConstantComposite.
+  opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
+                  : spirv::Opcode::OpConstantComposite;
+  for (APInt intValue : elementsAttr) {
+    // We are constructing an IntegerAttr for each APInt here. But given that
+    // we only use ElementsAttr for vectors with no more than 4 elements, it
+    // should be fine here.
+    // TODO(antiagainst): revisit this if special extensions enabling large
+    // vectors are supported.
+    auto intAttr = mlirBuilder.getIntegerAttr(elementType, intValue);
+    if (auto elementID = prepareConstantInt(loc, intAttr, isSpec)) {
+      operands.push_back(elementID);
+    } else {
+      return failure();
+    }
+  }
+  return success();
+}
+
+LogicalResult Serializer::prepareFloatVectorConstant(
+    Location loc, DenseFPElementsAttr elementsAttr, bool isSpec,
+    spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) {
+  auto type = elementsAttr.getType();
+  assert(type.hasRank() && type.getRank() == 1 &&
+         "spv.constant should have verified only vector literal uses "
+         "ElementsAttr");
+  auto count = type.getNumElements();
+  auto elementType = type.getElementType();
+
+  operands.reserve(count + 2);
+
+  if (Attribute splatAttr = elementsAttr.getSplatValue()) {
+    if (!isSpec && splatAttr.cast<FloatAttr>().getValue().isZero()) {
+      opcode = spirv::Opcode::OpConstantNull;
+      return success();
+    }
+
+    if (auto id = prepareConstantFp(loc, splatAttr.cast<FloatAttr>(), isSpec)) {
+      opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
+                      : spirv::Opcode::OpConstantComposite;
+      operands.append(count, id);
+      return success();
+    }
+
+    return failure();
+  }
+
+  opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
+                  : spirv::Opcode::OpConstantComposite;
+  for (APFloat floatValue : elementsAttr) {
+    auto fpAttr = mlirBuilder.getFloatAttr(elementType, floatValue);
+    if (auto elementID = prepareConstantFp(loc, fpAttr, isSpec)) {
+      operands.push_back(elementID);
+    } else {
+      return failure();
+    }
+  }
+  return success();
+}
+
+uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
+                                         bool isSpec) {
+  if (auto id = findConstantID(boolAttr)) {
+    return id;
+  }
+
+  // Process the type for this bool literal
+  uint32_t typeID = 0;
+  if (failed(processType(loc, boolAttr.getType(), typeID))) {
+    return 0;
+  }
+
+  auto resultID = getNextID();
+  auto opcode = boolAttr.getValue()
+                    ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
+                              : spirv::Opcode::OpConstantTrue)
+                    : (isSpec ? spirv::Opcode::OpSpecConstantFalse
+                              : spirv::Opcode::OpConstantFalse);
+  encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
+
+  return constIDMap[boolAttr] = resultID;
+}
+
+uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
+                                        bool isSpec) {
+  if (auto id = findConstantID(intAttr)) {
+    return id;
+  }
+
+  // Process the type for this integer literal
+  uint32_t typeID = 0;
+  if (failed(processType(loc, intAttr.getType(), typeID))) {
+    return 0;
+  }
+
+  auto resultID = getNextID();
+  APInt value = intAttr.getValue();
+  unsigned bitwidth = value.getBitWidth();
+  bool isSigned = value.isSignedIntN(bitwidth);
+
+  auto opcode =
+      isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
+
+  // According to SPIR-V spec, "When the type's bit width is less than 32-bits,
+  // the literal's value appears in the low-order bits of the word, and the
+  // high-order bits must be 0 for a floating-point type, or 0 for an integer
+  // type with Signedness of 0, or sign extended when Signedness is 1."
+  if (bitwidth == 32 || bitwidth == 16) {
+    uint32_t word = 0;
+    if (isSigned) {
+      word = static_cast<int32_t>(value.getSExtValue());
+    } else {
+      word = static_cast<uint32_t>(value.getZExtValue());
+    }
+    encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
+  }
+  // According to SPIR-V spec: "When the type's bit width is larger than one
+  // word, the literal’s low-order words appear first."
+  else if (bitwidth == 64) {
+    struct DoubleWord {
+      uint32_t word1;
+      uint32_t word2;
+    } words;
+    if (isSigned) {
+      words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
+    } else {
+      words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
+    }
+    encodeInstructionInto(typesGlobalValues, opcode,
+                          {typeID, resultID, words.word1, words.word2});
+  } else {
+    std::string valueStr;
+    llvm::raw_string_ostream rss(valueStr);
+    value.print(rss, /*isSigned*/ false);
+
+    emitError(loc, "cannot serialize ")
+        << bitwidth << "-bit integer literal: " << rss.str();
+    return 0;
+  }
+
+  return constIDMap[intAttr] = resultID;
+}
+
+uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
+                                       bool isSpec) {
+  if (auto id = findConstantID(floatAttr)) {
+    return id;
+  }
+
+  // Process the type for this float literal
+  uint32_t typeID = 0;
+  if (failed(processType(loc, floatAttr.getType(), typeID))) {
+    return 0;
+  }
+
+  auto resultID = getNextID();
+  APFloat value = floatAttr.getValue();
+  APInt intValue = value.bitcastToAPInt();
+
+  auto opcode =
+      isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
+
+  if (&value.getSemantics() == &APFloat::IEEEsingle()) {
+    uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
+    encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
+  } else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
+    struct DoubleWord {
+      uint32_t word1;
+      uint32_t word2;
+    } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
+    encodeInstructionInto(typesGlobalValues, opcode,
+                          {typeID, resultID, words.word1, words.word2});
+  } else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
+    uint32_t word =
+        static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
+    encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
+  } else {
+    std::string valueStr;
+    llvm::raw_string_ostream rss(valueStr);
+    value.print(rss);
+
+    emitError(loc, "cannot serialize ")
+        << floatAttr.getType() << "-typed float literal: " << rss.str();
+    return 0;
+  }
+
+  return constIDMap[floatAttr] = resultID;
+}
+
+//===----------------------------------------------------------------------===//
+// Operation
+//===----------------------------------------------------------------------===//
+
+LogicalResult Serializer::processOperation(Operation *op) {
+  // First dispatch the methods that do not directly mirror an operation from
+  // the SPIR-V spec
+  if (auto constOp = dyn_cast<spirv::ConstantOp>(op)) {
+    return processConstantOp(constOp);
+  }
+  if (auto fnOp = dyn_cast<FuncOp>(op)) {
+    return processFuncOp(fnOp);
+  }
+  if (isa<spirv::ModuleEndOp>(op)) {
+    return success();
+  }
+  return dispatchToAutogenSerialization(op);
+}
+
+namespace {
+template <>
+LogicalResult
+Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
+  SmallVector<uint32_t, 4> operands;
+  // Add the ExectionModel.
+  operands.push_back(static_cast<uint32_t>(op.execution_model()));
+  // Add the function <id>.
+  auto funcID = findFunctionID(op.fn());
+  if (!funcID) {
+    return op.emitError("missing <id> for function ")
+           << op.fn()
+           << "; function needs to be defined before spv.EntryPoint is "
+              "serialized";
+  }
+  operands.push_back(funcID);
+  // Add the name of the function.
+  encodeStringLiteralInto(operands, op.fn());
+
+  // Add the interface values.
+  for (auto val : op.interface()) {
+    auto id = findValueID(val);
+    if (!id) {
+      return op.emitError("referencing unintialized variable <id>. "
+                          "spv.EntryPoint is at the end of spv.module. All "
+                          "referenced variables should already be defined");
+    }
+    operands.push_back(id);
+  }
+  return encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint,
+                               operands);
+}
+
+template <>
+LogicalResult
+Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
+  SmallVector<uint32_t, 4> operands;
+  // Add the function <id>.
+  auto funcID = findFunctionID(op.fn());
+  if (!funcID) {
+    return op.emitError("missing <id> for function ")
+           << op.fn()
+           << "; function needs to be serialized before ExecutionModeOp is "
+              "serialized";
+  }
+  operands.push_back(funcID);
+  // Add the ExecutionMode.
+  operands.push_back(static_cast<uint32_t>(op.execution_mode()));
+
+  // Serialize values if any.
+  auto values = op.values();
+  if (values) {
+    for (auto &intVal : values.getValue()) {
+      operands.push_back(static_cast<uint32_t>(
+          intVal.cast<IntegerAttr>().getValue().getZExtValue()));
+    }
+  }
+  return encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
+                               operands);
+}
+
+// Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
+// various Serializer::processOp<...>() specializations.
+#define GET_SERIALIZATION_FNS
+#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc"
+} // namespace
+
+LogicalResult spirv::serialize(spirv::ModuleOp module,
+                               SmallVectorImpl<uint32_t> &binary) {
+  Serializer serializer(module);
+
+  if (failed(serializer.serialize()))
+    return failure();
+
+  serializer.collect(binary);
+  return success();
+}
diff --git a/third_party/mlir/lib/Dialect/Traits.cpp b/third_party/mlir/lib/Dialect/Traits.cpp
new file mode 100644
index 0000000..9945b6a
--- /dev/null
+++ b/third_party/mlir/lib/Dialect/Traits.cpp
@@ -0,0 +1,221 @@
+//===- Traits.cpp - Common op traits shared by dialects -------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/Traits.h"
+#include "mlir/IR/StandardTypes.h"
+#include "llvm/Support/FormatVariadic.h"
+
+using namespace mlir;
+
+bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
+                                        ArrayRef<int64_t> shape2,
+                                        SmallVectorImpl<int64_t> &resultShape) {
+  // To compute the result broadcasted shape, we compare operand shapes
+  // element-wise: starting with the trailing dimensions, and working the
+  // way backward. Two dimensions are compatible when
+  //   1. they are equal, or
+  //   2. one of them is 1
+  // The result shape has the maximum among the two inputs at every
+  // dimension index.
+
+  resultShape.clear();
+  if (shape1.size() > shape2.size()) {
+    std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape));
+  } else {
+    std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape));
+  }
+
+  auto i1 = shape1.rbegin(), e1 = shape1.rend();
+  auto i2 = shape2.rbegin(), e2 = shape2.rend();
+  auto iR = resultShape.rbegin();
+
+  // Check each dimension is consistent.
+  for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
+    if (*i1 == -1 || *i2 == -1) {
+      // One or both dimensions is unknown. Follow TensorFlow behavior:
+      // - If either dimension is greater than 1, we assume that the program is
+      //   correct, and the other dimension will be broadcast to match it.
+      // - If either dimension is 1, the other dimension is the output.
+      if (*i1 > 1) {
+        *iR = *i1;
+      } else if (*i2 > 1) {
+        *iR = *i2;
+      } else if (*i1 == 1) {
+        *iR = *i2;
+      } else if (*i2 == 1) {
+        *iR = *i1;
+      } else {
+        *iR = -1;
+      }
+    } else {
+      if (*i1 == *i2 || *i2 == 1) {
+        *iR = *i1;
+      } else if (*i1 == 1) {
+        *iR = *i2;
+      } else {
+        // This dimension of the two operand types is incompatible.
+        resultShape.clear();
+        return false;
+      }
+    }
+  }
+
+  return true;
+}
+
+/// Returns the shape of the given type. Scalars will be considered as having a
+/// shape with zero dimensions.
+static ArrayRef<int64_t> getShape(Type type) {
+  if (auto sType = type.dyn_cast<ShapedType>())
+    return sType.getShape();
+  return {};
+}
+
+/// Returns the result broadcast composition type from the two given types by
+/// following NumPy broadcast semantics. Returned type may have dynamic shape if
+/// either of the input types has dynamic shape. Returns null type if the two
+/// given types are not broadcast-compatible.
+Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
+  // Returns the scalar type out of the given type.
+  auto getScalarType = [](Type type) -> Type {
+    if (auto shapedType = type.dyn_cast<ShapedType>())
+      return shapedType.getElementType();
+    return type;
+  };
+
+  // Make sure underlying scalar type is the same.
+  auto scalarType = getScalarType(type1);
+  if (scalarType != getScalarType(type2))
+    return {};
+
+  // If one of the types is unranked tensor, then the other type shouldn't be
+  // vector and the result should have unranked tensor type.
+  if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
+    if (type1.isa<VectorType>() || type2.isa<VectorType>())
+      return {};
+    return UnrankedTensorType::get(scalarType);
+  }
+
+  // Returns the type kind if the given type is a vector or ranked tensor type.
+  // Returns llvm::None otherwise.
+  auto getCompositeTypeKind =
+      [](Type type) -> llvm::Optional<StandardTypes::Kind> {
+    if (type.isa<VectorType>() || type.isa<RankedTensorType>())
+      return static_cast<StandardTypes::Kind>(type.getKind());
+    return llvm::None;
+  };
+
+  // Make sure the composite type, if has, is consistent.
+  auto compositeKind1 = getCompositeTypeKind(type1);
+  auto compositeKind2 = getCompositeTypeKind(type2);
+  llvm::Optional<StandardTypes::Kind> resultCompositeKind;
+
+  if (compositeKind1 && compositeKind2) {
+    // Disallow mixing vector and tensor.
+    if (compositeKind1 != compositeKind2)
+      return {};
+    resultCompositeKind = compositeKind1;
+  } else if (compositeKind1) {
+    resultCompositeKind = compositeKind1;
+  } else if (compositeKind2) {
+    resultCompositeKind = compositeKind2;
+  }
+
+  // Get the shape of each type.
+  SmallVector<int64_t, 4> resultShape;
+  if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
+    return {};
+
+  // Compose the final broadcasted type
+  if (resultCompositeKind == StandardTypes::Vector)
+    return VectorType::get(resultShape, scalarType);
+  if (resultCompositeKind == StandardTypes::RankedTensor)
+    return RankedTensorType::get(resultShape, scalarType);
+  return scalarType;
+}
+
+/// Returns true if the given types has both vector types and tensor types.
+static bool hasBothVectorAndTensorType(ArrayRef<Type> types) {
+  return llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }) &&
+         llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); });
+}
+
+static bool areCompatibleShapes(ArrayRef<int64_t> shape1,
+                                ArrayRef<int64_t> shape2) {
+  auto isCompatible = [](int64_t dim1, int64_t dim2) {
+    return dim1 == dim2 || dim1 == -1 || dim2 == -1;
+  };
+  if (shape1.size() != shape2.size())
+    return false;
+  for (const auto &p : llvm::zip(shape1, shape2))
+    if (!isCompatible(std::get<0>(p), std::get<1>(p)))
+      return false;
+  return true;
+}
+
+LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
+  assert(op->getNumOperands() == 2 &&
+         "only support broadcast check on two operands");
+  assert(op->getNumResults() == 1 &&
+         "only support broadcast check on one result");
+
+  auto type1 = op->getOperand(0)->getType();
+  auto type2 = op->getOperand(1)->getType();
+  auto retType = op->getResult(0)->getType();
+
+  // We forbid broadcasting vector and tensor.
+  if (hasBothVectorAndTensorType({type1, type2, retType}))
+    return op->emitError("cannot broadcast vector with tensor");
+
+  if (retType.isa<UnrankedTensorType>())
+    return success();
+
+  bool isUnranked1 = type1.isa<UnrankedTensorType>();
+  bool isUnranked2 = type2.isa<UnrankedTensorType>();
+
+  // If both operands are unranked, then all result shapes are possible.
+  if (isUnranked1 && isUnranked2)
+    return success();
+
+  // If one of the operands is unranked, then the known dimensions in the result
+  // should be compatible with the other shaped operand.
+  if (isUnranked1 || isUnranked2) {
+    // Result should have higher rank than the shaped operand's rank and then
+    // the result's trailing dimensions should be compatible with the operand
+    // shape.
+    ArrayRef<int64_t> shape = getShape(!isUnranked1 ? type1 : type2);
+    ArrayRef<int64_t> actualSuffix = getShape(retType).take_back(shape.size());
+    if (!areCompatibleShapes(actualSuffix, shape))
+      return op->emitOpError()
+             << "result type " << retType
+             << " has shape incompatible with a ranked operand type";
+    return success();
+  }
+
+  // If both operands are shaped, then the computed broadcasted shape should be
+  // compatible with the result shape.
+  SmallVector<int64_t, 4> resultShape;
+  if (!util::getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
+    return op->emitOpError("operands don't have broadcast-compatible shapes");
+
+  if (!areCompatibleShapes(resultShape, getShape(retType)))
+    return op->emitOpError() << "result type " << retType
+                             << " does not have shape compatible with the one "
+                                "computed from the operand types";
+
+  return success();
+}
diff --git a/third_party/mlir/lib/EDSC/Builders.cpp b/third_party/mlir/lib/EDSC/Builders.cpp
new file mode 100644
index 0000000..d524900
--- /dev/null
+++ b/third_party/mlir/lib/EDSC/Builders.cpp
@@ -0,0 +1,459 @@
+//===- Builders.cpp - MLIR Declarative Builder Classes --------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/EDSC/Builders.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/StandardOps/Ops.h"
+
+#include "llvm/ADT/Optional.h"
+
+using namespace mlir;
+using namespace mlir::edsc;
+
+mlir::edsc::ScopedContext::ScopedContext(OpBuilder &builder, Location location)
+    : builder(builder), location(location),
+      enclosingScopedContext(ScopedContext::getCurrentScopedContext()),
+      nestedBuilder(nullptr) {
+  getCurrentScopedContext() = this;
+}
+
+/// Sets the insertion point of the builder to 'newInsertPt' for the duration
+/// of the scope. The existing insertion point of the builder is restored on
+/// destruction.
+mlir::edsc::ScopedContext::ScopedContext(OpBuilder &builder,
+                                         OpBuilder::InsertPoint newInsertPt,
+                                         Location location)
+    : builder(builder), prevBuilderInsertPoint(builder.saveInsertionPoint()),
+      location(location),
+      enclosingScopedContext(ScopedContext::getCurrentScopedContext()),
+      nestedBuilder(nullptr) {
+  getCurrentScopedContext() = this;
+  builder.restoreInsertionPoint(newInsertPt);
+}
+
+mlir::edsc::ScopedContext::~ScopedContext() {
+  assert(!nestedBuilder &&
+         "Active NestedBuilder must have been exited at this point!");
+  if (prevBuilderInsertPoint)
+    builder.restoreInsertionPoint(*prevBuilderInsertPoint);
+  getCurrentScopedContext() = enclosingScopedContext;
+}
+
+ScopedContext *&mlir::edsc::ScopedContext::getCurrentScopedContext() {
+  thread_local ScopedContext *context = nullptr;
+  return context;
+}
+
+OpBuilder &mlir::edsc::ScopedContext::getBuilder() {
+  assert(ScopedContext::getCurrentScopedContext() &&
+         "Unexpected Null ScopedContext");
+  return ScopedContext::getCurrentScopedContext()->builder;
+}
+
+Location mlir::edsc::ScopedContext::getLocation() {
+  assert(ScopedContext::getCurrentScopedContext() &&
+         "Unexpected Null ScopedContext");
+  return ScopedContext::getCurrentScopedContext()->location;
+}
+
+MLIRContext *mlir::edsc::ScopedContext::getContext() {
+  return getBuilder().getContext();
+}
+
+mlir::edsc::ValueHandle::ValueHandle(index_t cst) {
+  auto &b = ScopedContext::getBuilder();
+  auto loc = ScopedContext::getLocation();
+  v = b.create<ConstantIndexOp>(loc, cst.v).getResult();
+  t = v->getType();
+}
+
+ValueHandle &mlir::edsc::ValueHandle::operator=(const ValueHandle &other) {
+  assert(t == other.t && "Wrong type capture");
+  assert(!v && "ValueHandle has already been captured, use a new name!");
+  v = other.v;
+  return *this;
+}
+
+ValueHandle
+mlir::edsc::ValueHandle::createComposedAffineApply(AffineMap map,
+                                                   ArrayRef<Value *> operands) {
+  Operation *op =
+      makeComposedAffineApply(ScopedContext::getBuilder(),
+                              ScopedContext::getLocation(), map, operands)
+          .getOperation();
+  assert(op->getNumResults() == 1 && "Not a single result AffineApply");
+  return ValueHandle(op->getResult(0));
+}
+
+ValueHandle ValueHandle::create(StringRef name, ArrayRef<ValueHandle> operands,
+                                ArrayRef<Type> resultTypes,
+                                ArrayRef<NamedAttribute> attributes) {
+  Operation *op =
+      OperationHandle::create(name, operands, resultTypes, attributes);
+  if (op->getNumResults() == 1) {
+    return ValueHandle(op->getResult(0));
+  }
+  if (auto f = dyn_cast<AffineForOp>(op)) {
+    return ValueHandle(f.getInductionVar());
+  }
+  llvm_unreachable("unsupported operation, use an OperationHandle instead");
+}
+
+OperationHandle OperationHandle::create(StringRef name,
+                                        ArrayRef<ValueHandle> operands,
+                                        ArrayRef<Type> resultTypes,
+                                        ArrayRef<NamedAttribute> attributes) {
+  OperationState state(ScopedContext::getLocation(), name);
+  SmallVector<Value *, 4> ops(operands.begin(), operands.end());
+  state.addOperands(ops);
+  state.addTypes(resultTypes);
+  for (const auto &attr : attributes) {
+    state.addAttribute(attr.first, attr.second);
+  }
+  return OperationHandle(ScopedContext::getBuilder().createOperation(state));
+}
+
+BlockHandle mlir::edsc::BlockHandle::create(ArrayRef<Type> argTypes) {
+  auto &currentB = ScopedContext::getBuilder();
+  auto *ib = currentB.getInsertionBlock();
+  auto ip = currentB.getInsertionPoint();
+  BlockHandle res;
+  res.block = ScopedContext::getBuilder().createBlock(ib->getParent());
+  // createBlock sets the insertion point inside the block.
+  // We do not want this behavior when using declarative builders with nesting.
+  currentB.setInsertionPoint(ib, ip);
+  for (auto t : argTypes) {
+    res.block->addArgument(t);
+  }
+  return res;
+}
+
+static llvm::Optional<ValueHandle> emitStaticFor(ArrayRef<ValueHandle> lbs,
+                                                 ArrayRef<ValueHandle> ubs,
+                                                 int64_t step) {
+  if (lbs.size() != 1 || ubs.size() != 1)
+    return llvm::Optional<ValueHandle>();
+
+  auto *lbDef = lbs.front().getValue()->getDefiningOp();
+  auto *ubDef = ubs.front().getValue()->getDefiningOp();
+  if (!lbDef || !ubDef)
+    return llvm::Optional<ValueHandle>();
+
+  auto lbConst = dyn_cast<ConstantIndexOp>(lbDef);
+  auto ubConst = dyn_cast<ConstantIndexOp>(ubDef);
+  if (!lbConst || !ubConst)
+    return llvm::Optional<ValueHandle>();
+
+  return ValueHandle::create<AffineForOp>(lbConst.getValue(),
+                                          ubConst.getValue(), step);
+}
+
+mlir::edsc::LoopBuilder::LoopBuilder(ValueHandle *iv,
+                                     ArrayRef<ValueHandle> lbHandles,
+                                     ArrayRef<ValueHandle> ubHandles,
+                                     int64_t step) {
+  if (auto res = emitStaticFor(lbHandles, ubHandles, step)) {
+    *iv = res.getValue();
+  } else {
+    SmallVector<Value *, 4> lbs(lbHandles.begin(), lbHandles.end());
+    SmallVector<Value *, 4> ubs(ubHandles.begin(), ubHandles.end());
+    *iv = ValueHandle::create<AffineForOp>(
+        lbs, ScopedContext::getBuilder().getMultiDimIdentityMap(lbs.size()),
+        ubs, ScopedContext::getBuilder().getMultiDimIdentityMap(ubs.size()),
+        step);
+  }
+  auto *body = getForInductionVarOwner(iv->getValue()).getBody();
+  enter(body, /*prev=*/1);
+}
+
+ValueHandle
+mlir::edsc::LoopBuilder::operator()(llvm::function_ref<void(void)> fun) {
+  // Call to `exit` must be explicit and asymmetric (cannot happen in the
+  // destructor) because of ordering wrt comma operator.
+  /// The particular use case concerns nested blocks:
+  ///
+  /// ```c++
+  ///    For (&i, lb, ub, 1)({
+  ///      /--- destructor for this `For` is not always called before ...
+  ///      V
+  ///      For (&j1, lb, ub, 1)({
+  ///        some_op_1,
+  ///      }),
+  ///      /--- ... this scope is entered, resulting in improperly nested IR.
+  ///      V
+  ///      For (&j2, lb, ub, 1)({
+  ///        some_op_2,
+  ///      }),
+  ///    });
+  /// ```
+  if (fun)
+    fun();
+  exit();
+  return ValueHandle::null();
+}
+
+mlir::edsc::LoopNestBuilder::LoopNestBuilder(ArrayRef<ValueHandle *> ivs,
+                                             ArrayRef<ValueHandle> lbs,
+                                             ArrayRef<ValueHandle> ubs,
+                                             ArrayRef<int64_t> steps) {
+  assert(ivs.size() == lbs.size() && "Mismatch in number of arguments");
+  assert(ivs.size() == ubs.size() && "Mismatch in number of arguments");
+  assert(ivs.size() == steps.size() && "Mismatch in number of arguments");
+  for (auto it : llvm::zip(ivs, lbs, ubs, steps)) {
+    loops.emplace_back(std::get<0>(it), std::get<1>(it), std::get<2>(it),
+                       std::get<3>(it));
+  }
+}
+
+ValueHandle
+mlir::edsc::LoopNestBuilder::operator()(llvm::function_ref<void(void)> fun) {
+  if (fun)
+    fun();
+  // Iterate on the calling operator() on all the loops in the nest.
+  // The iteration order is from innermost to outermost because enter/exit needs
+  // to be asymmetric (i.e. enter() occurs on LoopBuilder construction, exit()
+  // occurs on calling operator()). The asymmetry is required for properly
+  // nesting imperfectly nested regions (see LoopBuilder::operator()).
+  for (auto lit = loops.rbegin(), eit = loops.rend(); lit != eit; ++lit) {
+    (*lit)();
+  }
+  return ValueHandle::null();
+}
+
+mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle bh, Append) {
+  assert(bh && "Expected already captured BlockHandle");
+  enter(bh.getBlock());
+}
+
+mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh,
+                                       ArrayRef<ValueHandle *> args) {
+  assert(!*bh && "BlockHandle already captures a block, use "
+                 "the explicit BockBuilder(bh, Append())({}) syntax instead.");
+  llvm::SmallVector<Type, 8> types;
+  for (auto *a : args) {
+    assert(!a->hasValue() &&
+           "Expected delayed ValueHandle that has not yet captured.");
+    types.push_back(a->getType());
+  }
+  *bh = BlockHandle::create(types);
+  for (auto it : llvm::zip(args, bh->getBlock()->getArguments())) {
+    *(std::get<0>(it)) = ValueHandle(std::get<1>(it));
+  }
+  enter(bh->getBlock());
+}
+
+/// Only serves as an ordering point between entering nested block and creating
+/// stmts.
+void mlir::edsc::BlockBuilder::operator()(llvm::function_ref<void(void)> fun) {
+  // Call to `exit` must be explicit and asymmetric (cannot happen in the
+  // destructor) because of ordering wrt comma operator.
+  if (fun)
+    fun();
+  exit();
+}
+
+template <typename Op>
+static ValueHandle createBinaryHandle(ValueHandle lhs, ValueHandle rhs) {
+  return ValueHandle::create<Op>(lhs.getValue(), rhs.getValue());
+}
+
+static std::pair<AffineExpr, Value *>
+categorizeValueByAffineType(MLIRContext *context, Value *val, unsigned &numDims,
+                            unsigned &numSymbols) {
+  AffineExpr d;
+  Value *resultVal = nullptr;
+  if (auto constant = dyn_cast_or_null<ConstantIndexOp>(val->getDefiningOp())) {
+    d = getAffineConstantExpr(constant.getValue(), context);
+  } else if (isValidSymbol(val) && !isValidDim(val)) {
+    d = getAffineSymbolExpr(numSymbols++, context);
+    resultVal = val;
+  } else {
+    assert(isValidDim(val) && "Must be a valid Dim");
+    d = getAffineDimExpr(numDims++, context);
+    resultVal = val;
+  }
+  return std::make_pair(d, resultVal);
+}
+
+static ValueHandle createBinaryIndexHandle(
+    ValueHandle lhs, ValueHandle rhs,
+    llvm::function_ref<AffineExpr(AffineExpr, AffineExpr)> affCombiner) {
+  MLIRContext *context = ScopedContext::getContext();
+  unsigned numDims = 0, numSymbols = 0;
+  AffineExpr d0, d1;
+  Value *v0, *v1;
+  std::tie(d0, v0) =
+      categorizeValueByAffineType(context, lhs.getValue(), numDims, numSymbols);
+  std::tie(d1, v1) =
+      categorizeValueByAffineType(context, rhs.getValue(), numDims, numSymbols);
+  SmallVector<Value *, 2> operands;
+  if (v0) {
+    operands.push_back(v0);
+  }
+  if (v1) {
+    operands.push_back(v1);
+  }
+  auto map = AffineMap::get(numDims, numSymbols, {affCombiner(d0, d1)});
+  // TODO: createOrFold when available.
+  return ValueHandle::createComposedAffineApply(map, operands);
+}
+
+template <typename IOp, typename FOp>
+static ValueHandle createBinaryHandle(
+    ValueHandle lhs, ValueHandle rhs,
+    llvm::function_ref<AffineExpr(AffineExpr, AffineExpr)> affCombiner) {
+  auto thisType = lhs.getValue()->getType();
+  auto thatType = rhs.getValue()->getType();
+  assert(thisType == thatType && "cannot mix types in operators");
+  (void)thisType;
+  (void)thatType;
+  if (thisType.isIndex()) {
+    return createBinaryIndexHandle(lhs, rhs, affCombiner);
+  } else if (thisType.isa<IntegerType>()) {
+    return createBinaryHandle<IOp>(lhs, rhs);
+  } else if (thisType.isa<FloatType>()) {
+    return createBinaryHandle<FOp>(lhs, rhs);
+  } else if (thisType.isa<VectorType>() || thisType.isa<TensorType>()) {
+    auto aggregateType = thisType.cast<ShapedType>();
+    if (aggregateType.getElementType().isa<IntegerType>())
+      return createBinaryHandle<IOp>(lhs, rhs);
+    else if (aggregateType.getElementType().isa<FloatType>())
+      return createBinaryHandle<FOp>(lhs, rhs);
+  }
+  llvm_unreachable("failed to create a ValueHandle");
+}
+
+ValueHandle mlir::edsc::op::operator+(ValueHandle lhs, ValueHandle rhs) {
+  return createBinaryHandle<AddIOp, AddFOp>(
+      lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 + d1; });
+}
+
+ValueHandle mlir::edsc::op::operator-(ValueHandle lhs, ValueHandle rhs) {
+  return createBinaryHandle<SubIOp, SubFOp>(
+      lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 - d1; });
+}
+
+ValueHandle mlir::edsc::op::operator*(ValueHandle lhs, ValueHandle rhs) {
+  return createBinaryHandle<MulIOp, MulFOp>(
+      lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 * d1; });
+}
+
+ValueHandle mlir::edsc::op::operator/(ValueHandle lhs, ValueHandle rhs) {
+  return createBinaryHandle<DivISOp, DivFOp>(
+      lhs, rhs, [](AffineExpr d0, AffineExpr d1) -> AffineExpr {
+        llvm_unreachable("only exprs of non-index type support operator/");
+      });
+}
+
+ValueHandle mlir::edsc::op::operator%(ValueHandle lhs, ValueHandle rhs) {
+  return createBinaryHandle<RemISOp, RemFOp>(
+      lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 % d1; });
+}
+
+ValueHandle mlir::edsc::op::floorDiv(ValueHandle lhs, ValueHandle rhs) {
+  return createBinaryIndexHandle(
+      lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.floorDiv(d1); });
+}
+
+ValueHandle mlir::edsc::op::ceilDiv(ValueHandle lhs, ValueHandle rhs) {
+  return createBinaryIndexHandle(
+      lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.ceilDiv(d1); });
+}
+
+ValueHandle mlir::edsc::op::operator!(ValueHandle value) {
+  assert(value.getType().isInteger(1) && "expected boolean expression");
+  return ValueHandle::create<ConstantIntOp>(1, 1) - value;
+}
+
+ValueHandle mlir::edsc::op::operator&&(ValueHandle lhs, ValueHandle rhs) {
+  assert(lhs.getType().isInteger(1) && "expected boolean expression on LHS");
+  assert(rhs.getType().isInteger(1) && "expected boolean expression on RHS");
+  return lhs * rhs;
+}
+
+ValueHandle mlir::edsc::op::operator||(ValueHandle lhs, ValueHandle rhs) {
+  return !(!lhs && !rhs);
+}
+
+static ValueHandle createIComparisonExpr(CmpIPredicate predicate,
+                                         ValueHandle lhs, ValueHandle rhs) {
+  auto lhsType = lhs.getType();
+  auto rhsType = rhs.getType();
+  (void)lhsType;
+  (void)rhsType;
+  assert(lhsType == rhsType && "cannot mix types in operators");
+  assert((lhsType.isa<IndexType>() || lhsType.isa<IntegerType>()) &&
+         "only integer comparisons are supported");
+
+  auto op = ScopedContext::getBuilder().create<CmpIOp>(
+      ScopedContext::getLocation(), predicate, lhs.getValue(), rhs.getValue());
+  return ValueHandle(op.getResult());
+}
+
+static ValueHandle createFComparisonExpr(CmpFPredicate predicate,
+                                         ValueHandle lhs, ValueHandle rhs) {
+  auto lhsType = lhs.getType();
+  auto rhsType = rhs.getType();
+  (void)lhsType;
+  (void)rhsType;
+  assert(lhsType == rhsType && "cannot mix types in operators");
+  assert(lhsType.isa<FloatType>() && "only float comparisons are supported");
+
+  auto op = ScopedContext::getBuilder().create<CmpFOp>(
+      ScopedContext::getLocation(), predicate, lhs.getValue(), rhs.getValue());
+  return ValueHandle(op.getResult());
+}
+
+// All floating point comparison are ordered through EDSL
+ValueHandle mlir::edsc::op::operator==(ValueHandle lhs, ValueHandle rhs) {
+  auto type = lhs.getType();
+  return type.isa<FloatType>()
+             ? createFComparisonExpr(CmpFPredicate::OEQ, lhs, rhs)
+             : createIComparisonExpr(CmpIPredicate::EQ, lhs, rhs);
+}
+ValueHandle mlir::edsc::op::operator!=(ValueHandle lhs, ValueHandle rhs) {
+  auto type = lhs.getType();
+  return type.isa<FloatType>()
+             ? createFComparisonExpr(CmpFPredicate::ONE, lhs, rhs)
+             : createIComparisonExpr(CmpIPredicate::NE, lhs, rhs);
+}
+ValueHandle mlir::edsc::op::operator<(ValueHandle lhs, ValueHandle rhs) {
+  auto type = lhs.getType();
+  return type.isa<FloatType>()
+             ? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs)
+             :
+             // TODO(ntv,zinenko): signed by default, how about unsigned?
+             createIComparisonExpr(CmpIPredicate::SLT, lhs, rhs);
+}
+ValueHandle mlir::edsc::op::operator<=(ValueHandle lhs, ValueHandle rhs) {
+  auto type = lhs.getType();
+  return type.isa<FloatType>()
+             ? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs)
+             : createIComparisonExpr(CmpIPredicate::SLE, lhs, rhs);
+}
+ValueHandle mlir::edsc::op::operator>(ValueHandle lhs, ValueHandle rhs) {
+  auto type = lhs.getType();
+  return type.isa<FloatType>()
+             ? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs)
+             : createIComparisonExpr(CmpIPredicate::SGT, lhs, rhs);
+}
+ValueHandle mlir::edsc::op::operator>=(ValueHandle lhs, ValueHandle rhs) {
+  auto type = lhs.getType();
+  return type.isa<FloatType>()
+             ? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs)
+             : createIComparisonExpr(CmpIPredicate::SGE, lhs, rhs);
+}
diff --git a/third_party/mlir/lib/EDSC/CMakeLists.txt b/third_party/mlir/lib/EDSC/CMakeLists.txt
new file mode 100644
index 0000000..d910480
--- /dev/null
+++ b/third_party/mlir/lib/EDSC/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_llvm_library(MLIREDSC
+  Builders.cpp
+  CoreAPIs.cpp
+  Helpers.cpp
+  Intrinsics.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/EDSC
+  )
+add_dependencies(MLIREDSC MLIRReferenceImplementationTestGen)
+target_link_libraries(MLIREDSC
+  PUBLIC
+    MLIRAffineOps
+    MLIRStandardOps
+    MLIRTransformUtils
+    MLIRVectorOps
+    )
diff --git a/third_party/mlir/lib/EDSC/CoreAPIs.cpp b/third_party/mlir/lib/EDSC/CoreAPIs.cpp
new file mode 100644
index 0000000..8b18313
--- /dev/null
+++ b/third_party/mlir/lib/EDSC/CoreAPIs.cpp
@@ -0,0 +1,103 @@
+//===- Types.cpp - Implementations of MLIR Core C APIs --------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir-c/Core.h"
+
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Support/LLVM.h"
+
+#include "llvm/ADT/StringSwitch.h"
+
+using namespace mlir;
+
+mlir_type_t makeScalarType(mlir_context_t context, const char *name,
+                           unsigned bitwidth) {
+  mlir::MLIRContext *c = reinterpret_cast<mlir::MLIRContext *>(context);
+  mlir_type_t res =
+      llvm::StringSwitch<mlir_type_t>(name)
+          .Case("bf16",
+                mlir_type_t{mlir::FloatType::getBF16(c).getAsOpaquePointer()})
+          .Case("f16",
+                mlir_type_t{mlir::FloatType::getF16(c).getAsOpaquePointer()})
+          .Case("f32",
+                mlir_type_t{mlir::FloatType::getF32(c).getAsOpaquePointer()})
+          .Case("f64",
+                mlir_type_t{mlir::FloatType::getF64(c).getAsOpaquePointer()})
+          .Case("index",
+                mlir_type_t{mlir::IndexType::get(c).getAsOpaquePointer()})
+          .Case("i",
+                mlir_type_t{
+                    mlir::IntegerType::get(bitwidth, c).getAsOpaquePointer()})
+          .Default(mlir_type_t{nullptr});
+  if (!res) {
+    llvm_unreachable("Invalid type specifier");
+  }
+  return res;
+}
+
+mlir_type_t makeMemRefType(mlir_context_t context, mlir_type_t elemType,
+                           int64_list_t sizes) {
+  auto t = mlir::MemRefType::get(
+      llvm::ArrayRef<int64_t>(sizes.values, sizes.n),
+      mlir::Type::getFromOpaquePointer(elemType),
+      {mlir::AffineMap::getMultiDimIdentityMap(
+          sizes.n, reinterpret_cast<mlir::MLIRContext *>(context))},
+      0);
+  return mlir_type_t{t.getAsOpaquePointer()};
+}
+
+mlir_type_t makeFunctionType(mlir_context_t context, mlir_type_list_t inputs,
+                             mlir_type_list_t outputs) {
+  llvm::SmallVector<mlir::Type, 8> ins(inputs.n), outs(outputs.n);
+  for (unsigned i = 0; i < inputs.n; ++i) {
+    ins[i] = mlir::Type::getFromOpaquePointer(inputs.types[i]);
+  }
+  for (unsigned i = 0; i < outputs.n; ++i) {
+    outs[i] = mlir::Type::getFromOpaquePointer(outputs.types[i]);
+  }
+  auto ft = mlir::FunctionType::get(
+      ins, outs, reinterpret_cast<mlir::MLIRContext *>(context));
+  return mlir_type_t{ft.getAsOpaquePointer()};
+}
+
+mlir_type_t makeIndexType(mlir_context_t context) {
+  auto *ctx = reinterpret_cast<mlir::MLIRContext *>(context);
+  auto type = mlir::IndexType::get(ctx);
+  return mlir_type_t{type.getAsOpaquePointer()};
+}
+
+mlir_attr_t makeIntegerAttr(mlir_type_t type, int64_t value) {
+  auto ty = Type::getFromOpaquePointer(reinterpret_cast<const void *>(type));
+  auto attr = IntegerAttr::get(ty, value);
+  return mlir_attr_t{attr.getAsOpaquePointer()};
+}
+
+mlir_attr_t makeBoolAttr(mlir_context_t context, bool value) {
+  auto *ctx = reinterpret_cast<mlir::MLIRContext *>(context);
+  auto attr = BoolAttr::get(value, ctx);
+  return mlir_attr_t{attr.getAsOpaquePointer()};
+}
+
+unsigned getFunctionArity(mlir_func_t function) {
+  auto f = mlir::FuncOp::getFromOpaquePointer(function);
+  return f.getNumArguments();
+}
diff --git a/third_party/mlir/lib/EDSC/Helpers.cpp b/third_party/mlir/lib/EDSC/Helpers.cpp
new file mode 100644
index 0000000..e6266d3
--- /dev/null
+++ b/third_party/mlir/lib/EDSC/Helpers.cpp
@@ -0,0 +1,64 @@
+//===- Helpers.cpp - MLIR Declarative Helper Functionality ----------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/EDSC/Helpers.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/StandardOps/Ops.h"
+
+using namespace mlir;
+using namespace mlir::edsc;
+
+static SmallVector<ValueHandle, 8> getMemRefSizes(Value *memRef) {
+  MemRefType memRefType = memRef->getType().cast<MemRefType>();
+
+  auto maps = memRefType.getAffineMaps();
+  (void)maps;
+  assert((maps.empty() || (maps.size() == 1 && maps[0].isIdentity())) &&
+         "Layout maps not supported");
+  SmallVector<ValueHandle, 8> res;
+  res.reserve(memRefType.getShape().size());
+  const auto &shape = memRefType.getShape();
+  for (unsigned idx = 0, n = shape.size(); idx < n; ++idx) {
+    if (shape[idx] == -1) {
+      res.push_back(ValueHandle::create<DimOp>(memRef, idx));
+    } else {
+      res.push_back(static_cast<index_t>(shape[idx]));
+    }
+  }
+  return res;
+}
+
+mlir::edsc::MemRefView::MemRefView(Value *v) : base(v) {
+  assert(v->getType().isa<MemRefType>() && "MemRefType expected");
+
+  auto memrefSizeValues = getMemRefSizes(v);
+  for (auto &size : memrefSizeValues) {
+    lbs.push_back(static_cast<index_t>(0));
+    ubs.push_back(size);
+    steps.push_back(1);
+  }
+}
+
+mlir::edsc::VectorView::VectorView(Value *v) : base(v) {
+  auto vectorType = v->getType().cast<VectorType>();
+
+  for (auto s : vectorType.getShape()) {
+    lbs.push_back(static_cast<index_t>(0));
+    ubs.push_back(static_cast<index_t>(s));
+    steps.push_back(1);
+  }
+}
diff --git a/third_party/mlir/lib/EDSC/Intrinsics.cpp b/third_party/mlir/lib/EDSC/Intrinsics.cpp
new file mode 100644
index 0000000..421cadc
--- /dev/null
+++ b/third_party/mlir/lib/EDSC/Intrinsics.cpp
@@ -0,0 +1,86 @@
+//===- Intrinsics.cpp - MLIR Operations for Declarative Builders ----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/EDSC/Intrinsics.h"
+#include "mlir/EDSC/Builders.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/VectorOps/VectorOps.h"
+
+using namespace mlir;
+using namespace mlir::edsc;
+
+OperationHandle mlir::edsc::intrinsics::br(BlockHandle bh,
+                                           ArrayRef<ValueHandle> operands) {
+  assert(bh && "Expected already captured BlockHandle");
+  for (auto &o : operands) {
+    (void)o;
+    assert(o && "Expected already captured ValueHandle");
+  }
+  SmallVector<Value *, 4> ops(operands.begin(), operands.end());
+  return OperationHandle::create<BranchOp>(bh.getBlock(), ops);
+}
+static void enforceEmptyCapturesMatchOperands(ArrayRef<ValueHandle *> captures,
+                                              ArrayRef<ValueHandle> operands) {
+  assert(captures.size() == operands.size() &&
+         "Expected same number of captures as operands");
+  for (auto it : llvm::zip(captures, operands)) {
+    (void)it;
+    assert(!std::get<0>(it)->hasValue() &&
+           "Unexpected already captured ValueHandle");
+    assert(std::get<1>(it) && "Expected already captured ValueHandle");
+    assert(std::get<0>(it)->getType() == std::get<1>(it).getType() &&
+           "Expected the same type for capture and operand");
+  }
+}
+
+OperationHandle mlir::edsc::intrinsics::br(BlockHandle *bh,
+                                           ArrayRef<ValueHandle *> captures,
+                                           ArrayRef<ValueHandle> operands) {
+  assert(!*bh && "Unexpected already captured BlockHandle");
+  enforceEmptyCapturesMatchOperands(captures, operands);
+  BlockBuilder(bh, captures)(/* no body */);
+  SmallVector<Value *, 4> ops(operands.begin(), operands.end());
+  return OperationHandle::create<BranchOp>(bh->getBlock(), ops);
+}
+
+OperationHandle
+mlir::edsc::intrinsics::cond_br(ValueHandle cond, BlockHandle trueBranch,
+                                ArrayRef<ValueHandle> trueOperands,
+                                BlockHandle falseBranch,
+                                ArrayRef<ValueHandle> falseOperands) {
+  SmallVector<Value *, 4> trueOps(trueOperands.begin(), trueOperands.end());
+  SmallVector<Value *, 4> falseOps(falseOperands.begin(), falseOperands.end());
+  return OperationHandle::create<CondBranchOp>(
+      cond, trueBranch.getBlock(), trueOps, falseBranch.getBlock(), falseOps);
+}
+
+OperationHandle mlir::edsc::intrinsics::cond_br(
+    ValueHandle cond, BlockHandle *trueBranch,
+    ArrayRef<ValueHandle *> trueCaptures, ArrayRef<ValueHandle> trueOperands,
+    BlockHandle *falseBranch, ArrayRef<ValueHandle *> falseCaptures,
+    ArrayRef<ValueHandle> falseOperands) {
+  assert(!*trueBranch && "Unexpected already captured BlockHandle");
+  assert(!*falseBranch && "Unexpected already captured BlockHandle");
+  enforceEmptyCapturesMatchOperands(trueCaptures, trueOperands);
+  enforceEmptyCapturesMatchOperands(falseCaptures, falseOperands);
+  BlockBuilder(trueBranch, trueCaptures)(/* no body */);
+  BlockBuilder(falseBranch, falseCaptures)(/* no body */);
+  SmallVector<Value *, 4> trueOps(trueOperands.begin(), trueOperands.end());
+  SmallVector<Value *, 4> falseOps(falseOperands.begin(), falseOperands.end());
+  return OperationHandle::create<CondBranchOp>(
+      cond, trueBranch->getBlock(), trueOps, falseBranch->getBlock(), falseOps);
+}
diff --git a/third_party/mlir/lib/ExecutionEngine/CMakeLists.txt b/third_party/mlir/lib/ExecutionEngine/CMakeLists.txt
new file mode 100644
index 0000000..fd856a7
--- /dev/null
+++ b/third_party/mlir/lib/ExecutionEngine/CMakeLists.txt
@@ -0,0 +1,10 @@
+llvm_map_components_to_libnames(outlibs "nativecodegen" "IPO")
+add_llvm_library(MLIRExecutionEngine
+  ExecutionEngine.cpp
+  MemRefUtils.cpp
+  OptUtils.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/ExecutionEngine
+  )
+target_link_libraries(MLIRExecutionEngine MLIRLLVMIR MLIRTargetLLVMIR LLVMExecutionEngine LLVMOrcJIT LLVMSupport ${outlibs})
diff --git a/third_party/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/third_party/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
new file mode 100644
index 0000000..b830c53
--- /dev/null
+++ b/third_party/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
@@ -0,0 +1,372 @@
+//===- ExecutionEngine.cpp - MLIR Execution engine and utils --------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the execution engine for MLIR modules based on LLVM Orc
+// JIT engine.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/ExecutionEngine/ExecutionEngine.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Target/LLVMIR.h"
+
+#include "llvm/ExecutionEngine/Orc/CompileUtils.h"
+#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
+#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
+#include "llvm/ExecutionEngine/Orc/IRTransformLayer.h"
+#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
+#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
+#include "llvm/ExecutionEngine/SectionMemoryManager.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/TargetRegistry.h"
+
+using namespace mlir;
+using llvm::Error;
+using llvm::Expected;
+
+namespace {
+// Memory manager for the JIT's objectLayer.  Its main goal is to fallback to
+// resolving functions in the current process if they cannot be resolved in the
+// JIT-compiled modules.
+class MemoryManager : public llvm::SectionMemoryManager {
+public:
+  MemoryManager(llvm::orc::ExecutionSession &execSession)
+      : session(execSession) {}
+
+  // Resolve the named symbol.  First, try looking it up in the main library of
+  // the execution session.  If there is no such symbol, try looking it up in
+  // the current process (for example, if it is a standard library function).
+  // Return `nullptr` if lookup fails.
+  llvm::JITSymbol findSymbol(const std::string &name) override {
+    auto mainLibSymbol = session.lookup({&session.getMainJITDylib()}, name);
+    if (mainLibSymbol)
+      return mainLibSymbol.get();
+    auto address = llvm::RTDyldMemoryManager::getSymbolAddressInProcess(name);
+    if (!address) {
+      llvm::errs() << "Could not look up: " << name << '\n';
+      return nullptr;
+    }
+    return llvm::JITSymbol(address, llvm::JITSymbolFlags::Exported);
+  }
+
+private:
+  llvm::orc::ExecutionSession &session;
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace impl {
+
+/// Wrapper class around DynamicLibrarySearchGenerator to allow searching
+/// in-process symbols that have not been explicitly exported.
+/// This first tries to resolve a symbol by using DynamicLibrarySearchGenerator.
+/// For symbols that are not found this way, it then uses
+///   `llvm::sys::DynamicLibrary::SearchForAddressOfSymbol` to extract symbols
+/// that have been explicitly added with `llvm::sys::DynamicLibrary::AddSymbol`,
+/// previously.
+class SearchGenerator {
+public:
+  SearchGenerator(char GlobalPrefix)
+      : defaultGenerator(cantFail(
+            llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
+                GlobalPrefix))) {}
+
+  // This function forwards to DynamicLibrarySearchGenerator::operator() and
+  // adds an extra resolution for names explicitly registered via
+  // `llvm::sys::DynamicLibrary::AddSymbol`.
+  Expected<llvm::orc::SymbolNameSet>
+  operator()(llvm::orc::JITDylib &JD, const llvm::orc::SymbolNameSet &Names) {
+    auto res = defaultGenerator(JD, Names);
+    if (!res)
+      return res;
+    llvm::orc::SymbolMap newSymbols;
+    for (auto &Name : Names) {
+      if (res.get().count(Name) > 0)
+        continue;
+      res.get().insert(Name);
+      auto addedSymbolAddress =
+          llvm::sys::DynamicLibrary::SearchForAddressOfSymbol(*Name);
+      if (!addedSymbolAddress)
+        continue;
+      llvm::JITEvaluatedSymbol Sym(
+          reinterpret_cast<uintptr_t>(addedSymbolAddress),
+          llvm::JITSymbolFlags::Exported);
+      newSymbols[Name] = Sym;
+    }
+    if (!newSymbols.empty())
+      cantFail(JD.define(absoluteSymbols(std::move(newSymbols))));
+    return res;
+  }
+
+private:
+  llvm::orc::DynamicLibrarySearchGenerator defaultGenerator;
+};
+
+// Simple layered Orc JIT compilation engine.
+class OrcJIT {
+public:
+  using IRTransformer = std::function<Error(llvm::Module *)>;
+
+  // Construct a JIT engine for the target host defined by `machineBuilder`,
+  // using the data layout provided as `dataLayout`.
+  // Setup the object layer to use our custom memory manager in order to
+  // resolve calls to library functions present in the process.
+  OrcJIT(llvm::orc::JITTargetMachineBuilder machineBuilder,
+         llvm::DataLayout layout, IRTransformer transform,
+         ArrayRef<StringRef> sharedLibPaths)
+      : irTransformer(transform),
+        objectLayer(
+            session,
+            [this]() { return llvm::make_unique<MemoryManager>(session); }),
+        compileLayer(
+            session, objectLayer,
+            llvm::orc::ConcurrentIRCompiler(std::move(machineBuilder))),
+        transformLayer(session, compileLayer, makeIRTransformFunction()),
+        dataLayout(layout), mangler(session, this->dataLayout),
+        threadSafeCtx(llvm::make_unique<llvm::LLVMContext>()) {
+    session.getMainJITDylib().setGenerator(
+        SearchGenerator(layout.getGlobalPrefix()));
+    loadLibraries(sharedLibPaths);
+  }
+
+  // Create a JIT engine for the current host.
+  static Expected<std::unique_ptr<OrcJIT>>
+  createDefault(IRTransformer transformer, ArrayRef<StringRef> sharedLibPaths) {
+    auto machineBuilder = llvm::orc::JITTargetMachineBuilder::detectHost();
+    if (!machineBuilder)
+      return machineBuilder.takeError();
+
+    auto dataLayout = machineBuilder->getDefaultDataLayoutForTarget();
+    if (!dataLayout)
+      return dataLayout.takeError();
+
+    return llvm::make_unique<OrcJIT>(std::move(*machineBuilder),
+                                     std::move(*dataLayout), transformer,
+                                     sharedLibPaths);
+  }
+
+  // Add an LLVM module to the main library managed by the JIT engine.
+  Error addModule(std::unique_ptr<llvm::Module> M) {
+    return transformLayer.add(
+        session.getMainJITDylib(),
+        llvm::orc::ThreadSafeModule(std::move(M), threadSafeCtx));
+  }
+
+  // Lookup a symbol in the main library managed by the JIT engine.
+  Expected<llvm::JITEvaluatedSymbol> lookup(StringRef Name) {
+    return session.lookup({&session.getMainJITDylib()}, mangler(Name.str()));
+  }
+
+private:
+  // Wrap the `irTransformer` into a function that can be called by the
+  // IRTranformLayer.  If `irTransformer` is not set up, return the module as
+  // is without errors.
+  llvm::orc::IRTransformLayer::TransformFunction makeIRTransformFunction() {
+    return [this](llvm::orc::ThreadSafeModule module,
+                  const llvm::orc::MaterializationResponsibility &resp)
+               -> Expected<llvm::orc::ThreadSafeModule> {
+      (void)resp;
+      if (!irTransformer)
+        return std::move(module);
+      Error err = module.withModuleDo(
+          [this](llvm::Module &module) { return irTransformer(&module); });
+      if (err)
+        return std::move(err);
+      return std::move(module);
+    };
+  }
+
+  // Iterate over shareLibPaths and load the corresponding libraries for symbol
+  // resolution.
+  void loadLibraries(ArrayRef<StringRef> sharedLibPaths);
+
+  IRTransformer irTransformer;
+  llvm::orc::ExecutionSession session;
+  llvm::orc::RTDyldObjectLinkingLayer objectLayer;
+  llvm::orc::IRCompileLayer compileLayer;
+  llvm::orc::IRTransformLayer transformLayer;
+  llvm::DataLayout dataLayout;
+  llvm::orc::MangleAndInterner mangler;
+  llvm::orc::ThreadSafeContext threadSafeCtx;
+};
+} // end namespace impl
+} // namespace mlir
+
+void mlir::impl::OrcJIT::loadLibraries(ArrayRef<StringRef> sharedLibPaths) {
+  for (auto libPath : sharedLibPaths) {
+    auto mb = llvm::MemoryBuffer::getFile(libPath);
+    if (!mb) {
+      llvm::errs() << "Could not create MemoryBuffer for: " << libPath << " "
+                   << mb.getError().message() << "\n";
+      continue;
+    }
+    auto &JD = session.createJITDylib(libPath);
+    auto loaded = llvm::orc::DynamicLibrarySearchGenerator::Load(
+        libPath.data(), dataLayout.getGlobalPrefix());
+    if (!loaded) {
+      llvm::errs() << "Could not load: " << libPath << " " << loaded.takeError()
+                   << "\n";
+      continue;
+    }
+    JD.setGenerator(loaded.get());
+    auto res = objectLayer.add(JD, std::move(mb.get()));
+    if (res)
+      llvm::errs() << "Could not add: " << libPath << " " << res << "\n";
+  }
+}
+
+// Wrap a string into an llvm::StringError.
+static inline Error make_string_error(const llvm::Twine &message) {
+  return llvm::make_error<llvm::StringError>(message.str(),
+                                             llvm::inconvertibleErrorCode());
+}
+
+// Setup LLVM target triple from the current machine.
+bool ExecutionEngine::setupTargetTriple(llvm::Module *llvmModule) {
+  // Setup the machine properties from the current architecture.
+  auto targetTriple = llvm::sys::getDefaultTargetTriple();
+  std::string errorMessage;
+  auto target = llvm::TargetRegistry::lookupTarget(targetTriple, errorMessage);
+  if (!target) {
+    llvm::errs() << "NO target: " << errorMessage << "\n";
+    return true;
+  }
+  auto machine =
+      target->createTargetMachine(targetTriple, "generic", "", {}, {});
+  llvmModule->setDataLayout(machine->createDataLayout());
+  llvmModule->setTargetTriple(targetTriple);
+  return false;
+}
+
+static std::string makePackedFunctionName(StringRef name) {
+  return "_mlir_" + name.str();
+}
+
+// For each function in the LLVM module, define an interface function that wraps
+// all the arguments of the original function and all its results into an i8**
+// pointer to provide a unified invocation interface.
+void packFunctionArguments(llvm::Module *module) {
+  auto &ctx = module->getContext();
+  llvm::IRBuilder<> builder(ctx);
+  llvm::DenseSet<llvm::Function *> interfaceFunctions;
+  for (auto &func : module->getFunctionList()) {
+    if (func.isDeclaration()) {
+      continue;
+    }
+    if (interfaceFunctions.count(&func)) {
+      continue;
+    }
+
+    // Given a function `foo(<...>)`, define the interface function
+    // `mlir_foo(i8**)`.
+    auto newType = llvm::FunctionType::get(
+        builder.getVoidTy(), builder.getInt8PtrTy()->getPointerTo(),
+        /*isVarArg=*/false);
+    auto newName = makePackedFunctionName(func.getName());
+    auto funcCst = module->getOrInsertFunction(newName, newType);
+    llvm::Function *interfaceFunc =
+        llvm::cast<llvm::Function>(funcCst.getCallee());
+    interfaceFunctions.insert(interfaceFunc);
+
+    // Extract the arguments from the type-erased argument list and cast them to
+    // the proper types.
+    auto bb = llvm::BasicBlock::Create(ctx);
+    bb->insertInto(interfaceFunc);
+    builder.SetInsertPoint(bb);
+    llvm::Value *argList = interfaceFunc->arg_begin();
+    llvm::SmallVector<llvm::Value *, 8> args;
+    args.reserve(llvm::size(func.args()));
+    for (auto &indexedArg : llvm::enumerate(func.args())) {
+      llvm::Value *argIndex = llvm::Constant::getIntegerValue(
+          builder.getInt64Ty(), llvm::APInt(64, indexedArg.index()));
+      llvm::Value *argPtrPtr = builder.CreateGEP(argList, argIndex);
+      llvm::Value *argPtr = builder.CreateLoad(argPtrPtr);
+      argPtr = builder.CreateBitCast(
+          argPtr, indexedArg.value().getType()->getPointerTo());
+      llvm::Value *arg = builder.CreateLoad(argPtr);
+      args.push_back(arg);
+    }
+
+    // Call the implementation function with the extracted arguments.
+    llvm::Value *result = builder.CreateCall(&func, args);
+
+    // Assuming the result is one value, potentially of type `void`.
+    if (!result->getType()->isVoidTy()) {
+      llvm::Value *retIndex = llvm::Constant::getIntegerValue(
+          builder.getInt64Ty(), llvm::APInt(64, llvm::size(func.args())));
+      llvm::Value *retPtrPtr = builder.CreateGEP(argList, retIndex);
+      llvm::Value *retPtr = builder.CreateLoad(retPtrPtr);
+      retPtr = builder.CreateBitCast(retPtr, result->getType()->getPointerTo());
+      builder.CreateStore(result, retPtr);
+    }
+
+    // The interface function returns void.
+    builder.CreateRetVoid();
+  }
+}
+
+// Out of line for PIMPL unique_ptr.
+ExecutionEngine::~ExecutionEngine() = default;
+
+Expected<std::unique_ptr<ExecutionEngine>>
+ExecutionEngine::create(ModuleOp m,
+                        std::function<llvm::Error(llvm::Module *)> transformer,
+                        ArrayRef<StringRef> sharedLibPaths) {
+  auto engine = llvm::make_unique<ExecutionEngine>();
+  auto expectedJIT = impl::OrcJIT::createDefault(transformer, sharedLibPaths);
+  if (!expectedJIT)
+    return expectedJIT.takeError();
+
+  auto llvmModule = translateModuleToLLVMIR(m);
+  if (!llvmModule)
+    return make_string_error("could not convert to LLVM IR");
+  // FIXME: the triple should be passed to the translation or dialect conversion
+  // instead of this.  Currently, the LLVM module created above has no triple
+  // associated with it.
+  setupTargetTriple(llvmModule.get());
+  packFunctionArguments(llvmModule.get());
+
+  if (auto err = (*expectedJIT)->addModule(std::move(llvmModule)))
+    return std::move(err);
+  engine->jit = std::move(*expectedJIT);
+
+  return std::move(engine);
+}
+
+Expected<void (*)(void **)> ExecutionEngine::lookup(StringRef name) const {
+  auto expectedSymbol = jit->lookup(makePackedFunctionName(name));
+  if (!expectedSymbol)
+    return expectedSymbol.takeError();
+  auto rawFPtr = expectedSymbol->getAddress();
+  auto fptr = reinterpret_cast<void (*)(void **)>(rawFPtr);
+  if (!fptr)
+    return make_string_error("looked up function is null");
+  return fptr;
+}
+
+llvm::Error ExecutionEngine::invoke(StringRef name,
+                                    MutableArrayRef<void *> args) {
+  auto expectedFPtr = lookup(name);
+  if (!expectedFPtr)
+    return expectedFPtr.takeError();
+  auto fptr = *expectedFPtr;
+
+  (*fptr)(args.data());
+
+  return llvm::Error::success();
+}
diff --git a/third_party/mlir/lib/ExecutionEngine/MemRefUtils.cpp b/third_party/mlir/lib/ExecutionEngine/MemRefUtils.cpp
new file mode 100644
index 0000000..e34bf44
--- /dev/null
+++ b/third_party/mlir/lib/ExecutionEngine/MemRefUtils.cpp
@@ -0,0 +1,107 @@
+//===- MemRefUtils.cpp - MLIR runtime utilities for memrefs ---------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is a set of utilities to working with objects of memref type in an JIT
+// context using the MLIR execution engine.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ExecutionEngine/MemRefUtils.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Support/LLVM.h"
+
+#include "llvm/Support/Error.h"
+#include <numeric>
+
+using namespace mlir;
+
+static inline llvm::Error make_string_error(const llvm::Twine &message) {
+  return llvm::make_error<llvm::StringError>(message.str(),
+                                             llvm::inconvertibleErrorCode());
+}
+
+static llvm::Expected<StaticFloatMemRef *>
+allocMemRefDescriptor(Type type, bool allocateData = true,
+                      float initialValue = 0.0) {
+  auto memRefType = type.dyn_cast<MemRefType>();
+  if (!memRefType)
+    return make_string_error("non-memref argument not supported");
+  if (!memRefType.hasStaticShape())
+    return make_string_error("memref with dynamic shapes not supported");
+
+  auto elementType = memRefType.getElementType();
+  if (!elementType.isF32())
+    return make_string_error(
+        "memref with element other than f32 not supported");
+
+  auto *descriptor =
+      reinterpret_cast<StaticFloatMemRef *>(malloc(sizeof(StaticFloatMemRef)));
+  if (!allocateData) {
+    descriptor->data = nullptr;
+    return descriptor;
+  }
+
+  auto shape = memRefType.getShape();
+  int64_t size = std::accumulate(shape.begin(), shape.end(), 1,
+                                 std::multiplies<int64_t>());
+  descriptor->data = reinterpret_cast<float *>(malloc(sizeof(float) * size));
+  for (int64_t i = 0; i < size; ++i) {
+    descriptor->data[i] = initialValue;
+  }
+  return descriptor;
+}
+
+llvm::Expected<SmallVector<void *, 8>>
+mlir::allocateMemRefArguments(FuncOp func, float initialValue) {
+  SmallVector<void *, 8> args;
+  args.reserve(func.getNumArguments());
+  for (const auto &arg : func.getArguments()) {
+    auto descriptor =
+        allocMemRefDescriptor(arg->getType(),
+                              /*allocateData=*/true, initialValue);
+    if (!descriptor)
+      return descriptor.takeError();
+    args.push_back(*descriptor);
+  }
+
+  if (func.getType().getNumResults() > 1)
+    return make_string_error("functions with more than 1 result not supported");
+
+  for (Type resType : func.getType().getResults()) {
+    auto descriptor = allocMemRefDescriptor(resType, /*allocateData=*/false);
+    if (!descriptor)
+      return descriptor.takeError();
+    args.push_back(*descriptor);
+  }
+
+  return args;
+}
+
+// Because the function can return the same descriptor as passed in arguments,
+// we check that we don't attempt to free the underlying data twice.
+void mlir::freeMemRefArguments(ArrayRef<void *> args) {
+  llvm::DenseSet<void *> dataPointers;
+  for (void *arg : args) {
+    float *dataPtr = reinterpret_cast<StaticFloatMemRef *>(arg)->data;
+    if (dataPointers.count(dataPtr) == 0) {
+      free(dataPtr);
+      dataPointers.insert(dataPtr);
+    }
+    free(arg);
+  }
+}
diff --git a/third_party/mlir/lib/ExecutionEngine/OptUtils.cpp b/third_party/mlir/lib/ExecutionEngine/OptUtils.cpp
new file mode 100644
index 0000000..e8c6652
--- /dev/null
+++ b/third_party/mlir/lib/ExecutionEngine/OptUtils.cpp
@@ -0,0 +1,151 @@
+//===- OptUtils.cpp - MLIR Execution Engine optimization pass utilities ---===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the utility functions to trigger LLVM optimizations from
+// MLIR Execution Engine.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ExecutionEngine/OptUtils.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/IR/LegacyPassNameParser.h"
+#include "llvm/IR/Module.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/StringSaver.h"
+#include "llvm/Target/TargetMachine.h"
+#include "llvm/Transforms/IPO.h"
+#include "llvm/Transforms/IPO/PassManagerBuilder.h"
+#include <climits>
+#include <mutex>
+
+// Run the module and function passes managed by the module manager.
+static void runPasses(llvm::legacy::PassManager &modulePM,
+                      llvm::legacy::FunctionPassManager &funcPM,
+                      llvm::Module &m) {
+  funcPM.doInitialization();
+  for (auto &func : m) {
+    funcPM.run(func);
+  }
+  funcPM.doFinalization();
+  modulePM.run(m);
+}
+
+// Initialize basic LLVM transformation passes under lock.
+void mlir::initializeLLVMPasses() {
+  static std::mutex mutex;
+  std::lock_guard<std::mutex> lock(mutex);
+
+  auto &registry = *llvm::PassRegistry::getPassRegistry();
+  llvm::initializeCore(registry);
+  llvm::initializeTransformUtils(registry);
+  llvm::initializeScalarOpts(registry);
+  llvm::initializeIPO(registry);
+  llvm::initializeInstCombine(registry);
+  llvm::initializeAggressiveInstCombine(registry);
+  llvm::initializeAnalysis(registry);
+  llvm::initializeVectorization(registry);
+}
+
+// Populate pass managers according to the optimization and size levels.
+// This behaves similarly to LLVM opt.
+static void populatePassManagers(llvm::legacy::PassManager &modulePM,
+                                 llvm::legacy::FunctionPassManager &funcPM,
+                                 unsigned optLevel, unsigned sizeLevel,
+                                 llvm::TargetMachine *targetMachine) {
+  llvm::PassManagerBuilder builder;
+  builder.OptLevel = optLevel;
+  builder.SizeLevel = sizeLevel;
+  builder.Inliner = llvm::createFunctionInliningPass(
+      optLevel, sizeLevel, /*DisableInlineHotCallSite=*/false);
+  builder.LoopVectorize = optLevel > 1 && sizeLevel < 2;
+  builder.SLPVectorize = optLevel > 1 && sizeLevel < 2;
+  builder.DisableUnrollLoops = (optLevel == 0);
+
+  if (targetMachine) {
+    // Add pass to initialize TTI for this specific target. Otherwise, TTI will
+    // be initialized to NoTTIImpl by defaul.
+    modulePM.add(createTargetTransformInfoWrapperPass(
+        targetMachine->getTargetIRAnalysis()));
+    funcPM.add(createTargetTransformInfoWrapperPass(
+        targetMachine->getTargetIRAnalysis()));
+  }
+
+  builder.populateModulePassManager(modulePM);
+  builder.populateFunctionPassManager(funcPM);
+}
+
+// Create and return a lambda that uses LLVM pass manager builder to set up
+// optimizations based on the given level.
+std::function<llvm::Error(llvm::Module *)>
+mlir::makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel,
+                                llvm::TargetMachine *targetMachine) {
+  return [optLevel, sizeLevel, targetMachine](llvm::Module *m) -> llvm::Error {
+    llvm::legacy::PassManager modulePM;
+    llvm::legacy::FunctionPassManager funcPM(m);
+    populatePassManagers(modulePM, funcPM, optLevel, sizeLevel, targetMachine);
+    runPasses(modulePM, funcPM, *m);
+
+    return llvm::Error::success();
+  };
+}
+
+// Create and return a lambda that is given a set of passes to run, plus an
+// optional optimization level to pre-populate the pass manager.
+std::function<llvm::Error(llvm::Module *)> mlir::makeLLVMPassesTransformer(
+    llvm::ArrayRef<const llvm::PassInfo *> llvmPasses,
+    llvm::Optional<unsigned> mbOptLevel, llvm::TargetMachine *targetMachine,
+    unsigned optPassesInsertPos) {
+  return [llvmPasses, mbOptLevel, optPassesInsertPos,
+          targetMachine](llvm::Module *m) -> llvm::Error {
+    llvm::legacy::PassManager modulePM;
+    llvm::legacy::FunctionPassManager funcPM(m);
+
+    bool insertOptPasses = mbOptLevel.hasValue();
+    for (unsigned i = 0, e = llvmPasses.size(); i < e; ++i) {
+      const auto *passInfo = llvmPasses[i];
+      if (!passInfo->getNormalCtor())
+        continue;
+
+      if (insertOptPasses && optPassesInsertPos == i) {
+        populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0,
+                             targetMachine);
+        insertOptPasses = false;
+      }
+
+      auto *pass = passInfo->createPass();
+      if (!pass)
+        return llvm::make_error<llvm::StringError>(
+            "could not create pass " + passInfo->getPassName(),
+            llvm::inconvertibleErrorCode());
+      modulePM.add(pass);
+    }
+
+    if (insertOptPasses)
+      populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0,
+                           targetMachine);
+
+    runPasses(modulePM, funcPM, *m);
+    return llvm::Error::success();
+  };
+}
diff --git a/third_party/mlir/lib/IR/AffineExpr.cpp b/third_party/mlir/lib/IR/AffineExpr.cpp
new file mode 100644
index 0000000..10aed66
--- /dev/null
+++ b/third_party/mlir/lib/IR/AffineExpr.cpp
@@ -0,0 +1,896 @@
+//===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/AffineExpr.h"
+#include "AffineExprDetail.h"
+#include "mlir/IR/AffineExprVisitor.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/Support/MathExtras.h"
+#include "mlir/Support/STLExtras.h"
+#include "llvm/ADT/STLExtras.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+MLIRContext *AffineExpr::getContext() const { return expr->context; }
+
+AffineExprKind AffineExpr::getKind() const {
+  return static_cast<AffineExprKind>(expr->getKind());
+}
+
+/// Walk all of the AffineExprs in this subgraph in postorder.
+void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
+  struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> {
+    std::function<void(AffineExpr)> callback;
+
+    AffineExprWalker(std::function<void(AffineExpr)> callback)
+        : callback(callback) {}
+
+    void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); }
+    void visitConstantExpr(AffineConstantExpr expr) { callback(expr); }
+    void visitDimExpr(AffineDimExpr expr) { callback(expr); }
+    void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); }
+  };
+
+  AffineExprWalker(callback).walkPostOrder(*this);
+}
+
+// Dispatch affine expression construction based on kind.
+AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
+                                       AffineExpr rhs) {
+  if (kind == AffineExprKind::Add)
+    return lhs + rhs;
+  if (kind == AffineExprKind::Mul)
+    return lhs * rhs;
+  if (kind == AffineExprKind::FloorDiv)
+    return lhs.floorDiv(rhs);
+  if (kind == AffineExprKind::CeilDiv)
+    return lhs.ceilDiv(rhs);
+  if (kind == AffineExprKind::Mod)
+    return lhs % rhs;
+
+  llvm_unreachable("unknown binary operation on affine expressions");
+}
+
+/// This method substitutes any uses of dimensions and symbols (e.g.
+/// dim#0 with dimReplacements[0]) and returns the modified expression tree.
+AffineExpr
+AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
+                                  ArrayRef<AffineExpr> symReplacements) const {
+  switch (getKind()) {
+  case AffineExprKind::Constant:
+    return *this;
+  case AffineExprKind::DimId: {
+    unsigned dimId = cast<AffineDimExpr>().getPosition();
+    if (dimId >= dimReplacements.size())
+      return *this;
+    return dimReplacements[dimId];
+  }
+  case AffineExprKind::SymbolId: {
+    unsigned symId = cast<AffineSymbolExpr>().getPosition();
+    if (symId >= symReplacements.size())
+      return *this;
+    return symReplacements[symId];
+  }
+  case AffineExprKind::Add:
+  case AffineExprKind::Mul:
+  case AffineExprKind::FloorDiv:
+  case AffineExprKind::CeilDiv:
+  case AffineExprKind::Mod:
+    auto binOp = cast<AffineBinaryOpExpr>();
+    auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
+    auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
+    auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
+    if (newLHS == lhs && newRHS == rhs)
+      return *this;
+    return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
+  }
+  llvm_unreachable("Unknown AffineExpr");
+}
+
+/// Returns true if this expression is made out of only symbols and
+/// constants (no dimensional identifiers).
+bool AffineExpr::isSymbolicOrConstant() const {
+  switch (getKind()) {
+  case AffineExprKind::Constant:
+    return true;
+  case AffineExprKind::DimId:
+    return false;
+  case AffineExprKind::SymbolId:
+    return true;
+
+  case AffineExprKind::Add:
+  case AffineExprKind::Mul:
+  case AffineExprKind::FloorDiv:
+  case AffineExprKind::CeilDiv:
+  case AffineExprKind::Mod: {
+    auto expr = this->cast<AffineBinaryOpExpr>();
+    return expr.getLHS().isSymbolicOrConstant() &&
+           expr.getRHS().isSymbolicOrConstant();
+  }
+  }
+  llvm_unreachable("Unknown AffineExpr");
+}
+
+/// Returns true if this is a pure affine expression, i.e., multiplication,
+/// floordiv, ceildiv, and mod is only allowed w.r.t constants.
+bool AffineExpr::isPureAffine() const {
+  switch (getKind()) {
+  case AffineExprKind::SymbolId:
+  case AffineExprKind::DimId:
+  case AffineExprKind::Constant:
+    return true;
+  case AffineExprKind::Add: {
+    auto op = cast<AffineBinaryOpExpr>();
+    return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
+  }
+
+  case AffineExprKind::Mul: {
+    // TODO: Canonicalize the constants in binary operators to the RHS when
+    // possible, allowing this to merge into the next case.
+    auto op = cast<AffineBinaryOpExpr>();
+    return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
+           (op.getLHS().template isa<AffineConstantExpr>() ||
+            op.getRHS().template isa<AffineConstantExpr>());
+  }
+  case AffineExprKind::FloorDiv:
+  case AffineExprKind::CeilDiv:
+  case AffineExprKind::Mod: {
+    auto op = cast<AffineBinaryOpExpr>();
+    return op.getLHS().isPureAffine() &&
+           op.getRHS().template isa<AffineConstantExpr>();
+  }
+  }
+  llvm_unreachable("Unknown AffineExpr");
+}
+
+// Returns the greatest known integral divisor of this affine expression.
+uint64_t AffineExpr::getLargestKnownDivisor() const {
+  AffineBinaryOpExpr binExpr(nullptr);
+  switch (getKind()) {
+  case AffineExprKind::SymbolId:
+    LLVM_FALLTHROUGH;
+  case AffineExprKind::DimId:
+    return 1;
+  case AffineExprKind::Constant:
+    return std::abs(this->cast<AffineConstantExpr>().getValue());
+  case AffineExprKind::Mul: {
+    binExpr = this->cast<AffineBinaryOpExpr>();
+    return binExpr.getLHS().getLargestKnownDivisor() *
+           binExpr.getRHS().getLargestKnownDivisor();
+  }
+  case AffineExprKind::Add:
+    LLVM_FALLTHROUGH;
+  case AffineExprKind::FloorDiv:
+  case AffineExprKind::CeilDiv:
+  case AffineExprKind::Mod: {
+    binExpr = cast<AffineBinaryOpExpr>();
+    return llvm::GreatestCommonDivisor64(
+        binExpr.getLHS().getLargestKnownDivisor(),
+        binExpr.getRHS().getLargestKnownDivisor());
+  }
+  }
+  llvm_unreachable("Unknown AffineExpr");
+}
+
+bool AffineExpr::isMultipleOf(int64_t factor) const {
+  AffineBinaryOpExpr binExpr(nullptr);
+  uint64_t l, u;
+  switch (getKind()) {
+  case AffineExprKind::SymbolId:
+    LLVM_FALLTHROUGH;
+  case AffineExprKind::DimId:
+    return factor * factor == 1;
+  case AffineExprKind::Constant:
+    return cast<AffineConstantExpr>().getValue() % factor == 0;
+  case AffineExprKind::Mul: {
+    binExpr = cast<AffineBinaryOpExpr>();
+    // It's probably not worth optimizing this further (to not traverse the
+    // whole sub-tree under - it that would require a version of isMultipleOf
+    // that on a 'false' return also returns the largest known divisor).
+    return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 ||
+           (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 ||
+           (l * u) % factor == 0;
+  }
+  case AffineExprKind::Add:
+  case AffineExprKind::FloorDiv:
+  case AffineExprKind::CeilDiv:
+  case AffineExprKind::Mod: {
+    binExpr = cast<AffineBinaryOpExpr>();
+    return llvm::GreatestCommonDivisor64(
+               binExpr.getLHS().getLargestKnownDivisor(),
+               binExpr.getRHS().getLargestKnownDivisor()) %
+               factor ==
+           0;
+  }
+  }
+  llvm_unreachable("Unknown AffineExpr");
+}
+
+bool AffineExpr::isFunctionOfDim(unsigned position) const {
+  if (getKind() == AffineExprKind::DimId) {
+    return *this == mlir::getAffineDimExpr(position, getContext());
+  }
+  if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
+    return expr.getLHS().isFunctionOfDim(position) ||
+           expr.getRHS().isFunctionOfDim(position);
+  }
+  return false;
+}
+
+AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
+    : AffineExpr(ptr) {}
+AffineExpr AffineBinaryOpExpr::getLHS() const {
+  return static_cast<ImplType *>(expr)->lhs;
+}
+AffineExpr AffineBinaryOpExpr::getRHS() const {
+  return static_cast<ImplType *>(expr)->rhs;
+}
+
+AffineDimExpr::AffineDimExpr(AffineExpr::ImplType *ptr) : AffineExpr(ptr) {}
+unsigned AffineDimExpr::getPosition() const {
+  return static_cast<ImplType *>(expr)->position;
+}
+
+static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
+                                       MLIRContext *context) {
+  auto assignCtx = [context](AffineDimExprStorage *storage) {
+    storage->context = context;
+  };
+
+  StorageUniquer &uniquer = context->getAffineUniquer();
+  return uniquer.get<AffineDimExprStorage>(
+      assignCtx, static_cast<unsigned>(kind), position);
+}
+
+AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
+  return getAffineDimOrSymbol(AffineExprKind::DimId, position, context);
+}
+
+AffineSymbolExpr::AffineSymbolExpr(AffineExpr::ImplType *ptr)
+    : AffineExpr(ptr) {}
+unsigned AffineSymbolExpr::getPosition() const {
+  return static_cast<ImplType *>(expr)->position;
+}
+
+AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
+  return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context);
+  ;
+}
+
+AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr)
+    : AffineExpr(ptr) {}
+int64_t AffineConstantExpr::getValue() const {
+  return static_cast<ImplType *>(expr)->constant;
+}
+
+AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
+  auto assignCtx = [context](AffineConstantExprStorage *storage) {
+    storage->context = context;
+  };
+
+  StorageUniquer &uniquer = context->getAffineUniquer();
+  return uniquer.get<AffineConstantExprStorage>(
+      assignCtx, static_cast<unsigned>(AffineExprKind::Constant), constant);
+}
+
+/// Simplify add expression. Return nullptr if it can't be simplified.
+static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
+  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
+  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
+  // Fold if both LHS, RHS are a constant.
+  if (lhsConst && rhsConst)
+    return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(),
+                                 lhs.getContext());
+
+  // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
+  // If only one of them is a symbolic expressions, make it the RHS.
+  if (lhs.isa<AffineConstantExpr>() ||
+      (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
+    return rhs + lhs;
+  }
+
+  // At this point, if there was a constant, it would be on the right.
+
+  // Addition with a zero is a noop, return the other input.
+  if (rhsConst) {
+    if (rhsConst.getValue() == 0)
+      return lhs;
+  }
+  // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
+  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
+  if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
+    if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
+      return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
+  }
+
+  // When doing successive additions, bring constant to the right: turn (d0 + 2)
+  // + d1 into (d0 + d1) + 2.
+  if (lBin && lBin.getKind() == AffineExprKind::Add) {
+    if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
+      return lBin.getLHS() + rhs + lrhs;
+    }
+  }
+
+  // Detect and transform "expr - c * (expr floordiv c)" to "expr mod c". This
+  // leads to a much more efficient form when 'c' is a power of two, and in
+  // general a more compact and readable form.
+
+  // Process '(expr floordiv c) * (-c)'.
+  AffineBinaryOpExpr rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>();
+  if (!rBinOpExpr)
+    return nullptr;
+
+  auto lrhs = rBinOpExpr.getLHS();
+  auto rrhs = rBinOpExpr.getRHS();
+
+  // Process lrhs, which is 'expr floordiv c'.
+  AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
+  if (!lrBinOpExpr)
+    return nullptr;
+
+  auto llrhs = lrBinOpExpr.getLHS();
+  auto rlrhs = lrBinOpExpr.getRHS();
+
+  if (lhs == llrhs && rlrhs == -rrhs) {
+    return lhs % rlrhs;
+  }
+  return nullptr;
+}
+
+AffineExpr AffineExpr::operator+(int64_t v) const {
+  return *this + getAffineConstantExpr(v, getContext());
+}
+AffineExpr AffineExpr::operator+(AffineExpr other) const {
+  if (auto simplified = simplifyAdd(*this, other))
+    return simplified;
+
+  StorageUniquer &uniquer = getContext()->getAffineUniquer();
+  return uniquer.get<AffineBinaryOpExprStorage>(
+      /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
+}
+
+/// Simplify a multiply expression. Return nullptr if it can't be simplified.
+static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
+  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
+  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
+
+  if (lhsConst && rhsConst)
+    return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(),
+                                 lhs.getContext());
+
+  assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant());
+
+  // Canonicalize the mul expression so that the constant/symbolic term is the
+  // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
+  // constant. (Note that a constant is trivially symbolic).
+  if (!rhs.isSymbolicOrConstant() || lhs.isa<AffineConstantExpr>()) {
+    // At least one of them has to be symbolic.
+    return rhs * lhs;
+  }
+
+  // At this point, if there was a constant, it would be on the right.
+
+  // Multiplication with a one is a noop, return the other input.
+  if (rhsConst) {
+    if (rhsConst.getValue() == 1)
+      return lhs;
+    // Multiplication with zero.
+    if (rhsConst.getValue() == 0)
+      return rhsConst;
+  }
+
+  // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
+  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
+  if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
+    if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
+      return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
+  }
+
+  // When doing successive multiplication, bring constant to the right: turn (d0
+  // * 2) * d1 into (d0 * d1) * 2.
+  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
+    if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
+      return (lBin.getLHS() * rhs) * lrhs;
+    }
+  }
+
+  return nullptr;
+}
+
+AffineExpr AffineExpr::operator*(int64_t v) const {
+  return *this * getAffineConstantExpr(v, getContext());
+}
+AffineExpr AffineExpr::operator*(AffineExpr other) const {
+  if (auto simplified = simplifyMul(*this, other))
+    return simplified;
+
+  StorageUniquer &uniquer = getContext()->getAffineUniquer();
+  return uniquer.get<AffineBinaryOpExprStorage>(
+      /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
+}
+
+// Unary minus, delegate to operator*.
+AffineExpr AffineExpr::operator-() const {
+  return *this * getAffineConstantExpr(-1, getContext());
+}
+
+// Delegate to operator+.
+AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); }
+AffineExpr AffineExpr::operator-(AffineExpr other) const {
+  return *this + (-other);
+}
+
+static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
+  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
+  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
+
+  if (!rhsConst || rhsConst.getValue() < 1)
+    return nullptr;
+
+  if (lhsConst)
+    return getAffineConstantExpr(
+        floorDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
+
+  // Fold floordiv of a multiply with a constant that is a multiple of the
+  // divisor. Eg: (i * 128) floordiv 64 = i * 2.
+  if (rhsConst.getValue() == 1)
+    return lhs;
+
+  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
+  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
+    if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
+      // rhsConst is known to be positive if a constant.
+      if (lrhs.getValue() % rhsConst.getValue() == 0)
+        return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
+    }
+  }
+
+  return nullptr;
+}
+
+AffineExpr AffineExpr::floorDiv(uint64_t v) const {
+  return floorDiv(getAffineConstantExpr(v, getContext()));
+}
+AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
+  if (auto simplified = simplifyFloorDiv(*this, other))
+    return simplified;
+
+  StorageUniquer &uniquer = getContext()->getAffineUniquer();
+  return uniquer.get<AffineBinaryOpExprStorage>(
+      /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
+      other);
+}
+
+static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
+  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
+  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
+
+  if (!rhsConst || rhsConst.getValue() < 1)
+    return nullptr;
+
+  if (lhsConst)
+    return getAffineConstantExpr(
+        ceilDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
+
+  // Fold ceildiv of a multiply with a constant that is a multiple of the
+  // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
+  if (rhsConst.getValue() == 1)
+    return lhs;
+
+  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
+  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
+    if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
+      // rhsConst is known to be positive if a constant.
+      if (lrhs.getValue() % rhsConst.getValue() == 0)
+        return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
+    }
+  }
+
+  return nullptr;
+}
+
+AffineExpr AffineExpr::ceilDiv(uint64_t v) const {
+  return ceilDiv(getAffineConstantExpr(v, getContext()));
+}
+AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
+  if (auto simplified = simplifyCeilDiv(*this, other))
+    return simplified;
+
+  StorageUniquer &uniquer = getContext()->getAffineUniquer();
+  return uniquer.get<AffineBinaryOpExprStorage>(
+      /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
+      other);
+}
+
+static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
+  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
+  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
+
+  if (!rhsConst || rhsConst.getValue() < 1)
+    return nullptr;
+
+  if (lhsConst)
+    return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
+                                 lhs.getContext());
+
+  // Fold modulo of an expression that is known to be a multiple of a constant
+  // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
+  // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
+  if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
+    return getAffineConstantExpr(0, lhs.getContext());
+
+  return nullptr;
+  // TODO(bondhugula): In general, this can be simplified more by using the GCD
+  // test, or in general using quantifier elimination (add two new variables q
+  // and r, and eliminate all variables from the linear system other than r. All
+  // of this can be done through mlir/Analysis/'s FlatAffineConstraints.
+}
+
+AffineExpr AffineExpr::operator%(uint64_t v) const {
+  return *this % getAffineConstantExpr(v, getContext());
+}
+AffineExpr AffineExpr::operator%(AffineExpr other) const {
+  if (auto simplified = simplifyMod(*this, other))
+    return simplified;
+
+  StorageUniquer &uniquer = getContext()->getAffineUniquer();
+  return uniquer.get<AffineBinaryOpExprStorage>(
+      /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
+}
+
+AffineExpr AffineExpr::compose(AffineMap map) const {
+  SmallVector<AffineExpr, 8> dimReplacements(map.getResults().begin(),
+                                             map.getResults().end());
+  return replaceDimsAndSymbols(dimReplacements, {});
+}
+raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr &expr) {
+  expr.print(os);
+  return os;
+}
+
+/// Constructs an affine expression from a flat ArrayRef. If there are local
+/// identifiers (neither dimensional nor symbolic) that appear in the sum of
+/// products expression, 'localExprs' is expected to have the AffineExpr
+/// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the
+/// format [dims, symbols, locals, constant term].
+AffineExpr mlir::toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
+                              unsigned numSymbols,
+                              ArrayRef<AffineExpr> localExprs,
+                              MLIRContext *context) {
+  // Assert expected numLocals = eq.size() - numDims - numSymbols - 1
+  assert(eq.size() - numDims - numSymbols - 1 == localExprs.size() &&
+         "unexpected number of local expressions");
+
+  auto expr = getAffineConstantExpr(0, context);
+  // Dimensions and symbols.
+  for (unsigned j = 0; j < numDims + numSymbols; j++) {
+    if (eq[j] == 0) {
+      continue;
+    }
+    auto id = j < numDims ? getAffineDimExpr(j, context)
+                          : getAffineSymbolExpr(j - numDims, context);
+    expr = expr + id * eq[j];
+  }
+
+  // Local identifiers.
+  for (unsigned j = numDims + numSymbols, e = eq.size() - 1; j < e; j++) {
+    if (eq[j] == 0) {
+      continue;
+    }
+    auto term = localExprs[j - numDims - numSymbols] * eq[j];
+    expr = expr + term;
+  }
+
+  // Constant term.
+  int64_t constTerm = eq[eq.size() - 1];
+  if (constTerm != 0)
+    expr = expr + constTerm;
+  return expr;
+}
+
+SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
+                                                     unsigned numSymbols)
+    : numDims(numDims), numSymbols(numSymbols), numLocals(0) {
+  operandExprStack.reserve(8);
+}
+
+void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
+  assert(operandExprStack.size() >= 2);
+  // This is a pure affine expr; the RHS will be a constant.
+  assert(expr.getRHS().isa<AffineConstantExpr>());
+  // Get the RHS constant.
+  auto rhsConst = operandExprStack.back()[getConstantIndex()];
+  operandExprStack.pop_back();
+  // Update the LHS in place instead of pop and push.
+  auto &lhs = operandExprStack.back();
+  for (unsigned i = 0, e = lhs.size(); i < e; i++) {
+    lhs[i] *= rhsConst;
+  }
+}
+
+void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
+  assert(operandExprStack.size() >= 2);
+  const auto &rhs = operandExprStack.back();
+  auto &lhs = operandExprStack[operandExprStack.size() - 2];
+  assert(lhs.size() == rhs.size());
+  // Update the LHS in place.
+  for (unsigned i = 0, e = rhs.size(); i < e; i++) {
+    lhs[i] += rhs[i];
+  }
+  // Pop off the RHS.
+  operandExprStack.pop_back();
+}
+
+//
+// t = expr mod c   <=>  t = expr - c*q and c*q <= expr <= c*q + c - 1
+//
+// A mod expression "expr mod c" is thus flattened by introducing a new local
+// variable q (= expr floordiv c), such that expr mod c is replaced with
+// 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
+void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
+  assert(operandExprStack.size() >= 2);
+  // This is a pure affine expr; the RHS will be a constant.
+  assert(expr.getRHS().isa<AffineConstantExpr>());
+  auto rhsConst = operandExprStack.back()[getConstantIndex()];
+  operandExprStack.pop_back();
+  auto &lhs = operandExprStack.back();
+  // TODO(bondhugula): handle modulo by zero case when this issue is fixed
+  // at the other places in the IR.
+  assert(rhsConst > 0 && "RHS constant has to be positive");
+
+  // Check if the LHS expression is a multiple of modulo factor.
+  unsigned i, e;
+  for (i = 0, e = lhs.size(); i < e; i++)
+    if (lhs[i] % rhsConst != 0)
+      break;
+  // If yes, modulo expression here simplifies to zero.
+  if (i == lhs.size()) {
+    std::fill(lhs.begin(), lhs.end(), 0);
+    return;
+  }
+
+  // Add a local variable for the quotient, i.e., expr % c is replaced by
+  // (expr - q * c) where q = expr floordiv c. Do this while canceling out
+  // the GCD of expr and c.
+  SmallVector<int64_t, 8> floorDividend(lhs);
+  uint64_t gcd = rhsConst;
+  for (unsigned i = 0, e = lhs.size(); i < e; i++)
+    gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
+  // Simplify the numerator and the denominator.
+  if (gcd != 1) {
+    for (unsigned i = 0, e = floorDividend.size(); i < e; i++)
+      floorDividend[i] = floorDividend[i] / static_cast<int64_t>(gcd);
+  }
+  int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
+
+  // Construct the AffineExpr form of the floordiv to store in localExprs.
+  MLIRContext *context = expr.getContext();
+  auto dividendExpr =
+      toAffineExpr(floorDividend, numDims, numSymbols, localExprs, context);
+  auto divisorExpr = getAffineConstantExpr(floorDivisor, context);
+  auto floorDivExpr = dividendExpr.floorDiv(divisorExpr);
+  int loc;
+  if ((loc = findLocalId(floorDivExpr)) == -1) {
+    addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr);
+    // Set result at top of stack to "lhs - rhsConst * q".
+    lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
+  } else {
+    // Reuse the existing local id.
+    lhs[getLocalVarStartIndex() + loc] = -rhsConst;
+  }
+}
+
+void SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
+  visitDivExpr(expr, /*isCeil=*/true);
+}
+void SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
+  visitDivExpr(expr, /*isCeil=*/false);
+}
+
+void SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
+  operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
+  auto &eq = operandExprStack.back();
+  assert(expr.getPosition() < numDims && "Inconsistent number of dims");
+  eq[getDimStartIndex() + expr.getPosition()] = 1;
+}
+
+void SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
+  operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
+  auto &eq = operandExprStack.back();
+  assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
+  eq[getSymbolStartIndex() + expr.getPosition()] = 1;
+}
+
+void SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
+  operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
+  auto &eq = operandExprStack.back();
+  eq[getConstantIndex()] = expr.getValue();
+}
+
+// t = expr floordiv c   <=> t = q, c * q <= expr <= c * q + c - 1
+// A floordiv is thus flattened by introducing a new local variable q, and
+// replacing that expression with 'q' while adding the constraints
+// c * q <= expr <= c * q + c - 1 to localVarCst (done by
+// FlatAffineConstraints::addLocalFloorDiv).
+//
+// A ceildiv is similarly flattened:
+// t = expr ceildiv c   <=> t =  (expr + c - 1) floordiv c
+void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
+                                             bool isCeil) {
+  assert(operandExprStack.size() >= 2);
+  assert(expr.getRHS().isa<AffineConstantExpr>());
+
+  // This is a pure affine expr; the RHS is a positive constant.
+  int64_t rhsConst = operandExprStack.back()[getConstantIndex()];
+  // TODO(bondhugula): handle division by zero at the same time the issue is
+  // fixed at other places.
+  assert(rhsConst > 0 && "RHS constant has to be positive");
+  operandExprStack.pop_back();
+  auto &lhs = operandExprStack.back();
+
+  // Simplify the floordiv, ceildiv if possible by canceling out the greatest
+  // common divisors of the numerator and denominator.
+  uint64_t gcd = std::abs(rhsConst);
+  for (unsigned i = 0, e = lhs.size(); i < e; i++)
+    gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
+  // Simplify the numerator and the denominator.
+  if (gcd != 1) {
+    for (unsigned i = 0, e = lhs.size(); i < e; i++)
+      lhs[i] = lhs[i] / static_cast<int64_t>(gcd);
+  }
+  int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
+  // If the divisor becomes 1, the updated LHS is the result. (The
+  // divisor can't be negative since rhsConst is positive).
+  if (divisor == 1)
+    return;
+
+  // If the divisor cannot be simplified to one, we will have to retain
+  // the ceil/floor expr (simplified up until here). Add an existential
+  // quantifier to express its result, i.e., expr1 div expr2 is replaced
+  // by a new identifier, q.
+  MLIRContext *context = expr.getContext();
+  auto a = toAffineExpr(lhs, numDims, numSymbols, localExprs, context);
+  auto b = getAffineConstantExpr(divisor, context);
+
+  int loc;
+  auto divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
+  if ((loc = findLocalId(divExpr)) == -1) {
+    if (!isCeil) {
+      SmallVector<int64_t, 8> dividend(lhs);
+      addLocalFloorDivId(dividend, divisor, divExpr);
+    } else {
+      // lhs ceildiv c <=>  (lhs + c - 1) floordiv c
+      SmallVector<int64_t, 8> dividend(lhs);
+      dividend.back() += divisor - 1;
+      addLocalFloorDivId(dividend, divisor, divExpr);
+    }
+  }
+  // Set the expression on stack to the local var introduced to capture the
+  // result of the division (floor or ceil).
+  std::fill(lhs.begin(), lhs.end(), 0);
+  if (loc == -1)
+    lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
+  else
+    lhs[getLocalVarStartIndex() + loc] = 1;
+}
+
+// Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
+// The local identifier added is always a floordiv of a pure add/mul affine
+// function of other identifiers, coefficients of which are specified in
+// dividend and with respect to a positive constant divisor. localExpr is the
+// simplified tree expression (AffineExpr) corresponding to the quantifier.
+void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend,
+                                                   int64_t divisor,
+                                                   AffineExpr localExpr) {
+  assert(divisor > 0 && "positive constant divisor expected");
+  for (auto &subExpr : operandExprStack)
+    subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
+  localExprs.push_back(localExpr);
+  numLocals++;
+  // dividend and divisor are not used here; an override of this method uses it.
+}
+
+int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
+  SmallVectorImpl<AffineExpr>::iterator it;
+  if ((it = llvm::find(localExprs, localExpr)) == localExprs.end())
+    return -1;
+  return it - localExprs.begin();
+}
+
+/// Simplify the affine expression by flattening it and reconstructing it.
+AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
+                                    unsigned numSymbols) {
+  // TODO(bondhugula): only pure affine for now. The simplification here can
+  // be extended to semi-affine maps in the future.
+  if (!expr.isPureAffine())
+    return expr;
+
+  SimpleAffineExprFlattener flattener(numDims, numSymbols);
+  flattener.walkPostOrder(expr);
+  ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
+  auto simplifiedExpr = toAffineExpr(flattenedExpr, numDims, numSymbols,
+                                     flattener.localExprs, expr.getContext());
+  flattener.operandExprStack.pop_back();
+  assert(flattener.operandExprStack.empty());
+
+  return simplifiedExpr;
+}
+
+// Flattens the expressions in map. Returns true on success or false
+// if 'expr' was unable to be flattened (i.e., semi-affine expressions not
+// handled yet).
+static bool getFlattenedAffineExprs(
+    ArrayRef<AffineExpr> exprs, unsigned numDims, unsigned numSymbols,
+    std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs) {
+  if (exprs.empty()) {
+    return true;
+  }
+
+  SimpleAffineExprFlattener flattener(numDims, numSymbols);
+  // Use the same flattener to simplify each expression successively. This way
+  // local identifiers / expressions are shared.
+  for (auto expr : exprs) {
+    if (!expr.isPureAffine())
+      return false;
+
+    flattener.walkPostOrder(expr);
+  }
+
+  flattenedExprs->clear();
+  assert(flattener.operandExprStack.size() == exprs.size());
+  flattenedExprs->assign(flattener.operandExprStack.begin(),
+                         flattener.operandExprStack.end());
+
+  return true;
+}
+
+// Flattens 'expr' into 'flattenedExpr'. Returns true on success or false
+// if 'expr' was unable to be flattened (semi-affine expressions not handled
+// yet).
+bool mlir::getFlattenedAffineExpr(
+    AffineExpr expr, unsigned numDims, unsigned numSymbols,
+    llvm::SmallVectorImpl<int64_t> *flattenedExpr) {
+  std::vector<SmallVector<int64_t, 8>> flattenedExprs;
+  bool ret =
+      ::getFlattenedAffineExprs({expr}, numDims, numSymbols, &flattenedExprs);
+  *flattenedExpr = flattenedExprs[0];
+  return ret;
+}
+
+/// Flattens the expressions in map. Returns true on success or false
+/// if 'expr' was unable to be flattened (i.e., semi-affine expressions not
+/// handled yet).
+bool mlir::getFlattenedAffineExprs(
+    AffineMap map, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs) {
+  if (map.getNumResults() == 0) {
+    return true;
+  }
+  return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(),
+                                   map.getNumSymbols(), flattenedExprs);
+}
+
+bool mlir::getFlattenedAffineExprs(
+    IntegerSet set,
+    std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs) {
+  if (set.getNumConstraints() == 0) {
+    return true;
+  }
+  return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(),
+                                   set.getNumSymbols(), flattenedExprs);
+}
diff --git a/third_party/mlir/lib/IR/AffineExprDetail.h b/third_party/mlir/lib/IR/AffineExprDetail.h
new file mode 100644
index 0000000..214fee6
--- /dev/null
+++ b/third_party/mlir/lib/IR/AffineExprDetail.h
@@ -0,0 +1,98 @@
+//===- AffineExprDetail.h - MLIR Affine Expr storage details ----*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This holds implementation details of AffineExpr. Ideally it would not be
+// exposed and would be kept local to AffineExpr.cpp however, MLIRContext.cpp
+// needs to know the sizes for placement-new style Allocation.
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_IR_AFFINEEXPRDETAIL_H_
+#define MLIR_IR_AFFINEEXPRDETAIL_H_
+
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Support/StorageUniquer.h"
+
+namespace mlir {
+
+class MLIRContext;
+
+namespace detail {
+
+/// Base storage class appearing in an affine expression.
+struct AffineExprStorage : public StorageUniquer::BaseStorage {
+  MLIRContext *context;
+};
+
+/// A binary operation appearing in an affine expression.
+struct AffineBinaryOpExprStorage : public AffineExprStorage {
+  using KeyTy = std::pair<AffineExpr, AffineExpr>;
+
+  bool operator==(const KeyTy &key) const {
+    return key.first == lhs && key.second == rhs;
+  }
+
+  static AffineBinaryOpExprStorage *
+  construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
+    auto *result = allocator.allocate<AffineBinaryOpExprStorage>();
+    result->lhs = key.first;
+    result->rhs = key.second;
+    result->context = result->lhs.getContext();
+    return result;
+  }
+
+  AffineExpr lhs;
+  AffineExpr rhs;
+};
+
+/// A dimensional or symbolic identifier appearing in an affine expression.
+struct AffineDimExprStorage : public AffineExprStorage {
+  using KeyTy = unsigned;
+
+  bool operator==(const KeyTy &key) const { return position == key; }
+
+  static AffineDimExprStorage *
+  construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
+    auto *result = allocator.allocate<AffineDimExprStorage>();
+    result->position = key;
+    return result;
+  }
+
+  /// Position of this identifier in the argument list.
+  unsigned position;
+};
+
+/// An integer constant appearing in affine expression.
+struct AffineConstantExprStorage : public AffineExprStorage {
+  using KeyTy = int64_t;
+
+  bool operator==(const KeyTy &key) const { return constant == key; }
+
+  static AffineConstantExprStorage *
+  construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
+    auto *result = allocator.allocate<AffineConstantExprStorage>();
+    result->constant = key;
+    return result;
+  }
+
+  // The constant.
+  int64_t constant;
+};
+
+} // end namespace detail
+} // end namespace mlir
+#endif // MLIR_IR_AFFINEEXPRDETAIL_H_
diff --git a/third_party/mlir/lib/IR/AffineMap.cpp b/third_party/mlir/lib/IR/AffineMap.cpp
new file mode 100644
index 0000000..1b6bbe5
--- /dev/null
+++ b/third_party/mlir/lib/IR/AffineMap.cpp
@@ -0,0 +1,319 @@
+//===- AffineMap.cpp - MLIR Affine Map Classes ----------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/AffineMap.h"
+#include "AffineMapDetail.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Support/Functional.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/MathExtras.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+namespace {
+
+// AffineExprConstantFolder evaluates an affine expression using constant
+// operands passed in 'operandConsts'. Returns an IntegerAttr attribute
+// representing the constant value of the affine expression evaluated on
+// constant 'operandConsts', or nullptr if it can't be folded.
+class AffineExprConstantFolder {
+public:
+  AffineExprConstantFolder(unsigned numDims, ArrayRef<Attribute> operandConsts)
+      : numDims(numDims), operandConsts(operandConsts) {}
+
+  /// Attempt to constant fold the specified affine expr, or return null on
+  /// failure.
+  IntegerAttr constantFold(AffineExpr expr) {
+    if (auto result = constantFoldImpl(expr))
+      return IntegerAttr::get(IndexType::get(expr.getContext()), *result);
+    return nullptr;
+  }
+
+private:
+  llvm::Optional<int64_t> constantFoldImpl(AffineExpr expr) {
+    switch (expr.getKind()) {
+    case AffineExprKind::Add:
+      return constantFoldBinExpr(
+          expr, [](int64_t lhs, int64_t rhs) { return lhs + rhs; });
+    case AffineExprKind::Mul:
+      return constantFoldBinExpr(
+          expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; });
+    case AffineExprKind::Mod:
+      return constantFoldBinExpr(
+          expr, [](int64_t lhs, int64_t rhs) { return mod(lhs, rhs); });
+    case AffineExprKind::FloorDiv:
+      return constantFoldBinExpr(
+          expr, [](int64_t lhs, int64_t rhs) { return floorDiv(lhs, rhs); });
+    case AffineExprKind::CeilDiv:
+      return constantFoldBinExpr(
+          expr, [](int64_t lhs, int64_t rhs) { return ceilDiv(lhs, rhs); });
+    case AffineExprKind::Constant:
+      return expr.cast<AffineConstantExpr>().getValue();
+    case AffineExprKind::DimId:
+      if (auto attr = operandConsts[expr.cast<AffineDimExpr>().getPosition()]
+                          .dyn_cast_or_null<IntegerAttr>())
+        return attr.getInt();
+      return llvm::None;
+    case AffineExprKind::SymbolId:
+      if (auto attr = operandConsts[numDims +
+                                    expr.cast<AffineSymbolExpr>().getPosition()]
+                          .dyn_cast_or_null<IntegerAttr>())
+        return attr.getInt();
+      return llvm::None;
+    }
+    llvm_unreachable("Unknown AffineExpr");
+  }
+
+  // TODO: Change these to operate on APInts too.
+  llvm::Optional<int64_t> constantFoldBinExpr(AffineExpr expr,
+                                              int64_t (*op)(int64_t, int64_t)) {
+    auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+    if (auto lhs = constantFoldImpl(binOpExpr.getLHS()))
+      if (auto rhs = constantFoldImpl(binOpExpr.getRHS()))
+        return op(*lhs, *rhs);
+    return llvm::None;
+  }
+
+  // The number of dimension operands in AffineMap containing this expression.
+  unsigned numDims;
+  // The constant valued operands used to evaluate this AffineExpr.
+  ArrayRef<Attribute> operandConsts;
+};
+
+} // end anonymous namespace
+
+/// Returns a single constant result affine map.
+AffineMap AffineMap::getConstantMap(int64_t val, MLIRContext *context) {
+  return get(/*dimCount=*/0, /*symbolCount=*/0,
+             {getAffineConstantExpr(val, context)});
+}
+
+AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims,
+                                            MLIRContext *context) {
+  SmallVector<AffineExpr, 4> dimExprs;
+  dimExprs.reserve(numDims);
+  for (unsigned i = 0; i < numDims; ++i)
+    dimExprs.push_back(mlir::getAffineDimExpr(i, context));
+  return get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs);
+}
+
+MLIRContext *AffineMap::getContext() const { return getResult(0).getContext(); }
+
+bool AffineMap::isIdentity() const {
+  if (getNumDims() != getNumResults())
+    return false;
+  ArrayRef<AffineExpr> results = getResults();
+  for (unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) {
+    auto expr = results[i].dyn_cast<AffineDimExpr>();
+    if (!expr || expr.getPosition() != i)
+      return false;
+  }
+  return true;
+}
+
+bool AffineMap::isSingleConstant() const {
+  return getNumResults() == 1 && getResult(0).isa<AffineConstantExpr>();
+}
+
+int64_t AffineMap::getSingleConstantResult() const {
+  assert(isSingleConstant() && "map must have a single constant result");
+  return getResult(0).cast<AffineConstantExpr>().getValue();
+}
+
+unsigned AffineMap::getNumDims() const {
+  assert(map && "uninitialized map storage");
+  return map->numDims;
+}
+unsigned AffineMap::getNumSymbols() const {
+  assert(map && "uninitialized map storage");
+  return map->numSymbols;
+}
+unsigned AffineMap::getNumResults() const {
+  assert(map && "uninitialized map storage");
+  return map->results.size();
+}
+unsigned AffineMap::getNumInputs() const {
+  assert(map && "uninitialized map storage");
+  return map->numDims + map->numSymbols;
+}
+
+ArrayRef<AffineExpr> AffineMap::getResults() const {
+  assert(map && "uninitialized map storage");
+  return map->results;
+}
+AffineExpr AffineMap::getResult(unsigned idx) const {
+  assert(map && "uninitialized map storage");
+  return map->results[idx];
+}
+
+/// Folds the results of the application of an affine map on the provided
+/// operands to a constant if possible. Returns false if the folding happens,
+/// true otherwise.
+LogicalResult
+AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
+                        SmallVectorImpl<Attribute> &results) const {
+  assert(getNumInputs() == operandConstants.size());
+
+  // Fold each of the result expressions.
+  AffineExprConstantFolder exprFolder(getNumDims(), operandConstants);
+  // Constant fold each AffineExpr in AffineMap and add to 'results'.
+  for (auto expr : getResults()) {
+    auto folded = exprFolder.constantFold(expr);
+    // If we didn't fold to a constant, then folding fails.
+    if (!folded)
+      return failure();
+
+    results.push_back(folded);
+  }
+  assert(results.size() == getNumResults() &&
+         "constant folding produced the wrong number of results");
+  return success();
+}
+
+/// Walk all of the AffineExpr's in this mapping. Each node in an expression
+/// tree is visited in postorder.
+void AffineMap::walkExprs(std::function<void(AffineExpr)> callback) const {
+  for (auto expr : getResults())
+    expr.walk(callback);
+}
+
+/// This method substitutes any uses of dimensions and symbols (e.g.
+/// dim#0 with dimReplacements[0]) in subexpressions and returns the modified
+/// expression mapping.  Because this can be used to eliminate dims and
+/// symbols, the client needs to specify the number of dims and symbols in
+/// the result.  The returned map always has the same number of results.
+AffineMap AffineMap::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
+                                           ArrayRef<AffineExpr> symReplacements,
+                                           unsigned numResultDims,
+                                           unsigned numResultSyms) {
+  SmallVector<AffineExpr, 8> results;
+  results.reserve(getNumResults());
+  for (auto expr : getResults())
+    results.push_back(
+        expr.replaceDimsAndSymbols(dimReplacements, symReplacements));
+
+  return get(numResultDims, numResultSyms, results);
+}
+
+AffineMap AffineMap::compose(AffineMap map) {
+  assert(getNumDims() == map.getNumResults() && "Number of results mismatch");
+  // Prepare `map` by concatenating the symbols and rewriting its exprs.
+  unsigned numDims = map.getNumDims();
+  unsigned numSymbolsThisMap = getNumSymbols();
+  unsigned numSymbols = numSymbolsThisMap + map.getNumSymbols();
+  SmallVector<AffineExpr, 8> newDims(numDims);
+  for (unsigned idx = 0; idx < numDims; ++idx) {
+    newDims[idx] = getAffineDimExpr(idx, getContext());
+  }
+  SmallVector<AffineExpr, 8> newSymbols(numSymbols);
+  for (unsigned idx = numSymbolsThisMap; idx < numSymbols; ++idx) {
+    newSymbols[idx - numSymbolsThisMap] =
+        getAffineSymbolExpr(idx, getContext());
+  }
+  auto newMap =
+      map.replaceDimsAndSymbols(newDims, newSymbols, numDims, numSymbols);
+  SmallVector<AffineExpr, 8> exprs;
+  exprs.reserve(getResults().size());
+  for (auto expr : getResults())
+    exprs.push_back(expr.compose(newMap));
+  return AffineMap::get(numDims, numSymbols, exprs);
+}
+
+bool AffineMap::isProjectedPermutation() {
+  if (getNumSymbols() > 0)
+    return false;
+  SmallVector<bool, 8> seen(getNumInputs(), false);
+  for (auto expr : getResults()) {
+    if (auto dim = expr.dyn_cast<AffineDimExpr>()) {
+      if (seen[dim.getPosition()])
+        return false;
+      seen[dim.getPosition()] = true;
+      continue;
+    }
+    return false;
+  }
+  return true;
+}
+
+bool AffineMap::isPermutation() {
+  if (getNumDims() != getNumResults())
+    return false;
+  return isProjectedPermutation();
+}
+
+AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) {
+  SmallVector<AffineExpr, 4> exprs;
+  exprs.reserve(resultPos.size());
+  for (auto idx : resultPos) {
+    exprs.push_back(getResult(idx));
+  }
+  return AffineMap::get(getNumDims(), getNumSymbols(), exprs);
+}
+
+AffineMap mlir::simplifyAffineMap(AffineMap map) {
+  SmallVector<AffineExpr, 8> exprs;
+  for (auto e : map.getResults()) {
+    exprs.push_back(
+        simplifyAffineExpr(e, map.getNumDims(), map.getNumSymbols()));
+  }
+  return AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs);
+}
+
+AffineMap mlir::inversePermutation(AffineMap map) {
+  if (!map)
+    return map;
+  assert(map.getNumSymbols() == 0 && "expected map without symbols");
+  SmallVector<AffineExpr, 4> exprs(map.getNumDims());
+  for (auto en : llvm::enumerate(map.getResults())) {
+    auto expr = en.value();
+    // Skip non-permutations.
+    if (auto d = expr.dyn_cast<AffineDimExpr>()) {
+      if (exprs[d.getPosition()])
+        continue;
+      exprs[d.getPosition()] = getAffineDimExpr(en.index(), d.getContext());
+    }
+  }
+  SmallVector<AffineExpr, 4> seenExprs;
+  seenExprs.reserve(map.getNumDims());
+  for (auto expr : exprs)
+    if (expr)
+      seenExprs.push_back(expr);
+  if (seenExprs.size() != map.getNumInputs())
+    return AffineMap();
+  return AffineMap::get(map.getNumResults(), 0, seenExprs);
+}
+
+AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
+  unsigned numResults = 0;
+  for (auto m : maps)
+    numResults += m ? m.getNumResults() : 0;
+  unsigned numDims = 0;
+  llvm::SmallVector<AffineExpr, 8> results;
+  results.reserve(numResults);
+  for (auto m : maps) {
+    if (!m)
+      continue;
+    assert(m.getNumSymbols() == 0 && "expected map without symbols");
+    results.append(m.getResults().begin(), m.getResults().end());
+    numDims = std::max(m.getNumDims(), numDims);
+  }
+  return numDims == 0 ? AffineMap() : AffineMap::get(numDims, 0, results);
+}
diff --git a/third_party/mlir/lib/IR/AffineMapDetail.h b/third_party/mlir/lib/IR/AffineMapDetail.h
new file mode 100644
index 0000000..af1d89c
--- /dev/null
+++ b/third_party/mlir/lib/IR/AffineMapDetail.h
@@ -0,0 +1,44 @@
+//===- AffineMapDetail.h - MLIR Affine Map details Class --------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This holds implementation details of AffineMap.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef AFFINEMAPDETAIL_H_
+#define AFFINEMAPDETAIL_H_
+
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "llvm/ADT/ArrayRef.h"
+
+namespace mlir {
+namespace detail {
+
+struct AffineMapStorage {
+  unsigned numDims;
+  unsigned numSymbols;
+
+  /// The affine expressions for this (multi-dimensional) map.
+  /// TODO: use trailing objects for this.
+  ArrayRef<AffineExpr> results;
+};
+
+} // end namespace detail
+} // end namespace mlir
+
+#endif // AFFINEMAPDETAIL_H_
diff --git a/third_party/mlir/lib/IR/AsmPrinter.cpp b/third_party/mlir/lib/IR/AsmPrinter.cpp
new file mode 100644
index 0000000..a137f26
--- /dev/null
+++ b/third_party/mlir/lib/IR/AsmPrinter.cpp
@@ -0,0 +1,1777 @@
+//===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the MLIR AsmPrinter class, which is used to implement
+// the various print() methods on the core IR objects.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Support/STLExtras.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopedHashTable.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Regex.h"
+using namespace mlir;
+
+void Identifier::print(raw_ostream &os) const { os << str(); }
+
+void Identifier::dump() const { print(llvm::errs()); }
+
+void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
+
+void OperationName::dump() const { print(llvm::errs()); }
+
+OpAsmPrinter::~OpAsmPrinter() {}
+
+//===----------------------------------------------------------------------===//
+// ModuleState
+//===----------------------------------------------------------------------===//
+
+// TODO(riverriddle) Rethink this flag when we have a pass that can remove debug
+// info or when we have a system for printer flags.
+static llvm::cl::opt<bool>
+    shouldPrintDebugInfoOpt("mlir-print-debuginfo",
+                            llvm::cl::desc("Print debug info in MLIR output"),
+                            llvm::cl::init(false));
+
+static llvm::cl::opt<bool> printPrettyDebugInfo(
+    "mlir-pretty-debuginfo",
+    llvm::cl::desc("Print pretty debug info in MLIR output"),
+    llvm::cl::init(false));
+
+// Use the generic op output form in the operation printer even if the custom
+// form is defined.
+static llvm::cl::opt<bool>
+    printGenericOpForm("mlir-print-op-generic",
+                       llvm::cl::desc("Print the generic op form"),
+                       llvm::cl::init(false), llvm::cl::Hidden);
+
+namespace {
+/// A special index constant used for non-kind attribute aliases.
+static constexpr int kNonAttrKindAlias = -1;
+
+class ModuleState {
+public:
+  /// This is the current context if it is knowable, otherwise this is null.
+  MLIRContext *const context;
+
+  explicit ModuleState(MLIRContext *context) : context(context) {}
+
+  // Initializes module state, populating affine map state.
+  void initialize(Operation *op);
+
+  Twine getAttributeAlias(Attribute attr) const {
+    auto alias = attrToAlias.find(attr);
+    if (alias == attrToAlias.end())
+      return Twine();
+
+    // Return the alias for this attribute, along with the index if this was
+    // generated by a kind alias.
+    int kindIndex = alias->second.second;
+    return alias->second.first +
+           (kindIndex == kNonAttrKindAlias ? Twine() : Twine(kindIndex));
+  }
+
+  void printAttributeAliases(raw_ostream &os) const {
+    auto printAlias = [&](StringRef alias, Attribute attr, int index) {
+      os << '#' << alias;
+      if (index != kNonAttrKindAlias)
+        os << index;
+      os << " = " << attr << '\n';
+    };
+
+    // Print all of the attribute kind aliases.
+    for (auto &kindAlias : attrKindToAlias) {
+      for (unsigned i = 0, e = kindAlias.second.second.size(); i != e; ++i)
+        printAlias(kindAlias.second.first, kindAlias.second.second[i], i);
+      os << "\n";
+    }
+
+    // In a second pass print all of the remaining attribute aliases that aren't
+    // kind aliases.
+    for (Attribute attr : usedAttributes) {
+      auto alias = attrToAlias.find(attr);
+      if (alias != attrToAlias.end() &&
+          alias->second.second == kNonAttrKindAlias)
+        printAlias(alias->second.first, attr, alias->second.second);
+    }
+  }
+
+  StringRef getTypeAlias(Type ty) const { return typeToAlias.lookup(ty); }
+
+  void printTypeAliases(raw_ostream &os) const {
+    for (Type type : usedTypes) {
+      auto alias = typeToAlias.find(type);
+      if (alias != typeToAlias.end())
+        os << '!' << alias->second << " = type " << type << '\n';
+    }
+  }
+
+private:
+  void recordAttributeReference(Attribute attr) {
+    // Don't recheck attributes that have already been seen or those that
+    // already have an alias.
+    if (!usedAttributes.insert(attr) || attrToAlias.count(attr))
+      return;
+
+    // If this attribute kind has an alias, then record one for this attribute.
+    auto alias = attrKindToAlias.find(static_cast<unsigned>(attr.getKind()));
+    if (alias == attrKindToAlias.end())
+      return;
+    std::pair<StringRef, int> attrAlias(alias->second.first,
+                                        alias->second.second.size());
+    attrToAlias.insert({attr, attrAlias});
+    alias->second.second.push_back(attr);
+  }
+
+  void recordTypeReference(Type ty) { usedTypes.insert(ty); }
+
+  // Visit functions.
+  void visitOperation(Operation *op);
+  void visitType(Type type);
+  void visitAttribute(Attribute attr);
+
+  // Initialize symbol aliases.
+  void initializeSymbolAliases();
+
+  /// Set of attributes known to be used within the module.
+  llvm::SetVector<Attribute> usedAttributes;
+
+  /// Mapping between attribute and a pair comprised of a base alias name and a
+  /// count suffix. If the suffix is set to -1, it is not displayed.
+  llvm::MapVector<Attribute, std::pair<StringRef, int>> attrToAlias;
+
+  /// Mapping between attribute kind and a pair comprised of a base alias name
+  /// and a unique list of attributes belonging to this kind sorted by location
+  /// seen in the module.
+  llvm::MapVector<unsigned, std::pair<StringRef, std::vector<Attribute>>>
+      attrKindToAlias;
+
+  /// Set of types known to be used within the module.
+  llvm::SetVector<Type> usedTypes;
+
+  /// A mapping between a type and a given alias.
+  DenseMap<Type, StringRef> typeToAlias;
+};
+} // end anonymous namespace
+
+// TODO Support visiting other types/operations when implemented.
+void ModuleState::visitType(Type type) {
+  recordTypeReference(type);
+  if (auto funcType = type.dyn_cast<FunctionType>()) {
+    // Visit input and result types for functions.
+    for (auto input : funcType.getInputs())
+      visitType(input);
+    for (auto result : funcType.getResults())
+      visitType(result);
+    return;
+  }
+  if (auto memref = type.dyn_cast<MemRefType>()) {
+    // Visit affine maps in memref type.
+    for (auto map : memref.getAffineMaps())
+      recordAttributeReference(AffineMapAttr::get(map));
+  }
+  if (auto shapedType = type.dyn_cast<ShapedType>()) {
+    visitType(shapedType.getElementType());
+  }
+}
+
+void ModuleState::visitAttribute(Attribute attr) {
+  recordAttributeReference(attr);
+  if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
+    for (auto elt : arrayAttr.getValue())
+      visitAttribute(elt);
+  } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
+    visitType(typeAttr.getValue());
+  }
+}
+
+void ModuleState::visitOperation(Operation *op) {
+  // Visit all the types used in the operation.
+  for (auto type : op->getOperandTypes())
+    visitType(type);
+  for (auto type : op->getResultTypes())
+    visitType(type);
+  for (auto &region : op->getRegions())
+    for (auto &block : region)
+      for (auto *arg : block.getArguments())
+        visitType(arg->getType());
+
+  // Visit each of the attributes.
+  for (auto elt : op->getAttrs())
+    visitAttribute(elt.second);
+}
+
+// Utility to generate a function to register a symbol alias.
+static bool canRegisterAlias(StringRef name, llvm::StringSet<> &usedAliases) {
+  assert(!name.empty() && "expected alias name to be non-empty");
+  // TODO(riverriddle) Assert that the provided alias name can be lexed as
+  // an identifier.
+
+  // Check that the alias doesn't contain a '.' character and the name is not
+  // already in use.
+  return !name.contains('.') && usedAliases.insert(name).second;
+}
+
+void ModuleState::initializeSymbolAliases() {
+  // Track the identifiers in use for each symbol so that the same identifier
+  // isn't used twice.
+  llvm::StringSet<> usedAliases;
+
+  // Get the currently registered dialects.
+  auto dialects = context->getRegisteredDialects();
+
+  // Collect the set of aliases from each dialect.
+  SmallVector<std::pair<unsigned, StringRef>, 8> attributeKindAliases;
+  SmallVector<std::pair<Attribute, StringRef>, 8> attributeAliases;
+  SmallVector<std::pair<Type, StringRef>, 16> typeAliases;
+
+  // AffineMap/Integer set have specific kind aliases.
+  attributeKindAliases.emplace_back(StandardAttributes::AffineMap, "map");
+  attributeKindAliases.emplace_back(StandardAttributes::IntegerSet, "set");
+
+  for (auto *dialect : dialects) {
+    dialect->getAttributeKindAliases(attributeKindAliases);
+    dialect->getAttributeAliases(attributeAliases);
+    dialect->getTypeAliases(typeAliases);
+  }
+
+  // Setup the attribute kind aliases.
+  StringRef alias;
+  unsigned attrKind;
+  for (auto &attrAliasPair : attributeKindAliases) {
+    std::tie(attrKind, alias) = attrAliasPair;
+    assert(!alias.empty() && "expected non-empty alias string");
+    if (!usedAliases.count(alias) && !alias.contains('.'))
+      attrKindToAlias.insert({attrKind, {alias, {}}});
+  }
+
+  // Clear the set of used identifiers so that the attribute kind aliases are
+  // just a prefix and not the full alias, i.e. there may be some overlap.
+  usedAliases.clear();
+
+  // Register the attribute aliases.
+  // Create a regex for the attribute kind alias names, these have a prefix with
+  // a counter appended to the end. We prevent normal aliases from having these
+  // names to avoid collisions.
+  llvm::Regex reservedAttrNames("[0-9]+$");
+
+  // Attribute value aliases.
+  Attribute attr;
+  for (auto &attrAliasPair : attributeAliases) {
+    std::tie(attr, alias) = attrAliasPair;
+    if (!reservedAttrNames.match(alias) && canRegisterAlias(alias, usedAliases))
+      attrToAlias.insert({attr, {alias, kNonAttrKindAlias}});
+  }
+
+  // Clear the set of used identifiers as types can have the same identifiers as
+  // affine structures.
+  usedAliases.clear();
+
+  // Type aliases.
+  for (auto &typeAliasPair : typeAliases)
+    if (canRegisterAlias(typeAliasPair.second, usedAliases))
+      typeToAlias.insert(typeAliasPair);
+}
+
+// Initializes module state, populating affine map and integer set state.
+void ModuleState::initialize(Operation *op) {
+  // Initialize the symbol aliases.
+  initializeSymbolAliases();
+
+  // Visit each of the nested operations.
+  op->walk([&](Operation *op) { visitOperation(op); });
+}
+
+//===----------------------------------------------------------------------===//
+// ModulePrinter
+//===----------------------------------------------------------------------===//
+
+namespace {
+class ModulePrinter {
+public:
+  ModulePrinter(raw_ostream &os, ModuleState &state) : os(os), state(state) {}
+  explicit ModulePrinter(ModulePrinter &printer)
+      : os(printer.os), state(printer.state) {}
+
+  template <typename Container, typename UnaryFunctor>
+  inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
+    interleave(c.begin(), c.end(), each_fn, [&]() { os << ", "; });
+  }
+
+  void print(ModuleOp module);
+
+  /// Print the given attribute. If 'mayElideType' is true, some attributes are
+  /// printed without the type when the type matches the default used in the
+  /// parser (for example i64 is the default for integer attributes).
+  void printAttribute(Attribute attr, bool mayElideType = false);
+
+  void printType(Type type);
+  void printLocation(LocationAttr loc);
+
+  void printAffineMap(AffineMap map);
+  void printAffineExpr(
+      AffineExpr expr,
+      llvm::function_ref<void(unsigned, bool)> printValueName = nullptr);
+  void printAffineConstraint(AffineExpr expr, bool isEq);
+  void printIntegerSet(IntegerSet set);
+
+protected:
+  raw_ostream &os;
+  ModuleState &state;
+
+  void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
+                             ArrayRef<StringRef> elidedAttrs = {});
+  void printTrailingLocation(Location loc);
+  void printLocationInternal(LocationAttr loc, bool pretty = false);
+  void printDenseElementsAttr(DenseElementsAttr attr);
+
+  /// This enum is used to represent the binding stength of the enclosing
+  /// context that an AffineExprStorage is being printed in, so we can
+  /// intelligently produce parens.
+  enum class BindingStrength {
+    Weak,   // + and -
+    Strong, // All other binary operators.
+  };
+  void printAffineExprInternal(
+      AffineExpr expr, BindingStrength enclosingTightness,
+      llvm::function_ref<void(unsigned, bool)> printValueName = nullptr);
+};
+} // end anonymous namespace
+
+void ModulePrinter::printTrailingLocation(Location loc) {
+  // Check to see if we are printing debug information.
+  if (!shouldPrintDebugInfoOpt)
+    return;
+
+  os << " ";
+  printLocation(loc);
+}
+
+void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) {
+  switch (loc.getKind()) {
+  case StandardAttributes::UnknownLocation:
+    if (pretty)
+      os << "[unknown]";
+    else
+      os << "unknown";
+    break;
+  case StandardAttributes::FileLineColLocation: {
+    auto fileLoc = loc.cast<FileLineColLoc>();
+    auto mayQuote = pretty ? "" : "\"";
+    os << mayQuote << fileLoc.getFilename() << mayQuote << ':'
+       << fileLoc.getLine() << ':' << fileLoc.getColumn();
+    break;
+  }
+  case StandardAttributes::NameLocation: {
+    auto nameLoc = loc.cast<NameLoc>();
+    os << '\"' << nameLoc.getName() << '\"';
+
+    // Print the child if it isn't unknown.
+    auto childLoc = nameLoc.getChildLoc();
+    if (!childLoc.isa<UnknownLoc>()) {
+      os << '(';
+      printLocationInternal(childLoc, pretty);
+      os << ')';
+    }
+    break;
+  }
+  case StandardAttributes::CallSiteLocation: {
+    auto callLocation = loc.cast<CallSiteLoc>();
+    auto caller = callLocation.getCaller();
+    auto callee = callLocation.getCallee();
+    if (!pretty)
+      os << "callsite(";
+    printLocationInternal(callee, pretty);
+    if (pretty) {
+      if (callee.isa<NameLoc>()) {
+        if (caller.isa<FileLineColLoc>()) {
+          os << " at ";
+        } else {
+          os << "\n at ";
+        }
+      } else {
+        os << "\n at ";
+      }
+    } else {
+      os << " at ";
+    }
+    printLocationInternal(caller, pretty);
+    if (!pretty)
+      os << ")";
+    break;
+  }
+  case StandardAttributes::FusedLocation: {
+    auto fusedLoc = loc.cast<FusedLoc>();
+    if (!pretty)
+      os << "fused";
+    if (auto metadata = fusedLoc.getMetadata())
+      os << '<' << metadata << '>';
+    os << '[';
+    interleave(
+        fusedLoc.getLocations(),
+        [&](Location loc) { printLocationInternal(loc, pretty); },
+        [&]() { os << ", "; });
+    os << ']';
+    break;
+  }
+  }
+}
+
+/// Print a floating point value in a way that the parser will be able to
+/// round-trip losslessly.
+static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
+  // We would like to output the FP constant value in exponential notation,
+  // but we cannot do this if doing so will lose precision.  Check here to
+  // make sure that we only output it in exponential format if we can parse
+  // the value back and get the same value.
+  bool isInf = apValue.isInfinity();
+  bool isNaN = apValue.isNaN();
+  if (!isInf && !isNaN) {
+    SmallString<128> strValue;
+    apValue.toString(strValue, 6, 0, false);
+
+    // Check to make sure that the stringized number is not some string like
+    // "Inf" or NaN, that atof will accept, but the lexer will not.  Check
+    // that the string matches the "[-+]?[0-9]" regex.
+    assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
+            ((strValue[0] == '-' || strValue[0] == '+') &&
+             (strValue[1] >= '0' && strValue[1] <= '9'))) &&
+           "[-+]?[0-9] regex does not match!");
+
+    // Parse back the stringized version and check that the value is equal
+    // (i.e., there is no precision loss). If it is not, use the default format
+    // of APFloat instead of the exponential notation.
+    if (!APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
+      strValue.clear();
+      apValue.toString(strValue);
+    }
+    os << strValue;
+    return;
+  }
+
+  // Print special values in hexadecimal format.  The sign bit should be
+  // included in the literal.
+  SmallVector<char, 16> str;
+  APInt apInt = apValue.bitcastToAPInt();
+  apInt.toString(str, /*Radix=*/16, /*Signed=*/false,
+                 /*formatAsCLiteral=*/true);
+  os << str;
+}
+
+void ModulePrinter::printLocation(LocationAttr loc) {
+  if (printPrettyDebugInfo) {
+    printLocationInternal(loc, /*pretty=*/true);
+  } else {
+    os << "loc(";
+    printLocationInternal(loc);
+    os << ')';
+  }
+}
+
+/// Returns if the given dialect symbol data is simple enough to print in the
+/// pretty form, i.e. without the enclosing "".
+static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
+  // The name must start with an identifier.
+  if (symName.empty() || !isalpha(symName.front()))
+    return false;
+
+  // Ignore all the characters that are valid in an identifier in the symbol
+  // name.
+  symName =
+      symName.drop_while([](char c) { return llvm::isAlnum(c) || c == '.'; });
+  if (symName.empty())
+    return true;
+
+  // If we got to an unexpected character, then it must be a <>.  Check those
+  // recursively.
+  if (symName.front() != '<' || symName.back() != '>')
+    return false;
+
+  SmallVector<char, 8> nestedPunctuation;
+  do {
+    // If we ran out of characters, then we had a punctuation mismatch.
+    if (symName.empty())
+      return false;
+
+    auto c = symName.front();
+    symName = symName.drop_front();
+
+    switch (c) {
+    // We never allow null characters. This is an EOF indicator for the lexer
+    // which we could handle, but isn't important for any known dialect.
+    case '\0':
+      return false;
+    case '<':
+    case '[':
+    case '(':
+    case '{':
+      nestedPunctuation.push_back(c);
+      continue;
+    // Reject types with mismatched brackets.
+    case '>':
+      if (nestedPunctuation.pop_back_val() != '<')
+        return false;
+      break;
+    case ']':
+      if (nestedPunctuation.pop_back_val() != '[')
+        return false;
+      break;
+    case ')':
+      if (nestedPunctuation.pop_back_val() != '(')
+        return false;
+      break;
+    case '}':
+      if (nestedPunctuation.pop_back_val() != '{')
+        return false;
+      break;
+    default:
+      continue;
+    }
+
+    // We're done when the punctuation is fully matched.
+  } while (!nestedPunctuation.empty());
+
+  // If there were extra characters, then we failed.
+  return symName.empty();
+}
+
+/// Print the given dialect symbol to the stream.
+static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
+                               StringRef dialectName, StringRef symString) {
+  os << symPrefix << dialectName;
+
+  // If this symbol name is simple enough, print it directly in pretty form,
+  // otherwise, we print it as an escaped string.
+  if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) {
+    os << '.' << symString;
+    return;
+  }
+
+  // TODO: escape the symbol name, it could contain " characters.
+  os << "<\"" << symString << "\">";
+}
+
+void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
+  if (!attr) {
+    os << "<<NULL ATTRIBUTE>>";
+    return;
+  }
+
+  // Check for an alias for this attribute.
+  Twine alias = state.getAttributeAlias(attr);
+  if (!alias.isTriviallyEmpty()) {
+    os << '#' << alias;
+    return;
+  }
+
+  switch (attr.getKind()) {
+  default: {
+    auto &dialect = attr.getDialect();
+
+    // Ask the dialect to serialize the attribute to a string.
+    std::string attrName;
+    {
+      llvm::raw_string_ostream attrNameStr(attrName);
+      dialect.printAttribute(attr, attrNameStr);
+    }
+
+    printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
+    break;
+  }
+  case StandardAttributes::Opaque: {
+    auto opaqueAttr = attr.cast<OpaqueAttr>();
+    printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
+                       opaqueAttr.getAttrData());
+    break;
+  }
+  case StandardAttributes::Unit:
+    os << "unit";
+    break;
+  case StandardAttributes::Bool:
+    os << (attr.cast<BoolAttr>().getValue() ? "true" : "false");
+
+    // BoolAttr always elides the type.
+    return;
+  case StandardAttributes::Dictionary:
+    os << '{';
+    interleaveComma(attr.cast<DictionaryAttr>().getValue(),
+                    [&](NamedAttribute attr) {
+                      os << attr.first << " = ";
+                      printAttribute(attr.second);
+                    });
+    os << '}';
+    break;
+  case StandardAttributes::Integer: {
+    auto intAttr = attr.cast<IntegerAttr>();
+    // Print all integer attributes as signed unless i1.
+    bool isSigned = intAttr.getType().isIndex() ||
+                    intAttr.getType().getIntOrFloatBitWidth() != 1;
+    intAttr.getValue().print(os, isSigned);
+
+    // IntegerAttr elides the type if I64.
+    if (mayElideType && intAttr.getType().isInteger(64))
+      return;
+    break;
+  }
+  case StandardAttributes::Float: {
+    auto floatAttr = attr.cast<FloatAttr>();
+    printFloatValue(floatAttr.getValue(), os);
+
+    // FloatAttr elides the type if F64.
+    if (mayElideType && floatAttr.getType().isF64())
+      return;
+    break;
+  }
+  case StandardAttributes::String:
+    os << '"';
+    printEscapedString(attr.cast<StringAttr>().getValue(), os);
+    os << '"';
+    break;
+  case StandardAttributes::Array:
+    os << '[';
+    interleaveComma(attr.cast<ArrayAttr>().getValue(), [&](Attribute attr) {
+      printAttribute(attr, /*mayElideType=*/true);
+    });
+    os << ']';
+    break;
+  case StandardAttributes::AffineMap:
+    attr.cast<AffineMapAttr>().getValue().print(os);
+
+    // AffineMap always elides the type.
+    return;
+  case StandardAttributes::IntegerSet:
+    attr.cast<IntegerSetAttr>().getValue().print(os);
+    break;
+  case StandardAttributes::Type:
+    printType(attr.cast<TypeAttr>().getValue());
+    break;
+  case StandardAttributes::SymbolRef:
+    os << '@' << attr.cast<SymbolRefAttr>().getValue();
+    break;
+  case StandardAttributes::OpaqueElements: {
+    auto eltsAttr = attr.cast<OpaqueElementsAttr>();
+    os << "opaque<\"" << eltsAttr.getDialect()->getNamespace() << "\", ";
+    os << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << "\">";
+    break;
+  }
+  case StandardAttributes::DenseElements: {
+    auto eltsAttr = attr.cast<DenseElementsAttr>();
+    os << "dense<";
+    printDenseElementsAttr(eltsAttr);
+    os << '>';
+    break;
+  }
+  case StandardAttributes::SparseElements: {
+    auto elementsAttr = attr.cast<SparseElementsAttr>();
+    os << "sparse<";
+    printDenseElementsAttr(elementsAttr.getIndices());
+    os << ", ";
+    printDenseElementsAttr(elementsAttr.getValues());
+    os << '>';
+    break;
+  }
+
+  // Location attributes.
+  case StandardAttributes::CallSiteLocation:
+  case StandardAttributes::FileLineColLocation:
+  case StandardAttributes::FusedLocation:
+  case StandardAttributes::NameLocation:
+  case StandardAttributes::UnknownLocation:
+    printLocation(attr.cast<LocationAttr>());
+    break;
+  }
+
+  // Print the type if it isn't a 'none' type.
+  auto attrType = attr.getType();
+  if (!attrType.isa<NoneType>()) {
+    os << " : ";
+    printType(attrType);
+  }
+}
+
+/// Print the integer element of the given DenseElementsAttr at 'index'.
+static void printDenseIntElement(DenseElementsAttr attr, raw_ostream &os,
+                                 unsigned index) {
+  APInt value = *std::next(attr.getIntValues().begin(), index);
+  if (value.getBitWidth() == 1)
+    os << (value.getBoolValue() ? "true" : "false");
+  else
+    value.print(os, /*isSigned=*/true);
+}
+
+/// Print the float element of the given DenseElementsAttr at 'index'.
+static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os,
+                                   unsigned index) {
+  APFloat value = *std::next(attr.getFloatValues().begin(), index);
+  printFloatValue(value, os);
+}
+
+void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
+  auto type = attr.getType();
+  auto shape = type.getShape();
+  auto rank = type.getRank();
+
+  // The function used to print elements of this attribute.
+  auto printEltFn = type.getElementType().isa<IntegerType>()
+                        ? printDenseIntElement
+                        : printDenseFloatElement;
+
+  // Special case for 0-d and splat tensors.
+  if (attr.isSplat()) {
+    printEltFn(attr, os, 0);
+    return;
+  }
+
+  // Special case for degenerate tensors.
+  auto numElements = type.getNumElements();
+  if (numElements == 0) {
+    for (int i = 0; i < rank; ++i)
+      os << '[';
+    for (int i = 0; i < rank; ++i)
+      os << ']';
+    return;
+  }
+
+  // We use a mixed-radix counter to iterate through the shape. When we bump a
+  // non-least-significant digit, we emit a close bracket. When we next emit an
+  // element we re-open all closed brackets.
+
+  // The mixed-radix counter, with radices in 'shape'.
+  SmallVector<unsigned, 4> counter(rank, 0);
+  // The number of brackets that have been opened and not closed.
+  unsigned openBrackets = 0;
+
+  auto bumpCounter = [&]() {
+    // Bump the least significant digit.
+    ++counter[rank - 1];
+    // Iterate backwards bubbling back the increment.
+    for (unsigned i = rank - 1; i > 0; --i)
+      if (counter[i] >= shape[i]) {
+        // Index 'i' is rolled over. Bump (i-1) and close a bracket.
+        counter[i] = 0;
+        ++counter[i - 1];
+        --openBrackets;
+        os << ']';
+      }
+  };
+
+  for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
+    if (idx != 0)
+      os << ", ";
+    while (openBrackets++ < rank)
+      os << '[';
+    openBrackets = rank;
+    printEltFn(attr, os, idx);
+    bumpCounter();
+  }
+  while (openBrackets-- > 0)
+    os << ']';
+}
+
+void ModulePrinter::printType(Type type) {
+  // Check for an alias for this type.
+  StringRef alias = state.getTypeAlias(type);
+  if (!alias.empty()) {
+    os << '!' << alias;
+    return;
+  }
+
+  switch (type.getKind()) {
+  default: {
+    auto &dialect = type.getDialect();
+
+    // Ask the dialect to serialize the type to a string.
+    std::string typeName;
+    {
+      llvm::raw_string_ostream typeNameStr(typeName);
+      dialect.printType(type, typeNameStr);
+    }
+
+    printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
+    return;
+  }
+  case Type::Kind::Opaque: {
+    auto opaqueTy = type.cast<OpaqueType>();
+    printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
+                       opaqueTy.getTypeData());
+    return;
+  }
+  case StandardTypes::Index:
+    os << "index";
+    return;
+  case StandardTypes::BF16:
+    os << "bf16";
+    return;
+  case StandardTypes::F16:
+    os << "f16";
+    return;
+  case StandardTypes::F32:
+    os << "f32";
+    return;
+  case StandardTypes::F64:
+    os << "f64";
+    return;
+
+  case StandardTypes::Integer: {
+    auto integer = type.cast<IntegerType>();
+    os << 'i' << integer.getWidth();
+    return;
+  }
+  case Type::Kind::Function: {
+    auto func = type.cast<FunctionType>();
+    os << '(';
+    interleaveComma(func.getInputs(), [&](Type type) { printType(type); });
+    os << ") -> ";
+    auto results = func.getResults();
+    if (results.size() == 1 && !results[0].isa<FunctionType>())
+      os << results[0];
+    else {
+      os << '(';
+      interleaveComma(results, [&](Type type) { printType(type); });
+      os << ')';
+    }
+    return;
+  }
+  case StandardTypes::Vector: {
+    auto v = type.cast<VectorType>();
+    os << "vector<";
+    for (auto dim : v.getShape())
+      os << dim << 'x';
+    os << v.getElementType() << '>';
+    return;
+  }
+  case StandardTypes::RankedTensor: {
+    auto v = type.cast<RankedTensorType>();
+    os << "tensor<";
+    for (auto dim : v.getShape()) {
+      if (dim < 0)
+        os << '?';
+      else
+        os << dim;
+      os << 'x';
+    }
+    os << v.getElementType() << '>';
+    return;
+  }
+  case StandardTypes::UnrankedTensor: {
+    auto v = type.cast<UnrankedTensorType>();
+    os << "tensor<*x";
+    printType(v.getElementType());
+    os << '>';
+    return;
+  }
+  case StandardTypes::MemRef: {
+    auto v = type.cast<MemRefType>();
+    os << "memref<";
+    for (auto dim : v.getShape()) {
+      if (dim < 0)
+        os << '?';
+      else
+        os << dim;
+      os << 'x';
+    }
+    printType(v.getElementType());
+    for (auto map : v.getAffineMaps()) {
+      os << ", ";
+      printAttribute(AffineMapAttr::get(map));
+    }
+    // Only print the memory space if it is the non-default one.
+    if (v.getMemorySpace())
+      os << ", " << v.getMemorySpace();
+    os << '>';
+    return;
+  }
+  case StandardTypes::Complex:
+    os << "complex<";
+    printType(type.cast<ComplexType>().getElementType());
+    os << '>';
+    return;
+  case StandardTypes::Tuple: {
+    auto tuple = type.cast<TupleType>();
+    os << "tuple<";
+    interleaveComma(tuple.getTypes(), [&](Type type) { printType(type); });
+    os << '>';
+    return;
+  }
+  case StandardTypes::None:
+    os << "none";
+    return;
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// Affine expressions and maps
+//===----------------------------------------------------------------------===//
+
+void ModulePrinter::printAffineExpr(
+    AffineExpr expr, llvm::function_ref<void(unsigned, bool)> printValueName) {
+  printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
+}
+
+void ModulePrinter::printAffineExprInternal(
+    AffineExpr expr, BindingStrength enclosingTightness,
+    llvm::function_ref<void(unsigned, bool)> printValueName) {
+  const char *binopSpelling = nullptr;
+  switch (expr.getKind()) {
+  case AffineExprKind::SymbolId: {
+    unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
+    if (printValueName)
+      printValueName(pos, /*isSymbol=*/true);
+    else
+      os << 's' << pos;
+    return;
+  }
+  case AffineExprKind::DimId: {
+    unsigned pos = expr.cast<AffineDimExpr>().getPosition();
+    if (printValueName)
+      printValueName(pos, /*isSymbol=*/false);
+    else
+      os << 'd' << pos;
+    return;
+  }
+  case AffineExprKind::Constant:
+    os << expr.cast<AffineConstantExpr>().getValue();
+    return;
+  case AffineExprKind::Add:
+    binopSpelling = " + ";
+    break;
+  case AffineExprKind::Mul:
+    binopSpelling = " * ";
+    break;
+  case AffineExprKind::FloorDiv:
+    binopSpelling = " floordiv ";
+    break;
+  case AffineExprKind::CeilDiv:
+    binopSpelling = " ceildiv ";
+    break;
+  case AffineExprKind::Mod:
+    binopSpelling = " mod ";
+    break;
+  }
+
+  auto binOp = expr.cast<AffineBinaryOpExpr>();
+  AffineExpr lhsExpr = binOp.getLHS();
+  AffineExpr rhsExpr = binOp.getRHS();
+
+  // Handle tightly binding binary operators.
+  if (binOp.getKind() != AffineExprKind::Add) {
+    if (enclosingTightness == BindingStrength::Strong)
+      os << '(';
+
+    // Pretty print multiplication with -1.
+    auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
+    if (rhsConst && rhsConst.getValue() == -1) {
+      os << "-";
+      printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
+      return;
+    }
+
+    printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
+
+    os << binopSpelling;
+    printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
+
+    if (enclosingTightness == BindingStrength::Strong)
+      os << ')';
+    return;
+  }
+
+  // Print out special "pretty" forms for add.
+  if (enclosingTightness == BindingStrength::Strong)
+    os << '(';
+
+  // Pretty print addition to a product that has a negative operand as a
+  // subtraction.
+  if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) {
+    if (rhs.getKind() == AffineExprKind::Mul) {
+      AffineExpr rrhsExpr = rhs.getRHS();
+      if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
+        if (rrhs.getValue() == -1) {
+          printAffineExprInternal(lhsExpr, BindingStrength::Weak,
+                                  printValueName);
+          os << " - ";
+          if (rhs.getLHS().getKind() == AffineExprKind::Add) {
+            printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
+                                    printValueName);
+          } else {
+            printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
+                                    printValueName);
+          }
+
+          if (enclosingTightness == BindingStrength::Strong)
+            os << ')';
+          return;
+        }
+
+        if (rrhs.getValue() < -1) {
+          printAffineExprInternal(lhsExpr, BindingStrength::Weak,
+                                  printValueName);
+          os << " - ";
+          printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
+                                  printValueName);
+          os << " * " << -rrhs.getValue();
+          if (enclosingTightness == BindingStrength::Strong)
+            os << ')';
+          return;
+        }
+      }
+    }
+  }
+
+  // Pretty print addition to a negative number as a subtraction.
+  if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
+    if (rhsConst.getValue() < 0) {
+      printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
+      os << " - " << -rhsConst.getValue();
+      if (enclosingTightness == BindingStrength::Strong)
+        os << ')';
+      return;
+    }
+  }
+
+  printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
+
+  os << " + ";
+  printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
+
+  if (enclosingTightness == BindingStrength::Strong)
+    os << ')';
+}
+
+void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) {
+  printAffineExprInternal(expr, BindingStrength::Weak);
+  isEq ? os << " == 0" : os << " >= 0";
+}
+
+void ModulePrinter::printAffineMap(AffineMap map) {
+  // Dimension identifiers.
+  os << '(';
+  for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
+    os << 'd' << i << ", ";
+  if (map.getNumDims() >= 1)
+    os << 'd' << map.getNumDims() - 1;
+  os << ')';
+
+  // Symbolic identifiers.
+  if (map.getNumSymbols() != 0) {
+    os << '[';
+    for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
+      os << 's' << i << ", ";
+    if (map.getNumSymbols() >= 1)
+      os << 's' << map.getNumSymbols() - 1;
+    os << ']';
+  }
+
+  // AffineMap should have at least one result.
+  assert(!map.getResults().empty());
+  // Result affine expressions.
+  os << " -> (";
+  interleaveComma(map.getResults(),
+                  [&](AffineExpr expr) { printAffineExpr(expr); });
+  os << ')';
+}
+
+void ModulePrinter::printIntegerSet(IntegerSet set) {
+  // Dimension identifiers.
+  os << '(';
+  for (unsigned i = 1; i < set.getNumDims(); ++i)
+    os << 'd' << i - 1 << ", ";
+  if (set.getNumDims() >= 1)
+    os << 'd' << set.getNumDims() - 1;
+  os << ')';
+
+  // Symbolic identifiers.
+  if (set.getNumSymbols() != 0) {
+    os << '[';
+    for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
+      os << 's' << i << ", ";
+    if (set.getNumSymbols() >= 1)
+      os << 's' << set.getNumSymbols() - 1;
+    os << ']';
+  }
+
+  // Print constraints.
+  os << " : (";
+  int numConstraints = set.getNumConstraints();
+  for (int i = 1; i < numConstraints; ++i) {
+    printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
+    os << ", ";
+  }
+  if (numConstraints >= 1)
+    printAffineConstraint(set.getConstraint(numConstraints - 1),
+                          set.isEq(numConstraints - 1));
+  os << ')';
+}
+
+//===----------------------------------------------------------------------===//
+// Operation printing
+//===----------------------------------------------------------------------===//
+
+void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
+                                          ArrayRef<StringRef> elidedAttrs) {
+  // If there are no attributes, then there is nothing to be done.
+  if (attrs.empty())
+    return;
+
+  // Filter out any attributes that shouldn't be included.
+  SmallVector<NamedAttribute, 8> filteredAttrs;
+  for (auto attr : attrs) {
+    // If the caller has requested that this attribute be ignored, then drop it.
+    if (llvm::any_of(elidedAttrs,
+                     [&](StringRef elided) { return attr.first.is(elided); }))
+      continue;
+
+    // Otherwise add it to our filteredAttrs list.
+    filteredAttrs.push_back(attr);
+  }
+
+  // If there are no attributes left to print after filtering, then we're done.
+  if (filteredAttrs.empty())
+    return;
+
+  // Otherwise, print them all out in braces.
+  os << " {";
+  interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
+    os << attr.first;
+
+    // Pretty printing elides the attribute value for unit attributes.
+    if (attr.second.isa<UnitAttr>())
+      return;
+
+    os << " = ";
+    printAttribute(attr.second);
+  });
+  os << '}';
+}
+
+namespace {
+
+// OperationPrinter contains common functionality for printing operations.
+class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
+public:
+  OperationPrinter(Operation *op, ModulePrinter &other);
+  OperationPrinter(Region *region, ModulePrinter &other);
+
+  // Methods to print operations.
+  void print(Operation *op);
+  void print(Block *block, bool printBlockArgs = true,
+             bool printBlockTerminator = true);
+
+  void printOperation(Operation *op);
+  void printGenericOp(Operation *op) override;
+
+  // Implement OpAsmPrinter.
+  raw_ostream &getStream() const override { return os; }
+  void printType(Type type) override { ModulePrinter::printType(type); }
+  void printAttribute(Attribute attr) override {
+    ModulePrinter::printAttribute(attr);
+  }
+  void printOperand(Value *value) override { printValueID(value); }
+
+  void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
+                             ArrayRef<StringRef> elidedAttrs = {}) override {
+    return ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs);
+  };
+
+  enum { nameSentinel = ~0U };
+
+  void printBlockName(Block *block) {
+    auto id = getBlockID(block);
+    if (id != ~0U)
+      os << "^bb" << id;
+    else
+      os << "^INVALIDBLOCK";
+  }
+
+  unsigned getBlockID(Block *block) {
+    auto it = blockIDs.find(block);
+    return it != blockIDs.end() ? it->second : ~0U;
+  }
+
+  void printSuccessorAndUseList(Operation *term, unsigned index) override;
+
+  /// Print a region.
+  void printRegion(Region &blocks, bool printEntryBlockArgs,
+                   bool printBlockTerminators) override {
+    os << " {\n";
+    if (!blocks.empty()) {
+      auto *entryBlock = &blocks.front();
+      print(entryBlock,
+            printEntryBlockArgs && entryBlock->getNumArguments() != 0,
+            printBlockTerminators);
+      for (auto &b : llvm::drop_begin(blocks.getBlocks(), 1))
+        print(&b);
+    }
+    os.indent(currentIndent) << "}";
+  }
+
+  void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
+                              ArrayRef<Value *> operands) override {
+    AffineMap map = mapAttr.getValue();
+    unsigned numDims = map.getNumDims();
+    auto printValueName = [&](unsigned pos, bool isSymbol) {
+      unsigned index = isSymbol ? numDims + pos : pos;
+      assert(index < operands.size());
+      if (isSymbol)
+        os << "symbol(";
+      printValueID(operands[index]);
+      if (isSymbol)
+        os << ')';
+    };
+
+    interleaveComma(map.getResults(), [&](AffineExpr expr) {
+      printAffineExpr(expr, printValueName);
+    });
+  }
+
+  // Number of spaces used for indenting nested operations.
+  const static unsigned indentWidth = 2;
+
+protected:
+  void numberValueID(Value *value);
+  void numberValuesInRegion(Region &region);
+  void numberValuesInBlock(Block &block);
+  void printValueID(Value *value, bool printResultNo = true) const;
+
+private:
+  /// Uniques the given value name within the printer. If the given name
+  /// conflicts, it is automatically renamed.
+  StringRef uniqueValueName(StringRef name);
+
+  /// This is the value ID for each SSA value. If this returns ~0, then the
+  /// valueID has an entry in valueNames.
+  DenseMap<Value *, unsigned> valueIDs;
+  DenseMap<Value *, StringRef> valueNames;
+
+  /// This is the block ID for each block in the current.
+  DenseMap<Block *, unsigned> blockIDs;
+
+  /// This keeps track of all of the non-numeric names that are in flight,
+  /// allowing us to check for duplicates.
+  /// Note: the value of the map is unused.
+  llvm::ScopedHashTable<StringRef, char> usedNames;
+  llvm::BumpPtrAllocator usedNameAllocator;
+
+  // This is the current indentation level for nested structures.
+  unsigned currentIndent = 0;
+
+  /// This is the next value ID to assign in numbering.
+  unsigned nextValueID = 0;
+  /// This is the next ID to assign to a region entry block argument.
+  unsigned nextArgumentID = 0;
+  /// This is the next ID to assign when a name conflict is detected.
+  unsigned nextConflictID = 0;
+};
+} // end anonymous namespace
+
+OperationPrinter::OperationPrinter(Operation *op, ModulePrinter &other)
+    : ModulePrinter(other) {
+  if (op->getNumResults() != 0)
+    numberValueID(op->getResult(0));
+  for (auto &region : op->getRegions())
+    numberValuesInRegion(region);
+}
+
+OperationPrinter::OperationPrinter(Region *region, ModulePrinter &other)
+    : ModulePrinter(other) {
+  numberValuesInRegion(*region);
+}
+
+/// Number all of the SSA values in the specified region.
+void OperationPrinter::numberValuesInRegion(Region &region) {
+  // Save the current value ids to allow for numbering values in sibling regions
+  // the same.
+  unsigned curValueID = nextValueID;
+  unsigned curArgumentID = nextArgumentID;
+  unsigned curConflictID = nextConflictID;
+
+  // Push a new used names scope.
+  llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames);
+
+  // Number the values within this region in a breadth-first order.
+  unsigned nextBlockID = 0;
+  for (auto &block : region) {
+    // Each block gets a unique ID, and all of the operations within it get
+    // numbered as well.
+    blockIDs[&block] = nextBlockID++;
+    numberValuesInBlock(block);
+  }
+
+  // After that we traverse the nested regions.
+  // TODO: Rework this loop to not use recursion.
+  for (auto &block : region) {
+    for (auto &op : block)
+      for (auto &nestedRegion : op.getRegions())
+        numberValuesInRegion(nestedRegion);
+  }
+
+  // Restore the original value ids.
+  nextValueID = curValueID;
+  nextArgumentID = curArgumentID;
+  nextConflictID = curConflictID;
+}
+
+/// Number all of the SSA values in the specified block, without traversing
+/// nested regions.
+void OperationPrinter::numberValuesInBlock(Block &block) {
+  // Number the block arguments.
+  for (auto *arg : block.getArguments())
+    numberValueID(arg);
+
+  // We number operation that have results, and we only number the first result.
+  for (auto &op : block)
+    if (op.getNumResults() != 0)
+      numberValueID(op.getResult(0));
+}
+
+void OperationPrinter::numberValueID(Value *value) {
+  assert(!valueIDs.count(value) && "Value numbered multiple times");
+
+  SmallString<32> specialNameBuffer;
+  llvm::raw_svector_ostream specialName(specialNameBuffer);
+
+  // Give constant integers special names.
+  if (auto *op = value->getDefiningOp()) {
+    Attribute cst;
+    if (m_Constant(&cst).match(op)) {
+      Type type = op->getResult(0)->getType();
+      if (auto intCst = cst.dyn_cast<IntegerAttr>()) {
+        if (type.isIndex()) {
+          specialName << 'c' << intCst.getInt();
+        } else if (type.cast<IntegerType>().isInteger(1)) {
+          // i1 constants get special names.
+          specialName << (intCst.getInt() ? "true" : "false");
+        } else {
+          specialName << 'c' << intCst.getInt() << '_' << type;
+        }
+      } else if (type.isa<FunctionType>()) {
+        specialName << 'f';
+      } else {
+        specialName << "cst";
+      }
+    }
+  }
+
+  if (specialNameBuffer.empty()) {
+    switch (value->getKind()) {
+    case Value::Kind::BlockArgument:
+      // If this is an argument to the entry block of a region, give it an 'arg'
+      // name.
+      if (auto *block = cast<BlockArgument>(value)->getOwner()) {
+        auto *parentRegion = block->getParent();
+        if (parentRegion && block == &parentRegion->front()) {
+          specialName << "arg" << nextArgumentID++;
+          break;
+        }
+      }
+      // Otherwise number it normally.
+      valueIDs[value] = nextValueID++;
+      return;
+    case Value::Kind::OpResult:
+      // This is an uninteresting result, give it a boring number and be
+      // done with it.
+      valueIDs[value] = nextValueID++;
+      return;
+    }
+  }
+
+  // Ok, this value had an interesting name.  Remember it with a sentinel.
+  valueIDs[value] = nameSentinel;
+  valueNames[value] = uniqueValueName(specialName.str());
+}
+
+/// Uniques the given value name within the printer. If the given name
+/// conflicts, it is automatically renamed.
+StringRef OperationPrinter::uniqueValueName(StringRef name) {
+  // Check to see if this name is already unique.
+  if (!usedNames.count(name)) {
+    name = name.copy(usedNameAllocator);
+  } else {
+    // Otherwise, we had a conflict - probe until we find a unique name. This
+    // is guaranteed to terminate (and usually in a single iteration) because it
+    // generates new names by incrementing nextConflictID.
+    SmallString<64> probeName(name);
+    probeName.push_back('_');
+    while (1) {
+      probeName.resize(name.size() + 1);
+      probeName += llvm::utostr(nextConflictID++);
+      if (!usedNames.count(probeName)) {
+        name = StringRef(probeName).copy(usedNameAllocator);
+        break;
+      }
+    }
+  }
+
+  usedNames.insert(name, char());
+  return name;
+}
+
+void OperationPrinter::print(Block *block, bool printBlockArgs,
+                             bool printBlockTerminator) {
+  // Print the block label and argument list if requested.
+  if (printBlockArgs) {
+    os.indent(currentIndent);
+    printBlockName(block);
+
+    // Print the argument list if non-empty.
+    if (!block->args_empty()) {
+      os << '(';
+      interleaveComma(block->getArguments(), [&](BlockArgument *arg) {
+        printValueID(arg);
+        os << ": ";
+        printType(arg->getType());
+      });
+      os << ')';
+    }
+    os << ':';
+
+    // Print out some context information about the predecessors of this block.
+    if (!block->getParent()) {
+      os << "\t// block is not in a region!";
+    } else if (block->hasNoPredecessors()) {
+      os << "\t// no predecessors";
+    } else if (auto *pred = block->getSinglePredecessor()) {
+      os << "\t// pred: ";
+      printBlockName(pred);
+    } else {
+      // We want to print the predecessors in increasing numeric order, not in
+      // whatever order the use-list is in, so gather and sort them.
+      SmallVector<std::pair<unsigned, Block *>, 4> predIDs;
+      for (auto *pred : block->getPredecessors())
+        predIDs.push_back({getBlockID(pred), pred});
+      llvm::array_pod_sort(predIDs.begin(), predIDs.end());
+
+      os << "\t// " << predIDs.size() << " preds: ";
+
+      interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) {
+        printBlockName(pred.second);
+      });
+    }
+    os << '\n';
+  }
+
+  currentIndent += indentWidth;
+  auto range = llvm::make_range(
+      block->getOperations().begin(),
+      std::prev(block->getOperations().end(), printBlockTerminator ? 0 : 1));
+  for (auto &op : range) {
+    print(&op);
+    os << '\n';
+  }
+  currentIndent -= indentWidth;
+}
+
+void OperationPrinter::print(Operation *op) {
+  os.indent(currentIndent);
+  printOperation(op);
+  printTrailingLocation(op->getLoc());
+}
+
+void OperationPrinter::printValueID(Value *value, bool printResultNo) const {
+  int resultNo = -1;
+  auto lookupValue = value;
+
+  // If this is a reference to the result of a multi-result operation or
+  // operation, print out the # identifier and make sure to map our lookup
+  // to the first result of the operation.
+  if (auto *result = dyn_cast<OpResult>(value)) {
+    if (result->getOwner()->getNumResults() != 1) {
+      resultNo = result->getResultNumber();
+      lookupValue = result->getOwner()->getResult(0);
+    }
+  }
+
+  auto it = valueIDs.find(lookupValue);
+  if (it == valueIDs.end()) {
+    os << "<<INVALID SSA VALUE>>";
+    return;
+  }
+
+  os << '%';
+  if (it->second != nameSentinel) {
+    os << it->second;
+  } else {
+    auto nameIt = valueNames.find(lookupValue);
+    assert(nameIt != valueNames.end() && "Didn't have a name entry?");
+    os << nameIt->second;
+  }
+
+  if (resultNo != -1 && printResultNo)
+    os << '#' << resultNo;
+}
+
+void OperationPrinter::printOperation(Operation *op) {
+  if (size_t numResults = op->getNumResults()) {
+    printValueID(op->getResult(0), /*printResultNo=*/false);
+    if (numResults > 1)
+      os << ':' << numResults;
+    os << " = ";
+  }
+
+  // TODO(riverriddle): FuncOp cannot be round-tripped currently, as
+  // FunctionType cannot be used in a TypeAttr.
+  if (printGenericOpForm && !isa<FuncOp>(op))
+    return printGenericOp(op);
+
+  // Check to see if this is a known operation.  If so, use the registered
+  // custom printer hook.
+  if (auto *opInfo = op->getAbstractOperation()) {
+    opInfo->printAssembly(op, this);
+    return;
+  }
+
+  // Otherwise print with the generic assembly form.
+  printGenericOp(op);
+}
+
+void OperationPrinter::printGenericOp(Operation *op) {
+  os << '"';
+  printEscapedString(op->getName().getStringRef(), os);
+  os << "\"(";
+
+  // Get the list of operands that are not successor operands.
+  unsigned totalNumSuccessorOperands = 0;
+  unsigned numSuccessors = op->getNumSuccessors();
+  for (unsigned i = 0; i < numSuccessors; ++i)
+    totalNumSuccessorOperands += op->getNumSuccessorOperands(i);
+  unsigned numProperOperands = op->getNumOperands() - totalNumSuccessorOperands;
+  SmallVector<Value *, 8> properOperands(
+      op->operand_begin(), std::next(op->operand_begin(), numProperOperands));
+
+  interleaveComma(properOperands, [&](Value *value) { printValueID(value); });
+
+  os << ')';
+
+  // For terminators, print the list of successors and their operands.
+  if (numSuccessors != 0) {
+    os << '[';
+    for (unsigned i = 0; i < numSuccessors; ++i) {
+      if (i != 0)
+        os << ", ";
+      printSuccessorAndUseList(op, i);
+    }
+    os << ']';
+  }
+
+  // Print regions.
+  if (op->getNumRegions() != 0) {
+    os << " (";
+    interleaveComma(op->getRegions(), [&](Region &region) {
+      printRegion(region, /*printEntryBlockArgs=*/true,
+                  /*printBlockTerminators=*/true);
+    });
+    os << ')';
+  }
+
+  auto attrs = op->getAttrs();
+  printOptionalAttrDict(attrs);
+
+  // Print the type signature of the operation.
+  os << " : ";
+  printFunctionalType(op);
+}
+
+void OperationPrinter::printSuccessorAndUseList(Operation *term,
+                                                unsigned index) {
+  printBlockName(term->getSuccessor(index));
+
+  auto succOperands = term->getSuccessorOperands(index);
+  if (succOperands.begin() == succOperands.end())
+    return;
+
+  os << '(';
+  interleaveComma(succOperands,
+                  [this](Value *operand) { printValueID(operand); });
+  os << " : ";
+  interleaveComma(succOperands,
+                  [this](Value *operand) { printType(operand->getType()); });
+  os << ')';
+}
+
+void ModulePrinter::print(ModuleOp module) {
+  // Output the aliases at the top level.
+  state.printAttributeAliases(os);
+  state.printTypeAliases(os);
+
+  // Print the module.
+  OperationPrinter(module, *this).print(module);
+  os << '\n';
+}
+
+//===----------------------------------------------------------------------===//
+// print and dump methods
+//===----------------------------------------------------------------------===//
+
+void Attribute::print(raw_ostream &os) const {
+  ModuleState state(/*no context is known*/ nullptr);
+  ModulePrinter(os, state).printAttribute(*this);
+}
+
+void Attribute::dump() const {
+  print(llvm::errs());
+  llvm::errs() << "\n";
+}
+
+void Type::print(raw_ostream &os) {
+  ModuleState state(getContext());
+  ModulePrinter(os, state).printType(*this);
+}
+
+void Type::dump() { print(llvm::errs()); }
+
+void AffineMap::dump() const {
+  print(llvm::errs());
+  llvm::errs() << "\n";
+}
+
+void IntegerSet::dump() const {
+  print(llvm::errs());
+  llvm::errs() << "\n";
+}
+
+void AffineExpr::print(raw_ostream &os) const {
+  if (expr == nullptr) {
+    os << "null affine expr";
+    return;
+  }
+  ModuleState state(getContext());
+  ModulePrinter(os, state).printAffineExpr(*this);
+}
+
+void AffineExpr::dump() const {
+  print(llvm::errs());
+  llvm::errs() << "\n";
+}
+
+void AffineMap::print(raw_ostream &os) const {
+  if (map == nullptr) {
+    os << "null affine map";
+    return;
+  }
+  ModuleState state(getContext());
+  ModulePrinter(os, state).printAffineMap(*this);
+}
+
+void IntegerSet::print(raw_ostream &os) const {
+  ModuleState state(/*no context is known*/ nullptr);
+  ModulePrinter(os, state).printIntegerSet(*this);
+}
+
+void Value::print(raw_ostream &os) {
+  switch (getKind()) {
+  case Value::Kind::BlockArgument:
+    // TODO: Improve this.
+    os << "<block argument>\n";
+    return;
+  case Value::Kind::OpResult:
+    return getDefiningOp()->print(os);
+  }
+}
+
+void Value::dump() { print(llvm::errs()); }
+
+void Operation::print(raw_ostream &os) {
+  // Handle top-level operations.
+  if (!getParent()) {
+    ModuleState state(getContext());
+    ModulePrinter modulePrinter(os, state);
+    OperationPrinter(this, modulePrinter).print(this);
+    return;
+  }
+
+  auto region = getParentRegion();
+  if (!region) {
+    os << "<<UNLINKED INSTRUCTION>>\n";
+    return;
+  }
+
+  // Get the top-level region.
+  while (auto *nextRegion = region->getParentRegion())
+    region = nextRegion;
+
+  ModuleState state(getContext());
+  ModulePrinter modulePrinter(os, state);
+  OperationPrinter(region, modulePrinter).print(this);
+}
+
+void Operation::dump() {
+  print(llvm::errs());
+  llvm::errs() << "\n";
+}
+
+void Block::print(raw_ostream &os) {
+  auto region = getParent();
+  if (!region) {
+    os << "<<UNLINKED BLOCK>>\n";
+    return;
+  }
+
+  // Get the top-level region.
+  while (auto *nextRegion = region->getParentRegion())
+    region = nextRegion;
+
+  ModuleState state(region->getContext());
+  ModulePrinter modulePrinter(os, state);
+  OperationPrinter(region, modulePrinter).print(this);
+}
+
+void Block::dump() { print(llvm::errs()); }
+
+/// Print out the name of the block without printing its body.
+void Block::printAsOperand(raw_ostream &os, bool printType) {
+  auto region = getParent();
+  if (!region) {
+    os << "<<UNLINKED BLOCK>>\n";
+    return;
+  }
+
+  // Get the top-level region.
+  while (auto *nextRegion = region->getParentRegion())
+    region = nextRegion;
+
+  ModuleState state(region->getContext());
+  ModulePrinter modulePrinter(os, state);
+  OperationPrinter(region, modulePrinter).printBlockName(this);
+}
+
+void ModuleOp::print(raw_ostream &os) {
+  ModuleState state(getContext());
+  state.initialize(*this);
+  ModulePrinter(os, state).print(*this);
+}
+
+void ModuleOp::dump() { print(llvm::errs()); }
diff --git a/third_party/mlir/lib/IR/AttributeDetail.h b/third_party/mlir/lib/IR/AttributeDetail.h
new file mode 100644
index 0000000..21f8b68
--- /dev/null
+++ b/third_party/mlir/lib/IR/AttributeDetail.h
@@ -0,0 +1,567 @@
+//===- AttributeDetail.h - MLIR Affine Map details Class --------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This holds implementation details of Attribute.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef ATTRIBUTEDETAIL_H_
+#define ATTRIBUTEDETAIL_H_
+
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Identifier.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Support/StorageUniquer.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/PointerIntPair.h"
+#include "llvm/Support/TrailingObjects.h"
+
+namespace mlir {
+namespace detail {
+// An attribute representing a reference to an affine map.
+struct AffineMapAttributeStorage : public AttributeStorage {
+  using KeyTy = AffineMap;
+
+  AffineMapAttributeStorage(AffineMap value)
+      : AttributeStorage(IndexType::get(value.getContext())), value(value) {}
+
+  /// Key equality function.
+  bool operator==(const KeyTy &key) const { return key == value; }
+
+  /// Construct a new storage instance.
+  static AffineMapAttributeStorage *
+  construct(AttributeStorageAllocator &allocator, KeyTy key) {
+    return new (allocator.allocate<AffineMapAttributeStorage>())
+        AffineMapAttributeStorage(key);
+  }
+
+  AffineMap value;
+};
+
+/// An attribute representing an array of other attributes.
+struct ArrayAttributeStorage : public AttributeStorage {
+  using KeyTy = ArrayRef<Attribute>;
+
+  ArrayAttributeStorage(ArrayRef<Attribute> value) : value(value) {}
+
+  /// Key equality function.
+  bool operator==(const KeyTy &key) const { return key == value; }
+
+  /// Construct a new storage instance.
+  static ArrayAttributeStorage *construct(AttributeStorageAllocator &allocator,
+                                          const KeyTy &key) {
+    return new (allocator.allocate<ArrayAttributeStorage>())
+        ArrayAttributeStorage(allocator.copyInto(key));
+  }
+
+  ArrayRef<Attribute> value;
+};
+
+/// An attribute representing a boolean value.
+struct BoolAttributeStorage : public AttributeStorage {
+  using KeyTy = std::pair<MLIRContext *, bool>;
+
+  BoolAttributeStorage(Type type, bool value)
+      : AttributeStorage(type), value(value) {}
+
+  /// We only check equality for and hash with the boolean key parameter.
+  bool operator==(const KeyTy &key) const { return key.second == value; }
+  static unsigned hashKey(const KeyTy &key) {
+    return llvm::hash_value(key.second);
+  }
+
+  static BoolAttributeStorage *construct(AttributeStorageAllocator &allocator,
+                                         const KeyTy &key) {
+    return new (allocator.allocate<BoolAttributeStorage>())
+        BoolAttributeStorage(IntegerType::get(1, key.first), key.second);
+  }
+
+  bool value;
+};
+
+/// An attribute representing a dictionary of sorted named attributes.
+struct DictionaryAttributeStorage final
+    : public AttributeStorage,
+      private llvm::TrailingObjects<DictionaryAttributeStorage,
+                                    NamedAttribute> {
+  using KeyTy = ArrayRef<NamedAttribute>;
+
+  /// Given a list of NamedAttribute's, canonicalize the list (sorting
+  /// by name) and return the unique'd result.
+  static DictionaryAttributeStorage *get(ArrayRef<NamedAttribute> attrs);
+
+  /// Key equality function.
+  bool operator==(const KeyTy &key) const { return key == getElements(); }
+
+  /// Construct a new storage instance.
+  static DictionaryAttributeStorage *
+  construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+    auto size = DictionaryAttributeStorage::totalSizeToAlloc<NamedAttribute>(
+        key.size());
+    auto rawMem = allocator.allocate(size, alignof(NamedAttribute));
+
+    // Initialize the storage and trailing attribute list.
+    auto result = ::new (rawMem) DictionaryAttributeStorage(key.size());
+    std::uninitialized_copy(key.begin(), key.end(),
+                            result->getTrailingObjects<NamedAttribute>());
+    return result;
+  }
+
+  /// Return the elements of this dictionary attribute.
+  ArrayRef<NamedAttribute> getElements() const {
+    return {getTrailingObjects<NamedAttribute>(), numElements};
+  }
+
+private:
+  friend class llvm::TrailingObjects<DictionaryAttributeStorage,
+                                     NamedAttribute>;
+
+  // This is used by the llvm::TrailingObjects base class.
+  size_t numTrailingObjects(OverloadToken<NamedAttribute>) const {
+    return numElements;
+  }
+  DictionaryAttributeStorage(unsigned numElements) : numElements(numElements) {}
+
+  /// This is the number of attributes.
+  const unsigned numElements;
+};
+
+/// An attribute representing a floating point value.
+struct FloatAttributeStorage final
+    : public AttributeStorage,
+      public llvm::TrailingObjects<FloatAttributeStorage, uint64_t> {
+  using KeyTy = std::pair<Type, APFloat>;
+
+  FloatAttributeStorage(const llvm::fltSemantics &semantics, Type type,
+                        size_t numObjects)
+      : AttributeStorage(type), semantics(semantics), numObjects(numObjects) {}
+
+  /// Key equality and hash functions.
+  bool operator==(const KeyTy &key) const {
+    return key.first == getType() && key.second.bitwiseIsEqual(getValue());
+  }
+  static unsigned hashKey(const KeyTy &key) {
+    return llvm::hash_combine(key.first, llvm::hash_value(key.second));
+  }
+
+  /// Construct a key with a type and double.
+  static KeyTy getKey(Type type, double value) {
+    // Treat BF16 as double because it is not supported in LLVM's APFloat.
+    // TODO(b/121118307): add BF16 support to APFloat?
+    if (type.isBF16() || type.isF64())
+      return KeyTy(type, APFloat(value));
+
+    // This handles, e.g., F16 because there is no APFloat constructor for it.
+    bool unused;
+    APFloat val(value);
+    val.convert(type.cast<FloatType>().getFloatSemantics(),
+                APFloat::rmNearestTiesToEven, &unused);
+    return KeyTy(type, val);
+  }
+
+  /// Construct a new storage instance.
+  static FloatAttributeStorage *construct(AttributeStorageAllocator &allocator,
+                                          const KeyTy &key) {
+    const auto &apint = key.second.bitcastToAPInt();
+
+    // Here one word's bitwidth equals to that of uint64_t.
+    auto elements = ArrayRef<uint64_t>(apint.getRawData(), apint.getNumWords());
+
+    auto byteSize =
+        FloatAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size());
+    auto rawMem = allocator.allocate(byteSize, alignof(FloatAttributeStorage));
+    auto result = ::new (rawMem) FloatAttributeStorage(
+        key.second.getSemantics(), key.first, elements.size());
+    std::uninitialized_copy(elements.begin(), elements.end(),
+                            result->getTrailingObjects<uint64_t>());
+    return result;
+  }
+
+  /// Returns an APFloat representing the stored value.
+  APFloat getValue() const {
+    auto val = APInt(APFloat::getSizeInBits(semantics),
+                     {getTrailingObjects<uint64_t>(), numObjects});
+    return APFloat(semantics, val);
+  }
+
+  const llvm::fltSemantics &semantics;
+  size_t numObjects;
+};
+
+/// An attribute representing a integral value.
+struct IntegerAttributeStorage final
+    : public AttributeStorage,
+      public llvm::TrailingObjects<IntegerAttributeStorage, uint64_t> {
+  using KeyTy = std::pair<Type, APInt>;
+
+  IntegerAttributeStorage(Type type, size_t numObjects)
+      : AttributeStorage(type), numObjects(numObjects) {
+    assert((type.isIndex() || type.isa<IntegerType>()) && "invalid type");
+  }
+
+  /// Key equality and hash functions.
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(getType(), getValue());
+  }
+  static unsigned hashKey(const KeyTy &key) {
+    return llvm::hash_combine(key.first, llvm::hash_value(key.second));
+  }
+
+  /// Construct a new storage instance.
+  static IntegerAttributeStorage *
+  construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+    Type type;
+    APInt value;
+    std::tie(type, value) = key;
+
+    auto elements = ArrayRef<uint64_t>(value.getRawData(), value.getNumWords());
+    auto size =
+        IntegerAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size());
+    auto rawMem = allocator.allocate(size, alignof(IntegerAttributeStorage));
+    auto result = ::new (rawMem) IntegerAttributeStorage(type, elements.size());
+    std::uninitialized_copy(elements.begin(), elements.end(),
+                            result->getTrailingObjects<uint64_t>());
+    return result;
+  }
+
+  /// Returns an APInt representing the stored value.
+  APInt getValue() const {
+    if (getType().isIndex())
+      return APInt(64, {getTrailingObjects<uint64_t>(), numObjects});
+    return APInt(getType().getIntOrFloatBitWidth(),
+                 {getTrailingObjects<uint64_t>(), numObjects});
+  }
+
+  size_t numObjects;
+};
+
+// An attribute representing a reference to an integer set.
+struct IntegerSetAttributeStorage : public AttributeStorage {
+  using KeyTy = IntegerSet;
+
+  IntegerSetAttributeStorage(IntegerSet value) : value(value) {}
+
+  /// Key equality function.
+  bool operator==(const KeyTy &key) const { return key == value; }
+
+  /// Construct a new storage instance.
+  static IntegerSetAttributeStorage *
+  construct(AttributeStorageAllocator &allocator, KeyTy key) {
+    return new (allocator.allocate<IntegerSetAttributeStorage>())
+        IntegerSetAttributeStorage(key);
+  }
+
+  IntegerSet value;
+};
+
+/// Opaque Attribute Storage and Uniquing.
+struct OpaqueAttributeStorage : public AttributeStorage {
+  OpaqueAttributeStorage(Identifier dialectNamespace, StringRef attrData,
+                         Type type)
+      : AttributeStorage(type), dialectNamespace(dialectNamespace),
+        attrData(attrData) {}
+
+  /// The hash key used for uniquing.
+  using KeyTy = std::tuple<Identifier, StringRef, Type>;
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(dialectNamespace, attrData, getType());
+  }
+
+  static OpaqueAttributeStorage *construct(AttributeStorageAllocator &allocator,
+                                           const KeyTy &key) {
+    return new (allocator.allocate<OpaqueAttributeStorage>())
+        OpaqueAttributeStorage(std::get<0>(key),
+                               allocator.copyInto(std::get<1>(key)),
+                               std::get<2>(key));
+  }
+
+  // The dialect namespace.
+  Identifier dialectNamespace;
+
+  // The parser attribute data for this opaque attribute.
+  StringRef attrData;
+};
+
+/// An attribute representing a string value.
+struct StringAttributeStorage : public AttributeStorage {
+  using KeyTy = std::pair<StringRef, Type>;
+
+  StringAttributeStorage(StringRef value, Type type)
+      : AttributeStorage(type), value(value) {}
+
+  /// Key equality function.
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(value, getType());
+  }
+
+  /// Construct a new storage instance.
+  static StringAttributeStorage *construct(AttributeStorageAllocator &allocator,
+                                           const KeyTy &key) {
+    return new (allocator.allocate<StringAttributeStorage>())
+        StringAttributeStorage(allocator.copyInto(key.first), key.second);
+  }
+
+  StringRef value;
+};
+
+/// An attribute representing a reference to a type.
+struct TypeAttributeStorage : public AttributeStorage {
+  using KeyTy = Type;
+
+  TypeAttributeStorage(Type value) : value(value) {}
+
+  /// Key equality function.
+  bool operator==(const KeyTy &key) const { return key == value; }
+
+  /// Construct a new storage instance.
+  static TypeAttributeStorage *construct(AttributeStorageAllocator &allocator,
+                                         KeyTy key) {
+    return new (allocator.allocate<TypeAttributeStorage>())
+        TypeAttributeStorage(key);
+  }
+
+  Type value;
+};
+
+//===----------------------------------------------------------------------===//
+// Elements Attributes
+//===----------------------------------------------------------------------===//
+
+/// An attribute representing a reference to a dense vector or tensor object.
+struct DenseElementsAttributeStorage : public AttributeStorage {
+  struct KeyTy {
+    KeyTy(ShapedType type, ArrayRef<char> data, llvm::hash_code hashCode,
+          bool isSplat = false)
+        : type(type), data(data), hashCode(hashCode), isSplat(isSplat) {}
+
+    /// The type of the dense elements.
+    ShapedType type;
+
+    /// The raw buffer for the data storage.
+    ArrayRef<char> data;
+
+    /// The computed hash code for the storage data.
+    llvm::hash_code hashCode;
+
+    /// A boolean that indicates if this data is a splat or not.
+    bool isSplat;
+  };
+
+  DenseElementsAttributeStorage(ShapedType ty, ArrayRef<char> data,
+                                bool isSplat = false)
+      : AttributeStorage(ty), data(data), isSplat(isSplat) {}
+
+  /// Compare this storage instance with the provided key.
+  bool operator==(const KeyTy &key) const {
+    if (key.type != getType())
+      return false;
+
+    // For boolean splats we need to explicitly check that the first bit is the
+    // same. Boolean values are packed at the bit level, and even though a splat
+    // is detected the rest of the bits in the first byte may differ from the
+    // splat value.
+    if (key.type.getElementTypeBitWidth() == 1) {
+      if (key.isSplat != isSplat)
+        return false;
+      if (isSplat)
+        return (key.data.front() & 1) == data.front();
+    }
+
+    // Otherwise, we can default to just checking the data.
+    return key.data == data;
+  }
+
+  /// Construct a key from a shaped type, raw data buffer, and a flag that
+  /// signals if the data is already known to be a splat. Callers to this
+  /// function are expected to tag preknown splat values when possible, e.g. one
+  /// element shapes.
+  static KeyTy getKey(ShapedType ty, ArrayRef<char> data, bool isKnownSplat) {
+    // Handle an empty storage instance.
+    if (data.empty())
+      return KeyTy(ty, data, 0);
+
+    // If the data is already known to be a splat, the key hash value is
+    // directly the data buffer.
+    if (isKnownSplat)
+      return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat);
+
+    // Otherwise, we need to check if the data corresponds to a splat or not.
+
+    // Handle the simple case of only one element.
+    size_t numElements = ty.getNumElements();
+    assert(numElements != 1 && "splat of 1 element should already be detected");
+
+    // Handle boolean values directly as they are packed to 1-bit.
+    size_t elementWidth = ty.getElementTypeBitWidth();
+    if (elementWidth == 1)
+      return getKeyForBoolData(ty, data, numElements);
+
+    // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
+    // with double semantics.
+    if (ty.getElementType().isBF16())
+      elementWidth = 64;
+
+    // Non 1-bit dense elements are padded to 8-bits.
+    size_t storageSize = llvm::divideCeil(elementWidth, CHAR_BIT);
+    assert(((data.size() / storageSize) == numElements) &&
+           "data does not hold expected number of elements");
+
+    // Create the initial hash value with just the first element.
+    auto firstElt = data.take_front(storageSize);
+    auto hashVal = llvm::hash_value(firstElt);
+
+    // Check to see if this storage represents a splat. If it doesn't then
+    // combine the hash for the data starting with the first non splat element.
+    for (size_t i = storageSize, e = data.size(); i != e; i += storageSize)
+      if (memcmp(data.data(), &data[i], storageSize))
+        return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
+
+    // Otherwise, this is a splat so just return the hash of the first element.
+    return KeyTy(ty, firstElt, hashVal, /*isSplat=*/true);
+  }
+
+  /// Construct a key with a set of boolean data.
+  static KeyTy getKeyForBoolData(ShapedType ty, ArrayRef<char> data,
+                                 size_t numElements) {
+    ArrayRef<char> splatData = data;
+    bool splatValue = splatData.front() & 1;
+
+    // Helper functor to generate a KeyTy for a boolean splat value.
+    auto generateSplatKey = [=] {
+      return KeyTy(ty, data.take_front(1),
+                   llvm::hash_value(ArrayRef<char>(splatValue ? 1 : 0)),
+                   /*isSplat=*/true);
+    };
+
+    // Handle the case where the potential splat value is 1 and the number of
+    // elements is non 8-bit aligned.
+    size_t numOddElements = numElements % CHAR_BIT;
+    if (splatValue && numOddElements != 0) {
+      // Check that all bits are set in the last value.
+      char lastElt = splatData.back();
+      if (lastElt != llvm::maskTrailingOnes<unsigned char>(numOddElements))
+        return KeyTy(ty, data, llvm::hash_value(data));
+
+      // If this is the only element, the data is known to be a splat.
+      if (splatData.size() == 1)
+        return generateSplatKey();
+      splatData = splatData.drop_back();
+    }
+
+    // Check that the data buffer corresponds to a splat of the proper mask.
+    char mask = splatValue ? ~0 : 0;
+    return llvm::all_of(splatData, [mask](char c) { return c == mask; })
+               ? generateSplatKey()
+               : KeyTy(ty, data, llvm::hash_value(data));
+  }
+
+  /// Hash the key for the storage.
+  static llvm::hash_code hashKey(const KeyTy &key) {
+    return llvm::hash_combine(key.type, key.hashCode);
+  }
+
+  /// Construct a new storage instance.
+  static DenseElementsAttributeStorage *
+  construct(AttributeStorageAllocator &allocator, KeyTy key) {
+    // If the data buffer is non-empty, we copy it into the allocator with a
+    // 64-bit alignment.
+    ArrayRef<char> copy, data = key.data;
+    if (!data.empty()) {
+      char *rawData = reinterpret_cast<char *>(
+          allocator.allocate(data.size(), alignof(uint64_t)));
+      std::memcpy(rawData, data.data(), data.size());
+
+      // If this is a boolean splat, make sure only the first bit is used.
+      if (key.isSplat && key.type.getElementTypeBitWidth() == 1)
+        rawData[0] &= 1;
+      copy = ArrayRef<char>(rawData, data.size());
+    }
+
+    return new (allocator.allocate<DenseElementsAttributeStorage>())
+        DenseElementsAttributeStorage(key.type, copy, key.isSplat);
+  }
+
+  ArrayRef<char> data;
+  bool isSplat;
+};
+
+/// An attribute representing a reference to a tensor constant with opaque
+/// content.
+struct OpaqueElementsAttributeStorage : public AttributeStorage {
+  using KeyTy = std::tuple<Type, Dialect *, StringRef>;
+
+  OpaqueElementsAttributeStorage(Type type, Dialect *dialect, StringRef bytes)
+      : AttributeStorage(type), dialect(dialect), bytes(bytes) {}
+
+  /// Key equality and hash functions.
+  bool operator==(const KeyTy &key) const {
+    return key == std::make_tuple(getType(), dialect, bytes);
+  }
+  static unsigned hashKey(const KeyTy &key) {
+    return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
+                              std::get<2>(key));
+  }
+
+  /// Construct a new storage instance.
+  static OpaqueElementsAttributeStorage *
+  construct(AttributeStorageAllocator &allocator, KeyTy key) {
+    // TODO(b/131468830): Provide a way to avoid copying content of large opaque
+    // tensors This will likely require a new reference attribute kind.
+    return new (allocator.allocate<OpaqueElementsAttributeStorage>())
+        OpaqueElementsAttributeStorage(std::get<0>(key), std::get<1>(key),
+                                       allocator.copyInto(std::get<2>(key)));
+  }
+
+  Dialect *dialect;
+  StringRef bytes;
+};
+
+/// An attribute representing a reference to a sparse vector or tensor object.
+struct SparseElementsAttributeStorage : public AttributeStorage {
+  using KeyTy = std::tuple<Type, DenseIntElementsAttr, DenseElementsAttr>;
+
+  SparseElementsAttributeStorage(Type type, DenseIntElementsAttr indices,
+                                 DenseElementsAttr values)
+      : AttributeStorage(type), indices(indices), values(values) {}
+
+  /// Key equality and hash functions.
+  bool operator==(const KeyTy &key) const {
+    return key == std::make_tuple(getType(), indices, values);
+  }
+  static unsigned hashKey(const KeyTy &key) {
+    return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
+                              std::get<2>(key));
+  }
+
+  /// Construct a new storage instance.
+  static SparseElementsAttributeStorage *
+  construct(AttributeStorageAllocator &allocator, KeyTy key) {
+    return new (allocator.allocate<SparseElementsAttributeStorage>())
+        SparseElementsAttributeStorage(std::get<0>(key), std::get<1>(key),
+                                       std::get<2>(key));
+  }
+
+  DenseIntElementsAttr indices;
+  DenseElementsAttr values;
+};
+} // namespace detail
+} // namespace mlir
+
+#endif // ATTRIBUTEDETAIL_H_
diff --git a/third_party/mlir/lib/IR/Attributes.cpp b/third_party/mlir/lib/IR/Attributes.cpp
new file mode 100644
index 0000000..e2a401c
--- /dev/null
+++ b/third_party/mlir/lib/IR/Attributes.cpp
@@ -0,0 +1,1041 @@
+//===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Attributes.h"
+#include "AttributeDetail.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/Types.h"
+#include "llvm/ADT/Twine.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+//===----------------------------------------------------------------------===//
+// AttributeStorage
+//===----------------------------------------------------------------------===//
+
+AttributeStorage::AttributeStorage(Type type)
+    : type(type.getAsOpaquePointer()) {}
+AttributeStorage::AttributeStorage() : type(nullptr) {}
+
+Type AttributeStorage::getType() const {
+  return Type::getFromOpaquePointer(type);
+}
+void AttributeStorage::setType(Type newType) {
+  type = newType.getAsOpaquePointer();
+}
+
+//===----------------------------------------------------------------------===//
+// Attribute
+//===----------------------------------------------------------------------===//
+
+/// Return the type of this attribute.
+Type Attribute::getType() const { return impl->getType(); }
+
+/// Return the context this attribute belongs to.
+MLIRContext *Attribute::getContext() const { return getType().getContext(); }
+
+/// Get the dialect this attribute is registered to.
+Dialect &Attribute::getDialect() const { return impl->getDialect(); }
+
+//===----------------------------------------------------------------------===//
+// AffineMapAttr
+//===----------------------------------------------------------------------===//
+
+AffineMapAttr AffineMapAttr::get(AffineMap value) {
+  return Base::get(value.getResult(0).getContext(),
+                   StandardAttributes::AffineMap, value);
+}
+
+AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
+
+//===----------------------------------------------------------------------===//
+// ArrayAttr
+//===----------------------------------------------------------------------===//
+
+ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
+  return Base::get(context, StandardAttributes::Array, value);
+}
+
+ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
+
+//===----------------------------------------------------------------------===//
+// BoolAttr
+//===----------------------------------------------------------------------===//
+
+bool BoolAttr::getValue() const { return getImpl()->value; }
+
+//===----------------------------------------------------------------------===//
+// DictionaryAttr
+//===----------------------------------------------------------------------===//
+
+/// Perform a three-way comparison between the names of the specified
+/// NamedAttributes.
+static int compareNamedAttributes(const NamedAttribute *lhs,
+                                  const NamedAttribute *rhs) {
+  return lhs->first.str().compare(rhs->first.str());
+}
+
+DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
+                                   MLIRContext *context) {
+  assert(llvm::all_of(value,
+                      [](const NamedAttribute &attr) { return attr.second; }) &&
+         "value cannot have null entries");
+
+  // We need to sort the element list to canonicalize it, but we also don't want
+  // to do a ton of work in the super common case where the element list is
+  // already sorted.
+  SmallVector<NamedAttribute, 8> storage;
+  switch (value.size()) {
+  case 0:
+    break;
+  case 1:
+    // A single element is already sorted.
+    break;
+  case 2:
+    assert(value[0].first != value[1].first &&
+           "DictionaryAttr element names must be unique");
+
+    // Don't invoke a general sort for two element case.
+    if (value[0].first.strref() > value[1].first.strref()) {
+      storage.push_back(value[1]);
+      storage.push_back(value[0]);
+      value = storage;
+    }
+    break;
+  default:
+    // Check to see they are sorted already.
+    bool isSorted = true;
+    for (unsigned i = 0, e = value.size() - 1; i != e; ++i) {
+      if (value[i].first.strref() > value[i + 1].first.strref()) {
+        isSorted = false;
+        break;
+      }
+    }
+    // If not, do a general sort.
+    if (!isSorted) {
+      storage.append(value.begin(), value.end());
+      llvm::array_pod_sort(storage.begin(), storage.end(),
+                           compareNamedAttributes);
+      value = storage;
+    }
+
+    // Ensure that the attribute elements are unique.
+    assert(std::adjacent_find(value.begin(), value.end(),
+                              [](NamedAttribute l, NamedAttribute r) {
+                                return l.first == r.first;
+                              }) == value.end() &&
+           "DictionaryAttr element names must be unique");
+  }
+
+  return Base::get(context, StandardAttributes::Dictionary, value);
+}
+
+ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
+  return getImpl()->getElements();
+}
+
+/// Return the specified attribute if present, null otherwise.
+Attribute DictionaryAttr::get(StringRef name) const {
+  for (auto elt : getValue())
+    if (elt.first.is(name))
+      return elt.second;
+  return nullptr;
+}
+Attribute DictionaryAttr::get(Identifier name) const {
+  for (auto elt : getValue())
+    if (elt.first == name)
+      return elt.second;
+  return nullptr;
+}
+
+DictionaryAttr::iterator DictionaryAttr::begin() const {
+  return getValue().begin();
+}
+DictionaryAttr::iterator DictionaryAttr::end() const {
+  return getValue().end();
+}
+size_t DictionaryAttr::size() const { return getValue().size(); }
+
+//===----------------------------------------------------------------------===//
+// FloatAttr
+//===----------------------------------------------------------------------===//
+
+FloatAttr FloatAttr::get(Type type, double value) {
+  return Base::get(type.getContext(), StandardAttributes::Float, type, value);
+}
+
+FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
+  return Base::getChecked(loc, type.getContext(), StandardAttributes::Float,
+                          type, value);
+}
+
+FloatAttr FloatAttr::get(Type type, const APFloat &value) {
+  return Base::get(type.getContext(), StandardAttributes::Float, type, value);
+}
+
+FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
+  return Base::getChecked(loc, type.getContext(), StandardAttributes::Float,
+                          type, value);
+}
+
+APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
+
+double FloatAttr::getValueAsDouble() const {
+  return getValueAsDouble(getValue());
+}
+double FloatAttr::getValueAsDouble(APFloat value) {
+  if (&value.getSemantics() != &APFloat::IEEEdouble()) {
+    bool losesInfo = false;
+    value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+  }
+  return value.convertToDouble();
+}
+
+/// Verify construction invariants.
+static LogicalResult verifyFloatTypeInvariants(llvm::Optional<Location> loc,
+                                               Type type) {
+  if (!type.isa<FloatType>()) {
+    if (loc)
+      emitError(*loc, "expected floating point type");
+    return failure();
+  }
+  return success();
+}
+
+LogicalResult FloatAttr::verifyConstructionInvariants(
+    llvm::Optional<Location> loc, MLIRContext *ctx, Type type, double value) {
+  return verifyFloatTypeInvariants(loc, type);
+}
+
+LogicalResult
+FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc,
+                                        MLIRContext *ctx, Type type,
+                                        const APFloat &value) {
+  // Verify that the type is correct.
+  if (failed(verifyFloatTypeInvariants(loc, type)))
+    return failure();
+
+  // Verify that the type semantics match that of the value.
+  if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
+    if (loc)
+      emitError(*loc,
+                "FloatAttr type doesn't match the type implied by its value");
+    return failure();
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// SymbolRefAttr
+//===----------------------------------------------------------------------===//
+
+SymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
+  return Base::get(ctx, StandardAttributes::SymbolRef, value,
+                   NoneType::get(ctx));
+}
+
+StringRef SymbolRefAttr::getValue() const { return getImpl()->value; }
+
+//===----------------------------------------------------------------------===//
+// IntegerAttr
+//===----------------------------------------------------------------------===//
+
+IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
+  return Base::get(type.getContext(), StandardAttributes::Integer, type, value);
+}
+
+IntegerAttr IntegerAttr::get(Type type, int64_t value) {
+  // This uses 64 bit APInts by default for index type.
+  if (type.isIndex())
+    return get(type, APInt(64, value));
+
+  auto intType = type.cast<IntegerType>();
+  return get(type, APInt(intType.getWidth(), value));
+}
+
+APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
+
+int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); }
+
+//===----------------------------------------------------------------------===//
+// IntegerSetAttr
+//===----------------------------------------------------------------------===//
+
+IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
+  return Base::get(value.getConstraint(0).getContext(),
+                   StandardAttributes::IntegerSet, value);
+}
+
+IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
+
+//===----------------------------------------------------------------------===//
+// OpaqueAttr
+//===----------------------------------------------------------------------===//
+
+OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
+                           MLIRContext *context) {
+  return Base::get(context, StandardAttributes::Opaque, dialect, attrData,
+                   type);
+}
+
+OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
+                                  Type type, Location location) {
+  return Base::getChecked(location, type.getContext(),
+                          StandardAttributes::Opaque, dialect, attrData, type);
+}
+
+/// Returns the dialect namespace of the opaque attribute.
+Identifier OpaqueAttr::getDialectNamespace() const {
+  return getImpl()->dialectNamespace;
+}
+
+/// Returns the raw attribute data of the opaque attribute.
+StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
+
+/// Verify the construction of an opaque attribute.
+LogicalResult OpaqueAttr::verifyConstructionInvariants(
+    llvm::Optional<Location> loc, MLIRContext *context, Identifier dialect,
+    StringRef attrData, Type type) {
+  if (!Dialect::isValidNamespace(dialect.strref())) {
+    if (loc)
+      emitError(*loc) << "invalid dialect namespace '" << dialect << "'";
+    return failure();
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// StringAttr
+//===----------------------------------------------------------------------===//
+
+StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
+  return get(bytes, NoneType::get(context));
+}
+
+/// Get an instance of a StringAttr with the given string and Type.
+StringAttr StringAttr::get(StringRef bytes, Type type) {
+  return Base::get(type.getContext(), StandardAttributes::String, bytes, type);
+}
+
+StringRef StringAttr::getValue() const { return getImpl()->value; }
+
+//===----------------------------------------------------------------------===//
+// TypeAttr
+//===----------------------------------------------------------------------===//
+
+TypeAttr TypeAttr::get(Type value) {
+  return Base::get(value.getContext(), StandardAttributes::Type, value);
+}
+
+Type TypeAttr::getValue() const { return getImpl()->value; }
+
+//===----------------------------------------------------------------------===//
+// ElementsAttr
+//===----------------------------------------------------------------------===//
+
+ShapedType ElementsAttr::getType() const {
+  return Attribute::getType().cast<ShapedType>();
+}
+
+/// Return the value at the given index. If index does not refer to a valid
+/// element, then a null attribute is returned.
+Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
+  switch (getKind()) {
+  case StandardAttributes::DenseElements:
+    return cast<DenseElementsAttr>().getValue(index);
+  case StandardAttributes::OpaqueElements:
+    return cast<OpaqueElementsAttr>().getValue(index);
+  case StandardAttributes::SparseElements:
+    return cast<SparseElementsAttr>().getValue(index);
+  default:
+    llvm_unreachable("unknown ElementsAttr kind");
+  }
+}
+
+ElementsAttr ElementsAttr::mapValues(
+    Type newElementType,
+    llvm::function_ref<APInt(const APInt &)> mapping) const {
+  switch (getKind()) {
+  case StandardAttributes::DenseElements:
+    return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
+  default:
+    llvm_unreachable("unsupported ElementsAttr subtype");
+  }
+}
+
+ElementsAttr ElementsAttr::mapValues(
+    Type newElementType,
+    llvm::function_ref<APInt(const APFloat &)> mapping) const {
+  switch (getKind()) {
+  case StandardAttributes::DenseElements:
+    return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
+  default:
+    llvm_unreachable("unsupported ElementsAttr subtype");
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// DenseElementAttr Utilities
+//===----------------------------------------------------------------------===//
+
+static size_t getDenseElementBitwidth(Type eltType) {
+  // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
+  // with double semantics.
+  return eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
+}
+
+/// Get the bitwidth of a dense element type within the buffer.
+/// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
+static size_t getDenseElementStorageWidth(size_t origWidth) {
+  return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
+}
+
+/// Set a bit to a specific value.
+static void setBit(char *rawData, size_t bitPos, bool value) {
+  if (value)
+    rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
+  else
+    rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
+}
+
+/// Return the value of the specified bit.
+static bool getBit(const char *rawData, size_t bitPos) {
+  return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
+}
+
+/// Writes value to the bit position `bitPos` in array `rawData`.
+static void writeBits(char *rawData, size_t bitPos, APInt value) {
+  size_t bitWidth = value.getBitWidth();
+
+  // If the bitwidth is 1 we just toggle the specific bit.
+  if (bitWidth == 1)
+    return setBit(rawData, bitPos, value.isOneValue());
+
+  // Otherwise, the bit position is guaranteed to be byte aligned.
+  assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
+  std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
+              llvm::divideCeil(bitWidth, CHAR_BIT),
+              rawData + (bitPos / CHAR_BIT));
+}
+
+/// Reads the next `bitWidth` bits from the bit position `bitPos` in array
+/// `rawData`.
+static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
+  // Handle a boolean bit position.
+  if (bitWidth == 1)
+    return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
+
+  // Otherwise, the bit position must be 8-bit aligned.
+  assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
+  APInt result(bitWidth, 0);
+  std::copy_n(
+      rawData + (bitPos / CHAR_BIT), llvm::divideCeil(bitWidth, CHAR_BIT),
+      const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
+  return result;
+}
+
+/// Returns if 'values' corresponds to a splat, i.e. one element, or has the
+/// same element count as 'type'.
+template <typename Values>
+static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
+  return (values.size() == 1) ||
+         (type.getNumElements() == static_cast<int64_t>(values.size()));
+}
+
+//===----------------------------------------------------------------------===//
+// DenseElementAttr Iterators
+//===----------------------------------------------------------------------===//
+
+/// Constructs a new iterator.
+DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
+    DenseElementsAttr attr, size_t index)
+    : indexed_accessor_iterator<AttributeElementIterator, const void *,
+                                Attribute, Attribute, Attribute>(
+          attr.getAsOpaquePointer(), index) {}
+
+/// Accesses the Attribute value at this iterator position.
+Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
+  auto owner = getFromOpaquePointer(object).cast<DenseElementsAttr>();
+  Type eltTy = owner.getType().getElementType();
+  if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) {
+    if (intEltTy.getWidth() == 1)
+      return BoolAttr::get((*IntElementIterator(owner, index)).isOneValue(),
+                           owner.getContext());
+    return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
+  }
+  if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
+    IntElementIterator intIt(owner, index);
+    FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
+    return FloatAttr::get(eltTy, *floatIt);
+  }
+  llvm_unreachable("unexpected element type");
+}
+
+/// Constructs a new iterator.
+DenseElementsAttr::IntElementIterator::IntElementIterator(
+    DenseElementsAttr attr, size_t index)
+    : indexed_accessor_iterator<IntElementIterator, const char *, APInt, APInt,
+                                APInt>(attr.getRawData().data(), index),
+      bitWidth(getDenseElementBitwidth(attr.getType().getElementType())) {}
+
+/// Accesses the raw APInt value at this iterator position.
+APInt DenseElementsAttr::IntElementIterator::operator*() const {
+  return readBits(object, index * getDenseElementStorageWidth(bitWidth),
+                  bitWidth);
+}
+
+DenseElementsAttr::FloatElementIterator::FloatElementIterator(
+    const llvm::fltSemantics &smt, IntElementIterator it)
+    : llvm::mapped_iterator<IntElementIterator,
+                            std::function<APFloat(const APInt &)>>(
+          it, [&](const APInt &val) { return APFloat(smt, val); }) {}
+
+//===----------------------------------------------------------------------===//
+// DenseElementsAttr
+//===----------------------------------------------------------------------===//
+
+DenseElementsAttr DenseElementsAttr::get(ShapedType type,
+                                         ArrayRef<Attribute> values) {
+  assert(type.getElementType().isIntOrFloat() &&
+         "expected int or float element type");
+  assert(hasSameElementsOrSplat(type, values));
+
+  auto eltType = type.getElementType();
+  size_t bitWidth = getDenseElementBitwidth(eltType);
+  size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
+
+  // Compress the attribute values into a character buffer.
+  SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
+                            values.size());
+  APInt intVal;
+  for (unsigned i = 0, e = values.size(); i < e; ++i) {
+    assert(eltType == values[i].getType() &&
+           "expected attribute value to have element type");
+
+    switch (eltType.getKind()) {
+    case StandardTypes::BF16:
+    case StandardTypes::F16:
+    case StandardTypes::F32:
+    case StandardTypes::F64:
+      intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
+      break;
+    case StandardTypes::Integer:
+      intVal = values[i].isa<BoolAttr>()
+                   ? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0)
+                   : values[i].cast<IntegerAttr>().getValue();
+      break;
+    default:
+      llvm_unreachable("unexpected element type");
+    }
+    assert(intVal.getBitWidth() == bitWidth &&
+           "expected value to have same bitwidth as element type");
+    writeBits(data.data(), i * storageBitWidth, intVal);
+  }
+  return getRaw(type, data, /*isSplat=*/(values.size() == 1));
+}
+
+DenseElementsAttr DenseElementsAttr::get(ShapedType type,
+                                         ArrayRef<bool> values) {
+  assert(hasSameElementsOrSplat(type, values));
+  assert(type.getElementType().isInteger(1));
+
+  std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
+  for (int i = 0, e = values.size(); i != e; ++i)
+    setBit(buff.data(), i, values[i]);
+  return getRaw(type, buff, /*isSplat=*/(values.size() == 1));
+}
+
+/// Constructs a dense integer elements attribute from an array of APInt
+/// values. Each APInt value is expected to have the same bitwidth as the
+/// element type of 'type'.
+DenseElementsAttr DenseElementsAttr::get(ShapedType type,
+                                         ArrayRef<APInt> values) {
+  assert(type.getElementType().isa<IntegerType>());
+  return getRaw(type, values);
+}
+
+// Constructs a dense float elements attribute from an array of APFloat
+// values. Each APFloat value is expected to have the same bitwidth as the
+// element type of 'type'.
+DenseElementsAttr DenseElementsAttr::get(ShapedType type,
+                                         ArrayRef<APFloat> values) {
+  assert(type.getElementType().isa<FloatType>());
+
+  // Convert the APFloat values to APInt and create a dense elements attribute.
+  std::vector<APInt> intValues(values.size());
+  for (unsigned i = 0, e = values.size(); i != e; ++i)
+    intValues[i] = values[i].bitcastToAPInt();
+  return getRaw(type, intValues);
+}
+
+// Constructs a dense elements attribute from an array of raw APInt values.
+// Each APInt value is expected to have the same bitwidth as the element type
+// of 'type'.
+DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
+                                            ArrayRef<APInt> values) {
+  assert(hasSameElementsOrSplat(type, values));
+
+  size_t bitWidth = getDenseElementBitwidth(type.getElementType());
+  size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
+  std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
+                                values.size());
+  for (unsigned i = 0, e = values.size(); i != e; ++i) {
+    assert(values[i].getBitWidth() == bitWidth);
+    writeBits(elementData.data(), i * storageBitWidth, values[i]);
+  }
+  return getRaw(type, elementData, /*isSplat=*/(values.size() == 1));
+}
+
+DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
+                                            ArrayRef<char> data, bool isSplat) {
+  assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
+         "type must be ranked tensor or vector");
+  assert(type.hasStaticShape() && "type must have static shape");
+  return Base::get(type.getContext(), StandardAttributes::DenseElements, type,
+                   data, isSplat);
+}
+
+/// Check the information for a c++ data type, check if this type is valid for
+/// the current attribute. This method is used to verify specific type
+/// invariants that the templatized 'getValues' method cannot.
+static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize,
+                              bool isInt) {
+  // Make sure that the data element size is the same as the type element width.
+  if ((dataEltSize * CHAR_BIT) != type.getElementTypeBitWidth())
+    return false;
+
+  // Check that the element type is valid.
+  return isInt ? type.getElementType().isa<IntegerType>()
+               : type.getElementType().isa<FloatType>();
+}
+
+/// Overload of the 'getRaw' method that asserts that the given type is of
+/// integer type. This method is used to verify type invariants that the
+/// templatized 'get' method cannot.
+DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
+                                                      ArrayRef<char> data,
+                                                      int64_t dataEltSize,
+                                                      bool isInt) {
+  assert(::isValidIntOrFloat(type, dataEltSize, isInt));
+
+  int64_t numElements = data.size() / dataEltSize;
+  assert(numElements == 1 || numElements == type.getNumElements());
+  return getRaw(type, data, /*isSplat=*/numElements == 1);
+}
+
+/// A method used to verify specific type invariants that the templatized 'get'
+/// method cannot.
+bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize,
+                                          bool isInt) const {
+  return ::isValidIntOrFloat(getType(), dataEltSize, isInt);
+}
+
+/// Return the raw storage data held by this attribute.
+ArrayRef<char> DenseElementsAttr::getRawData() const {
+  return static_cast<ImplType *>(impl)->data;
+}
+
+/// Returns the number of raw elements held by this attribute.
+size_t DenseElementsAttr::rawSize() const {
+  return isSplat() ? 1 : getType().getNumElements();
+}
+
+/// Returns if this attribute corresponds to a splat, i.e. if all element
+/// values are the same.
+bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; }
+
+/// If this attribute corresponds to a splat, then get the splat value.
+/// Otherwise, return null.
+Attribute DenseElementsAttr::getSplatValue() const {
+  return isSplat() ? *attr_value_begin() : Attribute();
+}
+
+/// Return the value at the given index. If index does not refer to a valid
+/// element, then a null attribute is returned.
+Attribute DenseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
+  auto type = getType();
+
+  // Verify that the rank of the indices matches the held type.
+  auto rank = type.getRank();
+  if (rank != static_cast<int64_t>(index.size()))
+    return Attribute();
+
+  // Verify that all of the indices are within the shape dimensions.
+  auto shape = type.getShape();
+  for (unsigned i = 0; i != rank; ++i)
+    if (shape[i] <= static_cast<int64_t>(index[i]))
+      return Attribute();
+
+  // If this is a splat, return the splat value directly.
+  if (isSplat())
+    return getSplatValue();
+
+  // Reduce the provided multidimensional index into a 1D index.
+  uint64_t valueIndex = 0;
+  uint64_t dimMultiplier = 1;
+  for (int i = rank - 1; i >= 0; --i) {
+    valueIndex += index[i] * dimMultiplier;
+    dimMultiplier *= shape[i];
+  }
+
+  // Return the element stored at the 1D index.
+  auto elementType = getType().getElementType();
+  size_t bitWidth = getDenseElementBitwidth(elementType);
+  size_t storageWidth = getDenseElementStorageWidth(bitWidth);
+  APInt rawValueData =
+      readBits(getRawData().data(), valueIndex * storageWidth, bitWidth);
+
+  // Convert the raw value data to an attribute value.
+  if (elementType.isa<IntegerType>())
+    return IntegerAttr::get(elementType, rawValueData);
+  if (auto fType = elementType.dyn_cast<FloatType>())
+    return FloatAttr::get(elementType,
+                          APFloat(fType.getFloatSemantics(), rawValueData));
+  llvm_unreachable("unexpected element type");
+}
+
+/// Return the held element values as a range of Attributes.
+auto DenseElementsAttr::getAttributeValues() const
+    -> llvm::iterator_range<AttributeElementIterator> {
+  return {attr_value_begin(), attr_value_end()};
+}
+auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
+  return AttributeElementIterator(*this, 0);
+}
+auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
+  return AttributeElementIterator(*this, rawSize());
+}
+
+/// Return the held element values as a range of APInts. The element type of
+/// this attribute must be of integer type.
+auto DenseElementsAttr::getIntValues() const
+    -> llvm::iterator_range<IntElementIterator> {
+  assert(getType().getElementType().isa<IntegerType>() &&
+         "expected integer type");
+  return {raw_int_begin(), raw_int_end()};
+}
+auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
+  assert(getType().getElementType().isa<IntegerType>() &&
+         "expected integer type");
+  return raw_int_begin();
+}
+auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
+  assert(getType().getElementType().isa<IntegerType>() &&
+         "expected integer type");
+  return raw_int_end();
+}
+
+/// Return the held element values as a range of APFloat. The element type of
+/// this attribute must be of float type.
+auto DenseElementsAttr::getFloatValues() const
+    -> llvm::iterator_range<FloatElementIterator> {
+  auto elementType = getType().getElementType().cast<FloatType>();
+  assert(elementType.isa<FloatType>() && "expected float type");
+  const auto &elementSemantics = elementType.getFloatSemantics();
+  return {FloatElementIterator(elementSemantics, raw_int_begin()),
+          FloatElementIterator(elementSemantics, raw_int_end())};
+}
+auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
+  return getFloatValues().begin();
+}
+auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
+  return getFloatValues().end();
+}
+
+/// Return a new DenseElementsAttr that has the same data as the current
+/// attribute, but has been reshaped to 'newType'. The new type must have the
+/// same total number of elements as well as element type.
+DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
+  ShapedType curType = getType();
+  if (curType == newType)
+    return *this;
+
+  (void)curType;
+  assert(newType.getElementType() == curType.getElementType() &&
+         "expected the same element type");
+  assert(newType.getNumElements() == curType.getNumElements() &&
+         "expected the same number of elements");
+  return getRaw(newType, getRawData(), isSplat());
+}
+
+DenseElementsAttr DenseElementsAttr::mapValues(
+    Type newElementType,
+    llvm::function_ref<APInt(const APInt &)> mapping) const {
+  return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
+}
+
+DenseElementsAttr DenseElementsAttr::mapValues(
+    Type newElementType,
+    llvm::function_ref<APInt(const APFloat &)> mapping) const {
+  return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
+}
+
+//===----------------------------------------------------------------------===//
+// DenseFPElementsAttr
+//===----------------------------------------------------------------------===//
+
+template <typename Fn, typename Attr>
+static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
+                                Type newElementType,
+                                llvm::SmallVectorImpl<char> &data) {
+  size_t bitWidth = getDenseElementBitwidth(newElementType);
+  size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
+
+  ShapedType newArrayType;
+  if (inType.isa<RankedTensorType>())
+    newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
+  else if (inType.isa<UnrankedTensorType>())
+    newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
+  else if (inType.isa<VectorType>())
+    newArrayType = VectorType::get(inType.getShape(), newElementType);
+  else
+    assert(newArrayType && "Unhandled tensor type");
+
+  data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * attr.rawSize());
+
+  uint64_t elementIdx = 0;
+  for (auto value : attr) {
+    auto newInt = mapping(value);
+    assert(newInt.getBitWidth() == bitWidth);
+    writeBits(data.data(), elementIdx * storageBitWidth, newInt);
+    ++elementIdx;
+  }
+
+  return newArrayType;
+}
+
+DenseElementsAttr DenseFPElementsAttr::mapValues(
+    Type newElementType,
+    llvm::function_ref<APInt(const APFloat &)> mapping) const {
+  llvm::SmallVector<char, 8> elementData;
+  auto newArrayType =
+      mappingHelper(mapping, *this, getType(), newElementType, elementData);
+
+  return getRaw(newArrayType, elementData, isSplat());
+}
+
+/// Method for supporting type inquiry through isa, cast and dyn_cast.
+bool DenseFPElementsAttr::classof(Attribute attr) {
+  return attr.isa<DenseElementsAttr>() &&
+         attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
+}
+
+//===----------------------------------------------------------------------===//
+// DenseIntElementsAttr
+//===----------------------------------------------------------------------===//
+
+DenseElementsAttr DenseIntElementsAttr::mapValues(
+    Type newElementType,
+    llvm::function_ref<APInt(const APInt &)> mapping) const {
+  llvm::SmallVector<char, 8> elementData;
+  auto newArrayType =
+      mappingHelper(mapping, *this, getType(), newElementType, elementData);
+
+  return getRaw(newArrayType, elementData, isSplat());
+}
+
+/// Method for supporting type inquiry through isa, cast and dyn_cast.
+bool DenseIntElementsAttr::classof(Attribute attr) {
+  return attr.isa<DenseElementsAttr>() &&
+         attr.getType().cast<ShapedType>().getElementType().isa<IntegerType>();
+}
+
+//===----------------------------------------------------------------------===//
+// OpaqueElementsAttr
+//===----------------------------------------------------------------------===//
+
+OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
+                                           StringRef bytes) {
+  assert(TensorType::isValidElementType(type.getElementType()) &&
+         "Input element type should be a valid tensor element type");
+  return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type,
+                   dialect, bytes);
+}
+
+StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
+
+/// Return the value at the given index. If index does not refer to a valid
+/// element, then a null attribute is returned.
+Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
+  if (Dialect *dialect = getDialect())
+    return dialect->extractElementHook(*this, index);
+  return Attribute();
+}
+
+Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
+
+bool OpaqueElementsAttr::decode(ElementsAttr &result) {
+  if (auto *d = getDialect())
+    return d->decodeHook(*this, result);
+  return true;
+}
+
+//===----------------------------------------------------------------------===//
+// SparseElementsAttr
+//===----------------------------------------------------------------------===//
+
+SparseElementsAttr SparseElementsAttr::get(ShapedType type,
+                                           DenseElementsAttr indices,
+                                           DenseElementsAttr values) {
+  assert(indices.getType().getElementType().isInteger(64) &&
+         "expected sparse indices to be 64-bit integer values");
+  assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
+         "type must be ranked tensor or vector");
+  assert(type.hasStaticShape() && "type must have static shape");
+  return Base::get(type.getContext(), StandardAttributes::SparseElements, type,
+                   indices.cast<DenseIntElementsAttr>(), values);
+}
+
+DenseIntElementsAttr SparseElementsAttr::getIndices() const {
+  return getImpl()->indices;
+}
+
+DenseElementsAttr SparseElementsAttr::getValues() const {
+  return getImpl()->values;
+}
+
+/// Return the value of the element at the given index.
+Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
+  auto type = getType();
+
+  // Verify that the rank of the indices matches the held type.
+  size_t rank = type.getRank();
+  if (rank != index.size())
+    return Attribute();
+
+  /// Return an attribute corresponding to '0' for the element type.
+  auto getZeroAttr = [=]() -> Attribute {
+    auto eltType = type.getElementType();
+    if (eltType.isa<FloatType>())
+      return FloatAttr::get(eltType, 0);
+    assert(eltType.isa<IntegerType>() && "unexpected element type");
+    return IntegerAttr::get(eltType, 0);
+  };
+
+  // The sparse indices are 64-bit integers, so we can reinterpret the raw data
+  // as a 1-D index array.
+  auto sparseIndices = getIndices();
+  ArrayRef<uint64_t> sparseIndexValues = sparseIndices.getValues<uint64_t>();
+
+  // Check to see if the indices are a splat.
+  if (sparseIndices.isSplat()) {
+    // If the index is also not a splat of the index value, we know that the
+    // value is zero.
+    auto splatIndex = sparseIndexValues.front();
+    if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
+      return getZeroAttr();
+
+    // If the indices are a splat, we also expect the values to be a splat.
+    assert(getValues().isSplat() && "expected splat values");
+    return getValues().getSplatValue();
+  }
+
+  // Build a mapping between known indices and the offset of the stored element.
+  llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
+  auto numSparseIndices = sparseIndices.getType().getDimSize(0);
+  for (size_t i = 0, e = numSparseIndices; i != e; ++i)
+    mappedIndices.try_emplace({&sparseIndexValues[i * rank], rank}, i);
+
+  // Look for the provided index key within the mapped indices. If the provided
+  // index is not found, then return a zero attribute.
+  auto it = mappedIndices.find(index);
+  if (it == mappedIndices.end())
+    return getZeroAttr();
+
+  // Otherwise, return the held sparse value element.
+  return getValues().getValue(it->second);
+}
+
+//===----------------------------------------------------------------------===//
+// NamedAttributeList
+//===----------------------------------------------------------------------===//
+
+NamedAttributeList::NamedAttributeList(ArrayRef<NamedAttribute> attributes) {
+  setAttrs(attributes);
+}
+
+ArrayRef<NamedAttribute> NamedAttributeList::getAttrs() const {
+  return attrs ? attrs.getValue() : llvm::None;
+}
+
+/// Replace the held attributes with ones provided in 'newAttrs'.
+void NamedAttributeList::setAttrs(ArrayRef<NamedAttribute> attributes) {
+  // Don't create an attribute list if there are no attributes.
+  if (attributes.empty())
+    attrs = nullptr;
+  else
+    attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext());
+}
+
+/// Return the specified attribute if present, null otherwise.
+Attribute NamedAttributeList::get(StringRef name) const {
+  return attrs ? attrs.get(name) : nullptr;
+}
+
+/// Return the specified attribute if present, null otherwise.
+Attribute NamedAttributeList::get(Identifier name) const {
+  return attrs ? attrs.get(name) : nullptr;
+}
+
+/// If the an attribute exists with the specified name, change it to the new
+/// value.  Otherwise, add a new attribute with the specified name/value.
+void NamedAttributeList::set(Identifier name, Attribute value) {
+  assert(value && "attributes may never be null");
+
+  // If we already have this attribute, replace it.
+  auto origAttrs = getAttrs();
+  SmallVector<NamedAttribute, 8> newAttrs(origAttrs.begin(), origAttrs.end());
+  for (auto &elt : newAttrs)
+    if (elt.first == name) {
+      elt.second = value;
+      attrs = DictionaryAttr::get(newAttrs, value.getContext());
+      return;
+    }
+
+  // Otherwise, add it.
+  newAttrs.push_back({name, value});
+  attrs = DictionaryAttr::get(newAttrs, value.getContext());
+}
+
+/// Remove the attribute with the specified name if it exists.  The return
+/// value indicates whether the attribute was present or not.
+auto NamedAttributeList::remove(Identifier name) -> RemoveResult {
+  auto origAttrs = getAttrs();
+  for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
+    if (origAttrs[i].first == name) {
+      // Handle the simple case of removing the only attribute in the list.
+      if (e == 1) {
+        attrs = nullptr;
+        return RemoveResult::Removed;
+      }
+
+      SmallVector<NamedAttribute, 8> newAttrs;
+      newAttrs.reserve(origAttrs.size() - 1);
+      newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
+      newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
+      attrs = DictionaryAttr::get(newAttrs, newAttrs[0].second.getContext());
+      return RemoveResult::Removed;
+    }
+  }
+  return RemoveResult::NotFound;
+}
diff --git a/third_party/mlir/lib/IR/Block.cpp b/third_party/mlir/lib/IR/Block.cpp
new file mode 100644
index 0000000..28614ca
--- /dev/null
+++ b/third_party/mlir/lib/IR/Block.cpp
@@ -0,0 +1,281 @@
+//===- Block.cpp - MLIR Block Class ---------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Operation.h"
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// BlockArgument
+//===----------------------------------------------------------------------===//
+
+/// Returns the number of this argument.
+unsigned BlockArgument::getArgNumber() {
+  // Arguments are not stored in place, so we have to find it within the list.
+  auto argList = getOwner()->getArguments();
+  return std::distance(argList.begin(), llvm::find(argList, this));
+}
+
+//===----------------------------------------------------------------------===//
+// Block
+//===----------------------------------------------------------------------===//
+
+Block::~Block() {
+  assert(!verifyInstOrder() && "Expected valid operation ordering.");
+  clear();
+
+  for (auto *arg : arguments)
+    if (!arg->use_empty())
+      arg->user_begin()->dump();
+
+  llvm::DeleteContainerPointers(arguments);
+}
+
+Region *Block::getParent() { return parentValidInstOrderPair.getPointer(); }
+
+/// Returns the closest surrounding operation that contains this block or
+/// nullptr if this block is unlinked.
+Operation *Block::getParentOp() {
+  return getParent() ? getParent()->getParentOp() : nullptr;
+}
+
+/// Return if this block is the entry block in the parent region.
+bool Block::isEntryBlock() { return this == &getParent()->front(); }
+
+/// Insert this block (which must not already be in a region) right before the
+/// specified block.
+void Block::insertBefore(Block *block) {
+  assert(!getParent() && "already inserted into a block!");
+  assert(block->getParent() && "cannot insert before a block without a parent");
+  block->getParent()->getBlocks().insert(Region::iterator(block), this);
+}
+
+/// Unlink this Block from its parent Region and delete it.
+void Block::erase() {
+  assert(getParent() && "Block has no parent");
+  getParent()->getBlocks().erase(this);
+}
+
+/// Returns 'op' if 'op' lies in this block, or otherwise finds the
+/// ancestor operation of 'op' that lies in this block. Returns nullptr if
+/// the latter fails.
+Operation *Block::findAncestorInstInBlock(Operation &op) {
+  // Traverse up the operation hierarchy starting from the owner of operand to
+  // find the ancestor operation that resides in the block of 'forInst'.
+  auto *currInst = &op;
+  while (currInst->getBlock() != this) {
+    currInst = currInst->getParentOp();
+    if (!currInst)
+      return nullptr;
+  }
+  return currInst;
+}
+
+/// This drops all operand uses from operations within this block, which is
+/// an essential step in breaking cyclic dependences between references when
+/// they are to be deleted.
+void Block::dropAllReferences() {
+  for (Operation &i : *this)
+    i.dropAllReferences();
+}
+
+void Block::dropAllDefinedValueUses() {
+  for (auto *arg : getArguments())
+    arg->dropAllUses();
+  for (auto &op : *this)
+    op.dropAllDefinedValueUses();
+  dropAllUses();
+}
+
+/// Returns true if the ordering of the child operations is valid, false
+/// otherwise.
+bool Block::isInstOrderValid() { return parentValidInstOrderPair.getInt(); }
+
+/// Invalidates the current ordering of operations.
+void Block::invalidateInstOrder() {
+  // Validate the current ordering.
+  assert(!verifyInstOrder());
+  parentValidInstOrderPair.setInt(false);
+}
+
+/// Verifies the current ordering of child operations. Returns false if the
+/// order is valid, true otherwise.
+bool Block::verifyInstOrder() {
+  // The order is already known to be invalid.
+  if (!isInstOrderValid())
+    return false;
+  // The order is valid if there are less than 2 operations.
+  if (operations.empty() || std::next(operations.begin()) == operations.end())
+    return false;
+
+  Operation *prev = nullptr;
+  for (auto &i : *this) {
+    // The previous operation must have a smaller order index than the next as
+    // it appears earlier in the list.
+    if (prev && prev->orderIndex >= i.orderIndex)
+      return true;
+    prev = &i;
+  }
+  return false;
+}
+
+/// Recomputes the ordering of child operations within the block.
+void Block::recomputeInstOrder() {
+  parentValidInstOrderPair.setInt(true);
+
+  // TODO(riverriddle) Have non-congruent indices to reduce the number of times
+  // an insert invalidates the list.
+  unsigned orderIndex = 0;
+  for (auto &op : *this)
+    op.orderIndex = orderIndex++;
+}
+
+//===----------------------------------------------------------------------===//
+// Argument list management.
+//===----------------------------------------------------------------------===//
+
+BlockArgument *Block::addArgument(Type type) {
+  auto *arg = new BlockArgument(type, this);
+  arguments.push_back(arg);
+  return arg;
+}
+
+/// Add one argument to the argument list for each type specified in the list.
+auto Block::addArguments(ArrayRef<Type> types)
+    -> llvm::iterator_range<args_iterator> {
+  arguments.reserve(arguments.size() + types.size());
+  auto initialSize = arguments.size();
+  for (auto type : types) {
+    addArgument(type);
+  }
+  return {arguments.data() + initialSize, arguments.data() + arguments.size()};
+}
+
+void Block::eraseArgument(unsigned index, bool updatePredTerms) {
+  assert(index < arguments.size());
+
+  // Delete the argument.
+  delete arguments[index];
+  arguments.erase(arguments.begin() + index);
+
+  // If we aren't updating predecessors, there is nothing left to do.
+  if (!updatePredTerms)
+    return;
+
+  // Erase this argument from each of the predecessor's terminator.
+  for (auto predIt = pred_begin(), predE = pred_end(); predIt != predE;
+       ++predIt) {
+    auto *predTerminator = (*predIt)->getTerminator();
+    predTerminator->eraseSuccessorOperand(predIt.getSuccessorIndex(), index);
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// Terminator management
+//===----------------------------------------------------------------------===//
+
+/// Get the terminator operation of this block. This function asserts that
+/// the block has a valid terminator operation.
+Operation *Block::getTerminator() {
+  assert(!empty() && !back().isKnownNonTerminator());
+  return &back();
+}
+
+/// Return true if this block has no predecessors.
+bool Block::hasNoPredecessors() { return pred_begin() == pred_end(); }
+
+// Indexed successor access.
+unsigned Block::getNumSuccessors() {
+  return empty() ? 0 : back().getNumSuccessors();
+}
+
+Block *Block::getSuccessor(unsigned i) {
+  assert(i < getNumSuccessors());
+  return getTerminator()->getSuccessor(i);
+}
+
+/// If this block has exactly one predecessor, return it.  Otherwise, return
+/// null.
+///
+/// Note that multiple edges from a single block (e.g. if you have a cond
+/// branch with the same block as the true/false destinations) is not
+/// considered to be a single predecessor.
+Block *Block::getSinglePredecessor() {
+  auto it = pred_begin();
+  if (it == pred_end())
+    return nullptr;
+  auto *firstPred = *it;
+  ++it;
+  return it == pred_end() ? firstPred : nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// Operation Walkers
+//===----------------------------------------------------------------------===//
+
+void Block::walk(llvm::function_ref<void(Operation *)> callback) {
+  walk(begin(), end(), callback);
+}
+
+/// Walk the operations in the specified [begin, end) range of this block,
+/// calling the callback for each operation.
+void Block::walk(Block::iterator begin, Block::iterator end,
+                 llvm::function_ref<void(Operation *)> callback) {
+  for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
+    op.walk(callback);
+}
+
+//===----------------------------------------------------------------------===//
+// Other
+//===----------------------------------------------------------------------===//
+
+/// Split the block into two blocks before the specified operation or
+/// iterator.
+///
+/// Note that all operations BEFORE the specified iterator stay as part of
+/// the original basic block, and the rest of the operations in the original
+/// block are moved to the new block, including the old terminator.  The
+/// original block is left without a terminator.
+///
+/// The newly formed Block is returned, and the specified iterator is
+/// invalidated.
+Block *Block::splitBlock(iterator splitBefore) {
+  // Start by creating a new basic block, and insert it immediate after this
+  // one in the containing region.
+  auto newBB = new Block();
+  getParent()->getBlocks().insert(std::next(Region::iterator(this)), newBB);
+
+  // Move all of the operations from the split point to the end of the region
+  // into the new block.
+  newBB->getOperations().splice(newBB->end(), getOperations(), splitBefore,
+                                end());
+  return newBB;
+}
+
+//===----------------------------------------------------------------------===//
+// Predecessors
+//===----------------------------------------------------------------------===//
+
+Block *PredecessorIterator::unwrap(BlockOperand &value) {
+  return value.getOwner()->getBlock();
+}
+
+/// Get the successor number in the predecessor terminator.
+unsigned PredecessorIterator::getSuccessorIndex() const {
+  return I->getOperandNumber();
+}
diff --git a/third_party/mlir/lib/IR/Builders.cpp b/third_party/mlir/lib/IR/Builders.cpp
new file mode 100644
index 0000000..2ade7b9
--- /dev/null
+++ b/third_party/mlir/lib/IR/Builders.cpp
@@ -0,0 +1,404 @@
+//===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Support/Functional.h"
+using namespace mlir;
+
+Builder::Builder(ModuleOp module) : context(module.getContext()) {}
+
+Identifier Builder::getIdentifier(StringRef str) {
+  return Identifier::get(str, context);
+}
+
+//===----------------------------------------------------------------------===//
+// Locations.
+//===----------------------------------------------------------------------===//
+
+Location Builder::getUnknownLoc() { return UnknownLoc::get(context); }
+
+Location Builder::getFileLineColLoc(Identifier filename, unsigned line,
+                                    unsigned column) {
+  return FileLineColLoc::get(filename, line, column, context);
+}
+
+Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
+  return FusedLoc::get(locs, metadata, context);
+}
+
+//===----------------------------------------------------------------------===//
+// Types.
+//===----------------------------------------------------------------------===//
+
+FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
+
+FloatType Builder::getF16Type() { return FloatType::getF16(context); }
+
+FloatType Builder::getF32Type() { return FloatType::getF32(context); }
+
+FloatType Builder::getF64Type() { return FloatType::getF64(context); }
+
+IndexType Builder::getIndexType() { return IndexType::get(context); }
+
+IntegerType Builder::getI1Type() { return IntegerType::get(1, context); }
+
+IntegerType Builder::getIntegerType(unsigned width) {
+  return IntegerType::get(width, context);
+}
+
+FunctionType Builder::getFunctionType(ArrayRef<Type> inputs,
+                                      ArrayRef<Type> results) {
+  return FunctionType::get(inputs, results, context);
+}
+
+MemRefType Builder::getMemRefType(ArrayRef<int64_t> shape, Type elementType,
+                                  ArrayRef<AffineMap> affineMapComposition,
+                                  unsigned memorySpace) {
+  return MemRefType::get(shape, elementType, affineMapComposition, memorySpace);
+}
+
+VectorType Builder::getVectorType(ArrayRef<int64_t> shape, Type elementType) {
+  return VectorType::get(shape, elementType);
+}
+
+RankedTensorType Builder::getTensorType(ArrayRef<int64_t> shape,
+                                        Type elementType) {
+  return RankedTensorType::get(shape, elementType);
+}
+
+UnrankedTensorType Builder::getTensorType(Type elementType) {
+  return UnrankedTensorType::get(elementType);
+}
+
+TupleType Builder::getTupleType(ArrayRef<Type> elementTypes) {
+  return TupleType::get(elementTypes, context);
+}
+
+NoneType Builder::getNoneType() { return NoneType::get(context); }
+
+//===----------------------------------------------------------------------===//
+// Attributes.
+//===----------------------------------------------------------------------===//
+
+NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) {
+  return NamedAttribute(getIdentifier(name), val);
+}
+
+UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); }
+
+BoolAttr Builder::getBoolAttr(bool value) {
+  return BoolAttr::get(value, context);
+}
+
+DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
+  return DictionaryAttr::get(value, context);
+}
+
+IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
+  return IntegerAttr::get(getIntegerType(64), APInt(64, value));
+}
+
+IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
+  return IntegerAttr::get(getIntegerType(32), APInt(32, value));
+}
+
+IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
+  if (type.isIndex())
+    return IntegerAttr::get(type, APInt(64, value));
+  return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), value));
+}
+
+IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
+  return IntegerAttr::get(type, value);
+}
+
+FloatAttr Builder::getF64FloatAttr(double value) {
+  return FloatAttr::get(getF64Type(), APFloat(value));
+}
+
+FloatAttr Builder::getF32FloatAttr(float value) {
+  return FloatAttr::get(getF32Type(), APFloat(value));
+}
+
+FloatAttr Builder::getF16FloatAttr(float value) {
+  return FloatAttr::get(getF16Type(), value);
+}
+
+FloatAttr Builder::getFloatAttr(Type type, double value) {
+  return FloatAttr::get(type, value);
+}
+
+FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {
+  return FloatAttr::get(type, value);
+}
+
+StringAttr Builder::getStringAttr(StringRef bytes) {
+  return StringAttr::get(bytes, context);
+}
+
+StringAttr Builder::getStringAttr(StringRef bytes, Type type) {
+  return StringAttr::get(bytes, type);
+}
+
+ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
+  return ArrayAttr::get(value, context);
+}
+
+AffineMapAttr Builder::getAffineMapAttr(AffineMap map) {
+  return AffineMapAttr::get(map);
+}
+
+IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) {
+  return IntegerSetAttr::get(set);
+}
+
+TypeAttr Builder::getTypeAttr(Type type) { return TypeAttr::get(type); }
+
+SymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
+  auto symName =
+      value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
+  assert(symName && "value does not have a valid symbol name");
+  return getSymbolRefAttr(symName.getValue());
+}
+SymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
+  return SymbolRefAttr::get(value, getContext());
+}
+
+ElementsAttr Builder::getDenseElementsAttr(ShapedType type,
+                                           ArrayRef<Attribute> values) {
+  return DenseElementsAttr::get(type, values);
+}
+
+ElementsAttr Builder::getDenseIntElementsAttr(ShapedType type,
+                                              ArrayRef<int64_t> values) {
+  return DenseIntElementsAttr::get(type, values);
+}
+
+ElementsAttr Builder::getSparseElementsAttr(ShapedType type,
+                                            DenseIntElementsAttr indices,
+                                            DenseElementsAttr values) {
+  return SparseElementsAttr::get(type, indices, values);
+}
+
+ElementsAttr Builder::getOpaqueElementsAttr(Dialect *dialect, ShapedType type,
+                                            StringRef bytes) {
+  return OpaqueElementsAttr::get(dialect, type, bytes);
+}
+
+ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
+  auto attrs = functional::map(
+      [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }, values);
+  return getArrayAttr(attrs);
+}
+
+ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
+  auto attrs = functional::map(
+      [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }, values);
+  return getArrayAttr(attrs);
+}
+
+ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
+  auto attrs = functional::map(
+      [this](float v) -> Attribute { return getF32FloatAttr(v); }, values);
+  return getArrayAttr(attrs);
+}
+
+ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) {
+  auto attrs = functional::map(
+      [this](double v) -> Attribute { return getF64FloatAttr(v); }, values);
+  return getArrayAttr(attrs);
+}
+
+ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
+  auto attrs = functional::map(
+      [this](StringRef v) -> Attribute { return getStringAttr(v); }, values);
+  return getArrayAttr(attrs);
+}
+
+ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
+  auto attrs = functional::map(
+      [this](AffineMap v) -> Attribute { return getAffineMapAttr(v); }, values);
+  return getArrayAttr(attrs);
+}
+
+Attribute Builder::getZeroAttr(Type type) {
+  switch (type.getKind()) {
+  case StandardTypes::F16:
+    return getF16FloatAttr(0);
+  case StandardTypes::F32:
+    return getF32FloatAttr(0);
+  case StandardTypes::F64:
+    return getF64FloatAttr(0);
+  case StandardTypes::Integer: {
+    auto width = type.cast<IntegerType>().getWidth();
+    if (width == 1)
+      return getBoolAttr(false);
+    return getIntegerAttr(type, APInt(width, 0));
+  }
+  case StandardTypes::Vector:
+  case StandardTypes::RankedTensor: {
+    auto vtType = type.cast<ShapedType>();
+    auto element = getZeroAttr(vtType.getElementType());
+    if (!element)
+      return {};
+    return getDenseElementsAttr(vtType, element);
+  }
+  default:
+    break;
+  }
+  return {};
+}
+
+//===----------------------------------------------------------------------===//
+// Affine Expressions, Affine Maps, and Integet Sets.
+//===----------------------------------------------------------------------===//
+
+AffineMap Builder::getAffineMap(unsigned dimCount, unsigned symbolCount,
+                                ArrayRef<AffineExpr> results) {
+  return AffineMap::get(dimCount, symbolCount, results);
+}
+
+AffineExpr Builder::getAffineDimExpr(unsigned position) {
+  return mlir::getAffineDimExpr(position, context);
+}
+
+AffineExpr Builder::getAffineSymbolExpr(unsigned position) {
+  return mlir::getAffineSymbolExpr(position, context);
+}
+
+AffineExpr Builder::getAffineConstantExpr(int64_t constant) {
+  return mlir::getAffineConstantExpr(constant, context);
+}
+
+IntegerSet Builder::getIntegerSet(unsigned dimCount, unsigned symbolCount,
+                                  ArrayRef<AffineExpr> constraints,
+                                  ArrayRef<bool> isEq) {
+  return IntegerSet::get(dimCount, symbolCount, constraints, isEq);
+}
+
+AffineMap Builder::getEmptyAffineMap() { return AffineMap::get(context); }
+
+AffineMap Builder::getConstantAffineMap(int64_t val) {
+  return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
+                        {getAffineConstantExpr(val)});
+}
+
+AffineMap Builder::getDimIdentityMap() {
+  return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
+                        {getAffineDimExpr(0)});
+}
+
+AffineMap Builder::getMultiDimIdentityMap(unsigned rank) {
+  SmallVector<AffineExpr, 4> dimExprs;
+  dimExprs.reserve(rank);
+  for (unsigned i = 0; i < rank; ++i)
+    dimExprs.push_back(getAffineDimExpr(i));
+  return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs);
+}
+
+AffineMap Builder::getSymbolIdentityMap() {
+  return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
+                        {getAffineSymbolExpr(0)});
+}
+
+AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) {
+  // expr = d0 + shift.
+  auto expr = getAffineDimExpr(0) + shift;
+  return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, {expr});
+}
+
+AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
+  SmallVector<AffineExpr, 4> shiftedResults;
+  shiftedResults.reserve(map.getNumResults());
+  for (auto resultExpr : map.getResults()) {
+    shiftedResults.push_back(resultExpr + shift);
+  }
+  return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults);
+}
+
+//===----------------------------------------------------------------------===//
+// OpBuilder.
+//===----------------------------------------------------------------------===//
+
+OpBuilder::~OpBuilder() {}
+
+/// Add new block and set the insertion point to the end of it. The block is
+/// inserted at the provided insertion point of 'parent'.
+Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt) {
+  assert(parent && "expected valid parent region");
+  if (insertPt == Region::iterator())
+    insertPt = parent->end();
+
+  Block *b = new Block();
+  parent->getBlocks().insert(insertPt, b);
+  setInsertionPointToEnd(b);
+  return b;
+}
+
+/// Add new block and set the insertion point to the end of it.  The block is
+/// placed before 'insertBefore'.
+Block *OpBuilder::createBlock(Block *insertBefore) {
+  assert(insertBefore && "expected valid insertion block");
+  return createBlock(insertBefore->getParent(), Region::iterator(insertBefore));
+}
+
+/// Create an operation given the fields represented as an OperationState.
+Operation *OpBuilder::createOperation(const OperationState &state) {
+  assert(block && "createOperation() called without setting builder's block");
+  auto *op = Operation::create(state);
+  insert(op);
+  return op;
+}
+
+/// Attempts to fold the given operation and places new results within
+/// 'results'.
+void OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value *> &results) {
+  results.reserve(op->getNumResults());
+  SmallVector<OpFoldResult, 4> foldResults;
+
+  // Returns if the given fold result corresponds to a valid existing value.
+  auto isValidValue = [](OpFoldResult result) {
+    return result.dyn_cast<Value *>();
+  };
+
+  // Check if the fold failed, or did not result in only existing values.
+  SmallVector<Attribute, 4> constOperands(op->getNumOperands());
+  if (failed(op->fold(constOperands, foldResults)) || foldResults.empty() ||
+      !llvm::all_of(foldResults, isValidValue)) {
+    // Simply return the existing operation results.
+    results.assign(op->result_begin(), op->result_end());
+    return;
+  }
+
+  // Populate the results with the folded results and remove the original op.
+  llvm::transform(foldResults, std::back_inserter(results),
+                  [](OpFoldResult result) { return result.get<Value *>(); });
+  op->erase();
+}
+
+/// Insert the given operation at the current insertion point.
+void OpBuilder::insert(Operation *op) {
+  if (block)
+    block->getOperations().insert(insertPoint, op);
+}
diff --git a/third_party/mlir/lib/IR/CMakeLists.txt b/third_party/mlir/lib/IR/CMakeLists.txt
new file mode 100644
index 0000000..6bb1265
--- /dev/null
+++ b/third_party/mlir/lib/IR/CMakeLists.txt
@@ -0,0 +1,9 @@
+file(GLOB globbed *.c *.cpp)
+add_llvm_library(MLIRIR
+  ${globbed}
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
+  )
+add_dependencies(MLIRIR MLIRSupport LLVMSupport)
+target_link_libraries(MLIRIR MLIRSupport LLVMSupport)
diff --git a/third_party/mlir/lib/IR/Diagnostics.cpp b/third_party/mlir/lib/IR/Diagnostics.cpp
new file mode 100644
index 0000000..076a9b2
--- /dev/null
+++ b/third_party/mlir/lib/IR/Diagnostics.cpp
@@ -0,0 +1,862 @@
+//===- Diagnostics.cpp - MLIR Diagnostics ---------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Identifier.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Types.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/StringMap.h"
+#include "llvm/Support/Mutex.h"
+#include "llvm/Support/PrettyStackTrace.h"
+#include "llvm/Support/Regex.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+//===----------------------------------------------------------------------===//
+// DiagnosticArgument
+//===----------------------------------------------------------------------===//
+
+// Construct from an Attribute.
+DiagnosticArgument::DiagnosticArgument(Attribute attr)
+    : kind(DiagnosticArgumentKind::Attribute),
+      opaqueVal(reinterpret_cast<intptr_t>(attr.getAsOpaquePointer())) {}
+
+// Construct from a Type.
+DiagnosticArgument::DiagnosticArgument(Type val)
+    : kind(DiagnosticArgumentKind::Type),
+      opaqueVal(reinterpret_cast<intptr_t>(val.getAsOpaquePointer())) {}
+
+/// Returns this argument as an Attribute.
+Attribute DiagnosticArgument::getAsAttribute() const {
+  assert(getKind() == DiagnosticArgumentKind::Attribute);
+  return Attribute::getFromOpaquePointer(
+      reinterpret_cast<const void *>(opaqueVal));
+}
+
+/// Returns this argument as a Type.
+Type DiagnosticArgument::getAsType() const {
+  assert(getKind() == DiagnosticArgumentKind::Type);
+  return Type::getFromOpaquePointer(reinterpret_cast<const void *>(opaqueVal));
+}
+
+/// Outputs this argument to a stream.
+void DiagnosticArgument::print(raw_ostream &os) const {
+  switch (kind) {
+  case DiagnosticArgumentKind::Attribute:
+    os << getAsAttribute();
+    break;
+  case DiagnosticArgumentKind::Double:
+    os << getAsDouble();
+    break;
+  case DiagnosticArgumentKind::Integer:
+    os << getAsInteger();
+    break;
+  case DiagnosticArgumentKind::Operation:
+    os << getAsOperation();
+    break;
+  case DiagnosticArgumentKind::String:
+    os << getAsString();
+    break;
+  case DiagnosticArgumentKind::Type:
+    os << '\'' << getAsType() << '\'';
+    break;
+  case DiagnosticArgumentKind::Unsigned:
+    os << getAsUnsigned();
+    break;
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// Diagnostic
+//===----------------------------------------------------------------------===//
+
+/// Convert a Twine to a StringRef. Memory used for generating the StringRef is
+/// stored in 'strings'.
+static StringRef twineToStrRef(const Twine &val,
+                               std::vector<std::unique_ptr<char[]>> &strings) {
+  // Allocate memory to hold this string.
+  llvm::SmallString<64> data;
+  auto strRef = val.toStringRef(data);
+  strings.push_back(std::unique_ptr<char[]>(new char[strRef.size()]));
+  memcpy(&strings.back()[0], strRef.data(), strRef.size());
+
+  // Return a reference to the new string.
+  return StringRef(&strings.back()[0], strRef.size());
+}
+
+/// Stream in a Twine argument.
+Diagnostic &Diagnostic::operator<<(char val) { return *this << Twine(val); }
+Diagnostic &Diagnostic::operator<<(const Twine &val) {
+  arguments.push_back(DiagnosticArgument(twineToStrRef(val, strings)));
+  return *this;
+}
+Diagnostic &Diagnostic::operator<<(Twine &&val) {
+  arguments.push_back(DiagnosticArgument(twineToStrRef(val, strings)));
+  return *this;
+}
+
+/// Stream in an Identifier.
+Diagnostic &Diagnostic::operator<<(Identifier val) {
+  // An identifier is stored in the context, so we don't need to worry about the
+  // lifetime of its data.
+  arguments.push_back(DiagnosticArgument(val.strref()));
+  return *this;
+}
+
+/// Stream in an OperationName.
+Diagnostic &Diagnostic::operator<<(OperationName val) {
+  // An OperationName is stored in the context, so we don't need to worry about
+  // the lifetime of its data.
+  arguments.push_back(DiagnosticArgument(val.getStringRef()));
+  return *this;
+}
+
+/// Outputs this diagnostic to a stream.
+void Diagnostic::print(raw_ostream &os) const {
+  for (auto &arg : getArguments())
+    arg.print(os);
+}
+
+/// Convert the diagnostic to a string.
+std::string Diagnostic::str() const {
+  std::string str;
+  llvm::raw_string_ostream os(str);
+  print(os);
+  return os.str();
+}
+
+/// Attaches a note to this diagnostic. A new location may be optionally
+/// provided, if not, then the location defaults to the one specified for this
+/// diagnostic. Notes may not be attached to other notes.
+Diagnostic &Diagnostic::attachNote(llvm::Optional<Location> noteLoc) {
+  // We don't allow attaching notes to notes.
+  assert(severity != DiagnosticSeverity::Note &&
+         "cannot attach a note to a note");
+
+  // If a location wasn't provided then reuse our location.
+  if (!noteLoc)
+    noteLoc = loc;
+
+  /// Append and return a new note.
+  notes.push_back(
+      llvm::make_unique<Diagnostic>(*noteLoc, DiagnosticSeverity::Note));
+  return *notes.back();
+}
+
+/// Allow a diagnostic to be converted to 'failure'.
+Diagnostic::operator LogicalResult() const { return failure(); }
+
+//===----------------------------------------------------------------------===//
+// InFlightDiagnostic
+//===----------------------------------------------------------------------===//
+
+/// Allow an inflight diagnostic to be converted to 'failure', otherwise
+/// 'success' if this is an empty diagnostic.
+InFlightDiagnostic::operator LogicalResult() const {
+  return failure(isActive());
+}
+
+/// Reports the diagnostic to the engine.
+void InFlightDiagnostic::report() {
+  // If this diagnostic is still inflight and it hasn't been abandoned, then
+  // report it.
+  if (isInFlight()) {
+    owner->emit(std::move(*impl));
+    owner = nullptr;
+  }
+  impl.reset();
+}
+
+/// Abandons this diagnostic.
+void InFlightDiagnostic::abandon() { owner = nullptr; }
+
+//===----------------------------------------------------------------------===//
+// DiagnosticEngineImpl
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace detail {
+struct DiagnosticEngineImpl {
+  /// Emit a diagnostic using the registered issue handle if present, or with
+  /// the default behavior if not.
+  void emit(Diagnostic diag);
+
+  /// A mutex to ensure that diagnostics emission is thread-safe.
+  llvm::sys::SmartMutex<true> mutex;
+
+  /// This is the handler to use to report diagnostics, or null if not
+  /// registered.
+  DiagnosticEngine::HandlerTy handler;
+};
+} // namespace detail
+} // namespace mlir
+
+/// Emit a diagnostic using the registered issue handle if present, or with
+/// the default behavior if not.
+void DiagnosticEngineImpl::emit(Diagnostic diag) {
+  llvm::sys::SmartScopedLock<true> lock(mutex);
+
+  // If we had a handler registered, emit the diagnostic using it.
+  if (handler)
+    return handler(std::move(diag));
+
+  // Otherwise, if this is an error we emit it to stderr.
+  if (diag.getSeverity() != DiagnosticSeverity::Error)
+    return;
+
+  auto &os = llvm::errs();
+  if (!diag.getLocation().isa<UnknownLoc>())
+    os << diag.getLocation() << ": ";
+  os << "error: ";
+
+  // The default behavior for errors is to emit them to stderr.
+  os << diag << '\n';
+  os.flush();
+}
+
+//===----------------------------------------------------------------------===//
+// DiagnosticEngine
+//===----------------------------------------------------------------------===//
+
+DiagnosticEngine::DiagnosticEngine() : impl(new DiagnosticEngineImpl()) {}
+DiagnosticEngine::~DiagnosticEngine() {}
+
+/// Set the diagnostic handler for this engine.  The handler is passed
+/// location information if present (nullptr if not) along with a message and
+/// a severity that indicates whether this is an error, warning, etc. Note
+/// that this replaces any existing handler.
+void DiagnosticEngine::setHandler(const HandlerTy &handler) {
+  impl->handler = handler;
+}
+
+/// Return the current diagnostic handler, or null if none is present.
+auto DiagnosticEngine::getHandler() -> HandlerTy {
+  llvm::sys::SmartScopedLock<true> lock(impl->mutex);
+  return impl->handler;
+}
+
+/// Emit a diagnostic using the registered issue handler if present, or with
+/// the default behavior if not.
+void DiagnosticEngine::emit(Diagnostic diag) {
+  assert(diag.getSeverity() != DiagnosticSeverity::Note &&
+         "notes should not be emitted directly");
+  impl->emit(std::move(diag));
+}
+
+/// Helper function used to emit a diagnostic with an optionally empty twine
+/// message. If the message is empty, then it is not inserted into the
+/// diagnostic.
+static InFlightDiagnostic emitDiag(Location location,
+                                   DiagnosticSeverity severity,
+                                   const llvm::Twine &message) {
+  auto &diagEngine = location->getContext()->getDiagEngine();
+  auto diag = diagEngine.emit(location, severity);
+  if (!message.isTriviallyEmpty())
+    diag << message;
+  return diag;
+}
+
+/// Emit an error message using this location.
+InFlightDiagnostic mlir::emitError(Location loc) { return emitError(loc, {}); }
+InFlightDiagnostic mlir::emitError(Location loc, const Twine &message) {
+  return emitDiag(loc, DiagnosticSeverity::Error, message);
+}
+
+/// Emit a warning message using this location.
+InFlightDiagnostic mlir::emitWarning(Location loc) {
+  return emitWarning(loc, {});
+}
+InFlightDiagnostic mlir::emitWarning(Location loc, const Twine &message) {
+  return emitDiag(loc, DiagnosticSeverity::Warning, message);
+}
+
+/// Emit a remark message using this location.
+InFlightDiagnostic mlir::emitRemark(Location loc) {
+  return emitRemark(loc, {});
+}
+InFlightDiagnostic mlir::emitRemark(Location loc, const Twine &message) {
+  return emitDiag(loc, DiagnosticSeverity::Remark, message);
+}
+
+//===----------------------------------------------------------------------===//
+// ScopedDiagnosticHandler
+//===----------------------------------------------------------------------===//
+
+ScopedDiagnosticHandler::ScopedDiagnosticHandler(MLIRContext *ctx)
+    : existingHandler(ctx->getDiagEngine().getHandler()), ctx(ctx) {}
+ScopedDiagnosticHandler::ScopedDiagnosticHandler(
+    MLIRContext *ctx, const DiagnosticEngine::HandlerTy &handler)
+    : ScopedDiagnosticHandler(ctx) {
+  ctx->getDiagEngine().setHandler(handler);
+}
+ScopedDiagnosticHandler::~ScopedDiagnosticHandler() {
+  ctx->getDiagEngine().setHandler(existingHandler);
+}
+
+//===----------------------------------------------------------------------===//
+// SourceMgrDiagnosticHandler
+//===----------------------------------------------------------------------===//
+namespace mlir {
+namespace detail {
+struct SourceMgrDiagnosticHandlerImpl {
+  /// Get a memory buffer for the given file, or nullptr if one is not found.
+  const llvm::MemoryBuffer *getBufferForFile(llvm::SourceMgr &mgr,
+                                             StringRef filename) {
+    // Check for an existing mapping to the buffer id for this file.
+    auto bufferIt = filenameToBuf.find(filename);
+    if (bufferIt != filenameToBuf.end())
+      return bufferIt->second;
+
+    // Look for a buffer in the manager that has this filename.
+    for (unsigned i = 1, e = mgr.getNumBuffers() + 1; i != e; ++i) {
+      auto *buf = mgr.getMemoryBuffer(i);
+      if (buf->getBufferIdentifier() == filename)
+        return filenameToBuf[filename] = buf;
+    }
+
+    // Otherwise, try to load the source file.
+    const llvm::MemoryBuffer *newBuf = nullptr;
+    std::string ignored;
+    if (auto newBufID = mgr.AddIncludeFile(filename, llvm::SMLoc(), ignored))
+      newBuf = mgr.getMemoryBuffer(newBufID);
+    return filenameToBuf[filename] = newBuf;
+  }
+
+  /// Mapping between file name and buffer pointer.
+  llvm::StringMap<const llvm::MemoryBuffer *> filenameToBuf;
+};
+} // end namespace detail
+} // end namespace mlir
+
+/// Return a processable FileLineColLoc from the given location.
+static llvm::Optional<FileLineColLoc> getFileLineColLoc(Location loc) {
+  switch (loc->getKind()) {
+  case StandardAttributes::NameLocation:
+    return getFileLineColLoc(loc.cast<NameLoc>().getChildLoc());
+  case StandardAttributes::FileLineColLocation:
+    return loc.cast<FileLineColLoc>();
+  case StandardAttributes::CallSiteLocation:
+    // Process the callee of a callsite location.
+    return getFileLineColLoc(loc.cast<CallSiteLoc>().getCallee());
+  default:
+    return llvm::None;
+  }
+}
+
+/// Given a diagnostic kind, returns the LLVM DiagKind.
+static llvm::SourceMgr::DiagKind getDiagKind(DiagnosticSeverity kind) {
+  switch (kind) {
+  case DiagnosticSeverity::Note:
+    return llvm::SourceMgr::DK_Note;
+  case DiagnosticSeverity::Warning:
+    return llvm::SourceMgr::DK_Warning;
+  case DiagnosticSeverity::Error:
+    return llvm::SourceMgr::DK_Error;
+  case DiagnosticSeverity::Remark:
+    return llvm::SourceMgr::DK_Remark;
+  }
+  llvm_unreachable("Unknown DiagnosticSeverity");
+}
+
+SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr,
+                                                       MLIRContext *ctx,
+                                                       llvm::raw_ostream &os)
+    : ScopedDiagnosticHandler(ctx), mgr(mgr), os(os),
+      impl(new SourceMgrDiagnosticHandlerImpl()) {
+  // Register a simple diagnostic handler.
+  ctx->getDiagEngine().setHandler(
+      [this](Diagnostic diag) { emitDiagnostic(diag); });
+}
+
+SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr,
+                                                       MLIRContext *ctx)
+    : SourceMgrDiagnosticHandler(mgr, ctx, llvm::errs()) {}
+
+SourceMgrDiagnosticHandler::~SourceMgrDiagnosticHandler() {}
+
+void SourceMgrDiagnosticHandler::emitDiagnostic(Location loc, Twine message,
+                                                DiagnosticSeverity kind) {
+  // Extract a file location from this loc.
+  auto fileLoc = getFileLineColLoc(loc);
+
+  // If one doesn't exist, then print the raw message without a source location.
+  if (!fileLoc) {
+    std::string str;
+    llvm::raw_string_ostream strOS(str);
+    if (!loc.isa<UnknownLoc>())
+      strOS << loc << ": ";
+    strOS << message;
+    return mgr.PrintMessage(os, llvm::SMLoc(), getDiagKind(kind), strOS.str());
+  }
+
+  // Otherwise, try to convert the file location to an SMLoc.
+  auto smloc = convertLocToSMLoc(*fileLoc);
+  if (smloc.isValid())
+    return mgr.PrintMessage(os, smloc, getDiagKind(kind), message);
+
+  // If the conversion was unsuccessful, create a diagnostic with the source
+  // location information directly.
+  llvm::SMDiagnostic diag(mgr, llvm::SMLoc(), fileLoc->getFilename(),
+                          fileLoc->getLine(), fileLoc->getColumn(),
+                          getDiagKind(kind), message.str(), /*LineStr=*/"",
+                          /*Ranges=*/llvm::None);
+  diag.print(nullptr, os);
+}
+
+/// Emit the given diagnostic with the held source manager.
+void SourceMgrDiagnosticHandler::emitDiagnostic(Diagnostic &diag) {
+  // Emit the diagnostic.
+  auto loc = diag.getLocation();
+  emitDiagnostic(loc, diag.str(), diag.getSeverity());
+
+  // If the diagnostic location was a call site location, then print the call
+  // stack as well.
+  if (auto callLoc = loc.dyn_cast<CallSiteLoc>()) {
+    // Print the call stack while valid, or until the limit is reached.
+    Location callerLoc = callLoc.getCaller();
+    for (unsigned curDepth = 0; curDepth < callStackLimit; ++curDepth) {
+      emitDiagnostic(callerLoc, "called from", DiagnosticSeverity::Note);
+      if ((callLoc = callerLoc.dyn_cast<CallSiteLoc>()))
+        callerLoc = callLoc.getCaller();
+      else
+        break;
+    }
+  }
+
+  // Emit each of the notes.
+  for (auto &note : diag.getNotes())
+    emitDiagnostic(note.getLocation(), note.str(), note.getSeverity());
+}
+
+/// Get a memory buffer for the given file, or nullptr if one is not found.
+const llvm::MemoryBuffer *
+SourceMgrDiagnosticHandler::getBufferForFile(StringRef filename) {
+  return impl->getBufferForFile(mgr, filename);
+}
+
+/// Get a memory buffer for the given file, or the main file of the source
+/// manager if one doesn't exist. This always returns non-null.
+llvm::SMLoc SourceMgrDiagnosticHandler::convertLocToSMLoc(FileLineColLoc loc) {
+  // Get the buffer for this filename.
+  auto *membuf = getBufferForFile(loc.getFilename());
+  if (!membuf)
+    return llvm::SMLoc();
+
+  // TODO: This should really be upstreamed to be a method on llvm::SourceMgr.
+  // Doing so would allow it to use the offset cache that is already maintained
+  // by SrcBuffer, making this more efficient.
+  unsigned lineNo = loc.getLine();
+  unsigned columnNo = loc.getColumn();
+
+  // Scan for the correct line number.
+  const char *position = membuf->getBufferStart();
+  const char *end = membuf->getBufferEnd();
+
+  // We start counting line and column numbers from 1.
+  if (lineNo != 0)
+    --lineNo;
+  if (columnNo != 0)
+    --columnNo;
+
+  while (position < end && lineNo) {
+    auto curChar = *position++;
+
+    // Scan for newlines.  If this isn't one, ignore it.
+    if (curChar != '\r' && curChar != '\n')
+      continue;
+
+    // We saw a line break, decrement our counter.
+    --lineNo;
+
+    // Check for \r\n and \n\r and treat it as a single escape.  We know that
+    // looking past one character is safe because MemoryBuffer's are always nul
+    // terminated.
+    if (*position != curChar && (*position == '\r' || *position == '\n'))
+      ++position;
+  }
+
+  // If the line/column counter was invalid, return a pointer to the start of
+  // the buffer.
+  if (lineNo || position + columnNo > end)
+    return llvm::SMLoc::getFromPointer(membuf->getBufferStart());
+
+  // If the column is zero, try to skip to the first non-whitespace character.
+  if (columnNo == 0) {
+    auto isNewline = [](char c) { return c == '\n' || c == '\r'; };
+    auto isWhitespace = [](char c) { return c == ' ' || c == '\t'; };
+
+    // Look for a valid non-whitespace character before the next line.
+    for (auto *newPos = position; newPos < end && !isNewline(*newPos); ++newPos)
+      if (!isWhitespace(*newPos))
+        return llvm::SMLoc::getFromPointer(newPos);
+  }
+
+  // Otherwise return the right pointer.
+  return llvm::SMLoc::getFromPointer(position + columnNo);
+}
+
+//===----------------------------------------------------------------------===//
+// SourceMgrDiagnosticVerifierHandler
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace detail {
+// Record the expected diagnostic's position, substring and whether it was
+// seen.
+struct ExpectedDiag {
+  DiagnosticSeverity kind;
+  unsigned lineNo;
+  StringRef substring;
+  llvm::SMLoc fileLoc;
+  bool matched;
+};
+
+struct SourceMgrDiagnosticVerifierHandlerImpl {
+  SourceMgrDiagnosticVerifierHandlerImpl() : status(success()) {}
+
+  /// Returns the expected diagnostics for the given source file.
+  llvm::Optional<MutableArrayRef<ExpectedDiag>>
+  getExpectedDiags(StringRef bufName);
+
+  /// Computes the expected diagnostics for the given source buffer.
+  MutableArrayRef<ExpectedDiag>
+  computeExpectedDiags(const llvm::MemoryBuffer *buf);
+
+  /// The current status of the verifier.
+  LogicalResult status;
+
+  /// A list of expected diagnostics for each buffer of the source manager.
+  llvm::StringMap<SmallVector<ExpectedDiag, 2>> expectedDiagsPerFile;
+
+  /// Regex to match the expected diagnostics format.
+  llvm::Regex expected = llvm::Regex(
+      "expected-(error|note|remark|warning) *(@[+-][0-9]+)? *{{(.*)}}");
+};
+} // end namespace detail
+} // end namespace mlir
+
+/// Given a diagnostic kind, return a human readable string for it.
+static StringRef getDiagKindStr(DiagnosticSeverity kind) {
+  switch (kind) {
+  case DiagnosticSeverity::Note:
+    return "note";
+  case DiagnosticSeverity::Warning:
+    return "warning";
+  case DiagnosticSeverity::Error:
+    return "error";
+  case DiagnosticSeverity::Remark:
+    return "remark";
+  }
+  llvm_unreachable("Unknown DiagnosticSeverity");
+}
+
+/// Returns the expected diagnostics for the given source file.
+llvm::Optional<MutableArrayRef<ExpectedDiag>>
+SourceMgrDiagnosticVerifierHandlerImpl::getExpectedDiags(StringRef bufName) {
+  auto expectedDiags = expectedDiagsPerFile.find(bufName);
+  if (expectedDiags != expectedDiagsPerFile.end())
+    return MutableArrayRef<ExpectedDiag>(expectedDiags->second);
+  return llvm::None;
+}
+
+/// Computes the expected diagnostics for the given source buffer.
+MutableArrayRef<ExpectedDiag>
+SourceMgrDiagnosticVerifierHandlerImpl::computeExpectedDiags(
+    const llvm::MemoryBuffer *buf) {
+  // If the buffer is invalid, return an empty list.
+  if (!buf)
+    return llvm::None;
+  auto &expectedDiags = expectedDiagsPerFile[buf->getBufferIdentifier()];
+
+  // Scan the file for expected-* designators.
+  SmallVector<StringRef, 100> lines;
+  buf->getBuffer().split(lines, '\n');
+  for (unsigned lineNo = 0, e = lines.size(); lineNo < e; ++lineNo) {
+    SmallVector<StringRef, 3> matches;
+    if (!expected.match(lines[lineNo], &matches))
+      continue;
+    // Point to the start of expected-*.
+    auto expectedStart = llvm::SMLoc::getFromPointer(matches[0].data());
+
+    DiagnosticSeverity kind;
+    if (matches[1] == "error")
+      kind = DiagnosticSeverity::Error;
+    else if (matches[1] == "warning")
+      kind = DiagnosticSeverity::Warning;
+    else if (matches[1] == "remark")
+      kind = DiagnosticSeverity::Remark;
+    else {
+      assert(matches[1] == "note");
+      kind = DiagnosticSeverity::Note;
+    }
+
+    ExpectedDiag record{kind, lineNo + 1, matches[3], expectedStart, false};
+    auto offsetMatch = matches[2];
+    if (!offsetMatch.empty()) {
+      int offset;
+      // Get the integer value without the @ and +/- prefix.
+      if (!offsetMatch.drop_front(2).getAsInteger(0, offset)) {
+        if (offsetMatch[1] == '+')
+          record.lineNo += offset;
+        else
+          record.lineNo -= offset;
+      }
+    }
+    expectedDiags.push_back(record);
+  }
+  return expectedDiags;
+}
+
+SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
+    llvm::SourceMgr &srcMgr, MLIRContext *ctx, llvm::raw_ostream &out)
+    : SourceMgrDiagnosticHandler(srcMgr, ctx, out),
+      impl(new SourceMgrDiagnosticVerifierHandlerImpl()) {
+  // Compute the expected diagnostics for each of the current files in the
+  // source manager.
+  for (unsigned i = 0, e = mgr.getNumBuffers(); i != e; ++i)
+    (void)impl->computeExpectedDiags(mgr.getMemoryBuffer(i + 1));
+
+  // Register a handler to verfy the diagnostics.
+  ctx->getDiagEngine().setHandler([&](Diagnostic diag) {
+    // Process the main diagnostics.
+    process(diag);
+
+    // Process each of the notes.
+    for (auto &note : diag.getNotes())
+      process(note);
+  });
+}
+
+SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
+    llvm::SourceMgr &srcMgr, MLIRContext *ctx)
+    : SourceMgrDiagnosticVerifierHandler(srcMgr, ctx, llvm::errs()) {}
+
+SourceMgrDiagnosticVerifierHandler::~SourceMgrDiagnosticVerifierHandler() {
+  // Ensure that all expected diagnosics were handled.
+  (void)verify();
+}
+
+/// Returns the status of the verifier and verifies that all expected
+/// diagnostics were emitted. This return success if all diagnostics were
+/// verified correctly, failure otherwise.
+LogicalResult SourceMgrDiagnosticVerifierHandler::verify() {
+  // Verify that all expected errors were seen.
+  for (auto &expectedDiagsPair : impl->expectedDiagsPerFile) {
+    for (auto &err : expectedDiagsPair.second) {
+      if (err.matched)
+        continue;
+      llvm::SMRange range(err.fileLoc,
+                          llvm::SMLoc::getFromPointer(err.fileLoc.getPointer() +
+                                                      err.substring.size()));
+      mgr.PrintMessage(os, err.fileLoc, llvm::SourceMgr::DK_Error,
+                       "expected " + getDiagKindStr(err.kind) + " \"" +
+                           err.substring + "\" was not produced",
+                       range);
+      impl->status = failure();
+    }
+  }
+  impl->expectedDiagsPerFile.clear();
+  return impl->status;
+}
+
+/// Process a single diagnostic.
+void SourceMgrDiagnosticVerifierHandler::process(Diagnostic &diag) {
+  auto kind = diag.getSeverity();
+
+  // Process a FileLineColLoc.
+  if (auto fileLoc = getFileLineColLoc(diag.getLocation()))
+    return process(*fileLoc, diag.str(), kind);
+
+  emitDiagnostic(diag.getLocation(),
+                 "unexpected " + getDiagKindStr(kind) + ": " + diag.str(),
+                 DiagnosticSeverity::Error);
+  impl->status = failure();
+}
+
+/// Process a FileLineColLoc diagnostic.
+void SourceMgrDiagnosticVerifierHandler::process(FileLineColLoc loc,
+                                                 StringRef msg,
+                                                 DiagnosticSeverity kind) {
+  // Get the expected diagnostics for this file.
+  auto diags = impl->getExpectedDiags(loc.getFilename());
+  if (!diags)
+    diags = impl->computeExpectedDiags(getBufferForFile(loc.getFilename()));
+
+  // Search for a matching expected diagnostic.
+  // If we find something that is close then emit a more specific error.
+  ExpectedDiag *nearMiss = nullptr;
+
+  // If this was an expected error, remember that we saw it and return.
+  unsigned line = loc.getLine();
+  for (auto &e : *diags) {
+    if (line == e.lineNo && msg.contains(e.substring)) {
+      if (e.kind == kind) {
+        e.matched = true;
+        return;
+      }
+
+      // If this only differs based on the diagnostic kind, then consider it
+      // to be a near miss.
+      nearMiss = &e;
+    }
+  }
+
+  // Otherwise, emit an error for the near miss.
+  if (nearMiss)
+    mgr.PrintMessage(os, nearMiss->fileLoc, llvm::SourceMgr::DK_Error,
+                     "'" + getDiagKindStr(kind) +
+                         "' diagnostic emitted when expecting a '" +
+                         getDiagKindStr(nearMiss->kind) + "'");
+  else
+    emitDiagnostic(loc, "unexpected " + getDiagKindStr(kind) + ": " + msg,
+                   DiagnosticSeverity::Error);
+  impl->status = failure();
+}
+
+//===----------------------------------------------------------------------===//
+// ParallelDiagnosticHandler
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace detail {
+struct ParallelDiagnosticHandlerImpl : public llvm::PrettyStackTraceEntry {
+  struct ThreadDiagnostic {
+    ThreadDiagnostic(size_t id, Diagnostic diag)
+        : id(id), diag(std::move(diag)) {}
+    bool operator<(const ThreadDiagnostic &rhs) const { return id < rhs.id; }
+
+    /// The id for this diagnostic, this is used for ordering.
+    /// Note: This id corresponds to the ordered position of the current element
+    ///       being processed by a given thread.
+    size_t id;
+
+    /// The diagnostic.
+    Diagnostic diag;
+  };
+
+  ParallelDiagnosticHandlerImpl(MLIRContext *ctx)
+      : prevHandler(ctx->getDiagEngine().getHandler()), context(ctx) {
+    ctx->getDiagEngine().setHandler([this](Diagnostic diag) {
+      uint64_t tid = llvm::get_threadid();
+      llvm::sys::SmartScopedLock<true> lock(mutex);
+      assert(threadToOrderID.count(tid) &&
+             "current thread does not have a valid orderID");
+
+      // Append a new diagnostic.
+      diagnostics.emplace_back(threadToOrderID[tid], std::move(diag));
+    });
+  }
+
+  ~ParallelDiagnosticHandlerImpl() {
+    // Restore the previous diagnostic handler.
+    context->getDiagEngine().setHandler(prevHandler);
+
+    // Early exit if there are no diagnostics, this is the common case.
+    if (diagnostics.empty())
+      return;
+
+    // Emit the diagnostics back to the context.
+    emitDiagnostics([&](Diagnostic diag) {
+      return context->getDiagEngine().emit(std::move(diag));
+    });
+  }
+
+  /// Utility method to emit any held diagnostics.
+  void emitDiagnostics(std::function<void(Diagnostic)> emitFn) {
+    // Stable sort all of the diagnostics that were emitted. This creates a
+    // deterministic ordering for the diagnostics based upon which order id they
+    // were emitted for.
+    std::stable_sort(diagnostics.begin(), diagnostics.end());
+
+    // Emit each diagnostic to the context again.
+    for (ThreadDiagnostic &diag : diagnostics)
+      emitFn(std::move(diag.diag));
+  }
+
+  /// Set the order id for the current thread.
+  void setOrderIDForThread(size_t orderID) {
+    uint64_t tid = llvm::get_threadid();
+    llvm::sys::SmartScopedLock<true> lock(mutex);
+    threadToOrderID[tid] = orderID;
+  }
+
+  /// Dump the current diagnostics that were inflight.
+  void print(raw_ostream &os) const override {
+    // Early exit if there are no diagnostics, this is the common case.
+    if (diagnostics.empty())
+      return;
+
+    os << "In-Flight Diagnostics:\n";
+    const_cast<ParallelDiagnosticHandlerImpl *>(this)->emitDiagnostics(
+        [&](Diagnostic diag) {
+          os.indent(4);
+
+          // Print each diagnostic with the format:
+          //   "<location>: <kind>: <msg>"
+          if (!diag.getLocation().isa<UnknownLoc>())
+            os << diag.getLocation() << ": ";
+          switch (diag.getSeverity()) {
+          case DiagnosticSeverity::Error:
+            os << "error: ";
+            break;
+          case DiagnosticSeverity::Warning:
+            os << "warning: ";
+            break;
+          case DiagnosticSeverity::Note:
+            os << "note: ";
+            break;
+          case DiagnosticSeverity::Remark:
+            os << "remark: ";
+            break;
+          }
+          os << diag << '\n';
+        });
+  }
+
+  /// The previous context diagnostic handler.
+  DiagnosticEngine::HandlerTy prevHandler;
+
+  /// A smart mutex to lock access to the internal state.
+  llvm::sys::SmartMutex<true> mutex;
+
+  /// A mapping between the thread id and the current order id.
+  DenseMap<uint64_t, size_t> threadToOrderID;
+
+  /// An unordered list of diagnostics that were emitted.
+  std::vector<ThreadDiagnostic> diagnostics;
+
+  /// The context to emit the diagnostics to.
+  MLIRContext *context;
+};
+} // end namespace detail
+} // end namespace mlir
+
+ParallelDiagnosticHandler::ParallelDiagnosticHandler(MLIRContext *ctx)
+    : impl(new ParallelDiagnosticHandlerImpl(ctx)) {}
+ParallelDiagnosticHandler::~ParallelDiagnosticHandler() {}
+
+/// Set the order id for the current thread.
+void ParallelDiagnosticHandler::setOrderIDForThread(size_t orderID) {
+  impl->setOrderIDForThread(orderID);
+}
diff --git a/third_party/mlir/lib/IR/Dialect.cpp b/third_party/mlir/lib/IR/Dialect.cpp
new file mode 100644
index 0000000..1170e06
--- /dev/null
+++ b/third_party/mlir/lib/IR/Dialect.cpp
@@ -0,0 +1,109 @@
+//===- Dialect.cpp - Dialect implementation -------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/DialectHooks.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/MLIRContext.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/ManagedStatic.h"
+#include "llvm/Support/Regex.h"
+using namespace mlir;
+
+// Registry for all dialect allocation functions.
+static llvm::ManagedStatic<SmallVector<DialectAllocatorFunction, 8>>
+    dialectRegistry;
+
+// Registry for functions that set dialect hooks.
+static llvm::ManagedStatic<SmallVector<DialectHooksSetter, 8>>
+    dialectHooksRegistry;
+
+/// Registers a specific dialect creation function with the system, typically
+/// used through the DialectRegistration template.
+void mlir::registerDialectAllocator(const DialectAllocatorFunction &function) {
+  assert(function &&
+         "Attempting to register an empty dialect initialize function");
+  dialectRegistry->push_back(function);
+}
+
+/// Registers a function to set specific hooks for a specific dialect, typically
+/// used through the DialectHooksRegistreation template.
+void mlir::registerDialectHooksSetter(const DialectHooksSetter &function) {
+  assert(
+      function &&
+      "Attempting to register an empty dialect hooks initialization function");
+
+  dialectHooksRegistry->push_back(function);
+}
+
+/// Registers all dialects and their const folding hooks with the specified
+/// MLIRContext.
+void mlir::registerAllDialects(MLIRContext *context) {
+  for (const auto &fn : *dialectRegistry)
+    fn(context);
+  for (const auto &fn : *dialectHooksRegistry) {
+    fn(context);
+  }
+}
+
+Dialect::Dialect(StringRef name, MLIRContext *context)
+    : name(name), context(context) {
+  assert(isValidNamespace(name) && "invalid dialect namespace");
+  registerDialect(context);
+}
+
+Dialect::~Dialect() {}
+
+/// Verify an attribute from this dialect on the argument at 'argIndex' for
+/// the region at 'regionIndex' on the given operation. Returns failure if
+/// the verification failed, success otherwise. This hook may optionally be
+/// invoked from any operation containing a region.
+LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
+                                                NamedAttribute) {
+  return success();
+}
+
+/// Parse an attribute registered to this dialect.
+Attribute Dialect::parseAttribute(StringRef attrData, Type type,
+                                  Location loc) const {
+  emitError(loc) << "dialect '" << getNamespace()
+                 << "' provides no attribute parsing hook";
+  return Attribute();
+}
+
+/// Parse a type registered to this dialect.
+Type Dialect::parseType(StringRef tyData, Location loc) const {
+  // If this dialect allows unknown types, then represent this with OpaqueType.
+  if (allowsUnknownTypes()) {
+    auto ns = Identifier::get(getNamespace(), getContext());
+    return OpaqueType::get(ns, tyData, getContext());
+  }
+
+  emitError(loc) << "dialect '" << getNamespace()
+                 << "' provides no type parsing hook";
+  return Type();
+}
+
+/// Utility function that returns if the given string is a valid dialect
+/// namespace.
+bool Dialect::isValidNamespace(StringRef str) {
+  if (str.empty())
+    return true;
+  llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
+  return dialectNameRegex.match(str);
+}
diff --git a/third_party/mlir/lib/IR/Function.cpp b/third_party/mlir/lib/IR/Function.cpp
new file mode 100644
index 0000000..fb54f85
--- /dev/null
+++ b/third_party/mlir/lib/IR/Function.cpp
@@ -0,0 +1,178 @@
+//===- Function.cpp - MLIR Function Classes -------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Function.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/OpImplementation.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/Twine.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Function Operation.
+//===----------------------------------------------------------------------===//
+
+FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
+                      ArrayRef<NamedAttribute> attrs) {
+  OperationState state(location, "func");
+  Builder builder(location->getContext());
+  FuncOp::build(&builder, &state, name, type, attrs);
+  return llvm::cast<FuncOp>(Operation::create(state));
+}
+FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
+                      llvm::iterator_range<dialect_attr_iterator> attrs) {
+  SmallVector<NamedAttribute, 8> attrRef(attrs);
+  return create(location, name, type, llvm::makeArrayRef(attrRef));
+}
+FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
+                      ArrayRef<NamedAttribute> attrs,
+                      ArrayRef<NamedAttributeList> argAttrs) {
+  FuncOp func = create(location, name, type, attrs);
+  func.setAllArgAttrs(argAttrs);
+  return func;
+}
+
+void FuncOp::build(Builder *builder, OperationState *result, StringRef name,
+                   FunctionType type, ArrayRef<NamedAttribute> attrs) {
+  result->addAttribute(SymbolTable::getSymbolAttrName(),
+                       builder->getStringAttr(name));
+  result->addAttribute(getTypeAttrName(), builder->getTypeAttr(type));
+  result->attributes.append(attrs.begin(), attrs.end());
+  result->addRegion();
+}
+
+void FuncOp::build(Builder *builder, OperationState *result, StringRef name,
+                   FunctionType type, ArrayRef<NamedAttribute> attrs,
+                   ArrayRef<NamedAttributeList> argAttrs) {
+  build(builder, result, name, type, attrs);
+  assert(type.getNumInputs() == argAttrs.size());
+  SmallString<8> argAttrName;
+  for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
+    if (auto argDict = argAttrs[i].getDictionary())
+      result->addAttribute(getArgAttrName(i, argAttrName), argDict);
+}
+
+/// Parsing/Printing methods.
+
+ParseResult FuncOp::parse(OpAsmParser *parser, OperationState *result) {
+  auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes,
+                          ArrayRef<Type> results, impl::VariadicFlag,
+                          std::string &) {
+    return builder.getFunctionType(argTypes, results);
+  };
+
+  return impl::parseFunctionLikeOp(parser, result, /*allowVariadic=*/false,
+                                   buildFuncType);
+}
+
+void FuncOp::print(OpAsmPrinter *p) {
+  FunctionType fnType = getType();
+  impl::printFunctionLikeOp(p, *this, fnType.getInputs(), /*isVariadic=*/false,
+                            fnType.getResults());
+}
+
+LogicalResult FuncOp::verify() {
+  // If this function is external there is nothing to do.
+  if (isExternal())
+    return success();
+
+  // Verify that the argument list of the function and the arg list of the entry
+  // block line up.  The trait already verified that the number of arguments is
+  // the same between the signature and the block.
+  auto fnInputTypes = getType().getInputs();
+  Block &entryBlock = front();
+  for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)
+    if (fnInputTypes[i] != entryBlock.getArgument(i)->getType())
+      return emitOpError("type of entry block argument #")
+             << i << '(' << entryBlock.getArgument(i)->getType()
+             << ") must match the type of the corresponding argument in "
+             << "function signature(" << fnInputTypes[i] << ')';
+
+  return success();
+}
+
+/// Add an entry block to an empty function, and set up the block arguments
+/// to match the signature of the function.
+Block *FuncOp::addEntryBlock() {
+  assert(empty() && "function already has an entry block");
+  auto *entry = new Block();
+  push_back(entry);
+  entry->addArguments(getType().getInputs());
+  return entry;
+}
+
+/// Clone the internal blocks from this function into dest and all attributes
+/// from this function to dest.
+void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {
+  // Add the attributes of this function to dest.
+  llvm::MapVector<Identifier, Attribute> newAttrs;
+  for (auto &attr : dest.getAttrs())
+    newAttrs.insert(attr);
+  for (auto &attr : getAttrs())
+    newAttrs.insert(attr);
+  dest.getOperation()->setAttrs(
+      DictionaryAttr::get(newAttrs.takeVector(), getContext()));
+
+  // Clone the body.
+  getBody().cloneInto(&dest.getBody(), mapper);
+}
+
+/// Create a deep copy of this function and all of its blocks, remapping
+/// any operands that use values outside of the function using the map that is
+/// provided (leaving them alone if no entry is present). Replaces references
+/// to cloned sub-values with the corresponding value that is copied, and adds
+/// those mappings to the mapper.
+FuncOp FuncOp::clone(BlockAndValueMapping &mapper) {
+  FunctionType newType = getType();
+
+  // If the function has a body, then the user might be deleting arguments to
+  // the function by specifying them in the mapper. If so, we don't add the
+  // argument to the input type vector.
+  bool isExternalFn = isExternal();
+  if (!isExternalFn) {
+    SmallVector<Type, 4> inputTypes;
+    inputTypes.reserve(newType.getNumInputs());
+    for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
+      if (!mapper.contains(getArgument(i)))
+        inputTypes.push_back(newType.getInput(i));
+    newType = FunctionType::get(inputTypes, newType.getResults(), getContext());
+  }
+
+  // Create the new function.
+  FuncOp newFunc = llvm::cast<FuncOp>(getOperation()->cloneWithoutRegions());
+  newFunc.setType(newType);
+
+  /// Set the argument attributes for arguments that aren't being replaced.
+  for (unsigned i = 0, e = getNumArguments(), destI = 0; i != e; ++i)
+    if (isExternalFn || !mapper.contains(getArgument(i)))
+      newFunc.setArgAttrs(destI++, getArgAttrs(i));
+
+  /// Clone the current function into the new one and return it.
+  cloneInto(newFunc, mapper);
+  return newFunc;
+}
+FuncOp FuncOp::clone() {
+  BlockAndValueMapping mapper;
+  return clone(mapper);
+}
diff --git a/third_party/mlir/lib/IR/FunctionSupport.cpp b/third_party/mlir/lib/IR/FunctionSupport.cpp
new file mode 100644
index 0000000..064e438
--- /dev/null
+++ b/third_party/mlir/lib/IR/FunctionSupport.cpp
@@ -0,0 +1,234 @@
+//===- FunctionSupport.cpp - Utility types for function-like ops ----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/FunctionSupport.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+
+using namespace mlir;
+
+static ParseResult
+parseArgumentList(OpAsmParser *parser, bool allowVariadic,
+                  SmallVectorImpl<Type> &argTypes,
+                  SmallVectorImpl<OpAsmParser::OperandType> &argNames,
+                  SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs,
+                  bool &isVariadic) {
+  if (parser->parseLParen())
+    return failure();
+
+  // The argument list either has to consistently have ssa-id's followed by
+  // types, or just be a type list.  It isn't ok to sometimes have SSA ID's and
+  // sometimes not.
+  auto parseArgument = [&]() -> ParseResult {
+    llvm::SMLoc loc = parser->getCurrentLocation();
+
+    // Parse argument name if present.
+    OpAsmParser::OperandType argument;
+    Type argumentType;
+    if (succeeded(parser->parseOptionalRegionArgument(argument)) &&
+        !argument.name.empty()) {
+      // Reject this if the preceding argument was missing a name.
+      if (argNames.empty() && !argTypes.empty())
+        return parser->emitError(loc,
+                                 "expected type instead of SSA identifier");
+      argNames.push_back(argument);
+
+      if (parser->parseColonType(argumentType))
+        return failure();
+    } else if (allowVariadic && succeeded(parser->parseOptionalEllipsis())) {
+      isVariadic = true;
+      return success();
+    } else if (!argNames.empty()) {
+      // Reject this if the preceding argument had a name.
+      return parser->emitError(loc, "expected SSA identifier");
+    } else if (parser->parseType(argumentType)) {
+      return failure();
+    }
+
+    // Add the argument type.
+    argTypes.push_back(argumentType);
+
+    // Parse any argument attributes.
+    SmallVector<NamedAttribute, 2> attrs;
+    if (parser->parseOptionalAttributeDict(attrs))
+      return failure();
+    argAttrs.push_back(attrs);
+    return success();
+  };
+
+  // Parse the function arguments.
+  if (parser->parseOptionalRParen()) {
+    do {
+      unsigned numTypedArguments = argTypes.size();
+      if (parseArgument())
+        return failure();
+
+      llvm::SMLoc loc = parser->getCurrentLocation();
+      if (argTypes.size() == numTypedArguments &&
+          succeeded(parser->parseOptionalComma()))
+        return parser->emitError(
+            loc, "variadic arguments must be in the end of the argument list");
+    } while (succeeded(parser->parseOptionalComma()));
+    parser->parseRParen();
+  }
+
+  return success();
+}
+
+/// Parse a function signature, starting with a name and including the
+/// parameter list.
+static ParseResult parseFunctionSignature(
+    OpAsmParser *parser, bool allowVariadic,
+    SmallVectorImpl<OpAsmParser::OperandType> &argNames,
+    SmallVectorImpl<Type> &argTypes,
+    SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs, bool &isVariadic,
+    SmallVectorImpl<Type> &results) {
+  if (parseArgumentList(parser, allowVariadic, argTypes, argNames, argAttrs,
+                        isVariadic))
+    return failure();
+  // Parse the return types if present.
+  return parser->parseOptionalArrowTypeList(results);
+}
+
+/// Parser implementation for function-like operations.  Uses `funcTypeBuilder`
+/// to construct the custom function type given lists of input and output types.
+ParseResult
+mlir::impl::parseFunctionLikeOp(OpAsmParser *parser, OperationState *result,
+                                bool allowVariadic,
+                                mlir::impl::FuncTypeBuilder funcTypeBuilder) {
+  SmallVector<OpAsmParser::OperandType, 4> entryArgs;
+  SmallVector<SmallVector<NamedAttribute, 2>, 4> argAttrs;
+  SmallVector<Type, 4> argTypes;
+  SmallVector<Type, 4> results;
+  auto &builder = parser->getBuilder();
+
+  // Parse the name as a symbol reference attribute.
+  SymbolRefAttr nameAttr;
+  if (parser->parseAttribute(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
+                             result->attributes))
+    return failure();
+  // Convert the parsed function attr into a string attr.
+  result->attributes.back().second = builder.getStringAttr(nameAttr.getValue());
+
+  // Parse the function signature.
+  auto signatureLocation = parser->getCurrentLocation();
+  bool isVariadic = false;
+  if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes,
+                             argAttrs, isVariadic, results))
+    return failure();
+
+  std::string errorMessage;
+  if (auto type = funcTypeBuilder(builder, argTypes, results,
+                                  impl::VariadicFlag(isVariadic), errorMessage))
+    result->addAttribute(getTypeAttrName(), builder.getTypeAttr(type));
+  else
+    return parser->emitError(signatureLocation)
+           << "failed to construct function type"
+           << (errorMessage.empty() ? "" : ": ") << errorMessage;
+
+  // If function attributes are present, parse them.
+  if (succeeded(parser->parseOptionalKeyword("attributes")))
+    if (parser->parseOptionalAttributeDict(result->attributes))
+      return failure();
+
+  // Add the attributes to the function arguments.
+  SmallString<8> argAttrName;
+  for (unsigned i = 0, e = argTypes.size(); i != e; ++i)
+    if (!argAttrs[i].empty())
+      result->addAttribute(getArgAttrName(i, argAttrName),
+                           builder.getDictionaryAttr(argAttrs[i]));
+
+  // Parse the optional function body.
+  auto *body = result->addRegion();
+  if (parser->parseOptionalRegion(*body, entryArgs,
+                                  entryArgs.empty() ? llvm::ArrayRef<Type>()
+                                                    : argTypes))
+    return failure();
+
+  return success();
+}
+
+/// Print the signature of the function-like operation `op`.  Assumes `op` has
+/// the FunctionLike trait and passed the verification.
+static void printSignature(OpAsmPrinter *p, Operation *op,
+                           ArrayRef<Type> argTypes, bool isVariadic,
+                           ArrayRef<Type> results) {
+  Region &body = op->getRegion(0);
+  bool isExternal = body.empty();
+
+  *p << '(';
+  for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
+    if (i > 0)
+      *p << ", ";
+
+    if (!isExternal) {
+      p->printOperand(body.front().getArgument(i));
+      *p << ": ";
+    }
+
+    p->printType(argTypes[i]);
+    p->printOptionalAttrDict(::mlir::impl::getArgAttrs(op, i));
+  }
+
+  if (isVariadic) {
+    if (!argTypes.empty())
+      *p << ", ";
+    *p << "...";
+  }
+
+  *p << ')';
+  p->printOptionalArrowTypeList(results);
+}
+
+/// Printer implementation for function-like operations.  Accepts lists of
+/// argument and result types to use while printing.
+void mlir::impl::printFunctionLikeOp(OpAsmPrinter *p, Operation *op,
+                                     ArrayRef<Type> argTypes, bool isVariadic,
+                                     ArrayRef<Type> results) {
+  // Print the operation and the function name.
+  auto funcName =
+      op->getAttrOfType<StringAttr>(::mlir::SymbolTable::getSymbolAttrName())
+          .getValue();
+  *p << op->getName() << " @" << funcName;
+
+  // Print the signature.
+  printSignature(p, op, argTypes, isVariadic, results);
+
+  // Print out function attributes, if present.
+  SmallVector<StringRef, 2> ignoredAttrs = {
+      ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName()};
+
+  // Ignore any argument attributes.
+  std::vector<SmallString<8>> argAttrStorage;
+  SmallString<8> argAttrName;
+  for (unsigned i = 0, e = argTypes.size(); i != e; ++i)
+    if (op->getAttr(getArgAttrName(i, argAttrName)))
+      argAttrStorage.emplace_back(argAttrName);
+  ignoredAttrs.append(argAttrStorage.begin(), argAttrStorage.end());
+
+  auto attrs = op->getAttrs();
+  if (attrs.size() > ignoredAttrs.size()) {
+    *p << "\n  attributes ";
+    p->printOptionalAttrDict(attrs, ignoredAttrs);
+  }
+
+  // Print the body if this is not an external function.
+  Region &body = op->getRegion(0);
+  if (!body.empty())
+    p->printRegion(body, /*printEntryBlockArgs=*/false,
+                   /*printBlockTerminators=*/true);
+}
diff --git a/third_party/mlir/lib/IR/IntegerSet.cpp b/third_party/mlir/lib/IR/IntegerSet.cpp
new file mode 100644
index 0000000..74a1297
--- /dev/null
+++ b/third_party/mlir/lib/IR/IntegerSet.cpp
@@ -0,0 +1,72 @@
+//===- IntegerSet.cpp - MLIR Integer Set class ----------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/IntegerSet.h"
+#include "IntegerSetDetail.h"
+#include "mlir/IR/AffineExpr.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+unsigned IntegerSet::getNumDims() const { return set->dimCount; }
+unsigned IntegerSet::getNumSymbols() const { return set->symbolCount; }
+unsigned IntegerSet::getNumOperands() const {
+  return set->dimCount + set->symbolCount;
+}
+
+unsigned IntegerSet::getNumConstraints() const {
+  return set->constraints.size();
+}
+
+unsigned IntegerSet::getNumEqualities() const {
+  unsigned numEqualities = 0;
+  for (unsigned i = 0, e = getNumConstraints(); i < e; i++)
+    if (isEq(i))
+      ++numEqualities;
+  return numEqualities;
+}
+
+unsigned IntegerSet::getNumInequalities() const {
+  return getNumConstraints() - getNumEqualities();
+}
+
+bool IntegerSet::isEmptyIntegerSet() const {
+  // This will only work if uniqui'ing is on.
+  static_assert(kUniquingThreshold >= 1,
+                "uniquing threshold should be at least one");
+  return *this == getEmptySet(set->dimCount, set->symbolCount, getContext());
+}
+
+ArrayRef<AffineExpr> IntegerSet::getConstraints() const {
+  return set->constraints;
+}
+
+AffineExpr IntegerSet::getConstraint(unsigned idx) const {
+  return getConstraints()[idx];
+}
+
+/// Returns the equality bits, which specify whether each of the constraints
+/// is an equality or inequality.
+ArrayRef<bool> IntegerSet::getEqFlags() const { return set->eqFlags; }
+
+/// Returns true if the idx^th constraint is an equality, false if it is an
+/// inequality.
+bool IntegerSet::isEq(unsigned idx) const { return getEqFlags()[idx]; }
+
+MLIRContext *IntegerSet::getContext() const {
+  return getConstraint(0).getContext();
+}
diff --git a/third_party/mlir/lib/IR/IntegerSetDetail.h b/third_party/mlir/lib/IR/IntegerSetDetail.h
new file mode 100644
index 0000000..b3eda52
--- /dev/null
+++ b/third_party/mlir/lib/IR/IntegerSetDetail.h
@@ -0,0 +1,45 @@
+//===- IntegerSetDetail.h - MLIR IntegerSet storage details -----*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This holds implementation details of IntegerSet.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef INTEGERSETDETAIL_H_
+#define INTEGERSETDETAIL_H_
+
+#include "mlir/IR/AffineExpr.h"
+#include "llvm/ADT/ArrayRef.h"
+
+namespace mlir {
+namespace detail {
+
+struct IntegerSetStorage {
+  unsigned dimCount;
+  unsigned symbolCount;
+
+  /// Array of affine constraints: a constraint is either an equality
+  /// (affine_expr == 0) or an inequality (affine_expr >= 0).
+  ArrayRef<AffineExpr> constraints;
+
+  // Bits to check whether a constraint is an equality or an inequality.
+  ArrayRef<bool> eqFlags;
+};
+
+} // end namespace detail
+} // end namespace mlir
+#endif // INTEGERSETDETAIL_H_
diff --git a/third_party/mlir/lib/IR/Location.cpp b/third_party/mlir/lib/IR/Location.cpp
new file mode 100644
index 0000000..83b579c
--- /dev/null
+++ b/third_party/mlir/lib/IR/Location.cpp
@@ -0,0 +1,126 @@
+//===- Location.cpp - MLIR Location Classes -------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Location.h"
+#include "LocationDetail.h"
+#include "llvm/ADT/SetVector.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+//===----------------------------------------------------------------------===//
+// CallSiteLoc
+//===----------------------------------------------------------------------===//
+
+Location CallSiteLoc::get(Location callee, Location caller,
+                          MLIRContext *context) {
+  return Base::get(context, StandardAttributes::CallSiteLocation, callee,
+                   caller);
+}
+
+Location CallSiteLoc::get(Location name, ArrayRef<Location> frames,
+                          MLIRContext *context) {
+  assert(!frames.empty() && "required at least 1 frames");
+  Location caller = frames.back();
+  for (auto frame : llvm::reverse(frames.drop_back()))
+    caller = CallSiteLoc::get(frame, caller, context);
+  return CallSiteLoc::get(name, caller, context);
+}
+
+Location CallSiteLoc::getCallee() const { return getImpl()->callee; }
+
+Location CallSiteLoc::getCaller() const { return getImpl()->caller; }
+
+//===----------------------------------------------------------------------===//
+// FileLineColLoc
+//===----------------------------------------------------------------------===//
+
+Location FileLineColLoc::get(Identifier filename, unsigned line,
+                             unsigned column, MLIRContext *context) {
+  return Base::get(context, StandardAttributes::FileLineColLocation, filename,
+                   line, column);
+}
+
+Location FileLineColLoc::get(StringRef filename, unsigned line, unsigned column,
+                             MLIRContext *context) {
+  return get(Identifier::get(filename.empty() ? "-" : filename, context), line,
+             column, context);
+}
+
+StringRef FileLineColLoc::getFilename() const { return getImpl()->filename; }
+unsigned FileLineColLoc::getLine() const { return getImpl()->line; }
+unsigned FileLineColLoc::getColumn() const { return getImpl()->column; }
+
+//===----------------------------------------------------------------------===//
+// FusedLoc
+//===----------------------------------------------------------------------===//
+
+Location FusedLoc::get(ArrayRef<Location> locs, Attribute metadata,
+                       MLIRContext *context) {
+  // Unique the set of locations to be fused.
+  llvm::SmallSetVector<Location, 4> decomposedLocs;
+  for (auto loc : locs) {
+    // If the location is a fused location we decompose it if it has no
+    // metadata or the metadata is the same as the top level metadata.
+    if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
+      if (fusedLoc.getMetadata() == metadata) {
+        // UnknownLoc's have already been removed from FusedLocs so we can
+        // simply add all of the internal locations.
+        decomposedLocs.insert(fusedLoc.getLocations().begin(),
+                              fusedLoc.getLocations().end());
+        continue;
+      }
+    }
+    // Otherwise, only add known locations to the set.
+    if (!loc.isa<UnknownLoc>())
+      decomposedLocs.insert(loc);
+  }
+  locs = decomposedLocs.getArrayRef();
+
+  // Handle the simple cases of less than two locations.
+  if (locs.empty())
+    return UnknownLoc::get(context);
+  if (locs.size() == 1)
+    return locs.front();
+  return Base::get(context, StandardAttributes::FusedLocation, locs, metadata);
+}
+
+ArrayRef<Location> FusedLoc::getLocations() const {
+  return getImpl()->getLocations();
+}
+
+Attribute FusedLoc::getMetadata() const { return getImpl()->metadata; }
+
+//===----------------------------------------------------------------------===//
+// NameLoc
+//===----------------------------------------------------------------------===//
+
+Location NameLoc::get(Identifier name, Location child, MLIRContext *context) {
+  assert(!child.isa<NameLoc>() &&
+         "a NameLoc cannot be used as a child of another NameLoc");
+  return Base::get(context, StandardAttributes::NameLocation, name, child);
+}
+
+Location NameLoc::get(Identifier name, MLIRContext *context) {
+  return get(name, UnknownLoc::get(context), context);
+}
+
+/// Return the name identifier.
+Identifier NameLoc::getName() const { return getImpl()->name; }
+
+/// Return the child location.
+Location NameLoc::getChildLoc() const { return getImpl()->child; }
diff --git a/third_party/mlir/lib/IR/LocationDetail.h b/third_party/mlir/lib/IR/LocationDetail.h
new file mode 100644
index 0000000..2076eb7
--- /dev/null
+++ b/third_party/mlir/lib/IR/LocationDetail.h
@@ -0,0 +1,140 @@
+//===- LocationDetail.h - MLIR Location storage details ---------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This holds implementation details of the location attributes.
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_IR_LOCATIONDETAIL_H_
+#define MLIR_IR_LOCATIONDETAIL_H_
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Identifier.h"
+#include "mlir/IR/Location.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/TrailingObjects.h"
+
+namespace mlir {
+
+namespace detail {
+
+struct CallSiteLocationStorage : public AttributeStorage {
+  CallSiteLocationStorage(Location callee, Location caller)
+      : callee(callee), caller(caller) {}
+
+  /// The hash key used for uniquing.
+  using KeyTy = std::pair<Location, Location>;
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(callee, caller);
+  }
+
+  /// Construct a new storage instance.
+  static CallSiteLocationStorage *
+  construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+    return new (allocator.allocate<CallSiteLocationStorage>())
+        CallSiteLocationStorage(key.first, key.second);
+  }
+
+  Location callee, caller;
+};
+
+struct FileLineColLocationStorage : public AttributeStorage {
+  FileLineColLocationStorage(Identifier filename, unsigned line,
+                             unsigned column)
+      : filename(filename), line(line), column(column) {}
+
+  /// The hash key used for uniquing.
+  using KeyTy = std::tuple<Identifier, unsigned, unsigned>;
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(filename, line, column);
+  }
+
+  /// Construct a new storage instance.
+  static FileLineColLocationStorage *
+  construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+    return new (allocator.allocate<FileLineColLocationStorage>())
+        FileLineColLocationStorage(std::get<0>(key), std::get<1>(key),
+                                   std::get<2>(key));
+  }
+
+  Identifier filename;
+  unsigned line, column;
+};
+
+struct FusedLocationStorage final
+    : public AttributeStorage,
+      public llvm::TrailingObjects<FusedLocationStorage, Location> {
+  FusedLocationStorage(unsigned numLocs, Attribute metadata)
+      : numLocs(numLocs), metadata(metadata) {}
+
+  ArrayRef<Location> getLocations() const {
+    return ArrayRef<Location>(getTrailingObjects<Location>(), numLocs);
+  }
+
+  /// The hash key used for uniquing.
+  using KeyTy = std::pair<ArrayRef<Location>, Attribute>;
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(getLocations(), metadata);
+  }
+
+  /// Construct a new storage instance.
+  static FusedLocationStorage *construct(AttributeStorageAllocator &allocator,
+                                         const KeyTy &key) {
+    ArrayRef<Location> locs = key.first;
+
+    auto byteSize = totalSizeToAlloc<Location>(locs.size());
+    auto rawMem = allocator.allocate(byteSize, alignof(FusedLocationStorage));
+    auto result = new (rawMem) FusedLocationStorage(locs.size(), key.second);
+
+    std::uninitialized_copy(locs.begin(), locs.end(),
+                            result->getTrailingObjects<Location>());
+    return result;
+  }
+
+  // This stuff is used by the TrailingObjects template.
+  friend llvm::TrailingObjects<FusedLocationStorage, Location>;
+  size_t numTrailingObjects(OverloadToken<Location>) const { return numLocs; }
+
+  /// Number of trailing location objects.
+  unsigned numLocs;
+
+  /// Metadata used to reason about the generation of this fused location.
+  Attribute metadata;
+};
+
+struct NameLocationStorage : public AttributeStorage {
+  NameLocationStorage(Identifier name, Location child)
+      : name(name), child(child) {}
+
+  /// The hash key used for uniquing.
+  using KeyTy = std::pair<Identifier, Location>;
+  bool operator==(const KeyTy &key) const { return key == KeyTy(name, child); }
+
+  /// Construct a new storage instance.
+  static NameLocationStorage *construct(AttributeStorageAllocator &allocator,
+                                        const KeyTy &key) {
+    return new (allocator.allocate<NameLocationStorage>())
+        NameLocationStorage(key.first, key.second);
+  }
+
+  Identifier name;
+  Location child;
+};
+
+} // end namespace detail
+} // end namespace mlir
+
+#endif // MLIR_IR_LOCATIONDETAIL_H_
diff --git a/third_party/mlir/lib/IR/MLIRContext.cpp b/third_party/mlir/lib/IR/MLIRContext.cpp
new file mode 100644
index 0000000..f2f4b2c
--- /dev/null
+++ b/third_party/mlir/lib/IR/MLIRContext.cpp
@@ -0,0 +1,641 @@
+//===- MLIRContext.cpp - MLIR Type Classes --------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/MLIRContext.h"
+#include "AffineExprDetail.h"
+#include "AffineMapDetail.h"
+#include "AttributeDetail.h"
+#include "IntegerSetDetail.h"
+#include "LocationDetail.h"
+#include "TypeDetail.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Identifier.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Support/STLExtras.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/RWMutex.h"
+#include "llvm/Support/raw_ostream.h"
+#include <memory>
+
+using namespace mlir;
+using namespace mlir::detail;
+
+using llvm::hash_combine;
+using llvm::hash_combine_range;
+
+/// A utility function to safely get or create a uniqued instance within the
+/// given set container.
+template <typename ValueT, typename DenseInfoT, typename KeyT,
+          typename ConstructorFn>
+static ValueT safeGetOrCreate(DenseSet<ValueT, DenseInfoT> &container,
+                              KeyT &&key, llvm::sys::SmartRWMutex<true> &mutex,
+                              ConstructorFn &&constructorFn) {
+  { // Check for an existing instance in read-only mode.
+    llvm::sys::SmartScopedReader<true> instanceLock(mutex);
+    auto it = container.find_as(key);
+    if (it != container.end())
+      return *it;
+  }
+
+  // Aquire a writer-lock so that we can safely create the new instance.
+  llvm::sys::SmartScopedWriter<true> instanceLock(mutex);
+
+  // Check for an existing instance again here, because another writer thread
+  // may have already created one.
+  auto existing = container.insert_as(ValueT(), key);
+  if (!existing.second)
+    return *existing.first;
+
+  // Otherwise, construct a new instance of the value.
+  return *existing.first = constructorFn();
+}
+
+namespace {
+/// A builtin dialect to define types/etc that are necessary for the validity of
+/// the IR.
+struct BuiltinDialect : public Dialect {
+  BuiltinDialect(MLIRContext *context) : Dialect(/*name=*/"", context) {
+    addAttributes<AffineMapAttr, ArrayAttr, BoolAttr, DenseElementsAttr,
+                  DictionaryAttr, FloatAttr, SymbolRefAttr, IntegerAttr,
+                  IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr,
+                  SparseElementsAttr, StringAttr, TypeAttr, UnitAttr>();
+    addAttributes<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, UnknownLoc>();
+
+    addTypes<ComplexType, FloatType, FunctionType, IndexType, IntegerType,
+             MemRefType, NoneType, OpaqueType, RankedTensorType, TupleType,
+             UnrankedTensorType, VectorType>();
+
+    // TODO: These operations should be moved to a different dialect when they
+    // have been fully decoupled from the core.
+    addOperations<FuncOp, ModuleOp, ModuleTerminatorOp>();
+  }
+};
+
+struct AffineMapKeyInfo : DenseMapInfo<AffineMap> {
+  // Affine maps are uniqued based on their dim/symbol counts and affine
+  // expressions.
+  using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr>>;
+  using DenseMapInfo<AffineMap>::isEqual;
+
+  static unsigned getHashValue(const AffineMap &key) {
+    return getHashValue(
+        KeyTy(key.getNumDims(), key.getNumSymbols(), key.getResults()));
+  }
+
+  static unsigned getHashValue(KeyTy key) {
+    return hash_combine(
+        std::get<0>(key), std::get<1>(key),
+        hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()));
+  }
+
+  static bool isEqual(const KeyTy &lhs, AffineMap rhs) {
+    if (rhs == getEmptyKey() || rhs == getTombstoneKey())
+      return false;
+    return lhs == std::make_tuple(rhs.getNumDims(), rhs.getNumSymbols(),
+                                  rhs.getResults());
+  }
+};
+
+struct IntegerSetKeyInfo : DenseMapInfo<IntegerSet> {
+  // Integer sets are uniqued based on their dim/symbol counts, affine
+  // expressions appearing in the LHS of constraints, and eqFlags.
+  using KeyTy =
+      std::tuple<unsigned, unsigned, ArrayRef<AffineExpr>, ArrayRef<bool>>;
+  using DenseMapInfo<IntegerSet>::isEqual;
+
+  static unsigned getHashValue(const IntegerSet &key) {
+    return getHashValue(KeyTy(key.getNumDims(), key.getNumSymbols(),
+                              key.getConstraints(), key.getEqFlags()));
+  }
+
+  static unsigned getHashValue(KeyTy key) {
+    return hash_combine(
+        std::get<0>(key), std::get<1>(key),
+        hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()),
+        hash_combine_range(std::get<3>(key).begin(), std::get<3>(key).end()));
+  }
+
+  static bool isEqual(const KeyTy &lhs, IntegerSet rhs) {
+    if (rhs == getEmptyKey() || rhs == getTombstoneKey())
+      return false;
+    return lhs == std::make_tuple(rhs.getNumDims(), rhs.getNumSymbols(),
+                                  rhs.getConstraints(), rhs.getEqFlags());
+  }
+};
+} // end anonymous namespace.
+
+namespace mlir {
+/// This is the implementation of the MLIRContext class, using the pImpl idiom.
+/// This class is completely private to this file, so everything is public.
+class MLIRContextImpl {
+public:
+  //===--------------------------------------------------------------------===//
+  // Identifier uniquing
+  //===--------------------------------------------------------------------===//
+
+  // Identifier allocator and mutex for thread safety.
+  llvm::BumpPtrAllocator identifierAllocator;
+  llvm::sys::SmartRWMutex<true> identifierMutex;
+
+  //===--------------------------------------------------------------------===//
+  // Diagnostics
+  //===--------------------------------------------------------------------===//
+  DiagnosticEngine diagEngine;
+
+  //===--------------------------------------------------------------------===//
+  // Other
+  //===--------------------------------------------------------------------===//
+
+  /// A general purpose mutex to lock access to parts of the context that do not
+  /// have a more specific mutex, e.g. registry operations.
+  llvm::sys::SmartRWMutex<true> contextMutex;
+
+  /// This is a list of dialects that are created referring to this context.
+  /// The MLIRContext owns the objects.
+  std::vector<std::unique_ptr<Dialect>> dialects;
+
+  /// This is a mapping from operation name to AbstractOperation for registered
+  /// operations.
+  llvm::StringMap<AbstractOperation> registeredOperations;
+
+  /// This is a mapping from class identifier to Dialect for registered
+  /// attributes and types.
+  DenseMap<const ClassID *, Dialect *> registeredDialectSymbols;
+
+  /// These are identifiers uniqued into this MLIRContext.
+  llvm::StringMap<char, llvm::BumpPtrAllocator &> identifiers;
+
+  //===--------------------------------------------------------------------===//
+  // Affine uniquing
+  //===--------------------------------------------------------------------===//
+
+  // Affine allocator and mutex for thread safety.
+  llvm::BumpPtrAllocator affineAllocator;
+  llvm::sys::SmartRWMutex<true> affineMutex;
+
+  // Affine map uniquing.
+  using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>;
+  AffineMapSet affineMaps;
+
+  // Integer set uniquing.
+  using IntegerSets = DenseSet<IntegerSet, IntegerSetKeyInfo>;
+  IntegerSets integerSets;
+
+  // Affine expression uniqui'ing.
+  StorageUniquer affineUniquer;
+
+  //===--------------------------------------------------------------------===//
+  // Type uniquing
+  //===--------------------------------------------------------------------===//
+  StorageUniquer typeUniquer;
+
+  /// Cached Type Instances.
+  FloatType bf16Ty, f16Ty, f32Ty, f64Ty;
+  IndexType indexTy;
+  IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
+  NoneType noneType;
+
+  //===--------------------------------------------------------------------===//
+  // Attribute uniquing
+  //===--------------------------------------------------------------------===//
+  StorageUniquer attributeUniquer;
+
+  /// Cached Attribute Instances.
+  BoolAttr falseAttr, trueAttr;
+  UnitAttr unitAttr;
+  UnknownLoc unknownLocAttr;
+
+public:
+  MLIRContextImpl() : identifiers(identifierAllocator) {}
+};
+} // end namespace mlir
+
+MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
+  new BuiltinDialect(this);
+  registerAllDialects(this);
+
+  // Initialize several common attributes and types to avoid the need to lock
+  // the context when accessing them.
+
+  //// Types.
+  /// Floating-point Types.
+  impl->bf16Ty = TypeUniquer::get<FloatType>(this, StandardTypes::BF16);
+  impl->f16Ty = TypeUniquer::get<FloatType>(this, StandardTypes::F16);
+  impl->f32Ty = TypeUniquer::get<FloatType>(this, StandardTypes::F32);
+  impl->f64Ty = TypeUniquer::get<FloatType>(this, StandardTypes::F64);
+  /// Index Type.
+  impl->indexTy = TypeUniquer::get<IndexType>(this, StandardTypes::Index);
+  /// Integer Types.
+  impl->int1Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer, 1);
+  impl->int8Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer, 8);
+  impl->int16Ty =
+      TypeUniquer::get<IntegerType>(this, StandardTypes::Integer, 16);
+  impl->int32Ty =
+      TypeUniquer::get<IntegerType>(this, StandardTypes::Integer, 32);
+  impl->int64Ty =
+      TypeUniquer::get<IntegerType>(this, StandardTypes::Integer, 64);
+  impl->int128Ty =
+      TypeUniquer::get<IntegerType>(this, StandardTypes::Integer, 128);
+  /// None Type.
+  impl->noneType = TypeUniquer::get<NoneType>(this, StandardTypes::None);
+
+  //// Attributes.
+  //// Note: These must be registered after the types as they may generate one
+  //// of the above types internally.
+  /// Bool Attributes.
+  // Note: The context is also used within the BoolAttrStorage.
+  impl->falseAttr = AttributeUniquer::get<BoolAttr>(
+      this, StandardAttributes::Bool, this, false);
+  impl->trueAttr = AttributeUniquer::get<BoolAttr>(
+      this, StandardAttributes::Bool, this, true);
+  /// Unit Attribute.
+  impl->unitAttr =
+      AttributeUniquer::get<UnitAttr>(this, StandardAttributes::Unit);
+  /// Unknown Location Attribute.
+  impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(
+      this, StandardAttributes::UnknownLocation);
+}
+
+MLIRContext::~MLIRContext() {}
+
+/// Copy the specified array of elements into memory managed by the provided
+/// bump pointer allocator.  This assumes the elements are all PODs.
+template <typename T>
+static ArrayRef<T> copyArrayRefInto(llvm::BumpPtrAllocator &allocator,
+                                    ArrayRef<T> elements) {
+  auto result = allocator.Allocate<T>(elements.size());
+  std::uninitialized_copy(elements.begin(), elements.end(), result);
+  return ArrayRef<T>(result, elements.size());
+}
+
+//===----------------------------------------------------------------------===//
+// Diagnostic Handlers
+//===----------------------------------------------------------------------===//
+
+/// Returns the diagnostic engine for this context.
+DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
+
+//===----------------------------------------------------------------------===//
+// Dialect and Operation Registration
+//===----------------------------------------------------------------------===//
+
+/// Return information about all registered IR dialects.
+std::vector<Dialect *> MLIRContext::getRegisteredDialects() {
+  // Lock access to the context registry.
+  llvm::sys::SmartScopedReader<true> registryLock(getImpl().contextMutex);
+
+  std::vector<Dialect *> result;
+  result.reserve(getImpl().dialects.size());
+  for (auto &dialect : getImpl().dialects)
+    result.push_back(dialect.get());
+  return result;
+}
+
+/// Get a registered IR dialect with the given namespace. If none is found,
+/// then return nullptr.
+Dialect *MLIRContext::getRegisteredDialect(StringRef name) {
+  // Lock access to the context registry.
+  llvm::sys::SmartScopedReader<true> registryLock(getImpl().contextMutex);
+  for (auto &dialect : getImpl().dialects)
+    if (name == dialect->getNamespace())
+      return dialect.get();
+  return nullptr;
+}
+
+/// Register this dialect object with the specified context.  The context
+/// takes ownership of the heap allocated dialect.
+void Dialect::registerDialect(MLIRContext *context) {
+  auto &impl = context->getImpl();
+
+  // Lock access to the context registry.
+  llvm::sys::SmartScopedWriter<true> registryLock(impl.contextMutex);
+  // Abort if dialect with namespace has already been registered.
+  if (llvm::any_of(impl.dialects, [this](std::unique_ptr<Dialect> &dialect) {
+        return dialect->getNamespace() == getNamespace();
+      })) {
+    llvm::report_fatal_error("a dialect with namespace '" +
+                             Twine(getNamespace()) +
+                             "' has already been registered");
+  }
+  impl.dialects.push_back(std::unique_ptr<Dialect>(this));
+}
+
+/// Return information about all registered operations.  This isn't very
+/// efficient, typically you should ask the operations about their properties
+/// directly.
+std::vector<AbstractOperation *> MLIRContext::getRegisteredOperations() {
+  std::vector<std::pair<StringRef, AbstractOperation *>> opsToSort;
+
+  { // Lock access to the context registry.
+    llvm::sys::SmartScopedReader<true> registryLock(getImpl().contextMutex);
+
+    // We just have the operations in a non-deterministic hash table order. Dump
+    // into a temporary array, then sort it by operation name to get a stable
+    // ordering.
+    llvm::StringMap<AbstractOperation> &registeredOps =
+        getImpl().registeredOperations;
+
+    opsToSort.reserve(registeredOps.size());
+    for (auto &elt : registeredOps)
+      opsToSort.push_back({elt.first(), &elt.second});
+  }
+
+  llvm::array_pod_sort(opsToSort.begin(), opsToSort.end());
+
+  std::vector<AbstractOperation *> result;
+  result.reserve(opsToSort.size());
+  for (auto &elt : opsToSort)
+    result.push_back(elt.second);
+  return result;
+}
+
+void Dialect::addOperation(AbstractOperation opInfo) {
+  assert((getNamespace().empty() ||
+          opInfo.name.split('.').first == getNamespace()) &&
+         "op name doesn't start with dialect namespace");
+  assert(&opInfo.dialect == this && "Dialect object mismatch");
+  auto &impl = context->getImpl();
+
+  // Lock access to the context registry.
+  llvm::sys::SmartScopedWriter<true> registryLock(impl.contextMutex);
+  if (!impl.registeredOperations.insert({opInfo.name, opInfo}).second) {
+    llvm::errs() << "error: operation named '" << opInfo.name
+                 << "' is already registered.\n";
+    abort();
+  }
+}
+
+/// Register a dialect-specific symbol(e.g. type) with the current context.
+void Dialect::addSymbol(const ClassID *const classID) {
+  auto &impl = context->getImpl();
+
+  // Lock access to the context registry.
+  llvm::sys::SmartScopedWriter<true> registryLock(impl.contextMutex);
+  if (!impl.registeredDialectSymbols.insert({classID, this}).second) {
+    llvm::errs() << "error: dialect symbol already registered.\n";
+    abort();
+  }
+}
+
+/// Look up the specified operation in the operation set and return a pointer
+/// to it if present.  Otherwise, return a null pointer.
+const AbstractOperation *AbstractOperation::lookup(StringRef opName,
+                                                   MLIRContext *context) {
+  auto &impl = context->getImpl();
+
+  // Lock access to the context registry.
+  llvm::sys::SmartScopedReader<true> registryLock(impl.contextMutex);
+  auto it = impl.registeredOperations.find(opName);
+  if (it != impl.registeredOperations.end())
+    return &it->second;
+  return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// Identifier uniquing
+//===----------------------------------------------------------------------===//
+
+/// Return an identifier for the specified string.
+Identifier Identifier::get(StringRef str, MLIRContext *context) {
+  assert(!str.empty() && "Cannot create an empty identifier");
+  assert(str.find('\0') == StringRef::npos &&
+         "Cannot create an identifier with a nul character");
+
+  auto &impl = context->getImpl();
+
+  { // Check for an existing identifier in read-only mode.
+    llvm::sys::SmartScopedReader<true> contextLock(impl.identifierMutex);
+    auto it = impl.identifiers.find(str);
+    if (it != impl.identifiers.end())
+      return Identifier(it->getKeyData());
+  }
+
+  // Aquire a writer-lock so that we can safely create the new instance.
+  llvm::sys::SmartScopedWriter<true> contextLock(impl.identifierMutex);
+  auto it = impl.identifiers.insert({str, char()}).first;
+  return Identifier(it->getKeyData());
+}
+
+//===----------------------------------------------------------------------===//
+// Type uniquing
+//===----------------------------------------------------------------------===//
+
+static Dialect &lookupDialectForSymbol(MLIRContext *ctx,
+                                       const ClassID *const classID) {
+  auto &impl = ctx->getImpl();
+  auto it = impl.registeredDialectSymbols.find(classID);
+  assert(it != impl.registeredDialectSymbols.end() &&
+         "symbol is not registered.");
+  return *it->second;
+}
+
+/// Returns the storage unqiuer used for constructing type storage instances.
+/// This should not be used directly.
+StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
+
+/// Get the dialect that registered the type with the provided typeid.
+Dialect &TypeUniquer::lookupDialectForType(MLIRContext *ctx,
+                                           const ClassID *const typeID) {
+  return lookupDialectForSymbol(ctx, typeID);
+}
+
+FloatType FloatType::get(StandardTypes::Kind kind, MLIRContext *context) {
+  assert(kindof(kind) && "Not a FP kind.");
+  switch (kind) {
+  case StandardTypes::BF16:
+    return context->getImpl().bf16Ty;
+  case StandardTypes::F16:
+    return context->getImpl().f16Ty;
+  case StandardTypes::F32:
+    return context->getImpl().f32Ty;
+  case StandardTypes::F64:
+    return context->getImpl().f64Ty;
+  default:
+    llvm_unreachable("unexpected floating-point kind");
+  }
+}
+
+/// Get an instance of the IndexType.
+IndexType IndexType::get(MLIRContext *context) {
+  return context->getImpl().indexTy;
+}
+
+/// Return an existing integer type instance if one is cached within the
+/// context.
+static IntegerType getCachedIntegerType(unsigned width, MLIRContext *context) {
+  switch (width) {
+  case 1:
+    return context->getImpl().int1Ty;
+  case 8:
+    return context->getImpl().int8Ty;
+  case 16:
+    return context->getImpl().int16Ty;
+  case 32:
+    return context->getImpl().int32Ty;
+  case 64:
+    return context->getImpl().int64Ty;
+  case 128:
+    return context->getImpl().int128Ty;
+  default:
+    return IntegerType();
+  }
+}
+
+IntegerType IntegerType::get(unsigned width, MLIRContext *context) {
+  if (auto cached = getCachedIntegerType(width, context))
+    return cached;
+  return Base::get(context, StandardTypes::Integer, width);
+}
+
+IntegerType IntegerType::getChecked(unsigned width, MLIRContext *context,
+                                    Location location) {
+  if (auto cached = getCachedIntegerType(width, context))
+    return cached;
+  return Base::getChecked(location, context, StandardTypes::Integer, width);
+}
+
+/// Get an instance of the NoneType.
+NoneType NoneType::get(MLIRContext *context) {
+  return context->getImpl().noneType;
+}
+
+//===----------------------------------------------------------------------===//
+// Attribute uniquing
+//===----------------------------------------------------------------------===//
+
+/// Returns the storage uniquer used for constructing attribute storage
+/// instances. This should not be used directly.
+StorageUniquer &MLIRContext::getAttributeUniquer() {
+  return getImpl().attributeUniquer;
+}
+
+/// Returns a functor used to initialize new attribute storage instances.
+std::function<void(AttributeStorage *)>
+AttributeUniquer::getInitFn(MLIRContext *ctx, const ClassID *const attrID) {
+  return [ctx, attrID](AttributeStorage *storage) {
+    storage->initializeDialect(lookupDialectForSymbol(ctx, attrID));
+
+    // If the attribute did not provide a type, then default to NoneType.
+    if (!storage->getType())
+      storage->setType(NoneType::get(ctx));
+  };
+}
+
+BoolAttr BoolAttr::get(bool value, MLIRContext *context) {
+  return value ? context->getImpl().trueAttr : context->getImpl().falseAttr;
+}
+
+UnitAttr UnitAttr::get(MLIRContext *context) {
+  return context->getImpl().unitAttr;
+}
+
+Location UnknownLoc::get(MLIRContext *context) {
+  return context->getImpl().unknownLocAttr;
+}
+
+//===----------------------------------------------------------------------===//
+// AffineMap uniquing
+//===----------------------------------------------------------------------===//
+
+StorageUniquer &MLIRContext::getAffineUniquer() {
+  return getImpl().affineUniquer;
+}
+
+AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount,
+                             ArrayRef<AffineExpr> results,
+                             MLIRContext *context) {
+  auto &impl = context->getImpl();
+  auto key = std::make_tuple(dimCount, symbolCount, results);
+
+  // Safely get or create an AffineMap instance.
+  return safeGetOrCreate(impl.affineMaps, key, impl.affineMutex, [&] {
+    auto *res = impl.affineAllocator.Allocate<detail::AffineMapStorage>();
+
+    // Copy the results into the bump pointer.
+    results = copyArrayRefInto(impl.affineAllocator, results);
+
+    // Initialize the memory using placement new.
+    new (res) detail::AffineMapStorage{dimCount, symbolCount, results};
+    return AffineMap(res);
+  });
+}
+
+AffineMap AffineMap::get(MLIRContext *context) {
+  return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context);
+}
+
+AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
+                         ArrayRef<AffineExpr> results) {
+  // The number of results can't be zero.
+  assert(!results.empty());
+  return getImpl(dimCount, symbolCount, results, results[0].getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// Integer Sets: these are allocated into the bump pointer, and are immutable.
+// Unlike AffineMap's, these are uniqued only if they are small.
+//===----------------------------------------------------------------------===//
+
+IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount,
+                           ArrayRef<AffineExpr> constraints,
+                           ArrayRef<bool> eqFlags) {
+  // The number of constraints can't be zero.
+  assert(!constraints.empty());
+  assert(constraints.size() == eqFlags.size());
+
+  auto &impl = constraints[0].getContext()->getImpl();
+
+  // A utility function to construct a new IntegerSetStorage instance.
+  auto constructorFn = [&] {
+    auto *res = impl.affineAllocator.Allocate<detail::IntegerSetStorage>();
+
+    // Copy the results and equality flags into the bump pointer.
+    constraints = copyArrayRefInto(impl.affineAllocator, constraints);
+    eqFlags = copyArrayRefInto(impl.affineAllocator, eqFlags);
+
+    // Initialize the memory using placement new.
+    new (res)
+        detail::IntegerSetStorage{dimCount, symbolCount, constraints, eqFlags};
+    return IntegerSet(res);
+  };
+
+  // If this instance is uniqued, then we handle it separately so that multiple
+  // threads may simulatenously access existing instances.
+  if (constraints.size() < IntegerSet::kUniquingThreshold) {
+    auto key = std::make_tuple(dimCount, symbolCount, constraints, eqFlags);
+    return safeGetOrCreate(impl.integerSets, key, impl.affineMutex,
+                           constructorFn);
+  }
+
+  // Otherwise, aquire a writer-lock so that we can safely create the new
+  // instance.
+  llvm::sys::SmartScopedWriter<true> affineLock(impl.affineMutex);
+  return constructorFn();
+}
diff --git a/third_party/mlir/lib/IR/Module.cpp b/third_party/mlir/lib/IR/Module.cpp
new file mode 100644
index 0000000..b1c56c2
--- /dev/null
+++ b/third_party/mlir/lib/IR/Module.cpp
@@ -0,0 +1,88 @@
+//===- Module.cpp - MLIR Module Operation ---------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Module Operation.
+//===----------------------------------------------------------------------===//
+
+void ModuleOp::build(Builder *builder, OperationState *result) {
+  ensureTerminator(*result->addRegion(), *builder, result->location);
+}
+
+/// Construct a module from the given context.
+ModuleOp ModuleOp::create(Location loc) {
+  OperationState state(loc, "module");
+  Builder builder(loc->getContext());
+  ModuleOp::build(&builder, &state);
+  return llvm::cast<ModuleOp>(Operation::create(state));
+}
+
+ParseResult ModuleOp::parse(OpAsmParser *parser, OperationState *result) {
+  // If module attributes are present, parse them.
+  if (succeeded(parser->parseOptionalKeyword("attributes")))
+    if (parser->parseOptionalAttributeDict(result->attributes))
+      return failure();
+
+  // Parse the module body.
+  auto *body = result->addRegion();
+  if (parser->parseRegion(*body, llvm::None, llvm::None))
+    return failure();
+
+  // Ensure that this module has a valid terminator.
+  ensureTerminator(*body, parser->getBuilder(), result->location);
+  return success();
+}
+
+void ModuleOp::print(OpAsmPrinter *p) {
+  *p << "module";
+
+  // Print the module attributes.
+  auto attrs = getAttrs();
+  if (!attrs.empty()) {
+    *p << " attributes";
+    p->printOptionalAttrDict(attrs, {});
+  }
+
+  // Print the region.
+  p->printRegion(getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
+                 /*printBlockTerminators=*/false);
+}
+
+LogicalResult ModuleOp::verify() {
+  auto &bodyRegion = getOperation()->getRegion(0);
+
+  // The body must contain a single basic block.
+  if (bodyRegion.empty() || std::next(bodyRegion.begin()) != bodyRegion.end())
+    return emitOpError("expected body region to have a single block");
+
+  // Check that the body has no block arguments.
+  auto *body = &bodyRegion.front();
+  if (body->getNumArguments() != 0)
+    return emitOpError("expected body to have no arguments");
+
+  return success();
+}
+
+/// Return body of this module.
+Region &ModuleOp::getBodyRegion() { return getOperation()->getRegion(0); }
+Block *ModuleOp::getBody() { return &getBodyRegion().front(); }
diff --git a/third_party/mlir/lib/IR/Operation.cpp b/third_party/mlir/lib/IR/Operation.cpp
new file mode 100644
index 0000000..fa2ce8c
--- /dev/null
+++ b/third_party/mlir/lib/IR/Operation.cpp
@@ -0,0 +1,1010 @@
+//===- Operation.cpp - Operation support code -----------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include <numeric>
+using namespace mlir;
+
+/// Form the OperationName for an op with the specified string.  This either is
+/// a reference to an AbstractOperation if one is known, or a uniqued Identifier
+/// if not.
+OperationName::OperationName(StringRef name, MLIRContext *context) {
+  if (auto *op = AbstractOperation::lookup(name, context))
+    representation = op;
+  else
+    representation = Identifier::get(name, context);
+}
+
+/// Return the name of the dialect this operation is registered to.
+StringRef OperationName::getDialect() const {
+  return getStringRef().split('.').first;
+}
+
+/// Return the name of this operation.  This always succeeds.
+StringRef OperationName::getStringRef() const {
+  if (auto *op = representation.dyn_cast<const AbstractOperation *>())
+    return op->name;
+  return representation.get<Identifier>().strref();
+}
+
+const AbstractOperation *OperationName::getAbstractOperation() const {
+  return representation.dyn_cast<const AbstractOperation *>();
+}
+
+OperationName OperationName::getFromOpaquePointer(void *pointer) {
+  return OperationName(RepresentationUnion::getFromOpaqueValue(pointer));
+}
+
+OpAsmParser::~OpAsmParser() {}
+
+//===----------------------------------------------------------------------===//
+// OpResult
+//===----------------------------------------------------------------------===//
+
+/// Return the result number of this result.
+unsigned OpResult::getResultNumber() {
+  // Results are always stored consecutively, so use pointer subtraction to
+  // figure out what number this is.
+  return this - &getOwner()->getOpResults()[0];
+}
+
+//===----------------------------------------------------------------------===//
+// OpOperand
+//===----------------------------------------------------------------------===//
+
+// TODO: This namespace is only required because of a bug in GCC<7.0.
+namespace mlir {
+/// Return which operand this is in the operand list.
+template <> unsigned OpOperand::getOperandNumber() {
+  return this - &getOwner()->getOpOperands()[0];
+}
+} // end namespace mlir
+
+//===----------------------------------------------------------------------===//
+// BlockOperand
+//===----------------------------------------------------------------------===//
+
+// TODO: This namespace is only required because of a bug in GCC<7.0.
+namespace mlir {
+/// Return which operand this is in the operand list.
+template <> unsigned BlockOperand::getOperandNumber() {
+  return this - &getOwner()->getBlockOperands()[0];
+}
+} // end namespace mlir
+
+//===----------------------------------------------------------------------===//
+// Operation
+//===----------------------------------------------------------------------===//
+
+/// Create a new Operation with the specific fields.
+Operation *Operation::create(Location location, OperationName name,
+                             ArrayRef<Value *> operands,
+                             ArrayRef<Type> resultTypes,
+                             ArrayRef<NamedAttribute> attributes,
+                             ArrayRef<Block *> successors, unsigned numRegions,
+                             bool resizableOperandList, MLIRContext *context) {
+  return create(location, name, operands, resultTypes,
+                NamedAttributeList(attributes), successors, numRegions,
+                resizableOperandList, context);
+}
+
+/// Create a new Operation from operation state.
+Operation *Operation::create(const OperationState &state) {
+  unsigned numRegions = state.regions.size();
+  Operation *op = create(state.location, state.name, state.operands,
+                         state.types, state.attributes, state.successors,
+                         numRegions, state.resizableOperandList, state.context);
+  for (unsigned i = 0; i < numRegions; ++i)
+    if (state.regions[i])
+      op->getRegion(i).takeBody(*state.regions[i]);
+  return op;
+}
+
+/// Overload of create that takes an existing NamedAttributeList to avoid
+/// unnecessarily uniquing a list of attributes.
+Operation *Operation::create(Location location, OperationName name,
+                             ArrayRef<Value *> operands,
+                             ArrayRef<Type> resultTypes,
+                             const NamedAttributeList &attributes,
+                             ArrayRef<Block *> successors, unsigned numRegions,
+                             bool resizableOperandList, MLIRContext *context) {
+  unsigned numSuccessors = successors.size();
+
+  // Input operands are nullptr-separated for each successor, the null operands
+  // aren't actually stored.
+  unsigned numOperands = operands.size() - numSuccessors;
+
+  // Compute the byte size for the operation and the operand storage.
+  auto byteSize = totalSizeToAlloc<OpResult, BlockOperand, unsigned, Region,
+                                   detail::OperandStorage>(
+      resultTypes.size(), numSuccessors, numSuccessors, numRegions,
+      /*detail::OperandStorage*/ 1);
+  byteSize += llvm::alignTo(detail::OperandStorage::additionalAllocSize(
+                                numOperands, resizableOperandList),
+                            alignof(Operation));
+  void *rawMem = malloc(byteSize);
+
+  // Create the new Operation.
+  auto op =
+      ::new (rawMem) Operation(location, name, resultTypes.size(),
+                               numSuccessors, numRegions, attributes, context);
+
+  assert((numSuccessors == 0 || !op->isKnownNonTerminator()) &&
+         "unexpected successors in a non-terminator operation");
+
+  // Initialize the regions.
+  for (unsigned i = 0; i != numRegions; ++i)
+    new (&op->getRegion(i)) Region(op);
+
+  // Initialize the results and operands.
+  new (&op->getOperandStorage())
+      detail::OperandStorage(numOperands, resizableOperandList);
+
+  auto instResults = op->getOpResults();
+  for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
+    new (&instResults[i]) OpResult(resultTypes[i], op);
+
+  auto opOperands = op->getOpOperands();
+
+  // Initialize normal operands.
+  unsigned operandIt = 0, operandE = operands.size();
+  unsigned nextOperand = 0;
+  for (; operandIt != operandE; ++operandIt) {
+    // Null operands are used as sentinels between successor operand lists. If
+    // we encounter one here, break and handle the successor operands lists
+    // separately below.
+    if (!operands[operandIt])
+      break;
+    new (&opOperands[nextOperand++]) OpOperand(op, operands[operandIt]);
+  }
+
+  unsigned currentSuccNum = 0;
+  if (operandIt == operandE) {
+    // Verify that the amount of sentinel operands is equivalent to the number
+    // of successors.
+    assert(currentSuccNum == numSuccessors);
+    return op;
+  }
+
+  assert(!op->isKnownNonTerminator() &&
+         "Unexpected nullptr in operand list when creating non-terminator.");
+  auto instBlockOperands = op->getBlockOperands();
+  unsigned *succOperandCountIt = op->getTrailingObjects<unsigned>();
+  unsigned *succOperandCountE = succOperandCountIt + numSuccessors;
+  (void)succOperandCountE;
+
+  for (; operandIt != operandE; ++operandIt) {
+    // If we encounter a sentinel branch to the next operand update the count
+    // variable.
+    if (!operands[operandIt]) {
+      assert(currentSuccNum < numSuccessors);
+
+      // After the first iteration update the successor operand count
+      // variable.
+      if (currentSuccNum != 0) {
+        ++succOperandCountIt;
+        assert(succOperandCountIt != succOperandCountE &&
+               "More sentinel operands than successors.");
+      }
+
+      new (&instBlockOperands[currentSuccNum])
+          BlockOperand(op, successors[currentSuccNum]);
+      *succOperandCountIt = 0;
+      ++currentSuccNum;
+      continue;
+    }
+    new (&opOperands[nextOperand++]) OpOperand(op, operands[operandIt]);
+    ++(*succOperandCountIt);
+  }
+
+  // Verify that the amount of sentinel operands is equivalent to the number of
+  // successors.
+  assert(currentSuccNum == numSuccessors);
+
+  return op;
+}
+
+Operation::Operation(Location location, OperationName name, unsigned numResults,
+                     unsigned numSuccessors, unsigned numRegions,
+                     const NamedAttributeList &attributes, MLIRContext *context)
+    : location(location), numResults(numResults), numSuccs(numSuccessors),
+      numRegions(numRegions), name(name), attrs(attributes) {}
+
+// Operations are deleted through the destroy() member because they are
+// allocated via malloc.
+Operation::~Operation() {
+  assert(block == nullptr && "operation destroyed but still in a block");
+
+  // Explicitly run the destructors for the operands and results.
+  getOperandStorage().~OperandStorage();
+
+  for (auto &result : getOpResults())
+    result.~OpResult();
+
+  // Explicitly run the destructors for the successors.
+  for (auto &successor : getBlockOperands())
+    successor.~BlockOperand();
+
+  // Explicitly destroy the regions.
+  for (auto &region : getRegions())
+    region.~Region();
+}
+
+/// Destroy this operation or one of its subclasses.
+void Operation::destroy() {
+  this->~Operation();
+  free(this);
+}
+
+/// Return the context this operation is associated with.
+MLIRContext *Operation::getContext() { return location->getContext(); }
+
+/// Return the dialact this operation is associated with, or nullptr if the
+/// associated dialect is not registered.
+Dialect *Operation::getDialect() {
+  if (auto *abstractOp = getAbstractOperation())
+    return &abstractOp->dialect;
+
+  // If this operation hasn't been registered or doesn't have abstract
+  // operation, try looking up the dialect name in the context.
+  return getContext()->getRegisteredDialect(getName().getDialect());
+}
+
+Region *Operation::getParentRegion() {
+  return block ? block->getParent() : nullptr;
+}
+
+Operation *Operation::getParentOp() {
+  return block ? block->getParentOp() : nullptr;
+}
+
+/// Replace any uses of 'from' with 'to' within this operation.
+void Operation::replaceUsesOfWith(Value *from, Value *to) {
+  if (from == to)
+    return;
+  for (auto &operand : getOpOperands())
+    if (operand.get() == from)
+      operand.set(to);
+}
+
+//===----------------------------------------------------------------------===//
+// Operation Walkers
+//===----------------------------------------------------------------------===//
+
+void Operation::walk(llvm::function_ref<void(Operation *)> callback) {
+  // Visit any internal operations.
+  for (auto &region : getRegions())
+    region.walk(callback);
+
+  // Visit the current operation.
+  callback(this);
+}
+
+//===----------------------------------------------------------------------===//
+// Other
+//===----------------------------------------------------------------------===//
+
+/// Emit an error about fatal conditions with this operation, reporting up to
+/// any diagnostic handlers that may be listening.
+InFlightDiagnostic Operation::emitError(const Twine &message) {
+  return mlir::emitError(getLoc(), message);
+}
+
+/// Emit a warning about this operation, reporting up to any diagnostic
+/// handlers that may be listening.
+InFlightDiagnostic Operation::emitWarning(const Twine &message) {
+  return mlir::emitWarning(getLoc(), message);
+}
+
+/// Emit a remark about this operation, reporting up to any diagnostic
+/// handlers that may be listening.
+InFlightDiagnostic Operation::emitRemark(const Twine &message) {
+  return mlir::emitRemark(getLoc(), message);
+}
+
+/// Given an operation 'other' that is within the same parent block, return
+/// whether the current operation is before 'other' in the operation list
+/// of the parent block.
+/// Note: This function has an average complexity of O(1), but worst case may
+/// take O(N) where N is the number of operations within the parent block.
+bool Operation::isBeforeInBlock(Operation *other) {
+  assert(block && "Operations without parent blocks have no order.");
+  assert(other && other->block == block &&
+         "Expected other operation to have the same parent block.");
+  // Recompute the parent ordering if necessary.
+  if (!block->isInstOrderValid())
+    block->recomputeInstOrder();
+  return orderIndex < other->orderIndex;
+}
+
+//===----------------------------------------------------------------------===//
+// ilist_traits for Operation
+//===----------------------------------------------------------------------===//
+
+auto llvm::ilist_detail::SpecificNodeAccess<
+    typename llvm::ilist_detail::compute_node_options<
+        ::mlir::Operation>::type>::getNodePtr(pointer N) -> node_type * {
+  return NodeAccess::getNodePtr<OptionsT>(N);
+}
+
+auto llvm::ilist_detail::SpecificNodeAccess<
+    typename llvm::ilist_detail::compute_node_options<
+        ::mlir::Operation>::type>::getNodePtr(const_pointer N)
+    -> const node_type * {
+  return NodeAccess::getNodePtr<OptionsT>(N);
+}
+
+auto llvm::ilist_detail::SpecificNodeAccess<
+    typename llvm::ilist_detail::compute_node_options<
+        ::mlir::Operation>::type>::getValuePtr(node_type *N) -> pointer {
+  return NodeAccess::getValuePtr<OptionsT>(N);
+}
+
+auto llvm::ilist_detail::SpecificNodeAccess<
+    typename llvm::ilist_detail::compute_node_options<
+        ::mlir::Operation>::type>::getValuePtr(const node_type *N)
+    -> const_pointer {
+  return NodeAccess::getValuePtr<OptionsT>(N);
+}
+
+void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) {
+  op->destroy();
+}
+
+Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() {
+  size_t Offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr))));
+  iplist<Operation> *Anchor(static_cast<iplist<Operation> *>(this));
+  return reinterpret_cast<Block *>(reinterpret_cast<char *>(Anchor) - Offset);
+}
+
+/// This is a trait method invoked when a operation is added to a block.  We
+/// keep the block pointer up to date.
+void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) {
+  assert(!op->getBlock() && "already in a operation block!");
+  op->block = getContainingBlock();
+
+  // Invalidate the block ordering.
+  op->block->invalidateInstOrder();
+}
+
+/// This is a trait method invoked when a operation is removed from a block.
+/// We keep the block pointer up to date.
+void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) {
+  assert(op->block && "not already in a operation block!");
+  op->block = nullptr;
+}
+
+/// This is a trait method invoked when a operation is moved from one block
+/// to another.  We keep the block pointer up to date.
+void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList(
+    ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) {
+  Block *curParent = getContainingBlock();
+
+  // Invalidate the ordering of the parent block.
+  curParent->invalidateInstOrder();
+
+  // If we are transferring operations within the same block, the block
+  // pointer doesn't need to be updated.
+  if (curParent == otherList.getContainingBlock())
+    return;
+
+  // Update the 'block' member of each operation.
+  for (; first != last; ++first)
+    first->block = curParent;
+}
+
+/// Remove this operation (and its descendants) from its Block and delete
+/// all of them.
+void Operation::erase() {
+  if (auto *parent = getBlock())
+    parent->getOperations().erase(this);
+  else
+    destroy();
+}
+
+/// Unlink this operation from its current block and insert it right before
+/// `existingInst` which may be in the same or another block in the same
+/// function.
+void Operation::moveBefore(Operation *existingInst) {
+  moveBefore(existingInst->getBlock(), existingInst->getIterator());
+}
+
+/// Unlink this operation from its current basic block and insert it right
+/// before `iterator` in the specified basic block.
+void Operation::moveBefore(Block *block,
+                           llvm::iplist<Operation>::iterator iterator) {
+  block->getOperations().splice(iterator, getBlock()->getOperations(),
+                                getIterator());
+}
+
+/// This drops all operand uses from this operation, which is an essential
+/// step in breaking cyclic dependences between references when they are to
+/// be deleted.
+void Operation::dropAllReferences() {
+  for (auto &op : getOpOperands())
+    op.drop();
+
+  for (auto &region : getRegions())
+    region.dropAllReferences();
+
+  for (auto &dest : getBlockOperands())
+    dest.drop();
+}
+
+/// This drops all uses of any values defined by this operation or its nested
+/// regions, wherever they are located.
+void Operation::dropAllDefinedValueUses() {
+  for (auto &val : getOpResults())
+    val.dropAllUses();
+
+  for (auto &region : getRegions())
+    for (auto &block : region)
+      block.dropAllDefinedValueUses();
+}
+
+/// Return true if there are no users of any results of this operation.
+bool Operation::use_empty() {
+  for (auto *result : getResults())
+    if (!result->use_empty())
+      return false;
+  return true;
+}
+
+void Operation::setSuccessor(Block *block, unsigned index) {
+  assert(index < getNumSuccessors());
+  getBlockOperands()[index].set(block);
+}
+
+auto Operation::getNonSuccessorOperands() -> operand_range {
+  return {operand_iterator(this, 0),
+          operand_iterator(this, hasSuccessors() ? getSuccessorOperandIndex(0)
+                                                 : getNumOperands())};
+}
+
+/// Get the index of the first operand of the successor at the provided
+/// index.
+unsigned Operation::getSuccessorOperandIndex(unsigned index) {
+  assert(!isKnownNonTerminator() && "only terminators may have successors");
+  assert(index < getNumSuccessors());
+
+  // Count the number of operands for each of the successors after, and
+  // including, the one at 'index'. This is based upon the assumption that all
+  // non successor operands are placed at the beginning of the operand list.
+  auto *successorOpCountBegin = getTrailingObjects<unsigned>();
+  unsigned postSuccessorOpCount =
+      std::accumulate(successorOpCountBegin + index,
+                      successorOpCountBegin + getNumSuccessors(), 0u);
+  return getNumOperands() - postSuccessorOpCount;
+}
+
+auto Operation::getSuccessorOperands(unsigned index) -> operand_range {
+  unsigned succOperandIndex = getSuccessorOperandIndex(index);
+  return {operand_iterator(this, succOperandIndex),
+          operand_iterator(this,
+                           succOperandIndex + getNumSuccessorOperands(index))};
+}
+
+/// Attempt to fold this operation using the Op's registered foldHook.
+LogicalResult Operation::fold(ArrayRef<Attribute> operands,
+                              SmallVectorImpl<OpFoldResult> &results) {
+  // If we have a registered operation definition matching this one, use it to
+  // try to constant fold the operation.
+  auto *abstractOp = getAbstractOperation();
+  if (abstractOp && succeeded(abstractOp->foldHook(this, operands, results)))
+    return success();
+
+  // Otherwise, fall back on the dialect hook to handle it.
+  Dialect *dialect = getDialect();
+  if (!dialect)
+    return failure();
+
+  SmallVector<Attribute, 8> constants;
+  if (failed(dialect->constantFoldHook(this, operands, constants)))
+    return failure();
+  results.assign(constants.begin(), constants.end());
+  return success();
+}
+
+/// Emit an error with the op name prefixed, like "'dim' op " which is
+/// convenient for verifiers.
+InFlightDiagnostic Operation::emitOpError(const Twine &message) {
+  return emitError() << "'" << getName() << "' op " << message;
+}
+
+//===----------------------------------------------------------------------===//
+// Operation Cloning
+//===----------------------------------------------------------------------===//
+
+/// Create a deep copy of this operation but keep the operation regions empty.
+/// Operands are remapped using `mapper` (if present), and `mapper` is updated
+/// to contain the results.
+Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) {
+  SmallVector<Value *, 8> operands;
+  SmallVector<Block *, 2> successors;
+
+  operands.reserve(getNumOperands() + getNumSuccessors());
+
+  if (getNumSuccessors() == 0) {
+    // Non-branching operations can just add all the operands.
+    for (auto *opValue : getOperands())
+      operands.push_back(mapper.lookupOrDefault(opValue));
+  } else {
+    // We add the operands separated by nullptr's for each successor.
+    unsigned firstSuccOperand =
+        getNumSuccessors() ? getSuccessorOperandIndex(0) : getNumOperands();
+    auto opOperands = getOpOperands();
+
+    unsigned i = 0;
+    for (; i != firstSuccOperand; ++i)
+      operands.push_back(mapper.lookupOrDefault(opOperands[i].get()));
+
+    successors.reserve(getNumSuccessors());
+    for (unsigned succ = 0, e = getNumSuccessors(); succ != e; ++succ) {
+      successors.push_back(mapper.lookupOrDefault(getSuccessor(succ)));
+
+      // Add sentinel to delineate successor operands.
+      operands.push_back(nullptr);
+
+      // Remap the successors operands.
+      for (auto *operand : getSuccessorOperands(succ))
+        operands.push_back(mapper.lookupOrDefault(operand));
+    }
+  }
+
+  SmallVector<Type, 8> resultTypes(getResultTypes());
+  unsigned numRegions = getNumRegions();
+  auto *newOp = Operation::create(getLoc(), getName(), operands, resultTypes,
+                                  attrs, successors, numRegions,
+                                  hasResizableOperandsList(), getContext());
+
+  // Remember the mapping of any results.
+  for (unsigned i = 0, e = getNumResults(); i != e; ++i)
+    mapper.map(getResult(i), newOp->getResult(i));
+
+  return newOp;
+}
+
+Operation *Operation::cloneWithoutRegions() {
+  BlockAndValueMapping mapper;
+  return cloneWithoutRegions(mapper);
+}
+
+/// Create a deep copy of this operation, remapping any operands that use
+/// values outside of the operation using the map that is provided (leaving
+/// them alone if no entry is present).  Replaces references to cloned
+/// sub-operations to the corresponding operation that is copied, and adds
+/// those mappings to the map.
+Operation *Operation::clone(BlockAndValueMapping &mapper) {
+  auto *newOp = cloneWithoutRegions(mapper);
+
+  // Clone the regions.
+  for (unsigned i = 0; i != numRegions; ++i)
+    getRegion(i).cloneInto(&newOp->getRegion(i), mapper);
+
+  return newOp;
+}
+
+Operation *Operation::clone() {
+  BlockAndValueMapping mapper;
+  return clone(mapper);
+}
+
+//===----------------------------------------------------------------------===//
+// OpState trait class.
+//===----------------------------------------------------------------------===//
+
+// The fallback for the parser is to reject the custom assembly form.
+ParseResult OpState::parse(OpAsmParser *parser, OperationState *result) {
+  return parser->emitError(parser->getNameLoc(), "has no custom assembly form");
+}
+
+// The fallback for the printer is to print in the generic assembly form.
+void OpState::print(OpAsmPrinter *p) { p->printGenericOp(getOperation()); }
+
+/// Emit an error about fatal conditions with this operation, reporting up to
+/// any diagnostic handlers that may be listening.
+InFlightDiagnostic OpState::emitError(const Twine &message) {
+  return getOperation()->emitError(message);
+}
+
+/// Emit an error with the op name prefixed, like "'dim' op " which is
+/// convenient for verifiers.
+InFlightDiagnostic OpState::emitOpError(const Twine &message) {
+  return getOperation()->emitOpError(message);
+}
+
+/// Emit a warning about this operation, reporting up to any diagnostic
+/// handlers that may be listening.
+InFlightDiagnostic OpState::emitWarning(const Twine &message) {
+  return getOperation()->emitWarning(message);
+}
+
+/// Emit a remark about this operation, reporting up to any diagnostic
+/// handlers that may be listening.
+InFlightDiagnostic OpState::emitRemark(const Twine &message) {
+  return getOperation()->emitRemark(message);
+}
+
+//===----------------------------------------------------------------------===//
+// Op Trait implementations
+//===----------------------------------------------------------------------===//
+
+LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) {
+  if (op->getNumOperands() != 0)
+    return op->emitOpError() << "requires zero operands";
+  return success();
+}
+
+LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) {
+  if (op->getNumOperands() != 1)
+    return op->emitOpError() << "requires a single operand";
+  return success();
+}
+
+LogicalResult OpTrait::impl::verifyNOperands(Operation *op,
+                                             unsigned numOperands) {
+  if (op->getNumOperands() != numOperands) {
+    return op->emitOpError() << "expected " << numOperands
+                             << " operands, but found " << op->getNumOperands();
+  }
+  return success();
+}
+
+LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op,
+                                                    unsigned numOperands) {
+  if (op->getNumOperands() < numOperands)
+    return op->emitOpError()
+           << "expected " << numOperands << " or more operands";
+  return success();
+}
+
+/// If this is a vector type, or a tensor type, return the scalar element type
+/// that it is built around, otherwise return the type unmodified.
+static Type getTensorOrVectorElementType(Type type) {
+  if (auto vec = type.dyn_cast<VectorType>())
+    return vec.getElementType();
+
+  // Look through tensor<vector<...>> to find the underlying element type.
+  if (auto tensor = type.dyn_cast<TensorType>())
+    return getTensorOrVectorElementType(tensor.getElementType());
+  return type;
+}
+
+LogicalResult OpTrait::impl::verifyOperandsAreIntegerLike(Operation *op) {
+  for (auto opType : op->getOperandTypes()) {
+    auto type = getTensorOrVectorElementType(opType);
+    if (!type.isIntOrIndex())
+      return op->emitOpError() << "requires an integer or index type";
+  }
+  return success();
+}
+
+LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) {
+  for (auto opType : op->getOperandTypes()) {
+    auto type = getTensorOrVectorElementType(opType);
+    if (!type.isa<FloatType>())
+      return op->emitOpError("requires a float type");
+  }
+  return success();
+}
+
+LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) {
+  // Zero or one operand always have the "same" type.
+  unsigned nOperands = op->getNumOperands();
+  if (nOperands < 2)
+    return success();
+
+  auto type = op->getOperand(0)->getType();
+  for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1))
+    if (opType != type)
+      return op->emitOpError() << "requires all operands to have the same type";
+  return success();
+}
+
+LogicalResult OpTrait::impl::verifyZeroResult(Operation *op) {
+  if (op->getNumResults() != 0)
+    return op->emitOpError() << "requires zero results";
+  return success();
+}
+
+LogicalResult OpTrait::impl::verifyOneResult(Operation *op) {
+  if (op->getNumResults() != 1)
+    return op->emitOpError() << "requires one result";
+  return success();
+}
+
+LogicalResult OpTrait::impl::verifyNResults(Operation *op,
+                                            unsigned numOperands) {
+  if (op->getNumResults() != numOperands)
+    return op->emitOpError() << "expected " << numOperands << " results";
+  return success();
+}
+
+LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op,
+                                                   unsigned numOperands) {
+  if (op->getNumResults() < numOperands)
+    return op->emitOpError()
+           << "expected " << numOperands << " or more results";
+  return success();
+}
+
+/// Returns success if the given two types have the same shape. That is,
+/// they are both scalars (not shaped), or they are both shaped types and at
+/// least one is unranked or they have the same shape. The element type does not
+/// matter.
+static LogicalResult verifyShapeMatch(Type type1, Type type2) {
+  auto sType1 = type1.dyn_cast<ShapedType>();
+  auto sType2 = type2.dyn_cast<ShapedType>();
+
+  // Either both or neither type should be shaped.
+  if (!sType1)
+    return success(!sType2);
+  if (!sType2)
+    return failure();
+
+  if (!sType1.hasRank() || !sType2.hasRank())
+    return success();
+
+  return success(sType1.getShape() == sType2.getShape());
+}
+
+LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {
+  if (op->getNumOperands() == 0 || op->getNumResults() == 0)
+    return failure();
+
+  auto type = op->getOperand(0)->getType();
+  for (auto resultType : op->getResultTypes()) {
+    if (failed(verifyShapeMatch(resultType, type)))
+      return op->emitOpError()
+             << "requires the same shape for all operands and results";
+  }
+  for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) {
+    if (failed(verifyShapeMatch(opType, type)))
+      return op->emitOpError()
+             << "requires the same shape for all operands and results";
+  }
+  return success();
+}
+
+LogicalResult
+OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) {
+  if (op->getNumOperands() == 0 || op->getNumResults() == 0)
+    return failure();
+
+  auto type = op->getResult(0)->getType().dyn_cast<ShapedType>();
+  if (!type)
+    return op->emitOpError("requires shaped type results");
+  auto elementType = type.getElementType();
+
+  // Verify result element type matches first result's element type.
+  for (auto result : drop_begin(op->getResults(), 1)) {
+    auto resultType = result->getType().dyn_cast<ShapedType>();
+    if (!resultType)
+      return op->emitOpError("requires shaped type results");
+    if (resultType.getElementType() != elementType)
+      return op->emitOpError(
+          "requires the same element type for all operands and results");
+  }
+
+  // Verify operand's element type matches first result's element type.
+  for (auto operand : op->getOperands()) {
+    auto operandType = operand->getType().dyn_cast<ShapedType>();
+    if (!operandType)
+      return op->emitOpError("requires shaped type operands");
+    if (operandType.getElementType() != elementType)
+      return op->emitOpError(
+          "requires the same element type for all operands and results");
+  }
+
+  return success();
+}
+
+LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
+  if (op->getNumOperands() == 0 || op->getNumResults() == 0)
+    return failure();
+
+  auto type = op->getResult(0)->getType();
+  for (auto resultType : llvm::drop_begin(op->getResultTypes(), 1)) {
+    if (resultType != type)
+      return op->emitOpError()
+             << "requires the same type for all operands and results";
+  }
+  for (auto opType : op->getOperandTypes()) {
+    if (opType != type)
+      return op->emitOpError()
+             << "requires the same type for all operands and results";
+  }
+  return success();
+}
+
+static LogicalResult verifyBBArguments(Operation::operand_range operands,
+                                       Block *destBB, Operation *op) {
+  unsigned operandCount = std::distance(operands.begin(), operands.end());
+  if (operandCount != destBB->getNumArguments())
+    return op->emitError() << "branch has " << operandCount
+                           << " operands, but target block has "
+                           << destBB->getNumArguments();
+
+  auto operandIt = operands.begin();
+  for (unsigned i = 0, e = operandCount; i != e; ++i, ++operandIt) {
+    if ((*operandIt)->getType() != destBB->getArgument(i)->getType())
+      return op->emitError() << "type mismatch in bb argument #" << i;
+  }
+
+  return success();
+}
+
+static LogicalResult verifyTerminatorSuccessors(Operation *op) {
+  auto *parent = op->getParentRegion();
+
+  // Verify that the operands lines up with the BB arguments in the successor.
+  for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
+    auto *succ = op->getSuccessor(i);
+    if (succ->getParent() != parent)
+      return op->emitError("reference to block defined in another region");
+    if (failed(verifyBBArguments(op->getSuccessorOperands(i), succ, op)))
+      return failure();
+  }
+  return success();
+}
+
+LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) {
+  Block *block = op->getBlock();
+  // Verify that the operation is at the end of the respective parent block.
+  if (!block || &block->back() != op)
+    return op->emitOpError("must be the last operation in the parent block");
+
+  // Verify the state of the successor blocks.
+  if (op->getNumSuccessors() != 0 && failed(verifyTerminatorSuccessors(op)))
+    return failure();
+  return success();
+}
+
+LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) {
+  for (auto resultType : op->getResultTypes()) {
+    auto elementType = getTensorOrVectorElementType(resultType);
+    bool isBoolType = elementType.isInteger(1);
+    if (!isBoolType)
+      return op->emitOpError() << "requires a bool result type";
+  }
+
+  return success();
+}
+
+LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) {
+  for (auto resultType : op->getResultTypes())
+    if (!getTensorOrVectorElementType(resultType).isa<FloatType>())
+      return op->emitOpError() << "requires a floating point type";
+
+  return success();
+}
+
+LogicalResult OpTrait::impl::verifyResultsAreIntegerLike(Operation *op) {
+  for (auto resultType : op->getResultTypes())
+    if (!getTensorOrVectorElementType(resultType).isIntOrIndex())
+      return op->emitOpError() << "requires an integer or index type";
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// BinaryOp implementation
+//===----------------------------------------------------------------------===//
+
+// These functions are out-of-line implementations of the methods in BinaryOp,
+// which avoids them being template instantiated/duplicated.
+
+void impl::buildBinaryOp(Builder *builder, OperationState *result, Value *lhs,
+                         Value *rhs) {
+  assert(lhs->getType() == rhs->getType());
+  result->addOperands({lhs, rhs});
+  result->types.push_back(lhs->getType());
+}
+
+ParseResult impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) {
+  SmallVector<OpAsmParser::OperandType, 2> ops;
+  Type type;
+  return failure(parser->parseOperandList(ops, 2) ||
+                 parser->parseOptionalAttributeDict(result->attributes) ||
+                 parser->parseColonType(type) ||
+                 parser->resolveOperands(ops, type, result->operands) ||
+                 parser->addTypeToList(type, result->types));
+}
+
+void impl::printBinaryOp(Operation *op, OpAsmPrinter *p) {
+  assert(op->getNumOperands() == 2 && "binary op should have two operands");
+  assert(op->getNumResults() == 1 && "binary op should have one result");
+
+  // If not all the operand and result types are the same, just use the
+  // generic assembly form to avoid omitting information in printing.
+  auto resultType = op->getResult(0)->getType();
+  if (op->getOperand(0)->getType() != resultType ||
+      op->getOperand(1)->getType() != resultType) {
+    p->printGenericOp(op);
+    return;
+  }
+
+  *p << op->getName() << ' ' << *op->getOperand(0) << ", "
+     << *op->getOperand(1);
+  p->printOptionalAttrDict(op->getAttrs());
+  // Now we can output only one type for all operands and the result.
+  *p << " : " << op->getResult(0)->getType();
+}
+
+//===----------------------------------------------------------------------===//
+// CastOp implementation
+//===----------------------------------------------------------------------===//
+
+void impl::buildCastOp(Builder *builder, OperationState *result, Value *source,
+                       Type destType) {
+  result->addOperands(source);
+  result->addTypes(destType);
+}
+
+ParseResult impl::parseCastOp(OpAsmParser *parser, OperationState *result) {
+  OpAsmParser::OperandType srcInfo;
+  Type srcType, dstType;
+  return failure(parser->parseOperand(srcInfo) ||
+                 parser->parseOptionalAttributeDict(result->attributes) ||
+                 parser->parseColonType(srcType) ||
+                 parser->resolveOperand(srcInfo, srcType, result->operands) ||
+                 parser->parseKeywordType("to", dstType) ||
+                 parser->addTypeToList(dstType, result->types));
+}
+
+void impl::printCastOp(Operation *op, OpAsmPrinter *p) {
+  *p << op->getName() << ' ' << *op->getOperand(0);
+  p->printOptionalAttrDict(op->getAttrs());
+  *p << " : " << op->getOperand(0)->getType() << " to "
+     << op->getResult(0)->getType();
+}
+
+Value *impl::foldCastOp(Operation *op) {
+  // Identity cast
+  if (op->getOperand(0)->getType() == op->getResult(0)->getType())
+    return op->getOperand(0);
+  return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// CastOp implementation
+//===----------------------------------------------------------------------===//
+
+/// Insert an operation, generated by `buildTerminatorOp`, at the end of the
+/// region's only block if it does not have a terminator already. If the region
+/// is empty, insert a new block first. `buildTerminatorOp` should return the
+/// terminator operation to insert.
+void impl::ensureRegionTerminator(
+    Region &region, Location loc,
+    llvm::function_ref<Operation *()> buildTerminatorOp) {
+  if (region.empty())
+    region.push_back(new Block);
+
+  Block &block = region.back();
+  if (!block.empty() && block.back().isKnownTerminator())
+    return;
+
+  block.push_back(buildTerminatorOp());
+}
diff --git a/third_party/mlir/lib/IR/OperationSupport.cpp b/third_party/mlir/lib/IR/OperationSupport.cpp
new file mode 100644
index 0000000..fdc9c03
--- /dev/null
+++ b/third_party/mlir/lib/IR/OperationSupport.cpp
@@ -0,0 +1,137 @@
+//===- OperationSupport.cpp -----------------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains out-of-line implementations of the support types that
+// Operation and related classes build on top of.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Operation.h"
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// OperationState
+//===----------------------------------------------------------------------===//
+
+OperationState::OperationState(Location location, StringRef name)
+    : context(location->getContext()), location(location),
+      name(name, location->getContext()) {}
+
+OperationState::OperationState(Location location, OperationName name)
+    : context(location->getContext()), location(location), name(name) {}
+
+OperationState::OperationState(Location location, StringRef name,
+                               ArrayRef<Value *> operands, ArrayRef<Type> types,
+                               ArrayRef<NamedAttribute> attributes,
+                               ArrayRef<Block *> successors,
+                               MutableArrayRef<std::unique_ptr<Region>> regions,
+                               bool resizableOperandList)
+    : context(location->getContext()), location(location),
+      name(name, location->getContext()),
+      operands(operands.begin(), operands.end()),
+      types(types.begin(), types.end()),
+      attributes(attributes.begin(), attributes.end()),
+      successors(successors.begin(), successors.end()) {
+  for (std::unique_ptr<Region> &r : regions) {
+    this->regions.push_back(std::move(r));
+  }
+}
+
+Region *OperationState::addRegion() {
+  regions.emplace_back(new Region);
+  return regions.back().get();
+}
+
+void OperationState::addRegion(std::unique_ptr<Region> &&region) {
+  regions.push_back(std::move(region));
+}
+
+//===----------------------------------------------------------------------===//
+// OperandStorage
+//===----------------------------------------------------------------------===//
+
+/// Replace the operands contained in the storage with the ones provided in
+/// 'operands'.
+void detail::OperandStorage::setOperands(Operation *owner,
+                                         ArrayRef<Value *> operands) {
+  // If the number of operands is less than or equal to the current amount, we
+  // can just update in place.
+  if (operands.size() <= numOperands) {
+    auto opOperands = getOperands();
+
+    // If the number of new operands is less than the current count, then remove
+    // any extra operands.
+    for (unsigned i = operands.size(); i != numOperands; ++i)
+      opOperands[i].~OpOperand();
+
+    // Set the operands in place.
+    numOperands = operands.size();
+    for (unsigned i = 0; i != numOperands; ++i)
+      opOperands[i].set(operands[i]);
+    return;
+  }
+
+  // Otherwise, we need to be resizable.
+  assert(resizable && "Only resizable operations may add operands");
+
+  // Grow the capacity if necessary.
+  auto &resizeUtil = getResizableStorage();
+  if (resizeUtil.capacity < operands.size())
+    grow(resizeUtil, operands.size());
+
+  // Set the operands.
+  OpOperand *opBegin = getRawOperands();
+  for (unsigned i = 0; i != numOperands; ++i)
+    opBegin[i].set(operands[i]);
+  for (unsigned e = operands.size(); numOperands != e; ++numOperands)
+    new (&opBegin[numOperands]) OpOperand(owner, operands[numOperands]);
+}
+
+/// Erase an operand held by the storage.
+void detail::OperandStorage::eraseOperand(unsigned index) {
+  assert(index < size());
+  auto operands = getOperands();
+  --numOperands;
+
+  // Shift all operands down by 1 if the operand to remove is not at the end.
+  auto indexIt = std::next(operands.begin(), index);
+  if (index != numOperands)
+    std::rotate(indexIt, std::next(indexIt), operands.end());
+  operands[numOperands].~OpOperand();
+}
+
+/// Grow the internal operand storage.
+void detail::OperandStorage::grow(ResizableStorage &resizeUtil,
+                                  size_t minSize) {
+  // Allocate a new storage array.
+  resizeUtil.capacity =
+      std::max(size_t(llvm::NextPowerOf2(resizeUtil.capacity + 2)), minSize);
+  OpOperand *newStorage = static_cast<OpOperand *>(
+      llvm::safe_malloc(resizeUtil.capacity * sizeof(OpOperand)));
+
+  // Move the current operands to the new storage.
+  auto operands = getOperands();
+  std::uninitialized_copy(std::make_move_iterator(operands.begin()),
+                          std::make_move_iterator(operands.end()), newStorage);
+
+  // Destroy the original operands and update the resizable storage pointer.
+  for (auto &operand : operands)
+    operand.~OpOperand();
+  resizeUtil.setDynamicStorage(newStorage);
+}
diff --git a/third_party/mlir/lib/IR/PatternMatch.cpp b/third_party/mlir/lib/IR/PatternMatch.cpp
new file mode 100644
index 0000000..b575abe
--- /dev/null
+++ b/third_party/mlir/lib/IR/PatternMatch.cpp
@@ -0,0 +1,177 @@
+//===- PatternMatch.cpp - Base classes for pattern match ------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+using namespace mlir;
+
+PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
+  assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
+         "This pattern match benefit is too large to represent");
+}
+
+unsigned short PatternBenefit::getBenefit() const {
+  assert(representation != ImpossibleToMatchSentinel &&
+         "Pattern doesn't match");
+  return representation;
+}
+
+//===----------------------------------------------------------------------===//
+// Pattern implementation
+//===----------------------------------------------------------------------===//
+
+Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
+                 MLIRContext *context)
+    : rootKind(OperationName(rootName, context)), benefit(benefit) {}
+
+// Out-of-line vtable anchor.
+void Pattern::anchor() {}
+
+//===----------------------------------------------------------------------===//
+// RewritePattern and PatternRewriter implementation
+//===----------------------------------------------------------------------===//
+
+void RewritePattern::rewrite(Operation *op, std::unique_ptr<PatternState> state,
+                             PatternRewriter &rewriter) const {
+  rewrite(op, rewriter);
+}
+
+void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
+  llvm_unreachable("need to implement either matchAndRewrite or one of the "
+                   "rewrite functions!");
+}
+
+PatternMatchResult RewritePattern::match(Operation *op) const {
+  llvm_unreachable("need to implement either match or matchAndRewrite!");
+}
+
+/// Patterns must specify the root operation name they match against, and can
+/// also specify the benefit of the pattern matching. They can also specify the
+/// names of operations that may be generated during a successful rewrite.
+RewritePattern::RewritePattern(StringRef rootName,
+                               ArrayRef<StringRef> generatedNames,
+                               PatternBenefit benefit, MLIRContext *context)
+    : Pattern(rootName, benefit, context) {
+  generatedOps.reserve(generatedNames.size());
+  std::transform(generatedNames.begin(), generatedNames.end(),
+                 std::back_inserter(generatedOps), [context](StringRef name) {
+                   return OperationName(name, context);
+                 });
+}
+
+PatternRewriter::~PatternRewriter() {
+  // Out of line to provide a vtable anchor for the class.
+}
+
+/// This method performs the final replacement for a pattern, where the
+/// results of the operation are updated to use the specified list of SSA
+/// values.  In addition to replacing and removing the specified operation,
+/// clients can specify a list of other nodes that this replacement may make
+/// (perhaps transitively) dead.  If any of those ops are dead, this will
+/// remove them as well.
+void PatternRewriter::replaceOp(Operation *op, ArrayRef<Value *> newValues,
+                                ArrayRef<Value *> valuesToRemoveIfDead) {
+  // Notify the rewriter subclass that we're about to replace this root.
+  notifyRootReplaced(op);
+
+  assert(op->getNumResults() == newValues.size() &&
+         "incorrect # of replacement values");
+  op->replaceAllUsesWith(newValues);
+
+  notifyOperationRemoved(op);
+  op->erase();
+
+  // TODO: Process the valuesToRemoveIfDead list, removing things and calling
+  // the notifyOperationRemoved hook in the process.
+}
+
+/// op and newOp are known to have the same number of results, replace the
+/// uses of op with uses of newOp
+void PatternRewriter::replaceOpWithResultsOfAnotherOp(
+    Operation *op, Operation *newOp, ArrayRef<Value *> valuesToRemoveIfDead) {
+  assert(op->getNumResults() == newOp->getNumResults() &&
+         "replacement op doesn't match results of original op");
+  if (op->getNumResults() == 1)
+    return replaceOp(op, newOp->getResult(0), valuesToRemoveIfDead);
+
+  SmallVector<Value *, 8> newResults(newOp->getResults().begin(),
+                                     newOp->getResults().end());
+  return replaceOp(op, newResults, valuesToRemoveIfDead);
+}
+
+/// Move the blocks that belong to "region" before the given position in
+/// another region.  The two regions must be different.  The caller is in
+/// charge to update create the operation transferring the control flow to the
+/// region and pass it the correct block arguments.
+void PatternRewriter::inlineRegionBefore(Region &region, Region &parent,
+                                         Region::iterator before) {
+  parent.getBlocks().splice(before, region.getBlocks());
+}
+void PatternRewriter::inlineRegionBefore(Region &region, Block *before) {
+  inlineRegionBefore(region, *before->getParent(), before->getIterator());
+}
+
+/// This method is used as the final notification hook for patterns that end
+/// up modifying the pattern root in place, by changing its operands.  This is
+/// a minor efficiency win (it avoids creating a new operation and removing
+/// the old one) but also often allows simpler code in the client.
+///
+/// The opsToRemoveIfDead list is an optional list of nodes that the rewriter
+/// should remove if they are dead at this point.
+///
+void PatternRewriter::updatedRootInPlace(
+    Operation *op, ArrayRef<Value *> valuesToRemoveIfDead) {
+  // Notify the rewriter subclass that we're about to replace this root.
+  notifyRootUpdated(op);
+
+  // TODO: Process the valuesToRemoveIfDead list, removing things and calling
+  // the notifyOperationRemoved hook in the process.
+}
+
+//===----------------------------------------------------------------------===//
+// PatternMatcher implementation
+//===----------------------------------------------------------------------===//
+
+RewritePatternMatcher::RewritePatternMatcher(
+    OwningRewritePatternList &patterns) {
+  for (auto &pattern : patterns)
+    this->patterns.push_back(pattern.get());
+
+  // Sort the patterns by benefit to simplify the matching logic.
+  std::stable_sort(this->patterns.begin(), this->patterns.end(),
+                   [](RewritePattern *l, RewritePattern *r) {
+                     return r->getBenefit() < l->getBenefit();
+                   });
+}
+
+/// Try to match the given operation to a pattern and rewrite it.
+bool RewritePatternMatcher::matchAndRewrite(Operation *op,
+                                            PatternRewriter &rewriter) {
+  for (auto *pattern : patterns) {
+    // Ignore patterns that are for the wrong root or are impossible to match.
+    if (pattern->getRootKind() != op->getName() ||
+        pattern->getBenefit().isImpossibleToMatch())
+      continue;
+
+    // Try to match and rewrite this pattern. The patterns are sorted by
+    // benefit, so if we match we can immediately rewrite and return.
+    if (pattern->matchAndRewrite(op, rewriter))
+      return true;
+  }
+  return false;
+}
diff --git a/third_party/mlir/lib/IR/Region.cpp b/third_party/mlir/lib/IR/Region.cpp
new file mode 100644
index 0000000..0947ddd
--- /dev/null
+++ b/third_party/mlir/lib/IR/Region.cpp
@@ -0,0 +1,212 @@
+//===- Region.cpp - MLIR Region Class -------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Region.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Operation.h"
+using namespace mlir;
+
+Region::Region(Operation *container) : container(container) {}
+
+Region::~Region() {
+  // Operations may have cyclic references, which need to be dropped before we
+  // can start deleting them.
+  dropAllReferences();
+}
+
+/// Return the context this region is inserted in. The region must have a valid
+/// parent container.
+MLIRContext *Region::getContext() {
+  assert(container && "region is not attached to a container");
+  return container->getContext();
+}
+
+/// Return a location for this region. This is the location attached to the
+/// parent container. The region must have a valid parent container.
+Location Region::getLoc() {
+  assert(container && "region is not attached to a container");
+  return container->getLoc();
+}
+
+Region *Region::getParentRegion() {
+  assert(container && "region is not attached to a container");
+  return container->getParentRegion();
+}
+
+Operation *Region::getParentOp() { return container; }
+
+bool Region::isProperAncestor(Region *other) {
+  if (this == other)
+    return false;
+
+  while ((other = other->getParentRegion())) {
+    if (this == other)
+      return true;
+  }
+  return false;
+}
+
+/// Return the number of this region in the parent operation.
+unsigned Region::getRegionNumber() {
+  // Regions are always stored consecutively, so use pointer subtraction to
+  // figure out what number this is.
+  return this - &getParentOp()->getRegions()[0];
+}
+
+/// Clone the internal blocks from this region into `dest`. Any
+/// cloned blocks are appended to the back of dest.
+void Region::cloneInto(Region *dest, BlockAndValueMapping &mapper) {
+  assert(dest && "expected valid region to clone into");
+  cloneInto(dest, dest->end(), mapper);
+}
+
+/// Clone this region into 'dest' before the given position in 'dest'.
+void Region::cloneInto(Region *dest, Region::iterator destPos,
+                       BlockAndValueMapping &mapper) {
+  assert(dest && "expected valid region to clone into");
+
+  // If the list is empty there is nothing to clone.
+  if (empty())
+    return;
+
+  for (Block &block : *this) {
+    Block *newBlock = new Block();
+    mapper.map(&block, newBlock);
+
+    // Clone the block arguments. The user might be deleting arguments to the
+    // block by specifying them in the mapper. If so, we don't add the
+    // argument to the cloned block.
+    for (auto *arg : block.getArguments())
+      if (!mapper.contains(arg))
+        mapper.map(arg, newBlock->addArgument(arg->getType()));
+
+    // Clone and remap the operations within this block.
+    for (auto &op : block)
+      newBlock->push_back(op.clone(mapper));
+
+    dest->getBlocks().insert(destPos, newBlock);
+  }
+
+  // Now that each of the blocks have been cloned, go through and remap the
+  // operands of each of the operations.
+  auto remapOperands = [&](Operation *op) {
+    for (auto &operand : op->getOpOperands())
+      if (auto *mappedOp = mapper.lookupOrNull(operand.get()))
+        operand.set(mappedOp);
+    for (auto &succOp : op->getBlockOperands())
+      if (auto *mappedOp = mapper.lookupOrNull(succOp.get()))
+        succOp.set(mappedOp);
+  };
+
+  for (iterator it(mapper.lookup(&front())); it != destPos; ++it)
+    it->walk(remapOperands);
+}
+
+void Region::dropAllReferences() {
+  for (Block &b : *this)
+    b.dropAllReferences();
+}
+
+/// Check if there are any values used by operations in `region` defined
+/// outside its ancestor region `limit`.  That is, given `A{B{C{}}}` with region
+/// `C` and limit `B`, the values defined in `B` can be used but the values
+/// defined in `A` cannot.  Emit errors if `noteLoc` is provided; this location
+/// is used to point to the operation containing the region, the actual error is
+/// reported at the operation with an offending use.
+static bool isIsolatedAbove(Region &region, Region &limit,
+                            llvm::Optional<Location> noteLoc) {
+  assert(limit.isAncestor(&region) &&
+         "expected isolation limit to be an ancestor of the given region");
+
+  // List of regions to analyze.  Each region is processed independently, with
+  // respect to the common `limit` region, so we can look at them in any order.
+  // Therefore, use a simple vector and push/pop back the current region.
+  SmallVector<Region *, 8> pendingRegions;
+  pendingRegions.push_back(&region);
+
+  // Traverse all operations in the region.
+  while (!pendingRegions.empty()) {
+    for (Block &block : *pendingRegions.pop_back_val()) {
+      for (Operation &op : block) {
+        for (Value *operand : op.getOperands()) {
+          // Check that any value that is used by an operation is defined in the
+          // same region as either an operation result or a block argument.
+          if (operand->getParentRegion()->isProperAncestor(&limit)) {
+            if (noteLoc) {
+              op.emitOpError("using value defined outside the region")
+                      .attachNote(noteLoc)
+                  << "required by region isolation constraints";
+            }
+            return false;
+          }
+        }
+        // Schedule any regions the operations contain for further checking.
+        pendingRegions.reserve(pendingRegions.size() + op.getNumRegions());
+        for (Region &subRegion : op.getRegions())
+          pendingRegions.push_back(&subRegion);
+      }
+    }
+  }
+  return true;
+}
+
+bool Region::isIsolatedFromAbove(llvm::Optional<Location> noteLoc) {
+  return isIsolatedAbove(*this, *this, noteLoc);
+}
+
+/// Walk the operations in this block in postorder, calling the callback for
+/// each operation.
+void Region::walk(llvm::function_ref<void(Operation *)> callback) {
+  for (auto &block : *this)
+    block.walk(callback);
+}
+
+Region *llvm::ilist_traits<::mlir::Block>::getParentRegion() {
+  size_t Offset(
+      size_t(&((Region *)nullptr->*Region::getSublistAccess(nullptr))));
+  iplist<Block> *Anchor(static_cast<iplist<Block> *>(this));
+  return reinterpret_cast<Region *>(reinterpret_cast<char *>(Anchor) - Offset);
+}
+
+/// This is a trait method invoked when a basic block is added to a region.
+/// We keep the region pointer up to date.
+void llvm::ilist_traits<::mlir::Block>::addNodeToList(Block *block) {
+  assert(!block->getParent() && "already in a region!");
+  block->parentValidInstOrderPair.setPointer(getParentRegion());
+}
+
+/// This is a trait method invoked when an operation is removed from a
+/// region.  We keep the region pointer up to date.
+void llvm::ilist_traits<::mlir::Block>::removeNodeFromList(Block *block) {
+  assert(block->getParent() && "not already in a region!");
+  block->parentValidInstOrderPair.setPointer(nullptr);
+}
+
+/// This is a trait method invoked when an operation is moved from one block
+/// to another.  We keep the block pointer up to date.
+void llvm::ilist_traits<::mlir::Block>::transferNodesFromList(
+    ilist_traits<Block> &otherList, block_iterator first, block_iterator last) {
+  // If we are transferring operations within the same function, the parent
+  // pointer doesn't need to be updated.
+  auto *curParent = getParentRegion();
+  if (curParent == otherList.getParentRegion())
+    return;
+
+  // Update the 'parent' member of each Block.
+  for (; first != last; ++first)
+    first->parentValidInstOrderPair.setPointer(curParent);
+}
diff --git a/third_party/mlir/lib/IR/StandardTypes.cpp b/third_party/mlir/lib/IR/StandardTypes.cpp
new file mode 100644
index 0000000..6077e4d
--- /dev/null
+++ b/third_party/mlir/lib/IR/StandardTypes.cpp
@@ -0,0 +1,423 @@
+//===- StandardTypes.cpp - MLIR Standard Type Classes ---------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/StandardTypes.h"
+#include "TypeDetail.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/Support/STLExtras.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+//===----------------------------------------------------------------------===//
+// Integer Type
+//===----------------------------------------------------------------------===//
+
+// static constexpr must have a definition (until in C++17 and inline variable).
+constexpr unsigned IntegerType::kMaxWidth;
+
+/// Verify the construction of an integer type.
+LogicalResult IntegerType::verifyConstructionInvariants(
+    llvm::Optional<Location> loc, MLIRContext *context, unsigned width) {
+  if (width > IntegerType::kMaxWidth) {
+    if (loc)
+      emitError(*loc) << "integer bitwidth is limited to "
+                      << IntegerType::kMaxWidth << " bits";
+    return failure();
+  }
+  return success();
+}
+
+unsigned IntegerType::getWidth() const { return getImpl()->width; }
+
+//===----------------------------------------------------------------------===//
+// Float Type
+//===----------------------------------------------------------------------===//
+
+unsigned FloatType::getWidth() {
+  switch (getKind()) {
+  case StandardTypes::BF16:
+  case StandardTypes::F16:
+    return 16;
+  case StandardTypes::F32:
+    return 32;
+  case StandardTypes::F64:
+    return 64;
+  default:
+    llvm_unreachable("unexpected type");
+  }
+}
+
+/// Returns the floating semantics for the given type.
+const llvm::fltSemantics &FloatType::getFloatSemantics() {
+  if (isBF16())
+    // Treat BF16 like a double. This is unfortunate but BF16 fltSemantics is
+    // not defined in LLVM.
+    // TODO(jpienaar): add BF16 to LLVM? fltSemantics are internal to APFloat.cc
+    // else one could add it.
+    //  static const fltSemantics semBF16 = {127, -126, 8, 16};
+    return APFloat::IEEEdouble();
+  if (isF16())
+    return APFloat::IEEEhalf();
+  if (isF32())
+    return APFloat::IEEEsingle();
+  if (isF64())
+    return APFloat::IEEEdouble();
+  llvm_unreachable("non-floating point type used");
+}
+
+unsigned Type::getIntOrFloatBitWidth() {
+  assert(isIntOrFloat() && "only ints and floats have a bitwidth");
+  if (auto intType = dyn_cast<IntegerType>()) {
+    return intType.getWidth();
+  }
+
+  auto floatType = cast<FloatType>();
+  return floatType.getWidth();
+}
+
+//===----------------------------------------------------------------------===//
+// ShapedType
+//===----------------------------------------------------------------------===//
+
+Type ShapedType::getElementType() const {
+  return static_cast<ImplType *>(impl)->elementType;
+}
+
+unsigned ShapedType::getElementTypeBitWidth() const {
+  return getElementType().getIntOrFloatBitWidth();
+}
+
+int64_t ShapedType::getNumElements() const {
+  assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
+  auto shape = getShape();
+  int64_t num = 1;
+  for (auto dim : shape)
+    num *= dim;
+  return num;
+}
+
+int64_t ShapedType::getRank() const { return getShape().size(); }
+
+bool ShapedType::hasRank() const { return !isa<UnrankedTensorType>(); }
+
+int64_t ShapedType::getDimSize(int64_t i) const {
+  assert(i >= 0 && i < getRank() && "invalid index for shaped type");
+  return getShape()[i];
+}
+
+/// Get the number of bits require to store a value of the given shaped type.
+/// Compute the value recursively since tensors are allowed to have vectors as
+/// elements.
+int64_t ShapedType::getSizeInBits() const {
+  assert(hasStaticShape() &&
+         "cannot get the bit size of an aggregate with a dynamic shape");
+
+  auto elementType = getElementType();
+  if (elementType.isIntOrFloat())
+    return elementType.getIntOrFloatBitWidth() * getNumElements();
+
+  // Tensors can have vectors and other tensors as elements, other shaped types
+  // cannot.
+  assert(isa<TensorType>() && "unsupported element type");
+  assert((elementType.isa<VectorType>() || elementType.isa<TensorType>()) &&
+         "unsupported tensor element type");
+  return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
+}
+
+ArrayRef<int64_t> ShapedType::getShape() const {
+  switch (getKind()) {
+  case StandardTypes::Vector:
+    return cast<VectorType>().getShape();
+  case StandardTypes::RankedTensor:
+    return cast<RankedTensorType>().getShape();
+  case StandardTypes::MemRef:
+    return cast<MemRefType>().getShape();
+  default:
+    llvm_unreachable("not a ShapedType or not ranked");
+  }
+}
+
+int64_t ShapedType::getNumDynamicDims() const {
+  return llvm::count_if(getShape(), isDynamic);
+}
+
+bool ShapedType::hasStaticShape() const {
+  return hasRank() && llvm::none_of(getShape(), isDynamic);
+}
+
+//===----------------------------------------------------------------------===//
+// VectorType
+//===----------------------------------------------------------------------===//
+
+VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
+  return Base::get(elementType.getContext(), StandardTypes::Vector, shape,
+                   elementType);
+}
+
+VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
+                                  Location location) {
+  return Base::getChecked(location, elementType.getContext(),
+                          StandardTypes::Vector, shape, elementType);
+}
+
+LogicalResult VectorType::verifyConstructionInvariants(
+    llvm::Optional<Location> loc, MLIRContext *context, ArrayRef<int64_t> shape,
+    Type elementType) {
+  if (shape.empty()) {
+    if (loc)
+      emitError(*loc, "vector types must have at least one dimension");
+    return failure();
+  }
+
+  if (!isValidElementType(elementType)) {
+    if (loc)
+      emitError(*loc, "vector elements must be int or float type");
+    return failure();
+  }
+
+  if (any_of(shape, [](int64_t i) { return i <= 0; })) {
+    if (loc)
+      emitError(*loc, "vector types must have positive constant sizes");
+    return failure();
+  }
+  return success();
+}
+
+ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
+
+//===----------------------------------------------------------------------===//
+// TensorType
+//===----------------------------------------------------------------------===//
+
+// Check if "elementType" can be an element type of a tensor. Emit errors if
+// location is not nullptr.  Returns failure if check failed.
+static inline LogicalResult checkTensorElementType(Optional<Location> location,
+                                                   MLIRContext *context,
+                                                   Type elementType) {
+  if (!TensorType::isValidElementType(elementType)) {
+    if (location)
+      emitError(*location, "invalid tensor element type");
+    return failure();
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// RankedTensorType
+//===----------------------------------------------------------------------===//
+
+RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
+                                       Type elementType) {
+  return Base::get(elementType.getContext(), StandardTypes::RankedTensor, shape,
+                   elementType);
+}
+
+RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
+                                              Type elementType,
+                                              Location location) {
+  return Base::getChecked(location, elementType.getContext(),
+                          StandardTypes::RankedTensor, shape, elementType);
+}
+
+LogicalResult RankedTensorType::verifyConstructionInvariants(
+    llvm::Optional<Location> loc, MLIRContext *context, ArrayRef<int64_t> shape,
+    Type elementType) {
+  for (int64_t s : shape) {
+    if (s < -1) {
+      if (loc)
+        emitError(*loc, "invalid tensor dimension size");
+      return failure();
+    }
+  }
+  return checkTensorElementType(loc, context, elementType);
+}
+
+ArrayRef<int64_t> RankedTensorType::getShape() const {
+  return getImpl()->getShape();
+}
+
+//===----------------------------------------------------------------------===//
+// UnrankedTensorType
+//===----------------------------------------------------------------------===//
+
+UnrankedTensorType UnrankedTensorType::get(Type elementType) {
+  return Base::get(elementType.getContext(), StandardTypes::UnrankedTensor,
+                   elementType);
+}
+
+UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
+                                                  Location location) {
+  return Base::getChecked(location, elementType.getContext(),
+                          StandardTypes::UnrankedTensor, elementType);
+}
+
+LogicalResult UnrankedTensorType::verifyConstructionInvariants(
+    llvm::Optional<Location> loc, MLIRContext *context, Type elementType) {
+  return checkTensorElementType(loc, context, elementType);
+}
+
+//===----------------------------------------------------------------------===//
+// MemRefType
+//===----------------------------------------------------------------------===//
+
+/// Get or create a new MemRefType based on shape, element type, affine
+/// map composition, and memory space.  Assumes the arguments define a
+/// well-formed MemRef type.  Use getChecked to gracefully handle MemRefType
+/// construction failures.
+MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
+                           ArrayRef<AffineMap> affineMapComposition,
+                           unsigned memorySpace) {
+  auto result = getImpl(shape, elementType, affineMapComposition, memorySpace,
+                        /*location=*/llvm::None);
+  assert(result && "Failed to construct instance of MemRefType.");
+  return result;
+}
+
+/// Get or create a new MemRefType based on shape, element type, affine
+/// map composition, and memory space declared at the given location.
+/// If the location is unknown, the last argument should be an instance of
+/// UnknownLoc.  If the MemRefType defined by the arguments would be
+/// ill-formed, emits errors (to the handler registered with the context or to
+/// the error stream) and returns nullptr.
+MemRefType MemRefType::getChecked(ArrayRef<int64_t> shape, Type elementType,
+                                  ArrayRef<AffineMap> affineMapComposition,
+                                  unsigned memorySpace, Location location) {
+  return getImpl(shape, elementType, affineMapComposition, memorySpace,
+                 location);
+}
+
+/// Get or create a new MemRefType defined by the arguments.  If the resulting
+/// type would be ill-formed, return nullptr.  If the location is provided,
+/// emit detailed error messages.  To emit errors when the location is unknown,
+/// pass in an instance of UnknownLoc.
+MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
+                               ArrayRef<AffineMap> affineMapComposition,
+                               unsigned memorySpace,
+                               Optional<Location> location) {
+  auto *context = elementType.getContext();
+
+  for (int64_t s : shape) {
+    // Negative sizes are not allowed except for `-1` that means dynamic size.
+    if (s < -1) {
+      if (location)
+        emitError(*location, "invalid memref size");
+      return {};
+    }
+  }
+
+  // Check that the structure of the composition is valid, i.e. that each
+  // subsequent affine map has as many inputs as the previous map has results.
+  // Take the dimensionality of the MemRef for the first map.
+  auto dim = shape.size();
+  unsigned i = 0;
+  for (const auto &affineMap : affineMapComposition) {
+    if (affineMap.getNumDims() != dim) {
+      if (location)
+        emitError(*location)
+            << "memref affine map dimension mismatch between "
+            << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
+            << " and affine map" << i + 1 << ": " << dim
+            << " != " << affineMap.getNumDims();
+      return nullptr;
+    }
+
+    dim = affineMap.getNumResults();
+    ++i;
+  }
+
+  // Drop identity maps from the composition.
+  // This may lead to the composition becoming empty, which is interpreted as an
+  // implicit identity.
+  llvm::SmallVector<AffineMap, 2> cleanedAffineMapComposition;
+  for (const auto &map : affineMapComposition) {
+    if (map.isIdentity())
+      continue;
+    cleanedAffineMapComposition.push_back(map);
+  }
+
+  return Base::get(context, StandardTypes::MemRef, shape, elementType,
+                   cleanedAffineMapComposition, memorySpace);
+}
+
+ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); }
+
+ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
+  return getImpl()->getAffineMaps();
+}
+
+unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; }
+
+//===----------------------------------------------------------------------===//
+/// ComplexType
+//===----------------------------------------------------------------------===//
+
+ComplexType ComplexType::get(Type elementType) {
+  return Base::get(elementType.getContext(), StandardTypes::Complex,
+                   elementType);
+}
+
+ComplexType ComplexType::getChecked(Type elementType, Location location) {
+  return Base::getChecked(location, elementType.getContext(),
+                          StandardTypes::Complex, elementType);
+}
+
+/// Verify the construction of an integer type.
+LogicalResult ComplexType::verifyConstructionInvariants(
+    llvm::Optional<Location> loc, MLIRContext *context, Type elementType) {
+  if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) {
+    if (loc)
+      emitError(*loc, "invalid element type for complex");
+    return failure();
+  }
+  return success();
+}
+
+Type ComplexType::getElementType() { return getImpl()->elementType; }
+
+//===----------------------------------------------------------------------===//
+/// TupleType
+//===----------------------------------------------------------------------===//
+
+/// Get or create a new TupleType with the provided element types. Assumes the
+/// arguments define a well-formed type.
+TupleType TupleType::get(ArrayRef<Type> elementTypes, MLIRContext *context) {
+  return Base::get(context, StandardTypes::Tuple, elementTypes);
+}
+
+/// Return the elements types for this tuple.
+ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
+
+/// Accumulate the types contained in this tuple and tuples nested within it.
+/// Note that this only flattens nested tuples, not any other container type,
+/// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
+/// (i32, tensor<i32>, f32, i64)
+void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
+  for (Type type : getTypes()) {
+    if (auto nestedTuple = type.dyn_cast<TupleType>())
+      nestedTuple.getFlattenedTypes(types);
+    else
+      types.push_back(type);
+  }
+}
+
+/// Return the number of element types.
+size_t TupleType::size() const { return getImpl()->size(); }
diff --git a/third_party/mlir/lib/IR/SymbolTable.cpp b/third_party/mlir/lib/IR/SymbolTable.cpp
new file mode 100644
index 0000000..62dd6b0
--- /dev/null
+++ b/third_party/mlir/lib/IR/SymbolTable.cpp
@@ -0,0 +1,114 @@
+//===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/SymbolTable.h"
+#include "llvm/ADT/SmallString.h"
+
+using namespace mlir;
+
+/// Build a symbol table with the symbols within the given operation.
+SymbolTable::SymbolTable(Operation *op) : context(op->getContext()) {
+  assert(op->hasTrait<OpTrait::SymbolTable>() &&
+         "expected operation to have SymbolTable trait");
+  assert(op->getNumRegions() == 1 &&
+         "expected operation to have a single region");
+
+  for (auto &block : op->getRegion(0)) {
+    for (auto &op : block) {
+      auto nameAttr = op.getAttrOfType<StringAttr>(getSymbolAttrName());
+      if (!nameAttr)
+        continue;
+
+      auto inserted = symbolTable.insert({nameAttr.getValue(), &op});
+      (void)inserted;
+      assert(inserted.second &&
+             "expected region to contain uniquely named symbol operations");
+    }
+  }
+}
+
+/// Look up a symbol with the specified name, returning null if no such name
+/// exists. Names never include the @ on them.
+Operation *SymbolTable::lookup(StringRef name) const {
+  return symbolTable.lookup(name);
+}
+
+/// Erase the given symbol from the table.
+void SymbolTable::erase(Operation *symbol) {
+  auto nameAttr = symbol->getAttrOfType<StringAttr>(getSymbolAttrName());
+  assert(nameAttr && "expected valid 'name' attribute");
+
+  auto it = symbolTable.find(nameAttr.getValue());
+  if (it != symbolTable.end() && it->second == symbol)
+    symbolTable.erase(it);
+}
+
+/// Insert a new symbol into the table, and rename it as necessary to avoid
+/// collisions.
+void SymbolTable::insert(Operation *symbol) {
+  auto nameAttr = symbol->getAttrOfType<StringAttr>(getSymbolAttrName());
+  assert(nameAttr && "expected valid 'name' attribute");
+
+  // Add this symbol to the symbol table, uniquing the name if a conflict is
+  // detected.
+  if (symbolTable.insert({nameAttr.getValue(), symbol}).second)
+    return;
+
+  // If a conflict was detected, then the symbol will not have been added to
+  // the symbol table. Try suffixes until we get to a unique name that works.
+  SmallString<128> nameBuffer(nameAttr.getValue());
+  unsigned originalLength = nameBuffer.size();
+
+  // Iteratively try suffixes until we find one that isn't used.
+  do {
+    nameBuffer.resize(originalLength);
+    nameBuffer += '_';
+    nameBuffer += std::to_string(uniquingCounter++);
+  } while (!symbolTable.insert({nameBuffer, symbol}).second);
+  symbol->setAttr(getSymbolAttrName(), StringAttr::get(nameBuffer, context));
+}
+
+//===----------------------------------------------------------------------===//
+// SymbolTable Trait Types
+//===----------------------------------------------------------------------===//
+
+LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) {
+  if (op->getNumRegions() != 1)
+    return op->emitOpError()
+           << "Operations with a 'SymbolTable' must have exactly one region";
+
+  // Check that all symboles are uniquely named within child regions.
+  llvm::StringMap<Location> nameToOrigLoc;
+  for (auto &block : op->getRegion(0)) {
+    for (auto &op : block) {
+      // Check for a symbol name attribute.
+      auto nameAttr =
+          op.getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName());
+      if (!nameAttr)
+        continue;
+
+      // Try to insert this symbol into the table.
+      auto it = nameToOrigLoc.try_emplace(nameAttr.getValue(), op.getLoc());
+      if (!it.second)
+        return op.emitError()
+            .append("redefinition of symbol named '", nameAttr.getValue(), "'")
+            .attachNote(it.first->second)
+            .append("see existing symbol definition here");
+    }
+  }
+  return success();
+}
diff --git a/third_party/mlir/lib/IR/TypeDetail.h b/third_party/mlir/lib/IR/TypeDetail.h
new file mode 100644
index 0000000..0e7edf0
--- /dev/null
+++ b/third_party/mlir/lib/IR/TypeDetail.h
@@ -0,0 +1,308 @@
+//===- TypeDetail.h - MLIR Type storage details -----------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This holds implementation details of Type.
+//
+//===----------------------------------------------------------------------===//
+#ifndef TYPEDETAIL_H_
+#define TYPEDETAIL_H_
+
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Identifier.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/TypeSupport.h"
+#include "mlir/IR/Types.h"
+#include "llvm/Support/TrailingObjects.h"
+
+namespace mlir {
+
+class MLIRContext;
+
+namespace detail {
+
+/// Opaque Type Storage and Uniquing.
+struct OpaqueTypeStorage : public TypeStorage {
+  OpaqueTypeStorage(Identifier dialectNamespace, StringRef typeData)
+      : dialectNamespace(dialectNamespace), typeData(typeData) {}
+
+  /// The hash key used for uniquing.
+  using KeyTy = std::pair<Identifier, StringRef>;
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(dialectNamespace, typeData);
+  }
+
+  static OpaqueTypeStorage *construct(TypeStorageAllocator &allocator,
+                                      const KeyTy &key) {
+    StringRef tyData = allocator.copyInto(key.second);
+    return new (allocator.allocate<OpaqueTypeStorage>())
+        OpaqueTypeStorage(key.first, tyData);
+  }
+
+  // The dialect namespace.
+  Identifier dialectNamespace;
+
+  // The parser type data for this opaque type.
+  StringRef typeData;
+};
+
+/// Integer Type Storage and Uniquing.
+struct IntegerTypeStorage : public TypeStorage {
+  IntegerTypeStorage(unsigned width) : width(width) {}
+
+  /// The hash key used for uniquing.
+  using KeyTy = unsigned;
+  bool operator==(const KeyTy &key) const { return key == width; }
+
+  static IntegerTypeStorage *construct(TypeStorageAllocator &allocator,
+                                       KeyTy bitwidth) {
+    return new (allocator.allocate<IntegerTypeStorage>())
+        IntegerTypeStorage(bitwidth);
+  }
+
+  unsigned width;
+};
+
+/// Function Type Storage and Uniquing.
+struct FunctionTypeStorage : public TypeStorage {
+  FunctionTypeStorage(unsigned numInputs, unsigned numResults,
+                      Type const *inputsAndResults)
+      : TypeStorage(numInputs), numResults(numResults),
+        inputsAndResults(inputsAndResults) {}
+
+  /// The hash key used for uniquing.
+  using KeyTy = std::pair<ArrayRef<Type>, ArrayRef<Type>>;
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(getInputs(), getResults());
+  }
+
+  /// Construction.
+  static FunctionTypeStorage *construct(TypeStorageAllocator &allocator,
+                                        const KeyTy &key) {
+    ArrayRef<Type> inputs = key.first, results = key.second;
+
+    // Copy the inputs and results into the bump pointer.
+    SmallVector<Type, 16> types;
+    types.reserve(inputs.size() + results.size());
+    types.append(inputs.begin(), inputs.end());
+    types.append(results.begin(), results.end());
+    auto typesList = allocator.copyInto(ArrayRef<Type>(types));
+
+    // Initialize the memory using placement new.
+    return new (allocator.allocate<FunctionTypeStorage>())
+        FunctionTypeStorage(inputs.size(), results.size(), typesList.data());
+  }
+
+  ArrayRef<Type> getInputs() const {
+    return ArrayRef<Type>(inputsAndResults, getSubclassData());
+  }
+  ArrayRef<Type> getResults() const {
+    return ArrayRef<Type>(inputsAndResults + getSubclassData(), numResults);
+  }
+
+  unsigned numResults;
+  Type const *inputsAndResults;
+};
+
+/// VectorOrTensor Type Storage.
+struct ShapedTypeStorage : public TypeStorage {
+  ShapedTypeStorage(Type elementType, unsigned subclassData = 0)
+      : TypeStorage(subclassData), elementType(elementType) {}
+
+  /// The hash key used for uniquing.
+  using KeyTy = Type;
+  bool operator==(const KeyTy &key) const { return key == elementType; }
+
+  Type elementType;
+};
+
+/// Vector Type Storage and Uniquing.
+struct VectorTypeStorage : public ShapedTypeStorage {
+  VectorTypeStorage(unsigned shapeSize, Type elementTy,
+                    const int64_t *shapeElements)
+      : ShapedTypeStorage(elementTy, shapeSize), shapeElements(shapeElements) {}
+
+  /// The hash key used for uniquing.
+  using KeyTy = std::pair<ArrayRef<int64_t>, Type>;
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(getShape(), elementType);
+  }
+
+  /// Construction.
+  static VectorTypeStorage *construct(TypeStorageAllocator &allocator,
+                                      const KeyTy &key) {
+    // Copy the shape into the bump pointer.
+    ArrayRef<int64_t> shape = allocator.copyInto(key.first);
+
+    // Initialize the memory using placement new.
+    return new (allocator.allocate<VectorTypeStorage>())
+        VectorTypeStorage(shape.size(), key.second, shape.data());
+  }
+
+  ArrayRef<int64_t> getShape() const {
+    return ArrayRef<int64_t>(shapeElements, getSubclassData());
+  }
+
+  const int64_t *shapeElements;
+};
+
+struct RankedTensorTypeStorage : public ShapedTypeStorage {
+  RankedTensorTypeStorage(unsigned shapeSize, Type elementTy,
+                          const int64_t *shapeElements)
+      : ShapedTypeStorage(elementTy, shapeSize), shapeElements(shapeElements) {}
+
+  /// The hash key used for uniquing.
+  using KeyTy = std::pair<ArrayRef<int64_t>, Type>;
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(getShape(), elementType);
+  }
+
+  /// Construction.
+  static RankedTensorTypeStorage *construct(TypeStorageAllocator &allocator,
+                                            const KeyTy &key) {
+    // Copy the shape into the bump pointer.
+    ArrayRef<int64_t> shape = allocator.copyInto(key.first);
+
+    // Initialize the memory using placement new.
+    return new (allocator.allocate<RankedTensorTypeStorage>())
+        RankedTensorTypeStorage(shape.size(), key.second, shape.data());
+  }
+
+  ArrayRef<int64_t> getShape() const {
+    return ArrayRef<int64_t>(shapeElements, getSubclassData());
+  }
+
+  const int64_t *shapeElements;
+};
+
+struct UnrankedTensorTypeStorage : public ShapedTypeStorage {
+  using ShapedTypeStorage::KeyTy;
+  using ShapedTypeStorage::ShapedTypeStorage;
+
+  /// Construction.
+  static UnrankedTensorTypeStorage *construct(TypeStorageAllocator &allocator,
+                                              Type elementTy) {
+    return new (allocator.allocate<UnrankedTensorTypeStorage>())
+        UnrankedTensorTypeStorage(elementTy);
+  }
+};
+
+struct MemRefTypeStorage : public ShapedTypeStorage {
+  MemRefTypeStorage(unsigned shapeSize, Type elementType,
+                    const int64_t *shapeElements, const unsigned numAffineMaps,
+                    AffineMap const *affineMapList, const unsigned memorySpace)
+      : ShapedTypeStorage(elementType, shapeSize), shapeElements(shapeElements),
+        numAffineMaps(numAffineMaps), affineMapList(affineMapList),
+        memorySpace(memorySpace) {}
+
+  /// The hash key used for uniquing.
+  // MemRefs are uniqued based on their shape, element type, affine map
+  // composition, and memory space.
+  using KeyTy =
+      std::tuple<ArrayRef<int64_t>, Type, ArrayRef<AffineMap>, unsigned>;
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(getShape(), elementType, getAffineMaps(), memorySpace);
+  }
+
+  /// Construction.
+  static MemRefTypeStorage *construct(TypeStorageAllocator &allocator,
+                                      const KeyTy &key) {
+    // Copy the shape into the bump pointer.
+    ArrayRef<int64_t> shape = allocator.copyInto(std::get<0>(key));
+
+    // Copy the affine map composition into the bump pointer.
+    ArrayRef<AffineMap> affineMapComposition =
+        allocator.copyInto(std::get<2>(key));
+
+    // Initialize the memory using placement new.
+    return new (allocator.allocate<MemRefTypeStorage>())
+        MemRefTypeStorage(shape.size(), std::get<1>(key), shape.data(),
+                          affineMapComposition.size(),
+                          affineMapComposition.data(), std::get<3>(key));
+  }
+
+  ArrayRef<int64_t> getShape() const {
+    return ArrayRef<int64_t>(shapeElements, getSubclassData());
+  }
+
+  ArrayRef<AffineMap> getAffineMaps() const {
+    return ArrayRef<AffineMap>(affineMapList, numAffineMaps);
+  }
+
+  /// An array of integers which stores the shape dimension sizes.
+  const int64_t *shapeElements;
+  /// The number of affine maps in the 'affineMapList' array.
+  const unsigned numAffineMaps;
+  /// List of affine maps in the memref's layout/index map composition.
+  AffineMap const *affineMapList;
+  /// Memory space in which data referenced by memref resides.
+  const unsigned memorySpace;
+};
+
+/// Complex Type Storage.
+struct ComplexTypeStorage : public TypeStorage {
+  ComplexTypeStorage(Type elementType) : elementType(elementType) {}
+
+  /// The hash key used for uniquing.
+  using KeyTy = Type;
+  bool operator==(const KeyTy &key) const { return key == elementType; }
+
+  /// Construction.
+  static ComplexTypeStorage *construct(TypeStorageAllocator &allocator,
+                                       Type elementType) {
+    return new (allocator.allocate<ComplexTypeStorage>())
+        ComplexTypeStorage(elementType);
+  }
+
+  Type elementType;
+};
+
+/// A type representing a collection of other types.
+struct TupleTypeStorage final
+    : public TypeStorage,
+      public llvm::TrailingObjects<TupleTypeStorage, Type> {
+  using KeyTy = ArrayRef<Type>;
+
+  TupleTypeStorage(unsigned numTypes) : TypeStorage(numTypes) {}
+
+  /// Construction.
+  static TupleTypeStorage *construct(TypeStorageAllocator &allocator,
+                                     ArrayRef<Type> key) {
+    // Allocate a new storage instance.
+    auto byteSize = TupleTypeStorage::totalSizeToAlloc<Type>(key.size());
+    auto rawMem = allocator.allocate(byteSize, alignof(TupleTypeStorage));
+    auto result = ::new (rawMem) TupleTypeStorage(key.size());
+
+    // Copy in the element types into the trailing storage.
+    std::uninitialized_copy(key.begin(), key.end(),
+                            result->getTrailingObjects<Type>());
+    return result;
+  }
+
+  bool operator==(const KeyTy &key) const { return key == getTypes(); }
+
+  /// Return the number of held types.
+  unsigned size() const { return getSubclassData(); }
+
+  /// Return the held types.
+  ArrayRef<Type> getTypes() const {
+    return {getTrailingObjects<Type>(), size()};
+  }
+};
+
+} // namespace detail
+} // namespace mlir
+#endif // TYPEDETAIL_H_
diff --git a/third_party/mlir/lib/IR/TypeUtilities.cpp b/third_party/mlir/lib/IR/TypeUtilities.cpp
new file mode 100644
index 0000000..95895af
--- /dev/null
+++ b/third_party/mlir/lib/IR/TypeUtilities.cpp
@@ -0,0 +1,76 @@
+//===- TypeUtilities.cpp - Helper function for type queries ---------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines generic type utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+
+using namespace mlir;
+
+Type mlir::getElementTypeOrSelf(Type type) {
+  if (auto st = type.dyn_cast<ShapedType>())
+    return st.getElementType();
+  return type;
+}
+
+Type mlir::getElementTypeOrSelf(Value *val) {
+  return getElementTypeOrSelf(val->getType());
+}
+
+Type mlir::getElementTypeOrSelf(Value &val) {
+  return getElementTypeOrSelf(val.getType());
+}
+
+Type mlir::getElementTypeOrSelf(Attribute attr) {
+  return getElementTypeOrSelf(attr.getType());
+}
+
+SmallVector<Type, 10> mlir::getFlattenedTypes(TupleType t) {
+  SmallVector<Type, 10> fTypes;
+  t.getFlattenedTypes(fTypes);
+  return fTypes;
+}
+
+/// Return true if the specified type is an opaque type with the specified
+/// dialect and typeData.
+bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect,
+                                StringRef typeData) {
+  if (auto opaque = type.dyn_cast<mlir::OpaqueType>())
+    return opaque.getDialectNamespace().is(dialect) &&
+           opaque.getTypeData() == typeData;
+  return false;
+}
+
+OperandElementTypeIterator::OperandElementTypeIterator(OperandIterator it)
+    : llvm::mapped_iterator<OperandIterator, Type (*)(Value *)>(it, &unwrap) {}
+
+Type OperandElementTypeIterator::unwrap(Value *value) {
+  return value->getType().cast<ShapedType>().getElementType();
+}
+
+ResultElementTypeIterator::ResultElementTypeIterator(ResultIterator it)
+    : llvm::mapped_iterator<ResultIterator, Type (*)(Value *)>(it, &unwrap) {}
+
+Type ResultElementTypeIterator::unwrap(Value *value) {
+  return value->getType().cast<ShapedType>().getElementType();
+}
diff --git a/third_party/mlir/lib/IR/Types.cpp b/third_party/mlir/lib/IR/Types.cpp
new file mode 100644
index 0000000..cd75176
--- /dev/null
+++ b/third_party/mlir/lib/IR/Types.cpp
@@ -0,0 +1,84 @@
+//===- Types.cpp - MLIR Type Classes --------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Types.h"
+#include "TypeDetail.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
+#include "llvm/ADT/Twine.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+unsigned Type::getKind() const { return impl->getKind(); }
+
+/// Get the dialect this type is registered to.
+Dialect &Type::getDialect() const { return impl->getDialect(); }
+
+MLIRContext *Type::getContext() const { return getDialect().getContext(); }
+
+unsigned Type::getSubclassData() const { return impl->getSubclassData(); }
+void Type::setSubclassData(unsigned val) { impl->setSubclassData(val); }
+
+/// Function Type.
+
+FunctionType FunctionType::get(ArrayRef<Type> inputs, ArrayRef<Type> results,
+                               MLIRContext *context) {
+  return Base::get(context, Type::Kind::Function, inputs, results);
+}
+
+ArrayRef<Type> FunctionType::getInputs() const {
+  return getImpl()->getInputs();
+}
+
+unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
+
+ArrayRef<Type> FunctionType::getResults() const {
+  return getImpl()->getResults();
+}
+
+/// OpaqueType
+
+OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData,
+                           MLIRContext *context) {
+  return Base::get(context, Type::Kind::Opaque, dialect, typeData);
+}
+
+OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData,
+                                  MLIRContext *context, Location location) {
+  return Base::getChecked(location, context, Kind::Opaque, dialect, typeData);
+}
+
+/// Returns the dialect namespace of the opaque type.
+Identifier OpaqueType::getDialectNamespace() const {
+  return getImpl()->dialectNamespace;
+}
+
+/// Returns the raw type data of the opaque type.
+StringRef OpaqueType::getTypeData() const { return getImpl()->typeData; }
+
+/// Verify the construction of an opaque type.
+LogicalResult OpaqueType::verifyConstructionInvariants(
+    llvm::Optional<Location> loc, MLIRContext *context, Identifier dialect,
+    StringRef typeData) {
+  if (!Dialect::isValidNamespace(dialect.strref())) {
+    if (loc)
+      emitError(*loc) << "invalid dialect namespace '" << dialect << "'";
+    return failure();
+  }
+  return success();
+}
diff --git a/third_party/mlir/lib/IR/Value.cpp b/third_party/mlir/lib/IR/Value.cpp
new file mode 100644
index 0000000..4ad1460
--- /dev/null
+++ b/third_party/mlir/lib/IR/Value.cpp
@@ -0,0 +1,67 @@
+//===- Value.cpp - MLIR Value Classes -------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Value.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Operation.h"
+using namespace mlir;
+
+/// If this value is the result of an Operation, return the operation that
+/// defines it.
+Operation *Value::getDefiningOp() {
+  if (auto *result = dyn_cast<OpResult>(this))
+    return result->getOwner();
+  return nullptr;
+}
+
+Location Value::getLoc() {
+  if (auto *op = getDefiningOp())
+    return op->getLoc();
+  return UnknownLoc::get(getContext());
+}
+
+/// Return the Region in which this Value is defined.
+Region *Value::getParentRegion() {
+  switch (getKind()) {
+  case Value::Kind::BlockArgument:
+    return cast<BlockArgument>(this)->getOwner()->getParent();
+  case Value::Kind::OpResult:
+    return getDefiningOp()->getParentRegion();
+  }
+  llvm_unreachable("Unknown Value Kind");
+}
+
+//===----------------------------------------------------------------------===//
+// IRObjectWithUseList implementation.
+//===----------------------------------------------------------------------===//
+
+/// Replace all uses of 'this' value with the new value, updating anything in
+/// the IR that uses 'this' to use the other value instead.  When this returns
+/// there are zero uses of 'this'.
+void IRObjectWithUseList::replaceAllUsesWith(IRObjectWithUseList *newValue) {
+  assert(this != newValue && "cannot RAUW a value with itself");
+  while (!use_empty()) {
+    use_begin()->set(newValue);
+  }
+}
+
+/// Drop all uses of this object from their respective owners.
+void IRObjectWithUseList::dropAllUses() {
+  while (!use_empty()) {
+    use_begin()->drop();
+  }
+}
diff --git a/third_party/mlir/lib/LLVMIR/CMakeLists.txt b/third_party/mlir/lib/LLVMIR/CMakeLists.txt
new file mode 100644
index 0000000..5e21850
--- /dev/null
+++ b/third_party/mlir/lib/LLVMIR/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_llvm_library(MLIRLLVMIR
+  IR/LLVMDialect.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/LLVMIR
+  )
+add_dependencies(MLIRLLVMIR MLIRLLVMOpsIncGen MLIRLLVMConversionsIncGen LLVMAsmParser LLVMCore LLVMSupport)
+target_link_libraries(MLIRLLVMIR LLVMAsmParser LLVMCore LLVMSupport)
+
+add_llvm_library(MLIRNVVMIR
+  IR/NVVMDialect.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/LLVMIR
+  )
+add_dependencies(MLIRNVVMIR MLIRNVVMOpsIncGen MLIRNVVMConversionsIncGen LLVMAsmParser LLVMCore LLVMSupport)
+target_link_libraries(MLIRNVVMIR LLVMAsmParser LLVMCore LLVMSupport)
diff --git a/third_party/mlir/lib/LLVMIR/IR/LLVMDialect.cpp b/third_party/mlir/lib/LLVMIR/IR/LLVMDialect.cpp
new file mode 100644
index 0000000..378907e
--- /dev/null
+++ b/third_party/mlir/lib/LLVMIR/IR/LLVMDialect.cpp
@@ -0,0 +1,1348 @@
+//===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines the types and operation details for the LLVM IR dialect in
+// MLIR, and the LLVM IR dialect.  It also registers the dialect.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/StandardTypes.h"
+
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Attributes.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Mutex.h"
+#include "llvm/Support/SourceMgr.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+#include "mlir/LLVMIR/LLVMOpsEnums.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::CmpOp.
+//===----------------------------------------------------------------------===//
+static void printICmpOp(OpAsmPrinter *p, ICmpOp &op) {
+  *p << op.getOperationName() << " \"" << stringifyICmpPredicate(op.predicate())
+     << "\" " << *op.getOperand(0) << ", " << *op.getOperand(1);
+  p->printOptionalAttrDict(op.getAttrs(), {"predicate"});
+  *p << " : " << op.lhs()->getType();
+}
+
+static void printFCmpOp(OpAsmPrinter *p, FCmpOp &op) {
+  *p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate())
+     << "\" " << *op.getOperand(0) << ", " << *op.getOperand(1);
+  p->printOptionalAttrDict(op.getAttrs(), {"predicate"});
+  *p << " : " << op.lhs()->getType();
+}
+
+// <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
+//                 attribute-dict? `:` type
+// <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use
+//                 attribute-dict? `:` type
+template <typename CmpPredicateType>
+static ParseResult parseCmpOp(OpAsmParser *parser, OperationState *result) {
+  Builder &builder = parser->getBuilder();
+
+  Attribute predicate;
+  SmallVector<NamedAttribute, 4> attrs;
+  OpAsmParser::OperandType lhs, rhs;
+  Type type;
+  llvm::SMLoc predicateLoc, trailingTypeLoc;
+  if (parser->getCurrentLocation(&predicateLoc) ||
+      parser->parseAttribute(predicate, "predicate", attrs) ||
+      parser->parseOperand(lhs) || parser->parseComma() ||
+      parser->parseOperand(rhs) || parser->parseOptionalAttributeDict(attrs) ||
+      parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) ||
+      parser->parseType(type) ||
+      parser->resolveOperand(lhs, type, result->operands) ||
+      parser->resolveOperand(rhs, type, result->operands))
+    return failure();
+
+  // Replace the string attribute `predicate` with an integer attribute.
+  auto predicateStr = predicate.dyn_cast<StringAttr>();
+  if (!predicateStr)
+    return parser->emitError(predicateLoc,
+                             "expected 'predicate' attribute of string type");
+
+  int64_t predicateValue = 0;
+  if (std::is_same<CmpPredicateType, ICmpPredicate>()) {
+    Optional<ICmpPredicate> predicate =
+        symbolizeICmpPredicate(predicateStr.getValue());
+    if (!predicate)
+      return parser->emitError(predicateLoc)
+             << "'" << predicateStr.getValue()
+             << "' is an incorrect value of the 'predicate' attribute";
+    predicateValue = static_cast<int64_t>(predicate.getValue());
+  } else {
+    Optional<FCmpPredicate> predicate =
+        symbolizeFCmpPredicate(predicateStr.getValue());
+    if (!predicate)
+      return parser->emitError(predicateLoc)
+             << "'" << predicateStr.getValue()
+             << "' is an incorrect value of the 'predicate' attribute";
+    predicateValue = static_cast<int64_t>(predicate.getValue());
+  }
+
+  attrs[0].second = parser->getBuilder().getI64IntegerAttr(predicateValue);
+
+  // The result type is either i1 or a vector type <? x i1> if the inputs are
+  // vectors.
+  auto *dialect = builder.getContext()->getRegisteredDialect<LLVMDialect>();
+  auto resultType = LLVMType::getInt1Ty(dialect);
+  auto argType = type.dyn_cast<LLVM::LLVMType>();
+  if (!argType)
+    return parser->emitError(trailingTypeLoc, "expected LLVM IR dialect type");
+  if (argType.getUnderlyingType()->isVectorTy())
+    resultType = LLVMType::getVectorTy(
+        resultType, argType.getUnderlyingType()->getVectorNumElements());
+
+  result->attributes = attrs;
+  result->addTypes({resultType});
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::AllocaOp.
+//===----------------------------------------------------------------------===//
+
+static void printAllocaOp(OpAsmPrinter *p, AllocaOp &op) {
+  auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy();
+
+  auto funcTy = FunctionType::get({op.arraySize()->getType()}, {op.getType()},
+                                  op.getContext());
+
+  *p << op.getOperationName() << ' ' << *op.arraySize() << " x " << elemTy;
+  p->printOptionalAttrDict(op.getAttrs());
+  *p << " : " << funcTy;
+}
+
+// <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict?
+//                 `:` type `,` type
+static ParseResult parseAllocaOp(OpAsmParser *parser, OperationState *result) {
+  SmallVector<NamedAttribute, 4> attrs;
+  OpAsmParser::OperandType arraySize;
+  Type type, elemType;
+  llvm::SMLoc trailingTypeLoc;
+  if (parser->parseOperand(arraySize) || parser->parseKeyword("x") ||
+      parser->parseType(elemType) ||
+      parser->parseOptionalAttributeDict(attrs) || parser->parseColon() ||
+      parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type))
+    return failure();
+
+  // Extract the result type from the trailing function type.
+  auto funcType = type.dyn_cast<FunctionType>();
+  if (!funcType || funcType.getNumInputs() != 1 ||
+      funcType.getNumResults() != 1)
+    return parser->emitError(
+        trailingTypeLoc,
+        "expected trailing function type with one argument and one result");
+
+  if (parser->resolveOperand(arraySize, funcType.getInput(0), result->operands))
+    return failure();
+
+  result->attributes = attrs;
+  result->addTypes({funcType.getResult(0)});
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::GEPOp.
+//===----------------------------------------------------------------------===//
+
+static void printGEPOp(OpAsmPrinter *p, GEPOp &op) {
+  SmallVector<Type, 8> types(op.getOperandTypes());
+  auto funcTy = FunctionType::get(types, op.getType(), op.getContext());
+
+  *p << op.getOperationName() << ' ' << *op.base() << '[';
+  p->printOperands(std::next(op.operand_begin()), op.operand_end());
+  *p << ']';
+  p->printOptionalAttrDict(op.getAttrs());
+  *p << " : " << funcTy;
+}
+
+// <operation> ::= `llvm.getelementptr` ssa-use `[` ssa-use-list `]`
+//                 attribute-dict? `:` type
+static ParseResult parseGEPOp(OpAsmParser *parser, OperationState *result) {
+  SmallVector<NamedAttribute, 4> attrs;
+  OpAsmParser::OperandType base;
+  SmallVector<OpAsmParser::OperandType, 8> indices;
+  Type type;
+  llvm::SMLoc trailingTypeLoc;
+  if (parser->parseOperand(base) ||
+      parser->parseOperandList(indices, OpAsmParser::Delimiter::Square) ||
+      parser->parseOptionalAttributeDict(attrs) || parser->parseColon() ||
+      parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type))
+    return failure();
+
+  // Deconstruct the trailing function type to extract the types of the base
+  // pointer and result (same type) and the types of the indices.
+  auto funcType = type.dyn_cast<FunctionType>();
+  if (!funcType || funcType.getNumResults() != 1 ||
+      funcType.getNumInputs() == 0)
+    return parser->emitError(trailingTypeLoc,
+                             "expected trailing function type with at least "
+                             "one argument and one result");
+
+  if (parser->resolveOperand(base, funcType.getInput(0), result->operands) ||
+      parser->resolveOperands(indices, funcType.getInputs().drop_front(),
+                              parser->getNameLoc(), result->operands))
+    return failure();
+
+  result->attributes = attrs;
+  result->addTypes(funcType.getResults());
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::LoadOp.
+//===----------------------------------------------------------------------===//
+
+static void printLoadOp(OpAsmPrinter *p, LoadOp &op) {
+  *p << op.getOperationName() << ' ' << *op.addr();
+  p->printOptionalAttrDict(op.getAttrs());
+  *p << " : " << op.addr()->getType();
+}
+
+// Extract the pointee type from the LLVM pointer type wrapped in MLIR.  Return
+// the resulting type wrapped in MLIR, or nullptr on error.
+static Type getLoadStoreElementType(OpAsmParser *parser, Type type,
+                                    llvm::SMLoc trailingTypeLoc) {
+  auto llvmTy = type.dyn_cast<LLVM::LLVMType>();
+  if (!llvmTy)
+    return parser->emitError(trailingTypeLoc, "expected LLVM IR dialect type"),
+           nullptr;
+  if (!llvmTy.getUnderlyingType()->isPointerTy())
+    return parser->emitError(trailingTypeLoc, "expected LLVM pointer type"),
+           nullptr;
+  return llvmTy.getPointerElementTy();
+}
+
+// <operation> ::= `llvm.load` ssa-use attribute-dict? `:` type
+static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *result) {
+  SmallVector<NamedAttribute, 4> attrs;
+  OpAsmParser::OperandType addr;
+  Type type;
+  llvm::SMLoc trailingTypeLoc;
+
+  if (parser->parseOperand(addr) || parser->parseOptionalAttributeDict(attrs) ||
+      parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) ||
+      parser->parseType(type) ||
+      parser->resolveOperand(addr, type, result->operands))
+    return failure();
+
+  Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
+
+  result->attributes = attrs;
+  result->addTypes(elemTy);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::StoreOp.
+//===----------------------------------------------------------------------===//
+
+static void printStoreOp(OpAsmPrinter *p, StoreOp &op) {
+  *p << op.getOperationName() << ' ' << *op.value() << ", " << *op.addr();
+  p->printOptionalAttrDict(op.getAttrs());
+  *p << " : " << op.addr()->getType();
+}
+
+// <operation> ::= `llvm.store` ssa-use `,` ssa-use attribute-dict? `:` type
+static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *result) {
+  SmallVector<NamedAttribute, 4> attrs;
+  OpAsmParser::OperandType addr, value;
+  Type type;
+  llvm::SMLoc trailingTypeLoc;
+
+  if (parser->parseOperand(value) || parser->parseComma() ||
+      parser->parseOperand(addr) || parser->parseOptionalAttributeDict(attrs) ||
+      parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) ||
+      parser->parseType(type))
+    return failure();
+
+  Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
+  if (!elemTy)
+    return failure();
+
+  if (parser->resolveOperand(value, elemTy, result->operands) ||
+      parser->resolveOperand(addr, type, result->operands))
+    return failure();
+
+  result->attributes = attrs;
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::CallOp.
+//===----------------------------------------------------------------------===//
+
+static void printCallOp(OpAsmPrinter *p, CallOp &op) {
+  auto callee = op.callee();
+  bool isDirect = callee.hasValue();
+
+  // Print the direct callee if present as a function attribute, or an indirect
+  // callee (first operand) otherwise.
+  *p << op.getOperationName() << ' ';
+  if (isDirect)
+    *p << '@' << callee.getValue();
+  else
+    *p << *op.getOperand(0);
+
+  *p << '(';
+  p->printOperands(llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1));
+  *p << ')';
+
+  p->printOptionalAttrDict(op.getAttrs(), {"callee"});
+
+  // Reconstruct the function MLIR function type from operand and result types.
+  SmallVector<Type, 1> resultTypes(op.getResultTypes());
+  SmallVector<Type, 8> argTypes(
+      llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1));
+
+  *p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext());
+}
+
+// <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`
+//                 attribute-dict? `:` function-type
+static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
+  SmallVector<NamedAttribute, 4> attrs;
+  SmallVector<OpAsmParser::OperandType, 8> operands;
+  Type type;
+  SymbolRefAttr funcAttr;
+  llvm::SMLoc trailingTypeLoc;
+
+  // Parse an operand list that will, in practice, contain 0 or 1 operand.  In
+  // case of an indirect call, there will be 1 operand before `(`.  In case of a
+  // direct call, there will be no operands and the parser will stop at the
+  // function identifier without complaining.
+  if (parser->parseOperandList(operands))
+    return failure();
+  bool isDirect = operands.empty();
+
+  // Optionally parse a function identifier.
+  if (isDirect)
+    if (parser->parseAttribute(funcAttr, "callee", attrs))
+      return failure();
+
+  if (parser->parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
+      parser->parseOptionalAttributeDict(attrs) || parser->parseColon() ||
+      parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type))
+    return failure();
+
+  auto funcType = type.dyn_cast<FunctionType>();
+  if (!funcType)
+    return parser->emitError(trailingTypeLoc, "expected function type");
+  if (isDirect) {
+    // Make sure types match.
+    if (parser->resolveOperands(operands, funcType.getInputs(),
+                                parser->getNameLoc(), result->operands))
+      return failure();
+    result->addTypes(funcType.getResults());
+  } else {
+    // Construct the LLVM IR Dialect function type that the first operand
+    // should match.
+    if (funcType.getNumResults() > 1)
+      return parser->emitError(trailingTypeLoc,
+                               "expected function with 0 or 1 result");
+
+    Builder &builder = parser->getBuilder();
+    auto *llvmDialect =
+        builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
+    LLVM::LLVMType llvmResultType;
+    if (funcType.getNumResults() == 0) {
+      llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect);
+    } else {
+      llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
+      if (!llvmResultType)
+        return parser->emitError(trailingTypeLoc,
+                                 "expected result to have LLVM type");
+    }
+
+    SmallVector<LLVM::LLVMType, 8> argTypes;
+    argTypes.reserve(funcType.getNumInputs());
+    for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) {
+      auto argType = funcType.getInput(i).dyn_cast<LLVM::LLVMType>();
+      if (!argType)
+        return parser->emitError(trailingTypeLoc,
+                                 "expected LLVM types as inputs");
+      argTypes.push_back(argType);
+    }
+    auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes,
+                                                      /*isVarArg=*/false);
+    auto wrappedFuncType = llvmFuncType.getPointerTo();
+
+    auto funcArguments =
+        ArrayRef<OpAsmParser::OperandType>(operands).drop_front();
+
+    // Make sure that the first operand (indirect callee) matches the wrapped
+    // LLVM IR function type, and that the types of the other call operands
+    // match the types of the function arguments.
+    if (parser->resolveOperand(operands[0], wrappedFuncType,
+                               result->operands) ||
+        parser->resolveOperands(funcArguments, funcType.getInputs(),
+                                parser->getNameLoc(), result->operands))
+      return failure();
+
+    result->addTypes(llvmResultType);
+  }
+
+  result->attributes = attrs;
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::ExtractElementOp.
+//===----------------------------------------------------------------------===//
+// Expects vector to be of wrapped LLVM vector type and position to be of
+// wrapped LLVM i32 type.
+void LLVM::ExtractElementOp::build(Builder *b, OperationState *result,
+                                   Value *vector, Value *position,
+                                   ArrayRef<NamedAttribute> attrs) {
+  auto wrappedVectorType = vector->getType().cast<LLVM::LLVMType>();
+  auto llvmType = wrappedVectorType.getVectorElementType();
+  build(b, result, llvmType, vector, position);
+  result->addAttributes(attrs);
+}
+
+static void printExtractElementOp(OpAsmPrinter *p, ExtractElementOp &op) {
+  *p << op.getOperationName() << ' ' << *op.vector() << ", " << *op.position();
+  p->printOptionalAttrDict(op.getAttrs());
+  *p << " : " << op.vector()->getType();
+}
+
+// <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use
+//                 attribute-dict? `:` type
+static ParseResult parseExtractElementOp(OpAsmParser *parser,
+                                         OperationState *result) {
+  llvm::SMLoc loc;
+  OpAsmParser::OperandType vector, position;
+  auto *llvmDialect = parser->getBuilder()
+                          .getContext()
+                          ->getRegisteredDialect<LLVM::LLVMDialect>();
+  Type type, i32Type = LLVMType::getInt32Ty(llvmDialect);
+  if (parser->getCurrentLocation(&loc) || parser->parseOperand(vector) ||
+      parser->parseComma() || parser->parseOperand(position) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(type) ||
+      parser->resolveOperand(vector, type, result->operands) ||
+      parser->resolveOperand(position, i32Type, result->operands))
+    return failure();
+  auto wrappedVectorType = type.dyn_cast<LLVM::LLVMType>();
+  if (!wrappedVectorType ||
+      !wrappedVectorType.getUnderlyingType()->isVectorTy())
+    return parser->emitError(
+        loc, "expected LLVM IR dialect vector type for operand #1");
+  result->addTypes(wrappedVectorType.getVectorElementType());
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::ExtractValueOp.
+//===----------------------------------------------------------------------===//
+
+static void printExtractValueOp(OpAsmPrinter *p, ExtractValueOp &op) {
+  *p << op.getOperationName() << ' ' << *op.container() << op.position();
+  p->printOptionalAttrDict(op.getAttrs(), {"position"});
+  *p << " : " << op.container()->getType();
+}
+
+// Extract the type at `position` in the wrapped LLVM IR aggregate type
+// `containerType`.  Position is an integer array attribute where each value
+// is a zero-based position of the element in the aggregate type.  Return the
+// resulting type wrapped in MLIR, or nullptr on error.
+static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser *parser,
+                                                       Type containerType,
+                                                       Attribute positionAttr,
+                                                       llvm::SMLoc attributeLoc,
+                                                       llvm::SMLoc typeLoc) {
+  auto wrappedContainerType = containerType.dyn_cast<LLVM::LLVMType>();
+  if (!wrappedContainerType)
+    return parser->emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr;
+
+  auto positionArrayAttr = positionAttr.dyn_cast<ArrayAttr>();
+  if (!positionArrayAttr)
+    return parser->emitError(attributeLoc, "expected an array attribute"),
+           nullptr;
+
+  // Infer the element type from the structure type: iteratively step inside the
+  // type by taking the element type, indexed by the position attribute for
+  // stuctures.  Check the position index before accessing, it is supposed to be
+  // in bounds.
+  for (Attribute subAttr : positionArrayAttr) {
+    auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
+    if (!positionElementAttr)
+      return parser->emitError(attributeLoc,
+                               "expected an array of integer literals"),
+             nullptr;
+    int position = positionElementAttr.getInt();
+    auto *llvmContainerType = wrappedContainerType.getUnderlyingType();
+    if (llvmContainerType->isArrayTy()) {
+      if (position < 0 || static_cast<unsigned>(position) >=
+                              llvmContainerType->getArrayNumElements())
+        return parser->emitError(attributeLoc, "position out of bounds"),
+               nullptr;
+      wrappedContainerType = wrappedContainerType.getArrayElementType();
+    } else if (llvmContainerType->isStructTy()) {
+      if (position < 0 || static_cast<unsigned>(position) >=
+                              llvmContainerType->getStructNumElements())
+        return parser->emitError(attributeLoc, "position out of bounds"),
+               nullptr;
+      wrappedContainerType =
+          wrappedContainerType.getStructElementType(position);
+    } else {
+      return parser->emitError(typeLoc,
+                               "expected wrapped LLVM IR structure/array type"),
+             nullptr;
+    }
+  }
+  return wrappedContainerType;
+}
+
+// <operation> ::= `llvm.extractvalue` ssa-use
+//                 `[` integer-literal (`,` integer-literal)* `]`
+//                 attribute-dict? `:` type
+static ParseResult parseExtractValueOp(OpAsmParser *parser,
+                                       OperationState *result) {
+  SmallVector<NamedAttribute, 4> attrs;
+  OpAsmParser::OperandType container;
+  Type containerType;
+  Attribute positionAttr;
+  llvm::SMLoc attributeLoc, trailingTypeLoc;
+
+  if (parser->parseOperand(container) ||
+      parser->getCurrentLocation(&attributeLoc) ||
+      parser->parseAttribute(positionAttr, "position", attrs) ||
+      parser->parseOptionalAttributeDict(attrs) || parser->parseColon() ||
+      parser->getCurrentLocation(&trailingTypeLoc) ||
+      parser->parseType(containerType) ||
+      parser->resolveOperand(container, containerType, result->operands))
+    return failure();
+
+  auto elementType = getInsertExtractValueElementType(
+      parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
+  if (!elementType)
+    return failure();
+
+  result->attributes = attrs;
+  result->addTypes(elementType);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::InsertElementOp.
+//===----------------------------------------------------------------------===//
+
+static void printInsertElementOp(OpAsmPrinter *p, InsertElementOp &op) {
+  *p << op.getOperationName() << ' ' << *op.vector() << ", " << *op.value()
+     << ", " << *op.position();
+  p->printOptionalAttrDict(op.getAttrs());
+  *p << " : " << op.vector()->getType();
+}
+
+// <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use
+//                 attribute-dict? `:` type
+static ParseResult parseInsertElementOp(OpAsmParser *parser,
+                                        OperationState *result) {
+  llvm::SMLoc loc;
+  OpAsmParser::OperandType vector, value, position;
+  auto *llvmDialect = parser->getBuilder()
+                          .getContext()
+                          ->getRegisteredDialect<LLVM::LLVMDialect>();
+  Type vectorType, i32Type = LLVMType::getInt32Ty(llvmDialect);
+  if (parser->getCurrentLocation(&loc) || parser->parseOperand(vector) ||
+      parser->parseComma() || parser->parseOperand(value) ||
+      parser->parseComma() || parser->parseOperand(position) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(vectorType))
+    return failure();
+
+  auto wrappedVectorType = vectorType.dyn_cast<LLVM::LLVMType>();
+  if (!wrappedVectorType ||
+      !wrappedVectorType.getUnderlyingType()->isVectorTy())
+    return parser->emitError(
+        loc, "expected LLVM IR dialect vector type for operand #1");
+  auto valueType = wrappedVectorType.getVectorElementType();
+  if (!valueType)
+    return failure();
+
+  if (parser->resolveOperand(vector, vectorType, result->operands) ||
+      parser->resolveOperand(value, valueType, result->operands) ||
+      parser->resolveOperand(position, i32Type, result->operands))
+    return failure();
+
+  result->addTypes(vectorType);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::InsertValueOp.
+//===----------------------------------------------------------------------===//
+
+static void printInsertValueOp(OpAsmPrinter *p, InsertValueOp &op) {
+  *p << op.getOperationName() << ' ' << *op.value() << ", " << *op.container()
+     << op.position();
+  p->printOptionalAttrDict(op.getAttrs(), {"position"});
+  *p << " : " << op.container()->getType();
+}
+
+// <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use
+//                 `[` integer-literal (`,` integer-literal)* `]`
+//                 attribute-dict? `:` type
+static ParseResult parseInsertValueOp(OpAsmParser *parser,
+                                      OperationState *result) {
+  OpAsmParser::OperandType container, value;
+  Type containerType;
+  Attribute positionAttr;
+  llvm::SMLoc attributeLoc, trailingTypeLoc;
+
+  if (parser->parseOperand(value) || parser->parseComma() ||
+      parser->parseOperand(container) ||
+      parser->getCurrentLocation(&attributeLoc) ||
+      parser->parseAttribute(positionAttr, "position", result->attributes) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) ||
+      parser->parseType(containerType))
+    return failure();
+
+  auto valueType = getInsertExtractValueElementType(
+      parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
+  if (!valueType)
+    return failure();
+
+  if (parser->resolveOperand(container, containerType, result->operands) ||
+      parser->resolveOperand(value, valueType, result->operands))
+    return failure();
+
+  result->addTypes(containerType);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::SelectOp.
+//===----------------------------------------------------------------------===//
+
+static void printSelectOp(OpAsmPrinter *p, SelectOp &op) {
+  *p << op.getOperationName() << ' ' << *op.condition() << ", "
+     << *op.trueValue() << ", " << *op.falseValue();
+  p->printOptionalAttrDict(op.getAttrs());
+  *p << " : " << op.condition()->getType() << ", " << op.trueValue()->getType();
+}
+
+// <operation> ::= `llvm.select` ssa-use `,` ssa-use `,` ssa-use
+//                 attribute-dict? `:` type, type
+static ParseResult parseSelectOp(OpAsmParser *parser, OperationState *result) {
+  OpAsmParser::OperandType condition, trueValue, falseValue;
+  Type conditionType, argType;
+
+  if (parser->parseOperand(condition) || parser->parseComma() ||
+      parser->parseOperand(trueValue) || parser->parseComma() ||
+      parser->parseOperand(falseValue) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(conditionType) || parser->parseComma() ||
+      parser->parseType(argType))
+    return failure();
+
+  if (parser->resolveOperand(condition, conditionType, result->operands) ||
+      parser->resolveOperand(trueValue, argType, result->operands) ||
+      parser->resolveOperand(falseValue, argType, result->operands))
+    return failure();
+
+  result->addTypes(argType);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::BrOp.
+//===----------------------------------------------------------------------===//
+
+static void printBrOp(OpAsmPrinter *p, BrOp &op) {
+  *p << op.getOperationName() << ' ';
+  p->printSuccessorAndUseList(op.getOperation(), 0);
+  p->printOptionalAttrDict(op.getAttrs());
+}
+
+// <operation> ::= `llvm.br` bb-id (`[` ssa-use-and-type-list `]`)?
+// attribute-dict?
+static ParseResult parseBrOp(OpAsmParser *parser, OperationState *result) {
+  Block *dest;
+  SmallVector<Value *, 4> operands;
+  if (parser->parseSuccessorAndUseList(dest, operands) ||
+      parser->parseOptionalAttributeDict(result->attributes))
+    return failure();
+
+  result->addSuccessor(dest, operands);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::CondBrOp.
+//===----------------------------------------------------------------------===//
+
+static void printCondBrOp(OpAsmPrinter *p, CondBrOp &op) {
+  *p << op.getOperationName() << ' ' << *op.getOperand(0) << ", ";
+  p->printSuccessorAndUseList(op.getOperation(), 0);
+  *p << ", ";
+  p->printSuccessorAndUseList(op.getOperation(), 1);
+  p->printOptionalAttrDict(op.getAttrs());
+}
+
+// <operation> ::= `llvm.cond_br` ssa-use `,`
+//                  bb-id (`[` ssa-use-and-type-list `]`)? `,`
+//                  bb-id (`[` ssa-use-and-type-list `]`)? attribute-dict?
+static ParseResult parseCondBrOp(OpAsmParser *parser, OperationState *result) {
+  Block *trueDest;
+  Block *falseDest;
+  SmallVector<Value *, 4> trueOperands;
+  SmallVector<Value *, 4> falseOperands;
+  OpAsmParser::OperandType condition;
+
+  Builder &builder = parser->getBuilder();
+  auto *llvmDialect =
+      builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
+  auto i1Type = LLVM::LLVMType::getInt1Ty(llvmDialect);
+
+  if (parser->parseOperand(condition) || parser->parseComma() ||
+      parser->parseSuccessorAndUseList(trueDest, trueOperands) ||
+      parser->parseComma() ||
+      parser->parseSuccessorAndUseList(falseDest, falseOperands) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->resolveOperand(condition, i1Type, result->operands))
+    return failure();
+
+  result->addSuccessor(trueDest, trueOperands);
+  result->addSuccessor(falseDest, falseOperands);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::ReturnOp.
+//===----------------------------------------------------------------------===//
+
+static void printReturnOp(OpAsmPrinter *p, ReturnOp &op) {
+  *p << op.getOperationName();
+  p->printOptionalAttrDict(op.getAttrs());
+  assert(op.getNumOperands() <= 1);
+
+  if (op.getNumOperands() == 0)
+    return;
+
+  *p << ' ' << *op.getOperand(0) << " : " << op.getOperand(0)->getType();
+}
+
+// <operation> ::= `llvm.return` ssa-use-list attribute-dict? `:`
+//                 type-list-no-parens
+static ParseResult parseReturnOp(OpAsmParser *parser, OperationState *result) {
+  SmallVector<OpAsmParser::OperandType, 1> operands;
+  Type type;
+
+  if (parser->parseOperandList(operands) ||
+      parser->parseOptionalAttributeDict(result->attributes))
+    return failure();
+  if (operands.empty())
+    return success();
+
+  if (parser->parseColonType(type) ||
+      parser->resolveOperand(operands[0], type, result->operands))
+    return failure();
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::UndefOp.
+//===----------------------------------------------------------------------===//
+
+static void printUndefOp(OpAsmPrinter *p, UndefOp &op) {
+  *p << op.getOperationName();
+  p->printOptionalAttrDict(op.getAttrs());
+  *p << " : " << op.res()->getType();
+}
+
+// <operation> ::= `llvm.undef` attribute-dict? : type
+static ParseResult parseUndefOp(OpAsmParser *parser, OperationState *result) {
+  Type type;
+
+  if (parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(type))
+    return failure();
+
+  result->addTypes(type);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::ConstantOp.
+//===----------------------------------------------------------------------===//
+
+static void printConstantOp(OpAsmPrinter *p, ConstantOp &op) {
+  *p << op.getOperationName() << '(' << op.value() << ')';
+  p->printOptionalAttrDict(op.getAttrs(), {"value"});
+  *p << " : " << op.res()->getType();
+}
+
+// <operation> ::= `llvm.constant` `(` attribute `)` attribute-list? : type
+static ParseResult parseConstantOp(OpAsmParser *parser,
+                                   OperationState *result) {
+  Attribute valueAttr;
+  Type type;
+
+  if (parser->parseLParen() ||
+      parser->parseAttribute(valueAttr, "value", result->attributes) ||
+      parser->parseRParen() ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(type))
+    return failure();
+
+  result->addTypes(type);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Builder, printer and verifier for LLVM::GlobalOp.
+//===----------------------------------------------------------------------===//
+
+void GlobalOp::build(Builder *builder, OperationState *result, LLVMType type,
+                     bool isConstant, StringRef name, Attribute value,
+                     ArrayRef<NamedAttribute> attrs) {
+  result->addAttribute(SymbolTable::getSymbolAttrName(),
+                       builder->getStringAttr(name));
+  result->addAttribute("type", builder->getTypeAttr(type));
+  if (isConstant)
+    result->addAttribute("constant", builder->getUnitAttr());
+  result->addAttribute("value", value);
+  result->attributes.append(attrs.begin(), attrs.end());
+}
+
+static void printGlobalOp(OpAsmPrinter *p, GlobalOp op) {
+  *p << op.getOperationName() << ' ';
+  if (op.constant())
+    *p << "constant ";
+  *p << '@' << op.sym_name() << '(';
+  p->printAttribute(op.value());
+  *p << ')';
+  p->printOptionalAttrDict(op.getAttrs(), {SymbolTable::getSymbolAttrName(),
+                                           "type", "constant", "value"});
+
+  // Print the trailing type unless it's a string global.
+  if (op.value().isa<StringAttr>())
+    return;
+  *p << " : ";
+  p->printType(op.type());
+}
+
+// <operation> ::= `llvm.global` `constant`? `@` identifier `(` attribute `)`
+//                  attribute-list? (`:` type)?
+//
+// The type can be omitted for string attributes, in which case it will be
+// inferred from the value of the string as [strlen(value) x i8].
+static ParseResult parseGlobalOp(OpAsmParser *parser, OperationState *result) {
+  if (succeeded(parser->parseOptionalKeyword("constant")))
+    result->addAttribute("constant", parser->getBuilder().getUnitAttr());
+
+  Attribute value;
+  StringAttr name;
+  SmallVector<Type, 1> types;
+  if (parser->parseSymbolName(name, SymbolTable::getSymbolAttrName(),
+                              result->attributes) ||
+      parser->parseLParen() ||
+      parser->parseAttribute(value, "value", result->attributes) ||
+      parser->parseRParen() ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseOptionalColonTypeList(types))
+    return failure();
+
+  if (types.size() > 1)
+    return parser->emitError(parser->getNameLoc(), "expected zero or one type");
+
+  if (types.empty()) {
+    if (auto strAttr = value.dyn_cast<StringAttr>()) {
+      MLIRContext *context = parser->getBuilder().getContext();
+      auto *dialect = context->getRegisteredDialect<LLVMDialect>();
+      auto arrayType = LLVM::LLVMType::getArrayTy(
+          LLVM::LLVMType::getInt8Ty(dialect), strAttr.getValue().size());
+      types.push_back(arrayType);
+    } else {
+      return parser->emitError(parser->getNameLoc(),
+                               "type can only be omitted for string globals");
+    }
+  }
+
+  result->addAttribute("type", parser->getBuilder().getTypeAttr(types[0]));
+  return success();
+}
+
+static LogicalResult verify(GlobalOp op) {
+  if (!llvm::PointerType::isValidElementType(op.getType().getUnderlyingType()))
+    return op.emitOpError(
+        "expects type to be a valid element type for an LLVM pointer");
+  if (op.getParentOp() && !isa<ModuleOp>(op.getParentOp()))
+    return op.emitOpError("must appear at the module level");
+  if (auto strAttr = op.value().dyn_cast<StringAttr>()) {
+    auto type = op.getType();
+    if (!type.getUnderlyingType()->isArrayTy() ||
+        !type.getArrayElementType().getUnderlyingType()->isIntegerTy(8) ||
+        type.getArrayNumElements() != strAttr.getValue().size())
+      return op.emitOpError(
+          "requires an i8 array type of the length equal to that of the string "
+          "attribute");
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::ShuffleVectorOp.
+//===----------------------------------------------------------------------===//
+// Expects vector to be of wrapped LLVM vector type and position to be of
+// wrapped LLVM i32 type.
+void LLVM::ShuffleVectorOp::build(Builder *b, OperationState *result, Value *v1,
+                                  Value *v2, ArrayAttr mask,
+                                  ArrayRef<NamedAttribute> attrs) {
+  auto wrappedContainerType1 = v1->getType().cast<LLVM::LLVMType>();
+  auto vType = LLVMType::getVectorTy(
+      wrappedContainerType1.getVectorElementType(), mask.size());
+  build(b, result, vType, v1, v2, mask);
+  result->addAttributes(attrs);
+}
+
+static void printShuffleVectorOp(OpAsmPrinter *p, ShuffleVectorOp &op) {
+  *p << op.getOperationName() << ' ' << *op.v1() << ", " << *op.v2() << " "
+     << op.mask();
+  p->printOptionalAttrDict(op.getAttrs(), {"mask"});
+  *p << " : " << op.v1()->getType() << ", " << op.v2()->getType();
+}
+
+// <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use
+//                 `[` integer-literal (`,` integer-literal)* `]`
+//                 attribute-dict? `:` type
+static ParseResult parseShuffleVectorOp(OpAsmParser *parser,
+                                        OperationState *result) {
+  llvm::SMLoc loc;
+  SmallVector<NamedAttribute, 4> attrs;
+  OpAsmParser::OperandType v1, v2;
+  Attribute maskAttr;
+  Type typeV1, typeV2;
+  if (parser->getCurrentLocation(&loc) || parser->parseOperand(v1) ||
+      parser->parseComma() || parser->parseOperand(v2) ||
+      parser->parseAttribute(maskAttr, "mask", attrs) ||
+      parser->parseOptionalAttributeDict(attrs) ||
+      parser->parseColonType(typeV1) || parser->parseComma() ||
+      parser->parseType(typeV2) ||
+      parser->resolveOperand(v1, typeV1, result->operands) ||
+      parser->resolveOperand(v2, typeV2, result->operands))
+    return failure();
+  auto wrappedContainerType1 = typeV1.dyn_cast<LLVM::LLVMType>();
+  if (!wrappedContainerType1 ||
+      !wrappedContainerType1.getUnderlyingType()->isVectorTy())
+    return parser->emitError(
+        loc, "expected LLVM IR dialect vector type for operand #1");
+  auto vType =
+      LLVMType::getVectorTy(wrappedContainerType1.getVectorElementType(),
+                            maskAttr.cast<ArrayAttr>().size());
+  result->attributes = attrs;
+  result->addTypes(vType);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Builder, printer and verifier for LLVM::LLVMFuncOp.
+//===----------------------------------------------------------------------===//
+
+void LLVMFuncOp::build(Builder *builder, OperationState *result, StringRef name,
+                       LLVMType type, ArrayRef<NamedAttribute> attrs,
+                       ArrayRef<NamedAttributeList> argAttrs) {
+  result->addRegion();
+  result->addAttribute(SymbolTable::getSymbolAttrName(),
+                       builder->getStringAttr(name));
+  result->addAttribute("type", builder->getTypeAttr(type));
+  result->attributes.append(attrs.begin(), attrs.end());
+  if (argAttrs.empty())
+    return;
+
+  unsigned numInputs = type.getUnderlyingType()->getFunctionNumParams();
+  assert(numInputs == argAttrs.size() &&
+         "expected as many argument attribute lists as arguments");
+  SmallString<8> argAttrName;
+  for (unsigned i = 0; i < numInputs; ++i)
+    if (auto argDict = argAttrs[i].getDictionary())
+      result->addAttribute(getArgAttrName(i, argAttrName), argDict);
+}
+
+// Build an LLVM function type from the given lists of input and output types.
+// Returns a null type if any of the types provided are non-LLVM types, or if
+// there is more than one output type.
+static Type buildLLVMFunctionType(Builder &b, ArrayRef<Type> inputs,
+                                  ArrayRef<Type> outputs,
+                                  impl::VariadicFlag variadicFlag,
+                                  std::string &errorMessage) {
+  if (outputs.size() > 1) {
+    errorMessage = "expected zero or one function result";
+    return {};
+  }
+
+  // Convert inputs to LLVM types, exit early on error.
+  SmallVector<LLVMType, 4> llvmInputs;
+  for (auto t : inputs) {
+    auto llvmTy = t.dyn_cast<LLVMType>();
+    if (!llvmTy) {
+      errorMessage = "expected LLVM type for function arguments";
+      return {};
+    }
+    llvmInputs.push_back(llvmTy);
+  }
+
+  // Get the dialect from the input type, if any exist.  Look it up in the
+  // context otherwise.
+  LLVMDialect *dialect =
+      llvmInputs.empty() ? b.getContext()->getRegisteredDialect<LLVMDialect>()
+                         : &llvmInputs.front().getDialect();
+
+  // No output is denoted as "void" in LLVM type system.
+  LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(dialect)
+                                        : outputs.front().dyn_cast<LLVMType>();
+  if (!llvmOutput) {
+    errorMessage = "expected LLVM type for function results";
+    return {};
+  }
+  return LLVMType::getFunctionTy(llvmOutput, llvmInputs,
+                                 variadicFlag.isVariadic());
+}
+
+// Print the LLVMFuncOp.  Collects argument and result types and passes them
+// to the trait printer.  Drops "void" result since it cannot be parsed back.
+static void printLLVMFuncOp(OpAsmPrinter *p, LLVMFuncOp op) {
+  LLVMType fnType = op.getType();
+  SmallVector<Type, 8> argTypes;
+  SmallVector<Type, 1> resTypes;
+  argTypes.reserve(fnType.getFunctionNumParams());
+  for (unsigned i = 0, e = fnType.getFunctionNumParams(); i < e; ++i)
+    argTypes.push_back(fnType.getFunctionParamType(i));
+
+  LLVMType returnType = fnType.getFunctionResultType();
+  if (!returnType.getUnderlyingType()->isVoidTy())
+    resTypes.push_back(returnType);
+
+  impl::printFunctionLikeOp(p, op, argTypes, op.isVarArg(), resTypes);
+}
+
+// Hook for OpTrait::FunctionLike, called after verifying that the 'type'
+// attribute is present.  This can check for preconditions of the
+// getNumArguments hook not failing.
+LogicalResult LLVMFuncOp::verifyType() {
+  auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMType>();
+  if (!llvmType || !llvmType.getUnderlyingType()->isFunctionTy())
+    return emitOpError("requires '" + getTypeAttrName() +
+                       "' attribute of wrapped LLVM function type");
+
+  return success();
+}
+
+// Hook for OpTrait::FunctionLike, returns the number of function arguments.
+// Depends on the type attribute being correct as checked by verifyType
+unsigned LLVMFuncOp::getNumFuncArguments() {
+  return getType().getUnderlyingType()->getFunctionNumParams();
+}
+
+static LogicalResult verify(LLVMFuncOp op) {
+  if (op.isExternal())
+    return success();
+
+  if (op.isVarArg())
+    return op.emitOpError("only external functions can be variadic");
+
+  auto *funcType = cast<llvm::FunctionType>(op.getType().getUnderlyingType());
+  unsigned numArguments = funcType->getNumParams();
+  Block &entryBlock = op.front();
+  for (unsigned i = 0; i < numArguments; ++i) {
+    Type argType = entryBlock.getArgument(i)->getType();
+    auto argLLVMType = argType.dyn_cast<LLVMType>();
+    if (!argLLVMType)
+      return op.emitOpError("entry block argument #")
+             << i << " is not of LLVM type";
+    if (funcType->getParamType(i) != argLLVMType.getUnderlyingType())
+      return op.emitOpError("the type of entry block argument #")
+             << i << " does not match the function signature";
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// LLVMDialect initialization, type parsing, and registration.
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace LLVM {
+namespace detail {
+struct LLVMDialectImpl {
+  LLVMDialectImpl() : module("LLVMDialectModule", llvmContext) {}
+
+  llvm::LLVMContext llvmContext;
+  llvm::Module module;
+
+  /// A set of LLVMTypes that are cached on construction to avoid any lookups or
+  /// locking.
+  LLVMType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
+  LLVMType doubleTy, floatTy, halfTy;
+  LLVMType voidTy;
+
+  /// A smart mutex to lock access to the llvm context. Unlike MLIR, LLVM is not
+  /// multi-threaded and requires locked access to prevent race conditions.
+  llvm::sys::SmartMutex<true> mutex;
+};
+} // end namespace detail
+} // end namespace LLVM
+} // end namespace mlir
+
+LLVMDialect::LLVMDialect(MLIRContext *context)
+    : Dialect(getDialectNamespace(), context),
+      impl(new detail::LLVMDialectImpl()) {
+  addTypes<LLVMType>();
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/LLVMIR/LLVMOps.cpp.inc"
+      >();
+
+  // Support unknown operations because not all LLVM operations are registered.
+  allowUnknownOperations();
+
+  // Cache some of the common LLVM types to avoid the need for lookups/locking.
+  auto &llvmContext = impl->llvmContext;
+  /// Integer Types.
+  impl->int1Ty = LLVMType::get(context, llvm::Type::getInt1Ty(llvmContext));
+  impl->int8Ty = LLVMType::get(context, llvm::Type::getInt8Ty(llvmContext));
+  impl->int16Ty = LLVMType::get(context, llvm::Type::getInt16Ty(llvmContext));
+  impl->int32Ty = LLVMType::get(context, llvm::Type::getInt32Ty(llvmContext));
+  impl->int64Ty = LLVMType::get(context, llvm::Type::getInt64Ty(llvmContext));
+  impl->int128Ty = LLVMType::get(context, llvm::Type::getInt128Ty(llvmContext));
+  /// Float Types.
+  impl->doubleTy = LLVMType::get(context, llvm::Type::getDoubleTy(llvmContext));
+  impl->floatTy = LLVMType::get(context, llvm::Type::getFloatTy(llvmContext));
+  impl->halfTy = LLVMType::get(context, llvm::Type::getHalfTy(llvmContext));
+  /// Other Types.
+  impl->voidTy = LLVMType::get(context, llvm::Type::getVoidTy(llvmContext));
+}
+
+LLVMDialect::~LLVMDialect() {}
+
+#define GET_OP_CLASSES
+#include "mlir/LLVMIR/LLVMOps.cpp.inc"
+
+llvm::LLVMContext &LLVMDialect::getLLVMContext() { return impl->llvmContext; }
+llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; }
+
+/// Parse a type registered to this dialect.
+Type LLVMDialect::parseType(StringRef tyData, Location loc) const {
+  // LLVM is not thread-safe, so lock access to it.
+  llvm::sys::SmartScopedLock<true> lock(impl->mutex);
+
+  llvm::SMDiagnostic errorMessage;
+  llvm::Type *type = llvm::parseType(tyData, errorMessage, impl->module);
+  if (!type)
+    return (emitError(loc, errorMessage.getMessage()), nullptr);
+  return LLVMType::get(getContext(), type);
+}
+
+/// Print a type registered to this dialect.
+void LLVMDialect::printType(Type type, raw_ostream &os) const {
+  auto llvmType = type.dyn_cast<LLVMType>();
+  assert(llvmType && "printing wrong type");
+  assert(llvmType.getUnderlyingType() && "no underlying LLVM type");
+  llvmType.getUnderlyingType()->print(os);
+}
+
+/// Verify LLVMIR function argument attributes.
+LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
+                                                    unsigned regionIdx,
+                                                    unsigned argIdx,
+                                                    NamedAttribute argAttr) {
+  // Check that llvm.noalias is a boolean attribute.
+  if (argAttr.first == "llvm.noalias" && !argAttr.second.isa<BoolAttr>())
+    return op->emitError()
+           << "llvm.noalias argument attribute of non boolean type";
+  return success();
+}
+
+static DialectRegistration<LLVMDialect> llvmDialect;
+
+//===----------------------------------------------------------------------===//
+// LLVMType.
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace LLVM {
+namespace detail {
+struct LLVMTypeStorage : public ::mlir::TypeStorage {
+  LLVMTypeStorage(llvm::Type *ty) : underlyingType(ty) {}
+
+  // LLVM types are pointer-unique.
+  using KeyTy = llvm::Type *;
+  bool operator==(const KeyTy &key) const { return key == underlyingType; }
+
+  static LLVMTypeStorage *construct(TypeStorageAllocator &allocator,
+                                    llvm::Type *ty) {
+    return new (allocator.allocate<LLVMTypeStorage>()) LLVMTypeStorage(ty);
+  }
+
+  llvm::Type *underlyingType;
+};
+} // end namespace detail
+} // end namespace LLVM
+} // end namespace mlir
+
+LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) {
+  return Base::get(context, FIRST_LLVM_TYPE, llvmType);
+}
+
+/// Get an LLVMType with an llvm type that may cause changes to the underlying
+/// llvm context when constructed.
+LLVMType LLVMType::getLocked(LLVMDialect *dialect,
+                             llvm::function_ref<llvm::Type *()> typeBuilder) {
+  // Lock access to the llvm context and build the type.
+  llvm::sys::SmartScopedLock<true> lock(dialect->impl->mutex);
+  return get(dialect->getContext(), typeBuilder());
+}
+
+LLVMDialect &LLVMType::getDialect() {
+  return static_cast<LLVMDialect &>(Type::getDialect());
+}
+
+llvm::Type *LLVMType::getUnderlyingType() const {
+  return getImpl()->underlyingType;
+}
+
+/// Array type utilities.
+LLVMType LLVMType::getArrayElementType() {
+  return get(getContext(), getUnderlyingType()->getArrayElementType());
+}
+unsigned LLVMType::getArrayNumElements() {
+  return getUnderlyingType()->getArrayNumElements();
+}
+
+/// Vector type utilities.
+LLVMType LLVMType::getVectorElementType() {
+  return get(getContext(), getUnderlyingType()->getVectorElementType());
+}
+
+/// Function type utilities.
+LLVMType LLVMType::getFunctionParamType(unsigned argIdx) {
+  return get(getContext(), getUnderlyingType()->getFunctionParamType(argIdx));
+}
+unsigned LLVMType::getFunctionNumParams() {
+  return getUnderlyingType()->getFunctionNumParams();
+}
+LLVMType LLVMType::getFunctionResultType() {
+  return get(
+      getContext(),
+      llvm::cast<llvm::FunctionType>(getUnderlyingType())->getReturnType());
+}
+
+/// Pointer type utilities.
+LLVMType LLVMType::getPointerTo(unsigned addrSpace) {
+  // Lock access to the dialect as this may modify the LLVM context.
+  return getLocked(&getDialect(), [=] {
+    return getUnderlyingType()->getPointerTo(addrSpace);
+  });
+}
+LLVMType LLVMType::getPointerElementTy() {
+  return get(getContext(), getUnderlyingType()->getPointerElementType());
+}
+
+/// Struct type utilities.
+LLVMType LLVMType::getStructElementType(unsigned i) {
+  return get(getContext(), getUnderlyingType()->getStructElementType(i));
+}
+
+/// Utilities used to generate floating point types.
+LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) {
+  return dialect->impl->doubleTy;
+}
+LLVMType LLVMType::getFloatTy(LLVMDialect *dialect) {
+  return dialect->impl->floatTy;
+}
+LLVMType LLVMType::getHalfTy(LLVMDialect *dialect) {
+  return dialect->impl->halfTy;
+}
+
+/// Utilities used to generate integer types.
+LLVMType LLVMType::getIntNTy(LLVMDialect *dialect, unsigned numBits) {
+  switch (numBits) {
+  case 1:
+    return dialect->impl->int1Ty;
+  case 8:
+    return dialect->impl->int8Ty;
+  case 16:
+    return dialect->impl->int16Ty;
+  case 32:
+    return dialect->impl->int32Ty;
+  case 64:
+    return dialect->impl->int64Ty;
+  case 128:
+    return dialect->impl->int128Ty;
+  default:
+    break;
+  }
+
+  // Lock access to the dialect as this may modify the LLVM context.
+  return getLocked(dialect, [=] {
+    return llvm::Type::getIntNTy(dialect->getLLVMContext(), numBits);
+  });
+}
+
+/// Utilities used to generate other miscellaneous types.
+LLVMType LLVMType::getArrayTy(LLVMType elementType, uint64_t numElements) {
+  // Lock access to the dialect as this may modify the LLVM context.
+  return getLocked(&elementType.getDialect(), [=] {
+    return llvm::ArrayType::get(elementType.getUnderlyingType(), numElements);
+  });
+}
+LLVMType LLVMType::getFunctionTy(LLVMType result, ArrayRef<LLVMType> params,
+                                 bool isVarArg) {
+  SmallVector<llvm::Type *, 8> llvmParams;
+  for (auto param : params)
+    llvmParams.push_back(param.getUnderlyingType());
+
+  // Lock access to the dialect as this may modify the LLVM context.
+  return getLocked(&result.getDialect(), [=] {
+    return llvm::FunctionType::get(result.getUnderlyingType(), llvmParams,
+                                   isVarArg);
+  });
+}
+LLVMType LLVMType::getStructTy(LLVMDialect *dialect,
+                               ArrayRef<LLVMType> elements, bool isPacked) {
+  SmallVector<llvm::Type *, 8> llvmElements;
+  for (auto elt : elements)
+    llvmElements.push_back(elt.getUnderlyingType());
+
+  // Lock access to the dialect as this may modify the LLVM context.
+  return getLocked(dialect, [=] {
+    return llvm::StructType::get(dialect->getLLVMContext(), llvmElements,
+                                 isPacked);
+  });
+}
+LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) {
+  // Lock access to the dialect as this may modify the LLVM context.
+  return getLocked(&elementType.getDialect(), [=] {
+    return llvm::VectorType::get(elementType.getUnderlyingType(), numElements);
+  });
+}
+LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) {
+  return dialect->impl->voidTy;
+}
diff --git a/third_party/mlir/lib/LLVMIR/IR/NVVMDialect.cpp b/third_party/mlir/lib/LLVMIR/IR/NVVMDialect.cpp
new file mode 100644
index 0000000..f586f0e
--- /dev/null
+++ b/third_party/mlir/lib/LLVMIR/IR/NVVMDialect.cpp
@@ -0,0 +1,88 @@
+//===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines the types and operation details for the NVVM IR dialect in
+// MLIR, and the LLVM IR dialect.  It also registers the dialect.
+//
+// The NVVM dialect only contains GPU specific additions on top of the general
+// LLVM dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/LLVMIR/NVVMDialect.h"
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/LLVMIR/LLVMDialect.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Attributes.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/SourceMgr.h"
+
+namespace mlir {
+namespace NVVM {
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for NVVM ops
+//===----------------------------------------------------------------------===//
+
+static void printNVVMSpecialRegisterOp(OpAsmPrinter *p, Operation *op) {
+  *p << op->getName() << " : ";
+  if (op->getNumResults() == 1) {
+    *p << op->getResult(0)->getType();
+  } else {
+    *p << "###invalid type###";
+  }
+}
+
+// <operation> ::= `llvm.nvvm.XYZ` : type
+static ParseResult parseNVVMSpecialRegisterOp(OpAsmParser *parser,
+                                              OperationState *result) {
+  Type type;
+  if (parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(type))
+    return failure();
+
+  result->addTypes(type);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// NVVMDialect initialization, type parsing, and registration.
+//===----------------------------------------------------------------------===//
+
+// TODO(herhut): This should be the llvm.nvvm dialect once this is supported.
+NVVMDialect::NVVMDialect(MLIRContext *context) : Dialect("nvvm", context) {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/LLVMIR/NVVMOps.cpp.inc"
+      >();
+
+  // Support unknown operations because not all NVVM operations are registered.
+  allowUnknownOperations();
+}
+
+#define GET_OP_CLASSES
+#include "mlir/LLVMIR/NVVMOps.cpp.inc"
+
+static DialectRegistration<NVVMDialect> nvvmDialect;
+
+} // namespace NVVM
+} // namespace mlir
diff --git a/third_party/mlir/lib/Linalg/Analysis/DependenceAnalysis.cpp b/third_party/mlir/lib/Linalg/Analysis/DependenceAnalysis.cpp
new file mode 100644
index 0000000..5a272a4
--- /dev/null
+++ b/third_party/mlir/lib/Linalg/Analysis/DependenceAnalysis.cpp
@@ -0,0 +1,212 @@
+//===- DependenceAnalysis.cpp - Dependence analysis on SSA views ----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements view-based alias and dependence analyses.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Linalg/Analysis/DependenceAnalysis.h"
+#include "mlir/Linalg/IR/LinalgOps.h"
+
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "linalg-dependence-analysis"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+using llvm::dbgs;
+
+Value *Aliases::find(Value *v) {
+  if (isa<BlockArgument>(v))
+    return v;
+
+  auto it = aliases.find(v);
+  if (it != aliases.end()) {
+    assert(((isa<BlockArgument>(it->getSecond()) &&
+             it->getSecond()->getType().isa<ViewType>()) ||
+            it->getSecond()->getType().isa<BufferType>()) &&
+           "Buffer or block argument expected");
+    return it->getSecond();
+  }
+
+  while (true) {
+    if (isa<BlockArgument>(v))
+      return v;
+    if (auto slice = dyn_cast_or_null<SliceOp>(v->getDefiningOp())) {
+      auto it = aliases.insert(std::make_pair(v, find(slice.getBaseView())));
+      return it.first->second;
+    }
+    if (auto view = dyn_cast_or_null<ViewOp>(v->getDefiningOp())) {
+      auto it = aliases.insert(std::make_pair(v, view.buffer()));
+      return it.first->second;
+    }
+    if (auto view = dyn_cast_or_null<SubViewOp>(v->getDefiningOp())) {
+      v = view.getView();
+      continue;
+    }
+    llvm::errs() << "View alias analysis reduces to: " << *v << "\n";
+    llvm_unreachable("unsupported view alias case");
+  }
+}
+
+LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
+                                             ArrayRef<Operation *> ops)
+    : aliases(aliases), linalgOps(ops.begin(), ops.end()) {
+  for (auto en : llvm::enumerate(linalgOps)) {
+    assert(isa<LinalgOp>(en.value()) && "Expected value for LinalgOp");
+    linalgOpPositions.insert(std::make_pair(en.value(), en.index()));
+  }
+  for (unsigned i = 0, e = ops.size(); i < e; ++i) {
+    for (unsigned j = i + 1; j < e; ++j) {
+      addDependencesBetween(cast<LinalgOp>(ops[i]), cast<LinalgOp>(ops[j]));
+    }
+  }
+}
+
+void LinalgDependenceGraph::addDependenceElem(DependenceType dt,
+                                              LinalgOpView indexingOpView,
+                                              LinalgOpView dependentOpView) {
+  LLVM_DEBUG(dbgs() << "\nAdd dep type " << dt << ":\t" << *indexingOpView.op
+                    << " -> " << *dependentOpView.op);
+  dependencesFromGraphs[dt][indexingOpView.op].push_back(
+      LinalgDependenceGraphElem{dependentOpView, indexingOpView.view});
+  dependencesIntoGraphs[dt][dependentOpView.op].push_back(
+      LinalgDependenceGraphElem{indexingOpView, dependentOpView.view});
+}
+
+LinalgDependenceGraph::dependence_range
+LinalgDependenceGraph::getDependencesFrom(
+    LinalgOp src, LinalgDependenceGraph::DependenceType dt) {
+  return getDependencesFrom(src.getOperation(), dt);
+}
+
+LinalgDependenceGraph::dependence_range
+LinalgDependenceGraph::getDependencesFrom(
+    Operation *src, LinalgDependenceGraph::DependenceType dt) {
+  auto &vec = dependencesFromGraphs[dt][src];
+  return llvm::make_range(vec.begin(), vec.end());
+}
+
+LinalgDependenceGraph::dependence_range
+LinalgDependenceGraph::getDependencesInto(
+    LinalgOp dst, LinalgDependenceGraph::DependenceType dt) {
+  return getDependencesInto(dst.getOperation(), dt);
+}
+
+LinalgDependenceGraph::dependence_range
+LinalgDependenceGraph::getDependencesInto(
+    Operation *dst, LinalgDependenceGraph::DependenceType dt) {
+  auto &vec = dependencesIntoGraphs[dt][dst];
+  return llvm::make_range(vec.begin(), vec.end());
+}
+
+void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
+  for (auto *srcView : src.getOutputs()) { // W
+    // RAW graph
+    for (auto *dstView : dst.getInputs()) {  // R
+      if (aliases.alias(srcView, dstView)) { // if alias, fill RAW
+        addDependenceElem(DependenceType::RAW,
+                          LinalgOpView{src.getOperation(), srcView},
+                          LinalgOpView{dst.getOperation(), dstView});
+      }
+    }
+    // WAW graph
+    for (auto *dstView : dst.getOutputs()) { // W
+      if (aliases.alias(srcView, dstView)) { // if alias, fill WAW
+        addDependenceElem(DependenceType::WAW,
+                          LinalgOpView{src.getOperation(), srcView},
+                          LinalgOpView{dst.getOperation(), dstView});
+      }
+    }
+  }
+  for (auto *srcView : src.getInputs()) { // R
+    // RAR graph
+    for (auto *dstView : dst.getInputs()) {  // R
+      if (aliases.alias(srcView, dstView)) { // if alias, fill RAR
+        addDependenceElem(DependenceType::RAR,
+                          LinalgOpView{src.getOperation(), srcView},
+                          LinalgOpView{dst.getOperation(), dstView});
+      }
+    }
+    // WAR graph
+    for (auto *dstView : dst.getOutputs()) { // W
+      if (aliases.alias(srcView, dstView)) { // if alias, fill WAR
+        addDependenceElem(DependenceType::WAR,
+                          LinalgOpView{src.getOperation(), srcView},
+                          LinalgOpView{dst.getOperation(), dstView});
+      }
+    }
+  }
+}
+
+SmallVector<Operation *, 8>
+LinalgDependenceGraph::findCoveringDependences(LinalgOp srcLinalgOp,
+                                               LinalgOp dstLinalgOp) {
+  return findOperationsWithCoveringDependences(
+      srcLinalgOp, dstLinalgOp, nullptr,
+      {DependenceType::WAW, DependenceType::WAR, DependenceType::RAW});
+}
+
+SmallVector<Operation *, 8>
+LinalgDependenceGraph::findCoveringWrites(LinalgOp srcLinalgOp,
+                                          LinalgOp dstLinalgOp, Value *view) {
+  return findOperationsWithCoveringDependences(
+      srcLinalgOp, dstLinalgOp, view,
+      {DependenceType::WAW, DependenceType::WAR});
+}
+
+SmallVector<Operation *, 8>
+LinalgDependenceGraph::findCoveringReads(LinalgOp srcLinalgOp,
+                                         LinalgOp dstLinalgOp, Value *view) {
+  return findOperationsWithCoveringDependences(
+      srcLinalgOp, dstLinalgOp, view,
+      {DependenceType::RAR, DependenceType::RAW});
+}
+
+SmallVector<Operation *, 8>
+LinalgDependenceGraph::findOperationsWithCoveringDependences(
+    LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value *view,
+    ArrayRef<DependenceType> types) {
+  auto *src = srcLinalgOp.getOperation();
+  auto *dst = dstLinalgOp.getOperation();
+  auto srcPos = linalgOpPositions[src];
+  auto dstPos = linalgOpPositions[dst];
+  assert(srcPos < dstPos && "expected dst after src in IR traversal order");
+
+  SmallVector<Operation *, 8> res;
+  // Consider an intermediate interleaved `interim` op, look for any dependence
+  // to an aliasing view on a src -> op -> dst path.
+  // TODO(ntv) we are not considering paths yet, just interleaved positions.
+  for (auto dt : types) {
+    for (auto dependence : getDependencesFrom(src, dt)) {
+      auto interimPos = linalgOpPositions[dependence.dependentOpView.op];
+      // Skip if not interleaved.
+      if (interimPos >= dstPos || interimPos <= srcPos)
+        continue;
+      if (view && !aliases.alias(view, dependence.indexingView))
+        continue;
+      auto *op = dependence.dependentOpView.op;
+      LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type " << dt
+                        << ": " << *src << " -> " << *op << " on "
+                        << *dependence.indexingView);
+      res.push_back(op);
+    }
+  }
+  return res;
+}
diff --git a/third_party/mlir/lib/Linalg/CMakeLists.txt b/third_party/mlir/lib/Linalg/CMakeLists.txt
new file mode 100644
index 0000000..b37bdaa
--- /dev/null
+++ b/third_party/mlir/lib/Linalg/CMakeLists.txt
@@ -0,0 +1,24 @@
+add_llvm_library(MLIRLinalg
+  LinalgRegistration.cpp
+  Analysis/DependenceAnalysis.cpp
+  IR/LinalgOps.cpp
+  IR/LinalgTypes.cpp
+  Transforms/Fusion.cpp
+  Transforms/LowerToLLVMDialect.cpp
+  Transforms/LowerToLoops.cpp
+  Transforms/Tiling.cpp
+  Utils/Utils.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Linalg
+  DEPENDS
+  intrinsics_gen
+  )
+
+add_dependencies(MLIRLinalg
+
+  MLIRAffineOps
+  MLIRLinalgOpsIncGen
+  MLIRLinalgLibraryOpsIncGen
+  MLIRStandardToLLVM
+  )
diff --git a/third_party/mlir/lib/Linalg/IR/LinalgOps.cpp b/third_party/mlir/lib/Linalg/IR/LinalgOps.cpp
new file mode 100644
index 0000000..bce2b32
--- /dev/null
+++ b/third_party/mlir/lib/Linalg/IR/LinalgOps.cpp
@@ -0,0 +1,1120 @@
+//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a the Linalg operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/EDSC/Helpers.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Linalg/IR/LinalgTypes.h"
+#include "mlir/Linalg/Utils/Utils.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/STLExtras.h"
+#include "mlir/Transforms/FoldUtils.h"
+
+#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+using namespace mlir::linalg;
+
+namespace {
+/// Fold constant dimensions into an alloc operation.
+struct SimplifyDimOp : public OpRewritePattern<linalg::DimOp> {
+  using OpRewritePattern<linalg::DimOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(linalg::DimOp dimOp,
+                                     PatternRewriter &rewriter) const override;
+};
+} // end namespace
+
+PatternMatchResult
+SimplifyDimOp::matchAndRewrite(linalg::DimOp dimOp,
+                               PatternRewriter &rewriter) const {
+  auto *viewProducingOp = dimOp.view()->getDefiningOp();
+  auto subView = dyn_cast_or_null<SubViewOp>(viewProducingOp);
+  auto slice = dyn_cast_or_null<SliceOp>(viewProducingOp);
+  auto view = dyn_cast_or_null<ViewOp>(viewProducingOp);
+  if (!subView && !slice && !view)
+    return matchFailure();
+
+  unsigned dim = dimOp.getIndex();
+  Value *min, *max, *step;
+  if (view) {
+    // Cannot traverse block arguments, fail.
+    if (isa<BlockArgument>(view.getRange(dim)))
+      return matchFailure();
+    // Record min, max, step for further processing.
+    auto range = cast<RangeOp>(view.getRange(dim)->getDefiningOp());
+    std::tie(min, max, step) =
+        std::make_tuple(range.min(), range.max(), range.step());
+  } else if (subView) {
+    // Record min, max, step for further processing.
+    auto range = subView.getRange(dim);
+    std::tie(min, max, step) =
+        std::make_tuple(range.min, range.max, range.step);
+  } else {
+    // Taking the dim of a slice must take a range (since other dims have been
+    // rank-reduced).
+    auto *rangeValue = slice.getRanges()[dim];
+    // Cannot traverse block arguments, fail.
+    if (isa<BlockArgument>(rangeValue))
+      return matchFailure();
+    auto range = cast<RangeOp>(rangeValue->getDefiningOp());
+    // Record min, max, step for further processing.
+    std::tie(min, max, step) =
+        std::make_tuple(range.min(), range.max(), range.step());
+  }
+
+  // Only support constant steps of 1 atm.
+  auto constant = dyn_cast_or_null<ConstantIndexOp>(step->getDefiningOp());
+  if (!constant || constant.getValue() != 1)
+    return matchFailure();
+
+  // Circumvent affine constraints:
+  //   emit an affine_apply when possible, otherwise emit a `subi`.
+  bool validAffineMin = isValidDim(min) || isValidSymbol(min) ||
+                        isa_and_nonnull<ConstantIndexOp>(min->getDefiningOp());
+  bool validAffineMax = isValidDim(max) || isValidSymbol(max) ||
+                        isa_and_nonnull<ConstantIndexOp>(max->getDefiningOp());
+
+  OpBuilder b(dimOp);
+  ScopedContext scope(b, dimOp.getLoc());
+  // Emit `subi`.
+  if (!validAffineMin || !validAffineMax) {
+    rewriter.replaceOp(dimOp, {subi(max, min)}, {dimOp.view()});
+    return matchSuccess();
+  }
+
+  // Emit affine_apply.
+  using edsc::op::operator-;
+  rewriter.replaceOp(dimOp, {ValueHandle(max) - ValueHandle(min)},
+                     {dimOp.view()});
+  return matchSuccess();
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// LoadOp.
+////////////////////////////////////////////////////////////////////////////////
+void mlir::linalg::LoadOp::build(Builder *b, OperationState *result,
+                                 Value *view, ArrayRef<Value *> indices) {
+  auto viewType = view->getType().cast<ViewType>();
+  result->addOperands(view);
+  result->addOperands(indices);
+  result->addTypes(viewType.getElementType());
+}
+
+// A LoadOp prints as:
+//
+// ```{.mlir}
+//    %0 = linalg.load %V[%c0] : !linalg.view<?xf32>
+// ```
+void mlir::linalg::LoadOp::print(OpAsmPrinter *p) {
+  *p << getOperationName() << " " << *getView() << '[';
+  p->printOperands(getIndices());
+  *p << ']';
+  p->printOptionalAttrDict(getAttrs());
+  *p << " : " << getViewType();
+}
+
+ParseResult mlir::linalg::LoadOp::parse(OpAsmParser *parser,
+                                        OperationState *result) {
+  OpAsmParser::OperandType viewInfo;
+  SmallVector<OpAsmParser::OperandType, 4> indexInfo;
+  ViewType type;
+
+  auto affineIntTy = parser->getBuilder().getIndexType();
+  return failure(
+      parser->parseOperand(viewInfo) ||
+      parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(type) ||
+      parser->resolveOperand(viewInfo, type, result->operands) ||
+      parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
+      parser->addTypeToList(type.getElementType(), result->types));
+}
+
+LogicalResult mlir::linalg::LoadOp::verify() {
+  if (getNumOperands() == 0)
+    return emitOpError("expected a view to load from");
+
+  auto viewType = getView()->getType().dyn_cast<ViewType>();
+  if (!viewType)
+    return emitOpError("first operand must be a view");
+
+  if (getType() != viewType.getElementType())
+    return emitOpError("result type must match element type of the view");
+
+  if (getRank() != getNumOperands() - 1)
+    return emitOpError("incorrect number of indices for load");
+
+  for (auto *idx : getIndices())
+    if (!idx->getType().isIndex())
+      return emitOpError("index to load must have 'index' type");
+
+  return success();
+}
+
+//////////////////////////////////////////////////////////////////////////////
+// RangeOp
+//////////////////////////////////////////////////////////////////////////////
+void mlir::linalg::RangeOp::build(Builder *b, OperationState *result,
+                                  Value *min, Value *max, Value *step) {
+  result->addOperands({min, max, step});
+  result->addTypes({RangeType::get(b->getContext())});
+}
+
+// Verification is simply that a RangeOp takes 3 index ssa-value.
+LogicalResult mlir::linalg::RangeOp::verify() {
+  if (!min() || !min()->getType().isa<IndexType>())
+    return emitOpError("first operand should be of type index");
+  if (!max() || !max()->getType().isa<IndexType>())
+    return emitOpError("second operand should be of type index");
+  if (!step() || !step()->getType().isa<IndexType>())
+    return emitOpError("third operand should be of type index");
+  return success();
+}
+
+// A RangeOp prints as:
+//
+// ```{.mlir}
+//   linalg.range %0:%1:%2 : !linalg.range
+// ```
+void mlir::linalg::RangeOp::print(OpAsmPrinter *p) {
+  *p << getOperationName() << " " << *min() << ":" << *max() << ":" << *step()
+     << " : " << getType();
+}
+
+ParseResult mlir::linalg::RangeOp::parse(OpAsmParser *parser,
+                                         OperationState *result) {
+  SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3);
+  RangeType type;
+  auto affineIntTy = parser->getBuilder().getIndexType();
+  return failure(
+      parser->parseOperand(rangeInfo[0]) || parser->parseColon() ||
+      parser->parseOperand(rangeInfo[1]) || parser->parseColon() ||
+      parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) ||
+      parser->resolveOperands(rangeInfo, affineIntTy, result->operands) ||
+      parser->addTypeToList(type, result->types));
+}
+
+//////////////////////////////////////////////////////////////////////////////
+// SliceOp
+//////////////////////////////////////////////////////////////////////////////
+void mlir::linalg::SliceOp::build(Builder *b, OperationState *result,
+                                  Value *base, ArrayRef<Value *> indexings) {
+  result->addOperands({base});
+  result->addOperands(indexings);
+
+  ViewType viewType = base->getType().cast<ViewType>();
+  unsigned rank = viewType.getRank();
+  for (auto *i : indexings)
+    if (!i->getType().isa<RangeType>())
+      rank--;
+  Type elementType = viewType.getElementType();
+  result->addTypes({ViewType::get(b->getContext(), elementType, rank)});
+}
+
+LogicalResult mlir::linalg::SliceOp::verify() {
+  if (llvm::empty(getOperands()))
+    return emitOpError(
+        "requires at least a view operand followed by 'rank' indices");
+  unsigned rank = getBaseViewRank();
+  if (llvm::size(getIndexings()) != rank) {
+    return emitOpError("requires at least a view operand followed by ")
+           << rank << " indexings";
+  }
+  unsigned index = 0;
+  for (auto indexing : getIndexings()) {
+    if (!indexing->getType().isa<RangeType>() &&
+        !indexing->getType().isa<IndexType>()) {
+      return emitOpError() << index
+                           << "^th index must be of range or index type";
+    }
+    if (indexing->getType().isa<IndexType>())
+      --rank;
+    ++index;
+  }
+  if (getRank() != rank) {
+    return emitOpError()
+           << "the rank of the view must be the number of its range indices ("
+           << rank << ") but got: " << getRank();
+  }
+  return success();
+}
+
+ParseResult mlir::linalg::SliceOp::parse(OpAsmParser *parser,
+                                         OperationState *result) {
+  OpAsmParser::OperandType baseInfo;
+  SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
+  SmallVector<Type, 8> types;
+  if (parser->parseOperand(baseInfo) ||
+      parser->parseOperandList(indexingsInfo, OpAsmParser::Delimiter::Square) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonTypeList(types))
+    return failure();
+
+  if (types.size() != 2 + indexingsInfo.size())
+    return parser->emitError(parser->getNameLoc(),
+                             "unexpected number of types ");
+  ViewType baseViewType = types[0].dyn_cast<ViewType>();
+  if (!baseViewType)
+    return parser->emitError(parser->getNameLoc(),
+                             "view type expected for first type");
+  if (indexingsInfo.size() != baseViewType.getRank())
+    return parser->emitError(parser->getNameLoc(), "expected ")
+           << baseViewType.getRank() << " indexings";
+  ViewType viewType = types.back().dyn_cast<ViewType>();
+  if (!viewType)
+    return parser->emitError(parser->getNameLoc(), "view type expected");
+
+  ArrayRef<Type> indexingTypes =
+      ArrayRef<Type>(types).drop_front(1).drop_back(1);
+  if (indexingTypes.size() != baseViewType.getRank())
+    return parser->emitError(parser->getNameLoc(), "expected ")
+           << baseViewType.getRank() << " indexing types";
+  return failure(
+      parser->resolveOperand(baseInfo, baseViewType, result->operands) ||
+      (!indexingsInfo.empty() &&
+       parser->resolveOperands(indexingsInfo, indexingTypes,
+                               indexingsInfo.front().location,
+                               result->operands)) ||
+      parser->addTypeToList(viewType, result->types));
+}
+
+// A SliceOp prints as:
+//
+// ```{.mlir}
+//   linalg.slice %0[%1, %2] :
+//     !linalg.view<?x?xf32>, [indexing-types], !linalg.view<?x?xf32>
+// ```
+//
+// Where %0 is an ssa-value holding a view created from a buffer, %1 and %2 are
+// ssa-value each holding a range.
+void mlir::linalg::SliceOp::print(OpAsmPrinter *p) {
+  *p << getOperationName() << " " << *getBaseView() << "[";
+  interleave(
+      getIndexings().begin(), getIndexings().end(), [p](Value *v) { *p << *v; },
+      [p]() { *p << ", "; });
+  *p << "] : " << getBaseViewType();
+  for (auto indexing : getIndexings()) {
+    *p << ", " << indexing->getType();
+  }
+  *p << ", " << getType();
+}
+
+ViewOp mlir::linalg::SliceOp::getBaseViewOp() {
+  return cast<ViewOp>(getOperand(0)->getDefiningOp());
+}
+
+ViewType mlir::linalg::SliceOp::getBaseViewType() {
+  return getOperand(0)->getType().cast<ViewType>();
+}
+
+SmallVector<Value *, 8> mlir::linalg::SliceOp::getRanges() {
+  llvm::SmallVector<Value *, 8> res;
+  for (auto *operand : getIndexings()) {
+    if (!operand->getType().isa<IndexType>()) {
+      res.push_back(operand);
+    }
+  }
+  return res;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// StoreOp.
+////////////////////////////////////////////////////////////////////////////////
+void mlir::linalg::StoreOp::build(Builder *b, OperationState *result,
+                                  Value *valueToStore, Value *view,
+                                  ArrayRef<Value *> indices) {
+  result->addOperands(valueToStore);
+  result->addOperands(view);
+  result->addOperands(indices);
+}
+
+// A StoreOp prints as:
+//
+// ```{.mlir}
+//    linalg.store %f, %V[%c0] : !linalg.view<?xf32>
+// ```
+void mlir::linalg::StoreOp::print(OpAsmPrinter *p) {
+  *p << getOperationName() << " " << *getValueToStore();
+  *p << ", " << *getView() << '[';
+  p->printOperands(getIndices());
+  *p << ']';
+  p->printOptionalAttrDict(getAttrs());
+  *p << " : " << getViewType();
+}
+
+ParseResult mlir::linalg::StoreOp::parse(OpAsmParser *parser,
+                                         OperationState *result) {
+  OpAsmParser::OperandType storeValueInfo;
+  OpAsmParser::OperandType viewInfo;
+  SmallVector<OpAsmParser::OperandType, 4> indexInfo;
+  ViewType viewType;
+
+  auto affineIntTy = parser->getBuilder().getIndexType();
+  return failure(
+      parser->parseOperand(storeValueInfo) || parser->parseComma() ||
+      parser->parseOperand(viewInfo) ||
+      parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(viewType) ||
+      parser->resolveOperand(storeValueInfo, viewType.getElementType(),
+                             result->operands) ||
+      parser->resolveOperand(viewInfo, viewType, result->operands) ||
+      parser->resolveOperands(indexInfo, affineIntTy, result->operands));
+}
+
+LogicalResult mlir::linalg::StoreOp::verify() {
+  if (getNumOperands() < 2)
+    return emitOpError("expected a value to store and a view");
+
+  // Second operand is a memref type.
+  auto viewType = getView()->getType().dyn_cast<ViewType>();
+  if (!viewType)
+    return emitOpError("second operand must be a view");
+
+  // First operand must have same type as memref element type.
+  if (getValueToStore()->getType() != viewType.getElementType())
+    return emitOpError("first operand must have same element type as the view");
+
+  if (getNumOperands() != 2 + viewType.getRank())
+    return emitOpError("store index operand count not equal to view rank");
+
+  for (auto *idx : getIndices())
+    if (!idx->getType().isIndex())
+      return emitOpError("index to store must have 'index' type");
+
+  return success();
+}
+
+///////////////////// Operations defined with Tablegen /////////////////////////
+// For such operations that do not correspond to library calls (i.e. defined in
+// LinalgOps.td), we define an overloaded `print` function and a
+// parse`className` function.
+
+//===----------------------------------------------------------------------===//
+// BufferAllocOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter *p, BufferAllocOp op) {
+  *p << op.getOperationName() << " ";
+  if (!llvm::empty(op.size()))
+    *p << *op.getOperand(0);
+  p->printOptionalAttrDict(op.getAttrs());
+  *p << " : " << op.getBufferType();
+}
+
+static ParseResult parseBufferAllocOp(OpAsmParser *parser,
+                                      OperationState *result) {
+  SmallVector<OpAsmParser::OperandType, 1> sizeInfo;
+  BufferType bufferType;
+  auto indexTy = parser->getBuilder().getIndexType();
+  if (parser->parseOperandList(sizeInfo) || parser->parseColonType(bufferType))
+    return failure();
+  if (sizeInfo.empty())
+    return parser->addTypeToList(bufferType, result->types);
+  return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
+                 parser->addTypeToList(bufferType, result->types));
+}
+
+static LogicalResult verify(BufferAllocOp op) {
+  if (!op.getBufferType().hasConstantSize()) {
+    if (llvm::size(op.size()) != 1 ||
+        !op.getOperand(0)->getType().isa<IndexType>())
+      return op.emitOpError(
+          "one operand of type index expected for dynamic buffer");
+  } else { // op.getBufferType().hasConstantSize()
+    if (!llvm::empty(op.size()))
+      return op.emitOpError("unexpected static buffer operand");
+    if (op.getBufferType().getBufferSize().getValue() <= 0)
+      return op.emitOpError("expected nonnegative static buffer size");
+  }
+  if (!VectorType::isValidElementType(op.getElementType()) &&
+      !op.getElementType().isa<VectorType>())
+    return op.emitOpError("unsupported buffer element type");
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// BufferDeallocOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter *p, BufferDeallocOp op) {
+  *p << op.getOperationName() << " " << *op.buffer();
+  p->printOptionalAttrDict(op.getAttrs());
+  *p << " : " << op.getBufferType();
+}
+
+static ParseResult parseBufferDeallocOp(OpAsmParser *parser,
+                                        OperationState *result) {
+  OpAsmParser::OperandType bufferInfo;
+  BufferType bufferType;
+  if (parser->parseOperand(bufferInfo) || parser->parseColonType(bufferType))
+    return failure();
+  return parser->resolveOperands(bufferInfo, bufferType, result->operands);
+}
+
+static void print(OpAsmPrinter *p, BufferSizeOp op) {
+  *p << op.getOperationName() << " " << *op.getOperand();
+  p->printOptionalAttrDict(op.getAttrs());
+  *p << " : " << op.getOperand()->getType();
+}
+
+//===----------------------------------------------------------------------===//
+// BufferSizeOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseBufferSizeOp(OpAsmParser *parser,
+                                     OperationState *result) {
+  OpAsmParser::OperandType op;
+  Type type;
+  return failure(parser->parseOperand(op) ||
+                 parser->parseOptionalAttributeDict(result->attributes) ||
+                 parser->parseColonType(type) ||
+                 parser->resolveOperand(op, type, result->operands) ||
+                 parser->addTypeToList(parser->getBuilder().getIndexType(),
+                                       result->types));
+}
+
+//===----------------------------------------------------------------------===//
+// DimOp
+//===----------------------------------------------------------------------===//
+void mlir::linalg::DimOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<SimplifyDimOp>(context);
+}
+
+static void print(OpAsmPrinter *p, linalg::DimOp op) {
+  *p << op.getOperationName() << " " << *op.getOperand() << ", "
+     << op.getIndex();
+  p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"});
+  *p << " : " << op.getOperand()->getType();
+}
+
+static ParseResult parseDimOp(OpAsmParser *parser, OperationState *result) {
+  OpAsmParser::OperandType operandInfo;
+  IntegerAttr indexAttr;
+  Type type;
+  Type indexType = parser->getBuilder().getIndexType();
+  return failure(parser->parseOperand(operandInfo) || parser->parseComma() ||
+                 parser->parseAttribute(indexAttr, indexType, "index",
+                                        result->attributes) ||
+                 parser->parseOptionalAttributeDict(result->attributes) ||
+                 parser->parseColonType(type) ||
+                 parser->resolveOperand(operandInfo, type, result->operands) ||
+                 parser->addTypeToList(indexType, result->types));
+}
+
+//===----------------------------------------------------------------------===//
+// GenericOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter *p, GenericOp op) {
+  auto attrNames = op.linalgTraitAttrNames();
+  llvm::StringSet<> linalgTraitAttrsSet;
+  linalgTraitAttrsSet.insert(attrNames.begin(), attrNames.end());
+  SmallVector<NamedAttribute, 8> attrs;
+  for (auto attr : op.getAttrs()) {
+    if (linalgTraitAttrsSet.count(attr.first.strref()) > 0)
+      attrs.push_back(attr);
+  }
+  auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
+  *p << op.getOperationName() << " " << dictAttr << " ";
+  p->printOperands(op.getOperands());
+  if (!op.region().empty())
+    p->printRegion(op.region());
+  p->printOptionalAttrDict(op.getAttrs(), attrNames);
+  *p << ": ";
+  interleaveComma(op.getOperandTypes(), *p);
+}
+
+static ParseResult parseGenericOp(OpAsmParser *parser, OperationState *result) {
+  SmallVector<OpAsmParser::OperandType, 8> operandsInfo, regionOperandsInfo;
+  DictionaryAttr dictAttr;
+  // Parse the core linalg traits that must check into a dictAttr.
+  // The name is unimportant as we will overwrite result->attributes.
+  // The core linalg traits must contain the information necessary to pass the
+  // verifier.
+  if (parser->parseAttribute(dictAttr, "_", result->attributes) ||
+      parser->parseOperandList(operandsInfo))
+    return failure();
+  result->attributes.assign(dictAttr.getValue().begin(),
+                            dictAttr.getValue().end());
+
+  Region &region = *result->addRegion();
+  SmallVector<Type, 8> operandTypes, regionTypes;
+  // Optional attributes may be added.
+  // Either Optional "fun" attribute or region must be specified.
+  if (!dictAttr.get("fun") &&
+      parser->parseOptionalRegion(region, regionOperandsInfo, regionTypes))
+    return failure();
+  if (parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonTypeList(operandTypes))
+    return failure();
+  return parser->resolveOperands(operandsInfo, operandTypes,
+                                 parser->getCurrentLocation(),
+                                 result->operands);
+}
+
+static LogicalResult verify(GenericOp op) {
+  auto nInputViews = op.getNumInputs();
+  auto nViews = op.getNumInputsAndOutputs();
+  if (nViews != llvm::size(op.views()))
+    return op.emitError("op expected exactly ") << nViews << " view operands";
+
+  auto &region = op.region();
+  auto funOp = op.getFunction();
+  auto funType = funOp ? funOp.getType() : FunctionType();
+  if (!region.empty()) {
+    if (region.getBlocks().size() != 1)
+      return op.emitError("op expected region with 1 block");
+
+    auto &block = region.getBlocks().front();
+    if (block.getNumArguments() != nViews)
+      return op.emitError(
+          "op expected number of block arguments to match number of views");
+
+    for (unsigned i = 0; i < nViews; ++i) {
+      auto viewType = op.getViewType(i);
+      if (viewType.getElementType() != block.getArgument(i)->getType())
+        return op.emitError("op expected block argument ")
+               << i << " of the same type as elemental type of "
+               << ((i < nInputViews) ? "input " : "output ")
+               << "view: " << viewType;
+    }
+  } else {
+    if (!funOp || !funOp.getType())
+      return op.emitError(
+          "op expected fun attribute to refer to a defined symbol");
+    if (funType.getNumInputs() != nViews)
+      return op.emitError("op expected fun arguments to match number of views");
+    if (funType.getNumResults() != op.getNumOutputs())
+      return op.emitError(
+          "op expected fun results to match number of output views");
+  }
+
+  auto nLoops = op.getNumLoops();
+  SmallVector<AffineMap, 4> indexingMaps;
+  indexingMaps.reserve(op.indexing_maps().size());
+  for (auto en : llvm::enumerate(op.indexing_maps())) {
+    auto idx = en.index();
+    auto m = en.value().cast<AffineMapAttr>().getValue();
+    indexingMaps.push_back(m); // Save reference to map for further checks.
+    auto view = (idx < nInputViews) ? op.getInputViewType(idx)
+                                    : op.getOutputViewType(idx - nInputViews);
+
+    if (m.getNumSymbols() != 0)
+      return op.emitError("op expected indexing_map #")
+             << idx << " to have no symbols";
+
+    if (m.getNumDims() != nLoops)
+      return op.emitError("op expected indexing_map #")
+             << idx << " to have " << nLoops
+             << " dim(s) to match the number of loops";
+
+    if (m.getNumResults() == 1 && view.getRank() == 0) {
+      auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>();
+      if (!cst || cst.getValue() != 0)
+        return op.emitError("op expected indexing_map #")
+               << idx << " to be 0 to match 0-D view: " << view;
+    }
+
+    if (m.getNumResults() != view.getRank())
+      return op.emitError("op expected indexing_map #")
+             << idx << " results to match view rank: " << view;
+
+    if (funType) {
+      if (funType.getInput(idx) != view.getElementType())
+        return op.emitError("op expected fun argument ")
+               << idx
+               << " to match view element type: " << view.getElementType();
+
+      if (idx >= nInputViews)
+        if (funType.getResult(idx - nInputViews) != view.getElementType())
+          return op.emitError("op expected fun result ")
+                 << idx << " to match output view element type: "
+                 << view.getElementType();
+    }
+  }
+
+  auto concatMap = concatAffineMaps(indexingMaps);
+  auto aggregateMap = inversePermutation(concatMap);
+  if (!aggregateMap)
+    return op.emitError("op expected the concatenation of maps in indexing_map "
+                        "to be invertible");
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ViewOp
+//===----------------------------------------------------------------------===//
+void mlir::linalg::ViewOp::build(Builder *b, OperationState *result,
+                                 Value *buffer, ArrayRef<Value *> ranges,
+                                 Type resultType,
+                                 ArrayRef<NamedAttribute> attrs) {
+  if (!resultType) {
+    Type elementType = buffer->getType().cast<BufferType>().getElementType();
+    resultType = ViewType::get(b->getContext(), elementType, ranges.size());
+  }
+  build(b, result, resultType, buffer, ranges);
+  result->addAttributes(attrs);
+}
+
+static ParseResult parseViewOp(OpAsmParser *parser, OperationState *result) {
+  OpAsmParser::OperandType bufferInfo;
+  SmallVector<OpAsmParser::OperandType, 8> rangesInfo;
+  Type bType, vType;
+  if (parser->parseOperand(bufferInfo) ||
+      parser->parseOperandList(rangesInfo, OpAsmParser::Delimiter::Square) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColon() || parser->parseType(bType) ||
+      parser->parseArrow() || parser->parseType(vType)) {
+    return failure();
+  }
+
+  BufferType bufferType = bType.dyn_cast<BufferType>();
+  if (!bufferType) {
+    return parser->emitError(parser->getNameLoc(), "buffer type expected");
+  }
+
+  ViewType viewType = vType.dyn_cast<ViewType>();
+  if (!viewType)
+    return parser->emitError(parser->getNameLoc(), "view type expected");
+  if (viewType.getRank() != rangesInfo.size())
+    return parser->emitError(parser->getNameLoc(), "expected")
+           << viewType.getRank() << " range ranges";
+  return failure(
+      parser->resolveOperand(bufferInfo, bufferType, result->operands) ||
+      (!rangesInfo.empty() &&
+       parser->resolveOperands(rangesInfo, RangeType::get(vType.getContext()),
+                               result->operands)) ||
+      parser->addTypeToList(viewType, result->types));
+}
+
+// A ViewOp prints as:
+//
+// ```{.mlir}
+//   linalg.view %0[%1, %2] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
+// ```
+//
+// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
+// holding a range.
+static void print(OpAsmPrinter *p, ViewOp op) {
+  *p << op.getOperationName() << " " << *op.buffer() << "[";
+  interleaveComma(op.ranges(), *p, [&](Value *v) { *p << *v; });
+  *p << "] : " << op.buffer()->getType() << " -> " << op.getType();
+}
+
+//===----------------------------------------------------------------------===//
+// YieldOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseYieldOp(OpAsmParser *parser, OperationState *result) {
+  SmallVector<OpAsmParser::OperandType, 2> opInfo;
+  SmallVector<Type, 2> types;
+  llvm::SMLoc loc = parser->getCurrentLocation();
+  return failure(parser->parseOperandList(opInfo) ||
+                 (!opInfo.empty() && parser->parseColonTypeList(types)) ||
+                 parser->resolveOperands(opInfo, types, loc, result->operands));
+}
+
+static void print(OpAsmPrinter *p, YieldOp op) {
+  *p << op.getOperationName();
+  if (op.getNumOperands() > 0) {
+    *p << ' ';
+    p->printOperands(op.operand_begin(), op.operand_end());
+    *p << " : ";
+    interleaveComma(op.getOperands(), *p,
+                    [&](Value *e) { p->printType(e->getType()); });
+  }
+}
+
+static LogicalResult verify(YieldOp op) {
+  auto *parentOp = op.getParentOp();
+  if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
+    return op.emitOpError("op expected single non-empty parent region");
+
+  auto genericOp = dyn_cast<GenericOp>(parentOp);
+  if (!genericOp)
+    return op.emitOpError("op expected '")
+           << GenericOp::getOperationName() << "' parent op";
+
+  // The operand number and types must match the view element types.
+  auto nOutputViews = genericOp.getNumOutputs();
+  if (op.getNumOperands() != nOutputViews)
+    return op.emitOpError("op expected ")
+           << nOutputViews << " operand to match enclosing linalg.generic op";
+
+  for (unsigned i = 0; i != nOutputViews; ++i) {
+    auto elementType = genericOp.getOutputViewType(i).getElementType();
+    if (op.getOperand(i)->getType() != elementType)
+      return op.emitError("type of return operand ")
+             << i << " (" << op.getOperand(i)->getType()
+             << ") doesn't match view element type (" << elementType << ")";
+  }
+  return success();
+}
+
+static void print(OpAsmPrinter *p, SubViewOp op) {
+  *p << op.getOperationName() << " " << *op.getOperand(0) << "[";
+  auto ranges = op.getRanges();
+  interleaveComma(ranges, *p, [&p](const SubViewOp::Range &i) {
+    *p << *i.min << ", " << *i.max << ", " << *i.step;
+  });
+  *p << "]";
+  p->printOptionalAttrDict(op.getAttrs());
+  *p << " : " << op.getViewType();
+}
+
+//===----------------------------------------------------------------------===//
+// SubViewOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseSubViewOp(OpAsmParser *parser, OperationState *result) {
+  OpAsmParser::OperandType inputView, resultView;
+  Type viewType;
+  if (parser->parseOperand(inputView))
+    return failure();
+
+  SmallVector<OpAsmParser::OperandType, 12> ops;
+  // TODO(ntv) evolve parsing from
+  //    linalg.subview %0[%1, %2, %3, %4, %5, %6]
+  // to something resembling
+  //    linalg.subview %0[%1:%2:%3][%4:%5:%6]
+  if (parser->parseOperandList(ops, OpAsmParser::Delimiter::Square) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(viewType))
+    return failure();
+
+  auto indexTy = parser->getBuilder().getIndexType();
+  return failure(
+      parser->resolveOperand(inputView, viewType, result->operands) ||
+      parser->resolveOperands(ops, indexTy, result->operands) ||
+      parser->addTypeToList(viewType, result->types));
+}
+
+/////// Operations corresponding to library calls defined with Tablegen ////////
+// For such operations correspond to library calls (i.e. defined in
+// LinalgLibraryOps.td), we define an overloaded `print` function and a
+// parse`className` function.
+
+// A LinalgLibraryOp prints as:
+//
+// ```{.mlir}
+//   concrete_op_name (ssa-inputs, ssa-outputs) : view-types
+// ```
+//
+// for example:
+//
+// ```
+//   linalg.matmul(%0, %1, %2) :
+//     !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+// ```
+//
+// Where %0, %1 and %2 are ssa-values of type ViewType.
+static void printLinalgLibraryOp(OpAsmPrinter *p, Operation *op) {
+  assert(op->getAbstractOperation() && "unregistered operation");
+  *p << op->getName().getStringRef() << "(";
+  interleave(
+      op->getOperands().begin(), op->getOperands().end(),
+      [&](Value *v) { *p << *v; }, [&]() { *p << ", "; });
+  *p << ")";
+  p->printOptionalAttrDict(op->getAttrs());
+  *p << " : ";
+  interleave(
+      op->getOperands().begin(), op->getOperands().end(),
+      [&](Value *v) { *p << v->getType(); }, [&]() { *p << ", "; });
+}
+
+static ParseResult parseLinalgLibraryOp(OpAsmParser *parser,
+                                        OperationState *result) {
+  SmallVector<OpAsmParser::OperandType, 3> ops;
+  SmallVector<Type, 3> types;
+  return failure(parser->parseOperandList(ops, OpAsmParser::Delimiter::Paren) ||
+                 parser->parseOptionalAttributeDict(result->attributes) ||
+                 parser->parseColonTypeList(types) ||
+                 parser->resolveOperands(ops, types, parser->getNameLoc(),
+                                         result->operands));
+}
+
+static LogicalResult verify(FillOp op) {
+  auto viewType = op.getOutputViewType(0);
+  auto fillType = op.getValue()->getType();
+  if (viewType.getElementType() != fillType)
+    return op.emitOpError("expects fill type to match view elemental type");
+  return success();
+}
+
+static LogicalResult verify(CopyOp op) {
+  auto outputViewType = op.getOutputViewType(0);
+  auto inputViewType = op.getInputViewType(0);
+  if (inputViewType.getElementType() != outputViewType.getElementType())
+    return op.emitOpError("expects views of the same type");
+  if (inputViewType.getRank() != outputViewType.getRank())
+    return op.emitOpError("expects views of the same rank");
+  auto rank = op.getNumParallelLoops();
+  auto inputPermutationMap = op.inputPermutation();
+  if (inputPermutationMap) {
+    if (inputPermutationMap->getNumInputs() != rank)
+      return op.emitOpError("expects optional input_permutation map of rank ")
+             << rank;
+    if (!inputPermutationMap->isPermutation())
+      return op.emitOpError(
+          "expects optional input_permutation map to be a permutation");
+  }
+  auto outputPermutationMap = op.outputPermutation();
+  if (outputPermutationMap) {
+    if (outputPermutationMap->getNumInputs() != rank)
+      return op.emitOpError("expects optional output_permutation map of rank ")
+             << rank;
+    if (!outputPermutationMap->isPermutation())
+      return op.emitOpError(
+          "expects optional output_permutation map to be a permutation");
+  }
+  if (rank == 0 && inputPermutationMap)
+    return op.emitOpError("expected no input permutation when rank == 0");
+  if (rank == 0 && outputPermutationMap)
+    return op.emitOpError("expected no output permutation when rank == 0");
+  return success();
+}
+
+static LogicalResult
+verifyStrideOrDilation(ConvOp op, ArrayRef<Attribute> attrs, bool isStride) {
+  auto strideOrDilation = isStride ? "stride" : "dilation";
+  if (attrs.size() != op.getNumWindowLoops())
+    return op.emitOpError("expects num ")
+           << strideOrDilation
+           << "s equal to number of window dimensions: " << attrs.size()
+           << " vs " << op.getNumWindowLoops();
+  return success();
+}
+
+static LogicalResult verify(ConvOp op) {
+  auto oType = op.output()->getType().cast<ViewType>();
+  auto fType = op.filter()->getType().cast<ViewType>();
+  auto iType = op.input()->getType().cast<ViewType>();
+  if (oType.getElementType() != iType.getElementType() ||
+      oType.getElementType() != fType.getElementType())
+    return op.emitOpError("expects view elemental types to match");
+  if (oType.getRank() != iType.getRank() || oType.getRank() != fType.getRank())
+    return op.emitOpError("expects view ranks to match");
+  if (auto strides = op.strides()) {
+    if (failed(
+            verifyStrideOrDilation(op, strides->getValue(), /*isStride=*/true)))
+      return failure();
+  }
+  if (auto dilations = op.dilations()) {
+    if (failed(verifyStrideOrDilation(op, dilations->getValue(),
+                                      /*isStride=*/false)))
+      return failure();
+  }
+  return success();
+}
+
+llvm::raw_ostream &mlir::linalg::operator<<(llvm::raw_ostream &os,
+                                            SubViewOp::Range &range) {
+  return os << "range " << *range.min << ":" << *range.max << ":"
+            << *range.step;
+}
+
+namespace mlir {
+namespace linalg {
+
+#define GET_OP_CLASSES
+#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
+
+} // namespace linalg
+} // namespace mlir
+
+static AffineMap extractOrIdentityMap(llvm::Optional<AffineMap> maybeMap,
+                                      unsigned rank, MLIRContext *context) {
+  if (maybeMap)
+    return maybeMap.getValue();
+  if (rank == 0)
+    return AffineMap();
+  return AffineMap::getMultiDimIdentityMap(rank, context);
+}
+
+// Returns `num` AffineDimExpr dimensions at positions [curIdx, curIdx + num)
+// and increments `curIdx` to `curIdx + num`.
+static SmallVector<AffineExpr, 4>
+makeAffineDimExprs(unsigned num, unsigned &curIdx, MLIRContext *context) {
+  SmallVector<AffineExpr, 4> res;
+  res.reserve(num);
+  for (unsigned i = 0; i < num; ++i)
+    res.push_back(getAffineDimExpr(curIdx++, context));
+  return res;
+}
+
+static SmallVector<AffineExpr, 4>
+weightedConvInputIndex(ConvOp op, ArrayRef<AffineExpr> a,
+                       ArrayRef<AffineExpr> b) {
+  assert(a.size() == b.size());
+  SmallVector<AffineExpr, 4> res;
+  res.reserve(a.size());
+  for (unsigned i = 0, e = a.size(); i < e; ++i) {
+    res.push_back(op.getStride(i) * a[i] + op.getDilation(i) * b[i]);
+  }
+  return res;
+}
+
+static SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
+                                         ArrayRef<AffineExpr> b) {
+  SmallVector<AffineExpr, 4> res;
+  res.reserve(a.size() + b.size());
+  res.assign(a.begin(), a.end());
+  res.append(b.begin(), b.end());
+  return res;
+}
+
+// Note: both functions below would completely disappear with a simple tensor
+// kernel language.
+//
+// Ideally this should all be Tablegen'd but there is no good story for
+// AffineMap for now.
+SmallVector<AffineMap, 4> mlir::linalg::loopToOperandRangesMaps(Operation *op) {
+  MLIRContext *context = op->getContext();
+  if (auto copyOp = dyn_cast<CopyOp>(op)) {
+    // I(input_perm(ivs)) -> O(output_perm(ivs))
+    auto maybeInputMap = copyOp.inputPermutation();
+    auto maybeOutputMap = copyOp.outputPermutation();
+    unsigned inputRank = copyOp.getInputViewType(0).getRank();
+    unsigned outputRank = copyOp.getOutputViewType(0).getRank();
+    return SmallVector<AffineMap, 4>{
+        extractOrIdentityMap(maybeInputMap, inputRank, context),
+        extractOrIdentityMap(maybeOutputMap, outputRank, context)};
+  }
+  if (auto fillOp = dyn_cast<FillOp>(op)) {
+    // filling_value -> O(ivs)
+    unsigned rank = fillOp.getNumParallelLoops();
+    return SmallVector<AffineMap, 4>{
+        extractOrIdentityMap(llvm::None, rank, context)};
+  }
+  auto i = getAffineDimExpr(0, context);
+  auto j = getAffineDimExpr(1, context);
+  auto k = getAffineDimExpr(2, context);
+  if (isa<DotOp>(op))
+    // A(r_i) * B(r_i) -> C()
+    return SmallVector<AffineMap, 4>{AffineMap::get(1, 0, {i}),
+                                     AffineMap::get(1, 0, {i}), AffineMap()};
+  if (isa<MatvecOp>(op))
+    //   A(i, r_j) * B(r_j) -> C(i)
+    return SmallVector<AffineMap, 4>{AffineMap::get(2, 0, {i, j}),
+                                     AffineMap::get(2, 0, {j}),
+                                     AffineMap::get(2, 0, {i})};
+  if (isa<MatmulOp>(op))
+    //   A(i, r_k) * B(r_k, j) -> C(i, j)
+    return SmallVector<AffineMap, 4>{AffineMap::get(3, 0, {i, k}),
+                                     AffineMap::get(3, 0, {k, j}),
+                                     AffineMap::get(3, 0, {i, j})};
+  if (auto convOp = dyn_cast<ConvOp>(op)) {
+    //   F(z0, ..., zN-1, q, k) * I(b, x0 + z0, ..., xN-1 + zN-1, q) ->
+    //     O(b, x0, ..., xN-1, k)
+    // for N equal to `nWindow`.
+    auto nWin = convOp.getNumWindowLoops();
+    assert(nWin > 0 && "expected at least one window dimension");
+    unsigned idx = 0;
+    // In the following, AffineDimExprs are indexed in loop order:
+    //   [ b, xs, k,           q,                     zs]
+    //    parallels     non-window reductions     windows
+    //
+    // Parallel dims are exactly the dimensions indexing `output`:
+    //     output[b, x[0], ..., x[N-1], k]; i.e.
+    //  * batch dimensions (bs with #bs = 1 for now)
+    //  * "image" dimensions (xs with #xs = #zs = output_rank - #bs - #ks)
+    //  * output filter dimensions (ks with #ks = 1 for now)
+    auto bs = makeAffineDimExprs(convOp.getNumBatchDimensions(), idx, context);
+    auto xs = makeAffineDimExprs(nWin, idx, context);
+    auto ks = makeAffineDimExprs(convOp.getNumOutputFeatureDimensions(), idx,
+                                 context);
+    // Non-window reduction dim: sum_{z[0], ..., z[N-1], q}
+    auto qs =
+        makeAffineDimExprs(convOp.getNumInputFeatureDimensions(), idx, context);
+    // Window reduction dims: sum_{z[0], ..., z[N-1], q}
+    auto zs = makeAffineDimExprs(nWin, idx, context);
+    // Construct the weighedSum expression.
+    auto ws = weightedConvInputIndex(convOp, xs, zs);
+    return SmallVector<AffineMap, 4>{
+        // filter[z[0], ..., z[N-1], q, k]
+        AffineMap::get(idx, 0, concat(concat(zs, qs), ks)),
+        // input[b,
+        //       x[0]*s[0] + d[0]*z[0], ..., x[N-1]*s[N-1] + d[N-1]*z[N-1],
+        //       q]
+        AffineMap::get(idx, 0, concat(concat(bs, ws), qs)),
+        // output[b, x[0], ..., x[N-1], k]
+        AffineMap::get(idx, 0, concat(concat(bs, xs), ks))};
+  } else if (auto genericOp = dyn_cast<GenericOp>(op)) {
+    SmallVector<AffineMap, 4> res;
+    unsigned nViews = genericOp.getNumInputsAndOutputs();
+    res.reserve(nViews);
+    for (unsigned i = 0, e = nViews; i < e; ++i) {
+      res.push_back(genericOp.getIndexingMap(i));
+    }
+    return res;
+  }
+  llvm_unreachable("Missing loopToOperandRangesMaps for op");
+}
+
+static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
+  if (auto view = t.dyn_cast<ViewType>()) {
+    ss << "view";
+    for (unsigned i = 0, e = view.getRank(); i < e; ++i)
+      ss << "x";
+    appendMangledType(ss, view.getElementType());
+  } else if (auto vec = t.dyn_cast<VectorType>()) {
+    ss << "vector";
+    interleave(
+        vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
+    appendMangledType(ss, vec.getElementType());
+  } else if (t.isIntOrIndexOrFloat()) {
+    ss << t;
+  } else {
+    llvm_unreachable("Invalid type for linalg library name mangling");
+  }
+}
+
+std::string mlir::linalg::generateLibraryCallName(Operation *op) {
+  assert(isa<LinalgOp>(op));
+  std::string name(op->getName().getStringRef().str());
+  name.reserve(128);
+  std::replace(name.begin(), name.end(), '.', '_');
+  llvm::raw_string_ostream ss(name);
+  ss << "_";
+  auto types = op->getOperandTypes();
+  interleave(
+      types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); },
+      [&]() { ss << "_"; });
+  return ss.str();
+}
diff --git a/third_party/mlir/lib/Linalg/IR/LinalgTypes.cpp b/third_party/mlir/lib/Linalg/IR/LinalgTypes.cpp
new file mode 100644
index 0000000..ca54c33
--- /dev/null
+++ b/third_party/mlir/lib/Linalg/IR/LinalgTypes.cpp
@@ -0,0 +1,268 @@
+//===- Dialect.cpp - Implementation of the linalg dialect and types -------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the Linalg dialect types and dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Linalg/IR/LinalgTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Linalg/IR/LinalgOps.h"
+#include "mlir/Parser.h"
+#include "mlir/Support/LLVM.h"
+
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context)
+    : Dialect(getDialectNamespace(), context) {
+  addTypes<BufferType, RangeType, ViewType>();
+  addOperations<LoadOp, RangeOp, StoreOp, SliceOp>();
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
+      >();
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
+      >();
+}
+
+struct mlir::linalg::BufferTypeStorage : public TypeStorage {
+  /// Underlying Key type to transport the payload needed to construct a custom
+  /// type in a generic way.
+  struct Key {
+    Key(Type elementType, int64_t bufferSize = -1)
+        : elementType(elementType), bufferSize(bufferSize) {}
+    Type elementType;
+    int64_t bufferSize;
+  };
+  /// `KeyTy` is a necessary typename hook for MLIR's custom type unique'ing.
+  using KeyTy = Key;
+
+  /// Construction in the llvm::BumpPtrAllocator given a key.
+  static BufferTypeStorage *construct(TypeStorageAllocator &allocator,
+                                      const Key &key) {
+    return new (allocator.allocate<BufferTypeStorage>()) BufferTypeStorage(key);
+  }
+
+  /// Equality operator for hashing.
+  bool operator==(const Key &key) const {
+    return elementType == key.elementType && bufferSize == key.bufferSize;
+  }
+
+  /// Hashing for unique'ing.
+  static unsigned hashKey(const Key &key) {
+    return llvm::hash_combine(key.elementType, key.bufferSize);
+  }
+
+  Type getElementType() { return elementType; }
+  bool hasConstantSize() { return bufferSize >= 0; }
+  Optional<int64_t> getBufferSize() {
+    if (hasConstantSize()) {
+      return bufferSize;
+    }
+    return llvm::None;
+  }
+
+private:
+  BufferTypeStorage(const Key &key)
+      : elementType(key.elementType), bufferSize(key.bufferSize) {}
+
+  Type elementType;
+  int64_t bufferSize;
+};
+
+BufferType mlir::linalg::BufferType::get(MLIRContext *context, Type elementType,
+                                         int64_t bufferSize) {
+  return Base::get(context, LinalgTypes::Buffer, elementType, bufferSize);
+}
+
+Type mlir::linalg::BufferType::getElementType() {
+  return getImpl()->getElementType();
+}
+
+bool mlir::linalg::BufferType::hasConstantSize() {
+  return getImpl()->hasConstantSize();
+}
+
+Optional<int64_t> mlir::linalg::BufferType::getBufferSize() {
+  return getImpl()->getBufferSize();
+}
+
+Type mlir::linalg::LinalgDialect::parseType(StringRef spec,
+                                            Location loc) const {
+  StringRef origSpec = spec;
+  MLIRContext *context = getContext();
+  if (spec == "range")
+    return RangeType::get(getContext());
+  else if (spec.consume_front("buffer")) {
+    if (spec.consume_front("<") && spec.consume_back(">")) {
+      StringRef sizeSpec, typeSpec;
+      std::tie(sizeSpec, typeSpec) = spec.split('x');
+      if (typeSpec.empty()) {
+        emitError(loc, "expected 'x' followed by element type");
+        return Type();
+      }
+      // Check for '?'
+      int64_t bufferSize = -1;
+      if (!sizeSpec.consume_front("?")) {
+        if (sizeSpec.consumeInteger(10, bufferSize)) {
+          emitError(loc, "expected buffer size to be an unsigned integer");
+          return Type();
+        }
+      }
+      if (!sizeSpec.empty()) {
+        emitError(loc, "unexpected token '") << sizeSpec << "'";
+      }
+
+      typeSpec = typeSpec.trim();
+      auto t = mlir::parseType(typeSpec, context);
+      if (!t) {
+        emitError(loc, "invalid type specification: '") << typeSpec << "'";
+        return Type();
+      }
+      return (bufferSize == -1 ? BufferType::get(getContext(), t)
+                               : BufferType::get(getContext(), t, bufferSize));
+    }
+  } else if (spec.consume_front("view")) {
+    if (spec.consume_front("<") && spec.consume_back(">")) {
+      // Just count the number of ? to get the rank.
+      unsigned rank = 0;
+      for (unsigned i = 0, e = spec.size(); i < e; ++i) {
+        if (spec.consume_front("?")) {
+          ++rank;
+          if (!spec.consume_front("x")) {
+            emitError(loc, "expected a list of '?x' dimension specifiers: ")
+                << spec;
+            return Type();
+          }
+        }
+      }
+      if (auto t = mlir::parseType(spec, context))
+        return ViewType::get(context, t, rank);
+    }
+  }
+  return (emitError(loc, "unknown Linalg type: " + origSpec), Type());
+}
+
+struct mlir::linalg::ViewTypeStorage : public TypeStorage {
+  /// Underlying Key type to transport the payload needed to construct a custom
+  /// type in a generic way.
+  struct Key {
+    Key(Type elementType, unsigned rank)
+        : elementType(elementType), rank(rank) {}
+    Type elementType;
+    unsigned rank;
+  };
+  /// `KeyTy` is a necessary typename hook for MLIR's custom type unique'ing.
+  using KeyTy = Key;
+
+  /// Construction in the llvm::BumpPtrAllocator given a key.
+  static ViewTypeStorage *construct(TypeStorageAllocator &allocator,
+                                    const Key &key) {
+    return new (allocator.allocate<ViewTypeStorage>()) ViewTypeStorage(key);
+  }
+
+  /// Equality operator for hashing.
+  bool operator==(const Key &key) const {
+    return elementType == key.elementType && rank == key.rank;
+  }
+
+  /// Hashing for unique'ing.
+  static unsigned hashKey(const Key &key) {
+    return llvm::hash_combine(key.elementType, key.rank);
+  }
+
+  unsigned getRank() { return rank; };
+  Type getElementType() { return elementType; };
+
+private:
+  ViewTypeStorage(const Key &key)
+      : elementType(key.elementType), rank(key.rank) {}
+
+  Type elementType;
+  unsigned rank;
+};
+
+ViewType mlir::linalg::ViewType::get(MLIRContext *context, Type elementType,
+                                     unsigned rank) {
+  return Base::get(context, LinalgTypes::View, elementType, rank);
+}
+
+Type mlir::linalg::ViewType::getElementType() {
+  return getImpl()->getElementType();
+}
+
+unsigned mlir::linalg::ViewType::getRank() { return getImpl()->getRank(); }
+
+/// BufferType prints as "buffer<element_type>".
+static void print(BufferType bt, raw_ostream &os) {
+  os << "buffer<";
+  auto bs = bt.getBufferSize();
+  if (bs) {
+    os << bs.getValue();
+  } else {
+    os << "?";
+  }
+  os << "x" << bt.getElementType() << ">";
+}
+
+/// RangeType prints as just "range".
+static void print(RangeType rt, raw_ostream &os) { os << "range"; }
+
+/// ViewType prints as:
+///
+/// ```{.mlir}
+///   view<?x?xf32>
+/// ```
+///
+/// or
+///
+/// ```{.mlir}
+///   view<?xf32>
+/// ```
+///
+/// for 0-D views (a.k.a pointer to a scalar value).
+static void print(mlir::linalg::ViewType rt, raw_ostream &os) {
+  os << "view<";
+  for (unsigned i = 0, e = rt.getRank(); i < e; ++i) {
+    os << "?x";
+  }
+  os << rt.getElementType();
+  os << ">";
+}
+
+void mlir::linalg::LinalgDialect::printType(Type type, raw_ostream &os) const {
+  switch (type.getKind()) {
+  default:
+    llvm_unreachable("Unhandled Linalg type");
+  case LinalgTypes::Buffer:
+    print(type.cast<BufferType>(), os);
+    break;
+  case LinalgTypes::Range:
+    print(type.cast<RangeType>(), os);
+    break;
+  case LinalgTypes::View:
+    print(type.cast<ViewType>(), os);
+    break;
+  }
+}
diff --git a/third_party/mlir/lib/Linalg/LinalgRegistration.cpp b/third_party/mlir/lib/Linalg/LinalgRegistration.cpp
new file mode 100644
index 0000000..cf5bd8f
--- /dev/null
+++ b/third_party/mlir/lib/Linalg/LinalgRegistration.cpp
@@ -0,0 +1,25 @@
+//===- LinalgRegistration.cpp - Register the linalg dialect statically ----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Linalg/IR/LinalgOps.h"
+#include "mlir/Linalg/IR/LinalgTypes.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+// Static initialization for LinalgOps dialect registration.
+static DialectRegistration<LinalgDialect> LinalgOps;
diff --git a/third_party/mlir/lib/Linalg/Transforms/Fusion.cpp b/third_party/mlir/lib/Linalg/Transforms/Fusion.cpp
new file mode 100644
index 0000000..4864f39
--- /dev/null
+++ b/third_party/mlir/lib/Linalg/Transforms/Fusion.cpp
@@ -0,0 +1,363 @@
+//===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the linalg dialect Fusion pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/EDSC/Helpers.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Linalg/Analysis/DependenceAnalysis.h"
+#include "mlir/Linalg/IR/LinalgOps.h"
+#include "mlir/Linalg/IR/LinalgTypes.h"
+#include "mlir/Linalg/Passes.h"
+#include "mlir/Linalg/Utils/Intrinsics.h"
+#include "mlir/Linalg/Utils/Utils.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/STLExtras.h"
+#include "mlir/Transforms/FoldUtils.h"
+
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "linalg-fusion"
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+using namespace mlir::linalg;
+using namespace mlir::linalg::intrinsics;
+
+using llvm::dbgs;
+
+/// Implements a simple high-level fusion pass of linalg library operations.
+///
+/// In each block, linalg ops are processed in reverse textual order.
+/// Given a linalg op, fusion occurs by:
+///   1. tiling the op by a given multi-dimensional tile size;
+///   2. inspecting the linalg ops that write into the views read by the op in
+///      step 1. This uses the SSA value of the views to determine producer-
+///      consumer dependences: only identical SSA views are considered for
+///      fusion at this point;
+///   3. greedily fuse the producing linalg ops into the consuming loop tiles;
+///   4. inspect the fused ops and determine whether they have other remaining
+///      LinalgOp uses. If not, then erase the original producing linalg op.
+///
+/// More advanced use cases, analyses as well as profitability heuristics are
+/// left for future work.
+
+static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
+static llvm::cl::list<unsigned> clTileSizes(
+    "linalg-fusion-tile-sizes",
+    llvm::cl::desc(
+        "Tile sizes by which to tile linalg operations during linalg fusion"),
+    llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
+    llvm::cl::cat(clOptionsCategory));
+
+// Return a cloned version of `op` that operates on `loopRanges`, assumed to be
+// a subset of the original loop ranges of `op`.
+// This is achieved by applying the `loopToOperandRangesMaps` permutation maps
+// to the `loopRanges` in order to obtain view ranges.
+static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
+                                    ArrayRef<SubViewOp::Range> loopRanges,
+                                    OperationFolder &state) {
+  ScopedContext scope(b, loc);
+
+  auto maps = loopToOperandRangesMaps(op);
+  SmallVector<Value *, 8> clonedViews;
+  clonedViews.reserve(op.getNumInputsAndOutputs());
+  // Iterate over the inputs and outputs in order.
+  // Extract the subranges from the linearized ranges.
+  SmallVector<Value *, 8> ios(op.getInputsAndOutputs());
+  for (auto en : llvm::enumerate(ios)) {
+    unsigned idx = en.index();
+    auto map = maps[idx];
+    LLVM_DEBUG(dbgs() << "map: " << map << "\n");
+    Value *view = en.value();
+    SmallVector<SubViewOp::Range, 8> viewRanges(map.getNumResults());
+    for (auto en2 : llvm::enumerate(map.getResults())) {
+      unsigned d = en2.index();
+      // loopToOperandRangesMaps are permutations-only.
+      unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition();
+      viewRanges[d] = loopRanges[loopPos];
+      LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index()
+                        << "\t"
+                        << "loopPos: " << loopPos << "\t" << viewRanges[d]);
+    }
+    // TODO(ntv) opportunities for folding/CSE here rather than build new IR.
+    clonedViews.push_back(b.create<SubViewOp>(loc, view, viewRanges));
+  }
+  auto operands = getAssumedNonViewOperands(op);
+  clonedViews.append(operands.begin(), operands.end());
+  return op.create(b, loc, clonedViews, op.getAttrs());
+}
+
+struct ViewDimension {
+  Value *view;
+  unsigned dimension;
+};
+
+static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
+  auto maps = loopToOperandRangesMaps(op);
+  SmallVector<Value *, 8> clonedViews;
+  clonedViews.reserve(op.getNumInputsAndOutputs());
+  // Iterate over the inputs and outputs in order.
+  // Extract the subranges from the linearized ranges.
+  SmallVector<Value *, 8> ios(op.getInputsAndOutputs());
+  for (auto en : llvm::enumerate(ios)) {
+    unsigned idx = en.index();
+    auto map = maps[idx];
+    LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n");
+    LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n");
+    Value *view = en.value();
+    SmallVector<Value *, 8> viewRanges(map.getNumResults(), nullptr);
+    for (auto en2 : llvm::enumerate(map.getResults())) {
+      if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
+        LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth
+                          << "\n");
+        LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << *view
+                          << "\n");
+        return ViewDimension{view, static_cast<unsigned>(en2.index())};
+      }
+    }
+  }
+  llvm_unreachable("Expect to be able to extract a view defining loop range");
+}
+
+static Optional<LinalgOp> fuse(Value *producedView, LinalgOp producer,
+                               LinalgOp consumer, LinalgOp tiledConsumer,
+                               OperationFolder &state) {
+  auto maybeConsumerIdx = consumer.getIndexOfInput(producedView);
+  if (!maybeConsumerIdx.hasValue())
+    return llvm::None;
+  unsigned consumerIdx = maybeConsumerIdx.getValue();
+
+  auto maybeProducerIdx = producer.getIndexOfOutput(producedView);
+  if (!maybeProducerIdx.hasValue())
+    return llvm::None;
+  unsigned producerIdx = maybeProducerIdx.getValue();
+
+  // If the view is the same between consumer and tiledConsumer, this means we
+  // don't have loops and the producer cannot be fused at this level.
+  if (consumer.getInput(consumerIdx) == tiledConsumer.getInput(consumerIdx))
+    return llvm::None;
+
+  auto tiledConsumerSubView = dyn_cast_or_null<SubViewOp>(
+      tiledConsumer.getInput(consumerIdx)->getDefiningOp());
+
+  // If we don't have a slice, this also means we don't have loops and the
+  // producer cannot be fused at this level.
+  if (!tiledConsumerSubView)
+    return llvm::None;
+
+  // loopToOperandRangesMaps are permutations-only by construction:
+  //   we can always identify a data dimension with a (at least one) loop
+  //   dimension.
+  AffineMap producerMap =
+      loopToOperandRangesMaps(producer)[producer.getNumInputs() + producerIdx];
+  LLVM_DEBUG(dbgs() << "Consumer Idx: " << consumerIdx << ", consumer map: "
+                    << loopToOperandRangesMaps(consumer)[consumerIdx] << "\n");
+  LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
+                    << ", producer map: " << producerMap << "\n");
+
+  unsigned nPar = producer.getNumParallelLoops();
+  unsigned nRed = producer.getNumReductionLoops();
+  unsigned nWin = producer.getNumWindowLoops();
+  SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
+
+  // Iterate over dimensions identified by the producer map for `producerIdx`.
+  // This defines a subset of the loop ranges that we need to complete later.
+  for (auto en : llvm::enumerate(producerMap.getResults())) {
+    unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
+    loopRanges[posInProducerLoop] = tiledConsumerSubView.getRange(en.index());
+  }
+
+  OpBuilder b(tiledConsumer.getOperation());
+  auto loc = tiledConsumer.getLoc();
+  // Iterate over all dimensions. For the dimensions not identified by the
+  // producer map for `producerIdx`, we need to explicitly compute the view that
+  // defines the loop ranges using the `producer`.
+  for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
+    if (loopRanges[i].min)
+      LLVM_DEBUG(llvm::dbgs()
+                 << "existing LoopRange: " << loopRanges[i] << "\n");
+    else {
+      auto viewDim = getViewDefiningLoopRange(producer, i);
+      loopRanges[i] = SubViewOp::Range{
+          state.create<ConstantIndexOp>(b, loc, 0),
+          linalg::intrinsics::dim(viewDim.view, viewDim.dimension),
+          state.create<ConstantIndexOp>(b, loc, 1)};
+      LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
+    }
+  }
+
+  return cloneWithLoopRanges(b, loc, producer, loopRanges, state);
+}
+
+// Encode structural fusion safety preconditions.
+// Some of these will be lifted in the future with better analysis.
+static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView,
+                                          LinalgOp consumer) {
+  // If a producer has multiple outputs, the analysis needs to take the tiling
+  // of other outputs into account.
+  if (producer.getNumOutputs() != 1)
+    return false;
+  // Until subview analysis is available, same SSA value is required for fusion.
+  if (producer.getOutput(0) != readView)
+    return false;
+  // No control-flow divergence supported. Only straightline op fusion allowed.
+  // TODO(ntv) allow fusion when a dominance relation exists.
+  if (producer.getOperation()->getBlock() !=
+      consumer.getOperation()->getBlock())
+    return false;
+  return true;
+}
+
+static void fuseLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) {
+  OperationFolder state;
+  DenseSet<Operation *> eraseSet;
+
+  LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
+
+  // 1. Record the linalg ops so we can traverse them in reverse order.
+  SmallVector<Operation *, 8> linalgOps;
+  f.walk<LinalgOp>(
+      [&](LinalgOp op) { linalgOps.push_back(op.getOperation()); });
+
+  // 2. Setup the dependences graph, aliases are populated lazily.
+  Aliases aliases;
+  LinalgDependenceGraph G(aliases, linalgOps);
+
+  // 2. For each original linalg op (in reverse order to allow chained
+  // fusions).
+  for (auto *op : llvm::reverse(linalgOps)) {
+    auto consumer = cast<LinalgOp>(op);
+    LLVM_DEBUG(dbgs() << "\n******\nStart processing:\t" << *op);
+    // 3. If marked for erasure, it has already been fused. Skip fusing op.
+    if (eraseSet.count(op) > 0) {
+      LLVM_DEBUG(dbgs() << "\nAlready fused and marked for erasure, skip.");
+      continue;
+    }
+
+    // 4. Apply loop tiling to enable fusion. If unsuccessful, skip fusing op.
+    auto tiledOp = tileLinalgOp(op, tileSizes, state);
+    if (!tiledOp) {
+      LLVM_DEBUG(dbgs() << "\nTile sizes did not produce loops, skip.");
+      continue;
+    }
+
+    // 5. For now, we only fuse RAW dependences.
+    SmallVector<Operation *, 8> fusedProducers;
+    SmallVector<Value *, 8> fusedViews;
+    for (auto dependence : G.getDependencesInto(
+             consumer, LinalgDependenceGraph::DependenceType::RAW)) {
+      auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
+      LLVM_DEBUG(dbgs() << "\n***Consider producer:\t"
+                        << *producer.getOperation() << "\n");
+
+      // a. For now we require fusion on identical SSA values, this allows us to
+      // not worry about partial writes etc.
+      // TODO(ntv) support more elaborate fusion with non identical SSA values.
+      auto *view = dependence.indexingView;
+      if (view != dependence.dependentOpView.view) {
+        LLVM_DEBUG(dbgs() << "\nviews are different SSA values, skip.");
+        continue;
+      }
+      // b. Make some simple structural checks that alleviate the need for more
+      // complex analyses.
+      if (!isStructurallyFusableProducer(producer, view, op)) {
+        LLVM_DEBUG(dbgs() << "\n***Not fusable:\t" << *producer.getOperation());
+        continue;
+      }
+      // c. Check for fusion-preventing write that would violate dependences.
+      // `view` is a producer write that cannot bypass any other write or read.
+      bool preventFusion = false;
+      for (auto *op : G.findCoveringDependences(producer, consumer))
+        if (eraseSet.count(op) == 0) {
+          preventFusion = true;
+          LLVM_DEBUG(dbgs() << "\n***Found fusion preventing dep via: " << *op);
+          break;
+        }
+      if (preventFusion)
+        continue;
+
+      // 6. Try to fuse `producer` just before `tiledOp`.
+      LLVM_DEBUG(f.print(dbgs() << "\nBefore tiledOp-fusion: \n"));
+
+      auto tOp = tiledOp->op;
+      OpBuilder builder(tOp.getOperation());
+      ScopedContext scope(builder, tOp.getLoc());
+      LLVM_DEBUG(dbgs() << "Try fuse into tiled consumer: " << *tOp << "\n");
+      auto maybeFusedProducer = fuse(view, producer, op, tOp, state);
+      if (!maybeFusedProducer) {
+        LLVM_DEBUG(dbgs() << "\nFusion did not do anything, skip.");
+        continue;
+      }
+
+      fusedProducers.push_back(producer.getOperation());
+      fusedViews.push_back(view);
+    }
+
+    // 7. If no fusion occurred, or a drop the outer tiled loop which undoes
+    // everything we did.
+    if (fusedProducers.empty()) {
+      tiledOp->loops[0].erase();
+      continue;
+    }
+
+    eraseSet.insert(op);
+    eraseSet.insert(fusedProducers.begin(), fusedProducers.end());
+  }
+
+  for (auto *op : eraseSet)
+    op->erase();
+
+  LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n"));
+}
+
+namespace {
+struct LinalgFusionPass : public FunctionPass<LinalgFusionPass> {
+  LinalgFusionPass() = default;
+  LinalgFusionPass(ArrayRef<int64_t> sizes);
+
+  void runOnFunction() { fuseLinalgOps(getFunction(), tileSizes); }
+
+  SmallVector<int64_t, 8> tileSizes;
+};
+} // namespace
+
+LinalgFusionPass::LinalgFusionPass(ArrayRef<int64_t> sizes)
+    : LinalgFusionPass() {
+  if (!sizes.empty())
+    this->tileSizes.assign(sizes.begin(), sizes.end());
+}
+
+FunctionPassBase *
+mlir::linalg::createLinalgFusionPass(ArrayRef<int64_t> tileSizes) {
+  return new LinalgFusionPass(tileSizes);
+}
+
+static PassRegistration<LinalgFusionPass>
+    pass("linalg-fusion", "Fuse operations in the linalg dialect", [] {
+      auto *pass = new LinalgFusionPass();
+      pass->tileSizes.assign(clTileSizes.begin(), clTileSizes.end());
+      return pass;
+    });
diff --git a/third_party/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/third_party/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
new file mode 100644
index 0000000..84452a2
--- /dev/null
+++ b/third_party/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
@@ -0,0 +1,750 @@
+//===- LowerToLLVMDialect.cpp - conversion from Linalg to LLVM dialect ----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Intrinsics.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/Linalg/IR/LinalgOps.h"
+#include "mlir/Linalg/IR/LinalgTypes.h"
+#include "mlir/Linalg/Passes.h"
+#include "mlir/Linalg/Utils/Intrinsics.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/LowerAffine.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "llvm/ADT/SetVector.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/ErrorHandling.h"
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+using namespace mlir::LLVM;
+using namespace mlir::linalg;
+using namespace mlir::linalg::intrinsics;
+
+using add = ValueBuilder<mlir::LLVM::AddOp>;
+using addi = ValueBuilder<mlir::AddIOp>;
+using bitcast = ValueBuilder<mlir::LLVM::BitcastOp>;
+using cmpi = ValueBuilder<mlir::CmpIOp>;
+using constant = ValueBuilder<mlir::LLVM::ConstantOp>;
+using extractvalue = ValueBuilder<mlir::LLVM::ExtractValueOp>;
+using gep = ValueBuilder<mlir::LLVM::GEPOp>;
+using insertvalue = ValueBuilder<mlir::LLVM::InsertValueOp>;
+using llvm_call = OperationBuilder<mlir::LLVM::CallOp>;
+using llvm_icmp = ValueBuilder<LLVM::ICmpOp>;
+using llvm_load = ValueBuilder<LLVM::LoadOp>;
+using llvm_store = OperationBuilder<LLVM::StoreOp>;
+using llvm_select = ValueBuilder<LLVM::SelectOp>;
+using mul = ValueBuilder<mlir::LLVM::MulOp>;
+using sub = ValueBuilder<mlir::LLVM::SubOp>;
+using undef = ValueBuilder<mlir::LLVM::UndefOp>;
+using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
+using llvm_return = OperationBuilder<LLVM::ReturnOp>;
+
+template <typename T>
+static LLVMType getPtrToElementType(T containerType,
+                                    LLVMTypeConverter &lowering) {
+  return lowering.convertType(containerType.getElementType())
+      .template cast<LLVMType>()
+      .getPointerTo();
+}
+
+// Convert the given type to the LLVM IR Dialect type.  The following
+// conversions are supported:
+//   - an Index type is converted into an LLVM integer type with pointer
+//     bitwidth (analogous to intptr_t in C);
+//   - an Integer type is converted into an LLVM integer type of the same width;
+//   - an F32 type is converted into an LLVM float type
+//   - a Buffer, Range or View is converted into an LLVM structure type
+//     containing the respective dynamic values.
+static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) {
+  auto *context = t.getContext();
+  auto int64Ty = lowering.convertType(IntegerType::get(64, context))
+                     .cast<LLVM::LLVMType>();
+
+  // A buffer descriptor contains the pointer to a flat region of storage and
+  // the size of the region.
+  //
+  // template <typename Elem, size_t Rank>
+  // struct {
+  //   Elem *ptr;
+  //   int64_t size;
+  // };
+  if (auto bufferType = t.dyn_cast<BufferType>()) {
+    auto ptrTy = getPtrToElementType(bufferType, lowering);
+    return LLVMType::getStructTy(ptrTy, int64Ty);
+  }
+
+  // Range descriptor contains the range bounds and the step as 64-bit integers.
+  //
+  // struct {
+  //   int64_t min;
+  //   int64_t max;
+  //   int64_t step;
+  // };
+  if (t.isa<RangeType>())
+    return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
+
+  // View descriptor contains the pointer to the data buffer, followed by a
+  // 64-bit integer containing the distance between the beginning of the buffer
+  // and the first element to be accessed through the view, followed by two
+  // arrays, each containing as many 64-bit integers as the rank of the View.
+  // The first array represents the size, in number of original elements, of the
+  // view along the given dimension.  When taking the view, the size is the
+  // difference between the upper and the lower bound of the range.  The second
+  // array represents the "stride" (in tensor abstraction sense), i.e. the
+  // number of consecutive elements of the underlying buffer that separate two
+  // consecutive elements addressable through the view along the given
+  // dimension.  When taking the view, the strides are constructed as products
+  // of the original sizes along the trailing dimensions, multiplied by the view
+  // step.  For example, a view of a MxN memref with ranges {0:M:1}, {0:N:1},
+  // i.e. the view of a complete memref, will have strides N and 1.  A view with
+  // ranges {0:M:2}, {0:N:3} will have strides 2*N and 3.
+  //
+  // template <typename Elem, size_t Rank>
+  // struct {
+  //   Elem *ptr;
+  //   int64_t offset;
+  //   int64_t sizes[Rank];
+  //   int64_t strides[Rank];
+  // };
+  if (auto viewType = t.dyn_cast<ViewType>()) {
+    auto ptrTy = getPtrToElementType(viewType, lowering);
+    auto arrayTy = LLVMType::getArrayTy(int64Ty, viewType.getRank());
+    return LLVMType::getStructTy(ptrTy, int64Ty, arrayTy, arrayTy);
+  }
+
+  return Type();
+}
+
+// Create an array attribute containing integer attributes with values provided
+// in `position`.
+static ArrayAttr positionAttr(Builder &builder, ArrayRef<int> position) {
+  SmallVector<Attribute, 4> attrs;
+  attrs.reserve(position.size());
+  for (auto p : position)
+    attrs.push_back(builder.getI64IntegerAttr(p));
+  return builder.getArrayAttr(attrs);
+}
+
+// BufferAllocOp creates a new `!linalg.buffer` value.
+class BufferAllocOpConversion : public LLVMOpLowering {
+public:
+  explicit BufferAllocOpConversion(MLIRContext *context,
+                                   LLVMTypeConverter &lowering_)
+      : LLVMOpLowering(BufferAllocOp::getOperationName(), context, lowering_) {}
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto indexType = IndexType::get(op->getContext());
+    auto voidPtrTy =
+        LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
+    auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
+    // Insert the `malloc` declaration if it is not already present.
+    auto module = op->getParentOfType<ModuleOp>();
+    FuncOp mallocFunc = module.lookupSymbol<FuncOp>("malloc");
+    if (!mallocFunc) {
+      auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy);
+      mallocFunc =
+          FuncOp::create(rewriter.getUnknownLoc(), "malloc", mallocType);
+      module.push_back(mallocFunc);
+    }
+
+    // Get MLIR types for injecting element pointer.
+    auto allocOp = cast<BufferAllocOp>(op);
+    auto elementType = allocOp.getElementType();
+    uint64_t elementSize = 0;
+    if (auto vectorType = elementType.dyn_cast<VectorType>())
+      elementSize = vectorType.getNumElements() *
+                    llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
+    else
+      elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
+    auto bufferType = allocOp.getResult()->getType().cast<BufferType>();
+    auto elementPtrType = getPtrToElementType(bufferType, lowering);
+    auto bufferDescriptorType =
+        convertLinalgType(allocOp.getResult()->getType(), lowering);
+
+    // Emit IR for creating a new buffer descriptor with an underlying malloc.
+    edsc::ScopedContext context(rewriter, op->getLoc());
+    auto constantSize = bufferType.getBufferSize();
+    Value *size =
+        constantSize
+            ? constant(int64Ty, IntegerAttr::get(indexType, *constantSize))
+                  .getValue()
+            : operands[0];
+    Value *allocSize =
+        mul(size, constant(int64Ty, IntegerAttr::get(indexType, elementSize)));
+    Value *allocated =
+        llvm_call(voidPtrTy, rewriter.getSymbolRefAttr(mallocFunc), allocSize)
+            .getOperation()
+            ->getResult(0);
+    allocated = bitcast(elementPtrType, allocated);
+    Value *desc = undef(bufferDescriptorType);
+    desc = insertvalue(bufferDescriptorType, desc, allocated,
+                       positionAttr(rewriter, 0));
+    desc = insertvalue(bufferDescriptorType, desc, size,
+                       positionAttr(rewriter, 1));
+    rewriter.replaceOp(op, desc);
+    return matchSuccess();
+  }
+};
+
+// BufferDeallocOp creates no value.
+class BufferDeallocOpConversion : public LLVMOpLowering {
+public:
+  explicit BufferDeallocOpConversion(MLIRContext *context,
+                                     LLVMTypeConverter &lowering_)
+      : LLVMOpLowering(BufferDeallocOp::getOperationName(), context,
+                       lowering_) {}
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto voidPtrTy =
+        LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
+    // Insert the `free` declaration if it is not already present.
+    auto module = op->getParentOfType<ModuleOp>();
+    FuncOp freeFunc = module.lookupSymbol<FuncOp>("free");
+    if (!freeFunc) {
+      auto freeType = rewriter.getFunctionType(voidPtrTy, {});
+      freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType);
+      module.push_back(freeFunc);
+    }
+
+    // Get MLIR types for extracting element pointer.
+    auto deallocOp = cast<BufferDeallocOp>(op);
+    auto elementPtrTy = getPtrToElementType(
+        deallocOp.getOperand()->getType().cast<BufferType>(), lowering);
+
+    // Emit MLIR for buffer_dealloc.
+    edsc::ScopedContext context(rewriter, op->getLoc());
+    Value *casted = bitcast(voidPtrTy, extractvalue(elementPtrTy, operands[0],
+                                                    positionAttr(rewriter, 0)));
+    llvm_call(ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
+    rewriter.replaceOp(op, llvm::None);
+    return matchSuccess();
+  }
+};
+
+// BufferSizeOp creates a new `index` value.
+class BufferSizeOpConversion : public LLVMOpLowering {
+public:
+  BufferSizeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
+      : LLVMOpLowering(BufferSizeOp::getOperationName(), context, lowering_) {}
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
+    edsc::ScopedContext context(rewriter, op->getLoc());
+    rewriter.replaceOp(
+        op, {extractvalue(int64Ty, operands[0], positionAttr(rewriter, 1))});
+    return matchSuccess();
+  }
+};
+
+// DimOp creates a new `index` value.
+class DimOpConversion : public LLVMOpLowering {
+public:
+  explicit DimOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
+      : LLVMOpLowering(linalg::DimOp::getOperationName(), context, lowering_) {}
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto dimOp = cast<linalg::DimOp>(op);
+    auto indexTy = lowering.convertType(rewriter.getIndexType());
+    edsc::ScopedContext context(rewriter, op->getLoc());
+    rewriter.replaceOp(
+        op,
+        {extractvalue(
+            indexTy, operands[0],
+            positionAttr(rewriter, {2, static_cast<int>(dimOp.getIndex())}))});
+    return matchSuccess();
+  }
+};
+
+namespace {
+// Common functionality for Linalg LoadOp and StoreOp conversion to the
+// LLVM IR Dialect.
+template <typename Op> class LoadStoreOpConversion : public LLVMOpLowering {
+public:
+  explicit LoadStoreOpConversion(MLIRContext *context,
+                                 LLVMTypeConverter &lowering_)
+      : LLVMOpLowering(Op::getOperationName(), context, lowering_) {}
+  using Base = LoadStoreOpConversion<Op>;
+
+  // Compute the pointer to an element of the buffer underlying the view given
+  // current view indices.  Use the base offset and strides stored in the view
+  // descriptor to emit IR iteratively computing the actual offset, followed by
+  // a getelementptr. This must be called under an edsc::ScopedContext.
+  Value *obtainDataPtr(Operation *op, Value *viewDescriptor,
+                       ArrayRef<Value *> indices,
+                       ConversionPatternRewriter &rewriter) const {
+    auto loadOp = cast<Op>(op);
+    auto elementTy = getPtrToElementType(loadOp.getViewType(), lowering);
+    auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
+    auto pos = [&rewriter](ArrayRef<int> values) {
+      return positionAttr(rewriter, values);
+    };
+
+    // Linearize subscripts as:
+    //   base_offset + SUM_i index_i * stride_i.
+    Value *base = extractvalue(elementTy, viewDescriptor, pos(0));
+    Value *offset = extractvalue(int64Ty, viewDescriptor, pos(1));
+    for (int i = 0, e = loadOp.getRank(); i < e; ++i) {
+      Value *stride = extractvalue(int64Ty, viewDescriptor, pos({3, i}));
+      Value *additionalOffset = mul(indices[i], stride);
+      offset = add(offset, additionalOffset);
+    }
+    return gep(elementTy, base, offset);
+  }
+};
+} // namespace
+
+// A load is converted into the actual address computation, getelementptr and
+// an LLVM IR load.
+class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
+  using Base::Base;
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    edsc::ScopedContext edscContext(rewriter, op->getLoc());
+    auto elementTy = lowering.convertType(*op->result_type_begin());
+    Value *viewDescriptor = operands[0];
+    ArrayRef<Value *> indices = operands.drop_front();
+    auto ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
+    rewriter.replaceOp(op, {llvm_load(elementTy, ptr)});
+    return matchSuccess();
+  }
+};
+
+// RangeOp creates a new range descriptor.
+class RangeOpConversion : public LLVMOpLowering {
+public:
+  explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
+      : LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {}
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto rangeOp = cast<RangeOp>(op);
+    auto rangeDescriptorTy =
+        convertLinalgType(rangeOp.getResult()->getType(), lowering);
+
+    edsc::ScopedContext context(rewriter, op->getLoc());
+
+    // Fill in an aggregate value of the descriptor.
+    Value *desc = undef(rangeDescriptorTy);
+    desc = insertvalue(rangeDescriptorTy, desc, operands[0],
+                       positionAttr(rewriter, 0));
+    desc = insertvalue(rangeDescriptorTy, desc, operands[1],
+                       positionAttr(rewriter, 1));
+    desc = insertvalue(rangeDescriptorTy, desc, operands[2],
+                       positionAttr(rewriter, 2));
+    rewriter.replaceOp(op, desc);
+    return matchSuccess();
+  }
+};
+
+class SliceOpConversion : public LLVMOpLowering {
+public:
+  explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
+      : LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {}
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto sliceOp = cast<SliceOp>(op);
+    auto viewDescriptorTy = convertLinalgType(sliceOp.getViewType(), lowering);
+    auto viewType = sliceOp.getBaseViewType();
+    auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
+
+    // Helper function to create an integer array attribute out of a list of
+    // values.
+    auto pos = [&rewriter](ArrayRef<int> values) {
+      return positionAttr(rewriter, values);
+    };
+    // Helper function to obtain the ptr of the given `view`.
+    auto getViewPtr = [pos, this](ViewType type, Value *view) -> Value * {
+      auto elementPtrTy = getPtrToElementType(type, lowering);
+      return extractvalue(elementPtrTy, view, pos(0));
+    };
+
+    edsc::ScopedContext context(rewriter, op->getLoc());
+    // Declare the view descriptor and insert data ptr.
+    Value *desc = undef(viewDescriptorTy);
+    desc = insertvalue(viewDescriptorTy, desc,
+                       getViewPtr(viewType, operands[0]), pos(0));
+
+    // TODO(ntv): extract sizes and emit asserts.
+    SmallVector<Value *, 4> strides(viewType.getRank());
+    for (int dim = 0, e = viewType.getRank(); dim < e; ++dim) {
+      strides[dim] = extractvalue(int64Ty, operands[0], pos({3, dim}));
+    }
+
+    // Compute and insert base offset.
+    Value *baseOffset = extractvalue(int64Ty, operands[0], pos(1));
+    for (int j = 0, e = viewType.getRank(); j < e; ++j) {
+      Value *indexing = operands[1 + j];
+      Value *min =
+          sliceOp.getIndexing(j)->getType().isa<RangeType>()
+              ? static_cast<Value *>(extractvalue(int64Ty, indexing, pos(0)))
+              : indexing;
+      Value *product = mul(min, strides[j]);
+      baseOffset = add(baseOffset, product);
+    }
+    desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1));
+
+    // Compute and insert view sizes (max - min along the range).  Skip the
+    // non-range operands as they will be projected away from the view.
+    int i = 0;
+    for (Value *index : sliceOp.getIndexings()) {
+      if (!index->getType().isa<RangeType>())
+        continue;
+
+      Value *rangeDescriptor = operands[1 + i];
+      Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
+      Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
+      Value *size = sub(max, min);
+
+      desc = insertvalue(viewDescriptorTy, desc, size, pos({2, i}));
+      ++i;
+    }
+
+    // Compute and insert view strides.  Step over the strides that correspond
+    // to non-range operands as they are projected away from the view.
+    i = 0;
+    for (int j = 0, e = strides.size(); j < e; ++j) {
+      if (!sliceOp.getIndexing(j)->getType().isa<RangeType>())
+        continue;
+      Value *step = extractvalue(int64Ty, operands[1 + j], pos(2));
+      Value *stride = mul(strides[j], step);
+      desc = insertvalue(viewDescriptorTy, desc, stride, pos({3, i}));
+      ++i;
+    }
+
+    rewriter.replaceOp(op, desc);
+    return matchSuccess();
+  }
+};
+
+// A store is converted into the actual address computation, getelementptr and
+// an LLVM IR store.
+class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
+  using Base::Base;
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    edsc::ScopedContext edscContext(rewriter, op->getLoc());
+    Value *data = operands[0];
+    Value *viewDescriptor = operands[1];
+    ArrayRef<Value *> indices = operands.drop_front(2);
+    Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
+    llvm_store(data, ptr);
+    rewriter.replaceOp(op, llvm::None);
+    return matchSuccess();
+  }
+};
+
+class ViewOpConversion : public LLVMOpLowering {
+public:
+  explicit ViewOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
+      : LLVMOpLowering(ViewOp::getOperationName(), context, lowering_) {}
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto viewOp = cast<ViewOp>(op);
+    auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering);
+    auto elementTy = getPtrToElementType(viewOp.getViewType(), lowering);
+    auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
+
+    auto pos = [&rewriter](ArrayRef<int> values) {
+      return positionAttr(rewriter, values);
+    };
+
+    // First operand to `view` is the buffer descriptor.
+    Value *bufferDescriptor = operands[0];
+
+    // Declare the descriptor of the view.
+    edsc::ScopedContext context(rewriter, op->getLoc());
+    Value *desc = undef(viewDescriptorTy);
+
+    // Copy the buffer pointer from the old descriptor to the new one.
+    Value *buffer = extractvalue(elementTy, bufferDescriptor, pos(0));
+    desc = insertvalue(viewDescriptorTy, desc, buffer, pos(0));
+
+    // Zero base offset.
+    auto indexTy = rewriter.getIndexType();
+    Value *baseOffset = constant(int64Ty, IntegerAttr::get(indexTy, 0));
+    desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1));
+
+    // Compute and insert view sizes (max - min along the range).
+    int numRanges = llvm::size(viewOp.ranges());
+    Value *runningStride = constant(int64Ty, IntegerAttr::get(indexTy, 1));
+    for (int i = numRanges - 1; i >= 0; --i) {
+      // Update stride.
+      Value *rangeDescriptor = operands[1 + i];
+      Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
+      Value *stride = mul(runningStride, step);
+      desc = insertvalue(viewDescriptorTy, desc, stride, pos({3, i}));
+      // Update size.
+      Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
+      Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
+      Value *size = sub(max, min);
+      desc = insertvalue(viewDescriptorTy, desc, size, pos({2, i}));
+      // Update stride for the next dimension.
+      if (i > 0)
+        runningStride = mul(runningStride, max);
+    }
+
+    rewriter.replaceOp(op, desc);
+    return matchSuccess();
+  }
+};
+
+// Create a function definition which takes as argument pointers to the input
+// types and returns pointers to the output types.
+static FuncOp getLLVMLibraryCallImplDefinition(FuncOp libFn) {
+  auto implFnName = (libFn.getName().str() + "_impl");
+  auto module = libFn.getParentOfType<ModuleOp>();
+  if (auto f = module.lookupSymbol<FuncOp>(implFnName)) {
+    return f;
+  }
+  SmallVector<Type, 4> fnArgTypes;
+  for (auto t : libFn.getType().getInputs()) {
+    assert(t && t.isa<LLVMType>() &&
+           "Expected LLVM Type for argument while generating library Call "
+           "Implementation Definition");
+    fnArgTypes.push_back(t.cast<LLVMType>().getPointerTo());
+  }
+  auto implFnType = FunctionType::get(fnArgTypes, {}, libFn.getContext());
+
+  // Insert the implementation function definition.
+  auto implFnDefn = FuncOp::create(libFn.getLoc(), implFnName, implFnType);
+  module.push_back(implFnDefn);
+  return implFnDefn;
+}
+
+// Get function definition for the LinalgOp. If it doesn't exist, insert a
+// definition.
+template <typename LinalgOp>
+static FuncOp
+getLLVMLibraryCallDeclaration(Operation *op, LLVMTypeConverter &lowering,
+                              ConversionPatternRewriter &rewriter) {
+  auto linalgOp = cast<LinalgOp>(op);
+  auto fnName = linalgOp.getLibraryCallName();
+  if (fnName.empty()) {
+    op->emitWarning("No library call defined for: ") << *op;
+    return FuncOp();
+  }
+  auto module = op->getParentOfType<ModuleOp>();
+  if (auto f = module.lookupSymbol<FuncOp>(fnName)) {
+    return f;
+  }
+
+  // Get the Function type consistent with LLVM Lowering.
+  SmallVector<Type, 4> inputTypes;
+  for (auto operand : op->getOperands())
+    inputTypes.push_back(lowering.convertType(operand->getType()));
+  assert(op->getNumResults() == 0 &&
+         "Library call for linalg operation can be generated only for ops that "
+         "have void return types");
+  auto libFnType = FunctionType::get(inputTypes, {}, op->getContext());
+  auto libFn = FuncOp::create(op->getLoc(), fnName, libFnType);
+  module.push_back(libFn);
+  // Return after creating the function definition. The body will be created
+  // later.
+  return libFn;
+}
+
+static void getLLVMLibraryCallDefinition(FuncOp fn,
+                                         LLVMTypeConverter &lowering) {
+  // Generate the implementation function definition.
+  auto implFn = getLLVMLibraryCallImplDefinition(fn);
+
+  // Generate the function body.
+  OpBuilder builder(fn.addEntryBlock());
+  edsc::ScopedContext scope(builder, fn.getLoc());
+  SmallVector<Value *, 4> implFnArgs;
+
+  // Create a constant 1.
+  auto one = constant(LLVMType::getInt64Ty(lowering.getDialect()),
+                      IntegerAttr::get(IndexType::get(fn.getContext()), 1));
+  for (auto arg : fn.getArguments()) {
+    // Allocate a stack for storing the argument value. The stack is passed to
+    // the implementation function.
+    auto alloca =
+        llvm_alloca(arg->getType().cast<LLVMType>().getPointerTo(), one)
+            .getValue();
+    implFnArgs.push_back(alloca);
+    llvm_store(arg, alloca);
+  }
+  llvm_call(ArrayRef<Type>(), builder.getSymbolRefAttr(implFn), implFnArgs);
+  llvm_return{ArrayRef<Value *>()};
+}
+
+namespace {
+// The conversion class from Linalg to LLVMIR.
+class LinalgTypeConverter : public LLVMTypeConverter {
+  using LLVMTypeConverter::LLVMTypeConverter;
+
+public:
+  Type convertType(Type t) override {
+    if (auto result = LLVMTypeConverter::convertType(t))
+      return result;
+    return convertLinalgType(t, *this);
+  }
+
+  void addLibraryFnDeclaration(FuncOp fn) { libraryFnDeclarations.insert(fn); }
+
+  ArrayRef<FuncOp> getLibraryFnDeclarations() {
+    return libraryFnDeclarations.getArrayRef();
+  }
+
+private:
+  /// List of library functions declarations needed during dialect conversion
+  llvm::SetVector<FuncOp> libraryFnDeclarations;
+};
+} // end anonymous namespace
+
+// LinalgOpConversion<LinalgOp> creates a new call to the
+// `LinalgOp::getLibraryCallName()` function.
+// The implementation of the function can be either in the same module or in an
+// externally linked library.
+template <typename LinalgOp> class LinalgOpConversion : public LLVMOpLowering {
+public:
+  explicit LinalgOpConversion(MLIRContext *context,
+                              LinalgTypeConverter &lowering_)
+      : LLVMOpLowering(LinalgOp::getOperationName(), context, lowering_) {}
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Only emit library call declaration. Fill in the body later.
+    auto f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter);
+    if (!f)
+      return matchFailure();
+    static_cast<LinalgTypeConverter &>(lowering).addLibraryFnDeclaration(f);
+
+    auto fAttr = rewriter.getSymbolRefAttr(f);
+    auto named = rewriter.getNamedAttr("callee", fAttr);
+    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operands,
+                                              ArrayRef<NamedAttribute>{named});
+    return matchSuccess();
+  }
+};
+
+/// Populate the given list with patterns that convert from Linalg to LLVM.
+static void
+populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter,
+                                       OwningRewritePatternList &patterns,
+                                       MLIRContext *ctx) {
+  patterns
+      .insert<BufferAllocOpConversion, BufferDeallocOpConversion,
+              BufferSizeOpConversion, DimOpConversion,
+              LinalgOpConversion<DotOp>, LinalgOpConversion<FillOp>,
+              LinalgOpConversion<MatmulOp>, LoadOpConversion, RangeOpConversion,
+              SliceOpConversion, StoreOpConversion, ViewOpConversion>(
+          ctx, converter);
+}
+
+namespace {
+struct LowerLinalgToLLVMPass : public ModulePass<LowerLinalgToLLVMPass> {
+  void runOnModule();
+};
+} // namespace
+
+// This is currently written as a standalone function because the lowering to
+// affine will look different than lowering to LLVM and it is still unclear how
+// everything will be eventually structured.
+static void lowerLinalgSubViewOps(FuncOp &f) {
+  f.walk<SubViewOp>([&](SubViewOp op) {
+    OpBuilder b(op);
+    ScopedContext scope(b, op.getLoc());
+    auto *view = op.getView();
+    SmallVector<Value *, 8> ranges;
+    for (auto en : llvm::enumerate(op.getRanges())) {
+      using edsc::op::operator<;
+      using linalg::intrinsics::dim;
+      unsigned rank = en.index();
+      auto sliceRange = en.value();
+      auto size = dim(view, rank);
+      ValueHandle ub(sliceRange.max);
+      auto max = edsc::intrinsics::select(size < ub, size, ub);
+      ranges.push_back(range(sliceRange.min, max, sliceRange.step));
+    }
+    op.replaceAllUsesWith(slice(view, ranges));
+    op.erase();
+  });
+}
+
+void LowerLinalgToLLVMPass::runOnModule() {
+  auto module = getModule();
+
+  for (auto f : module.getOps<FuncOp>())
+    lowerLinalgSubViewOps(f);
+
+  // Convert to the LLVM IR dialect using the converter defined above.
+  OwningRewritePatternList patterns;
+  LinalgTypeConverter converter(&getContext());
+  populateAffineToStdConversionPatterns(patterns, &getContext());
+  populateLoopToStdConversionPatterns(patterns, &getContext());
+  populateStdToLLVMConversionPatterns(converter, patterns);
+  populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
+
+  ConversionTarget target(getContext());
+  target.addLegalDialect<LLVM::LLVMDialect>();
+  target.addDynamicallyLegalOp<FuncOp>(
+      [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
+  if (failed(applyPartialConversion(module, target, patterns, &converter))) {
+    signalPassFailure();
+  }
+
+  // Emit the function body of any Library function that was declared.
+  for (auto fn : converter.getLibraryFnDeclarations()) {
+    getLLVMLibraryCallDefinition(fn, converter);
+  }
+}
+
+ModulePassBase *mlir::linalg::createLowerLinalgToLLVMPass() {
+  return new LowerLinalgToLLVMPass();
+}
+
+static PassRegistration<LowerLinalgToLLVMPass>
+    pass("linalg-lower-to-llvm-dialect",
+         "Lower the operations from the linalg dialect into the LLVM dialect");
diff --git a/third_party/mlir/lib/Linalg/Transforms/LowerToLoops.cpp b/third_party/mlir/lib/Linalg/Transforms/LowerToLoops.cpp
new file mode 100644
index 0000000..afeb5c4
--- /dev/null
+++ b/third_party/mlir/lib/Linalg/Transforms/LowerToLoops.cpp
@@ -0,0 +1,399 @@
+//===- LowerToLoops.cpp - conversion from Linalg library ops to loops------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/EDSC/Helpers.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Linalg/IR/LinalgOps.h"
+#include "mlir/Linalg/IR/LinalgTypes.h"
+#include "mlir/Linalg/Passes.h"
+#include "mlir/Linalg/Utils/Intrinsics.h"
+#include "mlir/Linalg/Utils/Utils.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/STLExtras.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/FoldUtils.h"
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+using namespace mlir::linalg;
+using namespace mlir::linalg::intrinsics;
+
+using IndexedLinalgValue = TemplatedIndexedValue<linalg_load, linalg_store>;
+using edsc::op::operator+;
+using edsc::op::operator==;
+
+static SmallVector<ValueHandle, 8>
+foldedAffineApplies(OpBuilder &b, Location loc, AffineMap map,
+                    ArrayRef<Value *> vals, OperationFolder &folder) {
+  assert(map.getNumSymbols() == 0);
+  assert(map.getNumInputs() == vals.size());
+  SmallVector<ValueHandle, 8> res;
+  res.reserve(map.getNumResults());
+  auto dims = map.getNumDims();
+  for (auto e : map.getResults()) {
+    auto exprMap = AffineMap::get(dims, 0, e);
+    SmallVector<Value *, 4> operands(vals.begin(), vals.end());
+    canonicalizeMapAndOperands(&exprMap, &operands);
+    res.push_back(affine_apply(folder, exprMap, operands));
+  }
+  return res;
+}
+
+static SmallVector<Value *, 4> permuteIvs(ArrayRef<Value *> ivs,
+                                          Optional<AffineMap> permutation,
+                                          OperationFolder &state) {
+  return permutation ? applyMapToValues(ScopedContext::getBuilder(),
+                                        ScopedContext::getLocation(),
+                                        permutation.getValue(), ivs, state)
+                     : SmallVector<Value *, 4>(ivs.begin(), ivs.end());
+}
+
+// Creates a number of ranges equal to the number of results in `map`.
+// The returned ranges correspond to the loop ranges, in the proper order, for
+// which new loops will be created.
+static SmallVector<Value *, 4> emitLoopRanges(OpBuilder &b, Location loc,
+                                              AffineMap map,
+                                              ArrayRef<Value *> allViewSizes,
+                                              OperationFolder &folder) {
+  // Apply `map` to get view sizes in loop order.
+  auto sizes = applyMapToValues(b, loc, map, allViewSizes, folder);
+  // Create a new range with the applied tile sizes.
+  ScopedContext scope(b, loc);
+  SmallVector<Value *, 4> res;
+  for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) {
+    res.push_back(range(constant_index(folder, 0), sizes[idx],
+                        constant_index(folder, 1)));
+  }
+  return res;
+}
+
+template <typename LinalgOpType> class LinalgScopedEmitter {};
+
+template <> class LinalgScopedEmitter<CopyOp> {
+public:
+  static void emitScalarImplementation(ArrayRef<Value *> allIvs, CopyOp copyOp,
+                                       OperationFolder &folder) {
+    auto nPar = copyOp.getNumParallelLoops();
+    assert(nPar == allIvs.size());
+    auto inputIvs =
+        permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation(), folder);
+    auto outputIvs =
+        permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation(), folder);
+    SmallVector<IndexHandle, 8> iivs(inputIvs.begin(), inputIvs.end());
+    SmallVector<IndexHandle, 8> oivs(outputIvs.begin(), outputIvs.end());
+    IndexedLinalgValue O(copyOp.getOutput(0)), I(copyOp.getInput(0));
+    // Emit the proper scalar assignment, whether we are dealing with a 0-D or
+    // an n-D loop nest; with or without permutations.
+    // clang-format off
+    nPar > 0 ? O(oivs) = I(iivs) :
+               O() = I();
+    // clang-format on
+  }
+};
+
+template <> class LinalgScopedEmitter<FillOp> {
+public:
+  static void emitScalarImplementation(ArrayRef<Value *> allIvs, FillOp fillOp,
+                                       OperationFolder &folder) {
+    auto nPar = fillOp.getNumParallelLoops();
+    assert(nPar == allIvs.size());
+    auto ivs =
+        SmallVector<IndexHandle, 4>(allIvs.begin(), allIvs.begin() + nPar);
+    IndexedLinalgValue O(fillOp.getOutput(0));
+    // Emit the proper scalar assignment, whether we are dealing with a 0-D or
+    // an n-D loop nest; with or without permutations.
+    nPar > 0 ? O(ivs) = ValueHandle(fillOp.getValue())
+             : O() = ValueHandle(fillOp.getValue());
+  }
+};
+
+template <> class LinalgScopedEmitter<DotOp> {
+public:
+  static void emitScalarImplementation(ArrayRef<Value *> allIvs, DotOp dotOp,
+                                       OperationFolder &folder) {
+    assert(allIvs.size() == 1);
+    IndexHandle r_i(allIvs[0]);
+    IndexedLinalgValue A(dotOp.getInput(0)), B(dotOp.getInput(1)),
+        C(dotOp.getOutput(0));
+    // Emit scalar form.
+    C() = C() + A(r_i) * B(r_i);
+  }
+};
+
+template <> class LinalgScopedEmitter<MatvecOp> {
+public:
+  static void emitScalarImplementation(ArrayRef<Value *> allIvs,
+                                       MatvecOp matvecOp,
+                                       OperationFolder &folder) {
+    assert(allIvs.size() == 2);
+    IndexHandle i(allIvs[0]), r_j(allIvs[1]);
+    IndexedLinalgValue A(matvecOp.getInput(0)), B(matvecOp.getInput(1)),
+        C(matvecOp.getOutput(0));
+    // Emit scalar form.
+    C(i) = C(i) + A(i, r_j) * B(r_j);
+  }
+};
+
+template <> class LinalgScopedEmitter<MatmulOp> {
+public:
+  static void emitScalarImplementation(ArrayRef<Value *> allIvs,
+                                       MatmulOp matmulOp,
+                                       OperationFolder &folder) {
+    assert(allIvs.size() == 3);
+    IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]);
+    IndexedLinalgValue A(matmulOp.getInput(0)), B(matmulOp.getInput(1)),
+        C(matmulOp.getOutput(0));
+    // Emit scalar form.
+    C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j);
+  }
+};
+
+template <> class LinalgScopedEmitter<ConvOp> {
+public:
+  static void emitScalarImplementation(ArrayRef<Value *> allIvs, ConvOp convOp,
+                                       OperationFolder &folder) {
+    auto b = ScopedContext::getBuilder();
+    auto loc = ScopedContext::getLocation();
+    auto maps = loopToOperandRangesMaps(convOp);
+    SmallVector<ValueHandle, 8> fIdx(
+        foldedAffineApplies(b, loc, maps[0], allIvs, folder));
+    SmallVector<ValueHandle, 8> imIdx(
+        foldedAffineApplies(b, loc, maps[1], allIvs, folder));
+    SmallVector<ValueHandle, 8> oIdx(
+        foldedAffineApplies(b, loc, maps[2], allIvs, folder));
+    IndexedLinalgValue F(convOp.filter()), I(convOp.input()),
+        O(convOp.output());
+    // Emit scalar form.
+    O(oIdx) += F(fIdx) * I(imIdx);
+  }
+};
+
+// Emits the MLIR for the scalar part of the generic op by:
+//   1. Emitting linalg_load and linalg_store ops for each input and output
+//      view in order. This is achieved by applying the appropriate input or
+//      output map to the enclosing induction variables.
+//   2. Emitting a call to `op.fun()` that takes as arguments the scalars
+//      from point 1. above.
+//   3. Emitting linalg_store to store the results of 2. to the output
+//      views.
+//
+// An example output may resemble:
+//
+// ```
+//    loop.for %i = %c0 to %0 step %c1 {
+//      loop.for %j = %c0 to %1 step %c1 {
+//        loop.for %k = %c0 to %4 step %c1 {
+//          %11 = linalg.load %arg0[%i, %j] : !linalg.view<?x?xf32>
+//          %12 = linalg.load %arg1[%i, %j, %k] : !linalg.view<?x?x?xf32>
+//          %13 = linalg.load %arg2[%i, %k, %j] : !linalg.view<?x?x?xf32>
+//          %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32)
+//          linalg.store %14#0, %arg1[%i, %j, %k] : !linalg.view<?x?x?xf32>
+//          linalg.store %14#1, %arg2[%i, %k, %j] : !linalg.view<?x?x?xf32>
+//       }
+//      }
+//    }
+// ```
+template <> class LinalgScopedEmitter<GenericOp> {
+public:
+  static void emitScalarImplementation(ArrayRef<Value *> allIvs,
+                                       GenericOp genericOp,
+                                       OperationFolder &folder) {
+    auto b = ScopedContext::getBuilder();
+    auto loc = ScopedContext::getLocation();
+    using edsc::intrinsics::detail::ValueHandleArray;
+    unsigned nInputs = genericOp.getNumInputs();
+    unsigned nOutputs = genericOp.getNumOutputs();
+    SmallVector<Value *, 4> indexedValues(nInputs + nOutputs);
+
+    // 1.a. Emit linalg_load from input views.
+    for (unsigned i = 0, e = nInputs; i < e; ++i) {
+      ValueHandleArray indexing(foldedAffineApplies(
+          b, loc, genericOp.getInputIndexingMap(i), allIvs, folder));
+      indexedValues[i] = linalg_load(genericOp.getInput(i), indexing);
+    }
+
+    // 1.b. Emit linalg_load from output views.
+    for (unsigned i = 0, e = nOutputs; i < e; ++i) {
+      ValueHandleArray indexing(foldedAffineApplies(
+          b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
+      indexedValues[nInputs + i] =
+          linalg_load(genericOp.getOutput(i), indexing);
+    }
+
+    auto funcOp = genericOp.getFunction();
+    if (funcOp) {
+      // 2. Emit call.
+      Operation *callOp = call(funcOp, indexedValues);
+      assert(callOp->getNumResults() == genericOp.getNumOutputs());
+
+      // 3. Emit linalg_store.
+      for (unsigned i = 0, e = nOutputs; i < e; ++i) {
+        ValueHandleArray indexing(foldedAffineApplies(
+            b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
+        linalg_store(callOp->getResult(i), genericOp.getOutput(i), indexing);
+      }
+    } else {
+      // TODO(ntv): When a region inliner exists, use it.
+      // 2. Inline region, currently only works for a single basic block.
+      BlockAndValueMapping map;
+      auto &block = genericOp.region().front();
+      for (auto it : llvm::zip(block.getArguments(), indexedValues))
+        map.map(std::get<0>(it), std::get<1>(it));
+      for (auto &op : block) {
+        // Skip terminator.
+        if (&op == &block.back())
+          continue;
+        assert(op.getNumRegions() == 0);
+        auto *newOp = b.clone(op, map);
+        for (auto it : llvm::zip(op.getResults(), newOp->getResults()))
+          map.map(std::get<0>(it), std::get<1>(it));
+      }
+
+      // 3. Emit linalg_store.
+      auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
+      assert(yieldOp->getNumOperands() == nOutputs);
+      for (unsigned i = 0, e = nOutputs; i < e; ++i) {
+        ValueHandleArray indexing(foldedAffineApplies(
+            b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
+        linalg_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i),
+                     indexing);
+      }
+    }
+  }
+};
+
+template <typename ConcreteOp>
+class LinalgRewritePattern : public RewritePattern {
+public:
+  explicit LinalgRewritePattern(MLIRContext *context)
+      : RewritePattern(ConcreteOp::getOperationName(), /*benefit=*/1, context) {
+  }
+
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const override {
+    OpBuilder b(op);
+    ScopedContext scope(b, op->getLoc());
+
+    // The flattened loopToOperandRangesMaps is expected to be an invertible
+    // permutation map (which is asserted in the inverse calculation).
+    auto linalgOp = cast<ConcreteOp>(op);
+    auto invertedMap =
+        inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp)));
+    if (!invertedMap) {
+      LinalgScopedEmitter<ConcreteOp>::emitScalarImplementation({}, linalgOp,
+                                                                folder);
+      rewriter.replaceOp(op, {});
+      return matchSuccess();
+    }
+
+    auto nPar = linalgOp.getNumParallelLoops();
+    auto nRed = linalgOp.getNumReductionLoops();
+    auto nWin = linalgOp.getNumWindowLoops();
+    SmallVector<IndexHandle, 4> allIvs(nPar + nRed + nWin);
+    SmallVector<ValueHandle *, 4> allPIvs = makeIndexHandlePointers(allIvs);
+    auto pivs = MutableArrayRef<ValueHandle *>(allPIvs).take_front(nPar);
+    auto rivs = MutableArrayRef<ValueHandle *>(allPIvs)
+                    .take_front(nPar + nRed)
+                    .take_back(nRed);
+    auto wivs = MutableArrayRef<ValueHandle *>(allPIvs).take_back(nWin);
+
+    auto loopRanges =
+        emitLoopRanges(scope.getBuilder(), scope.getLocation(), invertedMap,
+                       getViewSizes(linalgOp), folder);
+    assert(loopRanges.size() == pivs.size() + rivs.size() + wivs.size());
+
+    // clang-format off
+    ArrayRef<Value *> ranges(loopRanges);
+    LoopNestRangeBuilder(pivs, ranges.take_front(nPar))([&] {
+      LoopNestRangeBuilder(rivs, ranges.drop_back(nWin).take_back(nRed))([&] {
+        LoopNestRangeBuilder(wivs, ranges.take_back(wivs.size()))(
+          [&linalgOp, &allIvs, this] {
+            auto allIvValues = extractValues(allIvs);
+            LinalgScopedEmitter<ConcreteOp>::emitScalarImplementation(
+                allIvValues, linalgOp, folder);
+        });
+      });
+    });
+    // clang-format on
+    rewriter.replaceOp(op, {});
+    return matchSuccess();
+  }
+
+  mutable OperationFolder folder;
+};
+
+// Helper classes for type list expansion.
+template <typename... LinalgOps> class ConversionList;
+
+template <> class ConversionList<> {
+public:
+  static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {}
+};
+
+template <typename ConcreteOp, typename... LinalgOps>
+class ConversionList<ConcreteOp, LinalgOps...> {
+public:
+  static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {
+    patterns.insert<LinalgRewritePattern<ConcreteOp>>(ctx);
+    ConversionList<LinalgOps...>::build(patterns, ctx);
+  }
+};
+
+/// Populate the given list with patterns that convert from Linalg to LLVM.
+static void
+populateLinalgToLoopRewritePatterns(OwningRewritePatternList &patterns,
+                                    MLIRContext *ctx) {
+  ConversionList<
+#define GET_OP_LIST
+#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
+      >::build(patterns, ctx);
+}
+
+namespace {
+struct LowerLinalgToLoopsPass : public FunctionPass<LowerLinalgToLoopsPass> {
+  void runOnFunction();
+};
+} // namespace
+
+void LowerLinalgToLoopsPass::runOnFunction() {
+  OwningRewritePatternList patterns;
+  populateLinalgToLoopRewritePatterns(patterns, &getContext());
+
+  ConversionTarget target(getContext());
+  target.addLegalDialect<AffineOpsDialect>();
+  target.addLegalDialect<loop::LoopOpsDialect>();
+  target.addLegalDialect<StandardOpsDialect>();
+  if (failed(applyPartialConversion(getFunction(), target, patterns))) {
+    signalPassFailure();
+  }
+}
+
+FunctionPassBase *mlir::linalg::createLowerLinalgToLoopsPass() {
+  return new LowerLinalgToLoopsPass();
+}
+
+static PassRegistration<LowerLinalgToLoopsPass>
+    pass("linalg-lower-to-loops",
+         "Lower the operations from the linalg dialect into loops");
diff --git a/third_party/mlir/lib/Linalg/Transforms/Tiling.cpp b/third_party/mlir/lib/Linalg/Transforms/Tiling.cpp
new file mode 100644
index 0000000..8090a58
--- /dev/null
+++ b/third_party/mlir/lib/Linalg/Transforms/Tiling.cpp
@@ -0,0 +1,542 @@
+//===- Tiling.cpp - Implementation of linalg Tiling -----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the linalg dialect Tiling pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/EDSC/Helpers.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineExprVisitor.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Linalg/IR/LinalgOps.h"
+#include "mlir/Linalg/IR/LinalgTypes.h"
+#include "mlir/Linalg/Passes.h"
+#include "mlir/Linalg/Utils/Intrinsics.h"
+#include "mlir/Linalg/Utils/Utils.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/STLExtras.h"
+#include "mlir/Transforms/FoldUtils.h"
+
+#include "llvm/Support/CommandLine.h"
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+using namespace mlir::linalg;
+using namespace mlir::linalg::intrinsics;
+using namespace mlir::loop;
+
+#define DEBUG_TYPE "linalg-tiling"
+
+static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
+static llvm::cl::list<unsigned>
+    clTileSizes("linalg-tile-sizes",
+                llvm::cl::desc("Tile sizes by which to tile linalg operations"),
+                llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
+                llvm::cl::cat(clOptionsCategory));
+static llvm::cl::opt<bool> clPromoteFullTileViews(
+    "linalg-tile-promote-full-tile-views",
+    llvm::cl::desc("Create scoped local buffers for tiled views "),
+    llvm::cl::init(false), llvm::cl::cat(clOptionsCategory));
+
+static bool isZero(Value *v) {
+  return isa_and_nonnull<ConstantIndexOp>(v->getDefiningOp()) &&
+         cast<ConstantIndexOp>(v->getDefiningOp()).getValue() == 0;
+}
+
+// Creates a number of ranges equal to the number of non-zero in `tileSizes`.
+// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has
+// one entry per surrounding loop. It uses zero as the convention that a
+// particular loop is not tiled. This convention simplifies implementations by
+// avoiding affine map manipulations.
+// The returned ranges correspond to the loop ranges, in the proper order, that
+// are tiled and for which new loops will be created.
+static SmallVector<SubViewOp::Range, 4>
+makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
+                    ArrayRef<Value *> allViewSizes,
+                    ArrayRef<Value *> allTileSizes, OperationFolder &folder) {
+  assert(allTileSizes.size() == map.getNumResults());
+  // Apply `map` to get view sizes in loop order.
+  auto viewSizes = applyMapToValues(b, loc, map, allViewSizes, folder);
+  SmallVector<Value *, 4> tileSizes(allTileSizes.begin(), allTileSizes.end());
+
+  // Traverse the tile sizes, which are in loop order, erase zeros everywhere.
+  for (int idx = tileSizes.size() - 1; idx >= 0; --idx) {
+    if (isZero(tileSizes[idx])) {
+      viewSizes.erase(viewSizes.begin() + idx);
+      tileSizes.erase(tileSizes.begin() + idx);
+    }
+  }
+
+  // Create a new range with the applied tile sizes.
+  SmallVector<SubViewOp::Range, 4> res;
+  for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
+    res.push_back(SubViewOp::Range{constant_index(folder, 0), viewSizes[idx],
+                                   tileSizes[idx]});
+  }
+  return res;
+}
+
+namespace {
+// Helper visitor to determine whether an AffineExpr is tiled.
+// This is achieved by traversing every AffineDimExpr with position `pos` and
+// checking whether the corresponding `tileSizes[pos]` is non-zero.
+// This also enforces only positive coefficients occur in multiplications.
+//
+// Example:
+//   `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
+//
+struct TileCheck : public AffineExprVisitor<TileCheck> {
+  TileCheck(ArrayRef<Value *> tileSizes)
+      : isTiled(false), tileSizes(tileSizes) {}
+
+  void visitDimExpr(AffineDimExpr expr) {
+    isTiled |= !isZero(tileSizes[expr.getPosition()]);
+  }
+  void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
+    visit(expr.getLHS());
+    visit(expr.getRHS());
+    if (expr.getKind() == mlir::AffineExprKind::Mul)
+      assert(expr.getRHS().cast<AffineConstantExpr>().getValue() > 0 &&
+             "nonpositive multipliying coefficient");
+  }
+  bool isTiled;
+  ArrayRef<Value *> tileSizes;
+};
+} // namespace
+
+static bool isTiled(AffineExpr expr, ArrayRef<Value *> tileSizes) {
+  if (!expr)
+    return false;
+  TileCheck t(tileSizes);
+  t.visit(expr);
+  return t.isTiled;
+}
+
+// Checks whether the view with index `viewIndex` within `linalgOp` varies with
+// respect to a non-zero `tileSize`.
+static bool isTiled(AffineMap map, ArrayRef<Value *> tileSizes) {
+  if (!map)
+    return false;
+  for (unsigned r = 0; r < map.getNumResults(); ++r)
+    if (isTiled(map.getResult(r), tileSizes))
+      return true;
+  return false;
+}
+
+static SmallVector<Value *, 4>
+makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
+               ArrayRef<Value *> ivs, ArrayRef<Value *> tileSizes,
+               ArrayRef<Value *> viewSizes, OperationFolder &folder) {
+  assert(ivs.size() == static_cast<size_t>(llvm::count_if(
+                           llvm::make_range(tileSizes.begin(), tileSizes.end()),
+                           [](Value *v) { return !isZero(v); })) &&
+         "expected as many ivs as non-zero sizes");
+
+  using edsc::intrinsics::select;
+  using edsc::op::operator+;
+  using edsc::op::operator<;
+
+  // Construct (potentially temporary) mins and maxes on which to apply maps
+  // that define tile subviews.
+  SmallVector<Value *, 8> mins, maxes;
+  for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
+    if (isZero(tileSizes[idx])) {
+      mins.push_back(constant_index(folder, 0));
+      maxes.push_back(viewSizes[idx]);
+    } else {
+      ValueHandle lb(ivs[idxIvs++]), step(tileSizes[idx]);
+      mins.push_back(lb);
+      maxes.push_back(lb + step);
+    }
+  }
+
+  auto *op = linalgOp.getOperation();
+
+  SmallVector<Value *, 4> res;
+  res.reserve(op->getNumOperands());
+  auto viewIteratorBegin = linalgOp.getInputsAndOutputs().begin();
+  for (unsigned viewIndex = 0; viewIndex < linalgOp.getNumInputsAndOutputs();
+       ++viewIndex) {
+    Value *view = *(viewIteratorBegin + viewIndex);
+    unsigned viewRank = view->getType().cast<ViewType>().getRank();
+    auto map = loopToOperandRangesMaps(linalgOp)[viewIndex];
+    // If the view is not tiled, we can use it as is.
+    if (!isTiled(map, tileSizes)) {
+      res.push_back(view);
+      continue;
+    }
+
+    // Construct a new subview for the tile.
+    SmallVector<SubViewOp::Range, 4> subViewOperands;
+    subViewOperands.reserve(viewRank * 3);
+    for (unsigned r = 0; r < viewRank; ++r) {
+      if (!isTiled(map.getSubMap({r}), tileSizes)) {
+        subViewOperands.push_back(SubViewOp::Range{
+            constant_index(folder, 0), linalg::intrinsics::dim(view, r),
+            constant_index(folder, 1)});
+        continue;
+      }
+
+      auto m = map.getSubMap({r});
+      auto *min = applyMapToValues(b, loc, m, mins, folder).front();
+      auto *max = applyMapToValues(b, loc, m, maxes, folder).front();
+      // Tiling creates a new slice at the proper index, the slice step is 1
+      // (i.e. the slice view does not subsample, stepping occurs in the loop).
+      subViewOperands.push_back(
+          SubViewOp::Range{min, max, constant_index(folder, 1)});
+    }
+    res.push_back(b.create<SubViewOp>(loc, view, subViewOperands));
+  }
+
+  // Traverse the mins/maxes and erase those that don't have uses left.
+  mins.append(maxes.begin(), maxes.end());
+  for (auto *v : mins)
+    if (v->use_empty())
+      v->getDefiningOp()->erase();
+
+  return res;
+}
+
+static AffineMap getAffineDifferenceMap(MLIRContext *context) {
+  AffineExpr d0(getAffineDimExpr(0, context)), d1(getAffineDimExpr(1, context));
+  return AffineMap::get(2, 0, {d0 - d1});
+}
+
+static Value *allocBuffer(Type elementType, Value *size) {
+  if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size->getDefiningOp()))
+    return buffer_alloc(
+        BufferType::get(size->getContext(), elementType, cst.getValue()));
+  return buffer_alloc(BufferType::get(size->getContext(), elementType), size);
+}
+
+// Performs promotion of a `subView` into a local buffer of the size of the
+// *ranges* of the `subView`. This produces a buffer whose size may be bigger
+// than the actual size of the `subView` at the boundaries.
+// This is related to the full/partial tile problem.
+// Returns a PromotionInfo containing a `buffer`, `fullLocalView` and
+// `partialLocalView` such that:
+//   * `buffer` is always the size of the full tile.
+//   * `fullLocalView` is a dense contiguous view into that buffer.
+//   * `partialLocalView` is a dense non-contiguous slice of `fullLocalView`
+//     that corresponds to the size of `subView` and accounting for boundary
+//     effects.
+// The point of the full tile buffer is that constant static tile sizes are
+// folded and result in a buffer type with statically known size and alignment
+// properties.
+// To account for general boundary effects, padding must be performed on the
+// boundary tiles. For now this is done with an unconditional `fill` op followed
+// by a partial `copy` op.
+static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
+                                           SubViewOp subView,
+                                           OperationFolder &folder) {
+  auto zero = constant_index(folder, 0);
+  auto one = constant_index(folder, 1);
+
+  auto viewType = subView.getViewType();
+  auto rank = viewType.getRank();
+  Value *allocSize = one;
+  SmallVector<Value *, 8> fullRanges, partialRanges;
+  fullRanges.reserve(rank);
+  partialRanges.reserve(rank);
+  for (auto en : llvm::enumerate(subView.getRanges())) {
+    auto rank = en.index();
+    auto rangeValue = en.value();
+    Value *d =
+        isa<linalg::DimOp>(rangeValue.max->getDefiningOp())
+            ? rangeValue.max
+            : applyMapToValues(b, loc, getAffineDifferenceMap(b.getContext()),
+                               {rangeValue.max, rangeValue.min}, folder)
+                  .front();
+    allocSize = muli(folder, allocSize, d).getValue();
+    fullRanges.push_back(range(folder, zero, d, one));
+    partialRanges.push_back(
+        range(folder, zero, linalg::intrinsics::dim(subView, rank), one));
+  }
+  auto *buffer = allocBuffer(viewType.getElementType(), allocSize);
+  auto fullLocalView = view(buffer, fullRanges);
+  auto partialLocalView = slice(fullLocalView, partialRanges);
+  return PromotionInfo{buffer, fullLocalView, partialLocalView};
+}
+
+// Performs promotion of a view `v` into a local buffer of the size of the
+// view. This produces a buffer whose size is exactky the size of `v`.
+// Returns a PromotionInfo containing a `buffer`, `fullLocalView` and
+// `partialLocalView` such that:
+//   * `buffer` is always the size of the view.
+//   * `partialLocalView` is a dense contiguous view into that buffer.
+//   * `fullLocalView` is equal to `partialLocalView`.
+// The point of the full tile buffer is that constant static tile sizes are
+// folded and result in a buffer type with statically known size and alignment
+// properties.
+static PromotionInfo promotePartialTileBuffer(OpBuilder &b, Location loc,
+                                              Value *v,
+                                              OperationFolder &folder) {
+  auto zero = constant_index(folder, 0);
+  auto one = constant_index(folder, 1);
+
+  auto viewType = v->getType().cast<ViewType>();
+  auto rank = viewType.getRank();
+  Value *allocSize = one;
+  SmallVector<Value *, 8> partialRanges;
+  partialRanges.reserve(rank);
+  for (unsigned r = 0; r < rank; ++r) {
+    Value *d = linalg::intrinsics::dim(v, r);
+    allocSize = muli(folder, allocSize, d).getValue();
+    partialRanges.push_back(range(folder, zero, d, one));
+  }
+  auto *buffer = allocBuffer(viewType.getElementType(), allocSize);
+  auto partialLocalView = view(folder, buffer, partialRanges);
+  return PromotionInfo{buffer, partialLocalView, partialLocalView};
+}
+
+SmallVector<PromotionInfo, 8>
+mlir::linalg::promoteLinalgViews(OpBuilder &b, Location loc,
+                                 ArrayRef<Value *> views,
+                                 OperationFolder &folder) {
+  if (views.empty())
+    return {};
+
+  ScopedContext scope(b, loc);
+  SmallVector<PromotionInfo, 8> res;
+  res.reserve(views.size());
+  DenseMap<Value *, PromotionInfo> promotionInfo;
+  for (auto *v : views) {
+    PromotionInfo pi;
+    if (auto subView = dyn_cast<SubViewOp>(v->getDefiningOp()))
+      pi = promoteFullTileBuffer(b, loc, subView, folder);
+    else
+      pi = promotePartialTileBuffer(b, loc, v, folder);
+    promotionInfo.insert(std::make_pair(v, pi));
+    res.push_back(pi);
+  }
+
+  for (auto *v : views) {
+    auto info = promotionInfo.find(v);
+    if (info == promotionInfo.end())
+      continue;
+    auto viewType = v->getType().cast<ViewType>();
+    // TODO(ntv): value to fill with should be related to the operation.
+    // For now, just use APFloat(0.0f).
+    auto t = viewType.getElementType().cast<FloatType>();
+    Value *fillVal = constant_float(folder, APFloat(0.0f), t);
+    // TODO(ntv): fill is only necessary if `promotionInfo` has a full local
+    // view that is different from the partial local view and we are on the
+    // boundary.
+    fill(info->second.fullLocalView, fillVal);
+  }
+
+  for (auto *v : views) {
+    auto info = promotionInfo.find(v);
+    if (info == promotionInfo.end())
+      continue;
+    copy(v, info->second.partialLocalView);
+  }
+  return res;
+}
+
+llvm::Optional<TiledLinalgOp>
+mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<Value *> tileSizes,
+                           OperationFolder &folder,
+                           ArrayRef<bool> viewsToPromote) {
+  // 1. Enforce the convention that "tiling by zero" skips tiling a particular
+  // dimension. This convention is significantly simpler to handle instead of
+  // adjusting affine maps to account for missing dimensions.
+  assert(op.getNumParallelLoops() + op.getNumReductionLoops() +
+                 op.getNumWindowLoops() ==
+             tileSizes.size() &&
+         "expected matching number of tile sizes and loops");
+
+  OpBuilder builder(op.getOperation());
+  ScopedContext scope(builder, op.getLoc());
+  // 2. Build the tiled loop ranges.
+  auto viewSizes = getViewSizes(op);
+  // The flattened loopToOperandRangesMaps is expected to be an invertible
+  // permutation map (asserted in the inverse calculation).
+  auto viewSizesToLoopsMap =
+      inversePermutation(concatAffineMaps(loopToOperandRangesMaps(op)));
+  assert(viewSizesToLoopsMap && "expected invertible map");
+  auto loopRanges =
+      makeTiledLoopRanges(scope.getBuilder(), scope.getLocation(),
+                          viewSizesToLoopsMap, viewSizes, tileSizes, folder);
+
+  // 3. Create the tiled loops.
+  LinalgOp res = op;
+  SmallVector<IndexHandle, 4> ivs(loopRanges.size());
+  auto pivs = makeIndexHandlePointers(ivs);
+  LoopNestRangeBuilder(pivs, loopRanges)([&] {
+    auto b = ScopedContext::getBuilder();
+    auto loc = ScopedContext::getLocation();
+    SmallVector<Value *, 4> ivValues(ivs.begin(), ivs.end());
+    auto views =
+        makeTiledViews(b, loc, op, ivValues, tileSizes, viewSizes, folder);
+
+    // If no promotion, we are done.
+    auto promote = !viewsToPromote.empty() &&
+                   llvm::any_of(llvm::make_range(viewsToPromote.begin(),
+                                                 viewsToPromote.end()),
+                                [](bool b) { return b; });
+    if (!promote) {
+      auto operands = getAssumedNonViewOperands(op);
+      views.append(operands.begin(), operands.end());
+      res = op.create(b, loc, views, op.getAttrs());
+      return;
+    }
+
+    // 4. Filter the subset of views that need to be promoted.
+    SmallVector<Value *, 8> filteredViews;
+    filteredViews.reserve(views.size());
+    assert((viewsToPromote.empty() || views.size() == viewsToPromote.size()) &&
+           "expected viewsToPromote to be empty or of the same size as view");
+    for (auto it : llvm::zip(views, viewsToPromote)) {
+      if (!std::get<1>(it))
+        continue;
+      filteredViews.push_back(std::get<0>(it));
+    }
+
+    // 5. Promote the specified views and use them in the new op.
+    auto promotedBufferAndViews =
+        promoteLinalgViews(b, loc, filteredViews, folder);
+    SmallVector<Value *, 8> opViews(views.size(), nullptr);
+    SmallVector<Value *, 8> writebackViews(views.size(), nullptr);
+    for (unsigned i = 0, promotedIdx = 0, e = opViews.size(); i < e; ++i) {
+      if (viewsToPromote[i]) {
+        opViews[i] = promotedBufferAndViews[promotedIdx].fullLocalView;
+        writebackViews[i] =
+            promotedBufferAndViews[promotedIdx].partialLocalView;
+        promotedIdx++;
+      } else {
+        opViews[i] = views[i];
+      }
+    }
+    auto operands = getAssumedNonViewOperands(op);
+    opViews.append(operands.begin(), operands.end());
+    res = op.create(b, loc, opViews, op.getAttrs());
+
+    // 6. Emit write-back for the promoted output views: copy the partial view.
+    for (unsigned i = 0, e = writebackViews.size(); i < e; ++i) {
+      bool isOutput = res.getIndexOfOutput(opViews[i]).hasValue();
+      if (writebackViews[i] && isOutput)
+        copy(writebackViews[i], views[i]);
+    }
+
+    // 7. Dealloc local buffers.
+    for (const auto &pi : promotedBufferAndViews)
+      buffer_dealloc(pi.buffer);
+  });
+
+  // 8. Gather the newly created loops and return them with the new op.
+  SmallVector<ForOp, 8> loops;
+  loops.reserve(ivs.size());
+  for (auto iv : ivs)
+    loops.push_back(loop::getForInductionVarOwner(iv));
+
+  return TiledLinalgOp{res, loops};
+}
+
+llvm::Optional<TiledLinalgOp>
+mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<int64_t> tileSizes,
+                           OperationFolder &folder,
+                           ArrayRef<bool> viewsToPromote) {
+  if (tileSizes.empty())
+    return llvm::None;
+
+  // The following uses the convention that "tiling by zero" skips tiling a
+  // particular dimension. This convention is significantly simpler to handle
+  // instead of adjusting affine maps to account for missing dimensions.
+  auto nLoops = op.getNumParallelLoops() + op.getNumReductionLoops() +
+                op.getNumWindowLoops();
+  tileSizes = tileSizes.take_front(nLoops);
+  // If only 0 tilings are left, then return.
+  if (llvm::all_of(tileSizes, [](int64_t v) { return v == 0; }))
+    return llvm::None;
+
+  // Create a builder for tile size constants.
+  OpBuilder builder(op);
+  ScopedContext scope(builder, op.getLoc());
+
+  // Materialize concrete tile size values to pass the generic tiling function.
+  SmallVector<Value *, 8> tileSizeValues;
+  tileSizeValues.reserve(tileSizes.size());
+  for (auto ts : tileSizes)
+    tileSizeValues.push_back(constant_index(folder, ts));
+  // Pad tile sizes with zero values to enforce our convention.
+  if (tileSizeValues.size() < nLoops) {
+    for (unsigned i = tileSizeValues.size(); i < nLoops; ++i)
+      tileSizeValues.push_back(constant_index(folder, 0));
+  }
+
+  return tileLinalgOp(op, tileSizeValues, folder, viewsToPromote);
+}
+
+static void tileLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes,
+                          bool promoteViews) {
+  OperationFolder folder;
+  f.walk<LinalgOp>([promoteViews, tileSizes, &folder](LinalgOp op) {
+    // TODO(ntv) some heuristic here to decide what to promote. Atm it is all or
+    // nothing.
+    SmallVector<bool, 8> viewsToPromote(op.getNumInputsAndOutputs(),
+                                        promoteViews);
+    auto opLoopsPair = tileLinalgOp(op, tileSizes, folder, viewsToPromote);
+    // If tiling occurred successfully, erase old op.
+    if (opLoopsPair)
+      op.erase();
+  });
+  f.walk<LinalgOp>([](LinalgOp op) {
+    if (!op.getOperation()->hasNoSideEffect())
+      return;
+    if (op.getOperation()->use_empty())
+      op.erase();
+  });
+}
+
+namespace {
+struct LinalgTilingPass : public FunctionPass<LinalgTilingPass> {
+  LinalgTilingPass() = default;
+  LinalgTilingPass(ArrayRef<int64_t> sizes, bool promoteViews);
+
+  void runOnFunction() {
+    tileLinalgOps(getFunction(), tileSizes, promoteViews);
+  }
+
+  SmallVector<int64_t, 8> tileSizes;
+  bool promoteViews;
+};
+} // namespace
+
+LinalgTilingPass::LinalgTilingPass(ArrayRef<int64_t> sizes, bool promoteViews) {
+  this->tileSizes.assign(sizes.begin(), sizes.end());
+  this->promoteViews = promoteViews;
+}
+
+FunctionPassBase *
+mlir::linalg::createLinalgTilingPass(ArrayRef<int64_t> tileSizes,
+                                     bool promoteViews) {
+  return new LinalgTilingPass(tileSizes, promoteViews);
+}
+
+static PassRegistration<LinalgTilingPass>
+    pass("linalg-tile", "Tile operations in the linalg dialect", [] {
+      auto *pass = new LinalgTilingPass();
+      pass->tileSizes.assign(clTileSizes.begin(), clTileSizes.end());
+      pass->promoteViews = clPromoteFullTileViews;
+      return pass;
+    });
diff --git a/third_party/mlir/lib/Linalg/Utils/Utils.cpp b/third_party/mlir/lib/Linalg/Utils/Utils.cpp
new file mode 100644
index 0000000..d31fe0d
--- /dev/null
+++ b/third_party/mlir/lib/Linalg/Utils/Utils.cpp
@@ -0,0 +1,155 @@
+//===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements utilities for the Linalg dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/EDSC/Helpers.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Linalg/IR/LinalgOps.h"
+#include "mlir/Linalg/IR/LinalgTypes.h"
+#include "mlir/Linalg/Passes.h"
+#include "mlir/Linalg/Utils/Intrinsics.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Support/STLExtras.h"
+#include "mlir/Transforms/FoldUtils.h"
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+using namespace mlir::linalg;
+using namespace mlir::linalg::intrinsics;
+using namespace mlir::loop;
+
+mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(ValueHandle *iv,
+                                               ValueHandle range) {
+  assert(range.getType() && "expected !linalg.range type");
+  assert(range.getValue()->getDefiningOp() &&
+         "need operations to extract range parts");
+  auto rangeOp = cast<RangeOp>(range.getValue()->getDefiningOp());
+  auto lb = rangeOp.min();
+  auto ub = rangeOp.max();
+  auto step = rangeOp.step();
+  auto forOp = OperationHandle::createOp<ForOp>(lb, ub, step);
+  *iv = ValueHandle(forOp.getInductionVar());
+  auto *body = forOp.getBody();
+  enter(body, /*prev=*/1);
+}
+
+mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(ValueHandle *iv,
+                                               SubViewOp::Range range) {
+  auto forOp =
+      OperationHandle::createOp<ForOp>(range.min, range.max, range.step);
+  *iv = ValueHandle(forOp.getInductionVar());
+  auto *body = forOp.getBody();
+  enter(body, /*prev=*/1);
+}
+
+ValueHandle
+mlir::edsc::LoopRangeBuilder::operator()(std::function<void(void)> fun) {
+  if (fun)
+    fun();
+  exit();
+  return ValueHandle::null();
+}
+
+mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder(
+    ArrayRef<ValueHandle *> ivs, ArrayRef<SubViewOp::Range> ranges) {
+  loops.reserve(ranges.size());
+  for (unsigned i = 0, e = ranges.size(); i < e; ++i) {
+    loops.emplace_back(ivs[i], ranges[i]);
+  }
+  assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size");
+}
+
+mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder(
+    ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> ranges) {
+  loops.reserve(ranges.size());
+  for (unsigned i = 0, e = ranges.size(); i < e; ++i) {
+    loops.emplace_back(ivs[i], ranges[i]);
+  }
+  assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size");
+}
+
+mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder(
+    ArrayRef<ValueHandle *> ivs, ArrayRef<Value *> ranges)
+    : LoopNestRangeBuilder(
+          ivs, SmallVector<ValueHandle, 4>(ranges.begin(), ranges.end())) {}
+
+ValueHandle LoopNestRangeBuilder::LoopNestRangeBuilder::operator()(
+    std::function<void(void)> fun) {
+  if (fun)
+    fun();
+  for (auto &lit : reverse(loops)) {
+    lit({});
+  }
+  return ValueHandle::null();
+}
+
+static Value *emitOrFoldComposedAffineApply(OpBuilder &b, Location loc,
+                                            AffineMap map,
+                                            ArrayRef<Value *> operandsRef,
+                                            OperationFolder &state) {
+  SmallVector<Value *, 4> operands(operandsRef.begin(), operandsRef.end());
+  fullyComposeAffineMapAndOperands(&map, &operands);
+  canonicalizeMapAndOperands(&map, &operands);
+  return state.create<AffineApplyOp>(b, loc, map, operands);
+}
+
+SmallVector<Value *, 4> mlir::linalg::applyMapToValues(OpBuilder &b,
+                                                       Location loc,
+                                                       AffineMap map,
+                                                       ArrayRef<Value *> values,
+                                                       OperationFolder &state) {
+  SmallVector<Value *, 4> res;
+  res.reserve(map.getNumResults());
+  unsigned numDims = map.getNumDims();
+  // For each `expr` in `map`, applies the `expr` to the values extracted from
+  // ranges. If the resulting application can be folded into a Value*, the
+  // folding occurs eagerly. Otherwise, an affine.apply operation is emitted.
+  for (auto expr : map.getResults()) {
+    AffineMap map = AffineMap::get(numDims, 0, expr);
+    res.push_back(emitOrFoldComposedAffineApply(b, loc, map, values, state));
+  }
+  return res;
+}
+
+/// Returns all the operands of `linalgOp` that are not views.
+/// Asserts that these operands are value types to allow transformations like
+/// tiling to just use the values when cloning `linalgOp`.
+SmallVector<Value *, 4>
+mlir::linalg::getAssumedNonViewOperands(LinalgOp linalgOp) {
+  auto *op = linalgOp.getOperation();
+  unsigned numViews = linalgOp.getNumInputsAndOutputs();
+  unsigned nOperands = op->getNumOperands() - numViews;
+  SmallVector<Value *, 4> res;
+  res.reserve(nOperands);
+  for (unsigned i = 0; i < nOperands; ++i) {
+    res.push_back(op->getOperand(numViews + i));
+    auto t = res.back()->getType();
+    (void)t;
+    assert((t.isIntOrIndexOrFloat() || t.isa<VectorType>()) &&
+           "expected scalar or vector type");
+  }
+  return res;
+}
diff --git a/third_party/mlir/lib/Parser/CMakeLists.txt b/third_party/mlir/lib/Parser/CMakeLists.txt
new file mode 100644
index 0000000..9fd29ae
--- /dev/null
+++ b/third_party/mlir/lib/Parser/CMakeLists.txt
@@ -0,0 +1,10 @@
+add_llvm_library(MLIRParser
+  Lexer.cpp
+  Parser.cpp
+  Token.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Parser
+  )
+add_dependencies(MLIRParser MLIRIR MLIRAnalysis)
+target_link_libraries(MLIRParser MLIRIR MLIRAnalysis)
diff --git a/third_party/mlir/lib/Parser/Lexer.cpp b/third_party/mlir/lib/Parser/Lexer.cpp
new file mode 100644
index 0000000..f63b7fc
--- /dev/null
+++ b/third_party/mlir/lib/Parser/Lexer.cpp
@@ -0,0 +1,400 @@
+//===- Lexer.cpp - MLIR Lexer Implementation ------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the lexer for the MLIR textual form.
+//
+//===----------------------------------------------------------------------===//
+
+#include "Lexer.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Identifier.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "llvm/Support/SourceMgr.h"
+using namespace mlir;
+
+using llvm::SMLoc;
+using llvm::SourceMgr;
+
+// Returns true if 'c' is an allowable puncuation character: [$._-]
+// Returns false otherwise.
+static bool isPunct(char c) {
+  return c == '$' || c == '.' || c == '_' || c == '-';
+}
+
+Lexer::Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context)
+    : sourceMgr(sourceMgr), context(context) {
+  auto bufferID = sourceMgr.getMainFileID();
+  curBuffer = sourceMgr.getMemoryBuffer(bufferID)->getBuffer();
+  curPtr = curBuffer.begin();
+}
+
+/// Encode the specified source location information into an attribute for
+/// attachment to the IR.
+Location Lexer::getEncodedSourceLocation(llvm::SMLoc loc) {
+  auto &sourceMgr = getSourceMgr();
+  unsigned mainFileID = sourceMgr.getMainFileID();
+  auto lineAndColumn = sourceMgr.getLineAndColumn(loc, mainFileID);
+  auto *buffer = sourceMgr.getMemoryBuffer(mainFileID);
+
+  return FileLineColLoc::get(buffer->getBufferIdentifier(), lineAndColumn.first,
+                             lineAndColumn.second, context);
+}
+
+/// emitError - Emit an error message and return an Token::error token.
+Token Lexer::emitError(const char *loc, const Twine &message) {
+  mlir::emitError(getEncodedSourceLocation(SMLoc::getFromPointer(loc)),
+                  message);
+  return formToken(Token::error, loc);
+}
+
+Token Lexer::lexToken() {
+  // Ignore whitespace.
+  while (true) {
+    switch (*curPtr) {
+    case ' ':
+    case '\t':
+    case '\n':
+    case '\r':
+      ++curPtr;
+      continue;
+    default:
+      // Terminate loop on non-whitespace, including either an embedded or
+      // final terminating nul character that llvm::MemoryBuffer guarantees
+      // will be there.
+      break;
+    }
+    break;
+  }
+
+  const char *tokStart = curPtr;
+  switch (*curPtr++) {
+  default:
+    // Handle bare identifiers.
+    if (isalpha(curPtr[-1]))
+      return lexBareIdentifierOrKeyword(tokStart);
+
+    // Unknown character, emit an error.
+    return emitError(tokStart, "unexpected character");
+
+  case '_':
+    // Handle bare identifiers.
+    return lexBareIdentifierOrKeyword(tokStart);
+
+  case 0:
+    // This may either be a nul character in the source file or may be the EOF
+    // marker that llvm::MemoryBuffer guarantees will be there.
+    if (curPtr - 1 == curBuffer.end())
+      return formToken(Token::eof, tokStart);
+
+    LLVM_FALLTHROUGH;
+  case ':':
+    return formToken(Token::colon, tokStart);
+  case ',':
+    return formToken(Token::comma, tokStart);
+  case '.':
+    return lexEllipsis(tokStart);
+  case '(':
+    return formToken(Token::l_paren, tokStart);
+  case ')':
+    return formToken(Token::r_paren, tokStart);
+  case '{':
+    return formToken(Token::l_brace, tokStart);
+  case '}':
+    return formToken(Token::r_brace, tokStart);
+  case '[':
+    return formToken(Token::l_square, tokStart);
+  case ']':
+    return formToken(Token::r_square, tokStart);
+  case '<':
+    return formToken(Token::less, tokStart);
+  case '>':
+    return formToken(Token::greater, tokStart);
+  case '=':
+    return formToken(Token::equal, tokStart);
+
+  case '+':
+    return formToken(Token::plus, tokStart);
+  case '*':
+    return formToken(Token::star, tokStart);
+  case '-':
+    if (*curPtr == '>') {
+      ++curPtr;
+      return formToken(Token::arrow, tokStart);
+    }
+    return formToken(Token::minus, tokStart);
+
+  case '?':
+    return formToken(Token::question, tokStart);
+
+  case '/':
+    if (*curPtr == '/')
+      return lexComment();
+    return emitError(tokStart, "unexpected character");
+
+  case '@':
+    return lexAtIdentifier(tokStart);
+
+  case '!':
+    LLVM_FALLTHROUGH;
+  case '^':
+    LLVM_FALLTHROUGH;
+  case '#':
+    LLVM_FALLTHROUGH;
+  case '%':
+    return lexPrefixedIdentifier(tokStart);
+  case '"':
+    return lexString(tokStart);
+
+  case '0':
+  case '1':
+  case '2':
+  case '3':
+  case '4':
+  case '5':
+  case '6':
+  case '7':
+  case '8':
+  case '9':
+    return lexNumber(tokStart);
+  }
+}
+
+/// Lex an '@foo' identifier.
+///
+///   symbol-ref-id ::= `@` bare-id
+///
+Token Lexer::lexAtIdentifier(const char *tokStart) {
+  // These always start with a letter or underscore.
+  auto cur = *curPtr++;
+  if (!isalpha(cur) && cur != '_')
+    return emitError(curPtr - 1,
+                     "@ identifier expected to start with letter or '_'");
+
+  while (isalpha(*curPtr) || isdigit(*curPtr) || *curPtr == '_' ||
+         *curPtr == '$' || *curPtr == '.')
+    ++curPtr;
+  return formToken(Token::at_identifier, tokStart);
+}
+
+/// Lex a bare identifier or keyword that starts with a letter.
+///
+///   bare-id ::= (letter|[_]) (letter|digit|[_$.])*
+///   integer-type ::= `i[1-9][0-9]*`
+///
+Token Lexer::lexBareIdentifierOrKeyword(const char *tokStart) {
+  // Match the rest of the identifier regex: [0-9a-zA-Z_.$]*
+  while (isalpha(*curPtr) || isdigit(*curPtr) || *curPtr == '_' ||
+         *curPtr == '$' || *curPtr == '.')
+    ++curPtr;
+
+  // Check to see if this identifier is a keyword.
+  StringRef spelling(tokStart, curPtr - tokStart);
+
+  // Check for i123.
+  if (tokStart[0] == 'i') {
+    bool allDigits = true;
+    for (auto c : spelling.drop_front())
+      allDigits &= isdigit(c) != 0;
+    if (allDigits && spelling.size() != 1)
+      return Token(Token::inttype, spelling);
+  }
+
+  Token::Kind kind = llvm::StringSwitch<Token::Kind>(spelling)
+#define TOK_KEYWORD(SPELLING) .Case(#SPELLING, Token::kw_##SPELLING)
+#include "TokenKinds.def"
+                         .Default(Token::bare_identifier);
+
+  return Token(kind, spelling);
+}
+
+/// Lex a comment line, starting with a semicolon.
+///
+///   TODO: add a regex for comments here and to the spec.
+///
+Token Lexer::lexComment() {
+  // Advance over the second '/' in a '//' comment.
+  assert(*curPtr == '/');
+  ++curPtr;
+
+  while (true) {
+    switch (*curPtr++) {
+    case '\n':
+    case '\r':
+      // Newline is end of comment.
+      return lexToken();
+    case 0:
+      // If this is the end of the buffer, end the comment.
+      if (curPtr - 1 == curBuffer.end()) {
+        --curPtr;
+        return lexToken();
+      }
+      LLVM_FALLTHROUGH;
+    default:
+      // Skip over other characters.
+      break;
+    }
+  }
+}
+
+/// Lex an ellipsis.
+///
+///   ellipsis ::= '...'
+///
+Token Lexer::lexEllipsis(const char *tokStart) {
+  assert(curPtr[-1] == '.');
+
+  if (curPtr == curBuffer.end() || *curPtr != '.' || *(curPtr + 1) != '.')
+    return emitError(curPtr, "expected three consecutive dots for an ellipsis");
+
+  curPtr += 2;
+  return formToken(Token::ellipsis, tokStart);
+}
+
+/// Lex a number literal.
+///
+///   integer-literal ::= digit+ | `0x` hex_digit+
+///   float-literal ::= [-+]?[0-9]+[.][0-9]*([eE][-+]?[0-9]+)?
+///
+Token Lexer::lexNumber(const char *tokStart) {
+  assert(isdigit(curPtr[-1]));
+
+  // Handle the hexadecimal case.
+  if (curPtr[-1] == '0' && *curPtr == 'x') {
+    // If we see stuff like 0xi32, this is a literal `0` follwed by an
+    // identifier `xi32`, stop after `0`.
+    if (!isxdigit(curPtr[1]))
+      return formToken(Token::integer, tokStart);
+
+    curPtr += 2;
+    while (isxdigit(*curPtr))
+      ++curPtr;
+
+    return formToken(Token::integer, tokStart);
+  }
+
+  // Handle the normal decimal case.
+  while (isdigit(*curPtr))
+    ++curPtr;
+
+  if (*curPtr != '.')
+    return formToken(Token::integer, tokStart);
+  ++curPtr;
+
+  // Skip over [0-9]*([eE][-+]?[0-9]+)?
+  while (isdigit(*curPtr))
+    ++curPtr;
+
+  if (*curPtr == 'e' || *curPtr == 'E') {
+    if (isdigit(static_cast<unsigned char>(curPtr[1])) ||
+        ((curPtr[1] == '-' || curPtr[1] == '+') &&
+         isdigit(static_cast<unsigned char>(curPtr[2])))) {
+      curPtr += 2;
+      while (isdigit(*curPtr))
+        ++curPtr;
+    }
+  }
+  return formToken(Token::floatliteral, tokStart);
+}
+
+/// Lex an identifier that starts with a prefix followed by suffix-id.
+///
+///   affine-map-id ::= `#` suffix-id
+///   ssa-id        ::= '%' suffix-id
+///   block-id      ::= '^' suffix-id
+///   type-id       ::= '!' suffix-id
+///   suffix-id     ::= digit+ | (letter|id-punct) (letter|id-punct|digit)*
+///
+Token Lexer::lexPrefixedIdentifier(const char *tokStart) {
+  Token::Kind kind;
+  StringRef errorKind;
+  switch (*tokStart) {
+  case '#':
+    kind = Token::hash_identifier;
+    errorKind = "invalid attribute name";
+    break;
+  case '%':
+    kind = Token::percent_identifier;
+    errorKind = "invalid SSA name";
+    break;
+  case '^':
+    kind = Token::caret_identifier;
+    errorKind = "invalid block name";
+    break;
+  case '!':
+    kind = Token::exclamation_identifier;
+    errorKind = "invalid type identifier";
+    break;
+  default:
+    llvm_unreachable("invalid caller");
+  }
+
+  // Parse suffix-id.
+  if (isdigit(*curPtr)) {
+    // If suffix-id starts with a digit, the rest must be digits.
+    while (isdigit(*curPtr)) {
+      ++curPtr;
+    }
+  } else if (isalpha(*curPtr) || isPunct(*curPtr)) {
+    do {
+      ++curPtr;
+    } while (isalpha(*curPtr) || isdigit(*curPtr) || isPunct(*curPtr));
+  } else {
+    return emitError(curPtr - 1, errorKind);
+  }
+
+  return formToken(kind, tokStart);
+}
+
+/// Lex a string literal.
+///
+///   string-literal ::= '"' [^"\n\f\v\r]* '"'
+///
+/// TODO: define escaping rules.
+Token Lexer::lexString(const char *tokStart) {
+  assert(curPtr[-1] == '"');
+
+  while (1) {
+    switch (*curPtr++) {
+    case '"':
+      return formToken(Token::string, tokStart);
+    case 0:
+      // If this is a random nul character in the middle of a string, just
+      // include it.  If it is the end of file, then it is an error.
+      if (curPtr - 1 != curBuffer.end())
+        continue;
+      LLVM_FALLTHROUGH;
+    case '\n':
+    case '\v':
+    case '\f':
+      return emitError(curPtr - 1, "expected '\"' in string literal");
+    case '\\':
+      // Handle explicitly a few escapes.
+      if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' || *curPtr == 't')
+        ++curPtr;
+      else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1]))
+        // Support \xx for two hex digits.
+        curPtr += 2;
+      else
+        return emitError(curPtr - 1, "unknown escape in string literal");
+      continue;
+
+    default:
+      continue;
+    }
+  }
+}
diff --git a/third_party/mlir/lib/Parser/Lexer.h b/third_party/mlir/lib/Parser/Lexer.h
new file mode 100644
index 0000000..896c26c
--- /dev/null
+++ b/third_party/mlir/lib/Parser/Lexer.h
@@ -0,0 +1,77 @@
+//===- Lexer.h - MLIR Lexer Interface ---------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file declares the MLIR Lexer class.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_LIB_PARSER_LEXER_H
+#define MLIR_LIB_PARSER_LEXER_H
+
+#include "Token.h"
+#include "mlir/Parser.h"
+
+namespace mlir {
+class Location;
+
+/// This class breaks up the current file into a token stream.
+class Lexer {
+public:
+  explicit Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context);
+
+  const llvm::SourceMgr &getSourceMgr() { return sourceMgr; }
+
+  Token lexToken();
+
+  /// Encode the specified source location information into a Location object
+  /// for attachment to the IR or error reporting.
+  Location getEncodedSourceLocation(llvm::SMLoc loc);
+
+  /// Change the position of the lexer cursor.  The next token we lex will start
+  /// at the designated point in the input.
+  void resetPointer(const char *newPointer) { curPtr = newPointer; }
+
+private:
+  // Helpers.
+  Token formToken(Token::Kind kind, const char *tokStart) {
+    return Token(kind, StringRef(tokStart, curPtr - tokStart));
+  }
+
+  Token emitError(const char *loc, const Twine &message);
+
+  // Lexer implementation methods.
+  Token lexAtIdentifier(const char *tokStart);
+  Token lexBareIdentifierOrKeyword(const char *tokStart);
+  Token lexComment();
+  Token lexEllipsis(const char *tokStart);
+  Token lexNumber(const char *tokStart);
+  Token lexPrefixedIdentifier(const char *tokStart);
+  Token lexString(const char *tokStart);
+
+  const llvm::SourceMgr &sourceMgr;
+  MLIRContext *context;
+
+  StringRef curBuffer;
+  const char *curPtr;
+
+  Lexer(const Lexer &) = delete;
+  void operator=(const Lexer &) = delete;
+};
+
+} // end namespace mlir
+
+#endif // MLIR_LIB_PARSER_LEXER_H
diff --git a/third_party/mlir/lib/Parser/Parser.cpp b/third_party/mlir/lib/Parser/Parser.cpp
new file mode 100644
index 0000000..09f3052
--- /dev/null
+++ b/third_party/mlir/lib/Parser/Parser.cpp
@@ -0,0 +1,4160 @@
+//===- Parser.cpp - MLIR Parser Implementation ----------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the parser for the MLIR textual form.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Parser.h"
+#include "Lexer.h"
+#include "mlir/Analysis/Verifier.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Support/STLExtras.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/ADT/bit.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/PrettyStackTrace.h"
+#include "llvm/Support/SMLoc.h"
+#include "llvm/Support/SourceMgr.h"
+#include <algorithm>
+using namespace mlir;
+using llvm::MemoryBuffer;
+using llvm::SMLoc;
+using llvm::SourceMgr;
+
+namespace {
+class Parser;
+
+//===----------------------------------------------------------------------===//
+// ParserState
+//===----------------------------------------------------------------------===//
+
+/// This class refers to all of the state maintained globally by the parser,
+/// such as the current lexer position etc. The Parser base class provides
+/// methods to access this.
+class ParserState {
+public:
+  ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx)
+      : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()) {}
+
+  // A map from attribute alias identifier to Attribute.
+  llvm::StringMap<Attribute> attributeAliasDefinitions;
+
+  // A map from type alias identifier to Type.
+  llvm::StringMap<Type> typeAliasDefinitions;
+
+private:
+  ParserState(const ParserState &) = delete;
+  void operator=(const ParserState &) = delete;
+
+  friend class Parser;
+
+  // The context we're parsing into.
+  MLIRContext *const context;
+
+  // The lexer for the source file we're parsing.
+  Lexer lex;
+
+  // This is the next token that hasn't been consumed yet.
+  Token curToken;
+};
+
+//===----------------------------------------------------------------------===//
+// Parser
+//===----------------------------------------------------------------------===//
+
+/// This class implement support for parsing global entities like types and
+/// shared entities like SSA names.  It is intended to be subclassed by
+/// specialized subparsers that include state, e.g. when a local symbol table.
+class Parser {
+public:
+  Builder builder;
+
+  Parser(ParserState &state) : builder(state.context), state(state) {}
+
+  // Helper methods to get stuff from the parser-global state.
+  ParserState &getState() const { return state; }
+  MLIRContext *getContext() const { return state.context; }
+  const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); }
+
+  /// Parse a comma-separated list of elements up until the specified end token.
+  ParseResult
+  parseCommaSeparatedListUntil(Token::Kind rightToken,
+                               const std::function<ParseResult()> &parseElement,
+                               bool allowEmptyList = true);
+
+  /// Parse a comma separated list of elements that must have at least one entry
+  /// in it.
+  ParseResult
+  parseCommaSeparatedList(const std::function<ParseResult()> &parseElement);
+
+  ParseResult parsePrettyDialectSymbolName(StringRef &prettyName);
+
+  // We have two forms of parsing methods - those that return a non-null
+  // pointer on success, and those that return a ParseResult to indicate whether
+  // they returned a failure.  The second class fills in by-reference arguments
+  // as the results of their action.
+
+  //===--------------------------------------------------------------------===//
+  // Error Handling
+  //===--------------------------------------------------------------------===//
+
+  /// Emit an error and return failure.
+  InFlightDiagnostic emitError(const Twine &message = {}) {
+    return emitError(state.curToken.getLoc(), message);
+  }
+  InFlightDiagnostic emitError(SMLoc loc, const Twine &message = {});
+
+  /// Encode the specified source location information into an attribute for
+  /// attachment to the IR.
+  Location getEncodedSourceLocation(llvm::SMLoc loc) {
+    return state.lex.getEncodedSourceLocation(loc);
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Token Parsing
+  //===--------------------------------------------------------------------===//
+
+  /// Return the current token the parser is inspecting.
+  const Token &getToken() const { return state.curToken; }
+  StringRef getTokenSpelling() const { return state.curToken.getSpelling(); }
+
+  /// If the current token has the specified kind, consume it and return true.
+  /// If not, return false.
+  bool consumeIf(Token::Kind kind) {
+    if (state.curToken.isNot(kind))
+      return false;
+    consumeToken(kind);
+    return true;
+  }
+
+  /// Advance the current lexer onto the next token.
+  void consumeToken() {
+    assert(state.curToken.isNot(Token::eof, Token::error) &&
+           "shouldn't advance past EOF or errors");
+    state.curToken = state.lex.lexToken();
+  }
+
+  /// Advance the current lexer onto the next token, asserting what the expected
+  /// current token is.  This is preferred to the above method because it leads
+  /// to more self-documenting code with better checking.
+  void consumeToken(Token::Kind kind) {
+    assert(state.curToken.is(kind) && "consumed an unexpected token");
+    consumeToken();
+  }
+
+  /// Consume the specified token if present and return success.  On failure,
+  /// output a diagnostic and return failure.
+  ParseResult parseToken(Token::Kind expectedToken, const Twine &message);
+
+  //===--------------------------------------------------------------------===//
+  // Type Parsing
+  //===--------------------------------------------------------------------===//
+
+  ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements);
+  ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements);
+  ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements);
+
+  /// Parse an arbitrary type.
+  Type parseType();
+
+  /// Parse a complex type.
+  Type parseComplexType();
+
+  /// Parse an extended type.
+  Type parseExtendedType();
+
+  /// Parse a function type.
+  Type parseFunctionType();
+
+  /// Parse a memref type.
+  Type parseMemRefType();
+
+  /// Parse a non function type.
+  Type parseNonFunctionType();
+
+  /// Parse a tensor type.
+  Type parseTensorType();
+
+  /// Parse a tuple type.
+  Type parseTupleType();
+
+  /// Parse a vector type.
+  VectorType parseVectorType();
+  ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
+                                       bool allowDynamic = true);
+  ParseResult parseXInDimensionList();
+
+  //===--------------------------------------------------------------------===//
+  // Attribute Parsing
+  //===--------------------------------------------------------------------===//
+
+  /// Parse an arbitrary attribute with an optional type.
+  Attribute parseAttribute(Type type = {});
+
+  /// Parse an attribute dictionary.
+  ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
+
+  /// Parse an extended attribute.
+  Attribute parseExtendedAttr(Type type);
+
+  /// Parse a float attribute.
+  Attribute parseFloatAttr(Type type, bool isNegative);
+
+  /// Parse a decimal or a hexadecimal literal, which can be either an integer
+  /// or a float attribute.
+  Attribute parseDecOrHexAttr(Type type, bool isNegative);
+
+  /// Parse an opaque elements attribute.
+  Attribute parseOpaqueElementsAttr();
+
+  /// Parse a dense elements attribute.
+  Attribute parseDenseElementsAttr();
+  ShapedType parseElementsLiteralType();
+
+  /// Parse a sparse elements attribute.
+  Attribute parseSparseElementsAttr();
+
+  //===--------------------------------------------------------------------===//
+  // Location Parsing
+  //===--------------------------------------------------------------------===//
+
+  /// Parse an inline location.
+  ParseResult parseLocation(LocationAttr &loc);
+
+  /// Parse a raw location instance.
+  ParseResult parseLocationInstance(LocationAttr &loc);
+
+  /// Parse an optional trailing location.
+  ///
+  ///   trailing-location     ::= location?
+  ///
+  template <typename Owner>
+  ParseResult parseOptionalTrailingLocation(Owner *owner) {
+    // If there is a 'loc' we parse a trailing location.
+    if (!getToken().is(Token::kw_loc))
+      return success();
+
+    // Parse the location.
+    LocationAttr directLoc;
+    if (parseLocation(directLoc))
+      return failure();
+    owner->setLoc(directLoc);
+    return success();
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Affine Parsing
+  //===--------------------------------------------------------------------===//
+
+  ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map,
+                                                  IntegerSet &set);
+
+  /// Parse an AffineMap where the dim and symbol identifiers are SSA ids.
+  ParseResult
+  parseAffineMapOfSSAIds(AffineMap &map,
+                         llvm::function_ref<ParseResult(bool)> parseElement);
+
+private:
+  /// The Parser is subclassed and reinstantiated.  Do not add additional
+  /// non-trivial state here, add it to the ParserState class.
+  ParserState &state;
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Helper methods.
+//===----------------------------------------------------------------------===//
+
+/// Parse a comma separated list of elements that must have at least one entry
+/// in it.
+ParseResult Parser::parseCommaSeparatedList(
+    const std::function<ParseResult()> &parseElement) {
+  // Non-empty case starts with an element.
+  if (parseElement())
+    return failure();
+
+  // Otherwise we have a list of comma separated elements.
+  while (consumeIf(Token::comma)) {
+    if (parseElement())
+      return failure();
+  }
+  return success();
+}
+
+/// Parse a comma-separated list of elements, terminated with an arbitrary
+/// token.  This allows empty lists if allowEmptyList is true.
+///
+///   abstract-list ::= rightToken                  // if allowEmptyList == true
+///   abstract-list ::= element (',' element)* rightToken
+///
+ParseResult Parser::parseCommaSeparatedListUntil(
+    Token::Kind rightToken, const std::function<ParseResult()> &parseElement,
+    bool allowEmptyList) {
+  // Handle the empty case.
+  if (getToken().is(rightToken)) {
+    if (!allowEmptyList)
+      return emitError("expected list element");
+    consumeToken(rightToken);
+    return success();
+  }
+
+  if (parseCommaSeparatedList(parseElement) ||
+      parseToken(rightToken, "expected ',' or '" +
+                                 Token::getTokenSpelling(rightToken) + "'"))
+    return failure();
+
+  return success();
+}
+
+/// Parse the body of a pretty dialect symbol, which starts and ends with <>'s,
+/// and may be recursive.  Return with the 'prettyName' StringRef encompasing
+/// the entire pretty name.
+///
+///   pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
+///   pretty-dialect-sym-contents ::= pretty-dialect-sym-body
+///                                  | '(' pretty-dialect-sym-contents+ ')'
+///                                  | '[' pretty-dialect-sym-contents+ ']'
+///                                  | '{' pretty-dialect-sym-contents+ '}'
+///                                  | '[^[<({>\])}\0]+'
+///
+ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) {
+  // Pretty symbol names are a relatively unstructured format that contains a
+  // series of properly nested punctuation, with anything else in the middle.
+  // Scan ahead to find it and consume it if successful, otherwise emit an
+  // error.
+  auto *curPtr = getTokenSpelling().data();
+
+  SmallVector<char, 8> nestedPunctuation;
+
+  // Scan over the nested punctuation, bailing out on error and consuming until
+  // we find the end.  We know that we're currently looking at the '<', so we
+  // can go until we find the matching '>' character.
+  assert(*curPtr == '<');
+  do {
+    char c = *curPtr++;
+    switch (c) {
+    case '\0':
+      // This also handles the EOF case.
+      return emitError("unexpected nul or EOF in pretty dialect name");
+    case '<':
+    case '[':
+    case '(':
+    case '{':
+      nestedPunctuation.push_back(c);
+      continue;
+
+    case '>':
+      if (nestedPunctuation.pop_back_val() != '<')
+        return emitError("unbalanced '>' character in pretty dialect name");
+      break;
+    case ']':
+      if (nestedPunctuation.pop_back_val() != '[')
+        return emitError("unbalanced ']' character in pretty dialect name");
+      break;
+    case ')':
+      if (nestedPunctuation.pop_back_val() != '(')
+        return emitError("unbalanced ')' character in pretty dialect name");
+      break;
+    case '}':
+      if (nestedPunctuation.pop_back_val() != '{')
+        return emitError("unbalanced '}' character in pretty dialect name");
+      break;
+
+    default:
+      continue;
+    }
+  } while (!nestedPunctuation.empty());
+
+  // Ok, we succeeded, remember where we stopped, reset the lexer to know it is
+  // consuming all this stuff, and return.
+  state.lex.resetPointer(curPtr);
+
+  unsigned length = curPtr - prettyName.begin();
+  prettyName = StringRef(prettyName.begin(), length);
+  consumeToken();
+  return success();
+}
+
+/// Parse an extended dialect symbol.
+template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
+static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
+                                  SymbolAliasMap &aliases,
+                                  CreateFn &&createSymbol) {
+  // Parse the dialect namespace.
+  StringRef identifier = p.getTokenSpelling().drop_front();
+  auto loc = p.getToken().getLoc();
+  p.consumeToken(identifierTok);
+
+  // If there is no '<' token following this, and if the typename contains no
+  // dot, then we are parsing a symbol alias.
+  if (p.getToken().isNot(Token::less) && !identifier.contains('.')) {
+    // Check for an alias for this type.
+    auto aliasIt = aliases.find(identifier);
+    if (aliasIt == aliases.end())
+      return (p.emitError("undefined symbol alias id '" + identifier + "'"),
+              nullptr);
+    return aliasIt->second;
+  }
+
+  // Otherwise, we are parsing a dialect-specific symbol.  If the name contains
+  // a dot, then this is the "pretty" form.  If not, it is the verbose form that
+  // looks like <"...">.
+  std::string symbolData;
+  auto dialectName = identifier;
+
+  // Handle the verbose form, where "identifier" is a simple dialect name.
+  if (!identifier.contains('.')) {
+    // Consume the '<'.
+    if (p.parseToken(Token::less, "expected '<' in dialect type"))
+      return nullptr;
+
+    // Parse the symbol specific data.
+    if (p.getToken().isNot(Token::string))
+      return (p.emitError("expected string literal data in dialect symbol"),
+              nullptr);
+    symbolData = p.getToken().getStringValue();
+    loc = p.getToken().getLoc();
+    p.consumeToken(Token::string);
+
+    // Consume the '>'.
+    if (p.parseToken(Token::greater, "expected '>' in dialect symbol"))
+      return nullptr;
+  } else {
+    // Ok, the dialect name is the part of the identifier before the dot, the
+    // part after the dot is the dialect's symbol, or the start thereof.
+    auto dotHalves = identifier.split('.');
+    dialectName = dotHalves.first;
+    auto prettyName = dotHalves.second;
+
+    // If the dialect's symbol is followed immediately by a <, then lex the body
+    // of it into prettyName.
+    if (p.getToken().is(Token::less) &&
+        prettyName.bytes_end() == p.getTokenSpelling().bytes_begin()) {
+      if (p.parsePrettyDialectSymbolName(prettyName))
+        return nullptr;
+    }
+
+    symbolData = prettyName.str();
+  }
+
+  // Call into the provided symbol construction function.
+  auto encodedLoc = p.getEncodedSourceLocation(loc);
+  return createSymbol(dialectName, symbolData, encodedLoc);
+}
+
+//===----------------------------------------------------------------------===//
+// Error Handling
+//===----------------------------------------------------------------------===//
+
+InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) {
+  auto diag = mlir::emitError(getEncodedSourceLocation(loc), message);
+
+  // If we hit a parse error in response to a lexer error, then the lexer
+  // already reported the error.
+  if (getToken().is(Token::error))
+    diag.abandon();
+  return diag;
+}
+
+//===----------------------------------------------------------------------===//
+// Token Parsing
+//===----------------------------------------------------------------------===//
+
+/// Consume the specified token if present and return success.  On failure,
+/// output a diagnostic and return failure.
+ParseResult Parser::parseToken(Token::Kind expectedToken,
+                               const Twine &message) {
+  if (consumeIf(expectedToken))
+    return success();
+  return emitError(message);
+}
+
+//===----------------------------------------------------------------------===//
+// Type Parsing
+//===----------------------------------------------------------------------===//
+
+/// Parse an arbitrary type.
+///
+///   type ::= function-type
+///          | non-function-type
+///
+Type Parser::parseType() {
+  if (getToken().is(Token::l_paren))
+    return parseFunctionType();
+  return parseNonFunctionType();
+}
+
+/// Parse a function result type.
+///
+///   function-result-type ::= type-list-parens
+///                          | non-function-type
+///
+ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) {
+  if (getToken().is(Token::l_paren))
+    return parseTypeListParens(elements);
+
+  Type t = parseNonFunctionType();
+  if (!t)
+    return failure();
+  elements.push_back(t);
+  return success();
+}
+
+/// Parse a list of types without an enclosing parenthesis.  The list must have
+/// at least one member.
+///
+///   type-list-no-parens ::=  type (`,` type)*
+///
+ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
+  auto parseElt = [&]() -> ParseResult {
+    auto elt = parseType();
+    elements.push_back(elt);
+    return elt ? success() : failure();
+  };
+
+  return parseCommaSeparatedList(parseElt);
+}
+
+/// Parse a parenthesized list of types.
+///
+///   type-list-parens ::= `(` `)`
+///                      | `(` type-list-no-parens `)`
+///
+ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
+  if (parseToken(Token::l_paren, "expected '('"))
+    return failure();
+
+  // Handle empty lists.
+  if (getToken().is(Token::r_paren))
+    return consumeToken(), success();
+
+  if (parseTypeListNoParens(elements) ||
+      parseToken(Token::r_paren, "expected ')'"))
+    return failure();
+  return success();
+}
+
+/// Parse a complex type.
+///
+///   complex-type ::= `complex` `<` type `>`
+///
+Type Parser::parseComplexType() {
+  consumeToken(Token::kw_complex);
+
+  // Parse the '<'.
+  if (parseToken(Token::less, "expected '<' in complex type"))
+    return nullptr;
+
+  auto typeLocation = getEncodedSourceLocation(getToken().getLoc());
+  auto elementType = parseType();
+  if (!elementType ||
+      parseToken(Token::greater, "expected '>' in complex type"))
+    return nullptr;
+
+  return ComplexType::getChecked(elementType, typeLocation);
+}
+
+/// Parse an extended type.
+///
+///   extended-type ::= (dialect-type | type-alias)
+///   dialect-type  ::= `!` dialect-namespace `<` `"` type-data `"` `>`
+///   dialect-type  ::= `!` alias-name pretty-dialect-attribute-body?
+///   type-alias    ::= `!` alias-name
+///
+Type Parser::parseExtendedType() {
+  return parseExtendedSymbol<Type>(
+      *this, Token::exclamation_identifier, state.typeAliasDefinitions,
+      [&](StringRef dialectName, StringRef symbolData, Location loc) -> Type {
+        // If we found a registered dialect, then ask it to parse the type.
+        if (auto *dialect = state.context->getRegisteredDialect(dialectName))
+          return dialect->parseType(symbolData, loc);
+
+        // Otherwise, form a new opaque type.
+        return OpaqueType::getChecked(
+            Identifier::get(dialectName, state.context), symbolData,
+            state.context, loc);
+      });
+}
+
+/// Parse a function type.
+///
+///   function-type ::= type-list-parens `->` type-list
+///
+Type Parser::parseFunctionType() {
+  assert(getToken().is(Token::l_paren));
+
+  SmallVector<Type, 4> arguments, results;
+  if (parseTypeListParens(arguments) ||
+      parseToken(Token::arrow, "expected '->' in function type") ||
+      parseFunctionResultTypes(results))
+    return nullptr;
+
+  return builder.getFunctionType(arguments, results);
+}
+
+/// Parse a memref type.
+///
+///   memref-type ::= `memref` `<` dimension-list-ranked element-type
+///                   (`,` semi-affine-map-composition)? (`,` memory-space)? `>`
+///
+///   semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
+///   memory-space ::= integer-literal /* | TODO: address-space-id */
+///
+Type Parser::parseMemRefType() {
+  consumeToken(Token::kw_memref);
+
+  if (parseToken(Token::less, "expected '<' in memref type"))
+    return nullptr;
+
+  SmallVector<int64_t, 4> dimensions;
+  if (parseDimensionListRanked(dimensions))
+    return nullptr;
+
+  // Parse the element type.
+  auto typeLoc = getToken().getLoc();
+  auto elementType = parseType();
+  if (!elementType)
+    return nullptr;
+
+  // Parse semi-affine-map-composition.
+  SmallVector<AffineMap, 2> affineMapComposition;
+  unsigned memorySpace = 0;
+  bool parsedMemorySpace = false;
+
+  auto parseElt = [&]() -> ParseResult {
+    if (getToken().is(Token::integer)) {
+      // Parse memory space.
+      if (parsedMemorySpace)
+        return emitError("multiple memory spaces specified in memref type");
+      auto v = getToken().getUnsignedIntegerValue();
+      if (!v.hasValue())
+        return emitError("invalid memory space in memref type");
+      memorySpace = v.getValue();
+      consumeToken(Token::integer);
+      parsedMemorySpace = true;
+    } else {
+      // Parse affine map.
+      if (parsedMemorySpace)
+        return emitError("affine map after memory space in memref type");
+      auto affineMap = parseAttribute();
+      if (!affineMap)
+        return failure();
+
+      // Verify that the parsed attribute is an affine map.
+      if (auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>())
+        affineMapComposition.push_back(affineMapAttr.getValue());
+      else
+        return emitError("expected affine map in memref type");
+    }
+    return success();
+  };
+
+  // Parse a list of mappings and address space if present.
+  if (consumeIf(Token::comma)) {
+    // Parse comma separated list of affine maps, followed by memory space.
+    if (parseCommaSeparatedListUntil(Token::greater, parseElt,
+                                     /*allowEmptyList=*/false)) {
+      return nullptr;
+    }
+  } else {
+    if (parseToken(Token::greater, "expected ',' or '>' in memref type"))
+      return nullptr;
+  }
+
+  return MemRefType::getChecked(dimensions, elementType, affineMapComposition,
+                                memorySpace, getEncodedSourceLocation(typeLoc));
+}
+
+/// Parse any type except the function type.
+///
+///   non-function-type ::= integer-type
+///                       | index-type
+///                       | float-type
+///                       | extended-type
+///                       | vector-type
+///                       | tensor-type
+///                       | memref-type
+///                       | complex-type
+///                       | tuple-type
+///                       | none-type
+///
+///   index-type ::= `index`
+///   float-type ::= `f16` | `bf16` | `f32` | `f64`
+///   none-type ::= `none`
+///
+Type Parser::parseNonFunctionType() {
+  switch (getToken().getKind()) {
+  default:
+    return (emitError("expected non-function type"), nullptr);
+  case Token::kw_memref:
+    return parseMemRefType();
+  case Token::kw_tensor:
+    return parseTensorType();
+  case Token::kw_complex:
+    return parseComplexType();
+  case Token::kw_tuple:
+    return parseTupleType();
+  case Token::kw_vector:
+    return parseVectorType();
+  // integer-type
+  case Token::inttype: {
+    auto width = getToken().getIntTypeBitwidth();
+    if (!width.hasValue())
+      return (emitError("invalid integer width"), nullptr);
+    auto loc = getEncodedSourceLocation(getToken().getLoc());
+    consumeToken(Token::inttype);
+    return IntegerType::getChecked(width.getValue(), builder.getContext(), loc);
+  }
+
+  // float-type
+  case Token::kw_bf16:
+    consumeToken(Token::kw_bf16);
+    return builder.getBF16Type();
+  case Token::kw_f16:
+    consumeToken(Token::kw_f16);
+    return builder.getF16Type();
+  case Token::kw_f32:
+    consumeToken(Token::kw_f32);
+    return builder.getF32Type();
+  case Token::kw_f64:
+    consumeToken(Token::kw_f64);
+    return builder.getF64Type();
+
+  // index-type
+  case Token::kw_index:
+    consumeToken(Token::kw_index);
+    return builder.getIndexType();
+
+  // none-type
+  case Token::kw_none:
+    consumeToken(Token::kw_none);
+    return builder.getNoneType();
+
+  // extended type
+  case Token::exclamation_identifier:
+    return parseExtendedType();
+  }
+}
+
+/// Parse a tensor type.
+///
+///   tensor-type ::= `tensor` `<` dimension-list element-type `>`
+///   dimension-list ::= dimension-list-ranked | `*x`
+///
+Type Parser::parseTensorType() {
+  consumeToken(Token::kw_tensor);
+
+  if (parseToken(Token::less, "expected '<' in tensor type"))
+    return nullptr;
+
+  bool isUnranked;
+  SmallVector<int64_t, 4> dimensions;
+
+  if (consumeIf(Token::star)) {
+    // This is an unranked tensor type.
+    isUnranked = true;
+
+    if (parseXInDimensionList())
+      return nullptr;
+
+  } else {
+    isUnranked = false;
+    if (parseDimensionListRanked(dimensions))
+      return nullptr;
+  }
+
+  // Parse the element type.
+  auto typeLocation = getEncodedSourceLocation(getToken().getLoc());
+  auto elementType = parseType();
+  if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
+    return nullptr;
+
+  if (isUnranked)
+    return UnrankedTensorType::getChecked(elementType, typeLocation);
+  return RankedTensorType::getChecked(dimensions, elementType, typeLocation);
+}
+
+/// Parse a tuple type.
+///
+///   tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
+///
+Type Parser::parseTupleType() {
+  consumeToken(Token::kw_tuple);
+
+  // Parse the '<'.
+  if (parseToken(Token::less, "expected '<' in tuple type"))
+    return nullptr;
+
+  // Check for an empty tuple by directly parsing '>'.
+  if (consumeIf(Token::greater))
+    return TupleType::get(getContext());
+
+  // Parse the element types and the '>'.
+  SmallVector<Type, 4> types;
+  if (parseTypeListNoParens(types) ||
+      parseToken(Token::greater, "expected '>' in tuple type"))
+    return nullptr;
+
+  return TupleType::get(types, getContext());
+}
+
+/// Parse a vector type.
+///
+///   vector-type ::= `vector` `<` static-dimension-list primitive-type `>`
+///   static-dimension-list ::= (decimal-literal `x`)+
+///
+VectorType Parser::parseVectorType() {
+  consumeToken(Token::kw_vector);
+
+  if (parseToken(Token::less, "expected '<' in vector type"))
+    return nullptr;
+
+  SmallVector<int64_t, 4> dimensions;
+  if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false))
+    return nullptr;
+  if (dimensions.empty())
+    return (emitError("expected dimension size in vector type"), nullptr);
+
+  // Parse the element type.
+  auto typeLoc = getToken().getLoc();
+  auto elementType = parseType();
+  if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
+    return nullptr;
+
+  return VectorType::getChecked(dimensions, elementType,
+                                getEncodedSourceLocation(typeLoc));
+}
+
+/// Parse a dimension list of a tensor or memref type.  This populates the
+/// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and
+/// errors out on `?` otherwise.
+///
+///   dimension-list-ranked ::= (dimension `x`)*
+///   dimension ::= `?` | decimal-literal
+///
+/// When `allowDynamic` is not set, this can be also used to parse
+///
+///   static-dimension-list ::= (decimal-literal `x`)*
+ParseResult
+Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
+                                 bool allowDynamic) {
+  while (getToken().isAny(Token::integer, Token::question)) {
+    if (consumeIf(Token::question)) {
+      if (!allowDynamic)
+        return emitError("expected static shape");
+      dimensions.push_back(-1);
+    } else {
+      // Hexadecimal integer literals (starting with `0x`) are not allowed in
+      // aggregate type declarations.  Therefore, `0xf32` should be processed as
+      // a sequence of separate elements `0`, `x`, `f32`.
+      if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
+        // We can get here only if the token is an integer literal.  Hexadecimal
+        // integer literals can only start with `0x` (`1x` wouldn't lex as a
+        // literal, just `1` would, at which point we don't get into this
+        // branch).
+        assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
+        dimensions.push_back(0);
+        state.lex.resetPointer(getTokenSpelling().data() + 1);
+        consumeToken();
+      } else {
+        // Make sure this integer value is in bound and valid.
+        auto dimension = getToken().getUnsignedIntegerValue();
+        if (!dimension.hasValue())
+          return emitError("invalid dimension");
+        dimensions.push_back((int64_t)dimension.getValue());
+        consumeToken(Token::integer);
+      }
+    }
+
+    // Make sure we have an 'x' or something like 'xbf32'.
+    if (parseXInDimensionList())
+      return failure();
+  }
+
+  return success();
+}
+
+/// Parse an 'x' token in a dimension list, handling the case where the x is
+/// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next
+/// token.
+ParseResult Parser::parseXInDimensionList() {
+  if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x')
+    return emitError("expected 'x' in dimension list");
+
+  // If we had a prefix of 'x', lex the next token immediately after the 'x'.
+  if (getTokenSpelling().size() != 1)
+    state.lex.resetPointer(getTokenSpelling().data() + 1);
+
+  // Consume the 'x'.
+  consumeToken(Token::bare_identifier);
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Attribute parsing.
+//===----------------------------------------------------------------------===//
+
+/// Parse an arbitrary attribute.
+///
+///  attribute-value ::= `unit`
+///                    | bool-literal
+///                    | integer-literal (`:` (index-type | integer-type))?
+///                    | float-literal (`:` float-type)?
+///                    | string-literal (`:` type)?
+///                    | type
+///                    | `[` (attribute-value (`,` attribute-value)*)? `]`
+///                    | `{` (attribute-entry (`,` attribute-entry)*)? `}`
+///                    | symbol-ref-id
+///                    | `dense` `<` attribute-value `>` `:`
+///                      (tensor-type | vector-type)
+///                    | `sparse` `<` attribute-value `,` attribute-value `>`
+///                      `:` (tensor-type | vector-type)
+///                    | `opaque` `<` dialect-namespace  `,` hex-string-literal
+///                      `>` `:` (tensor-type | vector-type)
+///                    | extended-attribute
+///
+Attribute Parser::parseAttribute(Type type) {
+  switch (getToken().getKind()) {
+  // Parse an AffineMap or IntegerSet attribute.
+  case Token::l_paren: {
+    // Try to parse an affine map or an integer set reference.
+    AffineMap map;
+    IntegerSet set;
+    if (parseAffineMapOrIntegerSetReference(map, set))
+      return nullptr;
+    if (map)
+      return builder.getAffineMapAttr(map);
+    assert(set);
+    return builder.getIntegerSetAttr(set);
+  }
+
+  // Parse an array attribute.
+  case Token::l_square: {
+    consumeToken(Token::l_square);
+
+    SmallVector<Attribute, 4> elements;
+    auto parseElt = [&]() -> ParseResult {
+      elements.push_back(parseAttribute());
+      return elements.back() ? success() : failure();
+    };
+
+    if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
+      return nullptr;
+    return builder.getArrayAttr(elements);
+  }
+
+  // Parse a boolean attribute.
+  case Token::kw_false:
+    consumeToken(Token::kw_false);
+    return builder.getBoolAttr(false);
+  case Token::kw_true:
+    consumeToken(Token::kw_true);
+    return builder.getBoolAttr(true);
+
+  // Parse a dense elements attribute.
+  case Token::kw_dense:
+    return parseDenseElementsAttr();
+
+  // Parse a dictionary attribute.
+  case Token::l_brace: {
+    SmallVector<NamedAttribute, 4> elements;
+    if (parseAttributeDict(elements))
+      return nullptr;
+    return builder.getDictionaryAttr(elements);
+  }
+
+  // Parse an extended attribute, i.e. alias or dialect attribute.
+  case Token::hash_identifier:
+    return parseExtendedAttr(type);
+
+  // Parse floating point and integer attributes.
+  case Token::floatliteral:
+    return parseFloatAttr(type, /*isNegative=*/false);
+  case Token::integer:
+    return parseDecOrHexAttr(type, /*isNegative=*/false);
+  case Token::minus: {
+    consumeToken(Token::minus);
+    if (getToken().is(Token::integer))
+      return parseDecOrHexAttr(type, /*isNegative=*/true);
+    if (getToken().is(Token::floatliteral))
+      return parseFloatAttr(type, /*isNegative=*/true);
+
+    return (emitError("expected constant integer or floating point value"),
+            nullptr);
+  }
+
+  // Parse a location attribute.
+  case Token::kw_loc: {
+    LocationAttr attr;
+    return failed(parseLocation(attr)) ? Attribute() : attr;
+  }
+
+  // Parse an opaque elements attribute.
+  case Token::kw_opaque:
+    return parseOpaqueElementsAttr();
+
+  // Parse a sparse elements attribute.
+  case Token::kw_sparse:
+    return parseSparseElementsAttr();
+
+  // Parse a string attribute.
+  case Token::string: {
+    auto val = getToken().getStringValue();
+    consumeToken(Token::string);
+    // Parse the optional trailing colon type if one wasn't explicitly provided.
+    if (!type && consumeIf(Token::colon) && !(type = parseType()))
+      return Attribute();
+
+    return type ? StringAttr::get(val, type)
+                : StringAttr::get(val, getContext());
+  }
+
+  // Parse a symbol reference attribute.
+  case Token::at_identifier: {
+    auto nameStr = getTokenSpelling();
+    consumeToken(Token::at_identifier);
+    return builder.getSymbolRefAttr(nameStr.drop_front());
+  }
+
+  // Parse a 'unit' attribute.
+  case Token::kw_unit:
+    consumeToken(Token::kw_unit);
+    return builder.getUnitAttr();
+
+  default:
+    // Parse a type attribute.
+    if (Type type = parseType())
+      return builder.getTypeAttr(type);
+    return nullptr;
+  }
+}
+
+/// Attribute dictionary.
+///
+///   attribute-dict ::= `{` `}`
+///                    | `{` attribute-entry (`,` attribute-entry)* `}`
+///   attribute-entry ::= bare-id `=` attribute-value
+///
+ParseResult
+Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) {
+  if (!consumeIf(Token::l_brace))
+    return failure();
+
+  auto parseElt = [&]() -> ParseResult {
+    // We allow keywords as attribute names.
+    if (getToken().isNot(Token::bare_identifier, Token::inttype) &&
+        !getToken().isKeyword())
+      return emitError("expected attribute name");
+    Identifier nameId = builder.getIdentifier(getTokenSpelling());
+    consumeToken();
+
+    // Try to parse the '=' for the attribute value.
+    if (!consumeIf(Token::equal)) {
+      // If there is no '=', we treat this as a unit attribute.
+      attributes.push_back({nameId, builder.getUnitAttr()});
+      return success();
+    }
+
+    auto attr = parseAttribute();
+    if (!attr)
+      return failure();
+
+    attributes.push_back({nameId, attr});
+    return success();
+  };
+
+  if (parseCommaSeparatedListUntil(Token::r_brace, parseElt))
+    return failure();
+
+  return success();
+}
+
+/// Parse an extended attribute.
+///
+///   extended-attribute ::= (dialect-attribute | attribute-alias)
+///   dialect-attribute  ::= `#` dialect-namespace `<` `"` attr-data `"` `>`
+///   dialect-attribute  ::= `#` alias-name pretty-dialect-sym-body?
+///   attribute-alias    ::= `#` alias-name
+///
+Attribute Parser::parseExtendedAttr(Type type) {
+  Attribute attr = parseExtendedSymbol<Attribute>(
+      *this, Token::hash_identifier, state.attributeAliasDefinitions,
+      [&](StringRef dialectName, StringRef symbolData,
+          Location loc) -> Attribute {
+        // Parse an optional trailing colon type.
+        Type attrType = type;
+        if (consumeIf(Token::colon) && !(attrType = parseType()))
+          return Attribute();
+
+        // If we found a registered dialect, then ask it to parse the attribute.
+        if (auto *dialect = state.context->getRegisteredDialect(dialectName))
+          return dialect->parseAttribute(symbolData, attrType, loc);
+
+        // Otherwise, form a new opaque attribute.
+        return OpaqueAttr::getChecked(
+            Identifier::get(dialectName, state.context), symbolData,
+            attrType ? attrType : NoneType::get(state.context), loc);
+      });
+
+  // Ensure that the attribute has the same type as requested.
+  if (attr && type && attr.getType() != type) {
+    emitError("attribute type different than expected: expected ")
+        << type << ", but got " << attr.getType();
+    return nullptr;
+  }
+  return attr;
+}
+
+/// Parse a float attribute.
+Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
+  auto val = getToken().getFloatingPointValue();
+  if (!val.hasValue())
+    return (emitError("floating point value too large for attribute"), nullptr);
+  consumeToken(Token::floatliteral);
+  if (!type) {
+    // Default to F64 when no type is specified.
+    if (!consumeIf(Token::colon))
+      type = builder.getF64Type();
+    else if (!(type = parseType()))
+      return nullptr;
+  }
+  if (!type.isa<FloatType>())
+    return (emitError("floating point value not valid for specified type"),
+            nullptr);
+  return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue());
+}
+
+/// Construct a float attribute bitwise equivalent to the integer literal.
+static FloatAttr buildHexadecimalFloatLiteral(Parser *p, FloatType type,
+                                              uint64_t value) {
+  int width = type.getIntOrFloatBitWidth();
+  APInt apInt(width, value);
+  if (apInt != value) {
+    p->emitError("hexadecimal float constant out of range for type");
+    return nullptr;
+  }
+  APFloat apFloat(type.getFloatSemantics(), apInt);
+  return p->builder.getFloatAttr(type, apFloat);
+}
+
+/// Parse a decimal or a hexadecimal literal, which can be either an integer
+/// or a float attribute.
+Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
+  auto val = getToken().getUInt64IntegerValue();
+  if (!val.hasValue())
+    return (emitError("integer constant out of range for attribute"), nullptr);
+
+  // Remember if the literal is hexadecimal.
+  StringRef spelling = getToken().getSpelling();
+  bool isHex = spelling.size() > 1 && spelling[1] == 'x';
+
+  consumeToken(Token::integer);
+  if (!type) {
+    // Default to i64 if not type is specified.
+    if (!consumeIf(Token::colon))
+      type = builder.getIntegerType(64);
+    else if (!(type = parseType()))
+      return nullptr;
+  }
+
+  // Hexadecimal representation of float literals is not supported for bfloat16.
+  // When supported, the literal should be unsigned.
+  auto floatType = type.dyn_cast<FloatType>();
+  if (floatType && !type.isBF16()) {
+    if (isNegative) {
+      emitError("hexadecimal float literal should not have a leading minus");
+      return nullptr;
+    }
+    if (!isHex) {
+      emitError("unexpected decimal integer literal for a float attribute")
+              .attachNote()
+          << "add a trailing dot to make the literal a float";
+      return nullptr;
+    }
+
+    // Construct a float attribute bitwise equivalent to the integer literal.
+    return buildHexadecimalFloatLiteral(this, floatType, *val);
+  }
+
+  if (!type.isIntOrIndex())
+    return (emitError("integer literal not valid for specified type"), nullptr);
+
+  // Parse the integer literal.
+  int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
+  APInt apInt(width, *val, isNegative);
+  if (apInt != *val)
+    return (emitError("integer constant out of range for attribute"), nullptr);
+
+  // Otherwise construct an integer attribute.
+  if (isNegative ? (int64_t)-val.getValue() >= 0 : (int64_t)val.getValue() < 0)
+    return (emitError("integer constant out of range for attribute"), nullptr);
+
+  return builder.getIntegerAttr(type, isNegative ? -apInt : apInt);
+}
+
+/// Parse an opaque elements attribute.
+Attribute Parser::parseOpaqueElementsAttr() {
+  consumeToken(Token::kw_opaque);
+  if (parseToken(Token::less, "expected '<' after 'opaque'"))
+    return nullptr;
+
+  if (getToken().isNot(Token::string))
+    return (emitError("expected dialect namespace"), nullptr);
+
+  auto name = getToken().getStringValue();
+  auto *dialect = builder.getContext()->getRegisteredDialect(name);
+  // TODO(shpeisman): Allow for having an unknown dialect on an opaque
+  // attribute. Otherwise, it can't be roundtripped without having the dialect
+  // registered.
+  if (!dialect)
+    return (emitError("no registered dialect with namespace '" + name + "'"),
+            nullptr);
+
+  consumeToken(Token::string);
+  if (parseToken(Token::comma, "expected ','"))
+    return nullptr;
+
+  if (getToken().getKind() != Token::string)
+    return (emitError("opaque string should start with '0x'"), nullptr);
+
+  auto val = getToken().getStringValue();
+  if (val.size() < 2 || val[0] != '0' || val[1] != 'x')
+    return (emitError("opaque string should start with '0x'"), nullptr);
+
+  val = val.substr(2);
+  if (!llvm::all_of(val, llvm::isHexDigit))
+    return (emitError("opaque string only contains hex digits"), nullptr);
+
+  consumeToken(Token::string);
+  if (parseToken(Token::greater, "expected '>'") ||
+      parseToken(Token::colon, "expected ':'"))
+    return nullptr;
+
+  auto type = parseElementsLiteralType();
+  if (!type)
+    return nullptr;
+
+  return builder.getOpaqueElementsAttr(dialect, type, llvm::fromHex(val));
+}
+
+namespace {
+class TensorLiteralParser {
+public:
+  TensorLiteralParser(Parser &p) : p(p) {}
+
+  ParseResult parse() {
+    if (p.getToken().is(Token::l_square))
+      return parseList(shape);
+    return parseElement();
+  }
+
+  /// Build a dense attribute instance with the parsed elements and the given
+  /// shaped type.
+  DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type);
+
+  ArrayRef<int64_t> getShape() const { return shape; }
+
+private:
+  enum class ElementKind { Boolean, Integer, Float };
+
+  /// Return a string to represent the given element kind.
+  const char *getElementKindStr(ElementKind kind) {
+    switch (kind) {
+    case ElementKind::Boolean:
+      return "'boolean'";
+    case ElementKind::Integer:
+      return "'integer'";
+    case ElementKind::Float:
+      return "'float'";
+    }
+    llvm_unreachable("unknown element kind");
+  }
+
+  /// Build a Dense Integer attribute for the given type.
+  DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type,
+                               IntegerType eltTy);
+
+  /// Build a Dense Float attribute for the given type.
+  DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type,
+                                 FloatType eltTy);
+
+  /// Parse a single element, returning failure if it isn't a valid element
+  /// literal. For example:
+  /// parseElement(1) -> Success, 1
+  /// parseElement([1]) -> Failure
+  ParseResult parseElement();
+
+  /// Parse a list of either lists or elements, returning the dimensions of the
+  /// parsed sub-tensors in dims. For example:
+  ///   parseList([1, 2, 3]) -> Success, [3]
+  ///   parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
+  ///   parseList([[1, 2], 3]) -> Failure
+  ///   parseList([[1, [2, 3]], [4, [5]]]) -> Failure
+  ParseResult parseList(llvm::SmallVectorImpl<int64_t> &dims);
+
+  Parser &p;
+
+  /// The shape inferred from the parsed elements.
+  SmallVector<int64_t, 4> shape;
+
+  /// Storage used when parsing elements, this is a pair of <is_negated, token>.
+  std::vector<std::pair<bool, Token>> storage;
+
+  /// A flag that indicates the type of elements that have been parsed.
+  llvm::Optional<ElementKind> knownEltKind;
+};
+} // namespace
+
+/// Build a dense attribute instance with the parsed elements and the given
+/// shaped type.
+DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
+                                               ShapedType type) {
+  // Check that the parsed storage size has the same number of elements to the
+  // type, or is a known splat.
+  if (!shape.empty() && getShape() != type.getShape()) {
+    p.emitError(loc) << "inferred shape of elements literal ([" << getShape()
+                     << "]) does not match type ([" << type.getShape() << "])";
+    return nullptr;
+  }
+
+  // If the type is an integer, build a set of APInt values from the storage
+  // with the correct bitwidth.
+  if (auto intTy = type.getElementType().dyn_cast<IntegerType>())
+    return getIntAttr(loc, type, intTy);
+
+  // Otherwise, this must be a floating point type.
+  auto floatTy = type.getElementType().dyn_cast<FloatType>();
+  if (!floatTy) {
+    p.emitError(loc) << "expected floating-point or integer element type, got "
+                     << type.getElementType();
+    return nullptr;
+  }
+  return getFloatAttr(loc, type, floatTy);
+}
+
+/// Build a Dense Integer attribute for the given type.
+DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc,
+                                                  ShapedType type,
+                                                  IntegerType eltTy) {
+  std::vector<APInt> intElements;
+  intElements.reserve(storage.size());
+  for (const auto &signAndToken : storage) {
+    bool isNegative = signAndToken.first;
+    const Token &token = signAndToken.second;
+
+    // Check to see if floating point values were parsed.
+    if (token.is(Token::floatliteral)) {
+      p.emitError() << "expected integer elements, but parsed floating-point";
+      return nullptr;
+    }
+
+    assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
+           "unexpected token type");
+    if (token.isAny(Token::kw_true, Token::kw_false)) {
+      if (!eltTy.isInteger(1))
+        p.emitError() << "expected i1 type for 'true' or 'false' values";
+      APInt apInt(eltTy.getWidth(), token.is(Token::kw_true),
+                  /*isSigned=*/false);
+      intElements.push_back(apInt);
+      continue;
+    }
+
+    // Create APInt values for each element with the correct bitwidth.
+    auto val = token.getUInt64IntegerValue();
+    if (!val.hasValue() || (isNegative ? (int64_t)-val.getValue() >= 0
+                                       : (int64_t)val.getValue() < 0)) {
+      p.emitError(token.getLoc(),
+                  "integer constant out of range for attribute");
+      return nullptr;
+    }
+    APInt apInt(eltTy.getWidth(), val.getValue(), isNegative);
+    if (apInt != val.getValue())
+      return (p.emitError("integer constant out of range for type"), nullptr);
+    intElements.push_back(isNegative ? -apInt : apInt);
+  }
+
+  return DenseElementsAttr::get(type, intElements);
+}
+
+/// Build a Dense Float attribute for the given type.
+DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
+                                                    ShapedType type,
+                                                    FloatType eltTy) {
+  std::vector<Attribute> floatValues;
+  floatValues.reserve(storage.size());
+  for (const auto &signAndToken : storage) {
+    bool isNegative = signAndToken.first;
+    const Token &token = signAndToken.second;
+
+    // Handle hexadecimal float literals.
+    if (token.is(Token::integer) && token.getSpelling().startswith("0x")) {
+      if (isNegative) {
+        p.emitError(token.getLoc())
+            << "hexadecimal float literal should not have a leading minus";
+        return nullptr;
+      }
+      auto val = token.getUInt64IntegerValue();
+      if (!val.hasValue()) {
+        p.emitError("hexadecimal float constant out of range for attribute");
+        return nullptr;
+      }
+      FloatAttr attr = buildHexadecimalFloatLiteral(&p, eltTy, *val);
+      if (!attr)
+        return nullptr;
+      floatValues.push_back(attr);
+      continue;
+    }
+
+    // Check to see if any decimal integers or booleans were parsed.
+    if (!token.is(Token::floatliteral)) {
+      p.emitError() << "expected floating-point elements, but parsed integer";
+      return nullptr;
+    }
+
+    // Build the float values from tokens.
+    auto val = token.getFloatingPointValue();
+    if (!val.hasValue()) {
+      p.emitError("floating point value too large for attribute");
+      return nullptr;
+    }
+    floatValues.push_back(FloatAttr::get(eltTy, isNegative ? -*val : *val));
+  }
+
+  return DenseElementsAttr::get(type, floatValues);
+}
+
+ParseResult TensorLiteralParser::parseElement() {
+  switch (p.getToken().getKind()) {
+  // Parse a boolean element.
+  case Token::kw_true:
+  case Token::kw_false:
+  case Token::floatliteral:
+  case Token::integer:
+    storage.emplace_back(/*isNegative=*/false, p.getToken());
+    p.consumeToken();
+    break;
+
+  // Parse a signed integer or a negative floating-point element.
+  case Token::minus:
+    p.consumeToken(Token::minus);
+    if (!p.getToken().isAny(Token::floatliteral, Token::integer))
+      return p.emitError("expected integer or floating point literal");
+    storage.emplace_back(/*isNegative=*/true, p.getToken());
+    p.consumeToken();
+    break;
+
+  default:
+    return p.emitError("expected element literal of primitive type");
+  }
+
+  return success();
+}
+
+/// Parse a list of either lists or elements, returning the dimensions of the
+/// parsed sub-tensors in dims. For example:
+///   parseList([1, 2, 3]) -> Success, [3]
+///   parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
+///   parseList([[1, 2], 3]) -> Failure
+///   parseList([[1, [2, 3]], [4, [5]]]) -> Failure
+ParseResult
+TensorLiteralParser::parseList(llvm::SmallVectorImpl<int64_t> &dims) {
+  p.consumeToken(Token::l_square);
+
+  auto checkDims =
+      [&](const llvm::SmallVectorImpl<int64_t> &prevDims,
+          const llvm::SmallVectorImpl<int64_t> &newDims) -> ParseResult {
+    if (prevDims == newDims)
+      return success();
+    return p.emitError("tensor literal is invalid; ranks are not consistent "
+                       "between elements");
+  };
+
+  bool first = true;
+  llvm::SmallVector<int64_t, 4> newDims;
+  unsigned size = 0;
+  auto parseCommaSeparatedList = [&]() -> ParseResult {
+    llvm::SmallVector<int64_t, 4> thisDims;
+    if (p.getToken().getKind() == Token::l_square) {
+      if (parseList(thisDims))
+        return failure();
+    } else if (parseElement()) {
+      return failure();
+    }
+    ++size;
+    if (!first)
+      return checkDims(newDims, thisDims);
+    newDims = thisDims;
+    first = false;
+    return success();
+  };
+  if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList))
+    return failure();
+
+  // Return the sublists' dimensions with 'size' prepended.
+  dims.clear();
+  dims.push_back(size);
+  dims.append(newDims.begin(), newDims.end());
+  return success();
+}
+
+/// Parse a dense elements attribute.
+Attribute Parser::parseDenseElementsAttr() {
+  consumeToken(Token::kw_dense);
+  if (parseToken(Token::less, "expected '<' after 'dense'"))
+    return nullptr;
+
+  // Parse the literal data.
+  TensorLiteralParser literalParser(*this);
+  if (literalParser.parse())
+    return nullptr;
+
+  if (parseToken(Token::greater, "expected '>'") ||
+      parseToken(Token::colon, "expected ':'"))
+    return nullptr;
+
+  auto typeLoc = getToken().getLoc();
+  auto type = parseElementsLiteralType();
+  if (!type)
+    return nullptr;
+  return literalParser.getAttr(typeLoc, type);
+}
+
+/// Shaped type for elements attribute.
+///
+///   elements-literal-type ::= vector-type | ranked-tensor-type
+///
+/// This method also checks the type has static shape.
+ShapedType Parser::parseElementsLiteralType() {
+  auto type = parseType();
+  if (!type)
+    return nullptr;
+
+  if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) {
+    emitError("elements literal must be a ranked tensor or vector type");
+    return nullptr;
+  }
+
+  auto sType = type.cast<ShapedType>();
+  if (!sType.hasStaticShape())
+    return (emitError("elements literal type must have static shape"), nullptr);
+
+  return sType;
+}
+
+/// Parse a sparse elements attribute.
+Attribute Parser::parseSparseElementsAttr() {
+  consumeToken(Token::kw_sparse);
+  if (parseToken(Token::less, "Expected '<' after 'sparse'"))
+    return nullptr;
+
+  /// Parse indices
+  auto indicesLoc = getToken().getLoc();
+  TensorLiteralParser indiceParser(*this);
+  if (indiceParser.parse())
+    return nullptr;
+
+  if (parseToken(Token::comma, "expected ','"))
+    return nullptr;
+
+  /// Parse values.
+  auto valuesLoc = getToken().getLoc();
+  TensorLiteralParser valuesParser(*this);
+  if (valuesParser.parse())
+    return nullptr;
+
+  if (parseToken(Token::greater, "expected '>'") ||
+      parseToken(Token::colon, "expected ':'"))
+    return nullptr;
+
+  auto type = parseElementsLiteralType();
+  if (!type)
+    return nullptr;
+
+  // If the indices are a splat, i.e. the literal parser parsed an element and
+  // not a list, we set the shape explicitly. The indices are represented by a
+  // 2-dimensional shape where the second dimension is the rank of the type.
+  // Given that the parsed indices is a splat, we know that we only have one
+  // indice and thus one for the first dimension.
+  auto indiceEltType = builder.getIntegerType(64);
+  ShapedType indicesType;
+  if (indiceParser.getShape().empty()) {
+    indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
+  } else {
+    // Otherwise, set the shape to the one parsed by the literal parser.
+    indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
+  }
+  auto indices = indiceParser.getAttr(indicesLoc, indicesType);
+
+  // If the values are a splat, set the shape explicitly based on the number of
+  // indices. The number of indices is encoded in the first dimension of the
+  // indice shape type.
+  auto valuesEltType = type.getElementType();
+  ShapedType valuesType =
+      valuesParser.getShape().empty()
+          ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
+          : RankedTensorType::get(valuesParser.getShape(), valuesEltType);
+  auto values = valuesParser.getAttr(valuesLoc, valuesType);
+
+  /// Sanity check.
+  if (valuesType.getRank() != 1)
+    return (emitError("expected 1-d tensor for values"), nullptr);
+
+  auto sameShape = (indicesType.getRank() == 1) ||
+                   (type.getRank() == indicesType.getDimSize(1));
+  auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0);
+  if (!sameShape || !sameElementNum) {
+    emitError() << "expected shape ([" << type.getShape()
+                << "]); inferred shape of indices literal (["
+                << indicesType.getShape()
+                << "]); inferred shape of values literal (["
+                << valuesType.getShape() << "])";
+    return nullptr;
+  }
+
+  // Build the sparse elements attribute by the indices and values.
+  return SparseElementsAttr::get(type, indices, values);
+}
+
+//===----------------------------------------------------------------------===//
+// Location parsing.
+//===----------------------------------------------------------------------===//
+
+/// Parse a location.
+///
+///   location           ::= `loc` inline-location
+///   inline-location    ::= '(' location-inst ')'
+///
+ParseResult Parser::parseLocation(LocationAttr &loc) {
+  // Check for 'loc' identifier.
+  if (parseToken(Token::kw_loc, "expected 'loc' keyword"))
+    return emitError();
+
+  // Parse the inline-location.
+  if (parseToken(Token::l_paren, "expected '(' in inline location") ||
+      parseLocationInstance(loc) ||
+      parseToken(Token::r_paren, "expected ')' in inline location"))
+    return failure();
+  return success();
+}
+
+/// Specific location instances.
+///
+/// location-inst ::= filelinecol-location |
+///                   name-location |
+///                   callsite-location |
+///                   fused-location |
+///                   unknown-location
+/// filelinecol-location ::= string-literal ':' integer-literal
+///                                         ':' integer-literal
+/// name-location ::= string-literal
+/// callsite-location ::= 'callsite' '(' location-inst 'at' location-inst ')'
+/// fused-location ::= fused ('<' attribute-value '>')?
+///                    '[' location-inst (location-inst ',')* ']'
+/// unknown-location ::= 'unknown'
+///
+ParseResult Parser::parseLocationInstance(LocationAttr &loc) {
+  auto *ctx = getContext();
+
+  // Handle either name or filelinecol locations.
+  if (getToken().is(Token::string)) {
+    auto str = getToken().getStringValue();
+    consumeToken(Token::string);
+
+    // If the next token is ':' this is a filelinecol location.
+    if (consumeIf(Token::colon)) {
+      // Parse the line number.
+      if (getToken().isNot(Token::integer))
+        return emitError("expected integer line number in FileLineColLoc");
+      auto line = getToken().getUnsignedIntegerValue();
+      if (!line.hasValue())
+        return emitError("expected integer line number in FileLineColLoc");
+      consumeToken(Token::integer);
+
+      // Parse the ':'.
+      if (parseToken(Token::colon, "expected ':' in FileLineColLoc"))
+        return failure();
+
+      // Parse the column number.
+      if (getToken().isNot(Token::integer))
+        return emitError("expected integer column number in FileLineColLoc");
+      auto column = getToken().getUnsignedIntegerValue();
+      if (!column.hasValue())
+        return emitError("expected integer column number in FileLineColLoc");
+      consumeToken(Token::integer);
+
+      loc = FileLineColLoc::get(str, line.getValue(), column.getValue(), ctx);
+      return success();
+    }
+
+    // Otherwise, this is a NameLoc.
+
+    // Check for a child location.
+    if (consumeIf(Token::l_paren)) {
+      auto childSourceLoc = getToken().getLoc();
+
+      // Parse the child location.
+      LocationAttr childLoc;
+      if (parseLocationInstance(childLoc))
+        return failure();
+
+      // The child must not be another NameLoc.
+      if (childLoc.isa<NameLoc>())
+        return emitError(childSourceLoc,
+                         "child of NameLoc cannot be another NameLoc");
+      loc = NameLoc::get(Identifier::get(str, ctx), childLoc, ctx);
+
+      // Parse the closing ')'.
+      if (parseToken(Token::r_paren,
+                     "expected ')' after child location of NameLoc"))
+        return failure();
+    } else {
+      loc = NameLoc::get(Identifier::get(str, ctx), ctx);
+    }
+
+    return success();
+  }
+
+  // Check for a 'unknown' for an unknown location.
+  if (getToken().is(Token::bare_identifier) &&
+      getToken().getSpelling() == "unknown") {
+    consumeToken(Token::bare_identifier);
+    loc = UnknownLoc::get(ctx);
+    return success();
+  }
+
+  // If the token is 'fused', then this is a fused location.
+  if (getToken().is(Token::bare_identifier) &&
+      getToken().getSpelling() == "fused") {
+    consumeToken(Token::bare_identifier);
+
+    // Try to parse the optional metadata.
+    Attribute metadata;
+    if (consumeIf(Token::less)) {
+      metadata = parseAttribute();
+      if (!metadata)
+        return emitError("expected valid attribute metadata");
+      // Parse the '>' token.
+      if (parseToken(Token::greater,
+                     "expected '>' after fused location metadata"))
+        return failure();
+    }
+
+    llvm::SmallVector<Location, 4> locations;
+    auto parseElt = [&] {
+      LocationAttr newLoc;
+      if (parseLocationInstance(newLoc))
+        return failure();
+      locations.push_back(newLoc);
+      return success();
+    };
+
+    if (parseToken(Token::l_square, "expected '[' in fused location") ||
+        parseCommaSeparatedList(parseElt) ||
+        parseToken(Token::r_square, "expected ']' in fused location"))
+      return failure();
+
+    // Return the fused location.
+    loc = FusedLoc::get(locations, metadata, getContext());
+    return success();
+  }
+
+  // Check for the 'callsite' signifying a callsite location.
+  if (getToken().is(Token::bare_identifier) &&
+      getToken().getSpelling() == "callsite") {
+    consumeToken(Token::bare_identifier);
+
+    // Parse the '('.
+    if (parseToken(Token::l_paren, "expected '(' in callsite location"))
+      return failure();
+
+    // Parse the callee location.
+    LocationAttr calleeLoc;
+    if (parseLocationInstance(calleeLoc))
+      return failure();
+
+    // Parse the 'at'.
+    if (getToken().isNot(Token::bare_identifier) ||
+        getToken().getSpelling() != "at")
+      return emitError("expected 'at' in callsite location");
+    consumeToken(Token::bare_identifier);
+
+    // Parse the caller location.
+    LocationAttr callerLoc;
+    if (parseLocationInstance(callerLoc))
+      return failure();
+
+    // Parse the ')'.
+    if (parseToken(Token::r_paren, "expected ')' in callsite location"))
+      return failure();
+
+    // Return the callsite location.
+    loc = CallSiteLoc::get(calleeLoc, callerLoc, ctx);
+    return success();
+  }
+
+  return emitError("expected location instance");
+}
+
+//===----------------------------------------------------------------------===//
+// Affine parsing.
+//===----------------------------------------------------------------------===//
+
+/// Lower precedence ops (all at the same precedence level). LNoOp is false in
+/// the boolean sense.
+enum AffineLowPrecOp {
+  /// Null value.
+  LNoOp,
+  Add,
+  Sub
+};
+
+/// Higher precedence ops - all at the same precedence level. HNoOp is false
+/// in the boolean sense.
+enum AffineHighPrecOp {
+  /// Null value.
+  HNoOp,
+  Mul,
+  FloorDiv,
+  CeilDiv,
+  Mod
+};
+
+namespace {
+/// This is a specialized parser for affine structures (affine maps, affine
+/// expressions, and integer sets), maintaining the state transient to their
+/// bodies.
+class AffineParser : public Parser {
+public:
+  AffineParser(ParserState &state, bool allowParsingSSAIds = false,
+               llvm::function_ref<ParseResult(bool)> parseElement = nullptr)
+      : Parser(state), allowParsingSSAIds(allowParsingSSAIds),
+        parseElement(parseElement), numDimOperands(0), numSymbolOperands(0) {}
+
+  AffineMap parseAffineMapRange(unsigned numDims, unsigned numSymbols);
+  ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set);
+  IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols);
+  ParseResult parseAffineMapOfSSAIds(AffineMap &map);
+  void getDimsAndSymbolSSAIds(SmallVectorImpl<StringRef> &dimAndSymbolSSAIds,
+                              unsigned &numDims);
+
+private:
+  // Binary affine op parsing.
+  AffineLowPrecOp consumeIfLowPrecOp();
+  AffineHighPrecOp consumeIfHighPrecOp();
+
+  // Identifier lists for polyhedral structures.
+  ParseResult parseDimIdList(unsigned &numDims);
+  ParseResult parseSymbolIdList(unsigned &numSymbols);
+  ParseResult parseDimAndOptionalSymbolIdList(unsigned &numDims,
+                                              unsigned &numSymbols);
+  ParseResult parseIdentifierDefinition(AffineExpr idExpr);
+
+  AffineExpr parseAffineExpr();
+  AffineExpr parseParentheticalExpr();
+  AffineExpr parseNegateExpression(AffineExpr lhs);
+  AffineExpr parseIntegerExpr();
+  AffineExpr parseBareIdExpr();
+  AffineExpr parseSSAIdExpr(bool isSymbol);
+  AffineExpr parseSymbolSSAIdExpr();
+
+  AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs,
+                                   AffineExpr rhs, SMLoc opLoc);
+  AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs,
+                                   AffineExpr rhs);
+  AffineExpr parseAffineOperandExpr(AffineExpr lhs);
+  AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp);
+  AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp,
+                                       SMLoc llhsOpLoc);
+  AffineExpr parseAffineConstraint(bool *isEq);
+
+private:
+  bool allowParsingSSAIds;
+  llvm::function_ref<ParseResult(bool)> parseElement;
+  unsigned numDimOperands;
+  unsigned numSymbolOperands;
+  SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols;
+};
+} // end anonymous namespace
+
+/// Create an affine binary high precedence op expression (mul's, div's, mod).
+/// opLoc is the location of the op token to be used to report errors
+/// for non-conforming expressions.
+AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op,
+                                               AffineExpr lhs, AffineExpr rhs,
+                                               SMLoc opLoc) {
+  // TODO: make the error location info accurate.
+  switch (op) {
+  case Mul:
+    if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) {
+      emitError(opLoc, "non-affine expression: at least one of the multiply "
+                       "operands has to be either a constant or symbolic");
+      return nullptr;
+    }
+    return lhs * rhs;
+  case FloorDiv:
+    if (!rhs.isSymbolicOrConstant()) {
+      emitError(opLoc, "non-affine expression: right operand of floordiv "
+                       "has to be either a constant or symbolic");
+      return nullptr;
+    }
+    return lhs.floorDiv(rhs);
+  case CeilDiv:
+    if (!rhs.isSymbolicOrConstant()) {
+      emitError(opLoc, "non-affine expression: right operand of ceildiv "
+                       "has to be either a constant or symbolic");
+      return nullptr;
+    }
+    return lhs.ceilDiv(rhs);
+  case Mod:
+    if (!rhs.isSymbolicOrConstant()) {
+      emitError(opLoc, "non-affine expression: right operand of mod "
+                       "has to be either a constant or symbolic");
+      return nullptr;
+    }
+    return lhs % rhs;
+  case HNoOp:
+    llvm_unreachable("can't create affine expression for null high prec op");
+    return nullptr;
+  }
+  llvm_unreachable("Unknown AffineHighPrecOp");
+}
+
+/// Create an affine binary low precedence op expression (add, sub).
+AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op,
+                                               AffineExpr lhs, AffineExpr rhs) {
+  switch (op) {
+  case AffineLowPrecOp::Add:
+    return lhs + rhs;
+  case AffineLowPrecOp::Sub:
+    return lhs - rhs;
+  case AffineLowPrecOp::LNoOp:
+    llvm_unreachable("can't create affine expression for null low prec op");
+    return nullptr;
+  }
+  llvm_unreachable("Unknown AffineLowPrecOp");
+}
+
+/// Consume this token if it is a lower precedence affine op (there are only
+/// two precedence levels).
+AffineLowPrecOp AffineParser::consumeIfLowPrecOp() {
+  switch (getToken().getKind()) {
+  case Token::plus:
+    consumeToken(Token::plus);
+    return AffineLowPrecOp::Add;
+  case Token::minus:
+    consumeToken(Token::minus);
+    return AffineLowPrecOp::Sub;
+  default:
+    return AffineLowPrecOp::LNoOp;
+  }
+}
+
+/// Consume this token if it is a higher precedence affine op (there are only
+/// two precedence levels)
+AffineHighPrecOp AffineParser::consumeIfHighPrecOp() {
+  switch (getToken().getKind()) {
+  case Token::star:
+    consumeToken(Token::star);
+    return Mul;
+  case Token::kw_floordiv:
+    consumeToken(Token::kw_floordiv);
+    return FloorDiv;
+  case Token::kw_ceildiv:
+    consumeToken(Token::kw_ceildiv);
+    return CeilDiv;
+  case Token::kw_mod:
+    consumeToken(Token::kw_mod);
+    return Mod;
+  default:
+    return HNoOp;
+  }
+}
+
+/// Parse a high precedence op expression list: mul, div, and mod are high
+/// precedence binary ops, i.e., parse a
+///   expr_1 op_1 expr_2 op_2 ... expr_n
+/// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod).
+/// All affine binary ops are left associative.
+/// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is
+/// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is
+/// null. llhsOpLoc is the location of the llhsOp token that will be used to
+/// report an error for non-conforming expressions.
+AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs,
+                                                   AffineHighPrecOp llhsOp,
+                                                   SMLoc llhsOpLoc) {
+  AffineExpr lhs = parseAffineOperandExpr(llhs);
+  if (!lhs)
+    return nullptr;
+
+  // Found an LHS. Parse the remaining expression.
+  auto opLoc = getToken().getLoc();
+  if (AffineHighPrecOp op = consumeIfHighPrecOp()) {
+    if (llhs) {
+      AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc);
+      if (!expr)
+        return nullptr;
+      return parseAffineHighPrecOpExpr(expr, op, opLoc);
+    }
+    // No LLHS, get RHS
+    return parseAffineHighPrecOpExpr(lhs, op, opLoc);
+  }
+
+  // This is the last operand in this expression.
+  if (llhs)
+    return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc);
+
+  // No llhs, 'lhs' itself is the expression.
+  return lhs;
+}
+
+/// Parse an affine expression inside parentheses.
+///
+///   affine-expr ::= `(` affine-expr `)`
+AffineExpr AffineParser::parseParentheticalExpr() {
+  if (parseToken(Token::l_paren, "expected '('"))
+    return nullptr;
+  if (getToken().is(Token::r_paren))
+    return (emitError("no expression inside parentheses"), nullptr);
+
+  auto expr = parseAffineExpr();
+  if (!expr)
+    return nullptr;
+  if (parseToken(Token::r_paren, "expected ')'"))
+    return nullptr;
+
+  return expr;
+}
+
+/// Parse the negation expression.
+///
+///   affine-expr ::= `-` affine-expr
+AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) {
+  if (parseToken(Token::minus, "expected '-'"))
+    return nullptr;
+
+  AffineExpr operand = parseAffineOperandExpr(lhs);
+  // Since negation has the highest precedence of all ops (including high
+  // precedence ops) but lower than parentheses, we are only going to use
+  // parseAffineOperandExpr instead of parseAffineExpr here.
+  if (!operand)
+    // Extra error message although parseAffineOperandExpr would have
+    // complained. Leads to a better diagnostic.
+    return (emitError("missing operand of negation"), nullptr);
+  return (-1) * operand;
+}
+
+/// Parse a bare id that may appear in an affine expression.
+///
+///   affine-expr ::= bare-id
+AffineExpr AffineParser::parseBareIdExpr() {
+  if (getToken().isNot(Token::bare_identifier))
+    return (emitError("expected bare identifier"), nullptr);
+
+  StringRef sRef = getTokenSpelling();
+  for (auto entry : dimsAndSymbols) {
+    if (entry.first == sRef) {
+      consumeToken(Token::bare_identifier);
+      return entry.second;
+    }
+  }
+
+  return (emitError("use of undeclared identifier"), nullptr);
+}
+
+/// Parse an SSA id which may appear in an affine expression.
+AffineExpr AffineParser::parseSSAIdExpr(bool isSymbol) {
+  if (!allowParsingSSAIds)
+    return (emitError("unexpected ssa identifier"), nullptr);
+  if (getToken().isNot(Token::percent_identifier))
+    return (emitError("expected ssa identifier"), nullptr);
+  auto name = getTokenSpelling();
+  // Check if we already parsed this SSA id.
+  for (auto entry : dimsAndSymbols) {
+    if (entry.first == name) {
+      consumeToken(Token::percent_identifier);
+      return entry.second;
+    }
+  }
+  // Parse the SSA id and add an AffineDim/SymbolExpr to represent it.
+  if (parseElement(isSymbol))
+    return (emitError("failed to parse ssa identifier"), nullptr);
+  auto idExpr = isSymbol
+                    ? getAffineSymbolExpr(numSymbolOperands++, getContext())
+                    : getAffineDimExpr(numDimOperands++, getContext());
+  dimsAndSymbols.push_back({name, idExpr});
+  return idExpr;
+}
+
+AffineExpr AffineParser::parseSymbolSSAIdExpr() {
+  if (parseToken(Token::kw_symbol, "expected symbol keyword") ||
+      parseToken(Token::l_paren, "expected '(' at start of SSA symbol"))
+    return nullptr;
+  AffineExpr symbolExpr = parseSSAIdExpr(/*isSymbol=*/true);
+  if (!symbolExpr)
+    return nullptr;
+  if (parseToken(Token::r_paren, "expected ')' at end of SSA symbol"))
+    return nullptr;
+  return symbolExpr;
+}
+
+/// Parse a positive integral constant appearing in an affine expression.
+///
+///   affine-expr ::= integer-literal
+AffineExpr AffineParser::parseIntegerExpr() {
+  auto val = getToken().getUInt64IntegerValue();
+  if (!val.hasValue() || (int64_t)val.getValue() < 0)
+    return (emitError("constant too large for index"), nullptr);
+
+  consumeToken(Token::integer);
+  return builder.getAffineConstantExpr((int64_t)val.getValue());
+}
+
+/// Parses an expression that can be a valid operand of an affine expression.
+/// lhs: if non-null, lhs is an affine expression that is the lhs of a binary
+/// operator, the rhs of which is being parsed. This is used to determine
+/// whether an error should be emitted for a missing right operand.
+//  Eg: for an expression without parentheses (like i + j + k + l), each
+//  of the four identifiers is an operand. For i + j*k + l, j*k is not an
+//  operand expression, it's an op expression and will be parsed via
+//  parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and
+//  -l are valid operands that will be parsed by this function.
+AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) {
+  switch (getToken().getKind()) {
+  case Token::bare_identifier:
+    return parseBareIdExpr();
+  case Token::kw_symbol:
+    return parseSymbolSSAIdExpr();
+  case Token::percent_identifier:
+    return parseSSAIdExpr(/*isSymbol=*/false);
+  case Token::integer:
+    return parseIntegerExpr();
+  case Token::l_paren:
+    return parseParentheticalExpr();
+  case Token::minus:
+    return parseNegateExpression(lhs);
+  case Token::kw_ceildiv:
+  case Token::kw_floordiv:
+  case Token::kw_mod:
+  case Token::plus:
+  case Token::star:
+    if (lhs)
+      emitError("missing right operand of binary operator");
+    else
+      emitError("missing left operand of binary operator");
+    return nullptr;
+  default:
+    if (lhs)
+      emitError("missing right operand of binary operator");
+    else
+      emitError("expected affine expression");
+    return nullptr;
+  }
+}
+
+/// Parse affine expressions that are bare-id's, integer constants,
+/// parenthetical affine expressions, and affine op expressions that are a
+/// composition of those.
+///
+/// All binary op's associate from left to right.
+///
+/// {add, sub} have lower precedence than {mul, div, and mod}.
+///
+/// Add, sub'are themselves at the same precedence level. Mul, floordiv,
+/// ceildiv, and mod are at the same higher precedence level. Negation has
+/// higher precedence than any binary op.
+///
+/// llhs: the affine expression appearing on the left of the one being parsed.
+/// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null,
+/// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned
+/// if llhs is non-null; otherwise lhs is returned. This is to deal with left
+/// associativity.
+///
+/// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function
+/// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where
+/// (e2*e3) will be parsed using parseAffineHighPrecOpExpr().
+AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs,
+                                                  AffineLowPrecOp llhsOp) {
+  AffineExpr lhs;
+  if (!(lhs = parseAffineOperandExpr(llhs)))
+    return nullptr;
+
+  // Found an LHS. Deal with the ops.
+  if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) {
+    if (llhs) {
+      AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs);
+      return parseAffineLowPrecOpExpr(sum, lOp);
+    }
+    // No LLHS, get RHS and form the expression.
+    return parseAffineLowPrecOpExpr(lhs, lOp);
+  }
+  auto opLoc = getToken().getLoc();
+  if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) {
+    // We have a higher precedence op here. Get the rhs operand for the llhs
+    // through parseAffineHighPrecOpExpr.
+    AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc);
+    if (!highRes)
+      return nullptr;
+
+    // If llhs is null, the product forms the first operand of the yet to be
+    // found expression. If non-null, the op to associate with llhs is llhsOp.
+    AffineExpr expr =
+        llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes;
+
+    // Recurse for subsequent low prec op's after the affine high prec op
+    // expression.
+    if (AffineLowPrecOp nextOp = consumeIfLowPrecOp())
+      return parseAffineLowPrecOpExpr(expr, nextOp);
+    return expr;
+  }
+  // Last operand in the expression list.
+  if (llhs)
+    return getAffineBinaryOpExpr(llhsOp, llhs, lhs);
+  // No llhs, 'lhs' itself is the expression.
+  return lhs;
+}
+
+/// Parse an affine expression.
+///  affine-expr ::= `(` affine-expr `)`
+///                | `-` affine-expr
+///                | affine-expr `+` affine-expr
+///                | affine-expr `-` affine-expr
+///                | affine-expr `*` affine-expr
+///                | affine-expr `floordiv` affine-expr
+///                | affine-expr `ceildiv` affine-expr
+///                | affine-expr `mod` affine-expr
+///                | bare-id
+///                | integer-literal
+///
+/// Additional conditions are checked depending on the production. For eg.,
+/// one of the operands for `*` has to be either constant/symbolic; the second
+/// operand for floordiv, ceildiv, and mod has to be a positive integer.
+AffineExpr AffineParser::parseAffineExpr() {
+  return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp);
+}
+
+/// Parse a dim or symbol from the lists appearing before the actual
+/// expressions of the affine map. Update our state to store the
+/// dimensional/symbolic identifier.
+ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) {
+  if (getToken().isNot(Token::bare_identifier))
+    return emitError("expected bare identifier");
+
+  auto name = getTokenSpelling();
+  for (auto entry : dimsAndSymbols) {
+    if (entry.first == name)
+      return emitError("redefinition of identifier '" + name + "'");
+  }
+  consumeToken(Token::bare_identifier);
+
+  dimsAndSymbols.push_back({name, idExpr});
+  return success();
+}
+
+/// Parse the list of dimensional identifiers to an affine map.
+ParseResult AffineParser::parseDimIdList(unsigned &numDims) {
+  if (parseToken(Token::l_paren,
+                 "expected '(' at start of dimensional identifiers list")) {
+    return failure();
+  }
+
+  auto parseElt = [&]() -> ParseResult {
+    auto dimension = getAffineDimExpr(numDims++, getContext());
+    return parseIdentifierDefinition(dimension);
+  };
+  return parseCommaSeparatedListUntil(Token::r_paren, parseElt);
+}
+
+/// Parse the list of symbolic identifiers to an affine map.
+ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) {
+  consumeToken(Token::l_square);
+  auto parseElt = [&]() -> ParseResult {
+    auto symbol = getAffineSymbolExpr(numSymbols++, getContext());
+    return parseIdentifierDefinition(symbol);
+  };
+  return parseCommaSeparatedListUntil(Token::r_square, parseElt);
+}
+
+/// Parse the list of symbolic identifiers to an affine map.
+ParseResult
+AffineParser::parseDimAndOptionalSymbolIdList(unsigned &numDims,
+                                              unsigned &numSymbols) {
+  if (parseDimIdList(numDims)) {
+    return failure();
+  }
+  if (!getToken().is(Token::l_square)) {
+    numSymbols = 0;
+    return success();
+  }
+  return parseSymbolIdList(numSymbols);
+}
+
+/// Parses an ambiguous affine map or integer set definition inline.
+ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map,
+                                                           IntegerSet &set) {
+  unsigned numDims = 0, numSymbols = 0;
+
+  // List of dimensional and optional symbol identifiers.
+  if (parseDimAndOptionalSymbolIdList(numDims, numSymbols)) {
+    return failure();
+  }
+
+  // This is needed for parsing attributes as we wouldn't know whether we would
+  // be parsing an integer set attribute or an affine map attribute.
+  bool isArrow = getToken().is(Token::arrow);
+  bool isColon = getToken().is(Token::colon);
+  if (!isArrow && !isColon) {
+    return emitError("expected '->' or ':'");
+  } else if (isArrow) {
+    parseToken(Token::arrow, "expected '->' or '['");
+    map = parseAffineMapRange(numDims, numSymbols);
+    return map ? success() : failure();
+  } else if (parseToken(Token::colon, "expected ':' or '['")) {
+    return failure();
+  }
+
+  if ((set = parseIntegerSetConstraints(numDims, numSymbols)))
+    return success();
+
+  return failure();
+}
+
+/// Parse an AffineMap where the dim and symbol identifiers are SSA ids.
+ParseResult AffineParser::parseAffineMapOfSSAIds(AffineMap &map) {
+  if (parseToken(Token::l_square, "expected '['"))
+    return failure();
+
+  SmallVector<AffineExpr, 4> exprs;
+  auto parseElt = [&]() -> ParseResult {
+    auto elt = parseAffineExpr();
+    exprs.push_back(elt);
+    return elt ? success() : failure();
+  };
+
+  // Parse a multi-dimensional affine expression (a comma-separated list of
+  // 1-d affine expressions); the list cannot be empty. Grammar:
+  // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
+  if (parseCommaSeparatedListUntil(Token::r_square, parseElt,
+                                   /*allowEmptyList=*/true))
+    return failure();
+  // Parsed a valid affine map.
+  if (exprs.empty())
+    map = AffineMap();
+  else
+    map = builder.getAffineMap(numDimOperands,
+                               dimsAndSymbols.size() - numDimOperands, exprs);
+  return success();
+}
+
+/// Parse the range and sizes affine map definition inline.
+///
+///  affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr
+///
+///  multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
+AffineMap AffineParser::parseAffineMapRange(unsigned numDims,
+                                            unsigned numSymbols) {
+  parseToken(Token::l_paren, "expected '(' at start of affine map range");
+
+  SmallVector<AffineExpr, 4> exprs;
+  auto parseElt = [&]() -> ParseResult {
+    auto elt = parseAffineExpr();
+    ParseResult res = elt ? success() : failure();
+    exprs.push_back(elt);
+    return res;
+  };
+
+  // Parse a multi-dimensional affine expression (a comma-separated list of
+  // 1-d affine expressions); the list cannot be empty. Grammar:
+  // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
+  if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, false))
+    return AffineMap();
+
+  // Parsed a valid affine map.
+  return builder.getAffineMap(numDims, numSymbols, exprs);
+}
+
+/// Parse an affine constraint.
+///  affine-constraint ::= affine-expr `>=` `0`
+///                      | affine-expr `==` `0`
+///
+/// isEq is set to true if the parsed constraint is an equality, false if it
+/// is an inequality (greater than or equal).
+///
+AffineExpr AffineParser::parseAffineConstraint(bool *isEq) {
+  AffineExpr expr = parseAffineExpr();
+  if (!expr)
+    return nullptr;
+
+  if (consumeIf(Token::greater) && consumeIf(Token::equal) &&
+      getToken().is(Token::integer)) {
+    auto dim = getToken().getUnsignedIntegerValue();
+    if (dim.hasValue() && dim.getValue() == 0) {
+      consumeToken(Token::integer);
+      *isEq = false;
+      return expr;
+    }
+    return (emitError("expected '0' after '>='"), nullptr);
+  }
+
+  if (consumeIf(Token::equal) && consumeIf(Token::equal) &&
+      getToken().is(Token::integer)) {
+    auto dim = getToken().getUnsignedIntegerValue();
+    if (dim.hasValue() && dim.getValue() == 0) {
+      consumeToken(Token::integer);
+      *isEq = true;
+      return expr;
+    }
+    return (emitError("expected '0' after '=='"), nullptr);
+  }
+
+  return (emitError("expected '== 0' or '>= 0' at end of affine constraint"),
+          nullptr);
+}
+
+/// Parse the constraints that are part of an integer set definition.
+///  integer-set-inline
+///                ::= dim-and-symbol-id-lists `:`
+///                '(' affine-constraint-conjunction? ')'
+///  affine-constraint-conjunction ::= affine-constraint (`,`
+///                                       affine-constraint)*
+///
+IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims,
+                                                    unsigned numSymbols) {
+  if (parseToken(Token::l_paren,
+                 "expected '(' at start of integer set constraint list"))
+    return IntegerSet();
+
+  SmallVector<AffineExpr, 4> constraints;
+  SmallVector<bool, 4> isEqs;
+  auto parseElt = [&]() -> ParseResult {
+    bool isEq;
+    auto elt = parseAffineConstraint(&isEq);
+    ParseResult res = elt ? success() : failure();
+    if (elt) {
+      constraints.push_back(elt);
+      isEqs.push_back(isEq);
+    }
+    return res;
+  };
+
+  // Parse a list of affine constraints (comma-separated).
+  if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true))
+    return IntegerSet();
+
+  // If no constraints were parsed, then treat this as a degenerate 'true' case.
+  if (constraints.empty()) {
+    /* 0 == 0 */
+    auto zero = getAffineConstantExpr(0, getContext());
+    return builder.getIntegerSet(numDims, numSymbols, zero, true);
+  }
+
+  // Parsed a valid integer set.
+  return builder.getIntegerSet(numDims, numSymbols, constraints, isEqs);
+}
+
+/// Parse an ambiguous reference to either and affine map or an integer set.
+ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map,
+                                                        IntegerSet &set) {
+  return AffineParser(state).parseAffineMapOrIntegerSetInline(map, set);
+}
+
+/// Parse an AffineMap of SSA ids. The callback 'parseElement' is used to
+/// parse SSA value uses encountered while parsing affine expressions.
+ParseResult Parser::parseAffineMapOfSSAIds(
+    AffineMap &map, llvm::function_ref<ParseResult(bool)> parseElement) {
+  return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement)
+      .parseAffineMapOfSSAIds(map);
+}
+
+//===----------------------------------------------------------------------===//
+// OperationParser
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class provides support for parsing operations and regions of
+/// operations.
+class OperationParser : public Parser {
+public:
+  OperationParser(ParserState &state, ModuleOp moduleOp)
+      : Parser(state), opBuilder(moduleOp.getBodyRegion()), moduleOp(moduleOp) {
+  }
+
+  ~OperationParser();
+
+  /// After parsing is finished, this function must be called to see if there
+  /// are any remaining issues.
+  ParseResult finalize();
+
+  //===--------------------------------------------------------------------===//
+  // SSA Value Handling
+  //===--------------------------------------------------------------------===//
+
+  /// This represents a use of an SSA value in the program.  The first two
+  /// entries in the tuple are the name and result number of a reference.  The
+  /// third is the location of the reference, which is used in case this ends
+  /// up being a use of an undefined value.
+  struct SSAUseInfo {
+    StringRef name;  // Value name, e.g. %42 or %abc
+    unsigned number; // Number, specified with #12
+    SMLoc loc;       // Location of first definition or use.
+  };
+
+  /// Push a new SSA name scope to the parser.
+  void pushSSANameScope();
+
+  /// Pop the last SSA name scope from the parser.
+  ParseResult popSSANameScope();
+
+  /// Register a definition of a value with the symbol table.
+  ParseResult addDefinition(SSAUseInfo useInfo, Value *value);
+
+  /// Parse an optional list of SSA uses into 'results'.
+  ParseResult parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results);
+
+  /// Parse a single SSA use into 'result'.
+  ParseResult parseSSAUse(SSAUseInfo &result);
+
+  /// Given a reference to an SSA value and its type, return a reference. This
+  /// returns null on failure.
+  Value *resolveSSAUse(SSAUseInfo useInfo, Type type);
+
+  ParseResult parseSSADefOrUseAndType(
+      const std::function<ParseResult(SSAUseInfo, Type)> &action);
+
+  ParseResult parseOptionalSSAUseAndTypeList(SmallVectorImpl<Value *> &results);
+
+  /// Return the location of the value identified by its name and number if it
+  /// has been already defined.  Placeholder values are considered undefined.
+  llvm::Optional<SMLoc> getDefinitionLoc(StringRef name, unsigned number) {
+    if (!values.count(name) || number >= values[name].size())
+      return {};
+    Value *value = values[name][number].first;
+    if (value && !isForwardRefPlaceholder(value))
+      return values[name][number].second;
+    return {};
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Operation Parsing
+  //===--------------------------------------------------------------------===//
+
+  /// Parse an operation instance.
+  ParseResult parseOperation();
+
+  /// Parse a single operation successor and its operand list.
+  ParseResult parseSuccessorAndUseList(Block *&dest,
+                                       SmallVectorImpl<Value *> &operands);
+
+  /// Parse a comma-separated list of operation successors in brackets.
+  ParseResult
+  parseSuccessors(SmallVectorImpl<Block *> &destinations,
+                  SmallVectorImpl<SmallVector<Value *, 4>> &operands);
+
+  /// Parse an operation instance that is in the generic form.
+  Operation *parseGenericOperation();
+
+  /// Parse an operation instance that is in the op-defined custom form.
+  Operation *parseCustomOperation();
+
+  //===--------------------------------------------------------------------===//
+  // Region Parsing
+  //===--------------------------------------------------------------------===//
+
+  /// Parse a region into 'region' with the provided entry block arguments.
+  ParseResult parseRegion(Region &region,
+                          ArrayRef<std::pair<SSAUseInfo, Type>> entryArguments);
+
+  /// Parse a region body into 'region'.
+  ParseResult parseRegionBody(Region &region);
+
+  //===--------------------------------------------------------------------===//
+  // Block Parsing
+  //===--------------------------------------------------------------------===//
+
+  /// Parse a new block into 'block'.
+  ParseResult parseBlock(Block *&block);
+
+  /// Parse a list of operations into 'block'.
+  ParseResult parseBlockBody(Block *block);
+
+  /// Parse a (possibly empty) list of block arguments.
+  ParseResult
+  parseOptionalBlockArgList(SmallVectorImpl<BlockArgument *> &results,
+                            Block *owner);
+
+  /// Get the block with the specified name, creating it if it doesn't
+  /// already exist.  The location specified is the point of use, which allows
+  /// us to diagnose references to blocks that are not defined precisely.
+  Block *getBlockNamed(StringRef name, SMLoc loc);
+
+  /// Define the block with the specified name. Returns the Block* or nullptr in
+  /// the case of redefinition.
+  Block *defineBlockNamed(StringRef name, SMLoc loc, Block *existing);
+
+private:
+  /// Returns the info for a block at the current scope for the given name.
+  std::pair<Block *, SMLoc> &getBlockInfoByName(StringRef name) {
+    return blocksByName.back()[name];
+  }
+
+  /// Insert a new forward reference to the given block.
+  void insertForwardRef(Block *block, SMLoc loc) {
+    forwardRef.back().try_emplace(block, loc);
+  }
+
+  /// Erase any forward reference to the given block.
+  bool eraseForwardRef(Block *block) { return forwardRef.back().erase(block); }
+
+  /// Record that a definition was added at the current scope.
+  void recordDefinition(StringRef def) {
+    definitionsPerScope.back().insert(def);
+  }
+
+  /// Create a forward reference placeholder value with the given location and
+  /// result type.
+  Value *createForwardRefPlaceholder(SMLoc loc, Type type);
+
+  /// Return true if this is a forward reference.
+  bool isForwardRefPlaceholder(Value *value) {
+    return forwardRefPlaceholders.count(value);
+  }
+
+  /// This keeps track of the block names as well as the location of the first
+  /// reference for each nested name scope. This is used to diagnose invalid
+  /// block references and memoize them.
+  SmallVector<DenseMap<StringRef, std::pair<Block *, SMLoc>>, 2> blocksByName;
+  SmallVector<DenseMap<Block *, SMLoc>, 2> forwardRef;
+
+  /// This keeps track of all of the SSA values we are tracking for each name
+  /// scope, indexed by their name. This has one entry per result number.
+  llvm::StringMap<SmallVector<std::pair<Value *, SMLoc>, 1>> values;
+
+  /// This keeps track of all of the values defined by a specific name scope.
+  SmallVector<llvm::StringSet<>, 2> definitionsPerScope;
+
+  /// These are all of the placeholders we've made along with the location of
+  /// their first reference, to allow checking for use of undefined values.
+  DenseMap<Value *, SMLoc> forwardRefPlaceholders;
+
+  /// The builder used when creating parsed operation instances.
+  OpBuilder opBuilder;
+
+  /// The top level module operation.
+  ModuleOp moduleOp;
+};
+} // end anonymous namespace
+
+OperationParser::~OperationParser() {
+  for (auto &fwd : forwardRefPlaceholders) {
+    // Drop all uses of undefined forward declared reference and destroy
+    // defining operation.
+    fwd.first->dropAllUses();
+    fwd.first->getDefiningOp()->destroy();
+  }
+}
+
+/// After parsing is finished, this function must be called to see if there are
+/// any remaining issues.
+ParseResult OperationParser::finalize() {
+  // Check for any forward references that are left.  If we find any, error
+  // out.
+  if (!forwardRefPlaceholders.empty()) {
+    SmallVector<std::pair<const char *, Value *>, 4> errors;
+    // Iteration over the map isn't deterministic, so sort by source location.
+    for (auto entry : forwardRefPlaceholders)
+      errors.push_back({entry.second.getPointer(), entry.first});
+    llvm::array_pod_sort(errors.begin(), errors.end());
+
+    for (auto entry : errors) {
+      auto loc = SMLoc::getFromPointer(entry.first);
+      emitError(loc, "use of undeclared SSA value name");
+    }
+    return failure();
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// SSA Value Handling
+//===----------------------------------------------------------------------===//
+
+void OperationParser::pushSSANameScope() {
+  blocksByName.push_back(DenseMap<StringRef, std::pair<Block *, SMLoc>>());
+  forwardRef.push_back(DenseMap<Block *, SMLoc>());
+  definitionsPerScope.push_back({});
+}
+
+ParseResult OperationParser::popSSANameScope() {
+  auto forwardRefInCurrentScope = forwardRef.pop_back_val();
+
+  // Verify that all referenced blocks were defined.
+  if (!forwardRefInCurrentScope.empty()) {
+    SmallVector<std::pair<const char *, Block *>, 4> errors;
+    // Iteration over the map isn't deterministic, so sort by source location.
+    for (auto entry : forwardRefInCurrentScope) {
+      errors.push_back({entry.second.getPointer(), entry.first});
+      // Add this block to the top-level region to allow for automatic cleanup.
+      moduleOp.getOperation()->getRegion(0).push_back(entry.first);
+    }
+    llvm::array_pod_sort(errors.begin(), errors.end());
+
+    for (auto entry : errors) {
+      auto loc = SMLoc::getFromPointer(entry.first);
+      emitError(loc, "reference to an undefined block");
+    }
+    return failure();
+  }
+
+  // Drop any values defined in this scope from the value map.
+  for (auto &def : definitionsPerScope.pop_back_val())
+    values.erase(def.getKey());
+  blocksByName.pop_back();
+
+  return success();
+}
+
+/// Register a definition of a value with the symbol table.
+ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, Value *value) {
+  auto &entries = values[useInfo.name];
+
+  // Make sure there is a slot for this value.
+  if (entries.size() <= useInfo.number)
+    entries.resize(useInfo.number + 1);
+
+  // If we already have an entry for this, check to see if it was a definition
+  // or a forward reference.
+  if (auto *existing = entries[useInfo.number].first) {
+    if (!isForwardRefPlaceholder(existing)) {
+      return emitError(useInfo.loc)
+          .append("redefinition of SSA value '", useInfo.name, "'")
+          .attachNote(getEncodedSourceLocation(entries[useInfo.number].second))
+          .append("previously defined here");
+    }
+
+    // If it was a forward reference, update everything that used it to use
+    // the actual definition instead, delete the forward ref, and remove it
+    // from our set of forward references we track.
+    existing->replaceAllUsesWith(value);
+    existing->getDefiningOp()->destroy();
+    forwardRefPlaceholders.erase(existing);
+  }
+
+  /// Record this definition for the current scope.
+  entries[useInfo.number] = {value, useInfo.loc};
+  recordDefinition(useInfo.name);
+  return success();
+}
+
+/// Parse a (possibly empty) list of SSA operands.
+///
+///   ssa-use-list ::= ssa-use (`,` ssa-use)*
+///   ssa-use-list-opt ::= ssa-use-list?
+///
+ParseResult
+OperationParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) {
+  if (getToken().isNot(Token::percent_identifier))
+    return success();
+  return parseCommaSeparatedList([&]() -> ParseResult {
+    SSAUseInfo result;
+    if (parseSSAUse(result))
+      return failure();
+    results.push_back(result);
+    return success();
+  });
+}
+
+/// Parse a SSA operand for an operation.
+///
+///   ssa-use ::= ssa-id
+///
+ParseResult OperationParser::parseSSAUse(SSAUseInfo &result) {
+  result.name = getTokenSpelling();
+  result.number = 0;
+  result.loc = getToken().getLoc();
+  if (parseToken(Token::percent_identifier, "expected SSA operand"))
+    return failure();
+
+  // If we have an attribute ID, it is a result number.
+  if (getToken().is(Token::hash_identifier)) {
+    if (auto value = getToken().getHashIdentifierNumber())
+      result.number = value.getValue();
+    else
+      return emitError("invalid SSA value result number");
+    consumeToken(Token::hash_identifier);
+  }
+
+  return success();
+}
+
+/// Given an unbound reference to an SSA value and its type, return the value
+/// it specifies.  This returns null on failure.
+Value *OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) {
+  auto &entries = values[useInfo.name];
+
+  // If we have already seen a value of this name, return it.
+  if (useInfo.number < entries.size() && entries[useInfo.number].first) {
+    auto *result = entries[useInfo.number].first;
+    // Check that the type matches the other uses.
+    if (result->getType() == type)
+      return result;
+
+    emitError(useInfo.loc, "use of value '")
+        .append(useInfo.name,
+                "' expects different type than prior uses: ", type, " vs ",
+                result->getType())
+        .attachNote(getEncodedSourceLocation(entries[useInfo.number].second))
+        .append("prior use here");
+    return nullptr;
+  }
+
+  // Make sure we have enough slots for this.
+  if (entries.size() <= useInfo.number)
+    entries.resize(useInfo.number + 1);
+
+  // If the value has already been defined and this is an overly large result
+  // number, diagnose that.
+  if (entries[0].first && !isForwardRefPlaceholder(entries[0].first))
+    return (emitError(useInfo.loc, "reference to invalid result number"),
+            nullptr);
+
+  // Otherwise, this is a forward reference.  Create a placeholder and remember
+  // that we did so.
+  auto *result = createForwardRefPlaceholder(useInfo.loc, type);
+  entries[useInfo.number].first = result;
+  entries[useInfo.number].second = useInfo.loc;
+  return result;
+}
+
+/// Parse an SSA use with an associated type.
+///
+///   ssa-use-and-type ::= ssa-use `:` type
+ParseResult OperationParser::parseSSADefOrUseAndType(
+    const std::function<ParseResult(SSAUseInfo, Type)> &action) {
+  SSAUseInfo useInfo;
+  if (parseSSAUse(useInfo) ||
+      parseToken(Token::colon, "expected ':' and type for SSA operand"))
+    return failure();
+
+  auto type = parseType();
+  if (!type)
+    return failure();
+
+  return action(useInfo, type);
+}
+
+/// Parse a (possibly empty) list of SSA operands, followed by a colon, then
+/// followed by a type list.
+///
+///   ssa-use-and-type-list
+///     ::= ssa-use-list ':' type-list-no-parens
+///
+ParseResult OperationParser::parseOptionalSSAUseAndTypeList(
+    SmallVectorImpl<Value *> &results) {
+  SmallVector<SSAUseInfo, 4> valueIDs;
+  if (parseOptionalSSAUseList(valueIDs))
+    return failure();
+
+  // If there were no operands, then there is no colon or type lists.
+  if (valueIDs.empty())
+    return success();
+
+  SmallVector<Type, 4> types;
+  if (parseToken(Token::colon, "expected ':' in operand list") ||
+      parseTypeListNoParens(types))
+    return failure();
+
+  if (valueIDs.size() != types.size())
+    return emitError("expected ")
+           << valueIDs.size() << " types to match operand list";
+
+  results.reserve(valueIDs.size());
+  for (unsigned i = 0, e = valueIDs.size(); i != e; ++i) {
+    if (auto *value = resolveSSAUse(valueIDs[i], types[i]))
+      results.push_back(value);
+    else
+      return failure();
+  }
+
+  return success();
+}
+
+/// Create and remember a new placeholder for a forward reference.
+Value *OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) {
+  // Forward references are always created as operations, because we just need
+  // something with a def/use chain.
+  //
+  // We create these placeholders as having an empty name, which we know
+  // cannot be created through normal user input, allowing us to distinguish
+  // them.
+  auto name = OperationName("placeholder", getContext());
+  auto *op = Operation::create(
+      getEncodedSourceLocation(loc), name, /*operands=*/{}, type,
+      /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
+      /*resizableOperandList=*/false, getContext());
+  forwardRefPlaceholders[op->getResult(0)] = loc;
+  return op->getResult(0);
+}
+
+//===----------------------------------------------------------------------===//
+// Operation Parsing
+//===----------------------------------------------------------------------===//
+
+/// Parse an operation.
+///
+///  operation ::=
+///    operation-result? string '(' ssa-use-list? ')' attribute-dict?
+///    `:` function-type trailing-location?
+///  operation-result ::= ssa-id ((`:` integer-literal) | (`,` ssa-id)*) `=`
+///
+ParseResult OperationParser::parseOperation() {
+  auto loc = getToken().getLoc();
+  SmallVector<std::pair<StringRef, SMLoc>, 1> resultIDs;
+  size_t numExpectedResults;
+  if (getToken().is(Token::percent_identifier)) {
+    // Parse the first result id.
+    resultIDs.emplace_back(getTokenSpelling(), loc);
+    consumeToken(Token::percent_identifier);
+
+    // If the next token is a ':', we parse the expected result count.
+    if (consumeIf(Token::colon)) {
+      // Check that the next token is an integer.
+      if (!getToken().is(Token::integer))
+        return emitError("expected integer number of results");
+
+      // Check that number of results is > 0.
+      auto val = getToken().getUInt64IntegerValue();
+      if (!val.hasValue() || val.getValue() < 1)
+        return emitError("expected named operation to have atleast 1 result");
+      consumeToken(Token::integer);
+      numExpectedResults = *val;
+    } else {
+      // Otherwise, this is a comma separated list of result ids.
+      if (consumeIf(Token::comma)) {
+        auto parseNextResult = [&]() -> ParseResult {
+          // Parse the next result id.
+          if (!getToken().is(Token::percent_identifier))
+            return emitError("expected valid ssa identifier");
+
+          resultIDs.emplace_back(getTokenSpelling(), getToken().getLoc());
+          consumeToken(Token::percent_identifier);
+          return success();
+        };
+
+        if (parseCommaSeparatedList(parseNextResult))
+          return failure();
+      }
+      numExpectedResults = resultIDs.size();
+    }
+
+    if (parseToken(Token::equal, "expected '=' after SSA name"))
+      return failure();
+  }
+
+  Operation *op;
+  if (getToken().is(Token::bare_identifier) || getToken().isKeyword())
+    op = parseCustomOperation();
+  else if (getToken().is(Token::string))
+    op = parseGenericOperation();
+  else
+    return emitError("expected operation name in quotes");
+
+  // If parsing of the basic operation failed, then this whole thing fails.
+  if (!op)
+    return failure();
+
+  // If the operation had a name, register it.
+  if (!resultIDs.empty()) {
+    if (op->getNumResults() == 0)
+      return emitError(loc, "cannot name an operation with no results");
+    if (numExpectedResults != op->getNumResults())
+      return emitError(loc, "operation defines ")
+             << op->getNumResults() << " results but was provided "
+             << numExpectedResults << " to bind";
+
+    // If the number of result names matches the number of operation results, we
+    // can directly use the provided names.
+    if (resultIDs.size() == op->getNumResults()) {
+      for (unsigned i = 0, e = op->getNumResults(); i != e; ++i)
+        if (addDefinition({resultIDs[i].first, 0, resultIDs[i].second},
+                          op->getResult(i)))
+          return failure();
+    } else {
+      // Otherwise, we use the same name for all results.
+      StringRef name = resultIDs.front().first;
+      for (unsigned i = 0, e = op->getNumResults(); i != e; ++i)
+        if (addDefinition({name, i, loc}, op->getResult(i)))
+          return failure();
+    }
+  }
+
+  // Try to parse the optional trailing location.
+  if (parseOptionalTrailingLocation(op))
+    return failure();
+
+  return success();
+}
+
+/// Parse a single operation successor and its operand list.
+///
+///   successor ::= block-id branch-use-list?
+///   branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)`
+///
+ParseResult
+OperationParser::parseSuccessorAndUseList(Block *&dest,
+                                          SmallVectorImpl<Value *> &operands) {
+  // Verify branch is identifier and get the matching block.
+  if (!getToken().is(Token::caret_identifier))
+    return emitError("expected block name");
+  dest = getBlockNamed(getTokenSpelling(), getToken().getLoc());
+  consumeToken();
+
+  // Handle optional arguments.
+  if (consumeIf(Token::l_paren) &&
+      (parseOptionalSSAUseAndTypeList(operands) ||
+       parseToken(Token::r_paren, "expected ')' to close argument list"))) {
+    return failure();
+  }
+
+  return success();
+}
+
+/// Parse a comma-separated list of operation successors in brackets.
+///
+///   successor-list ::= `[` successor (`,` successor )* `]`
+///
+ParseResult OperationParser::parseSuccessors(
+    SmallVectorImpl<Block *> &destinations,
+    SmallVectorImpl<SmallVector<Value *, 4>> &operands) {
+  if (parseToken(Token::l_square, "expected '['"))
+    return failure();
+
+  auto parseElt = [this, &destinations, &operands]() {
+    Block *dest;
+    SmallVector<Value *, 4> destOperands;
+    auto res = parseSuccessorAndUseList(dest, destOperands);
+    destinations.push_back(dest);
+    operands.push_back(destOperands);
+    return res;
+  };
+  return parseCommaSeparatedListUntil(Token::r_square, parseElt,
+                                      /*allowEmptyList=*/false);
+}
+
+namespace {
+// RAII-style guard for cleaning up the regions in the operation state before
+// deleting them.  Within the parser, regions may get deleted if parsing failed,
+// and other errors may be present, in praticular undominated uses.  This makes
+// sure such uses are deleted.
+struct CleanupOpStateRegions {
+  ~CleanupOpStateRegions() {
+    SmallVector<Region *, 4> regionsToClean;
+    regionsToClean.reserve(state.regions.size());
+    for (auto &region : state.regions)
+      if (region)
+        for (auto &block : *region)
+          block.dropAllDefinedValueUses();
+  }
+  OperationState &state;
+};
+} // namespace
+
+Operation *OperationParser::parseGenericOperation() {
+  // Get location information for the operation.
+  auto srcLocation = getEncodedSourceLocation(getToken().getLoc());
+
+  auto name = getToken().getStringValue();
+  if (name.empty())
+    return (emitError("empty operation name is invalid"), nullptr);
+  if (name.find('\0') != StringRef::npos)
+    return (emitError("null character not allowed in operation name"), nullptr);
+
+  consumeToken(Token::string);
+
+  OperationState result(srcLocation, name);
+
+  // Generic operations have a resizable operation list.
+  result.setOperandListToResizable();
+
+  // Parse the operand list.
+  SmallVector<SSAUseInfo, 8> operandInfos;
+
+  if (parseToken(Token::l_paren, "expected '(' to start operand list") ||
+      parseOptionalSSAUseList(operandInfos) ||
+      parseToken(Token::r_paren, "expected ')' to end operand list")) {
+    return nullptr;
+  }
+
+  // Parse the successor list but don't add successors to the result yet to
+  // avoid messing up with the argument order.
+  SmallVector<Block *, 2> successors;
+  SmallVector<SmallVector<Value *, 4>, 2> successorOperands;
+  if (getToken().is(Token::l_square)) {
+    // Check if the operation is a known terminator.
+    const AbstractOperation *abstractOp = result.name.getAbstractOperation();
+    if (abstractOp && !abstractOp->hasProperty(OperationProperty::Terminator))
+      return emitError("successors in non-terminator"), nullptr;
+    if (parseSuccessors(successors, successorOperands))
+      return nullptr;
+  }
+
+  // Parse the region list.
+  CleanupOpStateRegions guard{result};
+  if (consumeIf(Token::l_paren)) {
+    do {
+      // Create temporary regions with the top level region as parent.
+      result.regions.emplace_back(new Region(moduleOp));
+      if (parseRegion(*result.regions.back(), /*entryArguments=*/{}))
+        return nullptr;
+    } while (consumeIf(Token::comma));
+    if (parseToken(Token::r_paren, "expected ')' to end region list"))
+      return nullptr;
+  }
+
+  if (getToken().is(Token::l_brace)) {
+    if (parseAttributeDict(result.attributes))
+      return nullptr;
+  }
+
+  if (parseToken(Token::colon, "expected ':' followed by operation type"))
+    return nullptr;
+
+  auto typeLoc = getToken().getLoc();
+  auto type = parseType();
+  if (!type)
+    return nullptr;
+  auto fnType = type.dyn_cast<FunctionType>();
+  if (!fnType)
+    return (emitError(typeLoc, "expected function type"), nullptr);
+
+  result.addTypes(fnType.getResults());
+
+  // Check that we have the right number of types for the operands.
+  auto operandTypes = fnType.getInputs();
+  if (operandTypes.size() != operandInfos.size()) {
+    auto plural = "s"[operandInfos.size() == 1];
+    return (emitError(typeLoc, "expected ")
+                << operandInfos.size() << " operand type" << plural
+                << " but had " << operandTypes.size(),
+            nullptr);
+  }
+
+  // Resolve all of the operands.
+  for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) {
+    result.operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i]));
+    if (!result.operands.back())
+      return nullptr;
+  }
+
+  // Add the sucessors, and their operands after the proper operands.
+  for (const auto &succ : llvm::zip(successors, successorOperands)) {
+    Block *successor = std::get<0>(succ);
+    const SmallVector<Value *, 4> &operands = std::get<1>(succ);
+    result.addSuccessor(successor, operands);
+  }
+
+  return opBuilder.createOperation(result);
+}
+
+namespace {
+class CustomOpAsmParser : public OpAsmParser {
+public:
+  CustomOpAsmParser(SMLoc nameLoc, StringRef opName, OperationParser &parser)
+      : nameLoc(nameLoc), opName(opName), parser(parser) {}
+
+  /// Parse an instance of the operation described by 'opDefinition' into the
+  /// provided operation state.
+  ParseResult parseOperation(const AbstractOperation *opDefinition,
+                             OperationState *opState) {
+    if (opDefinition->parseAssembly(this, opState))
+      return failure();
+
+    // Check that none of the operands of the current operation reference an
+    // entry block argument for any of the region.
+    for (auto *entryArg : parsedRegionEntryArgumentPlaceholders)
+      if (llvm::is_contained(opState->operands, entryArg))
+        return emitError(nameLoc, "operand use before it's defined");
+
+    return success();
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Utilities
+  //===--------------------------------------------------------------------===//
+
+  /// Return if any errors were emitted during parsing.
+  bool didEmitError() const { return emittedError; }
+
+  /// Emit a diagnostic at the specified location and return failure.
+  InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override {
+    emittedError = true;
+    return parser.emitError(loc, "custom op '" + opName + "' " + message);
+  }
+
+  llvm::SMLoc getCurrentLocation() override {
+    return parser.getToken().getLoc();
+  }
+
+  Builder &getBuilder() const override { return parser.builder; }
+
+  llvm::SMLoc getNameLoc() const override { return nameLoc; }
+
+  //===--------------------------------------------------------------------===//
+  // Token Parsing
+  //===--------------------------------------------------------------------===//
+
+  /// Parse a `->` token.
+  ParseResult parseArrow() override {
+    return parser.parseToken(Token::arrow, "expected '->'");
+  }
+
+  /// Parses a `->` if present.
+  ParseResult parseOptionalArrow() override {
+    return success(parser.consumeIf(Token::arrow));
+  }
+
+  /// Parse a `:` token.
+  ParseResult parseColon() override {
+    return parser.parseToken(Token::colon, "expected ':'");
+  }
+
+  /// Parse a `:` token if present.
+  ParseResult parseOptionalColon() override {
+    return success(parser.consumeIf(Token::colon));
+  }
+
+  /// Parse a `,` token.
+  ParseResult parseComma() override {
+    return parser.parseToken(Token::comma, "expected ','");
+  }
+
+  /// Parse a `,` token if present.
+  ParseResult parseOptionalComma() override {
+    return success(parser.consumeIf(Token::comma));
+  }
+
+  /// Parses a `...` if present.
+  ParseResult parseOptionalEllipsis() override {
+    return success(parser.consumeIf(Token::ellipsis));
+  }
+
+  /// Parse a `=` token.
+  ParseResult parseEqual() override {
+    return parser.parseToken(Token::equal, "expected '='");
+  }
+
+  /// Parse a keyword if present.
+  ParseResult parseOptionalKeyword(const char *keyword) override {
+    // Check that the current token is a bare identifier or keyword.
+    if (parser.getToken().isNot(Token::bare_identifier) &&
+        !parser.getToken().isKeyword())
+      return failure();
+
+    if (parser.getTokenSpelling() == keyword) {
+      parser.consumeToken();
+      return success();
+    }
+    return failure();
+  }
+
+  /// Parse a `(` token.
+  ParseResult parseLParen() override {
+    return parser.parseToken(Token::l_paren, "expected '('");
+  }
+
+  /// Parses a '(' if present.
+  ParseResult parseOptionalLParen() override {
+    return success(parser.consumeIf(Token::l_paren));
+  }
+
+  /// Parse a `)` token.
+  ParseResult parseRParen() override {
+    return parser.parseToken(Token::r_paren, "expected ')'");
+  }
+
+  /// Parses a ')' if present.
+  ParseResult parseOptionalRParen() override {
+    return success(parser.consumeIf(Token::r_paren));
+  }
+
+  /// Parse a `[` token.
+  ParseResult parseLSquare() override {
+    return parser.parseToken(Token::l_square, "expected '['");
+  }
+
+  /// Parses a '[' if present.
+  ParseResult parseOptionalLSquare() override {
+    return success(parser.consumeIf(Token::l_square));
+  }
+
+  /// Parse a `]` token.
+  ParseResult parseRSquare() override {
+    return parser.parseToken(Token::r_square, "expected ']'");
+  }
+
+  /// Parses a ']' if present.
+  ParseResult parseOptionalRSquare() override {
+    return success(parser.consumeIf(Token::r_square));
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Attribute Parsing
+  //===--------------------------------------------------------------------===//
+
+  /// Parse an arbitrary attribute of a given type and return it in result. This
+  /// also adds the attribute to the specified attribute list with the specified
+  /// name.
+  ParseResult parseAttribute(Attribute &result, Type type, StringRef attrName,
+                             SmallVectorImpl<NamedAttribute> &attrs) override {
+    result = parser.parseAttribute(type);
+    if (!result)
+      return failure();
+
+    attrs.push_back(parser.builder.getNamedAttr(attrName, result));
+    return success();
+  }
+
+  /// Parse a named dictionary into 'result' if it is present.
+  ParseResult
+  parseOptionalAttributeDict(SmallVectorImpl<NamedAttribute> &result) override {
+    if (parser.getToken().isNot(Token::l_brace))
+      return success();
+    return parser.parseAttributeDict(result);
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Identifier Parsing
+  //===--------------------------------------------------------------------===//
+
+  /// Parse an @-identifier and store it (without the '@' symbol) in a string
+  /// attribute named 'attrName'.
+  ParseResult parseSymbolName(StringAttr &result, StringRef attrName,
+                              SmallVectorImpl<NamedAttribute> &attrs) override {
+    if (parser.getToken().isNot(Token::at_identifier))
+      return failure();
+    result = getBuilder().getStringAttr(parser.getTokenSpelling().drop_front());
+    attrs.push_back(getBuilder().getNamedAttr(attrName, result));
+    parser.consumeToken();
+    return success();
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Operand Parsing
+  //===--------------------------------------------------------------------===//
+
+  /// Parse a single operand.
+  ParseResult parseOperand(OperandType &result) override {
+    OperationParser::SSAUseInfo useInfo;
+    if (parser.parseSSAUse(useInfo))
+      return failure();
+
+    result = {useInfo.loc, useInfo.name, useInfo.number};
+    return success();
+  }
+
+  /// Parse zero or more SSA comma-separated operand references with a specified
+  /// surrounding delimiter, and an optional required operand count.
+  ParseResult parseOperandList(SmallVectorImpl<OperandType> &result,
+                               int requiredOperandCount = -1,
+                               Delimiter delimiter = Delimiter::None) override {
+    return parseOperandOrRegionArgList(result, /*isOperandList=*/true,
+                                       requiredOperandCount, delimiter);
+  }
+
+  /// Parse zero or more SSA comma-separated operand or region arguments with
+  ///  optional surrounding delimiter and required operand count.
+  ParseResult
+  parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result,
+                              bool isOperandList, int requiredOperandCount = -1,
+                              Delimiter delimiter = Delimiter::None) {
+    auto startLoc = parser.getToken().getLoc();
+
+    // Handle delimiters.
+    switch (delimiter) {
+    case Delimiter::None:
+      // Don't check for the absence of a delimiter if the number of operands
+      // is unknown (and hence the operand list could be empty).
+      if (requiredOperandCount == -1)
+        break;
+      // Token already matches an identifier and so can't be a delimiter.
+      if (parser.getToken().is(Token::percent_identifier))
+        break;
+      // Test against known delimiters.
+      if (parser.getToken().is(Token::l_paren) ||
+          parser.getToken().is(Token::l_square))
+        return emitError(startLoc, "unexpected delimiter");
+      return emitError(startLoc, "invalid operand");
+    case Delimiter::OptionalParen:
+      if (parser.getToken().isNot(Token::l_paren))
+        return success();
+      LLVM_FALLTHROUGH;
+    case Delimiter::Paren:
+      if (parser.parseToken(Token::l_paren, "expected '(' in operand list"))
+        return failure();
+      break;
+    case Delimiter::OptionalSquare:
+      if (parser.getToken().isNot(Token::l_square))
+        return success();
+      LLVM_FALLTHROUGH;
+    case Delimiter::Square:
+      if (parser.parseToken(Token::l_square, "expected '[' in operand list"))
+        return failure();
+      break;
+    }
+
+    // Check for zero operands.
+    if (parser.getToken().is(Token::percent_identifier)) {
+      do {
+        OperandType operandOrArg;
+        if (isOperandList ? parseOperand(operandOrArg)
+                          : parseRegionArgument(operandOrArg))
+          return failure();
+        result.push_back(operandOrArg);
+      } while (parser.consumeIf(Token::comma));
+    }
+
+    // Handle delimiters.   If we reach here, the optional delimiters were
+    // present, so we need to parse their closing one.
+    switch (delimiter) {
+    case Delimiter::None:
+      break;
+    case Delimiter::OptionalParen:
+    case Delimiter::Paren:
+      if (parser.parseToken(Token::r_paren, "expected ')' in operand list"))
+        return failure();
+      break;
+    case Delimiter::OptionalSquare:
+    case Delimiter::Square:
+      if (parser.parseToken(Token::r_square, "expected ']' in operand list"))
+        return failure();
+      break;
+    }
+
+    if (requiredOperandCount != -1 &&
+        result.size() != static_cast<size_t>(requiredOperandCount))
+      return emitError(startLoc, "expected ")
+             << requiredOperandCount << " operands";
+    return success();
+  }
+
+  /// Parse zero or more trailing SSA comma-separated trailing operand
+  /// references with a specified surrounding delimiter, and an optional
+  /// required operand count. A leading comma is expected before the operands.
+  ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
+                                       int requiredOperandCount,
+                                       Delimiter delimiter) override {
+    if (parser.getToken().is(Token::comma)) {
+      parseComma();
+      return parseOperandList(result, requiredOperandCount, delimiter);
+    }
+    if (requiredOperandCount != -1)
+      return emitError(parser.getToken().getLoc(), "expected ")
+             << requiredOperandCount << " operands";
+    return success();
+  }
+
+  /// Resolve an operand to an SSA value, emitting an error on failure.
+  ParseResult resolveOperand(const OperandType &operand, Type type,
+                             SmallVectorImpl<Value *> &result) override {
+    OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number,
+                                               operand.location};
+    if (auto *value = parser.resolveSSAUse(operandInfo, type)) {
+      result.push_back(value);
+      return success();
+    }
+    return failure();
+  }
+
+  /// Parse an AffineMap of SSA ids.
+  ParseResult
+  parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands,
+                         Attribute &mapAttr, StringRef attrName,
+                         SmallVectorImpl<NamedAttribute> &attrs) override {
+    SmallVector<OperandType, 2> dimOperands;
+    SmallVector<OperandType, 1> symOperands;
+
+    auto parseElement = [&](bool isSymbol) -> ParseResult {
+      OperandType operand;
+      if (parseOperand(operand))
+        return failure();
+      if (isSymbol)
+        symOperands.push_back(operand);
+      else
+        dimOperands.push_back(operand);
+      return success();
+    };
+
+    AffineMap map;
+    if (parser.parseAffineMapOfSSAIds(map, parseElement))
+      return failure();
+    // Add AffineMap attribute.
+    if (map) {
+      mapAttr = parser.builder.getAffineMapAttr(map);
+      attrs.push_back(parser.builder.getNamedAttr(attrName, mapAttr));
+    }
+
+    // Add dim operands before symbol operands in 'operands'.
+    operands.assign(dimOperands.begin(), dimOperands.end());
+    operands.append(symOperands.begin(), symOperands.end());
+    return success();
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Region Parsing
+  //===--------------------------------------------------------------------===//
+
+  /// Parse a region that takes `arguments` of `argTypes` types.  This
+  /// effectively defines the SSA values of `arguments` and assignes their type.
+  ParseResult parseRegion(Region &region, ArrayRef<OperandType> arguments,
+                          ArrayRef<Type> argTypes) override {
+    assert(arguments.size() == argTypes.size() &&
+           "mismatching number of arguments and types");
+
+    SmallVector<std::pair<OperationParser::SSAUseInfo, Type>, 2>
+        regionArguments;
+    for (const auto &pair : llvm::zip(arguments, argTypes)) {
+      const OperandType &operand = std::get<0>(pair);
+      Type type = std::get<1>(pair);
+      OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number,
+                                                 operand.location};
+      regionArguments.emplace_back(operandInfo, type);
+
+      // Create a placeholder for this argument so that we can detect invalid
+      // references to region arguments.
+      Value *value = parser.resolveSSAUse(operandInfo, type);
+      if (!value)
+        return failure();
+      parsedRegionEntryArgumentPlaceholders.emplace_back(value);
+    }
+
+    return parser.parseRegion(region, regionArguments);
+  }
+
+  /// Parses a region if present.
+  ParseResult parseOptionalRegion(Region &region,
+                                  ArrayRef<OperandType> arguments,
+                                  ArrayRef<Type> argTypes) override {
+    if (parser.getToken().isNot(Token::l_brace))
+      return success();
+    return parseRegion(region, arguments, argTypes);
+  }
+
+  /// Parse a region argument.  Region arguments define new values, so this also
+  /// checks if the values with the same name has not been defined yet.  The
+  /// type of the argument will be resolved later by a call to `parseRegion`.
+  ParseResult parseRegionArgument(OperandType &argument) override {
+    // Use parseOperand to fill in the OperandType structure.
+    if (parseOperand(argument))
+      return failure();
+    if (auto defLoc = parser.getDefinitionLoc(argument.name, argument.number)) {
+      parser.emitError(argument.location,
+                       "redefinition of SSA value '" + argument.name + "'")
+              .attachNote(parser.getEncodedSourceLocation(*defLoc))
+          << "previously defined here";
+      return failure();
+    }
+    return success();
+  }
+
+  /// Parse a region argument if present.
+  ParseResult parseOptionalRegionArgument(OperandType &argument) override {
+    if (parser.getToken().isNot(Token::percent_identifier))
+      return success();
+    return parseRegionArgument(argument);
+  }
+
+  ParseResult
+  parseRegionArgumentList(SmallVectorImpl<OperandType> &result,
+                          int requiredOperandCount = -1,
+                          Delimiter delimiter = Delimiter::None) override {
+    return parseOperandOrRegionArgList(result, /*isOperandList=*/false,
+                                       requiredOperandCount, delimiter);
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Successor Parsing
+  //===--------------------------------------------------------------------===//
+
+  /// Parse a single operation successor and its operand list.
+  ParseResult
+  parseSuccessorAndUseList(Block *&dest,
+                           SmallVectorImpl<Value *> &operands) override {
+    return parser.parseSuccessorAndUseList(dest, operands);
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Type Parsing
+  //===--------------------------------------------------------------------===//
+
+  /// Parse a type.
+  ParseResult parseType(Type &result) override {
+    return failure(!(result = parser.parseType()));
+  }
+
+  /// Parse an optional arrow followed by a type list.
+  ParseResult
+  parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) override {
+    if (!parser.consumeIf(Token::arrow))
+      return success();
+    return parser.parseFunctionResultTypes(result);
+  }
+
+  /// Parse a colon followed by a type.
+  ParseResult parseColonType(Type &result) override {
+    return failure(parser.parseToken(Token::colon, "expected ':'") ||
+                   !(result = parser.parseType()));
+  }
+
+  /// Parse a colon followed by a type list, which must have at least one type.
+  ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override {
+    if (parser.parseToken(Token::colon, "expected ':'"))
+      return failure();
+    return parser.parseTypeListNoParens(result);
+  }
+
+  /// Parse an optional colon followed by a type list, which if present must
+  /// have at least one type.
+  ParseResult
+  parseOptionalColonTypeList(SmallVectorImpl<Type> &result) override {
+    if (!parser.consumeIf(Token::colon))
+      return success();
+    return parser.parseTypeListNoParens(result);
+  }
+
+private:
+  /// A set of placeholder value definitions for parsed region arguments.
+  SmallVector<Value *, 2> parsedRegionEntryArgumentPlaceholders;
+
+  /// The source location of the operation name.
+  SMLoc nameLoc;
+
+  /// The name of the operation.
+  StringRef opName;
+
+  /// The main operation parser.
+  OperationParser &parser;
+
+  /// A flag that indicates if any errors were emitted during parsing.
+  bool emittedError = false;
+};
+} // end anonymous namespace.
+
+Operation *OperationParser::parseCustomOperation() {
+  auto opLoc = getToken().getLoc();
+  auto opName = getTokenSpelling();
+  CustomOpAsmParser opAsmParser(opLoc, opName, *this);
+
+  auto *opDefinition = AbstractOperation::lookup(opName, getContext());
+  if (!opDefinition && !opName.contains('.')) {
+    // If the operation name has no namespace prefix we treat it as a standard
+    // operation and prefix it with "std".
+    // TODO: Would it be better to just build a mapping of the registered
+    // operations in the standard dialect?
+    opDefinition =
+        AbstractOperation::lookup(Twine("std." + opName).str(), getContext());
+  }
+
+  if (!opDefinition) {
+    opAsmParser.emitError(opLoc, "is unknown");
+    return nullptr;
+  }
+
+  consumeToken();
+
+  // If the custom op parser crashes, produce some indication to help
+  // debugging.
+  std::string opNameStr = opName.str();
+  llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'",
+                                   opNameStr.c_str());
+
+  // Get location information for the operation.
+  auto srcLocation = getEncodedSourceLocation(opLoc);
+
+  // Have the op implementation take a crack and parsing this.
+  OperationState opState(srcLocation, opDefinition->name);
+  CleanupOpStateRegions guard{opState};
+  if (opAsmParser.parseOperation(opDefinition, &opState))
+    return nullptr;
+
+  // If it emitted an error, we failed.
+  if (opAsmParser.didEmitError())
+    return nullptr;
+
+  // Otherwise, we succeeded.  Use the state it parsed as our op information.
+  return opBuilder.createOperation(opState);
+}
+
+//===----------------------------------------------------------------------===//
+// Region Parsing
+//===----------------------------------------------------------------------===//
+
+/// Region.
+///
+///   region ::= '{' region-body
+///
+ParseResult OperationParser::parseRegion(
+    Region &region,
+    ArrayRef<std::pair<OperationParser::SSAUseInfo, Type>> entryArguments) {
+  // Parse the '{'.
+  if (parseToken(Token::l_brace, "expected '{' to begin a region"))
+    return failure();
+
+  // Check for an empty region.
+  if (entryArguments.empty() && consumeIf(Token::r_brace))
+    return success();
+  auto currentPt = opBuilder.saveInsertionPoint();
+
+  // Push a new named value scope.
+  pushSSANameScope();
+
+  // Parse the first block directly to allow for it to be unnamed.
+  Block *block = new Block();
+
+  // Add arguments to the entry block.
+  if (!entryArguments.empty()) {
+    for (auto &placeholderArgPair : entryArguments)
+      if (addDefinition(placeholderArgPair.first,
+                        block->addArgument(placeholderArgPair.second))) {
+        delete block;
+        return failure();
+      }
+
+    // If we had named arguments, then don't allow a block name.
+    if (getToken().is(Token::caret_identifier))
+      return emitError("invalid block name in region with named arguments");
+  }
+
+  if (parseBlock(block)) {
+    delete block;
+    return failure();
+  }
+
+  // Verify that no other arguments were parsed.
+  if (!entryArguments.empty() &&
+      block->getNumArguments() > entryArguments.size()) {
+    delete block;
+    return emitError("entry block arguments were already defined");
+  }
+
+  // Parse the rest of the region.
+  region.push_back(block);
+  if (parseRegionBody(region))
+    return failure();
+
+  // Pop the SSA value scope for this region.
+  if (popSSANameScope())
+    return failure();
+
+  // Reset the original insertion point.
+  opBuilder.restoreInsertionPoint(currentPt);
+  return success();
+}
+
+/// Region.
+///
+///   region-body ::= block* '}'
+///
+ParseResult OperationParser::parseRegionBody(Region &region) {
+  // Parse the list of blocks.
+  while (!consumeIf(Token::r_brace)) {
+    Block *newBlock = nullptr;
+    if (parseBlock(newBlock))
+      return failure();
+    region.push_back(newBlock);
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Block Parsing
+//===----------------------------------------------------------------------===//
+
+/// Block declaration.
+///
+///   block ::= block-label? operation*
+///   block-label    ::= block-id block-arg-list? `:`
+///   block-id       ::= caret-id
+///   block-arg-list ::= `(` ssa-id-and-type-list? `)`
+///
+ParseResult OperationParser::parseBlock(Block *&block) {
+  // The first block of a region may already exist, if it does the caret
+  // identifier is optional.
+  if (block && getToken().isNot(Token::caret_identifier))
+    return parseBlockBody(block);
+
+  SMLoc nameLoc = getToken().getLoc();
+  auto name = getTokenSpelling();
+  if (parseToken(Token::caret_identifier, "expected block name"))
+    return failure();
+
+  block = defineBlockNamed(name, nameLoc, block);
+
+  // Fail if the block was already defined.
+  if (!block)
+    return emitError(nameLoc, "redefinition of block '") << name << "'";
+
+  // If an argument list is present, parse it.
+  if (consumeIf(Token::l_paren)) {
+    SmallVector<BlockArgument *, 8> bbArgs;
+    if (parseOptionalBlockArgList(bbArgs, block) ||
+        parseToken(Token::r_paren, "expected ')' to end argument list"))
+      return failure();
+  }
+
+  if (parseToken(Token::colon, "expected ':' after block name"))
+    return failure();
+
+  return parseBlockBody(block);
+}
+
+ParseResult OperationParser::parseBlockBody(Block *block) {
+  // Set the insertion point to the end of the block to parse.
+  opBuilder.setInsertionPointToEnd(block);
+
+  // Parse the list of operations that make up the body of the block.
+  while (getToken().isNot(Token::caret_identifier, Token::r_brace))
+    if (parseOperation())
+      return failure();
+
+  return success();
+}
+
+/// Get the block with the specified name, creating it if it doesn't already
+/// exist.  The location specified is the point of use, which allows
+/// us to diagnose references to blocks that are not defined precisely.
+Block *OperationParser::getBlockNamed(StringRef name, SMLoc loc) {
+  auto &blockAndLoc = getBlockInfoByName(name);
+  if (!blockAndLoc.first) {
+    blockAndLoc = {new Block(), loc};
+    insertForwardRef(blockAndLoc.first, loc);
+  }
+
+  return blockAndLoc.first;
+}
+
+/// Define the block with the specified name. Returns the Block* or nullptr in
+/// the case of redefinition.
+Block *OperationParser::defineBlockNamed(StringRef name, SMLoc loc,
+                                         Block *existing) {
+  auto &blockAndLoc = getBlockInfoByName(name);
+  if (!blockAndLoc.first) {
+    // If the caller provided a block, use it.  Otherwise create a new one.
+    if (!existing)
+      existing = new Block();
+    blockAndLoc.first = existing;
+    blockAndLoc.second = loc;
+    return blockAndLoc.first;
+  }
+
+  // Forward declarations are removed once defined, so if we are defining a
+  // existing block and it is not a forward declaration, then it is a
+  // redeclaration.
+  if (!eraseForwardRef(blockAndLoc.first))
+    return nullptr;
+  return blockAndLoc.first;
+}
+
+/// Parse a (possibly empty) list of SSA operands with types as block arguments.
+///
+///   ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)*
+///
+ParseResult OperationParser::parseOptionalBlockArgList(
+    SmallVectorImpl<BlockArgument *> &results, Block *owner) {
+  if (getToken().is(Token::r_brace))
+    return success();
+
+  // If the block already has arguments, then we're handling the entry block.
+  // Parse and register the names for the arguments, but do not add them.
+  bool definingExistingArgs = owner->getNumArguments() != 0;
+  unsigned nextArgument = 0;
+
+  return parseCommaSeparatedList([&]() -> ParseResult {
+    return parseSSADefOrUseAndType(
+        [&](SSAUseInfo useInfo, Type type) -> ParseResult {
+          // If this block did not have existing arguments, define a new one.
+          if (!definingExistingArgs)
+            return addDefinition(useInfo, owner->addArgument(type));
+
+          // Otherwise, ensure that this argument has already been created.
+          if (nextArgument >= owner->getNumArguments())
+            return emitError("too many arguments specified in argument list");
+
+          // Finally, make sure the existing argument has the correct type.
+          auto *arg = owner->getArgument(nextArgument++);
+          if (arg->getType() != type)
+            return emitError("argument and block argument type mismatch");
+          return addDefinition(useInfo, arg);
+        });
+  });
+}
+
+//===----------------------------------------------------------------------===//
+// Top-level entity parsing.
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This parser handles entities that are only valid at the top level of the
+/// file.
+class ModuleParser : public Parser {
+public:
+  explicit ModuleParser(ParserState &state) : Parser(state) {}
+
+  ParseResult parseModule(ModuleOp module);
+
+private:
+  /// Parse an attribute alias declaration.
+  ParseResult parseAttributeAliasDef();
+
+  /// Parse an attribute alias declaration.
+  ParseResult parseTypeAliasDef();
+};
+} // end anonymous namespace
+
+/// Parses an attribute alias declaration.
+///
+///   attribute-alias-def ::= '#' alias-name `=` attribute-value
+///
+ParseResult ModuleParser::parseAttributeAliasDef() {
+  assert(getToken().is(Token::hash_identifier));
+  StringRef aliasName = getTokenSpelling().drop_front();
+
+  // Check for redefinitions.
+  if (getState().attributeAliasDefinitions.count(aliasName) > 0)
+    return emitError("redefinition of attribute alias id '" + aliasName + "'");
+
+  // Make sure this isn't invading the dialect attribute namespace.
+  if (aliasName.contains('.'))
+    return emitError("attribute names with a '.' are reserved for "
+                     "dialect-defined names");
+
+  consumeToken(Token::hash_identifier);
+
+  // Parse the '='.
+  if (parseToken(Token::equal, "expected '=' in attribute alias definition"))
+    return failure();
+
+  // Parse the attribute value.
+  Attribute attr = parseAttribute();
+  if (!attr)
+    return failure();
+
+  getState().attributeAliasDefinitions[aliasName] = attr;
+  return success();
+}
+
+/// Parse a type alias declaration.
+///
+///   type-alias-def ::= '!' alias-name `=` 'type' type
+///
+ParseResult ModuleParser::parseTypeAliasDef() {
+  assert(getToken().is(Token::exclamation_identifier));
+  StringRef aliasName = getTokenSpelling().drop_front();
+
+  // Check for redefinitions.
+  if (getState().typeAliasDefinitions.count(aliasName) > 0)
+    return emitError("redefinition of type alias id '" + aliasName + "'");
+
+  // Make sure this isn't invading the dialect type namespace.
+  if (aliasName.contains('.'))
+    return emitError("type names with a '.' are reserved for "
+                     "dialect-defined names");
+
+  consumeToken(Token::exclamation_identifier);
+
+  // Parse the '=' and 'type'.
+  if (parseToken(Token::equal, "expected '=' in type alias definition") ||
+      parseToken(Token::kw_type, "expected 'type' in type alias definition"))
+    return failure();
+
+  // Parse the type.
+  Type aliasedType = parseType();
+  if (!aliasedType)
+    return failure();
+
+  // Register this alias with the parser state.
+  getState().typeAliasDefinitions.try_emplace(aliasName, aliasedType);
+  return success();
+}
+
+/// This is the top-level module parser.
+ParseResult ModuleParser::parseModule(ModuleOp module) {
+  OperationParser opParser(getState(), module);
+
+  // Module itself is a name scope.
+  opParser.pushSSANameScope();
+
+  while (1) {
+    switch (getToken().getKind()) {
+    default:
+      // Parse a top-level operation.
+      if (opParser.parseOperation())
+        return failure();
+      break;
+
+    // If we got to the end of the file, then we're done.
+    case Token::eof: {
+      if (opParser.finalize())
+        return failure();
+
+      // Handle the case where the top level module was explicitly defined.
+      auto &bodyBlocks = module.getBodyRegion().getBlocks();
+      auto &operations = bodyBlocks.front().getOperations();
+      assert(!operations.empty() && "expected a valid module terminator");
+
+      // Check that the first operation is a module, and it is the only
+      // non-terminator operation.
+      ModuleOp nested = dyn_cast<ModuleOp>(operations.front());
+      if (nested && std::next(operations.begin(), 2) == operations.end()) {
+        // Merge the data of the nested module operation into 'module'.
+        module.setLoc(nested.getLoc());
+        module.setAttrs(nested.getOperation()->getAttrList());
+        bodyBlocks.splice(bodyBlocks.end(), nested.getBodyRegion().getBlocks());
+
+        // Erase the original module body.
+        bodyBlocks.pop_front();
+      }
+
+      return opParser.popSSANameScope();
+    }
+
+    // If we got an error token, then the lexer already emitted an error, just
+    // stop.  Someday we could introduce error recovery if there was demand
+    // for it.
+    case Token::error:
+      return failure();
+
+    // Parse an attribute alias.
+    case Token::hash_identifier:
+      if (parseAttributeAliasDef())
+        return failure();
+      break;
+
+    // Parse a type alias.
+    case Token::exclamation_identifier:
+      if (parseTypeAliasDef())
+        return failure();
+      break;
+    }
+  }
+}
+
+//===----------------------------------------------------------------------===//
+
+/// This parses the file specified by the indicated SourceMgr and returns an
+/// MLIR module if it was valid.  If not, it emits diagnostics and returns
+/// null.
+ModuleOp mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
+                               MLIRContext *context) {
+  auto sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
+
+  // This is the result module we are parsing into.
+  OwningModuleRef module(ModuleOp::create(FileLineColLoc::get(
+      sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0, context)));
+
+  ParserState state(sourceMgr, context);
+  if (ModuleParser(state).parseModule(*module))
+    return nullptr;
+
+  // Make sure the parse module has no other structural problems detected by
+  // the verifier.
+  if (failed(verify(*module)))
+    return nullptr;
+
+  return module.release();
+}
+
+/// This parses the file specified by the indicated filename and returns an
+/// MLIR module if it was valid.  If not, the error message is emitted through
+/// the error handler registered in the context, and a null pointer is returned.
+ModuleOp mlir::parseSourceFile(StringRef filename, MLIRContext *context) {
+  llvm::SourceMgr sourceMgr;
+  return parseSourceFile(filename, sourceMgr, context);
+}
+
+/// This parses the file specified by the indicated filename using the provided
+/// SourceMgr and returns an MLIR module if it was valid.  If not, the error
+/// message is emitted through the error handler registered in the context, and
+/// a null pointer is returned.
+ModuleOp mlir::parseSourceFile(StringRef filename, llvm::SourceMgr &sourceMgr,
+                               MLIRContext *context) {
+  if (sourceMgr.getNumBuffers() != 0) {
+    // TODO(b/136086478): Extend to support multiple buffers.
+    emitError(mlir::UnknownLoc::get(context),
+              "only main buffer parsed at the moment");
+    return nullptr;
+  }
+  auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(filename);
+  if (std::error_code error = file_or_err.getError()) {
+    emitError(mlir::UnknownLoc::get(context),
+              "could not open input file " + filename);
+    return nullptr;
+  }
+
+  // Load the MLIR module.
+  sourceMgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc());
+  return parseSourceFile(sourceMgr, context);
+}
+
+/// This parses the program string to a MLIR module if it was valid. If not,
+/// it emits diagnostics and returns null.
+ModuleOp mlir::parseSourceString(StringRef moduleStr, MLIRContext *context) {
+  auto memBuffer = MemoryBuffer::getMemBuffer(moduleStr);
+  if (!memBuffer)
+    return nullptr;
+
+  SourceMgr sourceMgr;
+  sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
+  return parseSourceFile(sourceMgr, context);
+}
+
+Type mlir::parseType(llvm::StringRef typeStr, MLIRContext *context) {
+  SourceMgr sourceMgr;
+  auto memBuffer =
+      MemoryBuffer::getMemBuffer(typeStr, /*BufferName=*/"<mlir_type_buffer>",
+                                 /*RequiresNullTerminator=*/false);
+  sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
+  SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
+  ParserState state(sourceMgr, context);
+  Parser parser(state);
+  auto start = parser.getToken().getLoc();
+  auto ty = parser.parseType();
+  if (!ty)
+    return Type();
+
+  auto end = parser.getToken().getLoc();
+  auto read = end.getPointer() - start.getPointer();
+  // Make sure that the parsing of type consumes the entire string
+  if (static_cast<size_t>(read) < typeStr.size()) {
+    parser.emitError("unexpected additional tokens: '")
+        << typeStr.substr(read) << "' after parsing type: " << ty;
+    return Type();
+  }
+  return ty;
+}
diff --git a/third_party/mlir/lib/Parser/Token.cpp b/third_party/mlir/lib/Parser/Token.cpp
new file mode 100644
index 0000000..f944d69
--- /dev/null
+++ b/third_party/mlir/lib/Parser/Token.cpp
@@ -0,0 +1,161 @@
+//===- Token.cpp - MLIR Token Implementation ------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the Token class for the MLIR textual form.
+//
+//===----------------------------------------------------------------------===//
+
+#include "Token.h"
+#include "llvm/ADT/StringExtras.h"
+using namespace mlir;
+using llvm::SMLoc;
+using llvm::SMRange;
+
+SMLoc Token::getLoc() const { return SMLoc::getFromPointer(spelling.data()); }
+
+SMLoc Token::getEndLoc() const {
+  return SMLoc::getFromPointer(spelling.data() + spelling.size());
+}
+
+SMRange Token::getLocRange() const { return SMRange(getLoc(), getEndLoc()); }
+
+/// For an integer token, return its value as an unsigned.  If it doesn't fit,
+/// return None.
+Optional<unsigned> Token::getUnsignedIntegerValue() const {
+  bool isHex = spelling.size() > 1 && spelling[1] == 'x';
+
+  unsigned result = 0;
+  if (spelling.getAsInteger(isHex ? 0 : 10, result))
+    return None;
+  return result;
+}
+
+/// For an integer token, return its value as a uint64_t.  If it doesn't fit,
+/// return None.
+Optional<uint64_t> Token::getUInt64IntegerValue() const {
+  bool isHex = spelling.size() > 1 && spelling[1] == 'x';
+
+  uint64_t result = 0;
+  if (spelling.getAsInteger(isHex ? 0 : 10, result))
+    return None;
+  return result;
+}
+
+/// For a floatliteral, return its value as a double. Return None if the value
+/// underflows or overflows.
+Optional<double> Token::getFloatingPointValue() const {
+  double result = 0;
+  if (spelling.getAsDouble(result))
+    return None;
+  return result;
+}
+
+/// For an inttype token, return its bitwidth.
+Optional<unsigned> Token::getIntTypeBitwidth() const {
+  unsigned result = 0;
+  if (spelling[1] == '0' || spelling.drop_front().getAsInteger(10, result) ||
+      result == 0)
+    return None;
+  return result;
+}
+
+/// Given a 'string' token, return its value, including removing the quote
+/// characters and unescaping the contents of the string.  The lexer has already
+/// verified that this token is valid.
+std::string Token::getStringValue() const {
+  assert(getKind() == string);
+  // Start by dropping the quotes.
+  StringRef bytes = getSpelling().drop_front().drop_back();
+
+  std::string result;
+  result.reserve(bytes.size());
+  for (unsigned i = 0, e = bytes.size(); i != e;) {
+    auto c = bytes[i++];
+    if (c != '\\') {
+      result.push_back(c);
+      continue;
+    }
+
+    assert(i + 1 < e && "invalid string should be caught by lexer");
+    auto c1 = bytes[i++];
+    switch (c1) {
+    case '"':
+    case '\\':
+      result.push_back(c1);
+      continue;
+    case 'n':
+      result.push_back('\n');
+      continue;
+    case 't':
+      result.push_back('\t');
+      continue;
+    default:
+      break;
+    }
+
+    assert(i + 1 <= e && "invalid string should be caught by lexer");
+    auto c2 = bytes[i++];
+
+    assert(llvm::isHexDigit(c1) && llvm::isHexDigit(c2) && "invalid escape");
+    result.push_back((llvm::hexDigitValue(c1) << 4) | llvm::hexDigitValue(c2));
+  }
+
+  return result;
+}
+
+/// Given a hash_identifier token like #123, try to parse the number out of
+/// the identifier, returning None if it is a named identifier like #x or
+/// if the integer doesn't fit.
+Optional<unsigned> Token::getHashIdentifierNumber() const {
+  assert(getKind() == hash_identifier);
+  unsigned result = 0;
+  if (spelling.drop_front().getAsInteger(10, result))
+    return None;
+  return result;
+}
+
+/// Given a punctuation or keyword token kind, return the spelling of the
+/// token as a string.  Warning: This will abort on markers, identifiers and
+/// literal tokens since they have no fixed spelling.
+StringRef Token::getTokenSpelling(Kind kind) {
+  switch (kind) {
+  default:
+    llvm_unreachable("This token kind has no fixed spelling");
+#define TOK_PUNCTUATION(NAME, SPELLING)                                        \
+  case NAME:                                                                   \
+    return SPELLING;
+#define TOK_OPERATOR(NAME, SPELLING)                                           \
+  case NAME:                                                                   \
+    return SPELLING;
+#define TOK_KEYWORD(SPELLING)                                                  \
+  case kw_##SPELLING:                                                          \
+    return #SPELLING;
+#include "TokenKinds.def"
+  }
+}
+
+/// Return true if this is one of the keyword token kinds (e.g. kw_if).
+bool Token::isKeyword() const {
+  switch (kind) {
+  default:
+    return false;
+#define TOK_KEYWORD(SPELLING)                                                  \
+  case kw_##SPELLING:                                                          \
+    return true;
+#include "TokenKinds.def"
+  }
+}
diff --git a/third_party/mlir/lib/Parser/Token.h b/third_party/mlir/lib/Parser/Token.h
new file mode 100644
index 0000000..69c3207
--- /dev/null
+++ b/third_party/mlir/lib/Parser/Token.h
@@ -0,0 +1,116 @@
+//===- Token.h - MLIR Token Interface ---------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_LIB_PARSER_TOKEN_H
+#define MLIR_LIB_PARSER_TOKEN_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/SMLoc.h"
+
+namespace mlir {
+
+/// This represents a token in the MLIR syntax.
+class Token {
+public:
+  enum Kind {
+#define TOK_MARKER(NAME) NAME,
+#define TOK_IDENTIFIER(NAME) NAME,
+#define TOK_LITERAL(NAME) NAME,
+#define TOK_PUNCTUATION(NAME, SPELLING) NAME,
+#define TOK_OPERATOR(NAME, SPELLING) NAME,
+#define TOK_KEYWORD(SPELLING) kw_##SPELLING,
+#include "TokenKinds.def"
+  };
+
+  Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
+
+  // Return the bytes that make up this token.
+  StringRef getSpelling() const { return spelling; }
+
+  // Token classification.
+  Kind getKind() const { return kind; }
+  bool is(Kind K) const { return kind == K; }
+
+  bool isAny(Kind k1, Kind k2) const { return is(k1) || is(k2); }
+
+  /// Return true if this token is one of the specified kinds.
+  template <typename... T>
+  bool isAny(Kind k1, Kind k2, Kind k3, T... others) const {
+    if (is(k1))
+      return true;
+    return isAny(k2, k3, others...);
+  }
+
+  bool isNot(Kind k) const { return kind != k; }
+
+  /// Return true if this token isn't one of the specified kinds.
+  template <typename... T> bool isNot(Kind k1, Kind k2, T... others) const {
+    return !isAny(k1, k2, others...);
+  }
+
+  /// Return true if this is one of the keyword token kinds (e.g. kw_if).
+  bool isKeyword() const;
+
+  // Helpers to decode specific sorts of tokens.
+
+  /// For an integer token, return its value as an unsigned.  If it doesn't fit,
+  /// return None.
+  Optional<unsigned> getUnsignedIntegerValue() const;
+
+  /// For an integer token, return its value as an uint64_t.  If it doesn't fit,
+  /// return None.
+  Optional<uint64_t> getUInt64IntegerValue() const;
+
+  /// For a floatliteral token, return its value as a double. Returns None in
+  /// the case of underflow or overflow.
+  Optional<double> getFloatingPointValue() const;
+
+  /// For an inttype token, return its bitwidth.
+  Optional<unsigned> getIntTypeBitwidth() const;
+
+  /// Given a hash_identifier token like #123, try to parse the number out of
+  /// the identifier, returning None if it is a named identifier like #x or
+  /// if the integer doesn't fit.
+  Optional<unsigned> getHashIdentifierNumber() const;
+
+  /// Given a 'string' token, return its value, including removing the quote
+  /// characters and unescaping the contents of the string.
+  std::string getStringValue() const;
+
+  // Location processing.
+  llvm::SMLoc getLoc() const;
+  llvm::SMLoc getEndLoc() const;
+  llvm::SMRange getLocRange() const;
+
+  /// Given a punctuation or keyword token kind, return the spelling of the
+  /// token as a string.  Warning: This will abort on markers, identifiers and
+  /// literal tokens since they have no fixed spelling.
+  static StringRef getTokenSpelling(Kind kind);
+
+private:
+  /// Discriminator that indicates the sort of token this is.
+  Kind kind;
+
+  /// A reference to the entire token contents; this is always a pointer into
+  /// a memory buffer owned by the source manager.
+  StringRef spelling;
+};
+
+} // end namespace mlir
+
+#endif // MLIR_LIB_PARSER_TOKEN_H
diff --git a/third_party/mlir/lib/Parser/TokenKinds.def b/third_party/mlir/lib/Parser/TokenKinds.def
new file mode 100644
index 0000000..32e9b12
--- /dev/null
+++ b/third_party/mlir/lib/Parser/TokenKinds.def
@@ -0,0 +1,131 @@
+//===- TokenKinds.def - MLIR Token Description ------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file is intended to be #include'd multiple times to extract information
+// about tokens for various clients in the lexer.
+//
+//===----------------------------------------------------------------------===//
+
+#if !defined(TOK_MARKER) && !defined(TOK_IDENTIFIER) && !defined(TOK_LITERAL)&&\
+    !defined(TOK_PUNCTUATION) && !defined(TOK_OPERATOR) && !defined(TOK_KEYWORD)
+#  error Must define one of the TOK_ macros.
+#endif
+
+#ifndef TOK_MARKER
+#define TOK_MARKER(X)
+#endif
+#ifndef TOK_IDENTIFIER
+#define TOK_IDENTIFIER(NAME)
+#endif
+#ifndef TOK_LITERAL
+#define TOK_LITERAL(NAME)
+#endif
+#ifndef TOK_PUNCTUATION
+#define TOK_PUNCTUATION(NAME, SPELLING)
+#endif
+#ifndef TOK_OPERATOR
+#define TOK_OPERATOR(NAME, SPELLING)
+#endif
+#ifndef TOK_KEYWORD
+#define TOK_KEYWORD(SPELLING)
+#endif
+
+
+// Markers
+TOK_MARKER(eof)
+TOK_MARKER(error)
+
+// Identifiers.
+TOK_IDENTIFIER(bare_identifier)        // foo
+TOK_IDENTIFIER(at_identifier)          // @foo
+TOK_IDENTIFIER(hash_identifier)        // #foo
+TOK_IDENTIFIER(percent_identifier)     // %foo
+TOK_IDENTIFIER(caret_identifier)       // ^foo
+TOK_IDENTIFIER(exclamation_identifier) // !foo
+
+// Literals
+TOK_LITERAL(floatliteral)               // 2.0
+TOK_LITERAL(integer)                    // 42
+TOK_LITERAL(string)                     // "foo"
+TOK_LITERAL(inttype)                    // i421
+
+// Punctuation.
+TOK_PUNCTUATION(arrow,            "->")
+TOK_PUNCTUATION(at,               "@")
+TOK_PUNCTUATION(colon,            ":")
+TOK_PUNCTUATION(comma,            ",")
+TOK_PUNCTUATION(question,         "?")
+TOK_PUNCTUATION(l_paren,          "(")
+TOK_PUNCTUATION(r_paren,          ")")
+TOK_PUNCTUATION(l_brace,          "{")
+TOK_PUNCTUATION(r_brace,          "}")
+TOK_PUNCTUATION(l_square,         "[")
+TOK_PUNCTUATION(r_square,         "]")
+TOK_PUNCTUATION(less,             "<")
+TOK_PUNCTUATION(greater,          ">")
+TOK_PUNCTUATION(equal,            "=")
+TOK_PUNCTUATION(ellipsis,         "...")
+// TODO: More punctuation.
+
+// Operators.
+TOK_OPERATOR(plus,               "+")
+TOK_OPERATOR(minus,              "-")
+TOK_OPERATOR(star,               "*")
+// TODO: More operator tokens
+
+// Keywords.  These turn "foo" into Token::kw_foo enums.
+
+// NOTE: Please key these alphabetized to make it easier to find something in
+// this list and to cater to OCD.
+TOK_KEYWORD(attributes)
+TOK_KEYWORD(bf16)
+TOK_KEYWORD(ceildiv)
+TOK_KEYWORD(complex)
+TOK_KEYWORD(dense)
+TOK_KEYWORD(f16)
+TOK_KEYWORD(f32)
+TOK_KEYWORD(f64)
+TOK_KEYWORD(false)
+TOK_KEYWORD(floordiv)
+TOK_KEYWORD(for)
+TOK_KEYWORD(func)
+TOK_KEYWORD(index)
+TOK_KEYWORD(loc)
+TOK_KEYWORD(max)
+TOK_KEYWORD(memref)
+TOK_KEYWORD(min)
+TOK_KEYWORD(mod)
+TOK_KEYWORD(none)
+TOK_KEYWORD(opaque)
+TOK_KEYWORD(size)
+TOK_KEYWORD(sparse)
+TOK_KEYWORD(step)
+TOK_KEYWORD(symbol)
+TOK_KEYWORD(tensor)
+TOK_KEYWORD(to)
+TOK_KEYWORD(true)
+TOK_KEYWORD(tuple)
+TOK_KEYWORD(type)
+TOK_KEYWORD(unit)
+TOK_KEYWORD(vector)
+
+#undef TOK_MARKER
+#undef TOK_IDENTIFIER
+#undef TOK_LITERAL
+#undef TOK_PUNCTUATION
+#undef TOK_OPERATOR
+#undef TOK_KEYWORD
diff --git a/third_party/mlir/lib/Pass/CMakeLists.txt b/third_party/mlir/lib/Pass/CMakeLists.txt
new file mode 100644
index 0000000..05122f5
--- /dev/null
+++ b/third_party/mlir/lib/Pass/CMakeLists.txt
@@ -0,0 +1,9 @@
+file(GLOB globbed *.c *.cpp)
+add_llvm_library(MLIRPass
+  ${globbed}
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Pass
+  )
+add_dependencies(MLIRPass MLIRAnalysis MLIRIR LLVMSupport)
+target_link_libraries(MLIRPass MLIRAnalysis MLIRIR LLVMSupport)
diff --git a/third_party/mlir/lib/Pass/IRPrinting.cpp b/third_party/mlir/lib/Pass/IRPrinting.cpp
new file mode 100644
index 0000000..2de4b05
--- /dev/null
+++ b/third_party/mlir/lib/Pass/IRPrinting.cpp
@@ -0,0 +1,136 @@
+//===- IRPrinting.cpp -----------------------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "PassDetail.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Pass/PassManager.h"
+#include "llvm/Support/Format.h"
+#include "llvm/Support/FormatVariadic.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+namespace {
+class IRPrinterInstrumentation : public PassInstrumentation {
+public:
+  /// A filter function to decide if the given ir should be printed. Returns
+  /// true if the ir should be printed, false otherwise.
+  using ShouldPrintFn = std::function<bool(Pass *)>;
+
+  IRPrinterInstrumentation(ShouldPrintFn &&shouldPrintBeforePass,
+                           ShouldPrintFn &&shouldPrintAfterPass,
+                           bool printModuleScope, raw_ostream &out)
+      : shouldPrintBeforePass(shouldPrintBeforePass),
+        shouldPrintAfterPass(shouldPrintAfterPass),
+        printModuleScope(printModuleScope), out(out) {
+    assert((shouldPrintBeforePass || shouldPrintAfterPass) &&
+           "expected atleast one valid filter function");
+  }
+
+private:
+  /// Instrumentation hooks.
+  void runBeforePass(Pass *pass, const llvm::Any &ir) override;
+  void runAfterPass(Pass *pass, const llvm::Any &ir) override;
+  void runAfterPassFailed(Pass *pass, const llvm::Any &ir) override;
+
+  /// Filter functions for before and after pass execution.
+  ShouldPrintFn shouldPrintBeforePass, shouldPrintAfterPass;
+
+  /// Flag to toggle if the printer should always print at module scope.
+  bool printModuleScope;
+
+  /// The stream to output to.
+  raw_ostream &out;
+};
+} // end anonymous namespace
+
+/// Returns true if the given pass is hidden from IR printing.
+static bool isHiddenPass(Pass *pass) {
+  return isAdaptorPass(pass) || isVerifierPass(pass);
+}
+
+static void printIR(const llvm::Any &ir, bool printModuleScope,
+                    raw_ostream &out) {
+  // Check for printing at module scope.
+  if (printModuleScope && llvm::any_isa<FuncOp>(ir)) {
+    FuncOp function = llvm::any_cast<FuncOp>(ir);
+
+    // Print the function name and a newline before the Module.
+    out << " (function: " << function.getName() << ")\n";
+    function.getParentOfType<ModuleOp>().print(out);
+    return;
+  }
+
+  // Print a newline before the IR.
+  out << "\n";
+
+  // Print the given function.
+  if (llvm::any_isa<FuncOp>(ir)) {
+    llvm::any_cast<FuncOp>(ir).print(out);
+    return;
+  }
+
+  // Print the given module.
+  assert(llvm::any_isa<ModuleOp>(ir) && "unexpected IR unit");
+  llvm::any_cast<ModuleOp>(ir).print(out);
+}
+
+/// Instrumentation hooks.
+void IRPrinterInstrumentation::runBeforePass(Pass *pass, const llvm::Any &ir) {
+  // Skip adaptor passes and passes that the user filtered out.
+  if (!shouldPrintBeforePass || isHiddenPass(pass) ||
+      !shouldPrintBeforePass(pass))
+    return;
+  out << formatv("*** IR Dump Before {0} ***", pass->getName());
+  printIR(ir, printModuleScope, out);
+  out << "\n\n";
+}
+
+void IRPrinterInstrumentation::runAfterPass(Pass *pass, const llvm::Any &ir) {
+  // Skip adaptor passes and passes that the user filtered out.
+  if (!shouldPrintAfterPass || isHiddenPass(pass) ||
+      !shouldPrintAfterPass(pass))
+    return;
+  out << formatv("*** IR Dump After {0} ***", pass->getName());
+  printIR(ir, printModuleScope, out);
+  out << "\n\n";
+}
+
+void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass,
+                                                  const llvm::Any &ir) {
+  // Skip adaptor passes and passes that the user filtered out.
+  if (!shouldPrintAfterPass || isAdaptorPass(pass) ||
+      !shouldPrintAfterPass(pass))
+    return;
+  out << formatv("*** IR Dump After {0} Failed ***", pass->getName());
+  printIR(ir, printModuleScope, out);
+  out << "\n\n";
+}
+
+//===----------------------------------------------------------------------===//
+// PassManager
+//===----------------------------------------------------------------------===//
+
+/// Add an instrumentation to print the IR before and after pass execution.
+void PassManager::enableIRPrinting(
+    std::function<bool(Pass *)> shouldPrintBeforePass,
+    std::function<bool(Pass *)> shouldPrintAfterPass, bool printModuleScope,
+    raw_ostream &out) {
+  addInstrumentation(new IRPrinterInstrumentation(
+      std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass),
+      printModuleScope, out));
+}
diff --git a/third_party/mlir/lib/Pass/Pass.cpp b/third_party/mlir/lib/Pass/Pass.cpp
new file mode 100644
index 0000000..3ed7b24
--- /dev/null
+++ b/third_party/mlir/lib/Pass/Pass.cpp
@@ -0,0 +1,439 @@
+//===- Pass.cpp - Pass infrastructure implementation ----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements common pass infrastructure.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Pass/Pass.h"
+#include "PassDetail.h"
+#include "mlir/Analysis/Verifier.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Pass/PassManager.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Mutex.h"
+#include "llvm/Support/Parallel.h"
+#include "llvm/Support/Threading.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+//===----------------------------------------------------------------------===//
+// Pass
+//===----------------------------------------------------------------------===//
+
+/// Out of line virtual method to ensure vtables and metadata are emitted to a
+/// single .o file.
+void Pass::anchor() {}
+
+/// Forwarding function to execute this pass.
+LogicalResult FunctionPassBase::run(FuncOp fn, FunctionAnalysisManager &fam) {
+  // Initialize the pass state.
+  passState.emplace(fn, fam);
+
+  // Instrument before the pass has run.
+  auto pi = fam.getPassInstrumentor();
+  if (pi)
+    pi->runBeforePass(this, fn);
+
+  // Invoke the virtual runOnFunction function.
+  runOnFunction();
+
+  // Invalidate any non preserved analyses.
+  fam.invalidate(passState->preservedAnalyses);
+
+  // Instrument after the pass has run.
+  bool passFailed = passState->irAndPassFailed.getInt();
+  if (pi) {
+    if (passFailed)
+      pi->runAfterPassFailed(this, fn);
+    else
+      pi->runAfterPass(this, fn);
+  }
+
+  // Return if the pass signaled a failure.
+  return failure(passFailed);
+}
+
+/// Forwarding function to execute this pass.
+LogicalResult ModulePassBase::run(ModuleOp module, ModuleAnalysisManager &mam) {
+  // Initialize the pass state.
+  passState.emplace(module, mam);
+
+  // Instrument before the pass has run.
+  auto pi = mam.getPassInstrumentor();
+  if (pi)
+    pi->runBeforePass(this, module);
+
+  // Invoke the virtual runOnModule function.
+  runOnModule();
+
+  // Invalidate any non preserved analyses.
+  mam.invalidate(passState->preservedAnalyses);
+
+  // Instrument after the pass has run.
+  bool passFailed = passState->irAndPassFailed.getInt();
+  if (pi) {
+    if (passFailed)
+      pi->runAfterPassFailed(this, module);
+    else
+      pi->runAfterPass(this, module);
+  }
+
+  // Return if the pass signaled a failure.
+  return failure(passFailed);
+}
+
+//===----------------------------------------------------------------------===//
+// PassExecutor
+//===----------------------------------------------------------------------===//
+
+FunctionPassExecutor::FunctionPassExecutor(const FunctionPassExecutor &rhs)
+    : PassExecutor(Kind::FunctionExecutor) {
+  for (auto &pass : rhs.passes)
+    addPass(pass->clone());
+}
+
+/// Run all of the passes in this manager over the current function.
+LogicalResult detail::FunctionPassExecutor::run(FuncOp function,
+                                                FunctionAnalysisManager &fam) {
+  // Run each of the held passes.
+  for (auto &pass : passes)
+    if (failed(pass->run(function, fam)))
+      return failure();
+  return success();
+}
+
+/// Run all of the passes in this manager over the current module.
+LogicalResult detail::ModulePassExecutor::run(ModuleOp module,
+                                              ModuleAnalysisManager &mam) {
+  // Run each of the held passes.
+  for (auto &pass : passes)
+    if (failed(pass->run(module, mam)))
+      return failure();
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ModuleToFunctionPassAdaptor
+//===----------------------------------------------------------------------===//
+
+/// Utility to run the given function and analysis manager on a provided
+/// function pass executor.
+static LogicalResult runFunctionPipeline(FunctionPassExecutor &fpe, FuncOp func,
+                                         FunctionAnalysisManager &fam) {
+  // Run the function pipeline over the provided function.
+  auto result = fpe.run(func, fam);
+
+  // Clear out any computed function analyses. These analyses won't be used
+  // any more in this pipeline, and this helps reduce the current working set
+  // of memory. If preserving these analyses becomes important in the future
+  // we can re-evalutate this.
+  fam.clear();
+  return result;
+}
+
+/// Run the held function pipeline over all non-external functions within the
+/// module.
+void ModuleToFunctionPassAdaptor::runOnModule() {
+  ModuleAnalysisManager &mam = getAnalysisManager();
+  for (auto func : getModule().getOps<FuncOp>()) {
+    // Skip external functions.
+    if (func.isExternal())
+      continue;
+
+    // Run the held function pipeline over the current function.
+    auto fam = mam.slice(func);
+    if (failed(runFunctionPipeline(fpe, func, fam)))
+      return signalPassFailure();
+
+    // Clear out any computed function analyses. These analyses won't be used
+    // any more in this pipeline, and this helps reduce the current working set
+    // of memory. If preserving these analyses becomes important in the future
+    // we can re-evalutate this.
+    fam.clear();
+  }
+}
+
+// Run the held function pipeline synchronously across the functions within
+// the module.
+void ModuleToFunctionPassAdaptorParallel::runOnModule() {
+  ModuleAnalysisManager &mam = getAnalysisManager();
+
+  // Create the async executors if they haven't been created, or if the main
+  // function pipeline has changed.
+  if (asyncExecutors.empty() || asyncExecutors.front().size() != fpe.size())
+    asyncExecutors = {llvm::hardware_concurrency(), fpe};
+
+  // Run a prepass over the module to collect the functions to execute a over.
+  // This ensures that an analysis manager exists for each function, as well as
+  // providing a queue of functions to execute over.
+  std::vector<std::pair<FuncOp, FunctionAnalysisManager>> funcAMPairs;
+  for (auto func : getModule().getOps<FuncOp>())
+    if (!func.isExternal())
+      funcAMPairs.emplace_back(func, mam.slice(func));
+
+  // A parallel diagnostic handler that provides deterministic diagnostic
+  // ordering.
+  ParallelDiagnosticHandler diagHandler(&getContext());
+
+  // An index for the current function/analysis manager pair.
+  std::atomic<unsigned> funcIt(0);
+
+  // An atomic failure variable for the async executors.
+  std::atomic<bool> passFailed(false);
+  llvm::parallel::for_each(
+      llvm::parallel::par, asyncExecutors.begin(),
+      std::next(asyncExecutors.begin(),
+                std::min(asyncExecutors.size(), funcAMPairs.size())),
+      [&](FunctionPassExecutor &executor) {
+        for (auto e = funcAMPairs.size(); !passFailed && funcIt < e;) {
+          // Get the next available function index.
+          unsigned nextID = funcIt++;
+          if (nextID >= e)
+            break;
+
+          // Set the function id for this thread in the diagnostic handler.
+          diagHandler.setOrderIDForThread(nextID);
+
+          // Run the executor over the current function.
+          auto &it = funcAMPairs[nextID];
+          if (failed(runFunctionPipeline(executor, it.first, it.second))) {
+            passFailed = true;
+            break;
+          }
+        }
+      });
+
+  // Signal a failure if any of the executors failed.
+  if (passFailed)
+    signalPassFailure();
+}
+
+//===----------------------------------------------------------------------===//
+// Verifier Passes
+//===----------------------------------------------------------------------===//
+
+void FunctionVerifierPass::runOnFunction() {
+  if (failed(verify(getFunction())))
+    signalPassFailure();
+  markAllAnalysesPreserved();
+}
+
+void ModuleVerifierPass::runOnModule() {
+  if (failed(verify(getModule())))
+    signalPassFailure();
+  markAllAnalysesPreserved();
+}
+
+//===----------------------------------------------------------------------===//
+// PassManager
+//===----------------------------------------------------------------------===//
+
+PassManager::PassManager(bool verifyPasses)
+    : mpe(new ModulePassExecutor()), verifyPasses(verifyPasses),
+      passTiming(false), disableThreads(false) {}
+
+PassManager::~PassManager() {}
+
+/// Run the passes within this manager on the provided module.
+LogicalResult PassManager::run(ModuleOp module) {
+  ModuleAnalysisManager mam(module, instrumentor.get());
+  return mpe->run(module, mam);
+}
+
+/// Disable support for multi-threading within the pass manager.
+void PassManager::disableMultithreading(bool disable) {
+  disableThreads = disable;
+}
+
+/// Add an opaque pass pointer to the current manager. This takes ownership
+/// over the provided pass pointer.
+void PassManager::addPass(Pass *pass) {
+  switch (pass->getKind()) {
+  case Pass::Kind::FunctionPass:
+    addPass(cast<FunctionPassBase>(pass));
+    break;
+  case Pass::Kind::ModulePass:
+    addPass(cast<ModulePassBase>(pass));
+    break;
+  }
+}
+
+/// Add a module pass to the current manager. This takes ownership over the
+/// provided pass pointer.
+void PassManager::addPass(ModulePassBase *pass) {
+  nestedExecutorStack.clear();
+  mpe->addPass(pass);
+
+  // Add a verifier run if requested.
+  if (verifyPasses)
+    mpe->addPass(new ModuleVerifierPass());
+}
+
+/// Add a function pass to the current manager. This takes ownership over the
+/// provided pass pointer. This will automatically create a function pass
+/// executor if necessary.
+void PassManager::addPass(FunctionPassBase *pass) {
+  detail::FunctionPassExecutor *fpe;
+  if (nestedExecutorStack.empty()) {
+    /// Create an executor adaptor for this pass.
+    if (disableThreads || !llvm::llvm_is_multithreaded()) {
+      // If multi-threading is disabled, then create a synchronous adaptor.
+      auto *adaptor = new ModuleToFunctionPassAdaptor();
+      addPass(adaptor);
+      fpe = &adaptor->getFunctionExecutor();
+    } else {
+      auto *adaptor = new ModuleToFunctionPassAdaptorParallel();
+      addPass(adaptor);
+      fpe = &adaptor->getFunctionExecutor();
+    }
+
+    /// Add the executor to the stack.
+    nestedExecutorStack.push_back(fpe);
+  } else {
+    fpe = cast<detail::FunctionPassExecutor>(nestedExecutorStack.back());
+  }
+  fpe->addPass(pass);
+
+  // Add a verifier run if requested.
+  if (verifyPasses)
+    fpe->addPass(new FunctionVerifierPass());
+}
+
+/// Add the provided instrumentation to the pass manager. This takes ownership
+/// over the given pointer.
+void PassManager::addInstrumentation(PassInstrumentation *pi) {
+  if (!instrumentor)
+    instrumentor.reset(new PassInstrumentor());
+
+  instrumentor->addInstrumentation(pi);
+}
+
+//===----------------------------------------------------------------------===//
+// AnalysisManager
+//===----------------------------------------------------------------------===//
+
+/// Returns a pass instrumentation object for the current function.
+PassInstrumentor *FunctionAnalysisManager::getPassInstrumentor() const {
+  return parent->getPassInstrumentor();
+}
+
+/// Create an analysis slice for the given child function.
+FunctionAnalysisManager ModuleAnalysisManager::slice(FuncOp func) {
+  assert(func.getOperation()->getParentOp() == moduleAnalyses.getIRUnit() &&
+         "function has a different parent module");
+  auto it = functionAnalyses.find(func);
+  if (it == functionAnalyses.end()) {
+    it =
+        functionAnalyses.try_emplace(func, new AnalysisMap<FuncOp>(func)).first;
+  }
+  return {this, it->second.get()};
+}
+
+/// Invalidate any non preserved analyses.
+void ModuleAnalysisManager::invalidate(const detail::PreservedAnalyses &pa) {
+  // If all analyses were preserved, then there is nothing to do here.
+  if (pa.isAll())
+    return;
+
+  // Invalidate the module analyses directly.
+  moduleAnalyses.invalidate(pa);
+
+  // If no analyses were preserved, then just simply clear out the function
+  // analysis results.
+  if (pa.isNone()) {
+    functionAnalyses.clear();
+    return;
+  }
+
+  // Otherwise, invalidate each function analyses.
+  for (auto &analysisPair : functionAnalyses)
+    analysisPair.second->invalidate(pa);
+}
+
+//===----------------------------------------------------------------------===//
+// PassInstrumentation
+//===----------------------------------------------------------------------===//
+
+PassInstrumentation::~PassInstrumentation() {}
+
+//===----------------------------------------------------------------------===//
+// PassInstrumentor
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace detail {
+struct PassInstrumentorImpl {
+  /// Mutex to keep instrumentation access thread-safe.
+  llvm::sys::SmartMutex<true> mutex;
+
+  /// Set of registered instrumentations.
+  std::vector<std::unique_ptr<PassInstrumentation>> instrumentations;
+};
+} // end namespace detail
+} // end namespace mlir
+
+PassInstrumentor::PassInstrumentor() : impl(new PassInstrumentorImpl()) {}
+PassInstrumentor::~PassInstrumentor() {}
+
+/// See PassInstrumentation::runBeforePass for details.
+void PassInstrumentor::runBeforePass(Pass *pass, const llvm::Any &ir) {
+  llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
+  for (auto &instr : impl->instrumentations)
+    instr->runBeforePass(pass, ir);
+}
+
+/// See PassInstrumentation::runAfterPass for details.
+void PassInstrumentor::runAfterPass(Pass *pass, const llvm::Any &ir) {
+  llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
+  for (auto &instr : llvm::reverse(impl->instrumentations))
+    instr->runAfterPass(pass, ir);
+}
+
+/// See PassInstrumentation::runAfterPassFailed for details.
+void PassInstrumentor::runAfterPassFailed(Pass *pass, const llvm::Any &ir) {
+  llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
+  for (auto &instr : llvm::reverse(impl->instrumentations))
+    instr->runAfterPassFailed(pass, ir);
+}
+
+/// See PassInstrumentation::runBeforeAnalysis for details.
+void PassInstrumentor::runBeforeAnalysis(llvm::StringRef name, AnalysisID *id,
+                                         const llvm::Any &ir) {
+  llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
+  for (auto &instr : impl->instrumentations)
+    instr->runBeforeAnalysis(name, id, ir);
+}
+
+/// See PassInstrumentation::runAfterAnalysis for details.
+void PassInstrumentor::runAfterAnalysis(llvm::StringRef name, AnalysisID *id,
+                                        const llvm::Any &ir) {
+  llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
+  for (auto &instr : llvm::reverse(impl->instrumentations))
+    instr->runAfterAnalysis(name, id, ir);
+}
+
+/// Add the given instrumentation to the collection. This takes ownership over
+/// the given pointer.
+void PassInstrumentor::addInstrumentation(PassInstrumentation *pi) {
+  llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
+  impl->instrumentations.emplace_back(pi);
+}
+
+constexpr AnalysisID mlir::detail::PreservedAnalyses::allAnalysesID;
diff --git a/third_party/mlir/lib/Pass/PassDetail.h b/third_party/mlir/lib/Pass/PassDetail.h
new file mode 100644
index 0000000..0b41c44
--- /dev/null
+++ b/third_party/mlir/lib/Pass/PassDetail.h
@@ -0,0 +1,170 @@
+//===- PassDetail.h - MLIR Pass details -------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef MLIR_PASS_PASSDETAIL_H_
+#define MLIR_PASS_PASSDETAIL_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace detail {
+
+//===----------------------------------------------------------------------===//
+// Verifier Passes
+//===----------------------------------------------------------------------===//
+
+/// Pass to verify a function and signal failure if necessary.
+class FunctionVerifierPass : public FunctionPass<FunctionVerifierPass> {
+  void runOnFunction() override;
+};
+
+/// Pass to verify a module and signal failure if necessary.
+class ModuleVerifierPass : public ModulePass<ModuleVerifierPass> {
+  void runOnModule() override;
+};
+
+//===----------------------------------------------------------------------===//
+// PassExecutor
+//===----------------------------------------------------------------------===//
+
+/// The abstract base pass executor class.
+class PassExecutor {
+public:
+  enum Kind { FunctionExecutor, ModuleExecutor };
+  explicit PassExecutor(Kind kind) : kind(kind) {}
+
+  /// Get the kind of this executor.
+  Kind getKind() const { return kind; }
+
+private:
+  /// The kind of executor this object is.
+  Kind kind;
+};
+
+/// A pass executor that contains a list of passes over a function.
+class FunctionPassExecutor : public PassExecutor {
+public:
+  FunctionPassExecutor() : PassExecutor(Kind::FunctionExecutor) {}
+  FunctionPassExecutor(FunctionPassExecutor &&) = default;
+  FunctionPassExecutor(const FunctionPassExecutor &rhs);
+
+  /// Run the executor on the given function.
+  LogicalResult run(FuncOp function, FunctionAnalysisManager &fam);
+
+  /// Add a pass to the current executor. This takes ownership over the provided
+  /// pass pointer.
+  void addPass(FunctionPassBase *pass) { passes.emplace_back(pass); }
+
+  /// Returns the number of passes held by this executor.
+  size_t size() const { return passes.size(); }
+
+  static bool classof(const PassExecutor *pe) {
+    return pe->getKind() == Kind::FunctionExecutor;
+  }
+
+private:
+  std::vector<std::unique_ptr<FunctionPassBase>> passes;
+};
+
+/// A pass executor that contains a list of passes over a module unit.
+class ModulePassExecutor : public PassExecutor {
+public:
+  ModulePassExecutor() : PassExecutor(Kind::ModuleExecutor) {}
+  ModulePassExecutor(ModulePassExecutor &&) = default;
+
+  // Don't allow copying.
+  ModulePassExecutor(const ModulePassExecutor &) = delete;
+  ModulePassExecutor &operator=(const ModulePassExecutor &) = delete;
+
+  /// Run the executor on the given module.
+  LogicalResult run(ModuleOp module, ModuleAnalysisManager &mam);
+
+  /// Add a pass to the current executor. This takes ownership over the provided
+  /// pass pointer.
+  void addPass(ModulePassBase *pass) { passes.emplace_back(pass); }
+
+  static bool classof(const PassExecutor *pe) {
+    return pe->getKind() == Kind::ModuleExecutor;
+  }
+
+private:
+  /// Set of passes to run on the given module.
+  std::vector<std::unique_ptr<ModulePassBase>> passes;
+};
+
+//===----------------------------------------------------------------------===//
+// ModuleToFunctionPassAdaptor
+//===----------------------------------------------------------------------===//
+
+/// An adaptor module pass used to run function passes over all of the
+/// non-external functions of a module synchronously on a single thread.
+class ModuleToFunctionPassAdaptor
+    : public ModulePass<ModuleToFunctionPassAdaptor> {
+public:
+  /// Run the held function pipeline over all non-external functions within the
+  /// module.
+  void runOnModule() override;
+
+  /// Returns the function pass executor for this adaptor.
+  FunctionPassExecutor &getFunctionExecutor() { return fpe; }
+
+private:
+  FunctionPassExecutor fpe;
+};
+
+/// An adaptor module pass used to run function passes over all of the
+/// non-external functions of a module asynchronously across multiple threads.
+class ModuleToFunctionPassAdaptorParallel
+    : public ModulePass<ModuleToFunctionPassAdaptorParallel> {
+public:
+  /// Run the held function pipeline over all non-external functions within the
+  /// module.
+  void runOnModule() override;
+
+  /// Returns the function pass executor for this adaptor.
+  FunctionPassExecutor &getFunctionExecutor() { return fpe; }
+
+private:
+  // The main function pass executor for this adaptor.
+  FunctionPassExecutor fpe;
+
+  // A set of executors, cloned from the main executor, that run asynchronously
+  // on different threads.
+  std::vector<FunctionPassExecutor> asyncExecutors;
+};
+
+/// Utility function to return if a pass refers to an
+/// ModuleToFunctionPassAdaptor instance.
+inline bool isModuleToFunctionAdaptorPass(Pass *pass) {
+  return isa<ModuleToFunctionPassAdaptorParallel>(pass) ||
+         isa<ModuleToFunctionPassAdaptor>(pass);
+}
+
+/// Utility function to return if a pass refers to an adaptor pass. Adaptor
+/// passes are those that internally execute a pipeline, such as the
+/// ModuleToFunctionPassAdaptor.
+inline bool isAdaptorPass(Pass *pass) {
+  return isModuleToFunctionAdaptorPass(pass);
+}
+
+/// Utility function to return if a pass refers to a verifier pass.
+inline bool isVerifierPass(Pass *pass) {
+  return isa<FunctionVerifierPass>(pass) || isa<ModuleVerifierPass>(pass);
+}
+
+} // end namespace detail
+} // end namespace mlir
+#endif // MLIR_PASS_PASSDETAIL_H_
diff --git a/third_party/mlir/lib/Pass/PassManagerOptions.cpp b/third_party/mlir/lib/Pass/PassManagerOptions.cpp
new file mode 100644
index 0000000..055e81c
--- /dev/null
+++ b/third_party/mlir/lib/Pass/PassManagerOptions.cpp
@@ -0,0 +1,170 @@
+//===- PassManagerOptions.cpp - PassManager Command Line Options ----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/ManagedStatic.h"
+
+using namespace mlir;
+
+namespace {
+struct PassManagerOptions {
+  typedef llvm::cl::list<const mlir::PassRegistryEntry *, bool, PassNameParser>
+      PassOptionList;
+
+  PassManagerOptions();
+
+  //===--------------------------------------------------------------------===//
+  // Multi-threading
+  //===--------------------------------------------------------------------===//
+  llvm::cl::opt<bool> disableThreads;
+
+  //===--------------------------------------------------------------------===//
+  // IR Printing
+  //===--------------------------------------------------------------------===//
+  PassOptionList printBefore;
+  PassOptionList printAfter;
+  llvm::cl::opt<bool> printBeforeAll;
+  llvm::cl::opt<bool> printAfterAll;
+  llvm::cl::opt<bool> printModuleScope;
+
+  /// Add an IR printing instrumentation if enabled by any 'print-ir' flags.
+  void addPrinterInstrumentation(PassManager &pm);
+
+  //===--------------------------------------------------------------------===//
+  // Pass Timing
+  //===--------------------------------------------------------------------===//
+  llvm::cl::opt<bool> passTiming;
+  llvm::cl::opt<PassTimingDisplayMode> passTimingDisplayMode;
+
+  /// Add a pass timing instrumentation if enabled by 'pass-timing' flags.
+  void addTimingInstrumentation(PassManager &pm);
+};
+} // end anonymous namespace
+
+static llvm::ManagedStatic<llvm::Optional<PassManagerOptions>> options;
+
+PassManagerOptions::PassManagerOptions()
+    //===------------------------------------------------------------------===//
+    // Multi-threading
+    //===------------------------------------------------------------------===//
+    : disableThreads(
+          "disable-pass-threading",
+          llvm::cl::desc("Disable multithreading in the pass manager"),
+          llvm::cl::init(false)),
+
+      //===----------------------------------------------------------------===//
+      // IR Printing
+      //===----------------------------------------------------------------===//
+      printBefore("print-ir-before",
+                  llvm::cl::desc("Print IR before specified passes")),
+      printAfter("print-ir-after",
+                 llvm::cl::desc("Print IR after specified passes")),
+      printBeforeAll("print-ir-before-all",
+                     llvm::cl::desc("Print IR before each pass"),
+                     llvm::cl::init(false)),
+      printAfterAll("print-ir-after-all",
+                    llvm::cl::desc("Print IR after each pass"),
+                    llvm::cl::init(false)),
+      printModuleScope(
+          "print-ir-module-scope",
+          llvm::cl::desc("When printing IR for print-ir-[before|after]{-all} "
+                         "always print "
+                         "a module IR"),
+          llvm::cl::init(false)),
+
+      //===----------------------------------------------------------------===//
+      // Pass Timing
+      //===----------------------------------------------------------------===//
+      passTiming("pass-timing",
+                 llvm::cl::desc("Display the execution times of each pass")),
+      passTimingDisplayMode(
+          "pass-timing-display",
+          llvm::cl::desc("Display method for pass timing data"),
+          llvm::cl::init(PassTimingDisplayMode::Pipeline),
+          llvm::cl::values(
+              clEnumValN(PassTimingDisplayMode::List, "list",
+                         "display the results in a list sorted by total time"),
+              clEnumValN(PassTimingDisplayMode::Pipeline, "pipeline",
+                         "display the results with a nested pipeline view"))) {}
+
+/// Add an IR printing instrumentation if enabled by any 'print-ir' flags.
+void PassManagerOptions::addPrinterInstrumentation(PassManager &pm) {
+  std::function<bool(Pass *)> shouldPrintBeforePass, shouldPrintAfterPass;
+
+  // Handle print-before.
+  if (printBeforeAll) {
+    // If we are printing before all, then just return true for the filter.
+    shouldPrintBeforePass = [](Pass *) { return true; };
+  } else if (printBefore.getNumOccurrences() != 0) {
+    // Otherwise if there are specific passes to print before, then check to see
+    // if the pass info for the current pass is included in the list.
+    shouldPrintBeforePass = [&](Pass *pass) {
+      auto *passInfo = pass->lookupPassInfo();
+      return passInfo && llvm::is_contained(printBefore, passInfo);
+    };
+  }
+
+  // Handle print-after.
+  if (printAfterAll) {
+    // If we are printing after all, then just return true for the filter.
+    shouldPrintAfterPass = [](Pass *) { return true; };
+  } else if (printAfter.getNumOccurrences() != 0) {
+    // Otherwise if there are specific passes to print after, then check to see
+    // if the pass info for the current pass is included in the list.
+    shouldPrintAfterPass = [&](Pass *pass) {
+      auto *passInfo = pass->lookupPassInfo();
+      return passInfo && llvm::is_contained(printAfter, passInfo);
+    };
+  }
+
+  // If there are no valid printing filters, then just return.
+  if (!shouldPrintBeforePass && !shouldPrintAfterPass)
+    return;
+
+  // Otherwise, add the IR printing instrumentation.
+  pm.enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
+                      printModuleScope, llvm::errs());
+}
+
+/// Add a pass timing instrumentation if enabled by 'pass-timing' flags.
+void PassManagerOptions::addTimingInstrumentation(PassManager &pm) {
+  if (passTiming)
+    pm.enableTiming(passTimingDisplayMode);
+}
+
+void mlir::registerPassManagerCLOptions() {
+  // Reset the options instance if it hasn't been enabled yet.
+  if (!options->hasValue())
+    options->emplace();
+}
+
+void mlir::applyPassManagerCLOptions(PassManager &pm) {
+  // Disable multi-threading.
+  if ((*options)->disableThreads)
+    pm.disableMultithreading();
+
+  // Add the IR printing instrumentation.
+  (*options)->addPrinterInstrumentation(pm);
+
+  // Note: The pass timing instrumentation should be added last to avoid any
+  // potential "ghost" timing from other instrumentations being unintentionally
+  // included in the timing results.
+  (*options)->addTimingInstrumentation(pm);
+}
diff --git a/third_party/mlir/lib/Pass/PassRegistry.cpp b/third_party/mlir/lib/Pass/PassRegistry.cpp
new file mode 100644
index 0000000..0d85761
--- /dev/null
+++ b/third_party/mlir/lib/Pass/PassRegistry.cpp
@@ -0,0 +1,117 @@
+//===- PassRegistry.cpp - Pass Registration Utilities ---------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/ManagedStatic.h"
+
+using namespace mlir;
+
+/// Static mapping of all of the registered passes.
+static llvm::ManagedStatic<llvm::DenseMap<const PassID *, PassInfo>>
+    passRegistry;
+
+/// Static mapping of all of the registered pass pipelines.
+static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
+    passPipelineRegistry;
+
+/// Utility to create a default registry function from a pass instance.
+static PassRegistryFunction
+buildDefaultRegistryFn(PassAllocatorFunction allocator) {
+  return [=](PassManager &pm) { pm.addPass(allocator()); };
+}
+
+//===----------------------------------------------------------------------===//
+// PassPipelineInfo
+//===----------------------------------------------------------------------===//
+
+/// Constructor that accepts a pass allocator function instead of the standard
+/// registry function. This is useful for registering specializations of
+/// existing passes.
+PassPipelineRegistration::PassPipelineRegistration(
+    StringRef arg, StringRef description, PassAllocatorFunction allocator) {
+  registerPassPipeline(arg, description, buildDefaultRegistryFn(allocator));
+}
+
+void mlir::registerPassPipeline(StringRef arg, StringRef description,
+                                const PassRegistryFunction &function) {
+  PassPipelineInfo pipelineInfo(arg, description, function);
+  bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second;
+  assert(inserted && "Pass pipeline registered multiple times");
+  (void)inserted;
+}
+
+//===----------------------------------------------------------------------===//
+// PassInfo
+//===----------------------------------------------------------------------===//
+
+PassInfo::PassInfo(StringRef arg, StringRef description, const PassID *passID,
+                   PassAllocatorFunction allocator)
+    : PassRegistryEntry(arg, description, buildDefaultRegistryFn(allocator)) {}
+
+void mlir::registerPass(StringRef arg, StringRef description,
+                        const PassID *passID,
+                        const PassAllocatorFunction &function) {
+  PassInfo passInfo(arg, description, passID, function);
+  bool inserted = passRegistry->try_emplace(passID, passInfo).second;
+  assert(inserted && "Pass registered multiple times");
+  (void)inserted;
+}
+
+/// Returns the pass info for the specified pass class or null if unknown.
+const PassInfo *mlir::Pass::lookupPassInfo(const PassID *passID) {
+  auto it = passRegistry->find(passID);
+  if (it == passRegistry->end())
+    return nullptr;
+  return &it->getSecond();
+}
+
+//===----------------------------------------------------------------------===//
+// PassNameParser
+//===----------------------------------------------------------------------===//
+
+PassNameParser::PassNameParser(llvm::cl::Option &opt)
+    : llvm::cl::parser<const PassRegistryEntry *>(opt) {}
+
+void PassNameParser::initialize() {
+  llvm::cl::parser<const PassRegistryEntry *>::initialize();
+
+  /// Add the pass entries.
+  for (const auto &kv : *passRegistry) {
+    addLiteralOption(kv.second.getPassArgument(), &kv.second,
+                     kv.second.getPassDescription());
+  }
+  /// Add the pass pipeline entries.
+  for (const auto &kv : *passPipelineRegistry) {
+    addLiteralOption(kv.second.getPassArgument(), &kv.second,
+                     kv.second.getPassDescription());
+  }
+}
+
+void PassNameParser::printOptionInfo(const llvm::cl::Option &O,
+                                     size_t GlobalWidth) const {
+  PassNameParser *TP = const_cast<PassNameParser *>(this);
+  llvm::array_pod_sort(TP->Values.begin(), TP->Values.end(),
+                       [](const PassNameParser::OptionInfo *VT1,
+                          const PassNameParser::OptionInfo *VT2) {
+                         return VT1->Name.compare(VT2->Name);
+                       });
+  using llvm::cl::parser;
+  parser<const PassRegistryEntry *>::printOptionInfo(O, GlobalWidth);
+}
diff --git a/third_party/mlir/lib/Pass/PassTiming.cpp b/third_party/mlir/lib/Pass/PassTiming.cpp
new file mode 100644
index 0000000..b4f3756
--- /dev/null
+++ b/third_party/mlir/lib/Pass/PassTiming.cpp
@@ -0,0 +1,401 @@
+//===- PassTiming.cpp -----------------------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "PassDetail.h"
+#include "mlir/Pass/PassManager.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Support/Format.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/Threading.h"
+#include <chrono>
+
+using namespace mlir;
+using namespace mlir::detail;
+
+constexpr llvm::StringLiteral kPassTimingDescription =
+    "... Pass execution timing report ...";
+
+namespace {
+/// Simple record class to record timing information.
+struct TimeRecord {
+  TimeRecord(double wall = 0.0, double user = 0.0) : wall(wall), user(user) {}
+
+  TimeRecord &operator+=(const TimeRecord &other) {
+    wall += other.wall;
+    user += other.user;
+    return *this;
+  }
+
+  /// Print the current time record to 'os', with a breakdown showing
+  /// contributions to the give 'total' time record.
+  void print(raw_ostream &os, const TimeRecord &total) {
+    if (total.user != total.wall)
+      os << llvm::format("  %7.4f (%5.1f%%)  ", user,
+                         100.0 * user / total.user);
+    os << llvm::format("  %7.4f (%5.1f%%)  ", wall, 100.0 * wall / total.wall);
+  }
+
+  double wall, user;
+};
+
+struct Timer {
+  explicit Timer(std::string &&name) : name(std::move(name)) {}
+
+  /// Start the timer.
+  void start() { startTime = std::chrono::system_clock::now(); }
+
+  /// Stop the timer.
+  void stop() {
+    auto newTime = std::chrono::system_clock::now() - startTime;
+    wallTime += newTime;
+    userTime += newTime;
+  }
+
+  /// Get or create a child timer with the provided name and id.
+  Timer *getChildTimer(const void *id,
+                       std::function<std::string()> &&nameBuilder) {
+    auto &child = children[id];
+    if (!child)
+      child.reset(new Timer(nameBuilder()));
+    return child.get();
+  }
+
+  /// Returns the total time for this timer in seconds.
+  TimeRecord getTotalTime() {
+    // If we have a valid wall time, then we directly compute the seconds.
+    if (wallTime.count()) {
+      return TimeRecord(
+          std::chrono::duration_cast<std::chrono::duration<double>>(wallTime)
+              .count(),
+          std::chrono::duration_cast<std::chrono::duration<double>>(userTime)
+              .count());
+    }
+
+    // Otheriwse, accumulate the timing from each of the children.
+    TimeRecord totalTime;
+    for (auto &child : children)
+      totalTime += child.second->getTotalTime();
+    return totalTime;
+  }
+
+  /// A map of unique identifiers to child timers.
+  using ChildrenMap = llvm::MapVector<const void *, std::unique_ptr<Timer>>;
+
+  /// Merge the timing data from 'other' into this timer.
+  void merge(Timer &&other) {
+    if (wallTime < other.wallTime)
+      wallTime = other.wallTime;
+    userTime += other.userTime;
+    mergeChildren(std::move(other.children), /*isStructural=*/false);
+  }
+
+  /// Merge the timer chilren in 'otherChildren' with the children of this
+  /// timer. If 'isStructural' is true, the children are merged lexographically
+  /// and 'otherChildren' must have the same number of elements as the children
+  /// of this timer. Otherwise, the timer children are merged based upon the
+  /// given timer key.
+  void mergeChildren(ChildrenMap &&otherChildren, bool isStructural) {
+    // Check for an empty children list.
+    if (children.empty()) {
+      children = std::move(otherChildren);
+      return;
+    }
+
+    if (isStructural) {
+      // If this is a structural merge, the number of children must be the same.
+      assert(children.size() == otherChildren.size() &&
+             "structural merge requires the same number of children");
+      auto it = children.begin(), otherIt = otherChildren.begin();
+      for (auto e = children.end(); it != e; ++it, ++otherIt)
+        it->second->merge(std::move(*otherIt->second));
+      return;
+    }
+
+    // Otherwise, we merge based upon the child timers key.
+    for (auto &otherChild : otherChildren) {
+      auto &child = children[otherChild.first];
+      if (!child)
+        child = std::move(otherChild.second);
+      else
+        child->merge(std::move(*otherChild.second));
+    }
+  }
+
+  /// Raw timing information.
+  std::chrono::time_point<std::chrono::system_clock> startTime;
+  std::chrono::nanoseconds wallTime = std::chrono::nanoseconds(0);
+  std::chrono::nanoseconds userTime = std::chrono::nanoseconds(0);
+
+  /// A map of unique identifiers to child timers.
+  ChildrenMap children;
+
+  /// A descriptive name for this timer.
+  std::string name;
+};
+
+struct PassTiming : public PassInstrumentation {
+  PassTiming(PassTimingDisplayMode displayMode) : displayMode(displayMode) {}
+  ~PassTiming() { print(); }
+
+  /// Setup the instrumentation hooks.
+  void runBeforePass(Pass *pass, const llvm::Any &) override {
+    startPassTimer(pass);
+  }
+  void runAfterPass(Pass *pass, const llvm::Any &) override;
+  void runAfterPassFailed(Pass *pass, const llvm::Any &ir) override {
+    runAfterPass(pass, ir);
+  }
+  void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id,
+                         const llvm::Any &) override {
+    startAnalysisTimer(name, id);
+  }
+  void runAfterAnalysis(llvm::StringRef, AnalysisID *,
+                        const llvm::Any &) override;
+
+  /// Print and clear the timing results.
+  void print();
+
+  /// Start a new timer for the given pass.
+  void startPassTimer(Pass *pass);
+
+  /// Start a new timer for the given analysis.
+  void startAnalysisTimer(llvm::StringRef name, AnalysisID *id);
+
+  /// Stop a pass timer.
+  void stopPassTimer(Pass *pass);
+
+  /// Stop the last active timer.
+  void stopTimer();
+
+  /// Print the timing result in list mode.
+  void printResultsAsList(raw_ostream &os, Timer *root, TimeRecord totalTime);
+
+  /// Print the timing result in pipeline mode.
+  void printResultsAsPipeline(raw_ostream &os, Timer *root,
+                              TimeRecord totalTime);
+
+  /// Returns a timer for the provided identifier and name.
+  Timer *getTimer(const void *id, std::function<std::string()> &&nameBuilder) {
+    auto tid = llvm::get_threadid();
+
+    // If there is no active timer then add to the root timer.
+    auto &activeTimers = activeThreadTimers[tid];
+    if (activeTimers.empty()) {
+      auto &rootTimer = rootTimers[tid];
+      if (!rootTimer)
+        rootTimer.reset(new Timer("root"));
+      auto *timer = rootTimer->getChildTimer(id, std::move(nameBuilder));
+      activeTimers.push_back(timer);
+      return timer;
+    }
+
+    // Otherwise, add this to the active timer.
+    auto timer = activeTimers.back()->getChildTimer(id, std::move(nameBuilder));
+    activeTimers.push_back(timer);
+    return timer;
+  }
+
+  /// The root top level timers for each thread.
+  DenseMap<uint64_t, std::unique_ptr<Timer>> rootTimers;
+
+  /// A stack of the currently active pass timers per thread.
+  DenseMap<uint64_t, SmallVector<Timer *, 4>> activeThreadTimers;
+
+  /// The display mode to use when printing the timing results.
+  PassTimingDisplayMode displayMode;
+};
+} // end anonymous namespace
+
+/// Start a new timer for the given pass.
+void PassTiming::startPassTimer(Pass *pass) {
+  Timer *timer = getTimer(pass, [pass] {
+    if (isModuleToFunctionAdaptorPass(pass))
+      return StringRef("Function Pipeline");
+    return pass->getName();
+  });
+
+  // We don't actually want to time the adaptor passes, they gather their total
+  // from their held passes.
+  if (!isAdaptorPass(pass))
+    timer->start();
+}
+
+/// Start a new timer for the given analysis.
+void PassTiming::startAnalysisTimer(llvm::StringRef name, AnalysisID *id) {
+  Timer *timer = getTimer(id, [name] { return "(A) " + name.str(); });
+  timer->start();
+}
+
+/// Stop a pass timer.
+void PassTiming::runAfterPass(Pass *pass, const llvm::Any &) {
+  auto tid = llvm::get_threadid();
+  auto &activeTimers = activeThreadTimers[tid];
+  assert(!activeTimers.empty() && "expected active timer");
+  Timer *timer = activeTimers.pop_back_val();
+
+  // If this is an ModuleToFunctionPassAdaptorParallel, then we need to merge in
+  // the timing data for the other threads.
+  if (isa<ModuleToFunctionPassAdaptorParallel>(pass)) {
+    // The asychronous pipeline timers should exist as children of root timers
+    // for other threads.
+    for (auto &rootTimer : llvm::make_early_inc_range(rootTimers)) {
+      // Skip the current thread.
+      if (rootTimer.first == tid)
+        continue;
+      // Check that this thread has no active timers.
+      assert(activeThreadTimers[tid].empty() && "expected no active timers");
+
+      // Structurally merge this timers children into the parallel
+      // module-to-function pass timer.
+      timer->mergeChildren(std::move(rootTimer.second->children),
+                           /*isStructural=*/true);
+      rootTimers.erase(rootTimer.first);
+    }
+    return;
+  }
+
+  // Adapator passes aren't timed directly, so we don't need to stop their
+  // timers.
+  if (!isAdaptorPass(pass))
+    timer->stop();
+}
+
+/// Stop a timer.
+void PassTiming::runAfterAnalysis(llvm::StringRef, AnalysisID *,
+                                  const llvm::Any &) {
+  auto &activeTimers = activeThreadTimers[llvm::get_threadid()];
+  assert(!activeTimers.empty() && "expected active timer");
+  Timer *timer = activeTimers.pop_back_val();
+  timer->stop();
+}
+
+/// Utility to print the timer heading information.
+static void printTimerHeader(llvm::raw_ostream &os, TimeRecord total) {
+  os << "===" << std::string(73, '-') << "===\n";
+  // Figure out how many spaces to description name.
+  unsigned Padding = (80 - kPassTimingDescription.size()) / 2;
+  os.indent(Padding) << kPassTimingDescription << '\n';
+  os << "===" << std::string(73, '-') << "===\n";
+
+  // Print the total time followed by the section headers.
+  os << llvm::format("  Total Execution Time: %5.4f seconds\n\n", total.wall);
+  if (total.user != total.wall)
+    os << "   ---User Time---";
+  os << "   ---Wall Time---  --- Name ---\n";
+}
+
+/// Utility to print a single line entry in the timer output.
+static void printTimeEntry(raw_ostream &os, unsigned indent, StringRef name,
+                           TimeRecord time, TimeRecord totalTime) {
+  time.print(os, totalTime);
+  os.indent(indent) << name << "\n";
+}
+
+/// Print out the current timing information.
+void PassTiming::print() {
+  // Don't print anything if there is no timing data.
+  if (rootTimers.empty())
+    return;
+
+  assert(rootTimers.size() == 1 && "expected one remaining root timer");
+  auto &rootTimer = rootTimers.begin()->second;
+  auto os = llvm::CreateInfoOutputFile();
+
+  // Print the timer header.
+  TimeRecord totalTime = rootTimer->getTotalTime();
+  printTimerHeader(*os, totalTime);
+
+  // Defer to a specialized printer for each display mode.
+  switch (displayMode) {
+  case PassTimingDisplayMode::List:
+    printResultsAsList(*os, rootTimer.get(), totalTime);
+    break;
+  case PassTimingDisplayMode::Pipeline:
+    printResultsAsPipeline(*os, rootTimer.get(), totalTime);
+    break;
+  }
+  printTimeEntry(*os, 0, "Total", totalTime, totalTime);
+  os->flush();
+
+  // Reset root timers.
+  rootTimers.clear();
+  activeThreadTimers.clear();
+}
+
+/// Print the timing result in list mode.
+void PassTiming::printResultsAsList(raw_ostream &os, Timer *root,
+                                    TimeRecord totalTime) {
+  llvm::StringMap<TimeRecord> mergedTimings;
+
+  std::function<void(Timer *)> addTimer = [&](Timer *timer) {
+    // Check for timing information.
+    if (timer->wallTime.count())
+      mergedTimings[timer->name] += timer->getTotalTime();
+    for (auto &children : timer->children)
+      addTimer(children.second.get());
+  };
+
+  // Add each of the top level timers.
+  for (auto &topLevelTimer : root->children)
+    addTimer(topLevelTimer.second.get());
+
+  // Sort the timing information by wall time.
+  std::vector<std::pair<StringRef, TimeRecord>> timerNameAndTime;
+  for (auto &it : mergedTimings)
+    timerNameAndTime.emplace_back(it.first(), it.second);
+  llvm::array_pod_sort(timerNameAndTime.begin(), timerNameAndTime.end(),
+                       [](const std::pair<StringRef, TimeRecord> *lhs,
+                          const std::pair<StringRef, TimeRecord> *rhs) {
+                         return llvm::array_pod_sort_comparator<double>(
+                             &rhs->second.wall, &lhs->second.wall);
+                       });
+
+  // Print the timing information sequentially.
+  for (auto &timeData : timerNameAndTime)
+    printTimeEntry(os, 0, timeData.first, timeData.second, totalTime);
+}
+
+/// Print the timing result in pipeline mode.
+void PassTiming::printResultsAsPipeline(raw_ostream &os, Timer *root,
+                                        TimeRecord totalTime) {
+  std::function<void(unsigned, Timer *)> printTimer = [&](unsigned indent,
+                                                          Timer *timer) {
+    printTimeEntry(os, indent, timer->name, timer->getTotalTime(), totalTime);
+    for (auto &children : timer->children)
+      printTimer(indent + 2, children.second.get());
+  };
+
+  // Print each of the top level timers.
+  for (auto &topLevelTimer : root->children)
+    printTimer(0, topLevelTimer.second.get());
+}
+
+//===----------------------------------------------------------------------===//
+// PassManager
+//===----------------------------------------------------------------------===//
+
+/// Add an instrumentation to time the execution of passes and the computation
+/// of analyses.
+void PassManager::enableTiming(PassTimingDisplayMode displayMode) {
+  // Check if pass timing is already enabled.
+  if (passTiming)
+    return;
+  addInstrumentation(new PassTiming(displayMode));
+  passTiming = true;
+}
diff --git a/third_party/mlir/lib/Quantizer/CMakeLists.txt b/third_party/mlir/lib/Quantizer/CMakeLists.txt
new file mode 100644
index 0000000..bc157d0
--- /dev/null
+++ b/third_party/mlir/lib/Quantizer/CMakeLists.txt
@@ -0,0 +1,44 @@
+# Support.
+add_llvm_library(MLIRQuantizerSupport
+  Support/Configuration.cpp
+  Support/ConstraintAnalysisGraph.cpp
+  Support/Metadata.cpp
+  Support/Statistics.cpp
+  Support/TypeUtils.cpp
+  Support/UniformConstraints.cpp
+  Support/UniformSolvers.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  )
+add_dependencies(MLIRQuantizerSupport
+                 MLIRIR
+                 MLIRQuantOps
+                 MLIRSupport
+                 MLIRStandardOps)
+
+# Configurations.
+add_llvm_library(MLIRQuantizerFxpMathConfig
+  Configurations/FxpMathConfig.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  )
+add_dependencies(MLIRQuantizerFxpMathConfig
+                 MLIRFxpMathOpsIncGen
+                 MLIRQuantizerSupport)
+
+# Transforms.
+add_llvm_library(MLIRQuantizerTransforms
+  Transforms/AddDefaultStatsTestPass.cpp
+  Transforms/InferQuantizedTypesPass.cpp
+  Transforms/RemoveInstrumentationPass.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  )
+add_dependencies(MLIRQuantizerTransforms
+  MLIRQuantizerFxpMathConfig
+  MLIRQuantizerSupport
+  MLIRPass)
+target_link_libraries(MLIRQuantizerTransforms
+  MLIRQuantizerFxpMathConfig
+  MLIRQuantizerSupport
+  MLIRPass)
diff --git a/third_party/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp b/third_party/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp
new file mode 100644
index 0000000..d0eda55
--- /dev/null
+++ b/third_party/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp
@@ -0,0 +1,287 @@
+//===- FxpMathConfig.cpp - Reference fixed point config -------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines a TargetConfiguration for reference fixed-point math
+// quantization scheme based on the FxpMathOps (plus a small category of
+// extension ops that can be added from other dialects).
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Quantizer/Configurations/FxpMathConfig.h"
+
+#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
+#include "mlir/Dialect/QuantOps/QuantOps.h"
+#include "mlir/Dialect/QuantOps/QuantTypes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
+#include "mlir/Quantizer/Support/Metadata.h"
+#include "mlir/Quantizer/Support/Statistics.h"
+#include "mlir/Quantizer/Support/UniformConstraints.h"
+#include "mlir/StandardOps/Ops.h"
+
+using namespace mlir;
+using namespace mlir::quantizer;
+using namespace mlir::fxpmath;
+using namespace mlir::quant;
+using namespace std::placeholders;
+
+namespace {
+
+struct FxpMathTargetConfigImpl : public FxpMathTargetConfig {
+  FxpMathTargetConfigImpl(SolverContext &context)
+      : FxpMathTargetConfig(context) {
+    Builder b(&context.getMlirContext());
+    IntegerType i8Type = b.getIntegerType(8);
+    IntegerType i16Type = b.getIntegerType(16);
+    IntegerType i32Type = b.getIntegerType(32);
+
+    q8 = addCandidateType(
+        AnyQuantizedType::get(QuantizationFlags::Signed, i8Type, nullptr,
+                              std::numeric_limits<int8_t>::min(),
+                              std::numeric_limits<int8_t>::max()),
+        CandidateQuantizedType::Scheme::UniformPerLayer);
+    q16 = addCandidateType(
+        AnyQuantizedType::get(QuantizationFlags::Signed, i16Type, nullptr,
+                              std::numeric_limits<int16_t>::min(),
+                              std::numeric_limits<int16_t>::max()),
+        CandidateQuantizedType::Scheme::UniformPerLayer);
+    q32ExplicitFixedPoint = addCandidateType(
+        AnyQuantizedType::get(QuantizationFlags::Signed, i32Type, nullptr,
+                              std::numeric_limits<int32_t>::min(),
+                              std::numeric_limits<int32_t>::max()),
+        CandidateQuantizedType::Scheme::UniformExplicitFixedPointScale);
+
+    // Op handlers.
+    addOpHandler<ConstantOp>(
+        std::bind(&FxpMathTargetConfigImpl::handleConstant, this, _1, _2));
+    addOpHandler<ReturnOp>(
+        std::bind(&FxpMathTargetConfigImpl::handleTerminal, this, _1, _2));
+    addOpHandler<quant::StatisticsOp>(
+        std::bind(&FxpMathTargetConfigImpl::handleStats, this, _1, _2));
+
+    // FxpMathOps.
+    addOpHandler<RealAddEwOp>(
+        std::bind(&FxpMathTargetConfigImpl::handleAdd, this, _1, _2));
+    addOpHandler<RealMulEwOp>(
+        std::bind(&FxpMathTargetConfigImpl::handleMul, this, _1, _2));
+    addOpHandler<RealMatMulOp>(
+        std::bind(&FxpMathTargetConfigImpl::handleMatMul, this, _1, _2));
+    addOpHandler<RealMatMulBiasOp>(
+        std::bind(&FxpMathTargetConfigImpl::handleMatMulBias, this, _1, _2));
+
+    // Require stats ops.
+    addRequireStatsOp<RealAddEwOp>();
+    addRequireStatsOp<RealSubEwOp>();
+    addRequireStatsOp<RealDivEwOp>();
+    addRequireStatsOp<RealMulEwOp>();
+    addRequireStatsOp<RealMatMulOp>();
+    addRequireStatsOp<RealMatMulBiasOp>();
+  }
+
+  bool isHandledType(Type t) const final {
+    if (t.isa<FloatType>())
+      return true;
+    return (t.isa<VectorType>() || t.isa<TensorType>()) &&
+           t.cast<ShapedType>().getElementType().isa<FloatType>();
+  }
+
+  void finalizeAnchors(CAGSlice &cag) const override {
+    cag.enumerateImpliedConnections(
+        [&](CAGAnchorNode *from, CAGAnchorNode *to) {
+          UniformConstraintsBuilder(cag).coupleAnchors(from, to);
+        });
+  }
+
+  void addValueIdentityOpByName(StringRef opName) override {
+    addOpHandlerByName(
+        opName,
+        std::bind(&FxpMathTargetConfigImpl::handleValueIdentity, this, _1, _2));
+  }
+
+  void handleValueIdentity(Operation *op, CAGSlice &cag) const {
+    assert(op->getNumResults() == 1);
+    if (!isHandledType(op->getResult(0)->getType()))
+      return;
+
+    auto resultNode = cag.getResultAnchor(op, 0);
+    resultNode->setTypeTransformRule(
+        CAGAnchorNode::TypeTransformRule::DirectStorage);
+
+    for (unsigned opIdx = 0, e = op->getNumOperands(); opIdx < e; ++opIdx) {
+      if (!isHandledType(op->getOperand(opIdx)->getType()))
+        continue;
+      auto operandNode = cag.getOperandAnchor(op, opIdx);
+      operandNode->setTypeTransformRule(
+          CAGAnchorNode::TypeTransformRule::DirectStorage);
+      UniformConstraintsBuilder(cag).coupleAnchors(operandNode, resultNode);
+    }
+  }
+
+  void handleConstant(Operation *op, CAGSlice &cag) const {
+    if (!isHandledType(op->getResult(0)->getType()))
+      return;
+
+    auto resultNode = cag.getResultAnchor(op, 0);
+    resultNode->setTypeTransformRule(
+        CAGAnchorNode::TypeTransformRule::ExpressedOnly);
+    Attribute valueAttr;
+    if (!matchPattern(op, m_Constant(&valueAttr))) {
+      return;
+    }
+
+    AttributeTensorStatistics stats(valueAttr);
+    TensorAxisStatistics layerStats;
+    if (!stats.get(layerStats)) {
+      op->emitOpError("could not compute statistics");
+      return;
+    }
+
+    UniformConstraintsBuilder(cag).applyStats(resultNode, layerStats);
+  }
+
+  void handleTerminal(Operation *op, CAGSlice &cag) const {
+    if (!isHandledType(op->getOperand(0)->getType()))
+      return;
+    auto operandNode = cag.getOperandAnchor(op, 0);
+    operandNode->setTypeTransformRule(
+        CAGAnchorNode::TypeTransformRule::ExpressedOnly);
+  }
+
+  void handleStats(Operation *op, CAGSlice &cag) const {
+    if (!isHandledType(op->getResult(0)->getType()))
+      return;
+
+    auto argNode = cag.getOperandAnchor(op, 0);
+    auto resultNode = cag.getResultAnchor(op, 0);
+    UniformConstraintsBuilder(cag).coupleAnchors(argNode, resultNode);
+
+    TensorAxisStatistics layerStats;
+    auto statsOp = cast<quant::StatisticsOp>(op);
+    auto layerStatsAttr = statsOp.layerStats();
+    layerStats.minValue =
+        layerStatsAttr.getValue({0}).cast<FloatAttr>().getValueAsDouble();
+    layerStats.maxValue =
+        layerStatsAttr.getValue({1}).cast<FloatAttr>().getValueAsDouble();
+    UniformConstraintsBuilder(cag).applyStats(resultNode, layerStats);
+  }
+
+  void handleAdd(Operation *op, CAGSlice &cag) const {
+    if (!isHandledType(op->getResult(0)->getType()))
+      return;
+
+    auto lhs = cag.getOperandAnchor(op, 0);
+    auto rhs = cag.getOperandAnchor(op, 1);
+    auto resultNode = cag.getResultAnchor(op, 0);
+    // Add supports 8/16 bit math.
+    llvm::SmallBitVector disableMask =
+        getCandidateTypeDisabledExceptMask({q8, q16});
+    lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
+    rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
+    resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
+    // NOTE: We couple the add such that the scale/zeroPoint match between
+    // both args and the result. This is overly constrained in that it is
+    // possible to write efficient add kernels with a bit more freedom (i.e.
+    // zeroPoints can vary, scales can differ by a power of two, etc).
+    // However, fully coupled yields the simples solutions on the fast path.
+    // Further efficiency can be had by constraining the zeroPoint to 0, but
+    // there isn't a constraint for this yet (and there are tradeoffs).
+    UniformConstraintsBuilder(cag).coupleAnchors(lhs, resultNode);
+    UniformConstraintsBuilder(cag).coupleAnchors(rhs, resultNode);
+    addRealMathOptionalConstraints(op, resultNode, cag);
+  }
+
+  void handleMul(Operation *op, CAGSlice &cag) const {
+    if (!isHandledType(op->getResult(0)->getType()))
+      return;
+
+    auto lhs = cag.getOperandAnchor(op, 0);
+    auto rhs = cag.getOperandAnchor(op, 1);
+    auto resultNode = cag.getResultAnchor(op, 0);
+    // Mul supports 8/16 bit math.
+    llvm::SmallBitVector disableMask =
+        getCandidateTypeDisabledExceptMask({q8, q16});
+    lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
+    rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
+    resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
+    addRealMathOptionalConstraints(op, resultNode, cag);
+  }
+
+  void handleMatMul(Operation *op, CAGSlice &cag) const {
+    if (!isHandledType(op->getResult(0)->getType()))
+      return;
+
+    auto lhs = cag.getOperandAnchor(op, 0);
+    auto rhs = cag.getOperandAnchor(op, 1);
+    auto resultNode = cag.getResultAnchor(op, 0);
+    // Mul supports 8/16 bit math.
+    llvm::SmallBitVector disableMask =
+        getCandidateTypeDisabledExceptMask({q8, q16});
+    lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
+    rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
+    resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
+    addRealMathOptionalConstraints(op, resultNode, cag);
+  }
+
+  void handleMatMulBias(Operation *op, CAGSlice &cag) const {
+    if (!isHandledType(op->getResult(0)->getType()))
+      return;
+
+    auto lhs = cag.getOperandAnchor(op, 0);
+    auto rhs = cag.getOperandAnchor(op, 1);
+    auto bias = cag.getOperandAnchor(op, 2);
+    bias->getUniformMetadata().disabledCandidateTypes =
+        getCandidateTypeDisabledExceptMask({q32ExplicitFixedPoint});
+
+    auto resultNode = cag.getResultAnchor(op, 0);
+    UniformConstraintsBuilder(cag).propagateExplicitScale(resultNode, bias);
+
+    // Mul supports 8/16 bit math.
+    llvm::SmallBitVector disableMask =
+        getCandidateTypeDisabledExceptMask({q8, q16});
+    lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
+    rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
+    resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
+    addRealMathOptionalConstraints(op, resultNode, cag);
+  }
+
+  void addRealMathOptionalConstraints(Operation *op, CAGAnchorNode *anchor,
+                                      CAGSlice &cag) const {
+    // TODO: It would be nice if these all extended some base trait instead
+    // of requiring name lookup.
+    auto clampMinAttr = op->getAttrOfType<FloatAttr>("clamp_min");
+    auto clampMaxAttr = op->getAttrOfType<FloatAttr>("clamp_max");
+
+    if (clampMinAttr || clampMaxAttr) {
+      auto nan = APFloat::getQNaN(APFloat::IEEEdouble());
+      auto clampMin = clampMinAttr ? clampMinAttr.getValue() : nan;
+      auto clampMax = clampMaxAttr ? clampMaxAttr.getValue() : nan;
+      UniformConstraintsBuilder(cag).clamp(anchor, clampMin, clampMax);
+    }
+  }
+
+  unsigned q8;
+  unsigned q16;
+  unsigned q32ExplicitFixedPoint;
+};
+
+} // anonymous namespace
+
+std::unique_ptr<FxpMathTargetConfig>
+FxpMathTargetConfig::create(SolverContext &context) {
+  return llvm::make_unique<FxpMathTargetConfigImpl>(context);
+}
diff --git a/third_party/mlir/lib/Quantizer/Support/Configuration.cpp b/third_party/mlir/lib/Quantizer/Support/Configuration.cpp
new file mode 100644
index 0000000..78a7451
--- /dev/null
+++ b/third_party/mlir/lib/Quantizer/Support/Configuration.cpp
@@ -0,0 +1,48 @@
+//===- Configuration.cpp - Configuration object base classes --------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Quantizer/Support/Configuration.h"
+
+#include <limits>
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Identifier.h"
+#include "mlir/IR/MLIRContext.h"
+
+using namespace mlir;
+using namespace mlir::quantizer;
+
+TargetConfiguration::TargetConfiguration(SolverContext &context) {}
+
+void TargetConfiguration::addOpHandlerByName(StringRef name, OpHandlerFn fn) {
+  opHandlers[name] = fn;
+}
+
+void TargetConfiguration::addRequireStatsOpByName(StringRef opName) {
+  requireStatsOpNames.insert(opName);
+}
+
+bool TargetConfiguration::isRequireStatsOp(Operation *op) const {
+  return requireStatsOpNames.find(op->getName().getStringRef()) !=
+         requireStatsOpNames.end();
+}
+
+void TargetConfiguration::handleOp(Operation *op, CAGSlice &cag) const {
+  auto found_it = opHandlers.find(op->getName().getStringRef());
+  if (found_it != opHandlers.end())
+    found_it->second(op, cag);
+}
diff --git a/third_party/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp b/third_party/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp
new file mode 100644
index 0000000..b4d48b7
--- /dev/null
+++ b/third_party/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp
@@ -0,0 +1,181 @@
+//===- ConstraintAnalysisGraph.cpp - Graphs type for constraints ----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Quantizer/Support/Configuration.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::quantizer;
+
+void CAGNode::replaceIncoming(CAGNode *otherNode) {
+  if (this == otherNode)
+    return;
+  for (CAGNode *parentNode : incoming) {
+    for (CAGNode *&it : parentNode->outgoing) {
+      if (it == this) {
+        it = otherNode;
+        otherNode->incoming.push_back(parentNode);
+      }
+    }
+  }
+  incoming.clear();
+}
+
+void CAGNode::addOutgoing(CAGNode *toNode) {
+  if (!llvm::is_contained(outgoing, toNode)) {
+    outgoing.push_back(toNode);
+    toNode->incoming.push_back(this);
+  }
+}
+
+CAGOperandAnchor::CAGOperandAnchor(Operation *op, unsigned operandIdx)
+    : CAGAnchorNode(Kind::OperandAnchor, op->getOperand(operandIdx)->getType()),
+      op(op), operandIdx(operandIdx) {}
+
+CAGResultAnchor::CAGResultAnchor(Operation *op, unsigned resultIdx)
+    : CAGAnchorNode(Kind::ResultAnchor, op->getResult(resultIdx)->getType()),
+      resultValue(op->getResult(resultIdx)) {}
+
+CAGSlice::CAGSlice(SolverContext &context) : context(context) {}
+CAGSlice::~CAGSlice() { llvm::DeleteContainerPointers(allNodes); }
+
+CAGOperandAnchor *CAGSlice::getOperandAnchor(Operation *op,
+                                             unsigned operandIdx) {
+  assert(operandIdx < op->getNumOperands() && "illegal operand index");
+
+  // Dedup.
+  auto key = std::make_pair(op, operandIdx);
+  auto foundIt = operandAnchors.find(key);
+  if (foundIt != operandAnchors.end()) {
+    return foundIt->second;
+  }
+
+  // Create.
+  auto anchor = llvm::make_unique<CAGOperandAnchor>(op, operandIdx);
+  auto *unowned = anchor.release();
+  unowned->nodeId = allNodes.size();
+  allNodes.push_back(unowned);
+  operandAnchors.insert(std::make_pair(key, unowned));
+  return unowned;
+}
+
+CAGResultAnchor *CAGSlice::getResultAnchor(Operation *op, unsigned resultIdx) {
+  assert(resultIdx < op->getNumResults() && "illegal result index");
+
+  // Dedup.
+  auto key = std::make_pair(op, resultIdx);
+  auto foundIt = resultAnchors.find(key);
+  if (foundIt != resultAnchors.end()) {
+    return foundIt->second;
+  }
+
+  // Create.
+  auto anchor = llvm::make_unique<CAGResultAnchor>(op, resultIdx);
+  auto *unowned = anchor.release();
+  unowned->nodeId = allNodes.size();
+  allNodes.push_back(unowned);
+  resultAnchors.insert(std::make_pair(key, unowned));
+  return unowned;
+}
+
+void CAGSlice::enumerateImpliedConnections(
+    std::function<void(CAGAnchorNode *from, CAGAnchorNode *to)> callback) {
+  // Discover peer identity pairs (i.e. implied edges from Result->Operand and
+  // Arg->Call). Use an intermediate vector so that the callback can modify.
+  std::vector<std::pair<CAGAnchorNode *, CAGAnchorNode *>> impliedPairs;
+  for (auto &resultAnchorPair : resultAnchors) {
+    CAGResultAnchor *resultAnchor = resultAnchorPair.second;
+    Value *resultValue = resultAnchor->getValue();
+    for (auto &use : resultValue->getUses()) {
+      Operation *operandOp = use.getOwner();
+      unsigned operandIdx = use.getOperandNumber();
+      auto foundIt = operandAnchors.find(std::make_pair(operandOp, operandIdx));
+      if (foundIt != operandAnchors.end()) {
+        impliedPairs.push_back(std::make_pair(resultAnchor, foundIt->second));
+      }
+    }
+  }
+
+  // Callback for each pair.
+  for (auto &impliedPair : impliedPairs) {
+    callback(impliedPair.first, impliedPair.second);
+  }
+}
+
+unsigned CAGSlice::propagate(const TargetConfiguration &config) {
+  std::vector<CAGNode *> dirtyNodes;
+  dirtyNodes.reserve(allNodes.size());
+  // Note that because iteration happens in nodeId order, there is no need
+  // to sort in order to make deterministic. If the selection method changes,
+  // a sort should be explicitly done.
+  for (CAGNode *child : *this) {
+    if (child->isDirty()) {
+      dirtyNodes.push_back(child);
+    }
+  }
+
+  if (dirtyNodes.empty()) {
+    return 0;
+  }
+  for (auto dirtyNode : dirtyNodes) {
+    dirtyNode->clearDirty();
+    dirtyNode->propagate(context, config);
+  }
+
+  return dirtyNodes.size();
+}
+
+void CAGAnchorNode::propagate(SolverContext &solverContext,
+                              const TargetConfiguration &config) {
+  for (CAGNode *child : *this) {
+    child->markDirty();
+  }
+}
+
+Type CAGAnchorNode::getTransformedType() {
+  if (!getUniformMetadata().selectedType) {
+    return nullptr;
+  }
+  return getUniformMetadata().selectedType.castFromExpressedType(
+      getOriginalType());
+}
+
+void CAGNode::printLabel(llvm::raw_ostream &os) const {
+  os << "Node<" << static_cast<const void *>(this) << ">";
+}
+
+void CAGAnchorNode::printLabel(llvm::raw_ostream &os) const {
+  getUniformMetadata().printSummary(os);
+}
+
+void CAGOperandAnchor::printLabel(llvm::raw_ostream &os) const {
+  os << "Operand<";
+  op->getName().print(os);
+  os << "," << operandIdx;
+  os << ">";
+  CAGAnchorNode::printLabel(os);
+}
+
+void CAGResultAnchor::printLabel(llvm::raw_ostream &os) const {
+  os << "Result<";
+  getOp()->getName().print(os);
+  os << ">";
+  CAGAnchorNode::printLabel(os);
+}
diff --git a/third_party/mlir/lib/Quantizer/Support/Metadata.cpp b/third_party/mlir/lib/Quantizer/Support/Metadata.cpp
new file mode 100644
index 0000000..3661f52
--- /dev/null
+++ b/third_party/mlir/lib/Quantizer/Support/Metadata.cpp
@@ -0,0 +1,42 @@
+//===- Metadata.cpp - Top level types and metadata ------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Quantizer/Support/Metadata.h"
+
+#include "mlir/IR/MLIRContext.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::quantizer;
+
+void CAGUniformMetadata::printSummary(llvm::raw_ostream &os) const {
+  if (requiredRange.hasValue()) {
+    os << "\n[" << requiredRange.getValue().first << ","
+       << requiredRange.getValue().second << "]";
+  }
+
+  if (disabledCandidateTypes.any()) {
+    os << "\n![";
+    mlir::interleaveComma(disabledCandidateTypes.set_bits(), os);
+    os << "]";
+  }
+
+  if (selectedType) {
+    os << "\n" << selectedType;
+  }
+}
diff --git a/third_party/mlir/lib/Quantizer/Support/Statistics.cpp b/third_party/mlir/lib/Quantizer/Support/Statistics.cpp
new file mode 100644
index 0000000..058d31f
--- /dev/null
+++ b/third_party/mlir/lib/Quantizer/Support/Statistics.cpp
@@ -0,0 +1,111 @@
+//===- Statistics.cpp - Collects statistics over tensors ------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Quantizer/Support/Statistics.h"
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/StandardTypes.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::quantizer;
+
+//===----------------------------------------------------------------------===//
+// AttributeTensorStatistics implementation
+//===----------------------------------------------------------------------===//
+
+static void
+collectElementsStatisticsDim(ElementsAttr attr, unsigned numElements,
+                             ArrayRef<int64_t> shape,
+                             llvm::SmallVector<uint64_t, 4> &indices,
+                             uint64_t dim, TensorAxisStatistics &statistics) {
+  // Recursive terminating condition.
+  if (dim >= shape.size())
+    return;
+
+  if (dim < (shape.size() - 1)) {
+    // Recurse past dim.
+    for (uint64_t i = 0, s = shape[dim]; i < s; ++i) {
+      indices[dim] = i;
+      collectElementsStatisticsDim(attr, numElements, shape, indices, dim + 1,
+                                   statistics);
+    }
+    return;
+  }
+
+  // Collection dim.
+  for (uint64_t i = 0, s = shape[dim]; i < s; ++i) {
+    indices[dim] = i;
+    double value = attr.getValue(llvm::makeArrayRef(indices))
+                       .cast<FloatAttr>()
+                       .getValueAsDouble();
+    statistics.minValue = std::min(statistics.minValue, value);
+    statistics.maxValue = std::max(statistics.maxValue, value);
+    statistics.mean += value / numElements;
+    // TODO: Calculate a running variance.
+  }
+}
+
+static bool getElementsStatistics(ElementsAttr attr,
+                                  TensorAxisStatistics &statistics) {
+  statistics.clear();
+  statistics.minValue = std::numeric_limits<double>::infinity();
+  statistics.maxValue = -std::numeric_limits<double>::infinity();
+
+  ShapedType sType = attr.getType();
+  if (!sType.hasStaticShape())
+    return false;
+  Type elementTy = sType.getElementType();
+  if (!elementTy.isa<FloatType>())
+    return false;
+
+  llvm::SmallVector<uint64_t, 4> indices;
+  indices.resize(sType.getRank());
+  ArrayRef<int64_t> shape = sType.getShape();
+
+  auto numElements = sType.getNumElements();
+  collectElementsStatisticsDim(attr, numElements, shape, indices, 0,
+                               statistics);
+  statistics.sampleSize = numElements;
+
+  return true;
+}
+
+bool AttributeTensorStatistics::get(TensorAxisStatistics &stats) const {
+  if (FloatAttr floatAttr = attr.dyn_cast<FloatAttr>()) {
+    double value = floatAttr.getValueAsDouble();
+    stats = TensorAxisStatistics(1, value, value, value, 0);
+    return true;
+  } else if (auto eltAttr = attr.dyn_cast<ElementsAttr>()) {
+    return getElementsStatistics(eltAttr, stats);
+  }
+  return false;
+}
+
+namespace mlir {
+namespace quantizer {
+
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+                              const TensorAxisStatistics &stats) {
+  os << "STATS[sampleSize=" << stats.sampleSize << ", min=" << stats.minValue
+     << ", maxValue=" << stats.maxValue << ", mean=" << stats.mean
+     << ", variance=" << stats.variance << "]";
+  return os;
+}
+
+} // end namespace quantizer
+} // end namespace mlir
diff --git a/third_party/mlir/lib/Quantizer/Support/TypeUtils.cpp b/third_party/mlir/lib/Quantizer/Support/TypeUtils.cpp
new file mode 100644
index 0000000..fab4e56
--- /dev/null
+++ b/third_party/mlir/lib/Quantizer/Support/TypeUtils.cpp
@@ -0,0 +1,31 @@
+//===- TypeUtils.cpp - Helper function for manipulating types -------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Quantizer/Support/TypeUtils.h"
+
+#include "mlir/IR/StandardTypes.h"
+
+using namespace mlir;
+using namespace mlir::quantizer;
+
+Type mlir::quantizer::getElementOrPrimitiveType(Type t) {
+  if (auto sType = t.dyn_cast<ShapedType>()) {
+    return sType.getElementType();
+  } else {
+    return t;
+  }
+}
diff --git a/third_party/mlir/lib/Quantizer/Support/UniformConstraints.cpp b/third_party/mlir/lib/Quantizer/Support/UniformConstraints.cpp
new file mode 100644
index 0000000..c43ecdf
--- /dev/null
+++ b/third_party/mlir/lib/Quantizer/Support/UniformConstraints.cpp
@@ -0,0 +1,267 @@
+//===- UniformConstraints.cpp - Constraints for uniform quant -------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Quantizer/Support/UniformConstraints.h"
+
+#include "mlir/Dialect/QuantOps/QuantTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Quantizer/Support/Configuration.h"
+#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
+#include "mlir/Quantizer/Support/Metadata.h"
+#include "mlir/Quantizer/Support/Rules.h"
+#include "mlir/Quantizer/Support/TypeUtils.h"
+#include "mlir/Quantizer/Support/UniformSolvers.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::quantizer;
+using namespace mlir::quant;
+
+namespace {
+
+struct ClusteredFacts {
+  ExpandingMinMaxFact requiredRange;
+  DiscreteScaleZeroPointFact explicitScaleZeroPoint;
+};
+
+} // end anonymous namespace
+
+static QuantizedType solveUniformType(SolverContext &solverContext,
+                                      const ClusteredFacts &clusteredFacts,
+                                      const CandidateQuantizedType &ct,
+                                      Type originalElementType, Location loc) {
+  switch (ct.scheme) {
+  default:
+    emitError(loc, "unsupported scheme for uniform type conversion");
+    return nullptr;
+
+  case CandidateQuantizedType::Scheme::UniformPerLayer: {
+    if (!clusteredFacts.requiredRange.hasValue()) {
+      // TODO: Issue some kind of diagnostic. This is not an error.
+      return nullptr;
+    }
+
+    uint64_t numLevels = ct.quantizedType.getStorageTypeMax() -
+                         ct.quantizedType.getStorageTypeMin();
+    UniformStorageParams params{numLevels,
+                                ct.quantizedType.getStorageTypeMin()};
+    UniformParamsFromMinMaxSolver solver(
+        params, clusteredFacts.requiredRange.getValue().first,
+        clusteredFacts.requiredRange.getValue().second);
+    if (!solver.compute()) {
+      emitWarning(loc) << "unable to solve uniform type with "
+                       << "UniformParamsFromMinMaxSolver";
+      return nullptr;
+    }
+
+    return UniformQuantizedType::getChecked(
+        ct.quantizedType.getFlags(), ct.quantizedType.getStorageType(),
+        originalElementType, solver.getScale(), solver.getZp(),
+        ct.quantizedType.getStorageTypeMin(),
+        ct.quantizedType.getStorageTypeMax(), loc);
+  }
+  case CandidateQuantizedType::Scheme::UniformExplicitFixedPointScale: {
+    if (!clusteredFacts.explicitScaleZeroPoint.hasValue()) {
+      emitRemark(loc)
+          << "unable to solve uniform type with UniformExplicitFixedPointScale "
+          << "(no explicitScaleZeroPoint)";
+      return nullptr;
+    }
+
+    const auto &scaleZp = clusteredFacts.explicitScaleZeroPoint.getValue();
+    assert(scaleZp.value && "optional value not set on fact");
+
+    if (scaleZp.conflict) {
+      emitWarning(loc)
+          << "conflicting explicit scale/zeroPoint on node cluster: "
+          << "an arbitrary scale/zeroPoint will be used";
+    }
+
+    return UniformQuantizedType::getChecked(
+        ct.quantizedType.getFlags(), ct.quantizedType.getStorageType(),
+        originalElementType,
+        scaleZp.value->first, // scale
+        0, // zeroPoint (fixed point solutions only for this scheme)
+        ct.quantizedType.getStorageTypeMin(),
+        ct.quantizedType.getStorageTypeMax(), loc);
+
+    return nullptr;
+  }
+  }
+}
+
+namespace {
+
+class PropagateExplicitScale : public CAGConstraintNode {
+public:
+  PropagateExplicitScale()
+      : CAGConstraintNode(Kind::UniformPropagateExplicitScale) {}
+  static bool classof(const CAGNode *n) {
+    return n->getKind() == Kind::Constraint ||
+           n->getKind() == Kind::UniformPropagateExplicitScale;
+  }
+
+private:
+  void printLabel(llvm::raw_ostream &os) const override {
+    os << "PropagateExplicitScale";
+  }
+  void propagate(SolverContext &solverContext,
+                 const TargetConfiguration &config) override {
+    DiscreteScaleZeroPointFact scaleZp;
+
+    // Get scale/zp from all parents.
+    for (auto it = incoming_begin(), e = incoming_end(); it != e; ++it) {
+      auto parentAnchor = llvm::cast<CAGAnchorNode>(*it);
+      auto selectedType = parentAnchor->getUniformMetadata().selectedType;
+      if (auto uqType = selectedType.dyn_cast_or_null<UniformQuantizedType>()) {
+        scaleZp.assertValue(
+            CAGUniformMetadata::SalienceRequired,
+            std::make_pair(uqType.getScale(), static_cast<int64_t>(0)));
+      }
+    }
+
+    // Propagate to children.
+    if (scaleZp.hasValue()) {
+      for (auto it = begin(), e = end(); it != e; ++it) {
+        auto childAnchor = llvm::cast<CAGAnchorNode>(*it);
+        if (modified(childAnchor->getUniformMetadata()
+                         .explicitScaleZeroPoint.mergeFrom(scaleZp))) {
+          childAnchor->markDirty();
+        }
+      }
+    }
+  }
+};
+
+/// A constraint node which will solve uniform quantization for all parents
+/// of the constraint, assuming that they are coupled.
+class SolveUniformConstraintNode : public CAGConstraintNode {
+public:
+  SolveUniformConstraintNode()
+      : CAGConstraintNode(Kind::SolveUniformConstraint) {
+    markDirty();
+  }
+  static bool classof(const CAGNode *n) {
+    return n->getKind() == Kind::Constraint ||
+           n->getKind() == Kind::SolveUniformConstraint;
+  }
+
+private:
+  void printLabel(llvm::raw_ostream &os) const override {
+    os << "SolveUniform";
+  }
+
+  void propagate(SolverContext &solverContext,
+                 const TargetConfiguration &config) override {
+    // First determine the required min/max range and type constraints.
+    Location fusedLoc = UnknownLoc::get(&solverContext.getMlirContext());
+    llvm::SmallBitVector enabledCandidateTypesMask(
+        config.getAllCandidateTypesMask());
+    ClusteredFacts clusteredFacts;
+    Type originalElementType;
+    for (auto it = incoming_begin(), e = incoming_end(); it != e; ++it) {
+      auto parentAnchor = llvm::cast<CAGAnchorNode>(*it);
+      auto metadata = parentAnchor->getUniformMetadata();
+      // TODO: Possibly use a location that fuses all involved parents.
+      fusedLoc = parentAnchor->getOp()->getLoc();
+
+      // Shared element type.
+      auto parentOriginalElementType =
+          getElementOrPrimitiveType(parentAnchor->getOriginalType());
+      if (!originalElementType) {
+        originalElementType = parentOriginalElementType;
+      } else {
+        if (originalElementType != parentOriginalElementType) {
+          parentAnchor->getOp()->emitError()
+              << "cannot compute uniform type: parent element types mismatch";
+          return;
+        }
+      }
+      // Range.
+      clusteredFacts.requiredRange.mergeFrom(metadata.requiredRange);
+
+      // Explicit scale and zero point.
+      clusteredFacts.explicitScaleZeroPoint.mergeFrom(
+          metadata.explicitScaleZeroPoint);
+
+      // Shared candidate types.
+      enabledCandidateTypesMask.reset(metadata.disabledCandidateTypes);
+    }
+
+    // Find the first enabled candidate type.
+    const CandidateQuantizedType *bestCandidateType = nullptr;
+    for (auto &ct : config.getCandidateTypes()) {
+      if (enabledCandidateTypesMask.test(ct.ordinal)) {
+        bestCandidateType = &ct;
+        break;
+      }
+    }
+
+    if (!bestCandidateType || !originalElementType) {
+      emitRemark(fusedLoc)
+          << "not solving uniform type (no viable candidate type)";
+      return;
+    }
+
+    // Solve for the type.
+    QuantizedType selectedType =
+        solveUniformType(solverContext, clusteredFacts, *bestCandidateType,
+                         originalElementType, fusedLoc);
+
+    // Apply it to all parents.
+    for (auto it = incoming_begin(), e = incoming_end(); it != e; ++it) {
+      auto parentAnchor = llvm::cast<CAGAnchorNode>(*it);
+      auto &metadata = parentAnchor->getUniformMetadata();
+      if (metadata.selectedType != selectedType) {
+        metadata.selectedType = selectedType;
+        // And mark all children of the parent dirty (except us).
+        for (auto child : *parentAnchor) {
+          if (child != this) {
+            child->markDirty();
+          }
+        }
+      }
+    }
+  }
+};
+
+} // end anonymous namespace
+
+void UniformConstraintsBuilder::coupleAnchors(CAGAnchorNode *a,
+                                              CAGAnchorNode *b) {
+  slice.addClusteredConstraint<SolveUniformConstraintNode>({a, b});
+}
+
+void UniformConstraintsBuilder::applyStats(CAGAnchorNode *a,
+                                           TensorAxisStatistics stats) {
+  a->getUniformMetadata().requiredRange.assertValue(
+      CAGUniformMetadata::SalienceDefault, {stats.minValue, stats.maxValue});
+}
+
+void UniformConstraintsBuilder::clamp(CAGAnchorNode *a, APFloat minValue,
+                                      APFloat maxValue) {
+  a->getUniformMetadata().requiredRange.assertValue(
+      CAGUniformMetadata::SalienceDefault,
+      {minValue.convertToDouble(), maxValue.convertToDouble()});
+}
+
+void UniformConstraintsBuilder::propagateExplicitScale(CAGAnchorNode *from,
+                                                       CAGAnchorNode *to) {
+  slice.addUnidirectionalConstraint<PropagateExplicitScale>(from, {to});
+}
diff --git a/third_party/mlir/lib/Quantizer/Support/UniformSolvers.cpp b/third_party/mlir/lib/Quantizer/Support/UniformSolvers.cpp
new file mode 100644
index 0000000..b4c14ca
--- /dev/null
+++ b/third_party/mlir/lib/Quantizer/Support/UniformSolvers.cpp
@@ -0,0 +1,158 @@
+//===- UniformSolvers.cpp - Uniform type solver algorithms ----------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Quantizer/Support/UniformSolvers.h"
+
+#include "llvm/Support/raw_ostream.h"
+
+#include <cmath>
+
+using namespace mlir;
+using namespace mlir::quantizer;
+
+bool UniformParamsFromMinMaxSolver::compute() {
+  // Compute adjMin, adjMax, clamping to ensure that they straddle zero.
+  if (boundingMin > 0 && boundingMax >= boundingMin) {
+    // Lop-sided to the positive.
+    adjMin = 0;
+    adjMax = boundingMax;
+  } else if (boundingMax < 0 && boundingMin <= boundingMax) {
+    // Lop-sided to the negative.
+    adjMin = boundingMin;
+    adjMax = 0;
+  } else if (boundingMin <= 0 && boundingMax >= 0) {
+    adjMin = boundingMin;
+    adjMax = boundingMax;
+  } else {
+    // Illegal bounds.
+    return satisfied = false;
+  }
+
+  const double origMinAdj = adjMin;
+  const double origMaxAdj = adjMax;
+  const double numLevelsDouble = storageParams.numLevels;
+
+  struct fns {
+    static std::pair<double, double>
+    computeMinMax(double boundingMin, double numLevels, double delta) {
+      double adjMin = delta * std::floor(boundingMin / delta);
+      return std::make_pair(adjMin, adjMin + numLevels * delta);
+    }
+    static double overshoot(double boundingMin, double boundingMax,
+                            double numLevels, double delta) {
+      auto adjMinMax = computeMinMax(boundingMin, numLevels, delta);
+      double maxOvershoot = adjMinMax.second - boundingMax;
+      double minOvershoot = boundingMin - adjMinMax.first;
+      // If undershooting on the min or max end, return that because it is
+      // to be unconditionally avoided. Otherwise return the end with the
+      // greateast magnitude of overshoot.
+      if (maxOvershoot < 0)
+        return maxOvershoot;
+      if (minOvershoot < 0)
+        return minOvershoot;
+      return std::max(maxOvershoot, minOvershoot);
+    }
+  };
+
+  // Bisect to find a suitable delta, starting with bounds of deltaInit
+  // and deltaMax.
+  double deltaInit = (adjMax - adjMin) / numLevelsDouble;
+  double deltaMax =
+      ((numLevelsDouble * deltaInit) + 2 * deltaInit) / numLevelsDouble;
+  double deltaMid;
+  double prevDeltaMid = 0.0;
+  for (stepCount = 0; stepCount < 60; ++stepCount) {
+    deltaMid = (deltaInit + deltaMax) / 2.0;
+    auto fInit =
+        fns::overshoot(origMinAdj, origMaxAdj, numLevelsDouble, deltaInit);
+    auto fMid =
+        fns::overshoot(origMinAdj, origMaxAdj, numLevelsDouble, deltaMid);
+    if (fMid == 0 || (fMid > 0 && std::fabs(deltaMid - prevDeltaMid) < 1e-15)) {
+      // Solution found (or step size is infinitessimal and an overshoot).
+      // Empirically, this seems to terminate around 30-50 steps or so.
+      // This will find a zero point for exactly representable ranges and
+      // will terminate on a small step size for inexact, biasing towards
+      // overshooting.
+      delta = deltaMid;
+      break;
+    }
+    bool signMid = fMid > 0;
+    bool signInit = fInit > 0;
+    if (signMid == signInit) {
+      deltaInit = deltaMid;
+    } else {
+      deltaMax = deltaMid;
+    }
+    prevDeltaMid = deltaMid;
+  }
+  delta = deltaMid;
+
+  // Recalculate adjMin/adjMax based on new delta.
+  auto adjMinMax = fns::computeMinMax(origMinAdj, numLevelsDouble, delta);
+  adjMin = adjMinMax.first;
+  adjMax = adjMinMax.second;
+
+  satisfied = false;
+  zp = 0;
+
+  if (!std::isnan(delta) && !std::isnan(adjMin) && !std::isnan(adjMax)) {
+    satisfied = true;
+    // Finally, scale and zeroPoint. Since it casts to integer, only valid
+    // if the inputs are valid.
+    zp = std::round(storageParams.minValue - adjMin / delta);
+  }
+
+  return satisfied;
+}
+
+int64_t UniformParamsFromMinMaxSolver::quantize(double x) const {
+  int64_t xq = std::round(x / delta + zp);
+  return std::max<int64_t>(0, std::min<int64_t>(storageParams.numLevels, xq));
+}
+
+double UniformParamsFromMinMaxSolver::dequantize(int64_t xq) const {
+  return (xq - zp) * delta;
+}
+
+namespace mlir {
+namespace quantizer {
+
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+                              const UniformStorageParams &p) {
+  os << "UniformStorageParams{" << p.numLevels << ", " << p.minValue << "}";
+  return os;
+}
+
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+                              const UniformParamsFromMinMaxSolver &s) {
+  os << "UniformParamsFromMinMaxSolver(" << s.getStepCount() << "){";
+  os << "(" << s.getBoundingMin() << ":" << s.getBoundingMax() << ") -> ";
+  if (!s.isSatisfied()) {
+    os << "unsat}";
+    return os;
+  }
+
+  os << "(" << s.getAdjMin() << ":" << s.getAdjMax() << ")";
+  os << ", scale = " << s.getScale();
+  os << ", zp = " << s.getZp();
+  os << "}";
+
+  return os;
+}
+
+} // end namespace quantizer
+} // end namespace mlir
diff --git a/third_party/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp b/third_party/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp
new file mode 100644
index 0000000..3f26bf0
--- /dev/null
+++ b/third_party/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp
@@ -0,0 +1,128 @@
+//===- AddDefaultStatsTestPass.cpp - Testing pass to add default stats ----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines a testing pass to add default statistics nodes to every
+// quantization eligible op. Useful for unit testing.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/QuantOps/QuantOps.h"
+#include "mlir/Dialect/QuantOps/QuantTypes.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Quantizer/Configurations/FxpMathConfig.h"
+#include "mlir/Quantizer/Support/Configuration.h"
+#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
+#include "mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h"
+#include "mlir/Quantizer/Transforms/Passes.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/GraphWriter.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::quantizer;
+using namespace mlir::quant;
+
+namespace {
+
+class AddDefaultStatsPass : public FunctionPass<AddDefaultStatsPass> {
+public:
+  AddDefaultStatsPass() = default;
+  AddDefaultStatsPass(SolverContext &solverContext,
+                      const TargetConfiguration &config)
+      : explicitSolverContext(&solverContext), explicitConfig(&config) {}
+
+  void runOnFunction() override;
+  void runWithConfig(SolverContext &solverContext,
+                     const TargetConfiguration &config);
+
+private:
+  SolverContext *explicitSolverContext = nullptr;
+  const TargetConfiguration *explicitConfig = nullptr;
+};
+
+} // end anonymous namespace
+
+void AddDefaultStatsPass::runOnFunction() {
+  if (explicitSolverContext && explicitConfig) {
+    // If explicitly constructed with a config and context.
+    runWithConfig(*explicitSolverContext, *explicitConfig);
+    return;
+  }
+  // For global pass registration, use defaults.
+  SolverContext solverContext(*getFunction().getContext());
+  auto config = FxpMathTargetConfig::create(solverContext);
+  runWithConfig(solverContext, *config);
+}
+
+void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext,
+                                        const TargetConfiguration &config) {
+  auto func = getFunction();
+
+  // Insert stats for each argument.
+  for (auto *arg : func.getArguments()) {
+    if (!config.isHandledType(arg->getType()))
+      continue;
+    OpBuilder b(func.getBody());
+    APFloat minValue(-1.0f);
+    APFloat maxValue(1.0f);
+    ElementsAttr layerStats = DenseFPElementsAttr::get(
+        b.getTensorType({2}, b.getF32Type()), {minValue, maxValue});
+    auto statsOp =
+        b.create<StatisticsOp>(func.getLoc(), arg, layerStats, nullptr);
+    arg->replaceAllUsesWith(statsOp);
+
+    // StatsOp contained a use to 'arg' so make sure to reset it after replacing
+    // all of the uses of 'arg'.
+    statsOp.getOperation()->replaceUsesOfWith(statsOp, arg);
+  }
+
+  // Walk the ops and insert stats.
+  func.walk([&](Operation *op) {
+    if (!config.isRequireStatsOp(op)) {
+      return;
+    }
+    assert(op->getNumResults() == 1);
+
+    auto originalResult = op->getResult(0);
+    if (!config.isHandledType(originalResult->getType()))
+      return;
+
+    OpBuilder b(op->getBlock(), ++op->getIterator());
+
+    APFloat minValue(-1.0f);
+    APFloat maxValue(1.0f);
+    ElementsAttr layerStats = DenseFPElementsAttr::get(
+        b.getTensorType({2}, b.getF32Type()), {minValue, maxValue});
+    auto statsOp = b.create<StatisticsOp>(op->getLoc(), op->getResult(0),
+                                          layerStats, nullptr);
+    originalResult->replaceAllUsesWith(statsOp);
+
+    // StatsOp contained a use to 'op' so make sure to reset it after replacing
+    // all of the uses of 'op'.
+    statsOp.getOperation()->replaceUsesOfWith(statsOp, originalResult);
+  });
+}
+
+FunctionPassBase *mlir::quantizer::createAddDefaultStatsPass() {
+  return new AddDefaultStatsPass();
+}
+
+static PassRegistration<AddDefaultStatsPass> pass(
+    "quantizer-add-default-stats-test",
+    "Adds default (dummy) statistics to all ops that can benefit from "
+    "runtime statistics. This is meant to help in early stage bootstrapping.");
diff --git a/third_party/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp b/third_party/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp
new file mode 100644
index 0000000..765a36e
--- /dev/null
+++ b/third_party/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp
@@ -0,0 +1,296 @@
+//===- InferQuantizedTypesPass.cpp - Infers quantized types ---------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines the primary pass for instantiating a CAG, running it to
+// convergence on a module to determine eligible quantized type transforms, and
+// applying those transforms to the IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/QuantOps/QuantOps.h"
+#include "mlir/Dialect/QuantOps/QuantTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Quantizer/Configurations/FxpMathConfig.h"
+#include "mlir/Quantizer/Support/Configuration.h"
+#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
+#include "mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h"
+#include "mlir/Quantizer/Transforms/Passes.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/DOTGraphTraits.h"
+#include "llvm/Support/GraphWriter.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::quantizer;
+using namespace mlir::quant;
+
+namespace llvm {
+
+template <>
+struct DOTGraphTraits<const CAGSlice *>
+    : public DOTGraphTraits<const CAGNode *> {
+  DOTGraphTraits(bool isSimple = false)
+      : DOTGraphTraits<const CAGNode *>(isSimple) {}
+
+  std::string getNodeLabel(const CAGNode *node, const CAGSlice *graph) {
+    std::string s;
+    llvm::raw_string_ostream out(s);
+    node->printLabel(out);
+    return out.str();
+  }
+
+  static std::string getGraphProperties(const CAGSlice *) {
+    return "rankdir=LR;";
+  }
+
+  static bool isNodeHidden(const CAGNode *node) {
+    // Filter constraint nodes with no incoming or outgoing connections.
+    // These orphans are often created as part of graph merging operations.
+    return llvm::isa<CAGConstraintNode>(node) && node->isOrphan();
+  }
+
+  std::string getNodeAttributes(const CAGNode *node, const CAGSlice *graph) {
+    switch (node->getKind()) {
+    default:
+      return std::string();
+    case CAGNode::Kind::OperandAnchor:
+      return "shape=record,color=yellow,style=filled";
+    case CAGNode::Kind::ResultAnchor:
+      return "shape=record,color=lightblue,style=filled";
+    case CAGNode::Kind::Constraint:
+      return "shape=record,style=dotted";
+    }
+  }
+};
+
+} // end namespace llvm
+
+namespace {
+
+class InferQuantizedTypesPass : public ModulePass<InferQuantizedTypesPass> {
+public:
+  InferQuantizedTypesPass() = default;
+  InferQuantizedTypesPass(SolverContext &solverContext,
+                          const TargetConfiguration &config)
+      : explicitSolverContext(&solverContext), explicitConfig(&config) {}
+  void runOnModule() override;
+  void runWithConfig(SolverContext &solverContext,
+                     const TargetConfiguration &config);
+
+  void transformOperandType(CAGOperandAnchor *anchor, Type newType);
+  void transformResultType(CAGResultAnchor *anchor, Type newType);
+
+private:
+  SolverContext *explicitSolverContext = nullptr;
+  const TargetConfiguration *explicitConfig = nullptr;
+};
+
+} // end anonymous namespace
+
+/// Maximum number of propagation rounds to run to converge the CAG before
+/// signalling an error.
+static const int kMaximumPropagationRounds = 1000;
+
+static LogicalResult validateTypeConversion(Type newType, Type origType,
+                                            Operation *op) {
+  if (!newType) {
+    return op->emitOpError() << "unsupported type conversion from " << newType;
+  }
+  return success();
+}
+
+void InferQuantizedTypesPass::runOnModule() {
+  if (explicitSolverContext && explicitConfig) {
+    // If explicitly constructed with a config and context.
+    runWithConfig(*explicitSolverContext, *explicitConfig);
+    return;
+  }
+
+  // For global pass registration, use defaults.
+  SolverContext solverContext(*getModule().getContext());
+  auto config = FxpMathTargetConfig::create(solverContext);
+  runWithConfig(solverContext, *config);
+}
+
+void InferQuantizedTypesPass::runWithConfig(SolverContext &solverContext,
+                                            const TargetConfiguration &config) {
+  CAGSlice cag(solverContext);
+  for (auto f : getModule().getOps<FuncOp>()) {
+    f.walk([&cag, &config](Operation *op) { config.handleOp(op, cag); });
+  }
+  config.finalizeAnchors(cag);
+
+  // Propagate.
+  int propRound;
+  for (propRound = kMaximumPropagationRounds; propRound > 0; --propRound) {
+    auto propCount = cag.propagate(config);
+    if (propCount == 0)
+      break;
+  }
+  if (propRound == 0) {
+    emitError(UnknownLoc::get(&getContext()),
+              "exceeded maximum number of solver iterations (infinite loop?)");
+    return;
+  }
+
+  // TODO: Only dump the GraphViz if a flag is set and move to a utility.
+  // GraphViz.
+  if (!solverContext.getDebugCAGDotPath().empty()) {
+    auto actFileName =
+        llvm::WriteGraph(const_cast<const CAGSlice *>(&cag), "CAG",
+                         /*ShortNames=*/false,
+                         /*Title=*/"CAG",
+                         /*Filename=*/solverContext.getDebugCAGDotPath());
+    llvm::errs() << "Wrote graphviz file: " << actFileName << "\n";
+  }
+
+  // Start transforming the types in order of anchor type (results, then
+  // operands).
+  // Apply result types.
+  for (auto *node : cag) {
+    auto anchorNode = llvm::dyn_cast<CAGResultAnchor>(node);
+    if (!anchorNode)
+      continue;
+    if (Type newType = anchorNode->getTransformedType())
+      transformResultType(anchorNode, newType);
+  }
+
+  // Apply operand types.
+  for (auto *node : cag) {
+    auto anchorNode = llvm::dyn_cast<CAGOperandAnchor>(node);
+    if (!anchorNode)
+      continue;
+    if (Type newType = anchorNode->getTransformedType())
+      transformOperandType(anchorNode, newType);
+  }
+}
+
+void InferQuantizedTypesPass::transformOperandType(CAGOperandAnchor *anchor,
+                                                   Type newType) {
+  Value *inputValue = anchor->getValue();
+  Operation *op = anchor->getOp();
+  OpBuilder b(op->getBlock(), Block::iterator(op));
+
+  SmallVector<Value *, 1> removeValuesIfDead;
+
+  // Because we've already run the result transforms at this phase, it is
+  // very likely that inputValue points to a dcast op whose input matches
+  // our type. We detect that situation and route around just to save some
+  // bulk in the IR.
+  Value *newTypedInputValue = inputValue;
+  auto inputDcastOp =
+      dyn_cast_or_null<DequantizeCastOp>(inputValue->getDefiningOp());
+  if (inputDcastOp && inputDcastOp.arg()->getType() == newType) {
+    // Can just use the dcast's input value.
+    newTypedInputValue = inputDcastOp.arg();
+    removeValuesIfDead.push_back(inputDcastOp);
+  } else {
+    // Need to synthesize a qcast.
+    newTypedInputValue =
+        b.create<QuantizeCastOp>(op->getLoc(), newType, inputValue);
+  }
+
+  switch (anchor->getTypeTransformRule()) {
+  case CAGAnchorNode::TypeTransformRule::Direct:
+    anchor->getOp()->setOperand(anchor->getOperandIdx(), newTypedInputValue);
+    break;
+
+  case CAGAnchorNode::TypeTransformRule::DirectStorage: {
+    Type storageType = QuantizedType::castToStorageType(newType);
+    if (failed(validateTypeConversion(storageType, newType, op)))
+      return;
+    anchor->getOp()->setOperand(
+        anchor->getOperandIdx(),
+        b.create<StorageCastOp>(op->getLoc(), storageType, newTypedInputValue));
+    break;
+  }
+
+  case CAGAnchorNode::TypeTransformRule::ExpressedOnly:
+    // Leave the anchor as-is and just cast in/out after it.
+    anchor->getOp()->setOperand(
+        anchor->getOperandIdx(),
+        b.create<DequantizeCastOp>(op->getLoc(), anchor->getOriginalType(),
+                                   newTypedInputValue));
+    break;
+  }
+
+  for (Value *removeValueIfDead : removeValuesIfDead) {
+    if (removeValueIfDead->use_empty()) {
+      removeValueIfDead->getDefiningOp()->erase();
+    }
+  }
+}
+
+void InferQuantizedTypesPass::transformResultType(CAGResultAnchor *anchor,
+                                                  Type newType) {
+  Value *origResultValue = anchor->getValue();
+  Operation *op = origResultValue->getDefiningOp();
+  OpBuilder b(op->getBlock(), ++Block::iterator(op));
+
+  Value *replacedResultValue = nullptr;
+  Value *newResultValue = nullptr;
+  switch (anchor->getTypeTransformRule()) {
+  case CAGAnchorNode::TypeTransformRule::Direct:
+    origResultValue->setType(newType);
+    replacedResultValue = newResultValue = b.create<DequantizeCastOp>(
+        op->getLoc(), anchor->getOriginalType(), origResultValue);
+    break;
+
+  case CAGAnchorNode::TypeTransformRule::DirectStorage: {
+    Type storageType = QuantizedType::castToStorageType(newType);
+    if (failed(validateTypeConversion(storageType, newType, op)))
+      return;
+    origResultValue->setType(storageType);
+    replacedResultValue =
+        b.create<StorageCastOp>(op->getLoc(), newType, origResultValue);
+    newResultValue = b.create<DequantizeCastOp>(
+        op->getLoc(), anchor->getOriginalType(), replacedResultValue);
+    break;
+  }
+
+  case CAGAnchorNode::TypeTransformRule::ExpressedOnly:
+    // Leave the anchor as-is and just cast in/out after it.
+    replacedResultValue =
+        b.create<QuantizeCastOp>(op->getLoc(), newType, origResultValue);
+    newResultValue = b.create<DequantizeCastOp>(
+        op->getLoc(), anchor->getOriginalType(), replacedResultValue);
+    break;
+  }
+
+  if (replacedResultValue) {
+    // Transform:
+    //   origResultValue -->  replaceResultValue -> newResultValue
+    //                   \->  [original uses]
+    // To:
+    //   origResultValue -> replaceResultValue ->
+    //                      newResultValue -> [original uses]
+    // Note that replaceResultValue may equal newResultValue or there may
+    // be operands between the two.
+    origResultValue->replaceAllUsesWith(newResultValue);
+    replacedResultValue->getDefiningOp()->replaceUsesOfWith(newResultValue,
+                                                            origResultValue);
+  }
+}
+
+ModulePassBase *mlir::quantizer::createInferQuantizedTypesPass(
+    SolverContext &solverContext, const TargetConfiguration &config) {
+  return new InferQuantizedTypesPass(solverContext, config);
+}
+
+static PassRegistration<InferQuantizedTypesPass>
+    pass("quantizer-infer-quantized-types",
+         "Infers quantized types for a module");
diff --git a/third_party/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp b/third_party/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp
new file mode 100644
index 0000000..d5fb284
--- /dev/null
+++ b/third_party/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp
@@ -0,0 +1,76 @@
+//===- RemoveInstrumentationPass.cpp - Removes instrumentation ------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines a pass to remove any instrumentation ops. It is often one
+// of the final steps when performing quantization and is run after any
+// decisions requiring instrumentation have been made.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/QuantOps/QuantOps.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Quantizer/Transforms/Passes.h"
+
+using namespace mlir;
+using namespace mlir::quantizer;
+using namespace mlir::quant;
+
+namespace {
+
+class RemoveInstrumentationPass
+    : public FunctionPass<RemoveInstrumentationPass> {
+  void runOnFunction() override;
+};
+
+template <typename OpTy>
+class RemoveIdentityOpRewrite : public RewritePattern {
+public:
+  RemoveIdentityOpRewrite(MLIRContext *context)
+      : RewritePattern(OpTy::getOperationName(), 1, context) {}
+
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const override {
+    assert(op->getNumOperands() == 1);
+    assert(op->getNumResults() == 1);
+
+    rewriter.replaceOp(op, op->getOperand(0));
+    return matchSuccess();
+  }
+};
+
+} // end anonymous namespace
+
+void RemoveInstrumentationPass::runOnFunction() {
+  OwningRewritePatternList patterns;
+  auto func = getFunction();
+  auto *context = &getContext();
+  patterns.insert<RemoveIdentityOpRewrite<StatisticsOp>,
+                  RemoveIdentityOpRewrite<StatisticsRefOp>,
+                  RemoveIdentityOpRewrite<CoupledRefOp>>(context);
+  applyPatternsGreedily(func, patterns);
+}
+
+FunctionPassBase *mlir::quantizer::createRemoveInstrumentationPass() {
+  return new RemoveInstrumentationPass();
+}
+
+static PassRegistration<RemoveInstrumentationPass>
+    pass("quantizer-remove-instrumentation",
+         "Removes instrumentation and hints which have no effect on final "
+         "execution");
diff --git a/third_party/mlir/lib/SDBM/CMakeLists.txt b/third_party/mlir/lib/SDBM/CMakeLists.txt
new file mode 100644
index 0000000..30b2f641a
--- /dev/null
+++ b/third_party/mlir/lib/SDBM/CMakeLists.txt
@@ -0,0 +1,10 @@
+add_llvm_library(MLIRSDBM
+  SDBM.cpp
+  SDBMExpr.cpp
+  SDBMDialect.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/SDBM
+)
+add_dependencies(MLIRSDBM MLIRIR)
+target_link_libraries(MLIRSDBM MLIRIR)
diff --git a/third_party/mlir/lib/SDBM/SDBM.cpp b/third_party/mlir/lib/SDBM/SDBM.cpp
new file mode 100644
index 0000000..13932c6
--- /dev/null
+++ b/third_party/mlir/lib/SDBM/SDBM.cpp
@@ -0,0 +1,561 @@
+//===- SDBM.cpp - MLIR SDBM implementation --------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// A striped difference-bound matrix (SDBM) is a set in Z^N (or R^N) defined
+// as {(x_1, ... x_n) | f(x_1, ... x_n) >= 0} where f is an SDBM expression.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/SDBM/SDBM.h"
+#include "mlir/SDBM/SDBMExpr.h"
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+// Helper function for SDBM construction that collects information necessary to
+// start building an SDBM in one sweep.  In particular, it records the largest
+// position of a dimension in `dim`, that of a symbol in `symbol` as well as
+// collects all unique stripe expressions in `stripes`.  Uses SetVector to
+// ensure these expressions always have the same order.
+static void collectSDBMBuildInfo(SDBMExpr expr, int &dim, int &symbol,
+                                 llvm::SmallSetVector<SDBMExpr, 8> &stripes) {
+  struct Visitor : public SDBMVisitor<Visitor> {
+    void visitDim(SDBMDimExpr dimExpr) {
+      int p = dimExpr.getPosition();
+      if (p > maxDimPosition)
+        maxDimPosition = p;
+    }
+    void visitSymbol(SDBMSymbolExpr symbExpr) {
+      int p = symbExpr.getPosition();
+      if (p > maxSymbPosition)
+        maxSymbPosition = p;
+    }
+    void visitStripe(SDBMStripeExpr stripeExpr) { stripes.insert(stripeExpr); }
+
+    Visitor(llvm::SmallSetVector<SDBMExpr, 8> &stripes) : stripes(stripes) {}
+
+    int maxDimPosition = -1;
+    int maxSymbPosition = -1;
+    llvm::SmallSetVector<SDBMExpr, 8> &stripes;
+  };
+
+  Visitor visitor(stripes);
+  visitor.walkPostorder(expr);
+  dim = std::max(dim, visitor.maxDimPosition);
+  symbol = std::max(symbol, visitor.maxSymbPosition);
+}
+
+namespace {
+// Utility class for SDBMBuilder.  Represents a value that can be inserted in
+// the SDB matrix that corresponds to "v0 - v1 + C <= 0", where v0 and v1 is
+// any combination of the positive and negative positions.  Since multiple
+// variables can be declared equal to the same stripe expression, the
+// constraints on this expression must be reflected to all these variables.  For
+// example, if
+//   d0 = s0 # 42
+//   d1 = s0 # 42
+//   d2 = s1 # 2
+//   d3 = s1 # 2
+// the constraint
+//   s0 # 42 - s1 # 2 <= C
+// should be reflected in the DB matrix as
+//   d0 - d2 <= C
+//   d1 - d2 <= C
+//   d0 - d3 <= C
+//   d1 - d3 <= C
+// since the DB matrix has no knowledge of the transitive equality between d0,
+// d1 and s0 # 42 as well as between d2, d3 and s1 # 2.  This knowledge can be
+// obtained by computing a transitive closure, which is impossible until the
+// DBM is actually built.
+struct SDBMBuilderResult {
+  // Positions in the matrix of the variables taken with the "+" sign in the
+  // difference expression, 0 if it is a constant rather than a variable.
+  llvm::SmallVector<unsigned, 2> positivePos;
+
+  // Positions in the matrix of the variables taken with the "-" sign in the
+  // difference expression, 0 if it is a constant rather than a variable.
+  llvm::SmallVector<unsigned, 2> negativePos;
+
+  // Constant value in the difference expression.
+  int64_t value = 0;
+};
+
+// Visitor for building an SDBM from SDBM expressions.  After traversing an SDBM
+// expression, produces an update to the SDB matrix specifying the positions in
+// the matrix and the negated value that should be stored.  Both the positive
+// and the negative positions may be lists of indices in cases where multiple
+// variables are equal to the same stripe expression.  In such cases, the update
+// applies to the cross product of positions because elements involved in the
+// update are (transitively) equal and should have the same constraints, but we
+// may not have an explicit equality for them.
+struct SDBMBuilder : public SDBMVisitor<SDBMBuilder, SDBMBuilderResult> {
+public:
+  // A difference expression produces both the positive and the negative
+  // coordinate in the matrix, recursively traversing the LHS and the RHS. The
+  // value is the difference between values obtained from LHS and RHS.
+  SDBMBuilderResult visitDiff(SDBMDiffExpr diffExpr) {
+    auto lhs = visit(diffExpr.getLHS());
+    auto rhs = visit(diffExpr.getRHS());
+    assert(lhs.negativePos.size() == 1 && lhs.negativePos[0] == 0 &&
+           "unexpected negative expression in a difference expression");
+    assert(rhs.negativePos.size() == 1 && lhs.negativePos[0] == 0 &&
+           "unexpected negative expression in a difference expression");
+
+    SDBMBuilderResult result;
+    result.positivePos = lhs.positivePos;
+    result.negativePos = rhs.positivePos;
+    result.value = lhs.value - rhs.value;
+    return result;
+  }
+
+  // An input expression is always taken with the "+" sign and therefore
+  // produces a positive coordinate keeping the negative coordinate zero for an
+  // eventual constant.
+  SDBMBuilderResult visitInput(SDBMInputExpr expr) {
+    SDBMBuilderResult r;
+    r.positivePos.push_back(linearPosition(expr));
+    r.negativePos.push_back(0);
+    return r;
+  }
+
+  // A stripe expression is always equal to one or more variables, which may be
+  // temporaries, and appears with a "+" sign in the SDBM expression tree. Take
+  // the positions of the corresponding variables as positive coordinates.
+  SDBMBuilderResult visitStripe(SDBMStripeExpr expr) {
+    SDBMBuilderResult r;
+    assert(pointExprToStripe.count(expr));
+    r.positivePos = pointExprToStripe[expr];
+    r.negativePos.push_back(0);
+    return r;
+  }
+
+  // A constant expression has both coordinates at zero.
+  SDBMBuilderResult visitConstant(SDBMConstantExpr expr) {
+    SDBMBuilderResult r;
+    r.positivePos.push_back(0);
+    r.negativePos.push_back(0);
+    r.value = expr.getValue();
+    return r;
+  }
+
+  // A negation expression swaps the positive and the negative coordinates
+  // and also negates the constant value.
+  SDBMBuilderResult visitNeg(SDBMNegExpr expr) {
+    SDBMBuilderResult result = visit(expr.getVar());
+    std::swap(result.positivePos, result.negativePos);
+    result.value = -result.value;
+    return result;
+  }
+
+  // The RHS of a sum expression must be a constant and therefore must have both
+  // positive and negative coordinates at zero.  Take the sum of the values
+  // between LHS and RHS and keep LHS coordinates.
+  SDBMBuilderResult visitSum(SDBMSumExpr expr) {
+    auto lhs = visit(expr.getLHS());
+    auto rhs = visit(expr.getRHS());
+    for (auto pos : rhs.negativePos) {
+      (void)pos;
+      assert(pos == 0 && "unexpected variable on the RHS of SDBM sum");
+    }
+    for (auto pos : rhs.positivePos) {
+      (void)pos;
+      assert(pos == 0 && "unexpected variable on the RHS of SDBM sum");
+    }
+
+    lhs.value += rhs.value;
+    return lhs;
+  }
+
+  SDBMBuilder(llvm::DenseMap<SDBMExpr, llvm::SmallVector<unsigned, 2>>
+                  &pointExprToStripe,
+              llvm::function_ref<unsigned(SDBMInputExpr)> callback)
+      : pointExprToStripe(pointExprToStripe), linearPosition(callback) {}
+
+  llvm::DenseMap<SDBMExpr, llvm::SmallVector<unsigned, 2>> &pointExprToStripe;
+  llvm::function_ref<unsigned(SDBMInputExpr)> linearPosition;
+};
+} // namespace
+
+SDBM SDBM::get(ArrayRef<SDBMExpr> inequalities, ArrayRef<SDBMExpr> equalities) {
+  SDBM result;
+
+  // TODO(zinenko): consider detecting equalities in the list of inequalities.
+  // This is potentially expensive and requires to
+  //   - create a list of negated inequalities (may allocate under lock);
+  //   - perform a pairwise comparison of direct and negated inequalities;
+  //   - copy the lists of equalities and inequalities, and move entries between
+  //     them;
+  // only for the purpose of sparing a temporary variable in cases where an
+  // implicit equality between a variable and a stripe expression is present in
+  // the input.
+
+  // Do the first sweep over (in)equalities to collect the information necessary
+  // to allocate the SDB matrix (number of dimensions, symbol and temporary
+  // variables required for stripe expressions).
+  llvm::SmallSetVector<SDBMExpr, 8> stripes;
+  int maxDim = -1;
+  int maxSymbol = -1;
+  for (auto expr : inequalities)
+    collectSDBMBuildInfo(expr, maxDim, maxSymbol, stripes);
+  for (auto expr : equalities)
+    collectSDBMBuildInfo(expr, maxDim, maxSymbol, stripes);
+  // Indexing of dimensions starts with 0, obtain the number of dimensions by
+  // incrementing the maximal position of the dimension seen in expressions.
+  result.numDims = maxDim + 1;
+  result.numSymbols = maxSymbol + 1;
+  result.numTemporaries = 0;
+
+  // Helper function that returns the position of the variable represented by
+  // an SDBM input expression.
+  auto linearPosition = [result](SDBMInputExpr expr) {
+    if (expr.isa<SDBMDimExpr>())
+      return result.getDimPosition(expr.getPosition());
+    return result.getSymbolPosition(expr.getPosition());
+  };
+
+  // Check if some stripe expressions are equal to another variable. In
+  // particular, look for the equalities of the form
+  //   d0 - stripe-expression = 0, or
+  //   stripe-expression - d0 = 0.
+  // There may be multiple variables that are equal to the same stripe
+  // expression.  Keep track of those in pointExprToStripe.
+  // There may also be multiple stripe expressions equal to the same variable.
+  // Introduce a temporary variable for each of those.
+  llvm::DenseMap<SDBMExpr, llvm::SmallVector<unsigned, 2>> pointExprToStripe;
+  unsigned numTemporaries = 0;
+
+  auto updateStripePointMaps = [&numTemporaries, &result, &pointExprToStripe,
+                                linearPosition](SDBMInputExpr input,
+                                                SDBMExpr expr) {
+    unsigned position = linearPosition(input);
+    if (result.stripeToPoint.count(position) &&
+        result.stripeToPoint[position] != expr) {
+      position = result.getNumVariables() + numTemporaries++;
+    }
+    pointExprToStripe[expr].push_back(position);
+    result.stripeToPoint.insert(std::make_pair(position, expr));
+  };
+
+  for (auto eq : equalities) {
+    auto diffExpr = eq.dyn_cast<SDBMDiffExpr>();
+    if (!diffExpr)
+      continue;
+
+    auto lhs = diffExpr.getLHS();
+    auto rhs = diffExpr.getRHS();
+    auto lhsInput = lhs.dyn_cast<SDBMInputExpr>();
+    auto rhsInput = rhs.dyn_cast<SDBMInputExpr>();
+
+    if (lhsInput && stripes.count(rhs))
+      updateStripePointMaps(lhsInput, rhs);
+    if (rhsInput && stripes.count(lhs))
+      updateStripePointMaps(rhsInput, lhs);
+  }
+
+  // Assign the remaining stripe expressions to temporary variables.  These
+  // expressions are the ones that could not be associated with an existing
+  // variable in the previous step.
+  for (auto expr : stripes) {
+    if (pointExprToStripe.count(expr))
+      continue;
+    unsigned position = result.getNumVariables() + numTemporaries++;
+    pointExprToStripe[expr].push_back(position);
+    result.stripeToPoint.insert(std::make_pair(position, expr));
+  }
+
+  // Create the DBM matrix, initialized to infinity values for the least tight
+  // possible bound (x - y <= infinity is always true).
+  result.numTemporaries = numTemporaries;
+  result.matrix.resize(result.getNumVariables() * result.getNumVariables(),
+                       IntInfty::infinity());
+
+  SDBMBuilder builder(pointExprToStripe, linearPosition);
+
+  // Only keep the tightest constraint.  Since we transform everything into
+  // less-than-or-equals-to inequalities, keep the smallest constant.  For
+  // example, if we have d0 - d1 <= 42 and d0 - d1 <= 2, we keep the latter.
+  // Note that the input expressions are in the shape of d0 - d1 + -42 <= 0
+  // so we negate the value before storing it.
+  // In case where the positive and the negative positions are equal, the
+  // corresponding expression has the form d0 - d0 + -42 <= 0.  If the constant
+  // value is positive, the set defined by SDBM is trivially empty.  We store
+  // this value anyway and continue processing to maintain the correspondence
+  // between the matrix form and the list-of-SDBMExpr form.
+  // TODO(zinenko): we may want to reconsider this once we have canonicalization
+  // or simplification in place
+  auto updateMatrix = [](SDBM &sdbm, const SDBMBuilderResult &r) {
+    for (auto positivePos : r.positivePos) {
+      for (auto negativePos : r.negativePos) {
+        auto &m = sdbm.at(negativePos, positivePos);
+        m = m < -r.value ? m : -r.value;
+      }
+    }
+  };
+
+  // Do the second sweep on (in)equalities, updating the SDB matrix to reflect
+  // the constraints.
+  for (auto ineq : inequalities)
+    updateMatrix(result, builder.visit(ineq));
+
+  // An equality f(x) = 0 is represented as a pair of inequalities {f(x) >= 0;
+  // f(x) <= 0} or, alternatively, {-f(x) <= 0 and f(x) <= 0}.
+  for (auto eq : equalities) {
+    updateMatrix(result, builder.visit(eq));
+    updateMatrix(result, builder.visit(-eq));
+  }
+
+  // Add the inequalities induced by stripe equalities.
+  //   t = x # C  =>  t <= x <= t + C - 1
+  // which is equivalent to
+  //   {t - x <= 0;
+  //    x - t - (C - 1) <= 0}.
+  for (const auto &pair : result.stripeToPoint) {
+    auto stripe = pair.second.cast<SDBMStripeExpr>();
+    SDBMBuilderResult update = builder.visit(stripe.getVar());
+    assert(update.negativePos.size() == 1 && update.negativePos[0] == 0 &&
+           "unexpected negated variable in stripe expression");
+    assert(update.value == 0 &&
+           "unexpected non-zero value in stripe expression");
+    update.negativePos.clear();
+    update.negativePos.push_back(pair.first);
+    update.value = -(stripe.getStripeFactor().getValue() - 1);
+    updateMatrix(result, update);
+
+    std::swap(update.negativePos, update.positivePos);
+    update.value = 0;
+    updateMatrix(result, update);
+  }
+
+  return result;
+}
+
+// Given a row and a column position in the square DBM, insert one equality
+// or up to two inequalities that correspond the entries (col, row) and (row,
+// col) in the DBM.  `rowExpr` and `colExpr` contain the expressions such that
+// colExpr - rowExpr <= V where V is the value at (row, col) in the DBM.
+// If one of the expressions is derived from another using a stripe operation,
+// check if the inequalities induced by the stripe operation subsume the
+// inequalities defined in the DBM and if so, elide these inequalities.
+void SDBM::convertDBMElement(unsigned row, unsigned col,
+                             SDBMPositiveExpr rowExpr, SDBMPositiveExpr colExpr,
+                             SmallVectorImpl<SDBMExpr> &inequalities,
+                             SmallVectorImpl<SDBMExpr> &equalities) {
+  using ops_assertions::operator+;
+  using ops_assertions::operator-;
+
+  auto diffIJValue = at(col, row);
+  auto diffJIValue = at(row, col);
+
+  // If symmetric entries are opposite, the corresponding expressions are equal.
+  if (diffIJValue.isFinite() &&
+      diffIJValue.getValue() == -diffJIValue.getValue()) {
+    equalities.push_back(rowExpr - colExpr - diffIJValue.getValue());
+    return;
+  }
+
+  // Given an inequality x0 - x1 <= A, check if x0 is a stripe variable derived
+  // from x1: x0 = x1 # B.  If so, it would imply the constraints
+  // x0 <= x1 <= x0 + (B - 1) <=> x0 - x1 <= 0 and x1 - x0 <= (B - 1).
+  // Therefore, if A >= 0, this inequality is subsumed by that implied
+  // by the stripe equality and thus can be elided.
+  // Similarly, check if x1 is a stripe variable derived from x0: x1 = x0 # C.
+  // If so, it would imply the constraints x1 <= x0 <= x1 + (C - 1) <=>
+  // <=> x1 - x0 <= 0 and x0 - x1 <= (C - 1).  Therefore, if A >= (C - 1), this
+  // inequality can be elided.
+  //
+  // Note: x0 and x1 may be a stripe expressions themselves, we rely on stripe
+  // expressions being stored without temporaries on the RHS and being passed
+  // into this function as is.
+  auto canElide = [this](unsigned x0, unsigned x1, SDBMExpr x0Expr,
+                         SDBMExpr x1Expr, int64_t value) {
+    if (stripeToPoint.count(x0)) {
+      auto stripe = stripeToPoint[x0].cast<SDBMStripeExpr>();
+      SDBMPositiveExpr var = stripe.getVar();
+      if (x1Expr == var && value >= 0)
+        return true;
+    }
+    if (stripeToPoint.count(x1)) {
+      auto stripe = stripeToPoint[x1].cast<SDBMStripeExpr>();
+      SDBMPositiveExpr var = stripe.getVar();
+      if (x0Expr == var && value >= stripe.getStripeFactor().getValue() - 1)
+        return true;
+    }
+    return false;
+  };
+
+  // Check row - col.
+  if (diffIJValue.isFinite() &&
+      !canElide(row, col, rowExpr, colExpr, diffIJValue.getValue())) {
+    inequalities.push_back(rowExpr - colExpr - diffIJValue.getValue());
+  }
+  // Check col - row.
+  if (diffJIValue.isFinite() &&
+      !canElide(col, row, colExpr, rowExpr, diffJIValue.getValue())) {
+    inequalities.push_back(colExpr - rowExpr - diffJIValue.getValue());
+  }
+}
+
+// The values on the main diagonal correspond to the upper bound on the
+// difference between a variable and itself: d0 - d0 <= C, or alternatively
+// to -C <= 0.  Only construct the inequalities when C is negative, which
+// are trivially false but necessary for the returned system of inequalities
+// to indicate that the set it defines is empty.
+void SDBM::convertDBMDiagonalElement(unsigned pos, SDBMPositiveExpr expr,
+                                     SmallVectorImpl<SDBMExpr> &inequalities) {
+  auto selfDifference = at(pos, pos);
+  if (selfDifference.isFinite() && selfDifference < 0) {
+    auto selfDifferenceValueExpr =
+        SDBMConstantExpr::get(expr.getDialect(), -selfDifference.getValue());
+    inequalities.push_back(selfDifferenceValueExpr);
+  }
+}
+
+void SDBM::getSDBMExpressions(SDBMDialect *dialect,
+                              SmallVectorImpl<SDBMExpr> &inequalities,
+                              SmallVectorImpl<SDBMExpr> &equalities) {
+  using ops_assertions::operator-;
+  using ops_assertions::operator+;
+
+  // Helper function that creates an SDBMInputExpr given the linearized position
+  // of variable in the DBM.
+  auto getInput = [dialect, this](unsigned matrixPos) -> SDBMInputExpr {
+    if (matrixPos < numDims)
+      return SDBMDimExpr::get(dialect, matrixPos);
+    return SDBMSymbolExpr::get(dialect, matrixPos - numDims);
+  };
+
+  // The top-left value corresponds to inequality 0 <= C.  If C is negative, the
+  // set defined by SDBM is trivially empty and we add the constraint -C <= 0 to
+  // the list of inequalities.  Otherwise, the constraint is trivially true and
+  // we ignore it.
+  auto difference = at(0, 0);
+  if (difference.isFinite() && difference < 0) {
+    inequalities.push_back(
+        SDBMConstantExpr::get(dialect, -difference.getValue()));
+  }
+
+  // Traverse the segment of the matrix that involves non-temporary variables.
+  unsigned numTrueVariables = numDims + numSymbols;
+  for (unsigned i = 0; i < numTrueVariables; ++i) {
+    // The first row and column represent numerical upper and lower bound on
+    // each variable.  Transform them into inequalities if they are finite.
+    auto upperBound = at(0, 1 + i);
+    auto lowerBound = at(1 + i, 0);
+    auto inputExpr = getInput(i);
+    if (upperBound.isFinite() &&
+        upperBound.getValue() == -lowerBound.getValue()) {
+      equalities.push_back(inputExpr - upperBound.getValue());
+    } else if (upperBound.isFinite()) {
+      inequalities.push_back(inputExpr - upperBound.getValue());
+    } else if (lowerBound.isFinite()) {
+      inequalities.push_back(-inputExpr - lowerBound.getValue());
+    }
+
+    // Introduce trivially false inequalities if required by diagonal elements.
+    convertDBMDiagonalElement(1 + i, inputExpr, inequalities);
+
+    // Introduce equalities or inequalities between non-temporary variables.
+    for (unsigned j = 0; j < i; ++j) {
+      convertDBMElement(1 + i, 1 + j, getInput(i), getInput(j), inequalities,
+                        equalities);
+    }
+  }
+
+  // Add equalities for stripe expressions that define non-temporary
+  // variables.  Temporary variables will be substituted into their uses and
+  // should not appear in the resulting equalities.
+  for (const auto &stripePair : stripeToPoint) {
+    unsigned position = stripePair.first;
+    if (position < 1 + numTrueVariables) {
+      equalities.push_back(getInput(position - 1) - stripePair.second);
+    }
+  }
+
+  // Add equalities / inequalities involving temporaries by replacing the
+  // temporaries with stripe expressions that define them.
+  for (unsigned i = 1 + numTrueVariables, e = getNumVariables(); i < e; ++i) {
+    // Mixed constraints involving one temporary (j) and one non-temporary (i)
+    // variable.
+    for (unsigned j = 0; j < numTrueVariables; ++j) {
+      convertDBMElement(i, 1 + j, stripeToPoint[i].cast<SDBMStripeExpr>(),
+                        getInput(j), inequalities, equalities);
+    }
+
+    // Constraints involving only temporary variables.
+    for (unsigned j = 1 + numTrueVariables; j < i; ++j) {
+      convertDBMElement(i, j, stripeToPoint[i].cast<SDBMStripeExpr>(),
+                        stripeToPoint[j].cast<SDBMStripeExpr>(), inequalities,
+                        equalities);
+    }
+
+    // Introduce trivially false inequalities if required by diagonal elements.
+    convertDBMDiagonalElement(i, stripeToPoint[i].cast<SDBMStripeExpr>(),
+                              inequalities);
+  }
+}
+
+void SDBM::print(llvm::raw_ostream &os) {
+  unsigned numVariables = getNumVariables();
+
+  // Helper function that prints the name of the variable given its linearized
+  // position in the DBM.
+  auto getVarName = [this](unsigned matrixPos) -> std::string {
+    if (matrixPos == 0)
+      return "cst";
+    matrixPos -= 1;
+    if (matrixPos < numDims)
+      return llvm::formatv("d{0}", matrixPos);
+    matrixPos -= numDims;
+    if (matrixPos < numSymbols)
+      return llvm::formatv("s{0}", matrixPos);
+    matrixPos -= numSymbols;
+    return llvm::formatv("t{0}", matrixPos);
+  };
+
+  // Header row.
+  os << "      cst";
+  for (unsigned i = 1; i < numVariables; ++i) {
+    os << llvm::formatv(" {0,4}", getVarName(i));
+  }
+  os << '\n';
+
+  // Data rows.
+  for (unsigned i = 0; i < numVariables; ++i) {
+    os << llvm::formatv("{0,-4}", getVarName(i));
+    for (unsigned j = 0; j < numVariables; ++j) {
+      IntInfty value = operator()(i, j);
+      if (!value.isFinite())
+        os << "  inf";
+      else
+        os << llvm::formatv(" {0,4}", value.getValue());
+    }
+    os << '\n';
+  }
+
+  // Explanation of temporaries.
+  for (const auto &pair : stripeToPoint) {
+    os << getVarName(pair.first) << " = ";
+    pair.second.print(os);
+    os << '\n';
+  }
+}
+
+void SDBM::dump() { print(llvm::errs()); }
diff --git a/third_party/mlir/lib/SDBM/SDBMDialect.cpp b/third_party/mlir/lib/SDBM/SDBMDialect.cpp
new file mode 100644
index 0000000..e000209
--- /dev/null
+++ b/third_party/mlir/lib/SDBM/SDBMDialect.cpp
@@ -0,0 +1,20 @@
+//===- SDBMDialect.cpp - Dialect for striped difference-bound matrices ----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/SDBM/SDBMDialect.h"
+
+static mlir::DialectRegistration<mlir::SDBMDialect> SDBMDialect;
diff --git a/third_party/mlir/lib/SDBM/SDBMExpr.cpp b/third_party/mlir/lib/SDBM/SDBMExpr.cpp
new file mode 100644
index 0000000..5757ebe
--- /dev/null
+++ b/third_party/mlir/lib/SDBM/SDBMExpr.cpp
@@ -0,0 +1,647 @@
+//===- SDBMExpr.cpp - MLIR SDBM Expression implementation -----------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// A striped difference-bound matrix (SDBM) expression is a constant expression,
+// an identifier, a binary expression with constant RHS and +, stripe operators
+// or a difference expression between two identifiers.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/SDBM/SDBMExpr.h"
+#include "SDBMExprDetail.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineExprVisitor.h"
+#include "mlir/SDBM/SDBMDialect.h"
+
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+namespace {
+/// A simple compositional matcher for AffineExpr
+///
+/// Example usage:
+///
+/// ```c++
+///    AffineExprMatcher x, C, m;
+///    AffineExprMatcher pattern1 = ((x % C) * m) + x;
+///    AffineExprMatcher pattern2 = x + ((x % C) * m);
+///    if (pattern1.match(expr) || pattern2.match(expr)) {
+///      ...
+///    }
+/// ```
+class AffineExprMatcherStorage;
+class AffineExprMatcher {
+public:
+  AffineExprMatcher();
+  AffineExprMatcher(const AffineExprMatcher &other);
+
+  AffineExprMatcher operator+(AffineExprMatcher other) {
+    return AffineExprMatcher(AffineExprKind::Add, *this, other);
+  }
+  AffineExprMatcher operator*(AffineExprMatcher other) {
+    return AffineExprMatcher(AffineExprKind::Mul, *this, other);
+  }
+  AffineExprMatcher floorDiv(AffineExprMatcher other) {
+    return AffineExprMatcher(AffineExprKind::FloorDiv, *this, other);
+  }
+  AffineExprMatcher ceilDiv(AffineExprMatcher other) {
+    return AffineExprMatcher(AffineExprKind::CeilDiv, *this, other);
+  }
+  AffineExprMatcher operator%(AffineExprMatcher other) {
+    return AffineExprMatcher(AffineExprKind::Mod, *this, other);
+  }
+
+  AffineExpr match(AffineExpr expr);
+  AffineExpr matched();
+  Optional<int> getMatchedConstantValue();
+
+private:
+  AffineExprMatcher(AffineExprKind k, AffineExprMatcher a, AffineExprMatcher b);
+  AffineExprKind kind; // only used to match in binary op cases.
+  // A shared_ptr allows multiple references to same matcher storage without
+  // worrying about ownership or dealing with an arena. To be cleaned up if we
+  // go with this.
+  std::shared_ptr<AffineExprMatcherStorage> storage;
+};
+
+class AffineExprMatcherStorage {
+public:
+  AffineExprMatcherStorage() {}
+  AffineExprMatcherStorage(const AffineExprMatcherStorage &other)
+      : subExprs(other.subExprs.begin(), other.subExprs.end()),
+        matched(other.matched) {}
+  AffineExprMatcherStorage(ArrayRef<AffineExprMatcher> exprs)
+      : subExprs(exprs.begin(), exprs.end()) {}
+  AffineExprMatcherStorage(AffineExprMatcher &a, AffineExprMatcher &b)
+      : subExprs({a, b}) {}
+  llvm::SmallVector<AffineExprMatcher, 0> subExprs;
+  AffineExpr matched;
+};
+} // namespace
+
+AffineExprMatcher::AffineExprMatcher()
+    : kind(AffineExprKind::Constant), storage(new AffineExprMatcherStorage()) {}
+
+AffineExprMatcher::AffineExprMatcher(const AffineExprMatcher &other)
+    : kind(other.kind), storage(other.storage) {}
+
+Optional<int> AffineExprMatcher::getMatchedConstantValue() {
+  if (auto cst = storage->matched.dyn_cast<AffineConstantExpr>())
+    return cst.getValue();
+  return None;
+}
+
+AffineExpr AffineExprMatcher::match(AffineExpr expr) {
+  if (kind > AffineExprKind::LAST_AFFINE_BINARY_OP) {
+    if (storage->matched)
+      if (storage->matched != expr)
+        return AffineExpr();
+    storage->matched = expr;
+    return storage->matched;
+  }
+  if (kind != expr.getKind()) {
+    return AffineExpr();
+  }
+  if (auto bin = expr.dyn_cast<AffineBinaryOpExpr>()) {
+    if (!storage->subExprs.empty() &&
+        !storage->subExprs[0].match(bin.getLHS())) {
+      return AffineExpr();
+    }
+    if (!storage->subExprs.empty() &&
+        !storage->subExprs[1].match(bin.getRHS())) {
+      return AffineExpr();
+    }
+    if (storage->matched)
+      if (storage->matched != expr)
+        return AffineExpr();
+    storage->matched = expr;
+    return storage->matched;
+  }
+  llvm_unreachable("binary expected");
+}
+
+AffineExpr AffineExprMatcher::matched() { return storage->matched; }
+
+AffineExprMatcher::AffineExprMatcher(AffineExprKind k, AffineExprMatcher a,
+                                     AffineExprMatcher b)
+    : kind(k), storage(new AffineExprMatcherStorage(a, b)) {
+  storage->subExprs.push_back(a);
+  storage->subExprs.push_back(b);
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMExpr
+//===----------------------------------------------------------------------===//
+
+SDBMExprKind SDBMExpr::getKind() const { return impl->getKind(); }
+
+MLIRContext *SDBMExpr::getContext() const {
+  return impl->dialect->getContext();
+}
+
+SDBMDialect *SDBMExpr::getDialect() const { return impl->dialect; }
+
+void SDBMExpr::print(raw_ostream &os) const {
+  struct Printer : public SDBMVisitor<Printer> {
+    Printer(raw_ostream &ostream) : prn(ostream) {}
+
+    void visitSum(SDBMSumExpr expr) {
+      visitVarying(expr.getLHS());
+      prn << " + ";
+      visitConstant(expr.getRHS());
+    }
+    void visitDiff(SDBMDiffExpr expr) {
+      visitPositive(expr.getLHS());
+      prn << " - ";
+      visitPositive(expr.getRHS());
+    }
+    void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); }
+    void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); }
+    void visitStripe(SDBMStripeExpr expr) {
+      visitPositive(expr.getVar());
+      prn << " # ";
+      visitConstant(expr.getStripeFactor());
+    }
+    void visitNeg(SDBMNegExpr expr) {
+      prn << '-';
+      visitPositive(expr.getVar());
+    }
+    void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); }
+
+    raw_ostream &prn;
+  };
+  Printer printer(os);
+  printer.visit(*this);
+}
+
+void SDBMExpr::dump() const {
+  print(llvm::errs());
+  llvm::errs() << '\n';
+}
+
+namespace {
+// Helper class to perform negation of an SDBM expression.
+struct SDBMNegator : public SDBMVisitor<SDBMNegator, SDBMExpr> {
+  // Any positive expression is wrapped into a negation expression.
+  //  -(x) = -x
+  SDBMExpr visitPositive(SDBMPositiveExpr expr) {
+    return SDBMNegExpr::get(expr);
+  }
+  // A negation expression is unwrapped.
+  //  -(-x) = x
+  SDBMExpr visitNeg(SDBMNegExpr expr) { return expr.getVar(); }
+  // The value of the constant is negated.
+  SDBMExpr visitConstant(SDBMConstantExpr expr) {
+    return SDBMConstantExpr::get(expr.getDialect(), -expr.getValue());
+  }
+  // Both terms of the sum are negated recursively.
+  SDBMExpr visitSum(SDBMSumExpr expr) {
+    return SDBMSumExpr::get(visit(expr.getLHS()).cast<SDBMVaryingExpr>(),
+                            visit(expr.getRHS()).cast<SDBMConstantExpr>());
+  }
+  // Terms of a difference are interchanged.
+  //  -(x - y) = y - x
+  SDBMExpr visitDiff(SDBMDiffExpr expr) {
+    return SDBMDiffExpr::get(expr.getRHS(), expr.getLHS());
+  }
+};
+} // namespace
+
+SDBMExpr SDBMExpr::operator-() { return SDBMNegator().visit(*this); }
+
+//===----------------------------------------------------------------------===//
+// SDBMSumExpr
+//===----------------------------------------------------------------------===//
+
+SDBMSumExpr SDBMSumExpr::get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs) {
+  assert(lhs && "expected SDBM variable expression");
+  assert(rhs && "expected SDBM constant");
+
+  // If LHS of a sum is another sum, fold the constant RHS parts.
+  if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>()) {
+    lhs = lhsSum.getLHS();
+    rhs = SDBMConstantExpr::get(rhs.getDialect(),
+                                rhs.getValue() + lhsSum.getRHS().getValue());
+  }
+
+  StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
+  return uniquer.get<detail::SDBMBinaryExprStorage>(
+      /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Add), lhs, rhs);
+}
+
+SDBMVaryingExpr SDBMSumExpr::getLHS() const {
+  return static_cast<ImplType *>(impl)->lhs;
+}
+
+SDBMConstantExpr SDBMSumExpr::getRHS() const {
+  return static_cast<ImplType *>(impl)->rhs;
+}
+
+AffineExpr SDBMExpr::getAsAffineExpr() const {
+  struct Converter : public SDBMVisitor<Converter, AffineExpr> {
+    AffineExpr visitSum(SDBMSumExpr expr) {
+      AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
+      return lhs + rhs;
+    }
+
+    AffineExpr visitStripe(SDBMStripeExpr expr) {
+      AffineExpr lhs = visit(expr.getVar()),
+                 rhs = visit(expr.getStripeFactor());
+      return lhs - (lhs % rhs);
+    }
+
+    AffineExpr visitDiff(SDBMDiffExpr expr) {
+      AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
+      return lhs - rhs;
+    }
+
+    AffineExpr visitDim(SDBMDimExpr expr) {
+      return getAffineDimExpr(expr.getPosition(), expr.getContext());
+    }
+
+    AffineExpr visitSymbol(SDBMSymbolExpr expr) {
+      return getAffineSymbolExpr(expr.getPosition(), expr.getContext());
+    }
+
+    AffineExpr visitNeg(SDBMNegExpr expr) {
+      return getAffineBinaryOpExpr(AffineExprKind::Mul,
+                                   getAffineConstantExpr(-1, expr.getContext()),
+                                   visit(expr.getVar()));
+    }
+
+    AffineExpr visitConstant(SDBMConstantExpr expr) {
+      return getAffineConstantExpr(expr.getValue(), expr.getContext());
+    }
+  } converter;
+  return converter.visit(*this);
+}
+
+Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
+  struct Converter : public AffineExprVisitor<Converter, SDBMExpr> {
+    SDBMExpr visitAddExpr(AffineBinaryOpExpr expr) {
+      // Attempt to recover a stripe expression.  Because AffineExprs don't have
+      // a first-class difference kind, we check for both x + -1 * (x mod C) and
+      // -1 * (x mod C) + x cases.
+      AffineExprMatcher x, C, m;
+      AffineExprMatcher pattern1 = ((x % C) * m) + x;
+      AffineExprMatcher pattern2 = x + ((x % C) * m);
+      if ((pattern1.match(expr) && m.getMatchedConstantValue() == -1) ||
+          (pattern2.match(expr) && m.getMatchedConstantValue() == -1)) {
+        if (auto convertedLHS = visit(x.matched())) {
+          // TODO(ntv): return convertedLHS.stripe(C);
+          return SDBMStripeExpr::get(
+              convertedLHS.cast<SDBMPositiveExpr>(),
+              visit(C.matched()).cast<SDBMConstantExpr>());
+        }
+      }
+      auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
+      if (!lhs || !rhs)
+        return {};
+
+      // In a "add" AffineExpr, the constant always appears on the right.  If
+      // there were two constants, they would have been folded away.
+      assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
+      auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
+
+      // SDBM accepts LHS variables and RHS constants in a sum.
+      auto lhsVar = lhs.dyn_cast<SDBMVaryingExpr>();
+      auto rhsVar = rhs.dyn_cast<SDBMVaryingExpr>();
+      if (rhsConstant && lhsVar)
+        return SDBMSumExpr::get(lhsVar, rhsConstant);
+
+      // The sum of a negated variable and a non-negated variable is a
+      // difference, supported as a special kind in SDBM.  Because AffineExprs
+      // don't have first-class difference kind, check both LHS and RHS for
+      // negation.
+      auto lhsPos = lhs.dyn_cast<SDBMPositiveExpr>();
+      auto rhsPos = rhs.dyn_cast<SDBMPositiveExpr>();
+      auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>();
+      auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>();
+      if (lhsNeg && rhsVar)
+        return SDBMDiffExpr::get(rhsPos, lhsNeg.getVar());
+      if (rhsNeg && lhsVar)
+        return SDBMDiffExpr::get(lhsPos, rhsNeg.getVar());
+
+      // Other cases don't fit into SDBM.
+      return {};
+    }
+
+    SDBMExpr visitMulExpr(AffineBinaryOpExpr expr) {
+      // Attempt to recover a stripe expression "x # C = (x floordiv C) * C".
+      AffineExprMatcher x, C;
+      AffineExprMatcher pattern = (x.floorDiv(C)) * C;
+      if (pattern.match(expr)) {
+        if (SDBMExpr converted = visit(x.matched())) {
+          if (auto varConverted = converted.dyn_cast<SDBMPositiveExpr>())
+            // TODO(ntv): return varConverted.stripe(C.getConstantValue());
+            return SDBMStripeExpr::get(
+                varConverted,
+                SDBMConstantExpr::get(dialect,
+                                      C.getMatchedConstantValue().getValue()));
+        }
+      }
+
+      auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
+      if (!lhs || !rhs)
+        return {};
+
+      // In a "mul" AffineExpr, the constant always appears on the right.  If
+      // there were two constants, they would have been folded away.
+      assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
+      auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
+      if (!rhsConstant)
+        return {};
+
+      // The only supported "multiplication" expression is an SDBM is dimension
+      // negation, that is a product of dimension and constant -1.
+      auto lhsVar = lhs.dyn_cast<SDBMPositiveExpr>();
+      if (lhsVar && rhsConstant.getValue() == -1)
+        return SDBMNegExpr::get(lhsVar);
+
+      // Other multiplications are not allowed in SDBM.
+      return {};
+    }
+
+    SDBMExpr visitModExpr(AffineBinaryOpExpr expr) {
+      auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
+      if (!lhs || !rhs)
+        return {};
+
+      // 'mod' can only be converted to SDBM if its LHS is a variable
+      // and its RHS is a constant.  Then it `x mod c = x - x stripe c`.
+      auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
+      auto lhsVar = rhs.dyn_cast<SDBMPositiveExpr>();
+      if (!lhsVar || !rhsConstant)
+        return {};
+      return SDBMDiffExpr::get(lhsVar,
+                               SDBMStripeExpr::get(lhsVar, rhsConstant));
+    }
+
+    // `a floordiv b = (a stripe b) / b`, but we have no division in SDBM
+    SDBMExpr visitFloorDivExpr(AffineBinaryOpExpr expr) { return {}; }
+    SDBMExpr visitCeilDivExpr(AffineBinaryOpExpr expr) { return {}; }
+
+    // Dimensions, symbols and constants are converted trivially.
+    SDBMExpr visitConstantExpr(AffineConstantExpr expr) {
+      return SDBMConstantExpr::get(dialect, expr.getValue());
+    }
+    SDBMExpr visitDimExpr(AffineDimExpr expr) {
+      return SDBMDimExpr::get(dialect, expr.getPosition());
+    }
+    SDBMExpr visitSymbolExpr(AffineSymbolExpr expr) {
+      return SDBMSymbolExpr::get(dialect, expr.getPosition());
+    }
+
+    SDBMDialect *dialect;
+  } converter;
+  converter.dialect = affine.getContext()->getRegisteredDialect<SDBMDialect>();
+
+  if (auto result = converter.visit(affine))
+    return result;
+  return None;
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMDiffExpr
+//===----------------------------------------------------------------------===//
+
+SDBMDiffExpr SDBMDiffExpr::get(SDBMPositiveExpr lhs, SDBMPositiveExpr rhs) {
+  assert(lhs && "expected SDBM dimension");
+  assert(rhs && "expected SDBM dimension");
+
+  StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
+  return uniquer.get<detail::SDBMDiffExprStorage>(
+      /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Diff), lhs, rhs);
+}
+
+SDBMPositiveExpr SDBMDiffExpr::getLHS() const {
+  return static_cast<ImplType *>(impl)->lhs;
+}
+
+SDBMPositiveExpr SDBMDiffExpr::getRHS() const {
+  return static_cast<ImplType *>(impl)->rhs;
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMStripeExpr
+//===----------------------------------------------------------------------===//
+
+SDBMStripeExpr SDBMStripeExpr::get(SDBMPositiveExpr var,
+                                   SDBMConstantExpr stripeFactor) {
+  assert(var && "expected SDBM variable expression");
+  assert(stripeFactor && "expected non-null stripe factor");
+  if (stripeFactor.getValue() <= 0)
+    llvm::report_fatal_error("non-positive stripe factor");
+
+  StorageUniquer &uniquer = var.getDialect()->getUniquer();
+  return uniquer.get<detail::SDBMBinaryExprStorage>(
+      /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Stripe), var,
+      stripeFactor);
+}
+
+SDBMPositiveExpr SDBMStripeExpr::getVar() const {
+  if (SDBMVaryingExpr lhs = static_cast<ImplType *>(impl)->lhs)
+    return lhs.cast<SDBMPositiveExpr>();
+  return {};
+}
+
+SDBMConstantExpr SDBMStripeExpr::getStripeFactor() const {
+  return static_cast<ImplType *>(impl)->rhs;
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMInputExpr
+//===----------------------------------------------------------------------===//
+
+unsigned SDBMInputExpr::getPosition() const {
+  return static_cast<ImplType *>(impl)->position;
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMDimExpr
+//===----------------------------------------------------------------------===//
+
+SDBMDimExpr SDBMDimExpr::get(SDBMDialect *dialect, unsigned position) {
+  assert(dialect && "expected non-null dialect");
+
+  auto assignDialect = [dialect](detail::SDBMPositiveExprStorage *storage) {
+    storage->dialect = dialect;
+  };
+
+  StorageUniquer &uniquer = dialect->getUniquer();
+  return uniquer.get<detail::SDBMPositiveExprStorage>(
+      assignDialect, static_cast<unsigned>(SDBMExprKind::DimId), position);
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMSymbolExpr
+//===----------------------------------------------------------------------===//
+
+SDBMSymbolExpr SDBMSymbolExpr::get(SDBMDialect *dialect, unsigned position) {
+  assert(dialect && "expected non-null dialect");
+
+  auto assignDialect = [dialect](detail::SDBMPositiveExprStorage *storage) {
+    storage->dialect = dialect;
+  };
+
+  StorageUniquer &uniquer = dialect->getUniquer();
+  return uniquer.get<detail::SDBMPositiveExprStorage>(
+      assignDialect, static_cast<unsigned>(SDBMExprKind::SymbolId), position);
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMConstantExpr
+//===----------------------------------------------------------------------===//
+
+SDBMConstantExpr SDBMConstantExpr::get(SDBMDialect *dialect, int64_t value) {
+  assert(dialect && "expected non-null dialect");
+
+  auto assignCtx = [dialect](detail::SDBMConstantExprStorage *storage) {
+    storage->dialect = dialect;
+  };
+
+  StorageUniquer &uniquer = dialect->getUniquer();
+  return uniquer.get<detail::SDBMConstantExprStorage>(
+      assignCtx, static_cast<unsigned>(SDBMExprKind::Constant), value);
+}
+
+int64_t SDBMConstantExpr::getValue() const {
+  return static_cast<ImplType *>(impl)->constant;
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMNegExpr
+//===----------------------------------------------------------------------===//
+
+SDBMNegExpr SDBMNegExpr::get(SDBMPositiveExpr var) {
+  assert(var && "expected non-null SDBM variable expression");
+
+  StorageUniquer &uniquer = var.getDialect()->getUniquer();
+  return uniquer.get<detail::SDBMNegExprStorage>(
+      /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Neg), var);
+}
+
+SDBMPositiveExpr SDBMNegExpr::getVar() const {
+  return static_cast<ImplType *>(impl)->dim;
+}
+
+namespace mlir {
+namespace ops_assertions {
+
+SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs) {
+  // If one of the operands is a negation, take a difference rather than a sum.
+  auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>();
+  auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>();
+  assert(!(lhsNeg && rhsNeg) && "a sum of negated expressions is a negation of "
+                                "a sum of variables and not a correct SDBM");
+  if (lhsNeg)
+    return rhs - lhsNeg.getVar();
+  if (rhsNeg)
+    return lhs - rhsNeg.getVar();
+
+  // If LHS is a constant and RHS is not, swap the order to get into a supported
+  // sum case.  From now on, RHS must be a constant.
+  auto lhsConstant = lhs.dyn_cast<SDBMConstantExpr>();
+  auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
+  if (!rhsConstant && lhsConstant) {
+    std::swap(lhs, rhs);
+    std::swap(lhsConstant, rhsConstant);
+  }
+  assert(rhsConstant && "at least one operand must be a constant");
+
+  // If LHS is another sum, first compute the sum of its variable
+  // part with the other argument and then add the constant part to enable
+  // constant folding (the variable part may, e.g., be a negation that requires
+  // to enter this function again).
+  auto lhsSum = lhs.dyn_cast<SDBMSumExpr>();
+  if (lhsSum)
+    return lhsSum.getLHS() +
+           (lhsSum.getRHS().getValue() + rhsConstant.getValue());
+
+  // Constant-fold if LHS is a constant.
+  if (lhsConstant)
+    return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() +
+                                                       rhsConstant.getValue());
+
+  // Fold x + 0 == x.
+  if (rhsConstant.getValue() == 0)
+    return lhs;
+
+  return SDBMSumExpr::get(lhs.cast<SDBMVaryingExpr>(),
+                          rhs.cast<SDBMConstantExpr>());
+}
+
+SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs) {
+  // Fold x - x == 0.
+  if (lhs == rhs)
+    return SDBMConstantExpr::get(lhs.getDialect(), 0);
+
+  // LHS and RHS may be constants.
+  auto lhsConstant = lhs.dyn_cast<SDBMConstantExpr>();
+  auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
+
+  // Constant fold if both LHS and RHS are constants.
+  if (lhsConstant && rhsConstant)
+    return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() -
+                                                       rhsConstant.getValue());
+
+  // Replace a difference with a sum with a negated value if one of LHS and RHS
+  // is a constant:
+  //   x - C == x + (-C);
+  //   C - x == -x + C.
+  // This calls into operator+ for further simplification.
+  if (rhsConstant)
+    return lhs + (-rhsConstant);
+  if (lhsConstant)
+    return -rhs + lhsConstant;
+
+  // Hoist constant factors outside the difference if any of sides is a sum:
+  //   (x + A) - (y - B) == x - y + (A - B).
+  // If either LHS or RHS is a sum, collect the constant values separately and
+  // update LHS and RHS to point to the variable part of the sum.
+  auto lhsSum = lhs.dyn_cast<SDBMSumExpr>();
+  auto rhsSum = rhs.dyn_cast<SDBMSumExpr>();
+  int64_t value = 0;
+  if (lhsSum) {
+    value += lhsSum.getRHS().getValue();
+    lhs = lhsSum.getLHS();
+  }
+  if (rhsSum) {
+    value -= rhsSum.getRHS().getValue();
+    rhs = rhsSum.getLHS();
+  }
+
+  // This calls into operator+ for futher simplification in case value == 0.
+  return SDBMDiffExpr::get(lhs.cast<SDBMPositiveExpr>(),
+                           rhs.cast<SDBMPositiveExpr>()) +
+         value;
+}
+
+SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor) {
+  auto constantFactor = factor.cast<SDBMConstantExpr>();
+  assert(constantFactor.getValue() > 0 && "non-positive stripe");
+
+  // Fold x # 1 = x.
+  if (constantFactor.getValue() == 1)
+    return expr;
+
+  return SDBMStripeExpr::get(expr.cast<SDBMPositiveExpr>(), constantFactor);
+}
+
+} // namespace ops_assertions
+} // namespace mlir
diff --git a/third_party/mlir/lib/SDBM/SDBMExprDetail.h b/third_party/mlir/lib/SDBM/SDBMExprDetail.h
new file mode 100644
index 0000000..d2c241e
--- /dev/null
+++ b/third_party/mlir/lib/SDBM/SDBMExprDetail.h
@@ -0,0 +1,138 @@
+//===- SDBMExprDetail.h - MLIR SDBM Expression storage details --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This holds implementation details of SDBMExpr, in particular underlying
+// storage types.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_SDBMEXPRDETAIL_H
+#define MLIR_IR_SDBMEXPRDETAIL_H
+
+#include "mlir/SDBM/SDBMExpr.h"
+#include "mlir/Support/StorageUniquer.h"
+
+namespace mlir {
+
+class SDBMDialect;
+
+namespace detail {
+
+// Base storage class for SDBMExpr.
+struct SDBMExprStorage : public StorageUniquer::BaseStorage {
+  SDBMExprKind getKind() {
+    return static_cast<SDBMExprKind>(BaseStorage::getKind());
+  }
+
+  SDBMDialect *dialect;
+};
+
+// Storage class for SDBM sum and stripe expressions.
+struct SDBMBinaryExprStorage : public SDBMExprStorage {
+  using KeyTy = std::pair<SDBMVaryingExpr, SDBMConstantExpr>;
+
+  bool operator==(const KeyTy &key) const {
+    return std::get<0>(key) == lhs && std::get<1>(key) == rhs;
+  }
+
+  static SDBMBinaryExprStorage *
+  construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
+    auto *result = allocator.allocate<SDBMBinaryExprStorage>();
+    result->lhs = std::get<0>(key);
+    result->rhs = std::get<1>(key);
+    result->dialect = result->lhs.getDialect();
+    return result;
+  }
+
+  SDBMVaryingExpr lhs;
+  SDBMConstantExpr rhs;
+};
+
+// Storage class for SDBM difference expressions.
+struct SDBMDiffExprStorage : public SDBMExprStorage {
+  using KeyTy = std::pair<SDBMPositiveExpr, SDBMPositiveExpr>;
+
+  bool operator==(const KeyTy &key) const {
+    return std::get<0>(key) == lhs && std::get<1>(key) == rhs;
+  }
+
+  static SDBMDiffExprStorage *
+  construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
+    auto *result = allocator.allocate<SDBMDiffExprStorage>();
+    result->lhs = std::get<0>(key);
+    result->rhs = std::get<1>(key);
+    result->dialect = result->lhs.getDialect();
+    return result;
+  }
+
+  SDBMPositiveExpr lhs;
+  SDBMPositiveExpr rhs;
+};
+
+// Storage class for SDBM constant expressions.
+struct SDBMConstantExprStorage : public SDBMExprStorage {
+  using KeyTy = int64_t;
+
+  bool operator==(const KeyTy &key) const { return constant == key; }
+
+  static SDBMConstantExprStorage *
+  construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
+    auto *result = allocator.allocate<SDBMConstantExprStorage>();
+    result->constant = key;
+    return result;
+  }
+
+  int64_t constant;
+};
+
+// Storage class for SDBM dimension and symbol expressions.
+struct SDBMPositiveExprStorage : public SDBMExprStorage {
+  using KeyTy = unsigned;
+
+  bool operator==(const KeyTy &key) const { return position == key; }
+
+  static SDBMPositiveExprStorage *
+  construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
+    auto *result = allocator.allocate<SDBMPositiveExprStorage>();
+    result->position = key;
+    return result;
+  }
+
+  unsigned position;
+};
+
+// Storage class for SDBM negation expressions.
+struct SDBMNegExprStorage : public SDBMExprStorage {
+  using KeyTy = SDBMPositiveExpr;
+
+  bool operator==(const KeyTy &key) const { return key == dim; }
+
+  static SDBMNegExprStorage *
+  construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
+    auto *result = allocator.allocate<SDBMNegExprStorage>();
+    result->dim = key;
+    result->dialect = key.getDialect();
+    return result;
+  }
+
+  SDBMPositiveExpr dim;
+};
+
+} // end namespace detail
+} // end namespace mlir
+
+#endif // MLIR_IR_SDBMEXPRDETAIL_H
diff --git a/third_party/mlir/lib/StandardOps/CMakeLists.txt b/third_party/mlir/lib/StandardOps/CMakeLists.txt
new file mode 100644
index 0000000..e9fce2b
--- /dev/null
+++ b/third_party/mlir/lib/StandardOps/CMakeLists.txt
@@ -0,0 +1,9 @@
+file(GLOB globbed *.c *.cpp)
+add_llvm_library(MLIRStandardOps
+  ${globbed}
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/StandardOps
+  )
+add_dependencies(MLIRStandardOps MLIRStandardOpsIncGen LLVMSupport)
+target_link_libraries(MLIRStandardOps LLVMSupport)
diff --git a/third_party/mlir/lib/StandardOps/DialectRegistration.cpp b/third_party/mlir/lib/StandardOps/DialectRegistration.cpp
new file mode 100644
index 0000000..1f71a3d
--- /dev/null
+++ b/third_party/mlir/lib/StandardOps/DialectRegistration.cpp
@@ -0,0 +1,22 @@
+//===- DialectRegistration.cpp - Register standard Op dialect -------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/StandardOps/Ops.h"
+using namespace mlir;
+
+// Static initialization for standard op dialect registration.
+static DialectRegistration<StandardOpsDialect> StandardOps;
diff --git a/third_party/mlir/lib/StandardOps/Ops.cpp b/third_party/mlir/lib/StandardOps/Ops.cpp
new file mode 100644
index 0000000..9ecd99a
--- /dev/null
+++ b/third_party/mlir/lib/StandardOps/Ops.cpp
@@ -0,0 +1,2119 @@
+//===- Ops.cpp - Standard MLIR Operations ---------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/StandardOps/Ops.h"
+
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/MathExtras.h"
+#include "mlir/Support/STLExtras.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// StandardOpsDialect
+//===----------------------------------------------------------------------===//
+
+/// A custom binary operation printer that omits the "std." prefix from the
+/// operation names.
+static void printStandardBinaryOp(Operation *op, OpAsmPrinter *p) {
+  assert(op->getNumOperands() == 2 && "binary op should have two operands");
+  assert(op->getNumResults() == 1 && "binary op should have one result");
+
+  // If not all the operand and result types are the same, just use the
+  // generic assembly form to avoid omitting information in printing.
+  auto resultType = op->getResult(0)->getType();
+  if (op->getOperand(0)->getType() != resultType ||
+      op->getOperand(1)->getType() != resultType) {
+    p->printGenericOp(op);
+    return;
+  }
+
+  *p << op->getName().getStringRef().drop_front(strlen("std.")) << ' '
+     << *op->getOperand(0) << ", " << *op->getOperand(1);
+  p->printOptionalAttrDict(op->getAttrs());
+
+  // Now we can output only one type for all operands and the result.
+  *p << " : " << op->getResult(0)->getType();
+}
+
+/// A custom cast operation printer that omits the "std." prefix from the
+/// operation names.
+static void printStandardCastOp(Operation *op, OpAsmPrinter *p) {
+  *p << op->getName().getStringRef().drop_front(strlen("std.")) << ' '
+     << *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to "
+     << op->getResult(0)->getType();
+}
+
+/// A custom cast operation verifier.
+template <typename T> static LogicalResult verifyCastOp(T op) {
+  auto opType = op.getOperand()->getType();
+  auto resType = op.getType();
+  if (!T::areCastCompatible(opType, resType))
+    return op.emitError("operand type ") << opType << " and result type "
+                                         << resType << " are cast incompatible";
+
+  return success();
+}
+
+StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
+    : Dialect(getDialectNamespace(), context) {
+  addOperations<DmaStartOp, DmaWaitOp,
+#define GET_OP_LIST
+#include "mlir/StandardOps/Ops.cpp.inc"
+                >();
+}
+
+void mlir::printDimAndSymbolList(Operation::operand_iterator begin,
+                                 Operation::operand_iterator end,
+                                 unsigned numDims, OpAsmPrinter *p) {
+  *p << '(';
+  p->printOperands(begin, begin + numDims);
+  *p << ')';
+
+  if (begin + numDims != end) {
+    *p << '[';
+    p->printOperands(begin + numDims, end);
+    *p << ']';
+  }
+}
+
+// Parses dimension and symbol list, and sets 'numDims' to the number of
+// dimension operands parsed.
+// Returns 'false' on success and 'true' on error.
+ParseResult mlir::parseDimAndSymbolList(OpAsmParser *parser,
+                                        SmallVector<Value *, 4> &operands,
+                                        unsigned &numDims) {
+  SmallVector<OpAsmParser::OperandType, 8> opInfos;
+  if (parser->parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
+    return failure();
+  // Store number of dimensions for validation by caller.
+  numDims = opInfos.size();
+
+  // Parse the optional symbol operands.
+  auto affineIntTy = parser->getBuilder().getIndexType();
+  if (parser->parseOperandList(opInfos,
+                               OpAsmParser::Delimiter::OptionalSquare) ||
+      parser->resolveOperands(opInfos, affineIntTy, operands))
+    return failure();
+  return success();
+}
+
+/// Matches a ConstantIndexOp.
+/// TODO: This should probably just be a general matcher that uses m_Constant
+/// and checks the operation for an index type.
+static detail::op_matcher<ConstantIndexOp> m_ConstantIndex() {
+  return detail::op_matcher<ConstantIndexOp>();
+}
+
+//===----------------------------------------------------------------------===//
+// Common canonicalization pattern support logic
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This is a common class used for patterns of the form
+/// "someop(memrefcast) -> someop".  It folds the source of any memref_cast
+/// into the root operation directly.
+struct MemRefCastFolder : public RewritePattern {
+  /// The rootOpName is the name of the root operation to match against.
+  MemRefCastFolder(StringRef rootOpName, MLIRContext *context)
+      : RewritePattern(rootOpName, 1, context) {}
+
+  PatternMatchResult match(Operation *op) const override {
+    for (auto *operand : op->getOperands())
+      if (matchPattern(operand, m_Op<MemRefCastOp>()))
+        return matchSuccess();
+
+    return matchFailure();
+  }
+
+  void rewrite(Operation *op, PatternRewriter &rewriter) const override {
+    for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
+      if (auto *memref = op->getOperand(i)->getDefiningOp())
+        if (auto cast = dyn_cast<MemRefCastOp>(memref))
+          op->setOperand(i, cast.getOperand());
+    rewriter.updatedRootInPlace(op);
+  }
+};
+
+/// Performs const folding `calculate` with element-wise behavior on the two
+/// attributes in `operands` and returns the result if possible.
+template <class AttrElementT,
+          class ElementValueT = typename AttrElementT::ValueType,
+          class CalculationT =
+              std::function<ElementValueT(ElementValueT, ElementValueT)>>
+Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
+                            const CalculationT &calculate) {
+  assert(operands.size() == 2 && "binary op takes two operands");
+
+  if (auto lhs = operands[0].dyn_cast_or_null<AttrElementT>()) {
+    auto rhs = operands[1].dyn_cast_or_null<AttrElementT>();
+    if (!rhs || lhs.getType() != rhs.getType())
+      return {};
+
+    return AttrElementT::get(lhs.getType(),
+                             calculate(lhs.getValue(), rhs.getValue()));
+  } else if (auto lhs = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
+    auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>();
+    if (!rhs || lhs.getType() != rhs.getType())
+      return {};
+
+    auto elementResult = constFoldBinaryOp<AttrElementT>(
+        {lhs.getSplatValue(), rhs.getSplatValue()}, calculate);
+    if (!elementResult)
+      return {};
+
+    return DenseElementsAttr::get(lhs.getType(), elementResult);
+  }
+  return {};
+}
+} // end anonymous namespace.
+
+//===----------------------------------------------------------------------===//
+// AddFOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult AddFOp::fold(ArrayRef<Attribute> operands) {
+  return constFoldBinaryOp<FloatAttr>(
+      operands, [](APFloat a, APFloat b) { return a + b; });
+}
+
+//===----------------------------------------------------------------------===//
+// AddIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) {
+  /// addi(x, 0) -> x
+  if (matchPattern(rhs(), m_Zero()))
+    return lhs();
+
+  return constFoldBinaryOp<IntegerAttr>(operands,
+                                        [](APInt a, APInt b) { return a + b; });
+}
+
+//===----------------------------------------------------------------------===//
+// AllocOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter *p, AllocOp op) {
+  *p << "alloc";
+
+  // Print dynamic dimension operands.
+  MemRefType type = op.getType();
+  printDimAndSymbolList(op.operand_begin(), op.operand_end(),
+                        type.getNumDynamicDims(), p);
+  p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
+  *p << " : " << type;
+}
+
+static ParseResult parseAllocOp(OpAsmParser *parser, OperationState *result) {
+  MemRefType type;
+
+  // Parse the dimension operands and optional symbol operands, followed by a
+  // memref type.
+  unsigned numDimOperands;
+  if (parseDimAndSymbolList(parser, result->operands, numDimOperands) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(type))
+    return failure();
+
+  // Check numDynamicDims against number of question marks in memref type.
+  // Note: this check remains here (instead of in verify()), because the
+  // partition between dim operands and symbol operands is lost after parsing.
+  // Verification still checks that the total number of operands matches
+  // the number of symbols in the affine map, plus the number of dynamic
+  // dimensions in the memref.
+  if (numDimOperands != type.getNumDynamicDims())
+    return parser->emitError(parser->getNameLoc())
+           << "dimension operand count does not equal memref dynamic dimension "
+              "count";
+  result->types.push_back(type);
+  return success();
+}
+
+static LogicalResult verify(AllocOp op) {
+  auto memRefType = op.getResult()->getType().dyn_cast<MemRefType>();
+  if (!memRefType)
+    return op.emitOpError("result must be a memref");
+
+  unsigned numSymbols = 0;
+  if (!memRefType.getAffineMaps().empty()) {
+    AffineMap affineMap = memRefType.getAffineMaps()[0];
+    // Store number of symbols used in affine map (used in subsequent check).
+    numSymbols = affineMap.getNumSymbols();
+  }
+  unsigned numDynamicDims = memRefType.getNumDynamicDims();
+  // Check that the total number of operands matches the number of symbols in
+  // the affine map, plus the number of dynamic dimensions specified in the
+  // memref type.
+  if (op.getOperation()->getNumOperands() != numDynamicDims + numSymbols)
+    return op.emitOpError(
+        "operand count does not equal dimension plus symbol operand count");
+
+  // Verify that all operands are of type Index.
+  for (auto operandType : op.getOperandTypes())
+    if (!operandType.isIndex())
+      return op.emitOpError("requires operands to be of type Index");
+  return success();
+}
+
+namespace {
+/// Fold constant dimensions into an alloc operation.
+struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
+  using OpRewritePattern<AllocOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(AllocOp alloc,
+                                     PatternRewriter &rewriter) const override {
+    // Check to see if any dimensions operands are constants.  If so, we can
+    // substitute and drop them.
+    if (llvm::none_of(alloc.getOperands(), [](Value *operand) {
+          return matchPattern(operand, m_ConstantIndex());
+        }))
+      return matchFailure();
+
+    auto memrefType = alloc.getType();
+
+    // Ok, we have one or more constant operands.  Collect the non-constant ones
+    // and keep track of the resultant memref type to build.
+    SmallVector<int64_t, 4> newShapeConstants;
+    newShapeConstants.reserve(memrefType.getRank());
+    SmallVector<Value *, 4> newOperands;
+    SmallVector<Value *, 4> droppedOperands;
+
+    unsigned dynamicDimPos = 0;
+    for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
+      int64_t dimSize = memrefType.getDimSize(dim);
+      // If this is already static dimension, keep it.
+      if (dimSize != -1) {
+        newShapeConstants.push_back(dimSize);
+        continue;
+      }
+      auto *defOp = alloc.getOperand(dynamicDimPos)->getDefiningOp();
+      if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
+        // Dynamic shape dimension will be folded.
+        newShapeConstants.push_back(constantIndexOp.getValue());
+        // Record to check for zero uses later below.
+        droppedOperands.push_back(constantIndexOp);
+      } else {
+        // Dynamic shape dimension not folded; copy operand from old memref.
+        newShapeConstants.push_back(-1);
+        newOperands.push_back(alloc.getOperand(dynamicDimPos));
+      }
+      dynamicDimPos++;
+    }
+
+    // Create new memref type (which will have fewer dynamic dimensions).
+    auto newMemRefType = MemRefType::get(
+        newShapeConstants, memrefType.getElementType(),
+        memrefType.getAffineMaps(), memrefType.getMemorySpace());
+    assert(static_cast<int64_t>(newOperands.size()) ==
+           newMemRefType.getNumDynamicDims());
+
+    // Create and insert the alloc op for the new memref.
+    auto newAlloc =
+        rewriter.create<AllocOp>(alloc.getLoc(), newMemRefType, newOperands);
+    // Insert a cast so we have the same type as the old alloc.
+    auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc,
+                                                    alloc.getType());
+
+    rewriter.replaceOp(alloc, {resultCast}, droppedOperands);
+    return matchSuccess();
+  }
+};
+
+/// Fold alloc operations with no uses. Alloc has side effects on the heap,
+/// but can still be deleted if it has zero uses.
+struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
+  using OpRewritePattern<AllocOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(AllocOp alloc,
+                                     PatternRewriter &rewriter) const override {
+    // Check if the alloc'ed value has any uses.
+    if (!alloc.use_empty())
+      return matchFailure();
+
+    // If it doesn't, we can eliminate it.
+    alloc.erase();
+    return matchSuccess();
+  }
+};
+} // end anonymous namespace.
+
+void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                          MLIRContext *context) {
+  results.insert<SimplifyAllocConst, SimplifyDeadAlloc>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// BranchOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseBranchOp(OpAsmParser *parser, OperationState *result) {
+  Block *dest;
+  SmallVector<Value *, 4> destOperands;
+  if (parser->parseSuccessorAndUseList(dest, destOperands))
+    return failure();
+  result->addSuccessor(dest, destOperands);
+  return success();
+}
+
+static void print(OpAsmPrinter *p, BranchOp op) {
+  *p << "br ";
+  p->printSuccessorAndUseList(op.getOperation(), 0);
+}
+
+Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); }
+
+void BranchOp::setDest(Block *block) {
+  return getOperation()->setSuccessor(block, 0);
+}
+
+void BranchOp::eraseOperand(unsigned index) {
+  getOperation()->eraseSuccessorOperand(0, index);
+}
+
+//===----------------------------------------------------------------------===//
+// CallOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
+  SymbolRefAttr calleeAttr;
+  FunctionType calleeType;
+  SmallVector<OpAsmParser::OperandType, 4> operands;
+  auto calleeLoc = parser->getNameLoc();
+  if (parser->parseAttribute(calleeAttr, "callee", result->attributes) ||
+      parser->parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(calleeType) ||
+      parser->addTypesToList(calleeType.getResults(), result->types) ||
+      parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc,
+                              result->operands))
+    return failure();
+
+  return success();
+}
+
+static void print(OpAsmPrinter *p, CallOp op) {
+  *p << "call " << op.getAttr("callee") << '(';
+  p->printOperands(op.getOperands());
+  *p << ')';
+  p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
+  *p << " : ";
+  p->printType(op.getCalleeType());
+}
+
+static LogicalResult verify(CallOp op) {
+  // Check that the callee attribute was specified.
+  auto fnAttr = op.getAttrOfType<SymbolRefAttr>("callee");
+  if (!fnAttr)
+    return op.emitOpError("requires a 'callee' symbol reference attribute");
+  auto fn =
+      op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
+  if (!fn)
+    return op.emitOpError() << "'" << fnAttr.getValue()
+                            << "' does not reference a valid function";
+
+  // Verify that the operand and result types match the callee.
+  auto fnType = fn.getType();
+  if (fnType.getNumInputs() != op.getNumOperands())
+    return op.emitOpError("incorrect number of operands for callee");
+
+  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
+    if (op.getOperand(i)->getType() != fnType.getInput(i))
+      return op.emitOpError("operand type mismatch");
+
+  if (fnType.getNumResults() != op.getNumResults())
+    return op.emitOpError("incorrect number of results for callee");
+
+  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
+    if (op.getResult(i)->getType() != fnType.getResult(i))
+      return op.emitOpError("result type mismatch");
+
+  return success();
+}
+
+FunctionType CallOp::getCalleeType() {
+  SmallVector<Type, 4> resultTypes(getResultTypes());
+  SmallVector<Type, 8> argTypes(getOperandTypes());
+  return FunctionType::get(argTypes, resultTypes, getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// CallIndirectOp
+//===----------------------------------------------------------------------===//
+namespace {
+/// Fold indirect calls that have a constant function as the callee operand.
+struct SimplifyIndirectCallWithKnownCallee
+    : public OpRewritePattern<CallIndirectOp> {
+  using OpRewritePattern<CallIndirectOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(CallIndirectOp indirectCall,
+                                     PatternRewriter &rewriter) const override {
+    // Check that the callee is a constant callee.
+    SymbolRefAttr calledFn;
+    if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
+      return matchFailure();
+
+    // Replace with a direct call.
+    SmallVector<Type, 8> callResults(indirectCall.getResultTypes());
+    SmallVector<Value *, 8> callOperands(indirectCall.getArgOperands());
+    rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn.getValue(),
+                                        callResults, callOperands);
+    return matchSuccess();
+  }
+};
+} // end anonymous namespace.
+
+static ParseResult parseCallIndirectOp(OpAsmParser *parser,
+                                       OperationState *result) {
+  FunctionType calleeType;
+  OpAsmParser::OperandType callee;
+  llvm::SMLoc operandsLoc;
+  SmallVector<OpAsmParser::OperandType, 4> operands;
+  return failure(
+      parser->parseOperand(callee) ||
+      parser->getCurrentLocation(&operandsLoc) ||
+      parser->parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(calleeType) ||
+      parser->resolveOperand(callee, calleeType, result->operands) ||
+      parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc,
+                              result->operands) ||
+      parser->addTypesToList(calleeType.getResults(), result->types));
+}
+
+static void print(OpAsmPrinter *p, CallIndirectOp op) {
+  *p << "call_indirect ";
+  p->printOperand(op.getCallee());
+  *p << '(';
+  auto operandRange = op.getOperands();
+  p->printOperands(++operandRange.begin(), operandRange.end());
+  *p << ')';
+  p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
+  *p << " : " << op.getCallee()->getType();
+}
+
+static LogicalResult verify(CallIndirectOp op) {
+  // The callee must be a function.
+  auto fnType = op.getCallee()->getType().dyn_cast<FunctionType>();
+  if (!fnType)
+    return op.emitOpError("callee must have function type");
+
+  // Verify that the operand and result types match the callee.
+  if (fnType.getNumInputs() != op.getNumOperands() - 1)
+    return op.emitOpError("incorrect number of operands for callee");
+
+  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
+    if (op.getOperand(i + 1)->getType() != fnType.getInput(i))
+      return op.emitOpError("operand type mismatch");
+
+  if (fnType.getNumResults() != op.getNumResults())
+    return op.emitOpError("incorrect number of results for callee");
+
+  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
+    if (op.getResult(i)->getType() != fnType.getResult(i))
+      return op.emitOpError("result type mismatch");
+
+  return success();
+}
+
+void CallIndirectOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<SimplifyIndirectCallWithKnownCallee>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// General helpers for comparison ops
+//===----------------------------------------------------------------------===//
+
+// Return the type of the same shape (scalar, vector or tensor) containing i1.
+static Type getCheckedI1SameShape(Builder *build, Type type) {
+  auto i1Type = build->getI1Type();
+  if (type.isIntOrIndexOrFloat())
+    return i1Type;
+  if (auto tensorType = type.dyn_cast<RankedTensorType>())
+    return build->getTensorType(tensorType.getShape(), i1Type);
+  if (type.isa<UnrankedTensorType>())
+    return build->getTensorType(i1Type);
+  if (auto vectorType = type.dyn_cast<VectorType>())
+    return build->getVectorType(vectorType.getShape(), i1Type);
+  return Type();
+}
+
+static Type getI1SameShape(Builder *build, Type type) {
+  Type res = getCheckedI1SameShape(build, type);
+  assert(res && "expected type with valid i1 shape");
+  return res;
+}
+
+//===----------------------------------------------------------------------===//
+// CmpIOp
+//===----------------------------------------------------------------------===//
+
+// Returns an array of mnemonics for CmpIPredicates indexed by values thereof.
+static inline const char *const *getCmpIPredicateNames() {
+  static const char *predicateNames[]{
+      /*EQ*/ "eq",
+      /*NE*/ "ne",
+      /*SLT*/ "slt",
+      /*SLE*/ "sle",
+      /*SGT*/ "sgt",
+      /*SGE*/ "sge",
+      /*ULT*/ "ult",
+      /*ULE*/ "ule",
+      /*UGT*/ "ugt",
+      /*UGE*/ "uge",
+  };
+  static_assert(std::extent<decltype(predicateNames)>::value ==
+                    (size_t)CmpIPredicate::NumPredicates,
+                "wrong number of predicate names");
+  return predicateNames;
+}
+
+// Returns a value of the predicate corresponding to the given mnemonic.
+// Returns NumPredicates (one-past-end) if there is no such mnemonic.
+CmpIPredicate CmpIOp::getPredicateByName(StringRef name) {
+  return llvm::StringSwitch<CmpIPredicate>(name)
+      .Case("eq", CmpIPredicate::EQ)
+      .Case("ne", CmpIPredicate::NE)
+      .Case("slt", CmpIPredicate::SLT)
+      .Case("sle", CmpIPredicate::SLE)
+      .Case("sgt", CmpIPredicate::SGT)
+      .Case("sge", CmpIPredicate::SGE)
+      .Case("ult", CmpIPredicate::ULT)
+      .Case("ule", CmpIPredicate::ULE)
+      .Case("ugt", CmpIPredicate::UGT)
+      .Case("uge", CmpIPredicate::UGE)
+      .Default(CmpIPredicate::NumPredicates);
+}
+
+static void buildCmpIOp(Builder *build, OperationState *result,
+                        CmpIPredicate predicate, Value *lhs, Value *rhs) {
+  result->addOperands({lhs, rhs});
+  result->types.push_back(getI1SameShape(build, lhs->getType()));
+  result->addAttribute(
+      CmpIOp::getPredicateAttrName(),
+      build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
+}
+
+static ParseResult parseCmpIOp(OpAsmParser *parser, OperationState *result) {
+  SmallVector<OpAsmParser::OperandType, 2> ops;
+  SmallVector<NamedAttribute, 4> attrs;
+  Attribute predicateNameAttr;
+  Type type;
+  if (parser->parseAttribute(predicateNameAttr, CmpIOp::getPredicateAttrName(),
+                             attrs) ||
+      parser->parseComma() || parser->parseOperandList(ops, 2) ||
+      parser->parseOptionalAttributeDict(attrs) ||
+      parser->parseColonType(type) ||
+      parser->resolveOperands(ops, type, result->operands))
+    return failure();
+
+  if (!predicateNameAttr.isa<StringAttr>())
+    return parser->emitError(parser->getNameLoc(),
+                             "expected string comparison predicate attribute");
+
+  // Rewrite string attribute to an enum value.
+  StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue();
+  auto predicate = CmpIOp::getPredicateByName(predicateName);
+  if (predicate == CmpIPredicate::NumPredicates)
+    return parser->emitError(parser->getNameLoc())
+           << "unknown comparison predicate \"" << predicateName << "\"";
+
+  auto builder = parser->getBuilder();
+  Type i1Type = getCheckedI1SameShape(&builder, type);
+  if (!i1Type)
+    return parser->emitError(parser->getNameLoc(),
+                             "expected type with valid i1 shape");
+
+  attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate));
+  result->attributes = attrs;
+
+  result->addTypes({i1Type});
+  return success();
+}
+
+static void print(OpAsmPrinter *p, CmpIOp op) {
+  *p << "cmpi ";
+
+  auto predicateValue =
+      op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt();
+  assert(predicateValue >= static_cast<int>(CmpIPredicate::FirstValidValue) &&
+         predicateValue < static_cast<int>(CmpIPredicate::NumPredicates) &&
+         "unknown predicate index");
+  Builder b(op.getContext());
+  auto predicateStringAttr =
+      b.getStringAttr(getCmpIPredicateNames()[predicateValue]);
+  p->printAttribute(predicateStringAttr);
+
+  *p << ", ";
+  p->printOperand(op.lhs());
+  *p << ", ";
+  p->printOperand(op.rhs());
+  p->printOptionalAttrDict(op.getAttrs(),
+                           /*elidedAttrs=*/{CmpIOp::getPredicateAttrName()});
+  *p << " : " << op.lhs()->getType();
+}
+
+static LogicalResult verify(CmpIOp op) {
+  auto predicateAttr =
+      op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName());
+  if (!predicateAttr)
+    return op.emitOpError("requires an integer attribute named 'predicate'");
+  auto predicate = predicateAttr.getInt();
+  if (predicate < (int64_t)CmpIPredicate::FirstValidValue ||
+      predicate >= (int64_t)CmpIPredicate::NumPredicates)
+    return op.emitOpError("'predicate' attribute value out of range");
+
+  return success();
+}
+
+// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
+// comparison predicates.
+static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
+                              const APInt &rhs) {
+  switch (predicate) {
+  case CmpIPredicate::EQ:
+    return lhs.eq(rhs);
+  case CmpIPredicate::NE:
+    return lhs.ne(rhs);
+  case CmpIPredicate::SLT:
+    return lhs.slt(rhs);
+  case CmpIPredicate::SLE:
+    return lhs.sle(rhs);
+  case CmpIPredicate::SGT:
+    return lhs.sgt(rhs);
+  case CmpIPredicate::SGE:
+    return lhs.sge(rhs);
+  case CmpIPredicate::ULT:
+    return lhs.ult(rhs);
+  case CmpIPredicate::ULE:
+    return lhs.ule(rhs);
+  case CmpIPredicate::UGT:
+    return lhs.ugt(rhs);
+  case CmpIPredicate::UGE:
+    return lhs.uge(rhs);
+  default:
+    llvm_unreachable("unknown comparison predicate");
+  }
+}
+
+// Constant folding hook for comparisons.
+OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "cmpi takes two arguments");
+
+  auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
+  auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
+  if (!lhs || !rhs)
+    return {};
+
+  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
+  return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
+}
+
+//===----------------------------------------------------------------------===//
+// CmpFOp
+//===----------------------------------------------------------------------===//
+
+// Returns an array of mnemonics for CmpFPredicates indexed by values thereof.
+static inline const char *const *getCmpFPredicateNames() {
+  static const char *predicateNames[] = {
+      /*AlwaysFalse*/ "false",
+      /*OEQ*/ "oeq",
+      /*OGT*/ "ogt",
+      /*OGE*/ "oge",
+      /*OLT*/ "olt",
+      /*OLE*/ "ole",
+      /*ONE*/ "one",
+      /*ORD*/ "ord",
+      /*UEQ*/ "ueq",
+      /*UGT*/ "ugt",
+      /*UGE*/ "uge",
+      /*ULT*/ "ult",
+      /*ULE*/ "ule",
+      /*UNE*/ "une",
+      /*UNO*/ "uno",
+      /*AlwaysTrue*/ "true",
+  };
+  static_assert(std::extent<decltype(predicateNames)>::value ==
+                    (size_t)CmpFPredicate::NumPredicates,
+                "wrong number of predicate names");
+  return predicateNames;
+}
+
+// Returns a value of the predicate corresponding to the given mnemonic.
+// Returns NumPredicates (one-past-end) if there is no such mnemonic.
+CmpFPredicate CmpFOp::getPredicateByName(StringRef name) {
+  return llvm::StringSwitch<CmpFPredicate>(name)
+      .Case("false", CmpFPredicate::AlwaysFalse)
+      .Case("oeq", CmpFPredicate::OEQ)
+      .Case("ogt", CmpFPredicate::OGT)
+      .Case("oge", CmpFPredicate::OGE)
+      .Case("olt", CmpFPredicate::OLT)
+      .Case("ole", CmpFPredicate::OLE)
+      .Case("one", CmpFPredicate::ONE)
+      .Case("ord", CmpFPredicate::ORD)
+      .Case("ueq", CmpFPredicate::UEQ)
+      .Case("ugt", CmpFPredicate::UGT)
+      .Case("uge", CmpFPredicate::UGE)
+      .Case("ult", CmpFPredicate::ULT)
+      .Case("ule", CmpFPredicate::ULE)
+      .Case("une", CmpFPredicate::UNE)
+      .Case("uno", CmpFPredicate::UNO)
+      .Case("true", CmpFPredicate::AlwaysTrue)
+      .Default(CmpFPredicate::NumPredicates);
+}
+
+static void buildCmpFOp(Builder *build, OperationState *result,
+                        CmpFPredicate predicate, Value *lhs, Value *rhs) {
+  result->addOperands({lhs, rhs});
+  result->types.push_back(getI1SameShape(build, lhs->getType()));
+  result->addAttribute(
+      CmpFOp::getPredicateAttrName(),
+      build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
+}
+
+static ParseResult parseCmpFOp(OpAsmParser *parser, OperationState *result) {
+  SmallVector<OpAsmParser::OperandType, 2> ops;
+  SmallVector<NamedAttribute, 4> attrs;
+  Attribute predicateNameAttr;
+  Type type;
+  if (parser->parseAttribute(predicateNameAttr, CmpFOp::getPredicateAttrName(),
+                             attrs) ||
+      parser->parseComma() || parser->parseOperandList(ops, 2) ||
+      parser->parseOptionalAttributeDict(attrs) ||
+      parser->parseColonType(type) ||
+      parser->resolveOperands(ops, type, result->operands))
+    return failure();
+
+  if (!predicateNameAttr.isa<StringAttr>())
+    return parser->emitError(parser->getNameLoc(),
+                             "expected string comparison predicate attribute");
+
+  // Rewrite string attribute to an enum value.
+  StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue();
+  auto predicate = CmpFOp::getPredicateByName(predicateName);
+  if (predicate == CmpFPredicate::NumPredicates)
+    return parser->emitError(parser->getNameLoc(),
+                             "unknown comparison predicate \"" + predicateName +
+                                 "\"");
+
+  auto builder = parser->getBuilder();
+  Type i1Type = getCheckedI1SameShape(&builder, type);
+  if (!i1Type)
+    return parser->emitError(parser->getNameLoc(),
+                             "expected type with valid i1 shape");
+
+  attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate));
+  result->attributes = attrs;
+
+  result->addTypes({i1Type});
+  return success();
+}
+
+static void print(OpAsmPrinter *p, CmpFOp op) {
+  *p << "cmpf ";
+
+  auto predicateValue =
+      op.getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName()).getInt();
+  assert(predicateValue >= static_cast<int>(CmpFPredicate::FirstValidValue) &&
+         predicateValue < static_cast<int>(CmpFPredicate::NumPredicates) &&
+         "unknown predicate index");
+  Builder b(op.getContext());
+  auto predicateStringAttr =
+      b.getStringAttr(getCmpFPredicateNames()[predicateValue]);
+  p->printAttribute(predicateStringAttr);
+
+  *p << ", ";
+  p->printOperand(op.lhs());
+  *p << ", ";
+  p->printOperand(op.rhs());
+  p->printOptionalAttrDict(op.getAttrs(),
+                           /*elidedAttrs=*/{CmpFOp::getPredicateAttrName()});
+  *p << " : " << op.lhs()->getType();
+}
+
+static LogicalResult verify(CmpFOp op) {
+  auto predicateAttr =
+      op.getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName());
+  if (!predicateAttr)
+    return op.emitOpError("requires an integer attribute named 'predicate'");
+  auto predicate = predicateAttr.getInt();
+  if (predicate < (int64_t)CmpFPredicate::FirstValidValue ||
+      predicate >= (int64_t)CmpFPredicate::NumPredicates)
+    return op.emitOpError("'predicate' attribute value out of range");
+
+  return success();
+}
+
+// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
+// comparison predicates.
+static bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
+                              const APFloat &rhs) {
+  auto cmpResult = lhs.compare(rhs);
+  switch (predicate) {
+  case CmpFPredicate::AlwaysFalse:
+    return false;
+  case CmpFPredicate::OEQ:
+    return cmpResult == APFloat::cmpEqual;
+  case CmpFPredicate::OGT:
+    return cmpResult == APFloat::cmpGreaterThan;
+  case CmpFPredicate::OGE:
+    return cmpResult == APFloat::cmpGreaterThan ||
+           cmpResult == APFloat::cmpEqual;
+  case CmpFPredicate::OLT:
+    return cmpResult == APFloat::cmpLessThan;
+  case CmpFPredicate::OLE:
+    return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
+  case CmpFPredicate::ONE:
+    return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
+  case CmpFPredicate::ORD:
+    return cmpResult != APFloat::cmpUnordered;
+  case CmpFPredicate::UEQ:
+    return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
+  case CmpFPredicate::UGT:
+    return cmpResult == APFloat::cmpUnordered ||
+           cmpResult == APFloat::cmpGreaterThan;
+  case CmpFPredicate::UGE:
+    return cmpResult == APFloat::cmpUnordered ||
+           cmpResult == APFloat::cmpGreaterThan ||
+           cmpResult == APFloat::cmpEqual;
+  case CmpFPredicate::ULT:
+    return cmpResult == APFloat::cmpUnordered ||
+           cmpResult == APFloat::cmpLessThan;
+  case CmpFPredicate::ULE:
+    return cmpResult == APFloat::cmpUnordered ||
+           cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
+  case CmpFPredicate::UNE:
+    return cmpResult != APFloat::cmpEqual;
+  case CmpFPredicate::UNO:
+    return cmpResult == APFloat::cmpUnordered;
+  case CmpFPredicate::AlwaysTrue:
+    return true;
+  default:
+    llvm_unreachable("unknown comparison predicate");
+  }
+}
+
+// Constant folding hook for comparisons.
+OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "cmpf takes two arguments");
+
+  auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
+  auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
+  if (!lhs || !rhs ||
+      // TODO(b/122019992) Implement and test constant folding for nan/inf when
+      // it is possible to have constant nan/inf
+      !lhs.getValue().isFinite() || !rhs.getValue().isFinite())
+    return {};
+
+  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
+  return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
+}
+
+//===----------------------------------------------------------------------===//
+// CondBranchOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// cond_br true, ^bb1, ^bb2 -> br ^bb1
+/// cond_br false, ^bb1, ^bb2 -> br ^bb2
+///
+struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
+  using OpRewritePattern<CondBranchOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(CondBranchOp condbr,
+                                     PatternRewriter &rewriter) const override {
+    // Check that the condition is a constant.
+    if (!matchPattern(condbr.getCondition(), m_Op<ConstantOp>()))
+      return matchFailure();
+
+    Block *foldedDest;
+    SmallVector<Value *, 4> branchArgs;
+
+    // If the condition is known to evaluate to false we fold to a branch to the
+    // false destination. Otherwise, we fold to a branch to the true
+    // destination.
+    if (matchPattern(condbr.getCondition(), m_Zero())) {
+      foldedDest = condbr.getFalseDest();
+      branchArgs.assign(condbr.false_operand_begin(),
+                        condbr.false_operand_end());
+    } else {
+      foldedDest = condbr.getTrueDest();
+      branchArgs.assign(condbr.true_operand_begin(), condbr.true_operand_end());
+    }
+
+    rewriter.replaceOpWithNewOp<BranchOp>(condbr, foldedDest, branchArgs);
+    return matchSuccess();
+  }
+};
+} // end anonymous namespace.
+
+static ParseResult parseCondBranchOp(OpAsmParser *parser,
+                                     OperationState *result) {
+  SmallVector<Value *, 4> destOperands;
+  Block *dest;
+  OpAsmParser::OperandType condInfo;
+
+  // Parse the condition.
+  Type int1Ty = parser->getBuilder().getI1Type();
+  if (parser->parseOperand(condInfo) || parser->parseComma() ||
+      parser->resolveOperand(condInfo, int1Ty, result->operands)) {
+    return parser->emitError(parser->getNameLoc(),
+                             "expected condition type was boolean (i1)");
+  }
+
+  // Parse the true successor.
+  if (parser->parseSuccessorAndUseList(dest, destOperands))
+    return failure();
+  result->addSuccessor(dest, destOperands);
+
+  // Parse the false successor.
+  destOperands.clear();
+  if (parser->parseComma() ||
+      parser->parseSuccessorAndUseList(dest, destOperands))
+    return failure();
+  result->addSuccessor(dest, destOperands);
+
+  return success();
+}
+
+static void print(OpAsmPrinter *p, CondBranchOp op) {
+  *p << "cond_br ";
+  p->printOperand(op.getCondition());
+  *p << ", ";
+  p->printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex);
+  *p << ", ";
+  p->printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex);
+}
+
+void CondBranchOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<SimplifyConstCondBranchPred>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// Constant*Op
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter *p, ConstantOp &op) {
+  *p << "constant ";
+  p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
+
+  if (op.getAttrs().size() > 1)
+    *p << ' ';
+  p->printAttribute(op.getValue());
+
+  // If the value is a symbol reference, print a trailing type.
+  if (op.getValue().isa<SymbolRefAttr>()) {
+    *p << " : ";
+    p->printType(op.getType());
+  }
+}
+
+static ParseResult parseConstantOp(OpAsmParser *parser,
+                                   OperationState *result) {
+  Attribute valueAttr;
+  if (parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseAttribute(valueAttr, "value", result->attributes))
+    return failure();
+
+  // If the attribute is a symbol reference, then we expect a trailing type.
+  Type type;
+  if (!valueAttr.isa<SymbolRefAttr>())
+    type = valueAttr.getType();
+  else if (parser->parseColonType(type))
+    return failure();
+
+  // Add the attribute type to the list.
+  return parser->addTypeToList(type, result->types);
+}
+
+/// The constant op requires an attribute, and furthermore requires that it
+/// matches the return type.
+static LogicalResult verify(ConstantOp &op) {
+  auto value = op.getValue();
+  if (!value)
+    return op.emitOpError("requires a 'value' attribute");
+
+  auto type = op.getType();
+  if (!value.getType().isa<NoneType>() && type != value.getType())
+    return op.emitOpError() << "requires attribute's type (" << value.getType()
+                            << ") to match op's return type (" << type << ")";
+
+  if (type.isa<IndexType>() || value.isa<BoolAttr>())
+    return success();
+
+  if (auto intAttr = value.dyn_cast<IntegerAttr>()) {
+    // If the type has a known bitwidth we verify that the value can be
+    // represented with the given bitwidth.
+    auto bitwidth = type.cast<IntegerType>().getWidth();
+    auto intVal = intAttr.getValue();
+    if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth))
+      return op.emitOpError("requires 'value' to be an integer within the "
+                            "range of the integer result type");
+    return success();
+  }
+
+  if (type.isa<FloatType>()) {
+    if (!value.isa<FloatAttr>())
+      return op.emitOpError("requires 'value' to be a floating point constant");
+    return success();
+  }
+
+  if (type.isa<ShapedType>()) {
+    if (!value.isa<ElementsAttr>())
+      return op.emitOpError("requires 'value' to be a shaped constant");
+    return success();
+  }
+
+  if (type.isa<FunctionType>()) {
+    auto fnAttr = value.dyn_cast<SymbolRefAttr>();
+    if (!fnAttr)
+      return op.emitOpError("requires 'value' to be a function reference");
+
+    // Try to find the referenced function.
+    auto fn =
+        op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
+    if (!fn)
+      return op.emitOpError("reference to undefined function 'bar'");
+
+    // Check that the referenced function has the correct type.
+    if (fn.getType() != type)
+      return op.emitOpError("reference to function with mismatched type");
+
+    return success();
+  }
+
+  if (type.isa<NoneType>() && value.isa<UnitAttr>())
+    return success();
+
+  return op.emitOpError("unsupported 'value' attribute: ") << value;
+}
+
+OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.empty() && "constant has no operands");
+  return getValue();
+}
+
+/// Returns true if a constant operation can be built with the given value and
+/// result type.
+bool ConstantOp::isBuildableWith(Attribute value, Type type) {
+  // SymbolRefAttr can only be used with a function type.
+  if (value.isa<SymbolRefAttr>())
+    return type.isa<FunctionType>();
+  // Otherwise, the attribute must have the same type as 'type'.
+  if (value.getType() != type)
+    return false;
+  // Finally, check that the attribute kind is handled.
+  return value.isa<BoolAttr>() || value.isa<IntegerAttr>() ||
+         value.isa<FloatAttr>() || value.isa<ElementsAttr>() ||
+         value.isa<UnitAttr>();
+}
+
+void ConstantFloatOp::build(Builder *builder, OperationState *result,
+                            const APFloat &value, FloatType type) {
+  ConstantOp::build(builder, result, type, builder->getFloatAttr(type, value));
+}
+
+bool ConstantFloatOp::classof(Operation *op) {
+  return ConstantOp::classof(op) &&
+         op->getResult(0)->getType().isa<FloatType>();
+}
+
+/// ConstantIntOp only matches values whose result type is an IntegerType.
+bool ConstantIntOp::classof(Operation *op) {
+  return ConstantOp::classof(op) &&
+         op->getResult(0)->getType().isa<IntegerType>();
+}
+
+void ConstantIntOp::build(Builder *builder, OperationState *result,
+                          int64_t value, unsigned width) {
+  Type type = builder->getIntegerType(width);
+  ConstantOp::build(builder, result, type,
+                    builder->getIntegerAttr(type, value));
+}
+
+/// Build a constant int op producing an integer with the specified type,
+/// which must be an integer type.
+void ConstantIntOp::build(Builder *builder, OperationState *result,
+                          int64_t value, Type type) {
+  assert(type.isa<IntegerType>() && "ConstantIntOp can only have integer type");
+  ConstantOp::build(builder, result, type,
+                    builder->getIntegerAttr(type, value));
+}
+
+/// ConstantIndexOp only matches values whose result type is Index.
+bool ConstantIndexOp::classof(Operation *op) {
+  return ConstantOp::classof(op) && op->getResult(0)->getType().isIndex();
+}
+
+void ConstantIndexOp::build(Builder *builder, OperationState *result,
+                            int64_t value) {
+  Type type = builder->getIndexType();
+  ConstantOp::build(builder, result, type,
+                    builder->getIntegerAttr(type, value));
+}
+
+//===----------------------------------------------------------------------===//
+// DeallocOp
+//===----------------------------------------------------------------------===//
+namespace {
+/// Fold Dealloc operations that are deallocating an AllocOp that is only used
+/// by other Dealloc operations.
+struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
+  using OpRewritePattern<DeallocOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(DeallocOp dealloc,
+                                     PatternRewriter &rewriter) const override {
+    // Check that the memref operand's defining operation is an AllocOp.
+    Value *memref = dealloc.memref();
+    if (!isa_and_nonnull<AllocOp>(memref->getDefiningOp()))
+      return matchFailure();
+
+    // Check that all of the uses of the AllocOp are other DeallocOps.
+    for (auto *user : memref->getUsers())
+      if (!isa<DeallocOp>(user))
+        return matchFailure();
+
+    // Erase the dealloc operation.
+    rewriter.replaceOp(dealloc, llvm::None);
+    return matchSuccess();
+  }
+};
+} // end anonymous namespace.
+
+static void print(OpAsmPrinter *p, DeallocOp op) {
+  *p << "dealloc " << *op.memref() << " : " << op.memref()->getType();
+}
+
+static ParseResult parseDeallocOp(OpAsmParser *parser, OperationState *result) {
+  OpAsmParser::OperandType memrefInfo;
+  MemRefType type;
+
+  return failure(parser->parseOperand(memrefInfo) ||
+                 parser->parseColonType(type) ||
+                 parser->resolveOperand(memrefInfo, type, result->operands));
+}
+
+static LogicalResult verify(DeallocOp op) {
+  if (!op.memref()->getType().isa<MemRefType>())
+    return op.emitOpError("operand must be a memref");
+  return success();
+}
+
+void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                            MLIRContext *context) {
+  /// dealloc(memrefcast) -> dealloc
+  results.insert<MemRefCastFolder>(getOperationName(), context);
+  results.insert<SimplifyDeadDealloc>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// DimOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter *p, DimOp op) {
+  *p << "dim " << *op.getOperand() << ", " << op.getIndex();
+  p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"});
+  *p << " : " << op.getOperand()->getType();
+}
+
+static ParseResult parseDimOp(OpAsmParser *parser, OperationState *result) {
+  OpAsmParser::OperandType operandInfo;
+  IntegerAttr indexAttr;
+  Type type;
+  Type indexType = parser->getBuilder().getIndexType();
+
+  return failure(parser->parseOperand(operandInfo) || parser->parseComma() ||
+                 parser->parseAttribute(indexAttr, indexType, "index",
+                                        result->attributes) ||
+                 parser->parseOptionalAttributeDict(result->attributes) ||
+                 parser->parseColonType(type) ||
+                 parser->resolveOperand(operandInfo, type, result->operands) ||
+                 parser->addTypeToList(indexType, result->types));
+}
+
+static LogicalResult verify(DimOp op) {
+  // Check that we have an integer index operand.
+  auto indexAttr = op.getAttrOfType<IntegerAttr>("index");
+  if (!indexAttr)
+    return op.emitOpError("requires an integer attribute named 'index'");
+  int64_t index = indexAttr.getValue().getSExtValue();
+
+  auto type = op.getOperand()->getType();
+  if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
+    if (index >= tensorType.getRank())
+      return op.emitOpError("index is out of range");
+  } else if (auto memrefType = type.dyn_cast<MemRefType>()) {
+    if (index >= memrefType.getRank())
+      return op.emitOpError("index is out of range");
+
+  } else if (type.isa<UnrankedTensorType>()) {
+    // ok, assumed to be in-range.
+  } else {
+    return op.emitOpError("requires an operand with tensor or memref type");
+  }
+
+  return success();
+}
+
+OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
+  // Constant fold dim when the size along the index referred to is a constant.
+  auto opType = getOperand()->getType();
+  int64_t indexSize = -1;
+  if (auto tensorType = opType.dyn_cast<RankedTensorType>())
+    indexSize = tensorType.getShape()[getIndex()];
+  else if (auto memrefType = opType.dyn_cast<MemRefType>())
+    indexSize = memrefType.getShape()[getIndex()];
+
+  if (indexSize >= 0)
+    return IntegerAttr::get(IndexType::get(getContext()), indexSize);
+
+  return {};
+}
+
+//===----------------------------------------------------------------------===//
+// DivISOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult DivISOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "binary operation takes two operands");
+
+  auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
+  auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
+  if (!lhs || !rhs)
+    return {};
+
+  // Don't fold if it requires division by zero.
+  if (rhs.getValue().isNullValue())
+    return {};
+
+  // Don't fold if it would overflow.
+  bool overflow;
+  auto result = lhs.getValue().sdiv_ov(rhs.getValue(), overflow);
+  return overflow ? IntegerAttr{} : IntegerAttr::get(lhs.getType(), result);
+}
+
+//===----------------------------------------------------------------------===//
+// DivIUOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult DivIUOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "binary operation takes two operands");
+
+  auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
+  auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
+  if (!lhs || !rhs)
+    return {};
+
+  // Don't fold if it requires division by zero.
+  if (rhs.getValue().isNullValue()) {
+    return {};
+  }
+
+  return IntegerAttr::get(lhs.getType(), lhs.getValue().udiv(rhs.getValue()));
+}
+
+// ---------------------------------------------------------------------------
+// DmaStartOp
+// ---------------------------------------------------------------------------
+
+void DmaStartOp::build(Builder *builder, OperationState *result,
+                       Value *srcMemRef, ArrayRef<Value *> srcIndices,
+                       Value *destMemRef, ArrayRef<Value *> destIndices,
+                       Value *numElements, Value *tagMemRef,
+                       ArrayRef<Value *> tagIndices, Value *stride,
+                       Value *elementsPerStride) {
+  result->addOperands(srcMemRef);
+  result->addOperands(srcIndices);
+  result->addOperands(destMemRef);
+  result->addOperands(destIndices);
+  result->addOperands(numElements);
+  result->addOperands(tagMemRef);
+  result->addOperands(tagIndices);
+  if (stride) {
+    result->addOperands(stride);
+    result->addOperands(elementsPerStride);
+  }
+}
+
+void DmaStartOp::print(OpAsmPrinter *p) {
+  *p << "dma_start " << *getSrcMemRef() << '[';
+  p->printOperands(getSrcIndices());
+  *p << "], " << *getDstMemRef() << '[';
+  p->printOperands(getDstIndices());
+  *p << "], " << *getNumElements();
+  *p << ", " << *getTagMemRef() << '[';
+  p->printOperands(getTagIndices());
+  *p << ']';
+  if (isStrided()) {
+    *p << ", " << *getStride();
+    *p << ", " << *getNumElementsPerStride();
+  }
+  p->printOptionalAttrDict(getAttrs());
+  *p << " : " << getSrcMemRef()->getType();
+  *p << ", " << getDstMemRef()->getType();
+  *p << ", " << getTagMemRef()->getType();
+}
+
+// Parse DmaStartOp.
+// Ex:
+//   %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
+//                       %tag[%index], %stride, %num_elt_per_stride :
+//                     : memref<3076 x f32, 0>,
+//                       memref<1024 x f32, 2>,
+//                       memref<1 x i32>
+//
+ParseResult DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
+  OpAsmParser::OperandType srcMemRefInfo;
+  SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
+  OpAsmParser::OperandType dstMemRefInfo;
+  SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
+  OpAsmParser::OperandType numElementsInfo;
+  OpAsmParser::OperandType tagMemrefInfo;
+  SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
+  SmallVector<OpAsmParser::OperandType, 2> strideInfo;
+
+  SmallVector<Type, 3> types;
+  auto indexType = parser->getBuilder().getIndexType();
+
+  // Parse and resolve the following list of operands:
+  // *) source memref followed by its indices (in square brackets).
+  // *) destination memref followed by its indices (in square brackets).
+  // *) dma size in KiB.
+  if (parser->parseOperand(srcMemRefInfo) ||
+      parser->parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
+      parser->parseComma() || parser->parseOperand(dstMemRefInfo) ||
+      parser->parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
+      parser->parseComma() || parser->parseOperand(numElementsInfo) ||
+      parser->parseComma() || parser->parseOperand(tagMemrefInfo) ||
+      parser->parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
+    return failure();
+
+  // Parse optional stride and elements per stride.
+  if (parser->parseTrailingOperandList(strideInfo)) {
+    return failure();
+  }
+  if (!strideInfo.empty() && strideInfo.size() != 2) {
+    return parser->emitError(parser->getNameLoc(),
+                             "expected two stride related operands");
+  }
+  bool isStrided = strideInfo.size() == 2;
+
+  if (parser->parseColonTypeList(types))
+    return failure();
+
+  if (types.size() != 3)
+    return parser->emitError(parser->getNameLoc(), "fewer/more types expected");
+
+  if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) ||
+      parser->resolveOperands(srcIndexInfos, indexType, result->operands) ||
+      parser->resolveOperand(dstMemRefInfo, types[1], result->operands) ||
+      parser->resolveOperands(dstIndexInfos, indexType, result->operands) ||
+      // size should be an index.
+      parser->resolveOperand(numElementsInfo, indexType, result->operands) ||
+      parser->resolveOperand(tagMemrefInfo, types[2], result->operands) ||
+      // tag indices should be index.
+      parser->resolveOperands(tagIndexInfos, indexType, result->operands))
+    return failure();
+
+  if (!types[0].isa<MemRefType>())
+    return parser->emitError(parser->getNameLoc(),
+                             "expected source to be of memref type");
+
+  if (!types[1].isa<MemRefType>())
+    return parser->emitError(parser->getNameLoc(),
+                             "expected destination to be of memref type");
+
+  if (!types[2].isa<MemRefType>())
+    return parser->emitError(parser->getNameLoc(),
+                             "expected tag to be of memref type");
+
+  if (isStrided) {
+    if (parser->resolveOperand(strideInfo[0], indexType, result->operands) ||
+        parser->resolveOperand(strideInfo[1], indexType, result->operands))
+      return failure();
+  }
+
+  // Check that source/destination index list size matches associated rank.
+  if (static_cast<int64_t>(srcIndexInfos.size()) !=
+          types[0].cast<MemRefType>().getRank() ||
+      static_cast<int64_t>(dstIndexInfos.size()) !=
+          types[1].cast<MemRefType>().getRank())
+    return parser->emitError(parser->getNameLoc(),
+                             "memref rank not equal to indices count");
+
+  if (static_cast<int64_t>(tagIndexInfos.size()) !=
+      types[2].cast<MemRefType>().getRank())
+    return parser->emitError(parser->getNameLoc(),
+                             "tag memref rank not equal to indices count");
+
+  return success();
+}
+
+LogicalResult DmaStartOp::verify() {
+  // DMAs from different memory spaces supported.
+  if (getSrcMemorySpace() == getDstMemorySpace()) {
+    return emitOpError("DMA should be between different memory spaces");
+  }
+
+  if (getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() +
+                              getDstMemRefRank() + 3 + 1 &&
+      getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() +
+                              getDstMemRefRank() + 3 + 1 + 2) {
+    return emitOpError("incorrect number of operands");
+  }
+  return success();
+}
+
+void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                             MLIRContext *context) {
+  /// dma_start(memrefcast) -> dma_start
+  results.insert<MemRefCastFolder>(getOperationName(), context);
+}
+
+// ---------------------------------------------------------------------------
+// DmaWaitOp
+// ---------------------------------------------------------------------------
+
+void DmaWaitOp::build(Builder *builder, OperationState *result,
+                      Value *tagMemRef, ArrayRef<Value *> tagIndices,
+                      Value *numElements) {
+  result->addOperands(tagMemRef);
+  result->addOperands(tagIndices);
+  result->addOperands(numElements);
+}
+
+void DmaWaitOp::print(OpAsmPrinter *p) {
+  *p << "dma_wait ";
+  // Print operands.
+  p->printOperand(getTagMemRef());
+  *p << '[';
+  p->printOperands(getTagIndices());
+  *p << "], ";
+  p->printOperand(getNumElements());
+  p->printOptionalAttrDict(getAttrs());
+  *p << " : " << getTagMemRef()->getType();
+}
+
+// Parse DmaWaitOp.
+// Eg:
+//   dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
+//
+ParseResult DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
+  OpAsmParser::OperandType tagMemrefInfo;
+  SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
+  Type type;
+  auto indexType = parser->getBuilder().getIndexType();
+  OpAsmParser::OperandType numElementsInfo;
+
+  // Parse tag memref, its indices, and dma size.
+  if (parser->parseOperand(tagMemrefInfo) ||
+      parser->parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) ||
+      parser->parseComma() || parser->parseOperand(numElementsInfo) ||
+      parser->parseColonType(type) ||
+      parser->resolveOperand(tagMemrefInfo, type, result->operands) ||
+      parser->resolveOperands(tagIndexInfos, indexType, result->operands) ||
+      parser->resolveOperand(numElementsInfo, indexType, result->operands))
+    return failure();
+
+  if (!type.isa<MemRefType>())
+    return parser->emitError(parser->getNameLoc(),
+                             "expected tag to be of memref type");
+
+  if (static_cast<int64_t>(tagIndexInfos.size()) !=
+      type.cast<MemRefType>().getRank())
+    return parser->emitError(parser->getNameLoc(),
+                             "tag memref rank not equal to indices count");
+
+  return success();
+}
+
+void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                            MLIRContext *context) {
+  /// dma_wait(memrefcast) -> dma_wait
+  results.insert<MemRefCastFolder>(getOperationName(), context);
+}
+
+//===----------------------------------------------------------------------===//
+// ExtractElementOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter *p, ExtractElementOp op) {
+  *p << "extract_element " << *op.getAggregate() << '[';
+  p->printOperands(op.getIndices());
+  *p << ']';
+  p->printOptionalAttrDict(op.getAttrs());
+  *p << " : " << op.getAggregate()->getType();
+}
+
+static ParseResult parseExtractElementOp(OpAsmParser *parser,
+                                         OperationState *result) {
+  OpAsmParser::OperandType aggregateInfo;
+  SmallVector<OpAsmParser::OperandType, 4> indexInfo;
+  ShapedType type;
+
+  auto affineIntTy = parser->getBuilder().getIndexType();
+  return failure(
+      parser->parseOperand(aggregateInfo) ||
+      parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(type) ||
+      parser->resolveOperand(aggregateInfo, type, result->operands) ||
+      parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
+      parser->addTypeToList(type.getElementType(), result->types));
+}
+
+static LogicalResult verify(ExtractElementOp op) {
+  auto aggregateType = op.getAggregate()->getType().cast<ShapedType>();
+
+  // This should be possible with tablegen type constraints
+  if (op.getType() != aggregateType.getElementType())
+    return op.emitOpError("result type must match element type of aggregate");
+
+  // Verify the # indices match if we have a ranked type.
+  if (aggregateType.hasRank() &&
+      aggregateType.getRank() != op.getNumOperands() - 1)
+    return op.emitOpError("incorrect number of indices for extract_element");
+
+  return success();
+}
+
+OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
+  assert(!operands.empty() && "extract_element takes atleast one operand");
+
+  // The aggregate operand must be a known constant.
+  Attribute aggregate = operands.front();
+  if (!aggregate)
+    return {};
+
+  // If this is a splat elements attribute, simply return the value. All of the
+  // elements of a splat attribute are the same.
+  if (auto splatAggregate = aggregate.dyn_cast<SplatElementsAttr>())
+    return splatAggregate.getSplatValue();
+
+  // Otherwise, collect the constant indices into the aggregate.
+  SmallVector<uint64_t, 8> indices;
+  for (Attribute indice : llvm::drop_begin(operands, 1)) {
+    if (!indice || !indice.isa<IntegerAttr>())
+      return {};
+    indices.push_back(indice.cast<IntegerAttr>().getInt());
+  }
+
+  // If this is an elements attribute, query the value at the given indices.
+  if (auto elementsAttr = aggregate.dyn_cast<ElementsAttr>())
+    return elementsAttr.getValue(indices);
+  return {};
+}
+
+//===----------------------------------------------------------------------===//
+// IndexCastOp
+//===----------------------------------------------------------------------===//
+
+// Index cast is applicable from index to integer and backwards.
+bool IndexCastOp::areCastCompatible(Type a, Type b) {
+  return (a.isIndex() && b.isa<IntegerType>()) ||
+         (a.isa<IntegerType>() && b.isIndex());
+}
+
+//===----------------------------------------------------------------------===//
+// LoadOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter *p, LoadOp op) {
+  *p << "load " << *op.getMemRef() << '[';
+  p->printOperands(op.getIndices());
+  *p << ']';
+  p->printOptionalAttrDict(op.getAttrs());
+  *p << " : " << op.getMemRefType();
+}
+
+static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *result) {
+  OpAsmParser::OperandType memrefInfo;
+  SmallVector<OpAsmParser::OperandType, 4> indexInfo;
+  MemRefType type;
+
+  auto affineIntTy = parser->getBuilder().getIndexType();
+  return failure(
+      parser->parseOperand(memrefInfo) ||
+      parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(type) ||
+      parser->resolveOperand(memrefInfo, type, result->operands) ||
+      parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
+      parser->addTypeToList(type.getElementType(), result->types));
+}
+
+static LogicalResult verify(LoadOp op) {
+  if (op.getType() != op.getMemRefType().getElementType())
+    return op.emitOpError("result type must match element type of memref");
+
+  if (op.getMemRefType().getRank() != op.getNumOperands() - 1)
+    return op.emitOpError("incorrect number of indices for load");
+
+  for (auto *idx : op.getIndices())
+    if (!idx->getType().isIndex())
+      return op.emitOpError("index to load must have 'index' type");
+
+  // TODO: Verify we have the right number of indices.
+
+  // TODO: in Function verify that the indices are parameters, IV's, or the
+  // result of an affine.apply.
+  return success();
+}
+
+void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                         MLIRContext *context) {
+  /// load(memrefcast) -> load
+  results.insert<MemRefCastFolder>(getOperationName(), context);
+}
+
+//===----------------------------------------------------------------------===//
+// MemRefCastOp
+//===----------------------------------------------------------------------===//
+
+bool MemRefCastOp::areCastCompatible(Type a, Type b) {
+  auto aT = a.dyn_cast<MemRefType>();
+  auto bT = b.dyn_cast<MemRefType>();
+
+  if (!aT || !bT)
+    return false;
+  if (aT.getElementType() != bT.getElementType())
+    return false;
+  if (aT.getAffineMaps() != bT.getAffineMaps())
+    return false;
+  if (aT.getMemorySpace() != bT.getMemorySpace())
+    return false;
+
+  // They must have the same rank, and any specified dimensions must match.
+  if (aT.getRank() != bT.getRank())
+    return false;
+
+  for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
+    int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
+    if (aDim != -1 && bDim != -1 && aDim != bDim)
+      return false;
+  }
+
+  return true;
+}
+
+OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
+  return impl::foldCastOp(*this);
+}
+
+//===----------------------------------------------------------------------===//
+// MulFOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult MulFOp::fold(ArrayRef<Attribute> operands) {
+  return constFoldBinaryOp<FloatAttr>(
+      operands, [](APFloat a, APFloat b) { return a * b; });
+}
+
+//===----------------------------------------------------------------------===//
+// MulIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) {
+  /// muli(x, 0) -> 0
+  if (matchPattern(rhs(), m_Zero()))
+    return rhs();
+  /// muli(x, 1) -> x
+  if (matchPattern(rhs(), m_One()))
+    return getOperand(0);
+
+  // TODO: Handle the overflow case.
+  return constFoldBinaryOp<IntegerAttr>(operands,
+                                        [](APInt a, APInt b) { return a * b; });
+}
+
+//===----------------------------------------------------------------------===//
+// RankOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter *p, RankOp op) {
+  *p << "rank " << *op.getOperand() << " : " << op.getOperand()->getType();
+}
+
+static ParseResult parseRankOp(OpAsmParser *parser, OperationState *result) {
+  OpAsmParser::OperandType operandInfo;
+  Type type;
+  Type indexType = parser->getBuilder().getIndexType();
+
+  return failure(parser->parseOperand(operandInfo) ||
+                 parser->parseColonType(type) ||
+                 parser->resolveOperand(operandInfo, type, result->operands) ||
+                 parser->addTypeToList(indexType, result->types));
+}
+
+OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
+  // Constant fold rank when the rank of the tensor is known.
+  auto type = getOperand()->getType();
+  if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
+    int64_t rank = tensorType.getRank();
+    return IntegerAttr::get(IndexType::get(getContext()), rank);
+  }
+  return IntegerAttr();
+}
+
+//===----------------------------------------------------------------------===//
+// RemISOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult RemISOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "remis takes two operands");
+
+  auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
+  if (!rhs)
+    return {};
+
+  // x % 1 = 0
+  if (rhs.getValue().isOneValue())
+    return IntegerAttr::get(rhs.getType(),
+                            APInt(rhs.getValue().getBitWidth(), 0));
+
+  // Don't fold if it requires division by zero.
+  if (rhs.getValue().isNullValue())
+    return {};
+
+  auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
+  if (!lhs)
+    return {};
+
+  return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhs.getValue()));
+}
+
+//===----------------------------------------------------------------------===//
+// RemIUOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult RemIUOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "remiu takes two operands");
+
+  auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
+  if (!rhs)
+    return {};
+
+  // x % 1 = 0
+  if (rhs.getValue().isOneValue())
+    return IntegerAttr::get(rhs.getType(),
+                            APInt(rhs.getValue().getBitWidth(), 0));
+
+  // Don't fold if it requires division by zero.
+  if (rhs.getValue().isNullValue())
+    return {};
+
+  auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
+  if (!lhs)
+    return {};
+
+  return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhs.getValue()));
+}
+
+//===----------------------------------------------------------------------===//
+// ReturnOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseReturnOp(OpAsmParser *parser, OperationState *result) {
+  SmallVector<OpAsmParser::OperandType, 2> opInfo;
+  SmallVector<Type, 2> types;
+  llvm::SMLoc loc = parser->getCurrentLocation();
+  return failure(parser->parseOperandList(opInfo) ||
+                 (!opInfo.empty() && parser->parseColonTypeList(types)) ||
+                 parser->resolveOperands(opInfo, types, loc, result->operands));
+}
+
+static void print(OpAsmPrinter *p, ReturnOp op) {
+  *p << "return";
+  if (op.getNumOperands() > 0) {
+    *p << ' ';
+    p->printOperands(op.operand_begin(), op.operand_end());
+    *p << " : ";
+    interleave(
+        op.operand_begin(), op.operand_end(),
+        [&](Value *e) { p->printType(e->getType()); }, [&]() { *p << ", "; });
+  }
+}
+
+static LogicalResult verify(ReturnOp op) {
+  auto function = cast<FuncOp>(op.getParentOp());
+
+  // The operand number and types must match the function signature.
+  const auto &results = function.getType().getResults();
+  if (op.getNumOperands() != results.size())
+    return op.emitOpError("has ")
+           << op.getNumOperands()
+           << " operands, but enclosing function returns " << results.size();
+
+  for (unsigned i = 0, e = results.size(); i != e; ++i)
+    if (op.getOperand(i)->getType() != results[i])
+      return op.emitError()
+             << "type of return operand " << i << " ("
+             << op.getOperand(i)->getType()
+             << ") doesn't match function result type (" << results[i] << ")";
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// SIToFPOp
+//===----------------------------------------------------------------------===//
+
+// sitofp is applicable from integer types to float types.
+bool SIToFPOp::areCastCompatible(Type a, Type b) {
+  return a.isa<IntegerType>() && b.isa<FloatType>();
+}
+
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseSelectOp(OpAsmParser *parser, OperationState *result) {
+  SmallVector<OpAsmParser::OperandType, 3> ops;
+  SmallVector<NamedAttribute, 4> attrs;
+  Type type;
+
+  if (parser->parseOperandList(ops, 3) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(type))
+    return failure();
+
+  auto i1Type = getCheckedI1SameShape(&parser->getBuilder(), type);
+  if (!i1Type)
+    return parser->emitError(parser->getNameLoc(),
+                             "expected type with valid i1 shape");
+
+  SmallVector<Type, 3> types = {i1Type, type, type};
+  return failure(parser->resolveOperands(ops, types, parser->getNameLoc(),
+                                         result->operands) ||
+                 parser->addTypeToList(type, result->types));
+}
+
+static void print(OpAsmPrinter *p, SelectOp op) {
+  *p << "select ";
+  p->printOperands(op.getOperands());
+  *p << " : " << op.getTrueValue()->getType();
+  p->printOptionalAttrDict(op.getAttrs());
+}
+
+static LogicalResult verify(SelectOp op) {
+  auto trueType = op.getTrueValue()->getType();
+  auto falseType = op.getFalseValue()->getType();
+
+  if (trueType != falseType)
+    return op.emitOpError(
+        "requires 'true' and 'false' arguments to be of the same type");
+
+  return success();
+}
+
+OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
+  auto *condition = getCondition();
+
+  // select true, %0, %1 => %0
+  if (matchPattern(condition, m_One()))
+    return getTrueValue();
+
+  // select false, %0, %1 => %1
+  if (matchPattern(condition, m_Zero()))
+    return getFalseValue();
+  return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// StoreOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter *p, StoreOp op) {
+  *p << "store " << *op.getValueToStore();
+  *p << ", " << *op.getMemRef() << '[';
+  p->printOperands(op.getIndices());
+  *p << ']';
+  p->printOptionalAttrDict(op.getAttrs());
+  *p << " : " << op.getMemRefType();
+}
+
+static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *result) {
+  OpAsmParser::OperandType storeValueInfo;
+  OpAsmParser::OperandType memrefInfo;
+  SmallVector<OpAsmParser::OperandType, 4> indexInfo;
+  MemRefType memrefType;
+
+  auto affineIntTy = parser->getBuilder().getIndexType();
+  return failure(
+      parser->parseOperand(storeValueInfo) || parser->parseComma() ||
+      parser->parseOperand(memrefInfo) ||
+      parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(memrefType) ||
+      parser->resolveOperand(storeValueInfo, memrefType.getElementType(),
+                             result->operands) ||
+      parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
+      parser->resolveOperands(indexInfo, affineIntTy, result->operands));
+}
+
+static LogicalResult verify(StoreOp op) {
+  // First operand must have same type as memref element type.
+  if (op.getValueToStore()->getType() != op.getMemRefType().getElementType())
+    return op.emitOpError(
+        "first operand must have same type memref element type");
+
+  if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
+    return op.emitOpError("store index operand count not equal to memref rank");
+
+  for (auto *idx : op.getIndices())
+    if (!idx->getType().isIndex())
+      return op.emitOpError("index to load must have 'index' type");
+
+  // TODO: Verify we have the right number of indices.
+
+  // TODO: in Function verify that the indices are parameters, IV's, or the
+  // result of an affine.apply.
+  return success();
+}
+
+void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                          MLIRContext *context) {
+  /// store(memrefcast) -> store
+  results.insert<MemRefCastFolder>(getOperationName(), context);
+}
+
+//===----------------------------------------------------------------------===//
+// SubFOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult SubFOp::fold(ArrayRef<Attribute> operands) {
+  return constFoldBinaryOp<FloatAttr>(
+      operands, [](APFloat a, APFloat b) { return a - b; });
+}
+
+//===----------------------------------------------------------------------===//
+// SubIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
+  // subi(x,x) -> 0
+  if (getOperand(0) == getOperand(1))
+    return Builder(getContext()).getZeroAttr(getType());
+
+  return constFoldBinaryOp<IntegerAttr>(operands,
+                                        [](APInt a, APInt b) { return a - b; });
+}
+
+//===----------------------------------------------------------------------===//
+// AndOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
+  /// and(x, 0) -> 0
+  if (matchPattern(rhs(), m_Zero()))
+    return rhs();
+  /// and(x,x) -> x
+  if (lhs() == rhs())
+    return rhs();
+
+  return constFoldBinaryOp<IntegerAttr>(operands,
+                                        [](APInt a, APInt b) { return a & b; });
+}
+
+//===----------------------------------------------------------------------===//
+// OrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
+  /// or(x, 0) -> x
+  if (matchPattern(rhs(), m_Zero()))
+    return lhs();
+  /// or(x,x) -> x
+  if (lhs() == rhs())
+    return rhs();
+
+  return constFoldBinaryOp<IntegerAttr>(operands,
+                                        [](APInt a, APInt b) { return a | b; });
+}
+
+//===----------------------------------------------------------------------===//
+// XOrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
+  /// xor(x, 0) -> x
+  if (matchPattern(rhs(), m_Zero()))
+    return lhs();
+  /// xor(x,x) -> 0
+  if (lhs() == rhs())
+    return Builder(getContext()).getZeroAttr(getType());
+
+  return constFoldBinaryOp<IntegerAttr>(operands,
+                                        [](APInt a, APInt b) { return a ^ b; });
+}
+
+//===----------------------------------------------------------------------===//
+// TensorCastOp
+//===----------------------------------------------------------------------===//
+
+bool TensorCastOp::areCastCompatible(Type a, Type b) {
+  auto aT = a.dyn_cast<TensorType>();
+  auto bT = b.dyn_cast<TensorType>();
+  if (!aT || !bT)
+    return false;
+
+  if (aT.getElementType() != bT.getElementType())
+    return false;
+
+  // If the either are unranked, then the cast is valid.
+  auto aRType = aT.dyn_cast<RankedTensorType>();
+  auto bRType = bT.dyn_cast<RankedTensorType>();
+  if (!aRType || !bRType)
+    return true;
+
+  // If they are both ranked, they have to have the same rank, and any specified
+  // dimensions must match.
+  if (aRType.getRank() != bRType.getRank())
+    return false;
+
+  for (unsigned i = 0, e = aRType.getRank(); i != e; ++i) {
+    int64_t aDim = aRType.getDimSize(i), bDim = bRType.getDimSize(i);
+    if (aDim != -1 && bDim != -1 && aDim != bDim)
+      return false;
+  }
+
+  return true;
+}
+
+OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
+  return impl::foldCastOp(*this);
+}
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/StandardOps/Ops.cpp.inc"
diff --git a/third_party/mlir/lib/Support/CMakeLists.txt b/third_party/mlir/lib/Support/CMakeLists.txt
new file mode 100644
index 0000000..a927fc6
--- /dev/null
+++ b/third_party/mlir/lib/Support/CMakeLists.txt
@@ -0,0 +1,48 @@
+set(LLVM_OPTIONAL_SOURCES
+  FileUtilities.cpp
+  JitRunner.cpp
+  MlirOptMain.cpp
+  StorageUniquer.cpp
+  TranslateClParser.cpp
+)
+
+add_llvm_library(MLIRSupport
+  FileUtilities.cpp
+  StorageUniquer.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Support
+  )
+target_link_libraries(MLIRSupport LLVMSupport)
+
+add_llvm_library(MLIROptMain
+  MlirOptMain.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Support
+  )
+target_link_libraries(MLIROptMain LLVMSupport)
+
+add_llvm_library(MLIRTranslateClParser
+  TranslateClParser.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Support
+  )
+target_link_libraries(MLIRTranslateClParser LLVMSupport)
+
+add_llvm_library(MLIRJitRunner
+  JitRunner.cpp
+)
+target_link_libraries(MLIRJitRunner PRIVATE
+  MLIRExecutionEngine
+  MLIRIR
+  MLIRParser
+  MLIRStandardOps
+  MLIRTargetLLVMIR
+  MLIRTransforms
+  MLIRStandardToLLVM
+  MLIRSupport
+  LLVMCore
+  LLVMSupport
+)
diff --git a/third_party/mlir/lib/Support/FileUtilities.cpp b/third_party/mlir/lib/Support/FileUtilities.cpp
new file mode 100644
index 0000000..fb9f5cf
--- /dev/null
+++ b/third_party/mlir/lib/Support/FileUtilities.cpp
@@ -0,0 +1,56 @@
+//===- FileUtilities.cpp - utilities for working with files ---------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Definitions of common utilities for working with files.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/FileUtilities.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+using namespace mlir;
+
+std::unique_ptr<llvm::MemoryBuffer>
+mlir::openInputFile(StringRef inputFilename, std::string *errorMessage) {
+  auto fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
+  if (std::error_code error = fileOrErr.getError()) {
+    if (errorMessage)
+      *errorMessage = "cannot open input file '" + inputFilename.str() +
+                      "': " + error.message();
+    return nullptr;
+  }
+
+  return std::move(*fileOrErr);
+}
+
+std::unique_ptr<llvm::ToolOutputFile>
+mlir::openOutputFile(StringRef outputFilename, std::string *errorMessage) {
+  std::error_code error;
+  auto result = llvm::make_unique<llvm::ToolOutputFile>(outputFilename, error,
+                                                        llvm::sys::fs::F_None);
+  if (error) {
+    if (errorMessage)
+      *errorMessage = "cannot open output file '" + outputFilename.str() +
+                      "': " + error.message();
+    return nullptr;
+  }
+
+  return result;
+}
diff --git a/third_party/mlir/lib/Support/JitRunner.cpp b/third_party/mlir/lib/Support/JitRunner.cpp
new file mode 100644
index 0000000..919f829
--- /dev/null
+++ b/third_party/mlir/lib/Support/JitRunner.cpp
@@ -0,0 +1,340 @@
+//===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is a library that provides a shared implementation for command line
+// utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
+// IR before JIT-compiling and executing the latter.
+//
+// The translation can be customized by providing an MLIR to MLIR
+// transformation.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/JitRunner.h"
+
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/ExecutionEngine/ExecutionEngine.h"
+#include "mlir/ExecutionEngine/MemRefUtils.h"
+#include "mlir/ExecutionEngine/OptUtils.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/Parser.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/LegacyPassNameParser.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FileUtilities.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/PrettyStackTrace.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/StringSaver.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/ToolOutputFile.h"
+#include <numeric>
+
+using namespace mlir;
+using llvm::Error;
+
+static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
+                                                llvm::cl::desc("<input file>"),
+                                                llvm::cl::init("-"));
+static llvm::cl::opt<std::string>
+    initValue("init-value", llvm::cl::desc("Initial value of MemRef elements"),
+              llvm::cl::value_desc("<float value>"), llvm::cl::init("0.0"));
+static llvm::cl::opt<std::string>
+    mainFuncName("e", llvm::cl::desc("The function to be called"),
+                 llvm::cl::value_desc("<function name>"),
+                 llvm::cl::init("main"));
+static llvm::cl::opt<std::string> mainFuncType(
+    "entry-point-result",
+    llvm::cl::desc("Textual description of the function type to be called"),
+    llvm::cl::value_desc("f32 or memrefs"), llvm::cl::init("memrefs"));
+
+static llvm::cl::OptionCategory optFlags("opt-like flags");
+
+// CLI list of pass information
+static llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser>
+    llvmPasses(llvm::cl::desc("LLVM optimizing passes to run"),
+               llvm::cl::cat(optFlags));
+
+// CLI variables for -On options.
+static llvm::cl::opt<bool> optO0("O0", llvm::cl::desc("Run opt O0 passes"),
+                                 llvm::cl::cat(optFlags));
+static llvm::cl::opt<bool> optO1("O1", llvm::cl::desc("Run opt O1 passes"),
+                                 llvm::cl::cat(optFlags));
+static llvm::cl::opt<bool> optO2("O2", llvm::cl::desc("Run opt O2 passes"),
+                                 llvm::cl::cat(optFlags));
+static llvm::cl::opt<bool> optO3("O3", llvm::cl::desc("Run opt O3 passes"),
+                                 llvm::cl::cat(optFlags));
+
+static llvm::cl::OptionCategory clOptionsCategory("linking options");
+static llvm::cl::list<std::string>
+    clSharedLibs("shared-libs", llvm::cl::desc("Libraries to link dynamically"),
+                 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
+                 llvm::cl::cat(clOptionsCategory));
+
+static OwningModuleRef parseMLIRInput(StringRef inputFilename,
+                                      MLIRContext *context) {
+  // Set up the input file.
+  std::string errorMessage;
+  auto file = openInputFile(inputFilename, &errorMessage);
+  if (!file) {
+    llvm::errs() << errorMessage << "\n";
+    return nullptr;
+  }
+
+  llvm::SourceMgr sourceMgr;
+  sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
+  return OwningModuleRef(parseSourceFile(sourceMgr, context));
+}
+
+// Initialize the relevant subsystems of LLVM.
+static void initializeLLVM() {
+  llvm::InitializeNativeTarget();
+  llvm::InitializeNativeTargetAsmPrinter();
+}
+
+static inline Error make_string_error(const llvm::Twine &message) {
+  return llvm::make_error<llvm::StringError>(message.str(),
+                                             llvm::inconvertibleErrorCode());
+}
+
+static void printOneMemRef(Type t, void *val) {
+  auto memRefType = t.cast<MemRefType>();
+  auto shape = memRefType.getShape();
+  int64_t size = std::accumulate(shape.begin(), shape.end(), 1,
+                                 std::multiplies<int64_t>());
+  for (int64_t i = 0; i < size; ++i) {
+    llvm::outs() << reinterpret_cast<StaticFloatMemRef *>(val)->data[i] << ' ';
+  }
+  llvm::outs() << '\n';
+}
+
+static void printMemRefArguments(ArrayRef<Type> argTypes,
+                                 ArrayRef<Type> resTypes,
+                                 ArrayRef<void *> args) {
+  auto properArgs = args.take_front(argTypes.size());
+  for (const auto &kvp : llvm::zip(argTypes, properArgs)) {
+    auto type = std::get<0>(kvp);
+    auto val = std::get<1>(kvp);
+    printOneMemRef(type, val);
+  }
+
+  auto results = args.drop_front(argTypes.size());
+  for (const auto &kvp : llvm::zip(resTypes, results)) {
+    auto type = std::get<0>(kvp);
+    auto val = std::get<1>(kvp);
+    printOneMemRef(type, val);
+  }
+}
+
+// Calls the passes necessary to convert affine and standard dialects to the
+// LLVM IR dialect.
+// Currently, these passes are:
+// - CSE
+// - canonicalization
+// - affine to standard lowering
+// - standard to llvm lowering
+static LogicalResult convertAffineStandardToLLVMIR(ModuleOp module) {
+  PassManager manager;
+  manager.addPass(mlir::createCanonicalizerPass());
+  manager.addPass(mlir::createCSEPass());
+  manager.addPass(mlir::createLowerAffinePass());
+  manager.addPass(mlir::createConvertToLLVMIRPass());
+  return manager.run(module);
+}
+
+static Error compileAndExecuteFunctionWithMemRefs(
+    ModuleOp module, StringRef entryPoint,
+    std::function<llvm::Error(llvm::Module *)> transformer) {
+  FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint);
+  if (!mainFunction || mainFunction.getBlocks().empty()) {
+    return make_string_error("entry point not found");
+  }
+
+  // Store argument and result types of the original function necessary to
+  // pretty print the results, because the function itself will be rewritten
+  // to use the LLVM dialect.
+  SmallVector<Type, 8> argTypes =
+      llvm::to_vector<8>(mainFunction.getType().getInputs());
+  SmallVector<Type, 8> resTypes =
+      llvm::to_vector<8>(mainFunction.getType().getResults());
+
+  float init = std::stof(initValue.getValue());
+
+  auto expectedArguments = allocateMemRefArguments(mainFunction, init);
+  if (!expectedArguments)
+    return expectedArguments.takeError();
+
+  if (failed(convertAffineStandardToLLVMIR(module)))
+    return make_string_error("conversion to the LLVM IR dialect failed");
+
+  SmallVector<StringRef, 4> libs(clSharedLibs.begin(), clSharedLibs.end());
+  auto expectedEngine =
+      mlir::ExecutionEngine::create(module, transformer, libs);
+  if (!expectedEngine)
+    return expectedEngine.takeError();
+
+  auto engine = std::move(*expectedEngine);
+  auto expectedFPtr = engine->lookup(entryPoint);
+  if (!expectedFPtr)
+    return expectedFPtr.takeError();
+  void (*fptr)(void **) = *expectedFPtr;
+  (*fptr)(expectedArguments->data());
+  printMemRefArguments(argTypes, resTypes, *expectedArguments);
+  freeMemRefArguments(*expectedArguments);
+
+  return Error::success();
+}
+
+static Error compileAndExecuteSingleFloatReturnFunction(
+    ModuleOp module, StringRef entryPoint,
+    std::function<llvm::Error(llvm::Module *)> transformer) {
+  FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint);
+  if (!mainFunction || mainFunction.isExternal()) {
+    return make_string_error("entry point not found");
+  }
+
+  if (!mainFunction.getType().getInputs().empty())
+    return make_string_error("function inputs not supported");
+
+  if (mainFunction.getType().getResults().size() != 1)
+    return make_string_error("only single f32 function result supported");
+
+  auto t = mainFunction.getType().getResults()[0].dyn_cast<LLVM::LLVMType>();
+  if (!t)
+    return make_string_error("only single llvm.f32 function result supported");
+  auto *llvmTy = t.getUnderlyingType();
+  if (llvmTy != llvmTy->getFloatTy(llvmTy->getContext()))
+    return make_string_error("only single llvm.f32 function result supported");
+
+  SmallVector<StringRef, 4> libs(clSharedLibs.begin(), clSharedLibs.end());
+  auto expectedEngine =
+      mlir::ExecutionEngine::create(module, transformer, libs);
+  if (!expectedEngine)
+    return expectedEngine.takeError();
+
+  auto engine = std::move(*expectedEngine);
+  auto expectedFPtr = engine->lookup(entryPoint);
+  if (!expectedFPtr)
+    return expectedFPtr.takeError();
+  void (*fptr)(void **) = *expectedFPtr;
+
+  float res;
+  struct {
+    void *data;
+  } data;
+  data.data = &res;
+  (*fptr)((void **)&data);
+
+  // Intentional printing of the output so we can test.
+  llvm::outs() << res;
+
+  return Error::success();
+}
+
+// Entry point for all CPU runners. Expects the common argc/argv arguments for
+// standard C++ main functions and an mlirTransformer.
+// The latter is applied after parsing the input into MLIR IR and before passing
+// the MLIR module to the ExecutionEngine.
+int mlir::JitRunnerMain(
+    int argc, char **argv,
+    llvm::function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer) {
+  llvm::PrettyStackTraceProgram x(argc, argv);
+  llvm::InitLLVM y(argc, argv);
+
+  initializeLLVM();
+  mlir::initializeLLVMPasses();
+
+  llvm::SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
+      optO0, optO1, optO2, optO3};
+
+  llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
+
+  llvm::SmallVector<const llvm::PassInfo *, 4> passes;
+  llvm::Optional<unsigned> optLevel;
+  unsigned optCLIPosition = 0;
+  // Determine if there is an optimization flag present, and its CLI position
+  // (optCLIPosition).
+  for (unsigned j = 0; j < 4; ++j) {
+    auto &flag = optFlags[j].get();
+    if (flag) {
+      optLevel = j;
+      optCLIPosition = flag.getPosition();
+      break;
+    }
+  }
+  // Generate vector of pass information, plus the index at which we should
+  // insert any optimization passes in that vector (optPosition).
+  unsigned optPosition = 0;
+  for (unsigned i = 0, e = llvmPasses.size(); i < e; ++i) {
+    passes.push_back(llvmPasses[i]);
+    if (optCLIPosition < llvmPasses.getPosition(i)) {
+      optPosition = i;
+      optCLIPosition = UINT_MAX; // To ensure we never insert again
+    }
+  }
+
+  MLIRContext context;
+  auto m = parseMLIRInput(inputFilename, &context);
+  if (!m) {
+    llvm::errs() << "could not parse the input IR\n";
+    return 1;
+  }
+
+  if (mlirTransformer)
+    if (failed(mlirTransformer(m.get())))
+      return EXIT_FAILURE;
+
+  auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
+  if (!tmBuilderOrError) {
+    llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
+    return EXIT_FAILURE;
+  }
+  auto tmOrError = tmBuilderOrError->createTargetMachine();
+  if (!tmOrError) {
+    llvm::errs() << "Failed to create a TargetMachine for the host\n";
+    return EXIT_FAILURE;
+  }
+
+  auto transformer = mlir::makeLLVMPassesTransformer(
+      passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition);
+  auto error = mainFuncType.getValue() == "f32"
+                   ? compileAndExecuteSingleFloatReturnFunction(
+                         m.get(), mainFuncName.getValue(), transformer)
+                   : compileAndExecuteFunctionWithMemRefs(
+                         m.get(), mainFuncName.getValue(), transformer);
+  int exitCode = EXIT_SUCCESS;
+  llvm::handleAllErrors(std::move(error),
+                        [&exitCode](const llvm::ErrorInfoBase &info) {
+                          llvm::errs() << "Error: ";
+                          info.log(llvm::errs());
+                          llvm::errs() << '\n';
+                          exitCode = EXIT_FAILURE;
+                        });
+
+  return exitCode;
+}
diff --git a/third_party/mlir/lib/Support/MlirOptMain.cpp b/third_party/mlir/lib/Support/MlirOptMain.cpp
new file mode 100644
index 0000000..80cba5a
--- /dev/null
+++ b/third_party/mlir/lib/Support/MlirOptMain.cpp
@@ -0,0 +1,155 @@
+//===- MlirOptMain.cpp - MLIR Optimizer Driver ----------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is a utility that runs an optimization pass and prints the result back
+// out. It is designed to support unit testing.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/MlirOptMain.h"
+#include "mlir/Analysis/Passes.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Parser.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/Support/FileUtilities.h"
+#include "llvm/Support/Regex.h"
+#include "llvm/Support/SourceMgr.h"
+
+using namespace mlir;
+using namespace llvm;
+using llvm::SMLoc;
+
+/// Perform the actions on the input file indicated by the command line flags
+/// within the specified context.
+///
+/// This typically parses the main source file, runs zero or more optimization
+/// passes, then prints the output.
+///
+static LogicalResult
+performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses,
+               SourceMgr &sourceMgr, MLIRContext *context,
+               const std::vector<const mlir::PassRegistryEntry *> &passList) {
+  OwningModuleRef module(parseSourceFile(sourceMgr, context));
+  if (!module)
+    return failure();
+
+  // Apply any pass manager command line options.
+  PassManager pm(verifyPasses);
+  applyPassManagerCLOptions(pm);
+
+  // Run each of the passes that were selected.
+  for (const auto *passEntry : passList)
+    passEntry->addToPipeline(pm);
+
+  // Run the pipeline.
+  if (failed(pm.run(*module)))
+    return failure();
+
+  // Print the output.
+  module->print(os);
+  return success();
+}
+
+/// Parses the memory buffer.  If successfully, run a series of passes against
+/// it and print the result.
+static LogicalResult
+processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
+              bool verifyDiagnostics, bool verifyPasses,
+              const std::vector<const mlir::PassRegistryEntry *> &passList) {
+  // Tell sourceMgr about this buffer, which is what the parser will pick up.
+  SourceMgr sourceMgr;
+  sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
+
+  // Parse the input file.
+  MLIRContext context;
+
+  // If we are in verify diagnostics mode then we have a lot of work to do,
+  // otherwise just perform the actions without worrying about it.
+  if (!verifyDiagnostics) {
+    SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
+    return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr,
+                          &context, passList);
+  }
+
+  SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
+
+  // Do any processing requested by command line flags.  We don't care whether
+  // these actions succeed or fail, we only care what diagnostics they produce
+  // and whether they match our expectations.
+  performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context,
+                 passList);
+
+  // Verify the diagnostic handler to make sure that each of the diagnostics
+  // matched.
+  return sourceMgrHandler.verify();
+}
+
+/// Split the specified file on a marker and process each chunk independently
+/// according to the normal processBuffer logic.  This is primarily used to
+/// allow a large number of small independent parser tests to be put into a
+/// single test, but could be used for other purposes as well.
+static LogicalResult splitAndProcessFile(
+    raw_ostream &os, std::unique_ptr<MemoryBuffer> originalBuffer,
+    bool verifyDiagnostics, bool verifyPasses,
+    const std::vector<const mlir::PassRegistryEntry *> &passList) {
+  const char marker[] = "// -----";
+  auto *origMemBuffer = originalBuffer.get();
+  SmallVector<StringRef, 8> sourceBuffers;
+  origMemBuffer->getBuffer().split(sourceBuffers, marker);
+
+  // Add the original buffer to the source manager.
+  SourceMgr fileSourceMgr;
+  fileSourceMgr.AddNewSourceBuffer(std::move(originalBuffer), SMLoc());
+
+  bool hadUnexpectedResult = false;
+
+  // Process each chunk in turn.  If any fails, then return a failure of the
+  // tool.
+  for (auto &subBuffer : sourceBuffers) {
+    auto splitLoc = SMLoc::getFromPointer(subBuffer.data());
+    unsigned splitLine = fileSourceMgr.getLineAndColumn(splitLoc).first;
+    auto subMemBuffer = MemoryBuffer::getMemBufferCopy(
+        subBuffer, origMemBuffer->getBufferIdentifier() +
+                       Twine(" split at line #") + Twine(splitLine));
+    if (failed(processBuffer(os, std::move(subMemBuffer), verifyDiagnostics,
+                             verifyPasses, passList)))
+      hadUnexpectedResult = true;
+  }
+
+  return failure(hadUnexpectedResult);
+}
+
+LogicalResult
+mlir::MlirOptMain(raw_ostream &os, std::unique_ptr<MemoryBuffer> buffer,
+                  const std::vector<const mlir::PassRegistryEntry *> &passList,
+                  bool splitInputFile, bool verifyDiagnostics,
+                  bool verifyPasses) {
+  // The split-input-file mode is a very specific mode that slices the file
+  // up into small pieces and checks each independently.
+  if (splitInputFile)
+    return splitAndProcessFile(os, std::move(buffer), verifyDiagnostics,
+                               verifyPasses, passList);
+
+  return processBuffer(os, std::move(buffer), verifyDiagnostics, verifyPasses,
+                       passList);
+}
diff --git a/third_party/mlir/lib/Support/StorageUniquer.cpp b/third_party/mlir/lib/Support/StorageUniquer.cpp
new file mode 100644
index 0000000..c004b61
--- /dev/null
+++ b/third_party/mlir/lib/Support/StorageUniquer.cpp
@@ -0,0 +1,208 @@
+//===- StorageUniquer.cpp - Common Storage Class Uniquer ------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Support/StorageUniquer.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/RWMutex.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+namespace mlir {
+namespace detail {
+/// This is the implementation of the StorageUniquer class.
+struct StorageUniquerImpl {
+  using BaseStorage = StorageUniquer::BaseStorage;
+  using StorageAllocator = StorageUniquer::StorageAllocator;
+
+  /// A lookup key for derived instances of storage objects.
+  struct LookupKey {
+    /// The known derived kind for the storage.
+    unsigned kind;
+
+    /// The known hash value of the key.
+    unsigned hashValue;
+
+    /// An equality function for comparing with an existing storage instance.
+    llvm::function_ref<bool(const BaseStorage *)> isEqual;
+  };
+
+  /// A utility wrapper object representing a hashed storage object. This class
+  /// contains a storage object and an existing computed hash value.
+  struct HashedStorage {
+    unsigned hashValue;
+    BaseStorage *storage;
+  };
+
+  /// Get or create an instance of a complex derived type.
+  BaseStorage *
+  getOrCreate(unsigned kind, unsigned hashValue,
+              llvm::function_ref<bool(const BaseStorage *)> isEqual,
+              llvm::function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
+    LookupKey lookupKey{kind, hashValue, isEqual};
+
+    // Check for an existing instance in read-only mode.
+    {
+      llvm::sys::SmartScopedReader<true> typeLock(mutex);
+      auto it = storageTypes.find_as(lookupKey);
+      if (it != storageTypes.end())
+        return it->storage;
+    }
+
+    // Acquire a writer-lock so that we can safely create the new type instance.
+    llvm::sys::SmartScopedWriter<true> typeLock(mutex);
+
+    // Check for an existing instance again here, because another writer thread
+    // may have already created one.
+    auto existing = storageTypes.insert_as({}, lookupKey);
+    if (!existing.second)
+      return existing.first->storage;
+
+    // Otherwise, construct and initialize the derived storage for this type
+    // instance.
+    BaseStorage *storage = initializeStorage(kind, ctorFn);
+    *existing.first = HashedStorage{hashValue, storage};
+    return storage;
+  }
+
+  /// Get or create an instance of a simple derived type.
+  BaseStorage *
+  getOrCreate(unsigned kind,
+              llvm::function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
+    // Check for an existing instance in read-only mode.
+    {
+      llvm::sys::SmartScopedReader<true> typeLock(mutex);
+      auto it = simpleTypes.find(kind);
+      if (it != simpleTypes.end())
+        return it->second;
+    }
+
+    // Acquire a writer-lock so that we can safely create the new type instance.
+    llvm::sys::SmartScopedWriter<true> typeLock(mutex);
+
+    // Check for an existing instance again here, because another writer thread
+    // may have already created one.
+    auto &result = simpleTypes[kind];
+    if (result)
+      return result;
+
+    // Otherwise, create and return a new storage instance.
+    return result = initializeStorage(kind, ctorFn);
+  }
+
+  /// Erase an instance of a complex derived type.
+  void erase(unsigned kind, unsigned hashValue,
+             llvm::function_ref<bool(const BaseStorage *)> isEqual,
+             llvm::function_ref<void(BaseStorage *)> cleanupFn) {
+    LookupKey lookupKey{kind, hashValue, isEqual};
+
+    // Acquire a writer-lock so that we can safely erase the type instance.
+    llvm::sys::SmartScopedWriter<true> typeLock(mutex);
+    auto existing = storageTypes.find_as(lookupKey);
+    if (existing == storageTypes.end())
+      return;
+
+    // Cleanup the storage and remove it from the map.
+    cleanupFn(existing->storage);
+    storageTypes.erase(existing);
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Instance Storage
+  //===--------------------------------------------------------------------===//
+
+  /// Utility to create and initialize a storage instance.
+  BaseStorage *initializeStorage(
+      unsigned kind,
+      llvm::function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
+    BaseStorage *storage = ctorFn(allocator);
+    storage->kind = kind;
+    return storage;
+  }
+
+  /// Storage info for derived TypeStorage objects.
+  struct StorageKeyInfo : DenseMapInfo<HashedStorage> {
+    static HashedStorage getEmptyKey() {
+      return HashedStorage{0, DenseMapInfo<BaseStorage *>::getEmptyKey()};
+    }
+    static HashedStorage getTombstoneKey() {
+      return HashedStorage{0, DenseMapInfo<BaseStorage *>::getTombstoneKey()};
+    }
+
+    static unsigned getHashValue(const HashedStorage &key) {
+      return key.hashValue;
+    }
+    static unsigned getHashValue(LookupKey key) { return key.hashValue; }
+
+    static bool isEqual(const HashedStorage &lhs, const HashedStorage &rhs) {
+      return lhs.storage == rhs.storage;
+    }
+    static bool isEqual(const LookupKey &lhs, const HashedStorage &rhs) {
+      if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
+        return false;
+      // If the lookup kind matches the kind of the storage, then invoke the
+      // equality function on the lookup key.
+      return lhs.kind == rhs.storage->getKind() && lhs.isEqual(rhs.storage);
+    }
+  };
+
+  // Unique types with specific hashing or storage constraints.
+  using StorageTypeSet = llvm::DenseSet<HashedStorage, StorageKeyInfo>;
+  StorageTypeSet storageTypes;
+
+  // Unique types with just the kind.
+  llvm::DenseMap<unsigned, BaseStorage *> simpleTypes;
+
+  // Allocator to use when constructing derived type instances.
+  StorageUniquer::StorageAllocator allocator;
+
+  // A mutex to keep type uniquing thread-safe.
+  llvm::sys::SmartRWMutex<true> mutex;
+};
+} // end namespace detail
+} // namespace mlir
+
+StorageUniquer::StorageUniquer() : impl(new StorageUniquerImpl()) {}
+StorageUniquer::~StorageUniquer() {}
+
+/// Implementation for getting/creating an instance of a derived type with
+/// complex storage.
+auto StorageUniquer::getImpl(
+    unsigned kind, unsigned hashValue,
+    llvm::function_ref<bool(const BaseStorage *)> isEqual,
+    std::function<BaseStorage *(StorageAllocator &)> ctorFn) -> BaseStorage * {
+  return impl->getOrCreate(kind, hashValue, isEqual, ctorFn);
+}
+
+/// Implementation for getting/creating an instance of a derived type with
+/// default storage.
+auto StorageUniquer::getImpl(
+    unsigned kind, std::function<BaseStorage *(StorageAllocator &)> ctorFn)
+    -> BaseStorage * {
+  return impl->getOrCreate(kind, ctorFn);
+}
+
+/// Implementation for erasing an instance of a derived type with complex
+/// storage.
+void StorageUniquer::eraseImpl(
+    unsigned kind, unsigned hashValue,
+    llvm::function_ref<bool(const BaseStorage *)> isEqual,
+    std::function<void(BaseStorage *)> cleanupFn) {
+  impl->erase(kind, hashValue, isEqual, cleanupFn);
+}
diff --git a/third_party/mlir/lib/Support/TranslateClParser.cpp b/third_party/mlir/lib/Support/TranslateClParser.cpp
new file mode 100644
index 0000000..8a7367f
--- /dev/null
+++ b/third_party/mlir/lib/Support/TranslateClParser.cpp
@@ -0,0 +1,105 @@
+//===- TranslateClParser.h - Translations command line parser -------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains custom command line parser for translations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/TranslateClParser.h"
+
+#include "mlir/Analysis/Verifier.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Parser.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Translation.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FileUtilities.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+using namespace mlir;
+
+// Storage for the translation function wrappers that survive the parser.
+static llvm::SmallVector<TranslateFunction, 16> wrapperStorage;
+
+static LogicalResult printMLIROutput(ModuleOp module,
+                                     llvm::StringRef outputFilename) {
+  if (failed(verify(module)))
+    return failure();
+  auto file = openOutputFile(outputFilename);
+  if (!file)
+    return failure();
+  module.print(file->os());
+  file->keep();
+  return success();
+}
+
+TranslationParser::TranslationParser(llvm::cl::Option &opt)
+    : llvm::cl::parser<const TranslateFunction *>(opt) {
+  const auto &toMLIRRegistry = getTranslationToMLIRRegistry();
+  const auto &fromMLIRRegistry = getTranslationFromMLIRRegistry();
+
+  // Reserve the required capacity upfront so that pointers are not
+  // invalidated on reallocation.
+  wrapperStorage.reserve(toMLIRRegistry.size() + fromMLIRRegistry.size());
+  for (const auto &kv : toMLIRRegistry) {
+    TranslateToMLIRFunction function = kv.second;
+    TranslateFunction wrapper = [function](StringRef inputFilename,
+                                           StringRef outputFilename,
+                                           MLIRContext *context) {
+      OwningModuleRef module = function(inputFilename, context);
+      if (!module)
+        return failure();
+      return printMLIROutput(*module, outputFilename);
+    };
+    wrapperStorage.emplace_back(std::move(wrapper));
+
+    addLiteralOption(kv.first(), &wrapperStorage.back(), kv.first());
+  }
+
+  for (const auto &kv : fromMLIRRegistry) {
+    TranslateFromMLIRFunction function = kv.second;
+    TranslateFunction wrapper = [function](StringRef inputFilename,
+                                           StringRef outputFilename,
+                                           MLIRContext *context) {
+      llvm::SourceMgr sourceMgr;
+      SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
+      auto module =
+          OwningModuleRef(parseSourceFile(inputFilename, sourceMgr, context));
+      if (!module)
+        return failure();
+      return function(module.get(), outputFilename);
+    };
+    wrapperStorage.emplace_back(std::move(wrapper));
+
+    addLiteralOption(kv.first(), &wrapperStorage.back(), kv.first());
+  }
+}
+
+void TranslationParser::printOptionInfo(const llvm::cl::Option &O,
+                                        size_t GlobalWidth) const {
+  TranslationParser *TP = const_cast<TranslationParser *>(this);
+  llvm::array_pod_sort(TP->Values.begin(), TP->Values.end(),
+                       [](const TranslationParser::OptionInfo *VT1,
+                          const TranslationParser::OptionInfo *VT2) {
+                         return VT1->Name.compare(VT2->Name);
+                       });
+  using llvm::cl::parser;
+  parser<const TranslateFunction *>::printOptionInfo(O, GlobalWidth);
+}
diff --git a/third_party/mlir/lib/TableGen/Argument.cpp b/third_party/mlir/lib/TableGen/Argument.cpp
new file mode 100644
index 0000000..17dba05
--- /dev/null
+++ b/third_party/mlir/lib/TableGen/Argument.cpp
@@ -0,0 +1,29 @@
+//===- Argument.cpp - Argument definitions --------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/TableGen/Argument.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+
+bool tblgen::NamedTypeConstraint::hasPredicate() const {
+  return !constraint.getPredicate().isNull();
+}
+
+bool tblgen::NamedTypeConstraint::isVariadic() const {
+  return constraint.isVariadic();
+}
diff --git a/third_party/mlir/lib/TableGen/Attribute.cpp b/third_party/mlir/lib/TableGen/Attribute.cpp
new file mode 100644
index 0000000..b42bb94
--- /dev/null
+++ b/third_party/mlir/lib/TableGen/Attribute.cpp
@@ -0,0 +1,212 @@
+//===- Attribute.cpp - Attribute wrapper class ----------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Attribute wrapper to simplify using TableGen Record defining a MLIR
+// Attribute.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Format.h"
+#include "mlir/TableGen/Operator.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+
+using llvm::CodeInit;
+using llvm::DefInit;
+using llvm::Init;
+using llvm::Record;
+using llvm::StringInit;
+
+// Returns the initializer's value as string if the given TableGen initializer
+// is a code or string initializer. Returns the empty StringRef otherwise.
+static StringRef getValueAsString(const Init *init) {
+  if (const auto *code = dyn_cast<CodeInit>(init))
+    return code->getValue().trim();
+  else if (const auto *str = dyn_cast<StringInit>(init))
+    return str->getValue().trim();
+  return {};
+}
+
+tblgen::AttrConstraint::AttrConstraint(const Record *record)
+    : Constraint(Constraint::CK_Attr, record) {
+  assert(def->isSubClassOf("AttrConstraint") &&
+         "must be subclass of TableGen 'AttrConstraint' class");
+}
+
+tblgen::Attribute::Attribute(const Record *record) : AttrConstraint(record) {
+  assert(record->isSubClassOf("Attr") &&
+         "must be subclass of TableGen 'Attr' class");
+}
+
+tblgen::Attribute::Attribute(const DefInit *init) : Attribute(init->getDef()) {}
+
+bool tblgen::Attribute::isDerivedAttr() const {
+  return def->isSubClassOf("DerivedAttr");
+}
+
+bool tblgen::Attribute::isTypeAttr() const {
+  return def->isSubClassOf("TypeAttrBase");
+}
+
+bool tblgen::Attribute::isEnumAttr() const {
+  return def->isSubClassOf("EnumAttrInfo");
+}
+
+StringRef tblgen::Attribute::getStorageType() const {
+  const auto *init = def->getValueInit("storageType");
+  auto type = getValueAsString(init);
+  if (type.empty())
+    return "Attribute";
+  return type;
+}
+
+StringRef tblgen::Attribute::getReturnType() const {
+  const auto *init = def->getValueInit("returnType");
+  return getValueAsString(init);
+}
+
+StringRef tblgen::Attribute::getConvertFromStorageCall() const {
+  const auto *init = def->getValueInit("convertFromStorage");
+  return getValueAsString(init);
+}
+
+bool tblgen::Attribute::isConstBuildable() const {
+  const auto *init = def->getValueInit("constBuilderCall");
+  return !getValueAsString(init).empty();
+}
+
+StringRef tblgen::Attribute::getConstBuilderTemplate() const {
+  const auto *init = def->getValueInit("constBuilderCall");
+  return getValueAsString(init);
+}
+
+tblgen::Attribute tblgen::Attribute::getBaseAttr() const {
+  if (const auto *defInit =
+          llvm::dyn_cast<llvm::DefInit>(def->getValueInit("baseAttr"))) {
+    return Attribute(defInit).getBaseAttr();
+  }
+  return *this;
+}
+
+bool tblgen::Attribute::hasDefaultValueInitializer() const {
+  const auto *init = def->getValueInit("defaultValue");
+  return !getValueAsString(init).empty();
+}
+
+StringRef tblgen::Attribute::getDefaultValueInitializer() const {
+  const auto *init = def->getValueInit("defaultValue");
+  return getValueAsString(init);
+}
+
+bool tblgen::Attribute::isOptional() const {
+  return def->getValueAsBit("isOptional");
+}
+
+StringRef tblgen::Attribute::getAttrDefName() const {
+  if (def->isAnonymous()) {
+    return getBaseAttr().def->getName();
+  }
+  return def->getName();
+}
+
+StringRef tblgen::Attribute::getDerivedCodeBody() const {
+  assert(isDerivedAttr() && "only derived attribute has 'body' field");
+  return def->getValueAsString("body");
+}
+
+tblgen::ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
+  assert(def->isSubClassOf("ConstantAttr") &&
+         "must be subclass of TableGen 'ConstantAttr' class");
+}
+
+tblgen::Attribute tblgen::ConstantAttr::getAttribute() const {
+  return Attribute(def->getValueAsDef("attr"));
+}
+
+StringRef tblgen::ConstantAttr::getConstantValue() const {
+  return def->getValueAsString("value");
+}
+
+tblgen::EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
+    : Attribute(init) {
+  assert(def->isSubClassOf("EnumAttrCaseInfo") &&
+         "must be subclass of TableGen 'EnumAttrInfo' class");
+}
+
+bool tblgen::EnumAttrCase::isStrCase() const {
+  return def->isSubClassOf("StrEnumAttrCase");
+}
+
+StringRef tblgen::EnumAttrCase::getSymbol() const {
+  return def->getValueAsString("symbol");
+}
+
+int64_t tblgen::EnumAttrCase::getValue() const {
+  return def->getValueAsInt("value");
+}
+
+tblgen::EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) {
+  assert(def->isSubClassOf("EnumAttrInfo") &&
+         "must be subclass of TableGen 'EnumAttr' class");
+}
+
+tblgen::EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {}
+
+tblgen::EnumAttr::EnumAttr(const llvm::DefInit *init)
+    : EnumAttr(init->getDef()) {}
+
+StringRef tblgen::EnumAttr::getEnumClassName() const {
+  return def->getValueAsString("className");
+}
+
+StringRef tblgen::EnumAttr::getCppNamespace() const {
+  return def->getValueAsString("cppNamespace");
+}
+
+StringRef tblgen::EnumAttr::getUnderlyingType() const {
+  return def->getValueAsString("underlyingType");
+}
+
+StringRef tblgen::EnumAttr::getUnderlyingToSymbolFnName() const {
+  return def->getValueAsString("underlyingToSymbolFnName");
+}
+
+StringRef tblgen::EnumAttr::getStringToSymbolFnName() const {
+  return def->getValueAsString("stringToSymbolFnName");
+}
+
+StringRef tblgen::EnumAttr::getSymbolToStringFnName() const {
+  return def->getValueAsString("symbolToStringFnName");
+}
+
+StringRef tblgen::EnumAttr::getMaxEnumValFnName() const {
+  return def->getValueAsString("maxEnumValFnName");
+}
+
+std::vector<tblgen::EnumAttrCase> tblgen::EnumAttr::getAllCases() const {
+  const auto *inits = def->getValueAsListInit("enumerants");
+
+  std::vector<tblgen::EnumAttrCase> cases;
+  cases.reserve(inits->size());
+
+  for (const llvm::Init *init : *inits) {
+    cases.push_back(tblgen::EnumAttrCase(cast<llvm::DefInit>(init)));
+  }
+
+  return cases;
+}
diff --git a/third_party/mlir/lib/TableGen/CMakeLists.txt b/third_party/mlir/lib/TableGen/CMakeLists.txt
new file mode 100644
index 0000000..48ad446
--- /dev/null
+++ b/third_party/mlir/lib/TableGen/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_llvm_library(LLVMMLIRTableGen
+  Argument.cpp
+  Attribute.cpp
+  Constraint.cpp
+  Dialect.cpp
+  Format.cpp
+  Operator.cpp
+  OpTrait.cpp
+  Pattern.cpp
+  Predicate.cpp
+  Type.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/TableGen
+  )
+target_link_libraries(LLVMMLIRTableGen LLVMSupport LLVMTableGen)
diff --git a/third_party/mlir/lib/TableGen/Constraint.cpp b/third_party/mlir/lib/TableGen/Constraint.cpp
new file mode 100644
index 0000000..ef3fa52
--- /dev/null
+++ b/third_party/mlir/lib/TableGen/Constraint.cpp
@@ -0,0 +1,69 @@
+//===- Constraint.cpp - Constraint class ----------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Constraint wrapper to simplify using TableGen Record for constraints.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Constraint.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir::tblgen;
+
+Constraint::Constraint(const llvm::Record *record)
+    : def(record), kind(CK_Uncategorized) {
+  if (record->isSubClassOf("TypeConstraint")) {
+    kind = CK_Type;
+  } else if (record->isSubClassOf("AttrConstraint")) {
+    kind = CK_Attr;
+  } else if (record->isSubClassOf("RegionConstraint")) {
+    kind = CK_Region;
+  } else {
+    assert(record->isSubClassOf("Constraint"));
+  }
+}
+
+Constraint::Constraint(Kind kind, const llvm::Record *record)
+    : def(record), kind(kind) {}
+
+Pred Constraint::getPredicate() const {
+  auto *val = def->getValue("predicate");
+
+  // If no predicate is specified, then return the null predicate (which
+  // corresponds to true).
+  if (!val)
+    return Pred();
+
+  const auto *pred = dyn_cast<llvm::DefInit>(val->getValue());
+  return Pred(pred);
+}
+
+std::string Constraint::getConditionTemplate() const {
+  return getPredicate().getCondition();
+}
+
+llvm::StringRef Constraint::getDescription() const {
+  auto doc = def->getValueAsString("description");
+  if (doc.empty())
+    return def->getName();
+  return doc;
+}
+
+AppliedConstraint::AppliedConstraint(Constraint &&constraint,
+                                     llvm::StringRef self,
+                                     std::vector<std::string> &&entities)
+    : constraint(constraint), self(self), entities(std::move(entities)) {}
diff --git a/third_party/mlir/lib/TableGen/Dialect.cpp b/third_party/mlir/lib/TableGen/Dialect.cpp
new file mode 100644
index 0000000..d4a7e4f
--- /dev/null
+++ b/third_party/mlir/lib/TableGen/Dialect.cpp
@@ -0,0 +1,37 @@
+//===- Dialect.cpp - Dialect wrapper class --------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Dialect wrapper to simplify using TableGen Record defining a MLIR dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Dialect.h"
+#include "llvm/TableGen/Record.h"
+
+namespace mlir {
+namespace tblgen {
+
+StringRef tblgen::Dialect::getName() const {
+  return def.getValueAsString("name");
+}
+
+StringRef tblgen::Dialect::getCppNamespace() const {
+  return def.getValueAsString("cppNamespace");
+}
+
+} // end namespace tblgen
+} // end namespace mlir
diff --git a/third_party/mlir/lib/TableGen/Format.cpp b/third_party/mlir/lib/TableGen/Format.cpp
new file mode 100644
index 0000000..967d51a
--- /dev/null
+++ b/third_party/mlir/lib/TableGen/Format.cpp
@@ -0,0 +1,185 @@
+//===- Format.cpp - Utilities for String Format ---------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines utilities for formatting strings. They are specially
+// tailored to the needs of TableGen'ing op definitions and rewrite rules,
+// so they are not expected to be used as widely applicable utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Format.h"
+#include <cctype>
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+// Marker to indicate an error happened when replacing a placeholder.
+const char *const kMarkerForNoSubst = "<no-subst-found>";
+
+FmtContext &tblgen::FmtContext::addSubst(StringRef placeholder, Twine subst) {
+  customSubstMap[placeholder] = subst.str();
+  return *this;
+}
+
+FmtContext &tblgen::FmtContext::withBuilder(Twine subst) {
+  builtinSubstMap[PHKind::Builder] = subst.str();
+  return *this;
+}
+
+FmtContext &tblgen::FmtContext::withOp(Twine subst) {
+  builtinSubstMap[PHKind::Op] = subst.str();
+  return *this;
+}
+
+FmtContext &tblgen::FmtContext::withSelf(Twine subst) {
+  builtinSubstMap[PHKind::Self] = subst.str();
+  return *this;
+}
+
+Optional<StringRef>
+tblgen::FmtContext::getSubstFor(FmtContext::PHKind placeholder) const {
+  if (placeholder == FmtContext::PHKind::None ||
+      placeholder == FmtContext::PHKind::Custom)
+    return {};
+  auto it = builtinSubstMap.find(placeholder);
+  if (it == builtinSubstMap.end())
+    return {};
+  return StringRef(it->second);
+}
+
+Optional<StringRef>
+tblgen::FmtContext::getSubstFor(StringRef placeholder) const {
+  auto it = customSubstMap.find(placeholder);
+  if (it == customSubstMap.end())
+    return {};
+  return StringRef(it->second);
+}
+
+FmtContext::PHKind tblgen::FmtContext::getPlaceHolderKind(StringRef str) {
+  return llvm::StringSwitch<FmtContext::PHKind>(str)
+      .Case("_builder", FmtContext::PHKind::Builder)
+      .Case("_op", FmtContext::PHKind::Op)
+      .Case("_self", FmtContext::PHKind::Self)
+      .Case("", FmtContext::PHKind::None)
+      .Default(FmtContext::PHKind::Custom);
+}
+
+std::pair<FmtReplacement, StringRef>
+tblgen::FmtObjectBase::splitFmtSegment(StringRef fmt) {
+  size_t begin = fmt.find_first_of('$');
+  if (begin == StringRef::npos) {
+    // No placeholders: the whole format string should be returned as a
+    // literal string.
+    return {FmtReplacement{fmt}, StringRef()};
+  }
+  if (begin != 0) {
+    // The first placeholder is not at the beginning: we can split the format
+    // string into a literal string and the rest.
+    return {FmtReplacement{fmt.substr(0, begin)}, fmt.substr(begin)};
+  }
+
+  // The first placeholder is at the beginning
+
+  if (fmt.size() == 1) {
+    // The whole format string just contains '$': treat as literal.
+    return {FmtReplacement{fmt}, StringRef()};
+  }
+
+  // Allow escaping dollar with '$$'
+  if (fmt[1] == '$') {
+    return {FmtReplacement{fmt.substr(0, 1)}, fmt.substr(2)};
+  }
+
+  // First try to see if it's a positional placeholder, and then handle special
+  // placeholders.
+
+  size_t end = fmt.find_if_not([](char c) { return std::isdigit(c); }, 1);
+  if (end != 1) {
+    // We have a positional placeholder. Parse the index.
+    size_t index = 0;
+    if (fmt.substr(1, end - 1).consumeInteger(0, index)) {
+      llvm_unreachable("invalid replacement sequence index");
+    }
+
+    if (end == StringRef::npos) {
+      // All the remaining characters are part of the positional placeholder.
+      return {FmtReplacement{fmt, index}, StringRef()};
+    }
+    return {FmtReplacement{fmt.substr(0, end), index}, fmt.substr(end)};
+  }
+
+  end = fmt.find_if_not([](char c) { return std::isalnum(c) || c == '_'; }, 1);
+  auto placeholder = FmtContext::getPlaceHolderKind(fmt.substr(1, end - 1));
+  if (end == StringRef::npos) {
+    // All the remaining characters are part of the special placeholder.
+    return {FmtReplacement{fmt, placeholder}, StringRef()};
+  }
+  return {FmtReplacement{fmt.substr(0, end), placeholder}, fmt.substr(end)};
+}
+
+std::vector<FmtReplacement> FmtObjectBase::parseFormatString(StringRef fmt) {
+  std::vector<FmtReplacement> replacements;
+  FmtReplacement repl;
+  while (!fmt.empty()) {
+    std::tie(repl, fmt) = splitFmtSegment(fmt);
+    if (repl.type != FmtReplacement::Type::Empty)
+      replacements.push_back(repl);
+  }
+  return replacements;
+}
+
+void FmtObjectBase::format(raw_ostream &s) const {
+  for (auto &repl : replacements) {
+    if (repl.type == FmtReplacement::Type::Empty)
+      continue;
+
+    if (repl.type == FmtReplacement::Type::Literal) {
+      s << repl.spec;
+      continue;
+    }
+
+    if (repl.type == FmtReplacement::Type::SpecialPH) {
+      if (repl.placeholder == FmtContext::PHKind::None) {
+        s << repl.spec;
+      } else if (!context) {
+        // We need the context to replace special placeholders.
+        s << repl.spec << kMarkerForNoSubst;
+      } else {
+        Optional<StringRef> subst;
+        if (repl.placeholder == FmtContext::PHKind::Custom) {
+          // Skip the leading '$' sign for the custom placeholder
+          subst = context->getSubstFor(repl.spec.substr(1));
+        } else {
+          subst = context->getSubstFor(repl.placeholder);
+        }
+        if (subst)
+          s << *subst;
+        else
+          s << repl.spec << kMarkerForNoSubst;
+      }
+      continue;
+    }
+
+    assert(repl.type == FmtReplacement::Type::PositionalPH);
+
+    if (repl.index >= adapters.size()) {
+      s << repl.spec << kMarkerForNoSubst;
+      continue;
+    }
+    adapters[repl.index]->format(s, /*Options=*/"");
+  }
+}
diff --git a/third_party/mlir/lib/TableGen/OpTrait.cpp b/third_party/mlir/lib/TableGen/OpTrait.cpp
new file mode 100644
index 0000000..0a357ac
--- /dev/null
+++ b/third_party/mlir/lib/TableGen/OpTrait.cpp
@@ -0,0 +1,59 @@
+//===- OpTrait.cpp - OpTrait class ----------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// OpTrait wrapper to simplify using TableGen Record defining a MLIR OpTrait.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/OpTrait.h"
+#include "mlir/TableGen/Predicate.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+
+mlir::tblgen::OpTrait mlir::tblgen::OpTrait::create(const llvm::Init *init) {
+  auto def = cast<llvm::DefInit>(init)->getDef();
+  if (def->isSubClassOf("PredOpTrait"))
+    return OpTrait(Kind::Pred, def);
+  if (def->isSubClassOf("GenInternalOpTrait"))
+    return OpTrait(Kind::Internal, def);
+  assert(def->isSubClassOf("NativeOpTrait"));
+  return OpTrait(Kind::Native, def);
+}
+
+mlir::tblgen::OpTrait::OpTrait(Kind kind, const llvm::Record *def)
+    : def(def), kind(kind) {}
+
+llvm::StringRef mlir::tblgen::NativeOpTrait::getTrait() const {
+  return def->getValueAsString("trait");
+}
+
+llvm::StringRef mlir::tblgen::InternalOpTrait::getTrait() const {
+  return def->getValueAsString("trait");
+}
+
+std::string mlir::tblgen::PredOpTrait::getPredTemplate() const {
+  auto pred = tblgen::Pred(def->getValueInit("predicate"));
+  return pred.getCondition();
+}
+
+llvm::StringRef mlir::tblgen::PredOpTrait::getDescription() const {
+  return def->getValueAsString("description");
+}
diff --git a/third_party/mlir/lib/TableGen/Operator.cpp b/third_party/mlir/lib/TableGen/Operator.cpp
new file mode 100644
index 0000000..60fecf7
--- /dev/null
+++ b/third_party/mlir/lib/TableGen/Operator.cpp
@@ -0,0 +1,307 @@
+//===- Operator.cpp - Operator class --------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Operator wrapper to simplify using TableGen Record defining a MLIR Op.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Operator.h"
+#include "mlir/TableGen/OpTrait.h"
+#include "mlir/TableGen/Predicate.h"
+#include "mlir/TableGen/Type.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+
+using llvm::DagInit;
+using llvm::DefInit;
+using llvm::Record;
+
+tblgen::Operator::Operator(const llvm::Record &def)
+    : dialect(def.getValueAsDef("opDialect")), def(def) {
+  // The first `_` in the op's TableGen def name is treated as separating the
+  // dialect prefix and the op class name. The dialect prefix will be ignored if
+  // not empty. Otherwise, if def name starts with a `_`, the `_` is considered
+  // as part of the class name.
+  StringRef prefix;
+  std::tie(prefix, cppClassName) = def.getName().split('_');
+  if (prefix.empty()) {
+    // Class name with a leading underscore and without dialect prefix
+    cppClassName = def.getName();
+  } else if (cppClassName.empty()) {
+    // Class name without dialect prefix
+    cppClassName = prefix;
+  }
+
+  populateOpStructure();
+}
+
+std::string tblgen::Operator::getOperationName() const {
+  auto prefix = dialect.getName();
+  auto opName = def.getValueAsString("opName");
+  if (prefix.empty())
+    return opName;
+  return llvm::formatv("{0}.{1}", prefix, opName);
+}
+
+StringRef tblgen::Operator::getDialectName() const { return dialect.getName(); }
+
+StringRef tblgen::Operator::getCppClassName() const { return cppClassName; }
+
+std::string tblgen::Operator::getQualCppClassName() const {
+  auto prefix = dialect.getCppNamespace();
+  if (prefix.empty())
+    return cppClassName;
+  return llvm::formatv("{0}::{1}", prefix, cppClassName);
+}
+
+int tblgen::Operator::getNumResults() const {
+  DagInit *results = def.getValueAsDag("results");
+  return results->getNumArgs();
+}
+
+StringRef tblgen::Operator::getExtraClassDeclaration() const {
+  constexpr auto attr = "extraClassDeclaration";
+  if (def.isValueUnset(attr))
+    return {};
+  return def.getValueAsString(attr);
+}
+
+const llvm::Record &tblgen::Operator::getDef() const { return def; }
+
+bool tblgen::Operator::isVariadic() const {
+  return getNumVariadicOperands() != 0 || getNumVariadicResults() != 0;
+}
+
+bool tblgen::Operator::skipDefaultBuilders() const {
+  return def.getValueAsBit("skipDefaultBuilders");
+}
+
+auto tblgen::Operator::result_begin() -> value_iterator {
+  return results.begin();
+}
+
+auto tblgen::Operator::result_end() -> value_iterator { return results.end(); }
+
+auto tblgen::Operator::getResults() -> value_range {
+  return {result_begin(), result_end()};
+}
+
+tblgen::TypeConstraint
+tblgen::Operator::getResultTypeConstraint(int index) const {
+  DagInit *results = def.getValueAsDag("results");
+  return TypeConstraint(cast<DefInit>(results->getArg(index)));
+}
+
+StringRef tblgen::Operator::getResultName(int index) const {
+  DagInit *results = def.getValueAsDag("results");
+  return results->getArgNameStr(index);
+}
+
+unsigned tblgen::Operator::getNumVariadicResults() const {
+  return std::count_if(
+      results.begin(), results.end(),
+      [](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
+}
+
+unsigned tblgen::Operator::getNumVariadicOperands() const {
+  return std::count_if(
+      operands.begin(), operands.end(),
+      [](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
+}
+
+StringRef tblgen::Operator::getArgName(int index) const {
+  DagInit *argumentValues = def.getValueAsDag("arguments");
+  return argumentValues->getArgName(index)->getValue();
+}
+
+bool tblgen::Operator::hasTrait(StringRef trait) const {
+  for (auto t : getTraits()) {
+    if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&t)) {
+      if (opTrait->getTrait() == trait)
+        return true;
+    } else if (auto opTrait = dyn_cast<tblgen::InternalOpTrait>(&t)) {
+      if (opTrait->getTrait() == trait)
+        return true;
+    }
+  }
+  return false;
+}
+
+unsigned tblgen::Operator::getNumRegions() const { return regions.size(); }
+
+const tblgen::NamedRegion &tblgen::Operator::getRegion(unsigned index) const {
+  return regions[index];
+}
+
+auto tblgen::Operator::trait_begin() const -> const_trait_iterator {
+  return traits.begin();
+}
+auto tblgen::Operator::trait_end() const -> const_trait_iterator {
+  return traits.end();
+}
+auto tblgen::Operator::getTraits() const
+    -> llvm::iterator_range<const_trait_iterator> {
+  return {trait_begin(), trait_end()};
+}
+
+auto tblgen::Operator::attribute_begin() const -> attribute_iterator {
+  return attributes.begin();
+}
+auto tblgen::Operator::attribute_end() const -> attribute_iterator {
+  return attributes.end();
+}
+auto tblgen::Operator::getAttributes() const
+    -> llvm::iterator_range<attribute_iterator> {
+  return {attribute_begin(), attribute_end()};
+}
+
+auto tblgen::Operator::operand_begin() -> value_iterator {
+  return operands.begin();
+}
+auto tblgen::Operator::operand_end() -> value_iterator {
+  return operands.end();
+}
+auto tblgen::Operator::getOperands() -> value_range {
+  return {operand_begin(), operand_end()};
+}
+
+auto tblgen::Operator::getArg(int index) const -> Argument {
+  return arguments[index];
+}
+
+void tblgen::Operator::populateOpStructure() {
+  auto &recordKeeper = def.getRecords();
+  auto typeConstraintClass = recordKeeper.getClass("TypeConstraint");
+  auto attrClass = recordKeeper.getClass("Attr");
+  auto derivedAttrClass = recordKeeper.getClass("DerivedAttr");
+  numNativeAttributes = 0;
+
+  // The argument ordering is operands, native attributes, derived
+  // attributes.
+  DagInit *argumentValues = def.getValueAsDag("arguments");
+  unsigned i = 0;
+  // Handle operands and native attributes.
+  for (unsigned e = argumentValues->getNumArgs(); i != e; ++i) {
+    auto arg = argumentValues->getArg(i);
+    auto givenName = argumentValues->getArgNameStr(i);
+    auto argDefInit = dyn_cast<DefInit>(arg);
+    if (!argDefInit)
+      PrintFatalError(def.getLoc(),
+                      Twine("undefined type for argument #") + Twine(i));
+    Record *argDef = argDefInit->getDef();
+
+    if (argDef->isSubClassOf(typeConstraintClass)) {
+      operands.push_back(
+          NamedTypeConstraint{givenName, TypeConstraint(argDefInit)});
+      arguments.emplace_back(&operands.back());
+    } else if (argDef->isSubClassOf(attrClass)) {
+      if (givenName.empty())
+        PrintFatalError(argDef->getLoc(), "attributes must be named");
+      if (argDef->isSubClassOf(derivedAttrClass))
+        PrintFatalError(argDef->getLoc(),
+                        "derived attributes not allowed in argument list");
+      attributes.push_back({givenName, Attribute(argDef)});
+      arguments.emplace_back(&attributes.back());
+      ++numNativeAttributes;
+    } else {
+      PrintFatalError(def.getLoc(), "unexpected def type; only defs deriving "
+                                    "from TypeConstraint or Attr are allowed");
+    }
+  }
+
+  // Handle derived attributes.
+  for (const auto &val : def.getValues()) {
+    if (auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) {
+      if (!record->isSubClassOf(attrClass))
+        continue;
+      if (!record->isSubClassOf(derivedAttrClass))
+        PrintFatalError(def.getLoc(),
+                        "unexpected Attr where only DerivedAttr is allowed");
+
+      if (record->getClasses().size() != 1) {
+        PrintFatalError(
+            def.getLoc(),
+            "unsupported attribute modelling, only single class expected");
+      }
+      attributes.push_back(
+          {cast<llvm::StringInit>(val.getNameInit())->getValue(),
+           Attribute(cast<DefInit>(val.getValue()))});
+    }
+  }
+
+  auto *resultsDag = def.getValueAsDag("results");
+  auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator());
+  if (!outsOp || outsOp->getDef()->getName() != "outs") {
+    PrintFatalError(def.getLoc(), "'results' must have 'outs' directive");
+  }
+
+  // Handle results.
+  for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) {
+    auto name = resultsDag->getArgNameStr(i);
+    auto *resultDef = dyn_cast<DefInit>(resultsDag->getArg(i));
+    if (!resultDef) {
+      PrintFatalError(def.getLoc(),
+                      Twine("undefined type for result #") + Twine(i));
+    }
+    results.push_back({name, TypeConstraint(resultDef)});
+  }
+
+  auto traitListInit = def.getValueAsListInit("traits");
+  if (!traitListInit)
+    return;
+  traits.reserve(traitListInit->size());
+  for (auto traitInit : *traitListInit)
+    traits.push_back(OpTrait::create(traitInit));
+
+  // Handle regions
+  auto *regionsDag = def.getValueAsDag("regions");
+  auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator());
+  if (!regionsOp || regionsOp->getDef()->getName() != "region") {
+    PrintFatalError(def.getLoc(), "'regions' must have 'region' directive");
+  }
+
+  for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) {
+    auto name = regionsDag->getArgNameStr(i);
+    auto *regionInit = dyn_cast<DefInit>(regionsDag->getArg(i));
+    if (!regionInit) {
+      PrintFatalError(def.getLoc(),
+                      Twine("undefined kind for region #") + Twine(i));
+    }
+    regions.push_back({name, Region(regionInit->getDef())});
+  }
+}
+
+ArrayRef<llvm::SMLoc> tblgen::Operator::getLoc() const { return def.getLoc(); }
+
+bool tblgen::Operator::hasDescription() const {
+  return def.getValue("description") != nullptr;
+}
+
+StringRef tblgen::Operator::getDescription() const {
+  return def.getValueAsString("description");
+}
+
+bool tblgen::Operator::hasSummary() const {
+  return def.getValue("summary") != nullptr;
+}
+
+StringRef tblgen::Operator::getSummary() const {
+  return def.getValueAsString("summary");
+}
diff --git a/third_party/mlir/lib/TableGen/Pattern.cpp b/third_party/mlir/lib/TableGen/Pattern.cpp
new file mode 100644
index 0000000..51e4c3b
--- /dev/null
+++ b/third_party/mlir/lib/TableGen/Pattern.cpp
@@ -0,0 +1,445 @@
+//===- Pattern.cpp - Pattern wrapper class --------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Pattern wrapper class to simplify using TableGen Record defining a MLIR
+// Pattern.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Pattern.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+
+using llvm::formatv;
+using mlir::tblgen::Operator;
+
+//===----------------------------------------------------------------------===//
+// DagLeaf
+//===----------------------------------------------------------------------===//
+
+bool tblgen::DagLeaf::isUnspecified() const {
+  return dyn_cast_or_null<llvm::UnsetInit>(def);
+}
+
+bool tblgen::DagLeaf::isOperandMatcher() const {
+  // Operand matchers specify a type constraint.
+  return isSubClassOf("TypeConstraint");
+}
+
+bool tblgen::DagLeaf::isAttrMatcher() const {
+  // Attribute matchers specify an attribute constraint.
+  return isSubClassOf("AttrConstraint");
+}
+
+bool tblgen::DagLeaf::isNativeCodeCall() const {
+  return isSubClassOf("NativeCodeCall");
+}
+
+bool tblgen::DagLeaf::isConstantAttr() const {
+  return isSubClassOf("ConstantAttr");
+}
+
+bool tblgen::DagLeaf::isEnumAttrCase() const {
+  return isSubClassOf("EnumAttrCaseInfo");
+}
+
+tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const {
+  assert((isOperandMatcher() || isAttrMatcher()) &&
+         "the DAG leaf must be operand or attribute");
+  return Constraint(cast<llvm::DefInit>(def)->getDef());
+}
+
+tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const {
+  assert(isConstantAttr() && "the DAG leaf must be constant attribute");
+  return ConstantAttr(cast<llvm::DefInit>(def));
+}
+
+tblgen::EnumAttrCase tblgen::DagLeaf::getAsEnumAttrCase() const {
+  assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
+  return EnumAttrCase(cast<llvm::DefInit>(def));
+}
+
+std::string tblgen::DagLeaf::getConditionTemplate() const {
+  return getAsConstraint().getConditionTemplate();
+}
+
+llvm::StringRef tblgen::DagLeaf::getNativeCodeTemplate() const {
+  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
+  return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
+}
+
+bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const {
+  if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
+    return defInit->getDef()->isSubClassOf(superclass);
+  return false;
+}
+
+//===----------------------------------------------------------------------===//
+// DagNode
+//===----------------------------------------------------------------------===//
+
+bool tblgen::DagNode::isNativeCodeCall() const {
+  if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
+    return defInit->getDef()->isSubClassOf("NativeCodeCall");
+  return false;
+}
+
+bool tblgen::DagNode::isOperation() const {
+  return !(isNativeCodeCall() || isReplaceWithValue());
+}
+
+llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const {
+  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
+  return cast<llvm::DefInit>(node->getOperator())
+      ->getDef()
+      ->getValueAsString("expression");
+}
+
+llvm::StringRef tblgen::DagNode::getSymbol() const {
+  return node->getNameStr();
+}
+
+Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const {
+  llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
+  auto it = mapper->find(opDef);
+  if (it != mapper->end())
+    return *it->second;
+  return *mapper->try_emplace(opDef, llvm::make_unique<Operator>(opDef))
+              .first->second;
+}
+
+int tblgen::DagNode::getNumOps() const {
+  int count = isReplaceWithValue() ? 0 : 1;
+  for (int i = 0, e = getNumArgs(); i != e; ++i) {
+    if (auto child = getArgAsNestedDag(i))
+      count += child.getNumOps();
+  }
+  return count;
+}
+
+int tblgen::DagNode::getNumArgs() const { return node->getNumArgs(); }
+
+bool tblgen::DagNode::isNestedDagArg(unsigned index) const {
+  return isa<llvm::DagInit>(node->getArg(index));
+}
+
+tblgen::DagNode tblgen::DagNode::getArgAsNestedDag(unsigned index) const {
+  return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
+}
+
+tblgen::DagLeaf tblgen::DagNode::getArgAsLeaf(unsigned index) const {
+  assert(!isNestedDagArg(index));
+  return DagLeaf(node->getArg(index));
+}
+
+StringRef tblgen::DagNode::getArgName(unsigned index) const {
+  return node->getArgNameStr(index);
+}
+
+bool tblgen::DagNode::isReplaceWithValue() const {
+  auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
+  return dagOpDef->getName() == "replaceWithValue";
+}
+
+//===----------------------------------------------------------------------===//
+// SymbolInfoMap
+//===----------------------------------------------------------------------===//
+
+StringRef tblgen::SymbolInfoMap::getValuePackName(StringRef symbol,
+                                                  int *index) {
+  StringRef name, indexStr;
+  int idx = -1;
+  std::tie(name, indexStr) = symbol.rsplit("__");
+
+  if (indexStr.consumeInteger(10, idx)) {
+    // The second part is not an index; we return the whole symbol as-is.
+    return symbol;
+  }
+  if (index) {
+    *index = idx;
+  }
+  return name;
+}
+
+tblgen::SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op,
+                                              SymbolInfo::Kind kind,
+                                              Optional<int> index)
+    : op(op), kind(kind), argIndex(index) {}
+
+int tblgen::SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
+  switch (kind) {
+  case Kind::Attr:
+  case Kind::Operand:
+  case Kind::Value:
+    return 1;
+  case Kind::Result:
+    return op->getNumResults();
+  }
+}
+
+std::string
+tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
+  switch (kind) {
+  case Kind::Attr: {
+    auto type =
+        op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
+    return formatv("{0} {1};\n", type, name);
+  }
+  case Kind::Operand:
+  case Kind::Value: {
+    return formatv("Value *{0};\n", name);
+  }
+  case Kind::Result: {
+    // Use the op itself for the results.
+    return formatv("{0} {1};\n", op->getQualCppClassName(), name);
+  }
+  }
+}
+
+std::string
+tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(StringRef name,
+                                                       int index) const {
+  switch (kind) {
+  case Kind::Attr:
+  case Kind::Operand: {
+    assert(index < 0 && "only allowed for symbol bound to result");
+    return name;
+  }
+  case Kind::Result: {
+    // TODO(b/133341698): The following is incorrect for variadic results. We
+    // should use getODSResults().
+    if (index >= 0) {
+      return formatv("{0}.getOperation()->getResult({1})", name, index);
+    }
+
+    // If referencing multiple results, compose a comma-separated list.
+    SmallVector<std::string, 4> values;
+    for (int i = 0, e = op->getNumResults(); i < e; ++i) {
+      values.push_back(formatv("{0}.getOperation()->getResult({1})", name, i));
+    }
+    return llvm::join(values, ", ");
+  }
+  case Kind::Value: {
+    assert(index < 0 && "only allowed for symbol bound to result");
+    assert(op == nullptr);
+    return name;
+  }
+  }
+}
+
+bool tblgen::SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
+                                           int argIndex) {
+  StringRef name = getValuePackName(symbol);
+  if (name != symbol) {
+    auto error = formatv(
+        "symbol '{0}' with trailing index cannot bind to op argument", symbol);
+    PrintFatalError(loc, error);
+  }
+
+  auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
+                     ? SymbolInfo::getAttr(&op, argIndex)
+                     : SymbolInfo::getOperand(&op, argIndex);
+
+  return symbolInfoMap.insert({symbol, symInfo}).second;
+}
+
+bool tblgen::SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
+  StringRef name = getValuePackName(symbol);
+  return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second;
+}
+
+bool tblgen::SymbolInfoMap::bindValue(StringRef symbol) {
+  return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second;
+}
+
+bool tblgen::SymbolInfoMap::contains(StringRef symbol) const {
+  return find(symbol) != symbolInfoMap.end();
+}
+
+tblgen::SymbolInfoMap::const_iterator
+tblgen::SymbolInfoMap::find(StringRef key) const {
+  StringRef name = getValuePackName(key);
+  return symbolInfoMap.find(name);
+}
+
+int tblgen::SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
+  StringRef name = getValuePackName(symbol);
+  if (name != symbol) {
+    // If there is a trailing index inside symbol, it references just one
+    // static value.
+    return 1;
+  }
+  // Otherwise, find how many it represents by querying the symbol's info.
+  return find(name)->getValue().getStaticValueCount();
+}
+
+std::string tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol) const {
+  int index = -1;
+  StringRef name = getValuePackName(symbol, &index);
+
+  auto it = symbolInfoMap.find(name);
+  if (it == symbolInfoMap.end()) {
+    auto error = formatv("referencing unbound symbol '{0}'", symbol);
+    PrintFatalError(loc, error);
+  }
+
+  return it->getValue().getValueAndRangeUse(name, index);
+}
+
+//===----------------------------------------------------------------------===//
+// Pattern
+//==----------------------------------------------------------------------===//
+
+tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
+    : def(*def), recordOpMap(mapper) {}
+
+tblgen::DagNode tblgen::Pattern::getSourcePattern() const {
+  return tblgen::DagNode(def.getValueAsDag("sourcePattern"));
+}
+
+int tblgen::Pattern::getNumResultPatterns() const {
+  auto *results = def.getValueAsListInit("resultPatterns");
+  return results->size();
+}
+
+tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const {
+  auto *results = def.getValueAsListInit("resultPatterns");
+  return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index)));
+}
+
+void tblgen::Pattern::collectSourcePatternBoundSymbols(
+    tblgen::SymbolInfoMap &infoMap) {
+  collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
+}
+
+void tblgen::Pattern::collectResultPatternBoundSymbols(
+    tblgen::SymbolInfoMap &infoMap) {
+  for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
+    auto pattern = getResultPattern(i);
+    collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
+  }
+}
+
+const tblgen::Operator &tblgen::Pattern::getSourceRootOp() {
+  return getSourcePattern().getDialectOp(recordOpMap);
+}
+
+tblgen::Operator &tblgen::Pattern::getDialectOp(DagNode node) {
+  return node.getDialectOp(recordOpMap);
+}
+
+std::vector<tblgen::AppliedConstraint> tblgen::Pattern::getConstraints() const {
+  auto *listInit = def.getValueAsListInit("constraints");
+  std::vector<tblgen::AppliedConstraint> ret;
+  ret.reserve(listInit->size());
+
+  for (auto it : *listInit) {
+    auto *dagInit = dyn_cast<llvm::DagInit>(it);
+    if (!dagInit)
+      PrintFatalError(def.getLoc(), "all elemements in Pattern multi-entity "
+                                    "constraints should be DAG nodes");
+
+    std::vector<std::string> entities;
+    entities.reserve(dagInit->arg_size());
+    for (auto *argName : dagInit->getArgNames())
+      entities.push_back(argName->getValue());
+
+    ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
+                     dagInit->getNameStr(), std::move(entities));
+  }
+  return ret;
+}
+
+int tblgen::Pattern::getBenefit() const {
+  // The initial benefit value is a heuristic with number of ops in the source
+  // pattern.
+  int initBenefit = getSourcePattern().getNumOps();
+  llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
+  if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
+    PrintFatalError(def.getLoc(),
+                    "The 'addBenefit' takes and only takes one integer value");
+  }
+  return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
+}
+
+std::vector<tblgen::Pattern::IdentifierLine>
+tblgen::Pattern::getLocation() const {
+  std::vector<std::pair<StringRef, unsigned>> result;
+  result.reserve(def.getLoc().size());
+  for (auto loc : def.getLoc()) {
+    unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
+    assert(buf && "invalid source location");
+    result.emplace_back(
+        llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
+        llvm::SrcMgr.getLineAndColumn(loc, buf).first);
+  }
+  return result;
+}
+
+void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
+                                          bool isSrcPattern) {
+  auto treeName = tree.getSymbol();
+  if (!tree.isOperation()) {
+    if (!treeName.empty()) {
+      PrintFatalError(
+          def.getLoc(),
+          formatv("binding symbol '{0}' to non-operation unsupported right now",
+                  treeName));
+    }
+    return;
+  }
+
+  auto &op = getDialectOp(tree);
+  auto numOpArgs = op.getNumArgs();
+  auto numTreeArgs = tree.getNumArgs();
+
+  if (numOpArgs != numTreeArgs) {
+    auto err = formatv("op '{0}' argument number mismatch: "
+                       "{1} in pattern vs. {2} in definition",
+                       op.getOperationName(), numTreeArgs, numOpArgs);
+    PrintFatalError(def.getLoc(), err);
+  }
+
+  // The name attached to the DAG node's operator is for representing the
+  // results generated from this op. It should be remembered as bound results.
+  if (!treeName.empty()) {
+    if (!infoMap.bindOpResult(treeName, op))
+      PrintFatalError(def.getLoc(),
+                      formatv("symbol '{0}' bound more than once", treeName));
+  }
+
+  for (int i = 0; i != numTreeArgs; ++i) {
+    if (auto treeArg = tree.getArgAsNestedDag(i)) {
+      // This DAG node argument is a DAG node itself. Go inside recursively.
+      collectBoundSymbols(treeArg, infoMap, isSrcPattern);
+    } else if (isSrcPattern) {
+      // We can only bind symbols to op arguments in source pattern. Those
+      // symbols are referenced in result patterns.
+      auto treeArgName = tree.getArgName(i);
+      if (!treeArgName.empty()) {
+        if (!infoMap.bindOpArgument(treeArgName, op, i)) {
+          auto err = formatv("symbol '{0}' bound more than once", treeArgName);
+          PrintFatalError(def.getLoc(), err);
+        }
+      }
+    }
+  }
+}
diff --git a/third_party/mlir/lib/TableGen/Predicate.cpp b/third_party/mlir/lib/TableGen/Predicate.cpp
new file mode 100644
index 0000000..bc2b424
--- /dev/null
+++ b/third_party/mlir/lib/TableGen/Predicate.cpp
@@ -0,0 +1,374 @@
+//===- Predicate.cpp - Predicate class ------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Wrapper around predicates defined in TableGen.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Predicate.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+
+// Construct a Predicate from a record.
+tblgen::Pred::Pred(const llvm::Record *record) : def(record) {
+  assert(def->isSubClassOf("Pred") &&
+         "must be a subclass of TableGen 'Pred' class");
+}
+
+// Construct a Predicate from an initializer.
+tblgen::Pred::Pred(const llvm::Init *init) : def(nullptr) {
+  if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init))
+    def = defInit->getDef();
+}
+
+std::string tblgen::Pred::getCondition() const {
+  // Static dispatch to subclasses.
+  if (def->isSubClassOf("CombinedPred"))
+    return static_cast<const CombinedPred *>(this)->getConditionImpl();
+  if (def->isSubClassOf("CPred"))
+    return static_cast<const CPred *>(this)->getConditionImpl();
+  llvm_unreachable("Pred::getCondition must be overridden in subclasses");
+}
+
+bool tblgen::Pred::isCombined() const {
+  return def && def->isSubClassOf("CombinedPred");
+}
+
+ArrayRef<llvm::SMLoc> tblgen::Pred::getLoc() const { return def->getLoc(); }
+
+tblgen::CPred::CPred(const llvm::Record *record) : Pred(record) {
+  assert(def->isSubClassOf("CPred") &&
+         "must be a subclass of Tablegen 'CPred' class");
+}
+
+tblgen::CPred::CPred(const llvm::Init *init) : Pred(init) {
+  assert((!def || def->isSubClassOf("CPred")) &&
+         "must be a subclass of Tablegen 'CPred' class");
+}
+
+// Get condition of the C Predicate.
+std::string tblgen::CPred::getConditionImpl() const {
+  assert(!isNull() && "null predicate does not have a condition");
+  return def->getValueAsString("predExpr");
+}
+
+tblgen::CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) {
+  assert(def->isSubClassOf("CombinedPred") &&
+         "must be a subclass of Tablegen 'CombinedPred' class");
+}
+
+tblgen::CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) {
+  assert((!def || def->isSubClassOf("CombinedPred")) &&
+         "must be a subclass of Tablegen 'CombinedPred' class");
+}
+
+const llvm::Record *tblgen::CombinedPred::getCombinerDef() const {
+  assert(def->getValue("kind") && "CombinedPred must have a value 'kind'");
+  return def->getValueAsDef("kind");
+}
+
+const std::vector<llvm::Record *> tblgen::CombinedPred::getChildren() const {
+  assert(def->getValue("children") &&
+         "CombinedPred must have a value 'children'");
+  return def->getValueAsListOfDefs("children");
+}
+
+namespace {
+// Kinds of nodes in a logical predicate tree.
+enum class PredCombinerKind {
+  Leaf,
+  And,
+  Or,
+  Not,
+  SubstLeaves,
+  Concat,
+  // Special kinds that are used in simplification.
+  False,
+  True
+};
+
+// A node in a logical predicate tree.
+struct PredNode {
+  PredCombinerKind kind;
+  const tblgen::Pred *predicate;
+  SmallVector<PredNode *, 4> children;
+  std::string expr;
+
+  // Prefix and suffix are used by ConcatPred.
+  std::string prefix;
+  std::string suffix;
+};
+} // end anonymous namespace
+
+// Get a predicate tree node kind based on the kind used in the predicate
+// TableGen record.
+static PredCombinerKind getPredCombinerKind(const tblgen::Pred &pred) {
+  if (!pred.isCombined())
+    return PredCombinerKind::Leaf;
+
+  const auto &combinedPred = static_cast<const tblgen::CombinedPred &>(pred);
+  return llvm::StringSwitch<PredCombinerKind>(
+             combinedPred.getCombinerDef()->getName())
+      .Case("PredCombinerAnd", PredCombinerKind::And)
+      .Case("PredCombinerOr", PredCombinerKind::Or)
+      .Case("PredCombinerNot", PredCombinerKind::Not)
+      .Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves)
+      .Case("PredCombinerConcat", PredCombinerKind::Concat);
+}
+
+namespace {
+// Substitution<pattern, replacement>.
+using Subst = std::pair<StringRef, StringRef>;
+} // end anonymous namespace
+
+// Build the predicate tree starting from the top-level predicate, which may
+// have children, and perform leaf substitutions inplace.  Note that after
+// substitution, nodes are still pointing to the original TableGen record.
+// All nodes are created within "allocator".
+static PredNode *buildPredicateTree(const tblgen::Pred &root,
+                                    llvm::BumpPtrAllocator &allocator,
+                                    ArrayRef<Subst> substitutions) {
+  auto *rootNode = allocator.Allocate<PredNode>();
+  new (rootNode) PredNode;
+  rootNode->kind = getPredCombinerKind(root);
+  rootNode->predicate = &root;
+  if (!root.isCombined()) {
+    rootNode->expr = root.getCondition();
+    // Apply all parent substitutions from innermost to outermost.
+    for (const auto &subst : llvm::reverse(substitutions)) {
+      auto pos = rootNode->expr.find(subst.first);
+      while (pos != std::string::npos) {
+        rootNode->expr.replace(pos, subst.first.size(), subst.second);
+        // Skip the newly inserted substring, which itself may consider the
+        // pattern to match.
+        pos += subst.second.size();
+        // Find the next possible match position.
+        pos = rootNode->expr.find(subst.first, pos);
+      }
+    }
+    return rootNode;
+  }
+
+  // If the current combined predicate is a leaf substitution, append it to the
+  // list before contiuing.
+  auto allSubstitutions = llvm::to_vector<4>(substitutions);
+  if (rootNode->kind == PredCombinerKind::SubstLeaves) {
+    const auto &substPred = static_cast<const tblgen::SubstLeavesPred &>(root);
+    allSubstitutions.push_back(
+        {substPred.getPattern(), substPred.getReplacement()});
+  }
+  // If the current predicate is a ConcatPred, record the prefix and suffix.
+  else if (rootNode->kind == PredCombinerKind::Concat) {
+    const auto &concatPred = static_cast<const tblgen::ConcatPred &>(root);
+    rootNode->prefix = concatPred.getPrefix();
+    rootNode->suffix = concatPred.getSuffix();
+  }
+
+  // Build child subtrees.
+  auto combined = static_cast<const tblgen::CombinedPred &>(root);
+  for (const auto *record : combined.getChildren()) {
+    auto childTree =
+        buildPredicateTree(tblgen::Pred(record), allocator, allSubstitutions);
+    rootNode->children.push_back(childTree);
+  }
+  return rootNode;
+}
+
+// Simplify a predicate tree rooted at "node" using the predicates that are
+// known to be true(false).  For AND(OR) combined predicates, if any of the
+// children is known to be false(true), the result is also false(true).
+// Furthermore, for AND(OR) combined predicates, children that are known to be
+// true(false) don't have to be checked dynamically.
+static PredNode *propagateGroundTruth(
+    PredNode *node, const llvm::SmallPtrSetImpl<tblgen::Pred *> &knownTruePreds,
+    const llvm::SmallPtrSetImpl<tblgen::Pred *> &knownFalsePreds) {
+  // If the current predicate is known to be true or false, change the kind of
+  // the node and return immediately.
+  if (knownTruePreds.count(node->predicate) != 0) {
+    node->kind = PredCombinerKind::True;
+    node->children.clear();
+    return node;
+  }
+  if (knownFalsePreds.count(node->predicate) != 0) {
+    node->kind = PredCombinerKind::False;
+    node->children.clear();
+    return node;
+  }
+
+  // If the current node is a substitution, stop recursion now.
+  // The expressions in the leaves below this node were rewritten, but the nodes
+  // still point to the original predicate records.  While the original
+  // predicate may be known to be true or false, it is not necessarily the case
+  // after rewriting.
+  // TODO(zinenko,jpienaar): we can support ground truth for rewritten
+  // predicates by either (a) having our own unique'ing of the predicates
+  // instead of relying on TableGen record pointers or (b) taking ground truth
+  // values optinally prefixed with a list of substitutions to apply, e.g.
+  // "predX is true by itself as well as predSubY leaf substitution had been
+  // applied to it".
+  if (node->kind == PredCombinerKind::SubstLeaves) {
+    return node;
+  }
+
+  // Otherwise, look at child nodes.
+
+  // Move child nodes into some local variable so that they can be optimized
+  // separately and re-added if necessary.
+  llvm::SmallVector<PredNode *, 4> children;
+  std::swap(node->children, children);
+
+  for (auto &child : children) {
+    // First, simplify the child.  This maintains the predicate as it was.
+    auto simplifiedChild =
+        propagateGroundTruth(child, knownTruePreds, knownFalsePreds);
+
+    // Just add the child if we don't know how to simplify the current node.
+    if (node->kind != PredCombinerKind::And &&
+        node->kind != PredCombinerKind::Or) {
+      node->children.push_back(simplifiedChild);
+      continue;
+    }
+
+    // Second, based on the type define which known values of child predicates
+    // immediately collapse this predicate to a known value, and which others
+    // may be safely ignored.
+    //   OR(..., True, ...) = True
+    //   OR(..., False, ...) = OR(..., ...)
+    //   AND(..., False, ...) = False
+    //   AND(..., True, ...) = AND(..., ...)
+    auto collapseKind = node->kind == PredCombinerKind::And
+                            ? PredCombinerKind::False
+                            : PredCombinerKind::True;
+    auto eraseKind = node->kind == PredCombinerKind::And
+                         ? PredCombinerKind::True
+                         : PredCombinerKind::False;
+    const auto &collapseList =
+        node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds;
+    const auto &eraseList =
+        node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds;
+    if (simplifiedChild->kind == collapseKind ||
+        collapseList.count(simplifiedChild->predicate) != 0) {
+      node->kind = collapseKind;
+      node->children.clear();
+      return node;
+    } else if (simplifiedChild->kind == eraseKind ||
+               eraseList.count(simplifiedChild->predicate) != 0) {
+      continue;
+    }
+    node->children.push_back(simplifiedChild);
+  }
+  return node;
+}
+
+// Combine a list of predicate expressions using a binary combiner.  If a list
+// is empty, return "init".
+static std::string combineBinary(ArrayRef<std::string> children,
+                                 std::string combiner, std::string init) {
+  if (children.empty())
+    return init;
+
+  auto size = children.size();
+  if (size == 1)
+    return children.front();
+
+  std::string str;
+  llvm::raw_string_ostream os(str);
+  os << '(' << children.front() << ')';
+  for (unsigned i = 1; i < size; ++i) {
+    os << ' ' << combiner << " (" << children[i] << ')';
+  }
+  return os.str();
+}
+
+// Prepend negation to the only condition in the predicate expression list.
+static std::string combineNot(ArrayRef<std::string> children) {
+  assert(children.size() == 1 && "expected exactly one child predicate of Neg");
+  return (Twine("!(") + children.front() + Twine(')')).str();
+}
+
+// Recursively traverse the predicate tree in depth-first post-order and build
+// the final expression.
+static std::string getCombinedCondition(const PredNode &root) {
+  // Immediately return for non-combiner predicates that don't have children.
+  if (root.kind == PredCombinerKind::Leaf)
+    return root.expr;
+  if (root.kind == PredCombinerKind::True)
+    return "true";
+  if (root.kind == PredCombinerKind::False)
+    return "false";
+
+  // Recurse into children.
+  llvm::SmallVector<std::string, 4> childExpressions;
+  childExpressions.reserve(root.children.size());
+  for (const auto &child : root.children)
+    childExpressions.push_back(getCombinedCondition(*child));
+
+  // Combine the expressions based on the predicate node kind.
+  if (root.kind == PredCombinerKind::And)
+    return combineBinary(childExpressions, "&&", "true");
+  if (root.kind == PredCombinerKind::Or)
+    return combineBinary(childExpressions, "||", "false");
+  if (root.kind == PredCombinerKind::Not)
+    return combineNot(childExpressions);
+  if (root.kind == PredCombinerKind::Concat) {
+    assert(childExpressions.size() == 1 &&
+           "ConcatPred should only have one child");
+    return root.prefix + childExpressions.front() + root.suffix;
+  }
+
+  // Substitutions were applied before so just ignore them.
+  if (root.kind == PredCombinerKind::SubstLeaves) {
+    assert(childExpressions.size() == 1 &&
+           "substitution predicate must have one child");
+    return childExpressions[0];
+  }
+
+  llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind");
+}
+
+std::string tblgen::CombinedPred::getConditionImpl() const {
+  llvm::BumpPtrAllocator allocator;
+  auto predicateTree = buildPredicateTree(*this, allocator, {});
+  predicateTree = propagateGroundTruth(
+      predicateTree,
+      /*knownTruePreds=*/llvm::SmallPtrSet<tblgen::Pred *, 2>(),
+      /*knownFalsePreds=*/llvm::SmallPtrSet<tblgen::Pred *, 2>());
+
+  return getCombinedCondition(*predicateTree);
+}
+
+StringRef tblgen::SubstLeavesPred::getPattern() const {
+  return def->getValueAsString("pattern");
+}
+
+StringRef tblgen::SubstLeavesPred::getReplacement() const {
+  return def->getValueAsString("replacement");
+}
+
+StringRef tblgen::ConcatPred::getPrefix() const {
+  return def->getValueAsString("prefix");
+}
+
+StringRef tblgen::ConcatPred::getSuffix() const {
+  return def->getValueAsString("suffix");
+}
diff --git a/third_party/mlir/lib/TableGen/Type.cpp b/third_party/mlir/lib/TableGen/Type.cpp
new file mode 100644
index 0000000..340fb4b
--- /dev/null
+++ b/third_party/mlir/lib/TableGen/Type.cpp
@@ -0,0 +1,38 @@
+//===- Type.cpp - Type class ----------------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Type wrapper to simplify using TableGen Record defining a MLIR Type.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Type.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+
+tblgen::TypeConstraint::TypeConstraint(const llvm::Record *record)
+    : Constraint(Constraint::CK_Type, record) {
+  assert(def->isSubClassOf("TypeConstraint") &&
+         "must be subclass of TableGen 'TypeConstraint' class");
+}
+
+tblgen::TypeConstraint::TypeConstraint(const llvm::DefInit *init)
+    : TypeConstraint(init->getDef()) {}
+
+bool tblgen::TypeConstraint::isVariadic() const {
+  return def->isSubClassOf("Variadic");
+}
diff --git a/third_party/mlir/lib/Target/CMakeLists.txt b/third_party/mlir/lib/Target/CMakeLists.txt
new file mode 100644
index 0000000..9f49b81
--- /dev/null
+++ b/third_party/mlir/lib/Target/CMakeLists.txt
@@ -0,0 +1,30 @@
+add_llvm_library(MLIRTargetLLVMIRModuleTranslation
+  LLVMIR/ModuleTranslation.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR
+  DEPENDS
+  intrinsics_gen
+  )
+target_link_libraries(MLIRTargetLLVMIRModuleTranslation MLIRLLVMIR LLVMCore LLVMSupport LLVMTransformUtils MLIRTranslation)
+add_llvm_library(MLIRTargetLLVMIR
+  LLVMIR/ConvertToLLVMIR.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR
+  )
+target_link_libraries(MLIRTargetLLVMIR MLIRTargetLLVMIRModuleTranslation)
+add_llvm_library(MLIRTargetNVVMIR
+  LLVMIR/ConvertToNVVMIR.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR
+  DEPENDS
+  intrinsics_gen
+  )
+target_link_libraries(MLIRTargetNVVMIR
+  MLIRGPU
+  MLIRIR
+  MLIRNVVMIR
+  MLIRTargetLLVMIRModuleTranslation
+  )
diff --git a/third_party/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/third_party/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
new file mode 100644
index 0000000..0ba1581
--- /dev/null
+++ b/third_party/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
@@ -0,0 +1,54 @@
+//===- ConvertToLLVMIR.cpp - MLIR to LLVM IR conversion -------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a translation between the MLIR LLVM dialect and LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR.h"
+
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+#include "mlir/Translation.h"
+
+#include "llvm/ADT/StringRef.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+using namespace mlir;
+
+std::unique_ptr<llvm::Module> mlir::translateModuleToLLVMIR(ModuleOp m) {
+  return LLVM::ModuleTranslation::translateModule<>(m);
+}
+
+static TranslateFromMLIRRegistration registration(
+    "mlir-to-llvmir", [](ModuleOp module, llvm::StringRef outputFilename) {
+      if (!module)
+        return failure();
+
+      auto llvmModule = LLVM::ModuleTranslation::translateModule<>(module);
+      if (!llvmModule)
+        return failure();
+
+      auto file = openOutputFile(outputFilename);
+      if (!file)
+        return failure();
+
+      llvmModule->print(file->os(), nullptr);
+      file->keep();
+      return success();
+    });
diff --git a/third_party/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/third_party/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
new file mode 100644
index 0000000..a1e09fd
--- /dev/null
+++ b/third_party/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
@@ -0,0 +1,109 @@
+//===- ConvertToNVVMIR.cpp - MLIR to LLVM IR conversion -------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a translation between the MLIR LLVM + NVVM dialects and
+// LLVM IR with NVVM intrinsics and metadata.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/NVVMIR.h"
+
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/LLVMIR/NVVMDialect.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+#include "mlir/Translation.h"
+
+#include "llvm/ADT/StringRef.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+using namespace mlir;
+
+namespace {
+static llvm::Value *createIntrinsicCall(llvm::IRBuilder<> &builder,
+                                        llvm::Intrinsic::ID intrinsic) {
+  llvm::Module *module = builder.GetInsertBlock()->getModule();
+  llvm::Function *fn = llvm::Intrinsic::getDeclaration(module, intrinsic, {});
+  return builder.CreateCall(fn);
+}
+
+class ModuleTranslation : public LLVM::ModuleTranslation {
+
+public:
+  explicit ModuleTranslation(ModuleOp module)
+      : LLVM::ModuleTranslation(module) {}
+  ~ModuleTranslation() override {}
+
+protected:
+  LogicalResult convertOperation(Operation &opInst,
+                                 llvm::IRBuilder<> &builder) override {
+
+#include "mlir/LLVMIR/NVVMConversions.inc"
+
+    return LLVM::ModuleTranslation::convertOperation(opInst, builder);
+  }
+};
+} // namespace
+
+std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(ModuleOp m) {
+  ModuleTranslation translation(m);
+  auto llvmModule =
+      LLVM::ModuleTranslation::translateModule<ModuleTranslation>(m);
+
+  // Insert the nvvm.annotations kernel so that the NVVM backend recognizes the
+  // function as a kernel.
+  for (FuncOp func : m.getOps<FuncOp>()) {
+    if (!func.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelFuncAttrName()))
+      continue;
+
+    auto *llvmFunc = llvmModule->getFunction(func.getName());
+
+    llvm::Metadata *llvmMetadata[] = {
+        llvm::ValueAsMetadata::get(llvmFunc),
+        llvm::MDString::get(llvmModule->getContext(), "kernel"),
+        llvm::ValueAsMetadata::get(llvm::ConstantInt::get(
+            llvm::Type::getInt32Ty(llvmModule->getContext()), 1))};
+    llvm::MDNode *llvmMetadataNode =
+        llvm::MDNode::get(llvmModule->getContext(), llvmMetadata);
+    llvmModule->getOrInsertNamedMetadata("nvvm.annotations")
+        ->addOperand(llvmMetadataNode);
+  }
+
+  return llvmModule;
+}
+
+static TranslateFromMLIRRegistration
+    registration("mlir-to-nvvmir",
+                 [](ModuleOp module, llvm::StringRef outputFilename) {
+                   if (!module)
+                     return failure();
+
+                   auto llvmModule = mlir::translateModuleToNVVMIR(module);
+                   if (!llvmModule)
+                     return failure();
+
+                   auto file = openOutputFile(outputFilename);
+                   if (!file)
+                     return failure();
+
+                   llvmModule->print(file->os(), nullptr);
+                   file->keep();
+                   return success();
+                 });
diff --git a/third_party/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/third_party/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
new file mode 100644
index 0000000..5e1109b
--- /dev/null
+++ b/third_party/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -0,0 +1,484 @@
+//===- ModuleTranslation.cpp - MLIR to LLVM conversion --------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the translation between an MLIR LLVM dialect module and
+// the corresponding LLVMIR module. It only handles core LLVM IR operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Module.h"
+#include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/Support/LLVM.h"
+
+#include "llvm/ADT/SetVector.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+
+namespace mlir {
+namespace LLVM {
+
+// Convert an MLIR function type to LLVM IR.  Arguments of the function must of
+// MLIR LLVM IR dialect types.  Use `loc` as a location when reporting errors.
+// Return nullptr on errors.
+static llvm::FunctionType *convertFunctionType(llvm::LLVMContext &llvmContext,
+                                               FunctionType type, Location loc,
+                                               bool isVarArgs) {
+  assert(type && "expected non-null type");
+  if (type.getNumResults() > 1)
+    return emitError(loc, "LLVM functions can only have 0 or 1 result"),
+           nullptr;
+
+  SmallVector<llvm::Type *, 8> argTypes;
+  argTypes.reserve(type.getNumInputs());
+  for (auto t : type.getInputs()) {
+    auto wrappedLLVMType = t.dyn_cast<LLVM::LLVMType>();
+    if (!wrappedLLVMType)
+      return emitError(loc, "non-LLVM function argument type"), nullptr;
+    argTypes.push_back(wrappedLLVMType.getUnderlyingType());
+  }
+
+  if (type.getNumResults() == 0)
+    return llvm::FunctionType::get(llvm::Type::getVoidTy(llvmContext), argTypes,
+                                   isVarArgs);
+
+  auto wrappedResultType = type.getResult(0).dyn_cast<LLVM::LLVMType>();
+  if (!wrappedResultType)
+    return emitError(loc, "non-LLVM function result"), nullptr;
+
+  return llvm::FunctionType::get(wrappedResultType.getUnderlyingType(),
+                                 argTypes, isVarArgs);
+}
+
+// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
+// This currently supports integer, floating point, splat and dense element
+// attributes and combinations thereof.  In case of error, report it to `loc`
+// and return nullptr.
+llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType,
+                                                   Attribute attr,
+                                                   Location loc) {
+  if (auto intAttr = attr.dyn_cast<IntegerAttr>())
+    return llvm::ConstantInt::get(llvmType, intAttr.getValue());
+  if (auto floatAttr = attr.dyn_cast<FloatAttr>())
+    return llvm::ConstantFP::get(llvmType, floatAttr.getValue());
+  if (auto funcAttr = attr.dyn_cast<SymbolRefAttr>())
+    return functionMapping.lookup(funcAttr.getValue());
+  if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
+    auto *vectorType = cast<llvm::VectorType>(llvmType);
+    auto *child = getLLVMConstant(vectorType->getElementType(),
+                                  splatAttr.getSplatValue(), loc);
+    return llvm::ConstantVector::getSplat(vectorType->getNumElements(), child);
+  }
+  if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>()) {
+    auto *vectorType = cast<llvm::VectorType>(llvmType);
+    SmallVector<llvm::Constant *, 8> constants;
+    uint64_t numElements = vectorType->getNumElements();
+    constants.reserve(numElements);
+    for (auto n : denseAttr.getAttributeValues()) {
+      constants.push_back(
+          getLLVMConstant(vectorType->getElementType(), n, loc));
+      if (!constants.back())
+        return nullptr;
+    }
+    return llvm::ConstantVector::get(constants);
+  }
+  if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
+    return llvm::ConstantDataArray::get(
+        llvmModule->getContext(), ArrayRef<char>{stringAttr.getValue().data(),
+                                                 stringAttr.getValue().size()});
+  }
+  emitError(loc, "unsupported constant value");
+  return nullptr;
+}
+
+// Convert MLIR integer comparison predicate to LLVM IR comparison predicate.
+static llvm::CmpInst::Predicate getLLVMCmpPredicate(ICmpPredicate p) {
+  switch (p) {
+  case LLVM::ICmpPredicate::eq:
+    return llvm::CmpInst::Predicate::ICMP_EQ;
+  case LLVM::ICmpPredicate::ne:
+    return llvm::CmpInst::Predicate::ICMP_NE;
+  case LLVM::ICmpPredicate::slt:
+    return llvm::CmpInst::Predicate::ICMP_SLT;
+  case LLVM::ICmpPredicate::sle:
+    return llvm::CmpInst::Predicate::ICMP_SLE;
+  case LLVM::ICmpPredicate::sgt:
+    return llvm::CmpInst::Predicate::ICMP_SGT;
+  case LLVM::ICmpPredicate::sge:
+    return llvm::CmpInst::Predicate::ICMP_SGE;
+  case LLVM::ICmpPredicate::ult:
+    return llvm::CmpInst::Predicate::ICMP_ULT;
+  case LLVM::ICmpPredicate::ule:
+    return llvm::CmpInst::Predicate::ICMP_ULE;
+  case LLVM::ICmpPredicate::ugt:
+    return llvm::CmpInst::Predicate::ICMP_UGT;
+  case LLVM::ICmpPredicate::uge:
+    return llvm::CmpInst::Predicate::ICMP_UGE;
+  default:
+    llvm_unreachable("incorrect comparison predicate");
+  }
+}
+
+static llvm::CmpInst::Predicate getLLVMCmpPredicate(FCmpPredicate p) {
+  switch (p) {
+  case LLVM::FCmpPredicate::_false:
+    return llvm::CmpInst::Predicate::FCMP_FALSE;
+  case LLVM::FCmpPredicate::oeq:
+    return llvm::CmpInst::Predicate::FCMP_OEQ;
+  case LLVM::FCmpPredicate::ogt:
+    return llvm::CmpInst::Predicate::FCMP_OGT;
+  case LLVM::FCmpPredicate::oge:
+    return llvm::CmpInst::Predicate::FCMP_OGE;
+  case LLVM::FCmpPredicate::olt:
+    return llvm::CmpInst::Predicate::FCMP_OLT;
+  case LLVM::FCmpPredicate::ole:
+    return llvm::CmpInst::Predicate::FCMP_OLE;
+  case LLVM::FCmpPredicate::one:
+    return llvm::CmpInst::Predicate::FCMP_ONE;
+  case LLVM::FCmpPredicate::ord:
+    return llvm::CmpInst::Predicate::FCMP_ORD;
+  case LLVM::FCmpPredicate::ueq:
+    return llvm::CmpInst::Predicate::FCMP_UEQ;
+  case LLVM::FCmpPredicate::ugt:
+    return llvm::CmpInst::Predicate::FCMP_UGT;
+  case LLVM::FCmpPredicate::uge:
+    return llvm::CmpInst::Predicate::FCMP_UGE;
+  case LLVM::FCmpPredicate::ult:
+    return llvm::CmpInst::Predicate::FCMP_ULT;
+  case LLVM::FCmpPredicate::ule:
+    return llvm::CmpInst::Predicate::FCMP_ULE;
+  case LLVM::FCmpPredicate::une:
+    return llvm::CmpInst::Predicate::FCMP_UNE;
+  case LLVM::FCmpPredicate::uno:
+    return llvm::CmpInst::Predicate::FCMP_UNO;
+  case LLVM::FCmpPredicate::_true:
+    return llvm::CmpInst::Predicate::FCMP_TRUE;
+  default:
+    llvm_unreachable("incorrect comparison predicate");
+  }
+}
+
+// A helper to look up remapped operands in the value remapping table.
+template <typename Range>
+SmallVector<llvm::Value *, 8> ModuleTranslation::lookupValues(Range &&values) {
+  SmallVector<llvm::Value *, 8> remapped;
+  remapped.reserve(llvm::size(values));
+  for (Value *v : values) {
+    remapped.push_back(valueMapping.lookup(v));
+  }
+  return remapped;
+}
+
+// Given a single MLIR operation, create the corresponding LLVM IR operation
+// using the `builder`.  LLVM IR Builder does not have a generic interface so
+// this has to be a long chain of `if`s calling different functions with a
+// different number of arguments.
+LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
+                                                  llvm::IRBuilder<> &builder) {
+  auto extractPosition = [](ArrayAttr attr) {
+    SmallVector<unsigned, 4> position;
+    position.reserve(attr.size());
+    for (Attribute v : attr)
+      position.push_back(v.cast<IntegerAttr>().getValue().getZExtValue());
+    return position;
+  };
+
+#include "mlir/LLVMIR/LLVMConversions.inc"
+
+  // Emit function calls.  If the "callee" attribute is present, this is a
+  // direct function call and we also need to look up the remapped function
+  // itself.  Otherwise, this is an indirect call and the callee is the first
+  // operand, look it up as a normal value.  Return the llvm::Value representing
+  // the function result, which may be of llvm::VoidTy type.
+  auto convertCall = [this, &builder](Operation &op) -> llvm::Value * {
+    auto operands = lookupValues(op.getOperands());
+    ArrayRef<llvm::Value *> operandsRef(operands);
+    if (auto attr = op.getAttrOfType<SymbolRefAttr>("callee")) {
+      return builder.CreateCall(functionMapping.lookup(attr.getValue()),
+                                operandsRef);
+    } else {
+      return builder.CreateCall(operandsRef.front(), operandsRef.drop_front());
+    }
+  };
+
+  // Emit calls.  If the called function has a result, remap the corresponding
+  // value.  Note that LLVM IR dialect CallOp has either 0 or 1 result.
+  if (isa<LLVM::CallOp>(opInst)) {
+    llvm::Value *result = convertCall(opInst);
+    if (opInst.getNumResults() != 0) {
+      valueMapping[opInst.getResult(0)] = result;
+      return success();
+    }
+    // Check that LLVM call returns void for 0-result functions.
+    return success(result->getType()->isVoidTy());
+  }
+
+  // Emit branches.  We need to look up the remapped blocks and ignore the block
+  // arguments that were transformed into PHI nodes.
+  if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) {
+    builder.CreateBr(blockMapping[brOp.getSuccessor(0)]);
+    return success();
+  }
+  if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
+    builder.CreateCondBr(valueMapping.lookup(condbrOp.getOperand(0)),
+                         blockMapping[condbrOp.getSuccessor(0)],
+                         blockMapping[condbrOp.getSuccessor(1)]);
+    return success();
+  }
+
+  return opInst.emitError("unsupported or non-LLVM operation: ")
+         << opInst.getName();
+}
+
+// Convert block to LLVM IR.  Unless `ignoreArguments` is set, emit PHI nodes
+// to define values corresponding to the MLIR block arguments.  These nodes
+// are not connected to the source basic blocks, which may not exist yet.
+LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) {
+  llvm::IRBuilder<> builder(blockMapping[&bb]);
+
+  // Before traversing operations, make block arguments available through
+  // value remapping and PHI nodes, but do not add incoming edges for the PHI
+  // nodes just yet: those values may be defined by this or following blocks.
+  // This step is omitted if "ignoreArguments" is set.  The arguments of the
+  // first block have been already made available through the remapping of
+  // LLVM function arguments.
+  if (!ignoreArguments) {
+    auto predecessors = bb.getPredecessors();
+    unsigned numPredecessors =
+        std::distance(predecessors.begin(), predecessors.end());
+    for (auto *arg : bb.getArguments()) {
+      auto wrappedType = arg->getType().dyn_cast<LLVM::LLVMType>();
+      if (!wrappedType)
+        return emitError(bb.front().getLoc(),
+                         "block argument does not have an LLVM type");
+      llvm::Type *type = wrappedType.getUnderlyingType();
+      llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors);
+      valueMapping[arg] = phi;
+    }
+  }
+
+  // Traverse operations.
+  for (auto &op : bb) {
+    if (failed(convertOperation(op, builder)))
+      return failure();
+  }
+
+  return success();
+}
+
+// Create named global variables that correspond to llvm.global definitions.
+void ModuleTranslation::convertGlobals() {
+  for (auto op : mlirModule.getOps<LLVM::GlobalOp>()) {
+    // String attributes are treated separately because they cannot appear as
+    // in-function constants and are thus not supported by getLLVMConstant.
+    if (auto strAttr = op.value().dyn_cast<StringAttr>()) {
+      llvm::Constant *cst = llvm::ConstantDataArray::getString(
+          llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false);
+      new llvm::GlobalVariable(*llvmModule, cst->getType(), op.constant(),
+                               llvm::GlobalValue::InternalLinkage, cst,
+                               op.sym_name());
+      return;
+    }
+
+    llvm::Type *type = op.getType().getUnderlyingType();
+    new llvm::GlobalVariable(
+        *llvmModule, type, op.constant(), llvm::GlobalValue::InternalLinkage,
+        getLLVMConstant(type, op.value(), op.getLoc()), op.sym_name());
+  }
+}
+
+// Get the SSA value passed to the current block from the terminator operation
+// of its predecessor.
+static Value *getPHISourceValue(Block *current, Block *pred,
+                                unsigned numArguments, unsigned index) {
+  auto &terminator = *pred->getTerminator();
+  if (isa<LLVM::BrOp>(terminator)) {
+    return terminator.getOperand(index);
+  }
+
+  // For conditional branches, we need to check if the current block is reached
+  // through the "true" or the "false" branch and take the relevant operands.
+  auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator);
+  assert(condBranchOp &&
+         "only branch operations can be terminators of a block that "
+         "has successors");
+  assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) &&
+         "successors with arguments in LLVM conditional branches must be "
+         "different blocks");
+
+  return condBranchOp.getSuccessor(0) == current
+             ? terminator.getSuccessorOperand(0, index)
+             : terminator.getSuccessorOperand(1, index);
+}
+
+void ModuleTranslation::connectPHINodes(FuncOp func) {
+  // Skip the first block, it cannot be branched to and its arguments correspond
+  // to the arguments of the LLVM function.
+  for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
+    Block *bb = &*it;
+    llvm::BasicBlock *llvmBB = blockMapping.lookup(bb);
+    auto phis = llvmBB->phis();
+    auto numArguments = bb->getNumArguments();
+    assert(numArguments == std::distance(phis.begin(), phis.end()));
+    for (auto &numberedPhiNode : llvm::enumerate(phis)) {
+      auto &phiNode = numberedPhiNode.value();
+      unsigned index = numberedPhiNode.index();
+      for (auto *pred : bb->getPredecessors()) {
+        phiNode.addIncoming(valueMapping.lookup(getPHISourceValue(
+                                bb, pred, numArguments, index)),
+                            blockMapping.lookup(pred));
+      }
+    }
+  }
+}
+
+// TODO(mlir-team): implement an iterative version
+static void topologicalSortImpl(llvm::SetVector<Block *> &blocks, Block *b) {
+  blocks.insert(b);
+  for (Block *bb : b->getSuccessors()) {
+    if (blocks.count(bb) == 0)
+      topologicalSortImpl(blocks, bb);
+  }
+}
+
+// Sort function blocks topologically.
+static llvm::SetVector<Block *> topologicalSort(FuncOp f) {
+  // For each blocks that has not been visited yet (i.e. that has no
+  // predecessors), add it to the list and traverse its successors in DFS
+  // preorder.
+  llvm::SetVector<Block *> blocks;
+  for (Block &b : f.getBlocks()) {
+    if (blocks.count(&b) == 0)
+      topologicalSortImpl(blocks, &b);
+  }
+  assert(blocks.size() == f.getBlocks().size() && "some blocks are not sorted");
+
+  return blocks;
+}
+
+LogicalResult ModuleTranslation::convertOneFunction(FuncOp func) {
+  // Clear the block and value mappings, they are only relevant within one
+  // function.
+  blockMapping.clear();
+  valueMapping.clear();
+  llvm::Function *llvmFunc = functionMapping.lookup(func.getName());
+  // Add function arguments to the value remapping table.
+  // If there was noalias info then we decorate each argument accordingly.
+  unsigned int argIdx = 0;
+  for (const auto &kvp : llvm::zip(func.getArguments(), llvmFunc->args())) {
+    llvm::Argument &llvmArg = std::get<1>(kvp);
+    BlockArgument *mlirArg = std::get<0>(kvp);
+
+    if (auto attr = func.getArgAttrOfType<BoolAttr>(argIdx, "llvm.noalias")) {
+      // NB: Attribute already verified to be boolean, so check if we can indeed
+      // attach the attribute to this argument, based on its type.
+      auto argTy = mlirArg->getType().dyn_cast<LLVM::LLVMType>();
+      if (!argTy.getUnderlyingType()->isPointerTy())
+        return func.emitError(
+            "llvm.noalias attribute attached to LLVM non-pointer argument");
+      if (attr.getValue())
+        llvmArg.addAttr(llvm::Attribute::AttrKind::NoAlias);
+    }
+    valueMapping[mlirArg] = &llvmArg;
+    argIdx++;
+  }
+
+  // First, create all blocks so we can jump to them.
+  llvm::LLVMContext &llvmContext = llvmFunc->getContext();
+  for (auto &bb : func) {
+    auto *llvmBB = llvm::BasicBlock::Create(llvmContext);
+    llvmBB->insertInto(llvmFunc);
+    blockMapping[&bb] = llvmBB;
+  }
+
+  // Then, convert blocks one by one in topological order to ensure defs are
+  // converted before uses.
+  auto blocks = topologicalSort(func);
+  for (auto indexedBB : llvm::enumerate(blocks)) {
+    auto *bb = indexedBB.value();
+    if (failed(convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0)))
+      return failure();
+  }
+
+  // Finally, after all blocks have been traversed and values mapped, connect
+  // the PHI nodes to the results of preceding blocks.
+  connectPHINodes(func);
+  return success();
+}
+
+LogicalResult ModuleTranslation::convertFunctions() {
+  // Declare all functions first because there may be function calls that form a
+  // call graph with cycles.
+  for (FuncOp function : mlirModule.getOps<FuncOp>()) {
+    mlir::BoolAttr isVarArgsAttr =
+        function.getAttrOfType<BoolAttr>("std.varargs");
+    bool isVarArgs = isVarArgsAttr && isVarArgsAttr.getValue();
+    llvm::FunctionType *functionType =
+        convertFunctionType(llvmModule->getContext(), function.getType(),
+                            function.getLoc(), isVarArgs);
+    if (!functionType)
+      return failure();
+    llvm::FunctionCallee llvmFuncCst =
+        llvmModule->getOrInsertFunction(function.getName(), functionType);
+    assert(isa<llvm::Function>(llvmFuncCst.getCallee()));
+    functionMapping[function.getName()] =
+        cast<llvm::Function>(llvmFuncCst.getCallee());
+  }
+
+  // Convert functions.
+  for (FuncOp function : mlirModule.getOps<FuncOp>()) {
+    // Ignore external functions.
+    if (function.isExternal())
+      continue;
+
+    if (failed(convertOneFunction(function)))
+      return failure();
+  }
+
+  return success();
+}
+
+std::unique_ptr<llvm::Module> ModuleTranslation::prepareLLVMModule(ModuleOp m) {
+  auto *dialect = m.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
+  assert(dialect && "LLVM dialect must be registered");
+
+  auto llvmModule = llvm::CloneModule(dialect->getLLVMModule());
+  if (!llvmModule)
+    return nullptr;
+
+  llvm::LLVMContext &llvmContext = llvmModule->getContext();
+  llvm::IRBuilder<> builder(llvmContext);
+
+  // Inject declarations for `malloc` and `free` functions that can be used in
+  // memref allocation/deallocation coming from standard ops lowering.
+  llvmModule->getOrInsertFunction("malloc", builder.getInt8PtrTy(),
+                                  builder.getInt64Ty());
+  llvmModule->getOrInsertFunction("free", builder.getVoidTy(),
+                                  builder.getInt8PtrTy());
+
+  return llvmModule;
+}
+
+} // namespace LLVM
+} // namespace mlir
diff --git a/third_party/mlir/lib/Transforms/AffineDataCopyGeneration.cpp b/third_party/mlir/lib/Transforms/AffineDataCopyGeneration.cpp
new file mode 100644
index 0000000..522ed4a
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/AffineDataCopyGeneration.cpp
@@ -0,0 +1,892 @@
+//===- AffineDataCopyGeneration.cpp - Explicit memref copying pass ------*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to automatically promote accessed memref regions
+// to buffers in a faster memory space that is explicitly managed, with the
+// necessary data movement operations performed through either regular
+// point-wise load/store's or DMAs. Such explicit copying (also referred to as
+// array packing/unpacking in the literature), when done on arrays that exhibit
+// reuse, results in near elimination of conflict misses, TLB misses, reduced
+// use of hardware prefetch streams, and reduced false sharing. It is also
+// necessary for hardware that explicitly managed levels in the memory
+// hierarchy, and where DMAs may have to be used. This optimization is often
+// performed on already tiled code.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/Utils.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include <algorithm>
+
+#define DEBUG_TYPE "affine-data-copy-generate"
+
+using namespace mlir;
+using llvm::SmallMapVector;
+
+static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
+
+static llvm::cl::opt<unsigned long long> clFastMemoryCapacity(
+    "affine-data-copy-generate-fast-mem-capacity",
+    llvm::cl::desc(
+        "Set fast memory space capacity in KiB (default: unlimited)"),
+    llvm::cl::cat(clOptionsCategory));
+
+static llvm::cl::opt<bool>
+    clDma("affine-data-copy-generate-dma",
+          llvm::cl::desc("Generate DMA instead of point-wise copy"),
+          llvm::cl::cat(clOptionsCategory),
+          llvm::cl::init(true));
+
+static llvm::cl::opt<unsigned> clFastMemorySpace(
+    "affine-data-copy-generate-fast-mem-space", llvm::cl::init(0),
+    llvm::cl::desc(
+        "Fast memory space identifier for copy generation (default: 1)"),
+    llvm::cl::cat(clOptionsCategory));
+
+static llvm::cl::opt<bool> clSkipNonUnitStrideLoop(
+    "affine-data-copy-generate-skip-non-unit-stride-loops", llvm::cl::Hidden,
+    llvm::cl::init(false),
+    llvm::cl::desc("Testing purposes: avoid non-unit stride loop choice depths "
+                   "for copy placement"),
+    llvm::cl::cat(clOptionsCategory));
+
+namespace {
+
+/// Replaces all loads and stores on memref's living in 'slowMemorySpace' by
+/// introducing copy operations to transfer data into `fastMemorySpace` and
+/// rewriting the original load's/store's to instead load/store from the
+/// allocated fast memory buffers. Additional options specify the identifier
+/// corresponding to the fast memory space and the amount of fast memory space
+/// available. The pass traverses through the nesting structure, recursing to
+/// inner levels if necessary to determine at what depth copies need to be
+/// placed so that the allocated buffers fit within the memory capacity
+/// provided.
+// TODO(bondhugula): We currently can't generate copies correctly when stores
+// are strided. Check for strided stores.
+struct AffineDataCopyGeneration
+    : public FunctionPass<AffineDataCopyGeneration> {
+  explicit AffineDataCopyGeneration(
+      unsigned slowMemorySpace = 0,
+      unsigned fastMemorySpace = clFastMemorySpace, unsigned tagMemorySpace = 0,
+      int minDmaTransferSize = 1024,
+      uint64_t fastMemCapacityBytes =
+          (clFastMemoryCapacity.getNumOccurrences() > 0
+               ? clFastMemoryCapacity * 1024 // cl-provided size is in KiB
+               : std::numeric_limits<uint64_t>::max()),
+      bool generateDma = clDma,
+      bool skipNonUnitStrideLoops = clSkipNonUnitStrideLoop)
+      : slowMemorySpace(slowMemorySpace), fastMemorySpace(fastMemorySpace),
+        tagMemorySpace(tagMemorySpace), minDmaTransferSize(minDmaTransferSize),
+        fastMemCapacityBytes(fastMemCapacityBytes), generateDma(generateDma),
+        skipNonUnitStrideLoops(skipNonUnitStrideLoops) {}
+
+  explicit AffineDataCopyGeneration(const AffineDataCopyGeneration &other)
+      : slowMemorySpace(other.slowMemorySpace),
+        fastMemorySpace(other.fastMemorySpace),
+        tagMemorySpace(other.tagMemorySpace),
+        minDmaTransferSize(other.minDmaTransferSize),
+        fastMemCapacityBytes(other.fastMemCapacityBytes),
+        generateDma(other.generateDma),
+        skipNonUnitStrideLoops(other.skipNonUnitStrideLoops) {}
+
+  void runOnFunction() override;
+  LogicalResult runOnBlock(Block *block);
+  uint64_t runOnBlock(Block::iterator begin, Block::iterator end);
+
+  LogicalResult generateCopy(const MemRefRegion &region, Block *block,
+                             Block::iterator begin, Block::iterator end,
+                             uint64_t *sizeInBytes, Block::iterator *nBegin,
+                             Block::iterator *nEnd);
+
+  // List of memory regions to copy for. We need a map vector to have a
+  // guaranteed iteration order to write test cases. CHECK-DAG doesn't help here
+  // since the alloc's for example are identical except for the SSA id.
+  SmallMapVector<Value *, std::unique_ptr<MemRefRegion>, 4> readRegions;
+  SmallMapVector<Value *, std::unique_ptr<MemRefRegion>, 4> writeRegions;
+
+  // Nests that are copy in's or copy out's; the root AffineForOp of that
+  // nest is stored herein.
+  DenseSet<Operation *> copyNests;
+
+  // Map from original memref's to the fast buffers that their accesses are
+  // replaced with.
+  DenseMap<Value *, Value *> fastBufferMap;
+
+  // Slow memory space associated with copies.
+  const unsigned slowMemorySpace;
+  // Fast memory space associated with copies.
+  unsigned fastMemorySpace;
+  // Memory space associated with DMA tags.
+  unsigned tagMemorySpace;
+  // Minimum DMA transfer size supported by the target in bytes.
+  const int minDmaTransferSize;
+  // Capacity of the faster memory space.
+  uint64_t fastMemCapacityBytes;
+
+  // If set, generate DMA operations instead of read/write.
+  bool generateDma;
+
+  // If set, ignore loops with steps other than 1.
+  bool skipNonUnitStrideLoops;
+
+  // Constant zero index to avoid too many duplicates.
+  Value *zeroIndex = nullptr;
+};
+
+} // end anonymous namespace
+
+/// Generates copies for memref's living in 'slowMemorySpace' into newly created
+/// buffers in 'fastMemorySpace', and replaces memory operations to the former
+/// by the latter. Only load op's handled for now.
+/// TODO(bondhugula): extend this to store op's.
+FunctionPassBase *mlir::createAffineDataCopyGenerationPass(
+    unsigned slowMemorySpace, unsigned fastMemorySpace, unsigned tagMemorySpace,
+    int minDmaTransferSize, uint64_t fastMemCapacityBytes) {
+  return new AffineDataCopyGeneration(slowMemorySpace, fastMemorySpace,
+                                      tagMemorySpace, minDmaTransferSize,
+                                      fastMemCapacityBytes);
+}
+
+// Info comprising stride and number of elements transferred every stride.
+struct StrideInfo {
+  int64_t stride;
+  int64_t numEltPerStride;
+};
+
+/// Returns striding information for a copy/transfer of this region with
+/// potentially multiple striding levels from outermost to innermost. For an
+/// n-dimensional region, there can be at most n-1 levels of striding
+/// successively nested.
+//  TODO(bondhugula): make this work with non-identity layout maps.
+static void getMultiLevelStrides(const MemRefRegion &region,
+                                 ArrayRef<int64_t> bufferShape,
+                                 SmallVectorImpl<StrideInfo> *strideInfos) {
+  if (bufferShape.size() <= 1)
+    return;
+
+  int64_t numEltPerStride = 1;
+  int64_t stride = 1;
+  for (int d = bufferShape.size() - 1; d >= 1; d--) {
+    int64_t dimSize = region.memref->getType().cast<MemRefType>().getDimSize(d);
+    stride *= dimSize;
+    numEltPerStride *= bufferShape[d];
+    // A stride is needed only if the region has a shorter extent than the
+    // memref along the dimension *and* has an extent greater than one along the
+    // next major dimension.
+    if (bufferShape[d] < dimSize && bufferShape[d - 1] > 1) {
+      strideInfos->push_back({stride, numEltPerStride});
+    }
+  }
+}
+
+/// Construct the memref region to just include the entire memref. Returns false
+/// dynamic shaped memref's for now. `numParamLoopIVs` is the number of
+/// enclosing loop IVs of opInst (starting from the outermost) that the region
+/// is parametric on.
+static bool getFullMemRefAsRegion(Operation *opInst, unsigned numParamLoopIVs,
+                                  MemRefRegion *region) {
+  unsigned rank;
+  if (auto loadOp = dyn_cast<AffineLoadOp>(opInst)) {
+    rank = loadOp.getMemRefType().getRank();
+    region->memref = loadOp.getMemRef();
+    region->setWrite(false);
+  } else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst)) {
+    rank = storeOp.getMemRefType().getRank();
+    region->memref = storeOp.getMemRef();
+    region->setWrite(true);
+  } else {
+    assert(false && "expected load or store op");
+    return false;
+  }
+  auto memRefType = region->memref->getType().cast<MemRefType>();
+  if (!memRefType.hasStaticShape())
+    return false;
+
+  auto *regionCst = region->getConstraints();
+
+  // Just get the first numSymbols IVs, which the memref region is parametric
+  // on.
+  SmallVector<AffineForOp, 4> ivs;
+  getLoopIVs(*opInst, &ivs);
+  ivs.resize(numParamLoopIVs);
+  SmallVector<Value *, 4> symbols;
+  extractForInductionVars(ivs, &symbols);
+  regionCst->reset(rank, numParamLoopIVs, 0);
+  regionCst->setIdValues(rank, rank + numParamLoopIVs, symbols);
+
+  // Memref dim sizes provide the bounds.
+  for (unsigned d = 0; d < rank; d++) {
+    auto dimSize = memRefType.getDimSize(d);
+    assert(dimSize > 0 && "filtered dynamic shapes above");
+    regionCst->addConstantLowerBound(d, 0);
+    regionCst->addConstantUpperBound(d, dimSize - 1);
+  }
+  return true;
+}
+
+static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED
+emitRemarkForBlock(Block &block) {
+  return block.getParentOp()->emitRemark();
+}
+
+/// Generates a point-wise copy from/to `memref' to/from `fastMemRef' and
+/// returns the outermost AffineForOp of the copy loop nest. `memIndicesStart'
+/// holds the lower coordinates of the region in the original memref to copy
+/// in/out. If `copyOut' is true, generates a copy-out; otherwise a copy-in.
+static AffineForOp generatePointWiseCopy(Location loc, Value *memref,
+                                         Value *fastMemRef,
+                                         ArrayRef<Value *> memIndicesStart,
+                                         ArrayRef<int64_t> fastBufferShape,
+                                         bool isCopyOut, OpBuilder b) {
+  assert(!memIndicesStart.empty() && "only 1-d or more memrefs");
+
+  // The copy-in nest is generated as follows as an example for a 2-d region:
+  // for x = ...
+  //   for y = ...
+  //     fast_buf[x][y] = buf[mem_x + x][mem_y + y]
+
+  SmallVector<Value *, 4> fastBufIndices, memIndices;
+  AffineForOp copyNestRoot;
+  for (unsigned d = 0, e = fastBufferShape.size(); d < e; ++d) {
+    auto forOp = b.create<AffineForOp>(loc, 0, fastBufferShape[d]);
+    if (d == 0)
+      copyNestRoot = forOp;
+    b = forOp.getBodyBuilder();
+    fastBufIndices.push_back(forOp.getInductionVar());
+    // Construct the subscript for the slow memref being copied.
+    SmallVector<Value *, 2> operands = {memIndicesStart[d], forOp.getInductionVar()};
+    auto memIndex = b.create<AffineApplyOp>(
+        loc,
+        b.getAffineMap(2, 0, b.getAffineDimExpr(0) + b.getAffineDimExpr(1)),
+        operands);
+    memIndices.push_back(memIndex);
+  }
+
+  if (!isCopyOut) {
+    // Copy in.
+    auto load = b.create<AffineLoadOp>(loc, memref, memIndices);
+    b.create<AffineStoreOp>(loc, load, fastMemRef, fastBufIndices);
+    return copyNestRoot;
+  }
+
+  // Copy out.
+  auto load = b.create<AffineLoadOp>(loc, fastMemRef, fastBufIndices);
+  b.create<AffineStoreOp>(loc, load, memref, memIndices);
+  return copyNestRoot;
+}
+
+/// Creates a buffer in the faster memory space for the specified region;
+/// generates a copy from the lower memory space to this one, and replaces all
+/// loads to load from that buffer. Returns failure if copies could not be
+/// generated due to yet unimplemented cases. `begin` and `end` specify the
+/// insertion points where the incoming copies and outgoing copies,
+/// respectively, should be inserted (the insertion happens right before the
+/// insertion point). Since `begin` can itself be invalidated due to the memref
+/// rewriting done from this method, the output argument `nBegin` is set to its
+/// replacement (set to `begin` if no invalidation happens). Since outgoing
+/// copies are inserted at `end`, the output argument `nEnd` is set to the one
+/// following the original end (since the latter could have been
+/// invalidated/replaced). `sizeInBytes` is set to the size of the fast buffer
+/// allocated.
+LogicalResult AffineDataCopyGeneration::generateCopy(
+    const MemRefRegion &region, Block *block, Block::iterator begin,
+    Block::iterator end, uint64_t *sizeInBytes, Block::iterator *nBegin,
+    Block::iterator *nEnd) {
+  *nBegin = begin;
+  *nEnd = end;
+
+  if (begin == end)
+    return success();
+
+  // Copies for read regions are going to be inserted at 'begin'.
+  OpBuilder prologue(block, begin);
+  // Copies for write regions are going to be inserted at 'end'.
+  OpBuilder epilogue(block, end);
+  OpBuilder &b = region.isWrite() ? epilogue : prologue;
+
+  // Builder to create constants at the top level.
+  auto func = block->getParent()->getParentOfType<FuncOp>();
+  OpBuilder top(func.getBody());
+
+  auto loc = region.loc;
+  auto *memref = region.memref;
+  auto memRefType = memref->getType().cast<MemRefType>();
+
+  auto layoutMaps = memRefType.getAffineMaps();
+  if (layoutMaps.size() > 1 ||
+      (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) {
+    LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
+    return failure();
+  }
+
+  // Indices to use for the copying.
+  // Indices for the original memref being copied from/to.
+  SmallVector<Value *, 4> memIndices;
+  // Indices for the faster buffer being copied into/from.
+  SmallVector<Value *, 4> bufIndices;
+
+  unsigned rank = memRefType.getRank();
+  SmallVector<int64_t, 4> fastBufferShape;
+
+  // Compute the extents of the buffer.
+  std::vector<SmallVector<int64_t, 4>> lbs;
+  SmallVector<int64_t, 8> lbDivisors;
+  lbs.reserve(rank);
+  Optional<int64_t> numElements = region.getConstantBoundingSizeAndShape(
+      &fastBufferShape, &lbs, &lbDivisors);
+  if (!numElements.hasValue()) {
+    LLVM_DEBUG(llvm::dbgs() << "Non-constant region size not supported\n");
+    return failure();
+  }
+
+  if (numElements.getValue() == 0) {
+    LLVM_DEBUG(llvm::dbgs() << "Nothing to copy\n");
+    *sizeInBytes = 0;
+    return success();
+  }
+
+  const FlatAffineConstraints *cst = region.getConstraints();
+  // 'regionSymbols' hold values that this memory region is symbolic/paramteric
+  // on; these typically include loop IVs surrounding the level at which the
+  // copy generation is being done or other valid symbols in MLIR.
+  SmallVector<Value *, 8> regionSymbols;
+  cst->getIdValues(rank, cst->getNumIds(), &regionSymbols);
+
+  // Construct the index expressions for the fast memory buffer. The index
+  // expression for a particular dimension of the fast buffer is obtained by
+  // subtracting out the lower bound on the original memref's data region
+  // along the corresponding dimension.
+
+  // Index start offsets for faster memory buffer relative to the original.
+  SmallVector<AffineExpr, 4> offsets;
+  offsets.reserve(rank);
+  for (unsigned d = 0; d < rank; d++) {
+    assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
+
+    AffineExpr offset = top.getAffineConstantExpr(0);
+    for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) {
+      offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
+    }
+    assert(lbDivisors[d] > 0);
+    offset =
+        (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
+
+    // Set copy start location for this dimension in the lower memory space
+    // memref.
+    if (auto caf = offset.dyn_cast<AffineConstantExpr>()) {
+      auto indexVal = caf.getValue();
+      if (indexVal == 0) {
+        memIndices.push_back(zeroIndex);
+      } else {
+        memIndices.push_back(
+            top.create<ConstantIndexOp>(loc, indexVal).getResult());
+      }
+    } else {
+      // The coordinate for the start location is just the lower bound along the
+      // corresponding dimension on the memory region (stored in 'offset').
+      auto map = top.getAffineMap(
+          cst->getNumDimIds() + cst->getNumSymbolIds() - rank, 0, offset);
+      memIndices.push_back(b.create<AffineApplyOp>(loc, map, regionSymbols));
+    }
+    // The fast buffer is copied into at location zero; addressing is relative.
+    bufIndices.push_back(zeroIndex);
+
+    // Record the offsets since they are needed to remap the memory accesses of
+    // the original memref further below.
+    offsets.push_back(offset);
+  }
+
+  // The faster memory space buffer.
+  Value *fastMemRef;
+
+  // Check if a buffer was already created.
+  bool existingBuf = fastBufferMap.count(memref) > 0;
+  if (!existingBuf) {
+    auto fastMemRefType = top.getMemRefType(
+        fastBufferShape, memRefType.getElementType(), {}, fastMemorySpace);
+
+    // Create the fast memory space buffer just before the 'affine.for'
+    // operation.
+    fastMemRef = prologue.create<AllocOp>(loc, fastMemRefType).getResult();
+    // Record it.
+    fastBufferMap[memref] = fastMemRef;
+    // fastMemRefType is a constant shaped memref.
+    *sizeInBytes = getMemRefSizeInBytes(fastMemRefType).getValue();
+    LLVM_DEBUG(emitRemarkForBlock(*block)
+               << "Creating fast buffer of type " << fastMemRefType
+               << " and size " << llvm::divideCeil(*sizeInBytes, 1024)
+               << " KiB\n");
+  } else {
+    // Reuse the one already created.
+    fastMemRef = fastBufferMap[memref];
+    *sizeInBytes = 0;
+  }
+
+  auto numElementsSSA =
+      top.create<ConstantIndexOp>(loc, numElements.getValue());
+
+  SmallVector<StrideInfo, 4> strideInfos;
+  getMultiLevelStrides(region, fastBufferShape, &strideInfos);
+
+  // TODO(bondhugula): use all stride levels once DmaStartOp is extended for
+  // multi-level strides.
+  if (strideInfos.size() > 1) {
+    LLVM_DEBUG(llvm::dbgs() << "Only up to one level of stride supported\n");
+    return failure();
+  }
+
+  Value *stride = nullptr;
+  Value *numEltPerStride = nullptr;
+  if (!strideInfos.empty()) {
+    stride = top.create<ConstantIndexOp>(loc, strideInfos[0].stride);
+    numEltPerStride =
+        top.create<ConstantIndexOp>(loc, strideInfos[0].numEltPerStride);
+  }
+
+  // Record the last operation just before the point where we insert the
+  // copy out's. We later do the memref replacement later only in [begin,
+  // postDomFilter] so that the original memref's in the data movement code
+  // themselves don't get replaced.
+  auto postDomFilter = std::prev(end);
+
+  // Create fully composed affine maps for each memref.
+  auto memAffineMap = b.getMultiDimIdentityMap(memIndices.size());
+  fullyComposeAffineMapAndOperands(&memAffineMap, &memIndices);
+  auto bufAffineMap = b.getMultiDimIdentityMap(bufIndices.size());
+  fullyComposeAffineMapAndOperands(&bufAffineMap, &bufIndices);
+
+  if (!generateDma) {
+    auto copyNest = generatePointWiseCopy(loc, memref, fastMemRef, memIndices,
+                                          fastBufferShape,
+                                          /*isCopyOut=*/region.isWrite(), b);
+
+    // Record this so that we can skip it from yet another copy.
+    copyNests.insert(copyNest);
+
+    if (region.isWrite())
+      // Since new ops are being appended (for copy out's), adjust the end to
+      // mark end of block range being processed.
+      *nEnd = Block::iterator(copyNest.getOperation());
+  } else {
+    // Create a tag (single element 1-d memref) for the DMA.
+    auto tagMemRefType =
+        top.getMemRefType({1}, top.getIntegerType(32), {}, tagMemorySpace);
+    auto tagMemRef = prologue.create<AllocOp>(loc, tagMemRefType);
+
+    SmallVector<Value *, 4> tagIndices({zeroIndex});
+    auto tagAffineMap = b.getMultiDimIdentityMap(tagIndices.size());
+    fullyComposeAffineMapAndOperands(&tagAffineMap, &tagIndices);
+    if (!region.isWrite()) {
+      // DMA non-blocking read from original buffer to fast buffer.
+      b.create<AffineDmaStartOp>(loc, memref, memAffineMap, memIndices,
+                                 fastMemRef, bufAffineMap, bufIndices,
+                                 tagMemRef, tagAffineMap, tagIndices,
+                                 numElementsSSA, stride, numEltPerStride);
+    } else {
+      // DMA non-blocking write from fast buffer to the original memref.
+      auto op = b.create<AffineDmaStartOp>(
+          loc, fastMemRef, bufAffineMap, bufIndices, memref, memAffineMap,
+          memIndices, tagMemRef, tagAffineMap, tagIndices, numElementsSSA,
+          stride, numEltPerStride);
+      // Since new ops are being appended (for outgoing DMAs), adjust the end to
+      // mark end of block range being processed.
+      *nEnd = Block::iterator(op.getOperation());
+    }
+
+    // Matching DMA wait to block on completion; tag always has a 0 index.
+    b.create<AffineDmaWaitOp>(loc, tagMemRef, tagAffineMap, zeroIndex,
+                              numElementsSSA);
+
+    // Generate dealloc for the tag.
+    auto tagDeallocOp = epilogue.create<DeallocOp>(loc, tagMemRef);
+    if (*nEnd == end)
+      // Since new ops are being appended (for outgoing DMAs), adjust the end to
+      // mark end of range of the original.
+      *nEnd = Block::iterator(tagDeallocOp.getOperation());
+  }
+
+  // Generate dealloc for the buffer.
+  if (!existingBuf) {
+    auto bufDeallocOp = epilogue.create<DeallocOp>(loc, fastMemRef);
+    // When generating pointwise copies, `nEnd' has to be set to deallocOp on
+    // the fast buffer (since it marks the new end insertion point).
+    if (!generateDma && *nEnd == end)
+      *nEnd = Block::iterator(bufDeallocOp.getOperation());
+  }
+
+  // Replace all uses of the old memref with the faster one while remapping
+  // access indices (subtracting out lower bound offsets for each dimension).
+  // Ex: to replace load %A[%i, %j] with load %Abuf[%i - %iT, %j - %jT],
+  // index remap will be (%i, %j) -> (%i - %iT, %j - %jT),
+  // i.e., affine.apply (d0, d1, d2, d3) -> (d2-d0, d3-d1) (%iT, %jT, %i, %j),
+  // and (%iT, %jT) will be the 'extraOperands' for 'rep all memref uses with'.
+  // d2, d3 correspond to the original indices (%i, %j).
+  SmallVector<AffineExpr, 4> remapExprs;
+  remapExprs.reserve(rank);
+  for (unsigned i = 0; i < rank; i++) {
+    // The starting operands of indexRemap will be regionSymbols (the symbols on
+    // which the memref region is parametric); then those corresponding to
+    // the memref's original indices follow.
+    auto dimExpr = b.getAffineDimExpr(regionSymbols.size() + i);
+    remapExprs.push_back(dimExpr - offsets[i]);
+  }
+  auto indexRemap = b.getAffineMap(regionSymbols.size() + rank, 0, remapExprs);
+
+  // Record the begin since it may be invalidated by memref replacement.
+  Block::iterator prev;
+  bool wasAtStartOfBlock = (begin == block->begin());
+  if (!wasAtStartOfBlock)
+    prev = std::prev(begin);
+
+  // *Only* those uses within the range [begin, end) of 'block' are replaced.
+  replaceAllMemRefUsesWith(memref, fastMemRef,
+                           /*extraIndices=*/{}, indexRemap,
+                           /*extraOperands=*/regionSymbols,
+                           /*domInstFilter=*/&*begin,
+                           /*postDomInstFilter=*/&*postDomFilter);
+
+  *nBegin = wasAtStartOfBlock ? block->begin() : std::next(prev);
+
+  return success();
+}
+
+/// Generate copies for this block. The block is partitioned into separate
+/// ranges: each range is either a sequence of one or more operations starting
+/// and ending with an affine load or store op, or just an affine.forop (which
+/// could have other affine for op's nested within).
+LogicalResult AffineDataCopyGeneration::runOnBlock(Block *block) {
+  if (block->empty())
+    return success();
+
+  copyNests.clear();
+
+  // Every affine.forop in the block starts and ends a block range for copying.
+  // A contiguous sequence of operations starting and ending with a load/store
+  // op is also identified as a copy block range. Straightline code (a
+  // contiguous chunk of operations excluding AffineForOp's) are always assumed
+  // to not exhaust memory. As a result, this approach is conservative in some
+  // cases at the moment; we do a check later and report an error with location
+  // info.
+  // TODO(bondhugula): An 'affine.if' operation is being treated similar to an
+  // operation. 'affine.if''s could have 'affine.for's in them;
+  // treat them separately.
+
+  // Get to the first load, store, or for op (that is not a copy nest itself).
+  auto curBegin =
+      std::find_if(block->begin(), block->end(), [&](Operation &op) {
+        return (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op) ||
+                isa<AffineForOp>(op)) &&
+               copyNests.count(&op) == 0;
+      });
+
+  for (auto it = curBegin; it != block->end(); ++it) {
+    AffineForOp forOp;
+    if ((forOp = dyn_cast<AffineForOp>(&*it)) && copyNests.count(forOp) == 0) {
+      // Returns true if the footprint is known to exceed capacity.
+      auto exceedsCapacity = [&](AffineForOp forOp) {
+        Optional<int64_t> footprint =
+            getMemoryFootprintBytes(forOp,
+                                    /*memorySpace=*/0);
+        return (footprint.hasValue() &&
+                static_cast<uint64_t>(footprint.getValue()) >
+                    fastMemCapacityBytes);
+      };
+
+      // If the memory footprint of the 'affine.for' loop is higher than fast
+      // memory capacity (when provided), we recurse to copy at an inner level
+      // until we find a depth at which footprint fits in fast mem capacity. If
+      // the footprint can't be calculated, we assume for now it fits. Recurse
+      // inside if footprint for 'forOp' exceeds capacity, or when
+      // skipNonUnitStrideLoops is set and the step size is not one.
+      bool recurseInner = skipNonUnitStrideLoops ? forOp.getStep() != 1
+                                                 : exceedsCapacity(forOp);
+      if (recurseInner) {
+        // We'll recurse and do the copies at an inner level for 'forInst'.
+        runOnBlock(/*begin=*/curBegin, /*end=*/it);
+        // Recurse onto the body of this loop.
+        runOnBlock(forOp.getBody());
+        // The next block range starts right after the 'affine.for' operation.
+        curBegin = std::next(it);
+      } else {
+        // We have enough capacity, i.e., copies will be computed for the
+        // portion of the block until 'it', and for 'it', which is 'forOp'. Note
+        // that for the latter, the copies are placed just before this loop (for
+        // incoming copies) and right after (for outgoing ones).
+        runOnBlock(/*begin=*/curBegin, /*end=*/it);
+
+        // Inner loop copies have their own scope - we don't thus update
+        // consumed capacity. The footprint check above guarantees this inner
+        // loop's footprint fits.
+        runOnBlock(/*begin=*/it, /*end=*/std::next(it));
+        curBegin = std::next(it);
+      }
+    } else if (!isa<AffineLoadOp>(&*it) && !isa<AffineStoreOp>(&*it)) {
+      runOnBlock(/*begin=*/curBegin, /*end=*/it);
+      curBegin = std::next(it);
+    }
+  }
+
+  // Generate the copy for the final block range.
+  if (curBegin != block->end()) {
+    // Can't be a terminator because it would have been skipped above.
+    assert(!curBegin->isKnownTerminator() && "can't be a terminator");
+    runOnBlock(/*begin=*/curBegin, /*end=*/block->end());
+  }
+
+  return success();
+}
+
+/// Given a memref region, determine the lowest depth at which transfers can be
+/// placed for it, and return the corresponding block, start and end positions
+/// in the block for placing incoming (read) and outgoing (write) copies
+/// respectively. The lowest depth depends on whether the region being accessed
+/// is invariant with respect to one or more immediately surrounding loops.
+static void
+findHighestBlockForPlacement(const MemRefRegion &region, Block &block,
+                             Block::iterator &begin, Block::iterator &end,
+                             Block **copyPlacementBlock,
+                             Block::iterator *copyPlacementReadStart,
+                             Block::iterator *copyPlacementWriteStart) {
+  const auto *cst = region.getConstraints();
+  SmallVector<Value *, 4> symbols;
+  cst->getIdValues(cst->getNumDimIds(), cst->getNumDimAndSymbolIds(), &symbols);
+
+  SmallVector<AffineForOp, 4> enclosingFors;
+  getLoopIVs(*block.begin(), &enclosingFors);
+  // Walk up loop parents till we find an IV on which this region is
+  // symbolic/variant.
+  auto it = enclosingFors.rbegin();
+  for (auto e = enclosingFors.rend(); it != e; ++it) {
+    // TODO(bondhugula): also need to be checking this for regions symbols that
+    // aren't loop IVs, whether we are within their resp. defs' dominance scope.
+    if (llvm::is_contained(symbols, it->getInductionVar()))
+      break;
+  }
+
+  if (it != enclosingFors.rbegin()) {
+    auto lastInvariantIV = *std::prev(it);
+    *copyPlacementReadStart = Block::iterator(lastInvariantIV.getOperation());
+    *copyPlacementWriteStart = std::next(*copyPlacementReadStart);
+    *copyPlacementBlock = lastInvariantIV.getOperation()->getBlock();
+  } else {
+    *copyPlacementReadStart = begin;
+    *copyPlacementWriteStart = end;
+    *copyPlacementBlock = &block;
+  }
+}
+
+/// Generates copies for a contiguous sequence of operations in `block` in the
+/// iterator range [begin, end). Returns the total size of the fast buffers
+/// used.
+//  Since we generate alloc's and dealloc's for all fast buffers (before and
+//  after the range of operations resp.), all of the fast memory capacity is
+//  assumed to be available for processing this block range.
+uint64_t AffineDataCopyGeneration::runOnBlock(Block::iterator begin,
+                                              Block::iterator end) {
+  if (begin == end)
+    return 0;
+
+  assert(begin->getBlock() == std::prev(end)->getBlock() &&
+         "Inconsistent args");
+
+  Block *block = begin->getBlock();
+
+  // Copies will be generated for this depth, i.e., symbolic in all loops
+  // surrounding the this block range.
+  unsigned copyDepth = getNestingDepth(*begin);
+
+  LLVM_DEBUG(llvm::dbgs() << "Generating copies at depth " << copyDepth
+                          << "\n");
+
+  readRegions.clear();
+  writeRegions.clear();
+  fastBufferMap.clear();
+
+  // To check for errors when walking the block.
+  bool error = false;
+
+  // Walk this range of operations  to gather all memory regions.
+  block->walk(begin, end, [&](Operation *opInst) {
+    // Gather regions to allocate to buffers in faster memory space.
+    if (auto loadOp = dyn_cast<AffineLoadOp>(opInst)) {
+      if (loadOp.getMemRefType().getMemorySpace() != slowMemorySpace)
+        return;
+    } else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst)) {
+      if (storeOp.getMemRefType().getMemorySpace() != slowMemorySpace)
+        return;
+    } else {
+      // Neither load nor a store op.
+      return;
+    }
+
+    // Compute the MemRefRegion accessed.
+    auto region = llvm::make_unique<MemRefRegion>(opInst->getLoc());
+    if (failed(region->compute(opInst, copyDepth))) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Error obtaining memory region: semi-affine maps?\n");
+      LLVM_DEBUG(llvm::dbgs() << "over-approximating to the entire memref\n");
+      if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) {
+        LLVM_DEBUG(
+            opInst->emitError("Non-constant memref sizes not yet supported"));
+        error = true;
+        return;
+      }
+    }
+
+    // Each memref has a single buffer associated with it irrespective of how
+    // many load's and store's happen on it.
+    // TODO(bondhugula): in the future, when regions don't intersect and satisfy
+    // other properties (based on load/store regions), we could consider
+    // multiple buffers per memref.
+
+    // Add to the appropriate region if it's not already in it, or take a
+    // bounding box union with the existing one if it's already in there.
+    // Note that a memref may have both read and write regions - so update the
+    // region in the other list if one exists (write in case of read and vice
+    // versa) since there is a single bounding box for a memref across all reads
+    // and writes that happen on it.
+
+    // Attempts to update; returns true if 'region' exists in targetRegions.
+    auto updateRegion =
+        [&](const SmallMapVector<Value *, std::unique_ptr<MemRefRegion>, 4>
+                &targetRegions) {
+          auto it = targetRegions.find(region->memref);
+          if (it == targetRegions.end())
+            return false;
+
+          // Perform a union with the existing region.
+          if (failed(it->second->unionBoundingBox(*region))) {
+            LLVM_DEBUG(llvm::dbgs()
+                       << "Memory region bounding box failed; "
+                          "over-approximating to the entire memref\n");
+            // If the union fails, we will overapproximate.
+            if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) {
+              LLVM_DEBUG(opInst->emitError(
+                  "Non-constant memref sizes not yet supported"));
+              error = true;
+              return true;
+            }
+            it->second->getConstraints()->clearAndCopyFrom(
+                *region->getConstraints());
+          } else {
+            // Union was computed and stored in 'it->second': copy to 'region'.
+            region->getConstraints()->clearAndCopyFrom(
+                *it->second->getConstraints());
+          }
+          return true;
+        };
+
+    bool existsInRead = updateRegion(readRegions);
+    if (error)
+      return;
+    bool existsInWrite = updateRegion(writeRegions);
+    if (error)
+      return;
+
+    // Finally add it to the region list.
+    if (region->isWrite() && !existsInWrite) {
+      writeRegions[region->memref] = std::move(region);
+    } else if (!region->isWrite() && !existsInRead) {
+      readRegions[region->memref] = std::move(region);
+    }
+  });
+
+  if (error) {
+    begin->emitError(
+        "copy generation failed for one or more memref's in this block\n");
+    return 0;
+  }
+
+  uint64_t totalCopyBuffersSizeInBytes = 0;
+  bool ret = true;
+  auto processRegions =
+      [&](const SmallMapVector<Value *, std::unique_ptr<MemRefRegion>, 4>
+              &regions) {
+        for (const auto &regionEntry : regions) {
+          // For each region, hoist copy in/out past all invariant
+          // 'affine.for's.
+          Block::iterator copyPlacementReadStart, copyPlacementWriteStart;
+          Block *copyPlacementBlock;
+          findHighestBlockForPlacement(
+              *regionEntry.second, *block, begin, end, &copyPlacementBlock,
+              &copyPlacementReadStart, &copyPlacementWriteStart);
+
+          uint64_t sizeInBytes;
+          Block::iterator nBegin, nEnd;
+          LogicalResult iRet = generateCopy(
+              *regionEntry.second, copyPlacementBlock, copyPlacementReadStart,
+              copyPlacementWriteStart, &sizeInBytes, &nBegin, &nEnd);
+          if (succeeded(iRet)) {
+            // copyPlacmentStart/End (or begin/end) may be invalidated; use
+            // nBegin, nEnd to reset.
+            if (copyPlacementBlock == block) {
+              begin = nBegin;
+              end = nEnd;
+            }
+            totalCopyBuffersSizeInBytes += sizeInBytes;
+          }
+          ret = ret & succeeded(iRet);
+        }
+      };
+  processRegions(readRegions);
+  processRegions(writeRegions);
+
+  if (!ret) {
+    begin->emitError(
+        "copy generation failed for one or more memref's in this block\n");
+    return totalCopyBuffersSizeInBytes;
+  }
+
+  // For a range of operations, a note will be emitted at the caller.
+  AffineForOp forOp;
+  uint64_t sizeInKib = llvm::divideCeil(totalCopyBuffersSizeInBytes, 1024);
+  if (llvm::DebugFlag && (forOp = dyn_cast<AffineForOp>(&*begin))) {
+    forOp.emitRemark()
+        << sizeInKib
+        << " KiB of copy buffers in fast memory space for this block\n";
+  }
+
+  if (totalCopyBuffersSizeInBytes > fastMemCapacityBytes) {
+    StringRef str = "Total size of all copy buffers' for this block "
+                    "exceeds fast memory capacity\n";
+    block->getParentOp()->emitError(str);
+  }
+
+  return totalCopyBuffersSizeInBytes;
+}
+
+void AffineDataCopyGeneration::runOnFunction() {
+  FuncOp f = getFunction();
+  OpBuilder topBuilder(f.getBody());
+  zeroIndex = topBuilder.create<ConstantIndexOp>(f.getLoc(), 0);
+
+  for (auto &block : f)
+    runOnBlock(&block);
+}
+
+static PassRegistration<AffineDataCopyGeneration>
+    pass("affine-data-copy-generate",
+         "Generate explicit copying for memory operations");
diff --git a/third_party/mlir/lib/Transforms/CMakeLists.txt b/third_party/mlir/lib/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..e256c28
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/CMakeLists.txt
@@ -0,0 +1,36 @@
+add_subdirectory(Utils)
+
+add_llvm_library(MLIRTransforms
+  AffineDataCopyGeneration.cpp
+  Canonicalizer.cpp
+  CSE.cpp
+  DialectConversion.cpp
+  LoopCoalescing.cpp
+  LoopFusion.cpp
+  LoopInvariantCodeMotion.cpp
+  LoopTiling.cpp
+  LoopUnrollAndJam.cpp
+  LoopUnroll.cpp
+  LowerAffine.cpp
+  LowerVectorTransfers.cpp
+  MaterializeVectors.cpp
+  MemRefDataFlowOpt.cpp
+  PipelineDataTransfer.cpp
+  SimplifyAffineStructures.cpp
+  StripDebugInfo.cpp
+  Vectorize.cpp
+  ViewRegionGraph.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
+  )
+
+add_dependencies(MLIRTransforms MLIRStandardOpsIncGen)
+target_link_libraries(MLIRTransforms
+  MLIRAffineOps
+  MLIRAnalysis
+  MLIRLoopOps
+  MLIRPass
+  MLIRTransformUtils
+  MLIRVectorOps
+  )
diff --git a/third_party/mlir/lib/Transforms/CSE.cpp b/third_party/mlir/lib/Transforms/CSE.cpp
new file mode 100644
index 0000000..eeb63e7
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/CSE.cpp
@@ -0,0 +1,264 @@
+//===- CSE.cpp - Common Sub-expression Elimination ------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This transformation pass performs a simple common sub-expression elimination
+// algorithm on operations within a function.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Dominance.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/Functional.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/Utils.h"
+#include "llvm/ADT/DenseMapInfo.h"
+#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/ScopedHashTable.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/RecyclingAllocator.h"
+#include <deque>
+using namespace mlir;
+
+namespace {
+// TODO(riverriddle) Handle commutative operations.
+struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
+  static unsigned getHashValue(const Operation *opC) {
+    auto *op = const_cast<Operation *>(opC);
+    // Hash the operations based upon their:
+    //   - Operation Name
+    //   - Attributes
+    //   - Result Types
+    //   - Operands
+    return hash_combine(
+        op->getName(), op->getAttrs(),
+        hash_combine_range(op->result_type_begin(), op->result_type_end()),
+        hash_combine_range(op->operand_begin(), op->operand_end()));
+  }
+  static bool isEqual(const Operation *lhsC, const Operation *rhsC) {
+    auto *lhs = const_cast<Operation *>(lhsC);
+    auto *rhs = const_cast<Operation *>(rhsC);
+    if (lhs == rhs)
+      return true;
+    if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
+        rhs == getTombstoneKey() || rhs == getEmptyKey())
+      return false;
+
+    // Compare the operation name.
+    if (lhs->getName() != rhs->getName())
+      return false;
+    // Check operand and result type counts.
+    if (lhs->getNumOperands() != rhs->getNumOperands() ||
+        lhs->getNumResults() != rhs->getNumResults())
+      return false;
+    // Compare attributes.
+    if (lhs->getAttrs() != rhs->getAttrs())
+      return false;
+    // Compare operands.
+    if (!std::equal(lhs->operand_begin(), lhs->operand_end(),
+                    rhs->operand_begin()))
+      return false;
+    // Compare result types.
+    return std::equal(lhs->result_type_begin(), lhs->result_type_end(),
+                      rhs->result_type_begin());
+  }
+};
+} // end anonymous namespace
+
+namespace {
+/// Simple common sub-expression elimination.
+struct CSE : public FunctionPass<CSE> {
+  CSE() = default;
+  CSE(const CSE &) {}
+
+  /// Shared implementation of operation elimination and scoped map definitions.
+  using AllocatorTy = llvm::RecyclingAllocator<
+      llvm::BumpPtrAllocator,
+      llvm::ScopedHashTableVal<Operation *, Operation *>>;
+  using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
+                                            SimpleOperationInfo, AllocatorTy>;
+
+  /// Represents a single entry in the depth first traversal of a CFG.
+  struct CFGStackNode {
+    CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node)
+        : scope(knownValues), node(node), childIterator(node->begin()),
+          processed(false) {}
+
+    /// Scope for the known values.
+    ScopedMapTy::ScopeTy scope;
+
+    DominanceInfoNode *node;
+    DominanceInfoNode::iterator childIterator;
+
+    /// If this node has been fully processed yet or not.
+    bool processed;
+  };
+
+  /// Attempt to eliminate a redundant operation. Returns success if the
+  /// operation was marked for removal, failure otherwise.
+  LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op);
+
+  void simplifyBlock(ScopedMapTy &knownValues, DominanceInfo &domInfo,
+                     Block *bb);
+  void simplifyRegion(ScopedMapTy &knownValues, DominanceInfo &domInfo,
+                      Region &region);
+
+  void runOnFunction() override;
+
+private:
+  /// Operations marked as dead and to be erased.
+  std::vector<Operation *> opsToErase;
+};
+} // end anonymous namespace
+
+/// Attempt to eliminate a redundant operation.
+LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op) {
+  // Don't simplify operations with nested blocks. We don't currently model
+  // equality comparisons correctly among other things. It is also unclear
+  // whether we would want to CSE such operations.
+  if (op->getNumRegions() != 0)
+    return failure();
+
+  // TODO(riverriddle) We currently only eliminate non side-effecting
+  // operations.
+  if (!op->hasNoSideEffect())
+    return failure();
+
+  // If the operation is already trivially dead just add it to the erase list.
+  if (op->use_empty()) {
+    opsToErase.push_back(op);
+    return success();
+  }
+
+  // Look for an existing definition for the operation.
+  if (auto *existing = knownValues.lookup(op)) {
+    // If we find one then replace all uses of the current operation with the
+    // existing one and mark it for deletion.
+    op->replaceAllUsesWith(existing);
+    opsToErase.push_back(op);
+
+    // If the existing operation has an unknown location and the current
+    // operation doesn't, then set the existing op's location to that of the
+    // current op.
+    if (existing->getLoc().isa<UnknownLoc>() &&
+        !op->getLoc().isa<UnknownLoc>()) {
+      existing->setLoc(op->getLoc());
+    }
+    return success();
+  }
+
+  // Otherwise, we add this operation to the known values map.
+  knownValues.insert(op, op);
+  return failure();
+}
+
+void CSE::simplifyBlock(ScopedMapTy &knownValues, DominanceInfo &domInfo,
+                        Block *bb) {
+  for (auto &inst : *bb) {
+    // If the operation is simplified, we don't process any held regions.
+    if (succeeded(simplifyOperation(knownValues, &inst)))
+      continue;
+
+    // If this operation is isolated above, we can't process nested regions with
+    // the given 'knownValues' map. This would cause the insertion of implicit
+    // captures in explicit capture only regions.
+    if (!inst.isRegistered() || inst.isKnownIsolatedFromAbove()) {
+      ScopedMapTy nestedKnownValues;
+      for (auto &region : inst.getRegions())
+        simplifyRegion(nestedKnownValues, domInfo, region);
+      continue;
+    }
+
+    // Otherwise, process nested regions normally.
+    for (auto &region : inst.getRegions())
+      simplifyRegion(knownValues, domInfo, region);
+  }
+}
+
+void CSE::simplifyRegion(ScopedMapTy &knownValues, DominanceInfo &domInfo,
+                         Region &region) {
+  // If the region is empty there is nothing to do.
+  if (region.empty())
+    return;
+
+  // If the region only contains one block, then simplify it directly.
+  if (std::next(region.begin()) == region.end()) {
+    ScopedMapTy::ScopeTy scope(knownValues);
+    simplifyBlock(knownValues, domInfo, &region.front());
+    return;
+  }
+
+  // Note, deque is being used here because there was significant performance
+  // gains over vector when the container becomes very large due to the
+  // specific access patterns. If/when these performance issues are no
+  // longer a problem we can change this to vector. For more information see
+  // the llvm mailing list discussion on this:
+  // http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html
+  std::deque<std::unique_ptr<CFGStackNode>> stack;
+
+  // Process the nodes of the dom tree for this region.
+  stack.emplace_back(llvm::make_unique<CFGStackNode>(
+      knownValues, domInfo.getRootNode(&region)));
+
+  while (!stack.empty()) {
+    auto &currentNode = stack.back();
+
+    // Check to see if we need to process this node.
+    if (!currentNode->processed) {
+      currentNode->processed = true;
+      simplifyBlock(knownValues, domInfo, currentNode->node->getBlock());
+    }
+
+    // Otherwise, check to see if we need to process a child node.
+    if (currentNode->childIterator != currentNode->node->end()) {
+      auto *childNode = *(currentNode->childIterator++);
+      stack.emplace_back(
+          llvm::make_unique<CFGStackNode>(knownValues, childNode));
+    } else {
+      // Finally, if the node and all of its children have been processed
+      // then we delete the node.
+      stack.pop_back();
+    }
+  }
+}
+
+void CSE::runOnFunction() {
+  /// A scoped hash table of defining operations within a function.
+  ScopedMapTy knownValues;
+  simplifyRegion(knownValues, getAnalysis<DominanceInfo>(),
+                 getFunction().getBody());
+
+  // If no operations were erased, then we mark all analyses as preserved.
+  if (opsToErase.empty())
+    return markAllAnalysesPreserved();
+
+  /// Erase any operations that were marked as dead during simplification.
+  for (auto *op : opsToErase)
+    op->erase();
+  opsToErase.clear();
+
+  // We currently don't remove region operations, so mark dominance as
+  // preserved.
+  markAnalysesPreserved<DominanceInfo, PostDominanceInfo>();
+}
+
+FunctionPassBase *mlir::createCSEPass() { return new CSE(); }
+
+static PassRegistration<CSE>
+    pass("cse", "Eliminate common sub-expressions in functions");
diff --git a/third_party/mlir/lib/Transforms/Canonicalizer.cpp b/third_party/mlir/lib/Transforms/Canonicalizer.cpp
new file mode 100644
index 0000000..80d8ea9
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/Canonicalizer.cpp
@@ -0,0 +1,61 @@
+//===- Canonicalizer.cpp - Canonicalize MLIR operations -------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This transformation pass converts operations into their canonical forms by
+// folding constants, applying operation identity transformations etc.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// The actual Canonicalizer Pass.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Canonicalize operations in functions.
+struct Canonicalizer : public FunctionPass<Canonicalizer> {
+  void runOnFunction() override;
+};
+} // end anonymous namespace
+
+void Canonicalizer::runOnFunction() {
+  OwningRewritePatternList patterns;
+  auto func = getFunction();
+
+  // TODO: Instead of adding all known patterns from the whole system lazily add
+  // and cache the canonicalization patterns for ops we see in practice when
+  // building the worklist.  For now, we just grab everything.
+  auto *context = &getContext();
+  for (auto *op : context->getRegisteredOperations())
+    op->getCanonicalizationPatterns(patterns, context);
+
+  applyPatternsGreedily(func, patterns);
+}
+
+/// Create a Canonicalizer pass.
+FunctionPassBase *mlir::createCanonicalizerPass() {
+  return new Canonicalizer();
+}
+
+static PassRegistration<Canonicalizer> pass("canonicalize",
+                                            "Canonicalize operations");
diff --git a/third_party/mlir/lib/Transforms/DialectConversion.cpp b/third_party/mlir/lib/Transforms/DialectConversion.cpp
new file mode 100644
index 0000000..cfb85be
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/DialectConversion.cpp
@@ -0,0 +1,1391 @@
+//===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Transforms/Utils.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+#define DEBUG_TYPE "dialect-conversion"
+
+//===----------------------------------------------------------------------===//
+// ArgConverter
+//===----------------------------------------------------------------------===//
+namespace {
+/// This class provides a simple interface for converting the types of block
+/// arguments. This is done by inserting fake cast operations that map from the
+/// illegal type to the original type to allow for undoing pending rewrites in
+/// the case of failure.
+struct ArgConverter {
+  ArgConverter(TypeConverter *typeConverter, PatternRewriter &rewriter)
+      : castOpName(kCastName, rewriter.getContext()),
+        loc(rewriter.getUnknownLoc()), typeConverter(typeConverter),
+        rewriter(rewriter) {}
+
+  /// Erase any rewrites registered for arguments to blocks within the given
+  /// region. This function is called when the given region is to be destroyed.
+  void cancelPendingRewrites(Block *block);
+
+  /// Cleanup and undo any generated conversions for the arguments of block.
+  /// This method differs from 'cancelPendingRewrites' in that it returns the
+  /// block signature to its original state.
+  void discardPendingRewrites(Block *block);
+
+  /// Replace usages of the cast operations with the argument directly.
+  void applyRewrites();
+
+  /// Return if the signature of the given block has already been converted.
+  bool hasBeenConverted(Block *block) const { return argMapping.count(block); }
+
+  /// Attempt to convert the signature of the given block.
+  LogicalResult convertSignature(Block *block, BlockAndValueMapping &mapping);
+
+  /// Apply the given signature conversion on the given block.
+  void applySignatureConversion(
+      Block *block, TypeConverter::SignatureConversion &signatureConversion,
+      BlockAndValueMapping &mapping);
+
+  /// Convert the given block argument given the provided set of new argument
+  /// values that are to replace it. This function returns the operation used
+  /// to perform the conversion.
+  Operation *convertArgument(BlockArgument *origArg,
+                             ArrayRef<Value *> newValues,
+                             BlockAndValueMapping &mapping);
+
+  /// A utility function used to create a conversion cast operation with the
+  /// given input and result types.
+  Operation *createCast(ArrayRef<Value *> inputs, Type outputType);
+
+  /// This is an operation name for a fake operation that is inserted during the
+  /// conversion process. Operations of this type are guaranteed to never escape
+  /// the converter.
+  static constexpr StringLiteral kCastName = "__mlir_conversion.cast";
+  OperationName castOpName;
+
+  /// This is a collection of cast operations that were generated during the
+  /// conversion process when converting the types of block arguments.
+  llvm::MapVector<Block *, SmallVector<Operation *, 4>> argMapping;
+
+  /// An instance of the unknown location that is used when generating
+  /// producers.
+  Location loc;
+
+  /// The type converter to use when changing types.
+  TypeConverter *typeConverter;
+
+  /// The pattern rewriter to use when materializing conversions.
+  PatternRewriter &rewriter;
+};
+} // end anonymous namespace
+
+constexpr StringLiteral ArgConverter::kCastName;
+
+/// Erase any rewrites registered for arguments to the given block.
+void ArgConverter::cancelPendingRewrites(Block *block) {
+  auto it = argMapping.find(block);
+  if (it == argMapping.end())
+    return;
+  for (auto *op : it->second) {
+    op->dropAllDefinedValueUses();
+    op->erase();
+  }
+  argMapping.erase(it);
+}
+
+/// Cleanup and undo any generated conversions for the arguments of block.
+/// This method differs from 'cancelPendingRewrites' in that it returns the
+/// block signature to its original state.
+void ArgConverter::discardPendingRewrites(Block *block) {
+  auto it = argMapping.find(block);
+  if (it == argMapping.end())
+    return;
+
+  // Erase all of the new arguments.
+  for (int i = block->getNumArguments() - 1; i >= 0; --i) {
+    block->getArgument(i)->dropAllUses();
+    block->eraseArgument(i, /*updatePredTerms=*/false);
+  }
+
+  // Re-instate the old arguments.
+  auto &mapping = it->second;
+  for (unsigned i = 0, e = mapping.size(); i != e; ++i) {
+    auto *op = mapping[i];
+    auto *arg = block->addArgument(op->getResult(0)->getType());
+    op->getResult(0)->replaceAllUsesWith(arg);
+
+    // If this operation is within a block, it will be cleaned up automatically.
+    if (!op->getBlock())
+      op->erase();
+  }
+  argMapping.erase(it);
+}
+
+/// Replace usages of the cast operations with the argument directly.
+void ArgConverter::applyRewrites() {
+  Block *block;
+  ArrayRef<Operation *> argOps;
+  for (auto &mapping : argMapping) {
+    std::tie(block, argOps) = mapping;
+
+    // Process the remapping for each of the original arguments.
+    for (unsigned i = 0, e = argOps.size(); i != e; ++i) {
+      auto *op = argOps[i];
+
+      // Handle the case of a 1->N value mapping.
+      if (op->getNumOperands() > 1) {
+        // If all of the uses were removed, we can drop this op. Otherwise,
+        // keep the operation alive and let the user handle any remaining
+        // usages.
+        if (op->use_empty())
+          op->erase();
+        continue;
+      }
+
+      // If mapping is 1-1, replace the remaining uses and drop the cast
+      // operation.
+      // FIXME(riverriddle) This should check that the result type and operand
+      // type are the same, otherwise it should force a conversion to be
+      // materialized. This works around a current limitation with regards to
+      // region entry argument type conversion.
+      if (op->getNumOperands() == 1) {
+        op->getResult(0)->replaceAllUsesWith(op->getOperand(0));
+        op->destroy();
+        continue;
+      }
+
+      // Otherwise, if there are any dangling uses then replace the fake
+      // conversion operation with one generated by the type converter. This
+      // is necessary as the cast must persist in the IR after conversion.
+      auto *opResult = op->getResult(0);
+      if (!opResult->use_empty()) {
+        rewriter.setInsertionPointToStart(block);
+        SmallVector<Value *, 1> operands(op->getOperands());
+        auto *newOp = typeConverter->materializeConversion(
+            rewriter, opResult->getType(), operands, op->getLoc());
+        opResult->replaceAllUsesWith(newOp->getResult(0));
+      }
+      op->destroy();
+    }
+  }
+}
+
+/// Converts the signature of the given entry block.
+LogicalResult ArgConverter::convertSignature(Block *block,
+                                             BlockAndValueMapping &mapping) {
+  if (auto conversion = typeConverter->convertBlockSignature(block))
+    return applySignatureConversion(block, *conversion, mapping), success();
+  return failure();
+}
+
+/// Apply the given signature conversion on the given block.
+void ArgConverter::applySignatureConversion(
+    Block *block, TypeConverter::SignatureConversion &signatureConversion,
+    BlockAndValueMapping &mapping) {
+  unsigned origArgCount = block->getNumArguments();
+  auto convertedTypes = signatureConversion.getConvertedTypes();
+  if (origArgCount == 0 && convertedTypes.empty())
+    return;
+
+  SmallVector<Value *, 4> newArgRange(block->addArguments(convertedTypes));
+  ArrayRef<Value *> newArgRef(newArgRange);
+
+  // Remap each of the original arguments as determined by the signature
+  // conversion.
+  auto &newArgMapping = argMapping[block];
+  rewriter.setInsertionPointToStart(block);
+  for (unsigned i = 0; i != origArgCount; ++i) {
+    ArrayRef<Value *> remappedValues;
+    if (auto inputMap = signatureConversion.getInputMapping(i))
+      remappedValues = newArgRef.slice(inputMap->inputNo, inputMap->size);
+
+    BlockArgument *arg = block->getArgument(i);
+    newArgMapping.push_back(convertArgument(arg, remappedValues, mapping));
+  }
+
+  // Erase all of the original arguments.
+  for (unsigned i = 0; i != origArgCount; ++i)
+    block->eraseArgument(0, /*updatePredTerms=*/false);
+}
+
+/// Convert the given block argument given the provided set of new argument
+/// values that are to replace it. This function returns the operation used
+/// to perform the conversion.
+Operation *ArgConverter::convertArgument(BlockArgument *origArg,
+                                         ArrayRef<Value *> newValues,
+                                         BlockAndValueMapping &mapping) {
+  // Handle the cases of 1->0 or 1->1 mappings.
+  if (newValues.size() < 2) {
+    // Create a temporary producer for the argument during the conversion
+    // process.
+    auto *cast = createCast(newValues, origArg->getType());
+    origArg->replaceAllUsesWith(cast->getResult(0));
+
+    // Insert a mapping between this argument and the one that is replacing
+    // it.
+    if (!newValues.empty())
+      mapping.map(cast->getResult(0), newValues[0]);
+    return cast;
+  }
+
+  // Otherwise, this is a 1->N mapping. Call into the provided type converter
+  // to pack the new values.
+  auto *cast = typeConverter->materializeConversion(
+      rewriter, origArg->getType(), newValues, loc);
+  assert(cast->getNumResults() == 1 &&
+         cast->getNumOperands() == newValues.size());
+  origArg->replaceAllUsesWith(cast->getResult(0));
+  return cast;
+}
+
+/// A utility function used to create a conversion cast operation with the
+/// given input and result types.
+Operation *ArgConverter::createCast(ArrayRef<Value *> inputs, Type outputType) {
+  return Operation::create(loc, castOpName, inputs, outputType, llvm::None,
+                           llvm::None, 0, false, outputType.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// ConversionPatternRewriterImpl
+//===----------------------------------------------------------------------===//
+namespace {
+/// This class contains a snapshot of the current conversion rewriter state.
+/// This is useful when saving and undoing a set of rewrites.
+struct RewriterState {
+  RewriterState(unsigned numCreatedOperations, unsigned numReplacements,
+                unsigned numBlockActions)
+      : numCreatedOperations(numCreatedOperations),
+        numReplacements(numReplacements), numBlockActions(numBlockActions) {}
+
+  /// The current number of created operations.
+  unsigned numCreatedOperations;
+
+  /// The current number of replacements queued.
+  unsigned numReplacements;
+
+  /// The current number of block actions performed.
+  unsigned numBlockActions;
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace detail {
+struct ConversionPatternRewriterImpl {
+  /// This class represents one requested operation replacement via 'replaceOp'.
+  struct OpReplacement {
+    OpReplacement() = default;
+    OpReplacement(Operation *op, ArrayRef<Value *> newValues)
+        : op(op), newValues(newValues.begin(), newValues.end()) {}
+
+    Operation *op;
+    SmallVector<Value *, 2> newValues;
+  };
+
+  /// The kind of the block action performed during the rewrite.  Actions can be
+  /// undone if the conversion fails.
+  enum class BlockActionKind { Split, Move, TypeConversion };
+
+  /// Original position of the given block in its parent region.  We cannot use
+  /// a region iterator because it could have been invalidated by other region
+  /// operations since the position was stored.
+  struct BlockPosition {
+    Region *region;
+    Region::iterator::difference_type position;
+  };
+
+  /// The storage class for an undoable block action (one of BlockActionKind),
+  /// contains the information necessary to undo this action.
+  struct BlockAction {
+    static BlockAction getSplit(Block *block, Block *originalBlock) {
+      BlockAction action{BlockActionKind::Split, block, {}};
+      action.originalBlock = originalBlock;
+      return action;
+    }
+    static BlockAction getMove(Block *block, BlockPosition originalPos) {
+      return {BlockActionKind::Move, block, {originalPos}};
+    }
+    static BlockAction getTypeConversion(Block *block) {
+      return BlockAction{BlockActionKind::TypeConversion, block, {}};
+    }
+
+    // The action kind.
+    BlockActionKind kind;
+
+    // A pointer to the block that was created by the action.
+    Block *block;
+
+    union {
+      // In use if kind == BlockActionKind::Move and contains a pointer to the
+      // region that originally contained the block as well as the position of
+      // the block in that region.
+      BlockPosition originalPosition;
+      // In use if kind == BlockActionKind::Split and contains a pointer to the
+      // block that was split into two parts.
+      Block *originalBlock;
+    };
+  };
+
+  ConversionPatternRewriterImpl(PatternRewriter &rewriter,
+                                TypeConverter *converter)
+      : argConverter(converter, rewriter) {}
+
+  /// Return the current state of the rewriter.
+  RewriterState getCurrentState();
+
+  /// Reset the state of the rewriter to a previously saved point.
+  void resetState(RewriterState state);
+
+  /// Undo the block actions (motions, splits) one by one in reverse order until
+  /// "numActionsToKeep" actions remains.
+  void undoBlockActions(unsigned numActionsToKeep = 0);
+
+  /// Cleanup and destroy any generated rewrite operations. This method is
+  /// invoked when the conversion process fails.
+  void discardRewrites();
+
+  /// Apply all requested operation rewrites. This method is invoked when the
+  /// conversion process succeeds.
+  void applyRewrites();
+
+  /// Convert the signature of the given block.
+  LogicalResult convertBlockSignature(Block *block);
+
+  /// Apply a signature conversion on the given region.
+  void applySignatureConversion(Region *region,
+                                TypeConverter::SignatureConversion &conversion);
+
+  /// PatternRewriter hook for replacing the results of an operation.
+  void replaceOp(Operation *op, ArrayRef<Value *> newValues,
+                 ArrayRef<Value *> valuesToRemoveIfDead);
+
+  /// Notifies that a block was split.
+  void notifySplitBlock(Block *block, Block *continuation);
+
+  /// Notifies that the blocks of a region are about to be moved.
+  void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent,
+                                        Region::iterator before);
+
+  /// Remap the given operands to those with potentially different types.
+  void remapValues(Operation::operand_range operands,
+                   SmallVectorImpl<Value *> &remapped);
+
+  // Mapping between replaced values that differ in type. This happens when
+  // replacing a value with one of a different type.
+  BlockAndValueMapping mapping;
+
+  /// Utility used to convert block arguments.
+  ArgConverter argConverter;
+
+  /// Ordered vector of all of the newly created operations during conversion.
+  SmallVector<Operation *, 4> createdOps;
+
+  /// Ordered vector of any requested operation replacements.
+  SmallVector<OpReplacement, 4> replacements;
+
+  /// Ordered list of block operations (creations, splits, motions).
+  SmallVector<BlockAction, 4> blockActions;
+};
+} // end namespace detail
+} // end namespace mlir
+
+RewriterState ConversionPatternRewriterImpl::getCurrentState() {
+  return RewriterState(createdOps.size(), replacements.size(),
+                       blockActions.size());
+}
+
+void ConversionPatternRewriterImpl::resetState(RewriterState state) {
+  // Undo any block actions.
+  undoBlockActions(state.numBlockActions);
+
+  // Reset any replaced operations and undo any saved mappings.
+  for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
+    for (auto *result : repl.op->getResults())
+      mapping.erase(result);
+  replacements.resize(state.numReplacements);
+
+  // Pop all of the newly created operations.
+  while (createdOps.size() != state.numCreatedOperations)
+    createdOps.pop_back_val()->erase();
+}
+
+void ConversionPatternRewriterImpl::undoBlockActions(
+    unsigned numActionsToKeep) {
+  for (auto &action :
+       llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) {
+    switch (action.kind) {
+    // Merge back the block that was split out.
+    case BlockActionKind::Split: {
+      action.originalBlock->getOperations().splice(
+          action.originalBlock->end(), action.block->getOperations());
+      action.block->erase();
+      break;
+    }
+    // Move the block back to its original position.
+    case BlockActionKind::Move: {
+      Region *originalRegion = action.originalPosition.region;
+      originalRegion->getBlocks().splice(
+          std::next(originalRegion->begin(), action.originalPosition.position),
+          action.block->getParent()->getBlocks(), action.block);
+      break;
+    }
+    // Undo the type conversion.
+    case BlockActionKind::TypeConversion: {
+      argConverter.discardPendingRewrites(action.block);
+      break;
+    }
+    }
+  }
+  blockActions.resize(numActionsToKeep);
+}
+
+void ConversionPatternRewriterImpl::discardRewrites() {
+  undoBlockActions();
+
+  // Remove any newly created ops.
+  for (auto *op : createdOps) {
+    op->dropAllDefinedValueUses();
+    op->erase();
+  }
+}
+
+void ConversionPatternRewriterImpl::applyRewrites() {
+  // Apply all of the rewrites replacements requested during conversion.
+  for (auto &repl : replacements) {
+    for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i)
+      repl.op->getResult(i)->replaceAllUsesWith(
+          mapping.lookupOrDefault(repl.newValues[i]));
+
+    // If this operation defines any regions, drop any pending argument
+    // rewrites.
+    if (argConverter.typeConverter && repl.op->getNumRegions()) {
+      for (auto &region : repl.op->getRegions())
+        for (auto &block : region)
+          argConverter.cancelPendingRewrites(&block);
+    }
+  }
+
+  // In a second pass, erase all of the replaced operations in reverse. This
+  // allows processing nested operations before their parent region is
+  // destroyed.
+  for (auto &repl : llvm::reverse(replacements))
+    repl.op->erase();
+
+  argConverter.applyRewrites();
+}
+
+LogicalResult
+ConversionPatternRewriterImpl::convertBlockSignature(Block *block) {
+  // Check to see if this block should not be converted:
+  // * The block is invalid, or there is no type converter.
+  // * The block has already been converted.
+  // * This is an entry block, these are converted explicitly via patterns.
+  if (!block || !argConverter.typeConverter ||
+      argConverter.hasBeenConverted(block) || block->isEntryBlock())
+    return success();
+
+  // Otherwise, try to convert the block signature.
+  if (failed(argConverter.convertSignature(block, mapping)))
+    return failure();
+  blockActions.push_back(BlockAction::getTypeConversion(block));
+  return success();
+}
+
+void ConversionPatternRewriterImpl::applySignatureConversion(
+    Region *region, TypeConverter::SignatureConversion &conversion) {
+  if (!region->empty()) {
+    argConverter.applySignatureConversion(&region->front(), conversion,
+                                          mapping);
+    blockActions.push_back(BlockAction::getTypeConversion(&region->front()));
+  }
+}
+
+void ConversionPatternRewriterImpl::replaceOp(
+    Operation *op, ArrayRef<Value *> newValues,
+    ArrayRef<Value *> valuesToRemoveIfDead) {
+  assert(newValues.size() == op->getNumResults());
+
+  // Create mappings for each of the new result values.
+  for (unsigned i = 0, e = newValues.size(); i < e; ++i) {
+    assert((newValues[i] || op->getResult(i)->use_empty()) &&
+           "result value has remaining uses that must be replaced");
+    if (newValues[i])
+      mapping.map(op->getResult(i), newValues[i]);
+  }
+
+  // Record the requested operation replacement.
+  replacements.emplace_back(op, newValues);
+}
+
+void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
+                                                     Block *continuation) {
+  blockActions.push_back(BlockAction::getSplit(continuation, block));
+}
+
+void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
+    Region &region, Region &parent, Region::iterator before) {
+  for (auto &pair : llvm::enumerate(region)) {
+    Block &block = pair.value();
+    unsigned position = pair.index();
+    blockActions.push_back(BlockAction::getMove(&block, {&region, position}));
+  }
+}
+
+void ConversionPatternRewriterImpl::remapValues(
+    Operation::operand_range operands, SmallVectorImpl<Value *> &remapped) {
+  remapped.reserve(llvm::size(operands));
+  for (Value *operand : operands)
+    remapped.push_back(mapping.lookupOrDefault(operand));
+}
+
+//===----------------------------------------------------------------------===//
+// ConversionPatternRewriter
+//===----------------------------------------------------------------------===//
+
+ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx,
+                                                     TypeConverter *converter)
+    : PatternRewriter(ctx),
+      impl(new detail::ConversionPatternRewriterImpl(*this, converter)) {}
+ConversionPatternRewriter::~ConversionPatternRewriter() {}
+
+/// PatternRewriter hook for replacing the results of an operation.
+void ConversionPatternRewriter::replaceOp(
+    Operation *op, ArrayRef<Value *> newValues,
+    ArrayRef<Value *> valuesToRemoveIfDead) {
+  impl->replaceOp(op, newValues, valuesToRemoveIfDead);
+}
+
+/// Apply a signature conversion to the entry block of the given region.
+void ConversionPatternRewriter::applySignatureConversion(
+    Region *region, TypeConverter::SignatureConversion &conversion) {
+  impl->applySignatureConversion(region, conversion);
+}
+
+/// Clone the given operation without cloning its regions.
+Operation *ConversionPatternRewriter::cloneWithoutRegions(Operation *op) {
+  Operation *newOp = OpBuilder::cloneWithoutRegions(*op);
+  impl->createdOps.push_back(newOp);
+  return newOp;
+}
+
+/// PatternRewriter hook for splitting a block into two parts.
+Block *ConversionPatternRewriter::splitBlock(Block *block,
+                                             Block::iterator before) {
+  auto *continuation = PatternRewriter::splitBlock(block, before);
+  impl->notifySplitBlock(block, continuation);
+  return continuation;
+}
+
+/// PatternRewriter hook for moving blocks out of a region.
+void ConversionPatternRewriter::inlineRegionBefore(Region &region,
+                                                   Region &parent,
+                                                   Region::iterator before) {
+  impl->notifyRegionIsBeingInlinedBefore(region, parent, before);
+  PatternRewriter::inlineRegionBefore(region, parent, before);
+}
+
+/// PatternRewriter hook for creating a new operation.
+Operation *
+ConversionPatternRewriter::createOperation(const OperationState &state) {
+  auto *result = OpBuilder::createOperation(state);
+  impl->createdOps.push_back(result);
+  return result;
+}
+
+/// PatternRewriter hook for updating the root operation in-place.
+void ConversionPatternRewriter::notifyRootUpdated(Operation *op) {
+  // The rewriter caches changes to the IR to allow for operating in-place and
+  // backtracking. The rewriter is currently not capable of backtracking
+  // in-place modifications.
+  llvm_unreachable("in-place operation updates are not supported");
+}
+
+/// Return a reference to the internal implementation.
+detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
+  return *impl;
+}
+
+//===----------------------------------------------------------------------===//
+// Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+/// Attempt to match and rewrite the IR root at the specified operation.
+PatternMatchResult
+ConversionPattern::matchAndRewrite(Operation *op,
+                                   PatternRewriter &rewriter) const {
+  SmallVector<Value *, 4> operands;
+  auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
+  dialectRewriter.getImpl().remapValues(op->getOperands(), operands);
+
+  // If this operation has no successors, invoke the rewrite directly.
+  if (op->getNumSuccessors() == 0)
+    return matchAndRewrite(op, operands, dialectRewriter);
+
+  // Otherwise, we need to remap the successors.
+  SmallVector<Block *, 2> destinations;
+  destinations.reserve(op->getNumSuccessors());
+
+  SmallVector<ArrayRef<Value *>, 2> operandsPerDestination;
+  unsigned firstSuccessorOperand = op->getSuccessorOperandIndex(0);
+  for (unsigned i = 0, seen = 0, e = op->getNumSuccessors(); i < e; ++i) {
+    destinations.push_back(op->getSuccessor(i));
+
+    // Lookup the successors operands.
+    unsigned n = op->getNumSuccessorOperands(i);
+    operandsPerDestination.push_back(
+        llvm::makeArrayRef(operands.data() + firstSuccessorOperand + seen, n));
+    seen += n;
+  }
+
+  // Rewrite the operation.
+  return matchAndRewrite(
+      op,
+      llvm::makeArrayRef(operands.data(),
+                         operands.data() + firstSuccessorOperand),
+      destinations, operandsPerDestination, dialectRewriter);
+}
+
+//===----------------------------------------------------------------------===//
+// OperationLegalizer
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A set of rewrite patterns that can be used to legalize a given operation.
+using LegalizationPatterns = SmallVector<RewritePattern *, 1>;
+
+/// This class defines a recursive operation legalizer.
+class OperationLegalizer {
+public:
+  using LegalizationAction = ConversionTarget::LegalizationAction;
+
+  OperationLegalizer(ConversionTarget &targetInfo,
+                     OwningRewritePatternList &patterns)
+      : target(targetInfo) {
+    buildLegalizationGraph(patterns);
+    computeLegalizationGraphBenefit();
+  }
+
+  /// Returns if the given operation is known to be illegal on the target.
+  bool isIllegal(Operation *op) const;
+
+  /// Attempt to legalize the given operation. Returns success if the operation
+  /// was legalized, failure otherwise.
+  LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
+
+private:
+  /// Attempt to legalize the given operation by applying the provided pattern.
+  /// Returns success if the operation was legalized, failure otherwise.
+  LogicalResult legalizePattern(Operation *op, RewritePattern *pattern,
+                                ConversionPatternRewriter &rewriter);
+
+  /// Build an optimistic legalization graph given the provided patterns. This
+  /// function populates 'legalizerPatterns' with the operations that are not
+  /// directly legal, but may be transitively legal for the current target given
+  /// the provided patterns.
+  void buildLegalizationGraph(OwningRewritePatternList &patterns);
+
+  /// Compute the benefit of each node within the computed legalization graph.
+  /// This orders the patterns within 'legalizerPatterns' based upon two
+  /// criteria:
+  ///  1) Prefer patterns that have the lowest legalization depth, i.e.
+  ///     represent the more direct mapping to the target.
+  ///  2) When comparing patterns with the same legalization depth, prefer the
+  ///     pattern with the highest PatternBenefit. This allows for users to
+  ///     prefer specific legalizations over others.
+  void computeLegalizationGraphBenefit();
+
+  /// The current set of patterns that have been applied.
+  llvm::SmallPtrSet<RewritePattern *, 8> appliedPatterns;
+
+  /// The set of legality information for operations transitively supported by
+  /// the target.
+  DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
+
+  /// The legalization information provided by the target.
+  ConversionTarget &target;
+};
+} // namespace
+
+bool OperationLegalizer::isIllegal(Operation *op) const {
+  // Check if the target explicitly marked this operation as illegal.
+  if (auto action = target.getOpAction(op->getName()))
+    return action == LegalizationAction::Illegal;
+  return false;
+}
+
+LogicalResult
+OperationLegalizer::legalize(Operation *op,
+                             ConversionPatternRewriter &rewriter) {
+  // Make sure that the signature of the parent block has been converted.
+  if (failed(rewriter.getImpl().convertBlockSignature(op->getBlock())))
+    return failure();
+
+  LLVM_DEBUG(llvm::dbgs() << "Legalizing operation : " << op->getName()
+                          << "\n");
+
+  // Check if this operation is legal on the target.
+  if (target.isLegal(op)) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "-- Success : Operation marked legal by the target\n");
+    return success();
+  }
+
+  // Otherwise, we need to apply a legalization pattern to this operation.
+  auto it = legalizerPatterns.find(op->getName());
+  if (it == legalizerPatterns.end()) {
+    LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no known legalization path.\n");
+    return failure();
+  }
+
+  // The patterns are sorted by expected benefit, so try to apply each in-order.
+  for (auto *pattern : it->second)
+    if (succeeded(legalizePattern(op, pattern, rewriter)))
+      return success();
+
+  LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no matched legalization pattern.\n");
+  return failure();
+}
+
+LogicalResult
+OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
+                                    ConversionPatternRewriter &rewriter) {
+  LLVM_DEBUG({
+    llvm::dbgs() << "-* Applying rewrite pattern '" << op->getName() << " -> (";
+    interleaveComma(pattern->getGeneratedOps(), llvm::dbgs());
+    llvm::dbgs() << ")'.\n";
+  });
+
+  // Ensure that we don't cycle by not allowing the same pattern to be
+  // applied twice in the same recursion stack.
+  // TODO(riverriddle) We could eventually converge, but that requires more
+  // complicated analysis.
+  if (!appliedPatterns.insert(pattern).second) {
+    LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern was already applied.\n");
+    return failure();
+  }
+
+  auto &rewriterImpl = rewriter.getImpl();
+  RewriterState curState = rewriterImpl.getCurrentState();
+  auto cleanupFailure = [&] {
+    // Reset the rewriter state and pop this pattern.
+    rewriterImpl.resetState(curState);
+    appliedPatterns.erase(pattern);
+    return failure();
+  };
+
+  // Try to rewrite with the given pattern.
+  rewriter.setInsertionPoint(op);
+  if (!pattern->matchAndRewrite(op, rewriter)) {
+    LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern failed to match.\n");
+    return cleanupFailure();
+  }
+
+  // Recursively legalize each of the new operations.
+  for (unsigned i = curState.numCreatedOperations,
+                e = rewriterImpl.createdOps.size();
+       i != e; ++i) {
+    if (failed(legalize(rewriterImpl.createdOps[i], rewriter))) {
+      LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated operation was illegal.\n");
+      return cleanupFailure();
+    }
+  }
+
+  appliedPatterns.erase(pattern);
+  return success();
+}
+
+void OperationLegalizer::buildLegalizationGraph(
+    OwningRewritePatternList &patterns) {
+  // A mapping between an operation and a set of operations that can be used to
+  // generate it.
+  DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps;
+  // A mapping between an operation and any currently invalid patterns it has.
+  DenseMap<OperationName, SmallPtrSet<RewritePattern *, 2>> invalidPatterns;
+  // A worklist of patterns to consider for legality.
+  llvm::SetVector<RewritePattern *> patternWorklist;
+
+  // Build the mapping from operations to the parent ops that may generate them.
+  for (auto &pattern : patterns) {
+    auto root = pattern->getRootKind();
+
+    // Skip operations that are always known to be legal.
+    if (target.getOpAction(root) == LegalizationAction::Legal)
+      continue;
+
+    // Add this pattern to the invalid set for the root op and record this root
+    // as a parent for any generated operations.
+    invalidPatterns[root].insert(pattern.get());
+    for (auto op : pattern->getGeneratedOps())
+      parentOps[op].insert(root);
+
+    // Add this pattern to the worklist.
+    patternWorklist.insert(pattern.get());
+  }
+
+  while (!patternWorklist.empty()) {
+    auto *pattern = patternWorklist.pop_back_val();
+
+    // Check to see if any of the generated operations are invalid.
+    if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
+          auto action = target.getOpAction(op);
+          return !legalizerPatterns.count(op) &&
+                 (!action || action == LegalizationAction::Illegal);
+        }))
+      continue;
+
+    // Otherwise, if all of the generated operation are valid, this op is now
+    // legal so add all of the child patterns to the worklist.
+    legalizerPatterns[pattern->getRootKind()].push_back(pattern);
+    invalidPatterns[pattern->getRootKind()].erase(pattern);
+
+    // Add any invalid patterns of the parent operations to see if they have now
+    // become legal.
+    for (auto op : parentOps[pattern->getRootKind()])
+      patternWorklist.set_union(invalidPatterns[op]);
+  }
+}
+
+void OperationLegalizer::computeLegalizationGraphBenefit() {
+  // The smallest pattern depth, when legalizing an operation.
+  DenseMap<OperationName, unsigned> minPatternDepth;
+
+  // Compute the minimum legalization depth for a given operation.
+  std::function<unsigned(OperationName)> computeDepth = [&](OperationName op) {
+    // Check for existing depth.
+    auto depthIt = minPatternDepth.find(op);
+    if (depthIt != minPatternDepth.end())
+      return depthIt->second;
+
+    // If a mapping for this operation does not exist, then this operation
+    // is always legal. Return 0 as the depth for a directly legal operation.
+    auto opPatternsIt = legalizerPatterns.find(op);
+    if (opPatternsIt == legalizerPatterns.end())
+      return 0u;
+
+    auto &minDepth = minPatternDepth[op];
+    if (opPatternsIt->second.empty())
+      return minDepth;
+
+    // Initialize the depth to the maximum value.
+    minDepth = std::numeric_limits<unsigned>::max();
+
+    // Compute the depth for each pattern used to legalize this operation.
+    SmallVector<std::pair<RewritePattern *, unsigned>, 4> patternsByDepth;
+    patternsByDepth.reserve(opPatternsIt->second.size());
+    for (RewritePattern *pattern : opPatternsIt->second) {
+      unsigned depth = 0;
+      for (auto generatedOp : pattern->getGeneratedOps())
+        depth = std::max(depth, computeDepth(generatedOp) + 1);
+      patternsByDepth.emplace_back(pattern, depth);
+
+      // Update the min depth for this operation.
+      minDepth = std::min(minDepth, depth);
+    }
+
+    // If the operation only has one legalization pattern, there is no need to
+    // sort them.
+    if (patternsByDepth.size() == 1)
+      return minDepth;
+
+    // Sort the patterns by those likely to be the most beneficial.
+    llvm::array_pod_sort(
+        patternsByDepth.begin(), patternsByDepth.end(),
+        [](const std::pair<RewritePattern *, unsigned> *lhs,
+           const std::pair<RewritePattern *, unsigned> *rhs) {
+          // First sort by the smaller pattern legalization depth.
+          if (lhs->second != rhs->second)
+            return llvm::array_pod_sort_comparator<unsigned>(&lhs->second,
+                                                             &rhs->second);
+
+          // Then sort by the larger pattern benefit.
+          auto lhsBenefit = lhs->first->getBenefit();
+          auto rhsBenefit = rhs->first->getBenefit();
+          return llvm::array_pod_sort_comparator<PatternBenefit>(&rhsBenefit,
+                                                                 &lhsBenefit);
+        });
+
+    // Update the legalization pattern to use the new sorted list.
+    opPatternsIt->second.clear();
+    for (auto &patternIt : patternsByDepth)
+      opPatternsIt->second.push_back(patternIt.first);
+
+    return minDepth;
+  };
+
+  // For each operation that is transitively legal, compute a cost for it.
+  for (auto &opIt : legalizerPatterns)
+    if (!minPatternDepth.count(opIt.first))
+      computeDepth(opIt.first);
+}
+
+//===----------------------------------------------------------------------===//
+// OperationConverter
+//===----------------------------------------------------------------------===//
+namespace {
+enum OpConversionMode {
+  // In this mode, the conversion will ignore failed conversions to allow
+  // illegal operations to co-exist in the IR.
+  Partial,
+
+  // In this mode, all operations must be legal for the given target for the
+  // conversion to succeeed.
+  Full,
+
+  // In this mode, operations are analyzed for legality. No actual rewrites are
+  // applied to the operations on success.
+  Analysis,
+};
+
+// This class converts operations using the given pattern matcher. If a
+// TypeConverter object is provided, then the types of block arguments will be
+// converted using the appropriate 'convertType' calls.
+struct OperationConverter {
+  explicit OperationConverter(ConversionTarget &target,
+                              OwningRewritePatternList &patterns,
+                              OpConversionMode mode,
+                              DenseSet<Operation *> *legalizableOps = nullptr)
+      : opLegalizer(target, patterns), mode(mode),
+        legalizableOps(legalizableOps) {}
+
+  /// Converts the given operations to the conversion target.
+  LogicalResult convertOperations(ArrayRef<Operation *> ops,
+                                  TypeConverter *typeConverter);
+
+private:
+  /// Converts an operation with the given rewriter.
+  LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
+
+  /// Recursively collect all of the operations to convert from within 'region'.
+  LogicalResult computeConversionSet(Region &region,
+                                     std::vector<Operation *> &toConvert);
+
+  /// Converts the type signatures of the blocks nested within 'op' that have
+  /// yet to be converted.
+  LogicalResult convertBlockSignatures(ConversionPatternRewriter &rewriter,
+                                       Operation *op);
+
+  /// The legalizer to use when converting operations.
+  OperationLegalizer opLegalizer;
+
+  /// The conversion mode to use when legalizing operations.
+  OpConversionMode mode;
+
+  /// A set of pre-existing operations that were found to be legalizable to the
+  /// target. This field is only used when mode == OpConversionMode::Analysis.
+  DenseSet<Operation *> *legalizableOps;
+};
+} // end anonymous namespace
+
+LogicalResult
+OperationConverter::convertBlockSignatures(ConversionPatternRewriter &rewriter,
+                                           Operation *op) {
+  SmallVector<Region *, 8> worklist;
+  for (auto &region : op->getRegions())
+    worklist.push_back(&region);
+
+  while (!worklist.empty()) {
+    for (auto &block : *worklist.pop_back_val()) {
+      if (failed(rewriter.getImpl().convertBlockSignature(&block)))
+        return failure();
+      for (auto &nestedOp : block)
+        for (auto &region : nestedOp.getRegions())
+          worklist.push_back(&region);
+    }
+  }
+  return success();
+}
+
+LogicalResult
+OperationConverter::computeConversionSet(Region &region,
+                                         std::vector<Operation *> &toConvert) {
+  if (region.empty())
+    return success();
+
+  // Traverse starting from the entry block.
+  SmallVector<Block *, 16> worklist(1, &region.front());
+  DenseSet<Block *> visitedBlocks;
+  visitedBlocks.insert(&region.front());
+  while (!worklist.empty()) {
+    auto *block = worklist.pop_back_val();
+
+    // Compute the conversion set of each of the nested operations.
+    for (auto &op : *block) {
+      toConvert.emplace_back(&op);
+      for (auto &region : op.getRegions())
+        computeConversionSet(region, toConvert);
+    }
+
+    // Recurse to children that haven't been visited.
+    for (Block *succ : block->getSuccessors())
+      if (visitedBlocks.insert(succ).second)
+        worklist.push_back(succ);
+  }
+
+  // Check that all blocks in the region were visited.
+  if (llvm::any_of(llvm::drop_begin(region.getBlocks(), 1),
+                   [&](Block &block) { return !visitedBlocks.count(&block); }))
+    return emitError(region.getLoc(), "unreachable blocks were not converted");
+  return success();
+}
+
+LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
+                                          Operation *op) {
+  // Legalize the given operation.
+  if (failed(opLegalizer.legalize(op, rewriter))) {
+    // Handle the case of a failed conversion for each of the different modes.
+    /// Full conversions expect all operations to be converted.
+    if (mode == OpConversionMode::Full)
+      return op->emitError()
+             << "failed to legalize operation '" << op->getName() << "'";
+    /// Partial conversions allow conversions to fail iff the operation was not
+    /// explicitly marked as illegal.
+    if (mode == OpConversionMode::Partial && opLegalizer.isIllegal(op))
+      return op->emitError()
+             << "failed to legalize operation '" << op->getName()
+             << "' that was explicitly marked illegal";
+  } else if (mode == OpConversionMode::Analysis) {
+    /// Analysis conversions don't fail if any operations fail to legalize, they
+    /// are only interested in the operations that were successfully legalized.
+    legalizableOps->insert(op);
+  }
+  return success();
+}
+
+LogicalResult
+OperationConverter::convertOperations(ArrayRef<Operation *> ops,
+                                      TypeConverter *typeConverter) {
+  if (ops.empty())
+    return success();
+
+  /// Compute the set of operations and blocks to convert.
+  std::vector<Operation *> toConvert;
+  for (auto *op : ops) {
+    toConvert.emplace_back(op);
+    for (auto &region : op->getRegions())
+      if (failed(computeConversionSet(region, toConvert)))
+        return failure();
+  }
+
+  // Convert each operation and discard rewrites on failure.
+  ConversionPatternRewriter rewriter(ops.front()->getContext(), typeConverter);
+  for (auto *op : toConvert)
+    if (failed(convert(rewriter, op)))
+      return rewriter.getImpl().discardRewrites(), failure();
+
+  // If a type converter was provided, ensure that all blocks have had their
+  // signatures properly converted.
+  if (typeConverter) {
+    for (auto *op : ops)
+      if (failed(convertBlockSignatures(rewriter, op)))
+        return rewriter.getImpl().discardRewrites(), failure();
+  }
+
+  // Otherwise, the body conversion succeeded. Apply rewrites if this is not an
+  // analysis conversion.
+  if (mode == OpConversionMode::Analysis)
+    rewriter.getImpl().discardRewrites();
+  else
+    rewriter.getImpl().applyRewrites();
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Type Conversion
+//===----------------------------------------------------------------------===//
+
+/// Remap an input of the original signature with a new set of types. The
+/// new types are appended to the new signature conversion.
+void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
+                                                   ArrayRef<Type> types) {
+  assert(!types.empty() && "expected valid types");
+  remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
+  addInputs(types);
+}
+
+/// Append new input types to the signature conversion, this should only be
+/// used if the new types are not intended to remap an existing input.
+void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
+  assert(!types.empty() &&
+         "1->0 type remappings don't need to be added explicitly");
+  argTypes.append(types.begin(), types.end());
+}
+
+/// Remap an input of the original signature with a range of types in the
+/// new signature.
+void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
+                                                    unsigned newInputNo,
+                                                    unsigned newInputCount) {
+  assert(!remappedInputs[origInputNo] && "input has already been remapped");
+  assert(newInputCount != 0 && "expected valid input count");
+  remappedInputs[origInputNo] = InputMapping{newInputNo, newInputCount};
+}
+
+/// This hooks allows for converting a type.
+LogicalResult TypeConverter::convertType(Type t,
+                                         SmallVectorImpl<Type> &results) {
+  if (auto newT = convertType(t)) {
+    results.push_back(newT);
+    return success();
+  }
+  return failure();
+}
+
+/// Convert the given set of types, filling 'results' as necessary. This
+/// returns failure if the conversion of any of the types fails, success
+/// otherwise.
+LogicalResult TypeConverter::convertTypes(ArrayRef<Type> types,
+                                          SmallVectorImpl<Type> &results) {
+  for (auto type : types)
+    if (failed(convertType(type, results)))
+      return failure();
+  return success();
+}
+
+/// Return true if the given type is legal for this type converter, i.e. the
+/// type converts to itself.
+bool TypeConverter::isLegal(Type type) {
+  SmallVector<Type, 1> results;
+  return succeeded(convertType(type, results)) && results.size() == 1 &&
+         results.front() == type;
+}
+
+/// Return true if the inputs and outputs of the given function type are
+/// legal.
+bool TypeConverter::isSignatureLegal(FunctionType funcType) {
+  return llvm::all_of(
+      llvm::concat<const Type>(funcType.getInputs(), funcType.getResults()),
+      [this](Type type) { return isLegal(type); });
+}
+
+/// This hook allows for converting a specific argument of a signature.
+LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
+                                                 SignatureConversion &result) {
+  // Try to convert the given input type.
+  SmallVector<Type, 1> convertedTypes;
+  if (failed(convertType(type, convertedTypes)))
+    return failure();
+
+  // If this argument is being dropped, there is nothing left to do.
+  if (convertedTypes.empty())
+    return success();
+
+  // Otherwise, add the new inputs.
+  result.addInputs(inputNo, convertedTypes);
+  return success();
+}
+
+/// Create a default conversion pattern that rewrites the type signature of a
+/// FuncOp.
+namespace {
+struct FuncOpSignatureConversion : public ConversionPattern {
+  FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
+      : ConversionPattern(FuncOp::getOperationName(), 1, ctx),
+        converter(converter) {}
+
+  /// Hook for derived classes to implement combined matching and rewriting.
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto funcOp = cast<FuncOp>(op);
+    FunctionType type = funcOp.getType();
+
+    // Convert the original function arguments.
+    TypeConverter::SignatureConversion result(type.getNumInputs());
+    for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
+      if (failed(converter.convertSignatureArg(i, type.getInput(i), result)))
+        return matchFailure();
+
+    // Convert the original function results.
+    SmallVector<Type, 1> convertedResults;
+    if (failed(converter.convertTypes(type.getResults(), convertedResults)))
+      return matchFailure();
+
+    // Create a new function with an updated signature.
+    auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
+    rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+                                newFuncOp.end());
+    newFuncOp.setType(FunctionType::get(result.getConvertedTypes(),
+                                        convertedResults, funcOp.getContext()));
+
+    // Tell the rewriter to convert the region signature.
+    rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
+    rewriter.replaceOp(op, llvm::None);
+    return matchSuccess();
+  }
+
+  /// The type converter to use when rewriting the signature.
+  TypeConverter &converter;
+};
+} // end anonymous namespace
+
+void mlir::populateFuncOpTypeConversionPattern(
+    OwningRewritePatternList &patterns, MLIRContext *ctx,
+    TypeConverter &converter) {
+  patterns.insert<FuncOpSignatureConversion>(ctx, converter);
+}
+
+/// This function converts the type signature of the given block, by invoking
+/// 'convertSignatureArg' for each argument. This function should return a valid
+/// conversion for the signature on success, None otherwise.
+auto TypeConverter::convertBlockSignature(Block *block)
+    -> llvm::Optional<SignatureConversion> {
+  SignatureConversion conversion(block->getNumArguments());
+  for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i)
+    if (failed(convertSignatureArg(i, block->getArgument(i)->getType(),
+                                   conversion)))
+      return llvm::None;
+  return conversion;
+}
+
+//===----------------------------------------------------------------------===//
+// ConversionTarget
+//===----------------------------------------------------------------------===//
+
+/// Register a legality action for the given operation.
+void ConversionTarget::setOpAction(OperationName op,
+                                   LegalizationAction action) {
+  legalOperations[op] = action;
+}
+
+/// Register a legality action for the given dialects.
+void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
+                                        LegalizationAction action) {
+  for (StringRef dialect : dialectNames)
+    legalDialects[dialect] = action;
+}
+
+/// Get the legality action for the given operation.
+auto ConversionTarget::getOpAction(OperationName op) const
+    -> llvm::Optional<LegalizationAction> {
+  // Check for an action for this specific operation.
+  auto it = legalOperations.find(op);
+  if (it != legalOperations.end())
+    return it->second;
+  // Otherwise, default to checking for an action on the parent dialect.
+  auto dialectIt = legalDialects.find(op.getDialect());
+  if (dialectIt != legalDialects.end())
+    return dialectIt->second;
+  return llvm::None;
+}
+
+/// Return if the given operation instance is legal on this target.
+bool ConversionTarget::isLegal(Operation *op) const {
+  auto action = getOpAction(op->getName());
+
+  // Handle dynamic legality.
+  if (action == LegalizationAction::Dynamic) {
+    // Check for callbacks on the operation or dialect.
+    auto opFn = opLegalityFns.find(op->getName());
+    if (opFn != opLegalityFns.end())
+      return opFn->second(op);
+    auto dialectFn = dialectLegalityFns.find(op->getName().getDialect());
+    if (dialectFn != dialectLegalityFns.end())
+      return dialectFn->second(op);
+
+    // Otherwise, invoke the hook on the derived instance.
+    return isDynamicallyLegal(op);
+  }
+
+  // Otherwise, the operation is only legal if it was marked 'Legal'.
+  return action == LegalizationAction::Legal;
+}
+
+/// Set the dynamic legality callback for the given operation.
+void ConversionTarget::setLegalityCallback(
+    OperationName name, const DynamicLegalityCallbackFn &callback) {
+  assert(callback && "expected valid legality callback");
+  opLegalityFns[name] = callback;
+}
+
+/// Set the dynamic legality callback for the given dialects.
+void ConversionTarget::setLegalityCallback(
+    ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
+  assert(callback && "expected valid legality callback");
+  for (StringRef dialect : dialects)
+    dialectLegalityFns[dialect] = callback;
+}
+
+//===----------------------------------------------------------------------===//
+// Op Conversion Entry Points
+//===----------------------------------------------------------------------===//
+
+/// Apply a partial conversion on the given operations, and all nested
+/// operations. This method converts as many operations to the target as
+/// possible, ignoring operations that failed to legalize.
+LogicalResult mlir::applyPartialConversion(ArrayRef<Operation *> ops,
+                                           ConversionTarget &target,
+                                           OwningRewritePatternList &patterns,
+                                           TypeConverter *converter) {
+  OperationConverter opConverter(target, patterns, OpConversionMode::Partial);
+  return opConverter.convertOperations(ops, converter);
+}
+LogicalResult mlir::applyPartialConversion(Operation *op,
+                                           ConversionTarget &target,
+                                           OwningRewritePatternList &patterns,
+                                           TypeConverter *converter) {
+  return applyPartialConversion(llvm::makeArrayRef(op), target, patterns,
+                                converter);
+}
+
+/// Apply a complete conversion on the given operations, and all nested
+/// operations. This method will return failure if the conversion of any
+/// operation fails.
+LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
+                                        ConversionTarget &target,
+                                        OwningRewritePatternList &patterns,
+                                        TypeConverter *converter) {
+  OperationConverter opConverter(target, patterns, OpConversionMode::Full);
+  return opConverter.convertOperations(ops, converter);
+}
+LogicalResult mlir::applyFullConversion(Operation *op, ConversionTarget &target,
+                                        OwningRewritePatternList &patterns,
+                                        TypeConverter *converter) {
+  return applyFullConversion(llvm::makeArrayRef(op), target, patterns,
+                             converter);
+}
+
+/// Apply an analysis conversion on the given operations, and all nested
+/// operations. This method analyzes which operations would be successfully
+/// converted to the target if a conversion was applied. All operations that
+/// were found to be legalizable to the given 'target' are placed within the
+/// provided 'convertedOps' set; note that no actual rewrites are applied to the
+/// operations on success and only pre-existing operations are added to the set.
+LogicalResult mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
+                                            ConversionTarget &target,
+                                            OwningRewritePatternList &patterns,
+                                            DenseSet<Operation *> &convertedOps,
+                                            TypeConverter *converter) {
+  OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
+                                 &convertedOps);
+  return opConverter.convertOperations(ops, converter);
+}
+LogicalResult mlir::applyAnalysisConversion(Operation *op,
+                                            ConversionTarget &target,
+                                            OwningRewritePatternList &patterns,
+                                            DenseSet<Operation *> &convertedOps,
+                                            TypeConverter *converter) {
+  return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns,
+                                 convertedOps, converter);
+}
diff --git a/third_party/mlir/lib/Transforms/LoopCoalescing.cpp b/third_party/mlir/lib/Transforms/LoopCoalescing.cpp
new file mode 100644
index 0000000..f47433c
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/LoopCoalescing.cpp
@@ -0,0 +1,105 @@
+//===- LoopCoalescing.cpp - Pass transforming loop nests into single loops-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/Support/Debug.h"
+
+#define PASS_NAME "loop-coalescing"
+#define DEBUG_TYPE PASS_NAME
+
+using namespace mlir;
+
+namespace {
+class LoopCoalescingPass : public FunctionPass<LoopCoalescingPass> {
+public:
+  void runOnFunction() override {
+    FuncOp func = getFunction();
+
+    func.walk<loop::ForOp>([](loop::ForOp op) {
+      // Ignore nested loops.
+      if (op.getParentOfType<loop::ForOp>())
+        return;
+
+      SmallVector<loop::ForOp, 4> loops;
+      getPerfectlyNestedLoops(loops, op);
+      LLVM_DEBUG(llvm::dbgs()
+                 << "found a perfect nest of depth " << loops.size() << '\n');
+
+      // Look for a band of loops that can be coalesced, i.e. perfectly nested
+      // loops with bounds defined above some loop.
+      // 1. For each loop, find above which parent loop its operands are
+      // defined.
+      SmallVector<unsigned, 4> operandsDefinedAbove(loops.size());
+      for (unsigned i = 0, e = loops.size(); i < e; ++i) {
+        operandsDefinedAbove[i] = i;
+        for (unsigned j = 0; j < i; ++j) {
+          if (areValuesDefinedAbove(loops[i].getOperands(),
+                                    loops[j].region())) {
+            operandsDefinedAbove[i] = j;
+            break;
+          }
+        }
+        LLVM_DEBUG(llvm::dbgs()
+                   << "  bounds of loop " << i << " are known above depth "
+                   << operandsDefinedAbove[i] << '\n');
+      }
+
+      // 2. Identify bands of loops such that the operands of all of them are
+      // defined above the first loop in the band.  Traverse the nest bottom-up
+      // so that modifications don't invalidate the inner loops.
+      for (unsigned end = loops.size(); end > 0; --end) {
+        unsigned start = 0;
+        for (; start < end - 1; ++start) {
+          auto maxPos =
+              *std::max_element(std::next(operandsDefinedAbove.begin(), start),
+                                std::next(operandsDefinedAbove.begin(), end));
+          if (maxPos > start)
+            continue;
+
+          assert(maxPos == start &&
+                 "expected loop bounds to be known at the start of the band");
+          LLVM_DEBUG(llvm::dbgs() << "  found coalesceable band from " << start
+                                  << " to " << end << '\n');
+
+          auto band =
+              llvm::makeMutableArrayRef(loops.data() + start, end - start);
+          coalesceLoops(band);
+          break;
+        }
+        // If a band was found and transformed, keep looking at the loops above
+        // the outermost transformed loop.
+        if (start != end - 1)
+          end = start + 1;
+      }
+    });
+  }
+};
+
+} // namespace
+
+FunctionPassBase *mlir::createLoopCoalescingPass() {
+  return new LoopCoalescingPass;
+}
+
+static PassRegistration<LoopCoalescingPass>
+    reg(PASS_NAME,
+        "coalesce nested loops with independent bounds into a single loop");
diff --git a/third_party/mlir/lib/Transforms/LoopFusion.cpp b/third_party/mlir/lib/Transforms/LoopFusion.cpp
new file mode 100644
index 0000000..ea1a03f
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/LoopFusion.cpp
@@ -0,0 +1,1901 @@
+//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements loop fusion.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Transforms/LoopFusionUtils.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/Utils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include <iomanip>
+#include <sstream>
+#define DEBUG_TYPE "affine-loop-fusion"
+
+using llvm::SetVector;
+
+using namespace mlir;
+
+static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
+
+/// Disables fusion profitability check and fuses if valid. Ignore any
+/// additional (redundant) computation tolerance threshold
+/// that would have prevented fusion.
+static llvm::cl::opt<bool>
+    clMaximalLoopFusion("fusion-maximal",
+                        llvm::cl::desc("Enables maximal loop fusion"),
+                        llvm::cl::cat(clOptionsCategory));
+
+/// A threshold in percent of additional computation allowed when fusing.
+static llvm::cl::opt<double> clFusionAddlComputeTolerance(
+    "fusion-compute-tolerance",
+    llvm::cl::desc("Fractional increase in additional "
+                   "computation tolerated while fusing"),
+    llvm::cl::cat(clOptionsCategory));
+
+static llvm::cl::opt<unsigned> clFusionFastMemorySpace(
+    "fusion-fast-mem-space",
+    llvm::cl::desc("Faster memory space number to promote fusion buffers to"),
+    llvm::cl::cat(clOptionsCategory));
+
+// A local buffer of size less than or equal to this size is automatically
+// promoted to fast memory after producer-consumer fusion.
+static llvm::cl::opt<unsigned long long> clFusionLocalBufThreshold(
+    "fusion-local-buf-threshold",
+    llvm::cl::desc("Threshold size (KiB) for promoting local buffers to fast "
+                   "memory space"),
+    llvm::cl::cat(clOptionsCategory));
+
+namespace {
+
+/// Loop fusion pass. This pass currently supports a greedy fusion policy,
+/// which fuses loop nests with single-writer/single-reader memref dependences
+/// with the goal of improving locality.
+
+// TODO(andydavis) Support fusion of source loop nests which write to multiple
+// memrefs, where each memref can have multiple users (if profitable).
+// TODO(andydavis) Extend this pass to check for fusion preventing dependences,
+// and add support for more general loop fusion algorithms.
+
+struct LoopFusion : public FunctionPass<LoopFusion> {
+  LoopFusion(unsigned fastMemorySpace = 0, uint64_t localBufSizeThreshold = 0,
+             bool maximalFusion = false)
+      : localBufSizeThreshold(localBufSizeThreshold),
+        fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion) {}
+
+  void runOnFunction() override;
+
+  // Any local buffers smaller than this size (in bytes) will be created in
+  // `fastMemorySpace` if provided.
+  uint64_t localBufSizeThreshold;
+  Optional<unsigned> fastMemorySpace = None;
+  // If true, ignore any additional (redundant) computation tolerance threshold
+  // that would have prevented fusion.
+  bool maximalFusion;
+
+  // The amount of additional computation that is tolerated while fusing
+  // pair-wise as a fraction of the total computation.
+  constexpr static double kComputeToleranceThreshold = 0.30f;
+};
+
+} // end anonymous namespace
+
+FunctionPassBase *mlir::createLoopFusionPass(unsigned fastMemorySpace,
+                                             uint64_t localBufSizeThreshold,
+                                             bool maximalFusion) {
+  return new LoopFusion(fastMemorySpace, localBufSizeThreshold, maximalFusion);
+}
+
+namespace {
+
+// LoopNestStateCollector walks loop nests and collects load and store
+// operations, and whether or not an IfInst was encountered in the loop nest.
+struct LoopNestStateCollector {
+  SmallVector<AffineForOp, 4> forOps;
+  SmallVector<Operation *, 4> loadOpInsts;
+  SmallVector<Operation *, 4> storeOpInsts;
+  bool hasNonForRegion = false;
+
+  void collect(Operation *opToWalk) {
+    opToWalk->walk([&](Operation *op) {
+      if (isa<AffineForOp>(op))
+        forOps.push_back(cast<AffineForOp>(op));
+      else if (op->getNumRegions() != 0)
+        hasNonForRegion = true;
+      else if (isa<AffineLoadOp>(op))
+        loadOpInsts.push_back(op);
+      else if (isa<AffineStoreOp>(op))
+        storeOpInsts.push_back(op);
+    });
+  }
+};
+
+// TODO(b/117228571) Replace when this is modeled through side-effects/op traits
+static bool isMemRefDereferencingOp(Operation &op) {
+  if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op) ||
+      isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op))
+    return true;
+  return false;
+}
+
+// MemRefDependenceGraph is a graph data structure where graph nodes are
+// top-level operations in a FuncOp which contain load/store ops, and edges
+// are memref dependences between the nodes.
+// TODO(andydavis) Add a more flexible dependece graph representation.
+// TODO(andydavis) Add a depth parameter to dependence graph construction.
+struct MemRefDependenceGraph {
+public:
+  // Node represents a node in the graph. A Node is either an entire loop nest
+  // rooted at the top level which contains loads/stores, or a top level
+  // load/store.
+  struct Node {
+    // The unique identifier of this node in the graph.
+    unsigned id;
+    // The top-level statement which is (or contains) a load/store.
+    Operation *op;
+    // List of load operations.
+    SmallVector<Operation *, 4> loads;
+    // List of store op insts.
+    SmallVector<Operation *, 4> stores;
+    Node(unsigned id, Operation *op) : id(id), op(op) {}
+
+    // Returns the load op count for 'memref'.
+    unsigned getLoadOpCount(Value *memref) {
+      unsigned loadOpCount = 0;
+      for (auto *loadOpInst : loads) {
+        if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef())
+          ++loadOpCount;
+      }
+      return loadOpCount;
+    }
+
+    // Returns the store op count for 'memref'.
+    unsigned getStoreOpCount(Value *memref) {
+      unsigned storeOpCount = 0;
+      for (auto *storeOpInst : stores) {
+        if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef())
+          ++storeOpCount;
+      }
+      return storeOpCount;
+    }
+
+    // Returns all store ops in 'storeOps' which access 'memref'.
+    void getStoreOpsForMemref(Value *memref,
+                              SmallVectorImpl<Operation *> *storeOps) {
+      for (auto *storeOpInst : stores) {
+        if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef())
+          storeOps->push_back(storeOpInst);
+      }
+    }
+
+    // Returns all load ops in 'loadOps' which access 'memref'.
+    void getLoadOpsForMemref(Value *memref,
+                             SmallVectorImpl<Operation *> *loadOps) {
+      for (auto *loadOpInst : loads) {
+        if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef())
+          loadOps->push_back(loadOpInst);
+      }
+    }
+
+    // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
+    // has at least one load and store operation.
+    void getLoadAndStoreMemrefSet(DenseSet<Value *> *loadAndStoreMemrefSet) {
+      llvm::SmallDenseSet<Value *, 2> loadMemrefs;
+      for (auto *loadOpInst : loads) {
+        loadMemrefs.insert(cast<AffineLoadOp>(loadOpInst).getMemRef());
+      }
+      for (auto *storeOpInst : stores) {
+        auto *memref = cast<AffineStoreOp>(storeOpInst).getMemRef();
+        if (loadMemrefs.count(memref) > 0)
+          loadAndStoreMemrefSet->insert(memref);
+      }
+    }
+  };
+
+  // Edge represents a data dependece between nodes in the graph.
+  struct Edge {
+    // The id of the node at the other end of the edge.
+    // If this edge is stored in Edge = Node.inEdges[i], then
+    // 'Node.inEdges[i].id' is the identifier of the source node of the edge.
+    // If this edge is stored in Edge = Node.outEdges[i], then
+    // 'Node.outEdges[i].id' is the identifier of the dest node of the edge.
+    unsigned id;
+    // The SSA value on which this edge represents a dependence.
+    // If the value is a memref, then the dependence is between graph nodes
+    // which contain accesses to the same memref 'value'. If the value is a
+    // non-memref value, then the dependence is between a graph node which
+    // defines an SSA value and another graph node which uses the SSA value
+    // (e.g. a constant operation defining a value which is used inside a loop
+    // nest).
+    Value *value;
+  };
+
+  // Map from node id to Node.
+  DenseMap<unsigned, Node> nodes;
+  // Map from node id to list of input edges.
+  DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
+  // Map from node id to list of output edges.
+  DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
+  // Map from memref to a count on the dependence edges associated with that
+  // memref.
+  DenseMap<Value *, unsigned> memrefEdgeCount;
+  // The next unique identifier to use for newly created graph nodes.
+  unsigned nextNodeId = 0;
+
+  MemRefDependenceGraph() {}
+
+  // Initializes the dependence graph based on operations in 'f'.
+  // Returns true on success, false otherwise.
+  bool init(FuncOp f);
+
+  // Returns the graph node for 'id'.
+  Node *getNode(unsigned id) {
+    auto it = nodes.find(id);
+    assert(it != nodes.end());
+    return &it->second;
+  }
+
+  // Returns the graph node for 'forOp'.
+  Node *getForOpNode(AffineForOp forOp) {
+    for (auto &idAndNode : nodes)
+      if (idAndNode.second.op == forOp.getOperation())
+        return &idAndNode.second;
+    return nullptr;
+  }
+
+  // Adds a node with 'op' to the graph and returns its unique identifier.
+  unsigned addNode(Operation *op) {
+    Node node(nextNodeId++, op);
+    nodes.insert({node.id, node});
+    return node.id;
+  }
+
+  // Remove node 'id' (and its associated edges) from graph.
+  void removeNode(unsigned id) {
+    // Remove each edge in 'inEdges[id]'.
+    if (inEdges.count(id) > 0) {
+      SmallVector<Edge, 2> oldInEdges = inEdges[id];
+      for (auto &inEdge : oldInEdges) {
+        removeEdge(inEdge.id, id, inEdge.value);
+      }
+    }
+    // Remove each edge in 'outEdges[id]'.
+    if (outEdges.count(id) > 0) {
+      SmallVector<Edge, 2> oldOutEdges = outEdges[id];
+      for (auto &outEdge : oldOutEdges) {
+        removeEdge(id, outEdge.id, outEdge.value);
+      }
+    }
+    // Erase remaining node state.
+    inEdges.erase(id);
+    outEdges.erase(id);
+    nodes.erase(id);
+  }
+
+  // Returns true if node 'id' writes to any memref which escapes (or is an
+  // argument to) the function/block. Returns false otherwise.
+  bool writesToLiveInOrEscapingMemrefs(unsigned id) {
+    Node *node = getNode(id);
+    for (auto *storeOpInst : node->stores) {
+      auto *memref = cast<AffineStoreOp>(storeOpInst).getMemRef();
+      auto *op = memref->getDefiningOp();
+      // Return true if 'memref' is a block argument.
+      if (!op)
+        return true;
+      // Return true if any use of 'memref' escapes the function.
+      for (auto *user : memref->getUsers())
+        if (!isMemRefDereferencingOp(*user))
+          return true;
+    }
+    return false;
+  }
+
+  // Returns true if node 'id' can be removed from the graph. Returns false
+  // otherwise. A node can be removed from the graph iff the following
+  // conditions are met:
+  // *) The node does not write to any memref which escapes (or is a
+  //    function/block argument).
+  // *) The node has no successors in the dependence graph.
+  bool canRemoveNode(unsigned id) {
+    if (writesToLiveInOrEscapingMemrefs(id))
+      return false;
+    Node *node = getNode(id);
+    for (auto *storeOpInst : node->stores) {
+      // Return false if there exist out edges from 'id' on 'memref'.
+      if (getOutEdgeCount(id, cast<AffineStoreOp>(storeOpInst).getMemRef()) > 0)
+        return false;
+    }
+    return true;
+  }
+
+  // Returns true iff there is an edge from node 'srcId' to node 'dstId' which
+  // is for 'value' if non-null, or for any value otherwise. Returns false
+  // otherwise.
+  bool hasEdge(unsigned srcId, unsigned dstId, Value *value = nullptr) {
+    if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
+      return false;
+    }
+    bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
+      return edge.id == dstId && (!value || edge.value == value);
+    });
+    bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
+      return edge.id == srcId && (!value || edge.value == value);
+    });
+    return hasOutEdge && hasInEdge;
+  }
+
+  // Adds an edge from node 'srcId' to node 'dstId' for 'value'.
+  void addEdge(unsigned srcId, unsigned dstId, Value *value) {
+    if (!hasEdge(srcId, dstId, value)) {
+      outEdges[srcId].push_back({dstId, value});
+      inEdges[dstId].push_back({srcId, value});
+      if (value->getType().isa<MemRefType>())
+        memrefEdgeCount[value]++;
+    }
+  }
+
+  // Removes an edge from node 'srcId' to node 'dstId' for 'value'.
+  void removeEdge(unsigned srcId, unsigned dstId, Value *value) {
+    assert(inEdges.count(dstId) > 0);
+    assert(outEdges.count(srcId) > 0);
+    if (value->getType().isa<MemRefType>()) {
+      assert(memrefEdgeCount.count(value) > 0);
+      memrefEdgeCount[value]--;
+    }
+    // Remove 'srcId' from 'inEdges[dstId]'.
+    for (auto it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
+      if ((*it).id == srcId && (*it).value == value) {
+        inEdges[dstId].erase(it);
+        break;
+      }
+    }
+    // Remove 'dstId' from 'outEdges[srcId]'.
+    for (auto it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
+      if ((*it).id == dstId && (*it).value == value) {
+        outEdges[srcId].erase(it);
+        break;
+      }
+    }
+  }
+
+  // Returns true if there is a path in the dependence graph from node 'srcId'
+  // to node 'dstId'. Returns false otherwise.
+  bool hasDependencePath(unsigned srcId, unsigned dstId) {
+    // Worklist state is: <node-id, next-output-edge-index-to-visit>
+    SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
+    worklist.push_back({srcId, 0});
+    // Run DFS traversal to see if 'dstId' is reachable from 'srcId'.
+    while (!worklist.empty()) {
+      auto &idAndIndex = worklist.back();
+      // Return true if we have reached 'dstId'.
+      if (idAndIndex.first == dstId)
+        return true;
+      // Pop and continue if node has no out edges, or if all out edges have
+      // already been visited.
+      if (outEdges.count(idAndIndex.first) == 0 ||
+          idAndIndex.second == outEdges[idAndIndex.first].size()) {
+        worklist.pop_back();
+        continue;
+      }
+      // Get graph edge to traverse.
+      Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
+      // Increment next output edge index for 'idAndIndex'.
+      ++idAndIndex.second;
+      // Add node at 'edge.id' to worklist.
+      worklist.push_back({edge.id, 0});
+    }
+    return false;
+  }
+
+  // Returns the input edge count for node 'id' and 'memref' from src nodes
+  // which access 'memref' with a store operation.
+  unsigned getIncomingMemRefAccesses(unsigned id, Value *memref) {
+    unsigned inEdgeCount = 0;
+    if (inEdges.count(id) > 0)
+      for (auto &inEdge : inEdges[id])
+        if (inEdge.value == memref) {
+          Node *srcNode = getNode(inEdge.id);
+          // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
+          if (srcNode->getStoreOpCount(memref) > 0)
+            ++inEdgeCount;
+        }
+    return inEdgeCount;
+  }
+
+  // Returns the output edge count for node 'id' and 'memref' (if non-null),
+  // otherwise returns the total output edge count from node 'id'.
+  unsigned getOutEdgeCount(unsigned id, Value *memref = nullptr) {
+    unsigned outEdgeCount = 0;
+    if (outEdges.count(id) > 0)
+      for (auto &outEdge : outEdges[id])
+        if (!memref || outEdge.value == memref)
+          ++outEdgeCount;
+    return outEdgeCount;
+  }
+
+  // Computes and returns an insertion point operation, before which the
+  // the fused <srcId, dstId> loop nest can be inserted while preserving
+  // dependences. Returns nullptr if no such insertion point is found.
+  Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) {
+    if (outEdges.count(srcId) == 0)
+      return getNode(dstId)->op;
+
+    // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
+    SmallPtrSet<Operation *, 2> srcDepInsts;
+    for (auto &outEdge : outEdges[srcId])
+      if (outEdge.id != dstId)
+        srcDepInsts.insert(getNode(outEdge.id)->op);
+
+    // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
+    SmallPtrSet<Operation *, 2> dstDepInsts;
+    for (auto &inEdge : inEdges[dstId])
+      if (inEdge.id != srcId)
+        dstDepInsts.insert(getNode(inEdge.id)->op);
+
+    Operation *srcNodeInst = getNode(srcId)->op;
+    Operation *dstNodeInst = getNode(dstId)->op;
+
+    // Computing insertion point:
+    // *) Walk all operation positions in Block operation list in the
+    //    range (src, dst). For each operation 'op' visited in this search:
+    //   *) Store in 'firstSrcDepPos' the first position where 'op' has a
+    //      dependence edge from 'srcNode'.
+    //   *) Store in 'lastDstDepPost' the last position where 'op' has a
+    //      dependence edge to 'dstNode'.
+    // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the
+    //    operation insertion point (or return null pointer if no such
+    //    insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos').
+    SmallVector<Operation *, 2> depInsts;
+    Optional<unsigned> firstSrcDepPos;
+    Optional<unsigned> lastDstDepPos;
+    unsigned pos = 0;
+    for (Block::iterator it = std::next(Block::iterator(srcNodeInst));
+         it != Block::iterator(dstNodeInst); ++it) {
+      Operation *op = &(*it);
+      if (srcDepInsts.count(op) > 0 && firstSrcDepPos == None)
+        firstSrcDepPos = pos;
+      if (dstDepInsts.count(op) > 0)
+        lastDstDepPos = pos;
+      depInsts.push_back(op);
+      ++pos;
+    }
+
+    if (firstSrcDepPos.hasValue()) {
+      if (lastDstDepPos.hasValue()) {
+        if (firstSrcDepPos.getValue() <= lastDstDepPos.getValue()) {
+          // No valid insertion point exists which preserves dependences.
+          return nullptr;
+        }
+      }
+      // Return the insertion point at 'firstSrcDepPos'.
+      return depInsts[firstSrcDepPos.getValue()];
+    }
+    // No dependence targets in range (or only dst deps in range), return
+    // 'dstNodInst' insertion point.
+    return dstNodeInst;
+  }
+
+  // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef'
+  // has been replaced in node at 'dstId' by a private memref.
+  void updateEdges(unsigned srcId, unsigned dstId, Value *oldMemRef) {
+    // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'.
+    if (inEdges.count(srcId) > 0) {
+      SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
+      for (auto &inEdge : oldInEdges) {
+        // Add edge from 'inEdge.id' to 'dstId' if not for 'oldMemRef'.
+        if (inEdge.value != oldMemRef)
+          addEdge(inEdge.id, dstId, inEdge.value);
+      }
+    }
+    // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
+    if (outEdges.count(srcId) > 0) {
+      SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
+      for (auto &outEdge : oldOutEdges) {
+        // Remove any out edges from 'srcId' to 'dstId' across memrefs.
+        if (outEdge.id == dstId)
+          removeEdge(srcId, outEdge.id, outEdge.value);
+      }
+    }
+    // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
+    // replaced by a private memref). These edges could come from nodes
+    // other than 'srcId' which were removed in the previous step.
+    if (inEdges.count(dstId) > 0) {
+      SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
+      for (auto &inEdge : oldInEdges)
+        if (inEdge.value == oldMemRef)
+          removeEdge(inEdge.id, dstId, inEdge.value);
+    }
+  }
+
+  // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
+  // of sibling node 'sidId' into node 'dstId'.
+  void updateEdges(unsigned sibId, unsigned dstId) {
+    // For each edge in 'inEdges[sibId]':
+    // *) Add new edge from source node 'inEdge.id' to 'dstNode'.
+    // *) Remove edge from source node 'inEdge.id' to 'sibNode'.
+    if (inEdges.count(sibId) > 0) {
+      SmallVector<Edge, 2> oldInEdges = inEdges[sibId];
+      for (auto &inEdge : oldInEdges) {
+        addEdge(inEdge.id, dstId, inEdge.value);
+        removeEdge(inEdge.id, sibId, inEdge.value);
+      }
+    }
+
+    // For each edge in 'outEdges[sibId]' to node 'id'
+    // *) Add new edge from 'dstId' to 'outEdge.id'.
+    // *) Remove edge from 'sibId' to 'outEdge.id'.
+    if (outEdges.count(sibId) > 0) {
+      SmallVector<Edge, 2> oldOutEdges = outEdges[sibId];
+      for (auto &outEdge : oldOutEdges) {
+        addEdge(dstId, outEdge.id, outEdge.value);
+        removeEdge(sibId, outEdge.id, outEdge.value);
+      }
+    }
+  }
+
+  // Adds ops in 'loads' and 'stores' to node at 'id'.
+  void addToNode(unsigned id, const SmallVectorImpl<Operation *> &loads,
+                 const SmallVectorImpl<Operation *> &stores) {
+    Node *node = getNode(id);
+    for (auto *loadOpInst : loads)
+      node->loads.push_back(loadOpInst);
+    for (auto *storeOpInst : stores)
+      node->stores.push_back(storeOpInst);
+  }
+
+  void clearNodeLoadAndStores(unsigned id) {
+    Node *node = getNode(id);
+    node->loads.clear();
+    node->stores.clear();
+  }
+
+  // Calls 'callback' for each input edge incident to node 'id' which carries a
+  // memref dependence.
+  void forEachMemRefInputEdge(unsigned id,
+                              const std::function<void(Edge)> &callback) {
+    if (inEdges.count(id) > 0)
+      forEachMemRefEdge(inEdges[id], callback);
+  }
+
+  // Calls 'callback' for each output edge from node 'id' which carries a
+  // memref dependence.
+  void forEachMemRefOutputEdge(unsigned id,
+                               const std::function<void(Edge)> &callback) {
+    if (outEdges.count(id) > 0)
+      forEachMemRefEdge(outEdges[id], callback);
+  }
+
+  // Calls 'callback' for each edge in 'edges' which carries a memref
+  // dependence.
+  void forEachMemRefEdge(ArrayRef<Edge> edges,
+                         const std::function<void(Edge)> &callback) {
+    for (auto &edge : edges) {
+      // Skip if 'edge' is not a memref dependence edge.
+      if (!edge.value->getType().isa<MemRefType>())
+        continue;
+      assert(nodes.count(edge.id) > 0);
+      // Skip if 'edge.id' is not a loop nest.
+      if (!isa<AffineForOp>(getNode(edge.id)->op))
+        continue;
+      // Visit current input edge 'edge'.
+      callback(edge);
+    }
+  }
+
+  void print(raw_ostream &os) const {
+    os << "\nMemRefDependenceGraph\n";
+    os << "\nNodes:\n";
+    for (auto &idAndNode : nodes) {
+      os << "Node: " << idAndNode.first << "\n";
+      auto it = inEdges.find(idAndNode.first);
+      if (it != inEdges.end()) {
+        for (const auto &e : it->second)
+          os << "  InEdge: " << e.id << " " << e.value << "\n";
+      }
+      it = outEdges.find(idAndNode.first);
+      if (it != outEdges.end()) {
+        for (const auto &e : it->second)
+          os << "  OutEdge: " << e.id << " " << e.value << "\n";
+      }
+    }
+  }
+  void dump() const { print(llvm::errs()); }
+};
+
+// Intializes the data dependence graph by walking operations in 'f'.
+// Assigns each node in the graph a node id based on program order in 'f'.
+// TODO(andydavis) Add support for taking a Block arg to construct the
+// dependence graph at a different depth.
+bool MemRefDependenceGraph::init(FuncOp f) {
+  DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
+
+  // TODO: support multi-block functions.
+  if (f.getBlocks().size() != 1)
+    return false;
+
+  DenseMap<Operation *, unsigned> forToNodeMap;
+  for (auto &op : f.front()) {
+    if (auto forOp = dyn_cast<AffineForOp>(op)) {
+      // Create graph node 'id' to represent top-level 'forOp' and record
+      // all loads and store accesses it contains.
+      LoopNestStateCollector collector;
+      collector.collect(&op);
+      // Return false if a non 'affine.for' region was found (not currently
+      // supported).
+      if (collector.hasNonForRegion)
+        return false;
+      Node node(nextNodeId++, &op);
+      for (auto *opInst : collector.loadOpInsts) {
+        node.loads.push_back(opInst);
+        auto *memref = cast<AffineLoadOp>(opInst).getMemRef();
+        memrefAccesses[memref].insert(node.id);
+      }
+      for (auto *opInst : collector.storeOpInsts) {
+        node.stores.push_back(opInst);
+        auto *memref = cast<AffineStoreOp>(opInst).getMemRef();
+        memrefAccesses[memref].insert(node.id);
+      }
+      forToNodeMap[&op] = node.id;
+      nodes.insert({node.id, node});
+    } else if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
+      // Create graph node for top-level load op.
+      Node node(nextNodeId++, &op);
+      node.loads.push_back(&op);
+      auto *memref = cast<AffineLoadOp>(op).getMemRef();
+      memrefAccesses[memref].insert(node.id);
+      nodes.insert({node.id, node});
+    } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
+      // Create graph node for top-level store op.
+      Node node(nextNodeId++, &op);
+      node.stores.push_back(&op);
+      auto *memref = cast<AffineStoreOp>(op).getMemRef();
+      memrefAccesses[memref].insert(node.id);
+      nodes.insert({node.id, node});
+    } else if (op.getNumRegions() != 0) {
+      // Return false if another region is found (not currently supported).
+      return false;
+    } else if (op.getNumResults() > 0 && !op.use_empty()) {
+      // Create graph node for top-level producer of SSA values, which
+      // could be used by loop nest nodes.
+      Node node(nextNodeId++, &op);
+      nodes.insert({node.id, node});
+    }
+  }
+
+  // Add dependence edges between nodes which produce SSA values and their
+  // users.
+  for (auto &idAndNode : nodes) {
+    const Node &node = idAndNode.second;
+    if (!node.loads.empty() || !node.stores.empty())
+      continue;
+    auto *opInst = node.op;
+    for (auto *value : opInst->getResults()) {
+      for (auto *user : value->getUsers()) {
+        SmallVector<AffineForOp, 4> loops;
+        getLoopIVs(*user, &loops);
+        if (loops.empty())
+          continue;
+        assert(forToNodeMap.count(loops[0].getOperation()) > 0);
+        unsigned userLoopNestId = forToNodeMap[loops[0].getOperation()];
+        addEdge(node.id, userLoopNestId, value);
+      }
+    }
+  }
+
+  // Walk memref access lists and add graph edges between dependent nodes.
+  for (auto &memrefAndList : memrefAccesses) {
+    unsigned n = memrefAndList.second.size();
+    for (unsigned i = 0; i < n; ++i) {
+      unsigned srcId = memrefAndList.second[i];
+      bool srcHasStore =
+          getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
+      for (unsigned j = i + 1; j < n; ++j) {
+        unsigned dstId = memrefAndList.second[j];
+        bool dstHasStore =
+            getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
+        if (srcHasStore || dstHasStore)
+          addEdge(srcId, dstId, memrefAndList.first);
+      }
+    }
+  }
+  return true;
+}
+
+// Removes load operations from 'srcLoads' which operate on 'memref', and
+// adds them to 'dstLoads'.
+static void moveLoadsAccessingMemrefTo(Value *memref,
+                                       SmallVectorImpl<Operation *> *srcLoads,
+                                       SmallVectorImpl<Operation *> *dstLoads) {
+  dstLoads->clear();
+  SmallVector<Operation *, 4> srcLoadsToKeep;
+  for (auto *load : *srcLoads) {
+    if (cast<AffineLoadOp>(load).getMemRef() == memref)
+      dstLoads->push_back(load);
+    else
+      srcLoadsToKeep.push_back(load);
+  }
+  srcLoads->swap(srcLoadsToKeep);
+}
+
+// Returns the innermost common loop depth for the set of operations in 'ops'.
+static unsigned getInnermostCommonLoopDepth(ArrayRef<Operation *> ops) {
+  unsigned numOps = ops.size();
+  assert(numOps > 0);
+
+  std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
+  unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
+  for (unsigned i = 0; i < numOps; ++i) {
+    getLoopIVs(*ops[i], &loops[i]);
+    loopDepthLimit =
+        std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
+  }
+
+  unsigned loopDepth = 0;
+  for (unsigned d = 0; d < loopDepthLimit; ++d) {
+    unsigned i;
+    for (i = 1; i < numOps; ++i) {
+      if (loops[i - 1][d] != loops[i][d])
+        break;
+    }
+    if (i != numOps)
+      break;
+    ++loopDepth;
+  }
+  return loopDepth;
+}
+
+// Returns the maximum loop depth at which no dependences between 'loadOpInsts'
+// and 'storeOpInsts' are satisfied.
+static unsigned getMaxLoopDepth(ArrayRef<Operation *> loadOpInsts,
+                                ArrayRef<Operation *> storeOpInsts) {
+  // Merge loads and stores into the same array.
+  SmallVector<Operation *, 2> ops(loadOpInsts.begin(), loadOpInsts.end());
+  ops.append(storeOpInsts.begin(), storeOpInsts.end());
+
+  // Compute the innermost common loop depth for loads and stores.
+  unsigned loopDepth = getInnermostCommonLoopDepth(ops);
+
+  // Return common loop depth for loads if there are no store ops.
+  if (storeOpInsts.empty())
+    return loopDepth;
+
+  // Check dependences on all pairs of ops in 'ops' and store the minimum
+  // loop depth at which a dependence is satisfied.
+  for (unsigned i = 0, e = ops.size(); i < e; ++i) {
+    auto *srcOpInst = ops[i];
+    MemRefAccess srcAccess(srcOpInst);
+    for (unsigned j = 0; j < e; ++j) {
+      auto *dstOpInst = ops[j];
+      MemRefAccess dstAccess(dstOpInst);
+
+      unsigned numCommonLoops =
+          getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
+      for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
+        FlatAffineConstraints dependenceConstraints;
+        // TODO(andydavis) Cache dependence analysis results, check cache here.
+        DependenceResult result = checkMemrefAccessDependence(
+            srcAccess, dstAccess, d, &dependenceConstraints,
+            /*dependenceComponents=*/nullptr);
+        if (hasDependence(result)) {
+          // Store minimum loop depth and break because we want the min 'd' at
+          // which there is a dependence.
+          loopDepth = std::min(loopDepth, d - 1);
+          break;
+        }
+      }
+    }
+  }
+  return loopDepth;
+}
+
+// Sinks all sequential loops to the innermost levels (while preserving
+// relative order among them) and moves all parallel loops to the
+// outermost (while again preserving relative order among them).
+// This can increase the loop depth at which we can fuse a slice, since we are
+// pushing loop carried dependence to a greater depth in the loop nest.
+static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
+  assert(isa<AffineForOp>(node->op));
+  AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op));
+  node->op = newRootForOp.getOperation();
+}
+
+//  TODO(mlir-team): improve/complete this when we have target data.
+unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
+  auto elementType = memRefType.getElementType();
+
+  unsigned sizeInBits;
+  if (elementType.isIntOrFloat()) {
+    sizeInBits = elementType.getIntOrFloatBitWidth();
+  } else {
+    auto vectorType = elementType.cast<VectorType>();
+    sizeInBits =
+        vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
+  }
+  return llvm::divideCeil(sizeInBits, 8);
+}
+
+// Creates and returns a private (single-user) memref for fused loop rooted
+// at 'forOp', with (potentially reduced) memref size based on the
+// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
+// TODO(bondhugula): consider refactoring the common code from generateDma and
+// this one.
+static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
+                                  unsigned dstLoopDepth,
+                                  Optional<unsigned> fastMemorySpace,
+                                  uint64_t localBufSizeThreshold) {
+  auto *forInst = forOp.getOperation();
+
+  // Create builder to insert alloc op just before 'forOp'.
+  OpBuilder b(forInst);
+  // Builder to create constants at the top level.
+  OpBuilder top(forInst->getParentOfType<FuncOp>().getBody());
+  // Create new memref type based on slice bounds.
+  auto *oldMemRef = cast<AffineStoreOp>(srcStoreOpInst).getMemRef();
+  auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
+  unsigned rank = oldMemRefType.getRank();
+
+  // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
+  MemRefRegion region(srcStoreOpInst->getLoc());
+  bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth));
+  (void)validRegion;
+  assert(validRegion && "unexpected memref region failure");
+  SmallVector<int64_t, 4> newShape;
+  std::vector<SmallVector<int64_t, 4>> lbs;
+  SmallVector<int64_t, 8> lbDivisors;
+  lbs.reserve(rank);
+  // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
+  // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
+  Optional<int64_t> numElements =
+      region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors);
+  assert(numElements.hasValue() &&
+         "non-constant number of elts in local buffer");
+
+  const FlatAffineConstraints *cst = region.getConstraints();
+  // 'outerIVs' holds the values that this memory region is symbolic/paramteric
+  // on; this would correspond to loop IVs surrounding the level at which the
+  // slice is being materialized.
+  SmallVector<Value *, 8> outerIVs;
+  cst->getIdValues(rank, cst->getNumIds(), &outerIVs);
+
+  // Build 'rank' AffineExprs from MemRefRegion 'lbs'
+  SmallVector<AffineExpr, 4> offsets;
+  offsets.reserve(rank);
+  for (unsigned d = 0; d < rank; ++d) {
+    assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
+
+    AffineExpr offset = top.getAffineConstantExpr(0);
+    for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) {
+      offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
+    }
+    assert(lbDivisors[d] > 0);
+    offset =
+        (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
+    offsets.push_back(offset);
+  }
+
+  // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
+  // by 'srcStoreOpInst'.
+  uint64_t bufSize =
+      getMemRefEltSizeInBytes(oldMemRefType) * numElements.getValue();
+  unsigned newMemSpace;
+  if (bufSize <= localBufSizeThreshold && fastMemorySpace.hasValue()) {
+    newMemSpace = fastMemorySpace.getValue();
+  } else {
+    newMemSpace = oldMemRefType.getMemorySpace();
+  }
+  auto newMemRefType = top.getMemRefType(
+      newShape, oldMemRefType.getElementType(), {}, newMemSpace);
+  // Gather alloc operands for the dynamic dimensions of the memref.
+  SmallVector<Value *, 4> allocOperands;
+  unsigned dynamicDimCount = 0;
+  for (auto dimSize : oldMemRefType.getShape()) {
+    if (dimSize == -1)
+      allocOperands.push_back(
+          top.create<DimOp>(forOp.getLoc(), oldMemRef, dynamicDimCount++));
+  }
+
+  // Create new private memref for fused loop 'forOp'.
+  // TODO(andydavis) Create/move alloc ops for private memrefs closer to their
+  // consumer loop nests to reduce their live range. Currently they are added
+  // at the beginning of the function, because loop nests can be reordered
+  // during the fusion pass.
+  Value *newMemRef =
+      top.create<AllocOp>(forOp.getLoc(), newMemRefType, allocOperands);
+
+  // Build an AffineMap to remap access functions based on lower bound offsets.
+  SmallVector<AffineExpr, 4> remapExprs;
+  remapExprs.reserve(rank);
+  unsigned zeroOffsetCount = 0;
+  for (unsigned i = 0; i < rank; i++) {
+    if (auto constExpr = offsets[i].dyn_cast<AffineConstantExpr>())
+      if (constExpr.getValue() == 0)
+        ++zeroOffsetCount;
+    auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
+
+    auto remapExpr =
+        simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
+    remapExprs.push_back(remapExpr);
+  }
+  auto indexRemap = zeroOffsetCount == rank
+                        ? AffineMap()
+                        : b.getAffineMap(outerIVs.size() + rank, 0, remapExprs);
+  // Replace all users of 'oldMemRef' with 'newMemRef'.
+  bool ret =
+      replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
+                               /*extraOperands=*/outerIVs,
+                               /*domInstFilter=*/&*forOp.getBody()->begin());
+  assert(ret && "replaceAllMemrefUsesWith should always succeed here");
+  (void)ret;
+  return newMemRef;
+}
+
+// Checks if node 'srcId' (which writes to a live out memref), can be safely
+// fused into node 'dstId'. Returns true if the following conditions are met:
+// *) 'srcNode' only writes to live out 'memref'.
+// *) 'srcNode' has exactly one output edge on 'memref' (which is to 'dstId').
+// *) 'dstNode's read/write region to 'memref' is a super set of 'srcNode's
+//    write region to 'memref'.
+// TODO(andydavis) Generalize this to handle more live in/out cases.
+static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
+                                           Value *memref,
+                                           MemRefDependenceGraph *mdg) {
+  auto *srcNode = mdg->getNode(srcId);
+  auto *dstNode = mdg->getNode(dstId);
+
+  // Gather all memrefs from 'srcNode' store ops.
+  DenseSet<Value *> storeMemrefs;
+  for (auto *storeOpInst : srcNode->stores) {
+    storeMemrefs.insert(cast<AffineStoreOp>(storeOpInst).getMemRef());
+  }
+  // Return false if any of the following are true:
+  // *) 'srcNode' writes to a live in/out memref other than 'memref'.
+  // *) 'srcNode' has more than one output edge on 'memref'.
+  // Check that all stores are to the same memref.
+  if (storeMemrefs.size() != 1 ||
+      mdg->getOutEdgeCount(srcNode->id, memref) != 1)
+    return false;
+  // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOpInst' on 'memref'.
+  auto *srcStoreOpInst = srcNode->stores.front();
+  MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
+  if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Unable to compute MemRefRegion for source operation\n.");
+    return false;
+  }
+  SmallVector<int64_t, 4> srcShape;
+  // Query 'srcWriteRegion' for 'srcShape' and 'srcNumElements'.
+  // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
+  Optional<int64_t> srcNumElements =
+      srcWriteRegion.getConstantBoundingSizeAndShape(&srcShape);
+  if (!srcNumElements.hasValue())
+    return false;
+
+  // Compute MemRefRegion 'dstRegion' for 'dstStore/LoadOpInst' on 'memref'.
+  // TODO(andydavis) Compute 'unionboundingbox' of all write regions (one for
+  // each store op in 'dstStoreOps').
+  SmallVector<Operation *, 2> dstStoreOps;
+  dstNode->getStoreOpsForMemref(memref, &dstStoreOps);
+  SmallVector<Operation *, 2> dstLoadOps;
+  dstNode->getLoadOpsForMemref(memref, &dstLoadOps);
+
+  auto *dstOpInst = dstStoreOps.empty() ? dstLoadOps[0] : dstStoreOps[0];
+  MemRefRegion dstRegion(dstOpInst->getLoc());
+  if (failed(dstRegion.compute(dstOpInst, /*loopDepth=*/0))) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Unable to compute MemRefRegion for dest operation\n.");
+    return false;
+  }
+  SmallVector<int64_t, 4> dstShape;
+  // Query 'dstRegion' for 'dstShape' and 'dstNumElements'.
+  // by 'dstOpInst' at depth 'dstLoopDepth'.
+  Optional<int64_t> dstNumElements =
+      dstRegion.getConstantBoundingSizeAndShape(&dstShape);
+  if (!dstNumElements.hasValue())
+    return false;
+
+  // Return false if write region is not a superset of 'srcNodes' write
+  // region to 'memref'.
+  // TODO(andydavis) Check the shape and lower bounds here too.
+  if (srcNumElements != dstNumElements)
+    return false;
+  return true;
+}
+
+// Checks the profitability of fusing a backwards slice of the loop nest
+// surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
+// The argument 'srcStoreOpInst' is used to calculate the storage reduction on
+// the memref being produced and consumed, which is an input to the cost model.
+// For producer-constumer fusion, 'srcStoreOpInst' will be the same as
+// 'srcOpInst', as we are slicing w.r.t to that producer.
+// For input-reuse fusion, 'srcOpInst' will be the src loop nest LoadOp which
+// reads from the same memref as dst loop nest load ops, and 'srcStoreOpInst'
+// will be the unique store op in the src node, which will be used to check
+// that the write region is the same after input-reuse fusion.
+// Returns true if it is profitable to fuse the candidate loop nests. Returns
+// false otherwise. `dstLoopDepth` is set to the most profitable depth at which
+// to materialize the source loop nest slice.
+// The profitability model executes the following steps:
+// *) Computes the backward computation slice at 'srcOpInst'. This
+//    computation slice of the loop nest surrounding 'srcOpInst' is
+//    represented by modified src loop bounds in 'sliceState', which are
+//    functions of loop IVs in the loop nest surrounding 'srcOpInst'.
+// *) Computes the cost of unfused src/dst loop nests (currently the cost of a
+//    loop nest is the total number of dynamic operation instances in the loop
+//    nest).
+// *) Computes the cost of fusing a slice of the src loop nest into the dst
+//    loop nest at various values of dst loop depth, attempting to fuse
+//    the largest compution slice at the maximal dst loop depth (closest to the
+//    load) to minimize reuse distance and potentially enable subsequent
+//    load/store forwarding.
+//    NOTE: If the dst loop nest includes multiple loads in 'dstLoadOpInsts' for
+//    the same memref as is written by 'srcOpInst', then the union of slice
+//    loop bounds is used to compute the slice and associated slice cost.
+//    NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
+//    nest, at which the src computation slice is inserted/fused.
+//    NOTE: We attempt to maximize the dst loop depth, but there are cases
+//    where a particular setting for 'dstLoopNest' might fuse an unsliced
+//    loop (within the src computation slice) at a depth which results in
+//    execessive recomputation (see unit tests for examples).
+// *) Compares the total cost of the unfused loop nests to the min cost fused
+//    loop nest computed in the previous step, and returns true if the latter
+//    is lower.
+static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
+                               ArrayRef<Operation *> dstLoadOpInsts,
+                               ArrayRef<Operation *> dstStoreOpInsts,
+                               ComputationSliceState *sliceState,
+                               unsigned *dstLoopDepth, bool maximalFusion) {
+  LLVM_DEBUG({
+    llvm::dbgs() << "Checking whether fusion is profitable between:\n";
+    llvm::dbgs() << " " << *srcOpInst << " and \n";
+    for (auto dstOpInst : dstLoadOpInsts) {
+      llvm::dbgs() << " " << *dstOpInst << "\n";
+    };
+  });
+
+  // Compute cost of sliced and unsliced src loop nest.
+  SmallVector<AffineForOp, 4> srcLoopIVs;
+  getLoopIVs(*srcOpInst, &srcLoopIVs);
+  unsigned numSrcLoopIVs = srcLoopIVs.size();
+
+  // Walk src loop nest and collect stats.
+  LoopNestStats srcLoopNestStats;
+  if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats))
+    return false;
+
+  // Compute cost of dst loop nest.
+  SmallVector<AffineForOp, 4> dstLoopIVs;
+  getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs);
+
+  LoopNestStats dstLoopNestStats;
+  if (!getLoopNestStats(dstLoopIVs[0], &dstLoopNestStats))
+    return false;
+
+  // Compute the maximum loop depth at which we can can insert the src slice
+  // and still satisfy dest loop nest dependences, for producer-consumer fusion.
+  unsigned maxDstLoopDepth =
+      (srcOpInst == srcStoreOpInst)
+          ? getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts)
+          : dstLoopIVs.size();
+  if (maxDstLoopDepth == 0) {
+    LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxDstLoopDepth == 0 .\n");
+    return false;
+  }
+
+  // Search for min cost value for 'dstLoopDepth'. At each value of
+  // 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice
+  // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
+  // of these bounds). Next the union slice bounds are used to calculate
+  // the cost of the slice and the cost of the slice inserted into the dst
+  // loop nest at 'dstLoopDepth'.
+  uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
+  double maxStorageReduction = 0.0;
+  Optional<uint64_t> sliceMemEstimate = None;
+
+  SmallVector<ComputationSliceState, 4> sliceStates;
+  sliceStates.resize(maxDstLoopDepth);
+  // The best loop depth at which to materialize the slice.
+  Optional<unsigned> bestDstLoopDepth = None;
+
+  // Compute op instance count for the src loop nest without iteration slicing.
+  uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats);
+
+  // Compute src loop nest write region size.
+  MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
+  if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Unable to compute MemRefRegion for source operation\n.");
+    return false;
+  }
+
+  Optional<int64_t> maybeSrcWriteRegionSizeBytes =
+      srcWriteRegion.getRegionSize();
+  if (!maybeSrcWriteRegionSizeBytes.hasValue())
+    return false;
+  int64_t srcWriteRegionSizeBytes = maybeSrcWriteRegionSizeBytes.getValue();
+
+  // Compute op instance count for the src loop nest.
+  uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], dstLoopNestStats);
+
+  // Evaluate all depth choices for materializing the slice in the destination
+  // loop nest.
+  for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
+    // Compute the union of slice bounds of all ops in 'dstLoadOpInsts'.
+    if (failed(mlir::computeSliceUnion({srcOpInst}, dstLoadOpInsts,
+                                       /*loopDepth=*/i,
+                                       /*numCommonLoops=*/0,
+                                       /*isBackwardSlice=*/true,
+                                       &sliceStates[i - 1]))) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "computeSliceUnion failed for loopDepth: " << i << "\n");
+      continue;
+    }
+
+    int64_t fusedLoopNestComputeCost;
+    if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstLoopIVs[0],
+                              dstLoopNestStats, &sliceStates[i - 1],
+                              &fusedLoopNestComputeCost)) {
+      LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n.");
+      continue;
+    }
+
+    double additionalComputeFraction =
+        fusedLoopNestComputeCost /
+            (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
+        1;
+
+    // Determine what the slice write MemRefRegion would be, if the src loop
+    // nest slice 'sliceStates[i - 1]' were to be inserted into the dst loop
+    // nest at loop depth 'i'
+    MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
+    if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
+                                        &sliceStates[i - 1]))) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Failed to compute slice write region at loopDepth: " << i
+                 << "\n");
+      continue;
+    }
+
+    Optional<int64_t> maybeSliceWriteRegionSizeBytes =
+        sliceWriteRegion.getRegionSize();
+    if (!maybeSliceWriteRegionSizeBytes.hasValue() ||
+        maybeSliceWriteRegionSizeBytes.getValue() == 0) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Failed to get slice write region size at loopDepth: " << i
+                 << "\n");
+      continue;
+    }
+    int64_t sliceWriteRegionSizeBytes =
+        maybeSliceWriteRegionSizeBytes.getValue();
+
+    // If we are fusing for reuse, check that write regions remain the same.
+    // TODO(andydavis) Write region check should check sizes and offsets in
+    // each dimension, so that we are sure they are covering the same memref
+    // region. Also, move this out to a isMemRefRegionSuperSet helper function.
+    if (srcOpInst != srcStoreOpInst &&
+        sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes)
+      continue;
+
+    double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
+                              static_cast<double>(sliceWriteRegionSizeBytes);
+
+    LLVM_DEBUG({
+      std::stringstream msg;
+      msg << "  evaluating fusion profitability at depth : " << i << "\n"
+          << std::fixed << std::setprecision(2)
+          << "   additional compute fraction: "
+          << 100.0 * additionalComputeFraction << "%\n"
+          << "   storage reduction factor: " << storageReduction << "x\n"
+          << "   fused nest cost: " << fusedLoopNestComputeCost << "\n"
+          << "   src write region size: " << srcWriteRegionSizeBytes << "\n"
+          << "   slice write region size: " << sliceWriteRegionSizeBytes
+          << "\n";
+      llvm::dbgs() << msg.str();
+    });
+
+    double computeToleranceThreshold =
+        clFusionAddlComputeTolerance.getNumOccurrences() > 0
+            ? clFusionAddlComputeTolerance
+            : LoopFusion::kComputeToleranceThreshold;
+
+    // TODO(b/123247369): This is a placeholder cost model.
+    // Among all choices that add an acceptable amount of redundant computation
+    // (as per computeToleranceThreshold), we will simply pick the one that
+    // reduces the intermediary size the most.
+    if ((storageReduction > maxStorageReduction) &&
+        (maximalFusion ||
+         (additionalComputeFraction < computeToleranceThreshold))) {
+      maxStorageReduction = storageReduction;
+      bestDstLoopDepth = i;
+      minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
+      sliceMemEstimate = sliceWriteRegionSizeBytes;
+    }
+  }
+
+  // A simple cost model: fuse if it reduces the memory footprint. If
+  // -maximal-fusion is set, fuse nevertheless.
+
+  if (!maximalFusion && !bestDstLoopDepth.hasValue()) {
+    LLVM_DEBUG(
+        llvm::dbgs()
+        << "All fusion choices involve more than the threshold amount of "
+           "redundant computation; NOT fusing.\n");
+    return false;
+  }
+
+  if (!bestDstLoopDepth.hasValue()) {
+    LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n");
+    return false;
+  }
+
+  // Set dstLoopDepth based on best values from search.
+  *dstLoopDepth = bestDstLoopDepth.getValue();
+
+  LLVM_DEBUG(
+      llvm::dbgs() << " LoopFusion fusion stats:"
+                   << "\n  best loop depth: " << bestDstLoopDepth
+                   << "\n  src loop nest compute cost: " << srcLoopNestCost
+                   << "\n  dst loop nest compute cost: " << dstLoopNestCost
+                   << "\n  fused loop nest compute cost: "
+                   << minFusedLoopNestComputeCost << "\n");
+
+  auto dstMemSize = getMemoryFootprintBytes(dstLoopIVs[0]);
+  auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
+
+  Optional<double> storageReduction = None;
+
+  if (!maximalFusion) {
+    if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) {
+      LLVM_DEBUG(
+          llvm::dbgs()
+          << "  fusion memory benefit cannot be evaluated; NOT fusing.\n");
+      return false;
+    }
+
+    auto srcMemSizeVal = srcMemSize.getValue();
+    auto dstMemSizeVal = dstMemSize.getValue();
+
+    assert(sliceMemEstimate.hasValue() && "expected value");
+    auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue();
+
+    LLVM_DEBUG(llvm::dbgs() << "   src mem: " << srcMemSizeVal << "\n"
+                            << "   dst mem: " << dstMemSizeVal << "\n"
+                            << "   fused mem: " << fusedMem << "\n"
+                            << "   slice mem: " << sliceMemEstimate << "\n");
+
+    if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
+      LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
+      return false;
+    }
+    storageReduction =
+        100.0 *
+        (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
+  }
+
+  double additionalComputeFraction =
+      100.0 * (minFusedLoopNestComputeCost /
+                   (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
+               1);
+  (void)additionalComputeFraction;
+  LLVM_DEBUG({
+    std::stringstream msg;
+    msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
+        << std::setprecision(2) << additionalComputeFraction
+        << "% redundant computation and a ";
+    msg << (storageReduction.hasValue()
+                ? std::to_string(storageReduction.getValue())
+                : "<unknown>");
+    msg << "% storage reduction.\n";
+    llvm::dbgs() << msg.str();
+  });
+
+  // Update return parameter 'sliceState' with 'bestSliceState'.
+  ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1];
+  sliceState->lbs = bestSliceState->lbs;
+  sliceState->ubs = bestSliceState->ubs;
+  sliceState->lbOperands = bestSliceState->lbOperands;
+  sliceState->ubOperands = bestSliceState->ubOperands;
+
+  // Canonicalize slice bound affine maps.
+  for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
+    if (sliceState->lbs[i] != AffineMap()) {
+      canonicalizeMapAndOperands(&sliceState->lbs[i],
+                                 &sliceState->lbOperands[i]);
+    }
+    if (sliceState->ubs[i] != AffineMap()) {
+      canonicalizeMapAndOperands(&sliceState->ubs[i],
+                                 &sliceState->ubOperands[i]);
+    }
+  }
+  return true;
+}
+
+// GreedyFusion greedily fuses loop nests which have a producer/consumer or
+// input-reuse relationship on a memref, with the goal of improving locality.
+//
+// The steps of the producer-consumer fusion algorithm are as follows:
+//
+// *) A worklist is initialized with node ids from the dependence graph.
+// *) For each node id in the worklist:
+//   *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a
+//      candidate destination AffineForOp into which fusion will be attempted.
+//   *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'.
+//   *) For each LoadOp in 'dstLoadOps' do:
+//      *) Look up dependent loop nests which have a single store op to the same
+//         memref.
+//      *) Check if dependences would be violated by the fusion.
+//      *) Get a computation slice of 'srcLoopNest', which adjusts its loop
+//         bounds to be functions of 'dstLoopNest' IVs and symbols.
+//      *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
+//         at a loop depth determined by the cost model in 'isFusionProfitable'.
+//      *) Add the newly fused load/store operations to the state,
+//         and also add newly fused load ops to 'dstLoopOps' to be considered
+//         as fusion dst load ops in another iteration.
+//      *) Remove old src loop nest and its associated state.
+//
+// The steps of the input-reuse fusion algorithm are as follows:
+//
+// *) Initialize 'worklist' with node ids from the dependence graph.
+// *) For each 'dstNode' in the worklist:
+//   *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which
+//      loads from the same memref, but which has no dependence paths to/from.
+//   *) Get a computation slice of 'sibLoopNest', which adjusts its loop
+//      bounds to be functions of 'dstLoopNest' IVs and symbols.
+//   *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest',
+//      at a loop depth determined by the cost model in 'isFusionProfitable'.
+//      This function also checks that the memref write region of 'sibLoopNest',
+//      is preserved in the fused loop nest.
+//   *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'.
+//
+// Given a graph where top-level operations are vertices in the set 'V' and
+// edges in the set 'E' are dependences between vertices, this algorithm
+// takes O(V) time for initialization, and has runtime O(V + E).
+//
+// This greedy algorithm is not 'maximal' due to the current restriction of
+// fusing along single producer consumer edges, but there is a TODO to fix this.
+//
+// TODO(andydavis) Experiment with other fusion policies.
+struct GreedyFusion {
+public:
+  // The data dependence graph to traverse during fusion.
+  MemRefDependenceGraph *mdg;
+  // Worklist of graph nodes visited during the fusion pass.
+  SmallVector<unsigned, 8> worklist;
+  // Set of graph nodes which are present on the worklist.
+  llvm::SmallDenseSet<unsigned, 16> worklistSet;
+  // Parameter for local buffer size threshold.
+  unsigned localBufSizeThreshold;
+  // Parameter for fast memory space.
+  Optional<unsigned> fastMemorySpace;
+  // If true, ignore any additional (redundant) computation tolerance threshold
+  // that would have prevented fusion.
+  bool maximalFusion;
+
+  using Node = MemRefDependenceGraph::Node;
+
+  GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold,
+               Optional<unsigned> fastMemorySpace, bool maximalFusion)
+      : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
+        fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion) {}
+
+  // Initializes 'worklist' with nodes from 'mdg'
+  void init() {
+    // TODO(andydavis) Add a priority queue for prioritizing nodes by different
+    // metrics (e.g. arithmetic intensity/flops-to-bytes ratio).
+    worklist.clear();
+    worklistSet.clear();
+    for (auto &idAndNode : mdg->nodes) {
+      const Node &node = idAndNode.second;
+      worklist.push_back(node.id);
+      worklistSet.insert(node.id);
+    }
+  }
+
+  // Run the GreedyFusion pass.
+  // *) First pass through the nodes fuses single-use producer nodes into their
+  //    unique consumer.
+  // *) Second pass fuses sibling nodes which share no dependence edges.
+  // *) Third pass fuses any remaining producer nodes into their users.
+  void run() {
+    // TODO(andydavis) Run this repeatedly until a fixed-point is reached.
+    fuseProducerConsumerNodes(/*maxSrcUserCount=*/1);
+    fuseSiblingNodes();
+    fuseProducerConsumerNodes(
+        /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
+    eraseUnusedMemRefAllocations();
+  }
+
+  void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
+    init();
+    while (!worklist.empty()) {
+      unsigned dstId = worklist.back();
+      worklist.pop_back();
+      worklistSet.erase(dstId);
+
+      // Skip if this node was removed (fused into another node).
+      if (mdg->nodes.count(dstId) == 0)
+        continue;
+      // Get 'dstNode' into which to attempt fusion.
+      auto *dstNode = mdg->getNode(dstId);
+      // Skip if 'dstNode' is not a loop nest.
+      if (!isa<AffineForOp>(dstNode->op))
+        continue;
+      // Sink sequential loops in 'dstNode' (and thus raise parallel loops)
+      // while preserving relative order. This can increase the maximum loop
+      // depth at which we can fuse a slice of a producer loop nest into a
+      // consumer loop nest.
+      sinkSequentialLoops(dstNode);
+
+      SmallVector<Operation *, 4> loads = dstNode->loads;
+      SmallVector<Operation *, 4> dstLoadOpInsts;
+      DenseSet<Value *> visitedMemrefs;
+      while (!loads.empty()) {
+        // Get memref of load on top of the stack.
+        auto *memref = cast<AffineLoadOp>(loads.back()).getMemRef();
+        if (visitedMemrefs.count(memref) > 0)
+          continue;
+        visitedMemrefs.insert(memref);
+        // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'.
+        moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts);
+        // Skip if no input edges along which to fuse.
+        if (mdg->inEdges.count(dstId) == 0)
+          continue;
+        // Iterate through in-edges for 'dstId' and src node id for any
+        // edges on 'memref'.
+        SmallVector<unsigned, 2> srcNodeIds;
+        for (auto &srcEdge : mdg->inEdges[dstId]) {
+          // Skip 'srcEdge' if not for 'memref'.
+          if (srcEdge.value != memref)
+            continue;
+          srcNodeIds.push_back(srcEdge.id);
+        }
+        for (unsigned srcId : srcNodeIds) {
+          // Skip if this node was removed (fused into another node).
+          if (mdg->nodes.count(srcId) == 0)
+            continue;
+          // Get 'srcNode' from which to attempt fusion into 'dstNode'.
+          auto *srcNode = mdg->getNode(srcId);
+          // Skip if 'srcNode' is not a loop nest.
+          if (!isa<AffineForOp>(srcNode->op))
+            continue;
+          // Skip if 'srcNode' has more than one store to any memref.
+          // TODO(andydavis) Support fusing multi-output src loop nests.
+          if (srcNode->stores.size() != 1)
+            continue;
+
+          // Skip if 'srcNode' writes to any live in or escaping memrefs,
+          // and cannot be fused.
+          bool writesToLiveInOrOut =
+              mdg->writesToLiveInOrEscapingMemrefs(srcNode->id);
+          if (writesToLiveInOrOut &&
+              !canFuseSrcWhichWritesToLiveOut(srcId, dstId, memref, mdg))
+            continue;
+
+          // Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'.
+          if (mdg->getOutEdgeCount(srcNode->id, memref) > maxSrcUserCount)
+            continue;
+
+          // Compute an operation list insertion point for the fused loop
+          // nest which preserves dependences.
+          Operation *insertPointInst =
+              mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
+          if (insertPointInst == nullptr)
+            continue;
+
+          // Get unique 'srcNode' store op.
+          auto *srcStoreOpInst = srcNode->stores.front();
+          // Gather 'dstNode' store ops to 'memref'.
+          SmallVector<Operation *, 2> dstStoreOpInsts;
+          for (auto *storeOpInst : dstNode->stores)
+            if (cast<AffineStoreOp>(storeOpInst).getMemRef() == memref)
+              dstStoreOpInsts.push_back(storeOpInst);
+
+          unsigned bestDstLoopDepth;
+          mlir::ComputationSliceState sliceState;
+          // Check if fusion would be profitable.
+          if (!isFusionProfitable(srcStoreOpInst, srcStoreOpInst,
+                                  dstLoadOpInsts, dstStoreOpInsts, &sliceState,
+                                  &bestDstLoopDepth, maximalFusion))
+            continue;
+          // TODO(andydavis) Remove the following test code when canFuseLoops
+          // is fully functional.
+          mlir::ComputationSliceState sliceUnion;
+          if (!maximalFusion) {
+            FusionResult result = mlir::canFuseLoops(
+                cast<AffineForOp>(srcNode->op), cast<AffineForOp>(dstNode->op),
+                bestDstLoopDepth, &sliceUnion);
+            assert(result.value == FusionResult::Success);
+            (void)result;
+          }
+          // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
+          auto sliceLoopNest = mlir::insertBackwardComputationSlice(
+              srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
+          if (sliceLoopNest) {
+            LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n"
+                                    << *sliceLoopNest.getOperation() << "\n");
+            // Move 'dstAffineForOp' before 'insertPointInst' if needed.
+            auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
+            if (insertPointInst != dstAffineForOp.getOperation()) {
+              dstAffineForOp.getOperation()->moveBefore(insertPointInst);
+            }
+            // Update edges between 'srcNode' and 'dstNode'.
+            mdg->updateEdges(srcNode->id, dstNode->id, memref);
+
+            // Collect slice loop stats.
+            LoopNestStateCollector sliceCollector;
+            sliceCollector.collect(sliceLoopNest.getOperation());
+            // Promote single iteration slice loops to single IV value.
+            for (auto forOp : sliceCollector.forOps) {
+              promoteIfSingleIteration(forOp);
+            }
+            if (!writesToLiveInOrOut) {
+              // Create private memref for 'memref' in 'dstAffineForOp'.
+              SmallVector<Operation *, 4> storesForMemref;
+              for (auto *storeOpInst : sliceCollector.storeOpInsts) {
+                if (cast<AffineStoreOp>(storeOpInst).getMemRef() == memref)
+                  storesForMemref.push_back(storeOpInst);
+              }
+              assert(storesForMemref.size() == 1);
+              auto *newMemRef = createPrivateMemRef(
+                  dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
+                  fastMemorySpace, localBufSizeThreshold);
+              visitedMemrefs.insert(newMemRef);
+              // Create new node in dependence graph for 'newMemRef' alloc op.
+              unsigned newMemRefNodeId =
+                  mdg->addNode(newMemRef->getDefiningOp());
+              // Add edge from 'newMemRef' node to dstNode.
+              mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
+            }
+
+            // Collect dst loop stats after memref privatizaton transformation.
+            LoopNestStateCollector dstLoopCollector;
+            dstLoopCollector.collect(dstAffineForOp.getOperation());
+
+            // Add new load ops to current Node load op list 'loads' to
+            // continue fusing based on new operands.
+            for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
+              auto *loadMemRef = cast<AffineLoadOp>(loadOpInst).getMemRef();
+              if (visitedMemrefs.count(loadMemRef) == 0)
+                loads.push_back(loadOpInst);
+            }
+
+            // Clear and add back loads and stores.
+            mdg->clearNodeLoadAndStores(dstNode->id);
+            mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
+                           dstLoopCollector.storeOpInsts);
+            // Remove old src loop nest if it no longer has outgoing dependence
+            // edges, and if it does not write to a memref which escapes the
+            // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has
+            // been fused into 'dstNode' and write region of 'dstNode' covers
+            // the write region of 'srcNode', and 'srcNode' has no other users
+            // so it is safe to remove.
+            if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) {
+              mdg->removeNode(srcNode->id);
+              srcNode->op->erase();
+            } else {
+              // Add remaining users of 'oldMemRef' back on the worklist (if not
+              // already there), as its replacement with a local/private memref
+              // has reduced dependences on 'oldMemRef' which may have created
+              // new fusion opportunities.
+              if (mdg->outEdges.count(srcNode->id) > 0) {
+                SmallVector<MemRefDependenceGraph::Edge, 2> oldOutEdges =
+                    mdg->outEdges[srcNode->id];
+                for (auto &outEdge : oldOutEdges) {
+                  if (outEdge.value == memref &&
+                      worklistSet.count(outEdge.id) == 0) {
+                    worklist.push_back(outEdge.id);
+                    worklistSet.insert(outEdge.id);
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+
+  // Visits each node in the graph, and for each node, attempts to fuse it with
+  // its sibling nodes (nodes which share a parent, but no dependence edges).
+  void fuseSiblingNodes() {
+    init();
+    while (!worklist.empty()) {
+      unsigned dstId = worklist.back();
+      worklist.pop_back();
+      worklistSet.erase(dstId);
+
+      // Skip if this node was removed (fused into another node).
+      if (mdg->nodes.count(dstId) == 0)
+        continue;
+      // Get 'dstNode' into which to attempt fusion.
+      auto *dstNode = mdg->getNode(dstId);
+      // Skip if 'dstNode' is not a loop nest.
+      if (!isa<AffineForOp>(dstNode->op))
+        continue;
+      // Attempt to fuse 'dstNode' with its sibling nodes in the graph.
+      fuseWithSiblingNodes(dstNode);
+    }
+  }
+
+  // Attempt to fuse 'dstNode' with sibling nodes in the graph.
+  void fuseWithSiblingNodes(Node *dstNode) {
+    DenseSet<unsigned> visitedSibNodeIds;
+    std::pair<unsigned, Value *> idAndMemref;
+    while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
+      unsigned sibId = idAndMemref.first;
+      Value *memref = idAndMemref.second;
+      // TODO(andydavis) Check that 'sibStoreOpInst' post-dominates all other
+      // stores to the same memref in 'sibNode' loop nest.
+      auto *sibNode = mdg->getNode(sibId);
+      // Compute an operation list insertion point for the fused loop
+      // nest which preserves dependences.
+      assert(sibNode->op->getBlock() == dstNode->op->getBlock());
+      Operation *insertPointInst =
+          sibNode->op->isBeforeInBlock(dstNode->op)
+              ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id)
+              : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id);
+      if (insertPointInst == nullptr)
+        continue;
+
+      // Check if fusion would be profitable and at what depth.
+
+      // Get unique 'sibNode' load op to 'memref'.
+      SmallVector<Operation *, 2> sibLoadOpInsts;
+      sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
+      // Currently findSiblingNodeToFuse searches for siblings with one load.
+      assert(sibLoadOpInsts.size() == 1);
+      Operation *sibLoadOpInst = sibLoadOpInsts[0];
+      assert(!sibNode->stores.empty());
+      // TODO(andydavis) Choose the store which postdominates all other stores.
+      auto *sibStoreOpInst = sibNode->stores.back();
+
+      // Gather 'dstNode' load ops to 'memref'.
+      SmallVector<Operation *, 2> dstLoadOpInsts;
+      dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
+
+      // Gather 'dstNode' store ops to 'memref'.
+      SmallVector<Operation *, 2> dstStoreOpInsts;
+      dstNode->getStoreOpsForMemref(memref, &dstStoreOpInsts);
+
+      unsigned bestDstLoopDepth;
+      mlir::ComputationSliceState sliceState;
+
+      // Check if fusion would be profitable.
+      if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts,
+                              dstStoreOpInsts, &sliceState, &bestDstLoopDepth,
+                              maximalFusion))
+        continue;
+
+      // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
+      auto sliceLoopNest = mlir::insertBackwardComputationSlice(
+          sibLoadOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
+      if (sliceLoopNest != nullptr) {
+        auto dstForInst = cast<AffineForOp>(dstNode->op);
+        // Update operation position of fused loop nest (if needed).
+        if (insertPointInst != dstForInst.getOperation()) {
+          dstForInst.getOperation()->moveBefore(insertPointInst);
+        }
+        // Update data dependence graph state post fusion.
+        updateStateAfterSiblingFusion(sliceLoopNest, sibNode, dstNode);
+      }
+    }
+  }
+
+  // Searches function argument uses and the graph from 'dstNode' looking for a
+  // fusion candidate sibling node which shares no dependences with 'dstNode'
+  // but which loads from the same memref. Returns true and sets
+  // 'idAndMemrefToFuse' on success. Returns false otherwise.
+  bool findSiblingNodeToFuse(Node *dstNode,
+                             DenseSet<unsigned> *visitedSibNodeIds,
+                             std::pair<unsigned, Value *> *idAndMemrefToFuse) {
+    // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse
+    // on 'memref'.
+    auto canFuseWithSibNode = [&](Node *sibNode, Value *memref) {
+      // Skip if 'outEdge' is not a read-after-write dependence.
+      // TODO(andydavis) Remove restrict to single load op restriction.
+      if (sibNode->getLoadOpCount(memref) != 1)
+        return false;
+      // Skip if there exists a path of dependent edges between
+      // 'sibNode' and 'dstNode'.
+      if (mdg->hasDependencePath(sibNode->id, dstNode->id) ||
+          mdg->hasDependencePath(dstNode->id, sibNode->id))
+        return false;
+      // Skip sib node if it loads to (and stores from) the same memref on
+      // which it also has an input dependence edge.
+      DenseSet<Value *> loadAndStoreMemrefSet;
+      sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
+      if (llvm::any_of(loadAndStoreMemrefSet, [=](Value *memref) {
+            return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0;
+          }))
+        return false;
+
+      // Check that all stores are to the same memref.
+      DenseSet<Value *> storeMemrefs;
+      for (auto *storeOpInst : sibNode->stores) {
+        storeMemrefs.insert(cast<AffineStoreOp>(storeOpInst).getMemRef());
+      }
+      if (storeMemrefs.size() != 1)
+        return false;
+      return true;
+    };
+
+    // Search for siblings which load the same memref function argument.
+    auto fn = dstNode->op->getParentOfType<FuncOp>();
+    for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) {
+      for (auto *user : fn.getArgument(i)->getUsers()) {
+        if (auto loadOp = dyn_cast<AffineLoadOp>(user)) {
+          // Gather loops surrounding 'use'.
+          SmallVector<AffineForOp, 4> loops;
+          getLoopIVs(*user, &loops);
+          // Skip 'use' if it is not within a loop nest.
+          if (loops.empty())
+            continue;
+          Node *sibNode = mdg->getForOpNode(loops[0]);
+          assert(sibNode != nullptr);
+          // Skip 'use' if it not a sibling to 'dstNode'.
+          if (sibNode->id == dstNode->id)
+            continue;
+          // Skip 'use' if it has been visited.
+          if (visitedSibNodeIds->count(sibNode->id) > 0)
+            continue;
+          // Skip 'use' if it does not load from the same memref as 'dstNode'.
+          auto *memref = loadOp.getMemRef();
+          if (dstNode->getLoadOpCount(memref) == 0)
+            continue;
+          // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
+          if (canFuseWithSibNode(sibNode, memref)) {
+            visitedSibNodeIds->insert(sibNode->id);
+            idAndMemrefToFuse->first = sibNode->id;
+            idAndMemrefToFuse->second = memref;
+            return true;
+          }
+        }
+      }
+    }
+
+    // Search for siblings by following edges through an intermediate src node.
+    // Collect candidate 'dstNode' input edges in 'inEdges'.
+    SmallVector<MemRefDependenceGraph::Edge, 2> inEdges;
+    mdg->forEachMemRefInputEdge(
+        dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) {
+          // Add 'inEdge' if it is a read-after-write dependence.
+          if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
+              mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
+            inEdges.push_back(inEdge);
+        });
+
+    // Search for sibling nodes to fuse by visiting output edges from each input
+    // edge in 'inEdges'.
+    for (auto &inEdge : inEdges) {
+      // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'.
+      SmallVector<MemRefDependenceGraph::Edge, 2> outEdges;
+      mdg->forEachMemRefOutputEdge(
+          inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) {
+            unsigned sibNodeId = outEdge.id;
+            if (visitedSibNodeIds->count(sibNodeId) > 0)
+              return;
+            // Skip output edge if not a sibling using the same memref.
+            if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
+              return;
+            auto *sibNode = mdg->getNode(sibNodeId);
+            if (!isa<AffineForOp>(sibNode->op))
+              return;
+            // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
+            if (canFuseWithSibNode(sibNode, outEdge.value)) {
+              // Add candidate 'outEdge' to sibling node.
+              outEdges.push_back(outEdge);
+            }
+          });
+
+      // Add first candidate if any were returned.
+      if (!outEdges.empty()) {
+        visitedSibNodeIds->insert(outEdges[0].id);
+        idAndMemrefToFuse->first = outEdges[0].id;
+        idAndMemrefToFuse->second = outEdges[0].value;
+        return true;
+      }
+    }
+    return false;
+  }
+
+  void updateStateAfterSiblingFusion(AffineForOp sliceLoopNest, Node *sibNode,
+                                     Node *dstNode) {
+    // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion.
+    mdg->updateEdges(sibNode->id, dstNode->id);
+
+    // Collect slice loop stats.
+    LoopNestStateCollector sliceCollector;
+    sliceCollector.collect(sliceLoopNest.getOperation());
+    // Promote single iteration slice loops to single IV value.
+    for (auto forOp : sliceCollector.forOps) {
+      promoteIfSingleIteration(forOp);
+    }
+
+    // Collect dst loop stats after memref privatizaton transformation.
+    auto dstForInst = cast<AffineForOp>(dstNode->op);
+    LoopNestStateCollector dstLoopCollector;
+    dstLoopCollector.collect(dstForInst.getOperation());
+    // Clear and add back loads and stores
+    mdg->clearNodeLoadAndStores(dstNode->id);
+    mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
+                   dstLoopCollector.storeOpInsts);
+    // Remove old sibling loop nest if it no longer has outgoing dependence
+    // edges, and it does not write to a memref which escapes the
+    // function.
+    if (mdg->getOutEdgeCount(sibNode->id) == 0) {
+      mdg->removeNode(sibNode->id);
+      sibNode->op->erase();
+    }
+  }
+
+  // Clean up any allocs with no users.
+  void eraseUnusedMemRefAllocations() {
+    for (auto &pair : mdg->memrefEdgeCount) {
+      if (pair.second > 0)
+        continue;
+      auto *memref = pair.first;
+      // Skip if there exist other uses (return operation or function calls).
+      if (!memref->use_empty())
+        continue;
+      // Use list expected to match the dep graph info.
+      auto *op = memref->getDefiningOp();
+      if (isa_and_nonnull<AllocOp>(op))
+        op->erase();
+    }
+  }
+};
+
+} // end anonymous namespace
+
+void LoopFusion::runOnFunction() {
+  // Override if a command line argument was provided.
+  if (clFusionFastMemorySpace.getNumOccurrences() > 0) {
+    fastMemorySpace = clFusionFastMemorySpace.getValue();
+  }
+
+  // Override if a command line argument was provided.
+  if (clFusionLocalBufThreshold.getNumOccurrences() > 0) {
+    localBufSizeThreshold = clFusionLocalBufThreshold * 1024;
+  }
+
+  if (clMaximalLoopFusion.getNumOccurrences() > 0)
+    maximalFusion = clMaximalLoopFusion;
+
+  MemRefDependenceGraph g;
+  if (g.init(getFunction()))
+    GreedyFusion(&g, localBufSizeThreshold, fastMemorySpace, maximalFusion)
+        .run();
+}
+
+static PassRegistration<LoopFusion> pass("affine-loop-fusion",
+                                         "Fuse loop nests");
diff --git a/third_party/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/third_party/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
new file mode 100644
index 0000000..d8b5b2d
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
@@ -0,0 +1,251 @@
+//===- LoopInvariantCodeMotion.cpp - Code to perform loop fusion-----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements loop invariant code motion.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/Utils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+#define DEBUG_TYPE "licm"
+
+using namespace mlir;
+
+namespace {
+
+/// Loop invariant code motion (LICM) pass.
+/// TODO(asabne) : The pass is missing zero-trip tests.
+/// TODO(asabne) : Check for the presence of side effects before hoisting.
+struct LoopInvariantCodeMotion : public FunctionPass<LoopInvariantCodeMotion> {
+  void runOnFunction() override;
+  void runOnAffineForOp(AffineForOp forOp);
+};
+} // end anonymous namespace
+
+static bool
+checkInvarianceOfNestedIfOps(Operation *op, Value *indVar,
+                             SmallPtrSetImpl<Operation *> &definedOps,
+                             SmallPtrSetImpl<Operation *> &opsToHoist);
+static bool isOpLoopInvariant(Operation &op, Value *indVar,
+                              SmallPtrSetImpl<Operation *> &definedOps,
+                              SmallPtrSetImpl<Operation *> &opsToHoist);
+
+static bool
+areAllOpsInTheBlockListInvariant(Region &blockList, Value *indVar,
+                                 SmallPtrSetImpl<Operation *> &definedOps,
+                                 SmallPtrSetImpl<Operation *> &opsToHoist);
+
+static bool isMemRefDereferencingOp(Operation &op) {
+  // TODO(asabne): Support DMA Ops.
+  if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op)) {
+    return true;
+  }
+  return false;
+}
+
+FunctionPassBase *mlir::createLoopInvariantCodeMotionPass() {
+  return new LoopInvariantCodeMotion();
+}
+
+// Returns true if the individual op is loop invariant.
+bool isOpLoopInvariant(Operation &op, Value *indVar,
+                       SmallPtrSetImpl<Operation *> &definedOps,
+                       SmallPtrSetImpl<Operation *> &opsToHoist) {
+  LLVM_DEBUG(llvm::dbgs() << "iterating on op: " << op;);
+
+  if (isa<AffineIfOp>(op)) {
+    if (!checkInvarianceOfNestedIfOps(&op, indVar, definedOps, opsToHoist)) {
+      return false;
+    }
+  } else if (isa<AffineForOp>(op)) {
+    // If the body of a predicated region has a for loop, we don't hoist the
+    // 'affine.if'.
+    return false;
+  } else if (isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op)) {
+    // TODO(asabne): Support DMA ops.
+    return false;
+  } else if (!isa<ConstantOp>(op)) {
+    if (isMemRefDereferencingOp(op)) {
+      Value *memref = isa<AffineLoadOp>(op)
+                          ? cast<AffineLoadOp>(op).getMemRef()
+                          : cast<AffineStoreOp>(op).getMemRef();
+      for (auto *user : memref->getUsers()) {
+        // If this memref has a user that is a DMA, give up because these
+        // operations write to this memref.
+        if (isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op)) {
+          return false;
+        }
+        // If the memref used by the load/store is used in a store elsewhere in
+        // the loop nest, we do not hoist. Similarly, if the memref used in a
+        // load is also being stored too, we do not hoist the load.
+        if (isa<AffineStoreOp>(user) ||
+            (isa<AffineLoadOp>(user) && isa<AffineStoreOp>(op))) {
+          if (&op != user) {
+            SmallVector<AffineForOp, 8> userIVs;
+            getLoopIVs(*user, &userIVs);
+            // Check that userIVs don't contain the for loop around the op.
+            if (llvm::is_contained(userIVs, getForInductionVarOwner(indVar))) {
+              return false;
+            }
+          }
+        }
+      }
+    }
+
+    // Insert this op in the defined ops list.
+    definedOps.insert(&op);
+
+    if (op.getNumOperands() == 0 && !isa<AffineTerminatorOp>(op)) {
+      LLVM_DEBUG(llvm::dbgs() << "\nNon-constant op with 0 operands\n");
+      return false;
+    }
+    for (unsigned int i = 0; i < op.getNumOperands(); ++i) {
+      auto *operandSrc = op.getOperand(i)->getDefiningOp();
+
+      LLVM_DEBUG(
+          op.getOperand(i)->print(llvm::dbgs() << "\nIterating on operand\n"));
+
+      // If the loop IV is the operand, this op isn't loop invariant.
+      if (indVar == op.getOperand(i)) {
+        LLVM_DEBUG(llvm::dbgs() << "\nLoop IV is the operand\n");
+        return false;
+      }
+
+      if (operandSrc != nullptr) {
+        LLVM_DEBUG(llvm::dbgs()
+                   << *operandSrc << "\nIterating on operand src\n");
+
+        // If the value was defined in the loop (outside of the
+        // if/else region), and that operation itself wasn't meant to
+        // be hoisted, then mark this operation loop dependent.
+        if (definedOps.count(operandSrc) && opsToHoist.count(operandSrc) == 0) {
+          return false;
+        }
+      }
+    }
+  }
+
+  // If no operand was loop variant, mark this op for motion.
+  opsToHoist.insert(&op);
+  return true;
+}
+
+// Checks if all ops in a region (i.e. list of blocks) are loop invariant.
+bool areAllOpsInTheBlockListInvariant(
+    Region &blockList, Value *indVar, SmallPtrSetImpl<Operation *> &definedOps,
+    SmallPtrSetImpl<Operation *> &opsToHoist) {
+
+  for (auto &b : blockList) {
+    for (auto &op : b) {
+      if (!isOpLoopInvariant(op, indVar, definedOps, opsToHoist)) {
+        return false;
+      }
+    }
+  }
+
+  return true;
+}
+
+// Returns true if the affine.if op can be hoisted.
+bool checkInvarianceOfNestedIfOps(Operation *op, Value *indVar,
+                                  SmallPtrSetImpl<Operation *> &definedOps,
+                                  SmallPtrSetImpl<Operation *> &opsToHoist) {
+  assert(isa<AffineIfOp>(op));
+  auto ifOp = cast<AffineIfOp>(op);
+
+  if (!areAllOpsInTheBlockListInvariant(ifOp.thenRegion(), indVar, definedOps,
+                                        opsToHoist)) {
+    return false;
+  }
+
+  if (!areAllOpsInTheBlockListInvariant(ifOp.elseRegion(), indVar, definedOps,
+                                        opsToHoist)) {
+    return false;
+  }
+
+  return true;
+}
+
+void LoopInvariantCodeMotion::runOnAffineForOp(AffineForOp forOp) {
+  auto *loopBody = forOp.getBody();
+  auto *indVar = forOp.getInductionVar();
+
+  SmallPtrSet<Operation *, 8> definedOps;
+  // This is the place where hoisted instructions would reside.
+  OpBuilder b(forOp.getOperation());
+
+  SmallPtrSet<Operation *, 8> opsToHoist;
+  SmallVector<Operation *, 8> opsToMove;
+
+  for (auto &op : *loopBody) {
+    // We don't hoist for loops.
+    if (!isa<AffineForOp>(op)) {
+      if (!isa<AffineTerminatorOp>(op)) {
+        if (isOpLoopInvariant(op, indVar, definedOps, opsToHoist)) {
+          opsToMove.push_back(&op);
+        }
+      }
+    }
+  }
+
+  // For all instructions that we found to be invariant, place sequentially
+  // right before the for loop.
+  for (auto *op : opsToMove) {
+    op->moveBefore(forOp);
+  }
+
+  LLVM_DEBUG(forOp.getOperation()->print(llvm::dbgs() << "Modified loop\n"));
+
+  // If the for loop body has a single operation (the terminator), erase it.
+  if (forOp.getBody()->getOperations().size() == 1) {
+    assert(isa<AffineTerminatorOp>(forOp.getBody()->front()));
+    forOp.erase();
+  }
+}
+
+void LoopInvariantCodeMotion::runOnFunction() {
+  // Walk through all loops in a function in innermost-loop-first order.  This
+  // way, we first LICM from the inner loop, and place the ops in
+  // the outer loop, which in turn can be further LICM'ed.
+  getFunction().walk<AffineForOp>([&](AffineForOp op) {
+    LLVM_DEBUG(op.getOperation()->print(llvm::dbgs() << "\nOriginal loop\n"));
+    runOnAffineForOp(op);
+  });
+}
+
+static PassRegistration<LoopInvariantCodeMotion>
+    pass("affine-loop-invariant-code-motion",
+         "Hoist loop invariant instructions outside of the loop");
diff --git a/third_party/mlir/lib/Transforms/LoopTiling.cpp b/third_party/mlir/lib/Transforms/LoopTiling.cpp
new file mode 100644
index 0000000..0a331ca
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/LoopTiling.cpp
@@ -0,0 +1,410 @@
+//===- LoopTiling.cpp --- Loop tiling pass ------------------------------*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to tile loop nests.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/Utils.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+using namespace mlir;
+
+#define DEBUG_TYPE "affine-loop-tile"
+
+static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
+
+static llvm::cl::opt<unsigned long long>
+    clCacheSizeKiB("tile-cache-size",
+                   llvm::cl::desc("Set size of cache to tile for in KiB"),
+                   llvm::cl::cat(clOptionsCategory));
+
+// Tile size to use for all loops (overrides -tile-sizes if provided).
+static llvm::cl::opt<unsigned>
+    clTileSize("tile-size", llvm::cl::desc("Use this tile size for all loops"),
+               llvm::cl::cat(clOptionsCategory));
+
+// List of tile sizes. If any of them aren't provided, they are filled with
+// clTileSize / kDefaultTileSize.
+static llvm::cl::list<unsigned> clTileSizes(
+    "tile-sizes",
+    llvm::cl::desc(
+        "List of tile sizes for each perfect nest (overridden by -tile-size)"),
+    llvm::cl::ZeroOrMore, llvm::cl::cat(clOptionsCategory));
+
+namespace {
+
+/// A pass to perform loop tiling on all suitable loop nests of a Function.
+struct LoopTiling : public FunctionPass<LoopTiling> {
+  explicit LoopTiling(uint64_t cacheSizeBytes = kDefaultCacheMemCapacity,
+                      bool avoidMaxMinBounds = true)
+      : cacheSizeBytes(cacheSizeBytes), avoidMaxMinBounds(avoidMaxMinBounds) {}
+
+  void runOnFunction() override;
+  void getTileSizes(ArrayRef<AffineForOp> band,
+                    SmallVectorImpl<unsigned> *tileSizes);
+
+  // Default tile size if nothing is provided.
+  constexpr static unsigned kDefaultTileSize = 4;
+  constexpr static uint64_t kDefaultCacheMemCapacity = 512 * 1024UL;
+
+  // Capacity of the cache to tile for.
+  uint64_t cacheSizeBytes;
+  // If true, tile sizes are set to avoid max/min in bounds if possible.
+  bool avoidMaxMinBounds;
+};
+
+} // end anonymous namespace
+
+/// Creates a pass to perform loop tiling on all suitable loop nests of a
+/// Function.
+FunctionPassBase *mlir::createLoopTilingPass(uint64_t cacheSizeBytes) {
+  return new LoopTiling(cacheSizeBytes);
+}
+
+// Move the loop body of AffineForOp 'src' from 'src' into the specified
+// location in destination's body, ignoring the terminator.
+static inline void moveLoopBody(AffineForOp src, AffineForOp dest,
+                                Block::iterator loc) {
+  auto &insts = src.getBody()->getOperations();
+  dest.getBody()->getOperations().splice(loc, insts, insts.begin(),
+                                         std::prev(insts.end()));
+}
+
+// Move the loop body of AffineForOp 'src' from 'src' to the start of dest's
+// body.
+static inline void moveLoopBody(AffineForOp src, AffineForOp dest) {
+  moveLoopBody(src, dest, dest.getBody()->begin());
+}
+
+/// Constructs and sets new loop bounds after tiling for the case of
+/// hyper-rectangular index sets, where the bounds of one dimension do not
+/// depend on other dimensions. Bounds of each dimension can thus be treated
+/// independently, and deriving the new bounds is much simpler and faster
+/// than for the case of tiling arbitrary polyhedral shapes.
+static void
+constructTiledIndexSetHyperRect(MutableArrayRef<AffineForOp> origLoops,
+                                MutableArrayRef<AffineForOp> newLoops,
+                                ArrayRef<unsigned> tileSizes) {
+  assert(!origLoops.empty());
+  assert(origLoops.size() == tileSizes.size());
+
+  OpBuilder b(origLoops[0].getOperation());
+  unsigned width = origLoops.size();
+
+  // Bounds for tile space loops.
+  for (unsigned i = 0; i < width; i++) {
+    auto lbOperands = origLoops[i].getLowerBoundOperands();
+    auto ubOperands = origLoops[i].getUpperBoundOperands();
+    SmallVector<Value *, 4> newLbOperands(lbOperands);
+    SmallVector<Value *, 4> newUbOperands(ubOperands);
+    newLoops[i].setLowerBound(newLbOperands, origLoops[i].getLowerBoundMap());
+    newLoops[i].setUpperBound(newUbOperands, origLoops[i].getUpperBoundMap());
+    newLoops[i].setStep(tileSizes[i]);
+  }
+  // Bounds for intra-tile loops.
+  for (unsigned i = 0; i < width; i++) {
+    int64_t largestDiv = getLargestDivisorOfTripCount(origLoops[i]);
+    auto mayBeConstantCount = getConstantTripCount(origLoops[i]);
+    // The lower bound is just the tile-space loop.
+    AffineMap lbMap = b.getDimIdentityMap();
+    newLoops[width + i].setLowerBound(
+        /*operands=*/newLoops[i].getInductionVar(), lbMap);
+
+    // Set the upper bound.
+    if (mayBeConstantCount.hasValue() &&
+        mayBeConstantCount.getValue() < tileSizes[i]) {
+      // Trip count is less than tile size; upper bound is the trip count.
+      auto ubMap = b.getConstantAffineMap(mayBeConstantCount.getValue());
+      newLoops[width + i].setUpperBoundMap(ubMap);
+    } else if (largestDiv % tileSizes[i] != 0) {
+      // Intra-tile loop ii goes from i to min(i + tileSize, ub_i).
+      // Construct the upper bound map; the operands are the original operands
+      // with 'i' (tile-space loop) appended to it. The new upper bound map is
+      // the original one with an additional expression i + tileSize appended.
+      auto ub = origLoops[i].getUpperBound();
+      SmallVector<Value *, 4> ubOperands;
+      ubOperands.reserve(ub.getNumOperands() + 1);
+      auto origUbMap = ub.getMap();
+      // Add dim operands from original upper bound.
+      for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j) {
+        ubOperands.push_back(ub.getOperand(j));
+      }
+      // Add dim operand for new loop upper bound.
+      ubOperands.push_back(newLoops[i].getInductionVar());
+      // Add symbol operands from original upper bound.
+      for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j) {
+        ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j));
+      }
+      SmallVector<AffineExpr, 4> boundExprs;
+      boundExprs.reserve(1 + origUbMap.getNumResults());
+      auto dim = b.getAffineDimExpr(origUbMap.getNumDims());
+      // The new upper bound map is the original one with an additional
+      // expression i + tileSize appended.
+      boundExprs.push_back(dim + tileSizes[i]);
+      boundExprs.append(origUbMap.getResults().begin(),
+                        origUbMap.getResults().end());
+      auto ubMap = b.getAffineMap(origUbMap.getNumDims() + 1,
+                                  origUbMap.getNumSymbols(), boundExprs);
+      newLoops[width + i].setUpperBound(/*operands=*/ubOperands, ubMap);
+    } else {
+      // No need of the min expression.
+      auto dim = b.getAffineDimExpr(0);
+      auto ubMap = b.getAffineMap(1, 0, dim + tileSizes[i]);
+      newLoops[width + i].setUpperBound(newLoops[i].getInductionVar(), ubMap);
+    }
+  }
+}
+
+/// Tiles the specified band of perfectly nested loops creating tile-space loops
+/// and intra-tile loops. A band is a contiguous set of loops.
+//  TODO(bondhugula): handle non hyper-rectangular spaces.
+LogicalResult mlir::tileCodeGen(MutableArrayRef<AffineForOp> band,
+                                ArrayRef<unsigned> tileSizes) {
+  assert(!band.empty());
+  assert(band.size() == tileSizes.size() && "Incorrect number of tile sizes");
+
+  // Check if the supplied for op's are all successively nested.
+  for (unsigned i = 1, e = band.size(); i < e; i++) {
+    assert(band[i].getOperation()->getParentOp() == band[i - 1].getOperation());
+  }
+
+  auto origLoops = band;
+
+  AffineForOp rootAffineForOp = origLoops[0];
+  auto loc = rootAffineForOp.getLoc();
+  // Note that width is at least one since band isn't empty.
+  unsigned width = band.size();
+
+  SmallVector<AffineForOp, 12> newLoops(2 * width);
+  AffineForOp innermostPointLoop;
+
+  // The outermost among the loops as we add more..
+  auto *topLoop = rootAffineForOp.getOperation();
+
+  // Add intra-tile (or point) loops.
+  for (unsigned i = 0; i < width; i++) {
+    OpBuilder b(topLoop);
+    // Loop bounds will be set later.
+    auto pointLoop = b.create<AffineForOp>(loc, 0, 0);
+    pointLoop.getBody()->getOperations().splice(
+        pointLoop.getBody()->begin(), topLoop->getBlock()->getOperations(),
+        topLoop);
+    newLoops[2 * width - 1 - i] = pointLoop;
+    topLoop = pointLoop.getOperation();
+    if (i == 0)
+      innermostPointLoop = pointLoop;
+  }
+
+  // Add tile space loops;
+  for (unsigned i = width; i < 2 * width; i++) {
+    OpBuilder b(topLoop);
+    // Loop bounds will be set later.
+    auto tileSpaceLoop = b.create<AffineForOp>(loc, 0, 0);
+    tileSpaceLoop.getBody()->getOperations().splice(
+        tileSpaceLoop.getBody()->begin(), topLoop->getBlock()->getOperations(),
+        topLoop);
+    newLoops[2 * width - i - 1] = tileSpaceLoop;
+    topLoop = tileSpaceLoop.getOperation();
+  }
+
+  // Move the loop body of the original nest to the new one.
+  moveLoopBody(origLoops[origLoops.size() - 1], innermostPointLoop);
+
+  SmallVector<Value *, 8> origLoopIVs;
+  extractForInductionVars(band, &origLoopIVs);
+  SmallVector<Optional<Value *>, 6> ids(origLoopIVs.begin(), origLoopIVs.end());
+  FlatAffineConstraints cst;
+  getIndexSet(band, &cst);
+
+  if (!cst.isHyperRectangular(0, width)) {
+    rootAffineForOp.emitError("tiled code generation unimplemented for the "
+                              "non-hyperrectangular case");
+    return failure();
+  }
+
+  constructTiledIndexSetHyperRect(origLoops, newLoops, tileSizes);
+  // In this case, the point loop IVs just replace the original ones.
+  for (unsigned i = 0; i < width; i++) {
+    origLoopIVs[i]->replaceAllUsesWith(newLoops[i + width].getInductionVar());
+  }
+
+  // Erase the old loop nest.
+  rootAffineForOp.erase();
+
+  return success();
+}
+
+// Identify valid and profitable bands of loops to tile. This is currently just
+// a temporary placeholder to test the mechanics of tiled code generation.
+// Returns all maximal outermost perfect loop nests to tile.
+static void getTileableBands(FuncOp f,
+                             std::vector<SmallVector<AffineForOp, 6>> *bands) {
+  // Get maximal perfect nest of 'affine.for' insts starting from root
+  // (inclusive).
+  auto getMaximalPerfectLoopNest = [&](AffineForOp root) {
+    SmallVector<AffineForOp, 6> band;
+    getPerfectlyNestedLoops(band, root);
+    bands->push_back(band);
+  };
+
+  for (auto &block : f)
+    for (auto &op : block)
+      if (auto forOp = dyn_cast<AffineForOp>(op))
+        getMaximalPerfectLoopNest(forOp);
+}
+
+// Reduce each tile size to the largest divisor of the corresponding trip count
+// (if the trip count is known).
+static void adjustToDivisorsOfTripCounts(ArrayRef<AffineForOp> band,
+                                         SmallVectorImpl<unsigned> *tileSizes) {
+  assert(band.size() == tileSizes->size() && "invalid tile size count");
+  for (unsigned i = 0, e = band.size(); i < e; i++) {
+    unsigned &tSizeAdjusted = (*tileSizes)[i];
+    auto mayConst = getConstantTripCount(band[i]);
+    if (!mayConst.hasValue())
+      continue;
+    // Adjust the tile size to largest factor of the trip count less than
+    // tSize.
+    uint64_t constTripCount = mayConst.getValue();
+    if (constTripCount > 1 && tSizeAdjusted > constTripCount / 2)
+      tSizeAdjusted = constTripCount / 2;
+    while (constTripCount % tSizeAdjusted != 0)
+      tSizeAdjusted--;
+  }
+}
+
+// Returns tile sizes to use. Checks CL options; if none are specified, sets it
+// based on a simple model that looks at the memory footprint and determines
+// tile sizes assuming identity accesses / 1:1 tile size proportional footprint
+// along each of the dimensions being tiled.
+// TODO(mlir-team): evolve this model. Tile size determination is a large area
+// to play with in general.
+void LoopTiling::getTileSizes(ArrayRef<AffineForOp> band,
+                              SmallVectorImpl<unsigned> *tileSizes) {
+  if (band.empty())
+    return;
+
+  tileSizes->resize(band.size());
+
+  // Use clTileSize for all loops if specified.
+  if (clTileSize.getNumOccurrences() > 0) {
+    std::fill(tileSizes->begin(), tileSizes->end(), clTileSize);
+    return;
+  }
+
+  // Use clTileSizes and fill them with default tile size if it's short.
+  if (!clTileSizes.empty()) {
+    std::fill(tileSizes->begin(), tileSizes->end(),
+              LoopTiling::kDefaultTileSize);
+    std::copy(clTileSizes.begin(),
+              clTileSizes.begin() + std::min(clTileSizes.size(), band.size()),
+              tileSizes->begin());
+    return;
+  }
+
+  // The first loop in the band.
+  auto rootForOp = band[0];
+  (void)rootForOp;
+
+  // Obtain memory footprint and set tile sizes so that a tile fits in
+  // the cache size. This is an approximation with the assumption that the
+  // footprint increases with the tile size linearly in that dimension (i.e.,
+  // assumes one-to-one access function).
+  auto fp = getMemoryFootprintBytes(band[0], 0);
+  if (!fp.hasValue()) {
+    // Fill with default tile sizes if footprint is unknown.
+    std::fill(tileSizes->begin(), tileSizes->end(),
+              LoopTiling::kDefaultTileSize);
+    if (avoidMaxMinBounds)
+      adjustToDivisorsOfTripCounts(band, tileSizes);
+    LLVM_DEBUG(
+        rootForOp.emitWarning("memory footprint unknown: using default tile "
+                              "sizes adjusted to trip count divisors"));
+    return;
+  }
+
+  // Check how many times larger the cache size is when compared to footprint.
+  uint64_t excessFactor = llvm::divideCeil(fp.getValue(), cacheSizeBytes);
+  if (excessFactor <= 1) {
+    // No need of any tiling - set tile size to 1.
+    std::fill(tileSizes->begin(), tileSizes->end(), 1);
+    return;
+  }
+
+  // Divide all loops equally in an attempt to reduce footprint.
+  // TODO(bondhugula): this is approximate. Ideally, obtain reuse factor /
+  // profitability along each dimension and weight tile sizes based on that as
+  // one possible approach. Or compute a polynomial in tile sizes and solve for
+  // it.
+
+  // For an n-d tilable band, compute n^th root of the excess.
+  unsigned tSize =
+      static_cast<unsigned>(floorl(std::pow(excessFactor, 1.0 / band.size())));
+  // We'll keep a running product to determine the last tile size better.
+  unsigned cumulProductOfTileSizes = 1;
+  for (unsigned i = 0, e = band.size(); i < e; i++) {
+    if (i < e - 1)
+      (*tileSizes)[i] = tSize;
+    else
+      // Set last tile size to cover the balance.
+      (*tileSizes)[i] = std::max(
+          1U, static_cast<unsigned>(excessFactor / cumulProductOfTileSizes));
+    cumulProductOfTileSizes *= (*tileSizes)[i];
+  }
+  if (avoidMaxMinBounds)
+    adjustToDivisorsOfTripCounts(band, tileSizes);
+}
+
+void LoopTiling::runOnFunction() {
+  // Override cache size if provided on command line.
+  if (clCacheSizeKiB.getNumOccurrences() > 0)
+    cacheSizeBytes = clCacheSizeKiB * 1024;
+
+  // Bands of loops to tile.
+  std::vector<SmallVector<AffineForOp, 6>> bands;
+  getTileableBands(getFunction(), &bands);
+
+  for (auto &band : bands) {
+    // Set up tile sizes; fill missing tile sizes at the end with default tile
+    // size or clTileSize if one was provided.
+    SmallVector<unsigned, 6> tileSizes;
+    getTileSizes(band, &tileSizes);
+    if (llvm::DebugFlag) {
+      auto diag = band[0].emitRemark("using tile sizes [");
+      for (auto tSize : tileSizes)
+        diag << tSize << " ";
+      diag << "]\n";
+    }
+    if (failed(tileCodeGen(band, tileSizes)))
+      return signalPassFailure();
+  }
+}
+
+constexpr unsigned LoopTiling::kDefaultTileSize;
+constexpr uint64_t LoopTiling::kDefaultCacheMemCapacity;
+
+static PassRegistration<LoopTiling> pass("affine-loop-tile", "Tile loop nests");
diff --git a/third_party/mlir/lib/Transforms/LoopUnroll.cpp b/third_party/mlir/lib/Transforms/LoopUnroll.cpp
new file mode 100644
index 0000000..1c7f339
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/LoopUnroll.cpp
@@ -0,0 +1,191 @@
+//===- LoopUnroll.cpp - Code to perform loop unrolling --------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements loop unrolling.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "affine-loop-unroll"
+
+static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
+
+// Loop unrolling factor.
+static llvm::cl::opt<unsigned> clUnrollFactor(
+    "unroll-factor",
+    llvm::cl::desc("Use this unroll factor for all loops being unrolled"),
+    llvm::cl::cat(clOptionsCategory));
+
+static llvm::cl::opt<bool> clUnrollFull("unroll-full",
+                                        llvm::cl::desc("Fully unroll loops"),
+                                        llvm::cl::cat(clOptionsCategory));
+
+static llvm::cl::opt<unsigned> clUnrollNumRepetitions(
+    "unroll-num-reps",
+    llvm::cl::desc("Unroll innermost loops repeatedly this many times"),
+    llvm::cl::cat(clOptionsCategory));
+
+static llvm::cl::opt<unsigned> clUnrollFullThreshold(
+    "unroll-full-threshold", llvm::cl::Hidden,
+    llvm::cl::desc(
+        "Unroll all loops with trip count less than or equal to this"),
+    llvm::cl::cat(clOptionsCategory));
+
+namespace {
+/// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a
+/// full unroll threshold was specified, in which case, fully unrolls all loops
+/// with trip count less than the specified threshold. The latter is for testing
+/// purposes, especially for testing outer loop unrolling.
+struct LoopUnroll : public FunctionPass<LoopUnroll> {
+  const Optional<unsigned> unrollFactor;
+  const Optional<bool> unrollFull;
+  // Callback to obtain unroll factors; if this has a callable target, takes
+  // precedence over command-line argument or passed argument.
+  const std::function<unsigned(AffineForOp)> getUnrollFactor;
+
+  explicit LoopUnroll(
+      Optional<unsigned> unrollFactor = None, Optional<bool> unrollFull = None,
+      const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr)
+      : unrollFactor(unrollFactor), unrollFull(unrollFull),
+        getUnrollFactor(getUnrollFactor) {}
+
+  void runOnFunction() override;
+
+  /// Unroll this for op. Returns failure if nothing was done.
+  LogicalResult runOnAffineForOp(AffineForOp forOp);
+
+  static const unsigned kDefaultUnrollFactor = 4;
+};
+} // end anonymous namespace
+
+void LoopUnroll::runOnFunction() {
+  // Gathers all innermost loops through a post order pruned walk.
+  struct InnermostLoopGatherer {
+    // Store innermost loops as we walk.
+    std::vector<AffineForOp> loops;
+
+    void walkPostOrder(FuncOp f) {
+      for (auto &b : f)
+        walkPostOrder(b.begin(), b.end());
+    }
+
+    bool walkPostOrder(Block::iterator Start, Block::iterator End) {
+      bool hasInnerLoops = false;
+      // We need to walk all elements since all innermost loops need to be
+      // gathered as opposed to determining whether this list has any inner
+      // loops or not.
+      while (Start != End)
+        hasInnerLoops |= walkPostOrder(&(*Start++));
+      return hasInnerLoops;
+    }
+    bool walkPostOrder(Operation *opInst) {
+      bool hasInnerLoops = false;
+      for (auto &region : opInst->getRegions())
+        for (auto &block : region)
+          hasInnerLoops |= walkPostOrder(block.begin(), block.end());
+      if (isa<AffineForOp>(opInst)) {
+        if (!hasInnerLoops)
+          loops.push_back(cast<AffineForOp>(opInst));
+        return true;
+      }
+      return hasInnerLoops;
+    }
+  };
+
+  if (clUnrollFull.getNumOccurrences() > 0 &&
+      clUnrollFullThreshold.getNumOccurrences() > 0) {
+    // Store short loops as we walk.
+    std::vector<AffineForOp> loops;
+
+    // Gathers all loops with trip count <= minTripCount. Do a post order walk
+    // so that loops are gathered from innermost to outermost (or else unrolling
+    // an outer one may delete gathered inner ones).
+    getFunction().walk<AffineForOp>([&](AffineForOp forOp) {
+      Optional<uint64_t> tripCount = getConstantTripCount(forOp);
+      if (tripCount.hasValue() && tripCount.getValue() <= clUnrollFullThreshold)
+        loops.push_back(forOp);
+    });
+    for (auto forOp : loops)
+      loopUnrollFull(forOp);
+    return;
+  }
+
+  unsigned numRepetitions = clUnrollNumRepetitions.getNumOccurrences() > 0
+                                ? clUnrollNumRepetitions
+                                : 1;
+  // If the call back is provided, we will recurse until no loops are found.
+  FuncOp func = getFunction();
+  for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) {
+    InnermostLoopGatherer ilg;
+    ilg.walkPostOrder(func);
+    auto &loops = ilg.loops;
+    if (loops.empty())
+      break;
+    bool unrolled = false;
+    for (auto forOp : loops)
+      unrolled |= succeeded(runOnAffineForOp(forOp));
+    if (!unrolled)
+      // Break out if nothing was unrolled.
+      break;
+  }
+}
+
+/// Unrolls a 'affine.for' op. Returns success if the loop was unrolled,
+/// failure otherwise. The default unroll factor is 4.
+LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) {
+  // Use the function callback if one was provided.
+  if (getUnrollFactor) {
+    return loopUnrollByFactor(forOp, getUnrollFactor(forOp));
+  }
+  // Unroll by the factor passed, if any.
+  if (unrollFactor.hasValue())
+    return loopUnrollByFactor(forOp, unrollFactor.getValue());
+  // Unroll by the command line factor if one was specified.
+  if (clUnrollFactor.getNumOccurrences() > 0)
+    return loopUnrollByFactor(forOp, clUnrollFactor);
+  // Unroll completely if full loop unroll was specified.
+  if (clUnrollFull.getNumOccurrences() > 0 ||
+      (unrollFull.hasValue() && unrollFull.getValue()))
+    return loopUnrollFull(forOp);
+
+  // Unroll by four otherwise.
+  return loopUnrollByFactor(forOp, kDefaultUnrollFactor);
+}
+
+FunctionPassBase *mlir::createLoopUnrollPass(
+    int unrollFactor, int unrollFull,
+    const std::function<unsigned(AffineForOp)> &getUnrollFactor) {
+  return new LoopUnroll(
+      unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor),
+      unrollFull == -1 ? None : Optional<bool>(unrollFull), getUnrollFactor);
+}
+
+static PassRegistration<LoopUnroll> pass("affine-loop-unroll", "Unroll loops");
diff --git a/third_party/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/third_party/mlir/lib/Transforms/LoopUnrollAndJam.cpp
new file mode 100644
index 0000000..7650db1
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/LoopUnrollAndJam.cpp
@@ -0,0 +1,243 @@
+//===- LoopUnrollAndJam.cpp - Code to perform loop unroll and jam ---------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements loop unroll and jam. Unroll and jam is a transformation
+// that improves locality, in particular, register reuse, while also improving
+// operation level parallelism. The example below shows what it does in nearly
+// the general case. Loop unroll and jam currently works if the bounds of the
+// loops inner to the loop being unroll-jammed do not depend on the latter.
+//
+// Before      After unroll and jam of i by factor 2:
+//
+//             for i, step = 2
+// for i         S1(i);
+//   S1;         S2(i);
+//   S2;         S1(i+1);
+//   for j       S2(i+1);
+//     S3;       for j
+//     S4;         S3(i, j);
+//   S5;           S4(i, j);
+//   S6;           S3(i+1, j)
+//                 S4(i+1, j)
+//               S5(i);
+//               S6(i);
+//               S5(i+1);
+//               S6(i+1);
+//
+// Note: 'if/else' blocks are not jammed. So, if there are loops inside if
+// op's, bodies of those loops will not be jammed.
+//===----------------------------------------------------------------------===//
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/CommandLine.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "affine-loop-unroll-jam"
+
+static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
+
+// Loop unroll and jam factor.
+static llvm::cl::opt<unsigned>
+    clUnrollJamFactor("unroll-jam-factor", llvm::cl::Hidden,
+                      llvm::cl::desc("Use this unroll jam factor for all loops"
+                                     " (default 4)"),
+                      llvm::cl::cat(clOptionsCategory));
+
+namespace {
+/// Loop unroll jam pass. Currently, this just unroll jams the first
+/// outer loop in a Function.
+struct LoopUnrollAndJam : public FunctionPass<LoopUnrollAndJam> {
+  Optional<unsigned> unrollJamFactor;
+  static const unsigned kDefaultUnrollJamFactor = 4;
+
+  explicit LoopUnrollAndJam(Optional<unsigned> unrollJamFactor = None)
+      : unrollJamFactor(unrollJamFactor) {}
+
+  void runOnFunction() override;
+  LogicalResult runOnAffineForOp(AffineForOp forOp);
+};
+} // end anonymous namespace
+
+FunctionPassBase *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) {
+  return new LoopUnrollAndJam(
+      unrollJamFactor == -1 ? None : Optional<unsigned>(unrollJamFactor));
+}
+
+void LoopUnrollAndJam::runOnFunction() {
+  // Currently, just the outermost loop from the first loop nest is
+  // unroll-and-jammed by this pass. However, runOnAffineForOp can be called on
+  // any for operation.
+  auto &entryBlock = getFunction().front();
+  if (auto forOp = dyn_cast<AffineForOp>(entryBlock.front()))
+    runOnAffineForOp(forOp);
+}
+
+/// Unroll and jam a 'affine.for' op. Default unroll jam factor is
+/// kDefaultUnrollJamFactor. Return failure if nothing was done.
+LogicalResult LoopUnrollAndJam::runOnAffineForOp(AffineForOp forOp) {
+  // Unroll and jam by the factor that was passed if any.
+  if (unrollJamFactor.hasValue())
+    return loopUnrollJamByFactor(forOp, unrollJamFactor.getValue());
+  // Otherwise, unroll jam by the command-line factor if one was specified.
+  if (clUnrollJamFactor.getNumOccurrences() > 0)
+    return loopUnrollJamByFactor(forOp, clUnrollJamFactor);
+
+  // Unroll and jam by four otherwise.
+  return loopUnrollJamByFactor(forOp, kDefaultUnrollJamFactor);
+}
+
+LogicalResult mlir::loopUnrollJamUpToFactor(AffineForOp forOp,
+                                            uint64_t unrollJamFactor) {
+  Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+
+  if (mayBeConstantTripCount.hasValue() &&
+      mayBeConstantTripCount.getValue() < unrollJamFactor)
+    return loopUnrollJamByFactor(forOp, mayBeConstantTripCount.getValue());
+  return loopUnrollJamByFactor(forOp, unrollJamFactor);
+}
+
+/// Unrolls and jams this loop by the specified factor.
+LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp,
+                                          uint64_t unrollJamFactor) {
+  // Gathers all maximal sub-blocks of operations that do not themselves
+  // include a for op (a operation could have a descendant for op though
+  // in its tree).  Ignore the block terminators.
+  struct JamBlockGatherer {
+    // Store iterators to the first and last op of each sub-block found.
+    std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
+
+    // This is a linear time walk.
+    void walk(Operation *op) {
+      for (auto &region : op->getRegions())
+        for (auto &block : region)
+          walk(block);
+    }
+    void walk(Block &block) {
+      for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
+        auto subBlockStart = it;
+        while (it != e && !isa<AffineForOp>(&*it))
+          ++it;
+        if (it != subBlockStart)
+          subBlocks.push_back({subBlockStart, std::prev(it)});
+        // Process all for insts that appear next.
+        while (it != e && isa<AffineForOp>(&*it))
+          walk(&*it++);
+      }
+    }
+  };
+
+  assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1");
+
+  if (unrollJamFactor == 1)
+    return promoteIfSingleIteration(forOp);
+
+  if (forOp.getBody()->empty() ||
+      forOp.getBody()->begin() == std::prev(forOp.getBody()->end()))
+    return failure();
+
+  // Loops where both lower and upper bounds are multi-result maps won't be
+  // unrolled (since the trip can't be expressed as an affine function in
+  // general).
+  // TODO(mlir-team): this may not be common, but we could support the case
+  // where the lower bound is a multi-result map and the ub is a single result
+  // one.
+  if (forOp.getLowerBoundMap().getNumResults() != 1)
+    return failure();
+
+  Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+  // If the trip count is lower than the unroll jam factor, no unroll jam.
+  if (mayBeConstantTripCount.hasValue() &&
+      mayBeConstantTripCount.getValue() < unrollJamFactor)
+    return failure();
+
+  auto *forInst = forOp.getOperation();
+
+  // Gather all sub-blocks to jam upon the loop being unrolled.
+  JamBlockGatherer jbg;
+  jbg.walk(forInst);
+  auto &subBlocks = jbg.subBlocks;
+
+  // Generate the cleanup loop if trip count isn't a multiple of
+  // unrollJamFactor.
+  if (getLargestDivisorOfTripCount(forOp) % unrollJamFactor != 0) {
+    // Insert the cleanup loop right after 'forOp'.
+    OpBuilder builder(forInst->getBlock(), std::next(Block::iterator(forInst)));
+    auto cleanupAffineForOp = cast<AffineForOp>(builder.clone(*forInst));
+    // Adjust the lower bound of the cleanup loop; its upper bound is the same
+    // as the original loop's upper bound.
+    AffineMap cleanupMap;
+    SmallVector<Value *, 4> cleanupOperands;
+    getCleanupLoopLowerBound(forOp, unrollJamFactor, &cleanupMap,
+                             &cleanupOperands, builder);
+    cleanupAffineForOp.setLowerBound(cleanupOperands, cleanupMap);
+
+    // Promote the cleanup loop if it has turned into a single iteration loop.
+    promoteIfSingleIteration(cleanupAffineForOp);
+
+    // Adjust the upper bound of the original loop - it will be the same as the
+    // cleanup loop's lower bound. Its lower bound remains unchanged.
+    forOp.setUpperBound(cleanupOperands, cleanupMap);
+  }
+
+  // Scale the step of loop being unroll-jammed by the unroll-jam factor.
+  int64_t step = forOp.getStep();
+  forOp.setStep(step * unrollJamFactor);
+
+  auto *forOpIV = forOp.getInductionVar();
+  for (auto &subBlock : subBlocks) {
+    // Builder to insert unroll-jammed bodies. Insert right at the end of
+    // sub-block.
+    OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
+
+    // Unroll and jam (appends unrollJamFactor-1 additional copies).
+    for (unsigned i = 1; i < unrollJamFactor; i++) {
+      BlockAndValueMapping operandMapping;
+
+      // If the induction variable is used, create a remapping to the value for
+      // this unrolled instance.
+      if (!forOpIV->use_empty()) {
+        // iv' = iv + i, i = 1 to unrollJamFactor-1.
+        auto d0 = builder.getAffineDimExpr(0);
+        auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step});
+        auto ivUnroll =
+            builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forOpIV);
+        operandMapping.map(forOpIV, ivUnroll);
+      }
+      // Clone the sub-block being unroll-jammed.
+      for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) {
+        builder.clone(*it, operandMapping);
+      }
+    }
+  }
+
+  // Promote the loop body up if this has turned into a single iteration loop.
+  promoteIfSingleIteration(forOp);
+  return success();
+}
+
+static PassRegistration<LoopUnrollAndJam> pass("affine-loop-unroll-jam",
+                                               "Unroll and jam loops");
diff --git a/third_party/mlir/lib/Transforms/LowerAffine.cpp b/third_party/mlir/lib/Transforms/LowerAffine.cpp
new file mode 100644
index 0000000..062134d
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/LowerAffine.cpp
@@ -0,0 +1,538 @@
+//===- LowerAffine.cpp - Lower affine constructs to primitives ------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file lowers affine constructs (If and For statements, AffineApply
+// operations) within a function into their standard If and For equivalent ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/LowerAffine.h"
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/IR/AffineExprVisitor.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Support/Functional.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+
+namespace {
+// Visit affine expressions recursively and build the sequence of operations
+// that correspond to it.  Visitation functions return an Value of the
+// expression subtree they visited or `nullptr` on error.
+class AffineApplyExpander
+    : public AffineExprVisitor<AffineApplyExpander, Value *> {
+public:
+  // This internal class expects arguments to be non-null, checks must be
+  // performed at the call site.
+  AffineApplyExpander(OpBuilder &builder, ArrayRef<Value *> dimValues,
+                      ArrayRef<Value *> symbolValues, Location loc)
+      : builder(builder), dimValues(dimValues), symbolValues(symbolValues),
+        loc(loc) {}
+
+  template <typename OpTy> Value *buildBinaryExpr(AffineBinaryOpExpr expr) {
+    auto lhs = visit(expr.getLHS());
+    auto rhs = visit(expr.getRHS());
+    if (!lhs || !rhs)
+      return nullptr;
+    auto op = builder.create<OpTy>(loc, lhs, rhs);
+    return op.getResult();
+  }
+
+  Value *visitAddExpr(AffineBinaryOpExpr expr) {
+    return buildBinaryExpr<AddIOp>(expr);
+  }
+
+  Value *visitMulExpr(AffineBinaryOpExpr expr) {
+    return buildBinaryExpr<MulIOp>(expr);
+  }
+
+  // Euclidean modulo operation: negative RHS is not allowed.
+  // Remainder of the euclidean integer division is always non-negative.
+  //
+  // Implemented as
+  //
+  //     a mod b =
+  //         let remainder = srem a, b;
+  //             negative = a < 0 in
+  //         select negative, remainder + b, remainder.
+  Value *visitModExpr(AffineBinaryOpExpr expr) {
+    auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
+    if (!rhsConst) {
+      emitError(
+          loc,
+          "semi-affine expressions (modulo by non-const) are not supported");
+      return nullptr;
+    }
+    if (rhsConst.getValue() <= 0) {
+      emitError(loc, "modulo by non-positive value is not supported");
+      return nullptr;
+    }
+
+    auto lhs = visit(expr.getLHS());
+    auto rhs = visit(expr.getRHS());
+    assert(lhs && rhs && "unexpected affine expr lowering failure");
+
+    Value *remainder = builder.create<RemISOp>(loc, lhs, rhs);
+    Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0);
+    Value *isRemainderNegative =
+        builder.create<CmpIOp>(loc, CmpIPredicate::SLT, remainder, zeroCst);
+    Value *correctedRemainder = builder.create<AddIOp>(loc, remainder, rhs);
+    Value *result = builder.create<SelectOp>(loc, isRemainderNegative,
+                                             correctedRemainder, remainder);
+    return result;
+  }
+
+  // Floor division operation (rounds towards negative infinity).
+  //
+  // For positive divisors, it can be implemented without branching and with a
+  // single division operation as
+  //
+  //        a floordiv b =
+  //            let negative = a < 0 in
+  //            let absolute = negative ? -a - 1 : a in
+  //            let quotient = absolute / b in
+  //                negative ? -quotient - 1 : quotient
+  Value *visitFloorDivExpr(AffineBinaryOpExpr expr) {
+    auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
+    if (!rhsConst) {
+      emitError(
+          loc,
+          "semi-affine expressions (division by non-const) are not supported");
+      return nullptr;
+    }
+    if (rhsConst.getValue() <= 0) {
+      emitError(loc, "division by non-positive value is not supported");
+      return nullptr;
+    }
+
+    auto lhs = visit(expr.getLHS());
+    auto rhs = visit(expr.getRHS());
+    assert(lhs && rhs && "unexpected affine expr lowering failure");
+
+    Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0);
+    Value *noneCst = builder.create<ConstantIndexOp>(loc, -1);
+    Value *negative =
+        builder.create<CmpIOp>(loc, CmpIPredicate::SLT, lhs, zeroCst);
+    Value *negatedDecremented = builder.create<SubIOp>(loc, noneCst, lhs);
+    Value *dividend =
+        builder.create<SelectOp>(loc, negative, negatedDecremented, lhs);
+    Value *quotient = builder.create<DivISOp>(loc, dividend, rhs);
+    Value *correctedQuotient = builder.create<SubIOp>(loc, noneCst, quotient);
+    Value *result =
+        builder.create<SelectOp>(loc, negative, correctedQuotient, quotient);
+    return result;
+  }
+
+  // Ceiling division operation (rounds towards positive infinity).
+  //
+  // For positive divisors, it can be implemented without branching and with a
+  // single division operation as
+  //
+  //     a ceildiv b =
+  //         let negative = a <= 0 in
+  //         let absolute = negative ? -a : a - 1 in
+  //         let quotient = absolute / b in
+  //             negative ? -quotient : quotient + 1
+  Value *visitCeilDivExpr(AffineBinaryOpExpr expr) {
+    auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
+    if (!rhsConst) {
+      emitError(loc) << "semi-affine expressions (division by non-const) are "
+                        "not supported";
+      return nullptr;
+    }
+    if (rhsConst.getValue() <= 0) {
+      emitError(loc, "division by non-positive value is not supported");
+      return nullptr;
+    }
+    auto lhs = visit(expr.getLHS());
+    auto rhs = visit(expr.getRHS());
+    assert(lhs && rhs && "unexpected affine expr lowering failure");
+
+    Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0);
+    Value *oneCst = builder.create<ConstantIndexOp>(loc, 1);
+    Value *nonPositive =
+        builder.create<CmpIOp>(loc, CmpIPredicate::SLE, lhs, zeroCst);
+    Value *negated = builder.create<SubIOp>(loc, zeroCst, lhs);
+    Value *decremented = builder.create<SubIOp>(loc, lhs, oneCst);
+    Value *dividend =
+        builder.create<SelectOp>(loc, nonPositive, negated, decremented);
+    Value *quotient = builder.create<DivISOp>(loc, dividend, rhs);
+    Value *negatedQuotient = builder.create<SubIOp>(loc, zeroCst, quotient);
+    Value *incrementedQuotient = builder.create<AddIOp>(loc, quotient, oneCst);
+    Value *result = builder.create<SelectOp>(loc, nonPositive, negatedQuotient,
+                                             incrementedQuotient);
+    return result;
+  }
+
+  Value *visitConstantExpr(AffineConstantExpr expr) {
+    auto valueAttr =
+        builder.getIntegerAttr(builder.getIndexType(), expr.getValue());
+    auto op =
+        builder.create<ConstantOp>(loc, builder.getIndexType(), valueAttr);
+    return op.getResult();
+  }
+
+  Value *visitDimExpr(AffineDimExpr expr) {
+    assert(expr.getPosition() < dimValues.size() &&
+           "affine dim position out of range");
+    return dimValues[expr.getPosition()];
+  }
+
+  Value *visitSymbolExpr(AffineSymbolExpr expr) {
+    assert(expr.getPosition() < symbolValues.size() &&
+           "symbol dim position out of range");
+    return symbolValues[expr.getPosition()];
+  }
+
+private:
+  OpBuilder &builder;
+  ArrayRef<Value *> dimValues;
+  ArrayRef<Value *> symbolValues;
+
+  Location loc;
+};
+} // namespace
+
+// Create a sequence of operations that implement the `expr` applied to the
+// given dimension and symbol values.
+mlir::Value *mlir::expandAffineExpr(OpBuilder &builder, Location loc,
+                                    AffineExpr expr,
+                                    ArrayRef<Value *> dimValues,
+                                    ArrayRef<Value *> symbolValues) {
+  return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr);
+}
+
+// Create a sequence of operations that implement the `affineMap` applied to
+// the given `operands` (as it it were an AffineApplyOp).
+Optional<SmallVector<Value *, 8>> static expandAffineMap(
+    OpBuilder &builder, Location loc, AffineMap affineMap,
+    ArrayRef<Value *> operands) {
+  auto numDims = affineMap.getNumDims();
+  auto expanded = functional::map(
+      [numDims, &builder, loc, operands](AffineExpr expr) {
+        return expandAffineExpr(builder, loc, expr,
+                                operands.take_front(numDims),
+                                operands.drop_front(numDims));
+      },
+      affineMap.getResults());
+  if (llvm::all_of(expanded, [](Value *v) { return v; }))
+    return expanded;
+  return None;
+}
+
+// Given a range of values, emit the code that reduces them with "min" or "max"
+// depending on the provided comparison predicate.  The predicate defines which
+// comparison to perform, "lt" for "min", "gt" for "max" and is used for the
+// `cmpi` operation followed by the `select` operation:
+//
+//   %cond   = cmpi "predicate" %v0, %v1
+//   %result = select %cond, %v0, %v1
+//
+// Multiple values are scanned in a linear sequence.  This creates a data
+// dependences that wouldn't exist in a tree reduction, but is easier to
+// recognize as a reduction by the subsequent passes.
+static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate,
+                                      ArrayRef<Value *> values,
+                                      OpBuilder &builder) {
+  assert(!llvm::empty(values) && "empty min/max chain");
+
+  auto valueIt = values.begin();
+  Value *value = *valueIt++;
+  for (; valueIt != values.end(); ++valueIt) {
+    auto cmpOp = builder.create<CmpIOp>(loc, predicate, value, *valueIt);
+    value = builder.create<SelectOp>(loc, cmpOp.getResult(), value, *valueIt);
+  }
+
+  return value;
+}
+
+// Emit instructions that correspond to the affine map in the lower bound
+// applied to the respective operands, and compute the maximum value across
+// the results.
+Value *mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) {
+  SmallVector<Value *, 8> boundOperands(op.getLowerBoundOperands());
+  auto lbValues = expandAffineMap(builder, op.getLoc(), op.getLowerBoundMap(),
+                                  boundOperands);
+  if (!lbValues)
+    return nullptr;
+  return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::SGT, *lbValues,
+                                 builder);
+}
+
+// Emit instructions that correspond to the affine map in the upper bound
+// applied to the respective operands, and compute the minimum value across
+// the results.
+Value *mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) {
+  SmallVector<Value *, 8> boundOperands(op.getUpperBoundOperands());
+  auto ubValues = expandAffineMap(builder, op.getLoc(), op.getUpperBoundMap(),
+                                  boundOperands);
+  if (!ubValues)
+    return nullptr;
+  return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::SLT, *ubValues,
+                                 builder);
+}
+
+namespace {
+// Affine terminators are removed.
+class AffineTerminatorLowering : public OpRewritePattern<AffineTerminatorOp> {
+public:
+  using OpRewritePattern<AffineTerminatorOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(AffineTerminatorOp op,
+                                     PatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<loop::TerminatorOp>(op);
+    return matchSuccess();
+  }
+};
+
+class AffineForLowering : public OpRewritePattern<AffineForOp> {
+public:
+  using OpRewritePattern<AffineForOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(AffineForOp op,
+                                     PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value *lowerBound = lowerAffineLowerBound(op, rewriter);
+    Value *upperBound = lowerAffineUpperBound(op, rewriter);
+    Value *step = rewriter.create<ConstantIndexOp>(loc, op.getStep());
+    auto f = rewriter.create<loop::ForOp>(loc, lowerBound, upperBound, step);
+    f.region().getBlocks().clear();
+    rewriter.inlineRegionBefore(op.region(), f.region(), f.region().end());
+    rewriter.replaceOp(op, {});
+    return matchSuccess();
+  }
+};
+
+class AffineIfLowering : public OpRewritePattern<AffineIfOp> {
+public:
+  using OpRewritePattern<AffineIfOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(AffineIfOp op,
+                                     PatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+
+    // Now we just have to handle the condition logic.
+    auto integerSet = op.getIntegerSet();
+    Value *zeroConstant = rewriter.create<ConstantIndexOp>(loc, 0);
+    SmallVector<Value *, 8> operands(op.getOperation()->getOperands());
+    auto operandsRef = llvm::makeArrayRef(operands);
+
+    // Calculate cond as a conjunction without short-circuiting.
+    Value *cond = nullptr;
+    for (unsigned i = 0, e = integerSet.getNumConstraints(); i < e; ++i) {
+      AffineExpr constraintExpr = integerSet.getConstraint(i);
+      bool isEquality = integerSet.isEq(i);
+
+      // Build and apply an affine expression
+      auto numDims = integerSet.getNumDims();
+      Value *affResult = expandAffineExpr(rewriter, loc, constraintExpr,
+                                          operandsRef.take_front(numDims),
+                                          operandsRef.drop_front(numDims));
+      if (!affResult)
+        return matchFailure();
+      auto pred = isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE;
+      Value *cmpVal =
+          rewriter.create<CmpIOp>(loc, pred, affResult, zeroConstant);
+      cond =
+          cond ? rewriter.create<AndOp>(loc, cond, cmpVal).getResult() : cmpVal;
+    }
+    cond = cond ? cond
+                : rewriter.create<ConstantIntOp>(loc, /*value=*/1, /*width=*/1);
+
+    bool hasElseRegion = !op.elseRegion().empty();
+    auto ifOp = rewriter.create<loop::IfOp>(loc, cond, hasElseRegion);
+    rewriter.inlineRegionBefore(op.thenRegion(), &ifOp.thenRegion().back());
+    ifOp.thenRegion().back().erase();
+    if (hasElseRegion) {
+      rewriter.inlineRegionBefore(op.elseRegion(), &ifOp.elseRegion().back());
+      ifOp.elseRegion().back().erase();
+    }
+
+    // Ok, we're done!
+    rewriter.replaceOp(op, {});
+    return matchSuccess();
+  }
+};
+
+// Convert an "affine.apply" operation into a sequence of arithmetic
+// operations using the StandardOps dialect.
+class AffineApplyLowering : public OpRewritePattern<AffineApplyOp> {
+public:
+  using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
+
+  virtual PatternMatchResult
+  matchAndRewrite(AffineApplyOp op, PatternRewriter &rewriter) const override {
+    auto maybeExpandedMap =
+        expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
+                        llvm::to_vector<8>(op.getOperands()));
+    if (!maybeExpandedMap)
+      return matchFailure();
+    rewriter.replaceOp(op, *maybeExpandedMap);
+    return matchSuccess();
+  }
+};
+
+// Apply the affine map from an 'affine.load' operation to its operands, and
+// feed the results to a newly created 'std.load' operation (which replaces the
+// original 'affine.load').
+class AffineLoadLowering : public OpRewritePattern<AffineLoadOp> {
+public:
+  using OpRewritePattern<AffineLoadOp>::OpRewritePattern;
+
+  virtual PatternMatchResult
+  matchAndRewrite(AffineLoadOp op, PatternRewriter &rewriter) const override {
+    // Expand affine map from 'affineLoadOp'.
+    SmallVector<Value *, 8> indices(op.getIndices());
+    auto maybeExpandedMap =
+        expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
+    if (!maybeExpandedMap)
+      return matchFailure();
+
+    // Build std.load memref[expandedMap.results].
+    rewriter.replaceOpWithNewOp<LoadOp>(op, op.getMemRef(), *maybeExpandedMap);
+    return matchSuccess();
+  }
+};
+
+// Apply the affine map from an 'affine.store' operation to its operands, and
+// feed the results to a newly created 'std.store' operation (which replaces the
+// original 'affine.store').
+class AffineStoreLowering : public OpRewritePattern<AffineStoreOp> {
+public:
+  using OpRewritePattern<AffineStoreOp>::OpRewritePattern;
+
+  virtual PatternMatchResult
+  matchAndRewrite(AffineStoreOp op, PatternRewriter &rewriter) const override {
+    // Expand affine map from 'affineStoreOp'.
+    SmallVector<Value *, 8> indices(op.getIndices());
+    auto maybeExpandedMap =
+        expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
+    if (!maybeExpandedMap)
+      return matchFailure();
+
+    // Build std.store valutToStore, memref[expandedMap.results].
+    rewriter.replaceOpWithNewOp<StoreOp>(op, op.getValueToStore(),
+                                         op.getMemRef(), *maybeExpandedMap);
+    return matchSuccess();
+  }
+};
+
+// Apply the affine maps from an 'affine.dma_start' operation to each of their
+// respective map operands, and feed the results to a newly created
+// 'std.dma_start' operation (which replaces the original 'affine.dma_start').
+class AffineDmaStartLowering : public OpRewritePattern<AffineDmaStartOp> {
+public:
+  using OpRewritePattern<AffineDmaStartOp>::OpRewritePattern;
+
+  virtual PatternMatchResult
+  matchAndRewrite(AffineDmaStartOp op,
+                  PatternRewriter &rewriter) const override {
+    SmallVector<Value *, 8> operands(op.getOperands());
+    auto operandsRef = llvm::makeArrayRef(operands);
+
+    // Expand affine map for DMA source memref.
+    auto maybeExpandedSrcMap = expandAffineMap(
+        rewriter, op.getLoc(), op.getSrcMap(),
+        operandsRef.drop_front(op.getSrcMemRefOperandIndex() + 1));
+    if (!maybeExpandedSrcMap)
+      return matchFailure();
+    // Expand affine map for DMA destination memref.
+    auto maybeExpandedDstMap = expandAffineMap(
+        rewriter, op.getLoc(), op.getDstMap(),
+        operandsRef.drop_front(op.getDstMemRefOperandIndex() + 1));
+    if (!maybeExpandedDstMap)
+      return matchFailure();
+    // Expand affine map for DMA tag memref.
+    auto maybeExpandedTagMap = expandAffineMap(
+        rewriter, op.getLoc(), op.getTagMap(),
+        operandsRef.drop_front(op.getTagMemRefOperandIndex() + 1));
+    if (!maybeExpandedTagMap)
+      return matchFailure();
+
+    // Build std.dma_start operation with affine map results.
+    rewriter.replaceOpWithNewOp<DmaStartOp>(
+        op, op.getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(),
+        *maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(),
+        *maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride());
+    return matchSuccess();
+  }
+};
+
+// Apply the affine map from an 'affine.dma_wait' operation tag memref,
+// and feed the results to a newly created 'std.dma_wait' operation (which
+// replaces the original 'affine.dma_wait').
+class AffineDmaWaitLowering : public OpRewritePattern<AffineDmaWaitOp> {
+public:
+  using OpRewritePattern<AffineDmaWaitOp>::OpRewritePattern;
+
+  virtual PatternMatchResult
+  matchAndRewrite(AffineDmaWaitOp op,
+                  PatternRewriter &rewriter) const override {
+    // Expand affine map for DMA tag memref.
+    SmallVector<Value *, 8> indices(op.getTagIndices());
+    auto maybeExpandedTagMap =
+        expandAffineMap(rewriter, op.getLoc(), op.getTagMap(), indices);
+    if (!maybeExpandedTagMap)
+      return matchFailure();
+
+    // Build std.dma_wait operation with affine map results.
+    rewriter.replaceOpWithNewOp<DmaWaitOp>(
+        op, op.getTagMemRef(), *maybeExpandedTagMap, op.getNumElements());
+    return matchSuccess();
+  }
+};
+
+} // end namespace
+
+void mlir::populateAffineToStdConversionPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *ctx) {
+  patterns
+      .insert<AffineApplyLowering, AffineDmaStartLowering,
+              AffineDmaWaitLowering, AffineLoadLowering, AffineStoreLowering,
+              AffineForLowering, AffineIfLowering, AffineTerminatorLowering>(
+          ctx);
+}
+
+namespace {
+class LowerAffinePass : public FunctionPass<LowerAffinePass> {
+  void runOnFunction() override {
+    OwningRewritePatternList patterns;
+    populateAffineToStdConversionPatterns(patterns, &getContext());
+    ConversionTarget target(getContext());
+    target.addLegalDialect<loop::LoopOpsDialect, StandardOpsDialect>();
+    if (failed(applyPartialConversion(getFunction(), target, patterns)))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+/// Lowers If and For operations within a function into their lower level CFG
+/// equivalent blocks.
+FunctionPassBase *mlir::createLowerAffinePass() {
+  return new LowerAffinePass();
+}
+
+static PassRegistration<LowerAffinePass>
+    pass("lower-affine",
+         "Lower If, For, AffineApply operations to primitive equivalents");
diff --git a/third_party/mlir/lib/Transforms/LowerVectorTransfers.cpp b/third_party/mlir/lib/Transforms/LowerVectorTransfers.cpp
new file mode 100644
index 0000000..e2d5920
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/LowerVectorTransfers.cpp
@@ -0,0 +1,383 @@
+//===- LowerVectorTransfers.cpp - LowerVectorTransfers Pass Impl ----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements target-dependent lowering of vector transfer operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include <type_traits>
+
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/NestedMatcher.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/Analysis/VectorAnalysis.h"
+#include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Helpers.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Support/Functional.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/VectorOps/VectorOps.h"
+
+/// Implements lowering of VectorTransferReadOp and VectorTransferWriteOp to a
+/// proper abstraction for the hardware.
+///
+/// For now, we only emit a simple loop nest that performs clipped pointwise
+/// copies from a remote to a locally allocated memory.
+///
+/// Consider the case:
+///
+/// ```mlir {.mlir}
+///    // Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into
+///    // vector<32x256xf32> and pad with %f0 to handle the boundary case:
+///    %f0 = constant 0.0f : f32
+///    affine.for %i0 = 0 to %0 {
+///      affine.for %i1 = 0 to %1 step 256 {
+///        affine.for %i2 = 0 to %2 step 32 {
+///          %v = vector.transfer_read %A[%i0, %i1, %i2], (%f0)
+///               {permutation_map: (d0, d1, d2) -> (d2, d1)} :
+///               memref<?x?x?xf32>, vector<32x256xf32>
+///    }}}
+/// ```
+///
+/// The rewriters construct loop and indices that access MemRef A in a pattern
+/// resembling the following (while guaranteeing an always full-tile
+/// abstraction):
+///
+/// ```mlir {.mlir}
+///    affine.for %d2 = 0 to 256 {
+///      affine.for %d1 = 0 to 32 {
+///        %s = %A[%i0, %i1 + %d1, %i2 + %d2] : f32
+///        %tmp[%d2, %d1] = %s
+///      }
+///    }
+/// ```
+///
+/// In the current state, only a clipping transfer is implemented by `clip`,
+/// which creates individual indexing expressions of the form:
+///
+/// ```mlir-dsc
+///    SELECT(i + ii < zero, zero, SELECT(i + ii < N, i + ii, N - one))
+/// ```
+
+using namespace mlir;
+using vector::VectorTransferReadOp;
+using vector::VectorTransferWriteOp;
+
+#define DEBUG_TYPE "affine-lower-vector-transfers"
+
+namespace {
+
+/// Lowers VectorTransferOp into a combination of:
+///   1. local memory allocation;
+///   2. perfect loop nest over:
+///      a. scalar load/stores from local buffers (viewed as a scalar memref);
+///      a. scalar store/load to original memref (with clipping).
+///   3. vector_load/store
+///   4. local memory deallocation.
+/// Minor variations occur depending on whether a VectorTransferReadOp or
+/// a VectorTransferWriteOp is rewritten.
+template <typename VectorTransferOpTy>
+struct VectorTransferRewriter : public RewritePattern {
+  explicit VectorTransferRewriter(MLIRContext *context)
+      : RewritePattern(VectorTransferOpTy::getOperationName(), 1, context) {}
+
+  /// Used for staging the transfer in a local scalar buffer.
+  MemRefType tmpMemRefType(VectorTransferOpTy transfer) const {
+    auto vectorType = transfer.getVectorType();
+    return MemRefType::get(vectorType.getShape(), vectorType.getElementType(),
+                           {}, 0);
+  }
+
+  /// View of tmpMemRefType as one vector, used in vector load/store to tmp
+  /// buffer.
+  MemRefType vectorMemRefType(VectorTransferOpTy transfer) const {
+    return MemRefType::get({1}, transfer.getVectorType(), {}, 0);
+  }
+
+  /// Performs the rewrite.
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const override;
+};
+
+/// Analyzes the `transfer` to find an access dimension along the fastest remote
+/// MemRef dimension. If such a dimension with coalescing properties is found,
+/// `pivs` and `vectorView` are swapped so that the invocation of
+/// LoopNestBuilder captures it in the innermost loop.
+template <typename VectorTransferOpTy>
+void coalesceCopy(VectorTransferOpTy transfer,
+                  SmallVectorImpl<edsc::ValueHandle *> *pivs,
+                  edsc::VectorView *vectorView) {
+  // rank of the remote memory access, coalescing behavior occurs on the
+  // innermost memory dimension.
+  auto remoteRank = transfer.getMemRefType().getRank();
+  // Iterate over the results expressions of the permutation map to determine
+  // the loop order for creating pointwise copies between remote and local
+  // memories.
+  int coalescedIdx = -1;
+  auto exprs = transfer.getPermutationMap().getResults();
+  for (auto en : llvm::enumerate(exprs)) {
+    auto dim = en.value().template dyn_cast<AffineDimExpr>();
+    if (!dim) {
+      continue;
+    }
+    auto memRefDim = dim.getPosition();
+    if (memRefDim == remoteRank - 1) {
+      // memRefDim has coalescing properties, it should be swapped in the last
+      // position.
+      assert(coalescedIdx == -1 && "Unexpected > 1 coalesced indices");
+      coalescedIdx = en.index();
+    }
+  }
+  if (coalescedIdx >= 0) {
+    std::swap(pivs->back(), (*pivs)[coalescedIdx]);
+    vectorView->swapRanges(pivs->size() - 1, coalescedIdx);
+  }
+}
+
+/// Emits remote memory accesses that are clipped to the boundaries of the
+/// MemRef.
+template <typename VectorTransferOpTy>
+llvm::SmallVector<edsc::ValueHandle, 8> clip(VectorTransferOpTy transfer,
+                                             edsc::MemRefView &view,
+                                             ArrayRef<edsc::IndexHandle> ivs) {
+  using namespace mlir::edsc;
+  using namespace edsc::op;
+  using edsc::intrinsics::select;
+
+  IndexHandle zero(index_t(0)), one(index_t(1));
+  llvm::SmallVector<edsc::ValueHandle, 8> memRefAccess(transfer.getIndices());
+  llvm::SmallVector<edsc::ValueHandle, 8> clippedScalarAccessExprs(
+      memRefAccess.size(), edsc::IndexHandle());
+
+  // Indices accessing to remote memory are clipped and their expressions are
+  // returned in clippedScalarAccessExprs.
+  for (unsigned memRefDim = 0; memRefDim < clippedScalarAccessExprs.size();
+       ++memRefDim) {
+    // Linear search on a small number of entries.
+    int loopIndex = -1;
+    auto exprs = transfer.getPermutationMap().getResults();
+    for (auto en : llvm::enumerate(exprs)) {
+      auto expr = en.value();
+      auto dim = expr.template dyn_cast<AffineDimExpr>();
+      // Sanity check.
+      assert(
+          (dim || expr.template cast<AffineConstantExpr>().getValue() == 0) &&
+          "Expected dim or 0 in permutationMap");
+      if (dim && memRefDim == dim.getPosition()) {
+        loopIndex = en.index();
+        break;
+      }
+    }
+
+    // We cannot distinguish atm between unrolled dimensions that implement
+    // the "always full" tile abstraction and need clipping from the other
+    // ones. So we conservatively clip everything.
+    auto N = view.ub(memRefDim);
+    auto i = memRefAccess[memRefDim];
+    if (loopIndex < 0) {
+      auto N_minus_1 = N - one;
+      auto select_1 = select(i < N, i, N_minus_1);
+      clippedScalarAccessExprs[memRefDim] = select(i < zero, zero, select_1);
+    } else {
+      auto ii = ivs[loopIndex];
+      auto i_plus_ii = i + ii;
+      auto N_minus_1 = N - one;
+      auto select_1 = select(i_plus_ii < N, i_plus_ii, N_minus_1);
+      clippedScalarAccessExprs[memRefDim] =
+          select(i_plus_ii < zero, zero, select_1);
+    }
+  }
+
+  return clippedScalarAccessExprs;
+}
+
+/// Lowers VectorTransferReadOp into a combination of:
+///   1. local memory allocation;
+///   2. perfect loop nest over:
+///      a. scalar load from local buffers (viewed as a scalar memref);
+///      a. scalar store to original memref (with clipping).
+///   3. vector_load from local buffer (viewed as a memref<1 x vector>);
+///   4. local memory deallocation.
+///
+/// Lowers the data transfer part of a VectorTransferReadOp while ensuring no
+/// out-of-bounds accesses are possible. Out-of-bounds behavior is handled by
+/// clipping. This means that a given value in memory can be read multiple
+/// times and concurrently.
+///
+/// Important notes about clipping and "full-tiles only" abstraction:
+/// =================================================================
+/// When using clipping for dealing with boundary conditions, the same edge
+/// value will appear multiple times (a.k.a edge padding). This is fine if the
+/// subsequent vector operations are all data-parallel but **is generally
+/// incorrect** in the presence of reductions or extract operations.
+///
+/// More generally, clipping is a scalar abstraction that is expected to work
+/// fine as a baseline for CPUs and GPUs but not for vector_load and DMAs.
+/// To deal with real vector_load and DMAs, a "padded allocation + view"
+/// abstraction with the ability to read out-of-memref-bounds (but still within
+/// the allocated region) is necessary.
+///
+/// Whether using scalar loops or vector_load/DMAs to perform the transfer,
+/// junk values will be materialized in the vectors and generally need to be
+/// filtered out and replaced by the "neutral element". This neutral element is
+/// op-dependent so, in the future, we expect to create a vector filter and
+/// apply it to a splatted constant vector with the proper neutral element at
+/// each ssa-use. This filtering is not necessary for pure data-parallel
+/// operations.
+///
+/// In the case of vector_store/DMAs, Read-Modify-Write will be required, which
+/// also have concurrency implications. Note that by using clipped scalar stores
+/// in the presence of data-parallel only operations, we generate code that
+/// writes the same value multiple time on the edge locations.
+///
+/// TODO(ntv): implement alternatives to clipping.
+/// TODO(ntv): support non-data-parallel operations.
+
+/// Performs the rewrite.
+template <>
+PatternMatchResult
+VectorTransferRewriter<VectorTransferReadOp>::matchAndRewrite(
+    Operation *op, PatternRewriter &rewriter) const {
+  using namespace mlir::edsc;
+  using namespace mlir::edsc::op;
+  using namespace mlir::edsc::intrinsics;
+  using IndexedValue =
+      TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>;
+
+  VectorTransferReadOp transfer = cast<VectorTransferReadOp>(op);
+
+  // 1. Setup all the captures.
+  ScopedContext scope(rewriter, transfer.getLoc());
+  IndexedValue remote(transfer.getMemRef());
+  MemRefView view(transfer.getMemRef());
+  VectorView vectorView(transfer.getVector());
+  SmallVector<IndexHandle, 8> ivs = makeIndexHandles(vectorView.rank());
+  SmallVector<ValueHandle *, 8> pivs =
+      makeIndexHandlePointers(MutableArrayRef<IndexHandle>(ivs));
+  coalesceCopy(transfer, &pivs, &vectorView);
+
+  auto lbs = vectorView.getLbs();
+  auto ubs = vectorView.getUbs();
+  auto steps = vectorView.getSteps();
+
+  // 2. Emit alloc-copy-load-dealloc.
+  ValueHandle tmp = alloc(tmpMemRefType(transfer));
+  IndexedValue local(tmp);
+  ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer));
+  LoopNestBuilder(pivs, lbs, ubs, steps)([&] {
+    // Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
+    local(ivs) = remote(clip(transfer, view, ivs));
+  });
+  ValueHandle vectorValue = std_load(vec, {constant_index(0)});
+  (dealloc(tmp)); // vexing parse
+
+  // 3. Propagate.
+  rewriter.replaceOp(op, vectorValue.getValue());
+  return matchSuccess();
+}
+
+/// Lowers VectorTransferWriteOp into a combination of:
+///   1. local memory allocation;
+///   2. vector_store to local buffer (viewed as a memref<1 x vector>);
+///   3. perfect loop nest over:
+///      a. scalar load from local buffers (viewed as a scalar memref);
+///      a. scalar store to original memref (with clipping).
+///   4. local memory deallocation.
+///
+/// More specifically, lowers the data transfer part while ensuring no
+/// out-of-bounds accesses are possible. Out-of-bounds behavior is handled by
+/// clipping. This means that a given value in memory can be written to multiple
+/// times and concurrently.
+///
+/// See `Important notes about clipping and full-tiles only abstraction` in the
+/// description of `readClipped` above.
+///
+/// TODO(ntv): implement alternatives to clipping.
+/// TODO(ntv): support non-data-parallel operations.
+template <>
+PatternMatchResult
+VectorTransferRewriter<VectorTransferWriteOp>::matchAndRewrite(
+    Operation *op, PatternRewriter &rewriter) const {
+  using namespace mlir::edsc;
+  using namespace mlir::edsc::op;
+  using namespace mlir::edsc::intrinsics;
+  using IndexedValue =
+      TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>;
+
+  VectorTransferWriteOp transfer = cast<VectorTransferWriteOp>(op);
+
+  // 1. Setup all the captures.
+  ScopedContext scope(rewriter, transfer.getLoc());
+  IndexedValue remote(transfer.getMemRef());
+  MemRefView view(transfer.getMemRef());
+  ValueHandle vectorValue(transfer.getVector());
+  VectorView vectorView(transfer.getVector());
+  SmallVector<IndexHandle, 8> ivs = makeIndexHandles(vectorView.rank());
+  SmallVector<ValueHandle *, 8> pivs = makeIndexHandlePointers(ivs);
+  coalesceCopy(transfer, &pivs, &vectorView);
+
+  auto lbs = vectorView.getLbs();
+  auto ubs = vectorView.getUbs();
+  auto steps = vectorView.getSteps();
+
+  // 2. Emit alloc-store-copy-dealloc.
+  ValueHandle tmp = alloc(tmpMemRefType(transfer));
+  IndexedValue local(tmp);
+  ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer));
+  std_store(vectorValue, vec, {constant_index(0)});
+  LoopNestBuilder(pivs, lbs, ubs, steps)([&] {
+    // Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
+    remote(clip(transfer, view, ivs)) = local(ivs);
+  });
+  (dealloc(tmp)); // vexing parse...
+
+  rewriter.replaceOp(op, llvm::None);
+  return matchSuccess();
+}
+
+struct LowerVectorTransfersPass
+    : public FunctionPass<LowerVectorTransfersPass> {
+  void runOnFunction() {
+    OwningRewritePatternList patterns;
+    auto *context = &getContext();
+    patterns.insert<VectorTransferRewriter<vector::VectorTransferReadOp>,
+                    VectorTransferRewriter<vector::VectorTransferWriteOp>>(
+        context);
+    applyPatternsGreedily(getFunction(), patterns);
+  }
+};
+
+} // end anonymous namespace
+
+FunctionPassBase *mlir::createLowerVectorTransfersPass() {
+  return new LowerVectorTransfersPass();
+}
+
+static PassRegistration<LowerVectorTransfersPass>
+    pass("affine-lower-vector-transfers",
+         "Materializes vector transfer ops to a "
+         "proper abstraction for the hardware");
diff --git a/third_party/mlir/lib/Transforms/MaterializeVectors.cpp b/third_party/mlir/lib/Transforms/MaterializeVectors.cpp
new file mode 100644
index 0000000..17acc92
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/MaterializeVectors.cpp
@@ -0,0 +1,779 @@
+//===- MaterializeVectors.cpp - MaterializeVectors Pass Impl --------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements target-dependent materialization of super-vectors to
+// vectors of the proper size for the hardware.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/Dominance.h"
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Analysis/NestedMatcher.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/Analysis/VectorAnalysis.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Support/Functional.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/VectorOps/VectorOps.h"
+
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+///
+/// Implements target-dependent materialization of virtual super-vectors to
+/// vectors of the proper size for the hardware.
+///
+/// While the physical vector size is target-dependent, the pass is written in
+/// a target-independent way: the target vector size is specified as a parameter
+/// to the pass. This pass is thus a partial lowering that opens the "greybox"
+/// that is the super-vector abstraction. In particular, this pass can turn the
+/// vector.transfer_read and vector.transfer_write ops in either:
+///   1. a loop nest with either scalar and vector load/store operations; or
+///   2. a loop-nest with DmaStartOp / DmaWaitOp; or
+///   3. a pre-existing blackbox library call that can be written manually or
+///      synthesized using search and superoptimization.
+/// An important feature that either of these 3 target lowering abstractions
+/// must handle is the handling of "non-effecting" padding with the proper
+/// neutral element in order to guarantee that all "partial tiles" are actually
+/// "full tiles" in practice.
+///
+/// In particular this pass is a MLIR-MLIR rewriting and does not concern itself
+/// with target-specific instruction-selection and register allocation. These
+/// will happen downstream in LLVM.
+///
+/// In this sense, despite performing lowering to a target-dependent size, this
+/// pass is still target-agnostic.
+///
+/// Implementation details
+/// ======================
+/// The current decisions made by the super-vectorization pass guarantee that
+/// use-def chains do not escape an enclosing vectorized AffineForOp. In other
+/// words, this pass operates on a scoped program slice. Furthermore, since we
+/// do not vectorize in the presence of conditionals for now, sliced chains are
+/// guaranteed not to escape the innermost scope, which has to be either the top
+/// Function scope or the innermost loop scope, by construction. As a
+/// consequence, the implementation just starts from vector.transfer_write
+/// operations and builds the slice scoped the innermost loop enclosing the
+/// current vector.transfer_write. These assumptions and the implementation
+/// details are subject to revision in the future.
+///
+/// Example
+/// ========
+/// In the following, the single vector.transfer_write op operates on a
+/// vector<4x4x4xf32>. Let's assume the HW supports vector<4x4xf32>.
+/// Materialization is achieved by instantiating each occurrence of the leading
+/// dimension of vector<4x4x4xf32> into a vector<4x4xf32>.
+/// The program transformation that implements this instantiation is a
+/// multi-loop unroll-and-jam (it can be partial or full depending on the ratio
+/// of super-vector shape to HW-vector shape).
+///
+/// As a simple case, the following:
+///
+/// ```mlir
+///    mlfunc @materialize(%M : index, %N : index, %O : index, %P : index) {
+///      %A = alloc (%M, %N, %O, %P) : memref<?x?x?x?xf32>
+///      %f1 = constant dense<vector<4x4x4xf32>, 1.000000e+00> :
+///      vector<4x4x4xf32> affine.for %i0 = 0 to %M step 4 {
+///        affine.for %i1 = 0 to %N step 4 {
+///          affine.for %i2 = 0 to %O {
+///            affine.for %i3 = 0 to %P step 4 {
+///              vector.transfer_write %f1, %A[%i0, %i1, %i2, %i3]
+///                {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d0)} :
+///                 vector<4x4x4xf32>, memref<?x?x?x?xf32>
+///      }}}}
+///      return
+///    }
+/// ```
+///
+/// is instantiated by unroll-and-jam (just unroll in this case) into:
+///
+/// ```mlir
+///    mlfunc @materialize(%M : index, %N : index, %O : index, %P : index) {
+///      %A = alloc (%M, %N, %O, %P) : memref<?x?x?x?xf32, 0>
+///      %f1 = constant dense<vector<4x4xf32>, 1.000000e+00> : vector<4x4x4xf32>
+///       affine.for %i0 = 0 to %arg0 step 4 {
+///         affine.for %i1 = 0 to %arg1 step 4 {
+///           affine.for %i2 = 0 to %arg2 {
+///             affine.for %i3 = 0 to %arg3 step 4 {
+///               vector.transfer_write f1, %0[%i0, %i1, %i2, %i3]
+///                 {permutation_map: (d0, d1, d2, d3) -> (d1, d0)} :
+///                 vector<4x4xf32>, memref<?x?x?x?xf32>
+///               %i3p1 = affine.apply (d0) -> (d0 + 1)(%i3)
+///               vector.transfer_write {{.*}}, %0[%i0, %i1, %i2, %i3p1]
+///                 {permutation_map: (d0, d1, d2, d3) -> (d1, d0)} :
+///                 vector<4x4xf32>, memref<?x?x?x?xf32>
+///               %i3p2 = affine.apply (d0) -> (d0 + 2)(%i3)
+///               vector.transfer_write {{.*}}, %0[%i0, %i1, %i2, %i3p2]
+///                 {permutation_map: (d0, d1, d2, d3) -> (d1, d0)} :
+///                 vector<4x4xf32>, memref<?x?x?x?xf32>
+///               %i3p3 = affine.apply (d0) -> (d0 + 3)(%i3)
+///               vector.transfer_write {{.*}}, %0[%i0, %i1, %i2, %i3p3]
+///                 {permutation_map: (d0, d1, d2, d3) -> (d1, d0)} :
+///                 vector<4x4xf32>, memref<?x?x?x?xf32>
+///      }}}}
+///      return
+///    }
+/// ```
+
+using llvm::dbgs;
+using llvm::SetVector;
+
+using namespace mlir;
+using vector::VectorTransferReadOp;
+using vector::VectorTransferWriteOp;
+
+using functional::makePtrDynCaster;
+using functional::map;
+
+static llvm::cl::list<int>
+    clVectorSize("vector-size",
+                 llvm::cl::desc("Specify the HW vector size for vectorization"),
+                 llvm::cl::ZeroOrMore);
+
+#define DEBUG_TYPE "materialize-vect"
+
+namespace {
+struct MaterializationState {
+  /// In practice, the determination of the HW-specific vector type to use when
+  /// lowering a super-vector type must be based on the elemental type. The
+  /// elemental type must be retrieved from the super-vector type. In the future
+  /// information about hardware vector type for a particular elemental type
+  /// will be part of the contract between MLIR and the backend.
+  ///
+  /// For example, 8xf32 has the same size as 16xf16 but the targeted HW itself
+  /// may exhibit the following property:
+  /// 1. have a special unit for a 128xf16 datapath;
+  /// 2. no F16 FPU support on the regular 8xf32/16xf16 vector datapath.
+  ///
+  /// For now, we just assume hwVectorSize has the proper information regardless
+  /// of the type and we assert everything is f32.
+  /// TODO(ntv): relax the assumptions on admissible element type once a
+  /// contract exists.
+  MaterializationState(SmallVector<int64_t, 8> sizes) : hwVectorSize(sizes) {}
+
+  SmallVector<int64_t, 8> hwVectorSize;
+  VectorType superVectorType;
+  VectorType hwVectorType;
+  SmallVector<unsigned, 8> hwVectorInstance;
+  DenseMap<Value *, Value *> *substitutionsMap;
+};
+
+/// Base state for the vector materialization pass.
+/// Command line arguments are preempted by non-empty pass arguments.
+struct MaterializeVectorsPass : public FunctionPass<MaterializeVectorsPass> {
+  MaterializeVectorsPass()
+      : hwVectorSize(clVectorSize.begin(), clVectorSize.end()) {}
+  MaterializeVectorsPass(ArrayRef<int64_t> hwVectorSize)
+      : MaterializeVectorsPass() {
+    if (!hwVectorSize.empty())
+      this->hwVectorSize.assign(hwVectorSize.begin(), hwVectorSize.end());
+  }
+
+  SmallVector<int64_t, 8> hwVectorSize;
+  void runOnFunction() override;
+};
+
+} // end anonymous namespace
+
+/// Given a shape with sizes greater than 0 along all dimensions,
+/// returns the distance, in number of elements, between a slice in a dimension
+/// and the next slice in the same dimension.
+///   e.g. shape[3, 4, 5] -> strides[20, 5, 1]
+static SmallVector<unsigned, 8> makeStrides(ArrayRef<unsigned> shape) {
+  SmallVector<unsigned, 8> tmp;
+  tmp.reserve(shape.size());
+  unsigned running = 1;
+  for (auto rit = shape.rbegin(), reit = shape.rend(); rit != reit; ++rit) {
+    assert(*rit > 0 && "size must be greater than 0 along all dimensions of "
+                       "shape");
+    tmp.push_back(running);
+    running *= *rit;
+  }
+  return SmallVector<unsigned, 8>(tmp.rbegin(), tmp.rend());
+}
+
+/// Given a shape with sizes greater than 0 along all dimensions, returns the
+/// delinearized components of linearIndex along shape.
+static SmallVector<unsigned, 8> delinearize(unsigned linearIndex,
+                                            ArrayRef<unsigned> shape) {
+  SmallVector<unsigned, 8> res;
+  res.reserve(shape.size());
+  auto strides = makeStrides(shape);
+  for (unsigned idx = 0; idx < strides.size(); ++idx) {
+    assert(strides[idx] > 0);
+    auto val = linearIndex / strides[idx];
+    res.push_back(val);
+    assert(val < shape[idx] && "delinearization is out of bounds");
+    linearIndex %= strides[idx];
+  }
+  // Sanity check.
+  assert(linearIndex == 0 && "linear index constructed from shape must "
+                             "have 0 remainder after delinearization");
+  return res;
+}
+
+static Operation *instantiate(OpBuilder b, Operation *opInst,
+                              VectorType hwVectorType,
+                              DenseMap<Value *, Value *> *substitutionsMap);
+
+/// Not all Values belong to a program slice scoped within the immediately
+/// enclosing loop.
+/// One simple example is constants defined outside the innermost loop scope.
+/// For such cases the substitutionsMap has no entry and we allow an additional
+/// insertion.
+/// For now, this is limited to ConstantOp because we do not vectorize loop
+/// indices and will need to be extended in the future.
+///
+/// If substitution fails, returns nullptr.
+static Value *substitute(Value *v, VectorType hwVectorType,
+                         DenseMap<Value *, Value *> *substitutionsMap) {
+  auto it = substitutionsMap->find(v);
+  if (it == substitutionsMap->end()) {
+    auto *opInst = v->getDefiningOp();
+    if (isa<ConstantOp>(opInst)) {
+      OpBuilder b(opInst);
+      auto *op = instantiate(b, opInst, hwVectorType, substitutionsMap);
+      auto res = substitutionsMap->insert(std::make_pair(v, op->getResult(0)));
+      assert(res.second && "Insertion failed");
+      return res.first->second;
+    }
+    v->getDefiningOp()->emitError("Missing substitution");
+    return nullptr;
+  }
+  return it->second;
+}
+
+/// Returns a list of single result AffineApplyOps that reindex the
+/// `memRefIndices` by the multi-dimensional `hwVectorInstance`. This is used by
+/// the function that materializes a vector.transfer operation to use hardware
+/// vector types instead of super-vector types.
+///
+/// The general problem this function solves is as follows:
+/// Assume a vector.transfer operation at the super-vector granularity that has
+/// `l` enclosing loops (AffineForOp). Assume the vector transfer operation
+/// operates on a MemRef of rank `r`, a super-vector of rank `s` and a hardware
+/// vector of rank `h`. For the purpose of illustration assume l==4, r==3, s==2,
+/// h==1 and that the super-vector is vector<3x32xf32> and the hardware vector
+/// is vector<8xf32>. Assume the following MLIR snippet after
+/// super-vectorization has been applied:
+///
+/// ```mlir
+/// affine.for %i0 = 0 to %M {
+///   affine.for %i1 = 0 to %N step 3 {
+///     affine.for %i2 = 0 to %O {
+///       affine.for %i3 = 0 to %P step 32 {
+///         %r = vector.transfer_read(%A, map0(%i..), map1(%i..), map2(%i..)) :
+///              vector<3x32xf32>, memref<?x?x?xf32>
+///         ...
+/// }}}}
+/// ```
+///
+/// where map denotes an AffineMap operating on enclosing loops with properties
+/// compatible for vectorization (i.e. some contiguity left unspecified here).
+/// Note that the vectorized loops are %i1 and %i3.
+/// This function translates the vector.transfer_read operation to multiple
+/// instances of vector.transfer_read that operate on vector<8x32>.
+///
+/// Without loss of generality, we assume hwVectorInstance is: {2, 1}.
+/// The only constraints on hwVectorInstance is they belong to:
+///   [0, 2] x [0, 3], which is the span of ratio of super-vector shape to
+/// hardware vector shape in our example.
+///
+/// This function instantiates the iteration <2, 1> of vector.transfer_read
+/// into the set of operations in pseudo-MLIR:
+///
+/// ```mlir
+///   #map2 = (d0, d1, d2, d3) -> (d0, d1 + 2, d2, d3 + 1 * 8)
+///   #map3 = #map o #map2 // where o denotes composition
+///   aff0 = affine.apply #map3.0(%i..)
+///   aff1 = affine.apply #map3.1(%i..)
+///   aff2 = affine.apply #map3.2(%i..)
+///   %r = vector.transfer_read(%A, %aff0, %aff1, %aff2):
+//         vector<3x32xf32>, memref<?x?x?xf32>
+/// ```
+///
+/// Practical considerations
+/// ========================
+/// For now, `map` is assumed to be the identity map and the indices are
+/// specified just as vector.transfer_read%A[%i0, %i1, %i2, %i3]. This will be
+/// extended in the future once we have a proper Op for vector transfers.
+/// Additionally, the example above is specified in pseudo-MLIR form; once we
+/// have proper support for generic maps we can generate the code and show
+/// actual MLIR.
+///
+/// TODO(ntv): support a concrete AffineMap and compose with it.
+/// TODO(ntv): these implementation details should be captured in a
+/// vectorization trait at the op level directly.
+static SmallVector<mlir::Value *, 8>
+reindexAffineIndices(OpBuilder b, VectorType hwVectorType,
+                     ArrayRef<unsigned> hwVectorInstance,
+                     ArrayRef<Value *> memrefIndices) {
+  auto vectorShape = hwVectorType.getShape();
+  assert(hwVectorInstance.size() >= vectorShape.size());
+
+  unsigned numIndices = memrefIndices.size();
+  auto numMemRefIndices = numIndices - hwVectorInstance.size();
+  auto numVectorIndices = hwVectorInstance.size() - vectorShape.size();
+
+  SmallVector<AffineExpr, 8> affineExprs;
+  // TODO(ntv): support a concrete map and composition.
+  unsigned i = 0;
+  // The first numMemRefIndices correspond to AffineForOp that have not been
+  // vectorized, the transformation is the identity on those.
+  for (i = 0; i < numMemRefIndices; ++i) {
+    auto d_i = b.getAffineDimExpr(i);
+    affineExprs.push_back(d_i);
+  }
+  // The next numVectorIndices correspond to super-vector dimensions that
+  // do not have a hardware vector dimension counterpart. For those we only
+  // need to increment the index by the corresponding hwVectorInstance.
+  for (i = numMemRefIndices; i < numMemRefIndices + numVectorIndices; ++i) {
+    auto d_i = b.getAffineDimExpr(i);
+    auto offset = hwVectorInstance[i - numMemRefIndices];
+    affineExprs.push_back(d_i + offset);
+  }
+  // The remaining indices correspond to super-vector dimensions that
+  // have a hardware vector dimension counterpart. For those we to increment the
+  // index by "hwVectorInstance" multiples of the corresponding hardware
+  // vector size.
+  for (; i < numIndices; ++i) {
+    auto d_i = b.getAffineDimExpr(i);
+    auto offset = hwVectorInstance[i - numMemRefIndices];
+    auto stride = vectorShape[i - numMemRefIndices - numVectorIndices];
+    affineExprs.push_back(d_i + offset * stride);
+  }
+
+  // Create a bunch of single result AffineApplyOp.
+  SmallVector<mlir::Value *, 8> res;
+  res.reserve(affineExprs.size());
+  for (auto expr : affineExprs) {
+    auto map = AffineMap::get(numIndices, 0, expr);
+    res.push_back(makeComposedAffineApply(b, b.getInsertionPoint()->getLoc(),
+                                          map, memrefIndices));
+  }
+  return res;
+}
+
+/// Returns attributes with the following substitutions applied:
+///   - constant splat is replaced by constant splat of `hwVectorType`.
+/// TODO(ntv): add more substitutions on a per-need basis.
+static SmallVector<NamedAttribute, 1>
+materializeAttributes(Operation *opInst, VectorType hwVectorType) {
+  SmallVector<NamedAttribute, 1> res;
+  for (auto a : opInst->getAttrs()) {
+    if (auto splat = a.second.dyn_cast<SplatElementsAttr>()) {
+      auto attr = SplatElementsAttr::get(hwVectorType, splat.getSplatValue());
+      res.push_back(NamedAttribute(a.first, attr));
+    } else {
+      res.push_back(a);
+    }
+  }
+  return res;
+}
+
+/// Creates an instantiated version of `opInst`.
+/// Ops other than VectorTransferReadOp/VectorTransferWriteOp require no
+/// affine reindexing. Just substitute their Value operands and be done. For
+/// this case the actual instance is irrelevant. Just use the values in
+/// substitutionsMap.
+///
+/// If the underlying substitution fails, this fails too and returns nullptr.
+static Operation *instantiate(OpBuilder b, Operation *opInst,
+                              VectorType hwVectorType,
+                              DenseMap<Value *, Value *> *substitutionsMap) {
+  assert(!isa<VectorTransferReadOp>(opInst) &&
+         "Should call the function specialized for VectorTransferReadOp");
+  assert(!isa<VectorTransferWriteOp>(opInst) &&
+         "Should call the function specialized for VectorTransferWriteOp");
+  if (opInst->getNumRegions() != 0)
+    return nullptr;
+
+  bool fail = false;
+  auto operands = map(
+      [hwVectorType, substitutionsMap, &fail](Value *v) -> Value * {
+        auto *res =
+            fail ? nullptr : substitute(v, hwVectorType, substitutionsMap);
+        fail |= !res;
+        return res;
+      },
+      opInst->getOperands());
+  if (fail)
+    return nullptr;
+
+  auto attrs = materializeAttributes(opInst, hwVectorType);
+
+  OperationState state(opInst->getLoc(), opInst->getName().getStringRef(),
+                       operands, {hwVectorType}, attrs);
+  return b.createOperation(state);
+}
+
+/// Computes the permutationMap required for a VectorTransferOp from the memref
+/// to the `hwVectorType`.
+/// This is achieved by returning the projection of the permutationMap along the
+/// dimensions of the super-vector type that remain in the hwVectorType.
+/// In particular, if a dimension is fully instantiated (i.e. unrolled) then it
+/// is projected out in the final result.
+template <typename VectorTransferOpTy>
+static AffineMap projectedPermutationMap(VectorTransferOpTy transfer,
+                                         VectorType hwVectorType) {
+  static_assert(
+      std::is_same<VectorTransferOpTy, VectorTransferReadOp>::value ||
+          std::is_same<VectorTransferOpTy, VectorTransferWriteOp>::value,
+      "Must be called on a VectorTransferOp");
+  auto superVectorType = transfer.getVectorType();
+  auto optionalRatio = shapeRatio(superVectorType, hwVectorType);
+  assert(optionalRatio &&
+         (optionalRatio->size() == superVectorType.getShape().size()) &&
+         "Shape and ratio not of the same size");
+  unsigned dim = 0;
+  SmallVector<AffineExpr, 4> keep;
+  MLIRContext *context = transfer.getContext();
+  functional::zipApply(
+      [&dim, &keep, context](int64_t shape, int64_t ratio) {
+        assert(shape >= ratio && "shape dim must be greater than ratio dim");
+        if (shape != ratio) {
+          // HW vector is not full instantiated along this dim, keep it.
+          keep.push_back(getAffineDimExpr(dim, context));
+        }
+        ++dim;
+      },
+      superVectorType.getShape(), *optionalRatio);
+  auto permutationMap = transfer.getPermutationMap();
+  LLVM_DEBUG(permutationMap.print(dbgs() << "\npermutationMap: "));
+  if (keep.empty()) {
+    return permutationMap;
+  }
+  auto projectionMap = AffineMap::get(optionalRatio->size(), 0, keep);
+  LLVM_DEBUG(projectionMap.print(dbgs() << "\nprojectionMap: "));
+  return simplifyAffineMap(projectionMap.compose(permutationMap));
+}
+
+/// Creates an instantiated version of `read` for the instance of
+/// `hwVectorInstance` when lowering from a super-vector type to
+/// `hwVectorType`. `hwVectorInstance` represents one particular instance of
+/// `hwVectorType` int the covering of the super-vector type. For a more
+/// detailed description of the problem, see the description of
+/// reindexAffineIndices.
+static Operation *instantiate(OpBuilder b, VectorTransferReadOp read,
+                              VectorType hwVectorType,
+                              ArrayRef<unsigned> hwVectorInstance,
+                              DenseMap<Value *, Value *> *substitutionsMap) {
+  SmallVector<Value *, 8> indices =
+      map(makePtrDynCaster<Value>(), read.getIndices());
+  auto affineIndices =
+      reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices);
+  auto map = projectedPermutationMap(read, hwVectorType);
+  if (!map) {
+    return nullptr;
+  }
+  auto cloned = b.create<VectorTransferReadOp>(read.getLoc(), hwVectorType,
+                                               read.getMemRef(), affineIndices,
+                                               map, read.getPaddingValue());
+  return cloned.getOperation();
+}
+
+/// Creates an instantiated version of `write` for the instance of
+/// `hwVectorInstance` when lowering from a super-vector type to
+/// `hwVectorType`. `hwVectorInstance` represents one particular instance of
+/// `hwVectorType` int the covering of th3e super-vector type. For a more
+/// detailed description of the problem, see the description of
+/// reindexAffineIndices.
+static Operation *instantiate(OpBuilder b, VectorTransferWriteOp write,
+                              VectorType hwVectorType,
+                              ArrayRef<unsigned> hwVectorInstance,
+                              DenseMap<Value *, Value *> *substitutionsMap) {
+  SmallVector<Value *, 8> indices =
+      map(makePtrDynCaster<Value>(), write.getIndices());
+  auto affineIndices =
+      reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices);
+  auto cloned = b.create<VectorTransferWriteOp>(
+      write.getLoc(),
+      substitute(write.getVector(), hwVectorType, substitutionsMap),
+      write.getMemRef(), affineIndices,
+      projectedPermutationMap(write, hwVectorType));
+  return cloned.getOperation();
+}
+
+/// Returns `true` if op instance is properly cloned and inserted, false
+/// otherwise.
+/// The multi-dimensional `hwVectorInstance` belongs to the shapeRatio of
+/// super-vector type to hw vector type.
+/// A cloned instance of `op` is formed as follows:
+///   1. vector.transfer_read: the return `superVectorType` is replaced by
+///      `hwVectorType`. Additionally, affine indices are reindexed with
+///      `reindexAffineIndices` using `hwVectorInstance` and vector type
+///      information;
+///   2. vector.transfer_write: the `valueToStore` type is simply substituted.
+///      Since we operate on a topologically sorted slice, a substitution must
+///      have been registered for non-constant ops. Additionally, affine indices
+///      are reindexed in the same way as for vector.transfer_read;
+///   3. constant ops are splats of the super-vector type by construction.
+///      They are cloned to a splat on the hw vector type with the same value;
+///   4. remaining ops are cloned to version of the op that returns a hw vector
+///      type, all operands are substituted according to `substitutions`. Thanks
+///      to the topological order of a slice, the substitution is always
+///      possible.
+///
+/// Returns true on failure.
+static bool instantiateMaterialization(Operation *op,
+                                       MaterializationState *state) {
+  LLVM_DEBUG(dbgs() << "\ninstantiate: " << *op);
+
+  // Create a builder here for unroll-and-jam effects.
+  OpBuilder b(op);
+  // AffineApplyOp are ignored: instantiating the proper vector op will take
+  // care of AffineApplyOps by composing them properly.
+  if (isa<AffineApplyOp>(op)) {
+    return false;
+  }
+  if (op->getNumRegions() != 0)
+    return op->emitError("NYI path Op with region"), true;
+
+  if (auto write = dyn_cast<VectorTransferWriteOp>(op)) {
+    auto *clone = instantiate(b, write, state->hwVectorType,
+                              state->hwVectorInstance, state->substitutionsMap);
+    return clone == nullptr;
+  }
+  if (auto read = dyn_cast<VectorTransferReadOp>(op)) {
+    auto *clone = instantiate(b, read, state->hwVectorType,
+                              state->hwVectorInstance, state->substitutionsMap);
+    if (!clone) {
+      return true;
+    }
+    state->substitutionsMap->insert(
+        std::make_pair(read.getResult(), clone->getResult(0)));
+    return false;
+  }
+  // The only op with 0 results reaching this point must, by construction, be
+  // VectorTransferWriteOps and have been caught above. Ops with >= 2 results
+  // are not yet supported. So just support 1 result.
+  if (op->getNumResults() != 1) {
+    return op->emitError("NYI: ops with != 1 results"), true;
+  }
+  if (op->getResult(0)->getType() != state->superVectorType) {
+    return op->emitError("Op does not return a supervector."), true;
+  }
+  auto *clone =
+      instantiate(b, op, state->hwVectorType, state->substitutionsMap);
+  if (!clone) {
+    return true;
+  }
+  state->substitutionsMap->insert(
+      std::make_pair(op->getResult(0), clone->getResult(0)));
+  return false;
+}
+
+/// Takes a slice and rewrites the operations in it so that occurrences
+/// of `superVectorType` are replaced by `hwVectorType`.
+///
+/// Implementation
+/// ==============
+///   1. computes the shape ratio of super-vector to HW vector shapes. This
+///      gives for each op in the slice, how many instantiations are required
+///      in each dimension;
+///   2. performs the concrete materialization. Note that in a first
+///      implementation we use full unrolling because it pragmatically removes
+///      the need to explicitly materialize an AllocOp. Thanks to the properties
+///      of super-vectors, this unrolling is always possible and simple:
+///      vectorizing to a super-vector abstraction already achieved the
+///      equivalent of loop strip-mining + loop sinking and encoded this in the
+///      vector type.
+///
+/// Returns true on failure.
+///
+/// TODO(ntv): materialized allocs.
+/// TODO(ntv): full loops + materialized allocs.
+/// TODO(ntv): partial unrolling + materialized allocs.
+static bool emitSlice(MaterializationState *state,
+                      SetVector<Operation *> *slice) {
+  auto ratio = shapeRatio(state->superVectorType, state->hwVectorType);
+  assert(ratio.hasValue() &&
+         "ratio of super-vector to HW-vector shape is not integral");
+  // The number of integer points in a hyperrectangular region is:
+  // shape[0] * strides[0].
+  auto numValueToUnroll = (*ratio)[0] * makeStrides(*ratio)[0];
+  // Full unrolling to hardware vectors in a first approximation.
+  for (unsigned idx = 0; idx < numValueToUnroll; ++idx) {
+    // Fresh RAII instanceIndices and substitutionsMap.
+    MaterializationState scopedState = *state;
+    scopedState.hwVectorInstance = delinearize(idx, *ratio);
+    DenseMap<Value *, Value *> substitutionMap;
+    scopedState.substitutionsMap = &substitutionMap;
+    // slice are topologically sorted, we can just clone them in order.
+    for (auto *op : *slice) {
+      auto fail = instantiateMaterialization(op, &scopedState);
+      if (fail) {
+        op->emitError("Unhandled super-vector materialization failure");
+        return true;
+      }
+    }
+  }
+
+  LLVM_DEBUG(dbgs() << "\nFunction is now\n");
+  LLVM_DEBUG((*slice)[0]->getParentOfType<FuncOp>().print(dbgs()));
+
+  // slice are topologically sorted, we can just erase them in reverse
+  // order. Reverse iterator does not just work simply with an operator*
+  // dereference.
+  for (int idx = slice->size() - 1; idx >= 0; --idx) {
+    LLVM_DEBUG(dbgs() << "\nErase: ");
+    LLVM_DEBUG((*slice)[idx]->print(dbgs()));
+    (*slice)[idx]->erase();
+  }
+  return false;
+}
+
+/// Materializes super-vector types into concrete hw vector types as follows:
+///   1. start from super-vector terminators (current vector.transfer_write
+///      ops);
+///   2. collect all the operations that can be reached by transitive use-defs
+///      chains;
+///   3. get the superVectorType for this particular terminator and the
+///      corresponding hardware vector type (for now limited to F32)
+///      TODO(ntv): be more general than F32.
+///   4. emit the transitive useDef set to operate on the finer-grain vector
+///      types.
+///
+/// Notes
+/// =====
+/// The `slice` is sorted in topological order by construction.
+/// Additionally, this set is limited to operations in the same lexical scope
+/// because we currently disallow vectorization of defs that come from another
+/// scope.
+/// TODO(ntv): please document return value.
+static bool materialize(FuncOp f, const SetVector<Operation *> &terminators,
+                        MaterializationState *state) {
+  DenseSet<Operation *> seen;
+  DominanceInfo domInfo(f);
+  for (auto *term : terminators) {
+    // Short-circuit test, a given terminator may have been reached by some
+    // other previous transitive use-def chains.
+    if (seen.count(term) > 0) {
+      continue;
+    }
+
+    auto terminator = cast<VectorTransferWriteOp>(term);
+    LLVM_DEBUG(dbgs() << "\nFrom terminator:" << *term);
+
+    // Get the transitive use-defs starting from terminator, limited to the
+    // current enclosing scope of the terminator. See the top of the function
+    // Note for the justification of this restriction.
+    // TODO(ntv): relax scoping constraints.
+    auto *enclosingScope = term->getParentOp();
+    auto keepIfInSameScope = [enclosingScope, &domInfo](Operation *op) {
+      assert(op && "NULL op");
+      if (!enclosingScope) {
+        // by construction, everyone is always under the top scope (null scope).
+        return true;
+      }
+      return domInfo.properlyDominates(enclosingScope, op);
+    };
+    SetVector<Operation *> slice =
+        getSlice(term, keepIfInSameScope, keepIfInSameScope);
+    assert(!slice.empty());
+
+    // Sanity checks: transitive slice must be completely disjoint from
+    // what we have seen so far.
+    LLVM_DEBUG(dbgs() << "\nTransitive use-defs:");
+    for (auto *ud : slice) {
+      LLVM_DEBUG(dbgs() << "\nud:" << *ud);
+      assert(seen.count(ud) == 0 &&
+             "Transitive use-defs not disjoint from already seen");
+      seen.insert(ud);
+    }
+
+    // Emit the current slice.
+    // Set scoped super-vector and corresponding hw vector types.
+    state->superVectorType = terminator.getVectorType();
+    assert((state->superVectorType.getElementType() ==
+            FloatType::getF32(term->getContext())) &&
+           "Only f32 supported for now");
+    state->hwVectorType = VectorType::get(
+        state->hwVectorSize, state->superVectorType.getElementType());
+    auto fail = emitSlice(state, &slice);
+    if (fail) {
+      return true;
+    }
+    LLVM_DEBUG(dbgs() << "\nMLFunction is now\n");
+    LLVM_DEBUG(f.print(dbgs()));
+  }
+  return false;
+}
+
+void MaterializeVectorsPass::runOnFunction() {
+  // Thread-safe RAII local context, BumpPtrAllocator freed on exit.
+  NestedPatternContext mlContext;
+
+  // TODO(ntv): Check to see if this supports arbitrary top-level code.
+  FuncOp f = getFunction();
+  if (f.getBlocks().size() != 1)
+    return;
+
+  using matcher::Op;
+  LLVM_DEBUG(dbgs() << "\nMaterializeVectors on Function\n");
+  LLVM_DEBUG(f.print(dbgs()));
+
+  MaterializationState state(hwVectorSize);
+  // Get the hardware vector type.
+  // TODO(ntv): get elemental type from super-vector type rather than force f32.
+  auto subVectorType =
+      VectorType::get(hwVectorSize, FloatType::getF32(&getContext()));
+
+  // Capture terminators; i.e. vector.transfer_write ops involving a strict
+  // super-vector of subVectorType.
+  auto filter = [subVectorType](Operation &op) {
+    if (!isa<VectorTransferWriteOp>(op)) {
+      return false;
+    }
+    return matcher::operatesOnSuperVectorsOf(op, subVectorType);
+  };
+  auto pat = Op(filter);
+  SmallVector<NestedMatch, 8> matches;
+  pat.match(f, &matches);
+  SetVector<Operation *> terminators;
+  for (auto m : matches) {
+    terminators.insert(m.getMatchedOperation());
+  }
+
+  if (materialize(f, terminators, &state))
+    signalPassFailure();
+}
+
+FunctionPassBase *
+mlir::createMaterializeVectorsPass(llvm::ArrayRef<int64_t> vectorSize) {
+  return new MaterializeVectorsPass(vectorSize);
+}
+
+static PassRegistration<MaterializeVectorsPass>
+    pass("affine-materialize-vectors",
+         "Materializes super-vectors to vectors of the "
+         "proper size for the hardware");
+
+#undef DEBUG_TYPE
diff --git a/third_party/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/third_party/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
new file mode 100644
index 0000000..4f8b1c6
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
@@ -0,0 +1,260 @@
+//===- MemRefDataFlowOpt.cpp - MemRef DataFlow Optimization pass ------ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to forward memref stores to loads, thereby
+// potentially getting rid of intermediate memref's entirely.
+// TODO(mlir-team): In the future, similar techniques could be used to eliminate
+// dead memref store's and perform more complex forwarding when support for
+// SSA scalars live out of 'affine.for'/'affine.if' statements is available.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/Dominance.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include <algorithm>
+
+#define DEBUG_TYPE "memref-dataflow-opt"
+
+using namespace mlir;
+
+namespace {
+
+// The store to load forwarding relies on three conditions:
+//
+// 1) there has to be a dependence from the store to the load satisfied at the
+// block* immediately within the innermost loop enclosing both the load op and
+// the store op,
+//
+// 2) the store op should dominate the load op,
+//
+// 3) among all candidate store op's that satisfy (1) and (2), if there exists a
+// store op that postdominates all those that satisfy (1), such a store op is
+// provably the last writer to the particular memref location being loaded from
+// by the load op, and its store value can be forwarded to the load.
+//
+// 4) the load should touch a single location in the memref for a given
+// iteration of the innermost loop enclosing both the store op and the load op.
+//
+// (* A dependence being satisfied at a block: a dependence that is satisfied by
+// virtue of the destination operation appearing textually / lexically after
+// the source operation within the body of a 'affine.for' operation; thus, a
+// dependence is always either satisfied by a loop or by a block).
+//
+// The above conditions are simple to check, sufficient, and powerful for most
+// cases in practice - condition (1) and (3) are precise and necessary, while
+// condition (2) is a sufficient one but not necessary (since it doesn't reason
+// about loops that are guaranteed to execute at least once).
+//
+// TODO(mlir-team): more forwarding can be done when support for
+// loop/conditional live-out SSA values is available.
+// TODO(mlir-team): do general dead store elimination for memref's. This pass
+// currently only eliminates the stores only if no other loads/uses (other
+// than dealloc) remain.
+//
+struct MemRefDataFlowOpt : public FunctionPass<MemRefDataFlowOpt> {
+  void runOnFunction() override;
+
+  void forwardStoreToLoad(AffineLoadOp loadOp);
+
+  // A list of memref's that are potentially dead / could be eliminated.
+  SmallPtrSet<Value *, 4> memrefsToErase;
+  // Load op's whose results were replaced by those forwarded from stores.
+  std::vector<Operation *> loadOpsToErase;
+
+  DominanceInfo *domInfo = nullptr;
+  PostDominanceInfo *postDomInfo = nullptr;
+};
+
+} // end anonymous namespace
+
+/// Creates a pass to perform optimizations relying on memref dataflow such as
+/// store to load forwarding, elimination of dead stores, and dead allocs.
+FunctionPassBase *mlir::createMemRefDataFlowOptPass() {
+  return new MemRefDataFlowOpt();
+}
+
+// This is a straightforward implementation not optimized for speed. Optimize
+// this in the future if needed.
+void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) {
+  Operation *lastWriteStoreOp = nullptr;
+  Operation *loadOpInst = loadOp.getOperation();
+
+  // First pass over the use list to get minimum number of surrounding
+  // loops common between the load op and the store op, with min taken across
+  // all store ops.
+  SmallVector<Operation *, 8> storeOps;
+  unsigned minSurroundingLoops = getNestingDepth(*loadOpInst);
+  for (auto *user : loadOp.getMemRef()->getUsers()) {
+    auto storeOp = dyn_cast<AffineStoreOp>(user);
+    if (!storeOp)
+      continue;
+    auto *storeOpInst = storeOp.getOperation();
+    unsigned nsLoops = getNumCommonSurroundingLoops(*loadOpInst, *storeOpInst);
+    minSurroundingLoops = std::min(nsLoops, minSurroundingLoops);
+    storeOps.push_back(storeOpInst);
+  }
+
+  unsigned loadOpDepth = getNestingDepth(*loadOpInst);
+
+  // 1. Check if there is a dependence satisfied at depth equal to the depth
+  // of the loop body of the innermost common surrounding loop of the storeOp
+  // and loadOp.
+  // The list of store op candidates for forwarding - need to satisfy the
+  // conditions listed at the top.
+  SmallVector<Operation *, 8> fwdingCandidates;
+  // Store ops that have a dependence into the load (even if they aren't
+  // forwarding candidates). Each forwarding candidate will be checked for a
+  // post-dominance on these. 'fwdingCandidates' are a subset of depSrcStores.
+  SmallVector<Operation *, 8> depSrcStores;
+  for (auto *storeOpInst : storeOps) {
+    MemRefAccess srcAccess(storeOpInst);
+    MemRefAccess destAccess(loadOpInst);
+    FlatAffineConstraints dependenceConstraints;
+    unsigned nsLoops = getNumCommonSurroundingLoops(*loadOpInst, *storeOpInst);
+    // Dependences at loop depth <= minSurroundingLoops do NOT matter.
+    for (unsigned d = nsLoops + 1; d > minSurroundingLoops; d--) {
+      DependenceResult result = checkMemrefAccessDependence(
+          srcAccess, destAccess, d, &dependenceConstraints,
+          /*dependenceComponents=*/nullptr);
+      if (!hasDependence(result))
+        continue;
+      depSrcStores.push_back(storeOpInst);
+      // Check if this store is a candidate for forwarding; we only forward if
+      // the dependence from the store is carried by the *body* of innermost
+      // common surrounding loop. As an example this filters out cases like:
+      // affine.for %i0
+      //   affine.for %i1
+      //     %idx = affine.apply (d0) -> (d0 + 1) (%i0)
+      //     store %A[%idx]
+      //     load %A[%i0]
+      //
+      if (d != nsLoops + 1)
+        break;
+
+      // 2. The store has to dominate the load op to be candidate. This is not
+      // strictly a necessary condition since dominance isn't a prerequisite for
+      // a memref element store to reach a load, but this is sufficient and
+      // reasonably powerful in practice.
+      if (!domInfo->dominates(storeOpInst, loadOpInst))
+        break;
+
+      // Finally, forwarding is only possible if the load touches a single
+      // location in the memref across the enclosing loops *not* common with the
+      // store. This is filtering out cases like:
+      // for (i ...)
+      //   a [i] = ...
+      //   for (j ...)
+      //      ... = a[j]
+      // If storeOpInst and loadOpDepth at the same nesting depth, the load Op
+      // is trivially loading from a single location at that depth; so there
+      // isn't a need to call isRangeOneToOne.
+      if (getNestingDepth(*storeOpInst) < loadOpDepth) {
+        MemRefRegion region(loadOpInst->getLoc());
+        region.compute(loadOpInst, nsLoops);
+        if (!region.getConstraints()->isRangeOneToOne(
+                /*start=*/0, /*limit=*/loadOp.getMemRefType().getRank()))
+          break;
+      }
+
+      // After all these conditions, we have a candidate for forwarding!
+      fwdingCandidates.push_back(storeOpInst);
+      break;
+    }
+  }
+
+  // Note: this can implemented in a cleaner way with postdominator tree
+  // traversals. Consider this for the future if needed.
+  for (auto *storeOpInst : fwdingCandidates) {
+    // 3. Of all the store op's that meet the above criteria, the store
+    // that postdominates all 'depSrcStores' (if such a store exists) is the
+    // unique store providing the value to the load, i.e., provably the last
+    // writer to that memref loc.
+    if (llvm::all_of(depSrcStores, [&](Operation *depStore) {
+          return postDomInfo->postDominates(storeOpInst, depStore);
+        })) {
+      lastWriteStoreOp = storeOpInst;
+      break;
+    }
+  }
+  // TODO: optimization for future: those store op's that are determined to be
+  // postdominated above can actually be recorded and skipped on the 'i' loop
+  // iteration above --- since they can never post dominate everything.
+
+  if (!lastWriteStoreOp)
+    return;
+
+  // Perform the actual store to load forwarding.
+  Value *storeVal = cast<AffineStoreOp>(lastWriteStoreOp).getValueToStore();
+  loadOp.replaceAllUsesWith(storeVal);
+  // Record the memref for a later sweep to optimize away.
+  memrefsToErase.insert(loadOp.getMemRef());
+  // Record this to erase later.
+  loadOpsToErase.push_back(loadOpInst);
+}
+
+void MemRefDataFlowOpt::runOnFunction() {
+  // Only supports single block functions at the moment.
+  FuncOp f = getFunction();
+  if (f.getBlocks().size() != 1) {
+    markAllAnalysesPreserved();
+    return;
+  }
+
+  domInfo = &getAnalysis<DominanceInfo>();
+  postDomInfo = &getAnalysis<PostDominanceInfo>();
+
+  loadOpsToErase.clear();
+  memrefsToErase.clear();
+
+  // Walk all load's and perform load/store forwarding.
+  f.walk<AffineLoadOp>(
+      [&](AffineLoadOp loadOp) { forwardStoreToLoad(loadOp); });
+
+  // Erase all load op's whose results were replaced with store fwd'ed ones.
+  for (auto *loadOp : loadOpsToErase) {
+    loadOp->erase();
+  }
+
+  // Check if the store fwd'ed memrefs are now left with only stores and can
+  // thus be completely deleted. Note: the canononicalize pass should be able
+  // to do this as well, but we'll do it here since we collected these anyway.
+  for (auto *memref : memrefsToErase) {
+    // If the memref hasn't been alloc'ed in this function, skip.
+    Operation *defInst = memref->getDefiningOp();
+    if (!defInst || !isa<AllocOp>(defInst))
+      // TODO(mlir-team): if the memref was returned by a 'call' operation, we
+      // could still erase it if the call had no side-effects.
+      continue;
+    if (llvm::any_of(memref->getUsers(), [&](Operation *ownerInst) {
+          return (!isa<AffineStoreOp>(ownerInst) && !isa<DeallocOp>(ownerInst));
+        }))
+      continue;
+
+    // Erase all stores, the dealloc, and the alloc on the memref.
+    for (auto *user : llvm::make_early_inc_range(memref->getUsers()))
+      user->erase();
+    defInst->erase();
+  }
+}
+
+static PassRegistration<MemRefDataFlowOpt>
+    pass("memref-dataflow-opt", "Perform store/load forwarding for memrefs");
diff --git a/third_party/mlir/lib/Transforms/PipelineDataTransfer.cpp b/third_party/mlir/lib/Transforms/PipelineDataTransfer.cpp
new file mode 100644
index 0000000..af456c3
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -0,0 +1,382 @@
+//===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to pipeline data transfers.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "mlir/Transforms/Utils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+#define DEBUG_TYPE "affine-pipeline-data-transfer"
+
+using namespace mlir;
+
+namespace {
+
+struct PipelineDataTransfer : public FunctionPass<PipelineDataTransfer> {
+  void runOnFunction() override;
+  void runOnAffineForOp(AffineForOp forOp);
+
+  std::vector<AffineForOp> forOps;
+};
+
+} // end anonymous namespace
+
+/// Creates a pass to pipeline explicit movement of data across levels of the
+/// memory hierarchy.
+FunctionPassBase *mlir::createPipelineDataTransferPass() {
+  return new PipelineDataTransfer();
+}
+
+// Returns the position of the tag memref operand given a DMA operation.
+// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
+// added.  TODO(b/117228571)
+static unsigned getTagMemRefPos(Operation &dmaInst) {
+  assert(isa<AffineDmaStartOp>(dmaInst) || isa<AffineDmaWaitOp>(dmaInst));
+  if (auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaInst)) {
+    return dmaStartOp.getTagMemRefOperandIndex();
+  }
+  // First operand for a dma finish operation.
+  return 0;
+}
+
+/// Doubles the buffer of the supplied memref on the specified 'affine.for'
+/// operation by adding a leading dimension of size two to the memref.
+/// Replaces all uses of the old memref by the new one while indexing the newly
+/// added dimension by the loop IV of the specified 'affine.for' operation
+/// modulo 2. Returns false if such a replacement cannot be performed.
+static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) {
+  auto *forBody = forOp.getBody();
+  OpBuilder bInner(forBody, forBody->begin());
+  bInner.setInsertionPoint(forBody, forBody->begin());
+
+  // Doubles the shape with a leading dimension extent of 2.
+  auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
+    // Add the leading dimension in the shape for the double buffer.
+    ArrayRef<int64_t> oldShape = oldMemRefType.getShape();
+    SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank());
+    newShape[0] = 2;
+    std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1);
+    auto newMemRefType =
+        bInner.getMemRefType(newShape, oldMemRefType.getElementType(), {},
+                             oldMemRefType.getMemorySpace());
+    return newMemRefType;
+  };
+
+  auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
+  auto newMemRefType = doubleShape(oldMemRefType);
+
+  // The double buffer is allocated right before 'forInst'.
+  auto *forInst = forOp.getOperation();
+  OpBuilder bOuter(forInst);
+  // Put together alloc operands for any dynamic dimensions of the memref.
+  SmallVector<Value *, 4> allocOperands;
+  unsigned dynamicDimCount = 0;
+  for (auto dimSize : oldMemRefType.getShape()) {
+    if (dimSize == -1)
+      allocOperands.push_back(bOuter.create<DimOp>(forInst->getLoc(), oldMemRef,
+                                                   dynamicDimCount++));
+  }
+
+  // Create and place the alloc right before the 'affine.for' operation.
+  Value *newMemRef =
+      bOuter.create<AllocOp>(forInst->getLoc(), newMemRefType, allocOperands);
+
+  // Create 'iv mod 2' value to index the leading dimension.
+  auto d0 = bInner.getAffineDimExpr(0);
+  int64_t step = forOp.getStep();
+  auto modTwoMap = bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0,
+                                       {d0.floorDiv(step) % 2});
+  auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap,
+                                                 forOp.getInductionVar());
+
+  // replaceAllMemRefUsesWith will always succeed unless the forOp body has
+  // non-deferencing uses of the memref (dealloc's are fine though).
+  if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef,
+                                /*extraIndices=*/{ivModTwoOp},
+                                /*indexRemap=*/AffineMap(),
+                                /*extraOperands=*/{},
+                                /*domInstFilter=*/&*forOp.getBody()->begin())) {
+    LLVM_DEBUG(
+        forOp.emitError("memref replacement for double buffering failed"));
+    ivModTwoOp.erase();
+    return false;
+  }
+  // Insert the dealloc op right after the for loop.
+  bOuter.setInsertionPoint(forInst->getBlock(),
+                           std::next(Block::iterator(forInst)));
+  bOuter.create<DeallocOp>(forInst->getLoc(), newMemRef);
+
+  return true;
+}
+
+/// Returns success if the IR is in a valid state.
+void PipelineDataTransfer::runOnFunction() {
+  // Do a post order walk so that inner loop DMAs are processed first. This is
+  // necessary since 'affine.for' operations nested within would otherwise
+  // become invalid (erased) when the outer loop is pipelined (the pipelined one
+  // gets deleted and replaced by a prologue, a new steady-state loop and an
+  // epilogue).
+  forOps.clear();
+  getFunction().walk<AffineForOp>(
+      [&](AffineForOp forOp) { forOps.push_back(forOp); });
+  for (auto forOp : forOps)
+    runOnAffineForOp(forOp);
+}
+
+// Check if tags of the dma start op and dma wait op match.
+static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp) {
+  if (startOp.getTagMemRef() != waitOp.getTagMemRef())
+    return false;
+  auto startIndices = startOp.getTagIndices();
+  auto waitIndices = waitOp.getTagIndices();
+  // Both of these have the same number of indices since they correspond to the
+  // same tag memref.
+  for (auto it = startIndices.begin(), wIt = waitIndices.begin(),
+            e = startIndices.end();
+       it != e; ++it, ++wIt) {
+    // Keep it simple for now, just checking if indices match.
+    // TODO(mlir-team): this would in general need to check if there is no
+    // intervening write writing to the same tag location, i.e., memory last
+    // write/data flow analysis. This is however sufficient/powerful enough for
+    // now since the DMA generation pass or the input for it will always have
+    // start/wait with matching tags (same SSA operand indices).
+    if (*it != *wIt)
+      return false;
+  }
+  return true;
+}
+
+// Identify matching DMA start/finish operations to overlap computation with.
+static void findMatchingStartFinishInsts(
+    AffineForOp forOp,
+    SmallVectorImpl<std::pair<Operation *, Operation *>> &startWaitPairs) {
+
+  // Collect outgoing DMA operations - needed to check for dependences below.
+  SmallVector<AffineDmaStartOp, 4> outgoingDmaOps;
+  for (auto &op : *forOp.getBody()) {
+    auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
+    if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster())
+      outgoingDmaOps.push_back(dmaStartOp);
+  }
+
+  SmallVector<Operation *, 4> dmaStartInsts, dmaFinishInsts;
+  for (auto &op : *forOp.getBody()) {
+    // Collect DMA finish operations.
+    if (isa<AffineDmaWaitOp>(op)) {
+      dmaFinishInsts.push_back(&op);
+      continue;
+    }
+    auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
+    if (!dmaStartOp)
+      continue;
+
+    // Only DMAs incoming into higher memory spaces are pipelined for now.
+    // TODO(bondhugula): handle outgoing DMA pipelining.
+    if (!dmaStartOp.isDestMemorySpaceFaster())
+      continue;
+
+    // Check for dependence with outgoing DMAs. Doing this conservatively.
+    // TODO(andydavis,bondhugula): use the dependence analysis to check for
+    // dependences between an incoming and outgoing DMA in the same iteration.
+    auto it = outgoingDmaOps.begin();
+    for (; it != outgoingDmaOps.end(); ++it) {
+      if (it->getDstMemRef() == dmaStartOp.getSrcMemRef())
+        break;
+    }
+    if (it != outgoingDmaOps.end())
+      continue;
+
+    // We only double buffer if the buffer is not live out of loop.
+    auto *memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos());
+    bool escapingUses = false;
+    for (auto *user : memref->getUsers()) {
+      // We can double buffer regardless of dealloc's outside the loop.
+      if (isa<DeallocOp>(user))
+        continue;
+      if (!forOp.getBody()->findAncestorInstInBlock(*user)) {
+        LLVM_DEBUG(llvm::dbgs()
+                       << "can't pipeline: buffer is live out of loop\n";);
+        escapingUses = true;
+        break;
+      }
+    }
+    if (!escapingUses)
+      dmaStartInsts.push_back(&op);
+  }
+
+  // For each start operation, we look for a matching finish operation.
+  for (auto *dmaStartInst : dmaStartInsts) {
+    for (auto *dmaFinishInst : dmaFinishInsts) {
+      if (checkTagMatch(cast<AffineDmaStartOp>(dmaStartInst),
+                        cast<AffineDmaWaitOp>(dmaFinishInst))) {
+        startWaitPairs.push_back({dmaStartInst, dmaFinishInst});
+        break;
+      }
+    }
+  }
+}
+
+/// Overlap DMA transfers with computation in this loop. If successful,
+/// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are
+/// inserted right before where it was.
+void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
+  auto mayBeConstTripCount = getConstantTripCount(forOp);
+  if (!mayBeConstTripCount.hasValue()) {
+    LLVM_DEBUG(
+        forOp.emitRemark("won't pipeline due to unknown trip count loop"));
+    return;
+  }
+
+  SmallVector<std::pair<Operation *, Operation *>, 4> startWaitPairs;
+  findMatchingStartFinishInsts(forOp, startWaitPairs);
+
+  if (startWaitPairs.empty()) {
+    LLVM_DEBUG(forOp.emitRemark("No dma start/finish pairs\n"));
+    return;
+  }
+
+  // Double the buffers for the higher memory space memref's.
+  // Identify memref's to replace by scanning through all DMA start
+  // operations. A DMA start operation has two memref's - the one from the
+  // higher level of memory hierarchy is the one to double buffer.
+  // TODO(bondhugula): check whether double-buffering is even necessary.
+  // TODO(bondhugula): make this work with different layouts: assuming here that
+  // the dimension we are adding here for the double buffering is the outermost
+  // dimension.
+  for (auto &pair : startWaitPairs) {
+    auto *dmaStartInst = pair.first;
+    Value *oldMemRef = dmaStartInst->getOperand(
+        cast<AffineDmaStartOp>(dmaStartInst).getFasterMemPos());
+    if (!doubleBuffer(oldMemRef, forOp)) {
+      // Normally, double buffering should not fail because we already checked
+      // that there are no uses outside.
+      LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";);
+      LLVM_DEBUG(dmaStartInst->dump());
+      // IR still in a valid state.
+      return;
+    }
+    // If the old memref has no more uses, remove its 'dead' alloc if it was
+    // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim'
+    // operation could have been used on it if it was dynamically shaped in
+    // order to create the double buffer above.)
+    // '-canonicalize' does this in a more general way, but we'll anyway do the
+    // simple/common case so that the output / test cases looks clear.
+    if (auto *allocInst = oldMemRef->getDefiningOp()) {
+      if (oldMemRef->use_empty()) {
+        allocInst->erase();
+      } else if (oldMemRef->hasOneUse()) {
+        if (auto dealloc = dyn_cast<DeallocOp>(*oldMemRef->user_begin())) {
+          dealloc.erase();
+          oldMemRef->getDefiningOp()->erase();
+        }
+      }
+    }
+  }
+
+  // Double the buffers for tag memrefs.
+  for (auto &pair : startWaitPairs) {
+    auto *dmaFinishInst = pair.second;
+    Value *oldTagMemRef =
+        dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst));
+    if (!doubleBuffer(oldTagMemRef, forOp)) {
+      LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
+      return;
+    }
+    // If the old tag has no more uses, remove its 'dead' alloc if it was
+    // alloc'ed.
+    if (oldTagMemRef->use_empty())
+      if (auto *allocInst = oldTagMemRef->getDefiningOp())
+        allocInst->erase();
+  }
+
+  // Double buffering would have invalidated all the old DMA start/wait insts.
+  startWaitPairs.clear();
+  findMatchingStartFinishInsts(forOp, startWaitPairs);
+
+  // Store shift for operation for later lookup for AffineApplyOp's.
+  DenseMap<Operation *, unsigned> instShiftMap;
+  for (auto &pair : startWaitPairs) {
+    auto *dmaStartInst = pair.first;
+    assert(isa<AffineDmaStartOp>(dmaStartInst));
+    instShiftMap[dmaStartInst] = 0;
+    // Set shifts for DMA start op's affine operand computation slices to 0.
+    SmallVector<AffineApplyOp, 4> sliceOps;
+    mlir::createAffineComputationSlice(dmaStartInst, &sliceOps);
+    if (!sliceOps.empty()) {
+      for (auto sliceOp : sliceOps) {
+        instShiftMap[sliceOp.getOperation()] = 0;
+      }
+    } else {
+      // If a slice wasn't created, the reachable affine.apply op's from its
+      // operands are the ones that go with it.
+      SmallVector<Operation *, 4> affineApplyInsts;
+      SmallVector<Value *, 4> operands(dmaStartInst->getOperands());
+      getReachableAffineApplyOps(operands, affineApplyInsts);
+      for (auto *op : affineApplyInsts) {
+        instShiftMap[op] = 0;
+      }
+    }
+  }
+  // Everything else (including compute ops and dma finish) are shifted by one.
+  for (auto &op : *forOp.getBody()) {
+    if (instShiftMap.find(&op) == instShiftMap.end()) {
+      instShiftMap[&op] = 1;
+    }
+  }
+
+  // Get shifts stored in map.
+  std::vector<uint64_t> shifts(forOp.getBody()->getOperations().size());
+  unsigned s = 0;
+  for (auto &op : *forOp.getBody()) {
+    assert(instShiftMap.find(&op) != instShiftMap.end());
+    shifts[s++] = instShiftMap[&op];
+
+    // Tagging operations with shifts for debugging purposes.
+    LLVM_DEBUG({
+      OpBuilder b(&op);
+      op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1]));
+    });
+  }
+
+  if (!isInstwiseShiftValid(forOp, shifts)) {
+    // Violates dependences.
+    LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
+    return;
+  }
+
+  if (failed(instBodySkew(forOp, shifts))) {
+    LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";);
+    return;
+  }
+}
+
+static PassRegistration<PipelineDataTransfer> pass(
+    "affine-pipeline-data-transfer",
+    "Pipeline non-blocking data transfers between explicitly managed levels of "
+    "the memory hierarchy");
diff --git a/third_party/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/third_party/mlir/lib/Transforms/SimplifyAffineStructures.cpp
new file mode 100644
index 0000000..3b6c231
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/SimplifyAffineStructures.cpp
@@ -0,0 +1,108 @@
+//===- SimplifyAffineStructures.cpp ---------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to simplify affine structures.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+
+#define DEBUG_TYPE "simplify-affine-structure"
+
+using namespace mlir;
+
+namespace {
+
+/// Simplifies all affine expressions appearing in the operations of
+/// the Function. This is mainly to test the simplifyAffineExpr method.
+/// TODO(someone): This should just be defined as a canonicalization pattern
+/// on AffineMap and driven from the existing canonicalization pass.
+struct SimplifyAffineStructures
+    : public FunctionPass<SimplifyAffineStructures> {
+  void runOnFunction() override;
+
+  /// Utility to simplify an affine attribute and update its entry in the parent
+  /// operation if necessary.
+  template <typename AttributeT>
+  void simplifyAndUpdateAttribute(Operation *op, Identifier name,
+                                  AttributeT attr) {
+    auto &simplified = simplifiedAttributes[attr];
+    if (simplified == attr)
+      return;
+
+    // This is a newly encountered attribute.
+    if (!simplified) {
+      // Try to simplify the value of the attribute.
+      auto value = attr.getValue();
+      auto simplifiedValue = simplify(value);
+      if (simplifiedValue == value) {
+        simplified = attr;
+        return;
+      }
+      simplified = AttributeT::get(simplifiedValue);
+    }
+
+    // Simplification was successful, so update the attribute.
+    op->setAttr(name, simplified);
+  }
+
+  /// Performs basic integer set simplifications. Checks if it's empty, and
+  /// replaces it with the canonical empty set if it is.
+  IntegerSet simplify(IntegerSet set) {
+    FlatAffineConstraints fac(set);
+    if (fac.isEmpty())
+      return IntegerSet::getEmptySet(set.getNumDims(), set.getNumSymbols(),
+                                     &getContext());
+    return set;
+  }
+
+  /// Performs basic affine map simplifications.
+  AffineMap simplify(AffineMap map) {
+    MutableAffineMap mMap(map);
+    mMap.simplify();
+    return mMap.getAffineMap();
+  }
+
+  DenseMap<Attribute, Attribute> simplifiedAttributes;
+};
+
+} // end anonymous namespace
+
+FunctionPassBase *mlir::createSimplifyAffineStructuresPass() {
+  return new SimplifyAffineStructures();
+}
+
+void SimplifyAffineStructures::runOnFunction() {
+  simplifiedAttributes.clear();
+  getFunction().walk([&](Operation *opInst) {
+    for (auto attr : opInst->getAttrs()) {
+      if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>())
+        simplifyAndUpdateAttribute(opInst, attr.first, mapAttr);
+      else if (auto setAttr = attr.second.dyn_cast<IntegerSetAttr>())
+        simplifyAndUpdateAttribute(opInst, attr.first, setAttr);
+    }
+  });
+}
+
+static PassRegistration<SimplifyAffineStructures>
+    pass("simplify-affine-structures", "Simplify affine expressions");
diff --git a/third_party/mlir/lib/Transforms/StripDebugInfo.cpp b/third_party/mlir/lib/Transforms/StripDebugInfo.cpp
new file mode 100644
index 0000000..c82354e
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/StripDebugInfo.cpp
@@ -0,0 +1,46 @@
+//===- StripDebugInfo.cpp - Pass to strip debug information ---------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+
+namespace {
+struct StripDebugInfo : public FunctionPass<StripDebugInfo> {
+  void runOnFunction() override;
+};
+} // end anonymous namespace
+
+void StripDebugInfo::runOnFunction() {
+  FuncOp func = getFunction();
+  auto unknownLoc = UnknownLoc::get(&getContext());
+
+  // Strip the debug info from the function and its operations.
+  func.setLoc(unknownLoc);
+  func.walk([&](Operation *op) { op->setLoc(unknownLoc); });
+}
+
+/// Creates a pass to strip debug information from a function.
+FunctionPassBase *mlir::createStripDebugInfoPass() {
+  return new StripDebugInfo();
+}
+
+static PassRegistration<StripDebugInfo>
+    pass("strip-debuginfo", "Strip debug info from functions and operations");
diff --git a/third_party/mlir/lib/Transforms/Utils/CMakeLists.txt b/third_party/mlir/lib/Transforms/Utils/CMakeLists.txt
new file mode 100644
index 0000000..3c08f45
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -0,0 +1,20 @@
+add_llvm_library(MLIRTransformUtils
+  FoldUtils.cpp
+  GreedyPatternRewriteDriver.cpp
+  LoopFusionUtils.cpp
+  LoopUtils.cpp
+  RegionUtils.cpp
+  Utils.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
+  )
+
+add_dependencies(MLIRTransformUtils MLIRStandardOpsIncGen)
+target_link_libraries(MLIRTransformUtils
+  MLIRAffineOps
+  MLIRAnalysis
+  MLIRLoopOps
+  MLIRPass
+  MLIRStandardOps
+  )
diff --git a/third_party/mlir/lib/Transforms/Utils/FoldUtils.cpp b/third_party/mlir/lib/Transforms/Utils/FoldUtils.cpp
new file mode 100644
index 0000000..435ea85
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -0,0 +1,248 @@
+//===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines various operation fold utilities. These utilities are
+// intended to be used by passes to unify and simply their logic.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/FoldUtils.h"
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/StandardOps/Ops.h"
+
+using namespace mlir;
+
+/// Given an operation, find the parent region that folded constants should be
+/// inserted into.
+static Region *getInsertionRegion(Operation *op) {
+  while (Region *region = op->getParentRegion()) {
+    // Insert in this region for any of the following scenarios:
+    //  * The parent is unregistered, or is known to be isolated from above.
+    //  * The parent is a top-level operation.
+    auto *parentOp = region->getParentOp();
+    if (!parentOp->isRegistered() || parentOp->isKnownIsolatedFromAbove() ||
+        !parentOp->getBlock())
+      return region;
+    // Traverse up the parent looking for an insertion region.
+    op = parentOp;
+  }
+  llvm_unreachable("expected valid insertion region");
+}
+
+/// A utility function used to materialize a constant for a given attribute and
+/// type. On success, a valid constant value is returned. Otherwise, null is
+/// returned
+static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
+                                      Attribute value, Type type,
+                                      Location loc) {
+  auto insertPt = builder.getInsertionPoint();
+  (void)insertPt;
+
+  // Ask the dialect to materialize a constant operation for this value.
+  if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
+    assert(insertPt == builder.getInsertionPoint());
+    assert(matchPattern(constOp, m_Constant(&value)));
+    return constOp;
+  }
+
+  // If the dialect is unable to materialize a constant, check to see if the
+  // standard constant can be used.
+  if (ConstantOp::isBuildableWith(value, type))
+    return builder.create<ConstantOp>(loc, type, value);
+  return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// OperationFolder
+//===----------------------------------------------------------------------===//
+
+LogicalResult OperationFolder::tryToFold(
+    Operation *op,
+    llvm::function_ref<void(Operation *)> processGeneratedConstants,
+    llvm::function_ref<void(Operation *)> preReplaceAction) {
+  // If this is a unique'd constant, return failure as we know that it has
+  // already been folded.
+  if (referencedDialects.count(op))
+    return failure();
+
+  // Try to fold the operation.
+  SmallVector<Value *, 8> results;
+  if (failed(tryToFold(op, results, processGeneratedConstants)))
+    return failure();
+
+  // Constant folding succeeded. We will start replacing this op's uses and
+  // eventually erase this op. Invoke the callback provided by the caller to
+  // perform any pre-replacement action.
+  if (preReplaceAction)
+    preReplaceAction(op);
+
+  // Check to see if the operation was just updated in place.
+  if (results.empty())
+    return success();
+
+  // Otherwise, replace all of the result values and erase the operation.
+  for (unsigned i = 0, e = results.size(); i != e; ++i)
+    op->getResult(i)->replaceAllUsesWith(results[i]);
+  op->erase();
+  return success();
+}
+
+/// Notifies that the given constant `op` should be remove from this
+/// OperationFolder's internal bookkeeping.
+void OperationFolder::notifyRemoval(Operation *op) {
+  // Check to see if this operation is uniqued within the folder.
+  auto it = referencedDialects.find(op);
+  if (it == referencedDialects.end())
+    return;
+
+  // Get the constant value for this operation, this is the value that was used
+  // to unique the operation internally.
+  Attribute constValue;
+  matchPattern(op, m_Constant(&constValue));
+  assert(constValue);
+
+  // Get the constant map that this operation was uniqued in.
+  auto &uniquedConstants = foldScopes[getInsertionRegion(op)];
+
+  // Erase all of the references to this operation.
+  auto type = op->getResult(0)->getType();
+  for (auto *dialect : it->second)
+    uniquedConstants.erase(std::make_tuple(dialect, constValue, type));
+  referencedDialects.erase(it);
+}
+
+/// Tries to perform folding on the given `op`. If successful, populates
+/// `results` with the results of the folding.
+LogicalResult OperationFolder::tryToFold(
+    Operation *op, SmallVectorImpl<Value *> &results,
+    llvm::function_ref<void(Operation *)> processGeneratedConstants) {
+  SmallVector<Attribute, 8> operandConstants;
+  SmallVector<OpFoldResult, 8> foldResults;
+
+  // Check to see if any operands to the operation is constant and whether
+  // the operation knows how to constant fold itself.
+  operandConstants.assign(op->getNumOperands(), Attribute());
+  for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
+    matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
+
+  // If this is a commutative binary operation with a constant on the left
+  // side move it to the right side.
+  if (operandConstants.size() == 2 && operandConstants[0] &&
+      !operandConstants[1] && op->isCommutative()) {
+    std::swap(op->getOpOperand(0), op->getOpOperand(1));
+    std::swap(operandConstants[0], operandConstants[1]);
+  }
+
+  // Attempt to constant fold the operation.
+  if (failed(op->fold(operandConstants, foldResults)))
+    return failure();
+
+  // Check to see if the operation was just updated in place.
+  if (foldResults.empty())
+    return success();
+  assert(foldResults.size() == op->getNumResults());
+
+  // Create a builder to insert new operations into the entry block of the
+  // insertion region.
+  auto *insertionRegion = getInsertionRegion(op);
+  auto &entry = insertionRegion->front();
+  OpBuilder builder(&entry, entry.begin());
+
+  // Get the constant map for the insertion region of this operation.
+  auto &uniquedConstants = foldScopes[insertionRegion];
+
+  // Create the result constants and replace the results.
+  auto *dialect = op->getDialect();
+  for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
+    assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
+
+    // Check if the result was an SSA value.
+    if (auto *repl = foldResults[i].dyn_cast<Value *>()) {
+      results.emplace_back(repl);
+      continue;
+    }
+
+    // Check to see if there is a canonicalized version of this constant.
+    auto *res = op->getResult(i);
+    Attribute attrRepl = foldResults[i].get<Attribute>();
+    if (auto *constOp =
+            tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl,
+                                   res->getType(), op->getLoc())) {
+      results.push_back(constOp->getResult(0));
+      continue;
+    }
+    // If materialization fails, cleanup any operations generated for the
+    // previous results and return failure.
+    for (Operation &op : llvm::make_early_inc_range(
+             llvm::make_range(entry.begin(), builder.getInsertionPoint()))) {
+      notifyRemoval(&op);
+      op.erase();
+    }
+    return failure();
+  }
+
+  // Process any newly generated operations.
+  if (processGeneratedConstants) {
+    for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i)
+      processGeneratedConstants(&*i);
+  }
+
+  return success();
+}
+
+/// Try to get or create a new constant entry. On success this returns the
+/// constant operation value, nullptr otherwise.
+Operation *OperationFolder::tryGetOrCreateConstant(
+    ConstantMap &uniquedConstants, Dialect *dialect, OpBuilder &builder,
+    Attribute value, Type type, Location loc) {
+  // Check if an existing mapping already exists.
+  auto constKey = std::make_tuple(dialect, value, type);
+  auto *&constInst = uniquedConstants[constKey];
+  if (constInst)
+    return constInst;
+
+  // If one doesn't exist, try to materialize one.
+  if (!(constInst = materializeConstant(dialect, builder, value, type, loc)))
+    return nullptr;
+
+  // Check to see if the generated constant is in the expected dialect.
+  auto *newDialect = constInst->getDialect();
+  if (newDialect == dialect) {
+    referencedDialects[constInst].push_back(dialect);
+    return constInst;
+  }
+
+  // If it isn't, then we also need to make sure that the mapping for the new
+  // dialect is valid.
+  auto newKey = std::make_tuple(newDialect, value, type);
+
+  // If an existing operation in the new dialect already exists, delete the
+  // materialized operation in favor of the existing one.
+  if (auto *existingOp = uniquedConstants.lookup(newKey)) {
+    constInst->erase();
+    referencedDialects[existingOp].push_back(dialect);
+    return constInst = existingOp;
+  }
+
+  // Otherwise, update the new dialect to the materialized operation.
+  referencedDialects[constInst].assign({dialect, newDialect});
+  auto newIt = uniquedConstants.insert({newKey, constInst});
+  return newIt.first->second;
+}
diff --git a/third_party/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/third_party/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
new file mode 100644
index 0000000..d202a37
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -0,0 +1,238 @@
+//===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements mlir::applyPatternsGreedily.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Transforms/FoldUtils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "pattern-matcher"
+
+static llvm::cl::opt<unsigned> maxPatternMatchIterations(
+    "mlir-max-pattern-match-iterations",
+    llvm::cl::desc("Max number of iterations scanning for pattern match"),
+    llvm::cl::init(10));
+
+namespace {
+
+/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
+/// applies the locally optimal patterns in a roughly "bottom up" way.
+class GreedyPatternRewriteDriver : public PatternRewriter {
+public:
+  explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
+                                      OwningRewritePatternList &patterns)
+      : PatternRewriter(ctx), matcher(patterns) {
+    worklist.reserve(64);
+  }
+
+  /// Perform the rewrites. Return true if the rewrite converges in
+  /// `maxIterations`.
+  bool simplify(Operation *op, int maxIterations);
+
+  void addToWorklist(Operation *op) {
+    // Check to see if the worklist already contains this op.
+    if (worklistMap.count(op))
+      return;
+
+    worklistMap[op] = worklist.size();
+    worklist.push_back(op);
+  }
+
+  Operation *popFromWorklist() {
+    auto *op = worklist.back();
+    worklist.pop_back();
+
+    // This operation is no longer in the worklist, keep worklistMap up to date.
+    if (op)
+      worklistMap.erase(op);
+    return op;
+  }
+
+  /// If the specified operation is in the worklist, remove it.  If not, this is
+  /// a no-op.
+  void removeFromWorklist(Operation *op) {
+    auto it = worklistMap.find(op);
+    if (it != worklistMap.end()) {
+      assert(worklist[it->second] == op && "malformed worklist data structure");
+      worklist[it->second] = nullptr;
+    }
+  }
+
+  // These are hooks implemented for PatternRewriter.
+protected:
+  // Implement the hook for creating operations, and make sure that newly
+  // created ops are added to the worklist for processing.
+  Operation *createOperation(const OperationState &state) override {
+    auto *result = OpBuilder::createOperation(state);
+    addToWorklist(result);
+    return result;
+  }
+
+  // If an operation is about to be removed, make sure it is not in our
+  // worklist anymore because we'd get dangling references to it.
+  void notifyOperationRemoved(Operation *op) override {
+    addToWorklist(op->getOperands());
+    removeFromWorklist(op);
+    folder.notifyRemoval(op);
+    op->walk([this](Operation *operation) {
+      removeFromWorklist(operation);
+      folder.notifyRemoval(operation);
+    });
+  }
+
+  // When the root of a pattern is about to be replaced, it can trigger
+  // simplifications to its users - make sure to add them to the worklist
+  // before the root is changed.
+  void notifyRootReplaced(Operation *op) override {
+    for (auto *result : op->getResults())
+      for (auto *user : result->getUsers())
+        addToWorklist(user);
+  }
+
+private:
+  // Look over the provided operands for any defining operations that should
+  // be re-added to the worklist. This function should be called when an
+  // operation is modified or removed, as it may trigger further
+  // simplifications.
+  template <typename Operands> void addToWorklist(Operands &&operands) {
+    for (Value *operand : operands) {
+      // If the use count of this operand is now < 2, we re-add the defining
+      // operation to the worklist.
+      // TODO(riverriddle) This is based on the fact that zero use operations
+      // may be deleted, and that single use values often have more
+      // canonicalization opportunities.
+      if (!operand->use_empty() && !operand->hasOneUse())
+        continue;
+      if (auto *defInst = operand->getDefiningOp())
+        addToWorklist(defInst);
+    }
+  }
+
+  /// The low-level pattern matcher.
+  RewritePatternMatcher matcher;
+
+  /// The worklist for this transformation keeps track of the operations that
+  /// need to be revisited, plus their index in the worklist.  This allows us to
+  /// efficiently remove operations from the worklist when they are erased, even
+  /// if they aren't the root of a pattern.
+  std::vector<Operation *> worklist;
+  DenseMap<Operation *, unsigned> worklistMap;
+
+  /// Non-pattern based folder for operations.
+  OperationFolder folder;
+};
+} // end anonymous namespace
+
+/// Perform the rewrites.
+bool GreedyPatternRewriteDriver::simplify(Operation *op, int maxIterations) {
+  // Add the given operation to the worklist.
+  auto collectOps = [this](Operation *op) { addToWorklist(op); };
+
+  bool changed = false;
+  int i = 0;
+  do {
+    // Add all nested operations to the worklist.
+    for (auto &region : op->getRegions())
+      region.walk(collectOps);
+
+    // These are scratch vectors used in the folding loop below.
+    SmallVector<Value *, 8> originalOperands, resultValues;
+
+    changed = false;
+    while (!worklist.empty()) {
+      auto *op = popFromWorklist();
+
+      // Nulls get added to the worklist when operations are removed, ignore
+      // them.
+      if (op == nullptr)
+        continue;
+
+      // If the operation has no side effects, and no users, then it is
+      // trivially dead - remove it.
+      if (op->hasNoSideEffect() && op->use_empty()) {
+        // Be careful to update bookkeeping in OperationFolder to keep
+        // consistency if this is a constant op.
+        folder.notifyRemoval(op);
+        op->erase();
+        continue;
+      }
+
+      // Collects all the operands and result uses of the given `op` into work
+      // list.
+      originalOperands.assign(op->operand_begin(), op->operand_end());
+      auto collectOperandsAndUses = [&](Operation *op) {
+        // Add the operands to the worklist for visitation.
+        addToWorklist(originalOperands);
+
+        // Add all the users of the result to the worklist so we make sure
+        // to revisit them.
+        for (auto *result : op->getResults())
+          for (auto *operand : result->getUsers())
+            addToWorklist(operand);
+      };
+
+      // Try to fold this op.
+      if (succeeded(folder.tryToFold(op, collectOps, collectOperandsAndUses))) {
+        changed |= true;
+        continue;
+      }
+
+      // Make sure that any new operations are inserted at this point.
+      setInsertionPoint(op);
+
+      // Try to match one of the canonicalization patterns. The rewriter is
+      // automatically notified of any necessary changes, so there is nothing
+      // else to do here.
+      changed |= matcher.matchAndRewrite(op, *this);
+    }
+  } while (changed && ++i < maxIterations);
+  // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
+  return !changed;
+}
+
+/// Rewrite the regions of the specified operation, which must be isolated from
+/// above, by repeatedly applying the highest benefit patterns in a greedy
+/// work-list driven manner. Return true if no more patterns can be matched in
+/// the result operation regions.
+/// Note: This does not apply patterns to the top-level operation itself.
+///
+bool mlir::applyPatternsGreedily(Operation *op,
+                                 OwningRewritePatternList &patterns) {
+  // The top-level operation must be known to be isolated from above to
+  // prevent performing canonicalizations on operations defined at or above
+  // the region containing 'op'.
+  if (!op->isKnownIsolatedFromAbove())
+    return false;
+
+  GreedyPatternRewriteDriver driver(op->getContext(), patterns);
+  bool converged = driver.simplify(op, maxPatternMatchIterations);
+  LLVM_DEBUG(if (!converged) {
+    llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
+                 << maxPatternMatchIterations << " times";
+  });
+  return converged;
+}
diff --git a/third_party/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/third_party/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
new file mode 100644
index 0000000..4c079bd
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
@@ -0,0 +1,487 @@
+//===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements loop fusion transformation utility functions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/LoopFusionUtils.h"
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/StandardOps/Ops.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+#define DEBUG_TYPE "loop-fusion-utils"
+
+using namespace mlir;
+
+// Gathers all load and store memref accesses in 'opA' into 'values', where
+// 'values[memref] == true' for each store operation.
+static void getLoadAndStoreMemRefAccesses(Operation *opA,
+                                          DenseMap<Value *, bool> &values) {
+  opA->walk([&](Operation *op) {
+    if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
+      if (values.count(loadOp.getMemRef()) == 0)
+        values[loadOp.getMemRef()] = false;
+    } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
+      values[storeOp.getMemRef()] = true;
+    }
+  });
+}
+
+// Returns true if 'op' is a load or store operation which access an memref
+// accessed 'values' and at least one of the access is a store operation.
+// Returns false otherwise.
+static bool isDependentLoadOrStoreOp(Operation *op,
+                                     DenseMap<Value *, bool> &values) {
+  if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
+    return values.count(loadOp.getMemRef()) > 0 &&
+           values[loadOp.getMemRef()] == true;
+  } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
+    return values.count(storeOp.getMemRef()) > 0;
+  }
+  return false;
+}
+
+// Returns the first operation in range ('opA', 'opB') which has a data
+// dependence on 'opA'. Returns 'nullptr' of no dependence exists.
+static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) {
+  // Record memref values from all loads/store in loop nest rooted at 'opA'.
+  // Map from memref value to bool which is true if store, false otherwise.
+  DenseMap<Value *, bool> values;
+  getLoadAndStoreMemRefAccesses(opA, values);
+
+  // For each 'opX' in block in range ('opA', 'opB'), check if there is a data
+  // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref
+  // and at least one of the accesses is a store).
+  Operation *firstDepOp = nullptr;
+  for (Block::iterator it = std::next(Block::iterator(opA));
+       it != Block::iterator(opB); ++it) {
+    Operation *opX = &(*it);
+    opX->walk([&](Operation *op) {
+      if (!firstDepOp && isDependentLoadOrStoreOp(op, values))
+        firstDepOp = opX;
+    });
+    if (firstDepOp)
+      break;
+  }
+  return firstDepOp;
+}
+
+// Returns the last operation 'opX' in range ('opA', 'opB'), for which there
+// exists a data dependence from 'opX' to 'opB'.
+// Returns 'nullptr' of no dependence exists.
+static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) {
+  // Record memref values from all loads/store in loop nest rooted at 'opB'.
+  // Map from memref value to bool which is true if store, false otherwise.
+  DenseMap<Value *, bool> values;
+  getLoadAndStoreMemRefAccesses(opB, values);
+
+  // For each 'opX' in block in range ('opA', 'opB') in reverse order,
+  // check if there is a data dependence from 'opX' to 'opB':
+  // *) 'opX' and 'opB' access the same memref and at least one of the accesses
+  //    is a store.
+  // *) 'opX' produces an SSA Value which is used by 'opB'.
+  Operation *lastDepOp = nullptr;
+  for (Block::reverse_iterator it = std::next(Block::reverse_iterator(opB));
+       it != Block::reverse_iterator(opA); ++it) {
+    Operation *opX = &(*it);
+    opX->walk([&](Operation *op) {
+      if (lastDepOp)
+        return;
+      if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op)) {
+        if (isDependentLoadOrStoreOp(op, values))
+          lastDepOp = opX;
+        return;
+      }
+      for (auto *value : op->getResults()) {
+        for (auto *user : value->getUsers()) {
+          SmallVector<AffineForOp, 4> loops;
+          // Check if any loop in loop nest surrounding 'user' is 'opB'.
+          getLoopIVs(*user, &loops);
+          if (llvm::is_contained(loops, cast<AffineForOp>(opB))) {
+            lastDepOp = opX;
+          }
+        }
+      }
+    });
+    if (lastDepOp)
+      break;
+  }
+  return lastDepOp;
+}
+
+// Computes and returns an insertion point operation, before which the
+// the fused <srcForOp, dstForOp> loop nest can be inserted while preserving
+// dependences. Returns nullptr if no such insertion point is found.
+static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp,
+                                                 AffineForOp dstForOp) {
+  bool isSrcForOpBeforeDstForOp =
+      srcForOp.getOperation()->isBeforeInBlock(dstForOp.getOperation());
+  auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
+  auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
+
+  auto *firstDepOpA =
+      getFirstDependentOpInRange(forOpA.getOperation(), forOpB.getOperation());
+  auto *lastDepOpB =
+      getLastDependentOpInRange(forOpA.getOperation(), forOpB.getOperation());
+  // Block:
+  //      ...
+  //  |-- opA
+  //  |   ...
+  //  |   lastDepOpB --|
+  //  |   ...          |
+  //  |-> firstDepOpA  |
+  //      ...          |
+  //      opB <---------
+  //
+  // Valid insertion point range: (lastDepOpB, firstDepOpA)
+  //
+  if (firstDepOpA != nullptr) {
+    if (lastDepOpB != nullptr) {
+      if (firstDepOpA->isBeforeInBlock(lastDepOpB) || firstDepOpA == lastDepOpB)
+        // No valid insertion point exists which preserves dependences.
+        return nullptr;
+    }
+    // Return insertion point in valid range closest to 'opB'.
+    // TODO(andydavis) Consider other insertion points in valid range.
+    return firstDepOpA;
+  }
+  // No dependences from 'opA' to operation in range ('opA', 'opB'), return
+  // 'opB' insertion point.
+  return forOpB.getOperation();
+}
+
+// Gathers all load and store ops in loop nest rooted at 'forOp' into
+// 'loadAndStoreOps'.
+static bool
+gatherLoadsAndStores(AffineForOp forOp,
+                     SmallVectorImpl<Operation *> &loadAndStoreOps) {
+  bool hasIfOp = false;
+  forOp.getOperation()->walk([&](Operation *op) {
+    if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op))
+      loadAndStoreOps.push_back(op);
+    else if (isa<AffineIfOp>(op))
+      hasIfOp = true;
+  });
+  return !hasIfOp;
+}
+
+// TODO(andydavis) Prevent fusion of loop nests with side-effecting operations.
+FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
+                                unsigned dstLoopDepth,
+                                ComputationSliceState *srcSlice) {
+  // Return 'failure' if 'dstLoopDepth == 0'.
+  if (dstLoopDepth == 0) {
+    LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n.");
+    return FusionResult::FailPrecondition;
+  }
+  // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block.
+  auto *block = srcForOp.getOperation()->getBlock();
+  if (block != dstForOp.getOperation()->getBlock()) {
+    LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n.");
+    return FusionResult::FailPrecondition;
+  }
+
+  // Return 'failure' if no valid insertion point for fused loop nest in 'block'
+  // exists which would preserve dependences.
+  if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) {
+    LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n.");
+    return FusionResult::FailBlockDependence;
+  }
+
+  // Check if 'srcForOp' precedeces 'dstForOp' in 'block'.
+  bool isSrcForOpBeforeDstForOp =
+      srcForOp.getOperation()->isBeforeInBlock(dstForOp.getOperation());
+  // 'forOpA' executes before 'forOpB' in 'block'.
+  auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
+  auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
+
+  // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'.
+  SmallVector<Operation *, 4> opsA;
+  if (!gatherLoadsAndStores(forOpA, opsA)) {
+    LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n.");
+    return FusionResult::FailPrecondition;
+  }
+
+  // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'.
+  SmallVector<Operation *, 4> opsB;
+  if (!gatherLoadsAndStores(forOpB, opsB)) {
+    LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n.");
+    return FusionResult::FailPrecondition;
+  }
+
+  // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'.
+  unsigned numCommonLoops = mlir::getNumCommonSurroundingLoops(
+      *srcForOp.getOperation(), *dstForOp.getOperation());
+
+  // Compute union of computation slices computed between all pairs of ops
+  // from 'forOpA' and 'forOpB'.
+  if (failed(mlir::computeSliceUnion(opsA, opsB, dstLoopDepth, numCommonLoops,
+                                     isSrcForOpBeforeDstForOp, srcSlice))) {
+    LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
+    return FusionResult::FailPrecondition;
+  }
+
+  return FusionResult::Success;
+}
+
+/// Collect loop nest statistics (eg. loop trip count and operation count)
+/// in 'stats' for loop nest rooted at 'forOp'. Returns true on success,
+/// returns false otherwise.
+bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) {
+  bool ret = true;
+  forOpRoot.getOperation()->walk<AffineForOp>([&](AffineForOp forOp) {
+    auto *childForOp = forOp.getOperation();
+    auto *parentForOp = forOp.getOperation()->getParentOp();
+    if (!llvm::isa<FuncOp>(parentForOp)) {
+      if (!isa<AffineForOp>(parentForOp)) {
+        LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp");
+        ret = false;
+        return;
+      }
+      // Add mapping to 'forOp' from its parent AffineForOp.
+      stats->loopMap[parentForOp].push_back(forOp);
+    }
+
+    // Record the number of op operations in the body of 'forOp'.
+    unsigned count = 0;
+    stats->opCountMap[childForOp] = 0;
+    for (auto &op : *forOp.getBody()) {
+      if (!isa<AffineForOp>(op) && !isa<AffineIfOp>(op))
+        ++count;
+    }
+    stats->opCountMap[childForOp] = count;
+    // Record trip count for 'forOp'. Set flag if trip count is not
+    // constant.
+    Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
+    if (!maybeConstTripCount.hasValue()) {
+      // Currently only constant trip count loop nests are supported.
+      LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported");
+      ret = false;
+      return;
+    }
+    stats->tripCountMap[childForOp] = maybeConstTripCount.getValue();
+  });
+  return ret;
+}
+
+// Computes the total cost of the loop nest rooted at 'forOp'.
+// Currently, the total cost is computed by counting the total operation
+// instance count (i.e. total number of operations in the loop bodyloop
+// operation count * loop trip count) for the entire loop nest.
+// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
+// specified in the map when computing the total op instance count.
+// NOTEs: 1) This is used to compute the cost of computation slices, which are
+// sliced along the iteration dimension, and thus reduce the trip count.
+// If 'computeCostMap' is non-null, the total op count for forOps specified
+// in the map is increased (not overridden) by adding the op count from the
+// map to the existing op count for the for loop. This is done before
+// multiplying by the loop's trip count, and is used to model the cost of
+// inserting a sliced loop nest of known cost into the loop's body.
+// 2) This is also used to compute the cost of fusing a slice of some loop nest
+// within another loop.
+static int64_t getComputeCostHelper(
+    Operation *forOp, LoopNestStats &stats,
+    llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap,
+    DenseMap<Operation *, int64_t> *computeCostMap) {
+  // 'opCount' is the total number operations in one iteration of 'forOp' body,
+  // minus terminator op which is a no-op.
+  int64_t opCount = stats.opCountMap[forOp] - 1;
+  if (stats.loopMap.count(forOp) > 0) {
+    for (auto childForOp : stats.loopMap[forOp]) {
+      opCount += getComputeCostHelper(childForOp.getOperation(), stats,
+                                      tripCountOverrideMap, computeCostMap);
+    }
+  }
+  // Add in additional op instances from slice (if specified in map).
+  if (computeCostMap != nullptr) {
+    auto it = computeCostMap->find(forOp);
+    if (it != computeCostMap->end()) {
+      opCount += it->second;
+    }
+  }
+  // Override trip count (if specified in map).
+  int64_t tripCount = stats.tripCountMap[forOp];
+  if (tripCountOverrideMap != nullptr) {
+    auto it = tripCountOverrideMap->find(forOp);
+    if (it != tripCountOverrideMap->end()) {
+      tripCount = it->second;
+    }
+  }
+  // Returns the total number of dynamic instances of operations in loop body.
+  return tripCount * opCount;
+}
+
+// TODO(andydavis,b/126426796): extend this to handle multiple result maps.
+static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
+  assert(lbMap.getNumResults() == 1 && "expected single result bound map");
+  assert(ubMap.getNumResults() == 1 && "expected single result bound map");
+  assert(lbMap.getNumDims() == ubMap.getNumDims());
+  assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
+  AffineExpr lbExpr(lbMap.getResult(0));
+  AffineExpr ubExpr(ubMap.getResult(0));
+  auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
+                                         lbMap.getNumSymbols());
+  auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
+  if (!cExpr)
+    return None;
+  return cExpr.getValue();
+}
+
+// Return the number of iterations in the given slice.
+static uint64_t getSliceIterationCount(
+    const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
+  uint64_t iterCount = 1;
+  for (const auto &count : sliceTripCountMap) {
+    iterCount *= count.second;
+  }
+  return iterCount;
+}
+
+// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
+// nest surrounding represented by slice loop bounds in 'slice'.
+// Returns true on success, false otherwise (if a non-constant trip count
+// was encountered).
+// TODO(andydavis) Make this work with non-unit step loops.
+static bool buildSliceTripCountMap(
+    ComputationSliceState *slice,
+    llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
+  unsigned numSrcLoopIVs = slice->ivs.size();
+  // Populate map from AffineForOp -> trip count
+  for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
+    AffineForOp forOp = getForInductionVarOwner(slice->ivs[i]);
+    auto *op = forOp.getOperation();
+    AffineMap lbMap = slice->lbs[i];
+    AffineMap ubMap = slice->ubs[i];
+    if (lbMap == AffineMap() || ubMap == AffineMap()) {
+      // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
+      if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
+        (*tripCountMap)[op] =
+            forOp.getConstantUpperBound() - forOp.getConstantLowerBound();
+        continue;
+      }
+      Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
+      if (maybeConstTripCount.hasValue()) {
+        (*tripCountMap)[op] = maybeConstTripCount.getValue();
+        continue;
+      }
+      return false;
+    }
+    Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
+    // Slice bounds are created with a constant ub - lb difference.
+    if (!tripCount.hasValue())
+      return false;
+    (*tripCountMap)[op] = tripCount.getValue();
+  }
+  return true;
+}
+
+/// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
+/// Currently, the total cost is computed by counting the total operation
+/// instance count (i.e. total number of operations in the loop body * loop
+/// trip count) for the entire loop nest.
+int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) {
+  return getComputeCostHelper(forOp.getOperation(), stats,
+                              /*tripCountOverrideMap=*/nullptr,
+                              /*computeCostMap=*/nullptr);
+}
+
+/// Computes and returns in 'computeCost', the total compute cost of fusing the
+/// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently,
+/// the total cost is computed by counting the total operation instance count
+/// (i.e. total number of operations in the loop body * loop trip count) for
+/// the entire loop nest.
+bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
+                                AffineForOp dstForOp, LoopNestStats &dstStats,
+                                ComputationSliceState *slice,
+                                int64_t *computeCost) {
+  llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
+  DenseMap<Operation *, int64_t> computeCostMap;
+
+  // Build trip count map for computation slice.
+  if (!buildSliceTripCountMap(slice, &sliceTripCountMap))
+    return false;
+  // Checks whether a store to load forwarding will happen.
+  int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
+  assert(sliceIterationCount > 0);
+  bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
+  auto *insertPointParent = slice->insertPoint->getParentOp();
+
+  // The store and loads to this memref will disappear.
+  // TODO(andydavis) Add load coalescing to memref data flow opt pass.
+  if (storeLoadFwdGuaranteed) {
+    // Subtract from operation count the loads/store we expect load/store
+    // forwarding to remove.
+    unsigned storeCount = 0;
+    llvm::SmallDenseSet<Value *, 4> storeMemrefs;
+    srcForOp.getOperation()->walk([&](Operation *op) {
+      if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
+        storeMemrefs.insert(storeOp.getMemRef());
+        ++storeCount;
+      }
+    });
+    // Subtract out any store ops in single-iteration src slice loop nest.
+    if (storeCount > 0)
+      computeCostMap[insertPointParent] = -storeCount;
+    // Subtract out any load users of 'storeMemrefs' nested below
+    // 'insertPointParent'.
+    for (auto *value : storeMemrefs) {
+      for (auto *user : value->getUsers()) {
+        if (auto loadOp = dyn_cast<AffineLoadOp>(user)) {
+          SmallVector<AffineForOp, 4> loops;
+          // Check if any loop in loop nest surrounding 'user' is
+          // 'insertPointParent'.
+          getLoopIVs(*user, &loops);
+          if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) {
+            if (auto forOp =
+                    dyn_cast_or_null<AffineForOp>(user->getParentOp())) {
+              if (computeCostMap.count(forOp) == 0)
+                computeCostMap[forOp] = 0;
+              computeCostMap[forOp] -= 1;
+            }
+          }
+        }
+      }
+    }
+  }
+
+  // Compute op instance count for the src loop nest with iteration slicing.
+  int64_t sliceComputeCost = getComputeCostHelper(
+      srcForOp.getOperation(), srcStats, &sliceTripCountMap, &computeCostMap);
+
+  // Compute cost of fusion for this depth.
+  computeCostMap[insertPointParent] = sliceComputeCost;
+
+  *computeCost =
+      getComputeCostHelper(dstForOp.getOperation(), dstStats,
+                           /*tripCountOverrideMap=*/nullptr, &computeCostMap);
+  return true;
+}
diff --git a/third_party/mlir/lib/Transforms/Utils/LoopUtils.cpp b/third_party/mlir/lib/Transforms/Utils/LoopUtils.cpp
new file mode 100644
index 0000000..a4717ad
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -0,0 +1,1133 @@
+//===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements miscellaneous loop transformation routines.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/LoopUtils.h"
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "LoopUtils"
+
+using namespace mlir;
+using llvm::SetVector;
+
+/// Computes the cleanup loop lower bound of the loop being unrolled with
+/// the specified unroll factor; this bound will also be upper bound of the main
+/// part of the unrolled loop. Computes the bound as an AffineMap with its
+/// operands or a null map when the trip count can't be expressed as an affine
+/// expression.
+void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor,
+                                    AffineMap *map,
+                                    SmallVectorImpl<Value *> *operands,
+                                    OpBuilder &b) {
+  auto lbMap = forOp.getLowerBoundMap();
+
+  // Single result lower bound map only.
+  if (lbMap.getNumResults() != 1) {
+    *map = AffineMap();
+    return;
+  }
+
+  AffineMap tripCountMap;
+  SmallVector<Value *, 4> tripCountOperands;
+  buildTripCountMapAndOperands(forOp, &tripCountMap, &tripCountOperands);
+
+  // Sometimes the trip count cannot be expressed as an affine expression.
+  if (!tripCountMap) {
+    *map = AffineMap();
+    return;
+  }
+
+  unsigned step = forOp.getStep();
+
+  SmallVector<Value *, 4> lbOperands(forOp.getLowerBoundOperands());
+  auto lb = b.create<AffineApplyOp>(forOp.getLoc(), lbMap, lbOperands);
+
+  // For each upper bound expr, get the range.
+  // Eg: affine.for %i = lb to min (ub1, ub2),
+  // where tripCountExprs yield (tr1, tr2), we create affine.apply's:
+  // lb + tr1 - tr1 % ufactor, lb + tr2 - tr2 % ufactor; the results of all
+  // these affine.apply's make up the cleanup loop lower bound.
+  SmallVector<AffineExpr, 4> bumpExprs(tripCountMap.getNumResults());
+  SmallVector<Value *, 4> bumpValues(tripCountMap.getNumResults());
+  for (unsigned i = 0, e = tripCountMap.getNumResults(); i < e; i++) {
+    auto tripCountExpr = tripCountMap.getResult(i);
+    bumpExprs[i] = (tripCountExpr - tripCountExpr % unrollFactor) * step;
+    auto bumpMap = b.getAffineMap(tripCountMap.getNumDims(),
+                                  tripCountMap.getNumSymbols(), bumpExprs[i]);
+    bumpValues[i] =
+        b.create<AffineApplyOp>(forOp.getLoc(), bumpMap, tripCountOperands);
+  }
+
+  SmallVector<AffineExpr, 4> newUbExprs(tripCountMap.getNumResults());
+  for (unsigned i = 0, e = bumpExprs.size(); i < e; i++)
+    newUbExprs[i] = b.getAffineDimExpr(0) + b.getAffineDimExpr(i + 1);
+
+  operands->clear();
+  operands->push_back(lb);
+  operands->append(bumpValues.begin(), bumpValues.end());
+  *map = b.getAffineMap(1 + tripCountMap.getNumResults(), 0, newUbExprs);
+  // Simplify the map + operands.
+  fullyComposeAffineMapAndOperands(map, operands);
+  *map = simplifyAffineMap(*map);
+  canonicalizeMapAndOperands(map, operands);
+  // Remove any affine.apply's that became dead from the simplification above.
+  for (auto *v : bumpValues) {
+    if (v->use_empty()) {
+      v->getDefiningOp()->erase();
+    }
+  }
+  if (lb.use_empty())
+    lb.erase();
+}
+
+/// Promotes the loop body of a forOp to its containing block if the forOp
+/// was known to have a single iteration.
+// TODO(bondhugula): extend this for arbitrary affine bounds.
+LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) {
+  Optional<uint64_t> tripCount = getConstantTripCount(forOp);
+  if (!tripCount.hasValue() || tripCount.getValue() != 1)
+    return failure();
+
+  // TODO(mlir-team): there is no builder for a max.
+  if (forOp.getLowerBoundMap().getNumResults() != 1)
+    return failure();
+
+  // Replaces all IV uses to its single iteration value.
+  auto *iv = forOp.getInductionVar();
+  Operation *op = forOp.getOperation();
+  if (!iv->use_empty()) {
+    if (forOp.hasConstantLowerBound()) {
+      OpBuilder topBuilder(op->getParentOfType<FuncOp>().getBody());
+      auto constOp = topBuilder.create<ConstantIndexOp>(
+          forOp.getLoc(), forOp.getConstantLowerBound());
+      iv->replaceAllUsesWith(constOp);
+    } else {
+      AffineBound lb = forOp.getLowerBound();
+      SmallVector<Value *, 4> lbOperands(lb.operand_begin(), lb.operand_end());
+      OpBuilder builder(op->getBlock(), Block::iterator(op));
+      if (lb.getMap() == builder.getDimIdentityMap()) {
+        // No need of generating an affine.apply.
+        iv->replaceAllUsesWith(lbOperands[0]);
+      } else {
+        auto affineApplyOp = builder.create<AffineApplyOp>(
+            op->getLoc(), lb.getMap(), lbOperands);
+        iv->replaceAllUsesWith(affineApplyOp);
+      }
+    }
+  }
+  // Move the loop body operations, except for terminator, to the loop's
+  // containing block.
+  auto *block = op->getBlock();
+  forOp.getBody()->getOperations().back().erase();
+  block->getOperations().splice(Block::iterator(op),
+                                forOp.getBody()->getOperations());
+  forOp.erase();
+  return success();
+}
+
+/// Promotes all single iteration for op's in the FuncOp, i.e., moves
+/// their body into the containing Block.
+void mlir::promoteSingleIterationLoops(FuncOp f) {
+  // Gathers all innermost loops through a post order pruned walk.
+  f.walk<AffineForOp>(
+      [](AffineForOp forOp) { promoteIfSingleIteration(forOp); });
+}
+
+/// Generates a 'affine.for' op with the specified lower and upper bounds
+/// while generating the right IV remappings for the shifted operations. The
+/// operation blocks that go into the loop are specified in instGroupQueue
+/// starting from the specified offset, and in that order; the first element of
+/// the pair specifies the shift applied to that group of operations; note
+/// that the shift is multiplied by the loop step before being applied. Returns
+/// nullptr if the generated loop simplifies to a single iteration one.
+static AffineForOp
+generateLoop(AffineMap lbMap, AffineMap ubMap,
+             const std::vector<std::pair<uint64_t, ArrayRef<Operation *>>>
+                 &instGroupQueue,
+             unsigned offset, AffineForOp srcForInst, OpBuilder b) {
+  SmallVector<Value *, 4> lbOperands(srcForInst.getLowerBoundOperands());
+  SmallVector<Value *, 4> ubOperands(srcForInst.getUpperBoundOperands());
+
+  assert(lbMap.getNumInputs() == lbOperands.size());
+  assert(ubMap.getNumInputs() == ubOperands.size());
+
+  auto loopChunk =
+      b.create<AffineForOp>(srcForInst.getLoc(), lbOperands, lbMap, ubOperands,
+                            ubMap, srcForInst.getStep());
+  auto *loopChunkIV = loopChunk.getInductionVar();
+  auto *srcIV = srcForInst.getInductionVar();
+
+  BlockAndValueMapping operandMap;
+
+  OpBuilder bodyBuilder = loopChunk.getBodyBuilder();
+  for (auto it = instGroupQueue.begin() + offset, e = instGroupQueue.end();
+       it != e; ++it) {
+    uint64_t shift = it->first;
+    auto insts = it->second;
+    // All 'same shift' operations get added with their operands being
+    // remapped to results of cloned operations, and their IV used remapped.
+    // Generate the remapping if the shift is not zero: remappedIV = newIV -
+    // shift.
+    if (!srcIV->use_empty() && shift != 0) {
+      auto ivRemap = bodyBuilder.create<AffineApplyOp>(
+          srcForInst.getLoc(),
+          bodyBuilder.getSingleDimShiftAffineMap(
+              -static_cast<int64_t>(srcForInst.getStep() * shift)),
+          loopChunkIV);
+      operandMap.map(srcIV, ivRemap);
+    } else {
+      operandMap.map(srcIV, loopChunkIV);
+    }
+    for (auto *op : insts) {
+      if (!isa<AffineTerminatorOp>(op))
+        bodyBuilder.clone(*op, operandMap);
+    }
+  };
+  if (succeeded(promoteIfSingleIteration(loopChunk)))
+    return AffineForOp();
+  return loopChunk;
+}
+
+/// Skew the operations in the body of a 'affine.for' operation with the
+/// specified operation-wise shifts. The shifts are with respect to the
+/// original execution order, and are multiplied by the loop 'step' before being
+/// applied. A shift of zero for each operation will lead to no change.
+// The skewing of operations with respect to one another can be used for
+// example to allow overlap of asynchronous operations (such as DMA
+// communication) with computation, or just relative shifting of operations
+// for better register reuse, locality or parallelism. As such, the shifts are
+// typically expected to be at most of the order of the number of operations.
+// This method should not be used as a substitute for loop distribution/fission.
+// This method uses an algorithm// in time linear in the number of operations
+// in the body of the for loop - (using the 'sweep line' paradigm). This method
+// asserts preservation of SSA dominance. A check for that as well as that for
+// memory-based depedence preservation check rests with the users of this
+// method.
+LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef<uint64_t> shifts,
+                                 bool unrollPrologueEpilogue) {
+  if (forOp.getBody()->begin() == std::prev(forOp.getBody()->end()))
+    return success();
+
+  // If the trip counts aren't constant, we would need versioning and
+  // conditional guards (or context information to prevent such versioning). The
+  // better way to pipeline for such loops is to first tile them and extract
+  // constant trip count "full tiles" before applying this.
+  auto mayBeConstTripCount = getConstantTripCount(forOp);
+  if (!mayBeConstTripCount.hasValue()) {
+    LLVM_DEBUG(forOp.emitRemark("non-constant trip count loop not handled"));
+    return success();
+  }
+  uint64_t tripCount = mayBeConstTripCount.getValue();
+
+  assert(isInstwiseShiftValid(forOp, shifts) &&
+         "shifts will lead to an invalid transformation\n");
+
+  int64_t step = forOp.getStep();
+
+  unsigned numChildInsts = forOp.getBody()->getOperations().size();
+
+  // Do a linear time (counting) sort for the shifts.
+  uint64_t maxShift = 0;
+  for (unsigned i = 0; i < numChildInsts; i++) {
+    maxShift = std::max(maxShift, shifts[i]);
+  }
+  // Such large shifts are not the typical use case.
+  if (maxShift >= numChildInsts) {
+    forOp.emitWarning("not shifting because shifts are unrealistically large");
+    return success();
+  }
+
+  // An array of operation groups sorted by shift amount; each group has all
+  // operations with the same shift in the order in which they appear in the
+  // body of the 'affine.for' op.
+  std::vector<std::vector<Operation *>> sortedInstGroups(maxShift + 1);
+  unsigned pos = 0;
+  for (auto &op : *forOp.getBody()) {
+    auto shift = shifts[pos++];
+    sortedInstGroups[shift].push_back(&op);
+  }
+
+  // Unless the shifts have a specific pattern (which actually would be the
+  // common use case), prologue and epilogue are not meaningfully defined.
+  // Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first
+  // loop generated as the prologue and the last as epilogue and unroll these
+  // fully.
+  AffineForOp prologue;
+  AffineForOp epilogue;
+
+  // Do a sweep over the sorted shifts while storing open groups in a
+  // vector, and generating loop portions as necessary during the sweep. A block
+  // of operations is paired with its shift.
+  std::vector<std::pair<uint64_t, ArrayRef<Operation *>>> instGroupQueue;
+
+  auto origLbMap = forOp.getLowerBoundMap();
+  uint64_t lbShift = 0;
+  OpBuilder b(forOp.getOperation());
+  for (uint64_t d = 0, e = sortedInstGroups.size(); d < e; ++d) {
+    // If nothing is shifted by d, continue.
+    if (sortedInstGroups[d].empty())
+      continue;
+    if (!instGroupQueue.empty()) {
+      assert(d >= 1 &&
+             "Queue expected to be empty when the first block is found");
+      // The interval for which the loop needs to be generated here is:
+      // [lbShift, min(lbShift + tripCount, d)) and the body of the
+      // loop needs to have all operations in instQueue in that order.
+      AffineForOp res;
+      if (lbShift + tripCount * step < d * step) {
+        res = generateLoop(
+            b.getShiftedAffineMap(origLbMap, lbShift),
+            b.getShiftedAffineMap(origLbMap, lbShift + tripCount * step),
+            instGroupQueue, 0, forOp, b);
+        // Entire loop for the queued op groups generated, empty it.
+        instGroupQueue.clear();
+        lbShift += tripCount * step;
+      } else {
+        res = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift),
+                           b.getShiftedAffineMap(origLbMap, d), instGroupQueue,
+                           0, forOp, b);
+        lbShift = d * step;
+      }
+      if (!prologue && res)
+        prologue = res;
+      epilogue = res;
+    } else {
+      // Start of first interval.
+      lbShift = d * step;
+    }
+    // Augment the list of operations that get into the current open interval.
+    instGroupQueue.push_back({d, sortedInstGroups[d]});
+  }
+
+  // Those operations groups left in the queue now need to be processed (FIFO)
+  // and their loops completed.
+  for (unsigned i = 0, e = instGroupQueue.size(); i < e; ++i) {
+    uint64_t ubShift = (instGroupQueue[i].first + tripCount) * step;
+    epilogue = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift),
+                            b.getShiftedAffineMap(origLbMap, ubShift),
+                            instGroupQueue, i, forOp, b);
+    lbShift = ubShift;
+    if (!prologue)
+      prologue = epilogue;
+  }
+
+  // Erase the original for op.
+  forOp.erase();
+
+  if (unrollPrologueEpilogue && prologue)
+    loopUnrollFull(prologue);
+  if (unrollPrologueEpilogue && !epilogue &&
+      epilogue.getOperation() != prologue.getOperation())
+    loopUnrollFull(epilogue);
+
+  return success();
+}
+
+// Collect perfectly nested loops starting from `rootForOps`.  Loops are
+// perfectly nested if each loop is the first and only non-terminator operation
+// in the parent loop.  Collect at most `maxLoops` loops and append them to
+// `forOps`.
+template <typename T>
+void getPerfectlyNestedLoopsImpl(
+    SmallVectorImpl<T> &forOps, T rootForOp,
+    unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
+  for (unsigned i = 0; i < maxLoops; ++i) {
+    forOps.push_back(rootForOp);
+    // FIXME: ForOp and AffineForOp currently provide different names to access
+    // the region ("region" and "getRegion").  Remove this generic access when
+    // AffineForOp moves to ODS and also gets "region".
+    Block &body = rootForOp.getOperation()->getRegion(0).front();
+    if (body.begin() != std::prev(body.end(), 2))
+      return;
+
+    rootForOp = dyn_cast<T>(&body.front());
+    if (!rootForOp)
+      return;
+  }
+}
+
+/// Get perfectly nested sequence of loops starting at root of loop nest
+/// (the first op being another AffineFor, and the second op - a terminator).
+/// A loop is perfectly nested iff: the first op in the loop's body is another
+/// AffineForOp, and the second op is a terminator).
+void mlir::getPerfectlyNestedLoops(SmallVectorImpl<AffineForOp> &nestedLoops,
+                                   AffineForOp root) {
+  getPerfectlyNestedLoopsImpl(nestedLoops, root);
+}
+
+void mlir::getPerfectlyNestedLoops(SmallVectorImpl<loop::ForOp> &nestedLoops,
+                                   loop::ForOp root) {
+  getPerfectlyNestedLoopsImpl(nestedLoops, root);
+}
+
+/// Unrolls this loop completely.
+LogicalResult mlir::loopUnrollFull(AffineForOp forOp) {
+  Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+  if (mayBeConstantTripCount.hasValue()) {
+    uint64_t tripCount = mayBeConstantTripCount.getValue();
+    if (tripCount == 1) {
+      return promoteIfSingleIteration(forOp);
+    }
+    return loopUnrollByFactor(forOp, tripCount);
+  }
+  return failure();
+}
+
+/// Unrolls and jams this loop by the specified factor or by the trip count (if
+/// constant) whichever is lower.
+LogicalResult mlir::loopUnrollUpToFactor(AffineForOp forOp,
+                                         uint64_t unrollFactor) {
+  Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+
+  if (mayBeConstantTripCount.hasValue() &&
+      mayBeConstantTripCount.getValue() < unrollFactor)
+    return loopUnrollByFactor(forOp, mayBeConstantTripCount.getValue());
+  return loopUnrollByFactor(forOp, unrollFactor);
+}
+
+/// Unrolls this loop by the specified factor. Returns success if the loop
+/// is successfully unrolled.
+LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp,
+                                       uint64_t unrollFactor) {
+  assert(unrollFactor >= 1 && "unroll factor should be >= 1");
+
+  if (unrollFactor == 1)
+    return promoteIfSingleIteration(forOp);
+
+  if (forOp.getBody()->empty() ||
+      forOp.getBody()->begin() == std::prev(forOp.getBody()->end()))
+    return failure();
+
+  // Loops where the lower bound is a max expression isn't supported for
+  // unrolling since the trip count can be expressed as an affine function when
+  // both the lower bound and the upper bound are multi-result maps. However,
+  // one meaningful way to do such unrolling would be to specialize the loop for
+  // the 'hotspot' case and unroll that hotspot.
+  if (forOp.getLowerBoundMap().getNumResults() != 1)
+    return failure();
+
+  // If the trip count is lower than the unroll factor, no unrolled body.
+  // TODO(bondhugula): option to specify cleanup loop unrolling.
+  Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+  if (mayBeConstantTripCount.hasValue() &&
+      mayBeConstantTripCount.getValue() < unrollFactor)
+    return failure();
+
+  // Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
+  Operation *op = forOp.getOperation();
+  if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) {
+    OpBuilder builder(op->getBlock(), ++Block::iterator(op));
+    auto cleanupForInst = cast<AffineForOp>(builder.clone(*op));
+    AffineMap cleanupMap;
+    SmallVector<Value *, 4> cleanupOperands;
+    getCleanupLoopLowerBound(forOp, unrollFactor, &cleanupMap, &cleanupOperands,
+                             builder);
+    assert(cleanupMap &&
+           "cleanup loop lower bound map for single result lower bound maps "
+           "can always be determined");
+    cleanupForInst.setLowerBound(cleanupOperands, cleanupMap);
+    // Promote the loop body up if this has turned into a single iteration loop.
+    promoteIfSingleIteration(cleanupForInst);
+
+    // Adjust upper bound of the original loop; this is the same as the lower
+    // bound of the cleanup loop.
+    forOp.setUpperBound(cleanupOperands, cleanupMap);
+  }
+
+  // Scale the step of loop being unrolled by unroll factor.
+  int64_t step = forOp.getStep();
+  forOp.setStep(step * unrollFactor);
+
+  // Builder to insert unrolled bodies just before the terminator of the body of
+  // 'forOp'.
+  OpBuilder builder = forOp.getBodyBuilder();
+
+  // Keep a pointer to the last non-terminator operation in the original block
+  // so that we know what to clone (since we are doing this in-place).
+  Block::iterator srcBlockEnd = std::prev(forOp.getBody()->end(), 2);
+
+  // Unroll the contents of 'forOp' (append unrollFactor-1 additional copies).
+  auto *forOpIV = forOp.getInductionVar();
+  for (unsigned i = 1; i < unrollFactor; i++) {
+    BlockAndValueMapping operandMap;
+
+    // If the induction variable is used, create a remapping to the value for
+    // this unrolled instance.
+    if (!forOpIV->use_empty()) {
+      // iv' = iv + 1/2/3...unrollFactor-1;
+      auto d0 = builder.getAffineDimExpr(0);
+      auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step});
+      auto ivUnroll =
+          builder.create<AffineApplyOp>(forOp.getLoc(), bumpMap, forOpIV);
+      operandMap.map(forOpIV, ivUnroll);
+    }
+
+    // Clone the original body of 'forOp'.
+    for (auto it = forOp.getBody()->begin(); it != std::next(srcBlockEnd);
+         it++) {
+      builder.clone(*it, operandMap);
+    }
+  }
+
+  // Promote the loop body up if this has turned into a single iteration loop.
+  promoteIfSingleIteration(forOp);
+  return success();
+}
+
+/// Performs loop interchange on 'forOpA' and 'forOpB', where 'forOpB' is
+/// nested within 'forOpA' as the only non-terminator operation in its block.
+void mlir::interchangeLoops(AffineForOp forOpA, AffineForOp forOpB) {
+  auto *forOpAInst = forOpA.getOperation();
+
+  assert(&*forOpA.getBody()->begin() == forOpB.getOperation());
+  auto &forOpABody = forOpA.getBody()->getOperations();
+  auto &forOpBBody = forOpB.getBody()->getOperations();
+
+  // 1) Splice forOpA's non-terminator operations (which is just forOpB) just
+  // before forOpA (in ForOpA's parent's block) this should leave 'forOpA's
+  // body containing only the terminator.
+  forOpAInst->getBlock()->getOperations().splice(Block::iterator(forOpAInst),
+                                                 forOpABody, forOpABody.begin(),
+                                                 std::prev(forOpABody.end()));
+  // 2) Splice forOpB's non-terminator operations into the beginning of forOpA's
+  // body (this leaves forOpB's body containing only the terminator).
+  forOpABody.splice(forOpABody.begin(), forOpBBody, forOpBBody.begin(),
+                    std::prev(forOpBBody.end()));
+  // 3) Splice forOpA into the beginning of forOpB's body.
+  forOpBBody.splice(forOpBBody.begin(), forOpAInst->getBlock()->getOperations(),
+                    Block::iterator(forOpAInst));
+}
+
+// Checks each dependence component against the permutation to see if the
+// desired loop interchange would violate dependences by making the
+// dependence componenent lexicographically negative.
+static bool checkLoopInterchangeDependences(
+    const std::vector<llvm::SmallVector<DependenceComponent, 2>> &depCompsVec,
+    ArrayRef<AffineForOp> loops, ArrayRef<unsigned> loopPermMap) {
+  // Invert permutation map.
+  unsigned maxLoopDepth = loops.size();
+  llvm::SmallVector<unsigned, 4> loopPermMapInv;
+  loopPermMapInv.resize(maxLoopDepth);
+  for (unsigned i = 0; i < maxLoopDepth; ++i)
+    loopPermMapInv[loopPermMap[i]] = i;
+
+  // Check each dependence component against the permutation to see if the
+  // desired loop interchange permutation would make the dependence vectors
+  // lexicographically negative.
+  // Example 1: [-1, 1][0, 0]
+  // Example 2: [0, 0][-1, 1]
+  for (unsigned i = 0, e = depCompsVec.size(); i < e; ++i) {
+    const llvm::SmallVector<DependenceComponent, 2> &depComps = depCompsVec[i];
+    assert(depComps.size() >= maxLoopDepth);
+    // Check if the first non-zero dependence component is positive.
+    // This iterates through loops in the desired order.
+    for (unsigned j = 0; j < maxLoopDepth; ++j) {
+      unsigned permIndex = loopPermMapInv[j];
+      assert(depComps[permIndex].lb.hasValue());
+      int64_t depCompLb = depComps[permIndex].lb.getValue();
+      if (depCompLb > 0)
+        break;
+      if (depCompLb < 0)
+        return false;
+    }
+  }
+  return true;
+}
+
+/// Checks if the loop interchange permutation 'loopPermMap' of the perfectly
+/// nested sequence of loops in 'loops' would violate dependences.
+bool mlir::isValidLoopInterchangePermutation(ArrayRef<AffineForOp> loops,
+                                             ArrayRef<unsigned> loopPermMap) {
+  // Gather dependence components for dependences between all ops in loop nest
+  // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth].
+  assert(loopPermMap.size() == loops.size());
+  unsigned maxLoopDepth = loops.size();
+  std::vector<llvm::SmallVector<DependenceComponent, 2>> depCompsVec;
+  getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec);
+  return checkLoopInterchangeDependences(depCompsVec, loops, loopPermMap);
+}
+
+/// Performs a sequence of loop interchanges of loops in perfectly nested
+/// sequence of loops in 'loops', as specified by permutation in 'loopPermMap'.
+unsigned mlir::interchangeLoops(ArrayRef<AffineForOp> loops,
+                                ArrayRef<unsigned> loopPermMap) {
+  Optional<unsigned> loopNestRootIndex;
+  for (int i = loops.size() - 1; i >= 0; --i) {
+    int permIndex = static_cast<int>(loopPermMap[i]);
+    // Store the index of the for loop which will be the new loop nest root.
+    if (permIndex == 0)
+      loopNestRootIndex = i;
+    if (permIndex > i) {
+      // Sink loop 'i' by 'permIndex - i' levels deeper into the loop nest.
+      sinkLoop(loops[i], permIndex - i);
+    }
+  }
+  assert(loopNestRootIndex.hasValue());
+  return loopNestRootIndex.getValue();
+}
+
+// Sinks all sequential loops to the innermost levels (while preserving
+// relative order among them) and moves all parallel loops to the
+// outermost (while again preserving relative order among them).
+AffineForOp mlir::sinkSequentialLoops(AffineForOp forOp) {
+  SmallVector<AffineForOp, 4> loops;
+  getPerfectlyNestedLoops(loops, forOp);
+  if (loops.size() < 2)
+    return forOp;
+
+  // Gather dependence components for dependences between all ops in loop nest
+  // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth].
+  unsigned maxLoopDepth = loops.size();
+  std::vector<llvm::SmallVector<DependenceComponent, 2>> depCompsVec;
+  getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec);
+
+  // Mark loops as either parallel or sequential.
+  llvm::SmallVector<bool, 8> isParallelLoop(maxLoopDepth, true);
+  for (unsigned i = 0, e = depCompsVec.size(); i < e; ++i) {
+    llvm::SmallVector<DependenceComponent, 2> &depComps = depCompsVec[i];
+    assert(depComps.size() >= maxLoopDepth);
+    for (unsigned j = 0; j < maxLoopDepth; ++j) {
+      DependenceComponent &depComp = depComps[j];
+      assert(depComp.lb.hasValue() && depComp.ub.hasValue());
+      if (depComp.lb.getValue() != 0 || depComp.ub.getValue() != 0)
+        isParallelLoop[j] = false;
+    }
+  }
+
+  // Count the number of parallel loops.
+  unsigned numParallelLoops = 0;
+  for (unsigned i = 0, e = isParallelLoop.size(); i < e; ++i)
+    if (isParallelLoop[i])
+      ++numParallelLoops;
+
+  // Compute permutation of loops that sinks sequential loops (and thus raises
+  // parallel loops) while preserving relative order.
+  llvm::SmallVector<unsigned, 4> loopPermMap(maxLoopDepth);
+  unsigned nextSequentialLoop = numParallelLoops;
+  unsigned nextParallelLoop = 0;
+  for (unsigned i = 0; i < maxLoopDepth; ++i) {
+    if (isParallelLoop[i]) {
+      loopPermMap[i] = nextParallelLoop++;
+    } else {
+      loopPermMap[i] = nextSequentialLoop++;
+    }
+  }
+
+  // Check if permutation 'loopPermMap' would violate dependences.
+  if (!checkLoopInterchangeDependences(depCompsVec, loops, loopPermMap))
+    return forOp;
+  // Perform loop interchange according to permutation 'loopPermMap'.
+  unsigned loopNestRootIndex = interchangeLoops(loops, loopPermMap);
+  return loops[loopNestRootIndex];
+}
+
+/// Performs a series of loop interchanges to sink 'forOp' 'loopDepth' levels
+/// deeper in the loop nest.
+void mlir::sinkLoop(AffineForOp forOp, unsigned loopDepth) {
+  for (unsigned i = 0; i < loopDepth; ++i) {
+    AffineForOp nextForOp = cast<AffineForOp>(forOp.getBody()->front());
+    interchangeLoops(forOp, nextForOp);
+  }
+}
+
+// Factors out common behavior to add a new `iv` (resp. `iv` + `offset`) to the
+// lower (resp. upper) loop bound. When called for both the lower and upper
+// bounds, the resulting IR resembles:
+//
+// ```mlir
+//    affine.for %i = max (`iv, ...) to min (`iv` + `offset`) {
+//      ...
+//    }
+// ```
+static void augmentMapAndBounds(OpBuilder &b, Value *iv, AffineMap *map,
+                                SmallVector<Value *, 4> *operands,
+                                int64_t offset = 0) {
+  auto bounds = llvm::to_vector<4>(map->getResults());
+  bounds.push_back(b.getAffineDimExpr(map->getNumDims()) + offset);
+  operands->insert(operands->begin() + map->getNumDims(), iv);
+  *map = b.getAffineMap(map->getNumDims() + 1, map->getNumSymbols(), bounds);
+  canonicalizeMapAndOperands(map, operands);
+}
+
+// Stripmines `forOp` by `factor` and sinks it under each of the `targets`.
+// Stripmine-sink is a primitive building block for generalized tiling of
+// imperfectly nested loops.
+// This transformation is purely mechanical and does not check legality,
+// profitability or even structural correctness. It is the user's
+// responsibility to specify `targets` that are dominated by `forOp`.
+// Returns the new AffineForOps, one per `targets`, nested immediately under
+// each of the `targets`.
+static SmallVector<AffineForOp, 8>
+stripmineSink(AffineForOp forOp, uint64_t factor,
+              ArrayRef<AffineForOp> targets) {
+  auto originalStep = forOp.getStep();
+  auto scaledStep = originalStep * factor;
+  forOp.setStep(scaledStep);
+
+  auto *op = forOp.getOperation();
+  OpBuilder b(op->getBlock(), ++Block::iterator(op));
+
+  // Lower-bound map creation.
+  auto lbMap = forOp.getLowerBoundMap();
+  SmallVector<Value *, 4> lbOperands(forOp.getLowerBoundOperands());
+  augmentMapAndBounds(b, forOp.getInductionVar(), &lbMap, &lbOperands);
+
+  // Upper-bound map creation.
+  auto ubMap = forOp.getUpperBoundMap();
+  SmallVector<Value *, 4> ubOperands(forOp.getUpperBoundOperands());
+  augmentMapAndBounds(b, forOp.getInductionVar(), &ubMap, &ubOperands,
+                      /*offset=*/scaledStep);
+
+  auto *iv = forOp.getInductionVar();
+  SmallVector<AffineForOp, 8> innerLoops;
+  for (auto t : targets) {
+    // Insert newForOp before the terminator of `t`.
+    OpBuilder b = t.getBodyBuilder();
+    auto newForOp = b.create<AffineForOp>(t.getLoc(), lbOperands, lbMap,
+                                          ubOperands, ubMap, originalStep);
+    auto begin = t.getBody()->begin();
+    // Skip terminator and `newForOp` which is just before the terminator.
+    auto nOps = t.getBody()->getOperations().size() - 2;
+    newForOp.getBody()->getOperations().splice(
+        newForOp.getBody()->getOperations().begin(),
+        t.getBody()->getOperations(), begin, std::next(begin, nOps));
+    replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
+                               newForOp.region());
+    innerLoops.push_back(newForOp);
+  }
+
+  return innerLoops;
+}
+
+static Loops stripmineSink(loop::ForOp forOp, Value *factor,
+                           ArrayRef<loop::ForOp> targets) {
+  auto *originalStep = forOp.step();
+  auto *iv = forOp.getInductionVar();
+
+  OpBuilder b(forOp);
+  forOp.setStep(b.create<MulIOp>(forOp.getLoc(), originalStep, factor));
+
+  Loops innerLoops;
+  for (auto t : targets) {
+    // Save information for splicing ops out of t when done
+    auto begin = t.getBody()->begin();
+    auto nOps = t.getBody()->getOperations().size();
+
+    // Insert newForOp before the terminator of `t`.
+    OpBuilder b(t.getBodyBuilder());
+    Value *stepped = b.create<AddIOp>(t.getLoc(), iv, forOp.step());
+    Value *less = b.create<CmpIOp>(t.getLoc(), CmpIPredicate::SLT,
+                                   forOp.upperBound(), stepped);
+    Value *ub =
+        b.create<SelectOp>(t.getLoc(), less, forOp.upperBound(), stepped);
+
+    // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
+    auto newForOp = b.create<loop::ForOp>(t.getLoc(), iv, ub, originalStep);
+    newForOp.getBody()->getOperations().splice(
+        newForOp.getBody()->getOperations().begin(),
+        t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
+    replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
+                               newForOp.region());
+
+    innerLoops.push_back(newForOp);
+  }
+
+  return innerLoops;
+}
+
+// Stripmines a `forOp` by `factor` and sinks it under a single `target`.
+// Returns the new AffineForOps, nested immediately under `target`.
+template <typename ForType, typename SizeType>
+static ForType stripmineSink(ForType forOp, SizeType factor, ForType target) {
+  // TODO(ntv): Use cheap structural assertions that targets are nested under
+  // forOp and that targets are not nested under each other when DominanceInfo
+  // exposes the capability. It seems overkill to construct a whole function
+  // dominance tree at this point.
+  auto res = stripmineSink(forOp, factor, ArrayRef<ForType>{target});
+  assert(res.size() == 1 && "Expected 1 inner forOp");
+  return res[0];
+}
+
+template <typename ForType, typename SizeType>
+static SmallVector<SmallVector<ForType, 8>, 8>
+tileImpl(ArrayRef<ForType> forOps, ArrayRef<SizeType> sizes,
+         ArrayRef<ForType> targets) {
+  SmallVector<SmallVector<ForType, 8>, 8> res;
+  SmallVector<ForType, 8> currentTargets(targets.begin(), targets.end());
+  for (auto it : llvm::zip(forOps, sizes)) {
+    auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
+    res.push_back(step);
+    currentTargets = step;
+  }
+  return res;
+}
+
+SmallVector<SmallVector<AffineForOp, 8>, 8>
+mlir::tile(ArrayRef<AffineForOp> forOps, ArrayRef<uint64_t> sizes,
+           ArrayRef<AffineForOp> targets) {
+  return tileImpl(forOps, sizes, targets);
+}
+
+SmallVector<Loops, 8> mlir::tile(ArrayRef<loop::ForOp> forOps,
+                                 ArrayRef<Value *> sizes,
+                                 ArrayRef<loop::ForOp> targets) {
+  return tileImpl(forOps, sizes, targets);
+}
+
+template <typename ForType, typename SizeType>
+static SmallVector<ForType, 8>
+tileImpl(ArrayRef<ForType> forOps, ArrayRef<SizeType> sizes, ForType target) {
+  SmallVector<ForType, 8> res;
+  for (auto loops : tile(forOps, sizes, ArrayRef<ForType>{target})) {
+    assert(loops.size() == 1);
+    res.push_back(loops[0]);
+  }
+  return res;
+}
+
+SmallVector<AffineForOp, 8> mlir::tile(ArrayRef<AffineForOp> forOps,
+                                       ArrayRef<uint64_t> sizes,
+                                       AffineForOp target) {
+  return tileImpl(forOps, sizes, target);
+}
+
+Loops mlir::tile(ArrayRef<loop::ForOp> forOps, ArrayRef<Value *> sizes,
+                 loop::ForOp target) {
+  return tileImpl(forOps, sizes, target);
+}
+
+Loops mlir::tilePerfectlyNested(loop::ForOp rootForOp,
+                                ArrayRef<Value *> sizes) {
+  // Collect prefectly nested loops.  If more size values provided than nested
+  // loops available, truncate `sizes`.
+  SmallVector<loop::ForOp, 4> forOps;
+  forOps.reserve(sizes.size());
+  getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
+  if (forOps.size() < sizes.size())
+    sizes = sizes.take_front(forOps.size());
+
+  return ::tile(forOps, sizes, forOps.back());
+}
+
+// Build the IR that performs ceil division of a positive value by a constant:
+//    ceildiv(a, B) = divis(a + (B-1), B)
+// where divis is roundning-to-zero division.
+static Value *ceilDivPositive(OpBuilder &builder, Location loc, Value *dividend,
+                              int64_t divisor) {
+  assert(divisor > 0 && "expected positive divisor");
+  assert(dividend->getType().isIndex() && "expected index-typed value");
+
+  Value *divisorMinusOneCst = builder.create<ConstantIndexOp>(loc, divisor - 1);
+  Value *divisorCst = builder.create<ConstantIndexOp>(loc, divisor);
+  Value *sum = builder.create<AddIOp>(loc, dividend, divisorMinusOneCst);
+  return builder.create<DivISOp>(loc, sum, divisorCst);
+}
+
+// Build the IR that performs ceil division of a positive value by another
+// positive value:
+//    ceildiv(a, b) = divis(a + (b - 1), b)
+// where divis is rounding-to-zero division.
+static Value *ceilDivPositive(OpBuilder &builder, Location loc, Value *dividend,
+                              Value *divisor) {
+  assert(dividend->getType().isIndex() && "expected index-typed value");
+
+  Value *cstOne = builder.create<ConstantIndexOp>(loc, 1);
+  Value *divisorMinusOne = builder.create<SubIOp>(loc, divisor, cstOne);
+  Value *sum = builder.create<AddIOp>(loc, dividend, divisorMinusOne);
+  return builder.create<DivISOp>(loc, sum, divisor);
+}
+
+// Hoist the ops within `outer` that appear before `inner`.
+// Such ops include the ops that have been introduced by parametric tiling.
+// Ops that come from triangular loops (i.e. that belong to the program slice
+// rooted at `outer`) and ops that have side effects cannot be hoisted.
+// Return failure when any op fails to hoist.
+static LogicalResult hoistOpsBetween(loop::ForOp outer, loop::ForOp inner) {
+  SetVector<Operation *> forwardSlice;
+  getForwardSlice(outer.getOperation(), &forwardSlice, [&inner](Operation *op) {
+    return op != inner.getOperation();
+  });
+  LogicalResult status = success();
+  SmallVector<Operation *, 8> toHoist;
+  for (auto &op : outer.getBody()->getOperations()) {
+    // Stop when encountering the inner loop.
+    if (&op == inner.getOperation())
+      break;
+    // Skip over non-hoistable ops.
+    if (forwardSlice.count(&op) > 0) {
+      status = failure();
+      continue;
+    }
+    // Skip loop::ForOp, these are not considered a failure.
+    if (op.getNumRegions() > 0)
+      continue;
+    // Skip other ops with regions.
+    if (op.getNumRegions() > 0) {
+      status = failure();
+      continue;
+    }
+    // Skip if op has side effects.
+    // TODO(ntv): loads to immutable memory regions are ok.
+    if (!op.hasNoSideEffect()) {
+      status = failure();
+      continue;
+    }
+    toHoist.push_back(&op);
+  }
+  auto *outerForOp = outer.getOperation();
+  for (auto *op : toHoist)
+    op->moveBefore(outerForOp);
+  return status;
+}
+
+// Traverse the interTile and intraTile loops and try to hoist ops such that
+// bands of perfectly nested loops are isolated.
+// Return failure if either perfect interTile or perfect intraTile bands cannot
+// be formed.
+static LogicalResult tryIsolateBands(const TileLoops &tileLoops) {
+  LogicalResult status = success();
+  auto &interTile = tileLoops.first;
+  auto &intraTile = tileLoops.second;
+  auto size = interTile.size();
+  assert(size == intraTile.size());
+  if (size <= 1)
+    return success();
+  for (unsigned s = 1; s < size; ++s)
+    status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s])
+                               : failure();
+  for (unsigned s = 1; s < size; ++s)
+    status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s])
+                               : failure();
+  return status;
+}
+
+TileLoops mlir::extractFixedOuterLoops(loop::ForOp rootForOp,
+                                       ArrayRef<int64_t> sizes) {
+  // Collect prefectly nested loops.  If more size values provided than nested
+  // loops available, truncate `sizes`.
+  SmallVector<loop::ForOp, 4> forOps;
+  forOps.reserve(sizes.size());
+  getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
+  if (forOps.size() < sizes.size())
+    sizes = sizes.take_front(forOps.size());
+
+  // Compute the tile sizes such that i-th outer loop executes size[i]
+  // iterations.  Given that the loop current executes
+  //   numIterations = ceildiv((upperBound - lowerBound), step)
+  // iterations, we need to tile with size ceildiv(numIterations, size[i]).
+  SmallVector<Value *, 4> tileSizes;
+  tileSizes.reserve(sizes.size());
+  for (unsigned i = 0, e = sizes.size(); i < e; ++i) {
+    assert(sizes[i] > 0 && "expected strictly positive size for strip-mining");
+
+    auto forOp = forOps[i];
+    OpBuilder builder(forOp);
+    auto loc = forOp.getLoc();
+    Value *diff =
+        builder.create<SubIOp>(loc, forOp.upperBound(), forOp.lowerBound());
+    Value *numIterations = ceilDivPositive(builder, loc, diff, forOp.step());
+    Value *iterationsPerBlock =
+        ceilDivPositive(builder, loc, numIterations, sizes[i]);
+    tileSizes.push_back(iterationsPerBlock);
+  }
+
+  // Call parametric tiling with the given sizes.
+  auto intraTile = tile(forOps, tileSizes, forOps.back());
+  TileLoops tileLoops = std::make_pair(forOps, intraTile);
+
+  // TODO(ntv, zinenko) for now we just ignore the result of band isolation.
+  // In the future, mapping decisions may be impacted by the ability to
+  // isolate perfectly nested bands.
+  tryIsolateBands(tileLoops);
+
+  return tileLoops;
+}
+
+// Replaces all uses of `orig` with `replacement` except if the user is listed
+// in `exceptions`.
+static void
+replaceAllUsesExcept(Value *orig, Value *replacement,
+                     const SmallPtrSetImpl<Operation *> &exceptions) {
+  for (auto &use : orig->getUses()) {
+    if (exceptions.count(use.getOwner()) == 0)
+      use.set(replacement);
+  }
+}
+
+// Transform a loop with a strictly positive step
+//   for %i = %lb to %ub step %s
+// into a 0-based loop with step 1
+//   for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
+//     %i = %ii * %s + %lb
+// Insert the induction variable remapping in the body of `inner`, which is
+// expected to be either `loop` or another loop perfectly nested under `loop`.
+// Insert the definition of new bounds immediate before `outer`, which is
+// expected to be either `loop` or its parent in the loop nest.
+static void normalizeLoop(loop::ForOp loop, loop::ForOp outer,
+                          loop::ForOp inner) {
+  OpBuilder builder(outer);
+  Location loc = loop.getLoc();
+
+  // Check if the loop is already known to have a constant zero lower bound or
+  // a constant one step.
+  bool isZeroBased = false;
+  if (auto ubCst =
+          dyn_cast_or_null<ConstantIndexOp>(loop.lowerBound()->getDefiningOp()))
+    isZeroBased = ubCst.getValue() == 0;
+
+  bool isStepOne = false;
+  if (auto stepCst =
+          dyn_cast_or_null<ConstantIndexOp>(loop.step()->getDefiningOp()))
+    isStepOne = stepCst.getValue() == 1;
+
+  if (isZeroBased && isStepOne)
+    return;
+
+  // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
+  // assuming the step is strictly positive.  Update the bounds and the step
+  // of the loop to go from 0 to the number of iterations, if necessary.
+  // TODO(zinenko): introduce support for negative steps or emit dynamic asserts
+  // on step positivity, whatever gets implemented first.
+  Value *diff =
+      builder.create<SubIOp>(loc, loop.upperBound(), loop.lowerBound());
+  Value *numIterations = ceilDivPositive(builder, loc, diff, loop.step());
+  loop.setUpperBound(numIterations);
+
+  Value *lb = loop.lowerBound();
+  if (!isZeroBased) {
+    Value *cst0 = builder.create<ConstantIndexOp>(loc, 0);
+    loop.setLowerBound(cst0);
+  }
+
+  Value *step = loop.step();
+  if (!isStepOne) {
+    Value *cst1 = builder.create<ConstantIndexOp>(loc, 1);
+    loop.setStep(cst1);
+  }
+
+  // Insert code computing the value of the original loop induction variable
+  // from the "normalized" one.
+  builder.setInsertionPointToStart(inner.getBody());
+  Value *scaled =
+      isStepOne ? loop.getInductionVar()
+                : builder.create<MulIOp>(loc, loop.getInductionVar(), step);
+  Value *shifted =
+      isZeroBased ? scaled : builder.create<AddIOp>(loc, scaled, lb);
+
+  SmallPtrSet<Operation *, 2> preserve{scaled->getDefiningOp(),
+                                       shifted->getDefiningOp()};
+  replaceAllUsesExcept(loop.getInductionVar(), shifted, preserve);
+}
+
+void mlir::coalesceLoops(MutableArrayRef<loop::ForOp> loops) {
+  if (loops.size() < 2)
+    return;
+
+  loop::ForOp innermost = loops.back();
+  loop::ForOp outermost = loops.front();
+
+  // 1. Make sure all loops iterate from 0 to upperBound with step 1.  This
+  // allows the following code to assume upperBound is the number of iterations.
+  for (auto loop : loops)
+    normalizeLoop(loop, outermost, innermost);
+
+  // 2. Emit code computing the upper bound of the coalesced loop as product
+  // of the number of iterations of all loops.
+  OpBuilder builder(outermost);
+  Location loc = outermost.getLoc();
+  Value *upperBound = outermost.upperBound();
+  for (auto loop : loops.drop_front())
+    upperBound = builder.create<MulIOp>(loc, upperBound, loop.upperBound());
+  outermost.setUpperBound(upperBound);
+
+  builder.setInsertionPointToStart(outermost.getBody());
+
+  // 3. Remap induction variables.  For each original loop, the value of the
+  // induction variable can be obtained by dividing the induction variable of
+  // the linearized loop by the total number of iterations of the loops nested
+  // in it modulo the number of iterations in this loop (remove the values
+  // related to the outer loops):
+  //   iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
+  // Compute these iteratively from the innermost loop by creating a "running
+  // quotient" of division by the range.
+  Value *previous = outermost.getInductionVar();
+  for (unsigned i = 0, e = loops.size(); i < e; ++i) {
+    unsigned idx = loops.size() - i - 1;
+    if (i != 0)
+      previous =
+          builder.create<DivISOp>(loc, previous, loops[idx + 1].upperBound());
+
+    Value *iv = (i == e - 1) ? previous
+                             : builder.create<RemISOp>(loc, previous,
+                                                       loops[idx].upperBound());
+    replaceAllUsesInRegionWith(loops[idx].getInductionVar(), iv,
+                               loops.back().region());
+  }
+
+  // 4. Move the operations from the innermost just above the second-outermost
+  // loop, delete the extra terminator and the second-outermost loop.
+  loop::ForOp second = loops[1];
+  innermost.getBody()->back().erase();
+  outermost.getBody()->getOperations().splice(
+      Block::iterator(second.getOperation()),
+      innermost.getBody()->getOperations());
+  second.erase();
+}
+
+void mlir::mapLoopToProcessorIds(loop::ForOp forOp,
+                                 ArrayRef<Value *> processorId,
+                                 ArrayRef<Value *> numProcessors) {
+  assert(processorId.size() == numProcessors.size());
+  if (processorId.empty())
+    return;
+
+  OpBuilder b(forOp);
+  Location loc(forOp.getLoc());
+  Value *mul = processorId.front();
+  for (unsigned i = 1, e = processorId.size(); i < e; ++i)
+    mul = b.create<AddIOp>(loc, b.create<MulIOp>(loc, mul, numProcessors[i]),
+                           processorId[i]);
+  Value *lb = b.create<AddIOp>(loc, forOp.lowerBound(), mul);
+  forOp.setLowerBound(lb);
+
+  Value *step = numProcessors.front();
+  for (auto *numProcs : numProcessors.drop_front())
+    step = b.create<MulIOp>(loc, step, numProcs);
+  forOp.setStep(step);
+}
diff --git a/third_party/mlir/lib/Transforms/Utils/RegionUtils.cpp b/third_party/mlir/lib/Transforms/Utils/RegionUtils.cpp
new file mode 100644
index 0000000..a2b4fe3
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -0,0 +1,55 @@
+//===- RegionUtils.cpp - Region-related transformation utilities ----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Transforms/RegionUtils.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+
+#include "llvm/ADT/SmallSet.h"
+
+using namespace mlir;
+
+void mlir::replaceAllUsesInRegionWith(Value *orig, Value *replacement,
+                                      Region &region) {
+  for (IROperand &use : llvm::make_early_inc_range(orig->getUses())) {
+    if (region.isAncestor(use.getOwner()->getParentRegion()))
+      use.set(replacement);
+  }
+}
+
+void mlir::getUsedValuesDefinedAbove(Region &region, Region &limit,
+                                     llvm::SetVector<Value *> &values) {
+  assert(limit.isAncestor(&region) &&
+         "expected isolation limit to be an ancestor of the given region");
+
+  // Collect proper ancestors of `limit` upfront to avoid traversing the region
+  // tree for every value.
+  llvm::SmallPtrSet<Region *, 4> properAncestors;
+  for (auto *reg = limit.getParentRegion(); reg != nullptr;
+       reg = reg->getParentRegion()) {
+    properAncestors.insert(reg);
+  }
+
+  region.walk([&values, &properAncestors](Operation *op) {
+    for (Value *operand : op->getOperands())
+      // Collect values that are used by an operation and defined in a proper
+      // ancestor of region.
+      if (properAncestors.count(operand->getParentRegion()))
+        values.insert(operand);
+  });
+}
diff --git a/third_party/mlir/lib/Transforms/Utils/Utils.cpp b/third_party/mlir/lib/Transforms/Utils/Utils.cpp
new file mode 100644
index 0000000..250c769
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/Utils/Utils.cpp
@@ -0,0 +1,351 @@
+//===- Utils.cpp ---- Misc utilities for code and data transformation -----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements miscellaneous transformation routines for non-loop IR
+// structures.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Utils.h"
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/Dominance.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Support/MathExtras.h"
+#include "llvm/ADT/DenseMap.h"
+using namespace mlir;
+
+/// Return true if this operation dereferences one or more memref's.
+// Temporary utility: will be replaced when this is modeled through
+// side-effects/op traits. TODO(b/117228571)
+static bool isMemRefDereferencingOp(Operation &op) {
+  if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op) ||
+      isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op))
+    return true;
+  return false;
+}
+
+/// Return the AffineMapAttr associated with memory 'op' on 'memref'.
+static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value *memref) {
+  if (auto loadOp = dyn_cast<AffineLoadOp>(op))
+    return loadOp.getAffineMapAttrForMemRef(memref);
+  else if (auto storeOp = dyn_cast<AffineStoreOp>(op))
+    return storeOp.getAffineMapAttrForMemRef(memref);
+  else if (auto dmaStart = dyn_cast<AffineDmaStartOp>(op))
+    return dmaStart.getAffineMapAttrForMemRef(memref);
+  assert(isa<AffineDmaWaitOp>(op));
+  return cast<AffineDmaWaitOp>(op).getAffineMapAttrForMemRef(memref);
+}
+
+bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
+                                    ArrayRef<Value *> extraIndices,
+                                    AffineMap indexRemap,
+                                    ArrayRef<Value *> extraOperands,
+                                    Operation *domInstFilter,
+                                    Operation *postDomInstFilter) {
+  unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
+  (void)newMemRefRank; // unused in opt mode
+  unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
+  (void)newMemRefRank;
+  if (indexRemap) {
+    assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected");
+    assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank);
+    assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
+  } else {
+    assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
+  }
+
+  // Assert same elemental type.
+  assert(oldMemRef->getType().cast<MemRefType>().getElementType() ==
+         newMemRef->getType().cast<MemRefType>().getElementType());
+
+  std::unique_ptr<DominanceInfo> domInfo;
+  std::unique_ptr<PostDominanceInfo> postDomInfo;
+  if (domInstFilter)
+    domInfo = llvm::make_unique<DominanceInfo>(
+        domInstFilter->getParentOfType<FuncOp>());
+
+  if (postDomInstFilter)
+    postDomInfo = llvm::make_unique<PostDominanceInfo>(
+        postDomInstFilter->getParentOfType<FuncOp>());
+
+  // The ops where memref replacement succeeds are replaced with new ones.
+  SmallVector<Operation *, 8> opsToErase;
+
+  // Walk all uses of old memref. Operation using the memref gets replaced.
+  for (auto *opInst : llvm::make_early_inc_range(oldMemRef->getUsers())) {
+    // Skip this use if it's not dominated by domInstFilter.
+    if (domInstFilter && !domInfo->dominates(domInstFilter, opInst))
+      continue;
+
+    // Skip this use if it's not post-dominated by postDomInstFilter.
+    if (postDomInstFilter &&
+        !postDomInfo->postDominates(postDomInstFilter, opInst))
+      continue;
+
+    // Skip dealloc's - no replacement is necessary, and a replacement doesn't
+    // hurt dealloc's.
+    if (isa<DeallocOp>(opInst))
+      continue;
+
+    // Check if the memref was used in a non-deferencing context. It is fine for
+    // the memref to be used in a non-deferencing way outside of the region
+    // where this replacement is happening.
+    if (!isMemRefDereferencingOp(*opInst))
+      // Failure: memref used in a non-deferencing op (potentially escapes); no
+      // replacement in these cases.
+      return false;
+
+    auto getMemRefOperandPos = [&]() -> unsigned {
+      unsigned i, e;
+      for (i = 0, e = opInst->getNumOperands(); i < e; i++) {
+        if (opInst->getOperand(i) == oldMemRef)
+          break;
+      }
+      assert(i < opInst->getNumOperands() && "operand guaranteed to be found");
+      return i;
+    };
+
+    OpBuilder builder(opInst);
+    unsigned memRefOperandPos = getMemRefOperandPos();
+    NamedAttribute oldMapAttrPair =
+        getAffineMapAttrForMemRef(opInst, oldMemRef);
+    AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue();
+    unsigned oldMapNumInputs = oldMap.getNumInputs();
+    SmallVector<Value *, 4> oldMapOperands(
+        opInst->operand_begin() + memRefOperandPos + 1,
+        opInst->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
+    SmallVector<Value *, 4> affineApplyOps;
+
+    // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
+    SmallVector<Value *, 4> oldMemRefOperands;
+    oldMemRefOperands.reserve(oldMemRefRank);
+    if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
+      for (auto resultExpr : oldMap.getResults()) {
+        auto singleResMap = builder.getAffineMap(
+            oldMap.getNumDims(), oldMap.getNumSymbols(), resultExpr);
+        auto afOp = builder.create<AffineApplyOp>(opInst->getLoc(),
+                                                  singleResMap, oldMapOperands);
+        oldMemRefOperands.push_back(afOp);
+        affineApplyOps.push_back(afOp);
+      }
+    } else {
+      oldMemRefOperands.append(oldMapOperands.begin(), oldMapOperands.end());
+    }
+
+    // Construct new indices as a remap of the old ones if a remapping has been
+    // provided. The indices of a memref come right after it, i.e.,
+    // at position memRefOperandPos + 1.
+    SmallVector<Value *, 4> remapOperands;
+    remapOperands.reserve(extraOperands.size() + oldMemRefRank);
+    remapOperands.append(extraOperands.begin(), extraOperands.end());
+    remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
+
+    SmallVector<Value *, 4> remapOutputs;
+    remapOutputs.reserve(oldMemRefRank);
+
+    if (indexRemap &&
+        indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
+      // Remapped indices.
+      for (auto resultExpr : indexRemap.getResults()) {
+        auto singleResMap = builder.getAffineMap(
+            indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
+        auto afOp = builder.create<AffineApplyOp>(opInst->getLoc(),
+                                                  singleResMap, remapOperands);
+        remapOutputs.push_back(afOp);
+        affineApplyOps.push_back(afOp);
+      }
+    } else {
+      // No remapping specified.
+      remapOutputs.append(remapOperands.begin(), remapOperands.end());
+    }
+
+    SmallVector<Value *, 4> newMapOperands;
+    newMapOperands.reserve(newMemRefRank);
+
+    // Prepend 'extraIndices' in 'newMapOperands'.
+    for (auto *extraIndex : extraIndices) {
+      assert(extraIndex->getDefiningOp()->getNumResults() == 1 &&
+             "single result op's expected to generate these indices");
+      assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
+             "invalid memory op index");
+      newMapOperands.push_back(extraIndex);
+    }
+
+    // Append 'remapOutputs' to 'newMapOperands'.
+    newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
+
+    // Create new fully composed AffineMap for new op to be created.
+    assert(newMapOperands.size() == newMemRefRank);
+    auto newMap = builder.getMultiDimIdentityMap(newMemRefRank);
+    // TODO(b/136262594) Avoid creating/deleting temporary AffineApplyOps here.
+    fullyComposeAffineMapAndOperands(&newMap, &newMapOperands);
+    newMap = simplifyAffineMap(newMap);
+    canonicalizeMapAndOperands(&newMap, &newMapOperands);
+    // Remove any affine.apply's that became dead as a result of composition.
+    for (auto *value : affineApplyOps)
+      if (value->use_empty())
+        value->getDefiningOp()->erase();
+
+    // Construct the new operation using this memref.
+    OperationState state(opInst->getLoc(), opInst->getName());
+    state.setOperandListToResizable(opInst->hasResizableOperandsList());
+    state.operands.reserve(opInst->getNumOperands() + extraIndices.size());
+    // Insert the non-memref operands.
+    state.operands.append(opInst->operand_begin(),
+                          opInst->operand_begin() + memRefOperandPos);
+    // Insert the new memref value.
+    state.operands.push_back(newMemRef);
+
+    // Insert the new memref map operands.
+    state.operands.append(newMapOperands.begin(), newMapOperands.end());
+
+    // Insert the remaining operands unmodified.
+    state.operands.append(opInst->operand_begin() + memRefOperandPos + 1 +
+                              oldMapNumInputs,
+                          opInst->operand_end());
+
+    // Result types don't change. Both memref's are of the same elemental type.
+    state.types.reserve(opInst->getNumResults());
+    for (auto *result : opInst->getResults())
+      state.types.push_back(result->getType());
+
+    // Add attribute for 'newMap', other Attributes do not change.
+    auto newMapAttr = builder.getAffineMapAttr(newMap);
+    for (auto namedAttr : opInst->getAttrs()) {
+      if (namedAttr.first == oldMapAttrPair.first) {
+        state.attributes.push_back({namedAttr.first, newMapAttr});
+      } else {
+        state.attributes.push_back(namedAttr);
+      }
+    }
+
+    // Create the new operation.
+    auto *repOp = builder.createOperation(state);
+    opInst->replaceAllUsesWith(repOp);
+
+    // Collect and erase at the end since one of these op's could be
+    // domInstFilter or postDomInstFilter as well!
+    opsToErase.push_back(opInst);
+  }
+
+  for (auto *opInst : opsToErase)
+    opInst->erase();
+
+  return true;
+}
+
+/// Given an operation, inserts one or more single result affine
+/// apply operations, results of which are exclusively used by this operation
+/// operation. The operands of these newly created affine apply ops are
+/// guaranteed to be loop iterators or terminal symbols of a function.
+///
+/// Before
+///
+/// affine.for %i = 0 to #map(%N)
+///   %idx = affine.apply (d0) -> (d0 mod 2) (%i)
+///   "send"(%idx, %A, ...)
+///   "compute"(%idx)
+///
+/// After
+///
+/// affine.for %i = 0 to #map(%N)
+///   %idx = affine.apply (d0) -> (d0 mod 2) (%i)
+///   "send"(%idx, %A, ...)
+///   %idx_ = affine.apply (d0) -> (d0 mod 2) (%i)
+///   "compute"(%idx_)
+///
+/// This allows applying different transformations on send and compute (for eg.
+/// different shifts/delays).
+///
+/// Returns nullptr either if none of opInst's operands were the result of an
+/// affine.apply and thus there was no affine computation slice to create, or if
+/// all the affine.apply op's supplying operands to this opInst did not have any
+/// uses besides this opInst; otherwise returns the list of affine.apply
+/// operations created in output argument `sliceOps`.
+void mlir::createAffineComputationSlice(
+    Operation *opInst, SmallVectorImpl<AffineApplyOp> *sliceOps) {
+  // Collect all operands that are results of affine apply ops.
+  SmallVector<Value *, 4> subOperands;
+  subOperands.reserve(opInst->getNumOperands());
+  for (auto *operand : opInst->getOperands())
+    if (isa_and_nonnull<AffineApplyOp>(operand->getDefiningOp()))
+      subOperands.push_back(operand);
+
+  // Gather sequence of AffineApplyOps reachable from 'subOperands'.
+  SmallVector<Operation *, 4> affineApplyOps;
+  getReachableAffineApplyOps(subOperands, affineApplyOps);
+  // Skip transforming if there are no affine maps to compose.
+  if (affineApplyOps.empty())
+    return;
+
+  // Check if all uses of the affine apply op's lie only in this op op, in
+  // which case there would be nothing to do.
+  bool localized = true;
+  for (auto *op : affineApplyOps) {
+    for (auto *result : op->getResults()) {
+      for (auto *user : result->getUsers()) {
+        if (user != opInst) {
+          localized = false;
+          break;
+        }
+      }
+    }
+  }
+  if (localized)
+    return;
+
+  OpBuilder builder(opInst);
+  SmallVector<Value *, 4> composedOpOperands(subOperands);
+  auto composedMap = builder.getMultiDimIdentityMap(composedOpOperands.size());
+  fullyComposeAffineMapAndOperands(&composedMap, &composedOpOperands);
+
+  // Create an affine.apply for each of the map results.
+  sliceOps->reserve(composedMap.getNumResults());
+  for (auto resultExpr : composedMap.getResults()) {
+    auto singleResMap = builder.getAffineMap(
+        composedMap.getNumDims(), composedMap.getNumSymbols(), resultExpr);
+    sliceOps->push_back(builder.create<AffineApplyOp>(
+        opInst->getLoc(), singleResMap, composedOpOperands));
+  }
+
+  // Construct the new operands that include the results from the composed
+  // affine apply op above instead of existing ones (subOperands). So, they
+  // differ from opInst's operands only for those operands in 'subOperands', for
+  // which they will be replaced by the corresponding one from 'sliceOps'.
+  SmallVector<Value *, 4> newOperands(opInst->getOperands());
+  for (unsigned i = 0, e = newOperands.size(); i < e; i++) {
+    // Replace the subOperands from among the new operands.
+    unsigned j, f;
+    for (j = 0, f = subOperands.size(); j < f; j++) {
+      if (newOperands[i] == subOperands[j])
+        break;
+    }
+    if (j < subOperands.size()) {
+      newOperands[i] = (*sliceOps)[j];
+    }
+  }
+  for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++) {
+    opInst->setOperand(idx, newOperands[idx]);
+  }
+}
diff --git a/third_party/mlir/lib/Transforms/Vectorize.cpp b/third_party/mlir/lib/Transforms/Vectorize.cpp
new file mode 100644
index 0000000..ce25406
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/Vectorize.cpp
@@ -0,0 +1,1286 @@
+//===- Vectorize.cpp - Vectorize Pass Impl --------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements vectorization of loops, operations and data types to
+// a target-independent, n-D super-vector abstraction.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Analysis/NestedMatcher.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/Analysis/VectorAnalysis.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Support/Functional.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/VectorOps/VectorOps.h"
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+
+///
+/// Implements a high-level vectorization strategy on a Function.
+/// The abstraction used is that of super-vectors, which provide a single,
+/// compact, representation in the vector types, information that is expected
+/// to reduce the impact of the phase ordering problem
+///
+/// Vector granularity:
+/// ===================
+/// This pass is designed to perform vectorization at a super-vector
+/// granularity. A super-vector is loosely defined as a vector type that is a
+/// multiple of a "good" vector size so the HW can efficiently implement a set
+/// of high-level primitives. Multiple is understood along any dimension; e.g.
+/// both vector<16xf32> and vector<2x8xf32> are valid super-vectors for a
+/// vector<8xf32> HW vector. Note that a "good vector size so the HW can
+/// efficiently implement a set of high-level primitives" is not necessarily an
+/// integer multiple of actual hardware registers. We leave details of this
+/// distinction unspecified for now.
+///
+/// Some may prefer the terminology a "tile of HW vectors". In this case, one
+/// should note that super-vectors implement an "always full tile" abstraction.
+/// They guarantee no partial-tile separation is necessary by relying on a
+/// high-level copy-reshape abstraction that we call vector.transfer. This
+/// copy-reshape operations is also responsible for performing layout
+/// transposition if necessary. In the general case this will require a scoped
+/// allocation in some notional local memory.
+///
+/// Whatever the mental model one prefers to use for this abstraction, the key
+/// point is that we burn into a single, compact, representation in the vector
+/// types, information that is expected to reduce the impact of the phase
+/// ordering problem. Indeed, a vector type conveys information that:
+///   1. the associated loops have dependency semantics that do not prevent
+///      vectorization;
+///   2. the associate loops have been sliced in chunks of static sizes that are
+///      compatible with vector sizes (i.e. similar to unroll-and-jam);
+///   3. the inner loops, in the unroll-and-jam analogy of 2, are captured by
+///   the
+///      vector type and no vectorization hampering transformations can be
+///      applied to them anymore;
+///   4. the underlying memrefs are accessed in some notional contiguous way
+///      that allows loading into vectors with some amount of spatial locality;
+/// In other words, super-vectorization provides a level of separation of
+/// concern by way of opacity to subsequent passes. This has the effect of
+/// encapsulating and propagating vectorization constraints down the list of
+/// passes until we are ready to lower further.
+///
+/// For a particular target, a notion of minimal n-d vector size will be
+/// specified and vectorization targets a multiple of those. In the following
+/// paragraph, let "k ." represent "a multiple of", to be understood as a
+/// multiple in the same dimension (e.g. vector<16 x k . 128> summarizes
+/// vector<16 x 128>, vector<16 x 256>, vector<16 x 1024>, etc).
+///
+/// Some non-exhaustive notable super-vector sizes of interest include:
+///   - CPU: vector<k . HW_vector_size>,
+///          vector<k' . core_count x k . HW_vector_size>,
+///          vector<socket_count x k' . core_count x k . HW_vector_size>;
+///   - GPU: vector<k . warp_size>,
+///          vector<k . warp_size x float2>,
+///          vector<k . warp_size x float4>,
+///          vector<k . warp_size x 4 x 4x 4> (for tensor_core sizes).
+///
+/// Loops and operations are emitted that operate on those super-vector shapes.
+/// Subsequent lowering passes will materialize to actual HW vector sizes. These
+/// passes are expected to be (gradually) more target-specific.
+///
+/// At a high level, a vectorized load in a loop will resemble:
+/// ```mlir
+///   affine.for %i = ? to ? step ? {
+///     %v_a = vector.transfer_read A[%i] : memref<?xf32>, vector<128xf32>
+///   }
+/// ```
+/// It is the responsibility of the implementation of vector.transfer_read to
+/// materialize vector registers from the original scalar memrefs. A later (more
+/// target-dependent) lowering pass will materialize to actual HW vector sizes.
+/// This lowering may be occur at different times:
+///   1. at the MLIR level into a combination of loops, unrolling, DmaStartOp +
+///      DmaWaitOp + vectorized operations for data transformations and shuffle;
+///      thus opening opportunities for unrolling and pipelining. This is an
+///      instance of library call "whiteboxing"; or
+///   2. later in the a target-specific lowering pass or hand-written library
+///      call; achieving full separation of concerns. This is an instance of
+///      library call; or
+///   3. a mix of both, e.g. based on a model.
+/// In the future, these operations will expose a contract to constrain the
+/// search on vectorization patterns and sizes.
+///
+/// Occurrence of super-vectorization in the compiler flow:
+/// =======================================================
+/// This is an active area of investigation. We start with 2 remarks to position
+/// super-vectorization in the context of existing ongoing work: LLVM VPLAN
+/// and LLVM SLP Vectorizer.
+///
+/// LLVM VPLAN:
+/// -----------
+/// The astute reader may have noticed that in the limit, super-vectorization
+/// can be applied at a similar time and with similar objectives than VPLAN.
+/// For instance, in the case of a traditional, polyhedral compilation-flow (for
+/// instance, the PPCG project uses ISL to provide dependence analysis,
+/// multi-level(scheduling + tiling), lifting footprint to fast memory,
+/// communication synthesis, mapping, register optimizations) and before
+/// unrolling. When vectorization is applied at this *late* level in a typical
+/// polyhedral flow, and is instantiated with actual hardware vector sizes,
+/// super-vectorization is expected to match (or subsume) the type of patterns
+/// that LLVM's VPLAN aims at targeting. The main difference here is that MLIR
+/// is higher level and our implementation should be significantly simpler. Also
+/// note that in this mode, recursive patterns are probably a bit of an overkill
+/// although it is reasonable to expect that mixing a bit of outer loop and
+/// inner loop vectorization + unrolling will provide interesting choices to
+/// MLIR.
+///
+/// LLVM SLP Vectorizer:
+/// --------------------
+/// Super-vectorization however is not meant to be usable in a similar fashion
+/// to the SLP vectorizer. The main difference lies in the information that
+/// both vectorizers use: super-vectorization examines contiguity of memory
+/// references along fastest varying dimensions and loops with recursive nested
+/// patterns capturing imperfectly-nested loop nests; the SLP vectorizer, on
+/// the other hand, performs flat pattern matching inside a single unrolled loop
+/// body and stitches together pieces of load and store operations into full
+/// 1-D vectors. We envision that the SLP vectorizer is a good way to capture
+/// innermost loop, control-flow dependent patterns that super-vectorization may
+/// not be able to capture easily. In other words, super-vectorization does not
+/// aim at replacing the SLP vectorizer and the two solutions are complementary.
+///
+/// Ongoing investigations:
+/// -----------------------
+/// We discuss the following *early* places where super-vectorization is
+/// applicable and touch on the expected benefits and risks . We list the
+/// opportunities in the context of the traditional polyhedral compiler flow
+/// described in PPCG. There are essentially 6 places in the MLIR pass pipeline
+/// we expect to experiment with super-vectorization:
+/// 1. Right after language lowering to MLIR: this is the earliest time where
+///    super-vectorization is expected to be applied. At this level, all the
+///    language/user/library-level annotations are available and can be fully
+///    exploited. Examples include loop-type annotations (such as parallel,
+///    reduction, scan, dependence distance vector, vectorizable) as well as
+///    memory access annotations (such as non-aliasing writes guaranteed,
+///    indirect accesses that are permutations by construction) accesses or
+///    that a particular operation is prescribed atomic by the user. At this
+///    level, anything that enriches what dependence analysis can do should be
+///    aggressively exploited. At this level we are close to having explicit
+///    vector types in the language, except we do not impose that burden on the
+///    programmer/library: we derive information from scalar code + annotations.
+/// 2. After dependence analysis and before polyhedral scheduling: the
+///    information that supports vectorization does not need to be supplied by a
+///    higher level of abstraction. Traditional dependence anaysis is available
+///    in MLIR and will be used to drive vectorization and cost models.
+///
+/// Let's pause here and remark that applying super-vectorization as described
+/// in 1. and 2. presents clear opportunities and risks:
+///   - the opportunity is that vectorization is burned in the type system and
+///   is protected from the adverse effect of loop scheduling, tiling, loop
+///   interchange and all passes downstream. Provided that subsequent passes are
+///   able to operate on vector types; the vector shapes, associated loop
+///   iterator properties, alignment, and contiguity of fastest varying
+///   dimensions are preserved until we lower the super-vector types. We expect
+///   this to significantly rein in on the adverse effects of phase ordering.
+///   - the risks are that a. all passes after super-vectorization have to work
+///   on elemental vector types (not that this is always true, wherever
+///   vectorization is applied) and b. that imposing vectorization constraints
+///   too early may be overall detrimental to loop fusion, tiling and other
+///   transformations because the dependence distances are coarsened when
+///   operating on elemental vector types. For this reason, the pattern
+///   profitability analysis should include a component that also captures the
+///   maximal amount of fusion available under a particular pattern. This is
+///   still at the stage of rought ideas but in this context, search is our
+///   friend as the Tensor Comprehensions and auto-TVM contributions
+///   demonstrated previously.
+/// Bottom-line is we do not yet have good answers for the above but aim at
+/// making it easy to answer such questions.
+///
+/// Back to our listing, the last places where early super-vectorization makes
+/// sense are:
+/// 3. right after polyhedral-style scheduling: PLUTO-style algorithms are known
+///    to improve locality, parallelism and be configurable (e.g. max-fuse,
+///    smart-fuse etc). They can also have adverse effects on contiguity
+///    properties that are required for vectorization but the vector.transfer
+///    copy-reshape-pad-transpose abstraction is expected to help recapture
+///    these properties.
+/// 4. right after polyhedral-style scheduling+tiling;
+/// 5. right after scheduling+tiling+rescheduling: points 4 and 5 represent
+///    probably the most promising places because applying tiling achieves a
+///    separation of concerns that allows rescheduling to worry less about
+///    locality and more about parallelism and distribution (e.g. min-fuse).
+///
+/// At these levels the risk-reward looks different: on one hand we probably
+/// lost a good deal of language/user/library-level annotation; on the other
+/// hand we gained parallelism and locality through scheduling and tiling.
+/// However we probably want to ensure tiling is compatible with the
+/// full-tile-only abstraction used in super-vectorization or suffer the
+/// consequences. It is too early to place bets on what will win but we expect
+/// super-vectorization to be the right abstraction to allow exploring at all
+/// these levels. And again, search is our friend.
+///
+/// Lastly, we mention it again here:
+/// 6. as a MLIR-based alternative to VPLAN.
+///
+/// Lowering, unrolling, pipelining:
+/// ================================
+/// TODO(ntv): point to the proper places.
+///
+/// Algorithm:
+/// ==========
+/// The algorithm proceeds in a few steps:
+///  1. defining super-vectorization patterns and matching them on the tree of
+///     AffineForOp. A super-vectorization pattern is defined as a recursive
+///     data structures that matches and captures nested, imperfectly-nested
+///     loops that have a. comformable loop annotations attached (e.g. parallel,
+///     reduction, vectoriable, ...) as well as b. all contiguous load/store
+///     operations along a specified minor dimension (not necessarily the
+///     fastest varying) ;
+///  2. analyzing those patterns for profitability (TODO(ntv): and
+///     interference);
+///  3. Then, for each pattern in order:
+///    a. applying iterative rewriting of the loop and the load operations in
+///       DFS postorder. Rewriting is implemented by coarsening the loops and
+///       turning load operations into opaque vector.transfer_read ops;
+///    b. keeping track of the load operations encountered as "roots" and the
+///       store operations as "terminals";
+///    c. traversing the use-def chains starting from the roots and iteratively
+///       propagating vectorized values. Scalar values that are encountered
+///       during this process must come from outside the scope of the current
+///       pattern (TODO(ntv): enforce this and generalize). Such a scalar value
+///       is vectorized only if it is a constant (into a vector splat). The
+///       non-constant case is not supported for now and results in the pattern
+///       failing to vectorize;
+///    d. performing a second traversal on the terminals (store ops) to
+///       rewriting the scalar value they write to memory into vector form.
+///       If the scalar value has been vectorized previously, we simply replace
+///       it by its vector form. Otherwise, if the scalar value is a constant,
+///       it is vectorized into a splat. In all other cases, vectorization for
+///       the pattern currently fails.
+///    e. if everything under the root AffineForOp in the current pattern
+///       vectorizes properly, we commit that loop to the IR. Otherwise we
+///       discard it and restore a previously cloned version of the loop. Thanks
+///       to the recursive scoping nature of matchers and captured patterns,
+///       this is transparently achieved by a simple RAII implementation.
+///    f. vectorization is applied on the next pattern in the list. Because
+///       pattern interference avoidance is not yet implemented and that we do
+///       not support further vectorizing an already vector load we need to
+///       re-verify that the pattern is still vectorizable. This is expected to
+///       make cost models more difficult to write and is subject to improvement
+///       in the future.
+///
+/// Points c. and d. above are worth additional comment. In most passes that
+/// do not change the type of operands, it is usually preferred to eagerly
+/// `replaceAllUsesWith`. Unfortunately this does not work for vectorization
+/// because during the use-def chain traversal, all the operands of an operation
+/// must be available in vector form. Trying to propagate eagerly makes the IR
+/// temporarily invalid and results in errors such as:
+///   `vectorize.mlir:308:13: error: 'addf' op requires the same type for all
+///   operands and results
+///      %s5 = addf %a5, %b5 : f32`
+///
+/// Lastly, we show a minimal example for which use-def chains rooted in load /
+/// vector.transfer_read are not enough. This is what motivated splitting
+/// terminal processing out of the use-def chains starting from loads. In the
+/// following snippet, there is simply no load::
+/// ```mlir
+/// mlfunc @fill(%A : memref<128xf32>) -> () {
+///   %f1 = constant 1.0 : f32
+///   affine.for %i0 = 0 to 32 {
+///     store %f1, %A[%i0] : memref<128xf32, 0>
+///   }
+///   return
+/// }
+/// ```
+///
+/// Choice of loop transformation to support the algorithm:
+/// =======================================================
+/// The choice of loop transformation to apply for coarsening vectorized loops
+/// is still subject to exploratory tradeoffs. In particular, say we want to
+/// vectorize by a factor 128, we want to transform the following input:
+/// ```mlir
+///   affine.for %i = %M to %N {
+///     %a = load A[%i] : memref<?xf32>
+///   }
+/// ```
+///
+/// Traditionally, one would vectorize late (after scheduling, tiling,
+/// memory promotion etc) say after stripmining (and potentially unrolling in
+/// the case of LLVM's SLP vectorizer):
+/// ```mlir
+///   affine.for %i = floor(%M, 128) to ceil(%N, 128) {
+///     affine.for %ii = max(%M, 128 * %i) to min(%N, 128*%i + 127) {
+///       %a = load A[%ii] : memref<?xf32>
+///     }
+///   }
+/// ```
+///
+/// Instead, we seek to vectorize early and freeze vector types before
+/// scheduling, so we want to generate a pattern that resembles:
+/// ```mlir
+///   affine.for %i = ? to ? step ? {
+///     %v_a = vector.transfer_read A[%i] : memref<?xf32>, vector<128xf32>
+///   }
+/// ```
+///
+/// i. simply dividing the lower / upper bounds by 128 creates issues
+///    when representing expressions such as ii + 1 because now we only
+///    have access to original values that have been divided. Additional
+///    information is needed to specify accesses at below-128 granularity;
+/// ii. another alternative is to coarsen the loop step but this may have
+///    consequences on dependence analysis and fusability of loops: fusable
+///    loops probably need to have the same step (because we don't want to
+///    stripmine/unroll to enable fusion).
+/// As a consequence, we choose to represent the coarsening using the loop
+/// step for now and reevaluate in the future. Note that we can renormalize
+/// loop steps later if/when we have evidence that they are problematic.
+///
+/// For the simple strawman example above, vectorizing for a 1-D vector
+/// abstraction of size 128 returns code similar to:
+/// ```mlir
+///   affine.for %i = %M to %N step 128 {
+///     %v_a = vector.transfer_read A[%i] : memref<?xf32>, vector<128xf32>
+///   }
+/// ```
+///
+/// Unsupported cases, extensions, and work in progress (help welcome :-) ):
+/// ========================================================================
+///   1. lowering to concrete vector types for various HW;
+///   2. reduction support;
+///   3. non-effecting padding during vector.transfer_read and filter during
+///      vector.transfer_write;
+///   4. misalignment support vector.transfer_read / vector.transfer_write
+///      (hopefully without read-modify-writes);
+///   5. control-flow support;
+///   6. cost-models, heuristics and search;
+///   7. Op implementation, extensions and implication on memref views;
+///   8. many TODOs left around.
+///
+/// Examples:
+/// =========
+/// Consider the following Function:
+/// ```mlir
+/// mlfunc @vector_add_2d(%M : index, %N : index) -> f32 {
+///   %A = alloc (%M, %N) : memref<?x?xf32, 0>
+///   %B = alloc (%M, %N) : memref<?x?xf32, 0>
+///   %C = alloc (%M, %N) : memref<?x?xf32, 0>
+///   %f1 = constant 1.0 : f32
+///   %f2 = constant 2.0 : f32
+///   affine.for %i0 = 0 to %M {
+///     affine.for %i1 = 0 to %N {
+///       // non-scoped %f1
+///       store %f1, %A[%i0, %i1] : memref<?x?xf32, 0>
+///     }
+///   }
+///   affine.for %i2 = 0 to %M {
+///     affine.for %i3 = 0 to %N {
+///       // non-scoped %f2
+///       store %f2, %B[%i2, %i3] : memref<?x?xf32, 0>
+///     }
+///   }
+///   affine.for %i4 = 0 to %M {
+///     affine.for %i5 = 0 to %N {
+///       %a5 = load %A[%i4, %i5] : memref<?x?xf32, 0>
+///       %b5 = load %B[%i4, %i5] : memref<?x?xf32, 0>
+///       %s5 = addf %a5, %b5 : f32
+///       // non-scoped %f1
+///       %s6 = addf %s5, %f1 : f32
+///       // non-scoped %f2
+///       %s7 = addf %s5, %f2 : f32
+///       // diamond dependency.
+///       %s8 = addf %s7, %s6 : f32
+///       store %s8, %C[%i4, %i5] : memref<?x?xf32, 0>
+///     }
+///   }
+///   %c7 = constant 7 : index
+///   %c42 = constant 42 : index
+///   %res = load %C[%c7, %c42] : memref<?x?xf32, 0>
+///   return %res : f32
+/// }
+/// ```
+///
+/// TODO(ntv): update post b/119731251.
+/// The -vectorize pass with the following arguments:
+/// ```
+/// -vectorize -virtual-vector-size 256 --test-fastest-varying=0
+/// ```
+///
+/// produces this standard innermost-loop vectorized code:
+/// ```mlir
+/// mlfunc @vector_add_2d(%arg0 : index, %arg1 : index) -> f32 {
+///   %0 = alloc(%arg0, %arg1) : memref<?x?xf32>
+///   %1 = alloc(%arg0, %arg1) : memref<?x?xf32>
+///   %2 = alloc(%arg0, %arg1) : memref<?x?xf32>
+///   %cst = constant 1.0 : f32
+///   %cst_0 = constant 2.0 : f32
+///   affine.for %i0 = 0 to %arg0 {
+///     affine.for %i1 = 0 to %arg1 step 256 {
+///       %cst_1 = constant dense<vector<256xf32>, 1.0> :
+///                vector<256xf32>
+///       vector.transfer_write %cst_1, %0[%i0, %i1] :
+///                vector<256xf32>, memref<?x?xf32>
+///     }
+///   }
+///   affine.for %i2 = 0 to %arg0 {
+///     affine.for %i3 = 0 to %arg1 step 256 {
+///       %cst_2 = constant dense<vector<256xf32>, 2.0> :
+///                vector<256xf32>
+///       vector.transfer_write %cst_2, %1[%i2, %i3] :
+///                vector<256xf32>, memref<?x?xf32>
+///     }
+///   }
+///   affine.for %i4 = 0 to %arg0 {
+///     affine.for %i5 = 0 to %arg1 step 256 {
+///       %3 = vector.transfer_read %0[%i4, %i5] :
+///            memref<?x?xf32>, vector<256xf32>
+///       %4 = vector.transfer_read %1[%i4, %i5] :
+///            memref<?x?xf32>, vector<256xf32>
+///       %5 = addf %3, %4 : vector<256xf32>
+///       %cst_3 = constant dense<vector<256xf32>, 1.0> :
+///                vector<256xf32>
+///       %6 = addf %5, %cst_3 : vector<256xf32>
+///       %cst_4 = constant dense<vector<256xf32>, 2.0> :
+///                vector<256xf32>
+///       %7 = addf %5, %cst_4 : vector<256xf32>
+///       %8 = addf %7, %6 : vector<256xf32>
+///       vector.transfer_write %8, %2[%i4, %i5] :
+///                vector<256xf32>, memref<?x?xf32>
+///     }
+///   }
+///   %c7 = constant 7 : index
+///   %c42 = constant 42 : index
+///   %9 = load %2[%c7, %c42] : memref<?x?xf32>
+///   return %9 : f32
+/// }
+/// ```
+///
+/// TODO(ntv): update post b/119731251.
+/// The -vectorize pass with the following arguments:
+/// ```
+/// -vectorize -virtual-vector-size 32 -virtual-vector-size 256
+/// --test-fastest-varying=1 --test-fastest-varying=0
+/// ```
+///
+/// produces this more insteresting mixed outer-innermost-loop vectorized code:
+/// ```mlir
+/// mlfunc @vector_add_2d(%arg0 : index, %arg1 : index) -> f32 {
+///   %0 = alloc(%arg0, %arg1) : memref<?x?xf32>
+///   %1 = alloc(%arg0, %arg1) : memref<?x?xf32>
+///   %2 = alloc(%arg0, %arg1) : memref<?x?xf32>
+///   %cst = constant 1.0 : f32
+///   %cst_0 = constant 2.0 : f32
+///   affine.for %i0 = 0 to %arg0 step 32 {
+///     affine.for %i1 = 0 to %arg1 step 256 {
+///       %cst_1 = constant dense<vector<32x256xf32>, 1.0> :
+///                vector<32x256xf32>
+///       vector.transfer_write %cst_1, %0[%i0, %i1] :
+///                vector<32x256xf32>, memref<?x?xf32>
+///     }
+///   }
+///   affine.for %i2 = 0 to %arg0 step 32 {
+///     affine.for %i3 = 0 to %arg1 step 256 {
+///       %cst_2 = constant dense<vector<32x256xf32>, 2.0> :
+///                vector<32x256xf32>
+///       vector.transfer_write %cst_2, %1[%i2, %i3] :
+///                vector<32x256xf32>, memref<?x?xf32>
+///     }
+///   }
+///   affine.for %i4 = 0 to %arg0 step 32 {
+///     affine.for %i5 = 0 to %arg1 step 256 {
+///       %3 = vector.transfer_read %0[%i4, %i5] :
+///                memref<?x?xf32> vector<32x256xf32>
+///       %4 = vector.transfer_read %1[%i4, %i5] :
+///                memref<?x?xf32>, vector<32x256xf32>
+///       %5 = addf %3, %4 : vector<32x256xf32>
+///       %cst_3 = constant dense<vector<32x256xf32>, 1.0> :
+///                vector<32x256xf32>
+///       %6 = addf %5, %cst_3 : vector<32x256xf32>
+///       %cst_4 = constant dense<vector<32x256xf32>, 2.0> :
+///                vector<32x256xf32>
+///       %7 = addf %5, %cst_4 : vector<32x256xf32>
+///       %8 = addf %7, %6 : vector<32x256xf32>
+///       vector.transfer_write %8, %2[%i4, %i5] :
+///                vector<32x256xf32>, memref<?x?xf32>
+///     }
+///   }
+///   %c7 = constant 7 : index
+///   %c42 = constant 42 : index
+///   %9 = load %2[%c7, %c42] : memref<?x?xf32>
+///   return %9 : f32
+/// }
+/// ```
+///
+/// Of course, much more intricate n-D imperfectly-nested patterns can be
+/// vectorized too and specified in a fully declarative fashion.
+
+#define DEBUG_TYPE "early-vect"
+
+using functional::makePtrDynCaster;
+using functional::map;
+using llvm::dbgs;
+using llvm::SetVector;
+
+static llvm::cl::OptionCategory clOptionsCategory("vectorize options");
+
+static llvm::cl::list<int> clVirtualVectorSize(
+    "virtual-vector-size",
+    llvm::cl::desc("Specify an n-D virtual vector size for vectorization"),
+    llvm::cl::ZeroOrMore, llvm::cl::cat(clOptionsCategory));
+
+static llvm::cl::list<int> clFastestVaryingPattern(
+    "test-fastest-varying",
+    llvm::cl::desc(
+        "Specify a 1-D, 2-D or 3-D pattern of fastest varying memory"
+        " dimensions to match. See defaultPatterns in Vectorize.cpp for a"
+        " description and examples. This is used for testing purposes"),
+    llvm::cl::ZeroOrMore, llvm::cl::cat(clOptionsCategory));
+
+/// Forward declaration.
+static FilterFunctionType
+isVectorizableLoopPtrFactory(const llvm::DenseSet<Operation *> &parallelLoops,
+                             int fastestVaryingMemRefDimension);
+
+/// Creates a vectorization pattern from the command line arguments.
+/// Up to 3-D patterns are supported.
+/// If the command line argument requests a pattern of higher order, returns an
+/// empty pattern list which will conservatively result in no vectorization.
+static std::vector<NestedPattern>
+makePatterns(const llvm::DenseSet<Operation *> &parallelLoops, int vectorRank,
+             ArrayRef<int64_t> fastestVaryingPattern) {
+  using matcher::For;
+  int64_t d0 = fastestVaryingPattern.empty() ? -1 : fastestVaryingPattern[0];
+  int64_t d1 = fastestVaryingPattern.size() < 2 ? -1 : fastestVaryingPattern[1];
+  int64_t d2 = fastestVaryingPattern.size() < 3 ? -1 : fastestVaryingPattern[2];
+  switch (vectorRank) {
+  case 1:
+    return {For(isVectorizableLoopPtrFactory(parallelLoops, d0))};
+  case 2:
+    return {For(isVectorizableLoopPtrFactory(parallelLoops, d0),
+                For(isVectorizableLoopPtrFactory(parallelLoops, d1)))};
+  case 3:
+    return {For(isVectorizableLoopPtrFactory(parallelLoops, d0),
+                For(isVectorizableLoopPtrFactory(parallelLoops, d1),
+                    For(isVectorizableLoopPtrFactory(parallelLoops, d2))))};
+  default: {
+    return std::vector<NestedPattern>();
+  }
+  }
+}
+
+namespace {
+
+/// Base state for the vectorize pass.
+/// Command line arguments are preempted by non-empty pass arguments.
+struct Vectorize : public FunctionPass<Vectorize> {
+  Vectorize();
+  Vectorize(ArrayRef<int64_t> virtualVectorSize);
+  void runOnFunction() override;
+
+  // The virtual vector size that we vectorize to.
+  SmallVector<int64_t, 4> vectorSizes;
+  // Optionally, the fixed mapping from loop to fastest varying MemRef dimension
+  // for all the MemRefs within a loop pattern:
+  //   the index represents the loop depth, the value represents the k^th
+  //   fastest varying memory dimension.
+  // This is voluntarily restrictive and is meant to precisely target a
+  // particular loop/op pair, for testing purposes.
+  SmallVector<int64_t, 4> fastestVaryingPattern;
+};
+
+} // end anonymous namespace
+
+Vectorize::Vectorize()
+    : vectorSizes(clVirtualVectorSize.begin(), clVirtualVectorSize.end()),
+      fastestVaryingPattern(clFastestVaryingPattern.begin(),
+                            clFastestVaryingPattern.end()) {}
+
+Vectorize::Vectorize(ArrayRef<int64_t> virtualVectorSize) : Vectorize() {
+  if (!virtualVectorSize.empty()) {
+    this->vectorSizes.assign(virtualVectorSize.begin(),
+                             virtualVectorSize.end());
+  }
+}
+
+/////// TODO(ntv): Hoist to a VectorizationStrategy.cpp when appropriate.
+/////////
+namespace {
+
+struct VectorizationStrategy {
+  SmallVector<int64_t, 8> vectorSizes;
+  DenseMap<Operation *, unsigned> loopToVectorDim;
+};
+
+} // end anonymous namespace
+
+static void vectorizeLoopIfProfitable(Operation *loop, unsigned depthInPattern,
+                                      unsigned patternDepth,
+                                      VectorizationStrategy *strategy) {
+  assert(patternDepth > depthInPattern &&
+         "patternDepth is greater than depthInPattern");
+  if (patternDepth - depthInPattern > strategy->vectorSizes.size()) {
+    // Don't vectorize this loop
+    return;
+  }
+  strategy->loopToVectorDim[loop] =
+      strategy->vectorSizes.size() - (patternDepth - depthInPattern);
+}
+
+/// Implements a simple strawman strategy for vectorization.
+/// Given a matched pattern `matches` of depth `patternDepth`, this strategy
+/// greedily assigns the fastest varying dimension ** of the vector ** to the
+/// innermost loop in the pattern.
+/// When coupled with a pattern that looks for the fastest varying dimension in
+/// load/store MemRefs, this creates a generic vectorization strategy that works
+/// for any loop in a hierarchy (outermost, innermost or intermediate).
+///
+/// TODO(ntv): In the future we should additionally increase the power of the
+/// profitability analysis along 3 directions:
+///   1. account for loop extents (both static and parametric + annotations);
+///   2. account for data layout permutations;
+///   3. account for impact of vectorization on maximal loop fusion.
+/// Then we can quantify the above to build a cost model and search over
+/// strategies.
+static LogicalResult analyzeProfitability(ArrayRef<NestedMatch> matches,
+                                          unsigned depthInPattern,
+                                          unsigned patternDepth,
+                                          VectorizationStrategy *strategy) {
+  for (auto m : matches) {
+    if (failed(analyzeProfitability(m.getMatchedChildren(), depthInPattern + 1,
+                                    patternDepth, strategy))) {
+      return failure();
+    }
+    vectorizeLoopIfProfitable(m.getMatchedOperation(), depthInPattern,
+                              patternDepth, strategy);
+  }
+  return success();
+}
+
+///// end TODO(ntv): Hoist to a VectorizationStrategy.cpp when appropriate /////
+
+namespace {
+
+struct VectorizationState {
+  /// Adds an entry of pre/post vectorization operations in the state.
+  void registerReplacement(Operation *key, Operation *value);
+  /// When the current vectorization pattern is successful, this erases the
+  /// operations that were marked for erasure in the proper order and resets
+  /// the internal state for the next pattern.
+  void finishVectorizationPattern();
+
+  // In-order tracking of original Operation that have been vectorized.
+  // Erase in reverse order.
+  SmallVector<Operation *, 16> toErase;
+  // Set of Operation that have been vectorized (the values in the
+  // vectorizationMap for hashed access). The vectorizedSet is used in
+  // particular to filter the operations that have already been vectorized by
+  // this pattern, when iterating over nested loops in this pattern.
+  DenseSet<Operation *> vectorizedSet;
+  // Map of old scalar Operation to new vectorized Operation.
+  DenseMap<Operation *, Operation *> vectorizationMap;
+  // Map of old scalar Value to new vectorized Value.
+  DenseMap<Value *, Value *> replacementMap;
+  // The strategy drives which loop to vectorize by which amount.
+  const VectorizationStrategy *strategy;
+  // Use-def roots. These represent the starting points for the worklist in the
+  // vectorizeNonTerminals function. They consist of the subset of load
+  // operations that have been vectorized. They can be retrieved from
+  // `vectorizationMap` but it is convenient to keep track of them in a separate
+  // data structure.
+  DenseSet<Operation *> roots;
+  // Terminal operations for the worklist in the vectorizeNonTerminals
+  // function. They consist of the subset of store operations that have been
+  // vectorized. They can be retrieved from `vectorizationMap` but it is
+  // convenient to keep track of them in a separate data structure. Since they
+  // do not necessarily belong to use-def chains starting from loads (e.g
+  // storing a constant), we need to handle them in a post-pass.
+  DenseSet<Operation *> terminals;
+  // Checks that the type of `op` is AffineStoreOp and adds it to the terminals
+  // set.
+  void registerTerminal(Operation *op);
+
+private:
+  void registerReplacement(Value *key, Value *value);
+};
+
+} // end namespace
+
+void VectorizationState::registerReplacement(Operation *key, Operation *value) {
+  LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op: ");
+  LLVM_DEBUG(key->print(dbgs()));
+  LLVM_DEBUG(dbgs() << "  into  ");
+  LLVM_DEBUG(value->print(dbgs()));
+  assert(key->getNumResults() == 1 && "already registered");
+  assert(value->getNumResults() == 1 && "already registered");
+  assert(vectorizedSet.count(value) == 0 && "already registered");
+  assert(vectorizationMap.count(key) == 0 && "already registered");
+  toErase.push_back(key);
+  vectorizedSet.insert(value);
+  vectorizationMap.insert(std::make_pair(key, value));
+  registerReplacement(key->getResult(0), value->getResult(0));
+  if (isa<AffineLoadOp>(key)) {
+    assert(roots.count(key) == 0 && "root was already inserted previously");
+    roots.insert(key);
+  }
+}
+
+void VectorizationState::registerTerminal(Operation *op) {
+  assert(isa<AffineStoreOp>(op) && "terminal must be a AffineStoreOp");
+  assert(terminals.count(op) == 0 &&
+         "terminal was already inserted previously");
+  terminals.insert(op);
+}
+
+void VectorizationState::finishVectorizationPattern() {
+  while (!toErase.empty()) {
+    auto *op = toErase.pop_back_val();
+    LLVM_DEBUG(dbgs() << "\n[early-vect] finishVectorizationPattern erase: ");
+    LLVM_DEBUG(op->print(dbgs()));
+    op->erase();
+  }
+}
+
+void VectorizationState::registerReplacement(Value *key, Value *value) {
+  assert(replacementMap.count(key) == 0 && "replacement already registered");
+  replacementMap.insert(std::make_pair(key, value));
+}
+
+// Apply 'map' with 'mapOperands' returning resulting values in 'results'.
+static void computeMemoryOpIndices(Operation *op, AffineMap map,
+                                   ArrayRef<Value *> mapOperands,
+                                   SmallVectorImpl<Value *> &results) {
+  OpBuilder builder(op);
+  for (auto resultExpr : map.getResults()) {
+    auto singleResMap =
+        builder.getAffineMap(map.getNumDims(), map.getNumSymbols(), resultExpr);
+    auto afOp =
+        builder.create<AffineApplyOp>(op->getLoc(), singleResMap, mapOperands);
+    results.push_back(afOp);
+  }
+}
+
+////// TODO(ntv): Hoist to a VectorizationMaterialize.cpp when appropriate. ////
+
+/// Handles the vectorization of load and store MLIR operations.
+///
+/// AffineLoadOp operations are the roots of the vectorizeNonTerminals call.
+/// They are vectorized immediately. The resulting vector.transfer_read is
+/// immediately registered to replace all uses of the AffineLoadOp in this
+/// pattern's scope.
+///
+/// AffineStoreOp are the terminals of the vectorizeNonTerminals call. They
+/// need to be vectorized late once all the use-def chains have been traversed.
+/// Additionally, they may have ssa-values operands which come from outside the
+/// scope of the current pattern.
+/// Such special cases force us to delay the vectorization of the stores until
+/// the last step. Here we merely register the store operation.
+template <typename LoadOrStoreOpPointer>
+static LogicalResult vectorizeRootOrTerminal(Value *iv,
+                                             LoadOrStoreOpPointer memoryOp,
+                                             VectorizationState *state) {
+  auto memRefType = memoryOp.getMemRef()->getType().template cast<MemRefType>();
+
+  auto elementType = memRefType.getElementType();
+  // TODO(ntv): ponder whether we want to further vectorize a vector value.
+  assert(VectorType::isValidElementType(elementType) &&
+         "Not a valid vector element type");
+  auto vectorType = VectorType::get(state->strategy->vectorSizes, elementType);
+
+  // Materialize a MemRef with 1 vector.
+  auto *opInst = memoryOp.getOperation();
+  // For now, vector.transfers must be aligned, operate only on indices with an
+  // identity subset of AffineMap and do not change layout.
+  // TODO(ntv): increase the expressiveness power of vector.transfer operations
+  // as needed by various targets.
+  if (auto load = dyn_cast<AffineLoadOp>(opInst)) {
+    OpBuilder b(opInst);
+    SmallVector<Value *, 4> mapOperands(load.getIndices());
+    SmallVector<Value *, 8> indices;
+    indices.reserve(load.getMemRefType().getRank());
+    if (load.getAffineMap() !=
+        b.getMultiDimIdentityMap(load.getMemRefType().getRank())) {
+      computeMemoryOpIndices(opInst, load.getAffineMap(), mapOperands, indices);
+    } else {
+      indices.append(load.getIndices().begin(), load.getIndices().end());
+    }
+    auto permutationMap =
+        makePermutationMap(opInst, indices, state->strategy->loopToVectorDim);
+    if (!permutationMap)
+      return LogicalResult::Failure;
+    LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
+    LLVM_DEBUG(permutationMap.print(dbgs()));
+    auto transfer = b.create<vector::VectorTransferReadOp>(
+        opInst->getLoc(), vectorType, memoryOp.getMemRef(),
+        map(makePtrDynCaster<Value>(), indices), permutationMap);
+    state->registerReplacement(opInst, transfer.getOperation());
+  } else {
+    state->registerTerminal(opInst);
+  }
+  return success();
+}
+/// end TODO(ntv): Hoist to a VectorizationMaterialize.cpp when appropriate. ///
+
+/// Coarsens the loops bounds and transforms all remaining load and store
+/// operations into the appropriate vector.transfer.
+static LogicalResult vectorizeAffineForOp(AffineForOp loop, int64_t step,
+                                          VectorizationState *state) {
+  using namespace functional;
+  loop.setStep(step);
+
+  FilterFunctionType notVectorizedThisPattern = [state](Operation &op) {
+    if (!matcher::isLoadOrStore(op)) {
+      return false;
+    }
+    return state->vectorizationMap.count(&op) == 0 &&
+           state->vectorizedSet.count(&op) == 0 &&
+           state->roots.count(&op) == 0 && state->terminals.count(&op) == 0;
+  };
+  auto loadAndStores = matcher::Op(notVectorizedThisPattern);
+  SmallVector<NestedMatch, 8> loadAndStoresMatches;
+  loadAndStores.match(loop.getOperation(), &loadAndStoresMatches);
+  for (auto ls : loadAndStoresMatches) {
+    auto *opInst = ls.getMatchedOperation();
+    auto load = dyn_cast<AffineLoadOp>(opInst);
+    auto store = dyn_cast<AffineStoreOp>(opInst);
+    LLVM_DEBUG(opInst->print(dbgs()));
+    LogicalResult result =
+        load ? vectorizeRootOrTerminal(loop.getInductionVar(), load, state)
+             : vectorizeRootOrTerminal(loop.getInductionVar(), store, state);
+    if (failed(result)) {
+      return failure();
+    }
+  }
+  return success();
+}
+
+/// Returns a FilterFunctionType that can be used in NestedPattern to match a
+/// loop whose underlying load/store accesses are either invariant or all
+// varying along the `fastestVaryingMemRefDimension`.
+static FilterFunctionType
+isVectorizableLoopPtrFactory(const llvm::DenseSet<Operation *> &parallelLoops,
+                             int fastestVaryingMemRefDimension) {
+  return [&parallelLoops, fastestVaryingMemRefDimension](Operation &forOp) {
+    auto loop = cast<AffineForOp>(forOp);
+    auto parallelIt = parallelLoops.find(loop);
+    if (parallelIt == parallelLoops.end())
+      return false;
+    int memRefDim = -1;
+    auto vectorizableBody = isVectorizableLoopBody(loop, &memRefDim);
+    if (!vectorizableBody)
+      return false;
+    return memRefDim == -1 || fastestVaryingMemRefDimension == -1 ||
+           memRefDim == fastestVaryingMemRefDimension;
+  };
+}
+
+/// Apply vectorization of `loop` according to `state`. This is only triggered
+/// if all vectorizations in `childrenMatches` have already succeeded
+/// recursively in DFS post-order.
+static LogicalResult
+vectorizeLoopsAndLoadsRecursively(NestedMatch oneMatch,
+                                  VectorizationState *state) {
+  auto *loopInst = oneMatch.getMatchedOperation();
+  auto loop = cast<AffineForOp>(loopInst);
+  auto childrenMatches = oneMatch.getMatchedChildren();
+
+  // 1. DFS postorder recursion, if any of my children fails, I fail too.
+  for (auto m : childrenMatches) {
+    if (failed(vectorizeLoopsAndLoadsRecursively(m, state))) {
+      return failure();
+    }
+  }
+
+  // 2. This loop may have been omitted from vectorization for various reasons
+  // (e.g. due to the performance model or pattern depth > vector size).
+  auto it = state->strategy->loopToVectorDim.find(loopInst);
+  if (it == state->strategy->loopToVectorDim.end()) {
+    return success();
+  }
+
+  // 3. Actual post-order transformation.
+  auto vectorDim = it->second;
+  assert(vectorDim < state->strategy->vectorSizes.size() &&
+         "vector dim overflow");
+  //   a. get actual vector size
+  auto vectorSize = state->strategy->vectorSizes[vectorDim];
+  //   b. loop transformation for early vectorization is still subject to
+  //     exploratory tradeoffs (see top of the file). Apply coarsening, i.e.:
+  //        | ub -> ub
+  //        | step -> step * vectorSize
+  LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForOp by " << vectorSize
+                    << " : ");
+  LLVM_DEBUG(loopInst->print(dbgs()));
+  return vectorizeAffineForOp(loop, loop.getStep() * vectorSize, state);
+}
+
+/// Tries to transform a scalar constant into a vector splat of that constant.
+/// Returns the vectorized splat operation if the constant is a valid vector
+/// element type.
+/// If `type` is not a valid vector type or if the scalar constant is not a
+/// valid vector element type, returns nullptr.
+static Value *vectorizeConstant(Operation *op, ConstantOp constant, Type type) {
+  if (!type || !type.isa<VectorType>() ||
+      !VectorType::isValidElementType(constant.getType())) {
+    return nullptr;
+  }
+  OpBuilder b(op);
+  Location loc = op->getLoc();
+  auto vectorType = type.cast<VectorType>();
+  auto attr = DenseElementsAttr::get(vectorType, constant.getValue());
+  auto *constantOpInst = constant.getOperation();
+
+  OperationState state(loc, constantOpInst->getName().getStringRef(), {},
+                       {vectorType}, {b.getNamedAttr("value", attr)});
+
+  return b.createOperation(state)->getResult(0);
+}
+
+/// Tries to vectorize a given operand `op` of Operation `op` during
+/// def-chain propagation or during terminal vectorization, by applying the
+/// following logic:
+/// 1. if the defining operation is part of the vectorizedSet (i.e. vectorized
+///    useby -def propagation), `op` is already in the proper vector form;
+/// 2. otherwise, the `op` may be in some other vector form that fails to
+///    vectorize atm (i.e. broadcasting required), returns nullptr to indicate
+///    failure;
+/// 3. if the `op` is a constant, returns the vectorized form of the constant;
+/// 4. non-constant scalars are currently non-vectorizable, in particular to
+///    guard against vectorizing an index which may be loop-variant and needs
+///    special handling.
+///
+/// In particular this logic captures some of the use cases where definitions
+/// that are not scoped under the current pattern are needed to vectorize.
+/// One such example is top level function constants that need to be splatted.
+///
+/// Returns an operand that has been vectorized to match `state`'s strategy if
+/// vectorization is possible with the above logic. Returns nullptr otherwise.
+///
+/// TODO(ntv): handle more complex cases.
+static Value *vectorizeOperand(Value *operand, Operation *op,
+                               VectorizationState *state) {
+  LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: ");
+  LLVM_DEBUG(operand->print(dbgs()));
+  // 1. If this value has already been vectorized this round, we are done.
+  if (state->vectorizedSet.count(operand->getDefiningOp()) > 0) {
+    LLVM_DEBUG(dbgs() << " -> already vector operand");
+    return operand;
+  }
+  // 1.b. Delayed on-demand replacement of a use.
+  //    Note that we cannot just call replaceAllUsesWith because it may result
+  //    in ops with mixed types, for ops whose operands have not all yet
+  //    been vectorized. This would be invalid IR.
+  auto it = state->replacementMap.find(operand);
+  if (it != state->replacementMap.end()) {
+    auto *res = it->second;
+    LLVM_DEBUG(dbgs() << "-> delayed replacement by: ");
+    LLVM_DEBUG(res->print(dbgs()));
+    return res;
+  }
+  // 2. TODO(ntv): broadcast needed.
+  if (operand->getType().isa<VectorType>()) {
+    LLVM_DEBUG(dbgs() << "-> non-vectorizable");
+    return nullptr;
+  }
+  // 3. vectorize constant.
+  if (auto constant = dyn_cast<ConstantOp>(operand->getDefiningOp())) {
+    return vectorizeConstant(
+        op, constant,
+        VectorType::get(state->strategy->vectorSizes, operand->getType()));
+  }
+  // 4. currently non-vectorizable.
+  LLVM_DEBUG(dbgs() << "-> non-vectorizable");
+  LLVM_DEBUG(operand->print(dbgs()));
+  return nullptr;
+}
+
+/// Encodes Operation-specific behavior for vectorization. In general we assume
+/// that all operands of an op must be vectorized but this is not always true.
+/// In the future, it would be nice to have a trait that describes how a
+/// particular operation vectorizes. For now we implement the case distinction
+/// here.
+/// Returns a vectorized form of an operation or nullptr if vectorization fails.
+// TODO(ntv): consider adding a trait to Op to describe how it gets vectorized.
+// Maybe some Ops are not vectorizable or require some tricky logic, we cannot
+// do one-off logic here; ideally it would be TableGen'd.
+static Operation *vectorizeOneOperation(Operation *opInst,
+                                        VectorizationState *state) {
+  // Sanity checks.
+  assert(!isa<AffineLoadOp>(opInst) &&
+         "all loads must have already been fully vectorized independently");
+  assert(!isa<vector::VectorTransferReadOp>(opInst) &&
+         "vector.transfer_read cannot be further vectorized");
+  assert(!isa<vector::VectorTransferWriteOp>(opInst) &&
+         "vector.transfer_write cannot be further vectorized");
+
+  if (auto store = dyn_cast<AffineStoreOp>(opInst)) {
+    OpBuilder b(opInst);
+    auto *memRef = store.getMemRef();
+    auto *value = store.getValueToStore();
+    auto *vectorValue = vectorizeOperand(value, opInst, state);
+
+    SmallVector<Value *, 4> mapOperands(store.getIndices());
+    SmallVector<Value *, 8> indices;
+    indices.reserve(store.getMemRefType().getRank());
+    if (store.getAffineMap() !=
+        b.getMultiDimIdentityMap(store.getMemRefType().getRank())) {
+      computeMemoryOpIndices(opInst, store.getAffineMap(), mapOperands,
+                             indices);
+    } else {
+      indices.append(store.getIndices().begin(), store.getIndices().end());
+    }
+
+    auto permutationMap =
+        makePermutationMap(opInst, indices, state->strategy->loopToVectorDim);
+    if (!permutationMap)
+      return nullptr;
+    LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
+    LLVM_DEBUG(permutationMap.print(dbgs()));
+    auto transfer = b.create<vector::VectorTransferWriteOp>(
+        opInst->getLoc(), vectorValue, memRef, indices, permutationMap);
+    auto *res = transfer.getOperation();
+    LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res);
+    // "Terminals" (i.e. AffineStoreOps) are erased on the spot.
+    opInst->erase();
+    return res;
+  }
+  if (opInst->getNumRegions() != 0)
+    return nullptr;
+
+  SmallVector<Type, 8> vectorTypes;
+  for (auto *v : opInst->getResults()) {
+    vectorTypes.push_back(
+        VectorType::get(state->strategy->vectorSizes, v->getType()));
+  }
+  SmallVector<Value *, 8> vectorOperands;
+  for (auto *v : opInst->getOperands()) {
+    vectorOperands.push_back(vectorizeOperand(v, opInst, state));
+  }
+  // Check whether a single operand is null. If so, vectorization failed.
+  bool success = llvm::all_of(vectorOperands, [](Value *op) { return op; });
+  if (!success) {
+    LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ an operand failed vectorize");
+    return nullptr;
+  }
+
+  // Create a clone of the op with the proper operands and return types.
+  // TODO(ntv): The following assumes there is always an op with a fixed
+  // name that works both in scalar mode and vector mode.
+  // TODO(ntv): Is it worth considering an Operation.clone operation which
+  // changes the type so we can promote an Operation with less boilerplate?
+  OpBuilder b(opInst);
+  OperationState newOp(opInst->getLoc(), opInst->getName().getStringRef(),
+                       vectorOperands, vectorTypes, opInst->getAttrs(),
+                       /*successors=*/{},
+                       /*regions=*/{}, opInst->hasResizableOperandsList());
+  return b.createOperation(newOp);
+}
+
+/// Iterates over the forward slice from the loads in the vectorization pattern
+/// and rewrites them using their vectorized counterpart by:
+///   1. Create the forward slice starting from the laods in the vectorization
+///   pattern.
+///   2. Topologically sorts the forward slice.
+///   3. For each operation in the slice, create the vector form of this
+///   operation, replacing each operand by a replacement operands retrieved from
+///   replacementMap. If any such replacement is missing, vectorization fails.
+static LogicalResult vectorizeNonTerminals(VectorizationState *state) {
+  // 1. create initial worklist with the uses of the roots.
+  SetVector<Operation *> worklist;
+  // Note: state->roots have already been vectorized and must not be vectorized
+  // again. This fits `getForwardSlice` which does not insert `op` in the
+  // result.
+  // Note: we have to exclude terminals because some of their defs may not be
+  // nested under the vectorization pattern (e.g. constants defined in an
+  // encompassing scope).
+  // TODO(ntv): Use a backward slice for terminals, avoid special casing and
+  // merge implementations.
+  for (auto *op : state->roots) {
+    getForwardSlice(op, &worklist, [state](Operation *op) {
+      return state->terminals.count(op) == 0; // propagate if not terminal
+    });
+  }
+  // We merged multiple slices, topological order may not hold anymore.
+  worklist = topologicalSort(worklist);
+
+  for (unsigned i = 0; i < worklist.size(); ++i) {
+    auto *op = worklist[i];
+    LLVM_DEBUG(dbgs() << "\n[early-vect] vectorize use: ");
+    LLVM_DEBUG(op->print(dbgs()));
+
+    // Create vector form of the operation.
+    // Insert it just before op, on success register op as replaced.
+    auto *vectorizedInst = vectorizeOneOperation(op, state);
+    if (!vectorizedInst) {
+      return failure();
+    }
+
+    // 3. Register replacement for future uses in the scope.
+    //    Note that we cannot just call replaceAllUsesWith because it may
+    //    result in ops with mixed types, for ops whose operands have not all
+    //    yet been vectorized. This would be invalid IR.
+    state->registerReplacement(op, vectorizedInst);
+  }
+  return success();
+}
+
+/// Vectorization is a recursive procedure where anything below can fail.
+/// The root match thus needs to maintain a clone for handling failure.
+/// Each root may succeed independently but will otherwise clean after itself if
+/// anything below it fails.
+static LogicalResult vectorizeRootMatch(NestedMatch m,
+                                        VectorizationStrategy *strategy) {
+  auto loop = cast<AffineForOp>(m.getMatchedOperation());
+  VectorizationState state;
+  state.strategy = strategy;
+
+  // Since patterns are recursive, they can very well intersect.
+  // Since we do not want a fully greedy strategy in general, we decouple
+  // pattern matching, from profitability analysis, from application.
+  // As a consequence we must check that each root pattern is still
+  // vectorizable. If a pattern is not vectorizable anymore, we just skip it.
+  // TODO(ntv): implement a non-greedy profitability analysis that keeps only
+  // non-intersecting patterns.
+  if (!isVectorizableLoopBody(loop)) {
+    LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable");
+    return failure();
+  }
+
+  /// Sets up error handling for this root loop. This is how the root match
+  /// maintains a clone for handling failure and restores the proper state via
+  /// RAII.
+  auto *loopInst = loop.getOperation();
+  OpBuilder builder(loopInst);
+  auto clonedLoop = cast<AffineForOp>(builder.clone(*loopInst));
+  struct Guard {
+    LogicalResult failure() {
+      loop.getInductionVar()->replaceAllUsesWith(clonedLoop.getInductionVar());
+      loop.erase();
+      return mlir::failure();
+    }
+    LogicalResult success() {
+      clonedLoop.erase();
+      return mlir::success();
+    }
+    AffineForOp loop;
+    AffineForOp clonedLoop;
+  } guard{loop, clonedLoop};
+
+  //////////////////////////////////////////////////////////////////////////////
+  // Start vectorizing.
+  // From now on, any error triggers the scope guard above.
+  //////////////////////////////////////////////////////////////////////////////
+  // 1. Vectorize all the loops matched by the pattern, recursively.
+  // This also vectorizes the roots (AffineLoadOp) as well as registers the
+  // terminals (AffineStoreOp) for post-processing vectorization (we need to
+  // wait for all use-def chains into them to be vectorized first).
+  if (failed(vectorizeLoopsAndLoadsRecursively(m, &state))) {
+    LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed root vectorizeLoop");
+    return guard.failure();
+  }
+
+  // 2. Vectorize operations reached by use-def chains from root except the
+  // terminals (store operations) that need to be post-processed separately.
+  // TODO(ntv): add more as we expand.
+  if (failed(vectorizeNonTerminals(&state))) {
+    LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed vectorizeNonTerminals");
+    return guard.failure();
+  }
+
+  // 3. Post-process terminals.
+  // Note: we have to post-process terminals because some of their defs may not
+  // be nested under the vectorization pattern (e.g. constants defined in an
+  // encompassing scope).
+  // TODO(ntv): Use a backward slice for terminals, avoid special casing and
+  // merge implementations.
+  for (auto *op : state.terminals) {
+    if (!vectorizeOneOperation(op, &state)) { // nullptr == failure
+      LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed to vectorize terminals");
+      return guard.failure();
+    }
+  }
+
+  // 4. Finish this vectorization pattern.
+  LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ success vectorizing pattern");
+  state.finishVectorizationPattern();
+  return guard.success();
+}
+
+/// Applies vectorization to the current Function by searching over a bunch of
+/// predetermined patterns.
+void Vectorize::runOnFunction() {
+  FuncOp f = getFunction();
+  if (!fastestVaryingPattern.empty() &&
+      fastestVaryingPattern.size() != vectorSizes.size()) {
+    f.emitRemark("Fastest varying pattern specified with different size than "
+                 "the vector size.");
+    return signalPassFailure();
+  }
+
+  // Thread-safe RAII local context, BumpPtrAllocator freed on exit.
+  NestedPatternContext mlContext;
+
+  llvm::DenseSet<Operation *> parallelLoops;
+  f.walk<AffineForOp>([&parallelLoops](AffineForOp loop) {
+    if (isLoopParallel(loop))
+      parallelLoops.insert(loop);
+  });
+
+  for (auto &pat :
+       makePatterns(parallelLoops, vectorSizes.size(), fastestVaryingPattern)) {
+    LLVM_DEBUG(dbgs() << "\n******************************************");
+    LLVM_DEBUG(dbgs() << "\n******************************************");
+    LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on Function\n");
+    LLVM_DEBUG(f.print(dbgs()));
+    unsigned patternDepth = pat.getDepth();
+
+    SmallVector<NestedMatch, 8> matches;
+    pat.match(f, &matches);
+    // Iterate over all the top-level matches and vectorize eagerly.
+    // This automatically prunes intersecting matches.
+    for (auto m : matches) {
+      VectorizationStrategy strategy;
+      // TODO(ntv): depending on profitability, elect to reduce the vector size.
+      strategy.vectorSizes.assign(vectorSizes.begin(), vectorSizes.end());
+      if (failed(analyzeProfitability(m.getMatchedChildren(), 1, patternDepth,
+                                      &strategy))) {
+        continue;
+      }
+      vectorizeLoopIfProfitable(m.getMatchedOperation(), 0, patternDepth,
+                                &strategy);
+      // TODO(ntv): if pattern does not apply, report it; alter the
+      // cost/benefit.
+      vectorizeRootMatch(m, &strategy);
+      // TODO(ntv): some diagnostics if failure to vectorize occurs.
+    }
+  }
+  LLVM_DEBUG(dbgs() << "\n");
+}
+
+FunctionPassBase *
+mlir::createVectorizePass(llvm::ArrayRef<int64_t> virtualVectorSize) {
+  return new Vectorize(virtualVectorSize);
+}
+
+static PassRegistration<Vectorize>
+    pass("affine-vectorize",
+         "Vectorize to a target independent n-D vector abstraction");
diff --git a/third_party/mlir/lib/Transforms/ViewRegionGraph.cpp b/third_party/mlir/lib/Transforms/ViewRegionGraph.cpp
new file mode 100644
index 0000000..5a0e8e5
--- /dev/null
+++ b/third_party/mlir/lib/Transforms/ViewRegionGraph.cpp
@@ -0,0 +1,95 @@
+//===- ViewRegionGraph.cpp - View/write graphviz graphs -------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Transforms/ViewRegionGraph.h"
+#include "mlir/IR/RegionGraphTraits.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace llvm {
+
+// Specialize DOTGraphTraits to produce more readable output.
+template <> struct DOTGraphTraits<Region *> : public DefaultDOTGraphTraits {
+  using DefaultDOTGraphTraits::DefaultDOTGraphTraits;
+
+  static std::string getNodeLabel(Block *Block, Region *);
+};
+
+std::string DOTGraphTraits<Region *>::getNodeLabel(Block *Block, Region *) {
+  // Reuse the print output for the node labels.
+  std::string outStreamStr;
+  raw_string_ostream os(outStreamStr);
+  Block->print(os);
+  std::string &outStr = os.str();
+
+  if (outStr[0] == '\n')
+    outStr.erase(outStr.begin());
+
+  // Process string output to left justify the block.
+  for (unsigned i = 0; i != outStr.length(); ++i) {
+    if (outStr[i] == '\n') {
+      outStr[i] = '\\';
+      outStr.insert(outStr.begin() + i + 1, 'l');
+    }
+  }
+
+  return outStr;
+}
+
+} // end namespace llvm
+
+void mlir::viewGraph(Region &region, const llvm::Twine &name, bool shortNames,
+                     const llvm::Twine &title,
+                     llvm::GraphProgram::Name program) {
+  llvm::ViewGraph(&region, name, shortNames, title, program);
+}
+
+llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Region &region,
+                                    bool shortNames, const llvm::Twine &title) {
+  return llvm::WriteGraph(os, &region, shortNames, title);
+}
+
+void mlir::Region::viewGraph(const llvm::Twine &regionName) {
+  ::mlir::viewGraph(*this, regionName);
+}
+void mlir::Region::viewGraph() { viewGraph("region"); }
+
+namespace {
+struct PrintCFGPass : public FunctionPass<PrintCFGPass> {
+  PrintCFGPass(llvm::raw_ostream &os = llvm::errs(), bool shortNames = false,
+               const llvm::Twine &title = "")
+      : os(os), shortNames(shortNames), title(title.str()) {}
+  void runOnFunction() {
+    mlir::writeGraph(os, getFunction().getBody(), shortNames, title);
+  }
+
+private:
+  llvm::raw_ostream &os;
+  bool shortNames;
+  std::string title;
+};
+} // namespace
+
+FunctionPassBase *mlir::createPrintCFGGraphPass(llvm::raw_ostream &os,
+                                                bool shortNames,
+                                                const llvm::Twine &title) {
+  return new PrintCFGPass(os, shortNames, title);
+}
+
+static PassRegistration<PrintCFGPass> pass("print-cfg-graph",
+                                           "Print CFG graph per Function");
diff --git a/third_party/mlir/lib/Translation/CMakeLists.txt b/third_party/mlir/lib/Translation/CMakeLists.txt
new file mode 100644
index 0000000..122db2e
--- /dev/null
+++ b/third_party/mlir/lib/Translation/CMakeLists.txt
@@ -0,0 +1,7 @@
+add_llvm_library(MLIRTranslation
+  Translation.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Translation
+  )
+target_link_libraries(MLIRTranslation LLVMSupport)
diff --git a/third_party/mlir/lib/Translation/Translation.cpp b/third_party/mlir/lib/Translation/Translation.cpp
new file mode 100644
index 0000000..3025e9e
--- /dev/null
+++ b/third_party/mlir/lib/Translation/Translation.cpp
@@ -0,0 +1,77 @@
+//===- Translation.cpp - Translation registry -----------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Definitions of the translation registry.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Translation.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/ManagedStatic.h"
+
+using namespace mlir;
+
+// Get the mutable static map between registered "to MLIR" translations and the
+// TranslateToMLIRFunctions that perform those translations.
+static llvm::StringMap<TranslateToMLIRFunction> &
+getMutableTranslationToMLIRRegistry() {
+  static llvm::StringMap<TranslateToMLIRFunction> translationToMLIRRegistry;
+  return translationToMLIRRegistry;
+}
+// Get the mutable static map between registered "from MLIR" translations and
+// the TranslateFromMLIRFunctions that perform those translations.
+static llvm::StringMap<TranslateFromMLIRFunction> &
+getMutableTranslationFromMLIRRegistry() {
+  static llvm::StringMap<TranslateFromMLIRFunction> translationFromMLIRRegistry;
+  return translationFromMLIRRegistry;
+}
+
+TranslateToMLIRRegistration::TranslateToMLIRRegistration(
+    StringRef name, const TranslateToMLIRFunction &function) {
+  auto &translationToMLIRRegistry = getMutableTranslationToMLIRRegistry();
+  if (translationToMLIRRegistry.find(name) != translationToMLIRRegistry.end())
+    llvm::report_fatal_error(
+        "Attempting to overwrite an existing <to> function");
+  assert(function && "Attempting to register an empty translate <to> function");
+  translationToMLIRRegistry[name] = function;
+}
+
+TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
+    StringRef name, const TranslateFromMLIRFunction &function) {
+  auto &translationFromMLIRRegistry = getMutableTranslationFromMLIRRegistry();
+  if (translationFromMLIRRegistry.find(name) !=
+      translationFromMLIRRegistry.end())
+    llvm::report_fatal_error(
+        "Attempting to overwrite an existing <from> function");
+  assert(function &&
+         "Attempting to register an empty translate <from> function");
+  translationFromMLIRRegistry[name] = function;
+}
+
+// Merely add the const qualifier to the mutable registry so that external users
+// cannot modify it.
+const llvm::StringMap<TranslateToMLIRFunction> &
+mlir::getTranslationToMLIRRegistry() {
+  return getMutableTranslationToMLIRRegistry();
+}
+
+const llvm::StringMap<TranslateFromMLIRFunction> &
+mlir::getTranslationFromMLIRRegistry() {
+  return getMutableTranslationFromMLIRRegistry();
+}
diff --git a/third_party/mlir/lib/VectorOps/CMakeLists.txt b/third_party/mlir/lib/VectorOps/CMakeLists.txt
new file mode 100644
index 0000000..0e76501
--- /dev/null
+++ b/third_party/mlir/lib/VectorOps/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_llvm_library(MLIRVectorOps
+  DialectRegistration.cpp
+  VectorOps.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/VectorOps
+  )
+
+add_dependencies(MLIRVectorOps MLIRVectorOpsIncGen)
+
+target_link_libraries(MLIRVectorOps MLIRIR)
diff --git a/third_party/mlir/lib/VectorOps/DialectRegistration.cpp b/third_party/mlir/lib/VectorOps/DialectRegistration.cpp
new file mode 100644
index 0000000..aedba31
--- /dev/null
+++ b/third_party/mlir/lib/VectorOps/DialectRegistration.cpp
@@ -0,0 +1,22 @@
+//===- DialectRegistration.cpp - Register super vectorization dialect -----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/VectorOps/VectorOps.h"
+using namespace mlir;
+
+// Static initialization for VectorOps dialect registration.
+static DialectRegistration<vector::VectorOpsDialect> VectorOps;
diff --git a/third_party/mlir/lib/VectorOps/VectorOps.cpp b/third_party/mlir/lib/VectorOps/VectorOps.cpp
new file mode 100644
index 0000000..38267af
--- /dev/null
+++ b/third_party/mlir/lib/VectorOps/VectorOps.cpp
@@ -0,0 +1,546 @@
+//===- VectorOps.cpp - MLIR Super Vectorizer Operations -------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements convenience types for working with super-vectorization
+// operations, in particular super-vector loads and stores.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/VectorOps/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+//===----------------------------------------------------------------------===//
+// VectorOpsDialect
+//===----------------------------------------------------------------------===//
+
+mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
+    : Dialect(getDialectNamespace(), context) {
+  addOperations<VectorTransferReadOp, VectorTransferWriteOp,
+                VectorTypeCastOp>();
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/VectorOps/VectorOps.cpp.inc"
+      >();
+}
+
+//===----------------------------------------------------------------------===//
+// ExtractElementOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter *p, ExtractElementOp op) {
+  *p << op.getOperationName() << " " << *op.vector() << op.position();
+  p->printOptionalAttrDict(op.getAttrs(), {"position"});
+  *p << " : " << op.vector()->getType();
+}
+
+static ParseResult parseExtractElementOp(OpAsmParser *parser,
+                                         OperationState *result) {
+  llvm::SMLoc attributeLoc, typeLoc;
+  SmallVector<NamedAttribute, 4> attrs;
+  OpAsmParser::OperandType vector;
+  Type type;
+  Attribute attr;
+  if (parser->parseOperand(vector) ||
+      parser->getCurrentLocation(&attributeLoc) ||
+      parser->parseAttribute(attr, "position", attrs) ||
+      parser->parseOptionalAttributeDict(attrs) ||
+      parser->getCurrentLocation(&typeLoc) || parser->parseColonType(type))
+    return failure();
+
+  auto vectorType = type.dyn_cast<VectorType>();
+  if (!vectorType)
+    return parser->emitError(typeLoc, "expected vector type");
+
+  auto positionAttr = attr.dyn_cast<ArrayAttr>();
+  if (!positionAttr ||
+      static_cast<int64_t>(positionAttr.size()) > vectorType.getRank())
+    return parser->emitError(
+        attributeLoc,
+        "expected position attribute of rank smaller than vector");
+
+  Type resType =
+      (static_cast<int64_t>(positionAttr.size()) == vectorType.getRank())
+          ? vectorType.getElementType()
+          : VectorType::get(
+                vectorType.getShape().drop_front(positionAttr.size()),
+                vectorType.getElementType());
+
+  result->attributes = attrs;
+  return failure(parser->resolveOperand(vector, type, result->operands) ||
+                 parser->addTypeToList(resType, result->types));
+}
+
+static LogicalResult verify(ExtractElementOp op) {
+  auto positionAttr = op.position().getValue();
+  if (positionAttr.empty())
+    return op.emitOpError("expected non-empty position attribute");
+  if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank()))
+    return op.emitOpError(
+        "expected position attribute of rank smaller than vector");
+  for (auto en : llvm::enumerate(positionAttr)) {
+    auto attr = en.value().dyn_cast<IntegerAttr>();
+    if (!attr || attr.getInt() < 0 ||
+        attr.getInt() > op.getVectorType().getDimSize(en.index()))
+      return op.emitOpError("expected position attribute #")
+             << (en.index() + 1)
+             << " to be a positive integer smaller than the corresponding "
+                "vector dimension";
+  }
+  return success();
+}
+//===----------------------------------------------------------------------===//
+// OuterProductOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter *p, OuterProductOp op) {
+  *p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
+  *p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType();
+}
+
+static ParseResult parseOuterProductOp(OpAsmParser *parser,
+                                       OperationState *result) {
+  SmallVector<OpAsmParser::OperandType, 2> operandsInfo;
+  Type t0, t1;
+  if (parser->parseOperandList(operandsInfo) || parser->parseColonType(t0) ||
+      parser->parseComma() || parser->parseType(t1))
+    return failure();
+  VectorType v0 = t0.dyn_cast<VectorType>();
+  VectorType v1 = t1.dyn_cast<VectorType>();
+  if (!v0 || !v1)
+    return parser->emitError(parser->getNameLoc(), "expected 2 vector types");
+  VectorType resType = VectorType::get({v0.getDimSize(0), v1.getDimSize(0)},
+                                       v0.getElementType());
+  return failure(parser->resolveOperands(operandsInfo, {t0, t1},
+                                         parser->getCurrentLocation(),
+                                         result->operands) ||
+                 parser->addTypeToList(resType, result->types));
+}
+
+static LogicalResult verify(OuterProductOp op) {
+  VectorType v1 = op.getOperandVectorTypeLHS(),
+             v2 = op.getOperandVectorTypeRHS(), res = op.getVectorType();
+  if (v1.getRank() != 1)
+    return op.emitOpError("expected 1-d vector for operand #1");
+  if (v2.getRank() != 1)
+    return op.emitOpError("expected 1-d vector for operand #2");
+  if (res.getRank() != 2)
+    return op.emitOpError("expected 2-d vector result");
+  if (v1.getDimSize(0) != res.getDimSize(0))
+    return op.emitOpError(
+        "expected first operand dim to match first result dim");
+  if (v2.getDimSize(0) != res.getDimSize(1))
+    return op.emitOpError(
+        "expected second operand dim to match second result dim");
+  return success();
+}
+//===----------------------------------------------------------------------===//
+// VectorTransferReadOp
+//===----------------------------------------------------------------------===//
+template <typename EmitFun>
+static LogicalResult verifyPermutationMap(AffineMap permutationMap,
+                                          EmitFun emitOpError) {
+  SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
+  for (auto expr : permutationMap.getResults()) {
+    auto dim = expr.dyn_cast<AffineDimExpr>();
+    auto zero = expr.dyn_cast<AffineConstantExpr>();
+    if (zero) {
+      if (zero.getValue() != 0) {
+        return emitOpError(
+            "requires a projected permutation_map (at most one dim or the zero "
+            "constant can appear in each result)");
+      }
+      continue;
+    }
+    if (!dim) {
+      return emitOpError("requires a projected permutation_map (at most one "
+                         "dim or the zero constant can appear in each result)");
+    }
+    if (seen[dim.getPosition()]) {
+      return emitOpError(
+          "requires a permutation_map that is a permutation (found one dim "
+          "used more than once)");
+    }
+    seen[dim.getPosition()] = true;
+  }
+  return success();
+}
+
+void VectorTransferReadOp::build(Builder *builder, OperationState *result,
+                                 VectorType vectorType, Value *srcMemRef,
+                                 ArrayRef<Value *> srcIndices,
+                                 AffineMap permutationMap,
+                                 Optional<Value *> paddingValue) {
+  result->addOperands(srcMemRef);
+  result->addOperands(srcIndices);
+  if (paddingValue) {
+    result->addOperands({*paddingValue});
+  }
+  result->addAttribute(getPermutationMapAttrName(),
+                       builder->getAffineMapAttr(permutationMap));
+  result->addTypes(vectorType);
+}
+
+auto VectorTransferReadOp::getIndices() -> operand_range {
+  auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
+  auto end = begin + getMemRefType().getRank();
+  return {begin, end};
+}
+
+Optional<Value *> VectorTransferReadOp::getPaddingValue() {
+  auto memRefRank = getMemRefType().getRank();
+  if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) {
+    return None;
+  }
+  return Optional<Value *>(getOperand(Offsets::FirstIndexOffset + memRefRank));
+}
+
+AffineMap VectorTransferReadOp::getPermutationMap() {
+  return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue();
+}
+
+void VectorTransferReadOp::print(OpAsmPrinter *p) {
+  *p << getOperationName() << " ";
+  p->printOperand(getMemRef());
+  *p << "[";
+  p->printOperands(getIndices());
+  *p << "]";
+  auto optionalPaddingValue = getPaddingValue();
+  if (optionalPaddingValue) {
+    *p << ", (";
+    p->printOperand(*optionalPaddingValue);
+    *p << ")";
+  }
+  p->printOptionalAttrDict(getAttrs());
+  *p << " : " << getMemRefType();
+  *p << ", " << getResultType();
+}
+
+ParseResult VectorTransferReadOp::parse(OpAsmParser *parser,
+                                        OperationState *result) {
+  OpAsmParser::OperandType memrefInfo;
+  SmallVector<OpAsmParser::OperandType, 8> indexInfo;
+  SmallVector<OpAsmParser::OperandType, 8> paddingInfo;
+  SmallVector<Type, 2> types;
+
+  // Parsing with support for optional paddingValue.
+  if (parser->parseOperand(memrefInfo) ||
+      parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
+      parser->parseTrailingOperandList(paddingInfo,
+                                       OpAsmParser::Delimiter::Paren) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonTypeList(types))
+    return failure();
+
+  // Resolution.
+  if (types.size() != 2)
+    return parser->emitError(parser->getNameLoc(), "expected 2 types");
+  MemRefType memrefType = types[0].dyn_cast<MemRefType>();
+  if (!memrefType)
+    return parser->emitError(parser->getNameLoc(), "memRef type expected");
+  VectorType vectorType = types[1].dyn_cast<VectorType>();
+  if (!vectorType)
+    return parser->emitError(parser->getNameLoc(), "vector type expected");
+
+  // Extract optional paddingValue.
+  // At this point, indexInfo may contain the optional paddingValue, pop it
+  // out.
+  if (static_cast<int64_t>(indexInfo.size()) != memrefType.getRank())
+    return parser->emitError(parser->getNameLoc(),
+                             "expected " + Twine(memrefType.getRank()) +
+                                 " indices to the memref");
+  if (paddingInfo.size() > 1)
+    return parser->emitError(parser->getNameLoc(),
+                             "expected at most one padding value");
+  Type paddingType;
+  bool hasOptionalPaddingValue = !paddingInfo.empty();
+  if (hasOptionalPaddingValue) {
+    paddingType = vectorType.getElementType();
+  }
+  auto indexType = parser->getBuilder().getIndexType();
+  return failure(
+      parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
+      parser->resolveOperands(indexInfo, indexType, result->operands) ||
+      (hasOptionalPaddingValue &&
+       parser->resolveOperand(paddingInfo[0], paddingType, result->operands)) ||
+      parser->addTypeToList(vectorType, result->types));
+}
+
+LogicalResult VectorTransferReadOp::verify() {
+  // Consistency of memref type in function type.
+  if (llvm::empty(getOperands())) {
+    return emitOpError(
+        "requires at least a memref operand followed by 'rank' indices");
+  }
+  if (!getMemRef()->getType().isa<MemRefType>()) {
+    return emitOpError("requires a memref as first operand");
+  }
+  // Consistency of vector type in function type.
+  if (!getResult()->getType().isa<VectorType>()) {
+    return emitOpError("should have a vector result type in function type: "
+                       "memref_type<...xelemental_type>, vector_type");
+  }
+  // Consistency of elemental types in memref and vector.
+  MemRefType memrefType = getMemRefType();
+  VectorType vectorType = getResultType();
+  if (memrefType.getElementType() != vectorType.getElementType())
+    return emitOpError(
+        "requires memref and vector types of the same elemental type");
+  // Consistency of number of input types.
+  auto optionalPaddingValue = getPaddingValue();
+  unsigned expectedNumOperands = Offsets::FirstIndexOffset +
+                                 memrefType.getRank() +
+                                 (optionalPaddingValue ? 1 : 0);
+  // Checks on the actual operands and their types.
+  if (getNumOperands() != expectedNumOperands) {
+    return emitOpError("expects ")
+           << expectedNumOperands << " operands (of which "
+           << memrefType.getRank() << " indices)";
+  }
+  // Consistency of padding value with vector type.
+  if (optionalPaddingValue) {
+    auto paddingValue = *optionalPaddingValue;
+    auto elementalType = paddingValue->getType();
+    if (!VectorType::isValidElementType(elementalType)) {
+      return emitOpError("requires valid padding vector elemental type");
+    }
+    if (elementalType != vectorType.getElementType()) {
+      return emitOpError(
+          "requires formal padding and vector of the same elemental type");
+    }
+  }
+  // Consistency of indices types.
+  unsigned numIndices = 0;
+  for (auto *idx : getIndices()) {
+    if (!idx->getType().isIndex()) {
+      return emitOpError(
+          "index to vector.transfer_read must have 'index' type");
+    }
+    ++numIndices;
+  }
+  if (numIndices != memrefType.getRank()) {
+    return emitOpError("requires at least a memref operand followed by ")
+           << memrefType.getRank() << " indices";
+  }
+
+  // Consistency of AffineMap attribute.
+  if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) {
+    return emitOpError("requires an AffineMapAttr named 'permutation_map'");
+  }
+  auto permutationMap = getPermutationMap();
+  if (permutationMap.getNumSymbols() != 0) {
+    return emitOpError("requires a permutation_map without symbols");
+  }
+  if (permutationMap.getNumInputs() != memrefType.getRank()) {
+    return emitOpError("requires a permutation_map with input dims of the "
+                       "same rank as the memref type");
+  }
+  if (permutationMap.getNumResults() != vectorType.getRank()) {
+    return emitOpError("requires a permutation_map with result dims of the "
+                       "same rank as the vector type (")
+           << permutationMap.getNumResults() << " vs " << vectorType.getRank();
+  }
+  return verifyPermutationMap(permutationMap,
+                              [this](Twine t) { return emitOpError(t); });
+}
+
+//===----------------------------------------------------------------------===//
+// VectorTransferWriteOp
+//===----------------------------------------------------------------------===//
+void VectorTransferWriteOp::build(Builder *builder, OperationState *result,
+                                  Value *srcVector, Value *dstMemRef,
+                                  ArrayRef<Value *> dstIndices,
+                                  AffineMap permutationMap) {
+  result->addOperands({srcVector, dstMemRef});
+  result->addOperands(dstIndices);
+  result->addAttribute(getPermutationMapAttrName(),
+                       builder->getAffineMapAttr(permutationMap));
+}
+
+auto VectorTransferWriteOp::getIndices() -> operand_range {
+  auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
+  auto end = begin + getMemRefType().getRank();
+  return {begin, end};
+}
+
+AffineMap VectorTransferWriteOp::getPermutationMap() {
+  return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue();
+}
+
+void VectorTransferWriteOp::print(OpAsmPrinter *p) {
+  *p << getOperationName();
+  *p << " " << *getVector();
+  *p << ", " << *getMemRef();
+  *p << "[";
+  p->printOperands(getIndices());
+  *p << "]";
+  p->printOptionalAttrDict(getAttrs());
+  *p << " : ";
+  p->printType(getVectorType());
+  *p << ", ";
+  p->printType(getMemRefType());
+}
+
+ParseResult VectorTransferWriteOp::parse(OpAsmParser *parser,
+                                         OperationState *result) {
+  OpAsmParser::OperandType storeValueInfo;
+  OpAsmParser::OperandType memrefInfo;
+  SmallVector<OpAsmParser::OperandType, 4> indexInfo;
+  SmallVector<Type, 2> types;
+  auto indexType = parser->getBuilder().getIndexType();
+  if (parser->parseOperand(storeValueInfo) || parser->parseComma() ||
+      parser->parseOperand(memrefInfo) ||
+      parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonTypeList(types))
+    return failure();
+
+  if (types.size() != 2)
+    return parser->emitError(parser->getNameLoc(), "expected 2 types");
+  VectorType vectorType = types[Offsets::VectorOffset].dyn_cast<VectorType>();
+  if (!vectorType)
+    return parser->emitError(parser->getNameLoc(), "vector type expected");
+  MemRefType memrefType = types[Offsets::MemRefOffset].dyn_cast<MemRefType>();
+  if (!memrefType)
+    return parser->emitError(parser->getNameLoc(), "memRef type expected");
+
+  return failure(
+      parser->resolveOperands(storeValueInfo, vectorType, result->operands) ||
+      parser->resolveOperands(memrefInfo, memrefType, result->operands) ||
+      parser->resolveOperands(indexInfo, indexType, result->operands));
+}
+
+LogicalResult VectorTransferWriteOp::verify() {
+  // Consistency of memref type in function type.
+  if (llvm::empty(getOperands())) {
+    return emitOpError(
+        "requires at least a memref operand followed by 'rank' indices");
+  }
+  if (!getMemRef()->getType().isa<MemRefType>()) {
+    return emitOpError("requires a memref first operand");
+  }
+  // Consistency of vector type in function type.
+  if (!getVector()->getType().isa<VectorType>()) {
+    return emitOpError("should have a vector input type in function type: "
+                       "(vector_type, memref_type [, elemental_type]) -> ()");
+  }
+  // Consistency of elemental types in memref and vector.
+  MemRefType memrefType = getMemRefType();
+  VectorType vectorType = getVectorType();
+  if (memrefType.getElementType() != vectorType.getElementType())
+    return emitOpError(
+        "requires memref and vector types of the same elemental type");
+  // Consistency of number of input types.
+  unsigned expectedNumOperands =
+      Offsets::FirstIndexOffset + memrefType.getRank();
+  // Checks on the actual operands and their types.
+  if (getNumOperands() != expectedNumOperands) {
+    return emitOpError() << "expects " << expectedNumOperands
+                         << " operands (of which " << memrefType.getRank()
+                         << " indices)";
+  }
+  // Consistency of indices types.
+  unsigned numIndices = 0;
+  for (auto *idx : getIndices()) {
+    if (!idx->getType().isIndex()) {
+      return emitOpError(
+          "index to vector.transfer_write must have 'index' type");
+    }
+    numIndices++;
+  }
+  if (numIndices != memrefType.getRank()) {
+    return emitOpError("requires at least a memref operand followed by ")
+           << memrefType.getRank() << " indices";
+  }
+
+  // Consistency of AffineMap attribute.
+  if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) {
+    return emitOpError("requires an AffineMapAttr named 'permutation_map'");
+  }
+  auto permutationMap = getPermutationMap();
+  if (permutationMap.getNumSymbols() != 0) {
+    return emitOpError("requires a permutation_map without symbols");
+  }
+  if (permutationMap.getNumInputs() != memrefType.getRank()) {
+    return emitOpError("requires a permutation_map with input dims of the "
+                       "same rank as the memref type");
+  }
+  if (permutationMap.getNumResults() != vectorType.getRank()) {
+    return emitOpError("requires a permutation_map with result dims of the "
+                       "same rank as the vector type (")
+           << permutationMap.getNumResults() << " vs " << vectorType.getRank();
+  }
+  return verifyPermutationMap(permutationMap,
+                              [this](Twine t) { return emitOpError(t); });
+}
+
+//===----------------------------------------------------------------------===//
+// VectorTypeCastOp
+//===----------------------------------------------------------------------===//
+void VectorTypeCastOp::build(Builder *builder, OperationState *result,
+                             Value *srcVector, Type dstType) {
+  result->addOperands(srcVector);
+  result->addTypes(dstType);
+}
+
+ParseResult VectorTypeCastOp::parse(OpAsmParser *parser,
+                                    OperationState *result) {
+  OpAsmParser::OperandType operand;
+  Type srcType, dstType;
+  return failure(parser->parseOperand(operand) ||
+                 parser->parseOptionalAttributeDict(result->attributes) ||
+                 parser->parseColonType(srcType) || parser->parseComma() ||
+                 parser->parseType(dstType) ||
+                 parser->addTypeToList(dstType, result->types) ||
+                 parser->resolveOperand(operand, srcType, result->operands));
+}
+
+void VectorTypeCastOp::print(OpAsmPrinter *p) {
+  *p << getOperationName() << ' ' << *getOperand() << " : "
+     << getOperand()->getType() << ", " << getType();
+}
+
+LogicalResult VectorTypeCastOp::verify() {
+  auto dstMemrefType = getType().dyn_cast<MemRefType>();
+  if (!dstMemrefType)
+    return emitOpError("expects target type to be a memref type");
+  auto dstVectorType = dstMemrefType.getElementType().dyn_cast<VectorType>();
+  if (!dstVectorType)
+    return emitOpError(
+        "expects vector as an element of the target memref type");
+  if (!dstMemrefType.hasStaticShape())
+    return emitOpError("does not support dynamic shapes");
+
+  if (!getOperand()->getType().isa<MemRefType>())
+    return emitOpError("expects source type to be a memref type");
+
+  return success();
+}
+
+namespace mlir {
+
+#define GET_OP_CLASSES
+#include "mlir/VectorOps/VectorOps.cpp.inc"
+
+} // namespace mlir
diff --git a/third_party/mlir/mlir_configure.bzl b/third_party/mlir/mlir_configure.bzl
deleted file mode 100644
index ade32db..0000000
--- a/third_party/mlir/mlir_configure.bzl
+++ /dev/null
@@ -1,34 +0,0 @@
-"""Repository rule to setup the external MLIR repository."""
-
-_MLIR_REV = "83ff81bfd9d382852d0302ab2a234feb2e938fc7"
-_MLIR_SHA256 = "26979670616980014a823f88c1a057c28080763d9cb189fa67172a92c085d349"
-
-def _mlir_autoconf_impl(repository_ctx):
-    """Implementation of the mlir_configure repository rule."""
-    repository_ctx.download_and_extract(
-        [
-            "https://storage.googleapis.com/mirror.tensorflow.org/github.com/tensorflow/mlir/archive/{}.zip".format(_MLIR_REV),
-            "https://github.com/tensorflow/mlir/archive/{}.zip".format(_MLIR_REV),
-        ],
-        sha256 = _MLIR_SHA256,
-        stripPrefix = "mlir-{}".format(_MLIR_REV),
-    )
-
-    # Merge the checked-in BUILD files into the downloaded repo.
-    for file in ["BUILD", "tblgen.bzl", "test/BUILD"]:
-        repository_ctx.template(file, Label("//third_party/mlir:" + file))
-
-mlir_configure = repository_rule(
-    implementation = _mlir_autoconf_impl,
-)
-"""Configures the MLIR repository.
-
-Add the following to your WORKSPACE FILE:
-
-```python
-mlir_configure(name = "local_config_mlir")
-```
-
-Args:
-  name: A unique name for this workspace rule.
-"""
diff --git a/third_party/mlir/test/APITest.h b/third_party/mlir/test/APITest.h
new file mode 100644
index 0000000..6b02108
--- /dev/null
+++ b/third_party/mlir/test/APITest.h
@@ -0,0 +1,72 @@
+//===- Test.h - Simple macros for API unit tests ----------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file define simple macros for declaring test functions and running them.
+// The actual checking must be performed on the outputs with FileCheck.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TEST_TEST_H_
+#define MLIR_TEST_TEST_H_
+
+#include <functional>
+#include <vector>
+
+namespace test_detail {
+// Returns a mutable list of known test functions.  Used internally by test
+// macros to add and run tests.  This function is static to ensure it creates a
+// new list in each test file.
+static std::vector<std::function<void()>> &tests() {
+  static std::vector<std::function<void()>> list;
+  return list;
+}
+
+// Test registration class.  Used internally by test macros to register tests
+// during static allocation.
+struct TestRegistration {
+  explicit TestRegistration(std::function<void()> func) {
+    test_detail::tests().push_back(func);
+  }
+};
+} // end namespace test_detail
+
+/// Declares a test function with the given name and adds it to the list of
+/// known tets.  The body of the function must follow immediately.  Example:
+///
+/// TEST_FUNC(mytest) {
+///   // CHECK: expected-output-here
+///   emitSomethingToStdOut();
+/// }
+///
+#define TEST_FUNC(name)                                                        \
+  void name();                                                                 \
+  static test_detail::TestRegistration name##Registration(name);               \
+  void name()
+
+/// Runs all registered tests.  Example:
+///
+/// int main() {
+///   RUN_TESTS();
+///   return 0;
+/// }
+#define RUN_TESTS                                                              \
+  []() {                                                                       \
+    for (auto f : test_detail::tests())                                        \
+      f();                                                                     \
+  }
+
+#endif // MLIR_TEST_TEST_H_
diff --git a/third_party/mlir/test/BUILD b/third_party/mlir/test/BUILD
index 90bdfa4..fa389f5 100644
--- a/third_party/mlir/test/BUILD
+++ b/third_party/mlir/test/BUILD
@@ -56,11 +56,11 @@
     deps = [
         ":TestOpsIncGen",
         "@llvm//:support",
+        "@local_config_mlir//:Dialect",
         "@local_config_mlir//:IR",
         "@local_config_mlir//:Pass",
         "@local_config_mlir//:Support",
         "@local_config_mlir//:Transforms",
-        "@local_config_mlir//:TypeUtilities",
     ],
     alwayslink = 1,
 )
@@ -70,6 +70,8 @@
     srcs = [
         "lib/Transforms/TestConstantFold.cpp",
         "lib/Transforms/TestLoopFusion.cpp",
+        "lib/Transforms/TestLoopMapping.cpp",
+        "lib/Transforms/TestLoopParametricTiling.cpp",
         "lib/Transforms/TestVectorizationUtils.cpp",
     ],
     deps = [
@@ -78,6 +80,7 @@
         "@local_config_mlir//:Analysis",
         "@local_config_mlir//:EDSC",
         "@local_config_mlir//:IR",
+        "@local_config_mlir//:LoopOps",
         "@local_config_mlir//:Pass",
         "@local_config_mlir//:StandardOps",
         "@local_config_mlir//:Support",
diff --git a/third_party/mlir/test/CMakeLists.txt b/third_party/mlir/test/CMakeLists.txt
new file mode 100644
index 0000000..2e10239
--- /dev/null
+++ b/third_party/mlir/test/CMakeLists.txt
@@ -0,0 +1,69 @@
+add_subdirectory(EDSC)
+add_subdirectory(mlir-cpu-runner)
+add_subdirectory(SDBM)
+add_subdirectory(lib)
+
+llvm_canonicalize_cmake_booleans(
+  LLVM_BUILD_EXAMPLES
+  )
+
+# Passed to lit.site.cfg.py.in to set up the path where to find the libraries
+# for linalg integration tests.
+set(MLIR_LINALG_INTEGRATION_TEST_LIB_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
+
+# Passed to lit.site.cfg.py.in to set up the path where to find the libraries
+# for the mlir cuda runner tests.
+set(MLIR_CUDA_WRAPPER_LIBRARY_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
+
+configure_lit_site_cfg(
+  ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
+  ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
+  MAIN_CONFIG
+  ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py
+  )
+configure_lit_site_cfg(
+  ${CMAKE_CURRENT_SOURCE_DIR}/Unit/lit.site.cfg.py.in
+  ${CMAKE_CURRENT_BINARY_DIR}/Unit/lit.site.cfg.py
+  MAIN_CONFIG
+  ${CMAKE_CURRENT_SOURCE_DIR}/Unit/lit.cfg.py
+  )
+
+set(MLIR_TEST_DEPENDS
+  FileCheck count not
+  MLIRUnitTests
+  mlir-cpu-runner
+  mlir-edsc-builder-api-test
+  mlir-opt
+  mlir-sdbm-api-test
+  mlir-tblgen
+  mlir-translate
+  cblas
+  cblas_interface
+  )
+
+if(LLVM_BUILD_EXAMPLES)
+  list(APPEND MLIR_TEST_DEPENDS
+    linalg1-opt
+    toyc-ch1
+    toyc-ch2
+    toyc-ch3
+    toyc-ch4
+    toyc-ch5
+    )
+endif()
+
+if(MLIR_CUDA_RUNNER_ENABLED)
+  list(APPEND MLIR_TEST_DEPENDS
+    mlir-cuda-runner
+  )
+endif()
+
+add_lit_testsuite(check-mlir "Running the MLIR regression tests"
+  ${CMAKE_CURRENT_BINARY_DIR}
+  DEPENDS ${MLIR_TEST_DEPENDS}
+  )
+set_target_properties(check-mlir PROPERTIES FOLDER "Tests")
+
+add_lit_testsuites(MLIR ${CMAKE_CURRENT_SOURCE_DIR}
+  DEPENDS ${MLIR_TEST_DEPS}
+)
diff --git a/third_party/mlir/test/lib/CMakeLists.txt b/third_party/mlir/test/lib/CMakeLists.txt
new file mode 100644
index 0000000..860376b
--- /dev/null
+++ b/third_party/mlir/test/lib/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_subdirectory(TestDialect)
+add_subdirectory(Transforms)
diff --git a/third_party/mlir/test/lib/TestDialect/CMakeLists.txt b/third_party/mlir/test/lib/TestDialect/CMakeLists.txt
new file mode 100644
index 0000000..77bcd42
--- /dev/null
+++ b/third_party/mlir/test/lib/TestDialect/CMakeLists.txt
@@ -0,0 +1,25 @@
+set(LLVM_OPTIONAL_SOURCES
+  TestDialect.cpp
+  TestPatterns.cpp
+)
+
+set(LLVM_TARGET_DEFINITIONS TestOps.td)
+mlir_tablegen(TestOps.h.inc -gen-op-decls)
+mlir_tablegen(TestOps.cpp.inc -gen-op-defs)
+mlir_tablegen(TestPatterns.inc -gen-rewriters)
+add_public_tablegen_target(MLIRTestOpsIncGen)
+
+add_llvm_library(MLIRTestDialect
+  TestDialect.cpp
+  TestPatterns.cpp
+)
+add_dependencies(MLIRTestDialect
+  MLIRTestOpsIncGen
+  MLIRIR
+  LLVMSupport
+)
+target_link_libraries(MLIRTestDialect
+  MLIRDialect
+  MLIRIR
+  LLVMSupport
+)
diff --git a/third_party/mlir/test/lib/TestDialect/TestDialect.cpp b/third_party/mlir/test/lib/TestDialect/TestDialect.cpp
new file mode 100644
index 0000000..f71eff9
--- /dev/null
+++ b/third_party/mlir/test/lib/TestDialect/TestDialect.cpp
@@ -0,0 +1,81 @@
+//===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "TestDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// TestDialect
+//===----------------------------------------------------------------------===//
+
+TestDialect::TestDialect(MLIRContext *context)
+    : Dialect(getDialectName(), context) {
+  addOperations<
+#define GET_OP_LIST
+#include "TestOps.cpp.inc"
+      >();
+  allowUnknownOperations();
+}
+
+//===----------------------------------------------------------------------===//
+// Test PolyForOp - parse list of region arguments.
+//===----------------------------------------------------------------------===//
+ParseResult parsePolyForOp(OpAsmParser *parser, OperationState *result) {
+  SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
+  // Parse list of region arguments without a delimiter.
+  if (parser->parseRegionArgumentList(ivsInfo))
+    return failure();
+
+  // Parse the body region.
+  Region *body = result->addRegion();
+  auto &builder = parser->getBuilder();
+  SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
+  if (parser->parseRegion(*body, ivsInfo, argTypes))
+    return failure();
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Test removing op with inner ops.
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct TestRemoveOpWithInnerOps : public OpRewritePattern<TestOpWithRegion> {
+  using OpRewritePattern<TestOpWithRegion>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(TestOpWithRegion op,
+                                     PatternRewriter &rewriter) const override {
+    rewriter.replaceOp(op, llvm::None);
+    return matchSuccess();
+  }
+};
+} // end anonymous namespace
+
+void TestOpWithRegion::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<TestRemoveOpWithInnerOps>(context);
+}
+
+// Static initialization for Test dialect registration.
+static mlir::DialectRegistration<mlir::TestDialect> testDialect;
+
+#define GET_OP_CLASSES
+#include "TestOps.cpp.inc"
diff --git a/third_party/mlir/test/lib/TestDialect/TestDialect.h b/third_party/mlir/test/lib/TestDialect/TestDialect.h
new file mode 100644
index 0000000..8e3efa3
--- /dev/null
+++ b/third_party/mlir/test/lib/TestDialect/TestDialect.h
@@ -0,0 +1,48 @@
+//===- TestDialect.h - MLIR Dialect for testing -----------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines a fake 'test' dialect that can be used for testing things
+// that do not have a respective counterpart in the main source directories.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TESTDIALECT_H
+#define MLIR_TESTDIALECT_H
+
+#include "mlir/Dialect/Traits.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+
+class TestDialect : public Dialect {
+public:
+  /// Create the dialect in the given `context`.
+  TestDialect(MLIRContext *context);
+
+  /// Get the canonical string name of the dialect.
+  static StringRef getDialectName() { return "test"; }
+};
+
+#define GET_OP_CLASSES
+#include "TestOps.h.inc"
+
+} // end namespace mlir
+
+#endif // MLIR_TESTDIALECT_H
diff --git a/third_party/mlir/test/lib/TestDialect/TestOps.td b/third_party/mlir/test/lib/TestDialect/TestOps.td
new file mode 100644
index 0000000..8a22adf
--- /dev/null
+++ b/third_party/mlir/test/lib/TestDialect/TestOps.td
@@ -0,0 +1,576 @@
+//===-- TestOps.td - Test dialect operation definitions ----*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifdef TEST_OPS
+#else
+#define TEST_OPS
+
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+def TEST_Dialect : Dialect {
+  let name = "test";
+  let cppNamespace = "";
+}
+
+class TEST_Op<string mnemonic, list<OpTrait> traits = []> :
+    Op<TEST_Dialect, mnemonic, traits>;
+
+//===----------------------------------------------------------------------===//
+// Test Types
+//===----------------------------------------------------------------------===//
+
+def AnyVectorOrTensor: AnyTypeOf<[AnyVector, AnyTensor]>;
+
+def TupleOp : TEST_Op<"tuple_32_bit"> {
+  let results = (outs TupleOf<[I32, F32]>);
+}
+
+def NestedTupleOp : TEST_Op<"nested_tuple_32_bit"> {
+  let results = (outs NestedTupleOf<[I32, F32]>);
+}
+
+//===----------------------------------------------------------------------===//
+// Test Operands
+//===----------------------------------------------------------------------===//
+
+def MixedNormalVariadicOperandOp : TEST_Op<
+    "mixed_normal_variadic_operand", [SameVariadicOperandSize]> {
+  let arguments = (ins
+    Variadic<AnyTensor>:$input1,
+    AnyTensor:$input2,
+    Variadic<AnyTensor>:$input3
+  );
+}
+
+//===----------------------------------------------------------------------===//
+// Test Results
+//===----------------------------------------------------------------------===//
+
+def MixedNormalVariadicResults : TEST_Op<
+    "mixed_normal_variadic_result", [SameVariadicResultSize]> {
+  let results = (outs
+    Variadic<AnyTensor>:$output1,
+    AnyTensor:$output2,
+    Variadic<AnyTensor>:$output3
+  );
+}
+
+//===----------------------------------------------------------------------===//
+// Test Attributes
+//===----------------------------------------------------------------------===//
+
+def NonNegIntAttrOp : TEST_Op<"non_negative_int_attr"> {
+  let arguments = (ins
+      NonNegativeI32Attr:$i32attr,
+      NonNegativeI64Attr:$i64attr
+  );
+}
+
+def TypeArrayAttrOp : TEST_Op<"type_array_attr"> {
+  let arguments = (ins TypeArrayAttr:$attr);
+}
+def TypeStringAttrWithTypeOp : TEST_Op<"string_attr_with_type"> {
+  let arguments = (ins StrAttr:$attr);
+  let printer = [{ *p << getAttr("attr"); }];
+  let parser = [{
+    Attribute attr;
+    Type stringType = OpaqueType::get(Identifier::get("foo", result->context),
+                                      "string", result->context);
+    return parser->parseAttribute(attr, stringType, "attr", result->attributes);
+  }];
+}
+
+def StrCaseA: StrEnumAttrCase<"A">;
+def StrCaseB: StrEnumAttrCase<"B">;
+
+def SomeStrEnum: StrEnumAttr<
+  "SomeStrEnum", "", [StrCaseA, StrCaseB]>;
+
+def StrEnumAttrOp : TEST_Op<"str_enum_attr"> {
+  let arguments = (ins SomeStrEnum:$attr);
+  let results = (outs I32:$val);
+}
+
+def I32Case5:  I32EnumAttrCase<"case5", 5>;
+def I32Case10: I32EnumAttrCase<"case10", 10>;
+
+def SomeI32Enum: I32EnumAttr<
+  "SomeI32Enum", "", [I32Case5, I32Case10]>;
+
+def I32EnumAttrOp : TEST_Op<"i32_enum_attr"> {
+  let arguments = (ins SomeI32Enum:$attr);
+  let results = (outs I32:$val);
+}
+
+def I64Case5:  I64EnumAttrCase<"case5", 5>;
+def I64Case10: I64EnumAttrCase<"case10", 10>;
+
+def SomeI64Enum: I64EnumAttr<
+  "SomeI64Enum", "", [I64Case5, I64Case10]>;
+
+def I64EnumAttrOp : TEST_Op<"i64_enum_attr"> {
+  let arguments = (ins SomeI64Enum:$attr);
+  let results = (outs I32:$val);
+}
+
+//===----------------------------------------------------------------------===//
+// Test Regions
+//===----------------------------------------------------------------------===//
+
+def TwoRegionOp : TEST_Op<"two_region_op", []> {
+  let regions = (region AnyRegion, AnyRegion);
+}
+
+def SizedRegionOp : TEST_Op<"sized_region_op", []> {
+  let regions = (region SizedRegion<2>:$my_region, SizedRegion<1>);
+}
+
+//===----------------------------------------------------------------------===//
+// Test Traits
+//===----------------------------------------------------------------------===//
+
+def SameOperandAndResultElementTypeOp : TEST_Op<"same_operand_and_result_type",
+    [SameOperandsAndResultElementType]> {
+  let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y);
+  let results = (outs AnyVectorOrTensor:$res);
+}
+
+def SameOperandAndResultShapeOp : TEST_Op<"same_operand_and_result_shape",
+    [SameOperandsAndResultShape]> {
+  let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y);
+  let results = (outs AnyVectorOrTensor:$res);
+}
+
+def ArgAndResHaveFixedElementTypesOp :
+    TEST_Op<"arg_and_res_have_fixed_element_types",
+      [PredOpTrait<"fixed type combination",
+         Or<[And<[ElementTypeIsPred<"x", I32>,
+                  ElementTypeIsPred<"y", F32>]>,
+             ElementTypeIsPred<"attr", I8>]>>,
+      ElementTypeIs<"res", I16>]> {
+  let arguments = (ins
+    AnyVectorOrTensor:$x, AnyVectorOrTensor:$y, AnyAttr:$attr);
+  let results = (outs AnyVectorOrTensor:$res);
+}
+
+def OperandsHaveSameElementType : TEST_Op<"operands_have_same_element_type", [
+    AllElementTypesMatch<["x", "y"]>]> {
+  let arguments = (ins AnyTensor:$x, AnyTensor:$y);
+}
+
+def OperandOneAndResultHaveSameElementType : TEST_Op<
+    "operand_one_and_result_have_same_element_type",
+    [AllElementTypesMatch<["x", "res"]>]> {
+  let arguments = (ins AnyTensor:$x, AnyTensor:$y);
+  let results = (outs AnyTensor:$res);
+}
+
+def OperandsHaveSameType :
+    TEST_Op<"operands_have_same_type", [AllTypesMatch<["x", "y"]>]> {
+  let arguments = (ins AnyTensor:$x, AnyTensor:$y);
+}
+
+def OperandOneAndResultHaveSameType :
+    TEST_Op<"operand_one_and_result_have_same_type",
+            [AllTypesMatch<["x", "res"]>]> {
+  let arguments = (ins AnyTensor:$x, AnyTensor:$y);
+  let results = (outs AnyTensor:$res);
+}
+
+def IfFirstOperandIsNoneThenSoIsSecond :
+    TEST_Op<"if_first_operand_is_none_then_so_is_second", [PredOpTrait<
+    "has either both none type operands or first is not none",
+     Or<[
+        And<[TypeIsPred<"x", NoneType>, TypeIsPred<"y", NoneType>]>,
+        Neg<TypeIsPred<"x", NoneType>>]>>]> {
+  let arguments = (ins AnyType:$x, AnyType:$y);
+}
+
+def BroadcastableOp : TEST_Op<"broadcastable", [Broadcastable]> {
+  let arguments = (ins AnyTensor:$x, AnyTensor:$y);
+  let results = (outs AnyTensor:$res);
+}
+
+// There the "HasParent" trait.
+def ParentOp : TEST_Op<"parent">;
+def ChildOp : TEST_Op<"child", [HasParent<"ParentOp">]>;
+
+
+def TerminatorOp : TEST_Op<"finish", [Terminator]> {
+}
+def SingleBlockImplicitTerminatorOp : TEST_Op<"SingleBlockImplicitTerminator",
+    [SingleBlockImplicitTerminator<"TerminatorOp">]> {
+  let regions = (region SizedRegion<1>:$region);
+}
+
+//===----------------------------------------------------------------------===//
+// Test Patterns
+//===----------------------------------------------------------------------===//
+
+def OpA : TEST_Op<"op_a"> {
+  let arguments = (ins I32:$operand, I32Attr:$attr);
+  let results = (outs I32:$result);
+}
+
+def OpB : TEST_Op<"op_b"> {
+  let arguments = (ins I32:$operand, I32Attr:$attr);
+  let results = (outs I32:$result);
+}
+
+// Test named pattern.
+def TestNamedPatternRule : Pat<(OpA $input, $attr), (OpB $input, $attr)>;
+
+// Test with fused location.
+def : Pat<(OpA (OpA $input, $attr), $bttr), (OpB $input, $bttr)>;
+
+// Test added benefit.
+def OpD : TEST_Op<"op_d">, Arguments<(ins I32:$arg)>, Results<(outs I32:$res)>;
+def OpE : TEST_Op<"op_e">, Arguments<(ins I32:$arg)>, Results<(outs I32:$res)>;
+def OpF : TEST_Op<"op_f">, Arguments<(ins I32:$arg)>, Results<(outs I32:$res)>;
+def OpG : TEST_Op<"op_g">, Arguments<(ins I32:$arg)>, Results<(outs I32:$res)>;
+// Verify that bumping benefit results in selecting different op.
+def : Pat<(OpD $input), (OpE $input)>;
+def : Pat<(OpD $input), (OpF $input), [], (addBenefit 10)>;
+// Verify that patterns with more source nodes are selected before those with fewer.
+def : Pat<(OpG $input), (OpB $input, ConstantAttr<I32Attr, "20">:$attr)>;
+def : Pat<(OpG (OpG $input)), (OpB $input, ConstantAttr<I32Attr, "34">:$attr)>;
+
+// Test patterns for zero-result op.
+def OpH : TEST_Op<"op_h">, Arguments<(ins I32:$arg)>, Results<(outs)>;
+def OpI : TEST_Op<"op_i">, Arguments<(ins I32:$arg)>, Results<(outs)>;
+def : Pat<(OpH $input), (OpI $input)>;
+
+// Test patterns for zero-input op.
+def OpJ : TEST_Op<"op_j">, Arguments<(ins)>, Results<(outs I32:$res)>;
+def OpK : TEST_Op<"op_k">, Arguments<(ins)>, Results<(outs I32:$res)>;
+def : Pat<(OpJ), (OpK)>;
+
+// Test NativeCodeCall.
+def OpNativeCodeCall1 : TEST_Op<"native_code_call1"> {
+  let arguments = (ins
+    I32:$input1, I32:$input2,
+    BoolAttr:$choice,
+    I64Attr:$attr1, I64Attr:$attr2
+  );
+  let results = (outs I32:$output);
+}
+def OpNativeCodeCall2 : TEST_Op<"native_code_call2"> {
+  let arguments = (ins I32:$input, I64ArrayAttr:$attr);
+  let results = (outs I32:$output);
+}
+// Native code call to invoke a C++ function
+def CreateOperand: NativeCodeCall<"chooseOperand($0, $1, $2)">;
+// Native code call to invoke a C++ expression
+def CreateArraryAttr: NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">;
+// Test that we can use NativeCodeCall to create operand and attribute.
+// This pattern chooses between $input1 and $input2 according to $choice and
+// it combines $attr1 and $attr2 into an array attribute.
+def : Pat<(OpNativeCodeCall1 $input1, $input2,
+                             ConstBoolAttrTrue:$choice, $attr1, $attr2),
+          (OpNativeCodeCall2 (CreateOperand $input1, $input2, $choice),
+                             (CreateArraryAttr $attr1, $attr2))>;
+// Note: the following is just for testing purpose.
+// Should use the replaceWithValue directive instead.
+def UseOpResult: NativeCodeCall<"$0">;
+// Test that we can use NativeCodeCall to create result.
+def : Pat<(OpNativeCodeCall1 $input1, $input2,
+                             ConstBoolAttrFalse, $attr1, $attr2),
+          (UseOpResult $input2)>;
+
+// Test AllAttrConstraintsOf.
+def OpAllAttrConstraint1 : TEST_Op<"all_attr_constraint_of1"> {
+  let arguments = (ins I64ArrayAttr:$attr);
+  let results = (outs I32:$output);
+}
+def OpAllAttrConstraint2 : TEST_Op<"all_attr_constraint_of2"> {
+  let arguments = (ins I64ArrayAttr:$attr);
+  let results = (outs I32:$output);
+}
+def Constraint0 : AttrConstraint<
+    CPred<"$_self.cast<ArrayAttr>().getValue()[0]."
+          "cast<IntegerAttr>().getInt() == 0">,
+    "[0] == 0">;
+def Constraint1 : AttrConstraint<
+    CPred<"$_self.cast<ArrayAttr>().getValue()[1]."
+          "cast<IntegerAttr>().getInt() == 1">,
+    "[1] == 1">;
+def : Pat<(OpAllAttrConstraint1
+            AllAttrConstraintsOf<[Constraint0, Constraint1]>:$attr),
+          (OpAllAttrConstraint2 $attr)>;
+
+// Op for testing RewritePattern removing op with inner ops.
+def TestOpWithRegion : TEST_Op<"op_with_region"> {
+  let regions = (region SizedRegion<1>:$region);
+  let hasCanonicalizer = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Test Patterns (Symbol Binding)
+
+// Test symbol binding.
+def OpSymbolBindingA : TEST_Op<"symbol_binding_a", []> {
+  let arguments = (ins I32:$operand, I64Attr:$attr);
+  let results = (outs I32:$result);
+}
+def OpSymbolBindingB : TEST_Op<"symbol_binding_b", []> {
+  let arguments = (ins I32:$operand);
+  let results = (outs I32:$result);
+
+  let builders = [
+    OpBuilder<
+      "Builder *builder, OperationState *state, Value *operand",
+      [{
+        state->types.assign({builder->getIntegerType(32)});
+        state->addOperands({operand});
+      }]>
+  ];
+}
+def OpSymbolBindingC : TEST_Op<"symbol_binding_c", []> {
+  let arguments = (ins I32:$operand);
+  let results = (outs I32:$result);
+  let builders = OpSymbolBindingB.builders;
+}
+def OpSymbolBindingD : TEST_Op<"symbol_binding_d", []> {
+  let arguments = (ins I32:$input1, I32:$input2, I64Attr:$attr);
+  let results = (outs I32:$result);
+}
+def HasOneUse: Constraint<CPred<"$0->hasOneUse()">, "has one use">;
+def : Pattern<
+    // Bind to source pattern op operand/attribute/result
+    (OpSymbolBindingA:$res_a $operand, $attr), [
+        // Bind to auxiliary op result
+        (OpSymbolBindingC:$res_c (OpSymbolBindingB:$res_b $operand)),
+
+        // Use bound symbols in resultant ops
+        (OpSymbolBindingD $res_b, $res_c, $attr)],
+    // Use bound symbols in additional constraints
+    [(HasOneUse $res_a)]>;
+
+//===----------------------------------------------------------------------===//
+// Test Patterns (Attributes)
+
+// Test matching against op attributes.
+def OpAttrMatch1 : TEST_Op<"match_op_attribute1"> {
+  let arguments = (ins
+    I32Attr:$required_attr,
+    OptionalAttr<I32Attr>:$optional_attr,
+    DefaultValuedAttr<I32Attr, "42">:$default_valued_attr,
+    I32Attr:$more_attr
+  );
+  let results = (outs I32:$output);
+}
+def OpAttrMatch2 : TEST_Op<"match_op_attribute2"> {
+  let arguments = OpAttrMatch1.arguments;
+  let results = (outs I32:$output);
+}
+def MoreConstraint : AttrConstraint<
+    CPred<"$_self.cast<IntegerAttr>().getInt() == 4">, "more constraint">;
+def : Pat<(OpAttrMatch1 $required, $optional, $default_valued,
+                        MoreConstraint:$more),
+          (OpAttrMatch2 $required, $optional, $default_valued, $more)>;
+
+// Test unit attrs.
+def OpAttrMatch3 : TEST_Op<"match_op_attribute3"> {
+  let arguments = (ins UnitAttr:$attr);
+  let results = (outs I32);
+}
+def OpAttrMatch4 : TEST_Op<"match_op_attribute4"> {
+  let arguments = (ins UnitAttr:$attr1, UnitAttr:$attr2);
+  let results = (outs I32);
+}
+def : Pat<(OpAttrMatch3 $attr), (OpAttrMatch4 ConstUnitAttr, $attr)>;
+
+// Test with constant attr.
+def OpC : TEST_Op<"op_c">, Arguments<(ins I32:$arg)>, Results<(outs I32:$res)>;
+def : Pat<(OpC $input), (OpB $input, ConstantAttr<I32Attr, "17">:$attr)>;
+
+// Test string enum attribute in rewrites.
+def : Pat<(StrEnumAttrOp StrCaseA), (StrEnumAttrOp StrCaseB)>;
+// Test integer enum attribute in rewrites.
+def : Pat<(I32EnumAttrOp I32Case5), (I32EnumAttrOp I32Case10)>;
+def : Pat<(I64EnumAttrOp I64Case5), (I64EnumAttrOp I64Case10)>;
+
+//===----------------------------------------------------------------------===//
+// Test Patterns (Multi-result Ops)
+
+def MultiResultOpKind1: I64EnumAttrCase<"kind1", 1>;
+def MultiResultOpKind2: I64EnumAttrCase<"kind2", 2>;
+def MultiResultOpKind3: I64EnumAttrCase<"kind3", 3>;
+def MultiResultOpKind4: I64EnumAttrCase<"kind4", 4>;
+def MultiResultOpKind5: I64EnumAttrCase<"kind5", 5>;
+def MultiResultOpKind6: I64EnumAttrCase<"kind6", 6>;
+
+def MultiResultOpEnum: I64EnumAttr<
+  "Multi-result op kinds", "", [
+    MultiResultOpKind1, MultiResultOpKind2, MultiResultOpKind3,
+    MultiResultOpKind4, MultiResultOpKind5, MultiResultOpKind6
+  ]>;
+
+def ThreeResultOp : TEST_Op<"three_result"> {
+  let arguments = (ins MultiResultOpEnum:$kind);
+  let results = (outs I32:$result1, F32:$result2, F32:$result3);
+}
+
+def AnotherThreeResultOp : TEST_Op<"another_three_result"> {
+  let arguments = (ins MultiResultOpEnum:$kind);
+  let results = (outs I32:$result1, F32:$result2, F32:$result3);
+}
+
+def TwoResultOp : TEST_Op<"two_result"> {
+  let arguments = (ins MultiResultOpEnum:$kind);
+  let results = (outs I32:$result1, F32:$result2);
+
+  let builders = [
+    OpBuilder<
+      "Builder *builder, OperationState *state, IntegerAttr kind",
+      [{
+        auto i32 = builder->getIntegerType(32);
+        auto f32 = builder->getF32Type();
+        state->types.assign({i32, f32});
+        state->addAttribute("kind", kind);
+      }]>
+  ];
+}
+
+def AnotherTwoResultOp : TEST_Op<"another_two_result"> {
+  let arguments = (ins MultiResultOpEnum:$kind);
+  let results = (outs F32:$result1, F32:$result2);
+}
+
+def OneResultOp1 : TEST_Op<"one_result1"> {
+  let arguments = (ins MultiResultOpEnum:$kind);
+  let results = (outs F32:$result1);
+}
+
+def OneResultOp2 : TEST_Op<"one_result2"> {
+  let arguments = (ins MultiResultOpEnum:$kind);
+  let results = (outs I32:$result1);
+}
+
+def OneResultOp3 : TEST_Op<"one_result3"> {
+  let arguments = (ins F32:$input);
+  let results = (outs I32:$result1);
+}
+
+// Test using multi-result op as a whole
+def : Pat<(ThreeResultOp MultiResultOpKind1),
+          (AnotherThreeResultOp MultiResultOpKind1)>;
+
+// Test using multi-result op as a whole for partial replacement
+def : Pattern<(ThreeResultOp MultiResultOpKind2),
+              [(TwoResultOp MultiResultOpKind2),
+               (OneResultOp1 MultiResultOpKind2)]>;
+def : Pattern<(ThreeResultOp MultiResultOpKind3),
+              [(OneResultOp2 MultiResultOpKind3),
+               (AnotherTwoResultOp MultiResultOpKind3)]>;
+
+// Test using results separately in a multi-result op
+def : Pattern<(ThreeResultOp MultiResultOpKind4),
+              [(TwoResultOp:$res1__0 MultiResultOpKind4),
+               (OneResultOp1 MultiResultOpKind4),
+               (TwoResultOp:$res2__1 MultiResultOpKind4)]>;
+
+// Test referencing a single value in the value pack
+// This rule only matches TwoResultOp if its second result has no use.
+def : Pattern<(TwoResultOp:$res MultiResultOpKind5),
+              [(OneResultOp2 MultiResultOpKind5),
+               (OneResultOp1 MultiResultOpKind5)],
+              [(HasNoUseOf:$res__1)]>;
+
+// Test using auxiliary ops for replacing multi-result op
+def : Pattern<
+    (ThreeResultOp MultiResultOpKind6), [
+        // Auxiliary op generated to help building the final result but not
+        // directly used to replace the source op's results.
+        (TwoResultOp:$interm MultiResultOpKind6),
+
+        (OneResultOp3 $interm__1),
+        (AnotherTwoResultOp MultiResultOpKind6)
+    ]>;
+
+//===----------------------------------------------------------------------===//
+// Test Legalization
+//===----------------------------------------------------------------------===//
+
+def Test_LegalizerEnum_Success : StrEnumAttrCase<"Success">;
+def Test_LegalizerEnum_Failure : StrEnumAttrCase<"Failure">;
+
+def Test_LegalizerEnum : StrEnumAttr<"Success", "Failure",
+  [Test_LegalizerEnum_Success, Test_LegalizerEnum_Failure]>;
+
+def ILLegalOpA : TEST_Op<"illegal_op_a">, Results<(outs I32:$res)>;
+def ILLegalOpB : TEST_Op<"illegal_op_b">, Results<(outs I32:$res)>;
+def ILLegalOpC : TEST_Op<"illegal_op_c">, Results<(outs I32:$res)>;
+def ILLegalOpD : TEST_Op<"illegal_op_d">, Results<(outs I32:$res)>;
+def ILLegalOpE : TEST_Op<"illegal_op_e">, Results<(outs I32:$res)>;
+def ILLegalOpF : TEST_Op<"illegal_op_f">, Results<(outs I32:$res)>;
+def LegalOpA : TEST_Op<"legal_op_a">,
+  Arguments<(ins Test_LegalizerEnum:$status)>, Results<(outs I32:$res)>;
+
+// Check that smaller pattern depths are chosen, i.e. prioritize more direct
+// mappings.
+def : Pat<(ILLegalOpA), (LegalOpA Test_LegalizerEnum_Success)>;
+
+def : Pat<(ILLegalOpA), (ILLegalOpB)>;
+def : Pat<(ILLegalOpB), (LegalOpA Test_LegalizerEnum_Failure)>;
+
+// Check that the higher benefit pattern is taken for multiple legalizations
+// with the same depth.
+def : Pat<(ILLegalOpC), (ILLegalOpD)>;
+def : Pat<(ILLegalOpD), (LegalOpA Test_LegalizerEnum_Failure)>;
+
+def : Pat<(ILLegalOpC), (ILLegalOpE), [], (addBenefit 10)>;
+def : Pat<(ILLegalOpE), (LegalOpA Test_LegalizerEnum_Success)>;
+
+// Check that patterns use the most up-to-date value when being replaced.
+def TestRewriteOp : TEST_Op<"rewrite">,
+  Arguments<(ins AnyType:$input)>, Results<(outs AnyType:$res)>;
+def : Pat<(TestRewriteOp $input), (replaceWithValue $input)>;
+
+//===----------------------------------------------------------------------===//
+// Test Type Legalization
+//===----------------------------------------------------------------------===//
+
+def TestRegionBuilderOp : TEST_Op<"region_builder">;
+def TestReturnOp : TEST_Op<"return", [Terminator]>,
+  Arguments<(ins Variadic<AnyType>:$inputs)>;
+def TestCastOp : TEST_Op<"cast">,
+  Arguments<(ins Variadic<AnyType>:$inputs)>, Results<(outs AnyType:$res)>;
+def TestInvalidOp : TEST_Op<"invalid", [Terminator]>,
+  Arguments<(ins Variadic<AnyType>:$inputs)>;
+def TestValidOp : TEST_Op<"valid", [Terminator]>,
+  Arguments<(ins Variadic<AnyType>:$inputs)>;
+
+//===----------------------------------------------------------------------===//
+// Test region argument list parsing.
+//===----------------------------------------------------------------------===//
+
+def PolyForOp : TEST_Op<"polyfor">
+{
+  let summary =  "polyfor operation";
+  let description = [{
+    Test op with multiple region arguments, each argument of index type.
+  }];
+
+  let regions = (region SizedRegion<1>:$region);
+  let parser = [{ return ::parse$cppClass(parser, result); }];
+}
+
+#endif // TEST_OPS
diff --git a/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp b/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp
new file mode 100644
index 0000000..584ff99
--- /dev/null
+++ b/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp
@@ -0,0 +1,252 @@
+//===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "TestDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+using namespace mlir;
+
+// Native function for testing NativeCodeCall
+static Value *chooseOperand(Value *input1, Value *input2, BoolAttr choice) {
+  return choice.getValue() ? input1 : input2;
+}
+
+namespace {
+#include "TestPatterns.inc"
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Canonicalizer Driver.
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct TestPatternDriver : public FunctionPass<TestPatternDriver> {
+  void runOnFunction() override {
+    mlir::OwningRewritePatternList patterns;
+    populateWithGenerated(&getContext(), &patterns);
+
+    // Verify named pattern is generated with expected name.
+    patterns.insert<TestNamedPatternRule>(&getContext());
+
+    applyPatternsGreedily(getFunction(), patterns);
+  }
+};
+} // end anonymous namespace
+
+static mlir::PassRegistration<TestPatternDriver>
+    pass("test-patterns", "Run test dialect patterns");
+
+//===----------------------------------------------------------------------===//
+// Legalization Driver.
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This pattern is a simple pattern that inlines the first region of a given
+/// operation into the parent region.
+struct TestRegionRewriteBlockMovement : public ConversionPattern {
+  TestRegionRewriteBlockMovement(MLIRContext *ctx)
+      : ConversionPattern("test.region", 1, ctx) {}
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    // Inline this region into the parent region.
+    auto &parentRegion = *op->getParentRegion();
+    rewriter.inlineRegionBefore(op->getRegion(0), parentRegion,
+                                parentRegion.end());
+
+    // Drop this operation.
+    rewriter.replaceOp(op, llvm::None);
+    return matchSuccess();
+  }
+};
+/// This pattern is a simple pattern that generates a region containing an
+/// illegal operation.
+struct TestRegionRewriteUndo : public RewritePattern {
+  TestRegionRewriteUndo(MLIRContext *ctx)
+      : RewritePattern("test.region_builder", 1, ctx) {}
+
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const final {
+    // Create the region operation with an entry block containing arguments.
+    OperationState newRegion(op->getLoc(), "test.region");
+    newRegion.addRegion();
+    auto *regionOp = rewriter.createOperation(newRegion);
+    auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
+    entryBlock->addArgument(rewriter.getIntegerType(64));
+
+    // Add an explicitly illegal operation to ensure the conversion fails.
+    rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
+    rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value *>());
+
+    // Drop this operation.
+    rewriter.replaceOp(op, llvm::None);
+    return matchSuccess();
+  }
+};
+/// This pattern simply erases the given operation.
+struct TestDropOp : public ConversionPattern {
+  TestDropOp(MLIRContext *ctx) : ConversionPattern("test.drop_op", 1, ctx) {}
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    rewriter.replaceOp(op, llvm::None);
+    return matchSuccess();
+  }
+};
+/// This pattern simply updates the operands of the given operation.
+struct TestPassthroughInvalidOp : public ConversionPattern {
+  TestPassthroughInvalidOp(MLIRContext *ctx)
+      : ConversionPattern("test.invalid", 1, ctx) {}
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
+                                             llvm::None);
+    return matchSuccess();
+  }
+};
+/// This pattern handles the case of a split return value.
+struct TestSplitReturnType : public ConversionPattern {
+  TestSplitReturnType(MLIRContext *ctx)
+      : ConversionPattern("test.return", 1, ctx) {}
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    // Check for a return of F32.
+    if (op->getNumOperands() != 1 || !op->getOperand(0)->getType().isF32())
+      return matchFailure();
+
+    // Check if the first operation is a cast operation, if it is we use the
+    // results directly.
+    auto *defOp = operands[0]->getDefiningOp();
+    if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) {
+      SmallVector<Value *, 2> returnOperands(packerOp.getOperands());
+      rewriter.replaceOpWithNewOp<TestReturnOp>(op, returnOperands);
+      return matchSuccess();
+    }
+
+    // Otherwise, fail to match.
+    return matchFailure();
+  }
+};
+} // namespace
+
+namespace {
+struct TestTypeConverter : public TypeConverter {
+  using TypeConverter::TypeConverter;
+
+  LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) override {
+    // Drop I16 types.
+    if (t.isInteger(16))
+      return success();
+
+    // Convert I64 to F64.
+    if (t.isInteger(64)) {
+      results.push_back(FloatType::getF64(t.getContext()));
+      return success();
+    }
+
+    // Split F32 into F16,F16.
+    if (t.isF32()) {
+      results.assign(2, FloatType::getF16(t.getContext()));
+      return success();
+    }
+
+    // Otherwise, convert the type directly.
+    results.push_back(t);
+    return success();
+  }
+
+  /// Override the hook to materialize a conversion. This is necessary because
+  /// we generate 1->N type mappings.
+  Operation *materializeConversion(PatternRewriter &rewriter, Type resultType,
+                                   ArrayRef<Value *> inputs,
+                                   Location loc) override {
+    return rewriter.create<TestCastOp>(loc, resultType, inputs);
+  }
+};
+
+struct TestLegalizePatternDriver
+    : public ModulePass<TestLegalizePatternDriver> {
+  /// The mode of conversion to use with the driver.
+  enum class ConversionMode { Analysis, Partial };
+
+  TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
+
+  void runOnModule() override {
+    TestTypeConverter converter;
+    mlir::OwningRewritePatternList patterns;
+    populateWithGenerated(&getContext(), &patterns);
+    patterns.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
+                    TestDropOp, TestPassthroughInvalidOp, TestSplitReturnType>(
+        &getContext());
+    mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
+                                              converter);
+
+    // Define the conversion target used for the test.
+    ConversionTarget target(getContext());
+    target.addLegalOp<LegalOpA, TestCastOp, TestValidOp>();
+    target.addIllegalOp<ILLegalOpF, TestRegionBuilderOp>();
+    target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
+      // Don't allow F32 operands.
+      return llvm::none_of(op.getOperandTypes(),
+                           [](Type type) { return type.isF32(); });
+    });
+    target.addDynamicallyLegalOp<FuncOp>(
+        [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
+
+    // Handle a partial conversion.
+    if (mode == ConversionMode::Partial) {
+      (void)applyPartialConversion(getModule(), target, patterns, &converter);
+      return;
+    }
+
+    // Otherwise, handle an analysis conversion.
+    assert(mode == ConversionMode::Analysis);
+
+    // Analyze the convertible operations.
+    DenseSet<Operation *> legalizedOps;
+    if (failed(applyAnalysisConversion(getModule(), target, patterns,
+                                       legalizedOps, &converter)))
+      return signalPassFailure();
+
+    // Emit remarks for each legalizable operation.
+    for (auto *op : legalizedOps)
+      op->emitRemark() << "op '" << op->getName() << "' is legalizable";
+  }
+
+  /// The mode of conversion to use.
+  ConversionMode mode;
+};
+} // end anonymous namespace
+
+static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
+    legalizerConversionMode(
+        "test-legalize-mode",
+        llvm::cl::desc("The legalization mode to use with the test driver"),
+        llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
+        llvm::cl::values(
+            clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
+                       "analysis", "Perform an analysis conversion"),
+            clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
+                       "partial", "Perform a partial conversion")));
+
+static mlir::PassRegistration<TestLegalizePatternDriver> legalizer_pass(
+    "test-legalize-patterns", "Run test dialect legalization patterns",
+    [] { return new TestLegalizePatternDriver(legalizerConversionMode); });
diff --git a/third_party/mlir/test/lib/TestDialect/lit.local.cfg b/third_party/mlir/test/lib/TestDialect/lit.local.cfg
new file mode 100644
index 0000000..edb5b44
--- /dev/null
+++ b/third_party/mlir/test/lib/TestDialect/lit.local.cfg
@@ -0,0 +1 @@
+config.suffixes.remove('.td')
\ No newline at end of file
diff --git a/third_party/mlir/test/lib/Transforms/CMakeLists.txt b/third_party/mlir/test/lib/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..fa66eb3
--- /dev/null
+++ b/third_party/mlir/test/lib/Transforms/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_llvm_library(MLIRTestTransforms
+  TestConstantFold.cpp
+  TestLoopFusion.cpp
+  TestLoopMapping.cpp
+  TestLoopParametricTiling.cpp
+  TestVectorizationUtils.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
+  )
+add_dependencies(MLIRTestTransforms MLIRStandardOpsIncGen)
+target_link_libraries(MLIRTestTransforms
+  MLIRAffineOps
+  MLIRAnalysis
+  MLIRLoopOps
+  MLIRPass
+  MLIRVectorOps
+  )
diff --git a/third_party/mlir/test/lib/Transforms/TestConstantFold.cpp b/third_party/mlir/test/lib/Transforms/TestConstantFold.cpp
new file mode 100644
index 0000000..7d17f60
--- /dev/null
+++ b/third_party/mlir/test/lib/Transforms/TestConstantFold.cpp
@@ -0,0 +1,82 @@
+//===- TestConstantFold.cpp - Pass to test constant folding ---------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/Utils.h"
+
+using namespace mlir;
+
+namespace {
+/// Simple constant folding pass.
+struct TestConstantFold : public FunctionPass<TestConstantFold> {
+  // All constants in the function post folding.
+  SmallVector<Operation *, 8> existingConstants;
+
+  void foldOperation(Operation *op, OperationFolder &helper);
+  void runOnFunction() override;
+};
+} // end anonymous namespace
+
+void TestConstantFold::foldOperation(Operation *op, OperationFolder &helper) {
+  auto processGeneratedConstants = [this](Operation *op) {
+    existingConstants.push_back(op);
+  };
+
+  // Attempt to fold the specified operation, including handling unused or
+  // duplicated constants.
+  (void)helper.tryToFold(op, processGeneratedConstants);
+}
+
+// For now, we do a simple top-down pass over a function folding constants.  We
+// don't handle conditional control flow, block arguments, folding conditional
+// branches, or anything else fancy.
+void TestConstantFold::runOnFunction() {
+  existingConstants.clear();
+
+  // Collect and fold the operations within the function.
+  SmallVector<Operation *, 8> ops;
+  getFunction().walk([&](Operation *op) { ops.push_back(op); });
+
+  // Fold the constants in reverse so that the last generated constants from
+  // folding are at the beginning. This creates somewhat of a linear ordering to
+  // the newly generated constants that matches the operation order and improves
+  // the readability of test cases.
+  OperationFolder helper;
+  for (Operation *op : llvm::reverse(ops))
+    foldOperation(op, helper);
+
+  // By the time we are done, we may have simplified a bunch of code, leaving
+  // around dead constants.  Check for them now and remove them.
+  for (auto *cst : existingConstants) {
+    if (cst->use_empty())
+      cst->erase();
+  }
+}
+
+/// Creates a constant folding pass.
+FunctionPassBase *mlir::createTestConstantFoldPass() {
+  return new TestConstantFold();
+}
+
+static PassRegistration<TestConstantFold>
+    pass("test-constant-fold", "Test operation constant folding");
diff --git a/third_party/mlir/test/lib/Transforms/TestLoopFusion.cpp b/third_party/mlir/test/lib/Transforms/TestLoopFusion.cpp
new file mode 100644
index 0000000..3999096
--- /dev/null
+++ b/third_party/mlir/test/lib/Transforms/TestLoopFusion.cpp
@@ -0,0 +1,175 @@
+//===- TestLoopFusion.cpp - Test loop fusion ------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to test various loop fusion utility functions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/Passes.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Transforms/LoopFusionUtils.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "test-loop-fusion"
+
+using namespace mlir;
+
+static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
+
+static llvm::cl::opt<bool> clTestDependenceCheck(
+    "test-loop-fusion-dependence-check",
+    llvm::cl::desc("Enable testing of loop fusion dependence check"),
+    llvm::cl::cat(clOptionsCategory));
+
+static llvm::cl::opt<bool> clTestSliceComputation(
+    "test-loop-fusion-slice-computation",
+    llvm::cl::desc("Enable testing of loop fusion slice computation"),
+    llvm::cl::cat(clOptionsCategory));
+
+namespace {
+
+struct TestLoopFusion : public FunctionPass<TestLoopFusion> {
+  void runOnFunction() override;
+};
+
+} // end anonymous namespace
+
+FunctionPassBase *mlir::createTestLoopFusionPass() {
+  return new TestLoopFusion;
+}
+
+// Gathers all AffineForOps in 'block' at 'currLoopDepth' in 'depthToLoops'.
+static void
+gatherLoops(Block *block, unsigned currLoopDepth,
+            DenseMap<unsigned, SmallVector<AffineForOp, 2>> &depthToLoops) {
+  auto &loopsAtDepth = depthToLoops[currLoopDepth];
+  for (auto &op : *block) {
+    if (auto forOp = dyn_cast<AffineForOp>(op)) {
+      loopsAtDepth.push_back(forOp);
+      gatherLoops(forOp.getBody(), currLoopDepth + 1, depthToLoops);
+    }
+  }
+}
+
+// Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths
+// in range ['loopDepth' + 1, 'maxLoopDepth'].
+// Emits a remark on 'loops[i]' if a fusion-preventing dependence exists.
+static void testDependenceCheck(SmallVector<AffineForOp, 2> &loops, unsigned i,
+                                unsigned j, unsigned loopDepth,
+                                unsigned maxLoopDepth) {
+  AffineForOp srcForOp = loops[i];
+  AffineForOp dstForOp = loops[j];
+  mlir::ComputationSliceState sliceUnion;
+  for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
+    FusionResult result =
+        mlir::canFuseLoops(srcForOp, dstForOp, d, &sliceUnion);
+    if (result.value == FusionResult::FailBlockDependence) {
+      srcForOp.getOperation()->emitRemark("block-level dependence preventing"
+                                          " fusion of loop nest ")
+          << i << " into loop nest " << j << " at depth " << loopDepth;
+    }
+  }
+}
+
+// Returns the index of 'op' in its block.
+static unsigned getBlockIndex(Operation &op) {
+  unsigned index = 0;
+  for (auto &opX : *op.getBlock()) {
+    if (&op == &opX)
+      break;
+    ++index;
+  }
+  return index;
+}
+
+// Returns a string representation of 'sliceUnion'.
+static std::string getSliceStr(const mlir::ComputationSliceState &sliceUnion) {
+  std::string result;
+  llvm::raw_string_ostream os(result);
+  // Slice insertion point format [loop-depth, operation-block-index]
+  unsigned ipd = getNestingDepth(*sliceUnion.insertPoint);
+  unsigned ipb = getBlockIndex(*sliceUnion.insertPoint);
+  os << "insert point: (" << std::to_string(ipd) << ", " << std::to_string(ipb)
+     << ")";
+  assert(sliceUnion.lbs.size() == sliceUnion.ubs.size());
+  os << " loop bounds: ";
+  for (unsigned k = 0, e = sliceUnion.lbs.size(); k < e; ++k) {
+    os << '[';
+    sliceUnion.lbs[k].print(os);
+    os << ", ";
+    sliceUnion.ubs[k].print(os);
+    os << "] ";
+  }
+  return os.str();
+}
+
+// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths
+// in range ['loopDepth' + 1, 'maxLoopDepth'].
+// Emits a string represention of the slice union as a remark on 'loops[j]'.
+static void testSliceComputation(SmallVector<AffineForOp, 2> &loops, unsigned i,
+                                 unsigned j, unsigned loopDepth,
+                                 unsigned maxLoopDepth) {
+  AffineForOp forOpA = loops[i];
+  AffineForOp forOpB = loops[j];
+  for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
+    mlir::ComputationSliceState sliceUnion;
+    FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
+    if (result.value == FusionResult::Success) {
+      forOpB.getOperation()->emitRemark("slice (")
+          << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
+          << " : " << getSliceStr(sliceUnion) << ")";
+    }
+  }
+}
+
+void TestLoopFusion::runOnFunction() {
+  // Gather all AffineForOps by loop depth.
+  DenseMap<unsigned, SmallVector<AffineForOp, 2>> depthToLoops;
+  for (auto &block : getFunction()) {
+    gatherLoops(&block, /*currLoopDepth=*/0, depthToLoops);
+  }
+
+  // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
+  for (auto &depthAndLoops : depthToLoops) {
+    unsigned loopDepth = depthAndLoops.first;
+    auto &loops = depthAndLoops.second;
+    unsigned numLoops = loops.size();
+    for (unsigned j = 0; j < numLoops; ++j) {
+      for (unsigned k = 0; k < numLoops; ++k) {
+        if (j == k)
+          continue;
+        if (clTestDependenceCheck)
+          testDependenceCheck(loops, j, k, loopDepth, depthToLoops.size());
+        if (clTestSliceComputation)
+          testSliceComputation(loops, j, k, loopDepth, depthToLoops.size());
+      }
+    }
+  }
+}
+
+static PassRegistration<TestLoopFusion>
+    pass("test-loop-fusion", "Tests loop fusion utility functions.");
diff --git a/third_party/mlir/test/lib/Transforms/TestLoopMapping.cpp b/third_party/mlir/test/lib/Transforms/TestLoopMapping.cpp
new file mode 100644
index 0000000..bf35467
--- /dev/null
+++ b/third_party/mlir/test/lib/Transforms/TestLoopMapping.cpp
@@ -0,0 +1,65 @@
+//===- TestLoopMapping.cpp --- Parametric loop mapping pass ---------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to parametrically map loop.for loops to virtual
+// processing element dimensions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "llvm/ADT/SetVector.h"
+
+using namespace mlir;
+
+namespace {
+class TestLoopMappingPass : public FunctionPass<TestLoopMappingPass> {
+public:
+  explicit TestLoopMappingPass() {}
+
+  void runOnFunction() override {
+    FuncOp func = getFunction();
+
+    // SSA values for the transformation are created out of thin air by
+    // unregistered "new_processor_id_and_range" operations. This is enough to
+    // emulate mapping conditions.
+    SmallVector<Value *, 8> processorIds, numProcessors;
+    func.walk([&processorIds, &numProcessors](Operation *op) {
+      if (op->getName().getStringRef() != "new_processor_id_and_range")
+        return;
+      processorIds.push_back(op->getResult(0));
+      numProcessors.push_back(op->getResult(1));
+    });
+
+    func.walk<loop::ForOp>([&processorIds, &numProcessors](loop::ForOp op) {
+      // Ignore nested loops.
+      if (op.getParentRegion()->getParentOfType<loop::ForOp>())
+        return;
+      mapLoopToProcessorIds(op, processorIds, numProcessors);
+    });
+  }
+};
+} // end namespace
+
+static PassRegistration<TestLoopMappingPass>
+    reg("test-mapping-to-processing-elements",
+        "test mapping a single loop on a virtual processor grid",
+        [] { return new TestLoopMappingPass(); });
diff --git a/third_party/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp b/third_party/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp
new file mode 100644
index 0000000..d30eacc
--- /dev/null
+++ b/third_party/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp
@@ -0,0 +1,71 @@
+//===- TestLoopParametricTiling.cpp --- Parametric loop tiling pass -------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to parametrically tile nests of standard loops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+
+static llvm::cl::list<int> clOuterLoopSizes(
+    "test-outer-loop-sizes", llvm::cl::MiscFlags::CommaSeparated,
+    llvm::cl::desc(
+        "fixed number of iterations that the outer loops should have"));
+
+namespace {
+// Extracts fixed-range loops for top-level loop nests with ranges defined in
+// the pass constructor.  Assumes loops are permutable.
+class SimpleParametricLoopTilingPass
+    : public FunctionPass<SimpleParametricLoopTilingPass> {
+public:
+  explicit SimpleParametricLoopTilingPass(ArrayRef<int64_t> outerLoopSizes)
+      : sizes(outerLoopSizes.begin(), outerLoopSizes.end()) {}
+
+  void runOnFunction() override {
+    FuncOp func = getFunction();
+    func.walk<loop::ForOp>([this](loop::ForOp op) {
+      // Ignore nested loops.
+      if (op.getParentRegion()->getParentOfType<loop::ForOp>())
+        return;
+      extractFixedOuterLoops(op, sizes);
+    });
+  }
+
+  SmallVector<int64_t, 4> sizes;
+};
+} // end namespace
+
+FunctionPassBase *
+mlir::createSimpleParametricTilingPass(ArrayRef<int64_t> outerLoopSizes) {
+  return new SimpleParametricLoopTilingPass(outerLoopSizes);
+}
+
+static PassRegistration<SimpleParametricLoopTilingPass>
+    reg("test-extract-fixed-outer-loops",
+        "test application of parametric tiling to the outer loops so that the "
+        "ranges of outer loops become static",
+        [] {
+          auto *pass = new SimpleParametricLoopTilingPass({});
+          pass->sizes.assign(clOuterLoopSizes.begin(), clOuterLoopSizes.end());
+          return pass;
+        });
diff --git a/third_party/mlir/test/lib/Transforms/TestVectorizationUtils.cpp b/third_party/mlir/test/lib/Transforms/TestVectorizationUtils.cpp
new file mode 100644
index 0000000..b51de41
--- /dev/null
+++ b/third_party/mlir/test/lib/Transforms/TestVectorizationUtils.cpp
@@ -0,0 +1,301 @@
+//===- VectorizerTestPass.cpp - VectorizerTestPass Pass Impl --------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a simple testing pass for vectorization functionality.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/NestedMatcher.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/VectorAnalysis.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/Functional.h"
+#include "mlir/Support/STLExtras.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "affine-vectorizer-test"
+
+using namespace mlir;
+
+using llvm::SetVector;
+
+using functional::map;
+
+static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
+
+static llvm::cl::list<int> clTestVectorShapeRatio(
+    "vector-shape-ratio",
+    llvm::cl::desc("Specify the HW vector size for vectorization"),
+    llvm::cl::ZeroOrMore, llvm::cl::cat(clOptionsCategory));
+static llvm::cl::opt<bool> clTestForwardSlicingAnalysis(
+    "forward-slicing",
+    llvm::cl::desc("Enable testing forward static slicing and topological sort "
+                   "functionalities"),
+    llvm::cl::cat(clOptionsCategory));
+static llvm::cl::opt<bool> clTestBackwardSlicingAnalysis(
+    "backward-slicing",
+    llvm::cl::desc("Enable testing backward static slicing and "
+                   "topological sort functionalities"),
+    llvm::cl::cat(clOptionsCategory));
+static llvm::cl::opt<bool> clTestSlicingAnalysis(
+    "slicing",
+    llvm::cl::desc("Enable testing static slicing and topological sort "
+                   "functionalities"),
+    llvm::cl::cat(clOptionsCategory));
+static llvm::cl::opt<bool> clTestComposeMaps(
+    "compose-maps",
+    llvm::cl::desc(
+        "Enable testing the composition of AffineMap where each "
+        "AffineMap in the composition is specified as the affine_map attribute "
+        "in a constant op."),
+    llvm::cl::cat(clOptionsCategory));
+static llvm::cl::opt<bool> clTestNormalizeMaps(
+    "normalize-maps",
+    llvm::cl::desc(
+        "Enable testing the normalization of AffineAffineApplyOp "
+        "where each AffineAffineApplyOp in the composition is a single output "
+        "operation."),
+    llvm::cl::cat(clOptionsCategory));
+
+namespace {
+struct VectorizerTestPass : public FunctionPass<VectorizerTestPass> {
+  static constexpr auto kTestAffineMapOpName = "test_affine_map";
+  static constexpr auto kTestAffineMapAttrName = "affine_map";
+
+  void runOnFunction() override;
+  void testVectorShapeRatio(llvm::raw_ostream &outs);
+  void testForwardSlicing(llvm::raw_ostream &outs);
+  void testBackwardSlicing(llvm::raw_ostream &outs);
+  void testSlicing(llvm::raw_ostream &outs);
+  void testComposeMaps(llvm::raw_ostream &outs);
+  void testNormalizeMaps();
+};
+
+} // end anonymous namespace
+
+void VectorizerTestPass::testVectorShapeRatio(llvm::raw_ostream &outs) {
+  auto f = getFunction();
+  using matcher::Op;
+  SmallVector<int64_t, 8> shape(clTestVectorShapeRatio.begin(),
+                                clTestVectorShapeRatio.end());
+  auto subVectorType =
+      VectorType::get(shape, FloatType::getF32(f.getContext()));
+  // Only filter operations that operate on a strict super-vector and have one
+  // return. This makes testing easier.
+  auto filter = [&](Operation &op) {
+    assert(subVectorType.getElementType().isF32() &&
+           "Only f32 supported for now");
+    if (!matcher::operatesOnSuperVectorsOf(op, subVectorType)) {
+      return false;
+    }
+    if (op.getNumResults() != 1) {
+      return false;
+    }
+    return true;
+  };
+  auto pat = Op(filter);
+  SmallVector<NestedMatch, 8> matches;
+  pat.match(f, &matches);
+  for (auto m : matches) {
+    auto *opInst = m.getMatchedOperation();
+    // This is a unit test that only checks and prints shape ratio.
+    // As a consequence we write only Ops with a single return type for the
+    // purpose of this test. If we need to test more intricate behavior in the
+    // future we can always extend.
+    auto superVectorType = opInst->getResult(0)->getType().cast<VectorType>();
+    auto ratio = shapeRatio(superVectorType, subVectorType);
+    if (!ratio.hasValue()) {
+      opInst->emitRemark("NOT MATCHED");
+    } else {
+      outs << "\nmatched: " << *opInst << " with shape ratio: ";
+      interleaveComma(MutableArrayRef<unsigned>(*ratio), outs);
+    }
+  }
+}
+
+static NestedPattern patternTestSlicingOps() {
+  using functional::map;
+  using matcher::Op;
+  // Match all operations with the kTestSlicingOpName name.
+  auto filter = [](Operation &op) {
+    // Just use a custom op name for this test, it makes life easier.
+    return op.getName().getStringRef() == "slicing-test-op";
+  };
+  return Op(filter);
+}
+
+void VectorizerTestPass::testBackwardSlicing(llvm::raw_ostream &outs) {
+  auto f = getFunction();
+  outs << "\n" << f.getName();
+
+  SmallVector<NestedMatch, 8> matches;
+  patternTestSlicingOps().match(f, &matches);
+  for (auto m : matches) {
+    SetVector<Operation *> backwardSlice;
+    getBackwardSlice(m.getMatchedOperation(), &backwardSlice);
+    outs << "\nmatched: " << *m.getMatchedOperation()
+         << " backward static slice: ";
+    for (auto *op : backwardSlice)
+      outs << "\n" << *op;
+  }
+}
+
+void VectorizerTestPass::testForwardSlicing(llvm::raw_ostream &outs) {
+  auto f = getFunction();
+  outs << "\n" << f.getName();
+
+  SmallVector<NestedMatch, 8> matches;
+  patternTestSlicingOps().match(f, &matches);
+  for (auto m : matches) {
+    SetVector<Operation *> forwardSlice;
+    getForwardSlice(m.getMatchedOperation(), &forwardSlice);
+    outs << "\nmatched: " << *m.getMatchedOperation()
+         << " forward static slice: ";
+    for (auto *op : forwardSlice)
+      outs << "\n" << *op;
+  }
+}
+
+void VectorizerTestPass::testSlicing(llvm::raw_ostream &outs) {
+  auto f = getFunction();
+  outs << "\n" << f.getName();
+
+  SmallVector<NestedMatch, 8> matches;
+  patternTestSlicingOps().match(f, &matches);
+  for (auto m : matches) {
+    SetVector<Operation *> staticSlice = getSlice(m.getMatchedOperation());
+    outs << "\nmatched: " << *m.getMatchedOperation() << " static slice: ";
+    for (auto *op : staticSlice)
+      outs << "\n" << *op;
+  }
+}
+
+static bool customOpWithAffineMapAttribute(Operation &op) {
+  return op.getName().getStringRef() ==
+         VectorizerTestPass::kTestAffineMapOpName;
+}
+
+void VectorizerTestPass::testComposeMaps(llvm::raw_ostream &outs) {
+  auto f = getFunction();
+
+  using matcher::Op;
+  auto pattern = Op(customOpWithAffineMapAttribute);
+  SmallVector<NestedMatch, 8> matches;
+  pattern.match(f, &matches);
+  SmallVector<AffineMap, 4> maps;
+  maps.reserve(matches.size());
+  for (auto m : llvm::reverse(matches)) {
+    auto *opInst = m.getMatchedOperation();
+    auto map = opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName)
+                   .cast<AffineMapAttr>()
+                   .getValue();
+    maps.push_back(map);
+  }
+  AffineMap res;
+  for (auto m : maps) {
+    res = res ? res.compose(m) : m;
+  }
+  simplifyAffineMap(res).print(outs << "\nComposed map: ");
+}
+
+static bool affineApplyOp(Operation &op) { return isa<AffineApplyOp>(op); }
+
+static bool singleResultAffineApplyOpWithoutUses(Operation &op) {
+  auto app = dyn_cast<AffineApplyOp>(op);
+  return app && app.use_empty();
+}
+
+void VectorizerTestPass::testNormalizeMaps() {
+  using matcher::Op;
+
+  auto f = getFunction();
+
+  // Save matched AffineApplyOp that all need to be erased in the end.
+  auto pattern = Op(affineApplyOp);
+  SmallVector<NestedMatch, 8> toErase;
+  pattern.match(f, &toErase);
+  {
+    // Compose maps.
+    auto pattern = Op(singleResultAffineApplyOpWithoutUses);
+    SmallVector<NestedMatch, 8> matches;
+    pattern.match(f, &matches);
+    for (auto m : matches) {
+      auto app = cast<AffineApplyOp>(m.getMatchedOperation());
+      OpBuilder b(m.getMatchedOperation());
+      SmallVector<Value *, 8> operands(app.getOperands());
+      makeComposedAffineApply(b, app.getLoc(), app.getAffineMap(), operands);
+    }
+  }
+  // We should now be able to erase everything in reverse order in this test.
+  for (auto m : llvm::reverse(toErase)) {
+    m.getMatchedOperation()->erase();
+  }
+}
+
+void VectorizerTestPass::runOnFunction() {
+  // Thread-safe RAII local context, BumpPtrAllocator freed on exit.
+  NestedPatternContext mlContext;
+
+  // Only support single block functions at this point.
+  FuncOp f = getFunction();
+  if (f.getBlocks().size() != 1)
+    return;
+
+  std::string str;
+  llvm::raw_string_ostream outs(str);
+
+  if (!clTestVectorShapeRatio.empty())
+    testVectorShapeRatio(outs);
+
+  if (clTestForwardSlicingAnalysis)
+    testForwardSlicing(outs);
+
+  if (clTestBackwardSlicingAnalysis)
+    testBackwardSlicing(outs);
+
+  if (clTestSlicingAnalysis)
+    testSlicing(outs);
+
+  if (clTestComposeMaps)
+    testComposeMaps(outs);
+
+  if (clTestNormalizeMaps)
+    testNormalizeMaps();
+
+  if (!outs.str().empty()) {
+    emitRemark(UnknownLoc::get(&getContext()), outs.str());
+  }
+}
+
+FunctionPassBase *mlir::createVectorizerTestPass() {
+  return new VectorizerTestPass();
+}
+
+static PassRegistration<VectorizerTestPass>
+    pass("affine-vectorizer-test",
+         "Tests vectorizer standalone functionality.");
+
+#undef DEBUG_TYPE
diff --git a/third_party/mlir/test/lit.cfg.py b/third_party/mlir/test/lit.cfg.py
new file mode 100644
index 0000000..cf93894
--- /dev/null
+++ b/third_party/mlir/test/lit.cfg.py
@@ -0,0 +1,73 @@
+# -*- Python -*-
+
+import os
+import platform
+import re
+import subprocess
+import tempfile
+
+import lit.formats
+import lit.util
+
+from lit.llvm import llvm_config
+from lit.llvm.subst import ToolSubst
+from lit.llvm.subst import FindTool
+
+# Configuration file for the 'lit' test runner.
+
+# name: The name of this test suite.
+config.name = 'MLIR'
+
+config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
+
+# suffixes: A list of file extensions to treat as test files.
+config.suffixes = ['.td', '.mlir', '.toy']
+
+# test_source_root: The root path where tests are located.
+config.test_source_root = os.path.dirname(__file__)
+
+# test_exec_root: The root path where tests should be run.
+config.test_exec_root = os.path.join(config.mlir_obj_root, 'test')
+
+config.substitutions.append(('%PATH%', config.environment['PATH']))
+config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
+
+llvm_config.with_system_environment(
+    ['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP'])
+
+llvm_config.use_default_substitutions()
+
+# excludes: A list of directories to exclude from the testsuite. The 'Inputs'
+# subdirectories contain auxiliary inputs for various tests in their parent
+# directories.
+config.excludes = ['Inputs', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt']
+
+# test_source_root: The root path where tests are located.
+config.test_source_root = os.path.dirname(__file__)
+
+# test_exec_root: The root path where tests should be run.
+config.test_exec_root = os.path.join(config.mlir_obj_root, 'test')
+
+# Tweak the PATH to include the tools dir.
+llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
+
+tool_dirs = [config.mlir_tools_dir, config.llvm_tools_dir]
+tools = [
+    'mlir-opt',
+    'mlir-tblgen',
+    'mlir-translate',
+    'mlir-edsc-builder-api-test',
+]
+
+# The following tools are optional
+tools.extend([
+    ToolSubst('toy-ch1', unresolved='ignore'),
+    ToolSubst('toy-ch2', unresolved='ignore'),
+    ToolSubst('toy-ch3', unresolved='ignore'),
+    ToolSubst('toy-ch4', unresolved='ignore'),
+    ToolSubst('toy-ch5', unresolved='ignore'),
+    ToolSubst('%linalg_test_lib_dir', config.linalg_test_lib_dir, unresolved='ignore'),
+    ToolSubst('%cuda_wrapper_library_dir', config.cuda_wrapper_library_dir, unresolved='ignore')
+])
+
+llvm_config.add_tool_substitutions(tools, tool_dirs)
diff --git a/third_party/mlir/test/lit.site.cfg.py.in b/third_party/mlir/test/lit.site.cfg.py.in
new file mode 100644
index 0000000..830b65f
--- /dev/null
+++ b/third_party/mlir/test/lit.site.cfg.py.in
@@ -0,0 +1,53 @@
+@LIT_SITE_CFG_IN_HEADER@
+
+import sys
+
+config.host_triple = "@LLVM_HOST_TRIPLE@"
+config.target_triple = "@TARGET_TRIPLE@"
+config.llvm_src_root = "@LLVM_SOURCE_DIR@"
+config.llvm_obj_root = "@LLVM_BINARY_DIR@"
+config.llvm_tools_dir = "@LLVM_TOOLS_DIR@"
+config.llvm_lib_dir = "@LLVM_LIBRARY_DIR@"
+config.llvm_shlib_dir = "@SHLIBDIR@"
+config.llvm_shlib_ext = "@SHLIBEXT@"
+config.llvm_exe_ext = "@EXEEXT@"
+config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@"
+config.python_executable = "@PYTHON_EXECUTABLE@"
+config.gold_executable = "@GOLD_EXECUTABLE@"
+config.ld64_executable = "@LD64_EXECUTABLE@"
+config.enable_shared = @ENABLE_SHARED@
+config.enable_assertions = @ENABLE_ASSERTIONS@
+config.targets_to_build = "@TARGETS_TO_BUILD@"
+config.native_target = "@LLVM_NATIVE_ARCH@"
+config.llvm_bindings = "@LLVM_BINDINGS@".split(' ')
+config.host_os = "@HOST_OS@"
+config.host_cc = "@HOST_CC@"
+config.host_cxx = "@HOST_CXX@"
+config.host_ldflags = "@HOST_LDFLAGS@"
+config.llvm_use_sanitizer = "@LLVM_USE_SANITIZER@"
+config.llvm_host_triple = '@LLVM_HOST_TRIPLE@'
+config.host_arch = "@HOST_ARCH@"
+config.mlir_src_root = "@MLIR_SOURCE_DIR@"
+config.mlir_obj_root = "@MLIR_BINARY_DIR@"
+config.mlir_tools_dir = "@MLIR_TOOLS_DIR@"
+config.linalg_test_lib_dir = "@MLIR_LINALG_INTEGRATION_TEST_LIB_DIR@"
+config.build_examples = @LLVM_BUILD_EXAMPLES@
+config.run_cuda_tests = @MLIR_CUDA_CONVERSIONS_ENABLED@
+config.cuda_wrapper_library_dir = "@MLIR_CUDA_WRAPPER_LIBRARY_DIR@"
+config.enable_cuda_runner = @MLIR_CUDA_RUNNER_ENABLED@
+
+# Support substitution of the tools_dir with user parameters. This is
+# used when we can't determine the tool dir at configuration time.
+try:
+    config.llvm_tools_dir = config.llvm_tools_dir % lit_config.params
+    config.llvm_shlib_dir = config.llvm_shlib_dir % lit_config.params
+except KeyError:
+    e = sys.exc_info()[1]
+    key, = e.args
+    lit_config.fatal("unable to find %r parameter, use '--param=%s=VALUE'" % (key,key))
+
+import lit.llvm
+lit.llvm.initialize(lit_config, config)
+
+# Let the main config do the real work.
+lit_config.load_config(config, "@MLIR_SOURCE_DIR@/test/lit.cfg.py")
diff --git a/third_party/mlir/tools/CMakeLists.txt b/third_party/mlir/tools/CMakeLists.txt
new file mode 100644
index 0000000..2566dd8
--- /dev/null
+++ b/third_party/mlir/tools/CMakeLists.txt
@@ -0,0 +1,5 @@
+add_subdirectory(mlir-cuda-runner)
+add_subdirectory(mlir-cpu-runner)
+add_subdirectory(mlir-opt)
+add_subdirectory(mlir-tblgen)
+add_subdirectory(mlir-translate)
diff --git a/third_party/mlir/tools/mlir-cpu-runner/CMakeLists.txt b/third_party/mlir/tools/mlir-cpu-runner/CMakeLists.txt
new file mode 100644
index 0000000..561fd9d
--- /dev/null
+++ b/third_party/mlir/tools/mlir-cpu-runner/CMakeLists.txt
@@ -0,0 +1,28 @@
+add_llvm_executable(mlir-cpu-runner
+  mlir-cpu-runner.cpp
+)
+llvm_update_compile_flags(mlir-cpu-runner)
+whole_archive_link(mlir-cpu-runner
+  MLIRLLVMIR
+  MLIRStandardOps
+  MLIRTargetLLVMIR
+  MLIRTransforms
+  MLIRTranslation
+)
+target_link_libraries(mlir-cpu-runner PRIVATE
+  MLIRAffineOps
+  MLIRAnalysis
+  MLIRControlFlowToCFG
+  MLIREDSC
+  MLIRExecutionEngine
+  MLIRIR
+  MLIRJitRunner
+  MLIRLLVMIR
+  MLIRParser
+  MLIRTargetLLVMIR
+  MLIRTransforms
+  MLIRStandardToLLVM
+  MLIRSupport
+  LLVMCore
+  LLVMSupport
+)
diff --git a/third_party/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp b/third_party/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp
new file mode 100644
index 0000000..f7023c4c
--- /dev/null
+++ b/third_party/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp
@@ -0,0 +1,28 @@
+//===- mlir-cpu-runner.cpp - MLIR CPU Execution Driver---------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Main entry point to a command line utility that executes an MLIR file on the
+// CPU by  translating MLIR to LLVM IR before JIT-compiling and executing the
+// latter.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/JitRunner.h"
+
+int main(int argc, char **argv) {
+  return mlir::JitRunnerMain(argc, argv, nullptr);
+}
diff --git a/third_party/mlir/tools/mlir-cuda-runner/CMakeLists.txt b/third_party/mlir/tools/mlir-cuda-runner/CMakeLists.txt
new file mode 100644
index 0000000..fda9122
--- /dev/null
+++ b/third_party/mlir/tools/mlir-cuda-runner/CMakeLists.txt
@@ -0,0 +1,74 @@
+set(LLVM_OPTIONAL_SOURCES
+  cuda-runtime-wrappers.cpp
+  mlir-cuda-runner.cpp
+  )
+
+if(MLIR_CUDA_RUNNER_ENABLED)
+  if (NOT ("NVPTX" IN_LIST LLVM_TARGETS_TO_BUILD))
+    message(SEND_ERROR
+      "Building the mlir cuda runner requires the NVPTX backend")
+  endif()
+
+  # Configure CUDA runner support. Using check_language first allows us to give
+  # a custom error message.
+  include(CheckLanguage)
+  check_language(CUDA)
+  if (CMAKE_CUDA_COMPILER)
+    enable_language(CUDA)
+  else()
+    message(SEND_ERROR
+      "Building the mlir cuda runner requires a working CUDA install")
+  endif()
+
+  # We need the libcuda.so library.
+  find_library(CUDA_RUNTIME_LIBRARY cuda)
+
+  add_llvm_library(cuda-runtime-wrappers SHARED
+    cuda-runtime-wrappers.cpp
+  )
+  target_include_directories(cuda-runtime-wrappers
+    PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}
+    LLVMSupport
+  )
+  target_link_libraries(cuda-runtime-wrappers
+    LLVMSupport
+    ${CUDA_RUNTIME_LIBRARY}
+  )
+
+  set(FULL_LINK_LIBS
+    MLIRAffineOps
+    MLIRControlFlowToCFG
+    MLIRGPU
+    MLIRGPUtoCUDATransforms
+    MLIRGPUtoNVVMTransforms
+    MLIRLLVMIR
+    MLIRStandardOps
+    MLIRStandardToLLVM
+    MLIRTargetLLVMIR
+    MLIRTransforms
+    MLIRTranslation
+  )
+  set(LIBS
+    MLIRIR
+    MLIRParser
+    MLIREDSC
+    MLIRAnalysis
+    MLIRExecutionEngine
+    MLIRJitRunner
+    MLIRSupport
+    LLVMCore
+    LLVMSupport
+    ${CUDA_RUNTIME_LIBRARY}
+  )
+  add_llvm_executable(mlir-cuda-runner
+    mlir-cuda-runner.cpp
+  )
+  add_dependencies(mlir-cuda-runner cuda-runtime-wrappers)
+  target_include_directories(mlir-cuda-runner
+    PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}
+  )
+  llvm_update_compile_flags(mlir-cuda-runner)
+  whole_archive_link(mlir-cuda-runner ${FULL_LINK_LIBS})
+  target_link_libraries(mlir-cuda-runner PRIVATE ${FULL_LINK_LIBS} ${LIBS})
+
+endif()
diff --git a/third_party/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/third_party/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
new file mode 100644
index 0000000..c394662
--- /dev/null
+++ b/third_party/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
@@ -0,0 +1,108 @@
+//===- cuda-runtime-wrappers.cpp - MLIR CUDA runner wrapper library -------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Implements C wrappers around the CUDA library for easy linking in ORC jit.
+// Also adds some debugging helpers that are helpful when writing MLIR code to
+// run on GPUs.
+//
+//===----------------------------------------------------------------------===//
+
+#include <assert.h>
+#include <memory.h>
+
+#include "llvm/Support/raw_ostream.h"
+
+#include "cuda.h"
+
+namespace {
+int32_t reportErrorIfAny(CUresult result, const char *where) {
+  if (result != CUDA_SUCCESS) {
+    llvm::errs() << "CUDA failed with " << result << " in " << where << "\n";
+  }
+  return result;
+}
+} // anonymous namespace
+
+extern "C" int32_t mcuModuleLoad(void **module, void *data) {
+  int32_t err = reportErrorIfAny(
+      cuModuleLoadData(reinterpret_cast<CUmodule *>(module), data),
+      "ModuleLoad");
+  return err;
+}
+
+extern "C" int32_t mcuModuleGetFunction(void **function, void *module,
+                                        const char *name) {
+  return reportErrorIfAny(
+      cuModuleGetFunction(reinterpret_cast<CUfunction *>(function),
+                          reinterpret_cast<CUmodule>(module), name),
+      "GetFunction");
+}
+
+// The wrapper uses intptr_t instead of CUDA's unsigned int to match
+// the type of MLIR's index type. This avoids the need for casts in the
+// generated MLIR code.
+extern "C" int32_t mcuLaunchKernel(void *function, intptr_t gridX,
+                                   intptr_t gridY, intptr_t gridZ,
+                                   intptr_t blockX, intptr_t blockY,
+                                   intptr_t blockZ, int32_t smem, void *stream,
+                                   void **params, void **extra) {
+  return reportErrorIfAny(
+      cuLaunchKernel(reinterpret_cast<CUfunction>(function), gridX, gridY,
+                     gridZ, blockX, blockY, blockZ, smem,
+                     reinterpret_cast<CUstream>(stream), params, extra),
+      "LaunchKernel");
+}
+
+extern "C" void *mcuGetStreamHelper() {
+  CUstream stream;
+  reportErrorIfAny(cuStreamCreate(&stream, CU_STREAM_DEFAULT), "StreamCreate");
+  return stream;
+}
+
+extern "C" int32_t mcuStreamSynchronize(void *stream) {
+  return reportErrorIfAny(
+      cuStreamSynchronize(reinterpret_cast<CUstream>(stream)), "StreamSync");
+}
+
+/// Helper functions for writing mlir example code
+
+// A struct that corresponds to how MLIR represents unknown-length 1d memrefs.
+struct memref_t {
+  float *values;
+  intptr_t length;
+};
+
+// Allows to register a pointer with the CUDA runtime. Helpful until
+// we have transfer functions implemented.
+extern "C" void mcuMemHostRegister(const memref_t arg, int32_t flags) {
+  reportErrorIfAny(
+      cuMemHostRegister(arg.values, arg.length * sizeof(float), flags),
+      "MemHostRegister");
+}
+
+/// Prints the given float array to stderr.
+extern "C" void mcuPrintFloat(const memref_t arg) {
+  if (arg.length == 0) {
+    llvm::outs() << "[]\n";
+    return;
+  }
+  llvm::outs() << "[" << arg.values[0];
+  for (int pos = 1; pos < arg.length; pos++) {
+    llvm::outs() << ", " << arg.values[pos];
+  }
+  llvm::outs() << "]\n";
+}
diff --git a/third_party/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp b/third_party/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp
new file mode 100644
index 0000000..f75413f
--- /dev/null
+++ b/third_party/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp
@@ -0,0 +1,154 @@
+//===- mlir-cpu-runner.cpp - MLIR CPU Execution Driver---------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is a command line utility that executes an MLIR file on the GPU by
+// translating MLIR to NVVM/LVVM IR before JIT-compiling and executing the
+// latter.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/STLExtras.h"
+
+#include "mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h"
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/GPU/Passes.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/JitRunner.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "cuda.h"
+
+using namespace mlir;
+
+inline void emit_cuda_error(const llvm::Twine &message, const char *buffer,
+                            CUresult error, FuncOp &function) {
+  function.emitError(message.concat(" failed with error code ")
+                         .concat(llvm::Twine{error})
+                         .concat("[")
+                         .concat(buffer)
+                         .concat("]"));
+}
+
+#define RETURN_ON_CUDA_ERROR(expr, msg)                                        \
+  {                                                                            \
+    auto _cuda_error = (expr);                                                 \
+    if (_cuda_error != CUDA_SUCCESS) {                                         \
+      emit_cuda_error(msg, jitErrorBuffer, _cuda_error, function);             \
+      return {};                                                               \
+    }                                                                          \
+  }
+
+OwnedCubin compilePtxToCubin(const std::string ptx, FuncOp &function) {
+  char jitErrorBuffer[4096] = {0};
+
+  RETURN_ON_CUDA_ERROR(cuInit(0), "cuInit");
+
+  // Linking requires a device context.
+  CUdevice device;
+  RETURN_ON_CUDA_ERROR(cuDeviceGet(&device, 0), "cuDeviceGet");
+  CUcontext context;
+  RETURN_ON_CUDA_ERROR(cuCtxCreate(&context, 0, device), "cuCtxCreate");
+  CUlinkState linkState;
+
+  CUjit_option jitOptions[] = {CU_JIT_ERROR_LOG_BUFFER,
+                               CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES};
+  void *jitOptionsVals[] = {jitErrorBuffer,
+                            reinterpret_cast<void *>(sizeof(jitErrorBuffer))};
+
+  RETURN_ON_CUDA_ERROR(cuLinkCreate(2,              /* number of jit options */
+                                    jitOptions,     /* jit options */
+                                    jitOptionsVals, /* jit option values */
+                                    &linkState),
+                       "cuLinkCreate");
+
+  RETURN_ON_CUDA_ERROR(
+      cuLinkAddData(linkState, CUjitInputType::CU_JIT_INPUT_PTX,
+                    const_cast<void *>(static_cast<const void *>(ptx.c_str())),
+                    ptx.length(), function.getName().data(), /* kernel name */
+                    0,       /* number of jit options */
+                    nullptr, /* jit options */
+                    nullptr  /* jit option values */
+                    ),
+      "cuLinkAddData");
+
+  void *cubinData;
+  size_t cubinSize;
+  RETURN_ON_CUDA_ERROR(cuLinkComplete(linkState, &cubinData, &cubinSize),
+                       "cuLinkComplete");
+
+  char *cubinAsChar = static_cast<char *>(cubinData);
+  OwnedCubin result = llvm::make_unique<std::vector<char>>(
+      cubinAsChar, cubinAsChar + cubinSize);
+
+  // This will also destroy the cubin data.
+  RETURN_ON_CUDA_ERROR(cuLinkDestroy(linkState), "cuLinkDestroy");
+
+  return result;
+}
+
+namespace {
+struct GPULaunchFuncOpLowering : public LLVMOpLowering {
+public:
+  explicit GPULaunchFuncOpLowering(LLVMTypeConverter &lowering_)
+      : LLVMOpLowering(gpu::LaunchFuncOp::getOperationName(),
+                       lowering_.getDialect()->getContext(), lowering_) {}
+
+  // Convert the kernel arguments to an LLVM type, preserve the rest.
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.clone(*op)->setOperands(operands);
+    return rewriter.replaceOp(op, llvm::None), matchSuccess();
+  }
+};
+} // end anonymous namespace
+
+static LogicalResult runMLIRPasses(ModuleOp m) {
+  // As we gradually lower, the IR is inconsistent between passes. So do not
+  // verify inbetween.
+  PassManager pm(/*verifyPasses=*/false);
+
+  pm.addPass(createGpuKernelOutliningPass());
+  pm.addPass(createConvertToLLVMIRPass([](LLVMTypeConverter &converter,
+                                          OwningRewritePatternList &patterns) {
+    populateStdToLLVMConversionPatterns(converter, patterns);
+    patterns.insert<GPULaunchFuncOpLowering>(converter);
+  }));
+  pm.addPass(createLowerGpuOpsToNVVMOpsPass());
+  pm.addPass(createConvertGPUKernelToCubinPass(&compilePtxToCubin));
+  pm.addPass(createGenerateCubinAccessorPass());
+  pm.addPass(createConvertGpuLaunchFuncToCudaCallsPass());
+
+  if (failed(pm.run(m)))
+    return failure();
+
+  if (failed(m.verify()))
+    return failure();
+
+  return success();
+}
+
+int main(int argc, char **argv) {
+  return mlir::JitRunnerMain(argc, argv, &runMLIRPasses);
+}
diff --git a/third_party/mlir/tools/mlir-opt/CMakeLists.txt b/third_party/mlir/tools/mlir-opt/CMakeLists.txt
new file mode 100644
index 0000000..26f8885
--- /dev/null
+++ b/third_party/mlir/tools/mlir-opt/CMakeLists.txt
@@ -0,0 +1,57 @@
+set(LLVM_OPTIONAL_SOURCES
+  null.cpp
+)
+
+set(LIB_LIBS
+  MLIRAnalysis
+  MLIRLLVMIR
+  MLIRParser
+  MLIRPass
+  MLIRTransforms
+  MLIRSupport
+)
+add_llvm_library(MLIRMlirOptLib
+  mlir-opt.cpp
+)
+target_link_libraries(MLIRMlirOptLib ${LIB_LIBS})
+
+set(LIBS
+  MLIRAffineOps
+  MLIRLoopsToGPU
+  MLIRAnalysis
+  MLIRControlFlowToCFG
+  MLIREDSC
+  MLIRFxpMathOps
+  MLIRGPU
+  MLIRGPUtoNVVMTransforms
+  MLIRGPUtoSPIRVTransforms
+  MLIRLinalg
+  MLIRLLVMIR
+  MLIRLoopOps
+  MLIRNVVMIR
+  MLIROptMain
+  MLIRParser
+  MLIRPass
+  MLIRQuantizerTransforms
+  MLIRQuantOps
+  MLIRSPIRV
+  MLIRSPIRVConversion
+  MLIRStandardOps
+  MLIRStandardToLLVM
+  MLIRTransforms
+  MLIRTestDialect
+  MLIRTestTransforms
+  MLIRSupport
+  MLIRVectorOps
+)
+if(MLIR_CUDA_CONVERSIONS_ENABLED)
+  list(APPEND LIBS
+    MLIRGPUtoCUDATransforms
+  )
+endif()
+add_llvm_executable(mlir-opt
+ mlir-opt.cpp
+)
+llvm_update_compile_flags(mlir-opt)
+whole_archive_link(mlir-opt ${LIBS})
+target_link_libraries(mlir-opt PRIVATE MLIRIR MLIRMlirOptLib ${LIBS} LLVMSupport)
diff --git a/third_party/mlir/tools/mlir-opt/mlir-opt.cpp b/third_party/mlir/tools/mlir-opt/mlir-opt.cpp
new file mode 100644
index 0000000..35bba1f
--- /dev/null
+++ b/third_party/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -0,0 +1,91 @@
+//===- mlir-opt.cpp - MLIR Optimizer Driver -------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Main entry function for mlir-opt for when built as standalone binary.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Passes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/MlirOptMain.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/PrettyStackTrace.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+using namespace llvm;
+using namespace mlir;
+
+static cl::opt<std::string>
+    inputFilename(cl::Positional, cl::desc("<input file>"), cl::init("-"));
+
+static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"),
+                                           cl::value_desc("filename"),
+                                           cl::init("-"));
+
+static cl::opt<bool>
+    splitInputFile("split-input-file",
+                   cl::desc("Split the input file into pieces and process each "
+                            "chunk independently"),
+                   cl::init(false));
+
+static cl::opt<bool>
+    verifyDiagnostics("verify-diagnostics",
+                      cl::desc("Check that emitted diagnostics match "
+                               "expected-* lines on the corresponding line"),
+                      cl::init(false));
+
+static cl::opt<bool>
+    verifyPasses("verify-each",
+                 cl::desc("Run the verifier after each transformation pass"),
+                 cl::init(true));
+
+static std::vector<const PassRegistryEntry *> *passList;
+
+int main(int argc, char **argv) {
+  llvm::PrettyStackTraceProgram x(argc, argv);
+  InitLLVM y(argc, argv);
+
+  // Register any pass manager command line options.
+  registerPassManagerCLOptions();
+
+  // Parse pass names in main to ensure static initialization completed.
+  llvm::cl::list<const PassRegistryEntry *, bool, PassNameParser> passList(
+      "", llvm::cl::desc("Compiler passes to run"));
+  ::passList = &passList;
+  cl::ParseCommandLineOptions(argc, argv, "MLIR modular optimizer driver\n");
+
+  // Set up the input file.
+  std::string errorMessage;
+  auto file = openInputFile(inputFilename, &errorMessage);
+  if (!file) {
+    llvm::errs() << errorMessage << "\n";
+    return 1;
+  }
+
+  auto output = openOutputFile(outputFilename, &errorMessage);
+  if (!output) {
+    llvm::errs() << errorMessage << "\n";
+    exit(1);
+  }
+
+  return failed(MlirOptMain(output->os(), std::move(file), passList,
+                            splitInputFile, verifyDiagnostics, verifyPasses));
+}
diff --git a/third_party/mlir/tools/mlir-tblgen/CMakeLists.txt b/third_party/mlir/tools/mlir-tblgen/CMakeLists.txt
new file mode 100644
index 0000000..b18b04a
--- /dev/null
+++ b/third_party/mlir/tools/mlir-tblgen/CMakeLists.txt
@@ -0,0 +1,16 @@
+set(LLVM_LINK_COMPONENTS
+  MLIRTableGen
+  Support
+  )
+
+add_tablegen(mlir-tblgen MLIR
+  EnumsGen.cpp
+  LLVMIRConversionGen.cpp
+  mlir-tblgen.cpp
+  OpDefinitionsGen.cpp
+  OpDocGen.cpp
+  ReferenceImplGen.cpp
+  RewriterGen.cpp
+  SPIRVUtilsGen.cpp
+  )
+set_target_properties(mlir-tblgen PROPERTIES FOLDER "Tablegenning")
diff --git a/third_party/mlir/tools/mlir-tblgen/EnumsGen.cpp b/third_party/mlir/tools/mlir-tblgen/EnumsGen.cpp
new file mode 100644
index 0000000..36f2e04
--- /dev/null
+++ b/third_party/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -0,0 +1,285 @@
+//===- EnumsGen.cpp - MLIR enum utility generator -------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// EnumsGen generates common utility functions for enums.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+using llvm::formatv;
+using llvm::isDigit;
+using llvm::raw_ostream;
+using llvm::Record;
+using llvm::RecordKeeper;
+using llvm::StringRef;
+using mlir::tblgen::EnumAttr;
+using mlir::tblgen::EnumAttrCase;
+
+static std::string makeIdentifier(StringRef str) {
+  if (!str.empty() && isDigit(static_cast<unsigned char>(str.front()))) {
+    std::string newStr = std::string("_") + str.str();
+    return newStr;
+  }
+  return str.str();
+}
+
+static void emitEnumClass(const Record &enumDef, StringRef enumName,
+                          StringRef underlyingType, StringRef description,
+                          const std::vector<EnumAttrCase> &enumerants,
+                          raw_ostream &os) {
+  os << "// " << description << "\n";
+  os << "enum class " << enumName;
+
+  if (!underlyingType.empty())
+    os << " : " << underlyingType;
+  os << " {\n";
+
+  for (const auto &enumerant : enumerants) {
+    auto symbol = makeIdentifier(enumerant.getSymbol());
+    auto value = enumerant.getValue();
+    if (value >= 0) {
+      os << formatv("  {0} = {1},\n", symbol, value);
+    } else {
+      os << formatv("  {0},\n", symbol);
+    }
+  }
+  os << "};\n\n";
+}
+
+static void emitDenseMapInfo(StringRef enumName, std::string underlyingType,
+                             StringRef cppNamespace, raw_ostream &os) {
+  std::string qualName = formatv("{0}::{1}", cppNamespace, enumName);
+  if (underlyingType.empty())
+    underlyingType = formatv("std::underlying_type<{0}>::type", qualName);
+
+  const char *const mapInfo = R"(
+namespace llvm {
+template<> struct DenseMapInfo<{0}> {{
+  using StorageInfo = llvm::DenseMapInfo<{1}>;
+
+  static inline {0} getEmptyKey() {{
+    return static_cast<{0}>(StorageInfo::getEmptyKey());
+  }
+
+  static inline {0} getTombstoneKey() {{
+    return static_cast<{0}>(StorageInfo::getTombstoneKey());
+  }
+
+  static unsigned getHashValue(const {0} &val) {{
+    return StorageInfo::getHashValue(static_cast<{1}>(val));
+  }
+
+  static bool isEqual(const {0} &lhs, const {0} &rhs) {{
+    return lhs == rhs;
+  }
+};
+})";
+  os << formatv(mapInfo, qualName, underlyingType);
+  os << "\n\n";
+}
+
+static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) {
+  EnumAttr enumAttr(enumDef);
+  StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName();
+  auto enumerants = enumAttr.getAllCases();
+
+  unsigned maxEnumVal = 0;
+  for (const auto &enumerant : enumerants) {
+    int64_t value = enumerant.getValue();
+    // Avoid generating the max value function if there is an enumerant without
+    // explicit value.
+    if (value < 0)
+      return;
+
+    maxEnumVal = std::max(maxEnumVal, static_cast<unsigned>(value));
+  }
+
+  // Emit the function to return the max enum value
+  os << formatv("inline constexpr unsigned {0}() {{\n", maxEnumValFnName);
+  os << formatv("  return {0};\n", maxEnumVal);
+  os << "}\n\n";
+}
+
+static void emitSymToStrFn(const Record &enumDef, raw_ostream &os) {
+  EnumAttr enumAttr(enumDef);
+  StringRef enumName = enumAttr.getEnumClassName();
+  StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
+  auto enumerants = enumAttr.getAllCases();
+
+  os << formatv("llvm::StringRef {1}({0} val) {{\n", enumName, symToStrFnName);
+  os << "  switch (val) {\n";
+  for (const auto &enumerant : enumerants) {
+    auto symbol = enumerant.getSymbol();
+    os << formatv("    case {0}::{1}: return \"{2}\";\n", enumName,
+                  makeIdentifier(symbol), symbol);
+  }
+  os << "  }\n";
+  os << "  return \"\";\n";
+  os << "}\n\n";
+}
+
+static void emitStrToSymFn(const Record &enumDef, raw_ostream &os) {
+  EnumAttr enumAttr(enumDef);
+  StringRef enumName = enumAttr.getEnumClassName();
+  StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
+  auto enumerants = enumAttr.getAllCases();
+
+  os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName,
+                strToSymFnName);
+  os << formatv("  return llvm::StringSwitch<llvm::Optional<{0}>>(str)\n",
+                enumName);
+  for (const auto &enumerant : enumerants) {
+    auto symbol = enumerant.getSymbol();
+    os << formatv("      .Case(\"{1}\", {0}::{2})\n", enumName, symbol,
+                  makeIdentifier(symbol));
+  }
+  os << "      .Default(llvm::None);\n";
+  os << "}\n";
+}
+
+static void emitUnderlyingToSymFn(const Record &enumDef, raw_ostream &os) {
+  EnumAttr enumAttr(enumDef);
+  StringRef enumName = enumAttr.getEnumClassName();
+  std::string underlyingType = enumAttr.getUnderlyingType();
+  StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
+  auto enumerants = enumAttr.getAllCases();
+
+  // Avoid generating the underlying value to symbol conversion function if
+  // there is an enumerant without explicit value.
+  if (llvm::any_of(enumerants, [](EnumAttrCase enumerant) {
+        return enumerant.getValue() < 0;
+      }))
+    return;
+
+  os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", enumName,
+                underlyingToSymFnName,
+                underlyingType.empty() ? std::string("unsigned")
+                                       : underlyingType)
+     << "  switch (value) {\n";
+  for (const auto &enumerant : enumerants) {
+    auto symbol = enumerant.getSymbol();
+    auto value = enumerant.getValue();
+    os << formatv("  case {0}: return {1}::{2};\n", value, enumName,
+                  makeIdentifier(symbol));
+  }
+  os << "  default: return llvm::None;\n"
+     << "  }\n"
+     << "}\n\n";
+}
+
+static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
+  EnumAttr enumAttr(enumDef);
+  StringRef enumName = enumAttr.getEnumClassName();
+  StringRef cppNamespace = enumAttr.getCppNamespace();
+  std::string underlyingType = enumAttr.getUnderlyingType();
+  StringRef description = enumAttr.getDescription();
+  StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
+  StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
+  StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
+  auto enumerants = enumAttr.getAllCases();
+
+  llvm::SmallVector<StringRef, 2> namespaces;
+  llvm::SplitString(cppNamespace, namespaces, "::");
+
+  for (auto ns : namespaces)
+    os << "namespace " << ns << " {\n";
+
+  // Emit the enum class definition
+  emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os);
+
+  // Emit coversion function declarations
+  if (llvm::all_of(enumerants, [](EnumAttrCase enumerant) {
+        return enumerant.getValue() >= 0;
+      })) {
+    os << formatv(
+        "llvm::Optional<{0}> {1}({2});\n", enumName, underlyingToSymFnName,
+        underlyingType.empty() ? std::string("unsigned") : underlyingType);
+  }
+  os << formatv("llvm::StringRef {1}({0});\n", enumName, symToStrFnName);
+  os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef);\n", enumName,
+                strToSymFnName);
+
+  emitMaxValueFn(enumDef, os);
+
+  for (auto ns : llvm::reverse(namespaces))
+    os << "} // namespace " << ns << "\n";
+
+  // Emit DenseMapInfo for this enum class
+  emitDenseMapInfo(enumName, underlyingType, cppNamespace, os);
+}
+
+static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
+  llvm::emitSourceFileHeader("Enum Utility Declarations", os);
+
+  auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
+  for (const auto *def : defs)
+    emitEnumDecl(*def, os);
+
+  return false;
+}
+
+static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
+  EnumAttr enumAttr(enumDef);
+  StringRef cppNamespace = enumAttr.getCppNamespace();
+
+  llvm::SmallVector<StringRef, 2> namespaces;
+  llvm::SplitString(cppNamespace, namespaces, "::");
+
+  for (auto ns : namespaces)
+    os << "namespace " << ns << " {\n";
+
+  emitSymToStrFn(enumDef, os);
+  emitStrToSymFn(enumDef, os);
+  emitUnderlyingToSymFn(enumDef, os);
+
+  for (auto ns : llvm::reverse(namespaces))
+    os << "} // namespace " << ns << "\n";
+  os << "\n";
+}
+
+static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
+  llvm::emitSourceFileHeader("Enum Utility Definitions", os);
+
+  auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
+  for (const auto *def : defs)
+    emitEnumDef(*def, os);
+
+  return false;
+}
+
+// Registers the enum utility generator to mlir-tblgen.
+static mlir::GenRegistration
+    genEnumDecls("gen-enum-decls", "Generate enum utility declarations",
+                 [](const RecordKeeper &records, raw_ostream &os) {
+                   return emitEnumDecls(records, os);
+                 });
+
+// Registers the enum utility generator to mlir-tblgen.
+static mlir::GenRegistration
+    genEnumDefs("gen-enum-defs", "Generate enum utility definitions",
+                [](const RecordKeeper &records, raw_ostream &os) {
+                  return emitEnumDefs(records, os);
+                });
diff --git a/third_party/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/third_party/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
new file mode 100644
index 0000000..150fb7c
--- /dev/null
+++ b/third_party/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
@@ -0,0 +1,185 @@
+//===- LLVMIRConversionGen.cpp - MLIR LLVM IR builder generator -----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file uses tablegen definitions of the LLVM IR Dialect operations to
+// generate the code building the LLVM IR from it.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Operator.h"
+
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/TableGen/Record.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+using namespace llvm;
+using namespace mlir;
+
+static bool emitError(const Twine &message) {
+  llvm::errs() << message << "\n";
+  return false;
+}
+
+namespace {
+// Helper structure to return a position of the substring in a string.
+struct StringLoc {
+  size_t pos;
+  size_t length;
+
+  // Take a substring identified by this location in the given string.
+  StringRef in(StringRef str) const { return str.substr(pos, length); }
+
+  // A location is invalid if its position is outside the string.
+  explicit operator bool() { return pos != std::string::npos; }
+};
+} // namespace
+
+// Find the next TableGen variable in the given pattern.  These variables start
+// with a `$` character and can contain alphannumeric characters or underscores.
+// Return the position of the variable in the pattern and its length, including
+// the `$` character.  The escape syntax `$$` is also detected and returned.
+static StringLoc findNextVariable(StringRef str) {
+  size_t startPos = str.find('$');
+  if (startPos == std::string::npos)
+    return {startPos, 0};
+
+  // If we see "$$", return immediately.
+  if (startPos != str.size() - 1 && str[startPos + 1] == '$')
+    return {startPos, 2};
+
+  // Otherwise, the symbol spans until the first character that is not
+  // alphanumeric or '_'.
+  size_t endPos = str.find_if_not([](char c) { return isAlnum(c) || c == '_'; },
+                                  startPos + 1);
+  if (endPos == std::string::npos)
+    endPos = str.size();
+
+  return {startPos, endPos - startPos};
+}
+
+// Check if `name` is the name of the variadic operand of `op`.  The variadic
+// operand can only appear at the last position in the list of operands.
+static bool isVariadicOperandName(const tblgen::Operator &op, StringRef name) {
+  unsigned numOperands = op.getNumOperands();
+  if (numOperands == 0)
+    return false;
+  const auto &operand = op.getOperand(numOperands - 1);
+  return operand.isVariadic() && operand.name == name;
+}
+
+// Check if `result` is a known name of a result of `op`.
+static bool isResultName(const tblgen::Operator &op, StringRef name) {
+  for (int i = 0, e = op.getNumResults(); i < e; ++i)
+    if (op.getResultName(i) == name)
+      return true;
+  return false;
+}
+
+// Check if `name` is a known name of an attribute of `op`.
+static bool isAttributeName(const tblgen::Operator &op, StringRef name) {
+  return llvm::any_of(
+      op.getAttributes(),
+      [name](const tblgen::NamedAttribute &attr) { return attr.name == name; });
+}
+
+// Check if `name` is a known name of an operand of `op`.
+static bool isOperandName(const tblgen::Operator &op, StringRef name) {
+  for (int i = 0, e = op.getNumOperands(); i < e; ++i)
+    if (op.getOperand(i).name == name)
+      return true;
+  return false;
+}
+
+// Emit to `os` the operator-name driven check and the call to LLVM IRBuilder
+// for one definition of a LLVM IR Dialect operation.  Return true on success.
+static bool emitOneBuilder(const Record &record, raw_ostream &os) {
+  auto op = tblgen::Operator(record);
+
+  if (!record.getValue("llvmBuilder"))
+    return emitError("no 'llvmBuilder' field for op " + op.getOperationName());
+
+  // Return early if there is no builder specified.
+  auto builderStrRef = record.getValueAsString("llvmBuilder");
+  if (builderStrRef.empty())
+    return true;
+
+  // Progressively create the builder string by replacing $-variables with
+  // value lookups.  Keep only the not-yet-traversed part of the builder pattern
+  // to avoid re-traversing the string multiple times.
+  std::string builder;
+  llvm::raw_string_ostream bs(builder);
+  while (auto loc = findNextVariable(builderStrRef)) {
+    auto name = loc.in(builderStrRef).drop_front();
+    // First, insert the non-matched part as is.
+    bs << builderStrRef.substr(0, loc.pos);
+    // Then, rewrite the name based on its kind.
+    bool isVariadicOperand = isVariadicOperandName(op, name);
+    if (isOperandName(op, name)) {
+      auto result = isVariadicOperand
+                        ? formatv("lookupValues(op.{0}())", name)
+                        : formatv("valueMapping.lookup(op.{0}())", name);
+      bs << result;
+    } else if (isAttributeName(op, name)) {
+      bs << formatv("op.{0}()", name);
+    } else if (isResultName(op, name)) {
+      bs << formatv("valueMapping[op.{0}()]", name);
+    } else if (name == "_resultType") {
+      bs << "op.getResult()->getType().cast<LLVM::LLVMType>()."
+            "getUnderlyingType()";
+    } else if (name == "_hasResult") {
+      bs << "opInst.getNumResults() == 1";
+    } else if (name == "_location") {
+      bs << "opInst.getLoc()";
+    } else if (name == "_numOperands") {
+      bs << "opInst.getNumOperands()";
+    } else if (name == "$") {
+      bs << '$';
+    } else {
+      return emitError(name + " is neither an argument nor a result of " +
+                       op.getOperationName());
+    }
+    // Finally, only keep the untraversed part of the string.
+    builderStrRef = builderStrRef.substr(loc.pos + loc.length);
+  }
+
+  // Output the check and the rewritten builder string.
+  os << "if (auto op = dyn_cast<" << op.getQualCppClassName()
+     << ">(opInst)) {\n";
+  os << bs.str() << builderStrRef << "\n";
+  os << "  return success();\n";
+  os << "}\n";
+
+  return true;
+}
+
+// Emit all builders.  Returns false on success because of the generator
+// registration requirements.
+static bool emitBuilders(const RecordKeeper &recordKeeper, raw_ostream &os) {
+  for (const auto *def : recordKeeper.getAllDerivedDefinitions("LLVM_OpBase")) {
+    if (!emitOneBuilder(*def, os))
+      return true;
+  }
+  return false;
+}
+
+static mlir::GenRegistration
+    genLLVMIRConversions("gen-llvmir-conversions",
+                         "Generate LLVM IR conversions", emitBuilders);
diff --git a/third_party/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/third_party/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
new file mode 100644
index 0000000..1a96ddd
--- /dev/null
+++ b/third_party/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -0,0 +1,1385 @@
+//===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// OpDefinitionsGen uses the description of operations to generate C++
+// definitions for ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/STLExtras.h"
+#include "mlir/TableGen/Format.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/OpTrait.h"
+#include "mlir/TableGen/Operator.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/Signals.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::tblgen;
+
+static const char *const tblgenNamePrefix = "tblgen_";
+static const char *const generatedArgName = "tblgen_arg";
+static const char *const builderOpState = "tblgen_state";
+
+// The logic to calculate the dynamic value range for an static operand/result
+// of an op with variadic operands/results. Note that this logic is not for
+// general use; it assumes all variadic operands/results must have the same
+// number of values.
+//
+// {0}: The list of whether each static operand/result is variadic.
+// {1}: The total number of non-variadic operands/results.
+// {2}: The total number of variadic operands/results.
+// {3}: The total number of dynamic values.
+// {4}: The begin iterator of the dynamic values.
+// {5}: "operand" or "result"
+const char *valueRangeCalcCode = R"(
+  bool isVariadic[] = {{{0}};
+  int prevVariadicCount = 0;
+  for (unsigned i = 0; i < index; ++i)
+    if (isVariadic[i]) ++prevVariadicCount;
+
+  // Calculate how many dynamic values a static variadic {5} corresponds to.
+  // This assumes all static variadic {5}s have the same dynamic value count.
+  int variadicSize = ({3} - {1}) / {2};
+  // `index` passed in as the parameter is the static index which counts each
+  // {5} (variadic or not) as size 1. So here for each previous static variadic
+  // {5}, we need to offset by (variadicSize - 1) to get where the dynamic
+  // value pack for this static {5} starts.
+  int offset = index + (variadicSize - 1) * prevVariadicCount;
+  int size = isVariadic[index] ? variadicSize : 1;
+
+  return {{std::next({4}, offset), std::next({4}, offset + size)};
+)";
+
+static const char *const opCommentHeader = R"(
+//===----------------------------------------------------------------------===//
+// {0} {1}
+//===----------------------------------------------------------------------===//
+
+)";
+
+//===----------------------------------------------------------------------===//
+// Utility structs and functions
+//===----------------------------------------------------------------------===//
+
+// Returns whether the record has a value of the given name that can be returned
+// via getValueAsString.
+static inline bool hasStringAttribute(const Record &record,
+                                      StringRef fieldName) {
+  auto valueInit = record.getValueInit(fieldName);
+  return isa<CodeInit>(valueInit) || isa<StringInit>(valueInit);
+}
+
+static std::string getArgumentName(const Operator &op, int index) {
+  const auto &operand = op.getOperand(index);
+  if (!operand.name.empty())
+    return operand.name;
+  else
+    return formatv("{0}_{1}", generatedArgName, index);
+}
+
+namespace {
+// Simple RAII helper for defining ifdef-undef-endif scopes.
+class IfDefScope {
+public:
+  IfDefScope(StringRef name, raw_ostream &os) : name(name), os(os) {
+    os << "#ifdef " << name << "\n"
+       << "#undef " << name << "\n\n";
+  }
+
+  ~IfDefScope() { os << "\n#endif  // " << name << "\n\n"; }
+
+private:
+  StringRef name;
+  raw_ostream &os;
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Classes for C++ code emission
+//===----------------------------------------------------------------------===//
+
+// We emit the op declaration and definition into separate files: *Ops.h.inc
+// and *Ops.cpp.inc. The former is to be included in the dialect *Ops.h and
+// the latter for dialect *Ops.cpp. This way provides a cleaner interface.
+//
+// In order to do this split, we need to track method signature and
+// implementation logic separately. Signature information is used for both
+// declaration and definition, while implementation logic is only for
+// definition. So we have the following classes for C++ code emission.
+
+namespace {
+// Class for holding the signature of an op's method for C++ code emission
+class OpMethodSignature {
+public:
+  OpMethodSignature(StringRef retType, StringRef name, StringRef params);
+
+  // Writes the signature as a method declaration to the given `os`.
+  void writeDeclTo(raw_ostream &os) const;
+  // Writes the signature as the start of a method definition to the given `os`.
+  // `namePrefix` is the prefix to be prepended to the method name (typically
+  // namespaces for qualifying the method definition).
+  void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
+
+private:
+  // Returns true if the given C++ `type` ends with '&' or '*', or is empty.
+  static bool elideSpaceAfterType(StringRef type);
+
+  std::string returnType;
+  std::string methodName;
+  std::string parameters;
+};
+
+// Class for holding the body of an op's method for C++ code emission
+class OpMethodBody {
+public:
+  explicit OpMethodBody(bool declOnly);
+
+  OpMethodBody &operator<<(Twine content);
+  OpMethodBody &operator<<(int content);
+  OpMethodBody &operator<<(const FmtObjectBase &content);
+
+  void writeTo(raw_ostream &os) const;
+
+private:
+  // Whether this class should record method body.
+  bool isEffective;
+  std::string body;
+};
+
+// Class for holding an op's method for C++ code emission
+class OpMethod {
+public:
+  // Properties (qualifiers) of class methods. Bitfield is used here to help
+  // querying properties.
+  enum Property {
+    MP_None = 0x0,
+    MP_Static = 0x1,      // Static method
+    MP_Constructor = 0x2, // Constructor
+    MP_Private = 0x4,     // Private method
+  };
+
+  OpMethod(StringRef retType, StringRef name, StringRef params,
+           Property property, bool declOnly);
+
+  OpMethodBody &body();
+
+  // Returns true if this is a static method.
+  bool isStatic() const;
+
+  // Returns true if this is a private method.
+  bool isPrivate() const;
+
+  // Writes the method as a declaration to the given `os`.
+  void writeDeclTo(raw_ostream &os) const;
+  // Writes the method as a definition to the given `os`. `namePrefix` is the
+  // prefix to be prepended to the method name (typically namespaces for
+  // qualifying the method definition).
+  void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
+
+private:
+  Property properties;
+  // Whether this method only contains a declaration.
+  bool isDeclOnly;
+  OpMethodSignature methodSignature;
+  OpMethodBody methodBody;
+};
+
+// A class used to emit C++ classes from Tablegen.  Contains a list of public
+// methods and a list of private fields to be emitted.
+class Class {
+public:
+  explicit Class(StringRef name);
+
+  // Creates a new method in this class.
+  OpMethod &newMethod(StringRef retType, StringRef name, StringRef params = "",
+                      OpMethod::Property = OpMethod::MP_None,
+                      bool declOnly = false);
+
+  OpMethod &newConstructor(StringRef params = "", bool declOnly = false);
+
+  // Creates a new field in this class.
+  void newField(StringRef type, StringRef name, StringRef defaultValue = "");
+
+  // Writes this op's class as a declaration to the given `os`.
+  void writeDeclTo(raw_ostream &os) const;
+  // Writes the method definitions in this op's class to the given `os`.
+  void writeDefTo(raw_ostream &os) const;
+
+  // Returns the C++ class name of the op.
+  StringRef getClassName() const { return className; }
+
+protected:
+  std::string className;
+  SmallVector<OpMethod, 8> methods;
+  SmallVector<std::string, 4> fields;
+};
+
+// Class for holding an op for C++ code emission
+class OpClass : public Class {
+public:
+  explicit OpClass(StringRef name, StringRef extraClassDeclaration = "");
+
+  // Adds an op trait.
+  void addTrait(Twine trait);
+
+  // Writes this op's class as a declaration to the given `os`.  Redefines
+  // Class::writeDeclTo to also emit traits and extra class declarations.
+  void writeDeclTo(raw_ostream &os) const;
+
+private:
+  StringRef extraClassDeclaration;
+  SmallVector<std::string, 4> traits;
+};
+} // end anonymous namespace
+
+OpMethodSignature::OpMethodSignature(StringRef retType, StringRef name,
+                                     StringRef params)
+    : returnType(retType), methodName(name), parameters(params) {}
+
+void OpMethodSignature::writeDeclTo(raw_ostream &os) const {
+  os << returnType << (elideSpaceAfterType(returnType) ? "" : " ") << methodName
+     << "(" << parameters << ")";
+}
+
+void OpMethodSignature::writeDefTo(raw_ostream &os,
+                                   StringRef namePrefix) const {
+  // We need to remove the default values for parameters in method definition.
+  // TODO(antiagainst): We are using '=' and ',' as delimiters for parameter
+  // initializers. This is incorrect for initializer list with more than one
+  // element. Change to a more robust approach.
+  auto removeParamDefaultValue = [](StringRef params) {
+    std::string result;
+    std::pair<StringRef, StringRef> parts;
+    while (!params.empty()) {
+      parts = params.split("=");
+      result.append(result.empty() ? "" : ", ");
+      result.append(parts.first);
+      params = parts.second.split(",").second;
+    }
+    return result;
+  };
+
+  os << returnType << (elideSpaceAfterType(returnType) ? "" : " ") << namePrefix
+     << (namePrefix.empty() ? "" : "::") << methodName << "("
+     << removeParamDefaultValue(parameters) << ")";
+}
+
+bool OpMethodSignature::elideSpaceAfterType(StringRef type) {
+  return type.empty() || type.endswith("&") || type.endswith("*");
+}
+
+OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {}
+
+OpMethodBody &OpMethodBody::operator<<(Twine content) {
+  if (isEffective)
+    body.append(content.str());
+  return *this;
+}
+
+OpMethodBody &OpMethodBody::operator<<(int content) {
+  if (isEffective)
+    body.append(std::to_string(content));
+  return *this;
+}
+
+OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) {
+  if (isEffective)
+    body.append(content.str());
+  return *this;
+}
+
+void OpMethodBody::writeTo(raw_ostream &os) const {
+  auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; });
+  os << bodyRef;
+  if (bodyRef.empty() || bodyRef.back() != '\n')
+    os << "\n";
+}
+
+OpMethod::OpMethod(StringRef retType, StringRef name, StringRef params,
+                   OpMethod::Property property, bool declOnly)
+    : properties(property), isDeclOnly(declOnly),
+      methodSignature(retType, name, params), methodBody(declOnly) {}
+
+OpMethodBody &OpMethod::body() { return methodBody; }
+
+bool OpMethod::isStatic() const { return properties & MP_Static; }
+
+bool OpMethod::isPrivate() const { return properties & MP_Private; }
+
+void OpMethod::writeDeclTo(raw_ostream &os) const {
+  os.indent(2);
+  if (isStatic())
+    os << "static ";
+  methodSignature.writeDeclTo(os);
+  os << ";";
+}
+
+void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
+  if (isDeclOnly)
+    return;
+
+  methodSignature.writeDefTo(os, namePrefix);
+  os << " {\n";
+  methodBody.writeTo(os);
+  os << "}";
+}
+
+Class::Class(StringRef name) : className(name) {}
+
+OpMethod &Class::newMethod(StringRef retType, StringRef name, StringRef params,
+                           OpMethod::Property property, bool declOnly) {
+  methods.emplace_back(retType, name, params, property, declOnly);
+  return methods.back();
+}
+
+OpMethod &Class::newConstructor(StringRef params, bool declOnly) {
+  return newMethod("", getClassName(), params, OpMethod::MP_Constructor,
+                   declOnly);
+}
+
+void Class::newField(StringRef type, StringRef name, StringRef defaultValue) {
+  std::string varName = formatv("{0} {1}", type, name).str();
+  std::string field = defaultValue.empty()
+                          ? varName
+                          : formatv("{0} = {1}", varName, defaultValue).str();
+  fields.push_back(std::move(field));
+}
+
+void Class::writeDeclTo(raw_ostream &os) const {
+  bool hasPrivateMethod = false;
+  os << "class " << className << " {\n";
+  os << "public:\n";
+  for (const auto &method : methods) {
+    if (!method.isPrivate()) {
+      method.writeDeclTo(os);
+      os << '\n';
+    } else {
+      hasPrivateMethod = true;
+    }
+  }
+  os << '\n';
+  os << "private:\n";
+  if (hasPrivateMethod) {
+    for (const auto &method : methods) {
+      if (method.isPrivate()) {
+        method.writeDeclTo(os);
+        os << '\n';
+      }
+    }
+    os << '\n';
+  }
+  for (const auto &field : fields)
+    os.indent(2) << field << ";\n";
+  os << "};\n";
+}
+
+void Class::writeDefTo(raw_ostream &os) const {
+  for (const auto &method : methods) {
+    method.writeDefTo(os, className);
+    os << "\n\n";
+  }
+}
+
+OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
+    : Class(name), extraClassDeclaration(extraClassDeclaration) {}
+
+// Adds the given trait to this op. Prefixes "OpTrait::" to `trait` implicitly.
+void OpClass::addTrait(Twine trait) {
+  traits.push_back(("OpTrait::" + trait).str());
+}
+
+void OpClass::writeDeclTo(raw_ostream &os) const {
+  os << "class " << className << " : public Op<" << className;
+  for (const auto &trait : traits)
+    os << ", " << trait;
+  os << "> {\npublic:\n";
+  os << "  using Op::Op;\n";
+  os << "  using OperandAdaptor = " << className << "OperandAdaptor;\n";
+
+  bool hasPrivateMethod = false;
+  for (const auto &method : methods) {
+    if (!method.isPrivate()) {
+      method.writeDeclTo(os);
+      os << "\n";
+    } else {
+      hasPrivateMethod = true;
+    }
+  }
+
+  // TODO: Add line control markers to make errors easier to debug.
+  if (!extraClassDeclaration.empty())
+    os << extraClassDeclaration << "\n";
+
+  if (hasPrivateMethod) {
+    os << '\n';
+    os << "private:\n";
+    for (const auto &method : methods) {
+      if (method.isPrivate()) {
+        method.writeDeclTo(os);
+        os << "\n";
+      }
+    }
+  }
+
+  os << "};\n";
+}
+
+//===----------------------------------------------------------------------===//
+// Op emitter
+//===----------------------------------------------------------------------===//
+
+namespace {
+// Helper class to emit a record into the given output stream.
+class OpEmitter {
+public:
+  static void emitDecl(const Operator &op, raw_ostream &os);
+  static void emitDef(const Operator &op, raw_ostream &os);
+
+private:
+  OpEmitter(const Operator &op);
+
+  void emitDecl(raw_ostream &os);
+  void emitDef(raw_ostream &os);
+
+  // Generates the `getOperationName` method for this op.
+  void genOpNameGetter();
+
+  // Generates getters for the attributes.
+  void genAttrGetters();
+
+  // Generates getters for named operands.
+  void genNamedOperandGetters();
+
+  // Generates getters for named results.
+  void genNamedResultGetters();
+
+  // Generates getters for named regions.
+  void genNamedRegionGetters();
+
+  // Generates builder methods for the operation.
+  void genBuilder();
+
+  // Generates the build() method that takes each result-type/operand/attribute
+  // as a stand-alone parameter. This build() method also requires specifying
+  // result types for all results.
+  void genSeparateParamBuilder();
+
+  // Generates the build() method that takes each operand/attribute as a
+  // stand-alone parameter. This build() method uses first operand's type
+  // as all result's types.
+  void genUseOperandAsResultTypeBuilder();
+
+  // Generates the build() method that takes each operand/attribute as a
+  // stand-alone parameter. This build() method uses first attribute's type
+  // as all result's types.
+  void genUseAttrAsResultTypeBuilder();
+
+  // Generates the build() method that takes all result types collectively as
+  // one parameter. Similarly for operands and attributes.
+  void genCollectiveParamBuilder();
+
+  // Builds the parameter list for build() method of this op. This method writes
+  // to `paramList` the comma-separated parameter list. If `includeResultTypes`
+  // is true then `paramList` will also contain the parameters for all results
+  // and `resultTypeNames` will be populated with the parameter name for each
+  // result type.
+  void buildParamList(std::string &paramList,
+                      SmallVectorImpl<std::string> &resultTypeNames,
+                      bool includeResultTypes);
+
+  // Adds op arguments and regions into operation state for build() methods.
+  void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body);
+
+  // Generates canonicalizer declaration for the operation.
+  void genCanonicalizerDecls();
+
+  // Generates the folder declaration for the operation.
+  void genFolderDecls();
+
+  // Generates the parser for the operation.
+  void genParser();
+
+  // Generates the printer for the operation.
+  void genPrinter();
+
+  // Generates verify method for the operation.
+  void genVerifier();
+
+  // Generates verify statements for operands and results in the operation.
+  // The generated code will be attached to `body`.
+  void genOperandResultVerifier(OpMethodBody &body,
+                                Operator::value_range values,
+                                StringRef valueKind);
+
+  // Generates verify statements for regions in the operation.
+  // The generated code will be attached to `body`.
+  void genRegionVerifier(OpMethodBody &body);
+
+  // Generates the traits used by the object.
+  void genTraits();
+
+private:
+  // The TableGen record for this op.
+  // TODO(antiagainst,zinenko): OpEmitter should not have a Record directly,
+  // it should rather go through the Operator for better abstraction.
+  const Record &def;
+
+  // The wrapper operator class for querying information from this op.
+  Operator op;
+
+  // The C++ code builder for this op
+  OpClass opClass;
+
+  // The format context for verification code generation.
+  FmtContext verifyCtx;
+};
+} // end anonymous namespace
+
+OpEmitter::OpEmitter(const Operator &op)
+    : def(op.getDef()), op(op),
+      opClass(op.getCppClassName(), op.getExtraClassDeclaration()) {
+  verifyCtx.withOp("(*this->getOperation())");
+
+  genTraits();
+  // Generate C++ code for various op methods. The order here determines the
+  // methods in the generated file.
+  genOpNameGetter();
+  genNamedOperandGetters();
+  genNamedResultGetters();
+  genNamedRegionGetters();
+  genAttrGetters();
+  genBuilder();
+  genParser();
+  genPrinter();
+  genVerifier();
+  genCanonicalizerDecls();
+  genFolderDecls();
+}
+
+void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) {
+  OpEmitter(op).emitDecl(os);
+}
+
+void OpEmitter::emitDef(const Operator &op, raw_ostream &os) {
+  OpEmitter(op).emitDef(os);
+}
+
+void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); }
+
+void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); }
+
+void OpEmitter::genAttrGetters() {
+  FmtContext fctx;
+  fctx.withBuilder("mlir::Builder(this->getContext())");
+  for (auto &namedAttr : op.getAttributes()) {
+    const auto &name = namedAttr.name;
+    const auto &attr = namedAttr.attr;
+
+    auto &method = opClass.newMethod(attr.getReturnType(), name);
+    auto &body = method.body();
+
+    // Emit the derived attribute body.
+    if (attr.isDerivedAttr()) {
+      body << "  " << attr.getDerivedCodeBody() << "\n";
+      continue;
+    }
+
+    // Emit normal emitter.
+
+    // Return the queried attribute with the correct return type.
+    auto attrVal =
+        (attr.hasDefaultValueInitializer() || attr.isOptional())
+            ? formatv("this->getAttr(\"{0}\").dyn_cast_or_null<{1}>()", name,
+                      attr.getStorageType())
+            : formatv("this->getAttr(\"{0}\").cast<{1}>()", name,
+                      attr.getStorageType());
+    body << "  auto attr = " << attrVal << ";\n";
+    if (attr.hasDefaultValueInitializer()) {
+      // Returns the default value if not set.
+      // TODO: this is inefficient, we are recreating the attribute for every
+      // call. This should be set instead.
+      std::string defaultValue = tgfmt(attr.getConstBuilderTemplate(), &fctx,
+                                       attr.getDefaultValueInitializer());
+      body << "    if (!attr)\n      return "
+           << tgfmt(attr.getConvertFromStorageCall(),
+                    &fctx.withSelf(defaultValue))
+           << ";\n";
+    }
+    body << "  return "
+         << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr"))
+         << ";\n";
+  }
+}
+
+// Generates the named operand getter methods for the given Operator `op` and
+// puts them in `opClass`.  Uses `rangeType` as the return type of getters that
+// return a range of operands (individual operands are `Value *` and each
+// element in the range must also be `Value *`); use `rangeBeginCall` to get an
+// iterator to the beginning of the operand range; use `rangeSizeCall` to obtain
+// the number of operands. `getOperandCallPattern` contains the code necessary
+// to obtain a single operand whose position will be substituted instead of
+// "{0}" marker in the pattern.  Note that the pattern should work for any kind
+// of ops, in particular for one-operand ops that may not have the
+// `getOperand(unsigned)` method.
+static void generateNamedOperandGetters(const Operator &op, Class &opClass,
+                                        StringRef rangeType,
+                                        StringRef rangeBeginCall,
+                                        StringRef rangeSizeCall,
+                                        StringRef getOperandCallPattern) {
+  const int numOperands = op.getNumOperands();
+  const int numVariadicOperands = op.getNumVariadicOperands();
+  const int numNormalOperands = numOperands - numVariadicOperands;
+
+  if (numVariadicOperands > 1 && !op.hasTrait("SameVariadicOperandSize")) {
+    PrintFatalError(op.getLoc(), "op has multiple variadic operands but no "
+                                 "specification over their sizes");
+  }
+
+  // First emit a "sink" getter method upon which we layer all nicer named
+  // getter methods.
+  auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index");
+
+  if (numVariadicOperands == 0) {
+    // We still need to match the return type, which is a range.
+    m.body() << "return {std::next(" << rangeBeginCall << ", index), std::next("
+             << rangeBeginCall << ", index + 1)};";
+  } else {
+    // Because the op can have arbitrarily interleaved variadic and non-variadic
+    // operands, we need to embed a list in the "sink" getter method for
+    // calculation at run-time.
+    llvm::SmallVector<StringRef, 4> isVariadic;
+    isVariadic.reserve(numOperands);
+    for (int i = 0; i < numOperands; ++i) {
+      isVariadic.push_back(llvm::toStringRef(op.getOperand(i).isVariadic()));
+    }
+    std::string isVariadicList = llvm::join(isVariadic, ", ");
+
+    m.body() << formatv(valueRangeCalcCode, isVariadicList, numNormalOperands,
+                        numVariadicOperands, rangeSizeCall, rangeBeginCall,
+                        "operand");
+  }
+
+  // Then we emit nicer named getter methods by redirecting to the "sink" getter
+  // method.
+
+  for (int i = 0; i != numOperands; ++i) {
+    const auto &operand = op.getOperand(i);
+    if (operand.name.empty())
+      continue;
+
+    if (operand.isVariadic()) {
+      auto &m = opClass.newMethod(rangeType, operand.name);
+      m.body() << "return getODSOperands(" << i << ");";
+    } else {
+      auto &m = opClass.newMethod("Value *", operand.name);
+      m.body() << "return *getODSOperands(" << i << ").begin();";
+    }
+  }
+}
+
+void OpEmitter::genNamedOperandGetters() {
+  generateNamedOperandGetters(
+      op, opClass, /*rangeType=*/"Operation::operand_range",
+      /*rangeBeginCall=*/"getOperation()->operand_begin()",
+      /*rangeSizeCall=*/"getOperation()->getNumOperands()",
+      /*getOperandCallPattern=*/"getOperation()->getOperand({0})");
+}
+
+void OpEmitter::genNamedResultGetters() {
+  const int numResults = op.getNumResults();
+  const int numVariadicResults = op.getNumVariadicResults();
+  const int numNormalResults = numResults - numVariadicResults;
+
+  // If we have more than one variadic results, we need more complicated logic
+  // to calculate the value range for each result.
+
+  if (numVariadicResults > 1 && !op.hasTrait("SameVariadicResultSize")) {
+    PrintFatalError(op.getLoc(), "op has multiple variadic results but no "
+                                 "specification over their sizes");
+  }
+
+  auto &m = opClass.newMethod("Operation::result_range", "getODSResults",
+                              "unsigned index");
+
+  if (numVariadicResults == 0) {
+    m.body() << "return {std::next(getOperation()->result_begin(), index), "
+                "std::next(getOperation()->result_begin(), index + 1)};";
+  } else {
+    llvm::SmallVector<StringRef, 4> isVariadic;
+    isVariadic.reserve(numResults);
+    for (int i = 0; i < numResults; ++i) {
+      isVariadic.push_back(llvm::toStringRef(op.getResult(i).isVariadic()));
+    }
+    std::string isVariadicList = llvm::join(isVariadic, ", ");
+
+    m.body() << formatv(valueRangeCalcCode, isVariadicList, numNormalResults,
+                        numVariadicResults, "getOperation()->getNumResults()",
+                        "getOperation()->result_begin()", "result");
+  }
+
+  for (int i = 0; i != numResults; ++i) {
+    const auto &result = op.getResult(i);
+    if (result.name.empty())
+      continue;
+
+    if (result.isVariadic()) {
+      auto &m = opClass.newMethod("Operation::result_range", result.name);
+      m.body() << "return getODSResults(" << i << ");";
+    } else {
+      auto &m = opClass.newMethod("Value *", result.name);
+      m.body() << "return *getODSResults(" << i << ").begin();";
+    }
+  }
+}
+
+void OpEmitter::genNamedRegionGetters() {
+  unsigned numRegions = op.getNumRegions();
+  for (unsigned i = 0; i < numRegions; ++i) {
+    const auto &region = op.getRegion(i);
+    if (!region.name.empty()) {
+      auto &m = opClass.newMethod("Region &", region.name);
+      m.body() << formatv("return this->getOperation()->getRegion({0});", i);
+    }
+  }
+}
+
+void OpEmitter::genSeparateParamBuilder() {
+  std::string paramList;
+  llvm::SmallVector<std::string, 4> resultNames;
+  buildParamList(paramList, resultNames, /*includeResultTypes=*/true);
+
+  auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
+  genCodeForAddingArgAndRegionForBuilder(m.body());
+
+  // Push all result types to the operation state
+  for (int i = 0, e = op.getNumResults(); i < e; ++i) {
+    m.body() << "  " << builderOpState << "->addTypes(" << resultNames[i]
+             << ");\n";
+  }
+}
+
+void OpEmitter::genUseOperandAsResultTypeBuilder() {
+  std::string paramList;
+  llvm::SmallVector<std::string, 4> resultNames;
+  buildParamList(paramList, resultNames, /*includeResultTypes=*/false);
+
+  auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
+  genCodeForAddingArgAndRegionForBuilder(m.body());
+
+  auto numResults = op.getNumResults();
+  if (numResults == 0)
+    return;
+
+  // Push all result types to the operation state
+  const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
+  std::string resultType =
+      formatv("{0}{1}->getType()", getArgumentName(op, 0), index).str();
+  m.body() << "  " << builderOpState << "->addTypes({" << resultType;
+  for (int i = 1; i != numResults; ++i)
+    m.body() << ", " << resultType;
+  m.body() << "});\n\n";
+}
+
+void OpEmitter::genUseAttrAsResultTypeBuilder() {
+  std::string paramList;
+  llvm::SmallVector<std::string, 4> resultNames;
+  buildParamList(paramList, resultNames, /*includeResultTypes=*/false);
+
+  auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
+  genCodeForAddingArgAndRegionForBuilder(m.body());
+
+  auto numResults = op.getNumResults();
+  if (numResults == 0)
+    return;
+
+  // Push all result types to the operation state
+  std::string resultType;
+  const auto &namedAttr = op.getAttribute(0);
+  if (namedAttr.attr.isTypeAttr()) {
+    resultType = formatv("{0}.getValue()", namedAttr.name);
+  } else {
+    resultType = formatv("{0}.getType()", namedAttr.name);
+  }
+  m.body() << "  " << builderOpState << "->addTypes({" << resultType;
+  for (int i = 1; i != numResults; ++i)
+    m.body() << ", " << resultType;
+  m.body() << "});\n\n";
+}
+
+void OpEmitter::genBuilder() {
+  // Handle custom builders if provided.
+  // TODO(antiagainst): Create wrapper class for OpBuilder to hide the native
+  // TableGen API calls here.
+  {
+    auto *listInit = dyn_cast_or_null<ListInit>(def.getValueInit("builders"));
+    if (listInit) {
+      for (Init *init : listInit->getValues()) {
+        Record *builderDef = cast<DefInit>(init)->getDef();
+        StringRef params = builderDef->getValueAsString("params");
+        StringRef body = builderDef->getValueAsString("body");
+        bool hasBody = !body.empty();
+
+        auto &method =
+            opClass.newMethod("void", "build", params, OpMethod::MP_Static,
+                              /*declOnly=*/!hasBody);
+        if (hasBody)
+          method.body() << body;
+      }
+    }
+    if (op.skipDefaultBuilders()) {
+      if (!listInit || listInit->empty())
+        PrintFatalError(
+            op.getLoc(),
+            "default builders are skipped and no custom builders provided");
+      return;
+    }
+  }
+
+  // Generate default builders that requires all result type, operands, and
+  // attributes as parameters.
+
+  // We generate three builders here:
+  // 1. one having a stand-alone parameter for each result type / operand /
+  //    attribute, and
+  genSeparateParamBuilder();
+  // 2. one having an aggregated parameter for all result types / operands /
+  //    attributes, and
+  genCollectiveParamBuilder();
+  // 3. one having a stand-alone prameter for each operand and attribute,
+  //    use the first operand or attribute's type as all result types
+  // to facilitate different call patterns.
+  if (op.getNumVariadicResults() == 0) {
+    if (op.hasTrait("SameOperandsAndResultType"))
+      genUseOperandAsResultTypeBuilder();
+    if (op.hasTrait("FirstAttrDerivedResultType"))
+      genUseAttrAsResultTypeBuilder();
+  }
+}
+
+void OpEmitter::genCollectiveParamBuilder() {
+  int numResults = op.getNumResults();
+  int numVariadicResults = op.getNumVariadicResults();
+  int numNonVariadicResults = numResults - numVariadicResults;
+
+  int numOperands = op.getNumOperands();
+  int numVariadicOperands = op.getNumVariadicOperands();
+  int numNonVariadicOperands = numOperands - numVariadicOperands;
+  // Signature
+  std::string params =
+      std::string("Builder *, OperationState *") + builderOpState +
+      ", ArrayRef<Type> resultTypes, ArrayRef<Value *> operands, "
+      "ArrayRef<NamedAttribute> attributes";
+  auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static);
+  auto &body = m.body();
+
+  // Result types
+  if (numVariadicResults == 0 || numNonVariadicResults != 0)
+    body << "  assert(resultTypes.size()"
+         << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
+         << "u && \"mismatched number of return types\");\n";
+  body << "  " << builderOpState << "->addTypes(resultTypes);\n";
+
+  // Operands
+  if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
+    body << "  assert(operands.size()"
+         << (numVariadicOperands != 0 ? " >= " : " == ")
+         << numNonVariadicOperands
+         << "u && \"mismatched number of parameters\");\n";
+  body << "  " << builderOpState << "->addOperands(operands);\n\n";
+
+  // Attributes
+  body << "  for (const auto& pair : attributes)\n"
+       << "    " << builderOpState
+       << "->addAttribute(pair.first, pair.second);\n";
+
+  // Create the correct number of regions
+  if (int numRegions = op.getNumRegions()) {
+    for (int i = 0; i < numRegions; ++i)
+      m.body() << "  (void)" << builderOpState << "->addRegion();\n";
+  }
+}
+
+void OpEmitter::buildParamList(std::string &paramList,
+                               SmallVectorImpl<std::string> &resultTypeNames,
+                               bool includeResultTypes) {
+
+  paramList = "Builder *, OperationState *";
+  paramList.append(builderOpState);
+
+  if (includeResultTypes) {
+    resultTypeNames.clear();
+    auto numResults = op.getNumResults();
+    resultTypeNames.reserve(numResults);
+
+    // Add parameters for all return types
+    for (int i = 0; i < numResults; ++i) {
+      const auto &result = op.getResult(i);
+      std::string resultName = result.name;
+      if (resultName.empty())
+        resultName = formatv("resultType{0}", i);
+
+      paramList.append(result.isVariadic() ? ", ArrayRef<Type> " : ", Type ");
+      paramList.append(resultName);
+
+      resultTypeNames.emplace_back(std::move(resultName));
+    }
+  }
+
+  int numOperands = 0;
+  int numAttrs = 0;
+
+  // Add parameters for all arguments (operands and attributes).
+  for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
+    auto argument = op.getArg(i);
+    if (argument.is<tblgen::NamedTypeConstraint *>()) {
+      const auto &operand = op.getOperand(numOperands);
+      paramList.append(operand.isVariadic() ? ", ArrayRef<Value *> "
+                                            : ", Value *");
+      paramList.append(getArgumentName(op, numOperands));
+      ++numOperands;
+    } else {
+      // TODO(antiagainst): Support default initializer for attributes
+      const auto &namedAttr = op.getAttribute(numAttrs);
+      const auto &attr = namedAttr.attr;
+      paramList.append(", ");
+      if (attr.isOptional())
+        paramList.append("/*optional*/");
+      paramList.append(attr.getStorageType());
+      paramList.append(" ");
+      paramList.append(namedAttr.name);
+      ++numAttrs;
+    }
+  }
+
+  if (numOperands + numAttrs != op.getNumArgs())
+    PrintFatalError("op arguments must be either operands or attributes");
+}
+
+void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body) {
+  // Push all operands to the result
+  for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
+    body << "  " << builderOpState << "->addOperands(" << getArgumentName(op, i)
+         << ");\n";
+  }
+
+  // Push all attributes to the result
+  for (const auto &namedAttr : op.getAttributes()) {
+    if (!namedAttr.attr.isDerivedAttr()) {
+      bool emitNotNullCheck = namedAttr.attr.isOptional();
+      if (emitNotNullCheck) {
+        body << formatv("  if ({0}) ", namedAttr.name) << "{\n";
+      }
+      body << formatv("  {0}->addAttribute(\"{1}\", {1});\n", builderOpState,
+                      namedAttr.name);
+      if (emitNotNullCheck) {
+        body << "  }\n";
+      }
+    }
+  }
+
+  // Create the correct number of regions
+  if (int numRegions = op.getNumRegions()) {
+    for (int i = 0; i < numRegions; ++i)
+      body << "  (void)" << builderOpState << "->addRegion();\n";
+  }
+}
+
+void OpEmitter::genCanonicalizerDecls() {
+  if (!def.getValueAsBit("hasCanonicalizer"))
+    return;
+
+  const char *const params =
+      "OwningRewritePatternList &results, MLIRContext *context";
+  opClass.newMethod("void", "getCanonicalizationPatterns", params,
+                    OpMethod::MP_Static, /*declOnly=*/true);
+}
+
+void OpEmitter::genFolderDecls() {
+  bool hasSingleResult = op.getNumResults() == 1;
+
+  if (def.getValueAsBit("hasFolder")) {
+    if (hasSingleResult) {
+      const char *const params = "ArrayRef<Attribute> operands";
+      opClass.newMethod("OpFoldResult", "fold", params, OpMethod::MP_None,
+                        /*declOnly=*/true);
+    } else {
+      const char *const params = "ArrayRef<Attribute> operands, "
+                                 "SmallVectorImpl<OpFoldResult> &results";
+      opClass.newMethod("LogicalResult", "fold", params, OpMethod::MP_None,
+                        /*declOnly=*/true);
+    }
+  }
+}
+
+void OpEmitter::genParser() {
+  if (!hasStringAttribute(def, "parser"))
+    return;
+
+  auto &method = opClass.newMethod(
+      "ParseResult", "parse", "OpAsmParser *parser, OperationState *result",
+      OpMethod::MP_Static);
+  FmtContext fctx;
+  fctx.addSubst("cppClass", opClass.getClassName());
+  auto parser = def.getValueAsString("parser").ltrim().rtrim(" \t\v\f\r");
+  method.body() << "  " << tgfmt(parser, &fctx);
+}
+
+void OpEmitter::genPrinter() {
+  auto valueInit = def.getValueInit("printer");
+  CodeInit *codeInit = dyn_cast<CodeInit>(valueInit);
+  if (!codeInit)
+    return;
+
+  auto &method = opClass.newMethod("void", "print", "OpAsmPrinter *p");
+  FmtContext fctx;
+  fctx.addSubst("cppClass", opClass.getClassName());
+  auto printer = codeInit->getValue().ltrim().rtrim(" \t\v\f\r");
+  method.body() << "  " << tgfmt(printer, &fctx);
+}
+
+void OpEmitter::genVerifier() {
+  auto valueInit = def.getValueInit("verifier");
+  CodeInit *codeInit = dyn_cast<CodeInit>(valueInit);
+  bool hasCustomVerify = codeInit && !codeInit->getValue().empty();
+
+  auto &method = opClass.newMethod("LogicalResult", "verify", /*params=*/"");
+  auto &body = method.body();
+
+  // Populate substitutions for attributes and named operands and results.
+  for (const auto &namedAttr : op.getAttributes())
+    verifyCtx.addSubst(namedAttr.name,
+                       formatv("this->getAttr(\"{0}\")", namedAttr.name));
+  for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
+    auto &value = op.getOperand(i);
+    // Skip from from first variadic operands for now. Else getOperand index
+    // used below doesn't match.
+    if (value.isVariadic())
+      break;
+    if (!value.name.empty())
+      verifyCtx.addSubst(
+          value.name, formatv("(*this->getOperation()->getOperand({0}))", i));
+  }
+  for (int i = 0, e = op.getNumResults(); i < e; ++i) {
+    auto &value = op.getResult(i);
+    // Skip from from first variadic results for now. Else getResult index used
+    // below doesn't match.
+    if (value.isVariadic())
+      break;
+    if (!value.name.empty())
+      verifyCtx.addSubst(value.name,
+                         formatv("(*this->getOperation()->getResult({0}))", i));
+  }
+
+  // Verify the attributes have the correct type.
+  for (const auto &namedAttr : op.getAttributes()) {
+    const auto &attr = namedAttr.attr;
+    if (attr.isDerivedAttr())
+      continue;
+
+    auto attrName = namedAttr.name;
+    // Prefix with `tblgen_` to avoid hiding the attribute accessor.
+    auto varName = tblgenNamePrefix + attrName;
+    body << formatv("  auto {0} = this->getAttr(\"{1}\");\n", varName,
+                    attrName);
+
+    bool allowMissingAttr =
+        attr.hasDefaultValueInitializer() || attr.isOptional();
+    if (allowMissingAttr) {
+      // If the attribute has a default value, then only verify the predicate if
+      // set. This does effectively assume that the default value is valid.
+      // TODO: verify the debug value is valid (perhaps in debug mode only).
+      body << "  if (" << varName << ") {\n";
+    } else {
+      body << "  if (!" << varName
+           << ") return emitOpError(\"requires attribute '" << attrName
+           << "'\");\n  {\n";
+    }
+
+    auto attrPred = attr.getPredicate();
+    if (!attrPred.isNull()) {
+      body << tgfmt(
+          "    if (!($0)) return emitOpError(\"attribute '$1' "
+          "failed to satisfy constraint: $2\");\n",
+          /*ctx=*/nullptr,
+          tgfmt(attrPred.getCondition(), &verifyCtx.withSelf(varName)),
+          attrName, attr.getDescription());
+    }
+
+    body << "  }\n";
+  }
+
+  genOperandResultVerifier(body, op.getOperands(), "operand");
+  genOperandResultVerifier(body, op.getResults(), "result");
+
+  for (auto &trait : op.getTraits()) {
+    if (auto t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
+      body << tgfmt("  if (!($0)) {\n    "
+                    "return emitOpError(\"failed to verify that $1\");\n  }\n",
+                    &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
+                    t->getDescription());
+    }
+  }
+
+  genRegionVerifier(body);
+
+  if (hasCustomVerify)
+    body << codeInit->getValue() << "\n";
+  else
+    body << "  return mlir::success();\n";
+}
+
+void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
+                                         Operator::value_range values,
+                                         StringRef valueKind) {
+  FmtContext fctx;
+
+  body << "  {\n";
+  body << "    unsigned index = 0; (void)index;\n";
+
+  for (auto staticValue : llvm::enumerate(values)) {
+    if (!staticValue.value().hasPredicate())
+      continue;
+
+    // Emit a loop to check all the dynamic values in the pack.
+    body << formatv("    for (Value *v : getODS{0}{1}s({2})) {{\n",
+                    // Capitalize the first letter to match the function name
+                    valueKind.substr(0, 1).upper(), valueKind.substr(1),
+                    staticValue.index());
+
+    auto constraint = staticValue.value().constraint;
+
+    body << "      (void)v;\n"
+         << "      if (!("
+         << tgfmt(constraint.getConditionTemplate(),
+                  &fctx.withSelf("v->getType()"))
+         << ")) {\n"
+         << formatv("        return emitOpError(\"{0} #\") << index "
+                    "<< \" must be {1}\";\n",
+                    valueKind, constraint.getDescription())
+         << "      }\n" // if
+         << "      ++index;\n"
+         << "    }\n"; // for
+  }
+
+  body << "  }\n";
+}
+
+void OpEmitter::genRegionVerifier(OpMethodBody &body) {
+  unsigned numRegions = op.getNumRegions();
+
+  // Verify this op has the correct number of regions
+  body << formatv(
+      "  if (this->getOperation()->getNumRegions() != {0}) {\n    "
+      "return emitOpError(\"has incorrect number of regions: expected {0} but "
+      "found \") << this->getOperation()->getNumRegions();\n  }\n",
+      numRegions);
+
+  for (unsigned i = 0; i < numRegions; ++i) {
+    const auto &region = op.getRegion(i);
+
+    std::string name = formatv("#{0}", i);
+    if (!region.name.empty()) {
+      name += formatv(" ('{0}')", region.name);
+    }
+
+    auto getRegion = formatv("this->getOperation()->getRegion({0})", i).str();
+    auto constraint = tgfmt(region.constraint.getConditionTemplate(),
+                            &verifyCtx.withSelf(getRegion))
+                          .str();
+
+    body << formatv("  if (!({0})) {\n    "
+                    "return emitOpError(\"region {1} failed to verify "
+                    "constraint: {2}\");\n  }\n",
+                    constraint, name, region.constraint.getDescription());
+  }
+}
+
+void OpEmitter::genTraits() {
+  int numResults = op.getNumResults();
+  int numVariadicResults = op.getNumVariadicResults();
+
+  // Add return size trait.
+  if (numVariadicResults != 0) {
+    if (numResults == numVariadicResults)
+      opClass.addTrait("VariadicResults");
+    else
+      opClass.addTrait("AtLeastNResults<" + Twine(numResults - 1) + ">::Impl");
+  } else {
+    switch (numResults) {
+    case 0:
+      opClass.addTrait("ZeroResult");
+      break;
+    case 1:
+      opClass.addTrait("OneResult");
+      break;
+    default:
+      opClass.addTrait("NResults<" + Twine(numResults) + ">::Impl");
+      break;
+    }
+  }
+
+  for (const auto &trait : op.getTraits()) {
+    if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
+      opClass.addTrait(opTrait->getTrait());
+  }
+
+  // Add variadic size trait and normal op traits.
+  int numOperands = op.getNumOperands();
+  int numVariadicOperands = op.getNumVariadicOperands();
+
+  // Add operand size trait.
+  if (numVariadicOperands != 0) {
+    if (numOperands == numVariadicOperands)
+      opClass.addTrait("VariadicOperands");
+    else
+      opClass.addTrait("AtLeastNOperands<" + Twine(numOperands - 1) +
+                       ">::Impl");
+  } else {
+    switch (numOperands) {
+    case 0:
+      opClass.addTrait("ZeroOperands");
+      break;
+    case 1:
+      opClass.addTrait("OneOperand");
+      break;
+    default:
+      opClass.addTrait("NOperands<" + Twine(numOperands) + ">::Impl");
+      break;
+    }
+  }
+}
+
+void OpEmitter::genOpNameGetter() {
+  auto &method = opClass.newMethod("StringRef", "getOperationName",
+                                   /*params=*/"", OpMethod::MP_Static);
+  method.body() << "  return \"" << op.getOperationName() << "\";\n";
+}
+
+//===----------------------------------------------------------------------===//
+// OpOperandAdaptor emitter
+//===----------------------------------------------------------------------===//
+
+namespace {
+// Helper class to emit Op operand adaptors to an output stream.  Operand
+// adaptors are wrappers around ArrayRef<Value *> that provide named operand
+// getters identical to those defined in the Op.
+class OpOperandAdaptorEmitter {
+public:
+  static void emitDecl(const Operator &op, raw_ostream &os);
+  static void emitDef(const Operator &op, raw_ostream &os);
+
+private:
+  explicit OpOperandAdaptorEmitter(const Operator &op);
+
+  Class adapterClass;
+};
+} // end namespace
+
+OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
+    : adapterClass(op.getCppClassName().str() + "OperandAdaptor") {
+  adapterClass.newField("ArrayRef<Value *>", "tblgen_operands");
+  auto &constructor = adapterClass.newConstructor("ArrayRef<Value *> values");
+  constructor.body() << "  tblgen_operands = values;\n";
+
+  generateNamedOperandGetters(op, adapterClass,
+                              /*rangeType=*/"ArrayRef<Value *>",
+                              /*rangeBeginCall=*/"tblgen_operands.begin()",
+                              /*rangeSizeCall=*/"tblgen_operands.size()",
+                              /*getOperandCallPattern=*/"tblgen_operands[{0}]");
+}
+
+void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
+  OpOperandAdaptorEmitter(op).adapterClass.writeDeclTo(os);
+}
+
+void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) {
+  OpOperandAdaptorEmitter(op).adapterClass.writeDefTo(os);
+}
+
+// Emits the opcode enum and op classes.
+static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os,
+                          bool emitDecl) {
+  IfDefScope scope("GET_OP_CLASSES", os);
+  // First emit forward declaration for each class, this allows them to refer
+  // to each others in traits for example.
+  if (emitDecl) {
+    for (auto *def : defs) {
+      Operator op(*def);
+      os << "class " << op.getCppClassName() << ";\n";
+    }
+  }
+  for (auto *def : defs) {
+    Operator op(*def);
+    if (emitDecl) {
+      os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
+      OpOperandAdaptorEmitter::emitDecl(op, os);
+      OpEmitter::emitDecl(op, os);
+    } else {
+      os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
+      OpOperandAdaptorEmitter::emitDef(op, os);
+      OpEmitter::emitDef(op, os);
+    }
+  }
+}
+
+// Emits a comma-separated list of the ops.
+static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
+  IfDefScope scope("GET_OP_LIST", os);
+
+  interleave(
+      // TODO: We are constructing the Operator wrapper instance just for
+      // getting it's qualified class name here. Reduce the overhead by having a
+      // lightweight version of Operator class just for that purpose.
+      defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); },
+      [&os]() { os << ",\n"; });
+}
+
+static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
+  emitSourceFileHeader("Op Declarations", os);
+
+  const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
+  emitOpClasses(defs, os, /*emitDecl=*/true);
+
+  return false;
+}
+
+static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
+  emitSourceFileHeader("Op Definitions", os);
+
+  const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
+  emitOpList(defs, os);
+  emitOpClasses(defs, os, /*emitDecl=*/false);
+
+  return false;
+}
+
+static mlir::GenRegistration
+    genOpDecls("gen-op-decls", "Generate op declarations",
+               [](const RecordKeeper &records, raw_ostream &os) {
+                 return emitOpDecls(records, os);
+               });
+
+static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions",
+                                       [](const RecordKeeper &records,
+                                          raw_ostream &os) {
+                                         return emitOpDefs(records, os);
+                                       });
diff --git a/third_party/mlir/tools/mlir-tblgen/OpDocGen.cpp b/third_party/mlir/tools/mlir-tblgen/OpDocGen.cpp
new file mode 100644
index 0000000..0a16c31
--- /dev/null
+++ b/third_party/mlir/tools/mlir-tblgen/OpDocGen.cpp
@@ -0,0 +1,146 @@
+//===- OpDocGen.cpp - MLIR operation documentation generator --------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// OpDocGen uses the description of operations to generate documentation for the
+// operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Operator.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/Signals.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+using namespace llvm;
+using namespace mlir;
+
+using mlir::tblgen::Operator;
+
+// Emit the description by aligning the text to the left per line (e.g.,
+// removing the minimum indentation across the block).
+//
+// This expects that the description in the tablegen file is already formatted
+// in a way the user wanted but has some additional indenting due to being
+// nested in the op definition.
+static void emitDescription(StringRef description, raw_ostream &os) {
+  // Determine the minimum number of spaces in a line.
+  size_t min_indent = -1;
+  StringRef remaining = description;
+  while (!remaining.empty()) {
+    auto split = remaining.split('\n');
+    size_t indent = split.first.find_first_not_of(" \t");
+    if (indent != StringRef::npos)
+      min_indent = std::min(indent, min_indent);
+    remaining = split.second;
+  }
+
+  // Print out the description indented.
+  os << "\n";
+  remaining = description;
+  bool printed = false;
+  while (!remaining.empty()) {
+    auto split = remaining.split('\n');
+    if (split.second.empty()) {
+      // Skip last line with just spaces.
+      if (split.first.ltrim().empty())
+        break;
+    }
+    // Print empty new line without spaces if line only has spaces, unless no
+    // text has been emitted before.
+    if (split.first.ltrim().empty()) {
+      if (printed)
+        os << "\n";
+    } else {
+      os << split.first.substr(min_indent) << "\n";
+      printed = true;
+    }
+    remaining = split.second;
+  }
+}
+
+static void emitOpDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
+  const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
+  os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
+
+  // TODO: Group by dialect.
+  // TODO: Add docs for types used (maybe dialect specific ones?) and link
+  // between use and def.
+  os << "# Operation definition\n";
+  for (auto *def : defs) {
+    Operator op(def);
+    os << "## " << op.getOperationName() << " (" << op.getQualCppClassName()
+       << ")";
+
+    // Emit summary & description of operator.
+    if (op.hasSummary())
+      os << "\n" << op.getSummary() << "\n";
+    os << "\n### Description:\n";
+    if (op.hasDescription())
+      emitDescription(op.getDescription(), os);
+
+    // Emit operands & type of operand. All operands are numbered, some may be
+    // named too.
+    os << "\n### Operands:\n";
+    for (const auto &operand : op.getOperands()) {
+      os << "1. ";
+      if (!operand.name.empty())
+        os << "`" << operand.name << "`: ";
+      else
+        os << "&laquo;unnamed&raquo;: ";
+      os << operand.constraint.getDescription() << "\n";
+    }
+
+    // Emit attributes.
+    // TODO: Attributes are only documented by TableGen name, with no further
+    // info. This should be improved.
+    os << "\n### Attributes:\n";
+    if (op.getNumAttributes() > 0) {
+      os << "| Attribute | MLIR Type | Description |\n"
+         << "| :-------: | :-------: | ----------- |\n";
+    }
+    for (auto namedAttr : op.getAttributes()) {
+      os << "| `" << namedAttr.name << "` | `"
+         << namedAttr.attr.getStorageType() << "` | "
+         << namedAttr.attr.getDescription() << " attribute |\n";
+    }
+
+    // Emit results.
+    os << "\n### Results:\n";
+    for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
+      os << "1. ";
+      auto name = op.getResultName(i);
+      if (name.empty())
+        os << "&laquo;unnamed&raquo;: ";
+      else
+        os << "`" << name << "`: ";
+      os << op.getResultTypeConstraint(i).getDescription() << "\n";
+    }
+
+    os << "\n";
+  }
+}
+
+static mlir::GenRegistration
+    genRegister("gen-op-doc", "Generate operation documentation",
+                [](const RecordKeeper &records, raw_ostream &os) {
+                  emitOpDoc(records, os);
+                  return false;
+                });
diff --git a/third_party/mlir/tools/mlir-tblgen/ReferenceImplGen.cpp b/third_party/mlir/tools/mlir-tblgen/ReferenceImplGen.cpp
new file mode 100644
index 0000000..3e6893a
--- /dev/null
+++ b/third_party/mlir/tools/mlir-tblgen/ReferenceImplGen.cpp
@@ -0,0 +1,94 @@
+//===- ReferenceImplGen.cpp - MLIR reference implementation generator -----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// ReferenceImplGen uses the description of operations to generate reference
+// implementations for the ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Operator.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/Signals.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+using namespace llvm;
+using namespace mlir;
+
+using mlir::tblgen::Operator;
+
+static void emitReferenceImplementations(const RecordKeeper &recordKeeper,
+                                         raw_ostream &os) {
+  emitSourceFileHeader("Reference implementation file", os);
+  const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
+
+  os << "void printRefImplementation(StringRef opName, mlir::FuncOp *f) {\n"
+     << "  using namespace ::mlir::edsc;\n"
+     << "if (false) {}";
+  for (auto *def : defs) {
+    Operator op(def);
+    auto referenceImplGenerator = def->getValueInit("referenceImplementation");
+    if (!referenceImplGenerator)
+      continue;
+    os << " else if (opName == \"" << op.getOperationName() << "\") {\n"
+       << "  edsc::ScopedContext scope(f);\n";
+
+    for (auto en : llvm::enumerate(op.getOperands())) {
+      os.indent(2) << formatv("ValueHandle arg_{0}(f->getArgument({1})); "
+                              "(void)arg_{0};\n",
+                              en.value().name, en.index());
+      // TODO(jpienaar): this is generally incorrect, not all args are memref
+      // in the general case.
+      os.indent(2) << formatv("MemRefView view_{0}(f->getArgument({1})); "
+                              "(void)view_{0};\n",
+                              en.value().name, en.index());
+    }
+    unsigned numOperands = op.getNumOperands();
+    unsigned numResults = op.getNumResults();
+    for (unsigned idx = 0; idx < numResults; ++idx) {
+      os.indent(2) << formatv("ValueHandle arg_{0}(f->getArgument({1})); "
+                              "(void)arg_{0};\n",
+                              op.getResult(idx).name, numOperands + idx);
+      // TODO(jpienaar): this is generally incorrect, not all args are memref
+      // in the general case.
+      os.indent(2) << formatv("MemRefView view_{0}(f->getArgument({1})); "
+                              "(void)view_{0};\n",
+                              op.getResult(idx).name, numOperands + idx);
+    }
+
+    // Print the EDSC.
+    os << referenceImplGenerator->getAsUnquotedString() << "\n";
+    os.indent(2) << "f->print(llvm::outs());\n\n";
+    os << "}";
+  }
+  os << " else {\n";
+  os.indent(2) << "f->emitError(\"no reference impl. for \" + opName);\n";
+  os.indent(2) << "return;\n";
+  os << "}\n";
+  os << "}\n";
+}
+
+static mlir::GenRegistration
+    genRegister("gen-reference-implementations",
+                "Generate reference implemenations",
+                [](const RecordKeeper &records, raw_ostream &os) {
+                  emitReferenceImplementations(records, os);
+                  return false;
+                });
diff --git a/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp b/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp
new file mode 100644
index 0000000..3487eda
--- /dev/null
+++ b/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -0,0 +1,784 @@
+//===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/STLExtras.h"
+#include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/Format.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Operator.h"
+#include "mlir/TableGen/Pattern.h"
+#include "mlir/TableGen/Predicate.h"
+#include "mlir/TableGen/Type.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FormatAdapters.h"
+#include "llvm/Support/PrettyStackTrace.h"
+#include "llvm/Support/Signals.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Main.h"
+#include "llvm/TableGen/Record.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::tblgen;
+
+namespace llvm {
+template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
+  static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
+                     raw_ostream &os, StringRef style) {
+    os << v.first << ":" << v.second;
+  }
+};
+} // end namespace llvm
+
+//===----------------------------------------------------------------------===//
+// PatternEmitter
+//===----------------------------------------------------------------------===//
+
+namespace {
+class PatternEmitter {
+public:
+  PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
+
+  // Emits the mlir::RewritePattern struct named `rewriteName`.
+  void emit(StringRef rewriteName);
+
+private:
+  // Emits the code for matching ops.
+  void emitMatchLogic(DagNode tree);
+
+  // Emits the code for rewriting ops.
+  void emitRewriteLogic();
+
+  //===--------------------------------------------------------------------===//
+  // Match utilities
+  //===--------------------------------------------------------------------===//
+
+  // Emits C++ statements for matching the op constrained by the given DAG
+  // `tree`.
+  void emitOpMatch(DagNode tree, int depth);
+
+  // Emits C++ statements for matching the `index`-th argument of the given DAG
+  // `tree` as an operand.
+  void emitOperandMatch(DagNode tree, int index, int depth, int indent);
+
+  // Emits C++ statements for matching the `index`-th argument of the given DAG
+  // `tree` as an attribute.
+  void emitAttributeMatch(DagNode tree, int index, int depth, int indent);
+
+  //===--------------------------------------------------------------------===//
+  // Rewrite utilities
+  //===--------------------------------------------------------------------===//
+
+  // Entry point for handling a result pattern rooted at `resultTree` and
+  // dispatches to concrete handlers. The given tree is the `resultIndex`-th
+  // argument of the enclosing DAG.
+  std::string handleResultPattern(DagNode resultTree, int resultIndex,
+                                  int depth);
+
+  // Emits the C++ statement to replace the matched DAG with a value built via
+  // calling native C++ code.
+  std::string handleReplaceWithNativeCodeCall(DagNode resultTree);
+
+  // Returns the C++ expression referencing the old value serving as the
+  // replacement.
+  std::string handleReplaceWithValue(DagNode tree);
+
+  // Emits the C++ statement to build a new op out of the given DAG `tree` and
+  // returns the variable name that this op is assigned to. If the root op in
+  // DAG `tree` has a specified name, the created op will be assigned to a
+  // variable of the given name. Otherwise, a unique name will be used as the
+  // result value name.
+  std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
+
+  // Returns the C++ expression to construct a constant attribute of the given
+  // `value` for the given attribute kind `attr`.
+  std::string handleConstantAttr(Attribute attr, StringRef value);
+
+  // Returns the C++ expression to build an argument from the given DAG `leaf`.
+  // `patArgName` is used to bound the argument to the source pattern.
+  std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
+
+  //===--------------------------------------------------------------------===//
+  // General utilities
+  //===--------------------------------------------------------------------===//
+
+  // Collects all of the operations within the given dag tree.
+  void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
+
+  // Returns a unique symbol for a local variable of the given `op`.
+  std::string getUniqueSymbol(const Operator *op);
+
+  //===--------------------------------------------------------------------===//
+  // Symbol utilities
+  //===--------------------------------------------------------------------===//
+
+  // Gets the substitution for `symbol`. Aborts if `symbol` is not bound.
+  std::string resolveSymbol(StringRef symbol);
+
+  // Returns how many static values the given DAG `node` correspond to.
+  int getNodeValueCount(DagNode node);
+
+private:
+  // Pattern instantiation location followed by the location of multiclass
+  // prototypes used. This is intended to be used as a whole to
+  // PrintFatalError() on errors.
+  ArrayRef<llvm::SMLoc> loc;
+
+  // Op's TableGen Record to wrapper object.
+  RecordOperatorMap *opMap;
+
+  // Handy wrapper for pattern being emitted.
+  Pattern pattern;
+
+  // Map for all bound symbols' info.
+  SymbolInfoMap symbolInfoMap;
+
+  // The next unused ID for newly created values.
+  unsigned nextValueId;
+
+  raw_ostream &os;
+
+  // Format contexts containing placeholder substitutations.
+  FmtContext fmtCtx;
+
+  // Number of op processed.
+  int opCounter = 0;
+};
+} // end anonymous namespace
+
+PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
+                               raw_ostream &os)
+    : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
+      symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
+  fmtCtx.withBuilder("rewriter");
+}
+
+std::string PatternEmitter::handleConstantAttr(Attribute attr,
+                                               StringRef value) {
+  if (!attr.isConstBuildable())
+    PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
+                             " does not have the 'constBuilderCall' field");
+
+  // TODO(jpienaar): Verify the constants here
+  return tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value);
+}
+
+// Helper function to match patterns.
+void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
+  Operator &op = tree.getDialectOp(opMap);
+  if (op.isVariadic()) {
+    PrintFatalError(loc, formatv("matching op '{0}' with variadic "
+                                 "operands/results is unsupported right now",
+                                 op.getOperationName()));
+  }
+
+  int indent = 4 + 2 * depth;
+  os.indent(indent) << formatv(
+      "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n",
+      depth, op.getQualCppClassName());
+  // Skip the operand matching at depth 0 as the pattern rewriter already does.
+  if (depth != 0) {
+    // Skip if there is no defining operation (e.g., arguments to function).
+    os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n",
+                                 depth);
+  }
+  if (tree.getNumArgs() != op.getNumArgs()) {
+    PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
+                                 "pattern vs. {2} in definition",
+                                 op.getOperationName(), tree.getNumArgs(),
+                                 op.getNumArgs()));
+  }
+
+  // If the operand's name is set, set to that variable.
+  auto name = tree.getSymbol();
+  if (!name.empty())
+    os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth);
+
+  for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
+    auto opArg = op.getArg(i);
+
+    // Handle nested DAG construct first
+    if (DagNode argTree = tree.getArgAsNestedDag(i)) {
+      os.indent(indent) << "{\n";
+      os.indent(indent + 2)
+          << formatv("auto *op{0} = op{1}->getOperand({2})->getDefiningOp();\n",
+                     depth + 1, depth, i);
+      emitOpMatch(argTree, depth + 1);
+      os.indent(indent + 2)
+          << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
+      os.indent(indent) << "}\n";
+      continue;
+    }
+
+    // Next handle DAG leaf: operand or attribute
+    if (opArg.is<NamedTypeConstraint *>()) {
+      emitOperandMatch(tree, i, depth, indent);
+    } else if (opArg.is<NamedAttribute *>()) {
+      emitAttributeMatch(tree, i, depth, indent);
+    } else {
+      PrintFatalError(loc, "unhandled case when matching op");
+    }
+  }
+}
+
+void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
+                                      int indent) {
+  Operator &op = tree.getDialectOp(opMap);
+  auto *operand = op.getArg(index).get<NamedTypeConstraint *>();
+  auto matcher = tree.getArgAsLeaf(index);
+
+  // If a constraint is specified, we need to generate C++ statements to
+  // check the constraint.
+  if (!matcher.isUnspecified()) {
+    if (!matcher.isOperandMatcher()) {
+      PrintFatalError(
+          loc, formatv("the {1}-th argument of op '{0}' should be an operand",
+                       op.getOperationName(), index + 1));
+    }
+
+    // Only need to verify if the matcher's type is different from the one
+    // of op definition.
+    if (operand->constraint != matcher.getAsConstraint()) {
+      auto self = formatv("op{0}->getOperand({1})->getType()", depth, index);
+      os.indent(indent) << "if (!("
+                        << tgfmt(matcher.getConditionTemplate(),
+                                 &fmtCtx.withSelf(self))
+                        << ")) return matchFailure();\n";
+    }
+  }
+
+  // Capture the value
+  auto name = tree.getArgName(index);
+  if (!name.empty()) {
+    os.indent(indent) << formatv("{0} = op{1}->getOperand({2});\n", name, depth,
+                                 index);
+  }
+}
+
+void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
+                                        int indent) {
+  Operator &op = tree.getDialectOp(opMap);
+  auto *namedAttr = op.getArg(index).get<NamedAttribute *>();
+  const auto &attr = namedAttr->attr;
+
+  os.indent(indent) << "{\n";
+  indent += 2;
+  os.indent(indent) << formatv(
+      "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth,
+      attr.getStorageType(), namedAttr->name);
+
+  // TODO(antiagainst): This should use getter method to avoid duplication.
+  if (attr.hasDefaultValueInitializer()) {
+    os.indent(indent) << "if (!tblgen_attr) tblgen_attr = "
+                      << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
+                               attr.getDefaultValueInitializer())
+                      << ";\n";
+  } else if (attr.isOptional()) {
+    // For a missing attribute that is optional according to definition, we
+    // should just capature a mlir::Attribute() to signal the missing state.
+    // That is precisely what getAttr() returns on missing attributes.
+  } else {
+    os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n";
+  }
+
+  auto matcher = tree.getArgAsLeaf(index);
+  if (!matcher.isUnspecified()) {
+    if (!matcher.isAttrMatcher()) {
+      PrintFatalError(
+          loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
+                       op.getOperationName(), index + 1));
+    }
+
+    // If a constraint is specified, we need to generate C++ statements to
+    // check the constraint.
+    os.indent(indent) << "if (!("
+                      << tgfmt(matcher.getConditionTemplate(),
+                               &fmtCtx.withSelf("tblgen_attr"))
+                      << ")) return matchFailure();\n";
+  }
+
+  // Capture the value
+  auto name = tree.getArgName(index);
+  if (!name.empty()) {
+    os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
+  }
+
+  indent -= 2;
+  os.indent(indent) << "}\n";
+}
+
+void PatternEmitter::emitMatchLogic(DagNode tree) {
+  emitOpMatch(tree, 0);
+
+  for (auto &appliedConstraint : pattern.getConstraints()) {
+    auto &constraint = appliedConstraint.constraint;
+    auto &entities = appliedConstraint.entities;
+
+    auto condition = constraint.getConditionTemplate();
+    auto cmd = "if (!({0})) return matchFailure();\n";
+
+    if (isa<TypeConstraint>(constraint)) {
+      auto self = formatv("({0}->getType())", resolveSymbol(entities.front()));
+      os.indent(4) << formatv(cmd,
+                              tgfmt(condition, &fmtCtx.withSelf(self.str())));
+    } else if (isa<AttrConstraint>(constraint)) {
+      PrintFatalError(
+          loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
+    } else {
+      // TODO(b/138794486): replace formatv arguments with the exact specified
+      // args.
+      if (entities.size() > 4) {
+        PrintFatalError(loc, "only support up to 4-entity constraints now");
+      }
+      SmallVector<std::string, 4> names;
+      int i = 0;
+      for (int e = entities.size(); i < e; ++i)
+        names.push_back(resolveSymbol(entities[i]));
+      std::string self = appliedConstraint.self;
+      if (!self.empty())
+        self = resolveSymbol(self);
+      for (; i < 4; ++i)
+        names.push_back("<unused>");
+      os.indent(4) << formatv(cmd,
+                              tgfmt(condition, &fmtCtx.withSelf(self), names[0],
+                                    names[1], names[2], names[3]));
+    }
+  }
+}
+
+void PatternEmitter::collectOps(DagNode tree,
+                                llvm::SmallPtrSetImpl<const Operator *> &ops) {
+  // Check if this tree is an operation.
+  if (tree.isOperation())
+    ops.insert(&tree.getDialectOp(opMap));
+
+  // Recurse the arguments of the tree.
+  for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
+    if (auto child = tree.getArgAsNestedDag(i))
+      collectOps(child, ops);
+}
+
+void PatternEmitter::emit(StringRef rewriteName) {
+  // Get the DAG tree for the source pattern.
+  DagNode sourceTree = pattern.getSourcePattern();
+
+  const Operator &rootOp = pattern.getSourceRootOp();
+  auto rootName = rootOp.getOperationName();
+
+  // Collect the set of result operations.
+  llvm::SmallPtrSet<const Operator *, 4> resultOps;
+  for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i)
+    collectOps(pattern.getResultPattern(i), resultOps);
+
+  // Emit RewritePattern for Pattern.
+  auto locs = pattern.getLocation();
+  os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
+                make_range(locs.rbegin(), locs.rend()));
+  os << formatv(R"(struct {0} : public RewritePattern {
+  {0}(MLIRContext *context)
+      : RewritePattern("{1}", {{)",
+                rewriteName, rootName);
+  interleaveComma(resultOps, os, [&](const Operator *op) {
+    os << '"' << op->getOperationName() << '"';
+  });
+  os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
+
+  // Emit matchAndRewrite() function.
+  os << R"(
+  PatternMatchResult matchAndRewrite(Operation *op0,
+                                     PatternRewriter &rewriter) const override {
+)";
+
+  // Register all symbols bound in the source pattern.
+  pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
+
+  os.indent(4) << "// Variables for capturing values and attributes used for "
+                  "creating ops\n";
+  // Create local variables for storing the arguments and results bound
+  // to symbols.
+  for (const auto &symbolInfoPair : symbolInfoMap) {
+    StringRef symbol = symbolInfoPair.getKey();
+    auto &info = symbolInfoPair.getValue();
+    os.indent(4) << info.getVarDecl(symbol);
+  }
+  // TODO(jpienaar): capture ops with consistent numbering so that it can be
+  // reused for fused loc.
+  os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n",
+                          pattern.getSourcePattern().getNumOps());
+
+  os.indent(4) << "// Match\n";
+  os.indent(4) << "tblgen_ops[0] = op0;\n";
+  emitMatchLogic(sourceTree);
+  os << "\n";
+
+  os.indent(4) << "// Rewrite\n";
+  emitRewriteLogic();
+
+  os.indent(4) << "return matchSuccess();\n";
+  os << "  };\n";
+  os << "};\n";
+}
+
+void PatternEmitter::emitRewriteLogic() {
+  const Operator &rootOp = pattern.getSourceRootOp();
+  int numExpectedResults = rootOp.getNumResults();
+  int numResultPatterns = pattern.getNumResultPatterns();
+
+  // First register all symbols bound to ops generated in result patterns.
+  pattern.collectResultPatternBoundSymbols(symbolInfoMap);
+
+  // Only the last N static values generated are used to replace the matched
+  // root N-result op. We need to calculate the starting index (of the results
+  // of the matched op) each result pattern is to replace.
+  SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
+  // If we don't need to replace any value at all, set the replacement starting
+  // index as the number of result patterns so we skip all of them when trying
+  // to replace the matched op's results.
+  int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
+  for (int i = numResultPatterns - 1; i >= 0; --i) {
+    auto numValues = getNodeValueCount(pattern.getResultPattern(i));
+    offsets[i] = offsets[i + 1] - numValues;
+    if (offsets[i] == 0) {
+      if (replStartIndex == -1)
+        replStartIndex = i;
+    } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
+      auto error = formatv(
+          "cannot use the same multi-result op '{0}' to generate both "
+          "auxiliary values and values to be used for replacing the matched op",
+          pattern.getResultPattern(i).getSymbol());
+      PrintFatalError(loc, error);
+    }
+  }
+
+  if (offsets.front() > 0) {
+    const char error[] = "no enough values generated to replace the matched op";
+    PrintFatalError(loc, error);
+  }
+
+  os.indent(4) << "auto loc = rewriter.getFusedLoc({";
+  for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
+    os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
+  }
+  os << "}); (void)loc;\n";
+
+  // Collect the replacement value for each result
+  llvm::SmallVector<std::string, 2> resultValues;
+  for (int i = 0; i < numResultPatterns; ++i) {
+    DagNode resultTree = pattern.getResultPattern(i);
+    resultValues.push_back(handleResultPattern(resultTree, offsets[i], 0));
+  }
+
+  // Emit the final replaceOp() statement
+  os.indent(4) << "rewriter.replaceOp(op0, {";
+  interleaveComma(
+      ArrayRef<std::string>(resultValues).drop_front(replStartIndex), os,
+      [&](const std::string &symbol) { os << resolveSymbol(symbol); });
+  os << "});\n";
+}
+
+std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
+  return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++);
+}
+
+std::string PatternEmitter::handleResultPattern(DagNode resultTree,
+                                                int resultIndex, int depth) {
+  if (resultTree.isNativeCodeCall()) {
+    auto symbol = handleReplaceWithNativeCodeCall(resultTree);
+    symbolInfoMap.bindValue(symbol);
+    return symbol;
+  }
+
+  if (resultTree.isReplaceWithValue()) {
+    return handleReplaceWithValue(resultTree);
+  }
+
+  // Normal op creation.
+  auto symbol = handleOpCreation(resultTree, resultIndex, depth);
+  if (resultTree.getSymbol().empty()) {
+    // This is an op not explicitly bound to a symbol in the rewrite rule.
+    // Register the auto-generated symbol for it.
+    symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
+  }
+  return symbol;
+}
+
+std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
+  assert(tree.isReplaceWithValue());
+
+  if (tree.getNumArgs() != 1) {
+    PrintFatalError(
+        loc, "replaceWithValue directive must take exactly one argument");
+  }
+
+  if (!tree.getSymbol().empty()) {
+    PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
+  }
+
+  return resolveSymbol(tree.getArgName(0));
+}
+
+std::string PatternEmitter::handleOpArgument(DagLeaf leaf, StringRef argName) {
+  if (leaf.isConstantAttr()) {
+    auto constAttr = leaf.getAsConstantAttr();
+    return handleConstantAttr(constAttr.getAttribute(),
+                              constAttr.getConstantValue());
+  }
+  if (leaf.isEnumAttrCase()) {
+    auto enumCase = leaf.getAsEnumAttrCase();
+    if (enumCase.isStrCase())
+      return handleConstantAttr(enumCase, enumCase.getSymbol());
+    // This is an enum case backed by an IntegerAttr. We need to get its value
+    // to build the constant.
+    std::string val = std::to_string(enumCase.getValue());
+    return handleConstantAttr(enumCase, val);
+  }
+  if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
+    return argName;
+  }
+  if (leaf.isNativeCodeCall()) {
+    return tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
+  }
+  PrintFatalError(loc, "unhandled case when rewriting op");
+}
+
+std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
+  auto fmt = tree.getNativeCodeTemplate();
+  // TODO(b/138794486): replace formatv arguments with the exact specified args.
+  SmallVector<std::string, 8> attrs(8);
+  if (tree.getNumArgs() > 8) {
+    PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
+                             Twine(tree.getNumArgs()));
+  }
+  for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
+    attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
+  }
+  return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4],
+               attrs[5], attrs[6], attrs[7]);
+}
+
+std::string PatternEmitter::resolveSymbol(StringRef symbol) {
+  auto subst = symbolInfoMap.getValueAndRangeUse(symbol);
+  if (subst.empty()) {
+    PrintFatalError(loc, formatv("referencing unbound symbol '{0}'", symbol));
+  }
+  return subst;
+}
+
+int PatternEmitter::getNodeValueCount(DagNode node) {
+  if (node.isOperation()) {
+    // If the op is bound to a symbol in the rewrite rule, query its result
+    // count from the symbol info map.
+    auto symbol = node.getSymbol();
+    if (!symbol.empty()) {
+      return symbolInfoMap.getStaticValueCount(symbol);
+    }
+    // Otherwise this is an unbound op; we will use all its results.
+    return pattern.getDialectOp(node).getNumResults();
+  }
+  // TODO(antiagainst): This considers all NativeCodeCall as returning one
+  // value. Enhance if multi-value ones are needed.
+  return 1;
+}
+
+std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
+                                             int depth) {
+  Operator &resultOp = tree.getDialectOp(opMap);
+  auto numOpArgs = resultOp.getNumArgs();
+
+  if (resultOp.isVariadic()) {
+    PrintFatalError(loc, formatv("generating op '{0}' with variadic "
+                                 "operands/results is unsupported now",
+                                 resultOp.getOperationName()));
+  }
+
+  if (numOpArgs != tree.getNumArgs()) {
+    PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
+                                 "{1} in pattern vs. {2} in definition",
+                                 resultOp.getOperationName(), tree.getNumArgs(),
+                                 numOpArgs));
+  }
+
+  // A map to collect all nested DAG child nodes' names, with operand index as
+  // the key. This includes both bound and unbound child nodes. Bound child
+  // nodes will additionally be tracked in `symbolResolver` so they can be
+  // referenced by other patterns. Unbound child nodes will only be used once
+  // to build this op.
+  llvm::DenseMap<unsigned, std::string> childNodeNames;
+
+  // First go through all the child nodes who are nested DAG constructs to
+  // create ops for them, so that we can use the results in the current node.
+  // This happens in a recursive manner.
+  for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
+    if (auto child = tree.getArgAsNestedDag(i)) {
+      childNodeNames[i] = handleResultPattern(child, i, depth + 1);
+    }
+  }
+
+  // Use the specified name for this op if available. Generate one otherwise.
+  std::string resultValue = tree.getSymbol();
+  if (resultValue.empty())
+    resultValue = getUniqueSymbol(&resultOp);
+  // Strip the index to get the name for the value pack. This will be used to
+  // name the local variable for the op.
+  StringRef valuePackName = SymbolInfoMap::getValuePackName(resultValue);
+
+  // Then we build the new op corresponding to this DAG node.
+
+  // Right now we don't have general type inference in MLIR. Except a few
+  // special cases listed below, we need to supply types for all results
+  // when building an op.
+  bool isSameOperandsAndResultType =
+      resultOp.hasTrait("SameOperandsAndResultType");
+  bool isBroadcastable = resultOp.hasTrait("BroadcastableTwoOperandsOneResult");
+  bool useFirstAttr = resultOp.hasTrait("FirstAttrDerivedResultType");
+  bool usePartialResults = valuePackName != resultValue;
+
+  if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr ||
+      usePartialResults || depth > 0 || resultIndex < 0) {
+    os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc",
+                            valuePackName, resultOp.getQualCppClassName());
+  } else {
+    // If depth == 0 and resultIndex >= 0, it means we are replacing the values
+    // generated from the source pattern root op. Then we can use the source
+    // pattern's value types to determine the value type of the generated op
+    // here.
+
+    // We need to specify the types for all results.
+    SmallVector<std::string, 4> resultTypes;
+    int numResults = resultOp.getNumResults();
+    resultTypes.reserve(numResults);
+    for (int i = 0; i < numResults; ++i) {
+      resultTypes.push_back(
+          formatv("op0->getResult({0})->getType()", resultIndex + i));
+    }
+
+    os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc",
+                            valuePackName, resultOp.getQualCppClassName())
+                 << (resultTypes.empty() ? "" : ", ")
+                 << llvm::join(resultTypes, ", ");
+  }
+
+  // Create the builder call for the result.
+  // Add operands.
+  int argIndex = 0;
+  for (int e = resultOp.getNumOperands(); argIndex < e; ++argIndex) {
+    const auto &operand = resultOp.getOperand(argIndex);
+
+    // Start each operand on its own line.
+    (os << ",\n").indent(6);
+
+    if (!operand.name.empty())
+      os << "/*" << operand.name << "=*/";
+
+    if (tree.isNestedDagArg(argIndex)) {
+      os << childNodeNames[argIndex];
+    } else {
+      DagLeaf leaf = tree.getArgAsLeaf(argIndex);
+      auto symbol = resolveSymbol(tree.getArgName(argIndex));
+      if (leaf.isNativeCodeCall()) {
+        os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
+      } else {
+        os << symbol;
+      }
+    }
+    // TODO(jpienaar): verify types
+  }
+
+  // Add attributes.
+  for (; argIndex != numOpArgs; ++argIndex) {
+    // Start each attribute on its own line.
+    (os << ",\n").indent(6);
+    // The argument in the op definition.
+    auto opArgName = resultOp.getArgName(argIndex);
+    if (auto subTree = tree.getArgAsNestedDag(argIndex)) {
+      if (!subTree.isNativeCodeCall())
+        PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
+                             "for creating attribute");
+      os << formatv("/*{0}=*/{1}", opArgName,
+                    handleReplaceWithNativeCodeCall(subTree));
+    } else {
+      auto leaf = tree.getArgAsLeaf(argIndex);
+      // The argument in the result DAG pattern.
+      auto patArgName = tree.getArgName(argIndex);
+      if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
+        // TODO(jpienaar): Refactor out into map to avoid recomputing these.
+        auto argument = resultOp.getArg(argIndex);
+        if (!argument.is<NamedAttribute *>())
+          PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
+        if (!patArgName.empty())
+          os << "/*" << patArgName << "=*/";
+      } else {
+        os << "/*" << opArgName << "=*/";
+      }
+      os << handleOpArgument(leaf, patArgName);
+    }
+  }
+  os << "\n    );\n";
+
+  return resultValue;
+}
+
+static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
+  emitSourceFileHeader("Rewriters", os);
+
+  const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
+  auto numPatterns = patterns.size();
+
+  // We put the map here because it can be shared among multiple patterns.
+  RecordOperatorMap recordOpMap;
+
+  std::vector<std::string> rewriterNames;
+  rewriterNames.reserve(numPatterns);
+
+  std::string baseRewriterName = "GeneratedConvert";
+  int rewriterIndex = 0;
+
+  for (Record *p : patterns) {
+    std::string name;
+    if (p->isAnonymous()) {
+      // If no name is provided, ensure unique rewriter names simply by
+      // appending unique suffix.
+      name = baseRewriterName + llvm::utostr(rewriterIndex++);
+    } else {
+      name = p->getName();
+    }
+    PatternEmitter(p, &recordOpMap, os).emit(name);
+    rewriterNames.push_back(std::move(name));
+  }
+
+  // Emit function to add the generated matchers to the pattern list.
+  os << "void populateWithGenerated(MLIRContext *context, "
+     << "OwningRewritePatternList *patterns) {\n";
+  for (const auto &name : rewriterNames) {
+    os << "  patterns->insert<" << name << ">(context);\n";
+  }
+  os << "}\n";
+}
+
+static mlir::GenRegistration
+    genRewriters("gen-rewriters", "Generate pattern rewriters",
+                 [](const RecordKeeper &records, raw_ostream &os) {
+                   emitRewriters(records, os);
+                   return false;
+                 });
diff --git a/third_party/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/third_party/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
new file mode 100644
index 0000000..d948ec5
--- /dev/null
+++ b/third_party/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -0,0 +1,465 @@
+//===- SPIRVSerializationGen.cpp - SPIR-V serialization utility generator -===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// SPIRVSerializationGen generates common utility functions for SPIR-V
+// serialization.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/StringExtras.h"
+#include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Operator.h"
+#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+using llvm::ArrayRef;
+using llvm::formatv;
+using llvm::raw_ostream;
+using llvm::raw_string_ostream;
+using llvm::Record;
+using llvm::RecordKeeper;
+using llvm::SMLoc;
+using llvm::StringRef;
+using llvm::Twine;
+using mlir::tblgen::Attribute;
+using mlir::tblgen::EnumAttr;
+using mlir::tblgen::NamedAttribute;
+using mlir::tblgen::NamedTypeConstraint;
+using mlir::tblgen::Operator;
+
+// Writes the following function to `os`:
+//   inline uint32_t getOpcode(<op-class-name>) { return <opcode>; }
+static void emitGetOpcodeFunction(const Record *record, Operator const &op,
+                                  raw_ostream &os) {
+  os << formatv("template <> constexpr inline ::mlir::spirv::Opcode "
+                "getOpcode<{0}>()",
+                op.getQualCppClassName())
+     << " {\n  "
+     << formatv("return ::mlir::spirv::Opcode::{0};\n}\n",
+                record->getValueAsString("spirvOpName"));
+}
+
+static void declareOpcodeFn(raw_ostream &os) {
+  os << "template <typename OpClass> inline constexpr ::mlir::spirv::Opcode "
+        "getOpcode();\n";
+}
+
+static void emitAttributeSerialization(const Attribute &attr,
+                                       ArrayRef<SMLoc> loc, llvm::StringRef op,
+                                       llvm::StringRef operandList,
+                                       llvm::StringRef attrName,
+                                       raw_ostream &os) {
+  os << "    auto attr = " << op << ".getAttr(\"" << attrName << "\");\n";
+  os << "    if (attr) {\n";
+  if (attr.getAttrDefName() == "I32ArrayAttr") {
+    // Serialize all the elements of the array
+    os << "      for (auto attrElem : attr.cast<ArrayAttr>()) {\n";
+    os << "        " << operandList
+       << ".push_back(static_cast<uint32_t>(attrElem.cast<IntegerAttr>()."
+          "getValue().getZExtValue()));\n";
+    os << "      }\n";
+  } else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") {
+    os << "      " << operandList
+       << ".push_back(static_cast<uint32_t>(attr.cast<IntegerAttr>().getValue()"
+          ".getZExtValue()));\n";
+  } else {
+    PrintFatalError(
+        loc,
+        llvm::Twine(
+            "unhandled attribute type in SPIR-V serialization generation : '") +
+            attr.getAttrDefName() + llvm::Twine("'"));
+  }
+  os << "    }\n";
+}
+
+static void emitSerializationFunction(const Record *attrClass,
+                                      const Record *record, const Operator &op,
+                                      raw_ostream &os) {
+  // If the record has 'autogenSerialization' set to 0, nothing to do
+  if (!record->getValueAsBit("autogenSerialization")) {
+    return;
+  }
+  os << formatv("template <> LogicalResult\nSerializer::processOp<{0}>(\n"
+                "  {0} op)",
+                op.getQualCppClassName())
+     << " {\n";
+  os << "  SmallVector<uint32_t, 4> operands;\n";
+  os << "  SmallVector<StringRef, 2> elidedAttrs;\n";
+
+  // Serialize result information
+  if (op.getNumResults() == 1) {
+    os << "  uint32_t resultTypeID = 0;\n";
+    os << "  if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) "
+          "{\n";
+    os << "    return failure();\n";
+    os << "  }\n";
+    os << "  operands.push_back(resultTypeID);\n";
+    // Create an SSA result <id> for the op
+    os << "  auto resultID = getNextID();\n";
+    os << "  valueIDMap[op.getResult()] = resultID;\n";
+    os << "  operands.push_back(resultID);\n";
+  } else if (op.getNumResults() != 0) {
+    PrintFatalError(record->getLoc(), "SPIR-V ops can only zero or one result");
+  }
+
+  // Process arguments
+  auto operandNum = 0;
+  for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
+    auto argument = op.getArg(i);
+    os << "  {\n";
+    if (argument.is<NamedTypeConstraint *>()) {
+      os << "    for (auto arg : op.getODSOperands(" << operandNum << ")) {\n";
+      os << "      auto argID = findValueID(arg);\n";
+      os << "      if (!argID) {\n";
+      os << "        emitError(op.getLoc(), \"operand " << operandNum
+         << " has a use before def\");\n";
+      os << "      }\n";
+      os << "      operands.push_back(argID);\n";
+      os << "    }\n";
+      operandNum++;
+    } else {
+      auto attr = argument.get<NamedAttribute *>();
+      emitAttributeSerialization(
+          (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
+          record->getLoc(), "op", "operands", attr->name, os);
+      os << "    elidedAttrs.push_back(\"" << attr->name << "\");\n";
+    }
+    os << "  }\n";
+  }
+
+  os << formatv("  encodeInstructionInto("
+                "functions, spirv::getOpcode<{0}>(), operands);\n",
+                op.getQualCppClassName());
+
+  if (op.getNumResults() == 1) {
+    // All non-argument attributes translated into OpDecorate instruction
+    os << "  for (auto attr : op.getAttrs()) {\n";
+    os << "    if (llvm::any_of(elidedAttrs, [&](StringRef elided) { return "
+          "attr.first.is(elided); })) {\n";
+    os << "      continue;\n";
+    os << "    }\n";
+    os << "    if (failed(processDecoration(op.getLoc(), resultID, attr))) {\n";
+    os << "      return failure();";
+    os << "    }\n";
+    os << "  }\n";
+  }
+
+  os << "  return success();\n";
+  os << "}\n\n";
+}
+
+static void initDispatchSerializationFn(raw_ostream &os) {
+  os << "LogicalResult Serializer::dispatchToAutogenSerialization(Operation "
+        "*op) {\n ";
+}
+
+static void emitSerializationDispatch(const Operator &op, raw_ostream &os) {
+  os << formatv(" if (isa<{0}>(op)) ", op.getQualCppClassName()) << "{\n";
+  os << "    ";
+  os << formatv("return processOp<{0}>(cast<{0}>(op));\n",
+                op.getQualCppClassName());
+  os << "  } else";
+}
+
+static void finalizeDispatchSerializationFn(raw_ostream &os) {
+  os << " {\n";
+  os << "    return op->emitError(\"unhandled operation serialization\");\n";
+  os << "  }\n";
+  os << "  return success();\n";
+  os << "}\n\n";
+}
+
+static void emitAttributeDeserialization(
+    const Attribute &attr, ArrayRef<SMLoc> loc, llvm::StringRef attrList,
+    llvm::StringRef attrName, llvm::StringRef operandsList,
+    llvm::StringRef wordIndex, llvm::StringRef wordCount, raw_ostream &os) {
+  if (attr.getAttrDefName() == "I32ArrayAttr") {
+    os << "    SmallVector<Attribute, 4> attrListElems;\n";
+    os << "    while (" << wordIndex << " < " << wordCount << ") {\n";
+    os << "      attrListElems.push_back(opBuilder.getI32IntegerAttr("
+       << operandsList << "[" << wordIndex << "++]));\n";
+    os << "    }\n";
+    os << "    " << attrList << ".push_back(opBuilder.getNamedAttr(\""
+       << attrName << "\", opBuilder.getArrayAttr(attrListElems)));\n";
+  } else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") {
+    os << "    " << attrList << ".push_back(opBuilder.getNamedAttr(\""
+       << attrName << "\", opBuilder.getI32IntegerAttr(" << operandsList << "["
+       << wordIndex << "++])));\n";
+  } else {
+    PrintFatalError(
+        loc, llvm::Twine(
+                 "unhandled attribute type in deserialization generation : '") +
+                 attr.getAttrDefName() + llvm::Twine("'"));
+  }
+}
+
+static void emitDeserializationFunction(const Record *attrClass,
+                                        const Record *record,
+                                        const Operator &op, raw_ostream &os) {
+  // If the record has 'autogenSerialization' set to 0, nothing to do
+  if (!record->getValueAsBit("autogenSerialization")) {
+    return;
+  }
+  os << formatv("template <> "
+                "LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<"
+                "uint32_t> words)",
+                op.getQualCppClassName());
+  os << " {\n";
+  os << "  SmallVector<Type, 1> resultTypes;\n";
+  os << "  size_t wordIndex = 0; (void)wordIndex;\n";
+
+  // Deserialize result information if it exists
+  bool hasResult = false;
+  if (op.getNumResults() == 1) {
+    os << "  {\n";
+    os << "    if (wordIndex >= words.size()) {\n";
+    os << "      "
+       << formatv("return emitError(unknownLoc, \"expected result type <id> "
+                  "while deserializing {0}\");\n",
+                  op.getQualCppClassName());
+    os << "    }\n";
+    os << "    auto ty = getType(words[wordIndex]);\n";
+    os << "    if (!ty) {\n";
+    os << "      return emitError(unknownLoc, \"unknown type result <id> : "
+          "\") << words[wordIndex];\n";
+    os << "    }\n";
+    os << "    resultTypes.push_back(ty);\n";
+    os << "    wordIndex++;\n";
+    os << "  }\n";
+    os << "  if (wordIndex >= words.size()) {\n";
+    os << "    "
+       << formatv("return emitError(unknownLoc, \"expected result <id> while "
+                  "deserializing {0}\");\n",
+                  op.getQualCppClassName());
+    os << "  }\n";
+    os << "  uint32_t valueID = words[wordIndex++];\n";
+    hasResult = true;
+  } else if (op.getNumResults() != 0) {
+    PrintFatalError(record->getLoc(),
+                    "SPIR-V ops can have only zero or one result");
+  }
+
+  // Process operands/attributes
+  os << "  SmallVector<Value *, 4> operands;\n";
+  os << "  SmallVector<NamedAttribute, 4> attributes;\n";
+  unsigned operandNum = 0;
+  for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
+    auto argument = op.getArg(i);
+    if (auto valueArg = argument.dyn_cast<NamedTypeConstraint *>()) {
+      if (valueArg->isVariadic()) {
+        if (i != e - 1) {
+          PrintFatalError(record->getLoc(),
+                          "SPIR-V ops can have Variadic<..> argument only if "
+                          "it's the last argument");
+        }
+        os << "  for (; wordIndex < words.size(); ++wordIndex)";
+      } else {
+        os << "  if (wordIndex < words.size())";
+      }
+      os << " {\n";
+      os << "    auto arg = getValue(words[wordIndex]);\n";
+      os << "    if (!arg) {\n";
+      os << "      return emitError(unknownLoc, \"unknown result <id> : \") << "
+            "words[wordIndex];\n";
+      os << "    }\n";
+      os << "    operands.push_back(arg);\n";
+      if (!valueArg->isVariadic()) {
+        os << "    wordIndex++;\n";
+      }
+      operandNum++;
+      os << "  }\n";
+    } else {
+      os << "  if (wordIndex < words.size()) {\n";
+      auto attr = argument.get<NamedAttribute *>();
+      emitAttributeDeserialization(
+          (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
+          record->getLoc(), "attributes", attr->name, "words", "wordIndex",
+          "words.size()", os);
+      os << "  }\n";
+    }
+  }
+
+  os << "  if (wordIndex != words.size()) {\n";
+  os << "    return emitError(unknownLoc, \"found more operands than expected "
+        "when deserializing "
+     << op.getQualCppClassName()
+     << ", only \") << wordIndex << \" of \" << words.size() << \" "
+        "processed\";\n";
+  os << "  }\n\n";
+
+  // Import decorations parsed
+  if (op.getNumResults() == 1) {
+    os << "  if (decorations.count(valueID)) {\n"
+       << "    auto attrs = decorations[valueID].getAttrs();\n"
+       << "    attributes.append(attrs.begin(), attrs.end());\n"
+       << "  }\n";
+  }
+
+  os << formatv("  auto op = opBuilder.create<{0}>(unknownLoc, resultTypes, "
+                "operands, attributes); (void)op;\n",
+                op.getQualCppClassName());
+  if (hasResult) {
+    os << "  valueMap[valueID] = op.getResult();\n\n";
+  }
+
+  os << "  return success();\n";
+  os << "}\n\n";
+}
+
+static void initDispatchDeserializationFn(raw_ostream &os) {
+  os << "LogicalResult "
+        "Deserializer::dispatchToAutogenDeserialization(spirv::Opcode "
+        "opcode, ArrayRef<uint32_t> words) {\n";
+  os << "  switch (opcode) {\n";
+}
+
+static void emitDeserializationDispatch(const Operator &op, const Record *def,
+                                        raw_ostream &os) {
+  os << formatv("  case spirv::Opcode::{0}:\n",
+                def->getValueAsString("spirvOpName"));
+  os << formatv("    return processOp<{0}>(words);\n",
+                op.getQualCppClassName());
+}
+
+static void finalizeDispatchDeserializationFn(raw_ostream &os) {
+  os << "  default:\n";
+  os << "    ;\n";
+  os << "  }\n";
+  os << "  return emitError(unknownLoc, \"unhandled deserialization of \") << "
+        "spirv::stringifyOpcode(opcode);\n";
+  os << "}\n";
+}
+
+static bool emitSerializationFns(const RecordKeeper &recordKeeper,
+                                 raw_ostream &os) {
+  llvm::emitSourceFileHeader("SPIR-V Serialization Utilities/Functions", os);
+
+  std::string dSerFnString, dDesFnString, serFnString, deserFnString,
+      utilsString;
+  raw_string_ostream dSerFn(dSerFnString), dDesFn(dDesFnString),
+      serFn(serFnString), deserFn(deserFnString), utils(utilsString);
+  auto attrClass = recordKeeper.getClass("Attr");
+
+  declareOpcodeFn(utils);
+  initDispatchSerializationFn(dSerFn);
+  initDispatchDeserializationFn(dDesFn);
+  auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op");
+  for (const auto *def : defs) {
+    if (!def->getValueAsBit("hasOpcode")) {
+      continue;
+    }
+    Operator op(def);
+    emitGetOpcodeFunction(def, op, utils);
+    emitSerializationFunction(attrClass, def, op, serFn);
+    emitSerializationDispatch(op, dSerFn);
+    emitDeserializationFunction(attrClass, def, op, deserFn);
+    emitDeserializationDispatch(op, def, dDesFn);
+  }
+  finalizeDispatchSerializationFn(dSerFn);
+  finalizeDispatchDeserializationFn(dDesFn);
+
+  os << "#ifdef GET_SPIRV_SERIALIZATION_UTILS\n";
+  os << utils.str();
+  os << "#endif // GET_SPIRV_SERIALIZATION_UTILS\n\n";
+
+  os << "#ifdef GET_SERIALIZATION_FNS\n\n";
+  os << serFn.str();
+  os << dSerFn.str();
+  os << "#endif // GET_SERIALIZATION_FNS\n\n";
+
+  os << "#ifdef GET_DESERIALIZATION_FNS\n\n";
+  os << deserFn.str();
+  os << dDesFn.str();
+  os << "#endif // GET_DESERIALIZATION_FNS\n\n";
+
+  return false;
+}
+
+static void emitEnumGetAttrNameFnDecl(raw_ostream &os) {
+  os << formatv("template <typename EnumClass> inline constexpr StringRef "
+                "attributeName();\n");
+}
+
+static void emitEnumGetSymbolizeFnDecl(raw_ostream &os) {
+  os << "template <typename EnumClass> using SymbolizeFnTy = "
+        "llvm::Optional<EnumClass> (*)(StringRef);\n";
+  os << "template <typename EnumClass> inline constexpr "
+        "SymbolizeFnTy<EnumClass> symbolizeEnum();\n";
+}
+
+static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr,
+                                      raw_ostream &os) {
+  auto enumName = enumAttr.getEnumClassName();
+  os << formatv("template <> inline StringRef attributeName<{0}>()", enumName)
+     << " {\n";
+  os << "  "
+     << formatv("static constexpr const char attrName[] = \"{0}\";\n",
+                mlir::convertToSnakeCase(enumName));
+  os << "  return attrName;\n";
+  os << "}\n";
+}
+
+static void emitEnumGetSymbolizeFnDefn(const EnumAttr &enumAttr,
+                                       raw_ostream &os) {
+  auto enumName = enumAttr.getEnumClassName();
+  auto strToSymFnName = enumAttr.getStringToSymbolFnName();
+  os << formatv("template <> inline SymbolizeFnTy<{0}> symbolizeEnum<{0}>()",
+                enumName)
+     << " {\n";
+  os << "  return " << strToSymFnName << ";\n";
+  os << "}\n";
+}
+
+static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) {
+  llvm::emitSourceFileHeader("SPIR-V Op Utilites", os);
+
+  auto defs = recordKeeper.getAllDerivedDefinitions("I32EnumAttr");
+  os << "#ifndef SPIRV_OP_UTILS_H_\n";
+  os << "#define SPIRV_OP_UTILS_H_\n";
+  emitEnumGetAttrNameFnDecl(os);
+  emitEnumGetSymbolizeFnDecl(os);
+  for (const auto *def : defs) {
+    EnumAttr enumAttr(*def);
+    emitEnumGetAttrNameFnDefn(enumAttr, os);
+    emitEnumGetSymbolizeFnDefn(enumAttr, os);
+  }
+  os << "#endif // SPIRV_OP_UTILS_H\n";
+  return false;
+}
+
+// Registers the enum utility generator to mlir-tblgen.
+static mlir::GenRegistration genSerialization(
+    "gen-spirv-serialization",
+    "Generate SPIR-V (de)serialization utilities and functions",
+    [](const RecordKeeper &records, raw_ostream &os) {
+      return emitSerializationFns(records, os);
+    });
+
+static mlir::GenRegistration
+    genOpUtils("gen-spirv-op-utils",
+               "Generate SPIR-V operation utility definitions",
+               [](const RecordKeeper &records, raw_ostream &os) {
+                 return emitOpUtils(records, os);
+               });
diff --git a/third_party/mlir/tools/mlir-tblgen/mlir-tblgen.cpp b/third_party/mlir/tools/mlir-tblgen/mlir-tblgen.cpp
new file mode 100644
index 0000000..0bb5891
--- /dev/null
+++ b/third_party/mlir/tools/mlir-tblgen/mlir-tblgen.cpp
@@ -0,0 +1,91 @@
+//===- mlir-tblgen.cpp - Top-Level TableGen implementation for MLIR -------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains the main function for MLIR's TableGen.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/GenNameParser.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/ManagedStatic.h"
+#include "llvm/Support/PrettyStackTrace.h"
+#include "llvm/Support/Signals.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Main.h"
+#include "llvm/TableGen/Record.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+using namespace llvm;
+using namespace mlir;
+
+static llvm::ManagedStatic<std::vector<GenInfo>> generatorRegistry;
+
+mlir::GenRegistration::GenRegistration(StringRef arg, StringRef description,
+                                       GenFunction function) {
+  generatorRegistry->emplace_back(arg, description, function);
+}
+
+GenNameParser::GenNameParser(llvm::cl::Option &opt)
+    : llvm::cl::parser<const GenInfo *>(opt) {
+  for (const auto &kv : *generatorRegistry) {
+    addLiteralOption(kv.getGenArgument(), &kv, kv.getGenDescription());
+  }
+}
+
+void GenNameParser::printOptionInfo(const llvm::cl::Option &O,
+                                    size_t GlobalWidth) const {
+  GenNameParser *TP = const_cast<GenNameParser *>(this);
+  llvm::array_pod_sort(TP->Values.begin(), TP->Values.end(),
+                       [](const GenNameParser::OptionInfo *VT1,
+                          const GenNameParser::OptionInfo *VT2) {
+                         return VT1->Name.compare(VT2->Name);
+                       });
+  using llvm::cl::parser;
+  parser<const GenInfo *>::printOptionInfo(O, GlobalWidth);
+}
+
+// Generator that prints records.
+GenRegistration printRecords("print-records", "Print all records to stdout",
+                             [](const RecordKeeper &records, raw_ostream &os) {
+                               os << records;
+                               return false;
+                             });
+
+// Generator to invoke.
+const mlir::GenInfo *generator;
+
+// TableGenMain requires a function pointer so this function is passed in which
+// simply wraps the call to the generator.
+static bool MlirTableGenMain(raw_ostream &os, RecordKeeper &records) {
+  assert(generator && "no generator specified");
+  return generator->invoke(records, os);
+}
+
+int main(int argc, char **argv) {
+  sys::PrintStackTraceOnErrorSignal(argv[0]);
+  PrettyStackTraceProgram X(argc, argv);
+  llvm::cl::opt<const mlir::GenInfo *, false, mlir::GenNameParser> generator(
+      "", llvm::cl::desc("Generator to run"));
+  cl::ParseCommandLineOptions(argc, argv);
+  ::generator = generator.getValue();
+
+  llvm_shutdown_obj Y;
+  return TableGenMain(argv[0], &MlirTableGenMain);
+}
diff --git a/third_party/mlir/tools/mlir-translate/CMakeLists.txt b/third_party/mlir/tools/mlir-translate/CMakeLists.txt
new file mode 100644
index 0000000..50df9de
--- /dev/null
+++ b/third_party/mlir/tools/mlir-translate/CMakeLists.txt
@@ -0,0 +1,22 @@
+set(LIBS
+  MLIRAffineOps
+  MLIRAnalysis
+  MLIREDSC
+  MLIRParser
+  MLIRPass
+  MLIRSPIRV
+  MLIRSPIRVSerialization
+  MLIRStandardOps
+  MLIRTargetLLVMIR
+  MLIRTargetNVVMIR
+  MLIRTransforms
+  MLIRTranslation
+  MLIRSupport
+  MLIRVectorOps
+)
+add_llvm_executable(mlir-translate
+  mlir-translate.cpp
+)
+llvm_update_compile_flags(mlir-translate)
+whole_archive_link(mlir-translate ${LIBS})
+target_link_libraries(mlir-translate PRIVATE MLIRIR MLIRTranslateClParser ${LIBS} LLVMSupport)
diff --git a/third_party/mlir/tools/mlir-translate/mlir-translate.cpp b/third_party/mlir/tools/mlir-translate/mlir-translate.cpp
new file mode 100644
index 0000000..0ff5e6e
--- /dev/null
+++ b/third_party/mlir/tools/mlir-translate/mlir-translate.cpp
@@ -0,0 +1,52 @@
+//===- mlir-translate.cpp - MLIR Translate Driver -------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is a command line utility that translates a file from/to MLIR using one
+// of the registered translations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/TranslateClParser.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/PrettyStackTrace.h"
+
+using namespace mlir;
+
+static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
+                                                llvm::cl::desc("<input file>"),
+                                                llvm::cl::init("-"));
+
+static llvm::cl::opt<std::string>
+    outputFilename("o", llvm::cl::desc("Output filename"),
+                   llvm::cl::value_desc("filename"), llvm::cl::init("-"));
+
+int main(int argc, char **argv) {
+  llvm::PrettyStackTraceProgram x(argc, argv);
+  llvm::InitLLVM y(argc, argv);
+
+  // Add flags for all the registered translations.
+  llvm::cl::opt<const TranslateFunction *, false, TranslationParser>
+      translationRequested("", llvm::cl::desc("Translation to perform"),
+                           llvm::cl::Required);
+  llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR translation driver\n");
+
+  MLIRContext context;
+  return failed(
+      (*translationRequested)(inputFilename, outputFilename, &context));
+}
diff --git a/third_party/mlir/utils/emacs/mlir-mode.el b/third_party/mlir/utils/emacs/mlir-mode.el
new file mode 100644
index 0000000..636c5db
--- /dev/null
+++ b/third_party/mlir/utils/emacs/mlir-mode.el
@@ -0,0 +1,79 @@
+;;; mlir-mode.el --- Major mode for the MLIR assembler language.
+
+;; Copyright (C) 2019 The MLIR Authors.
+;;
+;; Licensed under the Apache License, Version 2.0 (the "License");
+;; you may not use this file except in compliance with the License.
+;; You may obtain a copy of the License at
+;;
+;;      http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+
+;;; Commentary:
+
+;; Major mode for editing MLIR files.
+
+;;; Code:
+
+(defvar mlir-mode-syntax-table
+  (let ((table (make-syntax-table)))
+    (modify-syntax-entry ?% "_" table)
+    (modify-syntax-entry ?@ "_" table)
+    (modify-syntax-entry ?# "_" table)
+    (modify-syntax-entry ?. "_" table)
+    (modify-syntax-entry ?/ ". 12" table)
+    (modify-syntax-entry ?\n "> " table)
+    table)
+  "Syntax table used while in MLIR mode.")
+
+(defvar mlir-font-lock-keywords
+  (list
+   ;; Variables
+   '("%[-a-zA-Z$._0-9]*" . font-lock-variable-name-face)
+   ;; Functions
+   '("@[-a-zA-Z$._0-9]*" . font-lock-function-name-face)
+   ;; Affinemaps
+   '("#[-a-zA-Z$._0-9]*" . font-lock-variable-name-face)
+   ;; Types
+   '("\\b\\(f16\\|bf16\\|f32\\|f64\\|index\\|tf_control\\|i[1-9][0-9]*\\)\\b" . font-lock-type-face)
+   '("\\b\\(tensor\\|vector\\|memref\\)\\b" . font-lock-type-face)
+   ;; Dimension lists
+   '("\\b\\([0-9?]+x\\)*\\(f16\\|bf16\\|f32\\|f64\\|index\\|i[1-9][0-9]*\\)\\b" . font-lock-preprocessor-face)
+   ;; Integer literals
+   '("\\b[-]?[0-9]+\\b" . font-lock-preprocessor-face)
+   ;; Floating point constants
+   '("\\b[-+]?[0-9]+.[0-9]*\\([eE][-+]?[0-9]+\\)?\\b" . font-lock-preprocessor-face)
+   ;; Hex constants
+   '("\\b0x[0-9A-Fa-f]+\\b" . font-lock-preprocessor-face)
+   ;; Keywords
+   `(,(regexp-opt
+       '(;; Toplevel entities
+         "br" "ceildiv" "func" "cond_br" "else" "extfunc" "false" "floordiv" "for" "if" "mod" "return" "size" "step" "to" "true" "??" ) 'symbols) . font-lock-keyword-face))
+  "Syntax highlighting for MLIR.")
+
+;; Emacs 23 compatibility.
+(defalias 'mlir-mode-prog-mode
+  (if (fboundp 'prog-mode)
+      'prog-mode
+    'fundamental-mode))
+
+;;;###autoload
+(define-derived-mode mlir-mode mlir-mode-prog-mode "MLIR"
+  "Major mode for editing MLIR source files.
+\\{mlir-mode-map}
+  Runs `mlir-mode-hook' on startup."
+  (setq font-lock-defaults `(mlir-font-lock-keywords))
+  (setq-local comment-start "//"))
+
+;; Associate .mlir files with mlir-mode
+;;;###autoload
+(add-to-list 'auto-mode-alist (cons "\\.mlir\\'" 'mlir-mode))
+
+(provide 'mlir-mode)
+
+;;; mlir-mode.el ends here
diff --git a/third_party/mlir/utils/spirv/define_enum.sh b/third_party/mlir/utils/spirv/define_enum.sh
new file mode 100755
index 0000000..9da898f
--- /dev/null
+++ b/third_party/mlir/utils/spirv/define_enum.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+
+# Copyright 2019 The MLIR Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Script for defining a new enum attr using SPIR-V spec from the Internet.
+#
+# Run as:
+# ./define_enum.sh <enum-class-name>
+#
+# The 'operand_kinds' dict of spirv.core.grammar.json contains all supported
+# SPIR-V enum classes.
+#
+# If <enum-name> is missing, this script updates existing ones.
+
+set -e
+
+new_enum=$1
+
+current_file="$(readlink -f "$0")"
+current_dir="$(dirname "$current_file")"
+
+python3 ${current_dir}/gen_spirv_dialect.py \
+  --base-td-path ${current_dir}/../../include/mlir/Dialect/SPIRV/SPIRVBase.td \
+  --new-enum "${new_enum}"
diff --git a/third_party/mlir/utils/spirv/define_inst.sh b/third_party/mlir/utils/spirv/define_inst.sh
new file mode 100755
index 0000000..49b5e8d
--- /dev/null
+++ b/third_party/mlir/utils/spirv/define_inst.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+
+# Copyright 2019 The MLIR Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Script for defining a new op using SPIR-V spec from the Internet.
+#
+# Run as:
+# ./define_inst.sh <opname>
+
+# For example:
+# ./define_inst.sh OpIAdd
+#
+# If <opname> is missing, this script updates existing ones.
+
+set -e
+
+new_op=$1
+
+current_file="$(readlink -f "$0")"
+current_dir="$(dirname "$current_file")"
+
+python3 ${current_dir}/gen_spirv_dialect.py \
+  --op-td-path ${current_dir}/../../include/mlir/Dialect/SPIRV/SPIRVOps.td \
+  --new-inst "${new_op}"
diff --git a/third_party/mlir/utils/spirv/define_opcodes.sh b/third_party/mlir/utils/spirv/define_opcodes.sh
new file mode 100755
index 0000000..05c3657
--- /dev/null
+++ b/third_party/mlir/utils/spirv/define_opcodes.sh
@@ -0,0 +1,38 @@
+#!/bin/bash
+
+# Copyright 2019 The MLIR Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Script for defining map for opname to opcode using SPIR-V spec from the
+# Internet
+#
+# Run as:
+# ./define_opcode.sh (<op-name>)*
+#
+# For example:
+# ./define_opcode.sh OpTypeVoid OpTypeFunction
+#
+# If no op-name is specified, the existing opcodes are updated
+#
+# The 'instructions' list of spirv.core.grammar.json contains all instructions
+# in SPIR-V
+
+set -e
+
+current_file="$(readlink -f "$0")"
+current_dir="$(dirname "$current_file")"
+
+python3 ${current_dir}/gen_spirv_dialect.py \
+  --base-td-path ${current_dir}/../../include/mlir/Dialect/SPIRV/SPIRVBase.td \
+  --new-opcode $@
diff --git a/third_party/mlir/utils/spirv/gen_spirv_dialect.py b/third_party/mlir/utils/spirv/gen_spirv_dialect.py
new file mode 100755
index 0000000..ac00179
--- /dev/null
+++ b/third_party/mlir/utils/spirv/gen_spirv_dialect.py
@@ -0,0 +1,616 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 The MLIR Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Script for updating SPIR-V dialect by scraping information from SPIR-V
+# HTML and JSON specs from the Internet.
+#
+# For example, to define the enum attribute for SPIR-V memory model:
+#
+# ./gen_spirv_dialect.py --base_td_path /path/to/SPIRVBase.td \
+#                        --new-enum MemoryModel
+#
+# The 'operand_kinds' dict of spirv.core.grammar.json contains all supported
+# SPIR-V enum classes.
+
+import re
+import requests
+import textwrap
+
+SPIRV_HTML_SPEC_URL = 'https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html'
+SPIRV_JSON_SPEC_URL = 'https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/spirv.core.grammar.json'
+
+AUTOGEN_OP_DEF_SEPARATOR = '\n// -----\n\n'
+AUTOGEN_ENUM_SECTION_MARKER = 'enum section. Generated from SPIR-V spec; DO NOT MODIFY!'
+AUTOGEN_OPCODE_SECTION_MARKER = (
+    'opcode section. Generated from SPIR-V spec; DO NOT MODIFY!')
+
+
+def get_spirv_doc_from_html_spec():
+  """Extracts instruction documentation from SPIR-V HTML spec.
+
+  Returns:
+    - A dict mapping from instruction opcode to documentation.
+  """
+  response = requests.get(SPIRV_HTML_SPEC_URL)
+  spec = response.content
+
+  from bs4 import BeautifulSoup
+  spirv = BeautifulSoup(spec, 'html.parser')
+
+  section_anchor = spirv.find('h3', {'id': '_a_id_instructions_a_instructions'})
+
+  doc = {}
+
+  for section in section_anchor.parent.find_all('div', {'class': 'sect3'}):
+    for table in section.find_all('table'):
+      inst_html = table.tbody.tr.td.p
+      opname = inst_html.a['id']
+      # Ignore the first line, which is just the opname.
+      doc[opname] = inst_html.text.split('\n', 1)[1].strip()
+
+  return doc
+
+
+def get_spirv_grammar_from_json_spec():
+  """Extracts operand kind and instruction grammar from SPIR-V JSON spec.
+
+  Returns:
+    - A list containing all operand kinds' grammar
+    - A list containing all instructions' grammar
+  """
+  response = requests.get(SPIRV_JSON_SPEC_URL)
+  spec = response.content
+
+  import json
+  spirv = json.loads(spec)
+
+  return spirv['operand_kinds'], spirv['instructions']
+
+
+def split_list_into_sublists(items, offset):
+  """Split the list of items into multiple sublists.
+
+  This is to make sure the string composed from each sublist won't exceed
+  80 characters.
+
+  Arguments:
+    - items: a list of strings
+    - offset: the offset in calculating each sublist's length
+  """
+  chuncks = []
+  chunk = []
+  chunk_len = 0
+
+  for item in items:
+    chunk_len += len(item) + 2
+    if chunk_len > 80:
+      chuncks.append(chunk)
+      chunk = []
+      chunk_len = len(item) + 2
+    chunk.append(item)
+
+  if len(chunk) != 0:
+    chuncks.append(chunk)
+
+  return chuncks
+
+
+def uniquify(lst, equality_fn):
+  """Returns a list after pruning duplicate elements.
+
+  Arguments:
+   - lst: List whose elements are to be uniqued.
+   - equality_fn: Function used to compare equality between elements of the
+     list.
+
+  Returns:
+   - A list with all duplicated removed. The order of elements is same as the
+     original list, with only the first occurence of duplicates retained.
+  """
+  keys = set()
+  unique_lst = []
+  for elem in lst:
+    key = equality_fn(elem)
+    if equality_fn(key) not in keys:
+      unique_lst.append(elem)
+      keys.add(key)
+  return unique_lst
+
+
+def gen_operand_kind_enum_attr(operand_kind):
+  """Generates the TableGen I32EnumAttr definition for the given operand kind.
+
+  Returns:
+    - The operand kind's name
+    - A string containing the TableGen I32EnumAttr definition
+  """
+  if 'enumerants' not in operand_kind:
+    return '', ''
+
+  kind_name = operand_kind['kind']
+  kind_acronym = ''.join([c for c in kind_name if c >= 'A' and c <= 'Z'])
+  kind_cases = [(case['enumerant'], case['value'])
+                for case in operand_kind['enumerants']]
+  kind_cases = uniquify(kind_cases, lambda x: x[1])
+  max_len = max([len(symbol) for (symbol, _) in kind_cases])
+
+  # Generate the definition for each enum case
+  fmt_str = 'def SPV_{acronym}_{symbol} {colon:>{offset}} '\
+            'I32EnumAttrCase<"{symbol}", {value}>;'
+  case_defs = [
+      fmt_str.format(
+          acronym=kind_acronym,
+          symbol=case[0],
+          value=case[1],
+          colon=':',
+          offset=(max_len + 1 - len(case[0]))) for case in kind_cases
+  ]
+  case_defs = '\n'.join(case_defs)
+
+  # Generate the list of enum case names
+  fmt_str = 'SPV_{acronym}_{symbol}';
+  case_names = [fmt_str.format(acronym=kind_acronym,symbol=case[0])
+                for case in kind_cases]
+
+  # Split them into sublists and concatenate into multiple lines
+  case_names = split_list_into_sublists(case_names, 6)
+  case_names = ['{:6}'.format('') + ', '.join(sublist)
+                for sublist in case_names]
+  case_names = ',\n'.join(case_names)
+
+  # Generate the enum attribute definition
+  enum_attr = 'def SPV_{name}Attr :\n    '\
+      'I32EnumAttr<"{name}", "valid SPIR-V {name}", [\n{cases}\n    ]> {{\n'\
+      '  let returnType = "::mlir::spirv::{name}";\n'\
+      '  let convertFromStorage = '\
+            '"static_cast<::mlir::spirv::{name}>($_self.getInt())";\n'\
+      '  let cppNamespace = "::mlir::spirv";\n}}'.format(
+          name=kind_name, cases=case_names)
+  return kind_name, case_defs + '\n\n' + enum_attr
+
+
+def gen_opcode(instructions):
+  """ Generates the TableGen definition to map opname to opcode
+
+  Returns:
+    - A string containing the TableGen SPV_OpCode definition
+  """
+
+  max_len = max([len(inst['opname']) for inst in instructions])
+  def_fmt_str = 'def SPV_OC_{name} {colon:>{offset}} '\
+            'I32EnumAttrCase<"{name}", {value}>;'
+  opcode_defs = [
+      def_fmt_str.format(
+          name=inst['opname'],
+          value=inst['opcode'],
+          colon=':',
+          offset=(max_len + 1 - len(inst['opname']))) for inst in instructions
+  ]
+  opcode_str = '\n'.join(opcode_defs)
+
+  decl_fmt_str = 'SPV_OC_{name}'
+  opcode_list = [
+      decl_fmt_str.format(name=inst['opname']) for inst in instructions
+  ]
+  opcode_list = split_list_into_sublists(opcode_list, 6)
+  opcode_list = [
+      '{:6}'.format('') + ', '.join(sublist) for sublist in opcode_list
+  ]
+  opcode_list = ',\n'.join(opcode_list)
+  enum_attr = 'def SPV_OpcodeAttr :\n'\
+              '    I32EnumAttr<"{name}", "valid SPIR-V instructions", [\n'\
+              '{lst}\n'\
+              '      ]> {{\n'\
+              '    let returnType = "::mlir::spirv::{name}";\n'\
+              '    let convertFromStorage = '\
+              '"static_cast<::mlir::spirv::{name}>($_self.getInt())";\n'\
+              '    let cppNamespace = "::mlir::spirv";\n}}'.format(
+                  name='Opcode', lst=opcode_list)
+  return opcode_str + '\n\n' + enum_attr
+
+
+def update_td_opcodes(path, instructions, filter_list):
+  """Updates SPIRBase.td with new generated opcode cases.
+
+  Arguments:
+    - path: the path to SPIRBase.td
+    - instructions: a list containing all SPIR-V instructions' grammar
+    - filter_list: a list containing new opnames to add
+  """
+
+  with open(path, 'r') as f:
+    content = f.read()
+
+  content = content.split(AUTOGEN_OPCODE_SECTION_MARKER)
+  assert len(content) == 3
+
+  # Extend opcode list with existing list
+  existing_opcodes = [k[11:] for k in re.findall('def SPV_OC_\w+', content[1])]
+  filter_list.extend(existing_opcodes)
+  filter_list = list(set(filter_list))
+
+  # Generate the opcode for all instructions in SPIR-V
+  filter_instrs = list(
+      filter(lambda inst: (inst['opname'] in filter_list), instructions))
+  # Sort instruction based on opcode
+  filter_instrs.sort(key=lambda inst: inst['opcode'])
+  opcode = gen_opcode(filter_instrs)
+
+  # Substitute the opcode
+  content = content[0] + AUTOGEN_OPCODE_SECTION_MARKER + '\n\n' + \
+        opcode + '\n\n// End ' + AUTOGEN_OPCODE_SECTION_MARKER \
+        + content[2]
+
+  with open(path, 'w') as f:
+    f.write(content)
+
+
+def update_td_enum_attrs(path, operand_kinds, filter_list):
+  """Updates SPIRBase.td with new generated enum definitions.
+
+  Arguments:
+    - path: the path to SPIRBase.td
+    - operand_kinds: a list containing all operand kinds' grammar
+    - filter_list: a list containing new enums to add
+  """
+  with open(path, 'r') as f:
+    content = f.read()
+
+  content = content.split(AUTOGEN_ENUM_SECTION_MARKER)
+  assert len(content) == 3
+
+  # Extend filter list with existing enum definitions
+  existing_kinds = [
+      k[8:-4] for k in re.findall('def SPV_\w+Attr', content[1])]
+  filter_list.extend(existing_kinds)
+
+  # Generate definitions for all enums in filter list
+  defs = [gen_operand_kind_enum_attr(kind)
+          for kind in operand_kinds if kind['kind'] in filter_list]
+  # Sort alphabetically according to enum name
+  defs.sort(key=lambda enum : enum[0])
+  # Only keep the definitions from now on
+  defs = [enum[1] for enum in defs]
+
+  # Substitute the old section
+  content = content[0] + AUTOGEN_ENUM_SECTION_MARKER + '\n\n' + \
+      '\n\n'.join(defs) + "\n\n// End " + AUTOGEN_ENUM_SECTION_MARKER  \
+      + content[2];
+
+  with open(path, 'w') as f:
+    f.write(content)
+
+
+def snake_casify(name):
+  """Turns the given name to follow snake_case convension."""
+  name = re.sub('\W+', '', name).split()
+  name = [s.lower() for s in name]
+  return '_'.join(name)
+
+
+def map_spec_operand_to_ods_argument(operand):
+  """Maps a operand in SPIR-V JSON spec to an op argument in ODS.
+
+  Arguments:
+    - A dict containing the operand's kind, quantifier, and name
+
+  Returns:
+    - A string containing both the type and name for the argument
+  """
+  kind = operand['kind']
+  quantifier = operand.get('quantifier', '')
+
+  # These instruction "operands" are for encoding the results; they should
+  # not be handled here.
+  assert kind != 'IdResultType', 'unexpected to handle "IdResultType" kind'
+  assert kind != 'IdResult', 'unexpected to handle "IdResult" kind'
+
+  if kind == 'IdRef':
+    if quantifier == '':
+      arg_type = 'SPV_Type'
+    elif quantifier == '?':
+      arg_type = 'SPV_Optional<SPV_Type>'
+    else:
+      arg_type = 'Variadic<SPV_Type>'
+  elif kind == 'IdMemorySemantics' or kind == 'IdScope':
+    # TODO(antiagainst): Need to further constrain 'IdMemorySemantics'
+    # and 'IdScope' given that they should be gernated from OpConstant.
+    assert quantifier == '', ('unexpected to have optional/variadic memory '
+                              'semantics or scope <id>')
+    arg_type = 'I32'
+  elif kind == 'LiteralInteger':
+    if quantifier == '':
+      arg_type = 'I32Attr'
+    elif quantifier == '?':
+      arg_type = 'OptionalAttr<I32Attr>'
+    else:
+      arg_type = 'OptionalAttr<I32ArrayAttr>'
+  elif kind == 'LiteralString' or \
+      kind == 'LiteralContextDependentNumber' or \
+      kind == 'LiteralExtInstInteger' or \
+      kind == 'LiteralSpecConstantOpInteger' or \
+      kind == 'PairLiteralIntegerIdRef' or \
+      kind == 'PairIdRefLiteralInteger' or \
+      kind == 'PairIdRefIdRef':
+    assert False, '"{}" kind unimplemented'.format(kind)
+  else:
+    # The rest are all enum operands that we represent with op attributes.
+    assert quantifier != '*', 'unexpected to have variadic enum attribute'
+    arg_type = 'SPV_{}Attr'.format(kind)
+    if quantifier == '?':
+      arg_type = 'OptionalAttr<{}>'.format(arg_type)
+
+  name = operand.get('name', '')
+  name = snake_casify(name) if name else kind.lower()
+
+  return '{}:${}'.format(arg_type, name)
+
+
+def get_op_definition(instruction, doc, existing_info):
+  """Generates the TableGen op definition for the given SPIR-V instruction.
+
+  Arguments:
+    - instruction: the instruction's SPIR-V JSON grammar
+    - doc: the instruction's SPIR-V HTML doc
+    - existing_info: a dict containing potential manually specified sections for
+      this instruction
+
+  Returns:
+    - A string containing the TableGen op definition
+  """
+  fmt_str = 'def SPV_{opname}Op : SPV_Op<"{opname}", [{traits}]> {{\n'\
+            '  let summary = {summary};\n\n'\
+            '  let description = [{{\n'\
+            '{description}\n\n'\
+            '    ### Custom assembly form\n'\
+            '{assembly}'\
+            '}}];\n\n'\
+            '  let arguments = (ins{args});\n\n'\
+            '  let results = (outs{results});\n'\
+            '{extras}'\
+            '}}\n'
+
+  opname = instruction['opname'][2:]
+
+  summary, description = doc.split('\n', 1)
+  wrapper = textwrap.TextWrapper(
+      width=76, initial_indent='    ', subsequent_indent='    ')
+
+  # Format summary. If the summary can fit in the same line, we print it out
+  # as a "-quoted string; otherwise, wrap the lines using "[{...}]".
+  summary = summary.strip();
+  if len(summary) + len('  let summary = "";') <= 80:
+    summary = '"{}"'.format(summary)
+  else:
+    summary = '[{{\n{}\n  }}]'.format(wrapper.fill(summary))
+
+  # Wrap description
+  description = description.split('\n')
+  description = [wrapper.fill(line) for line in description if line]
+  description = '\n\n'.join(description)
+
+  operands = instruction.get('operands', [])
+
+  # Set op's result
+  results = ''
+  if len(operands) > 0 and operands[0]['kind'] == 'IdResultType':
+    results = '\n    SPV_Type:$result\n  '
+    operands = operands[1:]
+  if 'results' in existing_info:
+    results = existing_info['results']
+
+  # Ignore the operand standing for the result <id>
+  if len(operands) > 0 and operands[0]['kind'] == 'IdResult':
+    operands = operands[1:]
+
+  # Set op' argument
+  arguments = existing_info.get('arguments', None)
+  if arguments is None:
+    arguments = [map_spec_operand_to_ods_argument(o) for o in operands]
+    arguments = '\n    '.join(arguments)
+    if arguments:
+      # Prepend and append whitespace for formatting
+      arguments = '\n    {}\n  '.format(arguments)
+
+  assembly = existing_info.get('assembly', None)
+  if assembly is None:
+    assembly = '    ``` {.ebnf}\n'\
+               '    [TODO]\n'\
+               '    ```\n\n'\
+               '    For example:\n\n'\
+               '    ```\n'\
+               '    [TODO]\n'\
+               '    ```\n  '
+
+  return fmt_str.format(
+      opname=opname,
+      traits=existing_info.get('traits', ''),
+      summary=summary,
+      description=description,
+      assembly=assembly,
+      args=arguments,
+      results=results,
+      extras=existing_info.get('extras', ''))
+
+
+def extract_td_op_info(op_def):
+  """Extracts potentially manually specified sections in op's definition.
+
+  Arguments: - A string containing the op's TableGen definition
+    - doc: the instruction's SPIR-V HTML doc
+
+  Returns:
+    - A dict containing potential manually specified sections
+  """
+  # Get opname
+  opname = [o[8:-2] for o in re.findall('def SPV_\w+Op', op_def)]
+  assert len(opname) == 1, 'more than one ops in the same section!'
+  opname = opname[0]
+
+  # Get traits
+  op_tmpl_params = op_def.split('<', 1)[1].split('>', 1)[0].split(', ', 1)
+  if len(op_tmpl_params) == 1:
+    traits = ''
+  else:
+    traits = op_tmpl_params[1].strip('[]')
+
+  # Get custom assembly form
+  rest = op_def.split('### Custom assembly form\n')
+  assert len(rest) == 2, \
+          '{}: cannot find "### Custom assembly form"'.format(opname)
+  rest = rest[1].split('  let arguments = (ins')
+  assert len(rest) == 2, '{}: cannot find arguments'.format(opname)
+  assembly = rest[0].rstrip('}];\n')
+
+  # Get arguments
+  rest = rest[1].split('  let results = (outs')
+  assert len(rest) == 2, '{}: cannot find results'.format(opname)
+  args = rest[0].rstrip(');\n')
+
+  # Get results
+  rest = rest[1].split(');', 1)
+  assert len(rest) == 2, \
+          '{}: cannot find ");" ending results'.format(opname)
+  results = rest[0]
+
+  extras = rest[1].strip(' }\n')
+  if extras:
+    extras = '\n  {}\n'.format(extras)
+
+  return {
+      # Prefix with 'Op' to make it consistent with SPIR-V spec
+      'opname': 'Op{}'.format(opname),
+      'traits': traits,
+      'assembly': assembly,
+      'arguments': args,
+      'results': results,
+      'extras': extras
+  }
+
+
+def update_td_op_definitions(path, instructions, docs, filter_list):
+  """Updates SPIRVOps.td with newly generated op definition.
+
+  Arguments:
+    - path: path to SPIRVOps.td
+    - instructions: SPIR-V JSON grammar for all instructions
+    - docs: SPIR-V HTML doc for all instructions
+    - filter_list: a list containing new opnames to include
+
+  Returns:
+    - A string containing all the TableGen op definitions
+  """
+  with open(path, 'r') as f:
+    content = f.read()
+
+  # Split the file into chuncks, each containing one op.
+  ops = content.split(AUTOGEN_OP_DEF_SEPARATOR)
+  header = ops[0]
+  footer = ops[-1]
+  ops = ops[1:-1]
+
+  # For each existing op, extract the manually-written sections out to retain
+  # them when re-generating the ops. Also append the existing ops to filter
+  # list.
+  op_info_dict = {}
+  for op in ops:
+    info_dict = extract_td_op_info(op)
+    opname = info_dict['opname']
+    op_info_dict[opname] = info_dict
+    filter_list.append(opname)
+  filter_list = sorted(list(set(filter_list)))
+
+  op_defs = []
+  for opname in filter_list:
+    # Find the grammar spec for this op
+    instruction = next(
+        inst for inst in instructions if inst['opname'] == opname)
+    op_defs.append(
+        get_op_definition(instruction, docs[opname],
+                          op_info_dict.get(opname, {})))
+
+  # Substitute the old op definitions
+  op_defs = [header] + op_defs + [footer]
+  content = AUTOGEN_OP_DEF_SEPARATOR.join(op_defs)
+
+  with open(path, 'w') as f:
+    f.write(content)
+
+
+if __name__ == '__main__':
+  import argparse
+
+  cli_parser = argparse.ArgumentParser(
+      description='Update SPIR-V dialect definitions using SPIR-V spec')
+
+  cli_parser.add_argument(
+      '--base-td-path',
+      dest='base_td_path',
+      type=str,
+      default=None,
+      help='Path to SPIRVBase.td')
+  cli_parser.add_argument(
+      '--op-td-path',
+      dest='op_td_path',
+      type=str,
+      default=None,
+      help='Path to SPIRVOps.td')
+
+  cli_parser.add_argument(
+      '--new-enum',
+      dest='new_enum',
+      type=str,
+      default=None,
+      help='SPIR-V enum to be added to SPIRVBase.td')
+  cli_parser.add_argument(
+      '--new-opcodes',
+      dest='new_opcodes',
+      type=str,
+      default=None,
+      nargs='*',
+      help='update SPIR-V opcodes in SPIRVBase.td')
+  cli_parser.add_argument(
+      '--new-inst',
+      dest='new_inst',
+      type=str,
+      default=None,
+      help='SPIR-V instruction to be added to SPIRVOps.td')
+
+  args = cli_parser.parse_args()
+
+  operand_kinds, instructions = get_spirv_grammar_from_json_spec()
+
+  # Define new enum attr
+  if args.new_enum is not None:
+    assert args.base_td_path is not None
+    filter_list = [args.new_enum] if args.new_enum else []
+    update_td_enum_attrs(args.base_td_path, operand_kinds, filter_list)
+
+  # Define new opcode
+  if args.new_opcodes is not None:
+    assert args.base_td_path is not None
+    update_td_opcodes(args.base_td_path, instructions, args.new_opcodes)
+
+  # Define new op
+  if args.new_inst is not None:
+    assert args.op_td_path is not None
+    filter_list = [args.new_inst] if args.new_inst else []
+    docs = get_spirv_doc_from_html_spec()
+    update_td_op_definitions(args.op_td_path, instructions, docs, filter_list)
+    print('Done. Note that this script just generates a template; ', end='')
+    print('please read the spec and update traits, arguments, and ', end='')
+    print('results accordingly.')
diff --git a/third_party/mlir/utils/vim/mlir.vim b/third_party/mlir/utils/vim/mlir.vim
new file mode 100644
index 0000000..18ff6fe
--- /dev/null
+++ b/third_party/mlir/utils/vim/mlir.vim
@@ -0,0 +1,51 @@
+" Copyright 2019 The MLIR Authors.
+"
+" Licensed under the Apache License, Version 2.0 (the "License");
+" you may not use this file except in compliance with the License.
+" You may obtain a copy of the License at
+"
+"     http://www.apache.org/licenses/LICENSE-2.0
+"
+" Unless required by applicable law or agreed to in writing, software
+" distributed under the License is distributed on an "AS IS" BASIS,
+" WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+" See the License for the specific language governing permissions and
+" limitations under the License.
+
+" Vim syntax file
+" Language: MLIR
+
+" quit when a syntax file was already loaded
+if exists("b:current_syntax")
+  finish
+endif
+
+syn keyword mlirType index i1 i2 i4 i8 i13 i16 i32 i64
+      \ f16 f32 tf_control
+syn keyword mlirType memref tensor vector
+
+syntax keyword mlirKeywords extfunc cfgfunc mlfunc for to step return
+syntax keyword mlirConditional if else
+syntax keyword mlirCoreOps dim addf addi subf subi mulf muli cmpi select constant affine.apply call call_indirect extract_element getTensor memref_cast tensor_cast load store alloc dealloc dma_start dma_wait
+
+syn match mlirInt "-\=\<\d\+\>"
+syn match mlirFloat "-\=\<\d\+\.\d\+\>"
+syn match mlirMapOutline "#.*$"
+syn match mlirOperator      "[+\-*=]"
+
+syn region mlirComment start="//" skip="\\$" end="$"
+syn region mlirString matchgroup=mlirString start=+"+ end=+"+
+
+hi def link mlirComment      Comment
+hi def link mlirKeywords     Instruction
+hi def link mlirCoreOps      Instruction
+hi def link mlirInt          Constant
+hi def link mlirType         Type
+hi def link mlirMapOutline   PreProc
+hi def link mlirConditional  Conditional
+hi def link mlirString       String
+hi def link mlirOperator     Operator
+hi def link mlirInstruction  Operator
+hi def link mlirAffineOp     Operator
+
+let b:current_syntax = "mlir"
diff --git a/third_party/pybind11.BUILD b/third_party/pybind11.BUILD
index 2e82147..95f452c 100644
--- a/third_party/pybind11.BUILD
+++ b/third_party/pybind11.BUILD
@@ -14,9 +14,8 @@
     ),
     copts = [
         "-fexceptions",
-        "-Xclang-only=-Wno-undefined-inline",
-        "-Xclang-only=-Wno-pragma-once-outside-header",
-        "-Xgcc-only=-Wno-error",  # no way to just disable the pragma-once warning in gcc
+        "-Wno-undefined-inline",
+        "-Wno-pragma-once-outside-header",
     ],
     includes = ["include"],
     deps = [
diff --git a/third_party/systemlibs/enum34.BUILD b/third_party/systemlibs/enum34.BUILD
new file mode 100644
index 0000000..de14bd5
--- /dev/null
+++ b/third_party/systemlibs/enum34.BUILD
@@ -0,0 +1,14 @@
+# Description:
+#   enum34 provides a backport of the enum module for Python 2.
+
+licenses(["notice"])  # MIT
+
+filegroup(
+    name = "LICENSE",
+    visibility = ["//visibility:public"],
+)
+
+py_library(
+    name = "enum",
+    visibility = ["//visibility:public"],
+)
diff --git a/third_party/systemlibs/jsoncpp.BUILD b/third_party/systemlibs/jsoncpp.BUILD
index 526fd0c..7d54f92 100644
--- a/third_party/systemlibs/jsoncpp.BUILD
+++ b/third_party/systemlibs/jsoncpp.BUILD
@@ -6,6 +6,8 @@
 )
 
 HEADERS = [
+    "include/json/allocator.h",
+    "include/json/assertions.h",
     "include/json/autolink.h",
     "include/json/config.h",
     "include/json/features.h",
diff --git a/third_party/systemlibs/syslibs_configure.bzl b/third_party/systemlibs/syslibs_configure.bzl
index 8c411a7..f83c0dd 100644
--- a/third_party/systemlibs/syslibs_configure.bzl
+++ b/third_party/systemlibs/syslibs_configure.bzl
@@ -20,6 +20,7 @@
     "curl",
     "cython",
     "double_conversion",
+    "enum34_archive",
     "flatbuffers",
     "gast_archive",
     "gif_archive",
diff --git a/third_party/toolchains/BUILD b/third_party/toolchains/BUILD
index b02b96e..2df2c3c 100644
--- a/third_party/toolchains/BUILD
+++ b/third_party/toolchains/BUILD
@@ -144,6 +144,25 @@
         """ % container_digests["cuda10.0-cudnn7-centos6"],
 )
 
+# Built with //tensorflow/tools/ci_build/Dockerfile.rbe.ubuntu16.04-manylinux2010.
+platform(
+    name = "rbe_ubuntu16.04-manylinux2010",
+    constraint_values = [
+        "@bazel_tools//platforms:x86_64",
+        "@bazel_tools//platforms:linux",
+    ],
+    remote_execution_properties = """
+        properties: {
+            name: "container-image"
+            value:"docker://gcr.io/tensorflow-testing/nosla-ubuntu16.04-manylinux2010@%s"
+        }
+        properties: {
+            name: "Pool"
+            value: "default"
+        }
+        """ % container_digests["ubuntu16.04-manylinux2010"],
+)
+
 # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.0-cudnn7-ubuntu16.04-manylinux2010.
 platform(
     name = "rbe_cuda10.0-cudnn7-ubuntu16.04-manylinux2010",
diff --git a/third_party/toolchains/preconfig/generate/BUILD b/third_party/toolchains/preconfig/generate/BUILD
index 2e6c670..9c25862 100644
--- a/third_party/toolchains/preconfig/generate/BUILD
+++ b/third_party/toolchains/preconfig/generate/BUILD
@@ -86,6 +86,22 @@
 )
 
 tensorflow_rbe_config(
+    name = "ubuntu16.04-py-gcc7_manylinux2010",
+    compiler = "/dt7/usr/bin/gcc",
+    compiler_prefix = "/usr/bin",
+    os = "ubuntu16.04-manylinux2010",
+    python_version = "2",
+)
+
+tensorflow_rbe_config(
+    name = "ubuntu16.04-py3-gcc7_manylinux2010",
+    compiler = "/dt7/usr/bin/gcc",
+    compiler_prefix = "/usr/bin",
+    os = "ubuntu16.04-manylinux2010",
+    python_version = "3.6",
+)
+
+tensorflow_rbe_config(
     name = "ubuntu16.04-py3-gcc7_manylinux2010-cuda10.0-cudnn7-tensorrt5.1",
     compiler = "/dt7/usr/bin/gcc",
     compiler_prefix = "/usr/bin",
diff --git a/third_party/toolchains/preconfig/generate/containers.bzl b/third_party/toolchains/preconfig/generate/containers.bzl
index 6f69216..e2e125b 100644
--- a/third_party/toolchains/preconfig/generate/containers.bzl
+++ b/third_party/toolchains/preconfig/generate/containers.bzl
@@ -2,6 +2,7 @@
 container_digests = {
     "ubuntu16.04": "sha256:b90dcf2f35f3354909f4491bdf019c110b4b4d95ef0395ebf178bc5d523a4208",
     "centos6": "sha256:d09c12fb26fbbe8398b4973260c75172eb67d509dae9d6f4ad54279b7d6b0494",
+    "ubuntu16.04-manylinux2010": "sha256:3a9b4820021801b1fa7d0592c1738483ac7abc209fc6ee8c9ef06cf2eab2d170",
     "cuda10.0-cudnn7-ubuntu14.04": "sha256:d433e1221f802dac393bc8652fabcc63aa46896cd920bb888ae0e2002fe6b756",
     "cuda10.0-cudnn7-centos7": "sha256:a453b7147a60928a8345689eae48916a746b3578b5e831bfa151f0529d469c88",
     "cuda10.0-cudnn7-centos6": "sha256:a1909ba09c703340ee0074ce63dd94fe8fea48035a25264677907a609e2375e0",
diff --git a/third_party/toolchains/preconfig/generate/workspace.bzl b/third_party/toolchains/preconfig/generate/workspace.bzl
index 92f2abd..fb8a303 100644
--- a/third_party/toolchains/preconfig/generate/workspace.bzl
+++ b/third_party/toolchains/preconfig/generate/workspace.bzl
@@ -47,6 +47,13 @@
     )
 
     container_pull(
+        name = "ubuntu16.04-manylinux2010",
+        registry = "gcr.io",
+        repository = "tensorflow-testing/nosla-ubuntu16.04-manylinux2010",
+        digest = container_digests["ubuntu16.04-manylinux2010"],
+    )
+
+    container_pull(
         name = "cuda10.0-cudnn7-ubuntu16.04-manylinux2010",
         registry = "gcr.io",
         repository = "tensorflow-testing/nosla-cuda10.0-cudnn7-ubuntu16.04-manylinux2010",
diff --git a/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/clang/bin/crosstool_wrapper_driver_is_not_gcc b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/clang/bin/crosstool_wrapper_driver_is_not_gcc
index 9800b76..1243dbb 100755
--- a/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/clang/bin/crosstool_wrapper_driver_is_not_gcc
+++ b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/clang/bin/crosstool_wrapper_driver_is_not_gcc
@@ -53,6 +53,11 @@
 PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH)
 NVCC_VERSION = '10.0'
 
+# Environment variable for supported TF CUDA Compute Capabilities
+# eg. export TF_CUDA_COMPUTE_CAPABILITIES=3.5,3.7,5.2,6.0,6.1,7.0
+CUDA_COMPUTE_ENV_VAR = 'TF_CUDA_COMPUTE_CAPABILITIES'
+DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,6.0'
+
 def Log(s):
   print('gpus/crosstool: {0}'.format(s))
 
@@ -202,7 +207,7 @@
   srcs = ' '.join(src_files)
   out = ' -o ' + out_file[0]
 
-  supported_cuda_compute_capabilities = [ "3.0", "6.0" ]
+  supported_cuda_compute_capabilities = os.environ.get(CUDA_COMPUTE_ENV_VAR, DEFAULT_CUDA_COMPUTE_CAPABILITIES).split(',')
   nvccopts = '-D_FORCE_INLINES '
   for capability in supported_cuda_compute_capabilities:
     capability = capability.replace('.', '')
diff --git a/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/windows/msvc_wrapper_for_nvcc.py b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/windows/msvc_wrapper_for_nvcc.py
index 79b98e5..a69d47f 100755
--- a/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/windows/msvc_wrapper_for_nvcc.py
+++ b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/windows/msvc_wrapper_for_nvcc.py
@@ -36,7 +36,14 @@
 NVCC_PATH = '/usr/local/cuda-10.0/bin/nvcc'
 NVCC_VERSION = '10.0'
 NVCC_TEMP_DIR = "C:\\Windows\\Temp\\nvcc_inter_files_tmp_dir"
-supported_cuda_compute_capabilities = [ "3.0", "6.0" ]
+DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,6.0'
+
+# Taken from environment variable for supported TF CUDA Compute Capabilities
+# eg. export TF_CUDA_COMPUTE_CAPABILITIES=3.5,3.7,5.2,6.0,6.1,7.0
+supported_cuda_compute_capabilities = os.environ.get(
+    'TF_CUDA_COMPUTE_CAPABILITIES',
+    DEFAULT_CUDA_COMPUTE_CAPABILITIES).split(',')
+
 
 def Log(s):
   print('gpus/crosstool: {0}'.format(s))
diff --git a/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/BUILD b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/BUILD
new file mode 100755
index 0000000..149a040
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/BUILD
@@ -0,0 +1,121 @@
+# Copyright 2016 The Bazel Authors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This becomes the BUILD file for @local_config_cc// under non-FreeBSD unixes.
+
+package(default_visibility = ["//visibility:public"])
+
+load(":cc_toolchain_config.bzl", "cc_toolchain_config")
+
+licenses(["notice"])  # Apache 2.0
+
+cc_library(
+    name = "malloc",
+)
+
+filegroup(
+    name = "empty",
+    srcs = [],
+)
+
+filegroup(
+    name = "cc_wrapper",
+    srcs = ["cc_wrapper.sh"],
+)
+
+filegroup(
+    name = "compiler_deps",
+    srcs = glob(["extra_tools/**"]) + [":empty"],
+)
+
+# This is the entry point for --crosstool_top.  Toolchains are found
+# by lopping off the name of --crosstool_top and searching for
+# the "${CPU}" entry in the toolchains attribute.
+cc_toolchain_suite(
+    name = "toolchain",
+    toolchains = {
+        "k8|/dt7/usr/bin/gcc": ":cc-compiler-k8",
+        "k8": ":cc-compiler-k8",
+        "armeabi-v7a|compiler": ":cc-compiler-armeabi-v7a",
+        "armeabi-v7a": ":cc-compiler-armeabi-v7a",
+    },
+)
+
+cc_toolchain(
+    name = "cc-compiler-k8",
+    all_files = ":compiler_deps",
+    ar_files = ":empty",
+    as_files = ":empty",
+    compiler_files = ":compiler_deps",
+    dwp_files = ":empty",
+    linker_files = ":compiler_deps",
+    objcopy_files = ":empty",
+    strip_files = ":empty",
+    supports_param_files = 1,
+    toolchain_config = ":linux_gnu_x86",
+    toolchain_identifier = "linux_gnu_x86",
+)
+
+cc_toolchain_config(
+    name = "linux_gnu_x86",
+    compiler = "/dt7/usr/bin/gcc",
+    cpu = "k8",
+)
+
+toolchain(
+    name = "cc-toolchain-k8",
+    exec_compatible_with = [
+        # TODO(katre): add autodiscovered constraints for host CPU and OS.
+    ],
+    target_compatible_with = [
+        # TODO(katre): add autodiscovered constraints for host CPU and OS.
+    ],
+    toolchain = ":cc-compiler-k8",
+    toolchain_type = "@bazel_tools//tools/cpp:toolchain_type",
+)
+
+# Android tooling requires a default toolchain for the armeabi-v7a cpu.
+cc_toolchain(
+    name = "cc-compiler-armeabi-v7a",
+    all_files = ":empty",
+    ar_files = ":empty",
+    as_files = ":empty",
+    compiler_files = ":empty",
+    dwp_files = ":empty",
+    linker_files = ":empty",
+    objcopy_files = ":empty",
+    strip_files = ":empty",
+    supports_param_files = 1,
+    toolchain_config = ":stub_armeabi-v7a",
+    toolchain_identifier = "stub_armeabi-v7a",
+)
+
+cc_toolchain_config(
+    name = "stub_armeabi-v7a",
+    compiler = "compiler",
+    cpu = "armeabi-v7a",
+)
+
+toolchain(
+    name = "cc-toolchain-armeabi-v7a",
+    exec_compatible_with = [
+        # TODO(katre): add autodiscovered constraints for host CPU and OS.
+    ],
+    target_compatible_with = [
+        "@bazel_tools//platforms:arm",
+        "@bazel_tools//platforms:android",
+    ],
+    toolchain = ":cc-compiler-armabi-v7a",
+    toolchain_type = "@bazel_tools//tools/cpp:toolchain_type",
+)
diff --git a/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/WORKSPACE b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/WORKSPACE
new file mode 100644
index 0000000..bc05b4c
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/WORKSPACE
@@ -0,0 +1,2 @@
+# DO NOT EDIT: automatically generated WORKSPACE file for cc_autoconf rule
+workspace(name = "local_config_cc")
diff --git a/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/cc_toolchain_config.bzl b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/cc_toolchain_config.bzl
new file mode 100755
index 0000000..12f087e
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/cc_toolchain_config.bzl
@@ -0,0 +1,1732 @@
+# Copyright 2019 The Bazel Authors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A Starlark cc_toolchain configuration rule"""
+
+load(
+    "@bazel_tools//tools/cpp:cc_toolchain_config_lib.bzl",
+    "action_config",
+    "artifact_name_pattern",
+    "env_entry",
+    "env_set",
+    "feature",
+    "feature_set",
+    "flag_group",
+    "flag_set",
+    "make_variable",  # @unused
+    "tool",
+    "tool_path",
+    "variable_with_value",
+    "with_feature_set",
+)
+load("@bazel_tools//tools/build_defs/cc:action_names.bzl", "ACTION_NAMES")
+
+all_compile_actions = [
+    ACTION_NAMES.c_compile,
+    ACTION_NAMES.cpp_compile,
+    ACTION_NAMES.linkstamp_compile,
+    ACTION_NAMES.assemble,
+    ACTION_NAMES.preprocess_assemble,
+    ACTION_NAMES.cpp_header_parsing,
+    ACTION_NAMES.cpp_module_compile,
+    ACTION_NAMES.cpp_module_codegen,
+    ACTION_NAMES.clif_match,
+    ACTION_NAMES.lto_backend,
+]
+
+all_cpp_compile_actions = [
+    ACTION_NAMES.cpp_compile,
+    ACTION_NAMES.linkstamp_compile,
+    ACTION_NAMES.cpp_header_parsing,
+    ACTION_NAMES.cpp_module_compile,
+    ACTION_NAMES.cpp_module_codegen,
+    ACTION_NAMES.clif_match,
+]
+
+preprocessor_compile_actions = [
+    ACTION_NAMES.c_compile,
+    ACTION_NAMES.cpp_compile,
+    ACTION_NAMES.linkstamp_compile,
+    ACTION_NAMES.preprocess_assemble,
+    ACTION_NAMES.cpp_header_parsing,
+    ACTION_NAMES.cpp_module_compile,
+    ACTION_NAMES.clif_match,
+]
+
+codegen_compile_actions = [
+    ACTION_NAMES.c_compile,
+    ACTION_NAMES.cpp_compile,
+    ACTION_NAMES.linkstamp_compile,
+    ACTION_NAMES.assemble,
+    ACTION_NAMES.preprocess_assemble,
+    ACTION_NAMES.cpp_module_codegen,
+    ACTION_NAMES.lto_backend,
+]
+
+all_link_actions = [
+    ACTION_NAMES.cpp_link_executable,
+    ACTION_NAMES.cpp_link_dynamic_library,
+    ACTION_NAMES.cpp_link_nodeps_dynamic_library,
+]
+
+def _windows_msvc_impl(ctx):
+    toolchain_identifier = "msvc_x64"
+    host_system_name = "local"
+    target_system_name = "local"
+    target_cpu = "x64_windows"
+    target_libc = "msvcrt"
+    compiler = "msvc-cl"
+    abi_version = "local"
+    abi_libc_version = "local"
+    cc_target_os = None
+    builtin_sysroot = None
+
+    cxx_builtin_include_directories = [
+        "/dt7/usr/lib/gcc/x86_64-pc-linux-gnu/7/include",
+        "/dt7/usr/lib/gcc/x86_64-pc-linux-gnu/7/include-fixed",
+        "/dt7/usr/include",
+        "/dt7/usr/include/c++/7",
+        "/dt7/usr/include/c++/7/x86_64-pc-linux-gnu",
+        "/dt7/usr/include/c++/7/backward",
+    ]
+
+    cpp_link_nodeps_dynamic_library_action = action_config(
+        action_name = ACTION_NAMES.cpp_link_nodeps_dynamic_library,
+        implies = [
+            "nologo",
+            "shared_flag",
+            "linkstamps",
+            "output_execpath_flags",
+            "input_param_flags",
+            "user_link_flags",
+            "default_link_flags",
+            "linker_subsystem_flag",
+            "linker_param_file",
+            "msvc_env",
+            "no_stripping",
+            "has_configured_linker_path",
+            "def_file",
+        ],
+        tools = [tool(path = "")],
+    )
+
+    cpp_link_static_library_action = action_config(
+        action_name = ACTION_NAMES.cpp_link_static_library,
+        implies = [
+            "nologo",
+            "archiver_flags",
+            "input_param_flags",
+            "linker_param_file",
+            "msvc_env",
+        ],
+        tools = [tool(path = "")],
+    )
+
+    assemble_action = action_config(
+        action_name = ACTION_NAMES.assemble,
+        implies = [
+            "compiler_input_flags",
+            "compiler_output_flags",
+            "nologo",
+            "msvc_env",
+            "sysroot",
+        ],
+        tools = [tool(path = "")],
+    )
+
+    preprocess_assemble_action = action_config(
+        action_name = ACTION_NAMES.preprocess_assemble,
+        implies = [
+            "compiler_input_flags",
+            "compiler_output_flags",
+            "nologo",
+            "msvc_env",
+            "sysroot",
+        ],
+        tools = [tool(path = "")],
+    )
+
+    c_compile_action = action_config(
+        action_name = ACTION_NAMES.c_compile,
+        implies = [
+            "compiler_input_flags",
+            "compiler_output_flags",
+            "default_compile_flags",
+            "nologo",
+            "msvc_env",
+            "parse_showincludes",
+            "user_compile_flags",
+            "sysroot",
+            "unfiltered_compile_flags",
+        ],
+        tools = [tool(path = "")],
+    )
+
+    cpp_compile_action = action_config(
+        action_name = ACTION_NAMES.cpp_compile,
+        implies = [
+            "compiler_input_flags",
+            "compiler_output_flags",
+            "default_compile_flags",
+            "nologo",
+            "msvc_env",
+            "parse_showincludes",
+            "user_compile_flags",
+            "sysroot",
+            "unfiltered_compile_flags",
+        ],
+        tools = [tool(path = "")],
+    )
+
+    cpp_link_executable_action = action_config(
+        action_name = ACTION_NAMES.cpp_link_executable,
+        implies = [
+            "nologo",
+            "linkstamps",
+            "output_execpath_flags",
+            "input_param_flags",
+            "user_link_flags",
+            "default_link_flags",
+            "linker_subsystem_flag",
+            "linker_param_file",
+            "msvc_env",
+            "no_stripping",
+        ],
+        tools = [tool(path = "")],
+    )
+
+    cpp_link_dynamic_library_action = action_config(
+        action_name = ACTION_NAMES.cpp_link_dynamic_library,
+        implies = [
+            "nologo",
+            "shared_flag",
+            "linkstamps",
+            "output_execpath_flags",
+            "input_param_flags",
+            "user_link_flags",
+            "default_link_flags",
+            "linker_subsystem_flag",
+            "linker_param_file",
+            "msvc_env",
+            "no_stripping",
+            "has_configured_linker_path",
+            "def_file",
+        ],
+        tools = [tool(path = "")],
+    )
+
+    action_configs = [
+        assemble_action,
+        preprocess_assemble_action,
+        c_compile_action,
+        cpp_compile_action,
+        cpp_link_executable_action,
+        cpp_link_dynamic_library_action,
+        cpp_link_nodeps_dynamic_library_action,
+        cpp_link_static_library_action,
+    ]
+
+    msvc_link_env_feature = feature(
+        name = "msvc_link_env",
+        env_sets = [
+            env_set(
+                actions = all_link_actions +
+                          [ACTION_NAMES.cpp_link_static_library],
+                env_entries = [env_entry(key = "LIB", value = "")],
+            ),
+        ],
+    )
+
+    shared_flag_feature = feature(
+        name = "shared_flag",
+        flag_sets = [
+            flag_set(
+                actions = [
+                    ACTION_NAMES.cpp_link_dynamic_library,
+                    ACTION_NAMES.cpp_link_nodeps_dynamic_library,
+                ],
+                flag_groups = [flag_group(flags = ["/DLL"])],
+            ),
+        ],
+    )
+
+    determinism_feature = feature(
+        name = "determinism",
+        enabled = True,
+        flag_sets = [
+            flag_set(
+                actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile],
+                flag_groups = [
+                    flag_group(
+                        flags = [
+                            "/wd4117",
+                            "-D__DATE__=\"redacted\"",
+                            "-D__TIMESTAMP__=\"redacted\"",
+                            "-D__TIME__=\"redacted\"",
+                        ],
+                    ),
+                ],
+            ),
+        ],
+    )
+
+    sysroot_feature = feature(
+        name = "sysroot",
+        flag_sets = [
+            flag_set(
+                actions = [
+                    ACTION_NAMES.assemble,
+                    ACTION_NAMES.preprocess_assemble,
+                    ACTION_NAMES.c_compile,
+                    ACTION_NAMES.cpp_compile,
+                    ACTION_NAMES.cpp_header_parsing,
+                    ACTION_NAMES.cpp_module_compile,
+                    ACTION_NAMES.cpp_module_codegen,
+                    ACTION_NAMES.cpp_link_executable,
+                    ACTION_NAMES.cpp_link_dynamic_library,
+                    ACTION_NAMES.cpp_link_nodeps_dynamic_library,
+                ],
+                flag_groups = [
+                    flag_group(
+                        flags = ["--sysroot=%{sysroot}"],
+                        iterate_over = "sysroot",
+                        expand_if_available = "sysroot",
+                    ),
+                ],
+            ),
+        ],
+    )
+
+    unfiltered_compile_flags_feature = feature(
+        name = "unfiltered_compile_flags",
+        flag_sets = [
+            flag_set(
+                actions = [
+                    ACTION_NAMES.preprocess_assemble,
+                    ACTION_NAMES.c_compile,
+                    ACTION_NAMES.cpp_compile,
+                    ACTION_NAMES.cpp_header_parsing,
+                    ACTION_NAMES.cpp_module_compile,
+                    ACTION_NAMES.cpp_module_codegen,
+                ],
+                flag_groups = [
+                    flag_group(
+                        flags = ["%{unfiltered_compile_flags}"],
+                        iterate_over = "unfiltered_compile_flags",
+                        expand_if_available = "unfiltered_compile_flags",
+                    ),
+                ],
+            ),
+        ],
+    )
+
+    copy_dynamic_libraries_to_binary_feature = feature(name = "copy_dynamic_libraries_to_binary")
+
+    input_param_flags_feature = feature(
+        name = "input_param_flags",
+        flag_sets = [
+            flag_set(
+                actions = [
+                    ACTION_NAMES.cpp_link_dynamic_library,
+                    ACTION_NAMES.cpp_link_nodeps_dynamic_library,
+                ],
+                flag_groups = [
+                    flag_group(
+                        flags = ["/IMPLIB:%{interface_library_output_path}"],
+                        expand_if_available = "interface_library_output_path",
+                    ),
+                ],
+            ),
+            flag_set(
+                actions = all_link_actions,
+                flag_groups = [
+                    flag_group(
+                        flags = ["%{libopts}"],
+                        iterate_over = "libopts",
+                        expand_if_available = "libopts",
+                    ),
+                ],
+            ),
+            flag_set(
+                actions = all_link_actions +
+                          [ACTION_NAMES.cpp_link_static_library],
+                flag_groups = [
+                    flag_group(
+                        iterate_over = "libraries_to_link",
+                        flag_groups = [
+                            flag_group(
+                                iterate_over = "libraries_to_link.object_files",
+                                flag_groups = [flag_group(flags = ["%{libraries_to_link.object_files}"])],
+                                expand_if_equal = variable_with_value(
+                                    name = "libraries_to_link.type",
+                                    value = "object_file_group",
+                                ),
+                            ),
+                            flag_group(
+                                flag_groups = [flag_group(flags = ["%{libraries_to_link.name}"])],
+                                expand_if_equal = variable_with_value(
+                                    name = "libraries_to_link.type",
+                                    value = "object_file",
+                                ),
+                            ),
+                            flag_group(
+                                flag_groups = [flag_group(flags = ["%{libraries_to_link.name}"])],
+                                expand_if_equal = variable_with_value(
+                                    name = "libraries_to_link.type",
+                                    value = "interface_library",
+                                ),
+                            ),
+                            flag_group(
+                                flag_groups = [
+                                    flag_group(
+                                        flags = ["%{libraries_to_link.name}"],
+                                        expand_if_false = "libraries_to_link.is_whole_archive",
+                                    ),
+                                    flag_group(
+                                        flags = ["/WHOLEARCHIVE:%{libraries_to_link.name}"],
+                                        expand_if_true = "libraries_to_link.is_whole_archive",
+                                    ),
+                                ],
+                                expand_if_equal = variable_with_value(
+                                    name = "libraries_to_link.type",
+                                    value = "static_library",
+                                ),
+                            ),
+                        ],
+                        expand_if_available = "libraries_to_link",
+                    ),
+                ],
+            ),
+        ],
+    )
+
+    fastbuild_feature = feature(
+        name = "fastbuild",
+        flag_sets = [
+            flag_set(
+                actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile],
+                flag_groups = [flag_group(flags = ["/Od", "/Z7"])],
+            ),
+            flag_set(
+                actions = all_link_actions,
+                flag_groups = [
+                    flag_group(
+                        flags = ["", "/INCREMENTAL:NO"],
+                    ),
+                ],
+            ),
+        ],
+        implies = ["generate_pdb_file"],
+    )
+
+    user_compile_flags_feature = feature(
+        name = "user_compile_flags",
+        flag_sets = [
+            flag_set(
+                actions = [
+                    ACTION_NAMES.preprocess_assemble,
+                    ACTION_NAMES.c_compile,
+                    ACTION_NAMES.cpp_compile,
+                    ACTION_NAMES.cpp_header_parsing,
+                    ACTION_NAMES.cpp_module_compile,
+                    ACTION_NAMES.cpp_module_codegen,
+                ],
+                flag_groups = [
+                    flag_group(
+                        flags = ["%{user_compile_flags}"],
+                        iterate_over = "user_compile_flags",
+                        expand_if_available = "user_compile_flags",
+                    ),
+                ],
+            ),
+        ],
+    )
+
+    archiver_flags_feature = feature(
+        name = "archiver_flags",
+        flag_sets = [
+            flag_set(
+                actions = [ACTION_NAMES.cpp_link_static_library],
+                flag_groups = [
+                    flag_group(
+                        flags = ["/OUT:%{output_execpath}"],
+                        expand_if_available = "output_execpath",
+                    ),
+                ],
+            ),
+        ],
+    )
+
+    default_link_flags_feature = feature(
+        name = "default_link_flags",
+        enabled = True,
+        flag_sets = [
+            flag_set(
+                actions = all_link_actions,
+                flag_groups = [flag_group(flags = ["/MACHINE:X64"])],
+            ),
+        ],
+    )
+
+    static_link_msvcrt_feature = feature(name = "static_link_msvcrt")
+
+    dynamic_link_msvcrt_debug_feature = feature(
+        name = "dynamic_link_msvcrt_debug",
+        flag_sets = [
+            flag_set(
+                actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile],
+                flag_groups = [flag_group(flags = ["/MDd"])],
+            ),
+            flag_set(
+                actions = all_link_actions,
+                flag_groups = [flag_group(flags = ["/DEFAULTLIB:msvcrtd.lib"])],
+            ),
+        ],
+        requires = [feature_set(features = ["dbg"])],
+    )
+
+    dbg_feature = feature(
+        name = "dbg",
+        flag_sets = [
+            flag_set(
+                actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile],
+                flag_groups = [flag_group(flags = ["/Od", "/Z7"])],
+            ),
+            flag_set(
+                actions = all_link_actions,
+                flag_groups = [
+                    flag_group(
+                        flags = ["", "/INCREMENTAL:NO"],
+                    ),
+                ],
+            ),
+        ],
+        implies = ["generate_pdb_file"],
+    )
+
+    opt_feature = feature(
+        name = "opt",
+        flag_sets = [
+            flag_set(
+                actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile],
+                flag_groups = [flag_group(flags = ["/O2"])],
+            ),
+        ],
+        implies = ["frame_pointer"],
+    )
+
+    supports_interface_shared_libraries_feature = feature(
+        name = "supports_interface_shared_libraries",
+        enabled = True,
+    )
+
+    user_link_flags_feature = feature(
+        name = "user_link_flags",
+        flag_sets = [
+            flag_set(
+                actions = all_link_actions,
+                flag_groups = [
+                    flag_group(
+                        flags = ["%{user_link_flags}"],
+                        iterate_over = "user_link_flags",
+                        expand_if_available = "user_link_flags",
+                    ),
+                ],
+            ),
+        ],
+    )
+
+    default_compile_flags_feature = feature(
+        name = "default_compile_flags",
+        enabled = True,
+        flag_sets = [
+            flag_set(
+                actions = [
+                    ACTION_NAMES.assemble,
+                    ACTION_NAMES.preprocess_assemble,
+                    ACTION_NAMES.linkstamp_compile,
+                    ACTION_NAMES.c_compile,
+                    ACTION_NAMES.cpp_compile,
+                    ACTION_NAMES.cpp_header_parsing,
+                    ACTION_NAMES.cpp_module_compile,
+                    ACTION_NAMES.cpp_module_codegen,
+                    ACTION_NAMES.lto_backend,
+                    ACTION_NAMES.clif_match,
+                ],
+                flag_groups = [
+                    flag_group(
+                        flags = [
+                            "/DCOMPILER_MSVC",
+                            "/DNOMINMAX",
+                            "/D_WIN32_WINNT=0x0601",
+                            "/D_CRT_SECURE_NO_DEPRECATE",
+                            "/D_CRT_SECURE_NO_WARNINGS",
+                            "/bigobj",
+                            "/Zm500",
+                            "/EHsc",
+                            "/wd4351",
+                            "/wd4291",
+                            "/wd4250",
+                            "/wd4996",
+                        ],
+                    ),
+                ],
+            ),
+        ],
+    )
+
+    msvc_compile_env_feature = feature(
+        name = "msvc_compile_env",
+        env_sets = [
+            env_set(
+                actions = [
+                    ACTION_NAMES.c_compile,
+                    ACTION_NAMES.cpp_compile,
+                    ACTION_NAMES.cpp_module_compile,
+                    ACTION_NAMES.cpp_module_codegen,
+                    ACTION_NAMES.cpp_header_parsing,
+                    ACTION_NAMES.assemble,
+                    ACTION_NAMES.preprocess_assemble,
+                ],
+                env_entries = [env_entry(key = "INCLUDE", value = "")],
+            ),
+        ],
+    )
+
+    preprocessor_defines_feature = feature(
+        name = "preprocessor_defines",
+        enabled = True,
+        flag_sets = [
+            flag_set(
+                actions = [
+                    ACTION_NAMES.assemble,
+                    ACTION_NAMES.preprocess_assemble,
+                    ACTION_NAMES.c_compile,
+                    ACTION_NAMES.cpp_compile,
+                    ACTION_NAMES.cpp_header_parsing,
+                    ACTION_NAMES.cpp_module_compile,
+                ],
+                flag_groups = [
+                    flag_group(
+                        flags = ["/D%{preprocessor_defines}"],
+                        iterate_over = "preprocessor_defines",
+                    ),
+                ],
+            ),
+        ],
+    )
+
+    generate_pdb_file_feature = feature(
+        name = "generate_pdb_file",
+        requires = [
+            feature_set(features = ["dbg"]),
+            feature_set(features = ["fastbuild"]),
+        ],
+    )
+
+    output_execpath_flags_feature = feature(
+        name = "output_execpath_flags",
+        flag_sets = [
+            flag_set(
+                actions = all_link_actions,
+                flag_groups = [
+                    flag_group(
+                        flags = ["/OUT:%{output_execpath}"],
+                        expand_if_available = "output_execpath",
+                    ),
+                ],
+            ),
+        ],
+    )
+
+    dynamic_link_msvcrt_no_debug_feature = feature(
+        name = "dynamic_link_msvcrt_no_debug",
+        flag_sets = [
+            flag_set(
+                actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile],
+                flag_groups = [flag_group(flags = ["/MD"])],
+            ),
+            flag_set(
+                actions = all_link_actions,
+                flag_groups = [flag_group(flags = ["/DEFAULTLIB:msvcrt.lib"])],
+            ),
+        ],
+        requires = [
+            feature_set(features = ["fastbuild"]),
+            feature_set(features = ["opt"]),
+        ],
+    )
+
+    disable_assertions_feature = feature(
+        name = "disable_assertions",
+        enabled = True,
+        flag_sets = [
+            flag_set(
+                actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile],
+                flag_groups = [flag_group(flags = ["/DNDEBUG"])],
+                with_features = [with_feature_set(features = ["opt"])],
+            ),
+        ],
+    )
+
+    has_configured_linker_path_feature = feature(name = "has_configured_linker_path")
+
+    supports_dynamic_linker_feature = feature(name = "supports_dynamic_linker", enabled = True)
+
+    no_stripping_feature = feature(name = "no_stripping")
+
+    linker_param_file_feature = feature(
+        name = "linker_param_file",
+        flag_sets = [
+            flag_set(
+                actions = all_link_actions +
+                          [ACTION_NAMES.cpp_link_static_library],
+                flag_groups = [
+                    flag_group(
+                        flags = ["@%{linker_param_file}"],
+                        expand_if_available = "linker_param_file",
+                    ),
+                ],
+            ),
+        ],
+    )
+
+    ignore_noisy_warnings_feature = feature(
+        name = "ignore_noisy_warnings",
+        enabled = True,
+        flag_sets = [
+            flag_set(
+                actions = [ACTION_NAMES.cpp_link_static_library],
+                flag_groups = [flag_group(flags = ["/ignore:4221"])],
+            ),
+        ],
+    )
+
+    no_legacy_features_feature = feature(name = "no_legacy_features")
+
+    parse_showincludes_feature = feature(
+        name = "parse_showincludes",
+        flag_sets = [
+            flag_set(
+                actions = [
+                    ACTION_NAMES.preprocess_assemble,
+                    ACTION_NAMES.c_compile,
+                    ACTION_NAMES.cpp_compile,
+                    ACTION_NAMES.cpp_module_compile,
+                    ACTION_NAMES.cpp_header_parsing,
+                ],
+                flag_groups = [flag_group(flags = ["/showIncludes"])],
+            ),
+        ],
+    )
+
+    static_link_msvcrt_no_debug_feature = feature(
+        name = "static_link_msvcrt_no_debug",
+        flag_sets = [
+            flag_set(
+                actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile],
+                flag_groups = [flag_group(flags = ["/MT"])],
+            ),
+            flag_set(
+                actions = all_link_actions,
+                flag_groups = [flag_group(flags = ["/DEFAULTLIB:libcmt.lib"])],
+            ),
+        ],
+        requires = [
+            feature_set(features = ["fastbuild"]),
+            feature_set(features = ["opt"]),
+        ],
+    )
+
+    treat_warnings_as_errors_feature = feature(
+        name = "treat_warnings_as_errors",
+        flag_sets = [
+            flag_set(
+                actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile],
+                flag_groups = [flag_group(flags = ["/WX"])],
+            ),
+        ],
+    )
+
+    windows_export_all_symbols_feature = feature(name = "windows_export_all_symbols")
+
+    no_windows_export_all_symbols_feature = feature(name = "no_windows_export_all_symbols")
+
+    include_paths_feature = feature(
+        name = "include_paths",
+        enabled = True,
+        flag_sets = [
+            flag_set(
+                actions = [
+                    ACTION_NAMES.assemble,
+                    ACTION_NAMES.preprocess_assemble,
+                    ACTION_NAMES.c_compile,
+                    ACTION_NAMES.cpp_compile,
+                    ACTION_NAMES.cpp_header_parsing,
+                    ACTION_NAMES.cpp_module_compile,
+                ],
+                flag_groups = [
+                    flag_group(
+                        flags = ["/I%{quote_include_paths}"],
+                        iterate_over = "quote_include_paths",
+                    ),
+                    flag_group(
+                        flags = ["/I%{include_paths}"],
+                        iterate_over = "include_paths",
+                    ),
+                    flag_group(
+                        flags = ["/I%{system_include_paths}"],
+                        iterate_over = "system_include_paths",
+                    ),
+                ],
+            ),
+        ],
+    )
+
+    linkstamps_feature = feature(
+        name = "linkstamps",
+        flag_sets = [
+            flag_set(
+                actions = all_link_actions,
+                flag_groups = [
+                    flag_group(
+                        flags = ["%{linkstamp_paths}"],
+                        iterate_over = "linkstamp_paths",
+                        expand_if_available = "linkstamp_paths",
+                    ),
+                ],
+            ),
+        ],
+    )
+
+    targets_windows_feature = feature(
+        name = "targets_windows",
+        enabled = True,
+        implies = ["copy_dynamic_libraries_to_binary"],
+    )
+
+    linker_subsystem_flag_feature = feature(
+        name = "linker_subsystem_flag",
+        flag_sets = [
+            flag_set(
+                actions = all_link_actions,
+                flag_groups = [flag_group(flags = ["/SUBSYSTEM:CONSOLE"])],
+            ),
+        ],
+    )
+
+    static_link_msvcrt_debug_feature = feature(
+        name = "static_link_msvcrt_debug",
+        flag_sets = [
+            flag_set(
+                actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile],
+                flag_groups = [flag_group(flags = ["/MTd"])],
+            ),
+            flag_set(
+                actions = all_link_actions,
+                flag_groups = [flag_group(flags = ["/DEFAULTLIB:libcmtd.lib"])],
+            ),
+        ],
+        requires = [feature_set(features = ["dbg"])],
+    )
+
+    frame_pointer_feature = feature(
+        name = "frame_pointer",
+        flag_sets = [
+            flag_set(
+                actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile],
+                flag_groups = [flag_group(flags = ["/Oy-"])],
+            ),
+        ],
+    )
+
+    compiler_output_flags_feature = feature(
+        name = "compiler_output_flags",
+        flag_sets = [
+            flag_set(
+                actions = [ACTION_NAMES.assemble],
+                flag_groups = [
+                    flag_group(
+                        flag_groups = [
+                            flag_group(
+                                flags = ["/Fo%{output_file}", "/Zi"],
+                                expand_if_available = "output_file",
+                                expand_if_not_available = "output_assembly_file",
+                            ),
+                        ],
+                        expand_if_not_available = "output_preprocess_file",
+                    ),
+                ],
+            ),
+            flag_set(
+                actions = [
+                    ACTION_NAMES.preprocess_assemble,
+                    ACTION_NAMES.c_compile,
+                    ACTION_NAMES.cpp_compile,
+                    ACTION_NAMES.cpp_header_parsing,
+                    ACTION_NAMES.cpp_module_compile,
+                    ACTION_NAMES.cpp_module_codegen,
+                ],
+                flag_groups = [
+                    flag_group(
+                        flag_groups = [
+                            flag_group(
+                                flags = ["/Fo%{output_file}"],
+                                expand_if_not_available = "output_preprocess_file",
+                            ),
+                        ],
+                        expand_if_available = "output_file",
+                        expand_if_not_available = "output_assembly_file",
+                    ),
+                    flag_group(
+                        flag_groups = [
+                            flag_group(
+                                flags = ["/Fa%{output_file}"],
+                                expand_if_available = "output_assembly_file",
+                            ),
+                        ],
+                        expand_if_available = "output_file",
+                    ),
+                    flag_group(
+                        flag_groups = [
+                            flag_group(
+                                flags = ["/P", "/Fi%{output_file}"],
+                                expand_if_available = "output_preprocess_file",
+                            ),
+                        ],
+                        expand_if_available = "output_file",
+                    ),
+                ],
+            ),
+        ],
+    )
+
+    nologo_feature = feature(
+        name = "nologo",
+        flag_sets = [
+            flag_set(
+                actions = [
+                    ACTION_NAMES.c_compile,
+                    ACTION_NAMES.cpp_compile,
+                    ACTION_NAMES.cpp_module_compile,
+                    ACTION_NAMES.cpp_module_codegen,
+                    ACTION_NAMES.cpp_header_parsing,
+                    ACTION_NAMES.assemble,
+                    ACTION_NAMES.preprocess_assemble,
+                    ACTION_NAMES.cpp_link_executable,
+                    ACTION_NAMES.cpp_link_dynamic_library,
+                    ACTION_NAMES.cpp_link_nodeps_dynamic_library,
+                    ACTION_NAMES.cpp_link_static_library,
+                ],
+                flag_groups = [flag_group(flags = ["/nologo"])],
+            ),
+        ],
+    )
+
+    smaller_binary_feature = feature(
+        name = "smaller_binary",
+        enabled = True,
+        flag_sets = [
+            flag_set(
+                actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile],
+                flag_groups = [flag_group(flags = ["/Gy", "/Gw"])],
+                with_features = [with_feature_set(features = ["opt"])],
+            ),
+            flag_set(
+                actions = all_link_actions,
+                flag_groups = [flag_group(flags = ["/OPT:ICF", "/OPT:REF"])],
+                with_features = [with_feature_set(features = ["opt"])],
+            ),
+        ],
+    )
+
+    compiler_input_flags_feature = feature(
+        name = "compiler_input_flags",
+        flag_sets = [
+            flag_set(
+                actions = [
+                    ACTION_NAMES.assemble,
+                    ACTION_NAMES.preprocess_assemble,
+                    ACTION_NAMES.c_compile,
+                    ACTION_NAMES.cpp_compile,
+                    ACTION_NAMES.cpp_header_parsing,
+                    ACTION_NAMES.cpp_module_compile,
+                    ACTION_NAMES.cpp_module_codegen,
+                ],
+                flag_groups = [
+                    flag_group(
+                        flags = ["/c", "%{source_file}"],
+                        expand_if_available = "source_file",
+                    ),
+                ],
+            ),
+        ],
+    )
+
+    def_file_feature = feature(
+        name = "def_file",
+        flag_sets = [
+            flag_set(
+                actions = all_link_actions,
+                flag_groups = [
+                    flag_group(
+                        flags = ["/DEF:%{def_file_path}", "/ignore:4070"],
+                        expand_if_available = "def_file_path",
+                    ),
+                ],
+            ),
+        ],
+    )
+
+    msvc_env_feature = feature(
+        name = "msvc_env",
+        env_sets = [
+            env_set(
+                actions = [
+                    ACTION_NAMES.c_compile,
+                    ACTION_NAMES.cpp_compile,
+                    ACTION_NAMES.cpp_module_compile,
+                    ACTION_NAMES.cpp_module_codegen,
+                    ACTION_NAMES.cpp_header_parsing,
+                    ACTION_NAMES.assemble,
+                    ACTION_NAMES.preprocess_assemble,
+                    ACTION_NAMES.cpp_link_executable,
+                    ACTION_NAMES.cpp_link_dynamic_library,
+                    ACTION_NAMES.cpp_link_nodeps_dynamic_library,
+                    ACTION_NAMES.cpp_link_static_library,
+                ],
+                env_entries = [
+                    env_entry(key = "PATH", value = ""),
+                    env_entry(key = "TMP", value = ""),
+                    env_entry(key = "TEMP", value = ""),
+                ],
+            ),
+        ],
+        implies = ["msvc_compile_env", "msvc_link_env"],
+    )
+
+    features = [
+        no_legacy_features_feature,
+        nologo_feature,
+        has_configured_linker_path_feature,
+        no_stripping_feature,
+        targets_windows_feature,
+        copy_dynamic_libraries_to_binary_feature,
+        default_compile_flags_feature,
+        msvc_env_feature,
+        msvc_compile_env_feature,
+        msvc_link_env_feature,
+        include_paths_feature,
+        preprocessor_defines_feature,
+        parse_showincludes_feature,
+        generate_pdb_file_feature,
+        shared_flag_feature,
+        linkstamps_feature,
+        output_execpath_flags_feature,
+        archiver_flags_feature,
+        input_param_flags_feature,
+        linker_subsystem_flag_feature,
+        user_link_flags_feature,
+        default_link_flags_feature,
+        linker_param_file_feature,
+        static_link_msvcrt_feature,
+        static_link_msvcrt_no_debug_feature,
+        dynamic_link_msvcrt_no_debug_feature,
+        static_link_msvcrt_debug_feature,
+        dynamic_link_msvcrt_debug_feature,
+        dbg_feature,
+        fastbuild_feature,
+        opt_feature,
+        frame_pointer_feature,
+        disable_assertions_feature,
+        determinism_feature,
+        treat_warnings_as_errors_feature,
+        smaller_binary_feature,
+        ignore_noisy_warnings_feature,
+        user_compile_flags_feature,
+        sysroot_feature,
+        unfiltered_compile_flags_feature,
+        compiler_output_flags_feature,
+        compiler_input_flags_feature,
+        def_file_feature,
+        windows_export_all_symbols_feature,
+        no_windows_export_all_symbols_feature,
+        supports_dynamic_linker_feature,
+        supports_interface_shared_libraries_feature,
+    ]
+
+    artifact_name_patterns = [
+        artifact_name_pattern(
+            category_name = "object_file",
+            prefix = "",
+            extension = ".obj",
+        ),
+        artifact_name_pattern(
+            category_name = "static_library",
+            prefix = "",
+            extension = ".lib",
+        ),
+        artifact_name_pattern(
+            category_name = "alwayslink_static_library",
+            prefix = "",
+            extension = ".lo.lib",
+        ),
+        artifact_name_pattern(
+            category_name = "executable",
+            prefix = "",
+            extension = ".exe",
+        ),
+        artifact_name_pattern(
+            category_name = "dynamic_library",
+            prefix = "",
+            extension = ".dll",
+        ),
+        artifact_name_pattern(
+            category_name = "interface_library",
+            prefix = "",
+            extension = ".if.lib",
+        ),
+    ]
+
+    make_variables = []
+
+    tool_paths = [
+        tool_path(name = "ar", path = ""),
+        tool_path(name = "ml", path = ""),
+        tool_path(name = "cpp", path = ""),
+        tool_path(name = "gcc", path = ""),
+        tool_path(name = "gcov", path = "wrapper/bin/msvc_nop.bat"),
+        tool_path(name = "ld", path = ""),
+        tool_path(name = "nm", path = "wrapper/bin/msvc_nop.bat"),
+        tool_path(
+            name = "objcopy",
+            path = "wrapper/bin/msvc_nop.bat",
+        ),
+        tool_path(
+            name = "objdump",
+            path = "wrapper/bin/msvc_nop.bat",
+        ),
+        tool_path(
+            name = "strip",
+            path = "wrapper/bin/msvc_nop.bat",
+        ),
+    ]
+
+    return cc_common.create_cc_toolchain_config_info(
+        ctx = ctx,
+        features = features,
+        action_configs = action_configs,
+        artifact_name_patterns = artifact_name_patterns,
+        cxx_builtin_include_directories = cxx_builtin_include_directories,
+        toolchain_identifier = toolchain_identifier,
+        host_system_name = host_system_name,
+        target_system_name = target_system_name,
+        target_cpu = target_cpu,
+        target_libc = target_libc,
+        compiler = compiler,
+        abi_version = abi_version,
+        abi_libc_version = abi_libc_version,
+        tool_paths = tool_paths,
+        make_variables = make_variables,
+        builtin_sysroot = builtin_sysroot,
+        cc_target_os = None,
+    )
+
+def _windows_msys_mingw_impl(ctx):
+    toolchain_identifier = "msys_x64_mingw"
+    host_system_name = "local"
+    target_system_name = "local"
+    target_cpu = "x64_windows"
+    target_libc = "mingw"
+    compiler = "mingw-gcc"
+    abi_version = "local"
+    abi_libc_version = "local"
+    cc_target_os = None
+    builtin_sysroot = None
+    action_configs = []
+
+    targets_windows_feature = feature(
+        name = "targets_windows",
+        implies = ["copy_dynamic_libraries_to_binary"],
+        enabled = True,
+    )
+
+    copy_dynamic_libraries_to_binary_feature = feature(name = "copy_dynamic_libraries_to_binary")
+
+    gcc_env_feature = feature(
+        name = "gcc_env",
+        enabled = True,
+        env_sets = [
+            env_set(
+                actions = [
+                    ACTION_NAMES.c_compile,
+                    ACTION_NAMES.cpp_compile,
+                    ACTION_NAMES.cpp_module_compile,
+                    ACTION_NAMES.cpp_module_codegen,
+                    ACTION_NAMES.cpp_header_parsing,
+                    ACTION_NAMES.assemble,
+                    ACTION_NAMES.preprocess_assemble,
+                    ACTION_NAMES.cpp_link_executable,
+                    ACTION_NAMES.cpp_link_dynamic_library,
+                    ACTION_NAMES.cpp_link_nodeps_dynamic_library,
+                    ACTION_NAMES.cpp_link_static_library,
+                ],
+                env_entries = [
+                    env_entry(key = "PATH", value = "NOT_USED"),
+                ],
+            ),
+        ],
+    )
+
+    msys_mingw_flags = [
+    ]
+    msys_mingw_link_flags = [
+    ]
+
+    default_compile_flags_feature = feature(
+        name = "default_compile_flags",
+        enabled = True,
+        flag_sets = [
+            flag_set(
+                actions = [
+                    ACTION_NAMES.assemble,
+                    ACTION_NAMES.preprocess_assemble,
+                    ACTION_NAMES.linkstamp_compile,
+                    ACTION_NAMES.c_compile,
+                    ACTION_NAMES.cpp_compile,
+                    ACTION_NAMES.cpp_header_parsing,
+                    ACTION_NAMES.cpp_module_compile,
+                    ACTION_NAMES.cpp_module_codegen,
+                    ACTION_NAMES.lto_backend,
+                    ACTION_NAMES.clif_match,
+                ],
+            ),
+            flag_set(
+                actions = [
+                    ACTION_NAMES.linkstamp_compile,
+                    ACTION_NAMES.cpp_compile,
+                    ACTION_NAMES.cpp_header_parsing,
+                    ACTION_NAMES.cpp_module_compile,
+                    ACTION_NAMES.cpp_module_codegen,
+                    ACTION_NAMES.lto_backend,
+                    ACTION_NAMES.clif_match,
+                ],
+                flag_groups = ([flag_group(flags = msys_mingw_flags)] if msys_mingw_flags else []),
+            ),
+        ],
+    )
+
+    default_link_flags_feature = feature(
+        name = "default_link_flags",
+        enabled = True,
+        flag_sets = [
+            flag_set(
+                actions = all_link_actions,
+                flag_groups = ([flag_group(flags = msys_mingw_link_flags)] if msys_mingw_link_flags else []),
+            ),
+        ],
+    )
+
+    supports_dynamic_linker_feature = feature(name = "supports_dynamic_linker", enabled = True)
+
+    features = [
+        targets_windows_feature,
+        copy_dynamic_libraries_to_binary_feature,
+        gcc_env_feature,
+        default_compile_flags_feature,
+        default_link_flags_feature,
+        supports_dynamic_linker_feature,
+    ]
+
+    cxx_builtin_include_directories = [
+    ]
+
+    artifact_name_patterns = [
+        artifact_name_pattern(
+            category_name = "executable",
+            prefix = "",
+            extension = ".exe",
+        ),
+    ]
+
+    make_variables = []
+    tool_paths = [
+    ]
+
+    return cc_common.create_cc_toolchain_config_info(
+        ctx = ctx,
+        features = features,
+        action_configs = action_configs,
+        artifact_name_patterns = artifact_name_patterns,
+        cxx_builtin_include_directories = cxx_builtin_include_directories,
+        toolchain_identifier = toolchain_identifier,
+        host_system_name = host_system_name,
+        target_system_name = target_system_name,
+        target_cpu = target_cpu,
+        target_libc = target_libc,
+        compiler = compiler,
+        abi_version = abi_version,
+        abi_libc_version = abi_libc_version,
+        tool_paths = tool_paths,
+        make_variables = make_variables,
+        builtin_sysroot = builtin_sysroot,
+        cc_target_os = cc_target_os,
+    )
+
+def _armeabi_impl(ctx):
+    toolchain_identifier = "stub_armeabi-v7a"
+    host_system_name = "armeabi-v7a"
+    target_system_name = "armeabi-v7a"
+    target_cpu = "armeabi-v7a"
+    target_libc = "armeabi-v7a"
+    compiler = "compiler"
+    abi_version = "armeabi-v7a"
+    abi_libc_version = "armeabi-v7a"
+    cc_target_os = None
+    builtin_sysroot = None
+    action_configs = []
+
+    supports_pic_feature = feature(name = "supports_pic", enabled = True)
+    supports_dynamic_linker_feature = feature(name = "supports_dynamic_linker", enabled = True)
+    features = [supports_dynamic_linker_feature, supports_pic_feature]
+
+    cxx_builtin_include_directories = []
+    artifact_name_patterns = []
+    make_variables = []
+
+    tool_paths = [
+        tool_path(name = "ar", path = "/bin/false"),
+        tool_path(name = "compat-ld", path = "/bin/false"),
+        tool_path(name = "cpp", path = "/bin/false"),
+        tool_path(name = "dwp", path = "/bin/false"),
+        tool_path(name = "gcc", path = "/bin/false"),
+        tool_path(name = "gcov", path = "/bin/false"),
+        tool_path(name = "ld", path = "/bin/false"),
+        tool_path(name = "nm", path = "/bin/false"),
+        tool_path(name = "objcopy", path = "/bin/false"),
+        tool_path(name = "objdump", path = "/bin/false"),
+        tool_path(name = "strip", path = "/bin/false"),
+    ]
+
+    return cc_common.create_cc_toolchain_config_info(
+        ctx = ctx,
+        features = features,
+        action_configs = action_configs,
+        artifact_name_patterns = artifact_name_patterns,
+        cxx_builtin_include_directories = cxx_builtin_include_directories,
+        toolchain_identifier = toolchain_identifier,
+        host_system_name = host_system_name,
+        target_system_name = target_system_name,
+        target_cpu = target_cpu,
+        target_libc = target_libc,
+        compiler = compiler,
+        abi_version = abi_version,
+        abi_libc_version = abi_libc_version,
+        tool_paths = tool_paths,
+        make_variables = make_variables,
+        builtin_sysroot = builtin_sysroot,
+        cc_target_os = cc_target_os,
+    )
+
+def _impl(ctx):
+    if ctx.attr.cpu == "armeabi-v7a":
+        return _armeabi_impl(ctx)
+    elif ctx.attr.cpu == "x64_windows" and ctx.attr.compiler == "msvc-cl":
+        return _windows_msvc_impl(ctx)
+    elif ctx.attr.cpu == "x64_windows" and ctx.attr.compiler == "mingw-gcc":
+        return _windows_msys_mingw_impl(ctx)
+
+    tool_paths = [
+        tool_path(name = "ar", path = "/usr/bin/ar"),
+        tool_path(name = "ld", path = "/usr/bin/ld"),
+        tool_path(name = "cpp", path = "/usr/bin/cpp"),
+        tool_path(name = "gcc", path = "/dt7/usr/bin/gcc"),
+        tool_path(name = "dwp", path = "/usr/bin/dwp"),
+        tool_path(name = "gcov", path = "/usr/bin/gcov"),
+        tool_path(name = "nm", path = "/usr/bin/nm"),
+        tool_path(name = "objcopy", path = "/usr/bin/objcopy"),
+        tool_path(name = "objdump", path = "/usr/bin/objdump"),
+        tool_path(name = "strip", path = "/usr/bin/strip"),
+    ]
+
+    cxx_builtin_include_directories = [
+        "/dt7/usr/lib/gcc/x86_64-pc-linux-gnu/7/include",
+        "/dt7/usr/lib/gcc/x86_64-pc-linux-gnu/7/include-fixed",
+        "/dt7/usr/include",
+        "/dt7/usr/include/c++/7",
+        "/dt7/usr/include/c++/7/x86_64-pc-linux-gnu",
+        "/dt7/usr/include/c++/7/backward",
+    ]
+
+    action_configs = []
+
+    compile_flags = [
+        "-U_FORTIFY_SOURCE",
+        "-fstack-protector",
+        "-Wall",
+        "-Wunused-but-set-parameter",
+        "-Wno-free-nonheap-object",
+        "-fno-omit-frame-pointer",
+    ]
+
+    dbg_compile_flags = [
+        "-g",
+    ]
+
+    opt_compile_flags = [
+        "-g0",
+        "-O2",
+        "-D_FORTIFY_SOURCE=1",
+        "-DNDEBUG",
+        "-ffunction-sections",
+        "-fdata-sections",
+    ]
+
+    cxx_flags = [
+        "-std=c++0x",
+    ]
+
+    link_flags = [
+        "-fuse-ld=gold",
+        "-Wl,-no-as-needed",
+        "-Wl,-z,relro,-z,now",
+        "-B/dt7/usr/bin",
+        "-pass-exit-codes",
+        "-lstdc++",
+        "-lm",
+    ]
+
+    opt_link_flags = [
+        "-Wl,--gc-sections",
+    ]
+
+    unfiltered_compile_flags = [
+        "-fno-canonical-system-headers",
+        "-Wno-builtin-macro-redefined",
+        "-D__DATE__=\"redacted\"",
+        "-D__TIMESTAMP__=\"redacted\"",
+        "-D__TIME__=\"redacted\"",
+    ]
+
+    targets_windows_feature = feature(
+        name = "targets_windows",
+        implies = ["copy_dynamic_libraries_to_binary"],
+        enabled = True,
+    )
+
+    copy_dynamic_libraries_to_binary_feature = feature(name = "copy_dynamic_libraries_to_binary")
+
+    gcc_env_feature = feature(
+        name = "gcc_env",
+        enabled = True,
+        env_sets = [
+            env_set(
+                actions = [
+                    ACTION_NAMES.c_compile,
+                    ACTION_NAMES.cpp_compile,
+                    ACTION_NAMES.cpp_module_compile,
+                    ACTION_NAMES.cpp_module_codegen,
+                    ACTION_NAMES.cpp_header_parsing,
+                    ACTION_NAMES.assemble,
+                    ACTION_NAMES.preprocess_assemble,
+                    ACTION_NAMES.cpp_link_executable,
+                    ACTION_NAMES.cpp_link_dynamic_library,
+                    ACTION_NAMES.cpp_link_nodeps_dynamic_library,
+                    ACTION_NAMES.cpp_link_static_library,
+                ],
+                env_entries = [
+                    env_entry(key = "PATH", value = "NOT_USED"),
+                ],
+            ),
+        ],
+    )
+
+    windows_features = [
+        targets_windows_feature,
+        copy_dynamic_libraries_to_binary_feature,
+        gcc_env_feature,
+    ]
+
+    coverage_feature = feature(
+        name = "coverage",
+        provides = ["profile"],
+        flag_sets = [
+            flag_set(
+                actions = [
+                    ACTION_NAMES.preprocess_assemble,
+                    ACTION_NAMES.c_compile,
+                    ACTION_NAMES.cpp_compile,
+                    ACTION_NAMES.cpp_header_parsing,
+                    ACTION_NAMES.cpp_module_compile,
+                ],
+                flag_groups = [
+                    flag_group(flags = ["--coverage"]),
+                ],
+            ),
+            flag_set(
+                actions = [
+                    ACTION_NAMES.cpp_link_dynamic_library,
+                    ACTION_NAMES.cpp_link_nodeps_dynamic_library,
+                    ACTION_NAMES.cpp_link_executable,
+                ],
+                flag_groups = [
+                    flag_group(flags = ["--coverage"]),
+                ],
+            ),
+        ],
+    )
+
+    supports_pic_feature = feature(
+        name = "supports_pic",
+        enabled = True,
+    )
+    supports_start_end_lib_feature = feature(
+        name = "supports_start_end_lib",
+        enabled = True,
+    )
+
+    default_compile_flags_feature = feature(
+        name = "default_compile_flags",
+        enabled = True,
+        flag_sets = [
+            flag_set(
+                actions = [
+                    ACTION_NAMES.assemble,
+                    ACTION_NAMES.preprocess_assemble,
+                    ACTION_NAMES.linkstamp_compile,
+                    ACTION_NAMES.c_compile,
+                    ACTION_NAMES.cpp_compile,
+                    ACTION_NAMES.cpp_header_parsing,
+                    ACTION_NAMES.cpp_module_compile,
+                    ACTION_NAMES.cpp_module_codegen,
+                    ACTION_NAMES.lto_backend,
+                    ACTION_NAMES.clif_match,
+                ],
+                flag_groups = ([flag_group(flags = compile_flags)] if compile_flags else []),
+            ),
+            flag_set(
+                actions = [
+                    ACTION_NAMES.assemble,
+                    ACTION_NAMES.preprocess_assemble,
+                    ACTION_NAMES.linkstamp_compile,
+                    ACTION_NAMES.c_compile,
+                    ACTION_NAMES.cpp_compile,
+                    ACTION_NAMES.cpp_header_parsing,
+                    ACTION_NAMES.cpp_module_compile,
+                    ACTION_NAMES.cpp_module_codegen,
+                    ACTION_NAMES.lto_backend,
+                    ACTION_NAMES.clif_match,
+                ],
+                flag_groups = ([flag_group(flags = dbg_compile_flags)] if dbg_compile_flags else []),
+                with_features = [with_feature_set(features = ["dbg"])],
+            ),
+            flag_set(
+                actions = [
+                    ACTION_NAMES.assemble,
+                    ACTION_NAMES.preprocess_assemble,
+                    ACTION_NAMES.linkstamp_compile,
+                    ACTION_NAMES.c_compile,
+                    ACTION_NAMES.cpp_compile,
+                    ACTION_NAMES.cpp_header_parsing,
+                    ACTION_NAMES.cpp_module_compile,
+                    ACTION_NAMES.cpp_module_codegen,
+                    ACTION_NAMES.lto_backend,
+                    ACTION_NAMES.clif_match,
+                ],
+                flag_groups = ([flag_group(flags = opt_compile_flags)] if opt_compile_flags else []),
+                with_features = [with_feature_set(features = ["opt"])],
+            ),
+            flag_set(
+                actions = [
+                    ACTION_NAMES.linkstamp_compile,
+                    ACTION_NAMES.cpp_compile,
+                    ACTION_NAMES.cpp_header_parsing,
+                    ACTION_NAMES.cpp_module_compile,
+                    ACTION_NAMES.cpp_module_codegen,
+                    ACTION_NAMES.lto_backend,
+                    ACTION_NAMES.clif_match,
+                ],
+                flag_groups = ([flag_group(flags = cxx_flags)] if cxx_flags else []),
+            ),
+        ],
+    )
+
+    default_link_flags_feature = feature(
+        name = "default_link_flags",
+        enabled = True,
+        flag_sets = [
+            flag_set(
+                actions = all_link_actions,
+                flag_groups = ([flag_group(flags = link_flags)] if link_flags else []),
+            ),
+            flag_set(
+                actions = all_link_actions,
+                flag_groups = ([flag_group(flags = opt_link_flags)] if opt_link_flags else []),
+                with_features = [with_feature_set(features = ["opt"])],
+            ),
+        ],
+    )
+
+    dbg_feature = feature(name = "dbg")
+
+    opt_feature = feature(name = "opt")
+
+    sysroot_feature = feature(
+        name = "sysroot",
+        enabled = True,
+        flag_sets = [
+            flag_set(
+                actions = [
+                    ACTION_NAMES.preprocess_assemble,
+                    ACTION_NAMES.linkstamp_compile,
+                    ACTION_NAMES.c_compile,
+                    ACTION_NAMES.cpp_compile,
+                    ACTION_NAMES.cpp_header_parsing,
+                    ACTION_NAMES.cpp_module_compile,
+                    ACTION_NAMES.cpp_module_codegen,
+                    ACTION_NAMES.lto_backend,
+                    ACTION_NAMES.clif_match,
+                    ACTION_NAMES.cpp_link_executable,
+                    ACTION_NAMES.cpp_link_dynamic_library,
+                    ACTION_NAMES.cpp_link_nodeps_dynamic_library,
+                ],
+                flag_groups = [
+                    flag_group(
+                        flags = ["--sysroot=%{sysroot}"],
+                        expand_if_available = "sysroot",
+                    ),
+                ],
+            ),
+        ],
+    )
+
+    fdo_optimize_feature = feature(
+        name = "fdo_optimize",
+        flag_sets = [
+            flag_set(
+                actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile],
+                flag_groups = [
+                    flag_group(
+                        flags = [
+                            "-fprofile-use=%{fdo_profile_path}",
+                            "-fprofile-correction",
+                        ],
+                        expand_if_available = "fdo_profile_path",
+                    ),
+                ],
+            ),
+        ],
+        provides = ["profile"],
+    )
+
+    supports_dynamic_linker_feature = feature(name = "supports_dynamic_linker", enabled = True)
+
+    user_compile_flags_feature = feature(
+        name = "user_compile_flags",
+        enabled = True,
+        flag_sets = [
+            flag_set(
+                actions = [
+                    ACTION_NAMES.assemble,
+                    ACTION_NAMES.preprocess_assemble,
+                    ACTION_NAMES.linkstamp_compile,
+                    ACTION_NAMES.c_compile,
+                    ACTION_NAMES.cpp_compile,
+                    ACTION_NAMES.cpp_header_parsing,
+                    ACTION_NAMES.cpp_module_compile,
+                    ACTION_NAMES.cpp_module_codegen,
+                    ACTION_NAMES.lto_backend,
+                    ACTION_NAMES.clif_match,
+                ],
+                flag_groups = [
+                    flag_group(
+                        flags = ["%{user_compile_flags}"],
+                        iterate_over = "user_compile_flags",
+                        expand_if_available = "user_compile_flags",
+                    ),
+                ],
+            ),
+        ],
+    )
+
+    unfiltered_compile_flags_feature = feature(
+        name = "unfiltered_compile_flags",
+        enabled = True,
+        flag_sets = [
+            flag_set(
+                actions = [
+                    ACTION_NAMES.assemble,
+                    ACTION_NAMES.preprocess_assemble,
+                    ACTION_NAMES.linkstamp_compile,
+                    ACTION_NAMES.c_compile,
+                    ACTION_NAMES.cpp_compile,
+                    ACTION_NAMES.cpp_header_parsing,
+                    ACTION_NAMES.cpp_module_compile,
+                    ACTION_NAMES.cpp_module_codegen,
+                    ACTION_NAMES.lto_backend,
+                    ACTION_NAMES.clif_match,
+                ],
+                flag_groups = ([flag_group(flags = unfiltered_compile_flags)] if unfiltered_compile_flags else []),
+            ),
+        ],
+    )
+
+    features = [
+        supports_pic_feature,
+        supports_start_end_lib_feature,
+        coverage_feature,
+        default_compile_flags_feature,
+        default_link_flags_feature,
+        fdo_optimize_feature,
+        supports_dynamic_linker_feature,
+        dbg_feature,
+        opt_feature,
+        user_compile_flags_feature,
+        sysroot_feature,
+        unfiltered_compile_flags_feature,
+    ]
+
+    artifact_name_patterns = [
+    ]
+
+    make_variables = []
+
+    return cc_common.create_cc_toolchain_config_info(
+        ctx = ctx,
+        features = features,
+        action_configs = action_configs,
+        artifact_name_patterns = artifact_name_patterns,
+        cxx_builtin_include_directories = cxx_builtin_include_directories,
+        toolchain_identifier = "linux_gnu_x86",
+        host_system_name = "i686-unknown-linux-gnu",
+        target_system_name = "x86_64-unknown-linux-gnu",
+        target_cpu = "k8",
+        target_libc = "glibc_2.19",
+        compiler = "/dt7/usr/bin/gcc",
+        abi_version = "gcc",
+        abi_libc_version = "glibc_2.19",
+        tool_paths = tool_paths,
+        make_variables = make_variables,
+        builtin_sysroot = "",
+        cc_target_os = None,
+    )
+
+cc_toolchain_config = rule(
+    implementation = _impl,
+    attrs = {
+        "cpu": attr.string(mandatory = True),
+        "compiler": attr.string(),
+    },
+    provides = [CcToolchainConfigInfo],
+)
diff --git a/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/cc_wrapper.sh b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/cc_wrapper.sh
new file mode 100755
index 0000000..898befb
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/cc_wrapper.sh
@@ -0,0 +1,25 @@
+#!/bin/bash
+#
+# Copyright 2015 The Bazel Authors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Ship the environment to the C++ action
+#
+set -eu
+
+# Set-up the environment
+
+
+# Call the C++ compiler
+/dt7/usr/bin/gcc "$@"
diff --git a/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/dummy_toolchain.bzl b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/dummy_toolchain.bzl
new file mode 100755
index 0000000..45c0285
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/dummy_toolchain.bzl
@@ -0,0 +1,23 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2017 The Bazel Authors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Skylark rule that stubs a toolchain."""
+
+def _dummy_toolchain_impl(ctx):
+    ctx = ctx  # unused argument
+    toolchain = platform_common.ToolchainInfo()
+    return [toolchain]
+
+dummy_toolchain = rule(_dummy_toolchain_impl, attrs = {})
diff --git a/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/tools/cpp/empty.cc b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/tools/cpp/empty.cc
new file mode 100755
index 0000000..237c8ce
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/tools/cpp/empty.cc
@@ -0,0 +1 @@
+int main() {}
diff --git a/third_party/toolchains/preconfig/ubuntu16.04/py/BUILD b/third_party/toolchains/preconfig/ubuntu16.04/py/BUILD
new file mode 100755
index 0000000..3cd5fdd
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu16.04/py/BUILD
@@ -0,0 +1,173 @@
+licenses(["restricted"])
+
+package(default_visibility = ["//visibility:public"])
+
+# To build Python C/C++ extension on Windows, we need to link to python import library pythonXY.lib
+# See https://docs.python.org/3/extending/windows.html
+cc_import(
+    name = "python_lib",
+    interface_library = select({
+        ":windows": ":python_import_lib",
+        # A placeholder for Unix platforms which makes --no_build happy.
+        "//conditions:default": "not-existing.lib",
+    }),
+    system_provided = 1,
+)
+
+cc_library(
+    name = "python_headers",
+    hdrs = [":python_include"],
+    includes = ["python_include"],
+    deps = select({
+        ":windows": [":python_lib"],
+        "//conditions:default": [],
+    }),
+)
+
+cc_library(
+    name = "numpy_headers",
+    hdrs = [":numpy_include"],
+    includes = ["numpy_include"],
+)
+
+config_setting(
+    name = "windows",
+    values = {"cpu": "x64_windows"},
+    visibility = ["//visibility:public"],
+)
+
+genrule(
+    name = "python_include",
+    outs = [
+        "python_include/Python-ast.h",
+        "python_include/Python.h",
+        "python_include/abstract.h",
+        "python_include/asdl.h",
+        "python_include/ast.h",
+        "python_include/bitset.h",
+        "python_include/boolobject.h",
+        "python_include/bufferobject.h",
+        "python_include/bytearrayobject.h",
+        "python_include/bytes_methods.h",
+        "python_include/bytesobject.h",
+        "python_include/cStringIO.h",
+        "python_include/cellobject.h",
+        "python_include/ceval.h",
+        "python_include/classobject.h",
+        "python_include/cobject.h",
+        "python_include/code.h",
+        "python_include/codecs.h",
+        "python_include/compile.h",
+        "python_include/complexobject.h",
+        "python_include/datetime.h",
+        "python_include/descrobject.h",
+        "python_include/dictobject.h",
+        "python_include/dtoa.h",
+        "python_include/enumobject.h",
+        "python_include/errcode.h",
+        "python_include/eval.h",
+        "python_include/fileobject.h",
+        "python_include/floatobject.h",
+        "python_include/frameobject.h",
+        "python_include/funcobject.h",
+        "python_include/genobject.h",
+        "python_include/graminit.h",
+        "python_include/grammar.h",
+        "python_include/import.h",
+        "python_include/intobject.h",
+        "python_include/intrcheck.h",
+        "python_include/iterobject.h",
+        "python_include/listobject.h",
+        "python_include/longintrepr.h",
+        "python_include/longobject.h",
+        "python_include/marshal.h",
+        "python_include/memoryobject.h",
+        "python_include/metagrammar.h",
+        "python_include/methodobject.h",
+        "python_include/modsupport.h",
+        "python_include/moduleobject.h",
+        "python_include/node.h",
+        "python_include/object.h",
+        "python_include/objimpl.h",
+        "python_include/opcode.h",
+        "python_include/osdefs.h",
+        "python_include/parsetok.h",
+        "python_include/patchlevel.h",
+        "python_include/pgen.h",
+        "python_include/pgenheaders.h",
+        "python_include/py_curses.h",
+        "python_include/pyarena.h",
+        "python_include/pycapsule.h",
+        "python_include/pyconfig.h",
+        "python_include/pyctype.h",
+        "python_include/pydebug.h",
+        "python_include/pyerrors.h",
+        "python_include/pyexpat.h",
+        "python_include/pyfpe.h",
+        "python_include/pygetopt.h",
+        "python_include/pymacconfig.h",
+        "python_include/pymactoolbox.h",
+        "python_include/pymath.h",
+        "python_include/pymem.h",
+        "python_include/pyport.h",
+        "python_include/pystate.h",
+        "python_include/pystrcmp.h",
+        "python_include/pystrtod.h",
+        "python_include/pythonrun.h",
+        "python_include/pythread.h",
+        "python_include/rangeobject.h",
+        "python_include/setobject.h",
+        "python_include/sliceobject.h",
+        "python_include/stringobject.h",
+        "python_include/structmember.h",
+        "python_include/structseq.h",
+        "python_include/symtable.h",
+        "python_include/sysmodule.h",
+        "python_include/timefuncs.h",
+        "python_include/token.h",
+        "python_include/traceback.h",
+        "python_include/tupleobject.h",
+        "python_include/ucnhash.h",
+        "python_include/unicodeobject.h",
+        "python_include/warnings.h",
+        "python_include/weakrefobject.h",
+    ],
+    cmd = """
+cp -f "/usr/include/python2.7/Python-ast.h" "$(@D)/python_include/Python-ast.h" && cp -f "/usr/include/python2.7/Python.h" "$(@D)/python_include/Python.h" && cp -f "/usr/include/python2.7/abstract.h" "$(@D)/python_include/abstract.h" && cp -f "/usr/include/python2.7/asdl.h" "$(@D)/python_include/asdl.h" && cp -f "/usr/include/python2.7/ast.h" "$(@D)/python_include/ast.h" && cp -f "/usr/include/python2.7/bitset.h" "$(@D)/python_include/bitset.h" && cp -f "/usr/include/python2.7/boolobject.h" "$(@D)/python_include/boolobject.h" && cp -f "/usr/include/python2.7/bufferobject.h" "$(@D)/python_include/bufferobject.h" && cp -f "/usr/include/python2.7/bytearrayobject.h" "$(@D)/python_include/bytearrayobject.h" && cp -f "/usr/include/python2.7/bytes_methods.h" "$(@D)/python_include/bytes_methods.h" && cp -f "/usr/include/python2.7/bytesobject.h" "$(@D)/python_include/bytesobject.h" && cp -f "/usr/include/python2.7/cStringIO.h" "$(@D)/python_include/cStringIO.h" && cp -f "/usr/include/python2.7/cellobject.h" "$(@D)/python_include/cellobject.h" && cp -f "/usr/include/python2.7/ceval.h" "$(@D)/python_include/ceval.h" && cp -f "/usr/include/python2.7/classobject.h" "$(@D)/python_include/classobject.h" && cp -f "/usr/include/python2.7/cobject.h" "$(@D)/python_include/cobject.h" && cp -f "/usr/include/python2.7/code.h" "$(@D)/python_include/code.h" && cp -f "/usr/include/python2.7/codecs.h" "$(@D)/python_include/codecs.h" && cp -f "/usr/include/python2.7/compile.h" "$(@D)/python_include/compile.h" && cp -f "/usr/include/python2.7/complexobject.h" "$(@D)/python_include/complexobject.h" && cp -f "/usr/include/python2.7/datetime.h" "$(@D)/python_include/datetime.h" && cp -f "/usr/include/python2.7/descrobject.h" "$(@D)/python_include/descrobject.h" && cp -f "/usr/include/python2.7/dictobject.h" "$(@D)/python_include/dictobject.h" && cp -f "/usr/include/python2.7/dtoa.h" "$(@D)/python_include/dtoa.h" && cp -f "/usr/include/python2.7/enumobject.h" "$(@D)/python_include/enumobject.h" && cp -f "/usr/include/python2.7/errcode.h" "$(@D)/python_include/errcode.h" && cp -f "/usr/include/python2.7/eval.h" "$(@D)/python_include/eval.h" && cp -f "/usr/include/python2.7/fileobject.h" "$(@D)/python_include/fileobject.h" && cp -f "/usr/include/python2.7/floatobject.h" "$(@D)/python_include/floatobject.h" && cp -f "/usr/include/python2.7/frameobject.h" "$(@D)/python_include/frameobject.h" && cp -f "/usr/include/python2.7/funcobject.h" "$(@D)/python_include/funcobject.h" && cp -f "/usr/include/python2.7/genobject.h" "$(@D)/python_include/genobject.h" && cp -f "/usr/include/python2.7/graminit.h" "$(@D)/python_include/graminit.h" && cp -f "/usr/include/python2.7/grammar.h" "$(@D)/python_include/grammar.h" && cp -f "/usr/include/python2.7/import.h" "$(@D)/python_include/import.h" && cp -f "/usr/include/python2.7/intobject.h" "$(@D)/python_include/intobject.h" && cp -f "/usr/include/python2.7/intrcheck.h" "$(@D)/python_include/intrcheck.h" && cp -f "/usr/include/python2.7/iterobject.h" "$(@D)/python_include/iterobject.h" && cp -f "/usr/include/python2.7/listobject.h" "$(@D)/python_include/listobject.h" && cp -f "/usr/include/python2.7/longintrepr.h" "$(@D)/python_include/longintrepr.h" && cp -f "/usr/include/python2.7/longobject.h" "$(@D)/python_include/longobject.h" && cp -f "/usr/include/python2.7/marshal.h" "$(@D)/python_include/marshal.h" && cp -f "/usr/include/python2.7/memoryobject.h" "$(@D)/python_include/memoryobject.h" && cp -f "/usr/include/python2.7/metagrammar.h" "$(@D)/python_include/metagrammar.h" && cp -f "/usr/include/python2.7/methodobject.h" "$(@D)/python_include/methodobject.h" && cp -f "/usr/include/python2.7/modsupport.h" "$(@D)/python_include/modsupport.h" && cp -f "/usr/include/python2.7/moduleobject.h" "$(@D)/python_include/moduleobject.h" && cp -f "/usr/include/python2.7/node.h" "$(@D)/python_include/node.h" && cp -f "/usr/include/python2.7/object.h" "$(@D)/python_include/object.h" && cp -f "/usr/include/python2.7/objimpl.h" "$(@D)/python_include/objimpl.h" && cp -f "/usr/include/python2.7/opcode.h" "$(@D)/python_include/opcode.h" && cp -f "/usr/include/python2.7/osdefs.h" "$(@D)/python_include/osdefs.h" && cp -f "/usr/include/python2.7/parsetok.h" "$(@D)/python_include/parsetok.h" && cp -f "/usr/include/python2.7/patchlevel.h" "$(@D)/python_include/patchlevel.h" && cp -f "/usr/include/python2.7/pgen.h" "$(@D)/python_include/pgen.h" && cp -f "/usr/include/python2.7/pgenheaders.h" "$(@D)/python_include/pgenheaders.h" && cp -f "/usr/include/python2.7/py_curses.h" "$(@D)/python_include/py_curses.h" && cp -f "/usr/include/python2.7/pyarena.h" "$(@D)/python_include/pyarena.h" && cp -f "/usr/include/python2.7/pycapsule.h" "$(@D)/python_include/pycapsule.h" && cp -f "/usr/include/python2.7/pyconfig.h" "$(@D)/python_include/pyconfig.h" && cp -f "/usr/include/python2.7/pyctype.h" "$(@D)/python_include/pyctype.h" && cp -f "/usr/include/python2.7/pydebug.h" "$(@D)/python_include/pydebug.h" && cp -f "/usr/include/python2.7/pyerrors.h" "$(@D)/python_include/pyerrors.h" && cp -f "/usr/include/python2.7/pyexpat.h" "$(@D)/python_include/pyexpat.h" && cp -f "/usr/include/python2.7/pyfpe.h" "$(@D)/python_include/pyfpe.h" && cp -f "/usr/include/python2.7/pygetopt.h" "$(@D)/python_include/pygetopt.h" && cp -f "/usr/include/python2.7/pymacconfig.h" "$(@D)/python_include/pymacconfig.h" && cp -f "/usr/include/python2.7/pymactoolbox.h" "$(@D)/python_include/pymactoolbox.h" && cp -f "/usr/include/python2.7/pymath.h" "$(@D)/python_include/pymath.h" && cp -f "/usr/include/python2.7/pymem.h" "$(@D)/python_include/pymem.h" && cp -f "/usr/include/python2.7/pyport.h" "$(@D)/python_include/pyport.h" && cp -f "/usr/include/python2.7/pystate.h" "$(@D)/python_include/pystate.h" && cp -f "/usr/include/python2.7/pystrcmp.h" "$(@D)/python_include/pystrcmp.h" && cp -f "/usr/include/python2.7/pystrtod.h" "$(@D)/python_include/pystrtod.h" && cp -f "/usr/include/python2.7/pythonrun.h" "$(@D)/python_include/pythonrun.h" && cp -f "/usr/include/python2.7/pythread.h" "$(@D)/python_include/pythread.h" && cp -f "/usr/include/python2.7/rangeobject.h" "$(@D)/python_include/rangeobject.h" && cp -f "/usr/include/python2.7/setobject.h" "$(@D)/python_include/setobject.h" && cp -f "/usr/include/python2.7/sliceobject.h" "$(@D)/python_include/sliceobject.h" && cp -f "/usr/include/python2.7/stringobject.h" "$(@D)/python_include/stringobject.h" && cp -f "/usr/include/python2.7/structmember.h" "$(@D)/python_include/structmember.h" && cp -f "/usr/include/python2.7/structseq.h" "$(@D)/python_include/structseq.h" && cp -f "/usr/include/python2.7/symtable.h" "$(@D)/python_include/symtable.h" && cp -f "/usr/include/python2.7/sysmodule.h" "$(@D)/python_include/sysmodule.h" && cp -f "/usr/include/python2.7/timefuncs.h" "$(@D)/python_include/timefuncs.h" && cp -f "/usr/include/python2.7/token.h" "$(@D)/python_include/token.h" && cp -f "/usr/include/python2.7/traceback.h" "$(@D)/python_include/traceback.h" && cp -f "/usr/include/python2.7/tupleobject.h" "$(@D)/python_include/tupleobject.h" && cp -f "/usr/include/python2.7/ucnhash.h" "$(@D)/python_include/ucnhash.h" && cp -f "/usr/include/python2.7/unicodeobject.h" "$(@D)/python_include/unicodeobject.h" && cp -f "/usr/include/python2.7/warnings.h" "$(@D)/python_include/warnings.h" && cp -f "/usr/include/python2.7/weakrefobject.h" "$(@D)/python_include/weakrefobject.h"
+   """,
+)
+
+genrule(
+    name = "numpy_include",
+    outs = [
+        "numpy_include/numpy/__multiarray_api.h",
+        "numpy_include/numpy/__ufunc_api.h",
+        "numpy_include/numpy/_neighborhood_iterator_imp.h",
+        "numpy_include/numpy/_numpyconfig.h",
+        "numpy_include/numpy/arrayobject.h",
+        "numpy_include/numpy/arrayscalars.h",
+        "numpy_include/numpy/halffloat.h",
+        "numpy_include/numpy/multiarray_api.txt",
+        "numpy_include/numpy/ndarrayobject.h",
+        "numpy_include/numpy/ndarraytypes.h",
+        "numpy_include/numpy/noprefix.h",
+        "numpy_include/numpy/npy_1_7_deprecated_api.h",
+        "numpy_include/numpy/npy_3kcompat.h",
+        "numpy_include/numpy/npy_common.h",
+        "numpy_include/numpy/npy_cpu.h",
+        "numpy_include/numpy/npy_endian.h",
+        "numpy_include/numpy/npy_interrupt.h",
+        "numpy_include/numpy/npy_math.h",
+        "numpy_include/numpy/npy_no_deprecated_api.h",
+        "numpy_include/numpy/npy_os.h",
+        "numpy_include/numpy/numpyconfig.h",
+        "numpy_include/numpy/old_defines.h",
+        "numpy_include/numpy/oldnumeric.h",
+        "numpy_include/numpy/ufunc_api.txt",
+        "numpy_include/numpy/ufuncobject.h",
+        "numpy_include/numpy/utils.h",
+    ],
+    cmd = """
+cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/__multiarray_api.h" "$(@D)/numpy_include/numpy/__multiarray_api.h" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/__ufunc_api.h" "$(@D)/numpy_include/numpy/__ufunc_api.h" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/_neighborhood_iterator_imp.h" "$(@D)/numpy_include/numpy/_neighborhood_iterator_imp.h" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/_numpyconfig.h" "$(@D)/numpy_include/numpy/_numpyconfig.h" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/arrayobject.h" "$(@D)/numpy_include/numpy/arrayobject.h" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/arrayscalars.h" "$(@D)/numpy_include/numpy/arrayscalars.h" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/halffloat.h" "$(@D)/numpy_include/numpy/halffloat.h" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/multiarray_api.txt" "$(@D)/numpy_include/numpy/multiarray_api.txt" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/ndarrayobject.h" "$(@D)/numpy_include/numpy/ndarrayobject.h" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/ndarraytypes.h" "$(@D)/numpy_include/numpy/ndarraytypes.h" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/noprefix.h" "$(@D)/numpy_include/numpy/noprefix.h" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_1_7_deprecated_api.h" "$(@D)/numpy_include/numpy/npy_1_7_deprecated_api.h" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_3kcompat.h" "$(@D)/numpy_include/numpy/npy_3kcompat.h" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_common.h" "$(@D)/numpy_include/numpy/npy_common.h" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_cpu.h" "$(@D)/numpy_include/numpy/npy_cpu.h" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_endian.h" "$(@D)/numpy_include/numpy/npy_endian.h" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_interrupt.h" "$(@D)/numpy_include/numpy/npy_interrupt.h" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_math.h" "$(@D)/numpy_include/numpy/npy_math.h" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_no_deprecated_api.h" "$(@D)/numpy_include/numpy/npy_no_deprecated_api.h" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_os.h" "$(@D)/numpy_include/numpy/npy_os.h" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/numpyconfig.h" "$(@D)/numpy_include/numpy/numpyconfig.h" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/old_defines.h" "$(@D)/numpy_include/numpy/old_defines.h" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/oldnumeric.h" "$(@D)/numpy_include/numpy/oldnumeric.h" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/ufunc_api.txt" "$(@D)/numpy_include/numpy/ufunc_api.txt" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/ufuncobject.h" "$(@D)/numpy_include/numpy/ufuncobject.h" && cp -f "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/utils.h" "$(@D)/numpy_include/numpy/utils.h"
+   """,
+)
diff --git a/third_party/toolchains/preconfig/ubuntu16.04/py/WORKSPACE b/third_party/toolchains/preconfig/ubuntu16.04/py/WORKSPACE
new file mode 100644
index 0000000..1d298fe
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu16.04/py/WORKSPACE
@@ -0,0 +1,2 @@
+# DO NOT EDIT: automatically generated WORKSPACE file for python_configure rule
+workspace(name = "local_config_python")
